From 57aa045208d2689481290c92a4876f3cd621d803 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 1 Feb 2026 17:06:02 -0500 Subject: [PATCH 1/9] Revert "Feat: read `prod_evals` and `logup_evals` from hint in `sumcheck_layer_eval` (#30)" This reverts commit ef22e8ecb9965091783d2c0369b8379e7f683f53. --- .../circuit/cuda/include/native/sumcheck.cuh | 8 +- .../native/circuit/cuda/src/sumcheck.cu | 24 ++-- extensions/native/circuit/src/sumcheck/air.rs | 93 ++++++--------- .../native/circuit/src/sumcheck/chip.rs | 111 ++++++++++-------- .../native/circuit/src/sumcheck/columns.rs | 10 +- .../native/circuit/src/sumcheck/execution.rs | 31 +++-- .../native/compiler/src/conversion/mod.rs | 6 +- .../native/compiler/src/ir/instructions.rs | 4 +- extensions/native/compiler/src/ir/sumcheck.rs | 14 ++- extensions/native/recursion/tests/sumcheck.rs | 92 ++++----------- 10 files changed, 164 insertions(+), 229 deletions(-) diff --git a/extensions/native/circuit/cuda/include/native/sumcheck.cuh b/extensions/native/circuit/cuda/include/native/sumcheck.cuh index 4f16d4fad7..da7aefd769 100644 --- a/extensions/native/circuit/cuda/include/native/sumcheck.cuh +++ b/extensions/native/circuit/cuda/include/native/sumcheck.cuh @@ -7,10 +7,8 @@ using namespace native; template struct HeaderSpecificCols { T pc; - T registers[3]; - T prod_id; - T logup_id; - MemoryReadAuxCols read_records[6]; + T registers[5]; + MemoryReadAuxCols read_records[8]; MemoryWriteAuxCols write_records; }; @@ -63,7 +61,7 @@ template struct NativeSumcheckCols { T start_timestamp; T last_timestamp; - T register_ptrs[3]; + T register_ptrs[5]; T ctx[EXT_DEG * 2]; diff --git a/extensions/native/circuit/cuda/src/sumcheck.cu b/extensions/native/circuit/cuda/src/sumcheck.cu index def3db0d1d..173abebbe2 100644 --- a/extensions/native/circuit/cuda/src/sumcheck.cu +++ b/extensions/native/circuit/cuda/src/sumcheck.cu @@ -11,7 +11,7 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h uint32_t start_timestamp = row[COL_INDEX(NativeSumcheckCols, start_timestamp)].asUInt32(); if (row[COL_INDEX(NativeSumcheckCols, header_row)] == Fp::one()) { - for (uint32_t i = 0; i < 6; ++i) { + for (uint32_t i = 0; i < 8; ++i) { mem_fill_base( mem_helper, start_timestamp + i, @@ -26,32 +26,32 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h ); } else if (row[COL_INDEX(NativeSumcheckCols, prod_row)] == Fp::one()) { if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { - // mem_fill_base( - // mem_helper, - // start_timestamp, - // specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base)) - // ); mem_fill_base( mem_helper, start_timestamp, + specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 1, specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base)) ); } } else if (row[COL_INDEX(NativeSumcheckCols, logup_row)] == Fp::one()) { if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { - // mem_fill_base( - // mem_helper, - // start_timestamp, - // specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base)) - // ); mem_fill_base( mem_helper, start_timestamp, - specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base)) + specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base)) ); mem_fill_base( mem_helper, start_timestamp + 1, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 2, specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base)) ); } diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index a82141332b..5bec217a1f 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -23,10 +23,6 @@ use crate::{ }, }; -pub const TOPLEVEL_TIMESTAMP_DIFF: usize = 6; -pub const NUM_RWS_FOR_PRODUCT: usize = 1; -pub const NUM_RWS_FOR_LOGUP: usize = 2; - #[derive(Clone, Debug)] pub struct NativeSumcheckAir { pub execution_bridge: ExecutionBridge, @@ -230,23 +226,21 @@ impl Air for NativeSumcheckAir { .when(next.prod_row + next.logup_row) .assert_eq( next.start_timestamp, - start_timestamp + AB::F::from_canonical_usize(TOPLEVEL_TIMESTAMP_DIFF), + start_timestamp + AB::F::from_canonical_usize(8), ); builder .when(prod_row) .when(next.prod_row + next.logup_row) .assert_eq( next.start_timestamp, - start_timestamp - + within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_PRODUCT), + start_timestamp + within_round_limit * AB::F::TWO, ); builder .when(logup_row) .when(next.prod_row + next.logup_row) .assert_eq( next.start_timestamp, - start_timestamp - + within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_LOGUP), + start_timestamp + within_round_limit * AB::F::from_canonical_usize(3), ); // Termination condition @@ -294,31 +288,27 @@ impl Air for NativeSumcheckAir { .execute_and_increment_pc( AB::Expr::from_canonical_usize(SUMCHECK_LAYER_EVAL.global_opcode().as_usize()), [ - registers[2].into(), + registers[4].into(), registers[0].into(), registers[1].into(), native_as.into(), native_as.into(), - header_row_specific.prod_id.into(), - header_row_specific.logup_id.into(), + registers[2].into(), + registers[3].into(), ], ExecutionState::new(header_row_specific.pc, first_timestamp), last_timestamp - first_timestamp, ) .eval(builder, header_row); - let mut header_timestamp_diff = (0..TOPLEVEL_TIMESTAMP_DIFF) - .into_iter() - .map(|i| AB::F::from_canonical_usize(i)); - let mut header_read_records_iter = header_row_specific.read_records.iter(); // Read registers - for i in 0..3usize { + for i in 0..5usize { self.memory_bridge .read( MemoryAddress::new(native_as, registers[i]), [register_ptrs[i]], - first_timestamp + header_timestamp_diff.next().unwrap(), - header_read_records_iter.next().unwrap(), + first_timestamp + AB::F::from_canonical_usize(i), + &header_row_specific.read_records[i], ) .eval(builder, header_row); } @@ -328,8 +318,8 @@ impl Air for NativeSumcheckAir { .read( MemoryAddress::new(native_as, register_ptrs[0]), ctx, - first_timestamp + header_timestamp_diff.next().unwrap(), - header_read_records_iter.next().unwrap(), + first_timestamp + AB::F::from_canonical_usize(5), + &header_row_specific.read_records[5], ) .eval(builder, header_row); @@ -338,8 +328,8 @@ impl Air for NativeSumcheckAir { .read( MemoryAddress::new(native_as, register_ptrs[1]), challenges, - first_timestamp + header_timestamp_diff.next().unwrap(), - header_read_records_iter.next().unwrap(), + first_timestamp + AB::F::from_canonical_usize(6), + &header_row_specific.read_records[6], ) .eval(builder, header_row); @@ -351,15 +341,15 @@ impl Air for NativeSumcheckAir { register_ptrs[0] + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN), ), [max_round], - first_timestamp + header_timestamp_diff.next().unwrap(), - header_read_records_iter.next().unwrap(), + first_timestamp + AB::F::from_canonical_usize(7), + &header_row_specific.read_records[7], ) .eval(builder, header_row); // Write final result self.memory_bridge .write( - MemoryAddress::new(native_as, register_ptrs[2]), + MemoryAddress::new(native_as, register_ptrs[4]), eval_acc, last_timestamp - AB::F::ONE, &header_row_specific.write_records, @@ -393,17 +383,14 @@ impl Air for NativeSumcheckAir { ); builder.assert_eq(prod_row * should_acc, prod_acc); - let mut prod_timestamp_diff = (0..1).into_iter().map(|i| AB::F::from_canonical_usize(i)); - - // TODO: read from hint space then write to memory - // self.memory_bridge - // .read( - // MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), - // prod_row_specific.p, - // start_timestamp, - // &prod_row_specific.read_records[0], - // ) - // .eval(builder, prod_row * within_round_limit); + self.memory_bridge + .read( + MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), + prod_row_specific.p, + start_timestamp, + &prod_row_specific.read_records[0], + ) + .eval(builder, prod_row * within_round_limit); let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().unwrap(); let p2: [AB::Var; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)] @@ -414,16 +401,14 @@ impl Air for NativeSumcheckAir { .write( MemoryAddress::new( native_as, - register_ptrs[2] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), + register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), ), prod_row_specific.p_evals, - start_timestamp + prod_timestamp_diff.next().unwrap(), + start_timestamp + AB::F::ONE, &prod_row_specific.write_record, ) .eval(builder, prod_row * within_round_limit); - assert!(prod_timestamp_diff.next().is_none()); - // Calculate evaluations let next_round_p_evals = FieldExtension::add( FieldExtension::multiply::(p1, c1), @@ -486,16 +471,14 @@ impl Air for NativeSumcheckAir { ); builder.assert_eq(logup_row * should_acc, logup_acc); - let mut logup_timestamp_diff = (0..2).into_iter().map(|i| AB::F::from_canonical_usize(i)); - - // self.memory_bridge - // .read( - // MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), - // logup_row_specific.pq, - // start_timestamp, - // &logup_row_specific.read_records[0], - // ) - // .eval(builder, logup_row * within_round_limit); + self.memory_bridge + .read( + MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), + logup_row_specific.pq, + start_timestamp, + &logup_row_specific.read_records[0], + ) + .eval(builder, logup_row * within_round_limit); let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().unwrap(); let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)] @@ -513,11 +496,11 @@ impl Air for NativeSumcheckAir { .write( MemoryAddress::new( native_as, - register_ptrs[2] + register_ptrs[4] + (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.p_evals, - start_timestamp + logup_timestamp_diff.next().unwrap(), + start_timestamp + AB::F::ONE, &logup_row_specific.write_records[0], ) .eval(builder, logup_row * within_round_limit); @@ -527,12 +510,12 @@ impl Air for NativeSumcheckAir { .write( MemoryAddress::new( native_as, - register_ptrs[2] + register_ptrs[4] + (num_prod_spec + num_logup_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.q_evals, - start_timestamp + logup_timestamp_diff.next().unwrap(), + start_timestamp + AB::F::TWO, &logup_row_specific.write_records[1], ) .eval(builder, logup_row * within_round_limit); diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 4f7ba79efd..d5e6f49a62 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -125,8 +125,8 @@ where c: challenges_reg, d: data_address_space, e: register_address_space, - f: prod_evals_id_ptr, - g: logup_evals_id_ptr, + f: prod_evals_reg, + g: logup_evals_reg, } = instruction; // This opcode supports two modes of operation: @@ -167,49 +167,53 @@ where head_specific.registers[0] = ctx_reg; head_specific.registers[1] = challenges_reg; - head_specific.registers[2] = r_evals_reg; + head_specific.registers[2] = prod_evals_reg; + head_specific.registers[3] = logup_evals_reg; + head_specific.registers[4] = r_evals_reg; - let mut head_read_records_iter = head_specific.read_records.iter_mut().map(|r| r.as_mut()); // read pointers let [ctx_ptr]: [F; 1] = tracing_read_native_helper( state.memory, ctx_reg.as_canonical_u32(), - head_read_records_iter.next().unwrap(), + head_specific.read_records[0].as_mut(), ); let [challenges_ptr]: [F; 1] = tracing_read_native_helper( state.memory, challenges_reg.as_canonical_u32(), - head_read_records_iter.next().unwrap(), + head_specific.read_records[1].as_mut(), + ); + let [prod_evals_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + prod_evals_reg.as_canonical_u32(), + head_specific.read_records[2].as_mut(), + ); + let [logup_evals_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + logup_evals_reg.as_canonical_u32(), + head_specific.read_records[3].as_mut(), ); - let [prod_evals_id]: [F; 1] = - memory_read_native(state.memory.data(), prod_evals_id_ptr.as_canonical_u32()); - let [logup_evals_id]: [F; 1] = - memory_read_native(state.memory.data(), logup_evals_id_ptr.as_canonical_u32()); let [r_evals_ptr]: [F; 1] = tracing_read_native_helper( state.memory, r_evals_reg.as_canonical_u32(), - head_read_records_iter.next().unwrap(), + head_specific.read_records[4].as_mut(), ); let ctx: [F; CONTEXT_ARR_BASE_LEN] = tracing_read_native_helper( state.memory, ctx_ptr.as_canonical_u32(), - head_read_records_iter.next().unwrap(), + head_specific.read_records[5].as_mut(), ); let challenges: [F; EXT_DEG * 4] = tracing_read_native_helper( state.memory, challenges_ptr.as_canonical_u32(), - head_read_records_iter.next().unwrap(), + head_specific.read_records[6].as_mut(), ); let [max_round]: [F; 1] = tracing_read_native_helper( state.memory, ctx_ptr.as_canonical_u32() + CONTEXT_ARR_BASE_LEN as u32, - head_read_records_iter.next().unwrap(), + head_specific.read_records[7].as_mut(), ); - assert!(head_read_records_iter.next().is_none()); - cur_timestamp += 6; // 3 register reads + ctx read + challenges read + max_round read + cur_timestamp += 8; // 5 register reads + ctx read + challenges read + max_round read head_row.challenges.copy_from_slice(&challenges); - head_specific.prod_id = prod_evals_id_ptr; - head_specific.logup_id = logup_evals_id_ptr; // challenges = [alpha, c1=r, c2=1-r] let alpha: [F; 4] = challenges[0..EXT_DEG].try_into().unwrap(); @@ -231,15 +235,12 @@ where F::from_canonical_u32(logup_specs_inner_len * logup_specs_inner_inner_len); row.register_ptrs[0] = ctx_ptr; row.register_ptrs[1] = challenges_ptr; - row.register_ptrs[2] = r_evals_ptr; + row.register_ptrs[2] = prod_evals_ptr; + row.register_ptrs[3] = logup_evals_ptr; + row.register_ptrs[4] = r_evals_ptr; row.max_round = max_round; } - let prod_evals_id = prod_evals_id.as_canonical_u32(); - let logup_evals_id = logup_evals_id.as_canonical_u32(); - let prod_evals = state.streams.hint_space[prod_evals_id as usize].clone(); - let logup_evals = state.streams.hint_space[logup_evals_id as usize].clone(); - // product rows for (i, prod_row) in rows .iter_mut() @@ -270,12 +271,15 @@ where i as u32, round, 0, - ) as usize; - prod_specific.data_ptr = F::from_canonical_usize(start); + ); + prod_specific.data_ptr = F::from_canonical_u32(start); - // read p1, p2 from hint space - let ps: [F; EXT_DEG * 2] = - prod_evals[start..start + EXT_DEG * 2].try_into().unwrap(); + // read p1, p2 + let ps: [F; EXT_DEG * 2] = tracing_read_native_helper( + state.memory, + prod_evals_ptr.as_canonical_u32() + start, + prod_specific.read_records[0].as_mut(), + ); let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); @@ -309,7 +313,7 @@ where eval, &mut prod_specific.write_record, ); - cur_timestamp += 1; // 1 write + cur_timestamp += 2; let eval_rlc = FieldExtension::multiply(alpha_acc, eval); prod_specific.eval_rlc = eval_rlc; @@ -353,12 +357,15 @@ where i as u32, round, 0, - ) as usize; - logup_specific.data_ptr = F::from_canonical_usize(start); + ); + logup_specific.data_ptr = F::from_canonical_u32(start); - // read p1, p2, q1, q2 from hint space - let pqs: [F; EXT_DEG * 4] = - logup_evals[start..start + EXT_DEG * 4].try_into().unwrap(); + // read p1, p2, q1, q2 + let pqs: [F; EXT_DEG * 4] = tracing_read_native_helper( + state.memory, + logup_evals_ptr.as_canonical_u32() + start, + logup_specific.read_records[0].as_mut(), + ); let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); let q1: [F; EXT_DEG] = pqs[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().unwrap(); @@ -416,7 +423,7 @@ where q_eval, &mut logup_specific.write_records[1], ); - cur_timestamp += 2; // 0 read, 2 writes + cur_timestamp += 3; // 1 read, 2 writes let eval_rlc = FieldExtension::add( FieldExtension::multiply(alpha_numerator, p_eval), @@ -495,7 +502,7 @@ impl TraceFiller for NativeSumcheckFiller { let header: &mut HeaderSpecificCols = cols.specific[..HeaderSpecificCols::::width()].borrow_mut(); - for i in 0..6usize { + for i in 0..8usize { mem_fill_helper( mem_helper, start_timestamp + i as u32, @@ -512,16 +519,16 @@ impl TraceFiller for NativeSumcheckFiller { cols.specific[..ProdSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { - // TODO: read p1, p2 from hint space, then write to memory - // mem_fill_helper( - // mem_helper, - // start_timestamp, - // prod_row_specific.read_records[0].as_mut(), - // ); - // write p_eval + // read p1, p2 mem_fill_helper( mem_helper, start_timestamp, + prod_row_specific.read_records[0].as_mut(), + ); + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp + 1, prod_row_specific.write_record.as_mut(), ); } @@ -530,22 +537,22 @@ impl TraceFiller for NativeSumcheckFiller { cols.specific[..LogupSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { - // TODO: read p1, p2, q1, q2 from hint space - // mem_fill_helper( - // mem_helper, - // start_timestamp, - // logup_row_specific.read_records[0].as_mut(), - // ); - // write p_eval + // read p1, p2, q1, q2 mem_fill_helper( mem_helper, start_timestamp, + logup_row_specific.read_records[0].as_mut(), + ); + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp + 1, logup_row_specific.write_records[0].as_mut(), ); // write q_eval mem_fill_helper( mem_helper, - start_timestamp + 1, + start_timestamp + 2, logup_row_specific.write_records[1].as_mut(), ); } diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index 4904e28278..51eb6d39cf 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -38,7 +38,7 @@ pub struct NativeSumcheckCols { pub last_timestamp: T, // Register values - pub register_ptrs: [T; 3], + pub register_ptrs: [T; 5], // Context variables // [ @@ -91,11 +91,9 @@ pub struct NativeSumcheckCols { #[derive(AlignedBorrow)] pub struct HeaderSpecificCols { pub pc: T, - pub registers: [T; 3], - pub prod_id: T, - pub logup_id: T, - /// 3 register reads + ctx read + max round read + challenges read - pub read_records: [MemoryReadAuxCols; 6], + pub registers: [T; 5], + /// 5 register reads + ctx read + max round read + challenges read + pub read_records: [MemoryReadAuxCols; 8], /// Write the final evaluation pub write_records: MemoryWriteAuxCols, } diff --git a/extensions/native/circuit/src/sumcheck/execution.rs b/extensions/native/circuit/src/sumcheck/execution.rs index 95901f86ef..54e06e540c 100644 --- a/extensions/native/circuit/src/sumcheck/execution.rs +++ b/extensions/native/circuit/src/sumcheck/execution.rs @@ -23,8 +23,8 @@ struct NativeSumcheckPreCompute { r_evals_reg: u32, ctx_reg: u32, challenges_reg: u32, - prod_evals_id_ptr: u32, - logup_evals_id_ptr: u32, + prod_evals_reg: u32, + logup_evals_reg: u32, } impl NativeSumcheckExecutor { @@ -49,8 +49,8 @@ impl NativeSumcheckExecutor { let r_evals_reg = a.as_canonical_u32(); let ctx_reg = b.as_canonical_u32(); let challenges_reg = c.as_canonical_u32(); - let prod_evals_id_ptr = f.as_canonical_u32(); - let logup_evals_id_ptr = g.as_canonical_u32(); + let prod_evals_reg = f.as_canonical_u32(); + let logup_evals_reg = g.as_canonical_u32(); if d.as_canonical_u32() != NATIVE_AS { return Err(StaticProgramError::InvalidInstruction(pc)); @@ -63,8 +63,8 @@ impl NativeSumcheckExecutor { r_evals_reg, ctx_reg, challenges_reg, - prod_evals_id_ptr, - logup_evals_id_ptr, + prod_evals_reg, + logup_evals_reg, }; Ok(()) @@ -199,13 +199,13 @@ unsafe fn execute_e12_impl( let [r_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.r_evals_reg); let [ctx_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.ctx_reg); let [challenges_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.challenges_reg); - let [prod_evals_id]: [F; 1] = exec_state.host_read(NATIVE_AS, pre_compute.prod_evals_id_ptr); - let [logup_evals_id]: [F; 1] = exec_state.host_read(NATIVE_AS, pre_compute.logup_evals_id_ptr); + let [prod_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.prod_evals_reg); + let [logup_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.logup_evals_reg); let r_evals_ptr_u32 = r_evals_ptr.as_canonical_u32(); let ctx_ptr_u32 = ctx_ptr.as_canonical_u32(); - let logup_evals_id = logup_evals_id.as_canonical_u32(); - let prod_evals_id = prod_evals_id.as_canonical_u32(); + let logup_evals_ptr = logup_evals_ptr.as_canonical_u32(); + let prod_evals_ptr = prod_evals_ptr.as_canonical_u32(); let ctx: [u32; 8] = exec_state .vm_read(NATIVE_AS, ctx_ptr_u32) @@ -224,9 +224,6 @@ unsafe fn execute_e12_impl( let mut alpha_acc = elem_to_ext(F::ONE); let mut eval_acc = elem_to_ext(F::ZERO); - let prod_evals = exec_state.streams.hint_space[prod_evals_id as usize].clone(); - let logup_evals = exec_state.streams.hint_space[logup_evals_id as usize].clone(); - for i in 0..num_prod_spec { let start = calculate_3d_ext_idx( prod_specs_inner_inner_len, @@ -234,10 +231,10 @@ unsafe fn execute_e12_impl( i, round, 0, - ) as usize; + ); if round < max_round - 1 { - let ps: &[F] = &prod_evals[start..start + EXT_DEG * 2]; + let ps: [F; EXT_DEG * 2] = exec_state.vm_read(NATIVE_AS, prod_evals_ptr + start); let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = ps[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); @@ -271,14 +268,14 @@ unsafe fn execute_e12_impl( i, round, 0, - ) as usize; + ); let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); let alpha_numerator = alpha_acc; if round < max_round - 1 { // read logup_evals - let pqs: &[F] = &logup_evals[start..start + EXT_DEG * 4]; + let pqs: [F; EXT_DEG * 4] = exec_state.vm_read(NATIVE_AS, logup_evals_ptr + start); let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = pqs[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); let q1: [F; EXT_DEG] = pqs[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); diff --git a/extensions/native/compiler/src/conversion/mod.rs b/extensions/native/compiler/src/conversion/mod.rs index 79797ff0ae..e71608c190 100644 --- a/extensions/native/compiler/src/conversion/mod.rs +++ b/extensions/native/compiler/src/conversion/mod.rs @@ -536,7 +536,7 @@ fn convert_instruction>( AS::Immediate, )] }, - AsmInstruction::SumcheckLayerEval(ctx, cs, prod_id, logup_id, r_ptr) => vec![ + AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr, r_ptr) => vec![ Instruction { opcode: options.opcode_with_offset(SumcheckOpcode::SUMCHECK_LAYER_EVAL), a: i32_f(r_ptr), @@ -544,8 +544,8 @@ fn convert_instruction>( c: i32_f(cs), d: AS::Native.to_field(), e: AS::Native.to_field(), - f: i32_f(prod_id), - g: i32_f(logup_id), + f: i32_f(p_ptr), + g: i32_f(l_ptr), } ], }; diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index d44bead6f9..78347283d5 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -341,8 +341,8 @@ pub enum DslIr { // 7. Operational mode indicator // 8+. usize-type variables indicating maximum rounds Ptr, // Challenges: alpha, coeffs - Var, // prod_specs_eval - Var, // logup_specs_eval + Ptr, // prod_specs_eval + Ptr, // logup_specs_eval Ptr, // output ), } diff --git a/extensions/native/compiler/src/ir/sumcheck.rs b/extensions/native/compiler/src/ir/sumcheck.rs index 792c9543b2..0237fd6740 100644 --- a/extensions/native/compiler/src/ir/sumcheck.rs +++ b/extensions/native/compiler/src/ir/sumcheck.rs @@ -1,4 +1,4 @@ -use super::{Array, Builder, Config, DslIr, Ext, Usize, Var}; +use super::{Array, Builder, Config, DslIr, Ext, Usize}; impl Builder { /// Extends native VM ability to calculate the evaluation for a sumcheck layer @@ -30,16 +30,18 @@ impl Builder { &mut self, input_ctx: &Array>, // Context variables challenges: &Array>, // Challenges - prod_specs_eval_id: Var, /* ID for GKR product IOP evaluations hint. */ - logup_specs_eval_id: Var, /* ID for GKR logup IOP evaluations hint. */ + prod_specs_eval: &Array>, /* GKR product IOP evaluations. Flattened + * from 3D array. */ + logup_specs_eval: &Array>, /* GKR logup IOP evaluations. Flattened + * from 3D array. */ r_evals: &Array>, /* Next layer's evaluations (pointer used for - * storing opcode output) */ + * storing opcode output) */ ) { self.operations.push(DslIr::SumcheckLayerEval( input_ctx.ptr(), challenges.ptr(), - prod_specs_eval_id, - logup_specs_eval_id, + prod_specs_eval.ptr(), + logup_specs_eval.ptr(), r_evals.ptr(), )); } diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs index 145336c362..284a103021 100644 --- a/extensions/native/recursion/tests/sumcheck.rs +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -26,32 +26,9 @@ pub type E = BinomialExtensionField; #[test] fn test_sumcheck_layer_eval() { - let mut rng = thread_rng(); let mut builder = AsmBuilder::>::default(); - let num_layers = 8; - let num_prod_specs = 6; - let num_logup_specs = 8; - - let prod_evals: Vec = (0..(num_prod_specs * num_layers * 2)) - .into_iter() - .map(|_| new_rand_ext(&mut rng)) - .collect(); - - let logup_evals: Vec = (0..(num_logup_specs * num_layers * 4)) - .into_iter() - .map(|_| new_rand_ext(&mut rng)) - .collect(); - - build_test_program( - &mut builder, - prod_evals.clone(), - logup_evals.clone(), - num_prod_specs, - num_logup_specs, - num_layers, - 3, - ); + build_test_program(&mut builder); let compilation_options = CompilerOptions::default().with_cycle_tracker(); let mut compiler = AsmCompiler::new(compilation_options.word_size); @@ -72,35 +49,13 @@ fn test_sumcheck_layer_eval() { standard_fri_params_with_100_bits_conjectured_security(1) }; - let mut input_stream: Vec> = vec![]; - input_stream.push( - prod_evals - .into_iter() - .flat_map(|e| >::as_base_slice(&e).to_vec()) - .collect(), - ); - input_stream.push( - logup_evals - .into_iter() - .flat_map(|e| >::as_base_slice(&e).to_vec()) - .collect(), - ); - let mut config = NativeConfig::aggregation(0, sumcheck_max_constraint_degree); config.system.memory_config.max_access_adapter_n = 16; let vb = NativeBuilder::default(); #[cfg(not(feature = "cuda"))] - air_test_impl::( - fri_params, - vb, - config, - program, - input_stream, - 1, - true, - ) - .unwrap(); + air_test_impl::(fri_params, vb, config, program, vec![], 1, true) + .unwrap(); #[cfg(feature = "cuda")] { air_test_impl::( @@ -108,7 +63,7 @@ fn test_sumcheck_layer_eval() { vb, config, program, - input_stream, + vec![], 1, true, ) @@ -116,24 +71,22 @@ fn test_sumcheck_layer_eval() { } } -fn new_rand_ext(rng: &mut R) -> E { - E::from_base_slice(&[ - F::from_canonical_u32(rng.next_u32()), - F::from_canonical_u32(rng.next_u32()), - F::from_canonical_u32(rng.next_u32()), - F::from_canonical_u32(rng.next_u32()), +fn new_rand_ext(rng: &mut R) -> C::EF { + C::EF::from_base_slice(&[ + C::F::from_canonical_u32(rng.next_u32()), + C::F::from_canonical_u32(rng.next_u32()), + C::F::from_canonical_u32(rng.next_u32()), + C::F::from_canonical_u32(rng.next_u32()), ]) } -fn build_test_program( - builder: &mut Builder, - prod_evals: Vec, - logup_evals: Vec, - num_prod_specs: usize, - num_logup_specs: usize, - num_layers: usize, - round: usize, -) { +fn build_test_program(builder: &mut Builder) { + let mut rng = thread_rng(); + // 6 prod specs in 8 layers, 5 logup specs in 8 layers + let round = 3; + let num_prod_specs = 6; + let num_logup_specs = 5; + let num_layers = 8; let mode = 1; // current_layer let mut ctx_u32s = vec![ @@ -175,7 +128,7 @@ fn build_test_program( let num_prod_evals = num_prod_specs * num_layers * 2; let prod_spec_evals: Array> = builder.dyn_array(num_prod_evals); for idx in 0..num_prod_evals { - let e: Ext = builder.constant(prod_evals[idx]); + let e: Ext = builder.constant(new_rand_ext::(&mut rng)); builder.set(&prod_spec_evals, idx, e); } @@ -183,7 +136,7 @@ fn build_test_program( let num_logup_evals = num_logup_specs * num_layers * 4; let logup_spec_evals: Array> = builder.dyn_array(num_logup_evals); for idx in 0..num_logup_evals { - let e: Ext = builder.constant(logup_evals[idx]); + let e: Ext = builder.constant(new_rand_ext::(&mut rng)); builder.set(&logup_spec_evals, idx, e); } @@ -249,16 +202,13 @@ fn build_test_program( .chain(logup_q_evals) .collect::>(); - let prod_spec_evals_id = builder.hint_load(); - let logup_spec_evals_id = builder.hint_load(); - let next_layer_evals: Array> = builder.dyn_array(r_evals.len()); builder.sumcheck_layer_eval( &ctx, &challenges, - prod_spec_evals_id, - logup_spec_evals_id, + &prod_spec_evals, + &logup_spec_evals, &next_layer_evals, ); From 0910b68d9790e43ce6f8397945dfe00e8eb0de2b Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 1 Feb 2026 19:46:43 -0500 Subject: [PATCH 2/9] consolidate hint slice changes --- extensions/native/circuit/src/sumcheck/air.rs | 38 +++++- .../native/circuit/src/sumcheck/chip.rs | 114 ++++++++++++++---- .../native/circuit/src/sumcheck/columns.rs | 12 +- .../native/circuit/src/sumcheck/execution.rs | 23 +++- 4 files changed, 151 insertions(+), 36 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 5bec217a1f..01d9a59e0d 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -23,6 +23,9 @@ use crate::{ }, }; +pub const NUM_RWS_FOR_PRODUCT: usize = 2; +pub const NUM_RWS_FOR_LOGUP: usize = 3; + #[derive(Clone, Debug)] pub struct NativeSumcheckAir { pub execution_bridge: ExecutionBridge, @@ -99,6 +102,9 @@ impl Air for NativeSumcheckAir { within_round_limit, should_acc, eval_acc, + is_hint_src_id, + prod_evals_id: _, + logup_evals_id: _, specific, } = local; @@ -233,14 +239,14 @@ impl Air for NativeSumcheckAir { .when(next.prod_row + next.logup_row) .assert_eq( next.start_timestamp, - start_timestamp + within_round_limit * AB::F::TWO, + start_timestamp + within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_PRODUCT), ); builder .when(logup_row) .when(next.prod_row + next.logup_row) .assert_eq( next.start_timestamp, - start_timestamp + within_round_limit * AB::F::from_canonical_usize(3), + start_timestamp + within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_LOGUP), ); // Termination condition @@ -390,7 +396,18 @@ impl Air for NativeSumcheckAir { start_timestamp, &prod_row_specific.read_records[0], ) - .eval(builder, prod_row * within_round_limit); + .eval(builder, prod_row * within_round_limit * not(is_hint_src_id)); + self.memory_bridge + .write( + MemoryAddress::new( + native_as, + register_ptrs[2] + prod_row_specific.data_ptr, + ), + prod_row_specific.p, + start_timestamp, + &prod_row_specific.write_ps_record, + ) + .eval(builder, prod_row * within_round_limit * is_hint_src_id); let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().unwrap(); let p2: [AB::Var; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)] @@ -478,8 +495,19 @@ impl Air for NativeSumcheckAir { start_timestamp, &logup_row_specific.read_records[0], ) - .eval(builder, logup_row * within_round_limit); - + .eval(builder, logup_row * within_round_limit * not(is_hint_src_id)); + self.memory_bridge + .write( + MemoryAddress::new( + native_as, + register_ptrs[3] + + logup_row_specific.data_ptr, + ), + logup_row_specific.pq, + start_timestamp, + &logup_row_specific.write_pqs_record, + ) + .eval(builder, logup_row * within_round_limit * is_hint_src_id); let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().unwrap(); let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)] .try_into() diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index d5e6f49a62..b76d7f8b5c 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -207,7 +207,7 @@ where challenges_ptr.as_canonical_u32(), head_specific.read_records[6].as_mut(), ); - let [max_round]: [F; 1] = tracing_read_native_helper( + let [max_round, is_hint_src_id, prod_evals_id, logup_evals_id]: [F; EXT_DEG] = tracing_read_native_helper( state.memory, ctx_ptr.as_canonical_u32() + CONTEXT_ARR_BASE_LEN as u32, head_specific.read_records[7].as_mut(), @@ -223,7 +223,7 @@ where let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); - // all rows share same register values, ctx, challenges, max_round + // all rows share same register values, ctx, challenges, max_round, hint_space_ptrs (optional) for row in rows.iter_mut() { // c1, c2 are same during the entire execution row.challenges[EXT_DEG..3 * EXT_DEG].copy_from_slice(&challenges[EXT_DEG..3 * EXT_DEG]); @@ -239,8 +239,24 @@ where row.register_ptrs[3] = logup_evals_ptr; row.register_ptrs[4] = r_evals_ptr; row.max_round = max_round; + row.is_hint_src_id = is_hint_src_id; + row.prod_evals_id = prod_evals_id; + row.logup_evals_id = logup_evals_id; } + // Load hints if source is a ptr + let is_hint_src_id = is_hint_src_id > F::ZERO; + let prod_evals_id = prod_evals_id.as_canonical_u32(); + let logup_evals_id = logup_evals_id.as_canonical_u32(); + let (prod_evals, logup_evals) = if is_hint_src_id { + ( + state.streams.hint_space[prod_evals_id as usize].clone(), + state.streams.hint_space[logup_evals_id as usize].clone(), + ) + } else { + (Vec::new(), Vec::new()) + }; + // product rows for (i, prod_row) in rows .iter_mut() @@ -275,16 +291,30 @@ where prod_specific.data_ptr = F::from_canonical_u32(start); // read p1, p2 - let ps: [F; EXT_DEG * 2] = tracing_read_native_helper( - state.memory, - prod_evals_ptr.as_canonical_u32() + start, - prod_specific.read_records[0].as_mut(), - ); + let ps: [F; EXT_DEG * 2] = if is_hint_src_id { + prod_evals[(start as usize)..((start as usize) + EXT_DEG * 2)].try_into().unwrap() + } else { + tracing_read_native_helper( + state.memory, + prod_evals_ptr.as_canonical_u32() + start, + prod_specific.read_records[0].as_mut(), + ) + }; let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); prod_specific.p = ps; + // If p values come from the hint stream, write back to the actual witness array + if is_hint_src_id { + tracing_write_native_inplace( + state.memory, + prod_evals_ptr.as_canonical_u32() + start, + ps, + &mut prod_specific.write_ps_record, + ); + } + // compute expected eval let eval = match mode { NEXT_LAYER_MODE => FieldExtension::add( @@ -313,7 +343,7 @@ where eval, &mut prod_specific.write_record, ); - cur_timestamp += 2; + cur_timestamp += 2; // Either 1 read, 1 write (witness array input), or 2 writes (hint_ptr_id) let eval_rlc = FieldExtension::multiply(alpha_acc, eval); prod_specific.eval_rlc = eval_rlc; @@ -361,11 +391,15 @@ where logup_specific.data_ptr = F::from_canonical_u32(start); // read p1, p2, q1, q2 - let pqs: [F; EXT_DEG * 4] = tracing_read_native_helper( - state.memory, - logup_evals_ptr.as_canonical_u32() + start, - logup_specific.read_records[0].as_mut(), - ); + let pqs: [F; EXT_DEG * 4] = if is_hint_src_id { + logup_evals[(start as usize)..(start as usize) + EXT_DEG * 4].try_into().unwrap() + } else { + tracing_read_native_helper( + state.memory, + logup_evals_ptr.as_canonical_u32() + start, + logup_specific.read_records[0].as_mut(), + ) + }; let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); let q1: [F; EXT_DEG] = pqs[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().unwrap(); @@ -373,6 +407,16 @@ where logup_specific.pq = pqs; + // write pqs + if is_hint_src_id { + tracing_write_native_inplace( + state.memory, + logup_evals_ptr.as_canonical_u32() + start, + pqs, + &mut logup_specific.write_pqs_record, + ); + } + // compute expected evals let p_eval = match mode { NEXT_LAYER_MODE => FieldExtension::add( @@ -423,7 +467,7 @@ where q_eval, &mut logup_specific.write_records[1], ); - cur_timestamp += 3; // 1 read, 2 writes + cur_timestamp += 3; // 1 read, 2 writes (witness array case) or 3 writes (hint space ptr case) let eval_rlc = FieldExtension::add( FieldExtension::multiply(alpha_numerator, p_eval), @@ -519,12 +563,21 @@ impl TraceFiller for NativeSumcheckFiller { cols.specific[..ProdSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { - // read p1, p2 - mem_fill_helper( - mem_helper, - start_timestamp, - prod_row_specific.read_records[0].as_mut(), - ); + if cols.is_hint_src_id == F::ONE { + // write p1, p2 to witness arrays from hint space pointers + mem_fill_helper( + mem_helper, + start_timestamp, + prod_row_specific.write_ps_record.as_mut(), + ); + } else { + // read p1, p2 from witness arrays + mem_fill_helper( + mem_helper, + start_timestamp, + prod_row_specific.read_records[0].as_mut(), + ); + } // write p_eval mem_fill_helper( mem_helper, @@ -537,12 +590,21 @@ impl TraceFiller for NativeSumcheckFiller { cols.specific[..LogupSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { - // read p1, p2, q1, q2 - mem_fill_helper( - mem_helper, - start_timestamp, - logup_row_specific.read_records[0].as_mut(), - ); + if cols.is_hint_src_id == F::ONE { + // write p1, p2, q1, q2 to witness arrays from hint space pointers + mem_fill_helper( + mem_helper, + start_timestamp, + logup_row_specific.write_pqs_record.as_mut(), + ); + } else { + // read p1, p2, q1, q2 from witness arrays + mem_fill_helper( + mem_helper, + start_timestamp, + logup_row_specific.read_records[0].as_mut(), + ); + } // write p_eval mem_fill_helper( mem_helper, diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index 51eb6d39cf..ee83f05e74 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -73,6 +73,12 @@ pub struct NativeSumcheckCols { // The current final evaluation accumulator. Extension element. pub eval_acc: [T; EXT_DEG], + // Indicator for an alternative source form of the inputs prod_evals/logup_evals + pub is_hint_src_id: T, + // Pointer ids for hint slices + pub prod_evals_id: T, + pub logup_evals_id: T, + // /// 1. For header row, 5 registers, ctx, challenges // /// 2. For the rest: max_variables, p1, p2, q1, q2 // pub read_records: [MemoryReadAuxCols; 7], @@ -92,7 +98,7 @@ pub struct NativeSumcheckCols { pub struct HeaderSpecificCols { pub pc: T, pub registers: [T; 5], - /// 5 register reads + ctx read + max round read + challenges read + /// 5 register reads + ctx read + max round/hint ptrs read + challenges read pub read_records: [MemoryReadAuxCols; 8], /// Write the final evaluation pub write_records: MemoryWriteAuxCols, @@ -111,6 +117,8 @@ pub struct ProdSpecificCols { pub p_evals: [T; EXT_DEG], /// write p_evals pub write_record: MemoryWriteAuxCols, + /// write p1, p2 values back to witness array if the source is hint space id + pub write_ps_record: MemoryWriteAuxCols, /// p_evals * alpha^i pub eval_rlc: [T; EXT_DEG], } @@ -130,6 +138,8 @@ pub struct LogupSpecificCols { pub q_evals: [T; EXT_DEG], /// write both p_evals and q_evals pub write_records: [MemoryWriteAuxCols; 2], + /// write p1, p2, q1, q2 back to witness array if the source is hint space id + pub write_pqs_record: MemoryWriteAuxCols, /// Evaluation for the accumulator pub eval_rlc: [T; EXT_DEG], } diff --git a/extensions/native/circuit/src/sumcheck/execution.rs b/extensions/native/circuit/src/sumcheck/execution.rs index 54e06e540c..f0a6bef04a 100644 --- a/extensions/native/circuit/src/sumcheck/execution.rs +++ b/extensions/native/circuit/src/sumcheck/execution.rs @@ -214,8 +214,8 @@ unsafe fn execute_e12_impl( ctx; let challenges: [F; EXT_DEG * 4] = exec_state.vm_read(NATIVE_AS, challenges_ptr.as_canonical_u32()); - let [max_round]: [u32; 1] = - exec_state.vm_read(NATIVE_AS, ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32); + let [max_round, is_hint_space_ids, prod_evals_id, logup_evals_id]: [u32; 4] = + exec_state.vm_read(NATIVE_AS, ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32).map(|x: F| x.as_canonical_u32()); let alpha: [F; EXT_DEG] = challenges[0..EXT_DEG].try_into().unwrap(); let c1: [F; EXT_DEG] = challenges[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); let c2: [F; EXT_DEG] = challenges[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); @@ -224,6 +224,12 @@ unsafe fn execute_e12_impl( let mut alpha_acc = elem_to_ext(F::ONE); let mut eval_acc = elem_to_ext(F::ZERO); + let (prod_evals, logup_evals) = if is_hint_space_ids > 0 { + (exec_state.streams.hint_space[prod_evals_id as usize].clone(), exec_state.streams.hint_space[logup_evals_id as usize].clone()) + } else { + (Vec::new(), Vec::new()) + }; + for i in 0..num_prod_spec { let start = calculate_3d_ext_idx( prod_specs_inner_inner_len, @@ -234,7 +240,12 @@ unsafe fn execute_e12_impl( ); if round < max_round - 1 { - let ps: [F; EXT_DEG * 2] = exec_state.vm_read(NATIVE_AS, prod_evals_ptr + start); + let ps: &[F] = if is_hint_space_ids > 0 { + &prod_evals[(start as usize)..(start as usize) + EXT_DEG * 2] + } else { + &exec_state.vm_read::<_, {EXT_DEG * 2}>(NATIVE_AS, prod_evals_ptr + start) + }; + let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = ps[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); @@ -275,7 +286,11 @@ unsafe fn execute_e12_impl( if round < max_round - 1 { // read logup_evals - let pqs: [F; EXT_DEG * 4] = exec_state.vm_read(NATIVE_AS, logup_evals_ptr + start); + let pqs: &[F] = if is_hint_space_ids > 0 { + &logup_evals[(start as usize)..(start as usize) + EXT_DEG * 4] + } else { + &exec_state.vm_read::<_, {EXT_DEG * 4}>(NATIVE_AS, logup_evals_ptr + start) + }; let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = pqs[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); let q1: [F; EXT_DEG] = pqs[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); From f197f67fc2bd14cd611847a9634c3d632c22a0c6 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 2 Feb 2026 16:06:27 -0500 Subject: [PATCH 3/9] debug --- extensions/native/circuit/src/sumcheck/air.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 01d9a59e0d..d63b6d28f0 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -103,8 +103,8 @@ impl Air for NativeSumcheckAir { should_acc, eval_acc, is_hint_src_id, - prod_evals_id: _, - logup_evals_id: _, + prod_evals_id, + logup_evals_id, specific, } = local; @@ -346,7 +346,7 @@ impl Air for NativeSumcheckAir { native_as, register_ptrs[0] + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN), ), - [max_round], + [max_round, is_hint_src_id, prod_evals_id, logup_evals_id], first_timestamp + AB::F::from_canonical_usize(7), &header_row_specific.read_records[7], ) From 43d63efb6e9cead7a3bb8a4ffde8fdcf681a8225 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 2 Feb 2026 16:19:02 -0500 Subject: [PATCH 4/9] debug --- .../native/circuit/cuda/include/native/sumcheck.cuh | 4 ++++ extensions/native/circuit/src/sumcheck/air.rs | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/extensions/native/circuit/cuda/include/native/sumcheck.cuh b/extensions/native/circuit/cuda/include/native/sumcheck.cuh index da7aefd769..30389134ea 100644 --- a/extensions/native/circuit/cuda/include/native/sumcheck.cuh +++ b/extensions/native/circuit/cuda/include/native/sumcheck.cuh @@ -80,6 +80,10 @@ template struct NativeSumcheckCols { T eval_acc[EXT_DEG]; + T is_hint_src_id; + T prod_evals_id; + T logup_evals_id; + T specific[COL_SPECIFIC_WIDTH]; }; diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index d63b6d28f0..42917c7b4e 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -396,7 +396,7 @@ impl Air for NativeSumcheckAir { start_timestamp, &prod_row_specific.read_records[0], ) - .eval(builder, prod_row * within_round_limit * not(is_hint_src_id)); + .eval(builder, (prod_in_round_evaluation + prod_next_round_evaluation) * not(is_hint_src_id)); self.memory_bridge .write( MemoryAddress::new( @@ -407,7 +407,7 @@ impl Air for NativeSumcheckAir { start_timestamp, &prod_row_specific.write_ps_record, ) - .eval(builder, prod_row * within_round_limit * is_hint_src_id); + .eval(builder, (prod_in_round_evaluation + prod_next_round_evaluation) * is_hint_src_id); let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().unwrap(); let p2: [AB::Var; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)] @@ -495,7 +495,7 @@ impl Air for NativeSumcheckAir { start_timestamp, &logup_row_specific.read_records[0], ) - .eval(builder, logup_row * within_round_limit * not(is_hint_src_id)); + .eval(builder, (logup_in_round_evaluation + logup_next_round_evaluation) * not(is_hint_src_id)); self.memory_bridge .write( MemoryAddress::new( @@ -507,7 +507,7 @@ impl Air for NativeSumcheckAir { start_timestamp, &logup_row_specific.write_pqs_record, ) - .eval(builder, logup_row * within_round_limit * is_hint_src_id); + .eval(builder, (logup_in_round_evaluation + logup_next_round_evaluation) * is_hint_src_id); let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().unwrap(); let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)] .try_into() From 32a4c370059e4ef9aaee92d28a0fc13953c7ad88 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 2 Feb 2026 17:57:34 -0500 Subject: [PATCH 5/9] fmt --- extensions/native/circuit/src/sumcheck/air.rs | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 42917c7b4e..d32daad91b 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -239,14 +239,16 @@ impl Air for NativeSumcheckAir { .when(next.prod_row + next.logup_row) .assert_eq( next.start_timestamp, - start_timestamp + within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_PRODUCT), + start_timestamp + + within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_PRODUCT), ); builder .when(logup_row) .when(next.prod_row + next.logup_row) .assert_eq( next.start_timestamp, - start_timestamp + within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_LOGUP), + start_timestamp + + within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_LOGUP), ); // Termination condition @@ -396,18 +398,21 @@ impl Air for NativeSumcheckAir { start_timestamp, &prod_row_specific.read_records[0], ) - .eval(builder, (prod_in_round_evaluation + prod_next_round_evaluation) * not(is_hint_src_id)); + .eval( + builder, + (prod_in_round_evaluation + prod_next_round_evaluation) * not(is_hint_src_id), + ); self.memory_bridge .write( - MemoryAddress::new( - native_as, - register_ptrs[2] + prod_row_specific.data_ptr, - ), + MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), prod_row_specific.p, start_timestamp, &prod_row_specific.write_ps_record, ) - .eval(builder, (prod_in_round_evaluation + prod_next_round_evaluation) * is_hint_src_id); + .eval( + builder, + (prod_in_round_evaluation + prod_next_round_evaluation) * is_hint_src_id, + ); let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().unwrap(); let p2: [AB::Var; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)] @@ -495,19 +500,21 @@ impl Air for NativeSumcheckAir { start_timestamp, &logup_row_specific.read_records[0], ) - .eval(builder, (logup_in_round_evaluation + logup_next_round_evaluation) * not(is_hint_src_id)); + .eval( + builder, + (logup_in_round_evaluation + logup_next_round_evaluation) * not(is_hint_src_id), + ); self.memory_bridge .write( - MemoryAddress::new( - native_as, - register_ptrs[3] - + logup_row_specific.data_ptr, - ), + MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), logup_row_specific.pq, start_timestamp, &logup_row_specific.write_pqs_record, ) - .eval(builder, (logup_in_round_evaluation + logup_next_round_evaluation) * is_hint_src_id); + .eval( + builder, + (logup_in_round_evaluation + logup_next_round_evaluation) * is_hint_src_id, + ); let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().unwrap(); let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)] .try_into() From 1ddfaa78d85ddd25c07b29a8b20057e29b2f10d0 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 2 Feb 2026 18:56:22 -0500 Subject: [PATCH 6/9] cuda --- .../native/circuit/cuda/src/sumcheck.cu | 36 +++++++++++++------ .../native/circuit/src/sumcheck/chip.rs | 21 ++++++----- .../native/circuit/src/sumcheck/columns.rs | 4 +-- .../native/circuit/src/sumcheck/execution.rs | 14 +++++--- 4 files changed, 50 insertions(+), 25 deletions(-) diff --git a/extensions/native/circuit/cuda/src/sumcheck.cu b/extensions/native/circuit/cuda/src/sumcheck.cu index 173abebbe2..aab8a90093 100644 --- a/extensions/native/circuit/cuda/src/sumcheck.cu +++ b/extensions/native/circuit/cuda/src/sumcheck.cu @@ -26,11 +26,19 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h ); } else if (row[COL_INDEX(NativeSumcheckCols, prod_row)] == Fp::one()) { if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { - mem_fill_base( - mem_helper, - start_timestamp, - specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base)) - ); + if (row[COL_INDEX(NativeSumcheckCols, is_hint_src_id)] == Fp::one()) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(ProdSpecificCols, write_ps_record.base)) + ); + } else { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base)) + ); + } mem_fill_base( mem_helper, start_timestamp + 1, @@ -39,11 +47,19 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h } } else if (row[COL_INDEX(NativeSumcheckCols, logup_row)] == Fp::one()) { if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { - mem_fill_base( - mem_helper, - start_timestamp, - specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base)) - ); + if (row[COL_INDEX(NativeSumcheckCols, is_hint_src_id)] == Fp::one()) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_pqs_record.base)) + ); + } else { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base)) + ); + } mem_fill_base( mem_helper, start_timestamp + 1, diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index b76d7f8b5c..0d707c1458 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -207,11 +207,12 @@ where challenges_ptr.as_canonical_u32(), head_specific.read_records[6].as_mut(), ); - let [max_round, is_hint_src_id, prod_evals_id, logup_evals_id]: [F; EXT_DEG] = tracing_read_native_helper( - state.memory, - ctx_ptr.as_canonical_u32() + CONTEXT_ARR_BASE_LEN as u32, - head_specific.read_records[7].as_mut(), - ); + let [max_round, is_hint_src_id, prod_evals_id, logup_evals_id]: [F; EXT_DEG] = + tracing_read_native_helper( + state.memory, + ctx_ptr.as_canonical_u32() + CONTEXT_ARR_BASE_LEN as u32, + head_specific.read_records[7].as_mut(), + ); cur_timestamp += 8; // 5 register reads + ctx read + challenges read + max_round read head_row.challenges.copy_from_slice(&challenges); @@ -292,7 +293,9 @@ where // read p1, p2 let ps: [F; EXT_DEG * 2] = if is_hint_src_id { - prod_evals[(start as usize)..((start as usize) + EXT_DEG * 2)].try_into().unwrap() + prod_evals[(start as usize)..((start as usize) + EXT_DEG * 2)] + .try_into() + .unwrap() } else { tracing_read_native_helper( state.memory, @@ -343,7 +346,7 @@ where eval, &mut prod_specific.write_record, ); - cur_timestamp += 2; // Either 1 read, 1 write (witness array input), or 2 writes (hint_ptr_id) + cur_timestamp += 2; // Either 1 read, 1 write (witness array input), or 2 writes (hint_ptr_id) let eval_rlc = FieldExtension::multiply(alpha_acc, eval); prod_specific.eval_rlc = eval_rlc; @@ -392,7 +395,9 @@ where // read p1, p2, q1, q2 let pqs: [F; EXT_DEG * 4] = if is_hint_src_id { - logup_evals[(start as usize)..(start as usize) + EXT_DEG * 4].try_into().unwrap() + logup_evals[(start as usize)..(start as usize) + EXT_DEG * 4] + .try_into() + .unwrap() } else { tracing_read_native_helper( state.memory, diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index ee83f05e74..7331400771 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -118,7 +118,7 @@ pub struct ProdSpecificCols { /// write p_evals pub write_record: MemoryWriteAuxCols, /// write p1, p2 values back to witness array if the source is hint space id - pub write_ps_record: MemoryWriteAuxCols, + pub write_ps_record: MemoryWriteAuxCols, /// p_evals * alpha^i pub eval_rlc: [T; EXT_DEG], } @@ -139,7 +139,7 @@ pub struct LogupSpecificCols { /// write both p_evals and q_evals pub write_records: [MemoryWriteAuxCols; 2], /// write p1, p2, q1, q2 back to witness array if the source is hint space id - pub write_pqs_record: MemoryWriteAuxCols, + pub write_pqs_record: MemoryWriteAuxCols, /// Evaluation for the accumulator pub eval_rlc: [T; EXT_DEG], } diff --git a/extensions/native/circuit/src/sumcheck/execution.rs b/extensions/native/circuit/src/sumcheck/execution.rs index f0a6bef04a..b4e117135d 100644 --- a/extensions/native/circuit/src/sumcheck/execution.rs +++ b/extensions/native/circuit/src/sumcheck/execution.rs @@ -214,8 +214,9 @@ unsafe fn execute_e12_impl( ctx; let challenges: [F; EXT_DEG * 4] = exec_state.vm_read(NATIVE_AS, challenges_ptr.as_canonical_u32()); - let [max_round, is_hint_space_ids, prod_evals_id, logup_evals_id]: [u32; 4] = - exec_state.vm_read(NATIVE_AS, ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32).map(|x: F| x.as_canonical_u32()); + let [max_round, is_hint_space_ids, prod_evals_id, logup_evals_id]: [u32; 4] = exec_state + .vm_read(NATIVE_AS, ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32) + .map(|x: F| x.as_canonical_u32()); let alpha: [F; EXT_DEG] = challenges[0..EXT_DEG].try_into().unwrap(); let c1: [F; EXT_DEG] = challenges[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); let c2: [F; EXT_DEG] = challenges[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); @@ -225,7 +226,10 @@ unsafe fn execute_e12_impl( let mut eval_acc = elem_to_ext(F::ZERO); let (prod_evals, logup_evals) = if is_hint_space_ids > 0 { - (exec_state.streams.hint_space[prod_evals_id as usize].clone(), exec_state.streams.hint_space[logup_evals_id as usize].clone()) + ( + exec_state.streams.hint_space[prod_evals_id as usize].clone(), + exec_state.streams.hint_space[logup_evals_id as usize].clone(), + ) } else { (Vec::new(), Vec::new()) }; @@ -243,7 +247,7 @@ unsafe fn execute_e12_impl( let ps: &[F] = if is_hint_space_ids > 0 { &prod_evals[(start as usize)..(start as usize) + EXT_DEG * 2] } else { - &exec_state.vm_read::<_, {EXT_DEG * 2}>(NATIVE_AS, prod_evals_ptr + start) + &exec_state.vm_read::<_, { EXT_DEG * 2 }>(NATIVE_AS, prod_evals_ptr + start) }; let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); @@ -289,7 +293,7 @@ unsafe fn execute_e12_impl( let pqs: &[F] = if is_hint_space_ids > 0 { &logup_evals[(start as usize)..(start as usize) + EXT_DEG * 4] } else { - &exec_state.vm_read::<_, {EXT_DEG * 4}>(NATIVE_AS, logup_evals_ptr + start) + &exec_state.vm_read::<_, { EXT_DEG * 4 }>(NATIVE_AS, logup_evals_ptr + start) }; let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = pqs[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); From 2ddcfef8c0d2fce81cf207ade9671752c2b83e24 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 3 Feb 2026 17:14:52 -0500 Subject: [PATCH 7/9] reduce cell/column usage --- .../system/memory/offline_checker/columns.rs | 2 +- .../circuit/cuda/include/native/sumcheck.cuh | 4 +- .../native/circuit/cuda/src/sumcheck.cu | 36 ++++--------- extensions/native/circuit/src/sumcheck/air.rs | 24 ++++++--- .../native/circuit/src/sumcheck/chip.rs | 52 ++++++------------- .../native/circuit/src/sumcheck/columns.rs | 19 ++++--- 6 files changed, 54 insertions(+), 83 deletions(-) diff --git a/crates/vm/src/system/memory/offline_checker/columns.rs b/crates/vm/src/system/memory/offline_checker/columns.rs index ef9821f859..63c114193d 100644 --- a/crates/vm/src/system/memory/offline_checker/columns.rs +++ b/crates/vm/src/system/memory/offline_checker/columns.rs @@ -62,7 +62,7 @@ impl MemoryWriteAuxCols { #[repr(C)] #[derive(Clone, Copy, Debug, AlignedBorrow)] pub struct MemoryReadAuxCols { - pub(in crate::system::memory) base: MemoryBaseAuxCols, + pub base: MemoryBaseAuxCols, } impl MemoryReadAuxCols { diff --git a/extensions/native/circuit/cuda/include/native/sumcheck.cuh b/extensions/native/circuit/cuda/include/native/sumcheck.cuh index 30389134ea..a67b134aaf 100644 --- a/extensions/native/circuit/cuda/include/native/sumcheck.cuh +++ b/extensions/native/circuit/cuda/include/native/sumcheck.cuh @@ -8,6 +8,8 @@ using namespace native; template struct HeaderSpecificCols { T pc; T registers[5]; + T prod_evals_id; + T logup_evals_id; MemoryReadAuxCols read_records[8]; MemoryWriteAuxCols write_records; }; @@ -81,8 +83,6 @@ template struct NativeSumcheckCols { T eval_acc[EXT_DEG]; T is_hint_src_id; - T prod_evals_id; - T logup_evals_id; T specific[COL_SPECIFIC_WIDTH]; }; diff --git a/extensions/native/circuit/cuda/src/sumcheck.cu b/extensions/native/circuit/cuda/src/sumcheck.cu index aab8a90093..f89b93f72d 100644 --- a/extensions/native/circuit/cuda/src/sumcheck.cu +++ b/extensions/native/circuit/cuda/src/sumcheck.cu @@ -26,19 +26,11 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h ); } else if (row[COL_INDEX(NativeSumcheckCols, prod_row)] == Fp::one()) { if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { - if (row[COL_INDEX(NativeSumcheckCols, is_hint_src_id)] == Fp::one()) { - mem_fill_base( - mem_helper, - start_timestamp, - specific.slice_from(COL_INDEX(ProdSpecificCols, write_ps_record.base)) - ); - } else { - mem_fill_base( - mem_helper, - start_timestamp, - specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base)) - ); - } + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(ProdSpecificCols, ps_record.base)) + ); mem_fill_base( mem_helper, start_timestamp + 1, @@ -47,19 +39,11 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h } } else if (row[COL_INDEX(NativeSumcheckCols, logup_row)] == Fp::one()) { if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { - if (row[COL_INDEX(NativeSumcheckCols, is_hint_src_id)] == Fp::one()) { - mem_fill_base( - mem_helper, - start_timestamp, - specific.slice_from(COL_INDEX(LogupSpecificCols, write_pqs_record.base)) - ); - } else { - mem_fill_base( - mem_helper, - start_timestamp, - specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base)) - ); - } + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(LogupSpecificCols, pqs_record.base)) + ); mem_fill_base( mem_helper, start_timestamp + 1, diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index d32daad91b..0f988cef55 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -2,7 +2,7 @@ use std::borrow::Borrow; use openvm_circuit::{ arch::{ExecutionBridge, ExecutionState}, - system::memory::{offline_checker::MemoryBridge, MemoryAddress}, + system::memory::{MemoryAddress, offline_checker::{MemoryBridge, MemoryReadAuxCols}}, }; use openvm_circuit_primitives::utils::{and, assert_array_eq, not}; use openvm_instructions::{LocalOpcode, NATIVE_AS}; @@ -103,8 +103,6 @@ impl Air for NativeSumcheckAir { should_acc, eval_acc, is_hint_src_id, - prod_evals_id, - logup_evals_id, specific, } = local; @@ -348,7 +346,7 @@ impl Air for NativeSumcheckAir { native_as, register_ptrs[0] + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN), ), - [max_round, is_hint_src_id, prod_evals_id, logup_evals_id], + [max_round, is_hint_src_id, header_row_specific.prod_evals_id, header_row_specific.logup_evals_id], first_timestamp + AB::F::from_canonical_usize(7), &header_row_specific.read_records[7], ) @@ -391,23 +389,28 @@ impl Air for NativeSumcheckAir { ); builder.assert_eq(prod_row * should_acc, prod_acc); + // Read p1, p2 from witness arrays self.memory_bridge .read( MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), prod_row_specific.p, start_timestamp, - &prod_row_specific.read_records[0], + &MemoryReadAuxCols { + base: prod_row_specific.ps_record.base, + }, ) .eval( builder, (prod_in_round_evaluation + prod_next_round_evaluation) * not(is_hint_src_id), ); + + // Obtain p1, p2 from hint space and write back to witness arrays self.memory_bridge .write( MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), prod_row_specific.p, start_timestamp, - &prod_row_specific.write_ps_record, + &prod_row_specific.ps_record, ) .eval( builder, @@ -493,23 +496,28 @@ impl Air for NativeSumcheckAir { ); builder.assert_eq(logup_row * should_acc, logup_acc); + // Read p1, p2, q1, q2 from witness arrays self.memory_bridge .read( MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), logup_row_specific.pq, start_timestamp, - &logup_row_specific.read_records[0], + &MemoryReadAuxCols { + base: logup_row_specific.pqs_record.base, + }, ) .eval( builder, (logup_in_round_evaluation + logup_next_round_evaluation) * not(is_hint_src_id), ); + + // Obtain p1, p2, q1, q2 from hint space self.memory_bridge .write( MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), logup_row_specific.pq, start_timestamp, - &logup_row_specific.write_pqs_record, + &logup_row_specific.pqs_record, ) .eval( builder, diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 0d707c1458..715912feae 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -241,8 +241,6 @@ where row.register_ptrs[4] = r_evals_ptr; row.max_round = max_round; row.is_hint_src_id = is_hint_src_id; - row.prod_evals_id = prod_evals_id; - row.logup_evals_id = logup_evals_id; } // Load hints if source is a ptr @@ -300,7 +298,7 @@ where tracing_read_native_helper( state.memory, prod_evals_ptr.as_canonical_u32() + start, - prod_specific.read_records[0].as_mut(), + prod_specific.ps_record.as_mut(), ) }; let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); @@ -314,7 +312,7 @@ where state.memory, prod_evals_ptr.as_canonical_u32() + start, ps, - &mut prod_specific.write_ps_record, + &mut prod_specific.ps_record, ); } @@ -402,7 +400,7 @@ where tracing_read_native_helper( state.memory, logup_evals_ptr.as_canonical_u32() + start, - logup_specific.read_records[0].as_mut(), + logup_specific.pqs_record.as_mut(), ) }; let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); @@ -418,7 +416,7 @@ where state.memory, logup_evals_ptr.as_canonical_u32() + start, pqs, - &mut logup_specific.write_pqs_record, + &mut logup_specific.pqs_record, ); } @@ -568,21 +566,12 @@ impl TraceFiller for NativeSumcheckFiller { cols.specific[..ProdSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { - if cols.is_hint_src_id == F::ONE { - // write p1, p2 to witness arrays from hint space pointers - mem_fill_helper( - mem_helper, - start_timestamp, - prod_row_specific.write_ps_record.as_mut(), - ); - } else { - // read p1, p2 from witness arrays - mem_fill_helper( - mem_helper, - start_timestamp, - prod_row_specific.read_records[0].as_mut(), - ); - } + // obtain p1, p2 + mem_fill_helper( + mem_helper, + start_timestamp, + prod_row_specific.ps_record.as_mut(), + ); // write p_eval mem_fill_helper( mem_helper, @@ -595,21 +584,12 @@ impl TraceFiller for NativeSumcheckFiller { cols.specific[..LogupSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { - if cols.is_hint_src_id == F::ONE { - // write p1, p2, q1, q2 to witness arrays from hint space pointers - mem_fill_helper( - mem_helper, - start_timestamp, - logup_row_specific.write_pqs_record.as_mut(), - ); - } else { - // read p1, p2, q1, q2 from witness arrays - mem_fill_helper( - mem_helper, - start_timestamp, - logup_row_specific.read_records[0].as_mut(), - ); - } + // obtain p1, p2, q1, q2 + mem_fill_helper( + mem_helper, + start_timestamp, + logup_row_specific.pqs_record.as_mut(), + ); // write p_eval mem_fill_helper( mem_helper, diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index 7331400771..f02f154cf2 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -75,9 +75,6 @@ pub struct NativeSumcheckCols { // Indicator for an alternative source form of the inputs prod_evals/logup_evals pub is_hint_src_id: T, - // Pointer ids for hint slices - pub prod_evals_id: T, - pub logup_evals_id: T, // /// 1. For header row, 5 registers, ctx, challenges // /// 2. For the rest: max_variables, p1, p2, q1, q2 @@ -98,6 +95,9 @@ pub struct NativeSumcheckCols { pub struct HeaderSpecificCols { pub pc: T, pub registers: [T; 5], + // Pointer ids for hint slices + pub prod_evals_id: T, + pub logup_evals_id: T, /// 5 register reads + ctx read + max round/hint ptrs read + challenges read pub read_records: [MemoryReadAuxCols; 8], /// Write the final evaluation @@ -111,14 +111,13 @@ pub struct ProdSpecificCols { pub data_ptr: T, /// 2 extension elements pub p: [T; EXT_DEG * 2], - /// read 2 p values - pub read_records: [MemoryReadAuxCols; 1], /// Calculated p evals pub p_evals: [T; EXT_DEG], /// write p_evals pub write_record: MemoryWriteAuxCols, - /// write p1, p2 values back to witness array if the source is hint space id - pub write_ps_record: MemoryWriteAuxCols, + /// Scenario 1: read p1, p2 values from witness array + /// Scenario 2: write p1, p2 values back to witness array if the source is hint space id + pub ps_record: MemoryWriteAuxCols, /// p_evals * alpha^i pub eval_rlc: [T; EXT_DEG], } @@ -131,15 +130,15 @@ pub struct LogupSpecificCols { /// 4 extension elements pub pq: [T; EXT_DEG * 4], /// read 4 values: p1, p2, q1, q2 - pub read_records: [MemoryReadAuxCols; 1], /// Calculated p evals pub p_evals: [T; EXT_DEG], /// Calculated q evals pub q_evals: [T; EXT_DEG], + /// Scenario 1: read p1, p2, q1, q2 from witness array + /// Scenario 2: write p1, p2, q1, q2 back to witness array if the source is hint space id + pub pqs_record: MemoryWriteAuxCols, /// write both p_evals and q_evals pub write_records: [MemoryWriteAuxCols; 2], - /// write p1, p2, q1, q2 back to witness array if the source is hint space id - pub write_pqs_record: MemoryWriteAuxCols, /// Evaluation for the accumulator pub eval_rlc: [T; EXT_DEG], } From 63552f2b59fe2e66ea3deaa372de405f06d6f7b6 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 3 Feb 2026 17:36:18 -0500 Subject: [PATCH 8/9] restore sumcheck hint ids unit test --- extensions/native/recursion/tests/sumcheck.rs | 96 +++++++++++++++---- 1 file changed, 76 insertions(+), 20 deletions(-) diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs index 284a103021..6d421f1865 100644 --- a/extensions/native/recursion/tests/sumcheck.rs +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -25,10 +25,33 @@ pub type F = BabyBear; pub type E = BinomialExtensionField; #[test] -fn test_sumcheck_layer_eval() { +fn test_sumcheck_layer_eval_with_hint_ids() { + let mut rng = thread_rng(); let mut builder = AsmBuilder::>::default(); - build_test_program(&mut builder); + let num_layers = 8; + let num_prod_specs = 6; + let num_logup_specs = 8; + + let prod_evals: Vec = (0..(num_prod_specs * num_layers * 2)) + .into_iter() + .map(|_| new_rand_ext(&mut rng)) + .collect(); + + let logup_evals: Vec = (0..(num_logup_specs * num_layers * 4)) + .into_iter() + .map(|_| new_rand_ext(&mut rng)) + .collect(); + + build_test_program( + &mut builder, + prod_evals.clone(), + logup_evals.clone(), + num_prod_specs, + num_logup_specs, + num_layers, + 3, + ); let compilation_options = CompilerOptions::default().with_cycle_tracker(); let mut compiler = AsmCompiler::new(compilation_options.word_size); @@ -49,13 +72,35 @@ fn test_sumcheck_layer_eval() { standard_fri_params_with_100_bits_conjectured_security(1) }; + let mut input_stream: Vec> = vec![]; + input_stream.push( + prod_evals + .into_iter() + .flat_map(|e| >::as_base_slice(&e).to_vec()) + .collect(), + ); + input_stream.push( + logup_evals + .into_iter() + .flat_map(|e| >::as_base_slice(&e).to_vec()) + .collect(), + ); + let mut config = NativeConfig::aggregation(0, sumcheck_max_constraint_degree); config.system.memory_config.max_access_adapter_n = 16; let vb = NativeBuilder::default(); #[cfg(not(feature = "cuda"))] - air_test_impl::(fri_params, vb, config, program, vec![], 1, true) - .unwrap(); + air_test_impl::( + fri_params, + vb, + config, + program, + input_stream, + 1, + true, + ) + .unwrap(); #[cfg(feature = "cuda")] { air_test_impl::( @@ -63,7 +108,7 @@ fn test_sumcheck_layer_eval() { vb, config, program, - vec![], + input_stream, 1, true, ) @@ -71,22 +116,24 @@ fn test_sumcheck_layer_eval() { } } -fn new_rand_ext(rng: &mut R) -> C::EF { - C::EF::from_base_slice(&[ - C::F::from_canonical_u32(rng.next_u32()), - C::F::from_canonical_u32(rng.next_u32()), - C::F::from_canonical_u32(rng.next_u32()), - C::F::from_canonical_u32(rng.next_u32()), +fn new_rand_ext(rng: &mut R) -> E { + E::from_base_slice(&[ + F::from_canonical_u32(rng.next_u32()), + F::from_canonical_u32(rng.next_u32()), + F::from_canonical_u32(rng.next_u32()), + F::from_canonical_u32(rng.next_u32()), ]) } -fn build_test_program(builder: &mut Builder) { - let mut rng = thread_rng(); - // 6 prod specs in 8 layers, 5 logup specs in 8 layers - let round = 3; - let num_prod_specs = 6; - let num_logup_specs = 5; - let num_layers = 8; +fn build_test_program( + builder: &mut Builder, + prod_evals: Vec, + logup_evals: Vec, + num_prod_specs: usize, + num_logup_specs: usize, + num_layers: usize, + round: usize, +) { let mode = 1; // current_layer let mut ctx_u32s = vec![ @@ -98,6 +145,10 @@ fn build_test_program(builder: &mut Builder) { num_layers, 4, mode, + 999, + 1, + 0, + 0, ]; ctx_u32s.extend(repeat_n(num_layers, num_prod_specs + num_logup_specs)); @@ -128,7 +179,7 @@ fn build_test_program(builder: &mut Builder) { let num_prod_evals = num_prod_specs * num_layers * 2; let prod_spec_evals: Array> = builder.dyn_array(num_prod_evals); for idx in 0..num_prod_evals { - let e: Ext = builder.constant(new_rand_ext::(&mut rng)); + let e: Ext = builder.constant(prod_evals[idx]); builder.set(&prod_spec_evals, idx, e); } @@ -136,7 +187,7 @@ fn build_test_program(builder: &mut Builder) { let num_logup_evals = num_logup_specs * num_layers * 4; let logup_spec_evals: Array> = builder.dyn_array(num_logup_evals); for idx in 0..num_logup_evals { - let e: Ext = builder.constant(new_rand_ext::(&mut rng)); + let e: Ext = builder.constant(logup_evals[idx]); builder.set(&logup_spec_evals, idx, e); } @@ -202,6 +253,11 @@ fn build_test_program(builder: &mut Builder) { .chain(logup_q_evals) .collect::>(); + let prod_spec_evals_id = builder.hint_load(); + let logup_spec_evals_id = builder.hint_load(); + builder.set(&ctx, 10, prod_spec_evals_id); + builder.set(&ctx, 11, logup_spec_evals_id); + let next_layer_evals: Array> = builder.dyn_array(r_evals.len()); builder.sumcheck_layer_eval( From fc9ce55cb175342ba44b3edecb286d7377c15627 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 3 Feb 2026 18:23:15 -0500 Subject: [PATCH 9/9] fix sumcheck test --- extensions/native/recursion/tests/sumcheck.rs | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs index 6d421f1865..97bbff57c8 100644 --- a/extensions/native/recursion/tests/sumcheck.rs +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -20,6 +20,7 @@ use openvm_stark_sdk::{ p3_baby_bear::BabyBear, }; use rand::{thread_rng, RngCore}; +const PRIME: u32 = 0x78000001; pub type F = BabyBear; pub type E = BinomialExtensionField; @@ -118,10 +119,10 @@ fn test_sumcheck_layer_eval_with_hint_ids() { fn new_rand_ext(rng: &mut R) -> E { E::from_base_slice(&[ - F::from_canonical_u32(rng.next_u32()), - F::from_canonical_u32(rng.next_u32()), - F::from_canonical_u32(rng.next_u32()), - F::from_canonical_u32(rng.next_u32()), + F::from_canonical_u32(rng.next_u32() % PRIME), + F::from_canonical_u32(rng.next_u32() % PRIME), + F::from_canonical_u32(rng.next_u32() % PRIME), + F::from_canonical_u32(rng.next_u32() % PRIME), ]) } @@ -145,13 +146,11 @@ fn build_test_program( num_layers, 4, mode, - 999, - 1, + 999, // max round + 1, // input from hint ids 0, 0, ]; - ctx_u32s.extend(repeat_n(num_layers, num_prod_specs + num_logup_specs)); - let ctx: Array> = builder.dyn_array(ctx_u32s.len()); for (idx, n) in ctx_u32s.into_iter().enumerate() { builder.set(&ctx, idx, Usize::from(n));