From ecab2ca82a5749369343eaff4807f1820dc9dd0d Mon Sep 17 00:00:00 2001 From: kilic Date: Thu, 14 May 2026 15:21:42 +0300 Subject: [PATCH 01/19] bench soft --- Cargo.lock | 1 + crates/whir/Cargo.toml | 1 + crates/whir/src/merkle.rs | 144 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 146 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index d938586b8..209470eae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -710,6 +710,7 @@ dependencies = [ "mt-utils", "rand", "rayon", + "sha2", "system-info", "tracing", "tracing-forest", diff --git a/crates/whir/Cargo.toml b/crates/whir/Cargo.toml index 1c2a2b0a7..696459af9 100644 --- a/crates/whir/Cargo.toml +++ b/crates/whir/Cargo.toml @@ -19,5 +19,6 @@ rand.workspace = true tracing.workspace = true [dev-dependencies] +sha2 = "0.10.9" tracing-forest.workspace = true tracing-subscriber.workspace = true diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index b5517cd09..94fd5f627 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -286,3 +286,147 @@ where digests } + +#[cfg(test)] +mod tests { + use std::time::Instant; + + use field::PrimeField32; + use field::integers::QuotientMap; + use rand::{RngExt, SeedableRng, rngs::StdRng}; + use sha2::{Digest, Sha256}; + + use super::*; + + type Sha256Digest = [u8; 16]; + + #[derive(Debug)] + struct Sha256MerkleTree { + digest_layers: Vec>, + } + + fn pseudo_random_koalabear(index: usize) -> KoalaBear { + let mut x = index as u64; + x = x.wrapping_add(0x9e37_79b9_7f4a_7c15); + x = (x ^ (x >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9); + x = (x ^ (x >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb); + KoalaBear::from_int(x ^ (x >> 31)) + } + + fn sha256_truncated(hasher: Sha256) -> Sha256Digest { + let digest = hasher.finalize(); + digest[..16].try_into().unwrap() + } + + fn sha256_hash_row(row: &[KoalaBear]) -> Sha256Digest { + let mut hasher = Sha256::new(); + for value in row { + hasher.update(value.as_canonical_u32().to_le_bytes()); + } + sha256_truncated(hasher) + } + + fn sha256_hash_pair(left: &Sha256Digest, right: &Sha256Digest) -> Sha256Digest { + let mut hasher = Sha256::new(); + hasher.update(left); + hasher.update(right); + sha256_truncated(hasher) + } + + #[tracing::instrument(name = "sha256 merkle commit", skip_all)] + fn sha256_merkle_commit(matrix: DenseMatrix) -> (Sha256Digest, Sha256MerkleTree) { + let width = matrix.width(); + let height = matrix.height(); + assert!(height.is_power_of_two()); + + let first_layer: Vec<_> = tracing::info_span!("leafs") + .in_scope(|| matrix.values.par_chunks_exact(width).map(sha256_hash_row).collect()); + let mut digest_layers = vec![first_layer]; + + tracing::info_span!("asc").in_scope(|| { + while digest_layers.last().unwrap().len() > 1 { + let prev_layer = digest_layers.last().unwrap(); + assert!(prev_layer.len().is_multiple_of(2)); + let next_layer = prev_layer + .par_chunks_exact(2) + .map(|pair| sha256_hash_pair(&pair[0], &pair[1])) + .collect(); + digest_layers.push(next_layer); + } + }); + + let root = digest_layers.last().unwrap()[0]; + (root, Sha256MerkleTree { digest_layers }) + } + + use tracing_forest::{ForestLayer, util::LevelFilter}; + use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt}; + + pub fn init_tracing() { + let env_filter = EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(); + + let _ = Registry::default() + .with(env_filter) + .with(ForestLayer::default()) + .try_init(); + } + + #[test] + fn bench_merkle_commit_koalabear_width_32_log_sizes() { + let folding_factor = 7; + let width = 1usize << folding_factor; + + init_tracing(); + + for log_size in 20..=26 { + let n_values = 1usize << log_size; + let height = n_values / width; + let values: Vec<_> = (0..n_values).map(pseudo_random_koalabear).collect(); + let matrix = DenseMatrix::new(values, width); + + assert_eq!(matrix.width(), width); + assert_eq!(matrix.height(), height); + + let start = Instant::now(); + let (root, prover_data) = merkle_commit::(matrix, width, width); + let elapsed = start.elapsed(); + + assert_eq!(prover_data.leaf.width(), width); + assert_eq!(prover_data.leaf.height(), height); + assert_eq!(prover_data.full_leaf_base_width, width); + + let log_height = log2_ceil_usize(height); + println!("poseidon log_size={log_size}, log_height={log_height}, width={width}, time={elapsed:?}",); + } + } + + #[test] + fn bench_sha256_merkle_commit_koalabear_width_32_log_sizes() { + let folding_factor = 7; + let width = 1usize << folding_factor; + + init_tracing(); + + for log_size in 20..=26 { + let n_values = 1usize << log_size; + let height = n_values / width; + let values: Vec<_> = (0..n_values).map(pseudo_random_koalabear).collect(); + let matrix = DenseMatrix::new(values, width); + + assert_eq!(matrix.width(), width); + assert_eq!(matrix.height(), height); + + let start = Instant::now(); + let (root, tree) = sha256_merkle_commit(matrix); + let elapsed = start.elapsed(); + + assert_eq!(tree.digest_layers[0].len(), height); + assert_eq!(tree.digest_layers.last().unwrap()[0], root); + + let log_height = log2_ceil_usize(height); + println!("sha256 log={log_size}, log_height={log_height}, width={width}, time={elapsed:?}",); + } + } +} From 467aee7cbbb96b67736f7a2c77471c579fd054a1 Mon Sep 17 00:00:00 2001 From: kilic Date: Thu, 14 May 2026 16:54:06 +0300 Subject: [PATCH 02/19] wip --- crates/backend/symetric/src/merkle.rs | 6 ++++++ crates/whir/Cargo.toml | 2 +- crates/whir/src/merkle.rs | 26 ++++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/crates/backend/symetric/src/merkle.rs b/crates/backend/symetric/src/merkle.rs index 676e83f3e..0663a26ad 100644 --- a/crates/backend/symetric/src/merkle.rs +++ b/crates/backend/symetric/src/merkle.rs @@ -16,6 +16,12 @@ pub struct MerkleTree { pub digest_layers: Vec>, } +/// A Merkle tree storing only the digest layers (no leaf data). +#[derive(Debug, Clone)] +pub struct MerkleTree2 { + pub digest_layers: Vec>, +} + impl MerkleTree { /// Build a Merkle tree from a pre-computed first digest layer. pub fn from_first_layer(comp: &Comp, first_layer: Vec<[F; DIGEST_ELEMS]>) -> Self diff --git a/crates/whir/Cargo.toml b/crates/whir/Cargo.toml index 696459af9..75b3f961b 100644 --- a/crates/whir/Cargo.toml +++ b/crates/whir/Cargo.toml @@ -17,8 +17,8 @@ itertools.workspace = true rayon.workspace = true rand.workspace = true tracing.workspace = true +sha2 = "0.10.9" [dev-dependencies] -sha2 = "0.10.9" tracing-forest.workspace = true tracing-subscriber.workspace = true diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index 94fd5f627..90bf034a2 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -8,6 +8,7 @@ use field::BasedVectorSpace; use field::ExtensionField; use field::Field; use field::PackedValue; +use field::PrimeField32; use koala_bear::{KoalaBear, QuinticExtensionFieldKB, default_koalabear_poseidon1_16}; use poly::*; @@ -247,6 +248,31 @@ where digests } +type Sha256Digest = [u8; 16]; +use sha2::{Digest, Sha256}; +#[instrument(name = "first digest layer", level = "debug", skip_all)] +fn sha2_first_digest_layer(h: &Sha256, matrix: &M, _full_width: usize) -> Vec +where + P: PrimeField32, + M: Matrix

, +{ + let height = matrix.height(); + let matrix_width = matrix.width(); + assert!(matrix_width <= full_width); + + (0..height) + .into_par_iter() + .map(|r| { + let mut hasher = h.clone(); + for value in matrix.row(r).unwrap() { + hasher.update(value.as_canonical_u32().to_le_bytes()); + } + let digest = hasher.finalize(); + digest[..16].try_into().unwrap() + }) + .collect() +} + #[instrument(skip_all)] fn first_digest_layer_with_initial_state( perm: &Perm, From 1809640840ab9bee5192b7b4a263f52cc294c09c Mon Sep 17 00:00:00 2001 From: kilic Date: Thu, 14 May 2026 18:06:53 +0300 Subject: [PATCH 03/19] dep --- Cargo.lock | 17 +++++++++++++++-- Cargo.toml | 2 ++ crates/backend/Cargo.toml | 2 ++ crates/backend/fiat-shamir/Cargo.toml | 1 + 4 files changed, 20 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 209470eae..b3b63a8ae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -105,6 +105,7 @@ dependencies = [ "mt-utils", "mt-whir", "rayon", + "sha2 0.11.0", "tracing", ] @@ -621,6 +622,7 @@ dependencies = [ "mt-utils", "rayon", "serde", + "sha2 0.11.0", "tracing", ] @@ -710,7 +712,7 @@ dependencies = [ "mt-utils", "rand", "rayon", - "sha2", + "sha2 0.10.9", "system-info", "tracing", "tracing-forest", @@ -837,7 +839,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "89815c69d36021a140146f26659a81d6c2afa33d216d736dd4be5381a7362220" dependencies = [ "pest", - "sha2", + "sha2 0.10.9", ] [[package]] @@ -1044,6 +1046,17 @@ dependencies = [ "digest 0.10.7", ] +[[package]] +name = "sha2" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "446ba717509524cb3f22f17ecc096f10f4822d76ab5c0b9822c5f9c284e825f4" +dependencies = [ + "cfg-if", + "cpufeatures 0.3.0", + "digest 0.11.2", +] + [[package]] name = "sha3" version = "0.11.0" diff --git a/Cargo.toml b/Cargo.toml index f8e2ada76..a83615197 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,6 +80,8 @@ tracing-forest = { version = "0.3.0", features = ["ansi", "smallvec"] } postcard = { version = "1.1.3", features = ["alloc"] } lz4_flex = "0.13.0" include_dir = "0.7" +sha2 = "0.11.0" + [features] prox-gaps-conjecture = ["rec_aggregation/prox-gaps-conjecture"] diff --git a/crates/backend/Cargo.toml b/crates/backend/Cargo.toml index 3f61957af..3ebd8084b 100644 --- a/crates/backend/Cargo.toml +++ b/crates/backend/Cargo.toml @@ -15,3 +15,5 @@ tracing.workspace = true fiat-shamir = { path = "fiat-shamir", package = "mt-fiat-shamir" } koala-bear = { path = "koala-bear", package = "mt-koala-bear" } utils = { path = "utils", package = "mt-utils" } +sha2.workspace = true + diff --git a/crates/backend/fiat-shamir/Cargo.toml b/crates/backend/fiat-shamir/Cargo.toml index ec8649bc2..d277c3e04 100644 --- a/crates/backend/fiat-shamir/Cargo.toml +++ b/crates/backend/fiat-shamir/Cargo.toml @@ -11,3 +11,4 @@ utils = { path = "../utils", package = "mt-utils" } tracing.workspace = true serde.workspace = true rayon.workspace = true +sha2.workspace = true \ No newline at end of file From 75badb0a6e629864a14ebf9f6893c7455303b72d Mon Sep 17 00:00:00 2001 From: kilic Date: Thu, 14 May 2026 18:07:33 +0300 Subject: [PATCH 04/19] merkle --- crates/backend/symetric/src/merkle.rs | 18 +- crates/whir/src/merkle.rs | 307 +++++++++++++++++++------- 2 files changed, 247 insertions(+), 78 deletions(-) diff --git a/crates/backend/symetric/src/merkle.rs b/crates/backend/symetric/src/merkle.rs index 0663a26ad..891be80d2 100644 --- a/crates/backend/symetric/src/merkle.rs +++ b/crates/backend/symetric/src/merkle.rs @@ -16,10 +16,11 @@ pub struct MerkleTree { pub digest_layers: Vec>, } +pub type Sha256Digest = [u8; 16]; /// A Merkle tree storing only the digest layers (no leaf data). #[derive(Debug, Clone)] -pub struct MerkleTree2 { - pub digest_layers: Vec>, +pub struct MerkleTreeSha2 { + pub digest_layers: Vec>, } impl MerkleTree { @@ -53,6 +54,19 @@ impl MerkleT } } +impl MerkleTreeSha2 { + #[must_use] + pub fn root(&self) -> Sha256Digest { + self.digest_layers.last().unwrap()[0] + } + + pub fn open_siblings(&self, index: usize, log_height: usize) -> Vec { + (0..log_height) + .map(|i| self.digest_layers[i][(index >> i) ^ 1]) + .collect() + } +} + pub fn compress_layer( prev_layer: &[[P::Value; DIGEST_ELEMS]], comp: &Comp, diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index 90bf034a2..18402c1cb 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -14,6 +14,8 @@ use poly::*; use rayon::prelude::*; use symetric::Compression; +use symetric::merkle::MerkleTreeSha2; +use symetric::merkle::Sha256Digest; use symetric::merkle::unpack_array; use tracing::instrument; use utils::log2_ceil_usize; @@ -24,6 +26,9 @@ use crate::Matrix; pub use symetric::DIGEST_ELEMS; pub(crate) type RoundMerkleTree = WhirMerkleTree, DIGEST_ELEMS>; +pub(crate) type RoundMerkleTreeSha2 = WhirMerkleTreeSha2>; + +use sha2::{Digest, Sha256}; #[allow(clippy::missing_transmute_annotations)] pub(crate) fn merkle_commit>( @@ -56,6 +61,35 @@ pub(crate) fn merkle_commit>( } } +#[allow(clippy::missing_transmute_annotations)] +pub(crate) fn merkle_commit_sha2>( + matrix: DenseMatrix, + full_n_cols: usize, + effective_n_cols: usize, +) -> (Sha256Digest, RoundMerkleTreeSha2) { + if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, QuinticExtensionFieldKB)>() { + let matrix = unsafe { std::mem::transmute::<_, DenseMatrix>(matrix) }; + let dim = >::DIMENSION; + let dft_base_width = matrix.width * dim; + let full_base_width = full_n_cols * dim; + let effective_base_width = effective_n_cols * dim; + let base_values = QuinticExtensionFieldKB::flatten_to_base(matrix.values); + let base_matrix = DenseMatrix::::new(base_values, dft_base_width); + let tree = build_merkle_tree_sha256(base_matrix, full_base_width, effective_base_width); + let root = tree.root(); + let tree = unsafe { std::mem::transmute::<_, RoundMerkleTreeSha2>(tree) }; + (root, tree) + } else if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, KoalaBear)>() { + let matrix = unsafe { std::mem::transmute::<_, DenseMatrix>(matrix) }; + let tree = build_merkle_tree_sha256(matrix, full_n_cols, effective_n_cols); + let root = tree.root(); + let tree = unsafe { std::mem::transmute::<_, RoundMerkleTreeSha2>(tree) }; + (root, tree) + } else { + unimplemented!() + } +} + #[instrument(name = "build merkle tree", skip_all)] fn build_merkle_tree_koalabear( leaf: DenseMatrix, @@ -88,6 +122,15 @@ fn build_merkle_tree_koalabear( } } +#[instrument(name = "build merkle tree sha256", skip_all)] +fn build_merkle_tree_sha256( + leaf: DenseMatrix, + full_base_width: usize, + effective_base_width: usize, +) -> RoundMerkleTreeSha2 { + WhirMerkleTreeSha2::new(leaf, full_base_width, effective_base_width) +} + #[allow(clippy::missing_transmute_annotations)] pub(crate) fn merkle_open>( merkle_tree: &RoundMerkleTree, @@ -112,6 +155,28 @@ pub(crate) fn merkle_open>( } } +#[allow(clippy::missing_transmute_annotations)] +pub(crate) fn merkle_open_sha2>( + merkle_tree: &RoundMerkleTreeSha2, + index: usize, +) -> (Vec, Vec) { + if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, QuinticExtensionFieldKB)>() { + let merkle_tree = unsafe { std::mem::transmute::<_, &RoundMerkleTreeSha2>(merkle_tree) }; + let (inner_leaf, proof) = merkle_tree.open(index); + let leaf = QuinticExtensionFieldKB::reconstitute_from_base(inner_leaf); + let leaf = unsafe { std::mem::transmute::<_, Vec>(leaf) }; + (leaf, proof) + } else if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, KoalaBear)>() { + let merkle_tree = unsafe { std::mem::transmute::<_, &RoundMerkleTreeSha2>(merkle_tree) }; + let (inner_leaf, proof) = merkle_tree.open(index); + let leaf = KoalaBear::reconstitute_from_base(inner_leaf); + let leaf = unsafe { std::mem::transmute::<_, Vec>(leaf) }; + (leaf, proof) + } else { + unimplemented!() + } +} + #[allow(clippy::missing_transmute_annotations)] pub(crate) fn merkle_verify>( merkle_root: [F; DIGEST_ELEMS], @@ -153,6 +218,28 @@ pub(crate) fn merkle_verify>( } } +#[allow(clippy::missing_transmute_annotations)] +pub(crate) fn merkle_verify_sha2>( + merkle_root: Sha256Digest, + index: usize, + dimension: Dimensions, + data: Vec, + proof: &Vec, +) -> bool { + let log_max_height = utils::log2_strict_usize(dimension.height.next_power_of_two()); + if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, QuinticExtensionFieldKB)>() { + let data = unsafe { std::mem::transmute::<_, Vec>(data) }; + let base_data: Vec = QuinticExtensionFieldKB::flatten_to_base(data); + sha2_merkle_verify(&merkle_root, log_max_height, index, &base_data, proof) + } else if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, KoalaBear)>() { + let data = unsafe { std::mem::transmute::<_, Vec>(data) }; + let base_data = KoalaBear::flatten_to_base(data); + sha2_merkle_verify(&merkle_root, log_max_height, index, &base_data, proof) + } else { + unimplemented!() + } +} + #[derive(Debug, Clone)] pub struct WhirMerkleTree { pub(crate) leaf: M, @@ -160,6 +247,14 @@ pub struct WhirMerkleTree { full_leaf_base_width: usize, } +#[derive(Debug, Clone)] +pub struct WhirMerkleTreeSha2> { + pub(crate) leaf: M, + pub(crate) tree: MerkleTreeSha2, + full_leaf_base_width: usize, + _marker: std::marker::PhantomData, +} + impl, const DIGEST_ELEMS: usize> WhirMerkleTree { @@ -248,31 +343,6 @@ where digests } -type Sha256Digest = [u8; 16]; -use sha2::{Digest, Sha256}; -#[instrument(name = "first digest layer", level = "debug", skip_all)] -fn sha2_first_digest_layer(h: &Sha256, matrix: &M, _full_width: usize) -> Vec -where - P: PrimeField32, - M: Matrix

, -{ - let height = matrix.height(); - let matrix_width = matrix.width(); - assert!(matrix_width <= full_width); - - (0..height) - .into_par_iter() - .map(|r| { - let mut hasher = h.clone(); - for value in matrix.row(r).unwrap() { - hasher.update(value.as_canonical_u32().to_le_bytes()); - } - let digest = hasher.finalize(); - digest[..16].try_into().unwrap() - }) - .collect() -} - #[instrument(skip_all)] fn first_digest_layer_with_initial_state( perm: &Perm, @@ -313,76 +383,161 @@ where digests } -#[cfg(test)] -mod tests { - use std::time::Instant; - - use field::PrimeField32; - use field::integers::QuotientMap; - use rand::{RngExt, SeedableRng, rngs::StdRng}; - use sha2::{Digest, Sha256}; - - use super::*; +#[instrument(name = "first digest layer", level = "debug", skip_all)] +fn sha2_first_digest_layer(h: &Sha256, matrix: &M, _full_width: usize) -> Vec +where + F: PrimeField32, + M: Matrix, +{ + let height = matrix.height(); - type Sha256Digest = [u8; 16]; + (0..height) + .into_par_iter() + .map(|r| { + let mut hasher = h.clone(); + for value in matrix.row(r).unwrap() { + hasher.update(value.as_canonical_u32().to_le_bytes()); + } + let digest = hasher.finalize(); + digest[..16].try_into().unwrap() + }) + .collect() +} - #[derive(Debug)] - struct Sha256MerkleTree { - digest_layers: Vec>, - } +fn sha2_compress_pair(left: &Sha256Digest, right: &Sha256Digest) -> Sha256Digest { + let mut hasher = Sha256::new(); + hasher.update(left); + hasher.update(right); + let digest = hasher.finalize(); + digest[..16].try_into().unwrap() +} - fn pseudo_random_koalabear(index: usize) -> KoalaBear { - let mut x = index as u64; - x = x.wrapping_add(0x9e37_79b9_7f4a_7c15); - x = (x ^ (x >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9); - x = (x ^ (x >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb); - KoalaBear::from_int(x ^ (x >> 31)) +fn sha2_merkle_verify( + commit: &Sha256Digest, + log_height: usize, + mut index: usize, + opened_values: &[F], + opening_proof: &[Sha256Digest], +) -> bool { + if opening_proof.len() != log_height { + return false; } - fn sha256_truncated(hasher: Sha256) -> Sha256Digest { - let digest = hasher.finalize(); - digest[..16].try_into().unwrap() + let mut hasher = Sha256::new(); + for value in opened_values { + hasher.update(value.as_canonical_u32().to_le_bytes()); } + let digest = hasher.finalize(); + let mut root: Sha256Digest = digest[..16].try_into().unwrap(); - fn sha256_hash_row(row: &[KoalaBear]) -> Sha256Digest { - let mut hasher = Sha256::new(); - for value in row { - hasher.update(value.as_canonical_u32().to_le_bytes()); - } - sha256_truncated(hasher) + for sibling in opening_proof { + let (left, right) = if index & 1 == 0 { + (root, *sibling) + } else { + (*sibling, root) + }; + root = sha2_compress_pair(&left, &right); + index >>= 1; } - fn sha256_hash_pair(left: &Sha256Digest, right: &Sha256Digest) -> Sha256Digest { - let mut hasher = Sha256::new(); - hasher.update(left); - hasher.update(right); - sha256_truncated(hasher) - } + commit == &root +} - #[tracing::instrument(name = "sha256 merkle commit", skip_all)] - fn sha256_merkle_commit(matrix: DenseMatrix) -> (Sha256Digest, Sha256MerkleTree) { - let width = matrix.width(); - let height = matrix.height(); - assert!(height.is_power_of_two()); +impl> WhirMerkleTreeSha2 { + #[instrument(name = "build merkle tree", skip_all)] + pub fn new(leaf: M, full_leaf_base_width: usize, effective_base_width: usize) -> Self { + assert!(leaf.height().is_power_of_two()); + assert!(leaf.width() <= full_leaf_base_width); + assert!(effective_base_width <= full_leaf_base_width); - let first_layer: Vec<_> = tracing::info_span!("leafs") - .in_scope(|| matrix.values.par_chunks_exact(width).map(sha256_hash_row).collect()); + let first_layer = sha2_first_digest_layer(&Sha256::new(), &leaf, full_leaf_base_width); let mut digest_layers = vec![first_layer]; - tracing::info_span!("asc").in_scope(|| { + tracing::debug_span!("asc").in_scope(|| { while digest_layers.last().unwrap().len() > 1 { let prev_layer = digest_layers.last().unwrap(); assert!(prev_layer.len().is_multiple_of(2)); let next_layer = prev_layer .par_chunks_exact(2) - .map(|pair| sha256_hash_pair(&pair[0], &pair[1])) + .map(|pair| sha2_compress_pair(&pair[0], &pair[1])) .collect(); digest_layers.push(next_layer); } }); - let root = digest_layers.last().unwrap()[0]; - (root, Sha256MerkleTree { digest_layers }) + let tree = MerkleTreeSha2 { digest_layers }; + Self { + leaf, + tree, + full_leaf_base_width, + _marker: std::marker::PhantomData, + } + } + + #[must_use] + pub fn root(&self) -> Sha256Digest { + self.tree.root() + } + + pub fn open(&self, index: usize) -> (Vec, Vec) { + let log_height = log2_ceil_usize(self.leaf.height()); + let mut opening: Vec = self.leaf.row(index).unwrap().into_iter().collect(); + opening.resize(self.full_leaf_base_width, F::default()); + let proof = self.tree.open_siblings(index, log_height); + (opening, proof) + } +} + +#[cfg(test)] +mod tests { + use std::time::Instant; + + use field::integers::QuotientMap; + + use super::*; + + fn pseudo_random_koalabear(index: usize) -> KoalaBear { + let mut x = index as u64; + x = x.wrapping_add(0x9e37_79b9_7f4a_7c15); + x = (x ^ (x >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9); + x = (x ^ (x >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb); + KoalaBear::from_int(x ^ (x >> 31)) + } + + #[test] + fn whir_sha2_merkle_tree_builds_expected_layers() { + let width = 8; + let height = 4; + let values: Vec<_> = (0..width * height).map(pseudo_random_koalabear).collect(); + let matrix = DenseMatrix::new(values, width); + + let (root, tree) = merkle_commit_sha2::(matrix, width, width); + + assert_eq!(tree.tree.digest_layers.len(), log2_ceil_usize(height) + 1); + assert_eq!(tree.tree.digest_layers[0].len(), height); + assert_eq!(tree.tree.digest_layers[1].len(), height / 2); + assert_eq!(tree.tree.digest_layers[2].len(), 1); + assert_eq!(tree.root(), root); + } + + #[test] + fn whir_sha2_merkle_opening_verifies() { + let width = 8; + let height = 4; + let index = 2; + let values: Vec<_> = (0..width * height).map(pseudo_random_koalabear).collect(); + let matrix = DenseMatrix::new(values, width); + + let (root, tree) = merkle_commit_sha2::(matrix, width, width); + let (leaf, proof) = merkle_open_sha2::(&tree, index); + + assert!(merkle_verify_sha2::( + root, + index, + Dimensions { width, height }, + leaf, + &proof, + )); } use tracing_forest::{ForestLayer, util::LevelFilter}; @@ -416,7 +571,7 @@ mod tests { assert_eq!(matrix.height(), height); let start = Instant::now(); - let (root, prover_data) = merkle_commit::(matrix, width, width); + let (_root, prover_data) = merkle_commit::(matrix, width, width); let elapsed = start.elapsed(); assert_eq!(prover_data.leaf.width(), width); @@ -445,11 +600,11 @@ mod tests { assert_eq!(matrix.height(), height); let start = Instant::now(); - let (root, tree) = sha256_merkle_commit(matrix); + let (root, tree) = merkle_commit_sha2::(matrix, width, width); let elapsed = start.elapsed(); - assert_eq!(tree.digest_layers[0].len(), height); - assert_eq!(tree.digest_layers.last().unwrap()[0], root); + assert_eq!(tree.tree.digest_layers[0].len(), height); + assert_eq!(tree.tree.digest_layers.last().unwrap()[0], root); let log_height = log2_ceil_usize(height); println!("sha256 log={log_size}, log_height={log_height}, width={width}, time={elapsed:?}",); From f687087c0a21f3a32a28eb1d453cbddd4242fc0a Mon Sep 17 00:00:00 2001 From: kilic Date: Thu, 14 May 2026 18:07:52 +0300 Subject: [PATCH 05/19] state --- crates/backend/fiat-shamir/src/challenger.rs | 82 ++++++++++++- crates/backend/fiat-shamir/src/prover.rs | 115 ++++++++++++++++++- crates/lean_prover/src/prove_execution.rs | 1 + crates/utils/src/wrappers.rs | 4 + 4 files changed, 199 insertions(+), 3 deletions(-) diff --git a/crates/backend/fiat-shamir/src/challenger.rs b/crates/backend/fiat-shamir/src/challenger.rs index 34fcd94ab..9304a410c 100644 --- a/crates/backend/fiat-shamir/src/challenger.rs +++ b/crates/backend/fiat-shamir/src/challenger.rs @@ -1,4 +1,7 @@ -use field::PrimeField64; +use std::marker::PhantomData; + +use field::{PrimeField32, PrimeField64}; +use sha2::{Digest, Sha256}; use symetric::Compression; pub(crate) const RATE: usize = 8; @@ -73,3 +76,80 @@ impl> Challenger { res } } + +#[derive(Clone, Debug)] +pub struct ChallengerSha2 { + pub sha2: Sha256, + _marker: PhantomData, +} + +impl ChallengerSha2 { + pub fn new() -> Self { + Self { + sha2: Sha256::new(), + _marker: PhantomData, + } + } + + pub fn observe(&mut self, value: [F; RATE]) { + for val in value { + self.sha2.update(val.as_canonical_u32().to_le_bytes()); + } + } + + pub fn observe_scalars(&mut self, scalars: &[F]) { + for chunk in scalars.chunks(RATE) { + let mut buffer = [F::ZERO; RATE]; + for (i, val) in chunk.iter().enumerate() { + buffer[i] = *val; + } + self.observe(buffer); + } + } + + pub fn sample_many(&mut self, n: usize) -> Vec<[F; RATE]> { + let mut sampled = Vec::with_capacity(n + 1); + for i in 0..n + 1 { + sampled.push(self.sample_chunk(i)); + } + let last = sampled.pop().unwrap(); + self.sha2 = Sha256::new(); + self.observe(last); + sampled + } + + pub fn sample_in_range(&mut self, bits: usize, n_samples: usize) -> Vec { + assert!(bits < F::bits()); + let sampled_fe = self.sample_many(n_samples.div_ceil(RATE)).into_iter().flatten(); + let mut res = Vec::new(); + for fe in sampled_fe.take(n_samples) { + let rand_usize = fe.as_canonical_u64() as usize; + res.push(rand_usize & ((1 << bits) - 1)); + } + res + } + + fn sample_chunk(&self, domain_sep: usize) -> [F; RATE] { + let mut words = Vec::with_capacity(RATE); + for block_idx in 0u64.. { + let mut hasher = self.sha2.clone(); + hasher.update((domain_sep as u64).to_le_bytes()); + hasher.update(block_idx.to_le_bytes()); + let digest = hasher.finalize(); + for word in digest.chunks_exact(size_of::()) { + let word = u32::from_le_bytes(word.try_into().unwrap()); + words.push(F::from_int(word)); + if words.len() == RATE { + return words.try_into().unwrap(); + } + } + } + unreachable!() + } +} + +impl Default for ChallengerSha2 { + fn default() -> Self { + Self::new() + } +} diff --git a/crates/backend/fiat-shamir/src/prover.rs b/crates/backend/fiat-shamir/src/prover.rs index 2ea95580d..7467975bd 100644 --- a/crates/backend/fiat-shamir/src/prover.rs +++ b/crates/backend/fiat-shamir/src/prover.rs @@ -1,13 +1,13 @@ use crate::{ MerklePaths, PrunedMerklePaths, - challenger::{Challenger, RATE, WIDTH}, + challenger::{Challenger, ChallengerSha2, RATE, WIDTH}, *, }; use field::Field; use field::PackedValue; use field::PrimeCharacteristicRing; use field::integers::QuotientMap; -use field::{ExtensionField, PrimeField64}; +use field::{ExtensionField, PrimeField32, PrimeField64}; use rayon::prelude::*; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::Duration; @@ -170,3 +170,114 @@ where POW_GRINDING_NANOS.fetch_add(elapsed.as_nanos() as u64, Ordering::Relaxed); } } + +#[derive(Debug)] +pub struct ProverStateSha2>> { + challenger: ChallengerSha2>, + transcript: Vec>, + merkle_paths: Vec, PF>>, +} + +impl>> ProverStateSha2 +where + PF: PrimeField32, +{ + #[must_use] + pub fn new() -> Self { + assert!(EF::DIMENSION <= RATE); + Self { + challenger: ChallengerSha2::new(), + transcript: Vec::new(), + merkle_paths: Vec::new(), + } + } + + pub fn into_proof(self) -> Proof> { + Proof { + transcript: self.transcript, + merkle_paths: self.merkle_paths, + } + } +} + +impl>> ChallengeSampler for ProverStateSha2 +where + PF: PrimeField32, +{ + fn sample_vec(&mut self, len: usize) -> Vec { + let sampled_fe = self + .challenger + .sample_many((len * EF::DIMENSION).div_ceil(RATE)) + .into_iter() + .flatten() + .take(len * EF::DIMENSION) + .collect::>>(); + pack_scalars_to_extension(&sampled_fe) + } + + fn sample_in_range(&mut self, bits: usize, n_samples: usize) -> Vec { + self.challenger.sample_in_range(bits, n_samples) + } +} + +impl>> FSProver for ProverStateSha2 +where + PF: PrimeField32, +{ + fn add_base_scalars(&mut self, scalars: &[PF]) { + self.challenger.observe_scalars(scalars); + self.transcript.extend_from_slice(scalars); + } + + fn observe_scalars(&mut self, scalars: &[PF]) { + self.challenger.observe_scalars(scalars); + } + + fn state(&self) -> String { + format!("sha2 transcript (n_items: {})", self.transcript.len()) + } + + fn add_sumcheck_polynomial(&mut self, coeffs: &[EF], eq_alpha: Option) { + match eq_alpha { + None => { + let scalars = flatten_scalars_to_base(coeffs); + self.challenger.observe_scalars(&scalars); + self.transcript.extend_from_slice(&scalars[EF::DIMENSION..]); + } + Some(alpha) => { + let bare_scalars = flatten_scalars_to_base(coeffs); + let full_scalars = flatten_scalars_to_base(&expand_bare_to_full(coeffs, alpha)); + self.challenger.observe_scalars(&full_scalars); + self.transcript.extend_from_slice(&bare_scalars[EF::DIMENSION..]); + } + } + } + + fn hint_merkle_paths_base(&mut self, paths: Vec, PF>>) { + self.merkle_paths.push(MerklePaths(paths).prune()); + } + + fn pow_grinding(&mut self, bits: usize) { + assert!(bits < PF::::bits()); + + if bits == 0 { + return; + } + + let time = Instant::now(); + let witness = (0..PF::::ORDER_U32) + .find_map(|candidate| { + let witness = unsafe { PF::::from_canonical_unchecked(candidate) }; + let mut challenger = self.challenger.clone(); + challenger.observe_scalars(&[witness]); + (challenger.sample_in_range(bits, 1)[0] == 0).then_some(witness) + }) + .expect("failed to find witness"); + + self.challenger.observe_scalars(&[witness]); + self.transcript.push(witness); + + let elapsed = time.elapsed(); + POW_GRINDING_NANOS.fetch_add(elapsed.as_nanos() as u64, Ordering::Relaxed); + } +} diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index fa86a3ae2..d3ea60e62 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -45,6 +45,7 @@ pub fn prove_execution( memory.resize(min_memory_size, F::ZERO); } let mut prover_state = build_prover_state(); + let mut prover_state2 = build_prover_state_sha2(); prover_state.observe_scalars(public_input); prover_state.observe_scalars(&poseidon16_compress_pair(&bytecode.hash, &SNARK_DOMAIN_SEP)); prover_state.add_base_scalars( diff --git a/crates/utils/src/wrappers.rs b/crates/utils/src/wrappers.rs index c8aa8e40f..5ddb5f6b2 100644 --- a/crates/utils/src/wrappers.rs +++ b/crates/utils/src/wrappers.rs @@ -9,6 +9,10 @@ pub fn build_prover_state() -> ProverState ProverState::new(get_poseidon16().clone()) } +pub fn build_prover_state_sha2() -> ProverStateSha2 { + ProverStateSha2::new() +} + pub fn build_verifier_state( prover_state: ProverState, ) -> Result, ProofError> { From 3d6378b854c4b38da9291b477b531fbd39093f8e Mon Sep 17 00:00:00 2001 From: kilic Date: Thu, 14 May 2026 18:11:30 +0300 Subject: [PATCH 06/19] start wiring --- crates/lean_prover/src/prove_execution.rs | 35 ++++++++------- crates/whir/src/commit.rs | 54 +++++++++++++++++++++++ 2 files changed, 73 insertions(+), 16 deletions(-) diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index d3ea60e62..fe579aee9 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use sub_protocols::*; use tracing::info_span; use utils::ansi::Colorize; -use utils::{build_prover_state, from_end}; +use utils::{build_prover_state, build_prover_state_sha2, from_end}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ExecutionProof { @@ -47,21 +47,24 @@ pub fn prove_execution( let mut prover_state = build_prover_state(); let mut prover_state2 = build_prover_state_sha2(); prover_state.observe_scalars(public_input); - prover_state.observe_scalars(&poseidon16_compress_pair(&bytecode.hash, &SNARK_DOMAIN_SEP)); - prover_state.add_base_scalars( - &[ - vec![ - whir_config.starting_log_inv_rate, - log2_strict_usize(memory.len()), - public_input.len(), - ], - traces.values().map(|t| t.log_n_rows).collect::>(), - ] - .concat() - .into_iter() - .map(F::from_usize) - .collect::>(), - ); + prover_state2.observe_scalars(public_input); + let bytecode_hash_with_domain_sep = poseidon16_compress_pair(&bytecode.hash, &SNARK_DOMAIN_SEP); + prover_state.observe_scalars(&bytecode_hash_with_domain_sep); + prover_state2.observe_scalars(&bytecode_hash_with_domain_sep); + let execution_metadata_scalars = [ + vec![ + whir_config.starting_log_inv_rate, + log2_strict_usize(memory.len()), + public_input.len(), + ], + traces.values().map(|t| t.log_n_rows).collect::>(), + ] + .concat() + .into_iter() + .map(F::from_usize) + .collect::>(); + prover_state.add_base_scalars(&execution_metadata_scalars); + prover_state2.add_base_scalars(&execution_metadata_scalars); for (table, table_trace) in &traces { let log_n_rows = table_trace.log_n_rows; assert!(log_n_rows >= MIN_LOG_N_ROWS_PER_TABLE, "missing padding"); diff --git a/crates/whir/src/commit.rs b/crates/whir/src/commit.rs index b64bb3502..d7e83ae17 100644 --- a/crates/whir/src/commit.rs +++ b/crates/whir/src/commit.rs @@ -13,6 +13,12 @@ pub enum MerkleData>> { Extension(RoundMerkleTree>), } +#[derive(Debug, Clone)] +pub enum MerkleData2>> { + Base(RoundMerkleTreeSha2>), + Extension(RoundMerkleTreeSha2>), +} + impl>> MerkleData { pub(crate) fn build( matrix: DftOutput, @@ -55,6 +61,16 @@ where pub ood_answers: Vec, } +#[derive(Debug, Clone)] +pub struct Witness2 +where + EF: ExtensionField>, +{ + pub prover_data: MerkleData2, + pub ood_points: Vec, + pub ood_answers: Vec, +} + impl WhirConfig where EF: ExtensionField>, @@ -97,4 +113,42 @@ where ood_answers, } } + + #[instrument(skip_all)] + pub fn commit2( + &self, + prover_state: &mut impl FSProver, + polynomial: &MleOwned, + actual_data_len: usize, // polynomial[actual_data_len..] is zero + ) -> Witness2 { + let n_blocks = 1usize << self.folding_factor.at_round(0); + let evals_len = 1usize << self.num_variables; + let effective_n_cols = actual_data_len.div_ceil(evals_len / n_blocks); + // DFT matrix width: skip as many zero columns as possible, aligned to packing (SIMD) + let dft_n_cols = effective_n_cols.next_multiple_of(packing_width::()).min(n_blocks); + + let folded_matrix = info_span!("FFT").in_scope(|| { + reorder_and_dft( + &polynomial.by_ref(), + self.folding_factor.at_round(0), + self.starting_log_inv_rate, + dft_n_cols, + ) + }); + + let (prover_data, root) = MerkleData2::build(folded_matrix, n_blocks, effective_n_cols); + + prover_state.add_base_scalars(&root); + + let (ood_points, ood_answers) = + sample_ood_points::(prover_state, self.commitment_ood_samples, self.num_variables, |point| { + polynomial.evaluate(point) + }); + + Witness { + prover_data, + ood_points, + ood_answers, + } + } } From 1d73eaebbfe83070891b3da19d89a8df1f7d3be1 Mon Sep 17 00:00:00 2001 From: kilic Date: Thu, 14 May 2026 18:37:09 +0300 Subject: [PATCH 07/19] cont wiring --- crates/lean_prover/src/prove_execution.rs | 108 ++++++++++++++++++++++ crates/sub_protocols/src/stacked_pcs.rs | 8 +- crates/whir/src/commit.rs | 43 ++++++++- 3 files changed, 154 insertions(+), 5 deletions(-) diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index fe579aee9..f7a490883 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -116,6 +116,7 @@ pub fn prove_execution( // 1st Commitment let stacked_pcs_witness = stack_polynomials_and_commit( &mut prover_state, + &mut prover_state2, whir_config, &memory, &memory_acc, @@ -138,8 +139,24 @@ pub fn prove_execution( &bytecode_acc, &traces, ); + let logup_c2 = prover_state2.sample(); + let logup_alphas2 = prover_state2.sample_vec(log2_ceil_usize(max_bus_width_including_domainsep())); + let logup_alphas_eq_poly2 = eval_eq(&logup_alphas2); + + let logup_statements2 = prove_generic_logup( + &mut prover_state2, + logup_c2, + &logup_alphas_eq_poly2, + &memory, + &memory_acc, + &bytecode.instructions_multilinear, + &bytecode_acc, + &traces, + ); let gkr_point = &logup_statements.gkr_point; + let gkr_point2 = &logup_statements2.gkr_point; let mut committed_statements: CommittedStatements = Default::default(); + let mut committed_statements2: CommittedStatements = Default::default(); for table in ALL_TABLES { let log_n_rows = traces[&table].log_n_rows; committed_statements.insert( @@ -150,12 +167,24 @@ pub fn prove_execution( BTreeMap::new(), )], ); + committed_statements2.insert( + table, + vec![( + MultilinearPoint(from_end(gkr_point2, log_n_rows).to_vec()), + logup_statements2.columns_values[&table].clone(), + BTreeMap::new(), + )], + ); } let bus_beta = prover_state.sample(); let air_alpha = prover_state.sample(); let air_alpha_powers: Vec = air_alpha.powers().collect_n(max_air_constraints() + 1); let air_eta: EF = prover_state.sample(); + let bus_beta2 = prover_state2.sample(); + let air_alpha2 = prover_state2.sample(); + let air_alpha_powers2: Vec = air_alpha2.powers().collect_n(max_air_constraints() + 1); + let air_eta2: EF = prover_state2.sample(); let tables_log_heights: BTreeMap = traces.iter().map(|(table, trace)| (*table, trace.log_n_rows)).collect(); @@ -178,6 +207,7 @@ pub fn prove_execution( .collect(); std::mem::drop(_span); let mut sessions = Vec::with_capacity(tables_sorted.len()); + let mut sessions2 = Vec::with_capacity(tables_sorted.len()); for (idx, (table, log_n_rows)) in tables_sorted.iter().enumerate() { let bus_numerator_value = logup_statements.bus_numerators_values[table]; let bus_denominator_value = logup_statements.bus_denominators_values[table]; @@ -205,10 +235,44 @@ pub fn prove_execution( }}; } sessions.push(delegate_to_inner!(table => make_session)); + + let bus_numerator_value2 = logup_statements2.bus_numerators_values[table]; + let bus_denominator_value2 = logup_statements2.bus_denominators_values[table]; + let bus_final_value2 = bus_numerator_value2 + * match table.bus().direction { + BusDirection::Pull => EF::NEG_ONE, + BusDirection::Push => EF::ONE, + } + + bus_beta2 * (bus_denominator_value2 - logup_c2); + + let eq_suffix2 = from_end(gkr_point2, *log_n_rows).to_vec(); + + let extra_data2 = ExtraDataForBuses::new(logup_alphas_eq_poly2.clone(), bus_beta2, air_alpha_powers2.clone()); + + let mut up_down2: Vec<&[PF]> = column_refs[idx].to_vec(); + up_down2.extend(shifted_rows[idx].iter().map(Vec::as_slice)); + let packed2 = MleGroupRef::::Base(up_down2).pack(); + + macro_rules! make_session2 { + ($t:expr) => {{ + let session = AirSumcheckSession::new( + packed2, + eq_suffix2, + bus_final_value2, + *$t, + extra_data2, + non_padded, + ); + Box::new(session) as Box + '_> + }}; + } + sessions2.push(delegate_to_inner!(table => make_session2)); } let sumcheck_air_point = info_span!("batched AIR sumcheck") .in_scope(|| prove_batched_air_sumcheck(&mut prover_state, &mut sessions, air_eta)); + let sumcheck_air_point2 = info_span!("batched AIR sumcheck sha2") + .in_scope(|| prove_batched_air_sumcheck(&mut prover_state2, &mut sessions2, air_eta2)); for (idx, (table, _)) in tables_sorted.iter().enumerate() { let col_evals = sessions[idx].final_column_evals(); @@ -221,10 +285,23 @@ pub fn prove_execution( } let claim = delegate_to_inner!(table => split); committed_statements.get_mut(table).unwrap().push(claim); + + let col_evals2 = sessions2[idx].final_column_evals(); + prover_state2.add_extension_scalars(&col_evals2); + + let natural_ordering_point2 = + natural_ordering_point_for_session(&sumcheck_air_point2.0, traces[table].log_n_rows); + macro_rules! split2 { + ($t:expr) => {{ columns_evals_up_and_down($t, &col_evals2, &natural_ordering_point2) }}; + } + let claim2 = delegate_to_inner!(table => split2); + committed_statements2.get_mut(table).unwrap().push(claim2); } let public_memory_random_point = MultilinearPoint(prover_state.sample_vec(log2_strict_usize(public_memory_size))); let public_memory_eval = (&memory[..public_memory_size]).evaluate(&public_memory_random_point); + let public_memory_random_point2 = MultilinearPoint(prover_state2.sample_vec(log2_strict_usize(public_memory_size))); + let public_memory_eval2 = (&memory[..public_memory_size]).evaluate(&public_memory_random_point2); let previous_statements = vec![ SparseStatement::new( @@ -249,6 +326,29 @@ pub fn prove_execution( )], ), ]; + let previous_statements2 = vec![ + SparseStatement::new( + stacked_pcs_witness.stacked_n_vars, + logup_statements2.memory_and_acc_point, + vec![ + SparseValue::new(0, logup_statements2.value_memory), + SparseValue::new(1, logup_statements2.value_memory_acc), + ], + ), + SparseStatement::new( + stacked_pcs_witness.stacked_n_vars, + public_memory_random_point2, + vec![SparseValue::new(0, public_memory_eval2)], + ), + SparseStatement::new( + stacked_pcs_witness.stacked_n_vars, + logup_statements2.bytecode_and_acc_point, + vec![SparseValue::new( + (2 * memory.len()) >> bytecode.log_size(), + logup_statements2.value_bytecode_acc, + )], + ), + ]; let global_statements_base = stacked_pcs_global_statements( stacked_pcs_witness.stacked_n_vars, @@ -258,6 +358,14 @@ pub fn prove_execution( &tables_log_heights, &committed_statements, ); + let _global_statements_base2 = stacked_pcs_global_statements( + stacked_pcs_witness.stacked_n_vars, + log2_strict_usize(memory.len()), + bytecode.log_size(), + previous_statements2, + &tables_log_heights, + &committed_statements2, + ); WhirConfig::new(whir_config, stacked_pcs_witness.global_polynomial.by_ref().n_vars()).prove( &mut prover_state, diff --git a/crates/sub_protocols/src/stacked_pcs.rs b/crates/sub_protocols/src/stacked_pcs.rs index e715af3c3..b76e95b8d 100644 --- a/crates/sub_protocols/src/stacked_pcs.rs +++ b/crates/sub_protocols/src/stacked_pcs.rs @@ -34,6 +34,7 @@ Stacking of various (multilinear) polynomials into a single -big- (multilinear) pub struct StackedPcsWitness { pub stacked_n_vars: VarCount, pub inner_witness: Witness, + pub inner_witness2: Witness2, pub global_polynomial: MleOwned, } @@ -97,6 +98,7 @@ pub fn stacked_pcs_global_statements( #[instrument(skip_all)] pub fn stack_polynomials_and_commit( prover_state: &mut impl FSProver, + prover_state2: &mut impl FSProver, whir_config_builder: &WhirConfigBuilder, memory: &[F], memory_acc: &[F], @@ -146,11 +148,13 @@ pub fn stack_polynomials_and_commit( let global_polynomial = MleOwned::Base(global_polynomial); - let inner_witness = - WhirConfig::new(whir_config_builder, stacked_n_vars).commit(prover_state, &global_polynomial, offset); + let whir_config = WhirConfig::new(whir_config_builder, stacked_n_vars); + let inner_witness = whir_config.commit(prover_state, &global_polynomial, offset); + let inner_witness2 = whir_config.commit2(prover_state2, &global_polynomial, offset); StackedPcsWitness { stacked_n_vars, inner_witness, + inner_witness2, global_polynomial, } } diff --git a/crates/whir/src/commit.rs b/crates/whir/src/commit.rs index d7e83ae17..f869e5cc3 100644 --- a/crates/whir/src/commit.rs +++ b/crates/whir/src/commit.rs @@ -1,8 +1,9 @@ // Credits: whir-p3 (https://github.com/tcoratger/whir-p3) (MIT and Apache-2.0 licenses). use fiat_shamir::FSProver; -use field::{ExtensionField, TwoAdicField}; +use field::{ExtensionField, PrimeField32, TwoAdicField}; use poly::*; +use symetric::merkle::Sha256Digest; use tracing::{info_span, instrument}; use crate::*; @@ -51,6 +52,42 @@ impl>> MerkleData { } } +impl>> MerkleData2 { + pub(crate) fn build(matrix: DftOutput, full_n_cols: usize, effective_n_cols: usize) -> (Self, Sha256Digest) { + match matrix { + DftOutput::Base(m) => { + let (root, prover_data) = merkle_commit_sha2::, PF>(m, full_n_cols, effective_n_cols); + (MerkleData2::Base(prover_data), root) + } + DftOutput::Extension(m) => { + let (root, prover_data) = merkle_commit_sha2::, EF>(m, full_n_cols, effective_n_cols); + (MerkleData2::Extension(prover_data), root) + } + } + } + + pub(crate) fn open(&self, index: usize) -> (MleOwned, Vec) { + match self { + MerkleData2::Base(prover_data) => { + let (leaf, proof) = merkle_open_sha2::, PF>(prover_data, index); + (MleOwned::Base(leaf), proof) + } + MerkleData2::Extension(prover_data) => { + let (leaf, proof) = merkle_open_sha2::, EF>(prover_data, index); + (MleOwned::Extension(leaf), proof) + } + } + } +} + +fn sha256_digest_to_scalars(digest: &Sha256Digest) -> [F; 4] { + std::array::from_fn(|i| { + let offset = i * 4; + let word = u32::from_le_bytes(digest[offset..offset + 4].try_into().unwrap()); + F::from_int(word) + }) +} + #[derive(Debug, Clone)] pub struct Witness where @@ -138,14 +175,14 @@ where let (prover_data, root) = MerkleData2::build(folded_matrix, n_blocks, effective_n_cols); - prover_state.add_base_scalars(&root); + prover_state.add_base_scalars(&sha256_digest_to_scalars(&root)); let (ood_points, ood_answers) = sample_ood_points::(prover_state, self.commitment_ood_samples, self.num_variables, |point| { polynomial.evaluate(point) }); - Witness { + Witness2 { prover_data, ood_points, ood_answers, From 2861180cd4da523a023c8bdab04afe1cef9f04b9 Mon Sep 17 00:00:00 2001 From: kilic Date: Thu, 14 May 2026 19:48:18 +0300 Subject: [PATCH 08/19] whir --- crates/backend/fiat-shamir/src/challenger.rs | 4 + .../backend/fiat-shamir/src/merkle_pruning.rs | 27 +- crates/backend/fiat-shamir/src/prover.rs | 26 +- crates/backend/fiat-shamir/src/traits.rs | 7 +- crates/backend/fiat-shamir/src/transcript.rs | 10 +- crates/backend/fiat-shamir/src/verifier.rs | 8 +- crates/lean_prover/src/prove_execution.rs | 12 +- crates/sub_protocols/src/stacked_pcs.rs | 5 +- crates/whir/src/commit.rs | 18 +- crates/whir/src/open.rs | 366 +++++++++++++++++- 10 files changed, 434 insertions(+), 49 deletions(-) diff --git a/crates/backend/fiat-shamir/src/challenger.rs b/crates/backend/fiat-shamir/src/challenger.rs index 9304a410c..81ea2036c 100644 --- a/crates/backend/fiat-shamir/src/challenger.rs +++ b/crates/backend/fiat-shamir/src/challenger.rs @@ -107,6 +107,10 @@ impl ChallengerSha2 { } } + pub fn observe_bytes(&mut self, bytes: &[u8]) { + self.sha2.update(bytes); + } + pub fn sample_many(&mut self, n: usize) -> Vec<[F; RATE]> { let mut sampled = Vec::with_capacity(n + 1); for i in 0..n + 1 { diff --git a/crates/backend/fiat-shamir/src/merkle_pruning.rs b/crates/backend/fiat-shamir/src/merkle_pruning.rs index 2461d0bbc..c6221d269 100644 --- a/crates/backend/fiat-shamir/src/merkle_pruning.rs +++ b/crates/backend/fiat-shamir/src/merkle_pruning.rs @@ -1,13 +1,13 @@ use serde::{Deserialize, Serialize}; -use crate::{DIGEST_LEN_FE, MerklePath, MerklePaths}; +use crate::{MerklePath, MerklePaths}; #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PrunedMerklePaths { +pub struct PrunedMerklePaths { pub merkle_height: usize, pub original_order: Vec, pub leaf_data: Vec>, - pub paths: Vec<(usize, Vec<[F; DIGEST_LEN_FE]>)>, + pub paths: Vec<(usize, Vec)>, pub n_trailing_zeros: usize, } @@ -15,8 +15,8 @@ fn lca_level(a: usize, b: usize) -> usize { (usize::BITS - (a ^ b).leading_zeros()) as usize } -impl MerklePaths { - pub fn prune(self) -> PrunedMerklePaths +impl MerklePaths { + pub fn prune(self) -> PrunedMerklePaths where Data: Default + PartialEq, { @@ -27,7 +27,7 @@ impl MerklePaths { indexed.sort_by_key(|(_, p)| p.leaf_index); let mut original_order = vec![0; indexed.len()]; - let mut deduped = Vec::>::new(); + let mut deduped = Vec::>::new(); for (orig_idx, path) in indexed { if deduped.last().map(|p| p.leaf_index) == Some(path.leaf_index) { @@ -83,12 +83,12 @@ impl MerklePaths { } } -impl PrunedMerklePaths { +impl PrunedMerklePaths { pub fn restore( mut self, - hash_leaf: &impl Fn(&[Data]) -> [F; DIGEST_LEN_FE], - hash_combine: &impl Fn(&[F; DIGEST_LEN_FE], &[F; DIGEST_LEN_FE]) -> [F; DIGEST_LEN_FE], - ) -> Option> + hash_leaf: &impl Fn(&[Data]) -> Digest, + hash_combine: &impl Fn(&Digest, &Digest) -> Digest, + ) -> Option> where Data: Default, { @@ -112,7 +112,7 @@ impl PrunedMerklePaths { let skip = |i: usize| self.paths.get(i + 1).map(|p| lca_level(self.paths[i].0, p.0) - 1); // Backward pass: compute subtree hashes needed to restore skipped siblings - let mut subtree_hashes: Vec> = vec![vec![]; n]; + let mut subtree_hashes: Vec> = vec![vec![]; n]; for i in (0..n).rev() { let (leaf_idx, ref stored) = self.paths[i]; @@ -139,7 +139,7 @@ impl PrunedMerklePaths { } // Forward pass: build full sibling arrays - let mut restored: Vec> = Vec::with_capacity(n); + let mut restored: Vec> = Vec::with_capacity(n); for i in 0..n { let (leaf_idx, ref stored) = self.paths[i]; @@ -178,6 +178,7 @@ impl PrunedMerklePaths { #[cfg(test)] mod tests { use super::*; + use crate::DIGEST_LEN_FE; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; @@ -231,7 +232,7 @@ mod tests { leaf_data: Vec, leaf_index: usize, tree: &[Vec<[u8; DIGEST_LEN_FE]>], - ) -> MerklePath { + ) -> MerklePath { let height = tree.len() - 1; let mut sibling_hashes = Vec::with_capacity(height); diff --git a/crates/backend/fiat-shamir/src/prover.rs b/crates/backend/fiat-shamir/src/prover.rs index 7467975bd..b15fa83a1 100644 --- a/crates/backend/fiat-shamir/src/prover.rs +++ b/crates/backend/fiat-shamir/src/prover.rs @@ -13,6 +13,7 @@ use std::sync::atomic::{AtomicU64, Ordering}; use std::time::Duration; use std::{fmt::Debug, sync::Mutex, time::Instant}; use symetric::Compression; +use symetric::merkle::Sha256Digest; static POW_GRINDING_NANOS: AtomicU64 = AtomicU64::new(0); @@ -28,7 +29,7 @@ pub fn reset_pow_grinding_time() { pub struct ProverState>, P> { challenger: Challenger, P>, transcript: Vec>, - merkle_paths: Vec, PF>>, + merkle_paths: Vec, [PF; DIGEST_LEN_FE]>>, } impl>, P: Compression<[PF; WIDTH]>> ProverState @@ -71,11 +72,17 @@ impl>, P: Compression<[PF; WIDTH]> + Compression<[ where PF: PrimeField64, { + type Digest = [PF; DIGEST_LEN_FE]; + fn add_base_scalars(&mut self, scalars: &[PF]) { self.challenger.observe_scalars(scalars); self.transcript.extend_from_slice(scalars); } + fn add_commitment(&mut self, commitment: &Self::Digest) { + self.add_base_scalars(commitment); + } + fn observe_scalars(&mut self, scalars: &[PF]) { self.challenger.observe_scalars(scalars); } @@ -109,7 +116,7 @@ where } } - fn hint_merkle_paths_base(&mut self, paths: Vec, PF>>) { + fn hint_merkle_paths_base(&mut self, paths: Vec, Self::Digest>>) { self.merkle_paths.push(MerklePaths(paths).prune()); } @@ -175,7 +182,8 @@ where pub struct ProverStateSha2>> { challenger: ChallengerSha2>, transcript: Vec>, - merkle_paths: Vec, PF>>, + commitments: Vec, + merkle_paths: Vec, Sha256Digest>>, } impl>> ProverStateSha2 @@ -188,11 +196,12 @@ where Self { challenger: ChallengerSha2::new(), transcript: Vec::new(), + commitments: Vec::new(), merkle_paths: Vec::new(), } } - pub fn into_proof(self) -> Proof> { + pub fn into_proof(self) -> Proof, Sha256Digest> { Proof { transcript: self.transcript, merkle_paths: self.merkle_paths, @@ -224,11 +233,18 @@ impl>> FSProver for ProverStateSha2 where PF: PrimeField32, { + type Digest = Sha256Digest; + fn add_base_scalars(&mut self, scalars: &[PF]) { self.challenger.observe_scalars(scalars); self.transcript.extend_from_slice(scalars); } + fn add_commitment(&mut self, commitment: &Self::Digest) { + self.challenger.observe_bytes(commitment); + self.commitments.push(*commitment); + } + fn observe_scalars(&mut self, scalars: &[PF]) { self.challenger.observe_scalars(scalars); } @@ -253,7 +269,7 @@ where } } - fn hint_merkle_paths_base(&mut self, paths: Vec, PF>>) { + fn hint_merkle_paths_base(&mut self, paths: Vec, Self::Digest>>) { self.merkle_paths.push(MerklePaths(paths).prune()); } diff --git a/crates/backend/fiat-shamir/src/traits.rs b/crates/backend/fiat-shamir/src/traits.rs index 5aba9f667..e3bd9b4b5 100644 --- a/crates/backend/fiat-shamir/src/traits.rs +++ b/crates/backend/fiat-shamir/src/traits.rs @@ -13,11 +13,14 @@ pub trait ChallengeSampler { } pub trait FSProver>>: ChallengeSampler { + type Digest: Clone; + fn state(&self) -> String; fn add_base_scalars(&mut self, scalars: &[PF]); + fn add_commitment(&mut self, commitment: &Self::Digest); fn observe_scalars(&mut self, scalars: &[PF]); fn pow_grinding(&mut self, bits: usize); - fn hint_merkle_paths_base(&mut self, paths: Vec, PF>>); + fn hint_merkle_paths_base(&mut self, paths: Vec, Self::Digest>>); fn add_sumcheck_polynomial(&mut self, coeffs: &[EF], eq_alpha: Option); fn add_extension_scalars(&mut self, scalars: &[EF]) { @@ -28,7 +31,7 @@ pub trait FSProver>>: ChallengeSampler { self.add_extension_scalars(&[scalar]); } - fn hint_merkle_paths_extension(&mut self, paths: Vec>>) { + fn hint_merkle_paths_extension(&mut self, paths: Vec>) { self.hint_merkle_paths_base( paths .into_iter() diff --git a/crates/backend/fiat-shamir/src/transcript.rs b/crates/backend/fiat-shamir/src/transcript.rs index 612c2d109..4cf751c60 100644 --- a/crates/backend/fiat-shamir/src/transcript.rs +++ b/crates/backend/fiat-shamir/src/transcript.rs @@ -19,20 +19,20 @@ pub struct RawProof { } #[derive(Debug, Clone)] -pub struct MerklePath { +pub struct MerklePath { pub leaf_data: Vec, - pub sibling_hashes: Vec<[F; DIGEST_LEN_FE]>, + pub sibling_hashes: Vec, // does not appear in the proof itself, but useful for Merkle pruning pub leaf_index: usize, } #[derive(Debug, Clone)] -pub struct MerklePaths(pub(crate) Vec>); +pub struct MerklePaths(pub(crate) Vec>); #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Proof { +pub struct Proof { pub(crate) transcript: Vec, - pub(crate) merkle_paths: Vec>, + pub(crate) merkle_paths: Vec>, } impl Proof { diff --git a/crates/backend/fiat-shamir/src/verifier.rs b/crates/backend/fiat-shamir/src/verifier.rs index 9bbc26bd7..8c40ff1dd 100644 --- a/crates/backend/fiat-shamir/src/verifier.rs +++ b/crates/backend/fiat-shamir/src/verifier.rs @@ -67,16 +67,18 @@ where } #[allow(clippy::missing_transmute_annotations)] - fn restore_merkle_paths(paths: PrunedMerklePaths, PF>) -> Option>>> { + fn restore_merkle_paths( + paths: PrunedMerklePaths, [PF; DIGEST_LEN_FE]>, + ) -> Option>>> { assert_eq!(TypeId::of::>(), TypeId::of::()); // SAFETY: We've confirmed PF == KoalaBear - let paths: PrunedMerklePaths = unsafe { std::mem::transmute(paths) }; + let paths: PrunedMerklePaths = unsafe { std::mem::transmute(paths) }; let perm = default_koalabear_poseidon1_16(); let hash_fn = |data: &[KoalaBear]| symetric::hash_slice::<_, _, 16, 8, DIGEST_LEN_FE>(&perm, data); let combine_fn = |left: &[KoalaBear; DIGEST_LEN_FE], right: &[KoalaBear; DIGEST_LEN_FE]| { symetric::compress(&perm, [*left, *right]) }; - let restored: MerklePaths = paths.restore(&hash_fn, &combine_fn)?; + let restored: MerklePaths = paths.restore(&hash_fn, &combine_fn)?; let openings: Vec> = restored .0 .into_iter() diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index f7a490883..a74c7e28d 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -1,6 +1,7 @@ use std::collections::BTreeMap; use crate::*; +use backend::merkle::Sha256Digest; use lean_vm::*; use serde::{Deserialize, Serialize}; @@ -12,6 +13,7 @@ use utils::{build_prover_state, build_prover_state_sha2, from_end}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ExecutionProof { pub proof: Proof, + pub proof2: Proof, // benchmark / debug purpose #[serde(skip, default)] pub metadata: Option, @@ -358,7 +360,7 @@ pub fn prove_execution( &tables_log_heights, &committed_statements, ); - let _global_statements_base2 = stacked_pcs_global_statements( + let global_statements_base2 = stacked_pcs_global_statements( stacked_pcs_witness.stacked_n_vars, log2_strict_usize(memory.len()), bytecode.log_size(), @@ -374,11 +376,19 @@ pub fn prove_execution( &stacked_pcs_witness.global_polynomial.by_ref(), ); + WhirConfig::new(whir_config, stacked_pcs_witness.global_polynomial.by_ref().n_vars()).prove2( + &mut prover_state2, + global_statements_base2, + stacked_pcs_witness.inner_witness2, + &stacked_pcs_witness.global_polynomial.by_ref(), + ); + tracing::info!("total pow_grinding time: {} ms", pow_grinding_time().as_millis()); reset_pow_grinding_time(); Ok(ExecutionProof { proof: prover_state.into_proof(), + proof2: prover_state2.into_proof(), metadata: Some(metadata), }) } diff --git a/crates/sub_protocols/src/stacked_pcs.rs b/crates/sub_protocols/src/stacked_pcs.rs index b76e95b8d..c8dfae4fc 100644 --- a/crates/sub_protocols/src/stacked_pcs.rs +++ b/crates/sub_protocols/src/stacked_pcs.rs @@ -1,3 +1,4 @@ +use backend::merkle::Sha256Digest; use backend::*; use lean_vm::{ ALL_TABLES, COL_PC, CommittedStatements, ENDING_PC, MIN_LOG_MEMORY_SIZE, MIN_LOG_N_ROWS_PER_TABLE, @@ -97,8 +98,8 @@ pub fn stacked_pcs_global_statements( #[instrument(skip_all)] pub fn stack_polynomials_and_commit( - prover_state: &mut impl FSProver, - prover_state2: &mut impl FSProver, + prover_state: &mut impl FSProver, + prover_state2: &mut impl FSProver, whir_config_builder: &WhirConfigBuilder, memory: &[F], memory_acc: &[F], diff --git a/crates/whir/src/commit.rs b/crates/whir/src/commit.rs index f869e5cc3..0aecf8002 100644 --- a/crates/whir/src/commit.rs +++ b/crates/whir/src/commit.rs @@ -1,7 +1,7 @@ // Credits: whir-p3 (https://github.com/tcoratger/whir-p3) (MIT and Apache-2.0 licenses). use fiat_shamir::FSProver; -use field::{ExtensionField, PrimeField32, TwoAdicField}; +use field::{ExtensionField, TwoAdicField}; use poly::*; use symetric::merkle::Sha256Digest; use tracing::{info_span, instrument}; @@ -80,14 +80,6 @@ impl>> MerkleData2 { } } -fn sha256_digest_to_scalars(digest: &Sha256Digest) -> [F; 4] { - std::array::from_fn(|i| { - let offset = i * 4; - let word = u32::from_le_bytes(digest[offset..offset + 4].try_into().unwrap()); - F::from_int(word) - }) -} - #[derive(Debug, Clone)] pub struct Witness where @@ -116,7 +108,7 @@ where #[instrument(skip_all)] pub fn commit( &self, - prover_state: &mut impl FSProver, + prover_state: &mut impl FSProver; DIGEST_ELEMS]>, polynomial: &MleOwned, actual_data_len: usize, // polynomial[actual_data_len..] is zero ) -> Witness { @@ -137,7 +129,7 @@ where let (prover_data, root) = MerkleData::build(folded_matrix, n_blocks, effective_n_cols); - prover_state.add_base_scalars(&root); + prover_state.add_commitment(&root); let (ood_points, ood_answers) = sample_ood_points::(prover_state, self.commitment_ood_samples, self.num_variables, |point| { @@ -154,7 +146,7 @@ where #[instrument(skip_all)] pub fn commit2( &self, - prover_state: &mut impl FSProver, + prover_state: &mut impl FSProver, polynomial: &MleOwned, actual_data_len: usize, // polynomial[actual_data_len..] is zero ) -> Witness2 { @@ -175,7 +167,7 @@ where let (prover_data, root) = MerkleData2::build(folded_matrix, n_blocks, effective_n_cols); - prover_state.add_base_scalars(&sha256_digest_to_scalars(&root)); + prover_state.add_commitment(&root); let (ood_points, ood_answers) = sample_ood_points::(prover_state, self.commitment_ood_samples, self.num_variables, |point| { diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 8b8b4031c..d9bbd36c5 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -7,6 +7,7 @@ use field::{ExtensionField, Field, TwoAdicField}; use poly::*; use rayon::prelude::*; use sumcheck::{ProductComputation, run_product_sumcheck, sumcheck_prove_many_rounds}; +use symetric::merkle::Sha256Digest; use tracing::{info_span, instrument}; use crate::{config::WhirConfig, *}; @@ -36,7 +37,7 @@ where #[instrument(name = "WHIR prove", skip_all)] pub fn prove( &self, - prover_state: &mut impl FSProver, + prover_state: &mut impl FSProver; DIGEST_ELEMS]>, statement: Vec>, witness: Witness, polynomial: &MleRef<'_, EF>, @@ -55,10 +56,32 @@ where MultilinearPoint(round_state.randomness_vec) } + #[instrument(name = "WHIR prove", skip_all)] + pub fn prove2( + &self, + prover_state: &mut impl FSProver, + statement: Vec>, + witness: Witness2, + polynomial: &MleRef<'_, EF>, + ) -> MultilinearPoint { + assert!(self.validate_parameters()); + // assert!(self.validate_witness(&witness, polynomial)); + self.validate_statement(&statement); + + let mut round_state = + RoundState2::initialize_first_round_state(self, prover_state, statement, witness, polynomial).unwrap(); + + for round in 0..=self.n_rounds() { + self.round2(round, prover_state, &mut round_state).unwrap(); + } + + MultilinearPoint(round_state.randomness_vec) + } + fn round( &self, round_index: usize, - prover_state: &mut impl FSProver, + prover_state: &mut impl FSProver; DIGEST_ELEMS]>, round_state: &mut RoundState, ) -> ProofResult<()> { let folded_evaluations = &round_state.sumcheck_prover.evals; @@ -90,7 +113,7 @@ where let full = 1 << folding_factor_next; let (prover_data, root) = MerkleData::build(folded_matrix, full, full); - prover_state.add_base_scalars(&root); + prover_state.add_commitment(&root); // Handle OOD (Out-Of-Domain) samples let (ood_points, ood_answers) = @@ -178,10 +201,136 @@ where Ok(()) } + fn round2( + &self, + round_index: usize, + prover_state: &mut impl FSProver, + round_state: &mut RoundState2, + ) -> ProofResult<()> { + let folded_evaluations = &round_state.sumcheck_prover.evals; + let num_variables = self.num_variables - self.folding_factor.total_number(round_index); + + // Base case: final round reached + if round_index == self.n_rounds() { + return self.final_round2(round_index, prover_state, round_state); + } + + let round_params = &self.round_parameters[round_index]; + + // Compute the folding factors for later use + let folding_factor_next = self.folding_factor.at_round(round_index + 1); + + // Compute polynomial evaluations and build Merkle tree + let domain_reduction = 1 << self.rs_reduction_factor(round_index); + let new_domain_size = round_state.domain_size / domain_reduction; + let inv_rate = new_domain_size >> num_variables; + let folded_matrix = info_span!("FFT").in_scope(|| { + reorder_and_dft( + &folded_evaluations.by_ref(), + folding_factor_next, + log2_strict_usize(inv_rate), + 1 << folding_factor_next, + ) + }); + + let full = 1 << folding_factor_next; + let (prover_data, root) = MerkleData2::build(folded_matrix, full, full); + + prover_state.add_commitment(&root); + + // Handle OOD (Out-Of-Domain) samples + let (ood_points, ood_answers) = + sample_ood_points::(prover_state, round_params.ood_samples, num_variables, |point| { + info_span!("ood evaluation").in_scope(|| folded_evaluations.evaluate(point)) + }); + + prover_state.pow_grinding(round_params.query_pow_bits); + + let (ood_challenges, stir_challenges, stir_challenges_indexes) = self.compute_stir_queries2( + prover_state, + round_state, + num_variables, + round_params, + &ood_points, + round_index, + )?; + + let folding_randomness = round_state.folding_randomness( + self.folding_factor.at_round(round_index) + round_state.commitment_merkle_prover_data_b.is_some() as usize, + ); + + let stir_evaluations = if let Some(data_b) = &round_state.commitment_merkle_prover_data_b { + let answers_a = open_merkle_tree_at_challenges2( + &round_state.merkle_prover_data, + prover_state, + &stir_challenges_indexes, + ); + let answers_b = open_merkle_tree_at_challenges2(data_b, prover_state, &stir_challenges_indexes); + let mut stir_evaluations = Vec::new(); + for (answer_a, answer_b) in answers_a.iter().zip(&answers_b) { + let vars_a = answer_a.by_ref().n_vars(); + let vars_b = answer_b.by_ref().n_vars(); + let a_trunc = folding_randomness[1..].to_vec(); + let eval_a = answer_a.evaluate(&MultilinearPoint(a_trunc)); + let b_trunc = folding_randomness[vars_a - vars_b + 1..].to_vec(); + let eval_b = answer_b.evaluate(&MultilinearPoint(b_trunc)); + let last_fold_rand_a = folding_randomness[0]; + let last_fold_rand_b = folding_randomness[..vars_a - vars_b + 1] + .iter() + .map(|&x| EF::ONE - x) + .product::(); + stir_evaluations.push(eval_a * last_fold_rand_a + eval_b * last_fold_rand_b); + } + + stir_evaluations + } else { + open_merkle_tree_at_challenges2(&round_state.merkle_prover_data, prover_state, &stir_challenges_indexes) + .iter() + .map(|answer| answer.evaluate(&folding_randomness)) + .collect() + }; + + // Randomness for combination + let combination_randomness_gen: EF = prover_state.sample(); + let ood_combination_randomness: Vec<_> = combination_randomness_gen.powers().collect_n(ood_challenges.len()); + round_state + .sumcheck_prover + .add_new_equality(&ood_challenges, &ood_answers, &ood_combination_randomness); + let stir_combination_randomness = combination_randomness_gen + .powers() + .skip(ood_challenges.len()) + .take(stir_challenges.len()) + .collect::>(); + + round_state.sumcheck_prover.add_new_base_equality( + &stir_challenges, + &stir_evaluations, + &stir_combination_randomness, + ); + + let next_folding_randomness = round_state.sumcheck_prover.run_sumcheck_many_rounds( + None, + prover_state, + folding_factor_next, + round_params.folding_pow_bits, + ); + + round_state.randomness_vec.extend_from_slice(&next_folding_randomness.0); + + // Update round state + round_state.domain_size = new_domain_size; + round_state.next_domain_gen = + PF::::two_adic_generator(log2_strict_usize(new_domain_size) - folding_factor_next); + round_state.merkle_prover_data = prover_data; + round_state.commitment_merkle_prover_data_b = None; + + Ok(()) + } + fn final_round( &self, round_index: usize, - prover_state: &mut impl FSProver, + prover_state: &mut impl FSProver; DIGEST_ELEMS]>, round_state: &mut RoundState, ) -> ProofResult<()> { // Convert evaluations to coefficient form and send to the verifier. @@ -246,6 +395,74 @@ where Ok(()) } + fn final_round2( + &self, + round_index: usize, + prover_state: &mut impl FSProver, + round_state: &mut RoundState2, + ) -> ProofResult<()> { + // Convert evaluations to coefficient form and send to the verifier. + let mut coeffs = match &round_state.sumcheck_prover.evals { + MleOwned::Extension(evals) => evals.clone(), + MleOwned::ExtensionPacked(evals) => unpack_extension::(evals), + _ => unreachable!(), + }; + evals_to_coeffs(&mut coeffs); + prover_state.add_extension_scalars(&coeffs); + + prover_state.pow_grinding(self.final_query_pow_bits); + + // Final verifier queries and answers. The indices are over the folded domain. + let final_challenge_indexes = get_challenge_stir_queries( + // The size of the original domain before folding + round_state.domain_size >> self.folding_factor.at_round(round_index), + self.final_queries, + prover_state, + ); + + let mut base_paths = Vec::new(); + let mut ext_paths = Vec::new(); + for challenge in final_challenge_indexes { + let (answer, sibling_hashes) = round_state.merkle_prover_data.open(challenge); + + match answer { + MleOwned::Base(leaf) => { + base_paths.push(MerklePath { + leaf_data: leaf, + sibling_hashes, + leaf_index: challenge, + }); + } + MleOwned::Extension(leaf) => { + ext_paths.push(MerklePath { + leaf_data: leaf, + sibling_hashes, + leaf_index: challenge, + }); + } + _ => unreachable!(), + } + } + if !base_paths.is_empty() { + prover_state.hint_merkle_paths_base(base_paths); + } + if !ext_paths.is_empty() { + prover_state.hint_merkle_paths_extension(ext_paths); + } + + // Run final sumcheck if required + if self.final_sumcheck_rounds > 0 { + let final_folding_randomness = + round_state + .sumcheck_prover + .run_sumcheck_many_rounds(None, prover_state, self.final_sumcheck_rounds, 0); + + round_state.randomness_vec.extend(final_folding_randomness.0); + } + + Ok(()) + } + #[allow(clippy::type_complexity)] fn compute_stir_queries( &self, @@ -274,11 +491,82 @@ where Ok((ood_challenges, stir_challenges, stir_challenges_indexes)) } + + #[allow(clippy::type_complexity)] + fn compute_stir_queries2( + &self, + prover_state: &mut impl FSProver, + round_state: &RoundState2, + num_variables: usize, + round_params: &RoundConfig, + ood_points: &[EF], + round_index: usize, + ) -> ProofResult<(Vec>, Vec>>, Vec)> { + let stir_challenges_indexes = get_challenge_stir_queries( + round_state.domain_size >> self.folding_factor.at_round(round_index), + round_params.num_queries, + prover_state, + ); + + let domain_scaled_gen = round_state.next_domain_gen; + let ood_challenges = ood_points + .iter() + .map(|univariate| MultilinearPoint::expand_from_univariate(*univariate, num_variables)) + .collect(); + let stir_challenges = stir_challenges_indexes + .iter() + .map(|i| MultilinearPoint::expand_from_univariate(domain_scaled_gen.exp_u64(*i as u64), num_variables)) + .collect(); + + Ok((ood_challenges, stir_challenges, stir_challenges_indexes)) + } } fn open_merkle_tree_at_challenges>>( merkle_tree: &MerkleData, - prover_state: &mut impl FSProver, + prover_state: &mut impl FSProver; DIGEST_ELEMS]>, + stir_challenges_indexes: &[usize], +) -> Vec> { + let mut answers = Vec::new(); + let mut base_paths = Vec::new(); + let mut ext_paths = Vec::new(); + + for &challenge in stir_challenges_indexes { + let (answer, sibling_hashes) = merkle_tree.open(challenge); + + match &answer { + MleOwned::Base(leaf) => { + base_paths.push(MerklePath { + leaf_data: leaf.clone(), + sibling_hashes, + leaf_index: challenge, + }); + } + MleOwned::Extension(leaf) => { + ext_paths.push(MerklePath { + leaf_data: leaf.clone(), + sibling_hashes, + leaf_index: challenge, + }); + } + _ => unreachable!(), + } + answers.push(answer); + } + + if !base_paths.is_empty() { + prover_state.hint_merkle_paths_base(base_paths); + } + if !ext_paths.is_empty() { + prover_state.hint_merkle_paths_extension(ext_paths); + } + + answers +} + +fn open_merkle_tree_at_challenges2>>( + merkle_tree: &MerkleData2, + prover_state: &mut impl FSProver, stir_challenges_indexes: &[usize], ) -> Vec> { let mut answers = Vec::new(); @@ -457,6 +745,19 @@ where randomness_vec: Vec, } +#[derive(Debug)] +pub(crate) struct RoundState2 +where + EF: ExtensionField>, +{ + domain_size: usize, + next_domain_gen: PF, + sumcheck_prover: SumcheckSingle, + commitment_merkle_prover_data_b: Option>, + merkle_prover_data: MerkleData2, + randomness_vec: Vec, +} + #[allow(clippy::mismatching_type_param_order)] impl RoundState where @@ -512,6 +813,61 @@ where } } +#[allow(clippy::mismatching_type_param_order)] +impl RoundState2 +where + EF: ExtensionField>, + PF: TwoAdicField, +{ + pub(crate) fn initialize_first_round_state( + prover: &WhirConfig, + prover_state: &mut impl FSProver, + mut statement: Vec>, + witness: Witness2, + polynomial: &MleRef<'_, EF>, + ) -> ProofResult { + let ood_statements = witness + .ood_points + .into_iter() + .zip(witness.ood_answers) + .map(|(point, evaluation)| { + SparseStatement::dense( + MultilinearPoint::expand_from_univariate(point, prover.num_variables), + evaluation, + ) + }) + .collect::>(); + + statement.splice(0..0, ood_statements); + + let combination_randomness_gen: EF = prover_state.sample(); + + let (sumcheck_prover, folding_randomness) = SumcheckSingle::run_initial_sumcheck_rounds( + polynomial, + &statement, + combination_randomness_gen, + prover_state, + prover.folding_factor.at_round(0), + prover.starting_folding_pow_bits, + ); + + Ok(Self { + domain_size: prover.starting_domain_size(), + next_domain_gen: PF::::two_adic_generator( + log2_strict_usize(prover.starting_domain_size()) - prover.folding_factor.at_round(0), + ), + sumcheck_prover, + merkle_prover_data: witness.prover_data, + commitment_merkle_prover_data_b: None, + randomness_vec: folding_randomness.0.clone(), + }) + } + + fn folding_randomness(&self, folding_factor: usize) -> MultilinearPoint { + MultilinearPoint(self.randomness_vec[self.randomness_vec.len() - folding_factor..].to_vec()) + } +} + #[instrument(skip_all, fields(num_constraints = statements.len(), n_vars = statements[0].total_num_variables))] fn combine_statement(statements: &[SparseStatement], gamma: EF) -> (Vec>, EF) where From 98855d39518f2427e4564a1f489886a22ec18028 Mon Sep 17 00:00:00 2001 From: kilic Date: Thu, 14 May 2026 19:51:20 +0300 Subject: [PATCH 09/19] fix --- crates/backend/fiat-shamir/src/prover.rs | 2 ++ crates/backend/fiat-shamir/src/transcript.rs | 2 ++ 2 files changed, 4 insertions(+) diff --git a/crates/backend/fiat-shamir/src/prover.rs b/crates/backend/fiat-shamir/src/prover.rs index b15fa83a1..1a2719686 100644 --- a/crates/backend/fiat-shamir/src/prover.rs +++ b/crates/backend/fiat-shamir/src/prover.rs @@ -49,6 +49,7 @@ where pub fn into_proof(self) -> Proof> { Proof { transcript: self.transcript, + commitments: Vec::new(), merkle_paths: self.merkle_paths, } } @@ -204,6 +205,7 @@ where pub fn into_proof(self) -> Proof, Sha256Digest> { Proof { transcript: self.transcript, + commitments: self.commitments, merkle_paths: self.merkle_paths, } } diff --git a/crates/backend/fiat-shamir/src/transcript.rs b/crates/backend/fiat-shamir/src/transcript.rs index 4cf751c60..3963db26c 100644 --- a/crates/backend/fiat-shamir/src/transcript.rs +++ b/crates/backend/fiat-shamir/src/transcript.rs @@ -32,6 +32,8 @@ pub struct MerklePaths(pub(crate) Vec>); #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Proof { pub(crate) transcript: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub(crate) commitments: Vec, pub(crate) merkle_paths: Vec>, } From 9662351c5860344ac11fc9e628eb5ae41f76a285 Mon Sep 17 00:00:00 2001 From: kilic Date: Thu, 14 May 2026 20:21:52 +0300 Subject: [PATCH 10/19] prep for v --- crates/backend/fiat-shamir/src/traits.rs | 3 +++ crates/backend/fiat-shamir/src/verifier.rs | 11 +++++++++++ crates/sub_protocols/src/stacked_pcs.rs | 2 +- crates/whir/src/verify.rs | 8 ++++---- 4 files changed, 19 insertions(+), 5 deletions(-) diff --git a/crates/backend/fiat-shamir/src/traits.rs b/crates/backend/fiat-shamir/src/traits.rs index e3bd9b4b5..687ec7f5b 100644 --- a/crates/backend/fiat-shamir/src/traits.rs +++ b/crates/backend/fiat-shamir/src/traits.rs @@ -46,8 +46,11 @@ pub trait FSProver>>: ChallengeSampler { } pub trait FSVerifier>>: ChallengeSampler { + type Digest: Clone; + fn state(&self) -> String; fn next_base_scalars_vec(&mut self, n: usize) -> Result>, ProofError>; + fn next_commitment(&mut self) -> Result; fn observe_scalars(&mut self, scalars: &[PF]); fn next_merkle_opening(&mut self) -> Result>, ProofError>; fn check_pow_grinding(&mut self, bits: usize) -> Result<(), ProofError>; diff --git a/crates/backend/fiat-shamir/src/verifier.rs b/crates/backend/fiat-shamir/src/verifier.rs index 8c40ff1dd..84c6d4a75 100644 --- a/crates/backend/fiat-shamir/src/verifier.rs +++ b/crates/backend/fiat-shamir/src/verifier.rs @@ -26,6 +26,10 @@ where PF: PrimeField64, { pub fn new(proof: Proof>, compressor: C) -> Result { + if !proof.commitments.is_empty() { + return Err(ProofError::InvalidProof); + } + let mut merkle_openings = Vec::new(); for paths in proof.merkle_paths { let restored = Self::restore_merkle_paths(paths).ok_or(ProofError::InvalidProof)?; @@ -108,6 +112,8 @@ impl>, C: Compression<[PF; WIDTH]>> FSVerifier where PF: PrimeField64, { + type Digest = [PF; DIGEST_LEN_FE]; + fn state(&self) -> String { format!( "state {} (offset: {}, merkle_idx: {})", @@ -132,6 +138,11 @@ where Ok(scalars) } + fn next_commitment(&mut self) -> Result { + self.next_base_scalars_vec(DIGEST_LEN_FE) + .map(|scalars| scalars.try_into().unwrap()) + } + fn next_merkle_opening(&mut self) -> Result>, ProofError> { if self.merkle_opening_index >= self.merkle_openings.len() { return Err(ProofError::ExceededTranscript); diff --git a/crates/sub_protocols/src/stacked_pcs.rs b/crates/sub_protocols/src/stacked_pcs.rs index c8dfae4fc..5921a3122 100644 --- a/crates/sub_protocols/src/stacked_pcs.rs +++ b/crates/sub_protocols/src/stacked_pcs.rs @@ -162,7 +162,7 @@ pub fn stack_polynomials_and_commit( pub fn stacked_pcs_parse_commitment( whir_config_builder: &WhirConfigBuilder, - verifier_state: &mut impl FSVerifier, + verifier_state: &mut impl FSVerifier, log_memory: usize, log_bytecode: usize, tables_heights: &BTreeMap, diff --git a/crates/whir/src/verify.rs b/crates/whir/src/verify.rs index 18925b287..ebb616109 100644 --- a/crates/whir/src/verify.rs +++ b/crates/whir/src/verify.rs @@ -19,7 +19,7 @@ pub struct ParsedCommitment> { impl> ParsedCommitment { pub fn parse( - verifier_state: &mut impl FSVerifier, + verifier_state: &mut impl FSVerifier; DIGEST_ELEMS]>, num_variables: usize, ood_samples: usize, ) -> ProofResult @@ -28,7 +28,7 @@ impl> ParsedCommitment { EF: ExtensionField + TwoAdicField, EF: ExtensionField>, { - let root = verifier_state.next_base_scalars_vec(DIGEST_ELEMS)?.try_into().unwrap(); + let root = verifier_state.next_commitment()?; let mut ood_points = vec![]; let ood_answers = if ood_samples > 0 { ood_points = verifier_state.sample_vec(ood_samples); @@ -65,7 +65,7 @@ where { pub fn parse_commitment( &self, - verifier_state: &mut impl FSVerifier, + verifier_state: &mut impl FSVerifier; DIGEST_ELEMS]>, ) -> ProofResult> where EF: ExtensionField, @@ -82,7 +82,7 @@ where #[allow(clippy::too_many_lines)] pub fn verify( &self, - verifier_state: &mut impl FSVerifier, + verifier_state: &mut impl FSVerifier; DIGEST_ELEMS]>, parsed_commitment: &ParsedCommitment, statement: Vec>, ) -> ProofResult> From 323d1c1e23afd7ef64480c9a79e74870db012c78 Mon Sep 17 00:00:00 2001 From: kilic Date: Thu, 14 May 2026 20:51:12 +0300 Subject: [PATCH 11/19] verifier --- crates/backend/fiat-shamir/src/traits.rs | 2 +- crates/backend/fiat-shamir/src/transcript.rs | 4 +- crates/backend/fiat-shamir/src/verifier.rs | 194 ++++++++++++- crates/lean_prover/src/test_zkvm.rs | 7 +- crates/lean_prover/src/verify_execution.rs | 216 +++++++++++++- crates/sub_protocols/src/stacked_pcs.rs | 22 ++ crates/whir/src/merkle.rs | 7 +- crates/whir/src/verify.rs | 288 ++++++++++++++++++- 8 files changed, 729 insertions(+), 11 deletions(-) diff --git a/crates/backend/fiat-shamir/src/traits.rs b/crates/backend/fiat-shamir/src/traits.rs index 687ec7f5b..baf098b8a 100644 --- a/crates/backend/fiat-shamir/src/traits.rs +++ b/crates/backend/fiat-shamir/src/traits.rs @@ -52,7 +52,7 @@ pub trait FSVerifier>>: ChallengeSampler { fn next_base_scalars_vec(&mut self, n: usize) -> Result>, ProofError>; fn next_commitment(&mut self) -> Result; fn observe_scalars(&mut self, scalars: &[PF]); - fn next_merkle_opening(&mut self) -> Result>, ProofError>; + fn next_merkle_opening(&mut self) -> Result, Self::Digest>, ProofError>; fn check_pow_grinding(&mut self, bits: usize) -> Result<(), ProofError>; fn next_sumcheck_polynomial( &mut self, diff --git a/crates/backend/fiat-shamir/src/transcript.rs b/crates/backend/fiat-shamir/src/transcript.rs index 3963db26c..9b88dbd5d 100644 --- a/crates/backend/fiat-shamir/src/transcript.rs +++ b/crates/backend/fiat-shamir/src/transcript.rs @@ -6,9 +6,9 @@ use crate::PrunedMerklePaths; pub const DIGEST_LEN_FE: usize = 8; #[derive(Debug, Clone)] -pub struct MerkleOpening { +pub struct MerkleOpening { pub leaf_data: Vec, - pub path: Vec<[F; DIGEST_LEN_FE]>, + pub path: Vec, } /// "RawProof": the format which is used in the zkVM recursion program (no Merkle pruning, no sumcheck optimization to send less data, etc) diff --git a/crates/backend/fiat-shamir/src/verifier.rs b/crates/backend/fiat-shamir/src/verifier.rs index 84c6d4a75..aec46f64d 100644 --- a/crates/backend/fiat-shamir/src/verifier.rs +++ b/crates/backend/fiat-shamir/src/verifier.rs @@ -3,14 +3,16 @@ use std::iter::repeat_n; use crate::{ MerkleOpening, MerklePaths, PrunedMerklePaths, RawProof, - challenger::{Challenger, RATE, WIDTH}, + challenger::{Challenger, ChallengerSha2, RATE, WIDTH}, transcript::{DIGEST_LEN_FE, Proof}, *, }; use field::PrimeCharacteristicRing; -use field::{ExtensionField, PrimeField64}; +use field::{ExtensionField, PrimeField32, PrimeField64}; use koala_bear::{KoalaBear, default_koalabear_poseidon1_16}; +use sha2::{Digest as Sha2DigestTrait, Sha256}; use symetric::Compression; +use symetric::merkle::Sha256Digest; pub struct VerifierState>, P> { challenger: Challenger, P>, @@ -143,7 +145,7 @@ where .map(|scalars| scalars.try_into().unwrap()) } - fn next_merkle_opening(&mut self) -> Result>, ProofError> { + fn next_merkle_opening(&mut self) -> Result, Self::Digest>, ProofError> { if self.merkle_opening_index >= self.merkle_openings.len() { return Err(ProofError::ExceededTranscript); } @@ -204,3 +206,189 @@ where } } } + +pub struct VerifierStateSha2>> { + challenger: ChallengerSha2>, + transcript: Vec>, + transcript_offset: usize, + commitments: Vec, + commitment_index: usize, + merkle_openings: Vec, Sha256Digest>>, + merkle_opening_index: usize, +} + +impl>> VerifierStateSha2 +where + PF: PrimeField32, +{ + pub fn new(proof: Proof, Sha256Digest>) -> Result { + let mut merkle_openings = Vec::new(); + for paths in proof.merkle_paths { + let restored = Self::restore_merkle_paths(paths).ok_or(ProofError::InvalidProof)?; + merkle_openings.extend(restored); + } + + Ok(Self { + challenger: ChallengerSha2::new(), + transcript: proof.transcript, + transcript_offset: 0, + commitments: proof.commitments, + commitment_index: 0, + merkle_openings, + merkle_opening_index: 0, + }) + } + + fn read_transcript(&mut self, n: usize) -> Result>, ProofError> { + if self.transcript_offset + n > self.transcript.len() { + return Err(ProofError::ExceededTranscript); + } + let scalars = self.transcript[self.transcript_offset..self.transcript_offset + n].to_vec(); + self.transcript_offset += n; + Ok(scalars) + } + + fn restore_merkle_paths( + paths: PrunedMerklePaths, Sha256Digest>, + ) -> Option, Sha256Digest>>> { + let hash_fn = |data: &[PF]| { + let mut hasher = Sha256::new(); + for value in data { + hasher.update(value.as_canonical_u32().to_le_bytes()); + } + let digest = hasher.finalize(); + digest[..16].try_into().unwrap() + }; + let combine_fn = |left: &Sha256Digest, right: &Sha256Digest| { + let mut hasher = Sha256::new(); + hasher.update(left); + hasher.update(right); + let digest = hasher.finalize(); + digest[..16].try_into().unwrap() + }; + let restored = paths.restore(&hash_fn, &combine_fn)?; + Some( + restored + .0 + .into_iter() + .map(|path| MerkleOpening { + leaf_data: path.leaf_data, + path: path.sibling_hashes, + }) + .collect(), + ) + } +} + +impl>> ChallengeSampler for VerifierStateSha2 +where + PF: PrimeField32, +{ + fn sample_vec(&mut self, len: usize) -> Vec { + let sampled_fe = self + .challenger + .sample_many((len * EF::DIMENSION).div_ceil(RATE)) + .into_iter() + .flatten() + .take(len * EF::DIMENSION) + .collect::>>(); + pack_scalars_to_extension(&sampled_fe) + } + + fn sample_in_range(&mut self, bits: usize, n_samples: usize) -> Vec { + self.challenger.sample_in_range(bits, n_samples) + } +} + +impl>> FSVerifier for VerifierStateSha2 +where + PF: PrimeField32, +{ + type Digest = Sha256Digest; + + fn state(&self) -> String { + format!( + "sha2 verifier (offset: {}, commitment_idx: {}, merkle_idx: {})", + self.transcript_offset, self.commitment_index, self.merkle_opening_index, + ) + } + + fn observe_scalars(&mut self, scalars: &[PF]) { + self.challenger.observe_scalars(scalars); + } + + fn next_base_scalars_vec(&mut self, n: usize) -> Result>, ProofError> { + let scalars = self.read_transcript(n)?; + self.challenger.observe_scalars(&scalars); + Ok(scalars) + } + + fn next_commitment(&mut self) -> Result { + if self.commitment_index >= self.commitments.len() { + return Err(ProofError::ExceededTranscript); + } + let commitment = self.commitments[self.commitment_index]; + self.commitment_index += 1; + self.challenger.observe_bytes(&commitment); + Ok(commitment) + } + + fn next_merkle_opening(&mut self) -> Result, Self::Digest>, ProofError> { + if self.merkle_opening_index >= self.merkle_openings.len() { + return Err(ProofError::ExceededTranscript); + } + let opening = self.merkle_openings[self.merkle_opening_index].clone(); + self.merkle_opening_index += 1; + Ok(opening) + } + + fn check_pow_grinding(&mut self, bits: usize) -> Result<(), ProofError> { + if bits == 0 { + return Ok(()); + } + let witness = self.read_transcript(1)?[0]; + self.challenger.observe_scalars(&[witness]); + let mut challenger = self.challenger.clone(); + if challenger.sample_in_range(bits, 1)[0] != 0 { + return Err(ProofError::InvalidGrindingWitness); + } + Ok(()) + } + + fn next_sumcheck_polynomial( + &mut self, + n_coeffs: usize, + claimed_sum: EF, + eq_alpha: Option, + ) -> ProofResult> { + match eq_alpha { + None => { + let rest_scalars = self.read_transcript((n_coeffs - 1) * EF::DIMENSION)?; + let rest_coeffs: Vec = pack_scalars_to_extension(&rest_scalars); + let c0 = (claimed_sum - rest_coeffs.iter().copied().sum::()).halve(); + + let mut full_coeffs = Vec::with_capacity(n_coeffs); + full_coeffs.push(c0); + full_coeffs.extend_from_slice(&rest_coeffs); + + let mut all_scalars = flatten_scalars_to_base(&[c0]); + all_scalars.extend_from_slice(&rest_scalars); + self.challenger.observe_scalars(&all_scalars); + Ok(full_coeffs) + } + Some(alpha) => { + let rest_scalars = self.read_transcript((n_coeffs - 2) * EF::DIMENSION)?; + let rest_bare: Vec = pack_scalars_to_extension(&rest_scalars); + let h0 = claimed_sum - alpha * rest_bare.iter().copied().sum::(); + + let mut bare = Vec::with_capacity(n_coeffs - 1); + bare.push(h0); + bare.extend_from_slice(&rest_bare); + + let full_coeffs = expand_bare_to_full(&bare, alpha); + self.challenger.observe_scalars(&flatten_scalars_to_base(&full_coeffs)); + Ok(full_coeffs) + } + } + } +} diff --git a/crates/lean_prover/src/test_zkvm.rs b/crates/lean_prover/src/test_zkvm.rs index f4e87c947..672adf5bf 100644 --- a/crates/lean_prover/src/test_zkvm.rs +++ b/crates/lean_prover/src/test_zkvm.rs @@ -1,4 +1,8 @@ -use crate::{default_whir_config, prove_execution::prove_execution, verify_execution::verify_execution}; +use crate::{ + default_whir_config, + prove_execution::prove_execution, + verify_execution::{verify_execution, verify_execution_sha2}, +}; use backend::*; use lean_compiler::*; use lean_vm::*; @@ -248,6 +252,7 @@ fn test_zk_vm_helper(program_str: &str, public_input: &[F]) { .unwrap(); let proof_time = time.elapsed(); verify_execution(&bytecode, public_input, proof.proof).unwrap(); + verify_execution_sha2(&bytecode, public_input, proof.proof2).unwrap(); println!("{}", proof.metadata.as_ref().unwrap().display()); println!("Proof time: {:.3} s", proof_time.as_secs_f32()); } diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index 0909691d7..6c5fd674f 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -1,7 +1,8 @@ use std::collections::BTreeMap; use crate::*; -use backend::{Proof, RawProof, VerifierState}; +use backend::merkle::Sha256Digest; +use backend::{Proof, RawProof, VerifierState, VerifierStateSha2}; use lean_vm::*; use sub_protocols::*; use utils::{ToUsize, from_end, get_poseidon16}; @@ -229,6 +230,219 @@ pub fn verify_execution( )) } +pub fn verify_execution_sha2( + bytecode: &Bytecode, + public_input: &[F], + proof: Proof, +) -> Result { + let mut verifier_state = VerifierStateSha2::::new(proof)?; + verifier_state.observe_scalars(public_input); + verifier_state.observe_scalars(&poseidon16_compress_pair(&bytecode.hash, &SNARK_DOMAIN_SEP)); + let dims = verifier_state + .next_base_scalars_vec(3 + N_TABLES)? + .into_iter() + .map(|x| x.to_usize()) + .collect::>(); + let log_inv_rate = dims[0]; + let log_memory = dims[1]; + let public_input_len = dims[2]; + if public_input_len != public_input.len() { + return Err(ProofError::InvalidProof); + } + let table_n_vars: BTreeMap = (0..N_TABLES).map(|i| (ALL_TABLES[i], dims[i + 3])).collect(); + check_rate(log_inv_rate)?; + let whir_config = default_whir_config(log_inv_rate); + for (table, &log_n_rows) in &table_n_vars { + if log_n_rows < MIN_LOG_N_ROWS_PER_TABLE { + return Err(ProofError::InvalidProof); + } + let log_limit = max_log_n_rows_per_table(table); + if log_n_rows > log_limit { + return Err(TooBigTableError { + table_name: table.name(), + log_n_rows, + log_limit, + } + .into()); + } + } + if log_memory < (*table_n_vars.values().max().unwrap()).max(bytecode.log_size()) { + return Err(ProofError::InvalidProof); + } + + let public_memory = padd_with_zero_to_next_power_of_two(public_input); + + if !(MIN_LOG_MEMORY_SIZE..=MAX_LOG_MEMORY_SIZE).contains(&log_memory) { + return Err(ProofError::InvalidProof); + } + + if bytecode.log_size() < MIN_BYTECODE_LOG_SIZE { + return Err(ProofError::InvalidProof); + } + + let parsed_commitment = stacked_pcs_parse_commitment_sha2( + &whir_config, + &mut verifier_state, + log_memory, + bytecode.log_size(), + &table_n_vars, + )?; + + let logup_c = verifier_state.sample(); + let logup_alphas = verifier_state.sample_vec(log2_ceil_usize(max_bus_width_including_domainsep())); + let logup_alphas_eq_poly = eval_eq(&logup_alphas); + + let logup_statements = verify_generic_logup( + &mut verifier_state, + logup_c, + &logup_alphas, + &logup_alphas_eq_poly, + log_memory, + &bytecode.instructions_multilinear, + &table_n_vars, + )?; + let gkr_point = &logup_statements.gkr_point; + let mut committed_statements: CommittedStatements = Default::default(); + for table in ALL_TABLES { + let log_n = table_n_vars[&table]; + committed_statements.insert( + table, + vec![( + MultilinearPoint(from_end(gkr_point, log_n).to_vec()), + logup_statements.columns_values[&table].clone(), + BTreeMap::new(), + )], + ); + } + + let bus_beta = verifier_state.sample(); + let air_alpha = verifier_state.sample(); + let air_alpha_powers: Vec = air_alpha.powers().collect_n(max_air_constraints() + 1); + let eta: EF = verifier_state.sample(); + + let tables_sorted = sort_tables_by_height(&table_n_vars); + + struct TableVerifyData { + table: Table, + extra_data: ExtraDataForBuses, + eta_power: EF, + } + let mut verify_data: Vec = Vec::new(); + let mut initial_sum = EF::ZERO; + let mut eta_power = EF::ONE; + + for (table, _) in &tables_sorted { + let bus_numerator_value = logup_statements.bus_numerators_values[table]; + let bus_denominator_value = logup_statements.bus_denominators_values[table]; + let bus_final_value = bus_numerator_value + * match table.bus().direction { + BusDirection::Pull => EF::NEG_ONE, + BusDirection::Push => EF::ONE, + } + + bus_beta * (bus_denominator_value - logup_c); + + initial_sum += eta_power * bus_final_value; + + verify_data.push(TableVerifyData { + table: *table, + eta_power, + extra_data: ExtraDataForBuses::new(logup_alphas_eq_poly.clone(), bus_beta, air_alpha_powers.clone()), + }); + + eta_power *= eta; + } + + let max_full_degree = tables_sorted.iter().map(|(t, _)| t.degree_air() + 1).max().unwrap(); + + let n_max = tables_sorted[0].1; + let Evaluation { + point: sumcheck_air_point, + value: claimed_air_final_value, + } = sumcheck_verify(&mut verifier_state, n_max, max_full_degree, initial_sum, None)?; + + let mut my_air_final_value = EF::ZERO; + for vd in &verify_data { + let n_cols_total = vd.table.n_columns() + vd.table.n_down_columns(); + let col_evals = verifier_state.next_extension_scalars_vec(n_cols_total)?; + + macro_rules! eval_constraint { + ($t:expr) => {{ <_ as SumcheckComputation>::eval_extension($t, &col_evals, &vd.extra_data) }}; + } + let constraint_eval = delegate_to_inner!(&vd.table => eval_constraint); + + let bus_point = from_end(gkr_point, table_n_vars[&vd.table]); + let natural_ordering_point = natural_ordering_point_for_session(&sumcheck_air_point.0, table_n_vars[&vd.table]); + my_air_final_value += back_loaded_table_contribution( + bus_point, + &sumcheck_air_point.0, + &natural_ordering_point, + constraint_eval, + vd.eta_power, + ); + + macro_rules! split { + ($t:expr) => {{ columns_evals_up_and_down($t, &col_evals, &natural_ordering_point) }}; + } + let claim = delegate_to_inner!(&vd.table => split); + + committed_statements.get_mut(&vd.table).unwrap().push(claim); + } + + if my_air_final_value != claimed_air_final_value { + return Err(ProofError::InvalidProof); + } + + let public_memory_random_point = + MultilinearPoint(verifier_state.sample_vec(log2_strict_usize(public_memory.len()))); + let public_memory_eval = public_memory.evaluate(&public_memory_random_point); + + let previous_statements = vec![ + SparseStatement::new( + parsed_commitment.num_variables, + logup_statements.memory_and_acc_point, + vec![ + SparseValue::new(0, logup_statements.value_memory), + SparseValue::new(1, logup_statements.value_memory_acc), + ], + ), + SparseStatement::new( + parsed_commitment.num_variables, + public_memory_random_point, + vec![SparseValue::new(0, public_memory_eval)], + ), + SparseStatement::new( + parsed_commitment.num_variables, + logup_statements.bytecode_and_acc_point, + vec![SparseValue::new( + (2 << log_memory) >> bytecode.log_size(), + logup_statements.value_bytecode_acc, + )], + ), + ]; + + let global_statements_base = stacked_pcs_global_statements( + parsed_commitment.num_variables, + log_memory, + bytecode.log_size(), + previous_statements, + &table_n_vars, + &committed_statements, + ); + + let num_whir_statements = global_statements_base.iter().map(|s| s.values.len()).sum::(); + assert_eq!(num_whir_statements, total_whir_statements()); + + WhirConfig::new(&whir_config, parsed_commitment.num_variables).verify2( + &mut verifier_state, + &parsed_commitment, + global_statements_base, + )?; + + Ok(ProofVerificationDetails { + bytecode_evaluation: logup_statements.bytecode_evaluation.unwrap(), + }) +} + fn back_loaded_table_contribution>>( bus_point: &[EF], sumcheck_air_point: &[EF], diff --git a/crates/sub_protocols/src/stacked_pcs.rs b/crates/sub_protocols/src/stacked_pcs.rs index 5921a3122..df1fe399d 100644 --- a/crates/sub_protocols/src/stacked_pcs.rs +++ b/crates/sub_protocols/src/stacked_pcs.rs @@ -184,6 +184,28 @@ pub fn stacked_pcs_parse_commitment( WhirConfig::new(whir_config_builder, stacked_n_vars).parse_commitment(verifier_state) } +pub fn stacked_pcs_parse_commitment_sha2( + whir_config_builder: &WhirConfigBuilder, + verifier_state: &mut impl FSVerifier, + log_memory: usize, + log_bytecode: usize, + tables_heights: &BTreeMap, +) -> Result, ProofError> { + if log_memory < tables_heights[&Table::execution()] + || tables_heights[&Table::execution()] < tables_heights.values().copied().max().unwrap() + { + return Err(ProofError::InvalidProof); + } + + let stacked_n_vars = compute_stacked_n_vars(log_memory, log_bytecode, tables_heights); + if stacked_n_vars + > F::TWO_ADICITY + whir_config_builder.folding_factor.at_round(0) - whir_config_builder.starting_log_inv_rate + { + return Err(ProofError::InvalidProof); + } + WhirConfig::new(whir_config_builder, stacked_n_vars).parse_commitment2(verifier_state) +} + fn compute_stacked_n_vars( log_memory: usize, log_bytecode: usize, diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index 18402c1cb..681c36bbf 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -384,12 +384,14 @@ where } #[instrument(name = "first digest layer", level = "debug", skip_all)] -fn sha2_first_digest_layer(h: &Sha256, matrix: &M, _full_width: usize) -> Vec +fn sha2_first_digest_layer(h: &Sha256, matrix: &M, full_width: usize) -> Vec where F: PrimeField32, M: Matrix, { let height = matrix.height(); + let matrix_width = matrix.width(); + let n_trailing_zeros = full_width - matrix_width; (0..height) .into_par_iter() @@ -398,6 +400,9 @@ where for value in matrix.row(r).unwrap() { hasher.update(value.as_canonical_u32().to_le_bytes()); } + for _ in 0..n_trailing_zeros { + hasher.update(F::ZERO.as_canonical_u32().to_le_bytes()); + } let digest = hasher.finalize(); digest[..16].try_into().unwrap() }) diff --git a/crates/whir/src/verify.rs b/crates/whir/src/verify.rs index ebb616109..8b00178d9 100644 --- a/crates/whir/src/verify.rs +++ b/crates/whir/src/verify.rs @@ -5,6 +5,7 @@ use std::{fmt::Debug, marker::PhantomData}; use fiat_shamir::{FSVerifier, ProofError, ProofResult, pack_scalars_to_extension}; use field::{ExtensionField, Field, PrimeCharacteristicRing, TwoAdicField}; use poly::*; +use symetric::merkle::Sha256Digest; use crate::*; @@ -17,6 +18,15 @@ pub struct ParsedCommitment> { pub base_field: PhantomData, } +#[derive(Debug, Clone)] +pub struct ParsedCommitment2> { + pub num_variables: usize, + pub root: Sha256Digest, + pub ood_points: Vec, + pub ood_answers: Vec, + pub base_field: PhantomData, +} + impl> ParsedCommitment { pub fn parse( verifier_state: &mut impl FSVerifier; DIGEST_ELEMS]>, @@ -59,6 +69,48 @@ impl> ParsedCommitment { } } +impl> ParsedCommitment2 { + pub fn parse( + verifier_state: &mut impl FSVerifier, + num_variables: usize, + ood_samples: usize, + ) -> ProofResult + where + F: TwoAdicField, + EF: ExtensionField + TwoAdicField, + EF: ExtensionField>, + { + let root = verifier_state.next_commitment()?; + let mut ood_points = vec![]; + let ood_answers = if ood_samples > 0 { + ood_points = verifier_state.sample_vec(ood_samples); + verifier_state.next_extension_scalars_vec(ood_samples)? + } else { + Vec::new() + }; + Ok(Self { + num_variables, + root, + ood_points, + ood_answers, + base_field: PhantomData, + }) + } + + pub fn oods_constraints(&self) -> Vec> { + self.ood_points + .iter() + .zip(&self.ood_answers) + .map(|(&point, &eval)| { + SparseStatement::dense( + MultilinearPoint::expand_from_univariate(point, self.num_variables), + eval, + ) + }) + .collect() + } +} + impl WhirConfig where EF: TwoAdicField + ExtensionField>, @@ -72,6 +124,16 @@ where { ParsedCommitment::::parse(verifier_state, self.num_variables, self.commitment_ood_samples) } + + pub fn parse_commitment2( + &self, + verifier_state: &mut impl FSVerifier, + ) -> ProofResult> + where + EF: ExtensionField, + { + ParsedCommitment2::::parse(verifier_state, self.num_variables, self.commitment_ood_samples) + } } impl WhirConfig @@ -204,6 +266,118 @@ where Ok(folding_randomness) } + #[allow(clippy::too_many_lines)] + pub fn verify2( + &self, + verifier_state: &mut impl FSVerifier, + parsed_commitment: &ParsedCommitment2, + statement: Vec>, + ) -> ProofResult> + where + F: TwoAdicField + ExtensionField>, + EF: ExtensionField, + { + statement + .iter() + .for_each(|c| assert_eq!(c.total_num_variables, parsed_commitment.num_variables)); + + let mut round_constraints = Vec::new(); + let mut round_folding_randomness = Vec::new(); + let mut claimed_sum = EF::ZERO; + let mut prev_commitment = parsed_commitment.clone(); + + let constraints: Vec<_> = prev_commitment + .oods_constraints() + .into_iter() + .chain(statement) + .collect(); + let combination_randomness = self.combine_constraints(verifier_state, &mut claimed_sum, &constraints)?; + round_constraints.push((combination_randomness, constraints)); + + let folding_randomness = verify_sumcheck_rounds::( + verifier_state, + &mut claimed_sum, + self.folding_factor.at_round(0), + self.starting_folding_pow_bits, + )?; + round_folding_randomness.push(folding_randomness); + + for round_index in 0..self.n_rounds() { + let round_params = &self.round_parameters[round_index]; + let new_commitment = ParsedCommitment2::::parse( + verifier_state, + round_params.num_variables, + round_params.ood_samples, + )?; + + let stir_constraints = self.verify_stir_challenges2( + verifier_state, + round_params, + &prev_commitment, + round_folding_randomness.last().unwrap(), + round_index, + )?; + + let constraints: Vec> = new_commitment + .oods_constraints() + .into_iter() + .chain(stir_constraints) + .collect(); + + let combination_randomness = self.combine_constraints(verifier_state, &mut claimed_sum, &constraints)?; + round_constraints.push((combination_randomness.clone(), constraints)); + + let folding_randomness = verify_sumcheck_rounds::( + verifier_state, + &mut claimed_sum, + self.folding_factor.at_round(round_index + 1), + round_params.folding_pow_bits, + )?; + + round_folding_randomness.push(folding_randomness); + prev_commitment = new_commitment; + } + + let n_final_coeffs = 1 << self.n_vars_of_final_polynomial(); + let final_coefficients = verifier_state.next_extension_scalars_vec(n_final_coeffs)?; + + let stir_constraints = self.verify_stir_challenges2( + verifier_state, + &self.final_round_config(), + &prev_commitment, + round_folding_randomness.last().unwrap(), + self.n_rounds(), + )?; + + stir_constraints + .iter() + .all(|c| verify_constraint_coeffs(c, &final_coefficients)) + .then_some(()) + .ok_or(ProofError::InvalidProof)?; + + let final_sumcheck_randomness = + verify_sumcheck_rounds::(verifier_state, &mut claimed_sum, self.final_sumcheck_rounds, 0)?; + round_folding_randomness.push(final_sumcheck_randomness.clone()); + + let folding_randomness = MultilinearPoint( + round_folding_randomness + .into_iter() + .flat_map(|poly| poly.0.into_iter()) + .collect(), + ); + + let evaluation_of_weights = self.eval_constraints_poly(&round_constraints, folding_randomness.clone()); + + let mut reversed_point = final_sumcheck_randomness.0.clone(); + reversed_point.reverse(); + let final_value = eval_multilinear_coeffs(&final_coefficients, &reversed_point); + if claimed_sum != evaluation_of_weights * final_value { + return Err(ProofError::InvalidProof); + } + + Ok(folding_randomness) + } + pub(crate) fn combine_constraints( &self, verifier_state: &mut impl FSVerifier, @@ -226,7 +400,7 @@ where fn verify_stir_challenges( &self, - verifier_state: &mut impl FSVerifier, + verifier_state: &mut impl FSVerifier; DIGEST_ELEMS]>, params: &RoundConfig, commitment: &ParsedCommitment, folding_randomness: &MultilinearPoint, @@ -284,10 +458,64 @@ where Ok(stir_constraints) } + fn verify_stir_challenges2( + &self, + verifier_state: &mut impl FSVerifier, + params: &RoundConfig, + commitment: &ParsedCommitment2, + folding_randomness: &MultilinearPoint, + round_index: usize, + ) -> ProofResult>> + where + F: Field + ExtensionField>, + EF: ExtensionField, + { + let leafs_base_field = round_index == 0; + + verifier_state.check_pow_grinding(params.query_pow_bits)?; + + let stir_challenges_indexes = get_challenge_stir_queries( + params.domain_size >> params.folding_factor, + params.num_queries, + verifier_state, + ); + + let dimensions = vec![Dimensions { + height: params.domain_size >> params.folding_factor, + width: 1 << params.folding_factor, + }]; + let answers = self.verify_merkle_proof2::( + verifier_state, + &commitment.root, + &stir_challenges_indexes, + &dimensions, + leafs_base_field, + )?; + + let folds: Vec<_> = answers + .into_iter() + .map(|answers| answers.evaluate(folding_randomness)) + .collect(); + + let stir_constraints = stir_challenges_indexes + .iter() + .map(|&index| params.folded_domain_gen.exp_u64(index as u64)) + .zip(&folds) + .map(|(point, &value)| { + SparseStatement::dense( + MultilinearPoint::expand_from_univariate(EF::from(point), params.num_variables), + value, + ) + }) + .collect(); + + Ok(stir_constraints) + } + #[allow(clippy::too_many_arguments)] fn verify_merkle_proof( &self, - verifier_state: &mut impl FSVerifier, + verifier_state: &mut impl FSVerifier; DIGEST_ELEMS]>, root: &[PF; DIGEST_ELEMS], indices: &[usize], dimensions: &[Dimensions], @@ -341,6 +569,62 @@ where Ok(res) } + fn verify_merkle_proof2( + &self, + verifier_state: &mut impl FSVerifier, + root: &Sha256Digest, + indices: &[usize], + dimensions: &[Dimensions], + leafs_base_field: bool, + ) -> ProofResult>> + where + F: Field + ExtensionField>, + EF: ExtensionField, + { + let res = if leafs_base_field { + let mut answers = Vec::>::new(); + let mut merkle_proofs = Vec::new(); + + for _ in 0..indices.len() { + let opening = verifier_state.next_merkle_opening()?; + answers.push(pack_scalars_to_extension::, F>(&opening.leaf_data)); + merkle_proofs.push(opening.path); + } + + for (i, &index) in indices.iter().enumerate() { + if !merkle_verify_sha2::, F>(*root, index, dimensions[0], answers[i].clone(), &merkle_proofs[i]) + { + return Err(ProofError::InvalidProof); + } + } + + answers + .into_iter() + .map(|inner| inner.iter().map(|&f_el| f_el.into()).collect()) + .collect() + } else { + let mut answers = vec![]; + let mut merkle_proofs = Vec::new(); + + for _ in 0..indices.len() { + let opening = verifier_state.next_merkle_opening()?; + answers.push(pack_scalars_to_extension::, EF>(&opening.leaf_data)); + merkle_proofs.push(opening.path); + } + + for (i, &index) in indices.iter().enumerate() { + if !merkle_verify_sha2::, EF>(*root, index, dimensions[0], answers[i].clone(), &merkle_proofs[i]) + { + return Err(ProofError::InvalidProof); + } + } + + answers + }; + + Ok(res) + } + fn eval_constraints_poly( &self, constraints: &[(Vec, Vec>)], From 8260c64eef272afb6b70e2c4c6200ae81891db50 Mon Sep 17 00:00:00 2001 From: kilic Date: Thu, 14 May 2026 20:53:48 +0300 Subject: [PATCH 12/19] generic --- crates/sub_protocols/src/stacked_pcs.rs | 44 ++++++++------ crates/whir/src/verify.rs | 79 ++++++------------------- 2 files changed, 46 insertions(+), 77 deletions(-) diff --git a/crates/sub_protocols/src/stacked_pcs.rs b/crates/sub_protocols/src/stacked_pcs.rs index df1fe399d..e85810d5f 100644 --- a/crates/sub_protocols/src/stacked_pcs.rs +++ b/crates/sub_protocols/src/stacked_pcs.rs @@ -167,21 +167,13 @@ pub fn stacked_pcs_parse_commitment( log_bytecode: usize, tables_heights: &BTreeMap, ) -> Result, ProofError> { - if log_memory < tables_heights[&Table::execution()] - || tables_heights[&Table::execution()] < tables_heights.values().copied().max().unwrap() - { - // memory must be at least as large as the number of cycles - // execution table must be the largest table - return Err(ProofError::InvalidProof); - } - - let stacked_n_vars = compute_stacked_n_vars(log_memory, log_bytecode, tables_heights); - if stacked_n_vars - > F::TWO_ADICITY + whir_config_builder.folding_factor.at_round(0) - whir_config_builder.starting_log_inv_rate - { - return Err(ProofError::InvalidProof); - } - WhirConfig::new(whir_config_builder, stacked_n_vars).parse_commitment(verifier_state) + stacked_pcs_parse_commitment_generic( + whir_config_builder, + verifier_state, + log_memory, + log_bytecode, + tables_heights, + ) } pub fn stacked_pcs_parse_commitment_sha2( @@ -190,10 +182,28 @@ pub fn stacked_pcs_parse_commitment_sha2( log_memory: usize, log_bytecode: usize, tables_heights: &BTreeMap, -) -> Result, ProofError> { +) -> Result, ProofError> { + stacked_pcs_parse_commitment_generic( + whir_config_builder, + verifier_state, + log_memory, + log_bytecode, + tables_heights, + ) +} + +fn stacked_pcs_parse_commitment_generic( + whir_config_builder: &WhirConfigBuilder, + verifier_state: &mut impl FSVerifier, + log_memory: usize, + log_bytecode: usize, + tables_heights: &BTreeMap, +) -> Result, ProofError> { if log_memory < tables_heights[&Table::execution()] || tables_heights[&Table::execution()] < tables_heights.values().copied().max().unwrap() { + // memory must be at least as large as the number of cycles + // execution table must be the largest table return Err(ProofError::InvalidProof); } @@ -203,7 +213,7 @@ pub fn stacked_pcs_parse_commitment_sha2( { return Err(ProofError::InvalidProof); } - WhirConfig::new(whir_config_builder, stacked_n_vars).parse_commitment2(verifier_state) + WhirConfig::new(whir_config_builder, stacked_n_vars).parse_commitment_generic(verifier_state) } fn compute_stacked_n_vars( diff --git a/crates/whir/src/verify.rs b/crates/whir/src/verify.rs index 8b00178d9..b8956f4b8 100644 --- a/crates/whir/src/verify.rs +++ b/crates/whir/src/verify.rs @@ -10,68 +10,17 @@ use symetric::merkle::Sha256Digest; use crate::*; #[derive(Debug, Clone)] -pub struct ParsedCommitment> { +pub struct ParsedCommitment, Digest = [PF; DIGEST_ELEMS]> { pub num_variables: usize, - pub root: [PF; DIGEST_ELEMS], + pub root: Digest, pub ood_points: Vec, pub ood_answers: Vec, pub base_field: PhantomData, } -#[derive(Debug, Clone)] -pub struct ParsedCommitment2> { - pub num_variables: usize, - pub root: Sha256Digest, - pub ood_points: Vec, - pub ood_answers: Vec, - pub base_field: PhantomData, -} - -impl> ParsedCommitment { +impl, Digest: Clone> ParsedCommitment { pub fn parse( - verifier_state: &mut impl FSVerifier; DIGEST_ELEMS]>, - num_variables: usize, - ood_samples: usize, - ) -> ProofResult - where - F: TwoAdicField, - EF: ExtensionField + TwoAdicField, - EF: ExtensionField>, - { - let root = verifier_state.next_commitment()?; - let mut ood_points = vec![]; - let ood_answers = if ood_samples > 0 { - ood_points = verifier_state.sample_vec(ood_samples); - verifier_state.next_extension_scalars_vec(ood_samples)? - } else { - Vec::new() - }; - Ok(Self { - num_variables, - root, - ood_points, - ood_answers, - base_field: PhantomData, - }) - } - - pub fn oods_constraints(&self) -> Vec> { - self.ood_points - .iter() - .zip(&self.ood_answers) - .map(|(&point, &eval)| { - SparseStatement::dense( - MultilinearPoint::expand_from_univariate(point, self.num_variables), - eval, - ) - }) - .collect() - } -} - -impl> ParsedCommitment2 { - pub fn parse( - verifier_state: &mut impl FSVerifier, + verifier_state: &mut impl FSVerifier, num_variables: usize, ood_samples: usize, ) -> ProofResult @@ -125,14 +74,24 @@ where ParsedCommitment::::parse(verifier_state, self.num_variables, self.commitment_ood_samples) } + pub fn parse_commitment_generic( + &self, + verifier_state: &mut impl FSVerifier, + ) -> ProofResult> + where + EF: ExtensionField, + { + ParsedCommitment::::parse(verifier_state, self.num_variables, self.commitment_ood_samples) + } + pub fn parse_commitment2( &self, verifier_state: &mut impl FSVerifier, - ) -> ProofResult> + ) -> ProofResult> where EF: ExtensionField, { - ParsedCommitment2::::parse(verifier_state, self.num_variables, self.commitment_ood_samples) + self.parse_commitment_generic(verifier_state) } } @@ -270,7 +229,7 @@ where pub fn verify2( &self, verifier_state: &mut impl FSVerifier, - parsed_commitment: &ParsedCommitment2, + parsed_commitment: &ParsedCommitment, statement: Vec>, ) -> ProofResult> where @@ -304,7 +263,7 @@ where for round_index in 0..self.n_rounds() { let round_params = &self.round_parameters[round_index]; - let new_commitment = ParsedCommitment2::::parse( + let new_commitment = ParsedCommitment::::parse( verifier_state, round_params.num_variables, round_params.ood_samples, @@ -462,7 +421,7 @@ where &self, verifier_state: &mut impl FSVerifier, params: &RoundConfig, - commitment: &ParsedCommitment2, + commitment: &ParsedCommitment, folding_randomness: &MultilinearPoint, round_index: usize, ) -> ProofResult>> From bc5993bc57d4620075bf3402e926f06915dba34a Mon Sep 17 00:00:00 2001 From: kilic Date: Thu, 14 May 2026 21:08:16 +0300 Subject: [PATCH 13/19] generic --- crates/lean_prover/src/verify_execution.rs | 2 +- crates/whir/src/verify.rs | 204 ++++++++------------- crates/whir/tests/run_whir.rs | 2 +- 3 files changed, 75 insertions(+), 133 deletions(-) diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index 6c5fd674f..04b8840bb 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -432,7 +432,7 @@ pub fn verify_execution_sha2( let num_whir_statements = global_statements_base.iter().map(|s| s.values.len()).sum::(); assert_eq!(num_whir_statements, total_whir_statements()); - WhirConfig::new(&whir_config, parsed_commitment.num_variables).verify2( + WhirConfig::new(&whir_config, parsed_commitment.num_variables).verify( &mut verifier_state, &parsed_commitment, global_statements_base, diff --git a/crates/whir/src/verify.rs b/crates/whir/src/verify.rs index b8956f4b8..6d78e2c4f 100644 --- a/crates/whir/src/verify.rs +++ b/crates/whir/src/verify.rs @@ -60,6 +60,69 @@ impl, Digest: Clone> ParsedCommitment: Clone + Sized +where + EF: TwoAdicField + ExtensionField>, + PF: TwoAdicField, + F: Field + ExtensionField>, + EF: ExtensionField, +{ + fn verify_stir_challenges( + config: &WhirConfig, + verifier_state: &mut V, + params: &RoundConfig, + commitment: &ParsedCommitment, + folding_randomness: &MultilinearPoint, + round_index: usize, + ) -> ProofResult>> + where + V: FSVerifier; +} + +impl WhirVerifierDigest for [PF; DIGEST_ELEMS] +where + EF: TwoAdicField + ExtensionField>, + PF: TwoAdicField, + F: Field + ExtensionField>, + EF: ExtensionField, +{ + fn verify_stir_challenges( + config: &WhirConfig, + verifier_state: &mut V, + params: &RoundConfig, + commitment: &ParsedCommitment, + folding_randomness: &MultilinearPoint, + round_index: usize, + ) -> ProofResult>> + where + V: FSVerifier, + { + config.verify_stir_challenges(verifier_state, params, commitment, folding_randomness, round_index) + } +} + +impl WhirVerifierDigest for Sha256Digest +where + EF: TwoAdicField + ExtensionField>, + PF: TwoAdicField, + F: Field + ExtensionField>, + EF: ExtensionField, +{ + fn verify_stir_challenges( + config: &WhirConfig, + verifier_state: &mut V, + params: &RoundConfig, + commitment: &ParsedCommitment, + folding_randomness: &MultilinearPoint, + round_index: usize, + ) -> ProofResult>> + where + V: FSVerifier, + { + config.verify_stir_challenges2(verifier_state, params, commitment, folding_randomness, round_index) + } +} + impl WhirConfig where EF: TwoAdicField + ExtensionField>, @@ -101,140 +164,17 @@ where PF: TwoAdicField, { #[allow(clippy::too_many_lines)] - pub fn verify( + pub fn verify( &self, - verifier_state: &mut impl FSVerifier; DIGEST_ELEMS]>, - parsed_commitment: &ParsedCommitment, - statement: Vec>, - ) -> ProofResult> - where - F: TwoAdicField + ExtensionField>, - EF: ExtensionField, - { - statement - .iter() - .for_each(|c| assert_eq!(c.total_num_variables, parsed_commitment.num_variables)); - - // During the rounds we collect constraints, combination randomness, folding randomness - // and we update the claimed sum of constraint evaluation. - let mut round_constraints = Vec::new(); - let mut round_folding_randomness = Vec::new(); - let mut claimed_sum = EF::ZERO; - let mut prev_commitment = parsed_commitment.clone(); - - // Combine OODS and statement constraints to claimed_sum - let constraints: Vec<_> = prev_commitment - .oods_constraints() - .into_iter() - .chain(statement) - .collect(); - let combination_randomness = self.combine_constraints(verifier_state, &mut claimed_sum, &constraints)?; - round_constraints.push((combination_randomness, constraints)); - - // Initial sumcheck - let folding_randomness = verify_sumcheck_rounds::( - verifier_state, - &mut claimed_sum, - self.folding_factor.at_round(0), - self.starting_folding_pow_bits, - )?; - round_folding_randomness.push(folding_randomness); - - for round_index in 0..self.n_rounds() { - // Fetch round parameters from config - let round_params = &self.round_parameters[round_index]; - - // Receive commitment to the folded polynomial (likely encoded at higher expansion) - let new_commitment = - ParsedCommitment::::parse(verifier_state, round_params.num_variables, round_params.ood_samples)?; - - // Verify in-domain challenges on the previous commitment. - let stir_constraints = self.verify_stir_challenges( - verifier_state, - round_params, - &prev_commitment, - round_folding_randomness.last().unwrap(), - round_index, - )?; - - // Add out-of-domain and in-domain constraints to claimed_sum - let constraints: Vec> = new_commitment - .oods_constraints() - .into_iter() - .chain(stir_constraints) - .collect(); - - let combination_randomness = self.combine_constraints(verifier_state, &mut claimed_sum, &constraints)?; - round_constraints.push((combination_randomness.clone(), constraints)); - - let folding_randomness = verify_sumcheck_rounds::( - verifier_state, - &mut claimed_sum, - self.folding_factor.at_round(round_index + 1), - round_params.folding_pow_bits, - )?; - - round_folding_randomness.push(folding_randomness); - - // Update round parameters - prev_commitment = new_commitment; - } - - // In the final round we receive the full polynomial instead of a commitment. - let n_final_coeffs = 1 << self.n_vars_of_final_polynomial(); - let final_coefficients = verifier_state.next_extension_scalars_vec(n_final_coeffs)?; - - // Verify in-domain challenges on the previous commitment. - let stir_constraints = self.verify_stir_challenges( - verifier_state, - &self.final_round_config(), - &prev_commitment, - round_folding_randomness.last().unwrap(), - self.n_rounds(), - )?; - - // Verify stir constraints directly on final polynomial - stir_constraints - .iter() - .all(|c| verify_constraint_coeffs(c, &final_coefficients)) - .then_some(()) - .ok_or(ProofError::InvalidProof)?; - - let final_sumcheck_randomness = - verify_sumcheck_rounds::(verifier_state, &mut claimed_sum, self.final_sumcheck_rounds, 0)?; - round_folding_randomness.push(final_sumcheck_randomness.clone()); - - // Compute folding randomness across all rounds. - let folding_randomness = MultilinearPoint( - round_folding_randomness - .into_iter() - .flat_map(|poly| poly.0.into_iter()) - .collect(), - ); - - let evaluation_of_weights = self.eval_constraints_poly(&round_constraints, folding_randomness.clone()); - - // Check the final sumcheck evaluation (coefficient form, reversed point) - let mut reversed_point = final_sumcheck_randomness.0.clone(); - reversed_point.reverse(); - let final_value = eval_multilinear_coeffs(&final_coefficients, &reversed_point); - if claimed_sum != evaluation_of_weights * final_value { - return Err(ProofError::InvalidProof); - } - - Ok(folding_randomness) - } - - #[allow(clippy::too_many_lines)] - pub fn verify2( - &self, - verifier_state: &mut impl FSVerifier, - parsed_commitment: &ParsedCommitment, + verifier_state: &mut V, + parsed_commitment: &ParsedCommitment, statement: Vec>, ) -> ProofResult> where F: TwoAdicField + ExtensionField>, EF: ExtensionField, + Digest: WhirVerifierDigest, + V: FSVerifier, { statement .iter() @@ -263,13 +203,14 @@ where for round_index in 0..self.n_rounds() { let round_params = &self.round_parameters[round_index]; - let new_commitment = ParsedCommitment::::parse( + let new_commitment = ParsedCommitment::::parse( verifier_state, round_params.num_variables, round_params.ood_samples, )?; - let stir_constraints = self.verify_stir_challenges2( + let stir_constraints = Digest::verify_stir_challenges( + self, verifier_state, round_params, &prev_commitment, @@ -300,7 +241,8 @@ where let n_final_coeffs = 1 << self.n_vars_of_final_polynomial(); let final_coefficients = verifier_state.next_extension_scalars_vec(n_final_coeffs)?; - let stir_constraints = self.verify_stir_challenges2( + let stir_constraints = Digest::verify_stir_challenges( + self, verifier_state, &self.final_round_config(), &prev_commitment, diff --git a/crates/whir/tests/run_whir.rs b/crates/whir/tests/run_whir.rs index 105460ca3..cd6b15d81 100644 --- a/crates/whir/tests/run_whir.rs +++ b/crates/whir/tests/run_whir.rs @@ -128,7 +128,7 @@ fn test_run_whir() { let parsed_commitment = params.parse_commitment::(&mut verifier_state).unwrap(); params - .verify::(&mut verifier_state, &parsed_commitment, statement.clone()) + .verify(&mut verifier_state, &parsed_commitment, statement.clone()) .unwrap(); println!( From 9b7abe90e314cd623e637c745f7bc1e01c3f0d06 Mon Sep 17 00:00:00 2001 From: kilic Date: Thu, 14 May 2026 21:18:21 +0300 Subject: [PATCH 14/19] better v --- crates/lean_prover/src/test_zkvm.rs | 14 +- crates/lean_prover/src/verify_execution.rs | 243 ++------------------- crates/rec_aggregation/src/lib.rs | 10 +- crates/sub_protocols/src/stacked_pcs.rs | 2 +- 4 files changed, 26 insertions(+), 243 deletions(-) diff --git a/crates/lean_prover/src/test_zkvm.rs b/crates/lean_prover/src/test_zkvm.rs index 672adf5bf..3a74ee447 100644 --- a/crates/lean_prover/src/test_zkvm.rs +++ b/crates/lean_prover/src/test_zkvm.rs @@ -1,13 +1,9 @@ -use crate::{ - default_whir_config, - prove_execution::prove_execution, - verify_execution::{verify_execution, verify_execution_sha2}, -}; +use crate::{default_whir_config, prove_execution::prove_execution, verify_execution::verify}; use backend::*; use lean_compiler::*; use lean_vm::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; -use utils::{init_tracing, poseidon16_compress}; +use utils::{get_poseidon16, init_tracing, poseidon16_compress}; #[test] fn test_zk_vm_all_precompiles() { @@ -251,8 +247,10 @@ fn test_zk_vm_helper(program_str: &str, public_input: &[F]) { ) .unwrap(); let proof_time = time.elapsed(); - verify_execution(&bytecode, public_input, proof.proof).unwrap(); - verify_execution_sha2(&bytecode, public_input, proof.proof2).unwrap(); + let mut verifier_state = VerifierState::::new(proof.proof, get_poseidon16().clone()).unwrap(); + verify(&bytecode, public_input, &mut verifier_state).unwrap(); + let mut verifier_state2 = VerifierStateSha2::::new(proof.proof2).unwrap(); + verify(&bytecode, public_input, &mut verifier_state2).unwrap(); println!("{}", proof.metadata.as_ref().unwrap().display()); println!("Proof time: {:.3} s", proof_time.as_secs_f32()); } diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index 04b8840bb..cdfcf0cb3 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -1,241 +1,24 @@ use std::collections::BTreeMap; use crate::*; -use backend::merkle::Sha256Digest; -use backend::{Proof, RawProof, VerifierState, VerifierStateSha2}; use lean_vm::*; use sub_protocols::*; -use utils::{ToUsize, from_end, get_poseidon16}; +use utils::{ToUsize, from_end}; #[derive(Debug, Clone)] pub struct ProofVerificationDetails { pub bytecode_evaluation: Evaluation, } -pub fn verify_execution( +pub fn verify( bytecode: &Bytecode, public_input: &[F], - proof: Proof, -) -> Result<(ProofVerificationDetails, RawProof), ProofError> { - let mut verifier_state = VerifierState::::new(proof, get_poseidon16().clone())?; - verifier_state.observe_scalars(public_input); - verifier_state.observe_scalars(&poseidon16_compress_pair(&bytecode.hash, &SNARK_DOMAIN_SEP)); - let dims = verifier_state - .next_base_scalars_vec(3 + N_TABLES)? - .into_iter() - .map(|x| x.to_usize()) - .collect::>(); - let log_inv_rate = dims[0]; - let log_memory = dims[1]; - let public_input_len = dims[2]; // enforce the exact length of the public input to pass through Fiat Shamir (otherwise we could have 2 public inputs, only differing by a few (<8) zeros in the end, leading to the same fiat shamir state: tipically giving the advseary 2 or 3 bits of advantage in the subsequent part where the public input is evaluated as a multilinear polynomial) - if public_input_len != public_input.len() { - return Err(ProofError::InvalidProof); - } - let table_n_vars: BTreeMap = (0..N_TABLES).map(|i| (ALL_TABLES[i], dims[i + 3])).collect(); - check_rate(log_inv_rate)?; - let whir_config = default_whir_config(log_inv_rate); - for (table, &log_n_rows) in &table_n_vars { - if log_n_rows < MIN_LOG_N_ROWS_PER_TABLE { - return Err(ProofError::InvalidProof); - } - let log_limit = max_log_n_rows_per_table(table); - if log_n_rows > log_limit { - return Err(TooBigTableError { - table_name: table.name(), - log_n_rows, - log_limit, - } - .into()); - } - } - // check memory is bigger than any other table - if log_memory < (*table_n_vars.values().max().unwrap()).max(bytecode.log_size()) { - return Err(ProofError::InvalidProof); - } - - let public_memory = padd_with_zero_to_next_power_of_two(public_input); - - if !(MIN_LOG_MEMORY_SIZE..=MAX_LOG_MEMORY_SIZE).contains(&log_memory) { - return Err(ProofError::InvalidProof); - } - - if bytecode.log_size() < MIN_BYTECODE_LOG_SIZE { - return Err(ProofError::InvalidProof); - } - - let parsed_commitment = stacked_pcs_parse_commitment( - &whir_config, - &mut verifier_state, - log_memory, - bytecode.log_size(), - &table_n_vars, - )?; - - let logup_c = verifier_state.sample(); - let logup_alphas = verifier_state.sample_vec(log2_ceil_usize(max_bus_width_including_domainsep())); - let logup_alphas_eq_poly = eval_eq(&logup_alphas); - - let logup_statements = verify_generic_logup( - &mut verifier_state, - logup_c, - &logup_alphas, - &logup_alphas_eq_poly, - log_memory, - &bytecode.instructions_multilinear, - &table_n_vars, - )?; - let gkr_point = &logup_statements.gkr_point; - let mut committed_statements: CommittedStatements = Default::default(); - for table in ALL_TABLES { - let log_n = table_n_vars[&table]; - committed_statements.insert( - table, - vec![( - MultilinearPoint(from_end(gkr_point, log_n).to_vec()), - logup_statements.columns_values[&table].clone(), - BTreeMap::new(), - )], - ); - } - - let bus_beta = verifier_state.sample(); - let air_alpha = verifier_state.sample(); - let air_alpha_powers: Vec = air_alpha.powers().collect_n(max_air_constraints() + 1); - let eta: EF = verifier_state.sample(); // batching the sumchecks proving validity of AIR tables - - let tables_sorted = sort_tables_by_height(&table_n_vars); - - struct TableVerifyData { - table: Table, - extra_data: ExtraDataForBuses, - eta_power: EF, - } - let mut verify_data: Vec = Vec::new(); - let mut initial_sum = EF::ZERO; - let mut eta_power = EF::ONE; - - for (table, _) in &tables_sorted { - let bus_numerator_value = logup_statements.bus_numerators_values[table]; - let bus_denominator_value = logup_statements.bus_denominators_values[table]; - let bus_final_value = bus_numerator_value - * match table.bus().direction { - BusDirection::Pull => EF::NEG_ONE, - BusDirection::Push => EF::ONE, - } - + bus_beta * (bus_denominator_value - logup_c); - - initial_sum += eta_power * bus_final_value; - - verify_data.push(TableVerifyData { - table: *table, - eta_power, - extra_data: ExtraDataForBuses::new(logup_alphas_eq_poly.clone(), bus_beta, air_alpha_powers.clone()), - }); - - eta_power *= eta; - } - - let max_full_degree = tables_sorted.iter().map(|(t, _)| t.degree_air() + 1).max().unwrap(); - - let n_max = tables_sorted[0].1; - let Evaluation { - point: sumcheck_air_point, - value: claimed_air_final_value, - } = sumcheck_verify(&mut verifier_state, n_max, max_full_degree, initial_sum, None)?; - - let mut my_air_final_value = EF::ZERO; - for vd in &verify_data { - let n_cols_total = vd.table.n_columns() + vd.table.n_down_columns(); - let col_evals = verifier_state.next_extension_scalars_vec(n_cols_total)?; - - macro_rules! eval_constraint { - ($t:expr) => {{ <_ as SumcheckComputation>::eval_extension($t, &col_evals, &vd.extra_data) }}; - } - let constraint_eval = delegate_to_inner!(&vd.table => eval_constraint); - - let bus_point = from_end(gkr_point, table_n_vars[&vd.table]); - let natural_ordering_point = natural_ordering_point_for_session(&sumcheck_air_point.0, table_n_vars[&vd.table]); - my_air_final_value += back_loaded_table_contribution( - bus_point, - &sumcheck_air_point.0, - &natural_ordering_point, - constraint_eval, - vd.eta_power, - ); - - macro_rules! split { - ($t:expr) => {{ columns_evals_up_and_down($t, &col_evals, &natural_ordering_point) }}; - } - let claim = delegate_to_inner!(&vd.table => split); - - committed_statements.get_mut(&vd.table).unwrap().push(claim); - } - - if my_air_final_value != claimed_air_final_value { - return Err(ProofError::InvalidProof); - } - - let public_memory_random_point = - MultilinearPoint(verifier_state.sample_vec(log2_strict_usize(public_memory.len()))); - let public_memory_eval = public_memory.evaluate(&public_memory_random_point); - - let previous_statements = vec![ - SparseStatement::new( - parsed_commitment.num_variables, - logup_statements.memory_and_acc_point, - vec![ - SparseValue::new(0, logup_statements.value_memory), - SparseValue::new(1, logup_statements.value_memory_acc), - ], - ), - SparseStatement::new( - parsed_commitment.num_variables, - public_memory_random_point, - vec![SparseValue::new(0, public_memory_eval)], - ), - SparseStatement::new( - parsed_commitment.num_variables, - logup_statements.bytecode_and_acc_point, - vec![SparseValue::new( - (2 << log_memory) >> bytecode.log_size(), - logup_statements.value_bytecode_acc, - )], - ), - ]; - - let global_statements_base = stacked_pcs_global_statements( - parsed_commitment.num_variables, - log_memory, - bytecode.log_size(), - previous_statements, - &table_n_vars, - &committed_statements, - ); - - // sanity check (not necessary for soundness) - let num_whir_statements = global_statements_base.iter().map(|s| s.values.len()).sum::(); - assert_eq!(num_whir_statements, total_whir_statements()); - - WhirConfig::new(&whir_config, parsed_commitment.num_variables).verify( - &mut verifier_state, - &parsed_commitment, - global_statements_base, - )?; - - Ok(( - ProofVerificationDetails { - bytecode_evaluation: logup_statements.bytecode_evaluation.unwrap(), - }, - verifier_state.into_raw_proof(), - )) -} - -pub fn verify_execution_sha2( - bytecode: &Bytecode, - public_input: &[F], - proof: Proof, -) -> Result { - let mut verifier_state = VerifierStateSha2::::new(proof)?; + verifier_state: &mut V, +) -> Result +where + V: FSVerifier, + V::Digest: WhirVerifierDigest, +{ verifier_state.observe_scalars(public_input); verifier_state.observe_scalars(&poseidon16_compress_pair(&bytecode.hash, &SNARK_DOMAIN_SEP)); let dims = verifier_state @@ -280,9 +63,9 @@ pub fn verify_execution_sha2( return Err(ProofError::InvalidProof); } - let parsed_commitment = stacked_pcs_parse_commitment_sha2( + let parsed_commitment = stacked_pcs_parse_commitment_generic( &whir_config, - &mut verifier_state, + verifier_state, log_memory, bytecode.log_size(), &table_n_vars, @@ -293,7 +76,7 @@ pub fn verify_execution_sha2( let logup_alphas_eq_poly = eval_eq(&logup_alphas); let logup_statements = verify_generic_logup( - &mut verifier_state, + verifier_state, logup_c, &logup_alphas, &logup_alphas_eq_poly, @@ -358,7 +141,7 @@ pub fn verify_execution_sha2( let Evaluation { point: sumcheck_air_point, value: claimed_air_final_value, - } = sumcheck_verify(&mut verifier_state, n_max, max_full_degree, initial_sum, None)?; + } = sumcheck_verify(verifier_state, n_max, max_full_degree, initial_sum, None)?; let mut my_air_final_value = EF::ZERO; for vd in &verify_data { @@ -433,7 +216,7 @@ pub fn verify_execution_sha2( assert_eq!(num_whir_statements, total_whir_statements()); WhirConfig::new(&whir_config, parsed_commitment.num_variables).verify( - &mut verifier_state, + verifier_state, &parsed_commitment, global_statements_base, )?; diff --git a/crates/rec_aggregation/src/lib.rs b/crates/rec_aggregation/src/lib.rs index 3180d23ab..77571c5d7 100644 --- a/crates/rec_aggregation/src/lib.rs +++ b/crates/rec_aggregation/src/lib.rs @@ -5,18 +5,18 @@ mod compilation; mod type_1_aggregation; mod type_2_aggregation; -use backend::{Evaluation, Proof, ProofError, RawProof}; +use backend::{Evaluation, Proof, ProofError, RawProof, VerifierState}; pub use compilation::{ MAX_RECURSIONS, MAX_XMSS_AGGREGATED, MAX_XMSS_DUPLICATES, NUM_REPEATED_ONES, PREAMBLE_MEMORY_LEN, ZERO_VEC_LEN, get_aggregation_bytecode, init_aggregation_bytecode, }; -use lean_prover::verify_execution::verify_execution; +use lean_prover::verify_execution::verify; use lean_vm::{DIGEST_LEN, EF, F}; pub use type_1_aggregation::{TypeOneInfo, TypeOneMultiSignature, aggregate_type_1, verify_type_1}; pub use type_2_aggregation::{ TypeTwoMultiSignature, merge_many_type_1, split_type_2, split_type_2_by_msg, verify_type_2, }; -use utils::poseidon_compress_slice; +use utils::{get_poseidon16, poseidon_compress_slice}; #[allow(missing_debug_implementations)] pub struct InnerVerified { @@ -29,7 +29,9 @@ pub struct InnerVerified { pub(crate) fn verify_inner(input_data: Vec, proof: Proof) -> Result { let input_data_hash = poseidon_compress_slice(&input_data, true); let bytecode = get_aggregation_bytecode(); - let (verif, raw_proof) = verify_execution(bytecode, &input_data_hash, proof)?; + let mut verifier_state = VerifierState::::new(proof, get_poseidon16().clone())?; + let verif = verify(bytecode, &input_data_hash, &mut verifier_state)?; + let raw_proof = verifier_state.into_raw_proof(); Ok(InnerVerified { input_data, input_data_hash, diff --git a/crates/sub_protocols/src/stacked_pcs.rs b/crates/sub_protocols/src/stacked_pcs.rs index e85810d5f..027e89bdf 100644 --- a/crates/sub_protocols/src/stacked_pcs.rs +++ b/crates/sub_protocols/src/stacked_pcs.rs @@ -192,7 +192,7 @@ pub fn stacked_pcs_parse_commitment_sha2( ) } -fn stacked_pcs_parse_commitment_generic( +pub fn stacked_pcs_parse_commitment_generic( whir_config_builder: &WhirConfigBuilder, verifier_state: &mut impl FSVerifier, log_memory: usize, From 37ffe902c6e282c648a9482afa89e60085375ad0 Mon Sep 17 00:00:00 2001 From: kilic Date: Thu, 14 May 2026 21:25:21 +0300 Subject: [PATCH 15/19] a little more generic prover --- crates/lean_prover/src/prove_execution.rs | 265 +++++++++++----------- crates/lean_prover/src/test_zkvm.rs | 16 +- crates/sub_protocols/src/stacked_pcs.rs | 48 +++- 3 files changed, 191 insertions(+), 138 deletions(-) diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index a74c7e28d..3e25281d1 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -11,14 +11,101 @@ use utils::ansi::Colorize; use utils::{build_prover_state, build_prover_state_sha2, from_end}; #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ExecutionProof { - pub proof: Proof, - pub proof2: Proof, +#[serde(bound( + serialize = "Proof: Serialize", + deserialize = "Proof: Deserialize<'de>" +))] +pub struct ExecutionProof { + pub proof: Proof, // benchmark / debug purpose #[serde(skip, default)] pub metadata: Option, } +trait ExecutionProverBackend

+where + P: FSProver, +{ + type InnerWitness; + + fn stack_polynomials_and_commit( + prover_state: &mut P, + whir_config: &WhirConfigBuilder, + memory: &[F], + memory_acc: &[F], + bytecode_acc: &[F], + traces: &BTreeMap, + ) -> StackedPcsWitness; + + fn prove_whir( + whir_config: &WhirConfig, + prover_state: &mut P, + statements: Vec>, + witness: Self::InnerWitness, + polynomial: &MleRef<'_, EF>, + ); +} + +struct PoseidonExecutionBackend; + +impl

ExecutionProverBackend

for PoseidonExecutionBackend +where + P: FSProver, +{ + type InnerWitness = Witness; + + fn stack_polynomials_and_commit( + prover_state: &mut P, + whir_config: &WhirConfigBuilder, + memory: &[F], + memory_acc: &[F], + bytecode_acc: &[F], + traces: &BTreeMap, + ) -> StackedPcsWitness { + stack_polynomials_and_commit(prover_state, whir_config, memory, memory_acc, bytecode_acc, traces) + } + + fn prove_whir( + whir_config: &WhirConfig, + prover_state: &mut P, + statements: Vec>, + witness: Self::InnerWitness, + polynomial: &MleRef<'_, EF>, + ) { + whir_config.prove(prover_state, statements, witness, polynomial); + } +} + +struct Sha2ExecutionBackend; + +impl

ExecutionProverBackend

for Sha2ExecutionBackend +where + P: FSProver, +{ + type InnerWitness = Witness2; + + fn stack_polynomials_and_commit( + prover_state: &mut P, + whir_config: &WhirConfigBuilder, + memory: &[F], + memory_acc: &[F], + bytecode_acc: &[F], + traces: &BTreeMap, + ) -> StackedPcsWitness { + stack_polynomials_and_commit_sha2(prover_state, whir_config, memory, memory_acc, bytecode_acc, traces) + } + + fn prove_whir( + whir_config: &WhirConfig, + prover_state: &mut P, + statements: Vec>, + witness: Self::InnerWitness, + polynomial: &MleRef<'_, EF>, + ) { + whir_config.prove2(prover_state, statements, witness, polynomial); + } +} + pub fn prove_execution( bytecode: &Bytecode, public_input: &[F], @@ -26,6 +113,49 @@ pub fn prove_execution( whir_config: &WhirConfigBuilder, vm_profiler: bool, ) -> Result { + prove_execution_with::<_, PoseidonExecutionBackend, _>( + bytecode, + public_input, + witness, + whir_config, + vm_profiler, + build_prover_state(), + |prover_state| prover_state.into_proof(), + ) +} + +pub fn prove_execution_sha2( + bytecode: &Bytecode, + public_input: &[F], + witness: &ExecutionWitness, + whir_config: &WhirConfigBuilder, + vm_profiler: bool, +) -> Result, ProverError> { + prove_execution_with::<_, Sha2ExecutionBackend, _>( + bytecode, + public_input, + witness, + whir_config, + vm_profiler, + build_prover_state_sha2(), + |prover_state| prover_state.into_proof(), + ) +} + +fn prove_execution_with( + bytecode: &Bytecode, + public_input: &[F], + witness: &ExecutionWitness, + whir_config: &WhirConfigBuilder, + vm_profiler: bool, + mut prover_state: P, + into_proof: IntoProof, +) -> Result, ProverError> +where + P: FSProver, + B: ExecutionProverBackend

, + IntoProof: FnOnce(P) -> Proof, +{ check_rate(whir_config.starting_log_inv_rate) .map_err(|err| panic!("{err}")) .unwrap(); @@ -46,13 +176,9 @@ pub fn prove_execution( if memory.len() < min_memory_size { memory.resize(min_memory_size, F::ZERO); } - let mut prover_state = build_prover_state(); - let mut prover_state2 = build_prover_state_sha2(); prover_state.observe_scalars(public_input); - prover_state2.observe_scalars(public_input); let bytecode_hash_with_domain_sep = poseidon16_compress_pair(&bytecode.hash, &SNARK_DOMAIN_SEP); prover_state.observe_scalars(&bytecode_hash_with_domain_sep); - prover_state2.observe_scalars(&bytecode_hash_with_domain_sep); let execution_metadata_scalars = [ vec![ whir_config.starting_log_inv_rate, @@ -66,7 +192,6 @@ pub fn prove_execution( .map(F::from_usize) .collect::>(); prover_state.add_base_scalars(&execution_metadata_scalars); - prover_state2.add_base_scalars(&execution_metadata_scalars); for (table, table_trace) in &traces { let log_n_rows = table_trace.log_n_rows; assert!(log_n_rows >= MIN_LOG_N_ROWS_PER_TABLE, "missing padding"); @@ -116,9 +241,8 @@ pub fn prove_execution( }); // 1st Commitment - let stacked_pcs_witness = stack_polynomials_and_commit( + let stacked_pcs_witness = B::stack_polynomials_and_commit( &mut prover_state, - &mut prover_state2, whir_config, &memory, &memory_acc, @@ -141,24 +265,8 @@ pub fn prove_execution( &bytecode_acc, &traces, ); - let logup_c2 = prover_state2.sample(); - let logup_alphas2 = prover_state2.sample_vec(log2_ceil_usize(max_bus_width_including_domainsep())); - let logup_alphas_eq_poly2 = eval_eq(&logup_alphas2); - - let logup_statements2 = prove_generic_logup( - &mut prover_state2, - logup_c2, - &logup_alphas_eq_poly2, - &memory, - &memory_acc, - &bytecode.instructions_multilinear, - &bytecode_acc, - &traces, - ); let gkr_point = &logup_statements.gkr_point; - let gkr_point2 = &logup_statements2.gkr_point; let mut committed_statements: CommittedStatements = Default::default(); - let mut committed_statements2: CommittedStatements = Default::default(); for table in ALL_TABLES { let log_n_rows = traces[&table].log_n_rows; committed_statements.insert( @@ -169,24 +277,12 @@ pub fn prove_execution( BTreeMap::new(), )], ); - committed_statements2.insert( - table, - vec![( - MultilinearPoint(from_end(gkr_point2, log_n_rows).to_vec()), - logup_statements2.columns_values[&table].clone(), - BTreeMap::new(), - )], - ); } let bus_beta = prover_state.sample(); let air_alpha = prover_state.sample(); let air_alpha_powers: Vec = air_alpha.powers().collect_n(max_air_constraints() + 1); let air_eta: EF = prover_state.sample(); - let bus_beta2 = prover_state2.sample(); - let air_alpha2 = prover_state2.sample(); - let air_alpha_powers2: Vec = air_alpha2.powers().collect_n(max_air_constraints() + 1); - let air_eta2: EF = prover_state2.sample(); let tables_log_heights: BTreeMap = traces.iter().map(|(table, trace)| (*table, trace.log_n_rows)).collect(); @@ -209,7 +305,6 @@ pub fn prove_execution( .collect(); std::mem::drop(_span); let mut sessions = Vec::with_capacity(tables_sorted.len()); - let mut sessions2 = Vec::with_capacity(tables_sorted.len()); for (idx, (table, log_n_rows)) in tables_sorted.iter().enumerate() { let bus_numerator_value = logup_statements.bus_numerators_values[table]; let bus_denominator_value = logup_statements.bus_denominators_values[table]; @@ -237,44 +332,10 @@ pub fn prove_execution( }}; } sessions.push(delegate_to_inner!(table => make_session)); - - let bus_numerator_value2 = logup_statements2.bus_numerators_values[table]; - let bus_denominator_value2 = logup_statements2.bus_denominators_values[table]; - let bus_final_value2 = bus_numerator_value2 - * match table.bus().direction { - BusDirection::Pull => EF::NEG_ONE, - BusDirection::Push => EF::ONE, - } - + bus_beta2 * (bus_denominator_value2 - logup_c2); - - let eq_suffix2 = from_end(gkr_point2, *log_n_rows).to_vec(); - - let extra_data2 = ExtraDataForBuses::new(logup_alphas_eq_poly2.clone(), bus_beta2, air_alpha_powers2.clone()); - - let mut up_down2: Vec<&[PF]> = column_refs[idx].to_vec(); - up_down2.extend(shifted_rows[idx].iter().map(Vec::as_slice)); - let packed2 = MleGroupRef::::Base(up_down2).pack(); - - macro_rules! make_session2 { - ($t:expr) => {{ - let session = AirSumcheckSession::new( - packed2, - eq_suffix2, - bus_final_value2, - *$t, - extra_data2, - non_padded, - ); - Box::new(session) as Box + '_> - }}; - } - sessions2.push(delegate_to_inner!(table => make_session2)); } let sumcheck_air_point = info_span!("batched AIR sumcheck") .in_scope(|| prove_batched_air_sumcheck(&mut prover_state, &mut sessions, air_eta)); - let sumcheck_air_point2 = info_span!("batched AIR sumcheck sha2") - .in_scope(|| prove_batched_air_sumcheck(&mut prover_state2, &mut sessions2, air_eta2)); for (idx, (table, _)) in tables_sorted.iter().enumerate() { let col_evals = sessions[idx].final_column_evals(); @@ -287,23 +348,10 @@ pub fn prove_execution( } let claim = delegate_to_inner!(table => split); committed_statements.get_mut(table).unwrap().push(claim); - - let col_evals2 = sessions2[idx].final_column_evals(); - prover_state2.add_extension_scalars(&col_evals2); - - let natural_ordering_point2 = - natural_ordering_point_for_session(&sumcheck_air_point2.0, traces[table].log_n_rows); - macro_rules! split2 { - ($t:expr) => {{ columns_evals_up_and_down($t, &col_evals2, &natural_ordering_point2) }}; - } - let claim2 = delegate_to_inner!(table => split2); - committed_statements2.get_mut(table).unwrap().push(claim2); } let public_memory_random_point = MultilinearPoint(prover_state.sample_vec(log2_strict_usize(public_memory_size))); let public_memory_eval = (&memory[..public_memory_size]).evaluate(&public_memory_random_point); - let public_memory_random_point2 = MultilinearPoint(prover_state2.sample_vec(log2_strict_usize(public_memory_size))); - let public_memory_eval2 = (&memory[..public_memory_size]).evaluate(&public_memory_random_point2); let previous_statements = vec![ SparseStatement::new( @@ -328,30 +376,6 @@ pub fn prove_execution( )], ), ]; - let previous_statements2 = vec![ - SparseStatement::new( - stacked_pcs_witness.stacked_n_vars, - logup_statements2.memory_and_acc_point, - vec![ - SparseValue::new(0, logup_statements2.value_memory), - SparseValue::new(1, logup_statements2.value_memory_acc), - ], - ), - SparseStatement::new( - stacked_pcs_witness.stacked_n_vars, - public_memory_random_point2, - vec![SparseValue::new(0, public_memory_eval2)], - ), - SparseStatement::new( - stacked_pcs_witness.stacked_n_vars, - logup_statements2.bytecode_and_acc_point, - vec![SparseValue::new( - (2 * memory.len()) >> bytecode.log_size(), - logup_statements2.value_bytecode_acc, - )], - ), - ]; - let global_statements_base = stacked_pcs_global_statements( stacked_pcs_witness.stacked_n_vars, log2_strict_usize(memory.len()), @@ -360,35 +384,20 @@ pub fn prove_execution( &tables_log_heights, &committed_statements, ); - let global_statements_base2 = stacked_pcs_global_statements( - stacked_pcs_witness.stacked_n_vars, - log2_strict_usize(memory.len()), - bytecode.log_size(), - previous_statements2, - &tables_log_heights, - &committed_statements2, - ); - WhirConfig::new(whir_config, stacked_pcs_witness.global_polynomial.by_ref().n_vars()).prove( + B::prove_whir( + &WhirConfig::new(whir_config, stacked_pcs_witness.global_polynomial.by_ref().n_vars()), &mut prover_state, global_statements_base, stacked_pcs_witness.inner_witness, &stacked_pcs_witness.global_polynomial.by_ref(), ); - WhirConfig::new(whir_config, stacked_pcs_witness.global_polynomial.by_ref().n_vars()).prove2( - &mut prover_state2, - global_statements_base2, - stacked_pcs_witness.inner_witness2, - &stacked_pcs_witness.global_polynomial.by_ref(), - ); - tracing::info!("total pow_grinding time: {} ms", pow_grinding_time().as_millis()); reset_pow_grinding_time(); Ok(ExecutionProof { - proof: prover_state.into_proof(), - proof2: prover_state2.into_proof(), + proof: into_proof(prover_state), metadata: Some(metadata), }) } diff --git a/crates/lean_prover/src/test_zkvm.rs b/crates/lean_prover/src/test_zkvm.rs index 3a74ee447..21ef1b7c9 100644 --- a/crates/lean_prover/src/test_zkvm.rs +++ b/crates/lean_prover/src/test_zkvm.rs @@ -1,4 +1,8 @@ -use crate::{default_whir_config, prove_execution::prove_execution, verify_execution::verify}; +use crate::{ + default_whir_config, + prove_execution::{prove_execution, prove_execution_sha2}, + verify_execution::verify, +}; use backend::*; use lean_compiler::*; use lean_vm::*; @@ -246,10 +250,18 @@ fn test_zk_vm_helper(program_str: &str, public_input: &[F]) { false, ) .unwrap(); + let proof2 = prove_execution_sha2( + &bytecode, + public_input, + &witness, + &default_whir_config(starting_log_inv_rate), + false, + ) + .unwrap(); let proof_time = time.elapsed(); let mut verifier_state = VerifierState::::new(proof.proof, get_poseidon16().clone()).unwrap(); verify(&bytecode, public_input, &mut verifier_state).unwrap(); - let mut verifier_state2 = VerifierStateSha2::::new(proof.proof2).unwrap(); + let mut verifier_state2 = VerifierStateSha2::::new(proof2.proof).unwrap(); verify(&bytecode, public_input, &mut verifier_state2).unwrap(); println!("{}", proof.metadata.as_ref().unwrap().display()); println!("Proof time: {:.3} s", proof_time.as_secs_f32()); diff --git a/crates/sub_protocols/src/stacked_pcs.rs b/crates/sub_protocols/src/stacked_pcs.rs index 027e89bdf..6320be630 100644 --- a/crates/sub_protocols/src/stacked_pcs.rs +++ b/crates/sub_protocols/src/stacked_pcs.rs @@ -32,10 +32,9 @@ Stacking of various (multilinear) polynomials into a single -big- (multilinear) */ #[derive(Debug)] -pub struct StackedPcsWitness { +pub struct StackedPcsWitness { pub stacked_n_vars: VarCount, - pub inner_witness: Witness, - pub inner_witness2: Witness2, + pub inner_witness: InnerWitness, pub global_polynomial: MleOwned, } @@ -99,13 +98,48 @@ pub fn stacked_pcs_global_statements( #[instrument(skip_all)] pub fn stack_polynomials_and_commit( prover_state: &mut impl FSProver, - prover_state2: &mut impl FSProver, whir_config_builder: &WhirConfigBuilder, memory: &[F], memory_acc: &[F], bytecode_acc: &[F], traces: &BTreeMap, -) -> StackedPcsWitness { +) -> StackedPcsWitness> { + stack_polynomials_and_commit_with( + whir_config_builder, + memory, + memory_acc, + bytecode_acc, + traces, + |whir_config, global_polynomial, offset| whir_config.commit(prover_state, global_polynomial, offset), + ) +} + +pub fn stack_polynomials_and_commit_sha2( + prover_state: &mut impl FSProver, + whir_config_builder: &WhirConfigBuilder, + memory: &[F], + memory_acc: &[F], + bytecode_acc: &[F], + traces: &BTreeMap, +) -> StackedPcsWitness> { + stack_polynomials_and_commit_with( + whir_config_builder, + memory, + memory_acc, + bytecode_acc, + traces, + |whir_config, global_polynomial, offset| whir_config.commit2(prover_state, global_polynomial, offset), + ) +} + +fn stack_polynomials_and_commit_with( + whir_config_builder: &WhirConfigBuilder, + memory: &[F], + memory_acc: &[F], + bytecode_acc: &[F], + traces: &BTreeMap, + commit: impl FnOnce(&WhirConfig, &MleOwned, usize) -> InnerWitness, +) -> StackedPcsWitness { assert_eq!(memory.len(), memory_acc.len()); let tables_heights = traces.iter().map(|(table, trace)| (*table, trace.log_n_rows)).collect(); let tables_heights_sorted = sort_tables_by_height(&tables_heights); @@ -150,12 +184,10 @@ pub fn stack_polynomials_and_commit( let global_polynomial = MleOwned::Base(global_polynomial); let whir_config = WhirConfig::new(whir_config_builder, stacked_n_vars); - let inner_witness = whir_config.commit(prover_state, &global_polynomial, offset); - let inner_witness2 = whir_config.commit2(prover_state2, &global_polynomial, offset); + let inner_witness = commit(&whir_config, &global_polynomial, offset); StackedPcsWitness { stacked_n_vars, inner_witness, - inner_witness2, global_polynomial, } } From 7c7556493d854c7d73b052ba1feb3dfe611ab85d Mon Sep 17 00:00:00 2001 From: kilic Date: Thu, 14 May 2026 21:42:32 +0300 Subject: [PATCH 16/19] better tracing --- crates/lean_prover/src/test_zkvm.rs | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/crates/lean_prover/src/test_zkvm.rs b/crates/lean_prover/src/test_zkvm.rs index 21ef1b7c9..aca563e04 100644 --- a/crates/lean_prover/src/test_zkvm.rs +++ b/crates/lean_prover/src/test_zkvm.rs @@ -239,9 +239,10 @@ def fibonacci_const(a, b, n: Const): fn test_zk_vm_helper(program_str: &str, public_input: &[F]) { utils::init_tracing(); let bytecode = compile_program(&ProgramSource::Raw(program_str.to_string())); - let time = std::time::Instant::now(); let starting_log_inv_rate = 1; let witness = ExecutionWitness::default(); + + let time = std::time::Instant::now(); let proof = prove_execution( &bytecode, public_input, @@ -250,6 +251,13 @@ fn test_zk_vm_helper(program_str: &str, public_input: &[F]) { false, ) .unwrap(); + let poseidon_proof_time = time.elapsed(); + + println!("Poseidon proof"); + println!("{}", proof.metadata.as_ref().unwrap().display()); + println!("Proof time: {:.3} s", poseidon_proof_time.as_secs_f32()); + + let time = std::time::Instant::now(); let proof2 = prove_execution_sha2( &bytecode, public_input, @@ -258,11 +266,14 @@ fn test_zk_vm_helper(program_str: &str, public_input: &[F]) { false, ) .unwrap(); - let proof_time = time.elapsed(); + let sha2_proof_time = time.elapsed(); + + println!("SHA2 proof"); + println!("{}", proof2.metadata.as_ref().unwrap().display()); + println!("Proof time: {:.3} s", sha2_proof_time.as_secs_f32()); + let mut verifier_state = VerifierState::::new(proof.proof, get_poseidon16().clone()).unwrap(); verify(&bytecode, public_input, &mut verifier_state).unwrap(); let mut verifier_state2 = VerifierStateSha2::::new(proof2.proof).unwrap(); verify(&bytecode, public_input, &mut verifier_state2).unwrap(); - println!("{}", proof.metadata.as_ref().unwrap().display()); - println!("Proof time: {:.3} s", proof_time.as_secs_f32()); } From 9530db5a6edfebb75ed2846ddc3b6347d55a0c47 Mon Sep 17 00:00:00 2001 From: kilic Date: Thu, 14 May 2026 21:49:03 +0300 Subject: [PATCH 17/19] timings --- crates/whir/src/open.rs | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index d9bbd36c5..5f7eb110f 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -12,6 +12,36 @@ use tracing::{info_span, instrument}; use crate::{config::WhirConfig, *}; +#[instrument( + skip_all, + fields(hash = "poseidon", round = round_index) +)] +fn merkle_data_build( + folded_matrix: DftOutput, + full_n_cols: usize, + round_index: usize, +) -> (MerkleData, [PF; DIGEST_ELEMS]) +where + EF: ExtensionField>, +{ + MerkleData::build(folded_matrix, full_n_cols, full_n_cols) +} + +#[instrument( + skip_all, + fields(hash = "sha2", round = round_index) +)] +fn merkle_data_build_sha2( + folded_matrix: DftOutput, + full_n_cols: usize, + round_index: usize, +) -> (MerkleData2, Sha256Digest) +where + EF: ExtensionField>, +{ + MerkleData2::build(folded_matrix, full_n_cols, full_n_cols) +} + impl WhirConfig where EF: ExtensionField>, @@ -111,7 +141,7 @@ where }); let full = 1 << folding_factor_next; - let (prover_data, root) = MerkleData::build(folded_matrix, full, full); + let (prover_data, root) = merkle_data_build(folded_matrix, full, round_index); prover_state.add_commitment(&root); @@ -234,7 +264,7 @@ where }); let full = 1 << folding_factor_next; - let (prover_data, root) = MerkleData2::build(folded_matrix, full, full); + let (prover_data, root) = merkle_data_build_sha2(folded_matrix, full, round_index); prover_state.add_commitment(&root); From 5e8a43c9f85a0a9b81db0bd16efce2ae64a2c365 Mon Sep 17 00:00:00 2001 From: kilic Date: Thu, 14 May 2026 21:49:24 +0300 Subject: [PATCH 18/19] grind faster --- crates/backend/fiat-shamir/src/challenger.rs | 65 ++++++++++++++++++-- crates/backend/fiat-shamir/src/prover.rs | 10 +-- crates/backend/fiat-shamir/src/verifier.rs | 3 +- 3 files changed, 68 insertions(+), 10 deletions(-) diff --git a/crates/backend/fiat-shamir/src/challenger.rs b/crates/backend/fiat-shamir/src/challenger.rs index 81ea2036c..5a570e673 100644 --- a/crates/backend/fiat-shamir/src/challenger.rs +++ b/crates/backend/fiat-shamir/src/challenger.rs @@ -133,13 +133,23 @@ impl ChallengerSha2 { res } + pub fn pow_grinding_sample_matches(&self, bits: usize) -> bool { + assert!(bits < F::bits()); + let sample = self.sample_first_word(0, 0); + let rand_usize = sample.as_canonical_u64() as usize; + (rand_usize & ((1 << bits) - 1)) == 0 + } + + pub fn pow_grinding_witness_matches(&self, witness: F, bits: usize) -> bool { + let mut challenger = self.clone(); + challenger.observe_scalars(&[witness]); + challenger.pow_grinding_sample_matches(bits) + } + fn sample_chunk(&self, domain_sep: usize) -> [F; RATE] { let mut words = Vec::with_capacity(RATE); for block_idx in 0u64.. { - let mut hasher = self.sha2.clone(); - hasher.update((domain_sep as u64).to_le_bytes()); - hasher.update(block_idx.to_le_bytes()); - let digest = hasher.finalize(); + let digest = self.sample_digest(domain_sep, block_idx); for word in digest.chunks_exact(size_of::()) { let word = u32::from_le_bytes(word.try_into().unwrap()); words.push(F::from_int(word)); @@ -150,6 +160,19 @@ impl ChallengerSha2 { } unreachable!() } + + fn sample_first_word(&self, domain_sep: usize, block_idx: u64) -> F { + let digest = self.sample_digest(domain_sep, block_idx); + let word = u32::from_le_bytes(digest[..size_of::()].try_into().unwrap()); + F::from_int(word) + } + + fn sample_digest(&self, domain_sep: usize, block_idx: u64) -> sha2::digest::Output { + let mut hasher = self.sha2.clone(); + hasher.update((domain_sep as u64).to_le_bytes()); + hasher.update(block_idx.to_le_bytes()); + hasher.finalize() + } } impl Default for ChallengerSha2 { @@ -157,3 +180,37 @@ impl Default for ChallengerSha2 { Self::new() } } + +#[cfg(test)] +mod tests { + use field::PrimeCharacteristicRing; + use koala_bear::KoalaBear; + + use super::ChallengerSha2; + + #[test] + fn sha2_pow_grinding_direct_predicate_matches_sampling_path() { + let transcript_prefixes = [ + vec![], + vec![KoalaBear::ONE], + (0..17).map(KoalaBear::from_usize).collect::>(), + (100..141).map(KoalaBear::from_usize).collect::>(), + ]; + + for prefix in transcript_prefixes { + let mut challenger = ChallengerSha2::new(); + challenger.observe_scalars(&prefix); + + for bits in 1..=20 { + for candidate in [0, 1, 2, 3, 5, 8, 13, 21, 55, 89, 144, 233, 377, 610] { + let witness = KoalaBear::from_usize(candidate); + let mut sampling_path = challenger.clone(); + sampling_path.observe_scalars(&[witness]); + let expected = sampling_path.sample_in_range(bits, 1)[0] == 0; + let actual = challenger.pow_grinding_witness_matches(witness, bits); + assert_eq!(actual, expected); + } + } + } + } +} diff --git a/crates/backend/fiat-shamir/src/prover.rs b/crates/backend/fiat-shamir/src/prover.rs index 1a2719686..2f693aec9 100644 --- a/crates/backend/fiat-shamir/src/prover.rs +++ b/crates/backend/fiat-shamir/src/prover.rs @@ -283,12 +283,14 @@ where } let time = Instant::now(); + let challenger = self.challenger.clone(); let witness = (0..PF::::ORDER_U32) - .find_map(|candidate| { + .into_par_iter() + .find_map_any(|candidate| { let witness = unsafe { PF::::from_canonical_unchecked(candidate) }; - let mut challenger = self.challenger.clone(); - challenger.observe_scalars(&[witness]); - (challenger.sample_in_range(bits, 1)[0] == 0).then_some(witness) + challenger + .pow_grinding_witness_matches(witness, bits) + .then_some(witness) }) .expect("failed to find witness"); diff --git a/crates/backend/fiat-shamir/src/verifier.rs b/crates/backend/fiat-shamir/src/verifier.rs index aec46f64d..3728119cf 100644 --- a/crates/backend/fiat-shamir/src/verifier.rs +++ b/crates/backend/fiat-shamir/src/verifier.rs @@ -348,8 +348,7 @@ where } let witness = self.read_transcript(1)?[0]; self.challenger.observe_scalars(&[witness]); - let mut challenger = self.challenger.clone(); - if challenger.sample_in_range(bits, 1)[0] != 0 { + if !self.challenger.pow_grinding_sample_matches(bits) { return Err(ProofError::InvalidGrindingWitness); } Ok(()) From 612a2828d91a67bc7e26af3a6381caf817e791bf Mon Sep 17 00:00:00 2001 From: kilic Date: Thu, 14 May 2026 22:51:48 +0300 Subject: [PATCH 19/19] port sha256 precompile --- Cargo.lock | 1 + .../lean_compiler/src/a_simplify_lang/mod.rs | 28 +- .../lean_compiler/src/instruction_encoder.rs | 1 + .../src/parser/parsers/function.rs | 5 +- crates/lean_compiler/tests/test_compiler.rs | 26 +- crates/lean_prover/src/test_zkvm.rs | 97 ++- crates/lean_prover/src/trace_gen.rs | 29 +- crates/lean_vm/Cargo.toml | 3 + crates/lean_vm/src/core/constants.rs | 5 +- crates/lean_vm/src/diagnostics/error.rs | 2 + crates/lean_vm/src/diagnostics/exec_result.rs | 7 + crates/lean_vm/src/execution/runner.rs | 1 + crates/lean_vm/src/isa/instruction.rs | 9 +- crates/lean_vm/src/tables/execution/mod.rs | 8 +- crates/lean_vm/src/tables/extension_op/mod.rs | 8 +- crates/lean_vm/src/tables/mod.rs | 7 +- crates/lean_vm/src/tables/poseidon_16/mod.rs | 12 +- .../lean_vm/src/tables/sha256_compress/air.rs | 373 ++++++++++++ .../src/tables/sha256_compress/columns.rs | 103 ++++ .../lean_vm/src/tables/sha256_compress/mod.rs | 562 ++++++++++++++++++ .../src/tables/sha256_compress/trace_gen.rs | 126 ++++ crates/lean_vm/src/tables/table_enum.rs | 19 +- crates/lean_vm/src/tables/table_trait.rs | 11 +- 23 files changed, 1403 insertions(+), 40 deletions(-) create mode 100644 crates/lean_vm/src/tables/sha256_compress/air.rs create mode 100644 crates/lean_vm/src/tables/sha256_compress/columns.rs create mode 100644 crates/lean_vm/src/tables/sha256_compress/mod.rs create mode 100644 crates/lean_vm/src/tables/sha256_compress/trace_gen.rs diff --git a/Cargo.lock b/Cargo.lock index b3b63a8ae..981ad6981 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -548,6 +548,7 @@ dependencies = [ "pest", "pest_derive", "rand", + "sha2 0.11.0", "tracing", "utils", "xmss", diff --git a/crates/lean_compiler/src/a_simplify_lang/mod.rs b/crates/lean_compiler/src/a_simplify_lang/mod.rs index e784d27f8..9de93f0db 100644 --- a/crates/lean_compiler/src/a_simplify_lang/mod.rs +++ b/crates/lean_compiler/src/a_simplify_lang/mod.rs @@ -8,7 +8,7 @@ use backend::PrimeCharacteristicRing; use lean_vm::{ ALL_POSEIDON16_NAMES, Boolean, BooleanExpr, CustomHint, ExtensionOpMode, FunctionName, POSEIDON16_HALF_HARDCODED_LEFT_NAME, POSEIDON16_HALF_NAME, POSEIDON16_HARDCODED_LEFT_NAME, PrecompileArgs, - PrecompileCompTimeArgs, SourceLocation, + PrecompileCompTimeArgs, SHA256_COMPRESS_NAME, SourceLocation, }; use std::{ collections::{BTreeMap, BTreeSet}, @@ -2308,6 +2308,32 @@ fn simplify_lines( continue; } + // Special handling for SHA256 compression precompile + if function_name == SHA256_COMPRESS_NAME { + if !targets.is_empty() { + return Err(format!( + "Precompile {function_name} should not return values, at {location}" + )); + } + if args.len() != 3 { + return Err(format!( + "Precompile {function_name} expects 3 arguments (state_ptr, block_ptr, out_ptr), got {}, at {location}", + args.len() + )); + } + let simplified_args = args + .iter() + .map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res)) + .collect::, _>>()?; + res.push(SimpleLine::Precompile(PrecompileArgs { + arg_0: simplified_args[0].clone(), + arg_1: simplified_args[1].clone(), + res: simplified_args[2].clone(), + data: PrecompileCompTimeArgs::Sha256Compress, + })); + continue; + } + // Special handling for custom hints if let Some(hint) = CustomHint::find_by_name(function_name) { if !targets.is_empty() { diff --git a/crates/lean_compiler/src/instruction_encoder.rs b/crates/lean_compiler/src/instruction_encoder.rs index 1060e3be4..d1658cb8d 100644 --- a/crates/lean_compiler/src/instruction_encoder.rs +++ b/crates/lean_compiler/src/instruction_encoder.rs @@ -59,6 +59,7 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { + POSEIDON_HARDCODED_LEFT_4_FLAG_SHIFT * flag_left + POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT * hardcoded_offset_left_val } + PrecompileCompTimeArgs::Sha256Compress => SHA256_PRECOMPILE_DATA, PrecompileCompTimeArgs::ExtensionOp { size, mode } => { assert!(*size >= 1, "invalid extension_op size={size}"); mode.flag_encoding() + EXT_OP_LEN_MULTIPLIER * size diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index 576e5af60..beebd81a5 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -8,7 +8,7 @@ use crate::{ grammar::{ParsePair, Rule}, }, }; -use lean_vm::{ALL_POSEIDON16_NAMES, CUSTOM_HINTS, ExtensionOpMode}; +use lean_vm::{ALL_POSEIDON16_NAMES, CUSTOM_HINTS, ExtensionOpMode, SHA256_COMPRESS_NAME}; /// Reserved function names that users cannot define. pub const RESERVED_FUNCTION_NAMES: &[&str] = &[ @@ -36,6 +36,9 @@ fn is_reserved_function_name(name: &str) -> bool { if ALL_POSEIDON16_NAMES.contains(&name) { return true; } + if name == SHA256_COMPRESS_NAME { + return true; + } if ExtensionOpMode::from_name(name).is_some() { return true; } diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 2c187a08e..635955b78 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -1,6 +1,6 @@ use std::time::Instant; -use backend::BasedVectorSpace; +use backend::{BasedVectorSpace, PrimeCharacteristicRing}; use lean_compiler::*; use lean_vm::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; @@ -26,6 +26,30 @@ def main(): let _ = dbg!(poseidon16_compress(public_input)); } +#[test] +fn test_sha256_compress() { + let program = r#" +def main(): + state = 0 + block = 16 + expected = 48 + out = Array(16) + sha256_compress(state, block, out) + + for i in unroll(0, 16): + assert out[i] == expected[i] + return + "#; + + let mut public_input = vec![F::ZERO; 64]; + public_input[0..16].copy_from_slice(&words_to_field_limbs_le(SHA256_IV)); + public_input[16..48].copy_from_slice(&words_to_field_limbs_le(SHA256_ABC_BLOCK)); + let expected = words_to_field_limbs_le(sha256_compress_words(SHA256_IV, SHA256_ABC_BLOCK)); + public_input[48..64].copy_from_slice(&expected); + + compile_and_run(&ProgramSource::Raw(program.to_string()), &public_input, false); +} + #[test] fn test_div_extension_field() { let program = r#" diff --git a/crates/lean_prover/src/test_zkvm.rs b/crates/lean_prover/src/test_zkvm.rs index aca563e04..46bb04b86 100644 --- a/crates/lean_prover/src/test_zkvm.rs +++ b/crates/lean_prover/src/test_zkvm.rs @@ -9,6 +9,49 @@ use lean_vm::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; use utils::{get_poseidon16, init_tracing, poseidon16_compress}; +#[test] +#[ignore = "benchmark; run with `cargo test --release -p lean_prover bench_sha256_compress -- --ignored --nocapture`"] +fn bench_sha256_compress() { + utils::init_tracing(); + let n_sha_calls = std::env::var("SHA256_BENCH_CALLS") + .ok() + .map(|raw| raw.parse::().expect("SHA256_BENCH_CALLS must be a usize")) + .unwrap_or(1); + const SHA_FIXTURE_STRIDE: usize = SHA256_STATE_LIMBS + SHA256_BLOCK_LIMBS + SHA256_STATE_LIMBS; + let program_str = format!( + r#" +N_SHA_CALLS = {n_sha_calls} +SHA_FIXTURE_STRIDE = 64 + +def main(): + for j in unroll(0, N_SHA_CALLS): + base = j * SHA_FIXTURE_STRIDE + state = base + block = base + 16 + expected = base + 48 + out = Array(16) + sha256_compress(state, block, out) + + for i in unroll(0, 16): + assert out[i] == expected[i] + return +"# + ); + + let mut public_input = vec![F::ZERO; n_sha_calls * SHA_FIXTURE_STRIDE]; + let expected = words_to_field_limbs_le(sha256_compress_words(SHA256_IV, SHA256_ABC_BLOCK)); + for j in 0..n_sha_calls { + let base = j * SHA_FIXTURE_STRIDE; + public_input[base..base + SHA256_STATE_LIMBS].copy_from_slice(&words_to_field_limbs_le(SHA256_IV)); + public_input[base + SHA256_STATE_LIMBS..base + SHA256_STATE_LIMBS + SHA256_BLOCK_LIMBS] + .copy_from_slice(&words_to_field_limbs_le(SHA256_ABC_BLOCK)); + public_input[base + SHA256_STATE_LIMBS + SHA256_BLOCK_LIMBS..base + SHA_FIXTURE_STRIDE] + .copy_from_slice(&expected); + } + + test_zk_vm_helper(&program_str, &public_input); +} + #[test] fn test_zk_vm_all_precompiles() { let program_str = r#" @@ -22,6 +65,15 @@ def main(): pub_start = 0 poseidon16_compress(pub_start + 4 * DIGEST_LEN, pub_start + 5 * DIGEST_LEN, pub_start + 6 * DIGEST_LEN) + # Keep the SHA fixture away from the extension-op fixture ranges below. + sha_state = pub_start + 1400 + sha_block = sha_state + 16 + sha_expected = sha_block + 32 + sha_out = Array(16) + sha256_compress(sha_state, sha_block, sha_out) + for i in unroll(0, 16): + assert sha_out[i] == sha_expected[i] + # poseidon16_compress_half: only first 4 FE constrained full_out = pub_start + 6 * DIGEST_LEN half_out = pub_start + 80 @@ -107,6 +159,17 @@ def main(): F::from_usize(444), ]); + // SHA256 compression test data: IV + padded "abc" block. + let sha_state_ptr = 1400; + let sha_block_ptr = sha_state_ptr + SHA256_STATE_LIMBS; + let sha_expected_ptr = sha_block_ptr + SHA256_BLOCK_LIMBS; + public_input[sha_state_ptr..sha_state_ptr + SHA256_STATE_LIMBS] + .copy_from_slice(&words_to_field_limbs_le(SHA256_IV)); + public_input[sha_block_ptr..sha_block_ptr + SHA256_BLOCK_LIMBS] + .copy_from_slice(&words_to_field_limbs_le(SHA256_ABC_BLOCK)); + let sha_expected = words_to_field_limbs_le(sha256_compress_words(SHA256_IV, SHA256_ABC_BLOCK)); + public_input[sha_expected_ptr..sha_expected_ptr + SHA256_STATE_LIMBS].copy_from_slice(&sha_expected); + let hardcoded_data: [F; 4] = rng.random(); let hardcoded_prefix: [F; 4] = rng.random(); public_input[1496..1500].copy_from_slice(&hardcoded_data); @@ -239,12 +302,30 @@ def fibonacci_const(a, b, n: Const): fn test_zk_vm_helper(program_str: &str, public_input: &[F]) { utils::init_tracing(); let bytecode = compile_program(&ProgramSource::Raw(program_str.to_string())); + + test_zk_vm_bytecode_helper_poseidon(&bytecode, public_input); + test_zk_vm_bytecode_helper_sha2(&bytecode, public_input); +} + +fn test_zk_vm_helper_poseidon(program_str: &str, public_input: &[F]) { + utils::init_tracing(); + let bytecode = compile_program(&ProgramSource::Raw(program_str.to_string())); + test_zk_vm_bytecode_helper_poseidon(&bytecode, public_input); +} + +fn test_zk_vm_helper_sha2(program_str: &str, public_input: &[F]) { + utils::init_tracing(); + let bytecode = compile_program(&ProgramSource::Raw(program_str.to_string())); + test_zk_vm_bytecode_helper_sha2(&bytecode, public_input); +} + +fn test_zk_vm_bytecode_helper_poseidon(bytecode: &Bytecode, public_input: &[F]) { let starting_log_inv_rate = 1; let witness = ExecutionWitness::default(); let time = std::time::Instant::now(); let proof = prove_execution( - &bytecode, + bytecode, public_input, &witness, &default_whir_config(starting_log_inv_rate), @@ -257,9 +338,17 @@ fn test_zk_vm_helper(program_str: &str, public_input: &[F]) { println!("{}", proof.metadata.as_ref().unwrap().display()); println!("Proof time: {:.3} s", poseidon_proof_time.as_secs_f32()); + let mut verifier_state = VerifierState::::new(proof.proof, get_poseidon16().clone()).unwrap(); + verify(bytecode, public_input, &mut verifier_state).unwrap(); +} + +fn test_zk_vm_bytecode_helper_sha2(bytecode: &Bytecode, public_input: &[F]) { + let starting_log_inv_rate = 1; + let witness = ExecutionWitness::default(); + let time = std::time::Instant::now(); let proof2 = prove_execution_sha2( - &bytecode, + bytecode, public_input, &witness, &default_whir_config(starting_log_inv_rate), @@ -272,8 +361,6 @@ fn test_zk_vm_helper(program_str: &str, public_input: &[F]) { println!("{}", proof2.metadata.as_ref().unwrap().display()); println!("Proof time: {:.3} s", sha2_proof_time.as_secs_f32()); - let mut verifier_state = VerifierState::::new(proof.proof, get_poseidon16().clone()).unwrap(); - verify(&bytecode, public_input, &mut verifier_state).unwrap(); let mut verifier_state2 = VerifierStateSha2::::new(proof2.proof).unwrap(); - verify(&bytecode, public_input, &mut verifier_state2).unwrap(); + verify(bytecode, public_input, &mut verifier_state2).unwrap(); } diff --git a/crates/lean_prover/src/trace_gen.rs b/crates/lean_prover/src/trace_gen.rs index 1801a5b62..5d0aa9346 100644 --- a/crates/lean_prover/src/trace_gen.rs +++ b/crates/lean_prover/src/trace_gen.rs @@ -99,6 +99,24 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul let null_poseidon_16_hash_ptr = memory_padded.len(); memory_padded.extend_from_slice(get_poseidon_16_of_zero()); + let sha256_padding_state_ptr = memory_padded.len(); + memory_padded.extend(words_to_field_limbs_le(SHA256_IV)); + let sha256_padding_block_ptr = memory_padded.len(); + memory_padded.extend(words_to_field_limbs_le(SHA256_ZERO_BLOCK)); + let sha256_padding_out_ptr = memory_padded.len(); + memory_padded.extend(words_to_field_limbs_le(sha256_compress_words( + SHA256_IV, + SHA256_ZERO_BLOCK, + ))); + + let padding_memory = PaddingMemory { + zero_vec_ptr: padding_zero_vec_ptr, + null_poseidon_16_hash_ptr, + sha256_state_ptr: sha256_padding_state_ptr, + sha256_block_ptr: sha256_padding_block_ptr, + sha256_out_ptr: sha256_padding_out_ptr, + }; + // IMPORTANT: memory size should always be >= number of VM cycles let padded_memory_len = (memory_padded.len().max(n_cycles).max(1 << MIN_LOG_N_ROWS_PER_TABLE)).next_power_of_two(); memory_padded.resize(padded_memory_len, F::ZERO); @@ -142,7 +160,7 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul }, ); for table in traces.keys().copied().collect::>() { - pad_table(&table, &mut traces, padding_zero_vec_ptr, null_poseidon_16_hash_ptr); + pad_table(&table, &mut traces, &padding_memory); } ExecutionTrace { @@ -153,12 +171,7 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul } } -fn pad_table( - table: &Table, - traces: &mut BTreeMap, - zero_vec_ptr: usize, - null_poseidon_16_hash_ptr: usize, -) { +fn pad_table(table: &Table, traces: &mut BTreeMap, padding_memory: &PaddingMemory) { let trace = traces.get_mut(table).unwrap(); let h = trace.columns[0].len(); trace @@ -170,7 +183,7 @@ fn pad_table( trace.non_padded_n_rows = h; trace.log_n_rows = log2_ceil_usize(h + 1).max(MIN_LOG_N_ROWS_PER_TABLE); let n_rows = 1 << trace.log_n_rows; - let padding_row = table.padding_row(zero_vec_ptr, null_poseidon_16_hash_ptr); + let padding_row = table.padding_row(padding_memory); trace.columns.par_iter_mut().enumerate().for_each(|(i, col)| { assert!(col.len() <= h); // potentially some columns have not been filled (in Poseidon -> we fill it later with SIMD + parallelism), but the first one should always be representative col.resize(n_rows, padding_row[i]); diff --git a/crates/lean_vm/Cargo.toml b/crates/lean_vm/Cargo.toml index 32e50f199..a138feb5d 100644 --- a/crates/lean_vm/Cargo.toml +++ b/crates/lean_vm/Cargo.toml @@ -15,3 +15,6 @@ rand.workspace = true tracing.workspace = true backend.workspace = true itertools.workspace = true + +[dev-dependencies] +sha2.workspace = true diff --git a/crates/lean_vm/src/core/constants.rs b/crates/lean_vm/src/core/constants.rs index afe6bc2d8..088054ad2 100644 --- a/crates/lean_vm/src/core/constants.rs +++ b/crates/lean_vm/src/core/constants.rs @@ -21,10 +21,13 @@ pub const MIN_BYTECODE_LOG_SIZE: usize = 8; /// Minimum and maximum number of rows per table (as powers of two), both inclusive pub const MIN_LOG_N_ROWS_PER_TABLE: usize = 8; // Zero padding will be added to each at least, if this minimum is not reached, (ensuring AIR / GKR work fine, with SIMD, without too much edge cases). Long term, we should find a more elegant solution. -pub const MAX_LOG_N_ROWS_PER_TABLE: [(Table, usize); 3] = [ +pub const MAX_LOG_N_ROWS_PER_TABLE: [(Table, usize); 4] = [ (Table::execution(), 24), (Table::extension_op(), 21), (Table::poseidon16(), 21), + // Direct Plonky3-style SHA256 has 7524 columns. 2^13 rows already exceeds + // the current commitment-surface guard; 2^12 is the largest safe cap today. + (Table::sha256_compress(), 12), ]; pub fn max_log_n_rows_per_table(table: &Table) -> usize { diff --git a/crates/lean_vm/src/diagnostics/error.rs b/crates/lean_vm/src/diagnostics/error.rs index 16492d708..48cc51fed 100644 --- a/crates/lean_vm/src/diagnostics/error.rs +++ b/crates/lean_vm/src/diagnostics/error.rs @@ -23,6 +23,7 @@ pub enum RunnerError { range: usize, }, InvalidExtensionOp, + InvalidSha256Input, ParallelSegmentFailed(usize, Box), } @@ -55,6 +56,7 @@ impl Display for RunnerError { ) } Self::InvalidExtensionOp => write!(f, "invalid extension op"), + Self::InvalidSha256Input => write!(f, "invalid sha256 input"), Self::ParallelSegmentFailed(id, err) => { write!(f, "parallel segment {id} failed: {err}") } diff --git a/crates/lean_vm/src/diagnostics/exec_result.rs b/crates/lean_vm/src/diagnostics/exec_result.rs index dcb1ae0cd..df9505c22 100644 --- a/crates/lean_vm/src/diagnostics/exec_result.rs +++ b/crates/lean_vm/src/diagnostics/exec_result.rs @@ -10,6 +10,7 @@ pub struct ExecutionMetadata { pub cycles: usize, pub memory: usize, pub n_poseidons: usize, + pub n_sha256_compress: usize, pub n_extension_ops: usize, pub bytecode_size: usize, pub public_input_size: usize, @@ -57,6 +58,12 @@ impl ExecutionMetadata { self.cycles / self.n_poseidons )); } + if self.n_sha256_compress > 0 { + out.push_str(&format!( + "SHA256Compress calls: {}\n", + pretty_integer(self.n_sha256_compress) + )); + } if self.n_extension_ops > 0 { out.push_str(&format!( "ExtensionOp calls: {}\n", diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs index 5d244f62f..2c8e35cd2 100644 --- a/crates/lean_vm/src/execution/runner.rs +++ b/crates/lean_vm/src/execution/runner.rs @@ -331,6 +331,7 @@ fn execute_bytecode_helper( cycles: trace.pcs.len(), memory: memory.0.len(), n_poseidons: trace.tables[&Table::poseidon16()].columns[0].len(), + n_sha256_compress: trace.tables[&Table::sha256_compress()].columns[0].len(), n_extension_ops: trace.tables[&Table::extension_op()].columns[0].len(), bytecode_size: bytecode.code.len(), public_input_size: public_input.len(), diff --git a/crates/lean_vm/src/isa/instruction.rs b/crates/lean_vm/src/isa/instruction.rs index f0b7ef212..c05afd287 100644 --- a/crates/lean_vm/src/isa/instruction.rs +++ b/crates/lean_vm/src/isa/instruction.rs @@ -2,12 +2,11 @@ use super::Operation; use super::operands::{MemOrConstant, MemOrFpOrConstant}; -use crate::POSEIDON16_NAME; use crate::core::{F, Label}; use crate::diagnostics::RunnerError; use crate::execution::memory::MemoryAccess; use crate::tables::TableT; -use crate::{ExtensionOpMode, Table, TableTrace}; +use crate::{ExtensionOpMode, POSEIDON16_NAME, SHA256_COMPRESS_NAME, Table, TableTrace}; use backend::*; use std::collections::BTreeMap; use std::fmt::{Display, Formatter}; @@ -69,6 +68,7 @@ pub enum PrecompileCompTimeArgs { // hardcoded_offset_left = Some(offset_left): left_input = m[offset_left..offset_left+4] | m[arg_a..arg_a+4] (arg_a is the first runtime parameter) hardcoded_offset_left: Option, }, + Sha256Compress, ExtensionOp { size: S, mode: ExtensionOpMode, @@ -79,6 +79,7 @@ impl PrecompileCompTimeArgs { pub fn table(&self) -> Table { match self { Self::Poseidon16 { .. } => Table::poseidon16(), + Self::Sha256Compress => Table::sha256_compress(), Self::ExtensionOp { .. } => Table::extension_op(), } } @@ -92,6 +93,7 @@ impl PrecompileCompTimeArgs { half_output, hardcoded_offset_left: hardcoded_left_4.map(&mut f), }, + Self::Sha256Compress => PrecompileCompTimeArgs::Sha256Compress, Self::ExtensionOp { size, mode } => PrecompileCompTimeArgs::ExtensionOp { size: f(size), mode }, } } @@ -262,6 +264,9 @@ impl Display for PrecompileArgs { "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, half, hardcoded_left_4={off})" ), }, + PrecompileCompTimeArgs::Sha256Compress => { + write!(f, "{SHA256_COMPRESS_NAME}({arg_0}, {arg_1}, {res})") + } PrecompileCompTimeArgs::ExtensionOp { size, mode } => { write!(f, "{}({arg_0}, {arg_1}, {res}, {size})", mode.name()) } diff --git a/crates/lean_vm/src/tables/execution/mod.rs b/crates/lean_vm/src/tables/execution/mod.rs index 10b854c04..0f13e08fe 100644 --- a/crates/lean_vm/src/tables/execution/mod.rs +++ b/crates/lean_vm/src/tables/execution/mod.rs @@ -56,7 +56,7 @@ impl TableT for ExecutionTable { } } - fn padding_row(&self, zero_vec_ptr: usize, _null_hash_ptr: usize) -> Vec { + fn padding_row(&self, padding: &PaddingMemory) -> Vec { let mut padding_row = vec![F::ZERO; N_TOTAL_EXECUTION_COLUMNS + N_TEMPORARY_EXEC_COLUMNS]; padding_row[COL_PC] = F::from_usize(ENDING_PC); padding_row[COL_JUMP] = F::ONE; @@ -65,9 +65,9 @@ impl TableT for ExecutionTable { padding_row[COL_FLAG_B] = F::ONE; padding_row[COL_FLAG_C_FP] = F::ONE; // this is kind of arbitrary padding_row[COL_EXEC_NU_A] = F::ONE; // because at the end of program, we always jump (looping at pc=0, so condition = nu_a = 1) - padding_row[COL_MEM_ADDRESS_A] = F::from_usize(zero_vec_ptr); - padding_row[COL_MEM_ADDRESS_B] = F::from_usize(zero_vec_ptr); - padding_row[COL_MEM_ADDRESS_C] = F::from_usize(zero_vec_ptr); + padding_row[COL_MEM_ADDRESS_A] = F::from_usize(padding.zero_vec_ptr); + padding_row[COL_MEM_ADDRESS_B] = F::from_usize(padding.zero_vec_ptr); + padding_row[COL_MEM_ADDRESS_C] = F::from_usize(padding.zero_vec_ptr); padding_row } diff --git a/crates/lean_vm/src/tables/extension_op/mod.rs b/crates/lean_vm/src/tables/extension_op/mod.rs index 03cc0045c..c7cb4583c 100644 --- a/crates/lean_vm/src/tables/extension_op/mod.rs +++ b/crates/lean_vm/src/tables/extension_op/mod.rs @@ -122,14 +122,14 @@ impl TableT for ExtensionOpPrecompile { self.n_columns() + 2 // +2 for COL_ACTIVATION_FLAG and COL_AUX_EXTENSION_OP (non-AIR, used in bus logup) } - fn padding_row(&self, zero_vec_ptr: usize, _null_hash_ptr: usize) -> Vec { + fn padding_row(&self, padding: &PaddingMemory) -> Vec { let mut row = vec![F::ZERO; self.n_columns_total()]; row[COL_START] = F::ONE; row[COL_LEN] = F::ONE; row[COL_AUX_EXTENSION_OP] = F::from_usize(EXT_OP_LEN_MULTIPLIER); - row[COL_IDX_A] = F::from_usize(zero_vec_ptr); - row[COL_IDX_B] = F::from_usize(zero_vec_ptr); - row[COL_IDX_RES] = F::from_usize(zero_vec_ptr); + row[COL_IDX_A] = F::from_usize(padding.zero_vec_ptr); + row[COL_IDX_B] = F::from_usize(padding.zero_vec_ptr); + row[COL_IDX_RES] = F::from_usize(padding.zero_vec_ptr); row } diff --git a/crates/lean_vm/src/tables/mod.rs b/crates/lean_vm/src/tables/mod.rs index bf9523291..c90f1dd41 100644 --- a/crates/lean_vm/src/tables/mod.rs +++ b/crates/lean_vm/src/tables/mod.rs @@ -4,6 +4,9 @@ pub use extension_op::*; mod poseidon_16; pub use poseidon_16::*; +pub mod sha256_compress; +pub use sha256_compress::*; + mod table_enum; pub use table_enum::*; @@ -16,10 +19,10 @@ pub use execution::*; mod utils; pub(crate) use utils::*; -// `PRECOMPILE_DATA` is the bus discriminator separating the two precompile -// tables. Disjointness is by parity of bit 0: +// `PRECOMPILE_DATA` is the bus discriminator separating precompile tables: // // Poseidon16 (odd): 1 + 2·flag_half + 4·flag_left + 8·flag_left·offset_left +// Sha256 (even): 2 // ExtensionOp (even): 4·is_be + 8·flag_add + 16·flag_mul + 32·flag_poly_eq + 64·len // // Multiplying `offset_left` by `flag_left` is needed for soundness: see 3.4.1 in minimal_zkVM.pdf diff --git a/crates/lean_vm/src/tables/poseidon_16/mod.rs b/crates/lean_vm/src/tables/poseidon_16/mod.rs index 9d0994d12..d0964b823 100644 --- a/crates/lean_vm/src/tables/poseidon_16/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_16/mod.rs @@ -174,7 +174,7 @@ impl TableT for Poseidon16Precompile { } } - fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize) -> Vec { + fn padding_row(&self, padding: &PaddingMemory) -> Vec { let mut row = vec![F::ZERO; num_cols_total_poseidon_16()]; let ptrs: Vec<*mut F> = (0..num_cols_poseidon_16()) .map(|i| unsafe { row.as_mut_ptr().add(i) }) @@ -183,15 +183,15 @@ impl TableT for Poseidon16Precompile { let perm: &mut Poseidon1Cols16<&mut F> = unsafe { &mut *(ptrs.as_ptr() as *mut Poseidon1Cols16<&mut F>) }; perm.inputs.iter_mut().for_each(|x| **x = F::ZERO); *perm.flag_active = F::ZERO; - *perm.index_b = F::from_usize(zero_vec_ptr); - *perm.index_res = F::from_usize(null_hash_ptr); + *perm.index_b = F::from_usize(padding.zero_vec_ptr); + *perm.index_res = F::from_usize(padding.null_poseidon_16_hash_ptr); *perm.flag_half_output = F::ZERO; *perm.flag_hardcoded_left = F::ZERO; *perm.offset_hardcoded_left = F::ZERO; - *perm.effective_index_left_first = F::from_usize(zero_vec_ptr); - *perm.effective_index_left_second = F::from_usize(zero_vec_ptr + HALF_DIGEST_LEN); + *perm.effective_index_left_first = F::from_usize(padding.zero_vec_ptr); + *perm.effective_index_left_second = F::from_usize(padding.zero_vec_ptr + HALF_DIGEST_LEN); // Non-committed columns - row[POSEIDON_16_COL_INDEX_INPUT_LEFT] = F::from_usize(zero_vec_ptr); + row[POSEIDON_16_COL_INDEX_INPUT_LEFT] = F::from_usize(padding.zero_vec_ptr); row[POSEIDON_16_COL_PRECOMPILE_DATA] = F::from_usize(POSEIDON_PRECOMPILE_DATA); generate_trace_rows_for_perm(perm); diff --git a/crates/lean_vm/src/tables/sha256_compress/air.rs b/crates/lean_vm/src/tables/sha256_compress/air.rs new file mode 100644 index 000000000..cbab2af8e --- /dev/null +++ b/crates/lean_vm/src/tables/sha256_compress/air.rs @@ -0,0 +1,373 @@ +use backend::*; + +use super::{ + NUM_SHA256_COMPRESS_COLS, SHA256_BLOCK_WORDS, SHA256_CHAIN_LEN, SHA256_K, SHA256_PRECOMPILE_DATA, + SHA256_SCHEDULE_EXTENSIONS, SHA256_U32_LIMBS, SHA256_WORD_BITS, Sha256Cols, Sha256CompressCols, + Sha256CompressPrecompile, +}; +use crate::{EF, ExtraDataForBuses, eval_virtual_bus_column}; + +const BITS_PER_LIMB: usize = 16; + +impl Air for Sha256CompressPrecompile { + type ExtraData = ExtraDataForBuses; + + fn n_columns(&self) -> usize { + NUM_SHA256_COMPRESS_COLS + } + + fn degree_air(&self) -> usize { + 3 + } + + fn n_constraints(&self) -> usize { + 7840 + 32 + 1 + BUS as usize + } + + fn down_column_indexes(&self) -> Vec { + vec![] + } + + fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { + let cols: &Sha256CompressCols = { + let up = builder.up(); + let (prefix, shorts, suffix) = unsafe { up.align_to::>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + unsafe { &*shorts.as_ptr() } + }; + + if BUS { + builder.assert_zero_ef(eval_virtual_bus_column::( + extra_data, + cols.flag, + &[ + AB::IF::from_usize(SHA256_PRECOMPILE_DATA), + cols.state_ptr, + cols.block_ptr, + cols.out_ptr, + ], + )); + } else { + builder.declare_values(std::slice::from_ref(&cols.flag)); + builder.declare_values(&[ + AB::IF::from_usize(SHA256_PRECOMPILE_DATA), + cols.state_ptr, + cols.block_ptr, + cols.out_ptr, + ]); + } + + builder.assert_bool(cols.flag); + eval_sha256_air(builder, &cols.sha); + eval_block_limb_bridges(builder, &cols); + } +} + +fn eval_sha256_air(builder: &mut AB, local: &Sha256Cols) { + eval_bit_range_checks(builder, local); + eval_initial_state(builder, local); + eval_message_schedule(builder, local); + eval_compression(builder, local); + eval_finalization(builder, local); +} + +fn eval_bit_range_checks(builder: &mut AB, local: &Sha256Cols) { + for word in &local.w { + assert_bools(builder, word); + } + for word in &local.a_chain { + assert_bools(builder, word); + } + for word in &local.e_chain { + assert_bools(builder, word); + } +} + +fn eval_initial_state(builder: &mut AB, local: &Sha256Cols) { + for i in 0..4 { + let chain_idx = 3 - i; + assert_packed_equals_bits(builder, &local.h_in[i], &local.a_chain[chain_idx]); + } + for i in 0..4 { + let chain_idx = 3 - i; + assert_packed_equals_bits(builder, &local.h_in[4 + i], &local.e_chain[chain_idx]); + } +} + +fn eval_block_limb_bridges(builder: &mut AB, cols: &Sha256CompressCols) { + for i in 0..SHA256_BLOCK_WORDS { + assert_packed_equals_bits(builder, &cols.block_limbs[i], &cols.sha.w[i]); + } +} + +fn eval_message_schedule(builder: &mut AB, local: &Sha256Cols) { + for i in 0..SHA256_SCHEDULE_EXTENSIONS { + let t = i + SHA256_BLOCK_WORDS; + + assert_sigma_matches( + builder, + &local.w[t - 15], + SigmaSpec::SmallSigma0, + &local.sched_sigma0[i], + ); + assert_sigma_matches(builder, &local.w[t - 2], SigmaSpec::SmallSigma1, &local.sched_sigma1[i]); + + let w_tm7_packed = pack_word::(&local.w[t - 7]); + add2(builder, &local.sched_tmp[i], &local.sched_sigma1[i], &w_tm7_packed); + + let w_t_packed = pack_word::(&local.w[t]); + let sched_sigma0 = local.sched_sigma0[i]; + let w_tm16_packed = pack_word::(&local.w[t - 16]); + add3_expr_out(builder, &w_t_packed, &local.sched_tmp[i], &sched_sigma0, &w_tm16_packed); + } +} + +fn eval_compression(builder: &mut AB, local: &Sha256Cols) { + for (t, round) in local.rounds.iter().enumerate() { + let a_bits = &local.a_chain[t + 3]; + let b_bits = &local.a_chain[t + 2]; + let c_bits = &local.a_chain[t + 1]; + let d_bits = &local.a_chain[t]; + let e_bits = &local.e_chain[t + 3]; + let f_bits = &local.e_chain[t + 2]; + let g_bits = &local.e_chain[t + 1]; + let h_bits = &local.e_chain[t]; + + assert_sigma_matches(builder, e_bits, SigmaSpec::BigSigma1, &round.sigma1_e); + assert_ch_matches(builder, e_bits, f_bits, g_bits, &round.ch); + + let h_packed = pack_word::(h_bits); + add3(builder, &round.tmp1, &round.sigma1_e, &round.ch, &h_packed); + + let k = [ + AB::IF::from_u32(SHA256_K[t] & 0xffff), + AB::IF::from_u32(SHA256_K[t] >> BITS_PER_LIMB), + ]; + let w_packed = pack_word::(&local.w[t]); + add3(builder, &round.t1, &round.tmp1, &k, &w_packed); + + assert_sigma_matches(builder, a_bits, SigmaSpec::BigSigma0, &round.sigma0_a); + assert_maj_matches(builder, a_bits, b_bits, c_bits, &round.maj); + + let new_a_packed = pack_word::(&local.a_chain[t + 4]); + add3_expr_out(builder, &new_a_packed, &round.t1, &round.sigma0_a, &round.maj); + + let new_e_packed = pack_word::(&local.e_chain[t + 4]); + let d_packed = pack_word::(d_bits); + add2_expr_out(builder, &new_e_packed, &round.t1, &d_packed); + } +} + +fn eval_finalization(builder: &mut AB, local: &Sha256Cols) { + for i in 0..4 { + let final_bits = &local.a_chain[SHA256_CHAIN_LEN - 1 - i]; + let packed = pack_word::(final_bits); + add2(builder, &local.h_out[i], &local.h_in[i], &packed); + } + for i in 0..4 { + let final_bits = &local.e_chain[SHA256_CHAIN_LEN - 1 - i]; + let packed = pack_word::(final_bits); + add2(builder, &local.h_out[4 + i], &local.h_in[4 + i], &packed); + } +} + +#[inline] +fn assert_bools(builder: &mut AB, bits: &[AB::IF; SHA256_WORD_BITS]) { + for &bit in bits { + builder.assert_bool(bit); + } +} + +#[inline] +fn pack_word(bits: &[AB::IF; SHA256_WORD_BITS]) -> [AB::IF; SHA256_U32_LIMBS] { + [ + pack_bits_le::(&bits[..BITS_PER_LIMB]), + pack_bits_le::(&bits[BITS_PER_LIMB..]), + ] +} + +#[inline] +fn pack_bits_le(bits: &[AB::IF]) -> AB::IF { + let mut acc = AB::IF::ZERO; + for &bit in bits.iter().rev() { + acc = acc.double() + bit; + } + acc +} + +#[inline] +fn assert_packed_equals_bits( + builder: &mut AB, + packed: &[AB::IF; SHA256_U32_LIMBS], + bits: &[AB::IF; SHA256_WORD_BITS], +) { + let built = pack_word::(bits); + builder.assert_zero(packed[0] - built[0]); + builder.assert_zero(packed[1] - built[1]); +} + +#[inline] +fn add2( + builder: &mut AB, + a: &[AB::IF; SHA256_U32_LIMBS], + b: &[AB::IF; SHA256_U32_LIMBS], + c: &[AB::IF; SHA256_U32_LIMBS], +) { + add2_expr_out(builder, a, b, c); +} + +#[inline] +fn add3( + builder: &mut AB, + a: &[AB::IF; SHA256_U32_LIMBS], + b: &[AB::IF; SHA256_U32_LIMBS], + c: &[AB::IF; SHA256_U32_LIMBS], + d: &[AB::IF; SHA256_U32_LIMBS], +) { + add3_expr_out(builder, a, b, c, d); +} + +#[inline] +fn add2_expr_out( + builder: &mut AB, + a: &[AB::IF; SHA256_U32_LIMBS], + b: &[AB::IF; SHA256_U32_LIMBS], + c: &[AB::IF; SHA256_U32_LIMBS], +) { + let two_16 = AB::IF::from_usize(1 << BITS_PER_LIMB); + let two_32 = two_16.square(); + + let acc_16 = a[0] - b[0] - c[0]; + let acc_32 = a[1] - b[1] - c[1]; + let acc = acc_16 + acc_32 * two_16; + + builder.assert_zero(acc * (acc + two_32)); + builder.assert_zero(acc_16 * (acc_16 + two_16)); +} + +#[inline] +fn add3_expr_out( + builder: &mut AB, + a: &[AB::IF; SHA256_U32_LIMBS], + b: &[AB::IF; SHA256_U32_LIMBS], + c: &[AB::IF; SHA256_U32_LIMBS], + d: &[AB::IF; SHA256_U32_LIMBS], +) { + let two_16 = AB::IF::from_usize(1 << BITS_PER_LIMB); + let two_32 = two_16.square(); + + let acc_16 = a[0] - b[0] - c[0] - d[0]; + let acc_32 = a[1] - b[1] - c[1] - d[1]; + let acc = acc_16 + acc_32 * two_16; + + builder.assert_zero(acc * (acc + two_32) * (acc + two_32.double())); + builder.assert_zero(acc_16 * (acc_16 + two_16) * (acc_16 + two_16.double())); +} + +#[derive(Copy, Clone)] +enum SigmaSpec { + BigSigma0, + BigSigma1, + SmallSigma0, + SmallSigma1, +} + +#[derive(Copy, Clone)] +enum ShiftKind { + Rotate, + Logical, +} + +#[inline] +const fn sigma_params(spec: SigmaSpec) -> (usize, usize, usize, ShiftKind) { + match spec { + SigmaSpec::BigSigma0 => (2, 13, 22, ShiftKind::Rotate), + SigmaSpec::BigSigma1 => (6, 11, 25, ShiftKind::Rotate), + SigmaSpec::SmallSigma0 => (7, 18, 3, ShiftKind::Logical), + SigmaSpec::SmallSigma1 => (17, 19, 10, ShiftKind::Logical), + } +} + +fn assert_sigma_matches( + builder: &mut AB, + bits: &[AB::IF; SHA256_WORD_BITS], + spec: SigmaSpec, + packed: &[AB::IF; SHA256_U32_LIMBS], +) { + let (r1, r2, r3, kind) = sigma_params(spec); + let mut built = [AB::IF::ZERO; SHA256_U32_LIMBS]; + for (limb, slot) in built.iter_mut().enumerate() { + let lo = limb * BITS_PER_LIMB; + let hi = lo + BITS_PER_LIMB; + let mut acc = AB::IF::ZERO; + for i in (lo..hi).rev() { + let b1 = bits[(i + r1) % SHA256_WORD_BITS]; + let b2 = bits[(i + r2) % SHA256_WORD_BITS]; + let b3 = match kind { + ShiftKind::Rotate => bits[(i + r3) % SHA256_WORD_BITS], + ShiftKind::Logical => { + let src = i + r3; + if src < SHA256_WORD_BITS { + bits[src] + } else { + AB::IF::ZERO + } + } + }; + acc = acc.double() + b1.xor3(&b2, &b3); + } + *slot = acc; + } + + builder.assert_zero(packed[0] - built[0]); + builder.assert_zero(packed[1] - built[1]); +} + +fn assert_ch_matches( + builder: &mut AB, + e: &[AB::IF; SHA256_WORD_BITS], + f: &[AB::IF; SHA256_WORD_BITS], + g: &[AB::IF; SHA256_WORD_BITS], + packed: &[AB::IF; SHA256_U32_LIMBS], +) { + let mut built = [AB::IF::ZERO; SHA256_U32_LIMBS]; + for (limb, slot) in built.iter_mut().enumerate() { + let lo = limb * BITS_PER_LIMB; + let hi = lo + BITS_PER_LIMB; + let mut acc = AB::IF::ZERO; + for i in (lo..hi).rev() { + let ei = e[i]; + let ch_i = ei * f[i] + (AB::IF::ONE - ei) * g[i]; + acc = acc.double() + ch_i; + } + *slot = acc; + } + + builder.assert_zero(packed[0] - built[0]); + builder.assert_zero(packed[1] - built[1]); +} + +fn assert_maj_matches( + builder: &mut AB, + a: &[AB::IF; SHA256_WORD_BITS], + b: &[AB::IF; SHA256_WORD_BITS], + c: &[AB::IF; SHA256_WORD_BITS], + packed: &[AB::IF; SHA256_U32_LIMBS], +) { + let mut built = [AB::IF::ZERO; SHA256_U32_LIMBS]; + for (limb, slot) in built.iter_mut().enumerate() { + let lo = limb * BITS_PER_LIMB; + let hi = lo + BITS_PER_LIMB; + let mut acc = AB::IF::ZERO; + for i in (lo..hi).rev() { + let maj_i = a[i] * b[i] + c[i] * a[i].xor(&b[i]); + acc = acc.double() + maj_i; + } + *slot = acc; + } + + builder.assert_zero(packed[0] - built[0]); + builder.assert_zero(packed[1] - built[1]); +} diff --git a/crates/lean_vm/src/tables/sha256_compress/columns.rs b/crates/lean_vm/src/tables/sha256_compress/columns.rs new file mode 100644 index 000000000..60780fce0 --- /dev/null +++ b/crates/lean_vm/src/tables/sha256_compress/columns.rs @@ -0,0 +1,103 @@ +use core::{ + borrow::{Borrow, BorrowMut}, + mem::size_of, +}; + +use super::{ + SHA256_BLOCK_LIMBS, SHA256_BLOCK_WORDS, SHA256_COMPRESS_ROUNDS, SHA256_SCHEDULE_EXTENSIONS, SHA256_STATE_WORDS, + SHA256_U32_LIMBS, SHA256_WORD_BITS, +}; + +pub const SHA256_CHAIN_LEN: usize = 4 + SHA256_COMPRESS_ROUNDS; + +pub const SHA256_COL_FLAG: usize = 0; +pub const SHA256_COL_STATE_PTR: usize = 1; +pub const SHA256_COL_BLOCK_PTR: usize = 2; +pub const SHA256_COL_OUT_PTR: usize = 3; +pub const SHA256_COL_BLOCK_LIMBS_START: usize = 4; +pub const SHA256_COL_AIR_START: usize = SHA256_COL_BLOCK_LIMBS_START + SHA256_BLOCK_LIMBS; +pub const SHA256_COL_STATE_LIMBS_START: usize = SHA256_COL_AIR_START; +pub const SHA256_COL_OUT_LIMBS_START: usize = NUM_SHA256_COMPRESS_COLS - SHA256_STATE_WORDS * SHA256_U32_LIMBS; + +#[repr(C)] +#[derive(Debug)] +pub struct Sha256RoundCols { + pub sigma1_e: [T; SHA256_U32_LIMBS], + pub ch: [T; SHA256_U32_LIMBS], + pub tmp1: [T; SHA256_U32_LIMBS], + pub t1: [T; SHA256_U32_LIMBS], + pub sigma0_a: [T; SHA256_U32_LIMBS], + pub maj: [T; SHA256_U32_LIMBS], +} + +#[repr(C)] +#[derive(Debug)] +pub struct Sha256Cols { + pub h_in: [[T; SHA256_U32_LIMBS]; SHA256_STATE_WORDS], + pub a_chain: [[T; SHA256_WORD_BITS]; SHA256_CHAIN_LEN], + pub e_chain: [[T; SHA256_WORD_BITS]; SHA256_CHAIN_LEN], + pub w: [[T; SHA256_WORD_BITS]; SHA256_COMPRESS_ROUNDS], + pub sched_sigma0: [[T; SHA256_U32_LIMBS]; SHA256_SCHEDULE_EXTENSIONS], + pub sched_sigma1: [[T; SHA256_U32_LIMBS]; SHA256_SCHEDULE_EXTENSIONS], + pub sched_tmp: [[T; SHA256_U32_LIMBS]; SHA256_SCHEDULE_EXTENSIONS], + pub rounds: [Sha256RoundCols; SHA256_COMPRESS_ROUNDS], + pub h_out: [[T; SHA256_U32_LIMBS]; SHA256_STATE_WORDS], +} + +#[repr(C)] +#[derive(Debug)] +pub struct Sha256CompressCols { + pub flag: T, + pub state_ptr: T, + pub block_ptr: T, + pub out_ptr: T, + pub block_limbs: [[T; SHA256_U32_LIMBS]; SHA256_BLOCK_WORDS], + pub sha: Sha256Cols, +} + +pub const NUM_SHA256_AIR_COLS: usize = size_of::>(); +pub const NUM_SHA256_COMPRESS_COLS: usize = size_of::>(); + +impl Borrow> for [T] { + fn borrow(&self) -> &Sha256Cols { + debug_assert_eq!(self.len(), NUM_SHA256_AIR_COLS); + let (prefix, shorts, suffix) = unsafe { self.align_to::>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + &shorts[0] + } +} + +impl BorrowMut> for [T] { + fn borrow_mut(&mut self) -> &mut Sha256Cols { + debug_assert_eq!(self.len(), NUM_SHA256_AIR_COLS); + let (prefix, shorts, suffix) = unsafe { self.align_to_mut::>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + &mut shorts[0] + } +} + +impl Borrow> for [T] { + fn borrow(&self) -> &Sha256CompressCols { + debug_assert_eq!(self.len(), NUM_SHA256_COMPRESS_COLS); + let (prefix, shorts, suffix) = unsafe { self.align_to::>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + &shorts[0] + } +} + +impl BorrowMut> for [T] { + fn borrow_mut(&mut self) -> &mut Sha256CompressCols { + debug_assert_eq!(self.len(), NUM_SHA256_COMPRESS_COLS); + let (prefix, shorts, suffix) = unsafe { self.align_to_mut::>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + &mut shorts[0] + } +} diff --git a/crates/lean_vm/src/tables/sha256_compress/mod.rs b/crates/lean_vm/src/tables/sha256_compress/mod.rs new file mode 100644 index 000000000..731221715 --- /dev/null +++ b/crates/lean_vm/src/tables/sha256_compress/mod.rs @@ -0,0 +1,562 @@ +use crate::{F, PrecompileCompTimeArgs, RunnerError, Table}; +use backend::{PrimeCharacteristicRing, PrimeField32}; +use utils::ToUsize; + +mod air; + +mod columns; +pub use columns::*; + +mod trace_gen; +pub use trace_gen::*; + +pub const SHA256_STATE_WORDS: usize = 8; +pub const SHA256_BLOCK_WORDS: usize = 16; +pub const SHA256_INPUT_WORDS: usize = SHA256_BLOCK_WORDS + SHA256_STATE_WORDS; +pub const SHA256_WORD_BITS: usize = 32; +pub const SHA256_U32_LIMBS: usize = 2; +pub const SHA256_COMPRESS_ROUNDS: usize = 64; +pub const SHA256_SCHEDULE_EXTENSIONS: usize = SHA256_COMPRESS_ROUNDS - SHA256_BLOCK_WORDS; + +pub const SHA256_STATE_LIMBS: usize = SHA256_STATE_WORDS * SHA256_U32_LIMBS; +pub const SHA256_BLOCK_LIMBS: usize = SHA256_BLOCK_WORDS * SHA256_U32_LIMBS; + +pub const SHA256_IV: [u32; SHA256_STATE_WORDS] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +pub const SHA256_ABC_BLOCK: [u32; SHA256_BLOCK_WORDS] = [0x61626380, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x18]; + +pub const SHA256_ZERO_BLOCK: [u32; SHA256_BLOCK_WORDS] = [0; SHA256_BLOCK_WORDS]; + +pub const SHA256_PRECOMPILE_DATA: usize = 2; +pub const SHA256_COMPRESS_NAME: &str = "sha256_compress"; + +const SHA256_K: [u32; SHA256_COMPRESS_ROUNDS] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, 0xd807aa98, + 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, + 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 0x983e5152, 0xa831c66d, 0xb00327c8, + 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, + 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, + 0xd6990624, 0xf40e3585, 0x106aa070, 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, + 0x5b9cca4f, 0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, + 0xc67178f2, +]; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Sha256RoundWitness { + pub sigma1_e: u32, + pub ch: u32, + pub tmp1: u32, + pub t1: u32, + pub sigma0_a: u32, + pub maj: u32, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Sha256CompressionWitness { + pub h_in: [u32; SHA256_STATE_WORDS], + pub block: [u32; SHA256_BLOCK_WORDS], + pub w: [u32; SHA256_COMPRESS_ROUNDS], + pub sched_sigma0: [u32; SHA256_SCHEDULE_EXTENSIONS], + pub sched_sigma1: [u32; SHA256_SCHEDULE_EXTENSIONS], + pub sched_tmp: [u32; SHA256_SCHEDULE_EXTENSIONS], + pub a_chain: [u32; 4 + SHA256_COMPRESS_ROUNDS], + pub e_chain: [u32; 4 + SHA256_COMPRESS_ROUNDS], + pub rounds: [Sha256RoundWitness; SHA256_COMPRESS_ROUNDS], + pub h_out: [u32; SHA256_STATE_WORDS], +} + +pub fn generate_sha256_compression_witness( + h_in: [u32; SHA256_STATE_WORDS], + block: [u32; SHA256_BLOCK_WORDS], +) -> Sha256CompressionWitness { + let mut w = [0u32; SHA256_COMPRESS_ROUNDS]; + w[..SHA256_BLOCK_WORDS].copy_from_slice(&block); + + let mut sched_sigma0 = [0u32; SHA256_SCHEDULE_EXTENSIONS]; + let mut sched_sigma1 = [0u32; SHA256_SCHEDULE_EXTENSIONS]; + let mut sched_tmp = [0u32; SHA256_SCHEDULE_EXTENSIONS]; + + for t in SHA256_BLOCK_WORDS..SHA256_COMPRESS_ROUNDS { + let i = t - SHA256_BLOCK_WORDS; + let s0 = small_sigma0(w[t - 15]); + let s1 = small_sigma1(w[t - 2]); + let tmp = s1.wrapping_add(w[t - 7]); + w[t] = tmp.wrapping_add(s0).wrapping_add(w[t - 16]); + sched_sigma0[i] = s0; + sched_sigma1[i] = s1; + sched_tmp[i] = tmp; + } + + let mut a_chain = [0u32; 4 + SHA256_COMPRESS_ROUNDS]; + let mut e_chain = [0u32; 4 + SHA256_COMPRESS_ROUNDS]; + a_chain[0] = h_in[3]; + a_chain[1] = h_in[2]; + a_chain[2] = h_in[1]; + a_chain[3] = h_in[0]; + e_chain[0] = h_in[7]; + e_chain[1] = h_in[6]; + e_chain[2] = h_in[5]; + e_chain[3] = h_in[4]; + + let empty_round = Sha256RoundWitness { + sigma1_e: 0, + ch: 0, + tmp1: 0, + t1: 0, + sigma0_a: 0, + maj: 0, + }; + let mut rounds = [empty_round; SHA256_COMPRESS_ROUNDS]; + let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = h_in; + + for t in 0..SHA256_COMPRESS_ROUNDS { + let sigma1_e = big_sigma1(e); + let ch = ch(e, f, g); + let tmp1 = h.wrapping_add(sigma1_e).wrapping_add(ch); + let t1 = tmp1.wrapping_add(SHA256_K[t]).wrapping_add(w[t]); + let sigma0_a = big_sigma0(a); + let maj = maj(a, b, c); + let new_a = t1.wrapping_add(sigma0_a).wrapping_add(maj); + let new_e = d.wrapping_add(t1); + + rounds[t] = Sha256RoundWitness { + sigma1_e, + ch, + tmp1, + t1, + sigma0_a, + maj, + }; + a_chain[t + 4] = new_a; + e_chain[t + 4] = new_e; + + h = g; + g = f; + f = e; + e = new_e; + d = c; + c = b; + b = a; + a = new_a; + } + + let final_state = [a, b, c, d, e, f, g, h]; + let h_out = core::array::from_fn(|i| h_in[i].wrapping_add(final_state[i])); + + Sha256CompressionWitness { + h_in, + block, + w, + sched_sigma0, + sched_sigma1, + sched_tmp, + a_chain, + e_chain, + rounds, + h_out, + } +} + +pub fn sha256_compress_words( + h_in: [u32; SHA256_STATE_WORDS], + block: [u32; SHA256_BLOCK_WORDS], +) -> [u32; SHA256_STATE_WORDS] { + generate_sha256_compression_witness(h_in, block).h_out +} + +pub const fn u32_to_u16_limbs_le(word: u32) -> [u16; SHA256_U32_LIMBS] { + [(word & 0xffff) as u16, (word >> 16) as u16] +} + +pub const fn u16_limbs_le_to_u32(limbs: [u16; SHA256_U32_LIMBS]) -> u32 { + limbs[0] as u32 | ((limbs[1] as u32) << 16) +} + +pub fn words_to_u16_limbs_le(words: impl IntoIterator) -> Vec { + let mut limbs = Vec::new(); + for word in words { + let word_limbs = u32_to_u16_limbs_le(word); + limbs.extend_from_slice(&word_limbs); + } + limbs +} + +pub fn words_to_field_limbs_le(words: [u32; N]) -> Vec { + words_to_u16_limbs_le(words) + .into_iter() + .map(|limb| F::from_usize(usize::from(limb))) + .collect() +} + +pub fn u32_to_bits_le(word: u32) -> [bool; SHA256_WORD_BITS] { + core::array::from_fn(|i| ((word >> i) & 1) == 1) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Sha256CompressPrecompile; + +impl crate::TableT for Sha256CompressPrecompile { + fn name(&self) -> &'static str { + SHA256_COMPRESS_NAME + } + + fn table(&self) -> Table { + Table::sha256_compress() + } + + fn lookups(&self) -> Vec { + vec![ + crate::LookupIntoMemory { + index: SHA256_COL_STATE_PTR, + values: (SHA256_COL_STATE_LIMBS_START..SHA256_COL_STATE_LIMBS_START + SHA256_STATE_LIMBS).collect(), + }, + crate::LookupIntoMemory { + index: SHA256_COL_BLOCK_PTR, + values: (SHA256_COL_BLOCK_LIMBS_START..SHA256_COL_BLOCK_LIMBS_START + SHA256_BLOCK_LIMBS).collect(), + }, + crate::LookupIntoMemory { + index: SHA256_COL_OUT_PTR, + values: (SHA256_COL_OUT_LIMBS_START..SHA256_COL_OUT_LIMBS_START + SHA256_STATE_LIMBS).collect(), + }, + ] + } + + fn bus(&self) -> crate::Bus { + crate::Bus { + direction: crate::BusDirection::Pull, + selector: SHA256_COL_FLAG, + data: vec![ + crate::BusData::Constant(SHA256_PRECOMPILE_DATA), + crate::BusData::Column(SHA256_COL_STATE_PTR), + crate::BusData::Column(SHA256_COL_BLOCK_PTR), + crate::BusData::Column(SHA256_COL_OUT_PTR), + ], + } + } + + fn padding_row(&self, padding: &crate::PaddingMemory) -> Vec { + sha256_compress_trace_row( + F::ZERO, + F::from_usize(padding.sha256_state_ptr), + F::from_usize(padding.sha256_block_ptr), + F::from_usize(padding.sha256_out_ptr), + SHA256_IV, + SHA256_ZERO_BLOCK, + ) + } + + fn execute( + &self, + arg_a: F, + arg_b: F, + arg_c: F, + args: PrecompileCompTimeArgs, + ctx: &mut crate::InstructionContext<'_, M>, + ) -> Result<(), RunnerError> { + let PrecompileCompTimeArgs::Sha256Compress = args else { + unreachable!("Sha256Compress table called with non-Sha256Compress args"); + }; + + let state_ptr = arg_a.to_usize(); + let block_ptr = arg_b.to_usize(); + let out_ptr = arg_c.to_usize(); + + let h_in = field_limbs_to_words::(&ctx.memory.get_slice(state_ptr, SHA256_STATE_LIMBS)?)?; + let block = field_limbs_to_words::(&ctx.memory.get_slice(block_ptr, SHA256_BLOCK_LIMBS)?)?; + let witness = generate_sha256_compression_witness(h_in, block); + ctx.memory.set_slice(out_ptr, &words_to_field_limbs_le(witness.h_out))?; + + let trace = ctx.traces.get_mut(&self.table()).unwrap(); + push_sha256_compress_trace_row_from_witness(trace, F::ONE, arg_a, arg_b, arg_c, &witness); + + Ok(()) + } +} + +fn field_limbs_to_words(limbs: &[F]) -> Result<[u32; N], RunnerError> { + assert_eq!(limbs.len(), N * SHA256_U32_LIMBS); + let mut words = [0u32; N]; + for (word, limb_pair) in words.iter_mut().zip(limbs.chunks_exact(SHA256_U32_LIMBS)) { + let lo = limb_to_u16(limb_pair[0])?; + let hi = limb_to_u16(limb_pair[1])?; + *word = u16_limbs_le_to_u32([lo, hi]); + } + Ok(words) +} + +fn limb_to_u16(limb: F) -> Result { + let value = limb.as_canonical_u32(); + u16::try_from(value).map_err(|_| RunnerError::InvalidSha256Input) +} + +#[inline] +const fn small_sigma0(x: u32) -> u32 { + x.rotate_right(7) ^ x.rotate_right(18) ^ (x >> 3) +} + +#[inline] +const fn small_sigma1(x: u32) -> u32 { + x.rotate_right(17) ^ x.rotate_right(19) ^ (x >> 10) +} + +#[inline] +const fn big_sigma0(x: u32) -> u32 { + x.rotate_right(2) ^ x.rotate_right(13) ^ x.rotate_right(22) +} + +#[inline] +const fn big_sigma1(x: u32) -> u32 { + x.rotate_right(6) ^ x.rotate_right(11) ^ x.rotate_right(25) +} + +#[inline] +const fn ch(e: u32, f: u32, g: u32) -> u32 { + (e & f) ^ ((!e) & g) +} + +#[inline] +const fn maj(a: u32, b: u32, c: u32) -> u32 { + (a & b) ^ (a & c) ^ (b & c) +} + +#[cfg(test)] +mod tests { + use super::*; + use backend::{ + Air, PrimeCharacteristicRing, PrimeField32, SumcheckComputation, get_symbolic_constraints_and_bus_data_values, + }; + use core::borrow::Borrow; + use std::collections::BTreeMap; + + use crate::{ + EF, ExtraDataForBuses, InstructionContext, InstructionCounts, Memory, MemoryAccess, TableT, TableTrace, + }; + + fn words_to_hex(words: [u32; SHA256_STATE_WORDS]) -> String { + words.iter().map(|word| format!("{word:08x}")).collect() + } + + fn extract_packed_words(limbs: &[[F; SHA256_U32_LIMBS]; SHA256_STATE_WORDS]) -> [u32; SHA256_STATE_WORDS] { + core::array::from_fn(|i| { + let lo = limbs[i][0].as_canonical_u32(); + let hi = limbs[i][1].as_canonical_u32(); + lo | (hi << 16) + }) + } + + fn extract_trace_output(row: &[F]) -> [u32; SHA256_STATE_WORDS] { + let cols: &Sha256CompressCols = row.borrow(); + extract_packed_words(&cols.sha.h_out) + } + + fn sha2_compress_reference( + block: [u32; SHA256_BLOCK_WORDS], + h_in: [u32; SHA256_STATE_WORDS], + ) -> [u32; SHA256_STATE_WORDS] { + let mut block_bytes = [0u8; 64]; + for (i, word) in block.iter().enumerate() { + block_bytes[i * 4..i * 4 + 4].copy_from_slice(&word.to_be_bytes()); + } + let mut state = h_in; + sha2::block_api::compress256(&mut state, core::slice::from_ref(&block_bytes)); + state + } + + fn air_extra_data(n_constraints: usize) -> ExtraDataForBuses { + let mut powers = Vec::with_capacity(n_constraints + 1); + let alpha = EF::from(F::from_usize(7)); + let mut current = EF::ONE; + for _ in 0..=n_constraints { + powers.push(current); + current *= alpha; + } + ExtraDataForBuses::new(Vec::new(), EF::ZERO, powers) + } + + #[test] + fn abc_single_block_matches_known_digest() { + let out = sha256_compress_words(SHA256_IV, SHA256_ABC_BLOCK); + assert_eq!( + words_to_hex(out), + "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad" + ); + } + + #[test] + fn low_then_high_limb_order_is_locked() { + assert_eq!(u32_to_u16_limbs_le(0x61626380), [0x6380, 0x6162]); + assert_eq!(u32_to_u16_limbs_le(0x18), [0x0018, 0x0000]); + assert_eq!(u16_limbs_le_to_u32([0x6380, 0x6162]), 0x61626380); + } + + #[test] + fn column_counts_match_plonky3_baseline_plus_leanvm_prefix() { + assert_eq!(NUM_SHA256_AIR_COLS, 7488); + assert_eq!(SHA256_COL_AIR_START, 36); + assert_eq!(NUM_SHA256_COMPRESS_COLS, 7524); + } + + #[test] + fn trace_row_populates_prefix_block_and_output() { + let row = sha256_compress_trace_row( + F::ONE, + F::from_usize(10), + F::from_usize(20), + F::from_usize(30), + SHA256_IV, + SHA256_ABC_BLOCK, + ); + assert_eq!(row.len(), NUM_SHA256_COMPRESS_COLS); + + let cols: &Sha256CompressCols = row.as_slice().borrow(); + assert_eq!(cols.flag, F::ONE); + assert_eq!(cols.state_ptr, F::from_usize(10)); + assert_eq!(cols.block_ptr, F::from_usize(20)); + assert_eq!(cols.out_ptr, F::from_usize(30)); + assert_eq!(cols.block_limbs[0], [F::from_usize(0x6380), F::from_usize(0x6162)]); + assert_eq!(cols.block_limbs[15], [F::from_usize(0x18), F::ZERO]); + + assert_eq!( + words_to_hex(extract_packed_words(&cols.sha.h_out)), + "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad" + ); + } + + #[test] + fn trace_generation_matches_sha2_compress_reference() { + let cases = [ + (SHA256_IV, SHA256_ABC_BLOCK), + (SHA256_IV, SHA256_ZERO_BLOCK), + ( + [ + 0x0123_4567, + 0x89ab_cdef, + 0xfedc_ba98, + 0x7654_3210, + 0x0f1e_2d3c, + 0x4b5a_6978, + 0x8877_6655, + 0x4433_2211, + ], + [ + 0xffff_ffff, + 0, + 0x1357_9bdf, + 0x2468_ace0, + 0xdead_beef, + 0xcafe_babe, + 0x0001_0002, + 0x0003_0004, + 0x0102_0304, + 0x1111_2222, + 0x3333_4444, + 0x5555_6666, + 0x7777_8888, + 0x9999_aaaa, + 0xbbbb_cccc, + 0xdddd_eeee, + ], + ), + ]; + + for (h_in, block) in cases { + let row = sha256_compress_trace_row(F::ONE, F::ZERO, F::ZERO, F::ZERO, h_in, block); + assert_eq!(extract_trace_output(&row), sha2_compress_reference(block, h_in)); + } + } + + #[test] + fn symbolic_constraint_count_matches_declared_count() { + let table = Sha256CompressPrecompile::; + let (constraints, bus_flag, bus_data) = get_symbolic_constraints_and_bus_data_values::(&table); + assert_eq!(constraints.len(), table.n_constraints()); + assert_eq!( + bus_flag, + backend::SymbolicExpression::Variable(backend::SymbolicVariable::new(SHA256_COL_FLAG)) + ); + assert_eq!(bus_data.len(), 4); + } + + #[test] + fn generated_trace_row_satisfies_air_and_tampered_row_fails() { + let table = Sha256CompressPrecompile::; + let extra_data = air_extra_data(table.n_constraints()); + let row = sha256_compress_trace_row( + F::ONE, + F::from_usize(10), + F::from_usize(20), + F::from_usize(30), + SHA256_IV, + SHA256_ABC_BLOCK, + ); + + assert_eq!( + as SumcheckComputation>::eval_base(&table, &row, &extra_data), + EF::ZERO + ); + + let mut tampered = row; + tampered[SHA256_COL_AIR_START] = F::TWO; + assert_ne!( + as SumcheckComputation>::eval_base(&table, &tampered, &extra_data), + EF::ZERO + ); + } + + #[test] + fn precompile_execute_writes_output_and_trace_row() { + let state_ptr = 0; + let block_ptr = SHA256_STATE_LIMBS; + let out_ptr = SHA256_STATE_LIMBS + SHA256_BLOCK_LIMBS; + + let mut memory = Memory::new(vec![]); + memory + .set_slice(state_ptr, &words_to_field_limbs_le(SHA256_IV)) + .unwrap(); + memory + .set_slice(block_ptr, &words_to_field_limbs_le(SHA256_ABC_BLOCK)) + .unwrap(); + + let table = Table::sha256_compress(); + let mut traces = BTreeMap::new(); + traces.insert(table, TableTrace::new(&Sha256CompressPrecompile::)); + let mut fp = 0; + let mut pc = 0; + let pcs = vec![0]; + let mut counts = InstructionCounts::default(); + let mut ctx = InstructionContext { + memory: &mut memory, + fp: &mut fp, + pc: &mut pc, + pcs: &pcs, + traces: &mut traces, + counts: &mut counts, + }; + + table + .execute( + F::from_usize(state_ptr), + F::from_usize(block_ptr), + F::from_usize(out_ptr), + PrecompileCompTimeArgs::Sha256Compress, + &mut ctx, + ) + .unwrap(); + + let out = ctx.memory.get_slice(out_ptr, SHA256_STATE_LIMBS).unwrap(); + let out_words = field_limbs_to_words::(&out).unwrap(); + assert_eq!( + words_to_hex(out_words), + "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad" + ); + + let trace = ctx.traces.get(&table).unwrap(); + assert_eq!(trace.columns.len(), NUM_SHA256_COMPRESS_COLS); + assert_eq!(trace.columns[SHA256_COL_FLAG], [F::ONE]); + assert_eq!(trace.columns[SHA256_COL_STATE_PTR], [F::from_usize(state_ptr)]); + assert_eq!(trace.columns[SHA256_COL_BLOCK_PTR], [F::from_usize(block_ptr)]); + assert_eq!(trace.columns[SHA256_COL_OUT_PTR], [F::from_usize(out_ptr)]); + } +} diff --git a/crates/lean_vm/src/tables/sha256_compress/trace_gen.rs b/crates/lean_vm/src/tables/sha256_compress/trace_gen.rs new file mode 100644 index 000000000..514a1226d --- /dev/null +++ b/crates/lean_vm/src/tables/sha256_compress/trace_gen.rs @@ -0,0 +1,126 @@ +use core::borrow::BorrowMut; + +use backend::PrimeCharacteristicRing; + +use crate::{F, TableTrace}; + +use super::{ + SHA256_BLOCK_WORDS, SHA256_COMPRESS_ROUNDS, SHA256_SCHEDULE_EXTENSIONS, SHA256_STATE_WORDS, Sha256Cols, + Sha256CompressCols, Sha256CompressionWitness, Sha256RoundCols, generate_sha256_compression_witness, u32_to_bits_le, + u32_to_u16_limbs_le, +}; + +pub fn sha256_compress_trace_row( + flag: F, + state_ptr: F, + block_ptr: F, + out_ptr: F, + h_in: [u32; SHA256_STATE_WORDS], + block: [u32; SHA256_BLOCK_WORDS], +) -> Vec { + let witness = generate_sha256_compression_witness(h_in, block); + sha256_compress_trace_row_from_witness(flag, state_ptr, block_ptr, out_ptr, &witness) +} + +pub fn sha256_compress_trace_row_from_witness( + flag: F, + state_ptr: F, + block_ptr: F, + out_ptr: F, + witness: &Sha256CompressionWitness, +) -> Vec { + let mut row = F::zero_vec(super::NUM_SHA256_COMPRESS_COLS); + let cols: &mut Sha256CompressCols = row.as_mut_slice().borrow_mut(); + fill_sha256_compress_cols(cols, flag, state_ptr, block_ptr, out_ptr, witness); + row +} + +pub fn push_sha256_compress_trace_row( + trace: &mut TableTrace, + flag: F, + state_ptr: F, + block_ptr: F, + out_ptr: F, + h_in: [u32; SHA256_STATE_WORDS], + block: [u32; SHA256_BLOCK_WORDS], +) { + let row = sha256_compress_trace_row(flag, state_ptr, block_ptr, out_ptr, h_in, block); + push_row(trace, row); +} + +pub fn push_sha256_compress_trace_row_from_witness( + trace: &mut TableTrace, + flag: F, + state_ptr: F, + block_ptr: F, + out_ptr: F, + witness: &Sha256CompressionWitness, +) { + let row = sha256_compress_trace_row_from_witness(flag, state_ptr, block_ptr, out_ptr, witness); + push_row(trace, row); +} + +fn push_row(trace: &mut TableTrace, row: Vec) { + debug_assert_eq!(trace.columns.len(), row.len()); + for (column, value) in trace.columns.iter_mut().zip(row) { + column.push(value); + } +} + +pub fn fill_sha256_compress_cols( + cols: &mut Sha256CompressCols, + flag: F, + state_ptr: F, + block_ptr: F, + out_ptr: F, + witness: &Sha256CompressionWitness, +) { + cols.flag = flag; + cols.state_ptr = state_ptr; + cols.block_ptr = block_ptr; + cols.out_ptr = out_ptr; + + for (dst, &word) in cols.block_limbs.iter_mut().zip(&witness.block) { + *dst = word_limbs(word); + } + + fill_sha256_air_cols(&mut cols.sha, witness); +} + +pub fn fill_sha256_air_cols(cols: &mut Sha256Cols, witness: &Sha256CompressionWitness) { + for i in 0..SHA256_STATE_WORDS { + cols.h_in[i] = word_limbs(witness.h_in[i]); + cols.h_out[i] = word_limbs(witness.h_out[i]); + } + + for i in 0..(4 + SHA256_COMPRESS_ROUNDS) { + cols.a_chain[i] = word_bits(witness.a_chain[i]); + cols.e_chain[i] = word_bits(witness.e_chain[i]); + } + + for i in 0..SHA256_COMPRESS_ROUNDS { + cols.w[i] = word_bits(witness.w[i]); + cols.rounds[i] = Sha256RoundCols { + sigma1_e: word_limbs(witness.rounds[i].sigma1_e), + ch: word_limbs(witness.rounds[i].ch), + tmp1: word_limbs(witness.rounds[i].tmp1), + t1: word_limbs(witness.rounds[i].t1), + sigma0_a: word_limbs(witness.rounds[i].sigma0_a), + maj: word_limbs(witness.rounds[i].maj), + }; + } + + for i in 0..SHA256_SCHEDULE_EXTENSIONS { + cols.sched_sigma0[i] = word_limbs(witness.sched_sigma0[i]); + cols.sched_sigma1[i] = word_limbs(witness.sched_sigma1[i]); + cols.sched_tmp[i] = word_limbs(witness.sched_tmp[i]); + } +} + +fn word_limbs(word: u32) -> [F; 2] { + u32_to_u16_limbs_le(word).map(|limb| F::from_usize(usize::from(limb))) +} + +fn word_bits(word: u32) -> [F; 32] { + u32_to_bits_le(word).map(F::from_bool) +} diff --git a/crates/lean_vm/src/tables/table_enum.rs b/crates/lean_vm/src/tables/table_enum.rs index 55be30e28..cefb64a54 100644 --- a/crates/lean_vm/src/tables/table_enum.rs +++ b/crates/lean_vm/src/tables/table_enum.rs @@ -3,8 +3,13 @@ use backend::*; use crate::execution::memory::MemoryAccess; use crate::*; -pub const N_TABLES: usize = 3; -pub const ALL_TABLES: [Table; N_TABLES] = [Table::execution(), Table::extension_op(), Table::poseidon16()]; +pub const N_TABLES: usize = 4; +pub const ALL_TABLES: [Table; N_TABLES] = [ + Table::execution(), + Table::extension_op(), + Table::poseidon16(), + Table::sha256_compress(), +]; pub const MAX_PRECOMPILE_BUS_WIDTH: usize = 4; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -13,6 +18,7 @@ pub enum Table { Execution(ExecutionTable), ExtensionOp(ExtensionOpPrecompile), Poseidon16(Poseidon16Precompile), + Sha256Compress(Sha256CompressPrecompile), } #[macro_export] @@ -22,6 +28,7 @@ macro_rules! delegate_to_inner { match $self { Self::ExtensionOp(p) => p.$method($($($arg),*)?), Self::Poseidon16(p) => p.$method($($($arg),*)?), + Self::Sha256Compress(p) => p.$method($($($arg),*)?), Self::Execution(p) => p.$method($($($arg),*)?), } }; @@ -30,6 +37,7 @@ macro_rules! delegate_to_inner { match $self { Table::ExtensionOp(p) => $macro_name!(p), Table::Poseidon16(p) => $macro_name!(p), + Table::Sha256Compress(p) => $macro_name!(p), Table::Execution(p) => $macro_name!(p), } }; @@ -45,6 +53,9 @@ impl Table { pub const fn poseidon16() -> Self { Self::Poseidon16(Poseidon16Precompile) } + pub const fn sha256_compress() -> Self { + Self::Sha256Compress(Sha256CompressPrecompile) + } pub fn embed(&self) -> PF { PF::from_usize(self.index()) } @@ -69,8 +80,8 @@ impl TableT for Table { fn bus(&self) -> Bus { delegate_to_inner!(self, bus) } - fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize) -> Vec> { - delegate_to_inner!(self, padding_row, zero_vec_ptr, null_hash_ptr) + fn padding_row(&self, padding: &PaddingMemory) -> Vec> { + delegate_to_inner!(self, padding_row, padding) } fn execute( &self, diff --git a/crates/lean_vm/src/tables/table_trait.rs b/crates/lean_vm/src/tables/table_trait.rs index cbb773c61..cd650f194 100644 --- a/crates/lean_vm/src/tables/table_trait.rs +++ b/crates/lean_vm/src/tables/table_trait.rs @@ -46,6 +46,15 @@ pub struct Bus { pub data: Vec, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct PaddingMemory { + pub zero_vec_ptr: usize, + pub null_poseidon_16_hash_ptr: usize, + pub sha256_state_ptr: usize, + pub sha256_block_ptr: usize, + pub sha256_out_ptr: usize, +} + #[derive(Debug, Default)] pub struct TableTrace { pub columns: Vec>, @@ -126,7 +135,7 @@ pub trait TableT: Air { fn table(&self) -> Table; fn lookups(&self) -> Vec; fn bus(&self) -> Bus; - fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize) -> Vec; + fn padding_row(&self, padding: &PaddingMemory) -> Vec; fn execute( &self, arg_a: F,