@@ -757,21 +757,21 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
757757 return tensor;
758758}
759759
760+ // Overload for double: pack each double into a float (losing precision)
760761inline Tensor createTensor (Context &ctx, const Shape &shape, NumType dtype,
761762 const double *data) {
762- assert (dtype == kf64);
763- Tensor tensor =
764- createTensor (ctx.pool , ctx.device , shape, dtype,
765- WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
766- WGPUBufferUsage_CopySrc);
767- wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , data,
768- tensor.data .size );
769- return tensor;
763+ assert (dtype == kf64); // unsupported: convert to kf32
764+ size_t numElements = size (shape);
765+ std::vector<float > packed (numElements);
766+ for (size_t i = 0 ; i < numElements; ++i) {
767+ packed[i] = static_cast <float >(data[i]);
768+ }
769+ return createTensor (ctx, shape, kf32, packed.data ());
770770}
771771
772772inline Tensor createTensor (Context &ctx, const Shape &shape, NumType dtype,
773- const uint8_t *data) {
774- assert (dtype == ku8 );
773+ const int32_t *data) {
774+ assert (dtype == ki32 );
775775 Tensor tensor =
776776 createTensor (ctx.pool , ctx.device , shape, dtype,
777777 WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
@@ -781,45 +781,55 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
781781 return tensor;
782782}
783783
784+ // Overload for int8_t: pack four 8‑bit ints into one 32‑bit integer
784785inline Tensor createTensor (Context &ctx, const Shape &shape, NumType dtype,
785- const uint16_t *data) {
786- assert (dtype == ku16);
787- Tensor tensor =
788- createTensor (ctx.pool , ctx.device , shape, dtype,
789- WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
790- WGPUBufferUsage_CopySrc);
791- wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , data,
792- tensor.data .size );
793- return tensor;
786+ const int8_t *data) {
787+ assert (dtype == ki8); // unsupported: pack into ki32
788+ size_t numElements = size (shape);
789+ size_t packedCount = (numElements + 3 ) / 4 ;
790+ std::vector<int32_t > packed (packedCount, 0 );
791+ for (size_t i = 0 ; i < numElements; ++i) {
792+ size_t idx = i / 4 ;
793+ size_t shift = (i % 4 ) * 8 ;
794+ // pack as unsigned bits then reinterpret; shader is then responsible for
795+ // unpacking
796+ packed[idx] |= (static_cast <uint8_t >(data[i]) << shift);
797+ }
798+ return createTensor (ctx, shape, ki32, packed.data ());
794799}
795800
801+ // Overload for int16_t: pack two 16‑bit ints into one 32‑bit integer
796802inline Tensor createTensor (Context &ctx, const Shape &shape, NumType dtype,
797- const uint32_t *data) {
798- assert (dtype == ku32);
799- Tensor tensor =
800- createTensor (ctx.pool , ctx.device , shape, dtype,
801- WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
802- WGPUBufferUsage_CopySrc);
803- wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , data,
804- tensor.data .size );
805- return tensor;
803+ const int16_t *data) {
804+ assert (dtype == ki16); // unsupported: pack into ki32
805+ size_t numElements = size (shape);
806+ size_t packedCount = (numElements + 1 ) / 2 ;
807+ std::vector<int32_t > packed (packedCount, 0 );
808+ for (size_t i = 0 ; i < numElements; ++i) {
809+ size_t idx = i / 2 ;
810+ size_t shift = (i % 2 ) * 16 ;
811+ packed[idx] |= (static_cast <uint16_t >(data[i]) << shift);
812+ }
813+ return createTensor (ctx, shape, ki32, packed.data ());
806814}
807815
816+ // Overload for int64_t: pack each 64‑bit int into two 32‑bit integers
808817inline Tensor createTensor (Context &ctx, const Shape &shape, NumType dtype,
809- const uint64_t *data) {
810- assert (dtype == ku64);
811- Tensor tensor =
812- createTensor (ctx.pool , ctx.device , shape, dtype,
813- WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
814- WGPUBufferUsage_CopySrc);
815- wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , data,
816- tensor.data .size );
817- return tensor;
818+ const int64_t *data) {
819+ assert (dtype == ki64); // unsupported: pack into two ki32s
820+ size_t numElements = size (shape);
821+ std::vector<int32_t > packed (numElements * 2 );
822+ for (size_t i = 0 ; i < numElements; ++i) {
823+ int64_t val = data[i];
824+ packed[2 * i] = static_cast <int32_t >(val & 0xFFFFFFFF );
825+ packed[2 * i + 1 ] = static_cast <int32_t >((val >> 32 ) & 0xFFFFFFFF );
826+ }
827+ return createTensor (ctx, shape, ki32, packed.data ());
818828}
819829
820830inline Tensor createTensor (Context &ctx, const Shape &shape, NumType dtype,
821- const int64_t *data) {
822- assert (dtype == ki64 );
831+ const uint32_t *data) {
832+ assert (dtype == ku32 );
823833 Tensor tensor =
824834 createTensor (ctx.pool , ctx.device , shape, dtype,
825835 WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
@@ -829,40 +839,51 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
829839 return tensor;
830840}
831841
842+ // Overload for uint8_t: pack four 8‑bit integers into one 32‑bit unsigned
843+ // integer
832844inline Tensor createTensor (Context &ctx, const Shape &shape, NumType dtype,
833- const int8_t *data) {
834- assert (dtype == ki8);
835- Tensor tensor =
836- createTensor (ctx.pool , ctx.device , shape, dtype,
837- WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
838- WGPUBufferUsage_CopySrc);
839- wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , data,
840- tensor.data .size );
841- return tensor;
845+ const uint8_t *data) {
846+ assert (dtype == ku8); // unsupported: pack into ku32
847+ size_t numElements = size (shape);
848+ size_t packedCount = (numElements + 3 ) / 4 ;
849+ std::vector<uint32_t > packed (packedCount, 0 );
850+ for (size_t i = 0 ; i < numElements; ++i) {
851+ size_t idx = i / 4 ;
852+ size_t shift = (i % 4 ) * 8 ;
853+ packed[idx] |= (static_cast <uint32_t >(data[i]) << shift);
854+ }
855+ return createTensor (ctx, shape, ku32, packed.data ());
842856}
843857
858+ // Overload for uint16_t: pack two 16‑bit integers into one 32‑bit unsigned
859+ // integer
844860inline Tensor createTensor (Context &ctx, const Shape &shape, NumType dtype,
845- const int16_t *data) {
846- assert (dtype == ki16);
847- Tensor tensor =
848- createTensor (ctx.pool , ctx.device , shape, dtype,
849- WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
850- WGPUBufferUsage_CopySrc);
851- wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , data,
852- tensor.data .size );
853- return tensor;
861+ const uint16_t *data) {
862+ assert (dtype == ku16); // unsupported: pack into ku32
863+ size_t numElements = size (shape);
864+ size_t packedCount = (numElements + 1 ) / 2 ;
865+ std::vector<uint32_t > packed (packedCount, 0 );
866+ for (size_t i = 0 ; i < numElements; ++i) {
867+ size_t idx = i / 2 ;
868+ size_t shift = (i % 2 ) * 16 ;
869+ packed[idx] |= (static_cast <uint32_t >(data[i]) << shift);
870+ }
871+ return createTensor (ctx, shape, ku32, packed.data ());
854872}
855873
874+ // Overload for uint64_t: pack each 64‑bit integer into two 32‑bit unsigned
875+ // integers
856876inline Tensor createTensor (Context &ctx, const Shape &shape, NumType dtype,
857- const int32_t *data) {
858- assert (dtype == ki32);
859- Tensor tensor =
860- createTensor (ctx.pool , ctx.device , shape, dtype,
861- WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
862- WGPUBufferUsage_CopySrc);
863- wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , data,
864- tensor.data .size );
865- return tensor;
877+ const uint64_t *data) {
878+ assert (dtype == ku64); // unsupported: pack into two ku32s
879+ size_t numElements = size (shape);
880+ std::vector<uint32_t > packed (numElements * 2 );
881+ for (size_t i = 0 ; i < numElements; ++i) {
882+ uint64_t val = data[i];
883+ packed[2 * i] = static_cast <uint32_t >(val & 0xFFFFFFFF );
884+ packed[2 * i + 1 ] = static_cast <uint32_t >(val >> 32 );
885+ }
886+ return createTensor (ctx, shape, ku32, packed.data ());
866887}
867888
868889/* *
@@ -1759,6 +1780,117 @@ inline void toCPU(Context &ctx, Tensor &tensor, std::array<float, N> &data,
17591780 wait (ctx, future);
17601781}
17611782
1783+ inline void toCPU (Context &ctx, Tensor &tensor, NumType dtype, void *output, size_t sourceOffset = 0 ) {
1784+ size_t numElements = size (tensor.shape );
1785+ switch (dtype) {
1786+ // These types are directly supported.
1787+ case kf16:
1788+ case kf32:
1789+ case ku32:
1790+ case ki32:
1791+ toCPU (ctx, tensor, output, tensor.data .size , sourceOffset);
1792+ break ;
1793+
1794+ // For double, the tensor was created by packing doubles into floats.
1795+ case kf64: {
1796+ std::vector<float > tmp (numElements);
1797+ toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (float ), sourceOffset);
1798+ double *dst = static_cast <double *>(output);
1799+ for (size_t i = 0 ; i < numElements; ++i) {
1800+ dst[i] = static_cast <double >(tmp[i]);
1801+ }
1802+ break ;
1803+ }
1804+
1805+ // For int8_t: four 8‑bit ints packed into one int32_t.
1806+ case ki8: {
1807+ size_t packedCount = (numElements + 3 ) / 4 ;
1808+ std::vector<int32_t > tmp (packedCount);
1809+ toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (int32_t ), sourceOffset);
1810+ int8_t *dst = static_cast <int8_t *>(output);
1811+ for (size_t i = 0 ; i < numElements; ++i) {
1812+ size_t idx = i / 4 ;
1813+ size_t shift = (i % 4 ) * 8 ;
1814+ dst[i] = static_cast <int8_t >((tmp[idx] >> shift) & 0xFF );
1815+ }
1816+ break ;
1817+ }
1818+
1819+ // For int16_t: two 16‑bit ints packed into one int32_t.
1820+ case ki16: {
1821+ size_t packedCount = (numElements + 1 ) / 2 ;
1822+ std::vector<int32_t > tmp (packedCount);
1823+ toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (int32_t ), sourceOffset);
1824+ int16_t *dst = static_cast <int16_t *>(output);
1825+ for (size_t i = 0 ; i < numElements; ++i) {
1826+ size_t idx = i / 2 ;
1827+ size_t shift = (i % 2 ) * 16 ;
1828+ dst[i] = static_cast <int16_t >((tmp[idx] >> shift) & 0xFFFF );
1829+ }
1830+ break ;
1831+ }
1832+
1833+ // For int64_t: each 64‑bit int was packed into two int32_t.
1834+ case ki64: {
1835+ std::vector<int32_t > tmp (numElements * 2 );
1836+ toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (int32_t ), sourceOffset);
1837+ int64_t *dst = static_cast <int64_t *>(output);
1838+ for (size_t i = 0 ; i < numElements; ++i) {
1839+ int32_t low = tmp[2 * i];
1840+ int32_t high = tmp[2 * i + 1 ];
1841+ dst[i] = (static_cast <int64_t >(high) << 32 ) |
1842+ (static_cast <uint32_t >(low));
1843+ }
1844+ break ;
1845+ }
1846+
1847+ // For uint8_t: four 8‑bit uints packed into one uint32_t.
1848+ case ku8: {
1849+ size_t packedCount = (numElements + 3 ) / 4 ;
1850+ std::vector<uint32_t > tmp (packedCount);
1851+ toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (uint32_t ), sourceOffset);
1852+ uint8_t *dst = static_cast <uint8_t *>(output);
1853+ for (size_t i = 0 ; i < numElements; ++i) {
1854+ size_t idx = i / 4 ;
1855+ size_t shift = (i % 4 ) * 8 ;
1856+ dst[i] = static_cast <uint8_t >((tmp[idx] >> shift) & 0xFF );
1857+ }
1858+ break ;
1859+ }
1860+
1861+ // For uint16_t: two 16‑bit uints packed into one uint32_t.
1862+ case ku16: {
1863+ size_t packedCount = (numElements + 1 ) / 2 ;
1864+ std::vector<uint32_t > tmp (packedCount);
1865+ toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (uint32_t ), sourceOffset);
1866+ uint16_t *dst = static_cast <uint16_t *>(output);
1867+ for (size_t i = 0 ; i < numElements; ++i) {
1868+ size_t idx = i / 2 ;
1869+ size_t shift = (i % 2 ) * 16 ;
1870+ dst[i] = static_cast <uint16_t >((tmp[idx] >> shift) & 0xFFFF );
1871+ }
1872+ break ;
1873+ }
1874+
1875+ // For uint64_t: each 64‑bit unsigned int was packed into two uint32_t.
1876+ case ku64: {
1877+ std::vector<uint32_t > tmp (numElements * 2 );
1878+ toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (uint32_t ), sourceOffset);
1879+ uint64_t *dst = static_cast <uint64_t *>(output);
1880+ for (size_t i = 0 ; i < numElements; ++i) {
1881+ uint32_t low = tmp[2 * i];
1882+ uint32_t high = tmp[2 * i + 1 ];
1883+ dst[i] = (static_cast <uint64_t >(high) << 32 ) | low;
1884+ }
1885+ break ;
1886+ }
1887+
1888+ default :
1889+ LOG (kDefLog , kError , " Unsupported dtype in toCPUUnpack" );
1890+ break ;
1891+ }
1892+ }
1893+
17621894/* *
17631895 * @brief Copies data from CPU memory to a GPU buffer. The toGPU overloads are
17641896 * effectively a convenience wrapper around the WebGPU API call
0 commit comments