@@ -613,6 +613,66 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
613613 return {unrolledCode, workgroupSize, precision};
614614}
615615
616+ inline KernelCode createMatmul12 (const char *shaderTemplate, const size_t M,
617+ const size_t K, const size_t N,
618+ NumType precision = kf32) {
619+ std::string codeString (shaderTemplate);
620+ replaceAll (codeString, {{" {{precision}}" , toString (precision)},
621+ {" {{M}}" , toString (M)},
622+ {" {{K}}" , toString (K)},
623+ {" {{N}}" , toString (N)}});
624+ return {codeString, {256 , 1 , 1 }, precision};
625+ }
626+
627+
628+
629+ // ─────────────────────────────────────────────────────────────────────────────
630+ // Optimised WGSL matrix‑multiply kernel using subgroupMatrixLoad/Store
631+ // and subgroupMatrixMultiplyAccumulate
632+ // ─────────────────────────────────────────────────────────────────────────────
633+ const char * kShaderSubgroupMatrixMultiply = R"(
634+ enable chromium_experimental_subgroup_matrix;
635+
636+ @group(0) @binding(0) var<storage, read> A: array<{{precision}}>;
637+ @group(0) @binding(1) var<storage, read> B: array<{{precision}}>;
638+ @group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
639+
640+ // Each workgroup computes one 16x16 tile of C.
641+ @compute @workgroup_size(256, 1, 1)
642+ fn main(@builtin(workgroup_id) groupID: vec3<u32>) {
643+
644+ let tileRow = groupID.y;
645+ let tileCol = groupID.x;
646+
647+ let outRowStart = tileRow * 16u;
648+ let outColStart = tileCol * 16u;
649+
650+ if (outRowStart >= {{M}} || outColStart >= {{N}}) {
651+ return;
652+ }
653+
654+ var acc: subgroup_matrix_result<{{precision}}, 16, 16>;
655+
656+ let kTiles = ({{K}} + 15u) / 16u;
657+
658+ // Load the first tile and multiply to initialize accumulator
659+ let a_tile_0 = subgroupMatrixLoad<subgroup_matrix_left<{{precision}}, 16, 16>>(A, outRowStart * {{K}}, true, {{K}});
660+ let b_tile_0 = subgroupMatrixLoad<subgroup_matrix_right<{{precision}}, 16, 16>>(B, outColStart, true, {{N}});
661+ acc = subgroupMatrixMultiply<{{precision}}>(a_tile_0, b_tile_0);
662+
663+ // Loop over the rest of the K-dimension
664+ for (var kTile: u32 = 1u; kTile < kTiles; kTile = kTile + 1u) {
665+ let k = kTile * 16u;
666+ let a_tile = subgroupMatrixLoad<subgroup_matrix_left<{{precision}}, 16, 16>>(A, outRowStart * {{K}} + k, true, {{K}});
667+ let b_tile = subgroupMatrixLoad<subgroup_matrix_right<{{precision}}, 16, 16>>(B, k * {{N}} + outColStart, true, {{N}});
668+ acc = subgroupMatrixMultiplyAccumulate(a_tile, b_tile, acc);
669+ }
670+
671+ subgroupMatrixStore(C, outRowStart * {{N}} + outColStart, acc, true, {{N}});
672+ }
673+ )" ;
674+
675+
616676/* *
617677 * @brief No-Op shader with matmul bindings for performance testing
618678 */
@@ -775,6 +835,16 @@ Kernel selectMatmul(Context &ctx, int version,
775835 numtype);
776836 kernel = createKernel (ctx, matmul, bindings,
777837 /* nWorkgroups*/ nWorkgroups);
838+ } else if (version == 12 ) {
839+ // f32: Subgroup matrix multiply
840+ Shape wgSize = {256 , 1 , 1 }; // One subgroup per workgroup
841+ Shape nWorkgroups = {cdiv (N, 16 ), cdiv (M, 16 ), 1 };
842+ LOG (kDefLog , kInfo , " M: %zu, K: %zu, N: %zu" , M, K, N);
843+ LOG (kDefLog , kInfo , " wgSize: ( %s )" , toString (wgSize).c_str ());
844+ LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
845+ KernelCode matmul =
846+ createMatmul12 (kShaderSubgroupMatrixMultiply , M, K, N, numtype);
847+ kernel = createKernel (ctx, matmul, bindings, nWorkgroups);
778848 }
779849 return kernel;
780850}
@@ -859,7 +929,7 @@ void runTest(int version, size_t M, size_t K, size_t N,
859929 // Use microsecond for more accurate time measurement
860930 auto duration =
861931 std::chrono::duration_cast<std::chrono::microseconds>(end - start);
862- float gflops = 2 * M * N *
932+ float gflops = 2 . 0f * M * N *
863933 K / // factor of 2 for multiplication & accumulation
864934 (static_cast <double >(duration.count ()) / 1000000.0 ) /
865935 1000000000.0 * static_cast <float >(nIter);
@@ -870,7 +940,7 @@ void runTest(int version, size_t M, size_t K, size_t N,
870940 show<precision>(outputPtr.get (), M, N, " Output[0]" ).c_str ());
871941
872942 LOG (kDefLog , kInfo , " \n\n ===================================================================="
873- " ============\n Execution Time: (M = %d , K = %d , N = %d ) x %d iterations "
943+ " ============\n Execution Time: (M = %zu , K = %zu , N = %zu ) x %zu iterations "
874944 " :\n %.1f "
875945 " milliseconds / dispatch ~ %.2f "
876946 " GFLOPS\n ================================================================"
@@ -911,15 +981,16 @@ const std::string versionToStr(int version){
911981 case 7 : return " f32: 2D blocktiling with loop unrolling" ;
912982 case 8 : return " f32: 2D blocktiling with loop unrolling and vectorization" ;
913983 case 9 : return " f32: 2D blocktiling with loop unrolling, vectorization and transpose" ;
914- case 10 : return " f16: 2D blocktiling with loop unrolling and vectorization" ;
984+ case 10 : return " f16: 2D blocktiling with loop unrolling and vectorization (default) " ;
915985 case 11 : return " f16: 2D blocktiling with loop unrolling, vectorization and transpose" ;
986+ case 12 : return " f32: Subgroup matrix multiply" ;
916987 default : return " Not specified" ;
917988 }
918989}
919990
920991int main () {
921992 char * version_str = getenv (" MATMUL_VERSION" );
922- int version = version_str == NULL ? 10 : atoi (version_str);
993+ int version = version_str == NULL ? 12 : atoi (version_str);
923994 // 1 == f32: No-Op
924995 // 2 == f32: naive matmul
925996 // 3 == f32: tiling
@@ -931,8 +1002,9 @@ int main() {
9311002 // 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose
9321003 // 10 == f16: 2D blocktiling with loop unrolling and vectorization (default)
9331004 // 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose
1005+ // 12 == f32: Subgroup matrix multiply
9341006 bool enableF16 = version == 10 || version ==11 ;
935- bool transposedInput = version == 9 || version == 11 ;
1007+ bool transposedInput = version == 9 || version == 11 || version == 12 ;
9361008 NumType numtype = enableF16 ? kf16 : kf32;
9371009
9381010 size_t M, K, N; // Matrix dimensions
0 commit comments