Skip to content

Commit f988a0b

Browse files
committed
conversion kernels for unpacking
1 parent ea8b2fd commit f988a0b

File tree

1 file changed

+212
-31
lines changed

1 file changed

+212
-31
lines changed

test/test_gpu.cpp

Lines changed: 212 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,115 @@
1111
using namespace gpu;
1212
using namespace std::chrono;
1313

14+
// WGSL Kernels
15+
16+
// Kernel to unpack 4x int8 (packed in i32) to 4x int32
17+
const char *kPackedInt8ToInt32Kernel = R"(
18+
@group(0) @binding(0) var<storage, read_write> packed_input: array<i32>;
19+
@group(0) @binding(1) var<storage, read_write> unpacked_output: array<i32>;
20+
21+
// Function to sign-extend an 8-bit value (represented in the lower bits of an i32)
22+
fn sign_extend_i8(val: i32) -> i32 {
23+
return (val << 24) >> 24;
24+
}
25+
26+
@compute @workgroup_size({{workgroupSize}})
27+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
28+
let packed_idx: u32 = gid.x;
29+
30+
// Check bounds for the PACKED input array
31+
if (packed_idx >= arrayLength(&packed_input)) {
32+
return;
33+
}
34+
35+
let packed_val = packed_input[packed_idx];
36+
37+
// Unpack and write 4 separate i32 values
38+
// Ensure the output buffer is large enough (4x the packed size)
39+
let base_output_idx = packed_idx * 4u;
40+
41+
// Check bounds for the UNPACKED output array (optional but safer)
42+
// This assumes arrayLength(&unpacked_output) is at least 4 * arrayLength(&packed_input)
43+
if ((base_output_idx + 3u) >= arrayLength(&unpacked_output)) {
44+
return; // Avoid out-of-bounds write if something is wrong
45+
}
46+
47+
unpacked_output[base_output_idx + 0u] = sign_extend_i8((packed_val >> 0u) & 0xFF);
48+
unpacked_output[base_output_idx + 1u] = sign_extend_i8((packed_val >> 8u) & 0xFF);
49+
unpacked_output[base_output_idx + 2u] = sign_extend_i8((packed_val >> 16u) & 0xFF);
50+
unpacked_output[base_output_idx + 3u] = sign_extend_i8((packed_val >> 24u) & 0xFF);
51+
}
52+
)";
53+
54+
// Kernel to pack 4x int32 back into 1x int32 (taking lower 8 bits)
55+
const char *kInt32ToPackedInt8Kernel = R"(
56+
@group(0) @binding(0) var<storage, read_write> unpacked_input: array<i32>;
57+
@group(0) @binding(1) var<storage, read_write> packed_output: array<i32>;
58+
59+
@compute @workgroup_size({{workgroupSize}})
60+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
61+
let packed_idx: u32 = gid.x; // Index for the PACKED output array
62+
63+
// Check bounds for the PACKED output array
64+
if (packed_idx >= arrayLength(&packed_output)) {
65+
return;
66+
}
67+
68+
let base_input_idx = packed_idx * 4u;
69+
70+
// Check bounds for the UNPACKED input array (optional but safer)
71+
// Assumes arrayLength(&unpacked_input) is at least 4 * arrayLength(&packed_output)
72+
if ((base_input_idx + 3u) >= arrayLength(&unpacked_input)) {
73+
// Handle potential error or incomplete data - maybe write 0?
74+
packed_output[packed_idx] = 0;
75+
return;
76+
}
77+
78+
// Read 4 separate i32 values
79+
let val0 = unpacked_input[base_input_idx + 0u];
80+
let val1 = unpacked_input[base_input_idx + 1u];
81+
let val2 = unpacked_input[base_input_idx + 2u];
82+
let val3 = unpacked_input[base_input_idx + 3u];
83+
84+
// Pack the lower 8 bits of each into one i32
85+
var packed_result: i32 = 0;
86+
packed_result = packed_result | ((val0 & 0xFF) << 0u);
87+
packed_result = packed_result | ((val1 & 0xFF) << 8u);
88+
packed_result = packed_result | ((val2 & 0xFF) << 16u);
89+
packed_result = packed_result | ((val3 & 0xFF) << 24u);
90+
91+
packed_output[packed_idx] = packed_result;
92+
}
93+
)";
94+
95+
// Simple addition kernel for i32
96+
const char *kSimpleAddKernelI32 = R"(
97+
@group(0) @binding(0) var<storage, read_write> a: array<{{precision}}>;
98+
@group(0) @binding(1) var<storage, read_write> b: array<{{precision}}>;
99+
@group(0) @binding(2) var<storage, read_write> c: array<{{precision}}>;
100+
101+
@compute @workgroup_size({{workgroupSize}})
102+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
103+
let i: u32 = gid.x;
104+
if (i < arrayLength(&a)) {
105+
c[i] = a[i] + b[i];
106+
}
107+
}
108+
)";
109+
110+
// A simple WGSL copy kernel that copies input to output.
111+
static const char *kCopyKernel = R"(
112+
@group(0) @binding(0) var<storage, read_write> inp: array<{{precision}}>;
113+
@group(0) @binding(1) var<storage, read_write> out: array<{{precision}}>;
114+
@compute @workgroup_size({{workgroupSize}})
115+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
116+
let i: u32 = gid.x;
117+
if (i < arrayLength(&inp)) {
118+
out[i] = inp[i];
119+
}
120+
}
121+
)";
122+
14123
// Forward declarations:
15124
void testToCPUWithTensor();
16125
void testToCPUWithBuffer();
@@ -31,46 +140,118 @@ void testToCPUWithUint64();
31140
void testNumTypeSizes();
32141
void testToCPUUnpack();
33142
void testCopyShaderPackedUnpack_int8();
143+
void testAddKernelInt8();
34144

35145
int main() {
36146
LOG(kDefLog, kInfo, "Running GPU integration tests...");
37-
testCopyShaderPackedUnpack_int8();
38-
testToCPUUnpack();
39-
testToCPUWithTensor();
40-
testToCPUWithBuffer();
41-
testToCPUWithTensorSourceOffset();
42-
testToCPUWithBufferSourceOffset();
43-
testToCPUWithHalf();
44-
testToCPUWithFloat();
45-
testToCPUWithDouble();
46-
testToCPUWithint8();
47-
testToCPUWithint16();
48-
testToCPUWithint();
49-
testToCPUWithint64();
50-
testToCPUWithUint8();
51-
testToCPUWithUint16();
52-
testToCPUWithUint32();
53-
testToCPUWithUint64();
54-
testNumTypeSizes();
55-
stressTestToCPU();
56-
testHalf();
147+
testAddKernelInt8();
148+
// testCopyShaderPackedUnpack_int8();
149+
// testToCPUUnpack();
150+
// testToCPUWithTensor();
151+
// testToCPUWithBuffer();
152+
// testToCPUWithTensorSourceOffset();
153+
// testToCPUWithBufferSourceOffset();
154+
// testToCPUWithHalf();
155+
// testToCPUWithFloat();
156+
// testToCPUWithDouble();
157+
// testToCPUWithint8();
158+
// testToCPUWithint16();
159+
// testToCPUWithint();
160+
// testToCPUWithint64();
161+
// testToCPUWithUint8();
162+
// testToCPUWithUint16();
163+
// testToCPUWithUint32();
164+
// testToCPUWithUint64();
165+
// testNumTypeSizes();
166+
// stressTestToCPU();
167+
// testHalf();
57168
LOG(kDefLog, kInfo, "All tests passed.");
58169
return 0;
59170
}
60171

61-
// A simple WGSL copy kernel that copies input to output.
62-
static const char *kCopyKernel = R"(
63-
@group(0) @binding(0) var<storage, read_write> inp: array<{{precision}}>;
64-
@group(0) @binding(1) var<storage, read_write> out: array<{{precision}}>;
65-
@group(0) @binding(1) var<storage, read_write> dummy: array<{{precision}}>;
66-
@compute @workgroup_size({{workgroupSize}})
67-
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
68-
let i: u32 = gid.x;
69-
if (i < arrayLength(&inp)) {
70-
out[i] = inp[i];
172+
void testAddKernelInt8() {
173+
LOG(kDefLog, kInfo, "Running testAddKernelInt8 (with conversion kernels)...");
174+
175+
#ifdef USE_DAWN_API
176+
Context ctx = createContextByGpuIdx(0);
177+
#else
178+
Context ctx = createContext();
179+
#endif
180+
181+
constexpr size_t N = 1024; // Logical number of int8 elements
182+
std::vector<int8_t> aInput(N), bInput(N), result(N);
183+
std::vector<int8_t> expected(N);
184+
185+
// CPU Data Setup
186+
for (size_t i = 0; i < N; ++i) {
187+
// Values in range [-10, 9]
188+
aInput[i] = static_cast<int8_t>((i % 20) - 10);
189+
bInput[i] = static_cast<int8_t>(((2 * i) % 20) - 10);
190+
// Compute expected as int then cast back.
191+
int temp = static_cast<int>(aInput[i]) + static_cast<int>(bInput[i]);
192+
expected[i] = static_cast<int8_t>(temp);
193+
result[i] = 0;
194+
}
195+
196+
// These store the int8 data packed into i32 format on the GPU
197+
Tensor aTensorPacked = createTensor(ctx, Shape{N}, ki8, aInput.data());
198+
Tensor bTensorPacked = createTensor(ctx, Shape{N}, ki8, bInput.data());
199+
// Final output tensor, also in packed format
200+
Tensor outputTensorPacked = createTensor(ctx, Shape{N}, ki8);
201+
202+
// These will hold the data converted to one i32 per original int8 element
203+
Tensor aTensorUnpacked = createTensor(ctx, Shape{N}, ki32);
204+
Tensor bTensorUnpacked = createTensor(ctx, Shape{N}, ki32);
205+
Tensor outputTensorUnpacked =
206+
createTensor(ctx, Shape{N}, ki32); // For the simple add result
207+
208+
constexpr uint32_t workgroupSize = 256;
209+
size_t packedCount = (N + 3) / 4; // Number of i32 elements in packed buffers
210+
size_t unpackedCount = N; // Number of i32 elements in unpacked buffers
211+
212+
// Convert Packed Inputs to Unpacked i32
213+
Kernel unpackKernelA =
214+
createKernel(ctx, {kPackedInt8ToInt32Kernel, workgroupSize, ki32},
215+
Bindings{aTensorPacked, aTensorUnpacked},
216+
{cdiv(packedCount, workgroupSize), 1,
217+
1}); // Dispatch based on packed size
218+
Kernel unpackKernelB =
219+
createKernel(ctx, {kPackedInt8ToInt32Kernel, workgroupSize, ki32},
220+
Bindings{bTensorPacked, bTensorUnpacked},
221+
{cdiv(packedCount, workgroupSize), 1,
222+
1});
223+
// Dispatch based on packed size
224+
dispatchKernel(ctx, unpackKernelA);
225+
dispatchKernel(ctx, unpackKernelB);
226+
227+
// Perform Simple Addition on Unpacked i32
228+
Kernel simpleAddKernel = createKernel(
229+
ctx, {kSimpleAddKernelI32, workgroupSize, ki32},
230+
Bindings{aTensorUnpacked, bTensorUnpacked, outputTensorUnpacked},
231+
{cdiv(unpackedCount, workgroupSize), 1,
232+
1}); // Dispatch based on unpacked size
233+
dispatchKernel(ctx, simpleAddKernel);
234+
235+
// Convert Unpacked i32 Result back to Packed
236+
Kernel packKernel =
237+
createKernel(ctx, {kInt32ToPackedInt8Kernel, workgroupSize, ki32},
238+
Bindings{outputTensorUnpacked, outputTensorPacked},
239+
{cdiv(packedCount, workgroupSize), 1,
240+
1}); // Dispatch based on packed size
241+
dispatchKernel(ctx, packKernel);
242+
243+
// Copy Final Packed Result to CPU and Unpack
244+
// Use the original toCPU for ki8, which handles the final CPU-side unpacking
245+
toCPU(ctx, outputTensorPacked, ki8, result.data(), 0);
246+
247+
for (size_t i = 0; i < N; ++i) {
248+
LOG(kDefLog, kInfo, "result[%zu] = %d, expected[%zu] = %d", i, result[i], i,
249+
expected[i]);
250+
assert(result[i] == expected[i]);
71251
}
252+
253+
LOG(kDefLog, kInfo, "testAddKernelInt8 (with conversion kernels) passed.");
72254
}
73-
)";
74255

75256
void testCopyShaderPackedUnpack_int8() {
76257
LOG(kDefLog, kInfo, "Running testCopyShaderPackedUnpack_int8...");

0 commit comments

Comments
 (0)