1414
1515using namespace gpu ;
1616
17+ const std::string versionToStr (int version);
18+
1719static const char *kShaderMatmul1 = R"(
1820@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
1921@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
@@ -466,6 +468,123 @@ inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, cons
466468 }
467469}
468470
471+ /* 2D block-tiling with transpose
472+ *
473+ */
474+ static const char *kShaderMatmulWithTranspose = R"(
475+ @group(0) @binding(0) var<storage, read_write> a: array<{{precision}}>;
476+ @group(0) @binding(1) var<storage, read_write> b: array<{{precision}}>;
477+ @group(0) @binding(2) var<storage, read_write> c: array<vec4<{{precision}}>>;
478+ var<workgroup> tileA: array<{{precision}}, {{BM}} * {{BK}}>;
479+ var<workgroup> tileB: array<{{precision}}, {{BK}} * {{BN}}>;
480+
481+ @compute @workgroup_size({{workgroupSize}})
482+ fn main(
483+ @builtin(global_invocation_id) globalID : vec3<u32>,
484+ @builtin(local_invocation_id) localID : vec3<u32>,
485+ @builtin(workgroup_id) groupid : vec3<u32>) {
486+
487+ var threadResults: array<vec4<{{precision}}>, {{TM}} * {{TN4}}>;
488+ var localM: array<{{precision}}, {{TM}}>;
489+ var localN: array<vec4<{{precision}}>, {{TN4}}>;
490+
491+ let cRow: u32 = groupid.x;
492+ let cCol: u32 = groupid.y;
493+ let numThread: u32 = ({{BM}} * {{BN}}) / ({{TM}} * {{TN}});
494+
495+ // position of the first c element computed by the thread
496+ let threadRow: u32 = (localID.x / ({{BN}} / {{TN}})) * {{TM}};
497+ let threadCol: u32 = (localID.x % ({{BN}} / {{TN}})) * {{TN}};
498+
499+ // aPtr and bPtr are the starting positions of the tiles in a and b,
500+ // incremented in the bkidx loop.
501+ // cPtr is the starting position of the tile in c which is fixed.
502+
503+ var aPtr: u32 = cRow * {{BM}} * {{K}};
504+ var bPtr: u32 = cCol * {{BN}};
505+ let cPtr: u32 = cRow * {{BM}} * {{N4}} + cCol * {{BN4}};
506+
507+ for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) {
508+
509+ // Load tile
510+ // Load BM x BK by numThread(BM * BN / (TM * TN))
511+ // The number of iteration == BM * BK / (BM * BN / (TM * TN))
512+ for (var idx: u32 = 0; idx < {{NUM_TILEA}}; idx++) {
513+ tileA[localID.x + idx * numThread] = a[aPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + (localID.x + idx * numThread) % {{BK}}];
514+ }
515+ // Load BK x BN by numThread(BM * BN / (TM * TN))
516+ // The number of iteration == BK * BN / (BM * BN / (TM * TN))
517+ for (var idx: u32 = 0; idx < {{NUM_TILEB}}; idx++) {
518+ tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BN}}) * {{N}} + ((localID.x + idx * numThread) % {{BN}})];
519+ }
520+
521+ aPtr += {{BK}};
522+ bPtr += {{BK}} * {{N}};
523+
524+ workgroupBarrier();
525+ // Compute tile
526+ for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) {
527+ for (var idx: u32 = 0; idx < {{TM}}; idx++) {
528+ localM[idx] = tileA[(threadRow + idx) * {{BK}} + dotIdx];
529+ }
530+ for (var idx: u32 = 0; idx < {{TN4}}; idx++) {
531+ localN[idx] = vec4<{{precision}}>(tileB[(threadCol + idx*4 ) + dotIdx * {{BN}}],
532+ tileB[(threadCol + idx*4 + 1) + dotIdx * {{BN}}],
533+ tileB[(threadCol + idx*4 + 2) + dotIdx * {{BN}}],
534+ tileB[(threadCol + idx*4 + 3) + dotIdx * {{BN}}]);
535+ }
536+ for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
537+ for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) {
538+ threadResults[resIdxM * {{TN4}} + resIdxN] += localM[resIdxM] * localN[resIdxN];
539+ }
540+ }
541+ }
542+ workgroupBarrier();
543+ }
544+
545+ for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
546+ for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) {
547+ c[cPtr + (threadRow + resIdxM) * {{N4}} + (threadCol/4) + resIdxN] = threadResults[resIdxM * {{TN4}} + resIdxN];
548+ }
549+ }
550+ }
551+ )" ;
552+
553+ inline KernelCode createMatmulWithTranspose (const char *shaderTemplate, const size_t M,
554+ const size_t K, const size_t N, const size_t BM,
555+ const size_t BK, const size_t BN,
556+ const size_t TM, const size_t TN,
557+ const Shape &workgroupSize = {256 , 1 , 1 },
558+ NumType precision = kf32) {
559+ assert (BM % TM == 0 );
560+ assert (BN % TN == 0 );
561+ assert (K % BK == 0 );
562+ assert (M % BM == 0 );
563+ assert (N % BN == 0 );
564+ // # threads = tile A size == tile B size == # threads for computing C
565+ int num_threads = BM * BN / (TM * TN);
566+ std::string codeString (shaderTemplate);
567+ replaceAll (codeString, {{" {{workgroupSize}}" , toString (workgroupSize)},
568+ {" {{precision}}" , toString (precision)},
569+ {" {{M}}" , toString (M)},
570+ {" {{K}}" , toString (K)},
571+ {" {{N}}" , toString (N)},
572+ {" {{BM}}" , toString (BM)},
573+ {" {{BK}}" , toString (BK)},
574+ {" {{BN}}" , toString (BN)},
575+ {" {{TM}}" , toString (TM)},
576+ {" {{TN}}" , toString (TN)},
577+ {" {{NUM_TILEA}}" , toString (BM * BK / num_threads)},
578+ {" {{NUM_TILEB}}" , toString (BN * BK / num_threads)},
579+ {" {{TN4}}" , toString (TN / 4 )},
580+ {" {{N4}}" , toString (N / 4 )},
581+ {" {{BN4}}" , toString (BN / 4 )},
582+ });
583+ std::string unrolledCode = loopUnrolling (codeString);
584+ // LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
585+ return {unrolledCode, workgroupSize};
586+ }
587+
469588/* *
470589 * @brief No-Op shader with matmul bindings for performance testing
471590 */
@@ -519,20 +638,26 @@ Kernel selectMatmul(Context &ctx, int version,
519638 size_t M, size_t K, size_t N) {
520639 Kernel kernel;
521640 if (version == 1 ) {
641+ Shape wgSize = {256 , 1 , 1 };
642+ Shape nWorkgroups = cdiv ({M, N, 1 }, {16 , 16 , 1 });
643+ KernelCode matmul = createNoOp (kShaderNoOp , /* wgsize*/ wgSize);
644+ kernel = createKernel (ctx, matmul, bindings,
645+ /* nWorkgroups*/ nWorkgroups);
646+ } else if (version == 2 ) {
522647 Shape wgSize = {16 , 16 , 1 };
523648 LOG (kDefLog , kInfo , " wgSize: %s" , toString (wgSize).c_str ());
524649 KernelCode matmul =
525650 createMatmul1 (kShaderMatmul1 , M, K, N, /* wgsize*/ wgSize);
526651 kernel = createKernel (ctx, matmul, bindings,
527652 /* nWorkgroups*/ cdiv ({M, N, 1 }, wgSize));
528- } else if (version == 2 ) {
653+ } else if (version == 3 ) {
529654 static constexpr size_t tileSize = 16 ;
530655 KernelCode matmul = createMatmul2 (kShaderMatmul2 , M, K, N,
531656 /* wgSize*/ {tileSize * tileSize, 1 , 1 });
532657 kernel =
533658 createKernel (ctx, matmul, bindings,
534659 /* nWorkgroups*/ cdiv ({M, N, 1 }, {tileSize, tileSize, 1 }));
535- } else if (version == 3 || version == 5 ) {
660+ } else if (version == 4 || version == 6 ) {
536661 static constexpr size_t BM = 64 ;
537662 static constexpr size_t BK = 4 ;
538663 static constexpr size_t BN = BM;
@@ -548,10 +673,10 @@ Kernel selectMatmul(Context &ctx, int version,
548673 KernelCode matmul = createMatmul3 (kShaderMatmul3 , M, K, N, BM, BK, BN, TM,
549674 /* wgSize*/ wgSize,
550675 kf32,
551- /* Loop unrolling*/ version == 5 ? true : false );
676+ /* Loop unrolling*/ version == 6 ? true : false );
552677 kernel = createKernel (ctx, matmul, bindings,
553678 /* nWorkgroups*/ nWorkgroups);
554- } else if (version == 4 || version == 6 ) {
679+ } else if (version == 5 || version == 7 ) {
555680 static constexpr size_t BM = 64 ;
556681 static constexpr size_t BK = 8 ;
557682 static constexpr size_t BN = 64 ;
@@ -566,10 +691,10 @@ Kernel selectMatmul(Context &ctx, int version,
566691 KernelCode matmul = createMatmul4 (kShaderMatmul4 , M, K, N, BM, BK, BN, TM, TN,
567692 /* wgSize*/ wgSize,
568693 kf32,
569- /* Loop unrolling*/ version == 6 ? true : false );
694+ /* Loop unrolling*/ version == 7 ? true : false );
570695 kernel = createKernel (ctx, matmul, bindings,
571696 /* nWorkgroups*/ nWorkgroups);
572- } else if (version == 7 ) {
697+ } else if (version == 8 ) {
573698 static constexpr size_t BM = 64 ;
574699 static constexpr size_t BK = 8 ;
575700 static constexpr size_t BN = 64 ;
@@ -587,10 +712,21 @@ Kernel selectMatmul(Context &ctx, int version,
587712 /* Loop unrolling*/ true );
588713 kernel = createKernel (ctx, matmul, bindings,
589714 /* nWorkgroups*/ nWorkgroups);
590- } else if (version == 8 ) {
591- Shape wgSize = {256 , 1 , 1 };
592- Shape nWorkgroups = cdiv ({M, N, 1 }, {16 , 16 , 1 });
593- KernelCode matmul = createNoOp (kShaderNoOp , /* wgsize*/ wgSize);
715+ } else if (version == 9 ) {
716+ static constexpr size_t BM = 64 ;
717+ static constexpr size_t BK = 8 ;
718+ static constexpr size_t BN = 64 ;
719+ static constexpr size_t TM = BM / BK;
720+ static constexpr size_t TN = BN / BK;
721+ Shape wgSize = {(BM / TM) * (BN / TN), 1 , 1 }; // This is the same as BK * BK.
722+ Shape nWorkgroups = {cdiv (M, BM), cdiv (N, BN), 1 };
723+ LOG (kDefLog , kInfo , " M: %d, K: %d, N: %d" , M, K, N);
724+ LOG (kDefLog , kInfo , " BM: %d, BK: %d, BN: %d, TM: %d, TN: %d" , BM, BK, BN, TM, TN);
725+ LOG (kDefLog , kInfo , " wgSize: ( %s )" , toString (wgSize).c_str ());
726+ LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
727+ KernelCode matmul = createMatmulWithTranspose (kShaderMatmulWithTranspose , M, K, N, BM, BK, BN, TM, TN,
728+ /* wgSize*/ wgSize,
729+ kf32);
594730 kernel = createKernel (ctx, matmul, bindings,
595731 /* nWorkgroups*/ nWorkgroups);
596732 }
@@ -626,8 +762,8 @@ void runTest(int version, size_t M, size_t K, size_t N,
626762
627763 printf (" [ Press enter to start tests ... ]\n " );
628764 getchar ();
629- LOG (kDefLog , kInfo , " Dispatching Kernel version %d, %d iterations ..." ,
630- version, nIter);
765+ LOG (kDefLog , kInfo , " Dispatching Kernel version %d: %s , %d iterations ..." ,
766+ version, versionToStr (version). c_str (), nIter);
631767
632768 // Dispatch kernel nIter times
633769 auto start = std::chrono::high_resolution_clock::now ();
@@ -662,26 +798,43 @@ void runTest(int version, size_t M, size_t K, size_t N,
662798 M, K, N, nIter, duration.count () / static_cast <double >(nIter) / 1000.0 /* us -> ms */ , gflops);
663799}
664800
801+ const std::string versionToStr (int version){
802+ switch (version) {
803+ case 1 : return " No-Op" ;
804+ case 2 : return " naive matmul" ;
805+ case 3 : return " tiling" ;
806+ case 4 : return " 1D blocktiling" ;
807+ case 5 : return " 2D blocktiling" ;
808+ case 6 : return " 1D blocktiling with loop unrolling" ;
809+ case 7 : return " 2D blocktiling with loop unrolling" ;
810+ case 8 : return " 2D blocktiling with loop unrolling and vectorization" ;
811+ case 9 : return " 2D blocktiling with loop unrolling, vectorization and transpose" ;
812+ default : return " Not specified" ;
813+ }
814+ }
815+
665816int main () {
666817 char * version_str = getenv (" MATMUL_VERSION" );
667- int version = version_str == NULL ? 7 : atoi (version_str);
668- // 1 == naive matmul
669- // 2 == tiling
670- // 3 == 1D blocktiling
671- // 4 == 2D blocktiling
672- // 5 == 1D blocktiling with loop unrolling
673- // 6 == 2D blocktiling with loop unrolling
674- // 7 == 2D blocktiling with loop unrolling and vectorization
675- // 8 == No-Op
818+ int version = version_str == NULL ? 9 : atoi (version_str);
819+ // 1 == No-Op
820+ // 2 == naive matmul
821+ // 3 == tiling
822+ // 4 == 1D blocktiling
823+ // 5 == 2D blocktiling
824+ // 6 == 1D blocktiling with loop unrolling
825+ // 7 == 2D blocktiling with loop unrolling
826+ // 8 == 2D blocktiling with loop unrolling and vectorization
827+ // 9 == 2D blocktiling with loop unrolling, vectorization and transpose (default)
676828
677829 size_t M, K, N; // Matrix dimensions
678- static constexpr int kTestSize = 2 ;
679- if constexpr (kTestSize == 0 ) {
830+ char * kTestSize_str = getenv (" MATMUL_SIZE" );
831+ int kTestSize = kTestSize_str == NULL ? 2 : atoi (kTestSize_str );
832+ if (kTestSize == 0 ) {
680833 // Tiny test
681834 M = 32 ;
682835 K = 32 ;
683836 N = 32 ;
684- } else if constexpr (kTestSize == 1 ) {
837+ } else if (kTestSize == 1 ) {
685838 // Small test
686839 M = 256 ;
687840 K = 128 ;
@@ -696,11 +849,19 @@ int main() {
696849 std::unique_ptr<float []> inputPtr = std::make_unique<float []>(M * K);
697850 std::unique_ptr<float []> weightsPtr = std::make_unique<float []>(N * K);
698851 std::unique_ptr<float []> outputPtr = std::make_unique<float []>(M * N);
852+ bool transposedInput = version == 9 ;
699853
700854 initData (M, K, N, inputPtr, weightsPtr);
701- runTest (version, M, K, N, inputPtr, weightsPtr, outputPtr);
855+ if (transposedInput) {
856+ std::unique_ptr<float []> transposedWeightPtr = std::make_unique<float []>(K * N);
857+ transpose (weightsPtr.get (), transposedWeightPtr.get (), N, K);
858+ runTest (version, M, K, N, inputPtr, transposedWeightPtr, outputPtr);
859+ } else {
860+ runTest (version, M, K, N, inputPtr, weightsPtr, outputPtr);
861+ }
862+
702863
703- if constexpr (kTestSize <= 1 ) {
864+ if (kTestSize <= 1 ) {
704865 // Check result with CPU reference implementation for tiny/small tests
705866 checkCPU (M, K, N, inputPtr, weightsPtr, outputPtr);
706867 }
0 commit comments