diff --git a/extensions/native/circuit/cuda/include/native/sumcheck.cuh b/extensions/native/circuit/cuda/include/native/sumcheck.cuh index da7aefd769..4f16d4fad7 100644 --- a/extensions/native/circuit/cuda/include/native/sumcheck.cuh +++ b/extensions/native/circuit/cuda/include/native/sumcheck.cuh @@ -7,8 +7,10 @@ using namespace native; template struct HeaderSpecificCols { T pc; - T registers[5]; - MemoryReadAuxCols read_records[8]; + T registers[3]; + T prod_id; + T logup_id; + MemoryReadAuxCols read_records[6]; MemoryWriteAuxCols write_records; }; @@ -61,7 +63,7 @@ template struct NativeSumcheckCols { T start_timestamp; T last_timestamp; - T register_ptrs[5]; + T register_ptrs[3]; T ctx[EXT_DEG * 2]; diff --git a/extensions/native/circuit/cuda/src/sumcheck.cu b/extensions/native/circuit/cuda/src/sumcheck.cu index 173abebbe2..def3db0d1d 100644 --- a/extensions/native/circuit/cuda/src/sumcheck.cu +++ b/extensions/native/circuit/cuda/src/sumcheck.cu @@ -11,7 +11,7 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h uint32_t start_timestamp = row[COL_INDEX(NativeSumcheckCols, start_timestamp)].asUInt32(); if (row[COL_INDEX(NativeSumcheckCols, header_row)] == Fp::one()) { - for (uint32_t i = 0; i < 8; ++i) { + for (uint32_t i = 0; i < 6; ++i) { mem_fill_base( mem_helper, start_timestamp + i, @@ -26,32 +26,32 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h ); } else if (row[COL_INDEX(NativeSumcheckCols, prod_row)] == Fp::one()) { if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { + // mem_fill_base( + // mem_helper, + // start_timestamp, + // specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base)) + // ); mem_fill_base( mem_helper, start_timestamp, - specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base)) - ); - mem_fill_base( - mem_helper, - start_timestamp + 1, specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base)) ); } } else if (row[COL_INDEX(NativeSumcheckCols, logup_row)] == Fp::one()) { if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { + // mem_fill_base( + // mem_helper, + // start_timestamp, + // specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base)) + // ); mem_fill_base( mem_helper, start_timestamp, - specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base)) - ); - mem_fill_base( - mem_helper, - start_timestamp + 1, specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base)) ); mem_fill_base( mem_helper, - start_timestamp + 2, + start_timestamp + 1, specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base)) ); } diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 5bec217a1f..a82141332b 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -23,6 +23,10 @@ use crate::{ }, }; +pub const TOPLEVEL_TIMESTAMP_DIFF: usize = 6; +pub const NUM_RWS_FOR_PRODUCT: usize = 1; +pub const NUM_RWS_FOR_LOGUP: usize = 2; + #[derive(Clone, Debug)] pub struct NativeSumcheckAir { pub execution_bridge: ExecutionBridge, @@ -226,21 +230,23 @@ impl Air for NativeSumcheckAir { .when(next.prod_row + next.logup_row) .assert_eq( next.start_timestamp, - start_timestamp + AB::F::from_canonical_usize(8), + start_timestamp + AB::F::from_canonical_usize(TOPLEVEL_TIMESTAMP_DIFF), ); builder .when(prod_row) .when(next.prod_row + next.logup_row) .assert_eq( next.start_timestamp, - start_timestamp + within_round_limit * AB::F::TWO, + start_timestamp + + within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_PRODUCT), ); builder .when(logup_row) .when(next.prod_row + next.logup_row) .assert_eq( next.start_timestamp, - start_timestamp + within_round_limit * AB::F::from_canonical_usize(3), + start_timestamp + + within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_LOGUP), ); // Termination condition @@ -288,27 +294,31 @@ impl Air for NativeSumcheckAir { .execute_and_increment_pc( AB::Expr::from_canonical_usize(SUMCHECK_LAYER_EVAL.global_opcode().as_usize()), [ - registers[4].into(), + registers[2].into(), registers[0].into(), registers[1].into(), native_as.into(), native_as.into(), - registers[2].into(), - registers[3].into(), + header_row_specific.prod_id.into(), + header_row_specific.logup_id.into(), ], ExecutionState::new(header_row_specific.pc, first_timestamp), last_timestamp - first_timestamp, ) .eval(builder, header_row); + let mut header_timestamp_diff = (0..TOPLEVEL_TIMESTAMP_DIFF) + .into_iter() + .map(|i| AB::F::from_canonical_usize(i)); + let mut header_read_records_iter = header_row_specific.read_records.iter(); // Read registers - for i in 0..5usize { + for i in 0..3usize { self.memory_bridge .read( MemoryAddress::new(native_as, registers[i]), [register_ptrs[i]], - first_timestamp + AB::F::from_canonical_usize(i), - &header_row_specific.read_records[i], + first_timestamp + header_timestamp_diff.next().unwrap(), + header_read_records_iter.next().unwrap(), ) .eval(builder, header_row); } @@ -318,8 +328,8 @@ impl Air for NativeSumcheckAir { .read( MemoryAddress::new(native_as, register_ptrs[0]), ctx, - first_timestamp + AB::F::from_canonical_usize(5), - &header_row_specific.read_records[5], + first_timestamp + header_timestamp_diff.next().unwrap(), + header_read_records_iter.next().unwrap(), ) .eval(builder, header_row); @@ -328,8 +338,8 @@ impl Air for NativeSumcheckAir { .read( MemoryAddress::new(native_as, register_ptrs[1]), challenges, - first_timestamp + AB::F::from_canonical_usize(6), - &header_row_specific.read_records[6], + first_timestamp + header_timestamp_diff.next().unwrap(), + header_read_records_iter.next().unwrap(), ) .eval(builder, header_row); @@ -341,15 +351,15 @@ impl Air for NativeSumcheckAir { register_ptrs[0] + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN), ), [max_round], - first_timestamp + AB::F::from_canonical_usize(7), - &header_row_specific.read_records[7], + first_timestamp + header_timestamp_diff.next().unwrap(), + header_read_records_iter.next().unwrap(), ) .eval(builder, header_row); // Write final result self.memory_bridge .write( - MemoryAddress::new(native_as, register_ptrs[4]), + MemoryAddress::new(native_as, register_ptrs[2]), eval_acc, last_timestamp - AB::F::ONE, &header_row_specific.write_records, @@ -383,14 +393,17 @@ impl Air for NativeSumcheckAir { ); builder.assert_eq(prod_row * should_acc, prod_acc); - self.memory_bridge - .read( - MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), - prod_row_specific.p, - start_timestamp, - &prod_row_specific.read_records[0], - ) - .eval(builder, prod_row * within_round_limit); + let mut prod_timestamp_diff = (0..1).into_iter().map(|i| AB::F::from_canonical_usize(i)); + + // TODO: read from hint space then write to memory + // self.memory_bridge + // .read( + // MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), + // prod_row_specific.p, + // start_timestamp, + // &prod_row_specific.read_records[0], + // ) + // .eval(builder, prod_row * within_round_limit); let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().unwrap(); let p2: [AB::Var; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)] @@ -401,14 +414,16 @@ impl Air for NativeSumcheckAir { .write( MemoryAddress::new( native_as, - register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), + register_ptrs[2] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), ), prod_row_specific.p_evals, - start_timestamp + AB::F::ONE, + start_timestamp + prod_timestamp_diff.next().unwrap(), &prod_row_specific.write_record, ) .eval(builder, prod_row * within_round_limit); + assert!(prod_timestamp_diff.next().is_none()); + // Calculate evaluations let next_round_p_evals = FieldExtension::add( FieldExtension::multiply::(p1, c1), @@ -471,14 +486,16 @@ impl Air for NativeSumcheckAir { ); builder.assert_eq(logup_row * should_acc, logup_acc); - self.memory_bridge - .read( - MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), - logup_row_specific.pq, - start_timestamp, - &logup_row_specific.read_records[0], - ) - .eval(builder, logup_row * within_round_limit); + let mut logup_timestamp_diff = (0..2).into_iter().map(|i| AB::F::from_canonical_usize(i)); + + // self.memory_bridge + // .read( + // MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), + // logup_row_specific.pq, + // start_timestamp, + // &logup_row_specific.read_records[0], + // ) + // .eval(builder, logup_row * within_round_limit); let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().unwrap(); let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)] @@ -496,11 +513,11 @@ impl Air for NativeSumcheckAir { .write( MemoryAddress::new( native_as, - register_ptrs[4] + register_ptrs[2] + (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.p_evals, - start_timestamp + AB::F::ONE, + start_timestamp + logup_timestamp_diff.next().unwrap(), &logup_row_specific.write_records[0], ) .eval(builder, logup_row * within_round_limit); @@ -510,12 +527,12 @@ impl Air for NativeSumcheckAir { .write( MemoryAddress::new( native_as, - register_ptrs[4] + register_ptrs[2] + (num_prod_spec + num_logup_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.q_evals, - start_timestamp + AB::F::TWO, + start_timestamp + logup_timestamp_diff.next().unwrap(), &logup_row_specific.write_records[1], ) .eval(builder, logup_row * within_round_limit); diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index d5e6f49a62..4f7ba79efd 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -125,8 +125,8 @@ where c: challenges_reg, d: data_address_space, e: register_address_space, - f: prod_evals_reg, - g: logup_evals_reg, + f: prod_evals_id_ptr, + g: logup_evals_id_ptr, } = instruction; // This opcode supports two modes of operation: @@ -167,53 +167,49 @@ where head_specific.registers[0] = ctx_reg; head_specific.registers[1] = challenges_reg; - head_specific.registers[2] = prod_evals_reg; - head_specific.registers[3] = logup_evals_reg; - head_specific.registers[4] = r_evals_reg; + head_specific.registers[2] = r_evals_reg; + let mut head_read_records_iter = head_specific.read_records.iter_mut().map(|r| r.as_mut()); // read pointers let [ctx_ptr]: [F; 1] = tracing_read_native_helper( state.memory, ctx_reg.as_canonical_u32(), - head_specific.read_records[0].as_mut(), + head_read_records_iter.next().unwrap(), ); let [challenges_ptr]: [F; 1] = tracing_read_native_helper( state.memory, challenges_reg.as_canonical_u32(), - head_specific.read_records[1].as_mut(), - ); - let [prod_evals_ptr]: [F; 1] = tracing_read_native_helper( - state.memory, - prod_evals_reg.as_canonical_u32(), - head_specific.read_records[2].as_mut(), - ); - let [logup_evals_ptr]: [F; 1] = tracing_read_native_helper( - state.memory, - logup_evals_reg.as_canonical_u32(), - head_specific.read_records[3].as_mut(), + head_read_records_iter.next().unwrap(), ); + let [prod_evals_id]: [F; 1] = + memory_read_native(state.memory.data(), prod_evals_id_ptr.as_canonical_u32()); + let [logup_evals_id]: [F; 1] = + memory_read_native(state.memory.data(), logup_evals_id_ptr.as_canonical_u32()); let [r_evals_ptr]: [F; 1] = tracing_read_native_helper( state.memory, r_evals_reg.as_canonical_u32(), - head_specific.read_records[4].as_mut(), + head_read_records_iter.next().unwrap(), ); let ctx: [F; CONTEXT_ARR_BASE_LEN] = tracing_read_native_helper( state.memory, ctx_ptr.as_canonical_u32(), - head_specific.read_records[5].as_mut(), + head_read_records_iter.next().unwrap(), ); let challenges: [F; EXT_DEG * 4] = tracing_read_native_helper( state.memory, challenges_ptr.as_canonical_u32(), - head_specific.read_records[6].as_mut(), + head_read_records_iter.next().unwrap(), ); let [max_round]: [F; 1] = tracing_read_native_helper( state.memory, ctx_ptr.as_canonical_u32() + CONTEXT_ARR_BASE_LEN as u32, - head_specific.read_records[7].as_mut(), + head_read_records_iter.next().unwrap(), ); - cur_timestamp += 8; // 5 register reads + ctx read + challenges read + max_round read + assert!(head_read_records_iter.next().is_none()); + cur_timestamp += 6; // 3 register reads + ctx read + challenges read + max_round read head_row.challenges.copy_from_slice(&challenges); + head_specific.prod_id = prod_evals_id_ptr; + head_specific.logup_id = logup_evals_id_ptr; // challenges = [alpha, c1=r, c2=1-r] let alpha: [F; 4] = challenges[0..EXT_DEG].try_into().unwrap(); @@ -235,12 +231,15 @@ where F::from_canonical_u32(logup_specs_inner_len * logup_specs_inner_inner_len); row.register_ptrs[0] = ctx_ptr; row.register_ptrs[1] = challenges_ptr; - row.register_ptrs[2] = prod_evals_ptr; - row.register_ptrs[3] = logup_evals_ptr; - row.register_ptrs[4] = r_evals_ptr; + row.register_ptrs[2] = r_evals_ptr; row.max_round = max_round; } + let prod_evals_id = prod_evals_id.as_canonical_u32(); + let logup_evals_id = logup_evals_id.as_canonical_u32(); + let prod_evals = state.streams.hint_space[prod_evals_id as usize].clone(); + let logup_evals = state.streams.hint_space[logup_evals_id as usize].clone(); + // product rows for (i, prod_row) in rows .iter_mut() @@ -271,15 +270,12 @@ where i as u32, round, 0, - ); - prod_specific.data_ptr = F::from_canonical_u32(start); + ) as usize; + prod_specific.data_ptr = F::from_canonical_usize(start); - // read p1, p2 - let ps: [F; EXT_DEG * 2] = tracing_read_native_helper( - state.memory, - prod_evals_ptr.as_canonical_u32() + start, - prod_specific.read_records[0].as_mut(), - ); + // read p1, p2 from hint space + let ps: [F; EXT_DEG * 2] = + prod_evals[start..start + EXT_DEG * 2].try_into().unwrap(); let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); @@ -313,7 +309,7 @@ where eval, &mut prod_specific.write_record, ); - cur_timestamp += 2; + cur_timestamp += 1; // 1 write let eval_rlc = FieldExtension::multiply(alpha_acc, eval); prod_specific.eval_rlc = eval_rlc; @@ -357,15 +353,12 @@ where i as u32, round, 0, - ); - logup_specific.data_ptr = F::from_canonical_u32(start); + ) as usize; + logup_specific.data_ptr = F::from_canonical_usize(start); - // read p1, p2, q1, q2 - let pqs: [F; EXT_DEG * 4] = tracing_read_native_helper( - state.memory, - logup_evals_ptr.as_canonical_u32() + start, - logup_specific.read_records[0].as_mut(), - ); + // read p1, p2, q1, q2 from hint space + let pqs: [F; EXT_DEG * 4] = + logup_evals[start..start + EXT_DEG * 4].try_into().unwrap(); let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); let q1: [F; EXT_DEG] = pqs[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().unwrap(); @@ -423,7 +416,7 @@ where q_eval, &mut logup_specific.write_records[1], ); - cur_timestamp += 3; // 1 read, 2 writes + cur_timestamp += 2; // 0 read, 2 writes let eval_rlc = FieldExtension::add( FieldExtension::multiply(alpha_numerator, p_eval), @@ -502,7 +495,7 @@ impl TraceFiller for NativeSumcheckFiller { let header: &mut HeaderSpecificCols = cols.specific[..HeaderSpecificCols::::width()].borrow_mut(); - for i in 0..8usize { + for i in 0..6usize { mem_fill_helper( mem_helper, start_timestamp + i as u32, @@ -519,16 +512,16 @@ impl TraceFiller for NativeSumcheckFiller { cols.specific[..ProdSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { - // read p1, p2 - mem_fill_helper( - mem_helper, - start_timestamp, - prod_row_specific.read_records[0].as_mut(), - ); + // TODO: read p1, p2 from hint space, then write to memory + // mem_fill_helper( + // mem_helper, + // start_timestamp, + // prod_row_specific.read_records[0].as_mut(), + // ); // write p_eval mem_fill_helper( mem_helper, - start_timestamp + 1, + start_timestamp, prod_row_specific.write_record.as_mut(), ); } @@ -537,22 +530,22 @@ impl TraceFiller for NativeSumcheckFiller { cols.specific[..LogupSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { - // read p1, p2, q1, q2 - mem_fill_helper( - mem_helper, - start_timestamp, - logup_row_specific.read_records[0].as_mut(), - ); + // TODO: read p1, p2, q1, q2 from hint space + // mem_fill_helper( + // mem_helper, + // start_timestamp, + // logup_row_specific.read_records[0].as_mut(), + // ); // write p_eval mem_fill_helper( mem_helper, - start_timestamp + 1, + start_timestamp, logup_row_specific.write_records[0].as_mut(), ); // write q_eval mem_fill_helper( mem_helper, - start_timestamp + 2, + start_timestamp + 1, logup_row_specific.write_records[1].as_mut(), ); } diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index 51eb6d39cf..4904e28278 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -38,7 +38,7 @@ pub struct NativeSumcheckCols { pub last_timestamp: T, // Register values - pub register_ptrs: [T; 5], + pub register_ptrs: [T; 3], // Context variables // [ @@ -91,9 +91,11 @@ pub struct NativeSumcheckCols { #[derive(AlignedBorrow)] pub struct HeaderSpecificCols { pub pc: T, - pub registers: [T; 5], - /// 5 register reads + ctx read + max round read + challenges read - pub read_records: [MemoryReadAuxCols; 8], + pub registers: [T; 3], + pub prod_id: T, + pub logup_id: T, + /// 3 register reads + ctx read + max round read + challenges read + pub read_records: [MemoryReadAuxCols; 6], /// Write the final evaluation pub write_records: MemoryWriteAuxCols, } diff --git a/extensions/native/circuit/src/sumcheck/execution.rs b/extensions/native/circuit/src/sumcheck/execution.rs index 54e06e540c..95901f86ef 100644 --- a/extensions/native/circuit/src/sumcheck/execution.rs +++ b/extensions/native/circuit/src/sumcheck/execution.rs @@ -23,8 +23,8 @@ struct NativeSumcheckPreCompute { r_evals_reg: u32, ctx_reg: u32, challenges_reg: u32, - prod_evals_reg: u32, - logup_evals_reg: u32, + prod_evals_id_ptr: u32, + logup_evals_id_ptr: u32, } impl NativeSumcheckExecutor { @@ -49,8 +49,8 @@ impl NativeSumcheckExecutor { let r_evals_reg = a.as_canonical_u32(); let ctx_reg = b.as_canonical_u32(); let challenges_reg = c.as_canonical_u32(); - let prod_evals_reg = f.as_canonical_u32(); - let logup_evals_reg = g.as_canonical_u32(); + let prod_evals_id_ptr = f.as_canonical_u32(); + let logup_evals_id_ptr = g.as_canonical_u32(); if d.as_canonical_u32() != NATIVE_AS { return Err(StaticProgramError::InvalidInstruction(pc)); @@ -63,8 +63,8 @@ impl NativeSumcheckExecutor { r_evals_reg, ctx_reg, challenges_reg, - prod_evals_reg, - logup_evals_reg, + prod_evals_id_ptr, + logup_evals_id_ptr, }; Ok(()) @@ -199,13 +199,13 @@ unsafe fn execute_e12_impl( let [r_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.r_evals_reg); let [ctx_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.ctx_reg); let [challenges_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.challenges_reg); - let [prod_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.prod_evals_reg); - let [logup_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.logup_evals_reg); + let [prod_evals_id]: [F; 1] = exec_state.host_read(NATIVE_AS, pre_compute.prod_evals_id_ptr); + let [logup_evals_id]: [F; 1] = exec_state.host_read(NATIVE_AS, pre_compute.logup_evals_id_ptr); let r_evals_ptr_u32 = r_evals_ptr.as_canonical_u32(); let ctx_ptr_u32 = ctx_ptr.as_canonical_u32(); - let logup_evals_ptr = logup_evals_ptr.as_canonical_u32(); - let prod_evals_ptr = prod_evals_ptr.as_canonical_u32(); + let logup_evals_id = logup_evals_id.as_canonical_u32(); + let prod_evals_id = prod_evals_id.as_canonical_u32(); let ctx: [u32; 8] = exec_state .vm_read(NATIVE_AS, ctx_ptr_u32) @@ -224,6 +224,9 @@ unsafe fn execute_e12_impl( let mut alpha_acc = elem_to_ext(F::ONE); let mut eval_acc = elem_to_ext(F::ZERO); + let prod_evals = exec_state.streams.hint_space[prod_evals_id as usize].clone(); + let logup_evals = exec_state.streams.hint_space[logup_evals_id as usize].clone(); + for i in 0..num_prod_spec { let start = calculate_3d_ext_idx( prod_specs_inner_inner_len, @@ -231,10 +234,10 @@ unsafe fn execute_e12_impl( i, round, 0, - ); + ) as usize; if round < max_round - 1 { - let ps: [F; EXT_DEG * 2] = exec_state.vm_read(NATIVE_AS, prod_evals_ptr + start); + let ps: &[F] = &prod_evals[start..start + EXT_DEG * 2]; let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = ps[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); @@ -268,14 +271,14 @@ unsafe fn execute_e12_impl( i, round, 0, - ); + ) as usize; let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); let alpha_numerator = alpha_acc; if round < max_round - 1 { // read logup_evals - let pqs: [F; EXT_DEG * 4] = exec_state.vm_read(NATIVE_AS, logup_evals_ptr + start); + let pqs: &[F] = &logup_evals[start..start + EXT_DEG * 4]; let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = pqs[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); let q1: [F; EXT_DEG] = pqs[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); diff --git a/extensions/native/compiler/src/conversion/mod.rs b/extensions/native/compiler/src/conversion/mod.rs index e71608c190..79797ff0ae 100644 --- a/extensions/native/compiler/src/conversion/mod.rs +++ b/extensions/native/compiler/src/conversion/mod.rs @@ -536,7 +536,7 @@ fn convert_instruction>( AS::Immediate, )] }, - AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr, r_ptr) => vec![ + AsmInstruction::SumcheckLayerEval(ctx, cs, prod_id, logup_id, r_ptr) => vec![ Instruction { opcode: options.opcode_with_offset(SumcheckOpcode::SUMCHECK_LAYER_EVAL), a: i32_f(r_ptr), @@ -544,8 +544,8 @@ fn convert_instruction>( c: i32_f(cs), d: AS::Native.to_field(), e: AS::Native.to_field(), - f: i32_f(p_ptr), - g: i32_f(l_ptr), + f: i32_f(prod_id), + g: i32_f(logup_id), } ], }; diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index 78347283d5..d44bead6f9 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -341,8 +341,8 @@ pub enum DslIr { // 7. Operational mode indicator // 8+. usize-type variables indicating maximum rounds Ptr, // Challenges: alpha, coeffs - Ptr, // prod_specs_eval - Ptr, // logup_specs_eval + Var, // prod_specs_eval + Var, // logup_specs_eval Ptr, // output ), } diff --git a/extensions/native/compiler/src/ir/sumcheck.rs b/extensions/native/compiler/src/ir/sumcheck.rs index 0237fd6740..792c9543b2 100644 --- a/extensions/native/compiler/src/ir/sumcheck.rs +++ b/extensions/native/compiler/src/ir/sumcheck.rs @@ -1,4 +1,4 @@ -use super::{Array, Builder, Config, DslIr, Ext, Usize}; +use super::{Array, Builder, Config, DslIr, Ext, Usize, Var}; impl Builder { /// Extends native VM ability to calculate the evaluation for a sumcheck layer @@ -30,18 +30,16 @@ impl Builder { &mut self, input_ctx: &Array>, // Context variables challenges: &Array>, // Challenges - prod_specs_eval: &Array>, /* GKR product IOP evaluations. Flattened - * from 3D array. */ - logup_specs_eval: &Array>, /* GKR logup IOP evaluations. Flattened - * from 3D array. */ + prod_specs_eval_id: Var, /* ID for GKR product IOP evaluations hint. */ + logup_specs_eval_id: Var, /* ID for GKR logup IOP evaluations hint. */ r_evals: &Array>, /* Next layer's evaluations (pointer used for - * storing opcode output) */ + * storing opcode output) */ ) { self.operations.push(DslIr::SumcheckLayerEval( input_ctx.ptr(), challenges.ptr(), - prod_specs_eval.ptr(), - logup_specs_eval.ptr(), + prod_specs_eval_id, + logup_specs_eval_id, r_evals.ptr(), )); } diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs index 284a103021..145336c362 100644 --- a/extensions/native/recursion/tests/sumcheck.rs +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -26,9 +26,32 @@ pub type E = BinomialExtensionField; #[test] fn test_sumcheck_layer_eval() { + let mut rng = thread_rng(); let mut builder = AsmBuilder::>::default(); - build_test_program(&mut builder); + let num_layers = 8; + let num_prod_specs = 6; + let num_logup_specs = 8; + + let prod_evals: Vec = (0..(num_prod_specs * num_layers * 2)) + .into_iter() + .map(|_| new_rand_ext(&mut rng)) + .collect(); + + let logup_evals: Vec = (0..(num_logup_specs * num_layers * 4)) + .into_iter() + .map(|_| new_rand_ext(&mut rng)) + .collect(); + + build_test_program( + &mut builder, + prod_evals.clone(), + logup_evals.clone(), + num_prod_specs, + num_logup_specs, + num_layers, + 3, + ); let compilation_options = CompilerOptions::default().with_cycle_tracker(); let mut compiler = AsmCompiler::new(compilation_options.word_size); @@ -49,13 +72,35 @@ fn test_sumcheck_layer_eval() { standard_fri_params_with_100_bits_conjectured_security(1) }; + let mut input_stream: Vec> = vec![]; + input_stream.push( + prod_evals + .into_iter() + .flat_map(|e| >::as_base_slice(&e).to_vec()) + .collect(), + ); + input_stream.push( + logup_evals + .into_iter() + .flat_map(|e| >::as_base_slice(&e).to_vec()) + .collect(), + ); + let mut config = NativeConfig::aggregation(0, sumcheck_max_constraint_degree); config.system.memory_config.max_access_adapter_n = 16; let vb = NativeBuilder::default(); #[cfg(not(feature = "cuda"))] - air_test_impl::(fri_params, vb, config, program, vec![], 1, true) - .unwrap(); + air_test_impl::( + fri_params, + vb, + config, + program, + input_stream, + 1, + true, + ) + .unwrap(); #[cfg(feature = "cuda")] { air_test_impl::( @@ -63,7 +108,7 @@ fn test_sumcheck_layer_eval() { vb, config, program, - vec![], + input_stream, 1, true, ) @@ -71,22 +116,24 @@ fn test_sumcheck_layer_eval() { } } -fn new_rand_ext(rng: &mut R) -> C::EF { - C::EF::from_base_slice(&[ - C::F::from_canonical_u32(rng.next_u32()), - C::F::from_canonical_u32(rng.next_u32()), - C::F::from_canonical_u32(rng.next_u32()), - C::F::from_canonical_u32(rng.next_u32()), +fn new_rand_ext(rng: &mut R) -> E { + E::from_base_slice(&[ + F::from_canonical_u32(rng.next_u32()), + F::from_canonical_u32(rng.next_u32()), + F::from_canonical_u32(rng.next_u32()), + F::from_canonical_u32(rng.next_u32()), ]) } -fn build_test_program(builder: &mut Builder) { - let mut rng = thread_rng(); - // 6 prod specs in 8 layers, 5 logup specs in 8 layers - let round = 3; - let num_prod_specs = 6; - let num_logup_specs = 5; - let num_layers = 8; +fn build_test_program( + builder: &mut Builder, + prod_evals: Vec, + logup_evals: Vec, + num_prod_specs: usize, + num_logup_specs: usize, + num_layers: usize, + round: usize, +) { let mode = 1; // current_layer let mut ctx_u32s = vec![ @@ -128,7 +175,7 @@ fn build_test_program(builder: &mut Builder) { let num_prod_evals = num_prod_specs * num_layers * 2; let prod_spec_evals: Array> = builder.dyn_array(num_prod_evals); for idx in 0..num_prod_evals { - let e: Ext = builder.constant(new_rand_ext::(&mut rng)); + let e: Ext = builder.constant(prod_evals[idx]); builder.set(&prod_spec_evals, idx, e); } @@ -136,7 +183,7 @@ fn build_test_program(builder: &mut Builder) { let num_logup_evals = num_logup_specs * num_layers * 4; let logup_spec_evals: Array> = builder.dyn_array(num_logup_evals); for idx in 0..num_logup_evals { - let e: Ext = builder.constant(new_rand_ext::(&mut rng)); + let e: Ext = builder.constant(logup_evals[idx]); builder.set(&logup_spec_evals, idx, e); } @@ -202,13 +249,16 @@ fn build_test_program(builder: &mut Builder) { .chain(logup_q_evals) .collect::>(); + let prod_spec_evals_id = builder.hint_load(); + let logup_spec_evals_id = builder.hint_load(); + let next_layer_evals: Array> = builder.dyn_array(r_evals.len()); builder.sumcheck_layer_eval( &ctx, &challenges, - &prod_spec_evals, - &logup_spec_evals, + prod_spec_evals_id, + logup_spec_evals_id, &next_layer_evals, );