Skip to content

Commit a61dfc3

Browse files
committed
adds override for buffer
1 parent 36fe730 commit a61dfc3

File tree

2 files changed

+231
-71
lines changed

2 files changed

+231
-71
lines changed

gpu.hpp

Lines changed: 187 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,114 +1780,230 @@ inline void toCPU(Context &ctx, Tensor &tensor, std::array<float, N> &data,
17801780
wait(ctx, future);
17811781
}
17821782

1783-
inline void toCPU(Context &ctx, Tensor &tensor, NumType dtype, void *output, size_t sourceOffset = 0) {
1783+
inline void toCPU(Context &ctx, Tensor &tensor, NumType dtype, void *output,
1784+
size_t sourceOffset = 0) {
17841785
size_t numElements = size(tensor.shape);
17851786
switch (dtype) {
17861787
// These types are directly supported.
17871788
case kf16:
17881789
case kf32:
17891790
case ku32:
17901791
case ki32:
1791-
toCPU(ctx, tensor, output, tensor.data.size, sourceOffset);
1792-
break;
1792+
toCPU(ctx, tensor, output, tensor.data.size, sourceOffset);
1793+
break;
17931794

17941795
// For double, the tensor was created by packing doubles into floats.
17951796
case kf64: {
1796-
std::vector<float> tmp(numElements);
1797-
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(float), sourceOffset);
1798-
double *dst = static_cast<double*>(output);
1799-
for (size_t i = 0; i < numElements; ++i) {
1800-
dst[i] = static_cast<double>(tmp[i]);
1801-
}
1802-
break;
1797+
std::vector<float> tmp(numElements);
1798+
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(float), sourceOffset);
1799+
double *dst = static_cast<double *>(output);
1800+
for (size_t i = 0; i < numElements; ++i) {
1801+
dst[i] = static_cast<double>(tmp[i]);
1802+
}
1803+
break;
18031804
}
18041805

18051806
// For int8_t: four 8‑bit ints packed into one int32_t.
18061807
case ki8: {
1807-
size_t packedCount = (numElements + 3) / 4;
1808-
std::vector<int32_t> tmp(packedCount);
1809-
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(int32_t), sourceOffset);
1810-
int8_t *dst = static_cast<int8_t*>(output);
1811-
for (size_t i = 0; i < numElements; ++i) {
1812-
size_t idx = i / 4;
1813-
size_t shift = (i % 4) * 8;
1814-
dst[i] = static_cast<int8_t>((tmp[idx] >> shift) & 0xFF);
1815-
}
1816-
break;
1808+
size_t packedCount = (numElements + 3) / 4;
1809+
std::vector<int32_t> tmp(packedCount);
1810+
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(int32_t), sourceOffset);
1811+
int8_t *dst = static_cast<int8_t *>(output);
1812+
for (size_t i = 0; i < numElements; ++i) {
1813+
size_t idx = i / 4;
1814+
size_t shift = (i % 4) * 8;
1815+
dst[i] = static_cast<int8_t>((tmp[idx] >> shift) & 0xFF);
1816+
}
1817+
break;
18171818
}
18181819

18191820
// For int16_t: two 16‑bit ints packed into one int32_t.
18201821
case ki16: {
1821-
size_t packedCount = (numElements + 1) / 2;
1822-
std::vector<int32_t> tmp(packedCount);
1823-
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(int32_t), sourceOffset);
1824-
int16_t *dst = static_cast<int16_t*>(output);
1825-
for (size_t i = 0; i < numElements; ++i) {
1826-
size_t idx = i / 2;
1827-
size_t shift = (i % 2) * 16;
1828-
dst[i] = static_cast<int16_t>((tmp[idx] >> shift) & 0xFFFF);
1829-
}
1830-
break;
1822+
size_t packedCount = (numElements + 1) / 2;
1823+
std::vector<int32_t> tmp(packedCount);
1824+
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(int32_t), sourceOffset);
1825+
int16_t *dst = static_cast<int16_t *>(output);
1826+
for (size_t i = 0; i < numElements; ++i) {
1827+
size_t idx = i / 2;
1828+
size_t shift = (i % 2) * 16;
1829+
dst[i] = static_cast<int16_t>((tmp[idx] >> shift) & 0xFFFF);
1830+
}
1831+
break;
18311832
}
18321833

18331834
// For int64_t: each 64‑bit int was packed into two int32_t.
18341835
case ki64: {
1835-
std::vector<int32_t> tmp(numElements * 2);
1836-
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(int32_t), sourceOffset);
1837-
int64_t *dst = static_cast<int64_t*>(output);
1838-
for (size_t i = 0; i < numElements; ++i) {
1839-
int32_t low = tmp[2 * i];
1840-
int32_t high = tmp[2 * i + 1];
1841-
dst[i] = (static_cast<int64_t>(high) << 32) |
1842-
(static_cast<uint32_t>(low));
1843-
}
1844-
break;
1836+
std::vector<int32_t> tmp(numElements * 2);
1837+
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(int32_t), sourceOffset);
1838+
int64_t *dst = static_cast<int64_t *>(output);
1839+
for (size_t i = 0; i < numElements; ++i) {
1840+
int32_t low = tmp[2 * i];
1841+
int32_t high = tmp[2 * i + 1];
1842+
dst[i] =
1843+
(static_cast<int64_t>(high) << 32) | (static_cast<uint32_t>(low));
1844+
}
1845+
break;
18451846
}
18461847

18471848
// For uint8_t: four 8‑bit uints packed into one uint32_t.
18481849
case ku8: {
1849-
size_t packedCount = (numElements + 3) / 4;
1850-
std::vector<uint32_t> tmp(packedCount);
1851-
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(uint32_t), sourceOffset);
1852-
uint8_t *dst = static_cast<uint8_t*>(output);
1853-
for (size_t i = 0; i < numElements; ++i) {
1854-
size_t idx = i / 4;
1855-
size_t shift = (i % 4) * 8;
1856-
dst[i] = static_cast<uint8_t>((tmp[idx] >> shift) & 0xFF);
1857-
}
1858-
break;
1850+
size_t packedCount = (numElements + 3) / 4;
1851+
std::vector<uint32_t> tmp(packedCount);
1852+
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(uint32_t), sourceOffset);
1853+
uint8_t *dst = static_cast<uint8_t *>(output);
1854+
for (size_t i = 0; i < numElements; ++i) {
1855+
size_t idx = i / 4;
1856+
size_t shift = (i % 4) * 8;
1857+
dst[i] = static_cast<uint8_t>((tmp[idx] >> shift) & 0xFF);
1858+
}
1859+
break;
18591860
}
18601861

18611862
// For uint16_t: two 16‑bit uints packed into one uint32_t.
18621863
case ku16: {
1863-
size_t packedCount = (numElements + 1) / 2;
1864-
std::vector<uint32_t> tmp(packedCount);
1865-
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(uint32_t), sourceOffset);
1866-
uint16_t *dst = static_cast<uint16_t*>(output);
1867-
for (size_t i = 0; i < numElements; ++i) {
1868-
size_t idx = i / 2;
1869-
size_t shift = (i % 2) * 16;
1870-
dst[i] = static_cast<uint16_t>((tmp[idx] >> shift) & 0xFFFF);
1871-
}
1872-
break;
1864+
size_t packedCount = (numElements + 1) / 2;
1865+
std::vector<uint32_t> tmp(packedCount);
1866+
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(uint32_t), sourceOffset);
1867+
uint16_t *dst = static_cast<uint16_t *>(output);
1868+
for (size_t i = 0; i < numElements; ++i) {
1869+
size_t idx = i / 2;
1870+
size_t shift = (i % 2) * 16;
1871+
dst[i] = static_cast<uint16_t>((tmp[idx] >> shift) & 0xFFFF);
1872+
}
1873+
break;
18731874
}
18741875

18751876
// For uint64_t: each 64‑bit unsigned int was packed into two uint32_t.
18761877
case ku64: {
1877-
std::vector<uint32_t> tmp(numElements * 2);
1878-
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(uint32_t), sourceOffset);
1879-
uint64_t *dst = static_cast<uint64_t*>(output);
1880-
for (size_t i = 0; i < numElements; ++i) {
1881-
uint32_t low = tmp[2 * i];
1882-
uint32_t high = tmp[2 * i + 1];
1883-
dst[i] = (static_cast<uint64_t>(high) << 32) | low;
1884-
}
1885-
break;
1878+
std::vector<uint32_t> tmp(numElements * 2);
1879+
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(uint32_t), sourceOffset);
1880+
uint64_t *dst = static_cast<uint64_t *>(output);
1881+
for (size_t i = 0; i < numElements; ++i) {
1882+
uint32_t low = tmp[2 * i];
1883+
uint32_t high = tmp[2 * i + 1];
1884+
dst[i] = (static_cast<uint64_t>(high) << 32) | low;
1885+
}
1886+
break;
1887+
}
1888+
1889+
default:
1890+
LOG(kDefLog, kError, "Unsupported dtype in toCPUUnpack");
1891+
break;
1892+
}
1893+
}
1894+
1895+
inline void toCPU(Context &ctx, WGPUBuffer buffer, NumType dtype, void *output,
1896+
size_t numElements, size_t sourceOffset = 0) {
1897+
switch (dtype) {
1898+
// Directly supported types.
1899+
case kf16:
1900+
case kf32:
1901+
case ku32:
1902+
case ki32: {
1903+
size_t byteSize = numElements * sizeBytes(dtype);
1904+
toCPU(ctx, buffer, output, byteSize, sourceOffset);
1905+
break;
1906+
}
1907+
1908+
// For double, the buffer was written as floats.
1909+
case kf64: {
1910+
std::vector<float> tmp(numElements);
1911+
toCPU(ctx, buffer, tmp.data(), numElements * sizeof(float), sourceOffset);
1912+
double *dst = static_cast<double *>(output);
1913+
for (size_t i = 0; i < numElements; ++i) {
1914+
dst[i] = static_cast<double>(tmp[i]);
1915+
}
1916+
break;
1917+
}
1918+
1919+
// For int8_t: four 8‑bit ints packed into one int32_t.
1920+
case ki8: {
1921+
size_t packedCount = (numElements + 3) / 4;
1922+
std::vector<int32_t> tmp(packedCount);
1923+
toCPU(ctx, buffer, tmp.data(), packedCount * sizeof(int32_t), sourceOffset);
1924+
int8_t *dst = static_cast<int8_t *>(output);
1925+
for (size_t i = 0; i < numElements; ++i) {
1926+
size_t idx = i / 4;
1927+
size_t shift = (i % 4) * 8;
1928+
dst[i] = static_cast<int8_t>((tmp[idx] >> shift) & 0xFF);
1929+
}
1930+
break;
1931+
}
1932+
1933+
// For int16_t: two 16‑bit ints packed into one int32_t.
1934+
case ki16: {
1935+
size_t packedCount = (numElements + 1) / 2;
1936+
std::vector<int32_t> tmp(packedCount);
1937+
toCPU(ctx, buffer, tmp.data(), packedCount * sizeof(int32_t), sourceOffset);
1938+
int16_t *dst = static_cast<int16_t *>(output);
1939+
for (size_t i = 0; i < numElements; ++i) {
1940+
size_t idx = i / 2;
1941+
size_t shift = (i % 2) * 16;
1942+
dst[i] = static_cast<int16_t>((tmp[idx] >> shift) & 0xFFFF);
1943+
}
1944+
break;
1945+
}
1946+
1947+
// For int64_t: each 64‑bit int is packed into two int32_t.
1948+
case ki64: {
1949+
std::vector<int32_t> tmp(numElements * 2);
1950+
toCPU(ctx, buffer, tmp.data(), tmp.size() * sizeof(int32_t), sourceOffset);
1951+
int64_t *dst = static_cast<int64_t *>(output);
1952+
for (size_t i = 0; i < numElements; ++i) {
1953+
int32_t low = tmp[2 * i];
1954+
int32_t high = tmp[2 * i + 1];
1955+
dst[i] =
1956+
(static_cast<int64_t>(high) << 32) | (static_cast<uint32_t>(low));
1957+
}
1958+
break;
1959+
}
1960+
1961+
// For uint8_t: four 8‑bit uints packed into one uint32_t.
1962+
case ku8: {
1963+
size_t packedCount = (numElements + 3) / 4;
1964+
std::vector<uint32_t> tmp(packedCount);
1965+
toCPU(ctx, buffer, tmp.data(), packedCount * sizeof(uint32_t),
1966+
sourceOffset);
1967+
uint8_t *dst = static_cast<uint8_t *>(output);
1968+
for (size_t i = 0; i < numElements; ++i) {
1969+
size_t idx = i / 4;
1970+
size_t shift = (i % 4) * 8;
1971+
dst[i] = static_cast<uint8_t>((tmp[idx] >> shift) & 0xFF);
1972+
}
1973+
break;
1974+
}
1975+
1976+
// For uint16_t: two 16‑bit uints packed into one uint32_t.
1977+
case ku16: {
1978+
size_t packedCount = (numElements + 1) / 2;
1979+
std::vector<uint32_t> tmp(packedCount);
1980+
toCPU(ctx, buffer, tmp.data(), packedCount * sizeof(uint32_t),
1981+
sourceOffset);
1982+
uint16_t *dst = static_cast<uint16_t *>(output);
1983+
for (size_t i = 0; i < numElements; ++i) {
1984+
size_t idx = i / 2;
1985+
size_t shift = (i % 2) * 16;
1986+
dst[i] = static_cast<uint16_t>((tmp[idx] >> shift) & 0xFFFF);
1987+
}
1988+
break;
1989+
}
1990+
1991+
// For uint64_t: each 64‑bit unsigned int packed into two uint32_t.
1992+
case ku64: {
1993+
std::vector<uint32_t> tmp(numElements * 2);
1994+
toCPU(ctx, buffer, tmp.data(), tmp.size() * sizeof(uint32_t), sourceOffset);
1995+
uint64_t *dst = static_cast<uint64_t *>(output);
1996+
for (size_t i = 0; i < numElements; ++i) {
1997+
uint32_t low = tmp[2 * i];
1998+
uint32_t high = tmp[2 * i + 1];
1999+
dst[i] = (static_cast<uint64_t>(high) << 32) | low;
2000+
}
2001+
break;
18862002
}
18872003

18882004
default:
1889-
LOG(kDefLog, kError, "Unsupported dtype in toCPUUnpack");
1890-
break;
2005+
LOG(kDefLog, kError, "Unsupported dtype in toCPU (raw buffer override)");
2006+
break;
18912007
}
18922008
}
18932009

test/test_gpu.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ void testToCPUWithUint32();
3030
void testToCPUWithUint64();
3131
void testNumTypeSizes();
3232
void testToCPUUnpack();
33+
void testCopyShaderPackedUnpack_int8();
3334

3435
int main() {
3536
LOG(kDefLog, kInfo, "Running GPU integration tests...");
37+
testCopyShaderPackedUnpack_int8();
3638
testToCPUUnpack();
3739
testToCPUWithTensor();
3840
testToCPUWithBuffer();
@@ -70,6 +72,48 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
7072
}
7173
)";
7274

75+
void testCopyShaderPackedUnpack_int8() {
76+
LOG(kDefLog, kInfo, "Running testCopyShaderPackedUnpack_int8...");
77+
78+
#ifdef USE_DAWN_API
79+
Context ctx = createContextByGpuIdx(0);
80+
#else
81+
Context ctx = createContext();
82+
#endif
83+
84+
constexpr size_t N = 1024;
85+
std::vector<int8_t> inputData(N), outputData(N);
86+
for (size_t i = 0; i < N; ++i) {
87+
// Values between -128 and 127.
88+
inputData[i] = static_cast<int8_t>((i % 256) - 128);
89+
}
90+
91+
// Create an input tensor using the int8_t overload.
92+
// Under the hood the data is packed into int32_t.
93+
Tensor inputTensor = createTensor(ctx, Shape{N}, ki8, inputData.data());
94+
95+
// Create an output tensor of the same shape and unsupported type.
96+
Tensor outputTensor = createTensor(ctx, Shape{N}, ki8);
97+
98+
// Our copy shader (kCopyKernel) expects to work with supported types.
99+
// Since int8_t is packed into int32_t, we pass 'ki32' as our shader
100+
// precision.
101+
Kernel copyKernel =
102+
createKernel(ctx, {kCopyKernel, 256, ki32},
103+
Bindings{inputTensor, outputTensor}, {cdiv(N, 256), 1, 1});
104+
dispatchKernel(ctx, copyKernel);
105+
106+
// Now retrieve the output from the GPU and unpack from the packed int32_t
107+
// back to int8_t.
108+
toCPU(ctx, outputTensor, ki8, outputData.data(), 0);
109+
110+
// Verify the unpacked data matches the original input.
111+
for (size_t i = 0; i < N; ++i) {
112+
assert(inputData[i] == outputData[i]);
113+
}
114+
LOG(kDefLog, kInfo, "testCopyShaderPackedUnpack_int8 passed.");
115+
}
116+
73117
void testToCPUUnpack() {
74118
LOG(kDefLog, kInfo, "Running testToCPUUnpack...");
75119

0 commit comments

Comments
 (0)