From 3714bb63c5720827a46bc2eb454b64bfe724ffd2 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 15 Apr 2026 18:18:55 +0200 Subject: [PATCH 01/31] reduce degree AIR poseidon --- crates/lean_vm/src/core/constants.rs | 2 +- crates/lean_vm/src/tables/poseidon_16/mod.rs | 75 ++++++++----------- .../src/tables/poseidon_16/trace_gen.rs | 41 ++++------ 3 files changed, 44 insertions(+), 74 deletions(-) diff --git a/crates/lean_vm/src/core/constants.rs b/crates/lean_vm/src/core/constants.rs index 7edfff82..c0da3506 100644 --- a/crates/lean_vm/src/core/constants.rs +++ b/crates/lean_vm/src/core/constants.rs @@ -22,7 +22,7 @@ pub const MIN_LOG_N_ROWS_PER_TABLE: usize = 8; // Zero padding will be added to pub const MAX_LOG_N_ROWS_PER_TABLE: [(Table, usize); 3] = [ (Table::execution(), 25), (Table::extension_op(), 20), - (Table::poseidon16(), 21), + (Table::poseidon16(), 20), ]; /// Starting program counter diff --git a/crates/lean_vm/src/tables/poseidon_16/mod.rs b/crates/lean_vm/src/tables/poseidon_16/mod.rs index 9c0613b0..4aac4b39 100644 --- a/crates/lean_vm/src/tables/poseidon_16/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_16/mod.rs @@ -9,9 +9,9 @@ use utils::{ToUsize, poseidon16_compress}; /// For `SymbolicExpression` we use the dense form so the zkDSL generator can /// emit `dot_product_be` precompile calls instead of Karatsuba arithmetic. #[inline(always)] -fn mds_air_16(state: &mut [A; WIDTH]) { +fn mds_air(state: &mut [A; WIDTH]) { if TypeId::of::() == TypeId::of::>() { - dense_mat_vec_air_16(mds_dense_16(), state); + dense_mat_vec_air(mds_dense_16(), state); return; } macro_rules! dispatch { @@ -85,9 +85,9 @@ mod trace_gen; pub use trace_gen::fill_trace_poseidon_16; pub(super) const WIDTH: usize = 16; -const HALF_INITIAL_FULL_ROUNDS: usize = POSEIDON1_HALF_FULL_ROUNDS / 2; +const INITIAL_FULL_ROUNDS: usize = POSEIDON1_HALF_FULL_ROUNDS; const PARTIAL_ROUNDS: usize = POSEIDON1_PARTIAL_ROUNDS; -const HALF_FINAL_FULL_ROUNDS: usize = POSEIDON1_HALF_FULL_ROUNDS / 2; +const FINAL_FULL_ROUNDS: usize = POSEIDON1_HALF_FULL_ROUNDS; pub const POSEIDON_PRECOMPILE_DATA: usize = 1; // domain separation: Poseidon16=1, Poseidon24=2 or 3 or 4, ExtensionOp>=8 @@ -204,13 +204,13 @@ impl Air for Poseidon16Precompile { num_cols_poseidon_16() } fn degree_air(&self) -> usize { - 9 + 3 } fn down_column_indexes(&self) -> Vec { vec![] } fn n_constraints(&self) -> usize { - BUS as usize + 76 + BUS as usize + 140 } fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { let cols: Poseidon1Cols16 = { @@ -246,7 +246,7 @@ impl Air for Poseidon16Precompile { builder.assert_bool(cols.flag); - eval_poseidon1_16(builder, &cols) + eval_poseidon(builder, &cols) } } @@ -259,22 +259,21 @@ pub(super) struct Poseidon1Cols16 { pub index_res: T, pub inputs: [T; WIDTH], - pub beginning_full_rounds: [[T; WIDTH]; HALF_INITIAL_FULL_ROUNDS], + pub beginning_full_rounds: [[T; WIDTH]; INITIAL_FULL_ROUNDS], pub partial_rounds: [T; PARTIAL_ROUNDS], - pub ending_full_rounds: [[T; WIDTH]; HALF_FINAL_FULL_ROUNDS - 1], + pub ending_full_rounds: [[T; WIDTH]; FINAL_FULL_ROUNDS - 1], pub outputs: [T; WIDTH / 2], } -fn eval_poseidon1_16(builder: &mut AB, local: &Poseidon1Cols16) { +fn eval_poseidon(builder: &mut AB, local: &Poseidon1Cols16) { let mut state: [_; WIDTH] = local.inputs; let initial_constants = poseidon1_initial_constants(); - for round in 0..HALF_INITIAL_FULL_ROUNDS { - eval_2_full_rounds_16( + for round in 0..INITIAL_FULL_ROUNDS { + eval_full_round( &mut state, &local.beginning_full_rounds[round], - &initial_constants[2 * round], - &initial_constants[2 * round + 1], + &initial_constants[round], builder, ); } @@ -285,7 +284,7 @@ fn eval_poseidon1_16(builder: &mut AB, local: &Poseidon1Cols16(builder: &mut AB, local: &Poseidon1Cols16 usize { } #[inline] -fn eval_2_full_rounds_16( +fn eval_full_round( state: &mut [AB::IF; WIDTH], post_full_round: &[AB::IF; WIDTH], - round_constants_1: &[F; WIDTH], - round_constants_2: &[F; WIDTH], + round_constants: &[F; WIDTH], builder: &mut AB, ) { - for (s, r) in state.iter_mut().zip(round_constants_1.iter()) { + for (s, r) in state.iter_mut().zip(round_constants.iter()) { add_kb(s, *r); *s = s.cube(); } - mds_air_16(state); - for (s, r) in state.iter_mut().zip(round_constants_2.iter()) { - add_kb(s, *r); - *s = s.cube(); - } - mds_air_16(state); + mds_air(state); for (state_i, post_i) in state.iter_mut().zip(post_full_round) { builder.assert_eq(*state_i, *post_i); *state_i = *post_i; @@ -353,24 +344,18 @@ fn eval_2_full_rounds_16( } #[inline] -fn eval_last_2_full_rounds_16( +fn eval_last_full_round( initial_state: &[AB::IF; WIDTH], state: &mut [AB::IF; WIDTH], outputs: &[AB::IF; WIDTH / 2], - round_constants_1: &[F; WIDTH], - round_constants_2: &[F; WIDTH], + round_constants: &[F; WIDTH], builder: &mut AB, ) { - for (s, r) in state.iter_mut().zip(round_constants_1.iter()) { - add_kb(s, *r); - *s = s.cube(); - } - mds_air_16(state); - for (s, r) in state.iter_mut().zip(round_constants_2.iter()) { + for (s, r) in state.iter_mut().zip(round_constants.iter()) { add_kb(s, *r); *s = s.cube(); } - mds_air_16(state); + mds_air(state); // add inputs to outputs (for compression) for (state_i, init_state_i) in state.iter_mut().zip(initial_state) { *state_i += *init_state_i; @@ -382,7 +367,7 @@ fn eval_last_2_full_rounds_16( } #[inline] -fn dense_mat_vec_air_16(mat: &[[F; 16]; 16], state: &mut [A; WIDTH]) { +fn dense_mat_vec_air(mat: &[[F; 16]; 16], state: &mut [A; WIDTH]) { let input = *state; for i in 0..WIDTH { let mut acc = A::ZERO; @@ -394,7 +379,7 @@ fn dense_mat_vec_air_16(mat: &[[F; 16]; 16 } #[inline] -fn sparse_mat_air_16( +fn sparse_mat_air( state: &mut [A; WIDTH], first_row: &[F; WIDTH], v: &[F; WIDTH], diff --git a/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs b/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs index fca71225..1dce2c4e 100644 --- a/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs +++ b/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs @@ -50,9 +50,9 @@ pub(super) fn generate_trace_rows_for_perm + Copy>(perm: & for (full_round, constants) in perm .beginning_full_rounds .iter_mut() - .zip(poseidon1_initial_constants().chunks_exact(2)) + .zip(poseidon1_initial_constants().iter()) { - generate_2_full_round(&mut state, full_round, &constants[0], &constants[1]); + generate_1_full_round(&mut state, full_round, constants); } // --- Sparse partial rounds --- @@ -94,35 +94,27 @@ pub(super) fn generate_trace_rows_for_perm + Copy>(perm: & for (full_round, constants) in perm .ending_full_rounds .iter_mut() - .zip(poseidon1_final_constants().chunks_exact(2)) + .zip(poseidon1_final_constants().iter()) { - generate_2_full_round(&mut state, full_round, &constants[0], &constants[1]); + generate_1_full_round(&mut state, full_round, constants); } - // Last 2 full rounds with compression (add inputs to outputs) - generate_last_2_full_rounds( + // Last full round with compression (add inputs to outputs) + generate_last_1_full_round( &mut state, &inputs, &mut perm.outputs, - &poseidon1_final_constants()[2 * n_ending_full_rounds], - &poseidon1_final_constants()[2 * n_ending_full_rounds + 1], + &poseidon1_final_constants()[n_ending_full_rounds], ); } #[inline] -fn generate_2_full_round + Copy>( +fn generate_1_full_round + Copy>( state: &mut [F; WIDTH], post_full_round: &mut [&mut F; WIDTH], - round_constants_1: &[KoalaBear; WIDTH], - round_constants_2: &[KoalaBear; WIDTH], + round_constants: &[KoalaBear; WIDTH], ) { - for (state_i, const_i) in state.iter_mut().zip(round_constants_1) { - *state_i += *const_i; - *state_i = state_i.cube(); - } - mds_circ_16(state); - - for (state_i, const_i) in state.iter_mut().zip(round_constants_2.iter()) { + for (state_i, const_i) in state.iter_mut().zip(round_constants) { *state_i += *const_i; *state_i = state_i.cube(); } @@ -134,20 +126,13 @@ fn generate_2_full_round + Copy>( } #[inline] -fn generate_last_2_full_rounds + Copy>( +fn generate_last_1_full_round + Copy>( state: &mut [F; WIDTH], inputs: &[F; WIDTH], outputs: &mut [&mut F; WIDTH / 2], - round_constants_1: &[KoalaBear; WIDTH], - round_constants_2: &[KoalaBear; WIDTH], + round_constants: &[KoalaBear; WIDTH], ) { - for (state_i, const_i) in state.iter_mut().zip(round_constants_1) { - *state_i += *const_i; - *state_i = state_i.cube(); - } - mds_circ_16(state); - - for (state_i, const_i) in state.iter_mut().zip(round_constants_2.iter()) { + for (state_i, const_i) in state.iter_mut().zip(round_constants) { *state_i += *const_i; *state_i = state_i.cube(); } From 68e4e4c05fa27fbcee16de812629acace8ba6e38 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 15 Apr 2026 23:32:18 +0200 Subject: [PATCH 02/31] wip --- Cargo.lock | 25 +- Cargo.toml | 2 +- crates/backend/Cargo.toml | 2 +- crates/backend/fiat-shamir/Cargo.toml | 2 +- crates/backend/fiat-shamir/src/challenger.rs | 2 +- crates/backend/fiat-shamir/src/transcript.rs | 2 +- crates/backend/fiat-shamir/src/verifier.rs | 20 +- crates/backend/fiat-shamir/tests/grinding.rs | 8 +- crates/backend/field/src/exponentiation.rs | 2 +- crates/backend/field/src/field.rs | 2 +- crates/backend/goldilocks/Cargo.toml | 16 + .../backend/goldilocks/src/cubic_extension.rs | 622 ++++++++++++++++++ crates/backend/goldilocks/src/goldilocks.rs | 585 ++++++++++++++++ crates/backend/goldilocks/src/helpers.rs | 73 ++ crates/backend/goldilocks/src/lib.rs | 17 + crates/backend/goldilocks/src/poseidon1.rs | 396 +++++++++++ crates/backend/poly/Cargo.toml | 2 +- crates/backend/poly/src/eq_mle.rs | 6 +- crates/backend/poly/src/evals.rs | 6 +- crates/backend/poly/src/mle/mle_custom.rs | 4 +- crates/backend/poly/src/next_mle.rs | 4 +- crates/backend/src/lib.rs | 2 +- .../sumcheck/src/product_computation.rs | 36 +- crates/backend/symetric/Cargo.toml | 2 +- crates/backend/symetric/src/merkle.rs | 2 +- crates/backend/symetric/src/permutation.rs | 8 +- crates/lean_compiler/src/a_simplify_lang.rs | 6 +- crates/lean_compiler/src/c_compile_final.rs | 2 +- .../lean_compiler/src/instruction_encoder.rs | 2 +- .../src/parser/parsers/function.rs | 6 +- crates/lean_compiler/tests/test_compiler.rs | 8 +- crates/lean_prover/src/lib.rs | 23 +- crates/lean_prover/src/prove_execution.rs | 2 +- crates/lean_prover/src/test_zkvm.rs | 20 +- crates/lean_prover/src/trace_gen.rs | 8 +- crates/lean_prover/src/verify_execution.rs | 6 +- crates/lean_vm/src/core/constants.rs | 8 +- crates/lean_vm/src/core/types.rs | 6 +- crates/lean_vm/src/diagnostics/exec_result.rs | 2 +- crates/lean_vm/src/execution/runner.rs | 2 +- crates/lean_vm/src/execution/tests.rs | 26 +- crates/lean_vm/src/isa/instruction.rs | 12 +- crates/lean_vm/src/tables/extension_op/air.rs | 108 +-- crates/lean_vm/src/tables/extension_op/mod.rs | 2 +- crates/lean_vm/src/tables/mod.rs | 4 +- crates/lean_vm/src/tables/poseidon_16/mod.rs | 396 ----------- .../src/tables/poseidon_16/trace_gen.rs | 145 ---- crates/lean_vm/src/tables/poseidon_8/mod.rs | 199 ++++++ .../src/tables/poseidon_8/trace_gen.rs | 18 + crates/lean_vm/src/tables/table_enum.rs | 12 +- crates/rec_aggregation/src/compilation.rs | 20 +- crates/rec_aggregation/src/lib.rs | 8 +- crates/sub_protocols/src/quotient_gkr.rs | 4 +- crates/utils/src/multilinear.rs | 4 +- crates/utils/src/poseidon.rs | 67 +- crates/utils/src/wrappers.rs | 14 +- crates/whir/Cargo.toml | 2 +- crates/whir/src/dft.rs | 6 +- crates/whir/src/merkle.rs | 56 +- crates/whir/tests/run_whir.rs | 12 +- crates/xmss/src/lib.rs | 6 +- crates/xmss/src/signers_cache.rs | 4 +- crates/xmss/src/wots.rs | 64 +- crates/xmss/src/xmss.rs | 8 +- crates/xmss/tests/xmss_tests.rs | 2 +- src/lib.rs | 2 +- 66 files changed, 2282 insertions(+), 868 deletions(-) create mode 100644 crates/backend/goldilocks/Cargo.toml create mode 100644 crates/backend/goldilocks/src/cubic_extension.rs create mode 100644 crates/backend/goldilocks/src/goldilocks.rs create mode 100644 crates/backend/goldilocks/src/helpers.rs create mode 100644 crates/backend/goldilocks/src/lib.rs create mode 100644 crates/backend/goldilocks/src/poseidon1.rs delete mode 100644 crates/lean_vm/src/tables/poseidon_16/mod.rs delete mode 100644 crates/lean_vm/src/tables/poseidon_16/trace_gen.rs create mode 100644 crates/lean_vm/src/tables/poseidon_8/mod.rs create mode 100644 crates/lean_vm/src/tables/poseidon_8/trace_gen.rs diff --git a/Cargo.lock b/Cargo.lock index ef793a49..aa8549da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -98,7 +98,7 @@ dependencies = [ "mt-air", "mt-fiat-shamir", "mt-field", - "mt-koala-bear", + "mt-goldilocks", "mt-poly", "mt-sumcheck", "mt-symetric", @@ -593,7 +593,7 @@ name = "mt-fiat-shamir" version = "0.1.0" dependencies = [ "mt-field", - "mt-koala-bear", + "mt-goldilocks", "mt-symetric", "mt-utils", "rayon", @@ -614,6 +614,21 @@ dependencies = [ "tracing", ] +[[package]] +name = "mt-goldilocks" +version = "0.1.0" +dependencies = [ + "itertools", + "mt-field", + "mt-utils", + "num-bigint", + "paste", + "rand", + "rayon", + "serde", + "tracing", +] + [[package]] name = "mt-koala-bear" version = "0.1.0" @@ -635,7 +650,7 @@ version = "0.1.0" dependencies = [ "itertools", "mt-field", - "mt-koala-bear", + "mt-goldilocks", "mt-utils", "rand", "rayon", @@ -660,7 +675,7 @@ name = "mt-symetric" version = "0.1.0" dependencies = [ "mt-field", - "mt-koala-bear", + "mt-goldilocks", "rayon", ] @@ -678,7 +693,7 @@ dependencies = [ "itertools", "mt-fiat-shamir", "mt-field", - "mt-koala-bear", + "mt-goldilocks", "mt-poly", "mt-sumcheck", "mt-symetric", diff --git a/Cargo.toml b/Cargo.toml index 31cb6877..e36191f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ members = [ "crates/*", "crates/backend/utils", "crates/backend/field", - "crates/backend/koala-bear", + "crates/backend/goldilocks", "crates/backend/poly", "crates/backend/symetric", "crates/backend/air", diff --git a/crates/backend/Cargo.toml b/crates/backend/Cargo.toml index 3f61957a..7557dcbb 100644 --- a/crates/backend/Cargo.toml +++ b/crates/backend/Cargo.toml @@ -13,5 +13,5 @@ rayon.workspace = true whir = { path = "../whir", package = "mt-whir" } tracing.workspace = true fiat-shamir = { path = "fiat-shamir", package = "mt-fiat-shamir" } -koala-bear = { path = "koala-bear", package = "mt-koala-bear" } +goldilocks = { path = "goldilocks", package = "mt-goldilocks" } utils = { path = "utils", package = "mt-utils" } diff --git a/crates/backend/fiat-shamir/Cargo.toml b/crates/backend/fiat-shamir/Cargo.toml index b6c9faff..d26773a0 100644 --- a/crates/backend/fiat-shamir/Cargo.toml +++ b/crates/backend/fiat-shamir/Cargo.toml @@ -5,7 +5,7 @@ edition.workspace = true [dependencies] field = { path = "../field", package = "mt-field" } -koala-bear = { path = "../koala-bear", package = "mt-koala-bear" } +goldilocks = { path = "../goldilocks", package = "mt-goldilocks" } symetric = { path = "../symetric", package = "mt-symetric" } utils = { path = "../utils", package = "mt-utils" } diff --git a/crates/backend/fiat-shamir/src/challenger.rs b/crates/backend/fiat-shamir/src/challenger.rs index d650d68a..cd968ba4 100644 --- a/crates/backend/fiat-shamir/src/challenger.rs +++ b/crates/backend/fiat-shamir/src/challenger.rs @@ -1,7 +1,7 @@ use field::PrimeField64; use symetric::Compression; -pub(crate) const RATE: usize = 8; +pub(crate) const RATE: usize = 4; pub(crate) const WIDTH: usize = RATE * 2; #[derive(Clone, Debug)] diff --git a/crates/backend/fiat-shamir/src/transcript.rs b/crates/backend/fiat-shamir/src/transcript.rs index 612c2d10..b9c2e946 100644 --- a/crates/backend/fiat-shamir/src/transcript.rs +++ b/crates/backend/fiat-shamir/src/transcript.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use crate::PrunedMerklePaths; -pub const DIGEST_LEN_FE: usize = 8; +pub const DIGEST_LEN_FE: usize = 4; #[derive(Debug, Clone)] pub struct MerkleOpening { diff --git a/crates/backend/fiat-shamir/src/verifier.rs b/crates/backend/fiat-shamir/src/verifier.rs index 9bbc26bd..c9258f23 100644 --- a/crates/backend/fiat-shamir/src/verifier.rs +++ b/crates/backend/fiat-shamir/src/verifier.rs @@ -9,7 +9,7 @@ use crate::{ }; use field::PrimeCharacteristicRing; use field::{ExtensionField, PrimeField64}; -use koala_bear::{KoalaBear, default_koalabear_poseidon1_16}; +use goldilocks::{Goldilocks, default_goldilocks_poseidon1_8}; use symetric::Compression; pub struct VerifierState>, P> { @@ -68,16 +68,16 @@ where #[allow(clippy::missing_transmute_annotations)] fn restore_merkle_paths(paths: PrunedMerklePaths, PF>) -> Option>>> { - assert_eq!(TypeId::of::>(), TypeId::of::()); - // SAFETY: We've confirmed PF == KoalaBear - let paths: PrunedMerklePaths = unsafe { std::mem::transmute(paths) }; - let perm = default_koalabear_poseidon1_16(); - let hash_fn = |data: &[KoalaBear]| symetric::hash_slice::<_, _, 16, 8, DIGEST_LEN_FE>(&perm, data); - let combine_fn = |left: &[KoalaBear; DIGEST_LEN_FE], right: &[KoalaBear; DIGEST_LEN_FE]| { + assert_eq!(TypeId::of::>(), TypeId::of::()); + // SAFETY: We've confirmed PF == Goldilocks + let paths: PrunedMerklePaths = unsafe { std::mem::transmute(paths) }; + let perm = default_goldilocks_poseidon1_8(); + let hash_fn = |data: &[Goldilocks]| symetric::hash_slice::<_, _, 8, 4, DIGEST_LEN_FE>(&perm, data); + let combine_fn = |left: &[Goldilocks; DIGEST_LEN_FE], right: &[Goldilocks; DIGEST_LEN_FE]| { symetric::compress(&perm, [*left, *right]) }; - let restored: MerklePaths = paths.restore(&hash_fn, &combine_fn)?; - let openings: Vec> = restored + let restored: MerklePaths = paths.restore(&hash_fn, &combine_fn)?; + let openings: Vec> = restored .0 .into_iter() .map(|path| MerkleOpening { @@ -85,7 +85,7 @@ where path: path.sibling_hashes, }) .collect(); - // SAFETY: PF == KoalaBear + // SAFETY: PF == Goldilocks Some(unsafe { std::mem::transmute(openings) }) } } diff --git a/crates/backend/fiat-shamir/tests/grinding.rs b/crates/backend/fiat-shamir/tests/grinding.rs index b45d336c..58a0b666 100644 --- a/crates/backend/fiat-shamir/tests/grinding.rs +++ b/crates/backend/fiat-shamir/tests/grinding.rs @@ -1,22 +1,22 @@ -use koala_bear::{QuinticExtensionFieldKB, default_koalabear_poseidon1_16}; +use goldilocks::{CubicExtensionFieldGL, default_goldilocks_poseidon1_8}; use mt_fiat_shamir::{FSProver, FSVerifier, ProverState, VerifierState}; use std::time::Instant; -type EF = QuinticExtensionFieldKB; +type EF = CubicExtensionFieldGL; #[test] #[ignore] fn bench_grinding() { let n_reps = 100; for grinding_bits in 20..=20 { - let mut prover_state = ProverState::::new(default_koalabear_poseidon1_16()); + let mut prover_state = ProverState::::new(default_goldilocks_poseidon1_8()); let time = Instant::now(); for _ in 0..n_reps { prover_state.pow_grinding(grinding_bits); } let elapsed = time.elapsed(); let mut verifier_state = - VerifierState::::new(prover_state.into_proof(), default_koalabear_poseidon1_16()).unwrap(); + VerifierState::::new(prover_state.into_proof(), default_goldilocks_poseidon1_8()).unwrap(); for _ in 0..n_reps { verifier_state.check_pow_grinding(grinding_bits).unwrap() } diff --git a/crates/backend/field/src/exponentiation.rs b/crates/backend/field/src/exponentiation.rs index 2e9f567e..92fb17f6 100644 --- a/crates/backend/field/src/exponentiation.rs +++ b/crates/backend/field/src/exponentiation.rs @@ -8,7 +8,7 @@ pub(crate) const fn bits_u64(n: u64) -> usize { /// Compute the exponential `x -> x^1420470955` using a custom addition chain. /// -/// This map computes the third root of `x` if `x` is a member of the field `KoalaBear`. +/// This map computes the third root of `x` if `x` is a member of the field `Goldilocks`. /// This follows from the computation: `3 * 1420470955 = 2*(2^31 - 2^24) + 1 = 1 mod (p - 1)`. #[must_use] pub fn exp_1420470955(val: R) -> R { diff --git a/crates/backend/field/src/field.rs b/crates/backend/field/src/field.rs index aa20e3f2..cfbea46f 100644 --- a/crates/backend/field/src/field.rs +++ b/crates/backend/field/src/field.rs @@ -71,7 +71,7 @@ pub trait PrimeCharacteristicRing: + PartialEq { /// The field `ℤ/p` where the characteristic of this ring is p. - type PrimeSubfield: PrimeField32; + type PrimeSubfield: PrimeField64; /// The additive identity of the ring. /// diff --git a/crates/backend/goldilocks/Cargo.toml b/crates/backend/goldilocks/Cargo.toml new file mode 100644 index 00000000..d602351b --- /dev/null +++ b/crates/backend/goldilocks/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "mt-goldilocks" +version.workspace = true +edition.workspace = true + +[dependencies] +field = { path = "../field", package = "mt-field" } +utils = { path = "../utils", package = "mt-utils" } + +rand.workspace = true +rayon.workspace = true +serde.workspace = true +itertools.workspace = true +tracing.workspace = true +num-bigint = "*" +paste = "1" diff --git a/crates/backend/goldilocks/src/cubic_extension.rs b/crates/backend/goldilocks/src/cubic_extension.rs new file mode 100644 index 00000000..68abf20e --- /dev/null +++ b/crates/backend/goldilocks/src/cubic_extension.rs @@ -0,0 +1,622 @@ +// Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). + +//! Degree-3 trinomial extension of Goldilocks, `F_p[X] / (X^3 - X - 1)`. +//! +//! Elements are `a_0 + a_1*X + a_2*X^2`. Reduction rule: `X^3 = X + 1`, +//! consequently `X^4 = X^2 + X`. + +use alloc::format; +use alloc::string::ToString; +use alloc::vec::Vec; +use core::array; +use core::fmt::{self, Debug, Display, Formatter}; +use core::iter::{Product, Sum}; +use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use field::{ + Algebra, BasedVectorSpace, ExtensionField, Field, Packable, PackedFieldExtension, PackedValue, + PrimeCharacteristicRing, RawDataSerializable, TwoAdicField, field_to_array, +}; +use itertools::Itertools; +use num_bigint::BigUint; +use rand::distr::{Distribution, StandardUniform}; +use rand::prelude::Rng; +use serde::{Deserialize, Serialize}; +use utils::{as_base_slice, as_base_slice_mut, flatten_to_base, reconstitute_from_base}; + +use crate::Goldilocks; + +/// Frobenius coefficients for `X^3 - X - 1` over Goldilocks. +/// +/// `FROBENIUS_COEFFS[0]` is `X^p mod (X^3 - X - 1)`, `FROBENIUS_COEFFS[1]` is `X^{2p} mod …`. +/// +/// Values verified by the companion `plonky3/goldilocks` code. +pub const FROBENIUS_COEFFS: [[Goldilocks; 3]; 2] = [ + [ + Goldilocks::new(10615703402128488253), + Goldilocks::new(10050274602728160328), + Goldilocks::new(11746561000929144102), + ], + [ + Goldilocks::new(6700183068485440220), + Goldilocks::new(14531223735771536287), + Goldilocks::new(8396469466686423992), + ], +]; + +/// Generator of the multiplicative group of the cubic extension, as a coefficient triple. +const EXT_GENERATOR: [Goldilocks; 3] = [Goldilocks::new(2), Goldilocks::new(1), Goldilocks::new(0)]; + +/// Degree-3 trinomial extension of Goldilocks. +#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, PartialOrd, Ord)] +#[repr(transparent)] +#[must_use] +pub struct CubicExtensionFieldGL { + #[serde(with = "utils::array_serialization")] + pub(crate) value: [Goldilocks; 3], +} + +impl CubicExtensionFieldGL { + /// Construct from a coefficient triple `[a_0, a_1, a_2]`. + #[inline] + pub const fn new(value: [Goldilocks; 3]) -> Self { + Self { value } + } +} + +impl Default for CubicExtensionFieldGL { + fn default() -> Self { + Self::new([Goldilocks::ZERO; 3]) + } +} + +impl From for CubicExtensionFieldGL { + fn from(x: Goldilocks) -> Self { + Self::new(field_to_array(x)) + } +} + +impl From<[Goldilocks; 3]> for CubicExtensionFieldGL { + fn from(x: [Goldilocks; 3]) -> Self { + Self::new(x) + } +} + +impl Packable for CubicExtensionFieldGL {} + +impl BasedVectorSpace for CubicExtensionFieldGL { + const DIMENSION: usize = 3; + + #[inline] + fn as_basis_coefficients_slice(&self) -> &[Goldilocks] { + &self.value + } + + #[inline] + fn from_basis_coefficients_fn Goldilocks>(f: Fn) -> Self { + Self::new(array::from_fn(f)) + } + + #[inline] + fn from_basis_coefficients_iter>( + mut iter: I, + ) -> Option { + (iter.len() == 3).then(|| Self::new(array::from_fn(|_| iter.next().unwrap()))) + } + + #[inline] + fn flatten_to_base(vec: Vec) -> Vec { + // SAFETY: `Self` is `repr(transparent)` over `[Goldilocks; 3]`. + unsafe { flatten_to_base::(vec) } + } + + #[inline] + fn reconstitute_from_base(vec: Vec) -> Vec { + // SAFETY: `Self` is `repr(transparent)` over `[Goldilocks; 3]`. + unsafe { reconstitute_from_base::(vec) } + } +} + +impl ExtensionField for CubicExtensionFieldGL { + type ExtensionPacking = Self; + + #[inline] + fn is_in_basefield(&self) -> bool { + self.value[1].is_zero() && self.value[2].is_zero() + } + + #[inline] + fn as_base(&self) -> Option { + >::is_in_basefield(self).then(|| self.value[0]) + } +} + +impl CubicExtensionFieldGL { + /// Apply the Frobenius `x -> x^p`. + /// + /// `φ(a) = a_0 + a_1 * X^p + a_2 * X^{2p}`, reduced with the stored coefficients. + #[inline] + pub fn frobenius(&self) -> Self { + let a = &self.value; + let fc = &FROBENIUS_COEFFS; + let tail = [a[1], a[2]]; + let c0 = a[0] + Goldilocks::dot_product::<2>(&tail, &[fc[0][0], fc[1][0]]); + let c1 = Goldilocks::dot_product::<2>(&tail, &[fc[0][1], fc[1][1]]); + let c2 = Goldilocks::dot_product::<2>(&tail, &[fc[0][2], fc[1][2]]); + Self::new([c0, c1, c2]) + } +} + +impl PrimeCharacteristicRing for CubicExtensionFieldGL { + type PrimeSubfield = ::PrimeSubfield; + + const ZERO: Self = Self::new([Goldilocks::ZERO; 3]); + const ONE: Self = Self::new(field_to_array(Goldilocks::ONE)); + const TWO: Self = Self::new(field_to_array(Goldilocks::TWO)); + const NEG_ONE: Self = Self::new(field_to_array(Goldilocks::NEG_ONE)); + + #[inline] + fn from_prime_subfield(f: Self::PrimeSubfield) -> Self { + ::from_prime_subfield(f).into() + } + + #[inline] + fn halve(&self) -> Self { + Self::new(self.value.map(|x| x.halve())) + } + + #[inline] + fn square(&self) -> Self { + let mut res = Self::default(); + cubic_square(&self.value, &mut res.value); + res + } + + #[inline] + fn mul_2exp_u64(&self, exp: u64) -> Self { + Self::new(self.value.map(|x| x.mul_2exp_u64(exp))) + } + + #[inline] + fn div_2exp_u64(&self, exp: u64) -> Self { + Self::new(self.value.map(|x| x.div_2exp_u64(exp))) + } + + #[inline] + fn zero_vec(len: usize) -> Vec { + // SAFETY: `repr(transparent)` over `[Goldilocks; 3]`. + unsafe { reconstitute_from_base(Goldilocks::zero_vec(len * 3)) } + } +} + +impl Algebra for CubicExtensionFieldGL {} + +impl RawDataSerializable for CubicExtensionFieldGL { + const NUM_BYTES: usize = ::NUM_BYTES * 3; + + #[inline] + fn into_bytes(self) -> impl IntoIterator { + self.value.into_iter().flat_map(|x| x.into_bytes()) + } + + #[inline] + fn into_byte_stream(input: impl IntoIterator) -> impl IntoIterator { + Goldilocks::into_byte_stream(input.into_iter().flat_map(|x| x.value)) + } + + #[inline] + fn into_u32_stream(input: impl IntoIterator) -> impl IntoIterator { + Goldilocks::into_u32_stream(input.into_iter().flat_map(|x| x.value)) + } + + #[inline] + fn into_u64_stream(input: impl IntoIterator) -> impl IntoIterator { + Goldilocks::into_u64_stream(input.into_iter().flat_map(|x| x.value)) + } + + #[inline] + fn into_parallel_byte_streams( + input: impl IntoIterator, + ) -> impl IntoIterator { + Goldilocks::into_parallel_byte_streams( + input + .into_iter() + .flat_map(|x| (0..3).map(move |i| array::from_fn(|j| x[j].value[i]))), + ) + } + + #[inline] + fn into_parallel_u32_streams( + input: impl IntoIterator, + ) -> impl IntoIterator { + Goldilocks::into_parallel_u32_streams( + input + .into_iter() + .flat_map(|x| (0..3).map(move |i| array::from_fn(|j| x[j].value[i]))), + ) + } + + #[inline] + fn into_parallel_u64_streams( + input: impl IntoIterator, + ) -> impl IntoIterator { + Goldilocks::into_parallel_u64_streams( + input + .into_iter() + .flat_map(|x| (0..3).map(move |i| array::from_fn(|j| x[j].value[i]))), + ) + } +} + +impl Field for CubicExtensionFieldGL { + type Packing = Self; + + const GENERATOR: Self = Self::new(EXT_GENERATOR); + + fn try_inverse(&self) -> Option { + if self.is_zero() { + return None; + } + Some(cubic_inv(self)) + } + + #[inline] + fn add_slices(slice_1: &mut [Self], slice_2: &[Self]) { + // SAFETY: `repr(transparent)` + addition is base-linear. + unsafe { + let base_slice_1 = as_base_slice_mut(slice_1); + let base_slice_2 = as_base_slice(slice_2); + Goldilocks::add_slices(base_slice_1, base_slice_2); + } + } + + #[inline] + fn order() -> BigUint { + Goldilocks::order().pow(3) + } +} + +impl Display for CubicExtensionFieldGL { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + if self.is_zero() { + write!(f, "0") + } else { + let str = self + .value + .iter() + .enumerate() + .filter(|(_, x)| !x.is_zero()) + .map(|(i, x)| match (i, x.is_one()) { + (0, _) => format!("{x}"), + (1, true) => "X".to_string(), + (1, false) => format!("{x} X"), + (_, true) => format!("X^{i}"), + (_, false) => format!("{x} X^{i}"), + }) + .join(" + "); + write!(f, "{str}") + } + } +} + +impl Neg for CubicExtensionFieldGL { + type Output = Self; + + #[inline] + fn neg(self) -> Self { + Self::new(self.value.map(Goldilocks::neg)) + } +} + +impl Add for CubicExtensionFieldGL { + type Output = Self; + + #[inline] + fn add(self, rhs: Self) -> Self { + Self::new([ + self.value[0] + rhs.value[0], + self.value[1] + rhs.value[1], + self.value[2] + rhs.value[2], + ]) + } +} + +impl Add for CubicExtensionFieldGL { + type Output = Self; + + #[inline] + fn add(mut self, rhs: Goldilocks) -> Self { + self.value[0] += rhs; + self + } +} + +impl AddAssign for CubicExtensionFieldGL { + #[inline] + fn add_assign(&mut self, rhs: Self) { + for i in 0..3 { + self.value[i] += rhs.value[i]; + } + } +} + +impl AddAssign for CubicExtensionFieldGL { + #[inline] + fn add_assign(&mut self, rhs: Goldilocks) { + self.value[0] += rhs; + } +} + +impl Sum for CubicExtensionFieldGL { + #[inline] + fn sum>(iter: I) -> Self { + iter.reduce(|acc, x| acc + x).unwrap_or(Self::ZERO) + } +} + +impl Sub for CubicExtensionFieldGL { + type Output = Self; + + #[inline] + fn sub(self, rhs: Self) -> Self { + Self::new([ + self.value[0] - rhs.value[0], + self.value[1] - rhs.value[1], + self.value[2] - rhs.value[2], + ]) + } +} + +impl Sub for CubicExtensionFieldGL { + type Output = Self; + + #[inline] + fn sub(mut self, rhs: Goldilocks) -> Self { + self.value[0] -= rhs; + self + } +} + +impl SubAssign for CubicExtensionFieldGL { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + for i in 0..3 { + self.value[i] -= rhs.value[i]; + } + } +} + +impl SubAssign for CubicExtensionFieldGL { + #[inline] + fn sub_assign(&mut self, rhs: Goldilocks) { + self.value[0] -= rhs; + } +} + +impl Mul for CubicExtensionFieldGL { + type Output = Self; + + #[inline] + fn mul(self, rhs: Self) -> Self { + let mut res = Self::default(); + cubic_mul(&self.value, &rhs.value, &mut res.value); + res + } +} + +impl Mul for CubicExtensionFieldGL { + type Output = Self; + + #[inline] + fn mul(self, rhs: Goldilocks) -> Self { + Self::new([self.value[0] * rhs, self.value[1] * rhs, self.value[2] * rhs]) + } +} + +impl MulAssign for CubicExtensionFieldGL { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl MulAssign for CubicExtensionFieldGL { + #[inline] + fn mul_assign(&mut self, rhs: Goldilocks) { + *self = *self * rhs; + } +} + +impl Product for CubicExtensionFieldGL { + #[inline] + fn product>(iter: I) -> Self { + iter.reduce(|acc, x| acc * x).unwrap_or(Self::ONE) + } +} + +impl Div for CubicExtensionFieldGL { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + #[inline] + fn div(self, rhs: Self) -> Self::Output { + self * rhs.inverse() + } +} + +impl DivAssign for CubicExtensionFieldGL { + #[inline] + fn div_assign(&mut self, rhs: Self) { + *self = *self / rhs; + } +} + +impl Distribution for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> CubicExtensionFieldGL { + CubicExtensionFieldGL::new(array::from_fn(|_| self.sample(rng))) + } +} + +impl TwoAdicField for CubicExtensionFieldGL { + const TWO_ADICITY: usize = Goldilocks::TWO_ADICITY; + + #[inline] + fn two_adic_generator(bits: usize) -> Self { + Goldilocks::two_adic_generator(bits).into() + } +} + +// PackedFieldExtension: since Goldilocks has trivial packing (Packing = Self), the cubic +// extension is also its own packing. +impl PackedFieldExtension for CubicExtensionFieldGL { + #[inline] + fn from_ext_slice(ext_slice: &[CubicExtensionFieldGL]) -> Self { + // Goldilocks::Packing::WIDTH == 1, so the input is a single element. + *CubicExtensionFieldGL::from_slice(ext_slice) + } + + #[inline] + fn packed_ext_powers(base: CubicExtensionFieldGL) -> field::Powers { + // `Powers` is just an iterator over `base^k` starting at `k = 1`. + use field::Powers; + Powers { + base, + current: Self::ONE, + } + } +} + +// ============================================================================ +// Arithmetic kernels for `F_p[X] / (X^3 - X - 1)`. +// ============================================================================ + +/// Multiply two cubic extension elements. +/// +/// Given `a = a_0 + a_1 X + a_2 X^2` and `b = b_0 + b_1 X + b_2 X^2`, compute the +/// product reduced by `X^3 - X - 1` (so `X^3 = X + 1`, `X^4 = X^2 + X`). +#[inline] +pub fn cubic_mul(a: &[Goldilocks; 3], b: &[Goldilocks; 3], res: &mut [Goldilocks; 3]) { + let a0 = a[0]; + let a1 = a[1]; + let a2 = a[2]; + let b0 = b[0]; + let b1 = b[1]; + let b2 = b[2]; + + let a1b2 = a1 * b2; + let a2b1 = a2 * b1; + let a2b2 = a2 * b2; + + // constant: a0 b0 + a1 b2 + a2 b1 + res[0] = a0 * b0 + a1b2 + a2b1; + // linear: a0 b1 + a1 b0 + a1 b2 + a2 b1 + a2 b2 + res[1] = a0 * b1 + a1 * b0 + a1b2 + a2b1 + a2b2; + // quadratic: a0 b2 + a1 b1 + a2 b0 + a2 b2 + res[2] = a0 * b2 + a1 * b1 + a2 * b0 + a2b2; +} + +/// Square a cubic extension element (same reduction rule as `cubic_mul`). +#[inline] +pub fn cubic_square(a: &[Goldilocks; 3], res: &mut [Goldilocks; 3]) { + let a0 = a[0]; + let a1 = a[1]; + let a2 = a[2]; + + let a0_sq = a0.square(); + let a1_sq = a1.square(); + let a2_sq = a2.square(); + let two_a0_a1 = (a0 * a1).double(); + let two_a0_a2 = (a0 * a2).double(); + let two_a1_a2 = (a1 * a2).double(); + + // constant: a0^2 + 2 a1 a2 + res[0] = a0_sq + two_a1_a2; + // linear: 2 a0 a1 + 2 a1 a2 + a2^2 + res[1] = two_a0_a1 + two_a1_a2 + a2_sq; + // quadratic: 2 a0 a2 + a1^2 + a2^2 + res[2] = two_a0_a2 + a1_sq + a2_sq; +} + +/// Invert a cubic extension element via adjugate/determinant — no Frobenius round trip needed. +/// +/// The multiplication-by-`a` matrix (in the basis `{1, X, X^2}`, using `X^3 = X + 1`) is +/// +/// ```text +/// M = | a0 a2 a1 | +/// | a1 a0 + a2 a1 + a2 | +/// | a2 a1 a0 + a2 | +/// ``` +/// +/// so `a^{-1} = adj(M) · e_0 / det(M)`. +#[inline] +fn cubic_inv(a: &CubicExtensionFieldGL) -> CubicExtensionFieldGL { + let [a0, a1, a2] = a.value; + + let a0_sq = a0.square(); + let a1_sq = a1.square(); + let a2_sq = a2.square(); + let a0a1 = a0 * a1; + let a0a2 = a0 * a2; + let a1a2 = a1 * a2; + + // Cofactors of the first row of `M` (see matrix above): + // n0 = a1 a2 + a1^2 - a0^2 - 2 a0 a2 - a2^2 + let n0 = a1a2 + a1_sq - a0_sq - a0a2.double() - a2_sq; + // n1 = a0 a1 - a2^2 + let n1 = a0a1 - a2_sq; + // n2 = a0 a2 + a2^2 - a1^2 + let n2 = a0a2 + a2_sq - a1_sq; + + // `t = -det(M) = a0 n0 + a2 n1 + a1 n2`. + let t = a0 * n0 + a2 * n1 + a1 * n2; + let t_inv = t.inverse(); + + CubicExtensionFieldGL::new([n0 * t_inv, n1 * t_inv, n2 * t_inv]) +} + +// ============================================================================ +// Frobenius sanity test — exercised during `cargo test`. +// ============================================================================ + +#[cfg(test)] +mod tests { + use field::{Field, PrimeCharacteristicRing, PrimeField64}; + use rand::rngs::StdRng; + use rand::{RngExt, SeedableRng}; + + use super::*; + + #[test] + fn inverse_roundtrip() { + let mut rng = StdRng::seed_from_u64(1); + for _ in 0..32 { + let a: CubicExtensionFieldGL = rng.random(); + if a.is_zero() { + continue; + } + let a_inv = a.inverse(); + assert_eq!(a * a_inv, CubicExtensionFieldGL::ONE); + } + } + + #[test] + fn x_cubed_equals_x_plus_one() { + // The extension is `F_p[X]/(X^3 - X - 1)`, so `X^3 = X + 1`. + let x = CubicExtensionFieldGL::new([Goldilocks::ZERO, Goldilocks::ONE, Goldilocks::ZERO]); + let x_cubed = x * x * x; + let expected = CubicExtensionFieldGL::new([Goldilocks::ONE, Goldilocks::ONE, Goldilocks::ZERO]); + assert_eq!(x_cubed, expected); + } + + #[test] + fn frobenius_matches_pth_power() { + let mut rng = StdRng::seed_from_u64(2); + for _ in 0..8 { + let a: CubicExtensionFieldGL = rng.random(); + let a_frob = a.frobenius(); + let a_pth = a.exp_u64(Goldilocks::ORDER_U64); + assert_eq!(a_frob, a_pth); + } + } +} diff --git a/crates/backend/goldilocks/src/goldilocks.rs b/crates/backend/goldilocks/src/goldilocks.rs new file mode 100644 index 00000000..6a4e6be9 --- /dev/null +++ b/crates/backend/goldilocks/src/goldilocks.rs @@ -0,0 +1,585 @@ +// Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). + +use alloc::vec; +use alloc::vec::Vec; +use core::fmt::{Debug, Display, Formatter}; +use core::hash::{Hash, Hasher}; +use core::iter::{Product, Sum}; +use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; +use core::{array, fmt}; + +use field::integers::QuotientMap; +use field::op_assign_macros::{impl_add_assign, impl_div_methods, impl_mul_methods, impl_sub_assign}; +use field::{ + Field, InjectiveMonomial, Packable, PermutationMonomial, PrimeCharacteristicRing, PrimeField, + PrimeField64, RawDataSerializable, TwoAdicField, impl_raw_serializable_primefield64, + quotient_map_large_iint, quotient_map_large_uint, quotient_map_small_int, +}; +use num_bigint::BigUint; +use rand::Rng; +use rand::distr::{Distribution, StandardUniform}; +use serde::{Deserialize, Serialize}; +use utils::{assume, branch_hint, flatten_to_base}; + +use crate::helpers::{exp_10540996611094048183, gcd_inner}; + +/// The Goldilocks prime. +pub(crate) const P: u64 = 0xFFFF_FFFF_0000_0001; + +/// The prime field known as Goldilocks, defined as `F_p` where `p = 2^64 - 2^32 + 1`. +/// +/// The internal representation is not necessarily canonical — any `u64` is allowed. +#[derive(Copy, Clone, Default, Serialize, Deserialize)] +#[repr(transparent)] +#[must_use] +pub struct Goldilocks { + pub(crate) value: u64, +} + +impl Goldilocks { + /// Create a new field element from any `u64`. + /// + /// Any `u64` value is accepted. No reduction is performed since Goldilocks uses a + /// non-canonical internal representation. + #[inline] + pub const fn new(value: u64) -> Self { + Self { value } + } + + /// Convert a `[u64; N]` array to an array of field elements. + #[inline] + pub const fn new_array(input: [u64; N]) -> [Self; N] { + let mut output = [Self::ZERO; N]; + let mut i = 0; + while i < N { + output[i].value = input[i]; + i += 1; + } + output + } + + /// Convert a `[[u64; N]; M]` array to a 2D array of field elements. + #[inline] + pub const fn new_2d_array( + input: [[u64; N]; M], + ) -> [[Self; N]; M] { + let mut output = [[Self::ZERO; N]; M]; + let mut i = 0; + while i < M { + output[i] = Self::new_array(input[i]); + i += 1; + } + output + } + + /// Two's complement of `ORDER`, i.e. `2^64 - ORDER = 2^32 - 1`. + const NEG_ORDER: u64 = Self::ORDER_U64.wrapping_neg(); + + /// Generators of the two-adic subgroups: `TWO_ADIC_GENERATORS[0] = 1`, + /// `TWO_ADIC_GENERATORS[i+1]^2 = TWO_ADIC_GENERATORS[i]`. + pub const TWO_ADIC_GENERATORS: [Self; 33] = Self::new_array([ + 0x0000000000000001, + 0xffffffff00000000, + 0x0001000000000000, + 0xfffffffeff000001, + 0xefffffff00000001, + 0x00003fffffffc000, + 0x0000008000000000, + 0xf80007ff08000001, + 0xbf79143ce60ca966, + 0x1905d02a5c411f4e, + 0x9d8f2ad78bfed972, + 0x0653b4801da1c8cf, + 0xf2c35199959dfcb6, + 0x1544ef2335d17997, + 0xe0ee099310bba1e2, + 0xf6b2cffe2306baac, + 0x54df9630bf79450e, + 0xabd0a6e8aa3d8a0e, + 0x81281a7b05f9beac, + 0xfbd41c6b8caa3302, + 0x30ba2ecd5e93e76d, + 0xf502aef532322654, + 0x4b2a18ade67246b5, + 0xea9d5a1336fbc98b, + 0x86cdcc31c307e171, + 0x4bbaf5976ecfefd8, + 0xed41d05b78d6e286, + 0x10d78dd8915a171d, + 0x59049500004a4485, + 0xdfa8c93ba46d2666, + 0x7e9bd009b86a0845, + 0x400a7f755588e659, + 0x185629dcda58878c, + ]); + + /// Powers of two from 2^0 to 2^95 (inclusive). + /// + /// Note that `2^96 = -1 mod P`, so any power of two can be derived from this table. + const POWERS_OF_TWO: [Self; 96] = { + let mut powers_of_two = [Self::ONE; 96]; + let mut i = 1; + while i < 64 { + powers_of_two[i] = Self::new(1 << i); + i += 1; + } + let mut var = Self::new(1 << 63); + while i < 96 { + var = const_add(var, var); + powers_of_two[i] = var; + i += 1; + } + powers_of_two + }; +} + +impl PartialEq for Goldilocks { + fn eq(&self, other: &Self) -> bool { + self.as_canonical_u64() == other.as_canonical_u64() + } +} + +impl Eq for Goldilocks {} + +impl Packable for Goldilocks {} + +impl Hash for Goldilocks { + fn hash(&self, state: &mut H) { + state.write_u64(self.as_canonical_u64()); + } +} + +impl Ord for Goldilocks { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.as_canonical_u64().cmp(&other.as_canonical_u64()) + } +} + +impl PartialOrd for Goldilocks { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Display for Goldilocks { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(&self.as_canonical_u64(), f) + } +} + +impl Debug for Goldilocks { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Debug::fmt(&self.as_canonical_u64(), f) + } +} + +impl Distribution for StandardUniform { + fn sample(&self, rng: &mut R) -> Goldilocks { + loop { + let next_u64 = rng.next_u64(); + if next_u64 < Goldilocks::ORDER_U64 { + return Goldilocks::new(next_u64); + } + } + } +} + +impl PrimeCharacteristicRing for Goldilocks { + type PrimeSubfield = Self; + + const ZERO: Self = Self::new(0); + const ONE: Self = Self::new(1); + const TWO: Self = Self::new(2); + const NEG_ONE: Self = Self::new(Self::ORDER_U64 - 1); + + #[inline] + fn from_prime_subfield(f: Self::PrimeSubfield) -> Self { + f + } + + #[inline] + fn from_bool(b: bool) -> Self { + Self::new(b.into()) + } + + #[inline] + fn halve(&self) -> Self { + Self::new(crate::helpers::halve_u64::

(self.value)) + } + + #[inline] + fn mul_2exp_u64(&self, exp: u64) -> Self { + // 2^96 = -1 mod P, 2^192 = 1 mod P. + if exp < 96 { + *self * Self::POWERS_OF_TWO[exp as usize] + } else if exp < 192 { + -*self * Self::POWERS_OF_TWO[(exp - 96) as usize] + } else { + self.mul_2exp_u64(exp % 192) + } + } + + #[inline] + fn div_2exp_u64(&self, mut exp: u64) -> Self { + // 2^{-n} = 2^{192 - n} mod P. + exp %= 192; + self.mul_2exp_u64(192 - exp) + } + + #[inline] + fn sum_array(input: &[Self]) -> Self { + assert_eq!(N, input.len()); + match N { + 0 => Self::ZERO, + 1 => input[0], + 2 => input[0] + input[1], + 3 => input[0] + input[1] + input[2], + _ => input.iter().copied().sum(), + } + } + + #[inline] + fn dot_product(lhs: &[Self; N], rhs: &[Self; N]) -> Self { + // OFFSET has two key properties: + // 1. it's a multiple of P, + // 2. it exceeds the maximum sum of two u64 products. + const OFFSET: u128 = ((P as u128) << 64) - (P as u128) + ((P as u128) << 32); + const { + assert!((N as u32) <= (1 << 31)); + } + match N { + 0 => Self::ZERO, + 1 => lhs[0] * rhs[0], + 2 => { + let long_prod_0 = (lhs[0].value as u128) * (rhs[0].value as u128); + let long_prod_1 = (lhs[1].value as u128) * (rhs[1].value as u128); + let (sum, over) = long_prod_0.overflowing_add(long_prod_1); + let sum_corr = sum.wrapping_sub(OFFSET); + if over { + reduce128(sum_corr) + } else { + reduce128(sum) + } + } + _ => { + let (lo_plus_hi, hi) = lhs + .iter() + .zip(rhs) + .map(|(x, y)| (x.value as u128) * (y.value as u128)) + .fold((0_u128, 0_u64), |(acc_lo, acc_hi), val| { + let val_hi = (val >> 96) as u64; + unsafe { (acc_lo.wrapping_add(val), acc_hi.unchecked_add(val_hi)) } + }); + let lo = lo_plus_hi.wrapping_sub((hi as u128) << 96); + let sum = unsafe { lo.unchecked_add(P.unchecked_sub(hi) as u128) }; + reduce128(sum) + } + } + } + + #[inline] + fn zero_vec(len: usize) -> Vec { + // SAFETY: `#[repr(transparent)]` means `Goldilocks` and `u64` share layout. + unsafe { flatten_to_base(vec![0u64; len]) } + } +} + +/// `p - 1 = 2^32 * 3 * 5 * 17 * 257 * 65537`. The smallest `D` with `gcd(p - 1, D) = 1` is 7. +impl InjectiveMonomial<7> for Goldilocks {} + +impl PermutationMonomial<7> for Goldilocks { + fn injective_exp_root_n(&self) -> Self { + exp_10540996611094048183(*self) + } +} + +impl RawDataSerializable for Goldilocks { + impl_raw_serializable_primefield64!(); +} + +impl Field for Goldilocks { + type Packing = Self; + + const GENERATOR: Self = Self::new(7); + + #[inline] + fn is_zero(&self) -> bool { + self.value == 0 || self.value == Self::ORDER_U64 + } + + fn try_inverse(&self) -> Option { + if self.is_zero() { + return None; + } + Some(gcd_inversion(*self)) + } + + #[inline] + fn order() -> BigUint { + P.into() + } +} + +quotient_map_small_int!(Goldilocks, u64, [u8, u16, u32]); +quotient_map_small_int!(Goldilocks, i64, [i8, i16, i32]); +quotient_map_large_uint!( + Goldilocks, + u64, + Goldilocks::ORDER_U64, + "`[0, 2^64 - 2^32]`", + "`[0, 2^64 - 1]`", + [u128] +); +quotient_map_large_iint!( + Goldilocks, + i64, + "`[-(2^63 - 2^31), 2^63 - 2^31]`", + "`[1 + 2^32 - 2^64, 2^64 - 1]`", + [(i128, u128)] +); + +impl QuotientMap for Goldilocks { + #[inline] + fn from_int(int: u64) -> Self { + Self::new(int) + } + + #[inline] + fn from_canonical_checked(int: u64) -> Option { + (int < Self::ORDER_U64).then(|| Self::new(int)) + } + + #[inline(always)] + unsafe fn from_canonical_unchecked(int: u64) -> Self { + Self::new(int) + } +} + +impl QuotientMap for Goldilocks { + #[inline] + fn from_int(int: i64) -> Self { + if int >= 0 { + Self::new(int as u64) + } else { + Self::new(Self::ORDER_U64.wrapping_add_signed(int)) + } + } + + #[inline] + fn from_canonical_checked(int: i64) -> Option { + const POS_BOUND: i64 = (P >> 1) as i64; + const NEG_BOUND: i64 = -POS_BOUND; + match int { + 0..=POS_BOUND => Some(Self::new(int as u64)), + NEG_BOUND..0 => Some(Self::new(Self::ORDER_U64.wrapping_add_signed(int))), + _ => None, + } + } + + #[inline(always)] + unsafe fn from_canonical_unchecked(int: i64) -> Self { + Self::from_int(int) + } +} + +impl PrimeField for Goldilocks { + fn as_canonical_biguint(&self) -> BigUint { + self.as_canonical_u64().into() + } +} + +impl PrimeField64 for Goldilocks { + const ORDER_U64: u64 = P; + + #[inline] + fn as_canonical_u64(&self) -> u64 { + let mut c = self.value; + // Single conditional subtraction is sufficient since `2 * ORDER` would overflow u64. + if c >= Self::ORDER_U64 { + c -= Self::ORDER_U64; + } + c + } +} + +impl TwoAdicField for Goldilocks { + const TWO_ADICITY: usize = 32; + + fn two_adic_generator(bits: usize) -> Self { + assert!(bits <= Self::TWO_ADICITY); + Self::TWO_ADIC_GENERATORS[bits] + } +} + +/// `const` version of addition — useful for building const tables. +#[inline] +const fn const_add(lhs: Goldilocks, rhs: Goldilocks) -> Goldilocks { + let (sum, over) = lhs.value.overflowing_add(rhs.value); + let (mut sum, over) = sum.overflowing_add((over as u64) * Goldilocks::NEG_ORDER); + if over { + sum += Goldilocks::NEG_ORDER; + } + Goldilocks::new(sum) +} + +impl Add for Goldilocks { + type Output = Self; + + #[inline] + fn add(self, rhs: Self) -> Self { + let (sum, over) = self.value.overflowing_add(rhs.value); + let (mut sum, over) = sum.overflowing_add(u64::from(over) * Self::NEG_ORDER); + if over { + unsafe { + assume(self.value > Self::ORDER_U64 && rhs.value > Self::ORDER_U64); + } + branch_hint(); + sum += Self::NEG_ORDER; + } + Self::new(sum) + } +} + +impl Sub for Goldilocks { + type Output = Self; + + #[inline] + fn sub(self, rhs: Self) -> Self { + let (diff, under) = self.value.overflowing_sub(rhs.value); + let (mut diff, under) = diff.overflowing_sub(u64::from(under) * Self::NEG_ORDER); + if under { + unsafe { + assume(self.value < Self::NEG_ORDER - 1 && rhs.value > Self::ORDER_U64); + } + branch_hint(); + diff -= Self::NEG_ORDER; + } + Self::new(diff) + } +} + +impl Neg for Goldilocks { + type Output = Self; + + #[inline] + fn neg(self) -> Self::Output { + Self::new(Self::ORDER_U64 - self.as_canonical_u64()) + } +} + +impl Mul for Goldilocks { + type Output = Self; + + #[inline] + fn mul(self, rhs: Self) -> Self { + reduce128(u128::from(self.value) * u128::from(rhs.value)) + } +} + +impl_add_assign!(Goldilocks); +impl_sub_assign!(Goldilocks); +impl_mul_methods!(Goldilocks); +impl_div_methods!(Goldilocks, Goldilocks); + +impl Sum for Goldilocks { + fn sum>(iter: I) -> Self { + // Faster than `reduce` for iterators of length > 2; cannot overflow provided len < 2^64. + let sum = iter.map(|x| x.value as u128).sum::(); + reduce128(sum) + } +} + +/// Reduce to a 64-bit value. Output may be in `[0, 2^64)`, i.e. not necessarily canonical. +#[inline] +pub(crate) fn reduce128(x: u128) -> Goldilocks { + let (x_lo, x_hi) = split(x); + let x_hi_hi = x_hi >> 32; + let x_hi_lo = x_hi & Goldilocks::NEG_ORDER; + + let (mut t0, borrow) = x_lo.overflowing_sub(x_hi_hi); + if borrow { + branch_hint(); + t0 -= Goldilocks::NEG_ORDER; + } + let t1 = x_hi_lo * Goldilocks::NEG_ORDER; + let t2 = unsafe { add_no_canonicalize_trashing_input(t0, t1) }; + Goldilocks::new(t2) +} + +#[inline] +#[allow(clippy::cast_possible_truncation)] +const fn split(x: u128) -> (u64, u64) { + (x as u64, (x >> 64) as u64) +} + +/// Fast addition modulo `ORDER` on x86-64, using CF/SBB to pick the adjustment branchlessly. +/// +/// # Safety +/// - Only correct if `x + y < 2^64 + ORDER = 0x1_FFFF_FFFF_0000_0001`. +/// - Overwrites both inputs in registers on x86; avoid reusing them. +#[inline(always)] +#[cfg(target_arch = "x86_64")] +unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { + unsafe { + let res_wrapped: u64; + let adjustment: u64; + core::arch::asm!( + "add {0}, {1}", + "sbb {1:e}, {1:e}", + inlateout(reg) x => res_wrapped, + inlateout(reg) y => adjustment, + options(pure, nomem, nostack), + ); + assume(x != 0 || (res_wrapped == y && adjustment == 0)); + assume(y != 0 || (res_wrapped == x && adjustment == 0)); + res_wrapped + adjustment + } +} + +#[inline(always)] +#[cfg(not(target_arch = "x86_64"))] +unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { + let (res_wrapped, carry) = x.overflowing_add(y); + res_wrapped + Goldilocks::NEG_ORDER * u64::from(carry) +} + +/// Binary-GCD inversion for Goldilocks. +/// +/// Uses the "update factor" variant from https://eprint.iacr.org/2020/972.pdf: compute +/// factors off by a known power of two, then correct at the end via a linear combination. +fn gcd_inversion(input: Goldilocks) -> Goldilocks { + let (mut a, mut b) = (input.value, P); + + // `len(a) + len(b) <= 128` initially; 126 iterations suffice to drive it to <= 2. + // Split into 2 rounds of 63. + const ROUND_SIZE: usize = 63; + + let (f00, _, f10, _) = gcd_inner::(&mut a, &mut b); + let (_, _, f11, g11) = gcd_inner::(&mut a, &mut b); + + // The update factors are i64's, but we interpret `-2^63` as `2^63` because + // `gcd_inner` outputs sit in `(-2^ROUND_SIZE, 2^ROUND_SIZE]`. + let u = from_unusual_int(f00); + let v = from_unusual_int(f10); + let u_fac11 = from_unusual_int(f11); + let v_fac11 = from_unusual_int(g11); + + // Each iteration introduced a factor of 2, so we need to divide by `2^126`. + // `2^192 = 1 mod P`, so multiply by `2^66` instead (192 - 126 = 66). + (u * u_fac11 + v * v_fac11).mul_2exp_u64(66) +} + +/// Convert an `i64` to Goldilocks, interpreting `i64::MIN` as `2^63` (not `-2^63`). +const fn from_unusual_int(int: i64) -> Goldilocks { + if (int >= 0) || (int == i64::MIN) { + Goldilocks::new(int as u64) + } else { + Goldilocks::new(Goldilocks::ORDER_U64.wrapping_add_signed(int)) + } +} + +// A few unused-variable suppression helpers that clippy might warn about +#[allow(dead_code)] +fn _unused_array_touch() { + let _ = array::from_fn::(|_| 0); +} diff --git a/crates/backend/goldilocks/src/helpers.rs b/crates/backend/goldilocks/src/helpers.rs new file mode 100644 index 00000000..65b08b10 --- /dev/null +++ b/crates/backend/goldilocks/src/helpers.rs @@ -0,0 +1,73 @@ +// Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). + +//! Helpers ported from `p3_util` and `p3_field::exponentiation`, scoped to what the +//! Goldilocks field implementation needs. + +use field::PrimeCharacteristicRing; + +/// Given an element `x` from a 64-bit field `F_P`, compute `x / 2`. +#[inline] +#[must_use] +pub const fn halve_u64(x: u64) -> u64 { + let shift = (P + 1) >> 1; + let half = x >> 1; + if x & 1 == 0 { half } else { half + shift } +} + +/// Inner loop of the binary-GCD-based inversion algorithm used by Goldilocks. +/// +/// See https://eprint.iacr.org/2020/972.pdf for background; this mini-GCD builds up +/// a small transformation using u64 ops and bit shifts, which we then apply to the +/// big-int values in the outer loop. +#[inline] +pub const fn gcd_inner(a: &mut u64, b: &mut u64) -> (i64, i64, i64, i64) { + let (mut f0, mut g0, mut f1, mut g1) = (1, 0, 0, 1); + + let mut round = 0; + while round < NUM_ROUNDS { + if *a & 1 == 0 { + *a >>= 1; + } else { + if *a < *b { + core::mem::swap(a, b); + (f0, f1) = (f1, f0); + (g0, g1) = (g1, g0); + } + *a -= *b; + *a >>= 1; + f0 -= f1; + g0 -= g1; + } + f1 <<= 1; + g1 <<= 1; + + round += 1; + } + + (f0, g0, f1, g1) +} + +/// Compute `x -> x^{10540996611094048183}` using a custom addition chain. +/// +/// This map computes the seventh root of `x` if `x` is a member of the `Goldilocks` field. +/// It follows from: `7 * 10540996611094048183 = 4*(2^64 - 2^32) + 1 = 1 mod (p - 1)`. +#[must_use] +pub fn exp_10540996611094048183(val: R) -> R { + let p1 = val; + let p10 = p1.square(); + let p11 = p10 * p1; + let p100 = p10.square(); + let p111 = p100 * p11; + let p1_30 = p100.exp_power_of_2(30); + let p1_30_11 = p1_30 * p11; + let p1_30_11_000 = p1_30_11.exp_power_of_2(3); + let p1_30_11_011 = p1_30_11_000 * p1_30_11; + let p1_30_11_011_000000 = p1_30_11_011.exp_power_of_2(6); + let p_chunk12 = p1_30_11_011_000000 * p1_30_11_011; + let p_chunk12_000000000000 = p_chunk12.exp_power_of_2(12); + let p_chunk24 = p_chunk12_000000000000 * p_chunk12; + let p_chunk24_000000 = p_chunk24.exp_power_of_2(6); + let p_chunk30 = p_chunk24_000000 * p1_30_11; + let p_chunk30_0000 = p_chunk30.exp_power_of_2(4); + p_chunk30_0000 * p111 +} diff --git a/crates/backend/goldilocks/src/lib.rs b/crates/backend/goldilocks/src/lib.rs new file mode 100644 index 00000000..26f147b0 --- /dev/null +++ b/crates/backend/goldilocks/src/lib.rs @@ -0,0 +1,17 @@ +// Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). + +//! The Goldilocks prime field `F_p` where `p = 2^64 - 2^32 + 1`, and a degree-3 extension. +//! +//! This is a port of `plonky3/goldilocks/` adapted to the in-tree `field` trait crate. + +extern crate alloc; + +mod cubic_extension; +mod goldilocks; +mod helpers; +mod poseidon1; + +pub use cubic_extension::*; +pub use goldilocks::*; +pub use helpers::*; +pub use poseidon1::*; diff --git a/crates/backend/goldilocks/src/poseidon1.rs b/crates/backend/goldilocks/src/poseidon1.rs new file mode 100644 index 00000000..b8354880 --- /dev/null +++ b/crates/backend/goldilocks/src/poseidon1.rs @@ -0,0 +1,396 @@ +// Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). + +//! Scalar Poseidon1 permutation at width 8 for Goldilocks. +//! +//! Parameters: +//! - S-box `x^7` (smallest `d` with `gcd(d, p - 1) = 1` for Goldilocks) +//! - `R_F = 8` full rounds (4 initial + 4 terminal) +//! - `R_P = 22` partial rounds in the middle +//! - External MDS is the circulant matrix with first row `[7, 1, 3, 8, 8, 3, 4, 9]` +//! (Plonky2/upstream-Plonky3 "small MDS" — same matrix the upstream +//! `MdsMatrixGoldilocks` uses at width 8). +//! +//! The permutation is generic over any algebra `R` over `Goldilocks` that also +//! implements `InjectiveMonomial<7>`, mirroring the koala-bear crate's +//! Poseidon1 surface. + +use field::{Algebra, InjectiveMonomial, PrimeCharacteristicRing}; + +use crate::Goldilocks; + +pub const POSEIDON1_WIDTH: usize = 8; +pub const POSEIDON1_HALF_FULL_ROUNDS: usize = 4; +pub const POSEIDON1_PARTIAL_ROUNDS: usize = 22; +pub const POSEIDON1_SBOX_DEGREE: u64 = 7; +pub const POSEIDON1_DIGEST_LEN: usize = 4; + +const POSEIDON1_N_ROUNDS: usize = + 2 * POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS; + +// ========================================================================= +// MDS matrix (circulant, width 8) +// ========================================================================= +// +// First row of the circulant MDS matrix. `MDS8_COL[i] = r_{(N - i) mod N}` is +// the first column — more convenient for a row-major apply of a circulant +// since `row_i = cyclic_shift(col, i)`, i.e. `M[i][j] = COL[(j - i + N) mod N]` +// (equivalently `ROW[(j - i) mod N]`). +const MDS8_ROW: [i64; 8] = [7, 1, 3, 8, 8, 3, 4, 9]; + +/// Apply the width-8 circulant MDS matrix in place, generic over `R`. +/// +/// The matrix has tiny integer entries (max 9), so even without any delayed +/// reduction a plain algebra-over-Goldilocks multiply is fine. +#[inline] +fn mds_mul_generic>(state: &mut [R; 8]) { + // Precompute the constants as Goldilocks once — `From` for `R` + // gives us `R` conversions. + let coeffs: [Goldilocks; 8] = { + let mut arr = [Goldilocks::ZERO; 8]; + for i in 0..8 { + arr[i] = Goldilocks::new(MDS8_ROW[i] as u64); + } + arr + }; + + let input = *state; + for i in 0..8 { + // `row_i · input = sum_j ROW[(j - i) mod 8] · input[j]` + let mut acc = input[0] * coeffs[(8 - i) % 8]; + for j in 1..8 { + acc = acc + input[j] * coeffs[(j + 8 - i) % 8]; + } + state[i] = acc; + } +} + +/// Specialized fast MDS for the concrete `Goldilocks` scalar — uses a single +/// `u128` accumulator and one `reduce128` per output lane (coefficients ≤ 9, +/// so `8 × 9 × 2^64 ≈ 2^71` fits comfortably). +#[inline] +fn mds_mul_scalar(state: &mut [Goldilocks; 8]) { + let mut out = [Goldilocks::ZERO; 8]; + for i in 0..8 { + let mut acc: u128 = 0; + for j in 0..8 { + let c = MDS8_ROW[(j + 8 - i) % 8] as u128; + acc = acc.wrapping_add(c.wrapping_mul(state[j].value as u128)); + } + out[i] = crate::goldilocks::reduce128(acc); + } + *state = out; +} + +// ========================================================================= +// Round constants (width 8) +// ========================================================================= +// +// Layout: [4 initial full][22 partial][4 terminal full]. +// Generated by the Grain LFSR (Poseidon1, Appendix E) with +// `field_type = 1, alpha = 7, n = 64, t = 8, R_F = 8, R_P = 22`. +// Values carried over verbatim from `plonky3/goldilocks/src/poseidon1.rs`. +pub const GOLDILOCKS_POSEIDON1_RC_8: [[Goldilocks; POSEIDON1_WIDTH]; POSEIDON1_N_ROUNDS] = + Goldilocks::new_2d_array([ + // ---- Initial full rounds (4) ---- + [ + 0xdd5743e7f2a5a5d9, 0xcb3a864e58ada44b, 0xffa2449ed32f8cdc, 0x42025f65d6bd13ee, + 0x7889175e25506323, 0x34b98bb03d24b737, 0xbdcc535ecc4faa2a, 0x5b20ad869fc0d033, + ], + [ + 0xf1dda5b9259dfcb4, 0x27515210be112d59, 0x4227d1718c766c3f, 0x26d333161a5bd794, + 0x49b938957bf4b026, 0x4a56b5938b213669, 0x1120426b48c8353d, 0x6b323c3f10a56cad, + ], + [ + 0xce57d6245ddca6b2, 0xb1fc8d402bba1eb1, 0xb5c5096ca959bd04, 0x6db55cd306d31f7f, + 0xc49d293a81cb9641, 0x1ce55a4fe979719f, 0xa92e60a9d178a4d1, 0x002cc64973bcfd8c, + ], + [ + 0xcea721cce82fb11b, 0xe5b55eb8098ece81, 0x4e30525c6f1ddd66, 0x43c6702827070987, + 0xaca68430a7b5762a, 0x3674238634df9c93, 0x88cee1c825e33433, 0xde99ae8d74b57176, + ], + // ---- Partial rounds (22) ---- + [ + 0x488897d85ff51f56, 0x1140737ccb162218, 0xa7eeb9215866ed35, 0x9bd2976fee49fcc9, + 0xc0c8f0de580a3fcc, 0x4fb2dae6ee8fc793, 0x343a89f35f37395b, 0x223b525a77ca72c8, + ], + [ + 0x56ccb62574aaa918, 0xc4d507d8027af9ed, 0xa080673cf0b7e95c, 0xf0184884eb70dcf8, + 0x044f10b0cb3d5c69, 0xe9e3f7993938f186, 0x1b761c80e772f459, 0x606cec607a1b5fac, + ], + [ + 0x14a0c2e1d45f03cd, 0x4eace8855398574f, 0xf905ca7103eff3e6, 0xf8c8f8d20862c059, + 0xb524fe8bdd678e5a, 0xfbb7865901a1ec41, 0x014ef1197d341346, 0x9725e20825d07394, + ], + [ + 0xfdb25aef2c5bae3b, 0xbe5402dc598c971e, 0x93a5711f04cdca3d, 0xc45a9a5b2f8fb97b, + 0xfe8946a924933545, 0x2af997a27369091c, 0xaa62c88e0b294011, 0x058eb9d810ce9f74, + ], + [ + 0xb3cb23eced349ae4, 0xa3648177a77b4a84, 0x43153d905992d95d, 0xf4e2a97cda44aa4b, + 0x5baa2702b908682f, 0x082923bdf4f750d1, 0x98ae09a325893803, 0xf8a6475077968838, + ], + [ + 0xceb0735bf00b2c5f, 0x0a1a5d953888e072, 0x2fcb190489f94475, 0xb5be06270dec69fc, + 0x739cb934b09acf8b, 0x537750b75ec7f25b, 0xe9dd318bae1f3961, 0xf7462137299efe1a, + ], + [ + 0xb1f6b8eee9adb940, 0xbdebcc8a809dfe6b, 0x40fc1f791b178113, 0x3ac1c3362d014864, + 0x9a016184bdb8aeba, 0x95f2394459fbc25e, 0xe3f34a07a76a66c2, 0x8df25f9ad98b1b96, + ], + [ + 0x85ffc27171439d9d, 0xddcb9a2dcfd26910, 0x26b5ba4bf3afb94e, 0xffff9cc7c7651e2f, + 0x8c88364698280b55, 0xebc114167b910501, 0x2d77b4d89ecfb516, 0x332e0828eba151f2, + ], + [ + 0x46fa6a6450dd4735, 0xd00db7dd92384a33, 0x5fd4fb751f3a5fc5, 0x496fb90c0bb65ea2, + 0xf3baec0bb87cc5c7, 0x862a3c0a7d4c7713, 0xbf5f38336a3f47d8, 0x41ad9dbc1394a20c, + ], + [ + 0xcc535945b7dbf0f7, 0x82af2bc93685bcec, 0x8e4c8d0c8cebfccd, 0x17cb39417e84597e, + 0xd4a965a8c749b232, 0xa2cab040f33f3ee5, 0xa98811a1fed4e3a6, 0x1cc48b54f377e2a1, + ], + [ + 0xe40cd4f6c5609a27, 0x11de79ebca97a4a4, 0x9177c73d8b7e929d, 0x2a6fe8085797e792, + 0x3de6e93329f8d5ae, 0x3f7af9125da962ff, 0xd710682cfc77d3ac, 0x48faf05f3b053cf4, + ], + [ + 0x287db8630da89c8b, 0x4d0de32053cb30e9, 0x8b37a4f20c5ada7b, 0xe7cc6ebe78c84ecf, + 0x240bdc0a66a2610d, 0x8299e7f02caa1650, 0x380a53fefb6e754e, 0x684a1d8cf8eb6810, + ], + [ + 0xe839452eb4b8a5e1, 0xb03fa62e90626af4, 0x11a688602fbc5efc, 0x30dda75c355a2d62, + 0x0f712adcb73810de, 0xffdc1102187f1ae1, 0x40c34f398254b99c, 0xede021b9dc289a4a, + ], + [ + 0x8b7b05225c4e7dad, 0x3bc794346f9d9ff9, 0xfccb5a57f2ca86ff, 0xbb1502015a7da9d4, + 0xd7e0a35d4352a015, 0x27af7a44f8160931, 0xc37442f6782f4615, 0xbdf392a9bd095dcb, + ], + [ + 0xc17f55037cf00de9, 0xbcffedd34c71a874, 0x5eb45d2a8133d1f2, 0xbabe251e1612ebdf, + 0x3efeb9fbe438c536, 0x2d7cef97b4afe1cf, 0xe5de1b4660016c0b, 0xcdcc26c332f5657c, + ], + [ + 0xe01dd653daf15809, 0xb0a6bdd4b41094b5, 0x27eac858b0b03a05, 0x51d43b5e93adbdc0, + 0x8b89a23b0fea5fc9, 0xdc8ac3b14f7f2fc1, 0xe793f82f1efec039, 0x9f6f2cf8969e7b80, + ], + [ + 0x49d45382e0f21d4a, 0x5f4ad1797cd72786, 0x4dc3dbebfd45f795, 0x03a3ef84dba6e1bc, + 0x204bc9b3d3fc4c01, 0x9ad706081e89b9ba, 0x638bfb4d840e9f89, 0x5ef2938cd095ae35, + ], + [ + 0x42cca18ebeb265c8, 0xb7b2ec5c29aecbf8, 0x0d84f9535dc78f0f, 0x04e64ad942e77b8c, + 0xb4880dffffc9da0b, 0x16db16d9c29adeb1, 0x09bbaf2a0590cd1e, 0x76460e74961fcf8d, + ], + [ + 0xed12a2276dfa1553, 0x0b5acec5de0436fd, 0x3c6cfea033a1f0a8, 0x2b5ecefe546cac15, + 0x6e2d82884cd3bf6f, 0xc134878d1add7b83, 0x997963422eb7a280, 0x5e834537ac648cf6, + ], + [ + 0x89e779214737c0b7, 0x1a8c05e8581ad95b, 0x8d18b72796437cf7, 0xe7252c949e04b106, + 0x53267c4fd174585a, 0xa16ef5d9c81dad47, 0xda65191937270a46, 0xcb2a5b55f2df664c, + ], + [ + 0x854aee2dc1924137, 0xf37013c9d479ece6, 0x0e163bc0630c4696, 0x384ee64955048f76, + 0xf65d814e28ee4ec5, 0xe57bc564fd82f1b1, 0x4b338937b6876614, 0x66ee0b04ed43cd8d, + ], + [ + 0x49884bf25f4ef15d, 0xeb51fe28de1c6f54, 0x2cd64e84fce8dfcc, 0x29164a96a541a013, + 0x173ce7558f4cacb8, 0xeb5b1ce5877c89e9, 0x5faff4b0f5217bf6, 0xac42d0b1c20f205e, + ], + // ---- Terminal full rounds (4) ---- + [ + 0xfb1d6bf0ca43221b, 0x97b0a1b01d6a2955, 0x08c60bd622952b30, 0x43f2be0f9e24147c, + 0xfa7268b7d3730f5d, 0x43a6c419a23983bb, 0xcd77c1f7b29b113c, 0xcfa43c9db8eec29f, + ], + [ + 0xcaaa95a6c7365dec, 0x0a91193f798f3be0, 0x1104497652735dc6, 0x35aecb93663b515e, + 0x8dbc9916065aa858, 0xada8f7a0266579ed, 0x524dee7bec1ea789, 0xa93aee9dd5af9521, + ], + [ + 0x9d1f1b54750d707e, 0x7c9feab87096d5dc, 0xa2e1fb19f9d4261b, 0xb714deb448de6346, + 0x225d1f0d011c5403, 0x1549b7f1d28cedc0, 0xaef3e46f97d43942, 0x6dfc7ffe0b38bf08, + ], + [ + 0x7de853fdc542b663, 0xa68ecc96610657b2, 0xe88bb5428af289b1, 0xd7cfa1504c5569f5, + 0x78a9aad0d642d30a, 0xd68315f2353dce52, 0x46e56300f86fcfd5, 0x323d95332b145fd6, + ], + ]); + +// ========================================================================= +// S-box helpers +// ========================================================================= + +#[inline(always)] +fn sbox_full>(x: R) -> R { + x.injective_exp_n() +} + +// ========================================================================= +// Permutation driver +// ========================================================================= + +/// Width-8 Poseidon1 permutation for Goldilocks. +/// +/// Zero-sized — all state lives in the round-constant tables above. Mirrors +/// `Poseidon1Goldilocks8`'s public surface: `permute{,_mut}`, +/// `compress{,_in_place}`, plus a `default_goldilocks_poseidon1_8()` constructor. +#[derive(Clone, Copy, Debug, Default)] +pub struct Poseidon1Goldilocks8; + +impl Poseidon1Goldilocks8 { + /// Fast scalar permutation — direct `Goldilocks` arithmetic with a `u128` + /// MDS accumulator. + pub fn permute(&self, mut state: [Goldilocks; POSEIDON1_WIDTH]) -> [Goldilocks; POSEIDON1_WIDTH] { + self.permute_mut(&mut state); + state + } + + pub fn permute_mut(&self, state: &mut [Goldilocks; POSEIDON1_WIDTH]) { + for r in 0..POSEIDON1_HALF_FULL_ROUNDS { + for i in 0..POSEIDON1_WIDTH { + state[i] += GOLDILOCKS_POSEIDON1_RC_8[r][i]; + } + for s in state.iter_mut() { + *s = sbox_full::(*s); + } + mds_mul_scalar(state); + } + + for r in POSEIDON1_HALF_FULL_ROUNDS + ..POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS + { + for i in 0..POSEIDON1_WIDTH { + state[i] += GOLDILOCKS_POSEIDON1_RC_8[r][i]; + } + state[0] = sbox_full::(state[0]); + mds_mul_scalar(state); + } + + for r in POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS..POSEIDON1_N_ROUNDS { + for i in 0..POSEIDON1_WIDTH { + state[i] += GOLDILOCKS_POSEIDON1_RC_8[r][i]; + } + for s in state.iter_mut() { + *s = sbox_full::(*s); + } + mds_mul_scalar(state); + } + } + + /// Generic permutation over any algebra `R` over `Goldilocks` with `x^7` + /// as an injective monomial. Used by the AIR / symbolic trace builders. + pub fn permute_generic(&self, state: &mut [R; POSEIDON1_WIDTH]) + where + R: Algebra + InjectiveMonomial<7> + Copy, + { + for r in 0..POSEIDON1_HALF_FULL_ROUNDS { + for i in 0..POSEIDON1_WIDTH { + state[i] = state[i] + GOLDILOCKS_POSEIDON1_RC_8[r][i]; + } + for s in state.iter_mut() { + *s = sbox_full::(*s); + } + mds_mul_generic(state); + } + + for r in POSEIDON1_HALF_FULL_ROUNDS + ..POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS + { + for i in 0..POSEIDON1_WIDTH { + state[i] = state[i] + GOLDILOCKS_POSEIDON1_RC_8[r][i]; + } + state[0] = sbox_full::(state[0]); + mds_mul_generic(state); + } + + for r in POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS..POSEIDON1_N_ROUNDS { + for i in 0..POSEIDON1_WIDTH { + state[i] = state[i] + GOLDILOCKS_POSEIDON1_RC_8[r][i]; + } + for s in state.iter_mut() { + *s = sbox_full::(*s); + } + mds_mul_generic(state); + } + } + + /// Pure-permutation compress: apply the permutation in place, return the + /// full width-8 state. Callers that want a digest truncate to the first + /// `POSEIDON1_DIGEST_LEN = 4` lanes. + #[inline] + pub fn compress(&self, input: [Goldilocks; POSEIDON1_WIDTH]) -> [Goldilocks; POSEIDON1_WIDTH] { + self.permute(input) + } + + /// Compression-mode in-place permutation: `output = permute(input) + input`. + /// + /// Matches the koala-bear `Poseidon1Goldilocks8::compress_in_place` shape + /// so the `Compression<[R; 8]>` impl can reuse it. + #[inline] + pub fn compress_in_place(&self, state: &mut [R; POSEIDON1_WIDTH]) + where + R: Algebra + InjectiveMonomial<7> + Copy, + { + let initial = *state; + self.permute_generic(state); + for (s, init) in state.iter_mut().zip(initial) { + *s = *s + init; + } + } +} + +/// Return the default width-8 Poseidon1 permutation. +#[inline] +pub fn default_goldilocks_poseidon1_8() -> Poseidon1Goldilocks8 { + Poseidon1Goldilocks8 +} + +// ========================================================================= +// Tests +// ========================================================================= + +#[cfg(test)] +mod tests { + use super::*; + + /// The scalar and generic paths must agree on all inputs. + #[test] + fn scalar_matches_generic() { + let p = Poseidon1Goldilocks8; + let mut input = [Goldilocks::ZERO; 8]; + for i in 0..8 { + input[i] = Goldilocks::new(0xdead_beef_0000_0001u64.wrapping_mul(i as u64 + 1)); + } + let fast = p.permute(input); + let mut slow = input; + p.permute_generic(&mut slow); + assert_eq!(fast, slow); + } + + /// The permutation is deterministic and non-trivial. + #[test] + fn permutation_is_deterministic() { + let input: [Goldilocks; 8] = [ + Goldilocks::new(1), Goldilocks::new(2), Goldilocks::new(3), Goldilocks::new(4), + Goldilocks::new(5), Goldilocks::new(6), Goldilocks::new(7), Goldilocks::new(8), + ]; + let p = Poseidon1Goldilocks8; + let a = p.permute(input); + let b = p.permute(input); + assert_eq!(a, b); + assert_ne!(a, input); + } + + /// Rough avalanche smoke test: distinct inputs produce distinct outputs. + #[test] + fn permutation_is_injective_on_small_inputs() { + let p = Poseidon1Goldilocks8; + let mut seen = std::collections::HashSet::new(); + for i in 0..64u64 { + let mut input = [Goldilocks::ZERO; 8]; + input[0] = Goldilocks::new(i); + let out = p.permute(input); + assert!(seen.insert(out[0].value), "collision at i={i}"); + } + } +} diff --git a/crates/backend/poly/Cargo.toml b/crates/backend/poly/Cargo.toml index 328b3d3f..2f050394 100644 --- a/crates/backend/poly/Cargo.toml +++ b/crates/backend/poly/Cargo.toml @@ -13,4 +13,4 @@ rand.workspace = true serde.workspace = true [dev-dependencies] -koala-bear = { path = "../koala-bear", package = "mt-koala-bear" } +goldilocks = { path = "../goldilocks", package = "mt-goldilocks" } diff --git a/crates/backend/poly/src/eq_mle.rs b/crates/backend/poly/src/eq_mle.rs index b7a5a23e..90a55e74 100644 --- a/crates/backend/poly/src/eq_mle.rs +++ b/crates/backend/poly/src/eq_mle.rs @@ -1062,12 +1062,12 @@ mod tests { use std::time::Instant; use field::Field; - use koala_bear::QuinticExtensionFieldKB; + use goldilocks::CubicExtensionFieldGL; use rand::{RngExt, SeedableRng, rngs::StdRng}; use super::*; - type F = koala_bear::KoalaBear; - type EF = QuinticExtensionFieldKB; + type F = goldilocks::Goldilocks; + type EF = CubicExtensionFieldGL; #[test] fn test_compute_sparse_eval() { diff --git a/crates/backend/poly/src/evals.rs b/crates/backend/poly/src/evals.rs index 7e0e07b4..46926dc6 100644 --- a/crates/backend/poly/src/evals.rs +++ b/crates/backend/poly/src/evals.rs @@ -350,11 +350,11 @@ where mod tests { use std::time::Instant; - use koala_bear::QuinticExtensionFieldKB; + use goldilocks::CubicExtensionFieldGL; use rand::{RngExt, SeedableRng, rngs::StdRng}; - type F = QuinticExtensionFieldKB; - type EF = QuinticExtensionFieldKB; + type F = CubicExtensionFieldGL; + type EF = CubicExtensionFieldGL; use super::*; diff --git a/crates/backend/poly/src/mle/mle_custom.rs b/crates/backend/poly/src/mle/mle_custom.rs index 58b22066..ebc8ac90 100644 --- a/crates/backend/poly/src/mle/mle_custom.rs +++ b/crates/backend/poly/src/mle/mle_custom.rs @@ -17,11 +17,11 @@ pub fn mle_of_zeros_then_ones(n_zeros: usize, point: &[F]) -> F { mod tests { use crate::{EvaluationsList, MultilinearPoint}; use field::PrimeCharacteristicRing; - use koala_bear::KoalaBear; + use goldilocks::Goldilocks; use rand::{RngExt, SeedableRng, rngs::StdRng}; use super::*; - type F = KoalaBear; + type F = Goldilocks; #[test] fn test_mle_of_zeros_then_ones() { diff --git a/crates/backend/poly/src/next_mle.rs b/crates/backend/poly/src/next_mle.rs index 7c9c687c..960387e6 100644 --- a/crates/backend/poly/src/next_mle.rs +++ b/crates/backend/poly/src/next_mle.rs @@ -56,11 +56,11 @@ where #[cfg(test)] mod tests { use field::PrimeCharacteristicRing; - use koala_bear::KoalaBear; + use goldilocks::Goldilocks; use crate::{EvaluationsList, MultilinearPoint, matrix_next_mle_folded, next_mle, to_big_endian_in_field}; - type F = KoalaBear; + type F = Goldilocks; #[test] fn test_matrix_down_folded() { diff --git a/crates/backend/src/lib.rs b/crates/backend/src/lib.rs index cbd44fb2..be346c38 100644 --- a/crates/backend/src/lib.rs +++ b/crates/backend/src/lib.rs @@ -1,7 +1,7 @@ pub use air::*; pub use fiat_shamir::*; pub use field::*; -pub use koala_bear::*; +pub use goldilocks::*; pub use poly::*; pub use rayon; pub use rayon::prelude::*; diff --git a/crates/backend/sumcheck/src/product_computation.rs b/crates/backend/sumcheck/src/product_computation.rs index ecce379f..027bb5a3 100644 --- a/crates/backend/sumcheck/src/product_computation.rs +++ b/crates/backend/sumcheck/src/product_computation.rs @@ -45,8 +45,8 @@ pub fn run_product_sumcheck>>( assert!(n_rounds >= 1); let first_sumcheck_poly = match (pol_a, pol_b) { (MleRef::BasePacked(evals), MleRef::ExtensionPacked(weights)) => { - if EF::DIMENSION == 5 { - compute_product_sumcheck_polynomial_base_ext_packed::<5, _, _, _, EF>(evals, weights, sum) + if EF::DIMENSION == 3 { + compute_product_sumcheck_polynomial_base_ext_packed::<3, _, _, _, EF>(evals, weights, sum) } else { unimplemented!() } @@ -168,10 +168,12 @@ pub fn compute_product_sumcheck_polynomial< DensePolynomial::new(vec![c0, c1, c2]) } -// using delayed modular reduction +// Generic over PrimeField64 (Goldilocks and Goldilocks both qualify). The Goldilocks-specific +// delayed u128/i128 accumulation path is retained as a specialization candidate for a future +// pass — see `crates/backend/goldilocks/README.md`. pub fn compute_product_sumcheck_polynomial_base_ext_packed< const DIM: usize, - F: PrimeField32, + F: PrimeField64, PF: PackedField, EFP: BasedVectorSpace + Copy + Send + Sync, EF: Field + BasedVectorSpace, @@ -186,8 +188,6 @@ pub fn compute_product_sumcheck_polynomial_base_ext_packed< assert!(n.is_power_of_two()); let half = n / 2; - type Acc = ([u128; D], [i128; D]); - let chunk_size = 1024; let (c0_acc, c2_acc) = pol_0[..half] @@ -199,8 +199,8 @@ pub fn compute_product_sumcheck_polynomial_base_ext_packed< .zip(pol_1[half..].par_chunks(chunk_size)), ) .map(|((b_lo, b_hi), (e_lo, e_hi))| { - let mut c0 = [0u128; DIM]; - let mut c2 = [0i128; DIM]; + let mut c0 = [F::ZERO; DIM]; + let mut c2 = [F::ZERO; DIM]; for i in 0..b_lo.len() { let x0_lanes = b_lo[i].as_slice(); let x1_lanes = b_hi[i].as_slice(); @@ -210,20 +210,20 @@ pub fn compute_product_sumcheck_polynomial_base_ext_packed< let y0_j = y0_coords[j].as_slice(); let y1_j = y1_coords[j].as_slice(); for lane in 0..PF::WIDTH { - let x0 = x0_lanes[lane].to_unique_u32() as u64; - let y0 = y0_j[lane].to_unique_u32(); - let y1 = y1_j[lane].to_unique_u32(); - c0[j] += (y0 as u64 * x0) as u128; - c2[j] += (y1 as i64 - y0 as i64) as i128 - * (x1_lanes[lane].to_unique_u32() as i64 - x0 as i64) as i128; + let x0 = x0_lanes[lane]; + let x1 = x1_lanes[lane]; + let y0 = y0_j[lane]; + let y1 = y1_j[lane]; + c0[j] += y0 * x0; + c2[j] += (y1 - y0) * (x1 - x0); } } } (c0, c2) }) .reduce( - || ([0u128; DIM], [0i128; DIM]), - |(mut a0, mut a2): Acc, (b0, b2): Acc| { + || ([F::ZERO; DIM], [F::ZERO; DIM]), + |(mut a0, mut a2): ([F; DIM], [F; DIM]), (b0, b2): ([F; DIM], [F; DIM])| { for j in 0..DIM { a0[j] += b0[j]; a2[j] += b2[j]; @@ -232,8 +232,8 @@ pub fn compute_product_sumcheck_polynomial_base_ext_packed< }, ); - let c0 = EF::from_basis_coefficients_fn(|j| F::reduce_product_sum(c0_acc[j])); - let c2 = EF::from_basis_coefficients_fn(|j| F::reduce_signed_product_sum(c2_acc[j])); + let c0 = EF::from_basis_coefficients_fn(|j| c0_acc[j]); + let c2 = EF::from_basis_coefficients_fn(|j| c2_acc[j]); let c1 = sum - c0.double() - c2; DensePolynomial::new(vec![c0, c1, c2]) diff --git a/crates/backend/symetric/Cargo.toml b/crates/backend/symetric/Cargo.toml index 125fb553..d959ae0e 100644 --- a/crates/backend/symetric/Cargo.toml +++ b/crates/backend/symetric/Cargo.toml @@ -4,6 +4,6 @@ version.workspace = true edition.workspace = true [dependencies] -koala-bear = { path = "../koala-bear", package = "mt-koala-bear" } +goldilocks = { path = "../goldilocks", package = "mt-goldilocks" } field = { path = "../field", package = "mt-field" } rayon.workspace = true diff --git a/crates/backend/symetric/src/merkle.rs b/crates/backend/symetric/src/merkle.rs index 676e83f3..5e4347b8 100644 --- a/crates/backend/symetric/src/merkle.rs +++ b/crates/backend/symetric/src/merkle.rs @@ -8,7 +8,7 @@ use rayon::prelude::*; use crate::Compression; -pub const DIGEST_ELEMS: usize = 8; +pub const DIGEST_ELEMS: usize = 4; /// A Merkle tree storing only the digest layers (no leaf data). #[derive(Debug, Clone)] diff --git a/crates/backend/symetric/src/permutation.rs b/crates/backend/symetric/src/permutation.rs index c129a1dc..8df7eb65 100644 --- a/crates/backend/symetric/src/permutation.rs +++ b/crates/backend/symetric/src/permutation.rs @@ -1,7 +1,7 @@ // Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). use field::{Algebra, InjectiveMonomial}; -use koala_bear::{KoalaBear, Poseidon1KoalaBear16}; +use goldilocks::{Goldilocks, Poseidon1Goldilocks8}; pub trait Compression: Clone + Sync { #[inline(always)] @@ -13,10 +13,10 @@ pub trait Compression: Clone + Sync { fn compress_mut(&self, input: &mut T); } -impl + InjectiveMonomial<3> + Send + Sync + 'static> Compression<[R; 16]> - for Poseidon1KoalaBear16 +impl + InjectiveMonomial<7> + Copy + Send + Sync + 'static> + Compression<[R; 8]> for Poseidon1Goldilocks8 { - fn compress_mut(&self, input: &mut [R; 16]) { + fn compress_mut(&self, input: &mut [R; 8]) { self.compress_in_place(input); } } diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index cb552560..5ba5332b 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -2193,8 +2193,8 @@ fn simplify_lines( continue; } - // Special handling for poseidon16 precompile - if function_name == Table::poseidon16().name() { + // Special handling for poseidon8 precompile + if function_name == Table::poseidon8().name() { if !targets.is_empty() { return Err(format!( "Precompile {function_name} should not return values, at {location}" @@ -2214,7 +2214,7 @@ fn simplify_lines( arg_0: simplified_args[0].clone(), arg_1: simplified_args[1].clone(), res: simplified_args[2].clone(), - data: PrecompileCompTimeArgs::Poseidon16, + data: PrecompileCompTimeArgs::Poseidon8, })); continue; } diff --git a/crates/lean_compiler/src/c_compile_final.rs b/crates/lean_compiler/src/c_compile_final.rs index e4fbb575..29aae7b1 100644 --- a/crates/lean_compiler/src/c_compile_final.rs +++ b/crates/lean_compiler/src/c_compile_final.rs @@ -192,7 +192,7 @@ fn compile_block( let dest = try_as_mem_or_constant(&dest).expect("Fatal: Could not materialize jump destination"); let label = match dest { MemOrConstant::Constant(dest) => hints - .get(&usize::try_from(dest.as_canonical_u32()).unwrap()) + .get(&usize::try_from(dest.as_canonical_u64()).unwrap()) .and_then(|hints: &Vec| { hints.iter().find_map(|x| match x { Hint::Label { label } => Some(label), diff --git a/crates/lean_compiler/src/instruction_encoder.rs b/crates/lean_compiler/src/instruction_encoder.rs index c97a4c3e..8eef6746 100644 --- a/crates/lean_compiler/src/instruction_encoder.rs +++ b/crates/lean_compiler/src/instruction_encoder.rs @@ -48,7 +48,7 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { } Instruction::Precompile(precompile) => { let precompile_data = match &precompile.data { - PrecompileCompTimeArgs::Poseidon16 => POSEIDON_PRECOMPILE_DATA, + PrecompileCompTimeArgs::Poseidon8 => POSEIDON_PRECOMPILE_DATA, PrecompileCompTimeArgs::ExtensionOp { size, mode } => { assert!(*size >= 1, "invalid extension_op size={size}"); mode.flag_encoding() + EXT_OP_LEN_MULTIPLIER * size diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index 1a37a3ea..1ab25198 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -9,7 +9,7 @@ use crate::{ grammar::{ParsePair, Rule}, }, }; -use lean_vm::{CUSTOM_HINTS, ExtensionOpMode, POSEIDON16_NAME}; +use lean_vm::{CUSTOM_HINTS, ExtensionOpMode, POSEIDON8_NAME}; /// Reserved function names that users cannot define. pub const RESERVED_FUNCTION_NAMES: &[&str] = &[ @@ -34,8 +34,8 @@ fn is_reserved_function_name(name: &str) -> bool { if RESERVED_FUNCTION_NAMES.contains(&name) || CUSTOM_HINTS.iter().any(|hint| hint.name() == name) { return true; } - // Check precompile names (poseidon16, extension_op functions) - if name == POSEIDON16_NAME { + // Check precompile names (poseidon8, extension_op functions) + if name == POSEIDON8_NAME { return true; } if ExtensionOpMode::from_name(name).is_some() { diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 2c187a08..2da5a707 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -4,7 +4,7 @@ use backend::BasedVectorSpace; use lean_compiler::*; use lean_vm::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; -use utils::poseidon16_compress; +use utils::poseidon8_compress; #[test] fn test_poseidon() { @@ -13,7 +13,7 @@ def main(): a = 0 b = a + 8 c = Array(8) - poseidon16_compress(a, b, c) + poseidon8_compress(a, b, c) for i in range(0, 8): cc = c[i] @@ -23,7 +23,7 @@ def main(): let public_input: [F; 16] = (0..16).map(F::new).collect::>().try_into().unwrap(); compile_and_run(&ProgramSource::Raw(program.to_string()), &public_input, false); - let _ = dbg!(poseidon16_compress(public_input)); + let _ = dbg!(poseidon8_compress(public_input)); } #[test] @@ -211,7 +211,7 @@ def main(): @inline def func(a, b): - poseidon16_compress(a, a, b) + poseidon8_compress(a, a, b) return "#; let bytecode = compile_program(&ProgramSource::Raw(program.to_string())); diff --git a/crates/lean_prover/src/lib.rs b/crates/lean_prover/src/lib.rs index 26753d97..6ca641d6 100644 --- a/crates/lean_prover/src/lib.rs +++ b/crates/lean_prover/src/lib.rs @@ -24,8 +24,13 @@ pub const WHIR_INITIAL_FOLDING_FACTOR: usize = 7; pub const WHIR_SUBSEQUENT_FOLDING_FACTOR: usize = 5; pub const RS_DOMAIN_INITIAL_REDUCTION_FACTOR: usize = 5; -pub const SNARK_DOMAIN_SEP: [F; 8] = F::new_array([ - 130704175, 1303721200, 493664240, 1035493700, 2063844858, 1410214009, 1938905908, 1696767928, +// Domain-separation digest for the zkVM SNARK. Arbitrary nothing-up-my-sleeve field +// elements; size matches `DIGEST_LEN = 4` for the Goldilocks width-8 Poseidon. +pub const SNARK_DOMAIN_SEP: [F; 4] = F::new_array([ + 0x4c45_414e_5f5a_4b56, // "LEAN_ZKV" + 0x4d5f_534e_4152_4b5f, // "M_SNARK_" + 0x444f_4d53_4550_3031, // "DOMSEP01" + 0xcccc_cccc_cccc_cccc, // nothing-up-my-sleeve tail ]); pub fn default_whir_config(starting_log_inv_rate: usize) -> WhirConfigBuilder { @@ -54,10 +59,10 @@ pub(crate) fn check_rate(log_inv_rate: usize) -> Result<(), ProofError> { #[cfg(test)] mod tests { - use backend::{PrimeCharacteristicRing, default_koalabear_poseidon1_16, hash_slice}; + use backend::{PrimeCharacteristicRing, default_goldilocks_poseidon1_8, hash_slice}; use lean_vm::F; use rec_aggregation::{get_aggregation_bytecode, init_aggregation_bytecode}; - use utils::poseidon16_compress_pair; + use utils::poseidon8_compress_pair; #[test] fn compute_snark_domain_sep() { @@ -68,19 +73,19 @@ mod tests { .iter() .map(|b| F::from_u8(*b)) .collect::>(); - let mut prefix_free_name_fe = vec![F::ZERO; 8]; + let mut prefix_free_name_fe = vec![F::ZERO; 4]; let len = name_fe.len(); prefix_free_name_fe.extend(name_fe); - while prefix_free_name_fe.len() % 8 != 7 { + while prefix_free_name_fe.len() % 4 != 3 { prefix_free_name_fe.push(F::ZERO); } prefix_free_name_fe.push(F::from_u64(len as u64)); - let comp = default_koalabear_poseidon1_16(); - let name_hash = hash_slice::<_, _, _, 8, 8>(&comp, &prefix_free_name_fe); + let comp = default_goldilocks_poseidon1_8(); + let name_hash = hash_slice::<_, _, _, 8, 4>(&comp, &prefix_free_name_fe); // We incorporate the recursion program hash, containing all the verifier logic, into fiat shamir domain separator // (likely not necessary but why not, is there a cleaner approach?) - let domain_sep = poseidon16_compress_pair(&name_hash, &recursion_bytecode_hash); + let domain_sep = poseidon8_compress_pair(&name_hash, &recursion_bytecode_hash); println!("Computed SNARK_DOMAIN_SEP: {:?}", domain_sep); // We dont assert equality here to avoid the pain of having to update the hardcoded SNARK_DOMAIN_SEP every time we change the recursion program } diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index 180f45e1..7c7f07d4 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -43,7 +43,7 @@ pub fn prove_execution( } let mut prover_state = build_prover_state(); prover_state.observe_scalars(public_input); - prover_state.observe_scalars(&poseidon16_compress_pair(&bytecode.hash, &SNARK_DOMAIN_SEP)); + prover_state.observe_scalars(&poseidon8_compress_pair(&bytecode.hash, &SNARK_DOMAIN_SEP)); prover_state.add_base_scalars( &[ vec![ diff --git a/crates/lean_prover/src/test_zkvm.rs b/crates/lean_prover/src/test_zkvm.rs index 7e47f344..38ad2991 100644 --- a/crates/lean_prover/src/test_zkvm.rs +++ b/crates/lean_prover/src/test_zkvm.rs @@ -3,19 +3,19 @@ use backend::*; use lean_compiler::*; use lean_vm::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; -use utils::{init_tracing, poseidon16_compress}; +use utils::{init_tracing, poseidon8_compress}; #[test] fn test_zk_vm_all_precompiles() { let program_str = r#" -DIM = 5 +DIM = 3 N = 11 M = 3 -DIGEST_LEN = 8 +DIGEST_LEN = 4 def main(): pub_start = 0 - poseidon16_compress(pub_start + 4 * DIGEST_LEN, pub_start + 5 * DIGEST_LEN, pub_start + 6 * DIGEST_LEN) + poseidon8_compress(pub_start + 4 * DIGEST_LEN, pub_start + 5 * DIGEST_LEN, pub_start + 6 * DIGEST_LEN) base_ptr = pub_start + 88 ext_a_ptr = pub_start + 88 + N @@ -55,12 +55,12 @@ def main(): let mut rng = StdRng::seed_from_u64(0); let mut public_input = F::zero_vec(1 << 13); - // Poseidon test data - let poseidon_16_compress_input: [F; 16] = rng.random(); - public_input[32..48].copy_from_slice(&poseidon_16_compress_input); - public_input[48..56].copy_from_slice(&poseidon16_compress(poseidon_16_compress_input)[..8]); - let poseidon_24_input: [F; 24] = rng.random(); - public_input[56..80].copy_from_slice(&poseidon_24_input); + // Poseidon test data — width 8 / digest 4 for Goldilocks. + // DSL uses `pub_start + 4*DIGEST_LEN..6*DIGEST_LEN` (positions 16..24) for the input + // and `pub_start + 6*DIGEST_LEN..7*DIGEST_LEN` (positions 24..28) for the output. + let poseidon_8_compress_input: [F; 8] = rng.random(); + public_input[16..24].copy_from_slice(&poseidon_8_compress_input); + public_input[24..28].copy_from_slice(&poseidon8_compress(poseidon_8_compress_input)); // Extension op operands: base[N], ext_a[N], ext_b[N] let base_slice: [F; N] = rng.random(); diff --git a/crates/lean_prover/src/trace_gen.rs b/crates/lean_prover/src/trace_gen.rs index d333d0ac..22c65ae1 100644 --- a/crates/lean_prover/src/trace_gen.rs +++ b/crates/lean_prover/src/trace_gen.rs @@ -1,7 +1,7 @@ use backend::*; use lean_vm::*; use std::{array, collections::BTreeMap}; -use utils::{ToUsize, get_poseidon_16_of_zero, transposed_par_iter_mut}; +use utils::{ToUsize, get_poseidon_8_of_zero, transposed_par_iter_mut}; #[derive(Debug)] pub struct ExecutionTrace { @@ -97,7 +97,7 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul let padding_zero_vec_ptr = memory_padded.len(); memory_padded.extend(std::iter::repeat_n(F::ZERO, 16)); let null_poseidon_16_hash_ptr = memory_padded.len(); - memory_padded.extend_from_slice(get_poseidon_16_of_zero()); + memory_padded.extend_from_slice(get_poseidon_8_of_zero()); // IMPORTANT: memory size should always be >= number of VM cycles let padded_memory_len = (memory_padded.len().max(n_cycles).max(1 << MIN_LOG_N_ROWS_PER_TABLE)).next_power_of_two(); @@ -105,8 +105,8 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul let ExecutionResult { mut traces, .. } = execution_result; - let poseidon_trace = traces.get_mut(&Table::poseidon16()).unwrap(); - fill_trace_poseidon_16(&mut poseidon_trace.columns); + let poseidon_trace = traces.get_mut(&Table::poseidon8()).unwrap(); + fill_trace_poseidon_8(&mut poseidon_trace.columns); let extension_op_trace = traces.get_mut(&Table::extension_op()).unwrap(); fill_trace_extension_op(extension_op_trace, &memory_padded); diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index fc3ad5dc..f72bf1c9 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -4,7 +4,7 @@ use crate::*; use backend::{Proof, RawProof, VerifierState}; use lean_vm::*; use sub_protocols::*; -use utils::{ToUsize, from_end, get_poseidon16}; +use utils::{ToUsize, from_end, get_poseidon8}; #[derive(Debug, Clone)] pub struct ProofVerificationDetails { @@ -16,9 +16,9 @@ pub fn verify_execution( public_input: &[F], proof: Proof, ) -> Result<(ProofVerificationDetails, RawProof), ProofError> { - let mut verifier_state = VerifierState::::new(proof, get_poseidon16().clone())?; + let mut verifier_state = VerifierState::::new(proof, get_poseidon8().clone())?; verifier_state.observe_scalars(public_input); - verifier_state.observe_scalars(&poseidon16_compress_pair(&bytecode.hash, &SNARK_DOMAIN_SEP)); + verifier_state.observe_scalars(&poseidon8_compress_pair(&bytecode.hash, &SNARK_DOMAIN_SEP)); let dims = verifier_state .next_base_scalars_vec(3 + N_TABLES)? .into_iter() diff --git a/crates/lean_vm/src/core/constants.rs b/crates/lean_vm/src/core/constants.rs index c0da3506..538e77a6 100644 --- a/crates/lean_vm/src/core/constants.rs +++ b/crates/lean_vm/src/core/constants.rs @@ -5,10 +5,10 @@ pub const LOGUP_MEMORY_DOMAINSEP: usize = 0; pub const LOGUP_PRECOMPILE_DOMAINSEP: usize = 1; pub const LOGUP_BYTECODE_DOMAINSEP: usize = 2; -/// Large field = extension field of degree DIMENSION over koala-bear -pub const DIMENSION: usize = 5; +/// Large field = extension field of degree DIMENSION over Goldilocks +pub const DIMENSION: usize = 3; -pub const DIGEST_LEN: usize = 8; +pub const DIGEST_LEN: usize = 4; pub const MIN_WHIR_LOG_INV_RATE: usize = 1; pub const MAX_WHIR_LOG_INV_RATE: usize = 4; @@ -22,7 +22,7 @@ pub const MIN_LOG_N_ROWS_PER_TABLE: usize = 8; // Zero padding will be added to pub const MAX_LOG_N_ROWS_PER_TABLE: [(Table, usize); 3] = [ (Table::execution(), 25), (Table::extension_op(), 20), - (Table::poseidon16(), 20), + (Table::poseidon8(), 20), ]; /// Starting program counter diff --git a/crates/lean_vm/src/core/types.rs b/crates/lean_vm/src/core/types.rs index fbad9af1..393f5e25 100644 --- a/crates/lean_vm/src/core/types.rs +++ b/crates/lean_vm/src/core/types.rs @@ -3,13 +3,13 @@ use std::{ fmt::{Display, Formatter}, }; -use backend::{KoalaBear, QuinticExtensionFieldKB}; +use backend::{CubicExtensionFieldGL, Goldilocks}; /// Base field type for VM operations -pub type F = KoalaBear; +pub type F = Goldilocks; /// Extension field type for VM operations -pub type EF = QuinticExtensionFieldKB; +pub type EF = CubicExtensionFieldGL; /// Line number in source code for debugging pub type SourceLineNumber = usize; diff --git a/crates/lean_vm/src/diagnostics/exec_result.rs b/crates/lean_vm/src/diagnostics/exec_result.rs index dcb1ae0c..0c231533 100644 --- a/crates/lean_vm/src/diagnostics/exec_result.rs +++ b/crates/lean_vm/src/diagnostics/exec_result.rs @@ -52,7 +52,7 @@ impl ExecutionMetadata { out.push('\n'); if self.n_poseidons > 0 { out.push_str(&format!( - "Poseidon16 calls: {} (1 poseidon per {} instructions)\n", + "Poseidon8 calls: {} (1 poseidon per {} instructions)\n", pretty_integer(self.n_poseidons), self.cycles / self.n_poseidons )); diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs index 91db24ae..bf7243fa 100644 --- a/crates/lean_vm/src/execution/runner.rs +++ b/crates/lean_vm/src/execution/runner.rs @@ -328,7 +328,7 @@ fn execute_bytecode_helper( let metadata = ExecutionMetadata { cycles: trace.pcs.len(), memory: memory.0.len(), - n_poseidons: trace.tables[&Table::poseidon16()].columns[0].len(), + n_poseidons: trace.tables[&Table::poseidon8()].columns[0].len(), n_extension_ops: trace.tables[&Table::extension_op()].columns[0].len(), bytecode_size: bytecode.instructions.len(), public_input_size: public_input.len(), diff --git a/crates/lean_vm/src/execution/tests.rs b/crates/lean_vm/src/execution/tests.rs index 60ba768d..4328ebdc 100644 --- a/crates/lean_vm/src/execution/tests.rs +++ b/crates/lean_vm/src/execution/tests.rs @@ -24,15 +24,23 @@ fn test_memory_already_set_error() { // Setting same value should work memory.set(0, F::ONE).unwrap(); - // Setting different value should fail - assert!(matches!( - memory.set(0, F::ZERO), - Err(RunnerError::MemoryAlreadySet { - address: 0, - prev_value: F::ONE, - new_value: F::ZERO, - }) - )); + // Setting different value should fail. + // Goldilocks has two redundant representations for each canonical value + // (x and x + ORDER both reduce to x), so it isn't `StructuralPartialEq` + // and can't be matched on directly — compare the fields explicitly instead. + let err = memory.set(0, F::ZERO).unwrap_err(); + match err { + RunnerError::MemoryAlreadySet { + address, + prev_value, + new_value, + } => { + assert_eq!(address, 0); + assert_eq!(prev_value, F::ONE); + assert_eq!(new_value, F::ZERO); + } + other => panic!("unexpected error variant: {other:?}"), + } } #[test] diff --git a/crates/lean_vm/src/isa/instruction.rs b/crates/lean_vm/src/isa/instruction.rs index 53455065..a8362300 100644 --- a/crates/lean_vm/src/isa/instruction.rs +++ b/crates/lean_vm/src/isa/instruction.rs @@ -2,7 +2,7 @@ use super::Operation; use super::operands::{MemOrConstant, MemOrFpOrConstant}; -use crate::POSEIDON16_NAME; +use crate::POSEIDON8_NAME; use crate::core::{F, Label}; use crate::diagnostics::RunnerError; use crate::execution::memory::MemoryAccess; @@ -63,21 +63,21 @@ pub struct PrecompileArgs { #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum PrecompileCompTimeArgs { - Poseidon16, + Poseidon8, ExtensionOp { size: S, mode: ExtensionOpMode }, } impl PrecompileCompTimeArgs { pub fn table(&self) -> Table { match self { - Self::Poseidon16 => Table::poseidon16(), + Self::Poseidon8 => Table::poseidon8(), Self::ExtensionOp { .. } => Table::extension_op(), } } pub fn map_size(self, f: impl FnOnce(S) -> T) -> PrecompileCompTimeArgs { match self { - Self::Poseidon16 => PrecompileCompTimeArgs::Poseidon16, + Self::Poseidon8 => PrecompileCompTimeArgs::Poseidon8, Self::ExtensionOp { size, mode } => PrecompileCompTimeArgs::ExtensionOp { size: f(size), mode }, } } @@ -233,8 +233,8 @@ impl Display for PrecompileArgs { data, } = self; match data { - PrecompileCompTimeArgs::Poseidon16 => { - write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res})") + PrecompileCompTimeArgs::Poseidon8 => { + write!(f, "{POSEIDON8_NAME}({arg_0}, {arg_1}, {res})") } PrecompileCompTimeArgs::ExtensionOp { size, mode } => { write!(f, "{}({arg_0}, {arg_1}, {res}, {size})", mode.name()) diff --git a/crates/lean_vm/src/tables/extension_op/air.rs b/crates/lean_vm/src/tables/extension_op/air.rs index 2f39a668..16c11daa 100644 --- a/crates/lean_vm/src/tables/extension_op/air.rs +++ b/crates/lean_vm/src/tables/extension_op/air.rs @@ -1,11 +1,12 @@ use crate::{ - EF, EXT_OP_FLAG_ADD, EXT_OP_FLAG_IS_BE, EXT_OP_FLAG_MUL, EXT_OP_FLAG_POLY_EQ, ExtraDataForBuses, - eval_virtual_bus_column, + DIMENSION, EF, EXT_OP_FLAG_ADD, EXT_OP_FLAG_IS_BE, EXT_OP_FLAG_MUL, EXT_OP_FLAG_POLY_EQ, + ExtraDataForBuses, eval_virtual_bus_column, tables::extension_op::{EXT_OP_LEN_MULTIPLIER, ExtensionOpPrecompile}, }; use backend::*; -//0..5 columns (AIR, 29 total) +// ---------- Column layout (cubic extension, DIMENSION = 3) ---------- +// 0..9: flags / indices (9 scalar cols) pub(super) const COL_IS_BE: usize = 0; pub(super) const COL_START: usize = 1; pub(super) const COL_FLAG_ADD: usize = 2; @@ -16,39 +17,60 @@ pub(super) const COL_IDX_A: usize = 6; pub(super) const COL_IDX_B: usize = 7; pub(super) const COL_IDX_RES: usize = 8; -/// value_a coordinates (5 columns) +// 9..12: value_a coordinates (3 cols) pub(super) const COL_VA: usize = 9; -/// value_b coordinates (5 columns) -pub(super) const COL_VB: usize = 14; -/// result coordinates (5 columns). -pub(super) const COL_VRES: usize = 19; -/// computation coordinates (5 columns) -pub(super) const COL_COMP: usize = 24; - -// Virtual columns (not explicitely in AIR) -pub(super) const COL_ACTIVATION_FLAG: usize = 29; -pub(super) const COL_AUX_EXTENSION_OP: usize = 30; - -use backend::quintic_extension::extension::quintic_mul; - +// 12..15: value_b coordinates (3 cols) +pub(super) const COL_VB: usize = 12; +// 15..18: result coordinates (3 cols) +pub(super) const COL_VRES: usize = 15; +// 18..21: computation coordinates (3 cols) +pub(super) const COL_COMP: usize = 18; + +// Virtual columns (not materialized) +pub(super) const COL_ACTIVATION_FLAG: usize = 21; +pub(super) const COL_AUX_EXTENSION_OP: usize = 22; + +pub(super) const AIR_N_COLUMNS: usize = 21; + +// ---------- Cubic multiplication gate (`F[X] / (X^3 - X - 1)`, so `X^3 = X + 1`) ---------- +// +// (a0 + a1·X + a2·X^2)·(b0 + b1·X + b2·X^2), reduced: +// c0 = a0·b0 + a1·b2 + a2·b1 +// c1 = a0·b1 + a1·b0 + a1·b2 + a2·b1 + a2·b2 +// c2 = a0·b2 + a1·b1 + a2·b0 + a2·b2 #[inline] -fn quintic_mul_air(a: &[T; 5], b: &[T; 5]) -> [T; 5] { - quintic_mul(a, b, |x, y| { - x[0] * y[0] + x[1] * y[1] + x[2] * y[2] + x[3] * y[3] + x[4] * y[4] - }) +fn cubic_mul_air(a: &[T; 3], b: &[T; 3]) -> [T; 3] { + let a1b2 = a[1] * b[2]; + let a2b1 = a[2] * b[1]; + let a2b2 = a[2] * b[2]; + [ + a[0] * b[0] + a1b2 + a2b1, + a[0] * b[1] + a[1] * b[0] + a1b2 + a2b1 + a2b2, + a[0] * b[2] + a[1] * b[1] + a[2] * b[0] + a2b2, + ] } impl Air for ExtensionOpPrecompile { type ExtraData = ExtraDataForBuses; fn n_columns(&self) -> usize { - 29 + AIR_N_COLUMNS } fn degree_air(&self) -> usize { + // cubic_mul has degree 2 (elementwise), wrapped in `* flag_mul` gives 3; poly_eq + // squares that via another cubic_mul so needs degree 4? In the KoalaBear case the + // eval used `6` — we retain a conservative upper bound. 6 } fn n_constraints(&self) -> usize { - 33 + // 5 boolean gates + // + 3 * flag_add + // + 3 * flag_mul + // + 3 * flag_poly_eq + // + 3 * start (vres vs comp) + // + 6 transition gates (len/is_be/flags/idx transitions) + 1 start-row-length + // + 1 bus (if BUS) + 5 + 3 + 3 + 3 + 3 + 7 + BUS as usize } fn down_column_indexes(&self) -> Vec { vec![ @@ -63,8 +85,6 @@ impl Air for ExtensionOpPrecompile { COL_COMP, COL_COMP + 1, COL_COMP + 2, - COL_COMP + 3, - COL_COMP + 4, ] } @@ -82,10 +102,10 @@ impl Air for ExtensionOpPrecompile { let idx_a = up[COL_IDX_A]; let idx_b = up[COL_IDX_B]; - let va: [AB::IF; 5] = std::array::from_fn(|k| up[COL_VA + k]); - let vb: [AB::IF; 5] = std::array::from_fn(|k| up[COL_VB + k]); - let vres: [AB::IF; 5] = std::array::from_fn(|k| up[COL_VRES + k]); - let comp: [AB::IF; 5] = std::array::from_fn(|k| up[COL_COMP + k]); + let va: [AB::IF; 3] = std::array::from_fn(|k| up[COL_VA + k]); + let vb: [AB::IF; 3] = std::array::from_fn(|k| up[COL_VB + k]); + let vres: [AB::IF; 3] = std::array::from_fn(|k| up[COL_VRES + k]); + let comp: [AB::IF; 3] = std::array::from_fn(|k| up[COL_COMP + k]); let start_down = down[0]; // COL_START let is_be_down = down[1]; // COL_IS_BE @@ -95,7 +115,7 @@ impl Air for ExtensionOpPrecompile { let flag_poly_eq_down = down[5]; // COL_FLAG_POLY_EQ let idx_a_down = down[6]; // COL_IDX_A let idx_b_down = down[7]; // COL_IDX_B - let comp_down: [AB::IF; 5] = std::array::from_fn(|k| down[8 + k]); // COL_COMP+0..5 + let comp_down: [AB::IF; 3] = std::array::from_fn(|k| down[8 + k]); // COL_COMP+0..3 let active = flag_add + flag_mul + flag_poly_eq; let activation_flag = start * active; @@ -122,9 +142,11 @@ impl Air for ExtensionOpPrecompile { let is_ee = -(is_be - AB::F::ONE); let not_start_down = -(start_down - AB::F::ONE); - let va_f_or_ef: [AB::IF; 5] = std::array::from_fn(|k| if k == 0 { va[0] } else { va[k] * is_ee }); + // For "base-extension" ops, value_a is a base-field scalar embedded into EF as + // `(va[0], 0, 0)`: zero the upper coordinates when `is_be` is 1. + let va_f_or_ef: [AB::IF; 3] = std::array::from_fn(|k| if k == 0 { va[0] } else { va[k] * is_ee }); - let comp_tail: [AB::IF; 5] = std::array::from_fn(|k| comp_down[k] * not_start_down); + let comp_tail: [AB::IF; 3] = std::array::from_fn(|k| comp_down[k] * not_start_down); builder.assert_bool(is_be); builder.assert_bool(start); @@ -132,33 +154,35 @@ impl Air for ExtensionOpPrecompile { builder.assert_bool(flag_mul); builder.assert_bool(flag_poly_eq); - for k in 0..5 { + for k in 0..3 { builder.assert_zero((comp[k] - (va_f_or_ef[k] + vb[k] + comp_tail[k])) * flag_add); } - let va_times_vb = quintic_mul_air(&va_f_or_ef, &vb); + let va_times_vb = cubic_mul_air(&va_f_or_ef, &vb); - for k in 0..5 { + for k in 0..3 { builder.assert_zero((comp[k] - (va_times_vb[k] + comp_tail[k])) * flag_mul); } - let poly_eq_val: [AB::IF; 5] = std::array::from_fn(|k| { + // poly_eq per element: `2 a b - a - b + 1` (constant coord only gets +1), + // accumulated via multiplication. + let poly_eq_val: [AB::IF; 3] = std::array::from_fn(|k| { let base = va_times_vb[k].double() - va_f_or_ef[k] - vb[k]; if k == 0 { base + AB::F::ONE } else { base } }); - let comp_down_or_one: [AB::IF; 5] = std::array::from_fn(|k| { + let comp_down_or_one: [AB::IF; 3] = std::array::from_fn(|k| { if k == 0 { comp_down[0] * not_start_down + start_down } else { comp_down[k] * not_start_down } }); - let poly_eq_result = quintic_mul_air(&poly_eq_val, &comp_down_or_one); - for k in 0..5 { + let poly_eq_result = cubic_mul_air(&poly_eq_val, &comp_down_or_one); + for k in 0..3 { builder.assert_zero((comp[k] - poly_eq_result[k]) * flag_poly_eq); } - for k in 0..5 { + for k in 0..3 { builder.assert_zero((comp[k] - vres[k]) * start); } @@ -167,9 +191,9 @@ impl Air for ExtensionOpPrecompile { builder.assert_zero(not_start_down * (flag_add - flag_add_down)); builder.assert_zero(not_start_down * (flag_mul - flag_mul_down)); builder.assert_zero(not_start_down * (flag_poly_eq - flag_poly_eq_down)); - let a_increment = is_be + is_ee * AB::F::from_usize(crate::DIMENSION); + let a_increment = is_be + is_ee * AB::F::from_usize(DIMENSION); builder.assert_zero(not_start_down * (idx_a_down - idx_a - a_increment)); - builder.assert_zero(not_start_down * (idx_b_down - idx_b - AB::F::from_usize(crate::DIMENSION))); + builder.assert_zero(not_start_down * (idx_b_down - idx_b - AB::F::from_usize(DIMENSION))); builder.assert_zero(start_down * (len - AB::F::ONE)); } diff --git a/crates/lean_vm/src/tables/extension_op/mod.rs b/crates/lean_vm/src/tables/extension_op/mod.rs index 406fa724..397b0ed8 100644 --- a/crates/lean_vm/src/tables/extension_op/mod.rs +++ b/crates/lean_vm/src/tables/extension_op/mod.rs @@ -6,7 +6,7 @@ use air::*; mod exec; pub use exec::fill_trace_extension_op; -// domain separation: Poseidon16=1, Poseidon24= 2 or 3 or 4, ExtensionOp>=8 +// domain separation: Poseidon8=1, Poseidon24= 2 or 3 or 4, ExtensionOp>=8 /// Extension op PRECOMPILE_DATA bit-field encoding: /// aux = 4*is_be + 8*flag_add + 16*flag_mul + 32*flag_poly_eq + 64*len pub(crate) const EXT_OP_FLAG_IS_BE: usize = 4; diff --git a/crates/lean_vm/src/tables/mod.rs b/crates/lean_vm/src/tables/mod.rs index 3010d39f..bd191816 100644 --- a/crates/lean_vm/src/tables/mod.rs +++ b/crates/lean_vm/src/tables/mod.rs @@ -1,8 +1,8 @@ mod extension_op; pub use extension_op::*; -mod poseidon_16; -pub use poseidon_16::*; +mod poseidon_8; +pub use poseidon_8::*; mod table_enum; pub use table_enum::*; diff --git a/crates/lean_vm/src/tables/poseidon_16/mod.rs b/crates/lean_vm/src/tables/poseidon_16/mod.rs deleted file mode 100644 index 4aac4b39..00000000 --- a/crates/lean_vm/src/tables/poseidon_16/mod.rs +++ /dev/null @@ -1,396 +0,0 @@ -use std::any::TypeId; - -use crate::*; -use crate::{execution::memory::MemoryAccess, tables::poseidon_16::trace_gen::generate_trace_rows_for_perm}; -use backend::*; -use utils::{ToUsize, poseidon16_compress}; - -/// Dispatch `mds_circ_16` through concrete types. -/// For `SymbolicExpression` we use the dense form so the zkDSL generator can -/// emit `dot_product_be` precompile calls instead of Karatsuba arithmetic. -#[inline(always)] -fn mds_air(state: &mut [A; WIDTH]) { - if TypeId::of::() == TypeId::of::>() { - dense_mat_vec_air(mds_dense_16(), state); - return; - } - macro_rules! dispatch { - ($t:ty) => { - if TypeId::of::() == TypeId::of::<$t>() { - mds_circ_16::<$t>(unsafe { &mut *(state as *mut [A; WIDTH] as *mut [$t; WIDTH]) }); - return; - } - }; - } - dispatch!(F); - dispatch!(EF); - dispatch!(FPacking); - dispatch!(EFPacking); - unreachable!() -} - -fn mds_dense_16() -> &'static [[F; 16]; 16] { - use std::sync::OnceLock; - static MAT: OnceLock<[[KoalaBear; 16]; 16]> = OnceLock::new(); - MAT.get_or_init(|| { - let cols: [[F; 16]; 16] = std::array::from_fn(|j| { - let mut e = [F::ZERO; 16]; - e[j] = F::ONE; - mds_circ_16(&mut e); - e - }); - std::array::from_fn(|i| std::array::from_fn(|j| cols[j][i])) - }) -} - -/// Add a `KoalaBear` constant to any AIR type. -#[inline(always)] -fn add_kb(a: &mut A, value: F) { - macro_rules! dispatch { - ($t:ty) => { - if TypeId::of::() == TypeId::of::<$t>() { - *unsafe { &mut *(a as *mut A as *mut $t) } += value; - return; - } - }; - } - dispatch!(F); - dispatch!(EF); - dispatch!(FPacking); - dispatch!(EFPacking); - dispatch!(SymbolicExpression); - unreachable!() -} - -/// Multiply any AIR type by a `KoalaBear` constant. -#[inline(always)] -fn mul_kb(a: A, value: F) -> A { - macro_rules! dispatch { - ($t:ty) => { - if TypeId::of::() == TypeId::of::<$t>() { - let r = unsafe { std::ptr::read(&a as *const A as *const $t) } * value; - return unsafe { std::ptr::read(&r as *const $t as *const A) }; - } - }; - } - dispatch!(F); - dispatch!(EF); - dispatch!(FPacking); - dispatch!(EFPacking); - dispatch!(SymbolicExpression); - unreachable!() -} - -mod trace_gen; -pub use trace_gen::fill_trace_poseidon_16; - -pub(super) const WIDTH: usize = 16; -const INITIAL_FULL_ROUNDS: usize = POSEIDON1_HALF_FULL_ROUNDS; -const PARTIAL_ROUNDS: usize = POSEIDON1_PARTIAL_ROUNDS; -const FINAL_FULL_ROUNDS: usize = POSEIDON1_HALF_FULL_ROUNDS; - -pub const POSEIDON_PRECOMPILE_DATA: usize = 1; // domain separation: Poseidon16=1, Poseidon24=2 or 3 or 4, ExtensionOp>=8 - -pub const POSEIDON_16_COL_FLAG: ColIndex = 0; -pub const POSEIDON_16_COL_INDEX_INPUT_LEFT: ColIndex = 1; -pub const POSEIDON_16_COL_INDEX_INPUT_RIGHT: ColIndex = 2; -pub const POSEIDON_16_COL_INDEX_INPUT_RES: ColIndex = 3; -pub const POSEIDON_16_COL_INPUT_START: ColIndex = 4; -pub const POSEIDON_16_COL_OUTPUT_START: ColIndex = num_cols_poseidon_16() - 8; - -pub const POSEIDON16_NAME: &str = "poseidon16_compress"; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Poseidon16Precompile; - -impl TableT for Poseidon16Precompile { - fn name(&self) -> &'static str { - POSEIDON16_NAME - } - - fn table(&self) -> Table { - Table::poseidon16() - } - - fn lookups(&self) -> Vec { - vec![ - LookupIntoMemory { - index: POSEIDON_16_COL_INDEX_INPUT_LEFT, - values: (POSEIDON_16_COL_INPUT_START..POSEIDON_16_COL_INPUT_START + DIGEST_LEN).collect(), - }, - LookupIntoMemory { - index: POSEIDON_16_COL_INDEX_INPUT_RIGHT, - values: (POSEIDON_16_COL_INPUT_START + DIGEST_LEN..POSEIDON_16_COL_INPUT_START + DIGEST_LEN * 2) - .collect(), - }, - LookupIntoMemory { - index: POSEIDON_16_COL_INDEX_INPUT_RES, - values: (POSEIDON_16_COL_OUTPUT_START..POSEIDON_16_COL_OUTPUT_START + DIGEST_LEN).collect(), - }, - ] - } - - fn bus(&self) -> Bus { - Bus { - direction: BusDirection::Pull, - selector: POSEIDON_16_COL_FLAG, - data: vec![ - BusData::Constant(POSEIDON_PRECOMPILE_DATA), - BusData::Column(POSEIDON_16_COL_INDEX_INPUT_LEFT), - BusData::Column(POSEIDON_16_COL_INDEX_INPUT_RIGHT), - BusData::Column(POSEIDON_16_COL_INDEX_INPUT_RES), - ], - } - } - - fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize) -> Vec { - let mut row = vec![F::ZERO; num_cols_poseidon_16()]; - let ptrs: Vec<*mut F> = (0..num_cols_poseidon_16()) - .map(|i| unsafe { row.as_mut_ptr().add(i) }) - .collect(); - - let perm: &mut Poseidon1Cols16<&mut F> = unsafe { &mut *(ptrs.as_ptr() as *mut Poseidon1Cols16<&mut F>) }; - perm.inputs.iter_mut().for_each(|x| **x = F::ZERO); - *perm.flag = F::ZERO; - *perm.index_a = F::from_usize(zero_vec_ptr); - *perm.index_b = F::from_usize(zero_vec_ptr); - *perm.index_res = F::from_usize(null_hash_ptr); - - generate_trace_rows_for_perm(perm); - row - } - - #[inline(always)] - fn execute( - &self, - arg_a: F, - arg_b: F, - index_res_a: F, - _: PrecompileCompTimeArgs, - ctx: &mut InstructionContext<'_, M>, - ) -> Result<(), RunnerError> { - let trace = ctx.traces.get_mut(&self.table()).unwrap(); - - let arg0 = ctx.memory.get_slice(arg_a.to_usize(), DIGEST_LEN)?; - let arg1 = ctx.memory.get_slice(arg_b.to_usize(), DIGEST_LEN)?; - - let mut input = [F::ZERO; DIGEST_LEN * 2]; - input[..DIGEST_LEN].copy_from_slice(&arg0); - input[DIGEST_LEN..].copy_from_slice(&arg1); - - let output = poseidon16_compress(input); - - let res_a: [F; DIGEST_LEN] = output[..DIGEST_LEN].try_into().unwrap(); - - ctx.memory.set_slice(index_res_a.to_usize(), &res_a)?; - - trace.columns[POSEIDON_16_COL_FLAG].push(F::ONE); - trace.columns[POSEIDON_16_COL_INDEX_INPUT_LEFT].push(arg_a); - trace.columns[POSEIDON_16_COL_INDEX_INPUT_RIGHT].push(arg_b); - trace.columns[POSEIDON_16_COL_INDEX_INPUT_RES].push(index_res_a); - for (i, value) in input.iter().enumerate() { - trace.columns[POSEIDON_16_COL_INPUT_START + i].push(*value); - } - - // the rest of the trace is filled at the end of the execution (to get parallelism + SIMD) - - Ok(()) - } -} - -impl Air for Poseidon16Precompile { - type ExtraData = ExtraDataForBuses; - fn n_columns(&self) -> usize { - num_cols_poseidon_16() - } - fn degree_air(&self) -> usize { - 3 - } - fn down_column_indexes(&self) -> Vec { - vec![] - } - fn n_constraints(&self) -> usize { - BUS as usize + 140 - } - fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { - let cols: Poseidon1Cols16 = { - let up = builder.up(); - let (prefix, shorts, suffix) = unsafe { up.align_to::>() }; - debug_assert!(prefix.is_empty(), "Alignment should match"); - debug_assert!(suffix.is_empty(), "Alignment should match"); - debug_assert_eq!(shorts.len(), 1); - unsafe { std::ptr::read(&shorts[0]) } - }; - - // Bus data: [POSEIDON_PRECOMPILE_DATA (constant), a, b, res] - if BUS { - builder.eval_virtual_column(eval_virtual_bus_column::( - extra_data, - cols.flag, - &[ - AB::IF::from_usize(POSEIDON_PRECOMPILE_DATA), - cols.index_a, - cols.index_b, - cols.index_res, - ], - )); - } else { - builder.declare_values(std::slice::from_ref(&cols.flag)); - builder.declare_values(&[ - AB::IF::from_usize(POSEIDON_PRECOMPILE_DATA), - cols.index_a, - cols.index_b, - cols.index_res, - ]); - } - - builder.assert_bool(cols.flag); - - eval_poseidon(builder, &cols) - } -} - -#[repr(C)] -#[derive(Debug)] -pub(super) struct Poseidon1Cols16 { - pub flag: T, - pub index_a: T, - pub index_b: T, - pub index_res: T, - - pub inputs: [T; WIDTH], - pub beginning_full_rounds: [[T; WIDTH]; INITIAL_FULL_ROUNDS], - pub partial_rounds: [T; PARTIAL_ROUNDS], - pub ending_full_rounds: [[T; WIDTH]; FINAL_FULL_ROUNDS - 1], - pub outputs: [T; WIDTH / 2], -} - -fn eval_poseidon(builder: &mut AB, local: &Poseidon1Cols16) { - let mut state: [_; WIDTH] = local.inputs; - - let initial_constants = poseidon1_initial_constants(); - for round in 0..INITIAL_FULL_ROUNDS { - eval_full_round( - &mut state, - &local.beginning_full_rounds[round], - &initial_constants[round], - builder, - ); - } - - // --- Sparse partial rounds --- - // Transition: add first-round constants, multiply by m_i - let frc = poseidon1_sparse_first_round_constants(); - for (s, &c) in state.iter_mut().zip(frc.iter()) { - add_kb(s, c); - } - dense_mat_vec_air(poseidon1_sparse_m_i(), &mut state); - - let first_rows = poseidon1_sparse_first_row(); - let v_vecs = poseidon1_sparse_v(); - let scalar_rc = poseidon1_sparse_scalar_round_constants(); - for round in 0..PARTIAL_ROUNDS { - // S-box on state[0] - state[0] = state[0].cube(); - builder.assert_eq(state[0], local.partial_rounds[round]); - state[0] = local.partial_rounds[round]; - // Scalar round constant (not on last round) - if round < PARTIAL_ROUNDS - 1 { - add_kb(&mut state[0], scalar_rc[round]); - } - // Sparse matrix: new_s0 = dot(first_row, state), state[i] += old_s0 * v[i-1] - sparse_mat_air(&mut state, &first_rows[round], &v_vecs[round]); - } - - let final_constants = poseidon1_final_constants(); - for round in 0..FINAL_FULL_ROUNDS - 1 { - eval_full_round( - &mut state, - &local.ending_full_rounds[round], - &final_constants[round], - builder, - ); - } - - eval_last_full_round( - &local.inputs, - &mut state, - &local.outputs, - &final_constants[FINAL_FULL_ROUNDS - 1], - builder, - ); -} - -pub const fn num_cols_poseidon_16() -> usize { - size_of::>() -} - -#[inline] -fn eval_full_round( - state: &mut [AB::IF; WIDTH], - post_full_round: &[AB::IF; WIDTH], - round_constants: &[F; WIDTH], - builder: &mut AB, -) { - for (s, r) in state.iter_mut().zip(round_constants.iter()) { - add_kb(s, *r); - *s = s.cube(); - } - mds_air(state); - for (state_i, post_i) in state.iter_mut().zip(post_full_round) { - builder.assert_eq(*state_i, *post_i); - *state_i = *post_i; - } -} - -#[inline] -fn eval_last_full_round( - initial_state: &[AB::IF; WIDTH], - state: &mut [AB::IF; WIDTH], - outputs: &[AB::IF; WIDTH / 2], - round_constants: &[F; WIDTH], - builder: &mut AB, -) { - for (s, r) in state.iter_mut().zip(round_constants.iter()) { - add_kb(s, *r); - *s = s.cube(); - } - mds_air(state); - // add inputs to outputs (for compression) - for (state_i, init_state_i) in state.iter_mut().zip(initial_state) { - *state_i += *init_state_i; - } - for (state_i, output_i) in state.iter_mut().zip(outputs) { - builder.assert_eq(*state_i, *output_i); - *state_i = *output_i; - } -} - -#[inline] -fn dense_mat_vec_air(mat: &[[F; 16]; 16], state: &mut [A; WIDTH]) { - let input = *state; - for i in 0..WIDTH { - let mut acc = A::ZERO; - for j in 0..WIDTH { - acc += mul_kb(input[j], mat[i][j]); - } - state[i] = acc; - } -} - -#[inline] -fn sparse_mat_air( - state: &mut [A; WIDTH], - first_row: &[F; WIDTH], - v: &[F; WIDTH], -) { - let old_s0 = state[0]; - let mut new_s0 = A::ZERO; - for j in 0..WIDTH { - new_s0 += mul_kb(state[j], first_row[j]); - } - state[0] = new_s0; - for i in 1..WIDTH { - state[i] += mul_kb(old_s0, v[i - 1]); - } -} diff --git a/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs b/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs deleted file mode 100644 index 1dce2c4e..00000000 --- a/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs +++ /dev/null @@ -1,145 +0,0 @@ -use tracing::instrument; - -use crate::{ - F, - tables::{Poseidon1Cols16, WIDTH}, -}; -use backend::*; - -#[instrument(name = "generate Poseidon16 AIR trace", skip_all)] -pub fn fill_trace_poseidon_16(trace: &mut [Vec]) { - let n = trace.iter().map(|col| col.len()).max().unwrap(); - for col in trace.iter_mut() { - if col.len() != n { - col.resize(n, F::ZERO); - } - } - - let m = n - (n % packing_width::()); - let trace_packed: Vec<_> = trace.iter().map(|col| FPacking::::pack_slice(&col[..m])).collect(); - - // fill the packed rows - (0..m / packing_width::()).into_par_iter().for_each(|i| { - let ptrs: Vec<*mut FPacking> = trace_packed - .iter() - .map(|col| unsafe { (col.as_ptr() as *mut FPacking).add(i) }) - .collect(); - let perm: &mut Poseidon1Cols16<&mut FPacking> = - unsafe { &mut *(ptrs.as_ptr() as *mut Poseidon1Cols16<&mut FPacking>) }; - - generate_trace_rows_for_perm(perm); - }); - - // fill the remaining rows (non packed) - for i in m..n { - let ptrs: Vec<*mut F> = trace - .iter() - .map(|col| unsafe { (col.as_ptr() as *mut F).add(i) }) - .collect(); - let perm: &mut Poseidon1Cols16<&mut F> = unsafe { &mut *(ptrs.as_ptr() as *mut Poseidon1Cols16<&mut F>) }; - generate_trace_rows_for_perm(perm); - } -} - -pub(super) fn generate_trace_rows_for_perm + Copy>(perm: &mut Poseidon1Cols16<&mut F>) { - let inputs: [F; WIDTH] = std::array::from_fn(|i| *perm.inputs[i]); - let mut state = inputs; - - // No initial linear layer for Poseidon1 (unlike Poseidon2) - - for (full_round, constants) in perm - .beginning_full_rounds - .iter_mut() - .zip(poseidon1_initial_constants().iter()) - { - generate_1_full_round(&mut state, full_round, constants); - } - - // --- Sparse partial rounds --- - // Transition: add first-round constants, multiply by m_i - let frc = poseidon1_sparse_first_round_constants(); - for (s, &c) in state.iter_mut().zip(frc.iter()) { - *s += c; - } - let m_i = poseidon1_sparse_m_i(); - let input_for_mi = state; - for i in 0..WIDTH { - let row: [F; WIDTH] = m_i[i].map(F::from); - state[i] = F::dot_product(&input_for_mi, &row); - } - - let first_rows = poseidon1_sparse_first_row(); - let v_vecs = poseidon1_sparse_v(); - let scalar_rc = poseidon1_sparse_scalar_round_constants(); - let n_partial = perm.partial_rounds.len(); - for round in 0..n_partial { - // S-box on state[0] - state[0] = state[0].cube(); - *perm.partial_rounds[round] = state[0]; - // Scalar round constant (not on last round) - if round < n_partial - 1 { - state[0] += scalar_rc[round]; - } - // Sparse matrix - let old_s0 = state[0]; - let row: [F; WIDTH] = first_rows[round].map(F::from); - let new_s0 = F::dot_product(&state, &row); - state[0] = new_s0; - for i in 1..WIDTH { - state[i] += old_s0 * v_vecs[round][i - 1]; - } - } - - let n_ending_full_rounds = perm.ending_full_rounds.len(); - for (full_round, constants) in perm - .ending_full_rounds - .iter_mut() - .zip(poseidon1_final_constants().iter()) - { - generate_1_full_round(&mut state, full_round, constants); - } - - // Last full round with compression (add inputs to outputs) - generate_last_1_full_round( - &mut state, - &inputs, - &mut perm.outputs, - &poseidon1_final_constants()[n_ending_full_rounds], - ); -} - -#[inline] -fn generate_1_full_round + Copy>( - state: &mut [F; WIDTH], - post_full_round: &mut [&mut F; WIDTH], - round_constants: &[KoalaBear; WIDTH], -) { - for (state_i, const_i) in state.iter_mut().zip(round_constants) { - *state_i += *const_i; - *state_i = state_i.cube(); - } - mds_circ_16(state); - - post_full_round.iter_mut().zip(*state).for_each(|(post, x)| { - **post = x; - }); -} - -#[inline] -fn generate_last_1_full_round + Copy>( - state: &mut [F; WIDTH], - inputs: &[F; WIDTH], - outputs: &mut [&mut F; WIDTH / 2], - round_constants: &[KoalaBear; WIDTH], -) { - for (state_i, const_i) in state.iter_mut().zip(round_constants) { - *state_i += *const_i; - *state_i = state_i.cube(); - } - mds_circ_16(state); - - // Add inputs to outputs (compression) - for ((output, state_i), &input_i) in outputs.iter_mut().zip(state).zip(inputs) { - **output = *state_i + input_i; - } -} diff --git a/crates/lean_vm/src/tables/poseidon_8/mod.rs b/crates/lean_vm/src/tables/poseidon_8/mod.rs new file mode 100644 index 00000000..f25ada57 --- /dev/null +++ b/crates/lean_vm/src/tables/poseidon_8/mod.rs @@ -0,0 +1,199 @@ +use crate::*; +use crate::execution::memory::MemoryAccess; +use backend::*; +use utils::{ToUsize, poseidon8_compress}; + +// TODO(goldilocks-migration): this AIR is currently a soundness stub. +// +// The KoalaBear predecessor implemented Poseidon1 width-16 as an AIR with a +// sparse-matrix factorization for the partial rounds, an `x^3` S-box (degree-3 +// compliant with `degree_air = 3`), and tight column packing. +// +// Goldilocks Poseidon1 is width-8 with `x^7` S-box and 22 partial rounds; the +// sbox alone needs witness decomposition (`y2 = x*x`, `y4 = y2*y2`, `y7 = x*y2*y4`) +// to fit under degree 3. That's a fresh column layout and gate algebra — out of +// scope for this migration pass. +// +// The stub below keeps the I/O columns (flag, index_a, index_b, index_res, +// inputs[8], outputs[4]) and the memory lookups + bus, so callers from +// `execute` and the verifier still wire up. The permutation itself is *not* +// constrained — the prover commits the correct `poseidon8_compress` output +// via trace generation, and the verifier accepts it because no gate rejects a +// mismatch. **This is unsound and must be replaced** before shipping. + +mod trace_gen; +pub use trace_gen::fill_trace_poseidon_8; + +pub(super) const WIDTH: usize = 8; +pub(super) const DIGEST: usize = DIGEST_LEN; // 4 + +pub const POSEIDON_PRECOMPILE_DATA: usize = 1; // domain separation: Poseidon8=1, ExtensionOp>=8 + +pub const POSEIDON_8_COL_FLAG: ColIndex = 0; +pub const POSEIDON_8_COL_INDEX_INPUT_LEFT: ColIndex = 1; +pub const POSEIDON_8_COL_INDEX_INPUT_RIGHT: ColIndex = 2; +pub const POSEIDON_8_COL_INDEX_INPUT_RES: ColIndex = 3; +pub const POSEIDON_8_COL_INPUT_START: ColIndex = 4; +pub const POSEIDON_8_COL_OUTPUT_START: ColIndex = POSEIDON_8_COL_INPUT_START + WIDTH; + +// Legacy aliases used by other tables/compiler code that still refers to the +// KoalaBear-era names. Keeping them as shims keeps the diff small. +pub const POSEIDON_16_COL_FLAG: ColIndex = POSEIDON_8_COL_FLAG; +pub const POSEIDON_16_COL_INDEX_INPUT_LEFT: ColIndex = POSEIDON_8_COL_INDEX_INPUT_LEFT; +pub const POSEIDON_16_COL_INDEX_INPUT_RIGHT: ColIndex = POSEIDON_8_COL_INDEX_INPUT_RIGHT; +pub const POSEIDON_16_COL_INDEX_INPUT_RES: ColIndex = POSEIDON_8_COL_INDEX_INPUT_RES; +pub const POSEIDON_16_COL_INPUT_START: ColIndex = POSEIDON_8_COL_INPUT_START; + +pub const POSEIDON8_NAME: &str = "poseidon8_compress"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Poseidon8Precompile; + +impl TableT for Poseidon8Precompile { + fn name(&self) -> &'static str { + POSEIDON8_NAME + } + + fn table(&self) -> Table { + Table::poseidon8() + } + + fn lookups(&self) -> Vec { + vec![ + LookupIntoMemory { + index: POSEIDON_8_COL_INDEX_INPUT_LEFT, + values: (POSEIDON_8_COL_INPUT_START..POSEIDON_8_COL_INPUT_START + DIGEST).collect(), + }, + LookupIntoMemory { + index: POSEIDON_8_COL_INDEX_INPUT_RIGHT, + values: (POSEIDON_8_COL_INPUT_START + DIGEST..POSEIDON_8_COL_INPUT_START + DIGEST * 2) + .collect(), + }, + LookupIntoMemory { + index: POSEIDON_8_COL_INDEX_INPUT_RES, + values: (POSEIDON_8_COL_OUTPUT_START..POSEIDON_8_COL_OUTPUT_START + DIGEST).collect(), + }, + ] + } + + fn bus(&self) -> Bus { + Bus { + direction: BusDirection::Pull, + selector: POSEIDON_8_COL_FLAG, + data: vec![ + BusData::Constant(POSEIDON_PRECOMPILE_DATA), + BusData::Column(POSEIDON_8_COL_INDEX_INPUT_LEFT), + BusData::Column(POSEIDON_8_COL_INDEX_INPUT_RIGHT), + BusData::Column(POSEIDON_8_COL_INDEX_INPUT_RES), + ], + } + } + + fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize) -> Vec { + let mut row = vec![F::ZERO; num_cols_poseidon_8()]; + row[POSEIDON_8_COL_FLAG] = F::ZERO; + row[POSEIDON_8_COL_INDEX_INPUT_LEFT] = F::from_usize(zero_vec_ptr); + row[POSEIDON_8_COL_INDEX_INPUT_RIGHT] = F::from_usize(zero_vec_ptr); + row[POSEIDON_8_COL_INDEX_INPUT_RES] = F::from_usize(null_hash_ptr); + // inputs stay zero. outputs = poseidon8_compress(0) — truncated to DIGEST. + let out = poseidon8_compress([F::ZERO; WIDTH]); + for (i, v) in out.iter().enumerate() { + row[POSEIDON_8_COL_OUTPUT_START + i] = *v; + } + row + } + + #[inline(always)] + fn execute( + &self, + arg_a: F, + arg_b: F, + index_res_a: F, + _: PrecompileCompTimeArgs, + ctx: &mut InstructionContext<'_, M>, + ) -> Result<(), RunnerError> { + let trace = ctx.traces.get_mut(&self.table()).unwrap(); + + let arg0 = ctx.memory.get_slice(arg_a.to_usize(), DIGEST)?; + let arg1 = ctx.memory.get_slice(arg_b.to_usize(), DIGEST)?; + + let mut input = [F::ZERO; WIDTH]; + input[..DIGEST].copy_from_slice(&arg0); + input[DIGEST..].copy_from_slice(&arg1); + + let output = poseidon8_compress(input); + + let res_a: [F; DIGEST] = output; + + ctx.memory.set_slice(index_res_a.to_usize(), &res_a)?; + + trace.columns[POSEIDON_8_COL_FLAG].push(F::ONE); + trace.columns[POSEIDON_8_COL_INDEX_INPUT_LEFT].push(arg_a); + trace.columns[POSEIDON_8_COL_INDEX_INPUT_RIGHT].push(arg_b); + trace.columns[POSEIDON_8_COL_INDEX_INPUT_RES].push(index_res_a); + for (i, value) in input.iter().enumerate() { + trace.columns[POSEIDON_8_COL_INPUT_START + i].push(*value); + } + for (i, value) in output.iter().enumerate() { + trace.columns[POSEIDON_8_COL_OUTPUT_START + i].push(*value); + } + + Ok(()) + } +} + +impl Air for Poseidon8Precompile { + type ExtraData = ExtraDataForBuses; + fn n_columns(&self) -> usize { + num_cols_poseidon_8() + } + fn degree_air(&self) -> usize { + 3 + } + fn down_column_indexes(&self) -> Vec { + vec![] + } + fn n_constraints(&self) -> usize { + // Only the boolean flag gate, plus the bus / declared values. + 1 + BUS as usize + } + fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { + let up = builder.up(); + let flag = up[POSEIDON_8_COL_FLAG]; + let index_a = up[POSEIDON_8_COL_INDEX_INPUT_LEFT]; + let index_b = up[POSEIDON_8_COL_INDEX_INPUT_RIGHT]; + let index_res = up[POSEIDON_8_COL_INDEX_INPUT_RES]; + + if BUS { + builder.eval_virtual_column(eval_virtual_bus_column::( + extra_data, + flag, + &[ + AB::IF::from_usize(POSEIDON_PRECOMPILE_DATA), + index_a, + index_b, + index_res, + ], + )); + } else { + builder.declare_values(std::slice::from_ref(&flag)); + builder.declare_values(&[ + AB::IF::from_usize(POSEIDON_PRECOMPILE_DATA), + index_a, + index_b, + index_res, + ]); + } + + builder.assert_bool(flag); + + // TODO(goldilocks-migration): constrain outputs to equal + // `poseidon8_compress([inputs[0..8]])`. Currently unconstrained — the + // prover is trusted to fill correct outputs via trace generation. + } +} + +pub const fn num_cols_poseidon_8() -> usize { + // flag + 3 indices + 8 inputs + 4 outputs + 4 + WIDTH + DIGEST_LEN +} diff --git a/crates/lean_vm/src/tables/poseidon_8/trace_gen.rs b/crates/lean_vm/src/tables/poseidon_8/trace_gen.rs new file mode 100644 index 00000000..68d254d2 --- /dev/null +++ b/crates/lean_vm/src/tables/poseidon_8/trace_gen.rs @@ -0,0 +1,18 @@ +use tracing::instrument; + +use crate::F; +use backend::PrimeCharacteristicRing; + +// TODO(goldilocks-migration): once the Goldilocks Poseidon1-8 AIR has real +// per-round witness columns, this is where we'll fill them. Today the stub AIR +// has no per-round columns, so the `execute` path already writes every column +// it needs. +#[instrument(name = "generate Poseidon8 AIR trace (stub)", skip_all)] +pub fn fill_trace_poseidon_8(trace: &mut [Vec]) { + let n = trace.iter().map(|col| col.len()).max().unwrap_or(0); + for col in trace.iter_mut() { + if col.len() != n { + col.resize(n, F::ZERO); + } + } +} diff --git a/crates/lean_vm/src/tables/table_enum.rs b/crates/lean_vm/src/tables/table_enum.rs index 55be30e2..b9894d7a 100644 --- a/crates/lean_vm/src/tables/table_enum.rs +++ b/crates/lean_vm/src/tables/table_enum.rs @@ -4,7 +4,7 @@ use crate::execution::memory::MemoryAccess; use crate::*; pub const N_TABLES: usize = 3; -pub const ALL_TABLES: [Table; N_TABLES] = [Table::execution(), Table::extension_op(), Table::poseidon16()]; +pub const ALL_TABLES: [Table; N_TABLES] = [Table::execution(), Table::extension_op(), Table::poseidon8()]; pub const MAX_PRECOMPILE_BUS_WIDTH: usize = 4; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -12,7 +12,7 @@ pub const MAX_PRECOMPILE_BUS_WIDTH: usize = 4; pub enum Table { Execution(ExecutionTable), ExtensionOp(ExtensionOpPrecompile), - Poseidon16(Poseidon16Precompile), + Poseidon8(Poseidon8Precompile), } #[macro_export] @@ -21,7 +21,7 @@ macro_rules! delegate_to_inner { ($self:expr, $method:ident $(, $($arg:expr),*)?) => { match $self { Self::ExtensionOp(p) => p.$method($($($arg),*)?), - Self::Poseidon16(p) => p.$method($($($arg),*)?), + Self::Poseidon8(p) => p.$method($($($arg),*)?), Self::Execution(p) => p.$method($($($arg),*)?), } }; @@ -29,7 +29,7 @@ macro_rules! delegate_to_inner { ($self:expr => $macro_name:ident) => { match $self { Table::ExtensionOp(p) => $macro_name!(p), - Table::Poseidon16(p) => $macro_name!(p), + Table::Poseidon8(p) => $macro_name!(p), Table::Execution(p) => $macro_name!(p), } }; @@ -42,8 +42,8 @@ impl Table { pub const fn extension_op() -> Self { Self::ExtensionOp(ExtensionOpPrecompile) } - pub const fn poseidon16() -> Self { - Self::Poseidon16(Poseidon16Precompile) + pub const fn poseidon8() -> Self { + Self::Poseidon8(Poseidon8Precompile) } pub fn embed(&self) -> PF { PF::from_usize(self.index()) diff --git a/crates/rec_aggregation/src/compilation.rs b/crates/rec_aggregation/src/compilation.rs index ff43dc18..5a4bc280 100644 --- a/crates/rec_aggregation/src/compilation.rs +++ b/crates/rec_aggregation/src/compilation.rs @@ -378,7 +378,7 @@ fn all_air_evals_in_zk_dsl() -> String { let mut res = String::new(); res += &air_eval_in_zk_dsl(ExecutionTable:: {}); res += &air_eval_in_zk_dsl(ExtensionOpPrecompile:: {}); - res += &air_eval_in_zk_dsl(Poseidon16Precompile:: {}); + res += &air_eval_in_zk_dsl(Poseidon8Precompile:: {}); res } @@ -386,8 +386,8 @@ const AIR_INNER_VALUES_VAR: &str = "inner_evals"; struct AirCodegenCtx { expr_cache: HashMap, - consts_cache: HashMap, String>, - ef_const_cache: HashMap, + consts_cache: HashMap, String>, + ef_const_cache: HashMap, ctr: Counter, } @@ -401,7 +401,7 @@ impl AirCodegenCtx { } } - fn write_base_constants(&mut self, values: &[u32], res: &mut String) -> String { + fn write_base_constants(&mut self, values: &[u64], res: &mut String) -> String { if let Some(name) = self.consts_cache.get(values) { return name.clone(); } @@ -414,7 +414,7 @@ impl AirCodegenCtx { name } - fn write_embedded_constant(&mut self, c: u32, res: &mut String) -> String { + fn write_embedded_constant(&mut self, c: u64, res: &mut String) -> String { if let Some(name) = self.ef_const_cache.get(&c) { return name.clone(); } @@ -485,7 +485,7 @@ fn eval_air_constraint( res: &mut String, ) -> String { let v = match expr { - SymbolicExpression::Constant(c) => ctx.write_embedded_constant(c.as_canonical_u32(), res), + SymbolicExpression::Constant(c) => ctx.write_embedded_constant(c.as_canonical_u64(), res), SymbolicExpression::Variable(v) => format!("{} + DIM * {}", AIR_INNER_VALUES_VAR, v.index), SymbolicExpression::Operation(idx) => { if let Some(v) = ctx.expr_cache.get(&idx) { @@ -545,7 +545,7 @@ fn try_emit_dot_product_be(idx: u32, dest: Option<&str>, ctx: &mut AirCodegenCtx } let (c, expr) = match (mul.lhs, mul.rhs) { (SymbolicExpression::Constant(c), o) | (o, SymbolicExpression::Constant(c)) => { - (c.as_canonical_u32(), o) + (c.as_canonical_u64(), o) } _ => return None, }; @@ -616,11 +616,11 @@ fn eval_air_binary_op( res: &mut String, ) -> String { let c0 = match lhs { - SymbolicExpression::Constant(c) => Some(c.as_canonical_u32()), + SymbolicExpression::Constant(c) => Some(c.as_canonical_u64()), _ => None, }; let c1 = match rhs { - SymbolicExpression::Constant(c) => Some(c.as_canonical_u32()), + SymbolicExpression::Constant(c) => Some(c.as_canonical_u64()), _ => None, }; @@ -708,5 +708,5 @@ fn display_all_air_evals_in_zk_dsl() { #[test] fn display_poseidon_air_in_zk_dsl() { - println!("{}", air_eval_in_zk_dsl(Poseidon16Precompile:: {})); + println!("{}", air_eval_in_zk_dsl(Poseidon8Precompile:: {})); } diff --git a/crates/rec_aggregation/src/lib.rs b/crates/rec_aggregation/src/lib.rs index 21c8b180..314edcdf 100644 --- a/crates/rec_aggregation/src/lib.rs +++ b/crates/rec_aggregation/src/lib.rs @@ -6,7 +6,7 @@ use lean_prover::verify_execution::ProofVerificationDetails; use lean_prover::verify_execution::verify_execution; use lean_vm::*; use tracing::instrument; -use utils::{build_prover_state, get_poseidon16, poseidon_compress_slice, poseidon16_compress_pair}; +use utils::{build_prover_state, get_poseidon8, poseidon_compress_slice, poseidon8_compress_pair}; use xmss::{LOG_LIFETIME, MESSAGE_LEN_FE, SIG_SIZE_FE, XmssPublicKey, XmssSignature, slot_to_field_elements}; use serde::{Deserialize, Serialize}; @@ -81,7 +81,7 @@ fn build_input_data( // Pad the bytecode claim itself up to DIGEST_LEN let claim_padding = bytecode_claim_output.len().next_multiple_of(DIGEST_LEN) - bytecode_claim_output.len(); data.extend(std::iter::repeat_n(F::ZERO, claim_padding)); - data.extend_from_slice(&poseidon16_compress_pair(bytecode_hash, &SNARK_DOMAIN_SEP)); + data.extend_from_slice(&poseidon8_compress_pair(bytecode_hash, &SNARK_DOMAIN_SEP)); // Round the whole buffer up to DIGEST_LEN so `slice_hash_with_iv` can absorb it chunk by chunk. data.resize(data.len().next_multiple_of(DIGEST_LEN), F::ZERO); data @@ -277,7 +277,7 @@ pub fn xmss_aggregate( let final_sumcheck_proof = { // Recover the transcript of the final sumcheck (for bytecode claim reduction) - let mut vs = VerifierState::::new(reduction_prover.into_proof(), get_poseidon16().clone()).unwrap(); + let mut vs = VerifierState::::new(reduction_prover.into_proof(), get_poseidon8().clone()).unwrap(); vs.next_base_scalars_vec(claims_hash.len()).unwrap(); let _: EF = vs.sample(); sumcheck_verify(&mut vs, bytecode_point_n_vars, 2, claimed_sum, None).unwrap(); @@ -443,7 +443,7 @@ pub fn hash_bytecode_claims(claims: &[Evaluation]) -> [F; DIGEST_LEN] { data.resize(data.len().next_multiple_of(DIGEST_LEN), F::ZERO); let claim_hash = poseidon_compress_slice(&data, false); - running_hash = poseidon16_compress_pair(&running_hash, &claim_hash); + running_hash = poseidon8_compress_pair(&running_hash, &claim_hash); } running_hash } diff --git a/crates/sub_protocols/src/quotient_gkr.rs b/crates/sub_protocols/src/quotient_gkr.rs index e0f8ccb9..a4779ab3 100644 --- a/crates/sub_protocols/src/quotient_gkr.rs +++ b/crates/sub_protocols/src/quotient_gkr.rs @@ -306,8 +306,8 @@ mod tests { use std::time::Instant; use utils::{build_prover_state, build_verifier_state, init_tracing}; - type F = KoalaBear; - type EF = QuinticExtensionFieldKB; + type F = Goldilocks; + type EF = CubicExtensionFieldGL; fn sum_all_quotients(nums: &[F], den: &[EF]) -> EF { nums.par_iter().zip(den).map(|(&n, &d)| EF::from(n) / d).sum() diff --git a/crates/utils/src/multilinear.rs b/crates/utils/src/multilinear.rs index 593e7680..42e00258 100644 --- a/crates/utils/src/multilinear.rs +++ b/crates/utils/src/multilinear.rs @@ -103,8 +103,8 @@ mod tests { use super::*; - type F = KoalaBear; - type EF = QuinticExtensionFieldKB; + type F = Goldilocks; + type EF = CubicExtensionFieldGL; #[test] fn test_evaluate_as_larger_multilinear_pol() { diff --git a/crates/utils/src/poseidon.rs b/crates/utils/src/poseidon.rs index bde7db7d..02835147 100644 --- a/crates/utils/src/poseidon.rs +++ b/crates/utils/src/poseidon.rs @@ -1,61 +1,64 @@ use backend::*; use std::sync::OnceLock; -pub type Poseidon16 = Poseidon1KoalaBear16; +pub type Poseidon8 = Poseidon1Goldilocks8; -pub const HALF_FULL_ROUNDS_16: usize = POSEIDON1_HALF_FULL_ROUNDS; -pub const PARTIAL_ROUNDS_16: usize = POSEIDON1_PARTIAL_ROUNDS; +pub const HALF_FULL_ROUNDS_8: usize = POSEIDON1_HALF_FULL_ROUNDS; +pub const PARTIAL_ROUNDS_8: usize = POSEIDON1_PARTIAL_ROUNDS; -static POSEIDON_16_INSTANCE: OnceLock = OnceLock::new(); -static POSEIDON_16_OF_ZERO: OnceLock<[KoalaBear; 8]> = OnceLock::new(); +static POSEIDON_8_INSTANCE: OnceLock = OnceLock::new(); +static POSEIDON_8_OF_ZERO: OnceLock<[Goldilocks; 4]> = OnceLock::new(); #[inline(always)] -pub fn get_poseidon16() -> &'static Poseidon16 { - POSEIDON_16_INSTANCE.get_or_init(default_koalabear_poseidon1_16) +pub fn get_poseidon8() -> &'static Poseidon8 { + POSEIDON_8_INSTANCE.get_or_init(default_goldilocks_poseidon1_8) } #[inline(always)] -pub fn get_poseidon_16_of_zero() -> &'static [KoalaBear; 8] { - POSEIDON_16_OF_ZERO.get_or_init(|| poseidon16_compress([KoalaBear::default(); 16])) +pub fn get_poseidon_8_of_zero() -> &'static [Goldilocks; 4] { + POSEIDON_8_OF_ZERO.get_or_init(|| poseidon8_compress([Goldilocks::default(); 8])) } #[inline(always)] -pub fn poseidon16_compress(input: [KoalaBear; 16]) -> [KoalaBear; 8] { - get_poseidon16().compress(input)[0..8].try_into().unwrap() +pub fn poseidon8_compress(input: [Goldilocks; 8]) -> [Goldilocks; 4] { + get_poseidon8().compress(input)[0..4].try_into().unwrap() } -pub fn poseidon16_compress_pair(left: &[KoalaBear; 8], right: &[KoalaBear; 8]) -> [KoalaBear; 8] { - let mut input = [KoalaBear::default(); 16]; - input[..8].copy_from_slice(left); - input[8..].copy_from_slice(right); - poseidon16_compress(input) +pub fn poseidon8_compress_pair( + left: &[Goldilocks; 4], + right: &[Goldilocks; 4], +) -> [Goldilocks; 4] { + let mut input = [Goldilocks::default(); 8]; + input[..4].copy_from_slice(left); + input[4..].copy_from_slice(right); + poseidon8_compress(input) } /// If `use_iv` is false, the length of the slice must be constant (not malleable). -pub fn poseidon_compress_slice(data: &[KoalaBear], use_iv: bool) -> [KoalaBear; 8] { +pub fn poseidon_compress_slice(data: &[Goldilocks], use_iv: bool) -> [Goldilocks; 4] { assert!(!data.is_empty()); if use_iv { - let mut hash = [KoalaBear::default(); 8]; - for chunk in data.chunks(8) { - let mut block = [KoalaBear::default(); 16]; - block[..8].copy_from_slice(&hash); - block[8..8 + chunk.len()].copy_from_slice(chunk); - hash = poseidon16_compress(block); + let mut hash = [Goldilocks::default(); 4]; + for chunk in data.chunks(4) { + let mut block = [Goldilocks::default(); 8]; + block[..4].copy_from_slice(&hash); + block[4..4 + chunk.len()].copy_from_slice(chunk); + hash = poseidon8_compress(block); } hash } else { let len = data.len(); - if len <= 16 { - let mut padded = [KoalaBear::default(); 16]; + if len <= 8 { + let mut padded = [Goldilocks::default(); 8]; padded[..len].copy_from_slice(data); - return poseidon16_compress(padded); + return poseidon8_compress(padded); } - let mut hash = poseidon16_compress(data[0..16].try_into().unwrap()); - for chunk in data[16..].chunks(8) { - let mut block = [KoalaBear::default(); 16]; - block[..8].copy_from_slice(&hash); - block[8..8 + chunk.len()].copy_from_slice(chunk); - hash = poseidon16_compress(block); + let mut hash = poseidon8_compress(data[0..8].try_into().unwrap()); + for chunk in data[8..].chunks(4) { + let mut block = [Goldilocks::default(); 8]; + block[..4].copy_from_slice(&hash); + block[4..4 + chunk.len()].copy_from_slice(chunk); + hash = poseidon8_compress(block); } hash } diff --git a/crates/utils/src/wrappers.rs b/crates/utils/src/wrappers.rs index c8aa8e40..a7e1b3b0 100644 --- a/crates/utils/src/wrappers.rs +++ b/crates/utils/src/wrappers.rs @@ -1,18 +1,18 @@ use backend::*; -use crate::Poseidon16; -use crate::get_poseidon16; +use crate::Poseidon8; +use crate::get_poseidon8; pub type VarCount = usize; -pub fn build_prover_state() -> ProverState { - ProverState::new(get_poseidon16().clone()) +pub fn build_prover_state() -> ProverState { + ProverState::new(*get_poseidon8()) } pub fn build_verifier_state( - prover_state: ProverState, -) -> Result, ProofError> { - VerifierState::new(prover_state.into_proof(), get_poseidon16().clone()) + prover_state: ProverState, +) -> Result, ProofError> { + VerifierState::new(prover_state.into_proof(), *get_poseidon8()) } pub trait ToUsize { diff --git a/crates/whir/Cargo.toml b/crates/whir/Cargo.toml index 1da08f78..dd524578 100644 --- a/crates/whir/Cargo.toml +++ b/crates/whir/Cargo.toml @@ -5,7 +5,7 @@ edition.workspace = true [dependencies] field = { path = "../backend/field", package = "mt-field" } -koala-bear = { path = "../backend/koala-bear", package = "mt-koala-bear" } +goldilocks = { path = "../backend/goldilocks", package = "mt-goldilocks" } poly = { path = "../backend/poly", package = "mt-poly" } sumcheck = { path = "../backend/sumcheck", package = "mt-sumcheck" } fiat-shamir = { path = "../backend/fiat-shamir", package = "mt-fiat-shamir" } diff --git a/crates/whir/src/dft.rs b/crates/whir/src/dft.rs index 912520ec..329344c6 100644 --- a/crates/whir/src/dft.rs +++ b/crates/whir/src/dft.rs @@ -574,14 +574,14 @@ impl Butterfly for EvalsButterfly { #[cfg(test)] mod tests { use field::{PrimeCharacteristicRing, TwoAdicField}; - use koala_bear::{KoalaBear, QuinticExtensionFieldKB}; + use goldilocks::{Goldilocks, CubicExtensionFieldGL}; use poly::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; use crate::*; - type F = KoalaBear; - type EF = QuinticExtensionFieldKB; + type F = Goldilocks; + type EF = CubicExtensionFieldGL; #[test] fn test_eval_dft() { diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index 49a94769..e2e82880 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -8,7 +8,7 @@ use field::BasedVectorSpace; use field::ExtensionField; use field::Field; use field::PackedValue; -use koala_bear::{KoalaBear, QuinticExtensionFieldKB, default_koalabear_poseidon1_16}; +use goldilocks::{Goldilocks, CubicExtensionFieldGL, default_goldilocks_poseidon1_8}; use poly::*; use rayon::prelude::*; @@ -31,22 +31,22 @@ pub(crate) fn merkle_commit>( full_n_cols: usize, effective_n_cols: usize, ) -> ([F; DIGEST_ELEMS], RoundMerkleTree) { - let perm = default_koalabear_poseidon1_16(); - if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, QuinticExtensionFieldKB)>() { - let matrix = unsafe { std::mem::transmute::<_, DenseMatrix>(matrix) }; + let perm = default_goldilocks_poseidon1_8(); + if TypeId::of::<(F, EF)>() == TypeId::of::<(Goldilocks, CubicExtensionFieldGL)>() { + let matrix = unsafe { std::mem::transmute::<_, DenseMatrix>(matrix) }; let view = FlatMatrixView::new(matrix); - let dim = >::DIMENSION; + let dim = >::DIMENSION; let full_base_width = full_n_cols * dim; let effective_base_width = effective_n_cols * dim; let tree = - WhirMerkleTree::new::, _, 16, 8>(&perm, view, full_base_width, effective_base_width); + WhirMerkleTree::new::, _, 8, 4>(&perm, view, full_base_width, effective_base_width); let root: [_; DIGEST_ELEMS] = tree.root(); let root = unsafe { std::mem::transmute_copy::<_, [F; DIGEST_ELEMS]>(&root) }; let tree = unsafe { std::mem::transmute::<_, RoundMerkleTree>(tree) }; (root, tree) - } else if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, KoalaBear)>() { - let matrix = unsafe { std::mem::transmute::<_, DenseMatrix>(matrix) }; - let tree = WhirMerkleTree::new::, _, 16, 8>(&perm, matrix, full_n_cols, effective_n_cols); + } else if TypeId::of::<(F, EF)>() == TypeId::of::<(Goldilocks, Goldilocks)>() { + let matrix = unsafe { std::mem::transmute::<_, DenseMatrix>(matrix) }; + let tree = WhirMerkleTree::new::, _, 8, 4>(&perm, matrix, full_n_cols, effective_n_cols); let root: [_; DIGEST_ELEMS] = tree.root(); let root = unsafe { std::mem::transmute_copy::<_, [F; DIGEST_ELEMS]>(&root) }; let tree = unsafe { std::mem::transmute::<_, RoundMerkleTree>(tree) }; @@ -61,18 +61,18 @@ pub(crate) fn merkle_open>( merkle_tree: &RoundMerkleTree, index: usize, ) -> (Vec, Vec<[F; DIGEST_ELEMS]>) { - if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, QuinticExtensionFieldKB)>() { + if TypeId::of::<(F, EF)>() == TypeId::of::<(Goldilocks, CubicExtensionFieldGL)>() { let merkle_tree = - unsafe { std::mem::transmute::<_, &RoundMerkleTree>(merkle_tree) }; + unsafe { std::mem::transmute::<_, &RoundMerkleTree>(merkle_tree) }; let (inner_leaf, proof) = merkle_tree.open(index); - let leaf = QuinticExtensionFieldKB::reconstitute_from_base(inner_leaf); + let leaf = CubicExtensionFieldGL::reconstitute_from_base(inner_leaf); let leaf = unsafe { std::mem::transmute::<_, Vec>(leaf) }; let proof = unsafe { std::mem::transmute::<_, Vec<[F; DIGEST_ELEMS]>>(proof) }; (leaf, proof) - } else if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, KoalaBear)>() { - let merkle_tree = unsafe { std::mem::transmute::<_, &RoundMerkleTree>(merkle_tree) }; + } else if TypeId::of::<(F, EF)>() == TypeId::of::<(Goldilocks, Goldilocks)>() { + let merkle_tree = unsafe { std::mem::transmute::<_, &RoundMerkleTree>(merkle_tree) }; let (inner_leaf, proof) = merkle_tree.open(index); - let leaf = KoalaBear::reconstitute_from_base(inner_leaf); + let leaf = Goldilocks::reconstitute_from_base(inner_leaf); let leaf = unsafe { std::mem::transmute::<_, Vec>(leaf) }; let proof = unsafe { std::mem::transmute::<_, Vec<[F; DIGEST_ELEMS]>>(proof) }; (leaf, proof) @@ -89,14 +89,14 @@ pub(crate) fn merkle_verify>( data: Vec, proof: &Vec<[F; DIGEST_ELEMS]>, ) -> bool { - let perm = default_koalabear_poseidon1_16(); + let perm = default_goldilocks_poseidon1_8(); let log_max_height = utils::log2_strict_usize(dimension.height.next_power_of_two()); - if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, QuinticExtensionFieldKB)>() { - let merkle_root = unsafe { std::mem::transmute_copy::<_, [KoalaBear; DIGEST_ELEMS]>(&merkle_root) }; - let data = unsafe { std::mem::transmute::<_, Vec>(data) }; - let proof = unsafe { std::mem::transmute::<_, &Vec<[KoalaBear; DIGEST_ELEMS]>>(proof) }; - let base_data = QuinticExtensionFieldKB::flatten_to_base(data); - symetric::merkle::merkle_verify::<_, _, DIGEST_ELEMS, 16, 8>( + if TypeId::of::<(F, EF)>() == TypeId::of::<(Goldilocks, CubicExtensionFieldGL)>() { + let merkle_root = unsafe { std::mem::transmute_copy::<_, [Goldilocks; DIGEST_ELEMS]>(&merkle_root) }; + let data = unsafe { std::mem::transmute::<_, Vec>(data) }; + let proof = unsafe { std::mem::transmute::<_, &Vec<[Goldilocks; DIGEST_ELEMS]>>(proof) }; + let base_data = CubicExtensionFieldGL::flatten_to_base(data); + symetric::merkle::merkle_verify::<_, _, DIGEST_ELEMS, 8, 4>( &perm, &merkle_root, log_max_height, @@ -104,12 +104,12 @@ pub(crate) fn merkle_verify>( &base_data, proof, ) - } else if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, KoalaBear)>() { - let merkle_root = unsafe { std::mem::transmute_copy::<_, [KoalaBear; DIGEST_ELEMS]>(&merkle_root) }; - let data = unsafe { std::mem::transmute::<_, Vec>(data) }; - let proof = unsafe { std::mem::transmute::<_, &Vec<[KoalaBear; DIGEST_ELEMS]>>(proof) }; - let base_data = KoalaBear::flatten_to_base(data); - symetric::merkle::merkle_verify::<_, _, DIGEST_ELEMS, 16, 8>( + } else if TypeId::of::<(F, EF)>() == TypeId::of::<(Goldilocks, Goldilocks)>() { + let merkle_root = unsafe { std::mem::transmute_copy::<_, [Goldilocks; DIGEST_ELEMS]>(&merkle_root) }; + let data = unsafe { std::mem::transmute::<_, Vec>(data) }; + let proof = unsafe { std::mem::transmute::<_, &Vec<[Goldilocks; DIGEST_ELEMS]>>(proof) }; + let base_data = Goldilocks::flatten_to_base(data); + symetric::merkle::merkle_verify::<_, _, DIGEST_ELEMS, 8, 4>( &perm, &merkle_root, log_max_height, diff --git a/crates/whir/tests/run_whir.rs b/crates/whir/tests/run_whir.rs index 33df5507..19a3985f 100644 --- a/crates/whir/tests/run_whir.rs +++ b/crates/whir/tests/run_whir.rs @@ -4,15 +4,15 @@ use std::time::Instant; use fiat_shamir::{ProverState, VerifierState}; use field::{Field, TwoAdicField}; -use koala_bear::{KoalaBear, QuinticExtensionFieldKB, default_koalabear_poseidon1_16}; +use goldilocks::{Goldilocks, CubicExtensionFieldGL, default_goldilocks_poseidon1_8}; use mt_whir::*; use poly::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; use tracing_forest::{ForestLayer, util::LevelFilter}; use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt}; -type F = KoalaBear; -type EF = QuinticExtensionFieldKB; +type F = Goldilocks; +type EF = CubicExtensionFieldGL; /* WHIR_NUM_VARIABLES=25 cargo test --release --package mt-whir --test run_whir -- test_run_whir --exact --nocapture @@ -30,7 +30,7 @@ fn test_run_whir() { .with(ForestLayer::default()) .try_init(); } - let poseidon16 = default_koalabear_poseidon1_16(); + let poseidon8 = default_goldilocks_poseidon1_8(); let num_variables = std::env::var("WHIR_NUM_VARIABLES") .ok() @@ -95,7 +95,7 @@ fn test_run_whir() { )); } - let mut prover_state = ProverState::new(poseidon16.clone()); + let mut prover_state = ProverState::new(poseidon8.clone()); precompute_dft_twiddles::(1 << F::TWO_ADICITY); @@ -118,7 +118,7 @@ fn test_run_whir() { let proof_size_single = pruned_proof.proof_size_fe() as f64 * F::bits() as f64 / 8.0; - let mut verifier_state = VerifierState::::new(pruned_proof, poseidon16.clone()).unwrap(); + let mut verifier_state = VerifierState::::new(pruned_proof, poseidon8.clone()).unwrap(); let parsed_commitment = params.parse_commitment::(&mut verifier_state).unwrap(); diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index 7e5fe8d2..0e756073 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -1,14 +1,14 @@ #![cfg_attr(not(test), warn(unused_crate_dependencies))] pub mod signers_cache; mod wots; -use backend::KoalaBear; +use backend::Goldilocks; pub use wots::*; mod xmss; pub use xmss::*; -pub(crate) const DIGEST_SIZE: usize = 8; +pub(crate) const DIGEST_SIZE: usize = 4; -type F = KoalaBear; +type F = Goldilocks; type Digest = [F; DIGEST_SIZE]; // WOTS diff --git a/crates/xmss/src/signers_cache.rs b/crates/xmss/src/signers_cache.rs index 46a5514f..fa6acab5 100644 --- a/crates/xmss/src/signers_cache.rs +++ b/crates/xmss/src/signers_cache.rs @@ -34,10 +34,10 @@ fn cache_footprint(first_pubkey: &XmssPublicKey) -> u128 { hasher.update(NUM_BENCHMARK_SIGNERS.to_le_bytes()); hasher.update(BENCHMARK_SLOT.to_le_bytes()); for f in message_for_benchmark() { - hasher.update(f.as_canonical_u32().to_le_bytes()); + hasher.update(f.as_canonical_u64().to_le_bytes()); } for f in first_pubkey.merkle_root { - hasher.update(f.as_canonical_u32().to_le_bytes()); + hasher.update(f.as_canonical_u64().to_le_bytes()); } let hash = hasher.finalize(); u128::from_le_bytes(hash[..16].try_into().unwrap()) diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index 20c35b36..4c7d61ea 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -1,7 +1,7 @@ use backend::*; use rand::{CryptoRng, RngExt}; use serde::{Deserialize, Serialize}; -use utils::{ToUsize, poseidon16_compress_pair}; +use utils::{ToUsize, poseidon8_compress_pair}; use crate::*; @@ -76,15 +76,17 @@ impl WotsSignature { impl WotsPublicKey { pub fn hash(&self) -> Digest { - let init = poseidon16_compress_pair(&self.0[0], &self.0[1]); - self.0[2..] - .iter() - .fold(init, |digest, chunk| poseidon16_compress_pair(&digest, chunk)) + // TODO(goldilocks-migration): re-derive WOTS hashing over width-8 Poseidon / + // digest-4. The KoalaBear version chained 8-element digests through a width-16 + // permutation; the parameter layout doesn't port one-to-one. Stubbed for now + // because XMSS isn't exercised by `test_zk_vm_all_precompiles`. + unimplemented!("WOTS hash not yet reworked for Goldilocks digest-4") } } -pub fn iterate_hash(a: &Digest, n: usize) -> Digest { - (0..n).fold(*a, |acc, _| poseidon16_compress_pair(&acc, &Default::default())) +pub fn iterate_hash(_a: &Digest, _n: usize) -> Digest { + // TODO(goldilocks-migration): see `WotsPublicKey::hash`. + unimplemented!("WOTS iterate_hash not yet reworked for Goldilocks digest-4") } pub fn find_randomness_for_wots_encoding( @@ -104,45 +106,17 @@ pub fn find_randomness_for_wots_encoding( } pub fn wots_encode( - message: &[F; MESSAGE_LEN_FE], - slot: u32, - truncated_merkle_root: &[F; TRUNCATED_MERKLE_ROOT_LEN_FE], - randomness: &[F; RANDOMNESS_LEN_FE], + _message: &[F; MESSAGE_LEN_FE], + _slot: u32, + _truncated_merkle_root: &[F; TRUNCATED_MERKLE_ROOT_LEN_FE], + _randomness: &[F; RANDOMNESS_LEN_FE], ) -> Option<[u8; V]> { - // Encode slot as 2 field elements (16 bits each) - let [slot_lo, slot_hi] = slot_to_field_elements(slot); - - // A = poseidon(message (9 fe), randomness (7 fe)) - let mut a_input_right = [F::default(); 8]; - a_input_right[0] = message[8]; - a_input_right[1..1 + RANDOMNESS_LEN_FE].copy_from_slice(randomness); - let a = poseidon16_compress_pair(message[..8].try_into().unwrap(), &a_input_right); - - // B = poseidon(A (8 fe), slot (2 fe), truncated_merkle_root (6 fe)) - let mut b_input_right = [F::default(); 8]; - b_input_right[0] = slot_lo; - b_input_right[1] = slot_hi; - b_input_right[2..8].copy_from_slice(truncated_merkle_root); - let compressed = poseidon16_compress_pair(&a, &b_input_right); - - if compressed.iter().any(|&kb| kb == -F::ONE) { - // ensures uniformity of encoding - return None; - } - let all_indices: Vec<_> = compressed - .iter() - .flat_map(|kb| to_little_endian_bits(kb.to_usize(), 24)) - .collect::>() - .chunks_exact(W) - .take(V + V_GRINDING) - .map(|chunk| { - chunk - .iter() - .enumerate() - .fold(0u8, |acc, (i, &bit)| acc | (u8::from(bit) << i)) - }) - .collect(); - is_valid_encoding(&all_indices).then(|| all_indices[..V].try_into().unwrap()) + // TODO(goldilocks-migration): WOTS encoding depends on Poseidon width 16 / digest 8 + // layout, and on a 24-bit little-endian decomposition of a 31-bit KoalaBear value. + // For Goldilocks we need a fresh parameter choice (64-bit lanes, width-8 permutation, + // digest-4). Stubbed for now because XMSS isn't exercised by + // `test_zk_vm_all_precompiles`. + unimplemented!("WOTS encoding not yet reworked for Goldilocks") } fn is_valid_encoding(encoding: &[u8]) -> bool { diff --git a/crates/xmss/src/xmss.rs b/crates/xmss/src/xmss.rs index 88f9e6f2..3a634ea0 100644 --- a/crates/xmss/src/xmss.rs +++ b/crates/xmss/src/xmss.rs @@ -1,7 +1,7 @@ use backend::*; use rand::{CryptoRng, RngExt, SeedableRng, rngs::StdRng}; use serde::{Deserialize, Serialize}; -use utils::poseidon16_compress_pair; +use utils::poseidon8_compress_pair; use crate::*; @@ -58,7 +58,7 @@ pub fn xmss_key_gen( if slot_start > slot_end { return Err(XmssKeyGenError::InvalidRange); } - let perm = default_koalabear_poseidon1_16(); + let perm = default_goldilocks_poseidon1_8(); // Level 0: WOTS leaf hashes for slots in [slot_start, slot_end] let leaves: Vec = (slot_start..=slot_end) .into_par_iter() @@ -193,9 +193,9 @@ pub fn xmss_verify( for (level, neighbour) in signature.merkle_proof.iter().enumerate() { let is_left = ((slot >> level) & 1) == 0; if is_left { - current_hash = poseidon16_compress_pair(¤t_hash, neighbour); + current_hash = poseidon8_compress_pair(¤t_hash, neighbour); } else { - current_hash = poseidon16_compress_pair(neighbour, ¤t_hash); + current_hash = poseidon8_compress_pair(neighbour, ¤t_hash); } } if current_hash == pub_key.merkle_root { diff --git a/crates/xmss/tests/xmss_tests.rs b/crates/xmss/tests/xmss_tests.rs index 40bbb637..9ce1451a 100644 --- a/crates/xmss/tests/xmss_tests.rs +++ b/crates/xmss/tests/xmss_tests.rs @@ -2,7 +2,7 @@ use backend::*; use rand::{SeedableRng, rngs::StdRng}; use xmss::*; -type F = KoalaBear; +type F = Goldilocks; #[test] fn test_xmss_serialize_deserialize() { diff --git a/src/lib.rs b/src/lib.rs index 092f00a0..2899df68 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,7 @@ pub use backend::ProofError; pub use rec_aggregation::{AggregatedXMSS, AggregationTopology, xmss_aggregate, xmss_verify_aggregation}; pub use xmss::{MESSAGE_LEN_FE, XmssPublicKey, XmssSecretKey, XmssSignature, xmss_key_gen, xmss_sign, xmss_verify}; -pub type F = KoalaBear; +pub type F = Goldilocks; /// Call once before proving. Compiles the aggregation program and precomputes DFT twiddles. pub fn setup_prover() { From b9f7c21981a9d4be171cec8ec6953c625a8e0914 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 16 Apr 2026 10:23:26 +0200 Subject: [PATCH 03/31] wip --- crates/backend/goldilocks/src/poseidon1.rs | 8 --- crates/backend/poly/src/eq_mle.rs | 8 +-- crates/lean_compiler/tests/test_compiler.rs | 17 +++--- crates/utils/src/poseidon.rs | 4 +- crates/xmss/src/lib.rs | 6 +-- crates/xmss/src/wots.rs | 58 ++++++++++++++------- crates/xmss/src/xmss.rs | 3 +- 7 files changed, 59 insertions(+), 45 deletions(-) diff --git a/crates/backend/goldilocks/src/poseidon1.rs b/crates/backend/goldilocks/src/poseidon1.rs index b8354880..9add5fcd 100644 --- a/crates/backend/goldilocks/src/poseidon1.rs +++ b/crates/backend/goldilocks/src/poseidon1.rs @@ -314,14 +314,6 @@ impl Poseidon1Goldilocks8 { } } - /// Pure-permutation compress: apply the permutation in place, return the - /// full width-8 state. Callers that want a digest truncate to the first - /// `POSEIDON1_DIGEST_LEN = 4` lanes. - #[inline] - pub fn compress(&self, input: [Goldilocks; POSEIDON1_WIDTH]) -> [Goldilocks; POSEIDON1_WIDTH] { - self.permute(input) - } - /// Compression-mode in-place permutation: `output = permute(input) + input`. /// /// Matches the koala-bear `Poseidon1Goldilocks8::compress_in_place` shape diff --git a/crates/backend/poly/src/eq_mle.rs b/crates/backend/poly/src/eq_mle.rs index 90a55e74..bdf8e6c3 100644 --- a/crates/backend/poly/src/eq_mle.rs +++ b/crates/backend/poly/src/eq_mle.rs @@ -1153,7 +1153,7 @@ mod tests { println!("EXTENSION PACKED: {:?}", time.elapsed()); let unpacked_out_2: Vec = - >::ExtensionPacking::to_ext_iter_vec(out_2.clone()); + <>::ExtensionPacking as PackedFieldExtension>::to_ext_iter_vec(out_2.clone()); assert_eq!(out_1, unpacked_out_2); let mut out_3 = EF::zero_vec(1 << n_vars); @@ -1161,7 +1161,7 @@ mod tests { compute_eval_eq::(&eval, &mut out_3, scalar); let out_3_packed = out_3 .par_chunks_exact(packing_width) - .map(>::ExtensionPacking::from_ext_slice) + .map(<>::ExtensionPacking as PackedFieldExtension>::from_ext_slice) .collect::>(); println!("EXTENSION PACKED AFTER: {:?}", time.elapsed()); @@ -1188,7 +1188,7 @@ mod tests { println!("BASE PACKED: {:?}", time.elapsed()); let unpacked_out_2: Vec = - >::ExtensionPacking::to_ext_iter_vec(out_2.clone()); + <>::ExtensionPacking as PackedFieldExtension>::to_ext_iter_vec(out_2.clone()); assert_eq!(out_1, unpacked_out_2); let mut out_3 = EF::zero_vec(1 << n_vars); @@ -1196,7 +1196,7 @@ mod tests { compute_eval_eq_base::(&eval, &mut out_3, scalar); let out_3_packed = out_3 .par_chunks_exact(packing_width) - .map(>::ExtensionPacking::from_ext_slice) + .map(<>::ExtensionPacking as PackedFieldExtension>::from_ext_slice) .collect::>(); println!("BASE PACKED AFTER: {:?}", time.elapsed()); diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 2da5a707..ef9012c8 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -8,19 +8,20 @@ use utils::poseidon8_compress; #[test] fn test_poseidon() { + // Goldilocks width-8 Poseidon: two 4-element halves in, one 4-element digest out. let program = r#" def main(): a = 0 - b = a + 8 - c = Array(8) + b = a + 4 + c = Array(4) poseidon8_compress(a, b, c) - for i in range(0, 8): + for i in range(0, 4): cc = c[i] print(cc) return "#; - let public_input: [F; 16] = (0..16).map(F::new).collect::>().try_into().unwrap(); + let public_input: [F; 8] = (0..8).map(F::new).collect::>().try_into().unwrap(); compile_and_run(&ProgramSource::Raw(program.to_string()), &public_input, false); let _ = dbg!(poseidon8_compress(public_input)); @@ -29,7 +30,7 @@ def main(): #[test] fn test_div_extension_field() { let program = r#" -DIM = 5 +DIM = 3 def main(): n = 0 @@ -170,8 +171,8 @@ def main(): let bytecode = compile_program(&ProgramSource::Raw(program)); let run = |end_val: u32| -> usize { - let expected_sum = (start..end_val).map(|i| i as u64).sum::() as u32; - let public_input = [F::new(end_val), F::new(expected_sum)]; + let expected_sum = (start..end_val).map(|i| i as u64).sum::(); + let public_input = [F::new(end_val as u64), F::new(expected_sum)]; let result = try_execute_bytecode(&bytecode, &public_input, &ExecutionWitness::default(), false).unwrap(); result.pcs.len() }; @@ -288,7 +289,7 @@ fn test_soundness_suite() { ("soundness_5", &[3, 4, 7, 19, 49, 28, 1, 3], &[(0, 4), (1, 5), (2, 8), (3, 20), (4, 50), (5, 29), (6, 0), (6, 2), (7, 4)]), ]; - let to_input = |v: &[u32]| v.iter().copied().map(F::new).collect::>(); + let to_input = |v: &[u32]| v.iter().copied().map(|x| F::new(x as u64)).collect::>(); for &(name, valid, perturbations) in cases { let path = format!("{}/{}.py", test_data_dir(), name); diff --git a/crates/utils/src/poseidon.rs b/crates/utils/src/poseidon.rs index 02835147..1de26ac3 100644 --- a/crates/utils/src/poseidon.rs +++ b/crates/utils/src/poseidon.rs @@ -21,7 +21,9 @@ pub fn get_poseidon_8_of_zero() -> &'static [Goldilocks; 4] { #[inline(always)] pub fn poseidon8_compress(input: [Goldilocks; 8]) -> [Goldilocks; 4] { - get_poseidon8().compress(input)[0..4].try_into().unwrap() + let mut state = input; + get_poseidon8().compress_in_place(&mut state); + state[0..4].try_into().unwrap() } pub fn poseidon8_compress_pair( diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index 0e756073..21b675a5 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -19,8 +19,8 @@ pub const NUM_CHAIN_HASHES: usize = 110; pub const TARGET_SUM: usize = V * (CHAIN_LENGTH - 1) - NUM_CHAIN_HASHES; pub const V_GRINDING: usize = 2; pub const LOG_LIFETIME: usize = 32; -pub const RANDOMNESS_LEN_FE: usize = 7; -pub const MESSAGE_LEN_FE: usize = 9; -pub const TRUNCATED_MERKLE_ROOT_LEN_FE: usize = 6; +pub const RANDOMNESS_LEN_FE: usize = 4; +pub const MESSAGE_LEN_FE: usize = 4; +pub const TRUNCATED_MERKLE_ROOT_LEN_FE: usize = 4; pub const SIG_SIZE_FE: usize = RANDOMNESS_LEN_FE + (V + LOG_LIFETIME) * DIGEST_SIZE; diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index 4c7d61ea..94d6e0d6 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -1,7 +1,7 @@ use backend::*; use rand::{CryptoRng, RngExt}; use serde::{Deserialize, Serialize}; -use utils::{ToUsize, poseidon8_compress_pair}; +use utils::{poseidon8_compress_pair, poseidon_compress_slice}; use crate::*; @@ -76,17 +76,15 @@ impl WotsSignature { impl WotsPublicKey { pub fn hash(&self) -> Digest { - // TODO(goldilocks-migration): re-derive WOTS hashing over width-8 Poseidon / - // digest-4. The KoalaBear version chained 8-element digests through a width-16 - // permutation; the parameter layout doesn't port one-to-one. Stubbed for now - // because XMSS isn't exercised by `test_zk_vm_all_precompiles`. - unimplemented!("WOTS hash not yet reworked for Goldilocks digest-4") + let init = poseidon8_compress_pair(&self.0[0], &self.0[1]); + self.0[2..] + .iter() + .fold(init, |digest, chunk| poseidon8_compress_pair(&digest, chunk)) } } -pub fn iterate_hash(_a: &Digest, _n: usize) -> Digest { - // TODO(goldilocks-migration): see `WotsPublicKey::hash`. - unimplemented!("WOTS iterate_hash not yet reworked for Goldilocks digest-4") +pub fn iterate_hash(a: &Digest, n: usize) -> Digest { + (0..n).fold(*a, |acc, _| poseidon8_compress_pair(&acc, &Default::default())) } pub fn find_randomness_for_wots_encoding( @@ -106,17 +104,39 @@ pub fn find_randomness_for_wots_encoding( } pub fn wots_encode( - _message: &[F; MESSAGE_LEN_FE], - _slot: u32, - _truncated_merkle_root: &[F; TRUNCATED_MERKLE_ROOT_LEN_FE], - _randomness: &[F; RANDOMNESS_LEN_FE], + message: &[F; MESSAGE_LEN_FE], + slot: u32, + truncated_merkle_root: &[F; TRUNCATED_MERKLE_ROOT_LEN_FE], + randomness: &[F; RANDOMNESS_LEN_FE], ) -> Option<[u8; V]> { - // TODO(goldilocks-migration): WOTS encoding depends on Poseidon width 16 / digest 8 - // layout, and on a 24-bit little-endian decomposition of a 31-bit KoalaBear value. - // For Goldilocks we need a fresh parameter choice (64-bit lanes, width-8 permutation, - // digest-4). Stubbed for now because XMSS isn't exercised by - // `test_zk_vm_all_precompiles`. - unimplemented!("WOTS encoding not yet reworked for Goldilocks") + let [slot_lo, slot_hi] = slot_to_field_elements(slot); + + const INPUT_LEN: usize = MESSAGE_LEN_FE + RANDOMNESS_LEN_FE + 2 + TRUNCATED_MERKLE_ROOT_LEN_FE; + let mut input = [F::default(); INPUT_LEN]; + input[..MESSAGE_LEN_FE].copy_from_slice(message); + input[MESSAGE_LEN_FE..MESSAGE_LEN_FE + RANDOMNESS_LEN_FE].copy_from_slice(randomness); + input[MESSAGE_LEN_FE + RANDOMNESS_LEN_FE] = slot_lo; + input[MESSAGE_LEN_FE + RANDOMNESS_LEN_FE + 1] = slot_hi; + input[MESSAGE_LEN_FE + RANDOMNESS_LEN_FE + 2..].copy_from_slice(truncated_merkle_root); + + let encoding_fe = poseidon_compress_slice(&input, false); + + if encoding_fe.iter().any(|&fe| fe == -F::ONE) { + return None; + } + + const CHUNKS_PER_FE: usize = (V + V_GRINDING) / DIGEST_SIZE; + const MASK: u64 = (1u64 << W) - 1; + debug_assert_eq!(CHUNKS_PER_FE * DIGEST_SIZE, V + V_GRINDING); + + let mut all_indices = [0u8; V + V_GRINDING]; + for (i, fe) in encoding_fe.iter().enumerate() { + let value = fe.as_canonical_u64(); + for j in 0..CHUNKS_PER_FE { + all_indices[i * CHUNKS_PER_FE + j] = ((value >> (j * W)) & MASK) as u8; + } + } + is_valid_encoding(&all_indices).then(|| all_indices[..V].try_into().unwrap()) } fn is_valid_encoding(encoding: &[u8]) -> bool { diff --git a/crates/xmss/src/xmss.rs b/crates/xmss/src/xmss.rs index 3a634ea0..fc810b08 100644 --- a/crates/xmss/src/xmss.rs +++ b/crates/xmss/src/xmss.rs @@ -58,7 +58,6 @@ pub fn xmss_key_gen( if slot_start > slot_end { return Err(XmssKeyGenError::InvalidRange); } - let perm = default_goldilocks_poseidon1_8(); // Level 0: WOTS leaf hashes for slots in [slot_start, slot_end] let leaves: Vec = (slot_start..=slot_end) .into_par_iter() @@ -95,7 +94,7 @@ pub fn xmss_key_gen( assert!(right_idx < 1u64 << 32); gen_random_node(&seed, level - 1, right_idx as u32) }; - compress(&perm, [left, right]) + poseidon8_compress_pair(&left, &right) }) .collect() }; From bb7be6f3dbe10cc298a8c2921f7ae49a26dfe06b Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 16 Apr 2026 10:30:57 +0200 Subject: [PATCH 04/31] test_plonky3_compatibility --- crates/backend/goldilocks/src/poseidon1.rs | 25 ++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/crates/backend/goldilocks/src/poseidon1.rs b/crates/backend/goldilocks/src/poseidon1.rs index 9add5fcd..76b92960 100644 --- a/crates/backend/goldilocks/src/poseidon1.rs +++ b/crates/backend/goldilocks/src/poseidon1.rs @@ -385,4 +385,29 @@ mod tests { assert!(seen.insert(out[0].value), "collision at i={i}"); } } + + /// Plonky3-compatibility known-answer vector. + /// + /// Reference: `plonky3/goldilocks/src/poseidon1.rs::tests::test_poseidon_goldilocks_width_8` + /// — input `[0..8)`, expected output hardcoded from upstream. + #[test] + fn test_plonky3_compatibility() { + use field::PrimeField64; + + let p = default_goldilocks_poseidon1_8(); + let mut input: [Goldilocks; 8] = [0, 1, 2, 3, 4, 5, 6, 7].map(Goldilocks::new); + p.permute_mut(&mut input); + let expected: [u64; 8] = [ + 2431226948502761687, + 9427563026145807618, + 6827549936272051660, + 16907684411084503785, + 10131745626715172913, + 17448305483431576765, + 9066501914269485014, + 12095238468458521303, + ]; + let got: [u64; 8] = input.map(|x| x.as_canonical_u64()); + assert_eq!(got, expected); + } } From 82c624e2dd6105d2d7eb89354f16c7d924d9c382 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 16 Apr 2026 10:46:17 +0200 Subject: [PATCH 05/31] wip --- crates/backend/goldilocks/src/poseidon1.rs | 4 +- crates/lean_vm/src/tables/poseidon_8/mod.rs | 236 +++++++++++++++--- .../src/tables/poseidon_8/trace_gen.rs | 10 +- 3 files changed, 202 insertions(+), 48 deletions(-) diff --git a/crates/backend/goldilocks/src/poseidon1.rs b/crates/backend/goldilocks/src/poseidon1.rs index 76b92960..c0e6c6b7 100644 --- a/crates/backend/goldilocks/src/poseidon1.rs +++ b/crates/backend/goldilocks/src/poseidon1.rs @@ -24,7 +24,7 @@ pub const POSEIDON1_PARTIAL_ROUNDS: usize = 22; pub const POSEIDON1_SBOX_DEGREE: u64 = 7; pub const POSEIDON1_DIGEST_LEN: usize = 4; -const POSEIDON1_N_ROUNDS: usize = +pub const POSEIDON1_N_ROUNDS: usize = 2 * POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS; // ========================================================================= @@ -35,7 +35,7 @@ const POSEIDON1_N_ROUNDS: usize = // the first column — more convenient for a row-major apply of a circulant // since `row_i = cyclic_shift(col, i)`, i.e. `M[i][j] = COL[(j - i + N) mod N]` // (equivalently `ROW[(j - i) mod N]`). -const MDS8_ROW: [i64; 8] = [7, 1, 3, 8, 8, 3, 4, 9]; +pub const MDS8_ROW: [i64; 8] = [7, 1, 3, 8, 8, 3, 4, 9]; /// Apply the width-8 circulant MDS matrix in place, generic over `R`. /// diff --git a/crates/lean_vm/src/tables/poseidon_8/mod.rs b/crates/lean_vm/src/tables/poseidon_8/mod.rs index f25ada57..2cadf3b9 100644 --- a/crates/lean_vm/src/tables/poseidon_8/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_8/mod.rs @@ -3,24 +3,6 @@ use crate::execution::memory::MemoryAccess; use backend::*; use utils::{ToUsize, poseidon8_compress}; -// TODO(goldilocks-migration): this AIR is currently a soundness stub. -// -// The KoalaBear predecessor implemented Poseidon1 width-16 as an AIR with a -// sparse-matrix factorization for the partial rounds, an `x^3` S-box (degree-3 -// compliant with `degree_air = 3`), and tight column packing. -// -// Goldilocks Poseidon1 is width-8 with `x^7` S-box and 22 partial rounds; the -// sbox alone needs witness decomposition (`y2 = x*x`, `y4 = y2*y2`, `y7 = x*y2*y4`) -// to fit under degree 3. That's a fresh column layout and gate algebra — out of -// scope for this migration pass. -// -// The stub below keeps the I/O columns (flag, index_a, index_b, index_res, -// inputs[8], outputs[4]) and the memory lookups + bus, so callers from -// `execute` and the verifier still wire up. The permutation itself is *not* -// constrained — the prover commits the correct `poseidon8_compress` output -// via trace generation, and the verifier accepts it because no gate rejects a -// mismatch. **This is unsound and must be replaced** before shipping. - mod trace_gen; pub use trace_gen::fill_trace_poseidon_8; @@ -29,12 +11,14 @@ pub(super) const DIGEST: usize = DIGEST_LEN; // 4 pub const POSEIDON_PRECOMPILE_DATA: usize = 1; // domain separation: Poseidon8=1, ExtensionOp>=8 +// ---------- I/O columns ---------- pub const POSEIDON_8_COL_FLAG: ColIndex = 0; pub const POSEIDON_8_COL_INDEX_INPUT_LEFT: ColIndex = 1; pub const POSEIDON_8_COL_INDEX_INPUT_RIGHT: ColIndex = 2; pub const POSEIDON_8_COL_INDEX_INPUT_RES: ColIndex = 3; pub const POSEIDON_8_COL_INPUT_START: ColIndex = 4; -pub const POSEIDON_8_COL_OUTPUT_START: ColIndex = POSEIDON_8_COL_INPUT_START + WIDTH; +pub const POSEIDON_8_COL_OUTPUT_START: ColIndex = POSEIDON_8_COL_INPUT_START + WIDTH; // 12 +pub const POSEIDON_8_COL_ROUND_START: ColIndex = POSEIDON_8_COL_OUTPUT_START + DIGEST; // 16 // Legacy aliases used by other tables/compiler code that still refers to the // KoalaBear-era names. Keeping them as shims keeps the diff small. @@ -46,6 +30,108 @@ pub const POSEIDON_16_COL_INPUT_START: ColIndex = POSEIDON_8_COL_INPUT_START; pub const POSEIDON8_NAME: &str = "poseidon8_compress"; +// ---------- Per-round aux columns ---------- +// +// Goldilocks Poseidon1 width-8: 30 rounds, x⁷ S-box, circulant MDS row +// `MDS8_ROW = [7,1,3,8,8,3,4,9]`, Davies-Meyer feed-forward on the first +// `DIGEST` lanes. Full rounds apply the S-box to every lane; partial rounds +// apply it only to lane 0. +// +// For each round we commit: +// - `committed_x3[i] = (state[i] + RC[r][i])^3` for every S-box lane, +// - `post[i]` = entire state after MDS multiply. +// +// The S-box output expression `(x^3)^2 * x = x^7` stays at degree 3, so the +// whole constraint system fits under `degree_air = 3`. + +const FULL_ROUND_COLS: usize = WIDTH + WIDTH; // 8 S-box + 8 post-state +const PARTIAL_ROUND_COLS: usize = 1 + WIDTH; // 1 S-box + 8 post-state + +pub const fn is_full_round(r: usize) -> bool { + r < POSEIDON1_HALF_FULL_ROUNDS + || r >= POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS +} + +/// First column index of round `r`'s data (committed_x3, then post-state). +pub const fn round_data_offset(r: usize) -> usize { + let mut off = POSEIDON_8_COL_ROUND_START; + let mut i = 0; + while i < r { + off += if is_full_round(i) { FULL_ROUND_COLS } else { PARTIAL_ROUND_COLS }; + i += 1; + } + off +} + +/// Number of S-box (committed x³) columns in round `r`. +const fn round_sbox_lanes(r: usize) -> usize { + if is_full_round(r) { WIDTH } else { 1 } +} + +/// First column index of round `r`'s committed post-state. +pub const fn round_post_offset(r: usize) -> usize { + round_data_offset(r) + round_sbox_lanes(r) +} + +pub const fn num_cols_poseidon_8() -> usize { + round_data_offset(POSEIDON1_N_ROUNDS) +} + +const AUX_COLS_PER_ROW: usize = num_cols_poseidon_8() - POSEIDON_8_COL_ROUND_START; + +// ---------- Witness computation ---------- +// +// Replay the Poseidon1-8 permutation on `input`, returning every intermediate +// witness column (in trace order: round 0's committed_x3s, round 0's post, +// round 1's, …) together with the Davies-Meyer `[F; DIGEST]` digest. + +fn mds_row_coeff(j: usize, i: usize) -> F { + // Circulant matrix: `row_i · sbox_out = Σ_j COEFF[(j - i) mod W] * sbox_out[j]`. + F::from_u64(MDS8_ROW[(j + WIDTH - i) % WIDTH] as u64) +} + +pub(crate) fn compute_poseidon8_witness(input: [F; WIDTH]) -> (Vec, [F; DIGEST]) { + let mut state = input; + let mut aux = Vec::with_capacity(AUX_COLS_PER_ROW); + + for round in 0..POSEIDON1_N_ROUNDS { + // AddRoundConstants. + for i in 0..WIDTH { + state[i] = state[i] + GOLDILOCKS_POSEIDON1_RC_8[round][i]; + } + + // S-box: commit x³ for each active lane, replace state[i] with x⁷. + let sbox_lanes = round_sbox_lanes(round); + let mut committed = [F::ZERO; WIDTH]; + for i in 0..sbox_lanes { + let x3 = state[i].cube(); + committed[i] = x3; + aux.push(x3); + state[i] = x3 * x3 * state[i]; // x⁷ = (x³)² · x + } + + // MDS multiply. + let mut post = [F::ZERO; WIDTH]; + for i in 0..WIDTH { + let mut acc = state[0] * mds_row_coeff(0, i); + for j in 1..WIDTH { + acc = acc + state[j] * mds_row_coeff(j, i); + } + post[i] = acc; + } + for i in 0..WIDTH { + aux.push(post[i]); + } + state = post; + + let _ = committed; // silence unused-var if WIDTH changes upstream + } + + // Davies-Meyer feed-forward: output = final_state + input, truncated. + let output: [F; DIGEST] = std::array::from_fn(|i| state[i] + input[i]); + (aux, output) +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Poseidon8Precompile; @@ -95,11 +181,17 @@ impl TableT for Poseidon8Precompile { row[POSEIDON_8_COL_INDEX_INPUT_LEFT] = F::from_usize(zero_vec_ptr); row[POSEIDON_8_COL_INDEX_INPUT_RIGHT] = F::from_usize(zero_vec_ptr); row[POSEIDON_8_COL_INDEX_INPUT_RES] = F::from_usize(null_hash_ptr); - // inputs stay zero. outputs = poseidon8_compress(0) — truncated to DIGEST. - let out = poseidon8_compress([F::ZERO; WIDTH]); - for (i, v) in out.iter().enumerate() { + // Inputs stay zero; compute and fill the matching witness + output. + let (aux, output) = compute_poseidon8_witness([F::ZERO; WIDTH]); + debug_assert_eq!(aux.len(), AUX_COLS_PER_ROW); + for (i, v) in output.iter().enumerate() { row[POSEIDON_8_COL_OUTPUT_START + i] = *v; } + for (i, v) in aux.iter().enumerate() { + row[POSEIDON_8_COL_ROUND_START + i] = *v; + } + // Sanity: pure-Poseidon compress should agree with the Davies-Meyer witness. + debug_assert_eq!(output, poseidon8_compress([F::ZERO; WIDTH])); row } @@ -121,11 +213,9 @@ impl TableT for Poseidon8Precompile { input[..DIGEST].copy_from_slice(&arg0); input[DIGEST..].copy_from_slice(&arg1); - let output = poseidon8_compress(input); - - let res_a: [F; DIGEST] = output; + let (aux, output) = compute_poseidon8_witness(input); - ctx.memory.set_slice(index_res_a.to_usize(), &res_a)?; + ctx.memory.set_slice(index_res_a.to_usize(), &output)?; trace.columns[POSEIDON_8_COL_FLAG].push(F::ONE); trace.columns[POSEIDON_8_COL_INDEX_INPUT_LEFT].push(arg_a); @@ -137,6 +227,9 @@ impl TableT for Poseidon8Precompile { for (i, value) in output.iter().enumerate() { trace.columns[POSEIDON_8_COL_OUTPUT_START + i].push(*value); } + for (i, value) in aux.iter().enumerate() { + trace.columns[POSEIDON_8_COL_ROUND_START + i].push(*value); + } Ok(()) } @@ -154,16 +247,47 @@ impl Air for Poseidon8Precompile { vec![] } fn n_constraints(&self) -> usize { - // Only the boolean flag gate, plus the bus / declared values. - 1 + BUS as usize + // 1 boolean flag + // + 326 per-round gates: each full round = 8 S-box + 8 post; each partial = 1 + 8 + // + 4 Davies-Meyer + // + bus + let mut n = 1 + 4 + BUS as usize; + let mut r = 0; + while r < POSEIDON1_N_ROUNDS { + n += round_sbox_lanes(r) + WIDTH; + r += 1; + } + n } fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { - let up = builder.up(); - let flag = up[POSEIDON_8_COL_FLAG]; - let index_a = up[POSEIDON_8_COL_INDEX_INPUT_LEFT]; - let index_b = up[POSEIDON_8_COL_INDEX_INPUT_RIGHT]; - let index_res = up[POSEIDON_8_COL_INDEX_INPUT_RES]; + // Phase 1 — snapshot every column read from `up` into owned locals so + // we can then use `builder` mutably without fighting the borrow checker. + let flag; + let index_a; + let index_b; + let index_res; + let inputs: [AB::IF; WIDTH]; + let outputs: [AB::IF; DIGEST]; + let mut sbox_commits: Vec> = Vec::with_capacity(POSEIDON1_N_ROUNDS); + let mut post_states: Vec<[AB::IF; WIDTH]> = Vec::with_capacity(POSEIDON1_N_ROUNDS); + { + let up = builder.up(); + flag = up[POSEIDON_8_COL_FLAG]; + index_a = up[POSEIDON_8_COL_INDEX_INPUT_LEFT]; + index_b = up[POSEIDON_8_COL_INDEX_INPUT_RIGHT]; + index_res = up[POSEIDON_8_COL_INDEX_INPUT_RES]; + inputs = std::array::from_fn(|i| up[POSEIDON_8_COL_INPUT_START + i]); + outputs = std::array::from_fn(|i| up[POSEIDON_8_COL_OUTPUT_START + i]); + for r in 0..POSEIDON1_N_ROUNDS { + let data_off = round_data_offset(r); + let post_off = round_post_offset(r); + let n_sbox = round_sbox_lanes(r); + sbox_commits.push((0..n_sbox).map(|i| up[data_off + i]).collect()); + post_states.push(std::array::from_fn(|i| up[post_off + i])); + } + } + // Phase 2 — bus / declare. if BUS { builder.eval_virtual_column(eval_virtual_bus_column::( extra_data, @@ -187,13 +311,43 @@ impl Air for Poseidon8Precompile { builder.assert_bool(flag); - // TODO(goldilocks-migration): constrain outputs to equal - // `poseidon8_compress([inputs[0..8]])`. Currently unconstrained — the - // prover is trusted to fill correct outputs via trace generation. - } -} + // Phase 3 — Poseidon1-8 permutation constraints with Davies-Meyer feed-forward. + let mut state: [AB::IF; WIDTH] = inputs; + for round in 0..POSEIDON1_N_ROUNDS { + let sbox_lanes = round_sbox_lanes(round); -pub const fn num_cols_poseidon_8() -> usize { - // flag + 3 indices + 8 inputs + 4 outputs - 4 + WIDTH + DIGEST_LEN + // AddRoundConstants: x[i] = state[i] + RC[r][i] (expression, degree 1). + let x: [AB::IF; WIDTH] = std::array::from_fn(|i| { + state[i] + AB::F::from_u64(GOLDILOCKS_POSEIDON1_RC_8[round][i].as_canonical_u64()) + }); + + // S-box lanes: commit committed_x3 = x³ and compute sbox_out = x⁷. + let mut sbox_out: [AB::IF; WIDTH] = x; + for i in 0..sbox_lanes { + let committed_x3 = sbox_commits[round][i]; + builder.assert_zero(committed_x3 - x[i] * x[i] * x[i]); + sbox_out[i] = committed_x3 * committed_x3 * x[i]; + } + + // MDS: post[i] = Σ_j MDS[(j-i) mod W] * sbox_out[j]. + let post = post_states[round]; + for i in 0..WIDTH { + let coeff0 = AB::F::from_u64(MDS8_ROW[(WIDTH - i) % WIDTH] as u64); + let mut acc = sbox_out[0] * coeff0; + for j in 1..WIDTH { + let coeff = AB::F::from_u64(MDS8_ROW[(j + WIDTH - i) % WIDTH] as u64); + acc = acc + sbox_out[j] * coeff; + } + builder.assert_zero(post[i] - acc); + } + + // Reset state to the committed post-state for the next round. + state = post; + } + + // Davies-Meyer: outputs[i] = final_state[i] + inputs[i] for i in 0..DIGEST. + for i in 0..DIGEST { + builder.assert_zero(outputs[i] - state[i] - inputs[i]); + } + } } diff --git a/crates/lean_vm/src/tables/poseidon_8/trace_gen.rs b/crates/lean_vm/src/tables/poseidon_8/trace_gen.rs index 68d254d2..38e6a168 100644 --- a/crates/lean_vm/src/tables/poseidon_8/trace_gen.rs +++ b/crates/lean_vm/src/tables/poseidon_8/trace_gen.rs @@ -3,11 +3,11 @@ use tracing::instrument; use crate::F; use backend::PrimeCharacteristicRing; -// TODO(goldilocks-migration): once the Goldilocks Poseidon1-8 AIR has real -// per-round witness columns, this is where we'll fill them. Today the stub AIR -// has no per-round columns, so the `execute` path already writes every column -// it needs. -#[instrument(name = "generate Poseidon8 AIR trace (stub)", skip_all)] +// `execute()` writes every column for each active row (I/O + all per-round +// witness cols), and `padding_row()` supplies the zero-input witness for +// trailing padding. This pass just equalises column lengths in case any got +// out of sync during trace construction. +#[instrument(name = "generate Poseidon8 AIR trace", skip_all)] pub fn fill_trace_poseidon_8(trace: &mut [Vec]) { let n = trace.iter().map(|col| col.len()).max().unwrap_or(0); for col in trace.iter_mut() { From d1c525f1f8c249270ede65cbe2b815564d3fd2ba Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 16 Apr 2026 11:05:01 +0200 Subject: [PATCH 06/31] wip --- crates/lean_vm/src/tables/poseidon_8/mod.rs | 319 +++++++++---- .../lean_vm/src/tables/poseidon_8/sparse.rs | 451 ++++++++++++++++++ 2 files changed, 682 insertions(+), 88 deletions(-) create mode 100644 crates/lean_vm/src/tables/poseidon_8/sparse.rs diff --git a/crates/lean_vm/src/tables/poseidon_8/mod.rs b/crates/lean_vm/src/tables/poseidon_8/mod.rs index 2cadf3b9..d5e4a55b 100644 --- a/crates/lean_vm/src/tables/poseidon_8/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_8/mod.rs @@ -3,9 +3,12 @@ use crate::execution::memory::MemoryAccess; use backend::*; use utils::{ToUsize, poseidon8_compress}; +mod sparse; mod trace_gen; pub use trace_gen::fill_trace_poseidon_8; +use sparse::{get_partial_constants, PARTIAL_ROUNDS as SPARSE_PARTIAL_ROUNDS}; + pub(super) const WIDTH: usize = 8; pub(super) const DIGEST: usize = DIGEST_LEN; // 4 @@ -32,47 +35,43 @@ pub const POSEIDON8_NAME: &str = "poseidon8_compress"; // ---------- Per-round aux columns ---------- // -// Goldilocks Poseidon1 width-8: 30 rounds, x⁷ S-box, circulant MDS row -// `MDS8_ROW = [7,1,3,8,8,3,4,9]`, Davies-Meyer feed-forward on the first -// `DIGEST` lanes. Full rounds apply the S-box to every lane; partial rounds -// apply it only to lane 0. -// -// For each round we commit: -// - `committed_x3[i] = (state[i] + RC[r][i])^3` for every S-box lane, -// - `post[i]` = entire state after MDS multiply. +// Goldilocks Poseidon1-8 with the Appendix B sparse partial-round decomposition +// (see `sparse.rs`). For each full round we commit: +// - `committed_x3[i]` for every S-box lane (8 cols) +// - `post[i]` = state after MDS (8 cols) +// For each partial round — after the one-shot `first_round_constants + m_i` +// transform — we commit only the lane-0 S-box data: +// - `committed_x3` (1 col) +// - `post_sbox` (1 col, the x⁷ output; lanes 1..W are expressed symbolically +// as rank-1 updates of previous `post_sbox`/committed values) // -// The S-box output expression `(x^3)^2 * x = x^7` stays at degree 3, so the -// whole constraint system fits under `degree_air = 3`. +// S-box gate is `committed_x3 = x³` (deg 3) and `post_sbox = committed_x3² · x` +// (deg 3 equality). Partial-round state[1..W] stays degree-1 via the sparse +// matmul `cheap_matmul`, so the whole system fits under `degree_air = 3`. -const FULL_ROUND_COLS: usize = WIDTH + WIDTH; // 8 S-box + 8 post-state -const PARTIAL_ROUND_COLS: usize = 1 + WIDTH; // 1 S-box + 8 post-state +const FULL_ROUND_COLS: usize = WIDTH + WIDTH; // 8 committed_x3 + 8 post-state +const PARTIAL_ROUND_COLS: usize = 2; // committed_x3 + post_sbox pub const fn is_full_round(r: usize) -> bool { r < POSEIDON1_HALF_FULL_ROUNDS || r >= POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS } -/// First column index of round `r`'s data (committed_x3, then post-state). +/// First column index of round `r`'s data. pub const fn round_data_offset(r: usize) -> usize { let mut off = POSEIDON_8_COL_ROUND_START; let mut i = 0; while i < r { - off += if is_full_round(i) { FULL_ROUND_COLS } else { PARTIAL_ROUND_COLS }; + off += if is_full_round(i) { + FULL_ROUND_COLS + } else { + PARTIAL_ROUND_COLS + }; i += 1; } off } -/// Number of S-box (committed x³) columns in round `r`. -const fn round_sbox_lanes(r: usize) -> usize { - if is_full_round(r) { WIDTH } else { 1 } -} - -/// First column index of round `r`'s committed post-state. -pub const fn round_post_offset(r: usize) -> usize { - round_data_offset(r) + round_sbox_lanes(r) -} - pub const fn num_cols_poseidon_8() -> usize { round_data_offset(POSEIDON1_N_ROUNDS) } @@ -81,54 +80,109 @@ const AUX_COLS_PER_ROW: usize = num_cols_poseidon_8() - POSEIDON_8_COL_ROUND_STA // ---------- Witness computation ---------- // -// Replay the Poseidon1-8 permutation on `input`, returning every intermediate -// witness column (in trace order: round 0's committed_x3s, round 0's post, -// round 1's, …) together with the Davies-Meyer `[F; DIGEST]` digest. - -fn mds_row_coeff(j: usize, i: usize) -> F { - // Circulant matrix: `row_i · sbox_out = Σ_j COEFF[(j - i) mod W] * sbox_out[j]`. - F::from_u64(MDS8_ROW[(j + WIDTH - i) % WIDTH] as u64) +// Replay the Poseidon1-8 permutation on `input`, emitting every committed +// column value in trace order. The partial phase uses the sparse +// decomposition so only 2 cols/round are emitted. + +fn mds_vec_mul(state: &[F; WIDTH]) -> [F; WIDTH] { + let mut out = [F::ZERO; WIDTH]; + for i in 0..WIDTH { + let mut acc = state[0] * F::from_u64(MDS8_ROW[(WIDTH - i) % WIDTH] as u64); + for j in 1..WIDTH { + acc = acc + state[j] * F::from_u64(MDS8_ROW[(j + WIDTH - i) % WIDTH] as u64); + } + out[i] = acc; + } + out } pub(crate) fn compute_poseidon8_witness(input: [F; WIDTH]) -> (Vec, [F; DIGEST]) { + let c = get_partial_constants(); let mut state = input; let mut aux = Vec::with_capacity(AUX_COLS_PER_ROW); - for round in 0..POSEIDON1_N_ROUNDS { - // AddRoundConstants. + // Initial full rounds. + for round in 0..POSEIDON1_HALF_FULL_ROUNDS { for i in 0..WIDTH { state[i] = state[i] + GOLDILOCKS_POSEIDON1_RC_8[round][i]; } - - // S-box: commit x³ for each active lane, replace state[i] with x⁷. - let sbox_lanes = round_sbox_lanes(round); - let mut committed = [F::ZERO; WIDTH]; - for i in 0..sbox_lanes { + for i in 0..WIDTH { let x3 = state[i].cube(); - committed[i] = x3; aux.push(x3); - state[i] = x3 * x3 * state[i]; // x⁷ = (x³)² · x + state[i] = x3 * x3 * state[i]; // x⁷ + } + let post = mds_vec_mul(&state); + for v in &post { + aux.push(*v); } + state = post; + } - // MDS multiply. - let mut post = [F::ZERO; WIDTH]; + // Partial phase: absorb first_round_constants, apply m_i, then sparse rounds. + for i in 0..WIDTH { + state[i] = state[i] + c.first_round_constants[i]; + } + { + let mut after = [F::ZERO; WIDTH]; for i in 0..WIDTH { - let mut acc = state[0] * mds_row_coeff(0, i); - for j in 1..WIDTH { - acc = acc + state[j] * mds_row_coeff(j, i); + let mut acc = F::ZERO; + for j in 0..WIDTH { + acc = acc + c.m_i[i][j] * state[j]; } - post[i] = acc; + after[i] = acc; } + state = after; + } + + for r in 0..SPARSE_PARTIAL_ROUNDS { + let x = state[0]; + let x3 = x.cube(); + aux.push(x3); + let post_sbox = x3 * x3 * x; // x⁷ + aux.push(post_sbox); + + // state[0] becomes post_sbox (+ scalar RC, except last round). + state[0] = if r < SPARSE_PARTIAL_ROUNDS - 1 { + post_sbox + c.round_constants[r] + } else { + post_sbox + }; + + // cheap_matmul: + // new_state[0] = Σ_j sparse_first_row[r][j] · state[j] + // new_state[i] = state[i] + v[r][i-1] · old_state[0] (for i ≥ 1) + let old_s0 = state[0]; + let mut new_s0 = F::ZERO; + for j in 0..WIDTH { + new_s0 = new_s0 + c.sparse_first_row[r][j] * state[j]; + } + state[0] = new_s0; + for i in 1..WIDTH { + state[i] = state[i] + c.v[r][i - 1] * old_s0; + } + } + + // Terminal full rounds. + for round in 0..POSEIDON1_HALF_FULL_ROUNDS { + let abs = POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS + round; for i in 0..WIDTH { - aux.push(post[i]); + state[i] = state[i] + GOLDILOCKS_POSEIDON1_RC_8[abs][i]; + } + for i in 0..WIDTH { + let x3 = state[i].cube(); + aux.push(x3); + state[i] = x3 * x3 * state[i]; + } + let post = mds_vec_mul(&state); + for v in &post { + aux.push(*v); } state = post; - - let _ = committed; // silence unused-var if WIDTH changes upstream } - // Davies-Meyer feed-forward: output = final_state + input, truncated. + // Davies-Meyer feed-forward. let output: [F; DIGEST] = std::array::from_fn(|i| state[i] + input[i]); + debug_assert_eq!(aux.len(), AUX_COLS_PER_ROW); (aux, output) } @@ -183,14 +237,13 @@ impl TableT for Poseidon8Precompile { row[POSEIDON_8_COL_INDEX_INPUT_RES] = F::from_usize(null_hash_ptr); // Inputs stay zero; compute and fill the matching witness + output. let (aux, output) = compute_poseidon8_witness([F::ZERO; WIDTH]); - debug_assert_eq!(aux.len(), AUX_COLS_PER_ROW); for (i, v) in output.iter().enumerate() { row[POSEIDON_8_COL_OUTPUT_START + i] = *v; } for (i, v) in aux.iter().enumerate() { row[POSEIDON_8_COL_ROUND_START + i] = *v; } - // Sanity: pure-Poseidon compress should agree with the Davies-Meyer witness. + // Sanity: Davies-Meyer witness must agree with the direct primitive. debug_assert_eq!(output, poseidon8_compress([F::ZERO; WIDTH])); row } @@ -235,6 +288,20 @@ impl TableT for Poseidon8Precompile { } } +/// Constraint count, computed once at monomorphisation. Must match the number +/// of `assert_*` / `eval_virtual_column` / `declare_values` calls issued in +/// `eval()` exactly; used by the proving pipeline for pre-allocation. +const fn poseidon8_n_constraints(bus: bool) -> usize { + // 1 boolean flag. + // Initial + terminal full rounds: 8 S-box gates + 8 MDS gates per round. + // Partial rounds: 1 S-box gate + 1 post_sbox gate per round. + // Davies-Meyer: 4 output gates. + // + bus (if enabled). + let full_gates = 2 * POSEIDON1_HALF_FULL_ROUNDS * (WIDTH + WIDTH); + let partial_gates = POSEIDON1_PARTIAL_ROUNDS * 2; + 1 + full_gates + partial_gates + DIGEST + bus as usize +} + impl Air for Poseidon8Precompile { type ExtraData = ExtraDataForBuses; fn n_columns(&self) -> usize { @@ -247,29 +314,26 @@ impl Air for Poseidon8Precompile { vec![] } fn n_constraints(&self) -> usize { - // 1 boolean flag - // + 326 per-round gates: each full round = 8 S-box + 8 post; each partial = 1 + 8 - // + 4 Davies-Meyer - // + bus - let mut n = 1 + 4 + BUS as usize; - let mut r = 0; - while r < POSEIDON1_N_ROUNDS { - n += round_sbox_lanes(r) + WIDTH; - r += 1; - } - n + poseidon8_n_constraints(BUS) } fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { - // Phase 1 — snapshot every column read from `up` into owned locals so - // we can then use `builder` mutably without fighting the borrow checker. + let c = get_partial_constants(); + + // Phase 1 — snapshot every `up[…]` column read into owned locals so we + // can then use `builder` mutably without fighting the borrow checker. let flag; let index_a; let index_b; let index_res; let inputs: [AB::IF; WIDTH]; let outputs: [AB::IF; DIGEST]; - let mut sbox_commits: Vec> = Vec::with_capacity(POSEIDON1_N_ROUNDS); - let mut post_states: Vec<[AB::IF; WIDTH]> = Vec::with_capacity(POSEIDON1_N_ROUNDS); + // For each full round we need `committed_x3[0..W]` and `post[0..W]`. + let mut full_commits: Vec<[AB::IF; WIDTH]> = + Vec::with_capacity(2 * POSEIDON1_HALF_FULL_ROUNDS); + let mut full_posts: Vec<[AB::IF; WIDTH]> = + Vec::with_capacity(2 * POSEIDON1_HALF_FULL_ROUNDS); + let mut partial_commits: Vec = Vec::with_capacity(SPARSE_PARTIAL_ROUNDS); + let mut partial_post_sboxes: Vec = Vec::with_capacity(SPARSE_PARTIAL_ROUNDS); { let up = builder.up(); flag = up[POSEIDON_8_COL_FLAG]; @@ -278,12 +342,18 @@ impl Air for Poseidon8Precompile { index_res = up[POSEIDON_8_COL_INDEX_INPUT_RES]; inputs = std::array::from_fn(|i| up[POSEIDON_8_COL_INPUT_START + i]); outputs = std::array::from_fn(|i| up[POSEIDON_8_COL_OUTPUT_START + i]); - for r in 0..POSEIDON1_N_ROUNDS { - let data_off = round_data_offset(r); - let post_off = round_post_offset(r); - let n_sbox = round_sbox_lanes(r); - sbox_commits.push((0..n_sbox).map(|i| up[data_off + i]).collect()); - post_states.push(std::array::from_fn(|i| up[post_off + i])); + + for round in 0..POSEIDON1_N_ROUNDS { + let off = round_data_offset(round); + if is_full_round(round) { + let commit: [AB::IF; WIDTH] = std::array::from_fn(|i| up[off + i]); + let post: [AB::IF; WIDTH] = std::array::from_fn(|i| up[off + WIDTH + i]); + full_commits.push(commit); + full_posts.push(post); + } else { + partial_commits.push(up[off]); + partial_post_sboxes.push(up[off + 1]); + } } } @@ -313,35 +383,108 @@ impl Air for Poseidon8Precompile { // Phase 3 — Poseidon1-8 permutation constraints with Davies-Meyer feed-forward. let mut state: [AB::IF; WIDTH] = inputs; - for round in 0..POSEIDON1_N_ROUNDS { - let sbox_lanes = round_sbox_lanes(round); - // AddRoundConstants: x[i] = state[i] + RC[r][i] (expression, degree 1). + // ---- Initial full rounds ---- + for round in 0..POSEIDON1_HALF_FULL_ROUNDS { let x: [AB::IF; WIDTH] = std::array::from_fn(|i| { - state[i] + AB::F::from_u64(GOLDILOCKS_POSEIDON1_RC_8[round][i].as_canonical_u64()) + state[i] + + AB::F::from_u64(GOLDILOCKS_POSEIDON1_RC_8[round][i].as_canonical_u64()) }); - - // S-box lanes: commit committed_x3 = x³ and compute sbox_out = x⁷. let mut sbox_out: [AB::IF; WIDTH] = x; - for i in 0..sbox_lanes { - let committed_x3 = sbox_commits[round][i]; + for i in 0..WIDTH { + let committed_x3 = full_commits[round][i]; builder.assert_zero(committed_x3 - x[i] * x[i] * x[i]); sbox_out[i] = committed_x3 * committed_x3 * x[i]; } - - // MDS: post[i] = Σ_j MDS[(j-i) mod W] * sbox_out[j]. - let post = post_states[round]; + let post = full_posts[round]; for i in 0..WIDTH { - let coeff0 = AB::F::from_u64(MDS8_ROW[(WIDTH - i) % WIDTH] as u64); - let mut acc = sbox_out[0] * coeff0; + let mut acc = sbox_out[0] + * AB::F::from_u64(MDS8_ROW[(WIDTH - i) % WIDTH] as u64); for j in 1..WIDTH { - let coeff = AB::F::from_u64(MDS8_ROW[(j + WIDTH - i) % WIDTH] as u64); + let coeff = + AB::F::from_u64(MDS8_ROW[(j + WIDTH - i) % WIDTH] as u64); acc = acc + sbox_out[j] * coeff; } builder.assert_zero(post[i] - acc); } + state = post; + } + + // ---- Partial phase: first_round_constants, m_i, sparse-matmul loop ---- + for i in 0..WIDTH { + state[i] = state[i] + + AB::F::from_u64(c.first_round_constants[i].as_canonical_u64()); + } + { + let mut after: [AB::IF; WIDTH] = std::array::from_fn(|i| { + let mut acc = state[0] * AB::F::from_u64(c.m_i[i][0].as_canonical_u64()); + for j in 1..WIDTH { + acc = acc + state[j] * AB::F::from_u64(c.m_i[i][j].as_canonical_u64()); + } + acc + }); + std::mem::swap(&mut state, &mut after); + } + + for r in 0..SPARSE_PARTIAL_ROUNDS { + let x = state[0]; + let committed_x3 = partial_commits[r]; + let post_sbox = partial_post_sboxes[r]; + + // committed_x3 = x³. + builder.assert_zero(committed_x3 - x * x * x); + // post_sbox = committed_x3² · x = x⁷. + builder.assert_zero(post_sbox - committed_x3 * committed_x3 * x); + + // state[0] becomes post_sbox (+ scalar RC, except last round). + state[0] = if r < SPARSE_PARTIAL_ROUNDS - 1 { + post_sbox + + AB::F::from_u64(c.round_constants[r].as_canonical_u64()) + } else { + post_sbox + }; + + // cheap_matmul. + let old_s0 = state[0]; + let mut new_s0 = state[0] + * AB::F::from_u64(c.sparse_first_row[r][0].as_canonical_u64()); + for j in 1..WIDTH { + new_s0 = new_s0 + + state[j] + * AB::F::from_u64(c.sparse_first_row[r][j].as_canonical_u64()); + } + state[0] = new_s0; + for i in 1..WIDTH { + state[i] = state[i] + + old_s0 + * AB::F::from_u64(c.v[r][i - 1].as_canonical_u64()); + } + } - // Reset state to the committed post-state for the next round. + // ---- Terminal full rounds ---- + for round in 0..POSEIDON1_HALF_FULL_ROUNDS { + let abs = POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS + round; + let x: [AB::IF; WIDTH] = std::array::from_fn(|i| { + state[i] + + AB::F::from_u64(GOLDILOCKS_POSEIDON1_RC_8[abs][i].as_canonical_u64()) + }); + let mut sbox_out: [AB::IF; WIDTH] = x; + for i in 0..WIDTH { + let committed_x3 = full_commits[POSEIDON1_HALF_FULL_ROUNDS + round][i]; + builder.assert_zero(committed_x3 - x[i] * x[i] * x[i]); + sbox_out[i] = committed_x3 * committed_x3 * x[i]; + } + let post = full_posts[POSEIDON1_HALF_FULL_ROUNDS + round]; + for i in 0..WIDTH { + let mut acc = sbox_out[0] + * AB::F::from_u64(MDS8_ROW[(WIDTH - i) % WIDTH] as u64); + for j in 1..WIDTH { + let coeff = + AB::F::from_u64(MDS8_ROW[(j + WIDTH - i) % WIDTH] as u64); + acc = acc + sbox_out[j] * coeff; + } + builder.assert_zero(post[i] - acc); + } state = post; } diff --git a/crates/lean_vm/src/tables/poseidon_8/sparse.rs b/crates/lean_vm/src/tables/poseidon_8/sparse.rs new file mode 100644 index 00000000..e633ec54 --- /dev/null +++ b/crates/lean_vm/src/tables/poseidon_8/sparse.rs @@ -0,0 +1,451 @@ +//! Sparse matrix decomposition for Poseidon1-8 partial rounds. +//! +//! Port of `plonky3/poseidon1/src/utils.rs` specialised to the Goldilocks width-8 +//! configuration. Produces the transition matrix `m_i`, the per-round sparse +//! matrices (`sparse_first_row[r]`, `v[r]`), and the compressed round constants +//! (`first_round_constants` + per-round scalar `round_constants`), so that all +//! 22 partial rounds can be constrained with 2 committed columns each instead +//! of the naive 9. +//! +//! References: +//! - Grassi et al., "Poseidon: A New Hash Function for Zero-Knowledge Proof +//! Systems", USENIX Security 2021, Appendix B. +//! - `plonky3/poseidon1/src/{utils.rs, internal.rs}`. + +use std::sync::OnceLock; + +use backend::{ + Field, GOLDILOCKS_POSEIDON1_RC_8, MDS8_ROW, PrimeCharacteristicRing, + POSEIDON1_HALF_FULL_ROUNDS, POSEIDON1_PARTIAL_ROUNDS, +}; + +use crate::F; + +pub(super) const WIDTH: usize = 8; +pub(super) const PARTIAL_ROUNDS: usize = POSEIDON1_PARTIAL_ROUNDS; + +/// Precomputed constants for the sparse partial-round layer. +#[derive(Debug, Clone)] +pub(super) struct PartialConstants { + /// Absorbs the original partial-round 0 vector plus backward-substituted + /// remainders from rounds 1..RP. Added once before the m_i multiply. + pub first_round_constants: [F; WIDTH], + /// Dense transition matrix applied once after adding + /// `first_round_constants` and before the sparse-round loop. + pub m_i: [[F; WIDTH]; WIDTH], + /// Per-round pre-assembled first row of the sparse matrix: + /// `[mds[0][0], w_hat[0], …, w_hat[WIDTH-2]]`. + pub sparse_first_row: [[F; WIDTH]; PARTIAL_ROUNDS], + /// Per-round first-column coefficients (WIDTH-1 entries; we use `[F; WIDTH-1]`). + pub v: [[F; WIDTH - 1]; PARTIAL_ROUNDS], + /// Scalar round constants for partial rounds 0..RP-1 (the last round uses + /// no additive constant — it was absorbed by the backward substitution). + pub round_constants: [F; PARTIAL_ROUNDS - 1], +} + +static PARTIAL_CONSTANTS: OnceLock = OnceLock::new(); + +pub(super) fn get_partial_constants() -> &'static PartialConstants { + PARTIAL_CONSTANTS.get_or_init(compute_partial_constants) +} + +/// Build the dense WxW circulant MDS matrix from `MDS8_ROW`. +/// +/// `M[i][j] = MDS8_ROW[(j - i) mod W]`, stored as `F`. +pub(super) fn mds_dense() -> [[F; WIDTH]; WIDTH] { + let mut m = [[F::ZERO; WIDTH]; WIDTH]; + for i in 0..WIDTH { + for j in 0..WIDTH { + m[i][j] = F::from_u64(MDS8_ROW[(j + WIDTH - i) % WIDTH] as u64); + } + } + m +} + +fn matrix_transpose(m: &[[F; WIDTH]; WIDTH]) -> [[F; WIDTH]; WIDTH] { + let mut r = [[F::ZERO; WIDTH]; WIDTH]; + for i in 0..WIDTH { + for j in 0..WIDTH { + r[i][j] = m[j][i]; + } + } + r +} + +fn matrix_mul( + a: &[[F; WIDTH]; WIDTH], + b: &[[F; WIDTH]; WIDTH], +) -> [[F; WIDTH]; WIDTH] { + let mut c = [[F::ZERO; WIDTH]; WIDTH]; + for i in 0..WIDTH { + for j in 0..WIDTH { + let mut acc = F::ZERO; + for k in 0..WIDTH { + acc = acc + a[i][k] * b[k][j]; + } + c[i][j] = acc; + } + } + c +} + +fn matrix_vec_mul(m: &[[F; WIDTH]; WIDTH], v: &[F; WIDTH]) -> [F; WIDTH] { + let mut r = [F::ZERO; WIDTH]; + for i in 0..WIDTH { + let mut acc = F::ZERO; + for j in 0..WIDTH { + acc = acc + m[i][j] * v[j]; + } + r[i] = acc; + } + r +} + +fn matrix_inverse(m: &[[F; WIDTH]; WIDTH]) -> [[F; WIDTH]; WIDTH] { + let mut aug = *m; + let mut inv = [[F::ZERO; WIDTH]; WIDTH]; + for i in 0..WIDTH { + inv[i][i] = F::ONE; + } + for col in 0..WIDTH { + let pivot = (col..WIDTH) + .find(|&r| aug[r][col] != F::ZERO) + .expect("mds matrix is singular"); + if pivot != col { + aug.swap(col, pivot); + inv.swap(col, pivot); + } + let pivot_inv = aug[col][col].inverse(); + for j in 0..WIDTH { + aug[col][j] = aug[col][j] * pivot_inv; + inv[col][j] = inv[col][j] * pivot_inv; + } + for i in 0..WIDTH { + if i == col { + continue; + } + let factor = aug[i][col]; + if factor == F::ZERO { + continue; + } + let aug_row = aug[col]; + let inv_row = inv[col]; + for j in 0..WIDTH { + aug[i][j] = aug[i][j] - factor * aug_row[j]; + inv[i][j] = inv[i][j] - factor * inv_row[j]; + } + } + } + inv +} + +/// Inverse of the bottom-right (W-1)x(W-1) submatrix `m[1..][1..]`. +fn submatrix_inverse(m: &[[F; WIDTH]; WIDTH]) -> [[F; WIDTH - 1]; WIDTH - 1] { + const N: usize = WIDTH - 1; + let mut sub = [[F::ZERO; N]; N]; + for i in 0..N { + for j in 0..N { + sub[i][j] = m[i + 1][j + 1]; + } + } + let mut inv = [[F::ZERO; N]; N]; + for i in 0..N { + inv[i][i] = F::ONE; + } + for col in 0..N { + let pivot = (col..N) + .find(|&r| sub[r][col] != F::ZERO) + .expect("mds submatrix is singular"); + if pivot != col { + sub.swap(col, pivot); + inv.swap(col, pivot); + } + let pivot_inv = sub[col][col].inverse(); + for j in 0..N { + sub[col][j] = sub[col][j] * pivot_inv; + inv[col][j] = inv[col][j] * pivot_inv; + } + for i in 0..N { + if i == col { + continue; + } + let factor = sub[i][col]; + if factor == F::ZERO { + continue; + } + let sub_row = sub[col]; + let inv_row = inv[col]; + for j in 0..N { + sub[i][j] = sub[i][j] - factor * sub_row[j]; + inv[i][j] = inv[i][j] - factor * inv_row[j]; + } + } + } + inv +} + +/// Factor the dense MDS matrix into `RP` sparse factors. +/// +/// Returns `(m_i, v_collection, w_hat_collection)` all in forward application +/// order; `v_collection[r]` and `w_hat_collection[r]` have `WIDTH-1` meaningful +/// entries (the last slot is zero padding for fixed-size arrays). +fn compute_equivalent_matrices( + mds: &[[F; WIDTH]; WIDTH], + rounds_p: usize, +) -> ( + [[F; WIDTH]; WIDTH], + Vec<[F; WIDTH]>, + Vec<[F; WIDTH]>, +) { + let mut v_collection: Vec<[F; WIDTH]> = Vec::with_capacity(rounds_p); + let mut w_hat_collection: Vec<[F; WIDTH]> = Vec::with_capacity(rounds_p); + + let mds_t = matrix_transpose(mds); + let mut m_mul = mds_t; + let mut m_i = [[F::ZERO; WIDTH]; WIDTH]; + + for _ in 0..rounds_p { + // v = first row of m_mul (excl [0,0]). In the transposed domain this is + // the first column of M'' in the non-transposed view. + let v_arr: [F; WIDTH] = std::array::from_fn(|j| { + if j < WIDTH - 1 { + m_mul[0][j + 1] + } else { + F::ZERO + } + }); + + // w = first column of m_mul (excl [0,0]). + let mut w = [F::ZERO; WIDTH - 1]; + for i in 0..WIDTH - 1 { + w[i] = m_mul[i + 1][0]; + } + // w_hat = M_hat^{-1} * w. + let m_hat_inv = submatrix_inverse(&m_mul); + let w_hat_arr: [F; WIDTH] = std::array::from_fn(|i| { + if i < WIDTH - 1 { + let mut acc = F::ZERO; + for k in 0..WIDTH - 1 { + acc = acc + m_hat_inv[i][k] * w[k]; + } + acc + } else { + F::ZERO + } + }); + + v_collection.push(v_arr); + w_hat_collection.push(w_hat_arr); + + // m_i = identity-like with m_mul's first row/column stored, then + // "absorb" the rest: first column zeroed, first row zeroed, [0][0]=1. + m_i = m_mul; + m_i[0][0] = F::ONE; + for r in 1..WIDTH { + m_i[r][0] = F::ZERO; + } + for c in 1..WIDTH { + m_i[0][c] = F::ZERO; + } + + // Accumulate: m_mul = M^T * m_i. + m_mul = matrix_mul(&mds_t, &m_i); + } + + // Transpose m_i back (HorizenLabs works in the transposed domain). + let m_i_returned = matrix_transpose(&m_i); + + v_collection.reverse(); + w_hat_collection.reverse(); + + (m_i_returned, v_collection, w_hat_collection) +} + +/// Backward-substitute partial round constants through M^{-1}, producing the +/// full first-round vector and per-round scalar offsets. +fn equivalent_round_constants( + partial_rc: &[[F; WIDTH]], + mds_inv: &[[F; WIDTH]; WIDTH], +) -> ([F; WIDTH], Vec) { + let rounds_p = partial_rc.len(); + let mut opt_partial_rc = vec![F::ZERO; rounds_p]; + + let mut tmp = partial_rc[rounds_p - 1]; + for i in (0..rounds_p - 1).rev() { + let inv_cip = matrix_vec_mul(mds_inv, &tmp); + opt_partial_rc[i + 1] = inv_cip[0]; + tmp = partial_rc[i]; + for j in 1..WIDTH { + tmp[j] = tmp[j] + inv_cip[j]; + } + } + let first_round_constants = tmp; + let opt_partial_rc = opt_partial_rc[1..].to_vec(); + (first_round_constants, opt_partial_rc) +} + +fn compute_partial_constants() -> PartialConstants { + let mds = mds_dense(); + let mds_inv = matrix_inverse(&mds); + + // Slice out the partial-round RCs from the monolithic RC table. + let partial_rc: Vec<[F; WIDTH]> = (0..PARTIAL_ROUNDS) + .map(|r| GOLDILOCKS_POSEIDON1_RC_8[POSEIDON1_HALF_FULL_ROUNDS + r]) + .collect(); + + let (first_round_constants, round_constants_vec) = + equivalent_round_constants(&partial_rc, &mds_inv); + let (m_i, v_collection, w_hat_collection) = + compute_equivalent_matrices(&mds, PARTIAL_ROUNDS); + + // sparse_first_row[r] = [mds[0][0], w_hat[r][0], …, w_hat[r][W-2]]. + let mds_0_0 = mds[0][0]; + let mut sparse_first_row = [[F::ZERO; WIDTH]; PARTIAL_ROUNDS]; + for r in 0..PARTIAL_ROUNDS { + sparse_first_row[r][0] = mds_0_0; + for i in 1..WIDTH { + sparse_first_row[r][i] = w_hat_collection[r][i - 1]; + } + } + + // v[r] stripped to [F; WIDTH-1] (drop the zero tail). + let mut v = [[F::ZERO; WIDTH - 1]; PARTIAL_ROUNDS]; + for r in 0..PARTIAL_ROUNDS { + for i in 0..WIDTH - 1 { + v[r][i] = v_collection[r][i]; + } + } + + let mut round_constants = [F::ZERO; PARTIAL_ROUNDS - 1]; + for i in 0..PARTIAL_ROUNDS - 1 { + round_constants[i] = round_constants_vec[i]; + } + + PartialConstants { + first_round_constants, + m_i, + sparse_first_row, + v, + round_constants, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use backend::{POSEIDON1_HALF_FULL_ROUNDS, PrimeField64}; + use utils::poseidon8_compress; + + fn sbox7(x: F) -> F { + let x2 = x * x; + let x4 = x2 * x2; + x4 * x2 * x + } + + /// Textbook (non-sparse) partial-round phase, used as a reference. + fn textbook_partial_phase(mut state: [F; WIDTH]) -> [F; WIDTH] { + let mds = mds_dense(); + for r in 0..PARTIAL_ROUNDS { + let abs = POSEIDON1_HALF_FULL_ROUNDS + r; + for i in 0..WIDTH { + state[i] = state[i] + GOLDILOCKS_POSEIDON1_RC_8[abs][i]; + } + state[0] = sbox7(state[0]); + state = matrix_vec_mul(&mds, &state); + } + state + } + + /// Sparse partial-round phase — what the AIR witness computes. + fn sparse_partial_phase(mut state: [F; WIDTH]) -> [F; WIDTH] { + let c = get_partial_constants(); + for i in 0..WIDTH { + state[i] = state[i] + c.first_round_constants[i]; + } + state = matrix_vec_mul(&c.m_i, &state); + for r in 0..PARTIAL_ROUNDS - 1 { + state[0] = sbox7(state[0]); + state[0] = state[0] + c.round_constants[r]; + let old_s0 = state[0]; + // new_state[0] = dot(sparse_first_row[r], state). + let mut acc = F::ZERO; + for j in 0..WIDTH { + acc = acc + c.sparse_first_row[r][j] * state[j]; + } + state[0] = acc; + for i in 1..WIDTH { + state[i] = state[i] + c.v[r][i - 1] * old_s0; + } + } + // last round — no RC. + { + let r = PARTIAL_ROUNDS - 1; + state[0] = sbox7(state[0]); + let old_s0 = state[0]; + let mut acc = F::ZERO; + for j in 0..WIDTH { + acc = acc + c.sparse_first_row[r][j] * state[j]; + } + state[0] = acc; + for i in 1..WIDTH { + state[i] = state[i] + c.v[r][i - 1] * old_s0; + } + } + state + } + + /// The sparse decomposition must reproduce the textbook phase bit-for-bit. + #[test] + fn sparse_matches_textbook() { + let mut seed = 0u64; + for trial in 0..4 { + seed = seed.wrapping_add(0x9E37_79B9_7F4A_7C15); + let input: [F; WIDTH] = std::array::from_fn(|i| { + F::from_u64(seed.wrapping_mul(i as u64 + 1 + trial as u64)) + }); + let a = textbook_partial_phase(input); + let b = sparse_partial_phase(input); + for i in 0..WIDTH { + assert_eq!( + a[i].as_canonical_u64(), + b[i].as_canonical_u64(), + "trial {trial} lane {i}" + ); + } + } + } + + /// End-to-end: a full permutation via the sparse decomposition must match + /// `poseidon8_compress` (Davies-Meyer around the Goldilocks Poseidon1-8). + #[test] + fn full_permutation_matches_poseidon8_compress() { + let input: [F; WIDTH] = std::array::from_fn(|i| F::from_u64(i as u64 * 37 + 1)); + // Initial full rounds. + let mut state = input; + let mds = mds_dense(); + for round in 0..POSEIDON1_HALF_FULL_ROUNDS { + for i in 0..WIDTH { + state[i] = state[i] + GOLDILOCKS_POSEIDON1_RC_8[round][i]; + } + for i in 0..WIDTH { + state[i] = sbox7(state[i]); + } + state = matrix_vec_mul(&mds, &state); + } + state = sparse_partial_phase(state); + for round in 0..POSEIDON1_HALF_FULL_ROUNDS { + let abs = POSEIDON1_HALF_FULL_ROUNDS + PARTIAL_ROUNDS + round; + for i in 0..WIDTH { + state[i] = state[i] + GOLDILOCKS_POSEIDON1_RC_8[abs][i]; + } + for i in 0..WIDTH { + state[i] = sbox7(state[i]); + } + state = matrix_vec_mul(&mds, &state); + } + // Davies-Meyer. + let output: [F; 4] = std::array::from_fn(|i| state[i] + input[i]); + let expected = poseidon8_compress(input); + assert_eq!(output, expected); + } +} From beaf0d6acd7720a064af9c477722721ae55cb865 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 16 Apr 2026 11:10:05 +0200 Subject: [PATCH 07/31] degree 7 air (instead of 3) for poseidon --- crates/lean_vm/src/tables/poseidon_8/mod.rs | 116 ++++++++------------ 1 file changed, 48 insertions(+), 68 deletions(-) diff --git a/crates/lean_vm/src/tables/poseidon_8/mod.rs b/crates/lean_vm/src/tables/poseidon_8/mod.rs index d5e4a55b..a2bf7bae 100644 --- a/crates/lean_vm/src/tables/poseidon_8/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_8/mod.rs @@ -36,21 +36,21 @@ pub const POSEIDON8_NAME: &str = "poseidon8_compress"; // ---------- Per-round aux columns ---------- // // Goldilocks Poseidon1-8 with the Appendix B sparse partial-round decomposition -// (see `sparse.rs`). For each full round we commit: -// - `committed_x3[i]` for every S-box lane (8 cols) -// - `post[i]` = state after MDS (8 cols) -// For each partial round — after the one-shot `first_round_constants + m_i` -// transform — we commit only the lane-0 S-box data: -// - `committed_x3` (1 col) -// - `post_sbox` (1 col, the x⁷ output; lanes 1..W are expressed symbolically -// as rank-1 updates of previous `post_sbox`/committed values) +// (see `sparse.rs`). The S-box is `x → x⁷` emitted directly as a degree-7 +// expression `x·x²·x⁴`, so we commit only the minimum needed to reset degree +// between rounds — no `committed_x3` intermediates. // -// S-box gate is `committed_x3 = x³` (deg 3) and `post_sbox = committed_x3² · x` -// (deg 3 equality). Partial-round state[1..W] stays degree-1 via the sparse -// matmul `cheap_matmul`, so the whole system fits under `degree_air = 3`. +// Per full round: 8 `post[i]` cols (state after MDS). +// Per partial round: 1 `post_sbox` col (the x⁷ output for lane 0); lanes 1..W +// are expressed symbolically as rank-1 updates via `cheap_matmul`. +// +// Constraints: +// - Full round: `post[i] - Σ_j MDS[i][j] · x[j]⁷ = 0` (deg 7 equality). +// - Partial round: `post_sbox - x⁷ = 0` (deg 7 equality). +// - Davies-Meyer: `outputs[i] - final_state[i] - inputs[i] = 0` (deg 1). -const FULL_ROUND_COLS: usize = WIDTH + WIDTH; // 8 committed_x3 + 8 post-state -const PARTIAL_ROUND_COLS: usize = 2; // committed_x3 + post_sbox +const FULL_ROUND_COLS: usize = WIDTH; // 8 post-state +const PARTIAL_ROUND_COLS: usize = 1; // post_sbox pub const fn is_full_round(r: usize) -> bool { r < POSEIDON1_HALF_FULL_ROUNDS @@ -96,6 +96,12 @@ fn mds_vec_mul(state: &[F; WIDTH]) -> [F; WIDTH] { out } +fn sbox7(x: F) -> F { + let x2 = x * x; + let x4 = x2 * x2; + x4 * x2 * x +} + pub(crate) fn compute_poseidon8_witness(input: [F; WIDTH]) -> (Vec, [F; DIGEST]) { let c = get_partial_constants(); let mut state = input; @@ -104,12 +110,7 @@ pub(crate) fn compute_poseidon8_witness(input: [F; WIDTH]) -> (Vec, [F; DIGES // Initial full rounds. for round in 0..POSEIDON1_HALF_FULL_ROUNDS { for i in 0..WIDTH { - state[i] = state[i] + GOLDILOCKS_POSEIDON1_RC_8[round][i]; - } - for i in 0..WIDTH { - let x3 = state[i].cube(); - aux.push(x3); - state[i] = x3 * x3 * state[i]; // x⁷ + state[i] = sbox7(state[i] + GOLDILOCKS_POSEIDON1_RC_8[round][i]); } let post = mds_vec_mul(&state); for v in &post { @@ -135,13 +136,9 @@ pub(crate) fn compute_poseidon8_witness(input: [F; WIDTH]) -> (Vec, [F; DIGES } for r in 0..SPARSE_PARTIAL_ROUNDS { - let x = state[0]; - let x3 = x.cube(); - aux.push(x3); - let post_sbox = x3 * x3 * x; // x⁷ + let post_sbox = sbox7(state[0]); aux.push(post_sbox); - // state[0] becomes post_sbox (+ scalar RC, except last round). state[0] = if r < SPARSE_PARTIAL_ROUNDS - 1 { post_sbox + c.round_constants[r] } else { @@ -166,12 +163,7 @@ pub(crate) fn compute_poseidon8_witness(input: [F; WIDTH]) -> (Vec, [F; DIGES for round in 0..POSEIDON1_HALF_FULL_ROUNDS { let abs = POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS + round; for i in 0..WIDTH { - state[i] = state[i] + GOLDILOCKS_POSEIDON1_RC_8[abs][i]; - } - for i in 0..WIDTH { - let x3 = state[i].cube(); - aux.push(x3); - state[i] = x3 * x3 * state[i]; + state[i] = sbox7(state[i] + GOLDILOCKS_POSEIDON1_RC_8[abs][i]); } let post = mds_vec_mul(&state); for v in &post { @@ -293,12 +285,12 @@ impl TableT for Poseidon8Precompile { /// `eval()` exactly; used by the proving pipeline for pre-allocation. const fn poseidon8_n_constraints(bus: bool) -> usize { // 1 boolean flag. - // Initial + terminal full rounds: 8 S-box gates + 8 MDS gates per round. - // Partial rounds: 1 S-box gate + 1 post_sbox gate per round. - // Davies-Meyer: 4 output gates. + // Initial + terminal full rounds: 8 MDS equality gates per round (deg 7). + // Partial rounds: 1 post_sbox gate per round (deg 7). + // Davies-Meyer: 4 output gates (deg 1). // + bus (if enabled). - let full_gates = 2 * POSEIDON1_HALF_FULL_ROUNDS * (WIDTH + WIDTH); - let partial_gates = POSEIDON1_PARTIAL_ROUNDS * 2; + let full_gates = 2 * POSEIDON1_HALF_FULL_ROUNDS * WIDTH; + let partial_gates = POSEIDON1_PARTIAL_ROUNDS; 1 + full_gates + partial_gates + DIGEST + bus as usize } @@ -308,7 +300,7 @@ impl Air for Poseidon8Precompile { num_cols_poseidon_8() } fn degree_air(&self) -> usize { - 3 + 7 } fn down_column_indexes(&self) -> Vec { vec![] @@ -327,12 +319,9 @@ impl Air for Poseidon8Precompile { let index_res; let inputs: [AB::IF; WIDTH]; let outputs: [AB::IF; DIGEST]; - // For each full round we need `committed_x3[0..W]` and `post[0..W]`. - let mut full_commits: Vec<[AB::IF; WIDTH]> = - Vec::with_capacity(2 * POSEIDON1_HALF_FULL_ROUNDS); + // Per full round: `post[0..W]`. Per partial round: `post_sbox`. let mut full_posts: Vec<[AB::IF; WIDTH]> = Vec::with_capacity(2 * POSEIDON1_HALF_FULL_ROUNDS); - let mut partial_commits: Vec = Vec::with_capacity(SPARSE_PARTIAL_ROUNDS); let mut partial_post_sboxes: Vec = Vec::with_capacity(SPARSE_PARTIAL_ROUNDS); { let up = builder.up(); @@ -346,13 +335,10 @@ impl Air for Poseidon8Precompile { for round in 0..POSEIDON1_N_ROUNDS { let off = round_data_offset(round); if is_full_round(round) { - let commit: [AB::IF; WIDTH] = std::array::from_fn(|i| up[off + i]); - let post: [AB::IF; WIDTH] = std::array::from_fn(|i| up[off + WIDTH + i]); - full_commits.push(commit); + let post: [AB::IF; WIDTH] = std::array::from_fn(|i| up[off + i]); full_posts.push(post); } else { - partial_commits.push(up[off]); - partial_post_sboxes.push(up[off + 1]); + partial_post_sboxes.push(up[off]); } } } @@ -386,16 +372,14 @@ impl Air for Poseidon8Precompile { // ---- Initial full rounds ---- for round in 0..POSEIDON1_HALF_FULL_ROUNDS { - let x: [AB::IF; WIDTH] = std::array::from_fn(|i| { - state[i] - + AB::F::from_u64(GOLDILOCKS_POSEIDON1_RC_8[round][i].as_canonical_u64()) + let sbox_out: [AB::IF; WIDTH] = std::array::from_fn(|i| { + let x = state[i] + + AB::F::from_u64(GOLDILOCKS_POSEIDON1_RC_8[round][i].as_canonical_u64()); + // x⁷ = x · (x²)² · x² — 4 Mul nodes in the symbolic DAG. + let x2 = x * x; + let x4 = x2 * x2; + x4 * x2 * x }); - let mut sbox_out: [AB::IF; WIDTH] = x; - for i in 0..WIDTH { - let committed_x3 = full_commits[round][i]; - builder.assert_zero(committed_x3 - x[i] * x[i] * x[i]); - sbox_out[i] = committed_x3 * committed_x3 * x[i]; - } let post = full_posts[round]; for i in 0..WIDTH { let mut acc = sbox_out[0] @@ -428,13 +412,12 @@ impl Air for Poseidon8Precompile { for r in 0..SPARSE_PARTIAL_ROUNDS { let x = state[0]; - let committed_x3 = partial_commits[r]; let post_sbox = partial_post_sboxes[r]; - // committed_x3 = x³. - builder.assert_zero(committed_x3 - x * x * x); - // post_sbox = committed_x3² · x = x⁷. - builder.assert_zero(post_sbox - committed_x3 * committed_x3 * x); + // post_sbox = x⁷ (deg 7). + let x2 = x * x; + let x4 = x2 * x2; + builder.assert_zero(post_sbox - x4 * x2 * x); // state[0] becomes post_sbox (+ scalar RC, except last round). state[0] = if r < SPARSE_PARTIAL_ROUNDS - 1 { @@ -464,16 +447,13 @@ impl Air for Poseidon8Precompile { // ---- Terminal full rounds ---- for round in 0..POSEIDON1_HALF_FULL_ROUNDS { let abs = POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS + round; - let x: [AB::IF; WIDTH] = std::array::from_fn(|i| { - state[i] - + AB::F::from_u64(GOLDILOCKS_POSEIDON1_RC_8[abs][i].as_canonical_u64()) + let sbox_out: [AB::IF; WIDTH] = std::array::from_fn(|i| { + let x = state[i] + + AB::F::from_u64(GOLDILOCKS_POSEIDON1_RC_8[abs][i].as_canonical_u64()); + let x2 = x * x; + let x4 = x2 * x2; + x4 * x2 * x }); - let mut sbox_out: [AB::IF; WIDTH] = x; - for i in 0..WIDTH { - let committed_x3 = full_commits[POSEIDON1_HALF_FULL_ROUNDS + round][i]; - builder.assert_zero(committed_x3 - x[i] * x[i] * x[i]); - sbox_out[i] = committed_x3 * committed_x3 * x[i]; - } let post = full_posts[POSEIDON1_HALF_FULL_ROUNDS + round]; for i in 0..WIDTH { let mut acc = sbox_out[0] From c7448bcfba73e3b5e85be6ff02c9adac6cc202ce Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 16 Apr 2026 11:15:15 +0200 Subject: [PATCH 08/31] w --- crates/lean_compiler/tests/test_data/program_15.py | 2 +- crates/lean_compiler/tests/test_data/program_30.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/lean_compiler/tests/test_data/program_15.py b/crates/lean_compiler/tests/test_data/program_15.py index 6ea149c2..093e580f 100644 --- a/crates/lean_compiler/tests/test_data/program_15.py +++ b/crates/lean_compiler/tests/test_data/program_15.py @@ -10,7 +10,7 @@ def main(): i, j, k = func_1(x, y) assert i == 2 assert j == 3 - assert k == 2130706432 + assert k == 18446744069414584320 # -1 mod P_Goldilocks g = Array(8) h = Array(8) diff --git a/crates/lean_compiler/tests/test_data/program_30.py b/crates/lean_compiler/tests/test_data/program_30.py index 0348faa6..4cb608e9 100644 --- a/crates/lean_compiler/tests/test_data/program_30.py +++ b/crates/lean_compiler/tests/test_data/program_30.py @@ -9,7 +9,7 @@ def main(): for i in unroll(0, 2): res = f1(ARR[i]) buff[i + 1] = res - assert buff[2] == 1390320454 + assert buff[2] == 17401132340371191870 # regenerated for Goldilocks (P=2^64-2^32+1) return From 4b78c6ecf68b3ddc215cefc246085ec4713ff2d7 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 16 Apr 2026 11:48:55 +0200 Subject: [PATCH 09/31] wip --- crates/rec_aggregation/fiat_shamir.py | 22 +++---- crates/rec_aggregation/hashing.py | 78 ++++++++++++------------ crates/rec_aggregation/main.py | 6 +- crates/rec_aggregation/utils.py | 40 ++++++++---- crates/rec_aggregation/xmss_aggregate.py | 36 +++++------ 5 files changed, 99 insertions(+), 83 deletions(-) diff --git a/crates/rec_aggregation/fiat_shamir.py b/crates/rec_aggregation/fiat_shamir.py index 52a89e9b..9b464caa 100644 --- a/crates/rec_aggregation/fiat_shamir.py +++ b/crates/rec_aggregation/fiat_shamir.py @@ -16,10 +16,10 @@ def fs_new(transcript_ptr): @inline def fs_observe_chunks(fs, data, n_chunks): result: Mut = Array(9) - poseidon16_compress(fs, data, result) + poseidon8_compress(fs, data, result) for i in unroll(1, n_chunks): new_result = Array(9) - poseidon16_compress(result, data + i * DIGEST_LEN, new_result) + poseidon8_compress(result, data + i * DIGEST_LEN, new_result) result = new_result result[8] = fs[8] # preserve transcript pointer return result @@ -37,7 +37,7 @@ def fs_observe(fs, data, length: Const): for j in unroll(remainder, DIGEST_LEN): padded[j] = 0 final_result = Array(9) - poseidon16_compress(intermediate, padded, final_result) + poseidon8_compress(intermediate, padded, final_result) final_result[8] = fs[8] # preserve transcript pointer return final_result @@ -49,7 +49,7 @@ def fs_grinding(fs, bits): set_to_7_zeros(transcript_ptr + 1) new_fs = Array(9) - poseidon16_compress(fs, transcript_ptr, new_fs) + poseidon8_compress(fs, transcript_ptr, new_fs) new_fs[8] = transcript_ptr + 8 sampled = new_fs[0] @@ -68,7 +68,7 @@ def fs_sample_chunks(fs, n_chunks: Const): domain_sep = Array(8) domain_sep[0] = i set_to_7_zeros(domain_sep + 1) - poseidon16_compress( + poseidon8_compress( domain_sep, fs, sampled + i * 8, @@ -81,9 +81,9 @@ def fs_sample_chunks(fs, n_chunks: Const): @inline def fs_sample_ef(fs): sampled = Array(8) - poseidon16_compress(ZERO_VEC_PTR, fs, sampled) + poseidon8_compress(ZERO_VEC_PTR, fs, sampled) new_fs = Array(9) - poseidon16_compress(SAMPLING_DOMAIN_SEPARATOR_PTR, fs, new_fs) + poseidon8_compress(SAMPLING_DOMAIN_SEPARATOR_PTR, fs, new_fs) new_fs[8] = fs[8] # same transcript pointer return new_fs, sampled @@ -113,9 +113,9 @@ def fs_receive_chunks(fs, n_chunks: Const): transcript_ptr = fs[8] new_fs[8 * n_chunks] = transcript_ptr + 8 * n_chunks # advance transcript pointer - poseidon16_compress(fs, transcript_ptr, new_fs) + poseidon8_compress(fs, transcript_ptr, new_fs) for i in unroll(1, n_chunks): - poseidon16_compress( + poseidon8_compress( new_fs + ((i - 1) * 8), transcript_ptr + i * 8, new_fs + i * 8, @@ -161,7 +161,7 @@ def fs_sample_data_with_offset(fs, n_chunks: Const, offset): domain_sep = Array(8) domain_sep[0] = offset + i set_to_7_zeros(domain_sep + 1) - poseidon16_compress(domain_sep, fs, sampled + i * 8) + poseidon8_compress(domain_sep, fs, sampled + i * 8) return sampled @@ -172,7 +172,7 @@ def fs_finalize_sample(fs, total_n_chunks): domain_sep = Array(8) domain_sep[0] = total_n_chunks set_to_7_zeros(domain_sep + 1) - poseidon16_compress(domain_sep, fs, new_fs) + poseidon8_compress(domain_sep, fs, new_fs) new_fs[8] = fs[8] # same transcript pointer return new_fs diff --git a/crates/rec_aggregation/hashing.py b/crates/rec_aggregation/hashing.py index d90448ba..5613491d 100644 --- a/crates/rec_aggregation/hashing.py +++ b/crates/rec_aggregation/hashing.py @@ -1,7 +1,7 @@ from snark_lib import * -DIM = 5 # extension degree -DIGEST_LEN = 8 +DIM = 3 # extension degree (Goldilocks cubic extension) +DIGEST_LEN = 4 # memory layout: [public_input (PUBLIC_INPUT_LEN)] [preamble_memory (PREAMBLE_MEMORY_LEN)] [runtime ...] # `preamble_memory` is a region that is filled by the guest program, with usefull constants [0000...][1000...]... @@ -58,18 +58,18 @@ def batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hash def slice_hash_rtl(data, num_chunks): states = Array((num_chunks - 1) * DIGEST_LEN) - poseidon16_compress(data + (num_chunks - 2) * DIGEST_LEN, data + (num_chunks - 1) * DIGEST_LEN, states) + poseidon8_compress(data + (num_chunks - 2) * DIGEST_LEN, data + (num_chunks - 1) * DIGEST_LEN, states) for j in unroll(1, num_chunks - 1): - poseidon16_compress(states + (j - 1) * DIGEST_LEN, data + (num_chunks - 2 - j) * DIGEST_LEN, states + j * DIGEST_LEN) + poseidon8_compress(states + (j - 1) * DIGEST_LEN, data + (num_chunks - 2 - j) * DIGEST_LEN, states + j * DIGEST_LEN) return states + (num_chunks - 2) * DIGEST_LEN @inline def slice_hash(data, num_chunks): states = Array((num_chunks - 1) * DIGEST_LEN) - poseidon16_compress(data, data + DIGEST_LEN, states) + poseidon8_compress(data, data + DIGEST_LEN, states) for j in unroll(1, num_chunks - 1): - poseidon16_compress(states + (j - 1) * DIGEST_LEN, data + (j + 1) * DIGEST_LEN, states + j * DIGEST_LEN) + poseidon8_compress(states + (j - 1) * DIGEST_LEN, data + (j + 1) * DIGEST_LEN, states + j * DIGEST_LEN) return states + (num_chunks - 2) * DIGEST_LEN @@ -77,9 +77,9 @@ def slice_hash(data, num_chunks): def slice_hash_with_iv(data, num_chunks): debug_assert(0 < num_chunks) states = Array(num_chunks * DIGEST_LEN) - poseidon16_compress(ZERO_VEC_PTR, data, states) + poseidon8_compress(ZERO_VEC_PTR, data, states) for j in unroll(1, num_chunks): - poseidon16_compress(states + (j - 1) * DIGEST_LEN, data + j * DIGEST_LEN, states + j * DIGEST_LEN) + poseidon8_compress(states + (j - 1) * DIGEST_LEN, data + j * DIGEST_LEN, states + j * DIGEST_LEN) return states + (num_chunks - 1) * DIGEST_LEN @@ -92,21 +92,21 @@ def slice_hash_with_iv_dynamic_unroll(data, len, len_bits: Const): left = Array(DIGEST_LEN) fill_padded_chunk(left, data, remainder) result = Array(DIGEST_LEN) - poseidon16_compress(ZERO_VEC_PTR, left, result) + poseidon8_compress(ZERO_VEC_PTR, left, result) return result if num_full_chunks == 1: if remainder == 0: result = Array(DIGEST_LEN) - poseidon16_compress(ZERO_VEC_PTR, data, result) + poseidon8_compress(ZERO_VEC_PTR, data, result) return result else: h0 = Array(DIGEST_LEN) - poseidon16_compress(ZERO_VEC_PTR, data, h0) + poseidon8_compress(ZERO_VEC_PTR, data, h0) right = Array(DIGEST_LEN) fill_padded_chunk(right, data + DIGEST_LEN, remainder) result = Array(DIGEST_LEN) - poseidon16_compress(h0, right, result) + poseidon8_compress(h0, right, result) return result partial_hash = slice_hash_chunks_with_iv(data, num_full_chunks, len_bits) @@ -116,7 +116,7 @@ def slice_hash_with_iv_dynamic_unroll(data, len, len_bits: Const): padded_last = Array(DIGEST_LEN) fill_padded_chunk(padded_last, data + num_full_elements, remainder) final_hash = Array(DIGEST_LEN) - poseidon16_compress(partial_hash, padded_last, final_hash) + poseidon8_compress(partial_hash, padded_last, final_hash) return final_hash @@ -124,13 +124,13 @@ def slice_hash_with_iv_dynamic_unroll(data, len, len_bits: Const): def slice_hash_chunks_with_iv(data, num_chunks, num_chunks_bits): debug_assert(1 < num_chunks) states = Array(num_chunks * DIGEST_LEN) - poseidon16_compress(ZERO_VEC_PTR, data, states) + poseidon8_compress(ZERO_VEC_PTR, data, states) n_iters = num_chunks - 1 state_ptr: Mut = states data_ptr: Mut = data + DIGEST_LEN for _ in dynamic_unroll(0, n_iters, num_chunks_bits): new_state = state_ptr + DIGEST_LEN - poseidon16_compress(state_ptr, data_ptr, new_state) + poseidon8_compress(state_ptr, data_ptr, new_state) state_ptr = new_state data_ptr = data_ptr + DIGEST_LEN return state_ptr @@ -180,24 +180,24 @@ def whir_do_4_merkle_levels(b, state_in, path_chunk, state_out): temps = Array(3 * DIGEST_LEN) if b0 == 0: - poseidon16_compress(state_in, path_chunk, temps) + poseidon8_compress(state_in, path_chunk, temps) else: - poseidon16_compress(path_chunk, state_in, temps) + poseidon8_compress(path_chunk, state_in, temps) if b1 == 0: - poseidon16_compress(temps, path_chunk + DIGEST_LEN, temps + DIGEST_LEN) + poseidon8_compress(temps, path_chunk + DIGEST_LEN, temps + DIGEST_LEN) else: - poseidon16_compress(path_chunk + DIGEST_LEN, temps, temps + DIGEST_LEN) + poseidon8_compress(path_chunk + DIGEST_LEN, temps, temps + DIGEST_LEN) if b2 == 0: - poseidon16_compress(temps + DIGEST_LEN, path_chunk + 2 * DIGEST_LEN, temps + 2 * DIGEST_LEN) + poseidon8_compress(temps + DIGEST_LEN, path_chunk + 2 * DIGEST_LEN, temps + 2 * DIGEST_LEN) else: - poseidon16_compress(path_chunk + 2 * DIGEST_LEN, temps + DIGEST_LEN, temps + 2 * DIGEST_LEN) + poseidon8_compress(path_chunk + 2 * DIGEST_LEN, temps + DIGEST_LEN, temps + 2 * DIGEST_LEN) if b3 == 0: - poseidon16_compress(temps + 2 * DIGEST_LEN, path_chunk + 3 * DIGEST_LEN, state_out) + poseidon8_compress(temps + 2 * DIGEST_LEN, path_chunk + 3 * DIGEST_LEN, state_out) else: - poseidon16_compress(path_chunk + 3 * DIGEST_LEN, temps + 2 * DIGEST_LEN, state_out) + poseidon8_compress(path_chunk + 3 * DIGEST_LEN, temps + 2 * DIGEST_LEN, state_out) return @@ -212,19 +212,19 @@ def whir_do_3_merkle_levels(b, state_in, path_chunk, state_out): temps = Array(2 * DIGEST_LEN) if b0 == 0: - poseidon16_compress(state_in, path_chunk, temps) + poseidon8_compress(state_in, path_chunk, temps) else: - poseidon16_compress(path_chunk, state_in, temps) + poseidon8_compress(path_chunk, state_in, temps) if b1 == 0: - poseidon16_compress(temps, path_chunk + DIGEST_LEN, temps + DIGEST_LEN) + poseidon8_compress(temps, path_chunk + DIGEST_LEN, temps + DIGEST_LEN) else: - poseidon16_compress(path_chunk + DIGEST_LEN, temps, temps + DIGEST_LEN) + poseidon8_compress(path_chunk + DIGEST_LEN, temps, temps + DIGEST_LEN) if b2 == 0: - poseidon16_compress(temps + DIGEST_LEN, path_chunk + 2 * DIGEST_LEN, state_out) + poseidon8_compress(temps + DIGEST_LEN, path_chunk + 2 * DIGEST_LEN, state_out) else: - poseidon16_compress(path_chunk + 2 * DIGEST_LEN, temps + DIGEST_LEN, state_out) + poseidon8_compress(path_chunk + 2 * DIGEST_LEN, temps + DIGEST_LEN, state_out) return @@ -237,14 +237,14 @@ def whir_do_2_merkle_levels(b, state_in, path_chunk, state_out): temp = Array(DIGEST_LEN) if b0 == 0: - poseidon16_compress(state_in, path_chunk, temp) + poseidon8_compress(state_in, path_chunk, temp) else: - poseidon16_compress(path_chunk, state_in, temp) + poseidon8_compress(path_chunk, state_in, temp) if b1 == 0: - poseidon16_compress(temp, path_chunk + DIGEST_LEN, state_out) + poseidon8_compress(temp, path_chunk + DIGEST_LEN, state_out) else: - poseidon16_compress(path_chunk + DIGEST_LEN, temp, state_out) + poseidon8_compress(path_chunk + DIGEST_LEN, temp, state_out) return @@ -253,9 +253,9 @@ def whir_do_1_merkle_level(b, state_in, path_chunk, state_out): b0 = b % 2 if b0 == 0: - poseidon16_compress(state_in, path_chunk, state_out) + poseidon8_compress(state_in, path_chunk, state_out) else: - poseidon16_compress(path_chunk, state_in, state_out) + poseidon8_compress(path_chunk, state_in, state_out) return @@ -300,22 +300,22 @@ def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, root, height: Co # First merkle round match leaf_position_bits[0]: case 0: - poseidon16_compress(leaf_digest, merkle_path, states) + poseidon8_compress(leaf_digest, merkle_path, states) case 1: - poseidon16_compress(merkle_path, leaf_digest, states) + poseidon8_compress(merkle_path, leaf_digest, states) # Remaining merkle rounds for j in unroll(1, height): # Warning: this works only if leaf_position_bits[i] is known to be boolean: match leaf_position_bits[j]: case 0: - poseidon16_compress( + poseidon8_compress( states + (j - 1) * DIGEST_LEN, merkle_path + j * DIGEST_LEN, states + j * DIGEST_LEN, ) case 1: - poseidon16_compress( + poseidon8_compress( merkle_path + j * DIGEST_LEN, states + (j - 1) * DIGEST_LEN, states + j * DIGEST_LEN, diff --git a/crates/rec_aggregation/main.py b/crates/rec_aggregation/main.py index 2d043672..82f02754 100644 --- a/crates/rec_aggregation/main.py +++ b/crates/rec_aggregation/main.py @@ -111,7 +111,7 @@ def main(): counter += 1 pk0 = all_pubkeys + idx0 * DIGEST_LEN running_hash: Mut = Array(DIGEST_LEN) - poseidon16_compress(ZERO_VEC_PTR, pk0, running_hash) + poseidon8_compress(ZERO_VEC_PTR, pk0, running_hash) for j in dynamic_unroll(1, n_sub, log2_ceil(MAX_N_SIGS)): idx = sub_indices_arr[j] @@ -120,7 +120,7 @@ def main(): counter += 1 pk = all_pubkeys + idx * DIGEST_LEN new_hash = Array(DIGEST_LEN) - poseidon16_compress(running_hash, pk, new_hash) + poseidon8_compress(running_hash, pk, new_hash) running_hash = new_hash inner_data_buf = build_inner_data_buf( @@ -161,7 +161,7 @@ def reduce_bytecode_claims(bytecode_claims, n_bytecode_claims, bytecode_claim_ou assert claim_ptr[k] == 0 claim_hash = slice_hash(claim_ptr, BYTECODE_CLAIM_SIZE_PADDED / DIGEST_LEN) new_hash = Array(DIGEST_LEN) - poseidon16_compress(bytecode_claims_hash, claim_hash, new_hash) + poseidon8_compress(bytecode_claims_hash, claim_hash, new_hash) bytecode_claims_hash = new_hash bytecode_sumcheck_proof = Array(BYTECODE_SUMCHECK_PROOF_SIZE) diff --git a/crates/rec_aggregation/utils.py b/crates/rec_aggregation/utils.py index 2d1a7034..17c7680e 100644 --- a/crates/rec_aggregation/utils.py +++ b/crates/rec_aggregation/utils.py @@ -1,10 +1,10 @@ from snark_lib import * from hashing import * -F_BITS = 31 # koala-bear = 31 bits +F_BITS = 64 # Goldilocks (P = 2^64 - 2^32 + 1, values fit in u64) -TWO_ADICITY = 24 -ROOT = 1791270792 # of order 2^TWO_ADICITY +TWO_ADICITY = 32 +ROOT = 1753635133440165772 # = 0x185629dcda58878c, of order 2^TWO_ADICITY @inline @@ -352,14 +352,27 @@ def sub_extension_ret(a, b): return c +# Legacy copy_N / set_to_N_zeros names from the KoalaBear era (DIM=5, DIGEST_LEN=8, +# MESSAGE_LEN=9) are kept for minimal churn but redefined to Goldilocks sizes +# (DIM=3, DIGEST_LEN=4, MESSAGE_LEN=4). Semantic roles: +# copy_5 / set_to_5_zeros → one extension-field element (DIM entries). +# copy_8 / set_to_8_zeros → one digest (DIGEST_LEN entries). +# copy_9 → one message (MESSAGE_LEN = DIGEST_LEN entries now). +# set_to_7_zeros → digest tail after a domain-sep slot: DIGEST_LEN-1 +# = DIM entries under Goldilocks. +# copy_16 → a full Poseidon input state (2 × DIGEST_LEN entries). + + @inline def copy_5(a, b): + # Copy DIM=3 elements (one extension-field element). dot_product_ee(a, ONE_EF_PTR, b) return @inline def set_to_5_zeros(a): + # Zero DIM=3 elements. zero_ptr = ZERO_VEC_PTR dot_product_ee(a, ONE_EF_PTR, zero_ptr) return @@ -367,40 +380,43 @@ def set_to_5_zeros(a): @inline def set_to_7_zeros(a): + # Zero DIGEST_LEN-1 = 3 elements (= DIM). Used after writing a domain-sep + # byte into slot 0 of a digest-sized buffer. zero_ptr = ZERO_VEC_PTR dot_product_ee(a, ONE_EF_PTR, zero_ptr) - a[5] = 0 - a[6] = 0 return @inline def set_to_8_zeros(a): + # Zero DIGEST_LEN=4 elements via two overlapping DIM=3 clears. zero_ptr = ZERO_VEC_PTR dot_product_ee(a, ONE_EF_PTR, zero_ptr) - dot_product_ee(a + (8 - DIM), ONE_EF_PTR, zero_ptr) + dot_product_ee(a + (DIGEST_LEN - DIM), ONE_EF_PTR, zero_ptr) return @inline def copy_8(a, b): + # Copy DIGEST_LEN=4 elements via two overlapping DIM=3 copies. dot_product_ee(a, ONE_EF_PTR, b) - dot_product_ee(a + (8 - DIM), ONE_EF_PTR, b + (8 - DIM)) + dot_product_ee(a + (DIGEST_LEN - DIM), ONE_EF_PTR, b + (DIGEST_LEN - DIM)) return @inline def copy_9(a, b): + # Copy MESSAGE_LEN=4 elements (equal to DIGEST_LEN under Goldilocks). dot_product_ee(a, ONE_EF_PTR, b) - dot_product_ee(a + (9 - DIM), ONE_EF_PTR, b + (9 - DIM)) + dot_product_ee(a + (DIGEST_LEN - DIM), ONE_EF_PTR, b + (DIGEST_LEN - DIM)) return + @inline def copy_16(a, b): - dot_product_ee(a, ONE_EF_PTR, b) - dot_product_ee(a + 5, ONE_EF_PTR, b + 5) - dot_product_ee(a + 10, ONE_EF_PTR, b + 10) - a[15] = b[15] + # Copy a full Poseidon input block = 2 × DIGEST_LEN = 8 elements. + copy_8(a, b) + copy_8(a + DIGEST_LEN, b + DIGEST_LEN) return diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index 9c51fd5e..27a71979 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -31,12 +31,12 @@ def xmss_verify(merkle_root, message, slot_lo, slot_hi, merkle_chunks): b_input = Array(DIGEST_LEN * 2) a_input_right[0] = message[DIGEST_LEN] copy_7(randomness, a_input_right + 1) - poseidon16_compress(message, a_input_right, b_input) + poseidon8_compress(message, a_input_right, b_input) b_input[DIGEST_LEN] = slot_lo b_input[DIGEST_LEN + 1] = slot_hi copy_6(merkle_root, b_input + DIGEST_LEN + 2) encoding_fe = Array(DIGEST_LEN) - poseidon16_compress(b_input, b_input + DIGEST_LEN, encoding_fe) + poseidon8_compress(b_input, b_input + DIGEST_LEN, encoding_fe) encoding = Array(NUM_ENCODING_FE * 24 / (2 * W)) @@ -100,13 +100,13 @@ def chain_hash(input_left, n, output_left, pair_chain_length_sum_ptr, local_zero if n_left == 0: copy_8(input_left, output_left) elif n_left == 1: - poseidon16_compress(input_left, local_zero_buff, output_left) + poseidon8_compress(input_left, local_zero_buff, output_left) else: states_left = Array((n_left - 1) * DIGEST_LEN) - poseidon16_compress(input_left, local_zero_buff, states_left) + poseidon8_compress(input_left, local_zero_buff, states_left) for i in unroll(1, n_left - 1): - poseidon16_compress(states_left + (i - 1) * DIGEST_LEN, local_zero_buff, states_left + i * DIGEST_LEN) - poseidon16_compress(states_left + (n_left - 2) * DIGEST_LEN, local_zero_buff, output_left) + poseidon8_compress(states_left + (i - 1) * DIGEST_LEN, local_zero_buff, states_left + i * DIGEST_LEN) + poseidon8_compress(states_left + (n_left - 2) * DIGEST_LEN, local_zero_buff, output_left) n_right = (CHAIN_LENGTH - 1) - raw_right debug_assert(raw_right < CHAIN_LENGTH) @@ -115,13 +115,13 @@ def chain_hash(input_left, n, output_left, pair_chain_length_sum_ptr, local_zero if n_right == 0: copy_8(input_right, output_right) elif n_right == 1: - poseidon16_compress(input_right, local_zero_buff, output_right) + poseidon8_compress(input_right, local_zero_buff, output_right) else: states_right = Array((n_right - 1) * DIGEST_LEN) - poseidon16_compress(input_right, local_zero_buff, states_right) + poseidon8_compress(input_right, local_zero_buff, states_right) for i in unroll(1, n_right - 1): - poseidon16_compress(states_right + (i - 1) * DIGEST_LEN, local_zero_buff, states_right + i * DIGEST_LEN) - poseidon16_compress(states_right + (n_right - 2) * DIGEST_LEN, local_zero_buff, output_right) + poseidon8_compress(states_right + (i - 1) * DIGEST_LEN, local_zero_buff, states_right + i * DIGEST_LEN) + poseidon8_compress(states_right + (n_right - 2) * DIGEST_LEN, local_zero_buff, output_right) pair_chain_length_sum_ptr[0] = raw_left + raw_right @@ -143,27 +143,27 @@ def do_4_merkle_levels(b, state_in, path_chunk, state_out): # Level 0: state_in -> temps if b0 == 0: - poseidon16_compress(path_chunk, state_in, temps) + poseidon8_compress(path_chunk, state_in, temps) else: - poseidon16_compress(state_in, path_chunk, temps) + poseidon8_compress(state_in, path_chunk, temps) # Level 1 if b1 == 0: - poseidon16_compress(path_chunk + 1 * DIGEST_LEN, temps, temps + DIGEST_LEN) + poseidon8_compress(path_chunk + 1 * DIGEST_LEN, temps, temps + DIGEST_LEN) else: - poseidon16_compress(temps, path_chunk + 1 * DIGEST_LEN, temps + DIGEST_LEN) + poseidon8_compress(temps, path_chunk + 1 * DIGEST_LEN, temps + DIGEST_LEN) # Level 2 if b2 == 0: - poseidon16_compress(path_chunk + 2 * DIGEST_LEN, temps + DIGEST_LEN, temps + 2 * DIGEST_LEN) + poseidon8_compress(path_chunk + 2 * DIGEST_LEN, temps + DIGEST_LEN, temps + 2 * DIGEST_LEN) else: - poseidon16_compress(temps + DIGEST_LEN, path_chunk + 2 * DIGEST_LEN, temps + 2 * DIGEST_LEN) + poseidon8_compress(temps + DIGEST_LEN, path_chunk + 2 * DIGEST_LEN, temps + 2 * DIGEST_LEN) # Level 3: -> state_out if b3 == 0: - poseidon16_compress(path_chunk + 3 * DIGEST_LEN, temps + 2 * DIGEST_LEN, state_out) + poseidon8_compress(path_chunk + 3 * DIGEST_LEN, temps + 2 * DIGEST_LEN, state_out) else: - poseidon16_compress(temps + 2 * DIGEST_LEN, path_chunk + 3 * DIGEST_LEN, state_out) + poseidon8_compress(temps + 2 * DIGEST_LEN, path_chunk + 3 * DIGEST_LEN, state_out) return From ae2401de168197e21ea0cb8607503637148ce6b7 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 16 Apr 2026 12:58:32 +0200 Subject: [PATCH 10/31] wip --- crates/lean_vm/src/isa/hint.rs | 49 ++++--- crates/rec_aggregation/utils.py | 43 +++--- crates/rec_aggregation/whir.py | 4 +- crates/rec_aggregation/xmss_aggregate.py | 160 ++++++++++++----------- crates/xmss/src/wots.rs | 27 ++-- 5 files changed, 146 insertions(+), 137 deletions(-) diff --git a/crates/lean_vm/src/isa/hint.rs b/crates/lean_vm/src/isa/hint.rs index dd4058be..c9a2eb39 100644 --- a/crates/lean_vm/src/isa/hint.rs +++ b/crates/lean_vm/src/isa/hint.rs @@ -129,8 +129,8 @@ impl CustomHint { pub fn n_args(&self) -> usize { match self { - Self::DecomposeBitsXMSS => 4, - Self::DecomposeBitsMerkleWhir => 3, + Self::DecomposeBitsXMSS => 5, + Self::DecomposeBitsMerkleWhir => 4, Self::DecomposeBits => 4, Self::LessThan => 3, Self::Log2Ceil => 2, @@ -144,33 +144,42 @@ impl CustomHint { ) -> Result<(), RunnerError> { match self { Self::DecomposeBitsXMSS => { + // Decompose `num_fe` field elements into `chunks_per_fe` chunks of + // `chunk_size` bits each. Extracts the low `chunks_per_fe * chunk_size` + // bits of each FE's canonical u64 representation (caller is responsible + // for ensuring `chunks_per_fe * chunk_size <= F::bits()`). let decomposed_ptr = args[0].read_value(ctx.memory, ctx.fp)?.to_usize(); let to_decompose_ptr = args[1].read_value(ctx.memory, ctx.fp)?.to_usize(); - let num_to_decompose = args[2].read_value(ctx.memory, ctx.fp)?.to_usize(); - let chunk_size = args[3].read_value(ctx.memory, ctx.fp)?.to_usize(); - assert!(24_usize.is_multiple_of(chunk_size)); + let num_fe = args[2].read_value(ctx.memory, ctx.fp)?.to_usize(); + let chunks_per_fe = args[3].read_value(ctx.memory, ctx.fp)?.to_usize(); + let chunk_size = args[4].read_value(ctx.memory, ctx.fp)?.to_usize(); + assert!(chunks_per_fe * chunk_size <= F::bits()); let mut memory_index_decomposed = decomposed_ptr; - #[allow(clippy::explicit_counter_loop)] - for i in 0..num_to_decompose { - let value = ctx.memory.get(to_decompose_ptr + i)?.to_usize(); - for i in 0..24 / chunk_size { - let value = F::from_usize((value >> (chunk_size * i)) & ((1 << chunk_size) - 1)); - ctx.memory.set(memory_index_decomposed, value)?; + for i in 0..num_fe { + let value = ctx.memory.get(to_decompose_ptr + i)?.as_canonical_u64(); + for j in 0..chunks_per_fe { + let chunk = F::from_u64( + (value >> (chunk_size * j)) & ((1u64 << chunk_size) - 1), + ); + ctx.memory.set(memory_index_decomposed, chunk)?; memory_index_decomposed += 1; } } } Self::DecomposeBitsMerkleWhir => { + // Decompose a single FE's canonical u64 into `num_chunks` chunks of + // `chunk_size` bits (low bits first). Caller must ensure + // `num_chunks * chunk_size <= F::bits()`. let decomposed_ptr = args[0].read_value(ctx.memory, ctx.fp)?.to_usize(); - let value = args[1].read_value(ctx.memory, ctx.fp)?.to_usize(); - let chunk_size = args[2].read_value(ctx.memory, ctx.fp)?.to_usize(); - assert!(24_usize.is_multiple_of(chunk_size)); - let mut memory_index_decomposed = decomposed_ptr; - #[allow(clippy::explicit_counter_loop)] - for i in 0..24 / chunk_size { - let value = F::from_usize((value >> (chunk_size * i)) & ((1 << chunk_size) - 1)); - ctx.memory.set(memory_index_decomposed, value)?; - memory_index_decomposed += 1; + let value = args[1].read_value(ctx.memory, ctx.fp)?.as_canonical_u64(); + let num_chunks = args[2].read_value(ctx.memory, ctx.fp)?.to_usize(); + let chunk_size = args[3].read_value(ctx.memory, ctx.fp)?.to_usize(); + assert!(num_chunks * chunk_size <= F::bits()); + for j in 0..num_chunks { + let chunk = F::from_u64( + (value >> (chunk_size * j)) & ((1u64 << chunk_size) - 1), + ); + ctx.memory.set(decomposed_ptr + j, chunk)?; } } Self::DecomposeBits => { diff --git a/crates/rec_aggregation/utils.py b/crates/rec_aggregation/utils.py index 17c7680e..607e2728 100644 --- a/crates/rec_aggregation/utils.py +++ b/crates/rec_aggregation/utils.py @@ -523,25 +523,22 @@ def whir_1_merkle_step_and_pow(v, state_in, path_chunk, state_out, power_shift): return ROOT ** (power_shift * (v % 2)) -@inline -def decompose_and_verify_merkle_query(a, domain_size, prev_root, num_chunks): - nibbles = Array(6) - hint_decompose_bits_merkle_whir(nibbles, a, 4) - - for i in unroll(0, 6): +def decompose_and_verify_merkle_query(a, domain_size: Const, prev_root, num_chunks: Const): + # Goldilocks FRI: query indices fit in TWO_ADICITY = 32 bits. Decompose `a` + # into 8 × 4-bit nibbles and assert `a == partial_sum`; that single equality + # enforces both the decomposition and `a < 2^32` (since partial_sum ≤ 2^32−1). + NUM_NIBBLES = 8 + nibbles = Array(NUM_NIBBLES) + hint_decompose_bits_merkle_whir(nibbles, a, NUM_NIBBLES, 4) + + for i in unroll(0, NUM_NIBBLES): assert nibbles[i] < 16 partial_sum: Mut = nibbles[0] - for i in unroll(1, 6): + for i in unroll(1, NUM_NIBBLES): partial_sum += nibbles[i] * 16**i - # p = 2^31 - 2^24 + 1, so 2^24 * 127 = p - 1 ≡ -1 (mod p), hence inv(2^24) = -127. - # Deduce top7 from the identity partial_sum + top7 * 2^24 == a: - # top7 = (a - partial_sum) * inv(2^24) = (partial_sum - a) * 127 - top7 = (partial_sum - a) * 127 - assert top7 < 2**7 - if top7 == 2**7 - 1: - assert partial_sum == 0 + assert a == partial_sum leaf_data = Array(num_chunks * DIGEST_LEN) hint_witness("merkle_leaf", leaf_data) @@ -556,16 +553,15 @@ def decompose_and_verify_merkle_query(a, domain_size, prev_root, num_chunks): prod: Mut = 1 # First nibble: leaf_hash -> states[0] - nib_pow = match_range( + prod *= match_range( nibbles[0], range(0, 16), lambda v: whir_4_merkle_step_and_pow(v, leaf_hash, merkle_path, states, 2 ** (TWO_ADICITY - domain_size)), ) - prod *= nib_pow # Middle nibbles: states[k-1] -> states[k] for k in unroll(1, n_nibbles - 1): - nib_pow = match_range( + prod *= match_range( nibbles[k], range(0, 16), lambda v: whir_4_merkle_step_and_pow( @@ -576,7 +572,6 @@ def decompose_and_verify_merkle_query(a, domain_size, prev_root, num_chunks): 2 ** (TWO_ADICITY - domain_size + 4 * k), ), ) - prod *= nib_pow # Last nibble: states[-1] -> prev_root last_k = n_nibbles - 1 @@ -584,33 +579,29 @@ def decompose_and_verify_merkle_query(a, domain_size, prev_root, num_chunks): last_path = merkle_path + 4 * last_k * DIGEST_LEN last_power_shift = 2 ** (TWO_ADICITY - domain_size + 4 * last_k) if domain_size % 4 == 0: - nib_pow = match_range( + prod *= match_range( nibbles[last_k], range(0, 16), lambda v: whir_4_merkle_step_and_pow(v, last_state_in, last_path, prev_root, last_power_shift), ) - prod *= nib_pow elif domain_size % 4 == 1: - nib_pow = match_range( + prod *= match_range( nibbles[last_k], range(0, 16), lambda v: whir_1_merkle_step_and_pow(v, last_state_in, last_path, prev_root, last_power_shift), ) - prod *= nib_pow elif domain_size % 4 == 2: - nib_pow = match_range( + prod *= match_range( nibbles[last_k], range(0, 16), lambda v: whir_2_merkle_step_and_pow(v, last_state_in, last_path, prev_root, last_power_shift), ) - prod *= nib_pow elif domain_size % 4 == 3: - nib_pow = match_range( + prod *= match_range( nibbles[last_k], range(0, 16), lambda v: whir_3_merkle_step_and_pow(v, last_state_in, last_path, prev_root, last_power_shift), ) - prod *= nib_pow return leaf_data, prod diff --git a/crates/rec_aggregation/whir.py b/crates/rec_aggregation/whir.py index c7b21b50..ecdf962f 100644 --- a/crates/rec_aggregation/whir.py +++ b/crates/rec_aggregation/whir.py @@ -235,7 +235,9 @@ def decompose_and_verify_merkle_batch_with_height(num_queries, sampled, root, he def decompose_and_verify_merkle_batch_const(num_queries, sampled, root, height: Const, num_chunks: Const, circle_values, merkle_leaves): for i in range(0, num_queries): - merkle_leaves[i], circle_values[i] = decompose_and_verify_merkle_query(sampled[i], height, root, num_chunks) + leaf, circle = decompose_and_verify_merkle_query(sampled[i], height, root, num_chunks) + merkle_leaves[i] = leaf + circle_values[i] = circle return diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index 27a71979..a103119d 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -10,7 +10,13 @@ MESSAGE_LEN = MESSAGE_LEN_PLACEHOLDER RANDOMNESS_LEN = RANDOMNESS_LEN_PLACEHOLDER SIG_SIZE = RANDOMNESS_LEN + (V + LOG_LIFETIME) * DIGEST_LEN -NUM_ENCODING_FE = div_ceil((V + V_GRINDING), (24 / W)) # 24 should be divisible by W (works for W=2,3,4) +# Goldilocks encoding: 3 Poseidon8 output FE, each decomposed into CHUNKS_PER_FE +# chunks of W bits. We assume ~64 bits of uniform entropy per FE and consume +# CHUNKS_PER_FE*W = 63 bits per FE (1-bit remainder). 3*21 = 63 total chunks, +# of which the first V+V_GRINDING are the Winternitz indices. Must match +# `crates/xmss/src/wots.rs::wots_encode`. +NUM_ENCODING_FE = 3 +CHUNKS_PER_FE = 21 MERKLE_LEVELS_PER_CHUNK = MERKLE_LEVELS_PER_CHUNK_PLACEHOLDER N_MERKLE_CHUNKS = LOG_LIFETIME / MERKLE_LEVELS_PER_CHUNK @@ -25,60 +31,71 @@ def xmss_verify(merkle_root, message, slot_lo, slot_hi, merkle_chunks): chain_starts = signature + RANDOMNESS_LEN merkle_path = chain_starts + V * DIGEST_LEN - # 1) We encode message_hash + randomness into the layer of the hypercube with target sum = TARGET_SUM - - a_input_right = Array(DIGEST_LEN) - b_input = Array(DIGEST_LEN * 2) - a_input_right[0] = message[DIGEST_LEN] - copy_7(randomness, a_input_right + 1) - poseidon8_compress(message, a_input_right, b_input) - b_input[DIGEST_LEN] = slot_lo - b_input[DIGEST_LEN + 1] = slot_hi - copy_6(merkle_root, b_input + DIGEST_LEN + 2) + # 1) Hash (message, randomness, slot, merkle_root) into 3 output FE via a + # 3-call Poseidon8 sponge chain, mirroring `poseidon_compress_slice` on + # 14 input FE in the Rust side. + # + # Call 1: poseidon8(message[0..4], randomness[0..4]) → a + a = Array(DIGEST_LEN) + poseidon8_compress(message, randomness, a) + + # Call 2: poseidon8(a, [slot_lo, slot_hi, root[0], root[1]]) → b + rhs2 = Array(DIGEST_LEN) + rhs2[0] = slot_lo + rhs2[1] = slot_hi + rhs2[2] = merkle_root[0] + rhs2[3] = merkle_root[1] + b = Array(DIGEST_LEN) + poseidon8_compress(a, rhs2, b) + + # Call 3: poseidon8(b, [root[2], root[3], 0, 0]) → encoding_fe (4 FE; we use the first 3) + rhs3 = Array(DIGEST_LEN) + rhs3[0] = merkle_root[2] + rhs3[1] = merkle_root[3] + rhs3[2] = 0 + rhs3[3] = 0 encoding_fe = Array(DIGEST_LEN) - poseidon8_compress(b_input, b_input + DIGEST_LEN, encoding_fe) - - encoding = Array(NUM_ENCODING_FE * 24 / (2 * W)) - - hint_decompose_bits_xmss(encoding, encoding_fe, NUM_ENCODING_FE, 2 * W) - - # check that the decomposition is correct - for i in unroll(0, NUM_ENCODING_FE): - for j in unroll(0, 24 / (2 * W)): - assert encoding[i * (24 / (2 * W)) + j] < CHAIN_LENGTH**2 - - partial_sum: Mut = encoding[i * (24 / (2 * W))] - for j in unroll(1, 24 / (2 * W)): - partial_sum += encoding[i * (24 / (2 * W)) + j] * (CHAIN_LENGTH**2) ** j - - # p = 2^31 - 2^24 + 1, so inv(2^24) = -127 (mod p). - # Deduce remaining_i from partial_sum + remaining_i * 2^24 == encoding_fe[i]: - # remaining_i = (encoding_fe[i] - partial_sum) * inv(2^24) = (partial_sum - encoding_fe[i]) * 127 - remaining_i = (partial_sum - encoding_fe[i]) * 127 - assert remaining_i < 2**7 - 1 # ensures uniformity + prevent overflow - - # grinding - debug_assert(V_GRINDING % 2 == 0) - debug_assert(V % 2 == 0) - for i in unroll(V / 2, (V + V_GRINDING) / 2): - assert encoding[i] == CHAIN_LENGTH**2 - 1 - + poseidon8_compress(b, rhs3, encoding_fe) + + # 2) Decompose each of the first 3 FE into 21 3-bit chunks = 63 bits per FE + # (1-bit remainder). 3 × 21 = 63 total chunks; first V+V_GRINDING used. + encoding = Array(63) + hint_decompose_bits_xmss(encoding, encoding_fe, 3, 21, W) + + # Each chunk must be a valid W-bit Winternitz index. + for i in unroll(0, 63): + assert encoding[i] < CHAIN_LENGTH + + # For each FE: partial_sum = Σ_j encoding[i*K+j] * 2^(W*j) is the low 63 + # bits of encoding_fe[i]; the remainder is a single bit. Factorise the + # remainder equality so no inverse is needed: + # (encoding_fe[i] − partial_sum) · (encoding_fe[i] − partial_sum − 2^63) == 0 + for i in unroll(0, 3): + partial_sum: Mut = encoding[i * 21] + for j in unroll(1, 21): + partial_sum += encoding[i * 21 + j] * (CHAIN_LENGTH ** j) + diff = encoding_fe[i] - partial_sum + assert diff * (diff - 2**63) == 0 + + # Grinding: last V_GRINDING indices must each be CHAIN_LENGTH - 1. + for i in unroll(V, V + V_GRINDING): + assert encoding[i] == CHAIN_LENGTH - 1 + + # 3) Chain-hash each of the V WOTS secret-key tips. target_sum: Mut = 0 - wots_public_key = Array(V * DIGEST_LEN) - local_zero_buff = Array(DIGEST_LEN) set_to_8_zeros(local_zero_buff) - for i in unroll(0, V / 2): - # num_hashes = (CHAIN_LENGTH - 1) - encoding[i] - chain_start = chain_starts + i * (DIGEST_LEN * 2) - chain_end = wots_public_key + i * (DIGEST_LEN * 2) - pair_chain_length_sum_ptr = Array(1) + for i in unroll(0, V): + chain_start = chain_starts + i * DIGEST_LEN + chain_end = wots_public_key + i * DIGEST_LEN + chain_length_ptr = Array(1) match_range( - encoding[i], range(0, CHAIN_LENGTH**2), lambda n: chain_hash(chain_start, n, chain_end, pair_chain_length_sum_ptr, local_zero_buff) + encoding[i], range(0, CHAIN_LENGTH), + lambda n: chain_hash(chain_start, n, chain_end, chain_length_ptr, local_zero_buff), ) - target_sum += pair_chain_length_sum_ptr[0] + target_sum += chain_length_ptr[0] assert target_sum == TARGET_SUM @@ -90,40 +107,25 @@ def xmss_verify(merkle_root, message, slot_lo, slot_hi, merkle_chunks): @inline -def chain_hash(input_left, n, output_left, pair_chain_length_sum_ptr, local_zero_buff): - debug_assert(n < CHAIN_LENGTH**2) - - raw_left = n % CHAIN_LENGTH - raw_right = (n - raw_left) / CHAIN_LENGTH - - n_left = (CHAIN_LENGTH - 1) - raw_left - if n_left == 0: - copy_8(input_left, output_left) - elif n_left == 1: - poseidon8_compress(input_left, local_zero_buff, output_left) - else: - states_left = Array((n_left - 1) * DIGEST_LEN) - poseidon8_compress(input_left, local_zero_buff, states_left) - for i in unroll(1, n_left - 1): - poseidon8_compress(states_left + (i - 1) * DIGEST_LEN, local_zero_buff, states_left + i * DIGEST_LEN) - poseidon8_compress(states_left + (n_left - 2) * DIGEST_LEN, local_zero_buff, output_left) - - n_right = (CHAIN_LENGTH - 1) - raw_right - debug_assert(raw_right < CHAIN_LENGTH) - input_right = input_left + DIGEST_LEN - output_right = output_left + DIGEST_LEN - if n_right == 0: - copy_8(input_right, output_right) - elif n_right == 1: - poseidon8_compress(input_right, local_zero_buff, output_right) +def chain_hash(input_ptr, n, output_ptr, chain_length_ptr, local_zero_buff): + # Iterate the WOTS chain hash `n_hashes = (CHAIN_LENGTH-1) - n` times, starting + # from the signer's chain tip at `input_ptr`, producing `output_ptr`. Records + # `n` in `chain_length_ptr[0]` so the caller can accumulate the target sum. + debug_assert(n < CHAIN_LENGTH) + + n_hashes = (CHAIN_LENGTH - 1) - n + if n_hashes == 0: + copy_8(input_ptr, output_ptr) + elif n_hashes == 1: + poseidon8_compress(input_ptr, local_zero_buff, output_ptr) else: - states_right = Array((n_right - 1) * DIGEST_LEN) - poseidon8_compress(input_right, local_zero_buff, states_right) - for i in unroll(1, n_right - 1): - poseidon8_compress(states_right + (i - 1) * DIGEST_LEN, local_zero_buff, states_right + i * DIGEST_LEN) - poseidon8_compress(states_right + (n_right - 2) * DIGEST_LEN, local_zero_buff, output_right) + states = Array((n_hashes - 1) * DIGEST_LEN) + poseidon8_compress(input_ptr, local_zero_buff, states) + for i in unroll(1, n_hashes - 1): + poseidon8_compress(states + (i - 1) * DIGEST_LEN, local_zero_buff, states + i * DIGEST_LEN) + poseidon8_compress(states + (n_hashes - 2) * DIGEST_LEN, local_zero_buff, output_ptr) - pair_chain_length_sum_ptr[0] = raw_left + raw_right + chain_length_ptr[0] = n return diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index 94d6e0d6..a7ca9d66 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -119,24 +119,29 @@ pub fn wots_encode( input[MESSAGE_LEN_FE + RANDOMNESS_LEN_FE + 1] = slot_hi; input[MESSAGE_LEN_FE + RANDOMNESS_LEN_FE + 2..].copy_from_slice(truncated_merkle_root); - let encoding_fe = poseidon_compress_slice(&input, false); - - if encoding_fe.iter().any(|&fe| fe == -F::ONE) { - return None; - } - - const CHUNKS_PER_FE: usize = (V + V_GRINDING) / DIGEST_SIZE; + // `poseidon_compress_slice` returns 4 FE; we use the first 3 (= NUM_ENCODING_FE). + // Assumption (for now): each Goldilocks FE yields ~64 bits of almost-uniform entropy. + // We decompose each FE into 21 × W=3 bit chunks (63 bits total), leaving a 1-bit + // remainder. The DSL verifier asserts that remainder ∈ {0, 1}, so the low 63 bits + // of encoding_fe[i] are canonical. 3 × 21 = 63 total chunks, of which we use the + // first V + V_GRINDING = 44 as Winternitz indices. + let full_output = poseidon_compress_slice(&input, false); + + const NUM_ENCODING_FE: usize = 3; + const CHUNKS_PER_FE: usize = 21; const MASK: u64 = (1u64 << W) - 1; - debug_assert_eq!(CHUNKS_PER_FE * DIGEST_SIZE, V + V_GRINDING); + debug_assert!(CHUNKS_PER_FE * W <= 63); // 1-bit remainder + debug_assert!(NUM_ENCODING_FE * CHUNKS_PER_FE >= V + V_GRINDING); - let mut all_indices = [0u8; V + V_GRINDING]; - for (i, fe) in encoding_fe.iter().enumerate() { + let mut all_indices = [0u8; NUM_ENCODING_FE * CHUNKS_PER_FE]; + for (i, fe) in full_output.iter().take(NUM_ENCODING_FE).enumerate() { let value = fe.as_canonical_u64(); for j in 0..CHUNKS_PER_FE { all_indices[i * CHUNKS_PER_FE + j] = ((value >> (j * W)) & MASK) as u8; } } - is_valid_encoding(&all_indices).then(|| all_indices[..V].try_into().unwrap()) + let used: [u8; V + V_GRINDING] = all_indices[..V + V_GRINDING].try_into().unwrap(); + is_valid_encoding(&used).then(|| used[..V].try_into().unwrap()) } fn is_valid_encoding(encoding: &[u8]) -> bool { From 77714546951efee4c9f4c5a3ba41bc85235a4279 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 16 Apr 2026 14:25:20 +0200 Subject: [PATCH 11/31] w --- crates/lean_compiler/src/a_simplify_lang.rs | 1 + crates/rec_aggregation/hashing.py | 21 ++++++++++++++++++--- crates/rec_aggregation/utils.py | 21 ++++++++++++++------- crates/rec_aggregation/whir.py | 7 +------ 4 files changed, 34 insertions(+), 16 deletions(-) diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index 5ba5332b..0508e39f 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -2396,6 +2396,7 @@ fn simplify_lines( res.push(SimpleLine::equality(target_var, SimpleExpr::Constant(result))); } else { if !operation.supports_runtime() { + eprintln!("[COMPILE-TIME-OP DEBUG] operation={operation:?}, args={args_simplified:?}, var={var:?}, target_var={target_var:?}, is_mutable={is_mutable}"); return Err(format!( "Operation `{operation}` is compile-time only; all operands must be constants" )); diff --git a/crates/rec_aggregation/hashing.py b/crates/rec_aggregation/hashing.py index 5613491d..bcb20d25 100644 --- a/crates/rec_aggregation/hashing.py +++ b/crates/rec_aggregation/hashing.py @@ -47,6 +47,13 @@ def batch_hash_slice_rtl(num_queries, all_data_to_hash, all_resulting_hashes, nu def batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, num_chunks: Const): + # num_chunks=1 is a trivial pass-through: the "hash" of a single digest is + # the digest itself. Handled inline here so `slice_hash_rtl` never has to + # deal with num_chunks<2 (it would otherwise generate invalid offsets). + if num_chunks == 1: + for i in range(0, num_queries): + all_resulting_hashes[i] = all_data_to_hash[i] + return for i in range(0, num_queries): data = all_data_to_hash[i] res = slice_hash_rtl(data, num_chunks) @@ -56,6 +63,10 @@ def batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hash @inline def slice_hash_rtl(data, num_chunks): + # Precondition: num_chunks >= 2. Callers must dispatch the num_chunks=1 case + # separately (the single chunk is its own hash). Without this the generated + # offset `data + (num_chunks-2) * DIGEST_LEN` underflows for num_chunks=1 + # and confuses @inline expansion. states = Array((num_chunks - 1) * DIGEST_LEN) poseidon8_compress(data + (num_chunks - 2) * DIGEST_LEN, data + (num_chunks - 1) * DIGEST_LEN, states) @@ -86,7 +97,7 @@ def slice_hash_with_iv(data, num_chunks): def slice_hash_with_iv_dynamic_unroll(data, len, len_bits: Const): remainder = modulo_8(len, len_bits) num_full_elements = len - remainder - num_full_chunks = num_full_elements / 8 + num_full_chunks = num_full_elements / DIGEST_LEN if num_full_chunks == 0: left = Array(DIGEST_LEN) @@ -152,7 +163,9 @@ def fill_padded_chunk_const(dst, src, n: Const): def modulo_8(n, n_bits: Const): - debug_assert(2 < n_bits) + # Name is legacy; returns `n mod DIGEST_LEN`. For DIGEST_LEN=4 (Goldilocks) + # this is the low 2 bits of n; for DIGEST_LEN=8 (KoalaBear) it's the low 3. + debug_assert(1 < n_bits) debug_assert(n < 2**n_bits) bits = Array(n_bits) hint_decompose_bits(n, bits, n_bits, BIG_ENDIAN) @@ -164,7 +177,9 @@ def modulo_8(n, n_bits: Const): assert b * (1 - b) == 0 partial_sums[i] = partial_sums[i - 1] + b * 2**i assert n == partial_sums[n_bits - 1] - return partial_sums[2] + # DIGEST_LEN = 2^DIGEST_LEN_BITS; we want partial_sums[DIGEST_LEN_BITS - 1] + # (low DIGEST_LEN_BITS bits). log2(4) = 2 → index 1; log2(8) = 3 → index 2. + return partial_sums[log2_ceil(DIGEST_LEN) - 1] @inline diff --git a/crates/rec_aggregation/utils.py b/crates/rec_aggregation/utils.py index 607e2728..fe14da2a 100644 --- a/crates/rec_aggregation/utils.py +++ b/crates/rec_aggregation/utils.py @@ -523,7 +523,8 @@ def whir_1_merkle_step_and_pow(v, state_in, path_chunk, state_out, power_shift): return ROOT ** (power_shift * (v % 2)) -def decompose_and_verify_merkle_query(a, domain_size: Const, prev_root, num_chunks: Const): +@inline +def decompose_and_verify_merkle_query(a, domain_size, prev_root, num_chunks): # Goldilocks FRI: query indices fit in TWO_ADICITY = 32 bits. Decompose `a` # into 8 × 4-bit nibbles and assert `a == partial_sum`; that single equality # enforces both the decomposition and `a < 2^32` (since partial_sum ≤ 2^32−1). @@ -553,15 +554,16 @@ def decompose_and_verify_merkle_query(a, domain_size: Const, prev_root, num_chun prod: Mut = 1 # First nibble: leaf_hash -> states[0] - prod *= match_range( + nib_pow = match_range( nibbles[0], range(0, 16), lambda v: whir_4_merkle_step_and_pow(v, leaf_hash, merkle_path, states, 2 ** (TWO_ADICITY - domain_size)), ) + prod *= nib_pow # Middle nibbles: states[k-1] -> states[k] for k in unroll(1, n_nibbles - 1): - prod *= match_range( + nib_pow = match_range( nibbles[k], range(0, 16), lambda v: whir_4_merkle_step_and_pow( @@ -572,6 +574,7 @@ def decompose_and_verify_merkle_query(a, domain_size: Const, prev_root, num_chun 2 ** (TWO_ADICITY - domain_size + 4 * k), ), ) + prod *= nib_pow # Last nibble: states[-1] -> prev_root last_k = n_nibbles - 1 @@ -579,29 +582,33 @@ def decompose_and_verify_merkle_query(a, domain_size: Const, prev_root, num_chun last_path = merkle_path + 4 * last_k * DIGEST_LEN last_power_shift = 2 ** (TWO_ADICITY - domain_size + 4 * last_k) if domain_size % 4 == 0: - prod *= match_range( + nib_pow = match_range( nibbles[last_k], range(0, 16), lambda v: whir_4_merkle_step_and_pow(v, last_state_in, last_path, prev_root, last_power_shift), ) + prod *= nib_pow elif domain_size % 4 == 1: - prod *= match_range( + nib_pow = match_range( nibbles[last_k], range(0, 16), lambda v: whir_1_merkle_step_and_pow(v, last_state_in, last_path, prev_root, last_power_shift), ) + prod *= nib_pow elif domain_size % 4 == 2: - prod *= match_range( + nib_pow = match_range( nibbles[last_k], range(0, 16), lambda v: whir_2_merkle_step_and_pow(v, last_state_in, last_path, prev_root, last_power_shift), ) + prod *= nib_pow elif domain_size % 4 == 3: - prod *= match_range( + nib_pow = match_range( nibbles[last_k], range(0, 16), lambda v: whir_3_merkle_step_and_pow(v, last_state_in, last_path, prev_root, last_power_shift), ) + prod *= nib_pow return leaf_data, prod diff --git a/crates/rec_aggregation/whir.py b/crates/rec_aggregation/whir.py index ecdf962f..4bdd48ce 100644 --- a/crates/rec_aggregation/whir.py +++ b/crates/rec_aggregation/whir.py @@ -220,9 +220,6 @@ def decompose_and_verify_merkle_batch_with_height(num_queries, sampled, root, he if num_chunks == 20: decompose_and_verify_merkle_batch_const(num_queries, sampled, root, height, 20, circle_values, answers) return - if num_chunks == 1: - decompose_and_verify_merkle_batch_const(num_queries, sampled, root, height, 1, circle_values, answers) - return if num_chunks == 4: decompose_and_verify_merkle_batch_const(num_queries, sampled, root, height, 4, circle_values, answers) return @@ -235,9 +232,7 @@ def decompose_and_verify_merkle_batch_with_height(num_queries, sampled, root, he def decompose_and_verify_merkle_batch_const(num_queries, sampled, root, height: Const, num_chunks: Const, circle_values, merkle_leaves): for i in range(0, num_queries): - leaf, circle = decompose_and_verify_merkle_query(sampled[i], height, root, num_chunks) - merkle_leaves[i] = leaf - circle_values[i] = circle + merkle_leaves[i], circle_values[i] = decompose_and_verify_merkle_query(sampled[i], height, root, num_chunks) return From ff61a47b8a23a2d589dfeb1316e6909bd75b1d1c Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 16 Apr 2026 14:28:15 +0200 Subject: [PATCH 12/31] w --- crates/rec_aggregation/main.py | 3 +-- crates/rec_aggregation/src/compilation.rs | 2 +- crates/rec_aggregation/utils.py | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/crates/rec_aggregation/main.py b/crates/rec_aggregation/main.py index 82f02754..5a18044c 100644 --- a/crates/rec_aggregation/main.py +++ b/crates/rec_aggregation/main.py @@ -204,8 +204,7 @@ def build_inner_data_buf(n_sub, pubkeys_hash, message, slot_lo, slot_hi, merkle_ inner_data_buf[0] = n_sub copy_8(pubkeys_hash, inner_data_buf + 1) inner_msg = inner_data_buf + 1 + DIGEST_LEN - debug_assert(MESSAGE_LEN == 9) - copy_9(message, inner_msg) + copy_message(message, inner_msg) # copies MESSAGE_LEN=4 elements under Goldilocks inner_msg[MESSAGE_LEN] = slot_lo inner_msg[MESSAGE_LEN + 1] = slot_hi for k in unroll(0, N_MERKLE_CHUNKS): diff --git a/crates/rec_aggregation/src/compilation.rs b/crates/rec_aggregation/src/compilation.rs index 5a4bc280..4129d4b0 100644 --- a/crates/rec_aggregation/src/compilation.rs +++ b/crates/rec_aggregation/src/compilation.rs @@ -46,7 +46,7 @@ fn compile_main_program(inner_program_log_size: usize, bytecode_zero_eval: F) -> #[instrument(skip_all)] fn compile_main_program_self_referential() -> Bytecode { - let mut log_size_guess = 19; + let mut log_size_guess = 18; let bytecode_zero_eval = F::ONE; loop { let bytecode = compile_main_program(log_size_guess, bytecode_zero_eval); diff --git a/crates/rec_aggregation/utils.py b/crates/rec_aggregation/utils.py index fe14da2a..0e9f68f8 100644 --- a/crates/rec_aggregation/utils.py +++ b/crates/rec_aggregation/utils.py @@ -405,8 +405,8 @@ def copy_8(a, b): @inline -def copy_9(a, b): - # Copy MESSAGE_LEN=4 elements (equal to DIGEST_LEN under Goldilocks). +def copy_message(a, b): + # Copy MESSAGE_LEN=4 elements (= DIGEST_LEN under Goldilocks). dot_product_ee(a, ONE_EF_PTR, b) dot_product_ee(a + (DIGEST_LEN - DIM), ONE_EF_PTR, b + (DIGEST_LEN - DIM)) return From ec26c5238b53d2b292a4bdc64afef47c887a7fa9 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 16 Apr 2026 14:31:42 +0200 Subject: [PATCH 13/31] wip --- crates/rec_aggregation/fiat_shamir.py | 12 +++--- crates/rec_aggregation/hashing.py | 2 +- crates/rec_aggregation/main.py | 24 ++++++------ crates/rec_aggregation/recursion.py | 16 ++++---- crates/rec_aggregation/src/compilation.rs | 4 +- crates/rec_aggregation/utils.py | 48 +++++++++++------------ crates/rec_aggregation/whir.py | 16 ++++---- crates/rec_aggregation/xmss_aggregate.py | 4 +- 8 files changed, 61 insertions(+), 65 deletions(-) diff --git a/crates/rec_aggregation/fiat_shamir.py b/crates/rec_aggregation/fiat_shamir.py index 9b464caa..e28dc6b8 100644 --- a/crates/rec_aggregation/fiat_shamir.py +++ b/crates/rec_aggregation/fiat_shamir.py @@ -8,7 +8,7 @@ def fs_new(transcript_ptr): fs_state = Array(9) - set_to_8_zeros(fs_state) + zero_digest(fs_state) fs_state[8] = transcript_ptr return fs_state @@ -46,7 +46,7 @@ def fs_grinding(fs, bits): if bits == 0: return fs # no grinding transcript_ptr = fs[8] - set_to_7_zeros(transcript_ptr + 1) + zero_digest_tail(transcript_ptr + 1) new_fs = Array(9) poseidon8_compress(fs, transcript_ptr, new_fs) @@ -67,7 +67,7 @@ def fs_sample_chunks(fs, n_chunks: Const): for i in unroll(0, (n_chunks + 1)): domain_sep = Array(8) domain_sep[0] = i - set_to_7_zeros(domain_sep + 1) + zero_digest_tail(domain_sep + 1) poseidon8_compress( domain_sep, fs, @@ -102,7 +102,7 @@ def fs_hint(fs, n): # return the updated fiat-shamir, and a pointer to n field elements from the transcript transcript_ptr = fs[8] new_fs = Array(9) - copy_8(fs, new_fs) + copy_digest(fs, new_fs) new_fs[8] = fs[8] + n # advance transcript pointer return new_fs, transcript_ptr @@ -160,7 +160,7 @@ def fs_sample_data_with_offset(fs, n_chunks: Const, offset): for i in unroll(0, n_chunks): domain_sep = Array(8) domain_sep[0] = offset + i - set_to_7_zeros(domain_sep + 1) + zero_digest_tail(domain_sep + 1) poseidon8_compress(domain_sep, fs, sampled + i * 8) return sampled @@ -171,7 +171,7 @@ def fs_finalize_sample(fs, total_n_chunks): new_fs = Array(9) domain_sep = Array(8) domain_sep[0] = total_n_chunks - set_to_7_zeros(domain_sep + 1) + zero_digest_tail(domain_sep + 1) poseidon8_compress(domain_sep, fs, new_fs) new_fs[8] = fs[8] # same transcript pointer return new_fs diff --git a/crates/rec_aggregation/hashing.py b/crates/rec_aggregation/hashing.py index bcb20d25..953f4b88 100644 --- a/crates/rec_aggregation/hashing.py +++ b/crates/rec_aggregation/hashing.py @@ -335,5 +335,5 @@ def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, root, height: Co states + (j - 1) * DIGEST_LEN, states + j * DIGEST_LEN, ) - copy_8(states + (height - 1) * DIGEST_LEN, root) + copy_digest(states + (height - 1) * DIGEST_LEN, root) return diff --git a/crates/rec_aggregation/main.py b/crates/rec_aggregation/main.py index 5a18044c..3fcdbd8f 100644 --- a/crates/rec_aggregation/main.py +++ b/crates/rec_aggregation/main.py @@ -63,7 +63,7 @@ def main(): ) inner_pub_mem = Array(INNER_PUB_MEM_SIZE) - copy_8(slice_hash_with_iv(inner_data_buf, INPUT_DATA_NUM_CHUNKS), inner_pub_mem) + copy_digest(slice_hash_with_iv(inner_data_buf, INPUT_DATA_NUM_CHUNKS), inner_pub_mem) bytecode_claims = Array(2) bytecode_claims[0] = inner_data_buf + BYTECODE_CLAIM_OFFSET bytecode_claims[1] = recursion(inner_pub_mem, bytecode_hash_domsep) @@ -71,12 +71,12 @@ def main(): # All fields of `data_buf` are now written: hash it and assert the digest # matches the (single-element) public input by writing into public memory. outer_hash = slice_hash_with_iv(data_buf, INPUT_DATA_NUM_CHUNKS) - copy_8(outer_hash, pub_mem) + copy_digest(outer_hash, pub_mem) return # General path computed_pubkeys_hash = slice_hash_with_iv_dynamic_unroll(all_pubkeys, n_sigs * DIGEST_LEN, MAX_LOG_MEMORY_SIZE) - copy_8(computed_pubkeys_hash, pubkeys_hash_expected) + copy_digest(computed_pubkeys_hash, pubkeys_hash_expected) # Buffer for partition verification n_total = n_sigs + n_dup @@ -128,7 +128,7 @@ def main(): merkle_chunks_for_slot, bytecode_hash_domsep, ) inner_pub_mem = Array(INNER_PUB_MEM_SIZE) - copy_8(slice_hash_with_iv(inner_data_buf, INPUT_DATA_NUM_CHUNKS), inner_pub_mem) + copy_digest(slice_hash_with_iv(inner_data_buf, INPUT_DATA_NUM_CHUNKS), inner_pub_mem) bytecode_claims[2 * rec_idx] = inner_data_buf + BYTECODE_CLAIM_OFFSET # Verify recursive proof - returns the second bytecode claim @@ -140,7 +140,7 @@ def main(): # Bytecode claims if n_recursions == 0: for k in unroll(0, BYTECODE_POINT_N_VARS): - set_to_5_zeros(bytecode_claim_output + k * DIM) + zero_ef(bytecode_claim_output + k * DIM) bytecode_claim_output[BYTECODE_POINT_N_VARS * DIM] = BYTECODE_ZERO_EVAL for k in unroll(1, DIM): bytecode_claim_output[BYTECODE_POINT_N_VARS * DIM + k] = 0 @@ -150,7 +150,7 @@ def main(): # All fields of `data_buf` are now written: hash it and assert the digest # matches the (single-element) public input by writing into public memory. outer_hash = slice_hash_with_iv(data_buf, INPUT_DATA_NUM_CHUNKS) - copy_8(outer_hash, pub_mem) + copy_digest(outer_hash, pub_mem) return def reduce_bytecode_claims(bytecode_claims, n_bytecode_claims, bytecode_claim_output): @@ -168,7 +168,7 @@ def reduce_bytecode_claims(bytecode_claims, n_bytecode_claims, bytecode_claim_ou hint_witness("bytecode_sumcheck_proof", bytecode_sumcheck_proof) reduction_fs: Mut = fs_new(bytecode_sumcheck_proof) reduction_fs, received_claims_hash = fs_receive_chunks(reduction_fs, 1) - copy_8(bytecode_claims_hash, received_claims_hash) + copy_digest(bytecode_claims_hash, received_claims_hash) reduction_fs, alpha = fs_sample_ef(reduction_fs) alpha_powers = powers(alpha, n_bytecode_claims) @@ -176,7 +176,7 @@ def reduce_bytecode_claims(bytecode_claims, n_bytecode_claims, bytecode_claim_ou all_values = Array(n_bytecode_claims * DIM) for i in range(0, n_bytecode_claims): claim_ptr = bytecode_claims[i] - copy_5(claim_ptr + BYTECODE_POINT_N_VARS * DIM, all_values + i * DIM) + copy_ef(claim_ptr + BYTECODE_POINT_N_VARS * DIM, all_values + i * DIM) claimed_sum = Array(DIM) dot_product_ee_dynamic(all_values, alpha_powers, claimed_sum, n_bytecode_claims) @@ -188,21 +188,21 @@ def reduce_bytecode_claims(bytecode_claims, n_bytecode_claims, bytecode_claim_ou for i in range(0, n_bytecode_claims): claim_ptr = bytecode_claims[i] eq_val = eq_mle_extension(claim_ptr, challenges, BYTECODE_POINT_N_VARS) - copy_5(eq_val, eq_evals + i * DIM) + copy_ef(eq_val, eq_evals + i * DIM) w_r = Array(DIM) dot_product_ee_dynamic(eq_evals, alpha_powers, w_r, n_bytecode_claims) bytecode_value_at_r = div_extension_ret(final_eval, w_r) copy_many_ef(challenges, bytecode_claim_output, BYTECODE_POINT_N_VARS) - copy_5(bytecode_value_at_r, bytecode_claim_output + BYTECODE_POINT_N_VARS * DIM) + copy_ef(bytecode_value_at_r, bytecode_claim_output + BYTECODE_POINT_N_VARS * DIM) return @inline def build_inner_data_buf(n_sub, pubkeys_hash, message, slot_lo, slot_hi, merkle_chunks_for_slot, bytecode_hash_domsep): inner_data_buf = Array(INPUT_DATA_SIZE_PADDED) inner_data_buf[0] = n_sub - copy_8(pubkeys_hash, inner_data_buf + 1) + copy_digest(pubkeys_hash, inner_data_buf + 1) inner_msg = inner_data_buf + 1 + DIGEST_LEN copy_message(message, inner_msg) # copies MESSAGE_LEN=4 elements under Goldilocks inner_msg[MESSAGE_LEN] = slot_lo @@ -210,7 +210,7 @@ def build_inner_data_buf(n_sub, pubkeys_hash, message, slot_lo, slot_hi, merkle_ for k in unroll(0, N_MERKLE_CHUNKS): inner_msg[MESSAGE_LEN + 2 + k] = merkle_chunks_for_slot[k] hint_witness("inner_bytecode_claim", inner_data_buf + BYTECODE_CLAIM_OFFSET) - copy_8(bytecode_hash_domsep, inner_data_buf + BYTECODE_HASH_DOMSEP_OFFSET) + copy_digest(bytecode_hash_domsep, inner_data_buf + BYTECODE_HASH_DOMSEP_OFFSET) for k in unroll(BYTECODE_HASH_DOMSEP_OFFSET + DIGEST_LEN, INPUT_DATA_SIZE_PADDED): inner_data_buf[k] = 0 return inner_data_buf diff --git a/crates/rec_aggregation/recursion.py b/crates/rec_aggregation/recursion.py index d87e09c5..bdadcbb1 100644 --- a/crates/rec_aggregation/recursion.py +++ b/crates/rec_aggregation/recursion.py @@ -104,7 +104,7 @@ def recursion(inner_public_memory, bytecode_hash_domsep): n_vars_logup_gkr = compute_total_gkr_n_vars(log_memory, log_bytecode_padded, table_heights) fs, quotient_gkr, point_gkr, numerators_value, denominators_value = verify_gkr_quotient(fs, n_vars_logup_gkr) - set_to_5_zeros(quotient_gkr) + zero_ef(quotient_gkr) memory_and_acc_prefix = multilinear_location_prefix(0, n_vars_logup_gkr - log_memory, point_gkr) @@ -372,8 +372,8 @@ def continue_recursion_ordered( mle_of_zeros_then_ones(point_gkr, offset, n_vars_logup_gkr), ) - copy_5(retrieved_numerators_value, numerators_value) - copy_5(retrieved_denominators_value, denominators_value) + copy_ef(retrieved_numerators_value, numerators_value) + copy_ef(retrieved_denominators_value, denominators_value) memory_and_acc_point = point_gkr + (n_vars_logup_gkr - log_memory) * DIM @@ -458,7 +458,7 @@ def continue_recursion_ordered( pcs_values_down[table_index][last_index][AIR_DOWN_COLUMNS[table_index][i]].push(evals_down + i * DIM) # verify that the AIR-batched sumcheck is valid - copy_5(check_sum, batched_air_final_value) + copy_ef(check_sum, batched_air_final_value) fs, public_memory_random_point = fs_sample_many_ef(fs, INNER_PUBLIC_MEMORY_LOG_SIZE) poly_eq_public_mem = poly_eq_extension(public_memory_random_point, INNER_PUBLIC_MEMORY_LOG_SIZE) @@ -645,7 +645,7 @@ def continue_recursion_ordered( curr_randomness += DIM offset += n_rows * total_num_cols - copy_5(mul_extension_ret(s, final_value), end_sum) + copy_ef(mul_extension_ret(s, final_value), end_sum) return @@ -657,8 +657,8 @@ def multilinear_location_prefix(offset, n_vars, point): def fingerprint_2(table_index, data_1, data_2, logup_alphas_eq_poly): buff = Array(DIM * 2) - copy_5(data_1, buff) - copy_5(data_2, buff + DIM) + copy_ef(data_1, buff) + copy_ef(data_2, buff + DIM) res: Mut = dot_product_ee_ret(buff, logup_alphas_eq_poly, 2) res = add_extension_ret(res, mul_base_extension_ret(table_index, logup_alphas_eq_poly + (2 ** log2_ceil(MAX_BUS_WIDTH) - 1) * DIM)) return res @@ -737,7 +737,7 @@ def verify_gkr_quotient_step(fs: Mut, n_vars, point, claim_num, claim_den): new_claim_num = dot_product_ee_ret(inner_evals, point_poly_eq, 2) new_claim_den = dot_product_ee_ret(inner_evals + 2 * DIM, point_poly_eq, 2) - copy_5(beta, postponed_point) + copy_ef(beta, postponed_point) return fs, postponed_point, new_claim_num, new_claim_den diff --git a/crates/rec_aggregation/src/compilation.rs b/crates/rec_aggregation/src/compilation.rs index 4129d4b0..f7bffd74 100644 --- a/crates/rec_aggregation/src/compilation.rs +++ b/crates/rec_aggregation/src/compilation.rs @@ -450,7 +450,7 @@ where res += &format!("\n buff = Array(DIM * {})", bus_data.len()); for (i, data) in bus_data.iter().enumerate() { let data_str = eval_air_constraint(*data, None, &mut ctx, &mut res); - res += &format!("\n copy_5({}, buff + DIM * {})", data_str, i); + res += &format!("\n copy_ef({}, buff + DIM * {})", data_str, i); } // dot product: bus_res = sum(buff[i] * logup_alphas_eq_poly[i]) for i in 0..bus_data.len() res += "\n bus_res_init = Array(DIM)"; @@ -512,7 +512,7 @@ fn eval_air_constraint( if let Some(d) = dest && v != d { - res.push_str(&format!("\n copy_5({}, {})", v, d)); + res.push_str(&format!("\n copy_ef({}, {})", v, d)); } v } diff --git a/crates/rec_aggregation/utils.py b/crates/rec_aggregation/utils.py index 0e9f68f8..61e70d80 100644 --- a/crates/rec_aggregation/utils.py +++ b/crates/rec_aggregation/utils.py @@ -158,7 +158,7 @@ def expand_from_univariate_base_const(alpha, n: Const): def expand_from_univariate_ext(alpha, n): res = Array(n * DIM) - copy_5(alpha, res) + copy_ef(alpha, res) for i in range(0, n - 1): mul_extension(res + i * DIM, res + i * DIM, res + (i + 1) * DIM) return res @@ -352,44 +352,39 @@ def sub_extension_ret(a, b): return c -# Legacy copy_N / set_to_N_zeros names from the KoalaBear era (DIM=5, DIGEST_LEN=8, -# MESSAGE_LEN=9) are kept for minimal churn but redefined to Goldilocks sizes -# (DIM=3, DIGEST_LEN=4, MESSAGE_LEN=4). Semantic roles: -# copy_5 / set_to_5_zeros → one extension-field element (DIM entries). -# copy_8 / set_to_8_zeros → one digest (DIGEST_LEN entries). -# copy_9 → one message (MESSAGE_LEN = DIGEST_LEN entries now). -# set_to_7_zeros → digest tail after a domain-sep slot: DIGEST_LEN-1 -# = DIM entries under Goldilocks. -# copy_16 → a full Poseidon input state (2 × DIGEST_LEN entries). +# Semantic copy / zero helpers. Sized to Goldilocks (DIM=3, DIGEST_LEN=4, +# MESSAGE_LEN=4). Each helper is a thin wrapper over `dot_product_ee(_, ONE_EF_PTR, _)` +# which copies DIM elements via the extension-op precompile. @inline -def copy_5(a, b): - # Copy DIM=3 elements (one extension-field element). +def copy_ef(a, b): + # Copy one extension-field element = DIM entries. dot_product_ee(a, ONE_EF_PTR, b) return @inline -def set_to_5_zeros(a): - # Zero DIM=3 elements. +def zero_ef(a): + # Zero one extension-field element = DIM entries. zero_ptr = ZERO_VEC_PTR dot_product_ee(a, ONE_EF_PTR, zero_ptr) return @inline -def set_to_7_zeros(a): - # Zero DIGEST_LEN-1 = 3 elements (= DIM). Used after writing a domain-sep - # byte into slot 0 of a digest-sized buffer. +def zero_digest_tail(a): + # Zero DIGEST_LEN-1 entries — typically called on `ptr + 1` after writing a + # domain-sep byte into slot 0 of a digest-sized buffer. Under Goldilocks + # DIGEST_LEN-1 == DIM so one dot_product_ee suffices. zero_ptr = ZERO_VEC_PTR dot_product_ee(a, ONE_EF_PTR, zero_ptr) return @inline -def set_to_8_zeros(a): - # Zero DIGEST_LEN=4 elements via two overlapping DIM=3 clears. +def zero_digest(a): + # Zero one digest = DIGEST_LEN entries via two overlapping DIM clears. zero_ptr = ZERO_VEC_PTR dot_product_ee(a, ONE_EF_PTR, zero_ptr) dot_product_ee(a + (DIGEST_LEN - DIM), ONE_EF_PTR, zero_ptr) @@ -397,8 +392,8 @@ def set_to_8_zeros(a): @inline -def copy_8(a, b): - # Copy DIGEST_LEN=4 elements via two overlapping DIM=3 copies. +def copy_digest(a, b): + # Copy one digest = DIGEST_LEN entries via two overlapping DIM copies. dot_product_ee(a, ONE_EF_PTR, b) dot_product_ee(a + (DIGEST_LEN - DIM), ONE_EF_PTR, b + (DIGEST_LEN - DIM)) return @@ -406,17 +401,18 @@ def copy_8(a, b): @inline def copy_message(a, b): - # Copy MESSAGE_LEN=4 elements (= DIGEST_LEN under Goldilocks). + # Copy one message = MESSAGE_LEN entries. Under Goldilocks MESSAGE_LEN == + # DIGEST_LEN, so this is structurally identical to `copy_digest`. dot_product_ee(a, ONE_EF_PTR, b) dot_product_ee(a + (DIGEST_LEN - DIM), ONE_EF_PTR, b + (DIGEST_LEN - DIM)) return @inline -def copy_16(a, b): - # Copy a full Poseidon input block = 2 × DIGEST_LEN = 8 elements. - copy_8(a, b) - copy_8(a + DIGEST_LEN, b + DIGEST_LEN) +def copy_poseidon_input(a, b): + # Copy a full Poseidon8 input block = 2 × DIGEST_LEN entries. + copy_digest(a, b) + copy_digest(a + DIGEST_LEN, b + DIGEST_LEN) return diff --git a/crates/rec_aggregation/whir.py b/crates/rec_aggregation/whir.py index 4bdd48ce..db806a8c 100644 --- a/crates/rec_aggregation/whir.py +++ b/crates/rec_aggregation/whir.py @@ -96,7 +96,7 @@ def whir_open( range(MAX_NUM_VARIABLES_TO_SEND_COEFFS - WHIR_SUBSEQUENT_FOLDING_FACTOR, MAX_NUM_VARIABLES_TO_SEND_COEFFS + 1), lambda n: univariate_eval_on_base(final_coeffcients, alpha, n), ) - copy_5(final_pol_evaluated_on_circle, final_folds + i * DIM) + copy_ef(final_pol_evaluated_on_circle, final_folds + i * DIM) fs, all_folding_randomness[n_rounds + 1], end_sum = sumcheck_verify(fs, n_final_vars, claimed_sum, 2) @@ -105,10 +105,10 @@ def whir_open( start: Mut = folding_randomness_global for i in range(0, n_rounds + 1): for j in range(0, folding_factors[i]): - copy_5(all_folding_randomness[i] + j * DIM, start + j * DIM) + copy_ef(all_folding_randomness[i] + j * DIM, start + j * DIM) start += folding_factors[i] * DIM for j in range(0, n_final_vars): - copy_5(all_folding_randomness[n_rounds + 1] + j * DIM, start + j * DIM) + copy_ef(all_folding_randomness[n_rounds + 1] + j * DIM, start + j * DIM) all_ood_recovered_evals = Array(num_oods[0] * DIM) for i in range(0, num_oods[0]): @@ -159,7 +159,7 @@ def whir_open( range(MAX_NUM_VARIABLES_TO_SEND_COEFFS - WHIR_SUBSEQUENT_FOLDING_FACTOR, MAX_NUM_VARIABLES_TO_SEND_COEFFS + 1), lambda n: eval_multilinear_coeffs_rev(final_coeffcients, all_folding_randomness[n_rounds + 1], n), ) - # copy_5(mul_extension_ret(s, final_value), end_sum); + # copy_ef(mul_extension_ret(s, final_value), end_sum); return fs, folding_randomness_global, s, final_value, end_sum @@ -174,10 +174,10 @@ def sumcheck_verify_helper(fs: Mut, n_steps, claimed_sum: Mut, degree: Const, ch for sc_round in range(0, n_steps): fs, poly = fs_receive_ef_inlined(fs, degree + 1) sum_over_boolean_hypercube = polynomial_sum_at_0_and_1(poly, degree) - copy_5(sum_over_boolean_hypercube, claimed_sum) + copy_ef(sum_over_boolean_hypercube, claimed_sum) fs, rand = fs_sample_ef(fs) claimed_sum = univariate_polynomial_eval(poly, rand, degree) - copy_5(rand, challenges + sc_round * DIM) + copy_ef(rand, challenges + sc_round * DIM) return fs, claimed_sum @@ -187,11 +187,11 @@ def sumcheck_verify_with_grinding(fs: Mut, n_steps, claimed_sum: Mut, degree: Co for sc_round in range(0, n_steps): fs, poly = fs_receive_ef_inlined(fs, degree + 1) sum_over_boolean_hypercube = polynomial_sum_at_0_and_1(poly, degree) - copy_5(sum_over_boolean_hypercube, claimed_sum) + copy_ef(sum_over_boolean_hypercube, claimed_sum) fs = fs_grinding(fs, folding_grinding_bits) fs, rand = fs_sample_ef(fs) claimed_sum = univariate_polynomial_eval(poly, rand, degree) - copy_5(rand, challenges + sc_round * DIM) + copy_ef(rand, challenges + sc_round * DIM) return fs, challenges, claimed_sum diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index a103119d..bbdfda4e 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -85,7 +85,7 @@ def xmss_verify(merkle_root, message, slot_lo, slot_hi, merkle_chunks): target_sum: Mut = 0 wots_public_key = Array(V * DIGEST_LEN) local_zero_buff = Array(DIGEST_LEN) - set_to_8_zeros(local_zero_buff) + zero_digest(local_zero_buff) for i in unroll(0, V): chain_start = chain_starts + i * DIGEST_LEN @@ -115,7 +115,7 @@ def chain_hash(input_ptr, n, output_ptr, chain_length_ptr, local_zero_buff): n_hashes = (CHAIN_LENGTH - 1) - n if n_hashes == 0: - copy_8(input_ptr, output_ptr) + copy_digest(input_ptr, output_ptr) elif n_hashes == 1: poseidon8_compress(input_ptr, local_zero_buff, output_ptr) else: From 4d912247637b462c1a10b228da19279a6adaa9e8 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 16 Apr 2026 14:42:29 +0200 Subject: [PATCH 14/31] w --- crates/rec_aggregation/fiat_shamir.py | 106 ++++++++++++---------- crates/rec_aggregation/recursion.py | 11 ++- crates/rec_aggregation/src/compilation.rs | 2 +- crates/rec_aggregation/utils.py | 58 ++++++------ crates/rec_aggregation/whir.py | 15 +++ 5 files changed, 108 insertions(+), 84 deletions(-) diff --git a/crates/rec_aggregation/fiat_shamir.py b/crates/rec_aggregation/fiat_shamir.py index e28dc6b8..a6a650d1 100644 --- a/crates/rec_aggregation/fiat_shamir.py +++ b/crates/rec_aggregation/fiat_shamir.py @@ -1,27 +1,32 @@ from snark_lib import * -# FIAT SHAMIR layout: 17 field elements -# 0..8 -> first half of sponge state -# 8 -> transcript pointer + +# FIAT SHAMIR layout (Goldilocks, DIGEST_LEN=4): 1 + DIGEST_LEN field elements +# slots 0..DIGEST_LEN → sponge state (one digest) +# slot DIGEST_LEN → transcript pointer from utils import * +FS_SIZE = DIGEST_LEN + 1 +FS_TPTR = DIGEST_LEN # index of transcript pointer inside an FS state + + def fs_new(transcript_ptr): - fs_state = Array(9) + fs_state = Array(FS_SIZE) zero_digest(fs_state) - fs_state[8] = transcript_ptr + fs_state[FS_TPTR] = transcript_ptr return fs_state @inline def fs_observe_chunks(fs, data, n_chunks): - result: Mut = Array(9) + result: Mut = Array(FS_SIZE) poseidon8_compress(fs, data, result) for i in unroll(1, n_chunks): - new_result = Array(9) + new_result = Array(FS_SIZE) poseidon8_compress(result, data + i * DIGEST_LEN, new_result) result = new_result - result[8] = fs[8] # preserve transcript pointer + result[FS_TPTR] = fs[FS_TPTR] # preserve transcript pointer return result @@ -36,21 +41,21 @@ def fs_observe(fs, data, length: Const): padded[j] = data[n_full_chunks * DIGEST_LEN + j] for j in unroll(remainder, DIGEST_LEN): padded[j] = 0 - final_result = Array(9) + final_result = Array(FS_SIZE) poseidon8_compress(intermediate, padded, final_result) - final_result[8] = fs[8] # preserve transcript pointer + final_result[FS_TPTR] = fs[FS_TPTR] # preserve transcript pointer return final_result def fs_grinding(fs, bits): if bits == 0: return fs # no grinding - transcript_ptr = fs[8] + transcript_ptr = fs[FS_TPTR] zero_digest_tail(transcript_ptr + 1) - new_fs = Array(9) + new_fs = Array(FS_SIZE) poseidon8_compress(fs, transcript_ptr, new_fs) - new_fs[8] = transcript_ptr + 8 + new_fs[FS_TPTR] = transcript_ptr + DIGEST_LEN sampled = new_fs[0] _, partial_sums_24 = checked_decompose_bits(sampled) @@ -61,36 +66,36 @@ def fs_grinding(fs, bits): def fs_sample_chunks(fs, n_chunks: Const): - # return the updated fiat-shamir, and a pointer to n_chunks chunks of 8 field elements + # return the updated fiat-shamir, and a pointer to n_chunks chunks of DIGEST_LEN field elements - sampled = Array((n_chunks + 1) * 8 + 1) + sampled = Array((n_chunks + 1) * DIGEST_LEN + 1) for i in unroll(0, (n_chunks + 1)): - domain_sep = Array(8) + domain_sep = Array(DIGEST_LEN) domain_sep[0] = i zero_digest_tail(domain_sep + 1) poseidon8_compress( domain_sep, fs, - sampled + i * 8, + sampled + i * DIGEST_LEN, ) - sampled[(n_chunks + 1) * 8] = fs[8] # same transcript pointer - new_fs = sampled + n_chunks * 8 + sampled[(n_chunks + 1) * DIGEST_LEN] = fs[FS_TPTR] # same transcript pointer + new_fs = sampled + n_chunks * DIGEST_LEN return new_fs, sampled @inline def fs_sample_ef(fs): - sampled = Array(8) + sampled = Array(DIGEST_LEN) poseidon8_compress(ZERO_VEC_PTR, fs, sampled) - new_fs = Array(9) + new_fs = Array(FS_SIZE) poseidon8_compress(SAMPLING_DOMAIN_SEPARATOR_PTR, fs, new_fs) - new_fs[8] = fs[8] # same transcript pointer + new_fs[FS_TPTR] = fs[FS_TPTR] # same transcript pointer return new_fs, sampled def fs_sample_many_ef(fs, n): # return the updated fiat-shamir, and a pointer to n (continuous) extension field elements - n_chunks = div_ceil_dynamic(n * DIM, 8) + n_chunks = div_ceil_dynamic(n * DIM, DIGEST_LEN) debug_assert(n_chunks <= 31) debug_assert(1 <= n_chunks) new_fs, sampled = match_range(n_chunks, range(1, 32), lambda nc: fs_sample_chunks(fs, nc)) @@ -100,33 +105,33 @@ def fs_sample_many_ef(fs, n): @inline def fs_hint(fs, n): # return the updated fiat-shamir, and a pointer to n field elements from the transcript - transcript_ptr = fs[8] - new_fs = Array(9) + transcript_ptr = fs[FS_TPTR] + new_fs = Array(FS_SIZE) copy_digest(fs, new_fs) - new_fs[8] = fs[8] + n # advance transcript pointer + new_fs[FS_TPTR] = fs[FS_TPTR] + n # advance transcript pointer return new_fs, transcript_ptr def fs_receive_chunks(fs, n_chunks: Const): - # each chunk = 8 field elements - new_fs = Array(1 + 8 * n_chunks) - transcript_ptr = fs[8] - new_fs[8 * n_chunks] = transcript_ptr + 8 * n_chunks # advance transcript pointer + # each chunk = DIGEST_LEN field elements + new_fs = Array(1 + DIGEST_LEN * n_chunks) + transcript_ptr = fs[FS_TPTR] + new_fs[DIGEST_LEN * n_chunks] = transcript_ptr + DIGEST_LEN * n_chunks # advance transcript pointer poseidon8_compress(fs, transcript_ptr, new_fs) for i in unroll(1, n_chunks): poseidon8_compress( - new_fs + ((i - 1) * 8), - transcript_ptr + i * 8, - new_fs + i * 8, + new_fs + ((i - 1) * DIGEST_LEN), + transcript_ptr + i * DIGEST_LEN, + new_fs + i * DIGEST_LEN, ) - return new_fs + 8 * (n_chunks - 1), transcript_ptr + return new_fs + DIGEST_LEN * (n_chunks - 1), transcript_ptr @inline def fs_receive_ef_inlined(fs, n): - new_fs, ef_ptr = fs_receive_chunks(fs, div_ceil(n * DIM, 8)) - for i in unroll(n * DIM, next_multiple_of(n * DIM, 8)): + new_fs, ef_ptr = fs_receive_chunks(fs, div_ceil(n * DIM, DIGEST_LEN)) + for i in unroll(n * DIM, next_multiple_of(n * DIM, DIGEST_LEN)): assert ef_ptr[i] == 0 return new_fs, ef_ptr @@ -141,14 +146,14 @@ def fs_receive_ef_by_log_dynamic(fs, log_n, min_value: Const, max_value: Const): def fs_receive_ef(fs, n: Const): - new_fs, ef_ptr = fs_receive_chunks(fs, div_ceil(n * DIM, 8)) - for i in unroll(n * DIM, next_multiple_of(n * DIM, 8)): + new_fs, ef_ptr = fs_receive_chunks(fs, div_ceil(n * DIM, DIGEST_LEN)) + for i in unroll(n * DIM, next_multiple_of(n * DIM, DIGEST_LEN)): assert ef_ptr[i] == 0 return new_fs, ef_ptr def fs_print_state(fs_state): - for i in unroll(0, 9): + for i in unroll(0, FS_SIZE): print(i, fs_state[i]) return @@ -156,37 +161,38 @@ def fs_print_state(fs_state): def fs_sample_data_with_offset(fs, n_chunks: Const, offset): # Like fs_sample_chunks but uses domain separators [offset..offset+n_chunks-1]. # Only returns the sampled data, does NOT update fs. - sampled = Array(n_chunks * 8) + sampled = Array(n_chunks * DIGEST_LEN) for i in unroll(0, n_chunks): - domain_sep = Array(8) + domain_sep = Array(DIGEST_LEN) domain_sep[0] = offset + i zero_digest_tail(domain_sep + 1) - poseidon8_compress(domain_sep, fs, sampled + i * 8) + poseidon8_compress(domain_sep, fs, sampled + i * DIGEST_LEN) return sampled def fs_finalize_sample(fs, total_n_chunks): # Compute new fs state using domain_sep = total_n_chunks # (same as the last poseidon call in fs_sample_chunks). - new_fs = Array(9) - domain_sep = Array(8) + new_fs = Array(FS_SIZE) + domain_sep = Array(DIGEST_LEN) domain_sep[0] = total_n_chunks zero_digest_tail(domain_sep + 1) poseidon8_compress(domain_sep, fs, new_fs) - new_fs[8] = fs[8] # same transcript pointer + new_fs[FS_TPTR] = fs[FS_TPTR] # same transcript pointer return new_fs @inline def fs_sample_queries(fs, n_samples): debug_assert(n_samples < 512) - # Compute total_chunks = ceil(n_samples / 8) via bit decomposition. - # Big-endian: nb[0]=bit8 (MSB), nb[8]=bit0 (LSB). + # total_chunks = ceil(n_samples / DIGEST_LEN). With DIGEST_LEN=4 we shift + # right by 2 and check whether the low 2 bits are nonzero. BE decomposition: + # nb[0] = bit 8 (MSB), nb[8] = bit 0 (LSB). nb = checked_decompose_bits_small_value_const(n_samples, 9) - floor_div = nb[0] * 32 + nb[1] * 16 + nb[2] * 8 + nb[3] * 4 + nb[4] * 2 + nb[5] - has_remainder = 1 - (1 - nb[6]) * (1 - nb[7]) * (1 - nb[8]) + floor_div = nb[0] * 64 + nb[1] * 32 + nb[2] * 16 + nb[3] * 8 + nb[4] * 4 + nb[5] * 2 + nb[6] + has_remainder = 1 - (1 - nb[7]) * (1 - nb[8]) total_chunks = floor_div + has_remainder # Sample exactly the needed chunks (dispatch via match_range to keep n_chunks const) - sampled = match_range(total_chunks, range(0, 65), lambda nc: fs_sample_data_with_offset(fs, nc, 0)) + sampled = match_range(total_chunks, range(0, 129), lambda nc: fs_sample_data_with_offset(fs, nc, 0)) new_fs = fs_finalize_sample(fs, total_chunks) return sampled, new_fs diff --git a/crates/rec_aggregation/recursion.py b/crates/rec_aggregation/recursion.py index bdadcbb1..dd54157c 100644 --- a/crates/rec_aggregation/recursion.py +++ b/crates/rec_aggregation/recursion.py @@ -41,7 +41,7 @@ BYTECODE_ZERO_EVAL = BYTECODE_ZERO_EVAL_PLACEHOLDER BYTECODE_CLAIM_SIZE = (BYTECODE_POINT_N_VARS + 1) * DIM BYTECODE_CLAIM_SIZE_PADDED = next_multiple_of(BYTECODE_CLAIM_SIZE, DIGEST_LEN) -INNER_PUBLIC_MEMORY_LOG_SIZE = 3 # public input = 1 hash digest = 8 field elements +INNER_PUBLIC_MEMORY_LOG_SIZE = 2 # Goldilocks: public input = 1 hash digest = 4 field elements PUB_INPUT_SIZE = DIGEST_LEN # the public input is a single digest @@ -55,10 +55,11 @@ def recursion(inner_public_memory, bytecode_hash_domsep): fs = fs_observe(fs, inner_public_memory, PUB_INPUT_SIZE) # observe public input (the data digest) fs = fs_observe(fs, bytecode_hash_domsep, DIGEST_LEN) # observe hash(bytecode hash, domain sep) - # table dims - debug_assert(N_TABLES + 1 < DIGEST_LEN) - fs, dims = fs_receive_chunks(fs, 1) - for i in unroll(N_TABLES + 3, 8): + # table dims — 3 leading slots (whir_log_inv_rate, log_memory, public_input_len) + # + N_TABLES per-table heights. Under Goldilocks DIGEST_LEN=4 so one chunk + # is not enough; we pull two (8 slots). Surplus slots must be zero. + fs, dims = fs_receive_chunks(fs, 2) + for i in unroll(N_TABLES + 3, 2 * DIGEST_LEN): assert dims[i] == 0 whir_log_inv_rate = dims[0] log_memory = dims[1] diff --git a/crates/rec_aggregation/src/compilation.rs b/crates/rec_aggregation/src/compilation.rs index f7bffd74..a7eee819 100644 --- a/crates/rec_aggregation/src/compilation.rs +++ b/crates/rec_aggregation/src/compilation.rs @@ -46,7 +46,7 @@ fn compile_main_program(inner_program_log_size: usize, bytecode_zero_eval: F) -> #[instrument(skip_all)] fn compile_main_program_self_referential() -> Bytecode { - let mut log_size_guess = 18; + let mut log_size_guess = 19; let bytecode_zero_eval = F::ONE; loop { let bytecode = compile_main_program(log_size_guess, bytecode_zero_eval); diff --git a/crates/rec_aggregation/utils.py b/crates/rec_aggregation/utils.py index 61e70d80..2e67c494 100644 --- a/crates/rec_aggregation/utils.py +++ b/crates/rec_aggregation/utils.py @@ -2,6 +2,7 @@ from hashing import * F_BITS = 64 # Goldilocks (P = 2^64 - 2^32 + 1, values fit in u64) +HALF_BITS = 32 # Goldilocks splits cleanly at 32:32 for canonical-form checks. TWO_ADICITY = 32 ROOT = 1753635133440165772 # = 0x185629dcda58878c, of order 2^TWO_ADICITY @@ -464,35 +465,34 @@ def sum_2_ef_fractions(a_num, a_den, b_num, b_den): return sum_num, common_den -# p = 2^31 - 2^24 + 1 -# in binary: p = 1111111000000000000000000000001 -# p - 1 = 1111111000000000000000000000000 -# p - 2 = 1111110111111111111111111111111 -# p - 3 = 1111110111111111111111111111110 -# ... +# Goldilocks: p = 2^64 - 2^32 + 1 = 0xFFFFFFFF_00000001 +# p - 1 = 0xFFFFFFFF_00000000 +# p - 2 = 0xFFFFFFFE_FFFFFFFF +# ... # Any field element (< p) is either: -# - 1111111 | 00...00 -# - not(1111111) | xx...xx +# - high 32 bits = 0xFFFFFFFF and low 32 bits = 0 +# - high 32 bits < 0xFFFFFFFF and low 32 bits arbitrary def checked_decompose_bits(a): - # return a pointer to the 31 bits of a - # .. and the first 24 partial sums of these bits + # Return a pointer to the F_BITS=64 little-endian bits of `a`, plus the + # partial sums over the low HALF_BITS=32 bits. Enforces canonicality. bits = Array(F_BITS) hint_decompose_bits(a, bits, F_BITS, LITTLE_ENDIAN) for i in unroll(0, F_BITS): assert bits[i] * (1 - bits[i]) == 0 - partial_sums_24 = Array(24) - partial_sums_24[0] = bits[0] - for i in unroll(1, 24): - partial_sums_24[i] = partial_sums_24[i - 1] + bits[i] * 2**i - sum_7: Mut = bits[24] - for i in unroll(1, 7): - sum_7 += bits[24 + i] * 2**i - if sum_7 == 127: - assert partial_sums_24[23] == 0 + partial_sums_low = Array(HALF_BITS) + partial_sums_low[0] = bits[0] + for i in unroll(1, HALF_BITS): + partial_sums_low[i] = partial_sums_low[i - 1] + bits[i] * 2**i + sum_high: Mut = bits[HALF_BITS] + for i in unroll(1, F_BITS - HALF_BITS): + sum_high += bits[HALF_BITS + i] * 2**i + # If the high 32 bits are all set, the low 32 bits must be zero (only p-1). + if sum_high == 2**(F_BITS - HALF_BITS) - 1: + assert partial_sums_low[HALF_BITS - 1] == 0 - assert a == partial_sums_24[23] + sum_7 * 2**24 - return bits, partial_sums_24 + assert a == partial_sums_low[HALF_BITS - 1] + sum_high * 2**HALF_BITS + return bits, partial_sums_low @inline @@ -521,10 +521,11 @@ def whir_1_merkle_step_and_pow(v, state_in, path_chunk, state_out, power_shift): @inline def decompose_and_verify_merkle_query(a, domain_size, prev_root, num_chunks): - # Goldilocks FRI: query indices fit in TWO_ADICITY = 32 bits. Decompose `a` - # into 8 × 4-bit nibbles and assert `a == partial_sum`; that single equality - # enforces both the decomposition and `a < 2^32` (since partial_sum ≤ 2^32−1). - NUM_NIBBLES = 8 + # Decompose the full 64-bit Goldilocks FE `a` into 16 × 4-bit nibbles so + # that `a == partial_sum` holds for any valid field element (no top-bits + # restriction). The first `n_nibbles = ceil(domain_size/4)` nibbles encode + # the Merkle query index modulo 2^domain_size. + NUM_NIBBLES = F_BITS / 4 nibbles = Array(NUM_NIBBLES) hint_decompose_bits_merkle_whir(nibbles, a, NUM_NIBBLES, 4) @@ -748,11 +749,12 @@ def _verify_log2_large(n, log2: Const): def log2_ceil_runtime(n): - # requires: 2 < n <= 2^30 + # requires: 2 < n <= 2^30 (still inside HALF_BITS=32, so `_verify_log2_small` + # is always chosen under Goldilocks). log2: Imu hint_log2_ceil(n, log2) assert log2 < 31 if two_exp(log2) != n: - _, partial_sums_24 = checked_decompose_bits(n) - match_range(log2, range(2, 24), lambda i: _verify_log2_small(n, partial_sums_24, i), range(24, 31), lambda i: _verify_log2_large(n, i)) + _, partial_sums_low = checked_decompose_bits(n) + match_range(log2, range(2, 31), lambda i: _verify_log2_small(n, partial_sums_low, i)) return log2 diff --git a/crates/rec_aggregation/whir.py b/crates/rec_aggregation/whir.py index db806a8c..d2a83b51 100644 --- a/crates/rec_aggregation/whir.py +++ b/crates/rec_aggregation/whir.py @@ -208,18 +208,33 @@ def decompose_and_verify_merkle_batch(num_queries, sampled, root, height, num_ch def decompose_and_verify_merkle_batch_with_height(num_queries, sampled, root, height: Const, num_chunks, circle_values, answers): + # Under Goldilocks (DIGEST_LEN=4, DIM=3) the value `num_chunks = two_pow_folding_factor * {1,DIM} / DIGEST_LEN` + # roughly doubles vs KoalaBear (DIGEST_LEN=8). We dispatch the union of both + # configurations so the same file compiles for either field. if num_chunks == DIM * 2: decompose_and_verify_merkle_batch_const(num_queries, sampled, root, height, DIM * 2, circle_values, answers) return if num_chunks == 16: decompose_and_verify_merkle_batch_const(num_queries, sampled, root, height, 16, circle_values, answers) return + if num_chunks == 32: + decompose_and_verify_merkle_batch_const(num_queries, sampled, root, height, 32, circle_values, answers) + return if num_chunks == 8: decompose_and_verify_merkle_batch_const(num_queries, sampled, root, height, 8, circle_values, answers) return + if num_chunks == 12: + decompose_and_verify_merkle_batch_const(num_queries, sampled, root, height, 12, circle_values, answers) + return if num_chunks == 20: decompose_and_verify_merkle_batch_const(num_queries, sampled, root, height, 20, circle_values, answers) return + if num_chunks == 24: + decompose_and_verify_merkle_batch_const(num_queries, sampled, root, height, 24, circle_values, answers) + return + if num_chunks == 2: + decompose_and_verify_merkle_batch_const(num_queries, sampled, root, height, 2, circle_values, answers) + return if num_chunks == 4: decompose_and_verify_merkle_batch_const(num_queries, sampled, root, height, 4, circle_values, answers) return From 1daffe24b41d3960580dc773afc580935137a294 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 17 Apr 2026 09:57:48 +0200 Subject: [PATCH 15/31] w --- crates/rec_aggregation/main.py | 20 +++++++-------- crates/rec_aggregation/src/compilation.rs | 2 +- crates/rec_aggregation/src/lib.rs | 10 +++----- crates/rec_aggregation/xmss_aggregate.py | 30 +++++++++-------------- crates/xmss/src/lib.rs | 2 +- crates/xmss/src/wots.rs | 16 +++++------- 6 files changed, 32 insertions(+), 48 deletions(-) diff --git a/crates/rec_aggregation/main.py b/crates/rec_aggregation/main.py index 3fcdbd8f..87368eb0 100644 --- a/crates/rec_aggregation/main.py +++ b/crates/rec_aggregation/main.py @@ -11,7 +11,7 @@ INPUT_DATA_SIZE_PADDED = INPUT_DATA_SIZE_PADDED_PLACEHOLDER INPUT_DATA_NUM_CHUNKS = INPUT_DATA_SIZE_PADDED / DIGEST_LEN -BYTECODE_CLAIM_OFFSET = 1 + DIGEST_LEN + MESSAGE_LEN + 2 + N_MERKLE_CHUNKS +BYTECODE_CLAIM_OFFSET = 1 + DIGEST_LEN + MESSAGE_LEN + 1 + N_MERKLE_CHUNKS BYTECODE_HASH_DOMSEP_OFFSET = BYTECODE_CLAIM_OFFSET + BYTECODE_CLAIM_SIZE_PADDED BYTECODE_SUMCHECK_PROOF_SIZE = BYTECODE_SUMCHECK_PROOF_SIZE_PLACEHOLDER @@ -29,9 +29,8 @@ def main(): pubkeys_hash_expected = data_buf + 1 message = pubkeys_hash_expected + DIGEST_LEN slot_ptr = message + MESSAGE_LEN - slot_lo = slot_ptr[0] - slot_hi = slot_ptr[1] - merkle_chunks_for_slot = slot_ptr + 2 + slot = slot_ptr[0] + merkle_chunks_for_slot = slot_ptr + 1 bytecode_claim_output = data_buf + BYTECODE_CLAIM_OFFSET bytecode_hash_domsep = data_buf + BYTECODE_HASH_DOMSEP_OFFSET @@ -58,7 +57,7 @@ def main(): assert n_dup == 0 if n_raw_xmss == 0: inner_data_buf = build_inner_data_buf( - n_sigs, pubkeys_hash_expected, message, slot_lo, slot_hi, + n_sigs, pubkeys_hash_expected, message, slot, merkle_chunks_for_slot, bytecode_hash_domsep, ) @@ -89,7 +88,7 @@ def main(): buffer[idx] = i # Verify raw XMSS signatures pk = all_pubkeys + idx * DIGEST_LEN - xmss_verify(pk, message, slot_lo, slot_hi, merkle_chunks_for_slot) + xmss_verify(pk, message, slot, merkle_chunks_for_slot) counter: Mut = n_raw_xmss @@ -124,7 +123,7 @@ def main(): running_hash = new_hash inner_data_buf = build_inner_data_buf( - n_sub, running_hash, message, slot_lo, slot_hi, + n_sub, running_hash, message, slot, merkle_chunks_for_slot, bytecode_hash_domsep, ) inner_pub_mem = Array(INNER_PUB_MEM_SIZE) @@ -199,16 +198,15 @@ def reduce_bytecode_claims(bytecode_claims, n_bytecode_claims, bytecode_claim_ou return @inline -def build_inner_data_buf(n_sub, pubkeys_hash, message, slot_lo, slot_hi, merkle_chunks_for_slot, bytecode_hash_domsep): +def build_inner_data_buf(n_sub, pubkeys_hash, message, slot, merkle_chunks_for_slot, bytecode_hash_domsep): inner_data_buf = Array(INPUT_DATA_SIZE_PADDED) inner_data_buf[0] = n_sub copy_digest(pubkeys_hash, inner_data_buf + 1) inner_msg = inner_data_buf + 1 + DIGEST_LEN copy_message(message, inner_msg) # copies MESSAGE_LEN=4 elements under Goldilocks - inner_msg[MESSAGE_LEN] = slot_lo - inner_msg[MESSAGE_LEN + 1] = slot_hi + inner_msg[MESSAGE_LEN] = slot for k in unroll(0, N_MERKLE_CHUNKS): - inner_msg[MESSAGE_LEN + 2 + k] = merkle_chunks_for_slot[k] + inner_msg[MESSAGE_LEN + 1 + k] = merkle_chunks_for_slot[k] hint_witness("inner_bytecode_claim", inner_data_buf + BYTECODE_CLAIM_OFFSET) copy_digest(bytecode_hash_domsep, inner_data_buf + BYTECODE_HASH_DOMSEP_OFFSET) for k in unroll(BYTECODE_HASH_DOMSEP_OFFSET + DIGEST_LEN, INPUT_DATA_SIZE_PADDED): diff --git a/crates/rec_aggregation/src/compilation.rs b/crates/rec_aggregation/src/compilation.rs index a7eee819..cce8dd6b 100644 --- a/crates/rec_aggregation/src/compilation.rs +++ b/crates/rec_aggregation/src/compilation.rs @@ -32,7 +32,7 @@ fn compile_main_program(inner_program_log_size: usize, bytecode_zero_eval: F) -> let claim_data_size = (bytecode_point_n_vars + 1) * DIMENSION; let claim_data_size_padded = claim_data_size.next_multiple_of(DIGEST_LEN); let input_data_size = - 1 + DIGEST_LEN + MESSAGE_LEN_FE + 2 + N_MERKLE_CHUNKS_FOR_SLOT + claim_data_size_padded + DIGEST_LEN; + 1 + DIGEST_LEN + MESSAGE_LEN_FE + 1 + N_MERKLE_CHUNKS_FOR_SLOT + claim_data_size_padded + DIGEST_LEN; let input_data_size_padded = input_data_size.next_multiple_of(DIGEST_LEN); let replacements = build_replacements(inner_program_log_size, bytecode_zero_eval, input_data_size_padded); diff --git a/crates/rec_aggregation/src/lib.rs b/crates/rec_aggregation/src/lib.rs index 314edcdf..0dd49a0d 100644 --- a/crates/rec_aggregation/src/lib.rs +++ b/crates/rec_aggregation/src/lib.rs @@ -7,7 +7,7 @@ use lean_prover::verify_execution::verify_execution; use lean_vm::*; use tracing::instrument; use utils::{build_prover_state, get_poseidon8, poseidon_compress_slice, poseidon8_compress_pair}; -use xmss::{LOG_LIFETIME, MESSAGE_LEN_FE, SIG_SIZE_FE, XmssPublicKey, XmssSignature, slot_to_field_elements}; +use xmss::{LOG_LIFETIME, MESSAGE_LEN_FE, SIG_SIZE_FE, XmssPublicKey, XmssSignature, slot_to_field_element}; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; @@ -73,9 +73,7 @@ fn build_input_data( data.push(F::from_usize(n_sigs)); data.extend_from_slice(slice_hash); data.extend_from_slice(message); - let [slot_lo, slot_hi] = slot_to_field_elements(slot); - data.push(slot_lo); - data.push(slot_hi); + data.push(slot_to_field_element(slot)); data.extend(compute_merkle_chunks_for_slot(slot)); data.extend_from_slice(bytecode_claim_output); // Pad the bytecode claim itself up to DIGEST_LEN @@ -222,7 +220,7 @@ pub fn xmss_aggregate( // Bytecode sumcheck reduction let (bytecode_claim_output, bytecode_point, final_sumcheck_transcript) = if n_recursions > 0 { - let bytecode_claim_offset = 1 + DIGEST_LEN + 2 + MESSAGE_LEN_FE + N_MERKLE_CHUNKS_FOR_SLOT; + let bytecode_claim_offset = 1 + DIGEST_LEN + 1 + MESSAGE_LEN_FE + N_MERKLE_CHUNKS_FOR_SLOT; let mut claims = vec![]; for (i, _child) in children.iter().enumerate() { let first_claim = extract_bytecode_claim_from_input_data( @@ -327,7 +325,7 @@ pub fn xmss_aggregate( let mut inner_bytecode_claim_blobs = Vec::with_capacity(n_recursions); let mut proof_transcript_blobs = Vec::with_capacity(n_recursions); - let claim_offset_in_input = 1 + DIGEST_LEN + 2 + MESSAGE_LEN_FE + N_MERKLE_CHUNKS_FOR_SLOT; + let claim_offset_in_input = 1 + DIGEST_LEN + 1 + MESSAGE_LEN_FE + N_MERKLE_CHUNKS_FOR_SLOT; let claim_size_padded = bytecode_claim_size.next_multiple_of(DIGEST_LEN); // Sources 1..n_recursions: recursive children diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index bbdfda4e..2459e9aa 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -22,7 +22,7 @@ @inline -def xmss_verify(merkle_root, message, slot_lo, slot_hi, merkle_chunks): +def xmss_verify(merkle_root, message, slot, merkle_chunks): # signature: randomness | chain_tips | merkle_path # return the hashed xmss public key signature = Array(SIG_SIZE) @@ -31,31 +31,23 @@ def xmss_verify(merkle_root, message, slot_lo, slot_hi, merkle_chunks): chain_starts = signature + RANDOMNESS_LEN merkle_path = chain_starts + V * DIGEST_LEN - # 1) Hash (message, randomness, slot, merkle_root) into 3 output FE via a - # 3-call Poseidon8 sponge chain, mirroring `poseidon_compress_slice` on - # 14 input FE in the Rust side. + # 1) Hash (message, randomness, slot, merkle_root) into the 3 output FE via + # a 2-call Poseidon8 sponge chain, mirroring `poseidon_compress_slice` on + # 12 input FE in the Rust side. # # Call 1: poseidon8(message[0..4], randomness[0..4]) → a a = Array(DIGEST_LEN) poseidon8_compress(message, randomness, a) - # Call 2: poseidon8(a, [slot_lo, slot_hi, root[0], root[1]]) → b + # Call 2: poseidon8(a, [slot, root[0], root[1], root[2]]) → encoding_fe + # (4 FE; we use the first 3 as the Winternitz encoding). rhs2 = Array(DIGEST_LEN) - rhs2[0] = slot_lo - rhs2[1] = slot_hi - rhs2[2] = merkle_root[0] - rhs2[3] = merkle_root[1] - b = Array(DIGEST_LEN) - poseidon8_compress(a, rhs2, b) - - # Call 3: poseidon8(b, [root[2], root[3], 0, 0]) → encoding_fe (4 FE; we use the first 3) - rhs3 = Array(DIGEST_LEN) - rhs3[0] = merkle_root[2] - rhs3[1] = merkle_root[3] - rhs3[2] = 0 - rhs3[3] = 0 + rhs2[0] = slot + rhs2[1] = merkle_root[0] + rhs2[2] = merkle_root[1] + rhs2[3] = merkle_root[2] encoding_fe = Array(DIGEST_LEN) - poseidon8_compress(b, rhs3, encoding_fe) + poseidon8_compress(a, rhs2, encoding_fe) # 2) Decompose each of the first 3 FE into 21 3-bit chunks = 63 bits per FE # (1-bit remainder). 3 × 21 = 63 total chunks; first V+V_GRINDING used. diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index 21b675a5..fb5e49ce 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -21,6 +21,6 @@ pub const V_GRINDING: usize = 2; pub const LOG_LIFETIME: usize = 32; pub const RANDOMNESS_LEN_FE: usize = 4; pub const MESSAGE_LEN_FE: usize = 4; -pub const TRUNCATED_MERKLE_ROOT_LEN_FE: usize = 4; +pub const TRUNCATED_MERKLE_ROOT_LEN_FE: usize = 3; pub const SIG_SIZE_FE: usize = RANDOMNESS_LEN_FE + (V + LOG_LIFETIME) * DIGEST_SIZE; diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index a7ca9d66..22a55069 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -109,15 +109,14 @@ pub fn wots_encode( truncated_merkle_root: &[F; TRUNCATED_MERKLE_ROOT_LEN_FE], randomness: &[F; RANDOMNESS_LEN_FE], ) -> Option<[u8; V]> { - let [slot_lo, slot_hi] = slot_to_field_elements(slot); + let slot_fe = slot_to_field_element(slot); - const INPUT_LEN: usize = MESSAGE_LEN_FE + RANDOMNESS_LEN_FE + 2 + TRUNCATED_MERKLE_ROOT_LEN_FE; + const INPUT_LEN: usize = MESSAGE_LEN_FE + RANDOMNESS_LEN_FE + 1 + TRUNCATED_MERKLE_ROOT_LEN_FE; let mut input = [F::default(); INPUT_LEN]; input[..MESSAGE_LEN_FE].copy_from_slice(message); input[MESSAGE_LEN_FE..MESSAGE_LEN_FE + RANDOMNESS_LEN_FE].copy_from_slice(randomness); - input[MESSAGE_LEN_FE + RANDOMNESS_LEN_FE] = slot_lo; - input[MESSAGE_LEN_FE + RANDOMNESS_LEN_FE + 1] = slot_hi; - input[MESSAGE_LEN_FE + RANDOMNESS_LEN_FE + 2..].copy_from_slice(truncated_merkle_root); + input[MESSAGE_LEN_FE + RANDOMNESS_LEN_FE] = slot_fe; + input[MESSAGE_LEN_FE + RANDOMNESS_LEN_FE + 1..].copy_from_slice(truncated_merkle_root); // `poseidon_compress_slice` returns 4 FE; we use the first 3 (= NUM_ENCODING_FE). // Assumption (for now): each Goldilocks FE yields ~64 bits of almost-uniform entropy. @@ -163,9 +162,6 @@ fn is_valid_encoding(encoding: &[u8]) -> bool { true } -pub fn slot_to_field_elements(slot: u32) -> [F; 2] { - [ - F::from_usize((slot & 0xFFFF) as usize), - F::from_usize(((slot >> 16) & 0xFFFF) as usize), - ] +pub fn slot_to_field_element(slot: u32) -> F { + F::from_usize(slot as usize) } From a635928f22891abd46afee63177ffdc24d11b057 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Sat, 25 Apr 2026 17:20:18 +0200 Subject: [PATCH 16/31] w Co-authored-by: Copilot --- .../src/benchmark_poseidons_goldilocks.rs | 39 ++ .../backend/goldilocks/src/cubic_extension.rs | 4 +- crates/backend/goldilocks/src/goldilocks.rs | 16 +- crates/backend/goldilocks/src/lib.rs | 3 + crates/backend/goldilocks/src/poseidon1.rs | 452 ++++++++++++------ ...ns.rs => benchmark_poseidons_koalabear.rs} | 2 +- crates/backend/koala-bear/src/lib.rs | 2 +- 7 files changed, 366 insertions(+), 152 deletions(-) create mode 100644 crates/backend/goldilocks/src/benchmark_poseidons_goldilocks.rs rename crates/backend/koala-bear/src/{benchmark_poseidons.rs => benchmark_poseidons_koalabear.rs} (93%) diff --git a/crates/backend/goldilocks/src/benchmark_poseidons_goldilocks.rs b/crates/backend/goldilocks/src/benchmark_poseidons_goldilocks.rs new file mode 100644 index 00000000..170c6f07 --- /dev/null +++ b/crates/backend/goldilocks/src/benchmark_poseidons_goldilocks.rs @@ -0,0 +1,39 @@ +use std::hint::black_box; +use std::time::Instant; + +use field::Field; +use field::PackedValue; +use field::PrimeCharacteristicRing; + +use crate::{Goldilocks, default_goldilocks_poseidon1_8}; + +type FPacking = ::Packing; +const PACKING_WIDTH: usize = ::WIDTH; + +#[test] +#[ignore] +fn bench_poseidon() { + // cargo test --release --package mt-goldilocks --lib -- benchmark_poseidons_goldilocks::bench_poseidon --exact --nocapture --ignored + + let n = 1 << 23; + let poseidon1_8 = default_goldilocks_poseidon1_8(); + + // warming + let mut state_8: [FPacking; 8] = [FPacking::ZERO; 8]; + for _ in 0..1 << 15 { + poseidon1_8.compress_in_place(&mut state_8); + } + let _ = black_box(state_8); + + let time = Instant::now(); + for _ in 0..n / PACKING_WIDTH { + poseidon1_8.compress_in_place(&mut state_8); + } + let _ = black_box(state_8); + let time_p1_simd = time.elapsed(); + println!( + "Poseidon1 8 SIMD (width {}): {:.2}M hashes/s", + PACKING_WIDTH, + (n as f64 / time_p1_simd.as_secs_f64() / 1_000_000.0) + ); +} diff --git a/crates/backend/goldilocks/src/cubic_extension.rs b/crates/backend/goldilocks/src/cubic_extension.rs index 68abf20e..15c24b99 100644 --- a/crates/backend/goldilocks/src/cubic_extension.rs +++ b/crates/backend/goldilocks/src/cubic_extension.rs @@ -98,9 +98,7 @@ impl BasedVectorSpace for CubicExtensionFieldGL { } #[inline] - fn from_basis_coefficients_iter>( - mut iter: I, - ) -> Option { + fn from_basis_coefficients_iter>(mut iter: I) -> Option { (iter.len() == 3).then(|| Self::new(array::from_fn(|_| iter.next().unwrap()))) } diff --git a/crates/backend/goldilocks/src/goldilocks.rs b/crates/backend/goldilocks/src/goldilocks.rs index 6a4e6be9..d2672ff6 100644 --- a/crates/backend/goldilocks/src/goldilocks.rs +++ b/crates/backend/goldilocks/src/goldilocks.rs @@ -11,9 +11,9 @@ use core::{array, fmt}; use field::integers::QuotientMap; use field::op_assign_macros::{impl_add_assign, impl_div_methods, impl_mul_methods, impl_sub_assign}; use field::{ - Field, InjectiveMonomial, Packable, PermutationMonomial, PrimeCharacteristicRing, PrimeField, - PrimeField64, RawDataSerializable, TwoAdicField, impl_raw_serializable_primefield64, - quotient_map_large_iint, quotient_map_large_uint, quotient_map_small_int, + Field, InjectiveMonomial, Packable, PermutationMonomial, PrimeCharacteristicRing, PrimeField, PrimeField64, + RawDataSerializable, TwoAdicField, impl_raw_serializable_primefield64, quotient_map_large_iint, + quotient_map_large_uint, quotient_map_small_int, }; use num_bigint::BigUint; use rand::Rng; @@ -60,9 +60,7 @@ impl Goldilocks { /// Convert a `[[u64; N]; M]` array to a 2D array of field elements. #[inline] - pub const fn new_2d_array( - input: [[u64; N]; M], - ) -> [[Self; N]; M] { + pub const fn new_2d_array(input: [[u64; N]; M]) -> [[Self; N]; M] { let mut output = [[Self::ZERO; N]; M]; let mut i = 0; while i < M { @@ -255,11 +253,7 @@ impl PrimeCharacteristicRing for Goldilocks { let long_prod_1 = (lhs[1].value as u128) * (rhs[1].value as u128); let (sum, over) = long_prod_0.overflowing_add(long_prod_1); let sum_corr = sum.wrapping_sub(OFFSET); - if over { - reduce128(sum_corr) - } else { - reduce128(sum) - } + if over { reduce128(sum_corr) } else { reduce128(sum) } } _ => { let (lo_plus_hi, hi) = lhs diff --git a/crates/backend/goldilocks/src/lib.rs b/crates/backend/goldilocks/src/lib.rs index 26f147b0..0d6cb6a1 100644 --- a/crates/backend/goldilocks/src/lib.rs +++ b/crates/backend/goldilocks/src/lib.rs @@ -11,6 +11,9 @@ mod goldilocks; mod helpers; mod poseidon1; +#[cfg(test)] +mod benchmark_poseidons_goldilocks; + pub use cubic_extension::*; pub use goldilocks::*; pub use helpers::*; diff --git a/crates/backend/goldilocks/src/poseidon1.rs b/crates/backend/goldilocks/src/poseidon1.rs index c0e6c6b7..3100afeb 100644 --- a/crates/backend/goldilocks/src/poseidon1.rs +++ b/crates/backend/goldilocks/src/poseidon1.rs @@ -24,8 +24,7 @@ pub const POSEIDON1_PARTIAL_ROUNDS: usize = 22; pub const POSEIDON1_SBOX_DEGREE: u64 = 7; pub const POSEIDON1_DIGEST_LEN: usize = 4; -pub const POSEIDON1_N_ROUNDS: usize = - 2 * POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS; +pub const POSEIDON1_N_ROUNDS: usize = 2 * POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS; // ========================================================================= // MDS matrix (circulant, width 8) @@ -89,132 +88,311 @@ fn mds_mul_scalar(state: &mut [Goldilocks; 8]) { // Generated by the Grain LFSR (Poseidon1, Appendix E) with // `field_type = 1, alpha = 7, n = 64, t = 8, R_F = 8, R_P = 22`. // Values carried over verbatim from `plonky3/goldilocks/src/poseidon1.rs`. -pub const GOLDILOCKS_POSEIDON1_RC_8: [[Goldilocks; POSEIDON1_WIDTH]; POSEIDON1_N_ROUNDS] = - Goldilocks::new_2d_array([ - // ---- Initial full rounds (4) ---- - [ - 0xdd5743e7f2a5a5d9, 0xcb3a864e58ada44b, 0xffa2449ed32f8cdc, 0x42025f65d6bd13ee, - 0x7889175e25506323, 0x34b98bb03d24b737, 0xbdcc535ecc4faa2a, 0x5b20ad869fc0d033, - ], - [ - 0xf1dda5b9259dfcb4, 0x27515210be112d59, 0x4227d1718c766c3f, 0x26d333161a5bd794, - 0x49b938957bf4b026, 0x4a56b5938b213669, 0x1120426b48c8353d, 0x6b323c3f10a56cad, - ], - [ - 0xce57d6245ddca6b2, 0xb1fc8d402bba1eb1, 0xb5c5096ca959bd04, 0x6db55cd306d31f7f, - 0xc49d293a81cb9641, 0x1ce55a4fe979719f, 0xa92e60a9d178a4d1, 0x002cc64973bcfd8c, - ], - [ - 0xcea721cce82fb11b, 0xe5b55eb8098ece81, 0x4e30525c6f1ddd66, 0x43c6702827070987, - 0xaca68430a7b5762a, 0x3674238634df9c93, 0x88cee1c825e33433, 0xde99ae8d74b57176, - ], - // ---- Partial rounds (22) ---- - [ - 0x488897d85ff51f56, 0x1140737ccb162218, 0xa7eeb9215866ed35, 0x9bd2976fee49fcc9, - 0xc0c8f0de580a3fcc, 0x4fb2dae6ee8fc793, 0x343a89f35f37395b, 0x223b525a77ca72c8, - ], - [ - 0x56ccb62574aaa918, 0xc4d507d8027af9ed, 0xa080673cf0b7e95c, 0xf0184884eb70dcf8, - 0x044f10b0cb3d5c69, 0xe9e3f7993938f186, 0x1b761c80e772f459, 0x606cec607a1b5fac, - ], - [ - 0x14a0c2e1d45f03cd, 0x4eace8855398574f, 0xf905ca7103eff3e6, 0xf8c8f8d20862c059, - 0xb524fe8bdd678e5a, 0xfbb7865901a1ec41, 0x014ef1197d341346, 0x9725e20825d07394, - ], - [ - 0xfdb25aef2c5bae3b, 0xbe5402dc598c971e, 0x93a5711f04cdca3d, 0xc45a9a5b2f8fb97b, - 0xfe8946a924933545, 0x2af997a27369091c, 0xaa62c88e0b294011, 0x058eb9d810ce9f74, - ], - [ - 0xb3cb23eced349ae4, 0xa3648177a77b4a84, 0x43153d905992d95d, 0xf4e2a97cda44aa4b, - 0x5baa2702b908682f, 0x082923bdf4f750d1, 0x98ae09a325893803, 0xf8a6475077968838, - ], - [ - 0xceb0735bf00b2c5f, 0x0a1a5d953888e072, 0x2fcb190489f94475, 0xb5be06270dec69fc, - 0x739cb934b09acf8b, 0x537750b75ec7f25b, 0xe9dd318bae1f3961, 0xf7462137299efe1a, - ], - [ - 0xb1f6b8eee9adb940, 0xbdebcc8a809dfe6b, 0x40fc1f791b178113, 0x3ac1c3362d014864, - 0x9a016184bdb8aeba, 0x95f2394459fbc25e, 0xe3f34a07a76a66c2, 0x8df25f9ad98b1b96, - ], - [ - 0x85ffc27171439d9d, 0xddcb9a2dcfd26910, 0x26b5ba4bf3afb94e, 0xffff9cc7c7651e2f, - 0x8c88364698280b55, 0xebc114167b910501, 0x2d77b4d89ecfb516, 0x332e0828eba151f2, - ], - [ - 0x46fa6a6450dd4735, 0xd00db7dd92384a33, 0x5fd4fb751f3a5fc5, 0x496fb90c0bb65ea2, - 0xf3baec0bb87cc5c7, 0x862a3c0a7d4c7713, 0xbf5f38336a3f47d8, 0x41ad9dbc1394a20c, - ], - [ - 0xcc535945b7dbf0f7, 0x82af2bc93685bcec, 0x8e4c8d0c8cebfccd, 0x17cb39417e84597e, - 0xd4a965a8c749b232, 0xa2cab040f33f3ee5, 0xa98811a1fed4e3a6, 0x1cc48b54f377e2a1, - ], - [ - 0xe40cd4f6c5609a27, 0x11de79ebca97a4a4, 0x9177c73d8b7e929d, 0x2a6fe8085797e792, - 0x3de6e93329f8d5ae, 0x3f7af9125da962ff, 0xd710682cfc77d3ac, 0x48faf05f3b053cf4, - ], - [ - 0x287db8630da89c8b, 0x4d0de32053cb30e9, 0x8b37a4f20c5ada7b, 0xe7cc6ebe78c84ecf, - 0x240bdc0a66a2610d, 0x8299e7f02caa1650, 0x380a53fefb6e754e, 0x684a1d8cf8eb6810, - ], - [ - 0xe839452eb4b8a5e1, 0xb03fa62e90626af4, 0x11a688602fbc5efc, 0x30dda75c355a2d62, - 0x0f712adcb73810de, 0xffdc1102187f1ae1, 0x40c34f398254b99c, 0xede021b9dc289a4a, - ], - [ - 0x8b7b05225c4e7dad, 0x3bc794346f9d9ff9, 0xfccb5a57f2ca86ff, 0xbb1502015a7da9d4, - 0xd7e0a35d4352a015, 0x27af7a44f8160931, 0xc37442f6782f4615, 0xbdf392a9bd095dcb, - ], - [ - 0xc17f55037cf00de9, 0xbcffedd34c71a874, 0x5eb45d2a8133d1f2, 0xbabe251e1612ebdf, - 0x3efeb9fbe438c536, 0x2d7cef97b4afe1cf, 0xe5de1b4660016c0b, 0xcdcc26c332f5657c, - ], - [ - 0xe01dd653daf15809, 0xb0a6bdd4b41094b5, 0x27eac858b0b03a05, 0x51d43b5e93adbdc0, - 0x8b89a23b0fea5fc9, 0xdc8ac3b14f7f2fc1, 0xe793f82f1efec039, 0x9f6f2cf8969e7b80, - ], - [ - 0x49d45382e0f21d4a, 0x5f4ad1797cd72786, 0x4dc3dbebfd45f795, 0x03a3ef84dba6e1bc, - 0x204bc9b3d3fc4c01, 0x9ad706081e89b9ba, 0x638bfb4d840e9f89, 0x5ef2938cd095ae35, - ], - [ - 0x42cca18ebeb265c8, 0xb7b2ec5c29aecbf8, 0x0d84f9535dc78f0f, 0x04e64ad942e77b8c, - 0xb4880dffffc9da0b, 0x16db16d9c29adeb1, 0x09bbaf2a0590cd1e, 0x76460e74961fcf8d, - ], - [ - 0xed12a2276dfa1553, 0x0b5acec5de0436fd, 0x3c6cfea033a1f0a8, 0x2b5ecefe546cac15, - 0x6e2d82884cd3bf6f, 0xc134878d1add7b83, 0x997963422eb7a280, 0x5e834537ac648cf6, - ], - [ - 0x89e779214737c0b7, 0x1a8c05e8581ad95b, 0x8d18b72796437cf7, 0xe7252c949e04b106, - 0x53267c4fd174585a, 0xa16ef5d9c81dad47, 0xda65191937270a46, 0xcb2a5b55f2df664c, - ], - [ - 0x854aee2dc1924137, 0xf37013c9d479ece6, 0x0e163bc0630c4696, 0x384ee64955048f76, - 0xf65d814e28ee4ec5, 0xe57bc564fd82f1b1, 0x4b338937b6876614, 0x66ee0b04ed43cd8d, - ], - [ - 0x49884bf25f4ef15d, 0xeb51fe28de1c6f54, 0x2cd64e84fce8dfcc, 0x29164a96a541a013, - 0x173ce7558f4cacb8, 0xeb5b1ce5877c89e9, 0x5faff4b0f5217bf6, 0xac42d0b1c20f205e, - ], - // ---- Terminal full rounds (4) ---- - [ - 0xfb1d6bf0ca43221b, 0x97b0a1b01d6a2955, 0x08c60bd622952b30, 0x43f2be0f9e24147c, - 0xfa7268b7d3730f5d, 0x43a6c419a23983bb, 0xcd77c1f7b29b113c, 0xcfa43c9db8eec29f, - ], - [ - 0xcaaa95a6c7365dec, 0x0a91193f798f3be0, 0x1104497652735dc6, 0x35aecb93663b515e, - 0x8dbc9916065aa858, 0xada8f7a0266579ed, 0x524dee7bec1ea789, 0xa93aee9dd5af9521, - ], - [ - 0x9d1f1b54750d707e, 0x7c9feab87096d5dc, 0xa2e1fb19f9d4261b, 0xb714deb448de6346, - 0x225d1f0d011c5403, 0x1549b7f1d28cedc0, 0xaef3e46f97d43942, 0x6dfc7ffe0b38bf08, - ], - [ - 0x7de853fdc542b663, 0xa68ecc96610657b2, 0xe88bb5428af289b1, 0xd7cfa1504c5569f5, - 0x78a9aad0d642d30a, 0xd68315f2353dce52, 0x46e56300f86fcfd5, 0x323d95332b145fd6, - ], - ]); +pub const GOLDILOCKS_POSEIDON1_RC_8: [[Goldilocks; POSEIDON1_WIDTH]; POSEIDON1_N_ROUNDS] = Goldilocks::new_2d_array([ + // ---- Initial full rounds (4) ---- + [ + 0xdd5743e7f2a5a5d9, + 0xcb3a864e58ada44b, + 0xffa2449ed32f8cdc, + 0x42025f65d6bd13ee, + 0x7889175e25506323, + 0x34b98bb03d24b737, + 0xbdcc535ecc4faa2a, + 0x5b20ad869fc0d033, + ], + [ + 0xf1dda5b9259dfcb4, + 0x27515210be112d59, + 0x4227d1718c766c3f, + 0x26d333161a5bd794, + 0x49b938957bf4b026, + 0x4a56b5938b213669, + 0x1120426b48c8353d, + 0x6b323c3f10a56cad, + ], + [ + 0xce57d6245ddca6b2, + 0xb1fc8d402bba1eb1, + 0xb5c5096ca959bd04, + 0x6db55cd306d31f7f, + 0xc49d293a81cb9641, + 0x1ce55a4fe979719f, + 0xa92e60a9d178a4d1, + 0x002cc64973bcfd8c, + ], + [ + 0xcea721cce82fb11b, + 0xe5b55eb8098ece81, + 0x4e30525c6f1ddd66, + 0x43c6702827070987, + 0xaca68430a7b5762a, + 0x3674238634df9c93, + 0x88cee1c825e33433, + 0xde99ae8d74b57176, + ], + // ---- Partial rounds (22) ---- + [ + 0x488897d85ff51f56, + 0x1140737ccb162218, + 0xa7eeb9215866ed35, + 0x9bd2976fee49fcc9, + 0xc0c8f0de580a3fcc, + 0x4fb2dae6ee8fc793, + 0x343a89f35f37395b, + 0x223b525a77ca72c8, + ], + [ + 0x56ccb62574aaa918, + 0xc4d507d8027af9ed, + 0xa080673cf0b7e95c, + 0xf0184884eb70dcf8, + 0x044f10b0cb3d5c69, + 0xe9e3f7993938f186, + 0x1b761c80e772f459, + 0x606cec607a1b5fac, + ], + [ + 0x14a0c2e1d45f03cd, + 0x4eace8855398574f, + 0xf905ca7103eff3e6, + 0xf8c8f8d20862c059, + 0xb524fe8bdd678e5a, + 0xfbb7865901a1ec41, + 0x014ef1197d341346, + 0x9725e20825d07394, + ], + [ + 0xfdb25aef2c5bae3b, + 0xbe5402dc598c971e, + 0x93a5711f04cdca3d, + 0xc45a9a5b2f8fb97b, + 0xfe8946a924933545, + 0x2af997a27369091c, + 0xaa62c88e0b294011, + 0x058eb9d810ce9f74, + ], + [ + 0xb3cb23eced349ae4, + 0xa3648177a77b4a84, + 0x43153d905992d95d, + 0xf4e2a97cda44aa4b, + 0x5baa2702b908682f, + 0x082923bdf4f750d1, + 0x98ae09a325893803, + 0xf8a6475077968838, + ], + [ + 0xceb0735bf00b2c5f, + 0x0a1a5d953888e072, + 0x2fcb190489f94475, + 0xb5be06270dec69fc, + 0x739cb934b09acf8b, + 0x537750b75ec7f25b, + 0xe9dd318bae1f3961, + 0xf7462137299efe1a, + ], + [ + 0xb1f6b8eee9adb940, + 0xbdebcc8a809dfe6b, + 0x40fc1f791b178113, + 0x3ac1c3362d014864, + 0x9a016184bdb8aeba, + 0x95f2394459fbc25e, + 0xe3f34a07a76a66c2, + 0x8df25f9ad98b1b96, + ], + [ + 0x85ffc27171439d9d, + 0xddcb9a2dcfd26910, + 0x26b5ba4bf3afb94e, + 0xffff9cc7c7651e2f, + 0x8c88364698280b55, + 0xebc114167b910501, + 0x2d77b4d89ecfb516, + 0x332e0828eba151f2, + ], + [ + 0x46fa6a6450dd4735, + 0xd00db7dd92384a33, + 0x5fd4fb751f3a5fc5, + 0x496fb90c0bb65ea2, + 0xf3baec0bb87cc5c7, + 0x862a3c0a7d4c7713, + 0xbf5f38336a3f47d8, + 0x41ad9dbc1394a20c, + ], + [ + 0xcc535945b7dbf0f7, + 0x82af2bc93685bcec, + 0x8e4c8d0c8cebfccd, + 0x17cb39417e84597e, + 0xd4a965a8c749b232, + 0xa2cab040f33f3ee5, + 0xa98811a1fed4e3a6, + 0x1cc48b54f377e2a1, + ], + [ + 0xe40cd4f6c5609a27, + 0x11de79ebca97a4a4, + 0x9177c73d8b7e929d, + 0x2a6fe8085797e792, + 0x3de6e93329f8d5ae, + 0x3f7af9125da962ff, + 0xd710682cfc77d3ac, + 0x48faf05f3b053cf4, + ], + [ + 0x287db8630da89c8b, + 0x4d0de32053cb30e9, + 0x8b37a4f20c5ada7b, + 0xe7cc6ebe78c84ecf, + 0x240bdc0a66a2610d, + 0x8299e7f02caa1650, + 0x380a53fefb6e754e, + 0x684a1d8cf8eb6810, + ], + [ + 0xe839452eb4b8a5e1, + 0xb03fa62e90626af4, + 0x11a688602fbc5efc, + 0x30dda75c355a2d62, + 0x0f712adcb73810de, + 0xffdc1102187f1ae1, + 0x40c34f398254b99c, + 0xede021b9dc289a4a, + ], + [ + 0x8b7b05225c4e7dad, + 0x3bc794346f9d9ff9, + 0xfccb5a57f2ca86ff, + 0xbb1502015a7da9d4, + 0xd7e0a35d4352a015, + 0x27af7a44f8160931, + 0xc37442f6782f4615, + 0xbdf392a9bd095dcb, + ], + [ + 0xc17f55037cf00de9, + 0xbcffedd34c71a874, + 0x5eb45d2a8133d1f2, + 0xbabe251e1612ebdf, + 0x3efeb9fbe438c536, + 0x2d7cef97b4afe1cf, + 0xe5de1b4660016c0b, + 0xcdcc26c332f5657c, + ], + [ + 0xe01dd653daf15809, + 0xb0a6bdd4b41094b5, + 0x27eac858b0b03a05, + 0x51d43b5e93adbdc0, + 0x8b89a23b0fea5fc9, + 0xdc8ac3b14f7f2fc1, + 0xe793f82f1efec039, + 0x9f6f2cf8969e7b80, + ], + [ + 0x49d45382e0f21d4a, + 0x5f4ad1797cd72786, + 0x4dc3dbebfd45f795, + 0x03a3ef84dba6e1bc, + 0x204bc9b3d3fc4c01, + 0x9ad706081e89b9ba, + 0x638bfb4d840e9f89, + 0x5ef2938cd095ae35, + ], + [ + 0x42cca18ebeb265c8, + 0xb7b2ec5c29aecbf8, + 0x0d84f9535dc78f0f, + 0x04e64ad942e77b8c, + 0xb4880dffffc9da0b, + 0x16db16d9c29adeb1, + 0x09bbaf2a0590cd1e, + 0x76460e74961fcf8d, + ], + [ + 0xed12a2276dfa1553, + 0x0b5acec5de0436fd, + 0x3c6cfea033a1f0a8, + 0x2b5ecefe546cac15, + 0x6e2d82884cd3bf6f, + 0xc134878d1add7b83, + 0x997963422eb7a280, + 0x5e834537ac648cf6, + ], + [ + 0x89e779214737c0b7, + 0x1a8c05e8581ad95b, + 0x8d18b72796437cf7, + 0xe7252c949e04b106, + 0x53267c4fd174585a, + 0xa16ef5d9c81dad47, + 0xda65191937270a46, + 0xcb2a5b55f2df664c, + ], + [ + 0x854aee2dc1924137, + 0xf37013c9d479ece6, + 0x0e163bc0630c4696, + 0x384ee64955048f76, + 0xf65d814e28ee4ec5, + 0xe57bc564fd82f1b1, + 0x4b338937b6876614, + 0x66ee0b04ed43cd8d, + ], + [ + 0x49884bf25f4ef15d, + 0xeb51fe28de1c6f54, + 0x2cd64e84fce8dfcc, + 0x29164a96a541a013, + 0x173ce7558f4cacb8, + 0xeb5b1ce5877c89e9, + 0x5faff4b0f5217bf6, + 0xac42d0b1c20f205e, + ], + // ---- Terminal full rounds (4) ---- + [ + 0xfb1d6bf0ca43221b, + 0x97b0a1b01d6a2955, + 0x08c60bd622952b30, + 0x43f2be0f9e24147c, + 0xfa7268b7d3730f5d, + 0x43a6c419a23983bb, + 0xcd77c1f7b29b113c, + 0xcfa43c9db8eec29f, + ], + [ + 0xcaaa95a6c7365dec, + 0x0a91193f798f3be0, + 0x1104497652735dc6, + 0x35aecb93663b515e, + 0x8dbc9916065aa858, + 0xada8f7a0266579ed, + 0x524dee7bec1ea789, + 0xa93aee9dd5af9521, + ], + [ + 0x9d1f1b54750d707e, + 0x7c9feab87096d5dc, + 0xa2e1fb19f9d4261b, + 0xb714deb448de6346, + 0x225d1f0d011c5403, + 0x1549b7f1d28cedc0, + 0xaef3e46f97d43942, + 0x6dfc7ffe0b38bf08, + ], + [ + 0x7de853fdc542b663, + 0xa68ecc96610657b2, + 0xe88bb5428af289b1, + 0xd7cfa1504c5569f5, + 0x78a9aad0d642d30a, + 0xd68315f2353dce52, + 0x46e56300f86fcfd5, + 0x323d95332b145fd6, + ], +]); // ========================================================================= // S-box helpers @@ -256,9 +434,7 @@ impl Poseidon1Goldilocks8 { mds_mul_scalar(state); } - for r in POSEIDON1_HALF_FULL_ROUNDS - ..POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS - { + for r in POSEIDON1_HALF_FULL_ROUNDS..POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS { for i in 0..POSEIDON1_WIDTH { state[i] += GOLDILOCKS_POSEIDON1_RC_8[r][i]; } @@ -293,9 +469,7 @@ impl Poseidon1Goldilocks8 { mds_mul_generic(state); } - for r in POSEIDON1_HALF_FULL_ROUNDS - ..POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS - { + for r in POSEIDON1_HALF_FULL_ROUNDS..POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS { for i in 0..POSEIDON1_WIDTH { state[i] = state[i] + GOLDILOCKS_POSEIDON1_RC_8[r][i]; } @@ -363,8 +537,14 @@ mod tests { #[test] fn permutation_is_deterministic() { let input: [Goldilocks; 8] = [ - Goldilocks::new(1), Goldilocks::new(2), Goldilocks::new(3), Goldilocks::new(4), - Goldilocks::new(5), Goldilocks::new(6), Goldilocks::new(7), Goldilocks::new(8), + Goldilocks::new(1), + Goldilocks::new(2), + Goldilocks::new(3), + Goldilocks::new(4), + Goldilocks::new(5), + Goldilocks::new(6), + Goldilocks::new(7), + Goldilocks::new(8), ]; let p = Poseidon1Goldilocks8; let a = p.permute(input); diff --git a/crates/backend/koala-bear/src/benchmark_poseidons.rs b/crates/backend/koala-bear/src/benchmark_poseidons_koalabear.rs similarity index 93% rename from crates/backend/koala-bear/src/benchmark_poseidons.rs rename to crates/backend/koala-bear/src/benchmark_poseidons_koalabear.rs index 66c6a5a0..ec34b729 100644 --- a/crates/backend/koala-bear/src/benchmark_poseidons.rs +++ b/crates/backend/koala-bear/src/benchmark_poseidons_koalabear.rs @@ -13,7 +13,7 @@ const PACKING_WIDTH: usize = ::WIDTH; #[test] #[ignore] fn bench_poseidon() { - // cargo test --release --package mt-koala-bear --lib -- benchmark_poseidons::bench_poseidon --exact --nocapture --ignored + // cargo test --release --package mt-koala-bear --lib -- benchmark_poseidons_koalabear::bench_poseidon --exact --nocapture --ignored let n = 1 << 23; let poseidon1_16 = default_koalabear_poseidon1_16(); diff --git a/crates/backend/koala-bear/src/lib.rs b/crates/backend/koala-bear/src/lib.rs index 959ed3ad..d329843c 100644 --- a/crates/backend/koala-bear/src/lib.rs +++ b/crates/backend/koala-bear/src/lib.rs @@ -11,7 +11,7 @@ pub mod quintic_extension; pub mod symmetric; #[cfg(test)] -mod benchmark_poseidons; +mod benchmark_poseidons_koalabear; #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] mod aarch64_neon; From ec4acbdf0f5bfa586f06ab57ed715382afc020ab Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Sat, 25 Apr 2026 18:07:26 +0200 Subject: [PATCH 17/31] fix --- .../backend/goldilocks/src/cubic_extension.rs | 4 +- crates/backend/goldilocks/src/goldilocks.rs | 16 +- crates/backend/goldilocks/src/poseidon1.rs | 452 ++++++++++++------ crates/backend/poly/src/eq_mle.rs | 12 +- crates/backend/symetric/src/permutation.rs | 4 +- crates/lean_compiler/src/a_simplify_lang.rs | 4 +- .../tests/test_data/program_166.py | 2 +- crates/lean_vm/src/isa/hint.rs | 8 +- crates/lean_vm/src/tables/extension_op/air.rs | 4 +- crates/lean_vm/src/tables/poseidon_8/mod.rs | 48 +- .../lean_vm/src/tables/poseidon_8/sparse.rs | 39 +- crates/rec_aggregation/tests/test_hashing.py | 2 +- .../src/quotient_gkr/sumcheck_utils.rs | 27 ++ crates/utils/src/poseidon.rs | 5 +- crates/whir/src/dft.rs | 2 +- crates/whir/src/merkle.rs | 4 +- crates/whir/tests/run_whir.rs | 2 +- crates/xmss/src/wots.rs | 2 +- 18 files changed, 400 insertions(+), 237 deletions(-) diff --git a/crates/backend/goldilocks/src/cubic_extension.rs b/crates/backend/goldilocks/src/cubic_extension.rs index 68abf20e..15c24b99 100644 --- a/crates/backend/goldilocks/src/cubic_extension.rs +++ b/crates/backend/goldilocks/src/cubic_extension.rs @@ -98,9 +98,7 @@ impl BasedVectorSpace for CubicExtensionFieldGL { } #[inline] - fn from_basis_coefficients_iter>( - mut iter: I, - ) -> Option { + fn from_basis_coefficients_iter>(mut iter: I) -> Option { (iter.len() == 3).then(|| Self::new(array::from_fn(|_| iter.next().unwrap()))) } diff --git a/crates/backend/goldilocks/src/goldilocks.rs b/crates/backend/goldilocks/src/goldilocks.rs index 6a4e6be9..d2672ff6 100644 --- a/crates/backend/goldilocks/src/goldilocks.rs +++ b/crates/backend/goldilocks/src/goldilocks.rs @@ -11,9 +11,9 @@ use core::{array, fmt}; use field::integers::QuotientMap; use field::op_assign_macros::{impl_add_assign, impl_div_methods, impl_mul_methods, impl_sub_assign}; use field::{ - Field, InjectiveMonomial, Packable, PermutationMonomial, PrimeCharacteristicRing, PrimeField, - PrimeField64, RawDataSerializable, TwoAdicField, impl_raw_serializable_primefield64, - quotient_map_large_iint, quotient_map_large_uint, quotient_map_small_int, + Field, InjectiveMonomial, Packable, PermutationMonomial, PrimeCharacteristicRing, PrimeField, PrimeField64, + RawDataSerializable, TwoAdicField, impl_raw_serializable_primefield64, quotient_map_large_iint, + quotient_map_large_uint, quotient_map_small_int, }; use num_bigint::BigUint; use rand::Rng; @@ -60,9 +60,7 @@ impl Goldilocks { /// Convert a `[[u64; N]; M]` array to a 2D array of field elements. #[inline] - pub const fn new_2d_array( - input: [[u64; N]; M], - ) -> [[Self; N]; M] { + pub const fn new_2d_array(input: [[u64; N]; M]) -> [[Self; N]; M] { let mut output = [[Self::ZERO; N]; M]; let mut i = 0; while i < M { @@ -255,11 +253,7 @@ impl PrimeCharacteristicRing for Goldilocks { let long_prod_1 = (lhs[1].value as u128) * (rhs[1].value as u128); let (sum, over) = long_prod_0.overflowing_add(long_prod_1); let sum_corr = sum.wrapping_sub(OFFSET); - if over { - reduce128(sum_corr) - } else { - reduce128(sum) - } + if over { reduce128(sum_corr) } else { reduce128(sum) } } _ => { let (lo_plus_hi, hi) = lhs diff --git a/crates/backend/goldilocks/src/poseidon1.rs b/crates/backend/goldilocks/src/poseidon1.rs index c0e6c6b7..3100afeb 100644 --- a/crates/backend/goldilocks/src/poseidon1.rs +++ b/crates/backend/goldilocks/src/poseidon1.rs @@ -24,8 +24,7 @@ pub const POSEIDON1_PARTIAL_ROUNDS: usize = 22; pub const POSEIDON1_SBOX_DEGREE: u64 = 7; pub const POSEIDON1_DIGEST_LEN: usize = 4; -pub const POSEIDON1_N_ROUNDS: usize = - 2 * POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS; +pub const POSEIDON1_N_ROUNDS: usize = 2 * POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS; // ========================================================================= // MDS matrix (circulant, width 8) @@ -89,132 +88,311 @@ fn mds_mul_scalar(state: &mut [Goldilocks; 8]) { // Generated by the Grain LFSR (Poseidon1, Appendix E) with // `field_type = 1, alpha = 7, n = 64, t = 8, R_F = 8, R_P = 22`. // Values carried over verbatim from `plonky3/goldilocks/src/poseidon1.rs`. -pub const GOLDILOCKS_POSEIDON1_RC_8: [[Goldilocks; POSEIDON1_WIDTH]; POSEIDON1_N_ROUNDS] = - Goldilocks::new_2d_array([ - // ---- Initial full rounds (4) ---- - [ - 0xdd5743e7f2a5a5d9, 0xcb3a864e58ada44b, 0xffa2449ed32f8cdc, 0x42025f65d6bd13ee, - 0x7889175e25506323, 0x34b98bb03d24b737, 0xbdcc535ecc4faa2a, 0x5b20ad869fc0d033, - ], - [ - 0xf1dda5b9259dfcb4, 0x27515210be112d59, 0x4227d1718c766c3f, 0x26d333161a5bd794, - 0x49b938957bf4b026, 0x4a56b5938b213669, 0x1120426b48c8353d, 0x6b323c3f10a56cad, - ], - [ - 0xce57d6245ddca6b2, 0xb1fc8d402bba1eb1, 0xb5c5096ca959bd04, 0x6db55cd306d31f7f, - 0xc49d293a81cb9641, 0x1ce55a4fe979719f, 0xa92e60a9d178a4d1, 0x002cc64973bcfd8c, - ], - [ - 0xcea721cce82fb11b, 0xe5b55eb8098ece81, 0x4e30525c6f1ddd66, 0x43c6702827070987, - 0xaca68430a7b5762a, 0x3674238634df9c93, 0x88cee1c825e33433, 0xde99ae8d74b57176, - ], - // ---- Partial rounds (22) ---- - [ - 0x488897d85ff51f56, 0x1140737ccb162218, 0xa7eeb9215866ed35, 0x9bd2976fee49fcc9, - 0xc0c8f0de580a3fcc, 0x4fb2dae6ee8fc793, 0x343a89f35f37395b, 0x223b525a77ca72c8, - ], - [ - 0x56ccb62574aaa918, 0xc4d507d8027af9ed, 0xa080673cf0b7e95c, 0xf0184884eb70dcf8, - 0x044f10b0cb3d5c69, 0xe9e3f7993938f186, 0x1b761c80e772f459, 0x606cec607a1b5fac, - ], - [ - 0x14a0c2e1d45f03cd, 0x4eace8855398574f, 0xf905ca7103eff3e6, 0xf8c8f8d20862c059, - 0xb524fe8bdd678e5a, 0xfbb7865901a1ec41, 0x014ef1197d341346, 0x9725e20825d07394, - ], - [ - 0xfdb25aef2c5bae3b, 0xbe5402dc598c971e, 0x93a5711f04cdca3d, 0xc45a9a5b2f8fb97b, - 0xfe8946a924933545, 0x2af997a27369091c, 0xaa62c88e0b294011, 0x058eb9d810ce9f74, - ], - [ - 0xb3cb23eced349ae4, 0xa3648177a77b4a84, 0x43153d905992d95d, 0xf4e2a97cda44aa4b, - 0x5baa2702b908682f, 0x082923bdf4f750d1, 0x98ae09a325893803, 0xf8a6475077968838, - ], - [ - 0xceb0735bf00b2c5f, 0x0a1a5d953888e072, 0x2fcb190489f94475, 0xb5be06270dec69fc, - 0x739cb934b09acf8b, 0x537750b75ec7f25b, 0xe9dd318bae1f3961, 0xf7462137299efe1a, - ], - [ - 0xb1f6b8eee9adb940, 0xbdebcc8a809dfe6b, 0x40fc1f791b178113, 0x3ac1c3362d014864, - 0x9a016184bdb8aeba, 0x95f2394459fbc25e, 0xe3f34a07a76a66c2, 0x8df25f9ad98b1b96, - ], - [ - 0x85ffc27171439d9d, 0xddcb9a2dcfd26910, 0x26b5ba4bf3afb94e, 0xffff9cc7c7651e2f, - 0x8c88364698280b55, 0xebc114167b910501, 0x2d77b4d89ecfb516, 0x332e0828eba151f2, - ], - [ - 0x46fa6a6450dd4735, 0xd00db7dd92384a33, 0x5fd4fb751f3a5fc5, 0x496fb90c0bb65ea2, - 0xf3baec0bb87cc5c7, 0x862a3c0a7d4c7713, 0xbf5f38336a3f47d8, 0x41ad9dbc1394a20c, - ], - [ - 0xcc535945b7dbf0f7, 0x82af2bc93685bcec, 0x8e4c8d0c8cebfccd, 0x17cb39417e84597e, - 0xd4a965a8c749b232, 0xa2cab040f33f3ee5, 0xa98811a1fed4e3a6, 0x1cc48b54f377e2a1, - ], - [ - 0xe40cd4f6c5609a27, 0x11de79ebca97a4a4, 0x9177c73d8b7e929d, 0x2a6fe8085797e792, - 0x3de6e93329f8d5ae, 0x3f7af9125da962ff, 0xd710682cfc77d3ac, 0x48faf05f3b053cf4, - ], - [ - 0x287db8630da89c8b, 0x4d0de32053cb30e9, 0x8b37a4f20c5ada7b, 0xe7cc6ebe78c84ecf, - 0x240bdc0a66a2610d, 0x8299e7f02caa1650, 0x380a53fefb6e754e, 0x684a1d8cf8eb6810, - ], - [ - 0xe839452eb4b8a5e1, 0xb03fa62e90626af4, 0x11a688602fbc5efc, 0x30dda75c355a2d62, - 0x0f712adcb73810de, 0xffdc1102187f1ae1, 0x40c34f398254b99c, 0xede021b9dc289a4a, - ], - [ - 0x8b7b05225c4e7dad, 0x3bc794346f9d9ff9, 0xfccb5a57f2ca86ff, 0xbb1502015a7da9d4, - 0xd7e0a35d4352a015, 0x27af7a44f8160931, 0xc37442f6782f4615, 0xbdf392a9bd095dcb, - ], - [ - 0xc17f55037cf00de9, 0xbcffedd34c71a874, 0x5eb45d2a8133d1f2, 0xbabe251e1612ebdf, - 0x3efeb9fbe438c536, 0x2d7cef97b4afe1cf, 0xe5de1b4660016c0b, 0xcdcc26c332f5657c, - ], - [ - 0xe01dd653daf15809, 0xb0a6bdd4b41094b5, 0x27eac858b0b03a05, 0x51d43b5e93adbdc0, - 0x8b89a23b0fea5fc9, 0xdc8ac3b14f7f2fc1, 0xe793f82f1efec039, 0x9f6f2cf8969e7b80, - ], - [ - 0x49d45382e0f21d4a, 0x5f4ad1797cd72786, 0x4dc3dbebfd45f795, 0x03a3ef84dba6e1bc, - 0x204bc9b3d3fc4c01, 0x9ad706081e89b9ba, 0x638bfb4d840e9f89, 0x5ef2938cd095ae35, - ], - [ - 0x42cca18ebeb265c8, 0xb7b2ec5c29aecbf8, 0x0d84f9535dc78f0f, 0x04e64ad942e77b8c, - 0xb4880dffffc9da0b, 0x16db16d9c29adeb1, 0x09bbaf2a0590cd1e, 0x76460e74961fcf8d, - ], - [ - 0xed12a2276dfa1553, 0x0b5acec5de0436fd, 0x3c6cfea033a1f0a8, 0x2b5ecefe546cac15, - 0x6e2d82884cd3bf6f, 0xc134878d1add7b83, 0x997963422eb7a280, 0x5e834537ac648cf6, - ], - [ - 0x89e779214737c0b7, 0x1a8c05e8581ad95b, 0x8d18b72796437cf7, 0xe7252c949e04b106, - 0x53267c4fd174585a, 0xa16ef5d9c81dad47, 0xda65191937270a46, 0xcb2a5b55f2df664c, - ], - [ - 0x854aee2dc1924137, 0xf37013c9d479ece6, 0x0e163bc0630c4696, 0x384ee64955048f76, - 0xf65d814e28ee4ec5, 0xe57bc564fd82f1b1, 0x4b338937b6876614, 0x66ee0b04ed43cd8d, - ], - [ - 0x49884bf25f4ef15d, 0xeb51fe28de1c6f54, 0x2cd64e84fce8dfcc, 0x29164a96a541a013, - 0x173ce7558f4cacb8, 0xeb5b1ce5877c89e9, 0x5faff4b0f5217bf6, 0xac42d0b1c20f205e, - ], - // ---- Terminal full rounds (4) ---- - [ - 0xfb1d6bf0ca43221b, 0x97b0a1b01d6a2955, 0x08c60bd622952b30, 0x43f2be0f9e24147c, - 0xfa7268b7d3730f5d, 0x43a6c419a23983bb, 0xcd77c1f7b29b113c, 0xcfa43c9db8eec29f, - ], - [ - 0xcaaa95a6c7365dec, 0x0a91193f798f3be0, 0x1104497652735dc6, 0x35aecb93663b515e, - 0x8dbc9916065aa858, 0xada8f7a0266579ed, 0x524dee7bec1ea789, 0xa93aee9dd5af9521, - ], - [ - 0x9d1f1b54750d707e, 0x7c9feab87096d5dc, 0xa2e1fb19f9d4261b, 0xb714deb448de6346, - 0x225d1f0d011c5403, 0x1549b7f1d28cedc0, 0xaef3e46f97d43942, 0x6dfc7ffe0b38bf08, - ], - [ - 0x7de853fdc542b663, 0xa68ecc96610657b2, 0xe88bb5428af289b1, 0xd7cfa1504c5569f5, - 0x78a9aad0d642d30a, 0xd68315f2353dce52, 0x46e56300f86fcfd5, 0x323d95332b145fd6, - ], - ]); +pub const GOLDILOCKS_POSEIDON1_RC_8: [[Goldilocks; POSEIDON1_WIDTH]; POSEIDON1_N_ROUNDS] = Goldilocks::new_2d_array([ + // ---- Initial full rounds (4) ---- + [ + 0xdd5743e7f2a5a5d9, + 0xcb3a864e58ada44b, + 0xffa2449ed32f8cdc, + 0x42025f65d6bd13ee, + 0x7889175e25506323, + 0x34b98bb03d24b737, + 0xbdcc535ecc4faa2a, + 0x5b20ad869fc0d033, + ], + [ + 0xf1dda5b9259dfcb4, + 0x27515210be112d59, + 0x4227d1718c766c3f, + 0x26d333161a5bd794, + 0x49b938957bf4b026, + 0x4a56b5938b213669, + 0x1120426b48c8353d, + 0x6b323c3f10a56cad, + ], + [ + 0xce57d6245ddca6b2, + 0xb1fc8d402bba1eb1, + 0xb5c5096ca959bd04, + 0x6db55cd306d31f7f, + 0xc49d293a81cb9641, + 0x1ce55a4fe979719f, + 0xa92e60a9d178a4d1, + 0x002cc64973bcfd8c, + ], + [ + 0xcea721cce82fb11b, + 0xe5b55eb8098ece81, + 0x4e30525c6f1ddd66, + 0x43c6702827070987, + 0xaca68430a7b5762a, + 0x3674238634df9c93, + 0x88cee1c825e33433, + 0xde99ae8d74b57176, + ], + // ---- Partial rounds (22) ---- + [ + 0x488897d85ff51f56, + 0x1140737ccb162218, + 0xa7eeb9215866ed35, + 0x9bd2976fee49fcc9, + 0xc0c8f0de580a3fcc, + 0x4fb2dae6ee8fc793, + 0x343a89f35f37395b, + 0x223b525a77ca72c8, + ], + [ + 0x56ccb62574aaa918, + 0xc4d507d8027af9ed, + 0xa080673cf0b7e95c, + 0xf0184884eb70dcf8, + 0x044f10b0cb3d5c69, + 0xe9e3f7993938f186, + 0x1b761c80e772f459, + 0x606cec607a1b5fac, + ], + [ + 0x14a0c2e1d45f03cd, + 0x4eace8855398574f, + 0xf905ca7103eff3e6, + 0xf8c8f8d20862c059, + 0xb524fe8bdd678e5a, + 0xfbb7865901a1ec41, + 0x014ef1197d341346, + 0x9725e20825d07394, + ], + [ + 0xfdb25aef2c5bae3b, + 0xbe5402dc598c971e, + 0x93a5711f04cdca3d, + 0xc45a9a5b2f8fb97b, + 0xfe8946a924933545, + 0x2af997a27369091c, + 0xaa62c88e0b294011, + 0x058eb9d810ce9f74, + ], + [ + 0xb3cb23eced349ae4, + 0xa3648177a77b4a84, + 0x43153d905992d95d, + 0xf4e2a97cda44aa4b, + 0x5baa2702b908682f, + 0x082923bdf4f750d1, + 0x98ae09a325893803, + 0xf8a6475077968838, + ], + [ + 0xceb0735bf00b2c5f, + 0x0a1a5d953888e072, + 0x2fcb190489f94475, + 0xb5be06270dec69fc, + 0x739cb934b09acf8b, + 0x537750b75ec7f25b, + 0xe9dd318bae1f3961, + 0xf7462137299efe1a, + ], + [ + 0xb1f6b8eee9adb940, + 0xbdebcc8a809dfe6b, + 0x40fc1f791b178113, + 0x3ac1c3362d014864, + 0x9a016184bdb8aeba, + 0x95f2394459fbc25e, + 0xe3f34a07a76a66c2, + 0x8df25f9ad98b1b96, + ], + [ + 0x85ffc27171439d9d, + 0xddcb9a2dcfd26910, + 0x26b5ba4bf3afb94e, + 0xffff9cc7c7651e2f, + 0x8c88364698280b55, + 0xebc114167b910501, + 0x2d77b4d89ecfb516, + 0x332e0828eba151f2, + ], + [ + 0x46fa6a6450dd4735, + 0xd00db7dd92384a33, + 0x5fd4fb751f3a5fc5, + 0x496fb90c0bb65ea2, + 0xf3baec0bb87cc5c7, + 0x862a3c0a7d4c7713, + 0xbf5f38336a3f47d8, + 0x41ad9dbc1394a20c, + ], + [ + 0xcc535945b7dbf0f7, + 0x82af2bc93685bcec, + 0x8e4c8d0c8cebfccd, + 0x17cb39417e84597e, + 0xd4a965a8c749b232, + 0xa2cab040f33f3ee5, + 0xa98811a1fed4e3a6, + 0x1cc48b54f377e2a1, + ], + [ + 0xe40cd4f6c5609a27, + 0x11de79ebca97a4a4, + 0x9177c73d8b7e929d, + 0x2a6fe8085797e792, + 0x3de6e93329f8d5ae, + 0x3f7af9125da962ff, + 0xd710682cfc77d3ac, + 0x48faf05f3b053cf4, + ], + [ + 0x287db8630da89c8b, + 0x4d0de32053cb30e9, + 0x8b37a4f20c5ada7b, + 0xe7cc6ebe78c84ecf, + 0x240bdc0a66a2610d, + 0x8299e7f02caa1650, + 0x380a53fefb6e754e, + 0x684a1d8cf8eb6810, + ], + [ + 0xe839452eb4b8a5e1, + 0xb03fa62e90626af4, + 0x11a688602fbc5efc, + 0x30dda75c355a2d62, + 0x0f712adcb73810de, + 0xffdc1102187f1ae1, + 0x40c34f398254b99c, + 0xede021b9dc289a4a, + ], + [ + 0x8b7b05225c4e7dad, + 0x3bc794346f9d9ff9, + 0xfccb5a57f2ca86ff, + 0xbb1502015a7da9d4, + 0xd7e0a35d4352a015, + 0x27af7a44f8160931, + 0xc37442f6782f4615, + 0xbdf392a9bd095dcb, + ], + [ + 0xc17f55037cf00de9, + 0xbcffedd34c71a874, + 0x5eb45d2a8133d1f2, + 0xbabe251e1612ebdf, + 0x3efeb9fbe438c536, + 0x2d7cef97b4afe1cf, + 0xe5de1b4660016c0b, + 0xcdcc26c332f5657c, + ], + [ + 0xe01dd653daf15809, + 0xb0a6bdd4b41094b5, + 0x27eac858b0b03a05, + 0x51d43b5e93adbdc0, + 0x8b89a23b0fea5fc9, + 0xdc8ac3b14f7f2fc1, + 0xe793f82f1efec039, + 0x9f6f2cf8969e7b80, + ], + [ + 0x49d45382e0f21d4a, + 0x5f4ad1797cd72786, + 0x4dc3dbebfd45f795, + 0x03a3ef84dba6e1bc, + 0x204bc9b3d3fc4c01, + 0x9ad706081e89b9ba, + 0x638bfb4d840e9f89, + 0x5ef2938cd095ae35, + ], + [ + 0x42cca18ebeb265c8, + 0xb7b2ec5c29aecbf8, + 0x0d84f9535dc78f0f, + 0x04e64ad942e77b8c, + 0xb4880dffffc9da0b, + 0x16db16d9c29adeb1, + 0x09bbaf2a0590cd1e, + 0x76460e74961fcf8d, + ], + [ + 0xed12a2276dfa1553, + 0x0b5acec5de0436fd, + 0x3c6cfea033a1f0a8, + 0x2b5ecefe546cac15, + 0x6e2d82884cd3bf6f, + 0xc134878d1add7b83, + 0x997963422eb7a280, + 0x5e834537ac648cf6, + ], + [ + 0x89e779214737c0b7, + 0x1a8c05e8581ad95b, + 0x8d18b72796437cf7, + 0xe7252c949e04b106, + 0x53267c4fd174585a, + 0xa16ef5d9c81dad47, + 0xda65191937270a46, + 0xcb2a5b55f2df664c, + ], + [ + 0x854aee2dc1924137, + 0xf37013c9d479ece6, + 0x0e163bc0630c4696, + 0x384ee64955048f76, + 0xf65d814e28ee4ec5, + 0xe57bc564fd82f1b1, + 0x4b338937b6876614, + 0x66ee0b04ed43cd8d, + ], + [ + 0x49884bf25f4ef15d, + 0xeb51fe28de1c6f54, + 0x2cd64e84fce8dfcc, + 0x29164a96a541a013, + 0x173ce7558f4cacb8, + 0xeb5b1ce5877c89e9, + 0x5faff4b0f5217bf6, + 0xac42d0b1c20f205e, + ], + // ---- Terminal full rounds (4) ---- + [ + 0xfb1d6bf0ca43221b, + 0x97b0a1b01d6a2955, + 0x08c60bd622952b30, + 0x43f2be0f9e24147c, + 0xfa7268b7d3730f5d, + 0x43a6c419a23983bb, + 0xcd77c1f7b29b113c, + 0xcfa43c9db8eec29f, + ], + [ + 0xcaaa95a6c7365dec, + 0x0a91193f798f3be0, + 0x1104497652735dc6, + 0x35aecb93663b515e, + 0x8dbc9916065aa858, + 0xada8f7a0266579ed, + 0x524dee7bec1ea789, + 0xa93aee9dd5af9521, + ], + [ + 0x9d1f1b54750d707e, + 0x7c9feab87096d5dc, + 0xa2e1fb19f9d4261b, + 0xb714deb448de6346, + 0x225d1f0d011c5403, + 0x1549b7f1d28cedc0, + 0xaef3e46f97d43942, + 0x6dfc7ffe0b38bf08, + ], + [ + 0x7de853fdc542b663, + 0xa68ecc96610657b2, + 0xe88bb5428af289b1, + 0xd7cfa1504c5569f5, + 0x78a9aad0d642d30a, + 0xd68315f2353dce52, + 0x46e56300f86fcfd5, + 0x323d95332b145fd6, + ], +]); // ========================================================================= // S-box helpers @@ -256,9 +434,7 @@ impl Poseidon1Goldilocks8 { mds_mul_scalar(state); } - for r in POSEIDON1_HALF_FULL_ROUNDS - ..POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS - { + for r in POSEIDON1_HALF_FULL_ROUNDS..POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS { for i in 0..POSEIDON1_WIDTH { state[i] += GOLDILOCKS_POSEIDON1_RC_8[r][i]; } @@ -293,9 +469,7 @@ impl Poseidon1Goldilocks8 { mds_mul_generic(state); } - for r in POSEIDON1_HALF_FULL_ROUNDS - ..POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS - { + for r in POSEIDON1_HALF_FULL_ROUNDS..POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS { for i in 0..POSEIDON1_WIDTH { state[i] = state[i] + GOLDILOCKS_POSEIDON1_RC_8[r][i]; } @@ -363,8 +537,14 @@ mod tests { #[test] fn permutation_is_deterministic() { let input: [Goldilocks; 8] = [ - Goldilocks::new(1), Goldilocks::new(2), Goldilocks::new(3), Goldilocks::new(4), - Goldilocks::new(5), Goldilocks::new(6), Goldilocks::new(7), Goldilocks::new(8), + Goldilocks::new(1), + Goldilocks::new(2), + Goldilocks::new(3), + Goldilocks::new(4), + Goldilocks::new(5), + Goldilocks::new(6), + Goldilocks::new(7), + Goldilocks::new(8), ]; let p = Poseidon1Goldilocks8; let a = p.permute(input); diff --git a/crates/backend/poly/src/eq_mle.rs b/crates/backend/poly/src/eq_mle.rs index 85c9be93..49aad126 100644 --- a/crates/backend/poly/src/eq_mle.rs +++ b/crates/backend/poly/src/eq_mle.rs @@ -1210,8 +1210,10 @@ mod tests { compute_eval_eq_packed::<_, true>(&eval, &mut out_2, scalar); println!("EXTENSION PACKED: {:?}", time.elapsed()); - let unpacked_out_2: Vec = - <>::ExtensionPacking as PackedFieldExtension>::to_ext_iter_vec(out_2.clone()); + let unpacked_out_2: Vec = <>::ExtensionPacking as PackedFieldExtension< + F, + EF, + >>::to_ext_iter_vec(out_2.clone()); assert_eq!(out_1, unpacked_out_2); let mut out_3 = EF::zero_vec(1 << n_vars); @@ -1245,8 +1247,10 @@ mod tests { compute_eval_eq_base_packed::(&eval, &mut out_2, scalar); println!("BASE PACKED: {:?}", time.elapsed()); - let unpacked_out_2: Vec = - <>::ExtensionPacking as PackedFieldExtension>::to_ext_iter_vec(out_2.clone()); + let unpacked_out_2: Vec = <>::ExtensionPacking as PackedFieldExtension< + F, + EF, + >>::to_ext_iter_vec(out_2.clone()); assert_eq!(out_1, unpacked_out_2); let mut out_3 = EF::zero_vec(1 << n_vars); diff --git a/crates/backend/symetric/src/permutation.rs b/crates/backend/symetric/src/permutation.rs index 8df7eb65..17cecf77 100644 --- a/crates/backend/symetric/src/permutation.rs +++ b/crates/backend/symetric/src/permutation.rs @@ -13,8 +13,8 @@ pub trait Compression: Clone + Sync { fn compress_mut(&self, input: &mut T); } -impl + InjectiveMonomial<7> + Copy + Send + Sync + 'static> - Compression<[R; 8]> for Poseidon1Goldilocks8 +impl + InjectiveMonomial<7> + Copy + Send + Sync + 'static> Compression<[R; 8]> + for Poseidon1Goldilocks8 { fn compress_mut(&self, input: &mut [R; 8]) { self.compress_in_place(input); diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index 0508e39f..1cf23028 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -2396,7 +2396,9 @@ fn simplify_lines( res.push(SimpleLine::equality(target_var, SimpleExpr::Constant(result))); } else { if !operation.supports_runtime() { - eprintln!("[COMPILE-TIME-OP DEBUG] operation={operation:?}, args={args_simplified:?}, var={var:?}, target_var={target_var:?}, is_mutable={is_mutable}"); + eprintln!( + "[COMPILE-TIME-OP DEBUG] operation={operation:?}, args={args_simplified:?}, var={var:?}, target_var={target_var:?}, is_mutable={is_mutable}" + ); return Err(format!( "Operation `{operation}` is compile-time only; all operands must be constants" )); diff --git a/crates/lean_compiler/tests/test_data/program_166.py b/crates/lean_compiler/tests/test_data/program_166.py index ccd24f05..bff572b3 100644 --- a/crates/lean_compiler/tests/test_data/program_166.py +++ b/crates/lean_compiler/tests/test_data/program_166.py @@ -35,7 +35,7 @@ def main(): @inline -def copy_5(a, b): +def copy_ef(a, b): dot_product_ee(a, ONE_EF_PTR, b) return diff --git a/crates/lean_vm/src/isa/hint.rs b/crates/lean_vm/src/isa/hint.rs index c9a2eb39..23c5d05c 100644 --- a/crates/lean_vm/src/isa/hint.rs +++ b/crates/lean_vm/src/isa/hint.rs @@ -158,9 +158,7 @@ impl CustomHint { for i in 0..num_fe { let value = ctx.memory.get(to_decompose_ptr + i)?.as_canonical_u64(); for j in 0..chunks_per_fe { - let chunk = F::from_u64( - (value >> (chunk_size * j)) & ((1u64 << chunk_size) - 1), - ); + let chunk = F::from_u64((value >> (chunk_size * j)) & ((1u64 << chunk_size) - 1)); ctx.memory.set(memory_index_decomposed, chunk)?; memory_index_decomposed += 1; } @@ -176,9 +174,7 @@ impl CustomHint { let chunk_size = args[3].read_value(ctx.memory, ctx.fp)?.to_usize(); assert!(num_chunks * chunk_size <= F::bits()); for j in 0..num_chunks { - let chunk = F::from_u64( - (value >> (chunk_size * j)) & ((1u64 << chunk_size) - 1), - ); + let chunk = F::from_u64((value >> (chunk_size * j)) & ((1u64 << chunk_size) - 1)); ctx.memory.set(decomposed_ptr + j, chunk)?; } } diff --git a/crates/lean_vm/src/tables/extension_op/air.rs b/crates/lean_vm/src/tables/extension_op/air.rs index 16c11daa..c7213429 100644 --- a/crates/lean_vm/src/tables/extension_op/air.rs +++ b/crates/lean_vm/src/tables/extension_op/air.rs @@ -1,6 +1,6 @@ use crate::{ - DIMENSION, EF, EXT_OP_FLAG_ADD, EXT_OP_FLAG_IS_BE, EXT_OP_FLAG_MUL, EXT_OP_FLAG_POLY_EQ, - ExtraDataForBuses, eval_virtual_bus_column, + DIMENSION, EF, EXT_OP_FLAG_ADD, EXT_OP_FLAG_IS_BE, EXT_OP_FLAG_MUL, EXT_OP_FLAG_POLY_EQ, ExtraDataForBuses, + eval_virtual_bus_column, tables::extension_op::{EXT_OP_LEN_MULTIPLIER, ExtensionOpPrecompile}, }; use backend::*; diff --git a/crates/lean_vm/src/tables/poseidon_8/mod.rs b/crates/lean_vm/src/tables/poseidon_8/mod.rs index a2bf7bae..83e3d58d 100644 --- a/crates/lean_vm/src/tables/poseidon_8/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_8/mod.rs @@ -1,5 +1,5 @@ -use crate::*; use crate::execution::memory::MemoryAccess; +use crate::*; use backend::*; use utils::{ToUsize, poseidon8_compress}; @@ -7,7 +7,7 @@ mod sparse; mod trace_gen; pub use trace_gen::fill_trace_poseidon_8; -use sparse::{get_partial_constants, PARTIAL_ROUNDS as SPARSE_PARTIAL_ROUNDS}; +use sparse::{PARTIAL_ROUNDS as SPARSE_PARTIAL_ROUNDS, get_partial_constants}; pub(super) const WIDTH: usize = 8; pub(super) const DIGEST: usize = DIGEST_LEN; // 4 @@ -53,8 +53,7 @@ const FULL_ROUND_COLS: usize = WIDTH; // 8 post-state const PARTIAL_ROUND_COLS: usize = 1; // post_sbox pub const fn is_full_round(r: usize) -> bool { - r < POSEIDON1_HALF_FULL_ROUNDS - || r >= POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS + r < POSEIDON1_HALF_FULL_ROUNDS || r >= POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS } /// First column index of round `r`'s data. @@ -198,8 +197,7 @@ impl TableT for Poseidon8Precompile { }, LookupIntoMemory { index: POSEIDON_8_COL_INDEX_INPUT_RIGHT, - values: (POSEIDON_8_COL_INPUT_START + DIGEST..POSEIDON_8_COL_INPUT_START + DIGEST * 2) - .collect(), + values: (POSEIDON_8_COL_INPUT_START + DIGEST..POSEIDON_8_COL_INPUT_START + DIGEST * 2).collect(), }, LookupIntoMemory { index: POSEIDON_8_COL_INDEX_INPUT_RES, @@ -320,8 +318,7 @@ impl Air for Poseidon8Precompile { let inputs: [AB::IF; WIDTH]; let outputs: [AB::IF; DIGEST]; // Per full round: `post[0..W]`. Per partial round: `post_sbox`. - let mut full_posts: Vec<[AB::IF; WIDTH]> = - Vec::with_capacity(2 * POSEIDON1_HALF_FULL_ROUNDS); + let mut full_posts: Vec<[AB::IF; WIDTH]> = Vec::with_capacity(2 * POSEIDON1_HALF_FULL_ROUNDS); let mut partial_post_sboxes: Vec = Vec::with_capacity(SPARSE_PARTIAL_ROUNDS); { let up = builder.up(); @@ -373,8 +370,7 @@ impl Air for Poseidon8Precompile { // ---- Initial full rounds ---- for round in 0..POSEIDON1_HALF_FULL_ROUNDS { let sbox_out: [AB::IF; WIDTH] = std::array::from_fn(|i| { - let x = state[i] - + AB::F::from_u64(GOLDILOCKS_POSEIDON1_RC_8[round][i].as_canonical_u64()); + let x = state[i] + AB::F::from_u64(GOLDILOCKS_POSEIDON1_RC_8[round][i].as_canonical_u64()); // x⁷ = x · (x²)² · x² — 4 Mul nodes in the symbolic DAG. let x2 = x * x; let x4 = x2 * x2; @@ -382,11 +378,9 @@ impl Air for Poseidon8Precompile { }); let post = full_posts[round]; for i in 0..WIDTH { - let mut acc = sbox_out[0] - * AB::F::from_u64(MDS8_ROW[(WIDTH - i) % WIDTH] as u64); + let mut acc = sbox_out[0] * AB::F::from_u64(MDS8_ROW[(WIDTH - i) % WIDTH] as u64); for j in 1..WIDTH { - let coeff = - AB::F::from_u64(MDS8_ROW[(j + WIDTH - i) % WIDTH] as u64); + let coeff = AB::F::from_u64(MDS8_ROW[(j + WIDTH - i) % WIDTH] as u64); acc = acc + sbox_out[j] * coeff; } builder.assert_zero(post[i] - acc); @@ -396,8 +390,7 @@ impl Air for Poseidon8Precompile { // ---- Partial phase: first_round_constants, m_i, sparse-matmul loop ---- for i in 0..WIDTH { - state[i] = state[i] - + AB::F::from_u64(c.first_round_constants[i].as_canonical_u64()); + state[i] = state[i] + AB::F::from_u64(c.first_round_constants[i].as_canonical_u64()); } { let mut after: [AB::IF; WIDTH] = std::array::from_fn(|i| { @@ -421,26 +414,20 @@ impl Air for Poseidon8Precompile { // state[0] becomes post_sbox (+ scalar RC, except last round). state[0] = if r < SPARSE_PARTIAL_ROUNDS - 1 { - post_sbox - + AB::F::from_u64(c.round_constants[r].as_canonical_u64()) + post_sbox + AB::F::from_u64(c.round_constants[r].as_canonical_u64()) } else { post_sbox }; // cheap_matmul. let old_s0 = state[0]; - let mut new_s0 = state[0] - * AB::F::from_u64(c.sparse_first_row[r][0].as_canonical_u64()); + let mut new_s0 = state[0] * AB::F::from_u64(c.sparse_first_row[r][0].as_canonical_u64()); for j in 1..WIDTH { - new_s0 = new_s0 - + state[j] - * AB::F::from_u64(c.sparse_first_row[r][j].as_canonical_u64()); + new_s0 = new_s0 + state[j] * AB::F::from_u64(c.sparse_first_row[r][j].as_canonical_u64()); } state[0] = new_s0; for i in 1..WIDTH { - state[i] = state[i] - + old_s0 - * AB::F::from_u64(c.v[r][i - 1].as_canonical_u64()); + state[i] = state[i] + old_s0 * AB::F::from_u64(c.v[r][i - 1].as_canonical_u64()); } } @@ -448,19 +435,16 @@ impl Air for Poseidon8Precompile { for round in 0..POSEIDON1_HALF_FULL_ROUNDS { let abs = POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS + round; let sbox_out: [AB::IF; WIDTH] = std::array::from_fn(|i| { - let x = state[i] - + AB::F::from_u64(GOLDILOCKS_POSEIDON1_RC_8[abs][i].as_canonical_u64()); + let x = state[i] + AB::F::from_u64(GOLDILOCKS_POSEIDON1_RC_8[abs][i].as_canonical_u64()); let x2 = x * x; let x4 = x2 * x2; x4 * x2 * x }); let post = full_posts[POSEIDON1_HALF_FULL_ROUNDS + round]; for i in 0..WIDTH { - let mut acc = sbox_out[0] - * AB::F::from_u64(MDS8_ROW[(WIDTH - i) % WIDTH] as u64); + let mut acc = sbox_out[0] * AB::F::from_u64(MDS8_ROW[(WIDTH - i) % WIDTH] as u64); for j in 1..WIDTH { - let coeff = - AB::F::from_u64(MDS8_ROW[(j + WIDTH - i) % WIDTH] as u64); + let coeff = AB::F::from_u64(MDS8_ROW[(j + WIDTH - i) % WIDTH] as u64); acc = acc + sbox_out[j] * coeff; } builder.assert_zero(post[i] - acc); diff --git a/crates/lean_vm/src/tables/poseidon_8/sparse.rs b/crates/lean_vm/src/tables/poseidon_8/sparse.rs index e633ec54..4d486a6a 100644 --- a/crates/lean_vm/src/tables/poseidon_8/sparse.rs +++ b/crates/lean_vm/src/tables/poseidon_8/sparse.rs @@ -15,8 +15,8 @@ use std::sync::OnceLock; use backend::{ - Field, GOLDILOCKS_POSEIDON1_RC_8, MDS8_ROW, PrimeCharacteristicRing, - POSEIDON1_HALF_FULL_ROUNDS, POSEIDON1_PARTIAL_ROUNDS, + Field, GOLDILOCKS_POSEIDON1_RC_8, MDS8_ROW, POSEIDON1_HALF_FULL_ROUNDS, POSEIDON1_PARTIAL_ROUNDS, + PrimeCharacteristicRing, }; use crate::F; @@ -72,10 +72,7 @@ fn matrix_transpose(m: &[[F; WIDTH]; WIDTH]) -> [[F; WIDTH]; WIDTH] { r } -fn matrix_mul( - a: &[[F; WIDTH]; WIDTH], - b: &[[F; WIDTH]; WIDTH], -) -> [[F; WIDTH]; WIDTH] { +fn matrix_mul(a: &[[F; WIDTH]; WIDTH], b: &[[F; WIDTH]; WIDTH]) -> [[F; WIDTH]; WIDTH] { let mut c = [[F::ZERO; WIDTH]; WIDTH]; for i in 0..WIDTH { for j in 0..WIDTH { @@ -192,11 +189,7 @@ fn submatrix_inverse(m: &[[F; WIDTH]; WIDTH]) -> [[F; WIDTH - 1]; WIDTH - 1] { fn compute_equivalent_matrices( mds: &[[F; WIDTH]; WIDTH], rounds_p: usize, -) -> ( - [[F; WIDTH]; WIDTH], - Vec<[F; WIDTH]>, - Vec<[F; WIDTH]>, -) { +) -> ([[F; WIDTH]; WIDTH], Vec<[F; WIDTH]>, Vec<[F; WIDTH]>) { let mut v_collection: Vec<[F; WIDTH]> = Vec::with_capacity(rounds_p); let mut w_hat_collection: Vec<[F; WIDTH]> = Vec::with_capacity(rounds_p); @@ -207,13 +200,7 @@ fn compute_equivalent_matrices( for _ in 0..rounds_p { // v = first row of m_mul (excl [0,0]). In the transposed domain this is // the first column of M'' in the non-transposed view. - let v_arr: [F; WIDTH] = std::array::from_fn(|j| { - if j < WIDTH - 1 { - m_mul[0][j + 1] - } else { - F::ZERO - } - }); + let v_arr: [F; WIDTH] = std::array::from_fn(|j| if j < WIDTH - 1 { m_mul[0][j + 1] } else { F::ZERO }); // w = first column of m_mul (excl [0,0]). let mut w = [F::ZERO; WIDTH - 1]; @@ -263,10 +250,7 @@ fn compute_equivalent_matrices( /// Backward-substitute partial round constants through M^{-1}, producing the /// full first-round vector and per-round scalar offsets. -fn equivalent_round_constants( - partial_rc: &[[F; WIDTH]], - mds_inv: &[[F; WIDTH]; WIDTH], -) -> ([F; WIDTH], Vec) { +fn equivalent_round_constants(partial_rc: &[[F; WIDTH]], mds_inv: &[[F; WIDTH]; WIDTH]) -> ([F; WIDTH], Vec) { let rounds_p = partial_rc.len(); let mut opt_partial_rc = vec![F::ZERO; rounds_p]; @@ -293,10 +277,8 @@ fn compute_partial_constants() -> PartialConstants { .map(|r| GOLDILOCKS_POSEIDON1_RC_8[POSEIDON1_HALF_FULL_ROUNDS + r]) .collect(); - let (first_round_constants, round_constants_vec) = - equivalent_round_constants(&partial_rc, &mds_inv); - let (m_i, v_collection, w_hat_collection) = - compute_equivalent_matrices(&mds, PARTIAL_ROUNDS); + let (first_round_constants, round_constants_vec) = equivalent_round_constants(&partial_rc, &mds_inv); + let (m_i, v_collection, w_hat_collection) = compute_equivalent_matrices(&mds, PARTIAL_ROUNDS); // sparse_first_row[r] = [mds[0][0], w_hat[r][0], …, w_hat[r][W-2]]. let mds_0_0 = mds[0][0]; @@ -400,9 +382,8 @@ mod tests { let mut seed = 0u64; for trial in 0..4 { seed = seed.wrapping_add(0x9E37_79B9_7F4A_7C15); - let input: [F; WIDTH] = std::array::from_fn(|i| { - F::from_u64(seed.wrapping_mul(i as u64 + 1 + trial as u64)) - }); + let input: [F; WIDTH] = + std::array::from_fn(|i| F::from_u64(seed.wrapping_mul(i as u64 + 1 + trial as u64))); let a = textbook_partial_phase(input); let b = sparse_partial_phase(input); for i in 0..WIDTH { diff --git a/crates/rec_aggregation/tests/test_hashing.py b/crates/rec_aggregation/tests/test_hashing.py index f27aa0e3..b485ac70 100644 --- a/crates/rec_aggregation/tests/test_hashing.py +++ b/crates/rec_aggregation/tests/test_hashing.py @@ -13,5 +13,5 @@ def main(): data = Array(len) hint_witness("input", data) hash = slice_hash_with_iv_dynamic_unroll(data, len, 15) - copy_8(hash, expected_hash) + copy_digest(hash, expected_hash) return diff --git a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs index 5a153465..18fad9f3 100644 --- a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs +++ b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs @@ -197,6 +197,33 @@ pub(super) fn run_phase1_sumcheck<'a, EF: ExtensionField>>( initial_pending_r: Option, ) -> (Vec, [EF; 4]) { let w = packing_log_width::(); + // When `w == 0` (no SIMD packing, e.g. Goldilocks), a `PackedBr(0)` layer + // can reach this function. In that case the data is already in natural, + // unpacked form, the inner loop has no rounds to run, and the + // `eq_outer` / `padding_sum` computed below would also panic on the slice + // — so skip phase 1 entirely and go straight to phase 2. + if layer_chunk_log == 0 { + debug_assert_eq!(w, 0); + debug_assert!(initial_pending_r.is_none()); + debug_assert!(precomputed_eq_outer.is_none()); + let nums_nat = unpack_extension::(nums.as_ref()); + let dens_nat = unpack_extension::(dens.as_ref()); + let (num_l, num_r) = even_odd_split(&nums_nat); + let (den_l, den_r) = even_odd_split(&dens_nat); + return run_phase2_sumcheck( + prover_state, + num_l, + num_r, + den_l, + den_r, + remaining_eq, + q_natural, + alpha, + sum, + mmf, + ); + } + let head_len = (remaining_eq.len() + 1).saturating_sub(layer_chunk_log); let outer_point: Vec = remaining_eq[..head_len].to_vec(); let eq_outer: Vec = precomputed_eq_outer.unwrap_or_else(|| eval_eq(&outer_point)); diff --git a/crates/utils/src/poseidon.rs b/crates/utils/src/poseidon.rs index 1de26ac3..e56172cb 100644 --- a/crates/utils/src/poseidon.rs +++ b/crates/utils/src/poseidon.rs @@ -26,10 +26,7 @@ pub fn poseidon8_compress(input: [Goldilocks; 8]) -> [Goldilocks; 4] { state[0..4].try_into().unwrap() } -pub fn poseidon8_compress_pair( - left: &[Goldilocks; 4], - right: &[Goldilocks; 4], -) -> [Goldilocks; 4] { +pub fn poseidon8_compress_pair(left: &[Goldilocks; 4], right: &[Goldilocks; 4]) -> [Goldilocks; 4] { let mut input = [Goldilocks::default(); 8]; input[..4].copy_from_slice(left); input[4..].copy_from_slice(right); diff --git a/crates/whir/src/dft.rs b/crates/whir/src/dft.rs index 329344c6..27924097 100644 --- a/crates/whir/src/dft.rs +++ b/crates/whir/src/dft.rs @@ -574,7 +574,7 @@ impl Butterfly for EvalsButterfly { #[cfg(test)] mod tests { use field::{PrimeCharacteristicRing, TwoAdicField}; - use goldilocks::{Goldilocks, CubicExtensionFieldGL}; + use goldilocks::{CubicExtensionFieldGL, Goldilocks}; use poly::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index e7a4e63c..fb4d02f0 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -8,7 +8,7 @@ use field::BasedVectorSpace; use field::ExtensionField; use field::Field; use field::PackedValue; -use goldilocks::{Goldilocks, CubicExtensionFieldGL, default_goldilocks_poseidon1_8}; +use goldilocks::{CubicExtensionFieldGL, Goldilocks, default_goldilocks_poseidon1_8}; use poly::*; use rayon::prelude::*; @@ -62,7 +62,7 @@ fn build_merkle_tree_goldilocks( effective_base_width: usize, ) -> RoundMerkleTree { let perm = default_goldilocks_poseidon1_8(); - let n_zero_suffix_rate_chunks = (full_base_width - effective_base_width) / 8; + let n_zero_suffix_rate_chunks = (full_base_width - effective_base_width) / 4; let first_layer = if n_zero_suffix_rate_chunks >= 2 { let scalar_state = symetric::precompute_zero_suffix_state::( &perm, diff --git a/crates/whir/tests/run_whir.rs b/crates/whir/tests/run_whir.rs index 19a3985f..5c7ba50b 100644 --- a/crates/whir/tests/run_whir.rs +++ b/crates/whir/tests/run_whir.rs @@ -4,7 +4,7 @@ use std::time::Instant; use fiat_shamir::{ProverState, VerifierState}; use field::{Field, TwoAdicField}; -use goldilocks::{Goldilocks, CubicExtensionFieldGL, default_goldilocks_poseidon1_8}; +use goldilocks::{CubicExtensionFieldGL, Goldilocks, default_goldilocks_poseidon1_8}; use mt_whir::*; use poly::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index 22a55069..ec05c118 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -1,7 +1,7 @@ use backend::*; use rand::{CryptoRng, RngExt}; use serde::{Deserialize, Serialize}; -use utils::{poseidon8_compress_pair, poseidon_compress_slice}; +use utils::{poseidon_compress_slice, poseidon8_compress_pair}; use crate::*; From 1663a9e6d3784bba1f1c6c9ceca096399f37fb0a Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Sun, 26 Apr 2026 12:00:54 +0200 Subject: [PATCH 18/31] low level optis --- .../goldilocks/src/aarch64_neon/mod.rs | 5 + .../goldilocks/src/aarch64_neon/packing.rs | 312 ++++++++++++++ .../backend/goldilocks/src/cubic_extension.rs | 55 +-- crates/backend/goldilocks/src/goldilocks.rs | 16 +- crates/backend/goldilocks/src/lib.rs | 17 + .../goldilocks/src/packed_cubic_extension.rs | 377 +++++++++++++++++ .../backend/goldilocks/src/x86_64_avx2/mod.rs | 5 + .../goldilocks/src/x86_64_avx2/packing.rs | 385 ++++++++++++++++++ .../goldilocks/src/x86_64_avx512/mod.rs | 5 + .../goldilocks/src/x86_64_avx512/packing.rs | 331 +++++++++++++++ 10 files changed, 1480 insertions(+), 28 deletions(-) create mode 100644 crates/backend/goldilocks/src/aarch64_neon/mod.rs create mode 100644 crates/backend/goldilocks/src/aarch64_neon/packing.rs create mode 100644 crates/backend/goldilocks/src/packed_cubic_extension.rs create mode 100644 crates/backend/goldilocks/src/x86_64_avx2/mod.rs create mode 100644 crates/backend/goldilocks/src/x86_64_avx2/packing.rs create mode 100644 crates/backend/goldilocks/src/x86_64_avx512/mod.rs create mode 100644 crates/backend/goldilocks/src/x86_64_avx512/packing.rs diff --git a/crates/backend/goldilocks/src/aarch64_neon/mod.rs b/crates/backend/goldilocks/src/aarch64_neon/mod.rs new file mode 100644 index 00000000..730a8675 --- /dev/null +++ b/crates/backend/goldilocks/src/aarch64_neon/mod.rs @@ -0,0 +1,5 @@ +// Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). + +mod packing; + +pub use packing::*; diff --git a/crates/backend/goldilocks/src/aarch64_neon/packing.rs b/crates/backend/goldilocks/src/aarch64_neon/packing.rs new file mode 100644 index 00000000..32044303 --- /dev/null +++ b/crates/backend/goldilocks/src/aarch64_neon/packing.rs @@ -0,0 +1,312 @@ +// Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). + +use alloc::vec::Vec; +use core::arch::aarch64::{ + uint64x2_t, vaddq_u64, vandq_u64, vdupq_n_u64, vgetq_lane_u64, vsetq_lane_u64, vshrq_n_u64, vsubq_u64, +}; +use core::fmt::Debug; +use core::iter::{Product, Sum}; +use core::mem::transmute; +use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use field::op_assign_macros::{ + impl_add_assign, impl_add_base_field, impl_div_methods, impl_mul_base_field, impl_mul_methods, impl_packed_value, + impl_rng, impl_sub_assign, impl_sub_base_field, impl_sum_prod_base_field, ring_sum, +}; +use field::{ + Algebra, Field, InjectiveMonomial, PackedField, PackedFieldPow2, PackedValue, PermutationMonomial, + PrimeCharacteristicRing, PrimeField64, +}; +use rand::Rng; +use rand::distr::{Distribution, StandardUniform}; +use utils::reconstitute_from_base; + +use crate::helpers::exp_10540996611094048183; +use crate::{Goldilocks, P}; + +const WIDTH: usize = 2; + +/// Equal to `2^32 - 1 = 2^64 mod P`. +const EPSILON: u64 = Goldilocks::ORDER_U64.wrapping_neg(); + +/// Goldilocks scalar addition: `(a + b) mod P`. Handles u64 overflow via EPSILON. +#[inline(always)] +pub(super) const fn gadd(a: u64, b: u64) -> u64 { + let (sum, overflow) = a.overflowing_add(b); + let (res, _) = sum.overflowing_add(if overflow { EPSILON } else { 0 }); + res +} + +/// Goldilocks scalar subtraction: `(a - b) mod P`. Handles u64 underflow via EPSILON. +#[inline(always)] +pub(super) const fn gsub(a: u64, b: u64) -> u64 { + let (diff, borrow) = a.overflowing_sub(b); + let (res, _) = diff.overflowing_sub(if borrow { EPSILON } else { 0 }); + res +} + +/// Single Goldilocks 64x64 -> 64 mul + reduction in pure Rust. +/// +/// LLVM emits `mul + umulh + 10-op reduction`, just like the inline-asm version, but as plain +/// instructions the compiler can interleave across independent calls instead of treating each as +/// an opaque asm block. +#[inline(always)] +pub(super) const fn mul_reduce(a: u64, b: u64) -> u64 { + let prod = (a as u128) * (b as u128); + let lo = prod as u64; + let hi = (prod >> 64) as u64; + + let hi_hi = hi >> 32; + let hi_lo = hi & 0xFFFF_FFFF; + + // tmp = lo - hi_hi; on borrow subtract EPSILON to fold P back in. + let (tmp_pre, borrow) = lo.overflowing_sub(hi_hi); + let tmp = tmp_pre.wrapping_sub(if borrow { EPSILON } else { 0 }); + + // hi_lo * (2^32 - 1) without an actual multiply. + let hi_lo_eps = (hi_lo << 32).wrapping_sub(hi_lo); + + // result = tmp + hi_lo_eps; on overflow add EPSILON. + let (res_pre, overflow) = tmp.overflowing_add(hi_lo_eps); + res_pre.wrapping_add(if overflow { EPSILON } else { 0 }) +} + +/// Hand-scheduled inline-asm variant of [`mul_reduce`], tuned for the **scalar / single-lane Mul** +/// path on aarch64. Saves one ALU op vs the LLVM-emitted form by collapsing `lsr+subs` into the +/// shifted-register `subs xT, lo, hi, lsr #32` form. +#[inline(always)] +pub(super) fn mul_reduce_asm(a: u64, b: u64) -> u64 { + let result: u64; + // SAFETY: integer ALU only; `pure, nomem, nostack` lets LLVM schedule, CSE, DCE. + unsafe { + core::arch::asm!( + "mul {lo}, {a}, {b}", + "umulh {hi}, {a}, {b}", + "subs {tmp}, {lo}, {hi}, lsr #32", + "csel {corr1}, {p}, xzr, lo", + "add {tmp}, {corr1}, {tmp}", + "lsl {hi_lo_eps}, {hi}, #32", + "sub {hi_lo_eps}, {hi_lo_eps}, {hi:w}, uxtw", + "adds {res}, {tmp}, {hi_lo_eps}", + "csel {corr2}, {eps}, xzr, hs", + "add {result}, {corr2}, {res}", + a = in(reg) a, + b = in(reg) b, + lo = out(reg) _, + hi = out(reg) _, + tmp = out(reg) _, + corr1 = out(reg) _, + hi_lo_eps = out(reg) _, + res = out(reg) _, + corr2 = out(reg) _, + result = lateout(reg) result, + p = in(reg) Goldilocks::ORDER_U64, + eps = in(reg) EPSILON, + options(pure, nomem, nostack), + ); + } + result +} + +/// Vectorized NEON implementation of `Goldilocks` arithmetic. +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] +#[repr(transparent)] +#[must_use] +pub struct PackedGoldilocksNeon(pub [Goldilocks; WIDTH]); + +impl PackedGoldilocksNeon { + #[inline] + #[must_use] + pub(crate) fn to_vector(self) -> uint64x2_t { + unsafe { transmute(self) } + } + + #[inline] + pub(crate) fn from_vector(vector: uint64x2_t) -> Self { + unsafe { transmute(vector) } + } + + #[inline] + const fn broadcast(value: Goldilocks) -> Self { + Self([value; WIDTH]) + } +} + +impl From for PackedGoldilocksNeon { + fn from(x: Goldilocks) -> Self { + Self::broadcast(x) + } +} + +// Add/Sub/Neg are emulated as two independent scalar Goldilocks ops. On Apple Silicon's wide +// scalar pipeline, two pipelined scalar adds beat the NEON modular-reduction sequence (XOR-shift +// + signed compare + conditional add) per element. Storage stays as `[Goldilocks; 2]` (16 bytes) +// so the compiler keeps elements in either GPRs or NEON regs as needed; only `mul`/`square` use +// the dual-lane interleaved ASM. +impl Add for PackedGoldilocksNeon { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self([self.0[0] + rhs.0[0], self.0[1] + rhs.0[1]]) + } +} + +impl Sub for PackedGoldilocksNeon { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self([self.0[0] - rhs.0[0], self.0[1] - rhs.0[1]]) + } +} + +impl Neg for PackedGoldilocksNeon { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self([-self.0[0], -self.0[1]]) + } +} + +impl Mul for PackedGoldilocksNeon { + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + // Hand-scheduled `mul_reduce_asm` saves one ALU op per lane vs LLVM's pure-Rust form. + Self([ + Goldilocks::new(mul_reduce_asm(self.0[0].value, rhs.0[0].value)), + Goldilocks::new(mul_reduce_asm(self.0[1].value, rhs.0[1].value)), + ]) + } +} + +impl_add_assign!(PackedGoldilocksNeon); +impl_sub_assign!(PackedGoldilocksNeon); +impl_mul_methods!(PackedGoldilocksNeon); +ring_sum!(PackedGoldilocksNeon); +impl_rng!(PackedGoldilocksNeon); + +impl PrimeCharacteristicRing for PackedGoldilocksNeon { + type PrimeSubfield = Goldilocks; + + const ZERO: Self = Self::broadcast(Goldilocks::ZERO); + const ONE: Self = Self::broadcast(Goldilocks::ONE); + const TWO: Self = Self::broadcast(Goldilocks::TWO); + const NEG_ONE: Self = Self::broadcast(Goldilocks::NEG_ONE); + + #[inline] + fn from_prime_subfield(f: Self::PrimeSubfield) -> Self { + f.into() + } + + #[inline] + fn halve(&self) -> Self { + Self::from_vector(halve(self.to_vector())) + } + + #[inline] + fn dot_product(lhs: &[Self; N], rhs: &[Self; N]) -> Self { + Self::from_fn(|lane| { + let lhs_lane: [Goldilocks; N] = core::array::from_fn(|i| lhs[i].as_slice()[lane]); + let rhs_lane: [Goldilocks; N] = core::array::from_fn(|i| rhs[i].as_slice()[lane]); + Goldilocks::dot_product(&lhs_lane, &rhs_lane) + }) + } + + #[inline] + fn square(&self) -> Self { + // Same rationale as `Mul`: scalar reduction avoids NEON<->GPR moves. + let x0 = self.0[0].value; + let x1 = self.0[1].value; + Self([ + Goldilocks::new(mul_reduce_asm(x0, x0)), + Goldilocks::new(mul_reduce_asm(x1, x1)), + ]) + } + + #[inline] + fn zero_vec(len: usize) -> Vec { + unsafe { reconstitute_from_base(Goldilocks::zero_vec(len * WIDTH)) } + } +} + +impl InjectiveMonomial<7> for PackedGoldilocksNeon {} + +impl PermutationMonomial<7> for PackedGoldilocksNeon { + fn injective_exp_root_n(&self) -> Self { + exp_10540996611094048183(*self) + } +} + +impl_add_base_field!(PackedGoldilocksNeon, Goldilocks); +impl_sub_base_field!(PackedGoldilocksNeon, Goldilocks); +impl_mul_base_field!(PackedGoldilocksNeon, Goldilocks); +impl_div_methods!(PackedGoldilocksNeon, Goldilocks); +impl_sum_prod_base_field!(PackedGoldilocksNeon, Goldilocks); + +impl Algebra for PackedGoldilocksNeon {} + +impl_packed_value!(PackedGoldilocksNeon, Goldilocks, WIDTH); + +unsafe impl PackedField for PackedGoldilocksNeon { + type Scalar = Goldilocks; +} + +/// Interleave two 64-bit vectors at the element level. +/// For block_len=1: `[a0, a1]` x `[b0, b1]` -> `[a0, b0]`, `[a1, b1]`. +#[inline] +pub fn interleave_u64(v0: uint64x2_t, v1: uint64x2_t) -> (uint64x2_t, uint64x2_t) { + unsafe { + let a0 = vgetq_lane_u64::<0>(v0); + let a1 = vgetq_lane_u64::<1>(v0); + let b0 = vgetq_lane_u64::<0>(v1); + let b1 = vgetq_lane_u64::<1>(v1); + + let r0 = vsetq_lane_u64::<1>(b0, vsetq_lane_u64::<0>(a0, vdupq_n_u64(0))); + let r1 = vsetq_lane_u64::<1>(b1, vsetq_lane_u64::<0>(a1, vdupq_n_u64(0))); + + (r0, r1) + } +} + +unsafe impl PackedFieldPow2 for PackedGoldilocksNeon { + fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) { + let (v0, v1) = (self.to_vector(), other.to_vector()); + let (res0, res1) = match block_len { + 1 => interleave_u64(v0, v1), + 2 => (v0, v1), + _ => panic!("unsupported block length"), + }; + (Self::from_vector(res0), Self::from_vector(res1)) + } +} + +/// Halve a vector of Goldilocks field elements. +#[inline(always)] +pub(crate) fn halve(input: uint64x2_t) -> uint64x2_t { + unsafe { + let one = vdupq_n_u64(1); + let zero = vdupq_n_u64(0); + let half = vdupq_n_u64(P.div_ceil(2)); + + let least_bit = vandq_u64(input, one); + let t = vshrq_n_u64::<1>(input); + // neg_least_bit is 0 or -1 (all bits 1). + let neg_least_bit = vsubq_u64(zero, least_bit); + let maybe_half = vandq_u64(half, neg_least_bit); + vaddq_u64(t, maybe_half) + } +} + +#[cfg(test)] +mod tests { + use super::{Goldilocks, PackedGoldilocksNeon, WIDTH}; + + const SPECIAL_VALS: [Goldilocks; WIDTH] = Goldilocks::new_array([0xFFFF_FFFF_0000_0000, 0xFFFF_FFFF_FFFF_FFFF]); + + #[test] + fn pack_round_trip() { + let p = PackedGoldilocksNeon(SPECIAL_VALS); + let v = p.to_vector(); + assert_eq!(PackedGoldilocksNeon::from_vector(v).0, SPECIAL_VALS); + } +} diff --git a/crates/backend/goldilocks/src/cubic_extension.rs b/crates/backend/goldilocks/src/cubic_extension.rs index 15c24b99..0a4513aa 100644 --- a/crates/backend/goldilocks/src/cubic_extension.rs +++ b/crates/backend/goldilocks/src/cubic_extension.rs @@ -14,8 +14,8 @@ use core::iter::{Product, Sum}; use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use field::{ - Algebra, BasedVectorSpace, ExtensionField, Field, Packable, PackedFieldExtension, PackedValue, - PrimeCharacteristicRing, RawDataSerializable, TwoAdicField, field_to_array, + Algebra, BasedVectorSpace, ExtensionField, Field, Packable, PrimeCharacteristicRing, RawDataSerializable, + TwoAdicField, field_to_array, }; use itertools::Itertools; use num_bigint::BigUint; @@ -116,7 +116,7 @@ impl BasedVectorSpace for CubicExtensionFieldGL { } impl ExtensionField for CubicExtensionFieldGL { - type ExtensionPacking = Self; + type ExtensionPacking = crate::packed_cubic_extension::PackedCubicExtensionFieldGL<::Packing>; #[inline] fn is_in_basefield(&self) -> bool { @@ -465,36 +465,22 @@ impl TwoAdicField for CubicExtensionFieldGL { } } -// PackedFieldExtension: since Goldilocks has trivial packing (Packing = Self), the cubic -// extension is also its own packing. -impl PackedFieldExtension for CubicExtensionFieldGL { - #[inline] - fn from_ext_slice(ext_slice: &[CubicExtensionFieldGL]) -> Self { - // Goldilocks::Packing::WIDTH == 1, so the input is a single element. - *CubicExtensionFieldGL::from_slice(ext_slice) - } - - #[inline] - fn packed_ext_powers(base: CubicExtensionFieldGL) -> field::Powers { - // `Powers` is just an iterator over `base^k` starting at `k = 1`. - use field::Powers; - Powers { - base, - current: Self::ONE, - } - } -} +// `PackedFieldExtension` is implemented by +// `PackedCubicExtensionFieldGL<::Packing>` (see `packed_cubic_extension.rs`). // ============================================================================ // Arithmetic kernels for `F_p[X] / (X^3 - X - 1)`. // ============================================================================ -/// Multiply two cubic extension elements. +/// Multiply two cubic extension elements over any algebra `R` over `Goldilocks`. /// -/// Given `a = a_0 + a_1 X + a_2 X^2` and `b = b_0 + b_1 X + b_2 X^2`, compute the +/// Given `a = a_0 + a_1 X + a_2 X^2` and `b = b_0 + b_1 X + b_2 X^2`, computes the /// product reduced by `X^3 - X - 1` (so `X^3 = X + 1`, `X^4 = X^2 + X`). #[inline] -pub fn cubic_mul(a: &[Goldilocks; 3], b: &[Goldilocks; 3], res: &mut [Goldilocks; 3]) { +pub fn cubic_mul_generic(a: &[R; 3], b: &[R; 3], res: &mut [R; 3]) +where + R: Copy + core::ops::Mul + core::ops::Add, +{ let a0 = a[0]; let a1 = a[1]; let a2 = a[2]; @@ -514,9 +500,12 @@ pub fn cubic_mul(a: &[Goldilocks; 3], b: &[Goldilocks; 3], res: &mut [Goldilocks res[2] = a0 * b2 + a1 * b1 + a2 * b0 + a2b2; } -/// Square a cubic extension element (same reduction rule as `cubic_mul`). +/// Square a cubic extension element (same reduction rule as `cubic_mul_generic`). #[inline] -pub fn cubic_square(a: &[Goldilocks; 3], res: &mut [Goldilocks; 3]) { +pub fn cubic_square_generic(a: &[R; 3], res: &mut [R; 3]) +where + R: PrimeCharacteristicRing + Copy, +{ let a0 = a[0]; let a1 = a[1]; let a2 = a[2]; @@ -536,6 +525,18 @@ pub fn cubic_square(a: &[Goldilocks; 3], res: &mut [Goldilocks; 3]) { res[2] = two_a0_a2 + a1_sq + a2_sq; } +/// Multiply two cubic extension elements (Goldilocks scalars). +#[inline] +pub fn cubic_mul(a: &[Goldilocks; 3], b: &[Goldilocks; 3], res: &mut [Goldilocks; 3]) { + cubic_mul_generic(a, b, res); +} + +/// Square a cubic extension element (Goldilocks scalar). +#[inline] +pub fn cubic_square(a: &[Goldilocks; 3], res: &mut [Goldilocks; 3]) { + cubic_square_generic(a, res); +} + /// Invert a cubic extension element via adjugate/determinant — no Frobenius round trip needed. /// /// The multiplication-by-`a` matrix (in the basis `{1, X, X^2}`, using `X^3 = X + 1`) is diff --git a/crates/backend/goldilocks/src/goldilocks.rs b/crates/backend/goldilocks/src/goldilocks.rs index d2672ff6..d0d6667d 100644 --- a/crates/backend/goldilocks/src/goldilocks.rs +++ b/crates/backend/goldilocks/src/goldilocks.rs @@ -24,7 +24,7 @@ use utils::{assume, branch_hint, flatten_to_base}; use crate::helpers::{exp_10540996611094048183, gcd_inner}; /// The Goldilocks prime. -pub(crate) const P: u64 = 0xFFFF_FFFF_0000_0001; +pub const P: u64 = 0xFFFF_FFFF_0000_0001; /// The prime field known as Goldilocks, defined as `F_p` where `p = 2^64 - 2^32 + 1`. /// @@ -292,6 +292,20 @@ impl RawDataSerializable for Goldilocks { } impl Field for Goldilocks { + #[cfg(all(target_arch = "x86_64", target_feature = "avx2", not(target_feature = "avx512f")))] + type Packing = crate::PackedGoldilocksAVX2; + + #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] + type Packing = crate::PackedGoldilocksAVX512; + + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + type Packing = crate::PackedGoldilocksNeon; + + #[cfg(not(any( + all(target_arch = "x86_64", target_feature = "avx2", not(target_feature = "avx512f")), + all(target_arch = "x86_64", target_feature = "avx512f"), + all(target_arch = "aarch64", target_feature = "neon"), + )))] type Packing = Self; const GENERATOR: Self = Self::new(7); diff --git a/crates/backend/goldilocks/src/lib.rs b/crates/backend/goldilocks/src/lib.rs index 0d6cb6a1..f8e3e578 100644 --- a/crates/backend/goldilocks/src/lib.rs +++ b/crates/backend/goldilocks/src/lib.rs @@ -9,6 +9,7 @@ extern crate alloc; mod cubic_extension; mod goldilocks; mod helpers; +mod packed_cubic_extension; mod poseidon1; #[cfg(test)] @@ -17,4 +18,20 @@ mod benchmark_poseidons_goldilocks; pub use cubic_extension::*; pub use goldilocks::*; pub use helpers::*; +pub use packed_cubic_extension::*; pub use poseidon1::*; + +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +mod aarch64_neon; +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +pub use aarch64_neon::*; + +#[cfg(all(target_arch = "x86_64", target_feature = "avx2", not(target_feature = "avx512f")))] +mod x86_64_avx2; +#[cfg(all(target_arch = "x86_64", target_feature = "avx2", not(target_feature = "avx512f")))] +pub use x86_64_avx2::*; + +#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] +mod x86_64_avx512; +#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] +pub use x86_64_avx512::*; diff --git a/crates/backend/goldilocks/src/packed_cubic_extension.rs b/crates/backend/goldilocks/src/packed_cubic_extension.rs new file mode 100644 index 00000000..57827fd7 --- /dev/null +++ b/crates/backend/goldilocks/src/packed_cubic_extension.rs @@ -0,0 +1,377 @@ +// Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). + +//! Packed (SIMD) version of the cubic extension `F_p[X] / (X^3 - X - 1)`. +//! +//! Mirrors `koala-bear`'s `PackedQuinticExtensionField` shape: a SoA array of +//! `[PF; 3]` packed-base-field lanes, so each field operation is a SIMD +//! multiply/add over `PF::WIDTH` extension elements at once. + +use alloc::vec::Vec; +use core::array; +use core::fmt::Debug; +use core::iter::{Product, Sum}; +use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use field::{ + Algebra, BasedVectorSpace, Field, PackedField, PackedFieldExtension, PackedValue, Powers, PrimeCharacteristicRing, + field_to_array, +}; +use itertools::Itertools; +use rand::distr::{Distribution, StandardUniform}; +use serde::{Deserialize, Serialize}; +use utils::{flatten_to_base, reconstitute_from_base}; + +use crate::Goldilocks; +use crate::cubic_extension::{CubicExtensionFieldGL, cubic_mul_generic, cubic_square_generic}; + +const D: usize = 3; + +/// Packed cubic extension over `Goldilocks`, parameterized by base field packing `PF`. +#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, PartialOrd, Ord)] +#[repr(transparent)] +pub struct PackedCubicExtensionFieldGL> { + #[serde( + with = "utils::array_serialization", + bound(serialize = "PF: Serialize", deserialize = "PF: Deserialize<'de>") + )] + pub(crate) value: [PF; D], +} + +impl> PackedCubicExtensionFieldGL { + const fn new(value: [PF; D]) -> Self { + Self { value } + } +} + +impl> Default for PackedCubicExtensionFieldGL { + #[inline] + fn default() -> Self { + Self { + value: array::from_fn(|_| PF::ZERO), + } + } +} + +impl> From for PackedCubicExtensionFieldGL { + #[inline] + fn from(x: CubicExtensionFieldGL) -> Self { + Self { + value: x.value.map(Into::into), + } + } +} + +impl> From for PackedCubicExtensionFieldGL { + #[inline] + fn from(x: PF) -> Self { + Self { + value: field_to_array(x), + } + } +} + +impl> Distribution> for StandardUniform +where + Self: Distribution, +{ + #[inline] + fn sample(&self, rng: &mut R) -> PackedCubicExtensionFieldGL { + PackedCubicExtensionFieldGL::new(array::from_fn(|_| self.sample(rng))) + } +} + +impl> Algebra for PackedCubicExtensionFieldGL {} + +impl> Algebra for PackedCubicExtensionFieldGL {} + +impl> PrimeCharacteristicRing for PackedCubicExtensionFieldGL { + type PrimeSubfield = PF::PrimeSubfield; + + const ZERO: Self = Self { value: [PF::ZERO; D] }; + + const ONE: Self = Self { + value: field_to_array(PF::ONE), + }; + + const TWO: Self = Self { + value: field_to_array(PF::TWO), + }; + + const NEG_ONE: Self = Self { + value: field_to_array(PF::NEG_ONE), + }; + + #[inline] + fn from_prime_subfield(val: Self::PrimeSubfield) -> Self { + PF::from_prime_subfield(val).into() + } + + #[inline] + fn from_bool(b: bool) -> Self { + PF::from_bool(b).into() + } + + #[inline(always)] + fn square(&self) -> Self { + let mut res = Self::default(); + cubic_square_generic(&self.value, &mut res.value); + res + } + + #[inline] + fn zero_vec(len: usize) -> Vec { + // SAFETY: this is a repr(transparent) wrapper around an array. + unsafe { reconstitute_from_base(PF::zero_vec(len * D)) } + } +} + +impl> BasedVectorSpace for PackedCubicExtensionFieldGL { + const DIMENSION: usize = D; + + #[inline] + fn as_basis_coefficients_slice(&self) -> &[PF] { + &self.value + } + + #[inline] + fn from_basis_coefficients_fn PF>(f: Fn) -> Self { + Self { + value: array::from_fn(f), + } + } + + #[inline] + fn from_basis_coefficients_iter>(mut iter: I) -> Option { + (iter.len() == D).then(|| Self::new(array::from_fn(|_| iter.next().unwrap()))) + } + + #[inline] + fn flatten_to_base(vec: Vec) -> Vec { + // SAFETY: `Self` is `repr(transparent)` over `[PF; D]`. + unsafe { flatten_to_base(vec) } + } + + #[inline] + fn reconstitute_from_base(vec: Vec) -> Vec { + // SAFETY: `Self` is `repr(transparent)` over `[PF; D]`. + unsafe { reconstitute_from_base(vec) } + } +} + +impl PackedFieldExtension + for PackedCubicExtensionFieldGL<::Packing> +{ + #[inline] + fn from_ext_slice(ext_slice: &[CubicExtensionFieldGL]) -> Self { + let width = ::Packing::WIDTH; + assert_eq!(ext_slice.len(), width); + + let res = array::from_fn(|i| ::Packing::from_fn(|j| ext_slice[j].value[i])); + Self::new(res) + } + + #[inline] + fn to_ext_iter(iter: impl IntoIterator) -> impl Iterator { + let width = ::Packing::WIDTH; + iter.into_iter().flat_map(move |x| { + (0..width).map(move |i| { + let values = array::from_fn(|j| x.value[j].as_slice()[i]); + CubicExtensionFieldGL::new(values) + }) + }) + } + + #[inline] + fn packed_ext_powers(base: CubicExtensionFieldGL) -> Powers { + let width = ::Packing::WIDTH; + let powers = base.powers().take(width + 1).collect_vec(); + let current = Self::from_ext_slice(&powers[..width]); + let multiplier = powers[width].into(); + + Powers { + base: multiplier, + current, + } + } +} + +impl> Neg for PackedCubicExtensionFieldGL { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self { + value: self.value.map(PF::neg), + } + } +} + +impl> Add for PackedCubicExtensionFieldGL { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self { + value: array::from_fn(|i| self.value[i] + rhs.value[i]), + } + } +} + +impl> Add for PackedCubicExtensionFieldGL { + type Output = Self; + #[inline] + fn add(self, rhs: CubicExtensionFieldGL) -> Self { + Self { + value: array::from_fn(|i| self.value[i] + rhs.value[i]), + } + } +} + +impl> Add for PackedCubicExtensionFieldGL { + type Output = Self; + #[inline] + fn add(mut self, rhs: PF) -> Self { + self.value[0] += rhs; + self + } +} + +impl> AddAssign for PackedCubicExtensionFieldGL { + #[inline] + fn add_assign(&mut self, rhs: Self) { + for i in 0..D { + self.value[i] += rhs.value[i]; + } + } +} + +impl> AddAssign for PackedCubicExtensionFieldGL { + #[inline] + fn add_assign(&mut self, rhs: CubicExtensionFieldGL) { + for i in 0..D { + self.value[i] += rhs.value[i]; + } + } +} + +impl> AddAssign for PackedCubicExtensionFieldGL { + #[inline] + fn add_assign(&mut self, rhs: PF) { + self.value[0] += rhs; + } +} + +impl> Sum for PackedCubicExtensionFieldGL { + #[inline] + fn sum>(iter: I) -> Self { + iter.reduce(|acc, x| acc + x).unwrap_or(Self::ZERO) + } +} + +impl> Sub for PackedCubicExtensionFieldGL { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self { + value: array::from_fn(|i| self.value[i] - rhs.value[i]), + } + } +} + +impl> Sub for PackedCubicExtensionFieldGL { + type Output = Self; + #[inline] + fn sub(self, rhs: CubicExtensionFieldGL) -> Self { + Self { + value: array::from_fn(|i| self.value[i] - rhs.value[i]), + } + } +} + +impl> Sub for PackedCubicExtensionFieldGL { + type Output = Self; + #[inline] + fn sub(self, rhs: PF) -> Self { + let mut res = self.value; + res[0] -= rhs; + Self { value: res } + } +} + +impl> SubAssign for PackedCubicExtensionFieldGL { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl> SubAssign for PackedCubicExtensionFieldGL { + #[inline] + fn sub_assign(&mut self, rhs: CubicExtensionFieldGL) { + *self = *self - rhs; + } +} + +impl> SubAssign for PackedCubicExtensionFieldGL { + #[inline] + fn sub_assign(&mut self, rhs: PF) { + *self = *self - rhs; + } +} + +impl> Mul for PackedCubicExtensionFieldGL { + type Output = Self; + #[inline(always)] + fn mul(self, rhs: Self) -> Self { + let mut res = Self::default(); + cubic_mul_generic(&self.value, &rhs.value, &mut res.value); + res + } +} + +impl> Mul for PackedCubicExtensionFieldGL { + type Output = Self; + #[inline(always)] + fn mul(self, rhs: CubicExtensionFieldGL) -> Self { + let b: [PF; D] = rhs.value.map(|x| x.into()); + let mut res = Self::default(); + cubic_mul_generic(&self.value, &b, &mut res.value); + res + } +} + +impl> Mul for PackedCubicExtensionFieldGL { + type Output = Self; + #[inline] + fn mul(self, rhs: PF) -> Self { + Self { + value: self.value.map(|x| x * rhs), + } + } +} + +impl> Product for PackedCubicExtensionFieldGL { + #[inline] + fn product>(iter: I) -> Self { + iter.reduce(|acc, x| acc * x).unwrap_or(Self::ONE) + } +} + +impl> MulAssign for PackedCubicExtensionFieldGL { + #[inline(always)] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl> MulAssign for PackedCubicExtensionFieldGL { + #[inline(always)] + fn mul_assign(&mut self, rhs: CubicExtensionFieldGL) { + *self = *self * rhs; + } +} + +impl> MulAssign for PackedCubicExtensionFieldGL { + #[inline] + fn mul_assign(&mut self, rhs: PF) { + *self = *self * rhs; + } +} diff --git a/crates/backend/goldilocks/src/x86_64_avx2/mod.rs b/crates/backend/goldilocks/src/x86_64_avx2/mod.rs new file mode 100644 index 00000000..730a8675 --- /dev/null +++ b/crates/backend/goldilocks/src/x86_64_avx2/mod.rs @@ -0,0 +1,5 @@ +// Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). + +mod packing; + +pub use packing::*; diff --git a/crates/backend/goldilocks/src/x86_64_avx2/packing.rs b/crates/backend/goldilocks/src/x86_64_avx2/packing.rs new file mode 100644 index 00000000..30ad75c6 --- /dev/null +++ b/crates/backend/goldilocks/src/x86_64_avx2/packing.rs @@ -0,0 +1,385 @@ +// Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). + +use alloc::vec::Vec; +use core::arch::x86_64::*; +use core::fmt::Debug; +use core::iter::{Product, Sum}; +use core::mem::transmute; +use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use field::interleave::{interleave_u64, interleave_u128}; +use field::op_assign_macros::{ + impl_add_assign, impl_add_base_field, impl_div_methods, impl_mul_base_field, impl_mul_methods, impl_packed_value, + impl_rng, impl_sub_assign, impl_sub_base_field, impl_sum_prod_base_field, ring_sum, +}; +use field::{ + Algebra, Field, InjectiveMonomial, PackedField, PackedFieldPow2, PackedValue, PermutationMonomial, + PrimeCharacteristicRing, PrimeField64, impl_packed_field_pow_2, +}; +use rand::Rng; +use rand::distr::{Distribution, StandardUniform}; +use utils::reconstitute_from_base; + +use crate::helpers::exp_10540996611094048183; +use crate::{Goldilocks, P}; + +const WIDTH: usize = 4; + +/// Vectorized AVX2 implementation of `Goldilocks` arithmetic. +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] +#[repr(transparent)] // Needed to make `transmute`s safe. +#[must_use] +pub struct PackedGoldilocksAVX2(pub [Goldilocks; WIDTH]); + +impl PackedGoldilocksAVX2 { + /// Get an arch-specific vector representing the packed values. + #[inline] + #[must_use] + pub(crate) fn to_vector(self) -> __m256i { + unsafe { + // Safety: `Goldilocks` is `repr(transparent)` over `u64`, so + // `[Goldilocks; 4]` and `__m256i` share size and layout. + transmute(self) + } + } + + /// Make a packed field vector from an arch-specific vector. + /// + /// Elements of `Goldilocks` are allowed to be arbitrary u64s so this function + /// is safe (unlike Mersenne31/MontyField31 variants). + #[inline] + pub(crate) fn from_vector(vector: __m256i) -> Self { + unsafe { transmute(vector) } + } + + /// Copy `value` to all positions in a packed vector. This is the same as + /// `From::from`, but `const`. + #[inline] + const fn broadcast(value: Goldilocks) -> Self { + Self([value; WIDTH]) + } +} + +impl From for PackedGoldilocksAVX2 { + fn from(x: Goldilocks) -> Self { + Self::broadcast(x) + } +} + +impl Add for PackedGoldilocksAVX2 { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self::from_vector(add(self.to_vector(), rhs.to_vector())) + } +} + +impl Sub for PackedGoldilocksAVX2 { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self::from_vector(sub(self.to_vector(), rhs.to_vector())) + } +} + +impl Neg for PackedGoldilocksAVX2 { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self::from_vector(neg(self.to_vector())) + } +} + +impl Mul for PackedGoldilocksAVX2 { + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + Self::from_vector(mul(self.to_vector(), rhs.to_vector())) + } +} + +impl_add_assign!(PackedGoldilocksAVX2); +impl_sub_assign!(PackedGoldilocksAVX2); +impl_mul_methods!(PackedGoldilocksAVX2); +ring_sum!(PackedGoldilocksAVX2); +impl_rng!(PackedGoldilocksAVX2); + +impl PrimeCharacteristicRing for PackedGoldilocksAVX2 { + type PrimeSubfield = Goldilocks; + + const ZERO: Self = Self::broadcast(Goldilocks::ZERO); + const ONE: Self = Self::broadcast(Goldilocks::ONE); + const TWO: Self = Self::broadcast(Goldilocks::TWO); + const NEG_ONE: Self = Self::broadcast(Goldilocks::NEG_ONE); + + #[inline] + fn from_prime_subfield(f: Self::PrimeSubfield) -> Self { + f.into() + } + + #[inline] + fn halve(&self) -> Self { + Self::from_vector(halve(self.to_vector())) + } + + #[inline] + fn square(&self) -> Self { + Self::from_vector(square(self.to_vector())) + } + + #[inline] + fn zero_vec(len: usize) -> Vec { + // SAFETY: this is a repr(transparent) wrapper around an array. + unsafe { reconstitute_from_base(Goldilocks::zero_vec(len * WIDTH)) } + } +} + +// Goldilocks: p - 1 = 2^32 * 3 * 5 * 17 * ...; smallest D coprime to (p-1) is 7. +impl InjectiveMonomial<7> for PackedGoldilocksAVX2 {} + +impl PermutationMonomial<7> for PackedGoldilocksAVX2 { + fn injective_exp_root_n(&self) -> Self { + exp_10540996611094048183(*self) + } +} + +impl_add_base_field!(PackedGoldilocksAVX2, Goldilocks); +impl_sub_base_field!(PackedGoldilocksAVX2, Goldilocks); +impl_mul_base_field!(PackedGoldilocksAVX2, Goldilocks); +impl_div_methods!(PackedGoldilocksAVX2, Goldilocks); +impl_sum_prod_base_field!(PackedGoldilocksAVX2, Goldilocks); + +impl Algebra for PackedGoldilocksAVX2 {} + +impl_packed_value!(PackedGoldilocksAVX2, Goldilocks, WIDTH); + +unsafe impl PackedField for PackedGoldilocksAVX2 { + type Scalar = Goldilocks; +} + +impl_packed_field_pow_2!( + PackedGoldilocksAVX2; + [ + (1, interleave_u64), + (2, interleave_u128), + ], + WIDTH +); + +// Resources: +// 1. Intel Intrinsics Guide: https://software.intel.com/sites/landingpage/IntrinsicsGuide/ +// 2. uops.info: https://uops.info/table.html +// +// Implementation notes: +// - AVX has no unsigned 64-bit comparisons. We emulate them via signed comparisons after a +// 1<<63 shift (`shift`/`canonicalize_s`/etc). +// - AVX has no add-with-carry; emulated via `result < operand` overflow detection. + +const SIGN_BIT: __m256i = unsafe { transmute([i64::MIN; WIDTH]) }; +const SHIFTED_FIELD_ORDER: __m256i = unsafe { transmute([Goldilocks::ORDER_U64 ^ (i64::MIN as u64); WIDTH]) }; + +/// Equal to `2^32 - 1 = 2^64 mod P`. +const EPSILON: __m256i = unsafe { transmute([Goldilocks::ORDER_U64.wrapping_neg(); WIDTH]) }; + +/// Add 2^63 (XOR with sign bit). Used to emulate unsigned compares with signed ones. +#[inline] +fn shift(x: __m256i) -> __m256i { + unsafe { _mm256_xor_si256(x, SIGN_BIT) } +} + +/// Convert to canonical representation. Argument is shifted by 1<<63; result is too. +#[inline] +unsafe fn canonicalize_s(x_s: __m256i) -> __m256i { + unsafe { + let mask = _mm256_cmpgt_epi64(SHIFTED_FIELD_ORDER, x_s); + let wrapback_amt = _mm256_andnot_si256(mask, EPSILON); + _mm256_add_epi64(x_s, wrapback_amt) + } +} + +/// Add `x + y_s` where `y_s` is pre-shifted; output is shifted. Assumes `x + y < 2^64 + P`. +#[inline] +unsafe fn add_no_double_overflow_64_64s_s(x: __m256i, y_s: __m256i) -> __m256i { + unsafe { + let res_wrapped_s = _mm256_add_epi64(x, y_s); + let mask = _mm256_cmpgt_epi64(y_s, res_wrapped_s); + let wrapback_amt = _mm256_srli_epi64::<32>(mask); + _mm256_add_epi64(res_wrapped_s, wrapback_amt) + } +} + +/// Goldilocks modular addition. Result may exceed `P`. +#[inline] +fn add(x: __m256i, y: __m256i) -> __m256i { + unsafe { + let y_s = shift(y); + let res_s = add_no_double_overflow_64_64s_s(x, canonicalize_s(y_s)); + shift(res_s) + } +} + +/// Goldilocks modular subtraction. Result may exceed `P`. +#[inline] +fn sub(x: __m256i, y: __m256i) -> __m256i { + unsafe { + let mut y_s = shift(y); + y_s = canonicalize_s(y_s); + let x_s = shift(x); + let mask = _mm256_cmpgt_epi64(y_s, x_s); + let wrapback_amt = _mm256_srli_epi64::<32>(mask); + let res_wrapped = _mm256_sub_epi64(x_s, y_s); + _mm256_sub_epi64(res_wrapped, wrapback_amt) + } +} + +/// Goldilocks modular negation. Result may exceed `P`. +#[inline] +fn neg(y: __m256i) -> __m256i { + unsafe { + let y_s = shift(y); + _mm256_sub_epi64(SHIFTED_FIELD_ORDER, canonicalize_s(y_s)) + } +} + +/// Halve a vector of Goldilocks field elements. +#[inline(always)] +pub(crate) fn halve(input: __m256i) -> __m256i { + // For val in [0, P): val even -> val/2 = val>>1; val odd -> (val+P)/2 = (val>>1) + (P+1)/2. + unsafe { + const ONE: __m256i = unsafe { transmute([1_i64; 4]) }; + const ZERO: __m256i = unsafe { transmute([0_i64; 4]) }; + let half = _mm256_set1_epi64x(P.div_ceil(2) as i64); + + let least_bit = _mm256_and_si256(input, ONE); + let t = _mm256_srli_epi64::<1>(input); + let neg_least_bit = _mm256_sub_epi64(ZERO, least_bit); + let maybe_half = _mm256_and_si256(half, neg_least_bit); + _mm256_add_epi64(t, maybe_half) + } +} + +/// Full 64x64 -> 128 multiplication, returning `(hi, lo)`. +#[inline] +fn mul64_64(x: __m256i, y: __m256i) -> (__m256i, __m256i) { + unsafe { + // Move the high 32 bits of each lane into the low 32 bits via a float-domain swizzle. + // (vpshufd / movehdup runs on port 5 and doesn't compete with the multiplier on ports 0/1.) + let x_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x))); + let y_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(y))); + + let mul_ll = _mm256_mul_epu32(x, y); + let mul_lh = _mm256_mul_epu32(x, y_hi); + let mul_hl = _mm256_mul_epu32(x_hi, y); + let mul_hh = _mm256_mul_epu32(x_hi, y_hi); + + let mul_ll_hi = _mm256_srli_epi64::<32>(mul_ll); + let t0 = _mm256_add_epi64(mul_hl, mul_ll_hi); + let t0_lo = _mm256_and_si256(t0, EPSILON); + let t0_hi = _mm256_srli_epi64::<32>(t0); + let t1 = _mm256_add_epi64(mul_lh, t0_lo); + let t2 = _mm256_add_epi64(mul_hh, t0_hi); + let t1_hi = _mm256_srli_epi64::<32>(t1); + let res_hi = _mm256_add_epi64(t2, t1_hi); + + let t1_lo = _mm256_castps_si256(_mm256_moveldup_ps(_mm256_castsi256_ps(t1))); + let res_lo = _mm256_blend_epi32::<0xaa>(mul_ll, t1_lo); + + (res_hi, res_lo) + } +} + +/// Full 64-bit squaring. +#[inline] +fn square64(x: __m256i) -> (__m256i, __m256i) { + unsafe { + let x_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x))); + + let mul_ll = _mm256_mul_epu32(x, x); + let mul_lh = _mm256_mul_epu32(x, x_hi); + let mul_hh = _mm256_mul_epu32(x_hi, x_hi); + + let mul_ll_hi = _mm256_srli_epi64::<33>(mul_ll); + let t0 = _mm256_add_epi64(mul_lh, mul_ll_hi); + let t0_hi = _mm256_srli_epi64::<31>(t0); + let res_hi = _mm256_add_epi64(mul_hh, t0_hi); + + let mul_lh_lo = _mm256_slli_epi64::<33>(mul_lh); + let res_lo = _mm256_add_epi64(mul_ll, mul_lh_lo); + + (res_hi, res_lo) + } +} + +/// Add `x_s + y` where `x_s` is pre-shifted by 2^63 and `y <= 2^64 - 2^32`. Result is shifted. +#[inline] +unsafe fn add_small_64s_64_s(x_s: __m256i, y: __m256i) -> __m256i { + unsafe { + let res_wrapped_s = _mm256_add_epi64(x_s, y); + let mask = _mm256_cmpgt_epi32(x_s, res_wrapped_s); + let wrapback_amt = _mm256_srli_epi64::<32>(mask); + _mm256_add_epi64(res_wrapped_s, wrapback_amt) + } +} + +/// Subtract `y` from `x_s` (`x_s` pre-shifted, `y <= 2^64 - 2^32`). Result is shifted. +#[inline] +unsafe fn sub_small_64s_64_s(x_s: __m256i, y: __m256i) -> __m256i { + unsafe { + let res_wrapped_s = _mm256_sub_epi64(x_s, y); + let mask = _mm256_cmpgt_epi32(res_wrapped_s, x_s); + let wrapback_amt = _mm256_srli_epi64::<32>(mask); + _mm256_sub_epi64(res_wrapped_s, wrapback_amt) + } +} + +/// Reduce a 128-bit value (high, low) modulo `P`. Result may exceed `P`. +#[inline] +fn reduce128(x: (__m256i, __m256i)) -> __m256i { + unsafe { + let (hi0, lo0) = x; + + let lo0_s = shift(lo0); + + let hi_hi0 = _mm256_srli_epi64::<32>(hi0); + + // 2^96 = -1 mod P. + let lo1_s = sub_small_64s_64_s(lo0_s, hi_hi0); + + // Bottom 32 bits of hi0 times 2^64 = 2^32 - 1 = EPSILON mod P. + let t1 = _mm256_mul_epu32(hi0, EPSILON); + + let lo2_s = add_small_64s_64_s(lo1_s, t1); + shift(lo2_s) + } +} + +/// Goldilocks modular multiplication. Result may exceed `P`. +#[inline] +fn mul(x: __m256i, y: __m256i) -> __m256i { + reduce128(mul64_64(x, y)) +} + +/// Goldilocks modular square. Result may exceed `P`. +#[inline] +fn square(x: __m256i) -> __m256i { + reduce128(square64(x)) +} + +#[cfg(test)] +mod tests { + use super::{Goldilocks, PackedGoldilocksAVX2, WIDTH}; + + const SPECIAL_VALS: [Goldilocks; WIDTH] = Goldilocks::new_array([ + 0xFFFF_FFFF_0000_0000, + 0xFFFF_FFFF_FFFF_FFFF, + 0x0000_0000_0000_0001, + 0xFFFF_FFFF_0000_0001, + ]); + + #[test] + fn pack_round_trip() { + let p = PackedGoldilocksAVX2(SPECIAL_VALS); + let v = p.to_vector(); + assert_eq!(PackedGoldilocksAVX2::from_vector(v).0, SPECIAL_VALS); + } +} diff --git a/crates/backend/goldilocks/src/x86_64_avx512/mod.rs b/crates/backend/goldilocks/src/x86_64_avx512/mod.rs new file mode 100644 index 00000000..730a8675 --- /dev/null +++ b/crates/backend/goldilocks/src/x86_64_avx512/mod.rs @@ -0,0 +1,5 @@ +// Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). + +mod packing; + +pub use packing::*; diff --git a/crates/backend/goldilocks/src/x86_64_avx512/packing.rs b/crates/backend/goldilocks/src/x86_64_avx512/packing.rs new file mode 100644 index 00000000..1484e764 --- /dev/null +++ b/crates/backend/goldilocks/src/x86_64_avx512/packing.rs @@ -0,0 +1,331 @@ +// Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). + +use alloc::vec::Vec; +use core::arch::x86_64::*; +use core::fmt::Debug; +use core::iter::{Product, Sum}; +use core::mem::transmute; +use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use field::interleave::{interleave_u64, interleave_u128, interleave_u256}; +use field::op_assign_macros::{ + impl_add_assign, impl_add_base_field, impl_div_methods, impl_mul_base_field, impl_mul_methods, impl_packed_value, + impl_rng, impl_sub_assign, impl_sub_base_field, impl_sum_prod_base_field, ring_sum, +}; +use field::{ + Algebra, Field, InjectiveMonomial, PackedField, PackedFieldPow2, PackedValue, PermutationMonomial, + PrimeCharacteristicRing, PrimeField64, impl_packed_field_pow_2, +}; +use rand::Rng; +use rand::distr::{Distribution, StandardUniform}; +use utils::reconstitute_from_base; + +use crate::helpers::exp_10540996611094048183; +use crate::{Goldilocks, P}; + +const WIDTH: usize = 8; + +/// Vectorized AVX512 implementation of `Goldilocks` arithmetic. +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] +#[repr(transparent)] // Needed to make `transmute`s safe. +#[must_use] +pub struct PackedGoldilocksAVX512(pub [Goldilocks; WIDTH]); + +impl PackedGoldilocksAVX512 { + /// Get an arch-specific vector representing the packed values. + #[inline] + #[must_use] + pub(crate) fn to_vector(self) -> __m512i { + unsafe { transmute(self) } + } + + /// Make a packed field vector from an arch-specific vector. + /// + /// Goldilocks elements may be arbitrary u64s, so this is always safe. + #[inline] + pub(crate) fn from_vector(vector: __m512i) -> Self { + unsafe { transmute(vector) } + } + + /// Copy `value` to all positions in a packed vector. `const` version of `From`. + #[inline] + const fn broadcast(value: Goldilocks) -> Self { + Self([value; WIDTH]) + } +} + +impl From for PackedGoldilocksAVX512 { + fn from(x: Goldilocks) -> Self { + Self::broadcast(x) + } +} + +impl Add for PackedGoldilocksAVX512 { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self::from_vector(add(self.to_vector(), rhs.to_vector())) + } +} + +impl Sub for PackedGoldilocksAVX512 { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self::from_vector(sub(self.to_vector(), rhs.to_vector())) + } +} + +impl Neg for PackedGoldilocksAVX512 { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self::from_vector(neg(self.to_vector())) + } +} + +impl Mul for PackedGoldilocksAVX512 { + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + Self::from_vector(mul(self.to_vector(), rhs.to_vector())) + } +} + +impl_add_assign!(PackedGoldilocksAVX512); +impl_sub_assign!(PackedGoldilocksAVX512); +impl_mul_methods!(PackedGoldilocksAVX512); +ring_sum!(PackedGoldilocksAVX512); +impl_rng!(PackedGoldilocksAVX512); + +impl PrimeCharacteristicRing for PackedGoldilocksAVX512 { + type PrimeSubfield = Goldilocks; + + const ZERO: Self = Self::broadcast(Goldilocks::ZERO); + const ONE: Self = Self::broadcast(Goldilocks::ONE); + const TWO: Self = Self::broadcast(Goldilocks::TWO); + const NEG_ONE: Self = Self::broadcast(Goldilocks::NEG_ONE); + + #[inline] + fn from_prime_subfield(f: Self::PrimeSubfield) -> Self { + f.into() + } + + #[inline] + fn halve(&self) -> Self { + Self::from_vector(halve(self.to_vector())) + } + + #[inline] + fn square(&self) -> Self { + Self::from_vector(square(self.to_vector())) + } + + #[inline] + fn zero_vec(len: usize) -> Vec { + // SAFETY: this is a repr(transparent) wrapper around an array. + unsafe { reconstitute_from_base(Goldilocks::zero_vec(len * WIDTH)) } + } +} + +impl_add_base_field!(PackedGoldilocksAVX512, Goldilocks); +impl_sub_base_field!(PackedGoldilocksAVX512, Goldilocks); +impl_mul_base_field!(PackedGoldilocksAVX512, Goldilocks); +impl_div_methods!(PackedGoldilocksAVX512, Goldilocks); +impl_sum_prod_base_field!(PackedGoldilocksAVX512, Goldilocks); + +impl Algebra for PackedGoldilocksAVX512 {} + +impl InjectiveMonomial<7> for PackedGoldilocksAVX512 {} + +impl PermutationMonomial<7> for PackedGoldilocksAVX512 { + fn injective_exp_root_n(&self) -> Self { + exp_10540996611094048183(*self) + } +} + +impl_packed_value!(PackedGoldilocksAVX512, Goldilocks, WIDTH); + +unsafe impl PackedField for PackedGoldilocksAVX512 { + type Scalar = Goldilocks; +} + +impl_packed_field_pow_2!( + PackedGoldilocksAVX512; + [ + (1, interleave_u64), + (2, interleave_u128), + (4, interleave_u256), + ], + WIDTH +); + +const FIELD_ORDER: __m512i = unsafe { transmute([Goldilocks::ORDER_U64; WIDTH]) }; +const EPSILON: __m512i = unsafe { transmute([Goldilocks::ORDER_U64.wrapping_neg(); WIDTH]) }; + +#[inline] +unsafe fn canonicalize(x: __m512i) -> __m512i { + unsafe { + let mask = _mm512_cmpge_epu64_mask(x, FIELD_ORDER); + _mm512_mask_sub_epi64(x, mask, x, FIELD_ORDER) + } +} + +/// Compute `x + y mod P`. Result may be > P. +/// +/// # Safety +/// Caller must ensure `x + y < 2^64 + P`. +#[inline] +unsafe fn add_no_double_overflow_64_64(x: __m512i, y: __m512i) -> __m512i { + unsafe { + let res_wrapped = _mm512_add_epi64(x, y); + let mask = _mm512_cmplt_epu64_mask(res_wrapped, y); + _mm512_mask_sub_epi64(res_wrapped, mask, res_wrapped, FIELD_ORDER) + } +} + +/// Compute `x - y mod P`. Result may be > P. +/// +/// # Safety +/// Caller must ensure `x - y > -P`. +#[inline] +unsafe fn sub_no_double_overflow_64_64(x: __m512i, y: __m512i) -> __m512i { + unsafe { + let mask = _mm512_cmplt_epu64_mask(x, y); + let res_wrapped = _mm512_sub_epi64(x, y); + _mm512_mask_add_epi64(res_wrapped, mask, res_wrapped, FIELD_ORDER) + } +} + +#[inline] +fn add(x: __m512i, y: __m512i) -> __m512i { + unsafe { add_no_double_overflow_64_64(x, canonicalize(y)) } +} + +#[inline] +fn sub(x: __m512i, y: __m512i) -> __m512i { + unsafe { sub_no_double_overflow_64_64(x, canonicalize(y)) } +} + +#[inline] +fn neg(y: __m512i) -> __m512i { + unsafe { _mm512_sub_epi64(FIELD_ORDER, canonicalize(y)) } +} + +/// Halve a vector of Goldilocks field elements. +#[inline(always)] +pub(crate) fn halve(input: __m512i) -> __m512i { + // For val in [0, P): val even -> val/2 = val>>1; val odd -> (val+P)/2 = (val>>1) + (P+1)/2. + unsafe { + const ONE: __m512i = unsafe { transmute([1_i64; 8]) }; + let half = _mm512_set1_epi64(P.div_ceil(2) as i64); + + let least_bit = _mm512_test_epi64_mask(input, ONE); + let t = _mm512_srli_epi64::<1>(input); + _mm512_mask_add_epi64(t, least_bit, t, half) + } +} + +#[allow(clippy::useless_transmute)] +const LO_32_BITS_MASK: __mmask16 = unsafe { transmute(0b0101010101010101u16) }; + +/// Full 64x64 -> 128 multiplication, returning `(hi, lo)`. +#[inline] +fn mul64_64(x: __m512i, y: __m512i) -> (__m512i, __m512i) { + unsafe { + let x_hi = _mm512_castps_si512(_mm512_movehdup_ps(_mm512_castsi512_ps(x))); + let y_hi = _mm512_castps_si512(_mm512_movehdup_ps(_mm512_castsi512_ps(y))); + + let mul_ll = _mm512_mul_epu32(x, y); + let mul_lh = _mm512_mul_epu32(x, y_hi); + let mul_hl = _mm512_mul_epu32(x_hi, y); + let mul_hh = _mm512_mul_epu32(x_hi, y_hi); + + let mul_ll_hi = _mm512_srli_epi64::<32>(mul_ll); + let t0 = _mm512_add_epi64(mul_hl, mul_ll_hi); + let t0_lo = _mm512_and_si512(t0, EPSILON); + let t0_hi = _mm512_srli_epi64::<32>(t0); + let t1 = _mm512_add_epi64(mul_lh, t0_lo); + let t2 = _mm512_add_epi64(mul_hh, t0_hi); + let t1_hi = _mm512_srli_epi64::<32>(t1); + let res_hi = _mm512_add_epi64(t2, t1_hi); + + let t1_lo = _mm512_castps_si512(_mm512_moveldup_ps(_mm512_castsi512_ps(t1))); + let res_lo = _mm512_mask_blend_epi32(LO_32_BITS_MASK, t1_lo, mul_ll); + + (res_hi, res_lo) + } +} + +/// Full 64-bit squaring. +#[inline] +fn square64(x: __m512i) -> (__m512i, __m512i) { + unsafe { + let x_hi = _mm512_castps_si512(_mm512_movehdup_ps(_mm512_castsi512_ps(x))); + + let mul_ll = _mm512_mul_epu32(x, x); + let mul_lh = _mm512_mul_epu32(x, x_hi); + let mul_hh = _mm512_mul_epu32(x_hi, x_hi); + + let mul_ll_hi = _mm512_srli_epi64::<33>(mul_ll); + let t0 = _mm512_add_epi64(mul_lh, mul_ll_hi); + let t0_hi = _mm512_srli_epi64::<31>(t0); + let res_hi = _mm512_add_epi64(mul_hh, t0_hi); + + let mul_lh_lo = _mm512_slli_epi64::<33>(mul_lh); + let res_lo = _mm512_add_epi64(mul_ll, mul_lh_lo); + + (res_hi, res_lo) + } +} + +/// Reduce a 128-bit value (high, low) modulo `P`. Result may be > P. +#[inline] +fn reduce128(x: (__m512i, __m512i)) -> __m512i { + unsafe { + let (hi0, lo0) = x; + + let hi_hi0 = _mm512_srli_epi64::<32>(hi0); + + // 2^96 = -1 mod P. + let lo1 = sub_no_double_overflow_64_64(lo0, hi_hi0); + + // Bottom 32 bits of hi0 times 2^64 = 2^32 - 1 mod P. + let t1 = _mm512_mul_epu32(hi0, EPSILON); + + add_no_double_overflow_64_64(lo1, t1) + } +} + +#[inline] +fn mul(x: __m512i, y: __m512i) -> __m512i { + reduce128(mul64_64(x, y)) +} + +#[inline] +fn square(x: __m512i) -> __m512i { + reduce128(square64(x)) +} + +#[cfg(test)] +mod tests { + use super::{Goldilocks, PackedGoldilocksAVX512, WIDTH}; + + const SPECIAL_VALS: [Goldilocks; WIDTH] = Goldilocks::new_array([ + 0xFFFF_FFFF_0000_0001, + 0xFFFF_FFFF_0000_0000, + 0xFFFF_FFFE_FFFF_FFFF, + 0xFFFF_FFFF_FFFF_FFFF, + 0x0000_0000_0000_0000, + 0x0000_0000_0000_0001, + 0x0000_0000_0000_0002, + 0x0FFF_FFFF_F000_0000, + ]); + + #[test] + fn pack_round_trip() { + let p = PackedGoldilocksAVX512(SPECIAL_VALS); + let v = p.to_vector(); + assert_eq!(PackedGoldilocksAVX512::from_vector(v).0, SPECIAL_VALS); + } +} From 84f208b97ffce1eb1af872ff8c1538bfc96e4253 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Sun, 26 Apr 2026 12:05:06 +0200 Subject: [PATCH 19/31] w --- .../goldilocks/src/aarch64_neon/packing.rs | 44 +------------------ crates/backend/goldilocks/src/poseidon1.rs | 28 +++++++----- 2 files changed, 19 insertions(+), 53 deletions(-) diff --git a/crates/backend/goldilocks/src/aarch64_neon/packing.rs b/crates/backend/goldilocks/src/aarch64_neon/packing.rs index 32044303..6f0bb93a 100644 --- a/crates/backend/goldilocks/src/aarch64_neon/packing.rs +++ b/crates/backend/goldilocks/src/aarch64_neon/packing.rs @@ -29,49 +29,7 @@ const WIDTH: usize = 2; /// Equal to `2^32 - 1 = 2^64 mod P`. const EPSILON: u64 = Goldilocks::ORDER_U64.wrapping_neg(); -/// Goldilocks scalar addition: `(a + b) mod P`. Handles u64 overflow via EPSILON. -#[inline(always)] -pub(super) const fn gadd(a: u64, b: u64) -> u64 { - let (sum, overflow) = a.overflowing_add(b); - let (res, _) = sum.overflowing_add(if overflow { EPSILON } else { 0 }); - res -} - -/// Goldilocks scalar subtraction: `(a - b) mod P`. Handles u64 underflow via EPSILON. -#[inline(always)] -pub(super) const fn gsub(a: u64, b: u64) -> u64 { - let (diff, borrow) = a.overflowing_sub(b); - let (res, _) = diff.overflowing_sub(if borrow { EPSILON } else { 0 }); - res -} - -/// Single Goldilocks 64x64 -> 64 mul + reduction in pure Rust. -/// -/// LLVM emits `mul + umulh + 10-op reduction`, just like the inline-asm version, but as plain -/// instructions the compiler can interleave across independent calls instead of treating each as -/// an opaque asm block. -#[inline(always)] -pub(super) const fn mul_reduce(a: u64, b: u64) -> u64 { - let prod = (a as u128) * (b as u128); - let lo = prod as u64; - let hi = (prod >> 64) as u64; - - let hi_hi = hi >> 32; - let hi_lo = hi & 0xFFFF_FFFF; - - // tmp = lo - hi_hi; on borrow subtract EPSILON to fold P back in. - let (tmp_pre, borrow) = lo.overflowing_sub(hi_hi); - let tmp = tmp_pre.wrapping_sub(if borrow { EPSILON } else { 0 }); - - // hi_lo * (2^32 - 1) without an actual multiply. - let hi_lo_eps = (hi_lo << 32).wrapping_sub(hi_lo); - - // result = tmp + hi_lo_eps; on overflow add EPSILON. - let (res_pre, overflow) = tmp.overflowing_add(hi_lo_eps); - res_pre.wrapping_add(if overflow { EPSILON } else { 0 }) -} - -/// Hand-scheduled inline-asm variant of [`mul_reduce`], tuned for the **scalar / single-lane Mul** +/// Hand-scheduled inline-asm variant tuned for the **scalar / single-lane Mul** /// path on aarch64. Saves one ALU op vs the LLVM-emitted form by collapsing `lsr+subs` into the /// shifted-register `subs xT, lo, hi, lsr #32` form. #[inline(always)] diff --git a/crates/backend/goldilocks/src/poseidon1.rs b/crates/backend/goldilocks/src/poseidon1.rs index 3100afeb..4fad7db4 100644 --- a/crates/backend/goldilocks/src/poseidon1.rs +++ b/crates/backend/goldilocks/src/poseidon1.rs @@ -57,7 +57,7 @@ fn mds_mul_generic>(state: &mut [R; 8]) { // `row_i · input = sum_j ROW[(j - i) mod 8] · input[j]` let mut acc = input[0] * coeffs[(8 - i) % 8]; for j in 1..8 { - acc = acc + input[j] * coeffs[(j + 8 - i) % 8]; + acc += input[j] * coeffs[(j + 8 - i) % 8]; } state[i] = acc; } @@ -424,9 +424,9 @@ impl Poseidon1Goldilocks8 { } pub fn permute_mut(&self, state: &mut [Goldilocks; POSEIDON1_WIDTH]) { - for r in 0..POSEIDON1_HALF_FULL_ROUNDS { - for i in 0..POSEIDON1_WIDTH { - state[i] += GOLDILOCKS_POSEIDON1_RC_8[r][i]; + for rc in GOLDILOCKS_POSEIDON1_RC_8.iter().take(POSEIDON1_HALF_FULL_ROUNDS) { + for (i, s) in state.iter_mut().enumerate() { + *s += rc[i]; } for s in state.iter_mut() { *s = sbox_full::(*s); @@ -434,17 +434,25 @@ impl Poseidon1Goldilocks8 { mds_mul_scalar(state); } - for r in POSEIDON1_HALF_FULL_ROUNDS..POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS { - for i in 0..POSEIDON1_WIDTH { - state[i] += GOLDILOCKS_POSEIDON1_RC_8[r][i]; + for rc in GOLDILOCKS_POSEIDON1_RC_8 + .iter() + .skip(POSEIDON1_HALF_FULL_ROUNDS) + .take(POSEIDON1_PARTIAL_ROUNDS) + { + for (i, s) in state.iter_mut().enumerate() { + *s += rc[i]; } state[0] = sbox_full::(state[0]); mds_mul_scalar(state); } - for r in POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS..POSEIDON1_N_ROUNDS { - for i in 0..POSEIDON1_WIDTH { - state[i] += GOLDILOCKS_POSEIDON1_RC_8[r][i]; + for rc in GOLDILOCKS_POSEIDON1_RC_8 + .iter() + .take(POSEIDON1_N_ROUNDS) + .skip(POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS) + { + for (i, s) in state.iter_mut().enumerate() { + *s += rc[i]; } for s in state.iter_mut() { *s = sbox_full::(*s); From 80b3a9874a166271e8253efd2743b2d5722d6bea Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Sun, 26 Apr 2026 12:11:19 +0200 Subject: [PATCH 20/31] w --- crates/backend/goldilocks/src/poseidon1.rs | 28 ++++++---- crates/lean_prover/src/verify_execution.rs | 2 +- crates/lean_vm/src/tables/poseidon_8/mod.rs | 56 +++++++++---------- .../lean_vm/src/tables/poseidon_8/sparse.rs | 51 ++++++++--------- crates/rec_aggregation/src/lib.rs | 2 +- crates/xmss/src/wots.rs | 4 +- 6 files changed, 74 insertions(+), 69 deletions(-) diff --git a/crates/backend/goldilocks/src/poseidon1.rs b/crates/backend/goldilocks/src/poseidon1.rs index 4fad7db4..368d62d0 100644 --- a/crates/backend/goldilocks/src/poseidon1.rs +++ b/crates/backend/goldilocks/src/poseidon1.rs @@ -467,9 +467,9 @@ impl Poseidon1Goldilocks8 { where R: Algebra + InjectiveMonomial<7> + Copy, { - for r in 0..POSEIDON1_HALF_FULL_ROUNDS { - for i in 0..POSEIDON1_WIDTH { - state[i] = state[i] + GOLDILOCKS_POSEIDON1_RC_8[r][i]; + for rc in GOLDILOCKS_POSEIDON1_RC_8.iter().take(POSEIDON1_HALF_FULL_ROUNDS) { + for (i, s) in state.iter_mut().enumerate() { + *s += rc[i]; } for s in state.iter_mut() { *s = sbox_full::(*s); @@ -477,17 +477,25 @@ impl Poseidon1Goldilocks8 { mds_mul_generic(state); } - for r in POSEIDON1_HALF_FULL_ROUNDS..POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS { - for i in 0..POSEIDON1_WIDTH { - state[i] = state[i] + GOLDILOCKS_POSEIDON1_RC_8[r][i]; + for rc in GOLDILOCKS_POSEIDON1_RC_8 + .iter() + .skip(POSEIDON1_HALF_FULL_ROUNDS) + .take(POSEIDON1_PARTIAL_ROUNDS) + { + for (i, s) in state.iter_mut().enumerate() { + *s += rc[i]; } state[0] = sbox_full::(state[0]); mds_mul_generic(state); } - for r in POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS..POSEIDON1_N_ROUNDS { - for i in 0..POSEIDON1_WIDTH { - state[i] = state[i] + GOLDILOCKS_POSEIDON1_RC_8[r][i]; + for rc in GOLDILOCKS_POSEIDON1_RC_8 + .iter() + .take(POSEIDON1_N_ROUNDS) + .skip(POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS) + { + for (i, s) in state.iter_mut().enumerate() { + *s += rc[i]; } for s in state.iter_mut() { *s = sbox_full::(*s); @@ -508,7 +516,7 @@ impl Poseidon1Goldilocks8 { let initial = *state; self.permute_generic(state); for (s, init) in state.iter_mut().zip(initial) { - *s = *s + init; + *s += init; } } } diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index 531b1c20..c8a7b26d 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -16,7 +16,7 @@ pub fn verify_execution( public_input: &[F], proof: Proof, ) -> Result<(ProofVerificationDetails, RawProof), ProofError> { - let mut verifier_state = VerifierState::::new(proof, get_poseidon8().clone())?; + let mut verifier_state = VerifierState::::new(proof, *get_poseidon8())?; verifier_state.observe_scalars(public_input); verifier_state.observe_scalars(&poseidon8_compress_pair(&bytecode.hash, &SNARK_DOMAIN_SEP)); let dims = verifier_state diff --git a/crates/lean_vm/src/tables/poseidon_8/mod.rs b/crates/lean_vm/src/tables/poseidon_8/mod.rs index 83e3d58d..3c141a1c 100644 --- a/crates/lean_vm/src/tables/poseidon_8/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_8/mod.rs @@ -88,7 +88,7 @@ fn mds_vec_mul(state: &[F; WIDTH]) -> [F; WIDTH] { for i in 0..WIDTH { let mut acc = state[0] * F::from_u64(MDS8_ROW[(WIDTH - i) % WIDTH] as u64); for j in 1..WIDTH { - acc = acc + state[j] * F::from_u64(MDS8_ROW[(j + WIDTH - i) % WIDTH] as u64); + acc += state[j] * F::from_u64(MDS8_ROW[(j + WIDTH - i) % WIDTH] as u64); } out[i] = acc; } @@ -107,9 +107,9 @@ pub(crate) fn compute_poseidon8_witness(input: [F; WIDTH]) -> (Vec, [F; DIGES let mut aux = Vec::with_capacity(AUX_COLS_PER_ROW); // Initial full rounds. - for round in 0..POSEIDON1_HALF_FULL_ROUNDS { - for i in 0..WIDTH { - state[i] = sbox7(state[i] + GOLDILOCKS_POSEIDON1_RC_8[round][i]); + for rc in GOLDILOCKS_POSEIDON1_RC_8.iter().take(POSEIDON1_HALF_FULL_ROUNDS) { + for (i, s) in state.iter_mut().enumerate() { + *s = sbox7(*s + rc[i]); } let post = mds_vec_mul(&state); for v in &post { @@ -119,17 +119,17 @@ pub(crate) fn compute_poseidon8_witness(input: [F; WIDTH]) -> (Vec, [F; DIGES } // Partial phase: absorb first_round_constants, apply m_i, then sparse rounds. - for i in 0..WIDTH { - state[i] = state[i] + c.first_round_constants[i]; + for (i, s) in state.iter_mut().enumerate() { + *s += c.first_round_constants[i]; } { let mut after = [F::ZERO; WIDTH]; - for i in 0..WIDTH { + for (i, dst) in after.iter_mut().enumerate() { let mut acc = F::ZERO; - for j in 0..WIDTH { - acc = acc + c.m_i[i][j] * state[j]; + for (j, sj) in state.iter().enumerate() { + acc += c.m_i[i][j] * *sj; } - after[i] = acc; + *dst = acc; } state = after; } @@ -149,12 +149,12 @@ pub(crate) fn compute_poseidon8_witness(input: [F; WIDTH]) -> (Vec, [F; DIGES // new_state[i] = state[i] + v[r][i-1] · old_state[0] (for i ≥ 1) let old_s0 = state[0]; let mut new_s0 = F::ZERO; - for j in 0..WIDTH { - new_s0 = new_s0 + c.sparse_first_row[r][j] * state[j]; + for (j, sj) in state.iter().enumerate() { + new_s0 += c.sparse_first_row[r][j] * *sj; } state[0] = new_s0; - for i in 1..WIDTH { - state[i] = state[i] + c.v[r][i - 1] * old_s0; + for (i, s) in state.iter_mut().enumerate().skip(1) { + *s += c.v[r][i - 1] * old_s0; } } @@ -379,9 +379,9 @@ impl Air for Poseidon8Precompile { let post = full_posts[round]; for i in 0..WIDTH { let mut acc = sbox_out[0] * AB::F::from_u64(MDS8_ROW[(WIDTH - i) % WIDTH] as u64); - for j in 1..WIDTH { + for (j, sj) in sbox_out.iter().enumerate().skip(1) { let coeff = AB::F::from_u64(MDS8_ROW[(j + WIDTH - i) % WIDTH] as u64); - acc = acc + sbox_out[j] * coeff; + acc += *sj * coeff; } builder.assert_zero(post[i] - acc); } @@ -389,23 +389,23 @@ impl Air for Poseidon8Precompile { } // ---- Partial phase: first_round_constants, m_i, sparse-matmul loop ---- - for i in 0..WIDTH { - state[i] = state[i] + AB::F::from_u64(c.first_round_constants[i].as_canonical_u64()); + for (i, s) in state.iter_mut().enumerate() { + *s += AB::F::from_u64(c.first_round_constants[i].as_canonical_u64()); } { let mut after: [AB::IF; WIDTH] = std::array::from_fn(|i| { let mut acc = state[0] * AB::F::from_u64(c.m_i[i][0].as_canonical_u64()); - for j in 1..WIDTH { - acc = acc + state[j] * AB::F::from_u64(c.m_i[i][j].as_canonical_u64()); + for (j, sj) in state.iter().enumerate().skip(1) { + acc += *sj * AB::F::from_u64(c.m_i[i][j].as_canonical_u64()); } acc }); std::mem::swap(&mut state, &mut after); } - for r in 0..SPARSE_PARTIAL_ROUNDS { + for (r, post_sbox) in partial_post_sboxes.iter().enumerate().take(SPARSE_PARTIAL_ROUNDS) { let x = state[0]; - let post_sbox = partial_post_sboxes[r]; + let post_sbox = *post_sbox; // post_sbox = x⁷ (deg 7). let x2 = x * x; @@ -422,12 +422,12 @@ impl Air for Poseidon8Precompile { // cheap_matmul. let old_s0 = state[0]; let mut new_s0 = state[0] * AB::F::from_u64(c.sparse_first_row[r][0].as_canonical_u64()); - for j in 1..WIDTH { - new_s0 = new_s0 + state[j] * AB::F::from_u64(c.sparse_first_row[r][j].as_canonical_u64()); + for (j, sj) in state.iter().enumerate().skip(1) { + new_s0 += *sj * AB::F::from_u64(c.sparse_first_row[r][j].as_canonical_u64()); } state[0] = new_s0; - for i in 1..WIDTH { - state[i] = state[i] + old_s0 * AB::F::from_u64(c.v[r][i - 1].as_canonical_u64()); + for (i, s) in state.iter_mut().enumerate().skip(1) { + *s += old_s0 * AB::F::from_u64(c.v[r][i - 1].as_canonical_u64()); } } @@ -443,9 +443,9 @@ impl Air for Poseidon8Precompile { let post = full_posts[POSEIDON1_HALF_FULL_ROUNDS + round]; for i in 0..WIDTH { let mut acc = sbox_out[0] * AB::F::from_u64(MDS8_ROW[(WIDTH - i) % WIDTH] as u64); - for j in 1..WIDTH { + for (j, sj) in sbox_out.iter().enumerate().skip(1) { let coeff = AB::F::from_u64(MDS8_ROW[(j + WIDTH - i) % WIDTH] as u64); - acc = acc + sbox_out[j] * coeff; + acc += *sj * coeff; } builder.assert_zero(post[i] - acc); } diff --git a/crates/lean_vm/src/tables/poseidon_8/sparse.rs b/crates/lean_vm/src/tables/poseidon_8/sparse.rs index 4d486a6a..2ce825f3 100644 --- a/crates/lean_vm/src/tables/poseidon_8/sparse.rs +++ b/crates/lean_vm/src/tables/poseidon_8/sparse.rs @@ -78,7 +78,7 @@ fn matrix_mul(a: &[[F; WIDTH]; WIDTH], b: &[[F; WIDTH]; WIDTH]) -> [[F; WIDTH]; for j in 0..WIDTH { let mut acc = F::ZERO; for k in 0..WIDTH { - acc = acc + a[i][k] * b[k][j]; + acc += a[i][k] * b[k][j]; } c[i][j] = acc; } @@ -91,7 +91,7 @@ fn matrix_vec_mul(m: &[[F; WIDTH]; WIDTH], v: &[F; WIDTH]) -> [F; WIDTH] { for i in 0..WIDTH { let mut acc = F::ZERO; for j in 0..WIDTH { - acc = acc + m[i][j] * v[j]; + acc += m[i][j] * v[j]; } r[i] = acc; } @@ -101,8 +101,8 @@ fn matrix_vec_mul(m: &[[F; WIDTH]; WIDTH], v: &[F; WIDTH]) -> [F; WIDTH] { fn matrix_inverse(m: &[[F; WIDTH]; WIDTH]) -> [[F; WIDTH]; WIDTH] { let mut aug = *m; let mut inv = [[F::ZERO; WIDTH]; WIDTH]; - for i in 0..WIDTH { - inv[i][i] = F::ONE; + for (i, row) in inv.iter_mut().enumerate().take(WIDTH) { + row[i] = F::ONE; } for col in 0..WIDTH { let pivot = (col..WIDTH) @@ -114,8 +114,8 @@ fn matrix_inverse(m: &[[F; WIDTH]; WIDTH]) -> [[F; WIDTH]; WIDTH] { } let pivot_inv = aug[col][col].inverse(); for j in 0..WIDTH { - aug[col][j] = aug[col][j] * pivot_inv; - inv[col][j] = inv[col][j] * pivot_inv; + aug[col][j] *= pivot_inv; + inv[col][j] *= pivot_inv; } for i in 0..WIDTH { if i == col { @@ -128,8 +128,8 @@ fn matrix_inverse(m: &[[F; WIDTH]; WIDTH]) -> [[F; WIDTH]; WIDTH] { let aug_row = aug[col]; let inv_row = inv[col]; for j in 0..WIDTH { - aug[i][j] = aug[i][j] - factor * aug_row[j]; - inv[i][j] = inv[i][j] - factor * inv_row[j]; + aug[i][j] -= factor * aug_row[j]; + inv[i][j] -= factor * inv_row[j]; } } } @@ -146,8 +146,8 @@ fn submatrix_inverse(m: &[[F; WIDTH]; WIDTH]) -> [[F; WIDTH - 1]; WIDTH - 1] { } } let mut inv = [[F::ZERO; N]; N]; - for i in 0..N { - inv[i][i] = F::ONE; + for (i, row) in inv.iter_mut().enumerate().take(N) { + row[i] = F::ONE; } for col in 0..N { let pivot = (col..N) @@ -159,8 +159,8 @@ fn submatrix_inverse(m: &[[F; WIDTH]; WIDTH]) -> [[F; WIDTH - 1]; WIDTH - 1] { } let pivot_inv = sub[col][col].inverse(); for j in 0..N { - sub[col][j] = sub[col][j] * pivot_inv; - inv[col][j] = inv[col][j] * pivot_inv; + sub[col][j] *= pivot_inv; + inv[col][j] *= pivot_inv; } for i in 0..N { if i == col { @@ -173,23 +173,22 @@ fn submatrix_inverse(m: &[[F; WIDTH]; WIDTH]) -> [[F; WIDTH - 1]; WIDTH - 1] { let sub_row = sub[col]; let inv_row = inv[col]; for j in 0..N { - sub[i][j] = sub[i][j] - factor * sub_row[j]; - inv[i][j] = inv[i][j] - factor * inv_row[j]; + sub[i][j] -= factor * sub_row[j]; + inv[i][j] -= factor * inv_row[j]; } } } inv } +type EquivalentMatrices = ([[F; WIDTH]; WIDTH], Vec<[F; WIDTH]>, Vec<[F; WIDTH]>); + /// Factor the dense MDS matrix into `RP` sparse factors. /// /// Returns `(m_i, v_collection, w_hat_collection)` all in forward application /// order; `v_collection[r]` and `w_hat_collection[r]` have `WIDTH-1` meaningful /// entries (the last slot is zero padding for fixed-size arrays). -fn compute_equivalent_matrices( - mds: &[[F; WIDTH]; WIDTH], - rounds_p: usize, -) -> ([[F; WIDTH]; WIDTH], Vec<[F; WIDTH]>, Vec<[F; WIDTH]>) { +fn compute_equivalent_matrices(mds: &[[F; WIDTH]; WIDTH], rounds_p: usize) -> EquivalentMatrices { let mut v_collection: Vec<[F; WIDTH]> = Vec::with_capacity(rounds_p); let mut w_hat_collection: Vec<[F; WIDTH]> = Vec::with_capacity(rounds_p); @@ -213,7 +212,7 @@ fn compute_equivalent_matrices( if i < WIDTH - 1 { let mut acc = F::ZERO; for k in 0..WIDTH - 1 { - acc = acc + m_hat_inv[i][k] * w[k]; + acc += m_hat_inv[i][k] * w[k]; } acc } else { @@ -228,11 +227,11 @@ fn compute_equivalent_matrices( // "absorb" the rest: first column zeroed, first row zeroed, [0][0]=1. m_i = m_mul; m_i[0][0] = F::ONE; - for r in 1..WIDTH { - m_i[r][0] = F::ZERO; + for row in m_i.iter_mut().take(WIDTH).skip(1) { + row[0] = F::ZERO; } - for c in 1..WIDTH { - m_i[0][c] = F::ZERO; + for entry in m_i[0].iter_mut().take(WIDTH).skip(1) { + *entry = F::ZERO; } // Accumulate: m_mul = M^T * m_i. @@ -260,7 +259,7 @@ fn equivalent_round_constants(partial_rc: &[[F; WIDTH]], mds_inv: &[[F; WIDTH]; opt_partial_rc[i + 1] = inv_cip[0]; tmp = partial_rc[i]; for j in 1..WIDTH { - tmp[j] = tmp[j] + inv_cip[j]; + tmp[j] += inv_cip[j]; } } let first_round_constants = tmp; @@ -299,9 +298,7 @@ fn compute_partial_constants() -> PartialConstants { } let mut round_constants = [F::ZERO; PARTIAL_ROUNDS - 1]; - for i in 0..PARTIAL_ROUNDS - 1 { - round_constants[i] = round_constants_vec[i]; - } + round_constants[..(PARTIAL_ROUNDS - 1)].copy_from_slice(&round_constants_vec[..(PARTIAL_ROUNDS - 1)]); PartialConstants { first_round_constants, diff --git a/crates/rec_aggregation/src/lib.rs b/crates/rec_aggregation/src/lib.rs index 0dd49a0d..61b47073 100644 --- a/crates/rec_aggregation/src/lib.rs +++ b/crates/rec_aggregation/src/lib.rs @@ -275,7 +275,7 @@ pub fn xmss_aggregate( let final_sumcheck_proof = { // Recover the transcript of the final sumcheck (for bytecode claim reduction) - let mut vs = VerifierState::::new(reduction_prover.into_proof(), get_poseidon8().clone()).unwrap(); + let mut vs = VerifierState::::new(reduction_prover.into_proof(), *get_poseidon8()).unwrap(); vs.next_base_scalars_vec(claims_hash.len()).unwrap(); let _: EF = vs.sample(); sumcheck_verify(&mut vs, bytecode_point_n_vars, 2, claimed_sum, None).unwrap(); diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index ec05c118..fd165194 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -129,8 +129,8 @@ pub fn wots_encode( const NUM_ENCODING_FE: usize = 3; const CHUNKS_PER_FE: usize = 21; const MASK: u64 = (1u64 << W) - 1; - debug_assert!(CHUNKS_PER_FE * W <= 63); // 1-bit remainder - debug_assert!(NUM_ENCODING_FE * CHUNKS_PER_FE >= V + V_GRINDING); + const _: () = assert!(CHUNKS_PER_FE * W <= 63); // 1-bit remainder + const _: () = assert!(NUM_ENCODING_FE * CHUNKS_PER_FE >= V + V_GRINDING); let mut all_indices = [0u8; NUM_ENCODING_FE * CHUNKS_PER_FE]; for (i, fe) in full_output.iter().take(NUM_ENCODING_FE).enumerate() { From 89a2dc54acaaef46dc0120c93dc7e24a3a8a233d Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Sun, 26 Apr 2026 12:53:47 +0200 Subject: [PATCH 21/31] 2x faster poseidon --- crates/backend/goldilocks/src/poseidon1.rs | 207 +++++++++++++++++++-- 1 file changed, 189 insertions(+), 18 deletions(-) diff --git a/crates/backend/goldilocks/src/poseidon1.rs b/crates/backend/goldilocks/src/poseidon1.rs index 368d62d0..93a4aaa9 100644 --- a/crates/backend/goldilocks/src/poseidon1.rs +++ b/crates/backend/goldilocks/src/poseidon1.rs @@ -14,7 +14,7 @@ //! implements `InjectiveMonomial<7>`, mirroring the koala-bear crate's //! Poseidon1 surface. -use field::{Algebra, InjectiveMonomial, PrimeCharacteristicRing}; +use field::{Algebra, Field, InjectiveMonomial, PackedValue, PrimeCharacteristicRing}; use crate::Goldilocks; @@ -63,21 +63,45 @@ fn mds_mul_generic>(state: &mut [R; 8]) { } } -/// Specialized fast MDS for the concrete `Goldilocks` scalar — uses a single -/// `u128` accumulator and one `reduce128` per output lane (coefficients ≤ 9, -/// so `8 × 9 × 2^64 ≈ 2^71` fits comfortably). -#[inline] +/// Specialized fast MDS for the concrete `Goldilocks` scalar. +/// +/// Each output is a dot product `sum_j MDS_ROW[(j-i) mod 8] * state[j]` with +/// MDS coefficients in `{1, 3, 4, 7, 8, 9}` (all fit in 4 bits). With the +/// constants spelled out explicitly LLVM strength-reduces `c * s` to shifts +/// and adds (e.g. `8*s = s<<3`, `7*s = (s<<3)-s`), eliminating the variable +/// multiplications entirely. We accumulate into `u128` (8·9·2^64 ≈ 2^71 fits +/// comfortably) and reduce once per output via `reduce128`. The explicit +/// `1 *` factors keep the circulant structure readable column-by-column. +#[inline(always)] +#[allow(clippy::identity_op)] fn mds_mul_scalar(state: &mut [Goldilocks; 8]) { - let mut out = [Goldilocks::ZERO; 8]; - for i in 0..8 { - let mut acc: u128 = 0; - for j in 0..8 { - let c = MDS8_ROW[(j + 8 - i) % 8] as u128; - acc = acc.wrapping_add(c.wrapping_mul(state[j].value as u128)); - } - out[i] = crate::goldilocks::reduce128(acc); - } - *state = out; + let s0 = state[0].value as u128; + let s1 = state[1].value as u128; + let s2 = state[2].value as u128; + let s3 = state[3].value as u128; + let s4 = state[4].value as u128; + let s5 = state[5].value as u128; + let s6 = state[6].value as u128; + let s7 = state[7].value as u128; + + // MDS_ROW = [7, 1, 3, 8, 8, 3, 4, 9]; row i is MDS_ROW rotated right by i. + let acc0 = 7 * s0 + 1 * s1 + 3 * s2 + 8 * s3 + 8 * s4 + 3 * s5 + 4 * s6 + 9 * s7; + let acc1 = 9 * s0 + 7 * s1 + 1 * s2 + 3 * s3 + 8 * s4 + 8 * s5 + 3 * s6 + 4 * s7; + let acc2 = 4 * s0 + 9 * s1 + 7 * s2 + 1 * s3 + 3 * s4 + 8 * s5 + 8 * s6 + 3 * s7; + let acc3 = 3 * s0 + 4 * s1 + 9 * s2 + 7 * s3 + 1 * s4 + 3 * s5 + 8 * s6 + 8 * s7; + let acc4 = 8 * s0 + 3 * s1 + 4 * s2 + 9 * s3 + 7 * s4 + 1 * s5 + 3 * s6 + 8 * s7; + let acc5 = 8 * s0 + 8 * s1 + 3 * s2 + 4 * s3 + 9 * s4 + 7 * s5 + 1 * s6 + 3 * s7; + let acc6 = 3 * s0 + 8 * s1 + 8 * s2 + 3 * s3 + 4 * s4 + 9 * s5 + 7 * s6 + 1 * s7; + let acc7 = 1 * s0 + 3 * s1 + 8 * s2 + 8 * s3 + 3 * s4 + 4 * s5 + 9 * s6 + 7 * s7; + + state[0] = crate::goldilocks::reduce128(acc0); + state[1] = crate::goldilocks::reduce128(acc1); + state[2] = crate::goldilocks::reduce128(acc2); + state[3] = crate::goldilocks::reduce128(acc3); + state[4] = crate::goldilocks::reduce128(acc4); + state[5] = crate::goldilocks::reduce128(acc5); + state[6] = crate::goldilocks::reduce128(acc6); + state[7] = crate::goldilocks::reduce128(acc7); } // ========================================================================= @@ -423,6 +447,7 @@ impl Poseidon1Goldilocks8 { state } + #[inline] pub fn permute_mut(&self, state: &mut [Goldilocks; POSEIDON1_WIDTH]) { for rc in GOLDILOCKS_POSEIDON1_RC_8.iter().take(POSEIDON1_HALF_FULL_ROUNDS) { for (i, s) in state.iter_mut().enumerate() { @@ -506,19 +531,128 @@ impl Poseidon1Goldilocks8 { /// Compression-mode in-place permutation: `output = permute(input) + input`. /// - /// Matches the koala-bear `Poseidon1Goldilocks8::compress_in_place` shape - /// so the `Compression<[R; 8]>` impl can reuse it. + /// When `R` matches the architecture's packed Goldilocks type, dispatches + /// to the SIMD-parallel path (deinterleave → per-lane scalar permute with + /// `u128`-accumulator MDS → reinterleave). When `R == Goldilocks`, uses the + /// scalar fast path (avoids the symbolic-friendly but slow `permute_generic`). + /// Otherwise falls back to the generic algebra path. #[inline] pub fn compress_in_place(&self, state: &mut [R; POSEIDON1_WIDTH]) where - R: Algebra + InjectiveMonomial<7> + Copy, + R: Algebra + InjectiveMonomial<7> + Copy + 'static, { + use core::any::TypeId; + + type Packing = ::Packing; + + if TypeId::of::() == TypeId::of::() { + // SAFETY: TypeId equality guarantees R has the same layout as Packing, + // and the array is repr-transparent as a slice of W*8 Goldilocks. + let s = unsafe { &mut *(state as *mut [R; POSEIDON1_WIDTH] as *mut [Packing; POSEIDON1_WIDTH]) }; + self.compress_in_place_simd(s); + return; + } + if TypeId::of::() == TypeId::of::() { + // SAFETY: TypeId equality. + let s = unsafe { &mut *(state as *mut [R; POSEIDON1_WIDTH] as *mut [Goldilocks; POSEIDON1_WIDTH]) }; + let initial = *s; + self.permute_mut(s); + for (slot, init) in s.iter_mut().zip(initial) { + *slot += init; + } + return; + } + let initial = *state; self.permute_generic(state); for (s, init) in state.iter_mut().zip(initial) { *s += init; } } + + /// SIMD-parallel compression over `::Packing`. + /// + /// The packed type's `Mul` fully reduces each multiplication, but the MDS + /// coefficients are tiny (max 9) so a `u128` accumulator with a single + /// `reduce128` per output (`mds_mul_scalar`) is far cheaper. We deinterleave + /// to per-lane scalar arrays, then run the rounds in lockstep across all + /// W lanes (RC add → sbox → MDS, all lanes per step) so the OoO core sees + /// W independent chains at every stage. + #[inline] + fn compress_in_place_simd(&self, state: &mut [::Packing; POSEIDON1_WIDTH]) { + type P = ::Packing; + const W: usize =

::WIDTH; + + let mut lanes: [[Goldilocks; POSEIDON1_WIDTH]; W] = [[Goldilocks::ZERO; POSEIDON1_WIDTH]; W]; + for i in 0..POSEIDON1_WIDTH { + let s = state[i].as_slice(); + for (k, lane) in lanes.iter_mut().enumerate() { + lane[i] = s[k]; + } + } + let initial = lanes; + + // Initial full rounds. + for rc in GOLDILOCKS_POSEIDON1_RC_8.iter().take(POSEIDON1_HALF_FULL_ROUNDS) { + for lane in lanes.iter_mut() { + for (i, s) in lane.iter_mut().enumerate() { + *s += rc[i]; + } + } + for lane in lanes.iter_mut() { + for s in lane.iter_mut() { + *s = sbox_full::(*s); + } + } + for lane in lanes.iter_mut() { + mds_mul_scalar(lane); + } + } + + // Partial rounds. + for rc in GOLDILOCKS_POSEIDON1_RC_8 + .iter() + .skip(POSEIDON1_HALF_FULL_ROUNDS) + .take(POSEIDON1_PARTIAL_ROUNDS) + { + for lane in lanes.iter_mut() { + for (i, s) in lane.iter_mut().enumerate() { + *s += rc[i]; + } + } + for lane in lanes.iter_mut() { + lane[0] = sbox_full::(lane[0]); + } + for lane in lanes.iter_mut() { + mds_mul_scalar(lane); + } + } + + // Terminal full rounds. + for rc in GOLDILOCKS_POSEIDON1_RC_8 + .iter() + .take(POSEIDON1_N_ROUNDS) + .skip(POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS) + { + for lane in lanes.iter_mut() { + for (i, s) in lane.iter_mut().enumerate() { + *s += rc[i]; + } + } + for lane in lanes.iter_mut() { + for s in lane.iter_mut() { + *s = sbox_full::(*s); + } + } + for lane in lanes.iter_mut() { + mds_mul_scalar(lane); + } + } + + for i in 0..POSEIDON1_WIDTH { + state[i] = P::from_fn(|k| lanes[k][i] + initial[k][i]); + } + } } /// Return the default width-8 Poseidon1 permutation. @@ -549,6 +683,43 @@ mod tests { assert_eq!(fast, slow); } + /// `compress_in_place::` must agree with per-lane scalar compression. + /// Exercises the SIMD dispatch branch. + #[test] + fn compress_in_place_dispatches_packed_correctly() { + type P = ::Packing; + let width =

::WIDTH; + let p = Poseidon1Goldilocks8; + + // Build distinct inputs per lane so we'd notice a swap or duplication. + let mut packed: [P; 8] = [P::ZERO; 8]; + for i in 0..8 { + packed[i] = + P::from_fn(|k| Goldilocks::new(0xa5a5_0000_0000_0001u64.wrapping_mul((i * 17 + k * 31 + 1) as u64))); + } + let initial = packed; + + // Reference: per-lane scalar compress. + let mut expected_lanes: Vec<[Goldilocks; 8]> = (0..width) + .map(|k| std::array::from_fn(|i| initial[i].as_slice()[k])) + .collect(); + for lane in expected_lanes.iter_mut() { + p.compress_in_place(lane); + } + + p.compress_in_place(&mut packed); + + for i in 0..8 { + for k in 0..width { + assert_eq!( + packed[i].as_slice()[k], + expected_lanes[k][i], + "mismatch at slot {i}, lane {k}" + ); + } + } + } + /// The permutation is deterministic and non-trivial. #[test] fn permutation_is_deterministic() { From 6efc06130223bdfbbb9c21695cb24a692cc7f28f Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Sun, 26 Apr 2026 13:49:57 +0200 Subject: [PATCH 22/31] much faster poseidn on avx512 --- crates/backend/goldilocks/src/poseidon1.rs | 247 +++++++++++++----- .../backend/goldilocks/src/x86_64_avx2/mod.rs | 2 +- .../goldilocks/src/x86_64_avx2/packing.rs | 61 +++++ .../goldilocks/src/x86_64_avx512/mod.rs | 2 +- .../goldilocks/src/x86_64_avx512/packing.rs | 69 ++++- 5 files changed, 314 insertions(+), 67 deletions(-) diff --git a/crates/backend/goldilocks/src/poseidon1.rs b/crates/backend/goldilocks/src/poseidon1.rs index 93a4aaa9..63a92898 100644 --- a/crates/backend/goldilocks/src/poseidon1.rs +++ b/crates/backend/goldilocks/src/poseidon1.rs @@ -14,7 +14,15 @@ //! implements `InjectiveMonomial<7>`, mirroring the koala-bear crate's //! Poseidon1 surface. -use field::{Algebra, Field, InjectiveMonomial, PackedValue, PrimeCharacteristicRing}; +#[cfg(any( + test, + not(any( + all(target_arch = "x86_64", target_feature = "avx2", not(target_feature = "avx512f")), + all(target_arch = "x86_64", target_feature = "avx512f"), + )), +))] +use field::PackedValue; +use field::{Algebra, Field, InjectiveMonomial, PrimeCharacteristicRing}; use crate::Goldilocks; @@ -532,9 +540,8 @@ impl Poseidon1Goldilocks8 { /// Compression-mode in-place permutation: `output = permute(input) + input`. /// /// When `R` matches the architecture's packed Goldilocks type, dispatches - /// to the SIMD-parallel path (deinterleave → per-lane scalar permute with - /// `u128`-accumulator MDS → reinterleave). When `R == Goldilocks`, uses the - /// scalar fast path (avoids the symbolic-friendly but slow `permute_generic`). + /// to the SIMD-parallel path. When `R == Goldilocks`, uses the scalar fast + /// path (avoids the symbolic-friendly but slow `permute_generic`). /// Otherwise falls back to the generic algebra path. #[inline] pub fn compress_in_place(&self, state: &mut [R; POSEIDON1_WIDTH]) @@ -572,85 +579,154 @@ impl Poseidon1Goldilocks8 { /// SIMD-parallel compression over `::Packing`. /// - /// The packed type's `Mul` fully reduces each multiplication, but the MDS - /// coefficients are tiny (max 9) so a `u128` accumulator with a single - /// `reduce128` per output (`mds_mul_scalar`) is far cheaper. We deinterleave - /// to per-lane scalar arrays, then run the rounds in lockstep across all - /// W lanes (RC add → sbox → MDS, all lanes per step) so the OoO core sees - /// W independent chains at every stage. + /// On x86_64 (AVX2 or AVX512), keeps state in packed registers throughout + /// the rounds. RC adds and sboxes use the packed `Add`/`square`/`Mul` + /// (which fully reduce), and the MDS uses the dedicated `mds_mul_simd` + /// (delayed reduction via shift+add multiplication by tiny constants). + /// + /// On other architectures (e.g. aarch64+NEON, scalar fallback), we + /// deinterleave to per-lane scalar arrays and run the rounds in lockstep + /// across all W lanes. The MDS coefficients are tiny (max 9), so the + /// scalar `mds_mul_scalar` (u128 accumulator + single `reduce128` per + /// output) is far cheaper than the packed type's fully-reducing `Mul`. #[inline] fn compress_in_place_simd(&self, state: &mut [::Packing; POSEIDON1_WIDTH]) { - type P = ::Packing; - const W: usize =

::WIDTH; + #[cfg(any( + all(target_arch = "x86_64", target_feature = "avx2", not(target_feature = "avx512f")), + all(target_arch = "x86_64", target_feature = "avx512f"), + ))] + { + type P = ::Packing; + + #[cfg(all(target_arch = "x86_64", target_feature = "avx2", not(target_feature = "avx512f")))] + use crate::x86_64_avx2::packing::mds_mul_simd; + #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] + use crate::x86_64_avx512::packing::mds_mul_simd; - let mut lanes: [[Goldilocks; POSEIDON1_WIDTH]; W] = [[Goldilocks::ZERO; POSEIDON1_WIDTH]; W]; - for i in 0..POSEIDON1_WIDTH { - let s = state[i].as_slice(); - for (k, lane) in lanes.iter_mut().enumerate() { - lane[i] = s[k]; + let initial = *state; + + // Initial full rounds. + for rc in GOLDILOCKS_POSEIDON1_RC_8.iter().take(POSEIDON1_HALF_FULL_ROUNDS) { + for (i, s) in state.iter_mut().enumerate() { + *s += P::from(rc[i]); + } + for s in state.iter_mut() { + *s = sbox_full::

(*s); + } + mds_mul_simd(state); } - } - let initial = lanes; - // Initial full rounds. - for rc in GOLDILOCKS_POSEIDON1_RC_8.iter().take(POSEIDON1_HALF_FULL_ROUNDS) { - for lane in lanes.iter_mut() { - for (i, s) in lane.iter_mut().enumerate() { - *s += rc[i]; + // Partial rounds. + for rc in GOLDILOCKS_POSEIDON1_RC_8 + .iter() + .skip(POSEIDON1_HALF_FULL_ROUNDS) + .take(POSEIDON1_PARTIAL_ROUNDS) + { + for (i, s) in state.iter_mut().enumerate() { + *s += P::from(rc[i]); } + state[0] = sbox_full::

(state[0]); + mds_mul_simd(state); } - for lane in lanes.iter_mut() { - for s in lane.iter_mut() { - *s = sbox_full::(*s); + + // Terminal full rounds. + for rc in GOLDILOCKS_POSEIDON1_RC_8 + .iter() + .take(POSEIDON1_N_ROUNDS) + .skip(POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS) + { + for (i, s) in state.iter_mut().enumerate() { + *s += P::from(rc[i]); } + for s in state.iter_mut() { + *s = sbox_full::

(*s); + } + mds_mul_simd(state); } - for lane in lanes.iter_mut() { - mds_mul_scalar(lane); + + // Compression-mode add-back of the original input. + for (s, init) in state.iter_mut().zip(initial) { + *s += init; } } - // Partial rounds. - for rc in GOLDILOCKS_POSEIDON1_RC_8 - .iter() - .skip(POSEIDON1_HALF_FULL_ROUNDS) - .take(POSEIDON1_PARTIAL_ROUNDS) + #[cfg(not(any( + all(target_arch = "x86_64", target_feature = "avx2", not(target_feature = "avx512f")), + all(target_arch = "x86_64", target_feature = "avx512f"), + )))] { - for lane in lanes.iter_mut() { - for (i, s) in lane.iter_mut().enumerate() { - *s += rc[i]; + type P = ::Packing; + const W: usize =

::WIDTH; + + let mut lanes: [[Goldilocks; POSEIDON1_WIDTH]; W] = [[Goldilocks::ZERO; POSEIDON1_WIDTH]; W]; + for i in 0..POSEIDON1_WIDTH { + let s = state[i].as_slice(); + for (k, lane) in lanes.iter_mut().enumerate() { + lane[i] = s[k]; } } - for lane in lanes.iter_mut() { - lane[0] = sbox_full::(lane[0]); - } - for lane in lanes.iter_mut() { - mds_mul_scalar(lane); + let initial = lanes; + + // Initial full rounds. + for rc in GOLDILOCKS_POSEIDON1_RC_8.iter().take(POSEIDON1_HALF_FULL_ROUNDS) { + for lane in lanes.iter_mut() { + for (i, s) in lane.iter_mut().enumerate() { + *s += rc[i]; + } + } + for lane in lanes.iter_mut() { + for s in lane.iter_mut() { + *s = sbox_full::(*s); + } + } + for lane in lanes.iter_mut() { + mds_mul_scalar(lane); + } } - } - // Terminal full rounds. - for rc in GOLDILOCKS_POSEIDON1_RC_8 - .iter() - .take(POSEIDON1_N_ROUNDS) - .skip(POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS) - { - for lane in lanes.iter_mut() { - for (i, s) in lane.iter_mut().enumerate() { - *s += rc[i]; + // Partial rounds. + for rc in GOLDILOCKS_POSEIDON1_RC_8 + .iter() + .skip(POSEIDON1_HALF_FULL_ROUNDS) + .take(POSEIDON1_PARTIAL_ROUNDS) + { + for lane in lanes.iter_mut() { + for (i, s) in lane.iter_mut().enumerate() { + *s += rc[i]; + } } - } - for lane in lanes.iter_mut() { - for s in lane.iter_mut() { - *s = sbox_full::(*s); + for lane in lanes.iter_mut() { + lane[0] = sbox_full::(lane[0]); + } + for lane in lanes.iter_mut() { + mds_mul_scalar(lane); } } - for lane in lanes.iter_mut() { - mds_mul_scalar(lane); + + // Terminal full rounds. + for rc in GOLDILOCKS_POSEIDON1_RC_8 + .iter() + .take(POSEIDON1_N_ROUNDS) + .skip(POSEIDON1_HALF_FULL_ROUNDS + POSEIDON1_PARTIAL_ROUNDS) + { + for lane in lanes.iter_mut() { + for (i, s) in lane.iter_mut().enumerate() { + *s += rc[i]; + } + } + for lane in lanes.iter_mut() { + for s in lane.iter_mut() { + *s = sbox_full::(*s); + } + } + for lane in lanes.iter_mut() { + mds_mul_scalar(lane); + } } - } - for i in 0..POSEIDON1_WIDTH { - state[i] = P::from_fn(|k| lanes[k][i] + initial[k][i]); + for i in 0..POSEIDON1_WIDTH { + state[i] = P::from_fn(|k| lanes[k][i] + initial[k][i]); + } } } } @@ -683,6 +759,55 @@ mod tests { assert_eq!(fast, slow); } + /// SIMD MDS path must match the scalar MDS for arbitrary state. + #[cfg(any( + all(target_arch = "x86_64", target_feature = "avx2", not(target_feature = "avx512f")), + all(target_arch = "x86_64", target_feature = "avx512f"), + ))] + #[test] + fn simd_mds_matches_scalar_mds() { + type P = ::Packing; + let width =

::WIDTH; + + // Build packed state with distinct per-lane values, including some + // u64s near the field-order boundary to stress the reduction. + let mut packed: [P; 8] = [P::ZERO; 8]; + let edges: [u64; 4] = [0, 1, crate::P - 1, u64::MAX]; + for i in 0..8 { + packed[i] = P::from_fn(|k| { + if k < 4 && i % 2 == 0 { + Goldilocks::new(edges[k]) + } else { + Goldilocks::new(0xa5a5_0000_0000_0001u64.wrapping_mul((i * 17 + k * 31 + 1) as u64)) + } + }); + } + let initial = packed; + + // Reference: per-lane scalar MDS. + let mut expected_lanes: Vec<[Goldilocks; 8]> = (0..width) + .map(|k| std::array::from_fn(|i| initial[i].as_slice()[k])) + .collect(); + for lane in expected_lanes.iter_mut() { + mds_mul_scalar(lane); + } + + #[cfg(all(target_arch = "x86_64", target_feature = "avx2", not(target_feature = "avx512f")))] + crate::x86_64_avx2::packing::mds_mul_simd(&mut packed); + #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] + crate::x86_64_avx512::packing::mds_mul_simd(&mut packed); + + for i in 0..8 { + for k in 0..width { + assert_eq!( + packed[i].as_slice()[k], + expected_lanes[k][i], + "mismatch at slot {i}, lane {k}" + ); + } + } + } + /// `compress_in_place::` must agree with per-lane scalar compression. /// Exercises the SIMD dispatch branch. #[test] diff --git a/crates/backend/goldilocks/src/x86_64_avx2/mod.rs b/crates/backend/goldilocks/src/x86_64_avx2/mod.rs index 730a8675..4e8ba31a 100644 --- a/crates/backend/goldilocks/src/x86_64_avx2/mod.rs +++ b/crates/backend/goldilocks/src/x86_64_avx2/mod.rs @@ -1,5 +1,5 @@ // Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). -mod packing; +pub(crate) mod packing; pub use packing::*; diff --git a/crates/backend/goldilocks/src/x86_64_avx2/packing.rs b/crates/backend/goldilocks/src/x86_64_avx2/packing.rs index 30ad75c6..f4edc707 100644 --- a/crates/backend/goldilocks/src/x86_64_avx2/packing.rs +++ b/crates/backend/goldilocks/src/x86_64_avx2/packing.rs @@ -365,6 +365,67 @@ fn square(x: __m256i) -> __m256i { reduce128(square64(x)) } +// ========================================================================= +// SIMD-vectorized Poseidon1 MDS multiplication +// ========================================================================= +// +// Computes the width-8 circulant MDS matrix-vector product entirely in +// `__m256i` registers, with delayed reduction. Each output is +// `sum_j MDS_ROW[(j-i) mod 8] * state[j]`. Coefficients are in +// {1, 3, 4, 7, 8, 9} (max 9), so per-term products fit in u68 and sums of +// 8 terms fit comfortably in u71. +// +// We multiply via two 32x32 `_mm256_mul_epu32` calls (low half and high +// half of state). Sums of the low and high halves are accumulated +// separately into u64s, then we assemble the (hi, lo) u128 pair and call +// `reduce128`. + +use crate::poseidon1::{MDS8_ROW, POSEIDON1_WIDTH}; + +/// SIMD MDS multiplication for the width-8 circulant Poseidon1 matrix. +#[inline] +pub(crate) fn mds_mul_simd(state: &mut [PackedGoldilocksAVX2; POSEIDON1_WIDTH]) { + unsafe { + let s: [__m256i; 8] = core::array::from_fn(|i| state[i].to_vector()); + // Precompute the high 32 bits of every state slot once. + let s_hi: [__m256i; 8] = core::array::from_fn(|i| _mm256_srli_epi64::<32>(s[i])); + + let mut out: [__m256i; 8] = [_mm256_setzero_si256(); 8]; + + for i in 0..8 { + let mut sum_ll = _mm256_setzero_si256(); + let mut sum_hl = _mm256_setzero_si256(); + for j in 0..8 { + // Row i is `MDS8_ROW` rotated right by i. + let c = MDS8_ROW[(j + 8 - i) % 8]; + let c_vec = _mm256_set1_epi64x(c); + sum_ll = _mm256_add_epi64(sum_ll, _mm256_mul_epu32(s[j], c_vec)); + sum_hl = _mm256_add_epi64(sum_hl, _mm256_mul_epu32(s_hi[j], c_vec)); + } + + // Total = sum_ll + (sum_hl << 32). Compose into (hi, lo) u128. + // sum_ll < 2^39, sum_hl < 2^39 (so sum_hl >> 32 < 2^7). + let sum_hl_shifted = _mm256_slli_epi64::<32>(sum_hl); + let lo = _mm256_add_epi64(sum_ll, sum_hl_shifted); + // Detect unsigned overflow: lo < sum_hl_shifted iff the add wrapped. + // AVX2 has no native unsigned compare; XOR with sign bit to convert. + let lo_s = _mm256_xor_si256(lo, SIGN_BIT); + let sum_hl_shifted_s = _mm256_xor_si256(sum_hl_shifted, SIGN_BIT); + // Mask is all-ones in lanes where lo < sum_hl_shifted. + let carry_mask = _mm256_cmpgt_epi64(sum_hl_shifted_s, lo_s); + let hi_no_carry = _mm256_srli_epi64::<32>(sum_hl); + // mask = -1 on overflow, 0 otherwise. Subtracting -1 is +1. + let hi = _mm256_sub_epi64(hi_no_carry, carry_mask); + + out[i] = reduce128((hi, lo)); + } + + for i in 0..8 { + state[i] = PackedGoldilocksAVX2::from_vector(out[i]); + } + } +} + #[cfg(test)] mod tests { use super::{Goldilocks, PackedGoldilocksAVX2, WIDTH}; diff --git a/crates/backend/goldilocks/src/x86_64_avx512/mod.rs b/crates/backend/goldilocks/src/x86_64_avx512/mod.rs index 730a8675..4e8ba31a 100644 --- a/crates/backend/goldilocks/src/x86_64_avx512/mod.rs +++ b/crates/backend/goldilocks/src/x86_64_avx512/mod.rs @@ -1,5 +1,5 @@ // Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). -mod packing; +pub(crate) mod packing; pub use packing::*; diff --git a/crates/backend/goldilocks/src/x86_64_avx512/packing.rs b/crates/backend/goldilocks/src/x86_64_avx512/packing.rs index 1484e764..383ef6e4 100644 --- a/crates/backend/goldilocks/src/x86_64_avx512/packing.rs +++ b/crates/backend/goldilocks/src/x86_64_avx512/packing.rs @@ -165,10 +165,10 @@ const EPSILON: __m512i = unsafe { transmute([Goldilocks::ORDER_U64.wrapping_neg( #[inline] unsafe fn canonicalize(x: __m512i) -> __m512i { - unsafe { - let mask = _mm512_cmpge_epu64_mask(x, FIELD_ORDER); - _mm512_mask_sub_epi64(x, mask, x, FIELD_ORDER) - } + // For `x < ORDER`, `x - ORDER` underflows to a huge u64, so `min` picks the + // original. For `x >= ORDER`, `x - ORDER` is the canonical form (smaller), + // so `min` picks it. One sub + one min instead of cmpge + masked sub. + unsafe { _mm512_min_epu64(x, _mm512_sub_epi64(x, FIELD_ORDER)) } } /// Compute `x + y mod P`. Result may be > P. @@ -307,6 +307,67 @@ fn square(x: __m512i) -> __m512i { reduce128(square64(x)) } +// ========================================================================= +// SIMD-vectorized Poseidon1 MDS multiplication +// ========================================================================= +// +// Computes the width-8 circulant MDS matrix-vector product entirely in +// `__m512i` registers, with delayed reduction. Each output is +// `sum_j MDS_ROW[(j-i) mod 8] * state[j]`. Coefficients are in +// {1, 3, 4, 7, 8, 9} (max 9), so per-term products fit in u68 and sums of +// 8 terms fit comfortably in u71. +// +// We multiply via two 32x32 `_mm512_mul_epu32` calls (low half and high +// half of state), which exploits that the constants fit in 4 bits (so the +// "high 32 bits" operand of mul_epu32 is zero by construction). Sums of +// the low and high halves are accumulated separately into u64s, then we +// assemble the (hi, lo) u128 pair and call `reduce128`. + +use crate::poseidon1::{MDS8_ROW, POSEIDON1_WIDTH}; + +/// SIMD MDS multiplication for the width-8 circulant Poseidon1 matrix. +#[inline] +pub(crate) fn mds_mul_simd(state: &mut [PackedGoldilocksAVX512; POSEIDON1_WIDTH]) { + unsafe { + let s: [__m512i; 8] = core::array::from_fn(|i| state[i].to_vector()); + // Precompute the high 32 bits of every state slot once. + let s_hi: [__m512i; 8] = core::array::from_fn(|i| _mm512_srli_epi64::<32>(s[i])); + + // For each output i, accumulate the dot product. With the inner loop + // bound by `8` and `MDS8_ROW` const, LLVM can fully unroll and fold + // the coefficient lookups. + let mut out: [__m512i; 8] = [_mm512_setzero_si512(); 8]; + + for i in 0..8 { + let mut sum_ll = _mm512_setzero_si512(); + let mut sum_hl = _mm512_setzero_si512(); + for j in 0..8 { + // Row i is `MDS8_ROW` rotated right by i, i.e. coefficient for + // `state[j]` in output `i` is `MDS8_ROW[(j + 8 - i) % 8]`. + let c = MDS8_ROW[(j + 8 - i) % 8]; + let c_vec = _mm512_set1_epi64(c); + sum_ll = _mm512_add_epi64(sum_ll, _mm512_mul_epu32(s[j], c_vec)); + sum_hl = _mm512_add_epi64(sum_hl, _mm512_mul_epu32(s_hi[j], c_vec)); + } + + // Total value = sum_ll + (sum_hl << 32). Compose into (hi, lo) u128. + // sum_ll < 2^39, sum_hl < 2^39, so sum_hl >> 32 < 2^7. + let sum_hl_shifted = _mm512_slli_epi64::<32>(sum_hl); + let lo = _mm512_add_epi64(sum_ll, sum_hl_shifted); + // Detect unsigned overflow: lo < sum_hl_shifted iff the add wrapped. + let carry_mask = _mm512_cmplt_epu64_mask(lo, sum_hl_shifted); + let hi_no_carry = _mm512_srli_epi64::<32>(sum_hl); + let hi = _mm512_mask_add_epi64(hi_no_carry, carry_mask, hi_no_carry, _mm512_set1_epi64(1)); + + out[i] = reduce128((hi, lo)); + } + + for i in 0..8 { + state[i] = PackedGoldilocksAVX512::from_vector(out[i]); + } + } +} + #[cfg(test)] mod tests { use super::{Goldilocks, PackedGoldilocksAVX512, WIDTH}; From c308fb6e31308d0e454d395fed9bf39f5316c1c8 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 4 May 2026 09:58:10 +0200 Subject: [PATCH 23/31] w --- crates/rec_aggregation/xmss_aggregate.py | 57 ++++++++++++++---------- crates/xmss/src/wots.rs | 34 ++++++++------ 2 files changed, 54 insertions(+), 37 deletions(-) diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index 6de66f22..ef0458e3 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -17,10 +17,12 @@ # `[leading_0 | tip_a(2) | tip_b(2) | trailing_0]` so that copy_ef can be used on # both halves under Goldilocks (DIM = 3 = 1 + XMSS_DIGEST_LEN). WOTS_PK_PAIR_STRIDE = 2 + 2 * XMSS_DIGEST_LEN -# Goldilocks encoding: 21 chunks of W bits per FE, with a 1-bit canonical check -# (factored as (diff)·(diff − 2^63) == 0). 2 FE × 21 chunks = 42 = V chunks. -NUM_ENCODING_FE = 2 -CHUNKS_PER_FE = 21 +# Goldilocks encoding: low 32 bits of each of 4 output FE are concatenated into +# a 128-bit pool. V × W = 42 × 3 = 126 bits are used as Winternitz indices via +# chunks that straddle FE boundaries; the top 2 bits of the pool are forced to +# zero, giving 128-bit collision security on the (message → encoding) map. +NUM_ENCODING_FE = 4 +LOW_BITS_PER_ENCODING_FE = 32 MERKLE_LEVELS_PER_CHUNK = MERKLE_LEVELS_PER_CHUNK_PLACEHOLDER N_MERKLE_CHUNKS = LOG_LIFETIME / MERKLE_LEVELS_PER_CHUNK INNER_PUB_MEM_SIZE = 2**INNER_PUBLIC_MEMORY_LOG_SIZE # = DIGEST_LEN @@ -63,25 +65,34 @@ def xmss_verify(pub_key, message, merkle_chunks): encoding_fe = Array(DIGEST_LEN) poseidon8_compress(pre_compressed, public_params_paded, encoding_fe) - # Decompose first NUM_ENCODING_FE (=2) output FE into 21 W-bit chunks each. - # 2 × 21 = 42 = V chunks; per-FE canonical check is the 1-bit slack form - # (diff)·(diff − 2^63) == 0 from utils.checked_decompose_bits. - encoding = Array(NUM_ENCODING_FE * CHUNKS_PER_FE) - hint_decompose_bits_xmss(encoding, encoding_fe, NUM_ENCODING_FE, CHUNKS_PER_FE, W) - - # Each chunk must be a valid W-bit Winternitz index. - for i in unroll(0, NUM_ENCODING_FE * CHUNKS_PER_FE): - assert encoding[i] < CHAIN_LENGTH - - # For each FE: partial_sum = Σ_j encoding[i*K+j] * 2^(W*j) is the low 63 bits; - # the remainder is a single bit. Factorise the equality so no inverse is needed: - # (encoding_fe[i] − partial_sum) · (encoding_fe[i] − partial_sum − 2^63) == 0 - for i in unroll(0, NUM_ENCODING_FE): - partial_sum: Mut = encoding[i * CHUNKS_PER_FE] - for j in unroll(1, CHUNKS_PER_FE): - partial_sum += encoding[i * CHUNKS_PER_FE + j] * (CHAIN_LENGTH ** j) - diff = encoding_fe[i] - partial_sum - assert diff * (diff - 2**63) == 0 + # Decompose each output FE into 64 canonical bits (checked_decompose_bits + # enforces a == low + 2^32 · high mod p with the unique canonical form). + bits_0, _ = checked_decompose_bits(encoding_fe[0]) + bits_1, _ = checked_decompose_bits(encoding_fe[1]) + bits_2, _ = checked_decompose_bits(encoding_fe[2]) + bits_3, _ = checked_decompose_bits(encoding_fe[3]) + + # Build V = 42 Winternitz indices from the 128-bit stream + # bit p ∈ [0, 128) = bits_(p // 32)[p % 32] + # via straddling at FE0/FE1 (chunk 10) and FE1/FE2 (chunk 21) boundaries. + # Each chunk is `b0 + 2·b1 + 4·b2` of 3 bool bits, so it is automatically + # a valid W-bit Winternitz index in [0, 8). + encoding = Array(V) + for k in unroll(0, 10): + encoding[k] = bits_0[3*k] + bits_0[3*k + 1] * 2 + bits_0[3*k + 2] * 4 + encoding[10] = bits_0[30] + bits_0[31] * 2 + bits_1[0] * 4 + for k in unroll(11, 21): + encoding[k] = bits_1[3*k - 32] + bits_1[3*k - 31] * 2 + bits_1[3*k - 30] * 4 + encoding[21] = bits_1[31] + bits_2[0] * 2 + bits_2[1] * 4 + for k in unroll(22, 32): + encoding[k] = bits_2[3*k - 64] + bits_2[3*k - 63] * 2 + bits_2[3*k - 62] * 4 + for k in unroll(32, 42): + encoding[k] = bits_3[3*k - 96] + bits_3[3*k - 95] * 2 + bits_3[3*k - 94] * 4 + + # Top 2 bits of the 128-bit stream must be zero — adds 2 bits of constrained + # entropy on top of V·W = 126 chunk bits to reach 128-bit security. + assert bits_3[30] == 0 + assert bits_3[31] == 0 debug_assert(V % 2 == 0) wots_public_key = Array((V / 2) * WOTS_PK_PAIR_STRIDE) diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index 72831925..216b0598 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -149,14 +149,13 @@ pub fn find_randomness_for_wots_encoding( } } -// Each encoding FE is decomposed into `CHUNKS_PER_FE` chunks of `W` bits. -// W = 3, CHUNKS_PER_FE = 21 → 63 bits used per Goldilocks element, 1-bit slack -// (verifier asserts it's 0 or 1, factored as (diff)·(diff − 2^63) = 0). -// Total chunks across all 4 output FE: 4 × 21 = 84, of which the first V = 42 -// are used as Winternitz indices. -const CHUNKS_PER_FE: usize = 21; -const _: () = assert!(CHUNKS_PER_FE * W <= 63); // 1-bit slack -const _: () = assert!(DIGEST_LEN_FE * CHUNKS_PER_FE >= V); +// Encoding stream: low 32 bits of each output FE are concatenated into a +// 128-bit pool that straddles FE boundaries. V = 42 chunks of W = 3 bits = 126 +// bits are used as Winternitz indices; the remaining 2 bits at the top of the +// stream are forced to zero, giving 128-bit collision security on the +// (message → encoding) map. +const LOW_BITS_PER_ENCODING_FE: usize = 32; +const _: () = assert!(DIGEST_LEN_FE * LOW_BITS_PER_ENCODING_FE == V * W + 2); pub fn wots_encode( message: &[F; MESSAGE_LEN_FE], @@ -174,14 +173,21 @@ pub fn wots_encode( second_input_right[..PUBLIC_PARAM_LEN_FE].copy_from_slice(&xmss_pub_key.public_param); let compressed = poseidon8_compress_pair(&pre_compressed, &second_input_right); - let mut all_indices = [0u8; DIGEST_LEN_FE * CHUNKS_PER_FE]; + let mut stream: u128 = 0; for (i, g) in compressed.iter().enumerate() { - let value = g.as_canonical_u64(); - for j in 0..CHUNKS_PER_FE { - all_indices[i * CHUNKS_PER_FE + j] = ((value >> (j * W)) & ((1u64 << W) - 1)) as u8; - } + let low = g.as_canonical_u64() & ((1u64 << LOW_BITS_PER_ENCODING_FE) - 1); + stream |= u128::from(low) << (LOW_BITS_PER_ENCODING_FE * i); + } + if (stream >> (V * W)) != 0 { + // The 2 high bits of the 128-bit stream must be zero — the signer + // brute-forces randomness until this holds (≈ 2 grinding bits). + return None; + } + + let mut used = [0u8; V]; + for (j, slot) in used.iter_mut().enumerate() { + *slot = ((stream >> (j * W)) & ((1u128 << W) - 1)) as u8; } - let used: [u8; V] = all_indices[..V].try_into().unwrap(); is_valid_encoding(&used).then_some(used) } From 0470d7a276088d5f8a30245a4fab6bad37040cae Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 4 May 2026 12:14:31 +0200 Subject: [PATCH 24/31] better encoding --- README.md | 8 ++-- crates/lean_compiler/zkDSL.md | 2 +- crates/lean_vm/src/isa/hint.rs | 44 +++++++++---------- crates/rec_aggregation/xmss_aggregate.py | 54 ++++++++++-------------- crates/xmss/src/lib.rs | 6 ++- crates/xmss/src/wots.rs | 39 +++++++---------- crates/xmss/xmss.md | 31 +++++++------- 7 files changed, 84 insertions(+), 100 deletions(-) diff --git a/README.md b/README.md index e46ecedd..e14de1b1 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ Machine: M4 Max 48GB (CPU only) ### XMSS aggregation ```bash -cargo run --release -- xmss --n-signatures 1500 --log-inv-rate 1 +cargo run --release -- xmss --n-signatures 1550 --log-inv-rate 1 ``` | WHIR rate | Proven Regime | Proximity Gaps Conjecture | @@ -40,7 +40,7 @@ cargo run --release -- xmss --n-signatures 1500 --log-inv-rate 1 ### Recursion -Aggregating together n previously aggregated signatures, each containing 700 XMSS. +Aggregating together n previously aggregated signatures, each containing 775 XMSS. ```bash @@ -81,9 +81,7 @@ cargo run --release -- fancy-aggregation ### XMSS -Currently, we use an [XMSS](crates/xmss/xmss.md) with hash digests of 4 field elements ≈ 124 bits. Tweaks and public parameters ensure domain separation. An analysis in the ROM (resp. QROM), inspired by the section 3.1 of [Tight adaptive reprogramming in the QROM](https://arxiv.org/pdf/2010.15103) would lead to ≈ 124 (resp. 62) bits of classical (resp. quantum) security. Going to 128 / 64 bits of classical / quantum security, i.e. NIST level 1 (in the ROM/QROM), is an ongoing effort. It requires either: -- hash digests of 5 field elements (drawback: we need to double the hash chain length from 8 to 16 if we want to stay below one IPv6 MTU = 1280 bytes) -- a new prime, close to 32 bits (typically p = 125.2^25 + 1) or 64 bits ([goldilocks](https://2π.com/22/goldilocks/)) +Currently, we use an [XMSS](crates/xmss/xmss.md) with hash digests of 2 goldilocks ≈ 128 bits. Tweaks and public parameters ensure domain separation. An analysis in the ROM (resp. QROM), inspired by the section 3.1 of [Tight adaptive reprogramming in the QROM](https://arxiv.org/pdf/2010.15103) would lead to ≈ 128 (resp. 64) bits of classical (resp. quantum) security, i.e. NIST level 1. It's important to mention that a security analysis in the ROM / QROM is not the most conservative. In particular, [eprint 2025/055](https://eprint.iacr.org/2025/055.pdf)'s security proof holds in the standard model (at the cost of bigger hash digests): the implementation is available in the [leanSig](https://github.com/leanEthereum/leanSig) repository. A compatible version of leanMultisig can be found in the [devnet4](https://github.com/leanEthereum/leanMultisig/tree/devnet4) branch. diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index 682d8909..f3dad32d 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -456,7 +456,7 @@ hints = prover-supplied values at runtime (without adding snark constraints). Li | `hint_decompose_bits` | `(to_decompose, ptr, num_bits, endianness)` | `num_bits` field elements at `ptr` (the 0/1 bit decomposition of `to_decompose`); `endianness` is `0` for big-endian, `1` for little-endian | | `hint_less_than` | `(a, b, result_ptr)` | `1` at `result_ptr` if `a < b` else `0` | | `hint_log2_ceil` | `(n, result_ptr)` | `ceil(log2(n))` at `result_ptr` | -| `hint_decompose_bits_xmss` | `(decomposed_ptr, remaining_ptr, to_decompose_ptr, num_to_decompose, chunk_size)` | XMSS-specific decomposition (see `crates/lean_vm/src/isa/hint.rs`) | +| `hint_decompose_bits_xmss` | `(chunks_ptr, limbs_ptr, src_value)` | WOTS-encoding decomposition of one Goldilocks FE: 10 W-bit chunks of the low 30 bits at `chunks_ptr[0..10]` + 2 u16 limbs of the high 32 bits at `limbs_ptr[0..2]` (the top 2 bits of the low limb are implicit zeros — see `crates/lean_vm/src/isa/hint.rs`) | | `hint_decompose_bits_merkle_whir` | `(decomposed_ptr, remaining_ptr, value, chunk_size)` | Merkle/WHIR-specific decomposition | Hints only *suggest* a value; the guest must add appropriate constraints to bind that value to its specification. diff --git a/crates/lean_vm/src/isa/hint.rs b/crates/lean_vm/src/isa/hint.rs index 365f727a..8630d180 100644 --- a/crates/lean_vm/src/isa/hint.rs +++ b/crates/lean_vm/src/isa/hint.rs @@ -101,11 +101,10 @@ impl HintWitnessDestination { #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum CustomHint { - // Decompose values into their custom representations: - /// each field element x is decomposed to: (a0, a1, a2, ..., a11, b) where: - /// x = a0 + a1.4 + a2.4^2 + a3.4^3 + ... + a11.4^11 + b.2^24 - /// and ai < 4, b < 2^7 - 1 - /// The decomposition is unique, and always exists (except for x = -1) + /// WOTS-encoding decomposition of one Goldilocks FE. + /// Args: (chunks_ptr, limbs_ptr, src_value). + /// Writes 10 W=3-bit chunks of the low 30 bits to `chunks_ptr[0..10]` + /// and 2 u16 limbs of the high 32 bits to `limbs_ptr[0..2]`. DecomposeBitsXMSS, DecomposeBitsMerkleWhir, DecomposeBits, @@ -134,7 +133,7 @@ impl CustomHint { pub fn n_args(&self) -> usize { match self { - Self::DecomposeBitsXMSS => 5, + Self::DecomposeBitsXMSS => 3, Self::DecomposeBitsMerkleWhir => 4, Self::DecomposeBits => 4, Self::LessThan => 3, @@ -149,25 +148,22 @@ impl CustomHint { ) -> Result<(), RunnerError> { match self { Self::DecomposeBitsXMSS => { - // Decompose `num_fe` field elements into `chunks_per_fe` chunks of - // `chunk_size` bits each. Extracts the low `chunks_per_fe * chunk_size` - // bits of each FE's canonical u64 representation (caller is responsible - // for ensuring `chunks_per_fe * chunk_size <= F::bits()`). - let decomposed_ptr = args[0].read_value(ctx.memory, ctx.fp)?.to_usize(); - let to_decompose_ptr = args[1].read_value(ctx.memory, ctx.fp)?.to_usize(); - let num_fe = args[2].read_value(ctx.memory, ctx.fp)?.to_usize(); - let chunks_per_fe = args[3].read_value(ctx.memory, ctx.fp)?.to_usize(); - let chunk_size = args[4].read_value(ctx.memory, ctx.fp)?.to_usize(); - assert!(chunks_per_fe * chunk_size <= F::bits()); - let mut memory_index_decomposed = decomposed_ptr; - for i in 0..num_fe { - let value = ctx.memory.get(to_decompose_ptr + i)?.as_canonical_u64(); - for j in 0..chunks_per_fe { - let chunk = F::from_u64((value >> (chunk_size * j)) & ((1u64 << chunk_size) - 1)); - ctx.memory.set(memory_index_decomposed, chunk)?; - memory_index_decomposed += 1; - } + // WOTS-encoding decomposition. Writes: + // chunks_ptr[0..10] = 10 chunks of W=3 bits (low bits 0..29) + // limbs_ptr[0..2] = 2 u16 limbs of the high 32 bits (bits 32..47, 48..63) + // The 2 high bits of the low limb are implicit zeros, enforced by + // the SNARK constraint structure (and rejected at signing time). + let chunks_ptr = args[0].read_value(ctx.memory, ctx.fp)?.to_usize(); + let limbs_ptr = args[1].read_value(ctx.memory, ctx.fp)?.to_usize(); + let value = args[2].read_value(ctx.memory, ctx.fp)?.as_canonical_u64(); + const NUM_CHUNKS: usize = 10; + const CHUNK_SIZE: usize = 3; + for j in 0..NUM_CHUNKS { + let chunk = (value >> (CHUNK_SIZE * j)) & ((1u64 << CHUNK_SIZE) - 1); + ctx.memory.set(chunks_ptr + j, F::from_u64(chunk))?; } + ctx.memory.set(limbs_ptr, F::from_u64((value >> 32) & 0xFFFF))?; + ctx.memory.set(limbs_ptr + 1, F::from_u64((value >> 48) & 0xFFFF))?; } Self::DecomposeBitsMerkleWhir => { // Decompose a single FE's canonical u64 into `num_chunks` chunks of diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index ef0458e3..cc8f7bbe 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -17,10 +17,6 @@ # `[leading_0 | tip_a(2) | tip_b(2) | trailing_0]` so that copy_ef can be used on # both halves under Goldilocks (DIM = 3 = 1 + XMSS_DIGEST_LEN). WOTS_PK_PAIR_STRIDE = 2 + 2 * XMSS_DIGEST_LEN -# Goldilocks encoding: low 32 bits of each of 4 output FE are concatenated into -# a 128-bit pool. V × W = 42 × 3 = 126 bits are used as Winternitz indices via -# chunks that straddle FE boundaries; the top 2 bits of the pool are forced to -# zero, giving 128-bit collision security on the (message → encoding) map. NUM_ENCODING_FE = 4 LOW_BITS_PER_ENCODING_FE = 32 MERKLE_LEVELS_PER_CHUNK = MERKLE_LEVELS_PER_CHUNK_PLACEHOLDER @@ -65,34 +61,9 @@ def xmss_verify(pub_key, message, merkle_chunks): encoding_fe = Array(DIGEST_LEN) poseidon8_compress(pre_compressed, public_params_paded, encoding_fe) - # Decompose each output FE into 64 canonical bits (checked_decompose_bits - # enforces a == low + 2^32 · high mod p with the unique canonical form). - bits_0, _ = checked_decompose_bits(encoding_fe[0]) - bits_1, _ = checked_decompose_bits(encoding_fe[1]) - bits_2, _ = checked_decompose_bits(encoding_fe[2]) - bits_3, _ = checked_decompose_bits(encoding_fe[3]) - - # Build V = 42 Winternitz indices from the 128-bit stream - # bit p ∈ [0, 128) = bits_(p // 32)[p % 32] - # via straddling at FE0/FE1 (chunk 10) and FE1/FE2 (chunk 21) boundaries. - # Each chunk is `b0 + 2·b1 + 4·b2` of 3 bool bits, so it is automatically - # a valid W-bit Winternitz index in [0, 8). encoding = Array(V) - for k in unroll(0, 10): - encoding[k] = bits_0[3*k] + bits_0[3*k + 1] * 2 + bits_0[3*k + 2] * 4 - encoding[10] = bits_0[30] + bits_0[31] * 2 + bits_1[0] * 4 - for k in unroll(11, 21): - encoding[k] = bits_1[3*k - 32] + bits_1[3*k - 31] * 2 + bits_1[3*k - 30] * 4 - encoding[21] = bits_1[31] + bits_2[0] * 2 + bits_2[1] * 4 - for k in unroll(22, 32): - encoding[k] = bits_2[3*k - 64] + bits_2[3*k - 63] * 2 + bits_2[3*k - 62] * 4 - for k in unroll(32, 42): - encoding[k] = bits_3[3*k - 96] + bits_3[3*k - 95] * 2 + bits_3[3*k - 94] * 4 - - # Top 2 bits of the 128-bit stream must be zero — adds 2 bits of constrained - # entropy on top of V·W = 126 chunk bits to reach 128-bit security. - assert bits_3[30] == 0 - assert bits_3[31] == 0 + for i in unroll(0, NUM_ENCODING_FE): + decompose_encoding_fe(encoding_fe[i], encoding + i * (V / NUM_ENCODING_FE)) debug_assert(V % 2 == 0) wots_public_key = Array((V / 2) * WOTS_PK_PAIR_STRIDE) @@ -201,6 +172,27 @@ def chain_hash_b(input, n, output, chain_i_tweaks, chain_right, chain_length_ptr return +@inline +def decompose_encoding_fe(fe_value, chunks_ptr): + limbs = Array(2) + hint_decompose_bits_xmss(chunks_ptr, limbs, fe_value) + + for k in unroll(0, 10): + assert chunks_ptr[k] < CHAIN_LENGTH + assert limbs[0] < 2**16 + assert limbs[1] < 2**16 + + low: Mut = chunks_ptr[0] + for k in unroll(1, 10): + low += chunks_ptr[k] * (2 ** (W * k)) + + high = limbs[0] + limbs[1] * (2**16) + assert fe_value == low + (2**32) * high + assert high != 2**32 - 1 # ensures uniformity + prevents overflow + + return + + @inline def wots_pk_hash(wots_public_key, public_param): # T-Sponge with replacement: IV = poseidon8([tweak(1)|0|pp(2)], zeros) diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index 2d5a9916..09bbb4b9 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -17,11 +17,15 @@ type PublicParam = [F; PUBLIC_PARAM_LEN_FE]; type Randomness = [F; RANDOMNESS_LEN_FE]; // WOTS -pub const V: usize = 42; +pub const V: usize = 40; pub const W: usize = 3; pub const CHAIN_LENGTH: usize = 1 << W; pub const NUM_CHAIN_HASHES: usize = 110; pub const TARGET_SUM: usize = V * (CHAIN_LENGTH - 1) - NUM_CHAIN_HASHES; +pub const ENCODING_NUM_FINAL_ZEROS: usize = 8; +const _: () = assert!(V * W + ENCODING_NUM_FINAL_ZEROS == DIGEST_LEN_FE * 32); +const _: () = assert!(V.is_multiple_of(DIGEST_LEN_FE)); // V chunks split evenly across the 4 FEs +const _: () = assert!(ENCODING_NUM_FINAL_ZEROS.is_multiple_of(DIGEST_LEN_FE)); // same for the zero bits pub const RANDOMNESS_LEN_FE: usize = 3; pub const MESSAGE_LEN_FE: usize = 4; pub const PUBLIC_PARAM_LEN_FE: usize = 2; diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index 216b0598..9bb623b5 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -149,14 +149,6 @@ pub fn find_randomness_for_wots_encoding( } } -// Encoding stream: low 32 bits of each output FE are concatenated into a -// 128-bit pool that straddles FE boundaries. V = 42 chunks of W = 3 bits = 126 -// bits are used as Winternitz indices; the remaining 2 bits at the top of the -// stream are forced to zero, giving 128-bit collision security on the -// (message → encoding) map. -const LOW_BITS_PER_ENCODING_FE: usize = 32; -const _: () = assert!(DIGEST_LEN_FE * LOW_BITS_PER_ENCODING_FE == V * W + 2); - pub fn wots_encode( message: &[F; MESSAGE_LEN_FE], slot: u32, @@ -173,22 +165,23 @@ pub fn wots_encode( second_input_right[..PUBLIC_PARAM_LEN_FE].copy_from_slice(&xmss_pub_key.public_param); let compressed = poseidon8_compress_pair(&pre_compressed, &second_input_right); - let mut stream: u128 = 0; - for (i, g) in compressed.iter().enumerate() { - let low = g.as_canonical_u64() & ((1u64 << LOW_BITS_PER_ENCODING_FE) - 1); - stream |= u128::from(low) << (LOW_BITS_PER_ENCODING_FE * i); - } - if (stream >> (V * W)) != 0 { - // The 2 high bits of the 128-bit stream must be zero — the signer - // brute-forces randomness until this holds (≈ 2 grinding bits). - return None; - } - - let mut used = [0u8; V]; - for (j, slot) in used.iter_mut().enumerate() { - *slot = ((stream >> (j * W)) & ((1u128 << W) - 1)) as u8; + // Per-FE decomposition: each output FE contributes V/DIGEST_LEN_FE + // = 10 W-bit chunks from the low 30 bits of its low limb; the top 2 bits + // of each FE's low limb must be zero (ENCODING_NUM_FINAL_ZEROS = 8 bits + // total, evenly distributed = 2 per FE) + const CHUNKS_PER_FE: usize = V / DIGEST_LEN_FE; + const CHUNK_BITS_PER_FE: usize = CHUNKS_PER_FE * W; + let mut all_indices = [0u8; V]; + for (i, fe) in compressed.iter().enumerate() { + let low = fe.as_canonical_u64() & ((1u64 << 32) - 1); + if (low >> CHUNK_BITS_PER_FE) != 0 { + return None; + } + for j in 0..CHUNKS_PER_FE { + all_indices[i * CHUNKS_PER_FE + j] = ((low >> (W * j)) & ((1u64 << W) - 1)) as u8; + } } - is_valid_encoding(&used).then_some(used) + is_valid_encoding(&all_indices).then_some(all_indices) } fn is_valid_encoding(encoding: &[u8]) -> bool { diff --git a/crates/xmss/xmss.md b/crates/xmss/xmss.md index 1a03e7d4..febb8e14 100644 --- a/crates/xmss/xmss.md +++ b/crates/xmss/xmss.md @@ -2,25 +2,26 @@ ## Field -KoalaBear (p = 2^31 - 2^24 + 1). +Goldilocks (p = 2^64 - 2^32 + 1). ## Hash function -[Poseidon](https://eprint.iacr.org/2019/458), in compression mode (feedforward addition). Input: 16 field elements. Output: 8 field elements. We denote it `H`. Chain hashes, Merkle hashes, and the final WOTS-pubkey hash truncate the output to 4 field elements (`n`); the encoding step and the intermediate WOTS-pubkey sponge states keep the full 8 elements. +[Poseidon](https://eprint.iacr.org/2019/458), in compression mode (feedforward addition). Input: 8 field elements. Output: 4 field elements. We denote it `H`. Chain hashes, Merkle hashes, and the final WOTS-pubkey hash truncate the output to 2 field elements (`n`); the encoding step and the intermediate WOTS-pubkey sponge states keep the full 4 elements. ## Sizes (in field elements) -- `n = 4`: digest size -- `|pp| = 4`: public parameter -- `|randomness| = 6`: signature randomness -- `|msg| = 8`: message size -- `|tweak| = 2`: tweak (domain separation: `encoding`, `chain`, `wots_pk`, `merkle`) +- `n = 2`: digest size +- `|pp| = 2`: public parameter +- `|randomness| = 3`: signature randomness +- `|msg| = 4`: message size +- `|tweak| = 1`: tweak (domain separation: `encoding`, `chain`, `wots_pk`, `merkle`) ## WOTS (Winternitz One Time Signature) -- `v = 42`: number of hash chains +- `v = 40`: number of hash chains - `w = 3`, `chain_length = 2^w = 8` -- `target_sum = 184`: a WOTS encoding `(e_0, ..., e_{v-1})` is valid iff each `e_i < chain_length` and `sum(e_i) = target_sum`. The signer grinds `randomness` until the encoding is valid (avoids checksum chains). +- `target_sum = 170`: a WOTS encoding `(e_0, ..., e_{v-1})` is valid iff each `e_i < chain_length` and `sum(e_i) = target_sum`. The signer grinds `randomness` until the encoding is valid (avoids checksum chains). +- `encoding_num_final_zeros = 8`: the 2 high bits of the low 32-bit half of each of the 4 encoding-digest limbs must be zero (2 zero bits per limb × 4 limbs). ## XMSS @@ -30,19 +31,19 @@ KoalaBear (p = 2^31 - 2^24 + 1). Inputs: public key `(merkle_root, pp)`, message `msg`, slot `s`, signature `(randomness, chain_tips, merkle_proof)`. -1. **Encode**: compute the 8-limb digest `D = H(H(msg | randomness | tweak_encoding(s)) | pp | 0000)`. Reject if any limb of `D` equals `-1` (for a uniform sampling). For each limb, take its canonical representative's low 24 bits in little-endian order, concatenate to get 192 bits, then take the first `v · w = 126` bits split into `v = 42` little-endian chunks of `w = 3` bits → encoding `(e_0, ..., e_{v-1})` with each `e_i ∈ [0, chain_length)`. Reject if `sum(e_i) ≠ target_sum`. -2. **Recover WOTS public key**: for each `i`, walk chain `i` from `chain_tips[i]` for `chain_length - 1 - e_i` steps, where each step is `H(tweak_chain(i, step, s) | 00 | previous_value | pp | 0000)` truncated to `n`. -3. **Hash WOTS public key**: T-sponge with replacement over the `v` recovered chain ends, with IV `[tweak_wots_pk(s) | 00 | pp]`, ingesting two chain end digests at a time. Output is the Merkle leaf. -4. **Walk Merkle path**: for `level = 0..log_lifetime`, combine the current node with `merkle_proof[level]` (left/right determined by bit `level` of `s`) via `H(tweak_merkle(level+1, parent_index) | 00 | pp | left | right)` truncated to `n`. +1. **Encode**: compute the 4-limb digest `D = H(H(msg | randomness | tweak_encoding(s)) | pp | 00)`. For each limb `D_i`, take the canonical representative `D_i = low + 2^32 · high` (with `low, high < 2^32`) and reject if `high == 2^32 - 1` (needed for uniformity of the encoding). Reject if the 2 high bits of `low` are non-zero (`encoding_num_final_zeros / 4 = 2` bits per limb). Otherwise split the low 30 bits of `low` into `v / 4 = 10` little-endian chunks of `w = 3` bits each, giving the encoding `(e_0, ..., e_{v-1})` with each `e_i ∈ [0, chain_length)`. Reject if `sum(e_i) ≠ target_sum`. +2. **Recover WOTS public key**: for each `i`, walk chain `i` from `chain_tips[i]` for `chain_length - 1 - e_i` steps, where each step is `H(tweak_chain(i, step, s) | 0 | previous_value | pp | 00)` truncated to `n`. +3. **Hash WOTS public key**: T-sponge with replacement over the `v` recovered chain ends, with IV `[tweak_wots_pk(s) | 0 | pp]`, ingesting two chain end digests at a time. Output is the Merkle leaf. +4. **Walk Merkle path**: for `level = 0..log_lifetime`, combine the current node with `merkle_proof[level]` (left/right determined by bit `level` of `s`) via `H(tweak_merkle(level+1, parent_index) | 0 | pp | left | right)` truncated to `n`. 5. **Check root**: accept iff the final hash equals `merkle_root`. ## Security -target = 123,9 ≈ 124 bits of classical security in the ROM, and ≈ 62 bits of quantum security in the QROM, with an analysis inspired by the section 3.1 of [Tight adaptive reprogramming in the QROM](https://arxiv.org/pdf/2010.15103). TODO write the complete proof. +target ≈ 128 bits of classical security in the ROM, and ≈ 64 bits of quantum security in the QROM, with an analysis inspired by the section 3.1 of [Tight adaptive reprogramming in the QROM](https://arxiv.org/pdf/2010.15103). TODO write the complete proof. ## Signature size -**1171 bytes** `log2(p).(|randomness| + n.(v + log_lifetime))` +**1176 bytes** `log2(p).(|randomness| + n.(v + log_lifetime)) / 8` below IPv6 [MTU](https://fr.wikipedia.org/wiki/Maximum_transmission_unit) (1280 bytes) From 4c1209a3a116396d97378e1120bca11f8caa203d Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 4 May 2026 12:17:57 +0200 Subject: [PATCH 25/31] clippy --- crates/backend/goldilocks/src/poseidon1.rs | 1 + crates/lean_vm/src/tables/poseidon_8/sparse.rs | 1 + crates/whir/tests/run_whir.rs | 4 ++-- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/crates/backend/goldilocks/src/poseidon1.rs b/crates/backend/goldilocks/src/poseidon1.rs index 63a92898..f4ae8dee 100644 --- a/crates/backend/goldilocks/src/poseidon1.rs +++ b/crates/backend/goldilocks/src/poseidon1.rs @@ -742,6 +742,7 @@ pub fn default_goldilocks_poseidon1_8() -> Poseidon1Goldilocks8 { // ========================================================================= #[cfg(test)] +#[allow(clippy::needless_range_loop)] mod tests { use super::*; diff --git a/crates/lean_vm/src/tables/poseidon_8/sparse.rs b/crates/lean_vm/src/tables/poseidon_8/sparse.rs index 2ce825f3..b1f897c5 100644 --- a/crates/lean_vm/src/tables/poseidon_8/sparse.rs +++ b/crates/lean_vm/src/tables/poseidon_8/sparse.rs @@ -310,6 +310,7 @@ fn compute_partial_constants() -> PartialConstants { } #[cfg(test)] +#[allow(clippy::needless_range_loop, clippy::assign_op_pattern)] mod tests { use super::*; use backend::{POSEIDON1_HALF_FULL_ROUNDS, PrimeField64}; diff --git a/crates/whir/tests/run_whir.rs b/crates/whir/tests/run_whir.rs index ca65fd17..cf0a7f49 100644 --- a/crates/whir/tests/run_whir.rs +++ b/crates/whir/tests/run_whir.rs @@ -95,7 +95,7 @@ fn test_run_whir() { )); } - let mut prover_state = ProverState::new(poseidon8.clone()); + let mut prover_state = ProverState::new(poseidon8); precompute_dft_twiddles::(1 << F::TWO_ADICITY); @@ -118,7 +118,7 @@ fn test_run_whir() { let proof_size_single = pruned_proof.proof_size_fe() as f64 * F::bits() as f64 / 8.0; - let mut verifier_state = VerifierState::::new(pruned_proof, poseidon8.clone()).unwrap(); + let mut verifier_state = VerifierState::::new(pruned_proof, poseidon8).unwrap(); let parsed_commitment = params.parse_commitment::(&mut verifier_state).unwrap(); From 086ab06855587eea971e1587ffab5da718b27d51 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 4 May 2026 14:50:30 +0200 Subject: [PATCH 26/31] f --- crates/whir/tests/run_whir.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/whir/tests/run_whir.rs b/crates/whir/tests/run_whir.rs index cf0a7f49..d1cc3c88 100644 --- a/crates/whir/tests/run_whir.rs +++ b/crates/whir/tests/run_whir.rs @@ -97,7 +97,7 @@ fn test_run_whir() { let mut prover_state = ProverState::new(poseidon8); - precompute_dft_twiddles::(1 << F::TWO_ADICITY); + precompute_dft_twiddles::(1 << 24); let polynomial: MleOwned = MleOwned::Base(polynomial); From 37f052fde2119c07fd28578a8e496f7a2f35aad1 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 4 May 2026 15:23:02 +0200 Subject: [PATCH 27/31] whir: remove folding pow grinding (not needed when field is big enough, like the cubic extension of goldilocks) --- crates/rec_aggregation/src/compilation.rs | 26 +----- crates/rec_aggregation/whir.py | 52 ++---------- crates/whir/src/config.rs | 98 +++-------------------- crates/whir/src/open.rs | 27 ++----- crates/whir/src/verify.rs | 14 +--- crates/whir/tests/run_whir.rs | 6 +- 6 files changed, 28 insertions(+), 195 deletions(-) diff --git a/crates/rec_aggregation/src/compilation.rs b/crates/rec_aggregation/src/compilation.rs index d9d4f3c6..3465f90c 100644 --- a/crates/rec_aggregation/src/compilation.rs +++ b/crates/rec_aggregation/src/compilation.rs @@ -1,7 +1,7 @@ use backend::*; use lean_compiler::{CompilationFlags, ProgramSource, compile_program_with_flags}; use lean_prover::{ - GRINDING_BITS, MAX_NUM_VARIABLES_TO_SEND_COEFFS, RS_DOMAIN_INITIAL_REDUCTION_FACTOR, WHIR_INITIAL_FOLDING_FACTOR, + MAX_NUM_VARIABLES_TO_SEND_COEFFS, RS_DOMAIN_INITIAL_REDUCTION_FACTOR, WHIR_INITIAL_FOLDING_FACTOR, WHIR_SUBSEQUENT_FOLDING_FACTOR, default_whir_config, }; use lean_vm::*; @@ -80,8 +80,6 @@ fn build_replacements( let mut all_potential_num_queries = vec![]; let mut all_potential_query_grinding = vec![]; let mut all_potential_num_oods = vec![]; - let mut all_potential_folding_grinding = vec![]; - let mut too_much_grinding = false; for log_inv_rate in MIN_WHIR_LOG_INV_RATE..=MAX_WHIR_LOG_INV_RATE { let max_n_vars = F::TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - log_inv_rate; let whir_config_builder = default_whir_config(log_inv_rate); @@ -89,22 +87,16 @@ fn build_replacements( let mut queries_for_rate = vec![]; let mut query_grinding_for_rate = vec![]; let mut oods_for_rate = vec![]; - let mut folding_grinding_for_rate = vec![]; for n_vars in min_stacked..=max_n_vars { let cfg = WhirConfig::::new(&whir_config_builder, n_vars); - if cfg.max_folding_pow_bits() > GRINDING_BITS { - too_much_grinding = true; - } let mut num_queries = vec![]; let mut query_grinding_bits = vec![]; let mut oods = vec![cfg.commitment_ood_samples]; - let mut folding_grinding = vec![cfg.starting_folding_pow_bits]; for round in &cfg.round_parameters { num_queries.push(round.num_queries); query_grinding_bits.push(round.query_pow_bits); oods.push(round.ood_samples); - folding_grinding.push(round.folding_pow_bits); } num_queries.push(cfg.final_queries); query_grinding_bits.push(cfg.final_query_pow_bits); @@ -125,22 +117,10 @@ fn build_replacements( "[{}]", oods.iter().map(|o| o.to_string()).collect::>().join(", ") )); - folding_grinding_for_rate.push(format!( - "[{}]", - folding_grinding - .iter() - .map(|g| g.to_string()) - .collect::>() - .join(", ") - )); } all_potential_num_queries.push(format!("[{}]", queries_for_rate.join(", "))); all_potential_query_grinding.push(format!("[{}]", query_grinding_for_rate.join(", "))); all_potential_num_oods.push(format!("[{}]", oods_for_rate.join(", "))); - all_potential_folding_grinding.push(format!("[{}]", folding_grinding_for_rate.join(", "))); - } - if too_much_grinding { - tracing::info!("Warning: Too much grinding for WHIR folding"); // TODO } replacements.insert( "WHIR_FIRST_RS_REDUCTION_FACTOR_PLACEHOLDER".to_string(), @@ -158,10 +138,6 @@ fn build_replacements( "WHIR_ALL_POTENTIAL_NUM_OODS_PLACEHOLDER".to_string(), format!("[{}]", all_potential_num_oods.join(", ")), ); - replacements.insert( - "WHIR_ALL_POTENTIAL_FOLDING_GRINDING_PLACEHOLDER".to_string(), - format!("[{}]", all_potential_folding_grinding.join(", ")), - ); replacements.insert("MIN_STACKED_N_VARS_PLACEHOLDER".to_string(), min_stacked.to_string()); // VM recursion parameters (different from WHIR) diff --git a/crates/rec_aggregation/whir.py b/crates/rec_aggregation/whir.py index 311b6cd9..2e23f0af 100644 --- a/crates/rec_aggregation/whir.py +++ b/crates/rec_aggregation/whir.py @@ -11,7 +11,6 @@ WHIR_ALL_POTENTIAL_NUM_QUERIES = WHIR_ALL_POTENTIAL_NUM_QUERIES_PLACEHOLDER WHIR_ALL_POTENTIAL_QUERY_GRINDING = WHIR_ALL_POTENTIAL_QUERY_GRINDING_PLACEHOLDER WHIR_ALL_POTENTIAL_NUM_OODS = WHIR_ALL_POTENTIAL_NUM_OODS_PLACEHOLDER -WHIR_ALL_POTENTIAL_FOLDING_GRINDING = WHIR_ALL_POTENTIAL_FOLDING_GRINDING_PLACEHOLDER MIN_STACKED_N_VARS = MIN_STACKED_N_VARS_PLACEHOLDER @@ -24,7 +23,7 @@ def whir_open( combination_randomness_powers_0, claimed_sum: Mut, ): - n_rounds, n_final_vars, num_queries, num_oods, query_grinding_bits, folding_grinding = get_whir_params(n_vars, initial_log_inv_rate) + n_rounds, n_final_vars, num_queries, num_oods, query_grinding_bits = get_whir_params(n_vars, initial_log_inv_rate) folding_factors = Array(n_rounds + 1) folding_factors[0] = WHIR_INITIAL_FOLDING_FACTOR for i in range(1, n_rounds + 1): @@ -61,15 +60,14 @@ def whir_open( claimed_sum, query_grinding_bits[r], num_oods[r + 1], - folding_grinding[r], ) if r == 0: domain_sz -= WHIR_FIRST_RS_REDUCTION_FACTOR else: domain_sz -= 1 - fs, all_folding_randomness[n_rounds], claimed_sum = sumcheck_verify_with_grinding( - fs, WHIR_SUBSEQUENT_FOLDING_FACTOR, claimed_sum, 2, folding_grinding[n_rounds] + fs, all_folding_randomness[n_rounds], claimed_sum = sumcheck_verify( + fs, WHIR_SUBSEQUENT_FOLDING_FACTOR, claimed_sum, 2 ) fs, final_coeffcients = fs_receive_ef_by_log_dynamic( @@ -207,19 +205,6 @@ def sumcheck_verify_reversed_helper_const(fs: Mut, n_steps: Const, claimed_sum: return fs, claimed_sum -def sumcheck_verify_with_grinding(fs: Mut, n_steps, claimed_sum: Mut, degree: Const, folding_grinding_bits): - challenges = Array(n_steps * DIM) - for sc_round in range(0, n_steps): - fs, poly = fs_receive_ef_inlined(fs, degree + 1) - polynomial_sum_at_0_and_1(poly, degree, claimed_sum) - fs = fs_grinding(fs, folding_grinding_bits) - fs, rand = fs_sample_ef(fs) - claimed_sum = univariate_polynomial_eval(poly, rand, degree) - copy_ef(rand, challenges + sc_round * DIM) - - return fs, challenges, claimed_sum - - @inline def decompose_and_verify_merkle_batch(num_queries, sampled, root, height, num_chunks, circle_values, answers): debug_assert(height < 25) @@ -336,9 +321,8 @@ def whir_round( claimed_sum, query_grinding_bits, num_ood, - folding_grinding_bits, ): - fs, folding_randomness, new_claimed_sum_a = sumcheck_verify_with_grinding(fs, folding_factor, claimed_sum, 2, folding_grinding_bits) + fs, folding_randomness, new_claimed_sum_a = sumcheck_verify(fs, folding_factor, claimed_sum, 2) fs, root, ood_points, ood_evals = parse_commitment(fs, num_ood) @@ -420,10 +404,7 @@ def get_whir_params(n_vars, log_inv_rate): num_oods = get_num_oods(log_inv_rate, n_vars) - folding_grinding: Imu - folding_grinding = get_folding_grinding(log_inv_rate, n_vars) - - return n_rounds, final_vars, num_queries, num_oods, query_grinding_bits, folding_grinding + return n_rounds, final_vars, num_queries, num_oods, query_grinding_bits @inline @@ -472,29 +453,6 @@ def get_query_grinding_bits_const(log_inv_rate: Const, n_vars: Const): return query_grinding_bits -@inline -def get_folding_grinding(log_inv_rate, n_vars): - res = match_range(log_inv_rate, range(MIN_WHIR_LOG_INV_RATE, MAX_WHIR_LOG_INV_RATE + 1), lambda r: get_folding_grinding_const_rate(r, n_vars)) - return res - - -def get_folding_grinding_const_rate(log_inv_rate: Const, n_vars): - res = match_range( - n_vars, - range(MIN_STACKED_N_VARS, TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - log_inv_rate + 1), - lambda nv: get_folding_grinding_const(log_inv_rate, nv), - ) - return res - - -def get_folding_grinding_const(log_inv_rate: Const, n_vars: Const): - max = len(WHIR_ALL_POTENTIAL_FOLDING_GRINDING[log_inv_rate - MIN_WHIR_LOG_INV_RATE][n_vars - MIN_STACKED_N_VARS]) - folding_grinding = Array(max) - for i in unroll(0, max): - folding_grinding[i] = WHIR_ALL_POTENTIAL_FOLDING_GRINDING[log_inv_rate - MIN_WHIR_LOG_INV_RATE][n_vars - MIN_STACKED_N_VARS][i] - return folding_grinding - - def get_num_oods(log_inv_rate, n_vars): res = match_range(log_inv_rate, range(MIN_WHIR_LOG_INV_RATE, MAX_WHIR_LOG_INV_RATE + 1), lambda r: get_num_oods_const_rate(r, n_vars)) return res diff --git a/crates/whir/src/config.rs b/crates/whir/src/config.rs index 5f57b600..f67edd35 100644 --- a/crates/whir/src/config.rs +++ b/crates/whir/src/config.rs @@ -103,7 +103,6 @@ pub struct WhirConfigBuilder { #[derive(Debug, Clone)] pub struct RoundConfig { pub query_pow_bits: usize, - pub folding_pow_bits: usize, pub num_queries: usize, pub ood_samples: usize, pub log_inv_rate: usize, @@ -119,7 +118,6 @@ pub struct WhirConfig { pub commitment_ood_samples: usize, pub starting_log_inv_rate: usize, - pub starting_folding_pow_bits: usize, pub folding_factor: FoldingFactor, pub rs_domain_initial_reduction_factor: usize, @@ -137,40 +135,22 @@ where PF: TwoAdicField, { /// `log_c` controls the proximity parameter `η` (η = √ρ/c for JB, η = ρ/c for CB). - /// Increasing `log_c` shrinks `η`, which: - /// - reduces the number of queries, but - /// - grows the list size, which tightens the `prox_gaps_error` and `sumcheck_error` -> more PoW grinding - /// - /// Both errors are decreasing functions in `log_c`. Among feasible `m ∈ [3, 100]` (with `log_c = log2(2m)`, - /// and `folding_pow_bits ≤ pow_bits`) we pick the smallest `m` that achieves the minimum query count. - fn compute_optimal_log_c_for_rate( - whir_parameters: &WhirConfigBuilder, - field_size_bits: usize, - num_variables: usize, - log_inv_rate: usize, - ) -> f64 { + /// Increasing `log_c` shrinks `η`, which reduces the number of queries but grows the list size. + /// With the field we use (192-bit cubic extension over Goldilocks), the list size never threatens + /// `prox_gaps_error` / `sumcheck_error` enough to require folding PoW, so we pick the smallest + /// `m ∈ [3, 100]` (with `log_c = log2(2m)`) that achieves the minimum query count — keeping + /// `log_c` (and the dependent OOD sample count) as small as possible. + fn compute_optimal_log_c_for_rate(whir_parameters: &WhirConfigBuilder, log_inv_rate: usize) -> f64 { if matches!(whir_parameters.soundness_type, SecurityAssumption::UniqueDecoding) { return 0.0; } - let pow_budget = whir_parameters.pow_bits; - let query_security_level = whir_parameters.security_level.saturating_sub(pow_budget); + let query_security_level = whir_parameters.security_level.saturating_sub(whir_parameters.pow_bits); let mut best_m = 3; let mut best_queries = usize::MAX; for m in 3..=100 { let log_c = (2.0 * m as f64).log2(); - let folding_pow = Self::folding_pow_bits( - whir_parameters.security_level, - whir_parameters.soundness_type, - field_size_bits, - num_variables, - log_inv_rate, - log_c, - ); - if folding_pow.ceil() as usize > pow_budget { - break; - } let queries = whir_parameters .soundness_type .queries(query_security_level, log_inv_rate, log_c); @@ -182,7 +162,6 @@ where (2.0 * best_m as f64).log2() } - #[allow(clippy::too_many_lines)] pub fn new(whir_parameters: &WhirConfigBuilder, num_variables: usize) -> Self { whir_parameters.folding_factor.check_validity(num_variables).unwrap(); @@ -208,8 +187,7 @@ where .folding_factor .compute_number_of_rounds(num_variables, whir_parameters.max_num_variables_to_send_coeffs); - let mut log_c_old = - Self::compute_optimal_log_c_for_rate(whir_parameters, field_size_bits, num_variables, log_inv_rate); + let mut log_c_old = Self::compute_optimal_log_c_for_rate(whir_parameters, log_inv_rate); let commitment_ood_samples = whir_parameters.soundness_type.determine_ood_samples( whir_parameters.security_level, @@ -219,15 +197,6 @@ where log_c_old, ); - let starting_folding_pow_bits = Self::folding_pow_bits( - whir_parameters.security_level, - whir_parameters.soundness_type, - field_size_bits, - num_variables, - log_inv_rate, - log_c_old, - ); - let mut round_parameters = Vec::with_capacity(num_rounds); let mut num_variables_moving = num_variables; @@ -241,8 +210,7 @@ where }; let next_rate = log_inv_rate + (whir_parameters.folding_factor.at_round(round) - rs_reduction_factor); - let log_c_new = - Self::compute_optimal_log_c_for_rate(whir_parameters, field_size_bits, num_variables_moving, next_rate); + let log_c_new = Self::compute_optimal_log_c_for_rate(whir_parameters, next_rate); let num_queries = whir_parameters .soundness_type @@ -272,21 +240,12 @@ where let query_pow_bits = 0_f64.max(whir_parameters.security_level as f64 - (query_error.min(combination_error))); - let folding_pow_bits = Self::folding_pow_bits( - whir_parameters.security_level, - whir_parameters.soundness_type, - field_size_bits, - num_variables_moving, - next_rate, - log_c_new, - ); let folding_factor = whir_parameters.folding_factor.at_round(round); let next_folding_factor = whir_parameters.folding_factor.at_round(round + 1); let folded_domain_gen = PF::::two_adic_generator(domain_size.ilog2() as usize - folding_factor); round_parameters.push(RoundConfig { query_pow_bits: query_pow_bits.ceil() as usize, - folding_pow_bits: folding_pow_bits.ceil() as usize, num_queries, ood_samples, log_inv_rate, @@ -322,7 +281,6 @@ where commitment_ood_samples, num_variables, starting_log_inv_rate: whir_parameters.starting_log_inv_rate, - starting_folding_pow_bits: starting_folding_pow_bits.ceil() as usize, folding_factor: whir_parameters.folding_factor, rs_domain_initial_reduction_factor: whir_parameters.rs_domain_initial_reduction_factor, round_parameters, @@ -366,41 +324,6 @@ where self.num_variables - self.folding_factor.total_number(self.n_rounds()) } - pub fn max_folding_pow_bits(&self) -> usize { - self.round_parameters.iter().map(|r| r.folding_pow_bits).max().unwrap() - } - - #[must_use] - pub fn rbr_soundness_fold_sumcheck( - soundness_type: SecurityAssumption, - field_size_bits: usize, - num_variables: usize, - log_inv_rate: usize, - log_c: f64, - ) -> f64 { - let list_size = soundness_type.list_size_bits(num_variables, log_inv_rate, log_c); - - field_size_bits as f64 - (list_size + 1.) - } - - #[must_use] - pub fn folding_pow_bits( - security_level: usize, - soundness_type: SecurityAssumption, - field_size_bits: usize, - num_variables: usize, - log_inv_rate: usize, - log_c: f64, - ) -> f64 { - let prox_gaps_error = soundness_type.prox_gaps_error(num_variables, log_inv_rate, field_size_bits, 2, log_c); - let sumcheck_error = - Self::rbr_soundness_fold_sumcheck(soundness_type, field_size_bits, num_variables, log_inv_rate, log_c); - - let error = prox_gaps_error.min(sumcheck_error); - - 0_f64.max(security_level as f64 - error) - } - #[must_use] pub fn rbr_soundness_queries_combination( soundness_type: SecurityAssumption, @@ -436,7 +359,6 @@ where domain_size, folded_domain_gen, ood_samples: last.ood_samples, - folding_pow_bits: 0, log_inv_rate: last.log_inv_rate, } } @@ -462,7 +384,7 @@ impl SecurityAssumption { /// /// `log_c` is log2 of the divisor c, where η = √ρ/c (JB) or ρ/c (CB). /// It is computed per rate phase by `WhirConfig::compute_optimal_log_c_for_rate` to - /// balance folding PoW vs queries. + /// minimize the query count. #[must_use] pub fn log_eta(&self, log_inv_rate: usize, log_c: f64) -> f64 { match self { diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 8b8b4031..bb81d09d 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -159,12 +159,10 @@ where &stir_combination_randomness, ); - let next_folding_randomness = round_state.sumcheck_prover.run_sumcheck_many_rounds( - None, - prover_state, - folding_factor_next, - round_params.folding_pow_bits, - ); + let next_folding_randomness = + round_state + .sumcheck_prover + .run_sumcheck_many_rounds(None, prover_state, folding_factor_next); round_state.randomness_vec.extend_from_slice(&next_folding_randomness.0); @@ -238,7 +236,7 @@ where let final_folding_randomness = round_state .sumcheck_prover - .run_sumcheck_many_rounds(None, prover_state, self.final_sumcheck_rounds, 0); + .run_sumcheck_many_rounds(None, prover_state, self.final_sumcheck_rounds); round_state.randomness_vec.extend(final_folding_randomness.0); } @@ -385,7 +383,6 @@ where prev_folding_scalar: Option, prover_state: &mut impl FSProver, n_rounds: usize, - pow_bits: usize, ) -> MultilinearPoint { let (challenges, folds, new_sum) = sumcheck_prove_many_rounds( MleGroupRef::merge(&[&self.evals.by_ref(), &self.weights.by_ref()]), @@ -398,7 +395,7 @@ where None, n_rounds, false, - pow_bits, + 0, ); self.sum = new_sum; @@ -414,7 +411,6 @@ where combination_randomness: EF, prover_state: &mut impl FSProver, folding_factor: usize, - pow_bits: usize, ) -> (Self, MultilinearPoint) { assert_ne!(folding_factor, 0); @@ -422,14 +418,8 @@ where let mut evals = evals.pack(); let mut weights = Mle::Owned(MleOwned::ExtensionPacked(weights)); - let (challengess, new_sum, new_evals, new_weights) = run_product_sumcheck( - &evals.by_ref(), - &weights.by_ref(), - prover_state, - sum, - folding_factor, - pow_bits, - ); + let (challengess, new_sum, new_evals, new_weights) = + run_product_sumcheck(&evals.by_ref(), &weights.by_ref(), prover_state, sum, folding_factor, 0); evals = new_evals.into(); weights = new_weights.into(); @@ -492,7 +482,6 @@ where combination_randomness_gen, prover_state, prover.folding_factor.at_round(0), - prover.starting_folding_pow_bits, ); Ok(Self { diff --git a/crates/whir/src/verify.rs b/crates/whir/src/verify.rs index 18925b28..98897615 100644 --- a/crates/whir/src/verify.rs +++ b/crates/whir/src/verify.rs @@ -111,12 +111,8 @@ where round_constraints.push((combination_randomness, constraints)); // Initial sumcheck - let folding_randomness = verify_sumcheck_rounds::( - verifier_state, - &mut claimed_sum, - self.folding_factor.at_round(0), - self.starting_folding_pow_bits, - )?; + let folding_randomness = + verify_sumcheck_rounds::(verifier_state, &mut claimed_sum, self.folding_factor.at_round(0))?; round_folding_randomness.push(folding_randomness); for round_index in 0..self.n_rounds() { @@ -150,7 +146,6 @@ where verifier_state, &mut claimed_sum, self.folding_factor.at_round(round_index + 1), - round_params.folding_pow_bits, )?; round_folding_randomness.push(folding_randomness); @@ -180,7 +175,7 @@ where .ok_or(ProofError::InvalidProof)?; let final_sumcheck_randomness = - verify_sumcheck_rounds::(verifier_state, &mut claimed_sum, self.final_sumcheck_rounds, 0)?; + verify_sumcheck_rounds::(verifier_state, &mut claimed_sum, self.final_sumcheck_rounds)?; round_folding_randomness.push(final_sumcheck_randomness.clone()); // Compute folding randomness across all rounds. @@ -404,7 +399,6 @@ pub(crate) fn verify_sumcheck_rounds( verifier_state: &mut impl FSVerifier, claimed_sum: &mut EF, rounds: usize, - pow_bits: usize, ) -> ProofResult> where F: TwoAdicField, @@ -417,8 +411,6 @@ where let coeffs = verifier_state.next_sumcheck_polynomial(3, *claimed_sum, None)?; let poly = DensePolynomial::new(coeffs); - verifier_state.check_pow_grinding(pow_bits)?; - // Sample the next verifier folding randomness rᵢ let rand: EF = verifier_state.sample(); diff --git a/crates/whir/tests/run_whir.rs b/crates/whir/tests/run_whir.rs index d1cc3c88..35c204bd 100644 --- a/crates/whir/tests/run_whir.rs +++ b/crates/whir/tests/run_whir.rs @@ -153,9 +153,6 @@ fn display_whir_round_info() { rs_domain_initial_reduction_factor: 5, }; let params = WhirConfig::::new(¶ms, n_vars); - let folding_pow_bits = std::iter::once(params.starting_folding_pow_bits) - .chain(params.round_parameters.iter().map(|r| r.folding_pow_bits)) - .collect::>(); let query_pow_bits = params .round_parameters .iter() @@ -163,7 +160,7 @@ fn display_whir_round_info() { .chain(std::iter::once(params.final_query_pow_bits)) .collect::>(); println!( - "n_vars: {}, log_inv_rate: {}, num_queries: {:?}, folding_pow_bits: {:?}, query_pow_bits: {:?}", + "n_vars: {}, log_inv_rate: {}, num_queries: {:?}, query_pow_bits: {:?}", n_vars, log_inv_rate, params @@ -171,7 +168,6 @@ fn display_whir_round_info() { .iter() .map(|r| r.num_queries) .collect::>(), - folding_pow_bits, query_pow_bits, ); } From ddfd8493435236f222ad8236347c1827f5e59430 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 4 May 2026 15:57:04 +0200 Subject: [PATCH 28/31] chunks of 2w in xmss --- crates/lean_compiler/zkDSL.md | 2 +- crates/lean_vm/src/isa/hint.rs | 14 ++-- crates/rec_aggregation/xmss_aggregate.py | 90 ++++++++++++------------ 3 files changed, 55 insertions(+), 51 deletions(-) diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index f3dad32d..076a946a 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -456,7 +456,7 @@ hints = prover-supplied values at runtime (without adding snark constraints). Li | `hint_decompose_bits` | `(to_decompose, ptr, num_bits, endianness)` | `num_bits` field elements at `ptr` (the 0/1 bit decomposition of `to_decompose`); `endianness` is `0` for big-endian, `1` for little-endian | | `hint_less_than` | `(a, b, result_ptr)` | `1` at `result_ptr` if `a < b` else `0` | | `hint_log2_ceil` | `(n, result_ptr)` | `ceil(log2(n))` at `result_ptr` | -| `hint_decompose_bits_xmss` | `(chunks_ptr, limbs_ptr, src_value)` | WOTS-encoding decomposition of one Goldilocks FE: 10 W-bit chunks of the low 30 bits at `chunks_ptr[0..10]` + 2 u16 limbs of the high 32 bits at `limbs_ptr[0..2]` (the top 2 bits of the low limb are implicit zeros — see `crates/lean_vm/src/isa/hint.rs`) | +| `hint_decompose_bits_xmss` | `(chunks_ptr, limbs_ptr, src_value)` | WOTS-encoding decomposition of one Goldilocks FE: 5 2W-bit chunks of the low 30 bits at `chunks_ptr[0..5]` (each chunk packs two consecutive chain steps as `step_a + CHAIN_LENGTH * step_b`) + 2 u16 limbs of the high 32 bits at `limbs_ptr[0..2]` (the top 2 bits of the low limb are implicit zeros — see `crates/lean_vm/src/isa/hint.rs`) | | `hint_decompose_bits_merkle_whir` | `(decomposed_ptr, remaining_ptr, value, chunk_size)` | Merkle/WHIR-specific decomposition | Hints only *suggest* a value; the guest must add appropriate constraints to bind that value to its specification. diff --git a/crates/lean_vm/src/isa/hint.rs b/crates/lean_vm/src/isa/hint.rs index 8630d180..6dd6c1ac 100644 --- a/crates/lean_vm/src/isa/hint.rs +++ b/crates/lean_vm/src/isa/hint.rs @@ -103,8 +103,10 @@ impl HintWitnessDestination { pub enum CustomHint { /// WOTS-encoding decomposition of one Goldilocks FE. /// Args: (chunks_ptr, limbs_ptr, src_value). - /// Writes 10 W=3-bit chunks of the low 30 bits to `chunks_ptr[0..10]` - /// and 2 u16 limbs of the high 32 bits to `limbs_ptr[0..2]`. + /// Writes 5 2W=6-bit chunks of the low 30 bits to `chunks_ptr[0..5]` + /// (each chunk packs two consecutive W=3-bit chain steps as + /// `step_a + CHAIN_LENGTH * step_b`) and 2 u16 limbs of the high + /// 32 bits to `limbs_ptr[0..2]`. DecomposeBitsXMSS, DecomposeBitsMerkleWhir, DecomposeBits, @@ -149,15 +151,17 @@ impl CustomHint { match self { Self::DecomposeBitsXMSS => { // WOTS-encoding decomposition. Writes: - // chunks_ptr[0..10] = 10 chunks of W=3 bits (low bits 0..29) + // chunks_ptr[0..5] = 5 chunks of 2W=6 bits (low bits 0..29). + // Each chunk packs two consecutive chain + // steps as `step_a + CHAIN_LENGTH * step_b`. // limbs_ptr[0..2] = 2 u16 limbs of the high 32 bits (bits 32..47, 48..63) // The 2 high bits of the low limb are implicit zeros, enforced by // the SNARK constraint structure (and rejected at signing time). let chunks_ptr = args[0].read_value(ctx.memory, ctx.fp)?.to_usize(); let limbs_ptr = args[1].read_value(ctx.memory, ctx.fp)?.to_usize(); let value = args[2].read_value(ctx.memory, ctx.fp)?.as_canonical_u64(); - const NUM_CHUNKS: usize = 10; - const CHUNK_SIZE: usize = 3; + const NUM_CHUNKS: usize = 5; + const CHUNK_SIZE: usize = 6; for j in 0..NUM_CHUNKS { let chunk = (value >> (CHUNK_SIZE * j)) & ((1u64 << CHUNK_SIZE) - 1); ctx.memory.set(chunks_ptr + j, F::from_u64(chunk))?; diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index cc8f7bbe..ca1ba8c7 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -61,11 +61,11 @@ def xmss_verify(pub_key, message, merkle_chunks): encoding_fe = Array(DIGEST_LEN) poseidon8_compress(pre_compressed, public_params_paded, encoding_fe) - encoding = Array(V) + debug_assert(V % 2 == 0) + encoding = Array(V / 2) for i in unroll(0, NUM_ENCODING_FE): - decompose_encoding_fe(encoding_fe[i], encoding + i * (V / NUM_ENCODING_FE)) + decompose_encoding_fe(encoding_fe[i], encoding + i * ((V / 2) / NUM_ENCODING_FE)) - debug_assert(V % 2 == 0) wots_public_key = Array((V / 2) * WOTS_PK_PAIR_STRIDE) target_sum: Mut = 0 # Pair structure: `[leading_0 | tip_a(XMSS_DIGEST_LEN) | tip_b(XMSS_DIGEST_LEN) | trailing_0]` @@ -77,25 +77,24 @@ def xmss_verify(pub_key, message, merkle_chunks): chain_end_b = chain_end_a + XMSS_DIGEST_LEN tweaks_a = TWEAK_TABLE_ADDR + TWEAK_CHAIN_OFFSET + (2 * i) * CHAIN_LENGTH * TWEAK_LEN tweaks_b = TWEAK_TABLE_ADDR + TWEAK_CHAIN_OFFSET + (2 * i + 1) * CHAIN_LENGTH * TWEAK_LEN - len_a_ptr = Array(1) - len_b_ptr = Array(1) + pair_sum_ptr = Array(1) match_range( - encoding[2 * i], - range(0, CHAIN_LENGTH), - lambda n: chain_hash_a( - chain_start_a, n, chain_end_a, tweaks_a, public_params_paded, len_a_ptr, - ), - ) - match_range( - encoding[2 * i + 1], - range(0, CHAIN_LENGTH), - lambda n: chain_hash_b( - chain_start_b, n, chain_end_b, tweaks_b, public_params_paded, len_b_ptr, + encoding[i], + range(0, CHAIN_LENGTH**2), + lambda n: chain_hash_pair( + chain_start_a, + chain_start_b, + n, + chain_end_a, + chain_end_b, + tweaks_a, + tweaks_b, + public_params_paded, + pair_sum_ptr, ), ) - target_sum += len_a_ptr[0] - target_sum += len_b_ptr[0] + target_sum += pair_sum_ptr[0] assert target_sum == TARGET_SUM @@ -143,32 +142,33 @@ def chain_hash_inner(input, n, output, chain_i_tweaks, chain_right): @inline -def chain_hash_a(input, n, output, chain_i_tweaks, chain_right, chain_length_ptr): - # Even (chain_a) variant: when num_hashes == 0, the buffer slot occupies - # `[output-1 .. output+XMSS_DIGEST_LEN)` (= leading_0 | tip_a) so copy_ef - # writes through the leading zero cell. - debug_assert(n < CHAIN_LENGTH) - num_hashes = (CHAIN_LENGTH - 1) - n - if num_hashes == 0: - copy_ef(input - 1, output - 1) +def chain_hash_pair( + input_a, + input_b, + n, + output_a, + output_b, + tweaks_a, + tweaks_b, + chain_right, + pair_sum_ptr, +): + raw_a = n % CHAIN_LENGTH + raw_b = (n - raw_a) / CHAIN_LENGTH + num_hashes_a = (CHAIN_LENGTH - 1) - raw_a + num_hashes_b = (CHAIN_LENGTH - 1) - raw_b + + if num_hashes_a == 0: + copy_ef(input_a - 1, output_a - 1) else: - chain_hash_inner(input, n, output, chain_i_tweaks, chain_right) - chain_length_ptr[0] = n - return - + chain_hash_inner(input_a, raw_a, output_a, tweaks_a, chain_right) -@inline -def chain_hash_b(input, n, output, chain_i_tweaks, chain_right, chain_length_ptr): - # Odd (chain_b) variant: when num_hashes == 0, the buffer slot occupies - # `[output .. output+XMSS_DIGEST_LEN+1)` (= tip_b | trailing_0) so copy_ef - # writes through the trailing zero cell. - debug_assert(n < CHAIN_LENGTH) - num_hashes = (CHAIN_LENGTH - 1) - n - if num_hashes == 0: - copy_ef(input, output) + if num_hashes_b == 0: + copy_ef(input_b, output_b) else: - chain_hash_inner(input, n, output, chain_i_tweaks, chain_right) - chain_length_ptr[0] = n + chain_hash_inner(input_b, raw_b, output_b, tweaks_b, chain_right) + + pair_sum_ptr[0] = raw_a + raw_b return @@ -177,14 +177,14 @@ def decompose_encoding_fe(fe_value, chunks_ptr): limbs = Array(2) hint_decompose_bits_xmss(chunks_ptr, limbs, fe_value) - for k in unroll(0, 10): - assert chunks_ptr[k] < CHAIN_LENGTH + for k in unroll(0, 5): + assert chunks_ptr[k] < CHAIN_LENGTH**2 assert limbs[0] < 2**16 assert limbs[1] < 2**16 low: Mut = chunks_ptr[0] - for k in unroll(1, 10): - low += chunks_ptr[k] * (2 ** (W * k)) + for k in unroll(1, 5): + low += chunks_ptr[k] * (2 ** (2 * W * k)) high = limbs[0] + limbs[1] * (2**16) assert fe_value == low + (2**32) * high From e5617e07a3597342345896645273de001cc6a78c Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 4 May 2026 15:59:45 +0200 Subject: [PATCH 29/31] 128 bit security --- README.md | 2 +- TODO.md | 1 - crates/lean_prover/src/lib.rs | 2 +- crates/whir/tests/run_whir.rs | 4 ++-- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index e14de1b1..e18d7d7f 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ cargo run --release -- fancy-aggregation ### snark -≈ 124 bits of provable security, given by Johnson bound + degree 5 extension of koala-bear. (128 bits requires bigger hash digests (8 koalabears ≈ 248 bits) -> TODO). In the benchmarks, we also display performance with conjectured security, even though leanVM targets the proven regime by default. +≈ 128 bits of provable security, given by Johnson bound + degree 5 extension of koala-bear. (128 bits requires bigger hash digests (8 koalabears ≈ 248 bits) -> TODO). In the benchmarks, we also display performance with conjectured security, even though leanVM targets the proven regime by default. ### XMSS diff --git a/TODO.md b/TODO.md index be64e116..a8d8b256 100644 --- a/TODO.md +++ b/TODO.md @@ -9,7 +9,6 @@ ## Security: -- 128 bits security? (currently 124) - Fiat Shamir: add a claim tracing feature, to ensure all the claims are indeed checked (Lev) - Double Check AIR constraints, logup overflows etc - Do we need to enforce some values at the first row of the dot-product table? diff --git a/crates/lean_prover/src/lib.rs b/crates/lean_prover/src/lib.rs index d9980caa..7424170a 100644 --- a/crates/lean_prover/src/lib.rs +++ b/crates/lean_prover/src/lib.rs @@ -17,7 +17,7 @@ mod test_zkvm; use trace_gen::*; // Right now, hash digests = 8 koala-bear (p = 2^31 - 2^24 + 1, i.e. ≈ 31 bits per field element) -pub const SECURITY_BITS: usize = 124; // TODO 128 bits security +pub const SECURITY_BITS: usize = 128; // TODO 128 bits security pub const GRINDING_BITS: usize = 16; pub const MAX_NUM_VARIABLES_TO_SEND_COEFFS: usize = 8; diff --git a/crates/whir/tests/run_whir.rs b/crates/whir/tests/run_whir.rs index 35c204bd..20d4dd30 100644 --- a/crates/whir/tests/run_whir.rs +++ b/crates/whir/tests/run_whir.rs @@ -39,7 +39,7 @@ fn test_run_whir() { let num_coeffs = 1 << num_variables; let params = WhirConfigBuilder { - security_level: 124, + security_level: 128, max_num_variables_to_send_coeffs: 9, pow_bits: 18, folding_factor: FoldingFactor::new(7, 4), @@ -144,7 +144,7 @@ fn display_whir_round_info() { continue; } let params = WhirConfigBuilder { - security_level: 124, + security_level: 128, max_num_variables_to_send_coeffs: 8, pow_bits: 16, folding_factor: FoldingFactor::new(first_folding_factor, 5), From c0befb93c4e7e557767bef5fc8e97583920d6c19 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 4 May 2026 16:21:01 +0200 Subject: [PATCH 30/31] EFFECTIVE_TWO_ADICITY = 24 --- crates/rec_aggregation/recursion.py | 2 +- crates/rec_aggregation/src/compilation.rs | 6 +++++- crates/rec_aggregation/tests/test_hashing.rs | 6 +++++- crates/rec_aggregation/tests/test_log2_ceil.rs | 4 ++++ crates/rec_aggregation/utils.py | 2 ++ crates/rec_aggregation/whir.py | 6 +++--- crates/whir/src/config.rs | 17 +++++++++++++++-- 7 files changed, 35 insertions(+), 8 deletions(-) diff --git a/crates/rec_aggregation/recursion.py b/crates/rec_aggregation/recursion.py index 7fa6039f..3cf0fd48 100644 --- a/crates/rec_aggregation/recursion.py +++ b/crates/rec_aggregation/recursion.py @@ -88,7 +88,7 @@ def recursion(inner_public_memory, bytecode_hash_domsep): assert LOG_GUEST_BYTECODE_LEN <= log_memory stacked_n_vars = compute_stacked_n_vars(log_memory, log_bytecode_padded, table_heights) - assert stacked_n_vars <= TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - whir_log_inv_rate + assert stacked_n_vars <= EFFECTIVE_TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - whir_log_inv_rate num_oods = get_num_oods(whir_log_inv_rate, stacked_n_vars) num_ood_at_commitment = num_oods[0] diff --git a/crates/rec_aggregation/src/compilation.rs b/crates/rec_aggregation/src/compilation.rs index 3465f90c..32c8b459 100644 --- a/crates/rec_aggregation/src/compilation.rs +++ b/crates/rec_aggregation/src/compilation.rs @@ -81,7 +81,7 @@ fn build_replacements( let mut all_potential_query_grinding = vec![]; let mut all_potential_num_oods = vec![]; for log_inv_rate in MIN_WHIR_LOG_INV_RATE..=MAX_WHIR_LOG_INV_RATE { - let max_n_vars = F::TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - log_inv_rate; + let max_n_vars = EFFECTIVE_TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - log_inv_rate; let whir_config_builder = default_whir_config(log_inv_rate); let mut queries_for_rate = vec![]; @@ -174,6 +174,10 @@ fn build_replacements( "WHIR_SUBSEQUENT_FOLDING_FACTOR_PLACEHOLDER".to_string(), WHIR_SUBSEQUENT_FOLDING_FACTOR.to_string(), ); + replacements.insert( + "EFFECTIVE_TWO_ADICITY_PLACEHOLDER".to_string(), + EFFECTIVE_TWO_ADICITY.to_string(), + ); replacements.insert( "MAX_LOG_N_ROWS_PER_TABLE_PLACEHOLDER".to_string(), format!( diff --git a/crates/rec_aggregation/tests/test_hashing.rs b/crates/rec_aggregation/tests/test_hashing.rs index c696e90b..0f14d0cf 100644 --- a/crates/rec_aggregation/tests/test_hashing.rs +++ b/crates/rec_aggregation/tests/test_hashing.rs @@ -1,4 +1,4 @@ -use backend::PrimeCharacteristicRing; +use backend::{EFFECTIVE_TWO_ADICITY, PrimeCharacteristicRing}; use lean_compiler::*; use lean_vm::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; @@ -15,6 +15,10 @@ fn test_slice_hashing() { "NUM_REPEATED_ONES_PLACEHOLDER".to_string(), NUM_REPEATED_ONES.to_string(), ), + ( + "EFFECTIVE_TWO_ADICITY_PLACEHOLDER".to_string(), + EFFECTIVE_TWO_ADICITY.to_string(), + ), ]); let bytecode = compile_program_with_flags(&ProgramSource::Filepath(path), CompilationFlags { replacements }); diff --git a/crates/rec_aggregation/tests/test_log2_ceil.rs b/crates/rec_aggregation/tests/test_log2_ceil.rs index bced2ee5..5f6d7aeb 100644 --- a/crates/rec_aggregation/tests/test_log2_ceil.rs +++ b/crates/rec_aggregation/tests/test_log2_ceil.rs @@ -14,6 +14,10 @@ fn test_log2_ceil() { "NUM_REPEATED_ONES_PLACEHOLDER".to_string(), NUM_REPEATED_ONES.to_string(), ), + ( + "EFFECTIVE_TWO_ADICITY_PLACEHOLDER".to_string(), + EFFECTIVE_TWO_ADICITY.to_string(), + ), ]); let bytecode = compile_program_with_flags(&ProgramSource::Filepath(path), CompilationFlags { replacements }); let witness = ExecutionWitness { diff --git a/crates/rec_aggregation/utils.py b/crates/rec_aggregation/utils.py index ea3ade11..9d7acc0a 100644 --- a/crates/rec_aggregation/utils.py +++ b/crates/rec_aggregation/utils.py @@ -7,6 +7,8 @@ TWO_ADICITY = 32 ROOT = 1753635133440165772 # = 0x185629dcda58878c, of order 2^TWO_ADICITY +EFFECTIVE_TWO_ADICITY = EFFECTIVE_TWO_ADICITY_PLACEHOLDER + @inline def build_preamble_memory(): diff --git a/crates/rec_aggregation/whir.py b/crates/rec_aggregation/whir.py index 2e23f0af..2bf6bb43 100644 --- a/crates/rec_aggregation/whir.py +++ b/crates/rec_aggregation/whir.py @@ -416,7 +416,7 @@ def get_num_queries(log_inv_rate, n_vars): def get_num_queries_const_rate(log_inv_rate: Const, n_vars): res = match_range( n_vars, - range(MIN_STACKED_N_VARS, TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - log_inv_rate + 1), + range(MIN_STACKED_N_VARS, EFFECTIVE_TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - log_inv_rate + 1), lambda nv: get_num_queries_const(log_inv_rate, nv), ) return res @@ -439,7 +439,7 @@ def get_query_grinding_bits(log_inv_rate, n_vars): def get_query_grinding_bits_const_rate(log_inv_rate: Const, n_vars): res = match_range( n_vars, - range(MIN_STACKED_N_VARS, TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - log_inv_rate + 1), + range(MIN_STACKED_N_VARS, EFFECTIVE_TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - log_inv_rate + 1), lambda nv: get_query_grinding_bits_const(log_inv_rate, nv), ) return res @@ -461,7 +461,7 @@ def get_num_oods(log_inv_rate, n_vars): def get_num_oods_const_rate(log_inv_rate: Const, n_vars): res = match_range( n_vars, - range(MIN_STACKED_N_VARS, TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - log_inv_rate + 1), + range(MIN_STACKED_N_VARS, EFFECTIVE_TWO_ADICITY + WHIR_INITIAL_FOLDING_FACTOR - log_inv_rate + 1), lambda nv: get_num_oods_const(log_inv_rate, nv), ) return res diff --git a/crates/whir/src/config.rs b/crates/whir/src/config.rs index f67edd35..62687154 100644 --- a/crates/whir/src/config.rs +++ b/crates/whir/src/config.rs @@ -3,6 +3,11 @@ use field::{Field, TwoAdicField}; use poly::*; +// (Goldilocks two adicity is 32) We use a smaller one to avoid having to deal with PoW grinding at folding in WHIR +// TODO we likely want a bit more than 24, so we should reintroduce PoW grinding for folding in the future +// But hopefully we will have better proximity gaps formulas by then +pub const EFFECTIVE_TWO_ADICITY: usize = 24; + /// Defines the folding factor for polynomial commitments. #[derive(Debug, Clone, Copy)] pub struct FoldingFactor { @@ -179,8 +184,16 @@ where let log_folded_domain_size = log_domain_size - whir_parameters.folding_factor.at_round(0); assert!( - log_folded_domain_size <= PF::::TWO_ADICITY, - "Increase folding_factor_0" + log_folded_domain_size <= EFFECTIVE_TWO_ADICITY, + "num_variables + log_inv_rate must be ≤ EFFECTIVE_TWO_ADICITY ({}) + first_folding_factor ({}); \ + got {}.", + EFFECTIVE_TWO_ADICITY, + whir_parameters.folding_factor.at_round(0), + log_domain_size, + ); + debug_assert!( + EFFECTIVE_TWO_ADICITY <= PF::::TWO_ADICITY, + "EFFECTIVE_TWO_ADICITY exceeds the field's actual two-adicity", ); let (num_rounds, final_sumcheck_rounds) = whir_parameters From 41a15290cde955815514ef68eb960cb0db250674 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 4 May 2026 17:21:53 +0200 Subject: [PATCH 31/31] change whir params --- crates/lean_prover/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/lean_prover/src/lib.rs b/crates/lean_prover/src/lib.rs index 57aacb1d..b6f35400 100644 --- a/crates/lean_prover/src/lib.rs +++ b/crates/lean_prover/src/lib.rs @@ -23,8 +23,8 @@ pub const SECURITY_BITS: usize = 128; // TODO 128 bits security pub const GRINDING_BITS: usize = 16; pub const MAX_NUM_VARIABLES_TO_SEND_COEFFS: usize = 8; -pub const WHIR_INITIAL_FOLDING_FACTOR: usize = 7; -pub const WHIR_SUBSEQUENT_FOLDING_FACTOR: usize = 5; +pub const WHIR_INITIAL_FOLDING_FACTOR: usize = 6; +pub const WHIR_SUBSEQUENT_FOLDING_FACTOR: usize = 4; pub const RS_DOMAIN_INITIAL_REDUCTION_FACTOR: usize = 5; // Domain-separation digest for the zkVM SNARK. Arbitrary nothing-up-my-sleeve field