-
Notifications
You must be signed in to change notification settings - Fork 270
[CK TILE] fix numerical errors of preshuffle_b #3695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR fixes numerical issues in the weight preshuffle path for quantized GEMM (including pk_int4 → fp8), aligns host-side shuffle logic with the kernel’s expectations, and broadens test coverage to additional GEMM tile shapes so fp8 issues are exercised.
Changes:
- Adjusted weight preshuffle and AB-quant pipeline policies to compute
KLaneBytesbased on the internal compute type (mixed_prec_compute_type_from_input_t) instead of the packed input storage type, ensuring correctNumAccessand memory access patterns for pk_int4/fp8 flows. - Fixed
shuffle_b_permuteN’s tile-view layout and permutation to use a consistentKLane/ItemsPerAccessformulation and corrected the permutation order to match the intended layout. - Extended the GEMM preshuffle test utilities with new config structs and updated example grouped-GEMM code to use properly qualified dependent template types, improving both coverage and compilation robustness.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp | Adds config_mn_16x16/config_mn_32x32 warp-tile configs and runs the preshuffle tests with both, increasing coverage across GEMM tile shapes (note: the *_16x16/*_32x32 names are currently inverted relative to their M_Warp_Tile/N_Warp_Tile values). |
| include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp | Updates UniversalWeightPreshufflePipelineAgBgCrPolicy::GetBlockWeightPreshuffle to compute KLaneBytes from BTypeToUse (mixed-precision compute type) so access granularity matches the actual compute datatype. |
| include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp | Mirrors the KLaneBytes/NumAccess fix in the AB-quant weight preshuffle B-quant policy, basing the lane-byte count on BTypeToUse derived from Problem::BDataType and the compute type. |
| include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp | Qualifies BlockGemm::OverrideADataType/OverrideBDataType with typename to correctly refer to dependent nested types in the AB-quant compute pipeline. |
| include/ck_tile/host/tensor_shuffle_utils.hpp | Reimplements the non-gfx12 branch of shuffle_b_permuteN to use constexpr KLane/ItemsPerAccess (aligned with shuffle_b) and adjusts the tensor view rank and reference_permute order to a layout consistent with the preshuffle kernel. |
| example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp | Adds typename in BaseGemmPipeline and GemmPipeline dependent template instantiations so grouped AB-quant GEMM examples compile cleanly with strict C++ template rules. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| static constexpr ck_tile::index_t M_Warp_Tile = 32; | ||
| static constexpr ck_tile::index_t N_Warp_Tile = 32; | ||
| static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32; | ||
| }; | ||
|
|
||
| template <typename Datatype> | ||
| struct config_mn_32x32 : public config<Datatype> | ||
| { | ||
| static constexpr ck_tile::index_t M_Warp_Tile = 16; | ||
| static constexpr ck_tile::index_t N_Warp_Tile = 16; |
Copilot
AI
Jan 31, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
config_mn_16x16 and config_mn_32x32 appear to have their M_Warp_Tile/N_Warp_Tile values swapped relative to their names: config_mn_16x16 sets both tiles to 32, while config_mn_32x32 sets them to 16. This mismatch is confusing when reading the tests and can make it harder to reason about which warp-tile shape is actually being exercised; consider either swapping the tile constants or renaming the structs so the *_16x16/*_32x32 suffix matches the configured warp tile sizes.
| static constexpr ck_tile::index_t M_Warp_Tile = 32; | |
| static constexpr ck_tile::index_t N_Warp_Tile = 32; | |
| static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32; | |
| }; | |
| template <typename Datatype> | |
| struct config_mn_32x32 : public config<Datatype> | |
| { | |
| static constexpr ck_tile::index_t M_Warp_Tile = 16; | |
| static constexpr ck_tile::index_t N_Warp_Tile = 16; | |
| static constexpr ck_tile::index_t M_Warp_Tile = 16; | |
| static constexpr ck_tile::index_t N_Warp_Tile = 16; | |
| static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32; | |
| }; | |
| template <typename Datatype> | |
| struct config_mn_32x32 : public config<Datatype> | |
| { | |
| static constexpr ck_tile::index_t M_Warp_Tile = 32; | |
| static constexpr ck_tile::index_t N_Warp_Tile = 32; |
This pull request introduces several improvements and fixes related to quantized grouped GEMM (General Matrix Multiply) pipelines and their supporting utilities.
The numerical issue
Steps to reproduce
Solution
The main changes address type correctness, improve data layout and shuffling logic, and expand test coverage to better validate different GEMM configurations.
Key changes include:
Data layout and shuffling logic
shuffle_b_permuteNto useconstexprvariables forKLaneandItemsPerAccess, simplifying tile view construction and correcting the permutation order for improved efficiency and correctness (tensor_shuffle_utils.hpp).KLaneBytesin weight preshuffle pipeline policies to account for internal data type conversion (e.g., frompk_int4_ttofp8), ensuring accurate memory access and alignment in quantized GEMM policies (wp_pipeline_agmem_bgmem_creg_base_policy.hpp,gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp). [1] [2]Test infrastructure enhancements
config_mn_16x16,config_mn_32x32) to support additional GEMM tile shapes and updated tests to run with these configurations for broader coverage (test_gemm_pipeline_util.hpp). [1] [2]