diff --git a/barretenberg/ts/dev/msm-webgpu/bench-apply-matrix.html b/barretenberg/ts/dev/msm-webgpu/bench-apply-matrix.html
new file mode 100644
index 000000000000..83c13009529e
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-apply-matrix.html
@@ -0,0 +1,37 @@
+
+
+
+
+ BY apply_matrix WebGPU dispatch test
+
+
+
+ BY apply_matrix WebGPU dispatch test
+ Query params: ?n=N&validate-n=N&reps=R
+
+
+
+
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-apply-matrix.ts b/barretenberg/ts/dev/msm-webgpu/bench-apply-matrix.ts
new file mode 100644
index 000000000000..e05603a24dfb
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-apply-matrix.ts
@@ -0,0 +1,464 @@
+///
+// BY apply_matrix WebGPU dispatch test. Mounted standalone (no SRS, no MSM
+// pipeline). Reads `?n=N&validate-n=N&reps=R`, generates seeded random
+// (Mat, f, g, d, e) records by first sampling random u64 (f_lo, g_lo, delta)
+// and running the TS `Wasm9x29.divsteps` to get a valid Mat (which
+// guarantees matrix-entry bounds), then synthesising plausible random f, g,
+// d, e BigIntBY states, runs `by_apply_matrix_fg` and `by_apply_matrix_de`
+// on the GPU per input record, and validates each output against the TS
+// `Wasm9x29.applyMatrix` reference.
+//
+// Safety. The shader's only loops are inherited from `by_apply_matrix_fg`
+// and `by_apply_matrix_de`, both bounded by the WGSL `const BY_NUM_LIMBS = 9u`.
+// One thread per input record, `n` capped at 2^20 to keep memory pressure
+// manageable.
+
+import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js';
+import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js';
+import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js';
+import {
+ Wasm9x29,
+ fromBigint as byFromBigint,
+ makeZero as byMakeZero,
+ normalise as byNormalise,
+ N as BY_N,
+} from '../../src/msm_webgpu/cuzk/bernstein_yang.js';
+
+interface SampleSummary {
+ reps: number;
+ msSamples: number[];
+ msMedian: number;
+ msMin: number;
+ msMax: number;
+ applyMatrixPerSec: number;
+}
+interface BenchResult {
+ validateOk: boolean;
+ mismatches: string[];
+ timing: SampleSummary | null;
+}
+interface BenchState {
+ state: 'boot' | 'running' | 'done' | 'error';
+ params: {
+ n: number;
+ validateN: number;
+ reps: number;
+ } | null;
+ result: BenchResult | null;
+ error: string | null;
+ log: string[];
+}
+
+const benchState: BenchState = {
+ state: 'boot',
+ params: null,
+ result: null,
+ error: null,
+ log: [],
+};
+(window as unknown as { __bench: BenchState }).__bench = benchState;
+
+const $log = document.getElementById('log') as HTMLDivElement;
+function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) {
+ const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : '';
+ const span = document.createElement('div');
+ span.className = cls;
+ span.textContent = msg;
+ $log.appendChild(span);
+ benchState.log.push(`[${level}] ${msg}`);
+ console.log(`[bench-apply-matrix] ${msg}`);
+}
+
+const N_MAX = 1 << 20;
+const INPUT_STRIDE_U32 = 44;
+const OUTPUT_STRIDE_I32 = 36;
+
+function makeRng(seed: number): () => number {
+ let state = (seed >>> 0) || 1;
+ return () => {
+ state = (Math.imul(state, 1664525) + 1013904223) >>> 0;
+ return state;
+ };
+}
+
+function randomU64(rng: () => number): { lo: number; hi: number } {
+ const lo = rng() >>> 0;
+ const hi = rng() >>> 0;
+ return { lo, hi };
+}
+
+// Random signed 29-bit limb in [-2^28, 2^28). The +shift then sub keeps the
+// distribution roughly uniform across the signed range while staying inside
+// the canonical |limb| < 2^29 invariant that `by_normalise` preserves.
+function randomSignedLimb29(rng: () => number): bigint {
+ const v = rng();
+ // Map u32 → signed [-2^28, 2^28).
+ const u29 = v & 0x1fffffff;
+ return BigInt(u29) - (1n << 28n);
+}
+
+// Make a random BigIntBY with limbs roughly uniform in [-2^28, 2^28),
+// then normalise so the lower limbs canonicalise to [0, 2^29). This matches
+// the post-by_normalise invariant the WGSL functions expect on input.
+function randomNormalisedBigIntBY(rng: () => number): bigint[] {
+ const x = byMakeZero();
+ for (let i = 0; i < BY_N; i++) {
+ x[i] = randomSignedLimb29(rng);
+ }
+ byNormalise(x);
+ return x;
+}
+
+function median(xs: number[]): number {
+ if (xs.length === 0) return NaN;
+ const s = xs.slice().sort((a, b) => a - b);
+ return s[Math.floor(s.length / 2)];
+}
+
+// Reassemble a (lo: i32, hi: i32) GPU pair as a signed bigint (two's comp).
+function pairToSignedBig(lo: number, hi: number): bigint {
+ const loBig = BigInt(lo >>> 0);
+ const hiBig = BigInt(hi >>> 0);
+ let v = loBig | (hiBig << 32n);
+ if (v >= 1n << 63n) v -= 1n << 64n;
+ return v;
+}
+
+async function createPipeline(
+ device: GPUDevice,
+ code: string,
+ cacheKey: string,
+): Promise<{ pipeline: GPUComputePipeline; layout: GPUBindGroupLayout }> {
+ const module = device.createShaderModule({ code });
+ const info = await module.getCompilationInfo();
+ let hasError = false;
+ for (const msg of info.messages) {
+ const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`;
+ if (msg.type === 'error') {
+ console.error(line);
+ hasError = true;
+ } else {
+ console.warn(line);
+ }
+ }
+ if (hasError) {
+ throw new Error(`WGSL compile failed for ${cacheKey}`);
+ }
+ const layout = device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+ const pipeline = await device.createComputePipelineAsync({
+ layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }),
+ compute: { module, entryPoint: 'main' },
+ });
+ return { pipeline, layout };
+}
+
+// Serialise a signed i64 bigint into (lo: u32, hi: u32) — the on-wire
+// representation of Mat fields in the input buffer.
+function i64ToLoHi(v: bigint): { lo: number; hi: number } {
+ const U64 = (1n << 64n) - 1n;
+ let u = v & U64;
+ if (u < 0n) u += 1n << 64n;
+ const lo = Number(u & 0xffffffffn) >>> 0;
+ const hi = Number((u >> 32n) & 0xffffffffn) >>> 0;
+ return { lo, hi };
+}
+
+async function runBench(
+ device: GPUDevice,
+ sm: ShaderManager,
+ n: number,
+ validateN: number,
+ reps: number,
+): Promise {
+ log('info', `building inputs (n=${n}, validate-n=${validateN}, reps=${reps})`);
+
+ // Seeded inputs.
+ const rng = makeRng(0xa57e1234);
+ const inputsBuf = new Uint32Array(n * INPUT_STRIDE_U32);
+
+ // Per-input ground truth: f', g', d', e' after applyMatrix.
+ const expectedOutputs = new Int32Array(validateN * OUTPUT_STRIDE_I32);
+
+ for (let i = 0; i < n; i++) {
+ // 1. Sample random u64 f_lo, g_lo, and a small random delta; run the TS
+ // Wasm9x29.divsteps to produce a valid Mat (matrix entries satisfy
+ // the |entry| <= 2^58 bound that apply_matrix downstream expects).
+ const f64 = randomU64(rng);
+ const g64 = randomU64(rng);
+ const delta = (rng() % 1025) - 512;
+ const fBig = BigInt(f64.lo >>> 0) | (BigInt(f64.hi >>> 0) << 32n);
+ const gBig = BigInt(g64.lo >>> 0) | (BigInt(g64.hi >>> 0) << 32n);
+ const { mat } = Wasm9x29.divsteps(BigInt(delta), fBig, gBig);
+
+ // 2. Synthesise random signed-canonical f, g, d, e BigIntBY states.
+ const fState = randomNormalisedBigIntBY(rng);
+ const gState = randomNormalisedBigIntBY(rng);
+ const dState = randomNormalisedBigIntBY(rng);
+ const eState = randomNormalisedBigIntBY(rng);
+
+ // 3. Encode into the input buffer. Mat fields go (u, v, q, r, u_hi, v_hi,
+ // q_hi, r_hi). Each is the low 32 / high 32 of the i64 stored in
+ // bigint form.
+ const base = i * INPUT_STRIDE_U32;
+ const u_pair = i64ToLoHi(mat.u);
+ const v_pair = i64ToLoHi(mat.v);
+ const q_pair = i64ToLoHi(mat.q);
+ const r_pair = i64ToLoHi(mat.r);
+ inputsBuf[base + 0] = u_pair.lo;
+ inputsBuf[base + 1] = v_pair.lo;
+ inputsBuf[base + 2] = q_pair.lo;
+ inputsBuf[base + 3] = r_pair.lo;
+ inputsBuf[base + 4] = u_pair.hi;
+ inputsBuf[base + 5] = v_pair.hi;
+ inputsBuf[base + 6] = q_pair.hi;
+ inputsBuf[base + 7] = r_pair.hi;
+ for (let j = 0; j < BY_N; j++) {
+ // Limbs are bigint after normalisation; coerce to signed 32-bit
+ // representation. The `>>> 0` then re-cast to i32 via Int32Array
+ // happens via writing as u32 in the buffer and bitcasting in WGSL.
+ inputsBuf[base + 8 + j] = Number(BigInt.asIntN(32, fState[j])) | 0;
+ inputsBuf[base + 17 + j] = Number(BigInt.asIntN(32, gState[j])) | 0;
+ inputsBuf[base + 26 + j] = Number(BigInt.asIntN(32, dState[j])) | 0;
+ inputsBuf[base + 35 + j] = Number(BigInt.asIntN(32, eState[j])) | 0;
+ }
+
+ // 4. Compute reference for validate-n.
+ if (i < validateN) {
+ const fRef = fState.slice();
+ const gRef = gState.slice();
+ const dRef = dState.slice();
+ const eRef = eState.slice();
+ const pRef = byFromBigint(Wasm9x29.P);
+ Wasm9x29.applyMatrix(mat, fRef, gRef, dRef, eRef, pRef, Wasm9x29.P_INV);
+ // applyMatrix output is signed; serialise as i32.
+ const outBase = i * OUTPUT_STRIDE_I32;
+ for (let j = 0; j < BY_N; j++) {
+ expectedOutputs[outBase + 0 + j] = Number(BigInt.asIntN(32, fRef[j])) | 0;
+ expectedOutputs[outBase + 9 + j] = Number(BigInt.asIntN(32, gRef[j])) | 0;
+ expectedOutputs[outBase + 18 + j] = Number(BigInt.asIntN(32, dRef[j])) | 0;
+ expectedOutputs[outBase + 27 + j] = Number(BigInt.asIntN(32, eRef[j])) | 0;
+ }
+ }
+ }
+
+ // Shader & pipeline.
+ const WORKGROUP_SIZE = 64;
+ const code = sm.gen_apply_matrix_bench_shader(WORKGROUP_SIZE);
+ const cacheKey = `apply-matrix-bench-wg${WORKGROUP_SIZE}`;
+ log('info', `compiling shader (${code.length} chars)`);
+ (window as unknown as Record)[`__shader`] = code;
+ const { pipeline, layout } = await createPipeline(device, code, cacheKey);
+
+ // Buffers.
+ const inputsGpu = device.createBuffer({
+ size: inputsBuf.byteLength,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
+ });
+ device.queue.writeBuffer(inputsGpu, 0, inputsBuf);
+
+ const outBytes = n * OUTPUT_STRIDE_I32 * 4;
+ const outBuf = device.createBuffer({
+ size: outBytes,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
+ });
+
+ const uniformBytes = new ArrayBuffer(16);
+ const uniformView = new Uint32Array(uniformBytes);
+ uniformView[0] = n;
+ uniformView[1] = 0;
+ const uniformBuf = device.createBuffer({
+ size: 16,
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
+ });
+ device.queue.writeBuffer(uniformBuf, 0, uniformBytes);
+
+ const bindGroup = device.createBindGroup({
+ layout,
+ entries: [
+ { binding: 0, resource: { buffer: inputsGpu } },
+ { binding: 1, resource: { buffer: outBuf } },
+ { binding: 2, resource: { buffer: uniformBuf } },
+ ],
+ });
+
+ const numWorkgroups = Math.ceil(n / WORKGROUP_SIZE);
+ log('info', `dispatching ${numWorkgroups} workgroups of ${WORKGROUP_SIZE} threads each (${n} threads total)`);
+
+ // Warmup pass.
+ {
+ const encoder = device.createCommandEncoder();
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bindGroup);
+ pass.dispatchWorkgroups(numWorkgroups, 1, 1);
+ pass.end();
+ device.queue.submit([encoder.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ }
+ log('info', 'warmup OK');
+
+ // Validation pass — read back outputs.
+ const stagingBuf = device.createBuffer({
+ size: outBytes,
+ usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
+ });
+ {
+ const encoder = device.createCommandEncoder();
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bindGroup);
+ pass.dispatchWorkgroups(numWorkgroups, 1, 1);
+ pass.end();
+ encoder.copyBufferToBuffer(outBuf, 0, stagingBuf, 0, outBytes);
+ device.queue.submit([encoder.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ await stagingBuf.mapAsync(GPUMapMode.READ);
+ }
+ const outBytesCopy = stagingBuf.getMappedRange(0, outBytes).slice(0);
+ stagingBuf.unmap();
+ stagingBuf.destroy();
+
+ const outI32 = new Int32Array(outBytesCopy);
+ const mismatches: string[] = [];
+ let validateOk = true;
+ for (let i = 0; i < validateN; i++) {
+ const base = i * OUTPUT_STRIDE_I32;
+ let pairOk = true;
+ for (let j = 0; j < OUTPUT_STRIDE_I32; j++) {
+ if (outI32[base + j] !== expectedOutputs[base + j]) {
+ pairOk = false;
+ break;
+ }
+ }
+ if (!pairOk) {
+ validateOk = false;
+ if (mismatches.length < 5) {
+ const got = Array.from(outI32.slice(base, base + OUTPUT_STRIDE_I32));
+ const want = Array.from(expectedOutputs.slice(base, base + OUTPUT_STRIDE_I32));
+ const labelLimbs = (arr: number[], offset: number, name: string) =>
+ `${name}: [${arr.slice(offset, offset + 9).join(', ')}]`;
+ mismatches.push(
+ `pair[${i}]:\n` +
+ ` expected:\n` +
+ ` ${labelLimbs(want, 0, "f'")}\n` +
+ ` ${labelLimbs(want, 9, "g'")}\n` +
+ ` ${labelLimbs(want, 18, "d'")}\n` +
+ ` ${labelLimbs(want, 27, "e'")}\n` +
+ ` actual:\n` +
+ ` ${labelLimbs(got, 0, "f'")}\n` +
+ ` ${labelLimbs(got, 9, "g'")}\n` +
+ ` ${labelLimbs(got, 18, "d'")}\n` +
+ ` ${labelLimbs(got, 27, "e'")}`,
+ );
+ }
+ }
+ }
+
+ if (!validateOk) {
+ log('err', `VALIDATION FAILED (${mismatches.length} mismatches shown; first ${validateN} pairs checked)`);
+ for (const m of mismatches) log('err', m);
+ inputsGpu.destroy();
+ outBuf.destroy();
+ uniformBuf.destroy();
+ return { validateOk: false, mismatches, timing: null };
+ }
+ log('ok', `VALIDATION OK (${validateN} pairs)`);
+
+ // Timed reps.
+ const msSamples: number[] = [];
+ for (let rep = 0; rep < reps; rep++) {
+ const encoder = device.createCommandEncoder();
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bindGroup);
+ pass.dispatchWorkgroups(numWorkgroups, 1, 1);
+ pass.end();
+ const t0 = performance.now();
+ device.queue.submit([encoder.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ const t1 = performance.now();
+ msSamples.push(t1 - t0);
+ }
+ const msMed = median(msSamples);
+ const msMin = Math.min(...msSamples);
+ const msMax = Math.max(...msSamples);
+ const applyMatrixPerSec = n / (msMed / 1000);
+ log(
+ 'ok',
+ `timing reps=${reps} median=${msMed.toFixed(3)}ms min=${msMin.toFixed(3)}ms max=${msMax.toFixed(3)}ms apply_matrix_calls/s=${applyMatrixPerSec.toExponential(3)} (n=${n})`,
+ );
+
+ inputsGpu.destroy();
+ outBuf.destroy();
+ uniformBuf.destroy();
+
+ return {
+ validateOk: true,
+ mismatches: [],
+ timing: { reps, msSamples, msMedian: msMed, msMin, msMax, applyMatrixPerSec },
+ };
+}
+
+function parseParams(): { n: number; validateN: number; reps: number } {
+ const qp = new URLSearchParams(window.location.search);
+ const n = parseInt(qp.get('n') ?? '1024', 10);
+ const validateN = parseInt(qp.get('validate-n') ?? String(Math.min(64, n)), 10);
+ const reps = parseInt(qp.get('reps') ?? '3', 10);
+ if (!Number.isFinite(n) || n <= 0 || n > N_MAX) {
+ throw new Error(`?n must be in (0, ${N_MAX}], got ${qp.get('n')}`);
+ }
+ if (!Number.isFinite(validateN) || validateN < 0 || validateN > n) {
+ throw new Error(`?validate-n must be in [0, n], got ${qp.get('validate-n')}`);
+ }
+ if (!Number.isFinite(reps) || reps <= 0 || reps > 100) {
+ throw new Error(`?reps must be in (0, 100], got ${qp.get('reps')}`);
+ }
+ return { n, validateN, reps };
+}
+
+async function main() {
+ try {
+ if (!('gpu' in navigator)) {
+ throw new Error('navigator.gpu missing — WebGPU not available');
+ }
+ const params = parseParams();
+ benchState.params = params;
+ log('info', `params: n=${params.n} validate-n=${params.validateN} reps=${params.reps}`);
+
+ benchState.state = 'running';
+ const device = await get_device();
+ log('info', 'WebGPU device acquired');
+
+ // ShaderManager is keyed on chunk_size / input_size for the MSM pipeline;
+ // for this bench we only need its constants table so values are arbitrary.
+ const sm = new ShaderManager(4, params.n, BN254_CURVE_CONFIG, false);
+
+ const result = await runBench(device, sm, params.n, params.validateN, params.reps);
+ benchState.result = result;
+
+ benchState.state = 'done';
+ if (result.validateOk) {
+ log('ok', 'bench done');
+ } else {
+ log('err', 'bench done with validation failures');
+ }
+ } catch (e) {
+ const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e);
+ log('err', `FATAL: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ }
+}
+
+main().catch(e => {
+ const msg = e instanceof Error ? e.message : String(e);
+ log('err', `unhandled: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+});
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-batch-affine.html b/barretenberg/ts/dev/msm-webgpu/bench-batch-affine.html
new file mode 100644
index 000000000000..dcb1507d12ef
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-batch-affine.html
@@ -0,0 +1,37 @@
+
+
+
+
+ Batch-affine EC add amortisation bench (WebGPU)
+
+
+
+ Batch-affine EC add amortisation bench (WebGPU)
+ Query params: ?reps=R
+
+
+
+
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-batch-affine.ts b/barretenberg/ts/dev/msm-webgpu/bench-batch-affine.ts
new file mode 100644
index 000000000000..1667518702fa
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-batch-affine.ts
@@ -0,0 +1,423 @@
+///
+// Standalone single-dispatch WebGPU benchmark for batch-affine EC addition.
+// Sweeps BATCH_SIZE in {32, 64, 128, 256, 512, 1024, 2048} at fixed
+// TOTAL_PAIRS = 65536 to find the sweet spot where amortising the single
+// `fr_inv_by_a` per batch stops beating the loss of GPU thread occupancy.
+//
+// SAFETY: NO MSM pipeline touched. The shader has only compile-time-const
+// loop bounds (BS, TPB, NUM_WORDS). Single dispatch per measurement, no
+// recursion. Total VRAM under 40 MB.
+
+import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js';
+import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js';
+import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js';
+import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js';
+import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js';
+
+// TOTAL_PAIRS defaults to 65536 (2^16). Overridable via ?total=N for
+// smoke-tests on small slices (NOT used in the headline sweep, but
+// useful for "does this dispatch return at all" runs at N=1024).
+const DEFAULT_TOTAL_PAIRS = 1 << 16; // 65536
+let TOTAL_PAIRS = DEFAULT_TOTAL_PAIRS;
+// BATCH_SIZES is also overridable via ?sizes=A,B,C for the smoke-test
+// where we only want to run one size.
+const DEFAULT_BATCH_SIZES = [32, 64, 128, 256, 512, 1024, 2048] as const;
+let BATCH_SIZES: readonly number[] = DEFAULT_BATCH_SIZES;
+
+// (batch_size, tpb, per_thread_count) table. Per the bench spec:
+// B=32: TPB=32, BS=1
+// B=64: TPB=64, BS=1
+// B=128+: TPB=64, BS=B/64
+function tpbFor(batchSize: number): { tpb: number; bs: number } {
+ if (batchSize === 32) return { tpb: 32, bs: 1 };
+ if (batchSize === 64) return { tpb: 64, bs: 1 };
+ return { tpb: 64, bs: batchSize / 64 };
+}
+
+const NUM_LIMBS_U32 = 20;
+const WORD_SIZE_U32 = 13;
+const W_U32 = 1n << BigInt(WORD_SIZE_U32);
+const MASK_U32 = W_U32 - 1n;
+
+function bigintToLimbsU32(v: bigint): number[] {
+ const limbs: number[] = new Array(NUM_LIMBS_U32);
+ let x = v;
+ for (let i = 0; i < NUM_LIMBS_U32; i++) {
+ limbs[i] = Number(x & MASK_U32);
+ x >>= BigInt(WORD_SIZE_U32);
+ }
+ return limbs;
+}
+
+function limbsU32ToBigint(limbs: ArrayLike): bigint {
+ let v = 0n;
+ for (let i = NUM_LIMBS_U32 - 1; i >= 0; i--) {
+ v = (v << BigInt(WORD_SIZE_U32)) | BigInt(limbs[i] >>> 0);
+ }
+ return v;
+}
+
+// Seeded LCG (Numerical Recipes constants). Matches the seeding convention
+// used by the other dev-page benches.
+function makeRng(seed: number): () => number {
+ let state = (seed >>> 0) || 1;
+ return () => {
+ state = (Math.imul(state, 1664525) + 1013904223) >>> 0;
+ return state;
+ };
+}
+
+function randomBelow(p: bigint, rng: () => number): bigint {
+ const bitlen = p.toString(2).length;
+ const byteLen = Math.ceil(bitlen / 8);
+ while (true) {
+ let v = 0n;
+ for (let i = 0; i < byteLen; i++) {
+ v = (v << 8n) | BigInt(rng() & 0xff);
+ }
+ v &= (1n << BigInt(bitlen)) - 1n;
+ if (v < p) return v;
+ }
+}
+
+function median(xs: number[]): number {
+ if (xs.length === 0) return NaN;
+ const s = xs.slice().sort((a, b) => a - b);
+ return s[Math.floor(s.length / 2)];
+}
+
+interface PerSizeResult {
+ batch_size: number;
+ tpb: number;
+ bs: number;
+ num_wgs: number;
+ total_threads: number;
+ median_ms: number;
+ min_ms: number;
+ max_ms: number;
+ ns_per_pair: number;
+ samples_ms: number[];
+ validated: boolean;
+}
+
+interface BenchState {
+ state: 'boot' | 'running' | 'done' | 'error';
+ params: { reps: number } | null;
+ results: PerSizeResult[];
+ error: string | null;
+ log: string[];
+}
+
+const benchState: BenchState = {
+ state: 'boot',
+ params: null,
+ results: [],
+ error: null,
+ log: [],
+};
+(window as unknown as { __bench: BenchState }).__bench = benchState;
+
+const $log = document.getElementById('log') as HTMLDivElement;
+function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) {
+ const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : '';
+ const span = document.createElement('div');
+ span.className = cls;
+ span.textContent = msg;
+ $log.appendChild(span);
+ benchState.log.push(`[${level}] ${msg}`);
+ console.log(`[bench-batch-affine] ${msg}`);
+}
+
+async function createPipeline(
+ device: GPUDevice,
+ code: string,
+ cacheKey: string,
+): Promise<{ pipeline: GPUComputePipeline; layout: GPUBindGroupLayout }> {
+ const module = device.createShaderModule({ code });
+ const info = await module.getCompilationInfo();
+ let hasError = false;
+ for (const msg of info.messages) {
+ const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`;
+ if (msg.type === 'error') {
+ console.error(line);
+ hasError = true;
+ } else {
+ console.warn(line);
+ }
+ }
+ if (hasError) {
+ throw new Error(`WGSL compile failed for ${cacheKey}`);
+ }
+ const layout = device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ ],
+ });
+ const pipeline = await device.createComputePipelineAsync({
+ layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }),
+ compute: { module, entryPoint: 'main' },
+ });
+ return { pipeline, layout };
+}
+
+// Build inputs ONCE per batch_size (deterministic via LCG seeded with
+// 0xdeadbeef + batch_size). Generates `TOTAL_PAIRS * 4` BigInts holding
+// random Mont-form values: pair k = [P_k.x, P_k.y, Q_k.x, Q_k.y]. The
+// algebra is garbage (these aren't on-curve points), but the shader
+// produces SOMETHING and the dispatch shape is what we're timing.
+function buildInputs(batchSize: number, R: bigint, p: bigint): ArrayBuffer {
+ const rng = makeRng(0xdeadbeef + batchSize);
+ const bytes = new ArrayBuffer(TOTAL_PAIRS * 4 * NUM_LIMBS_U32 * 4);
+ const buf = new Uint32Array(bytes);
+ for (let k = 0; k < TOTAL_PAIRS; k++) {
+ // Generate four field elements per pair, then Montgomery-form them so
+ // the shader's `montgomery_product` calls produce in-Mont-form
+ // outputs. We must also ensure P.x != Q.x (delta != 0) per pair, since
+ // batch-inverse on a zero delta is undefined and would NaN the run.
+ let pxMont: bigint, qxMont: bigint;
+ do {
+ const pxCan = randomBelow(p, rng);
+ pxMont = (pxCan * R) % p;
+ const qxCan = randomBelow(p, rng);
+ qxMont = (qxCan * R) % p;
+ } while (pxMont === qxMont);
+ const pyMont = (randomBelow(p, rng) * R) % p;
+ const qyMont = (randomBelow(p, rng) * R) % p;
+ const coords = [pxMont, pyMont, qxMont, qyMont];
+ const base = k * 4 * NUM_LIMBS_U32;
+ for (let c = 0; c < 4; c++) {
+ const limbs = bigintToLimbsU32(coords[c]);
+ const off = base + c * NUM_LIMBS_U32;
+ for (let j = 0; j < NUM_LIMBS_U32; j++) buf[off + j] = limbs[j];
+ }
+ }
+ return bytes;
+}
+
+async function runOne(
+ device: GPUDevice,
+ sm: ShaderManager,
+ batchSize: number,
+ reps: number,
+ R: bigint,
+ p: bigint,
+): Promise {
+ const { tpb, bs } = tpbFor(batchSize);
+ const numWgs = TOTAL_PAIRS / batchSize;
+ const totalThreads = numWgs * tpb;
+ log(
+ 'info',
+ `=== batch_size=${batchSize}: TPB=${tpb} BS=${bs} num_WGs=${numWgs} total_threads=${totalThreads}`,
+ );
+
+ // Compile shader.
+ const code = sm.gen_bench_batch_affine_shader(batchSize, tpb);
+ const cacheKey = `bench-batch-affine-B${batchSize}-T${tpb}`;
+ log('info', `compiling shader (${code.length} chars)`);
+ (window as unknown as Record)[`__shader_${batchSize}`] = code;
+ const { pipeline, layout } = await createPipeline(device, code, cacheKey);
+
+ // Build inputs.
+ const inputsAB = buildInputs(batchSize, R, p);
+
+ // Buffers.
+ const inputBytes = inputsAB.byteLength; // TOTAL_PAIRS * 4 * 80 = 21 MB
+ const inputsBuf = device.createBuffer({
+ size: inputBytes,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
+ });
+ device.queue.writeBuffer(inputsBuf, 0, inputsAB);
+
+ const prefixBytes = TOTAL_PAIRS * NUM_LIMBS_U32 * 4; // ~5.2 MB
+ const prefixBuf = device.createBuffer({
+ size: prefixBytes,
+ usage: GPUBufferUsage.STORAGE,
+ });
+
+ const outputBytes = TOTAL_PAIRS * 2 * NUM_LIMBS_U32 * 4; // ~10.5 MB
+ const outputsBuf = device.createBuffer({
+ size: outputBytes,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
+ });
+
+ const bindGroup = device.createBindGroup({
+ layout,
+ entries: [
+ { binding: 0, resource: { buffer: inputsBuf } },
+ { binding: 1, resource: { buffer: prefixBuf } },
+ { binding: 2, resource: { buffer: outputsBuf } },
+ ],
+ });
+
+ // Warmup + small validation: confirm the first pair's R.x/R.y are
+ // nonzero. We DON'T check algebraic correctness — the inputs aren't on
+ // an elliptic curve. We just need the shader to do SOMETHING and not
+ // hang or produce all zeros.
+ {
+ const encoder = device.createCommandEncoder();
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bindGroup);
+ pass.dispatchWorkgroups(numWgs, 1, 1);
+ pass.end();
+ device.queue.submit([encoder.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ }
+ log('info', 'warmup dispatch returned');
+
+ // Sanity readback: copy first pair's R.x to host and confirm non-zero.
+ const sanityBytes = 2 * NUM_LIMBS_U32 * 4;
+ const stagingBuf = device.createBuffer({
+ size: sanityBytes,
+ usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
+ });
+ {
+ const encoder = device.createCommandEncoder();
+ encoder.copyBufferToBuffer(outputsBuf, 0, stagingBuf, 0, sanityBytes);
+ device.queue.submit([encoder.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ await stagingBuf.mapAsync(GPUMapMode.READ);
+ }
+ const sanityBytesCopy = stagingBuf.getMappedRange(0, sanityBytes).slice(0);
+ stagingBuf.unmap();
+ stagingBuf.destroy();
+ const sanityU32 = new Uint32Array(sanityBytesCopy);
+ const rx = limbsU32ToBigint(sanityU32.subarray(0, NUM_LIMBS_U32));
+ const ry = limbsU32ToBigint(sanityU32.subarray(NUM_LIMBS_U32, 2 * NUM_LIMBS_U32));
+ if (rx === 0n && ry === 0n) {
+ log('err', `SANITY FAIL @ B=${batchSize}: first pair R.x and R.y both zero`);
+ inputsBuf.destroy();
+ prefixBuf.destroy();
+ outputsBuf.destroy();
+ throw new Error(`sanity fail at batch_size=${batchSize}`);
+ }
+ log('ok', `sanity OK: pair[0].R.x=0x${rx.toString(16).slice(0, 16)}... R.y=0x${ry.toString(16).slice(0, 16)}...`);
+
+ // Timed reps. Each rep includes a fresh encoder so we measure
+ // submit→idle (the shape the user cares about). Match bench-fr-inv.ts.
+ const samples: number[] = [];
+ for (let r = 0; r < reps; r++) {
+ const encoder = device.createCommandEncoder();
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bindGroup);
+ pass.dispatchWorkgroups(numWgs, 1, 1);
+ pass.end();
+ const t0 = performance.now();
+ device.queue.submit([encoder.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ const t1 = performance.now();
+ samples.push(t1 - t0);
+ }
+ const med = median(samples);
+ const mn = Math.min(...samples);
+ const mx = Math.max(...samples);
+ const nsPerPair = (med * 1e6) / TOTAL_PAIRS;
+
+ log(
+ 'ok',
+ `B=${batchSize}: median=${med.toFixed(3)}ms min=${mn.toFixed(3)}ms max=${mx.toFixed(3)}ms ns/pair=${nsPerPair.toFixed(1)}`,
+ );
+
+ inputsBuf.destroy();
+ prefixBuf.destroy();
+ outputsBuf.destroy();
+
+ return {
+ batch_size: batchSize,
+ tpb,
+ bs,
+ num_wgs: numWgs,
+ total_threads: totalThreads,
+ median_ms: med,
+ min_ms: mn,
+ max_ms: mx,
+ ns_per_pair: nsPerPair,
+ samples_ms: samples,
+ validated: true,
+ };
+}
+
+function parseParams(): { reps: number } {
+ const qp = new URLSearchParams(window.location.search);
+ const reps = parseInt(qp.get('reps') ?? '5', 10);
+ if (!Number.isFinite(reps) || reps <= 0 || reps > 50) {
+ throw new Error(`?reps must be in (0, 50], got ${qp.get('reps')}`);
+ }
+ const totalStr = qp.get('total');
+ if (totalStr !== null) {
+ const total = parseInt(totalStr, 10);
+ if (!Number.isFinite(total) || total <= 0 || total > (1 << 20)) {
+ throw new Error(`?total must be in (0, 2^20], got ${totalStr}`);
+ }
+ TOTAL_PAIRS = total;
+ }
+ const sizesStr = qp.get('sizes');
+ if (sizesStr !== null) {
+ const sizes = sizesStr.split(',').map(s => parseInt(s, 10));
+ for (const s of sizes) {
+ if (!Number.isFinite(s) || s <= 0 || s > 4096) {
+ throw new Error(`?sizes entries must be in (0, 4096], got ${s}`);
+ }
+ if (TOTAL_PAIRS % s !== 0) {
+ throw new Error(`?sizes entry ${s} does not divide TOTAL_PAIRS=${TOTAL_PAIRS}`);
+ }
+ }
+ BATCH_SIZES = sizes;
+ }
+ return { reps };
+}
+
+async function main() {
+ try {
+ if (!('gpu' in navigator)) {
+ throw new Error('navigator.gpu missing — WebGPU not available');
+ }
+ const params = parseParams();
+ benchState.params = params;
+ log('info', `params: reps=${params.reps} TOTAL_PAIRS=${TOTAL_PAIRS}`);
+
+ benchState.state = 'running';
+ const device = await get_device();
+ log('info', 'WebGPU device acquired');
+
+ const p = BN254_BASE_FIELD;
+ const miscParams = compute_misc_params(p, WORD_SIZE_U32);
+ if (miscParams.num_words !== NUM_LIMBS_U32) {
+ throw new Error(`expected num_words=${NUM_LIMBS_U32}, got ${miscParams.num_words}`);
+ }
+ const R = miscParams.r;
+
+ const sm = new ShaderManager(4, TOTAL_PAIRS, BN254_CURVE_CONFIG, false);
+
+ for (const B of BATCH_SIZES) {
+ try {
+ const r = await runOne(device, sm, B, params.reps, R, p);
+ benchState.results.push(r);
+ } catch (e) {
+ const msg = e instanceof Error ? e.message : String(e);
+ log('err', `B=${B} failed: ${msg} — STOPPING sweep at first failure`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ return;
+ }
+ }
+
+ benchState.state = 'done';
+ log('ok', 'all batches done');
+ } catch (e) {
+ const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e);
+ log('err', `FATAL: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ }
+}
+
+main().catch(e => {
+ const msg = e instanceof Error ? e.message : String(e);
+ log('err', `unhandled: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+});
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-divsteps.html b/barretenberg/ts/dev/msm-webgpu/bench-divsteps.html
new file mode 100644
index 000000000000..71ee4b64ff3d
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-divsteps.html
@@ -0,0 +1,37 @@
+
+
+
+
+ BY divsteps WebGPU dispatch test
+
+
+
+ BY divsteps WebGPU dispatch test
+ Query params: ?n=N&validate-n=N&reps=R
+
+
+
+
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-divsteps.ts b/barretenberg/ts/dev/msm-webgpu/bench-divsteps.ts
new file mode 100644
index 000000000000..92e2f15cbf33
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-divsteps.ts
@@ -0,0 +1,398 @@
+///
+// BY divsteps WebGPU dispatch test. Mounted standalone (no SRS, no MSM
+// pipeline). Reads `?n=N&validate-n=N&reps=R`, generates seeded random
+// (f_lo, g_lo, delta) tuples, runs the `by_divsteps` shader once per
+// thread, validates each output against the TS `Wasm9x29.divsteps`
+// reference, and reports timing via `window.__bench`.
+//
+// Safety. The shader's only loop is bounded by the WGSL `const BY_BATCH = 58u`
+// (in bigint_by.template.wgsl). The dispatch is one thread per input tuple,
+// `n` capped at 2^23. Inputs are u64 (for f_lo, g_lo) and i32 (for delta)
+// — by_divsteps is variable-time over branches but always exits after exactly
+// 58 iterations.
+
+import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js';
+import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js';
+import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js';
+import { Wasm9x29 } from '../../src/msm_webgpu/cuzk/bernstein_yang.js';
+
+interface SampleSummary {
+ reps: number;
+ msSamples: number[];
+ msMedian: number;
+ msMin: number;
+ msMax: number;
+ divstepsPerSec: number;
+}
+interface BenchResult {
+ validateOk: boolean;
+ mismatches: string[];
+ timing: SampleSummary | null;
+}
+interface BenchState {
+ state: 'boot' | 'running' | 'done' | 'error';
+ params: {
+ n: number;
+ validateN: number;
+ reps: number;
+ } | null;
+ result: BenchResult | null;
+ error: string | null;
+ log: string[];
+}
+
+const benchState: BenchState = {
+ state: 'boot',
+ params: null,
+ result: null,
+ error: null,
+ log: [],
+};
+(window as unknown as { __bench: BenchState }).__bench = benchState;
+
+const $log = document.getElementById('log') as HTMLDivElement;
+function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) {
+ const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : '';
+ const span = document.createElement('div');
+ span.className = cls;
+ span.textContent = msg;
+ $log.appendChild(span);
+ benchState.log.push(`[${level}] ${msg}`);
+ console.log(`[bench-divsteps] ${msg}`);
+}
+
+const N_MAX = 1 << 23;
+
+// Seeded LCG (Numerical Recipes constants). Matches the pattern used in
+// bench-field-mul.ts; the spec mandates seed 0xb33fb33f.
+function makeRng(seed: number): () => number {
+ let state = (seed >>> 0) || 1;
+ return () => {
+ state = (Math.imul(state, 1664525) + 1013904223) >>> 0;
+ return state;
+ };
+}
+
+// Produce a random u64 as a (lo32, hi32) pair via two LCG draws.
+function randomU64(rng: () => number): { lo: number; hi: number } {
+ const lo = rng() >>> 0;
+ const hi = rng() >>> 0;
+ return { lo, hi };
+}
+
+// Random delta in [-512, 512] inclusive (1025 values).
+function randomDelta(rng: () => number): number {
+ return (rng() % 1025) - 512;
+}
+
+function median(xs: number[]): number {
+ if (xs.length === 0) return NaN;
+ const s = xs.slice().sort((a, b) => a - b);
+ return s[Math.floor(s.length / 2)];
+}
+
+// Reassemble a (lo: i32, hi: i32) GPU output pair as a bigint (signed two's
+// complement, 64 bits) for cross-checking against the TS reference's
+// `bigint` matrix entries.
+function pairToSignedBig(lo: number, hi: number): bigint {
+ const loBig = BigInt(lo >>> 0);
+ const hiBig = BigInt(hi >>> 0);
+ let v = loBig | (hiBig << 32n);
+ if (v >= 1n << 63n) v -= 1n << 64n;
+ return v;
+}
+
+async function createPipeline(
+ device: GPUDevice,
+ code: string,
+ cacheKey: string,
+): Promise<{ pipeline: GPUComputePipeline; layout: GPUBindGroupLayout }> {
+ const module = device.createShaderModule({ code });
+ const info = await module.getCompilationInfo();
+ let hasError = false;
+ for (const msg of info.messages) {
+ const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`;
+ if (msg.type === 'error') {
+ console.error(line);
+ hasError = true;
+ } else {
+ console.warn(line);
+ }
+ }
+ if (hasError) {
+ throw new Error(`WGSL compile failed for ${cacheKey}`);
+ }
+ const layout = device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+ const pipeline = await device.createComputePipelineAsync({
+ layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }),
+ compute: { module, entryPoint: 'main' },
+ });
+ return { pipeline, layout };
+}
+
+async function runBench(
+ device: GPUDevice,
+ sm: ShaderManager,
+ n: number,
+ validateN: number,
+ reps: number,
+): Promise {
+ log('info', `building inputs (n=${n}, validate-n=${validateN}, reps=${reps})`);
+
+ // Generate seeded inputs and compute TS reference in lockstep.
+ const rng = makeRng(0xb33fb33f);
+ const inputsFg = new Uint32Array(n * 4);
+ const inputsDelta = new Int32Array(n);
+ const expectedU = new BigInt64Array(n);
+ const expectedV = new BigInt64Array(n);
+ const expectedQ = new BigInt64Array(n);
+ const expectedR = new BigInt64Array(n);
+ const expectedDelta = new Int32Array(n);
+ for (let i = 0; i < n; i++) {
+ const f = randomU64(rng);
+ const g = randomU64(rng);
+ const d = randomDelta(rng);
+ inputsFg[i * 4 + 0] = f.lo;
+ inputsFg[i * 4 + 1] = f.hi;
+ inputsFg[i * 4 + 2] = g.lo;
+ inputsFg[i * 4 + 3] = g.hi;
+ inputsDelta[i] = d;
+ // Reassemble u64 bigints for the TS reference; delta is signed i32.
+ const fBig = BigInt(f.lo >>> 0) | (BigInt(f.hi >>> 0) << 32n);
+ const gBig = BigInt(g.lo >>> 0) | (BigInt(g.hi >>> 0) << 32n);
+ if (i < validateN) {
+ const { mat, delta: deltaOut } = Wasm9x29.divsteps(BigInt(d), fBig, gBig);
+ expectedU[i] = mat.u;
+ expectedV[i] = mat.v;
+ expectedQ[i] = mat.q;
+ expectedR[i] = mat.r;
+ // The TS port returns delta as bigint (the C++ i64 view). For our
+ // i32 carrier on the GPU this is fine — under BATCH=58 inner ops,
+ // delta changes by at most 58 per call, so for |delta_in| <= 512 the
+ // result fits well inside i32.
+ expectedDelta[i] = Number(deltaOut);
+ }
+ }
+
+ // Shader & pipeline.
+ const WORKGROUP_SIZE = 64;
+ const code = sm.gen_divsteps_bench_shader(WORKGROUP_SIZE);
+ const cacheKey = `divsteps-bench-wg${WORKGROUP_SIZE}`;
+ log('info', `compiling shader (${code.length} chars)`);
+ (window as unknown as Record)[`__shader`] = code;
+ const { pipeline, layout } = await createPipeline(device, code, cacheKey);
+
+ // Buffers.
+ const inputsFgBuf = device.createBuffer({
+ size: inputsFg.byteLength,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
+ });
+ device.queue.writeBuffer(inputsFgBuf, 0, inputsFg);
+ const inputsDeltaBuf = device.createBuffer({
+ size: inputsDelta.byteLength,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
+ });
+ device.queue.writeBuffer(inputsDeltaBuf, 0, inputsDelta);
+ const outBytes = n * 9 * 4;
+ const outBuf = device.createBuffer({
+ size: outBytes,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
+ });
+ const uniformBytes = new ArrayBuffer(16);
+ const uniformView = new Uint32Array(uniformBytes);
+ uniformView[0] = n;
+ uniformView[1] = 0;
+ const uniformBuf = device.createBuffer({
+ size: 16,
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
+ });
+ device.queue.writeBuffer(uniformBuf, 0, uniformBytes);
+
+ const bindGroup = device.createBindGroup({
+ layout,
+ entries: [
+ { binding: 0, resource: { buffer: inputsFgBuf } },
+ { binding: 1, resource: { buffer: inputsDeltaBuf } },
+ { binding: 2, resource: { buffer: outBuf } },
+ { binding: 3, resource: { buffer: uniformBuf } },
+ ],
+ });
+
+ const numWorkgroups = Math.ceil(n / WORKGROUP_SIZE);
+ log('info', `dispatching ${numWorkgroups} workgroups of ${WORKGROUP_SIZE} threads each (${n} threads total)`);
+
+ // Warmup pass.
+ {
+ const encoder = device.createCommandEncoder();
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bindGroup);
+ pass.dispatchWorkgroups(numWorkgroups, 1, 1);
+ pass.end();
+ device.queue.submit([encoder.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ }
+ log('info', 'warmup OK');
+
+ // Validation pass — read back outputs.
+ const stagingBuf = device.createBuffer({
+ size: outBytes,
+ usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
+ });
+ {
+ const encoder = device.createCommandEncoder();
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bindGroup);
+ pass.dispatchWorkgroups(numWorkgroups, 1, 1);
+ pass.end();
+ encoder.copyBufferToBuffer(outBuf, 0, stagingBuf, 0, outBytes);
+ device.queue.submit([encoder.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ await stagingBuf.mapAsync(GPUMapMode.READ);
+ }
+ const outBytesCopy = stagingBuf.getMappedRange(0, outBytes).slice(0);
+ stagingBuf.unmap();
+ stagingBuf.destroy();
+
+ const outI32 = new Int32Array(outBytesCopy);
+ let mismatches: string[] = [];
+ let validateOk = true;
+ for (let i = 0; i < validateN; i++) {
+ const base = i * 9;
+ const u = pairToSignedBig(outI32[base + 0], outI32[base + 4]);
+ const v = pairToSignedBig(outI32[base + 1], outI32[base + 5]);
+ const q = pairToSignedBig(outI32[base + 2], outI32[base + 6]);
+ const r = pairToSignedBig(outI32[base + 3], outI32[base + 7]);
+ const dOut = outI32[base + 8];
+ const okU = u === expectedU[i];
+ const okV = v === expectedV[i];
+ const okQ = q === expectedQ[i];
+ const okR = r === expectedR[i];
+ const okD = dOut === expectedDelta[i];
+ if (!(okU && okV && okQ && okR && okD)) {
+ validateOk = false;
+ if (mismatches.length < 5) {
+ const fLo = BigInt(inputsFg[i * 4] >>> 0) | (BigInt(inputsFg[i * 4 + 1] >>> 0) << 32n);
+ const gLo = BigInt(inputsFg[i * 4 + 2] >>> 0) | (BigInt(inputsFg[i * 4 + 3] >>> 0) << 32n);
+ mismatches.push(
+ `pair[${i}]: delta_in=${inputsDelta[i]} f_lo=0x${fLo.toString(16)} g_lo=0x${gLo.toString(16)}\n` +
+ ` expected: u=${expectedU[i]} v=${expectedV[i]} q=${expectedQ[i]} r=${expectedR[i]} delta=${expectedDelta[i]}\n` +
+ ` actual: u=${u} v=${v} q=${q} r=${r} delta=${dOut}`,
+ );
+ }
+ }
+ }
+
+ if (!validateOk) {
+ log('err', `VALIDATION FAILED (${mismatches.length} mismatches shown; first ${validateN} pairs checked)`);
+ for (const m of mismatches) log('err', m);
+ inputsFgBuf.destroy();
+ inputsDeltaBuf.destroy();
+ outBuf.destroy();
+ uniformBuf.destroy();
+ return { validateOk: false, mismatches, timing: null };
+ }
+ log('ok', `VALIDATION OK (${validateN} pairs)`);
+
+ // Timed reps.
+ const msSamples: number[] = [];
+ for (let rep = 0; rep < reps; rep++) {
+ const encoder = device.createCommandEncoder();
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bindGroup);
+ pass.dispatchWorkgroups(numWorkgroups, 1, 1);
+ pass.end();
+ const t0 = performance.now();
+ device.queue.submit([encoder.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ const t1 = performance.now();
+ msSamples.push(t1 - t0);
+ }
+ const msMed = median(msSamples);
+ const msMin = Math.min(...msSamples);
+ const msMax = Math.max(...msSamples);
+ const divstepsPerSec = n / (msMed / 1000);
+ log(
+ 'ok',
+ `timing reps=${reps} median=${msMed.toFixed(3)}ms min=${msMin.toFixed(3)}ms max=${msMax.toFixed(3)}ms divsteps_calls/s=${divstepsPerSec.toExponential(3)} (n=${n})`,
+ );
+
+ inputsFgBuf.destroy();
+ inputsDeltaBuf.destroy();
+ outBuf.destroy();
+ uniformBuf.destroy();
+
+ return {
+ validateOk: true,
+ mismatches: [],
+ timing: { reps, msSamples, msMedian: msMed, msMin, msMax, divstepsPerSec },
+ };
+}
+
+function parseParams(): { n: number; validateN: number; reps: number } {
+ const qp = new URLSearchParams(window.location.search);
+ const n = parseInt(qp.get('n') ?? '1024', 10);
+ const validateN = parseInt(qp.get('validate-n') ?? String(Math.min(64, n)), 10);
+ const reps = parseInt(qp.get('reps') ?? '3', 10);
+ if (!Number.isFinite(n) || n <= 0 || n > N_MAX) {
+ throw new Error(`?n must be in (0, ${N_MAX}], got ${qp.get('n')}`);
+ }
+ if (!Number.isFinite(validateN) || validateN < 0 || validateN > n) {
+ throw new Error(`?validate-n must be in [0, n], got ${qp.get('validate-n')}`);
+ }
+ if (!Number.isFinite(reps) || reps <= 0 || reps > 100) {
+ throw new Error(`?reps must be in (0, 100], got ${qp.get('reps')}`);
+ }
+ return { n, validateN, reps };
+}
+
+async function main() {
+ try {
+ if (!('gpu' in navigator)) {
+ throw new Error('navigator.gpu missing — WebGPU not available');
+ }
+ const params = parseParams();
+ benchState.params = params;
+ log('info', `params: n=${params.n} validate-n=${params.validateN} reps=${params.reps}`);
+
+ benchState.state = 'running';
+ const device = await get_device();
+ log('info', 'WebGPU device acquired');
+
+ // ShaderManager is keyed on chunk_size / input_size for the MSM pipeline;
+ // for this bench we only need its constants table (num_words, etc.) so
+ // values are arbitrary.
+ const sm = new ShaderManager(4, params.n, BN254_CURVE_CONFIG, false);
+
+ const result = await runBench(device, sm, params.n, params.validateN, params.reps);
+ benchState.result = result;
+
+ benchState.state = 'done';
+ if (result.validateOk) {
+ log('ok', 'bench done');
+ } else {
+ log('err', 'bench done with validation failures');
+ }
+ } catch (e) {
+ const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e);
+ log('err', `FATAL: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ }
+}
+
+main().catch(e => {
+ const msg = e instanceof Error ? e.message : String(e);
+ log('err', `unhandled: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+});
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-field-mul.html b/barretenberg/ts/dev/msm-webgpu/bench-field-mul.html
new file mode 100644
index 000000000000..df7202bc6b07
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-field-mul.html
@@ -0,0 +1,37 @@
+
+
+
+
+ BN254 field-mul micro-benchmark
+
+
+
+ BN254 field-mul micro-benchmark
+ Query params: ?path=u32|f32&n=N&k=K&validate-n=N&reps=R
+
+
+
+
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-field-mul.ts b/barretenberg/ts/dev/msm-webgpu/bench-field-mul.ts
new file mode 100644
index 000000000000..d4bab32d56a2
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-field-mul.ts
@@ -0,0 +1,921 @@
+///
+// Field-mul micro-benchmark page. Mounted standalone (no SRS, no MSM
+// pipeline). Reads `?path=u32|f32&n=N&k=K&validate-n=N&reps=R`, generates
+// random BN254 base-field pairs, runs `k` chained Montgomery products
+// per thread, validates the first `validate-n` outputs against a host
+// BigInt reference, and reports timing via `window.__bench`.
+//
+// Safety: `k` is capped at 100, `n` at 2^23, both checked before any
+// dispatch. The only loop in either shader is the k-loop with this
+// bound (see field_mul_bench_*.template.wgsl).
+
+import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js';
+import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js';
+import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js';
+import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js';
+
+type Path = 'u32' | 'f32';
+interface SampleSummary {
+ reps: number;
+ msSamples: number[];
+ msMedian: number;
+ msMin: number;
+ msMax: number;
+ multsPerSec: number;
+}
+interface PathResult {
+ path: Path;
+ validateOk: boolean;
+ mismatches: string[];
+ timing: SampleSummary | null;
+}
+interface BenchState {
+ state: 'boot' | 'running' | 'done' | 'error';
+ params: {
+ path: Path | 'both';
+ n: number;
+ k: number;
+ validateN: number;
+ reps: number;
+ } | null;
+ results: PathResult[];
+ error: string | null;
+ log: string[];
+}
+
+const benchState: BenchState = {
+ state: 'boot',
+ params: null,
+ results: [],
+ error: null,
+ log: [],
+};
+(window as unknown as { __bench: BenchState }).__bench = benchState;
+
+const $log = document.getElementById('log') as HTMLDivElement;
+function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) {
+ const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : '';
+ const span = document.createElement('div');
+ span.className = cls;
+ span.textContent = msg;
+ $log.appendChild(span);
+ benchState.log.push(`[${level}] ${msg}`);
+ console.log(`[bench-field-mul] ${msg}`);
+}
+
+const N_MAX = 1 << 23;
+const K_MAX = 100;
+
+// Limb layouts are fixed for this bench: u32 = 20×13-bit, f32 = 12×22-bit.
+const NUM_LIMBS_F32 = 12;
+const WORD_SIZE_F32 = 22;
+const NUM_LIMBS_U32 = 20;
+const WORD_SIZE_U32 = 13;
+
+const W_U32 = 1n << BigInt(WORD_SIZE_U32);
+const MASK_U32 = W_U32 - 1n;
+const W_F32 = 1n << BigInt(WORD_SIZE_F32);
+const MASK_F32 = W_F32 - 1n;
+
+function bigintToLimbsU32(v: bigint): number[] {
+ const limbs: number[] = new Array(NUM_LIMBS_U32);
+ let x = v;
+ for (let i = 0; i < NUM_LIMBS_U32; i++) {
+ limbs[i] = Number(x & MASK_U32);
+ x >>= BigInt(WORD_SIZE_U32);
+ }
+ return limbs;
+}
+function limbsU32ToBigint(limbs: ArrayLike): bigint {
+ let v = 0n;
+ for (let i = NUM_LIMBS_U32 - 1; i >= 0; i--) {
+ v = (v << BigInt(WORD_SIZE_U32)) | BigInt(limbs[i] >>> 0);
+ }
+ return v;
+}
+function bigintToLimbsF32(v: bigint): number[] {
+ const limbs: number[] = new Array(NUM_LIMBS_F32);
+ let x = v;
+ for (let i = 0; i < NUM_LIMBS_F32; i++) {
+ limbs[i] = Number(x & MASK_F32);
+ x >>= BigInt(WORD_SIZE_F32);
+ }
+ return limbs;
+}
+function limbsF32ToBigint(limbs: ArrayLike): bigint {
+ let v = 0n;
+ for (let i = NUM_LIMBS_F32 - 1; i >= 0; i--) {
+ v = (v << BigInt(WORD_SIZE_F32)) | BigInt(Math.round(limbs[i]));
+ }
+ return v;
+}
+
+// Seeded LCG (Numerical Recipes constants) for reproducible pair gen.
+// Math.random() is fine for input pairs (we're not testing RNG quality),
+// but a deterministic stream makes failures repeatable.
+function makeRng(seed: number): () => number {
+ let state = (seed >>> 0) || 1;
+ return () => {
+ state = (Math.imul(state, 1664525) + 1013904223) >>> 0;
+ return state;
+ };
+}
+
+function randomBelow(p: bigint, rng: () => number): bigint {
+ const bitlen = p.toString(2).length;
+ const byteLen = Math.ceil(bitlen / 8);
+ while (true) {
+ let v = 0n;
+ for (let i = 0; i < byteLen; i++) {
+ v = (v << 8n) | BigInt(rng() & 0xff);
+ }
+ v &= (1n << BigInt(bitlen)) - 1n;
+ if (v < p) return v;
+ }
+}
+
+function median(xs: number[]): number {
+ if (xs.length === 0) return NaN;
+ const s = xs.slice().sort((a, b) => a - b);
+ return s[Math.floor(s.length / 2)];
+}
+
+// Host BigInt reference for `k` chained Mont products. Inputs are in
+// Mont form (x_m = x * R mod p). Each Mont multiply returns
+// (x_m * y_m * R^-1) mod p. After `k` rounds starting from a_m, the
+// result is a_m * b_m^k * (R^-1)^k mod p (in Mont form).
+function chainedMontReference(
+ aMont: bigint,
+ bMont: bigint,
+ k: number,
+ Rinv: bigint,
+ p: bigint,
+): bigint {
+ let acc = aMont;
+ for (let i = 0; i < k; i++) {
+ acc = (acc * bMont * Rinv) % p;
+ }
+ return acc;
+}
+
+async function createPipeline(
+ device: GPUDevice,
+ code: string,
+ cacheKey: string,
+): Promise<{ pipeline: GPUComputePipeline; layout: GPUBindGroupLayout }> {
+ const module = device.createShaderModule({ code });
+ const info = await module.getCompilationInfo();
+ let hasError = false;
+ for (const msg of info.messages) {
+ const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`;
+ if (msg.type === 'error') {
+ console.error(line);
+ hasError = true;
+ } else {
+ console.warn(line);
+ }
+ }
+ if (hasError) {
+ throw new Error(`WGSL compile failed for ${cacheKey}`);
+ }
+ const layout = device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+ const pipeline = await device.createComputePipelineAsync({
+ layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }),
+ compute: { module, entryPoint: 'main' },
+ });
+ return { pipeline, layout };
+}
+
+async function runPath(
+ device: GPUDevice,
+ sm: ShaderManager,
+ path: Path,
+ n: number,
+ k: number,
+ validateN: number,
+ reps: number,
+): Promise {
+ // Variant selects which Mont algorithm to bench:
+ // u32 path: 'cios' (default, mitschabaude runtime-loop CIOS) or 'karat'
+ // (recursive Karatsuba + Yuval reduction).
+ // f32 path: 'sos3uv3' only (separate per-slot tlo/thi f32 accumulators).
+ const qp = new URLSearchParams(window.location.search);
+ const variant = (qp.get('variant') ?? (path === 'u32' ? 'cios' : 'sos3uv3')) as
+ | 'cios'
+ | 'karat'
+ | 'sos3uv3';
+
+ log('info', `path=${path}: building pairs (n=${n}, k=${k}, validate-n=${validateN}, reps=${reps})`);
+ const p = BN254_CURVE_CONFIG.baseFieldModulus;
+ const wordSize = path === 'u32' ? WORD_SIZE_U32 : WORD_SIZE_F32;
+ const numLimbs = path === 'u32' ? NUM_LIMBS_U32 : NUM_LIMBS_F32;
+ const params = compute_misc_params(p, wordSize);
+ if (params.num_words !== numLimbs) {
+ throw new Error(`expected num_words=${numLimbs} for path=${path}, got ${params.num_words}`);
+ }
+ const R = params.r;
+ const Rinv = params.rinv;
+ if ((R * Rinv) % p !== 1n) {
+ throw new Error(`R * Rinv mod p != 1 for path=${path}`);
+ }
+
+ // Generate random a_canonical, b_canonical pairs (CPU side BigInts).
+ const rng = makeRng(0xc0ffee + (path === 'u32' ? 0 : 1));
+ const aCanonical: bigint[] = new Array(n);
+ const bCanonical: bigint[] = new Array(n);
+ for (let i = 0; i < n; i++) {
+ aCanonical[i] = randomBelow(p, rng);
+ bCanonical[i] = randomBelow(p, rng);
+ }
+
+ // Encode in the appropriate Mont ring.
+ const aMont: bigint[] = aCanonical.map(x => (x * R) % p);
+ const bMont: bigint[] = bCanonical.map(x => (x * R) % p);
+
+ // CPU reference for the first validate-n pairs.
+ log('info', `path=${path}: computing host reference for ${validateN} pairs`);
+ const expected: bigint[] = new Array(validateN);
+ for (let i = 0; i < validateN; i++) {
+ expected[i] = chainedMontReference(aMont[i], bMont[i], k, Rinv, p);
+ }
+
+ // Pack separate `xs` / `ys` buffers — one BigInt per thread per buffer.
+ const bytesPerLimbArray = numLimbs * 4;
+ const xsBytes = new ArrayBuffer(n * bytesPerLimbArray);
+ const ysBytes = new ArrayBuffer(n * bytesPerLimbArray);
+
+ if (path === 'u32') {
+ const xv = new Uint32Array(xsBytes);
+ const yv = new Uint32Array(ysBytes);
+ for (let i = 0; i < n; i++) {
+ const aLimbs = bigintToLimbsU32(aMont[i]);
+ const bLimbs = bigintToLimbsU32(bMont[i]);
+ const off = i * NUM_LIMBS_U32;
+ for (let j = 0; j < NUM_LIMBS_U32; j++) xv[off + j] = aLimbs[j];
+ for (let j = 0; j < NUM_LIMBS_U32; j++) yv[off + j] = bLimbs[j];
+ }
+ } else {
+ const xv = new Float32Array(xsBytes);
+ const yv = new Float32Array(ysBytes);
+ for (let i = 0; i < n; i++) {
+ const aLimbs = bigintToLimbsF32(aMont[i]);
+ const bLimbs = bigintToLimbsF32(bMont[i]);
+ const off = i * NUM_LIMBS_F32;
+ for (let j = 0; j < NUM_LIMBS_F32; j++) xv[off + j] = aLimbs[j];
+ for (let j = 0; j < NUM_LIMBS_F32; j++) yv[off + j] = bLimbs[j];
+ }
+ }
+
+ const WORKGROUP_SIZE = 64;
+ let code: string;
+ if (path === 'u32') {
+ if (variant !== 'cios' && variant !== 'karat') {
+ throw new Error(`u32 path supports variant 'cios' or 'karat', got '${variant}'`);
+ }
+ code = sm.gen_field_mul_bench_u32_shader(WORKGROUP_SIZE, variant);
+ } else {
+ if (variant !== 'sos3uv3') {
+ throw new Error(`f32 path supports variant 'sos3uv3', got '${variant}'`);
+ }
+ code = sm.gen_field_mul_bench_f32_shader(WORKGROUP_SIZE, variant);
+ }
+ const cacheKey = `field-mul-bench-${path}-wg${WORKGROUP_SIZE}`;
+ log('info', `path=${path}: compiling shader (${code.length} chars)`);
+ // Stash the rendered shader on window so external tooling can dump it
+ // for post-mortem analysis when validation fails.
+ (window as unknown as Record)[`__shader_${path}`] = code;
+ const { pipeline, layout } = await createPipeline(device, code, cacheKey);
+
+ // Buffers.
+ const xsBuf = device.createBuffer({
+ size: xsBytes.byteLength,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
+ });
+ device.queue.writeBuffer(xsBuf, 0, xsBytes);
+ const ysBuf = device.createBuffer({
+ size: ysBytes.byteLength,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
+ });
+ device.queue.writeBuffer(ysBuf, 0, ysBytes);
+ const outBytes = n * numLimbs * 4;
+ const outBuf = device.createBuffer({
+ size: outBytes,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
+ });
+ const uniformBytes = new ArrayBuffer(16);
+ const uniformView = new Uint32Array(uniformBytes);
+ uniformView[0] = n;
+ uniformView[1] = k;
+ const uniformBuf = device.createBuffer({
+ size: 16,
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
+ });
+ device.queue.writeBuffer(uniformBuf, 0, uniformBytes);
+
+ const bindGroup = device.createBindGroup({
+ layout,
+ entries: [
+ { binding: 0, resource: { buffer: xsBuf } },
+ { binding: 1, resource: { buffer: ysBuf } },
+ { binding: 2, resource: { buffer: outBuf } },
+ { binding: 3, resource: { buffer: uniformBuf } },
+ ],
+ });
+
+ const numWorkgroups = Math.ceil(n / WORKGROUP_SIZE);
+ log('info', `path=${path}: dispatching ${numWorkgroups} workgroups of ${WORKGROUP_SIZE} threads each (${n} threads total)`);
+
+ // Warmup pass — issued and awaited before the timed reps.
+ {
+ const encoder = device.createCommandEncoder();
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bindGroup);
+ pass.dispatchWorkgroups(numWorkgroups, 1, 1);
+ pass.end();
+ device.queue.submit([encoder.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ }
+ log('info', `path=${path}: warmup OK`);
+
+ // Validation pass — read back outputs.
+ const stagingBuf = device.createBuffer({
+ size: outBytes,
+ usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
+ });
+ {
+ const encoder = device.createCommandEncoder();
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bindGroup);
+ pass.dispatchWorkgroups(numWorkgroups, 1, 1);
+ pass.end();
+ encoder.copyBufferToBuffer(outBuf, 0, stagingBuf, 0, outBytes);
+ device.queue.submit([encoder.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ await stagingBuf.mapAsync(GPUMapMode.READ);
+ }
+ const outBytesCopy = stagingBuf.getMappedRange(0, outBytes).slice(0);
+ stagingBuf.unmap();
+ stagingBuf.destroy();
+
+ // Compare first validate-n outputs.
+ let mismatches: string[] = [];
+ let validateOk = true;
+ if (path === 'u32') {
+ const outU32 = new Uint32Array(outBytesCopy);
+ for (let i = 0; i < validateN; i++) {
+ const limbs = outU32.subarray(i * NUM_LIMBS_U32, (i + 1) * NUM_LIMBS_U32);
+ const got = limbsU32ToBigint(limbs);
+ if (got !== expected[i]) {
+ validateOk = false;
+ if (mismatches.length < 5) {
+ mismatches.push(
+ `pair[${i}]: a_can=0x${aCanonical[i].toString(16)} b_can=0x${bCanonical[i].toString(16)}\n` +
+ ` expected: 0x${expected[i].toString(16)}\n` +
+ ` actual: 0x${got.toString(16)}\n` +
+ ` expected_limbs: [${bigintToLimbsU32(expected[i]).join(', ')}]\n` +
+ ` actual_limbs: [${Array.from(limbs).join(', ')}]`,
+ );
+ }
+ }
+ }
+ } else {
+ const outF32 = new Float32Array(outBytesCopy);
+ for (let i = 0; i < validateN; i++) {
+ const limbs = outF32.subarray(i * NUM_LIMBS_F32, (i + 1) * NUM_LIMBS_F32);
+ const got = limbsF32ToBigint(limbs);
+ if (got !== expected[i]) {
+ validateOk = false;
+ if (mismatches.length < 5) {
+ mismatches.push(
+ `pair[${i}]: a_can=0x${aCanonical[i].toString(16)} b_can=0x${bCanonical[i].toString(16)}\n` +
+ ` expected: 0x${expected[i].toString(16)}\n` +
+ ` actual: 0x${got.toString(16)}\n` +
+ ` expected_limbs: [${bigintToLimbsF32(expected[i]).join(', ')}]\n` +
+ ` actual_limbs: [${Array.from(limbs).map(x => Math.round(x)).join(', ')}]`,
+ );
+ }
+ }
+ }
+ }
+
+ if (!validateOk) {
+ log('err', `path=${path}: VALIDATION FAILED (${mismatches.length}/${validateN} mismatches shown of total)`);
+ for (const m of mismatches) log('err', m);
+ // Diagnostic dump: log the first pair's input limbs as they were
+ // packed into the GPU buffer. If all outputs are zero this confirms
+ // the shader is not writing to the output buffer (vs. writing the
+ // wrong value).
+ const inLimbsA: number[] = [];
+ const inLimbsB: number[] = [];
+ if (path === 'u32') {
+ const xv = new Uint32Array(xsBytes, 0, NUM_LIMBS_U32);
+ const yv = new Uint32Array(ysBytes, 0, NUM_LIMBS_U32);
+ for (let j = 0; j < NUM_LIMBS_U32; j++) inLimbsA.push(xv[j]);
+ for (let j = 0; j < NUM_LIMBS_U32; j++) inLimbsB.push(yv[j]);
+ } else {
+ const xv = new Float32Array(xsBytes, 0, NUM_LIMBS_F32);
+ const yv = new Float32Array(ysBytes, 0, NUM_LIMBS_F32);
+ for (let j = 0; j < NUM_LIMBS_F32; j++) inLimbsA.push(xv[j]);
+ for (let j = 0; j < NUM_LIMBS_F32; j++) inLimbsB.push(yv[j]);
+ }
+ log('err', `path=${path}: pair[0] input limbs as packed: a=[${inLimbsA.join(', ')}] b=[${inLimbsB.join(', ')}]`);
+ log('err', `path=${path}: pair[0] input limbs expected from canonical: a=[${(path === 'u32' ? bigintToLimbsU32(aMont[0]) : bigintToLimbsF32(aMont[0])).join(', ')}]`);
+ xsBuf.destroy();
+ ysBuf.destroy();
+ outBuf.destroy();
+ uniformBuf.destroy();
+ return { path, validateOk: false, mismatches, timing: null };
+ }
+ log('ok', `path=${path}: VALIDATION OK (${validateN} pairs)`);
+
+ // Timed reps. Each rep = one dispatch + queue.onSubmittedWorkDone wait.
+ const msSamples: number[] = [];
+ for (let rep = 0; rep < reps; rep++) {
+ const encoder = device.createCommandEncoder();
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bindGroup);
+ pass.dispatchWorkgroups(numWorkgroups, 1, 1);
+ pass.end();
+ const t0 = performance.now();
+ device.queue.submit([encoder.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ const t1 = performance.now();
+ msSamples.push(t1 - t0);
+ }
+ const msMed = median(msSamples);
+ const msMin = Math.min(...msSamples);
+ const msMax = Math.max(...msSamples);
+ const totalMults = n * k;
+ const multsPerSec = totalMults / (msMed / 1000);
+ log(
+ 'ok',
+ `path=${path}: timing reps=${reps} median=${msMed.toFixed(3)}ms min=${msMin.toFixed(3)}ms max=${msMax.toFixed(3)}ms mults/s=${multsPerSec.toExponential(3)} (n*k=${totalMults.toLocaleString()})`,
+ );
+
+ xsBuf.destroy();
+ ysBuf.destroy();
+ outBuf.destroy();
+ uniformBuf.destroy();
+
+ return {
+ path,
+ validateOk: true,
+ mismatches: [],
+ timing: { reps, msSamples, msMedian: msMed, msMin, msMax, multsPerSec },
+ };
+}
+
+function parseParams(): {
+ path: Path | 'both';
+ n: number;
+ k: number;
+ validateN: number;
+ reps: number;
+ debug: string | null;
+} {
+ const qp = new URLSearchParams(window.location.search);
+ const pathStr = qp.get('path') ?? 'both';
+ if (pathStr !== 'u32' && pathStr !== 'f32' && pathStr !== 'both') {
+ throw new Error(`?path must be u32|f32|both, got ${pathStr}`);
+ }
+ const n = parseInt(qp.get('n') ?? '64', 10);
+ const k = parseInt(qp.get('k') ?? '1', 10);
+ const validateN = parseInt(qp.get('validate-n') ?? String(Math.min(64, n)), 10);
+ const reps = parseInt(qp.get('reps') ?? '3', 10);
+ const debug = qp.get('debug');
+ if (!Number.isFinite(n) || n <= 0 || n > N_MAX) {
+ throw new Error(`?n must be in (0, ${N_MAX}], got ${qp.get('n')}`);
+ }
+ if (!Number.isFinite(k) || k <= 0 || k > K_MAX) {
+ throw new Error(`?k must be in (0, ${K_MAX}], got ${qp.get('k')}`);
+ }
+ if (!Number.isFinite(validateN) || validateN < 0 || validateN > n) {
+ throw new Error(`?validate-n must be in [0, n], got ${qp.get('validate-n')}`);
+ }
+ if (!Number.isFinite(reps) || reps <= 0 || reps > 100) {
+ throw new Error(`?reps must be in (0, 100], got ${qp.get('reps')}`);
+ }
+ return { path: pathStr as Path | 'both', n, k, validateN, reps, debug };
+}
+
+async function runDebugMulhiloF32(device: GPUDevice, _sm: ShaderManager): Promise {
+ log('info', `[debug=mulhilo] testing mulhilo on hard-coded values`);
+ // Test cases: hand-picked products that exercise specific bit patterns.
+ const cases = [
+ { a: 1443728, b: 418697 },
+ { a: 1, b: 1 },
+ { a: 8388607, b: 8388607 },
+ { a: 4194304, b: 2 },
+ { a: 100, b: 200 },
+ ];
+ const inputBytes = new ArrayBuffer(cases.length * 8);
+ const iv = new Float32Array(inputBytes);
+ for (let i = 0; i < cases.length; i++) {
+ iv[i * 2] = cases[i].a;
+ iv[i * 2 + 1] = cases[i].b;
+ }
+ const outBytes = cases.length * 2 * 4;
+ const inBuf = device.createBuffer({ size: inputBytes.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST });
+ device.queue.writeBuffer(inBuf, 0, inputBytes);
+ const outBuf = device.createBuffer({ size: outBytes, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC });
+
+ // Just inline the mulhilo constants and function directly.
+ const code = `
+const BIAS: f32 = 70368744177664.0;
+const W: f32 = 8388608.0;
+const W_INV: f32 = 1.1920928955078125e-7;
+
+fn mulhilo(a: f32, b: f32) -> vec2 {
+ let q = fma(a, b, BIAS) - BIAS;
+ let lo0 = fma(a, b, -q);
+ let underflow = step(lo0, -0.5);
+ let hi = q * W_INV - underflow;
+ let lo = lo0 + underflow * W;
+ return vec2(hi, lo);
+}
+
+@group(0) @binding(0) var ins: array>;
+@group(0) @binding(1) var outs: array>;
+
+@compute @workgroup_size(1)
+fn main(@builtin(global_invocation_id) gid: vec3) {
+ let n = arrayLength(&ins);
+ if (gid.x >= n) { return; }
+ let v = ins[gid.x];
+ outs[gid.x] = mulhilo(v.x, v.y);
+}
+`;
+ const module = device.createShaderModule({ code });
+ await module.getCompilationInfo();
+ const layout = device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ ],
+ });
+ const pipeline = await device.createComputePipelineAsync({
+ layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }),
+ compute: { module, entryPoint: 'main' },
+ });
+ const bg = device.createBindGroup({
+ layout,
+ entries: [
+ { binding: 0, resource: { buffer: inBuf } },
+ { binding: 1, resource: { buffer: outBuf } },
+ ],
+ });
+ const enc = device.createCommandEncoder();
+ const ps = enc.beginComputePass();
+ ps.setPipeline(pipeline);
+ ps.setBindGroup(0, bg);
+ ps.dispatchWorkgroups(cases.length, 1, 1);
+ ps.end();
+ const stagingBuf = device.createBuffer({ size: outBytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST });
+ enc.copyBufferToBuffer(outBuf, 0, stagingBuf, 0, outBytes);
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ await stagingBuf.mapAsync(GPUMapMode.READ);
+ const out = new Float32Array(stagingBuf.getMappedRange().slice(0));
+ stagingBuf.unmap();
+ stagingBuf.destroy();
+ inBuf.destroy();
+ outBuf.destroy();
+
+ for (let i = 0; i < cases.length; i++) {
+ const { a, b } = cases[i];
+ const prod = BigInt(a) * BigInt(b);
+ const W = 8388608n;
+ const expHi = prod / W;
+ const expLo = prod % W;
+ const gotHi = Math.round(out[i * 2]);
+ const gotLo = Math.round(out[i * 2 + 1]);
+ const ok = (gotHi === Number(expHi)) && (gotLo === Number(expLo));
+ log(ok ? 'ok' : 'err', `[debug] mulhilo(${a}, ${b}) = (hi=${gotHi}, lo=${gotLo}); expected (hi=${expHi}, lo=${expLo}) ${ok ? 'OK' : 'WRONG'}`);
+ }
+}
+
+async function runDebugF32(device: GPUDevice, sm: ShaderManager, debugTag: string): Promise {
+ log('info', `[debug=${debugTag}] running f32 Mont debug shader`);
+ const p = BN254_CURVE_CONFIG.baseFieldModulus;
+ const params_f32 = compute_misc_params(p, WORD_SIZE_F32);
+ const R = params_f32.r;
+ const Rinv = params_f32.rinv;
+ log('info', `[debug] n0_f32=${params_f32.n0.toString()} num_words=${params_f32.num_words}`);
+
+ // Pair: a = 1, b = 1 (canonical). aMont = R mod p, bMont = R mod p.
+ // Mont(R, R, R^-1) = R^2 * R^-1 = R = aMont. So expected output limbs == aMont limbs.
+ const aMont = R % p;
+ const bMont = R % p;
+ const aLimbs = bigintToLimbsF32(aMont);
+ const bLimbs = bigintToLimbsF32(bMont);
+ const xsBytes = new ArrayBuffer(NUM_LIMBS_F32 * 4);
+ const ysBytes = new ArrayBuffer(NUM_LIMBS_F32 * 4);
+ const xv = new Float32Array(xsBytes);
+ const yv = new Float32Array(ysBytes);
+ for (let j = 0; j < NUM_LIMBS_F32; j++) xv[j] = aLimbs[j];
+ for (let j = 0; j < NUM_LIMBS_F32; j++) yv[j] = bLimbs[j];
+
+ // Capture intermediates: 64 f32 slots.
+ // [0..11] = s[0..11] AFTER i=0 outer iter
+ // [12..15] = (xy0_lo, xy0_hi, sum0, qi) at i=0
+ // [16..19] = (qp0_lo, qp0_hi, c_lo_init, c_hi_init) at i=0
+ // [20..31] = output limbs (after full montgomery_product)
+ // [32..33] = direct mulhilo(1443728.0, 418697.0).{x,y}
+ // [34..35] = direct mulhilo(sum0_s.y, N0).{x,y}
+ // [36..39] = (sum0_s.x, sum0_s.y, N0 echo, x.limbs[0] echo)
+ // [40..41] = bias_split_f32(1443728.0).{x,y}
+ // [42..43] = bias_split_f32(sum0).{x,y} // sum0 is the actual variable
+ // [44..47] = (bias_split_f32(rv).{x,y}, bias_split_f32(xy0.y).{x,y})
+ // [48..63] = j=11 intermediates (xyj.{x,y}, qpj.{x,y}, t1, t1_s.{x,y}, t2, t2_s.{x,y}, t3, t3_s.{x,y}, c_lo_before, c_lo_after)
+ const DEBUG_SLOTS = 64;
+ const xsBuf = device.createBuffer({
+ size: xsBytes.byteLength,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
+ });
+ device.queue.writeBuffer(xsBuf, 0, xsBytes);
+ const ysBuf = device.createBuffer({
+ size: ysBytes.byteLength,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
+ });
+ device.queue.writeBuffer(ysBuf, 0, ysBytes);
+ const dbgBuf = device.createBuffer({
+ size: DEBUG_SLOTS * 4,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
+ });
+
+ // Use ShaderManager's helper bundle (mulhilo + bigint_f32 + montgomery_product_f32).
+ const helpers = sm.gen_montgomery_product_f32_shader();
+ // Append a debug entry point that mirrors the Mont algorithm but captures
+ // per-position state for i=0 and finally writes the full Mont output.
+ const debugEntry = `
+@group(0) @binding(0) var xs: array;
+@group(0) @binding(1) var ys: array;
+@group(0) @binding(2) var dbg: array;
+
+@compute @workgroup_size(1)
+fn main(@builtin(global_invocation_id) gid: vec3) {
+ if (gid.x != 0u) { return; }
+ var x = xs[0];
+ var y = ys[0];
+
+ // Run Mont. Capture intermediates by inlining the first outer iteration.
+ var s: BigIntF32;
+ for (var k = 0u; k < NUM_LIMBS; k ++) { s.limbs[k] = 0.0; }
+ var pp = get_p_f32();
+
+ // i = 0
+ let xy0 = mulhilo(x.limbs[0], y.limbs[0]);
+ let sum0 = s.limbs[0] + xy0.y;
+ let sum0_s = bias_split_f32(sum0);
+ let qi = mulhilo(sum0_s.y, N0).y;
+ let qp0 = mulhilo(qi, pp.limbs[0]);
+ let lo_cancel = sum0_s.y + qp0.y;
+ let c_small = lo_cancel * W_INV + sum0_s.x;
+ let hi_pair = xy0.x + qp0.x;
+ let carry_full = hi_pair + c_small;
+ let carry_s = bias_split_f32(carry_full);
+ var c_hi = carry_s.x;
+ var c_lo = carry_s.y;
+
+ dbg[12] = xy0.y;
+ dbg[13] = xy0.x;
+ dbg[14] = sum0;
+ dbg[15] = qi;
+ dbg[16] = qp0.y;
+ dbg[17] = qp0.x;
+ dbg[18] = c_lo;
+ dbg[19] = c_hi;
+
+ // Direct test: mulhilo(1443728.0, 418697.0).y. Should be 1489936.
+ let direct_test = mulhilo(1443728.0, 418697.0);
+ dbg[32] = direct_test.x; // hi
+ dbg[33] = direct_test.y; // lo
+ // Also test with sum0_s.y and N0 directly (no Mont context).
+ let direct_test2 = mulhilo(sum0_s.y, N0);
+ dbg[34] = direct_test2.x;
+ dbg[35] = direct_test2.y;
+ dbg[36] = sum0_s.x;
+ dbg[37] = sum0_s.y;
+ dbg[38] = N0;
+ dbg[39] = x.limbs[0];
+
+ // Direct call to bias_split_f32 with a hard-coded constant.
+ let bs_const = bias_split_f32(1443728.0);
+ dbg[40] = bs_const.x;
+ dbg[41] = bs_const.y;
+ // Direct call to bias_split_f32 with the actual sum0 value.
+ let bs_var = bias_split_f32(sum0);
+ dbg[42] = bs_var.x;
+ dbg[43] = bs_var.y;
+ // Direct call with a runtime-derived variable that equals 1443728.
+ var rv: f32 = 1443728.0;
+ let bs_rv = bias_split_f32(rv);
+ dbg[44] = bs_rv.x;
+ dbg[45] = bs_rv.y;
+ // Direct call with xy0.y (the mulhilo result, runtime variable).
+ let bs_xy = bias_split_f32(xy0.y);
+ dbg[46] = bs_xy.x;
+ dbg[47] = bs_xy.y;
+
+
+ for (var j = 1u; j < NUM_LIMBS; j ++) {
+ let xyj = mulhilo(x.limbs[0], y.limbs[j]);
+ let qpj = mulhilo(qi, pp.limbs[j]);
+ let t1 = s.limbs[j] + xyj.y;
+ let t1_s = bias_split_f32(t1);
+ let t2 = t1_s.y + qpj.y;
+ let t2_s = bias_split_f32(t2);
+ let c_lo_before = c_lo;
+ let t3 = t2_s.y + c_lo;
+ let t3_s = bias_split_f32(t3);
+ if (j == 11u) {
+ dbg[60] = c_lo_before;
+ dbg[61] = t3_s.y;
+ // Direct re-check: what does bias_split_f32(5290618.0) return here?
+ let dt = bias_split_f32(5290618.0);
+ dbg[62] = dt.x;
+ dbg[63] = dt.y;
+ }
+ s.limbs[j - 1u] = t3_s.y;
+ let sum_overflow = t1_s.x + t2_s.x + t3_s.x + c_hi;
+ let nc1 = xyj.x + qpj.x;
+ let nc1_s = bias_split_f32(nc1);
+ let nc2 = nc1_s.y + sum_overflow;
+ let nc2_s = bias_split_f32(nc2);
+ c_hi = nc1_s.x + nc2_s.x;
+ c_lo = nc2_s.y;
+ if (j == 11u) {
+ // Capture j=11 intermediates.
+ dbg[48] = xyj.x;
+ dbg[49] = xyj.y;
+ dbg[50] = qpj.x;
+ dbg[51] = qpj.y;
+ dbg[52] = t1;
+ dbg[53] = t1_s.y;
+ dbg[54] = t2;
+ dbg[55] = t2_s.y;
+ dbg[56] = t3;
+ dbg[57] = t3_s.x;
+ dbg[58] = t3_s.y;
+ dbg[59] = c_lo; // this is c_lo AFTER update (since we capture after assignment)
+ }
+ }
+ s.limbs[NUM_LIMBS - 1u] = fma(c_hi, W, c_lo);
+
+ for (var k = 0u; k < NUM_LIMBS; k ++) { dbg[k] = s.limbs[k]; }
+
+ // Now call the full Mont function to compare.
+ var x2 = xs[0];
+ var y2 = ys[0];
+ let full = montgomery_product_f32(&x2, &y2);
+ for (var k = 0u; k < NUM_LIMBS; k ++) { dbg[20u + k] = full.limbs[k]; }
+}
+`;
+ const code = `${helpers}\n${debugEntry}`;
+ const module = device.createShaderModule({ code });
+ const ci = await module.getCompilationInfo();
+ let hasErr = false;
+ for (const m of ci.messages) {
+ if (m.type === 'error') {
+ console.error(`[debug shader] error: ${m.message} (line ${m.lineNum})`);
+ hasErr = true;
+ } else {
+ console.warn(`[debug shader] ${m.type}: ${m.message} (line ${m.lineNum})`);
+ }
+ }
+ if (hasErr) throw new Error('debug shader compile failed');
+
+ const layout = device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ ],
+ });
+ const pipeline = await device.createComputePipelineAsync({
+ layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }),
+ compute: { module, entryPoint: 'main' },
+ });
+ const bg = device.createBindGroup({
+ layout,
+ entries: [
+ { binding: 0, resource: { buffer: xsBuf } },
+ { binding: 1, resource: { buffer: ysBuf } },
+ { binding: 2, resource: { buffer: dbgBuf } },
+ ],
+ });
+
+ const enc = device.createCommandEncoder();
+ const ps = enc.beginComputePass();
+ ps.setPipeline(pipeline);
+ ps.setBindGroup(0, bg);
+ ps.dispatchWorkgroups(1, 1, 1);
+ ps.end();
+ const stagingBuf = device.createBuffer({
+ size: DEBUG_SLOTS * 4,
+ usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
+ });
+ enc.copyBufferToBuffer(dbgBuf, 0, stagingBuf, 0, DEBUG_SLOTS * 4);
+ device.queue.submit([enc.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ await stagingBuf.mapAsync(GPUMapMode.READ);
+ const view = new Float32Array(stagingBuf.getMappedRange().slice(0));
+ stagingBuf.unmap();
+ stagingBuf.destroy();
+
+ const fmtArr = (arr: Float32Array, start: number, end: number) =>
+ Array.from(arr.subarray(start, end))
+ .map(v => Math.round(v))
+ .join(', ');
+
+ log('info', `[debug] input a (Mont(1)): [${aLimbs.join(', ')}]`);
+ log('info', `[debug] input b (Mont(1)): [${bLimbs.join(', ')}]`);
+ log('info', `[debug] inline-i0 s[0..11] = [${fmtArr(view, 0, 12)}]`);
+ log('info', `[debug] inline-i0 xy0=(lo=${Math.round(view[12])}, hi=${Math.round(view[13])}) sum0=${Math.round(view[14])} qi=${Math.round(view[15])} qp0=(lo=${Math.round(view[16])}, hi=${Math.round(view[17])}) c_lo=${Math.round(view[18])} c_hi=${Math.round(view[19])}`);
+ log('info', `[debug] mont_full output = [${fmtArr(view, 20, 32)}]`);
+ log('info', `[debug] direct mulhilo(1443728, 418697) = (hi=${Math.round(view[32])}, lo=${Math.round(view[33])})`);
+ log('info', `[debug] direct mulhilo(sum0_s.y, N0) = (hi=${Math.round(view[34])}, lo=${Math.round(view[35])})`);
+ log('info', `[debug] sum0_s=(hi=${Math.round(view[36])}, lo=${Math.round(view[37])}) N0_echo=${Math.round(view[38])} x.limbs[0]_echo=${Math.round(view[39])}`);
+ log('info', `[debug] bias_split_f32(1443728.0) = (hi=${Math.round(view[40])}, lo=${Math.round(view[41])})`);
+ log('info', `[debug] bias_split_f32(sum0) = (hi=${Math.round(view[42])}, lo=${Math.round(view[43])})`);
+ log('info', `[debug] bias_split_f32(var rv=1443728.0) = (hi=${Math.round(view[44])}, lo=${Math.round(view[45])})`);
+ log('info', `[debug] bias_split_f32(xy0.y) = (hi=${Math.round(view[46])}, lo=${Math.round(view[47])})`);
+ log('info', `[debug] j=11: xyj=(hi=${Math.round(view[48])}, lo=${Math.round(view[49])}) qpj=(hi=${Math.round(view[50])}, lo=${Math.round(view[51])})`);
+ log('info', `[debug] j=11: t1=${Math.round(view[52])} t1_s.y=${Math.round(view[53])} t2=${Math.round(view[54])} t2_s.y=${Math.round(view[55])}`);
+ log('info', `[debug] j=11: t3=${Math.round(view[56])} t3_s.x=${Math.round(view[57])} t3_s.y=${Math.round(view[58])}`);
+ log('info', `[debug] j=11: c_lo_before=${Math.round(view[60])} t3_s.y_recheck=${Math.round(view[61])}`);
+ log('info', `[debug] j=11: bias_split_f32(5290618.0) direct = (hi=${Math.round(view[62])}, lo=${Math.round(view[63])})`);
+ log('info', `[debug] j=11: c_lo_after = ${Math.round(view[59])}`);
+
+ xsBuf.destroy();
+ ysBuf.destroy();
+ dbgBuf.destroy();
+}
+
+async function main() {
+ try {
+ if (!('gpu' in navigator)) {
+ throw new Error('navigator.gpu missing — WebGPU not available');
+ }
+ const params = parseParams();
+ benchState.params = params;
+ log('info', `params: path=${params.path} n=${params.n} k=${params.k} validate-n=${params.validateN} reps=${params.reps}`);
+
+ benchState.state = 'running';
+ const device = await get_device();
+ log('info', `WebGPU device acquired`);
+
+ // ShaderManager is keyed on chunk_size / input_size for the MSM
+ // pipeline; for the micro-bench we only need its Mont-constant
+ // pre-computation, so values are arbitrary.
+ const sm = new ShaderManager(4, params.n, BN254_CURVE_CONFIG, false);
+
+ if (params.debug) {
+ if (params.debug === 'mulhilo') {
+ await runDebugMulhiloF32(device, sm);
+ } else {
+ await runDebugF32(device, sm, params.debug);
+ }
+ benchState.state = 'done';
+ log('ok', `[debug] done`);
+ return;
+ }
+
+ const paths: Path[] = params.path === 'both' ? ['u32', 'f32'] : [params.path];
+ for (const path of paths) {
+ const result = await runPath(device, sm, path, params.n, params.k, params.validateN, params.reps);
+ benchState.results.push(result);
+ if (!result.validateOk) {
+ // Surface the failure but continue with the other path so the
+ // caller can see both results in one shot if requested.
+ log('err', `path=${path} failed validation — stopping path traversal`);
+ break;
+ }
+ }
+
+ benchState.state = 'done';
+ log('ok', `bench done: ${benchState.results.length} paths`);
+ } catch (e) {
+ const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e);
+ log('err', `FATAL: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ }
+}
+
+main().catch(e => {
+ const msg = e instanceof Error ? e.message : String(e);
+ log('err', `unhandled: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+});
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-fr-inv.html b/barretenberg/ts/dev/msm-webgpu/bench-fr-inv.html
new file mode 100644
index 000000000000..72d615c5f776
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-fr-inv.html
@@ -0,0 +1,37 @@
+
+
+
+
+ BY fr_inv_by WebGPU dispatch test
+
+
+
+ BY fr_inv_by WebGPU dispatch test
+ Query params: ?n=N&k=K&validate-n=N&reps=R
+
+
+
+
diff --git a/barretenberg/ts/dev/msm-webgpu/bench-fr-inv.ts b/barretenberg/ts/dev/msm-webgpu/bench-fr-inv.ts
new file mode 100644
index 000000000000..3db8bb7944a8
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/bench-fr-inv.ts
@@ -0,0 +1,472 @@
+///
+// fr_inv WebGPU dispatch test. Mounted standalone (no SRS, no MSM
+// pipeline). Reads `?n=N&k=K&validate-n=N&reps=R&variant=fr_inv|fr_inv_by`,
+// generates seeded random BN254 base-field values in Montgomery form, runs
+// `k` chained inverse calls per thread, validates each output against the
+// host `modInverse` + Mont-correction reference (same reference for both
+// variants — the algorithms must produce identical outputs), and reports
+// timing via `window.__bench`.
+//
+// Safety: `n` is capped at 2^20, `k` at 100. The only data-dependent loop
+// in the entry shader is `for (var i = 0u; i < k; ...)` with `k` capped.
+// Every loop inside `fr_inv` / `fr_inv_by` (and its callees) is bounded by
+// a compile-time `const` — see by_inverse.template.wgsl, bigint_by,
+// fr_pow.template.wgsl, and montgomery_product.
+
+import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js';
+import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js';
+import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js';
+import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js';
+import { modInverse, BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js';
+
+interface SampleSummary {
+ reps: number;
+ msSamples: number[];
+ msMedian: number;
+ msMin: number;
+ msMax: number;
+ inversionsPerSec: number;
+}
+interface BenchResult {
+ validateOk: boolean;
+ mismatches: string[];
+ timing: SampleSummary | null;
+}
+interface BenchState {
+ state: 'boot' | 'running' | 'done' | 'error';
+ params: {
+ n: number;
+ k: number;
+ validateN: number;
+ reps: number;
+ variant: 'fr_inv' | 'fr_inv_by' | 'fr_inv_by_a' | 'fr_pow_inv';
+ } | null;
+ result: BenchResult | null;
+ error: string | null;
+ log: string[];
+}
+
+const benchState: BenchState = {
+ state: 'boot',
+ params: null,
+ result: null,
+ error: null,
+ log: [],
+};
+(window as unknown as { __bench: BenchState }).__bench = benchState;
+
+const $log = document.getElementById('log') as HTMLDivElement;
+function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) {
+ const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : '';
+ const span = document.createElement('div');
+ span.className = cls;
+ span.textContent = msg;
+ $log.appendChild(span);
+ benchState.log.push(`[${level}] ${msg}`);
+ console.log(`[bench-fr-inv] ${msg}`);
+}
+
+const N_MAX = 1 << 20;
+const K_MAX = 100;
+
+// 20 x 13-bit limb layout matches the u32 MSM path (the only path that
+// fr_inv_by is wired into in this bench).
+const NUM_LIMBS_U32 = 20;
+const WORD_SIZE_U32 = 13;
+const W_U32 = 1n << BigInt(WORD_SIZE_U32);
+const MASK_U32 = W_U32 - 1n;
+
+function bigintToLimbsU32(v: bigint): number[] {
+ const limbs: number[] = new Array(NUM_LIMBS_U32);
+ let x = v;
+ for (let i = 0; i < NUM_LIMBS_U32; i++) {
+ limbs[i] = Number(x & MASK_U32);
+ x >>= BigInt(WORD_SIZE_U32);
+ }
+ return limbs;
+}
+
+function limbsU32ToBigint(limbs: ArrayLike): bigint {
+ let v = 0n;
+ for (let i = NUM_LIMBS_U32 - 1; i >= 0; i--) {
+ v = (v << BigInt(WORD_SIZE_U32)) | BigInt(limbs[i] >>> 0);
+ }
+ return v;
+}
+
+// Seeded LCG (Numerical Recipes constants) for reproducible input gen.
+// Matches the seeding convention in bench-field-mul.ts.
+function makeRng(seed: number): () => number {
+ let state = (seed >>> 0) || 1;
+ return () => {
+ state = (Math.imul(state, 1664525) + 1013904223) >>> 0;
+ return state;
+ };
+}
+
+function randomBelow(p: bigint, rng: () => number): bigint {
+ const bitlen = p.toString(2).length;
+ const byteLen = Math.ceil(bitlen / 8);
+ while (true) {
+ let v = 0n;
+ for (let i = 0; i < byteLen; i++) {
+ v = (v << 8n) | BigInt(rng() & 0xff);
+ }
+ v &= (1n << BigInt(bitlen)) - 1n;
+ if (v < p) return v;
+ }
+}
+
+function median(xs: number[]): number {
+ if (xs.length === 0) return NaN;
+ const s = xs.slice().sort((a, b) => a - b);
+ return s[Math.floor(s.length / 2)];
+}
+
+// Host reference: chained `fr_inv_by` over Mont-form inputs.
+//
+// fr_inv_by maps Mont(A) -> Mont(A^(-1)). Concretely, if input is
+// a_m = A * R mod p, then output is A^(-1) * R mod p.
+//
+// One way to compute this on the host: unmont -> modInverse -> mont.
+// A = a_m * Rinv mod p
+// A_inv = modInverse(A, p) = A^(-1) mod p (canonical)
+// out_m = A_inv * R mod p
+//
+// Equivalently: out_m = modInverse(a_m, p) * R^2 mod p (since
+// modInverse(a_m) = (A*R)^(-1) = A^(-1) * R^(-1), times R^2 lands at
+// A^(-1) * R = out_m). Either formula works; using the unmont->inv->mont
+// chain mirrors how a user typically thinks about Mont arithmetic.
+function fr_inv_by_host_once(a_mont: bigint, R: bigint, Rinv: bigint, p: bigint): bigint {
+ if (a_mont === 0n) return 0n;
+ const a_canonical = (a_mont * Rinv) % p;
+ const a_inv = modInverse(a_canonical, p);
+ return (a_inv * R) % p;
+}
+
+function fr_inv_by_host_chained(
+ a_mont: bigint,
+ k: number,
+ R: bigint,
+ Rinv: bigint,
+ p: bigint,
+): bigint {
+ let acc = a_mont;
+ for (let i = 0; i < k; i++) {
+ acc = fr_inv_by_host_once(acc, R, Rinv, p);
+ }
+ return acc;
+}
+
+async function createPipeline(
+ device: GPUDevice,
+ code: string,
+ cacheKey: string,
+): Promise<{ pipeline: GPUComputePipeline; layout: GPUBindGroupLayout }> {
+ const module = device.createShaderModule({ code });
+ const info = await module.getCompilationInfo();
+ let hasError = false;
+ for (const msg of info.messages) {
+ const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`;
+ if (msg.type === 'error') {
+ console.error(line);
+ hasError = true;
+ } else {
+ console.warn(line);
+ }
+ }
+ if (hasError) {
+ throw new Error(`WGSL compile failed for ${cacheKey}`);
+ }
+ const layout = device.createBindGroupLayout({
+ entries: [
+ { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
+ { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
+ { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
+ ],
+ });
+ const pipeline = await device.createComputePipelineAsync({
+ layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }),
+ compute: { module, entryPoint: 'main' },
+ });
+ return { pipeline, layout };
+}
+
+async function runBench(
+ device: GPUDevice,
+ sm: ShaderManager,
+ n: number,
+ k: number,
+ validateN: number,
+ reps: number,
+ variant: 'fr_inv' | 'fr_inv_by' | 'fr_inv_by_a' | 'fr_pow_inv',
+): Promise {
+ log('info', `building inputs (n=${n}, k=${k}, validate-n=${validateN}, reps=${reps}, variant=${variant})`);
+ const p = BN254_BASE_FIELD;
+ const params = compute_misc_params(p, WORD_SIZE_U32);
+ if (params.num_words !== NUM_LIMBS_U32) {
+ throw new Error(`expected num_words=${NUM_LIMBS_U32}, got ${params.num_words}`);
+ }
+ const R = params.r;
+ const Rinv = params.rinv;
+ if ((R * Rinv) % p !== 1n) {
+ throw new Error(`R * Rinv mod p != 1`);
+ }
+
+ // Generate random canonical inputs.
+ const rng = makeRng(0xc0ffee);
+ const aCanonical: bigint[] = new Array(n);
+ for (let i = 0; i < n; i++) {
+ // Ensure nonzero — fr_inv_by of 0 is undefined / 0 by convention but
+ // the chained case requires invertibility every step. Reject 0.
+ let v = randomBelow(p, rng);
+ while (v === 0n) v = randomBelow(p, rng);
+ aCanonical[i] = v;
+ }
+ const aMont: bigint[] = aCanonical.map(x => (x * R) % p);
+
+ // Host reference for the first validate-n inputs.
+ log('info', `computing host reference for ${validateN} pairs`);
+ const expected: bigint[] = new Array(validateN);
+ for (let i = 0; i < validateN; i++) {
+ expected[i] = fr_inv_by_host_chained(aMont[i], k, R, Rinv, p);
+ }
+
+ // Pack input buffer: n BigInts (20 x u32 each).
+ const bytesPerLimbArray = NUM_LIMBS_U32 * 4;
+ const xsBytes = new ArrayBuffer(n * bytesPerLimbArray);
+ const xv = new Uint32Array(xsBytes);
+ for (let i = 0; i < n; i++) {
+ const limbs = bigintToLimbsU32(aMont[i]);
+ const off = i * NUM_LIMBS_U32;
+ for (let j = 0; j < NUM_LIMBS_U32; j++) xv[off + j] = limbs[j];
+ }
+
+ // Shader & pipeline.
+ const WORKGROUP_SIZE = 64;
+ const code = sm.gen_fr_inv_bench_shader(WORKGROUP_SIZE, variant);
+ const cacheKey = `fr-inv-bench-${variant}-wg${WORKGROUP_SIZE}`;
+ log('info', `compiling shader (${code.length} chars)`);
+ (window as unknown as Record)[`__shader`] = code;
+ const { pipeline, layout } = await createPipeline(device, code, cacheKey);
+
+ // Buffers.
+ const xsBuf = device.createBuffer({
+ size: xsBytes.byteLength,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
+ });
+ device.queue.writeBuffer(xsBuf, 0, xsBytes);
+ const outBytes = n * NUM_LIMBS_U32 * 4;
+ const outBuf = device.createBuffer({
+ size: outBytes,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
+ });
+ const uniformBytes = new ArrayBuffer(16);
+ const uniformView = new Uint32Array(uniformBytes);
+ uniformView[0] = n;
+ uniformView[1] = k;
+ const uniformBuf = device.createBuffer({
+ size: 16,
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
+ });
+ device.queue.writeBuffer(uniformBuf, 0, uniformBytes);
+
+ const bindGroup = device.createBindGroup({
+ layout,
+ entries: [
+ { binding: 0, resource: { buffer: xsBuf } },
+ { binding: 1, resource: { buffer: outBuf } },
+ { binding: 2, resource: { buffer: uniformBuf } },
+ ],
+ });
+
+ const numWorkgroups = Math.ceil(n / WORKGROUP_SIZE);
+ log('info', `dispatching ${numWorkgroups} workgroups of ${WORKGROUP_SIZE} threads (${n} threads total)`);
+
+ // Warmup pass.
+ {
+ const encoder = device.createCommandEncoder();
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bindGroup);
+ pass.dispatchWorkgroups(numWorkgroups, 1, 1);
+ pass.end();
+ device.queue.submit([encoder.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ }
+ log('info', 'warmup OK');
+
+ // Validation pass — read back outputs.
+ const stagingBuf = device.createBuffer({
+ size: outBytes,
+ usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
+ });
+ {
+ const encoder = device.createCommandEncoder();
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bindGroup);
+ pass.dispatchWorkgroups(numWorkgroups, 1, 1);
+ pass.end();
+ encoder.copyBufferToBuffer(outBuf, 0, stagingBuf, 0, outBytes);
+ device.queue.submit([encoder.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ await stagingBuf.mapAsync(GPUMapMode.READ);
+ }
+ const outBytesCopy = stagingBuf.getMappedRange(0, outBytes).slice(0);
+ stagingBuf.unmap();
+ stagingBuf.destroy();
+
+ const outU32 = new Uint32Array(outBytesCopy);
+ const mismatches: string[] = [];
+ let validateOk = true;
+ for (let i = 0; i < validateN; i++) {
+ const limbs = outU32.subarray(i * NUM_LIMBS_U32, (i + 1) * NUM_LIMBS_U32);
+ const got = limbsU32ToBigint(limbs);
+ if (got !== expected[i]) {
+ validateOk = false;
+ if (mismatches.length < 5) {
+ mismatches.push(
+ `pair[${i}]: a_canonical=0x${aCanonical[i].toString(16)} a_mont=0x${aMont[i].toString(16)}\n` +
+ ` expected: 0x${expected[i].toString(16)}\n` +
+ ` actual: 0x${got.toString(16)}\n` +
+ ` expected_limbs: [${bigintToLimbsU32(expected[i]).join(', ')}]\n` +
+ ` actual_limbs: [${Array.from(limbs).join(', ')}]`,
+ );
+ }
+ }
+ }
+
+ if (!validateOk) {
+ log('err', `VALIDATION FAILED (${mismatches.length} mismatches shown; first ${validateN} pairs checked)`);
+ for (const m of mismatches) log('err', m);
+ xsBuf.destroy();
+ outBuf.destroy();
+ uniformBuf.destroy();
+ return { validateOk: false, mismatches, timing: null };
+ }
+ log('ok', `VALIDATION OK (${validateN} pairs)`);
+
+ // Timed reps.
+ const msSamples: number[] = [];
+ for (let rep = 0; rep < reps; rep++) {
+ const encoder = device.createCommandEncoder();
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bindGroup);
+ pass.dispatchWorkgroups(numWorkgroups, 1, 1);
+ pass.end();
+ const t0 = performance.now();
+ device.queue.submit([encoder.finish()]);
+ await device.queue.onSubmittedWorkDone();
+ const t1 = performance.now();
+ msSamples.push(t1 - t0);
+ }
+ const msMed = median(msSamples);
+ const msMin = Math.min(...msSamples);
+ const msMax = Math.max(...msSamples);
+ const totalInv = n * k;
+ const inversionsPerSec = totalInv / (msMed / 1000);
+ log(
+ 'ok',
+ `timing reps=${reps} median=${msMed.toFixed(3)}ms min=${msMin.toFixed(3)}ms max=${msMax.toFixed(3)}ms inversions/s=${inversionsPerSec.toExponential(3)} (n*k=${totalInv.toLocaleString()})`,
+ );
+
+ xsBuf.destroy();
+ outBuf.destroy();
+ uniformBuf.destroy();
+
+ return {
+ validateOk: true,
+ mismatches: [],
+ timing: { reps, msSamples, msMedian: msMed, msMin, msMax, inversionsPerSec },
+ };
+}
+
+function parseParams(): {
+ n: number;
+ k: number;
+ validateN: number;
+ reps: number;
+ variant: 'fr_inv' | 'fr_inv_by' | 'fr_inv_by_a' | 'fr_pow_inv';
+} {
+ const qp = new URLSearchParams(window.location.search);
+ const n = parseInt(qp.get('n') ?? '64', 10);
+ const k = parseInt(qp.get('k') ?? '1', 10);
+ const validateN = parseInt(qp.get('validate-n') ?? String(Math.min(64, n)), 10);
+ const reps = parseInt(qp.get('reps') ?? '1', 10);
+ const variantStr = qp.get('variant') ?? 'fr_inv_by';
+ if (!Number.isFinite(n) || n <= 0 || n > N_MAX) {
+ throw new Error(`?n must be in (0, ${N_MAX}], got ${qp.get('n')}`);
+ }
+ if (!Number.isFinite(k) || k <= 0 || k > K_MAX) {
+ throw new Error(`?k must be in (0, ${K_MAX}], got ${qp.get('k')}`);
+ }
+ if (!Number.isFinite(validateN) || validateN < 0 || validateN > n) {
+ throw new Error(`?validate-n must be in [0, n], got ${qp.get('validate-n')}`);
+ }
+ if (!Number.isFinite(reps) || reps <= 0 || reps > 100) {
+ throw new Error(`?reps must be in (0, 100], got ${qp.get('reps')}`);
+ }
+ if (
+ variantStr !== 'fr_inv' &&
+ variantStr !== 'fr_inv_by' &&
+ variantStr !== 'fr_inv_by_a' &&
+ variantStr !== 'fr_pow_inv'
+ ) {
+ throw new Error(
+ `?variant must be 'fr_inv' | 'fr_inv_by' | 'fr_inv_by_a' | 'fr_pow_inv', got '${variantStr}'`,
+ );
+ }
+ return { n, k, validateN, reps, variant: variantStr };
+}
+
+async function main() {
+ try {
+ if (!('gpu' in navigator)) {
+ throw new Error('navigator.gpu missing — WebGPU not available');
+ }
+ const params = parseParams();
+ benchState.params = params;
+ log(
+ 'info',
+ `params: n=${params.n} k=${params.k} validate-n=${params.validateN} reps=${params.reps} variant=${params.variant}`,
+ );
+
+ benchState.state = 'running';
+ const device = await get_device();
+ log('info', 'WebGPU device acquired');
+
+ const sm = new ShaderManager(4, params.n, BN254_CURVE_CONFIG, false);
+
+ const result = await runBench(
+ device,
+ sm,
+ params.n,
+ params.k,
+ params.validateN,
+ params.reps,
+ params.variant,
+ );
+ benchState.result = result;
+
+ benchState.state = 'done';
+ if (result.validateOk) {
+ log('ok', 'bench done');
+ } else {
+ log('err', 'bench done with validation failures');
+ }
+ } catch (e) {
+ const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e);
+ log('err', `FATAL: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+ }
+}
+
+main().catch(e => {
+ const msg = e instanceof Error ? e.message : String(e);
+ log('err', `unhandled: ${msg}`);
+ benchState.state = 'error';
+ benchState.error = msg;
+});
diff --git a/barretenberg/ts/dev/msm-webgpu/main.ts b/barretenberg/ts/dev/msm-webgpu/main.ts
index 1d9ce2f1cfaa..6eb897f041a3 100644
--- a/barretenberg/ts/dev/msm-webgpu/main.ts
+++ b/barretenberg/ts/dev/msm-webgpu/main.ts
@@ -1190,6 +1190,7 @@ $probeGpu?.addEventListener('click', async () => {
$runSanity.addEventListener('click', async () => {
$log.innerHTML = '';
abortRequested = false;
+ (window as unknown as { __sanity?: unknown }).__sanity = { state: 'running' };
setBusy(true, 'sanity check…');
try {
log('info', '[sanity] WebGPU-only smoke test, log₂(n)=16, no WASM, no noble');
@@ -1252,9 +1253,23 @@ $runSanity.addEventListener('click', async () => {
// view after a fresh page reload.
$results.innerHTML = renderBreakdownTable([{ logN: 16, captures: [gpu.capture] }]);
$results.classList.add('visible');
+ // Expose the raw capture so Playwright-driven profile scripts can
+ // pull per-stage GPU times without scraping the rendered table.
+ // Cleared at the start of every click, so a stale value from a
+ // previous run never bleeds into the next read.
+ (window as unknown as { __sanity?: unknown }).__sanity = {
+ state: 'done',
+ logN: 16,
+ ms: gpu.ms,
+ capture: JSON.parse(JSON.stringify(gpu.capture)),
+ };
} catch (err) {
log(abortRequested ? 'warn' : 'err', `[sanity] ${err instanceof Error ? err.message : String(err)}`);
if (!abortRequested && err instanceof Error && err.stack) log('err', err.stack);
+ (window as unknown as { __sanity?: unknown }).__sanity = {
+ state: 'error',
+ error: err instanceof Error ? err.message : String(err),
+ };
} finally {
setBusy(false);
}
diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/bench-apply-matrix.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/bench-apply-matrix.mjs
new file mode 100755
index 000000000000..0badddd1c4b1
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/scripts/bench-apply-matrix.mjs
@@ -0,0 +1,219 @@
+#!/usr/bin/env node
+// BY apply_matrix WebGPU dispatch test driver. Launches Playwright Chromium
+// with WebGPU enabled, navigates to the standalone bench-apply-matrix.html
+// page, waits for `window.__bench.state === 'done' | 'error'`, prints
+// results. Mirrors the bench-divsteps.mjs structure.
+
+import { chromium } from 'playwright-core';
+import { parseArgs } from 'node:util';
+
+const DEFAULT_URL_BASE = 'http://localhost:5173/dev/msm-webgpu/bench-apply-matrix.html';
+const N_MAX = 1 << 20;
+
+const { values: argv } = parseArgs({
+ options: {
+ n: { type: 'string', default: '1024' },
+ 'validate-n': { type: 'string' },
+ reps: { type: 'string', default: '1' },
+ url: { type: 'string', default: DEFAULT_URL_BASE },
+ headed: { type: 'boolean', default: false },
+ timeout: { type: 'string', default: '60' },
+ json: { type: 'boolean', default: false },
+ help: { type: 'boolean', default: false },
+ },
+ allowPositionals: false,
+});
+
+if (argv.help) {
+ process.stdout.write(
+ `BY apply_matrix WebGPU dispatch test driver
+
+Usage:
+ node dev/msm-webgpu/scripts/bench-apply-matrix.mjs [options]
+
+Options:
+ --n N (default 1024, max ${N_MAX})
+ --validate-n N (default min(64, n))
+ --reps R (default 1)
+ --url URL (default ${DEFAULT_URL_BASE})
+ --headed Run with visible browser window
+ --timeout SECS Bench page completion timeout (default 60)
+ --json Machine-readable JSON only output
+ --help Show this help
+`,
+ );
+ process.exit(0);
+}
+
+const n = parseInt(String(argv.n), 10);
+if (!Number.isFinite(n) || n <= 0 || n > N_MAX) {
+ process.stderr.write(`error: --n must be in (0, ${N_MAX}], got ${argv.n}\n`);
+ process.exit(2);
+}
+const validateN =
+ argv['validate-n'] !== undefined ? parseInt(String(argv['validate-n']), 10) : Math.min(64, n);
+if (!Number.isFinite(validateN) || validateN < 0 || validateN > n) {
+ process.stderr.write(`error: --validate-n must be in [0, n], got ${argv['validate-n']}\n`);
+ process.exit(2);
+}
+const reps = parseInt(String(argv.reps), 10);
+if (!Number.isFinite(reps) || reps <= 0 || reps > 100) {
+ process.stderr.write(`error: --reps must be in (0, 100], got ${argv.reps}\n`);
+ process.exit(2);
+}
+const headed = Boolean(argv.headed);
+const timeoutMs = parseFloat(String(argv.timeout)) * 1000;
+const jsonOnly = Boolean(argv.json);
+const baseUrl = String(argv.url);
+
+function err(msg) {
+ process.stderr.write(`[bench-apply-matrix] ${msg}\n`);
+}
+function out(msg) {
+ if (!jsonOnly) process.stdout.write(`${msg}\n`);
+}
+
+async function reachable(targetUrl) {
+ try {
+ const res = await fetch(targetUrl, { method: 'GET' });
+ return res.status >= 200 && res.status < 500;
+ } catch {
+ return false;
+ }
+}
+
+const CHROMIUM_ARGS = [
+ '--enable-unsafe-webgpu',
+ '--enable-features=Vulkan,WebGPU',
+ '--use-angle=metal',
+ '--disable-features=ServiceWorker',
+ '--ignore-gpu-blocklist',
+];
+
+async function tryLaunch(headless) {
+ return chromium.launch({ headless, args: CHROMIUM_ARGS });
+}
+
+function buildUrl() {
+ const u = new URL(baseUrl);
+ u.searchParams.set('n', String(n));
+ u.searchParams.set('validate-n', String(validateN));
+ u.searchParams.set('reps', String(reps));
+ return u.toString();
+}
+
+async function runOnce(headless) {
+ let browser;
+ try {
+ browser = await tryLaunch(headless);
+ const context = await browser.newContext({
+ viewport: { width: 900, height: 600 },
+ permissions: [],
+ bypassCSP: false,
+ });
+ const page = await context.newPage();
+ page.on('console', msg => {
+ const txt = msg.text();
+ if (!txt.startsWith('[vite]')) {
+ err(`[page:${msg.type()}] ${txt}`);
+ }
+ });
+ page.on('pageerror', e => err(`[page:pageerror] ${e.message}`));
+
+ const navUrl = buildUrl();
+ err(`navigating to ${navUrl} (headless=${headless})`);
+ await page.goto(navUrl, { waitUntil: 'domcontentloaded', timeout: 30_000 });
+
+ const hasWebGpu = await page.evaluate(() => 'gpu' in navigator);
+ if (!hasWebGpu) {
+ throw new Error('navigator.gpu missing in this Chromium instance');
+ }
+
+ err(`waiting for bench to complete (up to ${(timeoutMs / 1000).toFixed(0)}s)`);
+ const t0 = Date.now();
+ await page.waitForFunction(
+ () => window.__bench?.state === 'done' || window.__bench?.state === 'error',
+ { timeout: timeoutMs, polling: 250 },
+ );
+ const elapsed = (Date.now() - t0) / 1000;
+ err(`bench reached terminal state in ${elapsed.toFixed(1)}s`);
+
+ const result = await page.evaluate(() => {
+ const b = window.__bench;
+ return {
+ state: b.state,
+ params: b.params,
+ result: JSON.parse(JSON.stringify(b.result)),
+ error: b.error,
+ log: b.log.slice(),
+ };
+ });
+ result.elapsedSec = elapsed;
+ return result;
+ } finally {
+ try {
+ if (browser) await browser.close();
+ } catch (e) {
+ err(`browser.close failed: ${e.message}`);
+ }
+ }
+}
+
+(async () => {
+ if (!(await reachable(baseUrl))) {
+ err(`dev server not reachable at ${new URL(baseUrl).origin}`);
+ err(`start it with: cd barretenberg/ts && ./node_modules/.bin/vite --config dev/msm-webgpu/vite.config.ts --no-open`);
+ process.exit(3);
+ }
+
+ let result;
+ try {
+ result = await runOnce(!headed);
+ } catch (e) {
+ err(`run failed: ${e.message}`);
+ if (!headed) {
+ err('retrying in headed mode');
+ try {
+ result = await runOnce(false);
+ } catch (e2) {
+ err(`headed retry also failed: ${e2.message}`);
+ process.exit(4);
+ }
+ } else {
+ process.exit(4);
+ }
+ }
+
+ if (result.state === 'error') {
+ err(`page reported error: ${result.error}`);
+ if (jsonOnly) process.stdout.write(JSON.stringify(result, null, 2) + '\n');
+ process.exit(5);
+ }
+
+ if (!jsonOnly) {
+ out(`## BY apply_matrix WebGPU dispatch test`);
+ out('');
+ out(`- params: ${JSON.stringify(result.params)}`);
+ out(`- elapsed wall (s): ${result.elapsedSec.toFixed(2)}`);
+ out('');
+ const r = result.result;
+ if (r && r.timing) {
+ const t = r.timing;
+ out(`- validate ok: ${r.validateOk ? 'OK' : 'FAIL'}`);
+ out(`- timing reps=${t.reps} median=${t.msMedian.toFixed(3)}ms min=${t.msMin.toFixed(3)}ms max=${t.msMax.toFixed(3)}ms apply_matrix_calls/s=${t.applyMatrixPerSec.toExponential(3)}`);
+ } else if (r) {
+ out(`- validate ok: ${r.validateOk ? 'OK' : 'FAIL'} (no timing — validation failed)`);
+ for (const m of r.mismatches) {
+ out(` ${m.replace(/\n/g, '\n ')}`);
+ }
+ }
+ out('');
+ }
+ process.stdout.write(JSON.stringify(result, null, 2) + '\n');
+
+ const failed = !result.result || !result.result.validateOk;
+ process.exit(failed ? 1 : 0);
+})().catch(e => {
+ err(`unexpected: ${e.stack ?? e.message ?? String(e)}`);
+ process.exit(99);
+});
diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/bench-batch-affine.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/bench-batch-affine.mjs
new file mode 100644
index 000000000000..811949065ebe
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/scripts/bench-batch-affine.mjs
@@ -0,0 +1,212 @@
+#!/usr/bin/env node
+// Batch-affine EC add amortisation bench driver.
+// Mirrors bench-fr-inv.mjs structure: Playwright Chromium with WebGPU
+// enabled, navigates to the standalone bench-batch-affine.html page,
+// waits for `window.__bench.state === 'done' | 'error'`, dumps results
+// as a table.
+
+import { chromium } from 'playwright-core';
+import { parseArgs } from 'node:util';
+
+const DEFAULT_URL_BASE = 'http://localhost:5197/dev/msm-webgpu/bench-batch-affine.html';
+
+const { values: argv } = parseArgs({
+ options: {
+ reps: { type: 'string', default: '5' },
+ total: { type: 'string' },
+ sizes: { type: 'string' },
+ url: { type: 'string', default: DEFAULT_URL_BASE },
+ headed: { type: 'boolean', default: false },
+ timeout: { type: 'string', default: '180' },
+ json: { type: 'boolean', default: false },
+ help: { type: 'boolean', default: false },
+ },
+ allowPositionals: false,
+});
+
+if (argv.help) {
+ process.stdout.write(
+ `batch-affine WebGPU amortisation bench driver
+
+Usage:
+ node dev/msm-webgpu/scripts/bench-batch-affine.mjs [options]
+
+Options:
+ --reps R (default 5)
+ --url URL (default ${DEFAULT_URL_BASE})
+ --headed Run with visible browser window
+ --timeout SECS Bench page completion timeout (default 180)
+ --json Machine-readable JSON only output
+ --help Show this help
+`,
+ );
+ process.exit(0);
+}
+
+const reps = parseInt(String(argv.reps), 10);
+if (!Number.isFinite(reps) || reps <= 0 || reps > 50) {
+ process.stderr.write(`error: --reps must be in (0, 50], got ${argv.reps}\n`);
+ process.exit(2);
+}
+const headed = Boolean(argv.headed);
+const timeoutMs = parseFloat(String(argv.timeout)) * 1000;
+const jsonOnly = Boolean(argv.json);
+const baseUrl = String(argv.url);
+
+function err(msg) {
+ process.stderr.write(`[bench-batch-affine] ${msg}\n`);
+}
+function out(msg) {
+ if (!jsonOnly) process.stdout.write(`${msg}\n`);
+}
+
+async function reachable(targetUrl) {
+ try {
+ const res = await fetch(targetUrl, { method: 'GET' });
+ return res.status >= 200 && res.status < 500;
+ } catch {
+ return false;
+ }
+}
+
+const CHROMIUM_ARGS = [
+ '--enable-unsafe-webgpu',
+ '--enable-features=Vulkan,WebGPU',
+ '--use-angle=metal',
+ '--disable-features=ServiceWorker',
+ '--ignore-gpu-blocklist',
+];
+
+async function tryLaunch(headless) {
+ return chromium.launch({ headless, args: CHROMIUM_ARGS });
+}
+
+function buildUrl() {
+ const u = new URL(baseUrl);
+ u.searchParams.set('reps', String(reps));
+ if (argv.total !== undefined) u.searchParams.set('total', String(argv.total));
+ if (argv.sizes !== undefined) u.searchParams.set('sizes', String(argv.sizes));
+ return u.toString();
+}
+
+async function runOnce(headless) {
+ let browser;
+ try {
+ browser = await tryLaunch(headless);
+ const context = await browser.newContext({
+ viewport: { width: 900, height: 600 },
+ permissions: [],
+ bypassCSP: false,
+ });
+ const page = await context.newPage();
+ page.on('console', msg => {
+ const txt = msg.text();
+ if (!txt.startsWith('[vite]')) {
+ err(`[page:${msg.type()}] ${txt}`);
+ }
+ });
+ page.on('pageerror', e => err(`[page:pageerror] ${e.message}`));
+
+ const navUrl = buildUrl();
+ err(`navigating to ${navUrl} (headless=${headless})`);
+ await page.goto(navUrl, { waitUntil: 'domcontentloaded', timeout: 30_000 });
+
+ const hasWebGpu = await page.evaluate(() => 'gpu' in navigator);
+ if (!hasWebGpu) {
+ throw new Error('navigator.gpu missing in this Chromium instance');
+ }
+
+ err(`waiting for bench to complete (up to ${(timeoutMs / 1000).toFixed(0)}s)`);
+ const t0 = Date.now();
+ await page.waitForFunction(
+ () => window.__bench?.state === 'done' || window.__bench?.state === 'error',
+ { timeout: timeoutMs, polling: 250 },
+ );
+ const elapsed = (Date.now() - t0) / 1000;
+ err(`bench reached terminal state in ${elapsed.toFixed(1)}s`);
+
+ const result = await page.evaluate(() => {
+ const b = window.__bench;
+ return {
+ state: b.state,
+ params: b.params,
+ results: JSON.parse(JSON.stringify(b.results)),
+ error: b.error,
+ log: b.log.slice(),
+ };
+ });
+ result.elapsedSec = elapsed;
+ return result;
+ } finally {
+ try {
+ if (browser) await browser.close();
+ } catch (e) {
+ err(`browser.close failed: ${e.message}`);
+ }
+ }
+}
+
+(async () => {
+ if (!(await reachable(baseUrl))) {
+ err(`dev server not reachable at ${new URL(baseUrl).origin}`);
+ err(
+ `start it with: cd barretenberg/ts && ./node_modules/.bin/vite --config dev/msm-webgpu/vite.config.ts --port 5197 --strictPort --no-open`,
+ );
+ process.exit(3);
+ }
+
+ let result;
+ try {
+ result = await runOnce(!headed);
+ } catch (e) {
+ err(`run failed: ${e.message}`);
+ if (!headed) {
+ err('retrying in headed mode');
+ try {
+ result = await runOnce(false);
+ } catch (e2) {
+ err(`headed retry also failed: ${e2.message}`);
+ process.exit(4);
+ }
+ } else {
+ process.exit(4);
+ }
+ }
+
+ if (result.state === 'error') {
+ err(`page reported error: ${result.error}`);
+ if (jsonOnly) process.stdout.write(JSON.stringify(result, null, 2) + '\n');
+ // Don't exit non-zero — partial results may still be useful (e.g.
+ // first batch_size failed). Print what we have.
+ }
+
+ if (!jsonOnly) {
+ out(`## Batch-affine WebGPU amortisation bench (reps=${reps})`);
+ out('');
+ out(`- elapsed wall (s): ${result.elapsedSec.toFixed(2)}`);
+ out(`- state: ${result.state}`);
+ if (result.error) out(`- error: ${result.error}`);
+ out('');
+ if (result.results && result.results.length > 0) {
+ const base = result.results[result.results.length - 1].ns_per_pair;
+ out('batch_size | num_WGs | TPB | total_threads | median_ms | ns/pair | inv_amort_ratio');
+ out('---------- | ------- | --- | ------------- | --------- | ------- | --------------');
+ for (const r of result.results) {
+ const ratio = (r.ns_per_pair / base).toFixed(3);
+ out(
+ `${String(r.batch_size).padEnd(10)} | ${String(r.num_wgs).padEnd(7)} | ${String(r.tpb).padEnd(3)} | ${String(r.total_threads).padEnd(13)} | ${r.median_ms.toFixed(3).padStart(9)} | ${r.ns_per_pair.toFixed(1).padStart(7)} | ${String(ratio).padStart(14)}`,
+ );
+ }
+ } else {
+ out('(no results)');
+ }
+ out('');
+ }
+ process.stdout.write(JSON.stringify(result, null, 2) + '\n');
+
+ const failed = result.state === 'error';
+ process.exit(failed ? 1 : 0);
+})().catch(e => {
+ err(`unexpected: ${e.stack ?? e.message ?? String(e)}`);
+ process.exit(99);
+});
diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/bench-divsteps.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/bench-divsteps.mjs
new file mode 100755
index 000000000000..54dc7259c0d7
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/scripts/bench-divsteps.mjs
@@ -0,0 +1,219 @@
+#!/usr/bin/env node
+// BY divsteps WebGPU dispatch test driver. Launches Playwright Chromium
+// with WebGPU enabled, navigates to the standalone bench-divsteps.html
+// page, waits for `window.__bench.state === 'done' | 'error'`, prints
+// results. Mirrors the bench-field-mul.mjs structure.
+
+import { chromium } from 'playwright-core';
+import { parseArgs } from 'node:util';
+
+const DEFAULT_URL_BASE = 'http://localhost:5173/dev/msm-webgpu/bench-divsteps.html';
+const N_MAX = 1 << 23;
+
+const { values: argv } = parseArgs({
+ options: {
+ n: { type: 'string', default: '1024' },
+ 'validate-n': { type: 'string' },
+ reps: { type: 'string', default: '1' },
+ url: { type: 'string', default: DEFAULT_URL_BASE },
+ headed: { type: 'boolean', default: false },
+ timeout: { type: 'string', default: '60' },
+ json: { type: 'boolean', default: false },
+ help: { type: 'boolean', default: false },
+ },
+ allowPositionals: false,
+});
+
+if (argv.help) {
+ process.stdout.write(
+ `BY divsteps WebGPU dispatch test driver
+
+Usage:
+ node dev/msm-webgpu/scripts/bench-divsteps.mjs [options]
+
+Options:
+ --n N (default 1024, max ${N_MAX})
+ --validate-n N (default min(64, n))
+ --reps R (default 1)
+ --url URL (default ${DEFAULT_URL_BASE})
+ --headed Run with visible browser window
+ --timeout SECS Bench page completion timeout (default 60)
+ --json Machine-readable JSON only output
+ --help Show this help
+`,
+ );
+ process.exit(0);
+}
+
+const n = parseInt(String(argv.n), 10);
+if (!Number.isFinite(n) || n <= 0 || n > N_MAX) {
+ process.stderr.write(`error: --n must be in (0, ${N_MAX}], got ${argv.n}\n`);
+ process.exit(2);
+}
+const validateN =
+ argv['validate-n'] !== undefined ? parseInt(String(argv['validate-n']), 10) : Math.min(64, n);
+if (!Number.isFinite(validateN) || validateN < 0 || validateN > n) {
+ process.stderr.write(`error: --validate-n must be in [0, n], got ${argv['validate-n']}\n`);
+ process.exit(2);
+}
+const reps = parseInt(String(argv.reps), 10);
+if (!Number.isFinite(reps) || reps <= 0 || reps > 100) {
+ process.stderr.write(`error: --reps must be in (0, 100], got ${argv.reps}\n`);
+ process.exit(2);
+}
+const headed = Boolean(argv.headed);
+const timeoutMs = parseFloat(String(argv.timeout)) * 1000;
+const jsonOnly = Boolean(argv.json);
+const baseUrl = String(argv.url);
+
+function err(msg) {
+ process.stderr.write(`[bench-divsteps] ${msg}\n`);
+}
+function out(msg) {
+ if (!jsonOnly) process.stdout.write(`${msg}\n`);
+}
+
+async function reachable(targetUrl) {
+ try {
+ const res = await fetch(targetUrl, { method: 'GET' });
+ return res.status >= 200 && res.status < 500;
+ } catch {
+ return false;
+ }
+}
+
+const CHROMIUM_ARGS = [
+ '--enable-unsafe-webgpu',
+ '--enable-features=Vulkan,WebGPU',
+ '--use-angle=metal',
+ '--disable-features=ServiceWorker',
+ '--ignore-gpu-blocklist',
+];
+
+async function tryLaunch(headless) {
+ return chromium.launch({ headless, args: CHROMIUM_ARGS });
+}
+
+function buildUrl() {
+ const u = new URL(baseUrl);
+ u.searchParams.set('n', String(n));
+ u.searchParams.set('validate-n', String(validateN));
+ u.searchParams.set('reps', String(reps));
+ return u.toString();
+}
+
+async function runOnce(headless) {
+ let browser;
+ try {
+ browser = await tryLaunch(headless);
+ const context = await browser.newContext({
+ viewport: { width: 900, height: 600 },
+ permissions: [],
+ bypassCSP: false,
+ });
+ const page = await context.newPage();
+ page.on('console', msg => {
+ const txt = msg.text();
+ if (!txt.startsWith('[vite]')) {
+ err(`[page:${msg.type()}] ${txt}`);
+ }
+ });
+ page.on('pageerror', e => err(`[page:pageerror] ${e.message}`));
+
+ const navUrl = buildUrl();
+ err(`navigating to ${navUrl} (headless=${headless})`);
+ await page.goto(navUrl, { waitUntil: 'domcontentloaded', timeout: 30_000 });
+
+ const hasWebGpu = await page.evaluate(() => 'gpu' in navigator);
+ if (!hasWebGpu) {
+ throw new Error('navigator.gpu missing in this Chromium instance');
+ }
+
+ err(`waiting for bench to complete (up to ${(timeoutMs / 1000).toFixed(0)}s)`);
+ const t0 = Date.now();
+ await page.waitForFunction(
+ () => window.__bench?.state === 'done' || window.__bench?.state === 'error',
+ { timeout: timeoutMs, polling: 250 },
+ );
+ const elapsed = (Date.now() - t0) / 1000;
+ err(`bench reached terminal state in ${elapsed.toFixed(1)}s`);
+
+ const result = await page.evaluate(() => {
+ const b = window.__bench;
+ return {
+ state: b.state,
+ params: b.params,
+ result: JSON.parse(JSON.stringify(b.result)),
+ error: b.error,
+ log: b.log.slice(),
+ };
+ });
+ result.elapsedSec = elapsed;
+ return result;
+ } finally {
+ try {
+ if (browser) await browser.close();
+ } catch (e) {
+ err(`browser.close failed: ${e.message}`);
+ }
+ }
+}
+
+(async () => {
+ if (!(await reachable(baseUrl))) {
+ err(`dev server not reachable at ${new URL(baseUrl).origin}`);
+ err(`start it with: cd barretenberg/ts && ./node_modules/.bin/vite --config dev/msm-webgpu/vite.config.ts --no-open`);
+ process.exit(3);
+ }
+
+ let result;
+ try {
+ result = await runOnce(!headed);
+ } catch (e) {
+ err(`run failed: ${e.message}`);
+ if (!headed) {
+ err('retrying in headed mode');
+ try {
+ result = await runOnce(false);
+ } catch (e2) {
+ err(`headed retry also failed: ${e2.message}`);
+ process.exit(4);
+ }
+ } else {
+ process.exit(4);
+ }
+ }
+
+ if (result.state === 'error') {
+ err(`page reported error: ${result.error}`);
+ if (jsonOnly) process.stdout.write(JSON.stringify(result, null, 2) + '\n');
+ process.exit(5);
+ }
+
+ if (!jsonOnly) {
+ out(`## BY divsteps WebGPU dispatch test`);
+ out('');
+ out(`- params: ${JSON.stringify(result.params)}`);
+ out(`- elapsed wall (s): ${result.elapsedSec.toFixed(2)}`);
+ out('');
+ const r = result.result;
+ if (r && r.timing) {
+ const t = r.timing;
+ out(`- validate ok: ${r.validateOk ? 'OK' : 'FAIL'}`);
+ out(`- timing reps=${t.reps} median=${t.msMedian.toFixed(3)}ms min=${t.msMin.toFixed(3)}ms max=${t.msMax.toFixed(3)}ms divsteps_calls/s=${t.divstepsPerSec.toExponential(3)}`);
+ } else if (r) {
+ out(`- validate ok: ${r.validateOk ? 'OK' : 'FAIL'} (no timing — validation failed)`);
+ for (const m of r.mismatches) {
+ out(` ${m.replace(/\n/g, '\n ')}`);
+ }
+ }
+ out('');
+ }
+ process.stdout.write(JSON.stringify(result, null, 2) + '\n');
+
+ const failed = !result.result || !result.result.validateOk;
+ process.exit(failed ? 1 : 0);
+})().catch(e => {
+ err(`unexpected: ${e.stack ?? e.message ?? String(e)}`);
+ process.exit(99);
+});
diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/bench-field-mul.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/bench-field-mul.mjs
new file mode 100644
index 000000000000..d10b377d4255
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/scripts/bench-field-mul.mjs
@@ -0,0 +1,256 @@
+#!/usr/bin/env node
+// Field-mul micro-benchmark driver. Launches Playwright Chromium with
+// WebGPU enabled, navigates to the standalone bench-field-mul.html page,
+// waits for `window.__bench.state === 'done' | 'error'`, prints results.
+//
+// Host-side safety: every flag value is validated before any browser
+// launch. The page itself also caps `n` at 2^23 and `k` at 100; this
+// script rejects anything outside those bounds before navigation.
+
+import { chromium } from 'playwright-core';
+import { parseArgs } from 'node:util';
+
+const DEFAULT_URL_BASE = 'http://localhost:5173/dev/msm-webgpu/bench-field-mul.html';
+const N_MAX = 1 << 23;
+const K_MAX = 100;
+
+const { values: argv } = parseArgs({
+ options: {
+ path: { type: 'string', default: 'both' },
+ n: { type: 'string', default: '64' },
+ k: { type: 'string', default: '1' },
+ 'validate-n': { type: 'string' },
+ reps: { type: 'string', default: '3' },
+ url: { type: 'string', default: DEFAULT_URL_BASE },
+ headed: { type: 'boolean', default: false },
+ timeout: { type: 'string', default: '60' },
+ json: { type: 'boolean', default: false },
+ debug: { type: 'string' },
+ variant: { type: 'string', default: 'cios' },
+ help: { type: 'boolean', default: false },
+ },
+ allowPositionals: false,
+});
+
+if (argv.help) {
+ process.stdout.write(
+ `Field-mul WebGPU micro-benchmark driver
+
+Usage:
+ node dev/msm-webgpu/scripts/bench-field-mul.mjs [options]
+
+Options:
+ --path u32|f32|both (default both)
+ --n N (default 64, max ${N_MAX})
+ --k K (default 1, max ${K_MAX})
+ --validate-n N (default min(64, n))
+ --reps R (default 3)
+ --url URL (default ${DEFAULT_URL_BASE})
+ --headed Run with visible browser window
+ --timeout SECS Bench page completion timeout (default 60)
+ --json Machine-readable JSON only output
+ --variant V u32: cios|karat. f32: sos3uv3. (default cios)
+ --help Show this help
+`,
+ );
+ process.exit(0);
+}
+
+const path = String(argv.path);
+if (path !== 'u32' && path !== 'f32' && path !== 'both') {
+ process.stderr.write(`error: --path must be u32|f32|both, got "${path}"\n`);
+ process.exit(2);
+}
+const n = parseInt(String(argv.n), 10);
+if (!Number.isFinite(n) || n <= 0 || n > N_MAX) {
+ process.stderr.write(`error: --n must be in (0, ${N_MAX}], got ${argv.n}\n`);
+ process.exit(2);
+}
+const k = parseInt(String(argv.k), 10);
+if (!Number.isFinite(k) || k <= 0 || k > K_MAX) {
+ process.stderr.write(`error: --k must be in (0, ${K_MAX}], got ${argv.k}\n`);
+ process.exit(2);
+}
+const validateN =
+ argv['validate-n'] !== undefined ? parseInt(String(argv['validate-n']), 10) : Math.min(64, n);
+if (!Number.isFinite(validateN) || validateN < 0 || validateN > n) {
+ process.stderr.write(`error: --validate-n must be in [0, n], got ${argv['validate-n']}\n`);
+ process.exit(2);
+}
+const reps = parseInt(String(argv.reps), 10);
+if (!Number.isFinite(reps) || reps <= 0 || reps > 100) {
+ process.stderr.write(`error: --reps must be in (0, 100], got ${argv.reps}\n`);
+ process.exit(2);
+}
+const headed = Boolean(argv.headed);
+const timeoutMs = parseFloat(String(argv.timeout)) * 1000;
+const jsonOnly = Boolean(argv.json);
+const baseUrl = String(argv.url);
+
+function err(msg) {
+ process.stderr.write(`[bench-field-mul] ${msg}\n`);
+}
+function out(msg) {
+ if (!jsonOnly) process.stdout.write(`${msg}\n`);
+}
+
+async function reachable(targetUrl) {
+ // Vite returns 404 for the bare `/` when there's no index at the root,
+ // so hit the actual bench page URL (GET; HEAD is sometimes blocked by
+ // middleware too). Any response — even a non-200 — means the server is
+ // up; a network/connect-refused error throws.
+ try {
+ const res = await fetch(targetUrl, { method: 'GET' });
+ return res.status >= 200 && res.status < 500;
+ } catch {
+ return false;
+ }
+}
+
+const CHROMIUM_ARGS = [
+ '--enable-unsafe-webgpu',
+ '--enable-features=Vulkan,WebGPU',
+ '--use-angle=metal',
+ '--disable-features=ServiceWorker',
+ '--ignore-gpu-blocklist',
+];
+
+async function tryLaunch(headless) {
+ return chromium.launch({ headless, args: CHROMIUM_ARGS });
+}
+
+function buildUrl() {
+ const u = new URL(baseUrl);
+ u.searchParams.set('path', path);
+ u.searchParams.set('n', String(n));
+ u.searchParams.set('k', String(k));
+ u.searchParams.set('validate-n', String(validateN));
+ u.searchParams.set('reps', String(reps));
+ u.searchParams.set('variant', String(argv.variant));
+ if (argv.debug !== undefined) {
+ u.searchParams.set('debug', String(argv.debug));
+ }
+ return u.toString();
+}
+
+async function runOnce(headless) {
+ let browser;
+ try {
+ browser = await tryLaunch(headless);
+ const context = await browser.newContext({
+ viewport: { width: 900, height: 600 },
+ permissions: [],
+ bypassCSP: false,
+ });
+ const page = await context.newPage();
+ page.on('console', msg => {
+ const txt = msg.text();
+ if (!txt.startsWith('[vite]')) {
+ err(`[page:${msg.type()}] ${txt}`);
+ }
+ });
+ page.on('pageerror', e => err(`[page:pageerror] ${e.message}`));
+
+ const navUrl = buildUrl();
+ err(`navigating to ${navUrl} (headless=${headless})`);
+ await page.goto(navUrl, { waitUntil: 'domcontentloaded', timeout: 30_000 });
+
+ const hasWebGpu = await page.evaluate(() => 'gpu' in navigator);
+ if (!hasWebGpu) {
+ throw new Error('navigator.gpu missing in this Chromium instance');
+ }
+
+ err(`waiting for bench to complete (up to ${(timeoutMs / 1000).toFixed(0)}s)`);
+ const t0 = Date.now();
+ await page.waitForFunction(
+ () => window.__bench?.state === 'done' || window.__bench?.state === 'error',
+ { timeout: timeoutMs, polling: 250 },
+ );
+ const elapsed = (Date.now() - t0) / 1000;
+ err(`bench reached terminal state in ${elapsed.toFixed(1)}s`);
+
+ const result = await page.evaluate(() => {
+ const b = window.__bench;
+ return {
+ state: b.state,
+ params: b.params,
+ results: JSON.parse(JSON.stringify(b.results)),
+ error: b.error,
+ log: b.log.slice(),
+ };
+ });
+ result.elapsedSec = elapsed;
+ return result;
+ } finally {
+ try {
+ if (browser) await browser.close();
+ } catch (e) {
+ err(`browser.close failed: ${e.message}`);
+ }
+ }
+}
+
+(async () => {
+ if (!(await reachable(baseUrl))) {
+ err(`dev server not reachable at ${new URL(baseUrl).origin}`);
+ err(`start it with: cd barretenberg/ts && ./node_modules/.bin/vite --config dev/msm-webgpu/vite.config.ts --no-open`);
+ process.exit(3);
+ }
+
+ let result;
+ try {
+ result = await runOnce(!headed);
+ } catch (e) {
+ err(`run failed: ${e.message}`);
+ if (!headed) {
+ err('retrying in headed mode');
+ try {
+ result = await runOnce(false);
+ } catch (e2) {
+ err(`headed retry also failed: ${e2.message}`);
+ process.exit(4);
+ }
+ } else {
+ process.exit(4);
+ }
+ }
+
+ if (result.state === 'error') {
+ err(`page reported error: ${result.error}`);
+ if (jsonOnly) process.stdout.write(JSON.stringify(result, null, 2) + '\n');
+ process.exit(5);
+ }
+
+ if (!jsonOnly) {
+ out(`## field-mul micro-benchmark`);
+ out('');
+ out(`- params: ${JSON.stringify(result.params)}`);
+ out(`- elapsed wall (s): ${result.elapsedSec.toFixed(2)}`);
+ out('');
+ out('| path | validate ok | reps | median ms | min ms | max ms | mults/s |');
+ out('|------|-------------|-----:|----------:|-------:|-------:|------------:|');
+ for (const r of result.results) {
+ const t = r.timing;
+ const okStr = r.validateOk ? 'OK' : 'FAIL';
+ if (t) {
+ out(
+ `| ${r.path.padEnd(4)} | ${okStr.padEnd(11)} | ${String(t.reps).padStart(4)} | ` +
+ `${t.msMedian.toFixed(3).padStart(9)} | ${t.msMin.toFixed(3).padStart(6)} | ${t.msMax.toFixed(3).padStart(6)} | ${t.multsPerSec.toExponential(3).padStart(11)} |`,
+ );
+ } else {
+ out(`| ${r.path.padEnd(4)} | ${okStr.padEnd(11)} | (no timing — validation failed) |`);
+ for (const m of r.mismatches) {
+ out(` ${m.replace(/\n/g, '\n ')}`);
+ }
+ }
+ }
+ out('');
+ }
+ process.stdout.write(JSON.stringify(result, null, 2) + '\n');
+
+ const anyFailure = result.results.some(r => !r.validateOk);
+ process.exit(anyFailure ? 1 : 0);
+})().catch(e => {
+ err(`unexpected: ${e.stack ?? e.message ?? String(e)}`);
+ process.exit(99);
+});
diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/bench-fr-inv.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/bench-fr-inv.mjs
new file mode 100755
index 000000000000..7e36eb607b0b
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/scripts/bench-fr-inv.mjs
@@ -0,0 +1,239 @@
+#!/usr/bin/env node
+// fr_inv WebGPU dispatch test driver. Launches Playwright Chromium with
+// WebGPU enabled, navigates to the standalone bench-fr-inv.html page,
+// waits for `window.__bench.state === 'done' | 'error'`, prints results.
+// Mirrors the bench-divsteps.mjs / bench-apply-matrix.mjs structure.
+
+import { chromium } from 'playwright-core';
+import { parseArgs } from 'node:util';
+
+const DEFAULT_URL_BASE = 'http://localhost:5173/dev/msm-webgpu/bench-fr-inv.html';
+const N_MAX = 1 << 20;
+const K_MAX = 100;
+const VARIANTS = new Set(['fr_inv', 'fr_inv_by', 'fr_inv_by_a', 'fr_pow_inv']);
+
+const { values: argv } = parseArgs({
+ options: {
+ n: { type: 'string', default: '1024' },
+ k: { type: 'string', default: '1' },
+ 'validate-n': { type: 'string' },
+ reps: { type: 'string', default: '1' },
+ variant: { type: 'string', default: 'fr_inv_by' },
+ url: { type: 'string', default: DEFAULT_URL_BASE },
+ headed: { type: 'boolean', default: false },
+ timeout: { type: 'string', default: '120' },
+ json: { type: 'boolean', default: false },
+ help: { type: 'boolean', default: false },
+ },
+ allowPositionals: false,
+});
+
+if (argv.help) {
+ process.stdout.write(
+ `fr_inv WebGPU dispatch test driver
+
+Usage:
+ node dev/msm-webgpu/scripts/bench-fr-inv.mjs [options]
+
+Options:
+ --n N (default 1024, max ${N_MAX})
+ --k K (default 1, max ${K_MAX})
+ --validate-n N (default min(64, n))
+ --reps R (default 1)
+ --variant V 'fr_inv' | 'fr_inv_by' | 'fr_inv_by_a' | 'fr_pow_inv' (default fr_inv_by)
+ --url URL (default ${DEFAULT_URL_BASE})
+ --headed Run with visible browser window
+ --timeout SECS Bench page completion timeout (default 120)
+ --json Machine-readable JSON only output
+ --help Show this help
+`,
+ );
+ process.exit(0);
+}
+
+const n = parseInt(String(argv.n), 10);
+if (!Number.isFinite(n) || n <= 0 || n > N_MAX) {
+ process.stderr.write(`error: --n must be in (0, ${N_MAX}], got ${argv.n}\n`);
+ process.exit(2);
+}
+const k = parseInt(String(argv.k), 10);
+if (!Number.isFinite(k) || k <= 0 || k > K_MAX) {
+ process.stderr.write(`error: --k must be in (0, ${K_MAX}], got ${argv.k}\n`);
+ process.exit(2);
+}
+const validateN =
+ argv['validate-n'] !== undefined ? parseInt(String(argv['validate-n']), 10) : Math.min(64, n);
+if (!Number.isFinite(validateN) || validateN < 0 || validateN > n) {
+ process.stderr.write(`error: --validate-n must be in [0, n], got ${argv['validate-n']}\n`);
+ process.exit(2);
+}
+const reps = parseInt(String(argv.reps), 10);
+if (!Number.isFinite(reps) || reps <= 0 || reps > 100) {
+ process.stderr.write(`error: --reps must be in (0, 100], got ${argv.reps}\n`);
+ process.exit(2);
+}
+const variant = String(argv.variant);
+if (!VARIANTS.has(variant)) {
+ process.stderr.write(
+ `error: --variant must be one of ${[...VARIANTS].join(', ')}, got '${variant}'\n`,
+ );
+ process.exit(2);
+}
+const headed = Boolean(argv.headed);
+const timeoutMs = parseFloat(String(argv.timeout)) * 1000;
+const jsonOnly = Boolean(argv.json);
+const baseUrl = String(argv.url);
+
+function err(msg) {
+ process.stderr.write(`[bench-fr-inv] ${msg}\n`);
+}
+function out(msg) {
+ if (!jsonOnly) process.stdout.write(`${msg}\n`);
+}
+
+async function reachable(targetUrl) {
+ try {
+ const res = await fetch(targetUrl, { method: 'GET' });
+ return res.status >= 200 && res.status < 500;
+ } catch {
+ return false;
+ }
+}
+
+const CHROMIUM_ARGS = [
+ '--enable-unsafe-webgpu',
+ '--enable-features=Vulkan,WebGPU',
+ '--use-angle=metal',
+ '--disable-features=ServiceWorker',
+ '--ignore-gpu-blocklist',
+];
+
+async function tryLaunch(headless) {
+ return chromium.launch({ headless, args: CHROMIUM_ARGS });
+}
+
+function buildUrl() {
+ const u = new URL(baseUrl);
+ u.searchParams.set('n', String(n));
+ u.searchParams.set('k', String(k));
+ u.searchParams.set('validate-n', String(validateN));
+ u.searchParams.set('reps', String(reps));
+ u.searchParams.set('variant', variant);
+ return u.toString();
+}
+
+async function runOnce(headless) {
+ let browser;
+ try {
+ browser = await tryLaunch(headless);
+ const context = await browser.newContext({
+ viewport: { width: 900, height: 600 },
+ permissions: [],
+ bypassCSP: false,
+ });
+ const page = await context.newPage();
+ page.on('console', msg => {
+ const txt = msg.text();
+ if (!txt.startsWith('[vite]')) {
+ err(`[page:${msg.type()}] ${txt}`);
+ }
+ });
+ page.on('pageerror', e => err(`[page:pageerror] ${e.message}`));
+
+ const navUrl = buildUrl();
+ err(`navigating to ${navUrl} (headless=${headless})`);
+ await page.goto(navUrl, { waitUntil: 'domcontentloaded', timeout: 30_000 });
+
+ const hasWebGpu = await page.evaluate(() => 'gpu' in navigator);
+ if (!hasWebGpu) {
+ throw new Error('navigator.gpu missing in this Chromium instance');
+ }
+
+ err(`waiting for bench to complete (up to ${(timeoutMs / 1000).toFixed(0)}s)`);
+ const t0 = Date.now();
+ await page.waitForFunction(
+ () => window.__bench?.state === 'done' || window.__bench?.state === 'error',
+ { timeout: timeoutMs, polling: 250 },
+ );
+ const elapsed = (Date.now() - t0) / 1000;
+ err(`bench reached terminal state in ${elapsed.toFixed(1)}s`);
+
+ const result = await page.evaluate(() => {
+ const b = window.__bench;
+ return {
+ state: b.state,
+ params: b.params,
+ result: JSON.parse(JSON.stringify(b.result)),
+ error: b.error,
+ log: b.log.slice(),
+ };
+ });
+ result.elapsedSec = elapsed;
+ return result;
+ } finally {
+ try {
+ if (browser) await browser.close();
+ } catch (e) {
+ err(`browser.close failed: ${e.message}`);
+ }
+ }
+}
+
+(async () => {
+ if (!(await reachable(baseUrl))) {
+ err(`dev server not reachable at ${new URL(baseUrl).origin}`);
+ err(`start it with: cd barretenberg/ts && ./node_modules/.bin/vite --config dev/msm-webgpu/vite.config.ts --no-open`);
+ process.exit(3);
+ }
+
+ let result;
+ try {
+ result = await runOnce(!headed);
+ } catch (e) {
+ err(`run failed: ${e.message}`);
+ if (!headed) {
+ err('retrying in headed mode');
+ try {
+ result = await runOnce(false);
+ } catch (e2) {
+ err(`headed retry also failed: ${e2.message}`);
+ process.exit(4);
+ }
+ } else {
+ process.exit(4);
+ }
+ }
+
+ if (result.state === 'error') {
+ err(`page reported error: ${result.error}`);
+ if (jsonOnly) process.stdout.write(JSON.stringify(result, null, 2) + '\n');
+ process.exit(5);
+ }
+
+ if (!jsonOnly) {
+ out(`## fr_inv WebGPU dispatch test (variant=${variant})`);
+ out('');
+ out(`- params: ${JSON.stringify(result.params)}`);
+ out(`- elapsed wall (s): ${result.elapsedSec.toFixed(2)}`);
+ out('');
+ const r = result.result;
+ if (r && r.timing) {
+ const t = r.timing;
+ out(`- validate ok: ${r.validateOk ? 'OK' : 'FAIL'}`);
+ out(`- timing reps=${t.reps} median=${t.msMedian.toFixed(3)}ms min=${t.msMin.toFixed(3)}ms max=${t.msMax.toFixed(3)}ms inversions/s=${t.inversionsPerSec.toExponential(3)}`);
+ } else if (r) {
+ out(`- validate ok: ${r.validateOk ? 'OK' : 'FAIL'} (no timing — validation failed)`);
+ for (const m of r.mismatches) {
+ out(` ${m.replace(/\n/g, '\n ')}`);
+ }
+ }
+ out('');
+ }
+ process.stdout.write(JSON.stringify(result, null, 2) + '\n');
+
+ const failed = !result.result || !result.result.validateOk;
+ process.exit(failed ? 1 : 0);
+})().catch(e => {
+ err(`unexpected: ${e.stack ?? e.message ?? String(e)}`);
+ process.exit(99);
+});
diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/profile-sanity.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/profile-sanity.mjs
new file mode 100644
index 000000000000..c2c9eb94e83b
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/scripts/profile-sanity.mjs
@@ -0,0 +1,189 @@
+#!/usr/bin/env node
+// Drive the MSM dev page's Quick Sanity Check button via Playwright,
+// scrape window.__sanity.capture for per-stage GPU times, and print a
+// compact roll-up. Used to bench WPB / BPR-multi-window variants
+// without rebuilding the full bench/sweep UI flow.
+//
+// Outputs JSON like:
+// {
+// "reps": 5,
+// "ms_total": { "samples": [...], "median": ... },
+// "per_stage": { "bpr_1": { "samples": [...], "median": ... }, ... }
+// }
+
+import { chromium } from 'playwright-core';
+import { parseArgs } from 'node:util';
+
+const DEFAULT_URL = 'http://localhost:5195/dev/msm-webgpu/index.html';
+
+const { values: argv } = parseArgs({
+ options: {
+ url: { type: 'string', default: DEFAULT_URL },
+ reps: { type: 'string', default: '5' },
+ headed: { type: 'boolean', default: false },
+ timeout: { type: 'string', default: '120' },
+ help: { type: 'boolean', default: false },
+ },
+ allowPositionals: false,
+});
+
+if (argv.help) {
+ process.stdout.write(
+ `Profile-sanity WebGPU MSM driver
+Usage:
+ node dev/msm-webgpu/scripts/profile-sanity.mjs [options]
+Options:
+ --url URL (default ${DEFAULT_URL})
+ --reps R (default 5)
+ --headed Show browser window
+ --timeout SECS Per-sanity-click timeout (default 120)
+ --help Show this help
+`,
+ );
+ process.exit(0);
+}
+
+const reps = parseInt(String(argv.reps), 10);
+if (!Number.isFinite(reps) || reps <= 0 || reps > 20) {
+ process.stderr.write(`error: --reps must be in (0,20], got ${argv.reps}\n`);
+ process.exit(2);
+}
+const headed = Boolean(argv.headed);
+const timeoutMs = parseFloat(String(argv.timeout)) * 1000;
+const baseUrl = String(argv.url);
+
+function err(msg) {
+ process.stderr.write(`[profile-sanity] ${msg}\n`);
+}
+
+const CHROMIUM_ARGS = [
+ '--enable-unsafe-webgpu',
+ '--enable-features=Vulkan,WebGPU',
+ '--use-angle=metal',
+ '--disable-features=ServiceWorker',
+ '--ignore-gpu-blocklist',
+];
+
+function rollupLabel(label) {
+ const idx = label.indexOf('[');
+ return idx >= 0 ? label.substring(0, idx) : label;
+}
+
+function aggregateCapture(capture) {
+ // Per-rep stage sums: roll up `[r=...]` / `[subtask=...]` suffixes.
+ // Skip region entries (they overlap and would double-count).
+ const totals = new Map();
+ if (!capture?.profile) return { stages: {}, gpuWall: null, profiledSum: null, untimestamped: null };
+ for (const e of capture.profile) {
+ if (e.kind === 'region') continue;
+ const k = rollupLabel(e.label);
+ totals.set(k, (totals.get(k) ?? 0) + e.ms);
+ }
+ return {
+ stages: Object.fromEntries(totals),
+ gpuWall: capture.gpu_readback?.gpu_compute_wall ?? null,
+ profiledSum: capture.gpu_readback?.profiled_passes_sum ?? null,
+ untimestamped: capture.gpu_readback?.untimestamped ?? null,
+ };
+}
+
+function median(arr) {
+ if (!arr.length) return null;
+ const s = [...arr].sort((a, b) => a - b);
+ const m = Math.floor(s.length / 2);
+ return s.length % 2 === 0 ? (s[m - 1] + s[m]) / 2 : s[m];
+}
+
+async function run() {
+ const browser = await chromium.launch({ headless: !headed, args: CHROMIUM_ARGS });
+ try {
+ const context = await browser.newContext({ viewport: { width: 1100, height: 800 }, bypassCSP: false });
+ const page = await context.newPage();
+ page.on('console', msg => {
+ const t = msg.text();
+ if (!t.startsWith('[vite]')) err(`[page:${msg.type()}] ${t}`);
+ });
+ page.on('pageerror', e => err(`[page:pageerror] ${e.message}`));
+
+ err(`navigating to ${baseUrl} (headless=${!headed})`);
+ await page.goto(baseUrl, { waitUntil: 'domcontentloaded', timeout: 30_000 });
+
+ // Probe adapter info so the operator can verify Metal/Apple.
+ const adapter = await page.evaluate(async () => {
+ if (!('gpu' in navigator)) return null;
+ const a = await navigator.gpu.requestAdapter();
+ if (!a) return null;
+ // adapter.info is the modern API; fallback to requestAdapterInfo() for older.
+ // eslint-disable-next-line no-undef
+ const info = a.info ?? (await a.requestAdapterInfo?.());
+ return info ? { vendor: info.vendor, architecture: info.architecture, device: info.device, description: info.description } : null;
+ });
+ err(`adapter: ${JSON.stringify(adapter)}`);
+
+ // Wait for sanity button to be enabled (SRS loaded).
+ await page.waitForFunction(
+ () => {
+ const b = document.getElementById('run-sanity');
+ return b && !b.disabled;
+ },
+ { timeout: 60_000, polling: 200 },
+ );
+ err('sanity button enabled (SRS loaded)');
+
+ const samples = []; // per-rep aggregates
+ for (let r = 0; r < reps; r++) {
+ err(`rep ${r + 1}/${reps}: clicking Quick sanity check`);
+ // Reset state and click.
+ await page.evaluate(() => {
+ window.__sanity = { state: 'running' };
+ });
+ await page.click('#run-sanity');
+ await page.waitForFunction(
+ () => window.__sanity?.state === 'done' || window.__sanity?.state === 'error',
+ { timeout: timeoutMs, polling: 250 },
+ );
+ const out = await page.evaluate(() => JSON.parse(JSON.stringify(window.__sanity)));
+ if (out.state === 'error') {
+ err(`rep ${r + 1} ERROR: ${out.error}`);
+ process.exit(5);
+ }
+ const agg = aggregateCapture(out.capture);
+ agg.ms = out.ms;
+ samples.push(agg);
+ const bpr1 = agg.stages?.bpr_1;
+ err(`rep ${r + 1} ok: ms=${out.ms?.toFixed(1)} bpr_1=${bpr1?.toFixed?.(2) ?? '?'} gpuWall=${agg.gpuWall?.toFixed?.(2) ?? '?'}`);
+ }
+
+ // Aggregate per-stage medians + ms_total median.
+ const stageNames = new Set();
+ for (const s of samples) for (const k of Object.keys(s.stages)) stageNames.add(k);
+ const per_stage = {};
+ for (const k of stageNames) {
+ const xs = samples.map(s => s.stages[k]).filter(v => Number.isFinite(v));
+ per_stage[k] = { samples: xs, median: median(xs) };
+ }
+ const ms_total = samples.map(s => s.ms);
+ const gpu_wall = samples.map(s => s.gpuWall).filter(v => Number.isFinite(v));
+ const profiled_sum = samples.map(s => s.profiledSum).filter(v => Number.isFinite(v));
+ const result = {
+ reps,
+ adapter,
+ ms_total: { samples: ms_total, median: median(ms_total) },
+ gpu_wall: { samples: gpu_wall, median: median(gpu_wall) },
+ profiled_sum: { samples: profiled_sum, median: median(profiled_sum) },
+ per_stage,
+ };
+ process.stdout.write(JSON.stringify(result, null, 2) + '\n');
+ } finally {
+ try {
+ await browser.close();
+ } catch (e) {
+ err(`browser.close failed: ${e.message}`);
+ }
+ }
+}
+
+run().catch(e => {
+ err(`fatal: ${e.stack ?? e.message ?? String(e)}`);
+ process.exit(99);
+});
diff --git a/barretenberg/ts/dev/msm-webgpu/tsconfig.json b/barretenberg/ts/dev/msm-webgpu/tsconfig.json
new file mode 100644
index 000000000000..edf718d1c43b
--- /dev/null
+++ b/barretenberg/ts/dev/msm-webgpu/tsconfig.json
@@ -0,0 +1,10 @@
+{
+ "extends": "../../tsconfig.json",
+ "compilerOptions": {
+ "rootDir": ".",
+ "composite": false,
+ "ignoreDeprecations": "6.0",
+ "types": ["@webgpu/types", "vite/client"]
+ },
+ "include": ["**/*.ts"]
+}
diff --git a/barretenberg/ts/dev/msm-webgpu/wgsl_unit_tests.ts b/barretenberg/ts/dev/msm-webgpu/wgsl_unit_tests.ts
index 821081266812..de4229968626 100644
--- a/barretenberg/ts/dev/msm-webgpu/wgsl_unit_tests.ts
+++ b/barretenberg/ts/dev/msm-webgpu/wgsl_unit_tests.ts
@@ -515,5 +515,6 @@ export async function runAllWgslUnitTests(): Promise {
results.push(await testTransposeAtChunkSize(15, 256));
results.push(await testTransposeAtChunkSize(4, 256));
results.push(await testTransposeAtChunkSize(16, 256));
+
return results;
}
diff --git a/barretenberg/ts/package.json b/barretenberg/ts/package.json
index f3f3d4a1e5bf..5fd3c8371a14 100644
--- a/barretenberg/ts/package.json
+++ b/barretenberg/ts/package.json
@@ -99,6 +99,7 @@
"eslint": "^9.26.0",
"eslint-config-prettier": "^10.1.5",
"jest": "^30.0.0",
+ "playwright-core": "^1.59.1",
"prettier": "^3.5.3",
"ts-jest": "^29.4.0",
"ts-loader": "^9.4.2",
diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine.ts b/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine.ts
index 93b3738d7c41..6f926ba6f514 100644
--- a/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine.ts
+++ b/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine.ts
@@ -140,6 +140,26 @@ export const smvp_batch_affine_gpu = async (
const num_words = shaderManager.num_words;
const limb_byte_length = num_words * 4;
+ // WINDOWS_PER_BATCH pools the pair pools of WPB consecutive subtasks
+ // into ONE fr_inv_by_a per (batch, sub_wg) — see
+ // batch_inverse_parallel.template.wgsl. With T=17 subtasks at logN=16
+ // and WPB=4, num_batches = ceil(17/4) = 5, so a round runs 5×W
+ // workgroups in the inverse pass instead of 17×W (~4× fewer fr_inv
+ // calls overall). The pair-pool layout is per-subtask, unchanged; the
+ // pooling only changes how the inverse kernel reads/writes scratch.
+ // WPB=1 falls back to byte-identical pre-pooling behaviour.
+ //
+ // Default is 1: the per-pass profiler at logN=16 on Apple Silicon
+ // showed ba_inverse Σ ≈ 14% of MSM, so pooling buys little relative
+ // to BPR (51%). The plumbing stays in place so the knob is one edit
+ // away once BPR is no longer the dominant cost.
+ const WINDOWS_PER_BATCH = 1;
+ const num_batches = Math.ceil(num_subtasks / WINDOWS_PER_BATCH);
+ // count_buf (round_count + finalize_count) must be safely loadable
+ // for every slot the inverse kernel reads — i.e. through the tail of
+ // the last batch.
+ const count_slots_padded = num_batches * WINDOWS_PER_BATCH;
+
// ----- Workspace buffers -----
// When a persistent context is provided, all per-MSM workspace buffers
// are pulled from `context.acquirePersistentBuffer` and survive across
@@ -190,7 +210,12 @@ export const smvp_batch_affine_gpu = async (
// THIS round's inverse + apply. Decouples the schedule writer from
// the inverse/apply readers so we can zero pair_counter in dispatch_args
// without losing the values inverse/apply still need.
- const round_count_sb = acquire_ws('round_count', num_subtasks * 4);
+ //
+ // Padded to num_batches × WPB so the inverse kernel's per-batch WPB-
+ // wide atomicLoad scan can safely touch the tail of the last batch
+ // when num_subtasks isn't a multiple of WPB. Persistent buffers come
+ // zero-initialised; nothing else writes those tail slots.
+ const round_count_sb = acquire_ws('round_count', count_slots_padded * 4);
// Indirect-dispatch arg buffer. Layout: 9 u32s (36 B):
// args[0..3] = schedule (X, Y, Z) — for the *next* round
// args[3..6] = inverse (X, Y, Z) — for *this* round
@@ -216,9 +241,15 @@ export const smvp_batch_affine_gpu = async (
// Contents are constant across MSM calls, so we write once on first
// acquisition and skip the writeBuffer thereafter.
let finalize_count_sb: GPUBuffer;
+ // Sized to count_slots_padded so the inverse's per-batch WPB-wide scan
+ // sees zeros for tail slots past num_subtasks (which the kernel
+ // skips by treating wg_counts[w]=0 as an empty subtask). The first
+ // num_subtasks slots carry half_num_columns each.
+ const finalize_counts_arr = new Uint32Array(count_slots_padded);
+ for (let i = 0; i < num_subtasks; i++) finalize_counts_arr[i] = half_num_columns;
if (context !== undefined) {
- finalize_count_sb = context.acquirePersistentBuffer(`${ws_key}:finalize_count`, num_subtasks * 4);
- // Re-write the constants every time (cheap — T u32s = 64 B at T=16)
+ finalize_count_sb = context.acquirePersistentBuffer(`${ws_key}:finalize_count`, count_slots_padded * 4);
+ // Re-write the constants every time (cheap — at most ~20 u32s = 80 B)
// rather than tracking a "first call" flag. Avoids a subtle bug where
// a bigger N then smaller N would leave stale values in the cached
// buffer (acquirePersistentBuffer resets identity on size change but
@@ -226,12 +257,12 @@ export const smvp_batch_affine_gpu = async (
device.queue.writeBuffer(
finalize_count_sb,
0,
- new Uint8Array(new Uint32Array(new Array(num_subtasks).fill(half_num_columns)).buffer),
+ new Uint8Array(finalize_counts_arr.buffer, finalize_counts_arr.byteOffset, finalize_counts_arr.byteLength),
);
} else {
finalize_count_sb = create_and_write_sb(
device,
- new Uint8Array(new Uint32Array(new Array(num_subtasks).fill(half_num_columns)).buffer),
+ new Uint8Array(finalize_counts_arr.buffer, finalize_counts_arr.byteOffset, finalize_counts_arr.byteLength),
);
}
@@ -256,21 +287,20 @@ export const smvp_batch_affine_gpu = async (
const apply_workgroup_size = 64;
const finalize_workgroup_size_default = 256;
- // Inverse pass parallelism: each subtask's pair pool is split across
- // NUM_SUB_WGS_PER_SUBTASK contiguous slices, with one workgroup of
- // TPB=64 threads inverting each slice independently (its own fr_inv).
+ // Inverse pass parallelism: each (batch, sub_wg) inverts a contiguous
+ // slice of the merged pool. NUM_SUB_WGS_PER_SUBTASK workgroups of
+ // TPB=64 threads each handle one slice independently (own fr_inv).
// Drops per-thread sequential bs by W× → W× faster Phase A/D at large
// N. The W extra fr_invs run concurrently across SMs (zero wall-time
- // cost). W=8 keeps SM occupancy high (T=16 × W=8 = 128 in-flight
- // workgroups, matching RTX-class GPU SM counts).
+ // cost).
const NUM_SUB_WGS_PER_SUBTASK = 8;
const init_shader = shaderManager.gen_batch_affine_init_shader(init_workgroup_size);
const schedule_shader = shaderManager.gen_batch_affine_schedule_shader(schedule_workgroup_size);
- const dispatch_args_shader = shaderManager.gen_batch_affine_dispatch_args_shader();
- // Multi-workgroup batch-inverse (W sub-WGs per subtask). See
+ const dispatch_args_shader = shaderManager.gen_batch_affine_dispatch_args_shader(WINDOWS_PER_BATCH);
+ // Multi-workgroup batch-inverse (W sub-WGs per (batch, sub_wg)). See
// batch_inverse_parallel.template.wgsl for the algorithm.
- const inverse_shader = shaderManager.gen_batch_inverse_parallel_shader(NUM_SUB_WGS_PER_SUBTASK);
+ const inverse_shader = shaderManager.gen_batch_inverse_parallel_shader(NUM_SUB_WGS_PER_SUBTASK, WINDOWS_PER_BATCH);
const apply_shader = shaderManager.gen_batch_affine_apply_scatter_shader(apply_workgroup_size);
// Finalize dispatch geometry — single dispatch covers all T·h IDs;
// z workgroups = num_subtasks, threading inside the kernel mirrors
@@ -350,7 +380,7 @@ export const smvp_batch_affine_gpu = async (
],
inverse_shader,
context,
- `bn254:batch_affine_inverse:parallel-v3-pitch:${num_columns}:${input_size}`,
+ `bn254:batch_affine_inverse:parallel-v4-wpb${WINDOWS_PER_BATCH}:${num_columns}:${input_size}`,
);
const apply_pipe = await compile_pipeline_for(
@@ -381,7 +411,7 @@ export const smvp_batch_affine_gpu = async (
],
dispatch_args_shader,
context,
- `bn254:batch_affine_dispatch_args:v2-fwd:${num_subtasks}:${apply_workgroup_size}`,
+ `bn254:batch_affine_dispatch_args:v3-wpb${WINDOWS_PER_BATCH}:${num_subtasks}:${apply_workgroup_size}`,
);
const finalize_collect_pipe = await compile_pipeline_for(
@@ -532,7 +562,7 @@ export const smvp_batch_affine_gpu = async (
// Cache suffix bumped to "v2" so we don't accidentally reuse a cached
// bind group from the pre-fold v1 layout (same buffer count, different
// semantics — the v2 suffix tags the new wiring explicitly).
- const inverse_bg = acquire_bg('inverse_bg:v2', inverse_pipe.layout, [
+ const inverse_bg = acquire_bg(`inverse_bg:v3-wpb${WINDOWS_PER_BATCH}`, inverse_pipe.layout, [
pair_delta_sb,
pair_prefix_sb,
pair_inv_sb,
@@ -555,7 +585,7 @@ export const smvp_batch_affine_gpu = async (
// Suffix bumped to v2 by P2-clear: bind group now binds 4 buffers
// (pair_counter + round_count + dispatch_args + ub) instead of 3.
// W suffix dropped along with the P1 revert.
- const dispatch_args_bg = acquire_bg('dispatch_args_bg:v2', dispatch_args_pipe.layout, [
+ const dispatch_args_bg = acquire_bg(`dispatch_args_bg:v3-wpb${WINDOWS_PER_BATCH}`, dispatch_args_pipe.layout, [
pair_counter_sb,
round_count_sb,
dispatch_args_sb,
@@ -576,7 +606,7 @@ export const smvp_batch_affine_gpu = async (
finalize_ub,
]);
- const finalize_inverse_bg = acquire_bg('finalize_inverse_bg', inverse_pipe.layout, [
+ const finalize_inverse_bg = acquire_bg(`finalize_inverse_bg:v2-wpb${WINDOWS_PER_BATCH}`, inverse_pipe.layout, [
pair_delta_sb,
pair_prefix_sb,
pair_inv_sb,
@@ -786,17 +816,18 @@ export const smvp_batch_affine_gpu = async (
profiler?.stage('ba_finalize_collect'),
);
- // Pass B (batch_inverse): (W, 1, T) workgroups in parallel — each
- // subtask's h-element slice is split across W sub-WGs, each
- // independently inverting its sub-slice. Same kernel as the round
- // loop's inverse pass.
+ // Pass B (batch_inverse): (W, 1, num_batches) workgroups in parallel
+ // — each (batch, sub_wg) inverts the merged h-element slices of WPB
+ // consecutive subtasks with one fr_inv_by_a. Same kernel as the
+ // round loop's inverse pass; finalize_count_sb is pre-populated with
+ // half_num_columns for valid subtasks and zero for the padded tail.
await execute_pipeline(
commandEncoder,
inverse_pipe.pipeline,
finalize_inverse_bg,
NUM_SUB_WGS_PER_SUBTASK,
1,
- num_subtasks,
+ num_batches,
profiler?.stage('ba_finalize_inverse'),
);
diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine_bn254.test.ts b/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine_bn254.test.ts
new file mode 100644
index 000000000000..8df1ebbfc6a0
--- /dev/null
+++ b/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine_bn254.test.ts
@@ -0,0 +1,366 @@
+// Jest unit tests for the multi-window batched-Pippenger host
+// reference (sub-step 2.1 of the WebGPU MSM rewrite plan).
+//
+// What this validates (the WASM cross-window batched-inverse pattern
+// the GPU shaders will emulate in Phase 2):
+//
+// 1. windowsPerBatch=1 with the new signed-Booth path equals the
+// existing unsigned-digit `batchAffineMSM`. Asserts that the
+// signed-digit recoding plus 32-bit schedule encoding produces a
+// mathematically identical MSM result.
+//
+// 2. windowsPerBatch=2 and =4 agree with windowsPerBatch=1 on the
+// same inputs. Confirms the cross-window pooled batched-inverse
+// is associativity-correct (changing inversion batch size never
+// changes the result).
+//
+// 3. Edge cases: all-zero scalars, single non-zero scalar, scalars
+// at p-1, scalars at c-bit boundary values (2^k - 1, 2^k, 2^k + 1).
+//
+// 4. Schedule encoding round-trip: encode → decode is identity;
+// dedup bits (29|30) round-trip as zero (the consumer rejects
+// non-zero dedup bits).
+//
+// 5. Signed-Booth recoding identity: reassembling the recoded digits
+// via Σ_w (sign[w] ? -mag[w] : mag[w]) * 2^(c*w) == scalar.
+
+import {
+ batchAffineMSM,
+ batchAffineMSMMultiWindow,
+ decodeScheduleEntry,
+ encodeScheduleEntry,
+ SCHEDULE_INDEX_MASK,
+ SCHEDULE_SIGN_BIT,
+ _testOnly,
+} from "./batch_affine_bn254.js";
+import {
+ BN254_SCALAR_FIELD,
+ BN254_ZERO,
+ bn254ScalarField,
+ doubleBn254Point,
+ msmBn254,
+ scalarMultBn254Point,
+ type Bn254Point,
+} from "./bn254.js";
+
+// Reproducible RNG (mirrors bench-field-mul.ts / bernstein_yang.test.ts).
+const makeRng = (seed: number): (() => number) => {
+ let state = (seed >>> 0) || 1;
+ return () => {
+ state = (Math.imul(state, 1664525) + 1013904223) >>> 0;
+ return state;
+ };
+};
+
+// Random scalar in [0, p) via 8 × 32-bit limbs (>= 254 bits of entropy)
+// then reduced mod p. Plenty above the scalar field size.
+const randomScalar = (rng: () => number): bigint => {
+ let s = 0n;
+ for (let i = 0; i < 8; i++) {
+ s = (s << 32n) | BigInt(rng() >>> 0);
+ }
+ return s % BN254_SCALAR_FIELD;
+};
+
+// BN254 generator (matches the canonical hex used by other tests in
+// this directory: G_x = 1, G_y = 2).
+const BN254_G: Bn254Point = { x: 1n, y: 2n };
+
+// Build N test points as small multiples of G — cheap to construct,
+// and the random scalars we sample never collide with these multiples
+// modulo p (sampling space is ~2^254 vs N up to a few hundred).
+const buildPoints = (n: number, rng: () => number): Bn254Point[] => {
+ const pts: Bn254Point[] = new Array(n);
+ for (let i = 0; i < n; i++) {
+ // Use a fixed small offset + scalar from rng. Avoids ever passing
+ // identity by choosing scalars from [1, 2^31).
+ const k = BigInt((rng() & 0x7fffffff) + 1);
+ pts[i] = scalarMultBn254Point(BN254_G, k);
+ }
+ return pts;
+};
+
+const pointEq = (a: Bn254Point, b: Bn254Point): boolean => {
+ const aZ = a.infinity === true;
+ const bZ = b.infinity === true;
+ if (aZ && bZ) return true;
+ if (aZ || bZ) return false;
+ return a.x === b.x && a.y === b.y;
+};
+
+describe("Schedule entry encoding (32-bit layout)", () => {
+ it("encode/decode identity for positive sign", () => {
+ const entry = encodeScheduleEntry(42, 0);
+ expect(entry).toBe(42);
+ const dec = decodeScheduleEntry(entry);
+ expect(dec.scalarIdx).toBe(42);
+ expect(dec.sign).toBe(0);
+ });
+
+ it("encode/decode identity for negative sign", () => {
+ const entry = encodeScheduleEntry(42, 1);
+ // bit 31 set + scalarIdx = 42. As unsigned 32-bit:
+ expect(entry).toBe((SCHEDULE_SIGN_BIT | 42) >>> 0);
+ const dec = decodeScheduleEntry(entry);
+ expect(dec.scalarIdx).toBe(42);
+ expect(dec.sign).toBe(1);
+ });
+
+ it("encode/decode at 29-bit index boundary", () => {
+ const entry = encodeScheduleEntry(SCHEDULE_INDEX_MASK, 1);
+ const dec = decodeScheduleEntry(entry);
+ expect(dec.scalarIdx).toBe(SCHEDULE_INDEX_MASK);
+ expect(dec.sign).toBe(1);
+ });
+
+ it("encode throws on out-of-range scalarIdx", () => {
+ expect(() =>
+ encodeScheduleEntry(SCHEDULE_INDEX_MASK + 1, 0),
+ ).toThrow();
+ });
+
+ it("decode throws when dedup bit 29 is set", () => {
+ const entry = (1 << 29) | 7;
+ expect(() => decodeScheduleEntry(entry)).toThrow();
+ });
+
+ it("decode throws when dedup bit 30 is set", () => {
+ const entry = (1 << 30) | 7;
+ expect(() => decodeScheduleEntry(entry)).toThrow();
+ });
+
+ it("zero entry decodes to (0, 0)", () => {
+ const dec = decodeScheduleEntry(0);
+ expect(dec.scalarIdx).toBe(0);
+ expect(dec.sign).toBe(0);
+ });
+});
+
+describe("Signed-Booth digit recoding identity", () => {
+ const CS = [2, 4, 8, 15, 16];
+ const SCALAR_BITS = 254;
+
+ for (const c of CS) {
+ it(`c=${c}: reassembled digits equal scalar (random + edge cases)`, () => {
+ // Match `batchAffineMSMMultiWindow`'s formula (+2 headroom for
+ // the signed-Booth carry). Without this the top window's sign
+ // bit can fall inside the scalar's MSB region at small c values
+ // (e.g. c=2 with scalar = p-1, MSB = bit 253 ⇒ top window sign
+ // bit = bit 253 ⇒ phantom carry past num_windows).
+ const numWindows = Math.ceil((SCALAR_BITS + 2) / c);
+ const rng = makeRng(0xc001 + c);
+ const samples: bigint[] = [
+ 0n,
+ 1n,
+ 2n,
+ (1n << BigInt(c)) - 1n,
+ 1n << BigInt(c),
+ (1n << BigInt(c)) + 1n,
+ (1n << BigInt(c - 1)) - 1n,
+ 1n << BigInt(c - 1),
+ BN254_SCALAR_FIELD - 1n,
+ BN254_SCALAR_FIELD - 7n,
+ ];
+ for (let i = 0; i < 32; i++) samples.push(randomScalar(rng));
+
+ for (const s of samples) {
+ const { magnitudes, signs } = _testOnly.recodeScalarBooth(
+ s,
+ c,
+ numWindows,
+ );
+ // Reassemble digit_w * 2^(c*w). All magnitudes fit in [0, 2^(c-1)].
+ let recon = 0n;
+ const shifts: bigint[] = new Array(numWindows);
+ for (let w = 0; w < numWindows; w++) {
+ shifts[w] = BigInt(c * w);
+ }
+ for (let w = 0; w < numWindows; w++) {
+ const mag = BigInt(magnitudes[w]);
+ const signed = signs[w] === 1 ? -mag : mag;
+ recon += signed << shifts[w];
+ // Magnitude bound check.
+ expect(magnitudes[w]).toBeGreaterThanOrEqual(0);
+ expect(magnitudes[w]).toBeLessThanOrEqual(1 << (c - 1));
+ }
+ expect(recon).toBe(s);
+ }
+ });
+ }
+});
+
+describe("batchAffineMSMMultiWindow vs batchAffineMSM (windowsPerBatch=1)", () => {
+ // c=16 is the production chunk_size, but the unsigned-digit
+ // batchAffineMSM allocates 2^c = 65536 buckets per window — fine but
+ // slow in pure-bigint at large c. We use c=8 here so the test fixture
+ // stays under a second per case.
+ const C = 8;
+ const NS = [16, 64, 256];
+
+ for (const n of NS) {
+ it(`n=${n}, c=${C}: signed-Booth (wpb=1) == unsigned baseline`, () => {
+ const rng = makeRng(0xdead + n);
+ const points = buildPoints(n, rng);
+ const scalars: bigint[] = new Array(n);
+ for (let i = 0; i < n; i++) scalars[i] = randomScalar(rng);
+
+ const expected = batchAffineMSM(points, scalars, C);
+ const actual = batchAffineMSMMultiWindow(points, scalars, C, 1);
+ expect(pointEq(actual, expected)).toBe(true);
+
+ // Defence-in-depth: also cross-check against the brute-force MSM
+ // in bn254.ts. Both helpers above could be wrong in the same way
+ // (the signed-Booth recoder + bucket math is bespoke), so this
+ // third independent path locks down the truth.
+ const truth = msmBn254(points, scalars);
+ expect(pointEq(actual, truth)).toBe(true);
+ });
+ }
+
+ it("5 random seeds for n=64 all agree", () => {
+ for (let seed = 1; seed <= 5; seed++) {
+ const rng = makeRng(0xa11ce + seed);
+ const points = buildPoints(64, rng);
+ const scalars: bigint[] = new Array(64);
+ for (let i = 0; i < 64; i++) scalars[i] = randomScalar(rng);
+
+ const expected = batchAffineMSM(points, scalars, C);
+ const actual = batchAffineMSMMultiWindow(points, scalars, C, 1);
+ expect(pointEq(actual, expected)).toBe(true);
+ }
+ });
+});
+
+describe("batchAffineMSMMultiWindow: windowsPerBatch invariance", () => {
+ const C = 8;
+ const NS = [16, 64, 256];
+
+ for (const n of NS) {
+ it(`n=${n}, c=${C}: wpb=2 == wpb=1`, () => {
+ const rng = makeRng(0xb0b + n);
+ const points = buildPoints(n, rng);
+ const scalars: bigint[] = new Array(n);
+ for (let i = 0; i < n; i++) scalars[i] = randomScalar(rng);
+
+ const ref1 = batchAffineMSMMultiWindow(points, scalars, C, 1);
+ const ref2 = batchAffineMSMMultiWindow(points, scalars, C, 2);
+ expect(pointEq(ref2, ref1)).toBe(true);
+ });
+
+ it(`n=${n}, c=${C}: wpb=4 == wpb=1`, () => {
+ const rng = makeRng(0xfee1 + n);
+ const points = buildPoints(n, rng);
+ const scalars: bigint[] = new Array(n);
+ for (let i = 0; i < n; i++) scalars[i] = randomScalar(rng);
+
+ const ref1 = batchAffineMSMMultiWindow(points, scalars, C, 1);
+ const ref4 = batchAffineMSMMultiWindow(points, scalars, C, 4);
+ expect(pointEq(ref4, ref1)).toBe(true);
+ });
+ }
+});
+
+describe("batchAffineMSMMultiWindow: edge cases", () => {
+ const C = 8;
+ const N = 32;
+
+ it("all-zero scalars → identity", () => {
+ const rng = makeRng(0xc0de);
+ const points = buildPoints(N, rng);
+ const scalars: bigint[] = new Array(N).fill(0n);
+ for (const wpb of [1, 2, 4]) {
+ const r = batchAffineMSMMultiWindow(points, scalars, C, wpb);
+ expect(r.infinity).toBe(true);
+ }
+ });
+
+ it("single non-zero scalar agrees with scalar mult", () => {
+ const rng = makeRng(0x5eed);
+ const points = buildPoints(N, rng);
+ const idx = 7;
+ const k = 123456789n;
+ const scalars: bigint[] = new Array(N).fill(0n);
+ scalars[idx] = k;
+ const expected = scalarMultBn254Point(points[idx], k);
+ for (const wpb of [1, 2, 4]) {
+ const r = batchAffineMSMMultiWindow(points, scalars, C, wpb);
+ expect(pointEq(r, expected)).toBe(true);
+ }
+ });
+
+ it("scalar = p - 1 (maximum valid scalar)", () => {
+ const rng = makeRng(0xbeef);
+ const points = buildPoints(N, rng);
+ const scalars: bigint[] = new Array(N);
+ for (let i = 0; i < N; i++) scalars[i] = BN254_SCALAR_FIELD - 1n;
+ const truth = msmBn254(points, scalars);
+ for (const wpb of [1, 2, 4]) {
+ const r = batchAffineMSMMultiWindow(points, scalars, C, wpb);
+ expect(pointEq(r, truth)).toBe(true);
+ }
+ });
+
+ it("scalars at c-bit boundary values (2^k - 1, 2^k, 2^k + 1)", () => {
+ const rng = makeRng(0xface);
+ const points = buildPoints(8, rng);
+ // 8 scalars: covers 2^(c-1) - 1, 2^(c-1), 2^(c-1) + 1, 2^c - 1,
+ // 2^c, 2^c + 1, 2^(2c) - 1, 2^(2c).
+ const cBI = BigInt(C);
+ const scalars: bigint[] = [
+ (1n << (cBI - 1n)) - 1n,
+ 1n << (cBI - 1n),
+ (1n << (cBI - 1n)) + 1n,
+ (1n << cBI) - 1n,
+ 1n << cBI,
+ (1n << cBI) + 1n,
+ (1n << (cBI * 2n)) - 1n,
+ 1n << (cBI * 2n),
+ ];
+ const truth = msmBn254(points, scalars);
+ for (const wpb of [1, 2, 4]) {
+ const r = batchAffineMSMMultiWindow(points, scalars, C, wpb);
+ expect(pointEq(r, truth)).toBe(true);
+ }
+ });
+
+ it("mixed (zero + nonzero + p-1) and wpb=1/2/4 agree", () => {
+ const rng = makeRng(0xface5);
+ const N2 = 16;
+ const points = buildPoints(N2, rng);
+ const scalars: bigint[] = new Array(N2);
+ for (let i = 0; i < N2; i++) {
+ if (i % 4 === 0) scalars[i] = 0n;
+ else if (i % 4 === 1) scalars[i] = BN254_SCALAR_FIELD - 1n;
+ else scalars[i] = randomScalar(rng);
+ }
+ const truth = msmBn254(points, scalars);
+ for (const wpb of [1, 2, 4]) {
+ const r = batchAffineMSMMultiWindow(points, scalars, C, wpb);
+ expect(pointEq(r, truth)).toBe(true);
+ }
+ });
+});
+
+describe("batchAffineMSMMultiWindow: c=16 (production chunk size)", () => {
+ // Smoke test at c=16 to confirm the production chunk_size works.
+ // Limit n=16 to keep the 2^15-bucket-per-window per-batch allocation
+ // pure-bigint cost under a second.
+ it("c=16, n=16, wpb=1/2/4 agree with msmBn254 truth", () => {
+ const rng = makeRng(0x16);
+ const points = buildPoints(16, rng);
+ const scalars: bigint[] = new Array(16);
+ for (let i = 0; i < 16; i++) scalars[i] = randomScalar(rng);
+ const truth = msmBn254(points, scalars);
+ for (const wpb of [1, 2, 4]) {
+ const r = batchAffineMSMMultiWindow(points, scalars, 16, wpb);
+ expect(pointEq(r, truth)).toBe(true);
+ }
+ });
+});
+
+// Silence unused-import warnings: bn254ScalarField, doubleBn254Point,
+// BN254_ZERO are imported for future-proofing tests that exercise the
+// pooled-inverse boundary cases; the linter is otherwise vocal.
+void bn254ScalarField;
+void doubleBn254Point;
+void BN254_ZERO;
diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine_bn254.ts b/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine_bn254.ts
index 0168270f98c1..3611ee30e6a9 100644
--- a/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine_bn254.ts
+++ b/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine_bn254.ts
@@ -32,9 +32,23 @@ import {
doubleBn254Point,
isBn254Zero,
modInverse,
+ negateBn254Point,
type Bn254Point,
} from "./bn254.js";
+// 32-bit schedule-entry encoding (mirrors WASM
+// scalar_multiplication.cpp lines 552-567). Stage 4 stores only the
+// point sign + scalar index; bucket magnitude is recovered later from
+// bucket_start ranges because the schedule is bucket-contiguous.
+// bit 31: sign bit (1 = subtract the point)
+// bit 30: dedup redirect (always zero in this path)
+// bit 29: dedup skip (always zero in this path)
+// bits 0..28: scalar_idx (29-bit payload, ≤ 2^29 = 512M)
+// Use `>>> 0` to keep these as unsigned 32-bit; JS's `1 << 31` would
+// be the negative number -2147483648 otherwise.
+export const SCHEDULE_SIGN_BIT = (1 << 31) >>> 0; // 0x8000_0000
+export const SCHEDULE_INDEX_MASK = (1 << 29) - 1; // 0x1FFF_FFFF
+
/**
* Montgomery's batch-inverse trick.
*
@@ -282,3 +296,470 @@ export const _testOnlyInvViaFermat = (a: bigint): bigint => {
// a^(p-2) mod p
return modInverse(a);
};
+
+// =====================================================================
+// Multi-window batched Pippenger (host reference for Phase 2 GPU port)
+// =====================================================================
+//
+// Mirrors the WASM `pippenger_round_parallel` multi-window structure
+// (`scalar_multiplication.cpp` lines 3780-4080 for Stages 1-4, lines
+// 4210-4550 for Stage 6, lines 4550-4610 for the per-region driver).
+// Specifically un-dedup, un-GLV path.
+//
+// Key shape change vs `batchAffineMSM` above:
+// - Signed-digit (Booth) recoding per window: each c-bit digit is in
+// [-2^(c-1), 2^(c-1)], halving the bucket count to 2^(c-1) + 1 and
+// pairing positive/negative magnitudes via the schedule's sign bit.
+// - A "batch" of `windowsPerBatch` consecutive windows shares a
+// single batched inversion in the bucket-accumulation phase: pairs
+// pending from EVERY window in the batch are pooled into one
+// batchInverse call (WASM Stage 6b
+// `recursive_affine_bucket_reduce_strided`). With
+// windowsPerBatch = 4 this amortises the inversion ~4× harder than
+// the single-window path.
+
+// 32-bit schedule entry layout encode (bake bits 29, 30 = 0).
+export const encodeScheduleEntry = (
+ scalarIdx: number,
+ sign: 0 | 1,
+): number => {
+ if (scalarIdx < 0 || scalarIdx > SCHEDULE_INDEX_MASK) {
+ throw new Error(
+ `encodeScheduleEntry: scalarIdx ${scalarIdx} out of 29-bit range`,
+ );
+ }
+ // `>>> 0` keeps the result a non-negative 32-bit integer in JS's
+ // double-backed Number representation.
+ return ((sign << 31) | scalarIdx) >>> 0;
+};
+
+// 32-bit schedule entry layout decode. Verifies the dedup bits 29, 30
+// are zero — they must be in this path.
+export const decodeScheduleEntry = (
+ entry: number,
+): { scalarIdx: number; sign: 0 | 1 } => {
+ const u = entry >>> 0;
+ // Bits 29 and 30 must always be zero in the non-dedup path. Strict
+ // check here so consumers can rely on the invariant rather than
+ // silently masking them off.
+ if ((u & ((1 << 30) | (1 << 29))) !== 0) {
+ throw new Error(
+ `decodeScheduleEntry: dedup bits (29|30) set in entry 0x${u.toString(16)}`,
+ );
+ }
+ return {
+ scalarIdx: u & SCHEDULE_INDEX_MASK,
+ sign: (u >>> 31) as 0 | 1,
+ };
+};
+
+/**
+ * Constantine signed-Booth recoding of one scalar into `numWindows`
+ * c-bit signed digits.
+ *
+ * Returns digit magnitudes (always non-negative, in [0, 2^(c-1)]) and
+ * signs (0 = +, 1 = -) per window. The reconstruction identity:
+ *
+ * scalar = sum_w (sign[w] ? -magnitude[w] : magnitude[w]) * 2^(c*w)
+ *
+ * Mirrors `get_constantine_packed_digit`
+ * (`scalar_multiplication.cpp` lines 217-271). Carry between windows
+ * is implicit via the c+1-bit read that overlaps the previous window's
+ * high bit — no explicit propagation needed.
+ *
+ * The bottom window reads bits [0, c) with a synthetic 0 as the
+ * lookback bit. Non-bottom windows read [c*w - 1, c*(w+1)), with the
+ * shared boundary bit at c*w - 1 acting as the lookback.
+ */
+const recodeScalarBooth = (
+ scalar: bigint,
+ c: number,
+ numWindows: number,
+): { magnitudes: number[]; signs: (0 | 1)[] } => {
+ if (c < 1 || c > 28) {
+ // c+1 must fit in the 32-bit `raw_wide` intermediate, and bucket
+ // magnitudes must fit in the 29-bit schedule payload. Both are
+ // satisfied for the chunk sizes we care about (c = 15, 16). Plus a
+ // generous test margin.
+ throw new Error(`recodeScalarBooth: c=${c} out of supported range`);
+ }
+ const cBI = BigInt(c);
+ const magMask = (1n << cBI) - 1n; // val_mask in WASM
+
+ const magnitudes: number[] = new Array(numWindows);
+ const signs: (0 | 1)[] = new Array(numWindows);
+
+ for (let w = 0; w < numWindows; w++) {
+ let raw: bigint;
+ if (w === 0) {
+ // Bottom window: synthetic-zero lookback at bit -1.
+ // raw == (scalar[0:c]) << 1
+ raw = (scalar & magMask) << 1n;
+ } else {
+ // raw == c+1 bits at [c*w - 1, c*w + c).
+ const shift = BigInt(w * c - 1);
+ raw = (scalar >> shift) & ((1n << (cBI + 1n)) - 1n);
+ }
+ const neg = Number((raw >> cBI) & 1n) as 0 | 1;
+ // encode = (raw + 1) >> 1 — the inner "(encode - neg) ^ neg_mask"
+ // dance is the WASM's branchless negate-if-sign idiom. Below is
+ // the equivalent: when neg=0, mag = encode; when neg=1, mag =
+ // (-encode) & ((1<> 1n;
+ let magBI: bigint;
+ if (neg === 0) {
+ magBI = encode & magMask;
+ } else {
+ // Two's-complement negate, masked to c bits.
+ // (-encode) mod 2^c == (2^c - encode) mod 2^c.
+ // BigInt's bit-AND with a positive mask returns a non-negative
+ // result on negative inputs (effectively reduces mod 2^c), so
+ // `(-encode) & magMask` is exactly what we want.
+ magBI = (-encode) & magMask;
+ }
+ magnitudes[w] = Number(magBI);
+ signs[w] = neg;
+ }
+
+ return { magnitudes, signs };
+};
+
+/**
+ * Cross-window pooled batch-affine bucket reducer.
+ *
+ * Each window has its own `buckets[w][b]` stack of points. Per round,
+ * we pop ONE pair from EVERY (window, bucket) whose stack has ≥ 2
+ * points, into a single global pool, and run one batched inversion
+ * across the entire pool. Then we push each sum back to its
+ * originating (window, bucket). Loop until all stacks ≤ 1.
+ *
+ * This is the host analogue of WASM Stage 6b's
+ * `recursive_affine_bucket_reduce_strided` (lines 1276-1500): the
+ * inversion amortises across `windows_in_batch * num_buckets` pairs
+ * rather than just `num_buckets`. With windowsPerBatch=4 and ~2^c-1
+ * non-empty buckets per window, this is a ~4× larger inversion batch.
+ *
+ * Mutates the input bucket stacks (drains them).
+ */
+const pooledBatchSumBuckets = (
+ bucketsPerWindow: Bn254Point[][][],
+): Bn254Point[][] => {
+ const numWindows = bucketsPerWindow.length;
+ const numBuckets = numWindows > 0 ? bucketsPerWindow[0].length : 0;
+
+ while (true) {
+ // Collect one pair per (window, bucket) stack with ≥ 2 entries.
+ const ps: Bn254Point[] = [];
+ const qs: Bn254Point[] = [];
+ const targetW: number[] = [];
+ const targetB: number[] = [];
+
+ for (let w = 0; w < numWindows; w++) {
+ const buckets = bucketsPerWindow[w];
+ for (let b = 0; b < numBuckets; b++) {
+ if (buckets[b].length >= 2) {
+ const q = buckets[b].pop()!;
+ const p = buckets[b].pop()!;
+ ps.push(p);
+ qs.push(q);
+ targetW.push(w);
+ targetB.push(b);
+ }
+ }
+ }
+
+ if (ps.length === 0) break;
+
+ // Single batched inversion spans pairs from every window in the
+ // batch — this is the structural win over per-window reduction.
+ const sums = batchAffineAdd(ps, qs);
+ for (let i = 0; i < sums.length; i++) {
+ bucketsPerWindow[targetW[i]][targetB[i]].push(sums[i]);
+ }
+ }
+
+ // Each (window, bucket) stack now has 0 or 1 point.
+ const out: Bn254Point[][] = new Array(numWindows);
+ for (let w = 0; w < numWindows; w++) {
+ const buckets = bucketsPerWindow[w];
+ const row: Bn254Point[] = new Array(numBuckets);
+ for (let b = 0; b < numBuckets; b++) {
+ row[b] = buckets[b].length === 1 ? buckets[b][0] : BN254_ZERO;
+ }
+ out[w] = row;
+ }
+ return out;
+};
+
+/**
+ * Build per-window 32-bit schedule for the given window range, bucket-
+ * sorted (Stages 1+2+3+4 of the WASM compressed into one pass).
+ *
+ * Returns:
+ * - `schedule[w]`: array of 32-bit entries laid out in bucket order
+ * (all entries for bucket 1 first, then bucket 2, ...).
+ * - `bucketStart[w]`: exclusive prefix of per-bucket counts;
+ * `bucketStart[w][b+1] - bucketStart[w][b]` is the # of entries
+ * pinned to bucket b in window w. `bucketStart[w][0]` and
+ * `bucketStart[w][1]` are both 0 because bucket 0 is dropped
+ * (digit zero contributes nothing).
+ */
+const buildScheduleForBatch = (
+ scalars: bigint[],
+ c: number,
+ windowStart: number,
+ windowsInBatch: number,
+ numBuckets: number,
+ numWindowsTotal: number,
+): { schedule: number[][]; bucketStart: number[][] } => {
+ const n = scalars.length;
+ const schedule: number[][] = new Array(windowsInBatch);
+ const bucketStart: number[][] = new Array(windowsInBatch);
+
+ // Stage 1: histogram per (window, bucket). One pass per window.
+ const counts: number[][] = new Array(windowsInBatch);
+ // Per-scalar per-window recoded digit and sign, cached so Stage 4
+ // doesn't recompute. Matches the WASM's `fill_packed_digit_buffer`
+ // call pattern (one shared buffer reused across Stages 1 and 4).
+ const recoded: { mag: number; sign: 0 | 1 }[][] = new Array(
+ windowsInBatch,
+ );
+
+ for (let wInBatch = 0; wInBatch < windowsInBatch; wInBatch++) {
+ counts[wInBatch] = new Array(numBuckets).fill(0);
+ recoded[wInBatch] = new Array(n);
+ }
+
+ for (let i = 0; i < n; i++) {
+ const all = recodeScalarBooth(scalars[i], c, numWindowsTotal);
+ for (let wInBatch = 0; wInBatch < windowsInBatch; wInBatch++) {
+ const wAbs = windowStart + wInBatch;
+ const mag = all.magnitudes[wAbs];
+ const sign = all.signs[wAbs];
+ recoded[wInBatch][i] = { mag, sign };
+ if (mag !== 0) {
+ counts[wInBatch][mag]++;
+ }
+ }
+ }
+
+ // Stage 2+3: per-window exclusive prefix-sum of counts → bucketStart.
+ // bucketStart[w][b] = sum_{b' < b} counts[w][b'], with [0] = [1] = 0
+ // because bucket 0 is dropped (digit==0 contributes nothing).
+ for (let wInBatch = 0; wInBatch < windowsInBatch; wInBatch++) {
+ const bs: number[] = new Array(numBuckets + 1);
+ bs[0] = 0;
+ bs[1] = 0;
+ let running = 0;
+ for (let b = 1; b < numBuckets; b++) {
+ bs[b] = running;
+ running += counts[wInBatch][b];
+ }
+ bs[numBuckets] = running;
+ bucketStart[wInBatch] = bs;
+ }
+
+ // Stage 4: scatter into bucket-sorted schedule. Allocate per-window
+ // schedule arrays sized to the total entry count.
+ for (let wInBatch = 0; wInBatch < windowsInBatch; wInBatch++) {
+ schedule[wInBatch] = new Array(bucketStart[wInBatch][numBuckets]);
+ }
+ const cursor: number[][] = new Array(windowsInBatch);
+ for (let wInBatch = 0; wInBatch < windowsInBatch; wInBatch++) {
+ cursor[wInBatch] = new Array(numBuckets).fill(0);
+ }
+
+ for (let i = 0; i < n; i++) {
+ if (i > SCHEDULE_INDEX_MASK) {
+ throw new Error(
+ `buildScheduleForBatch: scalar index ${i} exceeds 29-bit schedule payload`,
+ );
+ }
+ for (let wInBatch = 0; wInBatch < windowsInBatch; wInBatch++) {
+ const { mag, sign } = recoded[wInBatch][i];
+ if (mag === 0) continue;
+ const slot = bucketStart[wInBatch][mag] + cursor[wInBatch][mag]++;
+ schedule[wInBatch][slot] = encodeScheduleEntry(i, sign);
+ }
+ }
+
+ return { schedule, bucketStart };
+};
+
+/**
+ * Consume a per-(window, bucket) sorted 32-bit schedule and accumulate
+ * each window's MSM contribution via cross-window pooled batch-affine.
+ *
+ * Returns one Bn254Point per window in the batch.
+ */
+const reduceBatchSchedule = (
+ points: Bn254Point[],
+ schedule: number[][],
+ bucketStart: number[][],
+ numBuckets: number,
+ c: number,
+): Bn254Point[] => {
+ const windowsInBatch = schedule.length;
+
+ // Build per-window per-bucket point stacks by decoding schedule
+ // entries. Sign bit ⇒ negate the loaded point before pushing.
+ const bucketsPerWindow: Bn254Point[][][] = new Array(windowsInBatch);
+ for (let wInBatch = 0; wInBatch < windowsInBatch; wInBatch++) {
+ const buckets: Bn254Point[][] = new Array(numBuckets);
+ for (let b = 0; b < numBuckets; b++) buckets[b] = [];
+ const sched = schedule[wInBatch];
+ const bs = bucketStart[wInBatch];
+ // Walk each bucket's contiguous run [bs[b], bs[b+1]) of entries.
+ for (let b = 1; b < numBuckets; b++) {
+ const lo = bs[b];
+ const hi = bs[b + 1];
+ for (let k = lo; k < hi; k++) {
+ const { scalarIdx, sign } = decodeScheduleEntry(sched[k]);
+ const p = sign === 1
+ ? negateBn254Point(points[scalarIdx])
+ : points[scalarIdx];
+ buckets[b].push(p);
+ }
+ }
+ bucketsPerWindow[wInBatch] = buckets;
+ }
+
+ // The pooled reducer is the structural win: ONE batched inversion
+ // per round spans pairs from EVERY window in the batch.
+ const bucketSumsPerWindow = pooledBatchSumBuckets(bucketsPerWindow);
+
+ // Per-window running-sum reduction
+ // S_w = sum_{b=1}^{B-1} b * bucketSums[w][b]
+ // is identical to the per-window path. Bucket reduction is sequential
+ // by nature (data-dependency through `running`); the cross-window win
+ // happens in the bucket-accumulation phase above, not here.
+ const out: Bn254Point[] = new Array(windowsInBatch);
+ for (let wInBatch = 0; wInBatch < windowsInBatch; wInBatch++) {
+ const sums = bucketSumsPerWindow[wInBatch];
+ let running: Bn254Point = BN254_ZERO;
+ let result: Bn254Point = BN254_ZERO;
+ for (let b = numBuckets - 1; b >= 1; b--) {
+ running = addBn254Points(running, sums[b]);
+ result = addBn254Points(result, running);
+ }
+ out[wInBatch] = result;
+ }
+ // Bind c into the result via Horner combine at the call site; this
+ // helper only returns per-window sums.
+ void c;
+ return out;
+};
+
+/**
+ * Multi-window batched Pippenger MSM with cross-window pooled
+ * batched-affine bucket reduction.
+ *
+ * This is the host ground-truth for the WebGPU multi-window Pippenger
+ * rewrite (Phase 2 of the plan). It implements the WASM structure with
+ * the un-GLV, un-dedup path:
+ *
+ * - Signed-Booth recoding splits each scalar into ceil(254/c) c-bit
+ * signed digits.
+ * - Per region (single c here), iterate windows in batches of
+ * `windowsPerBatch`. Per batch:
+ * 1. Stages 1+2+3+4 — histogram → prefix-sum → bucket-sorted
+ * 32-bit schedule (`buildScheduleForBatch`).
+ * 2. Stage 6 — pooled batch-affine bucket accumulation across
+ * all windows in the batch (`reduceBatchSchedule`). ONE
+ * inversion per round pools pairs from every window.
+ * 3. Stage 7 — per-window bucket running-sum reduction.
+ * - Final Horner combine over `c` doublings folds windows into one
+ * point.
+ *
+ * Cross-checks against `batchAffineMSM`:
+ * - windowsPerBatch=1 with this signed-Booth path equals the
+ * unsigned-digit `batchAffineMSM`.
+ * - windowsPerBatch=2 and =4 equal windowsPerBatch=1 (the cross-
+ * window batched inversion is associativity-only — output is
+ * identical, only the inversion amortisation factor changes).
+ *
+ * Used by Jest tests as ground truth for the GPU shader pipeline; not
+ * tuned for speed.
+ */
+export const batchAffineMSMMultiWindow = (
+ points: Bn254Point[],
+ scalars: bigint[],
+ c: number,
+ windowsPerBatch: number,
+): Bn254Point => {
+ if (points.length !== scalars.length) {
+ throw new Error(
+ "batchAffineMSMMultiWindow: points and scalars length mismatch",
+ );
+ }
+ if (windowsPerBatch < 1) {
+ throw new Error(
+ `batchAffineMSMMultiWindow: windowsPerBatch=${windowsPerBatch} must be ≥ 1`,
+ );
+ }
+
+ // BN254 scalar field is 254 bits wide.
+ const scalarBits = 254;
+ // `+ 2` headroom for the signed-Booth carry — matches WASM's
+ // `(NUM_BITS + 2 + window_bits - 1) / window_bits`
+ // (`scalar_multiplication.cpp` lines 514, 2484). Without this, scalars
+ // whose MSB lands on the top window's sign-bit position (e.g. p-1
+ // with small c) would generate a phantom carry past the topmost
+ // window and the reconstruction would drop the high contribution.
+ const numWindows = Math.ceil((scalarBits + 2) / c);
+ // Signed-digit recoding halves the bucket count: digits live in
+ // [-2^(c-1), 2^(c-1)], with magnitude 0 dropped. The bucket array
+ // is dimensioned 2^(c-1) + 1 so bucket index can equal 2^(c-1).
+ const numBuckets = (1 << (c - 1)) + 1;
+
+ const windowSums: Bn254Point[] = new Array(numWindows);
+
+ for (
+ let batchStart = 0;
+ batchStart < numWindows;
+ batchStart += windowsPerBatch
+ ) {
+ const windowsInBatch = Math.min(
+ windowsPerBatch,
+ numWindows - batchStart,
+ );
+
+ const { schedule, bucketStart } = buildScheduleForBatch(
+ scalars,
+ c,
+ batchStart,
+ windowsInBatch,
+ numBuckets,
+ numWindows,
+ );
+
+ const batchWindowSums = reduceBatchSchedule(
+ points,
+ schedule,
+ bucketStart,
+ numBuckets,
+ c,
+ );
+
+ for (let wInBatch = 0; wInBatch < windowsInBatch; wInBatch++) {
+ windowSums[batchStart + wInBatch] = batchWindowSums[wInBatch];
+ }
+ }
+
+ // Horner combine across windows: msm = Σ_w 2^(c*w) * windowSums[w].
+ let acc: Bn254Point = BN254_ZERO;
+ for (let w = numWindows - 1; w >= 0; w--) {
+ for (let bit = 0; bit < c; bit++) {
+ acc = doubleBn254Point(acc);
+ }
+ acc = addBn254Points(acc, windowSums[w]);
+ }
+ return acc;
+};
+
+// Test-only export: exposed so the test file can exercise the
+// signed-Booth recoder in isolation against a brute-force reconstruction.
+export const _testOnly = {
+ recodeScalarBooth,
+ buildScheduleForBatch,
+};
diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang.test.ts b/barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang.test.ts
new file mode 100644
index 000000000000..a8c8e5fabf83
--- /dev/null
+++ b/barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang.test.ts
@@ -0,0 +1,818 @@
+// Jest unit tests for the Wasm9x29 Bernstein-Yang TS port (sub-step 1.1 of
+// the WebGPU MSM rewrite plan). Asserts:
+// 1. invert(a, p) matches modInverse(a, p) on 1000 seeded-LCG random
+// BN254 base-field inputs.
+// 2. Edge cases: a == 1, a == p-1, a == 2, a == (p-1)/2.
+// 3. (a * invert(a)) mod p === 1 for every successful invert.
+// 4. The internal step counter never exceeds the static bound implied
+// by the WASM source (NUM_OUTER * (BATCH + 2 * RTC_MAX_ITERS) +
+// 2 * RTC_MAX_ITERS for the final reduce). This enforces the
+// "strictly bounded loops" constraint from the plan.
+
+import {
+ invert,
+ Wasm9x29,
+ P_BIGINT,
+ BATCH_NUM,
+ NUM_OUTER,
+ RTC_MAX_ITERS,
+} from "./bernstein_yang.js";
+import { modInverse } from "./bn254.js";
+
+// Seeded LCG (Numerical Recipes constants) mirrored from
+// dev/msm-webgpu/bench-field-mul.ts. Reproducible across runs.
+function makeRng(seed: number): () => number {
+ let state = (seed >>> 0) || 1;
+ return () => {
+ state = (Math.imul(state, 1664525) + 1013904223) >>> 0;
+ return state;
+ };
+}
+
+function randomBelow(p: bigint, rng: () => number): bigint {
+ const bitlen = p.toString(2).length;
+ const byteLen = Math.ceil(bitlen / 8);
+ // Bounded: max retries cap to keep the test loop hard-bounded.
+ // For BN254 base field, p ≈ 0.94 * 2^254, so rejection probability is ~6%;
+ // 64 attempts gives failure probability ~10^-79.
+ const MAX_TRIES = 64;
+ for (let attempt = 0; attempt < MAX_TRIES; attempt++) {
+ let v = 0n;
+ for (let i = 0; i < byteLen; i++) {
+ v = (v << 8n) | BigInt(rng() & 0xff);
+ }
+ v &= (1n << BigInt(bitlen)) - 1n;
+ if (v < p) return v;
+ }
+ throw new Error("randomBelow: rejection sampling exceeded MAX_TRIES");
+}
+
+// Hard upper bound on the bounded-loop step counter incremented by
+// divsteps inner steps + reduce_to_canonical inner steps. Tighter than
+// the absolute max so a regression that loosens the loop bound trips a
+// test failure. Components:
+// - NUM_OUTER * BATCH = 13 * 58 = 754 divstep inner steps.
+// - In-loop reduces: ⌊NUM_OUTER / REDUCE_INTERVAL⌋ = 3 boundary hits,
+// each calling reduce_to_canonical twice (d, e) of ≤ RTC_MAX_ITERS = 36.
+// → 3 * 2 * 36 = 216.
+// - Final reduce_to_canonical(d), plus optional reduce after neg(d) on
+// f-negative branch. → 2 * 36 = 72.
+// Total upper bound: 754 + 216 + 72 = 1042.
+const STEP_BOUND =
+ NUM_OUTER * BATCH_NUM +
+ 2 * Math.floor(NUM_OUTER / 4) * RTC_MAX_ITERS +
+ 2 * RTC_MAX_ITERS;
+
+describe("Bernstein-Yang Wasm9x29 TS port", () => {
+ const p = P_BIGINT;
+
+ it("invert(1) == 1", () => {
+ const stepRef = { steps: 0 };
+ const r = invert(1n, p, Wasm9x29.P_INV, stepRef);
+ expect(r).toBe(1n);
+ expect(stepRef.steps).toBeLessThan(STEP_BOUND);
+ });
+
+ it("invert(2) * 2 mod p == 1", () => {
+ const stepRef = { steps: 0 };
+ const inv2 = invert(2n, p, Wasm9x29.P_INV, stepRef);
+ expect((inv2 * 2n) % p).toBe(1n);
+ expect(inv2).toBe(modInverse(2n, p));
+ expect(stepRef.steps).toBeLessThan(STEP_BOUND);
+ });
+
+ it("invert(p - 1) * (p - 1) mod p == 1", () => {
+ const a = p - 1n;
+ const stepRef = { steps: 0 };
+ const r = invert(a, p, Wasm9x29.P_INV, stepRef);
+ expect((a * r) % p).toBe(1n);
+ expect(r).toBe(modInverse(a, p));
+ // p - 1 is its own inverse mod p (since (p-1)^2 = p^2 - 2p + 1 ≡ 1).
+ expect(r).toBe(p - 1n);
+ expect(stepRef.steps).toBeLessThan(STEP_BOUND);
+ });
+
+ it("invert((p - 1) / 2) * ((p - 1) / 2) mod p == 1", () => {
+ const a = (p - 1n) / 2n;
+ const stepRef = { steps: 0 };
+ const r = invert(a, p, Wasm9x29.P_INV, stepRef);
+ expect((a * r) % p).toBe(1n);
+ expect(r).toBe(modInverse(a, p));
+ expect(stepRef.steps).toBeLessThan(STEP_BOUND);
+ });
+
+ it("invert(0) == 0", () => {
+ expect(invert(0n, p)).toBe(0n);
+ });
+
+ it("matches modInverse on 1000 seeded-LCG random inputs", () => {
+ const rng = makeRng(0xdeadbeef);
+ const N = 1000;
+ let passes = 0;
+ let mismatches = 0;
+ let maxSteps = 0;
+ for (let i = 0; i < N; i++) {
+ const a = randomBelow(p, rng);
+ if (a === 0n) {
+ // BY returns 0 for 0; modInverse throws.
+ expect(invert(a, p)).toBe(0n);
+ passes++;
+ continue;
+ }
+ const stepRef = { steps: 0 };
+ const byInv = invert(a, p, Wasm9x29.P_INV, stepRef);
+ const refInv = modInverse(a, p);
+ const ok = byInv === refInv && (a * byInv) % p === 1n;
+ if (!ok) {
+ mismatches++;
+ if (mismatches <= 5) {
+ // Log up to 5 details for debugging.
+ // eslint-disable-next-line no-console
+ console.error(
+ `mismatch at i=${i}: a=${a}\n by =${byInv}\n ref=${refInv}`,
+ );
+ }
+ } else {
+ passes++;
+ }
+ if (stepRef.steps > maxSteps) maxSteps = stepRef.steps;
+ }
+ expect(mismatches).toBe(0);
+ expect(passes).toBe(N);
+ expect(maxSteps).toBeLessThan(STEP_BOUND);
+ // eslint-disable-next-line no-console
+ console.log(
+ `[BY] ${passes}/${N} random inputs matched modInverse. max bounded-loop steps observed: ${maxSteps} (bound ${STEP_BOUND})`,
+ );
+ });
+
+ it("invert(a) * a mod p == 1 (1000 random, separate seed)", () => {
+ const rng = makeRng(0xa5a5a5a5);
+ const N = 1000;
+ let bad = 0;
+ for (let i = 0; i < N; i++) {
+ const a = randomBelow(p, rng);
+ if (a === 0n) continue;
+ const r = invert(a, p);
+ if ((a * r) % p !== 1n) bad++;
+ }
+ expect(bad).toBe(0);
+ });
+
+ // --- Inner-piece sanity tests: lock down individual primitives so the
+ // --- WGSL port can diff against the same intermediate values.
+
+ it("fromBigint/toBigint round-trip on random values", () => {
+ const rng = makeRng(0x1234);
+ for (let i = 0; i < 100; i++) {
+ const x = randomBelow(p, rng);
+ const back = Wasm9x29.toBigint(Wasm9x29.fromBigint(x));
+ expect(back).toBe(x);
+ }
+ });
+
+ it("divsteps: gcd(p, a) ends with g == 0 after NUM_OUTER iters", () => {
+ // Spot-check: drive the full algorithm and assert the early-break
+ // condition fires within NUM_OUTER for typical inputs.
+ const rng = makeRng(7);
+ let everEarlyExit = false;
+ for (let i = 0; i < 50; i++) {
+ const a = randomBelow(p, rng);
+ if (a === 0n) continue;
+ // Re-run the driver loop and count outer iters until g==0.
+ const P = Wasm9x29.fromBigint(p);
+ const f = Wasm9x29.fromBigint(p);
+ const g = Wasm9x29.fromBigint(a);
+ const d = Wasm9x29.makeZero();
+ const e = Wasm9x29.makeOne();
+ let delta = 1n;
+ let outer = 0;
+ for (; outer < NUM_OUTER; outer++) {
+ const { mat, delta: nd } = Wasm9x29.divsteps(
+ delta,
+ Wasm9x29.low64(f),
+ Wasm9x29.low64(g),
+ );
+ delta = nd;
+ Wasm9x29.applyMatrix(mat, f, g, d, e, P, Wasm9x29.P_INV);
+ if (Wasm9x29.isZero(g)) break;
+ }
+ expect(outer).toBeLessThan(NUM_OUTER);
+ if (outer < NUM_OUTER - 1) everEarlyExit = true;
+ }
+ expect(everEarlyExit).toBe(true);
+ });
+});
+
+// ============================================================
+// WGSL helper TS mirror tests (sub-step 1.2)
+// ============================================================
+//
+// The TS helpers below mirror, byte-for-byte, the algorithms in
+// wgsl/bigint/bigint_by.template.wgsl. WGSL operates on 32-bit u32/i32
+// types that wrap on overflow; TS `number` plus `bigint` masking lets us
+// emulate the same bit semantics. These tests validate the algorithms
+// independent of any GPU dispatch — sub-step 1.3 will run the WGSL
+// helpers end-to-end in a divsteps shader.
+//
+// Each mirror function takes the same input shape as its WGSL counterpart
+// (typically a vec2 represented as a 2-element JS tuple) and returns the
+// same shape. Ground-truth checks use BigInt.
+
+const U32_MASK = 0xFFFFFFFFn;
+const I32_MIN = -(2 ** 31);
+const I32_MAX = 2 ** 31 - 1;
+const POW_29 = 1 << 29;
+const POW_29N = 1n << 29n;
+const POW_32N = 1n << 32n;
+const POW_64N = 1n << 64n;
+const I64_SIGN = 1n << 63n;
+const BY_LIMB_MASK_N = (1n << 29n) - 1n;
+
+// ---- 64-bit BigInt helpers used for ground-truth comparison ----
+function pairToU64(p: [number, number]): bigint {
+ const lo = BigInt(p[0] >>> 0);
+ const hi = BigInt(p[1] >>> 0);
+ return lo | (hi << 32n);
+}
+function u64ToPair(x: bigint): [number, number] {
+ const v = x & ((1n << 64n) - 1n);
+ return [Number(v & U32_MASK) >>> 0, Number((v >> 32n) & U32_MASK) >>> 0];
+}
+function i64PairToBigint(p: [number, number]): bigint {
+ // Low half is treated as unsigned 32 bits; high half is signed 32.
+ const lo = BigInt(p[0] >>> 0);
+ const hi = BigInt(p[1] | 0); // sign-extend via | 0
+ return lo + (hi << 32n);
+}
+function bigintToI64Pair(x: bigint): [number, number] {
+ const v = x & ((1n << 64n) - 1n);
+ const lo = Number(v & U32_MASK) >>> 0;
+ // hi is the upper 32 bits, interpreted as i32 (so values >= 2^31 wrap to negative).
+ const hiU = Number((v >> 32n) & U32_MASK);
+ const hi = hiU >= 2 ** 31 ? hiU - 2 ** 32 : hiU;
+ return [lo, hi];
+}
+
+// ---- TS mirrors of the WGSL helpers ----
+function u64_add_ts(a: [number, number], b: [number, number]): [number, number] {
+ const lo = (a[0] + b[0]) >>> 0;
+ const carry = lo < (a[0] >>> 0) ? 1 : 0;
+ // (a[1] + b[1] + carry) wraps to u32 like WGSL u32 add.
+ const hi = (a[1] + b[1] + carry) >>> 0;
+ return [lo, hi];
+}
+function u64_sub_ts(a: [number, number], b: [number, number]): [number, number] {
+ const lo = (a[0] - b[0]) >>> 0;
+ const borrow = (a[0] >>> 0) < (b[0] >>> 0) ? 1 : 0;
+ const hi = (a[1] - b[1] - borrow) >>> 0;
+ return [lo, hi];
+}
+function u64_shr1_ts(x: [number, number]): [number, number] {
+ const lo = (((x[0] >>> 0) >>> 1) | (((x[1] >>> 0) & 1) << 31)) >>> 0;
+ const hi = (x[1] >>> 0) >>> 1;
+ return [lo, hi];
+}
+function u64_low_bit_ts(x: [number, number]): number {
+ return (x[0] >>> 0) & 1;
+}
+function u64_neg_ts(x: [number, number]): [number, number] {
+ const nx: [number, number] = [(~x[0]) >>> 0, (~x[1]) >>> 0];
+ return u64_add_ts(nx, [1, 0]);
+}
+function u64_and_ts(
+ a: [number, number],
+ mask: [number, number],
+): [number, number] {
+ return [((a[0] & mask[0]) >>> 0), ((a[1] & mask[1]) >>> 0)];
+}
+
+function i64_add_pair_ts(
+ a: [number, number],
+ b: [number, number],
+): [number, number] {
+ const au = a[0] >>> 0;
+ const bu = b[0] >>> 0;
+ const lo = (au + bu) >>> 0;
+ const carry = lo < au ? 1 : 0;
+ // (a[1] + b[1] + carry) treated as i32 add with wrap.
+ const hi = (a[1] + b[1] + carry) | 0;
+ return [lo, hi];
+}
+function i64_sub_pair_ts(
+ a: [number, number],
+ b: [number, number],
+): [number, number] {
+ const au = a[0] >>> 0;
+ const bu = b[0] >>> 0;
+ const lo = (au - bu) >>> 0;
+ const borrow = au < bu ? 1 : 0;
+ const hi = (a[1] - b[1] - borrow) | 0;
+ return [lo, hi];
+}
+function i64_shl1_pair_ts(a: [number, number]): [number, number] {
+ const au = a[0] >>> 0;
+ const lo = (au << 1) >>> 0;
+ const bit31 = (au >>> 31) & 1;
+ // (a[1] << 1) | bit31, treated as i32.
+ const hi_u = (((a[1] >>> 0) << 1) | bit31) >>> 0;
+ const hi = hi_u >= 2 ** 31 ? hi_u - 2 ** 32 : hi_u;
+ return [lo, hi];
+}
+function i64_neg_pair_ts(a: [number, number]): [number, number] {
+ const nlo = (~a[0]) >>> 0;
+ const nhi = (~a[1]) >>> 0;
+ const lo = (nlo + 1) >>> 0;
+ const carry = lo < nlo ? 1 : 0;
+ const hi_u = (nhi + carry) >>> 0;
+ const hi = hi_u >= 2 ** 31 ? hi_u - 2 ** 32 : hi_u;
+ return [lo, hi];
+}
+
+// Track the worst-case partial-product magnitude across all calls. Tests
+// assert this stays under 2^31 to validate the width-choice claim.
+let signedMulSplitMaxPartial = 0n;
+
+function signed_mul_split_ts(a: number, b: number): [number, number] {
+ // Mirror the WGSL implementation: signed (15-bit lo, 14/16-bit hi) split.
+ // a_lo = (a << 17) >> 17 sign-extends low 15 bits. JS `|0` coerces to
+ // i32 with wrap; left/right shifts on Number use i32 semantics.
+ const a_lo = (a << 17) >> 17;
+ const a_hi = ((a - a_lo) >> 15) | 0;
+ const b_lo = (b << 17) >> 17;
+ const b_hi = ((b - b_lo) >> 15) | 0;
+
+ const pll = Math.imul(a_lo, b_lo);
+ const plh = Math.imul(a_lo, b_hi);
+ const phl = Math.imul(a_hi, b_lo);
+ const phh = Math.imul(a_hi, b_hi);
+ const mid = (plh + phl) | 0;
+
+ // Audit: every partial must fit in i32.
+ for (const p of [pll, plh, phl, phh, mid]) {
+ const bp = BigInt(Math.abs(p));
+ if (bp > signedMulSplitMaxPartial) signedMulSplitMaxPartial = bp;
+ }
+
+ // Compose low 64 bits of i64 = pll + (mid << 15) + (phh << 30) via u32
+ // wrap arithmetic with carries. JS Number can hold all intermediate
+ // values exactly (well under 2^53).
+ const pll_lo = pll >>> 0;
+ const pll_hi = (pll >> 31) >>> 0; // sign mask
+ const mid_lo = ((mid << 15) >>> 0);
+ // (mid >> 17) gives signed arith shift; >>> 0 converts to u32 bit pattern.
+ const mid_hi = (mid >> 17) >>> 0;
+ const phh_lo = ((phh << 30) >>> 0);
+ const phh_hi = (phh >> 2) >>> 0;
+
+ // Sum (pll_lo, pll_hi) + (mid_lo, mid_hi).
+ const s1_lo_raw = pll_lo + mid_lo;
+ const s1_lo = s1_lo_raw >>> 0;
+ const c1 = s1_lo_raw >= 2 ** 32 ? 1 : 0;
+ const s1_hi_raw = pll_hi + mid_hi + c1;
+ const s1_hi = s1_hi_raw >>> 0;
+
+ // Add (phh_lo, phh_hi).
+ const sum_lo_raw = s1_lo + phh_lo;
+ const sum_lo = sum_lo_raw >>> 0;
+ const c2 = sum_lo_raw >= 2 ** 32 ? 1 : 0;
+ const sum_hi_raw = s1_hi + phh_hi + c2;
+ const sum_hi = sum_hi_raw >>> 0;
+
+ const lo29 = sum_lo & 0x1FFFFFFF;
+ const hi_u = (((sum_lo >>> 29) | (sum_hi << 3)) >>> 0);
+ const hi = hi_u >= 2 ** 31 ? hi_u - 2 ** 32 : hi_u;
+ return [lo29, hi];
+}
+
+function by_accumulate_ts(
+ acc: [number, number],
+ m_lo: number,
+ x_limb: number,
+): [number, number] {
+ const prod = signed_mul_split_ts(m_lo, x_limb);
+ const lo_sum = (acc[0] + prod[0]) | 0;
+ const lo29 = lo_sum & ((1 << 29) - 1);
+ // Arithmetic shift right by 29 in i32: JS `>>` is sign-preserving on i32.
+ const lo_overflow = lo_sum >> 29;
+ const hi = ((acc[1] + prod[1] + lo_overflow) | 0);
+ return [lo29, hi];
+}
+
+const BY_NUM_LIMBS_N = 9;
+function by_normalise_ts(x: number[]): number[] {
+ const out = x.slice();
+ let c = 0n;
+ for (let i = 0; i < BY_NUM_LIMBS_N - 1; i++) {
+ // Use BigInt math to mirror WGSL i32 semantics exactly: limb values can
+ // be ±2^60 from caller, and we want the low 29 bits + arithmetic shift.
+ const v = BigInt(out[i]) + c;
+ out[i] = Number(v & BY_LIMB_MASK_N);
+ c = v >> 29n;
+ }
+ out[BY_NUM_LIMBS_N - 1] = Number(BigInt(out[BY_NUM_LIMBS_N - 1]) + c);
+ return out;
+}
+
+// ---- 20×13-bit BigInt <-> 9×29-bit BigIntBY (TS mirror) ----
+const NUM_WORDS = 20;
+const WORD_SIZE = 13;
+const WORD_MASK = (1 << 13) - 1;
+
+function read29Window(limbs: number[], bit_lo: number): number {
+ // 29-bit window from 13-bit limbs: requires up to ceil((29 + 12) / 13) = 4
+ // source limbs (worst case in_limb_bit = 12).
+ const base_idx = Math.floor(bit_lo / WORD_SIZE);
+ const in_limb_bit = bit_lo % WORD_SIZE;
+ const v0 = base_idx < NUM_WORDS ? limbs[base_idx] : 0;
+ const v1 = base_idx + 1 < NUM_WORDS ? limbs[base_idx + 1] : 0;
+ const v2 = base_idx + 2 < NUM_WORDS ? limbs[base_idx + 2] : 0;
+ const v3 = base_idx + 3 < NUM_WORDS ? limbs[base_idx + 3] : 0;
+ // Use BigInt to avoid 32-bit signed overflow when shifting up to 39 bits.
+ const s0 = BigInt(v0) >> BigInt(in_limb_bit);
+ const s1 = BigInt(v1) << BigInt(WORD_SIZE - in_limb_bit);
+ const s2 = BigInt(v2) << BigInt(WORD_SIZE * 2 - in_limb_bit);
+ const s3 = BigInt(v3) << BigInt(WORD_SIZE * 3 - in_limb_bit);
+ return Number((s0 | s1 | s2 | s3) & BY_LIMB_MASK_N);
+}
+function by_from_bigint_ts(limbs: number[]): number[] {
+ const r = new Array(9).fill(0);
+ for (let i = 0; i < 9; i++) {
+ r[i] = read29Window(limbs, i * 29);
+ }
+ return r;
+}
+function read13Window(by_l: number[], bit_lo: number): number {
+ const base_idx = Math.floor(bit_lo / 29);
+ const in_limb_bit = bit_lo % 29;
+ const v0 = base_idx < 9 ? by_l[base_idx] >>> 0 : 0;
+ const v1 = base_idx + 1 < 9 ? by_l[base_idx + 1] >>> 0 : 0;
+ const s0 = BigInt(v0) >> BigInt(in_limb_bit);
+ const s1 = BigInt(v1) << BigInt(29 - in_limb_bit);
+ return Number((s0 | s1) & BigInt(WORD_MASK));
+}
+function by_to_bigint_ts(by_l: number[]): number[] {
+ const out = new Array(NUM_WORDS).fill(0);
+ for (let i = 0; i < NUM_WORDS; i++) {
+ out[i] = read13Window(by_l, i * WORD_SIZE);
+ }
+ return out;
+}
+
+// Convert a 256-bit bigint into a 20×13-bit limb array (canonical, all
+// limbs in [0, 2^13)).
+function bigintToLimbs(x: bigint): number[] {
+ const out: number[] = new Array(NUM_WORDS).fill(0);
+ let v = x;
+ for (let i = 0; i < NUM_WORDS; i++) {
+ out[i] = Number(v & BigInt(WORD_MASK));
+ v = v >> 13n;
+ }
+ return out;
+}
+
+describe("WGSL helper TS mirror", () => {
+ it("u64_add matches BigInt ground truth (500 random)", () => {
+ const rng = makeRng(0x11111111);
+ let failures = 0;
+ for (let i = 0; i < 500; i++) {
+ const a: [number, number] = [rng() >>> 0, rng() >>> 0];
+ const b: [number, number] = [rng() >>> 0, rng() >>> 0];
+ const got = u64_add_ts(a, b);
+ const expectedBN = (pairToU64(a) + pairToU64(b)) & ((1n << 64n) - 1n);
+ if (pairToU64(got) !== expectedBN) failures++;
+ }
+ expect(failures).toBe(0);
+ });
+
+ it("u64_sub matches BigInt ground truth (500 random)", () => {
+ const rng = makeRng(0x22222222);
+ let failures = 0;
+ for (let i = 0; i < 500; i++) {
+ const a: [number, number] = [rng() >>> 0, rng() >>> 0];
+ const b: [number, number] = [rng() >>> 0, rng() >>> 0];
+ const got = u64_sub_ts(a, b);
+ const expectedBN =
+ (pairToU64(a) - pairToU64(b) + (1n << 64n)) & ((1n << 64n) - 1n);
+ if (pairToU64(got) !== expectedBN) failures++;
+ }
+ expect(failures).toBe(0);
+ });
+
+ it("u64_shr1 matches BigInt ground truth (100 random)", () => {
+ const rng = makeRng(0x33333333);
+ let failures = 0;
+ for (let i = 0; i < 100; i++) {
+ const a: [number, number] = [rng() >>> 0, rng() >>> 0];
+ const got = u64_shr1_ts(a);
+ const expectedBN = pairToU64(a) >> 1n;
+ if (pairToU64(got) !== expectedBN) failures++;
+ }
+ expect(failures).toBe(0);
+ });
+
+ it("u64_low_bit matches BigInt ground truth (100 random)", () => {
+ const rng = makeRng(0x44444444);
+ let failures = 0;
+ for (let i = 0; i < 100; i++) {
+ const a: [number, number] = [rng() >>> 0, rng() >>> 0];
+ const got = u64_low_bit_ts(a);
+ const expected = Number(pairToU64(a) & 1n);
+ if (got !== expected) failures++;
+ }
+ expect(failures).toBe(0);
+ });
+
+ it("u64_neg matches BigInt ground truth (100 random)", () => {
+ const rng = makeRng(0x55555555);
+ let failures = 0;
+ for (let i = 0; i < 100; i++) {
+ const a: [number, number] = [rng() >>> 0, rng() >>> 0];
+ const got = u64_neg_ts(a);
+ const expectedBN =
+ (((1n << 64n) - pairToU64(a)) & ((1n << 64n) - 1n)) % (1n << 64n);
+ if (pairToU64(got) !== expectedBN) failures++;
+ }
+ expect(failures).toBe(0);
+ });
+
+ it("u64_and matches BigInt ground truth (100 random)", () => {
+ const rng = makeRng(0x66666666);
+ let failures = 0;
+ for (let i = 0; i < 100; i++) {
+ const a: [number, number] = [rng() >>> 0, rng() >>> 0];
+ const m: [number, number] = [rng() >>> 0, rng() >>> 0];
+ const got = u64_and_ts(a, m);
+ const expectedBN = pairToU64(a) & pairToU64(m);
+ if (pairToU64(got) !== expectedBN) failures++;
+ }
+ expect(failures).toBe(0);
+ });
+
+ // ---- i64 signed pair helpers ----
+ function makeI64Pair(rng: () => number, magBits: number): [number, number] {
+ // Generate a uniform signed integer in [-2^magBits, 2^magBits).
+ let v = 0n;
+ const bytes = Math.ceil(magBits / 8) + 1;
+ for (let i = 0; i < bytes; i++) {
+ v = (v << 8n) | BigInt(rng() & 0xff);
+ }
+ const range = 1n << BigInt(magBits);
+ // Map into [-range, range).
+ v = (v % (2n * range)) - range;
+ return bigintToI64Pair(v);
+ }
+
+ it("i64_add_pair matches BigInt ground truth (200 random)", () => {
+ const rng = makeRng(0x77777777);
+ let failures = 0;
+ for (let i = 0; i < 200; i++) {
+ const a = makeI64Pair(rng, 60);
+ const b = makeI64Pair(rng, 60);
+ const got = i64_add_pair_ts(a, b);
+ let expectedBN = (i64PairToBigint(a) + i64PairToBigint(b)) & ((1n << 64n) - 1n);
+ if (expectedBN >= I64_SIGN) expectedBN -= 1n << 64n;
+ if (i64PairToBigint(got) !== expectedBN) {
+ failures++;
+ if (failures <= 3) {
+ console.error(
+ `i64_add fail: a=${i64PairToBigint(a)} b=${i64PairToBigint(b)} got=${i64PairToBigint(got)} expected=${expectedBN}`,
+ );
+ }
+ }
+ }
+ expect(failures).toBe(0);
+ });
+
+ it("i64_sub_pair matches BigInt ground truth (200 random)", () => {
+ const rng = makeRng(0x88888888);
+ let failures = 0;
+ for (let i = 0; i < 200; i++) {
+ const a = makeI64Pair(rng, 60);
+ const b = makeI64Pair(rng, 60);
+ const got = i64_sub_pair_ts(a, b);
+ let expectedBN = (i64PairToBigint(a) - i64PairToBigint(b)) & ((1n << 64n) - 1n);
+ if (expectedBN >= I64_SIGN) expectedBN -= 1n << 64n;
+ if (i64PairToBigint(got) !== expectedBN) failures++;
+ }
+ expect(failures).toBe(0);
+ });
+
+ it("i64_shl1_pair matches BigInt ground truth (200 random)", () => {
+ const rng = makeRng(0x99999999);
+ let failures = 0;
+ for (let i = 0; i < 200; i++) {
+ const a = makeI64Pair(rng, 60);
+ const got = i64_shl1_pair_ts(a);
+ let expectedBN = (i64PairToBigint(a) << 1n) & ((1n << 64n) - 1n);
+ if (expectedBN >= I64_SIGN) expectedBN -= 1n << 64n;
+ if (i64PairToBigint(got) !== expectedBN) failures++;
+ }
+ expect(failures).toBe(0);
+ });
+
+ it("i64_neg_pair matches BigInt ground truth (200 random)", () => {
+ const rng = makeRng(0xaaaaaaaa);
+ let failures = 0;
+ for (let i = 0; i < 200; i++) {
+ const a = makeI64Pair(rng, 60);
+ const got = i64_neg_pair_ts(a);
+ let expectedBN = (-i64PairToBigint(a)) & ((1n << 64n) - 1n);
+ if (expectedBN >= I64_SIGN) expectedBN -= 1n << 64n;
+ if (i64PairToBigint(got) !== expectedBN) failures++;
+ }
+ expect(failures).toBe(0);
+ });
+
+ // ---- signed_mul_split ----
+ it("signed_mul_split: 500 random (a,b) with bounds + partial-product audit", () => {
+ const rng = makeRng(0xbbbbbbbb);
+ let failures = 0;
+ signedMulSplitMaxPartial = 0n;
+ for (let i = 0; i < 500; i++) {
+ // a in [-2^29, 2^29] (one BY limb plus possible sign-bit overflow).
+ // b in [-2^31, 2^31) (i32 range, matrix entry low half).
+ const aMag = rng() & ((1 << 29) - 1); // 29-bit unsigned in [0, 2^29)
+ const aSign = (rng() & 1) ? -1 : 1;
+ const a = aSign * aMag;
+ const bSign = (rng() & 1) ? -1 : 1;
+ // Generate b in [-2^31, 2^31). i32 representation: any 32-bit value.
+ const bMagU32 = rng() >>> 0;
+ const b = bMagU32 >= 2 ** 31 ? bMagU32 - 2 ** 32 : bMagU32;
+ // Clamp |a| to ≤ 2^29 — generate could be exactly 2^29-1, which is fine.
+ if (Math.abs(a) > 2 ** 29) continue;
+
+ const got = signed_mul_split_ts(a, b);
+ const lo29 = BigInt(got[0]);
+ const hi = BigInt(got[1]);
+ const reconstructed = lo29 + hi * POW_29N;
+ const expected = BigInt(a) * BigInt(b);
+ if (reconstructed !== expected) {
+ failures++;
+ if (failures <= 3) {
+ console.error(
+ `signed_mul_split fail: a=${a} b=${b} got=(lo29=${got[0]}, hi=${got[1]}) reconstructed=${reconstructed} expected=${expected}`,
+ );
+ }
+ }
+ // lo29 must be in [0, 2^29).
+ if (got[0] < 0 || got[0] >= POW_29) failures++;
+ }
+ expect(failures).toBe(0);
+ // Width audit (per plan): every signed partial product fits in i32.
+ // With the 15-bit signed split, the worst case is |a_lo| * |b_hi| or
+ // |a_hi| * |b_hi| <= 2^14 * 2^16 = 2^30, well within i32.
+ expect(signedMulSplitMaxPartial).toBeLessThan(1n << 31n);
+ });
+
+ it("signed_mul_split: edge cases (zero, max-positive, min-negative)", () => {
+ // For |a|, |b| <= 2^29-1 (strict), |hi| <= 2^29 fits in i32 comfortably.
+ // Note: a = -2^29 with b = -2^31 produces a*b = 2^60, hi = 2^31 which
+ // does NOT fit in signed i32. The actual use case in apply_matrix has
+ // |b| <= 2^29 (matrix-entry low half is in [0, 2^29) unsigned, matrix-
+ // entry high half is signed in [-2^29, 2^29)), so the edge cases here
+ // are bounded accordingly.
+ const edges: Array<[number, number]> = [
+ [0, 0],
+ [1, 1],
+ [POW_29 - 1, 0],
+ [0, POW_29 - 1],
+ [POW_29 - 1, POW_29 - 1],
+ [-POW_29, POW_29 - 1],
+ [POW_29 - 1, -POW_29],
+ [-POW_29, -POW_29],
+ [-1, -1],
+ [1, -1],
+ [-1, 1],
+ ];
+ for (const [a, b] of edges) {
+ const got = signed_mul_split_ts(a, b);
+ const reconstructed = BigInt(got[0]) + BigInt(got[1]) * POW_29N;
+ const expected = BigInt(a) * BigInt(b);
+ expect(reconstructed).toBe(expected);
+ expect(got[0]).toBeGreaterThanOrEqual(0);
+ expect(got[0]).toBeLessThan(POW_29);
+ }
+ });
+
+ // ---- by_accumulate ----
+ it("by_accumulate: 200 random matches BigInt", () => {
+ const rng = makeRng(0xcccccccc);
+ let failures = 0;
+ for (let i = 0; i < 200; i++) {
+ // acc.x in [0, 2^29); acc.y bounded so |acc| < 2^58 (matches the
+ // 2-limb accumulator used by apply_matrix). Generate |acc_hi| <= 2^29.
+ const acc_lo = rng() & ((1 << 29) - 1);
+ const acc_hi_mag = rng() & ((1 << 29) - 1);
+ const acc_hi = (rng() & 1) ? -acc_hi_mag : acc_hi_mag;
+ const m_lo_mag = rng() & ((1 << 29) - 1);
+ const m_lo = (rng() & 1) ? -m_lo_mag : m_lo_mag;
+ const x_limb_mag = rng() & ((1 << 29) - 1);
+ const x_limb = (rng() & 1) ? -x_limb_mag : x_limb_mag;
+
+ const got = by_accumulate_ts([acc_lo, acc_hi], m_lo, x_limb);
+ const accBN = BigInt(acc_lo) + BigInt(acc_hi) * POW_29N;
+ const expected = accBN + BigInt(m_lo) * BigInt(x_limb);
+ const reconstructed = BigInt(got[0]) + BigInt(got[1]) * POW_29N;
+ if (reconstructed !== expected) {
+ failures++;
+ if (failures <= 3) {
+ console.error(
+ `by_accumulate fail: acc=(${acc_lo}, ${acc_hi}) m_lo=${m_lo} x_limb=${x_limb} got=(${got[0]}, ${got[1]}) reconstructed=${reconstructed} expected=${expected}`,
+ );
+ }
+ }
+ if (got[0] < 0 || got[0] >= POW_29) failures++;
+ }
+ expect(failures).toBe(0);
+ });
+
+ // ---- by_normalise ----
+ it("by_normalise: 100 cases with non-canonical limbs", () => {
+ const rng = makeRng(0xdddddddd);
+ let failures = 0;
+ for (let i = 0; i < 100; i++) {
+ const limbs: number[] = [];
+ // Build truly non-canonical limbs in [-2^60, 2^60) using BigInt then
+ // reduce mod 2^32 for the JS array (since we re-read via BigInt below).
+ const limbBN: bigint[] = [];
+ let originalValueBN = 0n;
+ for (let j = 0; j < 9; j++) {
+ // Random signed up to ~2^31 (since JS Number array would lose
+ // precision otherwise). We're still exercising overflow scenarios.
+ const mag = rng() & 0xfffffff; // 28-bit
+ const sign = (rng() & 1) ? -1n : 1n;
+ const v = sign * BigInt(mag);
+ limbBN.push(v);
+ limbs.push(Number(v));
+ originalValueBN += v << BigInt(j * 29);
+ }
+ const got = by_normalise_ts(limbs);
+ // Compute reconstructed (treating limbs as signed where limb 8 is the
+ // sign carrier). After normalise, l[0..7] are in [0, 2^29) and l[8]
+ // is signed.
+ let reconstructed = 0n;
+ for (let j = 0; j < 8; j++) {
+ reconstructed += BigInt(got[j]) << BigInt(j * 29);
+ }
+ // l[8] is signed (no normalisation past).
+ reconstructed += BigInt(got[8]) << BigInt(8 * 29);
+ if (reconstructed !== originalValueBN) {
+ failures++;
+ if (failures <= 3) {
+ console.error(
+ `by_normalise fail: original=${originalValueBN} reconstructed=${reconstructed}`,
+ );
+ }
+ }
+ for (let j = 0; j < 8; j++) {
+ if (got[j] < 0 || got[j] >= POW_29) {
+ failures++;
+ if (failures <= 3) {
+ console.error(`limb ${j} not canonical: ${got[j]}`);
+ }
+ }
+ }
+ }
+ expect(failures).toBe(0);
+ });
+
+ // ---- by_from_bigint / by_to_bigint ----
+ it("by_from_bigint / by_to_bigint round-trip on 200 BN254 base-field values", () => {
+ const rng = makeRng(0xeeeeeeee);
+ let failures = 0;
+ for (let i = 0; i < 200; i++) {
+ const x = randomBelow(P_BIGINT, rng);
+ const srcLimbs = bigintToLimbs(x);
+ const by = by_from_bigint_ts(srcLimbs);
+ // Reconstruct from BY limbs to verify the conversion.
+ let v = 0n;
+ for (let j = 0; j < 9; j++) {
+ v += BigInt(by[j]) << BigInt(j * 29);
+ }
+ if (v !== x) {
+ failures++;
+ if (failures <= 3) {
+ console.error(`by_from_bigint fail: x=${x} reconstructed=${v}`);
+ }
+ }
+ // Now round-trip back.
+ const back = by_to_bigint_ts(by);
+ let backV = 0n;
+ for (let j = 0; j < NUM_WORDS; j++) {
+ backV += BigInt(back[j]) << BigInt(j * WORD_SIZE);
+ }
+ if (backV !== x) {
+ failures++;
+ if (failures <= 3) {
+ console.error(`by_to_bigint fail: x=${x} back=${backV}`);
+ }
+ }
+ // Every output limb is canonical.
+ for (let j = 0; j < NUM_WORDS; j++) {
+ if (back[j] < 0 || back[j] > WORD_MASK) failures++;
+ }
+ }
+ expect(failures).toBe(0);
+ });
+});
diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang.ts b/barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang.ts
new file mode 100644
index 000000000000..693c02a78dcd
--- /dev/null
+++ b/barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang.ts
@@ -0,0 +1,470 @@
+// TypeScript port of the Wasm9x29 Bernstein-Yang safegcd modular inverse
+// from barretenberg/cpp/src/barretenberg/ecc/fields/bernstein_yang_inverse_wasm.hpp
+// (with the driver from bernstein_yang_inverse.hpp).
+//
+// This module is the ground-truth reference for the WGSL port that follows
+// in the WebGPU MSM rewrite plan (sub-steps 1.2-1.6). It must be a faithful
+// transliteration of the C++ — same constants, same control flow, same
+// per-limb arithmetic. Performance is not a concern; correctness vs
+// `modInverse` over 1000+ random BN254 base-field values is.
+//
+// Algorithmic notes
+// -----------------
+// State: 9 × 29-bit signed limbs held as `bigint[9]`. Top limb carries
+// sign, lower limbs in [0, 2^29) post-normalise. Choice of `bigint` over
+// `number`: matrix entries grow up to 2^58 after BATCH=58 divsteps and the
+// per-limb cross-products in apply_matrix reach |i58|, which overflow JS's
+// 32-bit safe integer range for arithmetic. The intermediate accumulators
+// (cf, cg, cd, ce, nf, ng, nd, ne, t_d, t_e) must be `bigint`.
+//
+// The divsteps low-64-bit carriers (f_lo, g_lo) are also `bigint` since they
+// hold u64 values. Each divstep mutates these via shift / add / sub at i64
+// precision; we mask back to u64 after each mutation to avoid sign
+// extension surprises. BigInt's `>>` is arithmetic shift for negative
+// values, matching C++ `(i64) >> n`.
+//
+// Loop bound discipline (per plan §1, hard constraint): the only loops in
+// this file have static upper bounds.
+// - divsteps inner loop: BATCH = 58
+// - apply_matrix per-limb loops: N = 9
+// - reduce_to_canonical: RTC_MAX_ITERS = 36
+// - driver outer loop: NUM_OUTER = 13 (with early `g == 0` break)
+// The test file asserts the step counter never exceeds these bounds.
+
+import { BN254_BASE_FIELD } from "./bn254.js";
+
+// ---- Constants (mirror Wasm9x29::*) -----------------------------------
+export const N = 9;
+export const LIMB_BITS = 29n;
+export const LIMB_MASK = (1n << LIMB_BITS) - 1n; // 0x1FFFFFFF
+export const BATCH = 58n;
+export const BATCH_NUM = 58;
+export const MASK_BATCH = (1n << BATCH) - 1n;
+export const NUM_OUTER = 13;
+export const REDUCE_INTERVAL = 4;
+export const RTC_MAX_ITERS = 36;
+
+// u64 mask for f_lo / g_lo carriers in divsteps.
+const U64_MASK = (1n << 64n) - 1n;
+const U64_SIGN_BIT = 1n << 63n;
+
+// ---- BN254 base-field p (matches the C++ `p` argument) ----------------
+export const P_BIGINT = BN254_BASE_FIELD;
+
+// p_inv = p^(-1) mod 2^BATCH (BATCH = 58). Used by apply_matrix's 2-adic
+// correction step. Precomputed via Hensel's lemma; verified
+// (P * P_INV) % 2^58 === 1.
+export const P_INV = 12939590167534711n;
+
+// ---- 2x2 matrix produced by BATCH divsteps ----------------------------
+export interface Mat {
+ u: bigint;
+ v: bigint;
+ q: bigint;
+ r: bigint;
+}
+
+// ---- BigIntBY state: 9 × 29-bit signed limbs --------------------------
+export type BigIntBY = bigint[]; // length 9
+
+export function makeZero(): BigIntBY {
+ return [0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n];
+}
+
+export function makeOne(): BigIntBY {
+ const r = makeZero();
+ r[0] = 1n;
+ return r;
+}
+
+// uint256 -> 9 × 29-bit limbs (matches Wasm9x29(const uint256_t&) ctor).
+// Input `x` is the canonical integer in [0, 2^256); we treat it as 4 × u64
+// limbs internally to mirror the C++ bit extraction exactly.
+export function fromBigint(x: bigint): BigIntBY {
+ if (x < 0n) {
+ throw new Error("fromBigint: negative input");
+ }
+ const mask64 = (1n << 64n) - 1n;
+ const d0 = x & mask64;
+ const d1 = (x >> 64n) & mask64;
+ const d2 = (x >> 128n) & mask64;
+ const d3 = (x >> 192n) & mask64;
+ const l: BigIntBY = makeZero();
+ l[0] = d0 & LIMB_MASK;
+ l[1] = (d0 >> 29n) & LIMB_MASK;
+ l[2] = ((d0 >> 58n) & 0x3Fn) | ((d1 & 0x7FFFFFn) << 6n);
+ l[3] = (d1 >> 23n) & LIMB_MASK;
+ l[4] = ((d1 >> 52n) & 0xFFFn) | ((d2 & 0x1FFFFn) << 12n);
+ l[5] = (d2 >> 17n) & LIMB_MASK;
+ l[6] = ((d2 >> 46n) & 0x3FFFFn) | ((d3 & 0x7FFn) << 18n);
+ l[7] = (d3 >> 11n) & LIMB_MASK;
+ l[8] = (d3 >> 40n) & 0xFFFFFFn;
+ return l;
+}
+
+// 9 × 29-bit limbs -> uint256 (matches to_uint256()).
+// We re-pack into 4 × u64 chunks then concatenate. Assumes the state is
+// canonical (all limbs in [0, 2^29), top limb non-negative).
+export function toBigint(x: BigIntBY): bigint {
+ const mask64 = (1n << 64n) - 1n;
+ const w0 = (x[0] | (x[1] << 29n) | (x[2] << 58n)) & mask64;
+ const w1 = ((x[2] >> 6n) | (x[3] << 23n) | (x[4] << 52n)) & mask64;
+ const w2 = ((x[4] >> 12n) | (x[5] << 17n) | (x[6] << 46n)) & mask64;
+ const w3 = ((x[6] >> 18n) | (x[7] << 11n) | (x[8] << 40n)) & mask64;
+ return w0 | (w1 << 64n) | (w2 << 128n) | (w3 << 192n);
+}
+
+export function isZero(x: BigIntBY): boolean {
+ let a = 0n;
+ for (let i = 0; i < N; i++) {
+ a |= x[i];
+ }
+ return a === 0n;
+}
+
+export function isNegative(x: BigIntBY): boolean {
+ return x[N - 1] < 0n;
+}
+
+export function neg(x: BigIntBY): void {
+ for (let i = 0; i < N; i++) {
+ x[i] = -x[i];
+ }
+ normalise(x);
+}
+
+// Canonicalise limb carries: propagate so each lower limb is in
+// [0, 2^29) using arithmetic shift; the top limb absorbs the final
+// carry (and may go negative).
+export function normalise(x: BigIntBY): void {
+ let c = 0n;
+ for (let i = 0; i < N - 1; i++) {
+ const v = x[i] + c;
+ x[i] = v & LIMB_MASK;
+ c = v >> LIMB_BITS; // arithmetic shift for negative v
+ }
+ x[N - 1] += c;
+}
+
+function addInplace(x: BigIntBY, b: BigIntBY): void {
+ for (let i = 0; i < N; i++) {
+ x[i] += b[i];
+ }
+ normalise(x);
+}
+
+function subInplace(x: BigIntBY, b: BigIntBY): void {
+ for (let i = 0; i < N; i++) {
+ x[i] -= b[i];
+ }
+ normalise(x);
+}
+
+// reduce_to_canonical: bring x into [0, p) using at most 36 add-p / sub-p
+// passes. 36 covers |x| ≤ 32p under REDUCE_INTERVAL = 4 (see Wasm9x29 docs).
+// Tracks step count and asserts it never exceeds the bound — the WGSL port
+// will have the same fixed bound.
+export function reduceToCanonical(
+ x: BigIntBY,
+ p: BigIntBY,
+ stepRef?: { steps: number },
+): void {
+ normalise(x);
+ for (let it = 0; it < RTC_MAX_ITERS; it++) {
+ if (stepRef) stepRef.steps++;
+ if (isNegative(x)) {
+ addInplace(x, p);
+ continue;
+ }
+ let cmp = 0;
+ for (let i = N - 1; i >= 0; i--) {
+ if (x[i] !== p[i]) {
+ cmp = x[i] > p[i] ? 1 : -1;
+ break;
+ }
+ }
+ if (cmp < 0) {
+ break;
+ }
+ subInplace(x, p);
+ }
+}
+
+// ---- divsteps ---------------------------------------------------------
+// Run BATCH = 58 branchy divsteps on the low 64 bits of (f, g); returns
+// the transition matrix M and updates delta. Variable-time over inner
+// branches — non-secret inputs only.
+//
+// Inputs/outputs:
+// delta: scalar bigint (interpret as i64), mutated by the caller via the
+// `deltaRef` wrapper. We return the new delta in the tuple to keep
+// the signature simple for testing.
+// f_lo, g_lo: bigint values representing u64. We mask to u64 after each
+// mutation; arithmetic shifts on bigint handle the sign correctly
+// because we explicitly convert g_lo to its signed form for the
+// subtraction `(g_lo - f_lo) >> 1` (C++ relies on u64-wrap then
+// unsigned >>1, which we mimic exactly with `& U64_MASK`).
+export function divsteps(
+ deltaIn: bigint,
+ fLoIn: bigint,
+ gLoIn: bigint,
+ stepRef?: { steps: number },
+): { mat: Mat; delta: bigint } {
+ let delta = deltaIn;
+ let f_lo = fLoIn & U64_MASK;
+ let g_lo = gLoIn & U64_MASK;
+ let u = 1n;
+ let v = 0n;
+ let q = 0n;
+ let r = 1n;
+ for (let i = 0; i < BATCH_NUM; i++) {
+ if (stepRef) stepRef.steps++;
+ if ((g_lo & 1n) !== 0n) {
+ if (delta > 0n) {
+ // (f, g) <- (g, (g - f)/2)
+ const nf = g_lo;
+ // (g_lo - f_lo) wraps mod 2^64 (C++ u64 sub), then unsigned >>1.
+ const diff = (g_lo - f_lo) & U64_MASK;
+ const ng = diff >> 1n;
+ const nu = q << 1n;
+ const nv = r << 1n;
+ const nq = q - u;
+ const nr = r - v;
+ f_lo = nf;
+ g_lo = ng;
+ u = nu;
+ v = nv;
+ q = nq;
+ r = nr;
+ delta = 1n - delta;
+ } else {
+ // g <- (g + f)/2
+ const sum = (g_lo + f_lo) & U64_MASK;
+ g_lo = sum >> 1n;
+ q = q + u;
+ r = r + v;
+ u <<= 1n;
+ v <<= 1n;
+ delta = delta + 1n;
+ }
+ } else {
+ // g <- g/2
+ g_lo >>= 1n;
+ u <<= 1n;
+ v <<= 1n;
+ delta = delta + 1n;
+ }
+ }
+ return { mat: { u, v, q, r }, delta };
+}
+
+// ---- apply_matrix -----------------------------------------------------
+// Apply M to (f, g, d, e) using the streamed schoolbook described in the
+// C++ comment block. Matrix entries can reach |2^58| post-divsteps; we
+// split into (lo: 29-bit, hi: i32) parts via arithmetic shift, matching
+// the C++ `u_lo = m.u & LIMB_MASK; u_hi = m.u >> LIMB_BITS;`.
+//
+// The exact `/ 2^BATCH = 2^58 = 2 * LIMB_BITS` at the end is realised by
+// writing the per-limb result at position `i - 2` (drop bottom two limbs).
+export function applyMatrix(
+ m: Mat,
+ f: BigIntBY,
+ g: BigIntBY,
+ d: BigIntBY,
+ e: BigIntBY,
+ p: BigIntBY,
+ pInv: bigint,
+): void {
+ // (signed) split into low 29 bits + high (arithmetic shift).
+ const u_lo = m.u & LIMB_MASK;
+ const u_hi = m.u >> LIMB_BITS;
+ const v_lo = m.v & LIMB_MASK;
+ const v_hi = m.v >> LIMB_BITS;
+ const q_lo = m.q & LIMB_MASK;
+ const q_hi = m.q >> LIMB_BITS;
+ const r_lo = m.r & LIMB_MASK;
+ const r_hi = m.r >> LIMB_BITS;
+
+ // ---- (f, g) pass ----------------------------------------------------
+ {
+ let cf = 0n;
+ let cg = 0n;
+ let fp = 0n;
+ let gp = 0n;
+ for (let i = 0; i < N; i++) {
+ const fi = f[i];
+ const gi = g[i];
+ const nf = u_lo * fi + v_lo * gi + u_hi * fp + v_hi * gp + cf;
+ const ng = q_lo * fi + r_lo * gi + q_hi * fp + r_hi * gp + cg;
+ cf = nf >> LIMB_BITS;
+ cg = ng >> LIMB_BITS;
+ if (i >= 2) {
+ f[i - 2] = nf & LIMB_MASK;
+ g[i - 2] = ng & LIMB_MASK;
+ }
+ fp = fi;
+ gp = gi;
+ }
+ const nf9 = u_hi * fp + v_hi * gp + cf;
+ const ng9 = q_hi * fp + r_hi * gp + cg;
+ f[N - 2] = nf9 & LIMB_MASK;
+ g[N - 2] = ng9 & LIMB_MASK;
+ f[N - 1] = nf9 >> LIMB_BITS;
+ g[N - 1] = ng9 >> LIMB_BITS;
+ }
+
+ // ---- (d, e) pass with 2-adic k·p correction -------------------------
+ {
+ const d0 = d[0];
+ const e0 = e[0];
+ const d1 = d[1];
+ const e1 = e[1];
+ const nd0 = u_lo * d0 + v_lo * e0;
+ const ne0 = q_lo * d0 + r_lo * e0;
+ const nd1 = u_lo * d1 + v_lo * e1 + u_hi * d0 + v_hi * e0;
+ const ne1 = q_lo * d1 + r_lo * e1 + q_hi * d0 + r_hi * e0;
+ // Reconstruct low 58 bits of nd and ne for k computation.
+ // C++ does `(u64)nd0 & LIMB_MASK | (((u64)(nd1 + (nd0 >> LIMB_BITS)) & LIMB_MASK) << LIMB_BITS)`.
+ // The `(u64)` casts truncate signed -> unsigned mod 2^64. We replicate
+ // by masking with U64_MASK before assembling.
+ const nd0_low = nd0 & LIMB_MASK;
+ const nd1_carry = (nd1 + (nd0 >> LIMB_BITS)) & LIMB_MASK;
+ const t_d = (nd0_low | (nd1_carry << LIMB_BITS)) & U64_MASK;
+ const ne0_low = ne0 & LIMB_MASK;
+ const ne1_carry = (ne1 + (ne0 >> LIMB_BITS)) & LIMB_MASK;
+ const t_e = (ne0_low | (ne1_carry << LIMB_BITS)) & U64_MASK;
+ // k = (-t mod 2^64) * p_inv mod 2^BATCH.
+ // (-t) mod 2^64 = (~t + 1) & U64_MASK = (U64_MASK - t + 1) & U64_MASK.
+ const neg_td = (U64_MASK - t_d + 1n) & U64_MASK;
+ const neg_te = (U64_MASK - t_e + 1n) & U64_MASK;
+ const k_d = (neg_td * pInv) & MASK_BATCH;
+ const k_e = (neg_te * pInv) & MASK_BATCH;
+ const kd_lo = k_d & LIMB_MASK;
+ const kd_hi = k_d >> LIMB_BITS;
+ const ke_lo = k_e & LIMB_MASK;
+ const ke_hi = k_e >> LIMB_BITS;
+ // Initial carries seed from limb-1 partials (matches the C++
+ // `cd = (nd1 + kd_lo*p[1] + kd_hi*p[0] + ((nd0 + kd_lo*p[0]) >> LIMB_BITS)) >> LIMB_BITS`).
+ let cd =
+ (nd1 + kd_lo * p[1] + kd_hi * p[0] + ((nd0 + kd_lo * p[0]) >> LIMB_BITS)) >>
+ LIMB_BITS;
+ let ce =
+ (ne1 + ke_lo * p[1] + ke_hi * p[0] + ((ne0 + ke_lo * p[0]) >> LIMB_BITS)) >>
+ LIMB_BITS;
+
+ let dp = d1;
+ let ep = e1;
+ for (let i = 2; i < N; i++) {
+ const di = d[i];
+ const ei = e[i];
+ const nd =
+ u_lo * di +
+ v_lo * ei +
+ u_hi * dp +
+ v_hi * ep +
+ kd_lo * p[i] +
+ kd_hi * p[i - 1] +
+ cd;
+ const ne =
+ q_lo * di +
+ r_lo * ei +
+ q_hi * dp +
+ r_hi * ep +
+ ke_lo * p[i] +
+ ke_hi * p[i - 1] +
+ ce;
+ cd = nd >> LIMB_BITS;
+ ce = ne >> LIMB_BITS;
+ d[i - 2] = nd & LIMB_MASK;
+ e[i - 2] = ne & LIMB_MASK;
+ dp = di;
+ ep = ei;
+ }
+ const nd9 = u_hi * dp + v_hi * ep + kd_hi * p[N - 1] + cd;
+ const ne9 = q_hi * dp + r_hi * ep + ke_hi * p[N - 1] + ce;
+ d[N - 2] = nd9 & LIMB_MASK;
+ e[N - 2] = ne9 & LIMB_MASK;
+ d[N - 1] = nd9 >> LIMB_BITS;
+ e[N - 1] = ne9 >> LIMB_BITS;
+ }
+}
+
+// ---- low_64: low 64 bits of the state ---------------------------------
+// Mirrors Wasm9x29::low_64(): re-pack limbs [0,1,2] (low 6 bits of limb 2)
+// into a u64. Returns a positive bigint <2^64.
+export function low64(x: BigIntBY): bigint {
+ return ((x[0] & LIMB_MASK) | ((x[1] & LIMB_MASK) << 29n) | ((x[2] & 0x3Fn) << 58n)) & U64_MASK;
+}
+
+// Convert a positive bigint into a signed-64 view (for use as `f_lo`/`g_lo`
+// in the divsteps inner where (g_lo - f_lo) interprets as u64 sub-wrap).
+// We don't need to convert — divsteps masks the inputs to U64_MASK.
+
+// ---- driver: invert_bernsteinyang19 -----------------------------------
+// Returns a^(-1) mod p, or 0 if a == 0. Bounded by NUM_OUTER outer iters
+// (≤ 13) with early break on g == 0, RTC_MAX_ITERS per reduce_to_canonical.
+//
+// stepRef (optional) accumulates the total bounded-loop step count for
+// test assertions: divsteps inner + reduce_to_canonical inner steps.
+export function invert(
+ a: bigint,
+ p: bigint = P_BIGINT,
+ pInv: bigint = P_INV,
+ stepRef?: { steps: number },
+): bigint {
+ if (a === 0n) return 0n;
+ const P = fromBigint(p);
+ const f = fromBigint(p);
+ const g = fromBigint(a);
+ const d = makeZero();
+ const e = makeOne();
+ let delta = 1n;
+ for (let i = 0; i < NUM_OUTER; i++) {
+ const f_lo = low64(f);
+ const g_lo = low64(g);
+ const { mat, delta: newDelta } = divsteps(delta, f_lo, g_lo, stepRef);
+ delta = newDelta;
+ applyMatrix(mat, f, g, d, e, P, pInv);
+ if (isZero(g)) break;
+ if ((i + 1) % REDUCE_INTERVAL === 0) {
+ reduceToCanonical(d, P, stepRef);
+ reduceToCanonical(e, P, stepRef);
+ }
+ }
+ reduceToCanonical(d, P, stepRef);
+ if (isNegative(f)) {
+ neg(d);
+ reduceToCanonical(d, P, stepRef);
+ }
+ return toBigint(d);
+}
+
+// ---- Convenience namespace ---------------------------------------------
+// Same shape as the C++ `Wasm9x29::*` static interface so the WGSL port
+// can be diffed line-by-line against this file.
+export const Wasm9x29 = {
+ N,
+ LIMB_BITS,
+ LIMB_MASK,
+ BATCH,
+ BATCH_NUM,
+ NUM_OUTER,
+ REDUCE_INTERVAL,
+ RTC_MAX_ITERS,
+ P: P_BIGINT,
+ P_INV,
+ makeZero,
+ makeOne,
+ fromBigint,
+ toBigint,
+ isZero,
+ isNegative,
+ neg,
+ normalise,
+ reduceToCanonical: (x: BigIntBY, p: BigIntBY) => reduceToCanonical(x, p),
+ divsteps: (delta: bigint, f_lo: bigint, g_lo: bigint) =>
+ divsteps(delta, f_lo, g_lo),
+ applyMatrix,
+ low64,
+ invert: (a: bigint, p: bigint = P_BIGINT) => invert(a, p, P_INV),
+};
diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang_a.ts b/barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang_a.ts
new file mode 100644
index 000000000000..c838408cc970
--- /dev/null
+++ b/barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang_a.ts
@@ -0,0 +1,247 @@
+// Option A TS reference: BY safegcd inverse on 20 x 13-bit BigInt with
+// BATCH=26, NUM_OUTER=29. Validates the WGSL port at fr_inv_by_a against
+// modInverse. Uses bigint throughout (no overflow concerns).
+
+import { BN254_BASE_FIELD, modInverse } from "./bn254.js";
+
+const N = 20;
+const LIMB_BITS = 13n;
+const LIMB_MASK = (1n << LIMB_BITS) - 1n;
+const BATCH = 26n;
+const BATCH_NUM = 26;
+const MASK_BATCH = (1n << BATCH) - 1n;
+const NUM_OUTER = 29;
+const REDUCE_INTERVAL = 4;
+const RTC_MAX_ITERS = 36;
+
+const U64_MASK = (1n << 64n) - 1n;
+const P_BIGINT = BN254_BASE_FIELD;
+// p^(-1) mod 2^26. Hensel-lifted.
+function computePInv26(p: bigint): bigint {
+ let inv = 1n;
+ let cur = 1n;
+ while (cur < BATCH) {
+ cur *= 2n;
+ const mask = (1n << cur) - 1n;
+ inv = (inv * (2n - p * inv)) & mask;
+ }
+ inv &= MASK_BATCH;
+ return inv;
+}
+const P_INV = computePInv26(P_BIGINT);
+
+type BigA = bigint[]; // length 20
+
+function makeZero(): BigA { return new Array(N).fill(0n); }
+function makeOne(): BigA { const r = makeZero(); r[0] = 1n; return r; }
+
+function fromBigint(x: bigint): BigA {
+ const r = makeZero();
+ for (let i = 0; i < N; i++) r[i] = (x >> (BigInt(i) * LIMB_BITS)) & LIMB_MASK;
+ return r;
+}
+
+function toBigint(x: BigA): bigint {
+ let v = 0n;
+ for (let i = N - 1; i >= 0; i--) v = (v << LIMB_BITS) | (x[i] & LIMB_MASK);
+ return v;
+}
+
+function lowU64(x: BigA): bigint {
+ // Low 64 bits of x. Uses limbs 0..4 (= 65 bits, top bit unused).
+ let v = 0n;
+ for (let i = 4; i >= 0; i--) v = (v << LIMB_BITS) | (x[i] & LIMB_MASK);
+ return v & U64_MASK;
+}
+
+function isZero(x: BigA): boolean {
+ for (let i = 0; i < N; i++) if (x[i] !== 0n) return false;
+ return true;
+}
+function isNegative(x: BigA): boolean { return x[N - 1] < 0n; }
+
+function normalise(x: BigA): void {
+ let c = 0n;
+ for (let i = 0; i < N - 1; i++) {
+ const v = x[i] + c;
+ x[i] = v & LIMB_MASK;
+ c = v >> LIMB_BITS; // bigint arith shift
+ }
+ x[N - 1] += c;
+}
+
+function neg(x: BigA): void {
+ for (let i = 0; i < N; i++) x[i] = -x[i];
+ normalise(x);
+}
+
+function addP(x: BigA, p: BigA): void {
+ for (let i = 0; i < N; i++) x[i] += p[i];
+ normalise(x);
+}
+function subP(x: BigA, p: BigA): void {
+ for (let i = 0; i < N; i++) x[i] -= p[i];
+ normalise(x);
+}
+
+function reduceToCanonical(x: BigA, p: BigA): void {
+ normalise(x);
+ for (let it = 0; it < RTC_MAX_ITERS; it++) {
+ if (isNegative(x)) { addP(x, p); continue; }
+ let cmp = 0;
+ for (let i = N - 1; i >= 0; i--) {
+ if (x[i] !== p[i]) { cmp = x[i] > p[i] ? 1 : -1; break; }
+ }
+ if (cmp < 0) break;
+ subP(x, p);
+ }
+}
+
+interface Mat { u: bigint; v: bigint; q: bigint; r: bigint; }
+
+function divsteps(deltaIn: bigint, fLoIn: bigint, gLoIn: bigint): { mat: Mat; delta: bigint } {
+ let delta = deltaIn;
+ let f_lo = fLoIn & U64_MASK;
+ let g_lo = gLoIn & U64_MASK;
+ let u = 1n, v = 0n, q = 0n, r = 1n;
+ for (let i = 0; i < BATCH_NUM; i++) {
+ if ((g_lo & 1n) !== 0n) {
+ if (delta > 0n) {
+ const nf = g_lo;
+ const diff = (g_lo - f_lo) & U64_MASK;
+ const ng = diff >> 1n;
+ const nu = q << 1n;
+ const nv = r << 1n;
+ const nq = q - u;
+ const nr = r - v;
+ f_lo = nf; g_lo = ng;
+ u = nu; v = nv; q = nq; r = nr;
+ delta = 1n - delta;
+ } else {
+ const sum = (g_lo + f_lo) & U64_MASK;
+ g_lo = sum >> 1n;
+ q = q + u; r = r + v;
+ u <<= 1n; v <<= 1n;
+ delta = delta + 1n;
+ }
+ } else {
+ g_lo >>= 1n;
+ u <<= 1n; v <<= 1n;
+ delta = delta + 1n;
+ }
+ }
+ return { mat: { u, v, q, r }, delta };
+}
+
+function applyMatrix(m: Mat, f: BigA, g: BigA, d: BigA, e: BigA, p: BigA, pInv: bigint): void {
+ // Use TS-style decomposition: u_lo unsigned in [0, 2^13), u_hi signed.
+ const u_lo = m.u & LIMB_MASK;
+ const u_hi = m.u >> LIMB_BITS;
+ const v_lo = m.v & LIMB_MASK;
+ const v_hi = m.v >> LIMB_BITS;
+ const q_lo = m.q & LIMB_MASK;
+ const q_hi = m.q >> LIMB_BITS;
+ const r_lo = m.r & LIMB_MASK;
+ const r_hi = m.r >> LIMB_BITS;
+
+ // (f, g) pass
+ {
+ let cf = 0n, cg = 0n, fp = 0n, gp = 0n;
+ for (let i = 0; i < N; i++) {
+ const fi = f[i];
+ const gi = g[i];
+ const nf = u_lo * fi + v_lo * gi + u_hi * fp + v_hi * gp + cf;
+ const ng = q_lo * fi + r_lo * gi + q_hi * fp + r_hi * gp + cg;
+ cf = nf >> LIMB_BITS;
+ cg = ng >> LIMB_BITS;
+ if (i >= 2) {
+ f[i - 2] = nf & LIMB_MASK;
+ g[i - 2] = ng & LIMB_MASK;
+ }
+ fp = fi;
+ gp = gi;
+ }
+ const nfTop = u_hi * fp + v_hi * gp + cf;
+ const ngTop = q_hi * fp + r_hi * gp + cg;
+ f[N - 2] = nfTop & LIMB_MASK;
+ g[N - 2] = ngTop & LIMB_MASK;
+ f[N - 1] = nfTop >> LIMB_BITS;
+ g[N - 1] = ngTop >> LIMB_BITS;
+ }
+
+ // (d, e) pass with k*p
+ {
+ const d0 = d[0], e0 = e[0], d1 = d[1], e1 = e[1];
+ const nd0 = u_lo * d0 + v_lo * e0;
+ const ne0 = q_lo * d0 + r_lo * e0;
+ const nd1 = u_lo * d1 + v_lo * e1 + u_hi * d0 + v_hi * e0;
+ const ne1 = q_lo * d1 + r_lo * e1 + q_hi * d0 + r_hi * e0;
+ const nd0_low = nd0 & LIMB_MASK;
+ const nd1_carry = (nd1 + (nd0 >> LIMB_BITS)) & LIMB_MASK;
+ const t_d = (nd0_low | (nd1_carry << LIMB_BITS)) & MASK_BATCH;
+ const ne0_low = ne0 & LIMB_MASK;
+ const ne1_carry = (ne1 + (ne0 >> LIMB_BITS)) & LIMB_MASK;
+ const t_e = (ne0_low | (ne1_carry << LIMB_BITS)) & MASK_BATCH;
+ const neg_td = (MASK_BATCH - t_d + 1n) & MASK_BATCH;
+ const neg_te = (MASK_BATCH - t_e + 1n) & MASK_BATCH;
+ const k_d = (neg_td * pInv) & MASK_BATCH;
+ const k_e = (neg_te * pInv) & MASK_BATCH;
+ const kd_lo = k_d & LIMB_MASK;
+ const kd_hi = k_d >> LIMB_BITS;
+ const ke_lo = k_e & LIMB_MASK;
+ const ke_hi = k_e >> LIMB_BITS;
+ let cd =
+ (nd1 + kd_lo * p[1] + kd_hi * p[0] + ((nd0 + kd_lo * p[0]) >> LIMB_BITS)) >>
+ LIMB_BITS;
+ let ce =
+ (ne1 + ke_lo * p[1] + ke_hi * p[0] + ((ne0 + ke_lo * p[0]) >> LIMB_BITS)) >>
+ LIMB_BITS;
+ let dp = d1, ep = e1;
+ for (let i = 2; i < N; i++) {
+ const di = d[i];
+ const ei = e[i];
+ const nd = u_lo * di + v_lo * ei + u_hi * dp + v_hi * ep + kd_lo * p[i] + kd_hi * p[i - 1] + cd;
+ const ne = q_lo * di + r_lo * ei + q_hi * dp + r_hi * ep + ke_lo * p[i] + ke_hi * p[i - 1] + ce;
+ cd = nd >> LIMB_BITS;
+ ce = ne >> LIMB_BITS;
+ d[i - 2] = nd & LIMB_MASK;
+ e[i - 2] = ne & LIMB_MASK;
+ dp = di;
+ ep = ei;
+ }
+ const ndTop = u_hi * dp + v_hi * ep + kd_hi * p[N - 1] + cd;
+ const neTop = q_hi * dp + r_hi * ep + ke_hi * p[N - 1] + ce;
+ d[N - 2] = ndTop & LIMB_MASK;
+ e[N - 2] = neTop & LIMB_MASK;
+ d[N - 1] = ndTop >> LIMB_BITS;
+ e[N - 1] = neTop >> LIMB_BITS;
+ }
+}
+
+export function invertA(a: bigint, p: bigint = P_BIGINT, pInv: bigint = P_INV): bigint {
+ if (a === 0n) return 0n;
+ const pa = fromBigint(p);
+ const f = fromBigint(p);
+ const g = fromBigint(a);
+ const d = makeZero();
+ const e = makeOne();
+ let delta = 1n;
+ for (let iter = 0; iter < NUM_OUTER; iter++) {
+ if (isZero(g)) break;
+ const flo = lowU64(f);
+ const glo = lowU64(g);
+ const { mat, delta: nd } = divsteps(delta, flo, glo);
+ delta = nd;
+ applyMatrix(mat, f, g, d, e, pa, pInv);
+ if (((iter + 1) % REDUCE_INTERVAL) === 0) {
+ reduceToCanonical(d, pa);
+ reduceToCanonical(e, pa);
+ }
+ }
+ reduceToCanonical(d, pa);
+ if (isNegative(f)) {
+ neg(d);
+ reduceToCanonical(d, pa);
+ }
+ return toBigint(d);
+}
diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/gpu.ts b/barretenberg/ts/src/msm_webgpu/cuzk/gpu.ts
index d7d2c02af0cc..10fb068a6833 100644
--- a/barretenberg/ts/src/msm_webgpu/cuzk/gpu.ts
+++ b/barretenberg/ts/src/msm_webgpu/cuzk/gpu.ts
@@ -1,5 +1,11 @@
// Request a high-performance GPU device. Enables the optional "timestamp-query"
// feature when the adapter supports it, so per-pass timings can be collected.
+//
+// Explicitly requests the adapter's MAX for `maxComputeWorkgroupStorageSize`
+// so workgroup-shared scratch can scale with hardware support.
+// WebGPU's default limit is 16 KiB; Apple M1/M2 Metal exposes ~32 KiB, M3+
+// ~64 KiB. Requesting the max keeps room for larger per-WG scratch buffers
+// without falling back to the default cap.
export const get_device = async (): Promise => {
const gpuErrMsg = 'Please use a browser that has WebGPU enabled.';
const adapter = await navigator.gpu.requestAdapter({
@@ -15,10 +21,65 @@ export const get_device = async (): Promise => {
requiredFeatures.push('timestamp-query');
}
- const device = await adapter.requestDevice({ requiredFeatures });
+ const requiredLimits: Record = {};
+ const adapterLimits = adapter.limits as unknown as Record;
+ const wgStorageMax = adapterLimits['maxComputeWorkgroupStorageSize'];
+ if (typeof wgStorageMax === 'number') {
+ requiredLimits['maxComputeWorkgroupStorageSize'] = wgStorageMax;
+ }
+
+ const device = await adapter.requestDevice({ requiredFeatures, requiredLimits });
+ const grantedLimits = device.limits as unknown as Record;
+ console.log(
+ `[gpu] requested maxComputeWorkgroupStorageSize=${wgStorageMax}B,` +
+ ` granted=${grantedLimits['maxComputeWorkgroupStorageSize']}B`,
+ );
return device;
};
+// Pick the largest SLAB that fits NUM_BUCKETS u32 atomics in the device's
+// workgroup-shared memory limit AND leaves headroom for compiler-allocated
+// temporaries. Returns `{ slab, slabs, bytes }` where
+// slab * 4 + HEADROOM <= workgroup_storage_limit
+// slabs = ceil(num_columns / slab)
+//
+// Selection:
+// - If `num_columns * 4 + HEADROOM <= limit`, use `slab = num_columns`
+// (single slab, zero tiling — every scalar is Booth-recoded exactly
+// once per window).
+// - Otherwise pick the largest power-of-2 SLAB that fits. Power-of-2 is
+// a convenience: it keeps `slot = k * WG_SIZE + tid` integer divisions
+// cheap and the per-slab digit range a clean [s, s+SLAB) interval.
+//
+// Headroom rationale: kernels also use small amounts of compiler-allocated
+// workgroup memory (loop counters, barrier state). Leaving a 1 KiB margin
+// keeps us inside the granted limit on drivers that over-allocate slightly.
+export const pick_shared_digit_slab = (
+ num_columns: number,
+ workgroup_storage_limit: number,
+): { slab: number; slabs: number; bytes: number } => {
+ const HEADROOM_BYTES = 1024;
+ const usable_bytes = Math.max(0, workgroup_storage_limit - HEADROOM_BYTES);
+ const max_cells = Math.floor(usable_bytes / 4);
+ if (max_cells <= 0) {
+ throw new Error(
+ `pick_shared_digit_slab: workgroup_storage_limit=${workgroup_storage_limit}B is too small even for the 1 KiB headroom`,
+ );
+ }
+ if (max_cells >= num_columns) {
+ return { slab: num_columns, slabs: 1, bytes: num_columns * 4 };
+ }
+ let slab = 1;
+ while (slab * 2 <= max_cells) slab *= 2;
+ if (slab < 1024) {
+ throw new Error(
+ `pick_shared_digit_slab: device workgroup storage ${workgroup_storage_limit}B can only fit ${slab} cells; below the 1024-cell floor`,
+ );
+ }
+ const slabs = Math.ceil(num_columns / slab);
+ return { slab, slabs, bytes: slab * 4 };
+};
+
export const read_write_buffer_usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST;
// Create and write a storage buffer
diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts
index 7730efe0e8f8..fb4018b6f9cd 100644
--- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts
+++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts
@@ -1,9 +1,11 @@
import mustache from 'mustache';
import {
+ apply_matrix_bench as apply_matrix_bench_shader,
barrett as barrett_funcs,
batch_affine_apply as batch_affine_apply_shader,
batch_affine_apply_scatter as batch_affine_apply_scatter_shader,
batch_affine_dispatch_args as batch_affine_dispatch_args_shader,
+ bench_batch_affine as bench_batch_affine_shader,
batch_affine_finalize as batch_affine_finalize_shader,
batch_affine_finalize_apply as batch_affine_finalize_apply_shader,
batch_affine_finalize_collect as batch_affine_finalize_collect_shader,
@@ -12,17 +14,33 @@ import {
batch_inverse as batch_inverse_shader,
batch_inverse_parallel as batch_inverse_parallel_shader,
bigint as bigint_funcs,
+ bigint_by as bigint_by_funcs,
+ bigint_f32 as bigint_f32_funcs,
+ // by_inverse hosts the Mat struct + by_divsteps (and grows to host
+ // by_apply_matrix / fr_inv_by in subsequent sub-steps of the BY rewrite).
+ by_inverse as by_inverse_funcs,
+ // Option A BY safegcd inverse on 20 x 13-bit BigInt with BATCH=26 /
+ // NUM_OUTER=29. Hosts MatA, bya_divsteps, bya_apply_matrix_{fg,de}, the
+ // bya_reduce_to_canonical helper chain, and the fr_inv_by_a driver.
+ by_inverse_a as by_inverse_a_funcs,
bpr_bn254 as bpr_bn254_shader,
convert_point_coords_and_decompose_scalars,
convert_points_only as convert_points_only_shader,
decompose_scalars_signed_only as decompose_scalars_signed_only_shader,
decompress_g1_bn254 as decompress_g1_bn254_shader,
+ divsteps_bench as divsteps_bench_shader,
ec_bn254 as ec_bn254_funcs,
extract_word_from_bytes_le as extract_word_from_bytes_le_funcs,
field as field_funcs,
+ field_mul_bench_f32 as field_mul_bench_f32_shader,
+ field_mul_bench_u32 as field_mul_bench_u32_shader,
+ fr_inv_bench as fr_inv_bench_shader,
fr_pow as fr_pow_funcs,
horner_reduce_bn254 as horner_reduce_bn254_shader,
mont_pro_product as montgomery_product_funcs,
+ mont_pro_product_f32_22_sos3uv3 as montgomery_product_f32_22_sos3uv3_funcs,
+ mont_pro_product_karat_yuval as montgomery_product_karat_yuval_funcs,
+ mulhilo_22 as mulhilo_22_funcs,
smvp_bn254 as smvp_bn254_shader,
structs,
transpose_parallel_count as transpose_parallel_count_shader,
@@ -31,15 +49,31 @@ import {
transpose_serial as transpose_serial_shader,
} from '../wgsl/_generated/shaders.js';
import {
+ compute_by_p_inv_a,
+ compute_by_p_inv_split,
compute_misc_params,
compute_mod_inverse_pow2,
gen_p_limbs,
+ gen_p_limbs_by_initializer,
+ gen_p_limbs_f32,
gen_r_limbs,
gen_mu_limbs,
gen_wgsl_limbs_code,
} from './utils.js';
import { BN254_CURVE_CONFIG, CurveConfig } from './curve_config.js';
+// Modular inverse via extended Euclidean. Returns a^-1 mod m. Both inputs > 0.
+function modinv(a: bigint, m: bigint): bigint {
+ let [old_r, r] = [((a % m) + m) % m, m];
+ let [old_s, s] = [1n, 0n];
+ while (r !== 0n) {
+ const q = old_r / r;
+ [old_r, r] = [r, old_r - q * r];
+ [old_s, s] = [s, old_s - q * s];
+ }
+ return ((old_s % m) + m) % m;
+}
+
// Generates parameterised WGSL shader sources for the BN254 MSM
// pipeline. Pre-computes Montgomery / Barrett constants for the
// configured word size on construction so the per-shader render
@@ -65,8 +99,43 @@ export class ShaderManager {
public r_cubed_limbs: string;
public b3_mont_limbs: string;
public sqrt_exp_limbs: string;
+ // (p - 2) as a BigInt literal — exponent for the Fermat-based fr_pow_inv
+ // bench variant. Plain (non-Montgomery) since fr_pow's `exp` is consumed
+ // bit-by-bit as a raw integer.
+ public p_minus_2_limbs: string;
public p_inv_mod_2w: number;
public mu_limbs: string;
+ // 22-bit-limb f32 Montgomery params. Used exclusively by
+ // `gen_field_mul_bench_f32_shader` for the sos3uv3 micro-benchmark.
+ // The 22-bit width buys a 4-way exact sum (4·2^22 = 2^24 fits in f32
+ // mantissa), enabling the per-slot (tlo, thi) chain-break in sos3uv3.
+ public num_limbs_f32_22: number;
+ public n0_f32_22: bigint;
+ public p_limbs_f32_22_str: string;
+ // 9 × 29-bit BY limb representation of `p` for the BY safegcd inverse
+ // path. Used by gen_apply_matrix_bench_shader (and in future sub-steps,
+ // by the fr_inv_by wiring). The initializer string is comma-separated
+ // limbs suitable for `BigIntBY(array({{{ p_limbs_by }}}))`.
+ public p_limbs_by_initializer: string;
+ // P_INV = p^(-1) mod 2^58, split as (low 32, high <=26) bits. The WASM
+ // convention is a single u64 (`p_inv` argument to Wasm9x29::apply_matrix);
+ // WGSL has no u64 so we precompute the split here and inject as two
+ // constants. Hensel-lifted from p mod 2 up to mod 2^58.
+ public p_inv_by_lo: number;
+ public p_inv_by_hi: number;
+ // 26-bit p^(-1) mod 2^26 for the Option A BY safegcd inverse driver
+ // (BATCH=26 / NUM_OUTER=29 on 20 x 13-bit BigInt). Single u32, since 26
+ // bits fit comfortably.
+ public p_inv_by_a_lo: number;
+ // Pre-rendered u32 Montgomery product source used as the
+ // `montgomery_product_funcs` mustache partial by every MSM shader that
+ // needs a base-field multiply. Defaults to the Karatsuba + Yuval body
+ // (see `renderKaratYuvalMont`), which benches ~27% faster than the
+ // runtime-loop CIOS at n=2^20, k=100 on Apple GPU. Both bodies expose
+ // the same `fn montgomery_product(x, y) -> BigInt` symbol and the same
+ // `get_p` / `conditional_reduce` helpers, so swapping the partial is
+ // a drop-in change at every callsite.
+ public mont_product_src: string;
public curveConfig: CurveConfig;
public recompile = '';
@@ -100,12 +169,38 @@ export class ShaderManager {
// (q + 1) / 4: closed-form sqrt exponent for q ≡ 3 (mod 4).
const sqrt_exp = (this.p + 1n) / 4n;
this.sqrt_exp_limbs = gen_wgsl_limbs_code(sqrt_exp, 'e', this.num_words, this.word_size);
+ // (p - 2): exponent for Fermat-based inversion in fr_pow_inv.
+ this.p_minus_2_limbs = gen_wgsl_limbs_code(this.p - 2n, 'e', this.num_words, this.word_size);
this.p_inv_mod_2w = compute_mod_inverse_pow2(this.p, this.word_size);
this.mu_limbs = gen_mu_limbs(this.p, this.num_words, this.word_size);
this.p_bitlength = this.p.toString(2).length;
this.slack = this.num_words * this.word_size - this.p_bitlength;
this.w_mask = (1 << this.word_size) - 1;
+ // 22-bit-limb f32 path (bench only). compute_misc_params(p, 22)
+ // gives num_words = 12 for BN254 (12·22 = 264 ≥ 254).
+ const params_f32_22 = compute_misc_params(this.p, 22);
+ this.num_limbs_f32_22 = params_f32_22.num_words;
+ this.n0_f32_22 = params_f32_22.n0;
+ this.p_limbs_f32_22_str = gen_p_limbs_f32(this.p, this.num_limbs_f32_22, 22);
+
+ // BY safegcd 9 × 29-bit representation of p and 58-bit p_inv split.
+ // Both feed `gen_apply_matrix_bench_shader` (and downstream by_inverse
+ // production wiring). The split is the WASM `p_inv` u64 broken into
+ // low-32 + high-26 chunks; the Mustache substitution is a flat u32
+ // constant on each side.
+ this.p_limbs_by_initializer = gen_p_limbs_by_initializer(this.p);
+ const p_inv_split = compute_by_p_inv_split(this.p);
+ this.p_inv_by_lo = p_inv_split.lo;
+ this.p_inv_by_hi = p_inv_split.hi;
+ // Option A 26-bit p_inv (single u32) for the BATCH=26 BY driver.
+ this.p_inv_by_a_lo = compute_by_p_inv_a(this.p);
+
+ // Render the Karatsuba+Yuval Mont body once. This is the default
+ // u32 multiplier used by every MSM shader that includes the
+ // `montgomery_product_funcs` mustache partial.
+ this.mont_product_src = this.renderKaratYuvalMont();
+
if (force_recompile) {
const rand = Math.round(Math.random() * 100000000000000000) % 2 ** 32;
this.recompile = `
@@ -166,7 +261,7 @@ export class ShaderManager {
bigint_funcs,
field_funcs,
barrett_funcs,
- montgomery_product_funcs,
+ montgomery_product_funcs: this.mont_product_src,
extract_word_from_bytes_le_funcs,
},
);
@@ -203,7 +298,7 @@ export class ShaderManager {
bigint_funcs,
field_funcs,
barrett_funcs,
- montgomery_product_funcs,
+ montgomery_product_funcs: this.mont_product_src,
extract_word_from_bytes_le_funcs,
},
);
@@ -225,6 +320,7 @@ export class ShaderManager {
mu_limbs: this.mu_limbs,
b3_mont_limbs: this.b3_mont_limbs,
sqrt_exp_limbs: this.sqrt_exp_limbs,
+ p_minus_2_limbs: this.p_minus_2_limbs,
w_mask: this.w_mask,
slack: this.slack,
num_words_mul_two: this.num_words * 2,
@@ -235,7 +331,7 @@ export class ShaderManager {
bigint_funcs,
field_funcs,
barrett_funcs,
- montgomery_product_funcs,
+ montgomery_product_funcs: this.mont_product_src,
fr_pow_funcs,
},
);
@@ -321,7 +417,7 @@ export class ShaderManager {
{
structs,
bigint_funcs,
- montgomery_product_funcs,
+ montgomery_product_funcs: this.mont_product_src,
field_funcs,
ec_funcs: ec_bn254_funcs,
},
@@ -338,22 +434,30 @@ export class ShaderManager {
p_limbs: this.p_limbs,
r_limbs: this.r_limbs,
r_cubed_limbs: this.r_cubed_limbs,
+ p_minus_2_limbs: this.p_minus_2_limbs,
mask: this.mask,
two_pow_word_size: this.two_pow_word_size,
p_inv_mod_2w: this.p_inv_mod_2w,
+ p_inv_by_a_lo: this.p_inv_by_a_lo,
recompile: this.recompile,
},
{
structs,
bigint_funcs,
- montgomery_product_funcs,
+ montgomery_product_funcs: this.mont_product_src,
field_funcs,
fr_pow_funcs,
+ bigint_by_funcs,
+ by_inverse_a_funcs,
},
);
}
- public gen_batch_inverse_parallel_shader(num_sub_wgs: number): string {
+ // `windows_per_batch` (WPB) sets how many consecutive subtask pair pools
+ // get merged into ONE fr_inv_by_a call per (batch, sub_wg). Z dispatch
+ // dim must be ceil(num_subtasks / WPB). Pass WPB=1 to recover the
+ // pre-pooling behaviour byte-for-byte.
+ public gen_batch_inverse_parallel_shader(num_sub_wgs: number, windows_per_batch: number): string {
return mustache.render(
batch_inverse_parallel_shader,
{
@@ -363,18 +467,72 @@ export class ShaderManager {
p_limbs: this.p_limbs,
r_limbs: this.r_limbs,
r_cubed_limbs: this.r_cubed_limbs,
+ p_minus_2_limbs: this.p_minus_2_limbs,
mask: this.mask,
two_pow_word_size: this.two_pow_word_size,
p_inv_mod_2w: this.p_inv_mod_2w,
+ p_inv_by_a_lo: this.p_inv_by_a_lo,
num_sub_wgs,
+ windows_per_batch,
recompile: this.recompile,
},
{
structs,
bigint_funcs,
- montgomery_product_funcs,
+ montgomery_product_funcs: this.mont_product_src,
field_funcs,
fr_pow_funcs,
+ bigint_by_funcs,
+ by_inverse_a_funcs,
+ },
+ );
+ }
+
+ // Standalone bench-only entry shader for batch-affine EC addition. One
+ // workgroup processes BATCH_SIZE pairs via the two-phase Montgomery
+ // batch-inverse trick (per-thread serial chunk + workgroup Hillis-Steele
+ // scan + single fr_inv_by_a + back-walk). Used by bench-batch-affine.ts
+ // to find the sweet spot where amortising the single inverse stops
+ // beating thread under-utilisation.
+ //
+ // `batch_size` must be an exact multiple of `tpb`; the caller picks both
+ // from a hand-built table (see bench-batch-affine.ts). BS = batch_size /
+ // tpb is baked into the shader as a compile-time constant so the inner
+ // forward-and-backward walks have static loop bounds.
+ public gen_bench_batch_affine_shader(batch_size: number, tpb: number): string {
+ if (batch_size <= 0 || tpb <= 0 || batch_size % tpb !== 0) {
+ throw new Error(
+ `gen_bench_batch_affine_shader: batch_size (${batch_size}) must be a positive multiple of tpb (${tpb})`,
+ );
+ }
+ const per_thread_count = batch_size / tpb;
+ return mustache.render(
+ bench_batch_affine_shader,
+ {
+ batch_size,
+ tpb,
+ per_thread_count,
+ word_size: this.word_size,
+ num_words: this.num_words,
+ n0: this.n0,
+ p_limbs: this.p_limbs,
+ r_limbs: this.r_limbs,
+ r_cubed_limbs: this.r_cubed_limbs,
+ p_minus_2_limbs: this.p_minus_2_limbs,
+ mask: this.mask,
+ two_pow_word_size: this.two_pow_word_size,
+ p_inv_mod_2w: this.p_inv_mod_2w,
+ p_inv_by_a_lo: this.p_inv_by_a_lo,
+ recompile: this.recompile,
+ },
+ {
+ structs,
+ bigint_funcs,
+ montgomery_product_funcs: this.mont_product_src,
+ field_funcs,
+ fr_pow_funcs,
+ bigint_by_funcs,
+ by_inverse_a_funcs,
},
);
}
@@ -391,8 +549,12 @@ export class ShaderManager {
);
}
- public gen_batch_affine_dispatch_args_shader(): string {
- return mustache.render(batch_affine_dispatch_args_shader, {}, {});
+ // `windows_per_batch` (WPB) is baked into the shader at render time —
+ // dispatch_args derives `num_batches = ceil(num_subtasks / WPB)` and
+ // uses it as the inverse-pass Z dispatch dim. Must match the WPB used
+ // by the corresponding gen_batch_inverse_parallel_shader call.
+ public gen_batch_affine_dispatch_args_shader(windows_per_batch: number): string {
+ return mustache.render(batch_affine_dispatch_args_shader, { windows_per_batch }, {});
}
public gen_batch_affine_schedule_shader(workgroup_size: number): string {
@@ -413,7 +575,7 @@ export class ShaderManager {
{
structs,
bigint_funcs,
- montgomery_product_funcs,
+ montgomery_product_funcs: this.mont_product_src,
field_funcs,
},
);
@@ -437,7 +599,7 @@ export class ShaderManager {
{
structs,
bigint_funcs,
- montgomery_product_funcs,
+ montgomery_product_funcs: this.mont_product_src,
field_funcs,
},
);
@@ -456,6 +618,7 @@ export class ShaderManager {
p_limbs: this.p_limbs,
r_limbs: this.r_limbs,
r_cubed_limbs: this.r_cubed_limbs,
+ p_minus_2_limbs: this.p_minus_2_limbs,
mask: this.mask,
two_pow_word_size: this.two_pow_word_size,
p_inv_mod_2w: this.p_inv_mod_2w,
@@ -464,7 +627,7 @@ export class ShaderManager {
{
structs,
bigint_funcs,
- montgomery_product_funcs,
+ montgomery_product_funcs: this.mont_product_src,
field_funcs,
fr_pow_funcs,
ec_funcs: ec_bn254_funcs,
@@ -492,7 +655,7 @@ export class ShaderManager {
{
structs,
bigint_funcs,
- montgomery_product_funcs,
+ montgomery_product_funcs: this.mont_product_src,
field_funcs,
},
);
@@ -518,7 +681,7 @@ export class ShaderManager {
{
structs,
bigint_funcs,
- montgomery_product_funcs,
+ montgomery_product_funcs: this.mont_product_src,
field_funcs,
},
);
@@ -542,7 +705,7 @@ export class ShaderManager {
{
structs,
bigint_funcs,
- montgomery_product_funcs,
+ montgomery_product_funcs: this.mont_product_src,
field_funcs,
},
);
@@ -560,6 +723,11 @@ export class ShaderManager {
bench_no_store?: boolean;
} = {},
safe_first_add_no_collision = false,
+ // Multi-window BPR: each thread loops over WPB consecutive subtasks,
+ // sharing kernel-launch and header overhead. WPB=1 keeps the legacy
+ // one-subtask-per-workgroup behaviour. Const-bounded inside the
+ // shader so Tint can fully unroll when WPB is small.
+ windows_per_batch = 1,
) {
const bench_null = !!bench_flags.bench_null;
const bench_compute_only = !!bench_flags.bench_compute_only;
@@ -584,6 +752,7 @@ export class ShaderManager {
p_inv_mod_2w: this.p_inv_mod_2w,
index_shift: this.index_shift,
workgroup_size,
+ windows_per_batch,
recompile: this.recompile,
capture_debug,
assume_affine_buckets,
@@ -597,7 +766,7 @@ export class ShaderManager {
{
structs,
bigint_funcs,
- montgomery_product_funcs,
+ montgomery_product_funcs: this.mont_product_src,
field_funcs,
ec_funcs: ec_bn254_funcs,
},
@@ -624,10 +793,499 @@ export class ShaderManager {
{
structs,
bigint_funcs,
- montgomery_product_funcs,
+ montgomery_product_funcs: this.mont_product_src,
field_funcs,
ec_funcs: ec_bn254_funcs,
},
);
}
+
+ // Bench-only entry shader for the u32 Montgomery product. Each thread
+ // chains `k` Mont mults over an (a, b) pair. `variant='cios'` selects
+ // the runtime-loop CIOS in `mont_pro_product.template.wgsl`; 'karat'
+ // selects the recursive Karatsuba + Yuval body below.
+ public gen_field_mul_bench_u32_shader(
+ workgroup_size: number,
+ variant: 'cios' | 'karat' = 'cios',
+ ): string {
+ const structs_src = mustache.render(structs, { num_words: this.num_words });
+ const bigint_src = mustache.render(bigint_funcs, {});
+ // 'karat' reuses the pre-rendered class-level default; 'cios' renders
+ // the original mitschabaude template inline so the bench can compare
+ // both bodies even though karat is the production default.
+ const mont_src =
+ variant === 'karat'
+ ? this.mont_product_src
+ : mustache.render(montgomery_product_funcs, {
+ num_words: this.num_words,
+ word_size: this.word_size,
+ n0: this.n0,
+ mask: this.mask,
+ two_pow_word_size: this.two_pow_word_size,
+ p_inv_mod_2w: this.p_inv_mod_2w,
+ p_limbs: this.p_limbs,
+ });
+ const entry_src = mustache.render(field_mul_bench_u32_shader, {
+ workgroup_size,
+ });
+ return `${structs_src}
+${bigint_src}
+${mont_src}
+${entry_src}`;
+ }
+
+ // Bench-only entry shader for the BY `by_divsteps` primitive. Assembles
+ // the BY helpers (bigint_by) + the by_inverse partial (which hosts the
+ // `Mat` struct and `by_divsteps`) + the per-thread bench entry.
+ //
+ // Each thread reads one (f_lo, g_lo, delta) tuple, calls `by_divsteps`,
+ // and writes the resulting 8-field Mat + updated delta. Used by the
+ // bench-divsteps.html page to compare against the TS Wasm9x29 port.
+ //
+ // Note: bigint_funcs / structs are NOT included here because divsteps
+ // only needs the BY-specific helpers (signed_mul_split, u64_*_pair,
+ // i64_*_pair). The BigInt-related portions of bigint_by (by_from_bigint
+ // et al.) are still rendered for completeness; they're dead code in this
+ // bench but will be needed in step 1.5 when fr_inv_by is wired up.
+ // We also strip BigInt-conversion helpers from the bigint_by render here
+ // by skipping the `{{ num_words }}` substitution — instead we set
+ // num_words to the standard MSM value so the by_from/to_bigint helpers
+ // remain syntactically valid even though unused.
+ public gen_divsteps_bench_shader(workgroup_size: number): string {
+ // Minimal BigInt struct declaration: by_from_bigint/by_to_bigint in
+ // bigint_by reference the BigInt type for completeness. This bench
+ // shader never calls them so the struct is dead but must compile.
+ const structs_src = mustache.render(structs, { num_words: this.num_words });
+ const bigint_src = mustache.render(bigint_funcs, {});
+ // by_inverse hosts `fr_inv_by`, which references `montgomery_product`,
+ // `get_r_cubed`, and the {{ p_limbs_by }} / {{ p_inv_by_* }} Mustache
+ // substitutions. divsteps_bench itself never calls fr_inv_by, but the
+ // partial must compile cleanly — so we pull in the same Mont + field +
+ // fr_pow surface that gen_fr_inv_bench_shader uses.
+ const mont_src = this.mont_product_src;
+ const field_src = mustache.render(field_funcs, {
+ word_size: this.word_size,
+ num_words: this.num_words,
+ n0: this.n0,
+ p_limbs: this.p_limbs,
+ r_limbs: this.r_limbs,
+ mask: this.mask,
+ two_pow_word_size: this.two_pow_word_size,
+ p_inv_mod_2w: this.p_inv_mod_2w,
+ });
+ const fr_pow_src = mustache.render(fr_pow_funcs, {
+ word_size: this.word_size,
+ num_words: this.num_words,
+ n0: this.n0,
+ p_limbs: this.p_limbs,
+ r_limbs: this.r_limbs,
+ r_cubed_limbs: this.r_cubed_limbs,
+ p_minus_2_limbs: this.p_minus_2_limbs,
+ mask: this.mask,
+ two_pow_word_size: this.two_pow_word_size,
+ p_inv_mod_2w: this.p_inv_mod_2w,
+ });
+ const bigint_by_src = mustache.render(bigint_by_funcs, {
+ num_words: this.num_words,
+ });
+ const by_inverse_src = this.renderByInverseFuncs();
+ const entry_src = mustache.render(divsteps_bench_shader, { workgroup_size });
+ const get_r_src = this.renderGetRFn();
+ return `${structs_src}
+${bigint_src}
+${mont_src}
+${field_src}
+${get_r_src}
+${fr_pow_src}
+${bigint_by_src}
+${by_inverse_src}
+${entry_src}`;
+ }
+
+ // Bench-only entry shader for `by_apply_matrix_fg` + `by_apply_matrix_de`.
+ // Each thread reads one (Mat, f, g, d, e) record, runs both passes, and
+ // writes the updated (f', g', d', e') as 36 i32 values. Validates against
+ // the TS `Wasm9x29.applyMatrix` reference (used by the bench-apply-matrix
+ // Playwright driver).
+ //
+ // Renders:
+ // - structs (with num_words for BigInt declaration — dead-coded in this
+ // bench, kept so the bigint_by partial compiles cleanly).
+ // - bigint_by partial (signed_mul_split, u64/i64 helpers, by_normalise).
+ // - by_inverse partial (Mat, by_divsteps, by_apply_matrix_*).
+ // - apply_matrix_bench entry (decode → apply → encode).
+ // Mustache substitutions on the entry shader:
+ // - workgroup_size
+ // - p_limbs_by: BigIntBY initializer for p
+ // - p_inv_by_lo: low 32 bits of P_INV = p^(-1) mod 2^58
+ // - p_inv_by_hi: high (up to 26) bits of P_INV
+ public gen_apply_matrix_bench_shader(workgroup_size: number): string {
+ const structs_src = mustache.render(structs, { num_words: this.num_words });
+ const bigint_src = mustache.render(bigint_funcs, {});
+ // See gen_divsteps_bench_shader for why the Mont + field + fr_pow surface
+ // is included — `fr_inv_by` lives inside the by_inverse partial and
+ // references those symbols, even though apply_matrix_bench never calls it.
+ const mont_src = this.mont_product_src;
+ const field_src = mustache.render(field_funcs, {
+ word_size: this.word_size,
+ num_words: this.num_words,
+ n0: this.n0,
+ p_limbs: this.p_limbs,
+ r_limbs: this.r_limbs,
+ mask: this.mask,
+ two_pow_word_size: this.two_pow_word_size,
+ p_inv_mod_2w: this.p_inv_mod_2w,
+ });
+ const fr_pow_src = mustache.render(fr_pow_funcs, {
+ word_size: this.word_size,
+ num_words: this.num_words,
+ n0: this.n0,
+ p_limbs: this.p_limbs,
+ r_limbs: this.r_limbs,
+ r_cubed_limbs: this.r_cubed_limbs,
+ p_minus_2_limbs: this.p_minus_2_limbs,
+ mask: this.mask,
+ two_pow_word_size: this.two_pow_word_size,
+ p_inv_mod_2w: this.p_inv_mod_2w,
+ });
+ const bigint_by_src = mustache.render(bigint_by_funcs, {
+ num_words: this.num_words,
+ });
+ const by_inverse_src = this.renderByInverseFuncs();
+ const entry_src = mustache.render(apply_matrix_bench_shader, {
+ workgroup_size,
+ p_limbs_by: this.p_limbs_by_initializer,
+ p_inv_by_lo: this.p_inv_by_lo,
+ p_inv_by_hi: this.p_inv_by_hi,
+ });
+ const get_r_src = this.renderGetRFn();
+ return `${structs_src}
+${bigint_src}
+${mont_src}
+${field_src}
+${get_r_src}
+${fr_pow_src}
+${bigint_by_src}
+${by_inverse_src}
+${entry_src}`;
+ }
+
+ // Bench-only entry shader for the BY top-level `fr_inv_by` driver. Each
+ // thread reads one BN254 base-field value `a` (in Montgomery form), runs
+ // `k` chained `fr_inv_by` calls, and writes the final value back. Used by
+ // the bench-fr-inv Playwright driver to validate against the host
+ // `Wasm9x29.invert` + Mont-correction reference.
+ //
+ // The render bundles:
+ // - structs (BigInt declaration, NUM_WORDS limbs)
+ // - bigint_funcs (basic BigInt utilities)
+ // - karat+yuval Mont product (provides `montgomery_product`, `get_p`)
+ // - fr_pow_funcs (provides `get_r_cubed` + `fr_inv`)
+ // - bigint_by (variant=fr_inv_by) (signed_mul_split, u64/i64 helpers,
+ // by_normalise, by_from/to_bigint)
+ // - by_inverse (variant=fr_inv_by) (Mat, by_divsteps, by_apply_matrix_*,
+ // by_reduce_to_canonical, fr_inv_by)
+ // - fr_inv_bench entry (per-thread chained inversion)
+ //
+ // The `variant` arg picks the symbol the entry shader calls. For
+ // 'fr_inv_by' we include the by_inverse + bigint_by partials; for the
+ // legacy 'fr_inv' (Pornin jumpy K=12 safegcd in fr_pow) we omit them so
+ // the rendered bundle stays small and dead-code-free.
+ //
+ // Mustache substitutions consumed by `by_inverse` (via renderByInverseFuncs):
+ // - p_limbs_by: BigIntBY initializer for the BN254 base-field modulus
+ // - p_inv_by_lo: low 32 bits of P_INV = p^(-1) mod 2^58
+ // - p_inv_by_hi: high (up to 26) bits of P_INV
+ public gen_fr_inv_bench_shader(
+ workgroup_size: number,
+ variant: 'fr_inv_by' | 'fr_inv' | 'fr_inv_by_a' | 'fr_pow_inv' = 'fr_inv_by',
+ ): string {
+ const structs_src = mustache.render(structs, { num_words: this.num_words });
+ const bigint_src = mustache.render(bigint_funcs, {});
+ const mont_src = this.mont_product_src;
+ // field_funcs provides `fr_sub` / `fr_add` / `bigint_halve_k_mod_p` and
+ // other helpers that fr_pow's alternate variants reference. `fr_inv_by`
+ // itself does not call them but the partial must resolve every symbol
+ // or shader compilation fails.
+ const field_src = mustache.render(field_funcs, {
+ word_size: this.word_size,
+ num_words: this.num_words,
+ n0: this.n0,
+ p_limbs: this.p_limbs,
+ r_limbs: this.r_limbs,
+ mask: this.mask,
+ two_pow_word_size: this.two_pow_word_size,
+ p_inv_mod_2w: this.p_inv_mod_2w,
+ });
+ // fr_pow exports `fr_pow`, `fr_inv`, `fr_inv_plain`, `fr_inv_bgcd`, and
+ // the `get_r_cubed` helper that `fr_inv_by` uses for the Mont
+ // correction. We include it in BOTH variants so `get_r_cubed` is in
+ // scope for fr_inv_by and `fr_inv` is in scope for the legacy variant.
+ const fr_pow_src = mustache.render(fr_pow_funcs, {
+ word_size: this.word_size,
+ num_words: this.num_words,
+ n0: this.n0,
+ p_limbs: this.p_limbs,
+ r_limbs: this.r_limbs,
+ r_cubed_limbs: this.r_cubed_limbs,
+ p_minus_2_limbs: this.p_minus_2_limbs,
+ mask: this.mask,
+ two_pow_word_size: this.two_pow_word_size,
+ p_inv_mod_2w: this.p_inv_mod_2w,
+ });
+ // bigint_by + by_inverse are gated on variant: they are large partials
+ // (≈700 lines of BY plumbing) and only needed when fr_inv_by is called.
+ // For the legacy fr_inv variant we omit them entirely so the bundle
+ // stays small and we don't pay compile time for dead code.
+ let by_blocks = '';
+ if (variant === 'fr_inv_by') {
+ by_blocks = `${mustache.render(bigint_by_funcs, { num_words: this.num_words })}
+${this.renderByInverseFuncs()}`;
+ } else if (variant === 'fr_inv_by_a') {
+ // Option A reuses the u64 helpers from bigint_by (u64_add, u64_sub,
+ // u64_shr1, u64_low_bit) but does NOT need the 9 x 29-bit BigIntBY
+ // struct or its conversion helpers. We still pull in bigint_by for
+ // the u64 helpers and signed_mul_split (unused here but cheap).
+ by_blocks = `${mustache.render(bigint_by_funcs, { num_words: this.num_words })}
+${this.renderByInverseAFuncs()}`;
+ }
+ const entry_src = mustache.render(fr_inv_bench_shader, {
+ workgroup_size,
+ r_limbs: this.r_limbs,
+ inv_fn: variant,
+ });
+ return `${structs_src}
+${bigint_src}
+${mont_src}
+${field_src}
+${fr_pow_src}
+${by_blocks}
+${entry_src}`;
+ }
+
+ // Render the by_inverse partial with the BY-specific Mustache constants
+ // (BigIntBY initializer for p, and the p_inv 58-bit split). Shared by the
+ // divsteps / apply_matrix / fr_inv bench renders so the partial's
+ // `{{{ p_limbs_by }}}` and `{{ p_inv_by_* }}` substitutions resolve to
+ // valid WGSL in every assembly.
+ private renderByInverseFuncs(): string {
+ return mustache.render(by_inverse_funcs, {
+ p_limbs_by: this.p_limbs_by_initializer,
+ p_inv_by_lo: this.p_inv_by_lo,
+ p_inv_by_hi: this.p_inv_by_hi,
+ });
+ }
+
+ // Render the by_inverse_a partial (Option A: BATCH=26 / NUM_OUTER=29 BY
+ // safegcd on 20 x 13-bit BigInt). Mustache substitutions: the 26-bit
+ // p_inv constant and {{ num_words }} for the streaming-loop bound.
+ private renderByInverseAFuncs(): string {
+ return mustache.render(by_inverse_a_funcs, {
+ num_words: this.num_words,
+ p_inv_by_a_lo: this.p_inv_by_a_lo,
+ });
+ }
+
+ // Inlined `get_r` definition with the curve-specific R limbs. Every
+ // production MSM shader defines its own; the bench harnesses pull this
+ // from a single helper so the divsteps / apply_matrix benches can hoist
+ // it before fr_pow_funcs (which calls `get_r()` from inside fr_pow).
+ private renderGetRFn(): string {
+ return `fn get_r() -> BigInt {\n var r: BigInt;\n${this.r_limbs}\n return r;\n}`;
+ }
+
+ // Bench-only entry shader for the f32 Montgomery product. Only the
+ // sos3uv3 variant (22-bit limbs, per-slot tlo/thi chain-break) is wired
+ // up here — it is the fastest f32 Mont mul found in the wider variant
+ // sweep and is kept as a reference point alongside the u32 paths.
+ public gen_field_mul_bench_f32_shader(
+ workgroup_size: number,
+ variant: 'sos3uv3' = 'sos3uv3',
+ ): string {
+ if (variant !== 'sos3uv3') {
+ throw new Error(`f32 bench variant must be 'sos3uv3', got '${variant}'`);
+ }
+ const helpers = this.gen_montgomery_product_f32_22_sos3uv3_shader();
+ const entry_src = mustache.render(field_mul_bench_f32_shader, {
+ workgroup_size,
+ });
+ return `${helpers}
+${entry_src}`;
+ }
+
+ // Renders mont_pro_product_karat_yuval.template.wgsl. The .wgsl file
+ // owns the algorithm structure (chunks → sums → 9 schoolbook
+ // sub-sub-products → inner combines → outer combine → Yuval reduce →
+ // final canonicalize) via mustache `{{#each}}` sections. The TS here
+ // just provides the index arrays + r_inv limb constants.
+ private renderKaratYuvalMont(): string {
+ const N = this.num_words; // 20
+ const WS = this.word_size; // 13
+ const W = 1n << BigInt(WS);
+
+ const r_inv = modinv(W, this.p);
+ const mask = W - 1n;
+ const limbs: number[] = [];
+ let v = r_inv;
+ for (let i = 0; i < N; i++) {
+ limbs.push(Number(v & mask));
+ v >>= BigInt(WS);
+ }
+ const r_inv_consts = limbs.map((val, idx) => ({ idx, val }));
+
+ const input_loads: Array<{ name: string; ptr: string; k: number }> = [];
+ const chunks = [
+ ['x_lo_lo', 'x_ptr', 0],
+ ['x_lo_hi', 'x_ptr', 5],
+ ['x_hi_lo', 'x_ptr', 10],
+ ['x_hi_hi', 'x_ptr', 15],
+ ['y_lo_lo', 'y_ptr', 0],
+ ['y_lo_hi', 'y_ptr', 5],
+ ['y_hi_lo', 'y_ptr', 10],
+ ['y_hi_hi', 'y_ptr', 15],
+ ] as const;
+ for (const [prefix, ptr, base] of chunks) {
+ for (let k = 0; k < 5; k++) {
+ input_loads.push({ name: `${prefix}_${k}`, ptr, k: (base as number) + k });
+ }
+ }
+
+ const sum_lets: Array<{ name: string; lhs: string; rhs: string }> = [];
+ const sumDefs = [
+ ['a_lo_sum', 'x_lo_lo', 'x_lo_hi'],
+ ['b_lo_sum', 'y_lo_lo', 'y_lo_hi'],
+ ['a_hi_sum', 'x_hi_lo', 'x_hi_hi'],
+ ['b_hi_sum', 'y_hi_lo', 'y_hi_hi'],
+ ['a_cr_lo', 'x_lo_lo', 'x_hi_lo'],
+ ['a_cr_hi', 'x_lo_hi', 'x_hi_hi'],
+ ['b_cr_lo', 'y_lo_lo', 'y_hi_lo'],
+ ['b_cr_hi', 'y_lo_hi', 'y_hi_hi'],
+ ['a_cr_sum', 'a_cr_lo', 'a_cr_hi'],
+ ['b_cr_sum', 'b_cr_lo', 'b_cr_hi'],
+ ] as const;
+ for (const [name, lhs, rhs] of sumDefs) {
+ for (let k = 0; k < 5; k++) {
+ sum_lets.push({ name: `${name}_${k}`, lhs: `${lhs}_${k}`, rhs: `${rhs}_${k}` });
+ }
+ }
+
+ const schoolbooks = [
+ { label: 'PP_lo_LL = x_lo_lo · y_lo_lo', out_prefix: 'pp_lo_ll', a_prefix: 'x_lo_lo', b_prefix: 'y_lo_lo' },
+ { label: 'PP_lo_HH = x_lo_hi · y_lo_hi', out_prefix: 'pp_lo_hh', a_prefix: 'x_lo_hi', b_prefix: 'y_lo_hi' },
+ { label: 'PP_lo_C = a_lo_sum · b_lo_sum', out_prefix: 'pp_lo_c', a_prefix: 'a_lo_sum', b_prefix: 'b_lo_sum' },
+ { label: 'PP_hi_LL = x_hi_lo · y_hi_lo', out_prefix: 'pp_hi_ll', a_prefix: 'x_hi_lo', b_prefix: 'y_hi_lo' },
+ { label: 'PP_hi_HH = x_hi_hi · y_hi_hi', out_prefix: 'pp_hi_hh', a_prefix: 'x_hi_hi', b_prefix: 'y_hi_hi' },
+ { label: 'PP_hi_C = a_hi_sum · b_hi_sum', out_prefix: 'pp_hi_c', a_prefix: 'a_hi_sum', b_prefix: 'b_hi_sum' },
+ { label: 'PP_cr_LL = a_cr_lo · b_cr_lo', out_prefix: 'pp_cr_ll', a_prefix: 'a_cr_lo', b_prefix: 'b_cr_lo' },
+ { label: 'PP_cr_HH = a_cr_hi · b_cr_hi', out_prefix: 'pp_cr_hh', a_prefix: 'a_cr_hi', b_prefix: 'b_cr_hi' },
+ { label: 'PP_cr_C = a_cr_sum · b_cr_sum', out_prefix: 'pp_cr_c', a_prefix: 'a_cr_sum', b_prefix: 'b_cr_sum' },
+ ];
+
+ const inner_combines = [
+ { label: 'P_lo from pp_lo_*', out_prefix: 'p_lo', ll_prefix: 'pp_lo_ll', hh_prefix: 'pp_lo_hh', c_prefix: 'pp_lo_c' },
+ { label: 'P_hi from pp_hi_*', out_prefix: 'p_hi', ll_prefix: 'pp_hi_ll', hh_prefix: 'pp_hi_hh', c_prefix: 'pp_hi_c' },
+ { label: 'P_cr from pp_cr_*', out_prefix: 'p_cr', ll_prefix: 'pp_cr_ll', hh_prefix: 'pp_cr_hh', c_prefix: 'pp_cr_c' },
+ ];
+
+ const outer_init: Array<{ slot: number; init_expr: string }> = [];
+ for (let k = 0; k < 19; k++) outer_init.push({ slot: k, init_expr: `p_lo_${k}` });
+ outer_init.push({ slot: 19, init_expr: '0u' });
+ for (let k = 0; k < 19; k++) outer_init.push({ slot: 20 + k, init_expr: `p_hi_${k}` });
+ outer_init.push({ slot: 39, init_expr: '0u' });
+
+ const outer_cross: Array<{ slot: number; k: number }> = [];
+ for (let k = 0; k < 19; k++) outer_cross.push({ slot: 10 + k, k });
+
+ const yuval_iters: Array<{ i: number; writes: Array<{ slot: number; r_idx: number; first: boolean }> }> = [];
+ for (let i = 0; i < N - 1; i++) {
+ const writes = [];
+ for (let j = 0; j < N; j++) {
+ writes.push({ slot: i + 1 + j, r_idx: j, first: j === 0 });
+ }
+ yuval_iters.push({ i, writes });
+ }
+
+ const i_std = N - 1;
+ const standard_writes: Array<{ slot: number; p_idx: number; first: boolean }> = [];
+ for (let j = 0; j < N; j++) {
+ standard_writes.push({ slot: i_std + j, p_idx: j, first: j === 1 });
+ }
+
+ const final_drain: Array<{ slot: number }> = [];
+ for (let i = 0; i < N; i++) final_drain.push({ slot: N + i });
+
+ const extract: Array<{ out_k: number; src_slot: number }> = [];
+ for (let i = 0; i < N; i++) extract.push({ out_k: i, src_slot: N + i });
+
+ return mustache.render(montgomery_product_karat_yuval_funcs, {
+ num_words: N,
+ word_size: WS,
+ n0: this.n0,
+ mask: this.mask,
+ two_pow_word_size: this.two_pow_word_size,
+ p_inv_mod_2w: this.p_inv_mod_2w,
+ p_limbs: this.p_limbs,
+ r_inv_consts,
+ input_loads,
+ sum_lets,
+ schoolbooks,
+ inner_combines,
+ outer_init,
+ outer_cross,
+ yuval_iters,
+ i_std,
+ standard_writes,
+ final_drain,
+ extract,
+ });
+ }
+
+ // Renders mont_pro_product_f32_22_sos3uv3.template.wgsl. The .wgsl owns
+ // the algorithm — separate per-slot tlo/thi f32 accumulators, no
+ // inter-j carry chain, drain at end of each outer iter via
+ // bias_split_f32_le4w. This TS supplies index arrays for the mustache
+ // slot-init / inner-pairs / drain-cols sections.
+ public gen_montgomery_product_f32_22_sos3uv3_shader(): string {
+ const N = this.num_limbs_f32_22;
+ const W_INV_VAL = 2.384185791015625e-7;
+ const n0Num = Number(this.n0_f32_22);
+ const n0Scaled = n0Num * W_INV_VAL;
+
+ // Slot init for iter 0: tlo[0] = init_slot0, everything else 0.
+ // Slot init for i>=1: tlo[0] = init_slot0 + s1, tlo[k] = s[k+1] for
+ // k=1..N-2, tlo[N-1] = 0; thi[*] = 0.
+ const slotInitsI0: Array<{ name: string; init_expr: string }> = [];
+ const slotInitsGeneral: Array<{ name: string; init_expr: string }> = [];
+ for (let k = 0; k < N; k++) {
+ slotInitsI0.push({ name: `tlo${k}`, init_expr: k === 0 ? 'init_slot0' : '0.0' });
+ slotInitsI0.push({ name: `thi${k}`, init_expr: '0.0' });
+ let genTlo: string;
+ if (k === 0) genTlo = 'init_slot0 + s1';
+ else if (k === N - 1) genTlo = '0.0';
+ else genTlo = `s${k + 1}`;
+ slotInitsGeneral.push({ name: `tlo${k}`, init_expr: genTlo });
+ slotInitsGeneral.push({ name: `thi${k}`, init_expr: '0.0' });
+ }
+
+ // Inner-j pairs: for j=1..N-1, write tlo[j-1] += lo_sum, thi[j] += hi_sum.
+ const innerPairs = [];
+ for (let j = 1; j < N; j++) innerPairs.push({ j, km1: j - 1, k: j });
+
+ // Drain cols: k=0..N-1.
+ const drainCols = Array.from({ length: N }, (_, k) => ({ k }));
+
+ const ctx = {
+ num_limbs: N,
+ n0: `${this.n0_f32_22.toString()}.0`,
+ n0_scaled: n0Scaled.toString(),
+ p_limbs_f32: this.p_limbs_f32_22_str,
+ slot_inits_i0: slotInitsI0,
+ slot_inits_general: slotInitsGeneral,
+ inner_pairs: innerPairs,
+ drain_cols: drainCols,
+ };
+ const bigint_f32_src = mustache.render(bigint_f32_funcs, ctx);
+ const mont_src = mustache.render(montgomery_product_f32_22_sos3uv3_funcs, ctx);
+ return `${mulhilo_22_funcs}\n${bigint_f32_src}\n${mont_src}`;
+ }
}
diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/utils.ts b/barretenberg/ts/src/msm_webgpu/cuzk/utils.ts
index 489081f278d3..c52165b49d5a 100644
--- a/barretenberg/ts/src/msm_webgpu/cuzk/utils.ts
+++ b/barretenberg/ts/src/msm_webgpu/cuzk/utils.ts
@@ -211,6 +211,35 @@ export const gen_wgsl_limbs_code = (
return r;
};
+// f32 variant of gen_wgsl_limbs_code: emits `.0` literals instead of `u`.
+// Used by the FMA-Montgomery (23-bit f32 limb) path. Does its own
+// little-endian limb extraction because `to_words_le` returns a
+// Uint16Array, which truncates any word_size > 16.
+export const gen_wgsl_limbs_code_f32 = (
+ val: bigint,
+ var_name: string,
+ num_words: number,
+ word_size: number,
+): string => {
+ const mask = (BigInt(1) << BigInt(word_size)) - BigInt(1);
+ let r = "";
+ let v = val;
+ for (let i = 0; i < num_words; i++) {
+ const limb = Number(v & mask);
+ r += ` ${var_name}.limbs[${i}]` + " = " + limb.toString() + ".0;\n";
+ v >>= BigInt(word_size);
+ }
+ return r;
+};
+
+export const gen_p_limbs_f32 = (
+ p: bigint,
+ num_words: number,
+ word_size: number,
+): string => {
+ return gen_wgsl_limbs_code_f32(p, "p", num_words, word_size);
+};
+
export const gen_barrett_domb_m_limbs = (
m: bigint,
num_words: number,
@@ -227,6 +256,110 @@ export const gen_p_limbs = (
return gen_wgsl_limbs_code(p, "p", num_words, word_size);
};
+// Emit a comma-separated `array` initializer list for the 9 × 29-bit
+// limbs of a non-negative bigint. Used to inject the BN254 base-field modulus
+// `p` into the apply_matrix bench shader as a literal `BigIntBY` initializer.
+//
+// Pre: p in [0, 2^256). Post: e.g. "1, 2, 3, 4, 5, 6, 7, 8, 9" suitable for
+// substitution inside `BigIntBY(array({{{ p_limbs_by }}}))`.
+export const gen_p_limbs_by_initializer = (p: bigint): string => {
+ if (p < 0n) {
+ throw new Error("gen_p_limbs_by_initializer: negative input");
+ }
+ const BY_NUM_LIMBS = 9;
+ const BY_LIMB_BITS = 29n;
+ const BY_LIMB_MASK = (1n << BY_LIMB_BITS) - 1n;
+ const limbs: string[] = [];
+ // 9 × 29-bit limb decomposition matches Wasm9x29::fromBigint exactly.
+ // Limbs are non-negative for inputs in [0, 2^256) since the top limb's
+ // bit 28 is 0 for p < 2^254.
+ const mask64 = (1n << 64n) - 1n;
+ const d0 = p & mask64;
+ const d1 = (p >> 64n) & mask64;
+ const d2 = (p >> 128n) & mask64;
+ const d3 = (p >> 192n) & mask64;
+ limbs.push((d0 & BY_LIMB_MASK).toString());
+ limbs.push(((d0 >> 29n) & BY_LIMB_MASK).toString());
+ limbs.push((((d0 >> 58n) & 0x3Fn) | ((d1 & 0x7FFFFFn) << 6n)).toString());
+ limbs.push(((d1 >> 23n) & BY_LIMB_MASK).toString());
+ limbs.push((((d1 >> 52n) & 0xFFFn) | ((d2 & 0x1FFFFn) << 12n)).toString());
+ limbs.push(((d2 >> 17n) & BY_LIMB_MASK).toString());
+ limbs.push((((d2 >> 46n) & 0x3FFFFn) | ((d3 & 0x7FFn) << 18n)).toString());
+ limbs.push(((d3 >> 11n) & BY_LIMB_MASK).toString());
+ limbs.push(((d3 >> 40n) & 0xFFFFFFn).toString());
+ if (limbs.length !== BY_NUM_LIMBS) {
+ throw new Error(`gen_p_limbs_by_initializer: expected ${BY_NUM_LIMBS} limbs, got ${limbs.length}`);
+ }
+ return limbs.join(", ");
+};
+
+// p_inv = p^(-1) mod 2^58 for the BY safegcd 2-adic correction step. WASM's
+// `Wasm9x29::apply_matrix` accepts this as a single u64; WGSL has no native
+// u64 so we split it as two u32s — low 32 bits and high (up to 26) bits. The
+// caller injects these as constants into the by_apply_matrix_de Mustache
+// substitution `{{ p_inv_by_lo }}` / `{{ p_inv_by_hi }}`.
+//
+// Returns { lo, hi } where (lo + hi * 2^32) === p_inv and 0 <= lo < 2^32,
+// 0 <= hi < 2^26.
+export const compute_by_p_inv_split = (p: bigint): { lo: number; hi: number } => {
+ const BATCH = 58n;
+ const MASK_BATCH = (1n << BATCH) - 1n;
+ if (p < 1n) {
+ throw new Error("compute_by_p_inv_split: p must be positive");
+ }
+ if ((p & 1n) === 0n) {
+ throw new Error("compute_by_p_inv_split: p must be odd");
+ }
+ // Hensel lift: invert p mod 2, then double precision each Newton step
+ // until we have >= 58 bits. inv_n satisfies p * inv_n ≡ 1 mod 2^n.
+ // Newton's formula: inv' = inv * (2 - p * inv) mod 2^(2k), valid when
+ // inv was valid mod 2^k. Start at k=1 (inv=1 works since p is odd).
+ let inv = 1n;
+ let cur = 1n;
+ while (cur < BATCH) {
+ cur *= 2n;
+ const mask = (1n << cur) - 1n;
+ inv = (inv * (2n - p * inv)) & mask;
+ }
+ inv &= MASK_BATCH;
+ if (((p * inv) & MASK_BATCH) !== 1n) {
+ throw new Error("compute_by_p_inv_split: Hensel lift failed sanity check");
+ }
+ const lo = Number(inv & ((1n << 32n) - 1n));
+ const hi = Number((inv >> 32n) & ((1n << 32n) - 1n));
+ return { lo, hi };
+};
+
+// p_inv = p^(-1) mod 2^26 for the Option A BY safegcd inverse driver
+// (BATCH=26 / NUM_OUTER=29 on 20 x 13-bit BigInt). Mirrors compute_by_p_inv_split
+// but at 26-bit precision so it fits in a single u32 WGSL constant.
+//
+// Pre: p odd, positive.
+// Post: low 26 bits of p^(-1) mod 2^26 packed in a u32, satisfying
+// (p * out) & ((1<<26) - 1) === 1.
+export const compute_by_p_inv_a = (p: bigint): number => {
+ const BATCH = 26n;
+ const MASK_BATCH = (1n << BATCH) - 1n;
+ if (p < 1n) {
+ throw new Error("compute_by_p_inv_a: p must be positive");
+ }
+ if ((p & 1n) === 0n) {
+ throw new Error("compute_by_p_inv_a: p must be odd");
+ }
+ let inv = 1n;
+ let cur = 1n;
+ while (cur < BATCH) {
+ cur *= 2n;
+ const mask = (1n << cur) - 1n;
+ inv = (inv * (2n - p * inv)) & mask;
+ }
+ inv &= MASK_BATCH;
+ if (((p * inv) & MASK_BATCH) !== 1n) {
+ throw new Error("compute_by_p_inv_a: Hensel lift failed sanity check");
+ }
+ return Number(inv);
+};
+
export const gen_r_limbs = (
r: bigint,
num_words: number,
diff --git a/barretenberg/ts/src/msm_webgpu/msm.ts b/barretenberg/ts/src/msm_webgpu/msm.ts
index bb23cdf002c3..f47d43fa4f55 100644
--- a/barretenberg/ts/src/msm_webgpu/msm.ts
+++ b/barretenberg/ts/src/msm_webgpu/msm.ts
@@ -948,7 +948,14 @@ const compute_curve_msm = async (
// loop's wraparound bug AND keeps the simpler one-iteration dispatch
// that the warm-cached bind-group cache already assumes.
const num_subtasks_per_bpr_1 = num_subtasks;
- const b_num_x_workgroups = num_subtasks_per_bpr_1;
+ // BPR_WINDOWS_PER_BATCH: each workgroup handles WPB consecutive
+ // subtasks via an in-kernel const-bounded loop. WPB=1 = legacy
+ // dispatch shape (X = num_subtasks). WPB > 1 trades thread count for
+ // per-thread work — see bpr_bn254.template.wgsl stage_1 comment for
+ // the register-pressure tradeoff. Override per call via the
+ // bpr_inner_loop knob downstream if needed.
+ const BPR_WINDOWS_PER_BATCH = 1;
+ const b_num_x_workgroups = Math.ceil(num_subtasks_per_bpr_1 / BPR_WINDOWS_PER_BATCH);
const b_workgroup_size = 256;
// Output of the parallel bucket points reduction (BPR) shader
@@ -1002,6 +1009,7 @@ const compute_curve_msm = async (
/* mixed_safe_buckets */ bpr_mixed_safe,
/* bench_flags */ bpr_bench_flags,
/* safe_first_add_no_collision */ bpr_safe_first,
+ /* windows_per_batch */ BPR_WINDOWS_PER_BATCH,
);
// Compact key derived from the bench flags. Forwarded into the bpr_1
// and bpr_2 pipeline cache keys so each variant compiles its own
@@ -1041,6 +1049,8 @@ const compute_curve_msm = async (
bpr_mixed_safe,
bpr_bench_key,
input_size,
+ num_subtasks,
+ BPR_WINDOWS_PER_BATCH,
);
}
@@ -1070,7 +1080,7 @@ const compute_curve_msm = async (
// Bucket points reduction (BPR) - stage 2
// Same as bpr_1: dispatch all T subtasks in one outer iter.
const num_subtasks_per_bpr_2 = num_subtasks;
- const b_2_num_x_workgroups = num_subtasks_per_bpr_2;
+ const b_2_num_x_workgroups = Math.ceil(num_subtasks_per_bpr_2 / BPR_WINDOWS_PER_BATCH);
for (let subtask_idx = 0; subtask_idx < num_subtasks; subtask_idx += num_subtasks_per_bpr_2) {
await bpr_2(
bpr_shader,
@@ -1096,6 +1106,8 @@ const compute_curve_msm = async (
bpr_mixed_safe,
bpr_bench_key,
input_size,
+ num_subtasks,
+ BPR_WINDOWS_PER_BATCH,
);
}
cpu_timer.phaseFrom('bpr_host_total', 'bpr_host_begin');
@@ -2359,6 +2371,11 @@ const bpr_1 = async (
// first call's buffers, BPR would read stale data, and downstream
// Horner would read an unwritten g_points buffer and return identity.
input_size_key = 0,
+ // Total subtask count + WPB. Passed to the shader as params[3] so
+ // multi-window dispatches with WPB > 1 can skip out-of-range subtasks
+ // in the tail batch (when num_subtasks is not a multiple of WPB).
+ num_subtasks_total = 0,
+ windows_per_batch = 1,
) => {
let original_bucket_sum_x_sb;
let original_bucket_sum_y_sb;
@@ -2375,11 +2392,13 @@ const bpr_1 = async (
commandEncoder.copyBufferToBuffer(bucket_sum_z_sb, 0, original_bucket_sum_z_sb, 0, bucket_sum_z_sb.size);
}
- // Parameters as a uniform buffer. Contents (subtask_idx, num_columns,
- // num_x_workgroups) are constant per (subtask_idx, layout) tuple, so
- // we cache them on the context when one is provided.
- const params_bytes = numbers_to_u8s_for_gpu([subtask_idx, num_columns, num_x_workgroups]);
- const bpr1_key = `${curveId ?? 'x'}:bpr1:${workgroup_size}:${num_columns}:${num_x_workgroups}:${subtask_idx}:N=${input_size_key}`;
+ // Parameters as a uniform buffer. Layout: (subtask_idx_base,
+ // num_columns, num_subtasks_per_bpr, num_subtasks_total). The 4th
+ // slot lets the WPB-aware shader skip out-of-range subtasks in tail
+ // batches. Constant per (subtask_idx, layout, WPB) tuple, so we cache
+ // it on the context when one is provided.
+ const params_bytes = numbers_to_u8s_for_gpu([subtask_idx, num_columns, num_x_workgroups, num_subtasks_total]);
+ const bpr1_key = `${curveId ?? 'x'}:bpr1:wpb${windows_per_batch}:${workgroup_size}:${num_columns}:${num_x_workgroups}:${subtask_idx}:N=${input_size_key}`;
let params_ub: GPUBuffer;
if (context !== undefined && !debug && !debug_capture_sb) {
const got = context.acquirePersistentUniform(`${bpr1_key}:params_ub`, params_bytes.length);
@@ -2406,7 +2425,7 @@ const bpr_1 = async (
// Debug-capture variant compiles a different shader (Mustache flag),
// so keys must not collide with the non-debug one.
context,
- `${curveId ?? 'x'}:bpr1:${workgroup_size}:${num_columns}:${debug_capture_sb ? 'dbg' : 'nodbg'}:${assume_affine_buckets ? 'aff' : mixed_safe_buckets ? 'mxs-v2' : 'gen'}:bench=${bench_flags_key || 'none'}`,
+ `${curveId ?? 'x'}:bpr1:wpb${windows_per_batch}:${workgroup_size}:${num_columns}:${debug_capture_sb ? 'dbg' : 'nodbg'}:${assume_affine_buckets ? 'aff' : mixed_safe_buckets ? 'mxs-v2' : 'gen'}:bench=${bench_flags_key || 'none'}`,
);
cpu_timer?.accumulate('compile_bpr1_shader', performance.now() - _b1_compile_t0);
@@ -2554,10 +2573,15 @@ const bpr_2 = async (
// group across MSM calls with different `input_size`, whose
// workspace buffers are equally sized but distinct GPUBuffer objects.
input_size_key = 0,
+ // See bpr_1: total subtasks (4th param slot) + WPB (cache key).
+ num_subtasks_total = 0,
+ windows_per_batch = 1,
) => {
// Parameters as a uniform buffer (cached on context when not debug).
- const params_bytes = numbers_to_u8s_for_gpu([subtask_idx, num_columns, num_x_workgroups]);
- const bpr2_key = `${curveId ?? 'x'}:bpr2:${workgroup_size}:${num_columns}:${num_x_workgroups}:${subtask_idx}:N=${input_size_key}`;
+ // Layout: (subtask_idx_base, num_columns, num_subtasks_per_bpr,
+ // num_subtasks_total). See bpr_1 for the layout rationale.
+ const params_bytes = numbers_to_u8s_for_gpu([subtask_idx, num_columns, num_x_workgroups, num_subtasks_total]);
+ const bpr2_key = `${curveId ?? 'x'}:bpr2:wpb${windows_per_batch}:${workgroup_size}:${num_columns}:${num_x_workgroups}:${subtask_idx}:N=${input_size_key}`;
let params_ub: GPUBuffer;
if (context !== undefined && !debug && !debug_capture_sb) {
const got = context.acquirePersistentUniform(`${bpr2_key}:params_ub`, params_bytes.length);
@@ -2578,7 +2602,7 @@ const bpr_2 = async (
shaderCode,
'stage_2',
context,
- `${curveId ?? 'x'}:bpr2:${workgroup_size}:${num_columns}:${debug_capture_sb ? 'dbg' : 'nodbg'}:${assume_affine_buckets ? 'aff' : mixed_safe_buckets ? 'mxs-v2' : 'gen'}:bench=${bench_flags_key || 'none'}`,
+ `${curveId ?? 'x'}:bpr2:wpb${windows_per_batch}:${workgroup_size}:${num_columns}:${debug_capture_sb ? 'dbg' : 'nodbg'}:${assume_affine_buckets ? 'aff' : mixed_safe_buckets ? 'mxs-v2' : 'gen'}:bench=${bench_flags_key || 'none'}`,
);
cpu_timer?.accumulate('compile_bpr2_shader', performance.now() - _b2_compile_t0);
diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts
index 68ffad6d7bd9..2ef6db054f7d 100644
--- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts
+++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts
@@ -1,6 +1,6 @@
// AUTO-GENERATED by scripts/inline-wgsl.mjs. DO NOT EDIT.
// Run `yarn generate:wgsl` (or `node scripts/inline-wgsl.mjs`) to regenerate.
-// 29 shader sources inlined.
+// 42 shader sources inlined.
/* eslint-disable */
@@ -506,6 +506,415 @@ fn bigint_signed_axby_modp_halve_k(
}
`;
+export const bigint_by = `// WGSL bigint helpers for the Bernstein-Yang safegcd inversion port
+// (sub-step 1.2 of the WebGPU MSM rewrite plan). The semantics here MUST
+// match the TS reference at cuzk/bernstein_yang.ts exactly so divsteps /
+// apply_matrix can be transliterated against this file in sub-step 1.3.
+//
+// REPRESENTATIONS
+// - BigIntBY: 9 limbs of signed-29-bit chunks (i32). Top limb carries
+// sign; lower limbs in [0, 2^29) post-normalise. Matches Wasm9x29.
+// - u64 as vec2: value = .x + (.y << 32). All u64 ops use this.
+// - i64 as vec2: value = u32(.x) + (i32(.y) << 32). The low half
+// carries the raw bit pattern (treated as unsigned), the high half
+// is signed. Used for matrix entries u/v/q/r which grow to ±2^58.
+//
+// LOOP BOUND DISCIPLINE
+// Every \`for\` loop in this file uses a const upper bound (BY_NUM_LIMBS).
+// The WGSL bounded-loop audit is \`grep 'for ' ...\`, run after Mustache
+// render — every match has \`< BY_NUM_LIMBS\` or \`< {{ num_words }}\`.
+
+const BY_NUM_LIMBS: u32 = 9u;
+const BY_LIMB_BITS: u32 = 29u;
+const BY_LIMB_MASK: u32 = (1u << 29u) - 1u;
+const BY_BATCH: u32 = 58u;
+const BY_NUM_OUTER: u32 = 13u;
+const BY_REDUCE_INTERVAL: u32 = 4u;
+const BY_RTC_MAX_ITERS: u32 = 36u;
+
+// 29-bit signed limb wrapped as i32. Lower limbs canonical in [0, 2^29);
+// top limb is the signed extension carrier.
+struct BigIntBY {
+ l: array,
+}
+
+// ============================================================
+// u64 arithmetic as vec2 (x = .x + (.y << 32))
+// ============================================================
+
+// u64 add. Pre: any a, b. Post: low 64 bits of a + b.
+fn u64_add(a: vec2, b: vec2) -> vec2 {
+ let lo = a.x + b.x;
+ let carry = select(0u, 1u, lo < a.x);
+ let hi = a.y + b.y + carry;
+ return vec2(lo, hi);
+}
+
+// u64 sub. Pre: any a, b. Post: low 64 bits of (a - b) mod 2^64.
+fn u64_sub(a: vec2, b: vec2) -> vec2 {
+ let lo = a.x - b.x;
+ let borrow = select(0u, 1u, a.x < b.x);
+ let hi = a.y - b.y - borrow;
+ return vec2(lo, hi);
+}
+
+// u64 logical right shift by 1. Pre: any x. Post: x >> 1 unsigned.
+fn u64_shr1(x: vec2) -> vec2 {
+ let lo = (x.x >> 1u) | ((x.y & 1u) << 31u);
+ let hi = x.y >> 1u;
+ return vec2(lo, hi);
+}
+
+// Bit 0 of a u64. Pre: any x. Post: 0 or 1.
+fn u64_low_bit(x: vec2) -> u32 {
+ return x.x & 1u;
+}
+
+// u64 two's-complement negate. Pre: any x. Post: (-x) mod 2^64.
+fn u64_neg(x: vec2) -> vec2 {
+ let nx = vec2(~x.x, ~x.y);
+ return u64_add(nx, vec2(1u, 0u));
+}
+
+// u64 bitwise AND. Pre: any a, mask. Post: a & mask, limbwise.
+fn u64_and(a: vec2, mask: vec2) -> vec2 {
+ return vec2(a.x & mask.x, a.y & mask.y);
+}
+
+// ============================================================
+// i64 signed arithmetic as vec2 (lo bits, hi signed)
+// value = u32(.x) + (i32(.y) << 32) treated as signed two's-complement.
+// ============================================================
+
+// i64 add. Pre: |a|, |b| < 2^63. Post: low 64 bits of a + b, signed.
+fn i64_add_pair(a: vec2, b: vec2) -> vec2 {
+ let au = u32(a.x);
+ let bu = u32(b.x);
+ let lo = au + bu;
+ let carry = select(0u, 1u, lo < au);
+ let hi = a.y + b.y + i32(carry);
+ return vec2(i32(lo), hi);
+}
+
+// i64 sub. Pre: |a|, |b| < 2^63. Post: low 64 bits of a - b, signed.
+fn i64_sub_pair(a: vec2, b: vec2) -> vec2 {
+ let au = u32(a.x);
+ let bu = u32(b.x);
+ let lo = au - bu;
+ let borrow = select(0u, 1u, au < bu);
+ let hi = a.y - b.y - i32(borrow);
+ return vec2(i32(lo), hi);
+}
+
+// i64 left shift by 1. Pre: a in (-2^62, 2^62). Post: a << 1, signed.
+fn i64_shl1_pair(a: vec2) -> vec2 {
+ let au = u32(a.x);
+ let lo = au << 1u;
+ let bit31 = au >> 31u;
+ // Arithmetic shift left of hi half: bring bottom carry bit up.
+ let hi_u = (u32(a.y) << 1u) | bit31;
+ return vec2(i32(lo), i32(hi_u));
+}
+
+// i64 two's-complement negate. Pre: any a (low 64 bits well-defined).
+// Post: (-a) mod 2^64 as a signed pair.
+fn i64_neg_pair(a: vec2) -> vec2 {
+ let nlo = ~u32(a.x);
+ let nhi = ~u32(a.y);
+ let lo = nlo + 1u;
+ let carry = select(0u, 1u, lo < nlo);
+ let hi_u = nhi + carry;
+ return vec2(i32(lo), i32(hi_u));
+}
+
+// ============================================================
+// signed_mul_split: signed 29 × 31-bit product, split into 29-bit chunks.
+//
+// Pre: |a| <= 2^29 (one BY limb), |b| <= 2^31 - 1 (i32 matrix-entry low).
+// Post: returns (lo29, hi) with a*b = lo29 + (hi * 2^29),
+// 0 <= lo29 < 2^29; |hi| <= 2^31 (fits in i32 except at the single
+// corner a == -2^29, b == -2^31 which the caller does not pass).
+//
+// Partial-product width split (the plan-mandated proof sketch):
+// Split each operand as v = v_lo + v_hi * 2^15 where v_lo is sign-
+// extended low 15 bits, i.e. v_lo in [-2^14, 2^14). Then v_hi = (v -
+// v_lo) >> 15.
+// |a_lo|, |b_lo| <= 2^14
+// |a_hi| <= ceil(2^29 / 2^15) = 2^14
+// |b_hi| <= ceil(2^31 / 2^15) = 2^16
+// Four signed partial products:
+// pll = a_lo * b_lo, |pll| <= 2^14 * 2^14 = 2^28 [< 2^31 ✓]
+// plh = a_lo * b_hi, |plh| <= 2^14 * 2^16 = 2^30 [< 2^31 ✓]
+// phl = a_hi * b_lo, |phl| <= 2^14 * 2^14 = 2^28 [< 2^31 ✓]
+// phh = a_hi * b_hi, |phh| <= 2^14 * 2^16 = 2^30 [< 2^31 ✓]
+// Middle sum:
+// mid = plh + phl, |mid| <= 2^30 + 2^28 < 2^31 [fits i32]
+// Every partial fits in signed i32, satisfying the plan's audit. The
+// subsequent combination uses u32 wrap arithmetic to assemble the low
+// 64 bits of the signed product without further overflow concerns.
+fn signed_mul_split(a: i32, b: i32) -> vec2 {
+ // Sign-extend the low 15 bits of each operand: (v << 17) >> 17 fills
+ // bits 15..31 with the sign of bit 14. WGSL i32 shifts are well-defined
+ // (left shift discards high bits; right shift on i32 is arithmetic).
+ let a_lo: i32 = (a << 17u) >> 17u;
+ let a_hi: i32 = (a - a_lo) >> 15u;
+ let b_lo: i32 = (b << 17u) >> 17u;
+ let b_hi: i32 = (b - b_lo) >> 15u;
+
+ let pll: i32 = a_lo * b_lo;
+ let plh: i32 = a_lo * b_hi;
+ let phl: i32 = a_hi * b_lo;
+ let phh: i32 = a_hi * b_hi;
+ let mid: i32 = plh + phl;
+
+ // total = pll + mid * 2^15 + phh * 2^30, computed as the low 64 bits
+ // of an i64 via u32 wrap arithmetic.
+ //
+ // Each i32 piece extended to i64 (lo, hi):
+ // pll: lo = u32(pll), hi = u32(pll >> 31) (sign mask)
+ // mid << 15: lo = u32(mid) << 15, hi = u32(mid >> 17)
+ // phh << 30: lo = u32(phh) << 30, hi = u32(phh >> 2)
+ // The >> on signed pieces is arithmetic so the high half carries the
+ // correct sign extension.
+ let pll_lo: u32 = u32(pll);
+ let pll_hi: u32 = u32(pll >> 31u);
+ let mid_lo: u32 = u32(mid) << 15u;
+ let mid_hi: u32 = u32(mid >> 17u);
+ let phh_lo: u32 = u32(phh) << 30u;
+ let phh_hi: u32 = u32(phh >> 2u);
+
+ // Sum the three i64 values with carry chains.
+ let s1_lo: u32 = pll_lo + mid_lo;
+ let c1: u32 = select(0u, 1u, s1_lo < pll_lo);
+ let s1_hi: u32 = pll_hi + mid_hi + c1;
+
+ let sum_lo: u32 = s1_lo + phh_lo;
+ let c2: u32 = select(0u, 1u, sum_lo < s1_lo);
+ let sum_hi: u32 = s1_hi + phh_hi + c2;
+
+ // Re-split into lo29 (low 29 bits) and hi = ARS by 29 (signed).
+ let lo29: u32 = sum_lo & 0x1FFFFFFFu;
+ // hi bits 0..2 from sum_lo[29..31]; bits 3..31 from sum_hi[0..28].
+ // sum_hi[29..31] are sign-extension bits under the precondition
+ // |a*b| <= 2^60 and are absorbed correctly when bitcasting hi_u to i32.
+ let hi_u: u32 = (sum_lo >> 29u) | (sum_hi << 3u);
+ return vec2(i32(lo29), i32(hi_u));
+}
+
+// by_accumulate: acc + m_lo * x_limb as a 58-bit signed two-limb pair.
+//
+// Pre: |acc.x| < 2^29 (canonical low half), |acc.y| < 2^29 + 2^k for
+// some small k (caller bounded), |m_lo| <= 2^29, |x_limb| <= 2^29.
+// Post: returns (lo29, hi) with acc + m_lo * x_limb = lo29 + hi * 2^29,
+// 0 <= lo29 < 2^29 (canonical low limb). hi grows by at most one
+// carry bit per call.
+//
+// This is the per-limb cross-product helper used by apply_matrix. The
+// product |m_lo * x_limb| <= 2^58, well within signed_mul_split's bounds.
+fn by_accumulate(acc: vec2, m_lo: i32, x_limb: i32) -> vec2 {
+ let prod = signed_mul_split(m_lo, x_limb);
+ // Add (prod.x, prod.y) to (acc.x, acc.y) treating both halves as signed
+ // 32-bit lanes. acc.x in [0, 2^29), prod.x in [0, 2^29), so the low-half
+ // sum is in [0, 2^30) — fits i32, no carry to track for this helper.
+ let lo_sum = acc.x + prod.x;
+ let lo29 = lo_sum & i32(BY_LIMB_MASK);
+ // The high half absorbs prod.y plus any bits of lo_sum above bit 28.
+ // WGSL requires the rhs of \`>>\` to be u32 (the lhs stays i32 — i32 shift
+ // is arithmetic, preserving sign).
+ let lo_overflow = lo_sum >> BY_LIMB_BITS;
+ let hi = acc.y + prod.y + lo_overflow;
+ return vec2(lo29, hi);
+}
+
+// by_low_u64_lohi: low 64 bits of a BigIntBY as a vec2 (.x = low 32,
+// .y = high 32). Mirrors \`Wasm9x29::low_64()\` and the TS \`low64\`:
+// bits 0..28 = x.l[0] & LIMB_MASK
+// bits 29..57 = x.l[1] & LIMB_MASK
+// bits 58..63 = x.l[2] & 0x3F
+// All inputs treated as unsigned u32 bit patterns. Negative limbs
+// reinterpreted via \`u32(...)\` — the caller is responsible for ensuring the
+// lower three limbs are canonical (in [0, 2^29) etc.), which is the case in
+// the BY driver after each by_normalise.
+fn by_low_u64_lohi(x: BigIntBY) -> vec2 {
+ let l0: u32 = u32(x.l[0]) & BY_LIMB_MASK;
+ let l1: u32 = u32(x.l[1]) & BY_LIMB_MASK;
+ let l2: u32 = u32(x.l[2]) & 0x3Fu;
+ // low 32 bits: l0 in bits 0..28, low 3 bits of l1 at 29..31.
+ let lo32: u32 = l0 | ((l1 & 0x7u) << 29u);
+ // high 32 bits: high 26 bits of l1 at 0..25, l2 (6 bits) at 26..31.
+ let hi32: u32 = (l1 >> 3u) | (l2 << 26u);
+ return vec2(lo32, hi32);
+}
+
+// by_normalise: carry-propagate so each limb i in [0, N-1) ends up in
+// [0, 2^29) and the top limb absorbs the final signed carry.
+//
+// Pre: any signed limb values (caller may pass non-canonical ±2^60).
+// Post: x.l[0..N-2] in [0, 2^29); x.l[N-1] is the signed extension.
+fn by_normalise(x: ptr) {
+ var c: i32 = 0;
+ for (var i: u32 = 0u; i < BY_NUM_LIMBS - 1u; i = i + 1u) {
+ let v = (*x).l[i] + c;
+ (*x).l[i] = v & i32(BY_LIMB_MASK);
+ // Arithmetic shift on i32 propagates sign; matches BigInt(v) >> 29n in TS.
+ // WGSL requires the rhs of \`>>\` to be u32 (the lhs i32 then gets the
+ // arithmetic shift semantics we want).
+ c = v >> BY_LIMB_BITS;
+ }
+ (*x).l[BY_NUM_LIMBS - 1u] = (*x).l[BY_NUM_LIMBS - 1u] + c;
+}
+
+// ============================================================
+// Conversion between 20×13-bit BigInt and 9×29-bit BigIntBY.
+//
+// The conversion is a bit-rewrite: read 29 source bits at offset
+// 29*k into limb k. Each 29-bit window spans up to 4 consecutive
+// source limbs (worst case bit-offset 12, covers 27 bits in 3 limbs;
+// needs 2 more from a 4th).
+//
+// Pre (by_from_bigint): x in [0, 2^260) represented in 20 × 13-bit limbs.
+// Post: result in [0, 2^256), 9 canonical 29-bit limbs.
+// Pre (by_to_bigint): x canonical (lower limbs in [0, 2^29), top non-neg).
+// Post: result in [0, 2^260), 20 × 13-bit unsigned limbs.
+// ============================================================
+
+const BY_SRC_LIMB_BITS: u32 = 13u;
+const BY_SRC_LIMB_MASK: u32 = (1u << 13u) - 1u;
+
+// Read a 29-bit window from the 20×13-bit BigInt starting at bit \`bit_lo\`.
+// The window spans up to 4 consecutive source limbs (ceil((29 + 12) / 13) = 4).
+fn by_read_29bit_window(x: BigInt, bit_lo: u32) -> u32 {
+ let base_idx = bit_lo / BY_SRC_LIMB_BITS; // first source limb touched
+ let in_limb_bit = bit_lo % BY_SRC_LIMB_BITS; // bit offset within base limb
+ // Up to four consecutive limbs (out-of-range reads return 0).
+ var v0: u32 = 0u;
+ var v1: u32 = 0u;
+ var v2: u32 = 0u;
+ var v3: u32 = 0u;
+ if (base_idx < {{ num_words }}u) { v0 = x.limbs[base_idx]; }
+ if (base_idx + 1u < {{ num_words }}u) { v1 = x.limbs[base_idx + 1u]; }
+ if (base_idx + 2u < {{ num_words }}u) { v2 = x.limbs[base_idx + 2u]; }
+ if (base_idx + 3u < {{ num_words }}u) { v3 = x.limbs[base_idx + 3u]; }
+ // Stitch: contributions are v0 >> in_limb_bit, v1..v3 at increasing offsets.
+ // The v3 shift can reach 13*3 - 0 = 39 which exceeds 32 — but in WGSL
+ // shifting a 13-bit limb by >= 32 yields zero by spec; we only need v3
+ // when in_limb_bit > 10, where the shift is 39 - in_limb_bit ∈ [27, 28].
+ // For in_limb_bit < 11 the contribution of v3 is irrelevant (covered by
+ // v0..v2). For safety we mask before shifting so >= 32 shifts vanish.
+ let s0 = v0 >> in_limb_bit;
+ let s1 = v1 << (BY_SRC_LIMB_BITS - in_limb_bit);
+ let s2 = v2 << ((BY_SRC_LIMB_BITS * 2u) - in_limb_bit);
+ let shift3: u32 = (BY_SRC_LIMB_BITS * 3u) - in_limb_bit;
+ // shift3 is in [27, 39]; for shift3 >= 32 we explicitly produce 0 to
+ // avoid WGSL's undefined-behavior region for shifts >= bit-width.
+ var s3: u32 = 0u;
+ if (shift3 < 32u) {
+ s3 = v3 << shift3;
+ }
+ return (s0 | s1 | s2 | s3) & BY_LIMB_MASK;
+}
+
+fn by_from_bigint(x: BigInt) -> BigIntBY {
+ var r: BigIntBY;
+ for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u) {
+ r.l[i] = i32(by_read_29bit_window(x, i * BY_LIMB_BITS));
+ }
+ return r;
+}
+
+// Write a 13-bit window into the 20×13-bit BigInt at bit \`bit_lo\`.
+// Reads three consecutive 29-bit BY limbs and emits the right 13-bit slice.
+fn by_read_13bit_window(x: BigIntBY, bit_lo: u32) -> u32 {
+ let base_idx = bit_lo / BY_LIMB_BITS;
+ let in_limb_bit = bit_lo % BY_LIMB_BITS;
+ // Up to two BY limbs cover any 13-bit window (since 29 > 13).
+ var v0: u32 = 0u;
+ var v1: u32 = 0u;
+ if (base_idx < BY_NUM_LIMBS) { v0 = u32(x.l[base_idx]); }
+ if (base_idx + 1u < BY_NUM_LIMBS) { v1 = u32(x.l[base_idx + 1u]); }
+ let s0 = v0 >> in_limb_bit;
+ // Shift amount for v1 is (29 - in_limb_bit). In WGSL shifts of >= 32 are
+ // undefined; (29 - in_limb_bit) is in [1, 29] so always < 32.
+ let shift_up = BY_LIMB_BITS - in_limb_bit;
+ let s1 = v1 << shift_up;
+ return (s0 | s1) & BY_SRC_LIMB_MASK;
+}
+
+fn by_to_bigint(x: BigIntBY) -> BigInt {
+ var out: BigInt;
+ for (var i: u32 = 0u; i < {{ num_words }}u; i = i + 1u) {
+ out.limbs[i] = by_read_13bit_window(x, i * BY_SRC_LIMB_BITS);
+ }
+ return out;
+}
+`;
+
+export const bigint_f32 = `// f32-limb mirror of \`bigint.template.wgsl\`. Each limb holds an
+// integer-valued f32 in [0, 2^WORD_SIZE_F32) = [0, W). Mirrors the
+// subset of \`bigint.template.wgsl\` needed by \`montgomery_product_f32\`
+// and the f32 field ops in \`field/field_f32.template.wgsl\` /
+// \`field/fr_pow_f32.template.wgsl\`.
+
+struct BigIntF32 {
+ limbs: array
+}
+
+fn bigint_f32_gt(x: ptr, y: ptr) -> bool {
+ for (var idx = 0u; idx < {{ num_limbs }}u; idx ++) {
+ let i = {{ num_limbs }}u - 1u - idx;
+ if ((*x).limbs[i] < (*y).limbs[i]) { return false; }
+ if ((*x).limbs[i] > (*y).limbs[i]) { return true; }
+ }
+ return false;
+}
+
+fn bigint_f32_eq(x: ptr, y: ptr) -> bool {
+ for (var i = 0u; i < {{ num_limbs }}u; i ++) {
+ if ((*x).limbs[i] != (*y).limbs[i]) { return false; }
+ }
+ return true;
+}
+
+// res = a - b. Per-limb borrow lives as an f32 in {0.0, 1.0};
+// \`step(diff, -0.5)\` is 1.0 iff diff is a negative integer (i.e.,
+// underflowed). Adding \`underflow * W\` then canonicalises the limb
+// back into [0, W). Returns the final borrow-out (0.0 or 1.0): callers
+// that need to know whether a >= b consult this flag, mirroring the
+// u32 \`bigint_sub\`'s return convention.
+fn bigint_f32_sub(a: ptr, b: ptr, res: ptr) -> f32 {
+ var borrow: f32 = 0.0;
+ for (var i = 0u; i < {{ num_limbs }}u; i ++) {
+ let diff = (*a).limbs[i] - (*b).limbs[i] - borrow;
+ let underflow = step(diff, -0.5);
+ (*res).limbs[i] = diff + underflow * W;
+ borrow = underflow;
+ }
+ return borrow;
+}
+
+// res = a + b. Per-limb carry lives as an f32 in {0.0, 1.0}; \`step(W-0.5, sum)\`
+// is 1.0 iff sum >= W (the bias-free branchless test). Each sum is at most
+// 2*(W-1) + 1 = 2^24 - 1, exact in the 24-bit f32 mantissa. Returns the
+// final carry-out (0.0 or 1.0) for downstream conditional-reduce logic.
+fn bigint_f32_add(a: ptr, b: ptr, res: ptr) -> f32 {
+ var carry: f32 = 0.0;
+ for (var i = 0u; i < {{ num_limbs }}u; i ++) {
+ let sum = (*a).limbs[i] + (*b).limbs[i] + carry;
+ let overflow = step(W - 0.5, sum);
+ (*res).limbs[i] = sum - overflow * W;
+ carry = overflow;
+ }
+ return carry;
+}
+
+fn bigint_f32_is_zero(x: ptr) -> bool {
+ for (var i = 0u; i < {{ num_limbs }}u; i ++) {
+ if ((*x).limbs[i] != 0.0) { return false; }
+ }
+ return true;
+}
+`;
+
export const ec_bn254 = `// Jacobian-coordinate EC arithmetic for BN254 (short Weierstrass a=0, b=3).
// Affine interpretation: (X, Y, Z) represents affine (X/Z^2, Y/Z^3).
// Identity: any point with Z = 0.
@@ -854,6 +1263,94 @@ fn add_points_mixed_no_collision(p1: Point, p2: Point) -> Point {
}
`;
+export const apply_matrix_bench = `// Single-thread-per-input bench shader for \`by_apply_matrix_fg\` /
+// \`by_apply_matrix_de\`. Each thread runs both passes on one input record
+// and writes the resulting 4 × 9 limbs (f', g', d', e') as 36 i32 outputs.
+//
+// INPUT LAYOUT (44 u32 per thread):
+// Offset 0..7: Mat fields { u, v, q, r, u_hi, v_hi, q_hi, r_hi } as i32.
+// Offset 8..16: f limbs (9 × i32)
+// Offset 17..25: g limbs (9 × i32)
+// Offset 26..34: d limbs (9 × i32)
+// Offset 35..43: e limbs (9 × i32)
+// All limbs are signed-canonical (lower limbs in [0, 2^29), top limb signed).
+//
+// OUTPUT LAYOUT (36 i32 per thread):
+// Offset 0..8: f' limbs (9 × i32) — output of by_apply_matrix_fg
+// Offset 9..17: g' limbs (9 × i32) — output of by_apply_matrix_fg
+// Offset 18..26: d' limbs (9 × i32) — output of by_apply_matrix_de
+// Offset 27..35: e' limbs (9 × i32) — output of by_apply_matrix_de
+//
+// CONSTANTS (Mustache injected):
+// p_limbs_by: BigIntBY initializer for the BN254 base-field modulus p.
+// p_inv_by_lo: low 32 bits of P_INV = p^(-1) mod 2^58.
+// p_inv_by_hi: high 32 bits of P_INV (only the low 26 bits are non-zero).
+//
+// LOOP BOUNDS: only inherited loops — by_apply_matrix_fg and
+// by_apply_matrix_de each use a single \`for\` bounded by BY_NUM_LIMBS
+// (= 9u const). The dispatch entry shader itself has no loops.
+//
+// SAFETY: \`if (tid >= n) return;\` static guard, no data-dependent loops.
+
+@group(0) @binding(0) var inputs: array;
+@group(0) @binding(1) var outputs: array;
+@group(0) @binding(2) var params: vec2;
+
+const APPLY_MATRIX_INPUT_STRIDE: u32 = 44u; // 8 (Mat) + 4 * 9 (limbs)
+const APPLY_MATRIX_OUTPUT_STRIDE: u32 = 36u; // 4 * 9 (limbs)
+const P_INV_BY_LO: u32 = {{ p_inv_by_lo }}u;
+const P_INV_BY_HI: u32 = {{ p_inv_by_hi }}u;
+
+@compute @workgroup_size({{ workgroup_size }})
+fn main(@builtin(global_invocation_id) gid: vec3) {
+ let n = params.x;
+ let tid = gid.x;
+ if (tid >= n) { return; }
+
+ let in_base: u32 = tid * APPLY_MATRIX_INPUT_STRIDE;
+ let out_base: u32 = tid * APPLY_MATRIX_OUTPUT_STRIDE;
+
+ // Decode Mat. The 8 i32 fields land in bindings as u32 — bitcast.
+ let m: Mat = Mat(
+ bitcast(inputs[in_base + 0u]),
+ bitcast(inputs[in_base + 1u]),
+ bitcast(inputs[in_base + 2u]),
+ bitcast(inputs[in_base + 3u]),
+ bitcast(inputs[in_base + 4u]),
+ bitcast(inputs[in_base + 5u]),
+ bitcast(inputs[in_base + 6u]),
+ bitcast(inputs[in_base + 7u]),
+ );
+
+ // Decode f, g, d, e as BigIntBY (9 × i32 each).
+ var f: BigIntBY;
+ var g: BigIntBY;
+ var d: BigIntBY;
+ var e: BigIntBY;
+ for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u) {
+ f.l[i] = bitcast(inputs[in_base + 8u + i]);
+ g.l[i] = bitcast(inputs[in_base + 17u + i]);
+ d.l[i] = bitcast(inputs[in_base + 26u + i]);
+ e.l[i] = bitcast(inputs[in_base + 35u + i]);
+ }
+
+ // Modulus p, fixed across all threads — Mustache-injected.
+ var p: BigIntBY = BigIntBY(array({{{ p_limbs_by }}}));
+
+ // Apply matrix passes.
+ by_apply_matrix_fg(m, &f, &g);
+ by_apply_matrix_de(m, &d, &e, &p, P_INV_BY_LO, P_INV_BY_HI);
+
+ // Write outputs.
+ for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u) {
+ outputs[out_base + 0u + i] = f.l[i];
+ outputs[out_base + 9u + i] = g.l[i];
+ outputs[out_base + 18u + i] = d.l[i];
+ outputs[out_base + 27u + i] = e.l[i];
+ }
+}
+`;
+
export const barrett = `const W_MASK = {{ w_mask }}u;
const SLACK = {{ slack }}u;
@@ -1195,6 +1692,11 @@ var round_count: array>;
@group(0) @binding(2)
var dispatch_args: array;
+// WINDOWS_PER_BATCH — baked at render time. Controls the inverse-pass
+// Z dispatch dim: num_batches = ceil(num_subtasks / WPB). See
+// batch_inverse_parallel.template.wgsl for the merged-pool inverse.
+const WPB: u32 = {{ windows_per_batch }}u;
+
// params[0] = num_subtasks
// params[1] = apply_workgroup_size
// params[2] = sched_x_groups (= ceil(num_columns / schedule_workgroup_size))
@@ -1209,6 +1711,7 @@ fn main() {
let apply_wg_size = params[1];
let sched_x_groups = params[2];
let num_sub_wgs = params[3];
+ let num_batches = (num_subtasks + WPB - 1u) / WPB;
var max_count: u32 = 0u;
for (var i: u32 = 0u; i < num_subtasks; i = i + 1u) {
@@ -1233,13 +1736,14 @@ fn main() {
dispatch_args[1] = 1u;
dispatch_args[2] = num_subtasks;
- // THIS round's inverse: (W, 1, T) workgroups — W sub-WGs per subtask
- // splitting each subtask's pair pool into W contiguous slices, each
- // independently inverted with its own fr_inv. Drops Phase A/D
- // per-thread sequential cost by W. See batch_inverse_parallel for
- // the algorithm.
+ // THIS round's inverse: (W, 1, num_batches) workgroups — W sub-WGs
+ // per (batch, sub_wg) splitting each batch's MERGED pair pool (of
+ // size sum_{w} round_count[batch * WPB + w]) into W contiguous
+ // slices, each independently inverted with its own fr_inv. Pooling
+ // amortises one fr_inv across WPB subtasks. Z dim drops from T to
+ // ceil(T/WPB) accordingly.
let inverse_x = select(0u, num_sub_wgs, any_work);
- let inverse_z = select(0u, num_subtasks, any_work);
+ let inverse_z = select(0u, num_batches, any_work);
dispatch_args[3] = inverse_x;
dispatch_args[4] = 1u;
dispatch_args[5] = inverse_z;
@@ -2042,6 +2546,8 @@ export const batch_inverse = `{{> structs }}
{{> montgomery_product_funcs }}
{{> field_funcs }}
{{> fr_pow_funcs }}
+{{> bigint_by_funcs }}
+{{> by_inverse_a_funcs }}
// get_r returns Montgomery R (= the integer 1 in Montgomery form). Each
// main shader defines its own with \`r_limbs\` substitution; the
@@ -2113,7 +2619,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3) {
}
// Step 2: invert the final product.
- var inv_acc: BigInt = fr_inv(acc);
+ var inv_acc: BigInt = fr_inv_by_a(acc);
// Step 3: walk back, emitting individual inverses.
for (var idx = 0u; idx < n - 1u; idx = idx + 1u) {
@@ -2136,6 +2642,8 @@ export const batch_inverse_parallel = `{{> structs }}
{{> montgomery_product_funcs }}
{{> field_funcs }}
{{> fr_pow_funcs }}
+{{> bigint_by_funcs }}
+{{> by_inverse_a_funcs }}
// Parallel Montgomery batch-inverse on the GPU.
//
@@ -2146,74 +2654,65 @@ export const batch_inverse_parallel = `{{> structs }}
// workgroup of TPB threads, dropping the per-dispatch cost from ~30 ms
// to ~1-3 ms at the realistic round sizes we see.
//
-// MULTI-WORKGROUP MODE. The dispatch is shaped (NUM_SUB_WGS, 1, T).
-// \`wid.z\` selects the subtask (count_buf[subtask]) and \`wid.x\` is the
-// sub-workgroup index inside that subtask (0..NUM_SUB_WGS-1). Each
-// sub-workgroup independently inverts a contiguous slice of length
-// per_sub_chunk = ceil(n / NUM_SUB_WGS), using its own fr_inv. Reasoning:
+// MULTI-WORKGROUP + WINDOWS-PER-BATCH MODE. The dispatch is shaped
+// (NUM_SUB_WGS, 1, num_batches), where num_batches = ceil(T / WPB).
+// \`wid.z\` selects the batch; \`wid.x\` is the sub-workgroup index inside
+// the batch. Each (wid.z, wid.x) pair inverts the COMBINED pair pool of
+// WPB consecutive subtasks, treating it as one big logical pool of total
+// length sum_{w=0..WPB-1} count_buf[batch * WPB + w], with one
+// fr_inv_by_a call per (batch, sub_wg).
//
-// - Phase A (per-thread fwd) and Phase D (per-thread back-walk) are
-// sequential per thread with bs = ceil(n / TPB). Splitting each
-// subtask across W sub-workgroups drops bs by W, giving up to W×
-// speedup on the dominant per-round latency at large N.
+// Amortisation: with WPB=4 we run T/WPB × W fr_inv_by_a calls per round
+// instead of T × W — a 4× drop in the dominant Phase C cost. The
+// per-subtask layout of the pair pool (and the \`prefix\` / \`outputs\`
+// scratch buffers) is unchanged, so \`batch_affine_apply_scatter\` reads
+// outputs[subtask_global * pitch + slot_in_subtask] just like before.
+// The kernel "sews together" the gaps between subtask slices in the
+// scan: a logical position p in the merged pool decodes to
+// (subtask_in_batch, slot_in_subtask) via the per-batch exclusive
+// prefix-sum \`wg_cum\`, then to a physical buffer position
+// (batch * WPB + subtask_in_batch) * pitch + slot_in_subtask.
//
-// - Each sub-workgroup runs its own fr_inv. fr_invs across sub-WGs
-// run concurrently on different SMs, so the EXTRA fr_invs cost zero
-// wall time (gated only by SM occupancy). With T=16 subtasks × W=8
-// sub-WGs = 128 workgroups in flight, RTX-class GPUs are fully
-// occupied during the inverse pass.
+// Per-thread walks advance a (cur_w, cur_off) cursor as they step
+// through their chunk so the per-step decode is O(1); the only WPB-loop
+// is the const-bounded prefix-sum setup at the top of the kernel.
//
// Two clients today:
-// - SMVP (cross-subtask): pitch = num_columns, count_buf[wid.z] =
-// pair_counter[wid.z] (per-subtask atomic).
-// - Finalize: pitch = half_num_columns, count_buf[wid.z] =
-// half_num_columns for all wid.z (pre-populated by host).
+// - SMVP (cross-subtask): pitch = num_columns, count_buf[g] =
+// round_count[g] (per-subtask, forwarded by dispatch_args).
+// - Finalize: pitch = half_num_columns, count_buf[g] =
+// half_num_columns for all g (pre-populated by host). With WPB=1 in
+// the finalize call this degenerates to the previous behaviour.
//
-// Algorithm (Montgomery's batch-inverse trick, two-level):
+// Algorithm (Montgomery's batch-inverse trick, two-level, merged pool):
//
// 1. Each thread i computes block_inclusive_prefix[k] for k in its
// chunk into the \`prefix\` scratch buffer (serial, length bs =
-// ceil(n / TPB)). Captures block_total[i] in a register.
+// ceil(n / TPB)). Captures block_total[i] in a register. The walk
+// is over LOGICAL positions in the merged pool, but each write
+// lands at the corresponding PHYSICAL buffer position (decoded via
+// wg_cum).
// 2. All TPB block_totals are pushed into workgroup memory.
// 3. Two parallel inclusive scans over wg memory: forward (wg_fwd)
// and backward (wg_bwd). Hillis-Steele, log2(TPB) passes.
-// After scan:
-// wg_fwd[i] = block_total[0] * ... * block_total[i]
-// wg_bwd[i] = block_total[i] * ... * block_total[TPB-1]
-// So global_total = wg_fwd[TPB-1] = wg_bwd[0].
-// 4. Thread 0 computes inv_global = inv(global_total) — single fr_inv.
-// Broadcast via wg_inv_total.
-// 5. Each thread walks back through its chunk. The key trick: within
-// a single block, block_excl_prefix cancels algebraically, so the
-// back-walk reduces to the standard 2-mul/element batch inverse
-// on the per-block prefix array.
-//
-// inv(global_prefix[k]) · global_prefix[k-1]
-// = inv(block_excl_prefix · P_in[k]) · (block_excl_prefix · P_in[k-1])
-// = inv(P_in[k]) · P_in[k-1] // block_excl_prefix cancels
-//
-// So:
-// block_excl_prefix[i] = wg_fwd[i-1] (or R for i=0)
-// block_excl_suffix[i] = wg_bwd[i+1] (or R for i=TPB-1)
-// // Setup: inv(block_total) = inv_global · block_excl_prefix · block_excl_suffix
-// inv_acc = inv(block_total[i]) // 2 muls (setup, one-time)
-// For k from chunk_end-1 down to chunk_start:
-// out[k] = inv_acc · (k>chunk_start ? prefix[k-1] : R) // 1 mul (or 0)
-// inv_acc = inv_acc · inputs[k] // 1 mul
-//
-// Cost: 2N muls in back-walk + N muls in forward + 2 muls setup
-// = 3N + O(1) per workgroup. Previously 3N + N back-walk muls
-// (extra mul-by-block_excl_prefix per element) = 4N total.
+// 4. Thread 0 computes inv_global = inv(global_total) — single
+// fr_inv_by_a per (batch, sub_wg). Broadcast via wg_inv_total.
+// 5. Each thread walks back through its chunk. Within a single block,
+// block_excl_prefix cancels algebraically, so the back-walk runs
+// as the standard 2-mul/element batch inverse on the per-block
+// prefix array (same as before). Outputs land at physical buffer
+// positions matching the input layout.
//
// TPB = 64 keeps workgroup memory at 2 * 64 * sizeof(BigInt) = 10240
// bytes (BN254: BigInt = 80 bytes), comfortably under the WebGPU
// default maxComputeWorkgroupStorageSize = 16384.
//
-// When n == 0 for this workgroup the kernel returns immediately at the
-// top.
+// When total_n == 0 for this (batch, sub_wg) the kernel returns
+// immediately at the top.
const TPB: u32 = 64u;
const NUM_SUB_WGS: u32 = {{ num_sub_wgs }}u;
+const WPB: u32 = {{ windows_per_batch }}u;
@group(0) @binding(0)
var inputs: array;
@@ -2227,7 +2726,7 @@ var outputs: array;
@group(0) @binding(3)
var count_buf: array>;
-// params[0] = pitch (per-workgroup slice stride)
+// params[0] = pitch (per-subtask slice stride inside inputs/prefix/outputs)
// params[1..3] = unused
@group(0) @binding(4)
var params: vec4;
@@ -2241,19 +2740,44 @@ fn get_r() -> BigInt {
var wg_fwd: array;
var wg_bwd: array;
var wg_inv_total: BigInt;
-// Broadcast slot for the pair count. atomicLoad returns a non-uniform
-// value as far as the WGSL uniformity analysis is concerned, so the
-// downstream \`workgroupBarrier()\`s would be ill-formed if we branched
-// directly on it. Funnelling it through a workgroup variable + an
-// explicit \`workgroupUniformLoad\` re-uniforms the value (and acts as
-// an implicit barrier).
-var wg_n: u32;
-// Per-sub-WG element offset into the subtask's slice. Set by tid 0
-// alongside wg_n; broadcast through workgroup memory so every thread
-// agrees on the slice base. (workgroupUniformLoad over wg_n implicitly
-// synchronises this write too, since both are written before the load.)
+// Per-batch exclusive prefix of subtask counts: wg_cum[w] is the
+// cumulative count of subtasks 0..w-1 within this batch (so wg_cum[0]
+// is always 0 and the merged-pool total length is
+// wg_cum[WPB-1] + wg_counts[WPB-1]). Used to decode "logical position
+// p in the merged pool" → "(subtask_in_batch w, slot s in that
+// subtask)" via the largest w with wg_cum[w] <= p.
+var wg_counts: array;
+var wg_cum: array;
+// Broadcast slot for the merged-pool total count. atomicLoad returns a
+// non-uniform value as far as the WGSL uniformity analysis is
+// concerned, so the downstream \`workgroupBarrier()\`s would be
+// ill-formed if we branched directly on it. Funnelling it through a
+// workgroup variable + an explicit \`workgroupUniformLoad\` re-uniforms
+// the value (and acts as an implicit barrier).
+var wg_total_n: u32;
+// Per-sub-WG element offset into the merged pool. Set by tid 0
+// alongside wg_total_n; broadcast through workgroup memory so every
+// thread agrees on the slice base. (workgroupUniformLoad over
+// wg_total_n implicitly synchronises this write too, since both are
+// written before the load.)
var wg_sub_offset: u32;
+// Decode a logical position p in the merged batch pool into
+// (subtask_in_batch, slot_in_subtask). The WPB-loop is const-bounded.
+struct PoolPos { w: u32, slot: u32 };
+fn decode_pool_pos(p: u32) -> PoolPos {
+ var out: PoolPos;
+ out.w = 0u;
+ out.slot = p;
+ for (var i = 1u; i < WPB; i = i + 1u) {
+ if (p >= wg_cum[i]) {
+ out.w = i;
+ out.slot = p - wg_cum[i];
+ }
+ }
+ return out;
+}
+
@compute
@workgroup_size(64)
fn main(
@@ -2262,15 +2786,27 @@ fn main(
) {
let tid = lid.x;
let pitch = params[0];
- let subtask_idx = wid.z;
+ let batch_idx = wid.z;
let sub_idx = wid.x;
+ let subtask_base = batch_idx * WPB;
- // Thread 0 reads the subtask's atomic count and the sub-WG's slice
- // bounds; broadcast via wg_n. workgroupUniformLoad ensures all
- // threads see a uniform \`n\` (= this sub-WG's element count) and
- // synchronises the workgroup before continuing.
+ // Thread 0 reads the WPB per-subtask atomic counts in this batch,
+ // computes the exclusive prefix-sum into wg_cum + wg_counts, and
+ // resolves the sub-WG's slice bounds inside the merged pool.
+ // Broadcasts via wg_total_n + wg_sub_offset. workgroupUniformLoad
+ // over wg_total_n then re-uniforms across the workgroup.
if (tid == 0u) {
- let total_n = atomicLoad(&count_buf[subtask_idx]);
+ var cum: u32 = 0u;
+ for (var w = 0u; w < WPB; w = w + 1u) {
+ let g = subtask_base + w;
+ // count_buf is sized to a multiple of WPB so the tail loads
+ // here return the host's initial zero (no out-of-bounds).
+ let c = atomicLoad(&count_buf[g]);
+ wg_counts[w] = c;
+ wg_cum[w] = cum;
+ cum = cum + c;
+ }
+ let total_n = cum;
// Split [0, total_n) into NUM_SUB_WGS contiguous chunks. Chunks
// are sized ceil(total_n / NUM_SUB_WGS); trailing sub-WGs may
// see a shorter (or empty) chunk if total_n isn't a multiple of
@@ -2285,39 +2821,70 @@ fn main(
sub_n = clamped_end - raw_start;
}
wg_sub_offset = raw_start;
- wg_n = sub_n;
+ wg_total_n = sub_n;
}
- let n = workgroupUniformLoad(&wg_n);
+ let n = workgroupUniformLoad(&wg_total_n);
if (n == 0u) {
return;
}
- let sub_offset_in_subtask = wg_sub_offset;
-
- // Subtask owns a slice of \`pitch\` elements at offset
- // subtask_idx * pitch. This sub-WG owns
- // [subtask_offset + sub_offset_in_subtask,
- // subtask_offset + sub_offset_in_subtask + n).
- let slice_offset = subtask_idx * pitch + sub_offset_in_subtask;
+ let sub_offset_in_batch = wg_sub_offset;
// Block size = ceil(n / TPB). Last thread may have a shorter chunk.
let bs = (n + TPB - 1u) / TPB;
- let chunk_start = tid * bs;
- var chunk_end = chunk_start + bs;
- if (chunk_end > n) {
- chunk_end = n;
+ let chunk_start_in_sub = tid * bs;
+ var chunk_end_in_sub = chunk_start_in_sub + bs;
+ if (chunk_end_in_sub > n) {
+ chunk_end_in_sub = n;
}
+ // Translate this thread's [chunk_start, chunk_end) window from
+ // sub-WG-local indexing into batch-merged-pool indexing.
+ let chunk_start = sub_offset_in_batch + chunk_start_in_sub;
+ let chunk_end = sub_offset_in_batch + chunk_end_in_sub;
+
// Phase A: per-thread inclusive prefix product over the chunk.
- // block_total = product of all elements in this thread's chunk,
- // or R (Montgomery 1) if the chunk is empty.
+ // Each k iterates over LOGICAL positions in the merged pool but
+ // every read/write lands at the PHYSICAL buffer position decoded
+ // through wg_cum. block_total = product of all elements in this
+ // thread's chunk, or R (Montgomery 1) if the chunk is empty.
var block_total: BigInt = get_r();
- if (chunk_start < n) {
- var acc: BigInt = inputs[slice_offset + chunk_start];
- prefix[slice_offset + chunk_start] = acc;
+ if (chunk_start_in_sub < n) {
+ // Decode the starting logical position once, then carry
+ // (cur_w, cur_off) forward through the chunk so the per-step
+ // decode is O(1). cur_off is the slot within the current
+ // subtask; when it reaches wg_counts[cur_w] we advance to the
+ // next subtask.
+ let start_pos = decode_pool_pos(chunk_start);
+ var cur_w: u32 = start_pos.w;
+ var cur_off: u32 = start_pos.slot;
+ let g0 = subtask_base + cur_w;
+ let phys0 = g0 * pitch + cur_off;
+ var acc: BigInt = inputs[phys0];
+ prefix[phys0] = acc;
+ // Advance one step past the seed (k = chunk_start) and walk to
+ // chunk_end. The k variable here is the BATCH-MERGED logical
+ // position; advance cur_off in lockstep, rolling cur_w + cur_off
+ // across the wg_counts[cur_w] boundary when the subtask runs out.
for (var k = chunk_start + 1u; k < chunk_end; k = k + 1u) {
- var x: BigInt = inputs[slice_offset + k];
+ cur_off = cur_off + 1u;
+ // Bounded rollover: cross at most one boundary per step
+ // (counts never exceed pitch, which exceeds bs in practice).
+ // The const-bounded loop guards against any degenerate
+ // input where wg_counts[cur_w] could be zero — we still
+ // can't iterate past WPB steps because cur_w is in [0, WPB).
+ for (var hop = 0u; hop < WPB; hop = hop + 1u) {
+ if (cur_off >= wg_counts[cur_w]) {
+ cur_w = cur_w + 1u;
+ cur_off = 0u;
+ } else {
+ break;
+ }
+ }
+ let g = subtask_base + cur_w;
+ let phys = g * pitch + cur_off;
+ var x: BigInt = inputs[phys];
acc = montgomery_product(&acc, &x);
- prefix[slice_offset + k] = acc;
+ prefix[phys] = acc;
}
block_total = acc;
}
@@ -2347,14 +2914,16 @@ fn main(
}
// Phase C: thread 0 inverts the global total. Broadcast via workgroup mem.
+ // One fr_inv_by_a per (batch, sub_wg), independent of WPB — that's
+ // the amortisation this layout buys.
if (tid == 0u) {
var global_total: BigInt = wg_fwd[TPB - 1u];
- wg_inv_total = fr_inv(global_total);
+ wg_inv_total = fr_inv_by_a(global_total);
}
workgroupBarrier();
// Phase D: walk back through this thread's chunk, emitting inverses.
- if (chunk_start >= n) {
+ if (chunk_start_in_sub >= n) {
return;
}
@@ -2379,6 +2948,13 @@ fn main(
var inv_acc: BigInt = montgomery_product(&inv_global, &block_excl_prefix);
inv_acc = montgomery_product(&inv_acc, &block_excl_suffix);
+ // Decode the END position (chunk_end - 1) and walk backward,
+ // mirroring Phase A's forward decode. Maintain (cur_w, cur_off) as
+ // we step k from chunk_end-1 down to chunk_start.
+ let end_pos = decode_pool_pos(chunk_end - 1u);
+ var cur_w: u32 = end_pos.w;
+ var cur_off: u32 = end_pos.slot;
+
// Walk from k = chunk_end-1 down to chunk_start. Within a block,
// block_excl_prefix cancels algebraically (see header comment), so
// we run the standard backward batch-inverse over the per-block
@@ -2387,23 +2963,288 @@ fn main(
while (k > chunk_start) {
k = k - 1u;
- // out[k] = inv_acc * prefix_in_block[k-1] (or inv_acc itself for k = chunk_start)
+ let g = subtask_base + cur_w;
+ let phys = g * pitch + cur_off;
+
+ // out[k] = inv_acc * prefix_in_block[k-1] (or inv_acc itself
+ // for k = chunk_start). The "k-1" position in the BLOCK is the
+ // previous logical position in the merged pool; if that lies
+ // in a different subtask, prefix is still indexed by physical
+ // position so we decode the neighbour separately.
var inv_a_k: BigInt;
if (k > chunk_start) {
- var prev_in_block: BigInt = prefix[slice_offset + k - 1u];
+ // Compute the physical position of (k - 1). This is the
+ // logical predecessor of the current step. We carry a
+ // dedicated (prev_w, prev_off) cursor by mirroring the
+ // forward-from-decode but starting from (cur_w, cur_off):
+ // step back by one slot; if cur_off was 0, roll back to
+ // the previous subtask's last slot.
+ var prev_w: u32 = cur_w;
+ var prev_off: u32;
+ if (cur_off > 0u) {
+ prev_off = cur_off - 1u;
+ } else {
+ // Roll back to the largest preceding subtask with a
+ // non-empty count. Const-bounded by WPB.
+ prev_off = 0u;
+ for (var hop = 0u; hop < WPB; hop = hop + 1u) {
+ if (prev_w == 0u) {
+ // Should not happen — we'd be past chunk_start.
+ break;
+ }
+ prev_w = prev_w - 1u;
+ if (wg_counts[prev_w] > 0u) {
+ prev_off = wg_counts[prev_w] - 1u;
+ break;
+ }
+ }
+ }
+ let g_prev = subtask_base + prev_w;
+ let phys_prev = g_prev * pitch + prev_off;
+ var prev_in_block: BigInt = prefix[phys_prev];
inv_a_k = montgomery_product(&inv_acc, &prev_in_block);
} else {
inv_a_k = inv_acc;
}
- outputs[slice_offset + k] = inv_a_k;
+ outputs[phys] = inv_a_k;
// Update: inv_acc <- inv_acc * a[k] = inv(prefix_in_block[k-1])
// for the next iteration. The update on the last iteration
// (k = chunk_start) is wasted — minor cost, kept for code
// simplicity; the loop exit condition skips it naturally.
if (k > chunk_start) {
- var a_k: BigInt = inputs[slice_offset + k];
+ var a_k: BigInt = inputs[phys];
inv_acc = montgomery_product(&inv_acc, &a_k);
+ // Step (cur_w, cur_off) back by one logical position, same
+ // rollover logic as the prev_* decode above.
+ if (cur_off > 0u) {
+ cur_off = cur_off - 1u;
+ } else {
+ for (var hop = 0u; hop < WPB; hop = hop + 1u) {
+ if (cur_w == 0u) {
+ break;
+ }
+ cur_w = cur_w - 1u;
+ if (wg_counts[cur_w] > 0u) {
+ cur_off = wg_counts[cur_w] - 1u;
+ break;
+ }
+ }
+ }
+ }
+ }
+
+ {{{ recompile }}}
+}
+`;
+
+export const bench_batch_affine = `{{> structs }}
+{{> bigint_funcs }}
+{{> montgomery_product_funcs }}
+{{> field_funcs }}
+{{> fr_pow_funcs }}
+{{> bigint_by_funcs }}
+{{> by_inverse_a_funcs }}
+
+// Standalone single-dispatch micro-benchmark for batch-affine EC addition.
+// Measures the per-pair cost of \`R_i = P_i + Q_i\` using Montgomery's
+// batch-inverse trick at a fixed BATCH_SIZE, swept across a range of sizes
+// by the host to find the sweet spot for amortising the single
+// \`fr_inv_by_a\` call per batch.
+//
+// Single-workgroup-per-batch layout. Each workgroup of TPB threads
+// processes one batch of BATCH_SIZE consecutive pairs. The host dispatches
+// (TOTAL_PAIRS / BATCH_SIZE, 1, 1) workgroups.
+//
+// Phase A: each thread serially walks its BS = BATCH_SIZE / TPB pairs,
+// computing delta_x_k = Q_k.x - P_k.x and accumulating the running
+// product \`prefix[k] = prefix[k-1] * delta_x_k\` for k in
+// [chunk_start, chunk_end). Captures \`block_total\` = product of
+// all chunk deltas.
+//
+// Phase B: workgroup-shared Hillis-Steele scan (forward + backward) over
+// the TPB block_totals.
+//
+// Phase C: thread 0 runs the SINGLE \`fr_inv_by_a\` call on the overall
+// product and broadcasts the result via workgroup memory.
+//
+// Phase D: each thread back-walks its chunk emitting one affine add per
+// step. Uses standard 2-mul/element backward batch-inverse:
+// \`inv_acc = inv(block_total)\` at the start, then for each k from
+// chunk_end-1 down to chunk_start:
+// inv_dx = (k > chunk_start) ? inv_acc * prefix[k-1] : inv_acc
+// ... affine add ...
+// outputs[2k] = R.x
+// outputs[2k+1] = R.y
+// if (k > chunk_start) inv_acc = inv_acc * delta_x_k
+//
+// LOOP BOUNDS — every loop in this kernel is bounded by a compile-time
+// \`const\`:
+// - PHASE_A_LOOP and PHASE_D_LOOP iterate \`BS\` times (compile-time
+// Mustache constant \`{{ per_thread_count }}\`).
+// - PHASE_B_LOOP runs while \`stride < TPB\` (compile-time const).
+// - Every inner \`montgomery_product\` / \`fr_inv_by_a\` / \`fr_sub\` loop is
+// bounded by \`NUM_WORDS\` or other compile-time constants in the
+// included partials.
+//
+// MEMORY BUDGET. Workgroup memory: 2 × TPB × sizeof(BigInt) = 2 × TPB × 80
+// bytes (BN254 BigInt = 20 × u32 = 80 B) + 1 × BigInt broadcast slot.
+// At TPB=64 this is 10 KiB + 80 B, well under the 16 KiB default.
+
+const BATCH_SIZE: u32 = {{ batch_size }}u;
+const TPB: u32 = {{ tpb }}u;
+const BS: u32 = {{ per_thread_count }}u; // = BATCH_SIZE / TPB
+
+@group(0) @binding(0)
+var inputs: array;
+
+@group(0) @binding(1)
+var prefix: array;
+
+@group(0) @binding(2)
+var outputs: array;
+
+fn get_r() -> BigInt {
+ var r: BigInt;
+{{{ r_limbs }}}
+ return r;
+}
+
+var wg_fwd: array;
+var wg_bwd: array;
+var wg_inv_total: BigInt;
+
+@compute
+@workgroup_size({{ tpb }})
+fn main(
+ @builtin(local_invocation_id) lid: vec3,
+ @builtin(workgroup_id) wid: vec3,
+) {
+ let tid = lid.x;
+ let batch_idx = wid.x;
+ let batch_base = batch_idx * BATCH_SIZE;
+ let chunk_start = tid * BS;
+ // chunk_end is statically BS past chunk_start since BATCH_SIZE is an
+ // exact multiple of TPB by construction; no clamping needed.
+ let chunk_end = chunk_start + BS;
+
+ // Phase A: per-thread inclusive-prefix product over \`delta_x_k\`.
+ var block_total: BigInt = get_r();
+ // First iter (k = chunk_start): seed acc with delta_x at that slot.
+ {
+ let k = chunk_start;
+ let pair_base = (batch_base + k) * 4u;
+ var p_x: BigInt = inputs[pair_base + 0u];
+ var q_x: BigInt = inputs[pair_base + 2u];
+ var dx: BigInt = fr_sub(&q_x, &p_x);
+ // prefix slot is per-WG-relative inside this batch's slice.
+ prefix[batch_base + k] = dx;
+ block_total = dx;
+ }
+ // Subsequent iters: const-bounded loop over the remaining BS-1 slots.
+ // PHASE_A_LOOP: bound = BS (compile-time constant).
+ for (var i = 1u; i < BS; i = i + 1u) {
+ let k = chunk_start + i;
+ let pair_base = (batch_base + k) * 4u;
+ var p_x: BigInt = inputs[pair_base + 0u];
+ var q_x: BigInt = inputs[pair_base + 2u];
+ var dx: BigInt = fr_sub(&q_x, &p_x);
+ block_total = montgomery_product(&block_total, &dx);
+ prefix[batch_base + k] = block_total;
+ }
+
+ wg_fwd[tid] = block_total;
+ wg_bwd[tid] = block_total;
+ workgroupBarrier();
+
+ // Phase B: Hillis-Steele forward + backward scans on the TPB
+ // block_totals. PHASE_B_LOOP: bound = TPB (compile-time constant).
+ for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) {
+ var fwd_x: BigInt = wg_fwd[tid];
+ if (tid >= stride) {
+ var lhs: BigInt = wg_fwd[tid - stride];
+ fwd_x = montgomery_product(&lhs, &fwd_x);
+ }
+ var bwd_x: BigInt = wg_bwd[tid];
+ if (tid + stride < TPB) {
+ var rhs: BigInt = wg_bwd[tid + stride];
+ bwd_x = montgomery_product(&bwd_x, &rhs);
+ }
+ workgroupBarrier();
+ wg_fwd[tid] = fwd_x;
+ wg_bwd[tid] = bwd_x;
+ workgroupBarrier();
+ }
+
+ // Phase C: thread 0 inverts the global total via fr_inv_by_a.
+ if (tid == 0u) {
+ var global_total: BigInt = wg_fwd[TPB - 1u];
+ wg_inv_total = fr_inv_by_a(global_total);
+ }
+ workgroupBarrier();
+
+ // Setup: inv_acc = inv(block_total[tid])
+ // = inv_global * block_excl_prefix * block_excl_suffix
+ var block_excl_prefix: BigInt = get_r();
+ if (tid > 0u) {
+ block_excl_prefix = wg_fwd[tid - 1u];
+ }
+ var block_excl_suffix: BigInt = get_r();
+ if (tid + 1u < TPB) {
+ block_excl_suffix = wg_bwd[tid + 1u];
+ }
+ var inv_global: BigInt = wg_inv_total;
+ var inv_acc: BigInt = montgomery_product(&inv_global, &block_excl_prefix);
+ inv_acc = montgomery_product(&inv_acc, &block_excl_suffix);
+
+ // Phase D: back-walk this thread's chunk, emitting affine adds.
+ // PHASE_D_LOOP: bound = BS (compile-time constant). Use the
+ // forward-form \`for (off = 0u; off < BS; ...)\` and derive the
+ // descending index k = chunk_end - 1 - off so the bound is provably
+ // static at compile time. NO \`while\`/\`loop\` constructs.
+ for (var off = 0u; off < BS; off = off + 1u) {
+ let k = chunk_start + (BS - 1u - off);
+ let pair_base = (batch_base + k) * 4u;
+
+ // Load all four pair coords once.
+ var p_x: BigInt = inputs[pair_base + 0u];
+ var p_y: BigInt = inputs[pair_base + 1u];
+ var q_x: BigInt = inputs[pair_base + 2u];
+ var q_y: BigInt = inputs[pair_base + 3u];
+
+ // Backward batch-inverse: inv_dx_k = inv_acc * prefix[k-1] for
+ // k > chunk_start, or inv_acc itself at the chunk's start slot.
+ var inv_dx: BigInt;
+ if (k > chunk_start) {
+ var prev_prefix: BigInt = prefix[batch_base + (k - 1u)];
+ inv_dx = montgomery_product(&inv_acc, &prev_prefix);
+ } else {
+ inv_dx = inv_acc;
+ }
+
+ // Affine add (Montgomery form):
+ // λ = (Q.y - P.y) * inv_dx
+ // R.x = λ^2 - P.x - Q.x
+ // R.y = λ * (P.x - R.x) - P.y
+ var dy: BigInt = fr_sub(&q_y, &p_y);
+ var slope: BigInt = montgomery_product(&dy, &inv_dx);
+ var slope_sq: BigInt = montgomery_product(&slope, &slope);
+ var t1: BigInt = fr_sub(&slope_sq, &p_x);
+ var r_x: BigInt = fr_sub(&t1, &q_x);
+ var dx_back: BigInt = fr_sub(&p_x, &r_x);
+ var ldx: BigInt = montgomery_product(&slope, &dx_back);
+ var r_y: BigInt = fr_sub(&ldx, &p_y);
+
+ // Output: 2 BigInts per pair.
+ let out_base = (batch_base + k) * 2u;
+ outputs[out_base + 0u] = r_x;
+ outputs[out_base + 1u] = r_y;
+
+ // Update inv_acc for the next (smaller-k) iteration, unless we
+ // just emitted the chunk-start slot.
+ if (k > chunk_start) {
+ var dx_k: BigInt = fr_sub(&q_x, &p_x);
+ inv_acc = montgomery_product(&inv_acc, &dx_k);
}
}
@@ -2434,9 +3275,14 @@ var g_points_y: array;
@group(0) @binding(5)
var g_points_z: array;
-// Unfiform storage buffer.
+// Uniform storage buffer. Layout: (subtask_idx_base, num_columns,
+// num_subtasks_per_bpr, num_subtasks_total). The 4th slot is the
+// dispatch's bounds-check ceiling for the WPB tail batch — values past
+// num_subtasks_total are skipped per-thread inside the multi-window
+// loop. Promoted from vec3 to vec4 so the host can pass a single
+// uniform across both legacy (WPB=1) and multi-window dispatches.
@group(0) @binding(6)
-var params: vec3;
+var params: vec4;
{{#capture_debug}}
// Debug capture for stage_2's inline add-2007-bl formula. Layout: 8 BigInts
@@ -2562,15 +3408,29 @@ fn bench_xor_into(a: Point, b: Point) -> Point {
return r;
}
+// Multi-window BPR (WPB > 1) trades thread count for per-thread work:
+// fewer workgroups dispatched, each thread runs the WPB-window inner loop
+// {{ windows_per_batch }}u times (one per subtask in the batch). Saves
+// total launches and amortises kernel header / dispatch overhead, but
+// inflates per-thread register pressure — a single subtask already
+// carries ~10 BigInt locals (~800B/thread) inside the add_points formula,
+// so WPB > 1 risks spilling on Apple Silicon SIMD-groups (32 KB shared
+// register file). When WPB=1 the inner loop runs once and the dispatch
+// shape collapses to the legacy (X=num_subtasks, 256) layout.
@compute
@workgroup_size({{ workgroup_size }})
-fn stage_1(@builtin(global_invocation_id) global_id: vec3) {
- let thread_id = global_id.x;
+fn stage_1(
+ @builtin(global_invocation_id) global_id: vec3,
+ @builtin(workgroup_id) wg_id: vec3,
+ @builtin(local_invocation_id) local_id: vec3,
+) {
+ let thread_id = local_id.x;
let num_threads_per_subtask = {{ workgroup_size }}u;
- let subtask_idx = params[0]; // 0, 2, 4, 6, ...
+ let subtask_idx_base = params[0]; // base subtask for this dispatch (host-side outer iter)
let num_columns = params[1]; // 65536
- let num_subtasks_per_bpr = params[2]; // Must be a power of 2
+ let num_subtasks_per_bpr = params[2]; // packs g_points output across subtasks (legacy)
+ let num_subtasks_total = params[3]; // bounds check for tail batches when WPB > 1
/// Number of buckets per subtask.
let num_buckets_per_subtask = num_columns / 2u; // 2 ** 15 = 32768
@@ -2580,17 +3440,30 @@ fn stage_1(@builtin(global_invocation_id) global_id: vec3) {
let num_buckets_per_bpr = num_buckets_per_subtask * num_subtasks_per_bpr;
/// Number of buckets to reduce per thread.
- let buckets_per_thread = num_buckets_per_subtask /
+ let buckets_per_thread = num_buckets_per_subtask /
num_threads_per_subtask;
- let multiplier = subtask_idx + (thread_id / num_threads_per_subtask);
- let offset = num_buckets_per_subtask * multiplier;
+ // WPB-aware subtask iteration. Each workgroup owns WPB consecutive
+ // subtasks; thread \`thread_id\` processes the same in-subtask slice
+ // for every one of them. With WPB=1 the outer loop runs once and
+ // the semantics are byte-identical to the legacy dispatch.
+ for (var w_local = 0u; w_local < {{ windows_per_batch }}u; w_local = w_local + 1u) {
+ let subtask_in_dispatch = wg_id.x * {{ windows_per_batch }}u + w_local;
+ let subtask_idx = subtask_idx_base + subtask_in_dispatch;
+ if (subtask_idx >= num_subtasks_total) {
+ // Tail batch: skip out-of-range subtasks (num_subtasks may
+ // not be a multiple of WPB).
+ continue;
+ }
- var idx = offset;
- if (thread_id % num_threads_per_subtask != 0u) {
- idx = (num_threads_per_subtask - (thread_id % num_threads_per_subtask)) *
- buckets_per_thread + offset;
- }
+ let multiplier = subtask_idx;
+ let offset = num_buckets_per_subtask * multiplier;
+
+ var idx = offset;
+ if (thread_id % num_threads_per_subtask != 0u) {
+ idx = (num_threads_per_subtask - (thread_id % num_threads_per_subtask)) *
+ buckets_per_thread + offset;
+ }
{{#bench_compute_only}}
// Microbench: synthesize the initial m so this load is also stripped.
@@ -2865,40 +3738,46 @@ fn stage_1(@builtin(global_invocation_id) global_id: vec3) {
{{/mixed_safe_buckets}}
{{/assume_affine_buckets}}
- let t = (subtask_idx / num_subtasks_per_bpr) *
- (num_threads_per_subtask * num_subtasks_per_bpr) +
- thread_id;
+ let t = (subtask_idx / num_subtasks_per_bpr) *
+ (num_threads_per_subtask * num_subtasks_per_bpr) +
+ thread_id;
{{#bench_skip_writes}}
- // V_NULL / V3: skip the 6 storage writes (bucket_sum + g_points).
- // Funnel one observable write per thread into the workgroup atomic
- // sink so the WGSL compiler can't DCE the inner loop. atomicXor is
- // the cheapest read-modify-write that can't be folded; it produces
- // 1 LDS op per thread vs 6 storage writes in the production path.
- atomicXor(&bench_sink, m.x.limbs[0] ^ g.x.limbs[0] ^ t);
+ // V_NULL / V3: skip the 6 storage writes (bucket_sum + g_points).
+ // Funnel one observable write per thread into the workgroup atomic
+ // sink so the WGSL compiler can't DCE the inner loop. atomicXor is
+ // the cheapest read-modify-write that can't be folded; it produces
+ // 1 LDS op per thread vs 6 storage writes in the production path.
+ atomicXor(&bench_sink, m.x.limbs[0] ^ g.x.limbs[0] ^ t);
{{/bench_skip_writes}}
{{^bench_skip_writes}}
- bucket_sum_x[idx] = m.x;
- bucket_sum_y[idx] = m.y;
- bucket_sum_z[idx] = m.z;
+ bucket_sum_x[idx] = m.x;
+ bucket_sum_y[idx] = m.y;
+ bucket_sum_z[idx] = m.z;
- g_points_x[t] = g.x;
- g_points_y[t] = g.y;
- g_points_z[t] = g.z;
+ g_points_x[t] = g.x;
+ g_points_y[t] = g.y;
+ g_points_z[t] = g.z;
{{/bench_skip_writes}}
+ } // end of w_local loop (WPB-aware multi-window outer)
{{{ recompile }}}
}
@compute
@workgroup_size({{ workgroup_size }})
-fn stage_2(@builtin(global_invocation_id) global_id: vec3) {
- let thread_id = global_id.x;
+fn stage_2(
+ @builtin(global_invocation_id) global_id: vec3,
+ @builtin(workgroup_id) wg_id: vec3,
+ @builtin(local_invocation_id) local_id: vec3,
+) {
+ let thread_id = local_id.x;
let num_threads_per_subtask = {{ workgroup_size }}u;
- let subtask_idx = params[0]; // 0, 2, 4, 6, ...
+ let subtask_idx_base = params[0]; // base subtask for this dispatch (host-side outer iter)
let num_columns = params[1]; // 65536
- let num_subtasks_per_bpr = params[2]; // Must be a power of 2
+ let num_subtasks_per_bpr = params[2]; // packs g_points output across subtasks (legacy)
+ let num_subtasks_total = params[3]; // bounds check for tail batches when WPB > 1
/// Number of buckets per subtask.
let num_buckets_per_subtask = num_columns / 2u; // 2 ** 15 = 32768
@@ -2908,27 +3787,35 @@ fn stage_2(@builtin(global_invocation_id) global_id: vec3) {
let num_buckets_per_bpr = num_buckets_per_subtask * num_subtasks_per_bpr;
/// Number of buckets to reduce per thread.
- let buckets_per_thread = num_buckets_per_subtask /
+ let buckets_per_thread = num_buckets_per_subtask /
num_threads_per_subtask;
- let multiplier = subtask_idx + (thread_id / num_threads_per_subtask);
- let offset = num_buckets_per_subtask * multiplier;
+ // WPB-aware subtask iteration. See stage_1 for the rationale.
+ for (var w_local = 0u; w_local < {{ windows_per_batch }}u; w_local = w_local + 1u) {
+ let subtask_in_dispatch = wg_id.x * {{ windows_per_batch }}u + w_local;
+ let subtask_idx = subtask_idx_base + subtask_in_dispatch;
+ if (subtask_idx >= num_subtasks_total) {
+ continue;
+ }
- var idx = offset;
- if (thread_id % num_threads_per_subtask != 0u) {
- idx = (num_threads_per_subtask - (thread_id % num_threads_per_subtask)) *
- buckets_per_thread + offset;
- }
+ let multiplier = subtask_idx;
+ let offset = num_buckets_per_subtask * multiplier;
- var m = load_bucket_sum(idx);
+ var idx = offset;
+ if (thread_id % num_threads_per_subtask != 0u) {
+ idx = (num_threads_per_subtask - (thread_id % num_threads_per_subtask)) *
+ buckets_per_thread + offset;
+ }
- let t = (subtask_idx / num_subtasks_per_bpr) *
- (num_threads_per_subtask * num_subtasks_per_bpr) +
- thread_id;
- var g = load_g_point(t);
+ var m = load_bucket_sum(idx);
- let s = buckets_per_thread * (num_threads_per_subtask - (thread_id % num_threads_per_subtask) - 1u);
- let sm: Point = double_and_add(m, s);
+ let t = (subtask_idx / num_subtasks_per_bpr) *
+ (num_threads_per_subtask * num_subtasks_per_bpr) +
+ thread_id;
+ var g = load_g_point(t);
+
+ let s = buckets_per_thread * (num_threads_per_subtask - (thread_id % num_threads_per_subtask) - 1u);
+ let sm: Point = double_and_add(m, s);
// The add-2007-bl formula is inlined here (rather than calling
// add_points(g, sm)) because on Dawn/Metal the function-return slot for
@@ -3016,16 +3903,17 @@ fn stage_2(@builtin(global_invocation_id) global_id: vec3) {
}
}
- g_points_x[t] = X3_out;
- g_points_y[t] = Y3_out;
- g_points_z[t] = Z3_out;
+ g_points_x[t] = X3_out;
+ g_points_y[t] = Y3_out;
+ g_points_z[t] = Z3_out;
- {{#capture_debug}}
- // Final values of the formula outputs, read from outer scope.
- debug_capture[thread_id * 8u + 4u] = X3_out;
- debug_capture[thread_id * 8u + 5u] = Y3_out;
- debug_capture[thread_id * 8u + 7u] = Z3_out;
- {{/capture_debug}}
+ {{#capture_debug}}
+ // Final values of the formula outputs, read from outer scope.
+ debug_capture[thread_id * 8u + 4u] = X3_out;
+ debug_capture[thread_id * 8u + 5u] = Y3_out;
+ debug_capture[thread_id * 8u + 7u] = Z3_out;
+ {{/capture_debug}}
+ } // end of w_local loop (WPB-aware multi-window outer)
{{{ recompile }}}
}
@@ -3556,6 +4444,52 @@ fn main(@builtin(global_invocation_id) global_id: vec3) {
}
`;
+export const divsteps_bench = `// Single-thread-per-input bench shader for \`by_divsteps\`.
+//
+// Each thread:
+// 1. Reads (f_lo, g_lo) as a vec4 from \`inputs_fg[tid]\`
+// (= f_lo.x, f_lo.y, g_lo.x, g_lo.y).
+// 2. Reads the initial delta from \`inputs_delta[tid]\`.
+// 3. Calls \`by_divsteps(&delta, f_lo, g_lo)\`.
+// 4. Writes the 8 Mat fields + the updated delta (9 i32) into \`outputs[tid]\`.
+//
+// LOOP BOUNDS
+// The only loops in the included partials are bounded by WGSL \`const\`s
+// (BY_NUM_LIMBS, BY_BATCH) or Mustache-const values (\`{{ num_words }}\`).
+// No loop is added in this entry shader. The \`if (tid >= n)\` guard is not
+// a loop bound.
+
+@group(0) @binding(0) var inputs_fg: array>;
+@group(0) @binding(1) var inputs_delta: array;
+@group(0) @binding(2) var outputs: array;
+@group(0) @binding(3) var params: vec2;
+
+@compute @workgroup_size({{ workgroup_size }})
+fn main(@builtin(global_invocation_id) gid: vec3) {
+ let n = params.x;
+ let tid = gid.x;
+ if (tid >= n) { return; }
+
+ let fg = inputs_fg[tid];
+ let f_lo: vec2 = vec2(fg.x, fg.y);
+ let g_lo: vec2 = vec2(fg.z, fg.w);
+ var delta: i32 = inputs_delta[tid];
+
+ let m: Mat = by_divsteps(&delta, f_lo, g_lo);
+
+ let base: u32 = tid * 9u;
+ outputs[base + 0u] = m.u;
+ outputs[base + 1u] = m.v;
+ outputs[base + 2u] = m.q;
+ outputs[base + 3u] = m.r;
+ outputs[base + 4u] = m.u_hi;
+ outputs[base + 5u] = m.v_hi;
+ outputs[base + 6u] = m.q_hi;
+ outputs[base + 7u] = m.r_hi;
+ outputs[base + 8u] = delta;
+}
+`;
+
export const extract_word_from_bytes_le = `fn extract_word_from_coord_bytes_le(
input: array,
word_idx: u32,
@@ -3609,6 +4543,140 @@ fn extract_word_from_bytes_le(
}
`;
+export const field_mul_bench_f32 = `// Field-mul micro-benchmark, f32 / 12×23-bit limbs (CIOS over f32 FMA).
+// Every thread loads one (a, b) pair, runs \`k\` chained Montgomery
+// products (a = montgomery_product_f32(a, b) repeated k times), and
+// writes the final \`a\` back. The host caps \`k\` at <=100 before passing
+// it in.
+//
+// Loop bounds. The only data-dependent loop is \`for (var i = 0u; i < k; i++)\`
+// where \`k = params.y\` is a host-uniform capped at 100. The early-out
+// \`if (tid >= n)\` is a guard, not a loop bound, and \`n = params.x\` is
+// host-capped at 2^23 to keep dispatch sizes reasonable. The inner
+// \`montgomery_product_f32\` loops are bounded by the compile-time constant
+// \`NUM_LIMBS\`.
+
+// Inputs are split into two separate \`xs\`/\`ys\` arrays (one BigIntF32
+// per thread per array). This mirrors the working layout used by
+// \`testMontgomeryProductF32\` in wgsl_unit_tests.ts — packing into a
+// single \`array\` produced all-zero outputs on Dawn/Metal
+// even when the inputs read back correctly via a passthrough variant,
+// so we sidestep struct-in-array layout concerns entirely.
+@group(0) @binding(0) var xs: array;
+@group(0) @binding(1) var ys: array;
+@group(0) @binding(2) var outputs: array;
+@group(0) @binding(3) var params: vec2;
+
+@compute @workgroup_size({{ workgroup_size }})
+fn main(@builtin(global_invocation_id) gid: vec3) {
+ let n = params.x;
+ let k = params.y;
+ let tid = gid.x;
+ if (tid >= n) { return; }
+
+ var a = xs[tid];
+ var b = ys[tid];
+ // Chained loop calls the unreduced body — Mont is closed over inputs
+ // in [0, 2p) since all limbs stay < W. Final conditional_reduce
+ // canonicalizes the result.
+ for (var i = 0u; i < k; i = i + 1u) {
+ a = montgomery_product_f32_unreduced(&a, &b);
+ }
+ var p_final = get_p_f32();
+ outputs[tid] = conditional_reduce_f32(&a, &p_final);
+}
+`;
+
+export const field_mul_bench_u32 = `// Field-mul micro-benchmark, u32 / 20×13-bit limbs (Mitschabaude-CIOS).
+// Every thread loads one (a, b) pair, runs \`k\` chained Montgomery
+// products (a = montgomery_product(a, b) repeated k times), and writes
+// the final \`a\` back. The host caps \`k\` at <=100 before passing it in.
+//
+// Loop bounds. The only data-dependent loop is \`for (var i = 0u; i < k; i++)\`
+// where \`k = params.y\` is a host-uniform capped at 100. The early-out
+// \`if (tid >= n)\` is a guard, not a loop bound, and \`n = params.x\` is
+// host-capped at 2^23 to keep dispatch sizes reasonable. The inner
+// \`montgomery_product\` loops are bounded by the compile-time constant
+// \`NUM_WORDS\`.
+
+// Inputs are split into two separate \`xs\`/\`ys\` arrays (one BigInt per
+// thread per array). Matches the f32 path's layout for symmetry — the
+// \`array\` packing was found to produce all-zero outputs on
+// Dawn/Metal in the f32 variant of this shader; using separate arrays
+// sidesteps that concern.
+@group(0) @binding(0) var xs: array;
+@group(0) @binding(1) var ys: array;
+@group(0) @binding(2) var outputs: array;
+@group(0) @binding(3) var params: vec2;
+
+@compute @workgroup_size({{ workgroup_size }})
+fn main(@builtin(global_invocation_id) gid: vec3) {
+ let n = params.x;
+ let k = params.y;
+ let tid = gid.x;
+ if (tid >= n) { return; }
+
+ var a = xs[tid];
+ var b = ys[tid];
+ for (var i = 0u; i < k; i = i + 1u) {
+ a = montgomery_product(&a, &b);
+ }
+ outputs[tid] = a;
+}
+`;
+
+export const fr_inv_bench = `// Per-thread chained-inversion micro-benchmark for the field inverse.
+//
+// Each thread reads one BN254 base-field value \`a\` (in Montgomery form) and
+// runs \`k\` chained calls to the inverse function selected by the
+// \`{{ inv_fn }}\` Mustache substitution (default \`fr_inv_by\`; the legacy
+// \`fr_inv\` Pornin jumpy K=12 safegcd is the other supported variant):
+// a <- {{ inv_fn }}(a) repeated k times,
+// writing the final \`a\` back to \`outputs[tid]\`. The host caps k at 100
+// before dispatch.
+//
+// LOOP BOUNDS
+// - The only data-dependent loop in this entry shader is
+// \`for (var i = 0u; i < k; i = i + 1u)\` with \`k = params.y\` host-capped
+// at 100. Every loop in the selected inverse function (and its callees)
+// is bounded by a compile-time \`const\` — see by_inverse.template.wgsl,
+// bigint_by, fr_pow.template.wgsl, and the inner \`montgomery_product\`
+// (\`for i, j < NUM_WORDS\`).
+// - The \`if (tid >= n)\` guard is not a loop bound.
+//
+// Bind layout mirrors \`field_mul_bench_u32\`: a single inputs buffer \`xs\`
+// (no \`ys\`, since inversion is unary) and an outputs buffer. \`params.x\` is
+// \`n\` (threads), \`params.y\` is \`k\` (chain length).
+
+// get_r returns Montgomery R (= the integer 1 in Montgomery form). Defined
+// per-entry shader by every cuzk pipeline kernel that uses fr_pow / fr_inv.
+// fr_pow_funcs references it via fr_pow; not called on fr_inv_by's hot path
+// but kept for binary compatibility with the partial.
+fn get_r() -> BigInt {
+ var r: BigInt;
+{{{ r_limbs }}}
+ return r;
+}
+
+@group(0) @binding(0) var xs: array;
+@group(0) @binding(1) var outputs: array;
+@group(0) @binding(2) var params: vec2;
+
+@compute @workgroup_size({{ workgroup_size }})
+fn main(@builtin(global_invocation_id) gid: vec3) {
+ let n = params.x;
+ let k = params.y;
+ let tid = gid.x;
+ if (tid >= n) { return; }
+
+ var a = xs[tid];
+ for (var i = 0u; i < k; i = i + 1u) {
+ a = {{ inv_fn }}(a);
+ }
+ outputs[tid] = a;
+}
+`;
+
export const horner_reduce_bn254 = `{{> structs }}
{{> bigint_funcs }}
{{> montgomery_product_funcs }}
@@ -4267,6 +5335,1754 @@ fn main(@builtin(global_invocation_id) global_id: vec3) {
{{{ recompile }}}
}`;
+export const by_inverse = `// Bernstein-Yang safegcd inversion for the BN254 base field, WGSL port.
+//
+// This file will grow over sub-steps 1.3-1.5 of the WebGPU MSM rewrite plan
+// to host the full \`fr_inv_by\` entrypoint. Currently it contains the inner
+// \`by_divsteps\` primitive (sub-step 1.3) — a line-for-line transliteration
+// of \`Wasm9x29::divsteps\` (bernstein_yang_inverse_wasm.hpp lines 147-178).
+//
+// REFERENCES
+// - TS port (ground truth): src/msm_webgpu/cuzk/bernstein_yang.ts
+// - bigint helpers used here: src/msm_webgpu/wgsl/bigint/bigint_by.template.wgsl
+//
+// REPRESENTATIONS
+// - delta: i32 counter (mirrors the i64 in C++; only the low value
+// matters because BATCH=58 caps |delta| growth per call).
+// - f_lo, g_lo: u64 carriers held as vec2 (.x = low 32, .y = high 32).
+// All u64 ops via u64_add / u64_sub / u64_shr1 helpers.
+// - u, v, q, r: i64 matrix entries held as paired i32 (.x = low 32 unsigned
+// bit pattern, .y = high 32 signed). After BATCH divsteps
+// |entry| <= 2^58, fits in i64. All ops via i64_*_pair.
+//
+// LOOP BOUND DISCIPLINE
+// Three loops total in this file:
+// - by_divsteps inner loop: \`for ... i < BY_BATCH\` (= 58u const)
+// - by_apply_matrix_fg streaming: \`for ... i < BY_NUM_LIMBS\` (= 9u const)
+// - by_apply_matrix_de streaming: \`for ... i < BY_NUM_LIMBS\` (= 9u const)
+// Both bounds are compile-time WGSL \`const\`s defined in bigint_by, so the
+// plan's "bounded loops" rule is satisfied. The audit grep
+// \`grep -nE 'for *\\(' ... | grep -v -E '< [A-Z][A-Z_]*[a-z]?|< [0-9]+|...'\`
+// returns no matches.
+
+// 2x2 matrix produced by BATCH=58 divsteps. Each entry is a 64-bit signed
+// integer stored as a paired (lo: i32, hi: i32) — value = u32(lo) | (i32(hi) << 32),
+// interpreted as two's complement. The naming matches the C++ struct field
+// names (m.u, m.v, m.q, m.r) suffixed with \`_hi\` for the high half.
+struct Mat {
+ u: i32,
+ v: i32,
+ q: i32,
+ r: i32,
+ u_hi: i32,
+ v_hi: i32,
+ q_hi: i32,
+ r_hi: i32,
+}
+
+// by_divsteps: run BATCH = 58 branchy divsteps on the low 64 bits of (f, g);
+// returns the transition matrix M and updates \`*delta\`.
+//
+// Mirrors Wasm9x29::divsteps line-for-line. The branches are variable-time
+// over inputs, which is acceptable here because BN254 base-field inversion
+// operates on public values in the MSM pipeline.
+//
+// Pre: delta is the current divstep counter (signed i32 view).
+// f_lo, g_lo are the low 64 bits of f and g respectively (vec2).
+// Post: returns the matrix M = ((u, v), (q, r)) such that
+// (f_new, g_new) = M * (f_old, g_old) / 2^BATCH
+// after BATCH=58 divsteps. *delta is updated.
+//
+// The TS reference uses \`(g_lo - f_lo) & U64_MASK\` then \`>> 1\` to mimic the
+// C++ \`(u64)(g_lo - f_lo) >> 1\` semantics: u64 wrap, then unsigned shift.
+// u64_sub + u64_shr1 here is exactly that.
+fn by_divsteps(delta: ptr, f_lo_in: vec2, g_lo_in: vec2) -> Mat {
+ var f_lo: vec2 = f_lo_in;
+ var g_lo: vec2 = g_lo_in;
+ // Matrix entries as paired i32 (i64). u = 1, v = 0, q = 0, r = 1.
+ var u: vec2 = vec2(1, 0);
+ var v: vec2 = vec2(0, 0);
+ var q: vec2 = vec2(0, 0);
+ var r: vec2 = vec2(1, 0);
+ var d: i32 = *delta;
+ for (var i: u32 = 0u; i < BY_BATCH; i = i + 1u) {
+ if (u64_low_bit(g_lo) != 0u) {
+ if (d > 0) {
+ // (f, g) <- (g, (g - f) >> 1) using u64 wrap-sub then unsigned >> 1.
+ let nf: vec2 = g_lo;
+ let diff: vec2 = u64_sub(g_lo, f_lo);
+ let ng: vec2 = u64_shr1(diff);
+ // (u, v, q, r) <- (q << 1, r << 1, q - u, r - v).
+ let nu: vec2 = i64_shl1_pair(q);
+ let nv: vec2 = i64_shl1_pair(r);
+ let nq: vec2 = i64_sub_pair(q, u);
+ let nr: vec2 = i64_sub_pair(r, v);
+ f_lo = nf;
+ g_lo = ng;
+ u = nu;
+ v = nv;
+ q = nq;
+ r = nr;
+ d = 1 - d;
+ } else {
+ // g <- (g + f) >> 1; q += u; r += v; u <<= 1; v <<= 1; d += 1.
+ let sum: vec2 = u64_add(g_lo, f_lo);
+ g_lo = u64_shr1(sum);
+ q = i64_add_pair(q, u);
+ r = i64_add_pair(r, v);
+ u = i64_shl1_pair(u);
+ v = i64_shl1_pair(v);
+ d = d + 1;
+ }
+ } else {
+ // g <- g >> 1; u <<= 1; v <<= 1; d += 1.
+ g_lo = u64_shr1(g_lo);
+ u = i64_shl1_pair(u);
+ v = i64_shl1_pair(v);
+ d = d + 1;
+ }
+ }
+ *delta = d;
+ return Mat(u.x, v.x, q.x, r.x, u.y, v.y, q.y, r.y);
+}
+
+// ============================================================
+// apply_matrix helpers
+//
+// \`signed_mul_split\` accepts |a| <= 2^29 (one BY limb) and |b| <= 2^31 - 1
+// and returns (lo29, hi) with a*b = lo29 + hi * 2^29. The streaming
+// schoolbook below feeds it (m_*, x_limb) pairs whose products lie in
+// [-2^58, 2^58], well inside the helper's contract.
+//
+// Each per-limb position i computes
+// acc <- m_lo * x_i + m_hi * x_{i-1} + carry_in (+ k * p_i terms in de pass)
+// then carry_out = acc >> 29 (arithmetic), with the masked low-29 bits
+// landing at output position i - 2 (= exact >> BATCH = >> 58 = >> (2 * 29)).
+//
+// The signed_mul_split returns lo29 in [0, 2^29) and a signed hi. Adding
+// four cross-products together can push the high half beyond i32 range
+// only when |sum| exceeds ~2^31 — at the per-limb level this is bounded
+// by the four-product sum of |2^58| / 2^29 + carry, well inside i32.
+
+// Convert a (lo29, hi) signed-product split to a full i64 (vec2).
+//
+// The signed_mul_split helper returns (lo29, hi) where
+// value = lo29 + hi * 2^29 with lo29 in [0, 2^29)
+// To add this into an i64 accumulator we need to re-express it as a
+// (lo32, hi32) pair where
+// value = u32(lo32) + i32(hi32) * 2^32 (two's complement)
+//
+// Bit layout:
+// bits 0..28 of value = lo29 (from the lo29 half)
+// bits 29..31 of value = bits 0..2 of \`hi\`
+// bits 32..63 of value = bits 3..34 of \`hi\` (sign-extended)
+//
+// \`hi << 29u\` on a u32 keeps only the low 3 bits of \`hi\` after shifting,
+// which yields exactly the contribution to bits 29..31. The high half is
+// \`hi >> 3u\` with signed arithmetic shift, which sign-extends to fill bit
+// 63 correctly when \`hi\` is negative.
+fn by_split_to_i64(split: vec2) -> vec2 {
+ let lo32: u32 = u32(split.x) | (u32(split.y) << 29u);
+ let hi32: i32 = split.y >> 3u;
+ return vec2(i32(lo32), hi32);
+}
+
+// Add \`m_lo * x_limb\` (a signed 58-bit product) into an i64 accumulator.
+// Helper used pervasively in the streaming schoolbook below.
+fn by_add_mul(acc: vec2, m_lo: i32, x_limb: i32) -> vec2 {
+ let split = signed_mul_split(m_lo, x_limb);
+ let prod_i64 = by_split_to_i64(split);
+ return i64_add_pair(acc, prod_i64);
+}
+
+// Arithmetic right shift of an i64 (vec2