diff --git a/cprover_bindings/src/goto_program/expr.rs b/cprover_bindings/src/goto_program/expr.rs index 72c01338c32..5d9de5c782c 100644 --- a/cprover_bindings/src/goto_program/expr.rs +++ b/cprover_bindings/src/goto_program/expr.rs @@ -353,6 +353,153 @@ impl Expr { /// Predicates impl Expr { + /// Replace all occurrences of `Symbol { identifier: old_id }` with `replacement`. + /// Returns `(new_expr, changed)` where `changed` indicates if any substitution occurred. + /// + /// Note: Does NOT recurse into `StatementExpression` nodes. These must be + /// flattened first via `inline_as_pure_expr` before substitution. + pub fn substitute_symbol(self, old_id: &InternedString, replacement: &Expr) -> (Expr, bool) { + let loc = self.location; + let typ = self.typ.clone(); + let ann = self.size_of_annotation.clone(); + let mk = |value: ExprValue| Expr { + value: Box::new(value), + typ: typ.clone(), + location: loc, + size_of_annotation: ann.clone(), + }; + let sub = |e: Expr| e.substitute_symbol(old_id, replacement); + let sub_vec = |v: Vec| -> (Vec, bool) { + let mut changed = false; + let v: Vec<_> = v + .into_iter() + .map(|e| { + let (e, c) = sub(e); + changed |= c; + e + }) + .collect(); + (v, changed) + }; + + match *self.value { + ExprValue::Symbol { identifier } if identifier == *old_id => { + (replacement.clone().with_location(loc), true) + } + ExprValue::AddressOf(e) => { + let (e, c) = sub(e); + (mk(AddressOf(e)), c) + } + ExprValue::Dereference(e) => { + let (e, c) = sub(e); + (mk(Dereference(e)), c) + } + ExprValue::Typecast(e) => { + let (e, c) = sub(e); + (mk(Typecast(e)), c) + } + ExprValue::UnOp { op, e } => { + let (e, c) = sub(e); + (mk(UnOp { op, e }), c) + } + ExprValue::BinOp { op, lhs, rhs } => { + let (l, c1) = sub(lhs); + let (r, c2) = sub(rhs); + (mk(BinOp { op, lhs: l, rhs: r }), c1 || c2) + } + ExprValue::If { c, t, e } => { + let (c, c1) = sub(c); + let (t, c2) = sub(t); + let (e, c3) = sub(e); + (mk(If { c, t, e }), c1 || c2 || c3) + } + ExprValue::Index { array, index } => { + let (a, c1) = sub(array); + let (i, c2) = sub(index); + (mk(Index { array: a, index: i }), c1 || c2) + } + ExprValue::Member { lhs, field } => { + let (l, c) = sub(lhs); + (mk(Member { lhs: l, field }), c) + } + ExprValue::FunctionCall { function, arguments } => { + let (f, c1) = sub(function); + let (a, c2) = sub_vec(arguments); + (mk(FunctionCall { function: f, arguments: a }), c1 || c2) + } + ExprValue::Array { elems } => { + let (e, c) = sub_vec(elems); + (mk(Array { elems: e }), c) + } + ExprValue::Struct { values } => { + let (v, c) = sub_vec(values); + (mk(Struct { values: v }), c) + } + ExprValue::Assign { left, right } => { + let (l, c1) = sub(left); + let (r, c2) = sub(right); + (mk(Assign { left: l, right: r }), c1 || c2) + } + ExprValue::ReadOk { ptr, size } => { + let (p, c1) = sub(ptr); + let (s, c2) = sub(size); + (mk(ReadOk { ptr: p, size: s }), c1 || c2) + } + ExprValue::ArrayOf { elem } => { + let (e, c) = sub(elem); + (mk(ArrayOf { elem: e }), c) + } + ExprValue::ByteExtract { e, offset } => { + let (e, c) = sub(e); + (mk(ByteExtract { e, offset }), c) + } + ExprValue::SelfOp { op, e } => { + let (e, c) = sub(e); + (mk(SelfOp { op, e }), c) + } + ExprValue::Union { value, field } => { + let (v, c) = sub(value); + (mk(Union { value: v, field }), c) + } + ExprValue::Forall { variable, domain } => { + let (v, c1) = sub(variable); + let (d, c2) = sub(domain); + (mk(Forall { variable: v, domain: d }), c1 || c2) + } + ExprValue::Exists { variable, domain } => { + let (v, c1) = sub(variable); + let (d, c2) = sub(domain); + (mk(Exists { variable: v, domain: d }), c1 || c2) + } + ExprValue::Vector { elems } => { + let (e, c) = sub_vec(elems); + (mk(Vector { elems: e }), c) + } + ExprValue::ShuffleVector { vector1, vector2, indexes } => { + let (v1, c1) = sub(vector1); + let (v2, c2) = sub(vector2); + let (ix, c3) = sub_vec(indexes); + (mk(ShuffleVector { vector1: v1, vector2: v2, indexes: ix }), c1 || c2 || c3) + } + // Leaf nodes — no substitution possible + ExprValue::Symbol { .. } + | ExprValue::IntConstant(_) + | ExprValue::BoolConstant(_) + | ExprValue::CBoolConstant(_) + | ExprValue::DoubleConstant(_) + | ExprValue::FloatConstant(_) + | ExprValue::Float16Constant(_) + | ExprValue::Float128Constant(_) + | ExprValue::PointerConstant(_) + | ExprValue::StringConstant { .. } + | ExprValue::Nondet + | ExprValue::EmptyUnion => (self, false), + // StatementExpression: not recursed into — must be flattened via + // inline_as_pure_expr before substitution. + ExprValue::StatementExpression { .. } => (self, false), + } + } + pub fn is_int_constant(&self) -> bool { match *self.value { IntConstant(_) => true, @@ -1762,3 +1909,94 @@ impl Expr { exprs } } + +#[cfg(test)] +mod tests { + use super::*; + + fn sym(name: &str) -> Expr { + Expr::symbol_expression(name, Type::signed_int(32)) + } + + fn int(val: i64) -> Expr { + Expr::int_constant(val, Type::signed_int(32)) + } + + #[test] + fn substitute_symbol_leaf_match() { + let old: InternedString = "x".into(); + let replacement = int(42); + let (result, _changed) = sym("x").substitute_symbol(&old, &replacement); + assert!(matches!(result.value(), ExprValue::IntConstant(v) if *v == 42.into())); + } + + #[test] + fn substitute_symbol_leaf_no_match() { + let old: InternedString = "x".into(); + let replacement = int(42); + let (result, _changed) = sym("y").substitute_symbol(&old, &replacement); + assert!(matches!(result.value(), ExprValue::Symbol { identifier } if *identifier == "y")); + } + + #[test] + fn substitute_symbol_in_binop() { + let old: InternedString = "x".into(); + let replacement = int(10); + // x + 1 → 10 + 1 + let expr = sym("x").plus(int(1)); + let (result, _changed) = expr.substitute_symbol(&old, &replacement); + if let ExprValue::BinOp { lhs, rhs, .. } = result.value() { + assert!(matches!(lhs.value(), ExprValue::IntConstant(v) if *v == 10.into())); + assert!(matches!(rhs.value(), ExprValue::IntConstant(v) if *v == 1.into())); + } else { + panic!("Expected BinOp"); + } + } + + #[test] + fn substitute_symbol_nested() { + let old: InternedString = "x".into(); + let replacement = int(5); + // (x + x) * 2 → (5 + 5) * 2 + let expr = sym("x").plus(sym("x")).mul(int(2)); + let (result, _changed) = expr.substitute_symbol(&old, &replacement); + if let ExprValue::BinOp { lhs, .. } = result.value() { + if let ExprValue::BinOp { lhs: ll, rhs: lr, .. } = lhs.value() { + assert!(matches!(ll.value(), ExprValue::IntConstant(v) if *v == 5.into())); + assert!(matches!(lr.value(), ExprValue::IntConstant(v) if *v == 5.into())); + } else { + panic!("Expected inner BinOp"); + } + } else { + panic!("Expected outer BinOp"); + } + } + + #[test] + fn substitute_symbol_in_typecast() { + let old: InternedString = "x".into(); + let replacement = int(7); + let expr = sym("x").cast_to(Type::signed_int(64)); + let (result, _changed) = expr.substitute_symbol(&old, &replacement); + if let ExprValue::Typecast(inner) = result.value() { + assert!(matches!(inner.value(), ExprValue::IntConstant(v) if *v == 7.into())); + } else { + panic!("Expected Typecast"); + } + } + + #[test] + fn substitute_preserves_unrelated_symbols() { + let old: InternedString = "x".into(); + let replacement = int(1); + // y + x → y + 1 + let expr = sym("y").plus(sym("x")); + let (result, _changed) = expr.substitute_symbol(&old, &replacement); + if let ExprValue::BinOp { lhs, rhs, .. } = result.value() { + assert!(matches!(lhs.value(), ExprValue::Symbol { identifier } if *identifier == "y")); + assert!(matches!(rhs.value(), ExprValue::IntConstant(v) if *v == 1.into())); + } else { + panic!("Expected BinOp"); + } + } +} diff --git a/docs/dev/pure-expression-inliner.md b/docs/dev/pure-expression-inliner.md new file mode 100644 index 00000000000..b472ce91d73 --- /dev/null +++ b/docs/dev/pure-expression-inliner.md @@ -0,0 +1,73 @@ +# Pure Expression Inliner + +## Overview + +The pure expression inliner (`inline_as_pure_expr`) inlines function calls +within expression trees as side-effect-free expressions. Unlike the existing +`inline_function_calls_in_expr` which wraps inlined bodies in CBMC +`StatementExpression` nodes, this produces pure expression trees. + +## How It Works + +For a function call `f(arg1, arg2)` where `f` is defined as: +```c +ret_type f(param1, param2) { + local1 = expr1(param1); + local2 = expr2(local1, param2); + return local2; +} +``` + +The pure inliner: +1. **Collects assignments**: `{local1 → expr1(param1), local2 → expr2(local1, param2)}` +2. **Finds the return symbol**: `local2` +3. **Resolves intermediates**: `local2` → `expr2(local1, param2)` → `expr2(expr1(param1), param2)` +4. **Flattens StatementExpressions**: e.g., `({ assert(2!=0); assume(2!=0); i%2 })` → `i%2` +5. **Substitutes parameters**: `expr2(expr1(arg1), arg2)` +6. **Recursively inlines** any remaining function calls in the result + +The resolution step (3) uses `Expr::substitute_symbol` iteratively until a +fixed point is reached. Change detection uses the `(Expr, bool)` return from +`substitute_symbol` — no string comparison needed. + +## Soundness Implications + +**Checked arithmetic in quantifier bodies**: When flattening `StatementExpression` +nodes (e.g., from checked division or remainder), the pure inliner drops the +`Assert` and `Assume` statements that check for overflow and division by zero. + +- **Division by zero** inside a quantifier body will NOT be detected. +- **Arithmetic overflow** inside a quantifier body will NOT be detected. + +**Future improvement**: The dropped assertions could be hoisted outside the +quantifier as preconditions, preserving soundness while keeping the body pure. + +## Limitations + +- **No control flow**: Functions with `if`/`else` or `match` that produce + multiple assignments to the return variable are not fully supported. The + inliner takes the last assignment and emits a `tracing::debug!` diagnostic. +- **No loops**: Functions containing loops cannot be inlined as pure expressions. +- **No recursion**: Recursive functions are detected and the original expression + is returned unchanged (with a `tracing::warn!` diagnostic). No ICE. +- **StatementExpression in substitute_symbol**: `Expr::substitute_symbol` does + NOT recurse into `StatementExpression` nodes. These must be flattened via + `inline_as_pure_expr` before substitution. + +## API + +```rust +// Public entry point — manages the visited set internally +pub fn inline_as_pure_expr_toplevel(&self, expr: &Expr) -> Expr; + +// Expr method — returns (new_expr, changed) for reliable change detection +pub fn substitute_symbol(self, old_id: &InternedString, replacement: &Expr) -> (Expr, bool); +``` + +## Files + +- `cprover_bindings/src/goto_program/expr.rs` — `Expr::substitute_symbol()` +- `kani-compiler/src/codegen_cprover_gotoc/context/goto_ctx.rs` — `inline_as_pure_expr()`, + `inline_as_pure_expr_toplevel()`, `inline_call_as_pure_expr()`, + `collect_assignments_from_stmt()`, `find_return_symbol_in_stmt()`, + `resolve_intermediates_iterative()` diff --git a/kani-compiler/src/codegen_cprover_gotoc/context/goto_ctx.rs b/kani-compiler/src/codegen_cprover_gotoc/context/goto_ctx.rs index a1e807ee4d2..df3b68fb011 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/context/goto_ctx.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/context/goto_ctx.rs @@ -23,7 +23,7 @@ use crate::kani_middle::transform::BodyTransformation; use crate::kani_queries::QueryDb; use cbmc::goto_program::{ CIntType, DatatypeComponent, Expr, ExprValue, Location, Stmt, StmtBody, SwitchCase, Symbol, - SymbolTable, SymbolValues, Type, + SymbolTable, SymbolValues, Type, UnaryOperator, }; use cbmc::utils::aggr_tag; use cbmc::{InternedString, MachineModel}; @@ -41,7 +41,7 @@ use rustc_public::ty::Allocation; use rustc_span::Span; use rustc_span::source_map::respan; use rustc_target::callconv::FnAbi; -use std::collections::{BTreeMap, HashSet}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt::Debug; /// A minimal context needed for recording our results. This allows us to move ownership of the @@ -338,6 +338,230 @@ impl<'tcx, 'r> GotocCtx<'tcx, 'r> { } } +/// Pure Expression Inlining +/// +/// Inline function calls within an expression tree, producing a pure (side-effect-free) +/// expression. Unlike `inline_function_calls_in_expr` which wraps inlined bodies in +/// `StatementExpression` nodes, this produces expressions using only `If` (ternary), +/// `BinOp`, `UnOp`, etc. — no statements, no gotos, no labels. +/// +/// **Soundness note**: When flattening `StatementExpression` nodes (e.g., from checked +/// arithmetic), runtime checks (Assert/Assume for overflow, division by zero) are dropped. +/// Users must ensure arithmetic in quantifier predicates cannot overflow or divide by zero. +/// +/// TODO(#4567): Remove `#[allow(dead_code)]` when used by quantifier-pure-expressions branch. +#[allow(dead_code)] +impl GotocCtx<'_, '_> { + /// Inline all function calls in `expr` as pure expressions. + /// Prefer `inline_as_pure_expr_toplevel` for the public entry point. + fn inline_as_pure_expr(&self, expr: &Expr, visited: &mut HashSet) -> Expr { + match expr.value() { + ExprValue::FunctionCall { function, arguments } => { + if let ExprValue::Symbol { identifier } = function.value() { + self.inline_call_as_pure_expr(identifier, arguments, expr, visited) + } else { + expr.clone() + } + } + ExprValue::BinOp { op, lhs, rhs } => self + .inline_as_pure_expr(lhs, visited) + .binop(*op, self.inline_as_pure_expr(rhs, visited)), + ExprValue::UnOp { op, e } => { + let inlined = self.inline_as_pure_expr(e, visited); + match op { + UnaryOperator::Not => inlined.not(), + UnaryOperator::Bitnot => inlined.bitnot(), + UnaryOperator::UnaryMinus => inlined.neg(), + other => { + tracing::warn!( + ?other, + "Unknown UnaryOperator in pure inliner, preserving original" + ); + expr.clone() + } + } + } + ExprValue::Typecast(e) => { + self.inline_as_pure_expr(e, visited).cast_to(expr.typ().clone()) + } + ExprValue::If { c, t, e } => self.inline_as_pure_expr(c, visited).ternary( + self.inline_as_pure_expr(t, visited), + self.inline_as_pure_expr(e, visited), + ), + ExprValue::Index { array, index } => self + .inline_as_pure_expr(array, visited) + .index(self.inline_as_pure_expr(index, visited)), + ExprValue::Member { lhs, field } => { + self.inline_as_pure_expr(lhs, visited).member(*field, &self.symbol_table) + } + ExprValue::Dereference(e) => self.inline_as_pure_expr(e, visited).dereference(), + ExprValue::AddressOf(e) => Expr::address_of(self.inline_as_pure_expr(e, visited)), + ExprValue::Forall { variable, domain } => Expr::forall_expr( + Type::Bool, + variable.clone(), + self.inline_as_pure_expr(domain, visited), + ), + ExprValue::Exists { variable, domain } => Expr::exists_expr( + Type::Bool, + variable.clone(), + self.inline_as_pure_expr(domain, visited), + ), + ExprValue::StatementExpression { statements, .. } => { + // Extract the final expression from the statement block. + // This handles checked arithmetic (Assert + Assume + Expression). + for stmt in statements.iter().rev() { + if let StmtBody::Expression(e) = stmt.body() { + return self.inline_as_pure_expr(e, visited); + } + } + expr.clone() + } + _ => expr.clone(), + } + } + + /// Public entry point for pure expression inlining. + pub fn inline_as_pure_expr_toplevel(&self, expr: &Expr) -> Expr { + self.inline_as_pure_expr(expr, &mut HashSet::new()) + } + + /// Inline a single function call as a pure expression. + /// Returns the original expression unchanged if the function cannot be inlined + /// (recursive, no body, non-symbol return). + fn inline_call_as_pure_expr( + &self, + fn_id: &InternedString, + arguments: &[Expr], + original_expr: &Expr, + visited: &mut HashSet, + ) -> Expr { + if visited.contains(fn_id) { + tracing::warn!(%fn_id, "Recursive function in quantifier body, cannot inline as pure expression"); + return original_expr.clone(); + } + + let function_body = self.symbol_table.lookup(*fn_id).and_then(|sym| match &sym.value { + SymbolValues::Stmt(stmt) => Some(stmt.clone()), + _ => None, + }); + + let Some(body) = function_body else { + return original_expr.clone(); + }; + + visited.insert(*fn_id); + + let mut assignments: HashMap = HashMap::new(); + Self::collect_assignments_from_stmt(&body, &mut assignments); + + let return_sym = Self::find_return_symbol_in_stmt(&body); + let Some(ret_sym) = return_sym else { + tracing::debug!(%fn_id, "No return symbol found, cannot inline as pure expression"); + visited.remove(fn_id); + return original_expr.clone(); + }; + + let Some(ret_expr) = assignments.remove(&ret_sym) else { + visited.remove(fn_id); + return original_expr.clone(); + }; + + let resolved = Self::resolve_intermediates_iterative(ret_expr, &assignments); + let flattened = self.inline_as_pure_expr(&resolved, visited); + + let result = if let Some(params) = self.symbol_table.lookup_parameters(*fn_id) { + let mut expr = flattened; + for (param, arg) in params.iter().zip(arguments.iter()) { + expr = expr.substitute_symbol(param, arg).0; + } + expr + } else { + flattened + }; + + visited.remove(fn_id); + result + } + + /// Collect all assignments (symbol = expr) from a statement tree. + /// Note: for variables assigned multiple times (e.g., in if/else branches), + /// only the last assignment is kept. This is a known limitation — functions + /// with control-flow-dependent assignments cannot be fully inlined as pure + /// expressions. + fn collect_assignments_from_stmt(stmt: &Stmt, map: &mut HashMap) { + match stmt.body() { + StmtBody::Assign { lhs, rhs } => { + if let ExprValue::Symbol { identifier } = lhs.value() { + if map.contains_key(identifier) { + tracing::debug!( + %identifier, + "Multiple assignments to same variable in function body; \ + last-write-wins may produce incorrect pure expression" + ); + } + map.insert(*identifier, rhs.clone()); + } + } + StmtBody::Block(stmts) => { + for s in stmts { + Self::collect_assignments_from_stmt(s, map); + } + } + StmtBody::Label { body, .. } => Self::collect_assignments_from_stmt(body, map), + _ => {} + } + } + + /// Find the symbol identifier returned by a Return statement. + /// Returns None (with a debug diagnostic) if the return is a direct expression + /// rather than a symbol reference. + fn find_return_symbol_in_stmt(stmt: &Stmt) -> Option { + match stmt.body() { + StmtBody::Return(Some(expr)) => { + if let ExprValue::Symbol { identifier } = expr.value() { + Some(*identifier) + } else { + tracing::debug!( + ?expr, + "Return expression is not a symbol, cannot inline as pure expression" + ); + None + } + } + StmtBody::Block(stmts) => { + for s in stmts { + if let Some(sym) = Self::find_return_symbol_in_stmt(s) { + return Some(sym); + } + } + None + } + StmtBody::Label { body, .. } => Self::find_return_symbol_in_stmt(body), + _ => None, + } + } + + /// Iteratively resolve intermediate variables in an expression. + /// Uses the `changed` flag from `substitute_symbol` for reliable change detection. + fn resolve_intermediates_iterative( + mut expr: Expr, + assignments: &HashMap, + ) -> Expr { + for _ in 0..assignments.len() + 1 { + let mut any_changed = false; + for (sym, rhs) in assignments { + let (new_expr, changed) = expr.substitute_symbol(sym, rhs); + expr = new_expr; + any_changed |= changed; + } + if !any_changed { + break; + } + } + expr + } +} + /// Quantifiers Related impl GotocCtx<'_, '_> { /// Find all quantifier expressions and recursively inline functions in the quantifier bodies.