1111#include " utils/array_utils.h" // show, isclose, randn, randint
1212#include " utils/logging.h" // LOG
1313#include " experimental/wgsl.h" // loopUnrolling
14+ #include " numeric_types/half.h"
1415
1516using namespace gpu ;
1617
1718const std::string versionToStr (int version);
1819
20+ void matmulf16_forward_cpu (half* out,
21+ const half* inp, const half* weight, const half* bias,
22+ int B, int T, int C, int OC) {
23+ // OC is short for "output channels"
24+ // inp is (B,T,C), weight is (OC, C)
25+ // out will be (B,T,OC)
26+ #pragma omp parallel for collapse(2)
27+ for (int b = 0 ; b < B; b++) {
28+ for (int t = 0 ; t < T; t++) {
29+ half* out_bt = out + b * T * OC + t * OC;
30+ const half* inp_bt = inp + b * T * C + t * C;
31+ for (int o = 0 ; o < OC; o++) {
32+ float val = (bias != NULL ) ? halfToFloat (bias[o]) : 0 .0f ;
33+ const half* wrow = weight + o*C;
34+ for (int i = 0 ; i < C; i++) {
35+ val += halfToFloat (inp_bt[i]) * halfToFloat (wrow[i]);
36+ }
37+ out_bt[o] = val;
38+ }
39+ }
40+ }
41+ }
42+
1943static const char *kShaderMatmul1 = R"(
2044@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
2145@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
@@ -47,7 +71,7 @@ inline KernelCode createMatmul1(const char *shaderTemplate, const size_t M,
4771 {" {{M}}" , toString (M)},
4872 {" {{K}}" , toString (K)},
4973 {" {{N}}" , toString (N)}});
50- return {codeString, workgroupSize};
74+ return {codeString, workgroupSize, precision };
5175}
5276
5377// Shared memory cache-blocking
@@ -108,7 +132,7 @@ inline KernelCode createMatmul2(const char *shaderTemplate, const size_t M,
108132 {" {{N}}" , toString (N)},
109133 {" {{tileSize}}" ,
110134 toString (static_cast <size_t >(sqrt (workgroupSize[0 ])))}});
111- return {codeString, workgroupSize};
135+ return {codeString, workgroupSize, precision };
112136}
113137
114138/* 1D block-tiling
@@ -224,9 +248,9 @@ inline KernelCode createMatmul3(const char *shaderTemplate, const size_t M,
224248 if (unrolling) {
225249 std::string unrolledCode = loopUnrolling (codeString);
226250 // LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
227- return {unrolledCode, workgroupSize};
251+ return {unrolledCode, workgroupSize, precision };
228252 } else {
229- return {codeString, workgroupSize};
253+ return {codeString, workgroupSize, precision };
230254 }
231255}
232256
@@ -340,9 +364,9 @@ inline KernelCode createMatmul4(const char *shaderTemplate, const size_t M,
340364 if (unrolling) {
341365 std::string unrolledCode = loopUnrolling (codeString);
342366 // LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
343- return {unrolledCode, workgroupSize};
367+ return {unrolledCode, workgroupSize, precision };
344368 } else {
345- return {codeString, workgroupSize};
369+ return {codeString, workgroupSize, precision };
346370 }
347371}
348372
@@ -462,9 +486,9 @@ inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, cons
462486 if (unrolling) {
463487 std::string unrolledCode = loopUnrolling (codeString);
464488 // LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
465- return {unrolledCode, workgroupSize};
489+ return {unrolledCode, workgroupSize, precision };
466490 } else {
467- return {codeString, workgroupSize};
491+ return {codeString, workgroupSize, precision };
468492 }
469493}
470494
@@ -582,7 +606,7 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
582606 });
583607 std::string unrolledCode = loopUnrolling (codeString);
584608 // LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
585- return {unrolledCode, workgroupSize};
609+ return {unrolledCode, workgroupSize, precision };
586610}
587611
588612/* *
@@ -604,7 +628,7 @@ inline KernelCode createNoOp(const char *shaderTemplate,
604628 std::string codeString (shaderTemplate);
605629 replaceAll (codeString, {{" {{workgroupSize}}" , toString (workgroupSize)},
606630 {" {{precision}}" , toString (precision)}});
607- return {codeString, workgroupSize};
631+ return {codeString, workgroupSize, precision };
608632}
609633
610634void initData (size_t M, size_t K, size_t N, std::unique_ptr<float []> &inputPtr,
@@ -619,23 +643,41 @@ void initData(size_t M, size_t K, size_t N, std::unique_ptr<float[]> &inputPtr,
619643 show<float >(weightsPtr.get (), N, K, " Weights" ).c_str ());
620644}
621645
622- void checkCPU (size_t M, size_t K, size_t N, std::unique_ptr<float []> &inputPtr,
623- std::unique_ptr<float []> &weightsPtr,
624- std::unique_ptr<float []> &outputPtr) {
646+ void initData (size_t M, size_t K, size_t N, std::unique_ptr<half[]> &inputPtr,
647+ std::unique_ptr<half[]> &weightsPtr) {
648+ std::mt19937 gen (314159 );
649+ randn (inputPtr.get (), M * K, gen);
650+ randn (weightsPtr.get (), N * K, gen);
651+ // randint(inputPtr.get(), M * K, gen, 1, 2);
652+ // randint(weightsPtr.get(), N * K, gen, 1, 2);
653+ LOG (kDefLog , kInfo , " %s" , show<half>(inputPtr.get (), M, K, " Input" ).c_str ());
654+ LOG (kDefLog , kInfo , " %s" ,
655+ show<half>(weightsPtr.get (), N, K, " Weights" ).c_str ());
656+ }
657+
658+ template <class precision =float >
659+ void checkCPU (size_t M, size_t K, size_t N, std::unique_ptr<precision[]> &inputPtr,
660+ std::unique_ptr<precision[]> &weightsPtr,
661+ std::unique_ptr<precision[]> &outputPtr) {
625662 LOG (kDefLog , kInfo , " Computing CPU reference implementation" );
626- std::unique_ptr<float []> outputRefPtr = std::make_unique<float []>(M * N);
627- ref::matmul_forward_cpu (outputRefPtr.get (), inputPtr.get (), weightsPtr.get (),
628- nullptr , 1 , M, K, N);
663+ std::unique_ptr<precision[]> outputRefPtr = std::make_unique<precision[]>(M * N);
664+ if constexpr (std::is_same<precision, float >::value) {
665+ ref::matmul_forward_cpu (outputRefPtr.get (), inputPtr.get (), weightsPtr.get (),
666+ nullptr , 1 , M, K, N);
667+ } else if constexpr (std::is_same<precision, half>::value) {
668+ matmulf16_forward_cpu (outputRefPtr.get (), inputPtr.get (), weightsPtr.get (),
669+ nullptr , 1 , M, K, N);
670+ }
629671 LOG (kDefLog , kInfo , " Reference Output: %s" ,
630- show<float >(outputRefPtr.get (), M, N, " Output (Reference)" ).c_str ());
672+ show<precision >(outputRefPtr.get (), M, N, " Output (Reference)" ).c_str ());
631673 LOG (kDefLog , kInfo ,
632674 isclose (outputPtr.get (), outputRefPtr.get (), M * N) ? " CPU Check: PASS"
633675 : " CPU Check: FAIL" );
634676}
635677
636678Kernel selectMatmul (Context &ctx, int version,
637679 const Bindings</* input, weights, output */ 3 > &bindings,
638- size_t M, size_t K, size_t N) {
680+ size_t M, size_t K, size_t N, NumType numtype ) {
639681 Kernel kernel;
640682 if (version == 1 ) {
641683 Shape wgSize = {256 , 1 , 1 };
@@ -647,13 +689,13 @@ Kernel selectMatmul(Context &ctx, int version,
647689 Shape wgSize = {16 , 16 , 1 };
648690 LOG (kDefLog , kInfo , " wgSize: %s" , toString (wgSize).c_str ());
649691 KernelCode matmul =
650- createMatmul1 (kShaderMatmul1 , M, K, N, /* wgsize*/ wgSize);
692+ createMatmul1 (kShaderMatmul1 , M, K, N, /* wgsize*/ wgSize, numtype );
651693 kernel = createKernel (ctx, matmul, bindings,
652694 /* nWorkgroups*/ cdiv ({M, N, 1 }, wgSize));
653695 } else if (version == 3 ) {
654696 static constexpr size_t tileSize = 16 ;
655697 KernelCode matmul = createMatmul2 (kShaderMatmul2 , M, K, N,
656- /* wgSize*/ {tileSize * tileSize, 1 , 1 });
698+ /* wgSize*/ {tileSize * tileSize, 1 , 1 }, numtype );
657699 kernel =
658700 createKernel (ctx, matmul, bindings,
659701 /* nWorkgroups*/ cdiv ({M, N, 1 }, {tileSize, tileSize, 1 }));
@@ -672,7 +714,7 @@ Kernel selectMatmul(Context &ctx, int version,
672714 LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
673715 KernelCode matmul = createMatmul3 (kShaderMatmul3 , M, K, N, BM, BK, BN, TM,
674716 /* wgSize*/ wgSize,
675- kf32 ,
717+ numtype ,
676718 /* Loop unrolling*/ version == 6 ? true : false );
677719 kernel = createKernel (ctx, matmul, bindings,
678720 /* nWorkgroups*/ nWorkgroups);
@@ -690,11 +732,11 @@ Kernel selectMatmul(Context &ctx, int version,
690732 LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
691733 KernelCode matmul = createMatmul4 (kShaderMatmul4 , M, K, N, BM, BK, BN, TM, TN,
692734 /* wgSize*/ wgSize,
693- kf32 ,
735+ numtype ,
694736 /* Loop unrolling*/ version == 7 ? true : false );
695737 kernel = createKernel (ctx, matmul, bindings,
696738 /* nWorkgroups*/ nWorkgroups);
697- } else if (version == 8 ) {
739+ } else if (version == 8 || version == 10 ) {
698740 static constexpr size_t BM = 64 ;
699741 static constexpr size_t BK = 8 ;
700742 static constexpr size_t BN = 64 ;
@@ -708,11 +750,11 @@ Kernel selectMatmul(Context &ctx, int version,
708750 LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
709751 KernelCode matmul = createMatmulWithVectorization (kShaderMatmulWithVectorization , M, K, N, BM, BK, BN, TM, TN,
710752 /* wgSize*/ wgSize,
711- kf32 ,
753+ numtype ,
712754 /* Loop unrolling*/ true );
713755 kernel = createKernel (ctx, matmul, bindings,
714756 /* nWorkgroups*/ nWorkgroups);
715- } else if (version == 9 ) {
757+ } else if (version == 9 || version == 11 ) {
716758 static constexpr size_t BM = 64 ;
717759 static constexpr size_t BK = 8 ;
718760 static constexpr size_t BN = 64 ;
@@ -726,23 +768,36 @@ Kernel selectMatmul(Context &ctx, int version,
726768 LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
727769 KernelCode matmul = createMatmulWithTranspose (kShaderMatmulWithTranspose , M, K, N, BM, BK, BN, TM, TN,
728770 /* wgSize*/ wgSize,
729- kf32 );
771+ numtype );
730772 kernel = createKernel (ctx, matmul, bindings,
731773 /* nWorkgroups*/ nWorkgroups);
732774 }
733775 return kernel;
734776}
735777
778+ template <class precision =float >
736779void runTest (int version, size_t M, size_t K, size_t N,
737- std::unique_ptr<float []> &inputPtr,
738- std::unique_ptr<float []> &weightsPtr,
739- std::unique_ptr<float []> &outputPtr) {
780+ std::unique_ptr<precision[]> &inputPtr,
781+ std::unique_ptr<precision[]> &weightsPtr,
782+ std::unique_ptr<precision[]> &outputPtr,
783+ NumType numtype) {
784+ if constexpr (std::is_same<precision, float >::value) {
785+ assert (numtype == kf32);
786+ } else if constexpr (std::is_same<precision, half>::value) {
787+ assert (numtype == kf16);
788+ }
740789
741790 // Allocate GPU buffers and copy data
742- Context ctx = createContext ();
743- Tensor input = createTensor (ctx, Shape{M, K}, kf32, inputPtr.get ());
744- Tensor weights =
745- createTensor (ctx, Shape{N, K}, kf32, weightsPtr.get ()); // column-major
791+ Context ctx = createContext (
792+ {}, {},
793+ /* device descriptor, enabling f16 in WGSL*/
794+ {
795+ .requiredFeatureCount = 1 ,
796+ .requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data (),
797+ });
798+
799+ Tensor input = createTensor (ctx, Shape{M, K}, numtype, inputPtr.get ());
800+ Tensor weights = createTensor (ctx, Shape{N, K}, numtype, weightsPtr.get ()); // column-major
746801
747802 constexpr size_t nIter = 30 ;
748803
@@ -756,8 +811,8 @@ void runTest(int version, size_t M, size_t K, size_t N,
756811 std::array<Tensor, nIter> outputs;
757812 for (int i = 0 ; i < nIter; i++) {
758813 futures[i] = promises[i].get_future ();
759- outputs[i] = createTensor (ctx, Shape{M, N}, kf32 );
760- kernels[i] = selectMatmul (ctx, version, {input, weights, outputs[i]}, M, K, N);
814+ outputs[i] = createTensor (ctx, Shape{M, N}, numtype );
815+ kernels[i] = selectMatmul (ctx, version, {input, weights, outputs[i]}, M, K, N, numtype );
761816 }
762817
763818 printf (" [ Press enter to start tests ... ]\n " );
@@ -785,9 +840,9 @@ void runTest(int version, size_t M, size_t K, size_t N,
785840 1000000000.0 * static_cast <float >(nIter);
786841
787842 LOG (kDefLog , kInfo , " Copying result to CPU" );
788- toCPU (ctx, outputs[0 ], outputPtr.get (), M * N * sizeof (float ));
843+ toCPU (ctx, outputs[0 ], outputPtr.get (), M * N * sizeof (precision ));
789844 LOG (kDefLog , kInfo , " %s" ,
790- show<float >(outputPtr.get (), M, N, " Output[0]" ).c_str ());
845+ show<precision >(outputPtr.get (), M, N, " Output[0]" ).c_str ());
791846
792847 LOG (kDefLog , kInfo , " \n\n ===================================================================="
793848 " ============\n Execution Time: (M = %d, K = %d, N = %d) x %d iterations "
@@ -798,33 +853,62 @@ void runTest(int version, size_t M, size_t K, size_t N,
798853 M, K, N, nIter, duration.count () / static_cast <double >(nIter) / 1000.0 /* us -> ms */ , gflops);
799854}
800855
856+ template <class precision =float >
857+ void runTestWithCheck (int version, size_t M, size_t K, size_t N,
858+ bool transposedInput, int kTestSize , NumType numtype) {
859+ std::unique_ptr<precision[]> inputPtr = std::make_unique<precision[]>(M * K);
860+ std::unique_ptr<precision[]> weightsPtr = std::make_unique<precision[]>(N * K);
861+ std::unique_ptr<precision[]> outputPtr = std::make_unique<precision[]>(M * N);
862+
863+ initData (M, K, N, inputPtr, weightsPtr);
864+ if (transposedInput) {
865+ std::unique_ptr<precision[]> transposedWeightPtr = std::make_unique<precision[]>(K * N);
866+ transpose (weightsPtr.get (), transposedWeightPtr.get (), N, K);
867+ runTest (version, M, K, N, inputPtr, transposedWeightPtr, outputPtr, numtype);
868+ } else {
869+ runTest (version, M, K, N, inputPtr, weightsPtr, outputPtr, numtype);
870+ }
871+
872+ if (kTestSize <= 1 ) {
873+ // Check result with CPU reference implementation for tiny/small tests
874+ checkCPU (M, K, N, inputPtr, weightsPtr, outputPtr);
875+ }
876+ }
877+
801878const std::string versionToStr (int version){
802879 switch (version) {
803- case 1 : return " No-Op" ;
804- case 2 : return " naive matmul" ;
805- case 3 : return " tiling" ;
806- case 4 : return " 1D blocktiling" ;
807- case 5 : return " 2D blocktiling" ;
808- case 6 : return " 1D blocktiling with loop unrolling" ;
809- case 7 : return " 2D blocktiling with loop unrolling" ;
810- case 8 : return " 2D blocktiling with loop unrolling and vectorization" ;
811- case 9 : return " 2D blocktiling with loop unrolling, vectorization and transpose" ;
880+ case 1 : return " f32: No-Op" ;
881+ case 2 : return " f32: naive matmul" ;
882+ case 3 : return " f32: tiling" ;
883+ case 4 : return " f32: 1D blocktiling" ;
884+ case 5 : return " f32: 2D blocktiling" ;
885+ case 6 : return " f32: 1D blocktiling with loop unrolling" ;
886+ case 7 : return " f32: 2D blocktiling with loop unrolling" ;
887+ case 8 : return " f32: 2D blocktiling with loop unrolling and vectorization" ;
888+ case 9 : return " f32: 2D blocktiling with loop unrolling, vectorization and transpose" ;
889+ case 10 : return " f16: 2D blocktiling with loop unrolling and vectorization" ;
890+ case 11 : return " f16: 2D blocktiling with loop unrolling, vectorization and transpose" ;
812891 default : return " Not specified" ;
813892 }
814893}
815894
816895int main () {
817896 char * version_str = getenv (" MATMUL_VERSION" );
818- int version = version_str == NULL ? 9 : atoi (version_str);
819- // 1 == No-Op
820- // 2 == naive matmul
821- // 3 == tiling
822- // 4 == 1D blocktiling
823- // 5 == 2D blocktiling
824- // 6 == 1D blocktiling with loop unrolling
825- // 7 == 2D blocktiling with loop unrolling
826- // 8 == 2D blocktiling with loop unrolling and vectorization
827- // 9 == 2D blocktiling with loop unrolling, vectorization and transpose (default)
897+ int version = version_str == NULL ? 10 : atoi (version_str);
898+ // 1 == f32: No-Op
899+ // 2 == f32: naive matmul
900+ // 3 == f32: tiling
901+ // 4 == f32: 1D blocktiling
902+ // 5 == f32: 2D blocktiling
903+ // 6 == f32: 1D blocktiling with loop unrolling
904+ // 7 == f32: 2D blocktiling with loop unrolling
905+ // 8 == f32: 2D blocktiling with loop unrolling and vectorization
906+ // 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose
907+ // 10 == f16: 2D blocktiling with loop unrolling and vectorization (default)
908+ // 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose
909+ bool enableF16 = version == 10 || version ==11 ;
910+ bool transposedInput = version == 9 || version == 11 ;
911+ NumType numtype = enableF16 ? kf16 : kf32;
828912
829913 size_t M, K, N; // Matrix dimensions
830914 char * kTestSize_str = getenv (" MATMUL_SIZE" );
@@ -846,24 +930,10 @@ int main() {
846930 N = 2 * 4096 ;
847931 }
848932
849- std::unique_ptr<float []> inputPtr = std::make_unique<float []>(M * K);
850- std::unique_ptr<float []> weightsPtr = std::make_unique<float []>(N * K);
851- std::unique_ptr<float []> outputPtr = std::make_unique<float []>(M * N);
852- bool transposedInput = version == 9 ;
853-
854- initData (M, K, N, inputPtr, weightsPtr);
855- if (transposedInput) {
856- std::unique_ptr<float []> transposedWeightPtr = std::make_unique<float []>(K * N);
857- transpose (weightsPtr.get (), transposedWeightPtr.get (), N, K);
858- runTest (version, M, K, N, inputPtr, transposedWeightPtr, outputPtr);
933+ if (enableF16) {
934+ runTestWithCheck<half>(version, M, K, N, transposedInput, kTestSize , numtype);
859935 } else {
860- runTest (version, M, K, N, inputPtr, weightsPtr, outputPtr);
861- }
862-
863-
864- if (kTestSize <= 1 ) {
865- // Check result with CPU reference implementation for tiny/small tests
866- checkCPU (M, K, N, inputPtr, weightsPtr, outputPtr);
936+ runTestWithCheck<float >(version, M, K, N, transposedInput, kTestSize , numtype);
867937 }
868938
869939 LOG (kDefLog , kInfo , " Done." );
0 commit comments