@@ -806,6 +806,8 @@ class MemorySanitizerOnSpirv {
806806 void initializeKernelCallerMap (Function *F);
807807
808808private:
809+ friend struct MemorySanitizerVisitor ;
810+
809811 Module &M;
810812 LLVMContext &C;
811813 const DataLayout &DL;
@@ -833,6 +835,7 @@ class MemorySanitizerOnSpirv {
833835 FunctionCallee MsanBarrierFunc;
834836 FunctionCallee MsanUnpoisonStackFunc;
835837 FunctionCallee MsanSetPrivateBaseFunc;
838+ FunctionCallee MsanUnpoisonStridedCopyFunc;
836839};
837840
838841} // end anonymous namespace
@@ -899,14 +902,14 @@ void MemorySanitizerOnSpirv::initializeCallbacks() {
899902 M.getOrInsertFunction (" __msan_unpoison_shadow_static_local" ,
900903 IRB.getVoidTy (), IntptrTy, IntptrTy);
901904
902- // __asan_poison_shadow_dynamic_local (
905+ // __msan_poison_shadow_dynamic_local (
903906 // uptr ptr,
904907 // uint32_t num_args
905908 // )
906909 MsanPoisonShadowDynamicLocalFunc = M.getOrInsertFunction (
907910 " __msan_poison_shadow_dynamic_local" , IRB.getVoidTy (), IntptrTy, Int32Ty);
908911
909- // __asan_unpoison_shadow_dynamic_local (
912+ // __msan_unpoison_shadow_dynamic_local (
910913 // uptr ptr,
911914 // uint32_t num_args
912915 // )
@@ -930,6 +933,18 @@ void MemorySanitizerOnSpirv::initializeCallbacks() {
930933 MsanSetPrivateBaseFunc =
931934 M.getOrInsertFunction (" __msan_set_private_base" , IRB.getVoidTy (),
932935 PointerType::get (C, kSpirOffloadPrivateAS ));
936+
937+ // __msan_unpoison_strided_copy(
938+ // uptr dest, uint32_t dest_as,
939+ // uptr src, uint32_t src_as,
940+ // uint32_t element_size,
941+ // uptr counts,
942+ // uptr stride
943+ // )
944+ MsanUnpoisonStridedCopyFunc = M.getOrInsertFunction (
945+ " __msan_unpoison_strided_copy" , IRB.getVoidTy (), IntptrTy,
946+ IRB.getInt32Ty (), IntptrTy, IRB.getInt32Ty (), IRB.getInt32Ty (),
947+ IRB.getInt64Ty (), IRB.getInt64Ty ());
933948}
934949
935950// Handle global variables:
@@ -1833,7 +1848,8 @@ static void setNoSanitizedMetadataSPIR(Instruction &I) {
18331848 }
18341849 } else {
18351850 auto FuncName = Func->getName ();
1836- if (FuncName.contains (" __spirv_" ))
1851+ if (FuncName.contains (" __spirv_" ) &&
1852+ !FuncName.contains (" __spirv_GroupAsyncCopy" ))
18371853 I.setNoSanitizeMetadata ();
18381854 }
18391855 }
@@ -1843,6 +1859,55 @@ static void setNoSanitizedMetadataSPIR(Instruction &I) {
18431859 I.setNoSanitizeMetadata ();
18441860}
18451861
1862+ // This is not a general-purpose function, but a helper for demangling
1863+ // "__spirv_GroupAsyncCopy" function name
1864+ static int getTypeSizeFromManglingName (StringRef Name) {
1865+ auto GetTypeSize = [](const char C) {
1866+ switch (C) {
1867+ case ' a' : // signed char
1868+ case ' c' : // char
1869+ return 1 ;
1870+ case ' s' : // short
1871+ return 2 ;
1872+ case ' f' : // float
1873+ case ' i' : // int
1874+ return 4 ;
1875+ case ' d' : // double
1876+ case ' l' : // long
1877+ return 8 ;
1878+ default :
1879+ return 0 ;
1880+ }
1881+ };
1882+
1883+ // Name should always be long enough since it has other unmeaningful chars,
1884+ // it should have at least 6 chars, such as "Dv16_d"
1885+ if (Name.size () < 6 )
1886+ return 0 ;
1887+
1888+ // 1. Basic type
1889+ if (Name[0 ] != ' D' )
1890+ return GetTypeSize (Name[0 ]);
1891+
1892+ // 2. Vector type
1893+
1894+ // Drop "Dv"
1895+ assert (Name[0 ] == ' D' && Name[1 ] == ' v' &&
1896+ " Invalid mangling name for vector type" );
1897+ Name = Name.drop_front (2 );
1898+
1899+ // Vector length
1900+ assert (isDigit (Name[0 ]) && " Invalid mangling name for vector type" );
1901+ int Len = std::stoi (Name.str ());
1902+ Name = Name.drop_front (Len >= 10 ? 2 : 1 );
1903+
1904+ assert (Name[0 ] == ' _' && " Invalid mangling name for vector type" );
1905+ Name = Name.drop_front (1 );
1906+
1907+ int Size = GetTypeSize (Name[0 ]);
1908+ return Len * Size;
1909+ }
1910+
18461911namespace {
18471912
18481913// / Helper class to attach debug information of the given instruction onto new
@@ -6395,6 +6460,41 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
63956460 VAHelper->visitCallBase (CB, IRB);
63966461 }
63976462
6463+ if (SpirOrSpirv) {
6464+ auto *Func = CB.getCalledFunction ();
6465+ if (Func) {
6466+ auto FuncName = Func->getName ();
6467+ if (FuncName.contains (" __spirv_GroupAsyncCopy" )) {
6468+ // clang-format off
6469+ // Handle functions like "_Z22__spirv_GroupAsyncCopyiPU3AS3dPU3AS1dllP13__spirv_Event",
6470+ // its demangled name is "__spirv_GroupAsyncCopy(int, double AS3* dst, double AS1* src, long, long, __spirv_Event*)"
6471+ // The type of "src" and "dst" should always be same.
6472+ // clang-format on
6473+
6474+ auto *Dest = CB.getArgOperand (1 );
6475+ auto *Src = CB.getArgOperand (2 );
6476+ auto *NumElements = CB.getArgOperand (3 );
6477+ auto *Stride = CB.getArgOperand (4 );
6478+
6479+ // Skip "_Z22__spirv_GroupAsyncCopyiPU3AS3" (33 char), get the size of
6480+ // parameter type directly
6481+ const size_t kManglingPrefixLength = 33 ;
6482+ int ElementSize = getTypeSizeFromManglingName (
6483+ FuncName.substr (kManglingPrefixLength ));
6484+ assert (ElementSize != 0 &&
6485+ " Unsupported __spirv_GroupAsyncCopy element type" );
6486+
6487+ IRB.CreateCall (
6488+ MS.Spirv .MsanUnpoisonStridedCopyFunc ,
6489+ {IRB.CreatePointerCast (Dest, MS.Spirv .IntptrTy ),
6490+ IRB.getInt32 (Dest->getType ()->getPointerAddressSpace ()),
6491+ IRB.CreatePointerCast (Src, MS.Spirv .IntptrTy ),
6492+ IRB.getInt32 (Src->getType ()->getPointerAddressSpace ()),
6493+ IRB.getInt32 (ElementSize), NumElements, Stride});
6494+ }
6495+ }
6496+ }
6497+
63986498 // Now, get the shadow for the RetVal.
63996499 if (!CB.getType ()->isSized ())
64006500 return ;
0 commit comments