From 8f414ef918d70dd7af96614b7da343e2bb890738 Mon Sep 17 00:00:00 2001 From: zac-williamson Date: Sat, 16 May 2026 18:59:27 +0100 Subject: [PATCH 1/4] feat(bb): WebGPU field-mul micro-benchmark + Karatsuba/sos3uv3 Mont mults MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a standalone WebGPU micro-benchmark page comparing three BN254 Montgomery product implementations for chained-mul throughput: - cios (u32): mitschabaude runtime-loop CIOS over 20×13-bit limbs. Baseline, ~109 ms at n=2^20, k=100. - karat (u32): recursive Karatsuba + Yuval reduction. 9 5×5 schoolbook sub-sub-products are computed independently and combined via two Karatsuba levels; reduction uses precomputed r_inv = W^-1 mod p with zero drains in the multiply phase (unsigned wrap unwinds via subsequent subtraction). ~80 ms (~28% faster than cios). - sos3uv3 (f32, reference): 22-bit f32 limbs with separate per-slot tlo/thi accumulators that break the inner-j carry chain. Single drain per outer iter via bias_split_f32_le4w. ~79 ms. The bench harness: - bench-field-mul.html is a standalone page; reads ?path=u32|f32 &n=N&k=K&validate-n=N&reps=R&variant=V from the URL. - bench-field-mul.ts runs k chained Mont mults per thread, validates the first `validate-n` outputs against a host BigInt reference, and writes timing into window.__bench. - scripts/bench-field-mul.mjs is a Playwright driver for headless invocation from the CLI (added playwright-core as devDependency). --- .../ts/dev/msm-webgpu/bench-field-mul.html | 37 + .../ts/dev/msm-webgpu/bench-field-mul.ts | 921 ++++++++++++++++++ .../msm-webgpu/scripts/bench-field-mul.mjs | 256 +++++ barretenberg/ts/dev/msm-webgpu/tsconfig.json | 10 + barretenberg/ts/package.json | 1 + .../ts/src/msm_webgpu/cuzk/shader_manager.ts | 259 +++++ barretenberg/ts/src/msm_webgpu/cuzk/utils.ts | 29 + .../src/msm_webgpu/wgsl/_generated/shaders.ts | 606 +++++++++++- .../wgsl/bigint/bigint_f32.template.wgsl | 64 ++ .../cuzk/field_mul_bench_f32.template.wgsl | 42 + .../cuzk/field_mul_bench_u32.template.wgsl | 36 + ...t_pro_product_f32_22_sos3uv3.template.wgsl | 215 ++++ ...mont_pro_product_karat_yuval.template.wgsl | 192 ++++ .../wgsl/montgomery/mulhilo_22.wgsl | 43 + barretenberg/ts/yarn.lock | 10 + 15 files changed, 2720 insertions(+), 1 deletion(-) create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-field-mul.html create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-field-mul.ts create mode 100644 barretenberg/ts/dev/msm-webgpu/scripts/bench-field-mul.mjs create mode 100644 barretenberg/ts/dev/msm-webgpu/tsconfig.json create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/bigint/bigint_f32.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/field_mul_bench_f32.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/field_mul_bench_u32.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/montgomery/mont_pro_product_f32_22_sos3uv3.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/montgomery/mont_pro_product_karat_yuval.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/montgomery/mulhilo_22.wgsl diff --git a/barretenberg/ts/dev/msm-webgpu/bench-field-mul.html b/barretenberg/ts/dev/msm-webgpu/bench-field-mul.html new file mode 100644 index 000000000000..df7202bc6b07 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-field-mul.html @@ -0,0 +1,37 @@ + + + + + BN254 field-mul micro-benchmark + + + +

BN254 field-mul micro-benchmark

+

Query params: ?path=u32|f32&n=N&k=K&validate-n=N&reps=R

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-field-mul.ts b/barretenberg/ts/dev/msm-webgpu/bench-field-mul.ts new file mode 100644 index 000000000000..d4bab32d56a2 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-field-mul.ts @@ -0,0 +1,921 @@ +/// +// Field-mul micro-benchmark page. Mounted standalone (no SRS, no MSM +// pipeline). Reads `?path=u32|f32&n=N&k=K&validate-n=N&reps=R`, generates +// random BN254 base-field pairs, runs `k` chained Montgomery products +// per thread, validates the first `validate-n` outputs against a host +// BigInt reference, and reports timing via `window.__bench`. +// +// Safety: `k` is capped at 100, `n` at 2^23, both checked before any +// dispatch. The only loop in either shader is the k-loop with this +// bound (see field_mul_bench_*.template.wgsl). + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; + +type Path = 'u32' | 'f32'; +interface SampleSummary { + reps: number; + msSamples: number[]; + msMedian: number; + msMin: number; + msMax: number; + multsPerSec: number; +} +interface PathResult { + path: Path; + validateOk: boolean; + mismatches: string[]; + timing: SampleSummary | null; +} +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: { + path: Path | 'both'; + n: number; + k: number; + validateN: number; + reps: number; + } | null; + results: PathResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { + state: 'boot', + params: null, + results: [], + error: null, + log: [], +}; +(window as unknown as { __bench: BenchState }).__bench = benchState; + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-field-mul] ${msg}`); +} + +const N_MAX = 1 << 23; +const K_MAX = 100; + +// Limb layouts are fixed for this bench: u32 = 20×13-bit, f32 = 12×22-bit. +const NUM_LIMBS_F32 = 12; +const WORD_SIZE_F32 = 22; +const NUM_LIMBS_U32 = 20; +const WORD_SIZE_U32 = 13; + +const W_U32 = 1n << BigInt(WORD_SIZE_U32); +const MASK_U32 = W_U32 - 1n; +const W_F32 = 1n << BigInt(WORD_SIZE_F32); +const MASK_F32 = W_F32 - 1n; + +function bigintToLimbsU32(v: bigint): number[] { + const limbs: number[] = new Array(NUM_LIMBS_U32); + let x = v; + for (let i = 0; i < NUM_LIMBS_U32; i++) { + limbs[i] = Number(x & MASK_U32); + x >>= BigInt(WORD_SIZE_U32); + } + return limbs; +} +function limbsU32ToBigint(limbs: ArrayLike): bigint { + let v = 0n; + for (let i = NUM_LIMBS_U32 - 1; i >= 0; i--) { + v = (v << BigInt(WORD_SIZE_U32)) | BigInt(limbs[i] >>> 0); + } + return v; +} +function bigintToLimbsF32(v: bigint): number[] { + const limbs: number[] = new Array(NUM_LIMBS_F32); + let x = v; + for (let i = 0; i < NUM_LIMBS_F32; i++) { + limbs[i] = Number(x & MASK_F32); + x >>= BigInt(WORD_SIZE_F32); + } + return limbs; +} +function limbsF32ToBigint(limbs: ArrayLike): bigint { + let v = 0n; + for (let i = NUM_LIMBS_F32 - 1; i >= 0; i--) { + v = (v << BigInt(WORD_SIZE_F32)) | BigInt(Math.round(limbs[i])); + } + return v; +} + +// Seeded LCG (Numerical Recipes constants) for reproducible pair gen. +// Math.random() is fine for input pairs (we're not testing RNG quality), +// but a deterministic stream makes failures repeatable. +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function randomBelow(p: bigint, rng: () => number): bigint { + const bitlen = p.toString(2).length; + const byteLen = Math.ceil(bitlen / 8); + while (true) { + let v = 0n; + for (let i = 0; i < byteLen; i++) { + v = (v << 8n) | BigInt(rng() & 0xff); + } + v &= (1n << BigInt(bitlen)) - 1n; + if (v < p) return v; + } +} + +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; +} + +// Host BigInt reference for `k` chained Mont products. Inputs are in +// Mont form (x_m = x * R mod p). Each Mont multiply returns +// (x_m * y_m * R^-1) mod p. After `k` rounds starting from a_m, the +// result is a_m * b_m^k * (R^-1)^k mod p (in Mont form). +function chainedMontReference( + aMont: bigint, + bMont: bigint, + k: number, + Rinv: bigint, + p: bigint, +): bigint { + let acc = aMont; + for (let i = 0; i < k; i++) { + acc = (acc * bMont * Rinv) % p; + } + return acc; +} + +async function createPipeline( + device: GPUDevice, + code: string, + cacheKey: string, +): Promise<{ pipeline: GPUComputePipeline; layout: GPUBindGroupLayout }> { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + for (const msg of info.messages) { + const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`; + if (msg.type === 'error') { + console.error(line); + hasError = true; + } else { + console.warn(line); + } + } + if (hasError) { + throw new Error(`WGSL compile failed for ${cacheKey}`); + } + const layout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + const pipeline = await device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); + return { pipeline, layout }; +} + +async function runPath( + device: GPUDevice, + sm: ShaderManager, + path: Path, + n: number, + k: number, + validateN: number, + reps: number, +): Promise { + // Variant selects which Mont algorithm to bench: + // u32 path: 'cios' (default, mitschabaude runtime-loop CIOS) or 'karat' + // (recursive Karatsuba + Yuval reduction). + // f32 path: 'sos3uv3' only (separate per-slot tlo/thi f32 accumulators). + const qp = new URLSearchParams(window.location.search); + const variant = (qp.get('variant') ?? (path === 'u32' ? 'cios' : 'sos3uv3')) as + | 'cios' + | 'karat' + | 'sos3uv3'; + + log('info', `path=${path}: building pairs (n=${n}, k=${k}, validate-n=${validateN}, reps=${reps})`); + const p = BN254_CURVE_CONFIG.baseFieldModulus; + const wordSize = path === 'u32' ? WORD_SIZE_U32 : WORD_SIZE_F32; + const numLimbs = path === 'u32' ? NUM_LIMBS_U32 : NUM_LIMBS_F32; + const params = compute_misc_params(p, wordSize); + if (params.num_words !== numLimbs) { + throw new Error(`expected num_words=${numLimbs} for path=${path}, got ${params.num_words}`); + } + const R = params.r; + const Rinv = params.rinv; + if ((R * Rinv) % p !== 1n) { + throw new Error(`R * Rinv mod p != 1 for path=${path}`); + } + + // Generate random a_canonical, b_canonical pairs (CPU side BigInts). + const rng = makeRng(0xc0ffee + (path === 'u32' ? 0 : 1)); + const aCanonical: bigint[] = new Array(n); + const bCanonical: bigint[] = new Array(n); + for (let i = 0; i < n; i++) { + aCanonical[i] = randomBelow(p, rng); + bCanonical[i] = randomBelow(p, rng); + } + + // Encode in the appropriate Mont ring. + const aMont: bigint[] = aCanonical.map(x => (x * R) % p); + const bMont: bigint[] = bCanonical.map(x => (x * R) % p); + + // CPU reference for the first validate-n pairs. + log('info', `path=${path}: computing host reference for ${validateN} pairs`); + const expected: bigint[] = new Array(validateN); + for (let i = 0; i < validateN; i++) { + expected[i] = chainedMontReference(aMont[i], bMont[i], k, Rinv, p); + } + + // Pack separate `xs` / `ys` buffers — one BigInt per thread per buffer. + const bytesPerLimbArray = numLimbs * 4; + const xsBytes = new ArrayBuffer(n * bytesPerLimbArray); + const ysBytes = new ArrayBuffer(n * bytesPerLimbArray); + + if (path === 'u32') { + const xv = new Uint32Array(xsBytes); + const yv = new Uint32Array(ysBytes); + for (let i = 0; i < n; i++) { + const aLimbs = bigintToLimbsU32(aMont[i]); + const bLimbs = bigintToLimbsU32(bMont[i]); + const off = i * NUM_LIMBS_U32; + for (let j = 0; j < NUM_LIMBS_U32; j++) xv[off + j] = aLimbs[j]; + for (let j = 0; j < NUM_LIMBS_U32; j++) yv[off + j] = bLimbs[j]; + } + } else { + const xv = new Float32Array(xsBytes); + const yv = new Float32Array(ysBytes); + for (let i = 0; i < n; i++) { + const aLimbs = bigintToLimbsF32(aMont[i]); + const bLimbs = bigintToLimbsF32(bMont[i]); + const off = i * NUM_LIMBS_F32; + for (let j = 0; j < NUM_LIMBS_F32; j++) xv[off + j] = aLimbs[j]; + for (let j = 0; j < NUM_LIMBS_F32; j++) yv[off + j] = bLimbs[j]; + } + } + + const WORKGROUP_SIZE = 64; + let code: string; + if (path === 'u32') { + if (variant !== 'cios' && variant !== 'karat') { + throw new Error(`u32 path supports variant 'cios' or 'karat', got '${variant}'`); + } + code = sm.gen_field_mul_bench_u32_shader(WORKGROUP_SIZE, variant); + } else { + if (variant !== 'sos3uv3') { + throw new Error(`f32 path supports variant 'sos3uv3', got '${variant}'`); + } + code = sm.gen_field_mul_bench_f32_shader(WORKGROUP_SIZE, variant); + } + const cacheKey = `field-mul-bench-${path}-wg${WORKGROUP_SIZE}`; + log('info', `path=${path}: compiling shader (${code.length} chars)`); + // Stash the rendered shader on window so external tooling can dump it + // for post-mortem analysis when validation fails. + (window as unknown as Record)[`__shader_${path}`] = code; + const { pipeline, layout } = await createPipeline(device, code, cacheKey); + + // Buffers. + const xsBuf = device.createBuffer({ + size: xsBytes.byteLength, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(xsBuf, 0, xsBytes); + const ysBuf = device.createBuffer({ + size: ysBytes.byteLength, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(ysBuf, 0, ysBytes); + const outBytes = n * numLimbs * 4; + const outBuf = device.createBuffer({ + size: outBytes, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, + }); + const uniformBytes = new ArrayBuffer(16); + const uniformView = new Uint32Array(uniformBytes); + uniformView[0] = n; + uniformView[1] = k; + const uniformBuf = device.createBuffer({ + size: 16, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(uniformBuf, 0, uniformBytes); + + const bindGroup = device.createBindGroup({ + layout, + entries: [ + { binding: 0, resource: { buffer: xsBuf } }, + { binding: 1, resource: { buffer: ysBuf } }, + { binding: 2, resource: { buffer: outBuf } }, + { binding: 3, resource: { buffer: uniformBuf } }, + ], + }); + + const numWorkgroups = Math.ceil(n / WORKGROUP_SIZE); + log('info', `path=${path}: dispatching ${numWorkgroups} workgroups of ${WORKGROUP_SIZE} threads each (${n} threads total)`); + + // Warmup pass — issued and awaited before the timed reps. + { + const encoder = device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(numWorkgroups, 1, 1); + pass.end(); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + } + log('info', `path=${path}: warmup OK`); + + // Validation pass — read back outputs. + const stagingBuf = device.createBuffer({ + size: outBytes, + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + { + const encoder = device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(numWorkgroups, 1, 1); + pass.end(); + encoder.copyBufferToBuffer(outBuf, 0, stagingBuf, 0, outBytes); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + await stagingBuf.mapAsync(GPUMapMode.READ); + } + const outBytesCopy = stagingBuf.getMappedRange(0, outBytes).slice(0); + stagingBuf.unmap(); + stagingBuf.destroy(); + + // Compare first validate-n outputs. + let mismatches: string[] = []; + let validateOk = true; + if (path === 'u32') { + const outU32 = new Uint32Array(outBytesCopy); + for (let i = 0; i < validateN; i++) { + const limbs = outU32.subarray(i * NUM_LIMBS_U32, (i + 1) * NUM_LIMBS_U32); + const got = limbsU32ToBigint(limbs); + if (got !== expected[i]) { + validateOk = false; + if (mismatches.length < 5) { + mismatches.push( + `pair[${i}]: a_can=0x${aCanonical[i].toString(16)} b_can=0x${bCanonical[i].toString(16)}\n` + + ` expected: 0x${expected[i].toString(16)}\n` + + ` actual: 0x${got.toString(16)}\n` + + ` expected_limbs: [${bigintToLimbsU32(expected[i]).join(', ')}]\n` + + ` actual_limbs: [${Array.from(limbs).join(', ')}]`, + ); + } + } + } + } else { + const outF32 = new Float32Array(outBytesCopy); + for (let i = 0; i < validateN; i++) { + const limbs = outF32.subarray(i * NUM_LIMBS_F32, (i + 1) * NUM_LIMBS_F32); + const got = limbsF32ToBigint(limbs); + if (got !== expected[i]) { + validateOk = false; + if (mismatches.length < 5) { + mismatches.push( + `pair[${i}]: a_can=0x${aCanonical[i].toString(16)} b_can=0x${bCanonical[i].toString(16)}\n` + + ` expected: 0x${expected[i].toString(16)}\n` + + ` actual: 0x${got.toString(16)}\n` + + ` expected_limbs: [${bigintToLimbsF32(expected[i]).join(', ')}]\n` + + ` actual_limbs: [${Array.from(limbs).map(x => Math.round(x)).join(', ')}]`, + ); + } + } + } + } + + if (!validateOk) { + log('err', `path=${path}: VALIDATION FAILED (${mismatches.length}/${validateN} mismatches shown of total)`); + for (const m of mismatches) log('err', m); + // Diagnostic dump: log the first pair's input limbs as they were + // packed into the GPU buffer. If all outputs are zero this confirms + // the shader is not writing to the output buffer (vs. writing the + // wrong value). + const inLimbsA: number[] = []; + const inLimbsB: number[] = []; + if (path === 'u32') { + const xv = new Uint32Array(xsBytes, 0, NUM_LIMBS_U32); + const yv = new Uint32Array(ysBytes, 0, NUM_LIMBS_U32); + for (let j = 0; j < NUM_LIMBS_U32; j++) inLimbsA.push(xv[j]); + for (let j = 0; j < NUM_LIMBS_U32; j++) inLimbsB.push(yv[j]); + } else { + const xv = new Float32Array(xsBytes, 0, NUM_LIMBS_F32); + const yv = new Float32Array(ysBytes, 0, NUM_LIMBS_F32); + for (let j = 0; j < NUM_LIMBS_F32; j++) inLimbsA.push(xv[j]); + for (let j = 0; j < NUM_LIMBS_F32; j++) inLimbsB.push(yv[j]); + } + log('err', `path=${path}: pair[0] input limbs as packed: a=[${inLimbsA.join(', ')}] b=[${inLimbsB.join(', ')}]`); + log('err', `path=${path}: pair[0] input limbs expected from canonical: a=[${(path === 'u32' ? bigintToLimbsU32(aMont[0]) : bigintToLimbsF32(aMont[0])).join(', ')}]`); + xsBuf.destroy(); + ysBuf.destroy(); + outBuf.destroy(); + uniformBuf.destroy(); + return { path, validateOk: false, mismatches, timing: null }; + } + log('ok', `path=${path}: VALIDATION OK (${validateN} pairs)`); + + // Timed reps. Each rep = one dispatch + queue.onSubmittedWorkDone wait. + const msSamples: number[] = []; + for (let rep = 0; rep < reps; rep++) { + const encoder = device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(numWorkgroups, 1, 1); + pass.end(); + const t0 = performance.now(); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + const t1 = performance.now(); + msSamples.push(t1 - t0); + } + const msMed = median(msSamples); + const msMin = Math.min(...msSamples); + const msMax = Math.max(...msSamples); + const totalMults = n * k; + const multsPerSec = totalMults / (msMed / 1000); + log( + 'ok', + `path=${path}: timing reps=${reps} median=${msMed.toFixed(3)}ms min=${msMin.toFixed(3)}ms max=${msMax.toFixed(3)}ms mults/s=${multsPerSec.toExponential(3)} (n*k=${totalMults.toLocaleString()})`, + ); + + xsBuf.destroy(); + ysBuf.destroy(); + outBuf.destroy(); + uniformBuf.destroy(); + + return { + path, + validateOk: true, + mismatches: [], + timing: { reps, msSamples, msMedian: msMed, msMin, msMax, multsPerSec }, + }; +} + +function parseParams(): { + path: Path | 'both'; + n: number; + k: number; + validateN: number; + reps: number; + debug: string | null; +} { + const qp = new URLSearchParams(window.location.search); + const pathStr = qp.get('path') ?? 'both'; + if (pathStr !== 'u32' && pathStr !== 'f32' && pathStr !== 'both') { + throw new Error(`?path must be u32|f32|both, got ${pathStr}`); + } + const n = parseInt(qp.get('n') ?? '64', 10); + const k = parseInt(qp.get('k') ?? '1', 10); + const validateN = parseInt(qp.get('validate-n') ?? String(Math.min(64, n)), 10); + const reps = parseInt(qp.get('reps') ?? '3', 10); + const debug = qp.get('debug'); + if (!Number.isFinite(n) || n <= 0 || n > N_MAX) { + throw new Error(`?n must be in (0, ${N_MAX}], got ${qp.get('n')}`); + } + if (!Number.isFinite(k) || k <= 0 || k > K_MAX) { + throw new Error(`?k must be in (0, ${K_MAX}], got ${qp.get('k')}`); + } + if (!Number.isFinite(validateN) || validateN < 0 || validateN > n) { + throw new Error(`?validate-n must be in [0, n], got ${qp.get('validate-n')}`); + } + if (!Number.isFinite(reps) || reps <= 0 || reps > 100) { + throw new Error(`?reps must be in (0, 100], got ${qp.get('reps')}`); + } + return { path: pathStr as Path | 'both', n, k, validateN, reps, debug }; +} + +async function runDebugMulhiloF32(device: GPUDevice, _sm: ShaderManager): Promise { + log('info', `[debug=mulhilo] testing mulhilo on hard-coded values`); + // Test cases: hand-picked products that exercise specific bit patterns. + const cases = [ + { a: 1443728, b: 418697 }, + { a: 1, b: 1 }, + { a: 8388607, b: 8388607 }, + { a: 4194304, b: 2 }, + { a: 100, b: 200 }, + ]; + const inputBytes = new ArrayBuffer(cases.length * 8); + const iv = new Float32Array(inputBytes); + for (let i = 0; i < cases.length; i++) { + iv[i * 2] = cases[i].a; + iv[i * 2 + 1] = cases[i].b; + } + const outBytes = cases.length * 2 * 4; + const inBuf = device.createBuffer({ size: inputBytes.byteLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST }); + device.queue.writeBuffer(inBuf, 0, inputBytes); + const outBuf = device.createBuffer({ size: outBytes, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC }); + + // Just inline the mulhilo constants and function directly. + const code = ` +const BIAS: f32 = 70368744177664.0; +const W: f32 = 8388608.0; +const W_INV: f32 = 1.1920928955078125e-7; + +fn mulhilo(a: f32, b: f32) -> vec2 { + let q = fma(a, b, BIAS) - BIAS; + let lo0 = fma(a, b, -q); + let underflow = step(lo0, -0.5); + let hi = q * W_INV - underflow; + let lo = lo0 + underflow * W; + return vec2(hi, lo); +} + +@group(0) @binding(0) var ins: array>; +@group(0) @binding(1) var outs: array>; + +@compute @workgroup_size(1) +fn main(@builtin(global_invocation_id) gid: vec3) { + let n = arrayLength(&ins); + if (gid.x >= n) { return; } + let v = ins[gid.x]; + outs[gid.x] = mulhilo(v.x, v.y); +} +`; + const module = device.createShaderModule({ code }); + await module.getCompilationInfo(); + const layout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + ], + }); + const pipeline = await device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); + const bg = device.createBindGroup({ + layout, + entries: [ + { binding: 0, resource: { buffer: inBuf } }, + { binding: 1, resource: { buffer: outBuf } }, + ], + }); + const enc = device.createCommandEncoder(); + const ps = enc.beginComputePass(); + ps.setPipeline(pipeline); + ps.setBindGroup(0, bg); + ps.dispatchWorkgroups(cases.length, 1, 1); + ps.end(); + const stagingBuf = device.createBuffer({ size: outBytes, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); + enc.copyBufferToBuffer(outBuf, 0, stagingBuf, 0, outBytes); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await stagingBuf.mapAsync(GPUMapMode.READ); + const out = new Float32Array(stagingBuf.getMappedRange().slice(0)); + stagingBuf.unmap(); + stagingBuf.destroy(); + inBuf.destroy(); + outBuf.destroy(); + + for (let i = 0; i < cases.length; i++) { + const { a, b } = cases[i]; + const prod = BigInt(a) * BigInt(b); + const W = 8388608n; + const expHi = prod / W; + const expLo = prod % W; + const gotHi = Math.round(out[i * 2]); + const gotLo = Math.round(out[i * 2 + 1]); + const ok = (gotHi === Number(expHi)) && (gotLo === Number(expLo)); + log(ok ? 'ok' : 'err', `[debug] mulhilo(${a}, ${b}) = (hi=${gotHi}, lo=${gotLo}); expected (hi=${expHi}, lo=${expLo}) ${ok ? 'OK' : 'WRONG'}`); + } +} + +async function runDebugF32(device: GPUDevice, sm: ShaderManager, debugTag: string): Promise { + log('info', `[debug=${debugTag}] running f32 Mont debug shader`); + const p = BN254_CURVE_CONFIG.baseFieldModulus; + const params_f32 = compute_misc_params(p, WORD_SIZE_F32); + const R = params_f32.r; + const Rinv = params_f32.rinv; + log('info', `[debug] n0_f32=${params_f32.n0.toString()} num_words=${params_f32.num_words}`); + + // Pair: a = 1, b = 1 (canonical). aMont = R mod p, bMont = R mod p. + // Mont(R, R, R^-1) = R^2 * R^-1 = R = aMont. So expected output limbs == aMont limbs. + const aMont = R % p; + const bMont = R % p; + const aLimbs = bigintToLimbsF32(aMont); + const bLimbs = bigintToLimbsF32(bMont); + const xsBytes = new ArrayBuffer(NUM_LIMBS_F32 * 4); + const ysBytes = new ArrayBuffer(NUM_LIMBS_F32 * 4); + const xv = new Float32Array(xsBytes); + const yv = new Float32Array(ysBytes); + for (let j = 0; j < NUM_LIMBS_F32; j++) xv[j] = aLimbs[j]; + for (let j = 0; j < NUM_LIMBS_F32; j++) yv[j] = bLimbs[j]; + + // Capture intermediates: 64 f32 slots. + // [0..11] = s[0..11] AFTER i=0 outer iter + // [12..15] = (xy0_lo, xy0_hi, sum0, qi) at i=0 + // [16..19] = (qp0_lo, qp0_hi, c_lo_init, c_hi_init) at i=0 + // [20..31] = output limbs (after full montgomery_product) + // [32..33] = direct mulhilo(1443728.0, 418697.0).{x,y} + // [34..35] = direct mulhilo(sum0_s.y, N0).{x,y} + // [36..39] = (sum0_s.x, sum0_s.y, N0 echo, x.limbs[0] echo) + // [40..41] = bias_split_f32(1443728.0).{x,y} + // [42..43] = bias_split_f32(sum0).{x,y} // sum0 is the actual variable + // [44..47] = (bias_split_f32(rv).{x,y}, bias_split_f32(xy0.y).{x,y}) + // [48..63] = j=11 intermediates (xyj.{x,y}, qpj.{x,y}, t1, t1_s.{x,y}, t2, t2_s.{x,y}, t3, t3_s.{x,y}, c_lo_before, c_lo_after) + const DEBUG_SLOTS = 64; + const xsBuf = device.createBuffer({ + size: xsBytes.byteLength, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(xsBuf, 0, xsBytes); + const ysBuf = device.createBuffer({ + size: ysBytes.byteLength, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(ysBuf, 0, ysBytes); + const dbgBuf = device.createBuffer({ + size: DEBUG_SLOTS * 4, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, + }); + + // Use ShaderManager's helper bundle (mulhilo + bigint_f32 + montgomery_product_f32). + const helpers = sm.gen_montgomery_product_f32_shader(); + // Append a debug entry point that mirrors the Mont algorithm but captures + // per-position state for i=0 and finally writes the full Mont output. + const debugEntry = ` +@group(0) @binding(0) var xs: array; +@group(0) @binding(1) var ys: array; +@group(0) @binding(2) var dbg: array; + +@compute @workgroup_size(1) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x != 0u) { return; } + var x = xs[0]; + var y = ys[0]; + + // Run Mont. Capture intermediates by inlining the first outer iteration. + var s: BigIntF32; + for (var k = 0u; k < NUM_LIMBS; k ++) { s.limbs[k] = 0.0; } + var pp = get_p_f32(); + + // i = 0 + let xy0 = mulhilo(x.limbs[0], y.limbs[0]); + let sum0 = s.limbs[0] + xy0.y; + let sum0_s = bias_split_f32(sum0); + let qi = mulhilo(sum0_s.y, N0).y; + let qp0 = mulhilo(qi, pp.limbs[0]); + let lo_cancel = sum0_s.y + qp0.y; + let c_small = lo_cancel * W_INV + sum0_s.x; + let hi_pair = xy0.x + qp0.x; + let carry_full = hi_pair + c_small; + let carry_s = bias_split_f32(carry_full); + var c_hi = carry_s.x; + var c_lo = carry_s.y; + + dbg[12] = xy0.y; + dbg[13] = xy0.x; + dbg[14] = sum0; + dbg[15] = qi; + dbg[16] = qp0.y; + dbg[17] = qp0.x; + dbg[18] = c_lo; + dbg[19] = c_hi; + + // Direct test: mulhilo(1443728.0, 418697.0).y. Should be 1489936. + let direct_test = mulhilo(1443728.0, 418697.0); + dbg[32] = direct_test.x; // hi + dbg[33] = direct_test.y; // lo + // Also test with sum0_s.y and N0 directly (no Mont context). + let direct_test2 = mulhilo(sum0_s.y, N0); + dbg[34] = direct_test2.x; + dbg[35] = direct_test2.y; + dbg[36] = sum0_s.x; + dbg[37] = sum0_s.y; + dbg[38] = N0; + dbg[39] = x.limbs[0]; + + // Direct call to bias_split_f32 with a hard-coded constant. + let bs_const = bias_split_f32(1443728.0); + dbg[40] = bs_const.x; + dbg[41] = bs_const.y; + // Direct call to bias_split_f32 with the actual sum0 value. + let bs_var = bias_split_f32(sum0); + dbg[42] = bs_var.x; + dbg[43] = bs_var.y; + // Direct call with a runtime-derived variable that equals 1443728. + var rv: f32 = 1443728.0; + let bs_rv = bias_split_f32(rv); + dbg[44] = bs_rv.x; + dbg[45] = bs_rv.y; + // Direct call with xy0.y (the mulhilo result, runtime variable). + let bs_xy = bias_split_f32(xy0.y); + dbg[46] = bs_xy.x; + dbg[47] = bs_xy.y; + + + for (var j = 1u; j < NUM_LIMBS; j ++) { + let xyj = mulhilo(x.limbs[0], y.limbs[j]); + let qpj = mulhilo(qi, pp.limbs[j]); + let t1 = s.limbs[j] + xyj.y; + let t1_s = bias_split_f32(t1); + let t2 = t1_s.y + qpj.y; + let t2_s = bias_split_f32(t2); + let c_lo_before = c_lo; + let t3 = t2_s.y + c_lo; + let t3_s = bias_split_f32(t3); + if (j == 11u) { + dbg[60] = c_lo_before; + dbg[61] = t3_s.y; + // Direct re-check: what does bias_split_f32(5290618.0) return here? + let dt = bias_split_f32(5290618.0); + dbg[62] = dt.x; + dbg[63] = dt.y; + } + s.limbs[j - 1u] = t3_s.y; + let sum_overflow = t1_s.x + t2_s.x + t3_s.x + c_hi; + let nc1 = xyj.x + qpj.x; + let nc1_s = bias_split_f32(nc1); + let nc2 = nc1_s.y + sum_overflow; + let nc2_s = bias_split_f32(nc2); + c_hi = nc1_s.x + nc2_s.x; + c_lo = nc2_s.y; + if (j == 11u) { + // Capture j=11 intermediates. + dbg[48] = xyj.x; + dbg[49] = xyj.y; + dbg[50] = qpj.x; + dbg[51] = qpj.y; + dbg[52] = t1; + dbg[53] = t1_s.y; + dbg[54] = t2; + dbg[55] = t2_s.y; + dbg[56] = t3; + dbg[57] = t3_s.x; + dbg[58] = t3_s.y; + dbg[59] = c_lo; // this is c_lo AFTER update (since we capture after assignment) + } + } + s.limbs[NUM_LIMBS - 1u] = fma(c_hi, W, c_lo); + + for (var k = 0u; k < NUM_LIMBS; k ++) { dbg[k] = s.limbs[k]; } + + // Now call the full Mont function to compare. + var x2 = xs[0]; + var y2 = ys[0]; + let full = montgomery_product_f32(&x2, &y2); + for (var k = 0u; k < NUM_LIMBS; k ++) { dbg[20u + k] = full.limbs[k]; } +} +`; + const code = `${helpers}\n${debugEntry}`; + const module = device.createShaderModule({ code }); + const ci = await module.getCompilationInfo(); + let hasErr = false; + for (const m of ci.messages) { + if (m.type === 'error') { + console.error(`[debug shader] error: ${m.message} (line ${m.lineNum})`); + hasErr = true; + } else { + console.warn(`[debug shader] ${m.type}: ${m.message} (line ${m.lineNum})`); + } + } + if (hasErr) throw new Error('debug shader compile failed'); + + const layout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + ], + }); + const pipeline = await device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); + const bg = device.createBindGroup({ + layout, + entries: [ + { binding: 0, resource: { buffer: xsBuf } }, + { binding: 1, resource: { buffer: ysBuf } }, + { binding: 2, resource: { buffer: dbgBuf } }, + ], + }); + + const enc = device.createCommandEncoder(); + const ps = enc.beginComputePass(); + ps.setPipeline(pipeline); + ps.setBindGroup(0, bg); + ps.dispatchWorkgroups(1, 1, 1); + ps.end(); + const stagingBuf = device.createBuffer({ + size: DEBUG_SLOTS * 4, + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + enc.copyBufferToBuffer(dbgBuf, 0, stagingBuf, 0, DEBUG_SLOTS * 4); + device.queue.submit([enc.finish()]); + await device.queue.onSubmittedWorkDone(); + await stagingBuf.mapAsync(GPUMapMode.READ); + const view = new Float32Array(stagingBuf.getMappedRange().slice(0)); + stagingBuf.unmap(); + stagingBuf.destroy(); + + const fmtArr = (arr: Float32Array, start: number, end: number) => + Array.from(arr.subarray(start, end)) + .map(v => Math.round(v)) + .join(', '); + + log('info', `[debug] input a (Mont(1)): [${aLimbs.join(', ')}]`); + log('info', `[debug] input b (Mont(1)): [${bLimbs.join(', ')}]`); + log('info', `[debug] inline-i0 s[0..11] = [${fmtArr(view, 0, 12)}]`); + log('info', `[debug] inline-i0 xy0=(lo=${Math.round(view[12])}, hi=${Math.round(view[13])}) sum0=${Math.round(view[14])} qi=${Math.round(view[15])} qp0=(lo=${Math.round(view[16])}, hi=${Math.round(view[17])}) c_lo=${Math.round(view[18])} c_hi=${Math.round(view[19])}`); + log('info', `[debug] mont_full output = [${fmtArr(view, 20, 32)}]`); + log('info', `[debug] direct mulhilo(1443728, 418697) = (hi=${Math.round(view[32])}, lo=${Math.round(view[33])})`); + log('info', `[debug] direct mulhilo(sum0_s.y, N0) = (hi=${Math.round(view[34])}, lo=${Math.round(view[35])})`); + log('info', `[debug] sum0_s=(hi=${Math.round(view[36])}, lo=${Math.round(view[37])}) N0_echo=${Math.round(view[38])} x.limbs[0]_echo=${Math.round(view[39])}`); + log('info', `[debug] bias_split_f32(1443728.0) = (hi=${Math.round(view[40])}, lo=${Math.round(view[41])})`); + log('info', `[debug] bias_split_f32(sum0) = (hi=${Math.round(view[42])}, lo=${Math.round(view[43])})`); + log('info', `[debug] bias_split_f32(var rv=1443728.0) = (hi=${Math.round(view[44])}, lo=${Math.round(view[45])})`); + log('info', `[debug] bias_split_f32(xy0.y) = (hi=${Math.round(view[46])}, lo=${Math.round(view[47])})`); + log('info', `[debug] j=11: xyj=(hi=${Math.round(view[48])}, lo=${Math.round(view[49])}) qpj=(hi=${Math.round(view[50])}, lo=${Math.round(view[51])})`); + log('info', `[debug] j=11: t1=${Math.round(view[52])} t1_s.y=${Math.round(view[53])} t2=${Math.round(view[54])} t2_s.y=${Math.round(view[55])}`); + log('info', `[debug] j=11: t3=${Math.round(view[56])} t3_s.x=${Math.round(view[57])} t3_s.y=${Math.round(view[58])}`); + log('info', `[debug] j=11: c_lo_before=${Math.round(view[60])} t3_s.y_recheck=${Math.round(view[61])}`); + log('info', `[debug] j=11: bias_split_f32(5290618.0) direct = (hi=${Math.round(view[62])}, lo=${Math.round(view[63])})`); + log('info', `[debug] j=11: c_lo_after = ${Math.round(view[59])}`); + + xsBuf.destroy(); + ysBuf.destroy(); + dbgBuf.destroy(); +} + +async function main() { + try { + if (!('gpu' in navigator)) { + throw new Error('navigator.gpu missing — WebGPU not available'); + } + const params = parseParams(); + benchState.params = params; + log('info', `params: path=${params.path} n=${params.n} k=${params.k} validate-n=${params.validateN} reps=${params.reps}`); + + benchState.state = 'running'; + const device = await get_device(); + log('info', `WebGPU device acquired`); + + // ShaderManager is keyed on chunk_size / input_size for the MSM + // pipeline; for the micro-bench we only need its Mont-constant + // pre-computation, so values are arbitrary. + const sm = new ShaderManager(4, params.n, BN254_CURVE_CONFIG, false); + + if (params.debug) { + if (params.debug === 'mulhilo') { + await runDebugMulhiloF32(device, sm); + } else { + await runDebugF32(device, sm, params.debug); + } + benchState.state = 'done'; + log('ok', `[debug] done`); + return; + } + + const paths: Path[] = params.path === 'both' ? ['u32', 'f32'] : [params.path]; + for (const path of paths) { + const result = await runPath(device, sm, path, params.n, params.k, params.validateN, params.reps); + benchState.results.push(result); + if (!result.validateOk) { + // Surface the failure but continue with the other path so the + // caller can see both results in one shot if requested. + log('err', `path=${path} failed validation — stopping path traversal`); + break; + } + } + + benchState.state = 'done'; + log('ok', `bench done: ${benchState.results.length} paths`); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main().catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; +}); diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/bench-field-mul.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/bench-field-mul.mjs new file mode 100644 index 000000000000..d10b377d4255 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/scripts/bench-field-mul.mjs @@ -0,0 +1,256 @@ +#!/usr/bin/env node +// Field-mul micro-benchmark driver. Launches Playwright Chromium with +// WebGPU enabled, navigates to the standalone bench-field-mul.html page, +// waits for `window.__bench.state === 'done' | 'error'`, prints results. +// +// Host-side safety: every flag value is validated before any browser +// launch. The page itself also caps `n` at 2^23 and `k` at 100; this +// script rejects anything outside those bounds before navigation. + +import { chromium } from 'playwright-core'; +import { parseArgs } from 'node:util'; + +const DEFAULT_URL_BASE = 'http://localhost:5173/dev/msm-webgpu/bench-field-mul.html'; +const N_MAX = 1 << 23; +const K_MAX = 100; + +const { values: argv } = parseArgs({ + options: { + path: { type: 'string', default: 'both' }, + n: { type: 'string', default: '64' }, + k: { type: 'string', default: '1' }, + 'validate-n': { type: 'string' }, + reps: { type: 'string', default: '3' }, + url: { type: 'string', default: DEFAULT_URL_BASE }, + headed: { type: 'boolean', default: false }, + timeout: { type: 'string', default: '60' }, + json: { type: 'boolean', default: false }, + debug: { type: 'string' }, + variant: { type: 'string', default: 'cios' }, + help: { type: 'boolean', default: false }, + }, + allowPositionals: false, +}); + +if (argv.help) { + process.stdout.write( + `Field-mul WebGPU micro-benchmark driver + +Usage: + node dev/msm-webgpu/scripts/bench-field-mul.mjs [options] + +Options: + --path u32|f32|both (default both) + --n N (default 64, max ${N_MAX}) + --k K (default 1, max ${K_MAX}) + --validate-n N (default min(64, n)) + --reps R (default 3) + --url URL (default ${DEFAULT_URL_BASE}) + --headed Run with visible browser window + --timeout SECS Bench page completion timeout (default 60) + --json Machine-readable JSON only output + --variant V u32: cios|karat. f32: sos3uv3. (default cios) + --help Show this help +`, + ); + process.exit(0); +} + +const path = String(argv.path); +if (path !== 'u32' && path !== 'f32' && path !== 'both') { + process.stderr.write(`error: --path must be u32|f32|both, got "${path}"\n`); + process.exit(2); +} +const n = parseInt(String(argv.n), 10); +if (!Number.isFinite(n) || n <= 0 || n > N_MAX) { + process.stderr.write(`error: --n must be in (0, ${N_MAX}], got ${argv.n}\n`); + process.exit(2); +} +const k = parseInt(String(argv.k), 10); +if (!Number.isFinite(k) || k <= 0 || k > K_MAX) { + process.stderr.write(`error: --k must be in (0, ${K_MAX}], got ${argv.k}\n`); + process.exit(2); +} +const validateN = + argv['validate-n'] !== undefined ? parseInt(String(argv['validate-n']), 10) : Math.min(64, n); +if (!Number.isFinite(validateN) || validateN < 0 || validateN > n) { + process.stderr.write(`error: --validate-n must be in [0, n], got ${argv['validate-n']}\n`); + process.exit(2); +} +const reps = parseInt(String(argv.reps), 10); +if (!Number.isFinite(reps) || reps <= 0 || reps > 100) { + process.stderr.write(`error: --reps must be in (0, 100], got ${argv.reps}\n`); + process.exit(2); +} +const headed = Boolean(argv.headed); +const timeoutMs = parseFloat(String(argv.timeout)) * 1000; +const jsonOnly = Boolean(argv.json); +const baseUrl = String(argv.url); + +function err(msg) { + process.stderr.write(`[bench-field-mul] ${msg}\n`); +} +function out(msg) { + if (!jsonOnly) process.stdout.write(`${msg}\n`); +} + +async function reachable(targetUrl) { + // Vite returns 404 for the bare `/` when there's no index at the root, + // so hit the actual bench page URL (GET; HEAD is sometimes blocked by + // middleware too). Any response — even a non-200 — means the server is + // up; a network/connect-refused error throws. + try { + const res = await fetch(targetUrl, { method: 'GET' }); + return res.status >= 200 && res.status < 500; + } catch { + return false; + } +} + +const CHROMIUM_ARGS = [ + '--enable-unsafe-webgpu', + '--enable-features=Vulkan,WebGPU', + '--use-angle=metal', + '--disable-features=ServiceWorker', + '--ignore-gpu-blocklist', +]; + +async function tryLaunch(headless) { + return chromium.launch({ headless, args: CHROMIUM_ARGS }); +} + +function buildUrl() { + const u = new URL(baseUrl); + u.searchParams.set('path', path); + u.searchParams.set('n', String(n)); + u.searchParams.set('k', String(k)); + u.searchParams.set('validate-n', String(validateN)); + u.searchParams.set('reps', String(reps)); + u.searchParams.set('variant', String(argv.variant)); + if (argv.debug !== undefined) { + u.searchParams.set('debug', String(argv.debug)); + } + return u.toString(); +} + +async function runOnce(headless) { + let browser; + try { + browser = await tryLaunch(headless); + const context = await browser.newContext({ + viewport: { width: 900, height: 600 }, + permissions: [], + bypassCSP: false, + }); + const page = await context.newPage(); + page.on('console', msg => { + const txt = msg.text(); + if (!txt.startsWith('[vite]')) { + err(`[page:${msg.type()}] ${txt}`); + } + }); + page.on('pageerror', e => err(`[page:pageerror] ${e.message}`)); + + const navUrl = buildUrl(); + err(`navigating to ${navUrl} (headless=${headless})`); + await page.goto(navUrl, { waitUntil: 'domcontentloaded', timeout: 30_000 }); + + const hasWebGpu = await page.evaluate(() => 'gpu' in navigator); + if (!hasWebGpu) { + throw new Error('navigator.gpu missing in this Chromium instance'); + } + + err(`waiting for bench to complete (up to ${(timeoutMs / 1000).toFixed(0)}s)`); + const t0 = Date.now(); + await page.waitForFunction( + () => window.__bench?.state === 'done' || window.__bench?.state === 'error', + { timeout: timeoutMs, polling: 250 }, + ); + const elapsed = (Date.now() - t0) / 1000; + err(`bench reached terminal state in ${elapsed.toFixed(1)}s`); + + const result = await page.evaluate(() => { + const b = window.__bench; + return { + state: b.state, + params: b.params, + results: JSON.parse(JSON.stringify(b.results)), + error: b.error, + log: b.log.slice(), + }; + }); + result.elapsedSec = elapsed; + return result; + } finally { + try { + if (browser) await browser.close(); + } catch (e) { + err(`browser.close failed: ${e.message}`); + } + } +} + +(async () => { + if (!(await reachable(baseUrl))) { + err(`dev server not reachable at ${new URL(baseUrl).origin}`); + err(`start it with: cd barretenberg/ts && ./node_modules/.bin/vite --config dev/msm-webgpu/vite.config.ts --no-open`); + process.exit(3); + } + + let result; + try { + result = await runOnce(!headed); + } catch (e) { + err(`run failed: ${e.message}`); + if (!headed) { + err('retrying in headed mode'); + try { + result = await runOnce(false); + } catch (e2) { + err(`headed retry also failed: ${e2.message}`); + process.exit(4); + } + } else { + process.exit(4); + } + } + + if (result.state === 'error') { + err(`page reported error: ${result.error}`); + if (jsonOnly) process.stdout.write(JSON.stringify(result, null, 2) + '\n'); + process.exit(5); + } + + if (!jsonOnly) { + out(`## field-mul micro-benchmark`); + out(''); + out(`- params: ${JSON.stringify(result.params)}`); + out(`- elapsed wall (s): ${result.elapsedSec.toFixed(2)}`); + out(''); + out('| path | validate ok | reps | median ms | min ms | max ms | mults/s |'); + out('|------|-------------|-----:|----------:|-------:|-------:|------------:|'); + for (const r of result.results) { + const t = r.timing; + const okStr = r.validateOk ? 'OK' : 'FAIL'; + if (t) { + out( + `| ${r.path.padEnd(4)} | ${okStr.padEnd(11)} | ${String(t.reps).padStart(4)} | ` + + `${t.msMedian.toFixed(3).padStart(9)} | ${t.msMin.toFixed(3).padStart(6)} | ${t.msMax.toFixed(3).padStart(6)} | ${t.multsPerSec.toExponential(3).padStart(11)} |`, + ); + } else { + out(`| ${r.path.padEnd(4)} | ${okStr.padEnd(11)} | (no timing — validation failed) |`); + for (const m of r.mismatches) { + out(` ${m.replace(/\n/g, '\n ')}`); + } + } + } + out(''); + } + process.stdout.write(JSON.stringify(result, null, 2) + '\n'); + + const anyFailure = result.results.some(r => !r.validateOk); + process.exit(anyFailure ? 1 : 0); +})().catch(e => { + err(`unexpected: ${e.stack ?? e.message ?? String(e)}`); + process.exit(99); +}); diff --git a/barretenberg/ts/dev/msm-webgpu/tsconfig.json b/barretenberg/ts/dev/msm-webgpu/tsconfig.json new file mode 100644 index 000000000000..edf718d1c43b --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/tsconfig.json @@ -0,0 +1,10 @@ +{ + "extends": "../../tsconfig.json", + "compilerOptions": { + "rootDir": ".", + "composite": false, + "ignoreDeprecations": "6.0", + "types": ["@webgpu/types", "vite/client"] + }, + "include": ["**/*.ts"] +} diff --git a/barretenberg/ts/package.json b/barretenberg/ts/package.json index f3f3d4a1e5bf..5fd3c8371a14 100644 --- a/barretenberg/ts/package.json +++ b/barretenberg/ts/package.json @@ -99,6 +99,7 @@ "eslint": "^9.26.0", "eslint-config-prettier": "^10.1.5", "jest": "^30.0.0", + "playwright-core": "^1.59.1", "prettier": "^3.5.3", "ts-jest": "^29.4.0", "ts-loader": "^9.4.2", diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index 7730efe0e8f8..c87f929b7fc5 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -12,6 +12,7 @@ import { batch_inverse as batch_inverse_shader, batch_inverse_parallel as batch_inverse_parallel_shader, bigint as bigint_funcs, + bigint_f32 as bigint_f32_funcs, bpr_bn254 as bpr_bn254_shader, convert_point_coords_and_decompose_scalars, convert_points_only as convert_points_only_shader, @@ -20,9 +21,14 @@ import { ec_bn254 as ec_bn254_funcs, extract_word_from_bytes_le as extract_word_from_bytes_le_funcs, field as field_funcs, + field_mul_bench_f32 as field_mul_bench_f32_shader, + field_mul_bench_u32 as field_mul_bench_u32_shader, fr_pow as fr_pow_funcs, horner_reduce_bn254 as horner_reduce_bn254_shader, mont_pro_product as montgomery_product_funcs, + mont_pro_product_f32_22_sos3uv3 as montgomery_product_f32_22_sos3uv3_funcs, + mont_pro_product_karat_yuval as montgomery_product_karat_yuval_funcs, + mulhilo_22 as mulhilo_22_funcs, smvp_bn254 as smvp_bn254_shader, structs, transpose_parallel_count as transpose_parallel_count_shader, @@ -34,12 +40,25 @@ import { compute_misc_params, compute_mod_inverse_pow2, gen_p_limbs, + gen_p_limbs_f32, gen_r_limbs, gen_mu_limbs, gen_wgsl_limbs_code, } from './utils.js'; import { BN254_CURVE_CONFIG, CurveConfig } from './curve_config.js'; +// Modular inverse via extended Euclidean. Returns a^-1 mod m. Both inputs > 0. +function modinv(a: bigint, m: bigint): bigint { + let [old_r, r] = [((a % m) + m) % m, m]; + let [old_s, s] = [1n, 0n]; + while (r !== 0n) { + const q = old_r / r; + [old_r, r] = [r, old_r - q * r]; + [old_s, s] = [s, old_s - q * s]; + } + return ((old_s % m) + m) % m; +} + // Generates parameterised WGSL shader sources for the BN254 MSM // pipeline. Pre-computes Montgomery / Barrett constants for the // configured word size on construction so the per-shader render @@ -67,6 +86,13 @@ export class ShaderManager { public sqrt_exp_limbs: string; public p_inv_mod_2w: number; public mu_limbs: string; + // 22-bit-limb f32 Montgomery params. Used exclusively by + // `gen_field_mul_bench_f32_shader` for the sos3uv3 micro-benchmark. + // The 22-bit width buys a 4-way exact sum (4·2^22 = 2^24 fits in f32 + // mantissa), enabling the per-slot (tlo, thi) chain-break in sos3uv3. + public num_limbs_f32_22: number; + public n0_f32_22: bigint; + public p_limbs_f32_22_str: string; public curveConfig: CurveConfig; public recompile = ''; @@ -106,6 +132,13 @@ export class ShaderManager { this.slack = this.num_words * this.word_size - this.p_bitlength; this.w_mask = (1 << this.word_size) - 1; + // 22-bit-limb f32 path (bench only). compute_misc_params(p, 22) + // gives num_words = 12 for BN254 (12·22 = 264 ≥ 254). + const params_f32_22 = compute_misc_params(this.p, 22); + this.num_limbs_f32_22 = params_f32_22.num_words; + this.n0_f32_22 = params_f32_22.n0; + this.p_limbs_f32_22_str = gen_p_limbs_f32(this.p, this.num_limbs_f32_22, 22); + if (force_recompile) { const rand = Math.round(Math.random() * 100000000000000000) % 2 ** 32; this.recompile = ` @@ -630,4 +663,230 @@ export class ShaderManager { }, ); } + + // Bench-only entry shader for the u32 Montgomery product. Each thread + // chains `k` Mont mults over an (a, b) pair. `variant='cios'` selects + // the runtime-loop CIOS in `mont_pro_product.template.wgsl`; 'karat' + // selects the recursive Karatsuba + Yuval body below. + public gen_field_mul_bench_u32_shader( + workgroup_size: number, + variant: 'cios' | 'karat' = 'cios', + ): string { + const structs_src = mustache.render(structs, { num_words: this.num_words }); + const bigint_src = mustache.render(bigint_funcs, {}); + const mont_src = + variant === 'karat' + ? this.renderKaratYuvalMont() + : mustache.render(montgomery_product_funcs, { + num_words: this.num_words, + word_size: this.word_size, + n0: this.n0, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + p_limbs: this.p_limbs, + }); + const entry_src = mustache.render(field_mul_bench_u32_shader, { + workgroup_size, + }); + return `${structs_src} +${bigint_src} +${mont_src} +${entry_src}`; + } + + // Bench-only entry shader for the f32 Montgomery product. Only the + // sos3uv3 variant (22-bit limbs, per-slot tlo/thi chain-break) is wired + // up here — it is the fastest f32 Mont mul found in the wider variant + // sweep and is kept as a reference point alongside the u32 paths. + public gen_field_mul_bench_f32_shader( + workgroup_size: number, + variant: 'sos3uv3' = 'sos3uv3', + ): string { + if (variant !== 'sos3uv3') { + throw new Error(`f32 bench variant must be 'sos3uv3', got '${variant}'`); + } + const helpers = this.gen_montgomery_product_f32_22_sos3uv3_shader(); + const entry_src = mustache.render(field_mul_bench_f32_shader, { + workgroup_size, + }); + return `${helpers} +${entry_src}`; + } + + // Renders mont_pro_product_karat_yuval.template.wgsl. The .wgsl file + // owns the algorithm structure (chunks → sums → 9 schoolbook + // sub-sub-products → inner combines → outer combine → Yuval reduce → + // final canonicalize) via mustache `{{#each}}` sections. The TS here + // just provides the index arrays + r_inv limb constants. + private renderKaratYuvalMont(): string { + const N = this.num_words; // 20 + const WS = this.word_size; // 13 + const W = 1n << BigInt(WS); + + const r_inv = modinv(W, this.p); + const mask = W - 1n; + const limbs: number[] = []; + let v = r_inv; + for (let i = 0; i < N; i++) { + limbs.push(Number(v & mask)); + v >>= BigInt(WS); + } + const r_inv_consts = limbs.map((val, idx) => ({ idx, val })); + + const input_loads: Array<{ name: string; ptr: string; k: number }> = []; + const chunks = [ + ['x_lo_lo', 'x_ptr', 0], + ['x_lo_hi', 'x_ptr', 5], + ['x_hi_lo', 'x_ptr', 10], + ['x_hi_hi', 'x_ptr', 15], + ['y_lo_lo', 'y_ptr', 0], + ['y_lo_hi', 'y_ptr', 5], + ['y_hi_lo', 'y_ptr', 10], + ['y_hi_hi', 'y_ptr', 15], + ] as const; + for (const [prefix, ptr, base] of chunks) { + for (let k = 0; k < 5; k++) { + input_loads.push({ name: `${prefix}_${k}`, ptr, k: (base as number) + k }); + } + } + + const sum_lets: Array<{ name: string; lhs: string; rhs: string }> = []; + const sumDefs = [ + ['a_lo_sum', 'x_lo_lo', 'x_lo_hi'], + ['b_lo_sum', 'y_lo_lo', 'y_lo_hi'], + ['a_hi_sum', 'x_hi_lo', 'x_hi_hi'], + ['b_hi_sum', 'y_hi_lo', 'y_hi_hi'], + ['a_cr_lo', 'x_lo_lo', 'x_hi_lo'], + ['a_cr_hi', 'x_lo_hi', 'x_hi_hi'], + ['b_cr_lo', 'y_lo_lo', 'y_hi_lo'], + ['b_cr_hi', 'y_lo_hi', 'y_hi_hi'], + ['a_cr_sum', 'a_cr_lo', 'a_cr_hi'], + ['b_cr_sum', 'b_cr_lo', 'b_cr_hi'], + ] as const; + for (const [name, lhs, rhs] of sumDefs) { + for (let k = 0; k < 5; k++) { + sum_lets.push({ name: `${name}_${k}`, lhs: `${lhs}_${k}`, rhs: `${rhs}_${k}` }); + } + } + + const schoolbooks = [ + { label: 'PP_lo_LL = x_lo_lo · y_lo_lo', out_prefix: 'pp_lo_ll', a_prefix: 'x_lo_lo', b_prefix: 'y_lo_lo' }, + { label: 'PP_lo_HH = x_lo_hi · y_lo_hi', out_prefix: 'pp_lo_hh', a_prefix: 'x_lo_hi', b_prefix: 'y_lo_hi' }, + { label: 'PP_lo_C = a_lo_sum · b_lo_sum', out_prefix: 'pp_lo_c', a_prefix: 'a_lo_sum', b_prefix: 'b_lo_sum' }, + { label: 'PP_hi_LL = x_hi_lo · y_hi_lo', out_prefix: 'pp_hi_ll', a_prefix: 'x_hi_lo', b_prefix: 'y_hi_lo' }, + { label: 'PP_hi_HH = x_hi_hi · y_hi_hi', out_prefix: 'pp_hi_hh', a_prefix: 'x_hi_hi', b_prefix: 'y_hi_hi' }, + { label: 'PP_hi_C = a_hi_sum · b_hi_sum', out_prefix: 'pp_hi_c', a_prefix: 'a_hi_sum', b_prefix: 'b_hi_sum' }, + { label: 'PP_cr_LL = a_cr_lo · b_cr_lo', out_prefix: 'pp_cr_ll', a_prefix: 'a_cr_lo', b_prefix: 'b_cr_lo' }, + { label: 'PP_cr_HH = a_cr_hi · b_cr_hi', out_prefix: 'pp_cr_hh', a_prefix: 'a_cr_hi', b_prefix: 'b_cr_hi' }, + { label: 'PP_cr_C = a_cr_sum · b_cr_sum', out_prefix: 'pp_cr_c', a_prefix: 'a_cr_sum', b_prefix: 'b_cr_sum' }, + ]; + + const inner_combines = [ + { label: 'P_lo from pp_lo_*', out_prefix: 'p_lo', ll_prefix: 'pp_lo_ll', hh_prefix: 'pp_lo_hh', c_prefix: 'pp_lo_c' }, + { label: 'P_hi from pp_hi_*', out_prefix: 'p_hi', ll_prefix: 'pp_hi_ll', hh_prefix: 'pp_hi_hh', c_prefix: 'pp_hi_c' }, + { label: 'P_cr from pp_cr_*', out_prefix: 'p_cr', ll_prefix: 'pp_cr_ll', hh_prefix: 'pp_cr_hh', c_prefix: 'pp_cr_c' }, + ]; + + const outer_init: Array<{ slot: number; init_expr: string }> = []; + for (let k = 0; k < 19; k++) outer_init.push({ slot: k, init_expr: `p_lo_${k}` }); + outer_init.push({ slot: 19, init_expr: '0u' }); + for (let k = 0; k < 19; k++) outer_init.push({ slot: 20 + k, init_expr: `p_hi_${k}` }); + outer_init.push({ slot: 39, init_expr: '0u' }); + + const outer_cross: Array<{ slot: number; k: number }> = []; + for (let k = 0; k < 19; k++) outer_cross.push({ slot: 10 + k, k }); + + const yuval_iters: Array<{ i: number; writes: Array<{ slot: number; r_idx: number; first: boolean }> }> = []; + for (let i = 0; i < N - 1; i++) { + const writes = []; + for (let j = 0; j < N; j++) { + writes.push({ slot: i + 1 + j, r_idx: j, first: j === 0 }); + } + yuval_iters.push({ i, writes }); + } + + const i_std = N - 1; + const standard_writes: Array<{ slot: number; p_idx: number; first: boolean }> = []; + for (let j = 0; j < N; j++) { + standard_writes.push({ slot: i_std + j, p_idx: j, first: j === 1 }); + } + + const final_drain: Array<{ slot: number }> = []; + for (let i = 0; i < N; i++) final_drain.push({ slot: N + i }); + + const extract: Array<{ out_k: number; src_slot: number }> = []; + for (let i = 0; i < N; i++) extract.push({ out_k: i, src_slot: N + i }); + + return mustache.render(montgomery_product_karat_yuval_funcs, { + num_words: N, + word_size: WS, + n0: this.n0, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + p_limbs: this.p_limbs, + r_inv_consts, + input_loads, + sum_lets, + schoolbooks, + inner_combines, + outer_init, + outer_cross, + yuval_iters, + i_std, + standard_writes, + final_drain, + extract, + }); + } + + // Renders mont_pro_product_f32_22_sos3uv3.template.wgsl. The .wgsl owns + // the algorithm — separate per-slot tlo/thi f32 accumulators, no + // inter-j carry chain, drain at end of each outer iter via + // bias_split_f32_le4w. This TS supplies index arrays for the mustache + // slot-init / inner-pairs / drain-cols sections. + public gen_montgomery_product_f32_22_sos3uv3_shader(): string { + const N = this.num_limbs_f32_22; + const W_INV_VAL = 2.384185791015625e-7; + const n0Num = Number(this.n0_f32_22); + const n0Scaled = n0Num * W_INV_VAL; + + // Slot init for iter 0: tlo[0] = init_slot0, everything else 0. + // Slot init for i>=1: tlo[0] = init_slot0 + s1, tlo[k] = s[k+1] for + // k=1..N-2, tlo[N-1] = 0; thi[*] = 0. + const slotInitsI0: Array<{ name: string; init_expr: string }> = []; + const slotInitsGeneral: Array<{ name: string; init_expr: string }> = []; + for (let k = 0; k < N; k++) { + slotInitsI0.push({ name: `tlo${k}`, init_expr: k === 0 ? 'init_slot0' : '0.0' }); + slotInitsI0.push({ name: `thi${k}`, init_expr: '0.0' }); + let genTlo: string; + if (k === 0) genTlo = 'init_slot0 + s1'; + else if (k === N - 1) genTlo = '0.0'; + else genTlo = `s${k + 1}`; + slotInitsGeneral.push({ name: `tlo${k}`, init_expr: genTlo }); + slotInitsGeneral.push({ name: `thi${k}`, init_expr: '0.0' }); + } + + // Inner-j pairs: for j=1..N-1, write tlo[j-1] += lo_sum, thi[j] += hi_sum. + const innerPairs = []; + for (let j = 1; j < N; j++) innerPairs.push({ j, km1: j - 1, k: j }); + + // Drain cols: k=0..N-1. + const drainCols = Array.from({ length: N }, (_, k) => ({ k })); + + const ctx = { + num_limbs: N, + n0: `${this.n0_f32_22.toString()}.0`, + n0_scaled: n0Scaled.toString(), + p_limbs_f32: this.p_limbs_f32_22_str, + slot_inits_i0: slotInitsI0, + slot_inits_general: slotInitsGeneral, + inner_pairs: innerPairs, + drain_cols: drainCols, + }; + const bigint_f32_src = mustache.render(bigint_f32_funcs, ctx); + const mont_src = mustache.render(montgomery_product_f32_22_sos3uv3_funcs, ctx); + return `${mulhilo_22_funcs}\n${bigint_f32_src}\n${mont_src}`; + } } diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/utils.ts b/barretenberg/ts/src/msm_webgpu/cuzk/utils.ts index 489081f278d3..8570a077e4f0 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/utils.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/utils.ts @@ -211,6 +211,35 @@ export const gen_wgsl_limbs_code = ( return r; }; +// f32 variant of gen_wgsl_limbs_code: emits `.0` literals instead of `u`. +// Used by the FMA-Montgomery (23-bit f32 limb) path. Does its own +// little-endian limb extraction because `to_words_le` returns a +// Uint16Array, which truncates any word_size > 16. +export const gen_wgsl_limbs_code_f32 = ( + val: bigint, + var_name: string, + num_words: number, + word_size: number, +): string => { + const mask = (BigInt(1) << BigInt(word_size)) - BigInt(1); + let r = ""; + let v = val; + for (let i = 0; i < num_words; i++) { + const limb = Number(v & mask); + r += ` ${var_name}.limbs[${i}]` + " = " + limb.toString() + ".0;\n"; + v >>= BigInt(word_size); + } + return r; +}; + +export const gen_p_limbs_f32 = ( + p: bigint, + num_words: number, + word_size: number, +): string => { + return gen_wgsl_limbs_code_f32(p, "p", num_words, word_size); +}; + export const gen_barrett_domb_m_limbs = ( m: bigint, num_words: number, diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index 68ffad6d7bd9..94e18b0bf47e 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -1,6 +1,6 @@ // AUTO-GENERATED by scripts/inline-wgsl.mjs. DO NOT EDIT. // Run `yarn generate:wgsl` (or `node scripts/inline-wgsl.mjs`) to regenerate. -// 29 shader sources inlined. +// 35 shader sources inlined. /* eslint-disable */ @@ -506,6 +506,72 @@ fn bigint_signed_axby_modp_halve_k( } `; +export const bigint_f32 = `// f32-limb mirror of \`bigint.template.wgsl\`. Each limb holds an +// integer-valued f32 in [0, 2^WORD_SIZE_F32) = [0, W). Mirrors the +// subset of \`bigint.template.wgsl\` needed by \`montgomery_product_f32\` +// and the f32 field ops in \`field/field_f32.template.wgsl\` / +// \`field/fr_pow_f32.template.wgsl\`. + +struct BigIntF32 { + limbs: array +} + +fn bigint_f32_gt(x: ptr, y: ptr) -> bool { + for (var idx = 0u; idx < {{ num_limbs }}u; idx ++) { + let i = {{ num_limbs }}u - 1u - idx; + if ((*x).limbs[i] < (*y).limbs[i]) { return false; } + if ((*x).limbs[i] > (*y).limbs[i]) { return true; } + } + return false; +} + +fn bigint_f32_eq(x: ptr, y: ptr) -> bool { + for (var i = 0u; i < {{ num_limbs }}u; i ++) { + if ((*x).limbs[i] != (*y).limbs[i]) { return false; } + } + return true; +} + +// res = a - b. Per-limb borrow lives as an f32 in {0.0, 1.0}; +// \`step(diff, -0.5)\` is 1.0 iff diff is a negative integer (i.e., +// underflowed). Adding \`underflow * W\` then canonicalises the limb +// back into [0, W). Returns the final borrow-out (0.0 or 1.0): callers +// that need to know whether a >= b consult this flag, mirroring the +// u32 \`bigint_sub\`'s return convention. +fn bigint_f32_sub(a: ptr, b: ptr, res: ptr) -> f32 { + var borrow: f32 = 0.0; + for (var i = 0u; i < {{ num_limbs }}u; i ++) { + let diff = (*a).limbs[i] - (*b).limbs[i] - borrow; + let underflow = step(diff, -0.5); + (*res).limbs[i] = diff + underflow * W; + borrow = underflow; + } + return borrow; +} + +// res = a + b. Per-limb carry lives as an f32 in {0.0, 1.0}; \`step(W-0.5, sum)\` +// is 1.0 iff sum >= W (the bias-free branchless test). Each sum is at most +// 2*(W-1) + 1 = 2^24 - 1, exact in the 24-bit f32 mantissa. Returns the +// final carry-out (0.0 or 1.0) for downstream conditional-reduce logic. +fn bigint_f32_add(a: ptr, b: ptr, res: ptr) -> f32 { + var carry: f32 = 0.0; + for (var i = 0u; i < {{ num_limbs }}u; i ++) { + let sum = (*a).limbs[i] + (*b).limbs[i] + carry; + let overflow = step(W - 0.5, sum); + (*res).limbs[i] = sum - overflow * W; + carry = overflow; + } + return carry; +} + +fn bigint_f32_is_zero(x: ptr) -> bool { + for (var i = 0u; i < {{ num_limbs }}u; i ++) { + if ((*x).limbs[i] != 0.0) { return false; } + } + return true; +} +`; + export const ec_bn254 = `// Jacobian-coordinate EC arithmetic for BN254 (short Weierstrass a=0, b=3). // Affine interpretation: (X, Y, Z) represents affine (X/Z^2, Y/Z^3). // Identity: any point with Z = 0. @@ -3609,6 +3675,88 @@ fn extract_word_from_bytes_le( } `; +export const field_mul_bench_f32 = `// Field-mul micro-benchmark, f32 / 12×23-bit limbs (CIOS over f32 FMA). +// Every thread loads one (a, b) pair, runs \`k\` chained Montgomery +// products (a = montgomery_product_f32(a, b) repeated k times), and +// writes the final \`a\` back. The host caps \`k\` at <=100 before passing +// it in. +// +// Loop bounds. The only data-dependent loop is \`for (var i = 0u; i < k; i++)\` +// where \`k = params.y\` is a host-uniform capped at 100. The early-out +// \`if (tid >= n)\` is a guard, not a loop bound, and \`n = params.x\` is +// host-capped at 2^23 to keep dispatch sizes reasonable. The inner +// \`montgomery_product_f32\` loops are bounded by the compile-time constant +// \`NUM_LIMBS\`. + +// Inputs are split into two separate \`xs\`/\`ys\` arrays (one BigIntF32 +// per thread per array). This mirrors the working layout used by +// \`testMontgomeryProductF32\` in wgsl_unit_tests.ts — packing into a +// single \`array\` produced all-zero outputs on Dawn/Metal +// even when the inputs read back correctly via a passthrough variant, +// so we sidestep struct-in-array layout concerns entirely. +@group(0) @binding(0) var xs: array; +@group(0) @binding(1) var ys: array; +@group(0) @binding(2) var outputs: array; +@group(0) @binding(3) var params: vec2; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let n = params.x; + let k = params.y; + let tid = gid.x; + if (tid >= n) { return; } + + var a = xs[tid]; + var b = ys[tid]; + // Chained loop calls the unreduced body — Mont is closed over inputs + // in [0, 2p) since all limbs stay < W. Final conditional_reduce + // canonicalizes the result. + for (var i = 0u; i < k; i = i + 1u) { + a = montgomery_product_f32_unreduced(&a, &b); + } + var p_final = get_p_f32(); + outputs[tid] = conditional_reduce_f32(&a, &p_final); +} +`; + +export const field_mul_bench_u32 = `// Field-mul micro-benchmark, u32 / 20×13-bit limbs (Mitschabaude-CIOS). +// Every thread loads one (a, b) pair, runs \`k\` chained Montgomery +// products (a = montgomery_product(a, b) repeated k times), and writes +// the final \`a\` back. The host caps \`k\` at <=100 before passing it in. +// +// Loop bounds. The only data-dependent loop is \`for (var i = 0u; i < k; i++)\` +// where \`k = params.y\` is a host-uniform capped at 100. The early-out +// \`if (tid >= n)\` is a guard, not a loop bound, and \`n = params.x\` is +// host-capped at 2^23 to keep dispatch sizes reasonable. The inner +// \`montgomery_product\` loops are bounded by the compile-time constant +// \`NUM_WORDS\`. + +// Inputs are split into two separate \`xs\`/\`ys\` arrays (one BigInt per +// thread per array). Matches the f32 path's layout for symmetry — the +// \`array\` packing was found to produce all-zero outputs on +// Dawn/Metal in the f32 variant of this shader; using separate arrays +// sidesteps that concern. +@group(0) @binding(0) var xs: array; +@group(0) @binding(1) var ys: array; +@group(0) @binding(2) var outputs: array; +@group(0) @binding(3) var params: vec2; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let n = params.x; + let k = params.y; + let tid = gid.x; + if (tid >= n) { return; } + + var a = xs[tid]; + var b = ys[tid]; + for (var i = 0u; i < k; i = i + 1u) { + a = montgomery_product(&a, &b); + } + outputs[tid] = a; +} +`; + export const horner_reduce_bn254 = `{{> structs }} {{> bigint_funcs }} {{> montgomery_product_funcs }} @@ -4777,6 +4925,462 @@ fn conditional_reduce(x: ptr, y: ptr) -> Big } `; +export const mont_pro_product_f32_22_sos3uv3 = `// 22-bit-limb f32 Montgomery product, sos3uv3 — chain breaking via +// SEPARATE per-slot f32 accumulators (lo and hi). +// +// Inside the inner-j loop: +// - sos3uv2 had a single chain c_hi/c_lo flowing j → j+1 (serial). +// - sos3u32 used i32 accumulators (broken chain but slow int ops). +// - sos3uv3 uses TWO per-slot f32 accumulators (tlo[k], thi[k]). +// Each j writes UNIQUE tlo[j-1] and thi[j] — no overlap across j's, +// no carry chain. Stays in f32 throughout (no int conversions). +// +// Per-slot bound: tlo[k] = s_old[k+1] + 1 lo (≤ 2.5W). thi[k] = 1 hi +// (≤ W). Combined at drain: 3.5W < 2²³·⁸ < 2²⁴ — fits f32 exactly. +// +// Drain at end of each outer iter: combine tlo[k] + thi[k] + carry, +// floor-based bias_split, write canonical s[k]. Serial across cols. +// +// === Structure === +// Outer iter i=0 is unrolled (no shifted-in s values to read). Iters +// 1..N-1 are in a runtime \`for\` loop with a separate slot_init that pulls +// in s1..s11. Inner-j pairs and drain cols are mustache-expanded — keeps +// the named locals (tlo0..tlo11, thi0..thi11) which Metal can register- +// allocate instead of spilling to thread-private memory. + +const NUM_LIMBS: u32 = {{ num_limbs }}u; +const N0: f32 = {{ n0 }}; +const N0_SCALED: f32 = {{ n0_scaled }}; +const TWO_W: f32 = 8388608.0; // 2 * 2^22 + +fn get_p_f32() -> BigIntF32 { + var p: BigIntF32; +{{{ p_limbs_f32 }}} + return p; +} + +fn bias_split_f32_le2w(x: f32) -> vec2 { + let hi = step(W, x); + let lo = fma(-hi, W, x); + return vec2(hi, lo); +} + +fn bias_split_f32_le4w(x: f32) -> vec2 { + let hi = floor(x * W_INV); + let lo = fma(-hi, W, x); + return vec2(hi, lo); +} + +fn mulhilo_sos3_corr(a: f32, a_scaled: f32, b: f32) -> vec2 { + let hi_off_inner = fma(a_scaled, b, W); + let hi_off = floor(hi_off_inner); + let neg_hi_w = fma(-W, hi_off, BIAS); + let lo0 = fma(a, b, neg_hi_w); + let hi_pre = hi_off - W; + let underflow = step(lo0, -0.5); + let hi = hi_pre - underflow; + let lo = fma(underflow, W, lo0); + return vec2(hi, lo); +} + +fn mulhilo_sos3_qi_lo(a: f32, b_scaled: f32, b: f32) -> f32 { + let hi_off_inner = fma(a, b_scaled, W); + let hi_off = floor(hi_off_inner); + let neg_hi_w = fma(-W, hi_off, BIAS); + let lo0 = fma(a, b, neg_hi_w); + let underflow = step(lo0, -0.5); + let lo = fma(underflow, W, lo0); + return lo; +} + +fn mulhilo_sos3_2_v2(a: vec2, a_scaled: vec2, b: vec2) -> vec4 { + let hi_off_inner = fma(a_scaled, b, vec2(W, W)); + let hi_off = floor(hi_off_inner); + let neg_hi_w = fma(vec2(-W, -W), hi_off, vec2(BIAS, BIAS)); + let lo = fma(a, b, neg_hi_w); + return vec4(hi_off.x, lo.x, hi_off.y, lo.y); +} + +fn montgomery_product_f32_unreduced(x: ptr, y: ptr) -> BigIntF32 { + var s0: f32 = 0.0; + var s1: f32 = 0.0; + var s2: f32 = 0.0; + var s3: f32 = 0.0; + var s4: f32 = 0.0; + var s5: f32 = 0.0; + var s6: f32 = 0.0; + var s7: f32 = 0.0; + var s8: f32 = 0.0; + var s9: f32 = 0.0; + var s10: f32 = 0.0; + var s11: f32 = 0.0; + + var p = get_p_f32(); + + // ===== Outer iter i=0 (unrolled — s[*] all zero, no shift-in). ===== + { + let x_i = (*x).limbs[0]; + let x_i_scaled = x_i * W_INV; + let xy0 = mulhilo_sos3_corr(x_i, x_i_scaled, (*y).limbs[0]); + let qi = mulhilo_sos3_qi_lo(xy0.y, N0_SCALED, N0); + let qi_scaled = qi * W_INV; + + let c_cancel = step(0.5, xy0.y); + let qp0_lo = c_cancel * (W - xy0.y); + let qp0_hi = fma(qi, p.limbs[0], -qp0_lo) * W_INV; + + // j=0's contribution to slot 0 (= position 1 unshifted). + let init_slot0 = xy0.x + qp0_hi + c_cancel; + + // Slot init for iter 0: tlo[0] gets init_slot0; everything else 0. +{{#slot_inits_i0}} + var {{name}}: f32 = {{init_expr}}; +{{/slot_inits_i0}} + + let xq_scaled = vec2(x_i_scaled, qi_scaled); + let xq = vec2(x_i, qi); + + // Inner-j writes: pair (y[j], p[j]) mulhilo, lo to tlo[j-1], hi to thi[j]. +{{#inner_pairs}} + { + let mh = mulhilo_sos3_2_v2(xq, xq_scaled, vec2((*y).limbs[{{j}}u], p.limbs[{{j}}u])); + let lo_sum = mh.y + mh.w; + let hi_sum = mh.x + mh.z - TWO_W; + tlo{{km1}} = tlo{{km1}} + lo_sum; + thi{{k}} = thi{{k}} + hi_sum; + } +{{/inner_pairs}} + + // Drain: combine tlo[k]+thi[k]+carry, split. Serial across cols. + var carry: f32 = 0.0; +{{#drain_cols}} + { + let sum = tlo{{k}} + thi{{k}} + carry; + let split = bias_split_f32_le4w(sum); + s{{k}} = split.y; + carry = split.x; + } +{{/drain_cols}} + } + + // ===== Outer iter i=1..NUM_LIMBS-1 (runtime loop). ===== + for (var i = 1u; i < NUM_LIMBS; i = i + 1u) { + let x_i = (*x).limbs[i]; + let x_i_scaled = x_i * W_INV; + + let xy0 = mulhilo_sos3_corr(x_i, x_i_scaled, (*y).limbs[0]); + let sum0 = s0 + xy0.y; + let sum0_s = bias_split_f32_le2w(sum0); + let qi = mulhilo_sos3_qi_lo(sum0_s.y, N0_SCALED, N0); + let qi_scaled = qi * W_INV; + + let c_cancel = step(0.5, sum0_s.y); + let qp0_lo = c_cancel * (W - sum0_s.y); + let qp0_hi = fma(qi, p.limbs[0], -qp0_lo) * W_INV; + + // j=0's contribution to slot 0 + s_old[1] shifted in. + let init_slot0 = xy0.x + qp0_hi + c_cancel + sum0_s.x; + + // Slot init for i>=1: tlo[k] pulls in s[k+1] as the shifted carry. +{{#slot_inits_general}} + var {{name}}: f32 = {{init_expr}}; +{{/slot_inits_general}} + + let xq_scaled = vec2(x_i_scaled, qi_scaled); + let xq = vec2(x_i, qi); + +{{#inner_pairs}} + { + let mh = mulhilo_sos3_2_v2(xq, xq_scaled, vec2((*y).limbs[{{j}}u], p.limbs[{{j}}u])); + let lo_sum = mh.y + mh.w; + let hi_sum = mh.x + mh.z - TWO_W; + tlo{{km1}} = tlo{{km1}} + lo_sum; + thi{{k}} = thi{{k}} + hi_sum; + } +{{/inner_pairs}} + + var carry: f32 = 0.0; +{{#drain_cols}} + { + let sum = tlo{{k}} + thi{{k}} + carry; + let split = bias_split_f32_le4w(sum); + s{{k}} = split.y; + carry = split.x; + } +{{/drain_cols}} + } + + var s: BigIntF32; + s.limbs[0] = s0; + s.limbs[1] = s1; + s.limbs[2] = s2; + s.limbs[3] = s3; + s.limbs[4] = s4; + s.limbs[5] = s5; + s.limbs[6] = s6; + s.limbs[7] = s7; + s.limbs[8] = s8; + s.limbs[9] = s9; + s.limbs[10] = s10; + s.limbs[11] = s11; + return s; +} + +fn montgomery_product_f32(x: ptr, y: ptr) -> BigIntF32 { + var s = montgomery_product_f32_unreduced(x, y); + var p = get_p_f32(); + return conditional_reduce_f32(&s, &p); +} + +fn conditional_reduce_f32(x: ptr, y: ptr) -> BigIntF32 { + if (bigint_f32_gt(x, y) || bigint_f32_eq(x, y)) { + var res: BigIntF32; + let _borrow = bigint_f32_sub(x, y, &res); + return res; + } + return *x; +} +`; + +export const mont_pro_product_karat_yuval = `// u32 Montgomery product via recursive KARATSUBA + YUVAL reduction. +// Fully unrolled — all indices compile-time constants so WGSL→MSL can +// SROA the temp slots into registers instead of thread-local memory. +// +// === Layout === +// 4 input chunks per operand (5 limbs each, named locals): +// x_lo_lo = x[0..4], x_lo_hi = x[5..9] +// x_hi_lo = x[10..14], x_hi_hi = x[15..19] +// (same naming for y) +// +// 9 sub-sub-products (each 5×5 schoolbook → 9 output limbs): +// For P_lo = x[0..9]·y[0..9]: pp_lo_LL, pp_lo_HH, pp_lo_C +// For P_hi = x[10..19]·y[10..19]: pp_hi_LL, pp_hi_HH, pp_hi_C +// For P_cr = (x[0..9]+x[10..19])·(y[0..9]+y[10..19]): +// pp_cr_LL, pp_cr_HH, pp_cr_C +// +// 3 outer sub-products (inner Karat combine): +// P_lo[k] = pp_lo_LL[k] + (pp_lo_C[k-5] - pp_lo_LL[k-5] - pp_lo_HH[k-5]) + pp_lo_HH[k-10] +// P_hi[k] = pp_hi_LL[k] + (pp_hi_C[k-5] - pp_hi_LL[k-5] - pp_hi_HH[k-5]) + pp_hi_HH[k-10] +// P_cr[k] = pp_cr_LL[k] + (pp_cr_C[k-5] - pp_cr_LL[k-5] - pp_cr_HH[k-5]) + pp_cr_HH[k-10] +// +// Outer combine into temp[0..38]: +// temp[k] += P_lo[k] for k in [0, 18] +// temp[k+20] += P_hi[k] for k in [0, 18] +// temp[k+10] += P_cr[k] - P_lo[k] - P_hi[k] for k in [0, 18] +// +// Yuval reduce (N-1 calls + 1 standard): +// for i in 0..N-1: +// t_mask = temp[i] & MASK; carry = temp[i] >> WORD_SIZE +// temp[i+1] += t_mask·R_INV[0] + carry +// temp[i+1+j] += t_mask·R_INV[j] for j in [1, N) +// standard reduce at i=N-1: k = (temp[i]&MASK)·N0 & MASK; +// temp[i+j] += k·p[j] for j in [0, N) +// plus the (temp[i]>>WORD_SIZE) carry folded into temp[i+1] +// +// Final canonicalization: single carry pass over the upper N slots. +// +// === Why no drains in the multiply phase === +// One inner sub-product (pp_cr_C slot 4 = 80·W² = 2³²·³²) wraps u32 by +// ~1.25×. This wrap is HARMLESS: the subsequent \`pp_cr_C - pp_cr_LL - +// pp_cr_HH\` subtraction unwinds the wrap via modular arithmetic, giving +// the correct (mathematically non-negative) pp_cr_mid value, which fits +// u32. Every other intermediate fits u32 directly. See +// karat_intermediate_check.mjs for the per-slot proof. +// +// Final temp[k] math bound: 40·W² = 2³¹·³² < 2³². ✓ Zero drains needed. + +const NUM_WORDS: u32 = {{ num_words }}u; +const WORD_SIZE: u32 = {{ word_size }}u; +const MASK: u32 = {{ mask }}u; +const TWO_POW_WORD_SIZE: u32 = {{ two_pow_word_size }}u; +const N0: u32 = {{ n0 }}u; +const P_INV_MOD_2W: u32 = {{ p_inv_mod_2w }}u; + +// r_inv = 2^{-WORD_SIZE} mod p, as N individual constants (NOT array — +// naga rejects runtime indexing into a const array, and the unrolled +// Yuval below uses each limb at a compile-time position anyway). +{{#r_inv_consts}} +const R_INV_{{idx}}: u32 = {{val}}u; +{{/r_inv_consts}} + +fn get_p() -> BigInt { + var p: BigInt; +{{{ p_limbs }}} + return p; +} + +fn montgomery_product(x_ptr: ptr, y_ptr: ptr) -> BigInt { + var p = get_p(); + + // === Input load: 40 named locals, one per limb. === +{{#input_loads}} + let {{name}}: u32 = (*{{ptr}}).limbs[{{k}}u]; +{{/input_loads}} + + // === Sums for inner Karatsuba: === + // a_lo_sum[k] = x_lo_lo[k] + x_lo_hi[k], a_hi_sum = x_hi_lo + x_hi_hi. + // For P_cross outer: a_cr_lo[k] = x_lo_lo[k] + x_hi_lo[k], a_cr_hi = + // x_lo_hi + x_hi_hi. Then inner cross sum: a_cr_sum = a_cr_lo + a_cr_hi. +{{#sum_lets}} + let {{name}}: u32 = {{lhs}} + {{rhs}}; +{{/sum_lets}} + + // === 9 sub-sub-products (5×5 schoolbook each). === + // Each output slot is a single \`let\` with a sum-of-products expression. +{{#schoolbooks}} + // --- {{label}}: out_prefix={{out_prefix}}, a={{a_prefix}}, b={{b_prefix}} --- + let {{out_prefix}}_0: u32 = {{a_prefix}}_0 * {{b_prefix}}_0; + let {{out_prefix}}_1: u32 = {{a_prefix}}_0 * {{b_prefix}}_1 + {{a_prefix}}_1 * {{b_prefix}}_0; + let {{out_prefix}}_2: u32 = {{a_prefix}}_0 * {{b_prefix}}_2 + {{a_prefix}}_1 * {{b_prefix}}_1 + {{a_prefix}}_2 * {{b_prefix}}_0; + let {{out_prefix}}_3: u32 = {{a_prefix}}_0 * {{b_prefix}}_3 + {{a_prefix}}_1 * {{b_prefix}}_2 + {{a_prefix}}_2 * {{b_prefix}}_1 + {{a_prefix}}_3 * {{b_prefix}}_0; + let {{out_prefix}}_4: u32 = {{a_prefix}}_0 * {{b_prefix}}_4 + {{a_prefix}}_1 * {{b_prefix}}_3 + {{a_prefix}}_2 * {{b_prefix}}_2 + {{a_prefix}}_3 * {{b_prefix}}_1 + {{a_prefix}}_4 * {{b_prefix}}_0; + let {{out_prefix}}_5: u32 = {{a_prefix}}_1 * {{b_prefix}}_4 + {{a_prefix}}_2 * {{b_prefix}}_3 + {{a_prefix}}_3 * {{b_prefix}}_2 + {{a_prefix}}_4 * {{b_prefix}}_1; + let {{out_prefix}}_6: u32 = {{a_prefix}}_2 * {{b_prefix}}_4 + {{a_prefix}}_3 * {{b_prefix}}_3 + {{a_prefix}}_4 * {{b_prefix}}_2; + let {{out_prefix}}_7: u32 = {{a_prefix}}_3 * {{b_prefix}}_4 + {{a_prefix}}_4 * {{b_prefix}}_3; + let {{out_prefix}}_8: u32 = {{a_prefix}}_4 * {{b_prefix}}_4; +{{/schoolbooks}} + + // === Inner Karatsuba combine: form P_lo, P_hi, P_cross as named locals. === + // For each, 19 outputs combining the 3 sub-sub-products. The mid term + // uses unsigned subtraction; underflow is impossible because the + // algebraic identity P_mid[m] = Σ a_lo·b_hi + a_hi·b_lo is non-neg + // per-limb at the lazy values. +{{#inner_combines}} + // --- {{label}}: out_prefix={{out_prefix}}, ll={{ll_prefix}}, hh={{hh_prefix}}, c={{c_prefix}} --- + let {{out_prefix}}_0: u32 = {{ll_prefix}}_0; + let {{out_prefix}}_1: u32 = {{ll_prefix}}_1; + let {{out_prefix}}_2: u32 = {{ll_prefix}}_2; + let {{out_prefix}}_3: u32 = {{ll_prefix}}_3; + let {{out_prefix}}_4: u32 = {{ll_prefix}}_4; + let {{out_prefix}}_5: u32 = {{ll_prefix}}_5 + {{c_prefix}}_0 - {{ll_prefix}}_0 - {{hh_prefix}}_0; + let {{out_prefix}}_6: u32 = {{ll_prefix}}_6 + {{c_prefix}}_1 - {{ll_prefix}}_1 - {{hh_prefix}}_1; + let {{out_prefix}}_7: u32 = {{ll_prefix}}_7 + {{c_prefix}}_2 - {{ll_prefix}}_2 - {{hh_prefix}}_2; + let {{out_prefix}}_8: u32 = {{ll_prefix}}_8 + {{c_prefix}}_3 - {{ll_prefix}}_3 - {{hh_prefix}}_3; + let {{out_prefix}}_9: u32 = {{c_prefix}}_4 - {{ll_prefix}}_4 - {{hh_prefix}}_4; + let {{out_prefix}}_10: u32 = {{c_prefix}}_5 - {{ll_prefix}}_5 - {{hh_prefix}}_5 + {{hh_prefix}}_0; + let {{out_prefix}}_11: u32 = {{c_prefix}}_6 - {{ll_prefix}}_6 - {{hh_prefix}}_6 + {{hh_prefix}}_1; + let {{out_prefix}}_12: u32 = {{c_prefix}}_7 - {{ll_prefix}}_7 - {{hh_prefix}}_7 + {{hh_prefix}}_2; + let {{out_prefix}}_13: u32 = {{c_prefix}}_8 - {{ll_prefix}}_8 - {{hh_prefix}}_8 + {{hh_prefix}}_3; + let {{out_prefix}}_14: u32 = {{hh_prefix}}_4; + let {{out_prefix}}_15: u32 = {{hh_prefix}}_5; + let {{out_prefix}}_16: u32 = {{hh_prefix}}_6; + let {{out_prefix}}_17: u32 = {{hh_prefix}}_7; + let {{out_prefix}}_18: u32 = {{hh_prefix}}_8; +{{/inner_combines}} + + // === Outer combine: initialize temp[0..39]. === + // Slots [0,18] = P_lo. Slots [20,38] = P_hi. Slots [10,28] += P_cr - P_lo - P_hi. + // We initialize as \`var\` (mutable) so the Yuval phase can mutate. +{{#outer_init}} + var t{{slot}}: u32 = {{init_expr}}; +{{/outer_init}} + + // === Outer Karatsuba cross combine: temp[k+10] += P_cr[k] - P_lo[k] - P_hi[k]. === + // Unsigned subtraction is safe per the outer-level algebraic identity: + // P_mid_outer[k] = Σ x_lo·y_hi + x_hi·y_lo is non-neg per-limb. +{{#outer_cross}} + t{{slot}} = t{{slot}} + p_cr_{{k}} - p_lo_{{k}} - p_hi_{{k}}; +{{/outer_cross}} + + // === Yuval reduce: 19 Yuval calls + 1 standard. === + // Each call extracts t_mask & carry from t{i}, then accumulates + // t_mask·R_INV[j] into t{i+1+j} for j=0..N-1 (the carry folds into + // the j=0 write). +{{#yuval_iters}} + { + let t_mask: u32 = t{{i}} & MASK; + let carry: u32 = t{{i}} >> WORD_SIZE; +{{#writes}} + t{{slot}} = t{{slot}} + t_mask * R_INV_{{r_idx}}{{#first}} + carry{{/first}}; +{{/writes}} + } +{{/yuval_iters}} + + // Standard Mont reduce for the last iter (i = N-1 = 19). + { + let t_mask: u32 = t{{i_std}} & MASK; + let k_std: u32 = (t_mask * N0) & MASK; +{{#standard_writes}} + t{{slot}} = t{{slot}} + k_std * p.limbs[{{p_idx}}u]{{#first}} + (t{{i_std}} >> WORD_SIZE){{/first}}; +{{/standard_writes}} + } + + // === Final canonicalization (single carry pass over t20..t39). === + var c: u32 = 0u; +{{#final_drain}} + { + let v: u32 = t{{slot}} + c; + c = v >> WORD_SIZE; + t{{slot}} = v & MASK; + } +{{/final_drain}} + + // === Repack canonical limbs into BigInt and return. === + var s: BigInt; +{{#extract}} + s.limbs[{{out_k}}u] = t{{src_slot}}; +{{/extract}} + + return conditional_reduce(&s, &p); +} + +fn conditional_reduce(x: ptr, y: ptr) -> BigInt { + var x_gt_y = bigint_gt(x, y); + var x_eq_y = bigint_eq(x, y); + if (x_gt_y == 1u || x_eq_y) { + var res: BigInt; + bigint_sub(x, y, &res); + return res; + } + return *x; +} +`; + +export const mulhilo_22 = `// 22-bit-limb sibling of \`mulhilo.wgsl\`. Same FMA-bias trick, different +// constants (W = 2^22, BIAS = 2^44). The two files are never compiled +// into the same WGSL module — callers pick one bundle or the other. +// +// Inputs: a, b in [0, 2^22) stored as f32 (integer-valued). +// Output: vec2(hi, lo) with a*b == hi * 2^22 + lo, hi, lo in [0, 2^22). +// +// Math (load-bearing): +// fma(a, b, BIAS) computes a*b + 2^44 with infinite intermediate +// precision and rounds ONCE to f32. Inside [2^44, 2^45) the f32 ULP is +// exactly 2^21, so the rounded result lies on a multiple of 2^21 +// (not 2^22!). Dividing by W and flooring effectively rounds to W, +// giving a non-negative integer q/W < W. The lo correction follows the +// same step()-based underflow scheme as the 23-bit variant. +// +// Important difference from 23-bit: at 22-bit, a*b in [0, 2^44) plus +// BIAS = 2^44 sits exactly at the boundary [2^44, 2^45]. Inside that +// range ULP is 2^21, so \`floor(abp * W_INV)\` rounds to the nearest +// multiple of W/2 — but then \`lo0 = fma(a, b, -hi_pre * W)\` recovers +// exactly. The step() correction handles the half-W slop. +const BIAS: f32 = 17592186044416.0; // 2^44 +const W: f32 = 4194304.0; // 2^22 +const W_INV: f32 = 2.384185791015625e-7; // 2^-22 + +fn mulhilo(a: f32, b: f32) -> vec2 { + let abp = fma(a, b, BIAS); + let hi_pre = floor(abp * W_INV) - W; + let lo0 = fma(a, b, -hi_pre * W); + let underflow = step(lo0, -0.5); + let hi = hi_pre - underflow; + let lo = lo0 + underflow * W; + return vec2(hi, lo); +} + +fn mulhilo2(a: vec2, b: vec2) -> vec4 { + let abp = fma(a, b, vec2(BIAS, BIAS)); + let hi_pre = floor(abp * W_INV) - vec2(W, W); + let lo0 = fma(a, b, -hi_pre * W); + let underflow = step(lo0, vec2(-0.5, -0.5)); + let hi = hi_pre - underflow; + let lo = lo0 + underflow * W; + return vec4(hi.x, lo.x, hi.y, lo.y); +} +`; + export const structs = `struct Point { x: BigInt, y: BigInt, diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/bigint/bigint_f32.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/bigint/bigint_f32.template.wgsl new file mode 100644 index 000000000000..2d50925fb5b3 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/bigint/bigint_f32.template.wgsl @@ -0,0 +1,64 @@ +// f32-limb mirror of `bigint.template.wgsl`. Each limb holds an +// integer-valued f32 in [0, 2^WORD_SIZE_F32) = [0, W). Mirrors the +// subset of `bigint.template.wgsl` needed by `montgomery_product_f32` +// and the f32 field ops in `field/field_f32.template.wgsl` / +// `field/fr_pow_f32.template.wgsl`. + +struct BigIntF32 { + limbs: array +} + +fn bigint_f32_gt(x: ptr, y: ptr) -> bool { + for (var idx = 0u; idx < {{ num_limbs }}u; idx ++) { + let i = {{ num_limbs }}u - 1u - idx; + if ((*x).limbs[i] < (*y).limbs[i]) { return false; } + if ((*x).limbs[i] > (*y).limbs[i]) { return true; } + } + return false; +} + +fn bigint_f32_eq(x: ptr, y: ptr) -> bool { + for (var i = 0u; i < {{ num_limbs }}u; i ++) { + if ((*x).limbs[i] != (*y).limbs[i]) { return false; } + } + return true; +} + +// res = a - b. Per-limb borrow lives as an f32 in {0.0, 1.0}; +// `step(diff, -0.5)` is 1.0 iff diff is a negative integer (i.e., +// underflowed). Adding `underflow * W` then canonicalises the limb +// back into [0, W). Returns the final borrow-out (0.0 or 1.0): callers +// that need to know whether a >= b consult this flag, mirroring the +// u32 `bigint_sub`'s return convention. +fn bigint_f32_sub(a: ptr, b: ptr, res: ptr) -> f32 { + var borrow: f32 = 0.0; + for (var i = 0u; i < {{ num_limbs }}u; i ++) { + let diff = (*a).limbs[i] - (*b).limbs[i] - borrow; + let underflow = step(diff, -0.5); + (*res).limbs[i] = diff + underflow * W; + borrow = underflow; + } + return borrow; +} + +// res = a + b. Per-limb carry lives as an f32 in {0.0, 1.0}; `step(W-0.5, sum)` +// is 1.0 iff sum >= W (the bias-free branchless test). Each sum is at most +// 2*(W-1) + 1 = 2^24 - 1, exact in the 24-bit f32 mantissa. Returns the +// final carry-out (0.0 or 1.0) for downstream conditional-reduce logic. +fn bigint_f32_add(a: ptr, b: ptr, res: ptr) -> f32 { + var carry: f32 = 0.0; + for (var i = 0u; i < {{ num_limbs }}u; i ++) { + let sum = (*a).limbs[i] + (*b).limbs[i] + carry; + let overflow = step(W - 0.5, sum); + (*res).limbs[i] = sum - overflow * W; + carry = overflow; + } + return carry; +} + +fn bigint_f32_is_zero(x: ptr) -> bool { + for (var i = 0u; i < {{ num_limbs }}u; i ++) { + if ((*x).limbs[i] != 0.0) { return false; } + } + return true; +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/field_mul_bench_f32.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/field_mul_bench_f32.template.wgsl new file mode 100644 index 000000000000..869da2faf83c --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/field_mul_bench_f32.template.wgsl @@ -0,0 +1,42 @@ +// Field-mul micro-benchmark, f32 / 12×23-bit limbs (CIOS over f32 FMA). +// Every thread loads one (a, b) pair, runs `k` chained Montgomery +// products (a = montgomery_product_f32(a, b) repeated k times), and +// writes the final `a` back. The host caps `k` at <=100 before passing +// it in. +// +// Loop bounds. The only data-dependent loop is `for (var i = 0u; i < k; i++)` +// where `k = params.y` is a host-uniform capped at 100. The early-out +// `if (tid >= n)` is a guard, not a loop bound, and `n = params.x` is +// host-capped at 2^23 to keep dispatch sizes reasonable. The inner +// `montgomery_product_f32` loops are bounded by the compile-time constant +// `NUM_LIMBS`. + +// Inputs are split into two separate `xs`/`ys` arrays (one BigIntF32 +// per thread per array). This mirrors the working layout used by +// `testMontgomeryProductF32` in wgsl_unit_tests.ts — packing into a +// single `array` produced all-zero outputs on Dawn/Metal +// even when the inputs read back correctly via a passthrough variant, +// so we sidestep struct-in-array layout concerns entirely. +@group(0) @binding(0) var xs: array; +@group(0) @binding(1) var ys: array; +@group(0) @binding(2) var outputs: array; +@group(0) @binding(3) var params: vec2; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let n = params.x; + let k = params.y; + let tid = gid.x; + if (tid >= n) { return; } + + var a = xs[tid]; + var b = ys[tid]; + // Chained loop calls the unreduced body — Mont is closed over inputs + // in [0, 2p) since all limbs stay < W. Final conditional_reduce + // canonicalizes the result. + for (var i = 0u; i < k; i = i + 1u) { + a = montgomery_product_f32_unreduced(&a, &b); + } + var p_final = get_p_f32(); + outputs[tid] = conditional_reduce_f32(&a, &p_final); +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/field_mul_bench_u32.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/field_mul_bench_u32.template.wgsl new file mode 100644 index 000000000000..f45ca08133c2 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/field_mul_bench_u32.template.wgsl @@ -0,0 +1,36 @@ +// Field-mul micro-benchmark, u32 / 20×13-bit limbs (Mitschabaude-CIOS). +// Every thread loads one (a, b) pair, runs `k` chained Montgomery +// products (a = montgomery_product(a, b) repeated k times), and writes +// the final `a` back. The host caps `k` at <=100 before passing it in. +// +// Loop bounds. The only data-dependent loop is `for (var i = 0u; i < k; i++)` +// where `k = params.y` is a host-uniform capped at 100. The early-out +// `if (tid >= n)` is a guard, not a loop bound, and `n = params.x` is +// host-capped at 2^23 to keep dispatch sizes reasonable. The inner +// `montgomery_product` loops are bounded by the compile-time constant +// `NUM_WORDS`. + +// Inputs are split into two separate `xs`/`ys` arrays (one BigInt per +// thread per array). Matches the f32 path's layout for symmetry — the +// `array` packing was found to produce all-zero outputs on +// Dawn/Metal in the f32 variant of this shader; using separate arrays +// sidesteps that concern. +@group(0) @binding(0) var xs: array; +@group(0) @binding(1) var ys: array; +@group(0) @binding(2) var outputs: array; +@group(0) @binding(3) var params: vec2; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let n = params.x; + let k = params.y; + let tid = gid.x; + if (tid >= n) { return; } + + var a = xs[tid]; + var b = ys[tid]; + for (var i = 0u; i < k; i = i + 1u) { + a = montgomery_product(&a, &b); + } + outputs[tid] = a; +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/montgomery/mont_pro_product_f32_22_sos3uv3.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/montgomery/mont_pro_product_f32_22_sos3uv3.template.wgsl new file mode 100644 index 000000000000..933594bf4f10 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/montgomery/mont_pro_product_f32_22_sos3uv3.template.wgsl @@ -0,0 +1,215 @@ +// 22-bit-limb f32 Montgomery product, sos3uv3 — chain breaking via +// SEPARATE per-slot f32 accumulators (lo and hi). +// +// Inside the inner-j loop: +// - sos3uv2 had a single chain c_hi/c_lo flowing j → j+1 (serial). +// - sos3u32 used i32 accumulators (broken chain but slow int ops). +// - sos3uv3 uses TWO per-slot f32 accumulators (tlo[k], thi[k]). +// Each j writes UNIQUE tlo[j-1] and thi[j] — no overlap across j's, +// no carry chain. Stays in f32 throughout (no int conversions). +// +// Per-slot bound: tlo[k] = s_old[k+1] + 1 lo (≤ 2.5W). thi[k] = 1 hi +// (≤ W). Combined at drain: 3.5W < 2²³·⁸ < 2²⁴ — fits f32 exactly. +// +// Drain at end of each outer iter: combine tlo[k] + thi[k] + carry, +// floor-based bias_split, write canonical s[k]. Serial across cols. +// +// === Structure === +// Outer iter i=0 is unrolled (no shifted-in s values to read). Iters +// 1..N-1 are in a runtime `for` loop with a separate slot_init that pulls +// in s1..s11. Inner-j pairs and drain cols are mustache-expanded — keeps +// the named locals (tlo0..tlo11, thi0..thi11) which Metal can register- +// allocate instead of spilling to thread-private memory. + +const NUM_LIMBS: u32 = {{ num_limbs }}u; +const N0: f32 = {{ n0 }}; +const N0_SCALED: f32 = {{ n0_scaled }}; +const TWO_W: f32 = 8388608.0; // 2 * 2^22 + +fn get_p_f32() -> BigIntF32 { + var p: BigIntF32; +{{{ p_limbs_f32 }}} + return p; +} + +fn bias_split_f32_le2w(x: f32) -> vec2 { + let hi = step(W, x); + let lo = fma(-hi, W, x); + return vec2(hi, lo); +} + +fn bias_split_f32_le4w(x: f32) -> vec2 { + let hi = floor(x * W_INV); + let lo = fma(-hi, W, x); + return vec2(hi, lo); +} + +fn mulhilo_sos3_corr(a: f32, a_scaled: f32, b: f32) -> vec2 { + let hi_off_inner = fma(a_scaled, b, W); + let hi_off = floor(hi_off_inner); + let neg_hi_w = fma(-W, hi_off, BIAS); + let lo0 = fma(a, b, neg_hi_w); + let hi_pre = hi_off - W; + let underflow = step(lo0, -0.5); + let hi = hi_pre - underflow; + let lo = fma(underflow, W, lo0); + return vec2(hi, lo); +} + +fn mulhilo_sos3_qi_lo(a: f32, b_scaled: f32, b: f32) -> f32 { + let hi_off_inner = fma(a, b_scaled, W); + let hi_off = floor(hi_off_inner); + let neg_hi_w = fma(-W, hi_off, BIAS); + let lo0 = fma(a, b, neg_hi_w); + let underflow = step(lo0, -0.5); + let lo = fma(underflow, W, lo0); + return lo; +} + +fn mulhilo_sos3_2_v2(a: vec2, a_scaled: vec2, b: vec2) -> vec4 { + let hi_off_inner = fma(a_scaled, b, vec2(W, W)); + let hi_off = floor(hi_off_inner); + let neg_hi_w = fma(vec2(-W, -W), hi_off, vec2(BIAS, BIAS)); + let lo = fma(a, b, neg_hi_w); + return vec4(hi_off.x, lo.x, hi_off.y, lo.y); +} + +fn montgomery_product_f32_unreduced(x: ptr, y: ptr) -> BigIntF32 { + var s0: f32 = 0.0; + var s1: f32 = 0.0; + var s2: f32 = 0.0; + var s3: f32 = 0.0; + var s4: f32 = 0.0; + var s5: f32 = 0.0; + var s6: f32 = 0.0; + var s7: f32 = 0.0; + var s8: f32 = 0.0; + var s9: f32 = 0.0; + var s10: f32 = 0.0; + var s11: f32 = 0.0; + + var p = get_p_f32(); + + // ===== Outer iter i=0 (unrolled — s[*] all zero, no shift-in). ===== + { + let x_i = (*x).limbs[0]; + let x_i_scaled = x_i * W_INV; + let xy0 = mulhilo_sos3_corr(x_i, x_i_scaled, (*y).limbs[0]); + let qi = mulhilo_sos3_qi_lo(xy0.y, N0_SCALED, N0); + let qi_scaled = qi * W_INV; + + let c_cancel = step(0.5, xy0.y); + let qp0_lo = c_cancel * (W - xy0.y); + let qp0_hi = fma(qi, p.limbs[0], -qp0_lo) * W_INV; + + // j=0's contribution to slot 0 (= position 1 unshifted). + let init_slot0 = xy0.x + qp0_hi + c_cancel; + + // Slot init for iter 0: tlo[0] gets init_slot0; everything else 0. +{{#slot_inits_i0}} + var {{name}}: f32 = {{init_expr}}; +{{/slot_inits_i0}} + + let xq_scaled = vec2(x_i_scaled, qi_scaled); + let xq = vec2(x_i, qi); + + // Inner-j writes: pair (y[j], p[j]) mulhilo, lo to tlo[j-1], hi to thi[j]. +{{#inner_pairs}} + { + let mh = mulhilo_sos3_2_v2(xq, xq_scaled, vec2((*y).limbs[{{j}}u], p.limbs[{{j}}u])); + let lo_sum = mh.y + mh.w; + let hi_sum = mh.x + mh.z - TWO_W; + tlo{{km1}} = tlo{{km1}} + lo_sum; + thi{{k}} = thi{{k}} + hi_sum; + } +{{/inner_pairs}} + + // Drain: combine tlo[k]+thi[k]+carry, split. Serial across cols. + var carry: f32 = 0.0; +{{#drain_cols}} + { + let sum = tlo{{k}} + thi{{k}} + carry; + let split = bias_split_f32_le4w(sum); + s{{k}} = split.y; + carry = split.x; + } +{{/drain_cols}} + } + + // ===== Outer iter i=1..NUM_LIMBS-1 (runtime loop). ===== + for (var i = 1u; i < NUM_LIMBS; i = i + 1u) { + let x_i = (*x).limbs[i]; + let x_i_scaled = x_i * W_INV; + + let xy0 = mulhilo_sos3_corr(x_i, x_i_scaled, (*y).limbs[0]); + let sum0 = s0 + xy0.y; + let sum0_s = bias_split_f32_le2w(sum0); + let qi = mulhilo_sos3_qi_lo(sum0_s.y, N0_SCALED, N0); + let qi_scaled = qi * W_INV; + + let c_cancel = step(0.5, sum0_s.y); + let qp0_lo = c_cancel * (W - sum0_s.y); + let qp0_hi = fma(qi, p.limbs[0], -qp0_lo) * W_INV; + + // j=0's contribution to slot 0 + s_old[1] shifted in. + let init_slot0 = xy0.x + qp0_hi + c_cancel + sum0_s.x; + + // Slot init for i>=1: tlo[k] pulls in s[k+1] as the shifted carry. +{{#slot_inits_general}} + var {{name}}: f32 = {{init_expr}}; +{{/slot_inits_general}} + + let xq_scaled = vec2(x_i_scaled, qi_scaled); + let xq = vec2(x_i, qi); + +{{#inner_pairs}} + { + let mh = mulhilo_sos3_2_v2(xq, xq_scaled, vec2((*y).limbs[{{j}}u], p.limbs[{{j}}u])); + let lo_sum = mh.y + mh.w; + let hi_sum = mh.x + mh.z - TWO_W; + tlo{{km1}} = tlo{{km1}} + lo_sum; + thi{{k}} = thi{{k}} + hi_sum; + } +{{/inner_pairs}} + + var carry: f32 = 0.0; +{{#drain_cols}} + { + let sum = tlo{{k}} + thi{{k}} + carry; + let split = bias_split_f32_le4w(sum); + s{{k}} = split.y; + carry = split.x; + } +{{/drain_cols}} + } + + var s: BigIntF32; + s.limbs[0] = s0; + s.limbs[1] = s1; + s.limbs[2] = s2; + s.limbs[3] = s3; + s.limbs[4] = s4; + s.limbs[5] = s5; + s.limbs[6] = s6; + s.limbs[7] = s7; + s.limbs[8] = s8; + s.limbs[9] = s9; + s.limbs[10] = s10; + s.limbs[11] = s11; + return s; +} + +fn montgomery_product_f32(x: ptr, y: ptr) -> BigIntF32 { + var s = montgomery_product_f32_unreduced(x, y); + var p = get_p_f32(); + return conditional_reduce_f32(&s, &p); +} + +fn conditional_reduce_f32(x: ptr, y: ptr) -> BigIntF32 { + if (bigint_f32_gt(x, y) || bigint_f32_eq(x, y)) { + var res: BigIntF32; + let _borrow = bigint_f32_sub(x, y, &res); + return res; + } + return *x; +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/montgomery/mont_pro_product_karat_yuval.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/montgomery/mont_pro_product_karat_yuval.template.wgsl new file mode 100644 index 000000000000..2deb60efe5da --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/montgomery/mont_pro_product_karat_yuval.template.wgsl @@ -0,0 +1,192 @@ +// u32 Montgomery product via recursive KARATSUBA + YUVAL reduction. +// Fully unrolled — all indices compile-time constants so WGSL→MSL can +// SROA the temp slots into registers instead of thread-local memory. +// +// === Layout === +// 4 input chunks per operand (5 limbs each, named locals): +// x_lo_lo = x[0..4], x_lo_hi = x[5..9] +// x_hi_lo = x[10..14], x_hi_hi = x[15..19] +// (same naming for y) +// +// 9 sub-sub-products (each 5×5 schoolbook → 9 output limbs): +// For P_lo = x[0..9]·y[0..9]: pp_lo_LL, pp_lo_HH, pp_lo_C +// For P_hi = x[10..19]·y[10..19]: pp_hi_LL, pp_hi_HH, pp_hi_C +// For P_cr = (x[0..9]+x[10..19])·(y[0..9]+y[10..19]): +// pp_cr_LL, pp_cr_HH, pp_cr_C +// +// 3 outer sub-products (inner Karat combine): +// P_lo[k] = pp_lo_LL[k] + (pp_lo_C[k-5] - pp_lo_LL[k-5] - pp_lo_HH[k-5]) + pp_lo_HH[k-10] +// P_hi[k] = pp_hi_LL[k] + (pp_hi_C[k-5] - pp_hi_LL[k-5] - pp_hi_HH[k-5]) + pp_hi_HH[k-10] +// P_cr[k] = pp_cr_LL[k] + (pp_cr_C[k-5] - pp_cr_LL[k-5] - pp_cr_HH[k-5]) + pp_cr_HH[k-10] +// +// Outer combine into temp[0..38]: +// temp[k] += P_lo[k] for k in [0, 18] +// temp[k+20] += P_hi[k] for k in [0, 18] +// temp[k+10] += P_cr[k] - P_lo[k] - P_hi[k] for k in [0, 18] +// +// Yuval reduce (N-1 calls + 1 standard): +// for i in 0..N-1: +// t_mask = temp[i] & MASK; carry = temp[i] >> WORD_SIZE +// temp[i+1] += t_mask·R_INV[0] + carry +// temp[i+1+j] += t_mask·R_INV[j] for j in [1, N) +// standard reduce at i=N-1: k = (temp[i]&MASK)·N0 & MASK; +// temp[i+j] += k·p[j] for j in [0, N) +// plus the (temp[i]>>WORD_SIZE) carry folded into temp[i+1] +// +// Final canonicalization: single carry pass over the upper N slots. +// +// === Why no drains in the multiply phase === +// One inner sub-product (pp_cr_C slot 4 = 80·W² = 2³²·³²) wraps u32 by +// ~1.25×. This wrap is HARMLESS: the subsequent `pp_cr_C - pp_cr_LL - +// pp_cr_HH` subtraction unwinds the wrap via modular arithmetic, giving +// the correct (mathematically non-negative) pp_cr_mid value, which fits +// u32. Every other intermediate fits u32 directly. See +// karat_intermediate_check.mjs for the per-slot proof. +// +// Final temp[k] math bound: 40·W² = 2³¹·³² < 2³². ✓ Zero drains needed. + +const NUM_WORDS: u32 = {{ num_words }}u; +const WORD_SIZE: u32 = {{ word_size }}u; +const MASK: u32 = {{ mask }}u; +const TWO_POW_WORD_SIZE: u32 = {{ two_pow_word_size }}u; +const N0: u32 = {{ n0 }}u; +const P_INV_MOD_2W: u32 = {{ p_inv_mod_2w }}u; + +// r_inv = 2^{-WORD_SIZE} mod p, as N individual constants (NOT array — +// naga rejects runtime indexing into a const array, and the unrolled +// Yuval below uses each limb at a compile-time position anyway). +{{#r_inv_consts}} +const R_INV_{{idx}}: u32 = {{val}}u; +{{/r_inv_consts}} + +fn get_p() -> BigInt { + var p: BigInt; +{{{ p_limbs }}} + return p; +} + +fn montgomery_product(x_ptr: ptr, y_ptr: ptr) -> BigInt { + var p = get_p(); + + // === Input load: 40 named locals, one per limb. === +{{#input_loads}} + let {{name}}: u32 = (*{{ptr}}).limbs[{{k}}u]; +{{/input_loads}} + + // === Sums for inner Karatsuba: === + // a_lo_sum[k] = x_lo_lo[k] + x_lo_hi[k], a_hi_sum = x_hi_lo + x_hi_hi. + // For P_cross outer: a_cr_lo[k] = x_lo_lo[k] + x_hi_lo[k], a_cr_hi = + // x_lo_hi + x_hi_hi. Then inner cross sum: a_cr_sum = a_cr_lo + a_cr_hi. +{{#sum_lets}} + let {{name}}: u32 = {{lhs}} + {{rhs}}; +{{/sum_lets}} + + // === 9 sub-sub-products (5×5 schoolbook each). === + // Each output slot is a single `let` with a sum-of-products expression. +{{#schoolbooks}} + // --- {{label}}: out_prefix={{out_prefix}}, a={{a_prefix}}, b={{b_prefix}} --- + let {{out_prefix}}_0: u32 = {{a_prefix}}_0 * {{b_prefix}}_0; + let {{out_prefix}}_1: u32 = {{a_prefix}}_0 * {{b_prefix}}_1 + {{a_prefix}}_1 * {{b_prefix}}_0; + let {{out_prefix}}_2: u32 = {{a_prefix}}_0 * {{b_prefix}}_2 + {{a_prefix}}_1 * {{b_prefix}}_1 + {{a_prefix}}_2 * {{b_prefix}}_0; + let {{out_prefix}}_3: u32 = {{a_prefix}}_0 * {{b_prefix}}_3 + {{a_prefix}}_1 * {{b_prefix}}_2 + {{a_prefix}}_2 * {{b_prefix}}_1 + {{a_prefix}}_3 * {{b_prefix}}_0; + let {{out_prefix}}_4: u32 = {{a_prefix}}_0 * {{b_prefix}}_4 + {{a_prefix}}_1 * {{b_prefix}}_3 + {{a_prefix}}_2 * {{b_prefix}}_2 + {{a_prefix}}_3 * {{b_prefix}}_1 + {{a_prefix}}_4 * {{b_prefix}}_0; + let {{out_prefix}}_5: u32 = {{a_prefix}}_1 * {{b_prefix}}_4 + {{a_prefix}}_2 * {{b_prefix}}_3 + {{a_prefix}}_3 * {{b_prefix}}_2 + {{a_prefix}}_4 * {{b_prefix}}_1; + let {{out_prefix}}_6: u32 = {{a_prefix}}_2 * {{b_prefix}}_4 + {{a_prefix}}_3 * {{b_prefix}}_3 + {{a_prefix}}_4 * {{b_prefix}}_2; + let {{out_prefix}}_7: u32 = {{a_prefix}}_3 * {{b_prefix}}_4 + {{a_prefix}}_4 * {{b_prefix}}_3; + let {{out_prefix}}_8: u32 = {{a_prefix}}_4 * {{b_prefix}}_4; +{{/schoolbooks}} + + // === Inner Karatsuba combine: form P_lo, P_hi, P_cross as named locals. === + // For each, 19 outputs combining the 3 sub-sub-products. The mid term + // uses unsigned subtraction; underflow is impossible because the + // algebraic identity P_mid[m] = Σ a_lo·b_hi + a_hi·b_lo is non-neg + // per-limb at the lazy values. +{{#inner_combines}} + // --- {{label}}: out_prefix={{out_prefix}}, ll={{ll_prefix}}, hh={{hh_prefix}}, c={{c_prefix}} --- + let {{out_prefix}}_0: u32 = {{ll_prefix}}_0; + let {{out_prefix}}_1: u32 = {{ll_prefix}}_1; + let {{out_prefix}}_2: u32 = {{ll_prefix}}_2; + let {{out_prefix}}_3: u32 = {{ll_prefix}}_3; + let {{out_prefix}}_4: u32 = {{ll_prefix}}_4; + let {{out_prefix}}_5: u32 = {{ll_prefix}}_5 + {{c_prefix}}_0 - {{ll_prefix}}_0 - {{hh_prefix}}_0; + let {{out_prefix}}_6: u32 = {{ll_prefix}}_6 + {{c_prefix}}_1 - {{ll_prefix}}_1 - {{hh_prefix}}_1; + let {{out_prefix}}_7: u32 = {{ll_prefix}}_7 + {{c_prefix}}_2 - {{ll_prefix}}_2 - {{hh_prefix}}_2; + let {{out_prefix}}_8: u32 = {{ll_prefix}}_8 + {{c_prefix}}_3 - {{ll_prefix}}_3 - {{hh_prefix}}_3; + let {{out_prefix}}_9: u32 = {{c_prefix}}_4 - {{ll_prefix}}_4 - {{hh_prefix}}_4; + let {{out_prefix}}_10: u32 = {{c_prefix}}_5 - {{ll_prefix}}_5 - {{hh_prefix}}_5 + {{hh_prefix}}_0; + let {{out_prefix}}_11: u32 = {{c_prefix}}_6 - {{ll_prefix}}_6 - {{hh_prefix}}_6 + {{hh_prefix}}_1; + let {{out_prefix}}_12: u32 = {{c_prefix}}_7 - {{ll_prefix}}_7 - {{hh_prefix}}_7 + {{hh_prefix}}_2; + let {{out_prefix}}_13: u32 = {{c_prefix}}_8 - {{ll_prefix}}_8 - {{hh_prefix}}_8 + {{hh_prefix}}_3; + let {{out_prefix}}_14: u32 = {{hh_prefix}}_4; + let {{out_prefix}}_15: u32 = {{hh_prefix}}_5; + let {{out_prefix}}_16: u32 = {{hh_prefix}}_6; + let {{out_prefix}}_17: u32 = {{hh_prefix}}_7; + let {{out_prefix}}_18: u32 = {{hh_prefix}}_8; +{{/inner_combines}} + + // === Outer combine: initialize temp[0..39]. === + // Slots [0,18] = P_lo. Slots [20,38] = P_hi. Slots [10,28] += P_cr - P_lo - P_hi. + // We initialize as `var` (mutable) so the Yuval phase can mutate. +{{#outer_init}} + var t{{slot}}: u32 = {{init_expr}}; +{{/outer_init}} + + // === Outer Karatsuba cross combine: temp[k+10] += P_cr[k] - P_lo[k] - P_hi[k]. === + // Unsigned subtraction is safe per the outer-level algebraic identity: + // P_mid_outer[k] = Σ x_lo·y_hi + x_hi·y_lo is non-neg per-limb. +{{#outer_cross}} + t{{slot}} = t{{slot}} + p_cr_{{k}} - p_lo_{{k}} - p_hi_{{k}}; +{{/outer_cross}} + + // === Yuval reduce: 19 Yuval calls + 1 standard. === + // Each call extracts t_mask & carry from t{i}, then accumulates + // t_mask·R_INV[j] into t{i+1+j} for j=0..N-1 (the carry folds into + // the j=0 write). +{{#yuval_iters}} + { + let t_mask: u32 = t{{i}} & MASK; + let carry: u32 = t{{i}} >> WORD_SIZE; +{{#writes}} + t{{slot}} = t{{slot}} + t_mask * R_INV_{{r_idx}}{{#first}} + carry{{/first}}; +{{/writes}} + } +{{/yuval_iters}} + + // Standard Mont reduce for the last iter (i = N-1 = 19). + { + let t_mask: u32 = t{{i_std}} & MASK; + let k_std: u32 = (t_mask * N0) & MASK; +{{#standard_writes}} + t{{slot}} = t{{slot}} + k_std * p.limbs[{{p_idx}}u]{{#first}} + (t{{i_std}} >> WORD_SIZE){{/first}}; +{{/standard_writes}} + } + + // === Final canonicalization (single carry pass over t20..t39). === + var c: u32 = 0u; +{{#final_drain}} + { + let v: u32 = t{{slot}} + c; + c = v >> WORD_SIZE; + t{{slot}} = v & MASK; + } +{{/final_drain}} + + // === Repack canonical limbs into BigInt and return. === + var s: BigInt; +{{#extract}} + s.limbs[{{out_k}}u] = t{{src_slot}}; +{{/extract}} + + return conditional_reduce(&s, &p); +} + +fn conditional_reduce(x: ptr, y: ptr) -> BigInt { + var x_gt_y = bigint_gt(x, y); + var x_eq_y = bigint_eq(x, y); + if (x_gt_y == 1u || x_eq_y) { + var res: BigInt; + bigint_sub(x, y, &res); + return res; + } + return *x; +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/montgomery/mulhilo_22.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/montgomery/mulhilo_22.wgsl new file mode 100644 index 000000000000..343bc8d35bf6 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/montgomery/mulhilo_22.wgsl @@ -0,0 +1,43 @@ +// 22-bit-limb sibling of `mulhilo.wgsl`. Same FMA-bias trick, different +// constants (W = 2^22, BIAS = 2^44). The two files are never compiled +// into the same WGSL module — callers pick one bundle or the other. +// +// Inputs: a, b in [0, 2^22) stored as f32 (integer-valued). +// Output: vec2(hi, lo) with a*b == hi * 2^22 + lo, hi, lo in [0, 2^22). +// +// Math (load-bearing): +// fma(a, b, BIAS) computes a*b + 2^44 with infinite intermediate +// precision and rounds ONCE to f32. Inside [2^44, 2^45) the f32 ULP is +// exactly 2^21, so the rounded result lies on a multiple of 2^21 +// (not 2^22!). Dividing by W and flooring effectively rounds to W, +// giving a non-negative integer q/W < W. The lo correction follows the +// same step()-based underflow scheme as the 23-bit variant. +// +// Important difference from 23-bit: at 22-bit, a*b in [0, 2^44) plus +// BIAS = 2^44 sits exactly at the boundary [2^44, 2^45]. Inside that +// range ULP is 2^21, so `floor(abp * W_INV)` rounds to the nearest +// multiple of W/2 — but then `lo0 = fma(a, b, -hi_pre * W)` recovers +// exactly. The step() correction handles the half-W slop. +const BIAS: f32 = 17592186044416.0; // 2^44 +const W: f32 = 4194304.0; // 2^22 +const W_INV: f32 = 2.384185791015625e-7; // 2^-22 + +fn mulhilo(a: f32, b: f32) -> vec2 { + let abp = fma(a, b, BIAS); + let hi_pre = floor(abp * W_INV) - W; + let lo0 = fma(a, b, -hi_pre * W); + let underflow = step(lo0, -0.5); + let hi = hi_pre - underflow; + let lo = lo0 + underflow * W; + return vec2(hi, lo); +} + +fn mulhilo2(a: vec2, b: vec2) -> vec4 { + let abp = fma(a, b, vec2(BIAS, BIAS)); + let hi_pre = floor(abp * W_INV) - vec2(W, W); + let lo0 = fma(a, b, -hi_pre * W); + let underflow = step(lo0, vec2(-0.5, -0.5)); + let hi = hi_pre - underflow; + let lo = lo0 + underflow * W; + return vec4(hi.x, lo.x, hi.y, lo.y); +} diff --git a/barretenberg/ts/yarn.lock b/barretenberg/ts/yarn.lock index b434e0a41ef8..d7e848e84fe8 100644 --- a/barretenberg/ts/yarn.lock +++ b/barretenberg/ts/yarn.lock @@ -31,6 +31,7 @@ __metadata: msgpackr: "npm:^1.11.2" mustache: "npm:^4.2.0" pako: "npm:^2.1.0" + playwright-core: "npm:^1.59.1" prettier: "npm:^3.5.3" ts-jest: "npm:^29.4.0" ts-loader: "npm:^9.4.2" @@ -4797,6 +4798,15 @@ __metadata: languageName: node linkType: hard +"playwright-core@npm:^1.59.1": + version: 1.59.1 + resolution: "playwright-core@npm:1.59.1" + bin: + playwright-core: cli.js + checksum: 10/d27857a6701587c2a9bfa26fed9a5d8c617a392299b99b187f2ddc198d012a1e296449806bc907220debea938152677e8b4d91d304ed00645f762f778de3abec + languageName: node + linkType: hard + "postcss@npm:^8.5.6": version: 8.5.14 resolution: "postcss@npm:8.5.14" From 400d0c1bada4a25a0d7a2b15dcc35a0cde83254c Mon Sep 17 00:00:00 2001 From: zac-williamson Date: Sat, 16 May 2026 19:40:31 +0100 Subject: [PATCH 2/4] feat(bb): use Karatsuba+Yuval as the default Mont mult in WebGPU MSM MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Routes the `montgomery_product_funcs` mustache partial through a pre-rendered Karatsuba+Yuval body in every MSM shader that does a base-field multiply (15 callsites: convert_points, smvp, horner, batch_affine_{apply,schedule,finalize_*,init,apply_scatter}, batch_inverse{,_parallel}, bpr, decompress_g1, montgomery_parity). The Karatsuba body benches ~27% faster than the mitschabaude runtime-loop CIOS at n=2^20, k=100 (80 ms vs 109 ms). It exposes the same `fn montgomery_product(x, y) -> BigInt` symbol plus the same `get_p` / `conditional_reduce` helpers and uses the same 20×13-bit limb layout, so the swap is a drop-in change with no callsite churn. The field-mul bench retains both options (`?variant=cios` renders the original template inline, `?variant=karat` reuses the class-level default) so the two bodies can be compared side-by-side. --- .../ts/src/msm_webgpu/cuzk/shader_manager.ts | 47 +++++++++++++------ 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index c87f929b7fc5..3697d208e933 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -93,6 +93,15 @@ export class ShaderManager { public num_limbs_f32_22: number; public n0_f32_22: bigint; public p_limbs_f32_22_str: string; + // Pre-rendered u32 Montgomery product source used as the + // `montgomery_product_funcs` mustache partial by every MSM shader that + // needs a base-field multiply. Defaults to the Karatsuba + Yuval body + // (see `renderKaratYuvalMont`), which benches ~27% faster than the + // runtime-loop CIOS at n=2^20, k=100 on Apple GPU. Both bodies expose + // the same `fn montgomery_product(x, y) -> BigInt` symbol and the same + // `get_p` / `conditional_reduce` helpers, so swapping the partial is + // a drop-in change at every callsite. + public mont_product_src: string; public curveConfig: CurveConfig; public recompile = ''; @@ -139,6 +148,11 @@ export class ShaderManager { this.n0_f32_22 = params_f32_22.n0; this.p_limbs_f32_22_str = gen_p_limbs_f32(this.p, this.num_limbs_f32_22, 22); + // Render the Karatsuba+Yuval Mont body once. This is the default + // u32 multiplier used by every MSM shader that includes the + // `montgomery_product_funcs` mustache partial. + this.mont_product_src = this.renderKaratYuvalMont(); + if (force_recompile) { const rand = Math.round(Math.random() * 100000000000000000) % 2 ** 32; this.recompile = ` @@ -199,7 +213,7 @@ export class ShaderManager { bigint_funcs, field_funcs, barrett_funcs, - montgomery_product_funcs, + montgomery_product_funcs: this.mont_product_src, extract_word_from_bytes_le_funcs, }, ); @@ -236,7 +250,7 @@ export class ShaderManager { bigint_funcs, field_funcs, barrett_funcs, - montgomery_product_funcs, + montgomery_product_funcs: this.mont_product_src, extract_word_from_bytes_le_funcs, }, ); @@ -268,7 +282,7 @@ export class ShaderManager { bigint_funcs, field_funcs, barrett_funcs, - montgomery_product_funcs, + montgomery_product_funcs: this.mont_product_src, fr_pow_funcs, }, ); @@ -354,7 +368,7 @@ export class ShaderManager { { structs, bigint_funcs, - montgomery_product_funcs, + montgomery_product_funcs: this.mont_product_src, field_funcs, ec_funcs: ec_bn254_funcs, }, @@ -379,7 +393,7 @@ export class ShaderManager { { structs, bigint_funcs, - montgomery_product_funcs, + montgomery_product_funcs: this.mont_product_src, field_funcs, fr_pow_funcs, }, @@ -405,7 +419,7 @@ export class ShaderManager { { structs, bigint_funcs, - montgomery_product_funcs, + montgomery_product_funcs: this.mont_product_src, field_funcs, fr_pow_funcs, }, @@ -446,7 +460,7 @@ export class ShaderManager { { structs, bigint_funcs, - montgomery_product_funcs, + montgomery_product_funcs: this.mont_product_src, field_funcs, }, ); @@ -470,7 +484,7 @@ export class ShaderManager { { structs, bigint_funcs, - montgomery_product_funcs, + montgomery_product_funcs: this.mont_product_src, field_funcs, }, ); @@ -497,7 +511,7 @@ export class ShaderManager { { structs, bigint_funcs, - montgomery_product_funcs, + montgomery_product_funcs: this.mont_product_src, field_funcs, fr_pow_funcs, ec_funcs: ec_bn254_funcs, @@ -525,7 +539,7 @@ export class ShaderManager { { structs, bigint_funcs, - montgomery_product_funcs, + montgomery_product_funcs: this.mont_product_src, field_funcs, }, ); @@ -551,7 +565,7 @@ export class ShaderManager { { structs, bigint_funcs, - montgomery_product_funcs, + montgomery_product_funcs: this.mont_product_src, field_funcs, }, ); @@ -575,7 +589,7 @@ export class ShaderManager { { structs, bigint_funcs, - montgomery_product_funcs, + montgomery_product_funcs: this.mont_product_src, field_funcs, }, ); @@ -630,7 +644,7 @@ export class ShaderManager { { structs, bigint_funcs, - montgomery_product_funcs, + montgomery_product_funcs: this.mont_product_src, field_funcs, ec_funcs: ec_bn254_funcs, }, @@ -657,7 +671,7 @@ export class ShaderManager { { structs, bigint_funcs, - montgomery_product_funcs, + montgomery_product_funcs: this.mont_product_src, field_funcs, ec_funcs: ec_bn254_funcs, }, @@ -674,9 +688,12 @@ export class ShaderManager { ): string { const structs_src = mustache.render(structs, { num_words: this.num_words }); const bigint_src = mustache.render(bigint_funcs, {}); + // 'karat' reuses the pre-rendered class-level default; 'cios' renders + // the original mitschabaude template inline so the bench can compare + // both bodies even though karat is the production default. const mont_src = variant === 'karat' - ? this.renderKaratYuvalMont() + ? this.mont_product_src : mustache.render(montgomery_product_funcs, { num_words: this.num_words, word_size: this.word_size, From d3fdb6bddcffbcdb28b11b19f3272d6160136808 Mon Sep 17 00:00:00 2001 From: zac-williamson Date: Sun, 17 May 2026 12:23:37 +0100 Subject: [PATCH 3/4] feat(bb): BY field inversion + tree-reduce MSM design MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 1 LANDED — BY safegcd inversion (fr_inv_by_a, Option A: 20×13-bit, BATCH=26, carry-free apply_matrix): - Production swap-in: wgsl/cuzk/batch_inverse{,_parallel}.template.wgsl call fr_inv_by_a - 1.5× faster than legacy fr_inv (Pornin K=12) at chained-inverse bench - ~8% MSM wall reduction at logN=16 sanity check - TS port (cuzk/bernstein_yang.ts, bernstein_yang_a.ts) + Jest tests (24 passing) - WGSL impls: wgsl/field/by_inverse{,_a}.template.wgsl + wgsl/bigint/bigint_by.template.wgsl Phase 2 EXPLORATORY — multi-window pooled batch_inverse + multi-window BPR: - WPB plumbing in batch_inverse_parallel + dispatch_args + batch_affine.ts - Default WPB=1 (= legacy behavior, no perf change) - BPR_WINDOWS_PER_BATCH knob in bpr_bn254.template.wgsl - Empirical: pooling without growing WG count gives 0% gain — design needs restructure Standalone bench infrastructure: - bench-divsteps, bench-apply-matrix, bench-fr-inv, bench-batch-affine - Each with HTML page + TS dispatcher + Playwright runner under dev/msm-webgpu/scripts/ - profile-sanity.mjs for per-pass GPU time breakdown on the Quick Sanity Check Tree-reduce design (Stage B) for autonomous remote execution: - .claude/plans/msm-tree-reduce.md — full design (adaptive batch sizing, analytical slice partition, 2 distinct phase kernels) - .claude/plans/remote-agent-brief.md — remote agent execution brief Co-authored with Claude. --- .claude/plans/msm-tree-reduce.md | 172 + .claude/plans/msm-webgpu-rewrite.md | 357 ++ .../ts/dev/msm-webgpu/bench-apply-matrix.html | 37 + .../ts/dev/msm-webgpu/bench-apply-matrix.ts | 464 +++ .../ts/dev/msm-webgpu/bench-batch-affine.html | 37 + .../ts/dev/msm-webgpu/bench-batch-affine.ts | 423 +++ .../ts/dev/msm-webgpu/bench-divsteps.html | 37 + .../ts/dev/msm-webgpu/bench-divsteps.ts | 398 +++ .../ts/dev/msm-webgpu/bench-fr-inv.html | 37 + .../ts/dev/msm-webgpu/bench-fr-inv.ts | 472 +++ barretenberg/ts/dev/msm-webgpu/main.ts | 15 + .../msm-webgpu/scripts/bench-apply-matrix.mjs | 219 ++ .../msm-webgpu/scripts/bench-batch-affine.mjs | 212 ++ .../dev/msm-webgpu/scripts/bench-divsteps.mjs | 219 ++ .../dev/msm-webgpu/scripts/bench-fr-inv.mjs | 239 ++ .../dev/msm-webgpu/scripts/profile-sanity.mjs | 189 ++ .../ts/dev/msm-webgpu/wgsl_unit_tests.ts | 1 + .../ts/src/msm_webgpu/cuzk/batch_affine.ts | 77 +- .../cuzk/batch_affine_bn254.test.ts | 366 ++ .../src/msm_webgpu/cuzk/batch_affine_bn254.ts | 481 +++ .../msm_webgpu/cuzk/bernstein_yang.test.ts | 818 +++++ .../ts/src/msm_webgpu/cuzk/bernstein_yang.ts | 470 +++ .../src/msm_webgpu/cuzk/bernstein_yang_a.ts | 247 ++ barretenberg/ts/src/msm_webgpu/cuzk/gpu.ts | 63 +- .../ts/src/msm_webgpu/cuzk/shader_manager.ts | 388 ++- barretenberg/ts/src/msm_webgpu/cuzk/utils.ts | 104 + barretenberg/ts/src/msm_webgpu/msm.ts | 46 +- .../src/msm_webgpu/wgsl/_generated/shaders.ts | 3010 ++++++++++++++++- .../wgsl/bigint/bigint_by.template.wgsl | 341 ++ .../cuzk/apply_matrix_bench.template.wgsl | 86 + .../batch_affine_dispatch_args.template.wgsl | 19 +- .../wgsl/cuzk/batch_inverse.template.wgsl | 4 +- .../cuzk/batch_inverse_parallel.template.wgsl | 307 +- .../cuzk/bench_batch_affine.template.wgsl | 212 ++ .../wgsl/cuzk/bpr_bn254.template.wgsl | 163 +- .../wgsl/cuzk/divsteps_bench.template.wgsl | 44 + .../wgsl/cuzk/fr_inv_bench.template.wgsl | 50 + .../wgsl/field/by_inverse.template.wgsl | 1075 ++++++ .../wgsl/field/by_inverse_a.template.wgsl | 669 ++++ .../wgsl/field/fr_pow.template.wgsl | 22 + 40 files changed, 12234 insertions(+), 356 deletions(-) create mode 100644 .claude/plans/msm-tree-reduce.md create mode 100644 .claude/plans/msm-webgpu-rewrite.md create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-apply-matrix.html create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-apply-matrix.ts create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-batch-affine.html create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-batch-affine.ts create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-divsteps.html create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-divsteps.ts create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-fr-inv.html create mode 100644 barretenberg/ts/dev/msm-webgpu/bench-fr-inv.ts create mode 100755 barretenberg/ts/dev/msm-webgpu/scripts/bench-apply-matrix.mjs create mode 100644 barretenberg/ts/dev/msm-webgpu/scripts/bench-batch-affine.mjs create mode 100755 barretenberg/ts/dev/msm-webgpu/scripts/bench-divsteps.mjs create mode 100755 barretenberg/ts/dev/msm-webgpu/scripts/bench-fr-inv.mjs create mode 100644 barretenberg/ts/dev/msm-webgpu/scripts/profile-sanity.mjs create mode 100644 barretenberg/ts/src/msm_webgpu/cuzk/batch_affine_bn254.test.ts create mode 100644 barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang.test.ts create mode 100644 barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang.ts create mode 100644 barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang_a.ts create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/bigint/bigint_by.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/apply_matrix_bench.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/bench_batch_affine.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/divsteps_bench.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/cuzk/fr_inv_bench.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/field/by_inverse.template.wgsl create mode 100644 barretenberg/ts/src/msm_webgpu/wgsl/field/by_inverse_a.template.wgsl diff --git a/.claude/plans/msm-tree-reduce.md b/.claude/plans/msm-tree-reduce.md new file mode 100644 index 000000000000..042ef4a51693 --- /dev/null +++ b/.claude/plans/msm-tree-reduce.md @@ -0,0 +1,172 @@ +# Stage B — Tree-reduce per bucket with adaptive batch sizing + +> Replaces the current SMVP round-loop (`smvp_batch_affine_gpu` + 5 shaders) with +> a tree-reduce structure that scales logarithmically in max bucket population +> instead of linearly. Designed for skewed real-world ZK workloads where the +> current round-loop's MAX_ROUNDS bound is dominated by a few heavy buckets. + +## Constants + +``` +SWEET_B = 1024 // peak per-pair throughput (24.4 ns/pair from bench) +MIN_B = 32 // floor: TPB=32, 1 SIMD group, no cross-SIMD barriers, 1.56× sweet cost +TARGET_THREADS = 40_000 // Apple Silicon (M-series Pro/Max) resident thread budget +TPB_DEFAULT = 64 +TPB_MIN_B = 32 // matches Apple's SIMD group width +MAX_PHASES = 10 // recursion safety cap; pre-pass usually computes exact depth +``` + +## Adaptive batch sizing + +``` +function pickBatch(total_adds): + candidate_B = total_adds / (TARGET_THREADS / TPB_DEFAULT) # = total_adds / 625 + + if candidate_B >= SWEET_B: # plenty of work + return (SWEET_B, ceil(total_adds / SWEET_B), 64) + elif candidate_B >= 64: # mid: largest pow-2 ≤ candidate + B = floor_pow2(candidate_B) + return (B, ceil(total_adds / B), 64) + else: # tail: floor at MIN_B with TPB=32 + return (MIN_B, ceil(total_adds / MIN_B), 32) +``` + +## Phase structure + +### Pre-pass kernel (per phase) + +One small dispatch. Inputs: sorted schedule (Phase 1) or partials buffer (Phase ≥2). For each entry: +1. Determine WG slice membership via index-partitioning of `total_adds`. +2. Flag if entry is first of its bucket in its WG slice. +3. Flag if entry pairs with the next entry (same bucket, both in same WG slice). +4. Emit pair (idx_a, idx_b) to per-WG pair-list slot. + +Then host-side prefix sums produce: +- `wg_pair_offset[]`, `wg_pair_count[]` — pair-list slice per WG +- `wg_output_offset[]`, `wg_output_count[]` — output partials slice per WG +- `wg_first_bucket[]` — bucket_id of first partial (for cross-WG boundary detection) +- `max_pop_remaining` — if 0, no more dup buckets, terminate + +No atomics. Per-entry kernel work is O(1); host prefix-sum is O(num_WGs) = O(1000) = trivial. + +### Phase 1: per-WG slice batch-affine + +One dispatch, `num_WGs` workgroups of TPB threads (from `pickBatch`). Inputs: +- Bucket-sorted schedule +- Pre-computed pair list + +Each WG: +1. Reads its `wg_pair_count` pre-computed pairs from `pair_list[wg_pair_offset[wg_id] : wg_pair_offset[wg_id] + wg_pair_count[wg_id]]` +2. For each pair (a, b): loads `P = points[scalar_idx_a]` (with sign-flip per SCHEDULE_SIGN_BIT), `Q = points[scalar_idx_b]`, computes `delta_x = Q.x - P.x` +3. Cooperative Phase A/B/C/D batch inverse (workgroup-shared scan, 1 fr_inv_by_a per WG) +4. Per-pair: compute slope, R = P + Q +5. Compaction: each pair's result is the partial for some bucket. Adjacent same-bucket pairs in the slice get combined into running sum; final partials written to `output[wg_output_offset[wg_id] + slot]` with `bucket_id` tag. + +Single fr_inv per WG amortises over `wg_pair_count` ≈ B pair-adds. + +### Phase ≥2: tree-reduce on partials + +Re-sort phase 1's output by bucket_id globally (use existing transpose pattern — fast on GPU). Then re-run pre-pass + Phase 1 kernel on the bucket-sorted partials buffer. + +Key difference from Phase 1: load is `partials[idx]` (a point) instead of `points[scalar_idx]` (a point with sign). Faster — no negation per load. + +### Phase final: BPR / Horner + +After all phases collapse buckets to 1 point each, hand off to existing BPR per window + Horner combine across windows. No change to those. + +## Memory budget (logN=16, N_entries=1.1M, B_active≈272K) + +- `pair_list` (Phase 1): ~825K pairs × 8 bytes = **6.6 MB** +- `wg_*` arrays: 1000 WGs × ~5 × 4 bytes = **20 KB** +- `output partials` (Phase 1): ~325K × 68 bytes = **22 MB** +- `output partials` (Phase 2): ~80K × 68 bytes = **5.5 MB** +- `output partials` (Phase ≥3): rapidly shrinking +- Total scratch: **~35 MB**, well under any device limit + +## Phase count (theoretical) + +For typical 4-entries-per-bucket: `max_pop ≈ 16`, `log2(16/1024) ≤ 0` → **Phase 1 alone resolves it**. + +For skewed (heavy bucket pop=10K): `max_pop = 10000`, `log2(10000/1024) ≈ 4` → **Phase 1 + 3-4 recursion levels**. + +For uniform with sweet B fill: ~5 phases worst case. + +vs current: 32 rounds. **6× fewer dispatches in typical case**. + +## What we save + +- **Dispatch overhead**: 5 phases × 3 dispatches each = 15 vs current 32 rounds × 3 = 96. Saves ~1.6 ms. +- **Late-round amortisation collapse**: gone — adaptive sizing keeps per-WG batch at sweet through phase 5+. +- **Pathological skew**: round count goes from O(max_pop) to O(log max_pop). **The big win for production ZK workloads.** + +## Open implementation decisions + +### Per-WG slice compaction (within phase 1 / phase ≥2) + +Each WG's batch-affine produces `wg_pair_count` result points. These need to be COMPACTED into per-bucket partials (one partial per distinct bucket the WG touched). + +Two sub-options: +- **(a) Within-WG sequential merge**: after batch-affine, one thread walks the pair results, merges adjacent same-bucket results, writes final partials. ~B sequential adds (cheap, 63/64 threads idle but only briefly). +- **(b) Within-WG segmented reduce**: parallel reduction grouping by bucket_id. More complex. + +Going with **(a)** — simpler, the post-merge work is negligible compared to the batch-affine. + +### Re-sort between phases + +Phase k output is grouped by WG; Phase k+1 needs bucket-grouped input. Options: +- **Transpose-style**: use existing `transpose_parallel_{count,scan,scatter}` infrastructure on the new layout. Adds ~3 dispatches per phase. +- **Per-WG outputs are SORTED by bucket already** (since schedule was bucket-sorted). Just need a parallel MERGE of K sorted lists. O(N log K). Cheap. + +Going with **merge** — fewer dispatches. + +### Pair-list pre-pass + +Single dispatch, one thread per schedule entry. Per entry: +- Determine WG = `entry_idx * num_WGs / total_adds_density` (uses precomputed running-adds index) +- Check predecessor entry: same bucket + same WG slice → emit pair (predecessor, self) to per-WG slot + +Per-WG pair slot allocation: pre-pre-pass counts per-WG pair count, host prefix-sums. + +So phase structure is actually: +1. count-pass — count pairs per WG (1 atomic per WG, only num_WGs increments, low contention) +2. host prefix-sum — compute pair_offsets +3. fill-pass — write pairs to per-WG slots (one atomic per WG for local cursor, or use 2-thread cooperation to make atomicLess) +4. phase 1 batch-affine + +Actually atomics per-WG are TRIVIAL (one address per WG = no contention). Acceptable. + +OR even cleaner: do the count-pass and fill-pass in ONE kernel with per-thread local pair-buffer in registers, flushed at WG boundary. Avoids any global atomics. Complexity vs simplicity tradeoff. + +For first implementation: 2-pass pre-pass with per-WG-local atomics. Optimize later. + +## Phase count termination + +Pre-pass computes per-bucket population at phase 0. `MAX_PHASES = ceil(log2(max_pop / SWEET_B)) + 2`. Hard-coded; no runtime detection. + +OR per-phase: if `num_distinct_buckets_output == num_distinct_buckets_input`, no reduction happened → done. + +Use the formula-based approach (cleaner; hardcoded loop count). Loop: +``` +for phase in 0..MAX_PHASES: + if total_adds_remaining == 0: break + dispatch pre-pass + dispatch phase k + re-sort output → input of next phase +``` + +## What this does NOT include (per user scope) + +- Duplicate stripping +- Two bucket widths +- Adaptive bucket width (c stays constant) +- GLV scalar split + +## Estimated impact + +For UNIFORM data at logN=16 (current bench case): +- ba_inverse_Σ: 10.8 ms → estimated 6-8 ms. Saving ~3 ms = 4% MSM wall. +- Dispatch overhead saving: ~1.6 ms. + +For SKEWED data (typical ZK workloads): +- ba_inverse phase: estimated 3-5× faster due to log_2 vs linear in max_pop. +- Could be ~25-40% MSM wall reduction. Numbers depend heavily on actual workload skew profile. diff --git a/.claude/plans/msm-webgpu-rewrite.md b/.claude/plans/msm-webgpu-rewrite.md new file mode 100644 index 000000000000..e270eb9b7027 --- /dev/null +++ b/.claude/plans/msm-webgpu-rewrite.md @@ -0,0 +1,357 @@ +# Plan: WebGPU BN254 MSM rewrite — BY field inversion + multi-window Pippenger + 32-bit point schedule + +> Produced by Plan subagent on 2026-05-16. The execution loop owner (orchestrator +> Claude) iterates phases below via coder + reviewer subagents. Source of truth +> for what "done" means: this plan's acceptance gates. + +## 0. Plan summary + +**Two ideas to implement.** + +**Idea 1 — Replace WGSL `fr_inv` with the Bernstein–Yang (BY) safegcd inversion that the WASM uses.** +Port `Wasm9x29::divsteps` + `Wasm9x29::apply_matrix` to WGSL (9 × 29-bit signed limbs, BATCH=58 inner divsteps per outer iter, NUM_OUTER=13 outer iters with early `g == 0` break, REDUCE_INTERVAL=4). Each outer iter folds 58 divsteps into one 2×2 matrix and applies it via a streamed schoolbook with limb-by-limb carry. Target: at least 2× wall reduction on the `fr_inv` critical path vs the existing jumpy safegcd `fr_inv`; ideally 3–5×. + +**Idea 2 — Multi-window batched Pippenger + 32-bit point schedule.** +- Replace the per-round bucket-cursor + atomic pair counter design with a bucket-sorted 32-bit schedule built via histogram → per-window prefix-sum → scatter (Stages 1/2/3/4 of the WASM). +- Make the batch-affine reduce phase consume `NUM_WINDOWS_PER_BATCH × num_columns` pairs in one batched inversion so the inversion amortises over both buckets AND windows. +- Extend `batch_inverse_parallel`'s workgroup-Z dimension to `num_subtasks × NUM_WINDOWS_PER_BATCH`. +- Schedule entry layout matches the WASM's `Constantine` packed digit: bit 31 = sign, bits 0..28 = scalar_idx (29-bit payload). Dedup-redirect / dedup-skip bits exist in the encoding but are unused. + +**Out of scope (explicit per spec):** duplicate stripping, two bucket widths for one MSM, adaptive c. + +--- + +## 1. Required reading + +The first coding agent MUST read these before writing any code. + +**WASM reference (source of ideas):** +1. `/Users/zac/barretenberg-claude-2/barretenberg/cpp/src/barretenberg/ecc/fields/bernstein_yang_inverse.hpp` +2. `/Users/zac/barretenberg-claude-2/barretenberg/cpp/src/barretenberg/ecc/fields/bernstein_yang_inverse_wasm.hpp` — `Wasm9x29` (closest to WGSL target) +3. `/Users/zac/barretenberg-claude-2/barretenberg/cpp/src/barretenberg/ecc/scalar_multiplication/scalar_multiplication.cpp`: + - lines 1–350 (file header, `get_scalar_slice_low`, `compute_constantine_slice_params`, `get_constantine_packed_digit`) + - lines 540–710 (schedule entry bit constants, `VariableWindowSchedule`, `RegionView`, cost model) + - lines 1620–1800 (`pippenger_round_parallel_jacobian_fast` — single-thread textbook structural reference) + - lines 2671–2830 (entry to `pippenger_round_parallel`, Arena setup, Phase 1) + - lines 3780–4080 (Stage 1 histogram, Stage 2/3 bucket-offset, Stage 4 scatter; skip Phase A dedup body) + - lines 4210–4550 (Stage 6 partition, Stage 6a/6b bucket reduction across `windows_in_batch`) + - lines 4550–4610 (per-region dispatch driver, lower/upper regions, batch loop) + +**Existing WebGPU MSM (target codebase):** +4. `barretenberg/ts/src/msm_webgpu/wgsl/field/fr_pow.template.wgsl` +5. `barretenberg/ts/src/msm_webgpu/wgsl/bigint/bigint.template.wgsl` +6. `barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_inverse_parallel.template.wgsl` +7. `barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_inverse.template.wgsl` +8. `barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts` +9. `barretenberg/ts/src/msm_webgpu/cuzk/batch_affine.ts` +10. `barretenberg/ts/src/msm_webgpu/msm.ts` lines 540–924 +11. `barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_schedule.template.wgsl` +12. `barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_apply_scatter.template.wgsl` +13. `barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_init.template.wgsl`, `batch_affine_dispatch_args.template.wgsl`, `batch_affine_finalize_collect.template.wgsl`, `batch_affine_finalize_apply.template.wgsl` +14. `barretenberg/ts/src/msm_webgpu/cuzk/bn254.ts` (host BigInt reference) +15. `barretenberg/ts/dev/msm-webgpu/bench-field-mul.ts`, `bench-field-mul.html` +16. `barretenberg/ts/dev/msm-webgpu/scripts/bench-field-mul.mjs` +17. `barretenberg/ts/dev/msm-webgpu/main.ts` lines 1180–1300 (`Quick sanity check (WebGPU only)` button) + +**Hard constraints for every coding agent (repeated in every task brief):** + +- WebGPU on Apple Silicon Metal is FRAGILE. A wedged shader can require a reboot. +- **Every WGSL loop MUST have a compile-time-constant upper bound** (`for (var i = 0u; i < CONST; i = i + 1u)` where `CONST` is a `const` or substituted Mustache value). Reject any shader that fails this audit. +- For the BY divsteps inner loop: bound is `BATCH = 58` as a `const`. Outer loop bound is `NUM_OUTER = 13`, also a `const`. +- For BY `apply_matrix` streamed schoolbook: bound is `const N: u32 = NUM_LIMBS_BY` — must be a `const`, not a runtime expression. +- Test order: `bench-field-mul` micro-bench first (n=2^10 to 2^14), then **only after green**, the dev-page Sanity Check at logN=16. Never invoke any MSM-runtime harness from Node. +- The base-field multiplication (`montgomery_product` Karat+Yuval) just landed; **do not modify it**. +- Never delete the existing `fr_inv` / `fr_inv_plain` / `fr_inv_bgcd`. They stay as A/B fallbacks. + +--- + +## 2. Phase 1 — BY field inversion in WGSL + +### 2.1 Locate the algorithm + +- Driver: `bernstein_yang_inverse.hpp` lines 290–326 (`invert_bernsteinyang19`). +- 9×29-bit engine: `bernstein_yang_inverse_wasm.hpp` lines 1–258. + - `Wasm9x29::divsteps(delta, f_lo, g_lo)` — lines 147–178. + - `Wasm9x29::apply_matrix(m, f, g, d, e, p, p_inv)` — lines 187–255. + - `Wasm9x29::reduce_to_canonical(p)` — lines 125–145. +- Convergence bound: 735 divsteps cited at header lines 26–27. With BATCH=58 → ⌈735/58⌉ = 13 outer iters. + +### 2.2 Iteration count and determinism + +- `NUM_OUTER = 13` hard cap, with early exit on `g == 0`. +- `BATCH = 58` inner divsteps per outer iter. +- Variable-time over branches; BN254 base-field values in our pipeline are public, so OK. +- Fully deterministic for a given input. + +### 2.3 WGSL representation + +Decisive choice: **Option B — `BigIntBY = array` of 29-bit signed limbs.** This matches the WASM and reaches the perf target. + +Conversion on entry/exit between the 20×13-bit `BigInt` and `BigIntBY`. The conversion is ~20 ops each way; per-call cost amortises over NUM_OUTER × BATCH ≈ 750 inner ops + 13 matrix applications. + +### 2.4 New WGSL files / signatures + +Create (Mustache partial `{{> by_inverse_funcs }}`): + +**File: `barretenberg/ts/src/msm_webgpu/wgsl/field/by_inverse.template.wgsl`** + +Top-level entry (drop-in replacement for existing `fr_inv`): + +```wgsl +// Bernstein-Yang safegcd inverse on 9 × 29-bit signed limbs. +// Input in Montgomery form. Output mont(a^(-1)). +fn fr_inv_by(a: BigInt) -> BigInt +``` + +Required constants and helpers (loop bounds all `const`): + +```wgsl +const BY_NUM_LIMBS: u32 = 9u; +const BY_LIMB_BITS: u32 = 29u; +const BY_LIMB_MASK: u32 = (1u << 29u) - 1u; +const BY_BATCH: u32 = 58u; +const BY_NUM_OUTER: u32 = 13u; +const BY_REDUCE_INTERVAL: u32 = 4u; +const BY_RTC_MAX_ITERS: u32 = 36u; // matches Wasm9x29::reduce_to_canonical + +struct BigIntBY { l: array }; + +fn by_from_bigint(x: BigInt) -> BigIntBY; +fn by_to_bigint(x: BigIntBY) -> BigInt; +fn by_get_p() -> BigIntBY; +fn by_one() -> BigIntBY; +fn by_low_u64_lohi(x: BigIntBY) -> vec2; +fn by_is_zero(x: BigIntBY) -> bool; +fn by_is_negative(x: BigIntBY) -> bool; +fn by_neg(x: BigIntBY) -> BigIntBY; +fn by_normalise(x: ptr); +fn by_reduce_to_canonical(x: ptr, p: ptr); + +// Matrix entries split into (lo: i32, hi: i32) representing i64 values. +// After BATCH=58 divsteps, |entry| ≤ 2^58. +struct Mat { u: i32, v: i32, q: i32, r: i32, u_hi: i32, v_hi: i32, q_hi: i32, r_hi: i32 }; + +fn by_divsteps(delta: ptr, f_lo: vec2, g_lo: vec2) -> Mat; +fn by_apply_matrix_fg(m: Mat, f: ptr, g: ptr); +fn by_apply_matrix_de(m: Mat, d: ptr, e: ptr, + p: ptr, p_inv_lo: u32, p_inv_hi: u32); + +fn fr_inv_by(a: BigInt) -> BigInt; +``` + +`by_divsteps`: transliterate `Wasm9x29::divsteps` lines 147–178. Use `vec2` for the 64-bit `f_lo` and `g_lo` carriers (WGSL has no native i64). Carry the matrix entries `u, v, q, r` as paired `(lo: i32, hi: i32)` because they grow up to 2^58. Loop bound: `for (var i: u32 = 0u; i < BY_BATCH; i = i + 1u) { ... }`. + +`by_apply_matrix_fg` / `by_apply_matrix_de`: transliterate lines 196–254. Each per-limb `m_lo * limb` is an i58, NOT i32. Define a single safe `signed_mul_split(a: i32, b: i32) -> vec2` helper bounded to |a|, |b| ≤ 2^29 and reuse everywhere. The coding agent picks the exact partial-product splits; the contract is only that each partial fits in i32. + +### 2.5 Test harness — `fr_inv` micro-bench + +Add `barretenberg/ts/src/msm_webgpu/wgsl/cuzk/fr_inv_bench.template.wgsl` (mirrors `field_mul_bench_u32.template.wgsl`). Per-thread chained `fr_inv_` `k` times, write to outputs. + +Host-side: `gen_fr_inv_bench_shader(workgroup_size, variant)` in `shader_manager.ts`, `--variant fr_inv_by` whitelisted in `bench-field-mul.mjs` and `bench-field-mul.ts`. Reference: `modInverse` from `cuzk/bn254.ts` with Mont conversion. + +**Acceptance criteria for Phase 1:** +1. `bench-field-mul.mjs --path u32 --variant fr_inv_by --n 1024 --k 1 --validate-n 1024` → all 1024 match host reference. +2. `--n 65536 --k 10` runs to completion (no hang, no `[shader fr_inv_bench] error:` console message). +3. `fr_inv_by` ≥ 2× faster than `fr_inv` median wall (target 3–5×). + +### 2.6 Wiring into production + +After Phase 1 acceptance: +1. `wgsl/cuzk/batch_inverse_parallel.template.wgsl` line ~219: `fr_inv` → `fr_inv_by`. +2. `wgsl/cuzk/batch_inverse.template.wgsl` line ~77: `fr_inv` → `fr_inv_by`. +3. `shader_manager.ts`: include `{{> by_inverse_funcs }}` in `gen_batch_inverse_parallel_shader` / `gen_batch_inverse_shader` partials. +4. Run Quick Sanity Check at logN=16 via Playwright; expect `[sanity] PASS`. If FAIL, revert and bisect via the micro-bench. + +--- + +## 3. Phase 2 — Multi-window batched Pippenger + 32-bit point schedule + +### 3.1 WASM multi-round structure + +(All line refs in `scalar_multiplication.cpp`.) + +**Outer dispatch (4551–4604):** Lower region + optional Upper region. We use only the lower region (single c). + +**Per region (4570–4602):** iterate windows in batches of `windows_per_batch`. Within one batch: +- **Stage 1 (3785–3877):** per-thread per-window digit histogram. Output `digit_cursors[(w · T + t) · bucket_stride + d]`. +- **Stage 2 (3879–3909):** per-thread → per-window prefix-sum. Writes per-(window, thread, digit) cursor base; writes per-digit totals to `bucket_start[d+1]`. +- **Stage 3 (3911–3937):** per-window serial prefix-sum on `bucket_start`. +- **Stage 4 (3939–4075):** scatter. Re-decodes each scalar's window-w digit, writes the 32-bit schedule entry to `schedule[w * capacity + bucket_start[d] + cursor[d]++]`. Dedup OFF. +- **Stage 5 (4211–4217):** per-window chunk partition. +- **Stage 6a (4344–4399):** per-(thread, window) batched-affine bucket reduction → `bucket_partials_dense`. +- **Stage 6b (4401–4525):** cross-thread, per-task slice `[d_lo, d_hi]`, `recursive_affine_bucket_reduce_strided` — the multi-window batched inversion. +- **Stage 7 (4534–4548):** per-window combine of per-thread partials. + +Final Horner combine over all windows: lines 4606–4615. + +### 3.2 32-bit schedule entry encoding + +Adopt bit-for-bit from WASM (lines 552–567): +- bit 31: sign +- bit 30: dedup redirect (always zero) +- bit 29: dedup skip (always zero) +- bits 0..28: scalar_idx (≤ 2^29 = 512M, plenty for logN ≤ 28) + +### 3.3 WGSL changes + +**Replace:** +- `wgsl/cuzk/batch_affine_schedule.template.wgsl` — delete the per-round bucket-cursor / atomic pair counter. Replace with three new shaders: + +**New `wgsl/cuzk/schedule_histogram.template.wgsl`** (Stage 1) +```wgsl +// Per-thread per-window per-digit histogram. +// Dispatch: (ceil(n / wg_size), 1, num_subtasks_in_batch) +// const NUM_WINDOWS_IN_BATCH: u32 = {{ num_windows_in_batch }}u; +// const NUM_BUCKETS: u32 = {{ num_columns }}u; +// Writes digit_cursors[(w * num_threads + tid) * num_buckets + d]. +``` + +**New `wgsl/cuzk/schedule_offsets.template.wgsl`** (Stage 2 + 3) +```wgsl +// One workgroup per (window, bucket-slice). Per-window prefix-sum. +// Output: bucket_start[w][d+1], digit_cursors[w][t][d]. +``` + +**New `wgsl/cuzk/schedule_scatter.template.wgsl`** (Stage 4) +```wgsl +// Dispatch: (ceil(n / wg_size), 1, num_subtasks_in_batch) +// sched[w * capacity + bucket_start[w][d] + cursor++] = sign << 31 | scalar_idx +``` + +**Keep + extend:** +- `batch_affine_apply_scatter.template.wgsl`: bind layout reads from bucket-sorted schedule; affine-add math unchanged. +- `batch_inverse_parallel.template.wgsl`: Z dimension becomes `num_subtasks × NUM_WINDOWS_IN_BATCH`. Inside, decode `wid.z` into `(subtask_in_batch, window_in_batch)`. +- `batch_affine_finalize_collect.template.wgsl` / `_apply.template.wgsl`: unchanged (called once at end of MSM). + +**New `wgsl/cuzk/bucket_reduce.template.wgsl`** (Stage 6a per-window single-thread bucket accumulator). Per-window kernel that: +1. Reads `schedule[w][chunk_start..chunk_end]` (bucket-sorted). +2. Accumulates each run of contiguous same-bucket entries via the existing batched-affine tree reduce (reuses `batch_inverse_parallel`). +3. Output per-(thread, window) `bucket_partials_dense`. + +### 3.4 Host TS changes + +`cuzk/batch_affine.ts` — major rewrite of `smvp_batch_affine_gpu`: +1. Add `windows_per_batch: number` (start = 4). +2. Replace init + schedule + (per-round inverse+apply) with: dispatch histogram → offsets → scatter → outer loop over batches → per-batch round loop with Z dispatch `windows_per_batch × num_subtasks_in_batch`. +3. Buffer changes: drop `pair_counter` (replaced by per-(w, subtask) atomic). Drop `bucket_cursor` (replaced by `digit_cursors`). Add `bucket_start`. Add `schedule` (32-bit bucket-sorted, ~`num_subtasks × num_columns × 4` bytes ≈ 2 MB at logN=16). + +`cuzk/shader_manager.ts` — add: +- `gen_schedule_histogram_shader(workgroup_size, num_columns, num_windows_in_batch)` +- `gen_schedule_offsets_shader(workgroup_size, num_columns, num_windows_in_batch)` +- `gen_schedule_scatter_shader(workgroup_size, num_columns, num_windows_in_batch)` + +Bump cache keys with new tag `mwb-v1`. + +`msm.ts` — at the `smvp_batch_affine_gpu` call, add `windows_per_batch: 4`. + +`cuzk/batch_affine_bn254.ts` (host reference) — extend `batchAffineMSM` with `windowsPerBatch`; one batched inversion spans pairs from all windows in the batch. **Required as ground truth for correctness tests.** + +### 3.5 Constants exported WASM → WGSL + +| Constant | Value | WGSL exposure | +|---|---|---| +| `SCHEDULE_SIGN_BIT` (line 559) | `1 << 31` | `const SCHED_SIGN_BIT: u32 = 1u << 31u;` | +| `DEDUP_REDIRECT_BIT` (560) | `1 << 30` | `const SCHED_REDIRECT_BIT: u32 = 1u << 30u;` (always zero) | +| `DEDUP_SKIP_BIT` (561) | `1 << 29` | `const SCHED_SKIP_BIT: u32 = 1u << 29u;` (always zero) | +| `SCHEDULE_INDEX_MASK` (562) | `(1<<29) - 1` | `const SCHED_INDEX_MASK: u32 = (1u << 29u) - 1u;` | +| `BATCH_CAPACITY` (596) | 256 | `const BATCH_AFFINE_BREAKEVEN: u32 = 256u;` | +| `BATCH_AFFINE_BREAKEVEN` (1525) | 32 | `const BATCH_AFFINE_DRAIN_THRESHOLD: u32 = 32u;` | + +`chunk_size` (c) stays at 15/16 per `msm.ts:554`. `num_columns = 2^c`. + +### 3.6 Intermediate validation milestones (each is a hard gate) + +For each milestone, the test is the Quick Sanity Check button at logN=16 via Playwright (`[sanity] PASS in N ms`): + +- After `shader_manager` additions, before host orchestrator changes: WGSL compile-only check via `getCompilationInfo()`. +- After histogram + offsets + scatter, `windows_per_batch = 1`: read back schedule on n=2^10 / 2^12; per-(w, d, k) entry matches host's bucket-sorted ground truth (set equality). +- After `bucket_reduce`, `windows_per_batch = 1`: Sanity Check PASS at logN=16. +- After `NUM_WINDOWS_PER_BATCH = 2`: Sanity Check PASS at logN=16. +- After `NUM_WINDOWS_PER_BATCH = 4`: Sanity Check PASS at logN=16 + visible `ba_inverse + ba_apply` wall reduction in `Profiler.report()`. + +### 3.7 Workgroup sizing + +- `schedule_histogram`: WG=256, dispatch `(ceil(n/256), 1, num_subtasks_in_batch)`. Per-thread arrays (no shared workgroup atomics). +- `schedule_offsets`: WG=64, dispatch `(1, 1, num_windows_in_batch)`. Per-thread → cross-thread → per-digit prefix sums. +- `schedule_scatter`: WG=256, same dispatch shape as histogram. +- `bucket_reduce`: WG=64 (matches existing apply_scatter). Z = `num_subtasks × NUM_WINDOWS_PER_BATCH`. +- `batch_inverse_parallel`: WG=64. Z = `num_subtasks × NUM_WINDOWS_PER_BATCH`. + +### 3.8 Loop-bound audit + +All loops introduced in Phase 2 use a `const`-bounded counter: +- `schedule_histogram` inner: `for (var w = 0u; w < NUM_WINDOWS_IN_BATCH; ...)`. +- `schedule_offsets` reductions: `for (var t = 0u; t < TPB; ...)`. +- Hillis-Steele scan: `for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u)`. +- `schedule_scatter` window loop: `NUM_WINDOWS_IN_BATCH`. +- `bucket_reduce` tree-reduce pass: bounded by `BATCH_AFFINE_BREAKEVEN`. + +Audit step after every render: `grep -E 'for *\(.*<' rendered.wgsl | grep -v -E '< [A-Z][A-Z_]*[a-z]?|< [0-9]+|< [a-z_]+\.x' | grep -v 'workgroup_size'`. + +--- + +## 4. Test plan + +| Phase | Test | Harness | Pass | +|---|---|---|---| +| 1.A | BY divsteps TS unit test | new Jest `cuzk/bernstein_yang.test.ts`, ~1000 random inputs vs `modInverse` | all match | +| 1.B | WGSL `fr_inv_by` correctness | `bench-field-mul.mjs --variant fr_inv_by --n 1024 --validate-n 1024 --k 1` | all 1024 match | +| 1.C | WGSL `fr_inv_by` perf | same w/ `--n 65536 --k 10 --reps 5` | ≥ 2× faster median than `fr_inv` | +| 1.D | E2E Sanity w/ BY swap-in | Playwright Quick Sanity Check button | `[sanity] PASS` | +| 2.A | Schedule correctness | `wgsl_unit_tests.ts` helper, n=2^10 | set equality vs host ground truth | +| 2.B | Bucket reduction `windows_per_batch=1` | Quick Sanity Check at logN=16 | PASS | +| 2.C | `NUM_WINDOWS_PER_BATCH=2` | Quick Sanity Check at logN=16 | PASS | +| 2.D | `NUM_WINDOWS_PER_BATCH=4` | Quick Sanity Check at logN=16, 18 | PASS + ≥ 1.5× wall reduction | + +**Critical safety rule:** Full MSM correctness ONLY via the Quick Sanity Check button via Playwright. NEVER invoke `compute_bn254_msm_*` directly from Node — the dev-page-button-with-Playwright is the only path validated against Apple Silicon Metal. Micro-bench (`bench-field-mul`) is for primitives only. + +--- + +## 5. Iteration breakdown (17 sub-steps, ≥ 10 floor met) + +**Phase 1:** + +1. **1.1** — Transliterate `Wasm9x29::divsteps` + `apply_matrix` + `reduce_to_canonical` + driver to TS. Jest `bernstein_yang.test.ts` with 1000 random inputs vs `modInverse`. Gate: all match. +2. **1.2** — Add WGSL bigint helpers: `signed_mul_split`, vec2 64-bit add/sub/shift, `by_normalise` carry propagation. New `wgsl/bigint/bigint_by.template.wgsl`. Unit-test via scratch shader. +3. **1.3** — Write WGSL `by_divsteps`. Validate via `divsteps_bench` shader vs TS port. +4. **1.4** — Write WGSL `by_apply_matrix_fg` / `by_apply_matrix_de`. Precompute `p_inv_by_lo` / `p_inv_by_hi` via Mustache in `shader_manager.ts`. +5. **1.5** — Wire `fr_inv_by` + `by_reduce_to_canonical`. Add `gen_fr_inv_bench_shader` + `--variant fr_inv_by`. `--n 1024 --validate-n 1024 --k 1` → all match. Hard gate. +6. **1.6** — Perf pass. `--n 65536 --k 10 --reps 5`. Hard gate: ≥ 2× over `fr_inv`. +7. **1.7** — Swap `fr_inv` → `fr_inv_by` in `batch_inverse_parallel` and `batch_inverse`. Quick Sanity Check at logN=16. Hard gate: PASS. + +**Phase 2:** + +8. **2.1** — Host BigInt reference for multi-window batched Pippenger. Extend `cuzk/batch_affine_bn254.ts` with `windowsPerBatch`. Jest cross-check vs `windowsPerBatch=1`. +9. **2.2** — Stage 1 `schedule_histogram`. Add unit test in `wgsl_unit_tests.ts` dispatching on n=2^10, compare per-(w, t, d) vs host. +10. **2.3** — Stage 2/3 `schedule_offsets`. Validate `bucket_start` after kernel = exclusive prefix of `Σ_t digit_cursors`. +11. **2.4** — Stage 4 `schedule_scatter`. Validate via read-back test 2.A. Gate: set equality. +12. **2.5** — `bucket_reduce` for one window (`NUM_WINDOWS_PER_BATCH=1`). Reuse `batch_affine_apply_scatter` math; rewire input from bucket-sorted schedule. +13. **2.6** — Rewire `batch_affine.ts` to dispatch histogram → offsets → scatter → bucket_reduce → finalize at `windows_per_batch=1`. Gate: Sanity Check PASS at logN=16, 14, 12. +14. **2.7** — Bump `NUM_WINDOWS_PER_BATCH` to 2. Decode `wid.z` into (subtask_in_batch, window_in_batch). Gate: Sanity Check PASS at logN=16. +15. **2.8** — Bump to 4 + profile. Gates: Sanity Check PASS at logN=16, 18; ≥ 1.5× wall reduction on `ba_inverse + ba_apply` summed across batches. +16. **2.9** — Cleanup + cache-key bump to `mwb-v1`. Re-run all sanity gates. +17. **2.10** — Final integration. Sanity Check at logN=14, 15, 16, 17, 18, 19, 20. Each must PASS. Wall time vs pre-rewrite baseline. + +--- + +## 6. Out of scope (per user) + +- Duplicate stripping (Phase A / dedup). Bits 29 and 30 of the schedule stay zero. +- Two bucket widths for one MSM (variable-window split). +- Adaptive c. + +If a coding agent finds themselves implementing any of these three, STOP. + +--- + +## Critical files for implementation + +- `/Users/zac/aztec-packages/barretenberg/ts/src/msm_webgpu/wgsl/field/fr_pow.template.wgsl` +- `/Users/zac/aztec-packages/barretenberg/ts/src/msm_webgpu/wgsl/bigint/bigint.template.wgsl` +- `/Users/zac/aztec-packages/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_inverse_parallel.template.wgsl` +- `/Users/zac/aztec-packages/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine.ts` +- `/Users/zac/aztec-packages/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts` + +Reference-only — source of all the algorithm structure: +- `/Users/zac/barretenberg-claude-2/barretenberg/cpp/src/barretenberg/ecc/fields/bernstein_yang_inverse_wasm.hpp` +- `/Users/zac/barretenberg-claude-2/barretenberg/cpp/src/barretenberg/ecc/scalar_multiplication/scalar_multiplication.cpp` lines 540–710, 1620–1800, 2671–2830, 3780–4080, 4210–4610. diff --git a/barretenberg/ts/dev/msm-webgpu/bench-apply-matrix.html b/barretenberg/ts/dev/msm-webgpu/bench-apply-matrix.html new file mode 100644 index 000000000000..83c13009529e --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-apply-matrix.html @@ -0,0 +1,37 @@ + + + + + BY apply_matrix WebGPU dispatch test + + + +

BY apply_matrix WebGPU dispatch test

+

Query params: ?n=N&validate-n=N&reps=R

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-apply-matrix.ts b/barretenberg/ts/dev/msm-webgpu/bench-apply-matrix.ts new file mode 100644 index 000000000000..e05603a24dfb --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-apply-matrix.ts @@ -0,0 +1,464 @@ +/// +// BY apply_matrix WebGPU dispatch test. Mounted standalone (no SRS, no MSM +// pipeline). Reads `?n=N&validate-n=N&reps=R`, generates seeded random +// (Mat, f, g, d, e) records by first sampling random u64 (f_lo, g_lo, delta) +// and running the TS `Wasm9x29.divsteps` to get a valid Mat (which +// guarantees matrix-entry bounds), then synthesising plausible random f, g, +// d, e BigIntBY states, runs `by_apply_matrix_fg` and `by_apply_matrix_de` +// on the GPU per input record, and validates each output against the TS +// `Wasm9x29.applyMatrix` reference. +// +// Safety. The shader's only loops are inherited from `by_apply_matrix_fg` +// and `by_apply_matrix_de`, both bounded by the WGSL `const BY_NUM_LIMBS = 9u`. +// One thread per input record, `n` capped at 2^20 to keep memory pressure +// manageable. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { + Wasm9x29, + fromBigint as byFromBigint, + makeZero as byMakeZero, + normalise as byNormalise, + N as BY_N, +} from '../../src/msm_webgpu/cuzk/bernstein_yang.js'; + +interface SampleSummary { + reps: number; + msSamples: number[]; + msMedian: number; + msMin: number; + msMax: number; + applyMatrixPerSec: number; +} +interface BenchResult { + validateOk: boolean; + mismatches: string[]; + timing: SampleSummary | null; +} +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: { + n: number; + validateN: number; + reps: number; + } | null; + result: BenchResult | null; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { + state: 'boot', + params: null, + result: null, + error: null, + log: [], +}; +(window as unknown as { __bench: BenchState }).__bench = benchState; + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-apply-matrix] ${msg}`); +} + +const N_MAX = 1 << 20; +const INPUT_STRIDE_U32 = 44; +const OUTPUT_STRIDE_I32 = 36; + +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function randomU64(rng: () => number): { lo: number; hi: number } { + const lo = rng() >>> 0; + const hi = rng() >>> 0; + return { lo, hi }; +} + +// Random signed 29-bit limb in [-2^28, 2^28). The +shift then sub keeps the +// distribution roughly uniform across the signed range while staying inside +// the canonical |limb| < 2^29 invariant that `by_normalise` preserves. +function randomSignedLimb29(rng: () => number): bigint { + const v = rng(); + // Map u32 → signed [-2^28, 2^28). + const u29 = v & 0x1fffffff; + return BigInt(u29) - (1n << 28n); +} + +// Make a random BigIntBY with limbs roughly uniform in [-2^28, 2^28), +// then normalise so the lower limbs canonicalise to [0, 2^29). This matches +// the post-by_normalise invariant the WGSL functions expect on input. +function randomNormalisedBigIntBY(rng: () => number): bigint[] { + const x = byMakeZero(); + for (let i = 0; i < BY_N; i++) { + x[i] = randomSignedLimb29(rng); + } + byNormalise(x); + return x; +} + +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; +} + +// Reassemble a (lo: i32, hi: i32) GPU pair as a signed bigint (two's comp). +function pairToSignedBig(lo: number, hi: number): bigint { + const loBig = BigInt(lo >>> 0); + const hiBig = BigInt(hi >>> 0); + let v = loBig | (hiBig << 32n); + if (v >= 1n << 63n) v -= 1n << 64n; + return v; +} + +async function createPipeline( + device: GPUDevice, + code: string, + cacheKey: string, +): Promise<{ pipeline: GPUComputePipeline; layout: GPUBindGroupLayout }> { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + for (const msg of info.messages) { + const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`; + if (msg.type === 'error') { + console.error(line); + hasError = true; + } else { + console.warn(line); + } + } + if (hasError) { + throw new Error(`WGSL compile failed for ${cacheKey}`); + } + const layout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + const pipeline = await device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); + return { pipeline, layout }; +} + +// Serialise a signed i64 bigint into (lo: u32, hi: u32) — the on-wire +// representation of Mat fields in the input buffer. +function i64ToLoHi(v: bigint): { lo: number; hi: number } { + const U64 = (1n << 64n) - 1n; + let u = v & U64; + if (u < 0n) u += 1n << 64n; + const lo = Number(u & 0xffffffffn) >>> 0; + const hi = Number((u >> 32n) & 0xffffffffn) >>> 0; + return { lo, hi }; +} + +async function runBench( + device: GPUDevice, + sm: ShaderManager, + n: number, + validateN: number, + reps: number, +): Promise { + log('info', `building inputs (n=${n}, validate-n=${validateN}, reps=${reps})`); + + // Seeded inputs. + const rng = makeRng(0xa57e1234); + const inputsBuf = new Uint32Array(n * INPUT_STRIDE_U32); + + // Per-input ground truth: f', g', d', e' after applyMatrix. + const expectedOutputs = new Int32Array(validateN * OUTPUT_STRIDE_I32); + + for (let i = 0; i < n; i++) { + // 1. Sample random u64 f_lo, g_lo, and a small random delta; run the TS + // Wasm9x29.divsteps to produce a valid Mat (matrix entries satisfy + // the |entry| <= 2^58 bound that apply_matrix downstream expects). + const f64 = randomU64(rng); + const g64 = randomU64(rng); + const delta = (rng() % 1025) - 512; + const fBig = BigInt(f64.lo >>> 0) | (BigInt(f64.hi >>> 0) << 32n); + const gBig = BigInt(g64.lo >>> 0) | (BigInt(g64.hi >>> 0) << 32n); + const { mat } = Wasm9x29.divsteps(BigInt(delta), fBig, gBig); + + // 2. Synthesise random signed-canonical f, g, d, e BigIntBY states. + const fState = randomNormalisedBigIntBY(rng); + const gState = randomNormalisedBigIntBY(rng); + const dState = randomNormalisedBigIntBY(rng); + const eState = randomNormalisedBigIntBY(rng); + + // 3. Encode into the input buffer. Mat fields go (u, v, q, r, u_hi, v_hi, + // q_hi, r_hi). Each is the low 32 / high 32 of the i64 stored in + // bigint form. + const base = i * INPUT_STRIDE_U32; + const u_pair = i64ToLoHi(mat.u); + const v_pair = i64ToLoHi(mat.v); + const q_pair = i64ToLoHi(mat.q); + const r_pair = i64ToLoHi(mat.r); + inputsBuf[base + 0] = u_pair.lo; + inputsBuf[base + 1] = v_pair.lo; + inputsBuf[base + 2] = q_pair.lo; + inputsBuf[base + 3] = r_pair.lo; + inputsBuf[base + 4] = u_pair.hi; + inputsBuf[base + 5] = v_pair.hi; + inputsBuf[base + 6] = q_pair.hi; + inputsBuf[base + 7] = r_pair.hi; + for (let j = 0; j < BY_N; j++) { + // Limbs are bigint after normalisation; coerce to signed 32-bit + // representation. The `>>> 0` then re-cast to i32 via Int32Array + // happens via writing as u32 in the buffer and bitcasting in WGSL. + inputsBuf[base + 8 + j] = Number(BigInt.asIntN(32, fState[j])) | 0; + inputsBuf[base + 17 + j] = Number(BigInt.asIntN(32, gState[j])) | 0; + inputsBuf[base + 26 + j] = Number(BigInt.asIntN(32, dState[j])) | 0; + inputsBuf[base + 35 + j] = Number(BigInt.asIntN(32, eState[j])) | 0; + } + + // 4. Compute reference for validate-n. + if (i < validateN) { + const fRef = fState.slice(); + const gRef = gState.slice(); + const dRef = dState.slice(); + const eRef = eState.slice(); + const pRef = byFromBigint(Wasm9x29.P); + Wasm9x29.applyMatrix(mat, fRef, gRef, dRef, eRef, pRef, Wasm9x29.P_INV); + // applyMatrix output is signed; serialise as i32. + const outBase = i * OUTPUT_STRIDE_I32; + for (let j = 0; j < BY_N; j++) { + expectedOutputs[outBase + 0 + j] = Number(BigInt.asIntN(32, fRef[j])) | 0; + expectedOutputs[outBase + 9 + j] = Number(BigInt.asIntN(32, gRef[j])) | 0; + expectedOutputs[outBase + 18 + j] = Number(BigInt.asIntN(32, dRef[j])) | 0; + expectedOutputs[outBase + 27 + j] = Number(BigInt.asIntN(32, eRef[j])) | 0; + } + } + } + + // Shader & pipeline. + const WORKGROUP_SIZE = 64; + const code = sm.gen_apply_matrix_bench_shader(WORKGROUP_SIZE); + const cacheKey = `apply-matrix-bench-wg${WORKGROUP_SIZE}`; + log('info', `compiling shader (${code.length} chars)`); + (window as unknown as Record)[`__shader`] = code; + const { pipeline, layout } = await createPipeline(device, code, cacheKey); + + // Buffers. + const inputsGpu = device.createBuffer({ + size: inputsBuf.byteLength, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(inputsGpu, 0, inputsBuf); + + const outBytes = n * OUTPUT_STRIDE_I32 * 4; + const outBuf = device.createBuffer({ + size: outBytes, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, + }); + + const uniformBytes = new ArrayBuffer(16); + const uniformView = new Uint32Array(uniformBytes); + uniformView[0] = n; + uniformView[1] = 0; + const uniformBuf = device.createBuffer({ + size: 16, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(uniformBuf, 0, uniformBytes); + + const bindGroup = device.createBindGroup({ + layout, + entries: [ + { binding: 0, resource: { buffer: inputsGpu } }, + { binding: 1, resource: { buffer: outBuf } }, + { binding: 2, resource: { buffer: uniformBuf } }, + ], + }); + + const numWorkgroups = Math.ceil(n / WORKGROUP_SIZE); + log('info', `dispatching ${numWorkgroups} workgroups of ${WORKGROUP_SIZE} threads each (${n} threads total)`); + + // Warmup pass. + { + const encoder = device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(numWorkgroups, 1, 1); + pass.end(); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + } + log('info', 'warmup OK'); + + // Validation pass — read back outputs. + const stagingBuf = device.createBuffer({ + size: outBytes, + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + { + const encoder = device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(numWorkgroups, 1, 1); + pass.end(); + encoder.copyBufferToBuffer(outBuf, 0, stagingBuf, 0, outBytes); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + await stagingBuf.mapAsync(GPUMapMode.READ); + } + const outBytesCopy = stagingBuf.getMappedRange(0, outBytes).slice(0); + stagingBuf.unmap(); + stagingBuf.destroy(); + + const outI32 = new Int32Array(outBytesCopy); + const mismatches: string[] = []; + let validateOk = true; + for (let i = 0; i < validateN; i++) { + const base = i * OUTPUT_STRIDE_I32; + let pairOk = true; + for (let j = 0; j < OUTPUT_STRIDE_I32; j++) { + if (outI32[base + j] !== expectedOutputs[base + j]) { + pairOk = false; + break; + } + } + if (!pairOk) { + validateOk = false; + if (mismatches.length < 5) { + const got = Array.from(outI32.slice(base, base + OUTPUT_STRIDE_I32)); + const want = Array.from(expectedOutputs.slice(base, base + OUTPUT_STRIDE_I32)); + const labelLimbs = (arr: number[], offset: number, name: string) => + `${name}: [${arr.slice(offset, offset + 9).join(', ')}]`; + mismatches.push( + `pair[${i}]:\n` + + ` expected:\n` + + ` ${labelLimbs(want, 0, "f'")}\n` + + ` ${labelLimbs(want, 9, "g'")}\n` + + ` ${labelLimbs(want, 18, "d'")}\n` + + ` ${labelLimbs(want, 27, "e'")}\n` + + ` actual:\n` + + ` ${labelLimbs(got, 0, "f'")}\n` + + ` ${labelLimbs(got, 9, "g'")}\n` + + ` ${labelLimbs(got, 18, "d'")}\n` + + ` ${labelLimbs(got, 27, "e'")}`, + ); + } + } + } + + if (!validateOk) { + log('err', `VALIDATION FAILED (${mismatches.length} mismatches shown; first ${validateN} pairs checked)`); + for (const m of mismatches) log('err', m); + inputsGpu.destroy(); + outBuf.destroy(); + uniformBuf.destroy(); + return { validateOk: false, mismatches, timing: null }; + } + log('ok', `VALIDATION OK (${validateN} pairs)`); + + // Timed reps. + const msSamples: number[] = []; + for (let rep = 0; rep < reps; rep++) { + const encoder = device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(numWorkgroups, 1, 1); + pass.end(); + const t0 = performance.now(); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + const t1 = performance.now(); + msSamples.push(t1 - t0); + } + const msMed = median(msSamples); + const msMin = Math.min(...msSamples); + const msMax = Math.max(...msSamples); + const applyMatrixPerSec = n / (msMed / 1000); + log( + 'ok', + `timing reps=${reps} median=${msMed.toFixed(3)}ms min=${msMin.toFixed(3)}ms max=${msMax.toFixed(3)}ms apply_matrix_calls/s=${applyMatrixPerSec.toExponential(3)} (n=${n})`, + ); + + inputsGpu.destroy(); + outBuf.destroy(); + uniformBuf.destroy(); + + return { + validateOk: true, + mismatches: [], + timing: { reps, msSamples, msMedian: msMed, msMin, msMax, applyMatrixPerSec }, + }; +} + +function parseParams(): { n: number; validateN: number; reps: number } { + const qp = new URLSearchParams(window.location.search); + const n = parseInt(qp.get('n') ?? '1024', 10); + const validateN = parseInt(qp.get('validate-n') ?? String(Math.min(64, n)), 10); + const reps = parseInt(qp.get('reps') ?? '3', 10); + if (!Number.isFinite(n) || n <= 0 || n > N_MAX) { + throw new Error(`?n must be in (0, ${N_MAX}], got ${qp.get('n')}`); + } + if (!Number.isFinite(validateN) || validateN < 0 || validateN > n) { + throw new Error(`?validate-n must be in [0, n], got ${qp.get('validate-n')}`); + } + if (!Number.isFinite(reps) || reps <= 0 || reps > 100) { + throw new Error(`?reps must be in (0, 100], got ${qp.get('reps')}`); + } + return { n, validateN, reps }; +} + +async function main() { + try { + if (!('gpu' in navigator)) { + throw new Error('navigator.gpu missing — WebGPU not available'); + } + const params = parseParams(); + benchState.params = params; + log('info', `params: n=${params.n} validate-n=${params.validateN} reps=${params.reps}`); + + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + + // ShaderManager is keyed on chunk_size / input_size for the MSM pipeline; + // for this bench we only need its constants table so values are arbitrary. + const sm = new ShaderManager(4, params.n, BN254_CURVE_CONFIG, false); + + const result = await runBench(device, sm, params.n, params.validateN, params.reps); + benchState.result = result; + + benchState.state = 'done'; + if (result.validateOk) { + log('ok', 'bench done'); + } else { + log('err', 'bench done with validation failures'); + } + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main().catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; +}); diff --git a/barretenberg/ts/dev/msm-webgpu/bench-batch-affine.html b/barretenberg/ts/dev/msm-webgpu/bench-batch-affine.html new file mode 100644 index 000000000000..dcb1507d12ef --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-batch-affine.html @@ -0,0 +1,37 @@ + + + + + Batch-affine EC add amortisation bench (WebGPU) + + + +

Batch-affine EC add amortisation bench (WebGPU)

+

Query params: ?reps=R

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-batch-affine.ts b/barretenberg/ts/dev/msm-webgpu/bench-batch-affine.ts new file mode 100644 index 000000000000..1667518702fa --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-batch-affine.ts @@ -0,0 +1,423 @@ +/// +// Standalone single-dispatch WebGPU benchmark for batch-affine EC addition. +// Sweeps BATCH_SIZE in {32, 64, 128, 256, 512, 1024, 2048} at fixed +// TOTAL_PAIRS = 65536 to find the sweet spot where amortising the single +// `fr_inv_by_a` per batch stops beating the loss of GPU thread occupancy. +// +// SAFETY: NO MSM pipeline touched. The shader has only compile-time-const +// loop bounds (BS, TPB, NUM_WORDS). Single dispatch per measurement, no +// recursion. Total VRAM under 40 MB. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js'; + +// TOTAL_PAIRS defaults to 65536 (2^16). Overridable via ?total=N for +// smoke-tests on small slices (NOT used in the headline sweep, but +// useful for "does this dispatch return at all" runs at N=1024). +const DEFAULT_TOTAL_PAIRS = 1 << 16; // 65536 +let TOTAL_PAIRS = DEFAULT_TOTAL_PAIRS; +// BATCH_SIZES is also overridable via ?sizes=A,B,C for the smoke-test +// where we only want to run one size. +const DEFAULT_BATCH_SIZES = [32, 64, 128, 256, 512, 1024, 2048] as const; +let BATCH_SIZES: readonly number[] = DEFAULT_BATCH_SIZES; + +// (batch_size, tpb, per_thread_count) table. Per the bench spec: +// B=32: TPB=32, BS=1 +// B=64: TPB=64, BS=1 +// B=128+: TPB=64, BS=B/64 +function tpbFor(batchSize: number): { tpb: number; bs: number } { + if (batchSize === 32) return { tpb: 32, bs: 1 }; + if (batchSize === 64) return { tpb: 64, bs: 1 }; + return { tpb: 64, bs: batchSize / 64 }; +} + +const NUM_LIMBS_U32 = 20; +const WORD_SIZE_U32 = 13; +const W_U32 = 1n << BigInt(WORD_SIZE_U32); +const MASK_U32 = W_U32 - 1n; + +function bigintToLimbsU32(v: bigint): number[] { + const limbs: number[] = new Array(NUM_LIMBS_U32); + let x = v; + for (let i = 0; i < NUM_LIMBS_U32; i++) { + limbs[i] = Number(x & MASK_U32); + x >>= BigInt(WORD_SIZE_U32); + } + return limbs; +} + +function limbsU32ToBigint(limbs: ArrayLike): bigint { + let v = 0n; + for (let i = NUM_LIMBS_U32 - 1; i >= 0; i--) { + v = (v << BigInt(WORD_SIZE_U32)) | BigInt(limbs[i] >>> 0); + } + return v; +} + +// Seeded LCG (Numerical Recipes constants). Matches the seeding convention +// used by the other dev-page benches. +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function randomBelow(p: bigint, rng: () => number): bigint { + const bitlen = p.toString(2).length; + const byteLen = Math.ceil(bitlen / 8); + while (true) { + let v = 0n; + for (let i = 0; i < byteLen; i++) { + v = (v << 8n) | BigInt(rng() & 0xff); + } + v &= (1n << BigInt(bitlen)) - 1n; + if (v < p) return v; + } +} + +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; +} + +interface PerSizeResult { + batch_size: number; + tpb: number; + bs: number; + num_wgs: number; + total_threads: number; + median_ms: number; + min_ms: number; + max_ms: number; + ns_per_pair: number; + samples_ms: number[]; + validated: boolean; +} + +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: { reps: number } | null; + results: PerSizeResult[]; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { + state: 'boot', + params: null, + results: [], + error: null, + log: [], +}; +(window as unknown as { __bench: BenchState }).__bench = benchState; + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-batch-affine] ${msg}`); +} + +async function createPipeline( + device: GPUDevice, + code: string, + cacheKey: string, +): Promise<{ pipeline: GPUComputePipeline; layout: GPUBindGroupLayout }> { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + for (const msg of info.messages) { + const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`; + if (msg.type === 'error') { + console.error(line); + hasError = true; + } else { + console.warn(line); + } + } + if (hasError) { + throw new Error(`WGSL compile failed for ${cacheKey}`); + } + const layout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + ], + }); + const pipeline = await device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); + return { pipeline, layout }; +} + +// Build inputs ONCE per batch_size (deterministic via LCG seeded with +// 0xdeadbeef + batch_size). Generates `TOTAL_PAIRS * 4` BigInts holding +// random Mont-form values: pair k = [P_k.x, P_k.y, Q_k.x, Q_k.y]. The +// algebra is garbage (these aren't on-curve points), but the shader +// produces SOMETHING and the dispatch shape is what we're timing. +function buildInputs(batchSize: number, R: bigint, p: bigint): ArrayBuffer { + const rng = makeRng(0xdeadbeef + batchSize); + const bytes = new ArrayBuffer(TOTAL_PAIRS * 4 * NUM_LIMBS_U32 * 4); + const buf = new Uint32Array(bytes); + for (let k = 0; k < TOTAL_PAIRS; k++) { + // Generate four field elements per pair, then Montgomery-form them so + // the shader's `montgomery_product` calls produce in-Mont-form + // outputs. We must also ensure P.x != Q.x (delta != 0) per pair, since + // batch-inverse on a zero delta is undefined and would NaN the run. + let pxMont: bigint, qxMont: bigint; + do { + const pxCan = randomBelow(p, rng); + pxMont = (pxCan * R) % p; + const qxCan = randomBelow(p, rng); + qxMont = (qxCan * R) % p; + } while (pxMont === qxMont); + const pyMont = (randomBelow(p, rng) * R) % p; + const qyMont = (randomBelow(p, rng) * R) % p; + const coords = [pxMont, pyMont, qxMont, qyMont]; + const base = k * 4 * NUM_LIMBS_U32; + for (let c = 0; c < 4; c++) { + const limbs = bigintToLimbsU32(coords[c]); + const off = base + c * NUM_LIMBS_U32; + for (let j = 0; j < NUM_LIMBS_U32; j++) buf[off + j] = limbs[j]; + } + } + return bytes; +} + +async function runOne( + device: GPUDevice, + sm: ShaderManager, + batchSize: number, + reps: number, + R: bigint, + p: bigint, +): Promise { + const { tpb, bs } = tpbFor(batchSize); + const numWgs = TOTAL_PAIRS / batchSize; + const totalThreads = numWgs * tpb; + log( + 'info', + `=== batch_size=${batchSize}: TPB=${tpb} BS=${bs} num_WGs=${numWgs} total_threads=${totalThreads}`, + ); + + // Compile shader. + const code = sm.gen_bench_batch_affine_shader(batchSize, tpb); + const cacheKey = `bench-batch-affine-B${batchSize}-T${tpb}`; + log('info', `compiling shader (${code.length} chars)`); + (window as unknown as Record)[`__shader_${batchSize}`] = code; + const { pipeline, layout } = await createPipeline(device, code, cacheKey); + + // Build inputs. + const inputsAB = buildInputs(batchSize, R, p); + + // Buffers. + const inputBytes = inputsAB.byteLength; // TOTAL_PAIRS * 4 * 80 = 21 MB + const inputsBuf = device.createBuffer({ + size: inputBytes, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(inputsBuf, 0, inputsAB); + + const prefixBytes = TOTAL_PAIRS * NUM_LIMBS_U32 * 4; // ~5.2 MB + const prefixBuf = device.createBuffer({ + size: prefixBytes, + usage: GPUBufferUsage.STORAGE, + }); + + const outputBytes = TOTAL_PAIRS * 2 * NUM_LIMBS_U32 * 4; // ~10.5 MB + const outputsBuf = device.createBuffer({ + size: outputBytes, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, + }); + + const bindGroup = device.createBindGroup({ + layout, + entries: [ + { binding: 0, resource: { buffer: inputsBuf } }, + { binding: 1, resource: { buffer: prefixBuf } }, + { binding: 2, resource: { buffer: outputsBuf } }, + ], + }); + + // Warmup + small validation: confirm the first pair's R.x/R.y are + // nonzero. We DON'T check algebraic correctness — the inputs aren't on + // an elliptic curve. We just need the shader to do SOMETHING and not + // hang or produce all zeros. + { + const encoder = device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + } + log('info', 'warmup dispatch returned'); + + // Sanity readback: copy first pair's R.x to host and confirm non-zero. + const sanityBytes = 2 * NUM_LIMBS_U32 * 4; + const stagingBuf = device.createBuffer({ + size: sanityBytes, + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + { + const encoder = device.createCommandEncoder(); + encoder.copyBufferToBuffer(outputsBuf, 0, stagingBuf, 0, sanityBytes); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + await stagingBuf.mapAsync(GPUMapMode.READ); + } + const sanityBytesCopy = stagingBuf.getMappedRange(0, sanityBytes).slice(0); + stagingBuf.unmap(); + stagingBuf.destroy(); + const sanityU32 = new Uint32Array(sanityBytesCopy); + const rx = limbsU32ToBigint(sanityU32.subarray(0, NUM_LIMBS_U32)); + const ry = limbsU32ToBigint(sanityU32.subarray(NUM_LIMBS_U32, 2 * NUM_LIMBS_U32)); + if (rx === 0n && ry === 0n) { + log('err', `SANITY FAIL @ B=${batchSize}: first pair R.x and R.y both zero`); + inputsBuf.destroy(); + prefixBuf.destroy(); + outputsBuf.destroy(); + throw new Error(`sanity fail at batch_size=${batchSize}`); + } + log('ok', `sanity OK: pair[0].R.x=0x${rx.toString(16).slice(0, 16)}... R.y=0x${ry.toString(16).slice(0, 16)}...`); + + // Timed reps. Each rep includes a fresh encoder so we measure + // submit→idle (the shape the user cares about). Match bench-fr-inv.ts. + const samples: number[] = []; + for (let r = 0; r < reps; r++) { + const encoder = device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(numWgs, 1, 1); + pass.end(); + const t0 = performance.now(); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + const t1 = performance.now(); + samples.push(t1 - t0); + } + const med = median(samples); + const mn = Math.min(...samples); + const mx = Math.max(...samples); + const nsPerPair = (med * 1e6) / TOTAL_PAIRS; + + log( + 'ok', + `B=${batchSize}: median=${med.toFixed(3)}ms min=${mn.toFixed(3)}ms max=${mx.toFixed(3)}ms ns/pair=${nsPerPair.toFixed(1)}`, + ); + + inputsBuf.destroy(); + prefixBuf.destroy(); + outputsBuf.destroy(); + + return { + batch_size: batchSize, + tpb, + bs, + num_wgs: numWgs, + total_threads: totalThreads, + median_ms: med, + min_ms: mn, + max_ms: mx, + ns_per_pair: nsPerPair, + samples_ms: samples, + validated: true, + }; +} + +function parseParams(): { reps: number } { + const qp = new URLSearchParams(window.location.search); + const reps = parseInt(qp.get('reps') ?? '5', 10); + if (!Number.isFinite(reps) || reps <= 0 || reps > 50) { + throw new Error(`?reps must be in (0, 50], got ${qp.get('reps')}`); + } + const totalStr = qp.get('total'); + if (totalStr !== null) { + const total = parseInt(totalStr, 10); + if (!Number.isFinite(total) || total <= 0 || total > (1 << 20)) { + throw new Error(`?total must be in (0, 2^20], got ${totalStr}`); + } + TOTAL_PAIRS = total; + } + const sizesStr = qp.get('sizes'); + if (sizesStr !== null) { + const sizes = sizesStr.split(',').map(s => parseInt(s, 10)); + for (const s of sizes) { + if (!Number.isFinite(s) || s <= 0 || s > 4096) { + throw new Error(`?sizes entries must be in (0, 4096], got ${s}`); + } + if (TOTAL_PAIRS % s !== 0) { + throw new Error(`?sizes entry ${s} does not divide TOTAL_PAIRS=${TOTAL_PAIRS}`); + } + } + BATCH_SIZES = sizes; + } + return { reps }; +} + +async function main() { + try { + if (!('gpu' in navigator)) { + throw new Error('navigator.gpu missing — WebGPU not available'); + } + const params = parseParams(); + benchState.params = params; + log('info', `params: reps=${params.reps} TOTAL_PAIRS=${TOTAL_PAIRS}`); + + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + + const p = BN254_BASE_FIELD; + const miscParams = compute_misc_params(p, WORD_SIZE_U32); + if (miscParams.num_words !== NUM_LIMBS_U32) { + throw new Error(`expected num_words=${NUM_LIMBS_U32}, got ${miscParams.num_words}`); + } + const R = miscParams.r; + + const sm = new ShaderManager(4, TOTAL_PAIRS, BN254_CURVE_CONFIG, false); + + for (const B of BATCH_SIZES) { + try { + const r = await runOne(device, sm, B, params.reps, R, p); + benchState.results.push(r); + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + log('err', `B=${B} failed: ${msg} — STOPPING sweep at first failure`); + benchState.state = 'error'; + benchState.error = msg; + return; + } + } + + benchState.state = 'done'; + log('ok', 'all batches done'); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main().catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; +}); diff --git a/barretenberg/ts/dev/msm-webgpu/bench-divsteps.html b/barretenberg/ts/dev/msm-webgpu/bench-divsteps.html new file mode 100644 index 000000000000..71ee4b64ff3d --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-divsteps.html @@ -0,0 +1,37 @@ + + + + + BY divsteps WebGPU dispatch test + + + +

BY divsteps WebGPU dispatch test

+

Query params: ?n=N&validate-n=N&reps=R

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-divsteps.ts b/barretenberg/ts/dev/msm-webgpu/bench-divsteps.ts new file mode 100644 index 000000000000..92e2f15cbf33 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-divsteps.ts @@ -0,0 +1,398 @@ +/// +// BY divsteps WebGPU dispatch test. Mounted standalone (no SRS, no MSM +// pipeline). Reads `?n=N&validate-n=N&reps=R`, generates seeded random +// (f_lo, g_lo, delta) tuples, runs the `by_divsteps` shader once per +// thread, validates each output against the TS `Wasm9x29.divsteps` +// reference, and reports timing via `window.__bench`. +// +// Safety. The shader's only loop is bounded by the WGSL `const BY_BATCH = 58u` +// (in bigint_by.template.wgsl). The dispatch is one thread per input tuple, +// `n` capped at 2^23. Inputs are u64 (for f_lo, g_lo) and i32 (for delta) +// — by_divsteps is variable-time over branches but always exits after exactly +// 58 iterations. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { Wasm9x29 } from '../../src/msm_webgpu/cuzk/bernstein_yang.js'; + +interface SampleSummary { + reps: number; + msSamples: number[]; + msMedian: number; + msMin: number; + msMax: number; + divstepsPerSec: number; +} +interface BenchResult { + validateOk: boolean; + mismatches: string[]; + timing: SampleSummary | null; +} +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: { + n: number; + validateN: number; + reps: number; + } | null; + result: BenchResult | null; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { + state: 'boot', + params: null, + result: null, + error: null, + log: [], +}; +(window as unknown as { __bench: BenchState }).__bench = benchState; + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-divsteps] ${msg}`); +} + +const N_MAX = 1 << 23; + +// Seeded LCG (Numerical Recipes constants). Matches the pattern used in +// bench-field-mul.ts; the spec mandates seed 0xb33fb33f. +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +// Produce a random u64 as a (lo32, hi32) pair via two LCG draws. +function randomU64(rng: () => number): { lo: number; hi: number } { + const lo = rng() >>> 0; + const hi = rng() >>> 0; + return { lo, hi }; +} + +// Random delta in [-512, 512] inclusive (1025 values). +function randomDelta(rng: () => number): number { + return (rng() % 1025) - 512; +} + +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; +} + +// Reassemble a (lo: i32, hi: i32) GPU output pair as a bigint (signed two's +// complement, 64 bits) for cross-checking against the TS reference's +// `bigint` matrix entries. +function pairToSignedBig(lo: number, hi: number): bigint { + const loBig = BigInt(lo >>> 0); + const hiBig = BigInt(hi >>> 0); + let v = loBig | (hiBig << 32n); + if (v >= 1n << 63n) v -= 1n << 64n; + return v; +} + +async function createPipeline( + device: GPUDevice, + code: string, + cacheKey: string, +): Promise<{ pipeline: GPUComputePipeline; layout: GPUBindGroupLayout }> { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + for (const msg of info.messages) { + const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`; + if (msg.type === 'error') { + console.error(line); + hasError = true; + } else { + console.warn(line); + } + } + if (hasError) { + throw new Error(`WGSL compile failed for ${cacheKey}`); + } + const layout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + const pipeline = await device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); + return { pipeline, layout }; +} + +async function runBench( + device: GPUDevice, + sm: ShaderManager, + n: number, + validateN: number, + reps: number, +): Promise { + log('info', `building inputs (n=${n}, validate-n=${validateN}, reps=${reps})`); + + // Generate seeded inputs and compute TS reference in lockstep. + const rng = makeRng(0xb33fb33f); + const inputsFg = new Uint32Array(n * 4); + const inputsDelta = new Int32Array(n); + const expectedU = new BigInt64Array(n); + const expectedV = new BigInt64Array(n); + const expectedQ = new BigInt64Array(n); + const expectedR = new BigInt64Array(n); + const expectedDelta = new Int32Array(n); + for (let i = 0; i < n; i++) { + const f = randomU64(rng); + const g = randomU64(rng); + const d = randomDelta(rng); + inputsFg[i * 4 + 0] = f.lo; + inputsFg[i * 4 + 1] = f.hi; + inputsFg[i * 4 + 2] = g.lo; + inputsFg[i * 4 + 3] = g.hi; + inputsDelta[i] = d; + // Reassemble u64 bigints for the TS reference; delta is signed i32. + const fBig = BigInt(f.lo >>> 0) | (BigInt(f.hi >>> 0) << 32n); + const gBig = BigInt(g.lo >>> 0) | (BigInt(g.hi >>> 0) << 32n); + if (i < validateN) { + const { mat, delta: deltaOut } = Wasm9x29.divsteps(BigInt(d), fBig, gBig); + expectedU[i] = mat.u; + expectedV[i] = mat.v; + expectedQ[i] = mat.q; + expectedR[i] = mat.r; + // The TS port returns delta as bigint (the C++ i64 view). For our + // i32 carrier on the GPU this is fine — under BATCH=58 inner ops, + // delta changes by at most 58 per call, so for |delta_in| <= 512 the + // result fits well inside i32. + expectedDelta[i] = Number(deltaOut); + } + } + + // Shader & pipeline. + const WORKGROUP_SIZE = 64; + const code = sm.gen_divsteps_bench_shader(WORKGROUP_SIZE); + const cacheKey = `divsteps-bench-wg${WORKGROUP_SIZE}`; + log('info', `compiling shader (${code.length} chars)`); + (window as unknown as Record)[`__shader`] = code; + const { pipeline, layout } = await createPipeline(device, code, cacheKey); + + // Buffers. + const inputsFgBuf = device.createBuffer({ + size: inputsFg.byteLength, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(inputsFgBuf, 0, inputsFg); + const inputsDeltaBuf = device.createBuffer({ + size: inputsDelta.byteLength, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(inputsDeltaBuf, 0, inputsDelta); + const outBytes = n * 9 * 4; + const outBuf = device.createBuffer({ + size: outBytes, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, + }); + const uniformBytes = new ArrayBuffer(16); + const uniformView = new Uint32Array(uniformBytes); + uniformView[0] = n; + uniformView[1] = 0; + const uniformBuf = device.createBuffer({ + size: 16, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(uniformBuf, 0, uniformBytes); + + const bindGroup = device.createBindGroup({ + layout, + entries: [ + { binding: 0, resource: { buffer: inputsFgBuf } }, + { binding: 1, resource: { buffer: inputsDeltaBuf } }, + { binding: 2, resource: { buffer: outBuf } }, + { binding: 3, resource: { buffer: uniformBuf } }, + ], + }); + + const numWorkgroups = Math.ceil(n / WORKGROUP_SIZE); + log('info', `dispatching ${numWorkgroups} workgroups of ${WORKGROUP_SIZE} threads each (${n} threads total)`); + + // Warmup pass. + { + const encoder = device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(numWorkgroups, 1, 1); + pass.end(); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + } + log('info', 'warmup OK'); + + // Validation pass — read back outputs. + const stagingBuf = device.createBuffer({ + size: outBytes, + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + { + const encoder = device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(numWorkgroups, 1, 1); + pass.end(); + encoder.copyBufferToBuffer(outBuf, 0, stagingBuf, 0, outBytes); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + await stagingBuf.mapAsync(GPUMapMode.READ); + } + const outBytesCopy = stagingBuf.getMappedRange(0, outBytes).slice(0); + stagingBuf.unmap(); + stagingBuf.destroy(); + + const outI32 = new Int32Array(outBytesCopy); + let mismatches: string[] = []; + let validateOk = true; + for (let i = 0; i < validateN; i++) { + const base = i * 9; + const u = pairToSignedBig(outI32[base + 0], outI32[base + 4]); + const v = pairToSignedBig(outI32[base + 1], outI32[base + 5]); + const q = pairToSignedBig(outI32[base + 2], outI32[base + 6]); + const r = pairToSignedBig(outI32[base + 3], outI32[base + 7]); + const dOut = outI32[base + 8]; + const okU = u === expectedU[i]; + const okV = v === expectedV[i]; + const okQ = q === expectedQ[i]; + const okR = r === expectedR[i]; + const okD = dOut === expectedDelta[i]; + if (!(okU && okV && okQ && okR && okD)) { + validateOk = false; + if (mismatches.length < 5) { + const fLo = BigInt(inputsFg[i * 4] >>> 0) | (BigInt(inputsFg[i * 4 + 1] >>> 0) << 32n); + const gLo = BigInt(inputsFg[i * 4 + 2] >>> 0) | (BigInt(inputsFg[i * 4 + 3] >>> 0) << 32n); + mismatches.push( + `pair[${i}]: delta_in=${inputsDelta[i]} f_lo=0x${fLo.toString(16)} g_lo=0x${gLo.toString(16)}\n` + + ` expected: u=${expectedU[i]} v=${expectedV[i]} q=${expectedQ[i]} r=${expectedR[i]} delta=${expectedDelta[i]}\n` + + ` actual: u=${u} v=${v} q=${q} r=${r} delta=${dOut}`, + ); + } + } + } + + if (!validateOk) { + log('err', `VALIDATION FAILED (${mismatches.length} mismatches shown; first ${validateN} pairs checked)`); + for (const m of mismatches) log('err', m); + inputsFgBuf.destroy(); + inputsDeltaBuf.destroy(); + outBuf.destroy(); + uniformBuf.destroy(); + return { validateOk: false, mismatches, timing: null }; + } + log('ok', `VALIDATION OK (${validateN} pairs)`); + + // Timed reps. + const msSamples: number[] = []; + for (let rep = 0; rep < reps; rep++) { + const encoder = device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(numWorkgroups, 1, 1); + pass.end(); + const t0 = performance.now(); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + const t1 = performance.now(); + msSamples.push(t1 - t0); + } + const msMed = median(msSamples); + const msMin = Math.min(...msSamples); + const msMax = Math.max(...msSamples); + const divstepsPerSec = n / (msMed / 1000); + log( + 'ok', + `timing reps=${reps} median=${msMed.toFixed(3)}ms min=${msMin.toFixed(3)}ms max=${msMax.toFixed(3)}ms divsteps_calls/s=${divstepsPerSec.toExponential(3)} (n=${n})`, + ); + + inputsFgBuf.destroy(); + inputsDeltaBuf.destroy(); + outBuf.destroy(); + uniformBuf.destroy(); + + return { + validateOk: true, + mismatches: [], + timing: { reps, msSamples, msMedian: msMed, msMin, msMax, divstepsPerSec }, + }; +} + +function parseParams(): { n: number; validateN: number; reps: number } { + const qp = new URLSearchParams(window.location.search); + const n = parseInt(qp.get('n') ?? '1024', 10); + const validateN = parseInt(qp.get('validate-n') ?? String(Math.min(64, n)), 10); + const reps = parseInt(qp.get('reps') ?? '3', 10); + if (!Number.isFinite(n) || n <= 0 || n > N_MAX) { + throw new Error(`?n must be in (0, ${N_MAX}], got ${qp.get('n')}`); + } + if (!Number.isFinite(validateN) || validateN < 0 || validateN > n) { + throw new Error(`?validate-n must be in [0, n], got ${qp.get('validate-n')}`); + } + if (!Number.isFinite(reps) || reps <= 0 || reps > 100) { + throw new Error(`?reps must be in (0, 100], got ${qp.get('reps')}`); + } + return { n, validateN, reps }; +} + +async function main() { + try { + if (!('gpu' in navigator)) { + throw new Error('navigator.gpu missing — WebGPU not available'); + } + const params = parseParams(); + benchState.params = params; + log('info', `params: n=${params.n} validate-n=${params.validateN} reps=${params.reps}`); + + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + + // ShaderManager is keyed on chunk_size / input_size for the MSM pipeline; + // for this bench we only need its constants table (num_words, etc.) so + // values are arbitrary. + const sm = new ShaderManager(4, params.n, BN254_CURVE_CONFIG, false); + + const result = await runBench(device, sm, params.n, params.validateN, params.reps); + benchState.result = result; + + benchState.state = 'done'; + if (result.validateOk) { + log('ok', 'bench done'); + } else { + log('err', 'bench done with validation failures'); + } + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main().catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; +}); diff --git a/barretenberg/ts/dev/msm-webgpu/bench-fr-inv.html b/barretenberg/ts/dev/msm-webgpu/bench-fr-inv.html new file mode 100644 index 000000000000..72d615c5f776 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-fr-inv.html @@ -0,0 +1,37 @@ + + + + + BY fr_inv_by WebGPU dispatch test + + + +

BY fr_inv_by WebGPU dispatch test

+

Query params: ?n=N&k=K&validate-n=N&reps=R

+
+ + + diff --git a/barretenberg/ts/dev/msm-webgpu/bench-fr-inv.ts b/barretenberg/ts/dev/msm-webgpu/bench-fr-inv.ts new file mode 100644 index 000000000000..3db8bb7944a8 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/bench-fr-inv.ts @@ -0,0 +1,472 @@ +/// +// fr_inv WebGPU dispatch test. Mounted standalone (no SRS, no MSM +// pipeline). Reads `?n=N&k=K&validate-n=N&reps=R&variant=fr_inv|fr_inv_by`, +// generates seeded random BN254 base-field values in Montgomery form, runs +// `k` chained inverse calls per thread, validates each output against the +// host `modInverse` + Mont-correction reference (same reference for both +// variants — the algorithms must produce identical outputs), and reports +// timing via `window.__bench`. +// +// Safety: `n` is capped at 2^20, `k` at 100. The only data-dependent loop +// in the entry shader is `for (var i = 0u; i < k; ...)` with `k` capped. +// Every loop inside `fr_inv` / `fr_inv_by` (and its callees) is bounded by +// a compile-time `const` — see by_inverse.template.wgsl, bigint_by, +// fr_pow.template.wgsl, and montgomery_product. + +import { ShaderManager } from '../../src/msm_webgpu/cuzk/shader_manager.js'; +import { BN254_CURVE_CONFIG } from '../../src/msm_webgpu/cuzk/curve_config.js'; +import { get_device } from '../../src/msm_webgpu/cuzk/gpu.js'; +import { compute_misc_params } from '../../src/msm_webgpu/cuzk/utils.js'; +import { modInverse, BN254_BASE_FIELD } from '../../src/msm_webgpu/cuzk/bn254.js'; + +interface SampleSummary { + reps: number; + msSamples: number[]; + msMedian: number; + msMin: number; + msMax: number; + inversionsPerSec: number; +} +interface BenchResult { + validateOk: boolean; + mismatches: string[]; + timing: SampleSummary | null; +} +interface BenchState { + state: 'boot' | 'running' | 'done' | 'error'; + params: { + n: number; + k: number; + validateN: number; + reps: number; + variant: 'fr_inv' | 'fr_inv_by' | 'fr_inv_by_a' | 'fr_pow_inv'; + } | null; + result: BenchResult | null; + error: string | null; + log: string[]; +} + +const benchState: BenchState = { + state: 'boot', + params: null, + result: null, + error: null, + log: [], +}; +(window as unknown as { __bench: BenchState }).__bench = benchState; + +const $log = document.getElementById('log') as HTMLDivElement; +function log(level: 'info' | 'ok' | 'err' | 'warn', msg: string) { + const cls = level === 'ok' ? 'ok' : level === 'err' ? 'err' : level === 'warn' ? 'warn' : ''; + const span = document.createElement('div'); + span.className = cls; + span.textContent = msg; + $log.appendChild(span); + benchState.log.push(`[${level}] ${msg}`); + console.log(`[bench-fr-inv] ${msg}`); +} + +const N_MAX = 1 << 20; +const K_MAX = 100; + +// 20 x 13-bit limb layout matches the u32 MSM path (the only path that +// fr_inv_by is wired into in this bench). +const NUM_LIMBS_U32 = 20; +const WORD_SIZE_U32 = 13; +const W_U32 = 1n << BigInt(WORD_SIZE_U32); +const MASK_U32 = W_U32 - 1n; + +function bigintToLimbsU32(v: bigint): number[] { + const limbs: number[] = new Array(NUM_LIMBS_U32); + let x = v; + for (let i = 0; i < NUM_LIMBS_U32; i++) { + limbs[i] = Number(x & MASK_U32); + x >>= BigInt(WORD_SIZE_U32); + } + return limbs; +} + +function limbsU32ToBigint(limbs: ArrayLike): bigint { + let v = 0n; + for (let i = NUM_LIMBS_U32 - 1; i >= 0; i--) { + v = (v << BigInt(WORD_SIZE_U32)) | BigInt(limbs[i] >>> 0); + } + return v; +} + +// Seeded LCG (Numerical Recipes constants) for reproducible input gen. +// Matches the seeding convention in bench-field-mul.ts. +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function randomBelow(p: bigint, rng: () => number): bigint { + const bitlen = p.toString(2).length; + const byteLen = Math.ceil(bitlen / 8); + while (true) { + let v = 0n; + for (let i = 0; i < byteLen; i++) { + v = (v << 8n) | BigInt(rng() & 0xff); + } + v &= (1n << BigInt(bitlen)) - 1n; + if (v < p) return v; + } +} + +function median(xs: number[]): number { + if (xs.length === 0) return NaN; + const s = xs.slice().sort((a, b) => a - b); + return s[Math.floor(s.length / 2)]; +} + +// Host reference: chained `fr_inv_by` over Mont-form inputs. +// +// fr_inv_by maps Mont(A) -> Mont(A^(-1)). Concretely, if input is +// a_m = A * R mod p, then output is A^(-1) * R mod p. +// +// One way to compute this on the host: unmont -> modInverse -> mont. +// A = a_m * Rinv mod p +// A_inv = modInverse(A, p) = A^(-1) mod p (canonical) +// out_m = A_inv * R mod p +// +// Equivalently: out_m = modInverse(a_m, p) * R^2 mod p (since +// modInverse(a_m) = (A*R)^(-1) = A^(-1) * R^(-1), times R^2 lands at +// A^(-1) * R = out_m). Either formula works; using the unmont->inv->mont +// chain mirrors how a user typically thinks about Mont arithmetic. +function fr_inv_by_host_once(a_mont: bigint, R: bigint, Rinv: bigint, p: bigint): bigint { + if (a_mont === 0n) return 0n; + const a_canonical = (a_mont * Rinv) % p; + const a_inv = modInverse(a_canonical, p); + return (a_inv * R) % p; +} + +function fr_inv_by_host_chained( + a_mont: bigint, + k: number, + R: bigint, + Rinv: bigint, + p: bigint, +): bigint { + let acc = a_mont; + for (let i = 0; i < k; i++) { + acc = fr_inv_by_host_once(acc, R, Rinv, p); + } + return acc; +} + +async function createPipeline( + device: GPUDevice, + code: string, + cacheKey: string, +): Promise<{ pipeline: GPUComputePipeline; layout: GPUBindGroupLayout }> { + const module = device.createShaderModule({ code }); + const info = await module.getCompilationInfo(); + let hasError = false; + for (const msg of info.messages) { + const line = `[shader ${cacheKey}] ${msg.type}: ${msg.message} (line ${msg.lineNum}, col ${msg.linePos})`; + if (msg.type === 'error') { + console.error(line); + hasError = true; + } else { + console.warn(line); + } + } + if (hasError) { + throw new Error(`WGSL compile failed for ${cacheKey}`); + } + const layout = device.createBindGroupLayout({ + entries: [ + { binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, + { binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, + { binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } }, + ], + }); + const pipeline = await device.createComputePipelineAsync({ + layout: device.createPipelineLayout({ bindGroupLayouts: [layout] }), + compute: { module, entryPoint: 'main' }, + }); + return { pipeline, layout }; +} + +async function runBench( + device: GPUDevice, + sm: ShaderManager, + n: number, + k: number, + validateN: number, + reps: number, + variant: 'fr_inv' | 'fr_inv_by' | 'fr_inv_by_a' | 'fr_pow_inv', +): Promise { + log('info', `building inputs (n=${n}, k=${k}, validate-n=${validateN}, reps=${reps}, variant=${variant})`); + const p = BN254_BASE_FIELD; + const params = compute_misc_params(p, WORD_SIZE_U32); + if (params.num_words !== NUM_LIMBS_U32) { + throw new Error(`expected num_words=${NUM_LIMBS_U32}, got ${params.num_words}`); + } + const R = params.r; + const Rinv = params.rinv; + if ((R * Rinv) % p !== 1n) { + throw new Error(`R * Rinv mod p != 1`); + } + + // Generate random canonical inputs. + const rng = makeRng(0xc0ffee); + const aCanonical: bigint[] = new Array(n); + for (let i = 0; i < n; i++) { + // Ensure nonzero — fr_inv_by of 0 is undefined / 0 by convention but + // the chained case requires invertibility every step. Reject 0. + let v = randomBelow(p, rng); + while (v === 0n) v = randomBelow(p, rng); + aCanonical[i] = v; + } + const aMont: bigint[] = aCanonical.map(x => (x * R) % p); + + // Host reference for the first validate-n inputs. + log('info', `computing host reference for ${validateN} pairs`); + const expected: bigint[] = new Array(validateN); + for (let i = 0; i < validateN; i++) { + expected[i] = fr_inv_by_host_chained(aMont[i], k, R, Rinv, p); + } + + // Pack input buffer: n BigInts (20 x u32 each). + const bytesPerLimbArray = NUM_LIMBS_U32 * 4; + const xsBytes = new ArrayBuffer(n * bytesPerLimbArray); + const xv = new Uint32Array(xsBytes); + for (let i = 0; i < n; i++) { + const limbs = bigintToLimbsU32(aMont[i]); + const off = i * NUM_LIMBS_U32; + for (let j = 0; j < NUM_LIMBS_U32; j++) xv[off + j] = limbs[j]; + } + + // Shader & pipeline. + const WORKGROUP_SIZE = 64; + const code = sm.gen_fr_inv_bench_shader(WORKGROUP_SIZE, variant); + const cacheKey = `fr-inv-bench-${variant}-wg${WORKGROUP_SIZE}`; + log('info', `compiling shader (${code.length} chars)`); + (window as unknown as Record)[`__shader`] = code; + const { pipeline, layout } = await createPipeline(device, code, cacheKey); + + // Buffers. + const xsBuf = device.createBuffer({ + size: xsBytes.byteLength, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(xsBuf, 0, xsBytes); + const outBytes = n * NUM_LIMBS_U32 * 4; + const outBuf = device.createBuffer({ + size: outBytes, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, + }); + const uniformBytes = new ArrayBuffer(16); + const uniformView = new Uint32Array(uniformBytes); + uniformView[0] = n; + uniformView[1] = k; + const uniformBuf = device.createBuffer({ + size: 16, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + device.queue.writeBuffer(uniformBuf, 0, uniformBytes); + + const bindGroup = device.createBindGroup({ + layout, + entries: [ + { binding: 0, resource: { buffer: xsBuf } }, + { binding: 1, resource: { buffer: outBuf } }, + { binding: 2, resource: { buffer: uniformBuf } }, + ], + }); + + const numWorkgroups = Math.ceil(n / WORKGROUP_SIZE); + log('info', `dispatching ${numWorkgroups} workgroups of ${WORKGROUP_SIZE} threads (${n} threads total)`); + + // Warmup pass. + { + const encoder = device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(numWorkgroups, 1, 1); + pass.end(); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + } + log('info', 'warmup OK'); + + // Validation pass — read back outputs. + const stagingBuf = device.createBuffer({ + size: outBytes, + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + { + const encoder = device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(numWorkgroups, 1, 1); + pass.end(); + encoder.copyBufferToBuffer(outBuf, 0, stagingBuf, 0, outBytes); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + await stagingBuf.mapAsync(GPUMapMode.READ); + } + const outBytesCopy = stagingBuf.getMappedRange(0, outBytes).slice(0); + stagingBuf.unmap(); + stagingBuf.destroy(); + + const outU32 = new Uint32Array(outBytesCopy); + const mismatches: string[] = []; + let validateOk = true; + for (let i = 0; i < validateN; i++) { + const limbs = outU32.subarray(i * NUM_LIMBS_U32, (i + 1) * NUM_LIMBS_U32); + const got = limbsU32ToBigint(limbs); + if (got !== expected[i]) { + validateOk = false; + if (mismatches.length < 5) { + mismatches.push( + `pair[${i}]: a_canonical=0x${aCanonical[i].toString(16)} a_mont=0x${aMont[i].toString(16)}\n` + + ` expected: 0x${expected[i].toString(16)}\n` + + ` actual: 0x${got.toString(16)}\n` + + ` expected_limbs: [${bigintToLimbsU32(expected[i]).join(', ')}]\n` + + ` actual_limbs: [${Array.from(limbs).join(', ')}]`, + ); + } + } + } + + if (!validateOk) { + log('err', `VALIDATION FAILED (${mismatches.length} mismatches shown; first ${validateN} pairs checked)`); + for (const m of mismatches) log('err', m); + xsBuf.destroy(); + outBuf.destroy(); + uniformBuf.destroy(); + return { validateOk: false, mismatches, timing: null }; + } + log('ok', `VALIDATION OK (${validateN} pairs)`); + + // Timed reps. + const msSamples: number[] = []; + for (let rep = 0; rep < reps; rep++) { + const encoder = device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(numWorkgroups, 1, 1); + pass.end(); + const t0 = performance.now(); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + const t1 = performance.now(); + msSamples.push(t1 - t0); + } + const msMed = median(msSamples); + const msMin = Math.min(...msSamples); + const msMax = Math.max(...msSamples); + const totalInv = n * k; + const inversionsPerSec = totalInv / (msMed / 1000); + log( + 'ok', + `timing reps=${reps} median=${msMed.toFixed(3)}ms min=${msMin.toFixed(3)}ms max=${msMax.toFixed(3)}ms inversions/s=${inversionsPerSec.toExponential(3)} (n*k=${totalInv.toLocaleString()})`, + ); + + xsBuf.destroy(); + outBuf.destroy(); + uniformBuf.destroy(); + + return { + validateOk: true, + mismatches: [], + timing: { reps, msSamples, msMedian: msMed, msMin, msMax, inversionsPerSec }, + }; +} + +function parseParams(): { + n: number; + k: number; + validateN: number; + reps: number; + variant: 'fr_inv' | 'fr_inv_by' | 'fr_inv_by_a' | 'fr_pow_inv'; +} { + const qp = new URLSearchParams(window.location.search); + const n = parseInt(qp.get('n') ?? '64', 10); + const k = parseInt(qp.get('k') ?? '1', 10); + const validateN = parseInt(qp.get('validate-n') ?? String(Math.min(64, n)), 10); + const reps = parseInt(qp.get('reps') ?? '1', 10); + const variantStr = qp.get('variant') ?? 'fr_inv_by'; + if (!Number.isFinite(n) || n <= 0 || n > N_MAX) { + throw new Error(`?n must be in (0, ${N_MAX}], got ${qp.get('n')}`); + } + if (!Number.isFinite(k) || k <= 0 || k > K_MAX) { + throw new Error(`?k must be in (0, ${K_MAX}], got ${qp.get('k')}`); + } + if (!Number.isFinite(validateN) || validateN < 0 || validateN > n) { + throw new Error(`?validate-n must be in [0, n], got ${qp.get('validate-n')}`); + } + if (!Number.isFinite(reps) || reps <= 0 || reps > 100) { + throw new Error(`?reps must be in (0, 100], got ${qp.get('reps')}`); + } + if ( + variantStr !== 'fr_inv' && + variantStr !== 'fr_inv_by' && + variantStr !== 'fr_inv_by_a' && + variantStr !== 'fr_pow_inv' + ) { + throw new Error( + `?variant must be 'fr_inv' | 'fr_inv_by' | 'fr_inv_by_a' | 'fr_pow_inv', got '${variantStr}'`, + ); + } + return { n, k, validateN, reps, variant: variantStr }; +} + +async function main() { + try { + if (!('gpu' in navigator)) { + throw new Error('navigator.gpu missing — WebGPU not available'); + } + const params = parseParams(); + benchState.params = params; + log( + 'info', + `params: n=${params.n} k=${params.k} validate-n=${params.validateN} reps=${params.reps} variant=${params.variant}`, + ); + + benchState.state = 'running'; + const device = await get_device(); + log('info', 'WebGPU device acquired'); + + const sm = new ShaderManager(4, params.n, BN254_CURVE_CONFIG, false); + + const result = await runBench( + device, + sm, + params.n, + params.k, + params.validateN, + params.reps, + params.variant, + ); + benchState.result = result; + + benchState.state = 'done'; + if (result.validateOk) { + log('ok', 'bench done'); + } else { + log('err', 'bench done with validation failures'); + } + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `FATAL: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; + } +} + +main().catch(e => { + const msg = e instanceof Error ? e.message : String(e); + log('err', `unhandled: ${msg}`); + benchState.state = 'error'; + benchState.error = msg; +}); diff --git a/barretenberg/ts/dev/msm-webgpu/main.ts b/barretenberg/ts/dev/msm-webgpu/main.ts index 1d9ce2f1cfaa..6eb897f041a3 100644 --- a/barretenberg/ts/dev/msm-webgpu/main.ts +++ b/barretenberg/ts/dev/msm-webgpu/main.ts @@ -1190,6 +1190,7 @@ $probeGpu?.addEventListener('click', async () => { $runSanity.addEventListener('click', async () => { $log.innerHTML = ''; abortRequested = false; + (window as unknown as { __sanity?: unknown }).__sanity = { state: 'running' }; setBusy(true, 'sanity check…'); try { log('info', '[sanity] WebGPU-only smoke test, log₂(n)=16, no WASM, no noble'); @@ -1252,9 +1253,23 @@ $runSanity.addEventListener('click', async () => { // view after a fresh page reload. $results.innerHTML = renderBreakdownTable([{ logN: 16, captures: [gpu.capture] }]); $results.classList.add('visible'); + // Expose the raw capture so Playwright-driven profile scripts can + // pull per-stage GPU times without scraping the rendered table. + // Cleared at the start of every click, so a stale value from a + // previous run never bleeds into the next read. + (window as unknown as { __sanity?: unknown }).__sanity = { + state: 'done', + logN: 16, + ms: gpu.ms, + capture: JSON.parse(JSON.stringify(gpu.capture)), + }; } catch (err) { log(abortRequested ? 'warn' : 'err', `[sanity] ${err instanceof Error ? err.message : String(err)}`); if (!abortRequested && err instanceof Error && err.stack) log('err', err.stack); + (window as unknown as { __sanity?: unknown }).__sanity = { + state: 'error', + error: err instanceof Error ? err.message : String(err), + }; } finally { setBusy(false); } diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/bench-apply-matrix.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/bench-apply-matrix.mjs new file mode 100755 index 000000000000..0badddd1c4b1 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/scripts/bench-apply-matrix.mjs @@ -0,0 +1,219 @@ +#!/usr/bin/env node +// BY apply_matrix WebGPU dispatch test driver. Launches Playwright Chromium +// with WebGPU enabled, navigates to the standalone bench-apply-matrix.html +// page, waits for `window.__bench.state === 'done' | 'error'`, prints +// results. Mirrors the bench-divsteps.mjs structure. + +import { chromium } from 'playwright-core'; +import { parseArgs } from 'node:util'; + +const DEFAULT_URL_BASE = 'http://localhost:5173/dev/msm-webgpu/bench-apply-matrix.html'; +const N_MAX = 1 << 20; + +const { values: argv } = parseArgs({ + options: { + n: { type: 'string', default: '1024' }, + 'validate-n': { type: 'string' }, + reps: { type: 'string', default: '1' }, + url: { type: 'string', default: DEFAULT_URL_BASE }, + headed: { type: 'boolean', default: false }, + timeout: { type: 'string', default: '60' }, + json: { type: 'boolean', default: false }, + help: { type: 'boolean', default: false }, + }, + allowPositionals: false, +}); + +if (argv.help) { + process.stdout.write( + `BY apply_matrix WebGPU dispatch test driver + +Usage: + node dev/msm-webgpu/scripts/bench-apply-matrix.mjs [options] + +Options: + --n N (default 1024, max ${N_MAX}) + --validate-n N (default min(64, n)) + --reps R (default 1) + --url URL (default ${DEFAULT_URL_BASE}) + --headed Run with visible browser window + --timeout SECS Bench page completion timeout (default 60) + --json Machine-readable JSON only output + --help Show this help +`, + ); + process.exit(0); +} + +const n = parseInt(String(argv.n), 10); +if (!Number.isFinite(n) || n <= 0 || n > N_MAX) { + process.stderr.write(`error: --n must be in (0, ${N_MAX}], got ${argv.n}\n`); + process.exit(2); +} +const validateN = + argv['validate-n'] !== undefined ? parseInt(String(argv['validate-n']), 10) : Math.min(64, n); +if (!Number.isFinite(validateN) || validateN < 0 || validateN > n) { + process.stderr.write(`error: --validate-n must be in [0, n], got ${argv['validate-n']}\n`); + process.exit(2); +} +const reps = parseInt(String(argv.reps), 10); +if (!Number.isFinite(reps) || reps <= 0 || reps > 100) { + process.stderr.write(`error: --reps must be in (0, 100], got ${argv.reps}\n`); + process.exit(2); +} +const headed = Boolean(argv.headed); +const timeoutMs = parseFloat(String(argv.timeout)) * 1000; +const jsonOnly = Boolean(argv.json); +const baseUrl = String(argv.url); + +function err(msg) { + process.stderr.write(`[bench-apply-matrix] ${msg}\n`); +} +function out(msg) { + if (!jsonOnly) process.stdout.write(`${msg}\n`); +} + +async function reachable(targetUrl) { + try { + const res = await fetch(targetUrl, { method: 'GET' }); + return res.status >= 200 && res.status < 500; + } catch { + return false; + } +} + +const CHROMIUM_ARGS = [ + '--enable-unsafe-webgpu', + '--enable-features=Vulkan,WebGPU', + '--use-angle=metal', + '--disable-features=ServiceWorker', + '--ignore-gpu-blocklist', +]; + +async function tryLaunch(headless) { + return chromium.launch({ headless, args: CHROMIUM_ARGS }); +} + +function buildUrl() { + const u = new URL(baseUrl); + u.searchParams.set('n', String(n)); + u.searchParams.set('validate-n', String(validateN)); + u.searchParams.set('reps', String(reps)); + return u.toString(); +} + +async function runOnce(headless) { + let browser; + try { + browser = await tryLaunch(headless); + const context = await browser.newContext({ + viewport: { width: 900, height: 600 }, + permissions: [], + bypassCSP: false, + }); + const page = await context.newPage(); + page.on('console', msg => { + const txt = msg.text(); + if (!txt.startsWith('[vite]')) { + err(`[page:${msg.type()}] ${txt}`); + } + }); + page.on('pageerror', e => err(`[page:pageerror] ${e.message}`)); + + const navUrl = buildUrl(); + err(`navigating to ${navUrl} (headless=${headless})`); + await page.goto(navUrl, { waitUntil: 'domcontentloaded', timeout: 30_000 }); + + const hasWebGpu = await page.evaluate(() => 'gpu' in navigator); + if (!hasWebGpu) { + throw new Error('navigator.gpu missing in this Chromium instance'); + } + + err(`waiting for bench to complete (up to ${(timeoutMs / 1000).toFixed(0)}s)`); + const t0 = Date.now(); + await page.waitForFunction( + () => window.__bench?.state === 'done' || window.__bench?.state === 'error', + { timeout: timeoutMs, polling: 250 }, + ); + const elapsed = (Date.now() - t0) / 1000; + err(`bench reached terminal state in ${elapsed.toFixed(1)}s`); + + const result = await page.evaluate(() => { + const b = window.__bench; + return { + state: b.state, + params: b.params, + result: JSON.parse(JSON.stringify(b.result)), + error: b.error, + log: b.log.slice(), + }; + }); + result.elapsedSec = elapsed; + return result; + } finally { + try { + if (browser) await browser.close(); + } catch (e) { + err(`browser.close failed: ${e.message}`); + } + } +} + +(async () => { + if (!(await reachable(baseUrl))) { + err(`dev server not reachable at ${new URL(baseUrl).origin}`); + err(`start it with: cd barretenberg/ts && ./node_modules/.bin/vite --config dev/msm-webgpu/vite.config.ts --no-open`); + process.exit(3); + } + + let result; + try { + result = await runOnce(!headed); + } catch (e) { + err(`run failed: ${e.message}`); + if (!headed) { + err('retrying in headed mode'); + try { + result = await runOnce(false); + } catch (e2) { + err(`headed retry also failed: ${e2.message}`); + process.exit(4); + } + } else { + process.exit(4); + } + } + + if (result.state === 'error') { + err(`page reported error: ${result.error}`); + if (jsonOnly) process.stdout.write(JSON.stringify(result, null, 2) + '\n'); + process.exit(5); + } + + if (!jsonOnly) { + out(`## BY apply_matrix WebGPU dispatch test`); + out(''); + out(`- params: ${JSON.stringify(result.params)}`); + out(`- elapsed wall (s): ${result.elapsedSec.toFixed(2)}`); + out(''); + const r = result.result; + if (r && r.timing) { + const t = r.timing; + out(`- validate ok: ${r.validateOk ? 'OK' : 'FAIL'}`); + out(`- timing reps=${t.reps} median=${t.msMedian.toFixed(3)}ms min=${t.msMin.toFixed(3)}ms max=${t.msMax.toFixed(3)}ms apply_matrix_calls/s=${t.applyMatrixPerSec.toExponential(3)}`); + } else if (r) { + out(`- validate ok: ${r.validateOk ? 'OK' : 'FAIL'} (no timing — validation failed)`); + for (const m of r.mismatches) { + out(` ${m.replace(/\n/g, '\n ')}`); + } + } + out(''); + } + process.stdout.write(JSON.stringify(result, null, 2) + '\n'); + + const failed = !result.result || !result.result.validateOk; + process.exit(failed ? 1 : 0); +})().catch(e => { + err(`unexpected: ${e.stack ?? e.message ?? String(e)}`); + process.exit(99); +}); diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/bench-batch-affine.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/bench-batch-affine.mjs new file mode 100644 index 000000000000..811949065ebe --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/scripts/bench-batch-affine.mjs @@ -0,0 +1,212 @@ +#!/usr/bin/env node +// Batch-affine EC add amortisation bench driver. +// Mirrors bench-fr-inv.mjs structure: Playwright Chromium with WebGPU +// enabled, navigates to the standalone bench-batch-affine.html page, +// waits for `window.__bench.state === 'done' | 'error'`, dumps results +// as a table. + +import { chromium } from 'playwright-core'; +import { parseArgs } from 'node:util'; + +const DEFAULT_URL_BASE = 'http://localhost:5197/dev/msm-webgpu/bench-batch-affine.html'; + +const { values: argv } = parseArgs({ + options: { + reps: { type: 'string', default: '5' }, + total: { type: 'string' }, + sizes: { type: 'string' }, + url: { type: 'string', default: DEFAULT_URL_BASE }, + headed: { type: 'boolean', default: false }, + timeout: { type: 'string', default: '180' }, + json: { type: 'boolean', default: false }, + help: { type: 'boolean', default: false }, + }, + allowPositionals: false, +}); + +if (argv.help) { + process.stdout.write( + `batch-affine WebGPU amortisation bench driver + +Usage: + node dev/msm-webgpu/scripts/bench-batch-affine.mjs [options] + +Options: + --reps R (default 5) + --url URL (default ${DEFAULT_URL_BASE}) + --headed Run with visible browser window + --timeout SECS Bench page completion timeout (default 180) + --json Machine-readable JSON only output + --help Show this help +`, + ); + process.exit(0); +} + +const reps = parseInt(String(argv.reps), 10); +if (!Number.isFinite(reps) || reps <= 0 || reps > 50) { + process.stderr.write(`error: --reps must be in (0, 50], got ${argv.reps}\n`); + process.exit(2); +} +const headed = Boolean(argv.headed); +const timeoutMs = parseFloat(String(argv.timeout)) * 1000; +const jsonOnly = Boolean(argv.json); +const baseUrl = String(argv.url); + +function err(msg) { + process.stderr.write(`[bench-batch-affine] ${msg}\n`); +} +function out(msg) { + if (!jsonOnly) process.stdout.write(`${msg}\n`); +} + +async function reachable(targetUrl) { + try { + const res = await fetch(targetUrl, { method: 'GET' }); + return res.status >= 200 && res.status < 500; + } catch { + return false; + } +} + +const CHROMIUM_ARGS = [ + '--enable-unsafe-webgpu', + '--enable-features=Vulkan,WebGPU', + '--use-angle=metal', + '--disable-features=ServiceWorker', + '--ignore-gpu-blocklist', +]; + +async function tryLaunch(headless) { + return chromium.launch({ headless, args: CHROMIUM_ARGS }); +} + +function buildUrl() { + const u = new URL(baseUrl); + u.searchParams.set('reps', String(reps)); + if (argv.total !== undefined) u.searchParams.set('total', String(argv.total)); + if (argv.sizes !== undefined) u.searchParams.set('sizes', String(argv.sizes)); + return u.toString(); +} + +async function runOnce(headless) { + let browser; + try { + browser = await tryLaunch(headless); + const context = await browser.newContext({ + viewport: { width: 900, height: 600 }, + permissions: [], + bypassCSP: false, + }); + const page = await context.newPage(); + page.on('console', msg => { + const txt = msg.text(); + if (!txt.startsWith('[vite]')) { + err(`[page:${msg.type()}] ${txt}`); + } + }); + page.on('pageerror', e => err(`[page:pageerror] ${e.message}`)); + + const navUrl = buildUrl(); + err(`navigating to ${navUrl} (headless=${headless})`); + await page.goto(navUrl, { waitUntil: 'domcontentloaded', timeout: 30_000 }); + + const hasWebGpu = await page.evaluate(() => 'gpu' in navigator); + if (!hasWebGpu) { + throw new Error('navigator.gpu missing in this Chromium instance'); + } + + err(`waiting for bench to complete (up to ${(timeoutMs / 1000).toFixed(0)}s)`); + const t0 = Date.now(); + await page.waitForFunction( + () => window.__bench?.state === 'done' || window.__bench?.state === 'error', + { timeout: timeoutMs, polling: 250 }, + ); + const elapsed = (Date.now() - t0) / 1000; + err(`bench reached terminal state in ${elapsed.toFixed(1)}s`); + + const result = await page.evaluate(() => { + const b = window.__bench; + return { + state: b.state, + params: b.params, + results: JSON.parse(JSON.stringify(b.results)), + error: b.error, + log: b.log.slice(), + }; + }); + result.elapsedSec = elapsed; + return result; + } finally { + try { + if (browser) await browser.close(); + } catch (e) { + err(`browser.close failed: ${e.message}`); + } + } +} + +(async () => { + if (!(await reachable(baseUrl))) { + err(`dev server not reachable at ${new URL(baseUrl).origin}`); + err( + `start it with: cd barretenberg/ts && ./node_modules/.bin/vite --config dev/msm-webgpu/vite.config.ts --port 5197 --strictPort --no-open`, + ); + process.exit(3); + } + + let result; + try { + result = await runOnce(!headed); + } catch (e) { + err(`run failed: ${e.message}`); + if (!headed) { + err('retrying in headed mode'); + try { + result = await runOnce(false); + } catch (e2) { + err(`headed retry also failed: ${e2.message}`); + process.exit(4); + } + } else { + process.exit(4); + } + } + + if (result.state === 'error') { + err(`page reported error: ${result.error}`); + if (jsonOnly) process.stdout.write(JSON.stringify(result, null, 2) + '\n'); + // Don't exit non-zero — partial results may still be useful (e.g. + // first batch_size failed). Print what we have. + } + + if (!jsonOnly) { + out(`## Batch-affine WebGPU amortisation bench (reps=${reps})`); + out(''); + out(`- elapsed wall (s): ${result.elapsedSec.toFixed(2)}`); + out(`- state: ${result.state}`); + if (result.error) out(`- error: ${result.error}`); + out(''); + if (result.results && result.results.length > 0) { + const base = result.results[result.results.length - 1].ns_per_pair; + out('batch_size | num_WGs | TPB | total_threads | median_ms | ns/pair | inv_amort_ratio'); + out('---------- | ------- | --- | ------------- | --------- | ------- | --------------'); + for (const r of result.results) { + const ratio = (r.ns_per_pair / base).toFixed(3); + out( + `${String(r.batch_size).padEnd(10)} | ${String(r.num_wgs).padEnd(7)} | ${String(r.tpb).padEnd(3)} | ${String(r.total_threads).padEnd(13)} | ${r.median_ms.toFixed(3).padStart(9)} | ${r.ns_per_pair.toFixed(1).padStart(7)} | ${String(ratio).padStart(14)}`, + ); + } + } else { + out('(no results)'); + } + out(''); + } + process.stdout.write(JSON.stringify(result, null, 2) + '\n'); + + const failed = result.state === 'error'; + process.exit(failed ? 1 : 0); +})().catch(e => { + err(`unexpected: ${e.stack ?? e.message ?? String(e)}`); + process.exit(99); +}); diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/bench-divsteps.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/bench-divsteps.mjs new file mode 100755 index 000000000000..54dc7259c0d7 --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/scripts/bench-divsteps.mjs @@ -0,0 +1,219 @@ +#!/usr/bin/env node +// BY divsteps WebGPU dispatch test driver. Launches Playwright Chromium +// with WebGPU enabled, navigates to the standalone bench-divsteps.html +// page, waits for `window.__bench.state === 'done' | 'error'`, prints +// results. Mirrors the bench-field-mul.mjs structure. + +import { chromium } from 'playwright-core'; +import { parseArgs } from 'node:util'; + +const DEFAULT_URL_BASE = 'http://localhost:5173/dev/msm-webgpu/bench-divsteps.html'; +const N_MAX = 1 << 23; + +const { values: argv } = parseArgs({ + options: { + n: { type: 'string', default: '1024' }, + 'validate-n': { type: 'string' }, + reps: { type: 'string', default: '1' }, + url: { type: 'string', default: DEFAULT_URL_BASE }, + headed: { type: 'boolean', default: false }, + timeout: { type: 'string', default: '60' }, + json: { type: 'boolean', default: false }, + help: { type: 'boolean', default: false }, + }, + allowPositionals: false, +}); + +if (argv.help) { + process.stdout.write( + `BY divsteps WebGPU dispatch test driver + +Usage: + node dev/msm-webgpu/scripts/bench-divsteps.mjs [options] + +Options: + --n N (default 1024, max ${N_MAX}) + --validate-n N (default min(64, n)) + --reps R (default 1) + --url URL (default ${DEFAULT_URL_BASE}) + --headed Run with visible browser window + --timeout SECS Bench page completion timeout (default 60) + --json Machine-readable JSON only output + --help Show this help +`, + ); + process.exit(0); +} + +const n = parseInt(String(argv.n), 10); +if (!Number.isFinite(n) || n <= 0 || n > N_MAX) { + process.stderr.write(`error: --n must be in (0, ${N_MAX}], got ${argv.n}\n`); + process.exit(2); +} +const validateN = + argv['validate-n'] !== undefined ? parseInt(String(argv['validate-n']), 10) : Math.min(64, n); +if (!Number.isFinite(validateN) || validateN < 0 || validateN > n) { + process.stderr.write(`error: --validate-n must be in [0, n], got ${argv['validate-n']}\n`); + process.exit(2); +} +const reps = parseInt(String(argv.reps), 10); +if (!Number.isFinite(reps) || reps <= 0 || reps > 100) { + process.stderr.write(`error: --reps must be in (0, 100], got ${argv.reps}\n`); + process.exit(2); +} +const headed = Boolean(argv.headed); +const timeoutMs = parseFloat(String(argv.timeout)) * 1000; +const jsonOnly = Boolean(argv.json); +const baseUrl = String(argv.url); + +function err(msg) { + process.stderr.write(`[bench-divsteps] ${msg}\n`); +} +function out(msg) { + if (!jsonOnly) process.stdout.write(`${msg}\n`); +} + +async function reachable(targetUrl) { + try { + const res = await fetch(targetUrl, { method: 'GET' }); + return res.status >= 200 && res.status < 500; + } catch { + return false; + } +} + +const CHROMIUM_ARGS = [ + '--enable-unsafe-webgpu', + '--enable-features=Vulkan,WebGPU', + '--use-angle=metal', + '--disable-features=ServiceWorker', + '--ignore-gpu-blocklist', +]; + +async function tryLaunch(headless) { + return chromium.launch({ headless, args: CHROMIUM_ARGS }); +} + +function buildUrl() { + const u = new URL(baseUrl); + u.searchParams.set('n', String(n)); + u.searchParams.set('validate-n', String(validateN)); + u.searchParams.set('reps', String(reps)); + return u.toString(); +} + +async function runOnce(headless) { + let browser; + try { + browser = await tryLaunch(headless); + const context = await browser.newContext({ + viewport: { width: 900, height: 600 }, + permissions: [], + bypassCSP: false, + }); + const page = await context.newPage(); + page.on('console', msg => { + const txt = msg.text(); + if (!txt.startsWith('[vite]')) { + err(`[page:${msg.type()}] ${txt}`); + } + }); + page.on('pageerror', e => err(`[page:pageerror] ${e.message}`)); + + const navUrl = buildUrl(); + err(`navigating to ${navUrl} (headless=${headless})`); + await page.goto(navUrl, { waitUntil: 'domcontentloaded', timeout: 30_000 }); + + const hasWebGpu = await page.evaluate(() => 'gpu' in navigator); + if (!hasWebGpu) { + throw new Error('navigator.gpu missing in this Chromium instance'); + } + + err(`waiting for bench to complete (up to ${(timeoutMs / 1000).toFixed(0)}s)`); + const t0 = Date.now(); + await page.waitForFunction( + () => window.__bench?.state === 'done' || window.__bench?.state === 'error', + { timeout: timeoutMs, polling: 250 }, + ); + const elapsed = (Date.now() - t0) / 1000; + err(`bench reached terminal state in ${elapsed.toFixed(1)}s`); + + const result = await page.evaluate(() => { + const b = window.__bench; + return { + state: b.state, + params: b.params, + result: JSON.parse(JSON.stringify(b.result)), + error: b.error, + log: b.log.slice(), + }; + }); + result.elapsedSec = elapsed; + return result; + } finally { + try { + if (browser) await browser.close(); + } catch (e) { + err(`browser.close failed: ${e.message}`); + } + } +} + +(async () => { + if (!(await reachable(baseUrl))) { + err(`dev server not reachable at ${new URL(baseUrl).origin}`); + err(`start it with: cd barretenberg/ts && ./node_modules/.bin/vite --config dev/msm-webgpu/vite.config.ts --no-open`); + process.exit(3); + } + + let result; + try { + result = await runOnce(!headed); + } catch (e) { + err(`run failed: ${e.message}`); + if (!headed) { + err('retrying in headed mode'); + try { + result = await runOnce(false); + } catch (e2) { + err(`headed retry also failed: ${e2.message}`); + process.exit(4); + } + } else { + process.exit(4); + } + } + + if (result.state === 'error') { + err(`page reported error: ${result.error}`); + if (jsonOnly) process.stdout.write(JSON.stringify(result, null, 2) + '\n'); + process.exit(5); + } + + if (!jsonOnly) { + out(`## BY divsteps WebGPU dispatch test`); + out(''); + out(`- params: ${JSON.stringify(result.params)}`); + out(`- elapsed wall (s): ${result.elapsedSec.toFixed(2)}`); + out(''); + const r = result.result; + if (r && r.timing) { + const t = r.timing; + out(`- validate ok: ${r.validateOk ? 'OK' : 'FAIL'}`); + out(`- timing reps=${t.reps} median=${t.msMedian.toFixed(3)}ms min=${t.msMin.toFixed(3)}ms max=${t.msMax.toFixed(3)}ms divsteps_calls/s=${t.divstepsPerSec.toExponential(3)}`); + } else if (r) { + out(`- validate ok: ${r.validateOk ? 'OK' : 'FAIL'} (no timing — validation failed)`); + for (const m of r.mismatches) { + out(` ${m.replace(/\n/g, '\n ')}`); + } + } + out(''); + } + process.stdout.write(JSON.stringify(result, null, 2) + '\n'); + + const failed = !result.result || !result.result.validateOk; + process.exit(failed ? 1 : 0); +})().catch(e => { + err(`unexpected: ${e.stack ?? e.message ?? String(e)}`); + process.exit(99); +}); diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/bench-fr-inv.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/bench-fr-inv.mjs new file mode 100755 index 000000000000..7e36eb607b0b --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/scripts/bench-fr-inv.mjs @@ -0,0 +1,239 @@ +#!/usr/bin/env node +// fr_inv WebGPU dispatch test driver. Launches Playwright Chromium with +// WebGPU enabled, navigates to the standalone bench-fr-inv.html page, +// waits for `window.__bench.state === 'done' | 'error'`, prints results. +// Mirrors the bench-divsteps.mjs / bench-apply-matrix.mjs structure. + +import { chromium } from 'playwright-core'; +import { parseArgs } from 'node:util'; + +const DEFAULT_URL_BASE = 'http://localhost:5173/dev/msm-webgpu/bench-fr-inv.html'; +const N_MAX = 1 << 20; +const K_MAX = 100; +const VARIANTS = new Set(['fr_inv', 'fr_inv_by', 'fr_inv_by_a', 'fr_pow_inv']); + +const { values: argv } = parseArgs({ + options: { + n: { type: 'string', default: '1024' }, + k: { type: 'string', default: '1' }, + 'validate-n': { type: 'string' }, + reps: { type: 'string', default: '1' }, + variant: { type: 'string', default: 'fr_inv_by' }, + url: { type: 'string', default: DEFAULT_URL_BASE }, + headed: { type: 'boolean', default: false }, + timeout: { type: 'string', default: '120' }, + json: { type: 'boolean', default: false }, + help: { type: 'boolean', default: false }, + }, + allowPositionals: false, +}); + +if (argv.help) { + process.stdout.write( + `fr_inv WebGPU dispatch test driver + +Usage: + node dev/msm-webgpu/scripts/bench-fr-inv.mjs [options] + +Options: + --n N (default 1024, max ${N_MAX}) + --k K (default 1, max ${K_MAX}) + --validate-n N (default min(64, n)) + --reps R (default 1) + --variant V 'fr_inv' | 'fr_inv_by' | 'fr_inv_by_a' | 'fr_pow_inv' (default fr_inv_by) + --url URL (default ${DEFAULT_URL_BASE}) + --headed Run with visible browser window + --timeout SECS Bench page completion timeout (default 120) + --json Machine-readable JSON only output + --help Show this help +`, + ); + process.exit(0); +} + +const n = parseInt(String(argv.n), 10); +if (!Number.isFinite(n) || n <= 0 || n > N_MAX) { + process.stderr.write(`error: --n must be in (0, ${N_MAX}], got ${argv.n}\n`); + process.exit(2); +} +const k = parseInt(String(argv.k), 10); +if (!Number.isFinite(k) || k <= 0 || k > K_MAX) { + process.stderr.write(`error: --k must be in (0, ${K_MAX}], got ${argv.k}\n`); + process.exit(2); +} +const validateN = + argv['validate-n'] !== undefined ? parseInt(String(argv['validate-n']), 10) : Math.min(64, n); +if (!Number.isFinite(validateN) || validateN < 0 || validateN > n) { + process.stderr.write(`error: --validate-n must be in [0, n], got ${argv['validate-n']}\n`); + process.exit(2); +} +const reps = parseInt(String(argv.reps), 10); +if (!Number.isFinite(reps) || reps <= 0 || reps > 100) { + process.stderr.write(`error: --reps must be in (0, 100], got ${argv.reps}\n`); + process.exit(2); +} +const variant = String(argv.variant); +if (!VARIANTS.has(variant)) { + process.stderr.write( + `error: --variant must be one of ${[...VARIANTS].join(', ')}, got '${variant}'\n`, + ); + process.exit(2); +} +const headed = Boolean(argv.headed); +const timeoutMs = parseFloat(String(argv.timeout)) * 1000; +const jsonOnly = Boolean(argv.json); +const baseUrl = String(argv.url); + +function err(msg) { + process.stderr.write(`[bench-fr-inv] ${msg}\n`); +} +function out(msg) { + if (!jsonOnly) process.stdout.write(`${msg}\n`); +} + +async function reachable(targetUrl) { + try { + const res = await fetch(targetUrl, { method: 'GET' }); + return res.status >= 200 && res.status < 500; + } catch { + return false; + } +} + +const CHROMIUM_ARGS = [ + '--enable-unsafe-webgpu', + '--enable-features=Vulkan,WebGPU', + '--use-angle=metal', + '--disable-features=ServiceWorker', + '--ignore-gpu-blocklist', +]; + +async function tryLaunch(headless) { + return chromium.launch({ headless, args: CHROMIUM_ARGS }); +} + +function buildUrl() { + const u = new URL(baseUrl); + u.searchParams.set('n', String(n)); + u.searchParams.set('k', String(k)); + u.searchParams.set('validate-n', String(validateN)); + u.searchParams.set('reps', String(reps)); + u.searchParams.set('variant', variant); + return u.toString(); +} + +async function runOnce(headless) { + let browser; + try { + browser = await tryLaunch(headless); + const context = await browser.newContext({ + viewport: { width: 900, height: 600 }, + permissions: [], + bypassCSP: false, + }); + const page = await context.newPage(); + page.on('console', msg => { + const txt = msg.text(); + if (!txt.startsWith('[vite]')) { + err(`[page:${msg.type()}] ${txt}`); + } + }); + page.on('pageerror', e => err(`[page:pageerror] ${e.message}`)); + + const navUrl = buildUrl(); + err(`navigating to ${navUrl} (headless=${headless})`); + await page.goto(navUrl, { waitUntil: 'domcontentloaded', timeout: 30_000 }); + + const hasWebGpu = await page.evaluate(() => 'gpu' in navigator); + if (!hasWebGpu) { + throw new Error('navigator.gpu missing in this Chromium instance'); + } + + err(`waiting for bench to complete (up to ${(timeoutMs / 1000).toFixed(0)}s)`); + const t0 = Date.now(); + await page.waitForFunction( + () => window.__bench?.state === 'done' || window.__bench?.state === 'error', + { timeout: timeoutMs, polling: 250 }, + ); + const elapsed = (Date.now() - t0) / 1000; + err(`bench reached terminal state in ${elapsed.toFixed(1)}s`); + + const result = await page.evaluate(() => { + const b = window.__bench; + return { + state: b.state, + params: b.params, + result: JSON.parse(JSON.stringify(b.result)), + error: b.error, + log: b.log.slice(), + }; + }); + result.elapsedSec = elapsed; + return result; + } finally { + try { + if (browser) await browser.close(); + } catch (e) { + err(`browser.close failed: ${e.message}`); + } + } +} + +(async () => { + if (!(await reachable(baseUrl))) { + err(`dev server not reachable at ${new URL(baseUrl).origin}`); + err(`start it with: cd barretenberg/ts && ./node_modules/.bin/vite --config dev/msm-webgpu/vite.config.ts --no-open`); + process.exit(3); + } + + let result; + try { + result = await runOnce(!headed); + } catch (e) { + err(`run failed: ${e.message}`); + if (!headed) { + err('retrying in headed mode'); + try { + result = await runOnce(false); + } catch (e2) { + err(`headed retry also failed: ${e2.message}`); + process.exit(4); + } + } else { + process.exit(4); + } + } + + if (result.state === 'error') { + err(`page reported error: ${result.error}`); + if (jsonOnly) process.stdout.write(JSON.stringify(result, null, 2) + '\n'); + process.exit(5); + } + + if (!jsonOnly) { + out(`## fr_inv WebGPU dispatch test (variant=${variant})`); + out(''); + out(`- params: ${JSON.stringify(result.params)}`); + out(`- elapsed wall (s): ${result.elapsedSec.toFixed(2)}`); + out(''); + const r = result.result; + if (r && r.timing) { + const t = r.timing; + out(`- validate ok: ${r.validateOk ? 'OK' : 'FAIL'}`); + out(`- timing reps=${t.reps} median=${t.msMedian.toFixed(3)}ms min=${t.msMin.toFixed(3)}ms max=${t.msMax.toFixed(3)}ms inversions/s=${t.inversionsPerSec.toExponential(3)}`); + } else if (r) { + out(`- validate ok: ${r.validateOk ? 'OK' : 'FAIL'} (no timing — validation failed)`); + for (const m of r.mismatches) { + out(` ${m.replace(/\n/g, '\n ')}`); + } + } + out(''); + } + process.stdout.write(JSON.stringify(result, null, 2) + '\n'); + + const failed = !result.result || !result.result.validateOk; + process.exit(failed ? 1 : 0); +})().catch(e => { + err(`unexpected: ${e.stack ?? e.message ?? String(e)}`); + process.exit(99); +}); diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/profile-sanity.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/profile-sanity.mjs new file mode 100644 index 000000000000..c2c9eb94e83b --- /dev/null +++ b/barretenberg/ts/dev/msm-webgpu/scripts/profile-sanity.mjs @@ -0,0 +1,189 @@ +#!/usr/bin/env node +// Drive the MSM dev page's Quick Sanity Check button via Playwright, +// scrape window.__sanity.capture for per-stage GPU times, and print a +// compact roll-up. Used to bench WPB / BPR-multi-window variants +// without rebuilding the full bench/sweep UI flow. +// +// Outputs JSON like: +// { +// "reps": 5, +// "ms_total": { "samples": [...], "median": ... }, +// "per_stage": { "bpr_1": { "samples": [...], "median": ... }, ... } +// } + +import { chromium } from 'playwright-core'; +import { parseArgs } from 'node:util'; + +const DEFAULT_URL = 'http://localhost:5195/dev/msm-webgpu/index.html'; + +const { values: argv } = parseArgs({ + options: { + url: { type: 'string', default: DEFAULT_URL }, + reps: { type: 'string', default: '5' }, + headed: { type: 'boolean', default: false }, + timeout: { type: 'string', default: '120' }, + help: { type: 'boolean', default: false }, + }, + allowPositionals: false, +}); + +if (argv.help) { + process.stdout.write( + `Profile-sanity WebGPU MSM driver +Usage: + node dev/msm-webgpu/scripts/profile-sanity.mjs [options] +Options: + --url URL (default ${DEFAULT_URL}) + --reps R (default 5) + --headed Show browser window + --timeout SECS Per-sanity-click timeout (default 120) + --help Show this help +`, + ); + process.exit(0); +} + +const reps = parseInt(String(argv.reps), 10); +if (!Number.isFinite(reps) || reps <= 0 || reps > 20) { + process.stderr.write(`error: --reps must be in (0,20], got ${argv.reps}\n`); + process.exit(2); +} +const headed = Boolean(argv.headed); +const timeoutMs = parseFloat(String(argv.timeout)) * 1000; +const baseUrl = String(argv.url); + +function err(msg) { + process.stderr.write(`[profile-sanity] ${msg}\n`); +} + +const CHROMIUM_ARGS = [ + '--enable-unsafe-webgpu', + '--enable-features=Vulkan,WebGPU', + '--use-angle=metal', + '--disable-features=ServiceWorker', + '--ignore-gpu-blocklist', +]; + +function rollupLabel(label) { + const idx = label.indexOf('['); + return idx >= 0 ? label.substring(0, idx) : label; +} + +function aggregateCapture(capture) { + // Per-rep stage sums: roll up `[r=...]` / `[subtask=...]` suffixes. + // Skip region entries (they overlap and would double-count). + const totals = new Map(); + if (!capture?.profile) return { stages: {}, gpuWall: null, profiledSum: null, untimestamped: null }; + for (const e of capture.profile) { + if (e.kind === 'region') continue; + const k = rollupLabel(e.label); + totals.set(k, (totals.get(k) ?? 0) + e.ms); + } + return { + stages: Object.fromEntries(totals), + gpuWall: capture.gpu_readback?.gpu_compute_wall ?? null, + profiledSum: capture.gpu_readback?.profiled_passes_sum ?? null, + untimestamped: capture.gpu_readback?.untimestamped ?? null, + }; +} + +function median(arr) { + if (!arr.length) return null; + const s = [...arr].sort((a, b) => a - b); + const m = Math.floor(s.length / 2); + return s.length % 2 === 0 ? (s[m - 1] + s[m]) / 2 : s[m]; +} + +async function run() { + const browser = await chromium.launch({ headless: !headed, args: CHROMIUM_ARGS }); + try { + const context = await browser.newContext({ viewport: { width: 1100, height: 800 }, bypassCSP: false }); + const page = await context.newPage(); + page.on('console', msg => { + const t = msg.text(); + if (!t.startsWith('[vite]')) err(`[page:${msg.type()}] ${t}`); + }); + page.on('pageerror', e => err(`[page:pageerror] ${e.message}`)); + + err(`navigating to ${baseUrl} (headless=${!headed})`); + await page.goto(baseUrl, { waitUntil: 'domcontentloaded', timeout: 30_000 }); + + // Probe adapter info so the operator can verify Metal/Apple. + const adapter = await page.evaluate(async () => { + if (!('gpu' in navigator)) return null; + const a = await navigator.gpu.requestAdapter(); + if (!a) return null; + // adapter.info is the modern API; fallback to requestAdapterInfo() for older. + // eslint-disable-next-line no-undef + const info = a.info ?? (await a.requestAdapterInfo?.()); + return info ? { vendor: info.vendor, architecture: info.architecture, device: info.device, description: info.description } : null; + }); + err(`adapter: ${JSON.stringify(adapter)}`); + + // Wait for sanity button to be enabled (SRS loaded). + await page.waitForFunction( + () => { + const b = document.getElementById('run-sanity'); + return b && !b.disabled; + }, + { timeout: 60_000, polling: 200 }, + ); + err('sanity button enabled (SRS loaded)'); + + const samples = []; // per-rep aggregates + for (let r = 0; r < reps; r++) { + err(`rep ${r + 1}/${reps}: clicking Quick sanity check`); + // Reset state and click. + await page.evaluate(() => { + window.__sanity = { state: 'running' }; + }); + await page.click('#run-sanity'); + await page.waitForFunction( + () => window.__sanity?.state === 'done' || window.__sanity?.state === 'error', + { timeout: timeoutMs, polling: 250 }, + ); + const out = await page.evaluate(() => JSON.parse(JSON.stringify(window.__sanity))); + if (out.state === 'error') { + err(`rep ${r + 1} ERROR: ${out.error}`); + process.exit(5); + } + const agg = aggregateCapture(out.capture); + agg.ms = out.ms; + samples.push(agg); + const bpr1 = agg.stages?.bpr_1; + err(`rep ${r + 1} ok: ms=${out.ms?.toFixed(1)} bpr_1=${bpr1?.toFixed?.(2) ?? '?'} gpuWall=${agg.gpuWall?.toFixed?.(2) ?? '?'}`); + } + + // Aggregate per-stage medians + ms_total median. + const stageNames = new Set(); + for (const s of samples) for (const k of Object.keys(s.stages)) stageNames.add(k); + const per_stage = {}; + for (const k of stageNames) { + const xs = samples.map(s => s.stages[k]).filter(v => Number.isFinite(v)); + per_stage[k] = { samples: xs, median: median(xs) }; + } + const ms_total = samples.map(s => s.ms); + const gpu_wall = samples.map(s => s.gpuWall).filter(v => Number.isFinite(v)); + const profiled_sum = samples.map(s => s.profiledSum).filter(v => Number.isFinite(v)); + const result = { + reps, + adapter, + ms_total: { samples: ms_total, median: median(ms_total) }, + gpu_wall: { samples: gpu_wall, median: median(gpu_wall) }, + profiled_sum: { samples: profiled_sum, median: median(profiled_sum) }, + per_stage, + }; + process.stdout.write(JSON.stringify(result, null, 2) + '\n'); + } finally { + try { + await browser.close(); + } catch (e) { + err(`browser.close failed: ${e.message}`); + } + } +} + +run().catch(e => { + err(`fatal: ${e.stack ?? e.message ?? String(e)}`); + process.exit(99); +}); diff --git a/barretenberg/ts/dev/msm-webgpu/wgsl_unit_tests.ts b/barretenberg/ts/dev/msm-webgpu/wgsl_unit_tests.ts index 821081266812..de4229968626 100644 --- a/barretenberg/ts/dev/msm-webgpu/wgsl_unit_tests.ts +++ b/barretenberg/ts/dev/msm-webgpu/wgsl_unit_tests.ts @@ -515,5 +515,6 @@ export async function runAllWgslUnitTests(): Promise { results.push(await testTransposeAtChunkSize(15, 256)); results.push(await testTransposeAtChunkSize(4, 256)); results.push(await testTransposeAtChunkSize(16, 256)); + return results; } diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine.ts b/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine.ts index 93b3738d7c41..6f926ba6f514 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine.ts @@ -140,6 +140,26 @@ export const smvp_batch_affine_gpu = async ( const num_words = shaderManager.num_words; const limb_byte_length = num_words * 4; + // WINDOWS_PER_BATCH pools the pair pools of WPB consecutive subtasks + // into ONE fr_inv_by_a per (batch, sub_wg) — see + // batch_inverse_parallel.template.wgsl. With T=17 subtasks at logN=16 + // and WPB=4, num_batches = ceil(17/4) = 5, so a round runs 5×W + // workgroups in the inverse pass instead of 17×W (~4× fewer fr_inv + // calls overall). The pair-pool layout is per-subtask, unchanged; the + // pooling only changes how the inverse kernel reads/writes scratch. + // WPB=1 falls back to byte-identical pre-pooling behaviour. + // + // Default is 1: the per-pass profiler at logN=16 on Apple Silicon + // showed ba_inverse Σ ≈ 14% of MSM, so pooling buys little relative + // to BPR (51%). The plumbing stays in place so the knob is one edit + // away once BPR is no longer the dominant cost. + const WINDOWS_PER_BATCH = 1; + const num_batches = Math.ceil(num_subtasks / WINDOWS_PER_BATCH); + // count_buf (round_count + finalize_count) must be safely loadable + // for every slot the inverse kernel reads — i.e. through the tail of + // the last batch. + const count_slots_padded = num_batches * WINDOWS_PER_BATCH; + // ----- Workspace buffers ----- // When a persistent context is provided, all per-MSM workspace buffers // are pulled from `context.acquirePersistentBuffer` and survive across @@ -190,7 +210,12 @@ export const smvp_batch_affine_gpu = async ( // THIS round's inverse + apply. Decouples the schedule writer from // the inverse/apply readers so we can zero pair_counter in dispatch_args // without losing the values inverse/apply still need. - const round_count_sb = acquire_ws('round_count', num_subtasks * 4); + // + // Padded to num_batches × WPB so the inverse kernel's per-batch WPB- + // wide atomicLoad scan can safely touch the tail of the last batch + // when num_subtasks isn't a multiple of WPB. Persistent buffers come + // zero-initialised; nothing else writes those tail slots. + const round_count_sb = acquire_ws('round_count', count_slots_padded * 4); // Indirect-dispatch arg buffer. Layout: 9 u32s (36 B): // args[0..3] = schedule (X, Y, Z) — for the *next* round // args[3..6] = inverse (X, Y, Z) — for *this* round @@ -216,9 +241,15 @@ export const smvp_batch_affine_gpu = async ( // Contents are constant across MSM calls, so we write once on first // acquisition and skip the writeBuffer thereafter. let finalize_count_sb: GPUBuffer; + // Sized to count_slots_padded so the inverse's per-batch WPB-wide scan + // sees zeros for tail slots past num_subtasks (which the kernel + // skips by treating wg_counts[w]=0 as an empty subtask). The first + // num_subtasks slots carry half_num_columns each. + const finalize_counts_arr = new Uint32Array(count_slots_padded); + for (let i = 0; i < num_subtasks; i++) finalize_counts_arr[i] = half_num_columns; if (context !== undefined) { - finalize_count_sb = context.acquirePersistentBuffer(`${ws_key}:finalize_count`, num_subtasks * 4); - // Re-write the constants every time (cheap — T u32s = 64 B at T=16) + finalize_count_sb = context.acquirePersistentBuffer(`${ws_key}:finalize_count`, count_slots_padded * 4); + // Re-write the constants every time (cheap — at most ~20 u32s = 80 B) // rather than tracking a "first call" flag. Avoids a subtle bug where // a bigger N then smaller N would leave stale values in the cached // buffer (acquirePersistentBuffer resets identity on size change but @@ -226,12 +257,12 @@ export const smvp_batch_affine_gpu = async ( device.queue.writeBuffer( finalize_count_sb, 0, - new Uint8Array(new Uint32Array(new Array(num_subtasks).fill(half_num_columns)).buffer), + new Uint8Array(finalize_counts_arr.buffer, finalize_counts_arr.byteOffset, finalize_counts_arr.byteLength), ); } else { finalize_count_sb = create_and_write_sb( device, - new Uint8Array(new Uint32Array(new Array(num_subtasks).fill(half_num_columns)).buffer), + new Uint8Array(finalize_counts_arr.buffer, finalize_counts_arr.byteOffset, finalize_counts_arr.byteLength), ); } @@ -256,21 +287,20 @@ export const smvp_batch_affine_gpu = async ( const apply_workgroup_size = 64; const finalize_workgroup_size_default = 256; - // Inverse pass parallelism: each subtask's pair pool is split across - // NUM_SUB_WGS_PER_SUBTASK contiguous slices, with one workgroup of - // TPB=64 threads inverting each slice independently (its own fr_inv). + // Inverse pass parallelism: each (batch, sub_wg) inverts a contiguous + // slice of the merged pool. NUM_SUB_WGS_PER_SUBTASK workgroups of + // TPB=64 threads each handle one slice independently (own fr_inv). // Drops per-thread sequential bs by W× → W× faster Phase A/D at large // N. The W extra fr_invs run concurrently across SMs (zero wall-time - // cost). W=8 keeps SM occupancy high (T=16 × W=8 = 128 in-flight - // workgroups, matching RTX-class GPU SM counts). + // cost). const NUM_SUB_WGS_PER_SUBTASK = 8; const init_shader = shaderManager.gen_batch_affine_init_shader(init_workgroup_size); const schedule_shader = shaderManager.gen_batch_affine_schedule_shader(schedule_workgroup_size); - const dispatch_args_shader = shaderManager.gen_batch_affine_dispatch_args_shader(); - // Multi-workgroup batch-inverse (W sub-WGs per subtask). See + const dispatch_args_shader = shaderManager.gen_batch_affine_dispatch_args_shader(WINDOWS_PER_BATCH); + // Multi-workgroup batch-inverse (W sub-WGs per (batch, sub_wg)). See // batch_inverse_parallel.template.wgsl for the algorithm. - const inverse_shader = shaderManager.gen_batch_inverse_parallel_shader(NUM_SUB_WGS_PER_SUBTASK); + const inverse_shader = shaderManager.gen_batch_inverse_parallel_shader(NUM_SUB_WGS_PER_SUBTASK, WINDOWS_PER_BATCH); const apply_shader = shaderManager.gen_batch_affine_apply_scatter_shader(apply_workgroup_size); // Finalize dispatch geometry — single dispatch covers all T·h IDs; // z workgroups = num_subtasks, threading inside the kernel mirrors @@ -350,7 +380,7 @@ export const smvp_batch_affine_gpu = async ( ], inverse_shader, context, - `bn254:batch_affine_inverse:parallel-v3-pitch:${num_columns}:${input_size}`, + `bn254:batch_affine_inverse:parallel-v4-wpb${WINDOWS_PER_BATCH}:${num_columns}:${input_size}`, ); const apply_pipe = await compile_pipeline_for( @@ -381,7 +411,7 @@ export const smvp_batch_affine_gpu = async ( ], dispatch_args_shader, context, - `bn254:batch_affine_dispatch_args:v2-fwd:${num_subtasks}:${apply_workgroup_size}`, + `bn254:batch_affine_dispatch_args:v3-wpb${WINDOWS_PER_BATCH}:${num_subtasks}:${apply_workgroup_size}`, ); const finalize_collect_pipe = await compile_pipeline_for( @@ -532,7 +562,7 @@ export const smvp_batch_affine_gpu = async ( // Cache suffix bumped to "v2" so we don't accidentally reuse a cached // bind group from the pre-fold v1 layout (same buffer count, different // semantics — the v2 suffix tags the new wiring explicitly). - const inverse_bg = acquire_bg('inverse_bg:v2', inverse_pipe.layout, [ + const inverse_bg = acquire_bg(`inverse_bg:v3-wpb${WINDOWS_PER_BATCH}`, inverse_pipe.layout, [ pair_delta_sb, pair_prefix_sb, pair_inv_sb, @@ -555,7 +585,7 @@ export const smvp_batch_affine_gpu = async ( // Suffix bumped to v2 by P2-clear: bind group now binds 4 buffers // (pair_counter + round_count + dispatch_args + ub) instead of 3. // W suffix dropped along with the P1 revert. - const dispatch_args_bg = acquire_bg('dispatch_args_bg:v2', dispatch_args_pipe.layout, [ + const dispatch_args_bg = acquire_bg(`dispatch_args_bg:v3-wpb${WINDOWS_PER_BATCH}`, dispatch_args_pipe.layout, [ pair_counter_sb, round_count_sb, dispatch_args_sb, @@ -576,7 +606,7 @@ export const smvp_batch_affine_gpu = async ( finalize_ub, ]); - const finalize_inverse_bg = acquire_bg('finalize_inverse_bg', inverse_pipe.layout, [ + const finalize_inverse_bg = acquire_bg(`finalize_inverse_bg:v2-wpb${WINDOWS_PER_BATCH}`, inverse_pipe.layout, [ pair_delta_sb, pair_prefix_sb, pair_inv_sb, @@ -786,17 +816,18 @@ export const smvp_batch_affine_gpu = async ( profiler?.stage('ba_finalize_collect'), ); - // Pass B (batch_inverse): (W, 1, T) workgroups in parallel — each - // subtask's h-element slice is split across W sub-WGs, each - // independently inverting its sub-slice. Same kernel as the round - // loop's inverse pass. + // Pass B (batch_inverse): (W, 1, num_batches) workgroups in parallel + // — each (batch, sub_wg) inverts the merged h-element slices of WPB + // consecutive subtasks with one fr_inv_by_a. Same kernel as the + // round loop's inverse pass; finalize_count_sb is pre-populated with + // half_num_columns for valid subtasks and zero for the padded tail. await execute_pipeline( commandEncoder, inverse_pipe.pipeline, finalize_inverse_bg, NUM_SUB_WGS_PER_SUBTASK, 1, - num_subtasks, + num_batches, profiler?.stage('ba_finalize_inverse'), ); diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine_bn254.test.ts b/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine_bn254.test.ts new file mode 100644 index 000000000000..8df1ebbfc6a0 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine_bn254.test.ts @@ -0,0 +1,366 @@ +// Jest unit tests for the multi-window batched-Pippenger host +// reference (sub-step 2.1 of the WebGPU MSM rewrite plan). +// +// What this validates (the WASM cross-window batched-inverse pattern +// the GPU shaders will emulate in Phase 2): +// +// 1. windowsPerBatch=1 with the new signed-Booth path equals the +// existing unsigned-digit `batchAffineMSM`. Asserts that the +// signed-digit recoding plus 32-bit schedule encoding produces a +// mathematically identical MSM result. +// +// 2. windowsPerBatch=2 and =4 agree with windowsPerBatch=1 on the +// same inputs. Confirms the cross-window pooled batched-inverse +// is associativity-correct (changing inversion batch size never +// changes the result). +// +// 3. Edge cases: all-zero scalars, single non-zero scalar, scalars +// at p-1, scalars at c-bit boundary values (2^k - 1, 2^k, 2^k + 1). +// +// 4. Schedule encoding round-trip: encode → decode is identity; +// dedup bits (29|30) round-trip as zero (the consumer rejects +// non-zero dedup bits). +// +// 5. Signed-Booth recoding identity: reassembling the recoded digits +// via Σ_w (sign[w] ? -mag[w] : mag[w]) * 2^(c*w) == scalar. + +import { + batchAffineMSM, + batchAffineMSMMultiWindow, + decodeScheduleEntry, + encodeScheduleEntry, + SCHEDULE_INDEX_MASK, + SCHEDULE_SIGN_BIT, + _testOnly, +} from "./batch_affine_bn254.js"; +import { + BN254_SCALAR_FIELD, + BN254_ZERO, + bn254ScalarField, + doubleBn254Point, + msmBn254, + scalarMultBn254Point, + type Bn254Point, +} from "./bn254.js"; + +// Reproducible RNG (mirrors bench-field-mul.ts / bernstein_yang.test.ts). +const makeRng = (seed: number): (() => number) => { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +}; + +// Random scalar in [0, p) via 8 × 32-bit limbs (>= 254 bits of entropy) +// then reduced mod p. Plenty above the scalar field size. +const randomScalar = (rng: () => number): bigint => { + let s = 0n; + for (let i = 0; i < 8; i++) { + s = (s << 32n) | BigInt(rng() >>> 0); + } + return s % BN254_SCALAR_FIELD; +}; + +// BN254 generator (matches the canonical hex used by other tests in +// this directory: G_x = 1, G_y = 2). +const BN254_G: Bn254Point = { x: 1n, y: 2n }; + +// Build N test points as small multiples of G — cheap to construct, +// and the random scalars we sample never collide with these multiples +// modulo p (sampling space is ~2^254 vs N up to a few hundred). +const buildPoints = (n: number, rng: () => number): Bn254Point[] => { + const pts: Bn254Point[] = new Array(n); + for (let i = 0; i < n; i++) { + // Use a fixed small offset + scalar from rng. Avoids ever passing + // identity by choosing scalars from [1, 2^31). + const k = BigInt((rng() & 0x7fffffff) + 1); + pts[i] = scalarMultBn254Point(BN254_G, k); + } + return pts; +}; + +const pointEq = (a: Bn254Point, b: Bn254Point): boolean => { + const aZ = a.infinity === true; + const bZ = b.infinity === true; + if (aZ && bZ) return true; + if (aZ || bZ) return false; + return a.x === b.x && a.y === b.y; +}; + +describe("Schedule entry encoding (32-bit layout)", () => { + it("encode/decode identity for positive sign", () => { + const entry = encodeScheduleEntry(42, 0); + expect(entry).toBe(42); + const dec = decodeScheduleEntry(entry); + expect(dec.scalarIdx).toBe(42); + expect(dec.sign).toBe(0); + }); + + it("encode/decode identity for negative sign", () => { + const entry = encodeScheduleEntry(42, 1); + // bit 31 set + scalarIdx = 42. As unsigned 32-bit: + expect(entry).toBe((SCHEDULE_SIGN_BIT | 42) >>> 0); + const dec = decodeScheduleEntry(entry); + expect(dec.scalarIdx).toBe(42); + expect(dec.sign).toBe(1); + }); + + it("encode/decode at 29-bit index boundary", () => { + const entry = encodeScheduleEntry(SCHEDULE_INDEX_MASK, 1); + const dec = decodeScheduleEntry(entry); + expect(dec.scalarIdx).toBe(SCHEDULE_INDEX_MASK); + expect(dec.sign).toBe(1); + }); + + it("encode throws on out-of-range scalarIdx", () => { + expect(() => + encodeScheduleEntry(SCHEDULE_INDEX_MASK + 1, 0), + ).toThrow(); + }); + + it("decode throws when dedup bit 29 is set", () => { + const entry = (1 << 29) | 7; + expect(() => decodeScheduleEntry(entry)).toThrow(); + }); + + it("decode throws when dedup bit 30 is set", () => { + const entry = (1 << 30) | 7; + expect(() => decodeScheduleEntry(entry)).toThrow(); + }); + + it("zero entry decodes to (0, 0)", () => { + const dec = decodeScheduleEntry(0); + expect(dec.scalarIdx).toBe(0); + expect(dec.sign).toBe(0); + }); +}); + +describe("Signed-Booth digit recoding identity", () => { + const CS = [2, 4, 8, 15, 16]; + const SCALAR_BITS = 254; + + for (const c of CS) { + it(`c=${c}: reassembled digits equal scalar (random + edge cases)`, () => { + // Match `batchAffineMSMMultiWindow`'s formula (+2 headroom for + // the signed-Booth carry). Without this the top window's sign + // bit can fall inside the scalar's MSB region at small c values + // (e.g. c=2 with scalar = p-1, MSB = bit 253 ⇒ top window sign + // bit = bit 253 ⇒ phantom carry past num_windows). + const numWindows = Math.ceil((SCALAR_BITS + 2) / c); + const rng = makeRng(0xc001 + c); + const samples: bigint[] = [ + 0n, + 1n, + 2n, + (1n << BigInt(c)) - 1n, + 1n << BigInt(c), + (1n << BigInt(c)) + 1n, + (1n << BigInt(c - 1)) - 1n, + 1n << BigInt(c - 1), + BN254_SCALAR_FIELD - 1n, + BN254_SCALAR_FIELD - 7n, + ]; + for (let i = 0; i < 32; i++) samples.push(randomScalar(rng)); + + for (const s of samples) { + const { magnitudes, signs } = _testOnly.recodeScalarBooth( + s, + c, + numWindows, + ); + // Reassemble digit_w * 2^(c*w). All magnitudes fit in [0, 2^(c-1)]. + let recon = 0n; + const shifts: bigint[] = new Array(numWindows); + for (let w = 0; w < numWindows; w++) { + shifts[w] = BigInt(c * w); + } + for (let w = 0; w < numWindows; w++) { + const mag = BigInt(magnitudes[w]); + const signed = signs[w] === 1 ? -mag : mag; + recon += signed << shifts[w]; + // Magnitude bound check. + expect(magnitudes[w]).toBeGreaterThanOrEqual(0); + expect(magnitudes[w]).toBeLessThanOrEqual(1 << (c - 1)); + } + expect(recon).toBe(s); + } + }); + } +}); + +describe("batchAffineMSMMultiWindow vs batchAffineMSM (windowsPerBatch=1)", () => { + // c=16 is the production chunk_size, but the unsigned-digit + // batchAffineMSM allocates 2^c = 65536 buckets per window — fine but + // slow in pure-bigint at large c. We use c=8 here so the test fixture + // stays under a second per case. + const C = 8; + const NS = [16, 64, 256]; + + for (const n of NS) { + it(`n=${n}, c=${C}: signed-Booth (wpb=1) == unsigned baseline`, () => { + const rng = makeRng(0xdead + n); + const points = buildPoints(n, rng); + const scalars: bigint[] = new Array(n); + for (let i = 0; i < n; i++) scalars[i] = randomScalar(rng); + + const expected = batchAffineMSM(points, scalars, C); + const actual = batchAffineMSMMultiWindow(points, scalars, C, 1); + expect(pointEq(actual, expected)).toBe(true); + + // Defence-in-depth: also cross-check against the brute-force MSM + // in bn254.ts. Both helpers above could be wrong in the same way + // (the signed-Booth recoder + bucket math is bespoke), so this + // third independent path locks down the truth. + const truth = msmBn254(points, scalars); + expect(pointEq(actual, truth)).toBe(true); + }); + } + + it("5 random seeds for n=64 all agree", () => { + for (let seed = 1; seed <= 5; seed++) { + const rng = makeRng(0xa11ce + seed); + const points = buildPoints(64, rng); + const scalars: bigint[] = new Array(64); + for (let i = 0; i < 64; i++) scalars[i] = randomScalar(rng); + + const expected = batchAffineMSM(points, scalars, C); + const actual = batchAffineMSMMultiWindow(points, scalars, C, 1); + expect(pointEq(actual, expected)).toBe(true); + } + }); +}); + +describe("batchAffineMSMMultiWindow: windowsPerBatch invariance", () => { + const C = 8; + const NS = [16, 64, 256]; + + for (const n of NS) { + it(`n=${n}, c=${C}: wpb=2 == wpb=1`, () => { + const rng = makeRng(0xb0b + n); + const points = buildPoints(n, rng); + const scalars: bigint[] = new Array(n); + for (let i = 0; i < n; i++) scalars[i] = randomScalar(rng); + + const ref1 = batchAffineMSMMultiWindow(points, scalars, C, 1); + const ref2 = batchAffineMSMMultiWindow(points, scalars, C, 2); + expect(pointEq(ref2, ref1)).toBe(true); + }); + + it(`n=${n}, c=${C}: wpb=4 == wpb=1`, () => { + const rng = makeRng(0xfee1 + n); + const points = buildPoints(n, rng); + const scalars: bigint[] = new Array(n); + for (let i = 0; i < n; i++) scalars[i] = randomScalar(rng); + + const ref1 = batchAffineMSMMultiWindow(points, scalars, C, 1); + const ref4 = batchAffineMSMMultiWindow(points, scalars, C, 4); + expect(pointEq(ref4, ref1)).toBe(true); + }); + } +}); + +describe("batchAffineMSMMultiWindow: edge cases", () => { + const C = 8; + const N = 32; + + it("all-zero scalars → identity", () => { + const rng = makeRng(0xc0de); + const points = buildPoints(N, rng); + const scalars: bigint[] = new Array(N).fill(0n); + for (const wpb of [1, 2, 4]) { + const r = batchAffineMSMMultiWindow(points, scalars, C, wpb); + expect(r.infinity).toBe(true); + } + }); + + it("single non-zero scalar agrees with scalar mult", () => { + const rng = makeRng(0x5eed); + const points = buildPoints(N, rng); + const idx = 7; + const k = 123456789n; + const scalars: bigint[] = new Array(N).fill(0n); + scalars[idx] = k; + const expected = scalarMultBn254Point(points[idx], k); + for (const wpb of [1, 2, 4]) { + const r = batchAffineMSMMultiWindow(points, scalars, C, wpb); + expect(pointEq(r, expected)).toBe(true); + } + }); + + it("scalar = p - 1 (maximum valid scalar)", () => { + const rng = makeRng(0xbeef); + const points = buildPoints(N, rng); + const scalars: bigint[] = new Array(N); + for (let i = 0; i < N; i++) scalars[i] = BN254_SCALAR_FIELD - 1n; + const truth = msmBn254(points, scalars); + for (const wpb of [1, 2, 4]) { + const r = batchAffineMSMMultiWindow(points, scalars, C, wpb); + expect(pointEq(r, truth)).toBe(true); + } + }); + + it("scalars at c-bit boundary values (2^k - 1, 2^k, 2^k + 1)", () => { + const rng = makeRng(0xface); + const points = buildPoints(8, rng); + // 8 scalars: covers 2^(c-1) - 1, 2^(c-1), 2^(c-1) + 1, 2^c - 1, + // 2^c, 2^c + 1, 2^(2c) - 1, 2^(2c). + const cBI = BigInt(C); + const scalars: bigint[] = [ + (1n << (cBI - 1n)) - 1n, + 1n << (cBI - 1n), + (1n << (cBI - 1n)) + 1n, + (1n << cBI) - 1n, + 1n << cBI, + (1n << cBI) + 1n, + (1n << (cBI * 2n)) - 1n, + 1n << (cBI * 2n), + ]; + const truth = msmBn254(points, scalars); + for (const wpb of [1, 2, 4]) { + const r = batchAffineMSMMultiWindow(points, scalars, C, wpb); + expect(pointEq(r, truth)).toBe(true); + } + }); + + it("mixed (zero + nonzero + p-1) and wpb=1/2/4 agree", () => { + const rng = makeRng(0xface5); + const N2 = 16; + const points = buildPoints(N2, rng); + const scalars: bigint[] = new Array(N2); + for (let i = 0; i < N2; i++) { + if (i % 4 === 0) scalars[i] = 0n; + else if (i % 4 === 1) scalars[i] = BN254_SCALAR_FIELD - 1n; + else scalars[i] = randomScalar(rng); + } + const truth = msmBn254(points, scalars); + for (const wpb of [1, 2, 4]) { + const r = batchAffineMSMMultiWindow(points, scalars, C, wpb); + expect(pointEq(r, truth)).toBe(true); + } + }); +}); + +describe("batchAffineMSMMultiWindow: c=16 (production chunk size)", () => { + // Smoke test at c=16 to confirm the production chunk_size works. + // Limit n=16 to keep the 2^15-bucket-per-window per-batch allocation + // pure-bigint cost under a second. + it("c=16, n=16, wpb=1/2/4 agree with msmBn254 truth", () => { + const rng = makeRng(0x16); + const points = buildPoints(16, rng); + const scalars: bigint[] = new Array(16); + for (let i = 0; i < 16; i++) scalars[i] = randomScalar(rng); + const truth = msmBn254(points, scalars); + for (const wpb of [1, 2, 4]) { + const r = batchAffineMSMMultiWindow(points, scalars, 16, wpb); + expect(pointEq(r, truth)).toBe(true); + } + }); +}); + +// Silence unused-import warnings: bn254ScalarField, doubleBn254Point, +// BN254_ZERO are imported for future-proofing tests that exercise the +// pooled-inverse boundary cases; the linter is otherwise vocal. +void bn254ScalarField; +void doubleBn254Point; +void BN254_ZERO; diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine_bn254.ts b/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine_bn254.ts index 0168270f98c1..3611ee30e6a9 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine_bn254.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine_bn254.ts @@ -32,9 +32,23 @@ import { doubleBn254Point, isBn254Zero, modInverse, + negateBn254Point, type Bn254Point, } from "./bn254.js"; +// 32-bit schedule-entry encoding (mirrors WASM +// scalar_multiplication.cpp lines 552-567). Stage 4 stores only the +// point sign + scalar index; bucket magnitude is recovered later from +// bucket_start ranges because the schedule is bucket-contiguous. +// bit 31: sign bit (1 = subtract the point) +// bit 30: dedup redirect (always zero in this path) +// bit 29: dedup skip (always zero in this path) +// bits 0..28: scalar_idx (29-bit payload, ≤ 2^29 = 512M) +// Use `>>> 0` to keep these as unsigned 32-bit; JS's `1 << 31` would +// be the negative number -2147483648 otherwise. +export const SCHEDULE_SIGN_BIT = (1 << 31) >>> 0; // 0x8000_0000 +export const SCHEDULE_INDEX_MASK = (1 << 29) - 1; // 0x1FFF_FFFF + /** * Montgomery's batch-inverse trick. * @@ -282,3 +296,470 @@ export const _testOnlyInvViaFermat = (a: bigint): bigint => { // a^(p-2) mod p return modInverse(a); }; + +// ===================================================================== +// Multi-window batched Pippenger (host reference for Phase 2 GPU port) +// ===================================================================== +// +// Mirrors the WASM `pippenger_round_parallel` multi-window structure +// (`scalar_multiplication.cpp` lines 3780-4080 for Stages 1-4, lines +// 4210-4550 for Stage 6, lines 4550-4610 for the per-region driver). +// Specifically un-dedup, un-GLV path. +// +// Key shape change vs `batchAffineMSM` above: +// - Signed-digit (Booth) recoding per window: each c-bit digit is in +// [-2^(c-1), 2^(c-1)], halving the bucket count to 2^(c-1) + 1 and +// pairing positive/negative magnitudes via the schedule's sign bit. +// - A "batch" of `windowsPerBatch` consecutive windows shares a +// single batched inversion in the bucket-accumulation phase: pairs +// pending from EVERY window in the batch are pooled into one +// batchInverse call (WASM Stage 6b +// `recursive_affine_bucket_reduce_strided`). With +// windowsPerBatch = 4 this amortises the inversion ~4× harder than +// the single-window path. + +// 32-bit schedule entry layout encode (bake bits 29, 30 = 0). +export const encodeScheduleEntry = ( + scalarIdx: number, + sign: 0 | 1, +): number => { + if (scalarIdx < 0 || scalarIdx > SCHEDULE_INDEX_MASK) { + throw new Error( + `encodeScheduleEntry: scalarIdx ${scalarIdx} out of 29-bit range`, + ); + } + // `>>> 0` keeps the result a non-negative 32-bit integer in JS's + // double-backed Number representation. + return ((sign << 31) | scalarIdx) >>> 0; +}; + +// 32-bit schedule entry layout decode. Verifies the dedup bits 29, 30 +// are zero — they must be in this path. +export const decodeScheduleEntry = ( + entry: number, +): { scalarIdx: number; sign: 0 | 1 } => { + const u = entry >>> 0; + // Bits 29 and 30 must always be zero in the non-dedup path. Strict + // check here so consumers can rely on the invariant rather than + // silently masking them off. + if ((u & ((1 << 30) | (1 << 29))) !== 0) { + throw new Error( + `decodeScheduleEntry: dedup bits (29|30) set in entry 0x${u.toString(16)}`, + ); + } + return { + scalarIdx: u & SCHEDULE_INDEX_MASK, + sign: (u >>> 31) as 0 | 1, + }; +}; + +/** + * Constantine signed-Booth recoding of one scalar into `numWindows` + * c-bit signed digits. + * + * Returns digit magnitudes (always non-negative, in [0, 2^(c-1)]) and + * signs (0 = +, 1 = -) per window. The reconstruction identity: + * + * scalar = sum_w (sign[w] ? -magnitude[w] : magnitude[w]) * 2^(c*w) + * + * Mirrors `get_constantine_packed_digit` + * (`scalar_multiplication.cpp` lines 217-271). Carry between windows + * is implicit via the c+1-bit read that overlaps the previous window's + * high bit — no explicit propagation needed. + * + * The bottom window reads bits [0, c) with a synthetic 0 as the + * lookback bit. Non-bottom windows read [c*w - 1, c*(w+1)), with the + * shared boundary bit at c*w - 1 acting as the lookback. + */ +const recodeScalarBooth = ( + scalar: bigint, + c: number, + numWindows: number, +): { magnitudes: number[]; signs: (0 | 1)[] } => { + if (c < 1 || c > 28) { + // c+1 must fit in the 32-bit `raw_wide` intermediate, and bucket + // magnitudes must fit in the 29-bit schedule payload. Both are + // satisfied for the chunk sizes we care about (c = 15, 16). Plus a + // generous test margin. + throw new Error(`recodeScalarBooth: c=${c} out of supported range`); + } + const cBI = BigInt(c); + const magMask = (1n << cBI) - 1n; // val_mask in WASM + + const magnitudes: number[] = new Array(numWindows); + const signs: (0 | 1)[] = new Array(numWindows); + + for (let w = 0; w < numWindows; w++) { + let raw: bigint; + if (w === 0) { + // Bottom window: synthetic-zero lookback at bit -1. + // raw == (scalar[0:c]) << 1 + raw = (scalar & magMask) << 1n; + } else { + // raw == c+1 bits at [c*w - 1, c*w + c). + const shift = BigInt(w * c - 1); + raw = (scalar >> shift) & ((1n << (cBI + 1n)) - 1n); + } + const neg = Number((raw >> cBI) & 1n) as 0 | 1; + // encode = (raw + 1) >> 1 — the inner "(encode - neg) ^ neg_mask" + // dance is the WASM's branchless negate-if-sign idiom. Below is + // the equivalent: when neg=0, mag = encode; when neg=1, mag = + // (-encode) & ((1<> 1n; + let magBI: bigint; + if (neg === 0) { + magBI = encode & magMask; + } else { + // Two's-complement negate, masked to c bits. + // (-encode) mod 2^c == (2^c - encode) mod 2^c. + // BigInt's bit-AND with a positive mask returns a non-negative + // result on negative inputs (effectively reduces mod 2^c), so + // `(-encode) & magMask` is exactly what we want. + magBI = (-encode) & magMask; + } + magnitudes[w] = Number(magBI); + signs[w] = neg; + } + + return { magnitudes, signs }; +}; + +/** + * Cross-window pooled batch-affine bucket reducer. + * + * Each window has its own `buckets[w][b]` stack of points. Per round, + * we pop ONE pair from EVERY (window, bucket) whose stack has ≥ 2 + * points, into a single global pool, and run one batched inversion + * across the entire pool. Then we push each sum back to its + * originating (window, bucket). Loop until all stacks ≤ 1. + * + * This is the host analogue of WASM Stage 6b's + * `recursive_affine_bucket_reduce_strided` (lines 1276-1500): the + * inversion amortises across `windows_in_batch * num_buckets` pairs + * rather than just `num_buckets`. With windowsPerBatch=4 and ~2^c-1 + * non-empty buckets per window, this is a ~4× larger inversion batch. + * + * Mutates the input bucket stacks (drains them). + */ +const pooledBatchSumBuckets = ( + bucketsPerWindow: Bn254Point[][][], +): Bn254Point[][] => { + const numWindows = bucketsPerWindow.length; + const numBuckets = numWindows > 0 ? bucketsPerWindow[0].length : 0; + + while (true) { + // Collect one pair per (window, bucket) stack with ≥ 2 entries. + const ps: Bn254Point[] = []; + const qs: Bn254Point[] = []; + const targetW: number[] = []; + const targetB: number[] = []; + + for (let w = 0; w < numWindows; w++) { + const buckets = bucketsPerWindow[w]; + for (let b = 0; b < numBuckets; b++) { + if (buckets[b].length >= 2) { + const q = buckets[b].pop()!; + const p = buckets[b].pop()!; + ps.push(p); + qs.push(q); + targetW.push(w); + targetB.push(b); + } + } + } + + if (ps.length === 0) break; + + // Single batched inversion spans pairs from every window in the + // batch — this is the structural win over per-window reduction. + const sums = batchAffineAdd(ps, qs); + for (let i = 0; i < sums.length; i++) { + bucketsPerWindow[targetW[i]][targetB[i]].push(sums[i]); + } + } + + // Each (window, bucket) stack now has 0 or 1 point. + const out: Bn254Point[][] = new Array(numWindows); + for (let w = 0; w < numWindows; w++) { + const buckets = bucketsPerWindow[w]; + const row: Bn254Point[] = new Array(numBuckets); + for (let b = 0; b < numBuckets; b++) { + row[b] = buckets[b].length === 1 ? buckets[b][0] : BN254_ZERO; + } + out[w] = row; + } + return out; +}; + +/** + * Build per-window 32-bit schedule for the given window range, bucket- + * sorted (Stages 1+2+3+4 of the WASM compressed into one pass). + * + * Returns: + * - `schedule[w]`: array of 32-bit entries laid out in bucket order + * (all entries for bucket 1 first, then bucket 2, ...). + * - `bucketStart[w]`: exclusive prefix of per-bucket counts; + * `bucketStart[w][b+1] - bucketStart[w][b]` is the # of entries + * pinned to bucket b in window w. `bucketStart[w][0]` and + * `bucketStart[w][1]` are both 0 because bucket 0 is dropped + * (digit zero contributes nothing). + */ +const buildScheduleForBatch = ( + scalars: bigint[], + c: number, + windowStart: number, + windowsInBatch: number, + numBuckets: number, + numWindowsTotal: number, +): { schedule: number[][]; bucketStart: number[][] } => { + const n = scalars.length; + const schedule: number[][] = new Array(windowsInBatch); + const bucketStart: number[][] = new Array(windowsInBatch); + + // Stage 1: histogram per (window, bucket). One pass per window. + const counts: number[][] = new Array(windowsInBatch); + // Per-scalar per-window recoded digit and sign, cached so Stage 4 + // doesn't recompute. Matches the WASM's `fill_packed_digit_buffer` + // call pattern (one shared buffer reused across Stages 1 and 4). + const recoded: { mag: number; sign: 0 | 1 }[][] = new Array( + windowsInBatch, + ); + + for (let wInBatch = 0; wInBatch < windowsInBatch; wInBatch++) { + counts[wInBatch] = new Array(numBuckets).fill(0); + recoded[wInBatch] = new Array(n); + } + + for (let i = 0; i < n; i++) { + const all = recodeScalarBooth(scalars[i], c, numWindowsTotal); + for (let wInBatch = 0; wInBatch < windowsInBatch; wInBatch++) { + const wAbs = windowStart + wInBatch; + const mag = all.magnitudes[wAbs]; + const sign = all.signs[wAbs]; + recoded[wInBatch][i] = { mag, sign }; + if (mag !== 0) { + counts[wInBatch][mag]++; + } + } + } + + // Stage 2+3: per-window exclusive prefix-sum of counts → bucketStart. + // bucketStart[w][b] = sum_{b' < b} counts[w][b'], with [0] = [1] = 0 + // because bucket 0 is dropped (digit==0 contributes nothing). + for (let wInBatch = 0; wInBatch < windowsInBatch; wInBatch++) { + const bs: number[] = new Array(numBuckets + 1); + bs[0] = 0; + bs[1] = 0; + let running = 0; + for (let b = 1; b < numBuckets; b++) { + bs[b] = running; + running += counts[wInBatch][b]; + } + bs[numBuckets] = running; + bucketStart[wInBatch] = bs; + } + + // Stage 4: scatter into bucket-sorted schedule. Allocate per-window + // schedule arrays sized to the total entry count. + for (let wInBatch = 0; wInBatch < windowsInBatch; wInBatch++) { + schedule[wInBatch] = new Array(bucketStart[wInBatch][numBuckets]); + } + const cursor: number[][] = new Array(windowsInBatch); + for (let wInBatch = 0; wInBatch < windowsInBatch; wInBatch++) { + cursor[wInBatch] = new Array(numBuckets).fill(0); + } + + for (let i = 0; i < n; i++) { + if (i > SCHEDULE_INDEX_MASK) { + throw new Error( + `buildScheduleForBatch: scalar index ${i} exceeds 29-bit schedule payload`, + ); + } + for (let wInBatch = 0; wInBatch < windowsInBatch; wInBatch++) { + const { mag, sign } = recoded[wInBatch][i]; + if (mag === 0) continue; + const slot = bucketStart[wInBatch][mag] + cursor[wInBatch][mag]++; + schedule[wInBatch][slot] = encodeScheduleEntry(i, sign); + } + } + + return { schedule, bucketStart }; +}; + +/** + * Consume a per-(window, bucket) sorted 32-bit schedule and accumulate + * each window's MSM contribution via cross-window pooled batch-affine. + * + * Returns one Bn254Point per window in the batch. + */ +const reduceBatchSchedule = ( + points: Bn254Point[], + schedule: number[][], + bucketStart: number[][], + numBuckets: number, + c: number, +): Bn254Point[] => { + const windowsInBatch = schedule.length; + + // Build per-window per-bucket point stacks by decoding schedule + // entries. Sign bit ⇒ negate the loaded point before pushing. + const bucketsPerWindow: Bn254Point[][][] = new Array(windowsInBatch); + for (let wInBatch = 0; wInBatch < windowsInBatch; wInBatch++) { + const buckets: Bn254Point[][] = new Array(numBuckets); + for (let b = 0; b < numBuckets; b++) buckets[b] = []; + const sched = schedule[wInBatch]; + const bs = bucketStart[wInBatch]; + // Walk each bucket's contiguous run [bs[b], bs[b+1]) of entries. + for (let b = 1; b < numBuckets; b++) { + const lo = bs[b]; + const hi = bs[b + 1]; + for (let k = lo; k < hi; k++) { + const { scalarIdx, sign } = decodeScheduleEntry(sched[k]); + const p = sign === 1 + ? negateBn254Point(points[scalarIdx]) + : points[scalarIdx]; + buckets[b].push(p); + } + } + bucketsPerWindow[wInBatch] = buckets; + } + + // The pooled reducer is the structural win: ONE batched inversion + // per round spans pairs from EVERY window in the batch. + const bucketSumsPerWindow = pooledBatchSumBuckets(bucketsPerWindow); + + // Per-window running-sum reduction + // S_w = sum_{b=1}^{B-1} b * bucketSums[w][b] + // is identical to the per-window path. Bucket reduction is sequential + // by nature (data-dependency through `running`); the cross-window win + // happens in the bucket-accumulation phase above, not here. + const out: Bn254Point[] = new Array(windowsInBatch); + for (let wInBatch = 0; wInBatch < windowsInBatch; wInBatch++) { + const sums = bucketSumsPerWindow[wInBatch]; + let running: Bn254Point = BN254_ZERO; + let result: Bn254Point = BN254_ZERO; + for (let b = numBuckets - 1; b >= 1; b--) { + running = addBn254Points(running, sums[b]); + result = addBn254Points(result, running); + } + out[wInBatch] = result; + } + // Bind c into the result via Horner combine at the call site; this + // helper only returns per-window sums. + void c; + return out; +}; + +/** + * Multi-window batched Pippenger MSM with cross-window pooled + * batched-affine bucket reduction. + * + * This is the host ground-truth for the WebGPU multi-window Pippenger + * rewrite (Phase 2 of the plan). It implements the WASM structure with + * the un-GLV, un-dedup path: + * + * - Signed-Booth recoding splits each scalar into ceil(254/c) c-bit + * signed digits. + * - Per region (single c here), iterate windows in batches of + * `windowsPerBatch`. Per batch: + * 1. Stages 1+2+3+4 — histogram → prefix-sum → bucket-sorted + * 32-bit schedule (`buildScheduleForBatch`). + * 2. Stage 6 — pooled batch-affine bucket accumulation across + * all windows in the batch (`reduceBatchSchedule`). ONE + * inversion per round pools pairs from every window. + * 3. Stage 7 — per-window bucket running-sum reduction. + * - Final Horner combine over `c` doublings folds windows into one + * point. + * + * Cross-checks against `batchAffineMSM`: + * - windowsPerBatch=1 with this signed-Booth path equals the + * unsigned-digit `batchAffineMSM`. + * - windowsPerBatch=2 and =4 equal windowsPerBatch=1 (the cross- + * window batched inversion is associativity-only — output is + * identical, only the inversion amortisation factor changes). + * + * Used by Jest tests as ground truth for the GPU shader pipeline; not + * tuned for speed. + */ +export const batchAffineMSMMultiWindow = ( + points: Bn254Point[], + scalars: bigint[], + c: number, + windowsPerBatch: number, +): Bn254Point => { + if (points.length !== scalars.length) { + throw new Error( + "batchAffineMSMMultiWindow: points and scalars length mismatch", + ); + } + if (windowsPerBatch < 1) { + throw new Error( + `batchAffineMSMMultiWindow: windowsPerBatch=${windowsPerBatch} must be ≥ 1`, + ); + } + + // BN254 scalar field is 254 bits wide. + const scalarBits = 254; + // `+ 2` headroom for the signed-Booth carry — matches WASM's + // `(NUM_BITS + 2 + window_bits - 1) / window_bits` + // (`scalar_multiplication.cpp` lines 514, 2484). Without this, scalars + // whose MSB lands on the top window's sign-bit position (e.g. p-1 + // with small c) would generate a phantom carry past the topmost + // window and the reconstruction would drop the high contribution. + const numWindows = Math.ceil((scalarBits + 2) / c); + // Signed-digit recoding halves the bucket count: digits live in + // [-2^(c-1), 2^(c-1)], with magnitude 0 dropped. The bucket array + // is dimensioned 2^(c-1) + 1 so bucket index can equal 2^(c-1). + const numBuckets = (1 << (c - 1)) + 1; + + const windowSums: Bn254Point[] = new Array(numWindows); + + for ( + let batchStart = 0; + batchStart < numWindows; + batchStart += windowsPerBatch + ) { + const windowsInBatch = Math.min( + windowsPerBatch, + numWindows - batchStart, + ); + + const { schedule, bucketStart } = buildScheduleForBatch( + scalars, + c, + batchStart, + windowsInBatch, + numBuckets, + numWindows, + ); + + const batchWindowSums = reduceBatchSchedule( + points, + schedule, + bucketStart, + numBuckets, + c, + ); + + for (let wInBatch = 0; wInBatch < windowsInBatch; wInBatch++) { + windowSums[batchStart + wInBatch] = batchWindowSums[wInBatch]; + } + } + + // Horner combine across windows: msm = Σ_w 2^(c*w) * windowSums[w]. + let acc: Bn254Point = BN254_ZERO; + for (let w = numWindows - 1; w >= 0; w--) { + for (let bit = 0; bit < c; bit++) { + acc = doubleBn254Point(acc); + } + acc = addBn254Points(acc, windowSums[w]); + } + return acc; +}; + +// Test-only export: exposed so the test file can exercise the +// signed-Booth recoder in isolation against a brute-force reconstruction. +export const _testOnly = { + recodeScalarBooth, + buildScheduleForBatch, +}; diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang.test.ts b/barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang.test.ts new file mode 100644 index 000000000000..a8c8e5fabf83 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang.test.ts @@ -0,0 +1,818 @@ +// Jest unit tests for the Wasm9x29 Bernstein-Yang TS port (sub-step 1.1 of +// the WebGPU MSM rewrite plan). Asserts: +// 1. invert(a, p) matches modInverse(a, p) on 1000 seeded-LCG random +// BN254 base-field inputs. +// 2. Edge cases: a == 1, a == p-1, a == 2, a == (p-1)/2. +// 3. (a * invert(a)) mod p === 1 for every successful invert. +// 4. The internal step counter never exceeds the static bound implied +// by the WASM source (NUM_OUTER * (BATCH + 2 * RTC_MAX_ITERS) + +// 2 * RTC_MAX_ITERS for the final reduce). This enforces the +// "strictly bounded loops" constraint from the plan. + +import { + invert, + Wasm9x29, + P_BIGINT, + BATCH_NUM, + NUM_OUTER, + RTC_MAX_ITERS, +} from "./bernstein_yang.js"; +import { modInverse } from "./bn254.js"; + +// Seeded LCG (Numerical Recipes constants) mirrored from +// dev/msm-webgpu/bench-field-mul.ts. Reproducible across runs. +function makeRng(seed: number): () => number { + let state = (seed >>> 0) || 1; + return () => { + state = (Math.imul(state, 1664525) + 1013904223) >>> 0; + return state; + }; +} + +function randomBelow(p: bigint, rng: () => number): bigint { + const bitlen = p.toString(2).length; + const byteLen = Math.ceil(bitlen / 8); + // Bounded: max retries cap to keep the test loop hard-bounded. + // For BN254 base field, p ≈ 0.94 * 2^254, so rejection probability is ~6%; + // 64 attempts gives failure probability ~10^-79. + const MAX_TRIES = 64; + for (let attempt = 0; attempt < MAX_TRIES; attempt++) { + let v = 0n; + for (let i = 0; i < byteLen; i++) { + v = (v << 8n) | BigInt(rng() & 0xff); + } + v &= (1n << BigInt(bitlen)) - 1n; + if (v < p) return v; + } + throw new Error("randomBelow: rejection sampling exceeded MAX_TRIES"); +} + +// Hard upper bound on the bounded-loop step counter incremented by +// divsteps inner steps + reduce_to_canonical inner steps. Tighter than +// the absolute max so a regression that loosens the loop bound trips a +// test failure. Components: +// - NUM_OUTER * BATCH = 13 * 58 = 754 divstep inner steps. +// - In-loop reduces: ⌊NUM_OUTER / REDUCE_INTERVAL⌋ = 3 boundary hits, +// each calling reduce_to_canonical twice (d, e) of ≤ RTC_MAX_ITERS = 36. +// → 3 * 2 * 36 = 216. +// - Final reduce_to_canonical(d), plus optional reduce after neg(d) on +// f-negative branch. → 2 * 36 = 72. +// Total upper bound: 754 + 216 + 72 = 1042. +const STEP_BOUND = + NUM_OUTER * BATCH_NUM + + 2 * Math.floor(NUM_OUTER / 4) * RTC_MAX_ITERS + + 2 * RTC_MAX_ITERS; + +describe("Bernstein-Yang Wasm9x29 TS port", () => { + const p = P_BIGINT; + + it("invert(1) == 1", () => { + const stepRef = { steps: 0 }; + const r = invert(1n, p, Wasm9x29.P_INV, stepRef); + expect(r).toBe(1n); + expect(stepRef.steps).toBeLessThan(STEP_BOUND); + }); + + it("invert(2) * 2 mod p == 1", () => { + const stepRef = { steps: 0 }; + const inv2 = invert(2n, p, Wasm9x29.P_INV, stepRef); + expect((inv2 * 2n) % p).toBe(1n); + expect(inv2).toBe(modInverse(2n, p)); + expect(stepRef.steps).toBeLessThan(STEP_BOUND); + }); + + it("invert(p - 1) * (p - 1) mod p == 1", () => { + const a = p - 1n; + const stepRef = { steps: 0 }; + const r = invert(a, p, Wasm9x29.P_INV, stepRef); + expect((a * r) % p).toBe(1n); + expect(r).toBe(modInverse(a, p)); + // p - 1 is its own inverse mod p (since (p-1)^2 = p^2 - 2p + 1 ≡ 1). + expect(r).toBe(p - 1n); + expect(stepRef.steps).toBeLessThan(STEP_BOUND); + }); + + it("invert((p - 1) / 2) * ((p - 1) / 2) mod p == 1", () => { + const a = (p - 1n) / 2n; + const stepRef = { steps: 0 }; + const r = invert(a, p, Wasm9x29.P_INV, stepRef); + expect((a * r) % p).toBe(1n); + expect(r).toBe(modInverse(a, p)); + expect(stepRef.steps).toBeLessThan(STEP_BOUND); + }); + + it("invert(0) == 0", () => { + expect(invert(0n, p)).toBe(0n); + }); + + it("matches modInverse on 1000 seeded-LCG random inputs", () => { + const rng = makeRng(0xdeadbeef); + const N = 1000; + let passes = 0; + let mismatches = 0; + let maxSteps = 0; + for (let i = 0; i < N; i++) { + const a = randomBelow(p, rng); + if (a === 0n) { + // BY returns 0 for 0; modInverse throws. + expect(invert(a, p)).toBe(0n); + passes++; + continue; + } + const stepRef = { steps: 0 }; + const byInv = invert(a, p, Wasm9x29.P_INV, stepRef); + const refInv = modInverse(a, p); + const ok = byInv === refInv && (a * byInv) % p === 1n; + if (!ok) { + mismatches++; + if (mismatches <= 5) { + // Log up to 5 details for debugging. + // eslint-disable-next-line no-console + console.error( + `mismatch at i=${i}: a=${a}\n by =${byInv}\n ref=${refInv}`, + ); + } + } else { + passes++; + } + if (stepRef.steps > maxSteps) maxSteps = stepRef.steps; + } + expect(mismatches).toBe(0); + expect(passes).toBe(N); + expect(maxSteps).toBeLessThan(STEP_BOUND); + // eslint-disable-next-line no-console + console.log( + `[BY] ${passes}/${N} random inputs matched modInverse. max bounded-loop steps observed: ${maxSteps} (bound ${STEP_BOUND})`, + ); + }); + + it("invert(a) * a mod p == 1 (1000 random, separate seed)", () => { + const rng = makeRng(0xa5a5a5a5); + const N = 1000; + let bad = 0; + for (let i = 0; i < N; i++) { + const a = randomBelow(p, rng); + if (a === 0n) continue; + const r = invert(a, p); + if ((a * r) % p !== 1n) bad++; + } + expect(bad).toBe(0); + }); + + // --- Inner-piece sanity tests: lock down individual primitives so the + // --- WGSL port can diff against the same intermediate values. + + it("fromBigint/toBigint round-trip on random values", () => { + const rng = makeRng(0x1234); + for (let i = 0; i < 100; i++) { + const x = randomBelow(p, rng); + const back = Wasm9x29.toBigint(Wasm9x29.fromBigint(x)); + expect(back).toBe(x); + } + }); + + it("divsteps: gcd(p, a) ends with g == 0 after NUM_OUTER iters", () => { + // Spot-check: drive the full algorithm and assert the early-break + // condition fires within NUM_OUTER for typical inputs. + const rng = makeRng(7); + let everEarlyExit = false; + for (let i = 0; i < 50; i++) { + const a = randomBelow(p, rng); + if (a === 0n) continue; + // Re-run the driver loop and count outer iters until g==0. + const P = Wasm9x29.fromBigint(p); + const f = Wasm9x29.fromBigint(p); + const g = Wasm9x29.fromBigint(a); + const d = Wasm9x29.makeZero(); + const e = Wasm9x29.makeOne(); + let delta = 1n; + let outer = 0; + for (; outer < NUM_OUTER; outer++) { + const { mat, delta: nd } = Wasm9x29.divsteps( + delta, + Wasm9x29.low64(f), + Wasm9x29.low64(g), + ); + delta = nd; + Wasm9x29.applyMatrix(mat, f, g, d, e, P, Wasm9x29.P_INV); + if (Wasm9x29.isZero(g)) break; + } + expect(outer).toBeLessThan(NUM_OUTER); + if (outer < NUM_OUTER - 1) everEarlyExit = true; + } + expect(everEarlyExit).toBe(true); + }); +}); + +// ============================================================ +// WGSL helper TS mirror tests (sub-step 1.2) +// ============================================================ +// +// The TS helpers below mirror, byte-for-byte, the algorithms in +// wgsl/bigint/bigint_by.template.wgsl. WGSL operates on 32-bit u32/i32 +// types that wrap on overflow; TS `number` plus `bigint` masking lets us +// emulate the same bit semantics. These tests validate the algorithms +// independent of any GPU dispatch — sub-step 1.3 will run the WGSL +// helpers end-to-end in a divsteps shader. +// +// Each mirror function takes the same input shape as its WGSL counterpart +// (typically a vec2 represented as a 2-element JS tuple) and returns the +// same shape. Ground-truth checks use BigInt. + +const U32_MASK = 0xFFFFFFFFn; +const I32_MIN = -(2 ** 31); +const I32_MAX = 2 ** 31 - 1; +const POW_29 = 1 << 29; +const POW_29N = 1n << 29n; +const POW_32N = 1n << 32n; +const POW_64N = 1n << 64n; +const I64_SIGN = 1n << 63n; +const BY_LIMB_MASK_N = (1n << 29n) - 1n; + +// ---- 64-bit BigInt helpers used for ground-truth comparison ---- +function pairToU64(p: [number, number]): bigint { + const lo = BigInt(p[0] >>> 0); + const hi = BigInt(p[1] >>> 0); + return lo | (hi << 32n); +} +function u64ToPair(x: bigint): [number, number] { + const v = x & ((1n << 64n) - 1n); + return [Number(v & U32_MASK) >>> 0, Number((v >> 32n) & U32_MASK) >>> 0]; +} +function i64PairToBigint(p: [number, number]): bigint { + // Low half is treated as unsigned 32 bits; high half is signed 32. + const lo = BigInt(p[0] >>> 0); + const hi = BigInt(p[1] | 0); // sign-extend via | 0 + return lo + (hi << 32n); +} +function bigintToI64Pair(x: bigint): [number, number] { + const v = x & ((1n << 64n) - 1n); + const lo = Number(v & U32_MASK) >>> 0; + // hi is the upper 32 bits, interpreted as i32 (so values >= 2^31 wrap to negative). + const hiU = Number((v >> 32n) & U32_MASK); + const hi = hiU >= 2 ** 31 ? hiU - 2 ** 32 : hiU; + return [lo, hi]; +} + +// ---- TS mirrors of the WGSL helpers ---- +function u64_add_ts(a: [number, number], b: [number, number]): [number, number] { + const lo = (a[0] + b[0]) >>> 0; + const carry = lo < (a[0] >>> 0) ? 1 : 0; + // (a[1] + b[1] + carry) wraps to u32 like WGSL u32 add. + const hi = (a[1] + b[1] + carry) >>> 0; + return [lo, hi]; +} +function u64_sub_ts(a: [number, number], b: [number, number]): [number, number] { + const lo = (a[0] - b[0]) >>> 0; + const borrow = (a[0] >>> 0) < (b[0] >>> 0) ? 1 : 0; + const hi = (a[1] - b[1] - borrow) >>> 0; + return [lo, hi]; +} +function u64_shr1_ts(x: [number, number]): [number, number] { + const lo = (((x[0] >>> 0) >>> 1) | (((x[1] >>> 0) & 1) << 31)) >>> 0; + const hi = (x[1] >>> 0) >>> 1; + return [lo, hi]; +} +function u64_low_bit_ts(x: [number, number]): number { + return (x[0] >>> 0) & 1; +} +function u64_neg_ts(x: [number, number]): [number, number] { + const nx: [number, number] = [(~x[0]) >>> 0, (~x[1]) >>> 0]; + return u64_add_ts(nx, [1, 0]); +} +function u64_and_ts( + a: [number, number], + mask: [number, number], +): [number, number] { + return [((a[0] & mask[0]) >>> 0), ((a[1] & mask[1]) >>> 0)]; +} + +function i64_add_pair_ts( + a: [number, number], + b: [number, number], +): [number, number] { + const au = a[0] >>> 0; + const bu = b[0] >>> 0; + const lo = (au + bu) >>> 0; + const carry = lo < au ? 1 : 0; + // (a[1] + b[1] + carry) treated as i32 add with wrap. + const hi = (a[1] + b[1] + carry) | 0; + return [lo, hi]; +} +function i64_sub_pair_ts( + a: [number, number], + b: [number, number], +): [number, number] { + const au = a[0] >>> 0; + const bu = b[0] >>> 0; + const lo = (au - bu) >>> 0; + const borrow = au < bu ? 1 : 0; + const hi = (a[1] - b[1] - borrow) | 0; + return [lo, hi]; +} +function i64_shl1_pair_ts(a: [number, number]): [number, number] { + const au = a[0] >>> 0; + const lo = (au << 1) >>> 0; + const bit31 = (au >>> 31) & 1; + // (a[1] << 1) | bit31, treated as i32. + const hi_u = (((a[1] >>> 0) << 1) | bit31) >>> 0; + const hi = hi_u >= 2 ** 31 ? hi_u - 2 ** 32 : hi_u; + return [lo, hi]; +} +function i64_neg_pair_ts(a: [number, number]): [number, number] { + const nlo = (~a[0]) >>> 0; + const nhi = (~a[1]) >>> 0; + const lo = (nlo + 1) >>> 0; + const carry = lo < nlo ? 1 : 0; + const hi_u = (nhi + carry) >>> 0; + const hi = hi_u >= 2 ** 31 ? hi_u - 2 ** 32 : hi_u; + return [lo, hi]; +} + +// Track the worst-case partial-product magnitude across all calls. Tests +// assert this stays under 2^31 to validate the width-choice claim. +let signedMulSplitMaxPartial = 0n; + +function signed_mul_split_ts(a: number, b: number): [number, number] { + // Mirror the WGSL implementation: signed (15-bit lo, 14/16-bit hi) split. + // a_lo = (a << 17) >> 17 sign-extends low 15 bits. JS `|0` coerces to + // i32 with wrap; left/right shifts on Number use i32 semantics. + const a_lo = (a << 17) >> 17; + const a_hi = ((a - a_lo) >> 15) | 0; + const b_lo = (b << 17) >> 17; + const b_hi = ((b - b_lo) >> 15) | 0; + + const pll = Math.imul(a_lo, b_lo); + const plh = Math.imul(a_lo, b_hi); + const phl = Math.imul(a_hi, b_lo); + const phh = Math.imul(a_hi, b_hi); + const mid = (plh + phl) | 0; + + // Audit: every partial must fit in i32. + for (const p of [pll, plh, phl, phh, mid]) { + const bp = BigInt(Math.abs(p)); + if (bp > signedMulSplitMaxPartial) signedMulSplitMaxPartial = bp; + } + + // Compose low 64 bits of i64 = pll + (mid << 15) + (phh << 30) via u32 + // wrap arithmetic with carries. JS Number can hold all intermediate + // values exactly (well under 2^53). + const pll_lo = pll >>> 0; + const pll_hi = (pll >> 31) >>> 0; // sign mask + const mid_lo = ((mid << 15) >>> 0); + // (mid >> 17) gives signed arith shift; >>> 0 converts to u32 bit pattern. + const mid_hi = (mid >> 17) >>> 0; + const phh_lo = ((phh << 30) >>> 0); + const phh_hi = (phh >> 2) >>> 0; + + // Sum (pll_lo, pll_hi) + (mid_lo, mid_hi). + const s1_lo_raw = pll_lo + mid_lo; + const s1_lo = s1_lo_raw >>> 0; + const c1 = s1_lo_raw >= 2 ** 32 ? 1 : 0; + const s1_hi_raw = pll_hi + mid_hi + c1; + const s1_hi = s1_hi_raw >>> 0; + + // Add (phh_lo, phh_hi). + const sum_lo_raw = s1_lo + phh_lo; + const sum_lo = sum_lo_raw >>> 0; + const c2 = sum_lo_raw >= 2 ** 32 ? 1 : 0; + const sum_hi_raw = s1_hi + phh_hi + c2; + const sum_hi = sum_hi_raw >>> 0; + + const lo29 = sum_lo & 0x1FFFFFFF; + const hi_u = (((sum_lo >>> 29) | (sum_hi << 3)) >>> 0); + const hi = hi_u >= 2 ** 31 ? hi_u - 2 ** 32 : hi_u; + return [lo29, hi]; +} + +function by_accumulate_ts( + acc: [number, number], + m_lo: number, + x_limb: number, +): [number, number] { + const prod = signed_mul_split_ts(m_lo, x_limb); + const lo_sum = (acc[0] + prod[0]) | 0; + const lo29 = lo_sum & ((1 << 29) - 1); + // Arithmetic shift right by 29 in i32: JS `>>` is sign-preserving on i32. + const lo_overflow = lo_sum >> 29; + const hi = ((acc[1] + prod[1] + lo_overflow) | 0); + return [lo29, hi]; +} + +const BY_NUM_LIMBS_N = 9; +function by_normalise_ts(x: number[]): number[] { + const out = x.slice(); + let c = 0n; + for (let i = 0; i < BY_NUM_LIMBS_N - 1; i++) { + // Use BigInt math to mirror WGSL i32 semantics exactly: limb values can + // be ±2^60 from caller, and we want the low 29 bits + arithmetic shift. + const v = BigInt(out[i]) + c; + out[i] = Number(v & BY_LIMB_MASK_N); + c = v >> 29n; + } + out[BY_NUM_LIMBS_N - 1] = Number(BigInt(out[BY_NUM_LIMBS_N - 1]) + c); + return out; +} + +// ---- 20×13-bit BigInt <-> 9×29-bit BigIntBY (TS mirror) ---- +const NUM_WORDS = 20; +const WORD_SIZE = 13; +const WORD_MASK = (1 << 13) - 1; + +function read29Window(limbs: number[], bit_lo: number): number { + // 29-bit window from 13-bit limbs: requires up to ceil((29 + 12) / 13) = 4 + // source limbs (worst case in_limb_bit = 12). + const base_idx = Math.floor(bit_lo / WORD_SIZE); + const in_limb_bit = bit_lo % WORD_SIZE; + const v0 = base_idx < NUM_WORDS ? limbs[base_idx] : 0; + const v1 = base_idx + 1 < NUM_WORDS ? limbs[base_idx + 1] : 0; + const v2 = base_idx + 2 < NUM_WORDS ? limbs[base_idx + 2] : 0; + const v3 = base_idx + 3 < NUM_WORDS ? limbs[base_idx + 3] : 0; + // Use BigInt to avoid 32-bit signed overflow when shifting up to 39 bits. + const s0 = BigInt(v0) >> BigInt(in_limb_bit); + const s1 = BigInt(v1) << BigInt(WORD_SIZE - in_limb_bit); + const s2 = BigInt(v2) << BigInt(WORD_SIZE * 2 - in_limb_bit); + const s3 = BigInt(v3) << BigInt(WORD_SIZE * 3 - in_limb_bit); + return Number((s0 | s1 | s2 | s3) & BY_LIMB_MASK_N); +} +function by_from_bigint_ts(limbs: number[]): number[] { + const r = new Array(9).fill(0); + for (let i = 0; i < 9; i++) { + r[i] = read29Window(limbs, i * 29); + } + return r; +} +function read13Window(by_l: number[], bit_lo: number): number { + const base_idx = Math.floor(bit_lo / 29); + const in_limb_bit = bit_lo % 29; + const v0 = base_idx < 9 ? by_l[base_idx] >>> 0 : 0; + const v1 = base_idx + 1 < 9 ? by_l[base_idx + 1] >>> 0 : 0; + const s0 = BigInt(v0) >> BigInt(in_limb_bit); + const s1 = BigInt(v1) << BigInt(29 - in_limb_bit); + return Number((s0 | s1) & BigInt(WORD_MASK)); +} +function by_to_bigint_ts(by_l: number[]): number[] { + const out = new Array(NUM_WORDS).fill(0); + for (let i = 0; i < NUM_WORDS; i++) { + out[i] = read13Window(by_l, i * WORD_SIZE); + } + return out; +} + +// Convert a 256-bit bigint into a 20×13-bit limb array (canonical, all +// limbs in [0, 2^13)). +function bigintToLimbs(x: bigint): number[] { + const out: number[] = new Array(NUM_WORDS).fill(0); + let v = x; + for (let i = 0; i < NUM_WORDS; i++) { + out[i] = Number(v & BigInt(WORD_MASK)); + v = v >> 13n; + } + return out; +} + +describe("WGSL helper TS mirror", () => { + it("u64_add matches BigInt ground truth (500 random)", () => { + const rng = makeRng(0x11111111); + let failures = 0; + for (let i = 0; i < 500; i++) { + const a: [number, number] = [rng() >>> 0, rng() >>> 0]; + const b: [number, number] = [rng() >>> 0, rng() >>> 0]; + const got = u64_add_ts(a, b); + const expectedBN = (pairToU64(a) + pairToU64(b)) & ((1n << 64n) - 1n); + if (pairToU64(got) !== expectedBN) failures++; + } + expect(failures).toBe(0); + }); + + it("u64_sub matches BigInt ground truth (500 random)", () => { + const rng = makeRng(0x22222222); + let failures = 0; + for (let i = 0; i < 500; i++) { + const a: [number, number] = [rng() >>> 0, rng() >>> 0]; + const b: [number, number] = [rng() >>> 0, rng() >>> 0]; + const got = u64_sub_ts(a, b); + const expectedBN = + (pairToU64(a) - pairToU64(b) + (1n << 64n)) & ((1n << 64n) - 1n); + if (pairToU64(got) !== expectedBN) failures++; + } + expect(failures).toBe(0); + }); + + it("u64_shr1 matches BigInt ground truth (100 random)", () => { + const rng = makeRng(0x33333333); + let failures = 0; + for (let i = 0; i < 100; i++) { + const a: [number, number] = [rng() >>> 0, rng() >>> 0]; + const got = u64_shr1_ts(a); + const expectedBN = pairToU64(a) >> 1n; + if (pairToU64(got) !== expectedBN) failures++; + } + expect(failures).toBe(0); + }); + + it("u64_low_bit matches BigInt ground truth (100 random)", () => { + const rng = makeRng(0x44444444); + let failures = 0; + for (let i = 0; i < 100; i++) { + const a: [number, number] = [rng() >>> 0, rng() >>> 0]; + const got = u64_low_bit_ts(a); + const expected = Number(pairToU64(a) & 1n); + if (got !== expected) failures++; + } + expect(failures).toBe(0); + }); + + it("u64_neg matches BigInt ground truth (100 random)", () => { + const rng = makeRng(0x55555555); + let failures = 0; + for (let i = 0; i < 100; i++) { + const a: [number, number] = [rng() >>> 0, rng() >>> 0]; + const got = u64_neg_ts(a); + const expectedBN = + (((1n << 64n) - pairToU64(a)) & ((1n << 64n) - 1n)) % (1n << 64n); + if (pairToU64(got) !== expectedBN) failures++; + } + expect(failures).toBe(0); + }); + + it("u64_and matches BigInt ground truth (100 random)", () => { + const rng = makeRng(0x66666666); + let failures = 0; + for (let i = 0; i < 100; i++) { + const a: [number, number] = [rng() >>> 0, rng() >>> 0]; + const m: [number, number] = [rng() >>> 0, rng() >>> 0]; + const got = u64_and_ts(a, m); + const expectedBN = pairToU64(a) & pairToU64(m); + if (pairToU64(got) !== expectedBN) failures++; + } + expect(failures).toBe(0); + }); + + // ---- i64 signed pair helpers ---- + function makeI64Pair(rng: () => number, magBits: number): [number, number] { + // Generate a uniform signed integer in [-2^magBits, 2^magBits). + let v = 0n; + const bytes = Math.ceil(magBits / 8) + 1; + for (let i = 0; i < bytes; i++) { + v = (v << 8n) | BigInt(rng() & 0xff); + } + const range = 1n << BigInt(magBits); + // Map into [-range, range). + v = (v % (2n * range)) - range; + return bigintToI64Pair(v); + } + + it("i64_add_pair matches BigInt ground truth (200 random)", () => { + const rng = makeRng(0x77777777); + let failures = 0; + for (let i = 0; i < 200; i++) { + const a = makeI64Pair(rng, 60); + const b = makeI64Pair(rng, 60); + const got = i64_add_pair_ts(a, b); + let expectedBN = (i64PairToBigint(a) + i64PairToBigint(b)) & ((1n << 64n) - 1n); + if (expectedBN >= I64_SIGN) expectedBN -= 1n << 64n; + if (i64PairToBigint(got) !== expectedBN) { + failures++; + if (failures <= 3) { + console.error( + `i64_add fail: a=${i64PairToBigint(a)} b=${i64PairToBigint(b)} got=${i64PairToBigint(got)} expected=${expectedBN}`, + ); + } + } + } + expect(failures).toBe(0); + }); + + it("i64_sub_pair matches BigInt ground truth (200 random)", () => { + const rng = makeRng(0x88888888); + let failures = 0; + for (let i = 0; i < 200; i++) { + const a = makeI64Pair(rng, 60); + const b = makeI64Pair(rng, 60); + const got = i64_sub_pair_ts(a, b); + let expectedBN = (i64PairToBigint(a) - i64PairToBigint(b)) & ((1n << 64n) - 1n); + if (expectedBN >= I64_SIGN) expectedBN -= 1n << 64n; + if (i64PairToBigint(got) !== expectedBN) failures++; + } + expect(failures).toBe(0); + }); + + it("i64_shl1_pair matches BigInt ground truth (200 random)", () => { + const rng = makeRng(0x99999999); + let failures = 0; + for (let i = 0; i < 200; i++) { + const a = makeI64Pair(rng, 60); + const got = i64_shl1_pair_ts(a); + let expectedBN = (i64PairToBigint(a) << 1n) & ((1n << 64n) - 1n); + if (expectedBN >= I64_SIGN) expectedBN -= 1n << 64n; + if (i64PairToBigint(got) !== expectedBN) failures++; + } + expect(failures).toBe(0); + }); + + it("i64_neg_pair matches BigInt ground truth (200 random)", () => { + const rng = makeRng(0xaaaaaaaa); + let failures = 0; + for (let i = 0; i < 200; i++) { + const a = makeI64Pair(rng, 60); + const got = i64_neg_pair_ts(a); + let expectedBN = (-i64PairToBigint(a)) & ((1n << 64n) - 1n); + if (expectedBN >= I64_SIGN) expectedBN -= 1n << 64n; + if (i64PairToBigint(got) !== expectedBN) failures++; + } + expect(failures).toBe(0); + }); + + // ---- signed_mul_split ---- + it("signed_mul_split: 500 random (a,b) with bounds + partial-product audit", () => { + const rng = makeRng(0xbbbbbbbb); + let failures = 0; + signedMulSplitMaxPartial = 0n; + for (let i = 0; i < 500; i++) { + // a in [-2^29, 2^29] (one BY limb plus possible sign-bit overflow). + // b in [-2^31, 2^31) (i32 range, matrix entry low half). + const aMag = rng() & ((1 << 29) - 1); // 29-bit unsigned in [0, 2^29) + const aSign = (rng() & 1) ? -1 : 1; + const a = aSign * aMag; + const bSign = (rng() & 1) ? -1 : 1; + // Generate b in [-2^31, 2^31). i32 representation: any 32-bit value. + const bMagU32 = rng() >>> 0; + const b = bMagU32 >= 2 ** 31 ? bMagU32 - 2 ** 32 : bMagU32; + // Clamp |a| to ≤ 2^29 — generate could be exactly 2^29-1, which is fine. + if (Math.abs(a) > 2 ** 29) continue; + + const got = signed_mul_split_ts(a, b); + const lo29 = BigInt(got[0]); + const hi = BigInt(got[1]); + const reconstructed = lo29 + hi * POW_29N; + const expected = BigInt(a) * BigInt(b); + if (reconstructed !== expected) { + failures++; + if (failures <= 3) { + console.error( + `signed_mul_split fail: a=${a} b=${b} got=(lo29=${got[0]}, hi=${got[1]}) reconstructed=${reconstructed} expected=${expected}`, + ); + } + } + // lo29 must be in [0, 2^29). + if (got[0] < 0 || got[0] >= POW_29) failures++; + } + expect(failures).toBe(0); + // Width audit (per plan): every signed partial product fits in i32. + // With the 15-bit signed split, the worst case is |a_lo| * |b_hi| or + // |a_hi| * |b_hi| <= 2^14 * 2^16 = 2^30, well within i32. + expect(signedMulSplitMaxPartial).toBeLessThan(1n << 31n); + }); + + it("signed_mul_split: edge cases (zero, max-positive, min-negative)", () => { + // For |a|, |b| <= 2^29-1 (strict), |hi| <= 2^29 fits in i32 comfortably. + // Note: a = -2^29 with b = -2^31 produces a*b = 2^60, hi = 2^31 which + // does NOT fit in signed i32. The actual use case in apply_matrix has + // |b| <= 2^29 (matrix-entry low half is in [0, 2^29) unsigned, matrix- + // entry high half is signed in [-2^29, 2^29)), so the edge cases here + // are bounded accordingly. + const edges: Array<[number, number]> = [ + [0, 0], + [1, 1], + [POW_29 - 1, 0], + [0, POW_29 - 1], + [POW_29 - 1, POW_29 - 1], + [-POW_29, POW_29 - 1], + [POW_29 - 1, -POW_29], + [-POW_29, -POW_29], + [-1, -1], + [1, -1], + [-1, 1], + ]; + for (const [a, b] of edges) { + const got = signed_mul_split_ts(a, b); + const reconstructed = BigInt(got[0]) + BigInt(got[1]) * POW_29N; + const expected = BigInt(a) * BigInt(b); + expect(reconstructed).toBe(expected); + expect(got[0]).toBeGreaterThanOrEqual(0); + expect(got[0]).toBeLessThan(POW_29); + } + }); + + // ---- by_accumulate ---- + it("by_accumulate: 200 random matches BigInt", () => { + const rng = makeRng(0xcccccccc); + let failures = 0; + for (let i = 0; i < 200; i++) { + // acc.x in [0, 2^29); acc.y bounded so |acc| < 2^58 (matches the + // 2-limb accumulator used by apply_matrix). Generate |acc_hi| <= 2^29. + const acc_lo = rng() & ((1 << 29) - 1); + const acc_hi_mag = rng() & ((1 << 29) - 1); + const acc_hi = (rng() & 1) ? -acc_hi_mag : acc_hi_mag; + const m_lo_mag = rng() & ((1 << 29) - 1); + const m_lo = (rng() & 1) ? -m_lo_mag : m_lo_mag; + const x_limb_mag = rng() & ((1 << 29) - 1); + const x_limb = (rng() & 1) ? -x_limb_mag : x_limb_mag; + + const got = by_accumulate_ts([acc_lo, acc_hi], m_lo, x_limb); + const accBN = BigInt(acc_lo) + BigInt(acc_hi) * POW_29N; + const expected = accBN + BigInt(m_lo) * BigInt(x_limb); + const reconstructed = BigInt(got[0]) + BigInt(got[1]) * POW_29N; + if (reconstructed !== expected) { + failures++; + if (failures <= 3) { + console.error( + `by_accumulate fail: acc=(${acc_lo}, ${acc_hi}) m_lo=${m_lo} x_limb=${x_limb} got=(${got[0]}, ${got[1]}) reconstructed=${reconstructed} expected=${expected}`, + ); + } + } + if (got[0] < 0 || got[0] >= POW_29) failures++; + } + expect(failures).toBe(0); + }); + + // ---- by_normalise ---- + it("by_normalise: 100 cases with non-canonical limbs", () => { + const rng = makeRng(0xdddddddd); + let failures = 0; + for (let i = 0; i < 100; i++) { + const limbs: number[] = []; + // Build truly non-canonical limbs in [-2^60, 2^60) using BigInt then + // reduce mod 2^32 for the JS array (since we re-read via BigInt below). + const limbBN: bigint[] = []; + let originalValueBN = 0n; + for (let j = 0; j < 9; j++) { + // Random signed up to ~2^31 (since JS Number array would lose + // precision otherwise). We're still exercising overflow scenarios. + const mag = rng() & 0xfffffff; // 28-bit + const sign = (rng() & 1) ? -1n : 1n; + const v = sign * BigInt(mag); + limbBN.push(v); + limbs.push(Number(v)); + originalValueBN += v << BigInt(j * 29); + } + const got = by_normalise_ts(limbs); + // Compute reconstructed (treating limbs as signed where limb 8 is the + // sign carrier). After normalise, l[0..7] are in [0, 2^29) and l[8] + // is signed. + let reconstructed = 0n; + for (let j = 0; j < 8; j++) { + reconstructed += BigInt(got[j]) << BigInt(j * 29); + } + // l[8] is signed (no normalisation past). + reconstructed += BigInt(got[8]) << BigInt(8 * 29); + if (reconstructed !== originalValueBN) { + failures++; + if (failures <= 3) { + console.error( + `by_normalise fail: original=${originalValueBN} reconstructed=${reconstructed}`, + ); + } + } + for (let j = 0; j < 8; j++) { + if (got[j] < 0 || got[j] >= POW_29) { + failures++; + if (failures <= 3) { + console.error(`limb ${j} not canonical: ${got[j]}`); + } + } + } + } + expect(failures).toBe(0); + }); + + // ---- by_from_bigint / by_to_bigint ---- + it("by_from_bigint / by_to_bigint round-trip on 200 BN254 base-field values", () => { + const rng = makeRng(0xeeeeeeee); + let failures = 0; + for (let i = 0; i < 200; i++) { + const x = randomBelow(P_BIGINT, rng); + const srcLimbs = bigintToLimbs(x); + const by = by_from_bigint_ts(srcLimbs); + // Reconstruct from BY limbs to verify the conversion. + let v = 0n; + for (let j = 0; j < 9; j++) { + v += BigInt(by[j]) << BigInt(j * 29); + } + if (v !== x) { + failures++; + if (failures <= 3) { + console.error(`by_from_bigint fail: x=${x} reconstructed=${v}`); + } + } + // Now round-trip back. + const back = by_to_bigint_ts(by); + let backV = 0n; + for (let j = 0; j < NUM_WORDS; j++) { + backV += BigInt(back[j]) << BigInt(j * WORD_SIZE); + } + if (backV !== x) { + failures++; + if (failures <= 3) { + console.error(`by_to_bigint fail: x=${x} back=${backV}`); + } + } + // Every output limb is canonical. + for (let j = 0; j < NUM_WORDS; j++) { + if (back[j] < 0 || back[j] > WORD_MASK) failures++; + } + } + expect(failures).toBe(0); + }); +}); diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang.ts b/barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang.ts new file mode 100644 index 000000000000..693c02a78dcd --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang.ts @@ -0,0 +1,470 @@ +// TypeScript port of the Wasm9x29 Bernstein-Yang safegcd modular inverse +// from barretenberg/cpp/src/barretenberg/ecc/fields/bernstein_yang_inverse_wasm.hpp +// (with the driver from bernstein_yang_inverse.hpp). +// +// This module is the ground-truth reference for the WGSL port that follows +// in the WebGPU MSM rewrite plan (sub-steps 1.2-1.6). It must be a faithful +// transliteration of the C++ — same constants, same control flow, same +// per-limb arithmetic. Performance is not a concern; correctness vs +// `modInverse` over 1000+ random BN254 base-field values is. +// +// Algorithmic notes +// ----------------- +// State: 9 × 29-bit signed limbs held as `bigint[9]`. Top limb carries +// sign, lower limbs in [0, 2^29) post-normalise. Choice of `bigint` over +// `number`: matrix entries grow up to 2^58 after BATCH=58 divsteps and the +// per-limb cross-products in apply_matrix reach |i58|, which overflow JS's +// 32-bit safe integer range for arithmetic. The intermediate accumulators +// (cf, cg, cd, ce, nf, ng, nd, ne, t_d, t_e) must be `bigint`. +// +// The divsteps low-64-bit carriers (f_lo, g_lo) are also `bigint` since they +// hold u64 values. Each divstep mutates these via shift / add / sub at i64 +// precision; we mask back to u64 after each mutation to avoid sign +// extension surprises. BigInt's `>>` is arithmetic shift for negative +// values, matching C++ `(i64) >> n`. +// +// Loop bound discipline (per plan §1, hard constraint): the only loops in +// this file have static upper bounds. +// - divsteps inner loop: BATCH = 58 +// - apply_matrix per-limb loops: N = 9 +// - reduce_to_canonical: RTC_MAX_ITERS = 36 +// - driver outer loop: NUM_OUTER = 13 (with early `g == 0` break) +// The test file asserts the step counter never exceeds these bounds. + +import { BN254_BASE_FIELD } from "./bn254.js"; + +// ---- Constants (mirror Wasm9x29::*) ----------------------------------- +export const N = 9; +export const LIMB_BITS = 29n; +export const LIMB_MASK = (1n << LIMB_BITS) - 1n; // 0x1FFFFFFF +export const BATCH = 58n; +export const BATCH_NUM = 58; +export const MASK_BATCH = (1n << BATCH) - 1n; +export const NUM_OUTER = 13; +export const REDUCE_INTERVAL = 4; +export const RTC_MAX_ITERS = 36; + +// u64 mask for f_lo / g_lo carriers in divsteps. +const U64_MASK = (1n << 64n) - 1n; +const U64_SIGN_BIT = 1n << 63n; + +// ---- BN254 base-field p (matches the C++ `p` argument) ---------------- +export const P_BIGINT = BN254_BASE_FIELD; + +// p_inv = p^(-1) mod 2^BATCH (BATCH = 58). Used by apply_matrix's 2-adic +// correction step. Precomputed via Hensel's lemma; verified +// (P * P_INV) % 2^58 === 1. +export const P_INV = 12939590167534711n; + +// ---- 2x2 matrix produced by BATCH divsteps ---------------------------- +export interface Mat { + u: bigint; + v: bigint; + q: bigint; + r: bigint; +} + +// ---- BigIntBY state: 9 × 29-bit signed limbs -------------------------- +export type BigIntBY = bigint[]; // length 9 + +export function makeZero(): BigIntBY { + return [0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n]; +} + +export function makeOne(): BigIntBY { + const r = makeZero(); + r[0] = 1n; + return r; +} + +// uint256 -> 9 × 29-bit limbs (matches Wasm9x29(const uint256_t&) ctor). +// Input `x` is the canonical integer in [0, 2^256); we treat it as 4 × u64 +// limbs internally to mirror the C++ bit extraction exactly. +export function fromBigint(x: bigint): BigIntBY { + if (x < 0n) { + throw new Error("fromBigint: negative input"); + } + const mask64 = (1n << 64n) - 1n; + const d0 = x & mask64; + const d1 = (x >> 64n) & mask64; + const d2 = (x >> 128n) & mask64; + const d3 = (x >> 192n) & mask64; + const l: BigIntBY = makeZero(); + l[0] = d0 & LIMB_MASK; + l[1] = (d0 >> 29n) & LIMB_MASK; + l[2] = ((d0 >> 58n) & 0x3Fn) | ((d1 & 0x7FFFFFn) << 6n); + l[3] = (d1 >> 23n) & LIMB_MASK; + l[4] = ((d1 >> 52n) & 0xFFFn) | ((d2 & 0x1FFFFn) << 12n); + l[5] = (d2 >> 17n) & LIMB_MASK; + l[6] = ((d2 >> 46n) & 0x3FFFFn) | ((d3 & 0x7FFn) << 18n); + l[7] = (d3 >> 11n) & LIMB_MASK; + l[8] = (d3 >> 40n) & 0xFFFFFFn; + return l; +} + +// 9 × 29-bit limbs -> uint256 (matches to_uint256()). +// We re-pack into 4 × u64 chunks then concatenate. Assumes the state is +// canonical (all limbs in [0, 2^29), top limb non-negative). +export function toBigint(x: BigIntBY): bigint { + const mask64 = (1n << 64n) - 1n; + const w0 = (x[0] | (x[1] << 29n) | (x[2] << 58n)) & mask64; + const w1 = ((x[2] >> 6n) | (x[3] << 23n) | (x[4] << 52n)) & mask64; + const w2 = ((x[4] >> 12n) | (x[5] << 17n) | (x[6] << 46n)) & mask64; + const w3 = ((x[6] >> 18n) | (x[7] << 11n) | (x[8] << 40n)) & mask64; + return w0 | (w1 << 64n) | (w2 << 128n) | (w3 << 192n); +} + +export function isZero(x: BigIntBY): boolean { + let a = 0n; + for (let i = 0; i < N; i++) { + a |= x[i]; + } + return a === 0n; +} + +export function isNegative(x: BigIntBY): boolean { + return x[N - 1] < 0n; +} + +export function neg(x: BigIntBY): void { + for (let i = 0; i < N; i++) { + x[i] = -x[i]; + } + normalise(x); +} + +// Canonicalise limb carries: propagate so each lower limb is in +// [0, 2^29) using arithmetic shift; the top limb absorbs the final +// carry (and may go negative). +export function normalise(x: BigIntBY): void { + let c = 0n; + for (let i = 0; i < N - 1; i++) { + const v = x[i] + c; + x[i] = v & LIMB_MASK; + c = v >> LIMB_BITS; // arithmetic shift for negative v + } + x[N - 1] += c; +} + +function addInplace(x: BigIntBY, b: BigIntBY): void { + for (let i = 0; i < N; i++) { + x[i] += b[i]; + } + normalise(x); +} + +function subInplace(x: BigIntBY, b: BigIntBY): void { + for (let i = 0; i < N; i++) { + x[i] -= b[i]; + } + normalise(x); +} + +// reduce_to_canonical: bring x into [0, p) using at most 36 add-p / sub-p +// passes. 36 covers |x| ≤ 32p under REDUCE_INTERVAL = 4 (see Wasm9x29 docs). +// Tracks step count and asserts it never exceeds the bound — the WGSL port +// will have the same fixed bound. +export function reduceToCanonical( + x: BigIntBY, + p: BigIntBY, + stepRef?: { steps: number }, +): void { + normalise(x); + for (let it = 0; it < RTC_MAX_ITERS; it++) { + if (stepRef) stepRef.steps++; + if (isNegative(x)) { + addInplace(x, p); + continue; + } + let cmp = 0; + for (let i = N - 1; i >= 0; i--) { + if (x[i] !== p[i]) { + cmp = x[i] > p[i] ? 1 : -1; + break; + } + } + if (cmp < 0) { + break; + } + subInplace(x, p); + } +} + +// ---- divsteps --------------------------------------------------------- +// Run BATCH = 58 branchy divsteps on the low 64 bits of (f, g); returns +// the transition matrix M and updates delta. Variable-time over inner +// branches — non-secret inputs only. +// +// Inputs/outputs: +// delta: scalar bigint (interpret as i64), mutated by the caller via the +// `deltaRef` wrapper. We return the new delta in the tuple to keep +// the signature simple for testing. +// f_lo, g_lo: bigint values representing u64. We mask to u64 after each +// mutation; arithmetic shifts on bigint handle the sign correctly +// because we explicitly convert g_lo to its signed form for the +// subtraction `(g_lo - f_lo) >> 1` (C++ relies on u64-wrap then +// unsigned >>1, which we mimic exactly with `& U64_MASK`). +export function divsteps( + deltaIn: bigint, + fLoIn: bigint, + gLoIn: bigint, + stepRef?: { steps: number }, +): { mat: Mat; delta: bigint } { + let delta = deltaIn; + let f_lo = fLoIn & U64_MASK; + let g_lo = gLoIn & U64_MASK; + let u = 1n; + let v = 0n; + let q = 0n; + let r = 1n; + for (let i = 0; i < BATCH_NUM; i++) { + if (stepRef) stepRef.steps++; + if ((g_lo & 1n) !== 0n) { + if (delta > 0n) { + // (f, g) <- (g, (g - f)/2) + const nf = g_lo; + // (g_lo - f_lo) wraps mod 2^64 (C++ u64 sub), then unsigned >>1. + const diff = (g_lo - f_lo) & U64_MASK; + const ng = diff >> 1n; + const nu = q << 1n; + const nv = r << 1n; + const nq = q - u; + const nr = r - v; + f_lo = nf; + g_lo = ng; + u = nu; + v = nv; + q = nq; + r = nr; + delta = 1n - delta; + } else { + // g <- (g + f)/2 + const sum = (g_lo + f_lo) & U64_MASK; + g_lo = sum >> 1n; + q = q + u; + r = r + v; + u <<= 1n; + v <<= 1n; + delta = delta + 1n; + } + } else { + // g <- g/2 + g_lo >>= 1n; + u <<= 1n; + v <<= 1n; + delta = delta + 1n; + } + } + return { mat: { u, v, q, r }, delta }; +} + +// ---- apply_matrix ----------------------------------------------------- +// Apply M to (f, g, d, e) using the streamed schoolbook described in the +// C++ comment block. Matrix entries can reach |2^58| post-divsteps; we +// split into (lo: 29-bit, hi: i32) parts via arithmetic shift, matching +// the C++ `u_lo = m.u & LIMB_MASK; u_hi = m.u >> LIMB_BITS;`. +// +// The exact `/ 2^BATCH = 2^58 = 2 * LIMB_BITS` at the end is realised by +// writing the per-limb result at position `i - 2` (drop bottom two limbs). +export function applyMatrix( + m: Mat, + f: BigIntBY, + g: BigIntBY, + d: BigIntBY, + e: BigIntBY, + p: BigIntBY, + pInv: bigint, +): void { + // (signed) split into low 29 bits + high (arithmetic shift). + const u_lo = m.u & LIMB_MASK; + const u_hi = m.u >> LIMB_BITS; + const v_lo = m.v & LIMB_MASK; + const v_hi = m.v >> LIMB_BITS; + const q_lo = m.q & LIMB_MASK; + const q_hi = m.q >> LIMB_BITS; + const r_lo = m.r & LIMB_MASK; + const r_hi = m.r >> LIMB_BITS; + + // ---- (f, g) pass ---------------------------------------------------- + { + let cf = 0n; + let cg = 0n; + let fp = 0n; + let gp = 0n; + for (let i = 0; i < N; i++) { + const fi = f[i]; + const gi = g[i]; + const nf = u_lo * fi + v_lo * gi + u_hi * fp + v_hi * gp + cf; + const ng = q_lo * fi + r_lo * gi + q_hi * fp + r_hi * gp + cg; + cf = nf >> LIMB_BITS; + cg = ng >> LIMB_BITS; + if (i >= 2) { + f[i - 2] = nf & LIMB_MASK; + g[i - 2] = ng & LIMB_MASK; + } + fp = fi; + gp = gi; + } + const nf9 = u_hi * fp + v_hi * gp + cf; + const ng9 = q_hi * fp + r_hi * gp + cg; + f[N - 2] = nf9 & LIMB_MASK; + g[N - 2] = ng9 & LIMB_MASK; + f[N - 1] = nf9 >> LIMB_BITS; + g[N - 1] = ng9 >> LIMB_BITS; + } + + // ---- (d, e) pass with 2-adic k·p correction ------------------------- + { + const d0 = d[0]; + const e0 = e[0]; + const d1 = d[1]; + const e1 = e[1]; + const nd0 = u_lo * d0 + v_lo * e0; + const ne0 = q_lo * d0 + r_lo * e0; + const nd1 = u_lo * d1 + v_lo * e1 + u_hi * d0 + v_hi * e0; + const ne1 = q_lo * d1 + r_lo * e1 + q_hi * d0 + r_hi * e0; + // Reconstruct low 58 bits of nd and ne for k computation. + // C++ does `(u64)nd0 & LIMB_MASK | (((u64)(nd1 + (nd0 >> LIMB_BITS)) & LIMB_MASK) << LIMB_BITS)`. + // The `(u64)` casts truncate signed -> unsigned mod 2^64. We replicate + // by masking with U64_MASK before assembling. + const nd0_low = nd0 & LIMB_MASK; + const nd1_carry = (nd1 + (nd0 >> LIMB_BITS)) & LIMB_MASK; + const t_d = (nd0_low | (nd1_carry << LIMB_BITS)) & U64_MASK; + const ne0_low = ne0 & LIMB_MASK; + const ne1_carry = (ne1 + (ne0 >> LIMB_BITS)) & LIMB_MASK; + const t_e = (ne0_low | (ne1_carry << LIMB_BITS)) & U64_MASK; + // k = (-t mod 2^64) * p_inv mod 2^BATCH. + // (-t) mod 2^64 = (~t + 1) & U64_MASK = (U64_MASK - t + 1) & U64_MASK. + const neg_td = (U64_MASK - t_d + 1n) & U64_MASK; + const neg_te = (U64_MASK - t_e + 1n) & U64_MASK; + const k_d = (neg_td * pInv) & MASK_BATCH; + const k_e = (neg_te * pInv) & MASK_BATCH; + const kd_lo = k_d & LIMB_MASK; + const kd_hi = k_d >> LIMB_BITS; + const ke_lo = k_e & LIMB_MASK; + const ke_hi = k_e >> LIMB_BITS; + // Initial carries seed from limb-1 partials (matches the C++ + // `cd = (nd1 + kd_lo*p[1] + kd_hi*p[0] + ((nd0 + kd_lo*p[0]) >> LIMB_BITS)) >> LIMB_BITS`). + let cd = + (nd1 + kd_lo * p[1] + kd_hi * p[0] + ((nd0 + kd_lo * p[0]) >> LIMB_BITS)) >> + LIMB_BITS; + let ce = + (ne1 + ke_lo * p[1] + ke_hi * p[0] + ((ne0 + ke_lo * p[0]) >> LIMB_BITS)) >> + LIMB_BITS; + + let dp = d1; + let ep = e1; + for (let i = 2; i < N; i++) { + const di = d[i]; + const ei = e[i]; + const nd = + u_lo * di + + v_lo * ei + + u_hi * dp + + v_hi * ep + + kd_lo * p[i] + + kd_hi * p[i - 1] + + cd; + const ne = + q_lo * di + + r_lo * ei + + q_hi * dp + + r_hi * ep + + ke_lo * p[i] + + ke_hi * p[i - 1] + + ce; + cd = nd >> LIMB_BITS; + ce = ne >> LIMB_BITS; + d[i - 2] = nd & LIMB_MASK; + e[i - 2] = ne & LIMB_MASK; + dp = di; + ep = ei; + } + const nd9 = u_hi * dp + v_hi * ep + kd_hi * p[N - 1] + cd; + const ne9 = q_hi * dp + r_hi * ep + ke_hi * p[N - 1] + ce; + d[N - 2] = nd9 & LIMB_MASK; + e[N - 2] = ne9 & LIMB_MASK; + d[N - 1] = nd9 >> LIMB_BITS; + e[N - 1] = ne9 >> LIMB_BITS; + } +} + +// ---- low_64: low 64 bits of the state --------------------------------- +// Mirrors Wasm9x29::low_64(): re-pack limbs [0,1,2] (low 6 bits of limb 2) +// into a u64. Returns a positive bigint <2^64. +export function low64(x: BigIntBY): bigint { + return ((x[0] & LIMB_MASK) | ((x[1] & LIMB_MASK) << 29n) | ((x[2] & 0x3Fn) << 58n)) & U64_MASK; +} + +// Convert a positive bigint into a signed-64 view (for use as `f_lo`/`g_lo` +// in the divsteps inner where (g_lo - f_lo) interprets as u64 sub-wrap). +// We don't need to convert — divsteps masks the inputs to U64_MASK. + +// ---- driver: invert_bernsteinyang19 ----------------------------------- +// Returns a^(-1) mod p, or 0 if a == 0. Bounded by NUM_OUTER outer iters +// (≤ 13) with early break on g == 0, RTC_MAX_ITERS per reduce_to_canonical. +// +// stepRef (optional) accumulates the total bounded-loop step count for +// test assertions: divsteps inner + reduce_to_canonical inner steps. +export function invert( + a: bigint, + p: bigint = P_BIGINT, + pInv: bigint = P_INV, + stepRef?: { steps: number }, +): bigint { + if (a === 0n) return 0n; + const P = fromBigint(p); + const f = fromBigint(p); + const g = fromBigint(a); + const d = makeZero(); + const e = makeOne(); + let delta = 1n; + for (let i = 0; i < NUM_OUTER; i++) { + const f_lo = low64(f); + const g_lo = low64(g); + const { mat, delta: newDelta } = divsteps(delta, f_lo, g_lo, stepRef); + delta = newDelta; + applyMatrix(mat, f, g, d, e, P, pInv); + if (isZero(g)) break; + if ((i + 1) % REDUCE_INTERVAL === 0) { + reduceToCanonical(d, P, stepRef); + reduceToCanonical(e, P, stepRef); + } + } + reduceToCanonical(d, P, stepRef); + if (isNegative(f)) { + neg(d); + reduceToCanonical(d, P, stepRef); + } + return toBigint(d); +} + +// ---- Convenience namespace --------------------------------------------- +// Same shape as the C++ `Wasm9x29::*` static interface so the WGSL port +// can be diffed line-by-line against this file. +export const Wasm9x29 = { + N, + LIMB_BITS, + LIMB_MASK, + BATCH, + BATCH_NUM, + NUM_OUTER, + REDUCE_INTERVAL, + RTC_MAX_ITERS, + P: P_BIGINT, + P_INV, + makeZero, + makeOne, + fromBigint, + toBigint, + isZero, + isNegative, + neg, + normalise, + reduceToCanonical: (x: BigIntBY, p: BigIntBY) => reduceToCanonical(x, p), + divsteps: (delta: bigint, f_lo: bigint, g_lo: bigint) => + divsteps(delta, f_lo, g_lo), + applyMatrix, + low64, + invert: (a: bigint, p: bigint = P_BIGINT) => invert(a, p, P_INV), +}; diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang_a.ts b/barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang_a.ts new file mode 100644 index 000000000000..c838408cc970 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/cuzk/bernstein_yang_a.ts @@ -0,0 +1,247 @@ +// Option A TS reference: BY safegcd inverse on 20 x 13-bit BigInt with +// BATCH=26, NUM_OUTER=29. Validates the WGSL port at fr_inv_by_a against +// modInverse. Uses bigint throughout (no overflow concerns). + +import { BN254_BASE_FIELD, modInverse } from "./bn254.js"; + +const N = 20; +const LIMB_BITS = 13n; +const LIMB_MASK = (1n << LIMB_BITS) - 1n; +const BATCH = 26n; +const BATCH_NUM = 26; +const MASK_BATCH = (1n << BATCH) - 1n; +const NUM_OUTER = 29; +const REDUCE_INTERVAL = 4; +const RTC_MAX_ITERS = 36; + +const U64_MASK = (1n << 64n) - 1n; +const P_BIGINT = BN254_BASE_FIELD; +// p^(-1) mod 2^26. Hensel-lifted. +function computePInv26(p: bigint): bigint { + let inv = 1n; + let cur = 1n; + while (cur < BATCH) { + cur *= 2n; + const mask = (1n << cur) - 1n; + inv = (inv * (2n - p * inv)) & mask; + } + inv &= MASK_BATCH; + return inv; +} +const P_INV = computePInv26(P_BIGINT); + +type BigA = bigint[]; // length 20 + +function makeZero(): BigA { return new Array(N).fill(0n); } +function makeOne(): BigA { const r = makeZero(); r[0] = 1n; return r; } + +function fromBigint(x: bigint): BigA { + const r = makeZero(); + for (let i = 0; i < N; i++) r[i] = (x >> (BigInt(i) * LIMB_BITS)) & LIMB_MASK; + return r; +} + +function toBigint(x: BigA): bigint { + let v = 0n; + for (let i = N - 1; i >= 0; i--) v = (v << LIMB_BITS) | (x[i] & LIMB_MASK); + return v; +} + +function lowU64(x: BigA): bigint { + // Low 64 bits of x. Uses limbs 0..4 (= 65 bits, top bit unused). + let v = 0n; + for (let i = 4; i >= 0; i--) v = (v << LIMB_BITS) | (x[i] & LIMB_MASK); + return v & U64_MASK; +} + +function isZero(x: BigA): boolean { + for (let i = 0; i < N; i++) if (x[i] !== 0n) return false; + return true; +} +function isNegative(x: BigA): boolean { return x[N - 1] < 0n; } + +function normalise(x: BigA): void { + let c = 0n; + for (let i = 0; i < N - 1; i++) { + const v = x[i] + c; + x[i] = v & LIMB_MASK; + c = v >> LIMB_BITS; // bigint arith shift + } + x[N - 1] += c; +} + +function neg(x: BigA): void { + for (let i = 0; i < N; i++) x[i] = -x[i]; + normalise(x); +} + +function addP(x: BigA, p: BigA): void { + for (let i = 0; i < N; i++) x[i] += p[i]; + normalise(x); +} +function subP(x: BigA, p: BigA): void { + for (let i = 0; i < N; i++) x[i] -= p[i]; + normalise(x); +} + +function reduceToCanonical(x: BigA, p: BigA): void { + normalise(x); + for (let it = 0; it < RTC_MAX_ITERS; it++) { + if (isNegative(x)) { addP(x, p); continue; } + let cmp = 0; + for (let i = N - 1; i >= 0; i--) { + if (x[i] !== p[i]) { cmp = x[i] > p[i] ? 1 : -1; break; } + } + if (cmp < 0) break; + subP(x, p); + } +} + +interface Mat { u: bigint; v: bigint; q: bigint; r: bigint; } + +function divsteps(deltaIn: bigint, fLoIn: bigint, gLoIn: bigint): { mat: Mat; delta: bigint } { + let delta = deltaIn; + let f_lo = fLoIn & U64_MASK; + let g_lo = gLoIn & U64_MASK; + let u = 1n, v = 0n, q = 0n, r = 1n; + for (let i = 0; i < BATCH_NUM; i++) { + if ((g_lo & 1n) !== 0n) { + if (delta > 0n) { + const nf = g_lo; + const diff = (g_lo - f_lo) & U64_MASK; + const ng = diff >> 1n; + const nu = q << 1n; + const nv = r << 1n; + const nq = q - u; + const nr = r - v; + f_lo = nf; g_lo = ng; + u = nu; v = nv; q = nq; r = nr; + delta = 1n - delta; + } else { + const sum = (g_lo + f_lo) & U64_MASK; + g_lo = sum >> 1n; + q = q + u; r = r + v; + u <<= 1n; v <<= 1n; + delta = delta + 1n; + } + } else { + g_lo >>= 1n; + u <<= 1n; v <<= 1n; + delta = delta + 1n; + } + } + return { mat: { u, v, q, r }, delta }; +} + +function applyMatrix(m: Mat, f: BigA, g: BigA, d: BigA, e: BigA, p: BigA, pInv: bigint): void { + // Use TS-style decomposition: u_lo unsigned in [0, 2^13), u_hi signed. + const u_lo = m.u & LIMB_MASK; + const u_hi = m.u >> LIMB_BITS; + const v_lo = m.v & LIMB_MASK; + const v_hi = m.v >> LIMB_BITS; + const q_lo = m.q & LIMB_MASK; + const q_hi = m.q >> LIMB_BITS; + const r_lo = m.r & LIMB_MASK; + const r_hi = m.r >> LIMB_BITS; + + // (f, g) pass + { + let cf = 0n, cg = 0n, fp = 0n, gp = 0n; + for (let i = 0; i < N; i++) { + const fi = f[i]; + const gi = g[i]; + const nf = u_lo * fi + v_lo * gi + u_hi * fp + v_hi * gp + cf; + const ng = q_lo * fi + r_lo * gi + q_hi * fp + r_hi * gp + cg; + cf = nf >> LIMB_BITS; + cg = ng >> LIMB_BITS; + if (i >= 2) { + f[i - 2] = nf & LIMB_MASK; + g[i - 2] = ng & LIMB_MASK; + } + fp = fi; + gp = gi; + } + const nfTop = u_hi * fp + v_hi * gp + cf; + const ngTop = q_hi * fp + r_hi * gp + cg; + f[N - 2] = nfTop & LIMB_MASK; + g[N - 2] = ngTop & LIMB_MASK; + f[N - 1] = nfTop >> LIMB_BITS; + g[N - 1] = ngTop >> LIMB_BITS; + } + + // (d, e) pass with k*p + { + const d0 = d[0], e0 = e[0], d1 = d[1], e1 = e[1]; + const nd0 = u_lo * d0 + v_lo * e0; + const ne0 = q_lo * d0 + r_lo * e0; + const nd1 = u_lo * d1 + v_lo * e1 + u_hi * d0 + v_hi * e0; + const ne1 = q_lo * d1 + r_lo * e1 + q_hi * d0 + r_hi * e0; + const nd0_low = nd0 & LIMB_MASK; + const nd1_carry = (nd1 + (nd0 >> LIMB_BITS)) & LIMB_MASK; + const t_d = (nd0_low | (nd1_carry << LIMB_BITS)) & MASK_BATCH; + const ne0_low = ne0 & LIMB_MASK; + const ne1_carry = (ne1 + (ne0 >> LIMB_BITS)) & LIMB_MASK; + const t_e = (ne0_low | (ne1_carry << LIMB_BITS)) & MASK_BATCH; + const neg_td = (MASK_BATCH - t_d + 1n) & MASK_BATCH; + const neg_te = (MASK_BATCH - t_e + 1n) & MASK_BATCH; + const k_d = (neg_td * pInv) & MASK_BATCH; + const k_e = (neg_te * pInv) & MASK_BATCH; + const kd_lo = k_d & LIMB_MASK; + const kd_hi = k_d >> LIMB_BITS; + const ke_lo = k_e & LIMB_MASK; + const ke_hi = k_e >> LIMB_BITS; + let cd = + (nd1 + kd_lo * p[1] + kd_hi * p[0] + ((nd0 + kd_lo * p[0]) >> LIMB_BITS)) >> + LIMB_BITS; + let ce = + (ne1 + ke_lo * p[1] + ke_hi * p[0] + ((ne0 + ke_lo * p[0]) >> LIMB_BITS)) >> + LIMB_BITS; + let dp = d1, ep = e1; + for (let i = 2; i < N; i++) { + const di = d[i]; + const ei = e[i]; + const nd = u_lo * di + v_lo * ei + u_hi * dp + v_hi * ep + kd_lo * p[i] + kd_hi * p[i - 1] + cd; + const ne = q_lo * di + r_lo * ei + q_hi * dp + r_hi * ep + ke_lo * p[i] + ke_hi * p[i - 1] + ce; + cd = nd >> LIMB_BITS; + ce = ne >> LIMB_BITS; + d[i - 2] = nd & LIMB_MASK; + e[i - 2] = ne & LIMB_MASK; + dp = di; + ep = ei; + } + const ndTop = u_hi * dp + v_hi * ep + kd_hi * p[N - 1] + cd; + const neTop = q_hi * dp + r_hi * ep + ke_hi * p[N - 1] + ce; + d[N - 2] = ndTop & LIMB_MASK; + e[N - 2] = neTop & LIMB_MASK; + d[N - 1] = ndTop >> LIMB_BITS; + e[N - 1] = neTop >> LIMB_BITS; + } +} + +export function invertA(a: bigint, p: bigint = P_BIGINT, pInv: bigint = P_INV): bigint { + if (a === 0n) return 0n; + const pa = fromBigint(p); + const f = fromBigint(p); + const g = fromBigint(a); + const d = makeZero(); + const e = makeOne(); + let delta = 1n; + for (let iter = 0; iter < NUM_OUTER; iter++) { + if (isZero(g)) break; + const flo = lowU64(f); + const glo = lowU64(g); + const { mat, delta: nd } = divsteps(delta, flo, glo); + delta = nd; + applyMatrix(mat, f, g, d, e, pa, pInv); + if (((iter + 1) % REDUCE_INTERVAL) === 0) { + reduceToCanonical(d, pa); + reduceToCanonical(e, pa); + } + } + reduceToCanonical(d, pa); + if (isNegative(f)) { + neg(d); + reduceToCanonical(d, pa); + } + return toBigint(d); +} diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/gpu.ts b/barretenberg/ts/src/msm_webgpu/cuzk/gpu.ts index d7d2c02af0cc..10fb068a6833 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/gpu.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/gpu.ts @@ -1,5 +1,11 @@ // Request a high-performance GPU device. Enables the optional "timestamp-query" // feature when the adapter supports it, so per-pass timings can be collected. +// +// Explicitly requests the adapter's MAX for `maxComputeWorkgroupStorageSize` +// so workgroup-shared scratch can scale with hardware support. +// WebGPU's default limit is 16 KiB; Apple M1/M2 Metal exposes ~32 KiB, M3+ +// ~64 KiB. Requesting the max keeps room for larger per-WG scratch buffers +// without falling back to the default cap. export const get_device = async (): Promise => { const gpuErrMsg = 'Please use a browser that has WebGPU enabled.'; const adapter = await navigator.gpu.requestAdapter({ @@ -15,10 +21,65 @@ export const get_device = async (): Promise => { requiredFeatures.push('timestamp-query'); } - const device = await adapter.requestDevice({ requiredFeatures }); + const requiredLimits: Record = {}; + const adapterLimits = adapter.limits as unknown as Record; + const wgStorageMax = adapterLimits['maxComputeWorkgroupStorageSize']; + if (typeof wgStorageMax === 'number') { + requiredLimits['maxComputeWorkgroupStorageSize'] = wgStorageMax; + } + + const device = await adapter.requestDevice({ requiredFeatures, requiredLimits }); + const grantedLimits = device.limits as unknown as Record; + console.log( + `[gpu] requested maxComputeWorkgroupStorageSize=${wgStorageMax}B,` + + ` granted=${grantedLimits['maxComputeWorkgroupStorageSize']}B`, + ); return device; }; +// Pick the largest SLAB that fits NUM_BUCKETS u32 atomics in the device's +// workgroup-shared memory limit AND leaves headroom for compiler-allocated +// temporaries. Returns `{ slab, slabs, bytes }` where +// slab * 4 + HEADROOM <= workgroup_storage_limit +// slabs = ceil(num_columns / slab) +// +// Selection: +// - If `num_columns * 4 + HEADROOM <= limit`, use `slab = num_columns` +// (single slab, zero tiling — every scalar is Booth-recoded exactly +// once per window). +// - Otherwise pick the largest power-of-2 SLAB that fits. Power-of-2 is +// a convenience: it keeps `slot = k * WG_SIZE + tid` integer divisions +// cheap and the per-slab digit range a clean [s, s+SLAB) interval. +// +// Headroom rationale: kernels also use small amounts of compiler-allocated +// workgroup memory (loop counters, barrier state). Leaving a 1 KiB margin +// keeps us inside the granted limit on drivers that over-allocate slightly. +export const pick_shared_digit_slab = ( + num_columns: number, + workgroup_storage_limit: number, +): { slab: number; slabs: number; bytes: number } => { + const HEADROOM_BYTES = 1024; + const usable_bytes = Math.max(0, workgroup_storage_limit - HEADROOM_BYTES); + const max_cells = Math.floor(usable_bytes / 4); + if (max_cells <= 0) { + throw new Error( + `pick_shared_digit_slab: workgroup_storage_limit=${workgroup_storage_limit}B is too small even for the 1 KiB headroom`, + ); + } + if (max_cells >= num_columns) { + return { slab: num_columns, slabs: 1, bytes: num_columns * 4 }; + } + let slab = 1; + while (slab * 2 <= max_cells) slab *= 2; + if (slab < 1024) { + throw new Error( + `pick_shared_digit_slab: device workgroup storage ${workgroup_storage_limit}B can only fit ${slab} cells; below the 1024-cell floor`, + ); + } + const slabs = Math.ceil(num_columns / slab); + return { slab, slabs, bytes: slab * 4 }; +}; + export const read_write_buffer_usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST; // Create and write a storage buffer diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index 3697d208e933..fb4018b6f9cd 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -1,9 +1,11 @@ import mustache from 'mustache'; import { + apply_matrix_bench as apply_matrix_bench_shader, barrett as barrett_funcs, batch_affine_apply as batch_affine_apply_shader, batch_affine_apply_scatter as batch_affine_apply_scatter_shader, batch_affine_dispatch_args as batch_affine_dispatch_args_shader, + bench_batch_affine as bench_batch_affine_shader, batch_affine_finalize as batch_affine_finalize_shader, batch_affine_finalize_apply as batch_affine_finalize_apply_shader, batch_affine_finalize_collect as batch_affine_finalize_collect_shader, @@ -12,17 +14,27 @@ import { batch_inverse as batch_inverse_shader, batch_inverse_parallel as batch_inverse_parallel_shader, bigint as bigint_funcs, + bigint_by as bigint_by_funcs, bigint_f32 as bigint_f32_funcs, + // by_inverse hosts the Mat struct + by_divsteps (and grows to host + // by_apply_matrix / fr_inv_by in subsequent sub-steps of the BY rewrite). + by_inverse as by_inverse_funcs, + // Option A BY safegcd inverse on 20 x 13-bit BigInt with BATCH=26 / + // NUM_OUTER=29. Hosts MatA, bya_divsteps, bya_apply_matrix_{fg,de}, the + // bya_reduce_to_canonical helper chain, and the fr_inv_by_a driver. + by_inverse_a as by_inverse_a_funcs, bpr_bn254 as bpr_bn254_shader, convert_point_coords_and_decompose_scalars, convert_points_only as convert_points_only_shader, decompose_scalars_signed_only as decompose_scalars_signed_only_shader, decompress_g1_bn254 as decompress_g1_bn254_shader, + divsteps_bench as divsteps_bench_shader, ec_bn254 as ec_bn254_funcs, extract_word_from_bytes_le as extract_word_from_bytes_le_funcs, field as field_funcs, field_mul_bench_f32 as field_mul_bench_f32_shader, field_mul_bench_u32 as field_mul_bench_u32_shader, + fr_inv_bench as fr_inv_bench_shader, fr_pow as fr_pow_funcs, horner_reduce_bn254 as horner_reduce_bn254_shader, mont_pro_product as montgomery_product_funcs, @@ -37,9 +49,12 @@ import { transpose_serial as transpose_serial_shader, } from '../wgsl/_generated/shaders.js'; import { + compute_by_p_inv_a, + compute_by_p_inv_split, compute_misc_params, compute_mod_inverse_pow2, gen_p_limbs, + gen_p_limbs_by_initializer, gen_p_limbs_f32, gen_r_limbs, gen_mu_limbs, @@ -84,6 +99,10 @@ export class ShaderManager { public r_cubed_limbs: string; public b3_mont_limbs: string; public sqrt_exp_limbs: string; + // (p - 2) as a BigInt literal — exponent for the Fermat-based fr_pow_inv + // bench variant. Plain (non-Montgomery) since fr_pow's `exp` is consumed + // bit-by-bit as a raw integer. + public p_minus_2_limbs: string; public p_inv_mod_2w: number; public mu_limbs: string; // 22-bit-limb f32 Montgomery params. Used exclusively by @@ -93,6 +112,21 @@ export class ShaderManager { public num_limbs_f32_22: number; public n0_f32_22: bigint; public p_limbs_f32_22_str: string; + // 9 × 29-bit BY limb representation of `p` for the BY safegcd inverse + // path. Used by gen_apply_matrix_bench_shader (and in future sub-steps, + // by the fr_inv_by wiring). The initializer string is comma-separated + // limbs suitable for `BigIntBY(array({{{ p_limbs_by }}}))`. + public p_limbs_by_initializer: string; + // P_INV = p^(-1) mod 2^58, split as (low 32, high <=26) bits. The WASM + // convention is a single u64 (`p_inv` argument to Wasm9x29::apply_matrix); + // WGSL has no u64 so we precompute the split here and inject as two + // constants. Hensel-lifted from p mod 2 up to mod 2^58. + public p_inv_by_lo: number; + public p_inv_by_hi: number; + // 26-bit p^(-1) mod 2^26 for the Option A BY safegcd inverse driver + // (BATCH=26 / NUM_OUTER=29 on 20 x 13-bit BigInt). Single u32, since 26 + // bits fit comfortably. + public p_inv_by_a_lo: number; // Pre-rendered u32 Montgomery product source used as the // `montgomery_product_funcs` mustache partial by every MSM shader that // needs a base-field multiply. Defaults to the Karatsuba + Yuval body @@ -135,6 +169,8 @@ export class ShaderManager { // (q + 1) / 4: closed-form sqrt exponent for q ≡ 3 (mod 4). const sqrt_exp = (this.p + 1n) / 4n; this.sqrt_exp_limbs = gen_wgsl_limbs_code(sqrt_exp, 'e', this.num_words, this.word_size); + // (p - 2): exponent for Fermat-based inversion in fr_pow_inv. + this.p_minus_2_limbs = gen_wgsl_limbs_code(this.p - 2n, 'e', this.num_words, this.word_size); this.p_inv_mod_2w = compute_mod_inverse_pow2(this.p, this.word_size); this.mu_limbs = gen_mu_limbs(this.p, this.num_words, this.word_size); this.p_bitlength = this.p.toString(2).length; @@ -148,6 +184,18 @@ export class ShaderManager { this.n0_f32_22 = params_f32_22.n0; this.p_limbs_f32_22_str = gen_p_limbs_f32(this.p, this.num_limbs_f32_22, 22); + // BY safegcd 9 × 29-bit representation of p and 58-bit p_inv split. + // Both feed `gen_apply_matrix_bench_shader` (and downstream by_inverse + // production wiring). The split is the WASM `p_inv` u64 broken into + // low-32 + high-26 chunks; the Mustache substitution is a flat u32 + // constant on each side. + this.p_limbs_by_initializer = gen_p_limbs_by_initializer(this.p); + const p_inv_split = compute_by_p_inv_split(this.p); + this.p_inv_by_lo = p_inv_split.lo; + this.p_inv_by_hi = p_inv_split.hi; + // Option A 26-bit p_inv (single u32) for the BATCH=26 BY driver. + this.p_inv_by_a_lo = compute_by_p_inv_a(this.p); + // Render the Karatsuba+Yuval Mont body once. This is the default // u32 multiplier used by every MSM shader that includes the // `montgomery_product_funcs` mustache partial. @@ -272,6 +320,7 @@ export class ShaderManager { mu_limbs: this.mu_limbs, b3_mont_limbs: this.b3_mont_limbs, sqrt_exp_limbs: this.sqrt_exp_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, w_mask: this.w_mask, slack: this.slack, num_words_mul_two: this.num_words * 2, @@ -385,9 +434,11 @@ export class ShaderManager { p_limbs: this.p_limbs, r_limbs: this.r_limbs, r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, mask: this.mask, two_pow_word_size: this.two_pow_word_size, p_inv_mod_2w: this.p_inv_mod_2w, + p_inv_by_a_lo: this.p_inv_by_a_lo, recompile: this.recompile, }, { @@ -396,11 +447,17 @@ export class ShaderManager { montgomery_product_funcs: this.mont_product_src, field_funcs, fr_pow_funcs, + bigint_by_funcs, + by_inverse_a_funcs, }, ); } - public gen_batch_inverse_parallel_shader(num_sub_wgs: number): string { + // `windows_per_batch` (WPB) sets how many consecutive subtask pair pools + // get merged into ONE fr_inv_by_a call per (batch, sub_wg). Z dispatch + // dim must be ceil(num_subtasks / WPB). Pass WPB=1 to recover the + // pre-pooling behaviour byte-for-byte. + public gen_batch_inverse_parallel_shader(num_sub_wgs: number, windows_per_batch: number): string { return mustache.render( batch_inverse_parallel_shader, { @@ -410,10 +467,62 @@ export class ShaderManager { p_limbs: this.p_limbs, r_limbs: this.r_limbs, r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, mask: this.mask, two_pow_word_size: this.two_pow_word_size, p_inv_mod_2w: this.p_inv_mod_2w, + p_inv_by_a_lo: this.p_inv_by_a_lo, num_sub_wgs, + windows_per_batch, + recompile: this.recompile, + }, + { + structs, + bigint_funcs, + montgomery_product_funcs: this.mont_product_src, + field_funcs, + fr_pow_funcs, + bigint_by_funcs, + by_inverse_a_funcs, + }, + ); + } + + // Standalone bench-only entry shader for batch-affine EC addition. One + // workgroup processes BATCH_SIZE pairs via the two-phase Montgomery + // batch-inverse trick (per-thread serial chunk + workgroup Hillis-Steele + // scan + single fr_inv_by_a + back-walk). Used by bench-batch-affine.ts + // to find the sweet spot where amortising the single inverse stops + // beating thread under-utilisation. + // + // `batch_size` must be an exact multiple of `tpb`; the caller picks both + // from a hand-built table (see bench-batch-affine.ts). BS = batch_size / + // tpb is baked into the shader as a compile-time constant so the inner + // forward-and-backward walks have static loop bounds. + public gen_bench_batch_affine_shader(batch_size: number, tpb: number): string { + if (batch_size <= 0 || tpb <= 0 || batch_size % tpb !== 0) { + throw new Error( + `gen_bench_batch_affine_shader: batch_size (${batch_size}) must be a positive multiple of tpb (${tpb})`, + ); + } + const per_thread_count = batch_size / tpb; + return mustache.render( + bench_batch_affine_shader, + { + batch_size, + tpb, + per_thread_count, + word_size: this.word_size, + num_words: this.num_words, + n0: this.n0, + p_limbs: this.p_limbs, + r_limbs: this.r_limbs, + r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + p_inv_by_a_lo: this.p_inv_by_a_lo, recompile: this.recompile, }, { @@ -422,6 +531,8 @@ export class ShaderManager { montgomery_product_funcs: this.mont_product_src, field_funcs, fr_pow_funcs, + bigint_by_funcs, + by_inverse_a_funcs, }, ); } @@ -438,8 +549,12 @@ export class ShaderManager { ); } - public gen_batch_affine_dispatch_args_shader(): string { - return mustache.render(batch_affine_dispatch_args_shader, {}, {}); + // `windows_per_batch` (WPB) is baked into the shader at render time — + // dispatch_args derives `num_batches = ceil(num_subtasks / WPB)` and + // uses it as the inverse-pass Z dispatch dim. Must match the WPB used + // by the corresponding gen_batch_inverse_parallel_shader call. + public gen_batch_affine_dispatch_args_shader(windows_per_batch: number): string { + return mustache.render(batch_affine_dispatch_args_shader, { windows_per_batch }, {}); } public gen_batch_affine_schedule_shader(workgroup_size: number): string { @@ -503,6 +618,7 @@ export class ShaderManager { p_limbs: this.p_limbs, r_limbs: this.r_limbs, r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, mask: this.mask, two_pow_word_size: this.two_pow_word_size, p_inv_mod_2w: this.p_inv_mod_2w, @@ -607,6 +723,11 @@ export class ShaderManager { bench_no_store?: boolean; } = {}, safe_first_add_no_collision = false, + // Multi-window BPR: each thread loops over WPB consecutive subtasks, + // sharing kernel-launch and header overhead. WPB=1 keeps the legacy + // one-subtask-per-workgroup behaviour. Const-bounded inside the + // shader so Tint can fully unroll when WPB is small. + windows_per_batch = 1, ) { const bench_null = !!bench_flags.bench_null; const bench_compute_only = !!bench_flags.bench_compute_only; @@ -631,6 +752,7 @@ export class ShaderManager { p_inv_mod_2w: this.p_inv_mod_2w, index_shift: this.index_shift, workgroup_size, + windows_per_batch, recompile: this.recompile, capture_debug, assume_affine_buckets, @@ -712,6 +834,266 @@ ${mont_src} ${entry_src}`; } + // Bench-only entry shader for the BY `by_divsteps` primitive. Assembles + // the BY helpers (bigint_by) + the by_inverse partial (which hosts the + // `Mat` struct and `by_divsteps`) + the per-thread bench entry. + // + // Each thread reads one (f_lo, g_lo, delta) tuple, calls `by_divsteps`, + // and writes the resulting 8-field Mat + updated delta. Used by the + // bench-divsteps.html page to compare against the TS Wasm9x29 port. + // + // Note: bigint_funcs / structs are NOT included here because divsteps + // only needs the BY-specific helpers (signed_mul_split, u64_*_pair, + // i64_*_pair). The BigInt-related portions of bigint_by (by_from_bigint + // et al.) are still rendered for completeness; they're dead code in this + // bench but will be needed in step 1.5 when fr_inv_by is wired up. + // We also strip BigInt-conversion helpers from the bigint_by render here + // by skipping the `{{ num_words }}` substitution — instead we set + // num_words to the standard MSM value so the by_from/to_bigint helpers + // remain syntactically valid even though unused. + public gen_divsteps_bench_shader(workgroup_size: number): string { + // Minimal BigInt struct declaration: by_from_bigint/by_to_bigint in + // bigint_by reference the BigInt type for completeness. This bench + // shader never calls them so the struct is dead but must compile. + const structs_src = mustache.render(structs, { num_words: this.num_words }); + const bigint_src = mustache.render(bigint_funcs, {}); + // by_inverse hosts `fr_inv_by`, which references `montgomery_product`, + // `get_r_cubed`, and the {{ p_limbs_by }} / {{ p_inv_by_* }} Mustache + // substitutions. divsteps_bench itself never calls fr_inv_by, but the + // partial must compile cleanly — so we pull in the same Mont + field + + // fr_pow surface that gen_fr_inv_bench_shader uses. + const mont_src = this.mont_product_src; + const field_src = mustache.render(field_funcs, { + word_size: this.word_size, + num_words: this.num_words, + n0: this.n0, + p_limbs: this.p_limbs, + r_limbs: this.r_limbs, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + }); + const fr_pow_src = mustache.render(fr_pow_funcs, { + word_size: this.word_size, + num_words: this.num_words, + n0: this.n0, + p_limbs: this.p_limbs, + r_limbs: this.r_limbs, + r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + }); + const bigint_by_src = mustache.render(bigint_by_funcs, { + num_words: this.num_words, + }); + const by_inverse_src = this.renderByInverseFuncs(); + const entry_src = mustache.render(divsteps_bench_shader, { workgroup_size }); + const get_r_src = this.renderGetRFn(); + return `${structs_src} +${bigint_src} +${mont_src} +${field_src} +${get_r_src} +${fr_pow_src} +${bigint_by_src} +${by_inverse_src} +${entry_src}`; + } + + // Bench-only entry shader for `by_apply_matrix_fg` + `by_apply_matrix_de`. + // Each thread reads one (Mat, f, g, d, e) record, runs both passes, and + // writes the updated (f', g', d', e') as 36 i32 values. Validates against + // the TS `Wasm9x29.applyMatrix` reference (used by the bench-apply-matrix + // Playwright driver). + // + // Renders: + // - structs (with num_words for BigInt declaration — dead-coded in this + // bench, kept so the bigint_by partial compiles cleanly). + // - bigint_by partial (signed_mul_split, u64/i64 helpers, by_normalise). + // - by_inverse partial (Mat, by_divsteps, by_apply_matrix_*). + // - apply_matrix_bench entry (decode → apply → encode). + // Mustache substitutions on the entry shader: + // - workgroup_size + // - p_limbs_by: BigIntBY initializer for p + // - p_inv_by_lo: low 32 bits of P_INV = p^(-1) mod 2^58 + // - p_inv_by_hi: high (up to 26) bits of P_INV + public gen_apply_matrix_bench_shader(workgroup_size: number): string { + const structs_src = mustache.render(structs, { num_words: this.num_words }); + const bigint_src = mustache.render(bigint_funcs, {}); + // See gen_divsteps_bench_shader for why the Mont + field + fr_pow surface + // is included — `fr_inv_by` lives inside the by_inverse partial and + // references those symbols, even though apply_matrix_bench never calls it. + const mont_src = this.mont_product_src; + const field_src = mustache.render(field_funcs, { + word_size: this.word_size, + num_words: this.num_words, + n0: this.n0, + p_limbs: this.p_limbs, + r_limbs: this.r_limbs, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + }); + const fr_pow_src = mustache.render(fr_pow_funcs, { + word_size: this.word_size, + num_words: this.num_words, + n0: this.n0, + p_limbs: this.p_limbs, + r_limbs: this.r_limbs, + r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + }); + const bigint_by_src = mustache.render(bigint_by_funcs, { + num_words: this.num_words, + }); + const by_inverse_src = this.renderByInverseFuncs(); + const entry_src = mustache.render(apply_matrix_bench_shader, { + workgroup_size, + p_limbs_by: this.p_limbs_by_initializer, + p_inv_by_lo: this.p_inv_by_lo, + p_inv_by_hi: this.p_inv_by_hi, + }); + const get_r_src = this.renderGetRFn(); + return `${structs_src} +${bigint_src} +${mont_src} +${field_src} +${get_r_src} +${fr_pow_src} +${bigint_by_src} +${by_inverse_src} +${entry_src}`; + } + + // Bench-only entry shader for the BY top-level `fr_inv_by` driver. Each + // thread reads one BN254 base-field value `a` (in Montgomery form), runs + // `k` chained `fr_inv_by` calls, and writes the final value back. Used by + // the bench-fr-inv Playwright driver to validate against the host + // `Wasm9x29.invert` + Mont-correction reference. + // + // The render bundles: + // - structs (BigInt declaration, NUM_WORDS limbs) + // - bigint_funcs (basic BigInt utilities) + // - karat+yuval Mont product (provides `montgomery_product`, `get_p`) + // - fr_pow_funcs (provides `get_r_cubed` + `fr_inv`) + // - bigint_by (variant=fr_inv_by) (signed_mul_split, u64/i64 helpers, + // by_normalise, by_from/to_bigint) + // - by_inverse (variant=fr_inv_by) (Mat, by_divsteps, by_apply_matrix_*, + // by_reduce_to_canonical, fr_inv_by) + // - fr_inv_bench entry (per-thread chained inversion) + // + // The `variant` arg picks the symbol the entry shader calls. For + // 'fr_inv_by' we include the by_inverse + bigint_by partials; for the + // legacy 'fr_inv' (Pornin jumpy K=12 safegcd in fr_pow) we omit them so + // the rendered bundle stays small and dead-code-free. + // + // Mustache substitutions consumed by `by_inverse` (via renderByInverseFuncs): + // - p_limbs_by: BigIntBY initializer for the BN254 base-field modulus + // - p_inv_by_lo: low 32 bits of P_INV = p^(-1) mod 2^58 + // - p_inv_by_hi: high (up to 26) bits of P_INV + public gen_fr_inv_bench_shader( + workgroup_size: number, + variant: 'fr_inv_by' | 'fr_inv' | 'fr_inv_by_a' | 'fr_pow_inv' = 'fr_inv_by', + ): string { + const structs_src = mustache.render(structs, { num_words: this.num_words }); + const bigint_src = mustache.render(bigint_funcs, {}); + const mont_src = this.mont_product_src; + // field_funcs provides `fr_sub` / `fr_add` / `bigint_halve_k_mod_p` and + // other helpers that fr_pow's alternate variants reference. `fr_inv_by` + // itself does not call them but the partial must resolve every symbol + // or shader compilation fails. + const field_src = mustache.render(field_funcs, { + word_size: this.word_size, + num_words: this.num_words, + n0: this.n0, + p_limbs: this.p_limbs, + r_limbs: this.r_limbs, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + }); + // fr_pow exports `fr_pow`, `fr_inv`, `fr_inv_plain`, `fr_inv_bgcd`, and + // the `get_r_cubed` helper that `fr_inv_by` uses for the Mont + // correction. We include it in BOTH variants so `get_r_cubed` is in + // scope for fr_inv_by and `fr_inv` is in scope for the legacy variant. + const fr_pow_src = mustache.render(fr_pow_funcs, { + word_size: this.word_size, + num_words: this.num_words, + n0: this.n0, + p_limbs: this.p_limbs, + r_limbs: this.r_limbs, + r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, + mask: this.mask, + two_pow_word_size: this.two_pow_word_size, + p_inv_mod_2w: this.p_inv_mod_2w, + }); + // bigint_by + by_inverse are gated on variant: they are large partials + // (≈700 lines of BY plumbing) and only needed when fr_inv_by is called. + // For the legacy fr_inv variant we omit them entirely so the bundle + // stays small and we don't pay compile time for dead code. + let by_blocks = ''; + if (variant === 'fr_inv_by') { + by_blocks = `${mustache.render(bigint_by_funcs, { num_words: this.num_words })} +${this.renderByInverseFuncs()}`; + } else if (variant === 'fr_inv_by_a') { + // Option A reuses the u64 helpers from bigint_by (u64_add, u64_sub, + // u64_shr1, u64_low_bit) but does NOT need the 9 x 29-bit BigIntBY + // struct or its conversion helpers. We still pull in bigint_by for + // the u64 helpers and signed_mul_split (unused here but cheap). + by_blocks = `${mustache.render(bigint_by_funcs, { num_words: this.num_words })} +${this.renderByInverseAFuncs()}`; + } + const entry_src = mustache.render(fr_inv_bench_shader, { + workgroup_size, + r_limbs: this.r_limbs, + inv_fn: variant, + }); + return `${structs_src} +${bigint_src} +${mont_src} +${field_src} +${fr_pow_src} +${by_blocks} +${entry_src}`; + } + + // Render the by_inverse partial with the BY-specific Mustache constants + // (BigIntBY initializer for p, and the p_inv 58-bit split). Shared by the + // divsteps / apply_matrix / fr_inv bench renders so the partial's + // `{{{ p_limbs_by }}}` and `{{ p_inv_by_* }}` substitutions resolve to + // valid WGSL in every assembly. + private renderByInverseFuncs(): string { + return mustache.render(by_inverse_funcs, { + p_limbs_by: this.p_limbs_by_initializer, + p_inv_by_lo: this.p_inv_by_lo, + p_inv_by_hi: this.p_inv_by_hi, + }); + } + + // Render the by_inverse_a partial (Option A: BATCH=26 / NUM_OUTER=29 BY + // safegcd on 20 x 13-bit BigInt). Mustache substitutions: the 26-bit + // p_inv constant and {{ num_words }} for the streaming-loop bound. + private renderByInverseAFuncs(): string { + return mustache.render(by_inverse_a_funcs, { + num_words: this.num_words, + p_inv_by_a_lo: this.p_inv_by_a_lo, + }); + } + + // Inlined `get_r` definition with the curve-specific R limbs. Every + // production MSM shader defines its own; the bench harnesses pull this + // from a single helper so the divsteps / apply_matrix benches can hoist + // it before fr_pow_funcs (which calls `get_r()` from inside fr_pow). + private renderGetRFn(): string { + return `fn get_r() -> BigInt {\n var r: BigInt;\n${this.r_limbs}\n return r;\n}`; + } + // Bench-only entry shader for the f32 Montgomery product. Only the // sos3uv3 variant (22-bit limbs, per-slot tlo/thi chain-break) is wired // up here — it is the fastest f32 Mont mul found in the wider variant diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/utils.ts b/barretenberg/ts/src/msm_webgpu/cuzk/utils.ts index 8570a077e4f0..c52165b49d5a 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/utils.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/utils.ts @@ -256,6 +256,110 @@ export const gen_p_limbs = ( return gen_wgsl_limbs_code(p, "p", num_words, word_size); }; +// Emit a comma-separated `array` initializer list for the 9 × 29-bit +// limbs of a non-negative bigint. Used to inject the BN254 base-field modulus +// `p` into the apply_matrix bench shader as a literal `BigIntBY` initializer. +// +// Pre: p in [0, 2^256). Post: e.g. "1, 2, 3, 4, 5, 6, 7, 8, 9" suitable for +// substitution inside `BigIntBY(array({{{ p_limbs_by }}}))`. +export const gen_p_limbs_by_initializer = (p: bigint): string => { + if (p < 0n) { + throw new Error("gen_p_limbs_by_initializer: negative input"); + } + const BY_NUM_LIMBS = 9; + const BY_LIMB_BITS = 29n; + const BY_LIMB_MASK = (1n << BY_LIMB_BITS) - 1n; + const limbs: string[] = []; + // 9 × 29-bit limb decomposition matches Wasm9x29::fromBigint exactly. + // Limbs are non-negative for inputs in [0, 2^256) since the top limb's + // bit 28 is 0 for p < 2^254. + const mask64 = (1n << 64n) - 1n; + const d0 = p & mask64; + const d1 = (p >> 64n) & mask64; + const d2 = (p >> 128n) & mask64; + const d3 = (p >> 192n) & mask64; + limbs.push((d0 & BY_LIMB_MASK).toString()); + limbs.push(((d0 >> 29n) & BY_LIMB_MASK).toString()); + limbs.push((((d0 >> 58n) & 0x3Fn) | ((d1 & 0x7FFFFFn) << 6n)).toString()); + limbs.push(((d1 >> 23n) & BY_LIMB_MASK).toString()); + limbs.push((((d1 >> 52n) & 0xFFFn) | ((d2 & 0x1FFFFn) << 12n)).toString()); + limbs.push(((d2 >> 17n) & BY_LIMB_MASK).toString()); + limbs.push((((d2 >> 46n) & 0x3FFFFn) | ((d3 & 0x7FFn) << 18n)).toString()); + limbs.push(((d3 >> 11n) & BY_LIMB_MASK).toString()); + limbs.push(((d3 >> 40n) & 0xFFFFFFn).toString()); + if (limbs.length !== BY_NUM_LIMBS) { + throw new Error(`gen_p_limbs_by_initializer: expected ${BY_NUM_LIMBS} limbs, got ${limbs.length}`); + } + return limbs.join(", "); +}; + +// p_inv = p^(-1) mod 2^58 for the BY safegcd 2-adic correction step. WASM's +// `Wasm9x29::apply_matrix` accepts this as a single u64; WGSL has no native +// u64 so we split it as two u32s — low 32 bits and high (up to 26) bits. The +// caller injects these as constants into the by_apply_matrix_de Mustache +// substitution `{{ p_inv_by_lo }}` / `{{ p_inv_by_hi }}`. +// +// Returns { lo, hi } where (lo + hi * 2^32) === p_inv and 0 <= lo < 2^32, +// 0 <= hi < 2^26. +export const compute_by_p_inv_split = (p: bigint): { lo: number; hi: number } => { + const BATCH = 58n; + const MASK_BATCH = (1n << BATCH) - 1n; + if (p < 1n) { + throw new Error("compute_by_p_inv_split: p must be positive"); + } + if ((p & 1n) === 0n) { + throw new Error("compute_by_p_inv_split: p must be odd"); + } + // Hensel lift: invert p mod 2, then double precision each Newton step + // until we have >= 58 bits. inv_n satisfies p * inv_n ≡ 1 mod 2^n. + // Newton's formula: inv' = inv * (2 - p * inv) mod 2^(2k), valid when + // inv was valid mod 2^k. Start at k=1 (inv=1 works since p is odd). + let inv = 1n; + let cur = 1n; + while (cur < BATCH) { + cur *= 2n; + const mask = (1n << cur) - 1n; + inv = (inv * (2n - p * inv)) & mask; + } + inv &= MASK_BATCH; + if (((p * inv) & MASK_BATCH) !== 1n) { + throw new Error("compute_by_p_inv_split: Hensel lift failed sanity check"); + } + const lo = Number(inv & ((1n << 32n) - 1n)); + const hi = Number((inv >> 32n) & ((1n << 32n) - 1n)); + return { lo, hi }; +}; + +// p_inv = p^(-1) mod 2^26 for the Option A BY safegcd inverse driver +// (BATCH=26 / NUM_OUTER=29 on 20 x 13-bit BigInt). Mirrors compute_by_p_inv_split +// but at 26-bit precision so it fits in a single u32 WGSL constant. +// +// Pre: p odd, positive. +// Post: low 26 bits of p^(-1) mod 2^26 packed in a u32, satisfying +// (p * out) & ((1<<26) - 1) === 1. +export const compute_by_p_inv_a = (p: bigint): number => { + const BATCH = 26n; + const MASK_BATCH = (1n << BATCH) - 1n; + if (p < 1n) { + throw new Error("compute_by_p_inv_a: p must be positive"); + } + if ((p & 1n) === 0n) { + throw new Error("compute_by_p_inv_a: p must be odd"); + } + let inv = 1n; + let cur = 1n; + while (cur < BATCH) { + cur *= 2n; + const mask = (1n << cur) - 1n; + inv = (inv * (2n - p * inv)) & mask; + } + inv &= MASK_BATCH; + if (((p * inv) & MASK_BATCH) !== 1n) { + throw new Error("compute_by_p_inv_a: Hensel lift failed sanity check"); + } + return Number(inv); +}; + export const gen_r_limbs = ( r: bigint, num_words: number, diff --git a/barretenberg/ts/src/msm_webgpu/msm.ts b/barretenberg/ts/src/msm_webgpu/msm.ts index bb23cdf002c3..f47d43fa4f55 100644 --- a/barretenberg/ts/src/msm_webgpu/msm.ts +++ b/barretenberg/ts/src/msm_webgpu/msm.ts @@ -948,7 +948,14 @@ const compute_curve_msm = async ( // loop's wraparound bug AND keeps the simpler one-iteration dispatch // that the warm-cached bind-group cache already assumes. const num_subtasks_per_bpr_1 = num_subtasks; - const b_num_x_workgroups = num_subtasks_per_bpr_1; + // BPR_WINDOWS_PER_BATCH: each workgroup handles WPB consecutive + // subtasks via an in-kernel const-bounded loop. WPB=1 = legacy + // dispatch shape (X = num_subtasks). WPB > 1 trades thread count for + // per-thread work — see bpr_bn254.template.wgsl stage_1 comment for + // the register-pressure tradeoff. Override per call via the + // bpr_inner_loop knob downstream if needed. + const BPR_WINDOWS_PER_BATCH = 1; + const b_num_x_workgroups = Math.ceil(num_subtasks_per_bpr_1 / BPR_WINDOWS_PER_BATCH); const b_workgroup_size = 256; // Output of the parallel bucket points reduction (BPR) shader @@ -1002,6 +1009,7 @@ const compute_curve_msm = async ( /* mixed_safe_buckets */ bpr_mixed_safe, /* bench_flags */ bpr_bench_flags, /* safe_first_add_no_collision */ bpr_safe_first, + /* windows_per_batch */ BPR_WINDOWS_PER_BATCH, ); // Compact key derived from the bench flags. Forwarded into the bpr_1 // and bpr_2 pipeline cache keys so each variant compiles its own @@ -1041,6 +1049,8 @@ const compute_curve_msm = async ( bpr_mixed_safe, bpr_bench_key, input_size, + num_subtasks, + BPR_WINDOWS_PER_BATCH, ); } @@ -1070,7 +1080,7 @@ const compute_curve_msm = async ( // Bucket points reduction (BPR) - stage 2 // Same as bpr_1: dispatch all T subtasks in one outer iter. const num_subtasks_per_bpr_2 = num_subtasks; - const b_2_num_x_workgroups = num_subtasks_per_bpr_2; + const b_2_num_x_workgroups = Math.ceil(num_subtasks_per_bpr_2 / BPR_WINDOWS_PER_BATCH); for (let subtask_idx = 0; subtask_idx < num_subtasks; subtask_idx += num_subtasks_per_bpr_2) { await bpr_2( bpr_shader, @@ -1096,6 +1106,8 @@ const compute_curve_msm = async ( bpr_mixed_safe, bpr_bench_key, input_size, + num_subtasks, + BPR_WINDOWS_PER_BATCH, ); } cpu_timer.phaseFrom('bpr_host_total', 'bpr_host_begin'); @@ -2359,6 +2371,11 @@ const bpr_1 = async ( // first call's buffers, BPR would read stale data, and downstream // Horner would read an unwritten g_points buffer and return identity. input_size_key = 0, + // Total subtask count + WPB. Passed to the shader as params[3] so + // multi-window dispatches with WPB > 1 can skip out-of-range subtasks + // in the tail batch (when num_subtasks is not a multiple of WPB). + num_subtasks_total = 0, + windows_per_batch = 1, ) => { let original_bucket_sum_x_sb; let original_bucket_sum_y_sb; @@ -2375,11 +2392,13 @@ const bpr_1 = async ( commandEncoder.copyBufferToBuffer(bucket_sum_z_sb, 0, original_bucket_sum_z_sb, 0, bucket_sum_z_sb.size); } - // Parameters as a uniform buffer. Contents (subtask_idx, num_columns, - // num_x_workgroups) are constant per (subtask_idx, layout) tuple, so - // we cache them on the context when one is provided. - const params_bytes = numbers_to_u8s_for_gpu([subtask_idx, num_columns, num_x_workgroups]); - const bpr1_key = `${curveId ?? 'x'}:bpr1:${workgroup_size}:${num_columns}:${num_x_workgroups}:${subtask_idx}:N=${input_size_key}`; + // Parameters as a uniform buffer. Layout: (subtask_idx_base, + // num_columns, num_subtasks_per_bpr, num_subtasks_total). The 4th + // slot lets the WPB-aware shader skip out-of-range subtasks in tail + // batches. Constant per (subtask_idx, layout, WPB) tuple, so we cache + // it on the context when one is provided. + const params_bytes = numbers_to_u8s_for_gpu([subtask_idx, num_columns, num_x_workgroups, num_subtasks_total]); + const bpr1_key = `${curveId ?? 'x'}:bpr1:wpb${windows_per_batch}:${workgroup_size}:${num_columns}:${num_x_workgroups}:${subtask_idx}:N=${input_size_key}`; let params_ub: GPUBuffer; if (context !== undefined && !debug && !debug_capture_sb) { const got = context.acquirePersistentUniform(`${bpr1_key}:params_ub`, params_bytes.length); @@ -2406,7 +2425,7 @@ const bpr_1 = async ( // Debug-capture variant compiles a different shader (Mustache flag), // so keys must not collide with the non-debug one. context, - `${curveId ?? 'x'}:bpr1:${workgroup_size}:${num_columns}:${debug_capture_sb ? 'dbg' : 'nodbg'}:${assume_affine_buckets ? 'aff' : mixed_safe_buckets ? 'mxs-v2' : 'gen'}:bench=${bench_flags_key || 'none'}`, + `${curveId ?? 'x'}:bpr1:wpb${windows_per_batch}:${workgroup_size}:${num_columns}:${debug_capture_sb ? 'dbg' : 'nodbg'}:${assume_affine_buckets ? 'aff' : mixed_safe_buckets ? 'mxs-v2' : 'gen'}:bench=${bench_flags_key || 'none'}`, ); cpu_timer?.accumulate('compile_bpr1_shader', performance.now() - _b1_compile_t0); @@ -2554,10 +2573,15 @@ const bpr_2 = async ( // group across MSM calls with different `input_size`, whose // workspace buffers are equally sized but distinct GPUBuffer objects. input_size_key = 0, + // See bpr_1: total subtasks (4th param slot) + WPB (cache key). + num_subtasks_total = 0, + windows_per_batch = 1, ) => { // Parameters as a uniform buffer (cached on context when not debug). - const params_bytes = numbers_to_u8s_for_gpu([subtask_idx, num_columns, num_x_workgroups]); - const bpr2_key = `${curveId ?? 'x'}:bpr2:${workgroup_size}:${num_columns}:${num_x_workgroups}:${subtask_idx}:N=${input_size_key}`; + // Layout: (subtask_idx_base, num_columns, num_subtasks_per_bpr, + // num_subtasks_total). See bpr_1 for the layout rationale. + const params_bytes = numbers_to_u8s_for_gpu([subtask_idx, num_columns, num_x_workgroups, num_subtasks_total]); + const bpr2_key = `${curveId ?? 'x'}:bpr2:wpb${windows_per_batch}:${workgroup_size}:${num_columns}:${num_x_workgroups}:${subtask_idx}:N=${input_size_key}`; let params_ub: GPUBuffer; if (context !== undefined && !debug && !debug_capture_sb) { const got = context.acquirePersistentUniform(`${bpr2_key}:params_ub`, params_bytes.length); @@ -2578,7 +2602,7 @@ const bpr_2 = async ( shaderCode, 'stage_2', context, - `${curveId ?? 'x'}:bpr2:${workgroup_size}:${num_columns}:${debug_capture_sb ? 'dbg' : 'nodbg'}:${assume_affine_buckets ? 'aff' : mixed_safe_buckets ? 'mxs-v2' : 'gen'}:bench=${bench_flags_key || 'none'}`, + `${curveId ?? 'x'}:bpr2:wpb${windows_per_batch}:${workgroup_size}:${num_columns}:${debug_capture_sb ? 'dbg' : 'nodbg'}:${assume_affine_buckets ? 'aff' : mixed_safe_buckets ? 'mxs-v2' : 'gen'}:bench=${bench_flags_key || 'none'}`, ); cpu_timer?.accumulate('compile_bpr2_shader', performance.now() - _b2_compile_t0); diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index 94e18b0bf47e..2ef6db054f7d 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -1,6 +1,6 @@ // AUTO-GENERATED by scripts/inline-wgsl.mjs. DO NOT EDIT. // Run `yarn generate:wgsl` (or `node scripts/inline-wgsl.mjs`) to regenerate. -// 35 shader sources inlined. +// 42 shader sources inlined. /* eslint-disable */ @@ -506,6 +506,349 @@ fn bigint_signed_axby_modp_halve_k( } `; +export const bigint_by = `// WGSL bigint helpers for the Bernstein-Yang safegcd inversion port +// (sub-step 1.2 of the WebGPU MSM rewrite plan). The semantics here MUST +// match the TS reference at cuzk/bernstein_yang.ts exactly so divsteps / +// apply_matrix can be transliterated against this file in sub-step 1.3. +// +// REPRESENTATIONS +// - BigIntBY: 9 limbs of signed-29-bit chunks (i32). Top limb carries +// sign; lower limbs in [0, 2^29) post-normalise. Matches Wasm9x29. +// - u64 as vec2: value = .x + (.y << 32). All u64 ops use this. +// - i64 as vec2: value = u32(.x) + (i32(.y) << 32). The low half +// carries the raw bit pattern (treated as unsigned), the high half +// is signed. Used for matrix entries u/v/q/r which grow to ±2^58. +// +// LOOP BOUND DISCIPLINE +// Every \`for\` loop in this file uses a const upper bound (BY_NUM_LIMBS). +// The WGSL bounded-loop audit is \`grep 'for ' ...\`, run after Mustache +// render — every match has \`< BY_NUM_LIMBS\` or \`< {{ num_words }}\`. + +const BY_NUM_LIMBS: u32 = 9u; +const BY_LIMB_BITS: u32 = 29u; +const BY_LIMB_MASK: u32 = (1u << 29u) - 1u; +const BY_BATCH: u32 = 58u; +const BY_NUM_OUTER: u32 = 13u; +const BY_REDUCE_INTERVAL: u32 = 4u; +const BY_RTC_MAX_ITERS: u32 = 36u; + +// 29-bit signed limb wrapped as i32. Lower limbs canonical in [0, 2^29); +// top limb is the signed extension carrier. +struct BigIntBY { + l: array, +} + +// ============================================================ +// u64 arithmetic as vec2 (x = .x + (.y << 32)) +// ============================================================ + +// u64 add. Pre: any a, b. Post: low 64 bits of a + b. +fn u64_add(a: vec2, b: vec2) -> vec2 { + let lo = a.x + b.x; + let carry = select(0u, 1u, lo < a.x); + let hi = a.y + b.y + carry; + return vec2(lo, hi); +} + +// u64 sub. Pre: any a, b. Post: low 64 bits of (a - b) mod 2^64. +fn u64_sub(a: vec2, b: vec2) -> vec2 { + let lo = a.x - b.x; + let borrow = select(0u, 1u, a.x < b.x); + let hi = a.y - b.y - borrow; + return vec2(lo, hi); +} + +// u64 logical right shift by 1. Pre: any x. Post: x >> 1 unsigned. +fn u64_shr1(x: vec2) -> vec2 { + let lo = (x.x >> 1u) | ((x.y & 1u) << 31u); + let hi = x.y >> 1u; + return vec2(lo, hi); +} + +// Bit 0 of a u64. Pre: any x. Post: 0 or 1. +fn u64_low_bit(x: vec2) -> u32 { + return x.x & 1u; +} + +// u64 two's-complement negate. Pre: any x. Post: (-x) mod 2^64. +fn u64_neg(x: vec2) -> vec2 { + let nx = vec2(~x.x, ~x.y); + return u64_add(nx, vec2(1u, 0u)); +} + +// u64 bitwise AND. Pre: any a, mask. Post: a & mask, limbwise. +fn u64_and(a: vec2, mask: vec2) -> vec2 { + return vec2(a.x & mask.x, a.y & mask.y); +} + +// ============================================================ +// i64 signed arithmetic as vec2 (lo bits, hi signed) +// value = u32(.x) + (i32(.y) << 32) treated as signed two's-complement. +// ============================================================ + +// i64 add. Pre: |a|, |b| < 2^63. Post: low 64 bits of a + b, signed. +fn i64_add_pair(a: vec2, b: vec2) -> vec2 { + let au = u32(a.x); + let bu = u32(b.x); + let lo = au + bu; + let carry = select(0u, 1u, lo < au); + let hi = a.y + b.y + i32(carry); + return vec2(i32(lo), hi); +} + +// i64 sub. Pre: |a|, |b| < 2^63. Post: low 64 bits of a - b, signed. +fn i64_sub_pair(a: vec2, b: vec2) -> vec2 { + let au = u32(a.x); + let bu = u32(b.x); + let lo = au - bu; + let borrow = select(0u, 1u, au < bu); + let hi = a.y - b.y - i32(borrow); + return vec2(i32(lo), hi); +} + +// i64 left shift by 1. Pre: a in (-2^62, 2^62). Post: a << 1, signed. +fn i64_shl1_pair(a: vec2) -> vec2 { + let au = u32(a.x); + let lo = au << 1u; + let bit31 = au >> 31u; + // Arithmetic shift left of hi half: bring bottom carry bit up. + let hi_u = (u32(a.y) << 1u) | bit31; + return vec2(i32(lo), i32(hi_u)); +} + +// i64 two's-complement negate. Pre: any a (low 64 bits well-defined). +// Post: (-a) mod 2^64 as a signed pair. +fn i64_neg_pair(a: vec2) -> vec2 { + let nlo = ~u32(a.x); + let nhi = ~u32(a.y); + let lo = nlo + 1u; + let carry = select(0u, 1u, lo < nlo); + let hi_u = nhi + carry; + return vec2(i32(lo), i32(hi_u)); +} + +// ============================================================ +// signed_mul_split: signed 29 × 31-bit product, split into 29-bit chunks. +// +// Pre: |a| <= 2^29 (one BY limb), |b| <= 2^31 - 1 (i32 matrix-entry low). +// Post: returns (lo29, hi) with a*b = lo29 + (hi * 2^29), +// 0 <= lo29 < 2^29; |hi| <= 2^31 (fits in i32 except at the single +// corner a == -2^29, b == -2^31 which the caller does not pass). +// +// Partial-product width split (the plan-mandated proof sketch): +// Split each operand as v = v_lo + v_hi * 2^15 where v_lo is sign- +// extended low 15 bits, i.e. v_lo in [-2^14, 2^14). Then v_hi = (v - +// v_lo) >> 15. +// |a_lo|, |b_lo| <= 2^14 +// |a_hi| <= ceil(2^29 / 2^15) = 2^14 +// |b_hi| <= ceil(2^31 / 2^15) = 2^16 +// Four signed partial products: +// pll = a_lo * b_lo, |pll| <= 2^14 * 2^14 = 2^28 [< 2^31 ✓] +// plh = a_lo * b_hi, |plh| <= 2^14 * 2^16 = 2^30 [< 2^31 ✓] +// phl = a_hi * b_lo, |phl| <= 2^14 * 2^14 = 2^28 [< 2^31 ✓] +// phh = a_hi * b_hi, |phh| <= 2^14 * 2^16 = 2^30 [< 2^31 ✓] +// Middle sum: +// mid = plh + phl, |mid| <= 2^30 + 2^28 < 2^31 [fits i32] +// Every partial fits in signed i32, satisfying the plan's audit. The +// subsequent combination uses u32 wrap arithmetic to assemble the low +// 64 bits of the signed product without further overflow concerns. +fn signed_mul_split(a: i32, b: i32) -> vec2 { + // Sign-extend the low 15 bits of each operand: (v << 17) >> 17 fills + // bits 15..31 with the sign of bit 14. WGSL i32 shifts are well-defined + // (left shift discards high bits; right shift on i32 is arithmetic). + let a_lo: i32 = (a << 17u) >> 17u; + let a_hi: i32 = (a - a_lo) >> 15u; + let b_lo: i32 = (b << 17u) >> 17u; + let b_hi: i32 = (b - b_lo) >> 15u; + + let pll: i32 = a_lo * b_lo; + let plh: i32 = a_lo * b_hi; + let phl: i32 = a_hi * b_lo; + let phh: i32 = a_hi * b_hi; + let mid: i32 = plh + phl; + + // total = pll + mid * 2^15 + phh * 2^30, computed as the low 64 bits + // of an i64 via u32 wrap arithmetic. + // + // Each i32 piece extended to i64 (lo, hi): + // pll: lo = u32(pll), hi = u32(pll >> 31) (sign mask) + // mid << 15: lo = u32(mid) << 15, hi = u32(mid >> 17) + // phh << 30: lo = u32(phh) << 30, hi = u32(phh >> 2) + // The >> on signed pieces is arithmetic so the high half carries the + // correct sign extension. + let pll_lo: u32 = u32(pll); + let pll_hi: u32 = u32(pll >> 31u); + let mid_lo: u32 = u32(mid) << 15u; + let mid_hi: u32 = u32(mid >> 17u); + let phh_lo: u32 = u32(phh) << 30u; + let phh_hi: u32 = u32(phh >> 2u); + + // Sum the three i64 values with carry chains. + let s1_lo: u32 = pll_lo + mid_lo; + let c1: u32 = select(0u, 1u, s1_lo < pll_lo); + let s1_hi: u32 = pll_hi + mid_hi + c1; + + let sum_lo: u32 = s1_lo + phh_lo; + let c2: u32 = select(0u, 1u, sum_lo < s1_lo); + let sum_hi: u32 = s1_hi + phh_hi + c2; + + // Re-split into lo29 (low 29 bits) and hi = ARS by 29 (signed). + let lo29: u32 = sum_lo & 0x1FFFFFFFu; + // hi bits 0..2 from sum_lo[29..31]; bits 3..31 from sum_hi[0..28]. + // sum_hi[29..31] are sign-extension bits under the precondition + // |a*b| <= 2^60 and are absorbed correctly when bitcasting hi_u to i32. + let hi_u: u32 = (sum_lo >> 29u) | (sum_hi << 3u); + return vec2(i32(lo29), i32(hi_u)); +} + +// by_accumulate: acc + m_lo * x_limb as a 58-bit signed two-limb pair. +// +// Pre: |acc.x| < 2^29 (canonical low half), |acc.y| < 2^29 + 2^k for +// some small k (caller bounded), |m_lo| <= 2^29, |x_limb| <= 2^29. +// Post: returns (lo29, hi) with acc + m_lo * x_limb = lo29 + hi * 2^29, +// 0 <= lo29 < 2^29 (canonical low limb). hi grows by at most one +// carry bit per call. +// +// This is the per-limb cross-product helper used by apply_matrix. The +// product |m_lo * x_limb| <= 2^58, well within signed_mul_split's bounds. +fn by_accumulate(acc: vec2, m_lo: i32, x_limb: i32) -> vec2 { + let prod = signed_mul_split(m_lo, x_limb); + // Add (prod.x, prod.y) to (acc.x, acc.y) treating both halves as signed + // 32-bit lanes. acc.x in [0, 2^29), prod.x in [0, 2^29), so the low-half + // sum is in [0, 2^30) — fits i32, no carry to track for this helper. + let lo_sum = acc.x + prod.x; + let lo29 = lo_sum & i32(BY_LIMB_MASK); + // The high half absorbs prod.y plus any bits of lo_sum above bit 28. + // WGSL requires the rhs of \`>>\` to be u32 (the lhs stays i32 — i32 shift + // is arithmetic, preserving sign). + let lo_overflow = lo_sum >> BY_LIMB_BITS; + let hi = acc.y + prod.y + lo_overflow; + return vec2(lo29, hi); +} + +// by_low_u64_lohi: low 64 bits of a BigIntBY as a vec2 (.x = low 32, +// .y = high 32). Mirrors \`Wasm9x29::low_64()\` and the TS \`low64\`: +// bits 0..28 = x.l[0] & LIMB_MASK +// bits 29..57 = x.l[1] & LIMB_MASK +// bits 58..63 = x.l[2] & 0x3F +// All inputs treated as unsigned u32 bit patterns. Negative limbs +// reinterpreted via \`u32(...)\` — the caller is responsible for ensuring the +// lower three limbs are canonical (in [0, 2^29) etc.), which is the case in +// the BY driver after each by_normalise. +fn by_low_u64_lohi(x: BigIntBY) -> vec2 { + let l0: u32 = u32(x.l[0]) & BY_LIMB_MASK; + let l1: u32 = u32(x.l[1]) & BY_LIMB_MASK; + let l2: u32 = u32(x.l[2]) & 0x3Fu; + // low 32 bits: l0 in bits 0..28, low 3 bits of l1 at 29..31. + let lo32: u32 = l0 | ((l1 & 0x7u) << 29u); + // high 32 bits: high 26 bits of l1 at 0..25, l2 (6 bits) at 26..31. + let hi32: u32 = (l1 >> 3u) | (l2 << 26u); + return vec2(lo32, hi32); +} + +// by_normalise: carry-propagate so each limb i in [0, N-1) ends up in +// [0, 2^29) and the top limb absorbs the final signed carry. +// +// Pre: any signed limb values (caller may pass non-canonical ±2^60). +// Post: x.l[0..N-2] in [0, 2^29); x.l[N-1] is the signed extension. +fn by_normalise(x: ptr) { + var c: i32 = 0; + for (var i: u32 = 0u; i < BY_NUM_LIMBS - 1u; i = i + 1u) { + let v = (*x).l[i] + c; + (*x).l[i] = v & i32(BY_LIMB_MASK); + // Arithmetic shift on i32 propagates sign; matches BigInt(v) >> 29n in TS. + // WGSL requires the rhs of \`>>\` to be u32 (the lhs i32 then gets the + // arithmetic shift semantics we want). + c = v >> BY_LIMB_BITS; + } + (*x).l[BY_NUM_LIMBS - 1u] = (*x).l[BY_NUM_LIMBS - 1u] + c; +} + +// ============================================================ +// Conversion between 20×13-bit BigInt and 9×29-bit BigIntBY. +// +// The conversion is a bit-rewrite: read 29 source bits at offset +// 29*k into limb k. Each 29-bit window spans up to 4 consecutive +// source limbs (worst case bit-offset 12, covers 27 bits in 3 limbs; +// needs 2 more from a 4th). +// +// Pre (by_from_bigint): x in [0, 2^260) represented in 20 × 13-bit limbs. +// Post: result in [0, 2^256), 9 canonical 29-bit limbs. +// Pre (by_to_bigint): x canonical (lower limbs in [0, 2^29), top non-neg). +// Post: result in [0, 2^260), 20 × 13-bit unsigned limbs. +// ============================================================ + +const BY_SRC_LIMB_BITS: u32 = 13u; +const BY_SRC_LIMB_MASK: u32 = (1u << 13u) - 1u; + +// Read a 29-bit window from the 20×13-bit BigInt starting at bit \`bit_lo\`. +// The window spans up to 4 consecutive source limbs (ceil((29 + 12) / 13) = 4). +fn by_read_29bit_window(x: BigInt, bit_lo: u32) -> u32 { + let base_idx = bit_lo / BY_SRC_LIMB_BITS; // first source limb touched + let in_limb_bit = bit_lo % BY_SRC_LIMB_BITS; // bit offset within base limb + // Up to four consecutive limbs (out-of-range reads return 0). + var v0: u32 = 0u; + var v1: u32 = 0u; + var v2: u32 = 0u; + var v3: u32 = 0u; + if (base_idx < {{ num_words }}u) { v0 = x.limbs[base_idx]; } + if (base_idx + 1u < {{ num_words }}u) { v1 = x.limbs[base_idx + 1u]; } + if (base_idx + 2u < {{ num_words }}u) { v2 = x.limbs[base_idx + 2u]; } + if (base_idx + 3u < {{ num_words }}u) { v3 = x.limbs[base_idx + 3u]; } + // Stitch: contributions are v0 >> in_limb_bit, v1..v3 at increasing offsets. + // The v3 shift can reach 13*3 - 0 = 39 which exceeds 32 — but in WGSL + // shifting a 13-bit limb by >= 32 yields zero by spec; we only need v3 + // when in_limb_bit > 10, where the shift is 39 - in_limb_bit ∈ [27, 28]. + // For in_limb_bit < 11 the contribution of v3 is irrelevant (covered by + // v0..v2). For safety we mask before shifting so >= 32 shifts vanish. + let s0 = v0 >> in_limb_bit; + let s1 = v1 << (BY_SRC_LIMB_BITS - in_limb_bit); + let s2 = v2 << ((BY_SRC_LIMB_BITS * 2u) - in_limb_bit); + let shift3: u32 = (BY_SRC_LIMB_BITS * 3u) - in_limb_bit; + // shift3 is in [27, 39]; for shift3 >= 32 we explicitly produce 0 to + // avoid WGSL's undefined-behavior region for shifts >= bit-width. + var s3: u32 = 0u; + if (shift3 < 32u) { + s3 = v3 << shift3; + } + return (s0 | s1 | s2 | s3) & BY_LIMB_MASK; +} + +fn by_from_bigint(x: BigInt) -> BigIntBY { + var r: BigIntBY; + for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u) { + r.l[i] = i32(by_read_29bit_window(x, i * BY_LIMB_BITS)); + } + return r; +} + +// Write a 13-bit window into the 20×13-bit BigInt at bit \`bit_lo\`. +// Reads three consecutive 29-bit BY limbs and emits the right 13-bit slice. +fn by_read_13bit_window(x: BigIntBY, bit_lo: u32) -> u32 { + let base_idx = bit_lo / BY_LIMB_BITS; + let in_limb_bit = bit_lo % BY_LIMB_BITS; + // Up to two BY limbs cover any 13-bit window (since 29 > 13). + var v0: u32 = 0u; + var v1: u32 = 0u; + if (base_idx < BY_NUM_LIMBS) { v0 = u32(x.l[base_idx]); } + if (base_idx + 1u < BY_NUM_LIMBS) { v1 = u32(x.l[base_idx + 1u]); } + let s0 = v0 >> in_limb_bit; + // Shift amount for v1 is (29 - in_limb_bit). In WGSL shifts of >= 32 are + // undefined; (29 - in_limb_bit) is in [1, 29] so always < 32. + let shift_up = BY_LIMB_BITS - in_limb_bit; + let s1 = v1 << shift_up; + return (s0 | s1) & BY_SRC_LIMB_MASK; +} + +fn by_to_bigint(x: BigIntBY) -> BigInt { + var out: BigInt; + for (var i: u32 = 0u; i < {{ num_words }}u; i = i + 1u) { + out.limbs[i] = by_read_13bit_window(x, i * BY_SRC_LIMB_BITS); + } + return out; +} +`; + export const bigint_f32 = `// f32-limb mirror of \`bigint.template.wgsl\`. Each limb holds an // integer-valued f32 in [0, 2^WORD_SIZE_F32) = [0, W). Mirrors the // subset of \`bigint.template.wgsl\` needed by \`montgomery_product_f32\` @@ -920,6 +1263,94 @@ fn add_points_mixed_no_collision(p1: Point, p2: Point) -> Point { } `; +export const apply_matrix_bench = `// Single-thread-per-input bench shader for \`by_apply_matrix_fg\` / +// \`by_apply_matrix_de\`. Each thread runs both passes on one input record +// and writes the resulting 4 × 9 limbs (f', g', d', e') as 36 i32 outputs. +// +// INPUT LAYOUT (44 u32 per thread): +// Offset 0..7: Mat fields { u, v, q, r, u_hi, v_hi, q_hi, r_hi } as i32. +// Offset 8..16: f limbs (9 × i32) +// Offset 17..25: g limbs (9 × i32) +// Offset 26..34: d limbs (9 × i32) +// Offset 35..43: e limbs (9 × i32) +// All limbs are signed-canonical (lower limbs in [0, 2^29), top limb signed). +// +// OUTPUT LAYOUT (36 i32 per thread): +// Offset 0..8: f' limbs (9 × i32) — output of by_apply_matrix_fg +// Offset 9..17: g' limbs (9 × i32) — output of by_apply_matrix_fg +// Offset 18..26: d' limbs (9 × i32) — output of by_apply_matrix_de +// Offset 27..35: e' limbs (9 × i32) — output of by_apply_matrix_de +// +// CONSTANTS (Mustache injected): +// p_limbs_by: BigIntBY initializer for the BN254 base-field modulus p. +// p_inv_by_lo: low 32 bits of P_INV = p^(-1) mod 2^58. +// p_inv_by_hi: high 32 bits of P_INV (only the low 26 bits are non-zero). +// +// LOOP BOUNDS: only inherited loops — by_apply_matrix_fg and +// by_apply_matrix_de each use a single \`for\` bounded by BY_NUM_LIMBS +// (= 9u const). The dispatch entry shader itself has no loops. +// +// SAFETY: \`if (tid >= n) return;\` static guard, no data-dependent loops. + +@group(0) @binding(0) var inputs: array; +@group(0) @binding(1) var outputs: array; +@group(0) @binding(2) var params: vec2; + +const APPLY_MATRIX_INPUT_STRIDE: u32 = 44u; // 8 (Mat) + 4 * 9 (limbs) +const APPLY_MATRIX_OUTPUT_STRIDE: u32 = 36u; // 4 * 9 (limbs) +const P_INV_BY_LO: u32 = {{ p_inv_by_lo }}u; +const P_INV_BY_HI: u32 = {{ p_inv_by_hi }}u; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let n = params.x; + let tid = gid.x; + if (tid >= n) { return; } + + let in_base: u32 = tid * APPLY_MATRIX_INPUT_STRIDE; + let out_base: u32 = tid * APPLY_MATRIX_OUTPUT_STRIDE; + + // Decode Mat. The 8 i32 fields land in bindings as u32 — bitcast. + let m: Mat = Mat( + bitcast(inputs[in_base + 0u]), + bitcast(inputs[in_base + 1u]), + bitcast(inputs[in_base + 2u]), + bitcast(inputs[in_base + 3u]), + bitcast(inputs[in_base + 4u]), + bitcast(inputs[in_base + 5u]), + bitcast(inputs[in_base + 6u]), + bitcast(inputs[in_base + 7u]), + ); + + // Decode f, g, d, e as BigIntBY (9 × i32 each). + var f: BigIntBY; + var g: BigIntBY; + var d: BigIntBY; + var e: BigIntBY; + for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u) { + f.l[i] = bitcast(inputs[in_base + 8u + i]); + g.l[i] = bitcast(inputs[in_base + 17u + i]); + d.l[i] = bitcast(inputs[in_base + 26u + i]); + e.l[i] = bitcast(inputs[in_base + 35u + i]); + } + + // Modulus p, fixed across all threads — Mustache-injected. + var p: BigIntBY = BigIntBY(array({{{ p_limbs_by }}})); + + // Apply matrix passes. + by_apply_matrix_fg(m, &f, &g); + by_apply_matrix_de(m, &d, &e, &p, P_INV_BY_LO, P_INV_BY_HI); + + // Write outputs. + for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u) { + outputs[out_base + 0u + i] = f.l[i]; + outputs[out_base + 9u + i] = g.l[i]; + outputs[out_base + 18u + i] = d.l[i]; + outputs[out_base + 27u + i] = e.l[i]; + } +} +`; + export const barrett = `const W_MASK = {{ w_mask }}u; const SLACK = {{ slack }}u; @@ -1261,6 +1692,11 @@ var round_count: array>; @group(0) @binding(2) var dispatch_args: array; +// WINDOWS_PER_BATCH — baked at render time. Controls the inverse-pass +// Z dispatch dim: num_batches = ceil(num_subtasks / WPB). See +// batch_inverse_parallel.template.wgsl for the merged-pool inverse. +const WPB: u32 = {{ windows_per_batch }}u; + // params[0] = num_subtasks // params[1] = apply_workgroup_size // params[2] = sched_x_groups (= ceil(num_columns / schedule_workgroup_size)) @@ -1275,6 +1711,7 @@ fn main() { let apply_wg_size = params[1]; let sched_x_groups = params[2]; let num_sub_wgs = params[3]; + let num_batches = (num_subtasks + WPB - 1u) / WPB; var max_count: u32 = 0u; for (var i: u32 = 0u; i < num_subtasks; i = i + 1u) { @@ -1299,13 +1736,14 @@ fn main() { dispatch_args[1] = 1u; dispatch_args[2] = num_subtasks; - // THIS round's inverse: (W, 1, T) workgroups — W sub-WGs per subtask - // splitting each subtask's pair pool into W contiguous slices, each - // independently inverted with its own fr_inv. Drops Phase A/D - // per-thread sequential cost by W. See batch_inverse_parallel for - // the algorithm. + // THIS round's inverse: (W, 1, num_batches) workgroups — W sub-WGs + // per (batch, sub_wg) splitting each batch's MERGED pair pool (of + // size sum_{w} round_count[batch * WPB + w]) into W contiguous + // slices, each independently inverted with its own fr_inv. Pooling + // amortises one fr_inv across WPB subtasks. Z dim drops from T to + // ceil(T/WPB) accordingly. let inverse_x = select(0u, num_sub_wgs, any_work); - let inverse_z = select(0u, num_subtasks, any_work); + let inverse_z = select(0u, num_batches, any_work); dispatch_args[3] = inverse_x; dispatch_args[4] = 1u; dispatch_args[5] = inverse_z; @@ -2108,6 +2546,8 @@ export const batch_inverse = `{{> structs }} {{> montgomery_product_funcs }} {{> field_funcs }} {{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} // get_r returns Montgomery R (= the integer 1 in Montgomery form). Each // main shader defines its own with \`r_limbs\` substitution; the @@ -2179,7 +2619,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { } // Step 2: invert the final product. - var inv_acc: BigInt = fr_inv(acc); + var inv_acc: BigInt = fr_inv_by_a(acc); // Step 3: walk back, emitting individual inverses. for (var idx = 0u; idx < n - 1u; idx = idx + 1u) { @@ -2202,6 +2642,8 @@ export const batch_inverse_parallel = `{{> structs }} {{> montgomery_product_funcs }} {{> field_funcs }} {{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} // Parallel Montgomery batch-inverse on the GPU. // @@ -2212,74 +2654,65 @@ export const batch_inverse_parallel = `{{> structs }} // workgroup of TPB threads, dropping the per-dispatch cost from ~30 ms // to ~1-3 ms at the realistic round sizes we see. // -// MULTI-WORKGROUP MODE. The dispatch is shaped (NUM_SUB_WGS, 1, T). -// \`wid.z\` selects the subtask (count_buf[subtask]) and \`wid.x\` is the -// sub-workgroup index inside that subtask (0..NUM_SUB_WGS-1). Each -// sub-workgroup independently inverts a contiguous slice of length -// per_sub_chunk = ceil(n / NUM_SUB_WGS), using its own fr_inv. Reasoning: -// -// - Phase A (per-thread fwd) and Phase D (per-thread back-walk) are -// sequential per thread with bs = ceil(n / TPB). Splitting each -// subtask across W sub-workgroups drops bs by W, giving up to W× -// speedup on the dominant per-round latency at large N. -// -// - Each sub-workgroup runs its own fr_inv. fr_invs across sub-WGs -// run concurrently on different SMs, so the EXTRA fr_invs cost zero -// wall time (gated only by SM occupancy). With T=16 subtasks × W=8 -// sub-WGs = 128 workgroups in flight, RTX-class GPUs are fully -// occupied during the inverse pass. +// MULTI-WORKGROUP + WINDOWS-PER-BATCH MODE. The dispatch is shaped +// (NUM_SUB_WGS, 1, num_batches), where num_batches = ceil(T / WPB). +// \`wid.z\` selects the batch; \`wid.x\` is the sub-workgroup index inside +// the batch. Each (wid.z, wid.x) pair inverts the COMBINED pair pool of +// WPB consecutive subtasks, treating it as one big logical pool of total +// length sum_{w=0..WPB-1} count_buf[batch * WPB + w], with one +// fr_inv_by_a call per (batch, sub_wg). +// +// Amortisation: with WPB=4 we run T/WPB × W fr_inv_by_a calls per round +// instead of T × W — a 4× drop in the dominant Phase C cost. The +// per-subtask layout of the pair pool (and the \`prefix\` / \`outputs\` +// scratch buffers) is unchanged, so \`batch_affine_apply_scatter\` reads +// outputs[subtask_global * pitch + slot_in_subtask] just like before. +// The kernel "sews together" the gaps between subtask slices in the +// scan: a logical position p in the merged pool decodes to +// (subtask_in_batch, slot_in_subtask) via the per-batch exclusive +// prefix-sum \`wg_cum\`, then to a physical buffer position +// (batch * WPB + subtask_in_batch) * pitch + slot_in_subtask. +// +// Per-thread walks advance a (cur_w, cur_off) cursor as they step +// through their chunk so the per-step decode is O(1); the only WPB-loop +// is the const-bounded prefix-sum setup at the top of the kernel. // // Two clients today: -// - SMVP (cross-subtask): pitch = num_columns, count_buf[wid.z] = -// pair_counter[wid.z] (per-subtask atomic). -// - Finalize: pitch = half_num_columns, count_buf[wid.z] = -// half_num_columns for all wid.z (pre-populated by host). +// - SMVP (cross-subtask): pitch = num_columns, count_buf[g] = +// round_count[g] (per-subtask, forwarded by dispatch_args). +// - Finalize: pitch = half_num_columns, count_buf[g] = +// half_num_columns for all g (pre-populated by host). With WPB=1 in +// the finalize call this degenerates to the previous behaviour. // -// Algorithm (Montgomery's batch-inverse trick, two-level): +// Algorithm (Montgomery's batch-inverse trick, two-level, merged pool): // // 1. Each thread i computes block_inclusive_prefix[k] for k in its // chunk into the \`prefix\` scratch buffer (serial, length bs = -// ceil(n / TPB)). Captures block_total[i] in a register. +// ceil(n / TPB)). Captures block_total[i] in a register. The walk +// is over LOGICAL positions in the merged pool, but each write +// lands at the corresponding PHYSICAL buffer position (decoded via +// wg_cum). // 2. All TPB block_totals are pushed into workgroup memory. // 3. Two parallel inclusive scans over wg memory: forward (wg_fwd) // and backward (wg_bwd). Hillis-Steele, log2(TPB) passes. -// After scan: -// wg_fwd[i] = block_total[0] * ... * block_total[i] -// wg_bwd[i] = block_total[i] * ... * block_total[TPB-1] -// So global_total = wg_fwd[TPB-1] = wg_bwd[0]. -// 4. Thread 0 computes inv_global = inv(global_total) — single fr_inv. -// Broadcast via wg_inv_total. -// 5. Each thread walks back through its chunk. The key trick: within -// a single block, block_excl_prefix cancels algebraically, so the -// back-walk reduces to the standard 2-mul/element batch inverse -// on the per-block prefix array. -// -// inv(global_prefix[k]) · global_prefix[k-1] -// = inv(block_excl_prefix · P_in[k]) · (block_excl_prefix · P_in[k-1]) -// = inv(P_in[k]) · P_in[k-1] // block_excl_prefix cancels -// -// So: -// block_excl_prefix[i] = wg_fwd[i-1] (or R for i=0) -// block_excl_suffix[i] = wg_bwd[i+1] (or R for i=TPB-1) -// // Setup: inv(block_total) = inv_global · block_excl_prefix · block_excl_suffix -// inv_acc = inv(block_total[i]) // 2 muls (setup, one-time) -// For k from chunk_end-1 down to chunk_start: -// out[k] = inv_acc · (k>chunk_start ? prefix[k-1] : R) // 1 mul (or 0) -// inv_acc = inv_acc · inputs[k] // 1 mul -// -// Cost: 2N muls in back-walk + N muls in forward + 2 muls setup -// = 3N + O(1) per workgroup. Previously 3N + N back-walk muls -// (extra mul-by-block_excl_prefix per element) = 4N total. +// 4. Thread 0 computes inv_global = inv(global_total) — single +// fr_inv_by_a per (batch, sub_wg). Broadcast via wg_inv_total. +// 5. Each thread walks back through its chunk. Within a single block, +// block_excl_prefix cancels algebraically, so the back-walk runs +// as the standard 2-mul/element batch inverse on the per-block +// prefix array (same as before). Outputs land at physical buffer +// positions matching the input layout. // // TPB = 64 keeps workgroup memory at 2 * 64 * sizeof(BigInt) = 10240 // bytes (BN254: BigInt = 80 bytes), comfortably under the WebGPU // default maxComputeWorkgroupStorageSize = 16384. // -// When n == 0 for this workgroup the kernel returns immediately at the -// top. +// When total_n == 0 for this (batch, sub_wg) the kernel returns +// immediately at the top. const TPB: u32 = 64u; const NUM_SUB_WGS: u32 = {{ num_sub_wgs }}u; +const WPB: u32 = {{ windows_per_batch }}u; @group(0) @binding(0) var inputs: array; @@ -2293,7 +2726,7 @@ var outputs: array; @group(0) @binding(3) var count_buf: array>; -// params[0] = pitch (per-workgroup slice stride) +// params[0] = pitch (per-subtask slice stride inside inputs/prefix/outputs) // params[1..3] = unused @group(0) @binding(4) var params: vec4; @@ -2307,19 +2740,44 @@ fn get_r() -> BigInt { var wg_fwd: array; var wg_bwd: array; var wg_inv_total: BigInt; -// Broadcast slot for the pair count. atomicLoad returns a non-uniform -// value as far as the WGSL uniformity analysis is concerned, so the -// downstream \`workgroupBarrier()\`s would be ill-formed if we branched -// directly on it. Funnelling it through a workgroup variable + an -// explicit \`workgroupUniformLoad\` re-uniforms the value (and acts as -// an implicit barrier). -var wg_n: u32; -// Per-sub-WG element offset into the subtask's slice. Set by tid 0 -// alongside wg_n; broadcast through workgroup memory so every thread -// agrees on the slice base. (workgroupUniformLoad over wg_n implicitly -// synchronises this write too, since both are written before the load.) +// Per-batch exclusive prefix of subtask counts: wg_cum[w] is the +// cumulative count of subtasks 0..w-1 within this batch (so wg_cum[0] +// is always 0 and the merged-pool total length is +// wg_cum[WPB-1] + wg_counts[WPB-1]). Used to decode "logical position +// p in the merged pool" → "(subtask_in_batch w, slot s in that +// subtask)" via the largest w with wg_cum[w] <= p. +var wg_counts: array; +var wg_cum: array; +// Broadcast slot for the merged-pool total count. atomicLoad returns a +// non-uniform value as far as the WGSL uniformity analysis is +// concerned, so the downstream \`workgroupBarrier()\`s would be +// ill-formed if we branched directly on it. Funnelling it through a +// workgroup variable + an explicit \`workgroupUniformLoad\` re-uniforms +// the value (and acts as an implicit barrier). +var wg_total_n: u32; +// Per-sub-WG element offset into the merged pool. Set by tid 0 +// alongside wg_total_n; broadcast through workgroup memory so every +// thread agrees on the slice base. (workgroupUniformLoad over +// wg_total_n implicitly synchronises this write too, since both are +// written before the load.) var wg_sub_offset: u32; +// Decode a logical position p in the merged batch pool into +// (subtask_in_batch, slot_in_subtask). The WPB-loop is const-bounded. +struct PoolPos { w: u32, slot: u32 }; +fn decode_pool_pos(p: u32) -> PoolPos { + var out: PoolPos; + out.w = 0u; + out.slot = p; + for (var i = 1u; i < WPB; i = i + 1u) { + if (p >= wg_cum[i]) { + out.w = i; + out.slot = p - wg_cum[i]; + } + } + return out; +} + @compute @workgroup_size(64) fn main( @@ -2328,15 +2786,27 @@ fn main( ) { let tid = lid.x; let pitch = params[0]; - let subtask_idx = wid.z; + let batch_idx = wid.z; let sub_idx = wid.x; + let subtask_base = batch_idx * WPB; - // Thread 0 reads the subtask's atomic count and the sub-WG's slice - // bounds; broadcast via wg_n. workgroupUniformLoad ensures all - // threads see a uniform \`n\` (= this sub-WG's element count) and - // synchronises the workgroup before continuing. + // Thread 0 reads the WPB per-subtask atomic counts in this batch, + // computes the exclusive prefix-sum into wg_cum + wg_counts, and + // resolves the sub-WG's slice bounds inside the merged pool. + // Broadcasts via wg_total_n + wg_sub_offset. workgroupUniformLoad + // over wg_total_n then re-uniforms across the workgroup. if (tid == 0u) { - let total_n = atomicLoad(&count_buf[subtask_idx]); + var cum: u32 = 0u; + for (var w = 0u; w < WPB; w = w + 1u) { + let g = subtask_base + w; + // count_buf is sized to a multiple of WPB so the tail loads + // here return the host's initial zero (no out-of-bounds). + let c = atomicLoad(&count_buf[g]); + wg_counts[w] = c; + wg_cum[w] = cum; + cum = cum + c; + } + let total_n = cum; // Split [0, total_n) into NUM_SUB_WGS contiguous chunks. Chunks // are sized ceil(total_n / NUM_SUB_WGS); trailing sub-WGs may // see a shorter (or empty) chunk if total_n isn't a multiple of @@ -2351,39 +2821,70 @@ fn main( sub_n = clamped_end - raw_start; } wg_sub_offset = raw_start; - wg_n = sub_n; + wg_total_n = sub_n; } - let n = workgroupUniformLoad(&wg_n); + let n = workgroupUniformLoad(&wg_total_n); if (n == 0u) { return; } - let sub_offset_in_subtask = wg_sub_offset; - - // Subtask owns a slice of \`pitch\` elements at offset - // subtask_idx * pitch. This sub-WG owns - // [subtask_offset + sub_offset_in_subtask, - // subtask_offset + sub_offset_in_subtask + n). - let slice_offset = subtask_idx * pitch + sub_offset_in_subtask; + let sub_offset_in_batch = wg_sub_offset; // Block size = ceil(n / TPB). Last thread may have a shorter chunk. let bs = (n + TPB - 1u) / TPB; - let chunk_start = tid * bs; - var chunk_end = chunk_start + bs; - if (chunk_end > n) { - chunk_end = n; + let chunk_start_in_sub = tid * bs; + var chunk_end_in_sub = chunk_start_in_sub + bs; + if (chunk_end_in_sub > n) { + chunk_end_in_sub = n; } + // Translate this thread's [chunk_start, chunk_end) window from + // sub-WG-local indexing into batch-merged-pool indexing. + let chunk_start = sub_offset_in_batch + chunk_start_in_sub; + let chunk_end = sub_offset_in_batch + chunk_end_in_sub; + // Phase A: per-thread inclusive prefix product over the chunk. - // block_total = product of all elements in this thread's chunk, - // or R (Montgomery 1) if the chunk is empty. + // Each k iterates over LOGICAL positions in the merged pool but + // every read/write lands at the PHYSICAL buffer position decoded + // through wg_cum. block_total = product of all elements in this + // thread's chunk, or R (Montgomery 1) if the chunk is empty. var block_total: BigInt = get_r(); - if (chunk_start < n) { - var acc: BigInt = inputs[slice_offset + chunk_start]; - prefix[slice_offset + chunk_start] = acc; + if (chunk_start_in_sub < n) { + // Decode the starting logical position once, then carry + // (cur_w, cur_off) forward through the chunk so the per-step + // decode is O(1). cur_off is the slot within the current + // subtask; when it reaches wg_counts[cur_w] we advance to the + // next subtask. + let start_pos = decode_pool_pos(chunk_start); + var cur_w: u32 = start_pos.w; + var cur_off: u32 = start_pos.slot; + let g0 = subtask_base + cur_w; + let phys0 = g0 * pitch + cur_off; + var acc: BigInt = inputs[phys0]; + prefix[phys0] = acc; + // Advance one step past the seed (k = chunk_start) and walk to + // chunk_end. The k variable here is the BATCH-MERGED logical + // position; advance cur_off in lockstep, rolling cur_w + cur_off + // across the wg_counts[cur_w] boundary when the subtask runs out. for (var k = chunk_start + 1u; k < chunk_end; k = k + 1u) { - var x: BigInt = inputs[slice_offset + k]; + cur_off = cur_off + 1u; + // Bounded rollover: cross at most one boundary per step + // (counts never exceed pitch, which exceeds bs in practice). + // The const-bounded loop guards against any degenerate + // input where wg_counts[cur_w] could be zero — we still + // can't iterate past WPB steps because cur_w is in [0, WPB). + for (var hop = 0u; hop < WPB; hop = hop + 1u) { + if (cur_off >= wg_counts[cur_w]) { + cur_w = cur_w + 1u; + cur_off = 0u; + } else { + break; + } + } + let g = subtask_base + cur_w; + let phys = g * pitch + cur_off; + var x: BigInt = inputs[phys]; acc = montgomery_product(&acc, &x); - prefix[slice_offset + k] = acc; + prefix[phys] = acc; } block_total = acc; } @@ -2413,14 +2914,16 @@ fn main( } // Phase C: thread 0 inverts the global total. Broadcast via workgroup mem. + // One fr_inv_by_a per (batch, sub_wg), independent of WPB — that's + // the amortisation this layout buys. if (tid == 0u) { var global_total: BigInt = wg_fwd[TPB - 1u]; - wg_inv_total = fr_inv(global_total); + wg_inv_total = fr_inv_by_a(global_total); } workgroupBarrier(); // Phase D: walk back through this thread's chunk, emitting inverses. - if (chunk_start >= n) { + if (chunk_start_in_sub >= n) { return; } @@ -2445,6 +2948,13 @@ fn main( var inv_acc: BigInt = montgomery_product(&inv_global, &block_excl_prefix); inv_acc = montgomery_product(&inv_acc, &block_excl_suffix); + // Decode the END position (chunk_end - 1) and walk backward, + // mirroring Phase A's forward decode. Maintain (cur_w, cur_off) as + // we step k from chunk_end-1 down to chunk_start. + let end_pos = decode_pool_pos(chunk_end - 1u); + var cur_w: u32 = end_pos.w; + var cur_off: u32 = end_pos.slot; + // Walk from k = chunk_end-1 down to chunk_start. Within a block, // block_excl_prefix cancels algebraically (see header comment), so // we run the standard backward batch-inverse over the per-block @@ -2453,23 +2963,288 @@ fn main( while (k > chunk_start) { k = k - 1u; - // out[k] = inv_acc * prefix_in_block[k-1] (or inv_acc itself for k = chunk_start) + let g = subtask_base + cur_w; + let phys = g * pitch + cur_off; + + // out[k] = inv_acc * prefix_in_block[k-1] (or inv_acc itself + // for k = chunk_start). The "k-1" position in the BLOCK is the + // previous logical position in the merged pool; if that lies + // in a different subtask, prefix is still indexed by physical + // position so we decode the neighbour separately. var inv_a_k: BigInt; if (k > chunk_start) { - var prev_in_block: BigInt = prefix[slice_offset + k - 1u]; + // Compute the physical position of (k - 1). This is the + // logical predecessor of the current step. We carry a + // dedicated (prev_w, prev_off) cursor by mirroring the + // forward-from-decode but starting from (cur_w, cur_off): + // step back by one slot; if cur_off was 0, roll back to + // the previous subtask's last slot. + var prev_w: u32 = cur_w; + var prev_off: u32; + if (cur_off > 0u) { + prev_off = cur_off - 1u; + } else { + // Roll back to the largest preceding subtask with a + // non-empty count. Const-bounded by WPB. + prev_off = 0u; + for (var hop = 0u; hop < WPB; hop = hop + 1u) { + if (prev_w == 0u) { + // Should not happen — we'd be past chunk_start. + break; + } + prev_w = prev_w - 1u; + if (wg_counts[prev_w] > 0u) { + prev_off = wg_counts[prev_w] - 1u; + break; + } + } + } + let g_prev = subtask_base + prev_w; + let phys_prev = g_prev * pitch + prev_off; + var prev_in_block: BigInt = prefix[phys_prev]; inv_a_k = montgomery_product(&inv_acc, &prev_in_block); } else { inv_a_k = inv_acc; } - outputs[slice_offset + k] = inv_a_k; + outputs[phys] = inv_a_k; // Update: inv_acc <- inv_acc * a[k] = inv(prefix_in_block[k-1]) // for the next iteration. The update on the last iteration // (k = chunk_start) is wasted — minor cost, kept for code // simplicity; the loop exit condition skips it naturally. if (k > chunk_start) { - var a_k: BigInt = inputs[slice_offset + k]; + var a_k: BigInt = inputs[phys]; inv_acc = montgomery_product(&inv_acc, &a_k); + // Step (cur_w, cur_off) back by one logical position, same + // rollover logic as the prev_* decode above. + if (cur_off > 0u) { + cur_off = cur_off - 1u; + } else { + for (var hop = 0u; hop < WPB; hop = hop + 1u) { + if (cur_w == 0u) { + break; + } + cur_w = cur_w - 1u; + if (wg_counts[cur_w] > 0u) { + cur_off = wg_counts[cur_w] - 1u; + break; + } + } + } + } + } + + {{{ recompile }}} +} +`; + +export const bench_batch_affine = `{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +// Standalone single-dispatch micro-benchmark for batch-affine EC addition. +// Measures the per-pair cost of \`R_i = P_i + Q_i\` using Montgomery's +// batch-inverse trick at a fixed BATCH_SIZE, swept across a range of sizes +// by the host to find the sweet spot for amortising the single +// \`fr_inv_by_a\` call per batch. +// +// Single-workgroup-per-batch layout. Each workgroup of TPB threads +// processes one batch of BATCH_SIZE consecutive pairs. The host dispatches +// (TOTAL_PAIRS / BATCH_SIZE, 1, 1) workgroups. +// +// Phase A: each thread serially walks its BS = BATCH_SIZE / TPB pairs, +// computing delta_x_k = Q_k.x - P_k.x and accumulating the running +// product \`prefix[k] = prefix[k-1] * delta_x_k\` for k in +// [chunk_start, chunk_end). Captures \`block_total\` = product of +// all chunk deltas. +// +// Phase B: workgroup-shared Hillis-Steele scan (forward + backward) over +// the TPB block_totals. +// +// Phase C: thread 0 runs the SINGLE \`fr_inv_by_a\` call on the overall +// product and broadcasts the result via workgroup memory. +// +// Phase D: each thread back-walks its chunk emitting one affine add per +// step. Uses standard 2-mul/element backward batch-inverse: +// \`inv_acc = inv(block_total)\` at the start, then for each k from +// chunk_end-1 down to chunk_start: +// inv_dx = (k > chunk_start) ? inv_acc * prefix[k-1] : inv_acc +// ... affine add ... +// outputs[2k] = R.x +// outputs[2k+1] = R.y +// if (k > chunk_start) inv_acc = inv_acc * delta_x_k +// +// LOOP BOUNDS — every loop in this kernel is bounded by a compile-time +// \`const\`: +// - PHASE_A_LOOP and PHASE_D_LOOP iterate \`BS\` times (compile-time +// Mustache constant \`{{ per_thread_count }}\`). +// - PHASE_B_LOOP runs while \`stride < TPB\` (compile-time const). +// - Every inner \`montgomery_product\` / \`fr_inv_by_a\` / \`fr_sub\` loop is +// bounded by \`NUM_WORDS\` or other compile-time constants in the +// included partials. +// +// MEMORY BUDGET. Workgroup memory: 2 × TPB × sizeof(BigInt) = 2 × TPB × 80 +// bytes (BN254 BigInt = 20 × u32 = 80 B) + 1 × BigInt broadcast slot. +// At TPB=64 this is 10 KiB + 80 B, well under the 16 KiB default. + +const BATCH_SIZE: u32 = {{ batch_size }}u; +const TPB: u32 = {{ tpb }}u; +const BS: u32 = {{ per_thread_count }}u; // = BATCH_SIZE / TPB + +@group(0) @binding(0) +var inputs: array; + +@group(0) @binding(1) +var prefix: array; + +@group(0) @binding(2) +var outputs: array; + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +var wg_fwd: array; +var wg_bwd: array; +var wg_inv_total: BigInt; + +@compute +@workgroup_size({{ tpb }}) +fn main( + @builtin(local_invocation_id) lid: vec3, + @builtin(workgroup_id) wid: vec3, +) { + let tid = lid.x; + let batch_idx = wid.x; + let batch_base = batch_idx * BATCH_SIZE; + let chunk_start = tid * BS; + // chunk_end is statically BS past chunk_start since BATCH_SIZE is an + // exact multiple of TPB by construction; no clamping needed. + let chunk_end = chunk_start + BS; + + // Phase A: per-thread inclusive-prefix product over \`delta_x_k\`. + var block_total: BigInt = get_r(); + // First iter (k = chunk_start): seed acc with delta_x at that slot. + { + let k = chunk_start; + let pair_base = (batch_base + k) * 4u; + var p_x: BigInt = inputs[pair_base + 0u]; + var q_x: BigInt = inputs[pair_base + 2u]; + var dx: BigInt = fr_sub(&q_x, &p_x); + // prefix slot is per-WG-relative inside this batch's slice. + prefix[batch_base + k] = dx; + block_total = dx; + } + // Subsequent iters: const-bounded loop over the remaining BS-1 slots. + // PHASE_A_LOOP: bound = BS (compile-time constant). + for (var i = 1u; i < BS; i = i + 1u) { + let k = chunk_start + i; + let pair_base = (batch_base + k) * 4u; + var p_x: BigInt = inputs[pair_base + 0u]; + var q_x: BigInt = inputs[pair_base + 2u]; + var dx: BigInt = fr_sub(&q_x, &p_x); + block_total = montgomery_product(&block_total, &dx); + prefix[batch_base + k] = block_total; + } + + wg_fwd[tid] = block_total; + wg_bwd[tid] = block_total; + workgroupBarrier(); + + // Phase B: Hillis-Steele forward + backward scans on the TPB + // block_totals. PHASE_B_LOOP: bound = TPB (compile-time constant). + for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { + var fwd_x: BigInt = wg_fwd[tid]; + if (tid >= stride) { + var lhs: BigInt = wg_fwd[tid - stride]; + fwd_x = montgomery_product(&lhs, &fwd_x); + } + var bwd_x: BigInt = wg_bwd[tid]; + if (tid + stride < TPB) { + var rhs: BigInt = wg_bwd[tid + stride]; + bwd_x = montgomery_product(&bwd_x, &rhs); + } + workgroupBarrier(); + wg_fwd[tid] = fwd_x; + wg_bwd[tid] = bwd_x; + workgroupBarrier(); + } + + // Phase C: thread 0 inverts the global total via fr_inv_by_a. + if (tid == 0u) { + var global_total: BigInt = wg_fwd[TPB - 1u]; + wg_inv_total = fr_inv_by_a(global_total); + } + workgroupBarrier(); + + // Setup: inv_acc = inv(block_total[tid]) + // = inv_global * block_excl_prefix * block_excl_suffix + var block_excl_prefix: BigInt = get_r(); + if (tid > 0u) { + block_excl_prefix = wg_fwd[tid - 1u]; + } + var block_excl_suffix: BigInt = get_r(); + if (tid + 1u < TPB) { + block_excl_suffix = wg_bwd[tid + 1u]; + } + var inv_global: BigInt = wg_inv_total; + var inv_acc: BigInt = montgomery_product(&inv_global, &block_excl_prefix); + inv_acc = montgomery_product(&inv_acc, &block_excl_suffix); + + // Phase D: back-walk this thread's chunk, emitting affine adds. + // PHASE_D_LOOP: bound = BS (compile-time constant). Use the + // forward-form \`for (off = 0u; off < BS; ...)\` and derive the + // descending index k = chunk_end - 1 - off so the bound is provably + // static at compile time. NO \`while\`/\`loop\` constructs. + for (var off = 0u; off < BS; off = off + 1u) { + let k = chunk_start + (BS - 1u - off); + let pair_base = (batch_base + k) * 4u; + + // Load all four pair coords once. + var p_x: BigInt = inputs[pair_base + 0u]; + var p_y: BigInt = inputs[pair_base + 1u]; + var q_x: BigInt = inputs[pair_base + 2u]; + var q_y: BigInt = inputs[pair_base + 3u]; + + // Backward batch-inverse: inv_dx_k = inv_acc * prefix[k-1] for + // k > chunk_start, or inv_acc itself at the chunk's start slot. + var inv_dx: BigInt; + if (k > chunk_start) { + var prev_prefix: BigInt = prefix[batch_base + (k - 1u)]; + inv_dx = montgomery_product(&inv_acc, &prev_prefix); + } else { + inv_dx = inv_acc; + } + + // Affine add (Montgomery form): + // λ = (Q.y - P.y) * inv_dx + // R.x = λ^2 - P.x - Q.x + // R.y = λ * (P.x - R.x) - P.y + var dy: BigInt = fr_sub(&q_y, &p_y); + var slope: BigInt = montgomery_product(&dy, &inv_dx); + var slope_sq: BigInt = montgomery_product(&slope, &slope); + var t1: BigInt = fr_sub(&slope_sq, &p_x); + var r_x: BigInt = fr_sub(&t1, &q_x); + var dx_back: BigInt = fr_sub(&p_x, &r_x); + var ldx: BigInt = montgomery_product(&slope, &dx_back); + var r_y: BigInt = fr_sub(&ldx, &p_y); + + // Output: 2 BigInts per pair. + let out_base = (batch_base + k) * 2u; + outputs[out_base + 0u] = r_x; + outputs[out_base + 1u] = r_y; + + // Update inv_acc for the next (smaller-k) iteration, unless we + // just emitted the chunk-start slot. + if (k > chunk_start) { + var dx_k: BigInt = fr_sub(&q_x, &p_x); + inv_acc = montgomery_product(&inv_acc, &dx_k); } } @@ -2500,9 +3275,14 @@ var g_points_y: array; @group(0) @binding(5) var g_points_z: array; -// Unfiform storage buffer. +// Uniform storage buffer. Layout: (subtask_idx_base, num_columns, +// num_subtasks_per_bpr, num_subtasks_total). The 4th slot is the +// dispatch's bounds-check ceiling for the WPB tail batch — values past +// num_subtasks_total are skipped per-thread inside the multi-window +// loop. Promoted from vec3 to vec4 so the host can pass a single +// uniform across both legacy (WPB=1) and multi-window dispatches. @group(0) @binding(6) -var params: vec3; +var params: vec4; {{#capture_debug}} // Debug capture for stage_2's inline add-2007-bl formula. Layout: 8 BigInts @@ -2628,15 +3408,29 @@ fn bench_xor_into(a: Point, b: Point) -> Point { return r; } +// Multi-window BPR (WPB > 1) trades thread count for per-thread work: +// fewer workgroups dispatched, each thread runs the WPB-window inner loop +// {{ windows_per_batch }}u times (one per subtask in the batch). Saves +// total launches and amortises kernel header / dispatch overhead, but +// inflates per-thread register pressure — a single subtask already +// carries ~10 BigInt locals (~800B/thread) inside the add_points formula, +// so WPB > 1 risks spilling on Apple Silicon SIMD-groups (32 KB shared +// register file). When WPB=1 the inner loop runs once and the dispatch +// shape collapses to the legacy (X=num_subtasks, 256) layout. @compute @workgroup_size({{ workgroup_size }}) -fn stage_1(@builtin(global_invocation_id) global_id: vec3) { - let thread_id = global_id.x; +fn stage_1( + @builtin(global_invocation_id) global_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, +) { + let thread_id = local_id.x; let num_threads_per_subtask = {{ workgroup_size }}u; - let subtask_idx = params[0]; // 0, 2, 4, 6, ... + let subtask_idx_base = params[0]; // base subtask for this dispatch (host-side outer iter) let num_columns = params[1]; // 65536 - let num_subtasks_per_bpr = params[2]; // Must be a power of 2 + let num_subtasks_per_bpr = params[2]; // packs g_points output across subtasks (legacy) + let num_subtasks_total = params[3]; // bounds check for tail batches when WPB > 1 /// Number of buckets per subtask. let num_buckets_per_subtask = num_columns / 2u; // 2 ** 15 = 32768 @@ -2646,17 +3440,30 @@ fn stage_1(@builtin(global_invocation_id) global_id: vec3) { let num_buckets_per_bpr = num_buckets_per_subtask * num_subtasks_per_bpr; /// Number of buckets to reduce per thread. - let buckets_per_thread = num_buckets_per_subtask / + let buckets_per_thread = num_buckets_per_subtask / num_threads_per_subtask; - let multiplier = subtask_idx + (thread_id / num_threads_per_subtask); - let offset = num_buckets_per_subtask * multiplier; + // WPB-aware subtask iteration. Each workgroup owns WPB consecutive + // subtasks; thread \`thread_id\` processes the same in-subtask slice + // for every one of them. With WPB=1 the outer loop runs once and + // the semantics are byte-identical to the legacy dispatch. + for (var w_local = 0u; w_local < {{ windows_per_batch }}u; w_local = w_local + 1u) { + let subtask_in_dispatch = wg_id.x * {{ windows_per_batch }}u + w_local; + let subtask_idx = subtask_idx_base + subtask_in_dispatch; + if (subtask_idx >= num_subtasks_total) { + // Tail batch: skip out-of-range subtasks (num_subtasks may + // not be a multiple of WPB). + continue; + } - var idx = offset; - if (thread_id % num_threads_per_subtask != 0u) { - idx = (num_threads_per_subtask - (thread_id % num_threads_per_subtask)) * - buckets_per_thread + offset; - } + let multiplier = subtask_idx; + let offset = num_buckets_per_subtask * multiplier; + + var idx = offset; + if (thread_id % num_threads_per_subtask != 0u) { + idx = (num_threads_per_subtask - (thread_id % num_threads_per_subtask)) * + buckets_per_thread + offset; + } {{#bench_compute_only}} // Microbench: synthesize the initial m so this load is also stripped. @@ -2931,40 +3738,46 @@ fn stage_1(@builtin(global_invocation_id) global_id: vec3) { {{/mixed_safe_buckets}} {{/assume_affine_buckets}} - let t = (subtask_idx / num_subtasks_per_bpr) * - (num_threads_per_subtask * num_subtasks_per_bpr) + - thread_id; + let t = (subtask_idx / num_subtasks_per_bpr) * + (num_threads_per_subtask * num_subtasks_per_bpr) + + thread_id; {{#bench_skip_writes}} - // V_NULL / V3: skip the 6 storage writes (bucket_sum + g_points). - // Funnel one observable write per thread into the workgroup atomic - // sink so the WGSL compiler can't DCE the inner loop. atomicXor is - // the cheapest read-modify-write that can't be folded; it produces - // 1 LDS op per thread vs 6 storage writes in the production path. - atomicXor(&bench_sink, m.x.limbs[0] ^ g.x.limbs[0] ^ t); + // V_NULL / V3: skip the 6 storage writes (bucket_sum + g_points). + // Funnel one observable write per thread into the workgroup atomic + // sink so the WGSL compiler can't DCE the inner loop. atomicXor is + // the cheapest read-modify-write that can't be folded; it produces + // 1 LDS op per thread vs 6 storage writes in the production path. + atomicXor(&bench_sink, m.x.limbs[0] ^ g.x.limbs[0] ^ t); {{/bench_skip_writes}} {{^bench_skip_writes}} - bucket_sum_x[idx] = m.x; - bucket_sum_y[idx] = m.y; - bucket_sum_z[idx] = m.z; + bucket_sum_x[idx] = m.x; + bucket_sum_y[idx] = m.y; + bucket_sum_z[idx] = m.z; - g_points_x[t] = g.x; - g_points_y[t] = g.y; - g_points_z[t] = g.z; + g_points_x[t] = g.x; + g_points_y[t] = g.y; + g_points_z[t] = g.z; {{/bench_skip_writes}} + } // end of w_local loop (WPB-aware multi-window outer) {{{ recompile }}} } @compute @workgroup_size({{ workgroup_size }}) -fn stage_2(@builtin(global_invocation_id) global_id: vec3) { - let thread_id = global_id.x; +fn stage_2( + @builtin(global_invocation_id) global_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, +) { + let thread_id = local_id.x; let num_threads_per_subtask = {{ workgroup_size }}u; - let subtask_idx = params[0]; // 0, 2, 4, 6, ... + let subtask_idx_base = params[0]; // base subtask for this dispatch (host-side outer iter) let num_columns = params[1]; // 65536 - let num_subtasks_per_bpr = params[2]; // Must be a power of 2 + let num_subtasks_per_bpr = params[2]; // packs g_points output across subtasks (legacy) + let num_subtasks_total = params[3]; // bounds check for tail batches when WPB > 1 /// Number of buckets per subtask. let num_buckets_per_subtask = num_columns / 2u; // 2 ** 15 = 32768 @@ -2974,27 +3787,35 @@ fn stage_2(@builtin(global_invocation_id) global_id: vec3) { let num_buckets_per_bpr = num_buckets_per_subtask * num_subtasks_per_bpr; /// Number of buckets to reduce per thread. - let buckets_per_thread = num_buckets_per_subtask / + let buckets_per_thread = num_buckets_per_subtask / num_threads_per_subtask; - let multiplier = subtask_idx + (thread_id / num_threads_per_subtask); - let offset = num_buckets_per_subtask * multiplier; + // WPB-aware subtask iteration. See stage_1 for the rationale. + for (var w_local = 0u; w_local < {{ windows_per_batch }}u; w_local = w_local + 1u) { + let subtask_in_dispatch = wg_id.x * {{ windows_per_batch }}u + w_local; + let subtask_idx = subtask_idx_base + subtask_in_dispatch; + if (subtask_idx >= num_subtasks_total) { + continue; + } + + let multiplier = subtask_idx; + let offset = num_buckets_per_subtask * multiplier; - var idx = offset; - if (thread_id % num_threads_per_subtask != 0u) { - idx = (num_threads_per_subtask - (thread_id % num_threads_per_subtask)) * - buckets_per_thread + offset; - } + var idx = offset; + if (thread_id % num_threads_per_subtask != 0u) { + idx = (num_threads_per_subtask - (thread_id % num_threads_per_subtask)) * + buckets_per_thread + offset; + } - var m = load_bucket_sum(idx); + var m = load_bucket_sum(idx); - let t = (subtask_idx / num_subtasks_per_bpr) * - (num_threads_per_subtask * num_subtasks_per_bpr) + - thread_id; - var g = load_g_point(t); + let t = (subtask_idx / num_subtasks_per_bpr) * + (num_threads_per_subtask * num_subtasks_per_bpr) + + thread_id; + var g = load_g_point(t); - let s = buckets_per_thread * (num_threads_per_subtask - (thread_id % num_threads_per_subtask) - 1u); - let sm: Point = double_and_add(m, s); + let s = buckets_per_thread * (num_threads_per_subtask - (thread_id % num_threads_per_subtask) - 1u); + let sm: Point = double_and_add(m, s); // The add-2007-bl formula is inlined here (rather than calling // add_points(g, sm)) because on Dawn/Metal the function-return slot for @@ -3082,16 +3903,17 @@ fn stage_2(@builtin(global_invocation_id) global_id: vec3) { } } - g_points_x[t] = X3_out; - g_points_y[t] = Y3_out; - g_points_z[t] = Z3_out; + g_points_x[t] = X3_out; + g_points_y[t] = Y3_out; + g_points_z[t] = Z3_out; - {{#capture_debug}} - // Final values of the formula outputs, read from outer scope. - debug_capture[thread_id * 8u + 4u] = X3_out; - debug_capture[thread_id * 8u + 5u] = Y3_out; - debug_capture[thread_id * 8u + 7u] = Z3_out; - {{/capture_debug}} + {{#capture_debug}} + // Final values of the formula outputs, read from outer scope. + debug_capture[thread_id * 8u + 4u] = X3_out; + debug_capture[thread_id * 8u + 5u] = Y3_out; + debug_capture[thread_id * 8u + 7u] = Z3_out; + {{/capture_debug}} + } // end of w_local loop (WPB-aware multi-window outer) {{{ recompile }}} } @@ -3622,6 +4444,52 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { } `; +export const divsteps_bench = `// Single-thread-per-input bench shader for \`by_divsteps\`. +// +// Each thread: +// 1. Reads (f_lo, g_lo) as a vec4 from \`inputs_fg[tid]\` +// (= f_lo.x, f_lo.y, g_lo.x, g_lo.y). +// 2. Reads the initial delta from \`inputs_delta[tid]\`. +// 3. Calls \`by_divsteps(&delta, f_lo, g_lo)\`. +// 4. Writes the 8 Mat fields + the updated delta (9 i32) into \`outputs[tid]\`. +// +// LOOP BOUNDS +// The only loops in the included partials are bounded by WGSL \`const\`s +// (BY_NUM_LIMBS, BY_BATCH) or Mustache-const values (\`{{ num_words }}\`). +// No loop is added in this entry shader. The \`if (tid >= n)\` guard is not +// a loop bound. + +@group(0) @binding(0) var inputs_fg: array>; +@group(0) @binding(1) var inputs_delta: array; +@group(0) @binding(2) var outputs: array; +@group(0) @binding(3) var params: vec2; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let n = params.x; + let tid = gid.x; + if (tid >= n) { return; } + + let fg = inputs_fg[tid]; + let f_lo: vec2 = vec2(fg.x, fg.y); + let g_lo: vec2 = vec2(fg.z, fg.w); + var delta: i32 = inputs_delta[tid]; + + let m: Mat = by_divsteps(&delta, f_lo, g_lo); + + let base: u32 = tid * 9u; + outputs[base + 0u] = m.u; + outputs[base + 1u] = m.v; + outputs[base + 2u] = m.q; + outputs[base + 3u] = m.r; + outputs[base + 4u] = m.u_hi; + outputs[base + 5u] = m.v_hi; + outputs[base + 6u] = m.q_hi; + outputs[base + 7u] = m.r_hi; + outputs[base + 8u] = delta; +} +`; + export const extract_word_from_bytes_le = `fn extract_word_from_coord_bytes_le( input: array, word_idx: u32, @@ -3757,6 +4625,58 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } `; +export const fr_inv_bench = `// Per-thread chained-inversion micro-benchmark for the field inverse. +// +// Each thread reads one BN254 base-field value \`a\` (in Montgomery form) and +// runs \`k\` chained calls to the inverse function selected by the +// \`{{ inv_fn }}\` Mustache substitution (default \`fr_inv_by\`; the legacy +// \`fr_inv\` Pornin jumpy K=12 safegcd is the other supported variant): +// a <- {{ inv_fn }}(a) repeated k times, +// writing the final \`a\` back to \`outputs[tid]\`. The host caps k at 100 +// before dispatch. +// +// LOOP BOUNDS +// - The only data-dependent loop in this entry shader is +// \`for (var i = 0u; i < k; i = i + 1u)\` with \`k = params.y\` host-capped +// at 100. Every loop in the selected inverse function (and its callees) +// is bounded by a compile-time \`const\` — see by_inverse.template.wgsl, +// bigint_by, fr_pow.template.wgsl, and the inner \`montgomery_product\` +// (\`for i, j < NUM_WORDS\`). +// - The \`if (tid >= n)\` guard is not a loop bound. +// +// Bind layout mirrors \`field_mul_bench_u32\`: a single inputs buffer \`xs\` +// (no \`ys\`, since inversion is unary) and an outputs buffer. \`params.x\` is +// \`n\` (threads), \`params.y\` is \`k\` (chain length). + +// get_r returns Montgomery R (= the integer 1 in Montgomery form). Defined +// per-entry shader by every cuzk pipeline kernel that uses fr_pow / fr_inv. +// fr_pow_funcs references it via fr_pow; not called on fr_inv_by's hot path +// but kept for binary compatibility with the partial. +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@group(0) @binding(0) var xs: array; +@group(0) @binding(1) var outputs: array; +@group(0) @binding(2) var params: vec2; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let n = params.x; + let k = params.y; + let tid = gid.x; + if (tid >= n) { return; } + + var a = xs[tid]; + for (var i = 0u; i < k; i = i + 1u) { + a = {{ inv_fn }}(a); + } + outputs[tid] = a; +} +`; + export const horner_reduce_bn254 = `{{> structs }} {{> bigint_funcs }} {{> montgomery_product_funcs }} @@ -4415,6 +5335,1754 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { {{{ recompile }}} }`; +export const by_inverse = `// Bernstein-Yang safegcd inversion for the BN254 base field, WGSL port. +// +// This file will grow over sub-steps 1.3-1.5 of the WebGPU MSM rewrite plan +// to host the full \`fr_inv_by\` entrypoint. Currently it contains the inner +// \`by_divsteps\` primitive (sub-step 1.3) — a line-for-line transliteration +// of \`Wasm9x29::divsteps\` (bernstein_yang_inverse_wasm.hpp lines 147-178). +// +// REFERENCES +// - TS port (ground truth): src/msm_webgpu/cuzk/bernstein_yang.ts +// - bigint helpers used here: src/msm_webgpu/wgsl/bigint/bigint_by.template.wgsl +// +// REPRESENTATIONS +// - delta: i32 counter (mirrors the i64 in C++; only the low value +// matters because BATCH=58 caps |delta| growth per call). +// - f_lo, g_lo: u64 carriers held as vec2 (.x = low 32, .y = high 32). +// All u64 ops via u64_add / u64_sub / u64_shr1 helpers. +// - u, v, q, r: i64 matrix entries held as paired i32 (.x = low 32 unsigned +// bit pattern, .y = high 32 signed). After BATCH divsteps +// |entry| <= 2^58, fits in i64. All ops via i64_*_pair. +// +// LOOP BOUND DISCIPLINE +// Three loops total in this file: +// - by_divsteps inner loop: \`for ... i < BY_BATCH\` (= 58u const) +// - by_apply_matrix_fg streaming: \`for ... i < BY_NUM_LIMBS\` (= 9u const) +// - by_apply_matrix_de streaming: \`for ... i < BY_NUM_LIMBS\` (= 9u const) +// Both bounds are compile-time WGSL \`const\`s defined in bigint_by, so the +// plan's "bounded loops" rule is satisfied. The audit grep +// \`grep -nE 'for *\\(' ... | grep -v -E '< [A-Z][A-Z_]*[a-z]?|< [0-9]+|...'\` +// returns no matches. + +// 2x2 matrix produced by BATCH=58 divsteps. Each entry is a 64-bit signed +// integer stored as a paired (lo: i32, hi: i32) — value = u32(lo) | (i32(hi) << 32), +// interpreted as two's complement. The naming matches the C++ struct field +// names (m.u, m.v, m.q, m.r) suffixed with \`_hi\` for the high half. +struct Mat { + u: i32, + v: i32, + q: i32, + r: i32, + u_hi: i32, + v_hi: i32, + q_hi: i32, + r_hi: i32, +} + +// by_divsteps: run BATCH = 58 branchy divsteps on the low 64 bits of (f, g); +// returns the transition matrix M and updates \`*delta\`. +// +// Mirrors Wasm9x29::divsteps line-for-line. The branches are variable-time +// over inputs, which is acceptable here because BN254 base-field inversion +// operates on public values in the MSM pipeline. +// +// Pre: delta is the current divstep counter (signed i32 view). +// f_lo, g_lo are the low 64 bits of f and g respectively (vec2). +// Post: returns the matrix M = ((u, v), (q, r)) such that +// (f_new, g_new) = M * (f_old, g_old) / 2^BATCH +// after BATCH=58 divsteps. *delta is updated. +// +// The TS reference uses \`(g_lo - f_lo) & U64_MASK\` then \`>> 1\` to mimic the +// C++ \`(u64)(g_lo - f_lo) >> 1\` semantics: u64 wrap, then unsigned shift. +// u64_sub + u64_shr1 here is exactly that. +fn by_divsteps(delta: ptr, f_lo_in: vec2, g_lo_in: vec2) -> Mat { + var f_lo: vec2 = f_lo_in; + var g_lo: vec2 = g_lo_in; + // Matrix entries as paired i32 (i64). u = 1, v = 0, q = 0, r = 1. + var u: vec2 = vec2(1, 0); + var v: vec2 = vec2(0, 0); + var q: vec2 = vec2(0, 0); + var r: vec2 = vec2(1, 0); + var d: i32 = *delta; + for (var i: u32 = 0u; i < BY_BATCH; i = i + 1u) { + if (u64_low_bit(g_lo) != 0u) { + if (d > 0) { + // (f, g) <- (g, (g - f) >> 1) using u64 wrap-sub then unsigned >> 1. + let nf: vec2 = g_lo; + let diff: vec2 = u64_sub(g_lo, f_lo); + let ng: vec2 = u64_shr1(diff); + // (u, v, q, r) <- (q << 1, r << 1, q - u, r - v). + let nu: vec2 = i64_shl1_pair(q); + let nv: vec2 = i64_shl1_pair(r); + let nq: vec2 = i64_sub_pair(q, u); + let nr: vec2 = i64_sub_pair(r, v); + f_lo = nf; + g_lo = ng; + u = nu; + v = nv; + q = nq; + r = nr; + d = 1 - d; + } else { + // g <- (g + f) >> 1; q += u; r += v; u <<= 1; v <<= 1; d += 1. + let sum: vec2 = u64_add(g_lo, f_lo); + g_lo = u64_shr1(sum); + q = i64_add_pair(q, u); + r = i64_add_pair(r, v); + u = i64_shl1_pair(u); + v = i64_shl1_pair(v); + d = d + 1; + } + } else { + // g <- g >> 1; u <<= 1; v <<= 1; d += 1. + g_lo = u64_shr1(g_lo); + u = i64_shl1_pair(u); + v = i64_shl1_pair(v); + d = d + 1; + } + } + *delta = d; + return Mat(u.x, v.x, q.x, r.x, u.y, v.y, q.y, r.y); +} + +// ============================================================ +// apply_matrix helpers +// +// \`signed_mul_split\` accepts |a| <= 2^29 (one BY limb) and |b| <= 2^31 - 1 +// and returns (lo29, hi) with a*b = lo29 + hi * 2^29. The streaming +// schoolbook below feeds it (m_*, x_limb) pairs whose products lie in +// [-2^58, 2^58], well inside the helper's contract. +// +// Each per-limb position i computes +// acc <- m_lo * x_i + m_hi * x_{i-1} + carry_in (+ k * p_i terms in de pass) +// then carry_out = acc >> 29 (arithmetic), with the masked low-29 bits +// landing at output position i - 2 (= exact >> BATCH = >> 58 = >> (2 * 29)). +// +// The signed_mul_split returns lo29 in [0, 2^29) and a signed hi. Adding +// four cross-products together can push the high half beyond i32 range +// only when |sum| exceeds ~2^31 — at the per-limb level this is bounded +// by the four-product sum of |2^58| / 2^29 + carry, well inside i32. + +// Convert a (lo29, hi) signed-product split to a full i64 (vec2). +// +// The signed_mul_split helper returns (lo29, hi) where +// value = lo29 + hi * 2^29 with lo29 in [0, 2^29) +// To add this into an i64 accumulator we need to re-express it as a +// (lo32, hi32) pair where +// value = u32(lo32) + i32(hi32) * 2^32 (two's complement) +// +// Bit layout: +// bits 0..28 of value = lo29 (from the lo29 half) +// bits 29..31 of value = bits 0..2 of \`hi\` +// bits 32..63 of value = bits 3..34 of \`hi\` (sign-extended) +// +// \`hi << 29u\` on a u32 keeps only the low 3 bits of \`hi\` after shifting, +// which yields exactly the contribution to bits 29..31. The high half is +// \`hi >> 3u\` with signed arithmetic shift, which sign-extends to fill bit +// 63 correctly when \`hi\` is negative. +fn by_split_to_i64(split: vec2) -> vec2 { + let lo32: u32 = u32(split.x) | (u32(split.y) << 29u); + let hi32: i32 = split.y >> 3u; + return vec2(i32(lo32), hi32); +} + +// Add \`m_lo * x_limb\` (a signed 58-bit product) into an i64 accumulator. +// Helper used pervasively in the streaming schoolbook below. +fn by_add_mul(acc: vec2, m_lo: i32, x_limb: i32) -> vec2 { + let split = signed_mul_split(m_lo, x_limb); + let prod_i64 = by_split_to_i64(split); + return i64_add_pair(acc, prod_i64); +} + +// Arithmetic right shift of an i64 (vec2) by 29. +// result.lo = (u32(acc.x) >> 29) | (u32(acc.y) << 3) +// result.hi = acc.y >> 29 (signed arithmetic shift) +fn i64_ars29(acc: vec2) -> vec2 { + let lo_u: u32 = (u32(acc.x) >> 29u) | (u32(acc.y) << 3u); + let hi: i32 = acc.y >> 29u; + return vec2(i32(lo_u), hi); +} + +// Low 29 bits of an i64. Returns i32 in [0, 2^29). +fn i64_low29(acc: vec2) -> i32 { + return i32(u32(acc.x) & BY_LIMB_MASK); +} + +// u64_mul_low64: low 64 bits of an unsigned u64 * u64 product. +// +// Implements via four 16x16 partials per operand half (16 partials total +// to compute the full 128-bit product), summing only the bits that +// land in the low 64. Used to evaluate \`k = ((-t) * p_inv) mod 2^58\` +// inside by_apply_matrix_de — the C++ does \`(u64)(-(i64)t) * p_inv\` and +// keeps the low 58 bits; we keep the low 64 and let the caller mask. +// +// Pre: any a, b u64 (as vec2). +// Post: low 64 bits of a*b, two's complement-equivalent under masking. +fn u64_mul_low64(a: vec2, b: vec2) -> vec2 { + // Split each u32 half into 16-bit pieces: + // a.x = a0 + a1 * 2^16, a.y = a2 + a3 * 2^16 + // b.x = b0 + b1 * 2^16, b.y = b2 + b3 * 2^16 + let MASK16: u32 = 0xFFFFu; + let a0: u32 = a.x & MASK16; + let a1: u32 = a.x >> 16u; + let a2: u32 = a.y & MASK16; + let a3: u32 = a.y >> 16u; + let b0: u32 = b.x & MASK16; + let b1: u32 = b.x >> 16u; + let b2: u32 = b.y & MASK16; + let b3: u32 = b.y >> 16u; + + // Partials landing in bits 0..15 (only one: a0*b0). + let p00: u32 = a0 * b0; + // Partials in bits 16..47 (a0*b1, a1*b0; we'll split further). + let p01: u32 = a0 * b1; + let p10: u32 = a1 * b0; + // Partials in bits 32..63. + let p02: u32 = a0 * b2; + let p20: u32 = a2 * b0; + let p11: u32 = a1 * b1; + // Partials in bits 48..79 (we only keep the part falling in [0, 64)). + let p03: u32 = a0 * b3; + let p30: u32 = a3 * b0; + let p12: u32 = a1 * b2; + let p21: u32 = a2 * b1; + // Partials in bits 64..95 (a1*b3, a3*b1, a2*b2) — discarded except for + // the part that wraps into the low 64 via the cross sums below. With + // bit-offset >= 64 the contribution is zero in the low 64. + + // Build the low 64 bits: + // bits 0..15: p00 low 16 + // bits 16..47: p00 high 16 + (p01 + p10) low 32 + // bits 32..63: carries from above + (p02 + p20 + p11) low 32 + + // ((p01 + p10) >> 16) plus higher partials' low pieces + // bits 48..63: (p03 + p30 + p12 + p21) low 16 + + // Sum bits 0..31 (the low u32 of the result). + let lo16 = p00 & MASK16; + let mid_a = (p00 >> 16u) + (p01 & MASK16) + (p10 & MASK16); + let lo_u32 = lo16 | (mid_a << 16u); + // Carry into bits 32+ from \`mid_a\` (the high part beyond 16 bits). + let mid_a_hi = mid_a >> 16u; + + // Sum bits 32..63. + // Contributions landing entirely in [32, 64): + // (p01 + p10) >> 16 (these are 32-bit values; the >> 16 lands them at bit 32) + // p02, p20, p11 (start at bit 32; whole 32 bits land in [32, 64)) + // Contributions landing partially in [32, 64) starting at bit 48: + // p03 << 16, p30 << 16, p12 << 16, p21 << 16 + let mid_b = (p01 >> 16u) + (p10 >> 16u) + p02 + p20 + p11; + let mid_c = (p03 + p30 + p12 + p21) << 16u; + let hi_u32 = mid_a_hi + mid_b + mid_c; + + return vec2(lo_u32, hi_u32); +} + +// by_apply_matrix_fg +// +// Mirrors \`Wasm9x29::apply_matrix\` lines 196-217 — the (f, g) streaming +// pass. After BATCH=58 divsteps we apply the 2x2 transition matrix M to +// (f, g) and divide by 2^58. The streamed schoolbook produces one (nf, ng) +// pair per source limb position i and writes the masked low-29 bits at +// output position i - 2 (= the exact >> 58 = >> (2 * 29) drop). +// +// PERF: inlined hot path. Replaces the four \`by_add_mul\` calls per +// accumulator with a single fused 15+14-bit partial-product schoolbook +// that sums all four products' lane-i pieces into a single i32 before +// any carry propagation. This eliminates the per-call (lo29,hi)→(lo32,hi32) +// conversion overhead and reduces 4 i64 adds to one composite extract. +// +// LANE PARTIAL-PRODUCT BOUND: +// Each per-limb cross product (a_l, a_h) * (b_l, b_h) yields four 28-bit +// signed pieces: pll, plh, phl, phh (each |.| < 2^28). +// For 4 products into one accumulator: per-lane sum |.| < 4 * 2^28 = 2^30 +// (fits i32 comfortably). The combined "mid" lane (plh+phl summed across +// 4 products) is |.| < 2 * 4 * 2^28 = 2^31 — still fits i32 (signed). +// +// LOOP BOUND: \`for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u)\` — const. +fn by_apply_matrix_fg(m: Mat, f: ptr, g: ptr) { + // Decompose each of the four matrix entries into low/high 29-bit halves, + // and then split each half into (15-bit signed low, 14-bit high) chunks + // for the partial-product schoolbook below. Hoisted out of the inner + // loop (loop-invariant). + let u_lo: i32 = m.u & i32(BY_LIMB_MASK); + let v_lo: i32 = m.v & i32(BY_LIMB_MASK); + let q_lo: i32 = m.q & i32(BY_LIMB_MASK); + let r_lo: i32 = m.r & i32(BY_LIMB_MASK); + let u_hi: i32 = i32((u32(m.u) >> 29u) | (u32(m.u_hi) << 3u)); + let v_hi: i32 = i32((u32(m.v) >> 29u) | (u32(m.v_hi) << 3u)); + let q_hi: i32 = i32((u32(m.q) >> 29u) | (u32(m.q_hi) << 3u)); + let r_hi: i32 = i32((u32(m.r) >> 29u) | (u32(m.r_hi) << 3u)); + + let u_lo_l: i32 = (u_lo << 17u) >> 17u; + let u_lo_h: i32 = (u_lo - u_lo_l) >> 15u; + let v_lo_l: i32 = (v_lo << 17u) >> 17u; + let v_lo_h: i32 = (v_lo - v_lo_l) >> 15u; + let q_lo_l: i32 = (q_lo << 17u) >> 17u; + let q_lo_h: i32 = (q_lo - q_lo_l) >> 15u; + let r_lo_l: i32 = (r_lo << 17u) >> 17u; + let r_lo_h: i32 = (r_lo - r_lo_l) >> 15u; + let u_hi_l: i32 = (u_hi << 17u) >> 17u; + let u_hi_h: i32 = (u_hi - u_hi_l) >> 15u; + let v_hi_l: i32 = (v_hi << 17u) >> 17u; + let v_hi_h: i32 = (v_hi - v_hi_l) >> 15u; + let q_hi_l: i32 = (q_hi << 17u) >> 17u; + let q_hi_h: i32 = (q_hi - q_hi_l) >> 15u; + let r_hi_l: i32 = (r_hi << 17u) >> 17u; + let r_hi_h: i32 = (r_hi - r_hi_l) >> 15u; + + // Streaming accumulator as i64 (lo, hi). + var cf_lo: u32 = 0u; + var cf_hi: i32 = 0; + var cg_lo: u32 = 0u; + var cg_hi: i32 = 0; + + // Previous limb 15/14-bit pre-splits (for u_hi * fp etc.). Start at 0; + // slide forward each iter to avoid re-splitting next time. + var fp_l: i32 = 0; + var fp_h: i32 = 0; + var gp_l: i32 = 0; + var gp_h: i32 = 0; + + // Single loop with conditional output: the per-iter \`if (i >= 2)\` check + // costs less than the duplicated loop body of a prologue/main split. The + // compiler can predicate the store on most GPUs. + for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u) { + let fi: i32 = (*f).l[i]; + let gi: i32 = (*g).l[i]; + let fi_l: i32 = (fi << 17u) >> 17u; + let fi_h: i32 = (fi - fi_l) >> 15u; + let gi_l: i32 = (gi << 17u) >> 17u; + let gi_h: i32 = (gi - gi_l) >> 15u; + + let nf_pll: i32 = u_lo_l * fi_l + v_lo_l * gi_l + u_hi_l * fp_l + v_hi_l * gp_l; + let nf_mid: i32 = + u_lo_l * fi_h + u_lo_h * fi_l + + v_lo_l * gi_h + v_lo_h * gi_l + + u_hi_l * fp_h + u_hi_h * fp_l + + v_hi_l * gp_h + v_hi_h * gp_l; + let nf_phh: i32 = u_lo_h * fi_h + v_lo_h * gi_h + u_hi_h * fp_h + v_hi_h * gp_h; + let ng_pll: i32 = q_lo_l * fi_l + r_lo_l * gi_l + q_hi_l * fp_l + r_hi_l * gp_l; + let ng_mid: i32 = + q_lo_l * fi_h + q_lo_h * fi_l + + r_lo_l * gi_h + r_lo_h * gi_l + + q_hi_l * fp_h + q_hi_h * fp_l + + r_hi_l * gp_h + r_hi_h * gp_l; + let ng_phh: i32 = q_lo_h * fi_h + r_lo_h * gi_h + q_hi_h * fp_h + r_hi_h * gp_h; + + let nf_pll_u: u32 = u32(nf_pll); + let nf_mid_u: u32 = u32(nf_mid); + let nf_phh_u: u32 = u32(nf_phh); + let nf_pll_hi: i32 = nf_pll >> 31u; + let nf_mid_hi: i32 = nf_mid >> 17u; + let nf_phh_hi: i32 = nf_phh >> 2u; + let nf_s1_lo: u32 = nf_pll_u + (nf_mid_u << 15u); + let nf_s1_c: i32 = select(0i, 1i, nf_s1_lo < nf_pll_u); + let nf_s2_lo: u32 = nf_s1_lo + (nf_phh_u << 30u); + let nf_s2_c: i32 = select(0i, 1i, nf_s2_lo < nf_s1_lo); + let nf_total_lo: u32 = nf_s2_lo + cf_lo; + let nf_total_c: i32 = select(0i, 1i, nf_total_lo < nf_s2_lo); + let nf_total_hi: i32 = nf_pll_hi + nf_mid_hi + nf_phh_hi + nf_s1_c + nf_s2_c + nf_total_c + cf_hi; + + let ng_pll_u: u32 = u32(ng_pll); + let ng_mid_u: u32 = u32(ng_mid); + let ng_phh_u: u32 = u32(ng_phh); + let ng_pll_hi: i32 = ng_pll >> 31u; + let ng_mid_hi: i32 = ng_mid >> 17u; + let ng_phh_hi: i32 = ng_phh >> 2u; + let ng_s1_lo: u32 = ng_pll_u + (ng_mid_u << 15u); + let ng_s1_c: i32 = select(0i, 1i, ng_s1_lo < ng_pll_u); + let ng_s2_lo: u32 = ng_s1_lo + (ng_phh_u << 30u); + let ng_s2_c: i32 = select(0i, 1i, ng_s2_lo < ng_s1_lo); + let ng_total_lo: u32 = ng_s2_lo + cg_lo; + let ng_total_c: i32 = select(0i, 1i, ng_total_lo < ng_s2_lo); + let ng_total_hi: i32 = ng_pll_hi + ng_mid_hi + ng_phh_hi + ng_s1_c + ng_s2_c + ng_total_c + cg_hi; + + if (i >= 2u) { + (*f).l[i - 2u] = i32(nf_total_lo & 0x1FFFFFFFu); + (*g).l[i - 2u] = i32(ng_total_lo & 0x1FFFFFFFu); + } + cf_lo = (nf_total_lo >> 29u) | (u32(nf_total_hi) << 3u); + cf_hi = nf_total_hi >> 29u; + cg_lo = (ng_total_lo >> 29u) | (u32(ng_total_hi) << 3u); + cg_hi = ng_total_hi >> 29u; + + fp_l = fi_l; fp_h = fi_h; gp_l = gi_l; gp_h = gi_h; + } + // Top finalisation: nf9 = u_hi * fp + v_hi * fp_prev + cf (only 2 products + // now, since we've consumed all the input limbs and fi=0). Same shape as + // the inner loop body but with the *_lo terms dropped. + let nf9_pll: i32 = u_hi_l * fp_l + v_hi_l * gp_l; + let nf9_mid: i32 = u_hi_l * fp_h + u_hi_h * fp_l + v_hi_l * gp_h + v_hi_h * gp_l; + let nf9_phh: i32 = u_hi_h * fp_h + v_hi_h * gp_h; + let ng9_pll: i32 = q_hi_l * fp_l + r_hi_l * gp_l; + let ng9_mid: i32 = q_hi_l * fp_h + q_hi_h * fp_l + r_hi_l * gp_h + r_hi_h * gp_l; + let ng9_phh: i32 = q_hi_h * fp_h + r_hi_h * gp_h; + + let nf9_pll_u: u32 = u32(nf9_pll); + let nf9_mid_u: u32 = u32(nf9_mid); + let nf9_phh_u: u32 = u32(nf9_phh); + let nf9_pll_hi: i32 = nf9_pll >> 31u; + let nf9_mid_hi: i32 = nf9_mid >> 17u; + let nf9_phh_hi: i32 = nf9_phh >> 2u; + let nf9_s1_lo: u32 = nf9_pll_u + (nf9_mid_u << 15u); + let nf9_s1_c: i32 = select(0i, 1i, nf9_s1_lo < nf9_pll_u); + let nf9_s2_lo: u32 = nf9_s1_lo + (nf9_phh_u << 30u); + let nf9_s2_c: i32 = select(0i, 1i, nf9_s2_lo < nf9_s1_lo); + let nf9_total_lo: u32 = nf9_s2_lo + cf_lo; + let nf9_total_c: i32 = select(0i, 1i, nf9_total_lo < nf9_s2_lo); + let nf9_total_hi: i32 = nf9_pll_hi + nf9_mid_hi + nf9_phh_hi + nf9_s1_c + nf9_s2_c + nf9_total_c + cf_hi; + + let ng9_pll_u: u32 = u32(ng9_pll); + let ng9_mid_u: u32 = u32(ng9_mid); + let ng9_phh_u: u32 = u32(ng9_phh); + let ng9_pll_hi: i32 = ng9_pll >> 31u; + let ng9_mid_hi: i32 = ng9_mid >> 17u; + let ng9_phh_hi: i32 = ng9_phh >> 2u; + let ng9_s1_lo: u32 = ng9_pll_u + (ng9_mid_u << 15u); + let ng9_s1_c: i32 = select(0i, 1i, ng9_s1_lo < ng9_pll_u); + let ng9_s2_lo: u32 = ng9_s1_lo + (ng9_phh_u << 30u); + let ng9_s2_c: i32 = select(0i, 1i, ng9_s2_lo < ng9_s1_lo); + let ng9_total_lo: u32 = ng9_s2_lo + cg_lo; + let ng9_total_c: i32 = select(0i, 1i, ng9_total_lo < ng9_s2_lo); + let ng9_total_hi: i32 = ng9_pll_hi + ng9_mid_hi + ng9_phh_hi + ng9_s1_c + ng9_s2_c + ng9_total_c + cg_hi; + + (*f).l[BY_NUM_LIMBS - 2u] = i32(nf9_total_lo & 0x1FFFFFFFu); + (*g).l[BY_NUM_LIMBS - 2u] = i32(ng9_total_lo & 0x1FFFFFFFu); + // Top limb: the value above bit 29 of (nf9_total_lo, nf9_total_hi). + (*f).l[BY_NUM_LIMBS - 1u] = i32((nf9_total_lo >> 29u) | (u32(nf9_total_hi) << 3u)); + (*g).l[BY_NUM_LIMBS - 1u] = i32((ng9_total_lo >> 29u) | (u32(ng9_total_hi) << 3u)); + // by_normalise is a no-op: all lower limbs already masked to [0, 2^29). +} + +// by_apply_matrix_de +// +// Mirrors \`Wasm9x29::apply_matrix\` lines 222-254 — the (d, e) pass with +// the 2-adic k·p correction. The first two output limbs are zero by +// construction (k chosen to clear the low 58 bits of (M · (d, e)) mod 2^58), +// so the streaming pass folds k·p in from position 2 onward. +// +// \`p_inv_lo\`, \`p_inv_hi\`: the 58-bit constant p^(-1) mod 2^58 split as the +// low 32 bits and the high 32 bits respectively. The WASM C++ stores it as +// a single u64 \`p_inv\`; the WGSL caller pre-splits it because WGSL has no +// native u64. Naming reflects the split: \`p_inv = p_inv_lo + (p_inv_hi << 32)\`. +// +// Loop bound is \`BY_NUM_LIMBS\` — const, satisfying the plan rule. +fn by_apply_matrix_de( + m: Mat, + d: ptr, + e: ptr, + p: ptr, + p_inv_lo: u32, + p_inv_hi: u32, +) { + // Same matrix split as the f/g pass, with 15+14-bit pre-splits hoisted + // out of the inner loop (loop-invariant). + let u_lo: i32 = m.u & i32(BY_LIMB_MASK); + let v_lo: i32 = m.v & i32(BY_LIMB_MASK); + let q_lo: i32 = m.q & i32(BY_LIMB_MASK); + let r_lo: i32 = m.r & i32(BY_LIMB_MASK); + let u_hi: i32 = i32((u32(m.u) >> 29u) | (u32(m.u_hi) << 3u)); + let v_hi: i32 = i32((u32(m.v) >> 29u) | (u32(m.v_hi) << 3u)); + let q_hi: i32 = i32((u32(m.q) >> 29u) | (u32(m.q_hi) << 3u)); + let r_hi: i32 = i32((u32(m.r) >> 29u) | (u32(m.r_hi) << 3u)); + + let u_lo_l: i32 = (u_lo << 17u) >> 17u; + let u_lo_h: i32 = (u_lo - u_lo_l) >> 15u; + let v_lo_l: i32 = (v_lo << 17u) >> 17u; + let v_lo_h: i32 = (v_lo - v_lo_l) >> 15u; + let q_lo_l: i32 = (q_lo << 17u) >> 17u; + let q_lo_h: i32 = (q_lo - q_lo_l) >> 15u; + let r_lo_l: i32 = (r_lo << 17u) >> 17u; + let r_lo_h: i32 = (r_lo - r_lo_l) >> 15u; + let u_hi_l: i32 = (u_hi << 17u) >> 17u; + let u_hi_h: i32 = (u_hi - u_hi_l) >> 15u; + let v_hi_l: i32 = (v_hi << 17u) >> 17u; + let v_hi_h: i32 = (v_hi - v_hi_l) >> 15u; + let q_hi_l: i32 = (q_hi << 17u) >> 17u; + let q_hi_h: i32 = (q_hi - q_hi_l) >> 15u; + let r_hi_l: i32 = (r_hi << 17u) >> 17u; + let r_hi_h: i32 = (r_hi - r_hi_l) >> 15u; + + let d0: i32 = (*d).l[0]; + let e0: i32 = (*e).l[0]; + let d1: i32 = (*d).l[1]; + let e1: i32 = (*e).l[1]; + + let d0_l: i32 = (d0 << 17u) >> 17u; + let d0_h: i32 = (d0 - d0_l) >> 15u; + let e0_l: i32 = (e0 << 17u) >> 17u; + let e0_h: i32 = (e0 - e0_l) >> 15u; + let d1_l: i32 = (d1 << 17u) >> 17u; + let d1_h: i32 = (d1 - d1_l) >> 15u; + let e1_l: i32 = (e1 << 17u) >> 17u; + let e1_h: i32 = (e1 - e1_l) >> 15u; + + // nd0 = u_lo * d0 + v_lo * e0 (2 products) — inline the 15+14 schoolbook. + let nd0_pll: i32 = u_lo_l * d0_l + v_lo_l * e0_l; + let nd0_mid: i32 = + u_lo_l * d0_h + u_lo_h * d0_l + + v_lo_l * e0_h + v_lo_h * e0_l; + let nd0_phh: i32 = u_lo_h * d0_h + v_lo_h * e0_h; + let ne0_pll: i32 = q_lo_l * d0_l + r_lo_l * e0_l; + let ne0_mid: i32 = + q_lo_l * d0_h + q_lo_h * d0_l + + r_lo_l * e0_h + r_lo_h * e0_l; + let ne0_phh: i32 = q_lo_h * d0_h + r_lo_h * e0_h; + + // nd1 = u_lo * d1 + v_lo * e1 + u_hi * d0 + v_hi * e0 (4 products). + let nd1_pll: i32 = u_lo_l * d1_l + v_lo_l * e1_l + u_hi_l * d0_l + v_hi_l * e0_l; + let nd1_mid: i32 = + u_lo_l * d1_h + u_lo_h * d1_l + + v_lo_l * e1_h + v_lo_h * e1_l + + u_hi_l * d0_h + u_hi_h * d0_l + + v_hi_l * e0_h + v_hi_h * e0_l; + let nd1_phh: i32 = u_lo_h * d1_h + v_lo_h * e1_h + u_hi_h * d0_h + v_hi_h * e0_h; + let ne1_pll: i32 = q_lo_l * d1_l + r_lo_l * e1_l + q_hi_l * d0_l + r_hi_l * e0_l; + let ne1_mid: i32 = + q_lo_l * d1_h + q_lo_h * d1_l + + r_lo_l * e1_h + r_lo_h * e1_l + + q_hi_l * d0_h + q_hi_h * d0_l + + r_hi_l * e0_h + r_hi_h * e0_l; + let ne1_phh: i32 = q_lo_h * d1_h + r_lo_h * e1_h + q_hi_h * d0_h + r_hi_h * e0_h; + + // Helper-equivalent extraction: convert (pll, mid, phh) → i64 (lo, hi). + // Inlined to avoid function-call overhead. + let nd0_pll_u: u32 = u32(nd0_pll); + let nd0_mid_u: u32 = u32(nd0_mid); + let nd0_phh_u: u32 = u32(nd0_phh); + let nd0_pll_hi: i32 = nd0_pll >> 31u; + let nd0_mid_hi: i32 = nd0_mid >> 17u; + let nd0_phh_hi: i32 = nd0_phh >> 2u; + let nd0_s1: u32 = nd0_pll_u + (nd0_mid_u << 15u); + let nd0_s1c: i32 = select(0i, 1i, nd0_s1 < nd0_pll_u); + let nd0_lo: u32 = nd0_s1 + (nd0_phh_u << 30u); + let nd0_s2c: i32 = select(0i, 1i, nd0_lo < nd0_s1); + let nd0_hi: i32 = nd0_pll_hi + nd0_mid_hi + nd0_phh_hi + nd0_s1c + nd0_s2c; + + let ne0_pll_u: u32 = u32(ne0_pll); + let ne0_mid_u: u32 = u32(ne0_mid); + let ne0_phh_u: u32 = u32(ne0_phh); + let ne0_pll_hi: i32 = ne0_pll >> 31u; + let ne0_mid_hi: i32 = ne0_mid >> 17u; + let ne0_phh_hi: i32 = ne0_phh >> 2u; + let ne0_s1: u32 = ne0_pll_u + (ne0_mid_u << 15u); + let ne0_s1c: i32 = select(0i, 1i, ne0_s1 < ne0_pll_u); + let ne0_lo: u32 = ne0_s1 + (ne0_phh_u << 30u); + let ne0_s2c: i32 = select(0i, 1i, ne0_lo < ne0_s1); + let ne0_hi: i32 = ne0_pll_hi + ne0_mid_hi + ne0_phh_hi + ne0_s1c + ne0_s2c; + + let nd1_pll_u: u32 = u32(nd1_pll); + let nd1_mid_u: u32 = u32(nd1_mid); + let nd1_phh_u: u32 = u32(nd1_phh); + let nd1_pll_hi: i32 = nd1_pll >> 31u; + let nd1_mid_hi: i32 = nd1_mid >> 17u; + let nd1_phh_hi: i32 = nd1_phh >> 2u; + let nd1_s1: u32 = nd1_pll_u + (nd1_mid_u << 15u); + let nd1_s1c: i32 = select(0i, 1i, nd1_s1 < nd1_pll_u); + let nd1_lo: u32 = nd1_s1 + (nd1_phh_u << 30u); + let nd1_s2c: i32 = select(0i, 1i, nd1_lo < nd1_s1); + let nd1_hi: i32 = nd1_pll_hi + nd1_mid_hi + nd1_phh_hi + nd1_s1c + nd1_s2c; + + let ne1_pll_u: u32 = u32(ne1_pll); + let ne1_mid_u: u32 = u32(ne1_mid); + let ne1_phh_u: u32 = u32(ne1_phh); + let ne1_pll_hi: i32 = ne1_pll >> 31u; + let ne1_mid_hi: i32 = ne1_mid >> 17u; + let ne1_phh_hi: i32 = ne1_phh >> 2u; + let ne1_s1: u32 = ne1_pll_u + (ne1_mid_u << 15u); + let ne1_s1c: i32 = select(0i, 1i, ne1_s1 < ne1_pll_u); + let ne1_lo: u32 = ne1_s1 + (ne1_phh_u << 30u); + let ne1_s2c: i32 = select(0i, 1i, ne1_lo < ne1_s1); + let ne1_hi: i32 = ne1_pll_hi + ne1_mid_hi + ne1_phh_hi + ne1_s1c + ne1_s2c; + + // Reconstruct low 58 bits of nd and ne for k computation. + // td = (nd0_low29 + (nd1_plus_low29 << 29)) where nd1_plus = nd1 + (nd0 >> 29). + let nd0_low29: u32 = nd0_lo & BY_LIMB_MASK; + let ne0_low29: u32 = ne0_lo & BY_LIMB_MASK; + // nd0 >> 29 arithmetic shift, as i64 (lo, hi): + let nd0_ars_lo: u32 = (nd0_lo >> 29u) | (u32(nd0_hi) << 3u); + let nd0_ars_hi: i32 = nd0_hi >> 29u; + let ne0_ars_lo: u32 = (ne0_lo >> 29u) | (u32(ne0_hi) << 3u); + let ne0_ars_hi: i32 = ne0_hi >> 29u; + let nd1p_lo: u32 = nd1_lo + nd0_ars_lo; + let nd1p_c: i32 = select(0i, 1i, nd1p_lo < nd1_lo); + let nd1p_hi: i32 = nd1_hi + nd0_ars_hi + nd1p_c; + let ne1p_lo: u32 = ne1_lo + ne0_ars_lo; + let ne1p_c: i32 = select(0i, 1i, ne1p_lo < ne1_lo); + let ne1p_hi: i32 = ne1_hi + ne0_ars_hi + ne1p_c; + let nd1_low29: u32 = nd1p_lo & BY_LIMB_MASK; + let ne1_low29: u32 = ne1p_lo & BY_LIMB_MASK; + + let td: vec2 = vec2(nd0_low29 | (nd1_low29 << 29u), nd1_low29 >> 3u); + let te: vec2 = vec2(ne0_low29 | (ne1_low29 << 29u), ne1_low29 >> 3u); + + // k_d = ((-t_d) * p_inv) & MASK_BATCH. + let neg_td: vec2 = u64_neg(td); + let neg_te: vec2 = u64_neg(te); + let p_inv: vec2 = vec2(p_inv_lo, p_inv_hi); + let kd_prod: vec2 = u64_mul_low64(neg_td, p_inv); + let ke_prod: vec2 = u64_mul_low64(neg_te, p_inv); + + let MASK_BATCH_HI: u32 = (1u << 26u) - 1u; + let kd_lo32: u32 = kd_prod.x; + let kd_hi26: u32 = kd_prod.y & MASK_BATCH_HI; + let ke_lo32: u32 = ke_prod.x; + let ke_hi26: u32 = ke_prod.y & MASK_BATCH_HI; + + let kd_lo: i32 = i32(kd_lo32 & BY_LIMB_MASK); + let kd_hi: i32 = i32((kd_lo32 >> 29u) | (kd_hi26 << 3u)); + let ke_lo: i32 = i32(ke_lo32 & BY_LIMB_MASK); + let ke_hi: i32 = i32((ke_lo32 >> 29u) | (ke_hi26 << 3u)); + + // Split k_*_lo, k_*_hi into 15+14 chunks for the inner loop. + let kd_lo_l: i32 = (kd_lo << 17u) >> 17u; + let kd_lo_h: i32 = (kd_lo - kd_lo_l) >> 15u; + let kd_hi_l: i32 = (kd_hi << 17u) >> 17u; + let kd_hi_h: i32 = (kd_hi - kd_hi_l) >> 15u; + let ke_lo_l: i32 = (ke_lo << 17u) >> 17u; + let ke_lo_h: i32 = (ke_lo - ke_lo_l) >> 15u; + let ke_hi_l: i32 = (ke_hi << 17u) >> 17u; + let ke_hi_h: i32 = (ke_hi - ke_hi_l) >> 15u; + + // Initial seed: nd0_plus = nd0 + kd_lo*p[0], cd_acc = nd1 + kd_lo*p[1] + kd_hi*p[0] + (nd0_plus >> 29). + // p[0] and p[1] are small (the BN254 modulus); we still split for correctness. + let p0: i32 = (*p).l[0]; + let p1: i32 = (*p).l[1]; + let p0_l: i32 = (p0 << 17u) >> 17u; + let p0_h: i32 = (p0 - p0_l) >> 15u; + let p1_l: i32 = (p1 << 17u) >> 17u; + let p1_h: i32 = (p1 - p1_l) >> 15u; + + // nd0_plus = nd0 + kd_lo*p[0] + let np0_pll: i32 = kd_lo_l * p0_l; + let np0_mid: i32 = kd_lo_l * p0_h + kd_lo_h * p0_l; + let np0_phh: i32 = kd_lo_h * p0_h; + let np0_pll_u: u32 = u32(np0_pll); + let np0_mid_u: u32 = u32(np0_mid); + let np0_phh_u: u32 = u32(np0_phh); + let np0_pll_hi: i32 = np0_pll >> 31u; + let np0_mid_hi: i32 = np0_mid >> 17u; + let np0_phh_hi: i32 = np0_phh >> 2u; + let np0_s1: u32 = np0_pll_u + (np0_mid_u << 15u); + let np0_s1c: i32 = select(0i, 1i, np0_s1 < np0_pll_u); + let np0_lo: u32 = np0_s1 + (np0_phh_u << 30u); + let np0_s2c: i32 = select(0i, 1i, np0_lo < np0_s1); + let np0_hi: i32 = np0_pll_hi + np0_mid_hi + np0_phh_hi + np0_s1c + np0_s2c; + // nd0_plus = nd0 + np0 + let nd0p_lo: u32 = nd0_lo + np0_lo; + let nd0p_c: i32 = select(0i, 1i, nd0p_lo < nd0_lo); + let nd0p_hi: i32 = nd0_hi + np0_hi + nd0p_c; + // (nd0_plus >> 29) signed arithmetic + let nd0p_ars_lo: u32 = (nd0p_lo >> 29u) | (u32(nd0p_hi) << 3u); + let nd0p_ars_hi: i32 = nd0p_hi >> 29u; + + let ne0p_pll: i32 = ke_lo_l * p0_l; + let ne0p_mid: i32 = ke_lo_l * p0_h + ke_lo_h * p0_l; + let ne0p_phh: i32 = ke_lo_h * p0_h; + let ne0p_pll_u: u32 = u32(ne0p_pll); + let ne0p_mid_u: u32 = u32(ne0p_mid); + let ne0p_phh_u: u32 = u32(ne0p_phh); + let ne0p_pll_hi: i32 = ne0p_pll >> 31u; + let ne0p_mid_hi: i32 = ne0p_mid >> 17u; + let ne0p_phh_hi: i32 = ne0p_phh >> 2u; + let ne0p_s1: u32 = ne0p_pll_u + (ne0p_mid_u << 15u); + let ne0p_s1c: i32 = select(0i, 1i, ne0p_s1 < ne0p_pll_u); + let ne0p_lo: u32 = ne0p_s1 + (ne0p_phh_u << 30u); + let ne0p_s2c: i32 = select(0i, 1i, ne0p_lo < ne0p_s1); + let ne0p_hi: i32 = ne0p_pll_hi + ne0p_mid_hi + ne0p_phh_hi + ne0p_s1c + ne0p_s2c; + let ne0pa_lo: u32 = ne0_lo + ne0p_lo; + let ne0pa_c: i32 = select(0i, 1i, ne0pa_lo < ne0_lo); + let ne0pa_hi: i32 = ne0_hi + ne0p_hi + ne0pa_c; + let ne0pa_ars_lo: u32 = (ne0pa_lo >> 29u) | (u32(ne0pa_hi) << 3u); + let ne0pa_ars_hi: i32 = ne0pa_hi >> 29u; + + // cd_acc = nd1 + kd_lo*p[1] + kd_hi*p[0] + (nd0_plus >> 29) + let cda_pll: i32 = kd_lo_l * p1_l + kd_hi_l * p0_l; + let cda_mid: i32 = + kd_lo_l * p1_h + kd_lo_h * p1_l + + kd_hi_l * p0_h + kd_hi_h * p0_l; + let cda_phh: i32 = kd_lo_h * p1_h + kd_hi_h * p0_h; + let cda_pll_u: u32 = u32(cda_pll); + let cda_mid_u: u32 = u32(cda_mid); + let cda_phh_u: u32 = u32(cda_phh); + let cda_pll_hi: i32 = cda_pll >> 31u; + let cda_mid_hi: i32 = cda_mid >> 17u; + let cda_phh_hi: i32 = cda_phh >> 2u; + let cda_s1: u32 = cda_pll_u + (cda_mid_u << 15u); + let cda_s1c: i32 = select(0i, 1i, cda_s1 < cda_pll_u); + let cda_p_lo: u32 = cda_s1 + (cda_phh_u << 30u); + let cda_s2c: i32 = select(0i, 1i, cda_p_lo < cda_s1); + let cda_p_hi: i32 = cda_pll_hi + cda_mid_hi + cda_phh_hi + cda_s1c + cda_s2c; + // cda = nd1 + cda_p + nd0p_ars + let cda_a_lo: u32 = nd1_lo + cda_p_lo; + let cda_a_c: i32 = select(0i, 1i, cda_a_lo < nd1_lo); + let cda_a_hi: i32 = nd1_hi + cda_p_hi + cda_a_c; + let cda_b_lo: u32 = cda_a_lo + nd0p_ars_lo; + let cda_b_c: i32 = select(0i, 1i, cda_b_lo < cda_a_lo); + let cda_b_hi: i32 = cda_a_hi + nd0p_ars_hi + cda_b_c; + // cd = cda >> 29 (signed arithmetic) + var cd_lo: u32 = (cda_b_lo >> 29u) | (u32(cda_b_hi) << 3u); + var cd_hi: i32 = cda_b_hi >> 29u; + + let cea_pll: i32 = ke_lo_l * p1_l + ke_hi_l * p0_l; + let cea_mid: i32 = + ke_lo_l * p1_h + ke_lo_h * p1_l + + ke_hi_l * p0_h + ke_hi_h * p0_l; + let cea_phh: i32 = ke_lo_h * p1_h + ke_hi_h * p0_h; + let cea_pll_u: u32 = u32(cea_pll); + let cea_mid_u: u32 = u32(cea_mid); + let cea_phh_u: u32 = u32(cea_phh); + let cea_pll_hi: i32 = cea_pll >> 31u; + let cea_mid_hi: i32 = cea_mid >> 17u; + let cea_phh_hi: i32 = cea_phh >> 2u; + let cea_s1: u32 = cea_pll_u + (cea_mid_u << 15u); + let cea_s1c: i32 = select(0i, 1i, cea_s1 < cea_pll_u); + let cea_p_lo: u32 = cea_s1 + (cea_phh_u << 30u); + let cea_s2c: i32 = select(0i, 1i, cea_p_lo < cea_s1); + let cea_p_hi: i32 = cea_pll_hi + cea_mid_hi + cea_phh_hi + cea_s1c + cea_s2c; + let cea_a_lo: u32 = ne1_lo + cea_p_lo; + let cea_a_c: i32 = select(0i, 1i, cea_a_lo < ne1_lo); + let cea_a_hi: i32 = ne1_hi + cea_p_hi + cea_a_c; + let cea_b_lo: u32 = cea_a_lo + ne0pa_ars_lo; + let cea_b_c: i32 = select(0i, 1i, cea_b_lo < cea_a_lo); + let cea_b_hi: i32 = cea_a_hi + ne0pa_ars_hi + cea_b_c; + var ce_lo: u32 = (cea_b_lo >> 29u) | (u32(cea_b_hi) << 3u); + var ce_hi: i32 = cea_b_hi >> 29u; + + // Slide-forward previous-limb splits for the inner loop. \`pc_l\`/\`pc_h\` + // hold p[i-1] entering iter i; after the body we set pc = p[i]. + var dp_l: i32 = d1_l; + var dp_h: i32 = d1_h; + var ep_l: i32 = e1_l; + var ep_h: i32 = e1_h; + var pc_l: i32 = p1_l; + var pc_h: i32 = p1_h; + + for (var i: u32 = 2u; i < BY_NUM_LIMBS; i = i + 1u) { + let di: i32 = (*d).l[i]; + let ei: i32 = (*e).l[i]; + let pi: i32 = (*p).l[i]; + let di_l: i32 = (di << 17u) >> 17u; + let di_h: i32 = (di - di_l) >> 15u; + let ei_l: i32 = (ei << 17u) >> 17u; + let ei_h: i32 = (ei - ei_l) >> 15u; + let pi_l: i32 = (pi << 17u) >> 17u; + let pi_h: i32 = (pi - pi_l) >> 15u; + + // nd = u_lo*di + v_lo*ei + u_hi*dp + v_hi*ep + kd_lo*p[i] + kd_hi*p[i-1] + cd + // 6 products. Bound check: each pll/phh < 2^28, sum < 6*2^28 < 2^31 ✓ + // each plh+phl < 2*2^28 = 2^29, sum < 6*2^29 < 2^32 ✗ overflow ! + // The "mid" lane needs care. Split: sum 6 lh-products and 6 hl-products SEPARATELY, + // each < 6*2^28 < 2^31. Then combine in i64 via two adds. + let nd_pll: i32 = + u_lo_l * di_l + v_lo_l * ei_l + + u_hi_l * dp_l + v_hi_l * ep_l + + kd_lo_l * pi_l + kd_hi_l * pc_l; + // Two mid sub-lanes: low_high_products + high_low_products. + let nd_mid_lh: i32 = + u_lo_l * di_h + v_lo_l * ei_h + + u_hi_l * dp_h + v_hi_l * ep_h + + kd_lo_l * pi_h + kd_hi_l * pc_h; + let nd_mid_hl: i32 = + u_lo_h * di_l + v_lo_h * ei_l + + u_hi_h * dp_l + v_hi_h * ep_l + + kd_lo_h * pi_l + kd_hi_h * pc_l; + let nd_phh: i32 = + u_lo_h * di_h + v_lo_h * ei_h + + u_hi_h * dp_h + v_hi_h * ep_h + + kd_lo_h * pi_h + kd_hi_h * pc_h; + + let ne_pll: i32 = + q_lo_l * di_l + r_lo_l * ei_l + + q_hi_l * dp_l + r_hi_l * ep_l + + ke_lo_l * pi_l + ke_hi_l * pc_l; + let ne_mid_lh: i32 = + q_lo_l * di_h + r_lo_l * ei_h + + q_hi_l * dp_h + r_hi_l * ep_h + + ke_lo_l * pi_h + ke_hi_l * pc_h; + let ne_mid_hl: i32 = + q_lo_h * di_l + r_lo_h * ei_l + + q_hi_h * dp_l + r_hi_h * ep_l + + ke_lo_h * pi_l + ke_hi_h * pc_l; + let ne_phh: i32 = + q_lo_h * di_h + r_lo_h * ei_h + + q_hi_h * dp_h + r_hi_h * ep_h + + ke_lo_h * pi_h + ke_hi_h * pc_h; + + // Combine nd_pll + (nd_mid_lh + nd_mid_hl) << 15 + nd_phh << 30 + cd into i64. + // First fold nd_mid_lh + nd_mid_hl as i64 (mid lane could overflow i32 if combined). + // Each is < 2^31; sum needs 33 bits. + let nd_pll_u: u32 = u32(nd_pll); + let nd_mlh_u: u32 = u32(nd_mid_lh); + let nd_mhl_u: u32 = u32(nd_mid_hl); + let nd_phh_u: u32 = u32(nd_phh); + let nd_pll_hi: i32 = nd_pll >> 31u; + let nd_mlh_hi: i32 = nd_mid_lh >> 17u; + let nd_mhl_hi: i32 = nd_mid_hl >> 17u; + let nd_phh_hi: i32 = nd_phh >> 2u; + + // s = pll + mlh<<15 + mhl<<15 + phh<<30 + cd + let nd_a_lo: u32 = nd_pll_u + (nd_mlh_u << 15u); + let nd_a_c: i32 = select(0i, 1i, nd_a_lo < nd_pll_u); + let nd_b_lo: u32 = nd_a_lo + (nd_mhl_u << 15u); + let nd_b_c: i32 = select(0i, 1i, nd_b_lo < nd_a_lo); + let nd_c_lo: u32 = nd_b_lo + (nd_phh_u << 30u); + let nd_c_c: i32 = select(0i, 1i, nd_c_lo < nd_b_lo); + let nd_d_lo: u32 = nd_c_lo + cd_lo; + let nd_d_c: i32 = select(0i, 1i, nd_d_lo < nd_c_lo); + let nd_d_hi: i32 = nd_pll_hi + nd_mlh_hi + nd_mhl_hi + nd_phh_hi + nd_a_c + nd_b_c + nd_c_c + nd_d_c + cd_hi; + + let ne_pll_u: u32 = u32(ne_pll); + let ne_mlh_u: u32 = u32(ne_mid_lh); + let ne_mhl_u: u32 = u32(ne_mid_hl); + let ne_phh_u: u32 = u32(ne_phh); + let ne_pll_hi: i32 = ne_pll >> 31u; + let ne_mlh_hi: i32 = ne_mid_lh >> 17u; + let ne_mhl_hi: i32 = ne_mid_hl >> 17u; + let ne_phh_hi: i32 = ne_phh >> 2u; + + let ne_a_lo: u32 = ne_pll_u + (ne_mlh_u << 15u); + let ne_a_c: i32 = select(0i, 1i, ne_a_lo < ne_pll_u); + let ne_b_lo: u32 = ne_a_lo + (ne_mhl_u << 15u); + let ne_b_c: i32 = select(0i, 1i, ne_b_lo < ne_a_lo); + let ne_c_lo: u32 = ne_b_lo + (ne_phh_u << 30u); + let ne_c_c: i32 = select(0i, 1i, ne_c_lo < ne_b_lo); + let ne_d_lo: u32 = ne_c_lo + ce_lo; + let ne_d_c: i32 = select(0i, 1i, ne_d_lo < ne_c_lo); + let ne_d_hi: i32 = ne_pll_hi + ne_mlh_hi + ne_mhl_hi + ne_phh_hi + ne_a_c + ne_b_c + ne_c_c + ne_d_c + ce_hi; + + (*d).l[i - 2u] = i32(nd_d_lo & BY_LIMB_MASK); + (*e).l[i - 2u] = i32(ne_d_lo & BY_LIMB_MASK); + cd_lo = (nd_d_lo >> 29u) | (u32(nd_d_hi) << 3u); + cd_hi = nd_d_hi >> 29u; + ce_lo = (ne_d_lo >> 29u) | (u32(ne_d_hi) << 3u); + ce_hi = ne_d_hi >> 29u; + + // Slide previous-limb splits. + dp_l = di_l; dp_h = di_h; + ep_l = ei_l; ep_h = ei_h; + pc_l = pi_l; pc_h = pi_h; + } + + // Top-limb finalisation: + // nd9 = u_hi * dp + v_hi * ep + kd_hi * p[N-1] + cd (3 products) + // ne9 = q_hi * dp + r_hi * ep + ke_hi * p[N-1] + ce + let p_top: i32 = (*p).l[BY_NUM_LIMBS - 1u]; + let pt_l: i32 = (p_top << 17u) >> 17u; + let pt_h: i32 = (p_top - pt_l) >> 15u; + + let nd9_pll: i32 = u_hi_l * dp_l + v_hi_l * ep_l + kd_hi_l * pt_l; + let nd9_mid: i32 = + u_hi_l * dp_h + u_hi_h * dp_l + + v_hi_l * ep_h + v_hi_h * ep_l + + kd_hi_l * pt_h + kd_hi_h * pt_l; + let nd9_phh: i32 = u_hi_h * dp_h + v_hi_h * ep_h + kd_hi_h * pt_h; + + let ne9_pll: i32 = q_hi_l * dp_l + r_hi_l * ep_l + ke_hi_l * pt_l; + let ne9_mid: i32 = + q_hi_l * dp_h + q_hi_h * dp_l + + r_hi_l * ep_h + r_hi_h * ep_l + + ke_hi_l * pt_h + ke_hi_h * pt_l; + let ne9_phh: i32 = q_hi_h * dp_h + r_hi_h * ep_h + ke_hi_h * pt_h; + + // For 3 products, mid sum ≤ 3 * 2 * 2^28 = 3 * 2^29 = 1.5 * 2^30 < 2^31 ✓ + let nd9_pll_u: u32 = u32(nd9_pll); + let nd9_mid_u: u32 = u32(nd9_mid); + let nd9_phh_u: u32 = u32(nd9_phh); + let nd9_pll_hi: i32 = nd9_pll >> 31u; + let nd9_mid_hi: i32 = nd9_mid >> 17u; + let nd9_phh_hi: i32 = nd9_phh >> 2u; + let nd9_s1: u32 = nd9_pll_u + (nd9_mid_u << 15u); + let nd9_s1c: i32 = select(0i, 1i, nd9_s1 < nd9_pll_u); + let nd9_s2: u32 = nd9_s1 + (nd9_phh_u << 30u); + let nd9_s2c: i32 = select(0i, 1i, nd9_s2 < nd9_s1); + let nd9_total_lo: u32 = nd9_s2 + cd_lo; + let nd9_total_c: i32 = select(0i, 1i, nd9_total_lo < nd9_s2); + let nd9_total_hi: i32 = nd9_pll_hi + nd9_mid_hi + nd9_phh_hi + nd9_s1c + nd9_s2c + nd9_total_c + cd_hi; + + let ne9_pll_u: u32 = u32(ne9_pll); + let ne9_mid_u: u32 = u32(ne9_mid); + let ne9_phh_u: u32 = u32(ne9_phh); + let ne9_pll_hi: i32 = ne9_pll >> 31u; + let ne9_mid_hi: i32 = ne9_mid >> 17u; + let ne9_phh_hi: i32 = ne9_phh >> 2u; + let ne9_s1: u32 = ne9_pll_u + (ne9_mid_u << 15u); + let ne9_s1c: i32 = select(0i, 1i, ne9_s1 < ne9_pll_u); + let ne9_s2: u32 = ne9_s1 + (ne9_phh_u << 30u); + let ne9_s2c: i32 = select(0i, 1i, ne9_s2 < ne9_s1); + let ne9_total_lo: u32 = ne9_s2 + ce_lo; + let ne9_total_c: i32 = select(0i, 1i, ne9_total_lo < ne9_s2); + let ne9_total_hi: i32 = ne9_pll_hi + ne9_mid_hi + ne9_phh_hi + ne9_s1c + ne9_s2c + ne9_total_c + ce_hi; + + (*d).l[BY_NUM_LIMBS - 2u] = i32(nd9_total_lo & BY_LIMB_MASK); + (*e).l[BY_NUM_LIMBS - 2u] = i32(ne9_total_lo & BY_LIMB_MASK); + (*d).l[BY_NUM_LIMBS - 1u] = i32((nd9_total_lo >> 29u) | (u32(nd9_total_hi) << 3u)); + (*e).l[BY_NUM_LIMBS - 1u] = i32((ne9_total_lo >> 29u) | (u32(ne9_total_hi) << 3u)); + // by_normalise no-op: lower limbs already masked. +} + +// ============================================================ +// fr_inv_by driver helpers +// ============================================================ + +// by_is_zero: returns true iff every limb of x is zero. +// Pre: any state (need not be canonical). Post: bool. +fn by_is_zero(x: ptr) -> bool { + var a: i32 = 0; + for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u) { + a = a | (*x).l[i]; + } + return a == 0; +} + +// by_is_negative: top-limb sign check on a normalised BigIntBY. +// Pre: x normalised so the top limb carries the sign. Post: bool. +fn by_is_negative(x: ptr) -> bool { + return (*x).l[BY_NUM_LIMBS - 1u] < 0; +} + +// by_neg_inplace: negate x then re-normalise so lower limbs are in +// [0, 2^29) again. Mirrors \`neg(x)\` in bernstein_yang.ts. +fn by_neg_inplace(x: ptr) { + for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u) { + (*x).l[i] = -(*x).l[i]; + } + by_normalise(x); +} + +// by_add_p_inplace: x <- x + p (limbwise) then normalise. +fn by_add_p_inplace(x: ptr, p: ptr) { + for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u) { + (*x).l[i] = (*x).l[i] + (*p).l[i]; + } + by_normalise(x); +} + +// by_sub_p_inplace: x <- x - p (limbwise) then normalise. +fn by_sub_p_inplace(x: ptr, p: ptr) { + for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u) { + (*x).l[i] = (*x).l[i] - (*p).l[i]; + } + by_normalise(x); +} + +// by_gte_p: true iff x >= p, assuming both x and p are non-negative +// canonical-limb 9-limb BigIntBY values (lower limbs in [0, 2^29), top limb +// non-negative). Walks limbs from high to low. +fn by_gte_p(x: ptr, p: ptr) -> bool { + var gt: bool = false; + var lt: bool = false; + for (var ii: u32 = 0u; ii < BY_NUM_LIMBS; ii = ii + 1u) { + let i: u32 = BY_NUM_LIMBS - 1u - ii; + let a: i32 = (*x).l[i]; + let b: i32 = (*p).l[i]; + let still_undecided: bool = !(gt || lt); + if (still_undecided) { + if (a > b) { gt = true; } + else if (a < b) { lt = true; } + } + } + return gt || !lt; +} + +// by_reduce_to_canonical: bring x into [0, p) using at most BY_RTC_MAX_ITERS +// (= 36) add-p / sub-p passes. Mirrors \`reduceToCanonical\` in +// bernstein_yang.ts exactly: if x is negative, add p; else if x >= p, +// subtract p; else break. The 36-iter bound suffices for |x| <= 32 p under +// REDUCE_INTERVAL = 4 (see Wasm9x29 docs). +// +// LOOP BOUND: \`for (var it: u32 = 0u; it < BY_RTC_MAX_ITERS; ...)\` — const. +// +// Pre: x is a possibly-non-canonical signed BigIntBY (post-by_normalise: +// lower limbs in [0, 2^29), top limb carries sign). p is the modulus +// in BigIntBY form (positive, canonical). +// Post: x in [0, p), canonical. +fn by_reduce_to_canonical(x: ptr, p: ptr) { + by_normalise(x); + var done: bool = false; + for (var it: u32 = 0u; it < BY_RTC_MAX_ITERS; it = it + 1u) { + if (done) { continue; } + if (by_is_negative(x)) { + by_add_p_inplace(x, p); + } else if (by_gte_p(x, p)) { + by_sub_p_inplace(x, p); + } else { + done = true; + } + } +} + +// 58-bit p_inv split as low 32 / high (<=26) bits. Mustache-injected by +// \`gen_fr_inv_bench_shader\` (and the production wiring in step 1.7) from +// \`compute_by_p_inv_split\` in cuzk/utils.ts. Matching pair to the +// \`P_INV_BY_LO\` / \`P_INV_BY_HI\` constants the apply_matrix bench uses. +const FR_INV_BY_P_INV_LO: u32 = {{ p_inv_by_lo }}u; +const FR_INV_BY_P_INV_HI: u32 = {{ p_inv_by_hi }}u; + +// fr_inv_by: Bernstein-Yang safegcd inverse driver, mirroring +// \`invert_bernsteinyang19\` (bernstein_yang_inverse.hpp lines 290-326) +// and the TS reference \`Wasm9x29.invert\` (bernstein_yang.ts:409-440). +// +// Input \`a\` is in Montgomery form: bigint_value(a) = A * R mod p. +// Output is in Montgomery form: bigint_value(output) = A^(-1) * R mod p. +// +// Algorithm sketch: +// 1. Convert (a, p) to 9 x 29-bit BigIntBY representation. +// 2. Run NUM_OUTER = 13 outer iterations. Each outer iter: +// a. Compute low-64-bit views (f_lo, g_lo) of (f, g). +// b. by_divsteps(&delta, f_lo, g_lo) -> Mat (BATCH = 58 inner). +// c. by_apply_matrix_fg(M, &f, &g) — folds 58 divsteps into f, g. +// d. by_apply_matrix_de(M, &d, &e, &p, p_inv) — same on (d, e). +// e. Every BY_REDUCE_INTERVAL = 4 iters, reduce_to_canonical(d, e). +// f. Early break on \`by_is_zero(g)\` — the const NUM_OUTER bound is +// still respected by the WGSL emitter via a guard flag, not by +// shrinking the loop count. +// 3. After the loop, reduce_to_canonical(d) and, if f is negative, +// negate d mod p (mirrors the C++ \`sign(f) * d\` step). +// 4. The BY output is \`inv_native = (A * R)^(-1) mod p = A^(-1) * R^(-1)\` +// in canonical [0, p). Apply the standard Mont correction via +// \`montgomery_product(inv_native, R^3)\` = +// inv_native * R^3 * R^(-1) = inv_native * R^2 = A^(-1) * R, in +// Montgomery form. Pattern matches \`fr_inv\` in fr_pow.template.wgsl. +// 5. Convert back to 20 x 13-bit BigInt and return. +// +// LOOP BOUND DISCIPLINE: +// - outer loop: \`for (... iter < BY_NUM_OUTER; ...)\` (const 13). +// - by_divsteps: \`for (... i < BY_BATCH; ...)\` (const 58). +// - by_apply_matrix_*: \`for (... i < BY_NUM_LIMBS; ...)\` (const 9). +// - by_reduce_to_canonical: \`for (... it < BY_RTC_MAX_ITERS; ...)\`(const 36). +// - by_normalise / by_neg: \`for (... i < BY_NUM_LIMBS; ...)\` (const 9). +// - by_from_bigint / by_to_bigint loops bounded by const BY_NUM_LIMBS +// and Mustache-const \`{{ num_words }}\`. +// No data-dependent loop bounds anywhere on the inversion path. +fn fr_inv_by(a: BigInt) -> BigInt { + // Modulus p in BigIntBY form. Use the same Mustache-injected initializer + // as the apply_matrix bench; this gates fr_inv_by's behaviour on the + // ShaderManager-supplied p_limbs_by, matching the rest of the BY surface. + var p_by: BigIntBY = BigIntBY(array({{{ p_limbs_by }}})); + var f: BigIntBY = BigIntBY(array({{{ p_limbs_by }}})); + var g: BigIntBY = by_from_bigint(a); + + var d: BigIntBY; + var e: BigIntBY; + for (var k: u32 = 0u; k < BY_NUM_LIMBS; k = k + 1u) { + d.l[k] = 0; + e.l[k] = 0; + } + e.l[0] = 1; + + var delta: i32 = 1; + var done: bool = false; + for (var iter: u32 = 0u; iter < BY_NUM_OUTER; iter = iter + 1u) { + if (done) { continue; } + // low_64 view of f and g for divsteps. Inlined by_low_u64_lohi. + let f_l0: u32 = u32(f.l[0]) & BY_LIMB_MASK; + let f_l1: u32 = u32(f.l[1]) & BY_LIMB_MASK; + let f_l2: u32 = u32(f.l[2]) & 0x3Fu; + let f_lo: vec2 = vec2(f_l0 | ((f_l1 & 0x7u) << 29u), (f_l1 >> 3u) | (f_l2 << 26u)); + let g_l0: u32 = u32(g.l[0]) & BY_LIMB_MASK; + let g_l1: u32 = u32(g.l[1]) & BY_LIMB_MASK; + let g_l2: u32 = u32(g.l[2]) & 0x3Fu; + let g_lo: vec2 = vec2(g_l0 | ((g_l1 & 0x7u) << 29u), (g_l1 >> 3u) | (g_l2 << 26u)); + let m: Mat = by_divsteps(&delta, f_lo, g_lo); + by_apply_matrix_fg(m, &f, &g); + by_apply_matrix_de(m, &d, &e, &p_by, FR_INV_BY_P_INV_LO, FR_INV_BY_P_INV_HI); + if (((iter + 1u) % BY_REDUCE_INTERVAL) == 0u) { + by_reduce_to_canonical(&d, &p_by); + by_reduce_to_canonical(&e, &p_by); + } + if (by_is_zero(&g)) { + done = true; + } + } + + by_reduce_to_canonical(&d, &p_by); + if (by_is_negative(&f)) { + by_neg_inplace(&d); + by_reduce_to_canonical(&d, &p_by); + } + + // inv_native = A^(-1) * R^(-1) mod p (canonical [0, p)). Mont correction + // via \`montgomery_product(inv_native, R^3)\` lands at A^(-1) * R, matching + // the pattern used by fr_inv in fr_pow.template.wgsl. + var inv_native: BigInt = by_to_bigint(d); + var r_cubed: BigInt = get_r_cubed(); + return montgomery_product(&inv_native, &r_cubed); +} +`; + +export const by_inverse_a = `// Option A: Bernstein-Yang safegcd inverse on the 20 x 13-bit BigInt +// representation. Tight wide-multiply apply_matrix variant. +// +// LAYOUT +// - BigInt: 20 x 13-bit limbs (canonical input) or 20 limbs storing +// SIGNED i32 bitcast into u32 (non-canonical between iters, magnitude +// bounded by 2^15). +// - Matrix entries u, v, q, r: signed i32. After BATCH=26 inner divsteps +// |entry| <= 2^26, fits comfortably in i32. +// - Inner divsteps operate on the LOW 64 BITS of (f, g) carried as a +// vec2. We need >= BATCH bits to drive divstep decisions +// correctly; 64 gives us 38 bits of headroom for sign propagation. +// +// APPLY_MATRIX DESIGN +// - Per-output-limb raw accumulators are each ONE inline expression of +// four 13-bit muls + three adds. No common-subexpression pre-compute +// (each lo/hi*limb product is used in exactly ONE slot). The compiler +// issues them back-to-back, the GPU keeps registers tight. +// - Carry-propagation is TWO parallel passes (each reads only the prior +// pass's output, not its own in-progress writes). After two passes the +// limbs fit in [-2^14, 2^14] and we store them as u32 bitcast. +// - We do NOT canonicalize between outer iterations: limbs stay signed +// non-canonical up to 2^15 magnitude. Next iter's multiply tolerates +// this because (2^13_matrix * 2^15_limb) * 4_terms = 2^30 < 2^31. +// - We DO canonicalize d at the very end (before the Mont correction). +// +// LOOP BOUND DISCIPLINE +// - Outer driver: \`for (var iter < BYA_NUM_OUTER)\` (const = 29) +// - Inner divsteps: \`for (var i < BYA_BATCH)\` (const = 26) +// - Apply matrix: fully unrolled (no loops) +// - Reduce-to-canonical: \`for (var it < BYA_RTC_MAX_ITERS)\` (const = 4) +// +// CONVERGENCE +// Bernstein-Yang safegcd bound for 256-bit modulus: 735 divsteps. +// BATCH=26 -> NUM_OUTER = ceil(735/26) = 29. + +const BYA_BATCH: u32 = 26u; +const BYA_NUM_OUTER: u32 = 29u; +const BYA_REDUCE_INTERVAL: u32 = 4u; +const BYA_RTC_MAX_ITERS: u32 = 4u; +const BYA_MASK13: u32 = (1u << 13u) - 1u; +const BYA_MASK13_I32: i32 = (1 << 13) - 1; + +// 2x2 matrix entries after BATCH=26 divsteps. Each entry is an i32 with +// |.| <= 2^26. +struct MatA { + u: i32, + v: i32, + q: i32, + r: i32, +} + +// ============================================================ +// bya_divsteps: BATCH=26 branchy divsteps on the low 64 bits of (f, g). +// +// Matrix entries u, v, q, r grow by at most one shl + one sub per iter, +// so after BATCH=26 we have |entry| <= 2^26. +// ============================================================ +fn bya_divsteps(delta: ptr, f_lo_in: vec2, g_lo_in: vec2) -> MatA { + var f_lo: vec2 = f_lo_in; + var g_lo: vec2 = g_lo_in; + var u: i32 = 1; + var v: i32 = 0; + var q: i32 = 0; + var r: i32 = 1; + var d: i32 = *delta; + for (var i: u32 = 0u; i < BYA_BATCH; i = i + 1u) { + if (u64_low_bit(g_lo) != 0u) { + if (d > 0) { + let nf: vec2 = g_lo; + let diff: vec2 = u64_sub(g_lo, f_lo); + let ng: vec2 = u64_shr1(diff); + let nu: i32 = q << 1u; + let nv: i32 = r << 1u; + let nq: i32 = q - u; + let nr: i32 = r - v; + f_lo = nf; + g_lo = ng; + u = nu; + v = nv; + q = nq; + r = nr; + d = 1 - d; + } else { + let sum: vec2 = u64_add(g_lo, f_lo); + g_lo = u64_shr1(sum); + q = q + u; + r = r + v; + u = u << 1u; + v = v << 1u; + d = d + 1; + } + } else { + g_lo = u64_shr1(g_lo); + u = u << 1u; + v = v << 1u; + d = d + 1; + } + } + *delta = d; + return MatA(u, v, q, r); +} + +// ============================================================ +// bya_low_u64_lohi: low 64 bits of a 20 x 13-bit BigInt with canonical +// 13-bit limbs (the serial-carry apply_matrix output guarantees this). +// ============================================================ +fn bya_low_u64_lohi(x: BigInt) -> vec2 { + let l0: u32 = x.limbs[0] & MASK; + let l1: u32 = x.limbs[1] & MASK; + let l2: u32 = x.limbs[2] & MASK; + let l3: u32 = x.limbs[3] & MASK; + let l4: u32 = x.limbs[4] & MASK; + let lo32: u32 = l0 | (l1 << 13u) | (l2 << 26u); + let hi32: u32 = (l2 >> 6u) | (l3 << 7u) | (l4 << 20u); + return vec2(lo32, hi32); +} + +// ============================================================ +// bya_normalise: carry-propagate so each limb in [0, N-1) is in +// [0, 2^13) canonical and the top limb absorbs the signed extension. +// Used by reduce_to_canonical at the END of fr_inv_by_a. +// ============================================================ +fn bya_normalise(x: ptr) { + var c: i32 = 0; + for (var i: u32 = 0u; i < {{ num_words }}u - 1u; i = i + 1u) { + let v = i32((*x).limbs[i]) + c; + (*x).limbs[i] = u32(v) & MASK; + c = v >> WORD_SIZE; + } + (*x).limbs[{{ num_words }}u - 1u] = u32(i32((*x).limbs[{{ num_words }}u - 1u]) + c) & MASK; +} + +// ============================================================ +// bya_apply_matrix_fg +// +// Compute (f_new, g_new) = ((u*f + v*g) >> 26, (q*f + r*g) >> 26). +// +// Matrix entry split: m = m_lo + m_hi * 2^13 where m_lo in [0, 2^13) +// (taken as low-13-bit unsigned) and m_hi in [-2^13, 2^13) (taken as +// arithmetic shift right of i32). The product is recovered as: +// m * x = m_lo * x + m_hi * x * 2^13 +// +// For each output position k in [0, 19], the raw value is +// nf[k] = u_lo*f[k+2] + v_lo*g[k+2] + u_hi*f[k+1] + v_hi*g[k+1] +// with the convention f[20] = f[21] = 0 (and same for g). The two +// dropped low product positions contribute a boundary carry into +// output 0 — see "boundary carry" comment below. +// +// Sign of f/g: limbs in [0, N-2] are non-negative (in [-2^15, 2^15) when +// non-canonical between iters); the top limb f[N-1] carries the signed +// extension of the full integer and is sign-extended via arithmetic +// shifts before multiplying. +// +// |nf[k]| <= 4 * (2^13 * 2^15) = 2^30 with non-canonical limbs, fits i32. +// ============================================================ +fn bya_apply_matrix_fg(m: MatA, f: ptr, g: ptr) { + // Matrix splits. _lo in [0, 2^13); _hi signed in [-2^13, 2^13). + let u_lo: i32 = i32(u32(m.u) & MASK); + let u_hi: i32 = m.u >> WORD_SIZE; + let v_lo: i32 = i32(u32(m.v) & MASK); + let v_hi: i32 = m.v >> WORD_SIZE; + let q_lo: i32 = i32(u32(m.q) & MASK); + let q_hi: i32 = m.q >> WORD_SIZE; + let r_lo: i32 = i32(u32(m.r) & MASK); + let r_hi: i32 = m.r >> WORD_SIZE; + + // Load all limbs into named locals to give the compiler a chance to + // hoist the loads above the multiply chain. + let f0: i32 = i32((*f).limbs[0]); + let f1: i32 = i32((*f).limbs[1]); + let f2: i32 = i32((*f).limbs[2]); + let f3: i32 = i32((*f).limbs[3]); + let f4: i32 = i32((*f).limbs[4]); + let f5: i32 = i32((*f).limbs[5]); + let f6: i32 = i32((*f).limbs[6]); + let f7: i32 = i32((*f).limbs[7]); + let f8: i32 = i32((*f).limbs[8]); + let f9: i32 = i32((*f).limbs[9]); + let f10: i32 = i32((*f).limbs[10]); + let f11: i32 = i32((*f).limbs[11]); + let f12: i32 = i32((*f).limbs[12]); + let f13: i32 = i32((*f).limbs[13]); + let f14: i32 = i32((*f).limbs[14]); + let f15: i32 = i32((*f).limbs[15]); + let f16: i32 = i32((*f).limbs[16]); + let f17: i32 = i32((*f).limbs[17]); + let f18: i32 = i32((*f).limbs[18]); + // Sign-extension of the top limb (bit 12 is the sign bit for canonical + // input; for non-canonical input we still arithmetic-shift the full + // i32 — high bits already carry sign). + let f19_raw: u32 = (*f).limbs[19]; + let f19: i32 = (i32(f19_raw) << (32u - WORD_SIZE)) >> (32u - WORD_SIZE); + + let g0: i32 = i32((*g).limbs[0]); + let g1: i32 = i32((*g).limbs[1]); + let g2: i32 = i32((*g).limbs[2]); + let g3: i32 = i32((*g).limbs[3]); + let g4: i32 = i32((*g).limbs[4]); + let g5: i32 = i32((*g).limbs[5]); + let g6: i32 = i32((*g).limbs[6]); + let g7: i32 = i32((*g).limbs[7]); + let g8: i32 = i32((*g).limbs[8]); + let g9: i32 = i32((*g).limbs[9]); + let g10: i32 = i32((*g).limbs[10]); + let g11: i32 = i32((*g).limbs[11]); + let g12: i32 = i32((*g).limbs[12]); + let g13: i32 = i32((*g).limbs[13]); + let g14: i32 = i32((*g).limbs[14]); + let g15: i32 = i32((*g).limbs[15]); + let g16: i32 = i32((*g).limbs[16]); + let g17: i32 = i32((*g).limbs[17]); + let g18: i32 = i32((*g).limbs[18]); + let g19_raw: u32 = (*g).limbs[19]; + let g19: i32 = (i32(g19_raw) << (32u - WORD_SIZE)) >> (32u - WORD_SIZE); + + // Boundary carry from the two dropped low product positions (positions + // 0 and 1). Carry-propagates as in a serial chain — the parallel-pass + // identity \`(A >> 26) + (B >> 13)\` is OFF BY 1 from + // \`((B + (A >> 13)) >> 13)\` in general because shifts don't distribute + // over addition. The boundary lands at output limb 0 below. + let rp0_f: i32 = u_lo * f0 + v_lo * g0; + let rp1_f: i32 = u_lo * f1 + v_lo * g1 + u_hi * f0 + v_hi * g0; + let boundary_f: i32 = (rp1_f + (rp0_f >> 13u)) >> 13u; + + let rp0_g: i32 = q_lo * f0 + r_lo * g0; + let rp1_g: i32 = q_lo * f1 + r_lo * g1 + q_hi * f0 + r_hi * g0; + let boundary_g: i32 = (rp1_g + (rp0_g >> 13u)) >> 13u; + + // MULTIPLY PHASE with shared partial products. Each individual product + // is used in EXACTLY ONE output slot — the names just help the GPU + // compiler pipeline issue muls + adds without re-reading the limb. + let ulf2 = u_lo * f2; let ulf3 = u_lo * f3; let ulf4 = u_lo * f4; let ulf5 = u_lo * f5; + let ulf6 = u_lo * f6; let ulf7 = u_lo * f7; let ulf8 = u_lo * f8; let ulf9 = u_lo * f9; + let ulf10 = u_lo * f10; let ulf11 = u_lo * f11; let ulf12 = u_lo * f12; let ulf13 = u_lo * f13; + let ulf14 = u_lo * f14; let ulf15 = u_lo * f15; let ulf16 = u_lo * f16; let ulf17 = u_lo * f17; + let ulf18 = u_lo * f18; let ulf19 = u_lo * f19; + let uhf1 = u_hi * f1; let uhf2 = u_hi * f2; let uhf3 = u_hi * f3; let uhf4 = u_hi * f4; + let uhf5 = u_hi * f5; let uhf6 = u_hi * f6; let uhf7 = u_hi * f7; let uhf8 = u_hi * f8; + let uhf9 = u_hi * f9; let uhf10 = u_hi * f10; let uhf11 = u_hi * f11; let uhf12 = u_hi * f12; + let uhf13 = u_hi * f13; let uhf14 = u_hi * f14; let uhf15 = u_hi * f15; let uhf16 = u_hi * f16; + let uhf17 = u_hi * f17; let uhf18 = u_hi * f18; let uhf19 = u_hi * f19; + let vlg2 = v_lo * g2; let vlg3 = v_lo * g3; let vlg4 = v_lo * g4; let vlg5 = v_lo * g5; + let vlg6 = v_lo * g6; let vlg7 = v_lo * g7; let vlg8 = v_lo * g8; let vlg9 = v_lo * g9; + let vlg10 = v_lo * g10; let vlg11 = v_lo * g11; let vlg12 = v_lo * g12; let vlg13 = v_lo * g13; + let vlg14 = v_lo * g14; let vlg15 = v_lo * g15; let vlg16 = v_lo * g16; let vlg17 = v_lo * g17; + let vlg18 = v_lo * g18; let vlg19 = v_lo * g19; + let vhg1 = v_hi * g1; let vhg2 = v_hi * g2; let vhg3 = v_hi * g3; let vhg4 = v_hi * g4; + let vhg5 = v_hi * g5; let vhg6 = v_hi * g6; let vhg7 = v_hi * g7; let vhg8 = v_hi * g8; + let vhg9 = v_hi * g9; let vhg10 = v_hi * g10; let vhg11 = v_hi * g11; let vhg12 = v_hi * g12; + let vhg13 = v_hi * g13; let vhg14 = v_hi * g14; let vhg15 = v_hi * g15; let vhg16 = v_hi * g16; + let vhg17 = v_hi * g17; let vhg18 = v_hi * g18; let vhg19 = v_hi * g19; + + let qlf2 = q_lo * f2; let qlf3 = q_lo * f3; let qlf4 = q_lo * f4; let qlf5 = q_lo * f5; + let qlf6 = q_lo * f6; let qlf7 = q_lo * f7; let qlf8 = q_lo * f8; let qlf9 = q_lo * f9; + let qlf10 = q_lo * f10; let qlf11 = q_lo * f11; let qlf12 = q_lo * f12; let qlf13 = q_lo * f13; + let qlf14 = q_lo * f14; let qlf15 = q_lo * f15; let qlf16 = q_lo * f16; let qlf17 = q_lo * f17; + let qlf18 = q_lo * f18; let qlf19 = q_lo * f19; + let qhf1 = q_hi * f1; let qhf2 = q_hi * f2; let qhf3 = q_hi * f3; let qhf4 = q_hi * f4; + let qhf5 = q_hi * f5; let qhf6 = q_hi * f6; let qhf7 = q_hi * f7; let qhf8 = q_hi * f8; + let qhf9 = q_hi * f9; let qhf10 = q_hi * f10; let qhf11 = q_hi * f11; let qhf12 = q_hi * f12; + let qhf13 = q_hi * f13; let qhf14 = q_hi * f14; let qhf15 = q_hi * f15; let qhf16 = q_hi * f16; + let qhf17 = q_hi * f17; let qhf18 = q_hi * f18; let qhf19 = q_hi * f19; + let rlg2 = r_lo * g2; let rlg3 = r_lo * g3; let rlg4 = r_lo * g4; let rlg5 = r_lo * g5; + let rlg6 = r_lo * g6; let rlg7 = r_lo * g7; let rlg8 = r_lo * g8; let rlg9 = r_lo * g9; + let rlg10 = r_lo * g10; let rlg11 = r_lo * g11; let rlg12 = r_lo * g12; let rlg13 = r_lo * g13; + let rlg14 = r_lo * g14; let rlg15 = r_lo * g15; let rlg16 = r_lo * g16; let rlg17 = r_lo * g17; + let rlg18 = r_lo * g18; let rlg19 = r_lo * g19; + let rhg1 = r_hi * g1; let rhg2 = r_hi * g2; let rhg3 = r_hi * g3; let rhg4 = r_hi * g4; + let rhg5 = r_hi * g5; let rhg6 = r_hi * g6; let rhg7 = r_hi * g7; let rhg8 = r_hi * g8; + let rhg9 = r_hi * g9; let rhg10 = r_hi * g10; let rhg11 = r_hi * g11; let rhg12 = r_hi * g12; + let rhg13 = r_hi * g13; let rhg14 = r_hi * g14; let rhg15 = r_hi * g15; let rhg16 = r_hi * g16; + let rhg17 = r_hi * g17; let rhg18 = r_hi * g18; let rhg19 = r_hi * g19; + + let nf0: i32 = ulf2 + vlg2 + uhf1 + vhg1 + boundary_f; + let nf1: i32 = ulf3 + vlg3 + uhf2 + vhg2; + let nf2: i32 = ulf4 + vlg4 + uhf3 + vhg3; + let nf3: i32 = ulf5 + vlg5 + uhf4 + vhg4; + let nf4: i32 = ulf6 + vlg6 + uhf5 + vhg5; + let nf5: i32 = ulf7 + vlg7 + uhf6 + vhg6; + let nf6: i32 = ulf8 + vlg8 + uhf7 + vhg7; + let nf7: i32 = ulf9 + vlg9 + uhf8 + vhg8; + let nf8: i32 = ulf10 + vlg10 + uhf9 + vhg9; + let nf9: i32 = ulf11 + vlg11 + uhf10 + vhg10; + let nf10: i32 = ulf12 + vlg12 + uhf11 + vhg11; + let nf11: i32 = ulf13 + vlg13 + uhf12 + vhg12; + let nf12: i32 = ulf14 + vlg14 + uhf13 + vhg13; + let nf13: i32 = ulf15 + vlg15 + uhf14 + vhg14; + let nf14: i32 = ulf16 + vlg16 + uhf15 + vhg15; + let nf15: i32 = ulf17 + vlg17 + uhf16 + vhg16; + let nf16: i32 = ulf18 + vlg18 + uhf17 + vhg17; + let nf17: i32 = ulf19 + vlg19 + uhf18 + vhg18; + let nf18: i32 = uhf19 + vhg19; + + let ng0: i32 = qlf2 + rlg2 + qhf1 + rhg1 + boundary_g; + let ng1: i32 = qlf3 + rlg3 + qhf2 + rhg2; + let ng2: i32 = qlf4 + rlg4 + qhf3 + rhg3; + let ng3: i32 = qlf5 + rlg5 + qhf4 + rhg4; + let ng4: i32 = qlf6 + rlg6 + qhf5 + rhg5; + let ng5: i32 = qlf7 + rlg7 + qhf6 + rhg6; + let ng6: i32 = qlf8 + rlg8 + qhf7 + rhg7; + let ng7: i32 = qlf9 + rlg9 + qhf8 + rhg8; + let ng8: i32 = qlf10 + rlg10 + qhf9 + rhg9; + let ng9: i32 = qlf11 + rlg11 + qhf10 + rhg10; + let ng10: i32 = qlf12 + rlg12 + qhf11 + rhg11; + let ng11: i32 = qlf13 + rlg13 + qhf12 + rhg12; + let ng12: i32 = qlf14 + rlg14 + qhf13 + rhg13; + let ng13: i32 = qlf15 + rlg15 + qhf14 + rhg14; + let ng14: i32 = qlf16 + rlg16 + qhf15 + rhg15; + let ng15: i32 = qlf17 + rlg17 + qhf16 + rhg16; + let ng16: i32 = qlf18 + rlg18 + qhf17 + rhg17; + let ng17: i32 = qlf19 + rlg19 + qhf18 + rhg18; + let ng18: i32 = qhf19 + rhg19; + + // SERIAL CARRY PASS — empirically faster than 2-pass parallel on this + // GPU (the carry chain is short enough that scheduler latency dominates + // any pipelining advantage). + var cf: i32 = 0; + let vf_0: i32 = nf0 + cf; (*f).limbs[0] = u32(vf_0) & MASK; cf = vf_0 >> 13u; + let vf_1: i32 = nf1 + cf; (*f).limbs[1] = u32(vf_1) & MASK; cf = vf_1 >> 13u; + let vf_2: i32 = nf2 + cf; (*f).limbs[2] = u32(vf_2) & MASK; cf = vf_2 >> 13u; + let vf_3: i32 = nf3 + cf; (*f).limbs[3] = u32(vf_3) & MASK; cf = vf_3 >> 13u; + let vf_4: i32 = nf4 + cf; (*f).limbs[4] = u32(vf_4) & MASK; cf = vf_4 >> 13u; + let vf_5: i32 = nf5 + cf; (*f).limbs[5] = u32(vf_5) & MASK; cf = vf_5 >> 13u; + let vf_6: i32 = nf6 + cf; (*f).limbs[6] = u32(vf_6) & MASK; cf = vf_6 >> 13u; + let vf_7: i32 = nf7 + cf; (*f).limbs[7] = u32(vf_7) & MASK; cf = vf_7 >> 13u; + let vf_8: i32 = nf8 + cf; (*f).limbs[8] = u32(vf_8) & MASK; cf = vf_8 >> 13u; + let vf_9: i32 = nf9 + cf; (*f).limbs[9] = u32(vf_9) & MASK; cf = vf_9 >> 13u; + let vf_10: i32 = nf10 + cf; (*f).limbs[10] = u32(vf_10) & MASK; cf = vf_10 >> 13u; + let vf_11: i32 = nf11 + cf; (*f).limbs[11] = u32(vf_11) & MASK; cf = vf_11 >> 13u; + let vf_12: i32 = nf12 + cf; (*f).limbs[12] = u32(vf_12) & MASK; cf = vf_12 >> 13u; + let vf_13: i32 = nf13 + cf; (*f).limbs[13] = u32(vf_13) & MASK; cf = vf_13 >> 13u; + let vf_14: i32 = nf14 + cf; (*f).limbs[14] = u32(vf_14) & MASK; cf = vf_14 >> 13u; + let vf_15: i32 = nf15 + cf; (*f).limbs[15] = u32(vf_15) & MASK; cf = vf_15 >> 13u; + let vf_16: i32 = nf16 + cf; (*f).limbs[16] = u32(vf_16) & MASK; cf = vf_16 >> 13u; + let vf_17: i32 = nf17 + cf; (*f).limbs[17] = u32(vf_17) & MASK; cf = vf_17 >> 13u; + let vf_18: i32 = nf18 + cf; (*f).limbs[18] = u32(vf_18) & MASK; cf = vf_18 >> 13u; + (*f).limbs[19] = u32(cf); + + var cg: i32 = 0; + let vg_0: i32 = ng0 + cg; (*g).limbs[0] = u32(vg_0) & MASK; cg = vg_0 >> 13u; + let vg_1: i32 = ng1 + cg; (*g).limbs[1] = u32(vg_1) & MASK; cg = vg_1 >> 13u; + let vg_2: i32 = ng2 + cg; (*g).limbs[2] = u32(vg_2) & MASK; cg = vg_2 >> 13u; + let vg_3: i32 = ng3 + cg; (*g).limbs[3] = u32(vg_3) & MASK; cg = vg_3 >> 13u; + let vg_4: i32 = ng4 + cg; (*g).limbs[4] = u32(vg_4) & MASK; cg = vg_4 >> 13u; + let vg_5: i32 = ng5 + cg; (*g).limbs[5] = u32(vg_5) & MASK; cg = vg_5 >> 13u; + let vg_6: i32 = ng6 + cg; (*g).limbs[6] = u32(vg_6) & MASK; cg = vg_6 >> 13u; + let vg_7: i32 = ng7 + cg; (*g).limbs[7] = u32(vg_7) & MASK; cg = vg_7 >> 13u; + let vg_8: i32 = ng8 + cg; (*g).limbs[8] = u32(vg_8) & MASK; cg = vg_8 >> 13u; + let vg_9: i32 = ng9 + cg; (*g).limbs[9] = u32(vg_9) & MASK; cg = vg_9 >> 13u; + let vg_10: i32 = ng10 + cg; (*g).limbs[10] = u32(vg_10) & MASK; cg = vg_10 >> 13u; + let vg_11: i32 = ng11 + cg; (*g).limbs[11] = u32(vg_11) & MASK; cg = vg_11 >> 13u; + let vg_12: i32 = ng12 + cg; (*g).limbs[12] = u32(vg_12) & MASK; cg = vg_12 >> 13u; + let vg_13: i32 = ng13 + cg; (*g).limbs[13] = u32(vg_13) & MASK; cg = vg_13 >> 13u; + let vg_14: i32 = ng14 + cg; (*g).limbs[14] = u32(vg_14) & MASK; cg = vg_14 >> 13u; + let vg_15: i32 = ng15 + cg; (*g).limbs[15] = u32(vg_15) & MASK; cg = vg_15 >> 13u; + let vg_16: i32 = ng16 + cg; (*g).limbs[16] = u32(vg_16) & MASK; cg = vg_16 >> 13u; + let vg_17: i32 = ng17 + cg; (*g).limbs[17] = u32(vg_17) & MASK; cg = vg_17 >> 13u; + let vg_18: i32 = ng18 + cg; (*g).limbs[18] = u32(vg_18) & MASK; cg = vg_18 >> 13u; + (*g).limbs[19] = u32(cg); +} + +// ============================================================ +// bya_apply_matrix_de — same shape as fg, plus k_d/k_e * p folded in. +// +// k_d, k_e are chosen so the low 26 bits of (u*d + v*e), (q*d + r*e) +// cancel mod p. The "low 26" reconstruction uses the same two-limb +// pre-compute as before. +// +// |nd[k]| <= 6 * (2^13 * 2^15) = 3 * 2^29 ≈ 2^30 — fits i32 with margin. +// ============================================================ +fn bya_apply_matrix_de( + m: MatA, + d: ptr, + e: ptr, + p: ptr, + p_inv_lo: u32, +) { + let u_lo: i32 = i32(u32(m.u) & MASK); + let u_hi: i32 = m.u >> WORD_SIZE; + let v_lo: i32 = i32(u32(m.v) & MASK); + let v_hi: i32 = m.v >> WORD_SIZE; + let q_lo: i32 = i32(u32(m.q) & MASK); + let q_hi: i32 = m.q >> WORD_SIZE; + let r_lo: i32 = i32(u32(m.r) & MASK); + let r_hi: i32 = m.r >> WORD_SIZE; + + // Load all limbs into named locals. + let d0: i32 = i32((*d).limbs[0]); + let d1: i32 = i32((*d).limbs[1]); + let d2: i32 = i32((*d).limbs[2]); + let d3: i32 = i32((*d).limbs[3]); + let d4: i32 = i32((*d).limbs[4]); + let d5: i32 = i32((*d).limbs[5]); + let d6: i32 = i32((*d).limbs[6]); + let d7: i32 = i32((*d).limbs[7]); + let d8: i32 = i32((*d).limbs[8]); + let d9: i32 = i32((*d).limbs[9]); + let d10: i32 = i32((*d).limbs[10]); + let d11: i32 = i32((*d).limbs[11]); + let d12: i32 = i32((*d).limbs[12]); + let d13: i32 = i32((*d).limbs[13]); + let d14: i32 = i32((*d).limbs[14]); + let d15: i32 = i32((*d).limbs[15]); + let d16: i32 = i32((*d).limbs[16]); + let d17: i32 = i32((*d).limbs[17]); + let d18: i32 = i32((*d).limbs[18]); + let d19_raw: u32 = (*d).limbs[19]; + let d19: i32 = (i32(d19_raw) << (32u - WORD_SIZE)) >> (32u - WORD_SIZE); + + let e0: i32 = i32((*e).limbs[0]); + let e1: i32 = i32((*e).limbs[1]); + let e2: i32 = i32((*e).limbs[2]); + let e3: i32 = i32((*e).limbs[3]); + let e4: i32 = i32((*e).limbs[4]); + let e5: i32 = i32((*e).limbs[5]); + let e6: i32 = i32((*e).limbs[6]); + let e7: i32 = i32((*e).limbs[7]); + let e8: i32 = i32((*e).limbs[8]); + let e9: i32 = i32((*e).limbs[9]); + let e10: i32 = i32((*e).limbs[10]); + let e11: i32 = i32((*e).limbs[11]); + let e12: i32 = i32((*e).limbs[12]); + let e13: i32 = i32((*e).limbs[13]); + let e14: i32 = i32((*e).limbs[14]); + let e15: i32 = i32((*e).limbs[15]); + let e16: i32 = i32((*e).limbs[16]); + let e17: i32 = i32((*e).limbs[17]); + let e18: i32 = i32((*e).limbs[18]); + let e19_raw: u32 = (*e).limbs[19]; + let e19: i32 = (i32(e19_raw) << (32u - WORD_SIZE)) >> (32u - WORD_SIZE); + + let p0: i32 = i32((*p).limbs[0]); + let p1: i32 = i32((*p).limbs[1]); + let p2: i32 = i32((*p).limbs[2]); + let p3: i32 = i32((*p).limbs[3]); + let p4: i32 = i32((*p).limbs[4]); + let p5: i32 = i32((*p).limbs[5]); + let p6: i32 = i32((*p).limbs[6]); + let p7: i32 = i32((*p).limbs[7]); + let p8: i32 = i32((*p).limbs[8]); + let p9: i32 = i32((*p).limbs[9]); + let p10: i32 = i32((*p).limbs[10]); + let p11: i32 = i32((*p).limbs[11]); + let p12: i32 = i32((*p).limbs[12]); + let p13: i32 = i32((*p).limbs[13]); + let p14: i32 = i32((*p).limbs[14]); + let p15: i32 = i32((*p).limbs[15]); + let p16: i32 = i32((*p).limbs[16]); + let p17: i32 = i32((*p).limbs[17]); + let p18: i32 = i32((*p).limbs[18]); + let p19: i32 = i32((*p).limbs[19]); + + // === Step 1: m-trick. Compute low 26 bits of (u*d + v*e), (q*d + r*e) + // to derive k_d, k_e so the result is divisible by 2^26. + let nd0_pre: i32 = u_lo * d0 + v_lo * e0; + let nd1_pre: i32 = u_lo * d1 + v_lo * e1 + u_hi * d0 + v_hi * e0; + let ne0_pre: i32 = q_lo * d0 + r_lo * e0; + let ne1_pre: i32 = q_lo * d1 + r_lo * e1 + q_hi * d0 + r_hi * e0; + + let nd1_full: i32 = nd1_pre + (nd0_pre >> 13u); + let ne1_full: i32 = ne1_pre + (ne0_pre >> 13u); + let td_low26: u32 = (u32(nd0_pre) & MASK) | ((u32(nd1_full) & MASK) << 13u); + let te_low26: u32 = (u32(ne0_pre) & MASK) | ((u32(ne1_full) & MASK) << 13u); + + let MASK_BATCH: u32 = (1u << BYA_BATCH) - 1u; + let neg_td: u32 = (~td_low26 + 1u) & MASK_BATCH; + let neg_te: u32 = (~te_low26 + 1u) & MASK_BATCH; + let kd_full: u32 = (neg_td * p_inv_lo) & MASK_BATCH; + let ke_full: u32 = (neg_te * p_inv_lo) & MASK_BATCH; + + let kd_lo: i32 = i32(kd_full & MASK); + let kd_hi: i32 = i32(kd_full >> WORD_SIZE); + let ke_lo: i32 = i32(ke_full & MASK); + let ke_hi: i32 = i32(ke_full >> WORD_SIZE); + + // Boundary carry from positions 0, 1 of the full product. After + // m-trick, the low 26 bits ARE zero, so boundary is exactly the + // shift-out from positions 0 and 1. + let rp0_d: i32 = nd0_pre + kd_lo * p0; + let rp1_d: i32 = nd1_pre + kd_lo * p1 + kd_hi * p0; + let boundary_d: i32 = (rp1_d + (rp0_d >> 13u)) >> 13u; + + let rp0_e: i32 = ne0_pre + ke_lo * p0; + let rp1_e: i32 = ne1_pre + ke_lo * p1 + ke_hi * p0; + let boundary_e: i32 = (rp1_e + (rp0_e >> 13u)) >> 13u; + + // PARALLEL MULTIPLY PHASE. + // raw_nd[k] = u_lo*d[k+2] + v_lo*e[k+2] + u_hi*d[k+1] + v_hi*e[k+1] + // + kd_lo*p[k+2] + kd_hi*p[k+1] + let nd0: i32 = u_lo * d2 + v_lo * e2 + u_hi * d1 + v_hi * e1 + kd_lo * p2 + kd_hi * p1 + boundary_d; + let nd1: i32 = u_lo * d3 + v_lo * e3 + u_hi * d2 + v_hi * e2 + kd_lo * p3 + kd_hi * p2; + let nd2: i32 = u_lo * d4 + v_lo * e4 + u_hi * d3 + v_hi * e3 + kd_lo * p4 + kd_hi * p3; + let nd3: i32 = u_lo * d5 + v_lo * e5 + u_hi * d4 + v_hi * e4 + kd_lo * p5 + kd_hi * p4; + let nd4: i32 = u_lo * d6 + v_lo * e6 + u_hi * d5 + v_hi * e5 + kd_lo * p6 + kd_hi * p5; + let nd5: i32 = u_lo * d7 + v_lo * e7 + u_hi * d6 + v_hi * e6 + kd_lo * p7 + kd_hi * p6; + let nd6: i32 = u_lo * d8 + v_lo * e8 + u_hi * d7 + v_hi * e7 + kd_lo * p8 + kd_hi * p7; + let nd7: i32 = u_lo * d9 + v_lo * e9 + u_hi * d8 + v_hi * e8 + kd_lo * p9 + kd_hi * p8; + let nd8: i32 = u_lo * d10 + v_lo * e10 + u_hi * d9 + v_hi * e9 + kd_lo * p10 + kd_hi * p9; + let nd9: i32 = u_lo * d11 + v_lo * e11 + u_hi * d10 + v_hi * e10 + kd_lo * p11 + kd_hi * p10; + let nd10: i32 = u_lo * d12 + v_lo * e12 + u_hi * d11 + v_hi * e11 + kd_lo * p12 + kd_hi * p11; + let nd11: i32 = u_lo * d13 + v_lo * e13 + u_hi * d12 + v_hi * e12 + kd_lo * p13 + kd_hi * p12; + let nd12: i32 = u_lo * d14 + v_lo * e14 + u_hi * d13 + v_hi * e13 + kd_lo * p14 + kd_hi * p13; + let nd13: i32 = u_lo * d15 + v_lo * e15 + u_hi * d14 + v_hi * e14 + kd_lo * p15 + kd_hi * p14; + let nd14: i32 = u_lo * d16 + v_lo * e16 + u_hi * d15 + v_hi * e15 + kd_lo * p16 + kd_hi * p15; + let nd15: i32 = u_lo * d17 + v_lo * e17 + u_hi * d16 + v_hi * e16 + kd_lo * p17 + kd_hi * p16; + let nd16: i32 = u_lo * d18 + v_lo * e18 + u_hi * d17 + v_hi * e17 + kd_lo * p18 + kd_hi * p17; + let nd17: i32 = u_lo * d19 + v_lo * e19 + u_hi * d18 + v_hi * e18 + kd_lo * p19 + kd_hi * p18; + let nd18: i32 = u_hi * d19 + v_hi * e19 + kd_hi * p19; + + let ne0: i32 = q_lo * d2 + r_lo * e2 + q_hi * d1 + r_hi * e1 + ke_lo * p2 + ke_hi * p1 + boundary_e; + let ne1: i32 = q_lo * d3 + r_lo * e3 + q_hi * d2 + r_hi * e2 + ke_lo * p3 + ke_hi * p2; + let ne2: i32 = q_lo * d4 + r_lo * e4 + q_hi * d3 + r_hi * e3 + ke_lo * p4 + ke_hi * p3; + let ne3: i32 = q_lo * d5 + r_lo * e5 + q_hi * d4 + r_hi * e4 + ke_lo * p5 + ke_hi * p4; + let ne4: i32 = q_lo * d6 + r_lo * e6 + q_hi * d5 + r_hi * e5 + ke_lo * p6 + ke_hi * p5; + let ne5: i32 = q_lo * d7 + r_lo * e7 + q_hi * d6 + r_hi * e6 + ke_lo * p7 + ke_hi * p6; + let ne6: i32 = q_lo * d8 + r_lo * e8 + q_hi * d7 + r_hi * e7 + ke_lo * p8 + ke_hi * p7; + let ne7: i32 = q_lo * d9 + r_lo * e9 + q_hi * d8 + r_hi * e8 + ke_lo * p9 + ke_hi * p8; + let ne8: i32 = q_lo * d10 + r_lo * e10 + q_hi * d9 + r_hi * e9 + ke_lo * p10 + ke_hi * p9; + let ne9: i32 = q_lo * d11 + r_lo * e11 + q_hi * d10 + r_hi * e10 + ke_lo * p11 + ke_hi * p10; + let ne10: i32 = q_lo * d12 + r_lo * e12 + q_hi * d11 + r_hi * e11 + ke_lo * p12 + ke_hi * p11; + let ne11: i32 = q_lo * d13 + r_lo * e13 + q_hi * d12 + r_hi * e12 + ke_lo * p13 + ke_hi * p12; + let ne12: i32 = q_lo * d14 + r_lo * e14 + q_hi * d13 + r_hi * e13 + ke_lo * p14 + ke_hi * p13; + let ne13: i32 = q_lo * d15 + r_lo * e15 + q_hi * d14 + r_hi * e14 + ke_lo * p15 + ke_hi * p14; + let ne14: i32 = q_lo * d16 + r_lo * e16 + q_hi * d15 + r_hi * e15 + ke_lo * p16 + ke_hi * p15; + let ne15: i32 = q_lo * d17 + r_lo * e17 + q_hi * d16 + r_hi * e16 + ke_lo * p17 + ke_hi * p16; + let ne16: i32 = q_lo * d18 + r_lo * e18 + q_hi * d17 + r_hi * e17 + ke_lo * p18 + ke_hi * p17; + let ne17: i32 = q_lo * d19 + r_lo * e19 + q_hi * d18 + r_hi * e18 + ke_lo * p19 + ke_hi * p18; + let ne18: i32 = q_hi * d19 + r_hi * e19 + ke_hi * p19; + + // SERIAL CARRY PASS. + var cd: i32 = 0; + let vd_0: i32 = nd0 + cd; (*d).limbs[0] = u32(vd_0) & MASK; cd = vd_0 >> 13u; + let vd_1: i32 = nd1 + cd; (*d).limbs[1] = u32(vd_1) & MASK; cd = vd_1 >> 13u; + let vd_2: i32 = nd2 + cd; (*d).limbs[2] = u32(vd_2) & MASK; cd = vd_2 >> 13u; + let vd_3: i32 = nd3 + cd; (*d).limbs[3] = u32(vd_3) & MASK; cd = vd_3 >> 13u; + let vd_4: i32 = nd4 + cd; (*d).limbs[4] = u32(vd_4) & MASK; cd = vd_4 >> 13u; + let vd_5: i32 = nd5 + cd; (*d).limbs[5] = u32(vd_5) & MASK; cd = vd_5 >> 13u; + let vd_6: i32 = nd6 + cd; (*d).limbs[6] = u32(vd_6) & MASK; cd = vd_6 >> 13u; + let vd_7: i32 = nd7 + cd; (*d).limbs[7] = u32(vd_7) & MASK; cd = vd_7 >> 13u; + let vd_8: i32 = nd8 + cd; (*d).limbs[8] = u32(vd_8) & MASK; cd = vd_8 >> 13u; + let vd_9: i32 = nd9 + cd; (*d).limbs[9] = u32(vd_9) & MASK; cd = vd_9 >> 13u; + let vd_10: i32 = nd10 + cd; (*d).limbs[10] = u32(vd_10) & MASK; cd = vd_10 >> 13u; + let vd_11: i32 = nd11 + cd; (*d).limbs[11] = u32(vd_11) & MASK; cd = vd_11 >> 13u; + let vd_12: i32 = nd12 + cd; (*d).limbs[12] = u32(vd_12) & MASK; cd = vd_12 >> 13u; + let vd_13: i32 = nd13 + cd; (*d).limbs[13] = u32(vd_13) & MASK; cd = vd_13 >> 13u; + let vd_14: i32 = nd14 + cd; (*d).limbs[14] = u32(vd_14) & MASK; cd = vd_14 >> 13u; + let vd_15: i32 = nd15 + cd; (*d).limbs[15] = u32(vd_15) & MASK; cd = vd_15 >> 13u; + let vd_16: i32 = nd16 + cd; (*d).limbs[16] = u32(vd_16) & MASK; cd = vd_16 >> 13u; + let vd_17: i32 = nd17 + cd; (*d).limbs[17] = u32(vd_17) & MASK; cd = vd_17 >> 13u; + let vd_18: i32 = nd18 + cd; (*d).limbs[18] = u32(vd_18) & MASK; cd = vd_18 >> 13u; + (*d).limbs[19] = u32(cd); + + var ce: i32 = 0; + let ve_0: i32 = ne0 + ce; (*e).limbs[0] = u32(ve_0) & MASK; ce = ve_0 >> 13u; + let ve_1: i32 = ne1 + ce; (*e).limbs[1] = u32(ve_1) & MASK; ce = ve_1 >> 13u; + let ve_2: i32 = ne2 + ce; (*e).limbs[2] = u32(ve_2) & MASK; ce = ve_2 >> 13u; + let ve_3: i32 = ne3 + ce; (*e).limbs[3] = u32(ve_3) & MASK; ce = ve_3 >> 13u; + let ve_4: i32 = ne4 + ce; (*e).limbs[4] = u32(ve_4) & MASK; ce = ve_4 >> 13u; + let ve_5: i32 = ne5 + ce; (*e).limbs[5] = u32(ve_5) & MASK; ce = ve_5 >> 13u; + let ve_6: i32 = ne6 + ce; (*e).limbs[6] = u32(ve_6) & MASK; ce = ve_6 >> 13u; + let ve_7: i32 = ne7 + ce; (*e).limbs[7] = u32(ve_7) & MASK; ce = ve_7 >> 13u; + let ve_8: i32 = ne8 + ce; (*e).limbs[8] = u32(ve_8) & MASK; ce = ve_8 >> 13u; + let ve_9: i32 = ne9 + ce; (*e).limbs[9] = u32(ve_9) & MASK; ce = ve_9 >> 13u; + let ve_10: i32 = ne10 + ce; (*e).limbs[10] = u32(ve_10) & MASK; ce = ve_10 >> 13u; + let ve_11: i32 = ne11 + ce; (*e).limbs[11] = u32(ve_11) & MASK; ce = ve_11 >> 13u; + let ve_12: i32 = ne12 + ce; (*e).limbs[12] = u32(ve_12) & MASK; ce = ve_12 >> 13u; + let ve_13: i32 = ne13 + ce; (*e).limbs[13] = u32(ve_13) & MASK; ce = ve_13 >> 13u; + let ve_14: i32 = ne14 + ce; (*e).limbs[14] = u32(ve_14) & MASK; ce = ve_14 >> 13u; + let ve_15: i32 = ne15 + ce; (*e).limbs[15] = u32(ve_15) & MASK; ce = ve_15 >> 13u; + let ve_16: i32 = ne16 + ce; (*e).limbs[16] = u32(ve_16) & MASK; ce = ve_16 >> 13u; + let ve_17: i32 = ne17 + ce; (*e).limbs[17] = u32(ve_17) & MASK; ce = ve_17 >> 13u; + let ve_18: i32 = ne18 + ce; (*e).limbs[18] = u32(ve_18) & MASK; ce = ve_18 >> 13u; + (*e).limbs[19] = u32(ce); +} + +// ============================================================ +// Driver helpers +// ============================================================ + +fn bya_is_zero(x: ptr) -> bool { + var a: u32 = 0u; + for (var i: u32 = 0u; i < {{ num_words }}u; i = i + 1u) { + a = a | (*x).limbs[i]; + } + return a == 0u; +} + +fn bya_neg_inplace(x: ptr) { + for (var i: u32 = 0u; i < {{ num_words }}u; i = i + 1u) { + (*x).limbs[i] = u32(-i32((*x).limbs[i])); + } + bya_normalise(x); +} + +fn bya_add_p_inplace(x: ptr, p: ptr) { + for (var i: u32 = 0u; i < {{ num_words }}u; i = i + 1u) { + (*x).limbs[i] = u32(i32((*x).limbs[i]) + i32((*p).limbs[i])); + } + bya_normalise(x); +} + +fn bya_sub_p_inplace(x: ptr, p: ptr) { + for (var i: u32 = 0u; i < {{ num_words }}u; i = i + 1u) { + (*x).limbs[i] = u32(i32((*x).limbs[i]) - i32((*p).limbs[i])); + } + bya_normalise(x); +} + +fn bya_reduce_to_canonical(x: ptr, p: ptr) { + bya_normalise(x); + var done: bool = false; + for (var it: u32 = 0u; it < BYA_RTC_MAX_ITERS; it = it + 1u) { + if (done) { continue; } + if (bigint_is_neg_2c(x)) { + bya_add_p_inplace(x, p); + } else if (bigint_gte(x, p)) { + bya_sub_p_inplace(x, p); + } else { + done = true; + } + } +} + +const FR_INV_BY_A_P_INV_LO: u32 = {{ p_inv_by_a_lo }}u; + +// fr_inv_by_a: Bernstein-Yang safegcd inverse driver, BATCH=26 / NUM_OUTER=29 +// on the 20 x 13-bit BigInt representation. Tight inline-mul apply_matrix. +fn fr_inv_by_a(a: BigInt) -> BigInt { + var p_loc: BigInt = get_p(); + var f: BigInt = get_p(); + var g: BigInt = a; + + var d: BigInt; + var e: BigInt; + for (var k: u32 = 0u; k < {{ num_words }}u; k = k + 1u) { + d.limbs[k] = 0u; + e.limbs[k] = 0u; + } + e.limbs[0] = 1u; + + var delta: i32 = 1; + var done: bool = false; + for (var iter: u32 = 0u; iter < BYA_NUM_OUTER; iter = iter + 1u) { + if (done) { continue; } + let f_lo: vec2 = bya_low_u64_lohi(f); + let g_lo: vec2 = bya_low_u64_lohi(g); + let m: MatA = bya_divsteps(&delta, f_lo, g_lo); + bya_apply_matrix_fg(m, &f, &g); + bya_apply_matrix_de(m, &d, &e, &p_loc, FR_INV_BY_A_P_INV_LO); + if (((iter + 1u) % BYA_REDUCE_INTERVAL) == 0u) { + bya_reduce_to_canonical(&d, &p_loc); + bya_reduce_to_canonical(&e, &p_loc); + } + if (bya_is_zero(&g)) { + done = true; + } + } + + bya_reduce_to_canonical(&d, &p_loc); + if (bigint_is_neg_2c(&f)) { + bya_neg_inplace(&d); + bya_reduce_to_canonical(&d, &p_loc); + } + + var inv_native: BigInt = d; + var r_cubed: BigInt = get_r_cubed(); + return montgomery_product(&inv_native, &r_cubed); +} +`; + export const field = `fn fr_add(a: ptr, b: ptr) -> BigInt { var res: BigInt; bigint_add(a, b, &res); @@ -4492,6 +7160,28 @@ fn fr_pow(base: BigInt, exp: BigInt) -> BigInt { return result; } +// (p - 2) as a plain (non-Montgomery) BigInt. Used by fr_pow_inv as the +// exponent in Fermat's little theorem: a^(p-2) ≡ a^(-1) (mod p). +fn get_p_minus_2() -> BigInt { + var e: BigInt; +{{{ p_minus_2_limbs }}} + return e; +} + +// Field inversion via Fermat's little theorem: a^(-1) ≡ a^(p-2) (mod p). +// Both input and output are in Montgomery form. Since \`fr_pow\` preserves +// Montgomery form (Mont(base)^exp -> Mont(base^exp)), the result of +// \`fr_pow(Mont(a), p-2)\` is directly Mont(a^(-1)). No extra correction. +// +// Cost: ~254 squarings + ~127 expected multiplies (half the bits of p-2 +// are set), ≈ 381 montgomery_products per call. Compare to fr_inv's +// jumpy K=12 safegcd which converges in ~62 outer iters with ~10 +// BigInt-ops each plus ONE montgomery_product at the end. +fn fr_pow_inv(a: BigInt) -> BigInt { + var exp: BigInt = get_p_minus_2(); + return fr_pow(a, exp); +} + // R^3 mod p. Used by \`fr_inv\` to convert the binary-GCD output (which is // in native form, pre-multiplied by R^(-1) because the input was in // Montgomery form) back into Montgomery form via a single diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/bigint/bigint_by.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/bigint/bigint_by.template.wgsl new file mode 100644 index 000000000000..341e174fb8cd --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/bigint/bigint_by.template.wgsl @@ -0,0 +1,341 @@ +// WGSL bigint helpers for the Bernstein-Yang safegcd inversion port +// (sub-step 1.2 of the WebGPU MSM rewrite plan). The semantics here MUST +// match the TS reference at cuzk/bernstein_yang.ts exactly so divsteps / +// apply_matrix can be transliterated against this file in sub-step 1.3. +// +// REPRESENTATIONS +// - BigIntBY: 9 limbs of signed-29-bit chunks (i32). Top limb carries +// sign; lower limbs in [0, 2^29) post-normalise. Matches Wasm9x29. +// - u64 as vec2: value = .x + (.y << 32). All u64 ops use this. +// - i64 as vec2: value = u32(.x) + (i32(.y) << 32). The low half +// carries the raw bit pattern (treated as unsigned), the high half +// is signed. Used for matrix entries u/v/q/r which grow to ±2^58. +// +// LOOP BOUND DISCIPLINE +// Every `for` loop in this file uses a const upper bound (BY_NUM_LIMBS). +// The WGSL bounded-loop audit is `grep 'for ' ...`, run after Mustache +// render — every match has `< BY_NUM_LIMBS` or `< {{ num_words }}`. + +const BY_NUM_LIMBS: u32 = 9u; +const BY_LIMB_BITS: u32 = 29u; +const BY_LIMB_MASK: u32 = (1u << 29u) - 1u; +const BY_BATCH: u32 = 58u; +const BY_NUM_OUTER: u32 = 13u; +const BY_REDUCE_INTERVAL: u32 = 4u; +const BY_RTC_MAX_ITERS: u32 = 36u; + +// 29-bit signed limb wrapped as i32. Lower limbs canonical in [0, 2^29); +// top limb is the signed extension carrier. +struct BigIntBY { + l: array, +} + +// ============================================================ +// u64 arithmetic as vec2 (x = .x + (.y << 32)) +// ============================================================ + +// u64 add. Pre: any a, b. Post: low 64 bits of a + b. +fn u64_add(a: vec2, b: vec2) -> vec2 { + let lo = a.x + b.x; + let carry = select(0u, 1u, lo < a.x); + let hi = a.y + b.y + carry; + return vec2(lo, hi); +} + +// u64 sub. Pre: any a, b. Post: low 64 bits of (a - b) mod 2^64. +fn u64_sub(a: vec2, b: vec2) -> vec2 { + let lo = a.x - b.x; + let borrow = select(0u, 1u, a.x < b.x); + let hi = a.y - b.y - borrow; + return vec2(lo, hi); +} + +// u64 logical right shift by 1. Pre: any x. Post: x >> 1 unsigned. +fn u64_shr1(x: vec2) -> vec2 { + let lo = (x.x >> 1u) | ((x.y & 1u) << 31u); + let hi = x.y >> 1u; + return vec2(lo, hi); +} + +// Bit 0 of a u64. Pre: any x. Post: 0 or 1. +fn u64_low_bit(x: vec2) -> u32 { + return x.x & 1u; +} + +// u64 two's-complement negate. Pre: any x. Post: (-x) mod 2^64. +fn u64_neg(x: vec2) -> vec2 { + let nx = vec2(~x.x, ~x.y); + return u64_add(nx, vec2(1u, 0u)); +} + +// u64 bitwise AND. Pre: any a, mask. Post: a & mask, limbwise. +fn u64_and(a: vec2, mask: vec2) -> vec2 { + return vec2(a.x & mask.x, a.y & mask.y); +} + +// ============================================================ +// i64 signed arithmetic as vec2 (lo bits, hi signed) +// value = u32(.x) + (i32(.y) << 32) treated as signed two's-complement. +// ============================================================ + +// i64 add. Pre: |a|, |b| < 2^63. Post: low 64 bits of a + b, signed. +fn i64_add_pair(a: vec2, b: vec2) -> vec2 { + let au = u32(a.x); + let bu = u32(b.x); + let lo = au + bu; + let carry = select(0u, 1u, lo < au); + let hi = a.y + b.y + i32(carry); + return vec2(i32(lo), hi); +} + +// i64 sub. Pre: |a|, |b| < 2^63. Post: low 64 bits of a - b, signed. +fn i64_sub_pair(a: vec2, b: vec2) -> vec2 { + let au = u32(a.x); + let bu = u32(b.x); + let lo = au - bu; + let borrow = select(0u, 1u, au < bu); + let hi = a.y - b.y - i32(borrow); + return vec2(i32(lo), hi); +} + +// i64 left shift by 1. Pre: a in (-2^62, 2^62). Post: a << 1, signed. +fn i64_shl1_pair(a: vec2) -> vec2 { + let au = u32(a.x); + let lo = au << 1u; + let bit31 = au >> 31u; + // Arithmetic shift left of hi half: bring bottom carry bit up. + let hi_u = (u32(a.y) << 1u) | bit31; + return vec2(i32(lo), i32(hi_u)); +} + +// i64 two's-complement negate. Pre: any a (low 64 bits well-defined). +// Post: (-a) mod 2^64 as a signed pair. +fn i64_neg_pair(a: vec2) -> vec2 { + let nlo = ~u32(a.x); + let nhi = ~u32(a.y); + let lo = nlo + 1u; + let carry = select(0u, 1u, lo < nlo); + let hi_u = nhi + carry; + return vec2(i32(lo), i32(hi_u)); +} + +// ============================================================ +// signed_mul_split: signed 29 × 31-bit product, split into 29-bit chunks. +// +// Pre: |a| <= 2^29 (one BY limb), |b| <= 2^31 - 1 (i32 matrix-entry low). +// Post: returns (lo29, hi) with a*b = lo29 + (hi * 2^29), +// 0 <= lo29 < 2^29; |hi| <= 2^31 (fits in i32 except at the single +// corner a == -2^29, b == -2^31 which the caller does not pass). +// +// Partial-product width split (the plan-mandated proof sketch): +// Split each operand as v = v_lo + v_hi * 2^15 where v_lo is sign- +// extended low 15 bits, i.e. v_lo in [-2^14, 2^14). Then v_hi = (v - +// v_lo) >> 15. +// |a_lo|, |b_lo| <= 2^14 +// |a_hi| <= ceil(2^29 / 2^15) = 2^14 +// |b_hi| <= ceil(2^31 / 2^15) = 2^16 +// Four signed partial products: +// pll = a_lo * b_lo, |pll| <= 2^14 * 2^14 = 2^28 [< 2^31 ✓] +// plh = a_lo * b_hi, |plh| <= 2^14 * 2^16 = 2^30 [< 2^31 ✓] +// phl = a_hi * b_lo, |phl| <= 2^14 * 2^14 = 2^28 [< 2^31 ✓] +// phh = a_hi * b_hi, |phh| <= 2^14 * 2^16 = 2^30 [< 2^31 ✓] +// Middle sum: +// mid = plh + phl, |mid| <= 2^30 + 2^28 < 2^31 [fits i32] +// Every partial fits in signed i32, satisfying the plan's audit. The +// subsequent combination uses u32 wrap arithmetic to assemble the low +// 64 bits of the signed product without further overflow concerns. +fn signed_mul_split(a: i32, b: i32) -> vec2 { + // Sign-extend the low 15 bits of each operand: (v << 17) >> 17 fills + // bits 15..31 with the sign of bit 14. WGSL i32 shifts are well-defined + // (left shift discards high bits; right shift on i32 is arithmetic). + let a_lo: i32 = (a << 17u) >> 17u; + let a_hi: i32 = (a - a_lo) >> 15u; + let b_lo: i32 = (b << 17u) >> 17u; + let b_hi: i32 = (b - b_lo) >> 15u; + + let pll: i32 = a_lo * b_lo; + let plh: i32 = a_lo * b_hi; + let phl: i32 = a_hi * b_lo; + let phh: i32 = a_hi * b_hi; + let mid: i32 = plh + phl; + + // total = pll + mid * 2^15 + phh * 2^30, computed as the low 64 bits + // of an i64 via u32 wrap arithmetic. + // + // Each i32 piece extended to i64 (lo, hi): + // pll: lo = u32(pll), hi = u32(pll >> 31) (sign mask) + // mid << 15: lo = u32(mid) << 15, hi = u32(mid >> 17) + // phh << 30: lo = u32(phh) << 30, hi = u32(phh >> 2) + // The >> on signed pieces is arithmetic so the high half carries the + // correct sign extension. + let pll_lo: u32 = u32(pll); + let pll_hi: u32 = u32(pll >> 31u); + let mid_lo: u32 = u32(mid) << 15u; + let mid_hi: u32 = u32(mid >> 17u); + let phh_lo: u32 = u32(phh) << 30u; + let phh_hi: u32 = u32(phh >> 2u); + + // Sum the three i64 values with carry chains. + let s1_lo: u32 = pll_lo + mid_lo; + let c1: u32 = select(0u, 1u, s1_lo < pll_lo); + let s1_hi: u32 = pll_hi + mid_hi + c1; + + let sum_lo: u32 = s1_lo + phh_lo; + let c2: u32 = select(0u, 1u, sum_lo < s1_lo); + let sum_hi: u32 = s1_hi + phh_hi + c2; + + // Re-split into lo29 (low 29 bits) and hi = ARS by 29 (signed). + let lo29: u32 = sum_lo & 0x1FFFFFFFu; + // hi bits 0..2 from sum_lo[29..31]; bits 3..31 from sum_hi[0..28]. + // sum_hi[29..31] are sign-extension bits under the precondition + // |a*b| <= 2^60 and are absorbed correctly when bitcasting hi_u to i32. + let hi_u: u32 = (sum_lo >> 29u) | (sum_hi << 3u); + return vec2(i32(lo29), i32(hi_u)); +} + +// by_accumulate: acc + m_lo * x_limb as a 58-bit signed two-limb pair. +// +// Pre: |acc.x| < 2^29 (canonical low half), |acc.y| < 2^29 + 2^k for +// some small k (caller bounded), |m_lo| <= 2^29, |x_limb| <= 2^29. +// Post: returns (lo29, hi) with acc + m_lo * x_limb = lo29 + hi * 2^29, +// 0 <= lo29 < 2^29 (canonical low limb). hi grows by at most one +// carry bit per call. +// +// This is the per-limb cross-product helper used by apply_matrix. The +// product |m_lo * x_limb| <= 2^58, well within signed_mul_split's bounds. +fn by_accumulate(acc: vec2, m_lo: i32, x_limb: i32) -> vec2 { + let prod = signed_mul_split(m_lo, x_limb); + // Add (prod.x, prod.y) to (acc.x, acc.y) treating both halves as signed + // 32-bit lanes. acc.x in [0, 2^29), prod.x in [0, 2^29), so the low-half + // sum is in [0, 2^30) — fits i32, no carry to track for this helper. + let lo_sum = acc.x + prod.x; + let lo29 = lo_sum & i32(BY_LIMB_MASK); + // The high half absorbs prod.y plus any bits of lo_sum above bit 28. + // WGSL requires the rhs of `>>` to be u32 (the lhs stays i32 — i32 shift + // is arithmetic, preserving sign). + let lo_overflow = lo_sum >> BY_LIMB_BITS; + let hi = acc.y + prod.y + lo_overflow; + return vec2(lo29, hi); +} + +// by_low_u64_lohi: low 64 bits of a BigIntBY as a vec2 (.x = low 32, +// .y = high 32). Mirrors `Wasm9x29::low_64()` and the TS `low64`: +// bits 0..28 = x.l[0] & LIMB_MASK +// bits 29..57 = x.l[1] & LIMB_MASK +// bits 58..63 = x.l[2] & 0x3F +// All inputs treated as unsigned u32 bit patterns. Negative limbs +// reinterpreted via `u32(...)` — the caller is responsible for ensuring the +// lower three limbs are canonical (in [0, 2^29) etc.), which is the case in +// the BY driver after each by_normalise. +fn by_low_u64_lohi(x: BigIntBY) -> vec2 { + let l0: u32 = u32(x.l[0]) & BY_LIMB_MASK; + let l1: u32 = u32(x.l[1]) & BY_LIMB_MASK; + let l2: u32 = u32(x.l[2]) & 0x3Fu; + // low 32 bits: l0 in bits 0..28, low 3 bits of l1 at 29..31. + let lo32: u32 = l0 | ((l1 & 0x7u) << 29u); + // high 32 bits: high 26 bits of l1 at 0..25, l2 (6 bits) at 26..31. + let hi32: u32 = (l1 >> 3u) | (l2 << 26u); + return vec2(lo32, hi32); +} + +// by_normalise: carry-propagate so each limb i in [0, N-1) ends up in +// [0, 2^29) and the top limb absorbs the final signed carry. +// +// Pre: any signed limb values (caller may pass non-canonical ±2^60). +// Post: x.l[0..N-2] in [0, 2^29); x.l[N-1] is the signed extension. +fn by_normalise(x: ptr) { + var c: i32 = 0; + for (var i: u32 = 0u; i < BY_NUM_LIMBS - 1u; i = i + 1u) { + let v = (*x).l[i] + c; + (*x).l[i] = v & i32(BY_LIMB_MASK); + // Arithmetic shift on i32 propagates sign; matches BigInt(v) >> 29n in TS. + // WGSL requires the rhs of `>>` to be u32 (the lhs i32 then gets the + // arithmetic shift semantics we want). + c = v >> BY_LIMB_BITS; + } + (*x).l[BY_NUM_LIMBS - 1u] = (*x).l[BY_NUM_LIMBS - 1u] + c; +} + +// ============================================================ +// Conversion between 20×13-bit BigInt and 9×29-bit BigIntBY. +// +// The conversion is a bit-rewrite: read 29 source bits at offset +// 29*k into limb k. Each 29-bit window spans up to 4 consecutive +// source limbs (worst case bit-offset 12, covers 27 bits in 3 limbs; +// needs 2 more from a 4th). +// +// Pre (by_from_bigint): x in [0, 2^260) represented in 20 × 13-bit limbs. +// Post: result in [0, 2^256), 9 canonical 29-bit limbs. +// Pre (by_to_bigint): x canonical (lower limbs in [0, 2^29), top non-neg). +// Post: result in [0, 2^260), 20 × 13-bit unsigned limbs. +// ============================================================ + +const BY_SRC_LIMB_BITS: u32 = 13u; +const BY_SRC_LIMB_MASK: u32 = (1u << 13u) - 1u; + +// Read a 29-bit window from the 20×13-bit BigInt starting at bit `bit_lo`. +// The window spans up to 4 consecutive source limbs (ceil((29 + 12) / 13) = 4). +fn by_read_29bit_window(x: BigInt, bit_lo: u32) -> u32 { + let base_idx = bit_lo / BY_SRC_LIMB_BITS; // first source limb touched + let in_limb_bit = bit_lo % BY_SRC_LIMB_BITS; // bit offset within base limb + // Up to four consecutive limbs (out-of-range reads return 0). + var v0: u32 = 0u; + var v1: u32 = 0u; + var v2: u32 = 0u; + var v3: u32 = 0u; + if (base_idx < {{ num_words }}u) { v0 = x.limbs[base_idx]; } + if (base_idx + 1u < {{ num_words }}u) { v1 = x.limbs[base_idx + 1u]; } + if (base_idx + 2u < {{ num_words }}u) { v2 = x.limbs[base_idx + 2u]; } + if (base_idx + 3u < {{ num_words }}u) { v3 = x.limbs[base_idx + 3u]; } + // Stitch: contributions are v0 >> in_limb_bit, v1..v3 at increasing offsets. + // The v3 shift can reach 13*3 - 0 = 39 which exceeds 32 — but in WGSL + // shifting a 13-bit limb by >= 32 yields zero by spec; we only need v3 + // when in_limb_bit > 10, where the shift is 39 - in_limb_bit ∈ [27, 28]. + // For in_limb_bit < 11 the contribution of v3 is irrelevant (covered by + // v0..v2). For safety we mask before shifting so >= 32 shifts vanish. + let s0 = v0 >> in_limb_bit; + let s1 = v1 << (BY_SRC_LIMB_BITS - in_limb_bit); + let s2 = v2 << ((BY_SRC_LIMB_BITS * 2u) - in_limb_bit); + let shift3: u32 = (BY_SRC_LIMB_BITS * 3u) - in_limb_bit; + // shift3 is in [27, 39]; for shift3 >= 32 we explicitly produce 0 to + // avoid WGSL's undefined-behavior region for shifts >= bit-width. + var s3: u32 = 0u; + if (shift3 < 32u) { + s3 = v3 << shift3; + } + return (s0 | s1 | s2 | s3) & BY_LIMB_MASK; +} + +fn by_from_bigint(x: BigInt) -> BigIntBY { + var r: BigIntBY; + for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u) { + r.l[i] = i32(by_read_29bit_window(x, i * BY_LIMB_BITS)); + } + return r; +} + +// Write a 13-bit window into the 20×13-bit BigInt at bit `bit_lo`. +// Reads three consecutive 29-bit BY limbs and emits the right 13-bit slice. +fn by_read_13bit_window(x: BigIntBY, bit_lo: u32) -> u32 { + let base_idx = bit_lo / BY_LIMB_BITS; + let in_limb_bit = bit_lo % BY_LIMB_BITS; + // Up to two BY limbs cover any 13-bit window (since 29 > 13). + var v0: u32 = 0u; + var v1: u32 = 0u; + if (base_idx < BY_NUM_LIMBS) { v0 = u32(x.l[base_idx]); } + if (base_idx + 1u < BY_NUM_LIMBS) { v1 = u32(x.l[base_idx + 1u]); } + let s0 = v0 >> in_limb_bit; + // Shift amount for v1 is (29 - in_limb_bit). In WGSL shifts of >= 32 are + // undefined; (29 - in_limb_bit) is in [1, 29] so always < 32. + let shift_up = BY_LIMB_BITS - in_limb_bit; + let s1 = v1 << shift_up; + return (s0 | s1) & BY_SRC_LIMB_MASK; +} + +fn by_to_bigint(x: BigIntBY) -> BigInt { + var out: BigInt; + for (var i: u32 = 0u; i < {{ num_words }}u; i = i + 1u) { + out.limbs[i] = by_read_13bit_window(x, i * BY_SRC_LIMB_BITS); + } + return out; +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/apply_matrix_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/apply_matrix_bench.template.wgsl new file mode 100644 index 000000000000..cfba8f0416cc --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/apply_matrix_bench.template.wgsl @@ -0,0 +1,86 @@ +// Single-thread-per-input bench shader for `by_apply_matrix_fg` / +// `by_apply_matrix_de`. Each thread runs both passes on one input record +// and writes the resulting 4 × 9 limbs (f', g', d', e') as 36 i32 outputs. +// +// INPUT LAYOUT (44 u32 per thread): +// Offset 0..7: Mat fields { u, v, q, r, u_hi, v_hi, q_hi, r_hi } as i32. +// Offset 8..16: f limbs (9 × i32) +// Offset 17..25: g limbs (9 × i32) +// Offset 26..34: d limbs (9 × i32) +// Offset 35..43: e limbs (9 × i32) +// All limbs are signed-canonical (lower limbs in [0, 2^29), top limb signed). +// +// OUTPUT LAYOUT (36 i32 per thread): +// Offset 0..8: f' limbs (9 × i32) — output of by_apply_matrix_fg +// Offset 9..17: g' limbs (9 × i32) — output of by_apply_matrix_fg +// Offset 18..26: d' limbs (9 × i32) — output of by_apply_matrix_de +// Offset 27..35: e' limbs (9 × i32) — output of by_apply_matrix_de +// +// CONSTANTS (Mustache injected): +// p_limbs_by: BigIntBY initializer for the BN254 base-field modulus p. +// p_inv_by_lo: low 32 bits of P_INV = p^(-1) mod 2^58. +// p_inv_by_hi: high 32 bits of P_INV (only the low 26 bits are non-zero). +// +// LOOP BOUNDS: only inherited loops — by_apply_matrix_fg and +// by_apply_matrix_de each use a single `for` bounded by BY_NUM_LIMBS +// (= 9u const). The dispatch entry shader itself has no loops. +// +// SAFETY: `if (tid >= n) return;` static guard, no data-dependent loops. + +@group(0) @binding(0) var inputs: array; +@group(0) @binding(1) var outputs: array; +@group(0) @binding(2) var params: vec2; + +const APPLY_MATRIX_INPUT_STRIDE: u32 = 44u; // 8 (Mat) + 4 * 9 (limbs) +const APPLY_MATRIX_OUTPUT_STRIDE: u32 = 36u; // 4 * 9 (limbs) +const P_INV_BY_LO: u32 = {{ p_inv_by_lo }}u; +const P_INV_BY_HI: u32 = {{ p_inv_by_hi }}u; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let n = params.x; + let tid = gid.x; + if (tid >= n) { return; } + + let in_base: u32 = tid * APPLY_MATRIX_INPUT_STRIDE; + let out_base: u32 = tid * APPLY_MATRIX_OUTPUT_STRIDE; + + // Decode Mat. The 8 i32 fields land in bindings as u32 — bitcast. + let m: Mat = Mat( + bitcast(inputs[in_base + 0u]), + bitcast(inputs[in_base + 1u]), + bitcast(inputs[in_base + 2u]), + bitcast(inputs[in_base + 3u]), + bitcast(inputs[in_base + 4u]), + bitcast(inputs[in_base + 5u]), + bitcast(inputs[in_base + 6u]), + bitcast(inputs[in_base + 7u]), + ); + + // Decode f, g, d, e as BigIntBY (9 × i32 each). + var f: BigIntBY; + var g: BigIntBY; + var d: BigIntBY; + var e: BigIntBY; + for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u) { + f.l[i] = bitcast(inputs[in_base + 8u + i]); + g.l[i] = bitcast(inputs[in_base + 17u + i]); + d.l[i] = bitcast(inputs[in_base + 26u + i]); + e.l[i] = bitcast(inputs[in_base + 35u + i]); + } + + // Modulus p, fixed across all threads — Mustache-injected. + var p: BigIntBY = BigIntBY(array({{{ p_limbs_by }}})); + + // Apply matrix passes. + by_apply_matrix_fg(m, &f, &g); + by_apply_matrix_de(m, &d, &e, &p, P_INV_BY_LO, P_INV_BY_HI); + + // Write outputs. + for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u) { + outputs[out_base + 0u + i] = f.l[i]; + outputs[out_base + 9u + i] = g.l[i]; + outputs[out_base + 18u + i] = d.l[i]; + outputs[out_base + 27u + i] = e.l[i]; + } +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_dispatch_args.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_dispatch_args.template.wgsl index 2bb9c655e876..44640c667de4 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_dispatch_args.template.wgsl +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_dispatch_args.template.wgsl @@ -60,6 +60,11 @@ var round_count: array>; @group(0) @binding(2) var dispatch_args: array; +// WINDOWS_PER_BATCH — baked at render time. Controls the inverse-pass +// Z dispatch dim: num_batches = ceil(num_subtasks / WPB). See +// batch_inverse_parallel.template.wgsl for the merged-pool inverse. +const WPB: u32 = {{ windows_per_batch }}u; + // params[0] = num_subtasks // params[1] = apply_workgroup_size // params[2] = sched_x_groups (= ceil(num_columns / schedule_workgroup_size)) @@ -74,6 +79,7 @@ fn main() { let apply_wg_size = params[1]; let sched_x_groups = params[2]; let num_sub_wgs = params[3]; + let num_batches = (num_subtasks + WPB - 1u) / WPB; var max_count: u32 = 0u; for (var i: u32 = 0u; i < num_subtasks; i = i + 1u) { @@ -98,13 +104,14 @@ fn main() { dispatch_args[1] = 1u; dispatch_args[2] = num_subtasks; - // THIS round's inverse: (W, 1, T) workgroups — W sub-WGs per subtask - // splitting each subtask's pair pool into W contiguous slices, each - // independently inverted with its own fr_inv. Drops Phase A/D - // per-thread sequential cost by W. See batch_inverse_parallel for - // the algorithm. + // THIS round's inverse: (W, 1, num_batches) workgroups — W sub-WGs + // per (batch, sub_wg) splitting each batch's MERGED pair pool (of + // size sum_{w} round_count[batch * WPB + w]) into W contiguous + // slices, each independently inverted with its own fr_inv. Pooling + // amortises one fr_inv across WPB subtasks. Z dim drops from T to + // ceil(T/WPB) accordingly. let inverse_x = select(0u, num_sub_wgs, any_work); - let inverse_z = select(0u, num_subtasks, any_work); + let inverse_z = select(0u, num_batches, any_work); dispatch_args[3] = inverse_x; dispatch_args[4] = 1u; dispatch_args[5] = inverse_z; diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_inverse.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_inverse.template.wgsl index 3102c494d4db..4ea97a102d4c 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_inverse.template.wgsl +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_inverse.template.wgsl @@ -3,6 +3,8 @@ {{> montgomery_product_funcs }} {{> field_funcs }} {{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} // get_r returns Montgomery R (= the integer 1 in Montgomery form). Each // main shader defines its own with `r_limbs` substitution; the @@ -74,7 +76,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { } // Step 2: invert the final product. - var inv_acc: BigInt = fr_inv(acc); + var inv_acc: BigInt = fr_inv_by_a(acc); // Step 3: walk back, emitting individual inverses. for (var idx = 0u; idx < n - 1u; idx = idx + 1u) { diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_inverse_parallel.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_inverse_parallel.template.wgsl index a76132e75733..64ffbffd7451 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_inverse_parallel.template.wgsl +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_inverse_parallel.template.wgsl @@ -3,6 +3,8 @@ {{> montgomery_product_funcs }} {{> field_funcs }} {{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} // Parallel Montgomery batch-inverse on the GPU. // @@ -13,74 +15,65 @@ // workgroup of TPB threads, dropping the per-dispatch cost from ~30 ms // to ~1-3 ms at the realistic round sizes we see. // -// MULTI-WORKGROUP MODE. The dispatch is shaped (NUM_SUB_WGS, 1, T). -// `wid.z` selects the subtask (count_buf[subtask]) and `wid.x` is the -// sub-workgroup index inside that subtask (0..NUM_SUB_WGS-1). Each -// sub-workgroup independently inverts a contiguous slice of length -// per_sub_chunk = ceil(n / NUM_SUB_WGS), using its own fr_inv. Reasoning: +// MULTI-WORKGROUP + WINDOWS-PER-BATCH MODE. The dispatch is shaped +// (NUM_SUB_WGS, 1, num_batches), where num_batches = ceil(T / WPB). +// `wid.z` selects the batch; `wid.x` is the sub-workgroup index inside +// the batch. Each (wid.z, wid.x) pair inverts the COMBINED pair pool of +// WPB consecutive subtasks, treating it as one big logical pool of total +// length sum_{w=0..WPB-1} count_buf[batch * WPB + w], with one +// fr_inv_by_a call per (batch, sub_wg). // -// - Phase A (per-thread fwd) and Phase D (per-thread back-walk) are -// sequential per thread with bs = ceil(n / TPB). Splitting each -// subtask across W sub-workgroups drops bs by W, giving up to W× -// speedup on the dominant per-round latency at large N. +// Amortisation: with WPB=4 we run T/WPB × W fr_inv_by_a calls per round +// instead of T × W — a 4× drop in the dominant Phase C cost. The +// per-subtask layout of the pair pool (and the `prefix` / `outputs` +// scratch buffers) is unchanged, so `batch_affine_apply_scatter` reads +// outputs[subtask_global * pitch + slot_in_subtask] just like before. +// The kernel "sews together" the gaps between subtask slices in the +// scan: a logical position p in the merged pool decodes to +// (subtask_in_batch, slot_in_subtask) via the per-batch exclusive +// prefix-sum `wg_cum`, then to a physical buffer position +// (batch * WPB + subtask_in_batch) * pitch + slot_in_subtask. // -// - Each sub-workgroup runs its own fr_inv. fr_invs across sub-WGs -// run concurrently on different SMs, so the EXTRA fr_invs cost zero -// wall time (gated only by SM occupancy). With T=16 subtasks × W=8 -// sub-WGs = 128 workgroups in flight, RTX-class GPUs are fully -// occupied during the inverse pass. +// Per-thread walks advance a (cur_w, cur_off) cursor as they step +// through their chunk so the per-step decode is O(1); the only WPB-loop +// is the const-bounded prefix-sum setup at the top of the kernel. // // Two clients today: -// - SMVP (cross-subtask): pitch = num_columns, count_buf[wid.z] = -// pair_counter[wid.z] (per-subtask atomic). -// - Finalize: pitch = half_num_columns, count_buf[wid.z] = -// half_num_columns for all wid.z (pre-populated by host). +// - SMVP (cross-subtask): pitch = num_columns, count_buf[g] = +// round_count[g] (per-subtask, forwarded by dispatch_args). +// - Finalize: pitch = half_num_columns, count_buf[g] = +// half_num_columns for all g (pre-populated by host). With WPB=1 in +// the finalize call this degenerates to the previous behaviour. // -// Algorithm (Montgomery's batch-inverse trick, two-level): +// Algorithm (Montgomery's batch-inverse trick, two-level, merged pool): // // 1. Each thread i computes block_inclusive_prefix[k] for k in its // chunk into the `prefix` scratch buffer (serial, length bs = -// ceil(n / TPB)). Captures block_total[i] in a register. +// ceil(n / TPB)). Captures block_total[i] in a register. The walk +// is over LOGICAL positions in the merged pool, but each write +// lands at the corresponding PHYSICAL buffer position (decoded via +// wg_cum). // 2. All TPB block_totals are pushed into workgroup memory. // 3. Two parallel inclusive scans over wg memory: forward (wg_fwd) // and backward (wg_bwd). Hillis-Steele, log2(TPB) passes. -// After scan: -// wg_fwd[i] = block_total[0] * ... * block_total[i] -// wg_bwd[i] = block_total[i] * ... * block_total[TPB-1] -// So global_total = wg_fwd[TPB-1] = wg_bwd[0]. -// 4. Thread 0 computes inv_global = inv(global_total) — single fr_inv. -// Broadcast via wg_inv_total. -// 5. Each thread walks back through its chunk. The key trick: within -// a single block, block_excl_prefix cancels algebraically, so the -// back-walk reduces to the standard 2-mul/element batch inverse -// on the per-block prefix array. -// -// inv(global_prefix[k]) · global_prefix[k-1] -// = inv(block_excl_prefix · P_in[k]) · (block_excl_prefix · P_in[k-1]) -// = inv(P_in[k]) · P_in[k-1] // block_excl_prefix cancels -// -// So: -// block_excl_prefix[i] = wg_fwd[i-1] (or R for i=0) -// block_excl_suffix[i] = wg_bwd[i+1] (or R for i=TPB-1) -// // Setup: inv(block_total) = inv_global · block_excl_prefix · block_excl_suffix -// inv_acc = inv(block_total[i]) // 2 muls (setup, one-time) -// For k from chunk_end-1 down to chunk_start: -// out[k] = inv_acc · (k>chunk_start ? prefix[k-1] : R) // 1 mul (or 0) -// inv_acc = inv_acc · inputs[k] // 1 mul -// -// Cost: 2N muls in back-walk + N muls in forward + 2 muls setup -// = 3N + O(1) per workgroup. Previously 3N + N back-walk muls -// (extra mul-by-block_excl_prefix per element) = 4N total. +// 4. Thread 0 computes inv_global = inv(global_total) — single +// fr_inv_by_a per (batch, sub_wg). Broadcast via wg_inv_total. +// 5. Each thread walks back through its chunk. Within a single block, +// block_excl_prefix cancels algebraically, so the back-walk runs +// as the standard 2-mul/element batch inverse on the per-block +// prefix array (same as before). Outputs land at physical buffer +// positions matching the input layout. // // TPB = 64 keeps workgroup memory at 2 * 64 * sizeof(BigInt) = 10240 // bytes (BN254: BigInt = 80 bytes), comfortably under the WebGPU // default maxComputeWorkgroupStorageSize = 16384. // -// When n == 0 for this workgroup the kernel returns immediately at the -// top. +// When total_n == 0 for this (batch, sub_wg) the kernel returns +// immediately at the top. const TPB: u32 = 64u; const NUM_SUB_WGS: u32 = {{ num_sub_wgs }}u; +const WPB: u32 = {{ windows_per_batch }}u; @group(0) @binding(0) var inputs: array; @@ -94,7 +87,7 @@ var outputs: array; @group(0) @binding(3) var count_buf: array>; -// params[0] = pitch (per-workgroup slice stride) +// params[0] = pitch (per-subtask slice stride inside inputs/prefix/outputs) // params[1..3] = unused @group(0) @binding(4) var params: vec4; @@ -108,19 +101,44 @@ fn get_r() -> BigInt { var wg_fwd: array; var wg_bwd: array; var wg_inv_total: BigInt; -// Broadcast slot for the pair count. atomicLoad returns a non-uniform -// value as far as the WGSL uniformity analysis is concerned, so the -// downstream `workgroupBarrier()`s would be ill-formed if we branched -// directly on it. Funnelling it through a workgroup variable + an -// explicit `workgroupUniformLoad` re-uniforms the value (and acts as -// an implicit barrier). -var wg_n: u32; -// Per-sub-WG element offset into the subtask's slice. Set by tid 0 -// alongside wg_n; broadcast through workgroup memory so every thread -// agrees on the slice base. (workgroupUniformLoad over wg_n implicitly -// synchronises this write too, since both are written before the load.) +// Per-batch exclusive prefix of subtask counts: wg_cum[w] is the +// cumulative count of subtasks 0..w-1 within this batch (so wg_cum[0] +// is always 0 and the merged-pool total length is +// wg_cum[WPB-1] + wg_counts[WPB-1]). Used to decode "logical position +// p in the merged pool" → "(subtask_in_batch w, slot s in that +// subtask)" via the largest w with wg_cum[w] <= p. +var wg_counts: array; +var wg_cum: array; +// Broadcast slot for the merged-pool total count. atomicLoad returns a +// non-uniform value as far as the WGSL uniformity analysis is +// concerned, so the downstream `workgroupBarrier()`s would be +// ill-formed if we branched directly on it. Funnelling it through a +// workgroup variable + an explicit `workgroupUniformLoad` re-uniforms +// the value (and acts as an implicit barrier). +var wg_total_n: u32; +// Per-sub-WG element offset into the merged pool. Set by tid 0 +// alongside wg_total_n; broadcast through workgroup memory so every +// thread agrees on the slice base. (workgroupUniformLoad over +// wg_total_n implicitly synchronises this write too, since both are +// written before the load.) var wg_sub_offset: u32; +// Decode a logical position p in the merged batch pool into +// (subtask_in_batch, slot_in_subtask). The WPB-loop is const-bounded. +struct PoolPos { w: u32, slot: u32 }; +fn decode_pool_pos(p: u32) -> PoolPos { + var out: PoolPos; + out.w = 0u; + out.slot = p; + for (var i = 1u; i < WPB; i = i + 1u) { + if (p >= wg_cum[i]) { + out.w = i; + out.slot = p - wg_cum[i]; + } + } + return out; +} + @compute @workgroup_size(64) fn main( @@ -129,15 +147,27 @@ fn main( ) { let tid = lid.x; let pitch = params[0]; - let subtask_idx = wid.z; + let batch_idx = wid.z; let sub_idx = wid.x; + let subtask_base = batch_idx * WPB; - // Thread 0 reads the subtask's atomic count and the sub-WG's slice - // bounds; broadcast via wg_n. workgroupUniformLoad ensures all - // threads see a uniform `n` (= this sub-WG's element count) and - // synchronises the workgroup before continuing. + // Thread 0 reads the WPB per-subtask atomic counts in this batch, + // computes the exclusive prefix-sum into wg_cum + wg_counts, and + // resolves the sub-WG's slice bounds inside the merged pool. + // Broadcasts via wg_total_n + wg_sub_offset. workgroupUniformLoad + // over wg_total_n then re-uniforms across the workgroup. if (tid == 0u) { - let total_n = atomicLoad(&count_buf[subtask_idx]); + var cum: u32 = 0u; + for (var w = 0u; w < WPB; w = w + 1u) { + let g = subtask_base + w; + // count_buf is sized to a multiple of WPB so the tail loads + // here return the host's initial zero (no out-of-bounds). + let c = atomicLoad(&count_buf[g]); + wg_counts[w] = c; + wg_cum[w] = cum; + cum = cum + c; + } + let total_n = cum; // Split [0, total_n) into NUM_SUB_WGS contiguous chunks. Chunks // are sized ceil(total_n / NUM_SUB_WGS); trailing sub-WGs may // see a shorter (or empty) chunk if total_n isn't a multiple of @@ -152,39 +182,70 @@ fn main( sub_n = clamped_end - raw_start; } wg_sub_offset = raw_start; - wg_n = sub_n; + wg_total_n = sub_n; } - let n = workgroupUniformLoad(&wg_n); + let n = workgroupUniformLoad(&wg_total_n); if (n == 0u) { return; } - let sub_offset_in_subtask = wg_sub_offset; - - // Subtask owns a slice of `pitch` elements at offset - // subtask_idx * pitch. This sub-WG owns - // [subtask_offset + sub_offset_in_subtask, - // subtask_offset + sub_offset_in_subtask + n). - let slice_offset = subtask_idx * pitch + sub_offset_in_subtask; + let sub_offset_in_batch = wg_sub_offset; // Block size = ceil(n / TPB). Last thread may have a shorter chunk. let bs = (n + TPB - 1u) / TPB; - let chunk_start = tid * bs; - var chunk_end = chunk_start + bs; - if (chunk_end > n) { - chunk_end = n; + let chunk_start_in_sub = tid * bs; + var chunk_end_in_sub = chunk_start_in_sub + bs; + if (chunk_end_in_sub > n) { + chunk_end_in_sub = n; } + // Translate this thread's [chunk_start, chunk_end) window from + // sub-WG-local indexing into batch-merged-pool indexing. + let chunk_start = sub_offset_in_batch + chunk_start_in_sub; + let chunk_end = sub_offset_in_batch + chunk_end_in_sub; + // Phase A: per-thread inclusive prefix product over the chunk. - // block_total = product of all elements in this thread's chunk, - // or R (Montgomery 1) if the chunk is empty. + // Each k iterates over LOGICAL positions in the merged pool but + // every read/write lands at the PHYSICAL buffer position decoded + // through wg_cum. block_total = product of all elements in this + // thread's chunk, or R (Montgomery 1) if the chunk is empty. var block_total: BigInt = get_r(); - if (chunk_start < n) { - var acc: BigInt = inputs[slice_offset + chunk_start]; - prefix[slice_offset + chunk_start] = acc; + if (chunk_start_in_sub < n) { + // Decode the starting logical position once, then carry + // (cur_w, cur_off) forward through the chunk so the per-step + // decode is O(1). cur_off is the slot within the current + // subtask; when it reaches wg_counts[cur_w] we advance to the + // next subtask. + let start_pos = decode_pool_pos(chunk_start); + var cur_w: u32 = start_pos.w; + var cur_off: u32 = start_pos.slot; + let g0 = subtask_base + cur_w; + let phys0 = g0 * pitch + cur_off; + var acc: BigInt = inputs[phys0]; + prefix[phys0] = acc; + // Advance one step past the seed (k = chunk_start) and walk to + // chunk_end. The k variable here is the BATCH-MERGED logical + // position; advance cur_off in lockstep, rolling cur_w + cur_off + // across the wg_counts[cur_w] boundary when the subtask runs out. for (var k = chunk_start + 1u; k < chunk_end; k = k + 1u) { - var x: BigInt = inputs[slice_offset + k]; + cur_off = cur_off + 1u; + // Bounded rollover: cross at most one boundary per step + // (counts never exceed pitch, which exceeds bs in practice). + // The const-bounded loop guards against any degenerate + // input where wg_counts[cur_w] could be zero — we still + // can't iterate past WPB steps because cur_w is in [0, WPB). + for (var hop = 0u; hop < WPB; hop = hop + 1u) { + if (cur_off >= wg_counts[cur_w]) { + cur_w = cur_w + 1u; + cur_off = 0u; + } else { + break; + } + } + let g = subtask_base + cur_w; + let phys = g * pitch + cur_off; + var x: BigInt = inputs[phys]; acc = montgomery_product(&acc, &x); - prefix[slice_offset + k] = acc; + prefix[phys] = acc; } block_total = acc; } @@ -214,14 +275,16 @@ fn main( } // Phase C: thread 0 inverts the global total. Broadcast via workgroup mem. + // One fr_inv_by_a per (batch, sub_wg), independent of WPB — that's + // the amortisation this layout buys. if (tid == 0u) { var global_total: BigInt = wg_fwd[TPB - 1u]; - wg_inv_total = fr_inv(global_total); + wg_inv_total = fr_inv_by_a(global_total); } workgroupBarrier(); // Phase D: walk back through this thread's chunk, emitting inverses. - if (chunk_start >= n) { + if (chunk_start_in_sub >= n) { return; } @@ -246,6 +309,13 @@ fn main( var inv_acc: BigInt = montgomery_product(&inv_global, &block_excl_prefix); inv_acc = montgomery_product(&inv_acc, &block_excl_suffix); + // Decode the END position (chunk_end - 1) and walk backward, + // mirroring Phase A's forward decode. Maintain (cur_w, cur_off) as + // we step k from chunk_end-1 down to chunk_start. + let end_pos = decode_pool_pos(chunk_end - 1u); + var cur_w: u32 = end_pos.w; + var cur_off: u32 = end_pos.slot; + // Walk from k = chunk_end-1 down to chunk_start. Within a block, // block_excl_prefix cancels algebraically (see header comment), so // we run the standard backward batch-inverse over the per-block @@ -254,23 +324,74 @@ fn main( while (k > chunk_start) { k = k - 1u; - // out[k] = inv_acc * prefix_in_block[k-1] (or inv_acc itself for k = chunk_start) + let g = subtask_base + cur_w; + let phys = g * pitch + cur_off; + + // out[k] = inv_acc * prefix_in_block[k-1] (or inv_acc itself + // for k = chunk_start). The "k-1" position in the BLOCK is the + // previous logical position in the merged pool; if that lies + // in a different subtask, prefix is still indexed by physical + // position so we decode the neighbour separately. var inv_a_k: BigInt; if (k > chunk_start) { - var prev_in_block: BigInt = prefix[slice_offset + k - 1u]; + // Compute the physical position of (k - 1). This is the + // logical predecessor of the current step. We carry a + // dedicated (prev_w, prev_off) cursor by mirroring the + // forward-from-decode but starting from (cur_w, cur_off): + // step back by one slot; if cur_off was 0, roll back to + // the previous subtask's last slot. + var prev_w: u32 = cur_w; + var prev_off: u32; + if (cur_off > 0u) { + prev_off = cur_off - 1u; + } else { + // Roll back to the largest preceding subtask with a + // non-empty count. Const-bounded by WPB. + prev_off = 0u; + for (var hop = 0u; hop < WPB; hop = hop + 1u) { + if (prev_w == 0u) { + // Should not happen — we'd be past chunk_start. + break; + } + prev_w = prev_w - 1u; + if (wg_counts[prev_w] > 0u) { + prev_off = wg_counts[prev_w] - 1u; + break; + } + } + } + let g_prev = subtask_base + prev_w; + let phys_prev = g_prev * pitch + prev_off; + var prev_in_block: BigInt = prefix[phys_prev]; inv_a_k = montgomery_product(&inv_acc, &prev_in_block); } else { inv_a_k = inv_acc; } - outputs[slice_offset + k] = inv_a_k; + outputs[phys] = inv_a_k; // Update: inv_acc <- inv_acc * a[k] = inv(prefix_in_block[k-1]) // for the next iteration. The update on the last iteration // (k = chunk_start) is wasted — minor cost, kept for code // simplicity; the loop exit condition skips it naturally. if (k > chunk_start) { - var a_k: BigInt = inputs[slice_offset + k]; + var a_k: BigInt = inputs[phys]; inv_acc = montgomery_product(&inv_acc, &a_k); + // Step (cur_w, cur_off) back by one logical position, same + // rollover logic as the prev_* decode above. + if (cur_off > 0u) { + cur_off = cur_off - 1u; + } else { + for (var hop = 0u; hop < WPB; hop = hop + 1u) { + if (cur_w == 0u) { + break; + } + cur_w = cur_w - 1u; + if (wg_counts[cur_w] > 0u) { + cur_off = wg_counts[cur_w] - 1u; + break; + } + } + } } } diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/bench_batch_affine.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/bench_batch_affine.template.wgsl new file mode 100644 index 000000000000..a75a101dc698 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/bench_batch_affine.template.wgsl @@ -0,0 +1,212 @@ +{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> by_inverse_a_funcs }} + +// Standalone single-dispatch micro-benchmark for batch-affine EC addition. +// Measures the per-pair cost of `R_i = P_i + Q_i` using Montgomery's +// batch-inverse trick at a fixed BATCH_SIZE, swept across a range of sizes +// by the host to find the sweet spot for amortising the single +// `fr_inv_by_a` call per batch. +// +// Single-workgroup-per-batch layout. Each workgroup of TPB threads +// processes one batch of BATCH_SIZE consecutive pairs. The host dispatches +// (TOTAL_PAIRS / BATCH_SIZE, 1, 1) workgroups. +// +// Phase A: each thread serially walks its BS = BATCH_SIZE / TPB pairs, +// computing delta_x_k = Q_k.x - P_k.x and accumulating the running +// product `prefix[k] = prefix[k-1] * delta_x_k` for k in +// [chunk_start, chunk_end). Captures `block_total` = product of +// all chunk deltas. +// +// Phase B: workgroup-shared Hillis-Steele scan (forward + backward) over +// the TPB block_totals. +// +// Phase C: thread 0 runs the SINGLE `fr_inv_by_a` call on the overall +// product and broadcasts the result via workgroup memory. +// +// Phase D: each thread back-walks its chunk emitting one affine add per +// step. Uses standard 2-mul/element backward batch-inverse: +// `inv_acc = inv(block_total)` at the start, then for each k from +// chunk_end-1 down to chunk_start: +// inv_dx = (k > chunk_start) ? inv_acc * prefix[k-1] : inv_acc +// ... affine add ... +// outputs[2k] = R.x +// outputs[2k+1] = R.y +// if (k > chunk_start) inv_acc = inv_acc * delta_x_k +// +// LOOP BOUNDS — every loop in this kernel is bounded by a compile-time +// `const`: +// - PHASE_A_LOOP and PHASE_D_LOOP iterate `BS` times (compile-time +// Mustache constant `{{ per_thread_count }}`). +// - PHASE_B_LOOP runs while `stride < TPB` (compile-time const). +// - Every inner `montgomery_product` / `fr_inv_by_a` / `fr_sub` loop is +// bounded by `NUM_WORDS` or other compile-time constants in the +// included partials. +// +// MEMORY BUDGET. Workgroup memory: 2 × TPB × sizeof(BigInt) = 2 × TPB × 80 +// bytes (BN254 BigInt = 20 × u32 = 80 B) + 1 × BigInt broadcast slot. +// At TPB=64 this is 10 KiB + 80 B, well under the 16 KiB default. + +const BATCH_SIZE: u32 = {{ batch_size }}u; +const TPB: u32 = {{ tpb }}u; +const BS: u32 = {{ per_thread_count }}u; // = BATCH_SIZE / TPB + +@group(0) @binding(0) +var inputs: array; + +@group(0) @binding(1) +var prefix: array; + +@group(0) @binding(2) +var outputs: array; + +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +var wg_fwd: array; +var wg_bwd: array; +var wg_inv_total: BigInt; + +@compute +@workgroup_size({{ tpb }}) +fn main( + @builtin(local_invocation_id) lid: vec3, + @builtin(workgroup_id) wid: vec3, +) { + let tid = lid.x; + let batch_idx = wid.x; + let batch_base = batch_idx * BATCH_SIZE; + let chunk_start = tid * BS; + // chunk_end is statically BS past chunk_start since BATCH_SIZE is an + // exact multiple of TPB by construction; no clamping needed. + let chunk_end = chunk_start + BS; + + // Phase A: per-thread inclusive-prefix product over `delta_x_k`. + var block_total: BigInt = get_r(); + // First iter (k = chunk_start): seed acc with delta_x at that slot. + { + let k = chunk_start; + let pair_base = (batch_base + k) * 4u; + var p_x: BigInt = inputs[pair_base + 0u]; + var q_x: BigInt = inputs[pair_base + 2u]; + var dx: BigInt = fr_sub(&q_x, &p_x); + // prefix slot is per-WG-relative inside this batch's slice. + prefix[batch_base + k] = dx; + block_total = dx; + } + // Subsequent iters: const-bounded loop over the remaining BS-1 slots. + // PHASE_A_LOOP: bound = BS (compile-time constant). + for (var i = 1u; i < BS; i = i + 1u) { + let k = chunk_start + i; + let pair_base = (batch_base + k) * 4u; + var p_x: BigInt = inputs[pair_base + 0u]; + var q_x: BigInt = inputs[pair_base + 2u]; + var dx: BigInt = fr_sub(&q_x, &p_x); + block_total = montgomery_product(&block_total, &dx); + prefix[batch_base + k] = block_total; + } + + wg_fwd[tid] = block_total; + wg_bwd[tid] = block_total; + workgroupBarrier(); + + // Phase B: Hillis-Steele forward + backward scans on the TPB + // block_totals. PHASE_B_LOOP: bound = TPB (compile-time constant). + for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u) { + var fwd_x: BigInt = wg_fwd[tid]; + if (tid >= stride) { + var lhs: BigInt = wg_fwd[tid - stride]; + fwd_x = montgomery_product(&lhs, &fwd_x); + } + var bwd_x: BigInt = wg_bwd[tid]; + if (tid + stride < TPB) { + var rhs: BigInt = wg_bwd[tid + stride]; + bwd_x = montgomery_product(&bwd_x, &rhs); + } + workgroupBarrier(); + wg_fwd[tid] = fwd_x; + wg_bwd[tid] = bwd_x; + workgroupBarrier(); + } + + // Phase C: thread 0 inverts the global total via fr_inv_by_a. + if (tid == 0u) { + var global_total: BigInt = wg_fwd[TPB - 1u]; + wg_inv_total = fr_inv_by_a(global_total); + } + workgroupBarrier(); + + // Setup: inv_acc = inv(block_total[tid]) + // = inv_global * block_excl_prefix * block_excl_suffix + var block_excl_prefix: BigInt = get_r(); + if (tid > 0u) { + block_excl_prefix = wg_fwd[tid - 1u]; + } + var block_excl_suffix: BigInt = get_r(); + if (tid + 1u < TPB) { + block_excl_suffix = wg_bwd[tid + 1u]; + } + var inv_global: BigInt = wg_inv_total; + var inv_acc: BigInt = montgomery_product(&inv_global, &block_excl_prefix); + inv_acc = montgomery_product(&inv_acc, &block_excl_suffix); + + // Phase D: back-walk this thread's chunk, emitting affine adds. + // PHASE_D_LOOP: bound = BS (compile-time constant). Use the + // forward-form `for (off = 0u; off < BS; ...)` and derive the + // descending index k = chunk_end - 1 - off so the bound is provably + // static at compile time. NO `while`/`loop` constructs. + for (var off = 0u; off < BS; off = off + 1u) { + let k = chunk_start + (BS - 1u - off); + let pair_base = (batch_base + k) * 4u; + + // Load all four pair coords once. + var p_x: BigInt = inputs[pair_base + 0u]; + var p_y: BigInt = inputs[pair_base + 1u]; + var q_x: BigInt = inputs[pair_base + 2u]; + var q_y: BigInt = inputs[pair_base + 3u]; + + // Backward batch-inverse: inv_dx_k = inv_acc * prefix[k-1] for + // k > chunk_start, or inv_acc itself at the chunk's start slot. + var inv_dx: BigInt; + if (k > chunk_start) { + var prev_prefix: BigInt = prefix[batch_base + (k - 1u)]; + inv_dx = montgomery_product(&inv_acc, &prev_prefix); + } else { + inv_dx = inv_acc; + } + + // Affine add (Montgomery form): + // λ = (Q.y - P.y) * inv_dx + // R.x = λ^2 - P.x - Q.x + // R.y = λ * (P.x - R.x) - P.y + var dy: BigInt = fr_sub(&q_y, &p_y); + var slope: BigInt = montgomery_product(&dy, &inv_dx); + var slope_sq: BigInt = montgomery_product(&slope, &slope); + var t1: BigInt = fr_sub(&slope_sq, &p_x); + var r_x: BigInt = fr_sub(&t1, &q_x); + var dx_back: BigInt = fr_sub(&p_x, &r_x); + var ldx: BigInt = montgomery_product(&slope, &dx_back); + var r_y: BigInt = fr_sub(&ldx, &p_y); + + // Output: 2 BigInts per pair. + let out_base = (batch_base + k) * 2u; + outputs[out_base + 0u] = r_x; + outputs[out_base + 1u] = r_y; + + // Update inv_acc for the next (smaller-k) iteration, unless we + // just emitted the chunk-start slot. + if (k > chunk_start) { + var dx_k: BigInt = fr_sub(&q_x, &p_x); + inv_acc = montgomery_product(&inv_acc, &dx_k); + } + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/bpr_bn254.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/bpr_bn254.template.wgsl index 54eb9a0446f4..9ef1a54613e3 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/bpr_bn254.template.wgsl +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/bpr_bn254.template.wgsl @@ -21,9 +21,14 @@ var g_points_y: array; @group(0) @binding(5) var g_points_z: array; -// Unfiform storage buffer. +// Uniform storage buffer. Layout: (subtask_idx_base, num_columns, +// num_subtasks_per_bpr, num_subtasks_total). The 4th slot is the +// dispatch's bounds-check ceiling for the WPB tail batch — values past +// num_subtasks_total are skipped per-thread inside the multi-window +// loop. Promoted from vec3 to vec4 so the host can pass a single +// uniform across both legacy (WPB=1) and multi-window dispatches. @group(0) @binding(6) -var params: vec3; +var params: vec4; {{#capture_debug}} // Debug capture for stage_2's inline add-2007-bl formula. Layout: 8 BigInts @@ -149,15 +154,29 @@ fn bench_xor_into(a: Point, b: Point) -> Point { return r; } +// Multi-window BPR (WPB > 1) trades thread count for per-thread work: +// fewer workgroups dispatched, each thread runs the WPB-window inner loop +// {{ windows_per_batch }}u times (one per subtask in the batch). Saves +// total launches and amortises kernel header / dispatch overhead, but +// inflates per-thread register pressure — a single subtask already +// carries ~10 BigInt locals (~800B/thread) inside the add_points formula, +// so WPB > 1 risks spilling on Apple Silicon SIMD-groups (32 KB shared +// register file). When WPB=1 the inner loop runs once and the dispatch +// shape collapses to the legacy (X=num_subtasks, 256) layout. @compute @workgroup_size({{ workgroup_size }}) -fn stage_1(@builtin(global_invocation_id) global_id: vec3) { - let thread_id = global_id.x; +fn stage_1( + @builtin(global_invocation_id) global_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, +) { + let thread_id = local_id.x; let num_threads_per_subtask = {{ workgroup_size }}u; - let subtask_idx = params[0]; // 0, 2, 4, 6, ... + let subtask_idx_base = params[0]; // base subtask for this dispatch (host-side outer iter) let num_columns = params[1]; // 65536 - let num_subtasks_per_bpr = params[2]; // Must be a power of 2 + let num_subtasks_per_bpr = params[2]; // packs g_points output across subtasks (legacy) + let num_subtasks_total = params[3]; // bounds check for tail batches when WPB > 1 /// Number of buckets per subtask. let num_buckets_per_subtask = num_columns / 2u; // 2 ** 15 = 32768 @@ -167,17 +186,30 @@ fn stage_1(@builtin(global_invocation_id) global_id: vec3) { let num_buckets_per_bpr = num_buckets_per_subtask * num_subtasks_per_bpr; /// Number of buckets to reduce per thread. - let buckets_per_thread = num_buckets_per_subtask / + let buckets_per_thread = num_buckets_per_subtask / num_threads_per_subtask; - let multiplier = subtask_idx + (thread_id / num_threads_per_subtask); - let offset = num_buckets_per_subtask * multiplier; + // WPB-aware subtask iteration. Each workgroup owns WPB consecutive + // subtasks; thread `thread_id` processes the same in-subtask slice + // for every one of them. With WPB=1 the outer loop runs once and + // the semantics are byte-identical to the legacy dispatch. + for (var w_local = 0u; w_local < {{ windows_per_batch }}u; w_local = w_local + 1u) { + let subtask_in_dispatch = wg_id.x * {{ windows_per_batch }}u + w_local; + let subtask_idx = subtask_idx_base + subtask_in_dispatch; + if (subtask_idx >= num_subtasks_total) { + // Tail batch: skip out-of-range subtasks (num_subtasks may + // not be a multiple of WPB). + continue; + } - var idx = offset; - if (thread_id % num_threads_per_subtask != 0u) { - idx = (num_threads_per_subtask - (thread_id % num_threads_per_subtask)) * - buckets_per_thread + offset; - } + let multiplier = subtask_idx; + let offset = num_buckets_per_subtask * multiplier; + + var idx = offset; + if (thread_id % num_threads_per_subtask != 0u) { + idx = (num_threads_per_subtask - (thread_id % num_threads_per_subtask)) * + buckets_per_thread + offset; + } {{#bench_compute_only}} // Microbench: synthesize the initial m so this load is also stripped. @@ -452,40 +484,46 @@ fn stage_1(@builtin(global_invocation_id) global_id: vec3) { {{/mixed_safe_buckets}} {{/assume_affine_buckets}} - let t = (subtask_idx / num_subtasks_per_bpr) * - (num_threads_per_subtask * num_subtasks_per_bpr) + - thread_id; + let t = (subtask_idx / num_subtasks_per_bpr) * + (num_threads_per_subtask * num_subtasks_per_bpr) + + thread_id; {{#bench_skip_writes}} - // V_NULL / V3: skip the 6 storage writes (bucket_sum + g_points). - // Funnel one observable write per thread into the workgroup atomic - // sink so the WGSL compiler can't DCE the inner loop. atomicXor is - // the cheapest read-modify-write that can't be folded; it produces - // 1 LDS op per thread vs 6 storage writes in the production path. - atomicXor(&bench_sink, m.x.limbs[0] ^ g.x.limbs[0] ^ t); + // V_NULL / V3: skip the 6 storage writes (bucket_sum + g_points). + // Funnel one observable write per thread into the workgroup atomic + // sink so the WGSL compiler can't DCE the inner loop. atomicXor is + // the cheapest read-modify-write that can't be folded; it produces + // 1 LDS op per thread vs 6 storage writes in the production path. + atomicXor(&bench_sink, m.x.limbs[0] ^ g.x.limbs[0] ^ t); {{/bench_skip_writes}} {{^bench_skip_writes}} - bucket_sum_x[idx] = m.x; - bucket_sum_y[idx] = m.y; - bucket_sum_z[idx] = m.z; + bucket_sum_x[idx] = m.x; + bucket_sum_y[idx] = m.y; + bucket_sum_z[idx] = m.z; - g_points_x[t] = g.x; - g_points_y[t] = g.y; - g_points_z[t] = g.z; + g_points_x[t] = g.x; + g_points_y[t] = g.y; + g_points_z[t] = g.z; {{/bench_skip_writes}} + } // end of w_local loop (WPB-aware multi-window outer) {{{ recompile }}} } @compute @workgroup_size({{ workgroup_size }}) -fn stage_2(@builtin(global_invocation_id) global_id: vec3) { - let thread_id = global_id.x; +fn stage_2( + @builtin(global_invocation_id) global_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, +) { + let thread_id = local_id.x; let num_threads_per_subtask = {{ workgroup_size }}u; - let subtask_idx = params[0]; // 0, 2, 4, 6, ... + let subtask_idx_base = params[0]; // base subtask for this dispatch (host-side outer iter) let num_columns = params[1]; // 65536 - let num_subtasks_per_bpr = params[2]; // Must be a power of 2 + let num_subtasks_per_bpr = params[2]; // packs g_points output across subtasks (legacy) + let num_subtasks_total = params[3]; // bounds check for tail batches when WPB > 1 /// Number of buckets per subtask. let num_buckets_per_subtask = num_columns / 2u; // 2 ** 15 = 32768 @@ -495,27 +533,35 @@ fn stage_2(@builtin(global_invocation_id) global_id: vec3) { let num_buckets_per_bpr = num_buckets_per_subtask * num_subtasks_per_bpr; /// Number of buckets to reduce per thread. - let buckets_per_thread = num_buckets_per_subtask / + let buckets_per_thread = num_buckets_per_subtask / num_threads_per_subtask; - let multiplier = subtask_idx + (thread_id / num_threads_per_subtask); - let offset = num_buckets_per_subtask * multiplier; + // WPB-aware subtask iteration. See stage_1 for the rationale. + for (var w_local = 0u; w_local < {{ windows_per_batch }}u; w_local = w_local + 1u) { + let subtask_in_dispatch = wg_id.x * {{ windows_per_batch }}u + w_local; + let subtask_idx = subtask_idx_base + subtask_in_dispatch; + if (subtask_idx >= num_subtasks_total) { + continue; + } - var idx = offset; - if (thread_id % num_threads_per_subtask != 0u) { - idx = (num_threads_per_subtask - (thread_id % num_threads_per_subtask)) * - buckets_per_thread + offset; - } + let multiplier = subtask_idx; + let offset = num_buckets_per_subtask * multiplier; - var m = load_bucket_sum(idx); + var idx = offset; + if (thread_id % num_threads_per_subtask != 0u) { + idx = (num_threads_per_subtask - (thread_id % num_threads_per_subtask)) * + buckets_per_thread + offset; + } + + var m = load_bucket_sum(idx); - let t = (subtask_idx / num_subtasks_per_bpr) * - (num_threads_per_subtask * num_subtasks_per_bpr) + - thread_id; - var g = load_g_point(t); + let t = (subtask_idx / num_subtasks_per_bpr) * + (num_threads_per_subtask * num_subtasks_per_bpr) + + thread_id; + var g = load_g_point(t); - let s = buckets_per_thread * (num_threads_per_subtask - (thread_id % num_threads_per_subtask) - 1u); - let sm: Point = double_and_add(m, s); + let s = buckets_per_thread * (num_threads_per_subtask - (thread_id % num_threads_per_subtask) - 1u); + let sm: Point = double_and_add(m, s); // The add-2007-bl formula is inlined here (rather than calling // add_points(g, sm)) because on Dawn/Metal the function-return slot for @@ -603,16 +649,17 @@ fn stage_2(@builtin(global_invocation_id) global_id: vec3) { } } - g_points_x[t] = X3_out; - g_points_y[t] = Y3_out; - g_points_z[t] = Z3_out; - - {{#capture_debug}} - // Final values of the formula outputs, read from outer scope. - debug_capture[thread_id * 8u + 4u] = X3_out; - debug_capture[thread_id * 8u + 5u] = Y3_out; - debug_capture[thread_id * 8u + 7u] = Z3_out; - {{/capture_debug}} + g_points_x[t] = X3_out; + g_points_y[t] = Y3_out; + g_points_z[t] = Z3_out; + + {{#capture_debug}} + // Final values of the formula outputs, read from outer scope. + debug_capture[thread_id * 8u + 4u] = X3_out; + debug_capture[thread_id * 8u + 5u] = Y3_out; + debug_capture[thread_id * 8u + 7u] = Z3_out; + {{/capture_debug}} + } // end of w_local loop (WPB-aware multi-window outer) {{{ recompile }}} } diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/divsteps_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/divsteps_bench.template.wgsl new file mode 100644 index 000000000000..f2ac49f775ff --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/divsteps_bench.template.wgsl @@ -0,0 +1,44 @@ +// Single-thread-per-input bench shader for `by_divsteps`. +// +// Each thread: +// 1. Reads (f_lo, g_lo) as a vec4 from `inputs_fg[tid]` +// (= f_lo.x, f_lo.y, g_lo.x, g_lo.y). +// 2. Reads the initial delta from `inputs_delta[tid]`. +// 3. Calls `by_divsteps(&delta, f_lo, g_lo)`. +// 4. Writes the 8 Mat fields + the updated delta (9 i32) into `outputs[tid]`. +// +// LOOP BOUNDS +// The only loops in the included partials are bounded by WGSL `const`s +// (BY_NUM_LIMBS, BY_BATCH) or Mustache-const values (`{{ num_words }}`). +// No loop is added in this entry shader. The `if (tid >= n)` guard is not +// a loop bound. + +@group(0) @binding(0) var inputs_fg: array>; +@group(0) @binding(1) var inputs_delta: array; +@group(0) @binding(2) var outputs: array; +@group(0) @binding(3) var params: vec2; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let n = params.x; + let tid = gid.x; + if (tid >= n) { return; } + + let fg = inputs_fg[tid]; + let f_lo: vec2 = vec2(fg.x, fg.y); + let g_lo: vec2 = vec2(fg.z, fg.w); + var delta: i32 = inputs_delta[tid]; + + let m: Mat = by_divsteps(&delta, f_lo, g_lo); + + let base: u32 = tid * 9u; + outputs[base + 0u] = m.u; + outputs[base + 1u] = m.v; + outputs[base + 2u] = m.q; + outputs[base + 3u] = m.r; + outputs[base + 4u] = m.u_hi; + outputs[base + 5u] = m.v_hi; + outputs[base + 6u] = m.q_hi; + outputs[base + 7u] = m.r_hi; + outputs[base + 8u] = delta; +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/fr_inv_bench.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/fr_inv_bench.template.wgsl new file mode 100644 index 000000000000..7b6226ee52da --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/fr_inv_bench.template.wgsl @@ -0,0 +1,50 @@ +// Per-thread chained-inversion micro-benchmark for the field inverse. +// +// Each thread reads one BN254 base-field value `a` (in Montgomery form) and +// runs `k` chained calls to the inverse function selected by the +// `{{ inv_fn }}` Mustache substitution (default `fr_inv_by`; the legacy +// `fr_inv` Pornin jumpy K=12 safegcd is the other supported variant): +// a <- {{ inv_fn }}(a) repeated k times, +// writing the final `a` back to `outputs[tid]`. The host caps k at 100 +// before dispatch. +// +// LOOP BOUNDS +// - The only data-dependent loop in this entry shader is +// `for (var i = 0u; i < k; i = i + 1u)` with `k = params.y` host-capped +// at 100. Every loop in the selected inverse function (and its callees) +// is bounded by a compile-time `const` — see by_inverse.template.wgsl, +// bigint_by, fr_pow.template.wgsl, and the inner `montgomery_product` +// (`for i, j < NUM_WORDS`). +// - The `if (tid >= n)` guard is not a loop bound. +// +// Bind layout mirrors `field_mul_bench_u32`: a single inputs buffer `xs` +// (no `ys`, since inversion is unary) and an outputs buffer. `params.x` is +// `n` (threads), `params.y` is `k` (chain length). + +// get_r returns Montgomery R (= the integer 1 in Montgomery form). Defined +// per-entry shader by every cuzk pipeline kernel that uses fr_pow / fr_inv. +// fr_pow_funcs references it via fr_pow; not called on fr_inv_by's hot path +// but kept for binary compatibility with the partial. +fn get_r() -> BigInt { + var r: BigInt; +{{{ r_limbs }}} + return r; +} + +@group(0) @binding(0) var xs: array; +@group(0) @binding(1) var outputs: array; +@group(0) @binding(2) var params: vec2; + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3) { + let n = params.x; + let k = params.y; + let tid = gid.x; + if (tid >= n) { return; } + + var a = xs[tid]; + for (var i = 0u; i < k; i = i + 1u) { + a = {{ inv_fn }}(a); + } + outputs[tid] = a; +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/field/by_inverse.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/field/by_inverse.template.wgsl new file mode 100644 index 000000000000..3663c4e72419 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/field/by_inverse.template.wgsl @@ -0,0 +1,1075 @@ +// Bernstein-Yang safegcd inversion for the BN254 base field, WGSL port. +// +// This file will grow over sub-steps 1.3-1.5 of the WebGPU MSM rewrite plan +// to host the full `fr_inv_by` entrypoint. Currently it contains the inner +// `by_divsteps` primitive (sub-step 1.3) — a line-for-line transliteration +// of `Wasm9x29::divsteps` (bernstein_yang_inverse_wasm.hpp lines 147-178). +// +// REFERENCES +// - TS port (ground truth): src/msm_webgpu/cuzk/bernstein_yang.ts +// - bigint helpers used here: src/msm_webgpu/wgsl/bigint/bigint_by.template.wgsl +// +// REPRESENTATIONS +// - delta: i32 counter (mirrors the i64 in C++; only the low value +// matters because BATCH=58 caps |delta| growth per call). +// - f_lo, g_lo: u64 carriers held as vec2 (.x = low 32, .y = high 32). +// All u64 ops via u64_add / u64_sub / u64_shr1 helpers. +// - u, v, q, r: i64 matrix entries held as paired i32 (.x = low 32 unsigned +// bit pattern, .y = high 32 signed). After BATCH divsteps +// |entry| <= 2^58, fits in i64. All ops via i64_*_pair. +// +// LOOP BOUND DISCIPLINE +// Three loops total in this file: +// - by_divsteps inner loop: `for ... i < BY_BATCH` (= 58u const) +// - by_apply_matrix_fg streaming: `for ... i < BY_NUM_LIMBS` (= 9u const) +// - by_apply_matrix_de streaming: `for ... i < BY_NUM_LIMBS` (= 9u const) +// Both bounds are compile-time WGSL `const`s defined in bigint_by, so the +// plan's "bounded loops" rule is satisfied. The audit grep +// `grep -nE 'for *\(' ... | grep -v -E '< [A-Z][A-Z_]*[a-z]?|< [0-9]+|...'` +// returns no matches. + +// 2x2 matrix produced by BATCH=58 divsteps. Each entry is a 64-bit signed +// integer stored as a paired (lo: i32, hi: i32) — value = u32(lo) | (i32(hi) << 32), +// interpreted as two's complement. The naming matches the C++ struct field +// names (m.u, m.v, m.q, m.r) suffixed with `_hi` for the high half. +struct Mat { + u: i32, + v: i32, + q: i32, + r: i32, + u_hi: i32, + v_hi: i32, + q_hi: i32, + r_hi: i32, +} + +// by_divsteps: run BATCH = 58 branchy divsteps on the low 64 bits of (f, g); +// returns the transition matrix M and updates `*delta`. +// +// Mirrors Wasm9x29::divsteps line-for-line. The branches are variable-time +// over inputs, which is acceptable here because BN254 base-field inversion +// operates on public values in the MSM pipeline. +// +// Pre: delta is the current divstep counter (signed i32 view). +// f_lo, g_lo are the low 64 bits of f and g respectively (vec2). +// Post: returns the matrix M = ((u, v), (q, r)) such that +// (f_new, g_new) = M * (f_old, g_old) / 2^BATCH +// after BATCH=58 divsteps. *delta is updated. +// +// The TS reference uses `(g_lo - f_lo) & U64_MASK` then `>> 1` to mimic the +// C++ `(u64)(g_lo - f_lo) >> 1` semantics: u64 wrap, then unsigned shift. +// u64_sub + u64_shr1 here is exactly that. +fn by_divsteps(delta: ptr, f_lo_in: vec2, g_lo_in: vec2) -> Mat { + var f_lo: vec2 = f_lo_in; + var g_lo: vec2 = g_lo_in; + // Matrix entries as paired i32 (i64). u = 1, v = 0, q = 0, r = 1. + var u: vec2 = vec2(1, 0); + var v: vec2 = vec2(0, 0); + var q: vec2 = vec2(0, 0); + var r: vec2 = vec2(1, 0); + var d: i32 = *delta; + for (var i: u32 = 0u; i < BY_BATCH; i = i + 1u) { + if (u64_low_bit(g_lo) != 0u) { + if (d > 0) { + // (f, g) <- (g, (g - f) >> 1) using u64 wrap-sub then unsigned >> 1. + let nf: vec2 = g_lo; + let diff: vec2 = u64_sub(g_lo, f_lo); + let ng: vec2 = u64_shr1(diff); + // (u, v, q, r) <- (q << 1, r << 1, q - u, r - v). + let nu: vec2 = i64_shl1_pair(q); + let nv: vec2 = i64_shl1_pair(r); + let nq: vec2 = i64_sub_pair(q, u); + let nr: vec2 = i64_sub_pair(r, v); + f_lo = nf; + g_lo = ng; + u = nu; + v = nv; + q = nq; + r = nr; + d = 1 - d; + } else { + // g <- (g + f) >> 1; q += u; r += v; u <<= 1; v <<= 1; d += 1. + let sum: vec2 = u64_add(g_lo, f_lo); + g_lo = u64_shr1(sum); + q = i64_add_pair(q, u); + r = i64_add_pair(r, v); + u = i64_shl1_pair(u); + v = i64_shl1_pair(v); + d = d + 1; + } + } else { + // g <- g >> 1; u <<= 1; v <<= 1; d += 1. + g_lo = u64_shr1(g_lo); + u = i64_shl1_pair(u); + v = i64_shl1_pair(v); + d = d + 1; + } + } + *delta = d; + return Mat(u.x, v.x, q.x, r.x, u.y, v.y, q.y, r.y); +} + +// ============================================================ +// apply_matrix helpers +// +// `signed_mul_split` accepts |a| <= 2^29 (one BY limb) and |b| <= 2^31 - 1 +// and returns (lo29, hi) with a*b = lo29 + hi * 2^29. The streaming +// schoolbook below feeds it (m_*, x_limb) pairs whose products lie in +// [-2^58, 2^58], well inside the helper's contract. +// +// Each per-limb position i computes +// acc <- m_lo * x_i + m_hi * x_{i-1} + carry_in (+ k * p_i terms in de pass) +// then carry_out = acc >> 29 (arithmetic), with the masked low-29 bits +// landing at output position i - 2 (= exact >> BATCH = >> 58 = >> (2 * 29)). +// +// The signed_mul_split returns lo29 in [0, 2^29) and a signed hi. Adding +// four cross-products together can push the high half beyond i32 range +// only when |sum| exceeds ~2^31 — at the per-limb level this is bounded +// by the four-product sum of |2^58| / 2^29 + carry, well inside i32. + +// Convert a (lo29, hi) signed-product split to a full i64 (vec2). +// +// The signed_mul_split helper returns (lo29, hi) where +// value = lo29 + hi * 2^29 with lo29 in [0, 2^29) +// To add this into an i64 accumulator we need to re-express it as a +// (lo32, hi32) pair where +// value = u32(lo32) + i32(hi32) * 2^32 (two's complement) +// +// Bit layout: +// bits 0..28 of value = lo29 (from the lo29 half) +// bits 29..31 of value = bits 0..2 of `hi` +// bits 32..63 of value = bits 3..34 of `hi` (sign-extended) +// +// `hi << 29u` on a u32 keeps only the low 3 bits of `hi` after shifting, +// which yields exactly the contribution to bits 29..31. The high half is +// `hi >> 3u` with signed arithmetic shift, which sign-extends to fill bit +// 63 correctly when `hi` is negative. +fn by_split_to_i64(split: vec2) -> vec2 { + let lo32: u32 = u32(split.x) | (u32(split.y) << 29u); + let hi32: i32 = split.y >> 3u; + return vec2(i32(lo32), hi32); +} + +// Add `m_lo * x_limb` (a signed 58-bit product) into an i64 accumulator. +// Helper used pervasively in the streaming schoolbook below. +fn by_add_mul(acc: vec2, m_lo: i32, x_limb: i32) -> vec2 { + let split = signed_mul_split(m_lo, x_limb); + let prod_i64 = by_split_to_i64(split); + return i64_add_pair(acc, prod_i64); +} + +// Arithmetic right shift of an i64 (vec2) by 29. +// result.lo = (u32(acc.x) >> 29) | (u32(acc.y) << 3) +// result.hi = acc.y >> 29 (signed arithmetic shift) +fn i64_ars29(acc: vec2) -> vec2 { + let lo_u: u32 = (u32(acc.x) >> 29u) | (u32(acc.y) << 3u); + let hi: i32 = acc.y >> 29u; + return vec2(i32(lo_u), hi); +} + +// Low 29 bits of an i64. Returns i32 in [0, 2^29). +fn i64_low29(acc: vec2) -> i32 { + return i32(u32(acc.x) & BY_LIMB_MASK); +} + +// u64_mul_low64: low 64 bits of an unsigned u64 * u64 product. +// +// Implements via four 16x16 partials per operand half (16 partials total +// to compute the full 128-bit product), summing only the bits that +// land in the low 64. Used to evaluate `k = ((-t) * p_inv) mod 2^58` +// inside by_apply_matrix_de — the C++ does `(u64)(-(i64)t) * p_inv` and +// keeps the low 58 bits; we keep the low 64 and let the caller mask. +// +// Pre: any a, b u64 (as vec2). +// Post: low 64 bits of a*b, two's complement-equivalent under masking. +fn u64_mul_low64(a: vec2, b: vec2) -> vec2 { + // Split each u32 half into 16-bit pieces: + // a.x = a0 + a1 * 2^16, a.y = a2 + a3 * 2^16 + // b.x = b0 + b1 * 2^16, b.y = b2 + b3 * 2^16 + let MASK16: u32 = 0xFFFFu; + let a0: u32 = a.x & MASK16; + let a1: u32 = a.x >> 16u; + let a2: u32 = a.y & MASK16; + let a3: u32 = a.y >> 16u; + let b0: u32 = b.x & MASK16; + let b1: u32 = b.x >> 16u; + let b2: u32 = b.y & MASK16; + let b3: u32 = b.y >> 16u; + + // Partials landing in bits 0..15 (only one: a0*b0). + let p00: u32 = a0 * b0; + // Partials in bits 16..47 (a0*b1, a1*b0; we'll split further). + let p01: u32 = a0 * b1; + let p10: u32 = a1 * b0; + // Partials in bits 32..63. + let p02: u32 = a0 * b2; + let p20: u32 = a2 * b0; + let p11: u32 = a1 * b1; + // Partials in bits 48..79 (we only keep the part falling in [0, 64)). + let p03: u32 = a0 * b3; + let p30: u32 = a3 * b0; + let p12: u32 = a1 * b2; + let p21: u32 = a2 * b1; + // Partials in bits 64..95 (a1*b3, a3*b1, a2*b2) — discarded except for + // the part that wraps into the low 64 via the cross sums below. With + // bit-offset >= 64 the contribution is zero in the low 64. + + // Build the low 64 bits: + // bits 0..15: p00 low 16 + // bits 16..47: p00 high 16 + (p01 + p10) low 32 + // bits 32..63: carries from above + (p02 + p20 + p11) low 32 + + // ((p01 + p10) >> 16) plus higher partials' low pieces + // bits 48..63: (p03 + p30 + p12 + p21) low 16 + + // Sum bits 0..31 (the low u32 of the result). + let lo16 = p00 & MASK16; + let mid_a = (p00 >> 16u) + (p01 & MASK16) + (p10 & MASK16); + let lo_u32 = lo16 | (mid_a << 16u); + // Carry into bits 32+ from `mid_a` (the high part beyond 16 bits). + let mid_a_hi = mid_a >> 16u; + + // Sum bits 32..63. + // Contributions landing entirely in [32, 64): + // (p01 + p10) >> 16 (these are 32-bit values; the >> 16 lands them at bit 32) + // p02, p20, p11 (start at bit 32; whole 32 bits land in [32, 64)) + // Contributions landing partially in [32, 64) starting at bit 48: + // p03 << 16, p30 << 16, p12 << 16, p21 << 16 + let mid_b = (p01 >> 16u) + (p10 >> 16u) + p02 + p20 + p11; + let mid_c = (p03 + p30 + p12 + p21) << 16u; + let hi_u32 = mid_a_hi + mid_b + mid_c; + + return vec2(lo_u32, hi_u32); +} + +// by_apply_matrix_fg +// +// Mirrors `Wasm9x29::apply_matrix` lines 196-217 — the (f, g) streaming +// pass. After BATCH=58 divsteps we apply the 2x2 transition matrix M to +// (f, g) and divide by 2^58. The streamed schoolbook produces one (nf, ng) +// pair per source limb position i and writes the masked low-29 bits at +// output position i - 2 (= the exact >> 58 = >> (2 * 29) drop). +// +// PERF: inlined hot path. Replaces the four `by_add_mul` calls per +// accumulator with a single fused 15+14-bit partial-product schoolbook +// that sums all four products' lane-i pieces into a single i32 before +// any carry propagation. This eliminates the per-call (lo29,hi)→(lo32,hi32) +// conversion overhead and reduces 4 i64 adds to one composite extract. +// +// LANE PARTIAL-PRODUCT BOUND: +// Each per-limb cross product (a_l, a_h) * (b_l, b_h) yields four 28-bit +// signed pieces: pll, plh, phl, phh (each |.| < 2^28). +// For 4 products into one accumulator: per-lane sum |.| < 4 * 2^28 = 2^30 +// (fits i32 comfortably). The combined "mid" lane (plh+phl summed across +// 4 products) is |.| < 2 * 4 * 2^28 = 2^31 — still fits i32 (signed). +// +// LOOP BOUND: `for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u)` — const. +fn by_apply_matrix_fg(m: Mat, f: ptr, g: ptr) { + // Decompose each of the four matrix entries into low/high 29-bit halves, + // and then split each half into (15-bit signed low, 14-bit high) chunks + // for the partial-product schoolbook below. Hoisted out of the inner + // loop (loop-invariant). + let u_lo: i32 = m.u & i32(BY_LIMB_MASK); + let v_lo: i32 = m.v & i32(BY_LIMB_MASK); + let q_lo: i32 = m.q & i32(BY_LIMB_MASK); + let r_lo: i32 = m.r & i32(BY_LIMB_MASK); + let u_hi: i32 = i32((u32(m.u) >> 29u) | (u32(m.u_hi) << 3u)); + let v_hi: i32 = i32((u32(m.v) >> 29u) | (u32(m.v_hi) << 3u)); + let q_hi: i32 = i32((u32(m.q) >> 29u) | (u32(m.q_hi) << 3u)); + let r_hi: i32 = i32((u32(m.r) >> 29u) | (u32(m.r_hi) << 3u)); + + let u_lo_l: i32 = (u_lo << 17u) >> 17u; + let u_lo_h: i32 = (u_lo - u_lo_l) >> 15u; + let v_lo_l: i32 = (v_lo << 17u) >> 17u; + let v_lo_h: i32 = (v_lo - v_lo_l) >> 15u; + let q_lo_l: i32 = (q_lo << 17u) >> 17u; + let q_lo_h: i32 = (q_lo - q_lo_l) >> 15u; + let r_lo_l: i32 = (r_lo << 17u) >> 17u; + let r_lo_h: i32 = (r_lo - r_lo_l) >> 15u; + let u_hi_l: i32 = (u_hi << 17u) >> 17u; + let u_hi_h: i32 = (u_hi - u_hi_l) >> 15u; + let v_hi_l: i32 = (v_hi << 17u) >> 17u; + let v_hi_h: i32 = (v_hi - v_hi_l) >> 15u; + let q_hi_l: i32 = (q_hi << 17u) >> 17u; + let q_hi_h: i32 = (q_hi - q_hi_l) >> 15u; + let r_hi_l: i32 = (r_hi << 17u) >> 17u; + let r_hi_h: i32 = (r_hi - r_hi_l) >> 15u; + + // Streaming accumulator as i64 (lo, hi). + var cf_lo: u32 = 0u; + var cf_hi: i32 = 0; + var cg_lo: u32 = 0u; + var cg_hi: i32 = 0; + + // Previous limb 15/14-bit pre-splits (for u_hi * fp etc.). Start at 0; + // slide forward each iter to avoid re-splitting next time. + var fp_l: i32 = 0; + var fp_h: i32 = 0; + var gp_l: i32 = 0; + var gp_h: i32 = 0; + + // Single loop with conditional output: the per-iter `if (i >= 2)` check + // costs less than the duplicated loop body of a prologue/main split. The + // compiler can predicate the store on most GPUs. + for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u) { + let fi: i32 = (*f).l[i]; + let gi: i32 = (*g).l[i]; + let fi_l: i32 = (fi << 17u) >> 17u; + let fi_h: i32 = (fi - fi_l) >> 15u; + let gi_l: i32 = (gi << 17u) >> 17u; + let gi_h: i32 = (gi - gi_l) >> 15u; + + let nf_pll: i32 = u_lo_l * fi_l + v_lo_l * gi_l + u_hi_l * fp_l + v_hi_l * gp_l; + let nf_mid: i32 = + u_lo_l * fi_h + u_lo_h * fi_l + + v_lo_l * gi_h + v_lo_h * gi_l + + u_hi_l * fp_h + u_hi_h * fp_l + + v_hi_l * gp_h + v_hi_h * gp_l; + let nf_phh: i32 = u_lo_h * fi_h + v_lo_h * gi_h + u_hi_h * fp_h + v_hi_h * gp_h; + let ng_pll: i32 = q_lo_l * fi_l + r_lo_l * gi_l + q_hi_l * fp_l + r_hi_l * gp_l; + let ng_mid: i32 = + q_lo_l * fi_h + q_lo_h * fi_l + + r_lo_l * gi_h + r_lo_h * gi_l + + q_hi_l * fp_h + q_hi_h * fp_l + + r_hi_l * gp_h + r_hi_h * gp_l; + let ng_phh: i32 = q_lo_h * fi_h + r_lo_h * gi_h + q_hi_h * fp_h + r_hi_h * gp_h; + + let nf_pll_u: u32 = u32(nf_pll); + let nf_mid_u: u32 = u32(nf_mid); + let nf_phh_u: u32 = u32(nf_phh); + let nf_pll_hi: i32 = nf_pll >> 31u; + let nf_mid_hi: i32 = nf_mid >> 17u; + let nf_phh_hi: i32 = nf_phh >> 2u; + let nf_s1_lo: u32 = nf_pll_u + (nf_mid_u << 15u); + let nf_s1_c: i32 = select(0i, 1i, nf_s1_lo < nf_pll_u); + let nf_s2_lo: u32 = nf_s1_lo + (nf_phh_u << 30u); + let nf_s2_c: i32 = select(0i, 1i, nf_s2_lo < nf_s1_lo); + let nf_total_lo: u32 = nf_s2_lo + cf_lo; + let nf_total_c: i32 = select(0i, 1i, nf_total_lo < nf_s2_lo); + let nf_total_hi: i32 = nf_pll_hi + nf_mid_hi + nf_phh_hi + nf_s1_c + nf_s2_c + nf_total_c + cf_hi; + + let ng_pll_u: u32 = u32(ng_pll); + let ng_mid_u: u32 = u32(ng_mid); + let ng_phh_u: u32 = u32(ng_phh); + let ng_pll_hi: i32 = ng_pll >> 31u; + let ng_mid_hi: i32 = ng_mid >> 17u; + let ng_phh_hi: i32 = ng_phh >> 2u; + let ng_s1_lo: u32 = ng_pll_u + (ng_mid_u << 15u); + let ng_s1_c: i32 = select(0i, 1i, ng_s1_lo < ng_pll_u); + let ng_s2_lo: u32 = ng_s1_lo + (ng_phh_u << 30u); + let ng_s2_c: i32 = select(0i, 1i, ng_s2_lo < ng_s1_lo); + let ng_total_lo: u32 = ng_s2_lo + cg_lo; + let ng_total_c: i32 = select(0i, 1i, ng_total_lo < ng_s2_lo); + let ng_total_hi: i32 = ng_pll_hi + ng_mid_hi + ng_phh_hi + ng_s1_c + ng_s2_c + ng_total_c + cg_hi; + + if (i >= 2u) { + (*f).l[i - 2u] = i32(nf_total_lo & 0x1FFFFFFFu); + (*g).l[i - 2u] = i32(ng_total_lo & 0x1FFFFFFFu); + } + cf_lo = (nf_total_lo >> 29u) | (u32(nf_total_hi) << 3u); + cf_hi = nf_total_hi >> 29u; + cg_lo = (ng_total_lo >> 29u) | (u32(ng_total_hi) << 3u); + cg_hi = ng_total_hi >> 29u; + + fp_l = fi_l; fp_h = fi_h; gp_l = gi_l; gp_h = gi_h; + } + // Top finalisation: nf9 = u_hi * fp + v_hi * fp_prev + cf (only 2 products + // now, since we've consumed all the input limbs and fi=0). Same shape as + // the inner loop body but with the *_lo terms dropped. + let nf9_pll: i32 = u_hi_l * fp_l + v_hi_l * gp_l; + let nf9_mid: i32 = u_hi_l * fp_h + u_hi_h * fp_l + v_hi_l * gp_h + v_hi_h * gp_l; + let nf9_phh: i32 = u_hi_h * fp_h + v_hi_h * gp_h; + let ng9_pll: i32 = q_hi_l * fp_l + r_hi_l * gp_l; + let ng9_mid: i32 = q_hi_l * fp_h + q_hi_h * fp_l + r_hi_l * gp_h + r_hi_h * gp_l; + let ng9_phh: i32 = q_hi_h * fp_h + r_hi_h * gp_h; + + let nf9_pll_u: u32 = u32(nf9_pll); + let nf9_mid_u: u32 = u32(nf9_mid); + let nf9_phh_u: u32 = u32(nf9_phh); + let nf9_pll_hi: i32 = nf9_pll >> 31u; + let nf9_mid_hi: i32 = nf9_mid >> 17u; + let nf9_phh_hi: i32 = nf9_phh >> 2u; + let nf9_s1_lo: u32 = nf9_pll_u + (nf9_mid_u << 15u); + let nf9_s1_c: i32 = select(0i, 1i, nf9_s1_lo < nf9_pll_u); + let nf9_s2_lo: u32 = nf9_s1_lo + (nf9_phh_u << 30u); + let nf9_s2_c: i32 = select(0i, 1i, nf9_s2_lo < nf9_s1_lo); + let nf9_total_lo: u32 = nf9_s2_lo + cf_lo; + let nf9_total_c: i32 = select(0i, 1i, nf9_total_lo < nf9_s2_lo); + let nf9_total_hi: i32 = nf9_pll_hi + nf9_mid_hi + nf9_phh_hi + nf9_s1_c + nf9_s2_c + nf9_total_c + cf_hi; + + let ng9_pll_u: u32 = u32(ng9_pll); + let ng9_mid_u: u32 = u32(ng9_mid); + let ng9_phh_u: u32 = u32(ng9_phh); + let ng9_pll_hi: i32 = ng9_pll >> 31u; + let ng9_mid_hi: i32 = ng9_mid >> 17u; + let ng9_phh_hi: i32 = ng9_phh >> 2u; + let ng9_s1_lo: u32 = ng9_pll_u + (ng9_mid_u << 15u); + let ng9_s1_c: i32 = select(0i, 1i, ng9_s1_lo < ng9_pll_u); + let ng9_s2_lo: u32 = ng9_s1_lo + (ng9_phh_u << 30u); + let ng9_s2_c: i32 = select(0i, 1i, ng9_s2_lo < ng9_s1_lo); + let ng9_total_lo: u32 = ng9_s2_lo + cg_lo; + let ng9_total_c: i32 = select(0i, 1i, ng9_total_lo < ng9_s2_lo); + let ng9_total_hi: i32 = ng9_pll_hi + ng9_mid_hi + ng9_phh_hi + ng9_s1_c + ng9_s2_c + ng9_total_c + cg_hi; + + (*f).l[BY_NUM_LIMBS - 2u] = i32(nf9_total_lo & 0x1FFFFFFFu); + (*g).l[BY_NUM_LIMBS - 2u] = i32(ng9_total_lo & 0x1FFFFFFFu); + // Top limb: the value above bit 29 of (nf9_total_lo, nf9_total_hi). + (*f).l[BY_NUM_LIMBS - 1u] = i32((nf9_total_lo >> 29u) | (u32(nf9_total_hi) << 3u)); + (*g).l[BY_NUM_LIMBS - 1u] = i32((ng9_total_lo >> 29u) | (u32(ng9_total_hi) << 3u)); + // by_normalise is a no-op: all lower limbs already masked to [0, 2^29). +} + +// by_apply_matrix_de +// +// Mirrors `Wasm9x29::apply_matrix` lines 222-254 — the (d, e) pass with +// the 2-adic k·p correction. The first two output limbs are zero by +// construction (k chosen to clear the low 58 bits of (M · (d, e)) mod 2^58), +// so the streaming pass folds k·p in from position 2 onward. +// +// `p_inv_lo`, `p_inv_hi`: the 58-bit constant p^(-1) mod 2^58 split as the +// low 32 bits and the high 32 bits respectively. The WASM C++ stores it as +// a single u64 `p_inv`; the WGSL caller pre-splits it because WGSL has no +// native u64. Naming reflects the split: `p_inv = p_inv_lo + (p_inv_hi << 32)`. +// +// Loop bound is `BY_NUM_LIMBS` — const, satisfying the plan rule. +fn by_apply_matrix_de( + m: Mat, + d: ptr, + e: ptr, + p: ptr, + p_inv_lo: u32, + p_inv_hi: u32, +) { + // Same matrix split as the f/g pass, with 15+14-bit pre-splits hoisted + // out of the inner loop (loop-invariant). + let u_lo: i32 = m.u & i32(BY_LIMB_MASK); + let v_lo: i32 = m.v & i32(BY_LIMB_MASK); + let q_lo: i32 = m.q & i32(BY_LIMB_MASK); + let r_lo: i32 = m.r & i32(BY_LIMB_MASK); + let u_hi: i32 = i32((u32(m.u) >> 29u) | (u32(m.u_hi) << 3u)); + let v_hi: i32 = i32((u32(m.v) >> 29u) | (u32(m.v_hi) << 3u)); + let q_hi: i32 = i32((u32(m.q) >> 29u) | (u32(m.q_hi) << 3u)); + let r_hi: i32 = i32((u32(m.r) >> 29u) | (u32(m.r_hi) << 3u)); + + let u_lo_l: i32 = (u_lo << 17u) >> 17u; + let u_lo_h: i32 = (u_lo - u_lo_l) >> 15u; + let v_lo_l: i32 = (v_lo << 17u) >> 17u; + let v_lo_h: i32 = (v_lo - v_lo_l) >> 15u; + let q_lo_l: i32 = (q_lo << 17u) >> 17u; + let q_lo_h: i32 = (q_lo - q_lo_l) >> 15u; + let r_lo_l: i32 = (r_lo << 17u) >> 17u; + let r_lo_h: i32 = (r_lo - r_lo_l) >> 15u; + let u_hi_l: i32 = (u_hi << 17u) >> 17u; + let u_hi_h: i32 = (u_hi - u_hi_l) >> 15u; + let v_hi_l: i32 = (v_hi << 17u) >> 17u; + let v_hi_h: i32 = (v_hi - v_hi_l) >> 15u; + let q_hi_l: i32 = (q_hi << 17u) >> 17u; + let q_hi_h: i32 = (q_hi - q_hi_l) >> 15u; + let r_hi_l: i32 = (r_hi << 17u) >> 17u; + let r_hi_h: i32 = (r_hi - r_hi_l) >> 15u; + + let d0: i32 = (*d).l[0]; + let e0: i32 = (*e).l[0]; + let d1: i32 = (*d).l[1]; + let e1: i32 = (*e).l[1]; + + let d0_l: i32 = (d0 << 17u) >> 17u; + let d0_h: i32 = (d0 - d0_l) >> 15u; + let e0_l: i32 = (e0 << 17u) >> 17u; + let e0_h: i32 = (e0 - e0_l) >> 15u; + let d1_l: i32 = (d1 << 17u) >> 17u; + let d1_h: i32 = (d1 - d1_l) >> 15u; + let e1_l: i32 = (e1 << 17u) >> 17u; + let e1_h: i32 = (e1 - e1_l) >> 15u; + + // nd0 = u_lo * d0 + v_lo * e0 (2 products) — inline the 15+14 schoolbook. + let nd0_pll: i32 = u_lo_l * d0_l + v_lo_l * e0_l; + let nd0_mid: i32 = + u_lo_l * d0_h + u_lo_h * d0_l + + v_lo_l * e0_h + v_lo_h * e0_l; + let nd0_phh: i32 = u_lo_h * d0_h + v_lo_h * e0_h; + let ne0_pll: i32 = q_lo_l * d0_l + r_lo_l * e0_l; + let ne0_mid: i32 = + q_lo_l * d0_h + q_lo_h * d0_l + + r_lo_l * e0_h + r_lo_h * e0_l; + let ne0_phh: i32 = q_lo_h * d0_h + r_lo_h * e0_h; + + // nd1 = u_lo * d1 + v_lo * e1 + u_hi * d0 + v_hi * e0 (4 products). + let nd1_pll: i32 = u_lo_l * d1_l + v_lo_l * e1_l + u_hi_l * d0_l + v_hi_l * e0_l; + let nd1_mid: i32 = + u_lo_l * d1_h + u_lo_h * d1_l + + v_lo_l * e1_h + v_lo_h * e1_l + + u_hi_l * d0_h + u_hi_h * d0_l + + v_hi_l * e0_h + v_hi_h * e0_l; + let nd1_phh: i32 = u_lo_h * d1_h + v_lo_h * e1_h + u_hi_h * d0_h + v_hi_h * e0_h; + let ne1_pll: i32 = q_lo_l * d1_l + r_lo_l * e1_l + q_hi_l * d0_l + r_hi_l * e0_l; + let ne1_mid: i32 = + q_lo_l * d1_h + q_lo_h * d1_l + + r_lo_l * e1_h + r_lo_h * e1_l + + q_hi_l * d0_h + q_hi_h * d0_l + + r_hi_l * e0_h + r_hi_h * e0_l; + let ne1_phh: i32 = q_lo_h * d1_h + r_lo_h * e1_h + q_hi_h * d0_h + r_hi_h * e0_h; + + // Helper-equivalent extraction: convert (pll, mid, phh) → i64 (lo, hi). + // Inlined to avoid function-call overhead. + let nd0_pll_u: u32 = u32(nd0_pll); + let nd0_mid_u: u32 = u32(nd0_mid); + let nd0_phh_u: u32 = u32(nd0_phh); + let nd0_pll_hi: i32 = nd0_pll >> 31u; + let nd0_mid_hi: i32 = nd0_mid >> 17u; + let nd0_phh_hi: i32 = nd0_phh >> 2u; + let nd0_s1: u32 = nd0_pll_u + (nd0_mid_u << 15u); + let nd0_s1c: i32 = select(0i, 1i, nd0_s1 < nd0_pll_u); + let nd0_lo: u32 = nd0_s1 + (nd0_phh_u << 30u); + let nd0_s2c: i32 = select(0i, 1i, nd0_lo < nd0_s1); + let nd0_hi: i32 = nd0_pll_hi + nd0_mid_hi + nd0_phh_hi + nd0_s1c + nd0_s2c; + + let ne0_pll_u: u32 = u32(ne0_pll); + let ne0_mid_u: u32 = u32(ne0_mid); + let ne0_phh_u: u32 = u32(ne0_phh); + let ne0_pll_hi: i32 = ne0_pll >> 31u; + let ne0_mid_hi: i32 = ne0_mid >> 17u; + let ne0_phh_hi: i32 = ne0_phh >> 2u; + let ne0_s1: u32 = ne0_pll_u + (ne0_mid_u << 15u); + let ne0_s1c: i32 = select(0i, 1i, ne0_s1 < ne0_pll_u); + let ne0_lo: u32 = ne0_s1 + (ne0_phh_u << 30u); + let ne0_s2c: i32 = select(0i, 1i, ne0_lo < ne0_s1); + let ne0_hi: i32 = ne0_pll_hi + ne0_mid_hi + ne0_phh_hi + ne0_s1c + ne0_s2c; + + let nd1_pll_u: u32 = u32(nd1_pll); + let nd1_mid_u: u32 = u32(nd1_mid); + let nd1_phh_u: u32 = u32(nd1_phh); + let nd1_pll_hi: i32 = nd1_pll >> 31u; + let nd1_mid_hi: i32 = nd1_mid >> 17u; + let nd1_phh_hi: i32 = nd1_phh >> 2u; + let nd1_s1: u32 = nd1_pll_u + (nd1_mid_u << 15u); + let nd1_s1c: i32 = select(0i, 1i, nd1_s1 < nd1_pll_u); + let nd1_lo: u32 = nd1_s1 + (nd1_phh_u << 30u); + let nd1_s2c: i32 = select(0i, 1i, nd1_lo < nd1_s1); + let nd1_hi: i32 = nd1_pll_hi + nd1_mid_hi + nd1_phh_hi + nd1_s1c + nd1_s2c; + + let ne1_pll_u: u32 = u32(ne1_pll); + let ne1_mid_u: u32 = u32(ne1_mid); + let ne1_phh_u: u32 = u32(ne1_phh); + let ne1_pll_hi: i32 = ne1_pll >> 31u; + let ne1_mid_hi: i32 = ne1_mid >> 17u; + let ne1_phh_hi: i32 = ne1_phh >> 2u; + let ne1_s1: u32 = ne1_pll_u + (ne1_mid_u << 15u); + let ne1_s1c: i32 = select(0i, 1i, ne1_s1 < ne1_pll_u); + let ne1_lo: u32 = ne1_s1 + (ne1_phh_u << 30u); + let ne1_s2c: i32 = select(0i, 1i, ne1_lo < ne1_s1); + let ne1_hi: i32 = ne1_pll_hi + ne1_mid_hi + ne1_phh_hi + ne1_s1c + ne1_s2c; + + // Reconstruct low 58 bits of nd and ne for k computation. + // td = (nd0_low29 + (nd1_plus_low29 << 29)) where nd1_plus = nd1 + (nd0 >> 29). + let nd0_low29: u32 = nd0_lo & BY_LIMB_MASK; + let ne0_low29: u32 = ne0_lo & BY_LIMB_MASK; + // nd0 >> 29 arithmetic shift, as i64 (lo, hi): + let nd0_ars_lo: u32 = (nd0_lo >> 29u) | (u32(nd0_hi) << 3u); + let nd0_ars_hi: i32 = nd0_hi >> 29u; + let ne0_ars_lo: u32 = (ne0_lo >> 29u) | (u32(ne0_hi) << 3u); + let ne0_ars_hi: i32 = ne0_hi >> 29u; + let nd1p_lo: u32 = nd1_lo + nd0_ars_lo; + let nd1p_c: i32 = select(0i, 1i, nd1p_lo < nd1_lo); + let nd1p_hi: i32 = nd1_hi + nd0_ars_hi + nd1p_c; + let ne1p_lo: u32 = ne1_lo + ne0_ars_lo; + let ne1p_c: i32 = select(0i, 1i, ne1p_lo < ne1_lo); + let ne1p_hi: i32 = ne1_hi + ne0_ars_hi + ne1p_c; + let nd1_low29: u32 = nd1p_lo & BY_LIMB_MASK; + let ne1_low29: u32 = ne1p_lo & BY_LIMB_MASK; + + let td: vec2 = vec2(nd0_low29 | (nd1_low29 << 29u), nd1_low29 >> 3u); + let te: vec2 = vec2(ne0_low29 | (ne1_low29 << 29u), ne1_low29 >> 3u); + + // k_d = ((-t_d) * p_inv) & MASK_BATCH. + let neg_td: vec2 = u64_neg(td); + let neg_te: vec2 = u64_neg(te); + let p_inv: vec2 = vec2(p_inv_lo, p_inv_hi); + let kd_prod: vec2 = u64_mul_low64(neg_td, p_inv); + let ke_prod: vec2 = u64_mul_low64(neg_te, p_inv); + + let MASK_BATCH_HI: u32 = (1u << 26u) - 1u; + let kd_lo32: u32 = kd_prod.x; + let kd_hi26: u32 = kd_prod.y & MASK_BATCH_HI; + let ke_lo32: u32 = ke_prod.x; + let ke_hi26: u32 = ke_prod.y & MASK_BATCH_HI; + + let kd_lo: i32 = i32(kd_lo32 & BY_LIMB_MASK); + let kd_hi: i32 = i32((kd_lo32 >> 29u) | (kd_hi26 << 3u)); + let ke_lo: i32 = i32(ke_lo32 & BY_LIMB_MASK); + let ke_hi: i32 = i32((ke_lo32 >> 29u) | (ke_hi26 << 3u)); + + // Split k_*_lo, k_*_hi into 15+14 chunks for the inner loop. + let kd_lo_l: i32 = (kd_lo << 17u) >> 17u; + let kd_lo_h: i32 = (kd_lo - kd_lo_l) >> 15u; + let kd_hi_l: i32 = (kd_hi << 17u) >> 17u; + let kd_hi_h: i32 = (kd_hi - kd_hi_l) >> 15u; + let ke_lo_l: i32 = (ke_lo << 17u) >> 17u; + let ke_lo_h: i32 = (ke_lo - ke_lo_l) >> 15u; + let ke_hi_l: i32 = (ke_hi << 17u) >> 17u; + let ke_hi_h: i32 = (ke_hi - ke_hi_l) >> 15u; + + // Initial seed: nd0_plus = nd0 + kd_lo*p[0], cd_acc = nd1 + kd_lo*p[1] + kd_hi*p[0] + (nd0_plus >> 29). + // p[0] and p[1] are small (the BN254 modulus); we still split for correctness. + let p0: i32 = (*p).l[0]; + let p1: i32 = (*p).l[1]; + let p0_l: i32 = (p0 << 17u) >> 17u; + let p0_h: i32 = (p0 - p0_l) >> 15u; + let p1_l: i32 = (p1 << 17u) >> 17u; + let p1_h: i32 = (p1 - p1_l) >> 15u; + + // nd0_plus = nd0 + kd_lo*p[0] + let np0_pll: i32 = kd_lo_l * p0_l; + let np0_mid: i32 = kd_lo_l * p0_h + kd_lo_h * p0_l; + let np0_phh: i32 = kd_lo_h * p0_h; + let np0_pll_u: u32 = u32(np0_pll); + let np0_mid_u: u32 = u32(np0_mid); + let np0_phh_u: u32 = u32(np0_phh); + let np0_pll_hi: i32 = np0_pll >> 31u; + let np0_mid_hi: i32 = np0_mid >> 17u; + let np0_phh_hi: i32 = np0_phh >> 2u; + let np0_s1: u32 = np0_pll_u + (np0_mid_u << 15u); + let np0_s1c: i32 = select(0i, 1i, np0_s1 < np0_pll_u); + let np0_lo: u32 = np0_s1 + (np0_phh_u << 30u); + let np0_s2c: i32 = select(0i, 1i, np0_lo < np0_s1); + let np0_hi: i32 = np0_pll_hi + np0_mid_hi + np0_phh_hi + np0_s1c + np0_s2c; + // nd0_plus = nd0 + np0 + let nd0p_lo: u32 = nd0_lo + np0_lo; + let nd0p_c: i32 = select(0i, 1i, nd0p_lo < nd0_lo); + let nd0p_hi: i32 = nd0_hi + np0_hi + nd0p_c; + // (nd0_plus >> 29) signed arithmetic + let nd0p_ars_lo: u32 = (nd0p_lo >> 29u) | (u32(nd0p_hi) << 3u); + let nd0p_ars_hi: i32 = nd0p_hi >> 29u; + + let ne0p_pll: i32 = ke_lo_l * p0_l; + let ne0p_mid: i32 = ke_lo_l * p0_h + ke_lo_h * p0_l; + let ne0p_phh: i32 = ke_lo_h * p0_h; + let ne0p_pll_u: u32 = u32(ne0p_pll); + let ne0p_mid_u: u32 = u32(ne0p_mid); + let ne0p_phh_u: u32 = u32(ne0p_phh); + let ne0p_pll_hi: i32 = ne0p_pll >> 31u; + let ne0p_mid_hi: i32 = ne0p_mid >> 17u; + let ne0p_phh_hi: i32 = ne0p_phh >> 2u; + let ne0p_s1: u32 = ne0p_pll_u + (ne0p_mid_u << 15u); + let ne0p_s1c: i32 = select(0i, 1i, ne0p_s1 < ne0p_pll_u); + let ne0p_lo: u32 = ne0p_s1 + (ne0p_phh_u << 30u); + let ne0p_s2c: i32 = select(0i, 1i, ne0p_lo < ne0p_s1); + let ne0p_hi: i32 = ne0p_pll_hi + ne0p_mid_hi + ne0p_phh_hi + ne0p_s1c + ne0p_s2c; + let ne0pa_lo: u32 = ne0_lo + ne0p_lo; + let ne0pa_c: i32 = select(0i, 1i, ne0pa_lo < ne0_lo); + let ne0pa_hi: i32 = ne0_hi + ne0p_hi + ne0pa_c; + let ne0pa_ars_lo: u32 = (ne0pa_lo >> 29u) | (u32(ne0pa_hi) << 3u); + let ne0pa_ars_hi: i32 = ne0pa_hi >> 29u; + + // cd_acc = nd1 + kd_lo*p[1] + kd_hi*p[0] + (nd0_plus >> 29) + let cda_pll: i32 = kd_lo_l * p1_l + kd_hi_l * p0_l; + let cda_mid: i32 = + kd_lo_l * p1_h + kd_lo_h * p1_l + + kd_hi_l * p0_h + kd_hi_h * p0_l; + let cda_phh: i32 = kd_lo_h * p1_h + kd_hi_h * p0_h; + let cda_pll_u: u32 = u32(cda_pll); + let cda_mid_u: u32 = u32(cda_mid); + let cda_phh_u: u32 = u32(cda_phh); + let cda_pll_hi: i32 = cda_pll >> 31u; + let cda_mid_hi: i32 = cda_mid >> 17u; + let cda_phh_hi: i32 = cda_phh >> 2u; + let cda_s1: u32 = cda_pll_u + (cda_mid_u << 15u); + let cda_s1c: i32 = select(0i, 1i, cda_s1 < cda_pll_u); + let cda_p_lo: u32 = cda_s1 + (cda_phh_u << 30u); + let cda_s2c: i32 = select(0i, 1i, cda_p_lo < cda_s1); + let cda_p_hi: i32 = cda_pll_hi + cda_mid_hi + cda_phh_hi + cda_s1c + cda_s2c; + // cda = nd1 + cda_p + nd0p_ars + let cda_a_lo: u32 = nd1_lo + cda_p_lo; + let cda_a_c: i32 = select(0i, 1i, cda_a_lo < nd1_lo); + let cda_a_hi: i32 = nd1_hi + cda_p_hi + cda_a_c; + let cda_b_lo: u32 = cda_a_lo + nd0p_ars_lo; + let cda_b_c: i32 = select(0i, 1i, cda_b_lo < cda_a_lo); + let cda_b_hi: i32 = cda_a_hi + nd0p_ars_hi + cda_b_c; + // cd = cda >> 29 (signed arithmetic) + var cd_lo: u32 = (cda_b_lo >> 29u) | (u32(cda_b_hi) << 3u); + var cd_hi: i32 = cda_b_hi >> 29u; + + let cea_pll: i32 = ke_lo_l * p1_l + ke_hi_l * p0_l; + let cea_mid: i32 = + ke_lo_l * p1_h + ke_lo_h * p1_l + + ke_hi_l * p0_h + ke_hi_h * p0_l; + let cea_phh: i32 = ke_lo_h * p1_h + ke_hi_h * p0_h; + let cea_pll_u: u32 = u32(cea_pll); + let cea_mid_u: u32 = u32(cea_mid); + let cea_phh_u: u32 = u32(cea_phh); + let cea_pll_hi: i32 = cea_pll >> 31u; + let cea_mid_hi: i32 = cea_mid >> 17u; + let cea_phh_hi: i32 = cea_phh >> 2u; + let cea_s1: u32 = cea_pll_u + (cea_mid_u << 15u); + let cea_s1c: i32 = select(0i, 1i, cea_s1 < cea_pll_u); + let cea_p_lo: u32 = cea_s1 + (cea_phh_u << 30u); + let cea_s2c: i32 = select(0i, 1i, cea_p_lo < cea_s1); + let cea_p_hi: i32 = cea_pll_hi + cea_mid_hi + cea_phh_hi + cea_s1c + cea_s2c; + let cea_a_lo: u32 = ne1_lo + cea_p_lo; + let cea_a_c: i32 = select(0i, 1i, cea_a_lo < ne1_lo); + let cea_a_hi: i32 = ne1_hi + cea_p_hi + cea_a_c; + let cea_b_lo: u32 = cea_a_lo + ne0pa_ars_lo; + let cea_b_c: i32 = select(0i, 1i, cea_b_lo < cea_a_lo); + let cea_b_hi: i32 = cea_a_hi + ne0pa_ars_hi + cea_b_c; + var ce_lo: u32 = (cea_b_lo >> 29u) | (u32(cea_b_hi) << 3u); + var ce_hi: i32 = cea_b_hi >> 29u; + + // Slide-forward previous-limb splits for the inner loop. `pc_l`/`pc_h` + // hold p[i-1] entering iter i; after the body we set pc = p[i]. + var dp_l: i32 = d1_l; + var dp_h: i32 = d1_h; + var ep_l: i32 = e1_l; + var ep_h: i32 = e1_h; + var pc_l: i32 = p1_l; + var pc_h: i32 = p1_h; + + for (var i: u32 = 2u; i < BY_NUM_LIMBS; i = i + 1u) { + let di: i32 = (*d).l[i]; + let ei: i32 = (*e).l[i]; + let pi: i32 = (*p).l[i]; + let di_l: i32 = (di << 17u) >> 17u; + let di_h: i32 = (di - di_l) >> 15u; + let ei_l: i32 = (ei << 17u) >> 17u; + let ei_h: i32 = (ei - ei_l) >> 15u; + let pi_l: i32 = (pi << 17u) >> 17u; + let pi_h: i32 = (pi - pi_l) >> 15u; + + // nd = u_lo*di + v_lo*ei + u_hi*dp + v_hi*ep + kd_lo*p[i] + kd_hi*p[i-1] + cd + // 6 products. Bound check: each pll/phh < 2^28, sum < 6*2^28 < 2^31 ✓ + // each plh+phl < 2*2^28 = 2^29, sum < 6*2^29 < 2^32 ✗ overflow ! + // The "mid" lane needs care. Split: sum 6 lh-products and 6 hl-products SEPARATELY, + // each < 6*2^28 < 2^31. Then combine in i64 via two adds. + let nd_pll: i32 = + u_lo_l * di_l + v_lo_l * ei_l + + u_hi_l * dp_l + v_hi_l * ep_l + + kd_lo_l * pi_l + kd_hi_l * pc_l; + // Two mid sub-lanes: low_high_products + high_low_products. + let nd_mid_lh: i32 = + u_lo_l * di_h + v_lo_l * ei_h + + u_hi_l * dp_h + v_hi_l * ep_h + + kd_lo_l * pi_h + kd_hi_l * pc_h; + let nd_mid_hl: i32 = + u_lo_h * di_l + v_lo_h * ei_l + + u_hi_h * dp_l + v_hi_h * ep_l + + kd_lo_h * pi_l + kd_hi_h * pc_l; + let nd_phh: i32 = + u_lo_h * di_h + v_lo_h * ei_h + + u_hi_h * dp_h + v_hi_h * ep_h + + kd_lo_h * pi_h + kd_hi_h * pc_h; + + let ne_pll: i32 = + q_lo_l * di_l + r_lo_l * ei_l + + q_hi_l * dp_l + r_hi_l * ep_l + + ke_lo_l * pi_l + ke_hi_l * pc_l; + let ne_mid_lh: i32 = + q_lo_l * di_h + r_lo_l * ei_h + + q_hi_l * dp_h + r_hi_l * ep_h + + ke_lo_l * pi_h + ke_hi_l * pc_h; + let ne_mid_hl: i32 = + q_lo_h * di_l + r_lo_h * ei_l + + q_hi_h * dp_l + r_hi_h * ep_l + + ke_lo_h * pi_l + ke_hi_h * pc_l; + let ne_phh: i32 = + q_lo_h * di_h + r_lo_h * ei_h + + q_hi_h * dp_h + r_hi_h * ep_h + + ke_lo_h * pi_h + ke_hi_h * pc_h; + + // Combine nd_pll + (nd_mid_lh + nd_mid_hl) << 15 + nd_phh << 30 + cd into i64. + // First fold nd_mid_lh + nd_mid_hl as i64 (mid lane could overflow i32 if combined). + // Each is < 2^31; sum needs 33 bits. + let nd_pll_u: u32 = u32(nd_pll); + let nd_mlh_u: u32 = u32(nd_mid_lh); + let nd_mhl_u: u32 = u32(nd_mid_hl); + let nd_phh_u: u32 = u32(nd_phh); + let nd_pll_hi: i32 = nd_pll >> 31u; + let nd_mlh_hi: i32 = nd_mid_lh >> 17u; + let nd_mhl_hi: i32 = nd_mid_hl >> 17u; + let nd_phh_hi: i32 = nd_phh >> 2u; + + // s = pll + mlh<<15 + mhl<<15 + phh<<30 + cd + let nd_a_lo: u32 = nd_pll_u + (nd_mlh_u << 15u); + let nd_a_c: i32 = select(0i, 1i, nd_a_lo < nd_pll_u); + let nd_b_lo: u32 = nd_a_lo + (nd_mhl_u << 15u); + let nd_b_c: i32 = select(0i, 1i, nd_b_lo < nd_a_lo); + let nd_c_lo: u32 = nd_b_lo + (nd_phh_u << 30u); + let nd_c_c: i32 = select(0i, 1i, nd_c_lo < nd_b_lo); + let nd_d_lo: u32 = nd_c_lo + cd_lo; + let nd_d_c: i32 = select(0i, 1i, nd_d_lo < nd_c_lo); + let nd_d_hi: i32 = nd_pll_hi + nd_mlh_hi + nd_mhl_hi + nd_phh_hi + nd_a_c + nd_b_c + nd_c_c + nd_d_c + cd_hi; + + let ne_pll_u: u32 = u32(ne_pll); + let ne_mlh_u: u32 = u32(ne_mid_lh); + let ne_mhl_u: u32 = u32(ne_mid_hl); + let ne_phh_u: u32 = u32(ne_phh); + let ne_pll_hi: i32 = ne_pll >> 31u; + let ne_mlh_hi: i32 = ne_mid_lh >> 17u; + let ne_mhl_hi: i32 = ne_mid_hl >> 17u; + let ne_phh_hi: i32 = ne_phh >> 2u; + + let ne_a_lo: u32 = ne_pll_u + (ne_mlh_u << 15u); + let ne_a_c: i32 = select(0i, 1i, ne_a_lo < ne_pll_u); + let ne_b_lo: u32 = ne_a_lo + (ne_mhl_u << 15u); + let ne_b_c: i32 = select(0i, 1i, ne_b_lo < ne_a_lo); + let ne_c_lo: u32 = ne_b_lo + (ne_phh_u << 30u); + let ne_c_c: i32 = select(0i, 1i, ne_c_lo < ne_b_lo); + let ne_d_lo: u32 = ne_c_lo + ce_lo; + let ne_d_c: i32 = select(0i, 1i, ne_d_lo < ne_c_lo); + let ne_d_hi: i32 = ne_pll_hi + ne_mlh_hi + ne_mhl_hi + ne_phh_hi + ne_a_c + ne_b_c + ne_c_c + ne_d_c + ce_hi; + + (*d).l[i - 2u] = i32(nd_d_lo & BY_LIMB_MASK); + (*e).l[i - 2u] = i32(ne_d_lo & BY_LIMB_MASK); + cd_lo = (nd_d_lo >> 29u) | (u32(nd_d_hi) << 3u); + cd_hi = nd_d_hi >> 29u; + ce_lo = (ne_d_lo >> 29u) | (u32(ne_d_hi) << 3u); + ce_hi = ne_d_hi >> 29u; + + // Slide previous-limb splits. + dp_l = di_l; dp_h = di_h; + ep_l = ei_l; ep_h = ei_h; + pc_l = pi_l; pc_h = pi_h; + } + + // Top-limb finalisation: + // nd9 = u_hi * dp + v_hi * ep + kd_hi * p[N-1] + cd (3 products) + // ne9 = q_hi * dp + r_hi * ep + ke_hi * p[N-1] + ce + let p_top: i32 = (*p).l[BY_NUM_LIMBS - 1u]; + let pt_l: i32 = (p_top << 17u) >> 17u; + let pt_h: i32 = (p_top - pt_l) >> 15u; + + let nd9_pll: i32 = u_hi_l * dp_l + v_hi_l * ep_l + kd_hi_l * pt_l; + let nd9_mid: i32 = + u_hi_l * dp_h + u_hi_h * dp_l + + v_hi_l * ep_h + v_hi_h * ep_l + + kd_hi_l * pt_h + kd_hi_h * pt_l; + let nd9_phh: i32 = u_hi_h * dp_h + v_hi_h * ep_h + kd_hi_h * pt_h; + + let ne9_pll: i32 = q_hi_l * dp_l + r_hi_l * ep_l + ke_hi_l * pt_l; + let ne9_mid: i32 = + q_hi_l * dp_h + q_hi_h * dp_l + + r_hi_l * ep_h + r_hi_h * ep_l + + ke_hi_l * pt_h + ke_hi_h * pt_l; + let ne9_phh: i32 = q_hi_h * dp_h + r_hi_h * ep_h + ke_hi_h * pt_h; + + // For 3 products, mid sum ≤ 3 * 2 * 2^28 = 3 * 2^29 = 1.5 * 2^30 < 2^31 ✓ + let nd9_pll_u: u32 = u32(nd9_pll); + let nd9_mid_u: u32 = u32(nd9_mid); + let nd9_phh_u: u32 = u32(nd9_phh); + let nd9_pll_hi: i32 = nd9_pll >> 31u; + let nd9_mid_hi: i32 = nd9_mid >> 17u; + let nd9_phh_hi: i32 = nd9_phh >> 2u; + let nd9_s1: u32 = nd9_pll_u + (nd9_mid_u << 15u); + let nd9_s1c: i32 = select(0i, 1i, nd9_s1 < nd9_pll_u); + let nd9_s2: u32 = nd9_s1 + (nd9_phh_u << 30u); + let nd9_s2c: i32 = select(0i, 1i, nd9_s2 < nd9_s1); + let nd9_total_lo: u32 = nd9_s2 + cd_lo; + let nd9_total_c: i32 = select(0i, 1i, nd9_total_lo < nd9_s2); + let nd9_total_hi: i32 = nd9_pll_hi + nd9_mid_hi + nd9_phh_hi + nd9_s1c + nd9_s2c + nd9_total_c + cd_hi; + + let ne9_pll_u: u32 = u32(ne9_pll); + let ne9_mid_u: u32 = u32(ne9_mid); + let ne9_phh_u: u32 = u32(ne9_phh); + let ne9_pll_hi: i32 = ne9_pll >> 31u; + let ne9_mid_hi: i32 = ne9_mid >> 17u; + let ne9_phh_hi: i32 = ne9_phh >> 2u; + let ne9_s1: u32 = ne9_pll_u + (ne9_mid_u << 15u); + let ne9_s1c: i32 = select(0i, 1i, ne9_s1 < ne9_pll_u); + let ne9_s2: u32 = ne9_s1 + (ne9_phh_u << 30u); + let ne9_s2c: i32 = select(0i, 1i, ne9_s2 < ne9_s1); + let ne9_total_lo: u32 = ne9_s2 + ce_lo; + let ne9_total_c: i32 = select(0i, 1i, ne9_total_lo < ne9_s2); + let ne9_total_hi: i32 = ne9_pll_hi + ne9_mid_hi + ne9_phh_hi + ne9_s1c + ne9_s2c + ne9_total_c + ce_hi; + + (*d).l[BY_NUM_LIMBS - 2u] = i32(nd9_total_lo & BY_LIMB_MASK); + (*e).l[BY_NUM_LIMBS - 2u] = i32(ne9_total_lo & BY_LIMB_MASK); + (*d).l[BY_NUM_LIMBS - 1u] = i32((nd9_total_lo >> 29u) | (u32(nd9_total_hi) << 3u)); + (*e).l[BY_NUM_LIMBS - 1u] = i32((ne9_total_lo >> 29u) | (u32(ne9_total_hi) << 3u)); + // by_normalise no-op: lower limbs already masked. +} + +// ============================================================ +// fr_inv_by driver helpers +// ============================================================ + +// by_is_zero: returns true iff every limb of x is zero. +// Pre: any state (need not be canonical). Post: bool. +fn by_is_zero(x: ptr) -> bool { + var a: i32 = 0; + for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u) { + a = a | (*x).l[i]; + } + return a == 0; +} + +// by_is_negative: top-limb sign check on a normalised BigIntBY. +// Pre: x normalised so the top limb carries the sign. Post: bool. +fn by_is_negative(x: ptr) -> bool { + return (*x).l[BY_NUM_LIMBS - 1u] < 0; +} + +// by_neg_inplace: negate x then re-normalise so lower limbs are in +// [0, 2^29) again. Mirrors `neg(x)` in bernstein_yang.ts. +fn by_neg_inplace(x: ptr) { + for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u) { + (*x).l[i] = -(*x).l[i]; + } + by_normalise(x); +} + +// by_add_p_inplace: x <- x + p (limbwise) then normalise. +fn by_add_p_inplace(x: ptr, p: ptr) { + for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u) { + (*x).l[i] = (*x).l[i] + (*p).l[i]; + } + by_normalise(x); +} + +// by_sub_p_inplace: x <- x - p (limbwise) then normalise. +fn by_sub_p_inplace(x: ptr, p: ptr) { + for (var i: u32 = 0u; i < BY_NUM_LIMBS; i = i + 1u) { + (*x).l[i] = (*x).l[i] - (*p).l[i]; + } + by_normalise(x); +} + +// by_gte_p: true iff x >= p, assuming both x and p are non-negative +// canonical-limb 9-limb BigIntBY values (lower limbs in [0, 2^29), top limb +// non-negative). Walks limbs from high to low. +fn by_gte_p(x: ptr, p: ptr) -> bool { + var gt: bool = false; + var lt: bool = false; + for (var ii: u32 = 0u; ii < BY_NUM_LIMBS; ii = ii + 1u) { + let i: u32 = BY_NUM_LIMBS - 1u - ii; + let a: i32 = (*x).l[i]; + let b: i32 = (*p).l[i]; + let still_undecided: bool = !(gt || lt); + if (still_undecided) { + if (a > b) { gt = true; } + else if (a < b) { lt = true; } + } + } + return gt || !lt; +} + +// by_reduce_to_canonical: bring x into [0, p) using at most BY_RTC_MAX_ITERS +// (= 36) add-p / sub-p passes. Mirrors `reduceToCanonical` in +// bernstein_yang.ts exactly: if x is negative, add p; else if x >= p, +// subtract p; else break. The 36-iter bound suffices for |x| <= 32 p under +// REDUCE_INTERVAL = 4 (see Wasm9x29 docs). +// +// LOOP BOUND: `for (var it: u32 = 0u; it < BY_RTC_MAX_ITERS; ...)` — const. +// +// Pre: x is a possibly-non-canonical signed BigIntBY (post-by_normalise: +// lower limbs in [0, 2^29), top limb carries sign). p is the modulus +// in BigIntBY form (positive, canonical). +// Post: x in [0, p), canonical. +fn by_reduce_to_canonical(x: ptr, p: ptr) { + by_normalise(x); + var done: bool = false; + for (var it: u32 = 0u; it < BY_RTC_MAX_ITERS; it = it + 1u) { + if (done) { continue; } + if (by_is_negative(x)) { + by_add_p_inplace(x, p); + } else if (by_gte_p(x, p)) { + by_sub_p_inplace(x, p); + } else { + done = true; + } + } +} + +// 58-bit p_inv split as low 32 / high (<=26) bits. Mustache-injected by +// `gen_fr_inv_bench_shader` (and the production wiring in step 1.7) from +// `compute_by_p_inv_split` in cuzk/utils.ts. Matching pair to the +// `P_INV_BY_LO` / `P_INV_BY_HI` constants the apply_matrix bench uses. +const FR_INV_BY_P_INV_LO: u32 = {{ p_inv_by_lo }}u; +const FR_INV_BY_P_INV_HI: u32 = {{ p_inv_by_hi }}u; + +// fr_inv_by: Bernstein-Yang safegcd inverse driver, mirroring +// `invert_bernsteinyang19` (bernstein_yang_inverse.hpp lines 290-326) +// and the TS reference `Wasm9x29.invert` (bernstein_yang.ts:409-440). +// +// Input `a` is in Montgomery form: bigint_value(a) = A * R mod p. +// Output is in Montgomery form: bigint_value(output) = A^(-1) * R mod p. +// +// Algorithm sketch: +// 1. Convert (a, p) to 9 x 29-bit BigIntBY representation. +// 2. Run NUM_OUTER = 13 outer iterations. Each outer iter: +// a. Compute low-64-bit views (f_lo, g_lo) of (f, g). +// b. by_divsteps(&delta, f_lo, g_lo) -> Mat (BATCH = 58 inner). +// c. by_apply_matrix_fg(M, &f, &g) — folds 58 divsteps into f, g. +// d. by_apply_matrix_de(M, &d, &e, &p, p_inv) — same on (d, e). +// e. Every BY_REDUCE_INTERVAL = 4 iters, reduce_to_canonical(d, e). +// f. Early break on `by_is_zero(g)` — the const NUM_OUTER bound is +// still respected by the WGSL emitter via a guard flag, not by +// shrinking the loop count. +// 3. After the loop, reduce_to_canonical(d) and, if f is negative, +// negate d mod p (mirrors the C++ `sign(f) * d` step). +// 4. The BY output is `inv_native = (A * R)^(-1) mod p = A^(-1) * R^(-1)` +// in canonical [0, p). Apply the standard Mont correction via +// `montgomery_product(inv_native, R^3)` = +// inv_native * R^3 * R^(-1) = inv_native * R^2 = A^(-1) * R, in +// Montgomery form. Pattern matches `fr_inv` in fr_pow.template.wgsl. +// 5. Convert back to 20 x 13-bit BigInt and return. +// +// LOOP BOUND DISCIPLINE: +// - outer loop: `for (... iter < BY_NUM_OUTER; ...)` (const 13). +// - by_divsteps: `for (... i < BY_BATCH; ...)` (const 58). +// - by_apply_matrix_*: `for (... i < BY_NUM_LIMBS; ...)` (const 9). +// - by_reduce_to_canonical: `for (... it < BY_RTC_MAX_ITERS; ...)`(const 36). +// - by_normalise / by_neg: `for (... i < BY_NUM_LIMBS; ...)` (const 9). +// - by_from_bigint / by_to_bigint loops bounded by const BY_NUM_LIMBS +// and Mustache-const `{{ num_words }}`. +// No data-dependent loop bounds anywhere on the inversion path. +fn fr_inv_by(a: BigInt) -> BigInt { + // Modulus p in BigIntBY form. Use the same Mustache-injected initializer + // as the apply_matrix bench; this gates fr_inv_by's behaviour on the + // ShaderManager-supplied p_limbs_by, matching the rest of the BY surface. + var p_by: BigIntBY = BigIntBY(array({{{ p_limbs_by }}})); + var f: BigIntBY = BigIntBY(array({{{ p_limbs_by }}})); + var g: BigIntBY = by_from_bigint(a); + + var d: BigIntBY; + var e: BigIntBY; + for (var k: u32 = 0u; k < BY_NUM_LIMBS; k = k + 1u) { + d.l[k] = 0; + e.l[k] = 0; + } + e.l[0] = 1; + + var delta: i32 = 1; + var done: bool = false; + for (var iter: u32 = 0u; iter < BY_NUM_OUTER; iter = iter + 1u) { + if (done) { continue; } + // low_64 view of f and g for divsteps. Inlined by_low_u64_lohi. + let f_l0: u32 = u32(f.l[0]) & BY_LIMB_MASK; + let f_l1: u32 = u32(f.l[1]) & BY_LIMB_MASK; + let f_l2: u32 = u32(f.l[2]) & 0x3Fu; + let f_lo: vec2 = vec2(f_l0 | ((f_l1 & 0x7u) << 29u), (f_l1 >> 3u) | (f_l2 << 26u)); + let g_l0: u32 = u32(g.l[0]) & BY_LIMB_MASK; + let g_l1: u32 = u32(g.l[1]) & BY_LIMB_MASK; + let g_l2: u32 = u32(g.l[2]) & 0x3Fu; + let g_lo: vec2 = vec2(g_l0 | ((g_l1 & 0x7u) << 29u), (g_l1 >> 3u) | (g_l2 << 26u)); + let m: Mat = by_divsteps(&delta, f_lo, g_lo); + by_apply_matrix_fg(m, &f, &g); + by_apply_matrix_de(m, &d, &e, &p_by, FR_INV_BY_P_INV_LO, FR_INV_BY_P_INV_HI); + if (((iter + 1u) % BY_REDUCE_INTERVAL) == 0u) { + by_reduce_to_canonical(&d, &p_by); + by_reduce_to_canonical(&e, &p_by); + } + if (by_is_zero(&g)) { + done = true; + } + } + + by_reduce_to_canonical(&d, &p_by); + if (by_is_negative(&f)) { + by_neg_inplace(&d); + by_reduce_to_canonical(&d, &p_by); + } + + // inv_native = A^(-1) * R^(-1) mod p (canonical [0, p)). Mont correction + // via `montgomery_product(inv_native, R^3)` lands at A^(-1) * R, matching + // the pattern used by fr_inv in fr_pow.template.wgsl. + var inv_native: BigInt = by_to_bigint(d); + var r_cubed: BigInt = get_r_cubed(); + return montgomery_product(&inv_native, &r_cubed); +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/field/by_inverse_a.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/field/by_inverse_a.template.wgsl new file mode 100644 index 000000000000..1c8714859974 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/field/by_inverse_a.template.wgsl @@ -0,0 +1,669 @@ +// Option A: Bernstein-Yang safegcd inverse on the 20 x 13-bit BigInt +// representation. Tight wide-multiply apply_matrix variant. +// +// LAYOUT +// - BigInt: 20 x 13-bit limbs (canonical input) or 20 limbs storing +// SIGNED i32 bitcast into u32 (non-canonical between iters, magnitude +// bounded by 2^15). +// - Matrix entries u, v, q, r: signed i32. After BATCH=26 inner divsteps +// |entry| <= 2^26, fits comfortably in i32. +// - Inner divsteps operate on the LOW 64 BITS of (f, g) carried as a +// vec2. We need >= BATCH bits to drive divstep decisions +// correctly; 64 gives us 38 bits of headroom for sign propagation. +// +// APPLY_MATRIX DESIGN +// - Per-output-limb raw accumulators are each ONE inline expression of +// four 13-bit muls + three adds. No common-subexpression pre-compute +// (each lo/hi*limb product is used in exactly ONE slot). The compiler +// issues them back-to-back, the GPU keeps registers tight. +// - Carry-propagation is TWO parallel passes (each reads only the prior +// pass's output, not its own in-progress writes). After two passes the +// limbs fit in [-2^14, 2^14] and we store them as u32 bitcast. +// - We do NOT canonicalize between outer iterations: limbs stay signed +// non-canonical up to 2^15 magnitude. Next iter's multiply tolerates +// this because (2^13_matrix * 2^15_limb) * 4_terms = 2^30 < 2^31. +// - We DO canonicalize d at the very end (before the Mont correction). +// +// LOOP BOUND DISCIPLINE +// - Outer driver: `for (var iter < BYA_NUM_OUTER)` (const = 29) +// - Inner divsteps: `for (var i < BYA_BATCH)` (const = 26) +// - Apply matrix: fully unrolled (no loops) +// - Reduce-to-canonical: `for (var it < BYA_RTC_MAX_ITERS)` (const = 4) +// +// CONVERGENCE +// Bernstein-Yang safegcd bound for 256-bit modulus: 735 divsteps. +// BATCH=26 -> NUM_OUTER = ceil(735/26) = 29. + +const BYA_BATCH: u32 = 26u; +const BYA_NUM_OUTER: u32 = 29u; +const BYA_REDUCE_INTERVAL: u32 = 4u; +const BYA_RTC_MAX_ITERS: u32 = 4u; +const BYA_MASK13: u32 = (1u << 13u) - 1u; +const BYA_MASK13_I32: i32 = (1 << 13) - 1; + +// 2x2 matrix entries after BATCH=26 divsteps. Each entry is an i32 with +// |.| <= 2^26. +struct MatA { + u: i32, + v: i32, + q: i32, + r: i32, +} + +// ============================================================ +// bya_divsteps: BATCH=26 branchy divsteps on the low 64 bits of (f, g). +// +// Matrix entries u, v, q, r grow by at most one shl + one sub per iter, +// so after BATCH=26 we have |entry| <= 2^26. +// ============================================================ +fn bya_divsteps(delta: ptr, f_lo_in: vec2, g_lo_in: vec2) -> MatA { + var f_lo: vec2 = f_lo_in; + var g_lo: vec2 = g_lo_in; + var u: i32 = 1; + var v: i32 = 0; + var q: i32 = 0; + var r: i32 = 1; + var d: i32 = *delta; + for (var i: u32 = 0u; i < BYA_BATCH; i = i + 1u) { + if (u64_low_bit(g_lo) != 0u) { + if (d > 0) { + let nf: vec2 = g_lo; + let diff: vec2 = u64_sub(g_lo, f_lo); + let ng: vec2 = u64_shr1(diff); + let nu: i32 = q << 1u; + let nv: i32 = r << 1u; + let nq: i32 = q - u; + let nr: i32 = r - v; + f_lo = nf; + g_lo = ng; + u = nu; + v = nv; + q = nq; + r = nr; + d = 1 - d; + } else { + let sum: vec2 = u64_add(g_lo, f_lo); + g_lo = u64_shr1(sum); + q = q + u; + r = r + v; + u = u << 1u; + v = v << 1u; + d = d + 1; + } + } else { + g_lo = u64_shr1(g_lo); + u = u << 1u; + v = v << 1u; + d = d + 1; + } + } + *delta = d; + return MatA(u, v, q, r); +} + +// ============================================================ +// bya_low_u64_lohi: low 64 bits of a 20 x 13-bit BigInt with canonical +// 13-bit limbs (the serial-carry apply_matrix output guarantees this). +// ============================================================ +fn bya_low_u64_lohi(x: BigInt) -> vec2 { + let l0: u32 = x.limbs[0] & MASK; + let l1: u32 = x.limbs[1] & MASK; + let l2: u32 = x.limbs[2] & MASK; + let l3: u32 = x.limbs[3] & MASK; + let l4: u32 = x.limbs[4] & MASK; + let lo32: u32 = l0 | (l1 << 13u) | (l2 << 26u); + let hi32: u32 = (l2 >> 6u) | (l3 << 7u) | (l4 << 20u); + return vec2(lo32, hi32); +} + +// ============================================================ +// bya_normalise: carry-propagate so each limb in [0, N-1) is in +// [0, 2^13) canonical and the top limb absorbs the signed extension. +// Used by reduce_to_canonical at the END of fr_inv_by_a. +// ============================================================ +fn bya_normalise(x: ptr) { + var c: i32 = 0; + for (var i: u32 = 0u; i < {{ num_words }}u - 1u; i = i + 1u) { + let v = i32((*x).limbs[i]) + c; + (*x).limbs[i] = u32(v) & MASK; + c = v >> WORD_SIZE; + } + (*x).limbs[{{ num_words }}u - 1u] = u32(i32((*x).limbs[{{ num_words }}u - 1u]) + c) & MASK; +} + +// ============================================================ +// bya_apply_matrix_fg +// +// Compute (f_new, g_new) = ((u*f + v*g) >> 26, (q*f + r*g) >> 26). +// +// Matrix entry split: m = m_lo + m_hi * 2^13 where m_lo in [0, 2^13) +// (taken as low-13-bit unsigned) and m_hi in [-2^13, 2^13) (taken as +// arithmetic shift right of i32). The product is recovered as: +// m * x = m_lo * x + m_hi * x * 2^13 +// +// For each output position k in [0, 19], the raw value is +// nf[k] = u_lo*f[k+2] + v_lo*g[k+2] + u_hi*f[k+1] + v_hi*g[k+1] +// with the convention f[20] = f[21] = 0 (and same for g). The two +// dropped low product positions contribute a boundary carry into +// output 0 — see "boundary carry" comment below. +// +// Sign of f/g: limbs in [0, N-2] are non-negative (in [-2^15, 2^15) when +// non-canonical between iters); the top limb f[N-1] carries the signed +// extension of the full integer and is sign-extended via arithmetic +// shifts before multiplying. +// +// |nf[k]| <= 4 * (2^13 * 2^15) = 2^30 with non-canonical limbs, fits i32. +// ============================================================ +fn bya_apply_matrix_fg(m: MatA, f: ptr, g: ptr) { + // Matrix splits. _lo in [0, 2^13); _hi signed in [-2^13, 2^13). + let u_lo: i32 = i32(u32(m.u) & MASK); + let u_hi: i32 = m.u >> WORD_SIZE; + let v_lo: i32 = i32(u32(m.v) & MASK); + let v_hi: i32 = m.v >> WORD_SIZE; + let q_lo: i32 = i32(u32(m.q) & MASK); + let q_hi: i32 = m.q >> WORD_SIZE; + let r_lo: i32 = i32(u32(m.r) & MASK); + let r_hi: i32 = m.r >> WORD_SIZE; + + // Load all limbs into named locals to give the compiler a chance to + // hoist the loads above the multiply chain. + let f0: i32 = i32((*f).limbs[0]); + let f1: i32 = i32((*f).limbs[1]); + let f2: i32 = i32((*f).limbs[2]); + let f3: i32 = i32((*f).limbs[3]); + let f4: i32 = i32((*f).limbs[4]); + let f5: i32 = i32((*f).limbs[5]); + let f6: i32 = i32((*f).limbs[6]); + let f7: i32 = i32((*f).limbs[7]); + let f8: i32 = i32((*f).limbs[8]); + let f9: i32 = i32((*f).limbs[9]); + let f10: i32 = i32((*f).limbs[10]); + let f11: i32 = i32((*f).limbs[11]); + let f12: i32 = i32((*f).limbs[12]); + let f13: i32 = i32((*f).limbs[13]); + let f14: i32 = i32((*f).limbs[14]); + let f15: i32 = i32((*f).limbs[15]); + let f16: i32 = i32((*f).limbs[16]); + let f17: i32 = i32((*f).limbs[17]); + let f18: i32 = i32((*f).limbs[18]); + // Sign-extension of the top limb (bit 12 is the sign bit for canonical + // input; for non-canonical input we still arithmetic-shift the full + // i32 — high bits already carry sign). + let f19_raw: u32 = (*f).limbs[19]; + let f19: i32 = (i32(f19_raw) << (32u - WORD_SIZE)) >> (32u - WORD_SIZE); + + let g0: i32 = i32((*g).limbs[0]); + let g1: i32 = i32((*g).limbs[1]); + let g2: i32 = i32((*g).limbs[2]); + let g3: i32 = i32((*g).limbs[3]); + let g4: i32 = i32((*g).limbs[4]); + let g5: i32 = i32((*g).limbs[5]); + let g6: i32 = i32((*g).limbs[6]); + let g7: i32 = i32((*g).limbs[7]); + let g8: i32 = i32((*g).limbs[8]); + let g9: i32 = i32((*g).limbs[9]); + let g10: i32 = i32((*g).limbs[10]); + let g11: i32 = i32((*g).limbs[11]); + let g12: i32 = i32((*g).limbs[12]); + let g13: i32 = i32((*g).limbs[13]); + let g14: i32 = i32((*g).limbs[14]); + let g15: i32 = i32((*g).limbs[15]); + let g16: i32 = i32((*g).limbs[16]); + let g17: i32 = i32((*g).limbs[17]); + let g18: i32 = i32((*g).limbs[18]); + let g19_raw: u32 = (*g).limbs[19]; + let g19: i32 = (i32(g19_raw) << (32u - WORD_SIZE)) >> (32u - WORD_SIZE); + + // Boundary carry from the two dropped low product positions (positions + // 0 and 1). Carry-propagates as in a serial chain — the parallel-pass + // identity `(A >> 26) + (B >> 13)` is OFF BY 1 from + // `((B + (A >> 13)) >> 13)` in general because shifts don't distribute + // over addition. The boundary lands at output limb 0 below. + let rp0_f: i32 = u_lo * f0 + v_lo * g0; + let rp1_f: i32 = u_lo * f1 + v_lo * g1 + u_hi * f0 + v_hi * g0; + let boundary_f: i32 = (rp1_f + (rp0_f >> 13u)) >> 13u; + + let rp0_g: i32 = q_lo * f0 + r_lo * g0; + let rp1_g: i32 = q_lo * f1 + r_lo * g1 + q_hi * f0 + r_hi * g0; + let boundary_g: i32 = (rp1_g + (rp0_g >> 13u)) >> 13u; + + // MULTIPLY PHASE with shared partial products. Each individual product + // is used in EXACTLY ONE output slot — the names just help the GPU + // compiler pipeline issue muls + adds without re-reading the limb. + let ulf2 = u_lo * f2; let ulf3 = u_lo * f3; let ulf4 = u_lo * f4; let ulf5 = u_lo * f5; + let ulf6 = u_lo * f6; let ulf7 = u_lo * f7; let ulf8 = u_lo * f8; let ulf9 = u_lo * f9; + let ulf10 = u_lo * f10; let ulf11 = u_lo * f11; let ulf12 = u_lo * f12; let ulf13 = u_lo * f13; + let ulf14 = u_lo * f14; let ulf15 = u_lo * f15; let ulf16 = u_lo * f16; let ulf17 = u_lo * f17; + let ulf18 = u_lo * f18; let ulf19 = u_lo * f19; + let uhf1 = u_hi * f1; let uhf2 = u_hi * f2; let uhf3 = u_hi * f3; let uhf4 = u_hi * f4; + let uhf5 = u_hi * f5; let uhf6 = u_hi * f6; let uhf7 = u_hi * f7; let uhf8 = u_hi * f8; + let uhf9 = u_hi * f9; let uhf10 = u_hi * f10; let uhf11 = u_hi * f11; let uhf12 = u_hi * f12; + let uhf13 = u_hi * f13; let uhf14 = u_hi * f14; let uhf15 = u_hi * f15; let uhf16 = u_hi * f16; + let uhf17 = u_hi * f17; let uhf18 = u_hi * f18; let uhf19 = u_hi * f19; + let vlg2 = v_lo * g2; let vlg3 = v_lo * g3; let vlg4 = v_lo * g4; let vlg5 = v_lo * g5; + let vlg6 = v_lo * g6; let vlg7 = v_lo * g7; let vlg8 = v_lo * g8; let vlg9 = v_lo * g9; + let vlg10 = v_lo * g10; let vlg11 = v_lo * g11; let vlg12 = v_lo * g12; let vlg13 = v_lo * g13; + let vlg14 = v_lo * g14; let vlg15 = v_lo * g15; let vlg16 = v_lo * g16; let vlg17 = v_lo * g17; + let vlg18 = v_lo * g18; let vlg19 = v_lo * g19; + let vhg1 = v_hi * g1; let vhg2 = v_hi * g2; let vhg3 = v_hi * g3; let vhg4 = v_hi * g4; + let vhg5 = v_hi * g5; let vhg6 = v_hi * g6; let vhg7 = v_hi * g7; let vhg8 = v_hi * g8; + let vhg9 = v_hi * g9; let vhg10 = v_hi * g10; let vhg11 = v_hi * g11; let vhg12 = v_hi * g12; + let vhg13 = v_hi * g13; let vhg14 = v_hi * g14; let vhg15 = v_hi * g15; let vhg16 = v_hi * g16; + let vhg17 = v_hi * g17; let vhg18 = v_hi * g18; let vhg19 = v_hi * g19; + + let qlf2 = q_lo * f2; let qlf3 = q_lo * f3; let qlf4 = q_lo * f4; let qlf5 = q_lo * f5; + let qlf6 = q_lo * f6; let qlf7 = q_lo * f7; let qlf8 = q_lo * f8; let qlf9 = q_lo * f9; + let qlf10 = q_lo * f10; let qlf11 = q_lo * f11; let qlf12 = q_lo * f12; let qlf13 = q_lo * f13; + let qlf14 = q_lo * f14; let qlf15 = q_lo * f15; let qlf16 = q_lo * f16; let qlf17 = q_lo * f17; + let qlf18 = q_lo * f18; let qlf19 = q_lo * f19; + let qhf1 = q_hi * f1; let qhf2 = q_hi * f2; let qhf3 = q_hi * f3; let qhf4 = q_hi * f4; + let qhf5 = q_hi * f5; let qhf6 = q_hi * f6; let qhf7 = q_hi * f7; let qhf8 = q_hi * f8; + let qhf9 = q_hi * f9; let qhf10 = q_hi * f10; let qhf11 = q_hi * f11; let qhf12 = q_hi * f12; + let qhf13 = q_hi * f13; let qhf14 = q_hi * f14; let qhf15 = q_hi * f15; let qhf16 = q_hi * f16; + let qhf17 = q_hi * f17; let qhf18 = q_hi * f18; let qhf19 = q_hi * f19; + let rlg2 = r_lo * g2; let rlg3 = r_lo * g3; let rlg4 = r_lo * g4; let rlg5 = r_lo * g5; + let rlg6 = r_lo * g6; let rlg7 = r_lo * g7; let rlg8 = r_lo * g8; let rlg9 = r_lo * g9; + let rlg10 = r_lo * g10; let rlg11 = r_lo * g11; let rlg12 = r_lo * g12; let rlg13 = r_lo * g13; + let rlg14 = r_lo * g14; let rlg15 = r_lo * g15; let rlg16 = r_lo * g16; let rlg17 = r_lo * g17; + let rlg18 = r_lo * g18; let rlg19 = r_lo * g19; + let rhg1 = r_hi * g1; let rhg2 = r_hi * g2; let rhg3 = r_hi * g3; let rhg4 = r_hi * g4; + let rhg5 = r_hi * g5; let rhg6 = r_hi * g6; let rhg7 = r_hi * g7; let rhg8 = r_hi * g8; + let rhg9 = r_hi * g9; let rhg10 = r_hi * g10; let rhg11 = r_hi * g11; let rhg12 = r_hi * g12; + let rhg13 = r_hi * g13; let rhg14 = r_hi * g14; let rhg15 = r_hi * g15; let rhg16 = r_hi * g16; + let rhg17 = r_hi * g17; let rhg18 = r_hi * g18; let rhg19 = r_hi * g19; + + let nf0: i32 = ulf2 + vlg2 + uhf1 + vhg1 + boundary_f; + let nf1: i32 = ulf3 + vlg3 + uhf2 + vhg2; + let nf2: i32 = ulf4 + vlg4 + uhf3 + vhg3; + let nf3: i32 = ulf5 + vlg5 + uhf4 + vhg4; + let nf4: i32 = ulf6 + vlg6 + uhf5 + vhg5; + let nf5: i32 = ulf7 + vlg7 + uhf6 + vhg6; + let nf6: i32 = ulf8 + vlg8 + uhf7 + vhg7; + let nf7: i32 = ulf9 + vlg9 + uhf8 + vhg8; + let nf8: i32 = ulf10 + vlg10 + uhf9 + vhg9; + let nf9: i32 = ulf11 + vlg11 + uhf10 + vhg10; + let nf10: i32 = ulf12 + vlg12 + uhf11 + vhg11; + let nf11: i32 = ulf13 + vlg13 + uhf12 + vhg12; + let nf12: i32 = ulf14 + vlg14 + uhf13 + vhg13; + let nf13: i32 = ulf15 + vlg15 + uhf14 + vhg14; + let nf14: i32 = ulf16 + vlg16 + uhf15 + vhg15; + let nf15: i32 = ulf17 + vlg17 + uhf16 + vhg16; + let nf16: i32 = ulf18 + vlg18 + uhf17 + vhg17; + let nf17: i32 = ulf19 + vlg19 + uhf18 + vhg18; + let nf18: i32 = uhf19 + vhg19; + + let ng0: i32 = qlf2 + rlg2 + qhf1 + rhg1 + boundary_g; + let ng1: i32 = qlf3 + rlg3 + qhf2 + rhg2; + let ng2: i32 = qlf4 + rlg4 + qhf3 + rhg3; + let ng3: i32 = qlf5 + rlg5 + qhf4 + rhg4; + let ng4: i32 = qlf6 + rlg6 + qhf5 + rhg5; + let ng5: i32 = qlf7 + rlg7 + qhf6 + rhg6; + let ng6: i32 = qlf8 + rlg8 + qhf7 + rhg7; + let ng7: i32 = qlf9 + rlg9 + qhf8 + rhg8; + let ng8: i32 = qlf10 + rlg10 + qhf9 + rhg9; + let ng9: i32 = qlf11 + rlg11 + qhf10 + rhg10; + let ng10: i32 = qlf12 + rlg12 + qhf11 + rhg11; + let ng11: i32 = qlf13 + rlg13 + qhf12 + rhg12; + let ng12: i32 = qlf14 + rlg14 + qhf13 + rhg13; + let ng13: i32 = qlf15 + rlg15 + qhf14 + rhg14; + let ng14: i32 = qlf16 + rlg16 + qhf15 + rhg15; + let ng15: i32 = qlf17 + rlg17 + qhf16 + rhg16; + let ng16: i32 = qlf18 + rlg18 + qhf17 + rhg17; + let ng17: i32 = qlf19 + rlg19 + qhf18 + rhg18; + let ng18: i32 = qhf19 + rhg19; + + // SERIAL CARRY PASS — empirically faster than 2-pass parallel on this + // GPU (the carry chain is short enough that scheduler latency dominates + // any pipelining advantage). + var cf: i32 = 0; + let vf_0: i32 = nf0 + cf; (*f).limbs[0] = u32(vf_0) & MASK; cf = vf_0 >> 13u; + let vf_1: i32 = nf1 + cf; (*f).limbs[1] = u32(vf_1) & MASK; cf = vf_1 >> 13u; + let vf_2: i32 = nf2 + cf; (*f).limbs[2] = u32(vf_2) & MASK; cf = vf_2 >> 13u; + let vf_3: i32 = nf3 + cf; (*f).limbs[3] = u32(vf_3) & MASK; cf = vf_3 >> 13u; + let vf_4: i32 = nf4 + cf; (*f).limbs[4] = u32(vf_4) & MASK; cf = vf_4 >> 13u; + let vf_5: i32 = nf5 + cf; (*f).limbs[5] = u32(vf_5) & MASK; cf = vf_5 >> 13u; + let vf_6: i32 = nf6 + cf; (*f).limbs[6] = u32(vf_6) & MASK; cf = vf_6 >> 13u; + let vf_7: i32 = nf7 + cf; (*f).limbs[7] = u32(vf_7) & MASK; cf = vf_7 >> 13u; + let vf_8: i32 = nf8 + cf; (*f).limbs[8] = u32(vf_8) & MASK; cf = vf_8 >> 13u; + let vf_9: i32 = nf9 + cf; (*f).limbs[9] = u32(vf_9) & MASK; cf = vf_9 >> 13u; + let vf_10: i32 = nf10 + cf; (*f).limbs[10] = u32(vf_10) & MASK; cf = vf_10 >> 13u; + let vf_11: i32 = nf11 + cf; (*f).limbs[11] = u32(vf_11) & MASK; cf = vf_11 >> 13u; + let vf_12: i32 = nf12 + cf; (*f).limbs[12] = u32(vf_12) & MASK; cf = vf_12 >> 13u; + let vf_13: i32 = nf13 + cf; (*f).limbs[13] = u32(vf_13) & MASK; cf = vf_13 >> 13u; + let vf_14: i32 = nf14 + cf; (*f).limbs[14] = u32(vf_14) & MASK; cf = vf_14 >> 13u; + let vf_15: i32 = nf15 + cf; (*f).limbs[15] = u32(vf_15) & MASK; cf = vf_15 >> 13u; + let vf_16: i32 = nf16 + cf; (*f).limbs[16] = u32(vf_16) & MASK; cf = vf_16 >> 13u; + let vf_17: i32 = nf17 + cf; (*f).limbs[17] = u32(vf_17) & MASK; cf = vf_17 >> 13u; + let vf_18: i32 = nf18 + cf; (*f).limbs[18] = u32(vf_18) & MASK; cf = vf_18 >> 13u; + (*f).limbs[19] = u32(cf); + + var cg: i32 = 0; + let vg_0: i32 = ng0 + cg; (*g).limbs[0] = u32(vg_0) & MASK; cg = vg_0 >> 13u; + let vg_1: i32 = ng1 + cg; (*g).limbs[1] = u32(vg_1) & MASK; cg = vg_1 >> 13u; + let vg_2: i32 = ng2 + cg; (*g).limbs[2] = u32(vg_2) & MASK; cg = vg_2 >> 13u; + let vg_3: i32 = ng3 + cg; (*g).limbs[3] = u32(vg_3) & MASK; cg = vg_3 >> 13u; + let vg_4: i32 = ng4 + cg; (*g).limbs[4] = u32(vg_4) & MASK; cg = vg_4 >> 13u; + let vg_5: i32 = ng5 + cg; (*g).limbs[5] = u32(vg_5) & MASK; cg = vg_5 >> 13u; + let vg_6: i32 = ng6 + cg; (*g).limbs[6] = u32(vg_6) & MASK; cg = vg_6 >> 13u; + let vg_7: i32 = ng7 + cg; (*g).limbs[7] = u32(vg_7) & MASK; cg = vg_7 >> 13u; + let vg_8: i32 = ng8 + cg; (*g).limbs[8] = u32(vg_8) & MASK; cg = vg_8 >> 13u; + let vg_9: i32 = ng9 + cg; (*g).limbs[9] = u32(vg_9) & MASK; cg = vg_9 >> 13u; + let vg_10: i32 = ng10 + cg; (*g).limbs[10] = u32(vg_10) & MASK; cg = vg_10 >> 13u; + let vg_11: i32 = ng11 + cg; (*g).limbs[11] = u32(vg_11) & MASK; cg = vg_11 >> 13u; + let vg_12: i32 = ng12 + cg; (*g).limbs[12] = u32(vg_12) & MASK; cg = vg_12 >> 13u; + let vg_13: i32 = ng13 + cg; (*g).limbs[13] = u32(vg_13) & MASK; cg = vg_13 >> 13u; + let vg_14: i32 = ng14 + cg; (*g).limbs[14] = u32(vg_14) & MASK; cg = vg_14 >> 13u; + let vg_15: i32 = ng15 + cg; (*g).limbs[15] = u32(vg_15) & MASK; cg = vg_15 >> 13u; + let vg_16: i32 = ng16 + cg; (*g).limbs[16] = u32(vg_16) & MASK; cg = vg_16 >> 13u; + let vg_17: i32 = ng17 + cg; (*g).limbs[17] = u32(vg_17) & MASK; cg = vg_17 >> 13u; + let vg_18: i32 = ng18 + cg; (*g).limbs[18] = u32(vg_18) & MASK; cg = vg_18 >> 13u; + (*g).limbs[19] = u32(cg); +} + +// ============================================================ +// bya_apply_matrix_de — same shape as fg, plus k_d/k_e * p folded in. +// +// k_d, k_e are chosen so the low 26 bits of (u*d + v*e), (q*d + r*e) +// cancel mod p. The "low 26" reconstruction uses the same two-limb +// pre-compute as before. +// +// |nd[k]| <= 6 * (2^13 * 2^15) = 3 * 2^29 ≈ 2^30 — fits i32 with margin. +// ============================================================ +fn bya_apply_matrix_de( + m: MatA, + d: ptr, + e: ptr, + p: ptr, + p_inv_lo: u32, +) { + let u_lo: i32 = i32(u32(m.u) & MASK); + let u_hi: i32 = m.u >> WORD_SIZE; + let v_lo: i32 = i32(u32(m.v) & MASK); + let v_hi: i32 = m.v >> WORD_SIZE; + let q_lo: i32 = i32(u32(m.q) & MASK); + let q_hi: i32 = m.q >> WORD_SIZE; + let r_lo: i32 = i32(u32(m.r) & MASK); + let r_hi: i32 = m.r >> WORD_SIZE; + + // Load all limbs into named locals. + let d0: i32 = i32((*d).limbs[0]); + let d1: i32 = i32((*d).limbs[1]); + let d2: i32 = i32((*d).limbs[2]); + let d3: i32 = i32((*d).limbs[3]); + let d4: i32 = i32((*d).limbs[4]); + let d5: i32 = i32((*d).limbs[5]); + let d6: i32 = i32((*d).limbs[6]); + let d7: i32 = i32((*d).limbs[7]); + let d8: i32 = i32((*d).limbs[8]); + let d9: i32 = i32((*d).limbs[9]); + let d10: i32 = i32((*d).limbs[10]); + let d11: i32 = i32((*d).limbs[11]); + let d12: i32 = i32((*d).limbs[12]); + let d13: i32 = i32((*d).limbs[13]); + let d14: i32 = i32((*d).limbs[14]); + let d15: i32 = i32((*d).limbs[15]); + let d16: i32 = i32((*d).limbs[16]); + let d17: i32 = i32((*d).limbs[17]); + let d18: i32 = i32((*d).limbs[18]); + let d19_raw: u32 = (*d).limbs[19]; + let d19: i32 = (i32(d19_raw) << (32u - WORD_SIZE)) >> (32u - WORD_SIZE); + + let e0: i32 = i32((*e).limbs[0]); + let e1: i32 = i32((*e).limbs[1]); + let e2: i32 = i32((*e).limbs[2]); + let e3: i32 = i32((*e).limbs[3]); + let e4: i32 = i32((*e).limbs[4]); + let e5: i32 = i32((*e).limbs[5]); + let e6: i32 = i32((*e).limbs[6]); + let e7: i32 = i32((*e).limbs[7]); + let e8: i32 = i32((*e).limbs[8]); + let e9: i32 = i32((*e).limbs[9]); + let e10: i32 = i32((*e).limbs[10]); + let e11: i32 = i32((*e).limbs[11]); + let e12: i32 = i32((*e).limbs[12]); + let e13: i32 = i32((*e).limbs[13]); + let e14: i32 = i32((*e).limbs[14]); + let e15: i32 = i32((*e).limbs[15]); + let e16: i32 = i32((*e).limbs[16]); + let e17: i32 = i32((*e).limbs[17]); + let e18: i32 = i32((*e).limbs[18]); + let e19_raw: u32 = (*e).limbs[19]; + let e19: i32 = (i32(e19_raw) << (32u - WORD_SIZE)) >> (32u - WORD_SIZE); + + let p0: i32 = i32((*p).limbs[0]); + let p1: i32 = i32((*p).limbs[1]); + let p2: i32 = i32((*p).limbs[2]); + let p3: i32 = i32((*p).limbs[3]); + let p4: i32 = i32((*p).limbs[4]); + let p5: i32 = i32((*p).limbs[5]); + let p6: i32 = i32((*p).limbs[6]); + let p7: i32 = i32((*p).limbs[7]); + let p8: i32 = i32((*p).limbs[8]); + let p9: i32 = i32((*p).limbs[9]); + let p10: i32 = i32((*p).limbs[10]); + let p11: i32 = i32((*p).limbs[11]); + let p12: i32 = i32((*p).limbs[12]); + let p13: i32 = i32((*p).limbs[13]); + let p14: i32 = i32((*p).limbs[14]); + let p15: i32 = i32((*p).limbs[15]); + let p16: i32 = i32((*p).limbs[16]); + let p17: i32 = i32((*p).limbs[17]); + let p18: i32 = i32((*p).limbs[18]); + let p19: i32 = i32((*p).limbs[19]); + + // === Step 1: m-trick. Compute low 26 bits of (u*d + v*e), (q*d + r*e) + // to derive k_d, k_e so the result is divisible by 2^26. + let nd0_pre: i32 = u_lo * d0 + v_lo * e0; + let nd1_pre: i32 = u_lo * d1 + v_lo * e1 + u_hi * d0 + v_hi * e0; + let ne0_pre: i32 = q_lo * d0 + r_lo * e0; + let ne1_pre: i32 = q_lo * d1 + r_lo * e1 + q_hi * d0 + r_hi * e0; + + let nd1_full: i32 = nd1_pre + (nd0_pre >> 13u); + let ne1_full: i32 = ne1_pre + (ne0_pre >> 13u); + let td_low26: u32 = (u32(nd0_pre) & MASK) | ((u32(nd1_full) & MASK) << 13u); + let te_low26: u32 = (u32(ne0_pre) & MASK) | ((u32(ne1_full) & MASK) << 13u); + + let MASK_BATCH: u32 = (1u << BYA_BATCH) - 1u; + let neg_td: u32 = (~td_low26 + 1u) & MASK_BATCH; + let neg_te: u32 = (~te_low26 + 1u) & MASK_BATCH; + let kd_full: u32 = (neg_td * p_inv_lo) & MASK_BATCH; + let ke_full: u32 = (neg_te * p_inv_lo) & MASK_BATCH; + + let kd_lo: i32 = i32(kd_full & MASK); + let kd_hi: i32 = i32(kd_full >> WORD_SIZE); + let ke_lo: i32 = i32(ke_full & MASK); + let ke_hi: i32 = i32(ke_full >> WORD_SIZE); + + // Boundary carry from positions 0, 1 of the full product. After + // m-trick, the low 26 bits ARE zero, so boundary is exactly the + // shift-out from positions 0 and 1. + let rp0_d: i32 = nd0_pre + kd_lo * p0; + let rp1_d: i32 = nd1_pre + kd_lo * p1 + kd_hi * p0; + let boundary_d: i32 = (rp1_d + (rp0_d >> 13u)) >> 13u; + + let rp0_e: i32 = ne0_pre + ke_lo * p0; + let rp1_e: i32 = ne1_pre + ke_lo * p1 + ke_hi * p0; + let boundary_e: i32 = (rp1_e + (rp0_e >> 13u)) >> 13u; + + // PARALLEL MULTIPLY PHASE. + // raw_nd[k] = u_lo*d[k+2] + v_lo*e[k+2] + u_hi*d[k+1] + v_hi*e[k+1] + // + kd_lo*p[k+2] + kd_hi*p[k+1] + let nd0: i32 = u_lo * d2 + v_lo * e2 + u_hi * d1 + v_hi * e1 + kd_lo * p2 + kd_hi * p1 + boundary_d; + let nd1: i32 = u_lo * d3 + v_lo * e3 + u_hi * d2 + v_hi * e2 + kd_lo * p3 + kd_hi * p2; + let nd2: i32 = u_lo * d4 + v_lo * e4 + u_hi * d3 + v_hi * e3 + kd_lo * p4 + kd_hi * p3; + let nd3: i32 = u_lo * d5 + v_lo * e5 + u_hi * d4 + v_hi * e4 + kd_lo * p5 + kd_hi * p4; + let nd4: i32 = u_lo * d6 + v_lo * e6 + u_hi * d5 + v_hi * e5 + kd_lo * p6 + kd_hi * p5; + let nd5: i32 = u_lo * d7 + v_lo * e7 + u_hi * d6 + v_hi * e6 + kd_lo * p7 + kd_hi * p6; + let nd6: i32 = u_lo * d8 + v_lo * e8 + u_hi * d7 + v_hi * e7 + kd_lo * p8 + kd_hi * p7; + let nd7: i32 = u_lo * d9 + v_lo * e9 + u_hi * d8 + v_hi * e8 + kd_lo * p9 + kd_hi * p8; + let nd8: i32 = u_lo * d10 + v_lo * e10 + u_hi * d9 + v_hi * e9 + kd_lo * p10 + kd_hi * p9; + let nd9: i32 = u_lo * d11 + v_lo * e11 + u_hi * d10 + v_hi * e10 + kd_lo * p11 + kd_hi * p10; + let nd10: i32 = u_lo * d12 + v_lo * e12 + u_hi * d11 + v_hi * e11 + kd_lo * p12 + kd_hi * p11; + let nd11: i32 = u_lo * d13 + v_lo * e13 + u_hi * d12 + v_hi * e12 + kd_lo * p13 + kd_hi * p12; + let nd12: i32 = u_lo * d14 + v_lo * e14 + u_hi * d13 + v_hi * e13 + kd_lo * p14 + kd_hi * p13; + let nd13: i32 = u_lo * d15 + v_lo * e15 + u_hi * d14 + v_hi * e14 + kd_lo * p15 + kd_hi * p14; + let nd14: i32 = u_lo * d16 + v_lo * e16 + u_hi * d15 + v_hi * e15 + kd_lo * p16 + kd_hi * p15; + let nd15: i32 = u_lo * d17 + v_lo * e17 + u_hi * d16 + v_hi * e16 + kd_lo * p17 + kd_hi * p16; + let nd16: i32 = u_lo * d18 + v_lo * e18 + u_hi * d17 + v_hi * e17 + kd_lo * p18 + kd_hi * p17; + let nd17: i32 = u_lo * d19 + v_lo * e19 + u_hi * d18 + v_hi * e18 + kd_lo * p19 + kd_hi * p18; + let nd18: i32 = u_hi * d19 + v_hi * e19 + kd_hi * p19; + + let ne0: i32 = q_lo * d2 + r_lo * e2 + q_hi * d1 + r_hi * e1 + ke_lo * p2 + ke_hi * p1 + boundary_e; + let ne1: i32 = q_lo * d3 + r_lo * e3 + q_hi * d2 + r_hi * e2 + ke_lo * p3 + ke_hi * p2; + let ne2: i32 = q_lo * d4 + r_lo * e4 + q_hi * d3 + r_hi * e3 + ke_lo * p4 + ke_hi * p3; + let ne3: i32 = q_lo * d5 + r_lo * e5 + q_hi * d4 + r_hi * e4 + ke_lo * p5 + ke_hi * p4; + let ne4: i32 = q_lo * d6 + r_lo * e6 + q_hi * d5 + r_hi * e5 + ke_lo * p6 + ke_hi * p5; + let ne5: i32 = q_lo * d7 + r_lo * e7 + q_hi * d6 + r_hi * e6 + ke_lo * p7 + ke_hi * p6; + let ne6: i32 = q_lo * d8 + r_lo * e8 + q_hi * d7 + r_hi * e7 + ke_lo * p8 + ke_hi * p7; + let ne7: i32 = q_lo * d9 + r_lo * e9 + q_hi * d8 + r_hi * e8 + ke_lo * p9 + ke_hi * p8; + let ne8: i32 = q_lo * d10 + r_lo * e10 + q_hi * d9 + r_hi * e9 + ke_lo * p10 + ke_hi * p9; + let ne9: i32 = q_lo * d11 + r_lo * e11 + q_hi * d10 + r_hi * e10 + ke_lo * p11 + ke_hi * p10; + let ne10: i32 = q_lo * d12 + r_lo * e12 + q_hi * d11 + r_hi * e11 + ke_lo * p12 + ke_hi * p11; + let ne11: i32 = q_lo * d13 + r_lo * e13 + q_hi * d12 + r_hi * e12 + ke_lo * p13 + ke_hi * p12; + let ne12: i32 = q_lo * d14 + r_lo * e14 + q_hi * d13 + r_hi * e13 + ke_lo * p14 + ke_hi * p13; + let ne13: i32 = q_lo * d15 + r_lo * e15 + q_hi * d14 + r_hi * e14 + ke_lo * p15 + ke_hi * p14; + let ne14: i32 = q_lo * d16 + r_lo * e16 + q_hi * d15 + r_hi * e15 + ke_lo * p16 + ke_hi * p15; + let ne15: i32 = q_lo * d17 + r_lo * e17 + q_hi * d16 + r_hi * e16 + ke_lo * p17 + ke_hi * p16; + let ne16: i32 = q_lo * d18 + r_lo * e18 + q_hi * d17 + r_hi * e17 + ke_lo * p18 + ke_hi * p17; + let ne17: i32 = q_lo * d19 + r_lo * e19 + q_hi * d18 + r_hi * e18 + ke_lo * p19 + ke_hi * p18; + let ne18: i32 = q_hi * d19 + r_hi * e19 + ke_hi * p19; + + // SERIAL CARRY PASS. + var cd: i32 = 0; + let vd_0: i32 = nd0 + cd; (*d).limbs[0] = u32(vd_0) & MASK; cd = vd_0 >> 13u; + let vd_1: i32 = nd1 + cd; (*d).limbs[1] = u32(vd_1) & MASK; cd = vd_1 >> 13u; + let vd_2: i32 = nd2 + cd; (*d).limbs[2] = u32(vd_2) & MASK; cd = vd_2 >> 13u; + let vd_3: i32 = nd3 + cd; (*d).limbs[3] = u32(vd_3) & MASK; cd = vd_3 >> 13u; + let vd_4: i32 = nd4 + cd; (*d).limbs[4] = u32(vd_4) & MASK; cd = vd_4 >> 13u; + let vd_5: i32 = nd5 + cd; (*d).limbs[5] = u32(vd_5) & MASK; cd = vd_5 >> 13u; + let vd_6: i32 = nd6 + cd; (*d).limbs[6] = u32(vd_6) & MASK; cd = vd_6 >> 13u; + let vd_7: i32 = nd7 + cd; (*d).limbs[7] = u32(vd_7) & MASK; cd = vd_7 >> 13u; + let vd_8: i32 = nd8 + cd; (*d).limbs[8] = u32(vd_8) & MASK; cd = vd_8 >> 13u; + let vd_9: i32 = nd9 + cd; (*d).limbs[9] = u32(vd_9) & MASK; cd = vd_9 >> 13u; + let vd_10: i32 = nd10 + cd; (*d).limbs[10] = u32(vd_10) & MASK; cd = vd_10 >> 13u; + let vd_11: i32 = nd11 + cd; (*d).limbs[11] = u32(vd_11) & MASK; cd = vd_11 >> 13u; + let vd_12: i32 = nd12 + cd; (*d).limbs[12] = u32(vd_12) & MASK; cd = vd_12 >> 13u; + let vd_13: i32 = nd13 + cd; (*d).limbs[13] = u32(vd_13) & MASK; cd = vd_13 >> 13u; + let vd_14: i32 = nd14 + cd; (*d).limbs[14] = u32(vd_14) & MASK; cd = vd_14 >> 13u; + let vd_15: i32 = nd15 + cd; (*d).limbs[15] = u32(vd_15) & MASK; cd = vd_15 >> 13u; + let vd_16: i32 = nd16 + cd; (*d).limbs[16] = u32(vd_16) & MASK; cd = vd_16 >> 13u; + let vd_17: i32 = nd17 + cd; (*d).limbs[17] = u32(vd_17) & MASK; cd = vd_17 >> 13u; + let vd_18: i32 = nd18 + cd; (*d).limbs[18] = u32(vd_18) & MASK; cd = vd_18 >> 13u; + (*d).limbs[19] = u32(cd); + + var ce: i32 = 0; + let ve_0: i32 = ne0 + ce; (*e).limbs[0] = u32(ve_0) & MASK; ce = ve_0 >> 13u; + let ve_1: i32 = ne1 + ce; (*e).limbs[1] = u32(ve_1) & MASK; ce = ve_1 >> 13u; + let ve_2: i32 = ne2 + ce; (*e).limbs[2] = u32(ve_2) & MASK; ce = ve_2 >> 13u; + let ve_3: i32 = ne3 + ce; (*e).limbs[3] = u32(ve_3) & MASK; ce = ve_3 >> 13u; + let ve_4: i32 = ne4 + ce; (*e).limbs[4] = u32(ve_4) & MASK; ce = ve_4 >> 13u; + let ve_5: i32 = ne5 + ce; (*e).limbs[5] = u32(ve_5) & MASK; ce = ve_5 >> 13u; + let ve_6: i32 = ne6 + ce; (*e).limbs[6] = u32(ve_6) & MASK; ce = ve_6 >> 13u; + let ve_7: i32 = ne7 + ce; (*e).limbs[7] = u32(ve_7) & MASK; ce = ve_7 >> 13u; + let ve_8: i32 = ne8 + ce; (*e).limbs[8] = u32(ve_8) & MASK; ce = ve_8 >> 13u; + let ve_9: i32 = ne9 + ce; (*e).limbs[9] = u32(ve_9) & MASK; ce = ve_9 >> 13u; + let ve_10: i32 = ne10 + ce; (*e).limbs[10] = u32(ve_10) & MASK; ce = ve_10 >> 13u; + let ve_11: i32 = ne11 + ce; (*e).limbs[11] = u32(ve_11) & MASK; ce = ve_11 >> 13u; + let ve_12: i32 = ne12 + ce; (*e).limbs[12] = u32(ve_12) & MASK; ce = ve_12 >> 13u; + let ve_13: i32 = ne13 + ce; (*e).limbs[13] = u32(ve_13) & MASK; ce = ve_13 >> 13u; + let ve_14: i32 = ne14 + ce; (*e).limbs[14] = u32(ve_14) & MASK; ce = ve_14 >> 13u; + let ve_15: i32 = ne15 + ce; (*e).limbs[15] = u32(ve_15) & MASK; ce = ve_15 >> 13u; + let ve_16: i32 = ne16 + ce; (*e).limbs[16] = u32(ve_16) & MASK; ce = ve_16 >> 13u; + let ve_17: i32 = ne17 + ce; (*e).limbs[17] = u32(ve_17) & MASK; ce = ve_17 >> 13u; + let ve_18: i32 = ne18 + ce; (*e).limbs[18] = u32(ve_18) & MASK; ce = ve_18 >> 13u; + (*e).limbs[19] = u32(ce); +} + +// ============================================================ +// Driver helpers +// ============================================================ + +fn bya_is_zero(x: ptr) -> bool { + var a: u32 = 0u; + for (var i: u32 = 0u; i < {{ num_words }}u; i = i + 1u) { + a = a | (*x).limbs[i]; + } + return a == 0u; +} + +fn bya_neg_inplace(x: ptr) { + for (var i: u32 = 0u; i < {{ num_words }}u; i = i + 1u) { + (*x).limbs[i] = u32(-i32((*x).limbs[i])); + } + bya_normalise(x); +} + +fn bya_add_p_inplace(x: ptr, p: ptr) { + for (var i: u32 = 0u; i < {{ num_words }}u; i = i + 1u) { + (*x).limbs[i] = u32(i32((*x).limbs[i]) + i32((*p).limbs[i])); + } + bya_normalise(x); +} + +fn bya_sub_p_inplace(x: ptr, p: ptr) { + for (var i: u32 = 0u; i < {{ num_words }}u; i = i + 1u) { + (*x).limbs[i] = u32(i32((*x).limbs[i]) - i32((*p).limbs[i])); + } + bya_normalise(x); +} + +fn bya_reduce_to_canonical(x: ptr, p: ptr) { + bya_normalise(x); + var done: bool = false; + for (var it: u32 = 0u; it < BYA_RTC_MAX_ITERS; it = it + 1u) { + if (done) { continue; } + if (bigint_is_neg_2c(x)) { + bya_add_p_inplace(x, p); + } else if (bigint_gte(x, p)) { + bya_sub_p_inplace(x, p); + } else { + done = true; + } + } +} + +const FR_INV_BY_A_P_INV_LO: u32 = {{ p_inv_by_a_lo }}u; + +// fr_inv_by_a: Bernstein-Yang safegcd inverse driver, BATCH=26 / NUM_OUTER=29 +// on the 20 x 13-bit BigInt representation. Tight inline-mul apply_matrix. +fn fr_inv_by_a(a: BigInt) -> BigInt { + var p_loc: BigInt = get_p(); + var f: BigInt = get_p(); + var g: BigInt = a; + + var d: BigInt; + var e: BigInt; + for (var k: u32 = 0u; k < {{ num_words }}u; k = k + 1u) { + d.limbs[k] = 0u; + e.limbs[k] = 0u; + } + e.limbs[0] = 1u; + + var delta: i32 = 1; + var done: bool = false; + for (var iter: u32 = 0u; iter < BYA_NUM_OUTER; iter = iter + 1u) { + if (done) { continue; } + let f_lo: vec2 = bya_low_u64_lohi(f); + let g_lo: vec2 = bya_low_u64_lohi(g); + let m: MatA = bya_divsteps(&delta, f_lo, g_lo); + bya_apply_matrix_fg(m, &f, &g); + bya_apply_matrix_de(m, &d, &e, &p_loc, FR_INV_BY_A_P_INV_LO); + if (((iter + 1u) % BYA_REDUCE_INTERVAL) == 0u) { + bya_reduce_to_canonical(&d, &p_loc); + bya_reduce_to_canonical(&e, &p_loc); + } + if (bya_is_zero(&g)) { + done = true; + } + } + + bya_reduce_to_canonical(&d, &p_loc); + if (bigint_is_neg_2c(&f)) { + bya_neg_inplace(&d); + bya_reduce_to_canonical(&d, &p_loc); + } + + var inv_native: BigInt = d; + var r_cubed: BigInt = get_r_cubed(); + return montgomery_product(&inv_native, &r_cubed); +} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/field/fr_pow.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/field/fr_pow.template.wgsl index b0d512966a26..a001ee1b89da 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/field/fr_pow.template.wgsl +++ b/barretenberg/ts/src/msm_webgpu/wgsl/field/fr_pow.template.wgsl @@ -34,6 +34,28 @@ fn fr_pow(base: BigInt, exp: BigInt) -> BigInt { return result; } +// (p - 2) as a plain (non-Montgomery) BigInt. Used by fr_pow_inv as the +// exponent in Fermat's little theorem: a^(p-2) ≡ a^(-1) (mod p). +fn get_p_minus_2() -> BigInt { + var e: BigInt; +{{{ p_minus_2_limbs }}} + return e; +} + +// Field inversion via Fermat's little theorem: a^(-1) ≡ a^(p-2) (mod p). +// Both input and output are in Montgomery form. Since `fr_pow` preserves +// Montgomery form (Mont(base)^exp -> Mont(base^exp)), the result of +// `fr_pow(Mont(a), p-2)` is directly Mont(a^(-1)). No extra correction. +// +// Cost: ~254 squarings + ~127 expected multiplies (half the bits of p-2 +// are set), ≈ 381 montgomery_products per call. Compare to fr_inv's +// jumpy K=12 safegcd which converges in ~62 outer iters with ~10 +// BigInt-ops each plus ONE montgomery_product at the end. +fn fr_pow_inv(a: BigInt) -> BigInt { + var exp: BigInt = get_p_minus_2(); + return fr_pow(a, exp); +} + // R^3 mod p. Used by `fr_inv` to convert the binary-GCD output (which is // in native form, pre-multiplied by R^(-1) because the input was in // Montgomery form) back into Montgomery form via a single From 84e7e3cbfa2a49c4881f01a23220575699678c61 Mon Sep 17 00:00:00 2001 From: zac-williamson Date: Sun, 17 May 2026 12:53:19 +0100 Subject: [PATCH 4/4] chore: remove plan files (delivered out-of-band) --- .claude/plans/msm-tree-reduce.md | 172 -------------- .claude/plans/msm-webgpu-rewrite.md | 357 ---------------------------- 2 files changed, 529 deletions(-) delete mode 100644 .claude/plans/msm-tree-reduce.md delete mode 100644 .claude/plans/msm-webgpu-rewrite.md diff --git a/.claude/plans/msm-tree-reduce.md b/.claude/plans/msm-tree-reduce.md deleted file mode 100644 index 042ef4a51693..000000000000 --- a/.claude/plans/msm-tree-reduce.md +++ /dev/null @@ -1,172 +0,0 @@ -# Stage B — Tree-reduce per bucket with adaptive batch sizing - -> Replaces the current SMVP round-loop (`smvp_batch_affine_gpu` + 5 shaders) with -> a tree-reduce structure that scales logarithmically in max bucket population -> instead of linearly. Designed for skewed real-world ZK workloads where the -> current round-loop's MAX_ROUNDS bound is dominated by a few heavy buckets. - -## Constants - -``` -SWEET_B = 1024 // peak per-pair throughput (24.4 ns/pair from bench) -MIN_B = 32 // floor: TPB=32, 1 SIMD group, no cross-SIMD barriers, 1.56× sweet cost -TARGET_THREADS = 40_000 // Apple Silicon (M-series Pro/Max) resident thread budget -TPB_DEFAULT = 64 -TPB_MIN_B = 32 // matches Apple's SIMD group width -MAX_PHASES = 10 // recursion safety cap; pre-pass usually computes exact depth -``` - -## Adaptive batch sizing - -``` -function pickBatch(total_adds): - candidate_B = total_adds / (TARGET_THREADS / TPB_DEFAULT) # = total_adds / 625 - - if candidate_B >= SWEET_B: # plenty of work - return (SWEET_B, ceil(total_adds / SWEET_B), 64) - elif candidate_B >= 64: # mid: largest pow-2 ≤ candidate - B = floor_pow2(candidate_B) - return (B, ceil(total_adds / B), 64) - else: # tail: floor at MIN_B with TPB=32 - return (MIN_B, ceil(total_adds / MIN_B), 32) -``` - -## Phase structure - -### Pre-pass kernel (per phase) - -One small dispatch. Inputs: sorted schedule (Phase 1) or partials buffer (Phase ≥2). For each entry: -1. Determine WG slice membership via index-partitioning of `total_adds`. -2. Flag if entry is first of its bucket in its WG slice. -3. Flag if entry pairs with the next entry (same bucket, both in same WG slice). -4. Emit pair (idx_a, idx_b) to per-WG pair-list slot. - -Then host-side prefix sums produce: -- `wg_pair_offset[]`, `wg_pair_count[]` — pair-list slice per WG -- `wg_output_offset[]`, `wg_output_count[]` — output partials slice per WG -- `wg_first_bucket[]` — bucket_id of first partial (for cross-WG boundary detection) -- `max_pop_remaining` — if 0, no more dup buckets, terminate - -No atomics. Per-entry kernel work is O(1); host prefix-sum is O(num_WGs) = O(1000) = trivial. - -### Phase 1: per-WG slice batch-affine - -One dispatch, `num_WGs` workgroups of TPB threads (from `pickBatch`). Inputs: -- Bucket-sorted schedule -- Pre-computed pair list - -Each WG: -1. Reads its `wg_pair_count` pre-computed pairs from `pair_list[wg_pair_offset[wg_id] : wg_pair_offset[wg_id] + wg_pair_count[wg_id]]` -2. For each pair (a, b): loads `P = points[scalar_idx_a]` (with sign-flip per SCHEDULE_SIGN_BIT), `Q = points[scalar_idx_b]`, computes `delta_x = Q.x - P.x` -3. Cooperative Phase A/B/C/D batch inverse (workgroup-shared scan, 1 fr_inv_by_a per WG) -4. Per-pair: compute slope, R = P + Q -5. Compaction: each pair's result is the partial for some bucket. Adjacent same-bucket pairs in the slice get combined into running sum; final partials written to `output[wg_output_offset[wg_id] + slot]` with `bucket_id` tag. - -Single fr_inv per WG amortises over `wg_pair_count` ≈ B pair-adds. - -### Phase ≥2: tree-reduce on partials - -Re-sort phase 1's output by bucket_id globally (use existing transpose pattern — fast on GPU). Then re-run pre-pass + Phase 1 kernel on the bucket-sorted partials buffer. - -Key difference from Phase 1: load is `partials[idx]` (a point) instead of `points[scalar_idx]` (a point with sign). Faster — no negation per load. - -### Phase final: BPR / Horner - -After all phases collapse buckets to 1 point each, hand off to existing BPR per window + Horner combine across windows. No change to those. - -## Memory budget (logN=16, N_entries=1.1M, B_active≈272K) - -- `pair_list` (Phase 1): ~825K pairs × 8 bytes = **6.6 MB** -- `wg_*` arrays: 1000 WGs × ~5 × 4 bytes = **20 KB** -- `output partials` (Phase 1): ~325K × 68 bytes = **22 MB** -- `output partials` (Phase 2): ~80K × 68 bytes = **5.5 MB** -- `output partials` (Phase ≥3): rapidly shrinking -- Total scratch: **~35 MB**, well under any device limit - -## Phase count (theoretical) - -For typical 4-entries-per-bucket: `max_pop ≈ 16`, `log2(16/1024) ≤ 0` → **Phase 1 alone resolves it**. - -For skewed (heavy bucket pop=10K): `max_pop = 10000`, `log2(10000/1024) ≈ 4` → **Phase 1 + 3-4 recursion levels**. - -For uniform with sweet B fill: ~5 phases worst case. - -vs current: 32 rounds. **6× fewer dispatches in typical case**. - -## What we save - -- **Dispatch overhead**: 5 phases × 3 dispatches each = 15 vs current 32 rounds × 3 = 96. Saves ~1.6 ms. -- **Late-round amortisation collapse**: gone — adaptive sizing keeps per-WG batch at sweet through phase 5+. -- **Pathological skew**: round count goes from O(max_pop) to O(log max_pop). **The big win for production ZK workloads.** - -## Open implementation decisions - -### Per-WG slice compaction (within phase 1 / phase ≥2) - -Each WG's batch-affine produces `wg_pair_count` result points. These need to be COMPACTED into per-bucket partials (one partial per distinct bucket the WG touched). - -Two sub-options: -- **(a) Within-WG sequential merge**: after batch-affine, one thread walks the pair results, merges adjacent same-bucket results, writes final partials. ~B sequential adds (cheap, 63/64 threads idle but only briefly). -- **(b) Within-WG segmented reduce**: parallel reduction grouping by bucket_id. More complex. - -Going with **(a)** — simpler, the post-merge work is negligible compared to the batch-affine. - -### Re-sort between phases - -Phase k output is grouped by WG; Phase k+1 needs bucket-grouped input. Options: -- **Transpose-style**: use existing `transpose_parallel_{count,scan,scatter}` infrastructure on the new layout. Adds ~3 dispatches per phase. -- **Per-WG outputs are SORTED by bucket already** (since schedule was bucket-sorted). Just need a parallel MERGE of K sorted lists. O(N log K). Cheap. - -Going with **merge** — fewer dispatches. - -### Pair-list pre-pass - -Single dispatch, one thread per schedule entry. Per entry: -- Determine WG = `entry_idx * num_WGs / total_adds_density` (uses precomputed running-adds index) -- Check predecessor entry: same bucket + same WG slice → emit pair (predecessor, self) to per-WG slot - -Per-WG pair slot allocation: pre-pre-pass counts per-WG pair count, host prefix-sums. - -So phase structure is actually: -1. count-pass — count pairs per WG (1 atomic per WG, only num_WGs increments, low contention) -2. host prefix-sum — compute pair_offsets -3. fill-pass — write pairs to per-WG slots (one atomic per WG for local cursor, or use 2-thread cooperation to make atomicLess) -4. phase 1 batch-affine - -Actually atomics per-WG are TRIVIAL (one address per WG = no contention). Acceptable. - -OR even cleaner: do the count-pass and fill-pass in ONE kernel with per-thread local pair-buffer in registers, flushed at WG boundary. Avoids any global atomics. Complexity vs simplicity tradeoff. - -For first implementation: 2-pass pre-pass with per-WG-local atomics. Optimize later. - -## Phase count termination - -Pre-pass computes per-bucket population at phase 0. `MAX_PHASES = ceil(log2(max_pop / SWEET_B)) + 2`. Hard-coded; no runtime detection. - -OR per-phase: if `num_distinct_buckets_output == num_distinct_buckets_input`, no reduction happened → done. - -Use the formula-based approach (cleaner; hardcoded loop count). Loop: -``` -for phase in 0..MAX_PHASES: - if total_adds_remaining == 0: break - dispatch pre-pass - dispatch phase k - re-sort output → input of next phase -``` - -## What this does NOT include (per user scope) - -- Duplicate stripping -- Two bucket widths -- Adaptive bucket width (c stays constant) -- GLV scalar split - -## Estimated impact - -For UNIFORM data at logN=16 (current bench case): -- ba_inverse_Σ: 10.8 ms → estimated 6-8 ms. Saving ~3 ms = 4% MSM wall. -- Dispatch overhead saving: ~1.6 ms. - -For SKEWED data (typical ZK workloads): -- ba_inverse phase: estimated 3-5× faster due to log_2 vs linear in max_pop. -- Could be ~25-40% MSM wall reduction. Numbers depend heavily on actual workload skew profile. diff --git a/.claude/plans/msm-webgpu-rewrite.md b/.claude/plans/msm-webgpu-rewrite.md deleted file mode 100644 index e270eb9b7027..000000000000 --- a/.claude/plans/msm-webgpu-rewrite.md +++ /dev/null @@ -1,357 +0,0 @@ -# Plan: WebGPU BN254 MSM rewrite — BY field inversion + multi-window Pippenger + 32-bit point schedule - -> Produced by Plan subagent on 2026-05-16. The execution loop owner (orchestrator -> Claude) iterates phases below via coder + reviewer subagents. Source of truth -> for what "done" means: this plan's acceptance gates. - -## 0. Plan summary - -**Two ideas to implement.** - -**Idea 1 — Replace WGSL `fr_inv` with the Bernstein–Yang (BY) safegcd inversion that the WASM uses.** -Port `Wasm9x29::divsteps` + `Wasm9x29::apply_matrix` to WGSL (9 × 29-bit signed limbs, BATCH=58 inner divsteps per outer iter, NUM_OUTER=13 outer iters with early `g == 0` break, REDUCE_INTERVAL=4). Each outer iter folds 58 divsteps into one 2×2 matrix and applies it via a streamed schoolbook with limb-by-limb carry. Target: at least 2× wall reduction on the `fr_inv` critical path vs the existing jumpy safegcd `fr_inv`; ideally 3–5×. - -**Idea 2 — Multi-window batched Pippenger + 32-bit point schedule.** -- Replace the per-round bucket-cursor + atomic pair counter design with a bucket-sorted 32-bit schedule built via histogram → per-window prefix-sum → scatter (Stages 1/2/3/4 of the WASM). -- Make the batch-affine reduce phase consume `NUM_WINDOWS_PER_BATCH × num_columns` pairs in one batched inversion so the inversion amortises over both buckets AND windows. -- Extend `batch_inverse_parallel`'s workgroup-Z dimension to `num_subtasks × NUM_WINDOWS_PER_BATCH`. -- Schedule entry layout matches the WASM's `Constantine` packed digit: bit 31 = sign, bits 0..28 = scalar_idx (29-bit payload). Dedup-redirect / dedup-skip bits exist in the encoding but are unused. - -**Out of scope (explicit per spec):** duplicate stripping, two bucket widths for one MSM, adaptive c. - ---- - -## 1. Required reading - -The first coding agent MUST read these before writing any code. - -**WASM reference (source of ideas):** -1. `/Users/zac/barretenberg-claude-2/barretenberg/cpp/src/barretenberg/ecc/fields/bernstein_yang_inverse.hpp` -2. `/Users/zac/barretenberg-claude-2/barretenberg/cpp/src/barretenberg/ecc/fields/bernstein_yang_inverse_wasm.hpp` — `Wasm9x29` (closest to WGSL target) -3. `/Users/zac/barretenberg-claude-2/barretenberg/cpp/src/barretenberg/ecc/scalar_multiplication/scalar_multiplication.cpp`: - - lines 1–350 (file header, `get_scalar_slice_low`, `compute_constantine_slice_params`, `get_constantine_packed_digit`) - - lines 540–710 (schedule entry bit constants, `VariableWindowSchedule`, `RegionView`, cost model) - - lines 1620–1800 (`pippenger_round_parallel_jacobian_fast` — single-thread textbook structural reference) - - lines 2671–2830 (entry to `pippenger_round_parallel`, Arena setup, Phase 1) - - lines 3780–4080 (Stage 1 histogram, Stage 2/3 bucket-offset, Stage 4 scatter; skip Phase A dedup body) - - lines 4210–4550 (Stage 6 partition, Stage 6a/6b bucket reduction across `windows_in_batch`) - - lines 4550–4610 (per-region dispatch driver, lower/upper regions, batch loop) - -**Existing WebGPU MSM (target codebase):** -4. `barretenberg/ts/src/msm_webgpu/wgsl/field/fr_pow.template.wgsl` -5. `barretenberg/ts/src/msm_webgpu/wgsl/bigint/bigint.template.wgsl` -6. `barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_inverse_parallel.template.wgsl` -7. `barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_inverse.template.wgsl` -8. `barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts` -9. `barretenberg/ts/src/msm_webgpu/cuzk/batch_affine.ts` -10. `barretenberg/ts/src/msm_webgpu/msm.ts` lines 540–924 -11. `barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_schedule.template.wgsl` -12. `barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_apply_scatter.template.wgsl` -13. `barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_affine_init.template.wgsl`, `batch_affine_dispatch_args.template.wgsl`, `batch_affine_finalize_collect.template.wgsl`, `batch_affine_finalize_apply.template.wgsl` -14. `barretenberg/ts/src/msm_webgpu/cuzk/bn254.ts` (host BigInt reference) -15. `barretenberg/ts/dev/msm-webgpu/bench-field-mul.ts`, `bench-field-mul.html` -16. `barretenberg/ts/dev/msm-webgpu/scripts/bench-field-mul.mjs` -17. `barretenberg/ts/dev/msm-webgpu/main.ts` lines 1180–1300 (`Quick sanity check (WebGPU only)` button) - -**Hard constraints for every coding agent (repeated in every task brief):** - -- WebGPU on Apple Silicon Metal is FRAGILE. A wedged shader can require a reboot. -- **Every WGSL loop MUST have a compile-time-constant upper bound** (`for (var i = 0u; i < CONST; i = i + 1u)` where `CONST` is a `const` or substituted Mustache value). Reject any shader that fails this audit. -- For the BY divsteps inner loop: bound is `BATCH = 58` as a `const`. Outer loop bound is `NUM_OUTER = 13`, also a `const`. -- For BY `apply_matrix` streamed schoolbook: bound is `const N: u32 = NUM_LIMBS_BY` — must be a `const`, not a runtime expression. -- Test order: `bench-field-mul` micro-bench first (n=2^10 to 2^14), then **only after green**, the dev-page Sanity Check at logN=16. Never invoke any MSM-runtime harness from Node. -- The base-field multiplication (`montgomery_product` Karat+Yuval) just landed; **do not modify it**. -- Never delete the existing `fr_inv` / `fr_inv_plain` / `fr_inv_bgcd`. They stay as A/B fallbacks. - ---- - -## 2. Phase 1 — BY field inversion in WGSL - -### 2.1 Locate the algorithm - -- Driver: `bernstein_yang_inverse.hpp` lines 290–326 (`invert_bernsteinyang19`). -- 9×29-bit engine: `bernstein_yang_inverse_wasm.hpp` lines 1–258. - - `Wasm9x29::divsteps(delta, f_lo, g_lo)` — lines 147–178. - - `Wasm9x29::apply_matrix(m, f, g, d, e, p, p_inv)` — lines 187–255. - - `Wasm9x29::reduce_to_canonical(p)` — lines 125–145. -- Convergence bound: 735 divsteps cited at header lines 26–27. With BATCH=58 → ⌈735/58⌉ = 13 outer iters. - -### 2.2 Iteration count and determinism - -- `NUM_OUTER = 13` hard cap, with early exit on `g == 0`. -- `BATCH = 58` inner divsteps per outer iter. -- Variable-time over branches; BN254 base-field values in our pipeline are public, so OK. -- Fully deterministic for a given input. - -### 2.3 WGSL representation - -Decisive choice: **Option B — `BigIntBY = array` of 29-bit signed limbs.** This matches the WASM and reaches the perf target. - -Conversion on entry/exit between the 20×13-bit `BigInt` and `BigIntBY`. The conversion is ~20 ops each way; per-call cost amortises over NUM_OUTER × BATCH ≈ 750 inner ops + 13 matrix applications. - -### 2.4 New WGSL files / signatures - -Create (Mustache partial `{{> by_inverse_funcs }}`): - -**File: `barretenberg/ts/src/msm_webgpu/wgsl/field/by_inverse.template.wgsl`** - -Top-level entry (drop-in replacement for existing `fr_inv`): - -```wgsl -// Bernstein-Yang safegcd inverse on 9 × 29-bit signed limbs. -// Input in Montgomery form. Output mont(a^(-1)). -fn fr_inv_by(a: BigInt) -> BigInt -``` - -Required constants and helpers (loop bounds all `const`): - -```wgsl -const BY_NUM_LIMBS: u32 = 9u; -const BY_LIMB_BITS: u32 = 29u; -const BY_LIMB_MASK: u32 = (1u << 29u) - 1u; -const BY_BATCH: u32 = 58u; -const BY_NUM_OUTER: u32 = 13u; -const BY_REDUCE_INTERVAL: u32 = 4u; -const BY_RTC_MAX_ITERS: u32 = 36u; // matches Wasm9x29::reduce_to_canonical - -struct BigIntBY { l: array }; - -fn by_from_bigint(x: BigInt) -> BigIntBY; -fn by_to_bigint(x: BigIntBY) -> BigInt; -fn by_get_p() -> BigIntBY; -fn by_one() -> BigIntBY; -fn by_low_u64_lohi(x: BigIntBY) -> vec2; -fn by_is_zero(x: BigIntBY) -> bool; -fn by_is_negative(x: BigIntBY) -> bool; -fn by_neg(x: BigIntBY) -> BigIntBY; -fn by_normalise(x: ptr); -fn by_reduce_to_canonical(x: ptr, p: ptr); - -// Matrix entries split into (lo: i32, hi: i32) representing i64 values. -// After BATCH=58 divsteps, |entry| ≤ 2^58. -struct Mat { u: i32, v: i32, q: i32, r: i32, u_hi: i32, v_hi: i32, q_hi: i32, r_hi: i32 }; - -fn by_divsteps(delta: ptr, f_lo: vec2, g_lo: vec2) -> Mat; -fn by_apply_matrix_fg(m: Mat, f: ptr, g: ptr); -fn by_apply_matrix_de(m: Mat, d: ptr, e: ptr, - p: ptr, p_inv_lo: u32, p_inv_hi: u32); - -fn fr_inv_by(a: BigInt) -> BigInt; -``` - -`by_divsteps`: transliterate `Wasm9x29::divsteps` lines 147–178. Use `vec2` for the 64-bit `f_lo` and `g_lo` carriers (WGSL has no native i64). Carry the matrix entries `u, v, q, r` as paired `(lo: i32, hi: i32)` because they grow up to 2^58. Loop bound: `for (var i: u32 = 0u; i < BY_BATCH; i = i + 1u) { ... }`. - -`by_apply_matrix_fg` / `by_apply_matrix_de`: transliterate lines 196–254. Each per-limb `m_lo * limb` is an i58, NOT i32. Define a single safe `signed_mul_split(a: i32, b: i32) -> vec2` helper bounded to |a|, |b| ≤ 2^29 and reuse everywhere. The coding agent picks the exact partial-product splits; the contract is only that each partial fits in i32. - -### 2.5 Test harness — `fr_inv` micro-bench - -Add `barretenberg/ts/src/msm_webgpu/wgsl/cuzk/fr_inv_bench.template.wgsl` (mirrors `field_mul_bench_u32.template.wgsl`). Per-thread chained `fr_inv_` `k` times, write to outputs. - -Host-side: `gen_fr_inv_bench_shader(workgroup_size, variant)` in `shader_manager.ts`, `--variant fr_inv_by` whitelisted in `bench-field-mul.mjs` and `bench-field-mul.ts`. Reference: `modInverse` from `cuzk/bn254.ts` with Mont conversion. - -**Acceptance criteria for Phase 1:** -1. `bench-field-mul.mjs --path u32 --variant fr_inv_by --n 1024 --k 1 --validate-n 1024` → all 1024 match host reference. -2. `--n 65536 --k 10` runs to completion (no hang, no `[shader fr_inv_bench] error:` console message). -3. `fr_inv_by` ≥ 2× faster than `fr_inv` median wall (target 3–5×). - -### 2.6 Wiring into production - -After Phase 1 acceptance: -1. `wgsl/cuzk/batch_inverse_parallel.template.wgsl` line ~219: `fr_inv` → `fr_inv_by`. -2. `wgsl/cuzk/batch_inverse.template.wgsl` line ~77: `fr_inv` → `fr_inv_by`. -3. `shader_manager.ts`: include `{{> by_inverse_funcs }}` in `gen_batch_inverse_parallel_shader` / `gen_batch_inverse_shader` partials. -4. Run Quick Sanity Check at logN=16 via Playwright; expect `[sanity] PASS`. If FAIL, revert and bisect via the micro-bench. - ---- - -## 3. Phase 2 — Multi-window batched Pippenger + 32-bit point schedule - -### 3.1 WASM multi-round structure - -(All line refs in `scalar_multiplication.cpp`.) - -**Outer dispatch (4551–4604):** Lower region + optional Upper region. We use only the lower region (single c). - -**Per region (4570–4602):** iterate windows in batches of `windows_per_batch`. Within one batch: -- **Stage 1 (3785–3877):** per-thread per-window digit histogram. Output `digit_cursors[(w · T + t) · bucket_stride + d]`. -- **Stage 2 (3879–3909):** per-thread → per-window prefix-sum. Writes per-(window, thread, digit) cursor base; writes per-digit totals to `bucket_start[d+1]`. -- **Stage 3 (3911–3937):** per-window serial prefix-sum on `bucket_start`. -- **Stage 4 (3939–4075):** scatter. Re-decodes each scalar's window-w digit, writes the 32-bit schedule entry to `schedule[w * capacity + bucket_start[d] + cursor[d]++]`. Dedup OFF. -- **Stage 5 (4211–4217):** per-window chunk partition. -- **Stage 6a (4344–4399):** per-(thread, window) batched-affine bucket reduction → `bucket_partials_dense`. -- **Stage 6b (4401–4525):** cross-thread, per-task slice `[d_lo, d_hi]`, `recursive_affine_bucket_reduce_strided` — the multi-window batched inversion. -- **Stage 7 (4534–4548):** per-window combine of per-thread partials. - -Final Horner combine over all windows: lines 4606–4615. - -### 3.2 32-bit schedule entry encoding - -Adopt bit-for-bit from WASM (lines 552–567): -- bit 31: sign -- bit 30: dedup redirect (always zero) -- bit 29: dedup skip (always zero) -- bits 0..28: scalar_idx (≤ 2^29 = 512M, plenty for logN ≤ 28) - -### 3.3 WGSL changes - -**Replace:** -- `wgsl/cuzk/batch_affine_schedule.template.wgsl` — delete the per-round bucket-cursor / atomic pair counter. Replace with three new shaders: - -**New `wgsl/cuzk/schedule_histogram.template.wgsl`** (Stage 1) -```wgsl -// Per-thread per-window per-digit histogram. -// Dispatch: (ceil(n / wg_size), 1, num_subtasks_in_batch) -// const NUM_WINDOWS_IN_BATCH: u32 = {{ num_windows_in_batch }}u; -// const NUM_BUCKETS: u32 = {{ num_columns }}u; -// Writes digit_cursors[(w * num_threads + tid) * num_buckets + d]. -``` - -**New `wgsl/cuzk/schedule_offsets.template.wgsl`** (Stage 2 + 3) -```wgsl -// One workgroup per (window, bucket-slice). Per-window prefix-sum. -// Output: bucket_start[w][d+1], digit_cursors[w][t][d]. -``` - -**New `wgsl/cuzk/schedule_scatter.template.wgsl`** (Stage 4) -```wgsl -// Dispatch: (ceil(n / wg_size), 1, num_subtasks_in_batch) -// sched[w * capacity + bucket_start[w][d] + cursor++] = sign << 31 | scalar_idx -``` - -**Keep + extend:** -- `batch_affine_apply_scatter.template.wgsl`: bind layout reads from bucket-sorted schedule; affine-add math unchanged. -- `batch_inverse_parallel.template.wgsl`: Z dimension becomes `num_subtasks × NUM_WINDOWS_IN_BATCH`. Inside, decode `wid.z` into `(subtask_in_batch, window_in_batch)`. -- `batch_affine_finalize_collect.template.wgsl` / `_apply.template.wgsl`: unchanged (called once at end of MSM). - -**New `wgsl/cuzk/bucket_reduce.template.wgsl`** (Stage 6a per-window single-thread bucket accumulator). Per-window kernel that: -1. Reads `schedule[w][chunk_start..chunk_end]` (bucket-sorted). -2. Accumulates each run of contiguous same-bucket entries via the existing batched-affine tree reduce (reuses `batch_inverse_parallel`). -3. Output per-(thread, window) `bucket_partials_dense`. - -### 3.4 Host TS changes - -`cuzk/batch_affine.ts` — major rewrite of `smvp_batch_affine_gpu`: -1. Add `windows_per_batch: number` (start = 4). -2. Replace init + schedule + (per-round inverse+apply) with: dispatch histogram → offsets → scatter → outer loop over batches → per-batch round loop with Z dispatch `windows_per_batch × num_subtasks_in_batch`. -3. Buffer changes: drop `pair_counter` (replaced by per-(w, subtask) atomic). Drop `bucket_cursor` (replaced by `digit_cursors`). Add `bucket_start`. Add `schedule` (32-bit bucket-sorted, ~`num_subtasks × num_columns × 4` bytes ≈ 2 MB at logN=16). - -`cuzk/shader_manager.ts` — add: -- `gen_schedule_histogram_shader(workgroup_size, num_columns, num_windows_in_batch)` -- `gen_schedule_offsets_shader(workgroup_size, num_columns, num_windows_in_batch)` -- `gen_schedule_scatter_shader(workgroup_size, num_columns, num_windows_in_batch)` - -Bump cache keys with new tag `mwb-v1`. - -`msm.ts` — at the `smvp_batch_affine_gpu` call, add `windows_per_batch: 4`. - -`cuzk/batch_affine_bn254.ts` (host reference) — extend `batchAffineMSM` with `windowsPerBatch`; one batched inversion spans pairs from all windows in the batch. **Required as ground truth for correctness tests.** - -### 3.5 Constants exported WASM → WGSL - -| Constant | Value | WGSL exposure | -|---|---|---| -| `SCHEDULE_SIGN_BIT` (line 559) | `1 << 31` | `const SCHED_SIGN_BIT: u32 = 1u << 31u;` | -| `DEDUP_REDIRECT_BIT` (560) | `1 << 30` | `const SCHED_REDIRECT_BIT: u32 = 1u << 30u;` (always zero) | -| `DEDUP_SKIP_BIT` (561) | `1 << 29` | `const SCHED_SKIP_BIT: u32 = 1u << 29u;` (always zero) | -| `SCHEDULE_INDEX_MASK` (562) | `(1<<29) - 1` | `const SCHED_INDEX_MASK: u32 = (1u << 29u) - 1u;` | -| `BATCH_CAPACITY` (596) | 256 | `const BATCH_AFFINE_BREAKEVEN: u32 = 256u;` | -| `BATCH_AFFINE_BREAKEVEN` (1525) | 32 | `const BATCH_AFFINE_DRAIN_THRESHOLD: u32 = 32u;` | - -`chunk_size` (c) stays at 15/16 per `msm.ts:554`. `num_columns = 2^c`. - -### 3.6 Intermediate validation milestones (each is a hard gate) - -For each milestone, the test is the Quick Sanity Check button at logN=16 via Playwright (`[sanity] PASS in N ms`): - -- After `shader_manager` additions, before host orchestrator changes: WGSL compile-only check via `getCompilationInfo()`. -- After histogram + offsets + scatter, `windows_per_batch = 1`: read back schedule on n=2^10 / 2^12; per-(w, d, k) entry matches host's bucket-sorted ground truth (set equality). -- After `bucket_reduce`, `windows_per_batch = 1`: Sanity Check PASS at logN=16. -- After `NUM_WINDOWS_PER_BATCH = 2`: Sanity Check PASS at logN=16. -- After `NUM_WINDOWS_PER_BATCH = 4`: Sanity Check PASS at logN=16 + visible `ba_inverse + ba_apply` wall reduction in `Profiler.report()`. - -### 3.7 Workgroup sizing - -- `schedule_histogram`: WG=256, dispatch `(ceil(n/256), 1, num_subtasks_in_batch)`. Per-thread arrays (no shared workgroup atomics). -- `schedule_offsets`: WG=64, dispatch `(1, 1, num_windows_in_batch)`. Per-thread → cross-thread → per-digit prefix sums. -- `schedule_scatter`: WG=256, same dispatch shape as histogram. -- `bucket_reduce`: WG=64 (matches existing apply_scatter). Z = `num_subtasks × NUM_WINDOWS_PER_BATCH`. -- `batch_inverse_parallel`: WG=64. Z = `num_subtasks × NUM_WINDOWS_PER_BATCH`. - -### 3.8 Loop-bound audit - -All loops introduced in Phase 2 use a `const`-bounded counter: -- `schedule_histogram` inner: `for (var w = 0u; w < NUM_WINDOWS_IN_BATCH; ...)`. -- `schedule_offsets` reductions: `for (var t = 0u; t < TPB; ...)`. -- Hillis-Steele scan: `for (var stride: u32 = 1u; stride < TPB; stride = stride * 2u)`. -- `schedule_scatter` window loop: `NUM_WINDOWS_IN_BATCH`. -- `bucket_reduce` tree-reduce pass: bounded by `BATCH_AFFINE_BREAKEVEN`. - -Audit step after every render: `grep -E 'for *\(.*<' rendered.wgsl | grep -v -E '< [A-Z][A-Z_]*[a-z]?|< [0-9]+|< [a-z_]+\.x' | grep -v 'workgroup_size'`. - ---- - -## 4. Test plan - -| Phase | Test | Harness | Pass | -|---|---|---|---| -| 1.A | BY divsteps TS unit test | new Jest `cuzk/bernstein_yang.test.ts`, ~1000 random inputs vs `modInverse` | all match | -| 1.B | WGSL `fr_inv_by` correctness | `bench-field-mul.mjs --variant fr_inv_by --n 1024 --validate-n 1024 --k 1` | all 1024 match | -| 1.C | WGSL `fr_inv_by` perf | same w/ `--n 65536 --k 10 --reps 5` | ≥ 2× faster median than `fr_inv` | -| 1.D | E2E Sanity w/ BY swap-in | Playwright Quick Sanity Check button | `[sanity] PASS` | -| 2.A | Schedule correctness | `wgsl_unit_tests.ts` helper, n=2^10 | set equality vs host ground truth | -| 2.B | Bucket reduction `windows_per_batch=1` | Quick Sanity Check at logN=16 | PASS | -| 2.C | `NUM_WINDOWS_PER_BATCH=2` | Quick Sanity Check at logN=16 | PASS | -| 2.D | `NUM_WINDOWS_PER_BATCH=4` | Quick Sanity Check at logN=16, 18 | PASS + ≥ 1.5× wall reduction | - -**Critical safety rule:** Full MSM correctness ONLY via the Quick Sanity Check button via Playwright. NEVER invoke `compute_bn254_msm_*` directly from Node — the dev-page-button-with-Playwright is the only path validated against Apple Silicon Metal. Micro-bench (`bench-field-mul`) is for primitives only. - ---- - -## 5. Iteration breakdown (17 sub-steps, ≥ 10 floor met) - -**Phase 1:** - -1. **1.1** — Transliterate `Wasm9x29::divsteps` + `apply_matrix` + `reduce_to_canonical` + driver to TS. Jest `bernstein_yang.test.ts` with 1000 random inputs vs `modInverse`. Gate: all match. -2. **1.2** — Add WGSL bigint helpers: `signed_mul_split`, vec2 64-bit add/sub/shift, `by_normalise` carry propagation. New `wgsl/bigint/bigint_by.template.wgsl`. Unit-test via scratch shader. -3. **1.3** — Write WGSL `by_divsteps`. Validate via `divsteps_bench` shader vs TS port. -4. **1.4** — Write WGSL `by_apply_matrix_fg` / `by_apply_matrix_de`. Precompute `p_inv_by_lo` / `p_inv_by_hi` via Mustache in `shader_manager.ts`. -5. **1.5** — Wire `fr_inv_by` + `by_reduce_to_canonical`. Add `gen_fr_inv_bench_shader` + `--variant fr_inv_by`. `--n 1024 --validate-n 1024 --k 1` → all match. Hard gate. -6. **1.6** — Perf pass. `--n 65536 --k 10 --reps 5`. Hard gate: ≥ 2× over `fr_inv`. -7. **1.7** — Swap `fr_inv` → `fr_inv_by` in `batch_inverse_parallel` and `batch_inverse`. Quick Sanity Check at logN=16. Hard gate: PASS. - -**Phase 2:** - -8. **2.1** — Host BigInt reference for multi-window batched Pippenger. Extend `cuzk/batch_affine_bn254.ts` with `windowsPerBatch`. Jest cross-check vs `windowsPerBatch=1`. -9. **2.2** — Stage 1 `schedule_histogram`. Add unit test in `wgsl_unit_tests.ts` dispatching on n=2^10, compare per-(w, t, d) vs host. -10. **2.3** — Stage 2/3 `schedule_offsets`. Validate `bucket_start` after kernel = exclusive prefix of `Σ_t digit_cursors`. -11. **2.4** — Stage 4 `schedule_scatter`. Validate via read-back test 2.A. Gate: set equality. -12. **2.5** — `bucket_reduce` for one window (`NUM_WINDOWS_PER_BATCH=1`). Reuse `batch_affine_apply_scatter` math; rewire input from bucket-sorted schedule. -13. **2.6** — Rewire `batch_affine.ts` to dispatch histogram → offsets → scatter → bucket_reduce → finalize at `windows_per_batch=1`. Gate: Sanity Check PASS at logN=16, 14, 12. -14. **2.7** — Bump `NUM_WINDOWS_PER_BATCH` to 2. Decode `wid.z` into (subtask_in_batch, window_in_batch). Gate: Sanity Check PASS at logN=16. -15. **2.8** — Bump to 4 + profile. Gates: Sanity Check PASS at logN=16, 18; ≥ 1.5× wall reduction on `ba_inverse + ba_apply` summed across batches. -16. **2.9** — Cleanup + cache-key bump to `mwb-v1`. Re-run all sanity gates. -17. **2.10** — Final integration. Sanity Check at logN=14, 15, 16, 17, 18, 19, 20. Each must PASS. Wall time vs pre-rewrite baseline. - ---- - -## 6. Out of scope (per user) - -- Duplicate stripping (Phase A / dedup). Bits 29 and 30 of the schedule stay zero. -- Two bucket widths for one MSM (variable-window split). -- Adaptive c. - -If a coding agent finds themselves implementing any of these three, STOP. - ---- - -## Critical files for implementation - -- `/Users/zac/aztec-packages/barretenberg/ts/src/msm_webgpu/wgsl/field/fr_pow.template.wgsl` -- `/Users/zac/aztec-packages/barretenberg/ts/src/msm_webgpu/wgsl/bigint/bigint.template.wgsl` -- `/Users/zac/aztec-packages/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/batch_inverse_parallel.template.wgsl` -- `/Users/zac/aztec-packages/barretenberg/ts/src/msm_webgpu/cuzk/batch_affine.ts` -- `/Users/zac/aztec-packages/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts` - -Reference-only — source of all the algorithm structure: -- `/Users/zac/barretenberg-claude-2/barretenberg/cpp/src/barretenberg/ecc/fields/bernstein_yang_inverse_wasm.hpp` -- `/Users/zac/barretenberg-claude-2/barretenberg/cpp/src/barretenberg/ecc/scalar_multiplication/scalar_multiplication.cpp` lines 540–710, 1620–1800, 2671–2830, 3780–4080, 4210–4610.