diff --git a/Cargo.lock b/Cargo.lock index d938586b8..981ad6981 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -105,6 +105,7 @@ dependencies = [ "mt-utils", "mt-whir", "rayon", + "sha2 0.11.0", "tracing", ] @@ -547,6 +548,7 @@ dependencies = [ "pest", "pest_derive", "rand", + "sha2 0.11.0", "tracing", "utils", "xmss", @@ -621,6 +623,7 @@ dependencies = [ "mt-utils", "rayon", "serde", + "sha2 0.11.0", "tracing", ] @@ -710,6 +713,7 @@ dependencies = [ "mt-utils", "rand", "rayon", + "sha2 0.10.9", "system-info", "tracing", "tracing-forest", @@ -836,7 +840,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "89815c69d36021a140146f26659a81d6c2afa33d216d736dd4be5381a7362220" dependencies = [ "pest", - "sha2", + "sha2 0.10.9", ] [[package]] @@ -1043,6 +1047,17 @@ dependencies = [ "digest 0.10.7", ] +[[package]] +name = "sha2" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "446ba717509524cb3f22f17ecc096f10f4822d76ab5c0b9822c5f9c284e825f4" +dependencies = [ + "cfg-if", + "cpufeatures 0.3.0", + "digest 0.11.2", +] + [[package]] name = "sha3" version = "0.11.0" diff --git a/Cargo.toml b/Cargo.toml index f8e2ada76..a83615197 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,6 +80,8 @@ tracing-forest = { version = "0.3.0", features = ["ansi", "smallvec"] } postcard = { version = "1.1.3", features = ["alloc"] } lz4_flex = "0.13.0" include_dir = "0.7" +sha2 = "0.11.0" + [features] prox-gaps-conjecture = ["rec_aggregation/prox-gaps-conjecture"] diff --git a/crates/backend/Cargo.toml b/crates/backend/Cargo.toml index 3f61957af..3ebd8084b 100644 --- a/crates/backend/Cargo.toml +++ b/crates/backend/Cargo.toml @@ -15,3 +15,5 @@ tracing.workspace = true fiat-shamir = { path = "fiat-shamir", package = "mt-fiat-shamir" } koala-bear = { path = "koala-bear", package = "mt-koala-bear" } utils = { path = "utils", package = "mt-utils" } +sha2.workspace = true + diff --git a/crates/backend/fiat-shamir/Cargo.toml b/crates/backend/fiat-shamir/Cargo.toml index ec8649bc2..d277c3e04 100644 --- a/crates/backend/fiat-shamir/Cargo.toml +++ b/crates/backend/fiat-shamir/Cargo.toml @@ -11,3 +11,4 @@ utils = { path = "../utils", package = "mt-utils" } tracing.workspace = true serde.workspace = true rayon.workspace = true +sha2.workspace = true \ No newline at end of file diff --git a/crates/backend/fiat-shamir/src/challenger.rs b/crates/backend/fiat-shamir/src/challenger.rs index 34fcd94ab..5a570e673 100644 --- a/crates/backend/fiat-shamir/src/challenger.rs +++ b/crates/backend/fiat-shamir/src/challenger.rs @@ -1,4 +1,7 @@ -use field::PrimeField64; +use std::marker::PhantomData; + +use field::{PrimeField32, PrimeField64}; +use sha2::{Digest, Sha256}; use symetric::Compression; pub(crate) const RATE: usize = 8; @@ -73,3 +76,141 @@ impl> Challenger { res } } + +#[derive(Clone, Debug)] +pub struct ChallengerSha2 { + pub sha2: Sha256, + _marker: PhantomData, +} + +impl ChallengerSha2 { + pub fn new() -> Self { + Self { + sha2: Sha256::new(), + _marker: PhantomData, + } + } + + pub fn observe(&mut self, value: [F; RATE]) { + for val in value { + self.sha2.update(val.as_canonical_u32().to_le_bytes()); + } + } + + pub fn observe_scalars(&mut self, scalars: &[F]) { + for chunk in scalars.chunks(RATE) { + let mut buffer = [F::ZERO; RATE]; + for (i, val) in chunk.iter().enumerate() { + buffer[i] = *val; + } + self.observe(buffer); + } + } + + pub fn observe_bytes(&mut self, bytes: &[u8]) { + self.sha2.update(bytes); + } + + pub fn sample_many(&mut self, n: usize) -> Vec<[F; RATE]> { + let mut sampled = Vec::with_capacity(n + 1); + for i in 0..n + 1 { + sampled.push(self.sample_chunk(i)); + } + let last = sampled.pop().unwrap(); + self.sha2 = Sha256::new(); + self.observe(last); + sampled + } + + pub fn sample_in_range(&mut self, bits: usize, n_samples: usize) -> Vec { + assert!(bits < F::bits()); + let sampled_fe = self.sample_many(n_samples.div_ceil(RATE)).into_iter().flatten(); + let mut res = Vec::new(); + for fe in sampled_fe.take(n_samples) { + let rand_usize = fe.as_canonical_u64() as usize; + res.push(rand_usize & ((1 << bits) - 1)); + } + res + } + + pub fn pow_grinding_sample_matches(&self, bits: usize) -> bool { + assert!(bits < F::bits()); + let sample = self.sample_first_word(0, 0); + let rand_usize = sample.as_canonical_u64() as usize; + (rand_usize & ((1 << bits) - 1)) == 0 + } + + pub fn pow_grinding_witness_matches(&self, witness: F, bits: usize) -> bool { + let mut challenger = self.clone(); + challenger.observe_scalars(&[witness]); + challenger.pow_grinding_sample_matches(bits) + } + + fn sample_chunk(&self, domain_sep: usize) -> [F; RATE] { + let mut words = Vec::with_capacity(RATE); + for block_idx in 0u64.. { + let digest = self.sample_digest(domain_sep, block_idx); + for word in digest.chunks_exact(size_of::()) { + let word = u32::from_le_bytes(word.try_into().unwrap()); + words.push(F::from_int(word)); + if words.len() == RATE { + return words.try_into().unwrap(); + } + } + } + unreachable!() + } + + fn sample_first_word(&self, domain_sep: usize, block_idx: u64) -> F { + let digest = self.sample_digest(domain_sep, block_idx); + let word = u32::from_le_bytes(digest[..size_of::()].try_into().unwrap()); + F::from_int(word) + } + + fn sample_digest(&self, domain_sep: usize, block_idx: u64) -> sha2::digest::Output { + let mut hasher = self.sha2.clone(); + hasher.update((domain_sep as u64).to_le_bytes()); + hasher.update(block_idx.to_le_bytes()); + hasher.finalize() + } +} + +impl Default for ChallengerSha2 { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use field::PrimeCharacteristicRing; + use koala_bear::KoalaBear; + + use super::ChallengerSha2; + + #[test] + fn sha2_pow_grinding_direct_predicate_matches_sampling_path() { + let transcript_prefixes = [ + vec![], + vec![KoalaBear::ONE], + (0..17).map(KoalaBear::from_usize).collect::>(), + (100..141).map(KoalaBear::from_usize).collect::>(), + ]; + + for prefix in transcript_prefixes { + let mut challenger = ChallengerSha2::new(); + challenger.observe_scalars(&prefix); + + for bits in 1..=20 { + for candidate in [0, 1, 2, 3, 5, 8, 13, 21, 55, 89, 144, 233, 377, 610] { + let witness = KoalaBear::from_usize(candidate); + let mut sampling_path = challenger.clone(); + sampling_path.observe_scalars(&[witness]); + let expected = sampling_path.sample_in_range(bits, 1)[0] == 0; + let actual = challenger.pow_grinding_witness_matches(witness, bits); + assert_eq!(actual, expected); + } + } + } + } +} diff --git a/crates/backend/fiat-shamir/src/merkle_pruning.rs b/crates/backend/fiat-shamir/src/merkle_pruning.rs index 2461d0bbc..c6221d269 100644 --- a/crates/backend/fiat-shamir/src/merkle_pruning.rs +++ b/crates/backend/fiat-shamir/src/merkle_pruning.rs @@ -1,13 +1,13 @@ use serde::{Deserialize, Serialize}; -use crate::{DIGEST_LEN_FE, MerklePath, MerklePaths}; +use crate::{MerklePath, MerklePaths}; #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PrunedMerklePaths { +pub struct PrunedMerklePaths { pub merkle_height: usize, pub original_order: Vec, pub leaf_data: Vec>, - pub paths: Vec<(usize, Vec<[F; DIGEST_LEN_FE]>)>, + pub paths: Vec<(usize, Vec)>, pub n_trailing_zeros: usize, } @@ -15,8 +15,8 @@ fn lca_level(a: usize, b: usize) -> usize { (usize::BITS - (a ^ b).leading_zeros()) as usize } -impl MerklePaths { - pub fn prune(self) -> PrunedMerklePaths +impl MerklePaths { + pub fn prune(self) -> PrunedMerklePaths where Data: Default + PartialEq, { @@ -27,7 +27,7 @@ impl MerklePaths { indexed.sort_by_key(|(_, p)| p.leaf_index); let mut original_order = vec![0; indexed.len()]; - let mut deduped = Vec::>::new(); + let mut deduped = Vec::>::new(); for (orig_idx, path) in indexed { if deduped.last().map(|p| p.leaf_index) == Some(path.leaf_index) { @@ -83,12 +83,12 @@ impl MerklePaths { } } -impl PrunedMerklePaths { +impl PrunedMerklePaths { pub fn restore( mut self, - hash_leaf: &impl Fn(&[Data]) -> [F; DIGEST_LEN_FE], - hash_combine: &impl Fn(&[F; DIGEST_LEN_FE], &[F; DIGEST_LEN_FE]) -> [F; DIGEST_LEN_FE], - ) -> Option> + hash_leaf: &impl Fn(&[Data]) -> Digest, + hash_combine: &impl Fn(&Digest, &Digest) -> Digest, + ) -> Option> where Data: Default, { @@ -112,7 +112,7 @@ impl PrunedMerklePaths { let skip = |i: usize| self.paths.get(i + 1).map(|p| lca_level(self.paths[i].0, p.0) - 1); // Backward pass: compute subtree hashes needed to restore skipped siblings - let mut subtree_hashes: Vec> = vec![vec![]; n]; + let mut subtree_hashes: Vec> = vec![vec![]; n]; for i in (0..n).rev() { let (leaf_idx, ref stored) = self.paths[i]; @@ -139,7 +139,7 @@ impl PrunedMerklePaths { } // Forward pass: build full sibling arrays - let mut restored: Vec> = Vec::with_capacity(n); + let mut restored: Vec> = Vec::with_capacity(n); for i in 0..n { let (leaf_idx, ref stored) = self.paths[i]; @@ -178,6 +178,7 @@ impl PrunedMerklePaths { #[cfg(test)] mod tests { use super::*; + use crate::DIGEST_LEN_FE; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; @@ -231,7 +232,7 @@ mod tests { leaf_data: Vec, leaf_index: usize, tree: &[Vec<[u8; DIGEST_LEN_FE]>], - ) -> MerklePath { + ) -> MerklePath { let height = tree.len() - 1; let mut sibling_hashes = Vec::with_capacity(height); diff --git a/crates/backend/fiat-shamir/src/prover.rs b/crates/backend/fiat-shamir/src/prover.rs index 2ea95580d..2f693aec9 100644 --- a/crates/backend/fiat-shamir/src/prover.rs +++ b/crates/backend/fiat-shamir/src/prover.rs @@ -1,18 +1,19 @@ use crate::{ MerklePaths, PrunedMerklePaths, - challenger::{Challenger, RATE, WIDTH}, + challenger::{Challenger, ChallengerSha2, RATE, WIDTH}, *, }; use field::Field; use field::PackedValue; use field::PrimeCharacteristicRing; use field::integers::QuotientMap; -use field::{ExtensionField, PrimeField64}; +use field::{ExtensionField, PrimeField32, PrimeField64}; use rayon::prelude::*; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::Duration; use std::{fmt::Debug, sync::Mutex, time::Instant}; use symetric::Compression; +use symetric::merkle::Sha256Digest; static POW_GRINDING_NANOS: AtomicU64 = AtomicU64::new(0); @@ -28,7 +29,7 @@ pub fn reset_pow_grinding_time() { pub struct ProverState>, P> { challenger: Challenger, P>, transcript: Vec>, - merkle_paths: Vec, PF>>, + merkle_paths: Vec, [PF; DIGEST_LEN_FE]>>, } impl>, P: Compression<[PF; WIDTH]>> ProverState @@ -48,6 +49,7 @@ where pub fn into_proof(self) -> Proof> { Proof { transcript: self.transcript, + commitments: Vec::new(), merkle_paths: self.merkle_paths, } } @@ -71,11 +73,17 @@ impl>, P: Compression<[PF; WIDTH]> + Compression<[ where PF: PrimeField64, { + type Digest = [PF; DIGEST_LEN_FE]; + fn add_base_scalars(&mut self, scalars: &[PF]) { self.challenger.observe_scalars(scalars); self.transcript.extend_from_slice(scalars); } + fn add_commitment(&mut self, commitment: &Self::Digest) { + self.add_base_scalars(commitment); + } + fn observe_scalars(&mut self, scalars: &[PF]) { self.challenger.observe_scalars(scalars); } @@ -109,7 +117,7 @@ where } } - fn hint_merkle_paths_base(&mut self, paths: Vec, PF>>) { + fn hint_merkle_paths_base(&mut self, paths: Vec, Self::Digest>>) { self.merkle_paths.push(MerklePaths(paths).prune()); } @@ -170,3 +178,126 @@ where POW_GRINDING_NANOS.fetch_add(elapsed.as_nanos() as u64, Ordering::Relaxed); } } + +#[derive(Debug)] +pub struct ProverStateSha2>> { + challenger: ChallengerSha2>, + transcript: Vec>, + commitments: Vec, + merkle_paths: Vec, Sha256Digest>>, +} + +impl>> ProverStateSha2 +where + PF: PrimeField32, +{ + #[must_use] + pub fn new() -> Self { + assert!(EF::DIMENSION <= RATE); + Self { + challenger: ChallengerSha2::new(), + transcript: Vec::new(), + commitments: Vec::new(), + merkle_paths: Vec::new(), + } + } + + pub fn into_proof(self) -> Proof, Sha256Digest> { + Proof { + transcript: self.transcript, + commitments: self.commitments, + merkle_paths: self.merkle_paths, + } + } +} + +impl>> ChallengeSampler for ProverStateSha2 +where + PF: PrimeField32, +{ + fn sample_vec(&mut self, len: usize) -> Vec { + let sampled_fe = self + .challenger + .sample_many((len * EF::DIMENSION).div_ceil(RATE)) + .into_iter() + .flatten() + .take(len * EF::DIMENSION) + .collect::>>(); + pack_scalars_to_extension(&sampled_fe) + } + + fn sample_in_range(&mut self, bits: usize, n_samples: usize) -> Vec { + self.challenger.sample_in_range(bits, n_samples) + } +} + +impl>> FSProver for ProverStateSha2 +where + PF: PrimeField32, +{ + type Digest = Sha256Digest; + + fn add_base_scalars(&mut self, scalars: &[PF]) { + self.challenger.observe_scalars(scalars); + self.transcript.extend_from_slice(scalars); + } + + fn add_commitment(&mut self, commitment: &Self::Digest) { + self.challenger.observe_bytes(commitment); + self.commitments.push(*commitment); + } + + fn observe_scalars(&mut self, scalars: &[PF]) { + self.challenger.observe_scalars(scalars); + } + + fn state(&self) -> String { + format!("sha2 transcript (n_items: {})", self.transcript.len()) + } + + fn add_sumcheck_polynomial(&mut self, coeffs: &[EF], eq_alpha: Option) { + match eq_alpha { + None => { + let scalars = flatten_scalars_to_base(coeffs); + self.challenger.observe_scalars(&scalars); + self.transcript.extend_from_slice(&scalars[EF::DIMENSION..]); + } + Some(alpha) => { + let bare_scalars = flatten_scalars_to_base(coeffs); + let full_scalars = flatten_scalars_to_base(&expand_bare_to_full(coeffs, alpha)); + self.challenger.observe_scalars(&full_scalars); + self.transcript.extend_from_slice(&bare_scalars[EF::DIMENSION..]); + } + } + } + + fn hint_merkle_paths_base(&mut self, paths: Vec, Self::Digest>>) { + self.merkle_paths.push(MerklePaths(paths).prune()); + } + + fn pow_grinding(&mut self, bits: usize) { + assert!(bits < PF::::bits()); + + if bits == 0 { + return; + } + + let time = Instant::now(); + let challenger = self.challenger.clone(); + let witness = (0..PF::::ORDER_U32) + .into_par_iter() + .find_map_any(|candidate| { + let witness = unsafe { PF::::from_canonical_unchecked(candidate) }; + challenger + .pow_grinding_witness_matches(witness, bits) + .then_some(witness) + }) + .expect("failed to find witness"); + + self.challenger.observe_scalars(&[witness]); + self.transcript.push(witness); + + let elapsed = time.elapsed(); + POW_GRINDING_NANOS.fetch_add(elapsed.as_nanos() as u64, Ordering::Relaxed); + } +} diff --git a/crates/backend/fiat-shamir/src/traits.rs b/crates/backend/fiat-shamir/src/traits.rs index 5aba9f667..baf098b8a 100644 --- a/crates/backend/fiat-shamir/src/traits.rs +++ b/crates/backend/fiat-shamir/src/traits.rs @@ -13,11 +13,14 @@ pub trait ChallengeSampler { } pub trait FSProver>>: ChallengeSampler { + type Digest: Clone; + fn state(&self) -> String; fn add_base_scalars(&mut self, scalars: &[PF]); + fn add_commitment(&mut self, commitment: &Self::Digest); fn observe_scalars(&mut self, scalars: &[PF]); fn pow_grinding(&mut self, bits: usize); - fn hint_merkle_paths_base(&mut self, paths: Vec, PF>>); + fn hint_merkle_paths_base(&mut self, paths: Vec, Self::Digest>>); fn add_sumcheck_polynomial(&mut self, coeffs: &[EF], eq_alpha: Option); fn add_extension_scalars(&mut self, scalars: &[EF]) { @@ -28,7 +31,7 @@ pub trait FSProver>>: ChallengeSampler { self.add_extension_scalars(&[scalar]); } - fn hint_merkle_paths_extension(&mut self, paths: Vec>>) { + fn hint_merkle_paths_extension(&mut self, paths: Vec>) { self.hint_merkle_paths_base( paths .into_iter() @@ -43,10 +46,13 @@ pub trait FSProver>>: ChallengeSampler { } pub trait FSVerifier>>: ChallengeSampler { + type Digest: Clone; + fn state(&self) -> String; fn next_base_scalars_vec(&mut self, n: usize) -> Result>, ProofError>; + fn next_commitment(&mut self) -> Result; fn observe_scalars(&mut self, scalars: &[PF]); - fn next_merkle_opening(&mut self) -> Result>, ProofError>; + fn next_merkle_opening(&mut self) -> Result, Self::Digest>, ProofError>; fn check_pow_grinding(&mut self, bits: usize) -> Result<(), ProofError>; fn next_sumcheck_polynomial( &mut self, diff --git a/crates/backend/fiat-shamir/src/transcript.rs b/crates/backend/fiat-shamir/src/transcript.rs index 612c2d109..9b88dbd5d 100644 --- a/crates/backend/fiat-shamir/src/transcript.rs +++ b/crates/backend/fiat-shamir/src/transcript.rs @@ -6,9 +6,9 @@ use crate::PrunedMerklePaths; pub const DIGEST_LEN_FE: usize = 8; #[derive(Debug, Clone)] -pub struct MerkleOpening { +pub struct MerkleOpening { pub leaf_data: Vec, - pub path: Vec<[F; DIGEST_LEN_FE]>, + pub path: Vec, } /// "RawProof": the format which is used in the zkVM recursion program (no Merkle pruning, no sumcheck optimization to send less data, etc) @@ -19,20 +19,22 @@ pub struct RawProof { } #[derive(Debug, Clone)] -pub struct MerklePath { +pub struct MerklePath { pub leaf_data: Vec, - pub sibling_hashes: Vec<[F; DIGEST_LEN_FE]>, + pub sibling_hashes: Vec, // does not appear in the proof itself, but useful for Merkle pruning pub leaf_index: usize, } #[derive(Debug, Clone)] -pub struct MerklePaths(pub(crate) Vec>); +pub struct MerklePaths(pub(crate) Vec>); #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Proof { +pub struct Proof { pub(crate) transcript: Vec, - pub(crate) merkle_paths: Vec>, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub(crate) commitments: Vec, + pub(crate) merkle_paths: Vec>, } impl Proof { diff --git a/crates/backend/fiat-shamir/src/verifier.rs b/crates/backend/fiat-shamir/src/verifier.rs index 9bbc26bd7..3728119cf 100644 --- a/crates/backend/fiat-shamir/src/verifier.rs +++ b/crates/backend/fiat-shamir/src/verifier.rs @@ -3,14 +3,16 @@ use std::iter::repeat_n; use crate::{ MerkleOpening, MerklePaths, PrunedMerklePaths, RawProof, - challenger::{Challenger, RATE, WIDTH}, + challenger::{Challenger, ChallengerSha2, RATE, WIDTH}, transcript::{DIGEST_LEN_FE, Proof}, *, }; use field::PrimeCharacteristicRing; -use field::{ExtensionField, PrimeField64}; +use field::{ExtensionField, PrimeField32, PrimeField64}; use koala_bear::{KoalaBear, default_koalabear_poseidon1_16}; +use sha2::{Digest as Sha2DigestTrait, Sha256}; use symetric::Compression; +use symetric::merkle::Sha256Digest; pub struct VerifierState>, P> { challenger: Challenger, P>, @@ -26,6 +28,10 @@ where PF: PrimeField64, { pub fn new(proof: Proof>, compressor: C) -> Result { + if !proof.commitments.is_empty() { + return Err(ProofError::InvalidProof); + } + let mut merkle_openings = Vec::new(); for paths in proof.merkle_paths { let restored = Self::restore_merkle_paths(paths).ok_or(ProofError::InvalidProof)?; @@ -67,16 +73,18 @@ where } #[allow(clippy::missing_transmute_annotations)] - fn restore_merkle_paths(paths: PrunedMerklePaths, PF>) -> Option>>> { + fn restore_merkle_paths( + paths: PrunedMerklePaths, [PF; DIGEST_LEN_FE]>, + ) -> Option>>> { assert_eq!(TypeId::of::>(), TypeId::of::()); // SAFETY: We've confirmed PF == KoalaBear - let paths: PrunedMerklePaths = unsafe { std::mem::transmute(paths) }; + let paths: PrunedMerklePaths = unsafe { std::mem::transmute(paths) }; let perm = default_koalabear_poseidon1_16(); let hash_fn = |data: &[KoalaBear]| symetric::hash_slice::<_, _, 16, 8, DIGEST_LEN_FE>(&perm, data); let combine_fn = |left: &[KoalaBear; DIGEST_LEN_FE], right: &[KoalaBear; DIGEST_LEN_FE]| { symetric::compress(&perm, [*left, *right]) }; - let restored: MerklePaths = paths.restore(&hash_fn, &combine_fn)?; + let restored: MerklePaths = paths.restore(&hash_fn, &combine_fn)?; let openings: Vec> = restored .0 .into_iter() @@ -106,6 +114,8 @@ impl>, C: Compression<[PF; WIDTH]>> FSVerifier where PF: PrimeField64, { + type Digest = [PF; DIGEST_LEN_FE]; + fn state(&self) -> String { format!( "state {} (offset: {}, merkle_idx: {})", @@ -130,7 +140,12 @@ where Ok(scalars) } - fn next_merkle_opening(&mut self) -> Result>, ProofError> { + fn next_commitment(&mut self) -> Result { + self.next_base_scalars_vec(DIGEST_LEN_FE) + .map(|scalars| scalars.try_into().unwrap()) + } + + fn next_merkle_opening(&mut self) -> Result, Self::Digest>, ProofError> { if self.merkle_opening_index >= self.merkle_openings.len() { return Err(ProofError::ExceededTranscript); } @@ -191,3 +206,188 @@ where } } } + +pub struct VerifierStateSha2>> { + challenger: ChallengerSha2>, + transcript: Vec>, + transcript_offset: usize, + commitments: Vec, + commitment_index: usize, + merkle_openings: Vec, Sha256Digest>>, + merkle_opening_index: usize, +} + +impl>> VerifierStateSha2 +where + PF: PrimeField32, +{ + pub fn new(proof: Proof, Sha256Digest>) -> Result { + let mut merkle_openings = Vec::new(); + for paths in proof.merkle_paths { + let restored = Self::restore_merkle_paths(paths).ok_or(ProofError::InvalidProof)?; + merkle_openings.extend(restored); + } + + Ok(Self { + challenger: ChallengerSha2::new(), + transcript: proof.transcript, + transcript_offset: 0, + commitments: proof.commitments, + commitment_index: 0, + merkle_openings, + merkle_opening_index: 0, + }) + } + + fn read_transcript(&mut self, n: usize) -> Result>, ProofError> { + if self.transcript_offset + n > self.transcript.len() { + return Err(ProofError::ExceededTranscript); + } + let scalars = self.transcript[self.transcript_offset..self.transcript_offset + n].to_vec(); + self.transcript_offset += n; + Ok(scalars) + } + + fn restore_merkle_paths( + paths: PrunedMerklePaths, Sha256Digest>, + ) -> Option, Sha256Digest>>> { + let hash_fn = |data: &[PF]| { + let mut hasher = Sha256::new(); + for value in data { + hasher.update(value.as_canonical_u32().to_le_bytes()); + } + let digest = hasher.finalize(); + digest[..16].try_into().unwrap() + }; + let combine_fn = |left: &Sha256Digest, right: &Sha256Digest| { + let mut hasher = Sha256::new(); + hasher.update(left); + hasher.update(right); + let digest = hasher.finalize(); + digest[..16].try_into().unwrap() + }; + let restored = paths.restore(&hash_fn, &combine_fn)?; + Some( + restored + .0 + .into_iter() + .map(|path| MerkleOpening { + leaf_data: path.leaf_data, + path: path.sibling_hashes, + }) + .collect(), + ) + } +} + +impl>> ChallengeSampler for VerifierStateSha2 +where + PF: PrimeField32, +{ + fn sample_vec(&mut self, len: usize) -> Vec { + let sampled_fe = self + .challenger + .sample_many((len * EF::DIMENSION).div_ceil(RATE)) + .into_iter() + .flatten() + .take(len * EF::DIMENSION) + .collect::>>(); + pack_scalars_to_extension(&sampled_fe) + } + + fn sample_in_range(&mut self, bits: usize, n_samples: usize) -> Vec { + self.challenger.sample_in_range(bits, n_samples) + } +} + +impl>> FSVerifier for VerifierStateSha2 +where + PF: PrimeField32, +{ + type Digest = Sha256Digest; + + fn state(&self) -> String { + format!( + "sha2 verifier (offset: {}, commitment_idx: {}, merkle_idx: {})", + self.transcript_offset, self.commitment_index, self.merkle_opening_index, + ) + } + + fn observe_scalars(&mut self, scalars: &[PF]) { + self.challenger.observe_scalars(scalars); + } + + fn next_base_scalars_vec(&mut self, n: usize) -> Result>, ProofError> { + let scalars = self.read_transcript(n)?; + self.challenger.observe_scalars(&scalars); + Ok(scalars) + } + + fn next_commitment(&mut self) -> Result { + if self.commitment_index >= self.commitments.len() { + return Err(ProofError::ExceededTranscript); + } + let commitment = self.commitments[self.commitment_index]; + self.commitment_index += 1; + self.challenger.observe_bytes(&commitment); + Ok(commitment) + } + + fn next_merkle_opening(&mut self) -> Result, Self::Digest>, ProofError> { + if self.merkle_opening_index >= self.merkle_openings.len() { + return Err(ProofError::ExceededTranscript); + } + let opening = self.merkle_openings[self.merkle_opening_index].clone(); + self.merkle_opening_index += 1; + Ok(opening) + } + + fn check_pow_grinding(&mut self, bits: usize) -> Result<(), ProofError> { + if bits == 0 { + return Ok(()); + } + let witness = self.read_transcript(1)?[0]; + self.challenger.observe_scalars(&[witness]); + if !self.challenger.pow_grinding_sample_matches(bits) { + return Err(ProofError::InvalidGrindingWitness); + } + Ok(()) + } + + fn next_sumcheck_polynomial( + &mut self, + n_coeffs: usize, + claimed_sum: EF, + eq_alpha: Option, + ) -> ProofResult> { + match eq_alpha { + None => { + let rest_scalars = self.read_transcript((n_coeffs - 1) * EF::DIMENSION)?; + let rest_coeffs: Vec = pack_scalars_to_extension(&rest_scalars); + let c0 = (claimed_sum - rest_coeffs.iter().copied().sum::()).halve(); + + let mut full_coeffs = Vec::with_capacity(n_coeffs); + full_coeffs.push(c0); + full_coeffs.extend_from_slice(&rest_coeffs); + + let mut all_scalars = flatten_scalars_to_base(&[c0]); + all_scalars.extend_from_slice(&rest_scalars); + self.challenger.observe_scalars(&all_scalars); + Ok(full_coeffs) + } + Some(alpha) => { + let rest_scalars = self.read_transcript((n_coeffs - 2) * EF::DIMENSION)?; + let rest_bare: Vec = pack_scalars_to_extension(&rest_scalars); + let h0 = claimed_sum - alpha * rest_bare.iter().copied().sum::(); + + let mut bare = Vec::with_capacity(n_coeffs - 1); + bare.push(h0); + bare.extend_from_slice(&rest_bare); + + let full_coeffs = expand_bare_to_full(&bare, alpha); + self.challenger.observe_scalars(&flatten_scalars_to_base(&full_coeffs)); + Ok(full_coeffs) + } + } + } +} diff --git a/crates/backend/symetric/src/merkle.rs b/crates/backend/symetric/src/merkle.rs index 676e83f3e..891be80d2 100644 --- a/crates/backend/symetric/src/merkle.rs +++ b/crates/backend/symetric/src/merkle.rs @@ -16,6 +16,13 @@ pub struct MerkleTree { pub digest_layers: Vec>, } +pub type Sha256Digest = [u8; 16]; +/// A Merkle tree storing only the digest layers (no leaf data). +#[derive(Debug, Clone)] +pub struct MerkleTreeSha2 { + pub digest_layers: Vec>, +} + impl MerkleTree { /// Build a Merkle tree from a pre-computed first digest layer. pub fn from_first_layer(comp: &Comp, first_layer: Vec<[F; DIGEST_ELEMS]>) -> Self @@ -47,6 +54,19 @@ impl MerkleT } } +impl MerkleTreeSha2 { + #[must_use] + pub fn root(&self) -> Sha256Digest { + self.digest_layers.last().unwrap()[0] + } + + pub fn open_siblings(&self, index: usize, log_height: usize) -> Vec { + (0..log_height) + .map(|i| self.digest_layers[i][(index >> i) ^ 1]) + .collect() + } +} + pub fn compress_layer( prev_layer: &[[P::Value; DIGEST_ELEMS]], comp: &Comp, diff --git a/crates/lean_compiler/src/a_simplify_lang/mod.rs b/crates/lean_compiler/src/a_simplify_lang/mod.rs index e784d27f8..9de93f0db 100644 --- a/crates/lean_compiler/src/a_simplify_lang/mod.rs +++ b/crates/lean_compiler/src/a_simplify_lang/mod.rs @@ -8,7 +8,7 @@ use backend::PrimeCharacteristicRing; use lean_vm::{ ALL_POSEIDON16_NAMES, Boolean, BooleanExpr, CustomHint, ExtensionOpMode, FunctionName, POSEIDON16_HALF_HARDCODED_LEFT_NAME, POSEIDON16_HALF_NAME, POSEIDON16_HARDCODED_LEFT_NAME, PrecompileArgs, - PrecompileCompTimeArgs, SourceLocation, + PrecompileCompTimeArgs, SHA256_COMPRESS_NAME, SourceLocation, }; use std::{ collections::{BTreeMap, BTreeSet}, @@ -2308,6 +2308,32 @@ fn simplify_lines( continue; } + // Special handling for SHA256 compression precompile + if function_name == SHA256_COMPRESS_NAME { + if !targets.is_empty() { + return Err(format!( + "Precompile {function_name} should not return values, at {location}" + )); + } + if args.len() != 3 { + return Err(format!( + "Precompile {function_name} expects 3 arguments (state_ptr, block_ptr, out_ptr), got {}, at {location}", + args.len() + )); + } + let simplified_args = args + .iter() + .map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res)) + .collect::, _>>()?; + res.push(SimpleLine::Precompile(PrecompileArgs { + arg_0: simplified_args[0].clone(), + arg_1: simplified_args[1].clone(), + res: simplified_args[2].clone(), + data: PrecompileCompTimeArgs::Sha256Compress, + })); + continue; + } + // Special handling for custom hints if let Some(hint) = CustomHint::find_by_name(function_name) { if !targets.is_empty() { diff --git a/crates/lean_compiler/src/instruction_encoder.rs b/crates/lean_compiler/src/instruction_encoder.rs index 1060e3be4..d1658cb8d 100644 --- a/crates/lean_compiler/src/instruction_encoder.rs +++ b/crates/lean_compiler/src/instruction_encoder.rs @@ -59,6 +59,7 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { + POSEIDON_HARDCODED_LEFT_4_FLAG_SHIFT * flag_left + POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT * hardcoded_offset_left_val } + PrecompileCompTimeArgs::Sha256Compress => SHA256_PRECOMPILE_DATA, PrecompileCompTimeArgs::ExtensionOp { size, mode } => { assert!(*size >= 1, "invalid extension_op size={size}"); mode.flag_encoding() + EXT_OP_LEN_MULTIPLIER * size diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index 576e5af60..beebd81a5 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -8,7 +8,7 @@ use crate::{ grammar::{ParsePair, Rule}, }, }; -use lean_vm::{ALL_POSEIDON16_NAMES, CUSTOM_HINTS, ExtensionOpMode}; +use lean_vm::{ALL_POSEIDON16_NAMES, CUSTOM_HINTS, ExtensionOpMode, SHA256_COMPRESS_NAME}; /// Reserved function names that users cannot define. pub const RESERVED_FUNCTION_NAMES: &[&str] = &[ @@ -36,6 +36,9 @@ fn is_reserved_function_name(name: &str) -> bool { if ALL_POSEIDON16_NAMES.contains(&name) { return true; } + if name == SHA256_COMPRESS_NAME { + return true; + } if ExtensionOpMode::from_name(name).is_some() { return true; } diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 2c187a08e..635955b78 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -1,6 +1,6 @@ use std::time::Instant; -use backend::BasedVectorSpace; +use backend::{BasedVectorSpace, PrimeCharacteristicRing}; use lean_compiler::*; use lean_vm::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; @@ -26,6 +26,30 @@ def main(): let _ = dbg!(poseidon16_compress(public_input)); } +#[test] +fn test_sha256_compress() { + let program = r#" +def main(): + state = 0 + block = 16 + expected = 48 + out = Array(16) + sha256_compress(state, block, out) + + for i in unroll(0, 16): + assert out[i] == expected[i] + return + "#; + + let mut public_input = vec![F::ZERO; 64]; + public_input[0..16].copy_from_slice(&words_to_field_limbs_le(SHA256_IV)); + public_input[16..48].copy_from_slice(&words_to_field_limbs_le(SHA256_ABC_BLOCK)); + let expected = words_to_field_limbs_le(sha256_compress_words(SHA256_IV, SHA256_ABC_BLOCK)); + public_input[48..64].copy_from_slice(&expected); + + compile_and_run(&ProgramSource::Raw(program.to_string()), &public_input, false); +} + #[test] fn test_div_extension_field() { let program = r#" diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index fa86a3ae2..3e25281d1 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -1,22 +1,111 @@ use std::collections::BTreeMap; use crate::*; +use backend::merkle::Sha256Digest; use lean_vm::*; use serde::{Deserialize, Serialize}; use sub_protocols::*; use tracing::info_span; use utils::ansi::Colorize; -use utils::{build_prover_state, from_end}; +use utils::{build_prover_state, build_prover_state_sha2, from_end}; #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ExecutionProof { - pub proof: Proof, +#[serde(bound( + serialize = "Proof: Serialize", + deserialize = "Proof: Deserialize<'de>" +))] +pub struct ExecutionProof { + pub proof: Proof, // benchmark / debug purpose #[serde(skip, default)] pub metadata: Option, } +trait ExecutionProverBackend

+where + P: FSProver, +{ + type InnerWitness; + + fn stack_polynomials_and_commit( + prover_state: &mut P, + whir_config: &WhirConfigBuilder, + memory: &[F], + memory_acc: &[F], + bytecode_acc: &[F], + traces: &BTreeMap, + ) -> StackedPcsWitness; + + fn prove_whir( + whir_config: &WhirConfig, + prover_state: &mut P, + statements: Vec>, + witness: Self::InnerWitness, + polynomial: &MleRef<'_, EF>, + ); +} + +struct PoseidonExecutionBackend; + +impl

ExecutionProverBackend

for PoseidonExecutionBackend +where + P: FSProver, +{ + type InnerWitness = Witness; + + fn stack_polynomials_and_commit( + prover_state: &mut P, + whir_config: &WhirConfigBuilder, + memory: &[F], + memory_acc: &[F], + bytecode_acc: &[F], + traces: &BTreeMap, + ) -> StackedPcsWitness { + stack_polynomials_and_commit(prover_state, whir_config, memory, memory_acc, bytecode_acc, traces) + } + + fn prove_whir( + whir_config: &WhirConfig, + prover_state: &mut P, + statements: Vec>, + witness: Self::InnerWitness, + polynomial: &MleRef<'_, EF>, + ) { + whir_config.prove(prover_state, statements, witness, polynomial); + } +} + +struct Sha2ExecutionBackend; + +impl

ExecutionProverBackend

for Sha2ExecutionBackend +where + P: FSProver, +{ + type InnerWitness = Witness2; + + fn stack_polynomials_and_commit( + prover_state: &mut P, + whir_config: &WhirConfigBuilder, + memory: &[F], + memory_acc: &[F], + bytecode_acc: &[F], + traces: &BTreeMap, + ) -> StackedPcsWitness { + stack_polynomials_and_commit_sha2(prover_state, whir_config, memory, memory_acc, bytecode_acc, traces) + } + + fn prove_whir( + whir_config: &WhirConfig, + prover_state: &mut P, + statements: Vec>, + witness: Self::InnerWitness, + polynomial: &MleRef<'_, EF>, + ) { + whir_config.prove2(prover_state, statements, witness, polynomial); + } +} + pub fn prove_execution( bytecode: &Bytecode, public_input: &[F], @@ -24,6 +113,49 @@ pub fn prove_execution( whir_config: &WhirConfigBuilder, vm_profiler: bool, ) -> Result { + prove_execution_with::<_, PoseidonExecutionBackend, _>( + bytecode, + public_input, + witness, + whir_config, + vm_profiler, + build_prover_state(), + |prover_state| prover_state.into_proof(), + ) +} + +pub fn prove_execution_sha2( + bytecode: &Bytecode, + public_input: &[F], + witness: &ExecutionWitness, + whir_config: &WhirConfigBuilder, + vm_profiler: bool, +) -> Result, ProverError> { + prove_execution_with::<_, Sha2ExecutionBackend, _>( + bytecode, + public_input, + witness, + whir_config, + vm_profiler, + build_prover_state_sha2(), + |prover_state| prover_state.into_proof(), + ) +} + +fn prove_execution_with( + bytecode: &Bytecode, + public_input: &[F], + witness: &ExecutionWitness, + whir_config: &WhirConfigBuilder, + vm_profiler: bool, + mut prover_state: P, + into_proof: IntoProof, +) -> Result, ProverError> +where + P: FSProver, + B: ExecutionProverBackend

, + IntoProof: FnOnce(P) -> Proof, +{ check_rate(whir_config.starting_log_inv_rate) .map_err(|err| panic!("{err}")) .unwrap(); @@ -44,23 +176,22 @@ pub fn prove_execution( if memory.len() < min_memory_size { memory.resize(min_memory_size, F::ZERO); } - let mut prover_state = build_prover_state(); prover_state.observe_scalars(public_input); - prover_state.observe_scalars(&poseidon16_compress_pair(&bytecode.hash, &SNARK_DOMAIN_SEP)); - prover_state.add_base_scalars( - &[ - vec![ - whir_config.starting_log_inv_rate, - log2_strict_usize(memory.len()), - public_input.len(), - ], - traces.values().map(|t| t.log_n_rows).collect::>(), - ] - .concat() - .into_iter() - .map(F::from_usize) - .collect::>(), - ); + let bytecode_hash_with_domain_sep = poseidon16_compress_pair(&bytecode.hash, &SNARK_DOMAIN_SEP); + prover_state.observe_scalars(&bytecode_hash_with_domain_sep); + let execution_metadata_scalars = [ + vec![ + whir_config.starting_log_inv_rate, + log2_strict_usize(memory.len()), + public_input.len(), + ], + traces.values().map(|t| t.log_n_rows).collect::>(), + ] + .concat() + .into_iter() + .map(F::from_usize) + .collect::>(); + prover_state.add_base_scalars(&execution_metadata_scalars); for (table, table_trace) in &traces { let log_n_rows = table_trace.log_n_rows; assert!(log_n_rows >= MIN_LOG_N_ROWS_PER_TABLE, "missing padding"); @@ -110,7 +241,7 @@ pub fn prove_execution( }); // 1st Commitment - let stacked_pcs_witness = stack_polynomials_and_commit( + let stacked_pcs_witness = B::stack_polynomials_and_commit( &mut prover_state, whir_config, &memory, @@ -245,7 +376,6 @@ pub fn prove_execution( )], ), ]; - let global_statements_base = stacked_pcs_global_statements( stacked_pcs_witness.stacked_n_vars, log2_strict_usize(memory.len()), @@ -255,7 +385,8 @@ pub fn prove_execution( &committed_statements, ); - WhirConfig::new(whir_config, stacked_pcs_witness.global_polynomial.by_ref().n_vars()).prove( + B::prove_whir( + &WhirConfig::new(whir_config, stacked_pcs_witness.global_polynomial.by_ref().n_vars()), &mut prover_state, global_statements_base, stacked_pcs_witness.inner_witness, @@ -266,7 +397,7 @@ pub fn prove_execution( reset_pow_grinding_time(); Ok(ExecutionProof { - proof: prover_state.into_proof(), + proof: into_proof(prover_state), metadata: Some(metadata), }) } diff --git a/crates/lean_prover/src/test_zkvm.rs b/crates/lean_prover/src/test_zkvm.rs index f4e87c947..46bb04b86 100644 --- a/crates/lean_prover/src/test_zkvm.rs +++ b/crates/lean_prover/src/test_zkvm.rs @@ -1,9 +1,56 @@ -use crate::{default_whir_config, prove_execution::prove_execution, verify_execution::verify_execution}; +use crate::{ + default_whir_config, + prove_execution::{prove_execution, prove_execution_sha2}, + verify_execution::verify, +}; use backend::*; use lean_compiler::*; use lean_vm::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; -use utils::{init_tracing, poseidon16_compress}; +use utils::{get_poseidon16, init_tracing, poseidon16_compress}; + +#[test] +#[ignore = "benchmark; run with `cargo test --release -p lean_prover bench_sha256_compress -- --ignored --nocapture`"] +fn bench_sha256_compress() { + utils::init_tracing(); + let n_sha_calls = std::env::var("SHA256_BENCH_CALLS") + .ok() + .map(|raw| raw.parse::().expect("SHA256_BENCH_CALLS must be a usize")) + .unwrap_or(1); + const SHA_FIXTURE_STRIDE: usize = SHA256_STATE_LIMBS + SHA256_BLOCK_LIMBS + SHA256_STATE_LIMBS; + let program_str = format!( + r#" +N_SHA_CALLS = {n_sha_calls} +SHA_FIXTURE_STRIDE = 64 + +def main(): + for j in unroll(0, N_SHA_CALLS): + base = j * SHA_FIXTURE_STRIDE + state = base + block = base + 16 + expected = base + 48 + out = Array(16) + sha256_compress(state, block, out) + + for i in unroll(0, 16): + assert out[i] == expected[i] + return +"# + ); + + let mut public_input = vec![F::ZERO; n_sha_calls * SHA_FIXTURE_STRIDE]; + let expected = words_to_field_limbs_le(sha256_compress_words(SHA256_IV, SHA256_ABC_BLOCK)); + for j in 0..n_sha_calls { + let base = j * SHA_FIXTURE_STRIDE; + public_input[base..base + SHA256_STATE_LIMBS].copy_from_slice(&words_to_field_limbs_le(SHA256_IV)); + public_input[base + SHA256_STATE_LIMBS..base + SHA256_STATE_LIMBS + SHA256_BLOCK_LIMBS] + .copy_from_slice(&words_to_field_limbs_le(SHA256_ABC_BLOCK)); + public_input[base + SHA256_STATE_LIMBS + SHA256_BLOCK_LIMBS..base + SHA_FIXTURE_STRIDE] + .copy_from_slice(&expected); + } + + test_zk_vm_helper(&program_str, &public_input); +} #[test] fn test_zk_vm_all_precompiles() { @@ -18,6 +65,15 @@ def main(): pub_start = 0 poseidon16_compress(pub_start + 4 * DIGEST_LEN, pub_start + 5 * DIGEST_LEN, pub_start + 6 * DIGEST_LEN) + # Keep the SHA fixture away from the extension-op fixture ranges below. + sha_state = pub_start + 1400 + sha_block = sha_state + 16 + sha_expected = sha_block + 32 + sha_out = Array(16) + sha256_compress(sha_state, sha_block, sha_out) + for i in unroll(0, 16): + assert sha_out[i] == sha_expected[i] + # poseidon16_compress_half: only first 4 FE constrained full_out = pub_start + 6 * DIGEST_LEN half_out = pub_start + 80 @@ -103,6 +159,17 @@ def main(): F::from_usize(444), ]); + // SHA256 compression test data: IV + padded "abc" block. + let sha_state_ptr = 1400; + let sha_block_ptr = sha_state_ptr + SHA256_STATE_LIMBS; + let sha_expected_ptr = sha_block_ptr + SHA256_BLOCK_LIMBS; + public_input[sha_state_ptr..sha_state_ptr + SHA256_STATE_LIMBS] + .copy_from_slice(&words_to_field_limbs_le(SHA256_IV)); + public_input[sha_block_ptr..sha_block_ptr + SHA256_BLOCK_LIMBS] + .copy_from_slice(&words_to_field_limbs_le(SHA256_ABC_BLOCK)); + let sha_expected = words_to_field_limbs_le(sha256_compress_words(SHA256_IV, SHA256_ABC_BLOCK)); + public_input[sha_expected_ptr..sha_expected_ptr + SHA256_STATE_LIMBS].copy_from_slice(&sha_expected); + let hardcoded_data: [F; 4] = rng.random(); let hardcoded_prefix: [F; 4] = rng.random(); public_input[1496..1500].copy_from_slice(&hardcoded_data); @@ -235,19 +302,65 @@ def fibonacci_const(a, b, n: Const): fn test_zk_vm_helper(program_str: &str, public_input: &[F]) { utils::init_tracing(); let bytecode = compile_program(&ProgramSource::Raw(program_str.to_string())); - let time = std::time::Instant::now(); + + test_zk_vm_bytecode_helper_poseidon(&bytecode, public_input); + test_zk_vm_bytecode_helper_sha2(&bytecode, public_input); +} + +fn test_zk_vm_helper_poseidon(program_str: &str, public_input: &[F]) { + utils::init_tracing(); + let bytecode = compile_program(&ProgramSource::Raw(program_str.to_string())); + test_zk_vm_bytecode_helper_poseidon(&bytecode, public_input); +} + +fn test_zk_vm_helper_sha2(program_str: &str, public_input: &[F]) { + utils::init_tracing(); + let bytecode = compile_program(&ProgramSource::Raw(program_str.to_string())); + test_zk_vm_bytecode_helper_sha2(&bytecode, public_input); +} + +fn test_zk_vm_bytecode_helper_poseidon(bytecode: &Bytecode, public_input: &[F]) { let starting_log_inv_rate = 1; let witness = ExecutionWitness::default(); + + let time = std::time::Instant::now(); let proof = prove_execution( - &bytecode, + bytecode, public_input, &witness, &default_whir_config(starting_log_inv_rate), false, ) .unwrap(); - let proof_time = time.elapsed(); - verify_execution(&bytecode, public_input, proof.proof).unwrap(); + let poseidon_proof_time = time.elapsed(); + + println!("Poseidon proof"); println!("{}", proof.metadata.as_ref().unwrap().display()); - println!("Proof time: {:.3} s", proof_time.as_secs_f32()); + println!("Proof time: {:.3} s", poseidon_proof_time.as_secs_f32()); + + let mut verifier_state = VerifierState::::new(proof.proof, get_poseidon16().clone()).unwrap(); + verify(bytecode, public_input, &mut verifier_state).unwrap(); +} + +fn test_zk_vm_bytecode_helper_sha2(bytecode: &Bytecode, public_input: &[F]) { + let starting_log_inv_rate = 1; + let witness = ExecutionWitness::default(); + + let time = std::time::Instant::now(); + let proof2 = prove_execution_sha2( + bytecode, + public_input, + &witness, + &default_whir_config(starting_log_inv_rate), + false, + ) + .unwrap(); + let sha2_proof_time = time.elapsed(); + + println!("SHA2 proof"); + println!("{}", proof2.metadata.as_ref().unwrap().display()); + println!("Proof time: {:.3} s", sha2_proof_time.as_secs_f32()); + + let mut verifier_state2 = VerifierStateSha2::::new(proof2.proof).unwrap(); + verify(bytecode, public_input, &mut verifier_state2).unwrap(); } diff --git a/crates/lean_prover/src/trace_gen.rs b/crates/lean_prover/src/trace_gen.rs index 1801a5b62..5d0aa9346 100644 --- a/crates/lean_prover/src/trace_gen.rs +++ b/crates/lean_prover/src/trace_gen.rs @@ -99,6 +99,24 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul let null_poseidon_16_hash_ptr = memory_padded.len(); memory_padded.extend_from_slice(get_poseidon_16_of_zero()); + let sha256_padding_state_ptr = memory_padded.len(); + memory_padded.extend(words_to_field_limbs_le(SHA256_IV)); + let sha256_padding_block_ptr = memory_padded.len(); + memory_padded.extend(words_to_field_limbs_le(SHA256_ZERO_BLOCK)); + let sha256_padding_out_ptr = memory_padded.len(); + memory_padded.extend(words_to_field_limbs_le(sha256_compress_words( + SHA256_IV, + SHA256_ZERO_BLOCK, + ))); + + let padding_memory = PaddingMemory { + zero_vec_ptr: padding_zero_vec_ptr, + null_poseidon_16_hash_ptr, + sha256_state_ptr: sha256_padding_state_ptr, + sha256_block_ptr: sha256_padding_block_ptr, + sha256_out_ptr: sha256_padding_out_ptr, + }; + // IMPORTANT: memory size should always be >= number of VM cycles let padded_memory_len = (memory_padded.len().max(n_cycles).max(1 << MIN_LOG_N_ROWS_PER_TABLE)).next_power_of_two(); memory_padded.resize(padded_memory_len, F::ZERO); @@ -142,7 +160,7 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul }, ); for table in traces.keys().copied().collect::>() { - pad_table(&table, &mut traces, padding_zero_vec_ptr, null_poseidon_16_hash_ptr); + pad_table(&table, &mut traces, &padding_memory); } ExecutionTrace { @@ -153,12 +171,7 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul } } -fn pad_table( - table: &Table, - traces: &mut BTreeMap, - zero_vec_ptr: usize, - null_poseidon_16_hash_ptr: usize, -) { +fn pad_table(table: &Table, traces: &mut BTreeMap, padding_memory: &PaddingMemory) { let trace = traces.get_mut(table).unwrap(); let h = trace.columns[0].len(); trace @@ -170,7 +183,7 @@ fn pad_table( trace.non_padded_n_rows = h; trace.log_n_rows = log2_ceil_usize(h + 1).max(MIN_LOG_N_ROWS_PER_TABLE); let n_rows = 1 << trace.log_n_rows; - let padding_row = table.padding_row(zero_vec_ptr, null_poseidon_16_hash_ptr); + let padding_row = table.padding_row(padding_memory); trace.columns.par_iter_mut().enumerate().for_each(|(i, col)| { assert!(col.len() <= h); // potentially some columns have not been filled (in Poseidon -> we fill it later with SIMD + parallelism), but the first one should always be representative col.resize(n_rows, padding_row[i]); diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index 0909691d7..cdfcf0cb3 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -1,22 +1,24 @@ use std::collections::BTreeMap; use crate::*; -use backend::{Proof, RawProof, VerifierState}; use lean_vm::*; use sub_protocols::*; -use utils::{ToUsize, from_end, get_poseidon16}; +use utils::{ToUsize, from_end}; #[derive(Debug, Clone)] pub struct ProofVerificationDetails { pub bytecode_evaluation: Evaluation, } -pub fn verify_execution( +pub fn verify( bytecode: &Bytecode, public_input: &[F], - proof: Proof, -) -> Result<(ProofVerificationDetails, RawProof), ProofError> { - let mut verifier_state = VerifierState::::new(proof, get_poseidon16().clone())?; + verifier_state: &mut V, +) -> Result +where + V: FSVerifier, + V::Digest: WhirVerifierDigest, +{ verifier_state.observe_scalars(public_input); verifier_state.observe_scalars(&poseidon16_compress_pair(&bytecode.hash, &SNARK_DOMAIN_SEP)); let dims = verifier_state @@ -26,7 +28,7 @@ pub fn verify_execution( .collect::>(); let log_inv_rate = dims[0]; let log_memory = dims[1]; - let public_input_len = dims[2]; // enforce the exact length of the public input to pass through Fiat Shamir (otherwise we could have 2 public inputs, only differing by a few (<8) zeros in the end, leading to the same fiat shamir state: tipically giving the advseary 2 or 3 bits of advantage in the subsequent part where the public input is evaluated as a multilinear polynomial) + let public_input_len = dims[2]; if public_input_len != public_input.len() { return Err(ProofError::InvalidProof); } @@ -47,7 +49,6 @@ pub fn verify_execution( .into()); } } - // check memory is bigger than any other table if log_memory < (*table_n_vars.values().max().unwrap()).max(bytecode.log_size()) { return Err(ProofError::InvalidProof); } @@ -62,9 +63,9 @@ pub fn verify_execution( return Err(ProofError::InvalidProof); } - let parsed_commitment = stacked_pcs_parse_commitment( + let parsed_commitment = stacked_pcs_parse_commitment_generic( &whir_config, - &mut verifier_state, + verifier_state, log_memory, bytecode.log_size(), &table_n_vars, @@ -75,7 +76,7 @@ pub fn verify_execution( let logup_alphas_eq_poly = eval_eq(&logup_alphas); let logup_statements = verify_generic_logup( - &mut verifier_state, + verifier_state, logup_c, &logup_alphas, &logup_alphas_eq_poly, @@ -100,7 +101,7 @@ pub fn verify_execution( let bus_beta = verifier_state.sample(); let air_alpha = verifier_state.sample(); let air_alpha_powers: Vec = air_alpha.powers().collect_n(max_air_constraints() + 1); - let eta: EF = verifier_state.sample(); // batching the sumchecks proving validity of AIR tables + let eta: EF = verifier_state.sample(); let tables_sorted = sort_tables_by_height(&table_n_vars); @@ -140,7 +141,7 @@ pub fn verify_execution( let Evaluation { point: sumcheck_air_point, value: claimed_air_final_value, - } = sumcheck_verify(&mut verifier_state, n_max, max_full_degree, initial_sum, None)?; + } = sumcheck_verify(verifier_state, n_max, max_full_degree, initial_sum, None)?; let mut my_air_final_value = EF::ZERO; for vd in &verify_data { @@ -211,22 +212,18 @@ pub fn verify_execution( &committed_statements, ); - // sanity check (not necessary for soundness) let num_whir_statements = global_statements_base.iter().map(|s| s.values.len()).sum::(); assert_eq!(num_whir_statements, total_whir_statements()); WhirConfig::new(&whir_config, parsed_commitment.num_variables).verify( - &mut verifier_state, + verifier_state, &parsed_commitment, global_statements_base, )?; - Ok(( - ProofVerificationDetails { - bytecode_evaluation: logup_statements.bytecode_evaluation.unwrap(), - }, - verifier_state.into_raw_proof(), - )) + Ok(ProofVerificationDetails { + bytecode_evaluation: logup_statements.bytecode_evaluation.unwrap(), + }) } fn back_loaded_table_contribution>>( diff --git a/crates/lean_vm/Cargo.toml b/crates/lean_vm/Cargo.toml index 32e50f199..a138feb5d 100644 --- a/crates/lean_vm/Cargo.toml +++ b/crates/lean_vm/Cargo.toml @@ -15,3 +15,6 @@ rand.workspace = true tracing.workspace = true backend.workspace = true itertools.workspace = true + +[dev-dependencies] +sha2.workspace = true diff --git a/crates/lean_vm/src/core/constants.rs b/crates/lean_vm/src/core/constants.rs index afe6bc2d8..088054ad2 100644 --- a/crates/lean_vm/src/core/constants.rs +++ b/crates/lean_vm/src/core/constants.rs @@ -21,10 +21,13 @@ pub const MIN_BYTECODE_LOG_SIZE: usize = 8; /// Minimum and maximum number of rows per table (as powers of two), both inclusive pub const MIN_LOG_N_ROWS_PER_TABLE: usize = 8; // Zero padding will be added to each at least, if this minimum is not reached, (ensuring AIR / GKR work fine, with SIMD, without too much edge cases). Long term, we should find a more elegant solution. -pub const MAX_LOG_N_ROWS_PER_TABLE: [(Table, usize); 3] = [ +pub const MAX_LOG_N_ROWS_PER_TABLE: [(Table, usize); 4] = [ (Table::execution(), 24), (Table::extension_op(), 21), (Table::poseidon16(), 21), + // Direct Plonky3-style SHA256 has 7524 columns. 2^13 rows already exceeds + // the current commitment-surface guard; 2^12 is the largest safe cap today. + (Table::sha256_compress(), 12), ]; pub fn max_log_n_rows_per_table(table: &Table) -> usize { diff --git a/crates/lean_vm/src/diagnostics/error.rs b/crates/lean_vm/src/diagnostics/error.rs index 16492d708..48cc51fed 100644 --- a/crates/lean_vm/src/diagnostics/error.rs +++ b/crates/lean_vm/src/diagnostics/error.rs @@ -23,6 +23,7 @@ pub enum RunnerError { range: usize, }, InvalidExtensionOp, + InvalidSha256Input, ParallelSegmentFailed(usize, Box), } @@ -55,6 +56,7 @@ impl Display for RunnerError { ) } Self::InvalidExtensionOp => write!(f, "invalid extension op"), + Self::InvalidSha256Input => write!(f, "invalid sha256 input"), Self::ParallelSegmentFailed(id, err) => { write!(f, "parallel segment {id} failed: {err}") } diff --git a/crates/lean_vm/src/diagnostics/exec_result.rs b/crates/lean_vm/src/diagnostics/exec_result.rs index dcb1ae0cd..df9505c22 100644 --- a/crates/lean_vm/src/diagnostics/exec_result.rs +++ b/crates/lean_vm/src/diagnostics/exec_result.rs @@ -10,6 +10,7 @@ pub struct ExecutionMetadata { pub cycles: usize, pub memory: usize, pub n_poseidons: usize, + pub n_sha256_compress: usize, pub n_extension_ops: usize, pub bytecode_size: usize, pub public_input_size: usize, @@ -57,6 +58,12 @@ impl ExecutionMetadata { self.cycles / self.n_poseidons )); } + if self.n_sha256_compress > 0 { + out.push_str(&format!( + "SHA256Compress calls: {}\n", + pretty_integer(self.n_sha256_compress) + )); + } if self.n_extension_ops > 0 { out.push_str(&format!( "ExtensionOp calls: {}\n", diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs index 5d244f62f..2c8e35cd2 100644 --- a/crates/lean_vm/src/execution/runner.rs +++ b/crates/lean_vm/src/execution/runner.rs @@ -331,6 +331,7 @@ fn execute_bytecode_helper( cycles: trace.pcs.len(), memory: memory.0.len(), n_poseidons: trace.tables[&Table::poseidon16()].columns[0].len(), + n_sha256_compress: trace.tables[&Table::sha256_compress()].columns[0].len(), n_extension_ops: trace.tables[&Table::extension_op()].columns[0].len(), bytecode_size: bytecode.code.len(), public_input_size: public_input.len(), diff --git a/crates/lean_vm/src/isa/instruction.rs b/crates/lean_vm/src/isa/instruction.rs index f0b7ef212..c05afd287 100644 --- a/crates/lean_vm/src/isa/instruction.rs +++ b/crates/lean_vm/src/isa/instruction.rs @@ -2,12 +2,11 @@ use super::Operation; use super::operands::{MemOrConstant, MemOrFpOrConstant}; -use crate::POSEIDON16_NAME; use crate::core::{F, Label}; use crate::diagnostics::RunnerError; use crate::execution::memory::MemoryAccess; use crate::tables::TableT; -use crate::{ExtensionOpMode, Table, TableTrace}; +use crate::{ExtensionOpMode, POSEIDON16_NAME, SHA256_COMPRESS_NAME, Table, TableTrace}; use backend::*; use std::collections::BTreeMap; use std::fmt::{Display, Formatter}; @@ -69,6 +68,7 @@ pub enum PrecompileCompTimeArgs { // hardcoded_offset_left = Some(offset_left): left_input = m[offset_left..offset_left+4] | m[arg_a..arg_a+4] (arg_a is the first runtime parameter) hardcoded_offset_left: Option, }, + Sha256Compress, ExtensionOp { size: S, mode: ExtensionOpMode, @@ -79,6 +79,7 @@ impl PrecompileCompTimeArgs { pub fn table(&self) -> Table { match self { Self::Poseidon16 { .. } => Table::poseidon16(), + Self::Sha256Compress => Table::sha256_compress(), Self::ExtensionOp { .. } => Table::extension_op(), } } @@ -92,6 +93,7 @@ impl PrecompileCompTimeArgs { half_output, hardcoded_offset_left: hardcoded_left_4.map(&mut f), }, + Self::Sha256Compress => PrecompileCompTimeArgs::Sha256Compress, Self::ExtensionOp { size, mode } => PrecompileCompTimeArgs::ExtensionOp { size: f(size), mode }, } } @@ -262,6 +264,9 @@ impl Display for PrecompileArgs { "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, half, hardcoded_left_4={off})" ), }, + PrecompileCompTimeArgs::Sha256Compress => { + write!(f, "{SHA256_COMPRESS_NAME}({arg_0}, {arg_1}, {res})") + } PrecompileCompTimeArgs::ExtensionOp { size, mode } => { write!(f, "{}({arg_0}, {arg_1}, {res}, {size})", mode.name()) } diff --git a/crates/lean_vm/src/tables/execution/mod.rs b/crates/lean_vm/src/tables/execution/mod.rs index 10b854c04..0f13e08fe 100644 --- a/crates/lean_vm/src/tables/execution/mod.rs +++ b/crates/lean_vm/src/tables/execution/mod.rs @@ -56,7 +56,7 @@ impl TableT for ExecutionTable { } } - fn padding_row(&self, zero_vec_ptr: usize, _null_hash_ptr: usize) -> Vec { + fn padding_row(&self, padding: &PaddingMemory) -> Vec { let mut padding_row = vec![F::ZERO; N_TOTAL_EXECUTION_COLUMNS + N_TEMPORARY_EXEC_COLUMNS]; padding_row[COL_PC] = F::from_usize(ENDING_PC); padding_row[COL_JUMP] = F::ONE; @@ -65,9 +65,9 @@ impl TableT for ExecutionTable { padding_row[COL_FLAG_B] = F::ONE; padding_row[COL_FLAG_C_FP] = F::ONE; // this is kind of arbitrary padding_row[COL_EXEC_NU_A] = F::ONE; // because at the end of program, we always jump (looping at pc=0, so condition = nu_a = 1) - padding_row[COL_MEM_ADDRESS_A] = F::from_usize(zero_vec_ptr); - padding_row[COL_MEM_ADDRESS_B] = F::from_usize(zero_vec_ptr); - padding_row[COL_MEM_ADDRESS_C] = F::from_usize(zero_vec_ptr); + padding_row[COL_MEM_ADDRESS_A] = F::from_usize(padding.zero_vec_ptr); + padding_row[COL_MEM_ADDRESS_B] = F::from_usize(padding.zero_vec_ptr); + padding_row[COL_MEM_ADDRESS_C] = F::from_usize(padding.zero_vec_ptr); padding_row } diff --git a/crates/lean_vm/src/tables/extension_op/mod.rs b/crates/lean_vm/src/tables/extension_op/mod.rs index 03cc0045c..c7cb4583c 100644 --- a/crates/lean_vm/src/tables/extension_op/mod.rs +++ b/crates/lean_vm/src/tables/extension_op/mod.rs @@ -122,14 +122,14 @@ impl TableT for ExtensionOpPrecompile { self.n_columns() + 2 // +2 for COL_ACTIVATION_FLAG and COL_AUX_EXTENSION_OP (non-AIR, used in bus logup) } - fn padding_row(&self, zero_vec_ptr: usize, _null_hash_ptr: usize) -> Vec { + fn padding_row(&self, padding: &PaddingMemory) -> Vec { let mut row = vec![F::ZERO; self.n_columns_total()]; row[COL_START] = F::ONE; row[COL_LEN] = F::ONE; row[COL_AUX_EXTENSION_OP] = F::from_usize(EXT_OP_LEN_MULTIPLIER); - row[COL_IDX_A] = F::from_usize(zero_vec_ptr); - row[COL_IDX_B] = F::from_usize(zero_vec_ptr); - row[COL_IDX_RES] = F::from_usize(zero_vec_ptr); + row[COL_IDX_A] = F::from_usize(padding.zero_vec_ptr); + row[COL_IDX_B] = F::from_usize(padding.zero_vec_ptr); + row[COL_IDX_RES] = F::from_usize(padding.zero_vec_ptr); row } diff --git a/crates/lean_vm/src/tables/mod.rs b/crates/lean_vm/src/tables/mod.rs index bf9523291..c90f1dd41 100644 --- a/crates/lean_vm/src/tables/mod.rs +++ b/crates/lean_vm/src/tables/mod.rs @@ -4,6 +4,9 @@ pub use extension_op::*; mod poseidon_16; pub use poseidon_16::*; +pub mod sha256_compress; +pub use sha256_compress::*; + mod table_enum; pub use table_enum::*; @@ -16,10 +19,10 @@ pub use execution::*; mod utils; pub(crate) use utils::*; -// `PRECOMPILE_DATA` is the bus discriminator separating the two precompile -// tables. Disjointness is by parity of bit 0: +// `PRECOMPILE_DATA` is the bus discriminator separating precompile tables: // // Poseidon16 (odd): 1 + 2·flag_half + 4·flag_left + 8·flag_left·offset_left +// Sha256 (even): 2 // ExtensionOp (even): 4·is_be + 8·flag_add + 16·flag_mul + 32·flag_poly_eq + 64·len // // Multiplying `offset_left` by `flag_left` is needed for soundness: see 3.4.1 in minimal_zkVM.pdf diff --git a/crates/lean_vm/src/tables/poseidon_16/mod.rs b/crates/lean_vm/src/tables/poseidon_16/mod.rs index 9d0994d12..d0964b823 100644 --- a/crates/lean_vm/src/tables/poseidon_16/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_16/mod.rs @@ -174,7 +174,7 @@ impl TableT for Poseidon16Precompile { } } - fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize) -> Vec { + fn padding_row(&self, padding: &PaddingMemory) -> Vec { let mut row = vec![F::ZERO; num_cols_total_poseidon_16()]; let ptrs: Vec<*mut F> = (0..num_cols_poseidon_16()) .map(|i| unsafe { row.as_mut_ptr().add(i) }) @@ -183,15 +183,15 @@ impl TableT for Poseidon16Precompile { let perm: &mut Poseidon1Cols16<&mut F> = unsafe { &mut *(ptrs.as_ptr() as *mut Poseidon1Cols16<&mut F>) }; perm.inputs.iter_mut().for_each(|x| **x = F::ZERO); *perm.flag_active = F::ZERO; - *perm.index_b = F::from_usize(zero_vec_ptr); - *perm.index_res = F::from_usize(null_hash_ptr); + *perm.index_b = F::from_usize(padding.zero_vec_ptr); + *perm.index_res = F::from_usize(padding.null_poseidon_16_hash_ptr); *perm.flag_half_output = F::ZERO; *perm.flag_hardcoded_left = F::ZERO; *perm.offset_hardcoded_left = F::ZERO; - *perm.effective_index_left_first = F::from_usize(zero_vec_ptr); - *perm.effective_index_left_second = F::from_usize(zero_vec_ptr + HALF_DIGEST_LEN); + *perm.effective_index_left_first = F::from_usize(padding.zero_vec_ptr); + *perm.effective_index_left_second = F::from_usize(padding.zero_vec_ptr + HALF_DIGEST_LEN); // Non-committed columns - row[POSEIDON_16_COL_INDEX_INPUT_LEFT] = F::from_usize(zero_vec_ptr); + row[POSEIDON_16_COL_INDEX_INPUT_LEFT] = F::from_usize(padding.zero_vec_ptr); row[POSEIDON_16_COL_PRECOMPILE_DATA] = F::from_usize(POSEIDON_PRECOMPILE_DATA); generate_trace_rows_for_perm(perm); diff --git a/crates/lean_vm/src/tables/sha256_compress/air.rs b/crates/lean_vm/src/tables/sha256_compress/air.rs new file mode 100644 index 000000000..cbab2af8e --- /dev/null +++ b/crates/lean_vm/src/tables/sha256_compress/air.rs @@ -0,0 +1,373 @@ +use backend::*; + +use super::{ + NUM_SHA256_COMPRESS_COLS, SHA256_BLOCK_WORDS, SHA256_CHAIN_LEN, SHA256_K, SHA256_PRECOMPILE_DATA, + SHA256_SCHEDULE_EXTENSIONS, SHA256_U32_LIMBS, SHA256_WORD_BITS, Sha256Cols, Sha256CompressCols, + Sha256CompressPrecompile, +}; +use crate::{EF, ExtraDataForBuses, eval_virtual_bus_column}; + +const BITS_PER_LIMB: usize = 16; + +impl Air for Sha256CompressPrecompile { + type ExtraData = ExtraDataForBuses; + + fn n_columns(&self) -> usize { + NUM_SHA256_COMPRESS_COLS + } + + fn degree_air(&self) -> usize { + 3 + } + + fn n_constraints(&self) -> usize { + 7840 + 32 + 1 + BUS as usize + } + + fn down_column_indexes(&self) -> Vec { + vec![] + } + + fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { + let cols: &Sha256CompressCols = { + let up = builder.up(); + let (prefix, shorts, suffix) = unsafe { up.align_to::>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + unsafe { &*shorts.as_ptr() } + }; + + if BUS { + builder.assert_zero_ef(eval_virtual_bus_column::( + extra_data, + cols.flag, + &[ + AB::IF::from_usize(SHA256_PRECOMPILE_DATA), + cols.state_ptr, + cols.block_ptr, + cols.out_ptr, + ], + )); + } else { + builder.declare_values(std::slice::from_ref(&cols.flag)); + builder.declare_values(&[ + AB::IF::from_usize(SHA256_PRECOMPILE_DATA), + cols.state_ptr, + cols.block_ptr, + cols.out_ptr, + ]); + } + + builder.assert_bool(cols.flag); + eval_sha256_air(builder, &cols.sha); + eval_block_limb_bridges(builder, &cols); + } +} + +fn eval_sha256_air(builder: &mut AB, local: &Sha256Cols) { + eval_bit_range_checks(builder, local); + eval_initial_state(builder, local); + eval_message_schedule(builder, local); + eval_compression(builder, local); + eval_finalization(builder, local); +} + +fn eval_bit_range_checks(builder: &mut AB, local: &Sha256Cols) { + for word in &local.w { + assert_bools(builder, word); + } + for word in &local.a_chain { + assert_bools(builder, word); + } + for word in &local.e_chain { + assert_bools(builder, word); + } +} + +fn eval_initial_state(builder: &mut AB, local: &Sha256Cols) { + for i in 0..4 { + let chain_idx = 3 - i; + assert_packed_equals_bits(builder, &local.h_in[i], &local.a_chain[chain_idx]); + } + for i in 0..4 { + let chain_idx = 3 - i; + assert_packed_equals_bits(builder, &local.h_in[4 + i], &local.e_chain[chain_idx]); + } +} + +fn eval_block_limb_bridges(builder: &mut AB, cols: &Sha256CompressCols) { + for i in 0..SHA256_BLOCK_WORDS { + assert_packed_equals_bits(builder, &cols.block_limbs[i], &cols.sha.w[i]); + } +} + +fn eval_message_schedule(builder: &mut AB, local: &Sha256Cols) { + for i in 0..SHA256_SCHEDULE_EXTENSIONS { + let t = i + SHA256_BLOCK_WORDS; + + assert_sigma_matches( + builder, + &local.w[t - 15], + SigmaSpec::SmallSigma0, + &local.sched_sigma0[i], + ); + assert_sigma_matches(builder, &local.w[t - 2], SigmaSpec::SmallSigma1, &local.sched_sigma1[i]); + + let w_tm7_packed = pack_word::(&local.w[t - 7]); + add2(builder, &local.sched_tmp[i], &local.sched_sigma1[i], &w_tm7_packed); + + let w_t_packed = pack_word::(&local.w[t]); + let sched_sigma0 = local.sched_sigma0[i]; + let w_tm16_packed = pack_word::(&local.w[t - 16]); + add3_expr_out(builder, &w_t_packed, &local.sched_tmp[i], &sched_sigma0, &w_tm16_packed); + } +} + +fn eval_compression(builder: &mut AB, local: &Sha256Cols) { + for (t, round) in local.rounds.iter().enumerate() { + let a_bits = &local.a_chain[t + 3]; + let b_bits = &local.a_chain[t + 2]; + let c_bits = &local.a_chain[t + 1]; + let d_bits = &local.a_chain[t]; + let e_bits = &local.e_chain[t + 3]; + let f_bits = &local.e_chain[t + 2]; + let g_bits = &local.e_chain[t + 1]; + let h_bits = &local.e_chain[t]; + + assert_sigma_matches(builder, e_bits, SigmaSpec::BigSigma1, &round.sigma1_e); + assert_ch_matches(builder, e_bits, f_bits, g_bits, &round.ch); + + let h_packed = pack_word::(h_bits); + add3(builder, &round.tmp1, &round.sigma1_e, &round.ch, &h_packed); + + let k = [ + AB::IF::from_u32(SHA256_K[t] & 0xffff), + AB::IF::from_u32(SHA256_K[t] >> BITS_PER_LIMB), + ]; + let w_packed = pack_word::(&local.w[t]); + add3(builder, &round.t1, &round.tmp1, &k, &w_packed); + + assert_sigma_matches(builder, a_bits, SigmaSpec::BigSigma0, &round.sigma0_a); + assert_maj_matches(builder, a_bits, b_bits, c_bits, &round.maj); + + let new_a_packed = pack_word::(&local.a_chain[t + 4]); + add3_expr_out(builder, &new_a_packed, &round.t1, &round.sigma0_a, &round.maj); + + let new_e_packed = pack_word::(&local.e_chain[t + 4]); + let d_packed = pack_word::(d_bits); + add2_expr_out(builder, &new_e_packed, &round.t1, &d_packed); + } +} + +fn eval_finalization(builder: &mut AB, local: &Sha256Cols) { + for i in 0..4 { + let final_bits = &local.a_chain[SHA256_CHAIN_LEN - 1 - i]; + let packed = pack_word::(final_bits); + add2(builder, &local.h_out[i], &local.h_in[i], &packed); + } + for i in 0..4 { + let final_bits = &local.e_chain[SHA256_CHAIN_LEN - 1 - i]; + let packed = pack_word::(final_bits); + add2(builder, &local.h_out[4 + i], &local.h_in[4 + i], &packed); + } +} + +#[inline] +fn assert_bools(builder: &mut AB, bits: &[AB::IF; SHA256_WORD_BITS]) { + for &bit in bits { + builder.assert_bool(bit); + } +} + +#[inline] +fn pack_word(bits: &[AB::IF; SHA256_WORD_BITS]) -> [AB::IF; SHA256_U32_LIMBS] { + [ + pack_bits_le::(&bits[..BITS_PER_LIMB]), + pack_bits_le::(&bits[BITS_PER_LIMB..]), + ] +} + +#[inline] +fn pack_bits_le(bits: &[AB::IF]) -> AB::IF { + let mut acc = AB::IF::ZERO; + for &bit in bits.iter().rev() { + acc = acc.double() + bit; + } + acc +} + +#[inline] +fn assert_packed_equals_bits( + builder: &mut AB, + packed: &[AB::IF; SHA256_U32_LIMBS], + bits: &[AB::IF; SHA256_WORD_BITS], +) { + let built = pack_word::(bits); + builder.assert_zero(packed[0] - built[0]); + builder.assert_zero(packed[1] - built[1]); +} + +#[inline] +fn add2( + builder: &mut AB, + a: &[AB::IF; SHA256_U32_LIMBS], + b: &[AB::IF; SHA256_U32_LIMBS], + c: &[AB::IF; SHA256_U32_LIMBS], +) { + add2_expr_out(builder, a, b, c); +} + +#[inline] +fn add3( + builder: &mut AB, + a: &[AB::IF; SHA256_U32_LIMBS], + b: &[AB::IF; SHA256_U32_LIMBS], + c: &[AB::IF; SHA256_U32_LIMBS], + d: &[AB::IF; SHA256_U32_LIMBS], +) { + add3_expr_out(builder, a, b, c, d); +} + +#[inline] +fn add2_expr_out( + builder: &mut AB, + a: &[AB::IF; SHA256_U32_LIMBS], + b: &[AB::IF; SHA256_U32_LIMBS], + c: &[AB::IF; SHA256_U32_LIMBS], +) { + let two_16 = AB::IF::from_usize(1 << BITS_PER_LIMB); + let two_32 = two_16.square(); + + let acc_16 = a[0] - b[0] - c[0]; + let acc_32 = a[1] - b[1] - c[1]; + let acc = acc_16 + acc_32 * two_16; + + builder.assert_zero(acc * (acc + two_32)); + builder.assert_zero(acc_16 * (acc_16 + two_16)); +} + +#[inline] +fn add3_expr_out( + builder: &mut AB, + a: &[AB::IF; SHA256_U32_LIMBS], + b: &[AB::IF; SHA256_U32_LIMBS], + c: &[AB::IF; SHA256_U32_LIMBS], + d: &[AB::IF; SHA256_U32_LIMBS], +) { + let two_16 = AB::IF::from_usize(1 << BITS_PER_LIMB); + let two_32 = two_16.square(); + + let acc_16 = a[0] - b[0] - c[0] - d[0]; + let acc_32 = a[1] - b[1] - c[1] - d[1]; + let acc = acc_16 + acc_32 * two_16; + + builder.assert_zero(acc * (acc + two_32) * (acc + two_32.double())); + builder.assert_zero(acc_16 * (acc_16 + two_16) * (acc_16 + two_16.double())); +} + +#[derive(Copy, Clone)] +enum SigmaSpec { + BigSigma0, + BigSigma1, + SmallSigma0, + SmallSigma1, +} + +#[derive(Copy, Clone)] +enum ShiftKind { + Rotate, + Logical, +} + +#[inline] +const fn sigma_params(spec: SigmaSpec) -> (usize, usize, usize, ShiftKind) { + match spec { + SigmaSpec::BigSigma0 => (2, 13, 22, ShiftKind::Rotate), + SigmaSpec::BigSigma1 => (6, 11, 25, ShiftKind::Rotate), + SigmaSpec::SmallSigma0 => (7, 18, 3, ShiftKind::Logical), + SigmaSpec::SmallSigma1 => (17, 19, 10, ShiftKind::Logical), + } +} + +fn assert_sigma_matches( + builder: &mut AB, + bits: &[AB::IF; SHA256_WORD_BITS], + spec: SigmaSpec, + packed: &[AB::IF; SHA256_U32_LIMBS], +) { + let (r1, r2, r3, kind) = sigma_params(spec); + let mut built = [AB::IF::ZERO; SHA256_U32_LIMBS]; + for (limb, slot) in built.iter_mut().enumerate() { + let lo = limb * BITS_PER_LIMB; + let hi = lo + BITS_PER_LIMB; + let mut acc = AB::IF::ZERO; + for i in (lo..hi).rev() { + let b1 = bits[(i + r1) % SHA256_WORD_BITS]; + let b2 = bits[(i + r2) % SHA256_WORD_BITS]; + let b3 = match kind { + ShiftKind::Rotate => bits[(i + r3) % SHA256_WORD_BITS], + ShiftKind::Logical => { + let src = i + r3; + if src < SHA256_WORD_BITS { + bits[src] + } else { + AB::IF::ZERO + } + } + }; + acc = acc.double() + b1.xor3(&b2, &b3); + } + *slot = acc; + } + + builder.assert_zero(packed[0] - built[0]); + builder.assert_zero(packed[1] - built[1]); +} + +fn assert_ch_matches( + builder: &mut AB, + e: &[AB::IF; SHA256_WORD_BITS], + f: &[AB::IF; SHA256_WORD_BITS], + g: &[AB::IF; SHA256_WORD_BITS], + packed: &[AB::IF; SHA256_U32_LIMBS], +) { + let mut built = [AB::IF::ZERO; SHA256_U32_LIMBS]; + for (limb, slot) in built.iter_mut().enumerate() { + let lo = limb * BITS_PER_LIMB; + let hi = lo + BITS_PER_LIMB; + let mut acc = AB::IF::ZERO; + for i in (lo..hi).rev() { + let ei = e[i]; + let ch_i = ei * f[i] + (AB::IF::ONE - ei) * g[i]; + acc = acc.double() + ch_i; + } + *slot = acc; + } + + builder.assert_zero(packed[0] - built[0]); + builder.assert_zero(packed[1] - built[1]); +} + +fn assert_maj_matches( + builder: &mut AB, + a: &[AB::IF; SHA256_WORD_BITS], + b: &[AB::IF; SHA256_WORD_BITS], + c: &[AB::IF; SHA256_WORD_BITS], + packed: &[AB::IF; SHA256_U32_LIMBS], +) { + let mut built = [AB::IF::ZERO; SHA256_U32_LIMBS]; + for (limb, slot) in built.iter_mut().enumerate() { + let lo = limb * BITS_PER_LIMB; + let hi = lo + BITS_PER_LIMB; + let mut acc = AB::IF::ZERO; + for i in (lo..hi).rev() { + let maj_i = a[i] * b[i] + c[i] * a[i].xor(&b[i]); + acc = acc.double() + maj_i; + } + *slot = acc; + } + + builder.assert_zero(packed[0] - built[0]); + builder.assert_zero(packed[1] - built[1]); +} diff --git a/crates/lean_vm/src/tables/sha256_compress/columns.rs b/crates/lean_vm/src/tables/sha256_compress/columns.rs new file mode 100644 index 000000000..60780fce0 --- /dev/null +++ b/crates/lean_vm/src/tables/sha256_compress/columns.rs @@ -0,0 +1,103 @@ +use core::{ + borrow::{Borrow, BorrowMut}, + mem::size_of, +}; + +use super::{ + SHA256_BLOCK_LIMBS, SHA256_BLOCK_WORDS, SHA256_COMPRESS_ROUNDS, SHA256_SCHEDULE_EXTENSIONS, SHA256_STATE_WORDS, + SHA256_U32_LIMBS, SHA256_WORD_BITS, +}; + +pub const SHA256_CHAIN_LEN: usize = 4 + SHA256_COMPRESS_ROUNDS; + +pub const SHA256_COL_FLAG: usize = 0; +pub const SHA256_COL_STATE_PTR: usize = 1; +pub const SHA256_COL_BLOCK_PTR: usize = 2; +pub const SHA256_COL_OUT_PTR: usize = 3; +pub const SHA256_COL_BLOCK_LIMBS_START: usize = 4; +pub const SHA256_COL_AIR_START: usize = SHA256_COL_BLOCK_LIMBS_START + SHA256_BLOCK_LIMBS; +pub const SHA256_COL_STATE_LIMBS_START: usize = SHA256_COL_AIR_START; +pub const SHA256_COL_OUT_LIMBS_START: usize = NUM_SHA256_COMPRESS_COLS - SHA256_STATE_WORDS * SHA256_U32_LIMBS; + +#[repr(C)] +#[derive(Debug)] +pub struct Sha256RoundCols { + pub sigma1_e: [T; SHA256_U32_LIMBS], + pub ch: [T; SHA256_U32_LIMBS], + pub tmp1: [T; SHA256_U32_LIMBS], + pub t1: [T; SHA256_U32_LIMBS], + pub sigma0_a: [T; SHA256_U32_LIMBS], + pub maj: [T; SHA256_U32_LIMBS], +} + +#[repr(C)] +#[derive(Debug)] +pub struct Sha256Cols { + pub h_in: [[T; SHA256_U32_LIMBS]; SHA256_STATE_WORDS], + pub a_chain: [[T; SHA256_WORD_BITS]; SHA256_CHAIN_LEN], + pub e_chain: [[T; SHA256_WORD_BITS]; SHA256_CHAIN_LEN], + pub w: [[T; SHA256_WORD_BITS]; SHA256_COMPRESS_ROUNDS], + pub sched_sigma0: [[T; SHA256_U32_LIMBS]; SHA256_SCHEDULE_EXTENSIONS], + pub sched_sigma1: [[T; SHA256_U32_LIMBS]; SHA256_SCHEDULE_EXTENSIONS], + pub sched_tmp: [[T; SHA256_U32_LIMBS]; SHA256_SCHEDULE_EXTENSIONS], + pub rounds: [Sha256RoundCols; SHA256_COMPRESS_ROUNDS], + pub h_out: [[T; SHA256_U32_LIMBS]; SHA256_STATE_WORDS], +} + +#[repr(C)] +#[derive(Debug)] +pub struct Sha256CompressCols { + pub flag: T, + pub state_ptr: T, + pub block_ptr: T, + pub out_ptr: T, + pub block_limbs: [[T; SHA256_U32_LIMBS]; SHA256_BLOCK_WORDS], + pub sha: Sha256Cols, +} + +pub const NUM_SHA256_AIR_COLS: usize = size_of::>(); +pub const NUM_SHA256_COMPRESS_COLS: usize = size_of::>(); + +impl Borrow> for [T] { + fn borrow(&self) -> &Sha256Cols { + debug_assert_eq!(self.len(), NUM_SHA256_AIR_COLS); + let (prefix, shorts, suffix) = unsafe { self.align_to::>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + &shorts[0] + } +} + +impl BorrowMut> for [T] { + fn borrow_mut(&mut self) -> &mut Sha256Cols { + debug_assert_eq!(self.len(), NUM_SHA256_AIR_COLS); + let (prefix, shorts, suffix) = unsafe { self.align_to_mut::>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + &mut shorts[0] + } +} + +impl Borrow> for [T] { + fn borrow(&self) -> &Sha256CompressCols { + debug_assert_eq!(self.len(), NUM_SHA256_COMPRESS_COLS); + let (prefix, shorts, suffix) = unsafe { self.align_to::>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + &shorts[0] + } +} + +impl BorrowMut> for [T] { + fn borrow_mut(&mut self) -> &mut Sha256CompressCols { + debug_assert_eq!(self.len(), NUM_SHA256_COMPRESS_COLS); + let (prefix, shorts, suffix) = unsafe { self.align_to_mut::>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + &mut shorts[0] + } +} diff --git a/crates/lean_vm/src/tables/sha256_compress/mod.rs b/crates/lean_vm/src/tables/sha256_compress/mod.rs new file mode 100644 index 000000000..731221715 --- /dev/null +++ b/crates/lean_vm/src/tables/sha256_compress/mod.rs @@ -0,0 +1,562 @@ +use crate::{F, PrecompileCompTimeArgs, RunnerError, Table}; +use backend::{PrimeCharacteristicRing, PrimeField32}; +use utils::ToUsize; + +mod air; + +mod columns; +pub use columns::*; + +mod trace_gen; +pub use trace_gen::*; + +pub const SHA256_STATE_WORDS: usize = 8; +pub const SHA256_BLOCK_WORDS: usize = 16; +pub const SHA256_INPUT_WORDS: usize = SHA256_BLOCK_WORDS + SHA256_STATE_WORDS; +pub const SHA256_WORD_BITS: usize = 32; +pub const SHA256_U32_LIMBS: usize = 2; +pub const SHA256_COMPRESS_ROUNDS: usize = 64; +pub const SHA256_SCHEDULE_EXTENSIONS: usize = SHA256_COMPRESS_ROUNDS - SHA256_BLOCK_WORDS; + +pub const SHA256_STATE_LIMBS: usize = SHA256_STATE_WORDS * SHA256_U32_LIMBS; +pub const SHA256_BLOCK_LIMBS: usize = SHA256_BLOCK_WORDS * SHA256_U32_LIMBS; + +pub const SHA256_IV: [u32; SHA256_STATE_WORDS] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +pub const SHA256_ABC_BLOCK: [u32; SHA256_BLOCK_WORDS] = [0x61626380, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x18]; + +pub const SHA256_ZERO_BLOCK: [u32; SHA256_BLOCK_WORDS] = [0; SHA256_BLOCK_WORDS]; + +pub const SHA256_PRECOMPILE_DATA: usize = 2; +pub const SHA256_COMPRESS_NAME: &str = "sha256_compress"; + +const SHA256_K: [u32; SHA256_COMPRESS_ROUNDS] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, 0xd807aa98, + 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, + 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 0x983e5152, 0xa831c66d, 0xb00327c8, + 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, + 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, + 0xd6990624, 0xf40e3585, 0x106aa070, 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, + 0x5b9cca4f, 0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, + 0xc67178f2, +]; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Sha256RoundWitness { + pub sigma1_e: u32, + pub ch: u32, + pub tmp1: u32, + pub t1: u32, + pub sigma0_a: u32, + pub maj: u32, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Sha256CompressionWitness { + pub h_in: [u32; SHA256_STATE_WORDS], + pub block: [u32; SHA256_BLOCK_WORDS], + pub w: [u32; SHA256_COMPRESS_ROUNDS], + pub sched_sigma0: [u32; SHA256_SCHEDULE_EXTENSIONS], + pub sched_sigma1: [u32; SHA256_SCHEDULE_EXTENSIONS], + pub sched_tmp: [u32; SHA256_SCHEDULE_EXTENSIONS], + pub a_chain: [u32; 4 + SHA256_COMPRESS_ROUNDS], + pub e_chain: [u32; 4 + SHA256_COMPRESS_ROUNDS], + pub rounds: [Sha256RoundWitness; SHA256_COMPRESS_ROUNDS], + pub h_out: [u32; SHA256_STATE_WORDS], +} + +pub fn generate_sha256_compression_witness( + h_in: [u32; SHA256_STATE_WORDS], + block: [u32; SHA256_BLOCK_WORDS], +) -> Sha256CompressionWitness { + let mut w = [0u32; SHA256_COMPRESS_ROUNDS]; + w[..SHA256_BLOCK_WORDS].copy_from_slice(&block); + + let mut sched_sigma0 = [0u32; SHA256_SCHEDULE_EXTENSIONS]; + let mut sched_sigma1 = [0u32; SHA256_SCHEDULE_EXTENSIONS]; + let mut sched_tmp = [0u32; SHA256_SCHEDULE_EXTENSIONS]; + + for t in SHA256_BLOCK_WORDS..SHA256_COMPRESS_ROUNDS { + let i = t - SHA256_BLOCK_WORDS; + let s0 = small_sigma0(w[t - 15]); + let s1 = small_sigma1(w[t - 2]); + let tmp = s1.wrapping_add(w[t - 7]); + w[t] = tmp.wrapping_add(s0).wrapping_add(w[t - 16]); + sched_sigma0[i] = s0; + sched_sigma1[i] = s1; + sched_tmp[i] = tmp; + } + + let mut a_chain = [0u32; 4 + SHA256_COMPRESS_ROUNDS]; + let mut e_chain = [0u32; 4 + SHA256_COMPRESS_ROUNDS]; + a_chain[0] = h_in[3]; + a_chain[1] = h_in[2]; + a_chain[2] = h_in[1]; + a_chain[3] = h_in[0]; + e_chain[0] = h_in[7]; + e_chain[1] = h_in[6]; + e_chain[2] = h_in[5]; + e_chain[3] = h_in[4]; + + let empty_round = Sha256RoundWitness { + sigma1_e: 0, + ch: 0, + tmp1: 0, + t1: 0, + sigma0_a: 0, + maj: 0, + }; + let mut rounds = [empty_round; SHA256_COMPRESS_ROUNDS]; + let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = h_in; + + for t in 0..SHA256_COMPRESS_ROUNDS { + let sigma1_e = big_sigma1(e); + let ch = ch(e, f, g); + let tmp1 = h.wrapping_add(sigma1_e).wrapping_add(ch); + let t1 = tmp1.wrapping_add(SHA256_K[t]).wrapping_add(w[t]); + let sigma0_a = big_sigma0(a); + let maj = maj(a, b, c); + let new_a = t1.wrapping_add(sigma0_a).wrapping_add(maj); + let new_e = d.wrapping_add(t1); + + rounds[t] = Sha256RoundWitness { + sigma1_e, + ch, + tmp1, + t1, + sigma0_a, + maj, + }; + a_chain[t + 4] = new_a; + e_chain[t + 4] = new_e; + + h = g; + g = f; + f = e; + e = new_e; + d = c; + c = b; + b = a; + a = new_a; + } + + let final_state = [a, b, c, d, e, f, g, h]; + let h_out = core::array::from_fn(|i| h_in[i].wrapping_add(final_state[i])); + + Sha256CompressionWitness { + h_in, + block, + w, + sched_sigma0, + sched_sigma1, + sched_tmp, + a_chain, + e_chain, + rounds, + h_out, + } +} + +pub fn sha256_compress_words( + h_in: [u32; SHA256_STATE_WORDS], + block: [u32; SHA256_BLOCK_WORDS], +) -> [u32; SHA256_STATE_WORDS] { + generate_sha256_compression_witness(h_in, block).h_out +} + +pub const fn u32_to_u16_limbs_le(word: u32) -> [u16; SHA256_U32_LIMBS] { + [(word & 0xffff) as u16, (word >> 16) as u16] +} + +pub const fn u16_limbs_le_to_u32(limbs: [u16; SHA256_U32_LIMBS]) -> u32 { + limbs[0] as u32 | ((limbs[1] as u32) << 16) +} + +pub fn words_to_u16_limbs_le(words: impl IntoIterator) -> Vec { + let mut limbs = Vec::new(); + for word in words { + let word_limbs = u32_to_u16_limbs_le(word); + limbs.extend_from_slice(&word_limbs); + } + limbs +} + +pub fn words_to_field_limbs_le(words: [u32; N]) -> Vec { + words_to_u16_limbs_le(words) + .into_iter() + .map(|limb| F::from_usize(usize::from(limb))) + .collect() +} + +pub fn u32_to_bits_le(word: u32) -> [bool; SHA256_WORD_BITS] { + core::array::from_fn(|i| ((word >> i) & 1) == 1) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Sha256CompressPrecompile; + +impl crate::TableT for Sha256CompressPrecompile { + fn name(&self) -> &'static str { + SHA256_COMPRESS_NAME + } + + fn table(&self) -> Table { + Table::sha256_compress() + } + + fn lookups(&self) -> Vec { + vec![ + crate::LookupIntoMemory { + index: SHA256_COL_STATE_PTR, + values: (SHA256_COL_STATE_LIMBS_START..SHA256_COL_STATE_LIMBS_START + SHA256_STATE_LIMBS).collect(), + }, + crate::LookupIntoMemory { + index: SHA256_COL_BLOCK_PTR, + values: (SHA256_COL_BLOCK_LIMBS_START..SHA256_COL_BLOCK_LIMBS_START + SHA256_BLOCK_LIMBS).collect(), + }, + crate::LookupIntoMemory { + index: SHA256_COL_OUT_PTR, + values: (SHA256_COL_OUT_LIMBS_START..SHA256_COL_OUT_LIMBS_START + SHA256_STATE_LIMBS).collect(), + }, + ] + } + + fn bus(&self) -> crate::Bus { + crate::Bus { + direction: crate::BusDirection::Pull, + selector: SHA256_COL_FLAG, + data: vec![ + crate::BusData::Constant(SHA256_PRECOMPILE_DATA), + crate::BusData::Column(SHA256_COL_STATE_PTR), + crate::BusData::Column(SHA256_COL_BLOCK_PTR), + crate::BusData::Column(SHA256_COL_OUT_PTR), + ], + } + } + + fn padding_row(&self, padding: &crate::PaddingMemory) -> Vec { + sha256_compress_trace_row( + F::ZERO, + F::from_usize(padding.sha256_state_ptr), + F::from_usize(padding.sha256_block_ptr), + F::from_usize(padding.sha256_out_ptr), + SHA256_IV, + SHA256_ZERO_BLOCK, + ) + } + + fn execute( + &self, + arg_a: F, + arg_b: F, + arg_c: F, + args: PrecompileCompTimeArgs, + ctx: &mut crate::InstructionContext<'_, M>, + ) -> Result<(), RunnerError> { + let PrecompileCompTimeArgs::Sha256Compress = args else { + unreachable!("Sha256Compress table called with non-Sha256Compress args"); + }; + + let state_ptr = arg_a.to_usize(); + let block_ptr = arg_b.to_usize(); + let out_ptr = arg_c.to_usize(); + + let h_in = field_limbs_to_words::(&ctx.memory.get_slice(state_ptr, SHA256_STATE_LIMBS)?)?; + let block = field_limbs_to_words::(&ctx.memory.get_slice(block_ptr, SHA256_BLOCK_LIMBS)?)?; + let witness = generate_sha256_compression_witness(h_in, block); + ctx.memory.set_slice(out_ptr, &words_to_field_limbs_le(witness.h_out))?; + + let trace = ctx.traces.get_mut(&self.table()).unwrap(); + push_sha256_compress_trace_row_from_witness(trace, F::ONE, arg_a, arg_b, arg_c, &witness); + + Ok(()) + } +} + +fn field_limbs_to_words(limbs: &[F]) -> Result<[u32; N], RunnerError> { + assert_eq!(limbs.len(), N * SHA256_U32_LIMBS); + let mut words = [0u32; N]; + for (word, limb_pair) in words.iter_mut().zip(limbs.chunks_exact(SHA256_U32_LIMBS)) { + let lo = limb_to_u16(limb_pair[0])?; + let hi = limb_to_u16(limb_pair[1])?; + *word = u16_limbs_le_to_u32([lo, hi]); + } + Ok(words) +} + +fn limb_to_u16(limb: F) -> Result { + let value = limb.as_canonical_u32(); + u16::try_from(value).map_err(|_| RunnerError::InvalidSha256Input) +} + +#[inline] +const fn small_sigma0(x: u32) -> u32 { + x.rotate_right(7) ^ x.rotate_right(18) ^ (x >> 3) +} + +#[inline] +const fn small_sigma1(x: u32) -> u32 { + x.rotate_right(17) ^ x.rotate_right(19) ^ (x >> 10) +} + +#[inline] +const fn big_sigma0(x: u32) -> u32 { + x.rotate_right(2) ^ x.rotate_right(13) ^ x.rotate_right(22) +} + +#[inline] +const fn big_sigma1(x: u32) -> u32 { + x.rotate_right(6) ^ x.rotate_right(11) ^ x.rotate_right(25) +} + +#[inline] +const fn ch(e: u32, f: u32, g: u32) -> u32 { + (e & f) ^ ((!e) & g) +} + +#[inline] +const fn maj(a: u32, b: u32, c: u32) -> u32 { + (a & b) ^ (a & c) ^ (b & c) +} + +#[cfg(test)] +mod tests { + use super::*; + use backend::{ + Air, PrimeCharacteristicRing, PrimeField32, SumcheckComputation, get_symbolic_constraints_and_bus_data_values, + }; + use core::borrow::Borrow; + use std::collections::BTreeMap; + + use crate::{ + EF, ExtraDataForBuses, InstructionContext, InstructionCounts, Memory, MemoryAccess, TableT, TableTrace, + }; + + fn words_to_hex(words: [u32; SHA256_STATE_WORDS]) -> String { + words.iter().map(|word| format!("{word:08x}")).collect() + } + + fn extract_packed_words(limbs: &[[F; SHA256_U32_LIMBS]; SHA256_STATE_WORDS]) -> [u32; SHA256_STATE_WORDS] { + core::array::from_fn(|i| { + let lo = limbs[i][0].as_canonical_u32(); + let hi = limbs[i][1].as_canonical_u32(); + lo | (hi << 16) + }) + } + + fn extract_trace_output(row: &[F]) -> [u32; SHA256_STATE_WORDS] { + let cols: &Sha256CompressCols = row.borrow(); + extract_packed_words(&cols.sha.h_out) + } + + fn sha2_compress_reference( + block: [u32; SHA256_BLOCK_WORDS], + h_in: [u32; SHA256_STATE_WORDS], + ) -> [u32; SHA256_STATE_WORDS] { + let mut block_bytes = [0u8; 64]; + for (i, word) in block.iter().enumerate() { + block_bytes[i * 4..i * 4 + 4].copy_from_slice(&word.to_be_bytes()); + } + let mut state = h_in; + sha2::block_api::compress256(&mut state, core::slice::from_ref(&block_bytes)); + state + } + + fn air_extra_data(n_constraints: usize) -> ExtraDataForBuses { + let mut powers = Vec::with_capacity(n_constraints + 1); + let alpha = EF::from(F::from_usize(7)); + let mut current = EF::ONE; + for _ in 0..=n_constraints { + powers.push(current); + current *= alpha; + } + ExtraDataForBuses::new(Vec::new(), EF::ZERO, powers) + } + + #[test] + fn abc_single_block_matches_known_digest() { + let out = sha256_compress_words(SHA256_IV, SHA256_ABC_BLOCK); + assert_eq!( + words_to_hex(out), + "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad" + ); + } + + #[test] + fn low_then_high_limb_order_is_locked() { + assert_eq!(u32_to_u16_limbs_le(0x61626380), [0x6380, 0x6162]); + assert_eq!(u32_to_u16_limbs_le(0x18), [0x0018, 0x0000]); + assert_eq!(u16_limbs_le_to_u32([0x6380, 0x6162]), 0x61626380); + } + + #[test] + fn column_counts_match_plonky3_baseline_plus_leanvm_prefix() { + assert_eq!(NUM_SHA256_AIR_COLS, 7488); + assert_eq!(SHA256_COL_AIR_START, 36); + assert_eq!(NUM_SHA256_COMPRESS_COLS, 7524); + } + + #[test] + fn trace_row_populates_prefix_block_and_output() { + let row = sha256_compress_trace_row( + F::ONE, + F::from_usize(10), + F::from_usize(20), + F::from_usize(30), + SHA256_IV, + SHA256_ABC_BLOCK, + ); + assert_eq!(row.len(), NUM_SHA256_COMPRESS_COLS); + + let cols: &Sha256CompressCols = row.as_slice().borrow(); + assert_eq!(cols.flag, F::ONE); + assert_eq!(cols.state_ptr, F::from_usize(10)); + assert_eq!(cols.block_ptr, F::from_usize(20)); + assert_eq!(cols.out_ptr, F::from_usize(30)); + assert_eq!(cols.block_limbs[0], [F::from_usize(0x6380), F::from_usize(0x6162)]); + assert_eq!(cols.block_limbs[15], [F::from_usize(0x18), F::ZERO]); + + assert_eq!( + words_to_hex(extract_packed_words(&cols.sha.h_out)), + "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad" + ); + } + + #[test] + fn trace_generation_matches_sha2_compress_reference() { + let cases = [ + (SHA256_IV, SHA256_ABC_BLOCK), + (SHA256_IV, SHA256_ZERO_BLOCK), + ( + [ + 0x0123_4567, + 0x89ab_cdef, + 0xfedc_ba98, + 0x7654_3210, + 0x0f1e_2d3c, + 0x4b5a_6978, + 0x8877_6655, + 0x4433_2211, + ], + [ + 0xffff_ffff, + 0, + 0x1357_9bdf, + 0x2468_ace0, + 0xdead_beef, + 0xcafe_babe, + 0x0001_0002, + 0x0003_0004, + 0x0102_0304, + 0x1111_2222, + 0x3333_4444, + 0x5555_6666, + 0x7777_8888, + 0x9999_aaaa, + 0xbbbb_cccc, + 0xdddd_eeee, + ], + ), + ]; + + for (h_in, block) in cases { + let row = sha256_compress_trace_row(F::ONE, F::ZERO, F::ZERO, F::ZERO, h_in, block); + assert_eq!(extract_trace_output(&row), sha2_compress_reference(block, h_in)); + } + } + + #[test] + fn symbolic_constraint_count_matches_declared_count() { + let table = Sha256CompressPrecompile::; + let (constraints, bus_flag, bus_data) = get_symbolic_constraints_and_bus_data_values::(&table); + assert_eq!(constraints.len(), table.n_constraints()); + assert_eq!( + bus_flag, + backend::SymbolicExpression::Variable(backend::SymbolicVariable::new(SHA256_COL_FLAG)) + ); + assert_eq!(bus_data.len(), 4); + } + + #[test] + fn generated_trace_row_satisfies_air_and_tampered_row_fails() { + let table = Sha256CompressPrecompile::; + let extra_data = air_extra_data(table.n_constraints()); + let row = sha256_compress_trace_row( + F::ONE, + F::from_usize(10), + F::from_usize(20), + F::from_usize(30), + SHA256_IV, + SHA256_ABC_BLOCK, + ); + + assert_eq!( + as SumcheckComputation>::eval_base(&table, &row, &extra_data), + EF::ZERO + ); + + let mut tampered = row; + tampered[SHA256_COL_AIR_START] = F::TWO; + assert_ne!( + as SumcheckComputation>::eval_base(&table, &tampered, &extra_data), + EF::ZERO + ); + } + + #[test] + fn precompile_execute_writes_output_and_trace_row() { + let state_ptr = 0; + let block_ptr = SHA256_STATE_LIMBS; + let out_ptr = SHA256_STATE_LIMBS + SHA256_BLOCK_LIMBS; + + let mut memory = Memory::new(vec![]); + memory + .set_slice(state_ptr, &words_to_field_limbs_le(SHA256_IV)) + .unwrap(); + memory + .set_slice(block_ptr, &words_to_field_limbs_le(SHA256_ABC_BLOCK)) + .unwrap(); + + let table = Table::sha256_compress(); + let mut traces = BTreeMap::new(); + traces.insert(table, TableTrace::new(&Sha256CompressPrecompile::)); + let mut fp = 0; + let mut pc = 0; + let pcs = vec![0]; + let mut counts = InstructionCounts::default(); + let mut ctx = InstructionContext { + memory: &mut memory, + fp: &mut fp, + pc: &mut pc, + pcs: &pcs, + traces: &mut traces, + counts: &mut counts, + }; + + table + .execute( + F::from_usize(state_ptr), + F::from_usize(block_ptr), + F::from_usize(out_ptr), + PrecompileCompTimeArgs::Sha256Compress, + &mut ctx, + ) + .unwrap(); + + let out = ctx.memory.get_slice(out_ptr, SHA256_STATE_LIMBS).unwrap(); + let out_words = field_limbs_to_words::(&out).unwrap(); + assert_eq!( + words_to_hex(out_words), + "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad" + ); + + let trace = ctx.traces.get(&table).unwrap(); + assert_eq!(trace.columns.len(), NUM_SHA256_COMPRESS_COLS); + assert_eq!(trace.columns[SHA256_COL_FLAG], [F::ONE]); + assert_eq!(trace.columns[SHA256_COL_STATE_PTR], [F::from_usize(state_ptr)]); + assert_eq!(trace.columns[SHA256_COL_BLOCK_PTR], [F::from_usize(block_ptr)]); + assert_eq!(trace.columns[SHA256_COL_OUT_PTR], [F::from_usize(out_ptr)]); + } +} diff --git a/crates/lean_vm/src/tables/sha256_compress/trace_gen.rs b/crates/lean_vm/src/tables/sha256_compress/trace_gen.rs new file mode 100644 index 000000000..514a1226d --- /dev/null +++ b/crates/lean_vm/src/tables/sha256_compress/trace_gen.rs @@ -0,0 +1,126 @@ +use core::borrow::BorrowMut; + +use backend::PrimeCharacteristicRing; + +use crate::{F, TableTrace}; + +use super::{ + SHA256_BLOCK_WORDS, SHA256_COMPRESS_ROUNDS, SHA256_SCHEDULE_EXTENSIONS, SHA256_STATE_WORDS, Sha256Cols, + Sha256CompressCols, Sha256CompressionWitness, Sha256RoundCols, generate_sha256_compression_witness, u32_to_bits_le, + u32_to_u16_limbs_le, +}; + +pub fn sha256_compress_trace_row( + flag: F, + state_ptr: F, + block_ptr: F, + out_ptr: F, + h_in: [u32; SHA256_STATE_WORDS], + block: [u32; SHA256_BLOCK_WORDS], +) -> Vec { + let witness = generate_sha256_compression_witness(h_in, block); + sha256_compress_trace_row_from_witness(flag, state_ptr, block_ptr, out_ptr, &witness) +} + +pub fn sha256_compress_trace_row_from_witness( + flag: F, + state_ptr: F, + block_ptr: F, + out_ptr: F, + witness: &Sha256CompressionWitness, +) -> Vec { + let mut row = F::zero_vec(super::NUM_SHA256_COMPRESS_COLS); + let cols: &mut Sha256CompressCols = row.as_mut_slice().borrow_mut(); + fill_sha256_compress_cols(cols, flag, state_ptr, block_ptr, out_ptr, witness); + row +} + +pub fn push_sha256_compress_trace_row( + trace: &mut TableTrace, + flag: F, + state_ptr: F, + block_ptr: F, + out_ptr: F, + h_in: [u32; SHA256_STATE_WORDS], + block: [u32; SHA256_BLOCK_WORDS], +) { + let row = sha256_compress_trace_row(flag, state_ptr, block_ptr, out_ptr, h_in, block); + push_row(trace, row); +} + +pub fn push_sha256_compress_trace_row_from_witness( + trace: &mut TableTrace, + flag: F, + state_ptr: F, + block_ptr: F, + out_ptr: F, + witness: &Sha256CompressionWitness, +) { + let row = sha256_compress_trace_row_from_witness(flag, state_ptr, block_ptr, out_ptr, witness); + push_row(trace, row); +} + +fn push_row(trace: &mut TableTrace, row: Vec) { + debug_assert_eq!(trace.columns.len(), row.len()); + for (column, value) in trace.columns.iter_mut().zip(row) { + column.push(value); + } +} + +pub fn fill_sha256_compress_cols( + cols: &mut Sha256CompressCols, + flag: F, + state_ptr: F, + block_ptr: F, + out_ptr: F, + witness: &Sha256CompressionWitness, +) { + cols.flag = flag; + cols.state_ptr = state_ptr; + cols.block_ptr = block_ptr; + cols.out_ptr = out_ptr; + + for (dst, &word) in cols.block_limbs.iter_mut().zip(&witness.block) { + *dst = word_limbs(word); + } + + fill_sha256_air_cols(&mut cols.sha, witness); +} + +pub fn fill_sha256_air_cols(cols: &mut Sha256Cols, witness: &Sha256CompressionWitness) { + for i in 0..SHA256_STATE_WORDS { + cols.h_in[i] = word_limbs(witness.h_in[i]); + cols.h_out[i] = word_limbs(witness.h_out[i]); + } + + for i in 0..(4 + SHA256_COMPRESS_ROUNDS) { + cols.a_chain[i] = word_bits(witness.a_chain[i]); + cols.e_chain[i] = word_bits(witness.e_chain[i]); + } + + for i in 0..SHA256_COMPRESS_ROUNDS { + cols.w[i] = word_bits(witness.w[i]); + cols.rounds[i] = Sha256RoundCols { + sigma1_e: word_limbs(witness.rounds[i].sigma1_e), + ch: word_limbs(witness.rounds[i].ch), + tmp1: word_limbs(witness.rounds[i].tmp1), + t1: word_limbs(witness.rounds[i].t1), + sigma0_a: word_limbs(witness.rounds[i].sigma0_a), + maj: word_limbs(witness.rounds[i].maj), + }; + } + + for i in 0..SHA256_SCHEDULE_EXTENSIONS { + cols.sched_sigma0[i] = word_limbs(witness.sched_sigma0[i]); + cols.sched_sigma1[i] = word_limbs(witness.sched_sigma1[i]); + cols.sched_tmp[i] = word_limbs(witness.sched_tmp[i]); + } +} + +fn word_limbs(word: u32) -> [F; 2] { + u32_to_u16_limbs_le(word).map(|limb| F::from_usize(usize::from(limb))) +} + +fn word_bits(word: u32) -> [F; 32] { + u32_to_bits_le(word).map(F::from_bool) +} diff --git a/crates/lean_vm/src/tables/table_enum.rs b/crates/lean_vm/src/tables/table_enum.rs index 55be30e28..cefb64a54 100644 --- a/crates/lean_vm/src/tables/table_enum.rs +++ b/crates/lean_vm/src/tables/table_enum.rs @@ -3,8 +3,13 @@ use backend::*; use crate::execution::memory::MemoryAccess; use crate::*; -pub const N_TABLES: usize = 3; -pub const ALL_TABLES: [Table; N_TABLES] = [Table::execution(), Table::extension_op(), Table::poseidon16()]; +pub const N_TABLES: usize = 4; +pub const ALL_TABLES: [Table; N_TABLES] = [ + Table::execution(), + Table::extension_op(), + Table::poseidon16(), + Table::sha256_compress(), +]; pub const MAX_PRECOMPILE_BUS_WIDTH: usize = 4; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -13,6 +18,7 @@ pub enum Table { Execution(ExecutionTable), ExtensionOp(ExtensionOpPrecompile), Poseidon16(Poseidon16Precompile), + Sha256Compress(Sha256CompressPrecompile), } #[macro_export] @@ -22,6 +28,7 @@ macro_rules! delegate_to_inner { match $self { Self::ExtensionOp(p) => p.$method($($($arg),*)?), Self::Poseidon16(p) => p.$method($($($arg),*)?), + Self::Sha256Compress(p) => p.$method($($($arg),*)?), Self::Execution(p) => p.$method($($($arg),*)?), } }; @@ -30,6 +37,7 @@ macro_rules! delegate_to_inner { match $self { Table::ExtensionOp(p) => $macro_name!(p), Table::Poseidon16(p) => $macro_name!(p), + Table::Sha256Compress(p) => $macro_name!(p), Table::Execution(p) => $macro_name!(p), } }; @@ -45,6 +53,9 @@ impl Table { pub const fn poseidon16() -> Self { Self::Poseidon16(Poseidon16Precompile) } + pub const fn sha256_compress() -> Self { + Self::Sha256Compress(Sha256CompressPrecompile) + } pub fn embed(&self) -> PF { PF::from_usize(self.index()) } @@ -69,8 +80,8 @@ impl TableT for Table { fn bus(&self) -> Bus { delegate_to_inner!(self, bus) } - fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize) -> Vec> { - delegate_to_inner!(self, padding_row, zero_vec_ptr, null_hash_ptr) + fn padding_row(&self, padding: &PaddingMemory) -> Vec> { + delegate_to_inner!(self, padding_row, padding) } fn execute( &self, diff --git a/crates/lean_vm/src/tables/table_trait.rs b/crates/lean_vm/src/tables/table_trait.rs index cbb773c61..cd650f194 100644 --- a/crates/lean_vm/src/tables/table_trait.rs +++ b/crates/lean_vm/src/tables/table_trait.rs @@ -46,6 +46,15 @@ pub struct Bus { pub data: Vec, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct PaddingMemory { + pub zero_vec_ptr: usize, + pub null_poseidon_16_hash_ptr: usize, + pub sha256_state_ptr: usize, + pub sha256_block_ptr: usize, + pub sha256_out_ptr: usize, +} + #[derive(Debug, Default)] pub struct TableTrace { pub columns: Vec>, @@ -126,7 +135,7 @@ pub trait TableT: Air { fn table(&self) -> Table; fn lookups(&self) -> Vec; fn bus(&self) -> Bus; - fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize) -> Vec; + fn padding_row(&self, padding: &PaddingMemory) -> Vec; fn execute( &self, arg_a: F, diff --git a/crates/rec_aggregation/src/lib.rs b/crates/rec_aggregation/src/lib.rs index 3180d23ab..77571c5d7 100644 --- a/crates/rec_aggregation/src/lib.rs +++ b/crates/rec_aggregation/src/lib.rs @@ -5,18 +5,18 @@ mod compilation; mod type_1_aggregation; mod type_2_aggregation; -use backend::{Evaluation, Proof, ProofError, RawProof}; +use backend::{Evaluation, Proof, ProofError, RawProof, VerifierState}; pub use compilation::{ MAX_RECURSIONS, MAX_XMSS_AGGREGATED, MAX_XMSS_DUPLICATES, NUM_REPEATED_ONES, PREAMBLE_MEMORY_LEN, ZERO_VEC_LEN, get_aggregation_bytecode, init_aggregation_bytecode, }; -use lean_prover::verify_execution::verify_execution; +use lean_prover::verify_execution::verify; use lean_vm::{DIGEST_LEN, EF, F}; pub use type_1_aggregation::{TypeOneInfo, TypeOneMultiSignature, aggregate_type_1, verify_type_1}; pub use type_2_aggregation::{ TypeTwoMultiSignature, merge_many_type_1, split_type_2, split_type_2_by_msg, verify_type_2, }; -use utils::poseidon_compress_slice; +use utils::{get_poseidon16, poseidon_compress_slice}; #[allow(missing_debug_implementations)] pub struct InnerVerified { @@ -29,7 +29,9 @@ pub struct InnerVerified { pub(crate) fn verify_inner(input_data: Vec, proof: Proof) -> Result { let input_data_hash = poseidon_compress_slice(&input_data, true); let bytecode = get_aggregation_bytecode(); - let (verif, raw_proof) = verify_execution(bytecode, &input_data_hash, proof)?; + let mut verifier_state = VerifierState::::new(proof, get_poseidon16().clone())?; + let verif = verify(bytecode, &input_data_hash, &mut verifier_state)?; + let raw_proof = verifier_state.into_raw_proof(); Ok(InnerVerified { input_data, input_data_hash, diff --git a/crates/sub_protocols/src/stacked_pcs.rs b/crates/sub_protocols/src/stacked_pcs.rs index e715af3c3..6320be630 100644 --- a/crates/sub_protocols/src/stacked_pcs.rs +++ b/crates/sub_protocols/src/stacked_pcs.rs @@ -1,3 +1,4 @@ +use backend::merkle::Sha256Digest; use backend::*; use lean_vm::{ ALL_TABLES, COL_PC, CommittedStatements, ENDING_PC, MIN_LOG_MEMORY_SIZE, MIN_LOG_N_ROWS_PER_TABLE, @@ -31,9 +32,9 @@ Stacking of various (multilinear) polynomials into a single -big- (multilinear) */ #[derive(Debug)] -pub struct StackedPcsWitness { +pub struct StackedPcsWitness { pub stacked_n_vars: VarCount, - pub inner_witness: Witness, + pub inner_witness: InnerWitness, pub global_polynomial: MleOwned, } @@ -96,13 +97,49 @@ pub fn stacked_pcs_global_statements( #[instrument(skip_all)] pub fn stack_polynomials_and_commit( - prover_state: &mut impl FSProver, + prover_state: &mut impl FSProver, whir_config_builder: &WhirConfigBuilder, memory: &[F], memory_acc: &[F], bytecode_acc: &[F], traces: &BTreeMap, -) -> StackedPcsWitness { +) -> StackedPcsWitness> { + stack_polynomials_and_commit_with( + whir_config_builder, + memory, + memory_acc, + bytecode_acc, + traces, + |whir_config, global_polynomial, offset| whir_config.commit(prover_state, global_polynomial, offset), + ) +} + +pub fn stack_polynomials_and_commit_sha2( + prover_state: &mut impl FSProver, + whir_config_builder: &WhirConfigBuilder, + memory: &[F], + memory_acc: &[F], + bytecode_acc: &[F], + traces: &BTreeMap, +) -> StackedPcsWitness> { + stack_polynomials_and_commit_with( + whir_config_builder, + memory, + memory_acc, + bytecode_acc, + traces, + |whir_config, global_polynomial, offset| whir_config.commit2(prover_state, global_polynomial, offset), + ) +} + +fn stack_polynomials_and_commit_with( + whir_config_builder: &WhirConfigBuilder, + memory: &[F], + memory_acc: &[F], + bytecode_acc: &[F], + traces: &BTreeMap, + commit: impl FnOnce(&WhirConfig, &MleOwned, usize) -> InnerWitness, +) -> StackedPcsWitness { assert_eq!(memory.len(), memory_acc.len()); let tables_heights = traces.iter().map(|(table, trace)| (*table, trace.log_n_rows)).collect(); let tables_heights_sorted = sort_tables_by_height(&tables_heights); @@ -146,8 +183,8 @@ pub fn stack_polynomials_and_commit( let global_polynomial = MleOwned::Base(global_polynomial); - let inner_witness = - WhirConfig::new(whir_config_builder, stacked_n_vars).commit(prover_state, &global_polynomial, offset); + let whir_config = WhirConfig::new(whir_config_builder, stacked_n_vars); + let inner_witness = commit(&whir_config, &global_polynomial, offset); StackedPcsWitness { stacked_n_vars, inner_witness, @@ -157,11 +194,43 @@ pub fn stack_polynomials_and_commit( pub fn stacked_pcs_parse_commitment( whir_config_builder: &WhirConfigBuilder, - verifier_state: &mut impl FSVerifier, + verifier_state: &mut impl FSVerifier, log_memory: usize, log_bytecode: usize, tables_heights: &BTreeMap, ) -> Result, ProofError> { + stacked_pcs_parse_commitment_generic( + whir_config_builder, + verifier_state, + log_memory, + log_bytecode, + tables_heights, + ) +} + +pub fn stacked_pcs_parse_commitment_sha2( + whir_config_builder: &WhirConfigBuilder, + verifier_state: &mut impl FSVerifier, + log_memory: usize, + log_bytecode: usize, + tables_heights: &BTreeMap, +) -> Result, ProofError> { + stacked_pcs_parse_commitment_generic( + whir_config_builder, + verifier_state, + log_memory, + log_bytecode, + tables_heights, + ) +} + +pub fn stacked_pcs_parse_commitment_generic( + whir_config_builder: &WhirConfigBuilder, + verifier_state: &mut impl FSVerifier, + log_memory: usize, + log_bytecode: usize, + tables_heights: &BTreeMap, +) -> Result, ProofError> { if log_memory < tables_heights[&Table::execution()] || tables_heights[&Table::execution()] < tables_heights.values().copied().max().unwrap() { @@ -176,7 +245,7 @@ pub fn stacked_pcs_parse_commitment( { return Err(ProofError::InvalidProof); } - WhirConfig::new(whir_config_builder, stacked_n_vars).parse_commitment(verifier_state) + WhirConfig::new(whir_config_builder, stacked_n_vars).parse_commitment_generic(verifier_state) } fn compute_stacked_n_vars( diff --git a/crates/utils/src/wrappers.rs b/crates/utils/src/wrappers.rs index c8aa8e40f..5ddb5f6b2 100644 --- a/crates/utils/src/wrappers.rs +++ b/crates/utils/src/wrappers.rs @@ -9,6 +9,10 @@ pub fn build_prover_state() -> ProverState ProverState::new(get_poseidon16().clone()) } +pub fn build_prover_state_sha2() -> ProverStateSha2 { + ProverStateSha2::new() +} + pub fn build_verifier_state( prover_state: ProverState, ) -> Result, ProofError> { diff --git a/crates/whir/Cargo.toml b/crates/whir/Cargo.toml index 1c2a2b0a7..75b3f961b 100644 --- a/crates/whir/Cargo.toml +++ b/crates/whir/Cargo.toml @@ -17,6 +17,7 @@ itertools.workspace = true rayon.workspace = true rand.workspace = true tracing.workspace = true +sha2 = "0.10.9" [dev-dependencies] tracing-forest.workspace = true diff --git a/crates/whir/src/commit.rs b/crates/whir/src/commit.rs index b64bb3502..0aecf8002 100644 --- a/crates/whir/src/commit.rs +++ b/crates/whir/src/commit.rs @@ -3,6 +3,7 @@ use fiat_shamir::FSProver; use field::{ExtensionField, TwoAdicField}; use poly::*; +use symetric::merkle::Sha256Digest; use tracing::{info_span, instrument}; use crate::*; @@ -13,6 +14,12 @@ pub enum MerkleData>> { Extension(RoundMerkleTree>), } +#[derive(Debug, Clone)] +pub enum MerkleData2>> { + Base(RoundMerkleTreeSha2>), + Extension(RoundMerkleTreeSha2>), +} + impl>> MerkleData { pub(crate) fn build( matrix: DftOutput, @@ -45,6 +52,34 @@ impl>> MerkleData { } } +impl>> MerkleData2 { + pub(crate) fn build(matrix: DftOutput, full_n_cols: usize, effective_n_cols: usize) -> (Self, Sha256Digest) { + match matrix { + DftOutput::Base(m) => { + let (root, prover_data) = merkle_commit_sha2::, PF>(m, full_n_cols, effective_n_cols); + (MerkleData2::Base(prover_data), root) + } + DftOutput::Extension(m) => { + let (root, prover_data) = merkle_commit_sha2::, EF>(m, full_n_cols, effective_n_cols); + (MerkleData2::Extension(prover_data), root) + } + } + } + + pub(crate) fn open(&self, index: usize) -> (MleOwned, Vec) { + match self { + MerkleData2::Base(prover_data) => { + let (leaf, proof) = merkle_open_sha2::, PF>(prover_data, index); + (MleOwned::Base(leaf), proof) + } + MerkleData2::Extension(prover_data) => { + let (leaf, proof) = merkle_open_sha2::, EF>(prover_data, index); + (MleOwned::Extension(leaf), proof) + } + } + } +} + #[derive(Debug, Clone)] pub struct Witness where @@ -55,6 +90,16 @@ where pub ood_answers: Vec, } +#[derive(Debug, Clone)] +pub struct Witness2 +where + EF: ExtensionField>, +{ + pub prover_data: MerkleData2, + pub ood_points: Vec, + pub ood_answers: Vec, +} + impl WhirConfig where EF: ExtensionField>, @@ -63,7 +108,7 @@ where #[instrument(skip_all)] pub fn commit( &self, - prover_state: &mut impl FSProver, + prover_state: &mut impl FSProver; DIGEST_ELEMS]>, polynomial: &MleOwned, actual_data_len: usize, // polynomial[actual_data_len..] is zero ) -> Witness { @@ -84,7 +129,7 @@ where let (prover_data, root) = MerkleData::build(folded_matrix, n_blocks, effective_n_cols); - prover_state.add_base_scalars(&root); + prover_state.add_commitment(&root); let (ood_points, ood_answers) = sample_ood_points::(prover_state, self.commitment_ood_samples, self.num_variables, |point| { @@ -97,4 +142,42 @@ where ood_answers, } } + + #[instrument(skip_all)] + pub fn commit2( + &self, + prover_state: &mut impl FSProver, + polynomial: &MleOwned, + actual_data_len: usize, // polynomial[actual_data_len..] is zero + ) -> Witness2 { + let n_blocks = 1usize << self.folding_factor.at_round(0); + let evals_len = 1usize << self.num_variables; + let effective_n_cols = actual_data_len.div_ceil(evals_len / n_blocks); + // DFT matrix width: skip as many zero columns as possible, aligned to packing (SIMD) + let dft_n_cols = effective_n_cols.next_multiple_of(packing_width::()).min(n_blocks); + + let folded_matrix = info_span!("FFT").in_scope(|| { + reorder_and_dft( + &polynomial.by_ref(), + self.folding_factor.at_round(0), + self.starting_log_inv_rate, + dft_n_cols, + ) + }); + + let (prover_data, root) = MerkleData2::build(folded_matrix, n_blocks, effective_n_cols); + + prover_state.add_commitment(&root); + + let (ood_points, ood_answers) = + sample_ood_points::(prover_state, self.commitment_ood_samples, self.num_variables, |point| { + polynomial.evaluate(point) + }); + + Witness2 { + prover_data, + ood_points, + ood_answers, + } + } } diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index b5517cd09..681c36bbf 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -8,11 +8,14 @@ use field::BasedVectorSpace; use field::ExtensionField; use field::Field; use field::PackedValue; +use field::PrimeField32; use koala_bear::{KoalaBear, QuinticExtensionFieldKB, default_koalabear_poseidon1_16}; use poly::*; use rayon::prelude::*; use symetric::Compression; +use symetric::merkle::MerkleTreeSha2; +use symetric::merkle::Sha256Digest; use symetric::merkle::unpack_array; use tracing::instrument; use utils::log2_ceil_usize; @@ -23,6 +26,9 @@ use crate::Matrix; pub use symetric::DIGEST_ELEMS; pub(crate) type RoundMerkleTree = WhirMerkleTree, DIGEST_ELEMS>; +pub(crate) type RoundMerkleTreeSha2 = WhirMerkleTreeSha2>; + +use sha2::{Digest, Sha256}; #[allow(clippy::missing_transmute_annotations)] pub(crate) fn merkle_commit>( @@ -55,6 +61,35 @@ pub(crate) fn merkle_commit>( } } +#[allow(clippy::missing_transmute_annotations)] +pub(crate) fn merkle_commit_sha2>( + matrix: DenseMatrix, + full_n_cols: usize, + effective_n_cols: usize, +) -> (Sha256Digest, RoundMerkleTreeSha2) { + if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, QuinticExtensionFieldKB)>() { + let matrix = unsafe { std::mem::transmute::<_, DenseMatrix>(matrix) }; + let dim = >::DIMENSION; + let dft_base_width = matrix.width * dim; + let full_base_width = full_n_cols * dim; + let effective_base_width = effective_n_cols * dim; + let base_values = QuinticExtensionFieldKB::flatten_to_base(matrix.values); + let base_matrix = DenseMatrix::::new(base_values, dft_base_width); + let tree = build_merkle_tree_sha256(base_matrix, full_base_width, effective_base_width); + let root = tree.root(); + let tree = unsafe { std::mem::transmute::<_, RoundMerkleTreeSha2>(tree) }; + (root, tree) + } else if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, KoalaBear)>() { + let matrix = unsafe { std::mem::transmute::<_, DenseMatrix>(matrix) }; + let tree = build_merkle_tree_sha256(matrix, full_n_cols, effective_n_cols); + let root = tree.root(); + let tree = unsafe { std::mem::transmute::<_, RoundMerkleTreeSha2>(tree) }; + (root, tree) + } else { + unimplemented!() + } +} + #[instrument(name = "build merkle tree", skip_all)] fn build_merkle_tree_koalabear( leaf: DenseMatrix, @@ -87,6 +122,15 @@ fn build_merkle_tree_koalabear( } } +#[instrument(name = "build merkle tree sha256", skip_all)] +fn build_merkle_tree_sha256( + leaf: DenseMatrix, + full_base_width: usize, + effective_base_width: usize, +) -> RoundMerkleTreeSha2 { + WhirMerkleTreeSha2::new(leaf, full_base_width, effective_base_width) +} + #[allow(clippy::missing_transmute_annotations)] pub(crate) fn merkle_open>( merkle_tree: &RoundMerkleTree, @@ -111,6 +155,28 @@ pub(crate) fn merkle_open>( } } +#[allow(clippy::missing_transmute_annotations)] +pub(crate) fn merkle_open_sha2>( + merkle_tree: &RoundMerkleTreeSha2, + index: usize, +) -> (Vec, Vec) { + if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, QuinticExtensionFieldKB)>() { + let merkle_tree = unsafe { std::mem::transmute::<_, &RoundMerkleTreeSha2>(merkle_tree) }; + let (inner_leaf, proof) = merkle_tree.open(index); + let leaf = QuinticExtensionFieldKB::reconstitute_from_base(inner_leaf); + let leaf = unsafe { std::mem::transmute::<_, Vec>(leaf) }; + (leaf, proof) + } else if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, KoalaBear)>() { + let merkle_tree = unsafe { std::mem::transmute::<_, &RoundMerkleTreeSha2>(merkle_tree) }; + let (inner_leaf, proof) = merkle_tree.open(index); + let leaf = KoalaBear::reconstitute_from_base(inner_leaf); + let leaf = unsafe { std::mem::transmute::<_, Vec>(leaf) }; + (leaf, proof) + } else { + unimplemented!() + } +} + #[allow(clippy::missing_transmute_annotations)] pub(crate) fn merkle_verify>( merkle_root: [F; DIGEST_ELEMS], @@ -152,6 +218,28 @@ pub(crate) fn merkle_verify>( } } +#[allow(clippy::missing_transmute_annotations)] +pub(crate) fn merkle_verify_sha2>( + merkle_root: Sha256Digest, + index: usize, + dimension: Dimensions, + data: Vec, + proof: &Vec, +) -> bool { + let log_max_height = utils::log2_strict_usize(dimension.height.next_power_of_two()); + if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, QuinticExtensionFieldKB)>() { + let data = unsafe { std::mem::transmute::<_, Vec>(data) }; + let base_data: Vec = QuinticExtensionFieldKB::flatten_to_base(data); + sha2_merkle_verify(&merkle_root, log_max_height, index, &base_data, proof) + } else if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, KoalaBear)>() { + let data = unsafe { std::mem::transmute::<_, Vec>(data) }; + let base_data = KoalaBear::flatten_to_base(data); + sha2_merkle_verify(&merkle_root, log_max_height, index, &base_data, proof) + } else { + unimplemented!() + } +} + #[derive(Debug, Clone)] pub struct WhirMerkleTree { pub(crate) leaf: M, @@ -159,6 +247,14 @@ pub struct WhirMerkleTree { full_leaf_base_width: usize, } +#[derive(Debug, Clone)] +pub struct WhirMerkleTreeSha2> { + pub(crate) leaf: M, + pub(crate) tree: MerkleTreeSha2, + full_leaf_base_width: usize, + _marker: std::marker::PhantomData, +} + impl, const DIGEST_ELEMS: usize> WhirMerkleTree { @@ -286,3 +382,237 @@ where digests } + +#[instrument(name = "first digest layer", level = "debug", skip_all)] +fn sha2_first_digest_layer(h: &Sha256, matrix: &M, full_width: usize) -> Vec +where + F: PrimeField32, + M: Matrix, +{ + let height = matrix.height(); + let matrix_width = matrix.width(); + let n_trailing_zeros = full_width - matrix_width; + + (0..height) + .into_par_iter() + .map(|r| { + let mut hasher = h.clone(); + for value in matrix.row(r).unwrap() { + hasher.update(value.as_canonical_u32().to_le_bytes()); + } + for _ in 0..n_trailing_zeros { + hasher.update(F::ZERO.as_canonical_u32().to_le_bytes()); + } + let digest = hasher.finalize(); + digest[..16].try_into().unwrap() + }) + .collect() +} + +fn sha2_compress_pair(left: &Sha256Digest, right: &Sha256Digest) -> Sha256Digest { + let mut hasher = Sha256::new(); + hasher.update(left); + hasher.update(right); + let digest = hasher.finalize(); + digest[..16].try_into().unwrap() +} + +fn sha2_merkle_verify( + commit: &Sha256Digest, + log_height: usize, + mut index: usize, + opened_values: &[F], + opening_proof: &[Sha256Digest], +) -> bool { + if opening_proof.len() != log_height { + return false; + } + + let mut hasher = Sha256::new(); + for value in opened_values { + hasher.update(value.as_canonical_u32().to_le_bytes()); + } + let digest = hasher.finalize(); + let mut root: Sha256Digest = digest[..16].try_into().unwrap(); + + for sibling in opening_proof { + let (left, right) = if index & 1 == 0 { + (root, *sibling) + } else { + (*sibling, root) + }; + root = sha2_compress_pair(&left, &right); + index >>= 1; + } + + commit == &root +} + +impl> WhirMerkleTreeSha2 { + #[instrument(name = "build merkle tree", skip_all)] + pub fn new(leaf: M, full_leaf_base_width: usize, effective_base_width: usize) -> Self { + assert!(leaf.height().is_power_of_two()); + assert!(leaf.width() <= full_leaf_base_width); + assert!(effective_base_width <= full_leaf_base_width); + + let first_layer = sha2_first_digest_layer(&Sha256::new(), &leaf, full_leaf_base_width); + let mut digest_layers = vec![first_layer]; + + tracing::debug_span!("asc").in_scope(|| { + while digest_layers.last().unwrap().len() > 1 { + let prev_layer = digest_layers.last().unwrap(); + assert!(prev_layer.len().is_multiple_of(2)); + let next_layer = prev_layer + .par_chunks_exact(2) + .map(|pair| sha2_compress_pair(&pair[0], &pair[1])) + .collect(); + digest_layers.push(next_layer); + } + }); + + let tree = MerkleTreeSha2 { digest_layers }; + Self { + leaf, + tree, + full_leaf_base_width, + _marker: std::marker::PhantomData, + } + } + + #[must_use] + pub fn root(&self) -> Sha256Digest { + self.tree.root() + } + + pub fn open(&self, index: usize) -> (Vec, Vec) { + let log_height = log2_ceil_usize(self.leaf.height()); + let mut opening: Vec = self.leaf.row(index).unwrap().into_iter().collect(); + opening.resize(self.full_leaf_base_width, F::default()); + let proof = self.tree.open_siblings(index, log_height); + (opening, proof) + } +} + +#[cfg(test)] +mod tests { + use std::time::Instant; + + use field::integers::QuotientMap; + + use super::*; + + fn pseudo_random_koalabear(index: usize) -> KoalaBear { + let mut x = index as u64; + x = x.wrapping_add(0x9e37_79b9_7f4a_7c15); + x = (x ^ (x >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9); + x = (x ^ (x >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb); + KoalaBear::from_int(x ^ (x >> 31)) + } + + #[test] + fn whir_sha2_merkle_tree_builds_expected_layers() { + let width = 8; + let height = 4; + let values: Vec<_> = (0..width * height).map(pseudo_random_koalabear).collect(); + let matrix = DenseMatrix::new(values, width); + + let (root, tree) = merkle_commit_sha2::(matrix, width, width); + + assert_eq!(tree.tree.digest_layers.len(), log2_ceil_usize(height) + 1); + assert_eq!(tree.tree.digest_layers[0].len(), height); + assert_eq!(tree.tree.digest_layers[1].len(), height / 2); + assert_eq!(tree.tree.digest_layers[2].len(), 1); + assert_eq!(tree.root(), root); + } + + #[test] + fn whir_sha2_merkle_opening_verifies() { + let width = 8; + let height = 4; + let index = 2; + let values: Vec<_> = (0..width * height).map(pseudo_random_koalabear).collect(); + let matrix = DenseMatrix::new(values, width); + + let (root, tree) = merkle_commit_sha2::(matrix, width, width); + let (leaf, proof) = merkle_open_sha2::(&tree, index); + + assert!(merkle_verify_sha2::( + root, + index, + Dimensions { width, height }, + leaf, + &proof, + )); + } + + use tracing_forest::{ForestLayer, util::LevelFilter}; + use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt}; + + pub fn init_tracing() { + let env_filter = EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(); + + let _ = Registry::default() + .with(env_filter) + .with(ForestLayer::default()) + .try_init(); + } + + #[test] + fn bench_merkle_commit_koalabear_width_32_log_sizes() { + let folding_factor = 7; + let width = 1usize << folding_factor; + + init_tracing(); + + for log_size in 20..=26 { + let n_values = 1usize << log_size; + let height = n_values / width; + let values: Vec<_> = (0..n_values).map(pseudo_random_koalabear).collect(); + let matrix = DenseMatrix::new(values, width); + + assert_eq!(matrix.width(), width); + assert_eq!(matrix.height(), height); + + let start = Instant::now(); + let (_root, prover_data) = merkle_commit::(matrix, width, width); + let elapsed = start.elapsed(); + + assert_eq!(prover_data.leaf.width(), width); + assert_eq!(prover_data.leaf.height(), height); + assert_eq!(prover_data.full_leaf_base_width, width); + + let log_height = log2_ceil_usize(height); + println!("poseidon log_size={log_size}, log_height={log_height}, width={width}, time={elapsed:?}",); + } + } + + #[test] + fn bench_sha256_merkle_commit_koalabear_width_32_log_sizes() { + let folding_factor = 7; + let width = 1usize << folding_factor; + + init_tracing(); + + for log_size in 20..=26 { + let n_values = 1usize << log_size; + let height = n_values / width; + let values: Vec<_> = (0..n_values).map(pseudo_random_koalabear).collect(); + let matrix = DenseMatrix::new(values, width); + + assert_eq!(matrix.width(), width); + assert_eq!(matrix.height(), height); + + let start = Instant::now(); + let (root, tree) = merkle_commit_sha2::(matrix, width, width); + let elapsed = start.elapsed(); + + assert_eq!(tree.tree.digest_layers[0].len(), height); + assert_eq!(tree.tree.digest_layers.last().unwrap()[0], root); + + let log_height = log2_ceil_usize(height); + println!("sha256 log={log_size}, log_height={log_height}, width={width}, time={elapsed:?}",); + } + } +} diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 8b8b4031c..5f7eb110f 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -7,10 +7,41 @@ use field::{ExtensionField, Field, TwoAdicField}; use poly::*; use rayon::prelude::*; use sumcheck::{ProductComputation, run_product_sumcheck, sumcheck_prove_many_rounds}; +use symetric::merkle::Sha256Digest; use tracing::{info_span, instrument}; use crate::{config::WhirConfig, *}; +#[instrument( + skip_all, + fields(hash = "poseidon", round = round_index) +)] +fn merkle_data_build( + folded_matrix: DftOutput, + full_n_cols: usize, + round_index: usize, +) -> (MerkleData, [PF; DIGEST_ELEMS]) +where + EF: ExtensionField>, +{ + MerkleData::build(folded_matrix, full_n_cols, full_n_cols) +} + +#[instrument( + skip_all, + fields(hash = "sha2", round = round_index) +)] +fn merkle_data_build_sha2( + folded_matrix: DftOutput, + full_n_cols: usize, + round_index: usize, +) -> (MerkleData2, Sha256Digest) +where + EF: ExtensionField>, +{ + MerkleData2::build(folded_matrix, full_n_cols, full_n_cols) +} + impl WhirConfig where EF: ExtensionField>, @@ -36,7 +67,7 @@ where #[instrument(name = "WHIR prove", skip_all)] pub fn prove( &self, - prover_state: &mut impl FSProver, + prover_state: &mut impl FSProver; DIGEST_ELEMS]>, statement: Vec>, witness: Witness, polynomial: &MleRef<'_, EF>, @@ -55,10 +86,32 @@ where MultilinearPoint(round_state.randomness_vec) } + #[instrument(name = "WHIR prove", skip_all)] + pub fn prove2( + &self, + prover_state: &mut impl FSProver, + statement: Vec>, + witness: Witness2, + polynomial: &MleRef<'_, EF>, + ) -> MultilinearPoint { + assert!(self.validate_parameters()); + // assert!(self.validate_witness(&witness, polynomial)); + self.validate_statement(&statement); + + let mut round_state = + RoundState2::initialize_first_round_state(self, prover_state, statement, witness, polynomial).unwrap(); + + for round in 0..=self.n_rounds() { + self.round2(round, prover_state, &mut round_state).unwrap(); + } + + MultilinearPoint(round_state.randomness_vec) + } + fn round( &self, round_index: usize, - prover_state: &mut impl FSProver, + prover_state: &mut impl FSProver; DIGEST_ELEMS]>, round_state: &mut RoundState, ) -> ProofResult<()> { let folded_evaluations = &round_state.sumcheck_prover.evals; @@ -88,9 +141,9 @@ where }); let full = 1 << folding_factor_next; - let (prover_data, root) = MerkleData::build(folded_matrix, full, full); + let (prover_data, root) = merkle_data_build(folded_matrix, full, round_index); - prover_state.add_base_scalars(&root); + prover_state.add_commitment(&root); // Handle OOD (Out-Of-Domain) samples let (ood_points, ood_answers) = @@ -178,10 +231,136 @@ where Ok(()) } + fn round2( + &self, + round_index: usize, + prover_state: &mut impl FSProver, + round_state: &mut RoundState2, + ) -> ProofResult<()> { + let folded_evaluations = &round_state.sumcheck_prover.evals; + let num_variables = self.num_variables - self.folding_factor.total_number(round_index); + + // Base case: final round reached + if round_index == self.n_rounds() { + return self.final_round2(round_index, prover_state, round_state); + } + + let round_params = &self.round_parameters[round_index]; + + // Compute the folding factors for later use + let folding_factor_next = self.folding_factor.at_round(round_index + 1); + + // Compute polynomial evaluations and build Merkle tree + let domain_reduction = 1 << self.rs_reduction_factor(round_index); + let new_domain_size = round_state.domain_size / domain_reduction; + let inv_rate = new_domain_size >> num_variables; + let folded_matrix = info_span!("FFT").in_scope(|| { + reorder_and_dft( + &folded_evaluations.by_ref(), + folding_factor_next, + log2_strict_usize(inv_rate), + 1 << folding_factor_next, + ) + }); + + let full = 1 << folding_factor_next; + let (prover_data, root) = merkle_data_build_sha2(folded_matrix, full, round_index); + + prover_state.add_commitment(&root); + + // Handle OOD (Out-Of-Domain) samples + let (ood_points, ood_answers) = + sample_ood_points::(prover_state, round_params.ood_samples, num_variables, |point| { + info_span!("ood evaluation").in_scope(|| folded_evaluations.evaluate(point)) + }); + + prover_state.pow_grinding(round_params.query_pow_bits); + + let (ood_challenges, stir_challenges, stir_challenges_indexes) = self.compute_stir_queries2( + prover_state, + round_state, + num_variables, + round_params, + &ood_points, + round_index, + )?; + + let folding_randomness = round_state.folding_randomness( + self.folding_factor.at_round(round_index) + round_state.commitment_merkle_prover_data_b.is_some() as usize, + ); + + let stir_evaluations = if let Some(data_b) = &round_state.commitment_merkle_prover_data_b { + let answers_a = open_merkle_tree_at_challenges2( + &round_state.merkle_prover_data, + prover_state, + &stir_challenges_indexes, + ); + let answers_b = open_merkle_tree_at_challenges2(data_b, prover_state, &stir_challenges_indexes); + let mut stir_evaluations = Vec::new(); + for (answer_a, answer_b) in answers_a.iter().zip(&answers_b) { + let vars_a = answer_a.by_ref().n_vars(); + let vars_b = answer_b.by_ref().n_vars(); + let a_trunc = folding_randomness[1..].to_vec(); + let eval_a = answer_a.evaluate(&MultilinearPoint(a_trunc)); + let b_trunc = folding_randomness[vars_a - vars_b + 1..].to_vec(); + let eval_b = answer_b.evaluate(&MultilinearPoint(b_trunc)); + let last_fold_rand_a = folding_randomness[0]; + let last_fold_rand_b = folding_randomness[..vars_a - vars_b + 1] + .iter() + .map(|&x| EF::ONE - x) + .product::(); + stir_evaluations.push(eval_a * last_fold_rand_a + eval_b * last_fold_rand_b); + } + + stir_evaluations + } else { + open_merkle_tree_at_challenges2(&round_state.merkle_prover_data, prover_state, &stir_challenges_indexes) + .iter() + .map(|answer| answer.evaluate(&folding_randomness)) + .collect() + }; + + // Randomness for combination + let combination_randomness_gen: EF = prover_state.sample(); + let ood_combination_randomness: Vec<_> = combination_randomness_gen.powers().collect_n(ood_challenges.len()); + round_state + .sumcheck_prover + .add_new_equality(&ood_challenges, &ood_answers, &ood_combination_randomness); + let stir_combination_randomness = combination_randomness_gen + .powers() + .skip(ood_challenges.len()) + .take(stir_challenges.len()) + .collect::>(); + + round_state.sumcheck_prover.add_new_base_equality( + &stir_challenges, + &stir_evaluations, + &stir_combination_randomness, + ); + + let next_folding_randomness = round_state.sumcheck_prover.run_sumcheck_many_rounds( + None, + prover_state, + folding_factor_next, + round_params.folding_pow_bits, + ); + + round_state.randomness_vec.extend_from_slice(&next_folding_randomness.0); + + // Update round state + round_state.domain_size = new_domain_size; + round_state.next_domain_gen = + PF::::two_adic_generator(log2_strict_usize(new_domain_size) - folding_factor_next); + round_state.merkle_prover_data = prover_data; + round_state.commitment_merkle_prover_data_b = None; + + Ok(()) + } + fn final_round( &self, round_index: usize, - prover_state: &mut impl FSProver, + prover_state: &mut impl FSProver; DIGEST_ELEMS]>, round_state: &mut RoundState, ) -> ProofResult<()> { // Convert evaluations to coefficient form and send to the verifier. @@ -246,6 +425,74 @@ where Ok(()) } + fn final_round2( + &self, + round_index: usize, + prover_state: &mut impl FSProver, + round_state: &mut RoundState2, + ) -> ProofResult<()> { + // Convert evaluations to coefficient form and send to the verifier. + let mut coeffs = match &round_state.sumcheck_prover.evals { + MleOwned::Extension(evals) => evals.clone(), + MleOwned::ExtensionPacked(evals) => unpack_extension::(evals), + _ => unreachable!(), + }; + evals_to_coeffs(&mut coeffs); + prover_state.add_extension_scalars(&coeffs); + + prover_state.pow_grinding(self.final_query_pow_bits); + + // Final verifier queries and answers. The indices are over the folded domain. + let final_challenge_indexes = get_challenge_stir_queries( + // The size of the original domain before folding + round_state.domain_size >> self.folding_factor.at_round(round_index), + self.final_queries, + prover_state, + ); + + let mut base_paths = Vec::new(); + let mut ext_paths = Vec::new(); + for challenge in final_challenge_indexes { + let (answer, sibling_hashes) = round_state.merkle_prover_data.open(challenge); + + match answer { + MleOwned::Base(leaf) => { + base_paths.push(MerklePath { + leaf_data: leaf, + sibling_hashes, + leaf_index: challenge, + }); + } + MleOwned::Extension(leaf) => { + ext_paths.push(MerklePath { + leaf_data: leaf, + sibling_hashes, + leaf_index: challenge, + }); + } + _ => unreachable!(), + } + } + if !base_paths.is_empty() { + prover_state.hint_merkle_paths_base(base_paths); + } + if !ext_paths.is_empty() { + prover_state.hint_merkle_paths_extension(ext_paths); + } + + // Run final sumcheck if required + if self.final_sumcheck_rounds > 0 { + let final_folding_randomness = + round_state + .sumcheck_prover + .run_sumcheck_many_rounds(None, prover_state, self.final_sumcheck_rounds, 0); + + round_state.randomness_vec.extend(final_folding_randomness.0); + } + + Ok(()) + } + #[allow(clippy::type_complexity)] fn compute_stir_queries( &self, @@ -274,11 +521,82 @@ where Ok((ood_challenges, stir_challenges, stir_challenges_indexes)) } + + #[allow(clippy::type_complexity)] + fn compute_stir_queries2( + &self, + prover_state: &mut impl FSProver, + round_state: &RoundState2, + num_variables: usize, + round_params: &RoundConfig, + ood_points: &[EF], + round_index: usize, + ) -> ProofResult<(Vec>, Vec>>, Vec)> { + let stir_challenges_indexes = get_challenge_stir_queries( + round_state.domain_size >> self.folding_factor.at_round(round_index), + round_params.num_queries, + prover_state, + ); + + let domain_scaled_gen = round_state.next_domain_gen; + let ood_challenges = ood_points + .iter() + .map(|univariate| MultilinearPoint::expand_from_univariate(*univariate, num_variables)) + .collect(); + let stir_challenges = stir_challenges_indexes + .iter() + .map(|i| MultilinearPoint::expand_from_univariate(domain_scaled_gen.exp_u64(*i as u64), num_variables)) + .collect(); + + Ok((ood_challenges, stir_challenges, stir_challenges_indexes)) + } } fn open_merkle_tree_at_challenges>>( merkle_tree: &MerkleData, - prover_state: &mut impl FSProver, + prover_state: &mut impl FSProver; DIGEST_ELEMS]>, + stir_challenges_indexes: &[usize], +) -> Vec> { + let mut answers = Vec::new(); + let mut base_paths = Vec::new(); + let mut ext_paths = Vec::new(); + + for &challenge in stir_challenges_indexes { + let (answer, sibling_hashes) = merkle_tree.open(challenge); + + match &answer { + MleOwned::Base(leaf) => { + base_paths.push(MerklePath { + leaf_data: leaf.clone(), + sibling_hashes, + leaf_index: challenge, + }); + } + MleOwned::Extension(leaf) => { + ext_paths.push(MerklePath { + leaf_data: leaf.clone(), + sibling_hashes, + leaf_index: challenge, + }); + } + _ => unreachable!(), + } + answers.push(answer); + } + + if !base_paths.is_empty() { + prover_state.hint_merkle_paths_base(base_paths); + } + if !ext_paths.is_empty() { + prover_state.hint_merkle_paths_extension(ext_paths); + } + + answers +} + +fn open_merkle_tree_at_challenges2>>( + merkle_tree: &MerkleData2, + prover_state: &mut impl FSProver, stir_challenges_indexes: &[usize], ) -> Vec> { let mut answers = Vec::new(); @@ -457,6 +775,19 @@ where randomness_vec: Vec, } +#[derive(Debug)] +pub(crate) struct RoundState2 +where + EF: ExtensionField>, +{ + domain_size: usize, + next_domain_gen: PF, + sumcheck_prover: SumcheckSingle, + commitment_merkle_prover_data_b: Option>, + merkle_prover_data: MerkleData2, + randomness_vec: Vec, +} + #[allow(clippy::mismatching_type_param_order)] impl RoundState where @@ -512,6 +843,61 @@ where } } +#[allow(clippy::mismatching_type_param_order)] +impl RoundState2 +where + EF: ExtensionField>, + PF: TwoAdicField, +{ + pub(crate) fn initialize_first_round_state( + prover: &WhirConfig, + prover_state: &mut impl FSProver, + mut statement: Vec>, + witness: Witness2, + polynomial: &MleRef<'_, EF>, + ) -> ProofResult { + let ood_statements = witness + .ood_points + .into_iter() + .zip(witness.ood_answers) + .map(|(point, evaluation)| { + SparseStatement::dense( + MultilinearPoint::expand_from_univariate(point, prover.num_variables), + evaluation, + ) + }) + .collect::>(); + + statement.splice(0..0, ood_statements); + + let combination_randomness_gen: EF = prover_state.sample(); + + let (sumcheck_prover, folding_randomness) = SumcheckSingle::run_initial_sumcheck_rounds( + polynomial, + &statement, + combination_randomness_gen, + prover_state, + prover.folding_factor.at_round(0), + prover.starting_folding_pow_bits, + ); + + Ok(Self { + domain_size: prover.starting_domain_size(), + next_domain_gen: PF::::two_adic_generator( + log2_strict_usize(prover.starting_domain_size()) - prover.folding_factor.at_round(0), + ), + sumcheck_prover, + merkle_prover_data: witness.prover_data, + commitment_merkle_prover_data_b: None, + randomness_vec: folding_randomness.0.clone(), + }) + } + + fn folding_randomness(&self, folding_factor: usize) -> MultilinearPoint { + MultilinearPoint(self.randomness_vec[self.randomness_vec.len() - folding_factor..].to_vec()) + } +} + #[instrument(skip_all, fields(num_constraints = statements.len(), n_vars = statements[0].total_num_variables))] fn combine_statement(statements: &[SparseStatement], gamma: EF) -> (Vec>, EF) where diff --git a/crates/whir/src/verify.rs b/crates/whir/src/verify.rs index 18925b287..6d78e2c4f 100644 --- a/crates/whir/src/verify.rs +++ b/crates/whir/src/verify.rs @@ -5,21 +5,22 @@ use std::{fmt::Debug, marker::PhantomData}; use fiat_shamir::{FSVerifier, ProofError, ProofResult, pack_scalars_to_extension}; use field::{ExtensionField, Field, PrimeCharacteristicRing, TwoAdicField}; use poly::*; +use symetric::merkle::Sha256Digest; use crate::*; #[derive(Debug, Clone)] -pub struct ParsedCommitment> { +pub struct ParsedCommitment, Digest = [PF; DIGEST_ELEMS]> { pub num_variables: usize, - pub root: [PF; DIGEST_ELEMS], + pub root: Digest, pub ood_points: Vec, pub ood_answers: Vec, pub base_field: PhantomData, } -impl> ParsedCommitment { +impl, Digest: Clone> ParsedCommitment { pub fn parse( - verifier_state: &mut impl FSVerifier, + verifier_state: &mut impl FSVerifier, num_variables: usize, ood_samples: usize, ) -> ProofResult @@ -28,7 +29,7 @@ impl> ParsedCommitment { EF: ExtensionField + TwoAdicField, EF: ExtensionField>, { - let root = verifier_state.next_base_scalars_vec(DIGEST_ELEMS)?.try_into().unwrap(); + let root = verifier_state.next_commitment()?; let mut ood_points = vec![]; let ood_answers = if ood_samples > 0 { ood_points = verifier_state.sample_vec(ood_samples); @@ -59,19 +60,102 @@ impl> ParsedCommitment { } } +pub trait WhirVerifierDigest: Clone + Sized +where + EF: TwoAdicField + ExtensionField>, + PF: TwoAdicField, + F: Field + ExtensionField>, + EF: ExtensionField, +{ + fn verify_stir_challenges( + config: &WhirConfig, + verifier_state: &mut V, + params: &RoundConfig, + commitment: &ParsedCommitment, + folding_randomness: &MultilinearPoint, + round_index: usize, + ) -> ProofResult>> + where + V: FSVerifier; +} + +impl WhirVerifierDigest for [PF; DIGEST_ELEMS] +where + EF: TwoAdicField + ExtensionField>, + PF: TwoAdicField, + F: Field + ExtensionField>, + EF: ExtensionField, +{ + fn verify_stir_challenges( + config: &WhirConfig, + verifier_state: &mut V, + params: &RoundConfig, + commitment: &ParsedCommitment, + folding_randomness: &MultilinearPoint, + round_index: usize, + ) -> ProofResult>> + where + V: FSVerifier, + { + config.verify_stir_challenges(verifier_state, params, commitment, folding_randomness, round_index) + } +} + +impl WhirVerifierDigest for Sha256Digest +where + EF: TwoAdicField + ExtensionField>, + PF: TwoAdicField, + F: Field + ExtensionField>, + EF: ExtensionField, +{ + fn verify_stir_challenges( + config: &WhirConfig, + verifier_state: &mut V, + params: &RoundConfig, + commitment: &ParsedCommitment, + folding_randomness: &MultilinearPoint, + round_index: usize, + ) -> ProofResult>> + where + V: FSVerifier, + { + config.verify_stir_challenges2(verifier_state, params, commitment, folding_randomness, round_index) + } +} + impl WhirConfig where EF: TwoAdicField + ExtensionField>, { pub fn parse_commitment( &self, - verifier_state: &mut impl FSVerifier, + verifier_state: &mut impl FSVerifier; DIGEST_ELEMS]>, ) -> ProofResult> where EF: ExtensionField, { ParsedCommitment::::parse(verifier_state, self.num_variables, self.commitment_ood_samples) } + + pub fn parse_commitment_generic( + &self, + verifier_state: &mut impl FSVerifier, + ) -> ProofResult> + where + EF: ExtensionField, + { + ParsedCommitment::::parse(verifier_state, self.num_variables, self.commitment_ood_samples) + } + + pub fn parse_commitment2( + &self, + verifier_state: &mut impl FSVerifier, + ) -> ProofResult> + where + EF: ExtensionField, + { + self.parse_commitment_generic(verifier_state) + } } impl WhirConfig @@ -80,28 +164,27 @@ where PF: TwoAdicField, { #[allow(clippy::too_many_lines)] - pub fn verify( + pub fn verify( &self, - verifier_state: &mut impl FSVerifier, - parsed_commitment: &ParsedCommitment, + verifier_state: &mut V, + parsed_commitment: &ParsedCommitment, statement: Vec>, ) -> ProofResult> where F: TwoAdicField + ExtensionField>, EF: ExtensionField, + Digest: WhirVerifierDigest, + V: FSVerifier, { statement .iter() .for_each(|c| assert_eq!(c.total_num_variables, parsed_commitment.num_variables)); - // During the rounds we collect constraints, combination randomness, folding randomness - // and we update the claimed sum of constraint evaluation. let mut round_constraints = Vec::new(); let mut round_folding_randomness = Vec::new(); let mut claimed_sum = EF::ZERO; let mut prev_commitment = parsed_commitment.clone(); - // Combine OODS and statement constraints to claimed_sum let constraints: Vec<_> = prev_commitment .oods_constraints() .into_iter() @@ -110,7 +193,6 @@ where let combination_randomness = self.combine_constraints(verifier_state, &mut claimed_sum, &constraints)?; round_constraints.push((combination_randomness, constraints)); - // Initial sumcheck let folding_randomness = verify_sumcheck_rounds::( verifier_state, &mut claimed_sum, @@ -120,15 +202,15 @@ where round_folding_randomness.push(folding_randomness); for round_index in 0..self.n_rounds() { - // Fetch round parameters from config let round_params = &self.round_parameters[round_index]; + let new_commitment = ParsedCommitment::::parse( + verifier_state, + round_params.num_variables, + round_params.ood_samples, + )?; - // Receive commitment to the folded polynomial (likely encoded at higher expansion) - let new_commitment = - ParsedCommitment::::parse(verifier_state, round_params.num_variables, round_params.ood_samples)?; - - // Verify in-domain challenges on the previous commitment. - let stir_constraints = self.verify_stir_challenges( + let stir_constraints = Digest::verify_stir_challenges( + self, verifier_state, round_params, &prev_commitment, @@ -136,7 +218,6 @@ where round_index, )?; - // Add out-of-domain and in-domain constraints to claimed_sum let constraints: Vec> = new_commitment .oods_constraints() .into_iter() @@ -154,17 +235,14 @@ where )?; round_folding_randomness.push(folding_randomness); - - // Update round parameters prev_commitment = new_commitment; } - // In the final round we receive the full polynomial instead of a commitment. let n_final_coeffs = 1 << self.n_vars_of_final_polynomial(); let final_coefficients = verifier_state.next_extension_scalars_vec(n_final_coeffs)?; - // Verify in-domain challenges on the previous commitment. - let stir_constraints = self.verify_stir_challenges( + let stir_constraints = Digest::verify_stir_challenges( + self, verifier_state, &self.final_round_config(), &prev_commitment, @@ -172,7 +250,6 @@ where self.n_rounds(), )?; - // Verify stir constraints directly on final polynomial stir_constraints .iter() .all(|c| verify_constraint_coeffs(c, &final_coefficients)) @@ -183,7 +260,6 @@ where verify_sumcheck_rounds::(verifier_state, &mut claimed_sum, self.final_sumcheck_rounds, 0)?; round_folding_randomness.push(final_sumcheck_randomness.clone()); - // Compute folding randomness across all rounds. let folding_randomness = MultilinearPoint( round_folding_randomness .into_iter() @@ -193,7 +269,6 @@ where let evaluation_of_weights = self.eval_constraints_poly(&round_constraints, folding_randomness.clone()); - // Check the final sumcheck evaluation (coefficient form, reversed point) let mut reversed_point = final_sumcheck_randomness.0.clone(); reversed_point.reverse(); let final_value = eval_multilinear_coeffs(&final_coefficients, &reversed_point); @@ -226,7 +301,7 @@ where fn verify_stir_challenges( &self, - verifier_state: &mut impl FSVerifier, + verifier_state: &mut impl FSVerifier; DIGEST_ELEMS]>, params: &RoundConfig, commitment: &ParsedCommitment, folding_randomness: &MultilinearPoint, @@ -284,10 +359,64 @@ where Ok(stir_constraints) } + fn verify_stir_challenges2( + &self, + verifier_state: &mut impl FSVerifier, + params: &RoundConfig, + commitment: &ParsedCommitment, + folding_randomness: &MultilinearPoint, + round_index: usize, + ) -> ProofResult>> + where + F: Field + ExtensionField>, + EF: ExtensionField, + { + let leafs_base_field = round_index == 0; + + verifier_state.check_pow_grinding(params.query_pow_bits)?; + + let stir_challenges_indexes = get_challenge_stir_queries( + params.domain_size >> params.folding_factor, + params.num_queries, + verifier_state, + ); + + let dimensions = vec![Dimensions { + height: params.domain_size >> params.folding_factor, + width: 1 << params.folding_factor, + }]; + let answers = self.verify_merkle_proof2::( + verifier_state, + &commitment.root, + &stir_challenges_indexes, + &dimensions, + leafs_base_field, + )?; + + let folds: Vec<_> = answers + .into_iter() + .map(|answers| answers.evaluate(folding_randomness)) + .collect(); + + let stir_constraints = stir_challenges_indexes + .iter() + .map(|&index| params.folded_domain_gen.exp_u64(index as u64)) + .zip(&folds) + .map(|(point, &value)| { + SparseStatement::dense( + MultilinearPoint::expand_from_univariate(EF::from(point), params.num_variables), + value, + ) + }) + .collect(); + + Ok(stir_constraints) + } + #[allow(clippy::too_many_arguments)] fn verify_merkle_proof( &self, - verifier_state: &mut impl FSVerifier, + verifier_state: &mut impl FSVerifier; DIGEST_ELEMS]>, root: &[PF; DIGEST_ELEMS], indices: &[usize], dimensions: &[Dimensions], @@ -341,6 +470,62 @@ where Ok(res) } + fn verify_merkle_proof2( + &self, + verifier_state: &mut impl FSVerifier, + root: &Sha256Digest, + indices: &[usize], + dimensions: &[Dimensions], + leafs_base_field: bool, + ) -> ProofResult>> + where + F: Field + ExtensionField>, + EF: ExtensionField, + { + let res = if leafs_base_field { + let mut answers = Vec::>::new(); + let mut merkle_proofs = Vec::new(); + + for _ in 0..indices.len() { + let opening = verifier_state.next_merkle_opening()?; + answers.push(pack_scalars_to_extension::, F>(&opening.leaf_data)); + merkle_proofs.push(opening.path); + } + + for (i, &index) in indices.iter().enumerate() { + if !merkle_verify_sha2::, F>(*root, index, dimensions[0], answers[i].clone(), &merkle_proofs[i]) + { + return Err(ProofError::InvalidProof); + } + } + + answers + .into_iter() + .map(|inner| inner.iter().map(|&f_el| f_el.into()).collect()) + .collect() + } else { + let mut answers = vec![]; + let mut merkle_proofs = Vec::new(); + + for _ in 0..indices.len() { + let opening = verifier_state.next_merkle_opening()?; + answers.push(pack_scalars_to_extension::, EF>(&opening.leaf_data)); + merkle_proofs.push(opening.path); + } + + for (i, &index) in indices.iter().enumerate() { + if !merkle_verify_sha2::, EF>(*root, index, dimensions[0], answers[i].clone(), &merkle_proofs[i]) + { + return Err(ProofError::InvalidProof); + } + } + + answers + }; + + Ok(res) + } + fn eval_constraints_poly( &self, constraints: &[(Vec, Vec>)], diff --git a/crates/whir/tests/run_whir.rs b/crates/whir/tests/run_whir.rs index 105460ca3..cd6b15d81 100644 --- a/crates/whir/tests/run_whir.rs +++ b/crates/whir/tests/run_whir.rs @@ -128,7 +128,7 @@ fn test_run_whir() { let parsed_commitment = params.parse_commitment::(&mut verifier_state).unwrap(); params - .verify::(&mut verifier_state, &parsed_commitment, statement.clone()) + .verify(&mut verifier_state, &parsed_commitment, statement.clone()) .unwrap(); println!(