From 279187119f1590edcbac26feed169d5dde08e1d8 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 29 Jan 2026 01:47:39 +0800 Subject: [PATCH 1/4] read prod_evals & logup_evals from hint space --- extensions/native/circuit/src/sumcheck/air.rs | 60 +++++++++--------- .../native/circuit/src/sumcheck/chip.rs | 63 ++++++++----------- .../native/circuit/src/sumcheck/columns.rs | 10 +-- .../native/circuit/src/sumcheck/execution.rs | 31 ++++----- .../native/compiler/src/conversion/mod.rs | 6 +- .../native/compiler/src/ir/instructions.rs | 4 +- extensions/native/compiler/src/ir/sumcheck.rs | 11 ++-- 7 files changed, 90 insertions(+), 95 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 5bec217a1f..2268f77e67 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -288,13 +288,13 @@ impl Air for NativeSumcheckAir { .execute_and_increment_pc( AB::Expr::from_canonical_usize(SUMCHECK_LAYER_EVAL.global_opcode().as_usize()), [ - registers[4].into(), + registers[2].into(), registers[0].into(), registers[1].into(), native_as.into(), native_as.into(), - registers[2].into(), - registers[3].into(), + header_row_specific.prod_id.into(), + header_row_specific.logup_id.into(), ], ExecutionState::new(header_row_specific.pc, first_timestamp), last_timestamp - first_timestamp, @@ -302,7 +302,7 @@ impl Air for NativeSumcheckAir { .eval(builder, header_row); // Read registers - for i in 0..5usize { + for i in 0..3usize { self.memory_bridge .read( MemoryAddress::new(native_as, registers[i]), @@ -318,8 +318,8 @@ impl Air for NativeSumcheckAir { .read( MemoryAddress::new(native_as, register_ptrs[0]), ctx, - first_timestamp + AB::F::from_canonical_usize(5), - &header_row_specific.read_records[5], + first_timestamp + AB::F::from_canonical_usize(3), + &header_row_specific.read_records[3], ) .eval(builder, header_row); @@ -328,8 +328,8 @@ impl Air for NativeSumcheckAir { .read( MemoryAddress::new(native_as, register_ptrs[1]), challenges, - first_timestamp + AB::F::from_canonical_usize(6), - &header_row_specific.read_records[6], + first_timestamp + AB::F::from_canonical_usize(4), + &header_row_specific.read_records[4], ) .eval(builder, header_row); @@ -341,15 +341,15 @@ impl Air for NativeSumcheckAir { register_ptrs[0] + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN), ), [max_round], - first_timestamp + AB::F::from_canonical_usize(7), - &header_row_specific.read_records[7], + first_timestamp + AB::F::from_canonical_usize(5), + &header_row_specific.read_records[5], ) .eval(builder, header_row); // Write final result self.memory_bridge .write( - MemoryAddress::new(native_as, register_ptrs[4]), + MemoryAddress::new(native_as, register_ptrs[2]), eval_acc, last_timestamp - AB::F::ONE, &header_row_specific.write_records, @@ -383,14 +383,14 @@ impl Air for NativeSumcheckAir { ); builder.assert_eq(prod_row * should_acc, prod_acc); - self.memory_bridge - .read( - MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), - prod_row_specific.p, - start_timestamp, - &prod_row_specific.read_records[0], - ) - .eval(builder, prod_row * within_round_limit); + // self.memory_bridge + // .read( + // MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), + // prod_row_specific.p, + // start_timestamp, + // &prod_row_specific.read_records[0], + // ) + // .eval(builder, prod_row * within_round_limit); let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().unwrap(); let p2: [AB::Var; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)] @@ -401,7 +401,7 @@ impl Air for NativeSumcheckAir { .write( MemoryAddress::new( native_as, - register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), + register_ptrs[2] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), ), prod_row_specific.p_evals, start_timestamp + AB::F::ONE, @@ -471,14 +471,14 @@ impl Air for NativeSumcheckAir { ); builder.assert_eq(logup_row * should_acc, logup_acc); - self.memory_bridge - .read( - MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), - logup_row_specific.pq, - start_timestamp, - &logup_row_specific.read_records[0], - ) - .eval(builder, logup_row * within_round_limit); + // self.memory_bridge + // .read( + // MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), + // logup_row_specific.pq, + // start_timestamp, + // &logup_row_specific.read_records[0], + // ) + // .eval(builder, logup_row * within_round_limit); let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().unwrap(); let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)] @@ -496,7 +496,7 @@ impl Air for NativeSumcheckAir { .write( MemoryAddress::new( native_as, - register_ptrs[4] + register_ptrs[2] + (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.p_evals, @@ -510,7 +510,7 @@ impl Air for NativeSumcheckAir { .write( MemoryAddress::new( native_as, - register_ptrs[4] + register_ptrs[2] + (num_prod_spec + num_logup_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), ), diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index d5e6f49a62..7f1e0c9b0b 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -125,8 +125,8 @@ where c: challenges_reg, d: data_address_space, e: register_address_space, - f: prod_evals_reg, - g: logup_evals_reg, + f: prod_evals_id_ptr, + g: logup_evals_id_ptr, } = instruction; // This opcode supports two modes of operation: @@ -167,9 +167,7 @@ where head_specific.registers[0] = ctx_reg; head_specific.registers[1] = challenges_reg; - head_specific.registers[2] = prod_evals_reg; - head_specific.registers[3] = logup_evals_reg; - head_specific.registers[4] = r_evals_reg; + head_specific.registers[2] = r_evals_reg; // read pointers let [ctx_ptr]: [F; 1] = tracing_read_native_helper( @@ -182,37 +180,31 @@ where challenges_reg.as_canonical_u32(), head_specific.read_records[1].as_mut(), ); - let [prod_evals_ptr]: [F; 1] = tracing_read_native_helper( - state.memory, - prod_evals_reg.as_canonical_u32(), - head_specific.read_records[2].as_mut(), - ); - let [logup_evals_ptr]: [F; 1] = tracing_read_native_helper( - state.memory, - logup_evals_reg.as_canonical_u32(), - head_specific.read_records[3].as_mut(), - ); + let [prod_evals_id]: [F; 1] = + memory_read_native(state.memory.data(), prod_evals_id_ptr.as_canonical_u32()); + let [logup_evals_id]: [F; 1] = + memory_read_native(state.memory.data(), logup_evals_id_ptr.as_canonical_u32()); let [r_evals_ptr]: [F; 1] = tracing_read_native_helper( state.memory, r_evals_reg.as_canonical_u32(), - head_specific.read_records[4].as_mut(), + head_specific.read_records[2].as_mut(), ); let ctx: [F; CONTEXT_ARR_BASE_LEN] = tracing_read_native_helper( state.memory, ctx_ptr.as_canonical_u32(), - head_specific.read_records[5].as_mut(), + head_specific.read_records[3].as_mut(), ); let challenges: [F; EXT_DEG * 4] = tracing_read_native_helper( state.memory, challenges_ptr.as_canonical_u32(), - head_specific.read_records[6].as_mut(), + head_specific.read_records[4].as_mut(), ); let [max_round]: [F; 1] = tracing_read_native_helper( state.memory, ctx_ptr.as_canonical_u32() + CONTEXT_ARR_BASE_LEN as u32, - head_specific.read_records[7].as_mut(), + head_specific.read_records[5].as_mut(), ); - cur_timestamp += 8; // 5 register reads + ctx read + challenges read + max_round read + cur_timestamp += 6; // 3 register reads + ctx read + challenges read + max_round read head_row.challenges.copy_from_slice(&challenges); // challenges = [alpha, c1=r, c2=1-r] @@ -235,12 +227,15 @@ where F::from_canonical_u32(logup_specs_inner_len * logup_specs_inner_inner_len); row.register_ptrs[0] = ctx_ptr; row.register_ptrs[1] = challenges_ptr; - row.register_ptrs[2] = prod_evals_ptr; - row.register_ptrs[3] = logup_evals_ptr; - row.register_ptrs[4] = r_evals_ptr; + row.register_ptrs[2] = r_evals_ptr; row.max_round = max_round; } + let prod_evals_id = prod_evals_id.as_canonical_u32(); + let logup_evals_id = logup_evals_id.as_canonical_u32(); + let prod_evals = state.streams.hint_space[prod_evals_id as usize].clone(); + let logup_evals = state.streams.hint_space[logup_evals_id as usize].clone(); + // product rows for (i, prod_row) in rows .iter_mut() @@ -271,15 +266,12 @@ where i as u32, round, 0, - ); - prod_specific.data_ptr = F::from_canonical_u32(start); + ) as usize; + prod_specific.data_ptr = F::from_canonical_usize(start); // read p1, p2 - let ps: [F; EXT_DEG * 2] = tracing_read_native_helper( - state.memory, - prod_evals_ptr.as_canonical_u32() + start, - prod_specific.read_records[0].as_mut(), - ); + let ps: [F; EXT_DEG * 2] = + prod_evals[start..start + EXT_DEG * 2].try_into().unwrap(); let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); @@ -357,15 +349,12 @@ where i as u32, round, 0, - ); - logup_specific.data_ptr = F::from_canonical_u32(start); + ) as usize; + logup_specific.data_ptr = F::from_canonical_usize(start); // read p1, p2, q1, q2 - let pqs: [F; EXT_DEG * 4] = tracing_read_native_helper( - state.memory, - logup_evals_ptr.as_canonical_u32() + start, - logup_specific.read_records[0].as_mut(), - ); + let pqs: [F; EXT_DEG * 4] = + logup_evals[start..start + EXT_DEG * 4].try_into().unwrap(); let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); let q1: [F; EXT_DEG] = pqs[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().unwrap(); diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index 51eb6d39cf..4904e28278 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -38,7 +38,7 @@ pub struct NativeSumcheckCols { pub last_timestamp: T, // Register values - pub register_ptrs: [T; 5], + pub register_ptrs: [T; 3], // Context variables // [ @@ -91,9 +91,11 @@ pub struct NativeSumcheckCols { #[derive(AlignedBorrow)] pub struct HeaderSpecificCols { pub pc: T, - pub registers: [T; 5], - /// 5 register reads + ctx read + max round read + challenges read - pub read_records: [MemoryReadAuxCols; 8], + pub registers: [T; 3], + pub prod_id: T, + pub logup_id: T, + /// 3 register reads + ctx read + max round read + challenges read + pub read_records: [MemoryReadAuxCols; 6], /// Write the final evaluation pub write_records: MemoryWriteAuxCols, } diff --git a/extensions/native/circuit/src/sumcheck/execution.rs b/extensions/native/circuit/src/sumcheck/execution.rs index 54e06e540c..95901f86ef 100644 --- a/extensions/native/circuit/src/sumcheck/execution.rs +++ b/extensions/native/circuit/src/sumcheck/execution.rs @@ -23,8 +23,8 @@ struct NativeSumcheckPreCompute { r_evals_reg: u32, ctx_reg: u32, challenges_reg: u32, - prod_evals_reg: u32, - logup_evals_reg: u32, + prod_evals_id_ptr: u32, + logup_evals_id_ptr: u32, } impl NativeSumcheckExecutor { @@ -49,8 +49,8 @@ impl NativeSumcheckExecutor { let r_evals_reg = a.as_canonical_u32(); let ctx_reg = b.as_canonical_u32(); let challenges_reg = c.as_canonical_u32(); - let prod_evals_reg = f.as_canonical_u32(); - let logup_evals_reg = g.as_canonical_u32(); + let prod_evals_id_ptr = f.as_canonical_u32(); + let logup_evals_id_ptr = g.as_canonical_u32(); if d.as_canonical_u32() != NATIVE_AS { return Err(StaticProgramError::InvalidInstruction(pc)); @@ -63,8 +63,8 @@ impl NativeSumcheckExecutor { r_evals_reg, ctx_reg, challenges_reg, - prod_evals_reg, - logup_evals_reg, + prod_evals_id_ptr, + logup_evals_id_ptr, }; Ok(()) @@ -199,13 +199,13 @@ unsafe fn execute_e12_impl( let [r_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.r_evals_reg); let [ctx_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.ctx_reg); let [challenges_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.challenges_reg); - let [prod_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.prod_evals_reg); - let [logup_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.logup_evals_reg); + let [prod_evals_id]: [F; 1] = exec_state.host_read(NATIVE_AS, pre_compute.prod_evals_id_ptr); + let [logup_evals_id]: [F; 1] = exec_state.host_read(NATIVE_AS, pre_compute.logup_evals_id_ptr); let r_evals_ptr_u32 = r_evals_ptr.as_canonical_u32(); let ctx_ptr_u32 = ctx_ptr.as_canonical_u32(); - let logup_evals_ptr = logup_evals_ptr.as_canonical_u32(); - let prod_evals_ptr = prod_evals_ptr.as_canonical_u32(); + let logup_evals_id = logup_evals_id.as_canonical_u32(); + let prod_evals_id = prod_evals_id.as_canonical_u32(); let ctx: [u32; 8] = exec_state .vm_read(NATIVE_AS, ctx_ptr_u32) @@ -224,6 +224,9 @@ unsafe fn execute_e12_impl( let mut alpha_acc = elem_to_ext(F::ONE); let mut eval_acc = elem_to_ext(F::ZERO); + let prod_evals = exec_state.streams.hint_space[prod_evals_id as usize].clone(); + let logup_evals = exec_state.streams.hint_space[logup_evals_id as usize].clone(); + for i in 0..num_prod_spec { let start = calculate_3d_ext_idx( prod_specs_inner_inner_len, @@ -231,10 +234,10 @@ unsafe fn execute_e12_impl( i, round, 0, - ); + ) as usize; if round < max_round - 1 { - let ps: [F; EXT_DEG * 2] = exec_state.vm_read(NATIVE_AS, prod_evals_ptr + start); + let ps: &[F] = &prod_evals[start..start + EXT_DEG * 2]; let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = ps[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); @@ -268,14 +271,14 @@ unsafe fn execute_e12_impl( i, round, 0, - ); + ) as usize; let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); let alpha_numerator = alpha_acc; if round < max_round - 1 { // read logup_evals - let pqs: [F; EXT_DEG * 4] = exec_state.vm_read(NATIVE_AS, logup_evals_ptr + start); + let pqs: &[F] = &logup_evals[start..start + EXT_DEG * 4]; let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = pqs[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); let q1: [F; EXT_DEG] = pqs[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); diff --git a/extensions/native/compiler/src/conversion/mod.rs b/extensions/native/compiler/src/conversion/mod.rs index e71608c190..79797ff0ae 100644 --- a/extensions/native/compiler/src/conversion/mod.rs +++ b/extensions/native/compiler/src/conversion/mod.rs @@ -536,7 +536,7 @@ fn convert_instruction>( AS::Immediate, )] }, - AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr, r_ptr) => vec![ + AsmInstruction::SumcheckLayerEval(ctx, cs, prod_id, logup_id, r_ptr) => vec![ Instruction { opcode: options.opcode_with_offset(SumcheckOpcode::SUMCHECK_LAYER_EVAL), a: i32_f(r_ptr), @@ -544,8 +544,8 @@ fn convert_instruction>( c: i32_f(cs), d: AS::Native.to_field(), e: AS::Native.to_field(), - f: i32_f(p_ptr), - g: i32_f(l_ptr), + f: i32_f(prod_id), + g: i32_f(logup_id), } ], }; diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index 78347283d5..d44bead6f9 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -341,8 +341,8 @@ pub enum DslIr { // 7. Operational mode indicator // 8+. usize-type variables indicating maximum rounds Ptr, // Challenges: alpha, coeffs - Ptr, // prod_specs_eval - Ptr, // logup_specs_eval + Var, // prod_specs_eval + Var, // logup_specs_eval Ptr, // output ), } diff --git a/extensions/native/compiler/src/ir/sumcheck.rs b/extensions/native/compiler/src/ir/sumcheck.rs index 0237fd6740..e79dac08b3 100644 --- a/extensions/native/compiler/src/ir/sumcheck.rs +++ b/extensions/native/compiler/src/ir/sumcheck.rs @@ -1,4 +1,5 @@ use super::{Array, Builder, Config, DslIr, Ext, Usize}; +use crate::ir::Var; impl Builder { /// Extends native VM ability to calculate the evaluation for a sumcheck layer @@ -30,18 +31,18 @@ impl Builder { &mut self, input_ctx: &Array>, // Context variables challenges: &Array>, // Challenges - prod_specs_eval: &Array>, /* GKR product IOP evaluations. Flattened + prod_specs_eval_id: Var, /* GKR product IOP evaluations. Flattened * from 3D array. */ - logup_specs_eval: &Array>, /* GKR logup IOP evaluations. Flattened - * from 3D array. */ + logup_specs_eval_id: Var, /* GKR logup IOP evaluations. Flattened + * from 3D array. */ r_evals: &Array>, /* Next layer's evaluations (pointer used for * storing opcode output) */ ) { self.operations.push(DslIr::SumcheckLayerEval( input_ctx.ptr(), challenges.ptr(), - prod_specs_eval.ptr(), - logup_specs_eval.ptr(), + prod_specs_eval_id, + logup_specs_eval_id, r_evals.ptr(), )); } From 5cdc6bad1632dda3aa0f1a6696c91dac31b5f20d Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 29 Jan 2026 13:47:17 +0800 Subject: [PATCH 2/4] wip --- extensions/native/circuit/src/sumcheck/air.rs | 24 +++-- .../native/circuit/src/sumcheck/chip.rs | 16 ++-- extensions/native/compiler/src/ir/sumcheck.rs | 9 +- extensions/native/recursion/tests/sumcheck.rs | 90 ++++++++++++++----- 4 files changed, 97 insertions(+), 42 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 2268f77e67..dc05bb1130 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -23,6 +23,8 @@ use crate::{ }, }; +pub const TOPLEVEL_TIMESTAMP_DIFF: usize = 6; + #[derive(Clone, Debug)] pub struct NativeSumcheckAir { pub execution_bridge: ExecutionBridge, @@ -226,7 +228,7 @@ impl Air for NativeSumcheckAir { .when(next.prod_row + next.logup_row) .assert_eq( next.start_timestamp, - start_timestamp + AB::F::from_canonical_usize(8), + start_timestamp + AB::F::from_canonical_usize(TOPLEVEL_TIMESTAMP_DIFF), ); builder .when(prod_row) @@ -301,14 +303,18 @@ impl Air for NativeSumcheckAir { ) .eval(builder, header_row); + let mut header_timestamp_diff = (0..TOPLEVEL_TIMESTAMP_DIFF) + .into_iter() + .map(|i| AB::F::from_canonical_usize(i)); + let mut header_read_records_iter = header_row_specific.read_records.iter(); // Read registers for i in 0..3usize { self.memory_bridge .read( MemoryAddress::new(native_as, registers[i]), [register_ptrs[i]], - first_timestamp + AB::F::from_canonical_usize(i), - &header_row_specific.read_records[i], + first_timestamp + header_timestamp_diff.next().unwrap(), + header_read_records_iter.next().unwrap(), ) .eval(builder, header_row); } @@ -318,8 +324,8 @@ impl Air for NativeSumcheckAir { .read( MemoryAddress::new(native_as, register_ptrs[0]), ctx, - first_timestamp + AB::F::from_canonical_usize(3), - &header_row_specific.read_records[3], + first_timestamp + header_timestamp_diff.next().unwrap(), + header_read_records_iter.next().unwrap(), ) .eval(builder, header_row); @@ -328,8 +334,8 @@ impl Air for NativeSumcheckAir { .read( MemoryAddress::new(native_as, register_ptrs[1]), challenges, - first_timestamp + AB::F::from_canonical_usize(4), - &header_row_specific.read_records[4], + first_timestamp + header_timestamp_diff.next().unwrap(), + header_read_records_iter.next().unwrap(), ) .eval(builder, header_row); @@ -341,8 +347,8 @@ impl Air for NativeSumcheckAir { register_ptrs[0] + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN), ), [max_round], - first_timestamp + AB::F::from_canonical_usize(5), - &header_row_specific.read_records[5], + first_timestamp + header_timestamp_diff.next().unwrap(), + header_read_records_iter.next().unwrap(), ) .eval(builder, header_row); diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 7f1e0c9b0b..79c3763925 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -169,16 +169,17 @@ where head_specific.registers[1] = challenges_reg; head_specific.registers[2] = r_evals_reg; + let mut head_read_records_iter = head_specific.read_records.iter_mut().map(|r| r.as_mut()); // read pointers let [ctx_ptr]: [F; 1] = tracing_read_native_helper( state.memory, ctx_reg.as_canonical_u32(), - head_specific.read_records[0].as_mut(), + head_read_records_iter.next().unwrap(), ); let [challenges_ptr]: [F; 1] = tracing_read_native_helper( state.memory, challenges_reg.as_canonical_u32(), - head_specific.read_records[1].as_mut(), + head_read_records_iter.next().unwrap(), ); let [prod_evals_id]: [F; 1] = memory_read_native(state.memory.data(), prod_evals_id_ptr.as_canonical_u32()); @@ -187,23 +188,24 @@ where let [r_evals_ptr]: [F; 1] = tracing_read_native_helper( state.memory, r_evals_reg.as_canonical_u32(), - head_specific.read_records[2].as_mut(), + head_read_records_iter.next().unwrap(), ); let ctx: [F; CONTEXT_ARR_BASE_LEN] = tracing_read_native_helper( state.memory, ctx_ptr.as_canonical_u32(), - head_specific.read_records[3].as_mut(), + head_read_records_iter.next().unwrap(), ); let challenges: [F; EXT_DEG * 4] = tracing_read_native_helper( state.memory, challenges_ptr.as_canonical_u32(), - head_specific.read_records[4].as_mut(), + head_read_records_iter.next().unwrap(), ); let [max_round]: [F; 1] = tracing_read_native_helper( state.memory, ctx_ptr.as_canonical_u32() + CONTEXT_ARR_BASE_LEN as u32, - head_specific.read_records[5].as_mut(), + head_read_records_iter.next().unwrap(), ); + assert!(head_read_records_iter.next().is_none()); cur_timestamp += 6; // 3 register reads + ctx read + challenges read + max_round read head_row.challenges.copy_from_slice(&challenges); @@ -491,7 +493,7 @@ impl TraceFiller for NativeSumcheckFiller { let header: &mut HeaderSpecificCols = cols.specific[..HeaderSpecificCols::::width()].borrow_mut(); - for i in 0..8usize { + for i in 0..6usize { mem_fill_helper( mem_helper, start_timestamp + i as u32, diff --git a/extensions/native/compiler/src/ir/sumcheck.rs b/extensions/native/compiler/src/ir/sumcheck.rs index e79dac08b3..dd5bfd2e26 100644 --- a/extensions/native/compiler/src/ir/sumcheck.rs +++ b/extensions/native/compiler/src/ir/sumcheck.rs @@ -1,5 +1,4 @@ -use super::{Array, Builder, Config, DslIr, Ext, Usize}; -use crate::ir::Var; +use super::{Array, Builder, Config, DslIr, Ext, Usize, Var}; impl Builder { /// Extends native VM ability to calculate the evaluation for a sumcheck layer @@ -31,10 +30,8 @@ impl Builder { &mut self, input_ctx: &Array>, // Context variables challenges: &Array>, // Challenges - prod_specs_eval_id: Var, /* GKR product IOP evaluations. Flattened - * from 3D array. */ - logup_specs_eval_id: Var, /* GKR logup IOP evaluations. Flattened - * from 3D array. */ + prod_specs_eval_id: Var, /* ID for GKR product IOP evaluations hint. */ + logup_specs_eval_id: Var, /* ID for GKR logup IOP evaluations hint. */ r_evals: &Array>, /* Next layer's evaluations (pointer used for * storing opcode output) */ ) { diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs index 284a103021..c7b1b95720 100644 --- a/extensions/native/recursion/tests/sumcheck.rs +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -26,9 +26,32 @@ pub type E = BinomialExtensionField; #[test] fn test_sumcheck_layer_eval() { + let mut rng = thread_rng(); let mut builder = AsmBuilder::>::default(); - build_test_program(&mut builder); + let num_layers = 8; + let num_prod_specs = 6; + let num_logup_specs = 8; + + let prod_evals: Vec = (0..(num_prod_specs * num_layers * 2)) + .into_iter() + .map(|_| new_rand_ext(&mut rng)) + .collect(); + + let logup_evals: Vec = (0..(num_logup_specs * num_layers * 4)) + .into_iter() + .map(|_| new_rand_ext(&mut rng)) + .collect(); + + build_test_program( + &mut builder, + prod_evals.clone(), + logup_evals.clone(), + num_prod_specs, + num_logup_specs, + num_layers, + 3, + ); let compilation_options = CompilerOptions::default().with_cycle_tracker(); let mut compiler = AsmCompiler::new(compilation_options.word_size); @@ -49,13 +72,35 @@ fn test_sumcheck_layer_eval() { standard_fri_params_with_100_bits_conjectured_security(1) }; + let mut input_stream = vec![]; + input_stream.push( + prod_evals + .into_iter() + .flat_map(|e| e.as_base_slice().iter().cloned().collect::>()) + .collect(), + ); + input_stream.push( + logup_evals + .into_iter() + .flat_map(|e| e.as_base_slice().iter().cloned().collect::>()) + .collect(), + ); + let mut config = NativeConfig::aggregation(0, sumcheck_max_constraint_degree); config.system.memory_config.max_access_adapter_n = 16; let vb = NativeBuilder::default(); #[cfg(not(feature = "cuda"))] - air_test_impl::(fri_params, vb, config, program, vec![], 1, true) - .unwrap(); + air_test_impl::( + fri_params, + vb, + config, + program, + input_stream, + 1, + true, + ) + .unwrap(); #[cfg(feature = "cuda")] { air_test_impl::( @@ -71,22 +116,24 @@ fn test_sumcheck_layer_eval() { } } -fn new_rand_ext(rng: &mut R) -> C::EF { - C::EF::from_base_slice(&[ - C::F::from_canonical_u32(rng.next_u32()), - C::F::from_canonical_u32(rng.next_u32()), - C::F::from_canonical_u32(rng.next_u32()), - C::F::from_canonical_u32(rng.next_u32()), +fn new_rand_ext(rng: &mut R) -> E { + E::from_base_slice(&[ + F::from_canonical_u32(rng.next_u32()), + F::from_canonical_u32(rng.next_u32()), + F::from_canonical_u32(rng.next_u32()), + F::from_canonical_u32(rng.next_u32()), ]) } -fn build_test_program(builder: &mut Builder) { - let mut rng = thread_rng(); - // 6 prod specs in 8 layers, 5 logup specs in 8 layers - let round = 3; - let num_prod_specs = 6; - let num_logup_specs = 5; - let num_layers = 8; +fn build_test_program( + builder: &mut Builder, + prod_evals: Vec, + logup_evals: Vec, + num_prod_specs: usize, + num_logup_specs: usize, + num_layers: usize, + round: usize, +) { let mode = 1; // current_layer let mut ctx_u32s = vec![ @@ -128,7 +175,7 @@ fn build_test_program(builder: &mut Builder) { let num_prod_evals = num_prod_specs * num_layers * 2; let prod_spec_evals: Array> = builder.dyn_array(num_prod_evals); for idx in 0..num_prod_evals { - let e: Ext = builder.constant(new_rand_ext::(&mut rng)); + let e: Ext = builder.constant(prod_evals[idx]); builder.set(&prod_spec_evals, idx, e); } @@ -136,7 +183,7 @@ fn build_test_program(builder: &mut Builder) { let num_logup_evals = num_logup_specs * num_layers * 4; let logup_spec_evals: Array> = builder.dyn_array(num_logup_evals); for idx in 0..num_logup_evals { - let e: Ext = builder.constant(new_rand_ext::(&mut rng)); + let e: Ext = builder.constant(logup_evals[idx]); builder.set(&logup_spec_evals, idx, e); } @@ -202,13 +249,16 @@ fn build_test_program(builder: &mut Builder) { .chain(logup_q_evals) .collect::>(); + let prod_spec_evals_id = builder.hint_load(); + let logup_spec_evals_id = builder.hint_load(); + let next_layer_evals: Array> = builder.dyn_array(r_evals.len()); builder.sumcheck_layer_eval( &ctx, &challenges, - &prod_spec_evals, - &logup_spec_evals, + prod_spec_evals_id, + logup_spec_evals_id, &next_layer_evals, ); From a25d99e798786908e06d4abc5abcab9372b94a5a Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 29 Jan 2026 14:20:23 +0800 Subject: [PATCH 3/4] unit test passed --- extensions/native/circuit/src/sumcheck/air.rs | 21 +++++++--- .../native/circuit/src/sumcheck/chip.rs | 40 ++++++++++--------- extensions/native/compiler/src/ir/sumcheck.rs | 6 +-- 3 files changed, 40 insertions(+), 27 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index dc05bb1130..a82141332b 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -24,6 +24,8 @@ use crate::{ }; pub const TOPLEVEL_TIMESTAMP_DIFF: usize = 6; +pub const NUM_RWS_FOR_PRODUCT: usize = 1; +pub const NUM_RWS_FOR_LOGUP: usize = 2; #[derive(Clone, Debug)] pub struct NativeSumcheckAir { @@ -235,14 +237,16 @@ impl Air for NativeSumcheckAir { .when(next.prod_row + next.logup_row) .assert_eq( next.start_timestamp, - start_timestamp + within_round_limit * AB::F::TWO, + start_timestamp + + within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_PRODUCT), ); builder .when(logup_row) .when(next.prod_row + next.logup_row) .assert_eq( next.start_timestamp, - start_timestamp + within_round_limit * AB::F::from_canonical_usize(3), + start_timestamp + + within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_LOGUP), ); // Termination condition @@ -389,6 +393,9 @@ impl Air for NativeSumcheckAir { ); builder.assert_eq(prod_row * should_acc, prod_acc); + let mut prod_timestamp_diff = (0..1).into_iter().map(|i| AB::F::from_canonical_usize(i)); + + // TODO: read from hint space then write to memory // self.memory_bridge // .read( // MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), @@ -410,11 +417,13 @@ impl Air for NativeSumcheckAir { register_ptrs[2] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), ), prod_row_specific.p_evals, - start_timestamp + AB::F::ONE, + start_timestamp + prod_timestamp_diff.next().unwrap(), &prod_row_specific.write_record, ) .eval(builder, prod_row * within_round_limit); + assert!(prod_timestamp_diff.next().is_none()); + // Calculate evaluations let next_round_p_evals = FieldExtension::add( FieldExtension::multiply::(p1, c1), @@ -477,6 +486,8 @@ impl Air for NativeSumcheckAir { ); builder.assert_eq(logup_row * should_acc, logup_acc); + let mut logup_timestamp_diff = (0..2).into_iter().map(|i| AB::F::from_canonical_usize(i)); + // self.memory_bridge // .read( // MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), @@ -506,7 +517,7 @@ impl Air for NativeSumcheckAir { + (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.p_evals, - start_timestamp + AB::F::ONE, + start_timestamp + logup_timestamp_diff.next().unwrap(), &logup_row_specific.write_records[0], ) .eval(builder, logup_row * within_round_limit); @@ -521,7 +532,7 @@ impl Air for NativeSumcheckAir { * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.q_evals, - start_timestamp + AB::F::TWO, + start_timestamp + logup_timestamp_diff.next().unwrap(), &logup_row_specific.write_records[1], ) .eval(builder, logup_row * within_round_limit); diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 79c3763925..4f7ba79efd 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -208,6 +208,8 @@ where assert!(head_read_records_iter.next().is_none()); cur_timestamp += 6; // 3 register reads + ctx read + challenges read + max_round read head_row.challenges.copy_from_slice(&challenges); + head_specific.prod_id = prod_evals_id_ptr; + head_specific.logup_id = logup_evals_id_ptr; // challenges = [alpha, c1=r, c2=1-r] let alpha: [F; 4] = challenges[0..EXT_DEG].try_into().unwrap(); @@ -271,7 +273,7 @@ where ) as usize; prod_specific.data_ptr = F::from_canonical_usize(start); - // read p1, p2 + // read p1, p2 from hint space let ps: [F; EXT_DEG * 2] = prod_evals[start..start + EXT_DEG * 2].try_into().unwrap(); let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); @@ -307,7 +309,7 @@ where eval, &mut prod_specific.write_record, ); - cur_timestamp += 2; + cur_timestamp += 1; // 1 write let eval_rlc = FieldExtension::multiply(alpha_acc, eval); prod_specific.eval_rlc = eval_rlc; @@ -354,7 +356,7 @@ where ) as usize; logup_specific.data_ptr = F::from_canonical_usize(start); - // read p1, p2, q1, q2 + // read p1, p2, q1, q2 from hint space let pqs: [F; EXT_DEG * 4] = logup_evals[start..start + EXT_DEG * 4].try_into().unwrap(); let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); @@ -414,7 +416,7 @@ where q_eval, &mut logup_specific.write_records[1], ); - cur_timestamp += 3; // 1 read, 2 writes + cur_timestamp += 2; // 0 read, 2 writes let eval_rlc = FieldExtension::add( FieldExtension::multiply(alpha_numerator, p_eval), @@ -510,16 +512,16 @@ impl TraceFiller for NativeSumcheckFiller { cols.specific[..ProdSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { - // read p1, p2 - mem_fill_helper( - mem_helper, - start_timestamp, - prod_row_specific.read_records[0].as_mut(), - ); + // TODO: read p1, p2 from hint space, then write to memory + // mem_fill_helper( + // mem_helper, + // start_timestamp, + // prod_row_specific.read_records[0].as_mut(), + // ); // write p_eval mem_fill_helper( mem_helper, - start_timestamp + 1, + start_timestamp, prod_row_specific.write_record.as_mut(), ); } @@ -528,22 +530,22 @@ impl TraceFiller for NativeSumcheckFiller { cols.specific[..LogupSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { - // read p1, p2, q1, q2 - mem_fill_helper( - mem_helper, - start_timestamp, - logup_row_specific.read_records[0].as_mut(), - ); + // TODO: read p1, p2, q1, q2 from hint space + // mem_fill_helper( + // mem_helper, + // start_timestamp, + // logup_row_specific.read_records[0].as_mut(), + // ); // write p_eval mem_fill_helper( mem_helper, - start_timestamp + 1, + start_timestamp, logup_row_specific.write_records[0].as_mut(), ); // write q_eval mem_fill_helper( mem_helper, - start_timestamp + 2, + start_timestamp + 1, logup_row_specific.write_records[1].as_mut(), ); } diff --git a/extensions/native/compiler/src/ir/sumcheck.rs b/extensions/native/compiler/src/ir/sumcheck.rs index dd5bfd2e26..792c9543b2 100644 --- a/extensions/native/compiler/src/ir/sumcheck.rs +++ b/extensions/native/compiler/src/ir/sumcheck.rs @@ -30,10 +30,10 @@ impl Builder { &mut self, input_ctx: &Array>, // Context variables challenges: &Array>, // Challenges - prod_specs_eval_id: Var, /* ID for GKR product IOP evaluations hint. */ - logup_specs_eval_id: Var, /* ID for GKR logup IOP evaluations hint. */ + prod_specs_eval_id: Var, /* ID for GKR product IOP evaluations hint. */ + logup_specs_eval_id: Var, /* ID for GKR logup IOP evaluations hint. */ r_evals: &Array>, /* Next layer's evaluations (pointer used for - * storing opcode output) */ + * storing opcode output) */ ) { self.operations.push(DslIr::SumcheckLayerEval( input_ctx.ptr(), From d2179dae1cf3edd12a0785664c99cf976c36e1c1 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 29 Jan 2026 14:52:17 +0800 Subject: [PATCH 4/4] gpu unit test pass --- .../circuit/cuda/include/native/sumcheck.cuh | 8 ++++--- .../native/circuit/cuda/src/sumcheck.cu | 24 +++++++++---------- extensions/native/recursion/tests/sumcheck.rs | 8 +++---- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/extensions/native/circuit/cuda/include/native/sumcheck.cuh b/extensions/native/circuit/cuda/include/native/sumcheck.cuh index da7aefd769..4f16d4fad7 100644 --- a/extensions/native/circuit/cuda/include/native/sumcheck.cuh +++ b/extensions/native/circuit/cuda/include/native/sumcheck.cuh @@ -7,8 +7,10 @@ using namespace native; template struct HeaderSpecificCols { T pc; - T registers[5]; - MemoryReadAuxCols read_records[8]; + T registers[3]; + T prod_id; + T logup_id; + MemoryReadAuxCols read_records[6]; MemoryWriteAuxCols write_records; }; @@ -61,7 +63,7 @@ template struct NativeSumcheckCols { T start_timestamp; T last_timestamp; - T register_ptrs[5]; + T register_ptrs[3]; T ctx[EXT_DEG * 2]; diff --git a/extensions/native/circuit/cuda/src/sumcheck.cu b/extensions/native/circuit/cuda/src/sumcheck.cu index 173abebbe2..def3db0d1d 100644 --- a/extensions/native/circuit/cuda/src/sumcheck.cu +++ b/extensions/native/circuit/cuda/src/sumcheck.cu @@ -11,7 +11,7 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h uint32_t start_timestamp = row[COL_INDEX(NativeSumcheckCols, start_timestamp)].asUInt32(); if (row[COL_INDEX(NativeSumcheckCols, header_row)] == Fp::one()) { - for (uint32_t i = 0; i < 8; ++i) { + for (uint32_t i = 0; i < 6; ++i) { mem_fill_base( mem_helper, start_timestamp + i, @@ -26,32 +26,32 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h ); } else if (row[COL_INDEX(NativeSumcheckCols, prod_row)] == Fp::one()) { if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { + // mem_fill_base( + // mem_helper, + // start_timestamp, + // specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base)) + // ); mem_fill_base( mem_helper, start_timestamp, - specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base)) - ); - mem_fill_base( - mem_helper, - start_timestamp + 1, specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base)) ); } } else if (row[COL_INDEX(NativeSumcheckCols, logup_row)] == Fp::one()) { if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { + // mem_fill_base( + // mem_helper, + // start_timestamp, + // specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base)) + // ); mem_fill_base( mem_helper, start_timestamp, - specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base)) - ); - mem_fill_base( - mem_helper, - start_timestamp + 1, specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base)) ); mem_fill_base( mem_helper, - start_timestamp + 2, + start_timestamp + 1, specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base)) ); } diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs index c7b1b95720..145336c362 100644 --- a/extensions/native/recursion/tests/sumcheck.rs +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -72,17 +72,17 @@ fn test_sumcheck_layer_eval() { standard_fri_params_with_100_bits_conjectured_security(1) }; - let mut input_stream = vec![]; + let mut input_stream: Vec> = vec![]; input_stream.push( prod_evals .into_iter() - .flat_map(|e| e.as_base_slice().iter().cloned().collect::>()) + .flat_map(|e| >::as_base_slice(&e).to_vec()) .collect(), ); input_stream.push( logup_evals .into_iter() - .flat_map(|e| e.as_base_slice().iter().cloned().collect::>()) + .flat_map(|e| >::as_base_slice(&e).to_vec()) .collect(), ); @@ -108,7 +108,7 @@ fn test_sumcheck_layer_eval() { vb, config, program, - vec![], + input_stream, 1, true, )