diff --git a/crates/vm/src/system/memory/offline_checker/columns.rs b/crates/vm/src/system/memory/offline_checker/columns.rs index ef9821f859..63c114193d 100644 --- a/crates/vm/src/system/memory/offline_checker/columns.rs +++ b/crates/vm/src/system/memory/offline_checker/columns.rs @@ -62,7 +62,7 @@ impl MemoryWriteAuxCols { #[repr(C)] #[derive(Clone, Copy, Debug, AlignedBorrow)] pub struct MemoryReadAuxCols { - pub(in crate::system::memory) base: MemoryBaseAuxCols, + pub base: MemoryBaseAuxCols, } impl MemoryReadAuxCols { diff --git a/extensions/native/circuit/cuda/include/native/sumcheck.cuh b/extensions/native/circuit/cuda/include/native/sumcheck.cuh index 4f16d4fad7..a67b134aaf 100644 --- a/extensions/native/circuit/cuda/include/native/sumcheck.cuh +++ b/extensions/native/circuit/cuda/include/native/sumcheck.cuh @@ -7,10 +7,10 @@ using namespace native; template struct HeaderSpecificCols { T pc; - T registers[3]; - T prod_id; - T logup_id; - MemoryReadAuxCols read_records[6]; + T registers[5]; + T prod_evals_id; + T logup_evals_id; + MemoryReadAuxCols read_records[8]; MemoryWriteAuxCols write_records; }; @@ -63,7 +63,7 @@ template struct NativeSumcheckCols { T start_timestamp; T last_timestamp; - T register_ptrs[3]; + T register_ptrs[5]; T ctx[EXT_DEG * 2]; @@ -82,6 +82,8 @@ template struct NativeSumcheckCols { T eval_acc[EXT_DEG]; + T is_hint_src_id; + T specific[COL_SPECIFIC_WIDTH]; }; diff --git a/extensions/native/circuit/cuda/src/sumcheck.cu b/extensions/native/circuit/cuda/src/sumcheck.cu index def3db0d1d..f89b93f72d 100644 --- a/extensions/native/circuit/cuda/src/sumcheck.cu +++ b/extensions/native/circuit/cuda/src/sumcheck.cu @@ -11,7 +11,7 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h uint32_t start_timestamp = row[COL_INDEX(NativeSumcheckCols, start_timestamp)].asUInt32(); if (row[COL_INDEX(NativeSumcheckCols, header_row)] == Fp::one()) { - for (uint32_t i = 0; i < 6; ++i) { + for (uint32_t i = 0; i < 8; ++i) { mem_fill_base( mem_helper, start_timestamp + i, @@ -26,32 +26,32 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h ); } else if (row[COL_INDEX(NativeSumcheckCols, prod_row)] == Fp::one()) { if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { - // mem_fill_base( - // mem_helper, - // start_timestamp, - // specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base)) - // ); mem_fill_base( mem_helper, start_timestamp, + specific.slice_from(COL_INDEX(ProdSpecificCols, ps_record.base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 1, specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base)) ); } } else if (row[COL_INDEX(NativeSumcheckCols, logup_row)] == Fp::one()) { if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { - // mem_fill_base( - // mem_helper, - // start_timestamp, - // specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base)) - // ); mem_fill_base( mem_helper, start_timestamp, - specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base)) + specific.slice_from(COL_INDEX(LogupSpecificCols, pqs_record.base)) ); mem_fill_base( mem_helper, start_timestamp + 1, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 2, specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base)) ); } diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index a82141332b..0f988cef55 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -2,7 +2,7 @@ use std::borrow::Borrow; use openvm_circuit::{ arch::{ExecutionBridge, ExecutionState}, - system::memory::{offline_checker::MemoryBridge, MemoryAddress}, + system::memory::{MemoryAddress, offline_checker::{MemoryBridge, MemoryReadAuxCols}}, }; use openvm_circuit_primitives::utils::{and, assert_array_eq, not}; use openvm_instructions::{LocalOpcode, NATIVE_AS}; @@ -23,9 +23,8 @@ use crate::{ }, }; -pub const TOPLEVEL_TIMESTAMP_DIFF: usize = 6; -pub const NUM_RWS_FOR_PRODUCT: usize = 1; -pub const NUM_RWS_FOR_LOGUP: usize = 2; +pub const NUM_RWS_FOR_PRODUCT: usize = 2; +pub const NUM_RWS_FOR_LOGUP: usize = 3; #[derive(Clone, Debug)] pub struct NativeSumcheckAir { @@ -103,6 +102,7 @@ impl Air for NativeSumcheckAir { within_round_limit, should_acc, eval_acc, + is_hint_src_id, specific, } = local; @@ -230,7 +230,7 @@ impl Air for NativeSumcheckAir { .when(next.prod_row + next.logup_row) .assert_eq( next.start_timestamp, - start_timestamp + AB::F::from_canonical_usize(TOPLEVEL_TIMESTAMP_DIFF), + start_timestamp + AB::F::from_canonical_usize(8), ); builder .when(prod_row) @@ -294,31 +294,27 @@ impl Air for NativeSumcheckAir { .execute_and_increment_pc( AB::Expr::from_canonical_usize(SUMCHECK_LAYER_EVAL.global_opcode().as_usize()), [ - registers[2].into(), + registers[4].into(), registers[0].into(), registers[1].into(), native_as.into(), native_as.into(), - header_row_specific.prod_id.into(), - header_row_specific.logup_id.into(), + registers[2].into(), + registers[3].into(), ], ExecutionState::new(header_row_specific.pc, first_timestamp), last_timestamp - first_timestamp, ) .eval(builder, header_row); - let mut header_timestamp_diff = (0..TOPLEVEL_TIMESTAMP_DIFF) - .into_iter() - .map(|i| AB::F::from_canonical_usize(i)); - let mut header_read_records_iter = header_row_specific.read_records.iter(); // Read registers - for i in 0..3usize { + for i in 0..5usize { self.memory_bridge .read( MemoryAddress::new(native_as, registers[i]), [register_ptrs[i]], - first_timestamp + header_timestamp_diff.next().unwrap(), - header_read_records_iter.next().unwrap(), + first_timestamp + AB::F::from_canonical_usize(i), + &header_row_specific.read_records[i], ) .eval(builder, header_row); } @@ -328,8 +324,8 @@ impl Air for NativeSumcheckAir { .read( MemoryAddress::new(native_as, register_ptrs[0]), ctx, - first_timestamp + header_timestamp_diff.next().unwrap(), - header_read_records_iter.next().unwrap(), + first_timestamp + AB::F::from_canonical_usize(5), + &header_row_specific.read_records[5], ) .eval(builder, header_row); @@ -338,8 +334,8 @@ impl Air for NativeSumcheckAir { .read( MemoryAddress::new(native_as, register_ptrs[1]), challenges, - first_timestamp + header_timestamp_diff.next().unwrap(), - header_read_records_iter.next().unwrap(), + first_timestamp + AB::F::from_canonical_usize(6), + &header_row_specific.read_records[6], ) .eval(builder, header_row); @@ -350,16 +346,16 @@ impl Air for NativeSumcheckAir { native_as, register_ptrs[0] + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN), ), - [max_round], - first_timestamp + header_timestamp_diff.next().unwrap(), - header_read_records_iter.next().unwrap(), + [max_round, is_hint_src_id, header_row_specific.prod_evals_id, header_row_specific.logup_evals_id], + first_timestamp + AB::F::from_canonical_usize(7), + &header_row_specific.read_records[7], ) .eval(builder, header_row); // Write final result self.memory_bridge .write( - MemoryAddress::new(native_as, register_ptrs[2]), + MemoryAddress::new(native_as, register_ptrs[4]), eval_acc, last_timestamp - AB::F::ONE, &header_row_specific.write_records, @@ -393,17 +389,33 @@ impl Air for NativeSumcheckAir { ); builder.assert_eq(prod_row * should_acc, prod_acc); - let mut prod_timestamp_diff = (0..1).into_iter().map(|i| AB::F::from_canonical_usize(i)); - - // TODO: read from hint space then write to memory - // self.memory_bridge - // .read( - // MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), - // prod_row_specific.p, - // start_timestamp, - // &prod_row_specific.read_records[0], - // ) - // .eval(builder, prod_row * within_round_limit); + // Read p1, p2 from witness arrays + self.memory_bridge + .read( + MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), + prod_row_specific.p, + start_timestamp, + &MemoryReadAuxCols { + base: prod_row_specific.ps_record.base, + }, + ) + .eval( + builder, + (prod_in_round_evaluation + prod_next_round_evaluation) * not(is_hint_src_id), + ); + + // Obtain p1, p2 from hint space and write back to witness arrays + self.memory_bridge + .write( + MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), + prod_row_specific.p, + start_timestamp, + &prod_row_specific.ps_record, + ) + .eval( + builder, + (prod_in_round_evaluation + prod_next_round_evaluation) * is_hint_src_id, + ); let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().unwrap(); let p2: [AB::Var; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)] @@ -414,16 +426,14 @@ impl Air for NativeSumcheckAir { .write( MemoryAddress::new( native_as, - register_ptrs[2] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), + register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), ), prod_row_specific.p_evals, - start_timestamp + prod_timestamp_diff.next().unwrap(), + start_timestamp + AB::F::ONE, &prod_row_specific.write_record, ) .eval(builder, prod_row * within_round_limit); - assert!(prod_timestamp_diff.next().is_none()); - // Calculate evaluations let next_round_p_evals = FieldExtension::add( FieldExtension::multiply::(p1, c1), @@ -486,17 +496,33 @@ impl Air for NativeSumcheckAir { ); builder.assert_eq(logup_row * should_acc, logup_acc); - let mut logup_timestamp_diff = (0..2).into_iter().map(|i| AB::F::from_canonical_usize(i)); - - // self.memory_bridge - // .read( - // MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), - // logup_row_specific.pq, - // start_timestamp, - // &logup_row_specific.read_records[0], - // ) - // .eval(builder, logup_row * within_round_limit); - + // Read p1, p2, q1, q2 from witness arrays + self.memory_bridge + .read( + MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), + logup_row_specific.pq, + start_timestamp, + &MemoryReadAuxCols { + base: logup_row_specific.pqs_record.base, + }, + ) + .eval( + builder, + (logup_in_round_evaluation + logup_next_round_evaluation) * not(is_hint_src_id), + ); + + // Obtain p1, p2, q1, q2 from hint space + self.memory_bridge + .write( + MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), + logup_row_specific.pq, + start_timestamp, + &logup_row_specific.pqs_record, + ) + .eval( + builder, + (logup_in_round_evaluation + logup_next_round_evaluation) * is_hint_src_id, + ); let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().unwrap(); let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)] .try_into() @@ -513,11 +539,11 @@ impl Air for NativeSumcheckAir { .write( MemoryAddress::new( native_as, - register_ptrs[2] + register_ptrs[4] + (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.p_evals, - start_timestamp + logup_timestamp_diff.next().unwrap(), + start_timestamp + AB::F::ONE, &logup_row_specific.write_records[0], ) .eval(builder, logup_row * within_round_limit); @@ -527,12 +553,12 @@ impl Air for NativeSumcheckAir { .write( MemoryAddress::new( native_as, - register_ptrs[2] + register_ptrs[4] + (num_prod_spec + num_logup_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.q_evals, - start_timestamp + logup_timestamp_diff.next().unwrap(), + start_timestamp + AB::F::TWO, &logup_row_specific.write_records[1], ) .eval(builder, logup_row * within_round_limit); diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 4f7ba79efd..715912feae 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -125,8 +125,8 @@ where c: challenges_reg, d: data_address_space, e: register_address_space, - f: prod_evals_id_ptr, - g: logup_evals_id_ptr, + f: prod_evals_reg, + g: logup_evals_reg, } = instruction; // This opcode supports two modes of operation: @@ -167,49 +167,54 @@ where head_specific.registers[0] = ctx_reg; head_specific.registers[1] = challenges_reg; - head_specific.registers[2] = r_evals_reg; + head_specific.registers[2] = prod_evals_reg; + head_specific.registers[3] = logup_evals_reg; + head_specific.registers[4] = r_evals_reg; - let mut head_read_records_iter = head_specific.read_records.iter_mut().map(|r| r.as_mut()); // read pointers let [ctx_ptr]: [F; 1] = tracing_read_native_helper( state.memory, ctx_reg.as_canonical_u32(), - head_read_records_iter.next().unwrap(), + head_specific.read_records[0].as_mut(), ); let [challenges_ptr]: [F; 1] = tracing_read_native_helper( state.memory, challenges_reg.as_canonical_u32(), - head_read_records_iter.next().unwrap(), + head_specific.read_records[1].as_mut(), + ); + let [prod_evals_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + prod_evals_reg.as_canonical_u32(), + head_specific.read_records[2].as_mut(), + ); + let [logup_evals_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + logup_evals_reg.as_canonical_u32(), + head_specific.read_records[3].as_mut(), ); - let [prod_evals_id]: [F; 1] = - memory_read_native(state.memory.data(), prod_evals_id_ptr.as_canonical_u32()); - let [logup_evals_id]: [F; 1] = - memory_read_native(state.memory.data(), logup_evals_id_ptr.as_canonical_u32()); let [r_evals_ptr]: [F; 1] = tracing_read_native_helper( state.memory, r_evals_reg.as_canonical_u32(), - head_read_records_iter.next().unwrap(), + head_specific.read_records[4].as_mut(), ); let ctx: [F; CONTEXT_ARR_BASE_LEN] = tracing_read_native_helper( state.memory, ctx_ptr.as_canonical_u32(), - head_read_records_iter.next().unwrap(), + head_specific.read_records[5].as_mut(), ); let challenges: [F; EXT_DEG * 4] = tracing_read_native_helper( state.memory, challenges_ptr.as_canonical_u32(), - head_read_records_iter.next().unwrap(), + head_specific.read_records[6].as_mut(), ); - let [max_round]: [F; 1] = tracing_read_native_helper( - state.memory, - ctx_ptr.as_canonical_u32() + CONTEXT_ARR_BASE_LEN as u32, - head_read_records_iter.next().unwrap(), - ); - assert!(head_read_records_iter.next().is_none()); - cur_timestamp += 6; // 3 register reads + ctx read + challenges read + max_round read + let [max_round, is_hint_src_id, prod_evals_id, logup_evals_id]: [F; EXT_DEG] = + tracing_read_native_helper( + state.memory, + ctx_ptr.as_canonical_u32() + CONTEXT_ARR_BASE_LEN as u32, + head_specific.read_records[7].as_mut(), + ); + cur_timestamp += 8; // 5 register reads + ctx read + challenges read + max_round read head_row.challenges.copy_from_slice(&challenges); - head_specific.prod_id = prod_evals_id_ptr; - head_specific.logup_id = logup_evals_id_ptr; // challenges = [alpha, c1=r, c2=1-r] let alpha: [F; 4] = challenges[0..EXT_DEG].try_into().unwrap(); @@ -219,7 +224,7 @@ where let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); - // all rows share same register values, ctx, challenges, max_round + // all rows share same register values, ctx, challenges, max_round, hint_space_ptrs (optional) for row in rows.iter_mut() { // c1, c2 are same during the entire execution row.challenges[EXT_DEG..3 * EXT_DEG].copy_from_slice(&challenges[EXT_DEG..3 * EXT_DEG]); @@ -231,14 +236,25 @@ where F::from_canonical_u32(logup_specs_inner_len * logup_specs_inner_inner_len); row.register_ptrs[0] = ctx_ptr; row.register_ptrs[1] = challenges_ptr; - row.register_ptrs[2] = r_evals_ptr; + row.register_ptrs[2] = prod_evals_ptr; + row.register_ptrs[3] = logup_evals_ptr; + row.register_ptrs[4] = r_evals_ptr; row.max_round = max_round; + row.is_hint_src_id = is_hint_src_id; } + // Load hints if source is a ptr + let is_hint_src_id = is_hint_src_id > F::ZERO; let prod_evals_id = prod_evals_id.as_canonical_u32(); let logup_evals_id = logup_evals_id.as_canonical_u32(); - let prod_evals = state.streams.hint_space[prod_evals_id as usize].clone(); - let logup_evals = state.streams.hint_space[logup_evals_id as usize].clone(); + let (prod_evals, logup_evals) = if is_hint_src_id { + ( + state.streams.hint_space[prod_evals_id as usize].clone(), + state.streams.hint_space[logup_evals_id as usize].clone(), + ) + } else { + (Vec::new(), Vec::new()) + }; // product rows for (i, prod_row) in rows @@ -270,17 +286,36 @@ where i as u32, round, 0, - ) as usize; - prod_specific.data_ptr = F::from_canonical_usize(start); - - // read p1, p2 from hint space - let ps: [F; EXT_DEG * 2] = - prod_evals[start..start + EXT_DEG * 2].try_into().unwrap(); + ); + prod_specific.data_ptr = F::from_canonical_u32(start); + + // read p1, p2 + let ps: [F; EXT_DEG * 2] = if is_hint_src_id { + prod_evals[(start as usize)..((start as usize) + EXT_DEG * 2)] + .try_into() + .unwrap() + } else { + tracing_read_native_helper( + state.memory, + prod_evals_ptr.as_canonical_u32() + start, + prod_specific.ps_record.as_mut(), + ) + }; let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); prod_specific.p = ps; + // If p values come from the hint stream, write back to the actual witness array + if is_hint_src_id { + tracing_write_native_inplace( + state.memory, + prod_evals_ptr.as_canonical_u32() + start, + ps, + &mut prod_specific.ps_record, + ); + } + // compute expected eval let eval = match mode { NEXT_LAYER_MODE => FieldExtension::add( @@ -309,7 +344,7 @@ where eval, &mut prod_specific.write_record, ); - cur_timestamp += 1; // 1 write + cur_timestamp += 2; // Either 1 read, 1 write (witness array input), or 2 writes (hint_ptr_id) let eval_rlc = FieldExtension::multiply(alpha_acc, eval); prod_specific.eval_rlc = eval_rlc; @@ -353,12 +388,21 @@ where i as u32, round, 0, - ) as usize; - logup_specific.data_ptr = F::from_canonical_usize(start); - - // read p1, p2, q1, q2 from hint space - let pqs: [F; EXT_DEG * 4] = - logup_evals[start..start + EXT_DEG * 4].try_into().unwrap(); + ); + logup_specific.data_ptr = F::from_canonical_u32(start); + + // read p1, p2, q1, q2 + let pqs: [F; EXT_DEG * 4] = if is_hint_src_id { + logup_evals[(start as usize)..(start as usize) + EXT_DEG * 4] + .try_into() + .unwrap() + } else { + tracing_read_native_helper( + state.memory, + logup_evals_ptr.as_canonical_u32() + start, + logup_specific.pqs_record.as_mut(), + ) + }; let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); let q1: [F; EXT_DEG] = pqs[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().unwrap(); @@ -366,6 +410,16 @@ where logup_specific.pq = pqs; + // write pqs + if is_hint_src_id { + tracing_write_native_inplace( + state.memory, + logup_evals_ptr.as_canonical_u32() + start, + pqs, + &mut logup_specific.pqs_record, + ); + } + // compute expected evals let p_eval = match mode { NEXT_LAYER_MODE => FieldExtension::add( @@ -416,7 +470,7 @@ where q_eval, &mut logup_specific.write_records[1], ); - cur_timestamp += 2; // 0 read, 2 writes + cur_timestamp += 3; // 1 read, 2 writes (witness array case) or 3 writes (hint space ptr case) let eval_rlc = FieldExtension::add( FieldExtension::multiply(alpha_numerator, p_eval), @@ -495,7 +549,7 @@ impl TraceFiller for NativeSumcheckFiller { let header: &mut HeaderSpecificCols = cols.specific[..HeaderSpecificCols::::width()].borrow_mut(); - for i in 0..6usize { + for i in 0..8usize { mem_fill_helper( mem_helper, start_timestamp + i as u32, @@ -512,16 +566,16 @@ impl TraceFiller for NativeSumcheckFiller { cols.specific[..ProdSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { - // TODO: read p1, p2 from hint space, then write to memory - // mem_fill_helper( - // mem_helper, - // start_timestamp, - // prod_row_specific.read_records[0].as_mut(), - // ); - // write p_eval + // obtain p1, p2 mem_fill_helper( mem_helper, start_timestamp, + prod_row_specific.ps_record.as_mut(), + ); + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp + 1, prod_row_specific.write_record.as_mut(), ); } @@ -530,22 +584,22 @@ impl TraceFiller for NativeSumcheckFiller { cols.specific[..LogupSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { - // TODO: read p1, p2, q1, q2 from hint space - // mem_fill_helper( - // mem_helper, - // start_timestamp, - // logup_row_specific.read_records[0].as_mut(), - // ); - // write p_eval + // obtain p1, p2, q1, q2 mem_fill_helper( mem_helper, start_timestamp, + logup_row_specific.pqs_record.as_mut(), + ); + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp + 1, logup_row_specific.write_records[0].as_mut(), ); // write q_eval mem_fill_helper( mem_helper, - start_timestamp + 1, + start_timestamp + 2, logup_row_specific.write_records[1].as_mut(), ); } diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index 4904e28278..f02f154cf2 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -38,7 +38,7 @@ pub struct NativeSumcheckCols { pub last_timestamp: T, // Register values - pub register_ptrs: [T; 3], + pub register_ptrs: [T; 5], // Context variables // [ @@ -73,6 +73,9 @@ pub struct NativeSumcheckCols { // The current final evaluation accumulator. Extension element. pub eval_acc: [T; EXT_DEG], + // Indicator for an alternative source form of the inputs prod_evals/logup_evals + pub is_hint_src_id: T, + // /// 1. For header row, 5 registers, ctx, challenges // /// 2. For the rest: max_variables, p1, p2, q1, q2 // pub read_records: [MemoryReadAuxCols; 7], @@ -91,11 +94,12 @@ pub struct NativeSumcheckCols { #[derive(AlignedBorrow)] pub struct HeaderSpecificCols { pub pc: T, - pub registers: [T; 3], - pub prod_id: T, - pub logup_id: T, - /// 3 register reads + ctx read + max round read + challenges read - pub read_records: [MemoryReadAuxCols; 6], + pub registers: [T; 5], + // Pointer ids for hint slices + pub prod_evals_id: T, + pub logup_evals_id: T, + /// 5 register reads + ctx read + max round/hint ptrs read + challenges read + pub read_records: [MemoryReadAuxCols; 8], /// Write the final evaluation pub write_records: MemoryWriteAuxCols, } @@ -107,12 +111,13 @@ pub struct ProdSpecificCols { pub data_ptr: T, /// 2 extension elements pub p: [T; EXT_DEG * 2], - /// read 2 p values - pub read_records: [MemoryReadAuxCols; 1], /// Calculated p evals pub p_evals: [T; EXT_DEG], /// write p_evals pub write_record: MemoryWriteAuxCols, + /// Scenario 1: read p1, p2 values from witness array + /// Scenario 2: write p1, p2 values back to witness array if the source is hint space id + pub ps_record: MemoryWriteAuxCols, /// p_evals * alpha^i pub eval_rlc: [T; EXT_DEG], } @@ -125,11 +130,13 @@ pub struct LogupSpecificCols { /// 4 extension elements pub pq: [T; EXT_DEG * 4], /// read 4 values: p1, p2, q1, q2 - pub read_records: [MemoryReadAuxCols; 1], /// Calculated p evals pub p_evals: [T; EXT_DEG], /// Calculated q evals pub q_evals: [T; EXT_DEG], + /// Scenario 1: read p1, p2, q1, q2 from witness array + /// Scenario 2: write p1, p2, q1, q2 back to witness array if the source is hint space id + pub pqs_record: MemoryWriteAuxCols, /// write both p_evals and q_evals pub write_records: [MemoryWriteAuxCols; 2], /// Evaluation for the accumulator diff --git a/extensions/native/circuit/src/sumcheck/execution.rs b/extensions/native/circuit/src/sumcheck/execution.rs index 95901f86ef..b4e117135d 100644 --- a/extensions/native/circuit/src/sumcheck/execution.rs +++ b/extensions/native/circuit/src/sumcheck/execution.rs @@ -23,8 +23,8 @@ struct NativeSumcheckPreCompute { r_evals_reg: u32, ctx_reg: u32, challenges_reg: u32, - prod_evals_id_ptr: u32, - logup_evals_id_ptr: u32, + prod_evals_reg: u32, + logup_evals_reg: u32, } impl NativeSumcheckExecutor { @@ -49,8 +49,8 @@ impl NativeSumcheckExecutor { let r_evals_reg = a.as_canonical_u32(); let ctx_reg = b.as_canonical_u32(); let challenges_reg = c.as_canonical_u32(); - let prod_evals_id_ptr = f.as_canonical_u32(); - let logup_evals_id_ptr = g.as_canonical_u32(); + let prod_evals_reg = f.as_canonical_u32(); + let logup_evals_reg = g.as_canonical_u32(); if d.as_canonical_u32() != NATIVE_AS { return Err(StaticProgramError::InvalidInstruction(pc)); @@ -63,8 +63,8 @@ impl NativeSumcheckExecutor { r_evals_reg, ctx_reg, challenges_reg, - prod_evals_id_ptr, - logup_evals_id_ptr, + prod_evals_reg, + logup_evals_reg, }; Ok(()) @@ -199,13 +199,13 @@ unsafe fn execute_e12_impl( let [r_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.r_evals_reg); let [ctx_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.ctx_reg); let [challenges_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.challenges_reg); - let [prod_evals_id]: [F; 1] = exec_state.host_read(NATIVE_AS, pre_compute.prod_evals_id_ptr); - let [logup_evals_id]: [F; 1] = exec_state.host_read(NATIVE_AS, pre_compute.logup_evals_id_ptr); + let [prod_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.prod_evals_reg); + let [logup_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.logup_evals_reg); let r_evals_ptr_u32 = r_evals_ptr.as_canonical_u32(); let ctx_ptr_u32 = ctx_ptr.as_canonical_u32(); - let logup_evals_id = logup_evals_id.as_canonical_u32(); - let prod_evals_id = prod_evals_id.as_canonical_u32(); + let logup_evals_ptr = logup_evals_ptr.as_canonical_u32(); + let prod_evals_ptr = prod_evals_ptr.as_canonical_u32(); let ctx: [u32; 8] = exec_state .vm_read(NATIVE_AS, ctx_ptr_u32) @@ -214,8 +214,9 @@ unsafe fn execute_e12_impl( ctx; let challenges: [F; EXT_DEG * 4] = exec_state.vm_read(NATIVE_AS, challenges_ptr.as_canonical_u32()); - let [max_round]: [u32; 1] = - exec_state.vm_read(NATIVE_AS, ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32); + let [max_round, is_hint_space_ids, prod_evals_id, logup_evals_id]: [u32; 4] = exec_state + .vm_read(NATIVE_AS, ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32) + .map(|x: F| x.as_canonical_u32()); let alpha: [F; EXT_DEG] = challenges[0..EXT_DEG].try_into().unwrap(); let c1: [F; EXT_DEG] = challenges[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); let c2: [F; EXT_DEG] = challenges[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); @@ -224,8 +225,14 @@ unsafe fn execute_e12_impl( let mut alpha_acc = elem_to_ext(F::ONE); let mut eval_acc = elem_to_ext(F::ZERO); - let prod_evals = exec_state.streams.hint_space[prod_evals_id as usize].clone(); - let logup_evals = exec_state.streams.hint_space[logup_evals_id as usize].clone(); + let (prod_evals, logup_evals) = if is_hint_space_ids > 0 { + ( + exec_state.streams.hint_space[prod_evals_id as usize].clone(), + exec_state.streams.hint_space[logup_evals_id as usize].clone(), + ) + } else { + (Vec::new(), Vec::new()) + }; for i in 0..num_prod_spec { let start = calculate_3d_ext_idx( @@ -234,10 +241,15 @@ unsafe fn execute_e12_impl( i, round, 0, - ) as usize; + ); if round < max_round - 1 { - let ps: &[F] = &prod_evals[start..start + EXT_DEG * 2]; + let ps: &[F] = if is_hint_space_ids > 0 { + &prod_evals[(start as usize)..(start as usize) + EXT_DEG * 2] + } else { + &exec_state.vm_read::<_, { EXT_DEG * 2 }>(NATIVE_AS, prod_evals_ptr + start) + }; + let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = ps[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); @@ -271,14 +283,18 @@ unsafe fn execute_e12_impl( i, round, 0, - ) as usize; + ); let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); let alpha_numerator = alpha_acc; if round < max_round - 1 { // read logup_evals - let pqs: &[F] = &logup_evals[start..start + EXT_DEG * 4]; + let pqs: &[F] = if is_hint_space_ids > 0 { + &logup_evals[(start as usize)..(start as usize) + EXT_DEG * 4] + } else { + &exec_state.vm_read::<_, { EXT_DEG * 4 }>(NATIVE_AS, logup_evals_ptr + start) + }; let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = pqs[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); let q1: [F; EXT_DEG] = pqs[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); diff --git a/extensions/native/compiler/src/conversion/mod.rs b/extensions/native/compiler/src/conversion/mod.rs index 79797ff0ae..e71608c190 100644 --- a/extensions/native/compiler/src/conversion/mod.rs +++ b/extensions/native/compiler/src/conversion/mod.rs @@ -536,7 +536,7 @@ fn convert_instruction>( AS::Immediate, )] }, - AsmInstruction::SumcheckLayerEval(ctx, cs, prod_id, logup_id, r_ptr) => vec![ + AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr, r_ptr) => vec![ Instruction { opcode: options.opcode_with_offset(SumcheckOpcode::SUMCHECK_LAYER_EVAL), a: i32_f(r_ptr), @@ -544,8 +544,8 @@ fn convert_instruction>( c: i32_f(cs), d: AS::Native.to_field(), e: AS::Native.to_field(), - f: i32_f(prod_id), - g: i32_f(logup_id), + f: i32_f(p_ptr), + g: i32_f(l_ptr), } ], }; diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index d44bead6f9..78347283d5 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -341,8 +341,8 @@ pub enum DslIr { // 7. Operational mode indicator // 8+. usize-type variables indicating maximum rounds Ptr, // Challenges: alpha, coeffs - Var, // prod_specs_eval - Var, // logup_specs_eval + Ptr, // prod_specs_eval + Ptr, // logup_specs_eval Ptr, // output ), } diff --git a/extensions/native/compiler/src/ir/sumcheck.rs b/extensions/native/compiler/src/ir/sumcheck.rs index 792c9543b2..0237fd6740 100644 --- a/extensions/native/compiler/src/ir/sumcheck.rs +++ b/extensions/native/compiler/src/ir/sumcheck.rs @@ -1,4 +1,4 @@ -use super::{Array, Builder, Config, DslIr, Ext, Usize, Var}; +use super::{Array, Builder, Config, DslIr, Ext, Usize}; impl Builder { /// Extends native VM ability to calculate the evaluation for a sumcheck layer @@ -30,16 +30,18 @@ impl Builder { &mut self, input_ctx: &Array>, // Context variables challenges: &Array>, // Challenges - prod_specs_eval_id: Var, /* ID for GKR product IOP evaluations hint. */ - logup_specs_eval_id: Var, /* ID for GKR logup IOP evaluations hint. */ + prod_specs_eval: &Array>, /* GKR product IOP evaluations. Flattened + * from 3D array. */ + logup_specs_eval: &Array>, /* GKR logup IOP evaluations. Flattened + * from 3D array. */ r_evals: &Array>, /* Next layer's evaluations (pointer used for - * storing opcode output) */ + * storing opcode output) */ ) { self.operations.push(DslIr::SumcheckLayerEval( input_ctx.ptr(), challenges.ptr(), - prod_specs_eval_id, - logup_specs_eval_id, + prod_specs_eval.ptr(), + logup_specs_eval.ptr(), r_evals.ptr(), )); } diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs index 145336c362..97bbff57c8 100644 --- a/extensions/native/recursion/tests/sumcheck.rs +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -20,12 +20,13 @@ use openvm_stark_sdk::{ p3_baby_bear::BabyBear, }; use rand::{thread_rng, RngCore}; +const PRIME: u32 = 0x78000001; pub type F = BabyBear; pub type E = BinomialExtensionField; #[test] -fn test_sumcheck_layer_eval() { +fn test_sumcheck_layer_eval_with_hint_ids() { let mut rng = thread_rng(); let mut builder = AsmBuilder::>::default(); @@ -118,10 +119,10 @@ fn test_sumcheck_layer_eval() { fn new_rand_ext(rng: &mut R) -> E { E::from_base_slice(&[ - F::from_canonical_u32(rng.next_u32()), - F::from_canonical_u32(rng.next_u32()), - F::from_canonical_u32(rng.next_u32()), - F::from_canonical_u32(rng.next_u32()), + F::from_canonical_u32(rng.next_u32() % PRIME), + F::from_canonical_u32(rng.next_u32() % PRIME), + F::from_canonical_u32(rng.next_u32() % PRIME), + F::from_canonical_u32(rng.next_u32() % PRIME), ]) } @@ -145,9 +146,11 @@ fn build_test_program( num_layers, 4, mode, + 999, // max round + 1, // input from hint ids + 0, + 0, ]; - ctx_u32s.extend(repeat_n(num_layers, num_prod_specs + num_logup_specs)); - let ctx: Array> = builder.dyn_array(ctx_u32s.len()); for (idx, n) in ctx_u32s.into_iter().enumerate() { builder.set(&ctx, idx, Usize::from(n)); @@ -251,14 +254,16 @@ fn build_test_program( let prod_spec_evals_id = builder.hint_load(); let logup_spec_evals_id = builder.hint_load(); + builder.set(&ctx, 10, prod_spec_evals_id); + builder.set(&ctx, 11, logup_spec_evals_id); let next_layer_evals: Array> = builder.dyn_array(r_evals.len()); builder.sumcheck_layer_eval( &ctx, &challenges, - prod_spec_evals_id, - logup_spec_evals_id, + &prod_spec_evals, + &logup_spec_evals, &next_layer_evals, );