Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,14 @@ def get_mask_cpp_check_expr(mask: str) -> str:
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
"kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE",
}

QSCALE_CHECK_MAP = {
"no": "quant_scale_enum::no_scale",
"pertensor": "quant_scale_enum::pertensor",
"blockscale": "quant_scale_enum::blockscale",
"kv_blockscale": "quant_scale_enum::kv_blockscale",
}

BIAS_MAP = {
Expand Down
5 changes: 4 additions & 1 deletion example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
kv_lookup_table,
) in itertools.product(
["t", "f"],
["pertensor"],
["pertensor", "kv_blockscale"],
get_mask_map(mask_impl).keys(),
["no"],
SUPPORTED_KV_MEMORY_LAYOUT,
Expand Down Expand Up @@ -740,6 +740,9 @@ def get_fwd_blobs(
for page_size in SUPPORTED_PAGE_SIZE:
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
continue
# kv_blockscale only supports page_size=1024
if pipeline.F_qscale == "kv_blockscale" and page_size != 1024:
continue
k = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,
Expand Down
20 changes: 18 additions & 2 deletions example/ck_tile/01_fmha/fmha_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,14 @@ struct fmha_batch_prefill_args

std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;

// KV_BLOCKSCALE: per-page K/V descales (Q per-tensor, K/V per-page)
// Layout: [num_block, num_kv_head, 2] where 2 = (k_descale, v_descale)
// Mutually exclusive with per-tensor k_descale_ptr/v_descale_ptr
const void* kv_block_descale_ptr = nullptr;
ck_tile::index_t kv_block_descale_stride_block = 0; // Stride along num_block dimension
ck_tile::index_t kv_block_descale_stride_head = 0; // Stride along num_kv_head dimension
ck_tile::index_t kv_block_descale_stride_kv = 1; // Stride for K/V index (last dim)
};

template <typename FmhaKernel>
Expand Down Expand Up @@ -1225,7 +1233,11 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.sink_ptr);
args.sink_ptr,
args.kv_block_descale_ptr,
args.kv_block_descale_stride_block,
args.kv_block_descale_stride_head,
args.kv_block_descale_stride_kv);
}
else
{ // create batch mode kernel arguments
Expand Down Expand Up @@ -1278,7 +1290,11 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.p_drop,
args.s_randval,
args.drop_seed_offset,
args.sink_ptr);
args.sink_ptr,
args.kv_block_descale_ptr,
args.kv_block_descale_stride_block,
args.kv_block_descale_stride_head,
args.kv_block_descale_stride_kv);
}
}();

Expand Down
13 changes: 10 additions & 3 deletions example/ck_tile/01_fmha/quant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
// keep sync with BlockAttentionQuantScaleEnum
enum class quant_scale_enum
{
no_scale = 0,
pertensor = 1,
blockscale,
no_scale = 0,
pertensor = 1,
blockscale = 2,
kv_blockscale = 3, // Q per-tensor, K/V per-page block scale
};

struct quant_scale_info
Expand All @@ -28,6 +29,8 @@ struct quant_scale_info
os << "pt";
else if(type == quant_scale_enum::blockscale)
os << "bs";
else if(type == quant_scale_enum::kv_blockscale)
os << "kvbs";
}

static quant_scale_info decode(std::string str)
Expand All @@ -45,6 +48,10 @@ struct quant_scale_info
{
info.type = quant_scale_enum::blockscale;
}
else if(str == "kvbs" || str == "3")
{
info.type = quant_scale_enum::kv_blockscale;
}
else
{
throw std::invalid_argument("invalid quant scale value: " + str);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockAttentionQuantScaleEnum
{
NO_SCALE = 0,
PERTENSOR = 1,
BLOCKSCALE,
NO_SCALE = 0,
PERTENSOR = 1,
BLOCKSCALE = 2,
KV_BLOCKSCALE = 3, // Q per-tensor, K/V per-page block scale
};

template <BlockAttentionQuantScaleEnum>
Expand All @@ -33,5 +34,10 @@ struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::BLOCKSCAL
{
static constexpr const char* name = "blockscale";
};
template <>
struct BlockAttentionQuantScaleEnumToStr<BlockAttentionQuantScaleEnum::KV_BLOCKSCALE>
{
static constexpr const char* name = "kv_blockscale";
};

} // namespace ck_tile
106 changes: 97 additions & 9 deletions include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,44 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t batch_stride_lse = 0;
};

struct FmhaFwdCommonQScaleKargs
// PERTENSOR: Q/K/V all use per-tensor descales
struct FmhaFwdPerTensorQScaleKargs
{
const void* q_descale_ptr = nullptr;
const void* k_descale_ptr = nullptr;
const void* v_descale_ptr = nullptr;
};

// KV_BLOCKSCALE: Q per-tensor, K/V per-page descales
struct FmhaFwdKVBlockScaleKargs
{
const void* q_descale_ptr = nullptr; // Per-tensor Q descale
const void* kv_block_descale_ptr = nullptr; // [num_block, num_kv_head, 2]
ck_tile::index_t kv_block_descale_stride_block = 0; // Stride along num_block dimension
ck_tile::index_t kv_block_descale_stride_head = 0; // Stride along num_kv_head dimension
ck_tile::index_t kv_block_descale_stride_kv = 1; // Stride for K/V index
};

// Helper template to select QScale Kargs type based on QScaleEnum
// EmptyType: type to use when QScaleEnum is NO_SCALE (e.g., FmhaFwdEmptyKargs<3>)
template <BlockAttentionQuantScaleEnum QScale, typename EmptyType>
struct QScaleKargsSelector
{
using type = EmptyType;
};

template <typename EmptyType>
struct QScaleKargsSelector<BlockAttentionQuantScaleEnum::PERTENSOR, EmptyType>
{
using type = FmhaFwdPerTensorQScaleKargs;
};

template <typename EmptyType>
struct QScaleKargsSelector<BlockAttentionQuantScaleEnum::KV_BLOCKSCALE, EmptyType>
{
using type = FmhaFwdKVBlockScaleKargs;
};

struct FmhaFwdDropoutSeedOffset
{
template <typename T>
Expand Down Expand Up @@ -255,9 +286,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<3>>,
QScaleKargsSelector<QScaleEnum, FmhaFwdEmptyKargs<3>>::type,
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
{
Expand All @@ -276,9 +305,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<3>>,
QScaleKargsSelector<QScaleEnum, FmhaFwdEmptyKargs<3>>::type,
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
{
Expand Down Expand Up @@ -348,7 +375,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
const void* sink_ptr = nullptr)
const void* sink_ptr = nullptr,
const void* kv_block_descale_ptr = nullptr,
ck_tile::index_t kv_block_descale_stride_block = 0,
ck_tile::index_t kv_block_descale_stride_head = 0,
ck_tile::index_t kv_block_descale_stride_kv = 1)
{
Kargs kargs{{q_ptr,
k_ptr,
Expand Down Expand Up @@ -419,6 +450,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.kv_block_descale_ptr = kv_block_descale_ptr;
kargs.kv_block_descale_stride_block = kv_block_descale_stride_block;
kargs.kv_block_descale_stride_head = kv_block_descale_stride_head;
kargs.kv_block_descale_stride_kv = kv_block_descale_stride_kv;
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
Expand Down Expand Up @@ -495,7 +534,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset,
const void* sink_ptr = nullptr)
const void* sink_ptr = nullptr,
const void* kv_block_descale_ptr = nullptr,
ck_tile::index_t kv_block_descale_stride_block = 0,
ck_tile::index_t kv_block_descale_stride_head = 0,
ck_tile::index_t kv_block_descale_stride_kv = 1)
{
Kargs kargs{{q_ptr,
k_ptr,
Expand Down Expand Up @@ -563,6 +606,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.kv_block_descale_ptr = kv_block_descale_ptr;
kargs.kv_block_descale_stride_block = kv_block_descale_stride_block;
kargs.kv_block_descale_stride_head = kv_block_descale_stride_head;
kargs.kv_block_descale_stride_kv = kv_block_descale_stride_kv;
}
if constexpr(kHasDropout)
{
if(drop_seed_offset.index() == 0) // seed & offset come from host
Expand Down Expand Up @@ -1162,6 +1213,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel

return kargs.scale_s * q_descale * k_descale;
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
// Q is per-tensor, K is per-page (handled in pipeline)
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
return kargs.scale_s * q_descale;
}
else
{
return kargs.scale_s;
Expand Down Expand Up @@ -1237,6 +1294,37 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
dropout,
sink_value);
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
// KV_BLOCKSCALE: K/V descale is per-page, handled in pipeline
const float* kv_block_descale_ptr =
reinterpret_cast<const float*>(kargs.kv_block_descale_ptr);

return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
lse_dram_window,
mask,
position_encoding,
variant_params.sm_scale,
variant,
variant_params,
block_indices,
smem_ptr,
page_idx,
stride_k_for_pipeline,
stride_v_for_pipeline,
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout,
sink_value,
kv_block_descale_ptr,
kargs.kv_block_descale_stride_block,
kargs.kv_block_descale_stride_head,
kargs.kv_block_descale_stride_kv);
}
else
{
return FmhaPipeline{}(q_dram_window,
Expand Down
Loading