Skip to content

Commit f5f6414

Browse files
committed
fix types and emscripten race condition
1 parent a581f72 commit f5f6414

File tree

1 file changed

+83
-24
lines changed

1 file changed

+83
-24
lines changed

gpu.hpp

Lines changed: 83 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -625,28 +625,46 @@ struct Context {
625625

626626
~Context() {
627627
LOG(kDefLog, kTrace, "Destroying context");
628+
629+
#ifdef __EMSCRIPTEN__
630+
// For WebAssembly, do NOT call processEvents during destruction
631+
// This prevents "Asyncify cannot be done during or after runtime exits"
632+
LOG(kDefLog, kTrace,
633+
"WebAssembly context destruction - skipping processEvents");
634+
#endif
635+
628636
if (queue) {
629637
wgpuQueueRelease(queue);
638+
queue = nullptr;
630639
} else {
631-
LOG(kDefLog, kTrace, "Queue is null");
640+
LOG(kDefLog, kTrace, "Queue already null");
632641
}
642+
633643
if (device) {
634644
wgpuDeviceRelease(device);
635-
processEvents(instance);
645+
device = nullptr;
636646
} else {
637-
LOG(kDefLog, kTrace, "Device is null");
647+
LOG(kDefLog, kTrace, "Device already null");
638648
}
649+
639650
if (adapter) {
640651
wgpuAdapterRelease(adapter);
641-
processEvents(instance);
652+
adapter = nullptr;
642653
} else {
643-
LOG(kDefLog, kTrace, "Adapter is null");
654+
LOG(kDefLog, kTrace, "Adapter already null");
644655
}
656+
645657
if (instance) {
658+
#ifndef __EMSCRIPTEN__
659+
// Only call processEvents on native platforms during cleanup
660+
processEvents(instance);
661+
#endif
646662
wgpuInstanceRelease(instance);
663+
instance = nullptr;
647664
} else {
648-
LOG(kDefLog, kTrace, "Instance is null");
665+
LOG(kDefLog, kTrace, "Instance already null");
649666
}
667+
650668
LOG(kDefLog, kTrace, "Context destroyed");
651669
}
652670
};
@@ -983,21 +1001,63 @@ inline void check(bool condition, const char *message,
9831001
* devDescriptor); WGPUDevice device = wait(instance, deviceFuture);
9841002
* @endcode
9851003
*/
1004+
#ifdef __EMSCRIPTEN__
1005+
// Global flag to prevent overlapping async operations in WebAssembly
1006+
static std::atomic<bool> asyncOperationInProgress{false};
1007+
#endif
1008+
9861009
template <typename T> T wait(Context &ctx, std::future<T> &f) {
9871010
#ifdef __EMSCRIPTEN__
988-
// Poll until the future is ready.
989-
while (f.wait_for(std::chrono::milliseconds(0)) !=
990-
std::future_status::ready) {
991-
// Yield control to the JS event loop.
992-
emscripten_sleep(1);
1011+
// Check if another async operation is in progress
1012+
if (asyncOperationInProgress.load()) {
1013+
LOG(kDefLog, kWarn,
1014+
"wait(): Another async operation in progress, skipping wait");
1015+
if constexpr (std::is_void_v<T>) {
1016+
return; // For void functions, just return
1017+
} else {
1018+
return T{}; // Return default-constructed value for non-void types
1019+
}
1020+
}
1021+
1022+
// Set the flag before starting async operation
1023+
asyncOperationInProgress.store(true);
1024+
1025+
try {
1026+
// Poll until the future is ready
1027+
while (f.wait_for(std::chrono::milliseconds(0)) !=
1028+
std::future_status::ready) {
1029+
emscripten_sleep(1);
1030+
}
1031+
1032+
// Handle void vs non-void return types
1033+
if constexpr (std::is_void_v<T>) {
1034+
f.get(); // Just call get() without storing result
1035+
asyncOperationInProgress.store(false);
1036+
return; // void return
1037+
} else {
1038+
T result = f.get();
1039+
asyncOperationInProgress.store(false);
1040+
return result;
1041+
}
1042+
1043+
} catch (...) {
1044+
asyncOperationInProgress.store(false);
1045+
throw;
9931046
}
994-
return f.get();
9951047
#else
1048+
// Native implementation unchanged
9961049
while (f.wait_for(std::chrono::milliseconds(0)) !=
9971050
std::future_status::ready) {
9981051
wgpuInstanceProcessEvents(ctx.instance);
9991052
}
1000-
return f.get();
1053+
1054+
// Handle void vs non-void for native too
1055+
if constexpr (std::is_void_v<T>) {
1056+
f.get();
1057+
return;
1058+
} else {
1059+
return f.get();
1060+
}
10011061
#endif
10021062
}
10031063

@@ -2097,10 +2157,10 @@ inline void toGPU(Context &ctx, const int8_t *data, WGPUBuffer buffer,
20972157
size_t packedCount = (numElements + 3) / 4;
20982158
std::vector<int32_t> packed(packedCount, 0);
20992159
for (size_t i = 0; i < numElements; ++i) {
2100-
size_t idx = i / 4;
2101-
size_t shift = (i % 4) * 8;
2102-
packed[idx] |= (static_cast<uint8_t>(data[i]) << shift);
2103-
// LOG(kDefLog, kInfo, "toGPU: %d %d %d", data[i], packed[idx], idx);
2160+
size_t idx = i / 4;
2161+
size_t shift = (i % 4) * 8;
2162+
packed[idx] |= (static_cast<uint8_t>(data[i]) << shift);
2163+
// LOG(kDefLog, kInfo, "toGPU: %d %d %d", data[i], packed[idx], idx);
21042164
}
21052165
toGPU(ctx, packed.data(), buffer, packedCount * sizeof(int32_t));
21062166
}
@@ -2139,9 +2199,9 @@ inline void toGPU(Context &ctx, const uint8_t *data, WGPUBuffer buffer,
21392199
size_t packedCount = (numElements + 3) / 4;
21402200
std::vector<uint32_t> packed(packedCount, 0);
21412201
for (size_t i = 0; i < numElements; ++i) {
2142-
size_t idx = i / 4;
2143-
size_t shift = (i % 4) * 8;
2144-
packed[idx] |= (static_cast<uint32_t>(data[i]) << shift);
2202+
size_t idx = i / 4;
2203+
size_t shift = (i % 4) * 8;
2204+
packed[idx] |= (static_cast<uint32_t>(data[i]) << shift);
21452205
}
21462206
toGPU(ctx, packed.data(), buffer, packedCount * sizeof(uint32_t));
21472207
}
@@ -2154,9 +2214,9 @@ inline void toGPU(Context &ctx, const uint16_t *data, WGPUBuffer buffer,
21542214
size_t packedCount = (numElements + 1) / 2;
21552215
std::vector<uint32_t> packed(packedCount, 0);
21562216
for (size_t i = 0; i < numElements; ++i) {
2157-
size_t idx = i / 2;
2158-
size_t shift = (i % 2) * 16;
2159-
packed[idx] |= (static_cast<uint32_t>(data[i]) << shift);
2217+
size_t idx = i / 2;
2218+
size_t shift = (i % 2) * 16;
2219+
packed[idx] |= (static_cast<uint32_t>(data[i]) << shift);
21602220
}
21612221
toGPU(ctx, packed.data(), buffer, packedCount * sizeof(uint32_t));
21622222
}
@@ -2798,7 +2858,6 @@ inline std::future<void> dispatchKernelAsync(Context &ctx, Kernel &kernel) {
27982858
workDoneCallbackInfo.userdata1 = reinterpret_cast<void *>(promise);
27992859
workDoneCallbackInfo.userdata2 = nullptr;
28002860

2801-
// IMPORTANT: Pass the address of the callback info structure.
28022861
wgpuQueueOnSubmittedWorkDone(ctx.queue, workDoneCallbackInfo);
28032862

28042863
return future;

0 commit comments

Comments
 (0)