Skip to content

Commit 36fe730

Browse files
committed
need to pack and unpack unsupported types
1 parent 81bfe07 commit 36fe730

File tree

2 files changed

+337
-86
lines changed

2 files changed

+337
-86
lines changed

gpu.hpp

Lines changed: 198 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -757,21 +757,21 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
757757
return tensor;
758758
}
759759

760+
// Overload for double: pack each double into a float (losing precision)
760761
inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
761762
const double *data) {
762-
assert(dtype == kf64);
763-
Tensor tensor =
764-
createTensor(ctx.pool, ctx.device, shape, dtype,
765-
WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
766-
WGPUBufferUsage_CopySrc);
767-
wgpuQueueWriteBuffer(ctx.queue, tensor.data.buffer, 0, data,
768-
tensor.data.size);
769-
return tensor;
763+
assert(dtype == kf64); // unsupported: convert to kf32
764+
size_t numElements = size(shape);
765+
std::vector<float> packed(numElements);
766+
for (size_t i = 0; i < numElements; ++i) {
767+
packed[i] = static_cast<float>(data[i]);
768+
}
769+
return createTensor(ctx, shape, kf32, packed.data());
770770
}
771771

772772
inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
773-
const uint8_t *data) {
774-
assert(dtype == ku8);
773+
const int32_t *data) {
774+
assert(dtype == ki32);
775775
Tensor tensor =
776776
createTensor(ctx.pool, ctx.device, shape, dtype,
777777
WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
@@ -781,45 +781,55 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
781781
return tensor;
782782
}
783783

784+
// Overload for int8_t: pack four 8‑bit ints into one 32‑bit integer
784785
inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
785-
const uint16_t *data) {
786-
assert(dtype == ku16);
787-
Tensor tensor =
788-
createTensor(ctx.pool, ctx.device, shape, dtype,
789-
WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
790-
WGPUBufferUsage_CopySrc);
791-
wgpuQueueWriteBuffer(ctx.queue, tensor.data.buffer, 0, data,
792-
tensor.data.size);
793-
return tensor;
786+
const int8_t *data) {
787+
assert(dtype == ki8); // unsupported: pack into ki32
788+
size_t numElements = size(shape);
789+
size_t packedCount = (numElements + 3) / 4;
790+
std::vector<int32_t> packed(packedCount, 0);
791+
for (size_t i = 0; i < numElements; ++i) {
792+
size_t idx = i / 4;
793+
size_t shift = (i % 4) * 8;
794+
// pack as unsigned bits then reinterpret; shader is then responsible for
795+
// unpacking
796+
packed[idx] |= (static_cast<uint8_t>(data[i]) << shift);
797+
}
798+
return createTensor(ctx, shape, ki32, packed.data());
794799
}
795800

801+
// Overload for int16_t: pack two 16‑bit ints into one 32‑bit integer
796802
inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
797-
const uint32_t *data) {
798-
assert(dtype == ku32);
799-
Tensor tensor =
800-
createTensor(ctx.pool, ctx.device, shape, dtype,
801-
WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
802-
WGPUBufferUsage_CopySrc);
803-
wgpuQueueWriteBuffer(ctx.queue, tensor.data.buffer, 0, data,
804-
tensor.data.size);
805-
return tensor;
803+
const int16_t *data) {
804+
assert(dtype == ki16); // unsupported: pack into ki32
805+
size_t numElements = size(shape);
806+
size_t packedCount = (numElements + 1) / 2;
807+
std::vector<int32_t> packed(packedCount, 0);
808+
for (size_t i = 0; i < numElements; ++i) {
809+
size_t idx = i / 2;
810+
size_t shift = (i % 2) * 16;
811+
packed[idx] |= (static_cast<uint16_t>(data[i]) << shift);
812+
}
813+
return createTensor(ctx, shape, ki32, packed.data());
806814
}
807815

816+
// Overload for int64_t: pack each 64‑bit int into two 32‑bit integers
808817
inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
809-
const uint64_t *data) {
810-
assert(dtype == ku64);
811-
Tensor tensor =
812-
createTensor(ctx.pool, ctx.device, shape, dtype,
813-
WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
814-
WGPUBufferUsage_CopySrc);
815-
wgpuQueueWriteBuffer(ctx.queue, tensor.data.buffer, 0, data,
816-
tensor.data.size);
817-
return tensor;
818+
const int64_t *data) {
819+
assert(dtype == ki64); // unsupported: pack into two ki32s
820+
size_t numElements = size(shape);
821+
std::vector<int32_t> packed(numElements * 2);
822+
for (size_t i = 0; i < numElements; ++i) {
823+
int64_t val = data[i];
824+
packed[2 * i] = static_cast<int32_t>(val & 0xFFFFFFFF);
825+
packed[2 * i + 1] = static_cast<int32_t>((val >> 32) & 0xFFFFFFFF);
826+
}
827+
return createTensor(ctx, shape, ki32, packed.data());
818828
}
819829

820830
inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
821-
const int64_t *data) {
822-
assert(dtype == ki64);
831+
const uint32_t *data) {
832+
assert(dtype == ku32);
823833
Tensor tensor =
824834
createTensor(ctx.pool, ctx.device, shape, dtype,
825835
WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
@@ -829,40 +839,51 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
829839
return tensor;
830840
}
831841

842+
// Overload for uint8_t: pack four 8‑bit integers into one 32‑bit unsigned
843+
// integer
832844
inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
833-
const int8_t *data) {
834-
assert(dtype == ki8);
835-
Tensor tensor =
836-
createTensor(ctx.pool, ctx.device, shape, dtype,
837-
WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
838-
WGPUBufferUsage_CopySrc);
839-
wgpuQueueWriteBuffer(ctx.queue, tensor.data.buffer, 0, data,
840-
tensor.data.size);
841-
return tensor;
845+
const uint8_t *data) {
846+
assert(dtype == ku8); // unsupported: pack into ku32
847+
size_t numElements = size(shape);
848+
size_t packedCount = (numElements + 3) / 4;
849+
std::vector<uint32_t> packed(packedCount, 0);
850+
for (size_t i = 0; i < numElements; ++i) {
851+
size_t idx = i / 4;
852+
size_t shift = (i % 4) * 8;
853+
packed[idx] |= (static_cast<uint32_t>(data[i]) << shift);
854+
}
855+
return createTensor(ctx, shape, ku32, packed.data());
842856
}
843857

858+
// Overload for uint16_t: pack two 16‑bit integers into one 32‑bit unsigned
859+
// integer
844860
inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
845-
const int16_t *data) {
846-
assert(dtype == ki16);
847-
Tensor tensor =
848-
createTensor(ctx.pool, ctx.device, shape, dtype,
849-
WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
850-
WGPUBufferUsage_CopySrc);
851-
wgpuQueueWriteBuffer(ctx.queue, tensor.data.buffer, 0, data,
852-
tensor.data.size);
853-
return tensor;
861+
const uint16_t *data) {
862+
assert(dtype == ku16); // unsupported: pack into ku32
863+
size_t numElements = size(shape);
864+
size_t packedCount = (numElements + 1) / 2;
865+
std::vector<uint32_t> packed(packedCount, 0);
866+
for (size_t i = 0; i < numElements; ++i) {
867+
size_t idx = i / 2;
868+
size_t shift = (i % 2) * 16;
869+
packed[idx] |= (static_cast<uint32_t>(data[i]) << shift);
870+
}
871+
return createTensor(ctx, shape, ku32, packed.data());
854872
}
855873

874+
// Overload for uint64_t: pack each 64‑bit integer into two 32‑bit unsigned
875+
// integers
856876
inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
857-
const int32_t *data) {
858-
assert(dtype == ki32);
859-
Tensor tensor =
860-
createTensor(ctx.pool, ctx.device, shape, dtype,
861-
WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
862-
WGPUBufferUsage_CopySrc);
863-
wgpuQueueWriteBuffer(ctx.queue, tensor.data.buffer, 0, data,
864-
tensor.data.size);
865-
return tensor;
877+
const uint64_t *data) {
878+
assert(dtype == ku64); // unsupported: pack into two ku32s
879+
size_t numElements = size(shape);
880+
std::vector<uint32_t> packed(numElements * 2);
881+
for (size_t i = 0; i < numElements; ++i) {
882+
uint64_t val = data[i];
883+
packed[2 * i] = static_cast<uint32_t>(val & 0xFFFFFFFF);
884+
packed[2 * i + 1] = static_cast<uint32_t>(val >> 32);
885+
}
886+
return createTensor(ctx, shape, ku32, packed.data());
866887
}
867888

868889
/**
@@ -1759,6 +1780,117 @@ inline void toCPU(Context &ctx, Tensor &tensor, std::array<float, N> &data,
17591780
wait(ctx, future);
17601781
}
17611782

1783+
inline void toCPU(Context &ctx, Tensor &tensor, NumType dtype, void *output, size_t sourceOffset = 0) {
1784+
size_t numElements = size(tensor.shape);
1785+
switch (dtype) {
1786+
// These types are directly supported.
1787+
case kf16:
1788+
case kf32:
1789+
case ku32:
1790+
case ki32:
1791+
toCPU(ctx, tensor, output, tensor.data.size, sourceOffset);
1792+
break;
1793+
1794+
// For double, the tensor was created by packing doubles into floats.
1795+
case kf64: {
1796+
std::vector<float> tmp(numElements);
1797+
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(float), sourceOffset);
1798+
double *dst = static_cast<double*>(output);
1799+
for (size_t i = 0; i < numElements; ++i) {
1800+
dst[i] = static_cast<double>(tmp[i]);
1801+
}
1802+
break;
1803+
}
1804+
1805+
// For int8_t: four 8‑bit ints packed into one int32_t.
1806+
case ki8: {
1807+
size_t packedCount = (numElements + 3) / 4;
1808+
std::vector<int32_t> tmp(packedCount);
1809+
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(int32_t), sourceOffset);
1810+
int8_t *dst = static_cast<int8_t*>(output);
1811+
for (size_t i = 0; i < numElements; ++i) {
1812+
size_t idx = i / 4;
1813+
size_t shift = (i % 4) * 8;
1814+
dst[i] = static_cast<int8_t>((tmp[idx] >> shift) & 0xFF);
1815+
}
1816+
break;
1817+
}
1818+
1819+
// For int16_t: two 16‑bit ints packed into one int32_t.
1820+
case ki16: {
1821+
size_t packedCount = (numElements + 1) / 2;
1822+
std::vector<int32_t> tmp(packedCount);
1823+
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(int32_t), sourceOffset);
1824+
int16_t *dst = static_cast<int16_t*>(output);
1825+
for (size_t i = 0; i < numElements; ++i) {
1826+
size_t idx = i / 2;
1827+
size_t shift = (i % 2) * 16;
1828+
dst[i] = static_cast<int16_t>((tmp[idx] >> shift) & 0xFFFF);
1829+
}
1830+
break;
1831+
}
1832+
1833+
// For int64_t: each 64‑bit int was packed into two int32_t.
1834+
case ki64: {
1835+
std::vector<int32_t> tmp(numElements * 2);
1836+
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(int32_t), sourceOffset);
1837+
int64_t *dst = static_cast<int64_t*>(output);
1838+
for (size_t i = 0; i < numElements; ++i) {
1839+
int32_t low = tmp[2 * i];
1840+
int32_t high = tmp[2 * i + 1];
1841+
dst[i] = (static_cast<int64_t>(high) << 32) |
1842+
(static_cast<uint32_t>(low));
1843+
}
1844+
break;
1845+
}
1846+
1847+
// For uint8_t: four 8‑bit uints packed into one uint32_t.
1848+
case ku8: {
1849+
size_t packedCount = (numElements + 3) / 4;
1850+
std::vector<uint32_t> tmp(packedCount);
1851+
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(uint32_t), sourceOffset);
1852+
uint8_t *dst = static_cast<uint8_t*>(output);
1853+
for (size_t i = 0; i < numElements; ++i) {
1854+
size_t idx = i / 4;
1855+
size_t shift = (i % 4) * 8;
1856+
dst[i] = static_cast<uint8_t>((tmp[idx] >> shift) & 0xFF);
1857+
}
1858+
break;
1859+
}
1860+
1861+
// For uint16_t: two 16‑bit uints packed into one uint32_t.
1862+
case ku16: {
1863+
size_t packedCount = (numElements + 1) / 2;
1864+
std::vector<uint32_t> tmp(packedCount);
1865+
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(uint32_t), sourceOffset);
1866+
uint16_t *dst = static_cast<uint16_t*>(output);
1867+
for (size_t i = 0; i < numElements; ++i) {
1868+
size_t idx = i / 2;
1869+
size_t shift = (i % 2) * 16;
1870+
dst[i] = static_cast<uint16_t>((tmp[idx] >> shift) & 0xFFFF);
1871+
}
1872+
break;
1873+
}
1874+
1875+
// For uint64_t: each 64‑bit unsigned int was packed into two uint32_t.
1876+
case ku64: {
1877+
std::vector<uint32_t> tmp(numElements * 2);
1878+
toCPU(ctx, tensor, tmp.data(), tmp.size() * sizeof(uint32_t), sourceOffset);
1879+
uint64_t *dst = static_cast<uint64_t*>(output);
1880+
for (size_t i = 0; i < numElements; ++i) {
1881+
uint32_t low = tmp[2 * i];
1882+
uint32_t high = tmp[2 * i + 1];
1883+
dst[i] = (static_cast<uint64_t>(high) << 32) | low;
1884+
}
1885+
break;
1886+
}
1887+
1888+
default:
1889+
LOG(kDefLog, kError, "Unsupported dtype in toCPUUnpack");
1890+
break;
1891+
}
1892+
}
1893+
17621894
/**
17631895
* @brief Copies data from CPU memory to a GPU buffer. The toGPU overloads are
17641896
* effectively a convenience wrapper around the WebGPU API call

0 commit comments

Comments
 (0)