Skip to content

Commit dbaeef0

Browse files
committed
pure wasm (no emscripten) end-to-end test of controlling gpu.js
1 parent 9d2f109 commit dbaeef0

File tree

8 files changed

+274
-34
lines changed

8 files changed

+274
-34
lines changed

experimental/wasm/Makefile

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,46 @@
1+
FLAGS=--target=wasm32 -nostdlib -Wl,--no-entry -Wl,--export-all -Wl,--import-memory -Wl,--allow-undefined -fexceptions -std=c++17 -O3
2+
# FLAGS=--target=wasm32-unknown-unknown -stdlib=libc++ -nostdlib++ -Wl,--no-entry -Wl,--export-all -Wl,--import-memory -Wl,--allow-undefined -fexceptions -std=c++17
3+
14
.PHONY: all clean dump-obj dump-wasm dependencies server
25

3-
all: run.wasm dump-obj dump-wasm
6+
all: build/hello.wasm dump-obj dump-wasm
7+
8+
watch:
9+
ls *.cpp *.h | entr make build/run.wasm
410

5-
# Compile the C++ source file to LLVM IR
6-
run.ll: run.cpp
7-
clang --target=wasm32 -emit-llvm -c -S run.cpp
11+
build/run.wasm: run.cpp Makefile
12+
clang++ $(FLAGS) -o build/run.wasm run.cpp
13+
14+
# cpp -> llvm ir
15+
build/hello.ll: hello.cpp
16+
clang --target=wasm32 -emit-llvm -c -S hello.cpp -o build/hello.ll
817

9-
# Assemble the LLVM IR to a WebAssembly object file
10-
run.o: run.ll
11-
llc -march=wasm32 -filetype=obj run.ll
18+
# llvm ir -> wasm object file
19+
build/hello.o: build/hello.ll
20+
llc -march=wasm32 -filetype=obj build/hello.ll -o build/hello.o
1221

1322
# Disassemble the WebAssembly object file
1423
dump-obj:
15-
wasm-objdump -x run.o
24+
wasm-objdump -x build/hello.o
1625

1726
# Link the WebAssembly object file to a WebAssembly module
18-
# no entry point function
19-
# export all functions
20-
run.wasm: run.o
27+
build/hello.wasm: build/hello.o
2128
wasm-ld \
2229
--no-entry \
2330
--export-all \
24-
-o run.wasm \
25-
run.o
31+
-o build/hello.wasm \
32+
build/hello.o
2633

2734
dump-wasm:
28-
wasm-objdump -x run.wasm
35+
wasm-objdump -x build/hello.wasm
2936

3037
# TODO(avh): this is just a reminder note for now - remove it later
3138
dependencies:
3239
brew install llvm
3340
brew install wabt
3441

3542
server:
36-
python3 -m http.server
43+
python3 -m http.server 8000
3744

3845
clean:
39-
rm -f run.ll run.o run.wasm
46+
rm -f build/hello.ll build/hello.o build/hello.wasm build/run.wasm

experimental/wasm/build/.gitkeep

Whitespace-only changes.

experimental/wasm/favicon.ico

15 KB
Binary file not shown.

experimental/wasm/gpu.js

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
// gpu.js
22

3+
const gpujs = (function() {
4+
35
class Shape {
46
static kMaxRank = 8;
57

68
constructor(...dims) {
79
if (dims.length > Shape.kMaxRank) {
810
throw new Error(`Shape can have at most ${Shape.kMaxRank} dimensions`);
911
}
10-
1112
this.rank = dims.length;
1213

1314
// Initialize data with the provided dimensions
@@ -19,6 +20,7 @@ class Shape {
1920
}
2021
}
2122
}
23+
2224
class Array {
2325
constructor(buffer, usage, size) {
2426
this.buffer = buffer;
@@ -187,6 +189,7 @@ async function createContext() {
187189
}
188190
context.device = await context.adapter.requestDevice();
189191
context.queue = context.device.queue;
192+
console.log("Context created");
190193
return context;
191194
}
192195

@@ -299,7 +302,7 @@ function dispatchKernel(ctx, kernel) {
299302
return ctx.device.queue.onSubmittedWorkDone();
300303
}
301304

302-
async function main() {
305+
async function simpleTest() {
303306
console.log("Starting main");
304307
const ctx = await createContext();
305308

@@ -338,4 +341,36 @@ async function main() {
338341
destroyContext(ctx);
339342
}
340343

341-
main().catch(console.error);
344+
// At the end of the file, return an object with all your exports
345+
return {
346+
Shape,
347+
Array,
348+
Tensor,
349+
TensorView,
350+
Bindings,
351+
Context,
352+
TensorPool,
353+
KernelPool,
354+
KernelCode,
355+
Kernel,
356+
NumType,
357+
size,
358+
sizeBytes,
359+
toString,
360+
replaceAll,
361+
cdiv,
362+
cdivShape,
363+
createContext,
364+
destroyContext,
365+
resetCommandBuffer,
366+
createKernel,
367+
createTensor,
368+
toGPU,
369+
toCPU,
370+
dispatchKernel,
371+
simpleTest,
372+
};
373+
})();
374+
375+
376+
export default gpujs;

experimental/wasm/hello.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
// Hello world llvm wasm test
2+
3+
extern "C" {
4+
int add(int a, int b) { return a + b; }
5+
int mul(int a, int b) { return a * b; }
6+
int foo(int a, int b) { return a * a + b + 4; }
7+
}

experimental/wasm/index.html

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,77 @@
1-
<!DOCTYPE html>
1+
<!doctype html>
22
<html lang="en">
3-
<head>
4-
<meta charset="UTF-8">
5-
<meta name="viewport" content="width=device-width, initial-scale=1.0">
6-
<title>WebGPU Context Creation</title>
7-
</head>
8-
<body>
3+
<head>
4+
<meta charset="UTF-8" />
5+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
6+
<title>gpu.cpp wasm test</title>
7+
</head>
8+
<body>
99
<h1>gpu.js test</h1>
10-
<div id="status">Initializing WebGPU...</div>
11-
<script src="gpu.js"></script>
12-
</body>
10+
<div id="status">gpu.cpp -> wasm test</div>
11+
12+
<script type="module">
13+
import gpujs from "./gpu.js";
14+
15+
let wasmInstance = null;
16+
const memory = new WebAssembly.Memory({ initial: 8192, maximum: 8192 });
17+
18+
function memset(ptr, value, num) {
19+
const view = new Uint8Array(memory.buffer);
20+
view.fill(value, ptr, ptr + num);
21+
}
22+
23+
async function loadWasm() {
24+
const response = await fetch("build/run.wasm");
25+
const bytes = await response.arrayBuffer();
26+
27+
// Create the WebAssembly environment
28+
const env = Object.keys(gpujs).reduce(
29+
(env, key) => {
30+
env[key] = (...args) => {
31+
console.log(`Calling ${key} from WebAssembly`);
32+
return gpujs[key](...args);
33+
};
34+
return env;
35+
},
36+
{
37+
memory: memory,
38+
jsLOG: (messagePtr) => {
39+
console.log("jsLOG called from WebAssembly");
40+
console.log("memory ", memory);
41+
const view = new Uint8Array(memory.buffer);
42+
console.log(
43+
"Memory Buffer Slice: ",
44+
view.slice(messagePtr, messagePtr + 100),
45+
); // Check buffer content
46+
47+
let message = "";
48+
console.log("messagePtr ", messagePtr);
49+
for (let i = messagePtr; view[i] !== 0; i++) {
50+
message += String.fromCharCode(view[i]);
51+
console.log(view[i]);
52+
}
53+
console.log(message);
54+
},
55+
memset,
56+
},
57+
);
58+
59+
const { instance } = await WebAssembly.instantiate(bytes, { env });
60+
// instance.exports.setMemory(memory.buffer.byteOffset); // Pass the memory buffer to the wasm module
61+
return instance;
62+
}
63+
64+
loadWasm()
65+
.then((instance) => {
66+
console.log("WebAssembly module loaded");
67+
instance.exports.main();
68+
})
69+
.catch((error) => {
70+
console.error("Failed to load WebAssembly module:", error);
71+
});
72+
73+
// Make gpujs globally available if needed
74+
window.gpujs = gpujs;
75+
</script>
76+
</body>
1377
</html>

experimental/wasm/run.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
// Hello world llvm wasm test
1+
#include "wasm.h"
22

3-
extern "C" {
4-
int add(int a, int b) { return a + b; }
5-
int mul(int a, int b) { return a * b; }
6-
int foo(int a, int b) { return a * a + b + 4; }
3+
int main() {
4+
// Note: This calls createContext but this doesn't work to obtain the return value
5+
// due to async
6+
// Context* ctx = createContext();
7+
// destroyContext(ctx);
8+
9+
LOG("Hello, World!");
10+
11+
return 0;
712
}

experimental/wasm/wasm.h

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#ifndef WASM_H
2+
#define WASM_H
3+
4+
// #define WASM_IMPORT __attribute__((import_module("env"),
5+
// import_name("memory"))) #define WASM_IMPORT __attribute__((used))
6+
// __attribute__((visibility("default")))
7+
8+
extern "C" {
9+
10+
// these are normally defined in stdint.h, but we can't include that in wasm
11+
typedef signed char int8_t;
12+
typedef short int16_t;
13+
typedef int int32_t;
14+
typedef long long int64_t;
15+
typedef unsigned char uint8_t;
16+
typedef unsigned short uint16_t;
17+
typedef unsigned int uint32_t;
18+
typedef unsigned long long uint64_t;
19+
typedef unsigned long size_t;
20+
21+
// Opaque handles to js shim objects
22+
typedef struct Shape Shape;
23+
typedef struct Array Array;
24+
typedef struct Tensor Tensor;
25+
typedef struct TensorView TensorView;
26+
typedef struct Bindings Bindings;
27+
typedef struct Context Context;
28+
typedef struct KernelCode KernelCode;
29+
typedef struct Kernel Kernel;
30+
31+
// Enum to match JavaScript NumType
32+
typedef enum { kf16, kf32 } NumType;
33+
34+
// Function declarations that will be implemented in JavaScript
35+
36+
Shape *createShape(int32_t *dims, int32_t rank);
37+
void destroyShape(Shape *shape);
38+
39+
Array *createArray(uint64_t bufferPtr, uint32_t usage, uint64_t size);
40+
void destroyArray(Array *array);
41+
42+
Tensor *createTensor(Array *data, Shape *shape);
43+
void destroyTensor(Tensor *tensor);
44+
45+
TensorView *createTensorView(Tensor *data, uint64_t offset, uint64_t span);
46+
void destroyTensorView(TensorView *view);
47+
48+
Bindings *createBindings(Tensor **tensors, int32_t count);
49+
void destroyBindings(Bindings *bindings);
50+
51+
Context *createContext();
52+
void destroyContext(Context *ctx);
53+
54+
KernelCode *createKernelCode(const char *data, Shape *workgroupSize,
55+
NumType precision);
56+
void destroyKernelCode(KernelCode *code);
57+
58+
Kernel *createKernel(Context *ctx, KernelCode *code, Bindings *dataBindings,
59+
Shape *nWorkgroups, void *params);
60+
void destroyKernel(Kernel *kernel);
61+
62+
uint64_t size(Shape *shape);
63+
uint64_t sizeBytes(NumType type);
64+
65+
char *toString(Shape *shape);
66+
char *toStringInt(int32_t value);
67+
char *toStringNumType(NumType type);
68+
69+
void replaceAll(char *str, const char *from, const char *to);
70+
71+
int32_t cdiv(int32_t n, int32_t d);
72+
Shape *cdivShape(Shape *total, Shape *group);
73+
74+
Tensor *createTensorImpl(Context *ctx, Shape *shape, NumType dtype);
75+
76+
void toGPU(Context *ctx, float *data, Tensor *tensor);
77+
void toCPU(Context *ctx, Tensor *tensor, float *data);
78+
79+
void dispatchKernel(Context *ctx, Kernel *kernel);
80+
81+
void resetCommandBuffer(Context *ctx, Kernel *kernel);
82+
83+
uint8_t *memory;
84+
85+
void jsLOG(uint8_t *messagePtr);
86+
87+
int simpleTest();
88+
89+
} // extern "C"
90+
91+
// Simple bump allocator for now
92+
93+
uint32_t kMemPtr = 0;
94+
95+
uint8_t* wasmMalloc(size_t size) {
96+
uint8_t* ptr = &memory[kMemPtr];
97+
kMemPtr += size;
98+
return ptr;
99+
}
100+
101+
size_t strlen(const char* str) {
102+
size_t len = 0;
103+
while (str[len]) {
104+
len++;
105+
}
106+
return len;
107+
}
108+
109+
void LOG(const char* message) {
110+
size_t len = strlen(message);
111+
uint8_t* start = (wasmMalloc(len));
112+
uint8_t* dest = start;
113+
size_t index = 0;
114+
while (*message) {
115+
*dest = *message;
116+
dest++;
117+
message++;
118+
}
119+
jsLOG(start);
120+
}
121+
122+
#endif // WASM_H

0 commit comments

Comments
 (0)