diff --git a/anneal/src/aeneas.rs b/anneal/src/aeneas.rs index 0088e6943e..fa689c9e15 100644 --- a/anneal/src/aeneas.rs +++ b/anneal/src/aeneas.rs @@ -346,8 +346,10 @@ pub fn generate_lean_workspace(roots: &LockedRoots, artifacts: &[AnnealArtifact] let slug = artifact.artifact_slug(); let output_dir = lean_generated_root.join(&slug); + let funs_types = parse_funs_types_in_dir(&output_dir)?; + // Generate Anneal specs - let generated = generate::generate_artifact(artifact); + let generated = generate::generate_artifact_with_funs_types(artifact, &funs_types); let specs_path = output_dir.join(artifact.lean_spec_file_name()); let map_path = output_dir.join(format!("{}.lean.map", artifact.artifact_slug())); @@ -591,6 +593,17 @@ pub fn generate_lean_workspace(roots: &LockedRoots, artifacts: &[AnnealArtifact] Ok(()) } +pub(crate) fn parse_funs_types_in_dir(output_dir: &Path) -> Result { + let funs_path = output_dir.join("Funs.lean"); + if !funs_path.exists() { + return Ok(crate::funs_types::FunsTypeMap::new()); + } + + let content = std::fs::read_to_string(&funs_path) + .with_context(|| format!("Failed to read {}", funs_path.display()))?; + Ok(crate::funs_types::parse_funs_types(&content)) +} + /// Completes Lean verification by generating Anneal `Specs.lean`, writing `Generated.lean`, /// and running `lake build` + diagnostics. pub fn verify_lean_workspace(roots: &LockedRoots, artifacts: &[AnnealArtifact]) -> Result<()> { diff --git a/anneal/src/funs_types.rs b/anneal/src/funs_types.rs new file mode 100644 index 0000000000..239e8b63cc --- /dev/null +++ b/anneal/src/funs_types.rs @@ -0,0 +1,211 @@ +//! Parses Aeneas-generated `Funs.lean` signatures. +//! +//! Anneal's spec generator re-derives types from the Rust AST, losing +//! qualification for `use`-imported names (e.g., `Ordering` instead of +//! `core.sync.atomic.Ordering`). This module provides a lookup table from the +//! Aeneas output as a corrective. + +use std::collections::HashMap; + +/// A function signature parsed from Aeneas-generated Lean. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct FunsSignature { + pub params: Vec<(String, String)>, + pub ret: Option, +} + +/// Function name → Aeneas-generated function signature. +pub type FunsTypeMap = HashMap; + +/// Parses `def` signatures from `Funs.lean`. +/// +/// This extracts explicit parameter types and return types while skipping +/// implicit `{}` parameters. Malformed signatures are ignored. +pub fn parse_funs_types(content: &str) -> FunsTypeMap { + let mut map = FunsTypeMap::new(); + let lines: Vec<&str> = content.lines().collect(); + + for i in 0..lines.len() { + let Some(rest) = lines[i].trim().strip_prefix("def ") else { + continue; + }; + let name = rest.split([' ', '(', '{', ':']).next().unwrap_or(""); + if name.is_empty() { + continue; + } + + // Collect signature lines until `:=`. + let mut sig = String::from(rest); + for j in (i + 1)..lines.len() { + if sig.contains(":=") { + break; + } + sig.push(' '); + sig.push_str(lines[j].trim()); + } + + let params = extract_params(&sig); + let ret = extract_return_type(&sig); + if !params.is_empty() || ret.is_some() { + map.insert(name.to_string(), FunsSignature { params, ret }); + } + } + map +} + +/// Extracts `(name : type)` bindings, skipping `{implicit}` params. +fn extract_params(sig: &str) -> Vec<(String, String)> { + let mut params = Vec::new(); + let mut chars = sig.chars().peekable(); + + while let Some(&c) = chars.peek() { + match c { + '(' => { + chars.next(); + if let Some(pair) = parse_binding(&collect_delimited(&mut chars, ')')) { + params.push(pair); + } + } + '{' => { + chars.next(); + collect_delimited(&mut chars, '}'); + } + ':' => break, + _ => { + chars.next(); + } + } + } + params +} + +/// Extracts the function return type. +fn extract_return_type(sig: &str) -> Option { + let return_start = top_level_return_colon(sig)? + 1; + let return_end = sig[return_start..].find(":=").map(|i| return_start + i).unwrap_or(sig.len()); + let ret = sig[return_start..return_end].trim(); + if ret.is_empty() { + return None; + } + Some(ret.strip_prefix("Result ").unwrap_or(ret).trim().to_string()) +} + +/// Finds the colon that separates the parameter list from the return type. +fn top_level_return_colon(sig: &str) -> Option { + let mut paren_depth = 0u32; + let mut brace_depth = 0u32; + for (i, c) in sig.char_indices() { + match c { + '(' => paren_depth += 1, + ')' => paren_depth = paren_depth.saturating_sub(1), + '{' => brace_depth += 1, + '}' => brace_depth = brace_depth.saturating_sub(1), + ':' if paren_depth == 0 && brace_depth == 0 => return Some(i), + _ => {} + } + } + None +} + +/// Reads chars until the matching `close` delimiter, handling nesting. +fn collect_delimited(chars: &mut std::iter::Peekable>, close: char) -> String { + let open = if close == ')' { '(' } else { '{' }; + let mut depth = 1u32; + let mut buf = String::new(); + for c in chars.by_ref() { + if c == open { + depth += 1; + } else if c == close { + depth -= 1; + if depth == 0 { + return buf; + } + } + buf.push(c); + } + buf +} + +/// Splits `"name : type"` on the first ` : `. +fn parse_binding(s: &str) -> Option<(String, String)> { + let s = s.trim(); + let i = s.find(" : ")?; + let (name, ty) = (s[..i].trim(), s[i + 3..].trim()); + (!name.is_empty() && !ty.is_empty()).then_some((name.to_string(), ty.to_string())) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn simple() { + let m = parse_funs_types("def hash_key (k : Std.Usize) : Result Std.Usize := do\n"); + assert_eq!(m["hash_key"].params, [("k".into(), "Std.Usize".into())]); + assert_eq!(m["hash_key"].ret.as_deref(), Some("Std.Usize")); + } + + #[test] + fn multiline() { + let m = parse_funs_types( + "def frame.AtomicFrameState.load\n\ + \x20 (self : frame.AtomicFrameState) (order : core.sync.atomic.Ordering) :\n\ + \x20 Result frame.FrameState\n\ + \x20 := do\n", + ); + let p = &m["frame.AtomicFrameState.load"].params; + assert_eq!(p[0], ("self".into(), "frame.AtomicFrameState".into())); + assert_eq!(p[1], ("order".into(), "core.sync.atomic.Ordering".into())); + assert_eq!(m["frame.AtomicFrameState.load"].ret.as_deref(), Some("frame.FrameState")); + } + + #[test] + fn skips_implicits() { + let m = parse_funs_types( + "def HashMap.alloc {T : Type} (slots : alloc.vec.Vec T) (n : Std.Usize) :\n\ + \x20 Result Unit := do\n", + ); + let p = &m["HashMap.alloc"].params; + assert_eq!(p.len(), 2); + assert_eq!(p[0].0, "slots"); + assert_eq!(p[1], ("n".into(), "Std.Usize".into())); + } + + #[test] + fn multiple_ordering_params() { + let m = parse_funs_types( + "def f.compare_exchange\n\ + \x20 (self : f.T) (expected : f.S)\n\ + \x20 (success : core.sync.atomic.Ordering)\n\ + \x20 (failure : core.sync.atomic.Ordering) :\n\ + \x20 Result Unit := do\n", + ); + let p = &m["f.compare_exchange"].params; + assert_eq!(p.len(), 4); + assert_eq!(p[2].1, "core.sync.atomic.Ordering"); + assert_eq!(p[3].1, "core.sync.atomic.Ordering"); + } + + #[test] + fn preserves_escaped_keyword_params() { + let m = parse_funs_types("def f (show1 : core.sync.atomic.Ordering) : Result Unit := do\n"); + assert_eq!(m["f"].params, [("show1".into(), "core.sync.atomic.Ordering".into())]); + } + + #[test] + fn parses_return_type() { + let m = parse_funs_types( + "def f (order : core.sync.atomic.Ordering) :\n\ + \x20 Result core.sync.atomic.Ordering\n\ + \x20 := do\n", + ); + assert_eq!(m["f"].ret.as_deref(), Some("core.sync.atomic.Ordering")); + } + + #[test] + fn trait_instance_no_params() { + let m = parse_funs_types("def Foo.Clone : core.clone.Clone Foo := {\n clone := x\n}\n"); + assert!(m["Foo.Clone"].params.is_empty()); + assert_eq!(m["Foo.Clone"].ret.as_deref(), Some("core.clone.Clone Foo")); + } +} diff --git a/anneal/src/generate.rs b/anneal/src/generate.rs index f087023d08..06d5b0ae28 100644 --- a/anneal/src/generate.rs +++ b/anneal/src/generate.rs @@ -236,11 +236,19 @@ pub fn generate_item( item: &crate::parse::ParsedLeanItem, builder: &mut LeanBuilder, naming_context: &NamingContext, + funs_types: &crate::funs_types::FunsTypeMap, ) { + let namespace = naming_context.item_namespace(item); match &item.item { - ParsedItem::Function(func) => { - generate_function(&func.item, &func.anneal, builder, &item.source_file, naming_context) - } + ParsedItem::Function(func) => generate_function_with_funs_types( + &func.item, + &func.anneal, + builder, + &item.source_file, + naming_context, + funs_types, + &namespace, + ), ParsedItem::Type(ty) => { generate_type(&ty.item, &ty.anneal, builder, &item.source_file, naming_context) } @@ -259,6 +267,13 @@ pub fn generate_item( /// 3. Iterates over all items in the artifact, generating code for each. /// 4. Wraps items in their respective module namespaces. pub fn generate_artifact(artifact: &crate::scanner::AnnealArtifact) -> GeneratedArtifact { + generate_artifact_with_funs_types(artifact, &Default::default()) +} + +pub fn generate_artifact_with_funs_types( + artifact: &crate::scanner::AnnealArtifact, + funs_types: &crate::funs_types::FunsTypeMap, +) -> GeneratedArtifact { let mut builder = LeanBuilder::new(); builder.push_str("-- This file is automatically generated by Anneal.\n"); builder.push_str("-- Do not edit manually.\n\n"); @@ -293,7 +308,7 @@ pub fn generate_artifact(artifact: &crate::scanner::AnnealArtifact) -> Generated builder.push_str(&format!("namespace {}\n\n", namespace)); } - generate_item(item, &mut builder, &naming_context); + generate_item(item, &mut builder, &naming_context, funs_types); builder.push('\n'); if !namespace.is_empty() { @@ -571,6 +586,26 @@ fn generate_function( builder: &mut LeanBuilder, source_file: &std::path::Path, naming_context: &NamingContext, +) { + generate_function_with_funs_types( + func, + block, + builder, + source_file, + naming_context, + &Default::default(), + "", + ) +} + +fn generate_function_with_funs_types( + func: &FunctionItem, + block: &FunctionAnnealBlock, + builder: &mut LeanBuilder, + source_file: &std::path::Path, + naming_context: &NamingContext, + funs_types: &crate::funs_types::FunsTypeMap, + item_namespace: &str, ) { let (fn_name, fn_span, impl_struct_name, generic_params, generic_bounds, dict_args) = match func { @@ -599,7 +634,27 @@ fn generate_function( (n.inner.sig.ident.clone(), n.inner.sig.name_span, None, p, b, d) } }; - let args = extract_args_metadata(func, &impl_struct_name); + let mut args = extract_args_metadata(func, &impl_struct_name); + let signature_name = if item_namespace.is_empty() { + fn_name.to_string() + } else { + format!("{}.{}", item_namespace, fn_name) + }; + let aeneas_signature = funs_types.get(&signature_name); + + // Override parameter types from Aeneas-generated Funs.lean when available. + // This ensures fully-qualified type names are used, preventing ambiguity + // with Lean builtins (e.g., `Ordering` vs `core.sync.atomic.Ordering`). + if let Some(aeneas_signature) = aeneas_signature { + for arg in args.iter_mut() { + if let Some((_, aeneas_type)) = + aeneas_signature.params.iter().find(|(name, _)| *name == arg.name) + { + arg.lean_type = aeneas_type.clone(); + } + } + } + let has_return_value = !is_unit_return(func); builder.push_str(&format!("namespace {}\n\n", fn_name)); @@ -747,7 +802,10 @@ fn generate_function( use crate::parse::hkd::SafeReturnType; let ret_lean_type = match &func.sig().output { SafeReturnType::Default => "Unit".to_string(), - SafeReturnType::Type(ty) => map_type(ty), + SafeReturnType::Type(ty) => aeneas_signature + .and_then(|sig| sig.ret.as_ref()) + .cloned() + .unwrap_or_else(|| map_type(ty)), }; post_outputs.push_str(&format!("(ret : {})", ret_lean_type)); } @@ -1529,6 +1587,68 @@ mod tests { assert!(requires_idx < return_type_idx, "Requires should come before return type"); } + #[test] + fn test_gen_uses_aeneas_param_and_return_types() { + let item: syn::ItemFn = parse_quote! { fn foo(order: Ordering) -> Ordering {} }; + let func = FunctionItem::Free(AstNode { inner: item.mirror() }); + let block = mk_block(vec![], vec![], Some(vec![]), None, vec![]); + let mut funs_types = crate::funs_types::FunsTypeMap::new(); + funs_types.insert( + "foo".to_string(), + crate::funs_types::FunsSignature { + params: vec![("order".to_string(), "core.sync.atomic.Ordering".to_string())], + ret: Some("core.sync.atomic.Ordering".to_string()), + }, + ); + + let mut builder = LeanBuilder::new(); + let naming_context = NamingContext::new("test".to_string()); + generate_function_with_funs_types( + &func, + &block, + &mut builder, + Path::new("test.rs"), + &naming_context, + &funs_types, + "", + ); + let out = builder.buf; + + assert!(out.contains("(order : core.sync.atomic.Ordering)")); + assert!(out.contains("(ret : core.sync.atomic.Ordering)")); + assert!(!out.contains("(order : Ordering)")); + assert!(!out.contains("(ret : Ordering)")); + } + + #[test] + fn test_gen_matches_escaped_aeneas_param_names() { + let item: syn::ItemFn = parse_quote! { fn foo(show: Ordering) {} }; + let func = FunctionItem::Free(AstNode { inner: item.mirror() }); + let block = mk_block(vec![], vec![], Some(vec![]), None, vec![]); + let mut funs_types = crate::funs_types::FunsTypeMap::new(); + funs_types.insert( + "foo".to_string(), + crate::funs_types::FunsSignature { + params: vec![("show1".to_string(), "core.sync.atomic.Ordering".to_string())], + ret: Some("Unit".to_string()), + }, + ); + + let mut builder = LeanBuilder::new(); + let naming_context = NamingContext::new("test".to_string()); + generate_function_with_funs_types( + &func, + &block, + &mut builder, + Path::new("test.rs"), + &naming_context, + &funs_types, + "", + ); + + assert!(builder.buf.contains("(show1 : core.sync.atomic.Ordering)")); + } + #[test] fn test_gen_unsafe_axiom() { let item: syn::ItemFn = parse_quote! { unsafe fn ffi(p: *const u8) {} }; diff --git a/anneal/src/main.rs b/anneal/src/main.rs index 05748849d8..3f2fafd2f6 100644 --- a/anneal/src/main.rs +++ b/anneal/src/main.rs @@ -2,6 +2,7 @@ mod aeneas; mod charon; mod diagnostics; mod errors; +mod funs_types; mod generate; mod parse; mod resolve; @@ -140,7 +141,12 @@ fn main() -> anyhow::Result<()> { if emit_anneal { println!("--- Anneal ---"); - let generated = generate::generate_artifact(artifact); + let funs_types = aeneas::parse_funs_types_in_dir(&output_dir)?; + let generated = if funs_types.is_empty() { + generate::generate_artifact(artifact) + } else { + generate::generate_artifact_with_funs_types(artifact, &funs_types) + }; println!("{}", generated.code); } }