diff --git a/README.md b/README.md index b075ffd3d..e46ecedd5 100644 --- a/README.md +++ b/README.md @@ -18,10 +18,6 @@ Documentation: [PDF](minimal_zkVM.pdf) The VM design is inspired by the famous [Cairo paper](https://eprint.iacr.org/2021/1063.pdf). -## Security - -124 bits of provable security, given by Johnson bound + degree 5 extension of koala-bear. (128 bits would require hash digests of more than 8 field elements, todo?). In the benchmarks, we also display performance with conjectured security, even though leanVM targets the proven regime by default. - ## Benchmarks Machine: M4 Max 48GB (CPU only) @@ -31,13 +27,14 @@ Machine: M4 Max 48GB (CPU only) ### XMSS aggregation ```bash -cargo run --release -- xmss --n-signatures 1400 --log-inv-rate 1 +cargo run --release -- xmss --n-signatures 1500 --log-inv-rate 1 ``` | WHIR rate | Proven Regime | Proximity Gaps Conjecture | | --------- | --------------------- | ------------------------- | -| 1/2 | 1193 XMSS/s - 377 KiB | 1207 XMSS/s - 191 KiB | -| 1/4 | 863 XMSS/s - 243 KiB | 872 XMSS/s - 129 KiB | +| 1/2 | 1319 XMSS/s - 338 KiB | 1345 XMSS/s - 176 KiB | +| 1/4 | 961 XMSS/s - 228 KiB | 969 XMSS/s - 126 KiB | + (Proving throughput - proof size) @@ -53,14 +50,15 @@ cargo run --release -- recursion --n 2 --log-inv-rate 2 | n | WHIR rate | Proven Regime | Proximity Gaps Conjecture | | --- | --------- | --------------------------- | --------------------------- | -| 1 | 1/2 | 0.35s = 1 x 0.35s - 256 KiB | 0.24s = 1 x 0.24s - 146 KiB | -| 1 | 1/4 | 0.33s = 1 x 0.33s - 183 KiB | 0.26s = 1 x 0.26s - 98 KiB | -| 2 | 1/2 | 0.65s = 2 x 0.33s - 272 KiB | 0.43s = 2 x 0.21s - 157 KiB | -| 2 | 1/4 | 0.56s = 2 x 0.28s - 190 KiB | 0.41s = 2 x 0.21s - 101 KiB | -| 3 | 1/2 | 0.83s = 3 x 0.28s - 303 KiB | 0.62s = 3 x 0.21s - 150 KiB | -| 3 | 1/4 | 0.86s = 3 x 0.29s - 192 KiB | 0.71s = 3 x 0.24s - 107 KiB | -| 4 | 1/2 | 1.23s = 4 x 0.31s - 327 KiB | 0.76s = 4 x 0.19s - 166 KiB | -| 4 | 1/4 | 1.01s = 4 x 0.25s - 200 KiB | 0.76s = 4 x 0.19s - 106 KiB | +| 1 | 1/2 | 0.39s = 1 x 0.39s - 278 KiB | 0.24s = 1 x 0.24s - 147 KiB | +| 1 | 1/4 | 0.32s = 1 x 0.32s - 188 KiB | 0.27s = 1 x 0.27s - 100 KiB | +| 2 | 1/2 | 0.7s = 2 x 0.35s - 293 KiB | 0.43s = 2 x 0.21s - 157 KiB | +| 2 | 1/4 | 0.56s = 2 x 0.28s - 194 KiB | 0.43s = 2 x 0.22s - 102 KiB | +| 3 | 1/2 | 0.85s = 3 x 0.28s - 312 KiB | 0.63s = 3 x 0.21s - 150 KiB | +| 3 | 1/4 | 0.94s = 3 x 0.31s - 203 KiB | 0.73s = 3 x 0.24s - 108 KiB | +| 4 | 1/2 | 1.27s = 4 x 0.32s - 308 KiB | 0.78s = 4 x 0.2s - 166 KiB | +| 4 | 1/4 | 1.02s = 4 x 0.26s - 206 KiB | 0.79s = 4 x 0.2s - 108 KiB | + (time for n->1 recursive aggregation - proof size) @@ -75,6 +73,20 @@ cargo run --release -- fancy-aggregation (Proven regime) +## Security + +### snark + +≈ 124 bits of provable security, given by Johnson bound + degree 5 extension of koala-bear. (128 bits requires bigger hash digests (8 koalabears ≈ 248 bits) -> TODO). In the benchmarks, we also display performance with conjectured security, even though leanVM targets the proven regime by default. + +### XMSS + +Currently, we use an [XMSS](crates/xmss/xmss.md) with hash digests of 4 field elements ≈ 124 bits. Tweaks and public parameters ensure domain separation. An analysis in the ROM (resp. QROM), inspired by the section 3.1 of [Tight adaptive reprogramming in the QROM](https://arxiv.org/pdf/2010.15103) would lead to ≈ 124 (resp. 62) bits of classical (resp. quantum) security. Going to 128 / 64 bits of classical / quantum security, i.e. NIST level 1 (in the ROM/QROM), is an ongoing effort. It requires either: +- hash digests of 5 field elements (drawback: we need to double the hash chain length from 8 to 16 if we want to stay below one IPv6 MTU = 1280 bytes) +- a new prime, close to 32 bits (typically p = 125.2^25 + 1) or 64 bits ([goldilocks](https://2π.com/22/goldilocks/)) + +It's important to mention that a security analysis in the ROM / QROM is not the most conservative. In particular, [eprint 2025/055](https://eprint.iacr.org/2025/055.pdf)'s security proof holds in the standard model (at the cost of bigger hash digests): the implementation is available in the [leanSig](https://github.com/leanEthereum/leanSig) repository. A compatible version of leanMultisig can be found in the [devnet4](https://github.com/leanEthereum/leanMultisig/tree/devnet4) branch. + ## Credits - [Plonky3](https://github.com/Plonky3/Plonky3) for its various performant crates diff --git a/crates/backend/fiat-shamir/src/challenger.rs b/crates/backend/fiat-shamir/src/challenger.rs index d650d68aa..34fcd94ab 100644 --- a/crates/backend/fiat-shamir/src/challenger.rs +++ b/crates/backend/fiat-shamir/src/challenger.rs @@ -43,7 +43,7 @@ impl> Challenger { } pub fn sample_many(&mut self, n: usize) -> Vec<[F; RATE]> { - let mut sampled = Vec::with_capacity(n); + let mut sampled = Vec::with_capacity(n + 1); for i in 0..n + 1 { let mut domain_sep = [F::ZERO; RATE]; domain_sep[0] = F::from_usize(i); diff --git a/crates/lean_compiler/snark_lib.py b/crates/lean_compiler/snark_lib.py index f3b8aae8f..5d13b761f 100644 --- a/crates/lean_compiler/snark_lib.py +++ b/crates/lean_compiler/snark_lib.py @@ -67,8 +67,26 @@ def pop(self): self._data.pop() -def poseidon16_compress(left, right, output, mode): - _ = left, right, output, mode +def poseidon16_compress(left, right, output): + _ = left, right, output + + +def poseidon16_compress_half(left, right, output): + """Poseidon16 compression outputting only the first 4 FE (last 4 unconstrained).""" + _ = left, right, output + + +def poseidon16_compress_hardcoded_left(left, right, output, offset): + """Poseidon16 compression where the first 4 FE of the left input are read from + memory[offset..offset+4] instead of memory[left..left+4]. The last 4 FE of the + left input come from memory[left..left+4]. `offset` must be a compile-time + constant expression.""" + _ = left, right, output, offset + + +def poseidon16_compress_half_hardcoded_left(left, right, output, offset): + """Composition of `poseidon16_compress_half` and `poseidon16_compress_hardcoded_left`.""" + _ = left, right, output, offset def add_be(a, b, result, length=None): diff --git a/crates/lean_compiler/src/a_simplify_lang/mod.rs b/crates/lean_compiler/src/a_simplify_lang/mod.rs index dc54e0dd2..bc268fb80 100644 --- a/crates/lean_compiler/src/a_simplify_lang/mod.rs +++ b/crates/lean_compiler/src/a_simplify_lang/mod.rs @@ -6,8 +6,9 @@ use crate::{ }; use backend::PrimeCharacteristicRing; use lean_vm::{ - Boolean, BooleanExpr, CustomHint, ExtensionOpMode, FunctionName, PrecompileArgs, PrecompileCompTimeArgs, - SourceLocation, Table, TableT, + ALL_POSEIDON16_NAMES, Boolean, BooleanExpr, CustomHint, ExtensionOpMode, FunctionName, + POSEIDON16_HALF_HARDCODED_LEFT_NAME, POSEIDON16_HALF_NAME, POSEIDON16_HARDCODED_LEFT_NAME, PrecompileArgs, + PrecompileCompTimeArgs, SourceLocation, }; use std::{ collections::{BTreeMap, BTreeSet}, @@ -2258,16 +2259,27 @@ fn simplify_lines( continue; } - // Special handling for poseidon16 precompile - if function_name == Table::poseidon16().name() { + // Special handling for poseidon16 precompile (4 variants) + if ALL_POSEIDON16_NAMES.contains(&function_name.as_str()) { if !targets.is_empty() { return Err(format!( "Precompile {function_name} should not return values, at {location}" )); } - if args.len() != 3 { + let half_output = [POSEIDON16_HALF_NAME, POSEIDON16_HALF_HARDCODED_LEFT_NAME] + .contains(&function_name.as_str()); + let is_hardcoded_left = + [POSEIDON16_HARDCODED_LEFT_NAME, POSEIDON16_HALF_HARDCODED_LEFT_NAME] + .contains(&function_name.as_str()); + let expected_args = if is_hardcoded_left { 4 } else { 3 }; + if args.len() != expected_args { + let signature = if is_hardcoded_left { + "(ptr_a, ptr_b, ptr_res, offset)" + } else { + "(ptr_a, ptr_b, ptr_res)" + }; return Err(format!( - "Precompile {function_name} expects 3 arguments (ptr_a, ptr_b, ptr_res), got {}, at {location}", + "Precompile {function_name} expects {expected_args} arguments {signature}, got {}, at {location}", args.len() )); } @@ -2275,11 +2287,23 @@ fn simplify_lines( .iter() .map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res)) .collect::, _>>()?; + let hardcoded_offset_left = if is_hardcoded_left { + Some(simplified_args[3].as_constant().ok_or_else(|| { + format!( + "{function_name}: offset argument must be a compile-time constant, at {location}" + ) + })?) + } else { + None + }; res.push(SimpleLine::Precompile(PrecompileArgs { arg_0: simplified_args[0].clone(), arg_1: simplified_args[1].clone(), res: simplified_args[2].clone(), - data: PrecompileCompTimeArgs::Poseidon16, + data: PrecompileCompTimeArgs::Poseidon16 { + half_output, + hardcoded_offset_left, + }, })); continue; } diff --git a/crates/lean_compiler/src/instruction_encoder.rs b/crates/lean_compiler/src/instruction_encoder.rs index c97a4c3eb..1060e3be4 100644 --- a/crates/lean_compiler/src/instruction_encoder.rs +++ b/crates/lean_compiler/src/instruction_encoder.rs @@ -48,7 +48,17 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { } Instruction::Precompile(precompile) => { let precompile_data = match &precompile.data { - PrecompileCompTimeArgs::Poseidon16 => POSEIDON_PRECOMPILE_DATA, + PrecompileCompTimeArgs::Poseidon16 { + half_output, + hardcoded_offset_left, + } => { + let flag_left = hardcoded_offset_left.is_some() as usize; + let hardcoded_offset_left_val = hardcoded_offset_left.unwrap_or(0); + POSEIDON_PRECOMPILE_DATA + + POSEIDON_HALF_OUTPUT_SHIFT * (*half_output as usize) + + POSEIDON_HARDCODED_LEFT_4_FLAG_SHIFT * flag_left + + POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT * hardcoded_offset_left_val + } PrecompileCompTimeArgs::ExtensionOp { size, mode } => { assert!(*size >= 1, "invalid extension_op size={size}"); mode.flag_encoding() + EXT_OP_LEN_MULTIPLIER * size diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index 04fc1541b..576e5af60 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -8,7 +8,7 @@ use crate::{ grammar::{ParsePair, Rule}, }, }; -use lean_vm::{CUSTOM_HINTS, ExtensionOpMode, POSEIDON16_NAME}; +use lean_vm::{ALL_POSEIDON16_NAMES, CUSTOM_HINTS, ExtensionOpMode}; /// Reserved function names that users cannot define. pub const RESERVED_FUNCTION_NAMES: &[&str] = &[ @@ -33,8 +33,7 @@ fn is_reserved_function_name(name: &str) -> bool { if RESERVED_FUNCTION_NAMES.contains(&name) || CUSTOM_HINTS.iter().any(|hint| hint.name() == name) { return true; } - // Check precompile names (poseidon16, extension_op functions) - if name == POSEIDON16_NAME { + if ALL_POSEIDON16_NAMES.contains(&name) { return true; } if ExtensionOpMode::from_name(name).is_some() { diff --git a/crates/lean_prover/src/test_zkvm.rs b/crates/lean_prover/src/test_zkvm.rs index 7d5c29d88..d4742bad8 100644 --- a/crates/lean_prover/src/test_zkvm.rs +++ b/crates/lean_prover/src/test_zkvm.rs @@ -12,11 +12,43 @@ DIM = 5 N = 11 M = 3 DIGEST_LEN = 8 +HALF_DIGEST_LEN = 4 def main(): pub_start = 0 poseidon16_compress(pub_start + 4 * DIGEST_LEN, pub_start + 5 * DIGEST_LEN, pub_start + 6 * DIGEST_LEN) + # poseidon16_compress_half: only first 4 FE constrained + full_out = pub_start + 6 * DIGEST_LEN + half_out = pub_start + 80 + poseidon16_compress_half(pub_start + 4 * DIGEST_LEN, pub_start + 5 * DIGEST_LEN, half_out) + for i in unroll(0, HALF_DIGEST_LEN): + assert full_out[i] == half_out[i] + + # poseidon16_compress_hardcoded_left: with the new convention, only 4 FE are read + # at the left pointer (the 4-element data digest at pub_start + 1496) and the first + # 4 FE of the left input come from memory[pub_start + 1500 .. pub_start + 1504] + # (the hardcoded prefix). + hardcoded_left = pub_start + 1496 + hardcoded_full_out = pub_start + 1504 + poseidon16_compress_hardcoded_left( + hardcoded_left, + pub_start + 5 * DIGEST_LEN, + hardcoded_full_out, + pub_start + 1500 + ) + + # Same, but only first 4 FE of the output are constrained. + hardcoded_half_out = pub_start + 1512 + poseidon16_compress_half_hardcoded_left( + hardcoded_left, + pub_start + 5 * DIGEST_LEN, + hardcoded_half_out, + pub_start + 1500 + ) + for i in unroll(0, HALF_DIGEST_LEN): + assert hardcoded_full_out[i] == hardcoded_half_out[i] + base_ptr = pub_start + 88 ext_a_ptr = pub_start + 88 + N ext_b_ptr = pub_start + 88 + N * (DIM + 1) @@ -58,9 +90,38 @@ def main(): // Poseidon test data let poseidon_16_compress_input: [F; 16] = rng.random(); public_input[32..48].copy_from_slice(&poseidon_16_compress_input); - public_input[48..56].copy_from_slice(&poseidon16_compress(poseidon_16_compress_input)[..8]); + let poseidon_output = poseidon16_compress(poseidon_16_compress_input); + public_input[48..56].copy_from_slice(&poseidon_output[..8]); let poseidon_24_input: [F; 24] = rng.random(); public_input[56..80].copy_from_slice(&poseidon_24_input); + // poseidon16_compress_half output at offset 80: first 4 = hash, last 4 = arbitrary pre-existing data + public_input[80..84].copy_from_slice(&poseidon_output[..4]); + public_input[84..88].copy_from_slice(&[ + F::from_usize(111), + F::from_usize(222), + F::from_usize(333), + F::from_usize(444), + ]); + + let hardcoded_data: [F; 4] = rng.random(); + let hardcoded_prefix: [F; 4] = rng.random(); + public_input[1496..1500].copy_from_slice(&hardcoded_data); + public_input[1500..1504].copy_from_slice(&hardcoded_prefix); + let mut hardcoded_input = [F::ZERO; 16]; + hardcoded_input[..4].copy_from_slice(&hardcoded_prefix); + hardcoded_input[4..8].copy_from_slice(&hardcoded_data); + hardcoded_input[8..16].copy_from_slice(&poseidon_16_compress_input[8..16]); + let hardcoded_output = poseidon16_compress(hardcoded_input); + // Full output at 1504..1512 + public_input[1504..1512].copy_from_slice(&hardcoded_output); + // Half output at 1512..1520: first 4 = hash, last 4 = arbitrary pre-existing data + public_input[1512..1516].copy_from_slice(&hardcoded_output[..4]); + public_input[1516..1520].copy_from_slice(&[ + F::from_usize(555), + F::from_usize(666), + F::from_usize(777), + F::from_usize(888), + ]); // Extension op operands: base[N], ext_a[N], ext_b[N] let base_slice: [F; N] = rng.random(); diff --git a/crates/lean_prover/src/trace_gen.rs b/crates/lean_prover/src/trace_gen.rs index 331152a4b..1801a5b62 100644 --- a/crates/lean_prover/src/trace_gen.rs +++ b/crates/lean_prover/src/trace_gen.rs @@ -108,6 +108,28 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul let poseidon_trace = traces.get_mut(&Table::poseidon16()).unwrap(); fill_trace_poseidon_16(&mut poseidon_trace.columns); + // For half_output rows, override last 4 output columns with actual memory values + // (the AIR doesn't constrain them, but the lookup checks against memory). + { + let split = POSEIDON_16_COL_OUTPUT_START + HALF_DIGEST_LEN; + let (left, right) = poseidon_trace.columns.split_at_mut(split); + let half_output_col = &left[POSEIDON_16_COL_FLAG_HALF_OUTPUT]; + let res_col = &left[POSEIDON_16_COL_INDEX_INPUT_RES]; + let output_cols: &mut [Vec; HALF_DIGEST_LEN] = (&mut right[..HALF_DIGEST_LEN]).try_into().unwrap(); + + transposed_par_iter_mut(output_cols) + .zip(half_output_col) + .zip(res_col) + .for_each(|((row, &half), &res)| { + if half == F::ONE { + let base = res.to_usize() + HALF_DIGEST_LEN; + for j in 0..HALF_DIGEST_LEN { + *row[j] = memory_padded[base + j]; + } + } + }); + } + let extension_op_trace = traces.get_mut(&Table::extension_op()).unwrap(); fill_trace_extension_op(extension_op_trace, &memory_padded); diff --git a/crates/lean_vm/src/core/constants.rs b/crates/lean_vm/src/core/constants.rs index 59011adc2..afe6bc2d8 100644 --- a/crates/lean_vm/src/core/constants.rs +++ b/crates/lean_vm/src/core/constants.rs @@ -22,7 +22,7 @@ pub const MIN_BYTECODE_LOG_SIZE: usize = 8; /// Minimum and maximum number of rows per table (as powers of two), both inclusive pub const MIN_LOG_N_ROWS_PER_TABLE: usize = 8; // Zero padding will be added to each at least, if this minimum is not reached, (ensuring AIR / GKR work fine, with SIMD, without too much edge cases). Long term, we should find a more elegant solution. pub const MAX_LOG_N_ROWS_PER_TABLE: [(Table, usize); 3] = [ - (Table::execution(), 25), + (Table::execution(), 24), (Table::extension_op(), 21), (Table::poseidon16(), 21), ]; diff --git a/crates/lean_vm/src/isa/instruction.rs b/crates/lean_vm/src/isa/instruction.rs index ec635ed08..f0b7ef212 100644 --- a/crates/lean_vm/src/isa/instruction.rs +++ b/crates/lean_vm/src/isa/instruction.rs @@ -63,21 +63,35 @@ pub struct PrecompileArgs { #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum PrecompileCompTimeArgs { - Poseidon16, - ExtensionOp { size: S, mode: ExtensionOpMode }, + Poseidon16 { + half_output: bool, + // hardcoded_offset_left = None: left_input = m[arg_a..arg_a+8] + // hardcoded_offset_left = Some(offset_left): left_input = m[offset_left..offset_left+4] | m[arg_a..arg_a+4] (arg_a is the first runtime parameter) + hardcoded_offset_left: Option, + }, + ExtensionOp { + size: S, + mode: ExtensionOpMode, + }, } impl PrecompileCompTimeArgs { pub fn table(&self) -> Table { match self { - Self::Poseidon16 => Table::poseidon16(), + Self::Poseidon16 { .. } => Table::poseidon16(), Self::ExtensionOp { .. } => Table::extension_op(), } } - pub fn map_size(self, f: impl FnOnce(S) -> T) -> PrecompileCompTimeArgs { + pub fn map_size(self, mut f: impl FnMut(S) -> T) -> PrecompileCompTimeArgs { match self { - Self::Poseidon16 => PrecompileCompTimeArgs::Poseidon16, + Self::Poseidon16 { + half_output, + hardcoded_offset_left: hardcoded_left_4, + } => PrecompileCompTimeArgs::Poseidon16 { + half_output, + hardcoded_offset_left: hardcoded_left_4.map(&mut f), + }, Self::ExtensionOp { size, mode } => PrecompileCompTimeArgs::ExtensionOp { size: f(size), mode }, } } @@ -236,9 +250,18 @@ impl Display for PrecompileArgs { data, } = self; match data { - PrecompileCompTimeArgs::Poseidon16 => { - write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res})") - } + PrecompileCompTimeArgs::Poseidon16 { + half_output, + hardcoded_offset_left: hardcoded_left_4, + } => match (*half_output, hardcoded_left_4) { + (false, None) => write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res})"), + (true, None) => write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, half)"), + (false, Some(off)) => write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, hardcoded_left_4={off})"), + (true, Some(off)) => write!( + f, + "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, half, hardcoded_left_4={off})" + ), + }, PrecompileCompTimeArgs::ExtensionOp { size, mode } => { write!(f, "{}({arg_0}, {arg_1}, {res}, {size})", mode.name()) } diff --git a/crates/lean_vm/src/tables/extension_op/mod.rs b/crates/lean_vm/src/tables/extension_op/mod.rs index c50ac663d..03cc0045c 100644 --- a/crates/lean_vm/src/tables/extension_op/mod.rs +++ b/crates/lean_vm/src/tables/extension_op/mod.rs @@ -6,9 +6,7 @@ use air::*; mod exec; pub use exec::fill_trace_extension_op; -// domain separation: Poseidon16=1, Poseidon24= 2 or 3 or 4, ExtensionOp>=8 -/// Extension op PRECOMPILE_DATA bit-field encoding: -/// aux = 4*is_be + 8*flag_add + 16*flag_mul + 32*flag_poly_eq + 64*len +// `PRECOMPILE_DATA` encoding: see `tables/mod.rs`. pub(crate) const EXT_OP_FLAG_IS_BE: usize = 4; pub(crate) const EXT_OP_FLAG_ADD: usize = 8; pub(crate) const EXT_OP_FLAG_MUL: usize = 16; diff --git a/crates/lean_vm/src/tables/mod.rs b/crates/lean_vm/src/tables/mod.rs index 3010d39fd..bf9523291 100644 --- a/crates/lean_vm/src/tables/mod.rs +++ b/crates/lean_vm/src/tables/mod.rs @@ -15,3 +15,11 @@ pub use execution::*; mod utils; pub(crate) use utils::*; + +// `PRECOMPILE_DATA` is the bus discriminator separating the two precompile +// tables. Disjointness is by parity of bit 0: +// +// Poseidon16 (odd): 1 + 2·flag_half + 4·flag_left + 8·flag_left·offset_left +// ExtensionOp (even): 4·is_be + 8·flag_add + 16·flag_mul + 32·flag_poly_eq + 64·len +// +// Multiplying `offset_left` by `flag_left` is needed for soundness: see 3.4.1 in minimal_zkVM.pdf diff --git a/crates/lean_vm/src/tables/poseidon_16/mod.rs b/crates/lean_vm/src/tables/poseidon_16/mod.rs index 68a9a300d..5cffe5194 100644 --- a/crates/lean_vm/src/tables/poseidon_16/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_16/mod.rs @@ -89,16 +89,37 @@ const HALF_INITIAL_FULL_ROUNDS: usize = POSEIDON1_HALF_FULL_ROUNDS / 2; const PARTIAL_ROUNDS: usize = POSEIDON1_PARTIAL_ROUNDS; const HALF_FINAL_FULL_ROUNDS: usize = POSEIDON1_HALF_FULL_ROUNDS / 2; -pub const POSEIDON_PRECOMPILE_DATA: usize = 1; // domain separation: Poseidon16=1, Poseidon24=2 or 3 or 4, ExtensionOp>=8 +// `PRECOMPILE_DATA` encoding: see `tables/mod.rs`. +pub const POSEIDON_PRECOMPILE_DATA: usize = 1; +pub const POSEIDON_HALF_OUTPUT_SHIFT: usize = 1 << 1; +pub const POSEIDON_HARDCODED_LEFT_4_FLAG_SHIFT: usize = 1 << 2; +pub const POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT: usize = 1 << 3; pub const POSEIDON_16_COL_FLAG: ColIndex = 0; -pub const POSEIDON_16_COL_INDEX_INPUT_LEFT: ColIndex = 1; -pub const POSEIDON_16_COL_INDEX_INPUT_RIGHT: ColIndex = 2; -pub const POSEIDON_16_COL_INDEX_INPUT_RES: ColIndex = 3; -pub const POSEIDON_16_COL_INPUT_START: ColIndex = 4; +pub const POSEIDON_16_COL_INDEX_INPUT_RIGHT: ColIndex = 1; +pub const POSEIDON_16_COL_INDEX_INPUT_RES: ColIndex = 2; +pub const POSEIDON_16_COL_FLAG_HALF_OUTPUT: ColIndex = 3; +pub const POSEIDON_16_COL_FLAG_HARDCODED_LEFT: ColIndex = 4; +pub const POSEIDON_16_COL_OFFSET_LEFT_HARDCODED: ColIndex = 5; +pub const POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_FIRST: ColIndex = 6; +pub const POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_SECOND: ColIndex = 7; +pub const POSEIDON_16_COL_INPUT_START: ColIndex = 8; pub const POSEIDON_16_COL_OUTPUT_START: ColIndex = num_cols_poseidon_16() - 8; +/// Non-committed columns ("virtual"): +pub const POSEIDON_16_COL_INDEX_INPUT_LEFT: ColIndex = num_cols_poseidon_16(); +pub const POSEIDON_16_COL_PRECOMPILE_DATA: ColIndex = num_cols_poseidon_16() + 1; pub const POSEIDON16_NAME: &str = "poseidon16_compress"; +pub const POSEIDON16_HALF_NAME: &str = "poseidon16_compress_half"; +pub const POSEIDON16_HARDCODED_LEFT_NAME: &str = "poseidon16_compress_hardcoded_left"; +pub const POSEIDON16_HALF_HARDCODED_LEFT_NAME: &str = "poseidon16_compress_half_hardcoded_left"; +pub const ALL_POSEIDON16_NAMES: [&str; 4] = [ + POSEIDON16_NAME, + POSEIDON16_HALF_NAME, + POSEIDON16_HARDCODED_LEFT_NAME, + POSEIDON16_HALF_HARDCODED_LEFT_NAME, +]; +pub const HALF_DIGEST_LEN: usize = DIGEST_LEN / 2; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Poseidon16Precompile; @@ -115,8 +136,13 @@ impl TableT for Poseidon16Precompile { fn lookups(&self) -> Vec { vec![ LookupIntoMemory { - index: POSEIDON_16_COL_INDEX_INPUT_LEFT, - values: (POSEIDON_16_COL_INPUT_START..POSEIDON_16_COL_INPUT_START + DIGEST_LEN).collect(), + index: POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_FIRST, + values: (POSEIDON_16_COL_INPUT_START..POSEIDON_16_COL_INPUT_START + HALF_DIGEST_LEN).collect(), + }, + LookupIntoMemory { + index: POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_SECOND, + values: (POSEIDON_16_COL_INPUT_START + HALF_DIGEST_LEN..POSEIDON_16_COL_INPUT_START + DIGEST_LEN) + .collect(), }, LookupIntoMemory { index: POSEIDON_16_COL_INDEX_INPUT_RIGHT, @@ -130,10 +156,14 @@ impl TableT for Poseidon16Precompile { ] } + fn n_columns_total(&self) -> usize { + num_cols_total_poseidon_16() + } + #[allow(clippy::vec_init_then_push)] // https://github.com/leanEthereum/leanMultisig/issues/198 fn bus(&self) -> Bus { let mut data = Vec::with_capacity(4); - data.push(BusData::Constant(POSEIDON_PRECOMPILE_DATA)); + data.push(BusData::Column(POSEIDON_16_COL_PRECOMPILE_DATA)); data.push(BusData::Column(POSEIDON_16_COL_INDEX_INPUT_LEFT)); data.push(BusData::Column(POSEIDON_16_COL_INDEX_INPUT_RIGHT)); data.push(BusData::Column(POSEIDON_16_COL_INDEX_INPUT_RES)); @@ -145,17 +175,24 @@ impl TableT for Poseidon16Precompile { } fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize) -> Vec { - let mut row = vec![F::ZERO; num_cols_poseidon_16()]; + let mut row = vec![F::ZERO; num_cols_total_poseidon_16()]; let ptrs: Vec<*mut F> = (0..num_cols_poseidon_16()) .map(|i| unsafe { row.as_mut_ptr().add(i) }) .collect(); let perm: &mut Poseidon1Cols16<&mut F> = unsafe { &mut *(ptrs.as_ptr() as *mut Poseidon1Cols16<&mut F>) }; perm.inputs.iter_mut().for_each(|x| **x = F::ZERO); - *perm.flag = F::ZERO; - *perm.index_a = F::from_usize(zero_vec_ptr); + *perm.flag_active = F::ZERO; *perm.index_b = F::from_usize(zero_vec_ptr); *perm.index_res = F::from_usize(null_hash_ptr); + *perm.flag_half_output = F::ZERO; + *perm.flag_hardcoded_left = F::ZERO; + *perm.offset_hardcoded_left = F::ZERO; + *perm.effective_index_left_first = F::from_usize(zero_vec_ptr); + *perm.effective_index_left_second = F::from_usize(zero_vec_ptr + HALF_DIGEST_LEN); + // Non-committed columns + row[POSEIDON_16_COL_INDEX_INPUT_LEFT] = F::from_usize(zero_vec_ptr); + row[POSEIDON_16_COL_PRECOMPILE_DATA] = F::from_usize(POSEIDON_PRECOMPILE_DATA); generate_trace_rows_for_perm(perm); row @@ -167,31 +204,69 @@ impl TableT for Poseidon16Precompile { arg_a: F, arg_b: F, index_res_a: F, - _: PrecompileCompTimeArgs, + args: PrecompileCompTimeArgs, ctx: &mut InstructionContext<'_, M>, ) -> Result<(), RunnerError> { + let PrecompileCompTimeArgs::Poseidon16 { + half_output, + hardcoded_offset_left, + } = args + else { + unreachable!("Poseidon16 table called with non-Poseidon16 args"); + }; let trace = ctx.traces.get_mut(&self.table()).unwrap(); - let arg0 = ctx.memory.get_slice(arg_a.to_usize(), DIGEST_LEN)?; + let arg_a_usize = arg_a.to_usize(); + let flag_hardcoded = hardcoded_offset_left.is_some(); + // Convention: + // flag_hardcoded = 0: left input = m[arg_a..arg_a+8] (split as [arg_a..+4], [arg_a+4..+8]) + // flag_hardcoded = 1: left input = m[offset..offset+4] | m[arg_a..arg_a+4] + // (i.e. arg_a now points to a 4-element data digest, and the first 4 + // elements come from the hardcoded prefix at `offset`) + let left_first_addr = hardcoded_offset_left.unwrap_or(arg_a_usize); + let left_second_addr = if flag_hardcoded { + arg_a_usize + } else { + arg_a_usize + HALF_DIGEST_LEN + }; + let arg0_first = ctx.memory.get_slice(left_first_addr, HALF_DIGEST_LEN)?; + let arg0_second = ctx.memory.get_slice(left_second_addr, HALF_DIGEST_LEN)?; let arg1 = ctx.memory.get_slice(arg_b.to_usize(), DIGEST_LEN)?; let mut input = [F::ZERO; DIGEST_LEN * 2]; - input[..DIGEST_LEN].copy_from_slice(&arg0); + input[..HALF_DIGEST_LEN].copy_from_slice(&arg0_first); + input[HALF_DIGEST_LEN..DIGEST_LEN].copy_from_slice(&arg0_second); input[DIGEST_LEN..].copy_from_slice(&arg1); let output = poseidon16_compress(input); - let res_a: [F; DIGEST_LEN] = output[..DIGEST_LEN].try_into().unwrap(); + if half_output { + ctx.memory + .set_slice(index_res_a.to_usize(), &output[..HALF_DIGEST_LEN])?; + } else { + ctx.memory.set_slice(index_res_a.to_usize(), &output)?; + } - ctx.memory.set_slice(index_res_a.to_usize(), &res_a)?; + let hardcoded_offset_left_val = hardcoded_offset_left.unwrap_or(0); trace.columns[POSEIDON_16_COL_FLAG].push(F::ONE); - trace.columns[POSEIDON_16_COL_INDEX_INPUT_LEFT].push(arg_a); trace.columns[POSEIDON_16_COL_INDEX_INPUT_RIGHT].push(arg_b); trace.columns[POSEIDON_16_COL_INDEX_INPUT_RES].push(index_res_a); + trace.columns[POSEIDON_16_COL_FLAG_HALF_OUTPUT].push(if half_output { F::ONE } else { F::ZERO }); + trace.columns[POSEIDON_16_COL_FLAG_HARDCODED_LEFT].push(if flag_hardcoded { F::ONE } else { F::ZERO }); + trace.columns[POSEIDON_16_COL_OFFSET_LEFT_HARDCODED].push(F::from_usize(hardcoded_offset_left_val)); + trace.columns[POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_FIRST].push(F::from_usize(left_first_addr)); + trace.columns[POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_SECOND].push(F::from_usize(left_second_addr)); for (i, value) in input.iter().enumerate() { trace.columns[POSEIDON_16_COL_INPUT_START + i].push(*value); } + // Non-committed columns + trace.columns[POSEIDON_16_COL_INDEX_INPUT_LEFT].push(arg_a); + let precompile_data = POSEIDON_PRECOMPILE_DATA + + POSEIDON_HALF_OUTPUT_SHIFT * (half_output as usize) + + POSEIDON_HARDCODED_LEFT_4_FLAG_SHIFT * (flag_hardcoded as usize) + + POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT * hardcoded_offset_left_val; + trace.columns[POSEIDON_16_COL_PRECOMPILE_DATA].push(F::from_usize(precompile_data)); // the rest of the trace is filled at the end of the execution (to get parallelism + SIMD) @@ -205,7 +280,7 @@ impl Air for Poseidon16Precompile { num_cols_poseidon_16() } fn degree_air(&self) -> usize { - 9 + 10 // Last 4 output constraints are gated by (1 - half_output), raising degree from 9 to 10 } fn low_degree_air(&self) -> Option<(usize, usize)> { // Each partial round contributes one `assert_eq_low` per round (1 S-box / round), of degree 3 (= the "low" degree part) @@ -215,7 +290,7 @@ impl Air for Poseidon16Precompile { vec![] } fn n_constraints(&self) -> usize { - BUS as usize + 76 + BUS as usize + 80 } fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { let cols: Poseidon1Cols16 = { @@ -227,29 +302,36 @@ impl Air for Poseidon16Precompile { unsafe { std::ptr::read(&shorts[0]) } }; - // Bus data: [POSEIDON_PRECOMPILE_DATA (constant), a, b, res] + let precompile_data_reconstructed = AB::IF::ONE + + cols.flag_half_output * AB::F::from_usize(POSEIDON_HALF_OUTPUT_SHIFT) + + cols.flag_hardcoded_left * AB::F::from_usize(POSEIDON_HARDCODED_LEFT_4_FLAG_SHIFT) + + cols.flag_hardcoded_left + * cols.offset_hardcoded_left + * AB::F::from_usize(POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT); + + // effective_index_left_first = index_a * (1 - flag_hardcoded_left_4) + offset * flag_hardcoded_left_4 + let one_minus_flag_hardcoded_left = AB::IF::ONE - cols.flag_hardcoded_left; + let index_a = + cols.effective_index_left_second - one_minus_flag_hardcoded_left * AB::F::from_usize(HALF_DIGEST_LEN); + + // Bus data: [precompile_data, a, b, res] if BUS { builder.eval_virtual_column(eval_virtual_bus_column::( extra_data, - cols.flag, - &[ - AB::IF::from_usize(POSEIDON_PRECOMPILE_DATA), - cols.index_a, - cols.index_b, - cols.index_res, - ], + cols.flag_active, + &[precompile_data_reconstructed, index_a, cols.index_b, cols.index_res], )); } else { - builder.declare_values(std::slice::from_ref(&cols.flag)); - builder.declare_values(&[ - AB::IF::from_usize(POSEIDON_PRECOMPILE_DATA), - cols.index_a, - cols.index_b, - cols.index_res, - ]); + builder.declare_values(std::slice::from_ref(&cols.flag_active)); + builder.declare_values(&[precompile_data_reconstructed, index_a, cols.index_b, cols.index_res]); } - builder.assert_bool(cols.flag); + builder.assert_bool(cols.flag_active); + builder.assert_bool(cols.flag_half_output); + builder.assert_bool(cols.flag_hardcoded_left); + + builder.assert_zero(cols.flag_hardcoded_left * (cols.offset_hardcoded_left - cols.effective_index_left_first)); + builder.assert_zero(one_minus_flag_hardcoded_left * (index_a - cols.effective_index_left_first)); eval_poseidon1_16(builder, &cols) } @@ -258,10 +340,14 @@ impl Air for Poseidon16Precompile { #[repr(C)] #[derive(Debug)] pub(super) struct Poseidon1Cols16 { - pub flag: T, - pub index_a: T, + pub flag_active: T, // 0 = padding, 1 = active pub index_b: T, pub index_res: T, + pub flag_half_output: T, + pub flag_hardcoded_left: T, + pub offset_hardcoded_left: T, + pub effective_index_left_first: T, + pub effective_index_left_second: T, pub inputs: [T; WIDTH], pub beginning_full_rounds: [[T; WIDTH]; HALF_INITIAL_FULL_ROUNDS], @@ -329,6 +415,7 @@ fn eval_poseidon1_16(builder: &mut AB, local: &Poseidon1Cols16 usize { size_of::>() } +pub const fn num_cols_total_poseidon_16() -> usize { + // +2 for non-committed columns: POSEIDON_16_COL_INDEX_INPUT_LEFT, POSEIDON_16_COL_PRECOMPILE_DATA + num_cols_poseidon_16() + 2 +} + #[inline] fn eval_2_full_rounds_16( state: &mut [AB::IF; WIDTH], @@ -368,6 +460,7 @@ fn eval_last_2_full_rounds_16( outputs: &[AB::IF; WIDTH / 2], round_constants_1: &[F; WIDTH], round_constants_2: &[F; WIDTH], + flag_half_output: AB::IF, builder: &mut AB, ) { for (s, r) in state.iter_mut().zip(round_constants_1.iter()) { @@ -384,8 +477,15 @@ fn eval_last_2_full_rounds_16( for (state_i, init_state_i) in state.iter_mut().zip(initial_state) { *state_i += *init_state_i; } - for (state_i, output_i) in state.iter_mut().zip(outputs) { - builder.assert_eq(*state_i, *output_i); + let one_minus_flag_half_output = AB::IF::ONE - flag_half_output; + for (idx, (state_i, output_i)) in state.iter_mut().zip(outputs).enumerate() { + if idx < HALF_DIGEST_LEN { + // First 4 outputs: always constrained + builder.assert_eq(*state_i, *output_i); + } else { + // Last 4 outputs: constrained only when half_output = 0 + builder.assert_zero(one_minus_flag_half_output * (*state_i - *output_i)); + } *state_i = *output_i; } } diff --git a/crates/rec_aggregation/main.py b/crates/rec_aggregation/main.py index f0d637346..b914e379b 100644 --- a/crates/rec_aggregation/main.py +++ b/crates/rec_aggregation/main.py @@ -7,11 +7,12 @@ MAX_N_SIGS = 2**15 MAX_N_DUPS = 2**15 -INNER_PUB_MEM_SIZE = 2**INNER_PUBLIC_MEMORY_LOG_SIZE # = DIGEST_LEN - INPUT_DATA_SIZE_PADDED = INPUT_DATA_SIZE_PADDED_PLACEHOLDER INPUT_DATA_NUM_CHUNKS = INPUT_DATA_SIZE_PADDED / DIGEST_LEN -BYTECODE_CLAIM_OFFSET = 1 + DIGEST_LEN + MESSAGE_LEN + 2 + N_MERKLE_CHUNKS +# data_buf layout: n_sigs(1) + slice_hash(8) + message + merkle_chunks_for_slot +# + tweaks_hash(8) + bytecode_claim_padded + bytecode_hash_domsep(8) +TWEAKS_HASH_OFFSET = 1 + DIGEST_LEN + MESSAGE_LEN + N_MERKLE_CHUNKS +BYTECODE_CLAIM_OFFSET = TWEAKS_HASH_OFFSET + DIGEST_LEN BYTECODE_HASH_DOMSEP_OFFSET = BYTECODE_CLAIM_OFFSET + BYTECODE_CLAIM_SIZE_PADDED BYTECODE_SUMCHECK_PROOF_SIZE = BYTECODE_SUMCHECK_PROOF_SIZE_PLACEHOLDER @@ -21,6 +22,9 @@ def main(): pub_mem = 0 # See hashing.py for the memory layout build_preamble_memory() + tweak_table: Mut = TWEAK_TABLE_ADDR + hint_witness("tweak_table", tweak_table) + data_buf = Array(INPUT_DATA_SIZE_PADDED) hint_witness("input_data", data_buf) n_sigs = data_buf[0] @@ -28,10 +32,8 @@ def main(): assert n_sigs - 1 < MAX_N_SIGS pubkeys_hash_expected = data_buf + 1 message = pubkeys_hash_expected + DIGEST_LEN - slot_ptr = message + MESSAGE_LEN - slot_lo = slot_ptr[0] - slot_hi = slot_ptr[1] - merkle_chunks_for_slot = slot_ptr + 2 + merkle_chunks_for_slot = message + MESSAGE_LEN + tweaks_hash_expected = data_buf + TWEAKS_HASH_OFFSET bytecode_claim_output = data_buf + BYTECODE_CLAIM_OFFSET bytecode_hash_domsep = data_buf + BYTECODE_HASH_DOMSEP_OFFSET @@ -53,15 +55,17 @@ def main(): aggregate_sizes = Array(n_recursions) hint_witness("aggregate_sizes", aggregate_sizes) + computed_tweaks_hash = slice_hash(tweak_table, TWEAK_TABLE_SIZE_FE_PADDED / DIGEST_LEN) + copy_8(computed_tweaks_hash, tweaks_hash_expected) + # 1->1 optimization if n_recursions == 1: assert n_dup == 0 if n_raw_xmss == 0: inner_data_buf = build_inner_data_buf( - n_sigs, pubkeys_hash_expected, message, slot_lo, slot_hi, - merkle_chunks_for_slot, bytecode_hash_domsep, + n_sigs, pubkeys_hash_expected, message, + merkle_chunks_for_slot, tweaks_hash_expected, bytecode_hash_domsep, ) - inner_pub_mem = Array(INNER_PUB_MEM_SIZE) copy_8(slice_hash_with_iv(inner_data_buf, INPUT_DATA_NUM_CHUNKS), inner_pub_mem) bytecode_claims = Array(2) @@ -75,7 +79,7 @@ def main(): return # General path - computed_pubkeys_hash = slice_hash_with_iv_dynamic_unroll(all_pubkeys, n_sigs * DIGEST_LEN, MAX_LOG_MEMORY_SIZE) + computed_pubkeys_hash = slice_hash_with_iv_dynamic_unroll(all_pubkeys, n_sigs * PUB_KEY_SIZE, MAX_LOG_MEMORY_SIZE) copy_8(computed_pubkeys_hash, pubkeys_hash_expected) # Buffer for partition verification @@ -87,9 +91,9 @@ def main(): idx = raw_indices[i] assert idx < n_total buffer[idx] = i - # Verify raw XMSS signatures - pk = all_pubkeys + idx * DIGEST_LEN - xmss_verify(pk, message, slot_lo, slot_hi, merkle_chunks_for_slot) + # Verify raw XMSS signatures. + pk = all_pubkeys + idx * PUB_KEY_SIZE + xmss_verify(pk, message, merkle_chunks_for_slot) counter: Mut = n_raw_xmss @@ -109,7 +113,7 @@ def main(): assert idx0 < n_total buffer[idx0] = counter counter += 1 - pk0 = all_pubkeys + idx0 * DIGEST_LEN + pk0 = all_pubkeys + idx0 * PUB_KEY_SIZE running_hash: Mut = Array(DIGEST_LEN) poseidon16_compress(ZERO_VEC_PTR, pk0, running_hash) @@ -118,14 +122,14 @@ def main(): assert idx < n_total buffer[idx] = counter counter += 1 - pk = all_pubkeys + idx * DIGEST_LEN + pk = all_pubkeys + idx * PUB_KEY_SIZE new_hash = Array(DIGEST_LEN) poseidon16_compress(running_hash, pk, new_hash) running_hash = new_hash inner_data_buf = build_inner_data_buf( - n_sub, running_hash, message, slot_lo, slot_hi, - merkle_chunks_for_slot, bytecode_hash_domsep, + n_sub, running_hash, message, + merkle_chunks_for_slot, tweaks_hash_expected, bytecode_hash_domsep, ) inner_pub_mem = Array(INNER_PUB_MEM_SIZE) copy_8(slice_hash_with_iv(inner_data_buf, INPUT_DATA_NUM_CHUNKS), inner_pub_mem) @@ -198,18 +202,18 @@ def reduce_bytecode_claims(bytecode_claims, n_bytecode_claims, bytecode_claim_ou return @inline -def build_inner_data_buf(n_sub, pubkeys_hash, message, slot_lo, slot_hi, merkle_chunks_for_slot, bytecode_hash_domsep): +def build_inner_data_buf(n_sub, pubkeys_hash, message, merkle_chunks_for_slot, tweaks_hash, bytecode_hash_domsep): inner_data_buf = Array(INPUT_DATA_SIZE_PADDED) inner_data_buf[0] = n_sub copy_8(pubkeys_hash, inner_data_buf + 1) inner_msg = inner_data_buf + 1 + DIGEST_LEN - debug_assert(MESSAGE_LEN == 9) - copy_9(message, inner_msg) - inner_msg[MESSAGE_LEN] = slot_lo - inner_msg[MESSAGE_LEN + 1] = slot_hi + copy_8(message, inner_msg) for k in unroll(0, N_MERKLE_CHUNKS): - inner_msg[MESSAGE_LEN + 2 + k] = merkle_chunks_for_slot[k] + inner_msg[MESSAGE_LEN + k] = merkle_chunks_for_slot[k] + copy_8(tweaks_hash, inner_data_buf + TWEAKS_HASH_OFFSET) hint_witness("inner_bytecode_claim", inner_data_buf + BYTECODE_CLAIM_OFFSET) + for k in unroll(BYTECODE_CLAIM_OFFSET + BYTECODE_CLAIM_SIZE, BYTECODE_HASH_DOMSEP_OFFSET): + inner_data_buf[k] = 0 copy_8(bytecode_hash_domsep, inner_data_buf + BYTECODE_HASH_DOMSEP_OFFSET) for k in unroll(BYTECODE_HASH_DOMSEP_OFFSET + DIGEST_LEN, INPUT_DATA_SIZE_PADDED): inner_data_buf[k] = 0 diff --git a/crates/rec_aggregation/src/compilation.rs b/crates/rec_aggregation/src/compilation.rs index ff43dc183..7d8b11eef 100644 --- a/crates/rec_aggregation/src/compilation.rs +++ b/crates/rec_aggregation/src/compilation.rs @@ -11,7 +11,7 @@ use std::sync::OnceLock; use sub_protocols::{N_VARS_TO_SEND_GKR_COEFFS, min_stacked_n_vars, total_whir_statements}; use tracing::instrument; use utils::Counter; -use xmss::{LOG_LIFETIME, MESSAGE_LEN_FE, RANDOMNESS_LEN_FE, TARGET_SUM, V, V_GRINDING, W}; +use xmss::{LOG_LIFETIME, MESSAGE_LEN_FE, PUBLIC_PARAM_LEN_FE, RANDOMNESS_LEN_FE, TARGET_SUM, V, W, XMSS_DIGEST_LEN}; use crate::{MERKLE_LEVELS_PER_CHUNK_FOR_SLOT, N_MERKLE_CHUNKS_FOR_SLOT, NUM_REPEATED_ONES, ZERO_VEC_LEN}; @@ -27,14 +27,17 @@ pub fn init_aggregation_bytecode() { BYTECODE.get_or_init(compile_main_program_self_referential); } -fn compile_main_program(inner_program_log_size: usize, bytecode_zero_eval: F) -> Bytecode { - let bytecode_point_n_vars = inner_program_log_size + log2_ceil_usize(N_INSTRUCTION_COLUMNS); +fn compile_main_program(program_log_size: usize, bytecode_zero_eval: F) -> Bytecode { + let bytecode_point_n_vars = program_log_size + log2_ceil_usize(N_INSTRUCTION_COLUMNS); let claim_data_size = (bytecode_point_n_vars + 1) * DIMENSION; let claim_data_size_padded = claim_data_size.next_multiple_of(DIGEST_LEN); + // input_data_buf layout (part of the witness, "hinted" then hashed to a single digest that should match public input): + // n_sigs(1) + pubkeys_hash(8) + message + merkle_chunks_for_slot + // + tweaks_hash(8) + bytecode_claim_padded + bytecode_hash_domsep(8) let input_data_size = - 1 + DIGEST_LEN + MESSAGE_LEN_FE + 2 + N_MERKLE_CHUNKS_FOR_SLOT + claim_data_size_padded + DIGEST_LEN; + 1 + DIGEST_LEN + MESSAGE_LEN_FE + N_MERKLE_CHUNKS_FOR_SLOT + DIGEST_LEN + claim_data_size_padded + DIGEST_LEN; let input_data_size_padded = input_data_size.next_multiple_of(DIGEST_LEN); - let replacements = build_replacements(inner_program_log_size, bytecode_zero_eval, input_data_size_padded); + let replacements = build_replacements(program_log_size, bytecode_zero_eval, input_data_size_padded); let filepath = Path::new(env!("CARGO_MANIFEST_DIR")) .join("main.py") @@ -344,16 +347,20 @@ fn build_replacements( // XMSS-specific replacements replacements.insert("V_PLACEHOLDER".to_string(), V.to_string()); - replacements.insert("V_GRINDING_PLACEHOLDER".to_string(), V_GRINDING.to_string()); replacements.insert("W_PLACEHOLDER".to_string(), W.to_string()); replacements.insert("TARGET_SUM_PLACEHOLDER".to_string(), TARGET_SUM.to_string()); replacements.insert("LOG_LIFETIME_PLACEHOLDER".to_string(), LOG_LIFETIME.to_string()); replacements.insert("MESSAGE_LEN_PLACEHOLDER".to_string(), MESSAGE_LEN_FE.to_string()); replacements.insert("RANDOMNESS_LEN_PLACEHOLDER".to_string(), RANDOMNESS_LEN_FE.to_string()); + replacements.insert( + "PUBLIC_PARAM_LEN_FE_PLACEHOLDER".to_string(), + PUBLIC_PARAM_LEN_FE.to_string(), + ); replacements.insert( "MERKLE_LEVELS_PER_CHUNK_PLACEHOLDER".to_string(), MERKLE_LEVELS_PER_CHUNK_FOR_SLOT.to_string(), ); + replacements.insert("XMSS_DIGEST_LEN_PLACEHOLDER".to_string(), XMSS_DIGEST_LEN.to_string()); // Bytecode zero eval replacements.insert( diff --git a/crates/rec_aggregation/src/lib.rs b/crates/rec_aggregation/src/lib.rs index 9f0f569ab..a2062162e 100644 --- a/crates/rec_aggregation/src/lib.rs +++ b/crates/rec_aggregation/src/lib.rs @@ -7,7 +7,7 @@ use lean_prover::verify_execution::verify_execution; use lean_vm::*; use tracing::instrument; use utils::{build_prover_state, get_poseidon16, poseidon_compress_slice, poseidon16_compress_pair}; -use xmss::{LOG_LIFETIME, MESSAGE_LEN_FE, SIG_SIZE_FE, XmssPublicKey, XmssSignature, slot_to_field_elements}; +use xmss::{LOG_LIFETIME, MESSAGE_LEN_FE, PUB_KEY_FLAT_SIZE, V, W, WOTS_SIG_SIZE_FE, XmssPublicKey, XmssSignature}; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; @@ -20,11 +20,31 @@ mod compilation; const MERKLE_LEVELS_PER_CHUNK_FOR_SLOT: usize = 4; const N_MERKLE_CHUNKS_FOR_SLOT: usize = LOG_LIFETIME / MERKLE_LEVELS_PER_CHUNK_FOR_SLOT; +const CHAIN_LENGTH: usize = 1 << W; -// preamble memory layout: see `build_preamble_memory` in utils.py +// Tweak types (must match xmss crate) +const TWEAK_TYPE_CHAIN: usize = 0; +const TWEAK_TYPE_WOTS_PK: usize = 1; +const TWEAK_TYPE_MERKLE: usize = 2; +const TWEAK_TYPE_ENCODING: usize = 3; + +/// Number of tweaks in the table: 1 encoding + V*CHAIN_LENGTH chains + 1 wots_pk + LOG_LIFETIME merkle +const N_TWEAKS: usize = 1 + V * CHAIN_LENGTH + 1 + LOG_LIFETIME; +/// All, tweaks are stored as a 4-FE slot [tw[0], tw[1], 0, 0]. +const TWEAK_SLOT_SIZE: usize = 4; +const TWEAK_TABLE_SIZE_FE_PADDED: usize = (N_TWEAKS * TWEAK_SLOT_SIZE).next_multiple_of(DIGEST_LEN); + +const TWEAKS_HASHING_USE_IV: bool = false; // fixed size → no IV needed + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] +pub struct Digest(pub [F; DIGEST_LEN]); + +// preamble memory layout: see `build_preamble_memory` in utils.py: +// [000.. (ZERO_VEC_LEN)][10000000 (fiat-shamir domain sep)][10000 (one in extension field)][111... (NUM_REPEATED_ONES)][tweak table] pub const ZERO_VEC_LEN: usize = 16; pub const NUM_REPEATED_ONES: usize = 32; -pub const PREAMBLE_MEMORY_LEN: usize = ZERO_VEC_LEN + DIGEST_LEN + DIMENSION + NUM_REPEATED_ONES; +pub const PREAMBLE_MEMORY_LEN: usize = + ZERO_VEC_LEN + DIGEST_LEN + DIMENSION + NUM_REPEATED_ONES + TWEAK_TABLE_SIZE_FE_PADDED; #[derive(Debug, Clone)] pub struct AggregationTopology { @@ -59,9 +79,49 @@ pub(crate) fn count_signers(topology: &AggregationTopology) -> usize { topology.raw_xmss + child_count - topology.overlap * n_overlaps } -pub fn hash_pubkeys(pub_keys: &[XmssPublicKey]) -> [F; DIGEST_LEN] { - let flat: Vec = pub_keys.iter().flat_map(|pk| pk.merkle_root.iter().copied()).collect(); - poseidon_compress_slice(&flat, true) +pub fn hash_pubkeys(pub_keys: &[XmssPublicKey]) -> Digest { + let flat: Vec = pub_keys.iter().flat_map(|pk| pk.flaten().into_iter()).collect(); + Digest(poseidon_compress_slice(&flat, true)) +} + +fn make_tweak_values(tweak_type: usize, sub_position: usize, index: u32) -> [F; 2] { + let index_lo = (index & 0xFFFF) as usize; + let index_hi = (index >> 16) as usize; + [ + F::from_usize((tweak_type << 26) + (index_hi << 10) + sub_position), + F::from_usize(index_lo), + ] +} + +/// Tweak slots are 4-FE [tw[0], tw[1], 0, 0] +fn compute_tweak_table(slot: u32) -> Vec { + let mut table = Vec::new(); + + let push_padded = |table: &mut Vec, tweak_type: usize, sub_position: usize, index: u32| { + table.extend(make_tweak_values(tweak_type, sub_position, index)); + table.extend(std::iter::repeat_n(F::ZERO, 2)); + }; + + // Encoding tweak + push_padded(&mut table, TWEAK_TYPE_ENCODING, 0, slot); + + // Chain tweaks + for i in 0..V { + for s in 0..CHAIN_LENGTH { + push_padded(&mut table, TWEAK_TYPE_CHAIN, i * CHAIN_LENGTH + s, slot); + } + } + + // WOTS_PK tweak + push_padded(&mut table, TWEAK_TYPE_WOTS_PK, 0, slot); + + // Merkle tweaks + for level in 0..LOG_LIFETIME { + let parent_index = ((slot as u64) >> (level + 1)) as u32; + push_padded(&mut table, TWEAK_TYPE_MERKLE, level + 1, parent_index); + } + table.resize(TWEAK_TABLE_SIZE_FE_PADDED, F::ZERO); + table } fn compute_merkle_chunks_for_slot(slot: u32) -> Vec { @@ -86,6 +146,7 @@ fn build_input_data( slice_hash: &[F; DIGEST_LEN], message: &[F; MESSAGE_LEN_FE], slot: u32, + tweaks_hash: &[F; DIGEST_LEN], bytecode_claim_output: &[F], bytecode_hash: &[F; DIGEST_LEN], ) -> Vec { @@ -93,10 +154,8 @@ fn build_input_data( data.push(F::from_usize(n_sigs)); data.extend_from_slice(slice_hash); data.extend_from_slice(message); - let [slot_lo, slot_hi] = slot_to_field_elements(slot); - data.push(slot_lo); - data.push(slot_hi); data.extend(compute_merkle_chunks_for_slot(slot)); + data.extend_from_slice(tweaks_hash); data.extend_from_slice(bytecode_claim_output); // Pad the bytecode claim itself up to DIGEST_LEN let claim_padding = bytecode_claim_output.len().next_multiple_of(DIGEST_LEN) - bytecode_claim_output.len(); @@ -112,14 +171,11 @@ pub(crate) fn hash_input_data(data: &[F]) -> [F; DIGEST_LEN] { poseidon_compress_slice(data, true) } -fn encode_xmss_signature(sig: &XmssSignature) -> Vec { +fn encode_wots_signature(sig: &XmssSignature) -> Vec { let mut data = vec![]; data.extend(sig.wots_signature.randomness.to_vec()); data.extend(sig.wots_signature.chain_tips.iter().flat_map(|digest| digest.to_vec())); - for neighbor in &sig.merkle_proof { - data.extend(neighbor.to_vec()); - } - assert_eq!(data.len(), SIG_SIZE_FE); + assert_eq!(data.len(), WOTS_SIG_SIZE_FE); data } @@ -164,12 +220,15 @@ impl AggregatedXMSS { assert_eq!(bytecode_claim_output.len(), bytecode_claim_size); let slice_hash = hash_pubkeys(pub_keys); + let tweak_table = compute_tweak_table(slot); + let tweaks_hash = poseidon_compress_slice(&tweak_table, TWEAKS_HASHING_USE_IV); build_input_data( pub_keys.len(), - &slice_hash, + &slice_hash.0, message, slot, + &tweaks_hash, &bytecode_claim_output, &bytecode.hash, ) @@ -205,7 +264,7 @@ pub fn xmss_aggregate( log_inv_rate: usize, ) -> (Vec, AggregatedXMSS) { raw_xmss.sort_by(|(a, _), (b, _)| a.cmp(b)); - raw_xmss.dedup_by(|(a, _), (b, _)| a.merkle_root == b.merkle_root); + raw_xmss.dedup_by(|(a, _), (b, _)| a == b); let n_recursions = children.len(); let raw_count = raw_xmss.len(); @@ -225,6 +284,10 @@ pub fn xmss_aggregate( global_pub_keys.dedup(); let n_sigs = global_pub_keys.len(); + // Compute tweak table and its hash + let tweak_table = compute_tweak_table(slot); + let tweaks_hash = poseidon_compress_slice(&tweak_table, TWEAKS_HASHING_USE_IV); + // Verify child proofs let mut child_input_data = vec![]; let mut child_input_hashes = vec![]; @@ -242,7 +305,7 @@ pub fn xmss_aggregate( // Bytecode sumcheck reduction let (bytecode_claim_output, bytecode_point, final_sumcheck_transcript) = if n_recursions > 0 { - let bytecode_claim_offset = 1 + DIGEST_LEN + 2 + MESSAGE_LEN_FE + N_MERKLE_CHUNKS_FOR_SLOT; + let bytecode_claim_offset = 1 + DIGEST_LEN + MESSAGE_LEN_FE + N_MERKLE_CHUNKS_FOR_SLOT + DIGEST_LEN; let mut claims = vec![]; for (i, _child) in children.iter().enumerate() { let first_claim = extract_bytecode_claim_from_input_data( @@ -319,9 +382,10 @@ pub fn xmss_aggregate( let slice_hash = hash_pubkeys(&global_pub_keys); let pub_input_data = build_input_data( n_sigs, - &slice_hash, + &slice_hash.0, message, slot, + &tweaks_hash, &bytecode_claim_output, &bytecode.hash, ); @@ -330,7 +394,14 @@ pub fn xmss_aggregate( let mut claimed: HashSet = HashSet::new(); let mut dup_pub_keys: Vec = Vec::new(); - let xmss_signatures: Vec> = raw_xmss.iter().map(|(_, sig)| encode_xmss_signature(sig)).collect(); + // Raw XMSS data is split into two named hints — `wots` (randomness | chain_tips, + // one entry per signature) and `xmss_merkle_node` (one entry per 4-FE merkle node, + // flattened in the order `do_4_merkle_levels` consumes them at runtime). + let wots_blobs: Vec> = raw_xmss.iter().map(|(_, sig)| encode_wots_signature(sig)).collect(); + let xmss_merkle_node_blobs: Vec> = raw_xmss + .iter() + .flat_map(|(_, sig)| sig.merkle_proof.iter().map(|d| d.to_vec())) + .collect(); // Raw XMSS indices. let raw_indices: Vec = raw_xmss @@ -347,7 +418,7 @@ pub fn xmss_aggregate( let mut inner_bytecode_claim_blobs = Vec::with_capacity(n_recursions); let mut proof_transcript_blobs = Vec::with_capacity(n_recursions); - let claim_offset_in_input = 1 + DIGEST_LEN + 2 + MESSAGE_LEN_FE + N_MERKLE_CHUNKS_FOR_SLOT; + let claim_offset_in_input = 1 + DIGEST_LEN + MESSAGE_LEN_FE + N_MERKLE_CHUNKS_FOR_SLOT + DIGEST_LEN; let claim_size_padded = bytecode_claim_size.next_multiple_of(DIGEST_LEN); // Sources 1..n_recursions: recursive children @@ -375,12 +446,12 @@ pub fn xmss_aggregate( let n_dup = dup_pub_keys.len(); - let mut pubkeys_blob: Vec = Vec::with_capacity((n_sigs + n_dup) * DIGEST_LEN); + let mut pubkeys_blob: Vec = Vec::with_capacity((n_sigs + n_dup) * PUB_KEY_FLAT_SIZE); for pk in &global_pub_keys { - pubkeys_blob.extend_from_slice(&pk.merkle_root); + pubkeys_blob.extend_from_slice(&pk.flaten()); } for pk in &dup_pub_keys { - pubkeys_blob.extend_from_slice(&pk.merkle_root); + pubkeys_blob.extend_from_slice(&pk.flaten()); } let (merkle_leaf_blobs, merkle_path_blobs): (Vec>, Vec>) = child_raw_proofs @@ -422,10 +493,12 @@ pub fn xmss_aggregate( .collect(), ); hints.insert("proof_transcript".to_string(), proof_transcript_blobs); - hints.insert("xmss_signature".to_string(), xmss_signatures); + hints.insert("wots".to_string(), wots_blobs); + hints.insert("xmss_merkle_node".to_string(), xmss_merkle_node_blobs); hints.insert("merkle_leaf".to_string(), merkle_leaf_blobs); hints.insert("merkle_path".to_string(), merkle_path_blobs); hints.insert("aggregate_sizes".to_string(), vec![aggregate_sizes]); + hints.insert("tweak_table".to_string(), vec![tweak_table]); if n_recursions > 0 { hints.insert("bytecode_sumcheck_proof".to_string(), vec![final_sumcheck_transcript]); } diff --git a/crates/rec_aggregation/utils.py b/crates/rec_aggregation/utils.py index 752d9b359..3999f4548 100644 --- a/crates/rec_aggregation/utils.py +++ b/crates/rec_aggregation/utils.py @@ -360,6 +360,11 @@ def set_to_5_zeros(a): dot_product_ee(a, ONE_EF_PTR, zero_ptr) return +@inline +def copy_6(a, b): + dot_product_ee(a, ONE_EF_PTR, b) + a[5] = b[5] + return @inline def set_to_7_zeros(a): @@ -385,12 +390,6 @@ def copy_8(a, b): return -@inline -def copy_9(a, b): - dot_product_ee(a, ONE_EF_PTR, b) - dot_product_ee(a + (9 - DIM), ONE_EF_PTR, b + (9 - DIM)) - return - @inline def copy_16(a, b): dot_product_ee(a, ONE_EF_PTR, b) diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py index 9c51fd5ec..39090e6c6 100644 --- a/crates/rec_aggregation/xmss_aggregate.py +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -2,42 +2,62 @@ from utils import * V = V_PLACEHOLDER -V_GRINDING = V_GRINDING_PLACEHOLDER W = W_PLACEHOLDER CHAIN_LENGTH = 2**W TARGET_SUM = TARGET_SUM_PLACEHOLDER LOG_LIFETIME = LOG_LIFETIME_PLACEHOLDER MESSAGE_LEN = MESSAGE_LEN_PLACEHOLDER RANDOMNESS_LEN = RANDOMNESS_LEN_PLACEHOLDER -SIG_SIZE = RANDOMNESS_LEN + (V + LOG_LIFETIME) * DIGEST_LEN -NUM_ENCODING_FE = div_ceil((V + V_GRINDING), (24 / W)) # 24 should be divisible by W (works for W=2,3,4) +PUBLIC_PARAM_LEN_FE = PUBLIC_PARAM_LEN_FE_PLACEHOLDER +XMSS_DIGEST_LEN = XMSS_DIGEST_LEN_PLACEHOLDER +PUB_KEY_SIZE = XMSS_DIGEST_LEN + PUBLIC_PARAM_LEN_FE +PP_IN_LEFT = DIGEST_LEN - XMSS_DIGEST_LEN +WOTS_SIG_SIZE = RANDOMNESS_LEN + V * XMSS_DIGEST_LEN +# wots_public_key pair stride: each pair occupies 10 cells `[leading_0 | tip_a(4) | tip_b(4) | trailing_0]`. In order to be able to use copy_5 on both sides. +WOTS_PK_PAIR_STRIDE = 2 + 2 * XMSS_DIGEST_LEN +NUM_ENCODING_FE = div_ceil(V, (24 / W)) MERKLE_LEVELS_PER_CHUNK = MERKLE_LEVELS_PER_CHUNK_PLACEHOLDER N_MERKLE_CHUNKS = LOG_LIFETIME / MERKLE_LEVELS_PER_CHUNK - +INNER_PUB_MEM_SIZE = 2**INNER_PUBLIC_MEMORY_LOG_SIZE # = DIGEST_LEN +TWEAK_TABLE_ADDR = PREAMBLE_MEMORY_END + +# Tweak table layout: all tweaks are stored as a 4-FE slot [tw[0], tw[1], 0, 0] +TWEAK_LEN = 4 # stride / slot size for non-encoding tweaks +N_TWEAKS = 1 + V * CHAIN_LENGTH + 1 + LOG_LIFETIME +TWEAK_TABLE_SIZE_FE_PADDED = next_multiple_of(N_TWEAKS * TWEAK_LEN, DIGEST_LEN) +TWEAK_ENCODING_OFFSET = 0 +TWEAK_CHAIN_OFFSET = TWEAK_ENCODING_OFFSET + TWEAK_LEN # just after the encoding tweak +TWEAK_WOTS_PK_OFFSET = TWEAK_CHAIN_OFFSET + V * CHAIN_LENGTH * TWEAK_LEN +TWEAK_MERKLE_OFFSET = TWEAK_WOTS_PK_OFFSET + TWEAK_LEN @inline -def xmss_verify(merkle_root, message, slot_lo, slot_hi, merkle_chunks): - # signature: randomness | chain_tips | merkle_path - # return the hashed xmss public key - signature = Array(SIG_SIZE) - hint_witness("xmss_signature", signature) - randomness = signature - chain_starts = signature + RANDOMNESS_LEN - merkle_path = chain_starts + V * DIGEST_LEN - - # 1) We encode message_hash + randomness into the layer of the hypercube with target sum = TARGET_SUM - +def xmss_verify(pub_key, message, merkle_chunks): + wots = Array(WOTS_SIG_SIZE) + hint_witness("wots", wots) + + public_param = pub_key + XMSS_DIGEST_LEN + randomness = wots + chain_starts = wots + RANDOMNESS_LEN + + # 1) Encode: poseidon16_compress(message[0:8], [randomness(6) | tweak_encoding(2)) + # poseidon16_compress(pre_compressed, [pp(4) | zeros(4)]) + encoding_tweak = TWEAK_TABLE_ADDR + TWEAK_ENCODING_OFFSET a_input_right = Array(DIGEST_LEN) - b_input = Array(DIGEST_LEN * 2) - a_input_right[0] = message[DIGEST_LEN] - copy_7(randomness, a_input_right + 1) - poseidon16_compress(message, a_input_right, b_input) - b_input[DIGEST_LEN] = slot_lo - b_input[DIGEST_LEN + 1] = slot_hi - copy_6(merkle_root, b_input + DIGEST_LEN + 2) + copy_6(randomness, a_input_right) + a_input_right[6] = encoding_tweak[0] + a_input_right[7] = encoding_tweak[1] + pre_compressed = Array(DIGEST_LEN) + poseidon16_compress(message, a_input_right, pre_compressed) + + public_params_paded_buff = Array(DIGEST_LEN + 2) # 0 [public_param(4) | zeros(4)] 0 + copy_5(public_param - 1, public_params_paded_buff) + set_to_5_zeros(public_params_paded_buff + 5) + public_params_paded = public_params_paded_buff + 1 encoding_fe = Array(DIGEST_LEN) - poseidon16_compress(b_input, b_input + DIGEST_LEN, encoding_fe) + poseidon16_compress(pre_compressed, public_params_paded, encoding_fe) + # Decompose the encoding into chunks of 2*W bits. Each chunk packs the chain step + # counts of two consecutive WOTS chains: chunk i = step_{2i} + CHAIN_LENGTH * step_{2i+1}. encoding = Array(NUM_ENCODING_FE * 24 / (2 * W)) hint_decompose_bits_xmss(encoding, encoding_fe, NUM_ENCODING_FE, 2 * W) @@ -51,86 +71,141 @@ def xmss_verify(merkle_root, message, slot_lo, slot_hi, merkle_chunks): for j in unroll(1, 24 / (2 * W)): partial_sum += encoding[i * (24 / (2 * W)) + j] * (CHAIN_LENGTH**2) ** j - # p = 2^31 - 2^24 + 1, so inv(2^24) = -127 (mod p). + # p = 2^31 - 2^24 + 1 = 127.2^24 + 1, so inv(2^24) = -127 (mod p). # Deduce remaining_i from partial_sum + remaining_i * 2^24 == encoding_fe[i]: # remaining_i = (encoding_fe[i] - partial_sum) * inv(2^24) = (partial_sum - encoding_fe[i]) * 127 remaining_i = (partial_sum - encoding_fe[i]) * 127 - assert remaining_i < 2**7 - 1 # ensures uniformity + prevent overflow + assert remaining_i < 127 # ensures uniformity + prevent overflow - # grinding - debug_assert(V_GRINDING % 2 == 0) - debug_assert(V % 2 == 0) - for i in unroll(V / 2, (V + V_GRINDING) / 2): - assert encoding[i] == CHAIN_LENGTH**2 - 1 + debug_assert(V % 2 == 0) + wots_public_key = Array((V / 2) * WOTS_PK_PAIR_STRIDE) target_sum: Mut = 0 - - wots_public_key = Array(V * DIGEST_LEN) - - local_zero_buff = Array(DIGEST_LEN) - set_to_8_zeros(local_zero_buff) - for i in unroll(0, V / 2): - # num_hashes = (CHAIN_LENGTH - 1) - encoding[i] - chain_start = chain_starts + i * (DIGEST_LEN * 2) - chain_end = wots_public_key + i * (DIGEST_LEN * 2) - pair_chain_length_sum_ptr = Array(1) + chain_start_a = chain_starts + (2 * i) * XMSS_DIGEST_LEN + chain_start_b = chain_starts + (2 * i + 1) * XMSS_DIGEST_LEN + chain_end_a = wots_public_key + i * WOTS_PK_PAIR_STRIDE + 1 + chain_end_b = chain_end_a + XMSS_DIGEST_LEN + tweaks_a = TWEAK_TABLE_ADDR + TWEAK_CHAIN_OFFSET + (2 * i) * CHAIN_LENGTH * TWEAK_LEN + tweaks_b = TWEAK_TABLE_ADDR + TWEAK_CHAIN_OFFSET + (2 * i + 1) * CHAIN_LENGTH * TWEAK_LEN + pair_sum_ptr = Array(1) + match_range( - encoding[i], range(0, CHAIN_LENGTH**2), lambda n: chain_hash(chain_start, n, chain_end, pair_chain_length_sum_ptr, local_zero_buff) + encoding[i], + range(0, CHAIN_LENGTH**2), + lambda n: chain_hash_pair( + chain_start_a, + chain_start_b, + n, + chain_end_a, + chain_end_b, + tweaks_a, + tweaks_b, + public_params_paded, + pair_sum_ptr, + ), ) - target_sum += pair_chain_length_sum_ptr[0] + target_sum += pair_sum_ptr[0] assert target_sum == TARGET_SUM - wots_pubkey_hashed = slice_hash(wots_public_key, V) - - xmss_merkle_verify(wots_pubkey_hashed, merkle_path, merkle_chunks, merkle_root) + merkle_leaf = wots_pk_hash(wots_public_key, public_param) + merkle_tweaks = TWEAK_TABLE_ADDR + TWEAK_MERKLE_OFFSET + xmss_merkle_verify(merkle_leaf, merkle_chunks, pub_key, public_param, merkle_tweaks) return @inline -def chain_hash(input_left, n, output_left, pair_chain_length_sum_ptr, local_zero_buff): - debug_assert(n < CHAIN_LENGTH**2) +def chain_hash_pa(input, n, output, chain_i_tweaks, chain_right): + starting_step = CHAIN_LENGTH - 1 - n + if n == 1: + first_tweak = chain_i_tweaks + starting_step * TWEAK_LEN + poseidon16_compress_half_hardcoded_left(input, chain_right, output, first_tweak) + else: + digests = Array(n * XMSS_DIGEST_LEN) + + # Hash 0: input → digests[0..4] + first_tweak = chain_i_tweaks + starting_step * TWEAK_LEN + poseidon16_compress_half_hardcoded_left(input, chain_right, digests, first_tweak) + + # Hashes 1..n-2: digests[(j-1)*4..j*4] → digests[j*4..(j+1)*4] + for j in unroll(1, n - 1): + cur_tweak = chain_i_tweaks + (starting_step + j) * TWEAK_LEN + poseidon16_compress_half_hardcoded_left( + digests + (j - 1) * XMSS_DIGEST_LEN, + chain_right, + digests + j * XMSS_DIGEST_LEN, + cur_tweak, + ) + + # Final hash: digests[(n-2)*4..(n-1)*4] → output + last_tweak = chain_i_tweaks + (starting_step + n - 1) * TWEAK_LEN + poseidon16_compress_half_hardcoded_left( + digests + (n - 2) * XMSS_DIGEST_LEN, chain_right, output, last_tweak + ) + return - raw_left = n % CHAIN_LENGTH - raw_right = (n - raw_left) / CHAIN_LENGTH - n_left = (CHAIN_LENGTH - 1) - raw_left - if n_left == 0: - copy_8(input_left, output_left) - elif n_left == 1: - poseidon16_compress(input_left, local_zero_buff, output_left) +@inline +def chain_hash_pair( + input_a, + input_b, + n, + output_a, + output_b, + tweaks_a, + tweaks_b, + chain_right, + pair_sum_ptr, +): + # Pair-encoded chain hash. `n` is a compile-time constant in [0, CHAIN_LENGTH^2) + raw_a = n % CHAIN_LENGTH + raw_b = (n - raw_a) / CHAIN_LENGTH + num_hashes_a = (CHAIN_LENGTH - 1) - raw_a + num_hashes_b = (CHAIN_LENGTH - 1) - raw_b + + if num_hashes_a == 0: + copy_5(input_a - 1, output_a - 1) else: - states_left = Array((n_left - 1) * DIGEST_LEN) - poseidon16_compress(input_left, local_zero_buff, states_left) - for i in unroll(1, n_left - 1): - poseidon16_compress(states_left + (i - 1) * DIGEST_LEN, local_zero_buff, states_left + i * DIGEST_LEN) - poseidon16_compress(states_left + (n_left - 2) * DIGEST_LEN, local_zero_buff, output_left) - - n_right = (CHAIN_LENGTH - 1) - raw_right - debug_assert(raw_right < CHAIN_LENGTH) - input_right = input_left + DIGEST_LEN - output_right = output_left + DIGEST_LEN - if n_right == 0: - copy_8(input_right, output_right) - elif n_right == 1: - poseidon16_compress(input_right, local_zero_buff, output_right) + chain_hash_pa(input_a, num_hashes_a, output_a, tweaks_a, chain_right) + + if num_hashes_b == 0: + copy_5(input_b, output_b) else: - states_right = Array((n_right - 1) * DIGEST_LEN) - poseidon16_compress(input_right, local_zero_buff, states_right) - for i in unroll(1, n_right - 1): - poseidon16_compress(states_right + (i - 1) * DIGEST_LEN, local_zero_buff, states_right + i * DIGEST_LEN) - poseidon16_compress(states_right + (n_right - 2) * DIGEST_LEN, local_zero_buff, output_right) + chain_hash_pa(input_b, num_hashes_b, output_b, tweaks_b, chain_right) + + pair_sum_ptr[0] = raw_a + raw_b + return + + +@inline +def wots_pk_hash(wots_public_key, public_param): + N_CHUNKS = V / 2 + states = Array((N_CHUNKS + 1) * DIGEST_LEN) + poseidon16_compress_hardcoded_left( + public_param, ZERO_VEC_PTR, states, TWEAK_TABLE_ADDR + TWEAK_WOTS_PK_OFFSET + ) + for i in unroll(0, N_CHUNKS): + poseidon16_compress( + states + i * DIGEST_LEN, + wots_public_key + i * WOTS_PK_PAIR_STRIDE + 1, + states + (i + 1) * DIGEST_LEN, + ) + + return states + N_CHUNKS * DIGEST_LEN - pair_chain_length_sum_ptr[0] = raw_left + raw_right +@inline +def set_buf_prefix_right(buf, public_param): + # Writes [pp(4)] to buf[0..4] — the RIGHT-input prefix. + for k in unroll(0, PP_IN_LEFT): + buf[k] = public_param[k] return @inline -def do_4_merkle_levels(b, state_in, path_chunk, state_out): - # Extract bits of b (compile-time; each division is exact so field div == integer div) +def do_4_merkle_levels(b, state_in, state_out, public_param, merkle_tweaks_chunk): b0 = b % 2 r1 = (b - b0) / 2 b1 = r1 % 2 @@ -139,71 +214,82 @@ def do_4_merkle_levels(b, state_in, path_chunk, state_out): r3 = (r2 - b2) / 2 b3 = r3 % 2 - temps = Array(3 * DIGEST_LEN) - - # Level 0: state_in -> temps - if b0 == 0: - poseidon16_compress(path_chunk, state_in, temps) + buf0_alloc = Array(XMSS_DIGEST_LEN * 2 + 2) + buf0 = buf0_alloc + 1 + if b0 == 1: + # state_in is the LEFT child → state_in[0..4] lands at buf0[0..4]. + copy_5(state_in - 1, buf0 - 1) + hint_witness("xmss_merkle_node", buf0 + XMSS_DIGEST_LEN) else: - poseidon16_compress(state_in, path_chunk, temps) - - # Level 1 - if b1 == 0: - poseidon16_compress(path_chunk + 1 * DIGEST_LEN, temps, temps + DIGEST_LEN) + # state_in is the RIGHT child → state_in[0..4] lands at buf0[4..8]. + hint_witness("xmss_merkle_node", buf0) + copy_5(state_in, buf0 + XMSS_DIGEST_LEN) + + # Level 0 hash + buf1 = Array(XMSS_DIGEST_LEN * 2) + if b1 == 1: + poseidon16_compress_half_hardcoded_left(public_param, buf0, buf1, merkle_tweaks_chunk) + hint_witness("xmss_merkle_node", buf1 + XMSS_DIGEST_LEN) else: - poseidon16_compress(temps, path_chunk + 1 * DIGEST_LEN, temps + DIGEST_LEN) - - # Level 2 - if b2 == 0: - poseidon16_compress(path_chunk + 2 * DIGEST_LEN, temps + DIGEST_LEN, temps + 2 * DIGEST_LEN) + poseidon16_compress_half_hardcoded_left(public_param, buf0, buf1 + XMSS_DIGEST_LEN, merkle_tweaks_chunk) + hint_witness("xmss_merkle_node", buf1) + + # Level 1 hash → buf2 + buf2 = Array(XMSS_DIGEST_LEN * 2) + if b2 == 1: + poseidon16_compress_half_hardcoded_left(public_param, buf1, buf2, merkle_tweaks_chunk + 1 * TWEAK_LEN) + hint_witness("xmss_merkle_node", buf2 + XMSS_DIGEST_LEN) else: - poseidon16_compress(temps + DIGEST_LEN, path_chunk + 2 * DIGEST_LEN, temps + 2 * DIGEST_LEN) - - # Level 3: -> state_out - if b3 == 0: - poseidon16_compress(path_chunk + 3 * DIGEST_LEN, temps + 2 * DIGEST_LEN, state_out) + poseidon16_compress_half_hardcoded_left(public_param, buf1, buf2 + XMSS_DIGEST_LEN, merkle_tweaks_chunk + 1 * TWEAK_LEN) + hint_witness("xmss_merkle_node", buf2) + + # Level 2 hash → buf3 + buf3 = Array(XMSS_DIGEST_LEN * 2) + if b3 == 1: + poseidon16_compress_half_hardcoded_left(public_param, buf2, buf3, merkle_tweaks_chunk + 2 * TWEAK_LEN) + hint_witness("xmss_merkle_node", buf3 + XMSS_DIGEST_LEN) else: - poseidon16_compress(temps + 2 * DIGEST_LEN, path_chunk + 3 * DIGEST_LEN, state_out) + poseidon16_compress_half_hardcoded_left(public_param, buf2, buf3 + XMSS_DIGEST_LEN, merkle_tweaks_chunk + 2 * TWEAK_LEN) + hint_witness("xmss_merkle_node", buf3) + + poseidon16_compress_half_hardcoded_left(public_param, buf3, state_out, merkle_tweaks_chunk + 3 * TWEAK_LEN) return @inline -def xmss_merkle_verify(leaf_digest, merkle_path, merkle_chunks, expected_root): - states = Array((N_MERKLE_CHUNKS - 1) * DIGEST_LEN) +def xmss_merkle_verify(leaf_digest, merkle_chunks, expected_root, public_param, merkle_tweaks): + states_alloc = Array(DIM * N_MERKLE_CHUNKS) + states = states_alloc + 1 - # First chunk: leaf_digest -> states - match_range(merkle_chunks[0], range(0, 16), lambda b: do_4_merkle_levels(b, leaf_digest, merkle_path, states)) + # First chunk + match_range(merkle_chunks[0], range(0, 16), lambda b: do_4_merkle_levels(b, leaf_digest, states, public_param, merkle_tweaks)) - # Middle chunks + state_indexes = Array(N_MERKLE_CHUNKS - 1) + state_indexes[0] = states for j in unroll(1, N_MERKLE_CHUNKS - 1): + state_indexes[j] = state_indexes[j - 1] + DIM match_range( merkle_chunks[j], range(0, 16), lambda b: do_4_merkle_levels( - b, states + (j - 1) * DIGEST_LEN, merkle_path + j * MERKLE_LEVELS_PER_CHUNK * DIGEST_LEN, states + j * DIGEST_LEN + b, + state_indexes[j - 1], + state_indexes[j], + public_param, + merkle_tweaks + j * MERKLE_LEVELS_PER_CHUNK * TWEAK_LEN, ), ) - # Last chunk: -> expected_root + # last chunk → write directly to expected_root match_range( merkle_chunks[N_MERKLE_CHUNKS - 1], range(0, 16), lambda b: do_4_merkle_levels( - b, states + (N_MERKLE_CHUNKS - 2) * DIGEST_LEN, merkle_path + (N_MERKLE_CHUNKS - 1) * MERKLE_LEVELS_PER_CHUNK * DIGEST_LEN, expected_root + b, + state_indexes[N_MERKLE_CHUNKS - 2], + expected_root, + public_param, + merkle_tweaks + (N_MERKLE_CHUNKS - 1) * MERKLE_LEVELS_PER_CHUNK * TWEAK_LEN, ), ) return - - -@inline -def copy_7(x, y): - dot_product_ee(x, ONE_EF_PTR, y) - dot_product_ee(x + (7 - DIM), ONE_EF_PTR, y + (7 - DIM)) - return - - -@inline -def copy_6(x, y): - dot_product_ee(x, ONE_EF_PTR, y) - y[DIM] = x[DIM] - return diff --git a/crates/xmss/params.md b/crates/xmss/params.md deleted file mode 100644 index 71b39aee9..000000000 --- a/crates/xmss/params.md +++ /dev/null @@ -1,72 +0,0 @@ -# XMSS parameters (WIP) - -> **Warning:** The current implementation does not match the [leanSig](https://github.com/leanEthereum/leanSig) paper and does not provide 128-bit security in the Standard Model (though it may still be secure in the ROM/QROM). Expect changes in the future. - -## 1. Field and Hash - -**Field:** KoalaBear, p = 2^31 - 2^24 + 1. Each field element fits in a u32. - -**Hash:** Poseidon2 (width 16) in compression mode: `compress: [F; 16] -> [F; 8]`. Applies the Poseidon2 permutation, adds the input (feed-forward), and returns the first 8 elements. - -**Digest:** 8 field elements (~248 bits). Used for tree nodes, and chain values. - -**Chain step:** `chain_step(x) = compress(x, 0)`. Iterated n times: `iterate_hash(x, n) = chain_step^n(x)`. - -## 2. WOTS - -| Parameter | Symbol | Value | -|---|---|---| -| Chains | V | 40 | -| Winternitz parameter | W | 3 | -| Chain length | CHAIN_LENGTH | 2^W = 8 | -| Verifier chain hashes | NUM_CHAIN_HASHES | 120 | -| Signer chain hashes | TARGET_SUM | 160 (= V*(CHAIN_LENGTH-1) - NUM_CHAIN_HASHES) | -| Grinding chains | V_GRINDING | 3 | -| Message length | MESSAGE_LEN_FE | 9 | -| Randomness length | RANDOMNESS_LEN_FE | 7 | -| Truncated root length | TRUNCATED_MERKLE_ROOT_LEN_FE | 6 | - -### 2.1 Encoding - -Converts (message, randomness, slot, truncated_merkle_root) into 40 chain indices via a **fixed-sum encoding** (indices sum to TARGET_SUM, eliminating the need for checksum chains). - -1. `A = compress(message[0..8], [message[8], randomness[0..7]])` -2. `B = compress(A, [slot_lo, slot_hi, merkle_root[0..6]])` where slot is split into two 16-bit field elements. -3. Reject if any element of B equals -1 (uniformity guard). -4. Extract 24 bits per element of B (little-endian), split into 3-bit chunks, take first 43. -5. Valid iff: first 40 sum to 160, last 3 all equal 7. Otherwise retry with new randomness. - -(Note: adding part of the merkle root to the encoding computation contributes to multi-user security via domain-separation, otherwise the security of the encoding W * (V + V_GRINDING) would degrade bellow 128 bits with multiple users.) - -### 2.2 Keys - -- **Secret key:** 40 random pre-image digests. -- **Public key:** `pk[i] = iterate_hash(pre_image[i], 7)` for each chain. -- **Public key hash:** sequential left fold: `compress(compress(...compress(pk[0], pk[1])..., pk[38]), pk[39])` (39 compressions). - -### 2.3 Sign and Verify - -**Sign:** Find randomness r yielding a valid encoding, then `chain_tip[i] = iterate_hash(pre_image[i], encoding[i])`. Signature = (chain_tips, r). - -**Verify (public key recovery):** Recompute encoding from (message, slot, truncated_root, r), then `recovered_pk[i] = iterate_hash(chain_tip[i], 7 - encoding[i])`. - -## 3. XMSS - -**Tree:** Binary Merkle tree of depth LOG_LIFETIME = 32 (2^32 slots). Nodes = `compress(left, right)`. - -### 3.1 Key Generation - -Inputs: seed (32 bytes), slot range [start, end]. Only WOTS leaves for [start, end] are generated; Merkle nodes outside this range are filled with deterministic random digests (derived from the seed). To an observer, the resulting tree is indistinguishable from a full 2^32-leaf tree. - -**Public key:** the Merkle root (single digest). - - -... -TODO - -## 4. Properties - -- public key size: 31 bytes -- num. hashes at signing: < 2^16 (mostly grinding at encoding) -- num. hashes at verification: 2 (encoding) + NUM_CHAIN_HASHES + V + LOG_LIFETIME = 194 -- sig. size : RANDOMNESS_LEN_FE + 8 * (V + LOG_LIFETIME) = 583 field elements = 2.21 KiB \ No newline at end of file diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index 7e5fe8d21..3a41f2164 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -1,15 +1,20 @@ #![cfg_attr(not(test), warn(unused_crate_dependencies))] +use backend::PrimeCharacteristicRing; +use backend::{DIGEST_LEN_FE, KoalaBear, POSEIDON1_WIDTH}; + pub mod signers_cache; mod wots; -use backend::KoalaBear; pub use wots::*; mod xmss; pub use xmss::*; -pub(crate) const DIGEST_SIZE: usize = 8; +pub const XMSS_DIGEST_LEN: usize = 4; +pub(crate) const TWEAK_LEN: usize = 2; type F = KoalaBear; -type Digest = [F; DIGEST_SIZE]; +type Digest = [F; XMSS_DIGEST_LEN]; +type PublicParam = [F; PUBLIC_PARAM_LEN_FE]; +type Randomness = [F; RANDOMNESS_LEN_FE]; // WOTS pub const V: usize = 42; @@ -17,10 +22,62 @@ pub const W: usize = 3; pub const CHAIN_LENGTH: usize = 1 << W; pub const NUM_CHAIN_HASHES: usize = 110; pub const TARGET_SUM: usize = V * (CHAIN_LENGTH - 1) - NUM_CHAIN_HASHES; -pub const V_GRINDING: usize = 2; +pub const RANDOMNESS_LEN_FE: usize = 6; +pub const MESSAGE_LEN_FE: usize = 8; +pub const PUBLIC_PARAM_LEN_FE: usize = 4; +pub const PUB_KEY_FLAT_SIZE: usize = XMSS_DIGEST_LEN + PUBLIC_PARAM_LEN_FE; +pub const WOTS_SIG_SIZE_FE: usize = RANDOMNESS_LEN_FE + V * XMSS_DIGEST_LEN; + +// XMSS pub const LOG_LIFETIME: usize = 32; -pub const RANDOMNESS_LEN_FE: usize = 7; -pub const MESSAGE_LEN_FE: usize = 9; -pub const TRUNCATED_MERKLE_ROOT_LEN_FE: usize = 6; -pub const SIG_SIZE_FE: usize = RANDOMNESS_LEN_FE + (V + LOG_LIFETIME) * DIGEST_SIZE; +// Tweak: domain separation within each hash. +pub(crate) const TWEAK_TYPE_CHAIN: usize = 0; +pub(crate) const TWEAK_TYPE_WOTS_PK: usize = 1; +pub(crate) const TWEAK_TYPE_MERKLE: usize = 2; +pub(crate) const TWEAK_TYPE_ENCODING: usize = 3; + +const _: () = assert!(V.is_multiple_of(2)); // For efficiency of the snark (we can batch chains in pairs) + +/// index = slot or node_index in Merkle tree +pub(crate) fn make_tweak(tweak_type: usize, sub_position: usize, index: u32) -> [F; TWEAK_LEN] { + assert!(tweak_type < 4); + assert!(sub_position < 1 << 10); + let index_lo = (index & 0xFFFF) as usize; + let index_hi = (index >> 16) as usize; + [ + F::from_usize((tweak_type << 26) + (index_hi << 10) + sub_position), + F::from_usize(index_lo), + ] +} + +/// [tweak(2) | zeros(2) | public_param(4) | left_child(4) | right_child(4)] +pub(crate) fn build_merkle_data( + tweak: [F; TWEAK_LEN], + public_param: &PublicParam, + left_child: &Digest, + right_child: &Digest, +) -> [F; POSEIDON1_WIDTH] { + let mut data = [F::default(); POSEIDON1_WIDTH]; + data[..TWEAK_LEN].copy_from_slice(&tweak); + // data[2..4] = zeros (default) + data[DIGEST_LEN_FE - PUBLIC_PARAM_LEN_FE..][..PUBLIC_PARAM_LEN_FE].copy_from_slice(public_param); + data[DIGEST_LEN_FE..][..XMSS_DIGEST_LEN].copy_from_slice(left_child); + data[DIGEST_LEN_FE + XMSS_DIGEST_LEN..].copy_from_slice(right_child); + data +} + +/// [tweak(2) | zeros(2) | data(4)] +pub(crate) fn build_left_chain_input(tweak: [F; TWEAK_LEN], data: &Digest) -> [F; DIGEST_LEN_FE] { + let mut left = [F::default(); DIGEST_LEN_FE]; + left[..TWEAK_LEN].copy_from_slice(&tweak); + left[DIGEST_LEN_FE - XMSS_DIGEST_LEN..].copy_from_slice(data); + left +} + +/// [public_param(4) | zeros(4)] +pub(crate) fn build_right_chain_input(public_param: &PublicParam) -> [F; DIGEST_LEN_FE] { + let mut right = [F::default(); DIGEST_LEN_FE]; + right[..PUBLIC_PARAM_LEN_FE].copy_from_slice(public_param); + right +} diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index 20c35b361..bce2e6a4c 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -21,18 +21,20 @@ pub struct WotsSignature { bound(serialize = "F: Serialize", deserialize = "F: Deserialize<'de>") )] pub chain_tips: [Digest; V], - pub randomness: [F; RANDOMNESS_LEN_FE], + pub randomness: Randomness, } impl WotsSecretKey { - pub fn random(rng: &mut impl CryptoRng) -> Self { - Self::new(rng.random()) + pub fn random(rng: &mut impl CryptoRng, public_param: PublicParam, slot: u32) -> Self { + Self::new(rng.random(), public_param, slot) } - pub fn new(pre_images: [Digest; V]) -> Self { + pub fn new(pre_images: [Digest; V], public_param: PublicParam, slot: u32) -> Self { Self { pre_images, - public_key: WotsPublicKey(std::array::from_fn(|i| iterate_hash(&pre_images[i], CHAIN_LENGTH - 1))), + public_key: WotsPublicKey(std::array::from_fn(|i| { + iterate_hash(&pre_images[i], CHAIN_LENGTH - 1, public_param, slot, i, 0) + })), } } @@ -44,16 +46,24 @@ impl WotsSecretKey { &self, message: &[F; MESSAGE_LEN_FE], slot: u32, - truncated_merkle_root: &[F; TRUNCATED_MERKLE_ROOT_LEN_FE], - randomness: [F; RANDOMNESS_LEN_FE], + xmss_pub_key: &XmssPublicKey, + randomness: Randomness, ) -> WotsSignature { - let encoding = wots_encode(message, slot, truncated_merkle_root, &randomness).unwrap(); - self.sign_with_encoding(randomness, &encoding) + let encoding = wots_encode(message, slot, xmss_pub_key, &randomness).unwrap(); + self.sign_with_encoding(randomness, &encoding, xmss_pub_key.public_param, slot) } - fn sign_with_encoding(&self, randomness: [F; RANDOMNESS_LEN_FE], encoding: &[u8; V]) -> WotsSignature { + fn sign_with_encoding( + &self, + randomness: Randomness, + encoding: &[u8; V], + public_param: PublicParam, + slot: u32, + ) -> WotsSignature { WotsSignature { - chain_tips: std::array::from_fn(|i| iterate_hash(&self.pre_images[i], encoding[i] as usize)), + chain_tips: std::array::from_fn(|i| { + iterate_hash(&self.pre_images[i], encoding[i] as usize, public_param, slot, i, 0) + }), randomness, } } @@ -64,40 +74,76 @@ impl WotsSignature { &self, message: &[F; MESSAGE_LEN_FE], slot: u32, - truncated_merkle_root: &[F; TRUNCATED_MERKLE_ROOT_LEN_FE], + xmss_pub_key: &XmssPublicKey, signature: &Self, ) -> Option { - let encoding = wots_encode(message, slot, truncated_merkle_root, &signature.randomness)?; + let encoding = wots_encode(message, slot, xmss_pub_key, &signature.randomness)?; Some(WotsPublicKey(std::array::from_fn(|i| { - iterate_hash(&self.chain_tips[i], CHAIN_LENGTH - 1 - encoding[i] as usize) + iterate_hash( + &self.chain_tips[i], + CHAIN_LENGTH - 1 - encoding[i] as usize, + xmss_pub_key.public_param, + slot, + i, + encoding[i] as usize, + ) }))) } } impl WotsPublicKey { - pub fn hash(&self) -> Digest { - let init = poseidon16_compress_pair(&self.0[0], &self.0[1]); - self.0[2..] - .iter() - .fold(init, |digest, chunk| poseidon16_compress_pair(&digest, chunk)) + // We use a T-Sponge with replacement, i.e. we use Poseidon in compression mode + replace (instead of modular addition) when ingesting 8 new field elements. + pub fn hash(&self, public_param: PublicParam, slot: u32) -> Digest { + // IV: [tweak(2) | 00 | pp(4)] + let tweak = make_tweak(TWEAK_TYPE_WOTS_PK, 0, slot); + let mut state = [F::default(); 8]; + state[..TWEAK_LEN].copy_from_slice(&tweak); + // state[2..4] = 00 (default) + state[4..4 + PUBLIC_PARAM_LEN_FE].copy_from_slice(&public_param); + + let zeros = [F::ZERO; 8]; // for snark-friendliless (not necessary for security) + state = poseidon16_compress_pair(&state, &zeros); + + for i in (0..V).step_by(2) { + let mut chunk = [F::default(); 8]; + chunk[..XMSS_DIGEST_LEN].copy_from_slice(&self.0[i]); + chunk[XMSS_DIGEST_LEN..].copy_from_slice(&self.0[i + 1]); + state = poseidon16_compress_pair(&state, &chunk); + } + state[..XMSS_DIGEST_LEN].try_into().unwrap() } } -pub fn iterate_hash(a: &Digest, n: usize) -> Digest { - (0..n).fold(*a, |acc, _| poseidon16_compress_pair(&acc, &Default::default())) +pub fn iterate_hash( + a: &Digest, + n: usize, + public_param: PublicParam, + slot: u32, + chain_index: usize, + start_step: usize, +) -> Digest { + // Chain hash layout: left = [tweak (2) | zeros (2) | data (4)], right = [public_param(4) | zeros(4)]. + let right = build_right_chain_input(&public_param); + (0..n).fold(*a, |acc, j| { + let tweak = make_tweak(TWEAK_TYPE_CHAIN, chain_index * CHAIN_LENGTH + start_step + j, slot); + let left = build_left_chain_input(tweak, &acc); + poseidon16_compress_pair(&left, &right)[..XMSS_DIGEST_LEN] + .try_into() + .unwrap() + }) } pub fn find_randomness_for_wots_encoding( message: &[F; MESSAGE_LEN_FE], slot: u32, - truncated_merkle_root: &[F; TRUNCATED_MERKLE_ROOT_LEN_FE], + xmss_pub_key: &XmssPublicKey, rng: &mut impl CryptoRng, -) -> ([F; RANDOMNESS_LEN_FE], [u8; V], usize) { +) -> (Randomness, [u8; V], usize) { let mut num_iters = 0; loop { num_iters += 1; let randomness = rng.random(); - if let Some(encoding) = wots_encode(message, slot, truncated_merkle_root, &randomness) { + if let Some(encoding) = wots_encode(message, slot, xmss_pub_key, &randomness) { return (randomness, encoding, num_iters); } } @@ -106,24 +152,18 @@ pub fn find_randomness_for_wots_encoding( pub fn wots_encode( message: &[F; MESSAGE_LEN_FE], slot: u32, - truncated_merkle_root: &[F; TRUNCATED_MERKLE_ROOT_LEN_FE], - randomness: &[F; RANDOMNESS_LEN_FE], + xmss_pub_key: &XmssPublicKey, + randomness: &Randomness, ) -> Option<[u8; V]> { - // Encode slot as 2 field elements (16 bits each) - let [slot_lo, slot_hi] = slot_to_field_elements(slot); - - // A = poseidon(message (9 fe), randomness (7 fe)) - let mut a_input_right = [F::default(); 8]; - a_input_right[0] = message[8]; - a_input_right[1..1 + RANDOMNESS_LEN_FE].copy_from_slice(randomness); - let a = poseidon16_compress_pair(message[..8].try_into().unwrap(), &a_input_right); - - // B = poseidon(A (8 fe), slot (2 fe), truncated_merkle_root (6 fe)) - let mut b_input_right = [F::default(); 8]; - b_input_right[0] = slot_lo; - b_input_right[1] = slot_hi; - b_input_right[2..8].copy_from_slice(truncated_merkle_root); - let compressed = poseidon16_compress_pair(&a, &b_input_right); + let first_input_left = message; + let mut first_input_right = [F::default(); DIGEST_LEN_FE]; + first_input_right[..RANDOMNESS_LEN_FE].copy_from_slice(randomness); + first_input_right[RANDOMNESS_LEN_FE..][..TWEAK_LEN].copy_from_slice(&make_tweak(TWEAK_TYPE_ENCODING, 0, slot)); + let pre_compressed = poseidon16_compress_pair(first_input_left, &first_input_right); + + let mut second_input_right = [F::default(); DIGEST_LEN_FE]; + second_input_right[..PUBLIC_PARAM_LEN_FE].copy_from_slice(&xmss_pub_key.public_param); + let compressed = poseidon16_compress_pair(&pre_compressed, &second_input_right); if compressed.iter().any(|&kb| kb == -F::ONE) { // ensures uniformity of encoding @@ -134,7 +174,7 @@ pub fn wots_encode( .flat_map(|kb| to_little_endian_bits(kb.to_usize(), 24)) .collect::>() .chunks_exact(W) - .take(V + V_GRINDING) + .take(V) .map(|chunk| { chunk .iter() @@ -146,27 +186,14 @@ pub fn wots_encode( } fn is_valid_encoding(encoding: &[u8]) -> bool { - if encoding.len() != V + V_GRINDING { + if encoding.len() != V { return false; } - // All indices must be < CHAIN_LENGTH if !encoding.iter().all(|&x| (x as usize) < CHAIN_LENGTH) { return false; } - // First V indices must sum to TARGET_SUM - if encoding[..V].iter().map(|&x| x as usize).sum::() != TARGET_SUM { - return false; - } - // Last V_GRINDING indices must all be CHAIN_LENGTH-1 (grinding constraint) - if !encoding[V..].iter().all(|&x| x as usize == CHAIN_LENGTH - 1) { + if encoding.iter().map(|&x| x as usize).sum::() != TARGET_SUM { return false; } true } - -pub fn slot_to_field_elements(slot: u32) -> [F; 2] { - [ - F::from_usize((slot & 0xFFFF) as usize), - F::from_usize(((slot >> 16) & 0xFFFF) as usize), - ] -} diff --git a/crates/xmss/src/xmss.rs b/crates/xmss/src/xmss.rs index 88f9e6f22..02c5f6b4d 100644 --- a/crates/xmss/src/xmss.rs +++ b/crates/xmss/src/xmss.rs @@ -1,7 +1,8 @@ use backend::*; use rand::{CryptoRng, RngExt, SeedableRng, rngs::StdRng}; use serde::{Deserialize, Serialize}; -use utils::poseidon16_compress_pair; +use sha3::{Digest as Sha3Digest, Keccak256}; +use utils::poseidon16_compress; use crate::*; @@ -9,7 +10,8 @@ use crate::*; pub struct XmssSecretKey { pub(crate) slot_start: u32, // inclusive pub(crate) slot_end: u32, // inclusive - pub(crate) seed: [u8; 20], + pub(crate) public_param: PublicParam, + pub(crate) seed: [u8; 32], // At level l, stored indices go from (slot_start >> l) to (slot_end >> l). pub(crate) merkle_tree: Vec>, } @@ -23,25 +25,43 @@ pub struct XmssSignature { #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct XmssPublicKey { pub merkle_root: Digest, + pub public_param: PublicParam, } -fn gen_wots_secret_key(seed: &[u8; 20], slot: u32) -> WotsSecretKey { - let mut rng_seed = [0u8; 32]; - rng_seed[..20].copy_from_slice(seed); - rng_seed[20] = 0x00; - rng_seed[21..25].copy_from_slice(&slot.to_le_bytes()); - let mut rng = StdRng::from_seed(rng_seed); - WotsSecretKey::random(&mut rng) +impl XmssPublicKey { + pub fn flaten(&self) -> [F; PUB_KEY_FLAT_SIZE] { + let mut output = [F::default(); PUB_KEY_FLAT_SIZE]; + output[..XMSS_DIGEST_LEN].copy_from_slice(&self.merkle_root); + output[XMSS_DIGEST_LEN..].copy_from_slice(&self.public_param); + output + } +} + +fn gen_wots_secret_key(seed: &[u8; 32], slot: u32, public_param: PublicParam) -> WotsSecretKey { + let mut hasher = Keccak256::new(); + hasher.update(b"wots_secret_key"); + hasher.update(seed); + hasher.update(slot.to_le_bytes()); + let mut rng = StdRng::from_seed(hasher.finalize().into()); + WotsSecretKey::random(&mut rng, public_param, slot) +} + +fn gen_public_param(seed: &[u8; 32]) -> PublicParam { + let mut hasher = Keccak256::new(); + hasher.update(b"public_param"); + hasher.update(seed); + let mut rng = StdRng::from_seed(hasher.finalize().into()); + rng.random() } /// Deterministic pseudo-random digest for an out-of-range tree node. -fn gen_random_node(seed: &[u8; 20], level: usize, index: u32) -> Digest { - let mut rng_seed = [0u8; 32]; - rng_seed[..20].copy_from_slice(seed); - rng_seed[20] = 0x01; - rng_seed[21] = level as u8; - rng_seed[22..26].copy_from_slice(&index.to_le_bytes()); - let mut rng = StdRng::from_seed(rng_seed); +fn gen_random_node(seed: &[u8; 32], level: usize, index: u64) -> Digest { + let mut hasher = Keccak256::new(); + hasher.update(b"random_node"); + hasher.update(seed); + hasher.update((level as u64).to_le_bytes()); + hasher.update(index.to_le_bytes()); + let mut rng = StdRng::from_seed(hasher.finalize().into()); rng.random() } @@ -51,20 +71,20 @@ pub enum XmssKeyGenError { } pub fn xmss_key_gen( - seed: [u8; 20], + seed: [u8; 32], slot_start: u32, slot_end: u32, ) -> Result<(XmssSecretKey, XmssPublicKey), XmssKeyGenError> { - if slot_start > slot_end { + if slot_start > slot_end || slot_end as u64 >= (1 << LOG_LIFETIME) { return Err(XmssKeyGenError::InvalidRange); } - let perm = default_koalabear_poseidon1_16(); + let public_param: PublicParam = gen_public_param(&seed); // Level 0: WOTS leaf hashes for slots in [slot_start, slot_end] let leaves: Vec = (slot_start..=slot_end) .into_par_iter() .map(|slot| { - let wots = gen_wots_secret_key(&seed, slot); - wots.public_key().hash() + let wots = gen_wots_secret_key(&seed, slot, public_param); + wots.public_key().hash(public_param, slot) }) .collect(); let mut merkle_tree = vec![leaves]; @@ -72,10 +92,10 @@ pub fn xmss_key_gen( // At level l, we store nodes with index in [(slot_start >> l), (slot_end >> l)]. // Children outside [slot_start, slot_end]'s subtree are replaced by gen_random_node. for level in 1..=LOG_LIFETIME { - let base = u64::from(slot_start) >> level; - let top = u64::from(slot_end) >> level; - let prev_base = u64::from(slot_start) >> (level - 1); - let prev_top = u64::from(slot_end) >> (level - 1); + let base: u64 = (slot_start as u64) >> level; + let top: u64 = (slot_end as u64) >> level; + let prev_base: u64 = (slot_start as u64) >> (level - 1); + let prev_top: u64 = (slot_end as u64) >> (level - 1); let nodes: Vec = { let prev = &merkle_tree[level - 1]; (base..=top) @@ -86,16 +106,20 @@ pub fn xmss_key_gen( let left = if left_idx >= prev_base && left_idx <= prev_top { prev[(left_idx - prev_base) as usize] } else { - assert!(left_idx < 1u64 << 32); - gen_random_node(&seed, level - 1, left_idx as u32) + gen_random_node(&seed, level - 1, left_idx) }; let right = if right_idx >= prev_base && right_idx <= prev_top { prev[(right_idx - prev_base) as usize] } else { - assert!(right_idx < 1u64 << 32); - gen_random_node(&seed, level - 1, right_idx as u32) + gen_random_node(&seed, level - 1, right_idx) }; - compress(&perm, [left, right]) + let merkle_data = build_merkle_data( + make_tweak(TWEAK_TYPE_MERKLE, level, i as u32), + &public_param, + &left, + &right, + ); + poseidon16_compress(merkle_data)[..XMSS_DIGEST_LEN].try_into().unwrap() }) .collect() }; @@ -103,10 +127,12 @@ pub fn xmss_key_gen( } let pub_key = XmssPublicKey { merkle_root: merkle_tree.last().unwrap()[0], + public_param, }; let secret_key = XmssSecretKey { slot_start, slot_end, + public_param, seed, merkle_tree, }; @@ -124,9 +150,7 @@ pub fn xmss_sign( message: &[F; MESSAGE_LEN_FE], slot: u32, ) -> Result { - let merkle_root = secret_key.public_key().merkle_root; - let truncated_merkle_root = merkle_root[0..TRUNCATED_MERKLE_ROOT_LEN_FE].try_into().unwrap(); - let (randomness, _, _) = find_randomness_for_wots_encoding(message, slot, truncated_merkle_root, rng); + let (randomness, _, _) = find_randomness_for_wots_encoding(message, slot, &secret_key.public_key(), rng); xmss_sign_with_randomness(secret_key, message, slot, randomness) } @@ -139,15 +163,13 @@ pub fn xmss_sign_with_randomness( if slot < secret_key.slot_start || slot > secret_key.slot_end { return Err(XmssSignatureError::SlotOutOfRange); } - let wots_secret_key = gen_wots_secret_key(&secret_key.seed, slot); - let merkle_root = secret_key.public_key().merkle_root; - let truncated_merkle_root = merkle_root[0..TRUNCATED_MERKLE_ROOT_LEN_FE].try_into().unwrap(); - let wots_signature = wots_secret_key.sign_with_randomness(message, slot, &truncated_merkle_root, randomness); + let wots_secret_key = gen_wots_secret_key(&secret_key.seed, slot, secret_key.public_param); + let wots_signature = wots_secret_key.sign_with_randomness(message, slot, &secret_key.public_key(), randomness); let merkle_proof = (0..LOG_LIFETIME) .map(|level| { - let neighbour_index = (slot >> level) ^ 1; - let base = secret_key.slot_start >> level; - let top = secret_key.slot_end >> level; + let neighbour_index = ((slot as u64) >> level) ^ 1; + let base = (secret_key.slot_start as u64) >> level; + let top = (secret_key.slot_end as u64) >> level; if neighbour_index >= base && neighbour_index <= top { secret_key.merkle_tree[level][(neighbour_index - base) as usize] } else { @@ -165,6 +187,7 @@ impl XmssSecretKey { pub fn public_key(&self) -> XmssPublicKey { XmssPublicKey { merkle_root: self.merkle_tree.last().unwrap()[0], + public_param: self.public_param, } } } @@ -181,22 +204,29 @@ pub fn xmss_verify( signature: &XmssSignature, slot: u32, ) -> Result<(), XmssVerifyError> { - let truncated_merkle_root = pub_key.merkle_root[0..TRUNCATED_MERKLE_ROOT_LEN_FE].try_into().unwrap(); let wots_public_key = signature .wots_signature - .recover_public_key(message, slot, &truncated_merkle_root, &signature.wots_signature) + .recover_public_key(message, slot, pub_key, &signature.wots_signature) .ok_or(XmssVerifyError::InvalidWots)?; - let mut current_hash = wots_public_key.hash(); + let mut current_hash = wots_public_key.hash(pub_key.public_param, slot); if signature.merkle_proof.len() != LOG_LIFETIME { return Err(XmssVerifyError::InvalidMerklePath); } for (level, neighbour) in signature.merkle_proof.iter().enumerate() { - let is_left = ((slot >> level) & 1) == 0; - if is_left { - current_hash = poseidon16_compress_pair(¤t_hash, neighbour); + let is_left = (((slot as u64) >> level) & 1) == 0; + let parent_index = ((slot as u64) >> (level + 1)) as u32; + let (left_child, right_child) = if is_left { + (current_hash, *neighbour) } else { - current_hash = poseidon16_compress_pair(neighbour, ¤t_hash); - } + (*neighbour, current_hash) + }; + let merkle_data = build_merkle_data( + make_tweak(TWEAK_TYPE_MERKLE, level + 1, parent_index), + &pub_key.public_param, + &left_child, + &right_child, + ); + current_hash = poseidon16_compress(merkle_data)[..XMSS_DIGEST_LEN].try_into().unwrap(); } if current_hash == pub_key.merkle_root { Ok(()) diff --git a/crates/xmss/tests/xmss_tests.rs b/crates/xmss/tests/xmss_tests.rs index 40bbb6377..0fb08e01d 100644 --- a/crates/xmss/tests/xmss_tests.rs +++ b/crates/xmss/tests/xmss_tests.rs @@ -6,7 +6,7 @@ type F = KoalaBear; #[test] fn test_xmss_serialize_deserialize() { - let keygen_seed: [u8; 20] = std::array::from_fn(|i| i as u8); + let keygen_seed: [u8; 32] = std::array::from_fn(|i| i as u8); let message: [F; MESSAGE_LEN_FE] = std::array::from_fn(|i| F::from_usize(i * 3 + 7)); let slot_start = 100; let slot_end = 115; @@ -28,14 +28,12 @@ fn test_xmss_serialize_deserialize() { #[test] fn keygen_sign_verify() { - let keygen_seed: [u8; 20] = std::array::from_fn(|i| i as u8); + let keygen_seed: [u8; 32] = std::array::from_fn(|i| i as u8); let message: [F; MESSAGE_LEN_FE] = std::array::from_fn(|i| F::from_usize(i * 3 + 7)); - let slot_start = 100; - let slot_end = 115; - let (sk, pk) = xmss_key_gen(keygen_seed, slot_start, slot_end).unwrap(); - for slot in slot_start..=slot_end { - let sig = xmss_sign(&mut StdRng::seed_from_u64(u64::from(slot)), &sk, &message, slot).unwrap(); + for slot in [0, 1234, u32::MAX] { + let (sk, pk) = xmss_key_gen(keygen_seed, slot.saturating_sub(1), slot.saturating_add(2)).unwrap(); + let sig = xmss_sign(&mut StdRng::seed_from_u64(slot as u64), &sk, &message, slot).unwrap(); xmss_verify(&pk, &message, &sig, slot).unwrap(); } } @@ -44,15 +42,18 @@ fn keygen_sign_verify() { #[ignore] fn encoding_grinding_bits() { let n = 100; + let xmss_pub_key = XmssPublicKey { + merkle_root: Default::default(), + public_param: Default::default(), + }; let total_iters = (0..n) .into_par_iter() .map(|i| { let message: [F; MESSAGE_LEN_FE] = Default::default(); let slot = i as u32; - let truncated_merkle_root: [F; TRUNCATED_MERKLE_ROOT_LEN_FE] = Default::default(); let mut rng = StdRng::seed_from_u64(i as u64); let (_randomness, _encoding, num_iters) = - find_randomness_for_wots_encoding(&message, slot, &truncated_merkle_root, &mut rng); + find_randomness_for_wots_encoding(&message, slot, &xmss_pub_key, &mut rng); num_iters }) .sum::(); diff --git a/crates/xmss/xmss.md b/crates/xmss/xmss.md new file mode 100644 index 000000000..1a03e7d4c --- /dev/null +++ b/crates/xmss/xmss.md @@ -0,0 +1,48 @@ +# XMSS high-level specification + +## Field + +KoalaBear (p = 2^31 - 2^24 + 1). + +## Hash function + +[Poseidon](https://eprint.iacr.org/2019/458), in compression mode (feedforward addition). Input: 16 field elements. Output: 8 field elements. We denote it `H`. Chain hashes, Merkle hashes, and the final WOTS-pubkey hash truncate the output to 4 field elements (`n`); the encoding step and the intermediate WOTS-pubkey sponge states keep the full 8 elements. + +## Sizes (in field elements) + +- `n = 4`: digest size +- `|pp| = 4`: public parameter +- `|randomness| = 6`: signature randomness +- `|msg| = 8`: message size +- `|tweak| = 2`: tweak (domain separation: `encoding`, `chain`, `wots_pk`, `merkle`) + +## WOTS (Winternitz One Time Signature) + +- `v = 42`: number of hash chains +- `w = 3`, `chain_length = 2^w = 8` +- `target_sum = 184`: a WOTS encoding `(e_0, ..., e_{v-1})` is valid iff each `e_i < chain_length` and `sum(e_i) = target_sum`. The signer grinds `randomness` until the encoding is valid (avoids checksum chains). + +## XMSS + +`log_lifetime = 32`: a key is valid for up to `2^32` slots. `log_lifetime` corresponds to the Merkle tree height. + +## Verification + +Inputs: public key `(merkle_root, pp)`, message `msg`, slot `s`, signature `(randomness, chain_tips, merkle_proof)`. + +1. **Encode**: compute the 8-limb digest `D = H(H(msg | randomness | tweak_encoding(s)) | pp | 0000)`. Reject if any limb of `D` equals `-1` (for a uniform sampling). For each limb, take its canonical representative's low 24 bits in little-endian order, concatenate to get 192 bits, then take the first `v · w = 126` bits split into `v = 42` little-endian chunks of `w = 3` bits → encoding `(e_0, ..., e_{v-1})` with each `e_i ∈ [0, chain_length)`. Reject if `sum(e_i) ≠ target_sum`. +2. **Recover WOTS public key**: for each `i`, walk chain `i` from `chain_tips[i]` for `chain_length - 1 - e_i` steps, where each step is `H(tweak_chain(i, step, s) | 00 | previous_value | pp | 0000)` truncated to `n`. +3. **Hash WOTS public key**: T-sponge with replacement over the `v` recovered chain ends, with IV `[tweak_wots_pk(s) | 00 | pp]`, ingesting two chain end digests at a time. Output is the Merkle leaf. +4. **Walk Merkle path**: for `level = 0..log_lifetime`, combine the current node with `merkle_proof[level]` (left/right determined by bit `level` of `s`) via `H(tweak_merkle(level+1, parent_index) | 00 | pp | left | right)` truncated to `n`. +5. **Check root**: accept iff the final hash equals `merkle_root`. + + +## Security + +target = 123,9 ≈ 124 bits of classical security in the ROM, and ≈ 62 bits of quantum security in the QROM, with an analysis inspired by the section 3.1 of [Tight adaptive reprogramming in the QROM](https://arxiv.org/pdf/2010.15103). TODO write the complete proof. + +## Signature size + +**1171 bytes** `log2(p).(|randomness| + n.(v + log_lifetime))` + +below IPv6 [MTU](https://fr.wikipedia.org/wiki/Maximum_transmission_unit) (1280 bytes) diff --git a/minimal_zkVM.pdf b/minimal_zkVM.pdf index c3f56f730..68a052ba2 100644 Binary files a/minimal_zkVM.pdf and b/minimal_zkVM.pdf differ diff --git a/misc/images/fancy-aggregation.png b/misc/images/fancy-aggregation.png index 711939507..2b729a0ff 100644 Binary files a/misc/images/fancy-aggregation.png and b/misc/images/fancy-aggregation.png differ diff --git a/misc/minimal_zkVM.tex b/misc/minimal_zkVM.tex index d2c2c22dc..1567d3a23 100644 --- a/misc/minimal_zkVM.tex +++ b/misc/minimal_zkVM.tex @@ -37,7 +37,8 @@ \newtheorem{lemma}{Lemma} -\title{Minimal zkVM for Lean Ethereum (draft 0.6.0)} +\title{Minimal zkVM for Lean Ethereum (draft 0.7.0)} +\author{} \date{} \begin{document} @@ -269,16 +270,18 @@ \subsection{Precompiles} \subsubsection{POSEIDON} -Compression of 16 field elements (two blocks of 8) into 8 field elements. +Compression (feed-forward) of 16 field elements (two blocks of 8) into 8 field elements. $$ -\textbf{m}[\nu_C..\nu_C + 8] = \text{Poseidon}(\textbf{m}[\nu_A..\nu_A + 8] | \textbf{m}[\nu_B..\nu_B + 8]) + \textbf{m}[\nu_A..\nu_A + 8] +\textbf{m}[\nu_C..\nu_C + 8] = \text{Poseidon}(\textbf{m}[\nu_A..\nu_A + 8] \;\|\; \textbf{m}[\nu_B..\nu_B + 8]) + \textbf{m}[\nu_A..\nu_A + 8] $$ \vspace{2mm} \texttt{PRECOMPILE\_DATA} $= 1$ +Recently some additonal paramters were introduced (see \ref{efficiently_verrifying_hash-based_signatures} for details), allowing more granular input / output. + \subsubsection{EXTENSION\_OP}\label{extension_op_instruction} EXTENSION\_OP enables computations of one these 3 forms in the extension field $\Fq$: @@ -346,7 +349,7 @@ \subsubsection{Loops}\label{loops} \subsubsection{Range checks} -\fbox{It's possible to check that a given memory cell is smaller than some value $t$ (for $t \leq 2^{16})$ in 3 cycles.} +\fbox{It's possible to check that a memory cell is smaller than $t$ (for $t \leq 2^{16}$) in 3 cycles.} We denote by \textbf{m}[\textbf{fp} + $x$] the memory cell for which we want to ensure \textbf{m}[\textbf{fp} + $x$] $< t$. We also denote by \textbf{m}[\textbf{fp} + $i$], \textbf{m}[\textbf{fp} + $j$] and \textbf{m}[\textbf{fp} + $k$] 3 auxiliary memory cells (that have not been used yet). @@ -393,11 +396,10 @@ \subsubsection{Switch statements} Suppose we want a different logic depending on the value $x$ of a given memory cell, where $x$ is known to be $< k$ (if the value $x$ comes from a "hint", don't forget to range-check it). -Each of the $k$ different value leads to a different branch at runtime, represented by a block of code. We want to jump to the correct block of code depending on $x$. -One efficient implementation consists in placing our blocks of code at regular intervals, and to jump to a $a+ b.x$, where $a$ is the offset of the first block of code (in case $x = 0$), and $b$ is the distance between two consecutive blocks. -\newline -\newline -Example: During XMSS verification, for each of the $v$ chains, we need to hash a pre-image, a number of times depending on the encoding, but known to be $< w$. Here $k = w$, and the $i-th$ block of code we could jump to will execute $i$ times the hash function (unrolled loop). +Each of the $k$ different values leads to a different branch at runtime, represented by a block of code. We want to jump to the correct block of code depending on $x$. +One efficient implementation consists in placing our blocks of code at regular intervals, and to jump to $a + b \cdot x$, where $a$ is the offset of the first block of code (in case $x = 0$), and $b$ is the distance between two consecutive blocks. + +Example: During XMSS verification, for each of the $v$ chains, we need to hash a pre-image, a number of times depending on the encoding, but known to be $< w$. Here $k = w$, and the $i$-th block of code we could jump to will execute $i$ times the hash function (unrolled loop). \section{Proving system}\label{sec:proving} @@ -558,7 +560,42 @@ \subsection{Poseidon table} We use poseidon \cite{poseidon1} over 16 field elements, in compression mode, i.e. for $\texttt{input} \in \Fp^{16}$ and interpreting the addition as coordinate-wise in $\Fp^8$: $$\text{poseidon\_compress}(\texttt{input}) = \text{poseidon}(\texttt{input})[..8] + \texttt{input}[..8]$$. -The Poseidon precompile receives 3 arguments: $\nu_A$, $\nu_B$, $\nu_C$ interpreted as memory pointers. 3 lookups (each of size 8) into the memory are used to fetch $\texttt{left}$ = $\textbf{m}[\nu_A..\nu_A + 8] \in \Fp^8$, $\texttt{right}$ = $\textbf{m}[\nu_B..\nu_B + 8] \in \Fp^8$, and $\texttt{res}$ = $\textbf{m}[\nu_C..\nu_C + 8] \in \Fp^8$. AIR constraints, of degree 9, assert that $\texttt{res} = \text{poseidon\_compress}(\texttt{left} | \texttt{right})$. Degree 3 is also an alternative, but with $\approx$ 2x more columns committed. +The Poseidon precompile receives 3 runtime arguments: $\nu_A$, $\nu_B$, $\nu_C$ interpreted as memory pointers. 3 lookups (each of size 8) into the memory are used to fetch $\texttt{left}$ = $\textbf{m}[\nu_A..\nu_A + 8] \in \Fp^8$, $\texttt{right}$ = $\textbf{m}[\nu_B..\nu_B + 8] \in \Fp^8$, and $\texttt{res}$ = $\textbf{m}[\nu_C..\nu_C + 8] \in \Fp^8$. AIR constraints, of degree 9, assert that $\texttt{res} = \text{poseidon\_compress}(\texttt{left} \;\|\; \texttt{right})$. Degree 3 is also an alternative, at the cost of more committed columns ($\approx 160$ vs. $\approx 100$). + +\subsubsection{Efficiently verrifying hash-based signatures}\label{efficiently_verrifying_hash-based_signatures} + +Hash-based signatures often rely on tweaks and public parameters (see \cite{ethereum_signatures}). + +We present two independent (and composable) optimizations of the poseidon precompile, when both the hash digest and the public parameters are composed of $n = 4$ field elements. We also assume each tweak occupies less than $n$ field elements (in practice 2 is enough). + +\vspace{5mm} + +\textbf{1. Chopped output} + +We introduce a boolean flag $\texttt{flag\_short}$ that signals that only the first $\texttt{n}$ field elements of the output should be written to memory, at $\textbf{m}[\nu_C..\nu_C + \texttt{n}]$ (the remaining $8 - \texttt{n}$ ones are ignored), + +\vspace{5mm} + +\textbf{2. Custom first-half of left-input} + +We introduce a boolean flag $\texttt{flag\_left}$, alongside a compile-time parameter $\texttt{offset\_left}$: +$$\texttt{left} = \begin{cases} \textbf{m}[\nu_A..\nu_A + 8] & \text{if } \texttt{flag\_left} = 0 \text{ (default)} \\ \textbf{m}[\texttt{offset\_left}..\texttt{offset\_left} + 4] \;\|\; \textbf{m}[\nu_A..\nu_A + 4] & \text{if } \texttt{flag\_left} = 1 \end{cases}$$ + +\vspace{5mm} + +Both optimizations can be implemented (together), at the cost of 4 additional columns, and incrementing the degree of the constraints by 1. + +Both flags ($\texttt{flag\_short}$ and $\texttt{flag\_left}$) and the offset ($\texttt{offset\_left}$) can be encoded in a single (compile-time) parameter, as follows: +$$\texttt{AUX} = 1 + 2 \cdot \texttt{flag\_short} + 4 \cdot \texttt{flag\_left} + 8 \cdot \texttt{offset\_left} \cdot \texttt{flag\_left}$$ + +Soundness of this encoding (encoding multiple data into a single field element requires care, to avoid overflows that would break injectivity): +\begin{itemize} + \item both flags are asserted to be boolean by the AIR constraints + \item if $\texttt{flag\_left} = 0$: then $\texttt{offset\_left}$ is unconstrained: but it does not contribute to the encoding $\texttt{AUX}$, and it has no effect in the AIR logic. + \item if $\texttt{flag\_left} = 1$: then for the lookup into memory $\textbf{m}[\texttt{offset\_left}..\texttt{offset\_left} + 4]$ to be valid, $\texttt{offset\_left}$ must be smaller than the memory size $M \leq 2^{26} < p / 2$. As a result no overflow can occur in the encoding of $\texttt{AUX}$, which is thus injective. +\end{itemize} + +\textbf{Conclusion.} Every tuple $(\nu_A, \nu_B, \nu_C, \texttt{AUX})$ pushed by the execution table is faithfully pulled and decoded by the Poseidon table: the recovered triple $(\texttt{flag\_short}, \texttt{flag\_left}, \texttt{offset\_left})$ matches the one used at the call site, with one harmless exception: when $\texttt{flag\_left} = 0$, $\texttt{offset\_left}$ is left unconstrained, but this is inconsequential since it then has no effect at the AIR level. \subsection{Extension op table} @@ -638,7 +675,7 @@ \subsubsection{Memory lookups} \end{itemize} \subsubsection{Bus interaction} -On rows where $\texttt{activation\_flag} = 1$ (i.e.\ $\texttt{start} = 1$ and $\texttt{active} = 1$), the table \textsc{Pull}s $(\texttt{aux}, \text{idx}_A, \text{idx}_B, \text{idx}_R)$ from the precompile bus. The execution table \textsc{Push}es a matching tuple for each EXTENSION\_OP instruction. +On rows where $\texttt{activation\_flag} = 1$ (i.e.\ $\texttt{start} = 1$ and $\texttt{active} = 1$), the table \textsc{Pulls} the tuple $(\texttt{aux}, \text{idx}_A, \text{idx}_B, \text{idx}_R)$ from the precompile bus. The execution table \textsc{Pushes} a matching tuple for each EXTENSION\_OP instruction. The $\texttt{aux}$ encoding ensures both the mode and the length are bound to the bus data. Since $\texttt{is\_be}$, $\texttt{flag\_add}$, $\texttt{flag\_mul}$, and $\texttt{flag\_poly\_eq}$ are constrained to be boolean, and since the length is constrained to be $\leq 2^{20}$ (by Lemma~\ref{lem:len-bound}), no overflow can occur modulo $p$ and the $\texttt{aux}$ value is unique for each combination of parameters, which enforces that all values of PRECOMPILE\_DATA sent by the execution table are correctly received by the extension\_op table. diff --git a/src/bench.sh b/src/bench.sh index 7b72f9029..8ede13ae3 100755 --- a/src/bench.sh +++ b/src/bench.sh @@ -6,6 +6,8 @@ OUTPUT_FILE="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/new_tables.md" +N_SIGNATURES=1550 + set -euo pipefail if ! command -v jq >/dev/null 2>&1; then @@ -56,10 +58,10 @@ recursion_cell() { } # --- XMSS aggregation runs --- -xmss_r1_proven=$(run_bench "" xmss --n-signatures 1400 --log-inv-rate 1); sleep 1 -xmss_r2_proven=$(run_bench "" xmss --n-signatures 1400 --log-inv-rate 2); sleep 1 -xmss_r1_conj=$(run_bench prox-gaps-conjecture xmss --n-signatures 1400 --log-inv-rate 1); sleep 1 -xmss_r2_conj=$(run_bench prox-gaps-conjecture xmss --n-signatures 1400 --log-inv-rate 2); sleep 1 +xmss_r1_proven=$(run_bench "" xmss --n-signatures "$N_SIGNATURES" --log-inv-rate 1); sleep 1 +xmss_r2_proven=$(run_bench "" xmss --n-signatures "$N_SIGNATURES" --log-inv-rate 2); sleep 1 +xmss_r1_conj=$(run_bench prox-gaps-conjecture xmss --n-signatures "$N_SIGNATURES" --log-inv-rate 1); sleep 1 +xmss_r2_conj=$(run_bench prox-gaps-conjecture xmss --n-signatures "$N_SIGNATURES" --log-inv-rate 2); sleep 1 # --- Recursion runs: len(RECURSION_NS) fan-ins x 2 rates x 2 regimes --- # Stored in flat shell variables `rec___` for bash 3.2 compatibility. diff --git a/src/main.rs b/src/main.rs index 5a0fda6cd..bcfe87fb0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -91,7 +91,7 @@ fn main() { raw_xmss: 0, children: vec![ AggregationTopology { - raw_xmss: 700, + raw_xmss: 775, children: vec![], log_inv_rate, overlap: 0, @@ -116,13 +116,13 @@ fn main() { raw_xmss: 0, children: vec![ AggregationTopology { - raw_xmss: 1400, + raw_xmss: 1550, children: vec![], log_inv_rate: 1, overlap: 0, }, AggregationTopology { - raw_xmss: 658, + raw_xmss: 508, children: vec![], log_inv_rate: 2, overlap: 0, @@ -135,13 +135,13 @@ fn main() { raw_xmss: 0, children: vec![ AggregationTopology { - raw_xmss: 1400, + raw_xmss: 1550, children: vec![], log_inv_rate: 2, overlap: 0, }, AggregationTopology { - raw_xmss: 658, + raw_xmss: 508, children: vec![], log_inv_rate: 2, overlap: 0, @@ -158,7 +158,7 @@ fn main() { raw_xmss: 0, children: vec![ AggregationTopology { - raw_xmss: 700, + raw_xmss: 775, children: vec![], log_inv_rate: 2, overlap: 0,