Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ tracing-forest = { version = "0.3.0", features = ["ansi", "smallvec"] }
postcard = { version = "1.1.3", features = ["alloc"] }
lz4_flex = "0.13.0"
include_dir = "0.7"
sha2 = "0.11.0"


[features]
prox-gaps-conjecture = ["rec_aggregation/prox-gaps-conjecture"]
Expand Down
2 changes: 2 additions & 0 deletions crates/backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ tracing.workspace = true
fiat-shamir = { path = "fiat-shamir", package = "mt-fiat-shamir" }
koala-bear = { path = "koala-bear", package = "mt-koala-bear" }
utils = { path = "utils", package = "mt-utils" }
sha2.workspace = true

1 change: 1 addition & 0 deletions crates/backend/fiat-shamir/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ utils = { path = "../utils", package = "mt-utils" }
tracing.workspace = true
serde.workspace = true
rayon.workspace = true
sha2.workspace = true
143 changes: 142 additions & 1 deletion crates/backend/fiat-shamir/src/challenger.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use field::PrimeField64;
use std::marker::PhantomData;

use field::{PrimeField32, PrimeField64};
use sha2::{Digest, Sha256};
use symetric::Compression;

pub(crate) const RATE: usize = 8;
Expand Down Expand Up @@ -73,3 +76,141 @@ impl<F: PrimeField64, P: Compression<[F; WIDTH]>> Challenger<F, P> {
res
}
}

#[derive(Clone, Debug)]
pub struct ChallengerSha2<F> {
pub sha2: Sha256,
_marker: PhantomData<F>,
}

impl<F: PrimeField32> ChallengerSha2<F> {
pub fn new() -> Self {
Self {
sha2: Sha256::new(),
_marker: PhantomData,
}
}

pub fn observe(&mut self, value: [F; RATE]) {
for val in value {
self.sha2.update(val.as_canonical_u32().to_le_bytes());
}
}

pub fn observe_scalars(&mut self, scalars: &[F]) {
for chunk in scalars.chunks(RATE) {
let mut buffer = [F::ZERO; RATE];
for (i, val) in chunk.iter().enumerate() {
buffer[i] = *val;
}
self.observe(buffer);
}
}

pub fn observe_bytes(&mut self, bytes: &[u8]) {
self.sha2.update(bytes);
}

pub fn sample_many(&mut self, n: usize) -> Vec<[F; RATE]> {
let mut sampled = Vec::with_capacity(n + 1);
for i in 0..n + 1 {
sampled.push(self.sample_chunk(i));
}
let last = sampled.pop().unwrap();
self.sha2 = Sha256::new();
self.observe(last);
sampled
}

pub fn sample_in_range(&mut self, bits: usize, n_samples: usize) -> Vec<usize> {
assert!(bits < F::bits());
let sampled_fe = self.sample_many(n_samples.div_ceil(RATE)).into_iter().flatten();
let mut res = Vec::new();
for fe in sampled_fe.take(n_samples) {
let rand_usize = fe.as_canonical_u64() as usize;
res.push(rand_usize & ((1 << bits) - 1));
}
res
}

pub fn pow_grinding_sample_matches(&self, bits: usize) -> bool {
assert!(bits < F::bits());
let sample = self.sample_first_word(0, 0);
let rand_usize = sample.as_canonical_u64() as usize;
(rand_usize & ((1 << bits) - 1)) == 0
}

pub fn pow_grinding_witness_matches(&self, witness: F, bits: usize) -> bool {
let mut challenger = self.clone();
challenger.observe_scalars(&[witness]);
challenger.pow_grinding_sample_matches(bits)
}

fn sample_chunk(&self, domain_sep: usize) -> [F; RATE] {
let mut words = Vec::with_capacity(RATE);
for block_idx in 0u64.. {
let digest = self.sample_digest(domain_sep, block_idx);
for word in digest.chunks_exact(size_of::<u32>()) {
let word = u32::from_le_bytes(word.try_into().unwrap());
words.push(F::from_int(word));
if words.len() == RATE {
return words.try_into().unwrap();
}
}
}
unreachable!()
}

fn sample_first_word(&self, domain_sep: usize, block_idx: u64) -> F {
let digest = self.sample_digest(domain_sep, block_idx);
let word = u32::from_le_bytes(digest[..size_of::<u32>()].try_into().unwrap());
F::from_int(word)
}

fn sample_digest(&self, domain_sep: usize, block_idx: u64) -> sha2::digest::Output<Sha256> {
let mut hasher = self.sha2.clone();
hasher.update((domain_sep as u64).to_le_bytes());
hasher.update(block_idx.to_le_bytes());
hasher.finalize()
}
}

impl<F: PrimeField32> Default for ChallengerSha2<F> {
fn default() -> Self {
Self::new()
}
}

#[cfg(test)]
mod tests {
use field::PrimeCharacteristicRing;
use koala_bear::KoalaBear;

use super::ChallengerSha2;

#[test]
fn sha2_pow_grinding_direct_predicate_matches_sampling_path() {
let transcript_prefixes = [
vec![],
vec![KoalaBear::ONE],
(0..17).map(KoalaBear::from_usize).collect::<Vec<_>>(),
(100..141).map(KoalaBear::from_usize).collect::<Vec<_>>(),
];

for prefix in transcript_prefixes {
let mut challenger = ChallengerSha2::new();
challenger.observe_scalars(&prefix);

for bits in 1..=20 {
for candidate in [0, 1, 2, 3, 5, 8, 13, 21, 55, 89, 144, 233, 377, 610] {
let witness = KoalaBear::from_usize(candidate);
let mut sampling_path = challenger.clone();
sampling_path.observe_scalars(&[witness]);
let expected = sampling_path.sample_in_range(bits, 1)[0] == 0;
let actual = challenger.pow_grinding_witness_matches(witness, bits);
assert_eq!(actual, expected);
}
}
}
}
}
27 changes: 14 additions & 13 deletions crates/backend/fiat-shamir/src/merkle_pruning.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
use serde::{Deserialize, Serialize};

use crate::{DIGEST_LEN_FE, MerklePath, MerklePaths};
use crate::{MerklePath, MerklePaths};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrunedMerklePaths<Data, F> {
pub struct PrunedMerklePaths<Data, Digest> {
pub merkle_height: usize,
pub original_order: Vec<usize>,
pub leaf_data: Vec<Vec<Data>>,
pub paths: Vec<(usize, Vec<[F; DIGEST_LEN_FE]>)>,
pub paths: Vec<(usize, Vec<Digest>)>,
pub n_trailing_zeros: usize,
}

fn lca_level(a: usize, b: usize) -> usize {
(usize::BITS - (a ^ b).leading_zeros()) as usize
}

impl<Data: Clone, F: Clone> MerklePaths<Data, F> {
pub fn prune(self) -> PrunedMerklePaths<Data, F>
impl<Data: Clone, Digest: Clone> MerklePaths<Data, Digest> {
pub fn prune(self) -> PrunedMerklePaths<Data, Digest>
where
Data: Default + PartialEq,
{
Expand All @@ -27,7 +27,7 @@ impl<Data: Clone, F: Clone> MerklePaths<Data, F> {
indexed.sort_by_key(|(_, p)| p.leaf_index);

let mut original_order = vec![0; indexed.len()];
let mut deduped = Vec::<MerklePath<Data, F>>::new();
let mut deduped = Vec::<MerklePath<Data, Digest>>::new();

for (orig_idx, path) in indexed {
if deduped.last().map(|p| p.leaf_index) == Some(path.leaf_index) {
Expand Down Expand Up @@ -83,12 +83,12 @@ impl<Data: Clone, F: Clone> MerklePaths<Data, F> {
}
}

impl<Data: Clone, F: Clone> PrunedMerklePaths<Data, F> {
impl<Data: Clone, Digest: Clone> PrunedMerklePaths<Data, Digest> {
pub fn restore(
mut self,
hash_leaf: &impl Fn(&[Data]) -> [F; DIGEST_LEN_FE],
hash_combine: &impl Fn(&[F; DIGEST_LEN_FE], &[F; DIGEST_LEN_FE]) -> [F; DIGEST_LEN_FE],
) -> Option<MerklePaths<Data, F>>
hash_leaf: &impl Fn(&[Data]) -> Digest,
hash_combine: &impl Fn(&Digest, &Digest) -> Digest,
) -> Option<MerklePaths<Data, Digest>>
where
Data: Default,
{
Expand All @@ -112,7 +112,7 @@ impl<Data: Clone, F: Clone> PrunedMerklePaths<Data, F> {
let skip = |i: usize| self.paths.get(i + 1).map(|p| lca_level(self.paths[i].0, p.0) - 1);

// Backward pass: compute subtree hashes needed to restore skipped siblings
let mut subtree_hashes: Vec<Vec<[F; DIGEST_LEN_FE]>> = vec![vec![]; n];
let mut subtree_hashes: Vec<Vec<Digest>> = vec![vec![]; n];

for i in (0..n).rev() {
let (leaf_idx, ref stored) = self.paths[i];
Expand All @@ -139,7 +139,7 @@ impl<Data: Clone, F: Clone> PrunedMerklePaths<Data, F> {
}

// Forward pass: build full sibling arrays
let mut restored: Vec<MerklePath<Data, F>> = Vec::with_capacity(n);
let mut restored: Vec<MerklePath<Data, Digest>> = Vec::with_capacity(n);

for i in 0..n {
let (leaf_idx, ref stored) = self.paths[i];
Expand Down Expand Up @@ -178,6 +178,7 @@ impl<Data: Clone, F: Clone> PrunedMerklePaths<Data, F> {
#[cfg(test)]
mod tests {
use super::*;
use crate::DIGEST_LEN_FE;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};

Expand Down Expand Up @@ -231,7 +232,7 @@ mod tests {
leaf_data: Vec<u8>,
leaf_index: usize,
tree: &[Vec<[u8; DIGEST_LEN_FE]>],
) -> MerklePath<u8, u8> {
) -> MerklePath<u8, [u8; DIGEST_LEN_FE]> {
let height = tree.len() - 1;
let mut sibling_hashes = Vec::with_capacity(height);

Expand Down
Loading