@@ -1780,114 +1780,230 @@ inline void toCPU(Context &ctx, Tensor &tensor, std::array<float, N> &data,
17801780 wait (ctx, future);
17811781}
17821782
1783- inline void toCPU (Context &ctx, Tensor &tensor, NumType dtype, void *output, size_t sourceOffset = 0 ) {
1783+ inline void toCPU (Context &ctx, Tensor &tensor, NumType dtype, void *output,
1784+ size_t sourceOffset = 0 ) {
17841785 size_t numElements = size (tensor.shape );
17851786 switch (dtype) {
17861787 // These types are directly supported.
17871788 case kf16:
17881789 case kf32:
17891790 case ku32:
17901791 case ki32:
1791- toCPU (ctx, tensor, output, tensor.data .size , sourceOffset);
1792- break ;
1792+ toCPU (ctx, tensor, output, tensor.data .size , sourceOffset);
1793+ break ;
17931794
17941795 // For double, the tensor was created by packing doubles into floats.
17951796 case kf64: {
1796- std::vector<float > tmp (numElements);
1797- toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (float ), sourceOffset);
1798- double *dst = static_cast <double *>(output);
1799- for (size_t i = 0 ; i < numElements; ++i) {
1800- dst[i] = static_cast <double >(tmp[i]);
1801- }
1802- break ;
1797+ std::vector<float > tmp (numElements);
1798+ toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (float ), sourceOffset);
1799+ double *dst = static_cast <double *>(output);
1800+ for (size_t i = 0 ; i < numElements; ++i) {
1801+ dst[i] = static_cast <double >(tmp[i]);
1802+ }
1803+ break ;
18031804 }
18041805
18051806 // For int8_t: four 8‑bit ints packed into one int32_t.
18061807 case ki8: {
1807- size_t packedCount = (numElements + 3 ) / 4 ;
1808- std::vector<int32_t > tmp (packedCount);
1809- toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (int32_t ), sourceOffset);
1810- int8_t *dst = static_cast <int8_t *>(output);
1811- for (size_t i = 0 ; i < numElements; ++i) {
1812- size_t idx = i / 4 ;
1813- size_t shift = (i % 4 ) * 8 ;
1814- dst[i] = static_cast <int8_t >((tmp[idx] >> shift) & 0xFF );
1815- }
1816- break ;
1808+ size_t packedCount = (numElements + 3 ) / 4 ;
1809+ std::vector<int32_t > tmp (packedCount);
1810+ toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (int32_t ), sourceOffset);
1811+ int8_t *dst = static_cast <int8_t *>(output);
1812+ for (size_t i = 0 ; i < numElements; ++i) {
1813+ size_t idx = i / 4 ;
1814+ size_t shift = (i % 4 ) * 8 ;
1815+ dst[i] = static_cast <int8_t >((tmp[idx] >> shift) & 0xFF );
1816+ }
1817+ break ;
18171818 }
18181819
18191820 // For int16_t: two 16‑bit ints packed into one int32_t.
18201821 case ki16: {
1821- size_t packedCount = (numElements + 1 ) / 2 ;
1822- std::vector<int32_t > tmp (packedCount);
1823- toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (int32_t ), sourceOffset);
1824- int16_t *dst = static_cast <int16_t *>(output);
1825- for (size_t i = 0 ; i < numElements; ++i) {
1826- size_t idx = i / 2 ;
1827- size_t shift = (i % 2 ) * 16 ;
1828- dst[i] = static_cast <int16_t >((tmp[idx] >> shift) & 0xFFFF );
1829- }
1830- break ;
1822+ size_t packedCount = (numElements + 1 ) / 2 ;
1823+ std::vector<int32_t > tmp (packedCount);
1824+ toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (int32_t ), sourceOffset);
1825+ int16_t *dst = static_cast <int16_t *>(output);
1826+ for (size_t i = 0 ; i < numElements; ++i) {
1827+ size_t idx = i / 2 ;
1828+ size_t shift = (i % 2 ) * 16 ;
1829+ dst[i] = static_cast <int16_t >((tmp[idx] >> shift) & 0xFFFF );
1830+ }
1831+ break ;
18311832 }
18321833
18331834 // For int64_t: each 64‑bit int was packed into two int32_t.
18341835 case ki64: {
1835- std::vector<int32_t > tmp (numElements * 2 );
1836- toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (int32_t ), sourceOffset);
1837- int64_t *dst = static_cast <int64_t *>(output);
1838- for (size_t i = 0 ; i < numElements; ++i) {
1839- int32_t low = tmp[2 * i];
1840- int32_t high = tmp[2 * i + 1 ];
1841- dst[i] = ( static_cast < int64_t >(high) << 32 ) |
1842- (static_cast <uint32_t >(low));
1843- }
1844- break ;
1836+ std::vector<int32_t > tmp (numElements * 2 );
1837+ toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (int32_t ), sourceOffset);
1838+ int64_t *dst = static_cast <int64_t *>(output);
1839+ for (size_t i = 0 ; i < numElements; ++i) {
1840+ int32_t low = tmp[2 * i];
1841+ int32_t high = tmp[2 * i + 1 ];
1842+ dst[i] =
1843+ ( static_cast < int64_t >(high) << 32 ) | (static_cast <uint32_t >(low));
1844+ }
1845+ break ;
18451846 }
18461847
18471848 // For uint8_t: four 8‑bit uints packed into one uint32_t.
18481849 case ku8: {
1849- size_t packedCount = (numElements + 3 ) / 4 ;
1850- std::vector<uint32_t > tmp (packedCount);
1851- toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (uint32_t ), sourceOffset);
1852- uint8_t *dst = static_cast <uint8_t *>(output);
1853- for (size_t i = 0 ; i < numElements; ++i) {
1854- size_t idx = i / 4 ;
1855- size_t shift = (i % 4 ) * 8 ;
1856- dst[i] = static_cast <uint8_t >((tmp[idx] >> shift) & 0xFF );
1857- }
1858- break ;
1850+ size_t packedCount = (numElements + 3 ) / 4 ;
1851+ std::vector<uint32_t > tmp (packedCount);
1852+ toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (uint32_t ), sourceOffset);
1853+ uint8_t *dst = static_cast <uint8_t *>(output);
1854+ for (size_t i = 0 ; i < numElements; ++i) {
1855+ size_t idx = i / 4 ;
1856+ size_t shift = (i % 4 ) * 8 ;
1857+ dst[i] = static_cast <uint8_t >((tmp[idx] >> shift) & 0xFF );
1858+ }
1859+ break ;
18591860 }
18601861
18611862 // For uint16_t: two 16‑bit uints packed into one uint32_t.
18621863 case ku16: {
1863- size_t packedCount = (numElements + 1 ) / 2 ;
1864- std::vector<uint32_t > tmp (packedCount);
1865- toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (uint32_t ), sourceOffset);
1866- uint16_t *dst = static_cast <uint16_t *>(output);
1867- for (size_t i = 0 ; i < numElements; ++i) {
1868- size_t idx = i / 2 ;
1869- size_t shift = (i % 2 ) * 16 ;
1870- dst[i] = static_cast <uint16_t >((tmp[idx] >> shift) & 0xFFFF );
1871- }
1872- break ;
1864+ size_t packedCount = (numElements + 1 ) / 2 ;
1865+ std::vector<uint32_t > tmp (packedCount);
1866+ toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (uint32_t ), sourceOffset);
1867+ uint16_t *dst = static_cast <uint16_t *>(output);
1868+ for (size_t i = 0 ; i < numElements; ++i) {
1869+ size_t idx = i / 2 ;
1870+ size_t shift = (i % 2 ) * 16 ;
1871+ dst[i] = static_cast <uint16_t >((tmp[idx] >> shift) & 0xFFFF );
1872+ }
1873+ break ;
18731874 }
18741875
18751876 // For uint64_t: each 64‑bit unsigned int was packed into two uint32_t.
18761877 case ku64: {
1877- std::vector<uint32_t > tmp (numElements * 2 );
1878- toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (uint32_t ), sourceOffset);
1879- uint64_t *dst = static_cast <uint64_t *>(output);
1880- for (size_t i = 0 ; i < numElements; ++i) {
1881- uint32_t low = tmp[2 * i];
1882- uint32_t high = tmp[2 * i + 1 ];
1883- dst[i] = (static_cast <uint64_t >(high) << 32 ) | low;
1884- }
1885- break ;
1878+ std::vector<uint32_t > tmp (numElements * 2 );
1879+ toCPU (ctx, tensor, tmp.data (), tmp.size () * sizeof (uint32_t ), sourceOffset);
1880+ uint64_t *dst = static_cast <uint64_t *>(output);
1881+ for (size_t i = 0 ; i < numElements; ++i) {
1882+ uint32_t low = tmp[2 * i];
1883+ uint32_t high = tmp[2 * i + 1 ];
1884+ dst[i] = (static_cast <uint64_t >(high) << 32 ) | low;
1885+ }
1886+ break ;
1887+ }
1888+
1889+ default :
1890+ LOG (kDefLog , kError , " Unsupported dtype in toCPUUnpack" );
1891+ break ;
1892+ }
1893+ }
1894+
1895+ inline void toCPU (Context &ctx, WGPUBuffer buffer, NumType dtype, void *output,
1896+ size_t numElements, size_t sourceOffset = 0 ) {
1897+ switch (dtype) {
1898+ // Directly supported types.
1899+ case kf16:
1900+ case kf32:
1901+ case ku32:
1902+ case ki32: {
1903+ size_t byteSize = numElements * sizeBytes (dtype);
1904+ toCPU (ctx, buffer, output, byteSize, sourceOffset);
1905+ break ;
1906+ }
1907+
1908+ // For double, the buffer was written as floats.
1909+ case kf64: {
1910+ std::vector<float > tmp (numElements);
1911+ toCPU (ctx, buffer, tmp.data (), numElements * sizeof (float ), sourceOffset);
1912+ double *dst = static_cast <double *>(output);
1913+ for (size_t i = 0 ; i < numElements; ++i) {
1914+ dst[i] = static_cast <double >(tmp[i]);
1915+ }
1916+ break ;
1917+ }
1918+
1919+ // For int8_t: four 8‑bit ints packed into one int32_t.
1920+ case ki8: {
1921+ size_t packedCount = (numElements + 3 ) / 4 ;
1922+ std::vector<int32_t > tmp (packedCount);
1923+ toCPU (ctx, buffer, tmp.data (), packedCount * sizeof (int32_t ), sourceOffset);
1924+ int8_t *dst = static_cast <int8_t *>(output);
1925+ for (size_t i = 0 ; i < numElements; ++i) {
1926+ size_t idx = i / 4 ;
1927+ size_t shift = (i % 4 ) * 8 ;
1928+ dst[i] = static_cast <int8_t >((tmp[idx] >> shift) & 0xFF );
1929+ }
1930+ break ;
1931+ }
1932+
1933+ // For int16_t: two 16‑bit ints packed into one int32_t.
1934+ case ki16: {
1935+ size_t packedCount = (numElements + 1 ) / 2 ;
1936+ std::vector<int32_t > tmp (packedCount);
1937+ toCPU (ctx, buffer, tmp.data (), packedCount * sizeof (int32_t ), sourceOffset);
1938+ int16_t *dst = static_cast <int16_t *>(output);
1939+ for (size_t i = 0 ; i < numElements; ++i) {
1940+ size_t idx = i / 2 ;
1941+ size_t shift = (i % 2 ) * 16 ;
1942+ dst[i] = static_cast <int16_t >((tmp[idx] >> shift) & 0xFFFF );
1943+ }
1944+ break ;
1945+ }
1946+
1947+ // For int64_t: each 64‑bit int is packed into two int32_t.
1948+ case ki64: {
1949+ std::vector<int32_t > tmp (numElements * 2 );
1950+ toCPU (ctx, buffer, tmp.data (), tmp.size () * sizeof (int32_t ), sourceOffset);
1951+ int64_t *dst = static_cast <int64_t *>(output);
1952+ for (size_t i = 0 ; i < numElements; ++i) {
1953+ int32_t low = tmp[2 * i];
1954+ int32_t high = tmp[2 * i + 1 ];
1955+ dst[i] =
1956+ (static_cast <int64_t >(high) << 32 ) | (static_cast <uint32_t >(low));
1957+ }
1958+ break ;
1959+ }
1960+
1961+ // For uint8_t: four 8‑bit uints packed into one uint32_t.
1962+ case ku8: {
1963+ size_t packedCount = (numElements + 3 ) / 4 ;
1964+ std::vector<uint32_t > tmp (packedCount);
1965+ toCPU (ctx, buffer, tmp.data (), packedCount * sizeof (uint32_t ),
1966+ sourceOffset);
1967+ uint8_t *dst = static_cast <uint8_t *>(output);
1968+ for (size_t i = 0 ; i < numElements; ++i) {
1969+ size_t idx = i / 4 ;
1970+ size_t shift = (i % 4 ) * 8 ;
1971+ dst[i] = static_cast <uint8_t >((tmp[idx] >> shift) & 0xFF );
1972+ }
1973+ break ;
1974+ }
1975+
1976+ // For uint16_t: two 16‑bit uints packed into one uint32_t.
1977+ case ku16: {
1978+ size_t packedCount = (numElements + 1 ) / 2 ;
1979+ std::vector<uint32_t > tmp (packedCount);
1980+ toCPU (ctx, buffer, tmp.data (), packedCount * sizeof (uint32_t ),
1981+ sourceOffset);
1982+ uint16_t *dst = static_cast <uint16_t *>(output);
1983+ for (size_t i = 0 ; i < numElements; ++i) {
1984+ size_t idx = i / 2 ;
1985+ size_t shift = (i % 2 ) * 16 ;
1986+ dst[i] = static_cast <uint16_t >((tmp[idx] >> shift) & 0xFFFF );
1987+ }
1988+ break ;
1989+ }
1990+
1991+ // For uint64_t: each 64‑bit unsigned int packed into two uint32_t.
1992+ case ku64: {
1993+ std::vector<uint32_t > tmp (numElements * 2 );
1994+ toCPU (ctx, buffer, tmp.data (), tmp.size () * sizeof (uint32_t ), sourceOffset);
1995+ uint64_t *dst = static_cast <uint64_t *>(output);
1996+ for (size_t i = 0 ; i < numElements; ++i) {
1997+ uint32_t low = tmp[2 * i];
1998+ uint32_t high = tmp[2 * i + 1 ];
1999+ dst[i] = (static_cast <uint64_t >(high) << 32 ) | low;
2000+ }
2001+ break ;
18862002 }
18872003
18882004 default :
1889- LOG (kDefLog , kError , " Unsupported dtype in toCPUUnpack " );
1890- break ;
2005+ LOG (kDefLog , kError , " Unsupported dtype in toCPU (raw buffer override) " );
2006+ break ;
18912007 }
18922008}
18932009
0 commit comments