diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index 911ea53f..6466e83a 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -357,12 +357,7 @@ fn check_block_scoping(block: &[Line], ctx: &mut Context) { for line in block.iter() { match line { Line::ForwardDeclaration { var } => { - let last_scope = ctx.scopes.last_mut().unwrap(); - assert!( - !last_scope.vars.contains(var), - "Variable declared multiple times in the same scope: {var}", - ); - last_scope.vars.insert(var.clone()); + ctx.add_var(var); } Line::Match { value, arms } => { check_expr_scoping(value, ctx); @@ -374,12 +369,9 @@ fn check_block_scoping(block: &[Line], ctx: &mut Context) { } Line::Assignment { var, value } => { check_expr_scoping(value, ctx); - let last_scope = ctx.scopes.last_mut().unwrap(); - assert!( - !last_scope.vars.contains(var), - "Variable declared multiple times in the same scope: {var}", - ); - last_scope.vars.insert(var.clone()); + if !ctx.defines(var) { + ctx.add_var(var); + } } Line::ArrayAssign { array, index, value } => { check_simple_expr_scoping(array, ctx); @@ -428,15 +420,12 @@ fn check_block_scoping(block: &[Line], ctx: &mut Context) { for arg in args { check_expr_scoping(arg, ctx); } - let last_scope = ctx.scopes.last_mut().unwrap(); for target in return_data { match target { AssignmentTarget::Var(var) => { - assert!( - !last_scope.vars.contains(var), - "Variable declared multiple times in the same scope: {var}", - ); - last_scope.vars.insert(var.clone()); + if !ctx.defines(var) { + ctx.add_var(var); + } } AssignmentTarget::ArrayAccess { .. } => {} } @@ -475,12 +464,9 @@ fn check_block_scoping(block: &[Line], ctx: &mut Context) { } => { check_expr_scoping(size, ctx); check_expr_scoping(vectorized_len, ctx); - let last_scope = ctx.scopes.last_mut().unwrap(); - assert!( - !last_scope.vars.contains(var), - "Variable declared multiple times in the same scope: {var}", - ); - last_scope.vars.insert(var.clone()); + if !ctx.defines(var) { + ctx.add_var(var); + } } Line::CustomHint(_, args) => { for arg in args { @@ -488,12 +474,7 @@ fn check_block_scoping(block: &[Line], ctx: &mut Context) { } } Line::PrivateInputStart { result } => { - let last_scope = ctx.scopes.last_mut().unwrap(); - assert!( - !last_scope.vars.contains(result), - "Variable declared multiple times in the same scope: {result}" - ); - last_scope.vars.insert(result.clone()); + ctx.add_var(result); } } } @@ -2291,7 +2272,17 @@ fn handle_inlined_functions(program: &mut Program) { if !func.inlined { let old_body = func.body.clone(); - handle_inlined_functions_helper(&mut func.body, &inlined_functions, &mut counter1, &mut counter2); + let mut ctx = Context::new(); + for (var, _) in func.arguments.iter() { + ctx.add_var(var); + } + func.body = handle_inlined_functions_helper( + &mut ctx, + &func.body, + &inlined_functions, + &mut counter1, + &mut counter2, + ); if func.body != old_body { any_changes = true; @@ -2305,7 +2296,11 @@ fn handle_inlined_functions(program: &mut Program) { if func.inlined { let old_body = func.body.clone(); - handle_inlined_functions_helper(&mut func.body, &inlined_functions, &mut counter1, &mut counter2); + let mut ctx = Context::new(); + for (var, _) in func.arguments.iter() { + ctx.add_var(var); + } + handle_inlined_functions_helper(&mut ctx, &func.body, &inlined_functions, &mut counter1, &mut counter2); if func.body != old_body { any_changes = true; @@ -2331,51 +2326,58 @@ fn handle_inlined_functions(program: &mut Program) { /// Recursively extracts inlined function calls from an expression. /// Returns the modified expression and lines to prepend (forward declarations and function calls). fn extract_inlined_calls_from_expr( - expr: &mut Expression, + expr: &Expression, inlined_functions: &BTreeMap, inlined_var_counter: &mut Counter, -) -> Vec { +) -> (Expression, Vec) { let mut lines = vec![]; match expr { - Expression::Value(_) => {} - Expression::ArrayAccess { index, .. } => { - for idx in index.iter_mut() { - lines.extend(extract_inlined_calls_from_expr( - idx, - inlined_functions, - inlined_var_counter, - )); - } + Expression::Value(_) => (expr.clone(), vec![]), + Expression::ArrayAccess { array, index } => { + let mut index_new = vec![]; + for idx in index { + let (idx, idx_lines) = extract_inlined_calls_from_expr(idx, inlined_functions, inlined_var_counter); + lines.extend(idx_lines); + index_new.push(idx); + } + ( + Expression::ArrayAccess { + array: array.clone(), + index: index_new, + }, + lines, + ) } - Expression::Binary { left, right, .. } => { - lines.extend(extract_inlined_calls_from_expr( - left, - inlined_functions, - inlined_var_counter, - )); - lines.extend(extract_inlined_calls_from_expr( - right, - inlined_functions, - inlined_var_counter, - )); + Expression::Binary { left, operation, right } => { + let (left, left_lines) = extract_inlined_calls_from_expr(left, inlined_functions, inlined_var_counter); + lines.extend(left_lines); + let (right, right_lines) = extract_inlined_calls_from_expr(right, inlined_functions, inlined_var_counter); + lines.extend(right_lines); + ( + Expression::Binary { + left: Box::new(left), + operation: *operation, + right: Box::new(right), + }, + lines, + ) } - Expression::MathExpr(_, args) => { - for arg in args.iter_mut() { - lines.extend(extract_inlined_calls_from_expr( - arg, - inlined_functions, - inlined_var_counter, - )); + Expression::MathExpr(formula, args) => { + let mut args_new = vec![]; + for arg in args { + let (arg, arg_lines) = extract_inlined_calls_from_expr(arg, inlined_functions, inlined_var_counter); + lines.extend(arg_lines); + args_new.push(arg); } + (Expression::MathExpr(*formula, args_new), lines) } Expression::FunctionCall { function_name, args } => { - for arg in args.iter_mut() { - lines.extend(extract_inlined_calls_from_expr( - arg, - inlined_functions, - inlined_var_counter, - )); + let mut args_new = vec![]; + for arg in args { + let (arg, arg_lines) = extract_inlined_calls_from_expr(arg, inlined_functions, inlined_var_counter); + args_new.push(arg); + lines.extend(arg_lines); } if inlined_functions.contains_key(function_name) { @@ -2383,84 +2385,96 @@ fn extract_inlined_calls_from_expr( lines.push(Line::ForwardDeclaration { var: aux_var.clone() }); lines.push(Line::FunctionCall { function_name: function_name.clone(), - args: std::mem::take(args), + args: args.clone(), return_data: vec![AssignmentTarget::Var(aux_var.clone())], line_number: 0, }); - *expr = Expression::Value(SimpleExpr::Var(aux_var)); - } - } - Expression::Len { indices, .. } => { - for idx in indices.iter_mut() { - lines.extend(extract_inlined_calls_from_expr( - idx, - inlined_functions, - inlined_var_counter, - )); - } + (Expression::Value(SimpleExpr::Var(aux_var)), lines) + } else { + (expr.clone(), lines) + } + } + Expression::Len { array, indices } => { + let mut new_indices = vec![]; + for idx in indices.iter() { + let (idx, idx_lines) = extract_inlined_calls_from_expr(idx, inlined_functions, inlined_var_counter); + lines.extend(idx_lines); + new_indices.push(idx); + } + ( + Expression::Len { + array: array.clone(), + indices: new_indices, + }, + lines, + ) } } - - lines } fn extract_inlined_calls_from_boolean_expr( - boolean: &mut BooleanExpr, + boolean: &BooleanExpr, inlined_functions: &BTreeMap, inlined_var_counter: &mut Counter, -) -> Vec { - let mut lines = vec![]; - lines.extend(extract_inlined_calls_from_expr( - &mut boolean.left, - inlined_functions, - inlined_var_counter, - )); - lines.extend(extract_inlined_calls_from_expr( - &mut boolean.right, - inlined_functions, - inlined_var_counter, - )); - lines +) -> (BooleanExpr, Vec) { + let (left, mut lines) = extract_inlined_calls_from_expr(&boolean.left, inlined_functions, inlined_var_counter); + let (right, right_lines) = extract_inlined_calls_from_expr(&boolean.right, inlined_functions, inlined_var_counter); + lines.extend(right_lines); + let boolean = BooleanExpr { + kind: boolean.kind, + left, + right, + }; + (boolean, lines) } fn extract_inlined_calls_from_condition( - condition: &mut Condition, + condition: &Condition, inlined_functions: &BTreeMap, inlined_var_counter: &mut Counter, -) -> Vec { +) -> (Condition, Vec) { match condition { - Condition::Expression(expr, _) => extract_inlined_calls_from_expr(expr, inlined_functions, inlined_var_counter), + Condition::Expression(expr, assume_boolean) => { + let (expr, expr_lines) = extract_inlined_calls_from_expr(expr, inlined_functions, inlined_var_counter); + (Condition::Expression(expr, *assume_boolean), expr_lines) + } Condition::Comparison(boolean) => { - extract_inlined_calls_from_boolean_expr(boolean, inlined_functions, inlined_var_counter) + let (boolean, boolean_lines) = + extract_inlined_calls_from_boolean_expr(boolean, inlined_functions, inlined_var_counter); + (Condition::Comparison(boolean), boolean_lines) } } } fn handle_inlined_functions_helper( - lines: &mut Vec, + ctx: &mut Context, + lines_in: &Vec, inlined_functions: &BTreeMap, inlined_var_counter: &mut Counter, total_inlined_counter: &mut Counter, -) { - // First pass: extract inlined function calls from expressions and handle Line::FunctionCall inlining - // We iterate in reverse to handle splicing correctly - let mut i = lines.len(); - while i > 0 { - i -= 1; - let prepend_lines = match &mut lines[i] { +) -> Vec { + let mut lines_out = vec![]; + for line in lines_in { + match line { + Line::Break | Line::Panic | Line::LocationReport { .. } => { + lines_out.push(line.clone()); + } Line::FunctionCall { function_name, args, return_data, line_number: _, } => { - if let Some(func) = inlined_functions.get(&*function_name) { + if let Some(func) = inlined_functions.get(function_name) { let mut inlined_lines = vec![]; // Only add forward declarations for variable targets, not array accesses for target in return_data.iter() { - if let AssignmentTarget::Var(var) = target { + if let AssignmentTarget::Var(var) = target + && !ctx.defines(var) + { inlined_lines.push(Line::ForwardDeclaration { var: var.clone() }); + ctx.add_var(var); } } @@ -2510,151 +2524,223 @@ fn handle_inlined_functions_helper( let mut func_body = func.body.clone(); inline_lines(&mut func_body, &inlined_args, return_data, total_inlined_counter.next()); inlined_lines.extend(func_body); - - lines.remove(i); // remove the call to the inlined function - lines.splice(i..i, inlined_lines); + lines_out.extend(inlined_lines); + } else { + lines_out.push(line.clone()); } - vec![] } - Line::IfCondition { condition, .. } => { - extract_inlined_calls_from_condition(condition, inlined_functions, inlined_var_counter) - } - Line::ForLoop { start, end, .. } => { - let mut prepend = vec![]; - prepend.extend(extract_inlined_calls_from_expr( - start, + Line::IfCondition { + condition, + then_branch, + else_branch, + line_number, + } => { + extract_inlined_calls_from_condition(condition, inlined_functions, inlined_var_counter); + ctx.scopes.push(Scope::default()); + let then_branch_out = handle_inlined_functions_helper( + ctx, + then_branch, inlined_functions, inlined_var_counter, - )); - prepend.extend(extract_inlined_calls_from_expr( - end, + total_inlined_counter, + ); + ctx.scopes.pop(); + ctx.scopes.push(Scope::default()); + let else_branch_out = handle_inlined_functions_helper( + ctx, + else_branch, inlined_functions, inlined_var_counter, - )); - prepend + total_inlined_counter, + ); + ctx.scopes.pop(); + lines_out.push(Line::IfCondition { + condition: condition.clone(), + then_branch: then_branch_out, + else_branch: else_branch_out, + line_number: *line_number, + }); } - Line::Assert { boolean, .. } => { - extract_inlined_calls_from_boolean_expr(boolean, inlined_functions, inlined_var_counter) + Line::Match { value, arms } => { + let mut arms_out: Vec<(usize, Vec)> = Vec::new(); + for (i, arm) in arms { + ctx.scopes.push(Scope::default()); + let arm_out = handle_inlined_functions_helper( + ctx, + arm, + inlined_functions, + inlined_var_counter, + total_inlined_counter, + ); + ctx.scopes.pop(); + arms_out.push((*i, arm_out)); + } + lines_out.push(Line::Match { + value: value.clone(), + arms: arms_out, + }); } - Line::Assignment { value, .. } => { - extract_inlined_calls_from_expr(value, inlined_functions, inlined_var_counter) + Line::ForwardDeclaration { var } => { + lines_out.push(line.clone()); + ctx.add_var(var); } - Line::ArrayAssign { index, value, .. } => { - let mut prepend = vec![]; - prepend.extend(extract_inlined_calls_from_expr( - index, + Line::PrivateInputStart { result } => { + lines_out.push(line.clone()); + if !ctx.defines(result) { + ctx.add_var(result); + } + } + Line::ForLoop { + iterator, + start, + end, + body, + rev, + unroll, + line_number, + } => { + // Handle inlining in the loop bounds + let (start, start_lines) = + extract_inlined_calls_from_expr(start, inlined_functions, inlined_var_counter); + lines_out.extend(start_lines); + let (end, end_lines) = extract_inlined_calls_from_expr(end, inlined_functions, inlined_var_counter); + lines_out.extend(end_lines); + + // Handle inlining in the loop body + ctx.scopes.push(Scope::default()); + ctx.add_var(iterator); + let loop_body_out = handle_inlined_functions_helper( + ctx, + body, inlined_functions, inlined_var_counter, - )); - prepend.extend(extract_inlined_calls_from_expr( + total_inlined_counter, + ); + ctx.scopes.pop(); + + // Push modified loop + lines_out.push(Line::ForLoop { + iterator: iterator.clone(), + start, + end, + body: loop_body_out, + rev: *rev, + unroll: *unroll, + line_number: *line_number, + }); + } + Line::Assert { + debug, + boolean, + line_number, + } => { + let (boolean, boolean_lines) = + extract_inlined_calls_from_boolean_expr(boolean, inlined_functions, inlined_var_counter); + lines_out.extend(boolean_lines); + lines_out.push(Line::Assert { + debug: *debug, + boolean, + line_number: *line_number, + }); + } + Line::Assignment { var, value } => { + let (value, value_lines) = + extract_inlined_calls_from_expr(value, inlined_functions, inlined_var_counter); + lines_out.extend(value_lines); + if !ctx.defines(var) { + ctx.add_var(var); + } + lines_out.push(Line::Assignment { + var: var.clone(), value, - inlined_functions, - inlined_var_counter, - )); - prepend + }); } - Line::Print { content, .. } => { - let mut prepend = vec![]; - for expr in content.iter_mut() { - prepend.extend(extract_inlined_calls_from_expr( - expr, - inlined_functions, - inlined_var_counter, - )); + Line::ArrayAssign { array, index, value } => { + let (index, index_lines) = + extract_inlined_calls_from_expr(index, inlined_functions, inlined_var_counter); + lines_out.extend(index_lines); + let (value, value_lines) = + extract_inlined_calls_from_expr(value, inlined_functions, inlined_var_counter); + lines_out.extend(value_lines); + lines_out.push(Line::ArrayAssign { + array: array.clone(), + index, + value, + }); + } + Line::Print { line_info, content } => { + let mut new_content = vec![]; + for expr in content { + let (expr, expr_lines) = + extract_inlined_calls_from_expr(expr, inlined_functions, inlined_var_counter); + lines_out.extend(expr_lines); + new_content.push(expr); } - prepend + lines_out.push(Line::Print { + line_info: line_info.clone(), + content: new_content, + }); } Line::FunctionRet { return_data } => { - let mut prepend = vec![]; - for expr in return_data.iter_mut() { - prepend.extend(extract_inlined_calls_from_expr( - expr, - inlined_functions, - inlined_var_counter, - )); + let mut new_return_data = vec![]; + for expr in return_data { + let (expr, expr_lines) = + extract_inlined_calls_from_expr(expr, inlined_functions, inlined_var_counter); + lines_out.extend(expr_lines); + new_return_data.push(expr); } - prepend + lines_out.push(Line::FunctionRet { + return_data: new_return_data, + }); } - Line::Precompile { args, .. } => { - let mut prepend = vec![]; - for expr in args.iter_mut() { - prepend.extend(extract_inlined_calls_from_expr( - expr, - inlined_functions, - inlined_var_counter, - )); + Line::Precompile { table, args } => { + let mut new_args = vec![]; + for expr in args { + let (expr, new_lines) = + extract_inlined_calls_from_expr(expr, inlined_functions, inlined_var_counter); + lines_out.extend(new_lines); + new_args.push(expr); } - prepend + lines_out.push(Line::Precompile { + table: *table, + args: new_args, + }); } Line::MAlloc { - size, vectorized_len, .. + var, + size, + vectorized, + vectorized_len, } => { - let mut prepend = vec![]; - prepend.extend(extract_inlined_calls_from_expr( + let (size, size_lines) = extract_inlined_calls_from_expr(size, inlined_functions, inlined_var_counter); + lines_out.extend(size_lines); + let (vectorized_len, vectorized_len_lines) = + extract_inlined_calls_from_expr(vectorized_len, inlined_functions, inlined_var_counter); + lines_out.extend(vectorized_len_lines); + + if !ctx.defines(var) { + ctx.add_var(var); + } + lines_out.push(Line::MAlloc { + var: var.clone(), size, - inlined_functions, - inlined_var_counter, - )); - prepend.extend(extract_inlined_calls_from_expr( + vectorized: *vectorized, vectorized_len, - inlined_functions, - inlined_var_counter, - )); - prepend + }); } - Line::CustomHint(_, args) => { - let mut prepend = vec![]; - for expr in args.iter_mut() { - prepend.extend(extract_inlined_calls_from_expr( - expr, - inlined_functions, - inlined_var_counter, - )); + Line::CustomHint(hint, args) => { + let mut new_args = vec![]; + for expr in args { + let (expr, new_lines) = + extract_inlined_calls_from_expr(expr, inlined_functions, inlined_var_counter); + lines_out.extend(new_lines); + new_args.push(expr); } - prepend + lines_out.push(Line::CustomHint(*hint, new_args)); } - _ => vec![], }; - - if !prepend_lines.is_empty() { - let prepend_count = prepend_lines.len(); - lines.splice(i..i, prepend_lines); - i += prepend_count; // Adjust i to account for the inserted lines - } - } - - // Second pass: recursively process nested blocks - for line in lines.iter_mut() { - match line { - Line::IfCondition { - then_branch, - else_branch, - .. - } => { - handle_inlined_functions_helper( - then_branch, - inlined_functions, - inlined_var_counter, - total_inlined_counter, - ); - handle_inlined_functions_helper( - else_branch, - inlined_functions, - inlined_var_counter, - total_inlined_counter, - ); - } - Line::ForLoop { body, .. } => { - handle_inlined_functions_helper(body, inlined_functions, inlined_var_counter, total_inlined_counter); - } - Line::Match { arms, .. } => { - for (_, arm) in arms { - handle_inlined_functions_helper(arm, inlined_functions, inlined_var_counter, total_inlined_counter); - } - } - _ => {} - } } + lines_out } fn handle_const_arguments(program: &mut Program) -> bool { diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index 3e460c48..53af6669 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -220,7 +220,7 @@ impl From for ConstExpression { } } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum AssumeBoolean { AssumeBoolean, DoNotAssumeBoolean, @@ -486,6 +486,7 @@ pub enum Line { } /// A context specifying which variables are in scope. +#[derive(Debug)] pub struct Context { /// A list of lexical scopes, innermost scope last. pub scopes: Vec, @@ -494,6 +495,13 @@ pub struct Context { } impl Context { + pub fn new() -> Context { + Context { + scopes: vec![Scope::default()], + const_arrays: BTreeMap::new(), + } + } + pub fn defines(&self, var: &Var) -> bool { if self.const_arrays.contains_key(var) { return true; @@ -505,9 +513,18 @@ impl Context { } false } + + pub fn add_var(&mut self, var: &Var) { + let last_scope = self.scopes.last_mut().unwrap(); + assert!( + !last_scope.vars.contains(var), + "Variable declared multiple times in the same scope: {var}", + ); + last_scope.vars.insert(var.clone()); + } } -#[derive(Default)] +#[derive(Debug, Default)] pub struct Scope { /// A set of declared variables. pub vars: BTreeSet, diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index 7170c8dd..d6a09859 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -270,15 +270,13 @@ impl FunctionCallParser { if let Some(hint) = CustomHint::find_by_name(&function_name) { if !return_data.is_empty() { return Err(SemanticError::new(format!( - "Custom hint: \"{}\" should not return values", - function_name + "Custom hint: \"{function_name}\" should not return values", )) .into()); } if !hint.n_args_range().contains(&args.len()) { return Err(SemanticError::new(format!( - "Custom hint: \"{}\" : invalid number of arguments", - function_name + "Custom hint: \"{function_name}\" : invalid number of arguments", )) .into()); } diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 9518b6e8..759040d7 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -563,6 +563,33 @@ fn test_inlined_2() { ); } +#[test] +fn test_inlined_3() { + let program = r#" + fn main() { + x = func(); + return; + } + fn func() -> 1 { + var a; + if 0 == 0 { + a = aux(); + } + return a; + } + + fn aux() inline -> 1 { + return 1; + } + "#; + compile_and_run( + &ProgramSource::Raw(program.to_string()), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); +} + #[test] fn test_match() { let program = r#"