Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ tracing-subscriber = { version = "0.3.23", features = ["std", "env-filter"] }
tracing-forest = { version = "0.3.0", features = ["ansi", "smallvec"] }
postcard = { version = "1.1.3", features = ["alloc"] }
lz4_flex = "0.13.0"
include_dir = "0.7"

[features]
prox-gaps-conjecture = ["rec_aggregation/prox-gaps-conjecture"]
Expand Down
1 change: 1 addition & 0 deletions crates/lean_compiler/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ xmss.workspace = true
rand.workspace = true

tracing.workspace = true
include_dir.workspace = true
sub_protocols.workspace = true
lean_vm.workspace = true
backend.workspace = true
Expand Down
30 changes: 17 additions & 13 deletions crates/lean_compiler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,26 +91,30 @@ impl From<RunnerError> for Error {
pub enum ProgramSource {
Raw(String),
Filepath(String),
Embedded {
entry: String,
dir: &'static include_dir::Dir<'static>,
},
}

impl ProgramSource {
pub fn get_content(&self, flags: &CompilationFlags) -> Result<String, String> {
match self {
ProgramSource::Raw(src) => {
let mut result = src.clone();
for (key, value) in flags.replacements.iter() {
result = result.replace(key, value);
}
Ok(result)
}
let raw = match self {
ProgramSource::Raw(src) => src.clone(),
ProgramSource::Filepath(fp) => {
let mut result = std::fs::read_to_string(fp).map_err(|e| format!("Failed to read file {fp}: {e}"))?;
for (key, value) in flags.replacements.iter() {
result = result.replace(key, value);
}
Ok(result)
std::fs::read_to_string(fp).map_err(|e| format!("Failed to read file {fp}: {e}"))?
}
ProgramSource::Embedded { entry, dir } => dir
.get_file(entry)
.and_then(|f| f.contents_utf8())
.ok_or_else(|| format!("Embedded entry '{entry}' not found or not valid UTF-8"))?
.to_string(),
};
let mut result = raw;
for (key, value) in flags.replacements.iter() {
result = result.replace(key, value);
}
Ok(result)
}
}

Expand Down
14 changes: 10 additions & 4 deletions crates/lean_compiler/src/parser/parsers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,24 @@ pub struct ParseContext {
pub next_file_id: usize,
/// Compilation flags
pub flags: CompilationFlags,
/// If `Some`, imports resolve against this embedded directory instead of the filesystem.
pub embedded_dir: Option<&'static include_dir::Dir<'static>>,
}

impl ParseContext {
pub fn new(input: &ProgramSource, flags: CompilationFlags) -> Result<Self, SemanticError> {
let current_source_code = input.get_content(&flags).unwrap();
let (current_filepath, imported_filepaths) = match input {
ProgramSource::Raw(_) => ("<raw_input>".to_string(), BTreeSet::new()),
let current_source_code = input.get_content(&flags).map_err(SemanticError::new)?;
let (current_filepath, imported_filepaths, embedded_dir) = match input {
ProgramSource::Raw(_) => ("<raw_input>".to_string(), BTreeSet::new(), None),
ProgramSource::Filepath(fp) => {
let canonical = std::fs::canonicalize(fp)
.map_err(|e| SemanticError::new(format!("Cannot resolve filepath '{}': {}", fp, e)))?
.to_string_lossy()
.to_string();
(canonical.clone(), [canonical].into_iter().collect())
(canonical.clone(), [canonical].into_iter().collect(), None)
}
ProgramSource::Embedded { entry, dir } => {
(entry.clone(), [entry.clone()].into_iter().collect(), Some(*dir))
}
};
let import_stack = vec![current_filepath.clone()];
Expand All @@ -132,6 +137,7 @@ impl ParseContext {
current_source_code,
next_file_id: 1,
flags,
embedded_dir,
})
}

Expand Down
67 changes: 55 additions & 12 deletions crates/lean_compiler/src/parser/parsers/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,30 @@ impl Parse<Program> for ProgramParser {
ctx.import_root.clone()
};
let raw_path = Path::new(&base_dir).join(&relative_path);
let filepath = raw_path
.canonicalize()
.map_err(|e| {
SemanticError::new(format!(
"Cannot resolve import '{}' (resolved to '{}'): {}",
relative_path,
raw_path.display(),
e
let filepath = if let Some(dir) = ctx.embedded_dir {
let key = lexical_normalize(&raw_path);
if dir.get_file(Path::new(&key)).is_none() {
return Err(SemanticError::new(format!(
"Cannot resolve embedded import '{}' (resolved to '{}')",
relative_path, key
))
})?
.to_string_lossy()
.to_string();
.into());
}
key
} else {
raw_path
.canonicalize()
.map_err(|e| {
SemanticError::new(format!(
"Cannot resolve import '{}' (resolved to '{}'): {}",
relative_path,
raw_path.display(),
e
))
})?
.to_string_lossy()
.to_string()
};

// Check for circular imports
if ctx.import_stack.contains(&filepath) {
Expand Down Expand Up @@ -100,7 +112,15 @@ impl Parse<Program> for ProgramParser {
}
ctx.imported_filepaths.insert(filepath.clone());
ctx.import_stack.push(filepath.clone());
ctx.current_source_code = ProgramSource::Filepath(filepath).get_content(&ctx.flags)?;
let import_source = if let Some(dir) = ctx.embedded_dir {
ProgramSource::Embedded {
entry: filepath.clone(),
dir,
}
} else {
ProgramSource::Filepath(filepath.clone())
};
ctx.current_source_code = import_source.get_content(&ctx.flags)?;
let subprogram = parse_program_helper(ctx)?;
ctx.import_stack.pop();
functions.extend(subprogram.functions);
Expand Down Expand Up @@ -143,6 +163,29 @@ impl Parse<Program> for ProgramParser {
}
}

/// Lexically normalize a path for embedded-source lookups: collapse `.` and
/// `..` components and join with `/` regardless of host OS, so the same key
/// works on every platform.
fn lexical_normalize(path: &Path) -> String {
use std::path::Component;
let mut parts: Vec<String> = Vec::new();
for c in path.components() {
match c {
Component::CurDir => {}
Component::ParentDir => {
if matches!(parts.last().map(String::as_str), Some("..") | None) {
parts.push("..".to_string());
} else {
parts.pop();
}
}
Component::Normal(s) => parts.push(s.to_string_lossy().into_owned()),
Component::RootDir | Component::Prefix(_) => {}
}
}
parts.join("/")
}

/// Parser for import statements.
pub struct ImportStatementParser;

Expand Down
2 changes: 1 addition & 1 deletion crates/lean_compiler/zkDSL.md
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ dot_product_ee(x, y, z) # z = x * y

# Copy extension element (multiply by [1,0,0,0,0]).
# `ONE_EF_PTR` is a guest-program constant that the program must materialize
# in its preamble memory at startup; see `crates/rec_aggregation/utils.py`
# in its preamble memory at startup; see `crates/rec_aggregation/zkdsl_implem/utils.py`
# for an example (`build_preamble_memory`).
dot_product_ee(src, ONE_EF_PTR, dst)

Expand Down
1 change: 1 addition & 0 deletions crates/rec_aggregation/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ xmss.workspace = true
rand.workspace = true

tracing.workspace = true
include_dir.workspace = true
sub_protocols.workspace = true
lean_vm.workspace = true
lean_compiler.workspace = true
Expand Down
14 changes: 7 additions & 7 deletions crates/rec_aggregation/src/compilation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use lean_prover::{
};
use lean_vm::*;
use std::collections::{BTreeMap, HashMap};
use std::path::Path;
use std::sync::OnceLock;
use sub_protocols::{N_VARS_TO_SEND_GKR_COEFFS, min_stacked_n_vars, total_whir_statements};
use tracing::instrument;
Expand Down Expand Up @@ -38,6 +37,8 @@ pub fn init_aggregation_bytecode() {
BYTECODE.get_or_init(compile_main_program_self_referential);
}

static EMBEDDED_ZK_DSL: include_dir::Dir<'_> = include_dir::include_dir!("$CARGO_MANIFEST_DIR/zkdsl_implem");

pub const MAX_RECURSIONS: usize = 16;
pub const MAX_XMSS_AGGREGATED: usize = 1 << 15; // TODO increase (we would need a bigger minimal memory size, totally doable)
pub const MAX_XMSS_DUPLICATES: usize = 1 << 15; // ...same
Expand Down Expand Up @@ -69,12 +70,11 @@ pub(crate) fn type1_input_data_size_padded(program_log_size: usize) -> usize {
fn compile_main_program(program_log_size: usize, bytecode_zero_eval: F) -> Bytecode {
let replacements = build_replacements(program_log_size, bytecode_zero_eval);

let filepath = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("main.py")
.to_str()
.unwrap()
.to_string();
compile_program_with_flags(&ProgramSource::Filepath(filepath), CompilationFlags { replacements })
let source = ProgramSource::Embedded {
entry: "main.py".to_string(),
dir: &EMBEDDED_ZK_DSL,
};
compile_program_with_flags(&source, CompilationFlags { replacements })
}

#[instrument(skip_all)]
Expand Down
2 changes: 1 addition & 1 deletion crates/rec_aggregation/tests/test_log2_ceil.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from snark_lib import *
from ..utils import *
from ..zkdsl_implem.utils import *


def main():
Expand Down
Loading