diff --git a/Strata/DDM/AST.lean b/Strata/DDM/AST.lean index 1077807c34..44df751988 100644 --- a/Strata/DDM/AST.lean +++ b/Strata/DDM/AST.lean @@ -150,6 +150,7 @@ inductive SepFormat where | comma -- Comma separator (CommaSepBy) | space -- Space separator (SpaceSepBy) | spacePrefix -- Space before each element (SpacePrefixSepBy) +| newline -- Newline separator (NewlineSepBy) deriving Inhabited, Repr, BEq namespace SepFormat @@ -159,18 +160,21 @@ def toString : SepFormat → String | .comma => "commaSepBy" | .space => "spaceSepBy" | .spacePrefix => "spacePrefixSepBy" + | .newline => "newlineSepBy" def toIonName : SepFormat → String | .none => "seq" | .comma => "commaSepList" | .space => "spaceSepList" | .spacePrefix => "spacePrefixedList" + | .newline => "newlineSepList" def fromIonName? : String → Option SepFormat | "seq" => some .none | "commaSepList" => some .comma | "spaceSepList" => some .space | "spacePrefixedList" => some .spacePrefix + | "newlineSepList" => some .newline | _ => none theorem fromIonName_toIonName_roundtrip (sep : SepFormat) : diff --git a/Strata/DDM/BuiltinDialects/Init.lean b/Strata/DDM/BuiltinDialects/Init.lean index 20ebfda384..927bb3600c 100644 --- a/Strata/DDM/BuiltinDialects/Init.lean +++ b/Strata/DDM/BuiltinDialects/Init.lean @@ -20,6 +20,7 @@ def SyntaxCat.mkSeq (c:SyntaxCat) : SyntaxCat := { ann := .none, name := q`Init. def SyntaxCat.mkCommaSepBy (c:SyntaxCat) : SyntaxCat := { ann := .none, name := q`Init.CommaSepBy, args := #[c] } def SyntaxCat.mkSpaceSepBy (c:SyntaxCat) : SyntaxCat := { ann := .none, name := q`Init.SpaceSepBy, args := #[c] } def SyntaxCat.mkSpacePrefixSepBy (c:SyntaxCat) : SyntaxCat := { ann := .none, name := q`Init.SpacePrefixSepBy, args := #[c] } +def SyntaxCat.mkNewlineSepBy (c:SyntaxCat) : SyntaxCat := { ann := .none, name := q`Init.NewlineSepBy, args := #[c] } def initDialect : Dialect := BuiltinM.create! "Init" #[] do let Ident : ArgDeclKind := .cat <| .atom .none q`Init.Ident @@ -56,6 +57,8 @@ def initDialect : Dialect := BuiltinM.create! "Init" #[] do declareCat q`Init.SpacePrefixSepBy #["a"] + declareCat q`Init.NewlineSepBy #["a"] + let QualifiedIdent := q`Init.QualifiedIdent declareCat QualifiedIdent declareOp { diff --git a/Strata/DDM/Elab/Core.lean b/Strata/DDM/Elab/Core.lean index 92c52cf284..a242257375 100644 --- a/Strata/DDM/Elab/Core.lean +++ b/Strata/DDM/Elab/Core.lean @@ -1173,6 +1173,8 @@ partial def catElaborator (c : SyntaxCat) : TypingContext → Syntax → ElabM T elabSeqWith c .space "spaceSepBy" (·.getSepArgs) | q`Init.SpacePrefixSepBy => elabSeqWith c .spacePrefix "spacePrefixSepBy" (·.getArgs) + | q`Init.NewlineSepBy => + elabSeqWith c .newline "newlineSepBy" (·.getArgs) | _ => assert! c.args.isEmpty elabOperation diff --git a/Strata/DDM/Format.lean b/Strata/DDM/Format.lean index 03fdd0f164..debc8df28b 100644 --- a/Strata/DDM/Format.lean +++ b/Strata/DDM/Format.lean @@ -316,7 +316,7 @@ private def SyntaxDefAtom.formatArgs (opts : FormatOptions) (args : Array PrecFo match stx with | .ident lvl prec _ => let ⟨r, innerPrec⟩ := args[lvl]! - if prec > 0 ∧ (innerPrec ≤ prec ∨ opts.alwaysParen) then + if prec > 0 ∧ (innerPrec < prec ∨ opts.alwaysParen) then f!"({r})" else r @@ -397,6 +397,13 @@ private partial def ArgF.mformatM {α} : ArgF α → FormatM PrecFormat | .spacePrefix => .atom <$> entries.foldlM (init := .nil) fun p a => return (p ++ " " ++ (← a.mformatM).format) + | .newline => + if z : entries.size = 0 then + pure (.atom .nil) + else do + let f i q s := return s ++ .line ++ (← entries[i].mformatM).format + let a := (← entries[0].mformatM).format + .atom <$> entries.size.foldlM f (start := 1) a private partial def ppArgs (f : StrataFormat) (rargs : Array Arg) : FormatM PrecFormat := if rargs.isEmpty then diff --git a/Strata/DDM/Integration/Java/Gen.lean b/Strata/DDM/Integration/Java/Gen.lean index 577cf00dc4..0004d4a6fd 100644 --- a/Strata/DDM/Integration/Java/Gen.lean +++ b/Strata/DDM/Integration/Java/Gen.lean @@ -117,14 +117,14 @@ partial def syntaxCatToJavaType (cat : SyntaxCat) : JavaType := else if abstractCategories.contains cat.name then .simple (abstractJavaName cat.name) else match cat.name with - | ⟨"Init", "Option"⟩ => + | q`Init.Option => match cat.args[0]? with | some inner => .optional (syntaxCatToJavaType inner) | none => panic! "Init.Option requires a type argument" - | ⟨"Init", "Seq"⟩ | ⟨"Init", "CommaSepBy"⟩ => + | q`Init.Seq | q`Init.CommaSepBy | q`Init.NewlineSepBy | q`Init.SpaceSepBy | q`Init.SpacePrefixSepBy => match cat.args[0]? with | some inner => .list (syntaxCatToJavaType inner) - | none => panic! "Init.Seq/CommaSepBy requires a type argument" + | none => panic! "List category requires a type argument" | ⟨"Init", _⟩ => panic! s!"Unknown Init category: {cat.name.name}" | ⟨_, name⟩ => .simple (escapeJavaName (toPascalCase name)) @@ -132,12 +132,23 @@ def argDeclKindToJavaType : ArgDeclKind → JavaType | .type _ => .simple "Expr" | .cat c => syntaxCatToJavaType c +/-- Get Ion separator name for a list category, or none if not a list. -/ +def getSeparator (c : SyntaxCat) : Option String := + match c.name with + | q`Init.Seq => some "seq" + | q`Init.CommaSepBy => some "commaSepList" + | q`Init.NewlineSepBy => some "newlineSepList" + | q`Init.SpaceSepBy => some "spaceSepList" + | q`Init.SpacePrefixSepBy => some "spacePrefixedList" + | _ => none + /-- Extract the QualifiedIdent for categories that need Java interfaces, or none for primitives. -/ partial def syntaxCatToQualifiedName (cat : SyntaxCat) : Option QualifiedIdent := if primitiveCategories.contains cat.name then none else if abstractCategories.contains cat.name then some cat.name else match cat.name with - | ⟨"Init", "Option"⟩ | ⟨"Init", "Seq"⟩ | ⟨"Init", "CommaSepBy"⟩ => + | q`Init.Option | q`Init.Seq | q`Init.CommaSepBy + | q`Init.NewlineSepBy | q`Init.SpaceSepBy | q`Init.SpacePrefixSepBy => cat.args[0]?.bind syntaxCatToQualifiedName | ⟨"Init", _⟩ => none | qid => some qid @@ -178,8 +189,7 @@ structure NameAssignments where /-! ## Code Generation -/ def argDeclToJavaField (decl : ArgDecl) : JavaField := - { name := escapeJavaName decl.ident - type := argDeclKindToJavaType decl.kind } + { name := escapeJavaName decl.ident, type := argDeclKindToJavaType decl.kind } def JavaField.toParam (f : JavaField) : String := s!"{f.type.toJava} {f.name}" @@ -225,8 +235,9 @@ def generateNodeInterface (package : String) (categories : List String) : String def generateStubInterface (package : String) (name : String) : String × String := (s!"{name}.java", s!"package {package};\n\npublic non-sealed interface {name} extends Node \{}\n") -def generateSerializer (package : String) : String := +def generateSerializer (package : String) (separatorMap : String) : String := serializerTemplate.replace templatePackage package + |>.replace "/*SEPARATOR_MAP*/" separatorMap /-- Assign unique Java names to all generated types -/ def assignAllNames (d : Dialect) : NameAssignments := @@ -240,7 +251,7 @@ def assignAllNames (d : Dialect) : NameAssignments := let cats := if cats.contains op.category then cats else cats.push op.category let refs := op.argDecls.toArray.foldl (init := refs) fun refs arg => match arg.kind with - | .type _ => refs.insert ⟨"Init", "Expr"⟩ + | .type _ => refs.insert q`Init.Expr | .cat c => match syntaxCatToQualifiedName c with | some qid => refs.insert qid | none => refs @@ -307,17 +318,30 @@ def opDeclToJavaRecord (dialectName : String) (names : NameAssignments) (op : Op fields := op.argDecls.toArray.map argDeclToJavaField } def generateBuilders (package : String) (dialectName : String) (d : Dialect) (names : NameAssignments) : String := - let method (op : OpDecl) := + let methods (op : OpDecl) := let fields := op.argDecls.toArray.map argDeclToJavaField - let (ps, as) := fields.foldl (init := (#[], #[])) fun (ps, as) f => + let (ps, as, checks) := fields.foldl (init := (#[], #[], #[])) fun (ps, as, checks) f => match f.type with - | .simple "java.math.BigInteger" _ => (ps.push s!"long {f.name}", as.push s!"java.math.BigInteger.valueOf({f.name})") - | .simple "java.math.BigDecimal" _ => (ps.push s!"double {f.name}", as.push s!"java.math.BigDecimal.valueOf({f.name})") - | t => (ps.push s!"{t.toJava} {f.name}", as.push f.name) + | .simple "java.math.BigInteger" _ => + (ps.push s!"long {f.name}", + as.push s!"java.math.BigInteger.valueOf({f.name})", + checks.push s!"if ({f.name} < 0) throw new IllegalArgumentException(\"{f.name} must be non-negative\");") + | .simple "java.math.BigDecimal" _ => (ps.push s!"double {f.name}", as.push s!"java.math.BigDecimal.valueOf({f.name})", checks) + | t => (ps.push s!"{t.toJava} {f.name}", as.push f.name, checks) let methodName := escapeJavaName op.name - s!" public static {names.categories[op.category]!} {methodName}({", ".intercalate ps.toList}) \{ return new {names.operators[(op.category, op.name)]!}(SourceRange.NONE{if as.isEmpty then "" else ", " ++ ", ".intercalate as.toList}); }" - let methods := d.declarations.filterMap fun | .op op => some (method op) | _ => none - s!"package {package};\n\npublic class {dialectName} \{\n{"\n".intercalate methods.toList}\n}\n" + let returnType := names.categories[op.category]! + let recordName := names.operators[(op.category, op.name)]! + let checksStr := if checks.isEmpty then "" else " ".intercalate checks.toList ++ " " + let argsStr := if as.isEmpty then "" else ", " ++ ", ".intercalate as.toList + let paramsStr := ", ".intercalate ps.toList + -- Overload with SourceRange parameter + let srParams := if ps.isEmpty then "SourceRange sourceRange" else s!"SourceRange sourceRange, {paramsStr}" + let withSR := s!" public static {returnType} {methodName}({srParams}) \{ {checksStr}return new {recordName}(sourceRange{argsStr}); }" + -- Convenience overload without SourceRange + let withoutSR := s!" public static {returnType} {methodName}({paramsStr}) \{ {checksStr}return new {recordName}(SourceRange.NONE{argsStr}); }" + s!"{withSR}\n{withoutSR}" + let allMethods := d.declarations.filterMap fun | .op op => some (methods op) | _ => none + s!"package {package};\n\npublic class {dialectName} \{\n{"\n\n".intercalate allMethods.toList}\n}\n" def generateDialect (d : Dialect) (package : String) : Except String GeneratedFiles := do let names := assignAllNames d @@ -351,13 +375,30 @@ def generateDialect (d : Dialect) (package : String) : Except String GeneratedFi -- All interface names for Node permits clause let allInterfaceNames := (sealedInterfaces ++ stubInterfaces).map (·.1.dropRight 5) + -- Generate separator map for list fields + let separatorEntries := d.declarations.toList.filterMap fun decl => + match decl with + | .op op => + let opName := s!"{d.name}.{op.name}" + let fieldEntries := op.argDecls.toArray.toList.filterMap fun arg => + match arg.kind with + | .cat c => match getSeparator c with + | some sep => some s!"\"{escapeJavaName arg.ident}\", \"{sep}\"" + | none => none + | _ => none + if fieldEntries.isEmpty then none + else some s!" \"{opName}\", java.util.Map.of({", ".intercalate fieldEntries})" + | _ => none + let separatorMap := if separatorEntries.isEmpty then "java.util.Map.of()" + else s!"java.util.Map.of(\n{",\n".intercalate separatorEntries})" + return { sourceRange := generateSourceRange package node := generateNodeInterface package allInterfaceNames interfaces := sealedInterfaces.toArray ++ stubInterfaces.toArray records := records.toArray builders := (s!"{names.builders}.java", generateBuilders package names.builders d names) - serializer := generateSerializer package + serializer := generateSerializer package separatorMap } /-! ## File Output -/ diff --git a/Strata/DDM/Integration/Java/templates/IonSerializer.java b/Strata/DDM/Integration/Java/templates/IonSerializer.java index 2a0157fca7..ae1d512215 100644 --- a/Strata/DDM/Integration/Java/templates/IonSerializer.java +++ b/Strata/DDM/Integration/Java/templates/IonSerializer.java @@ -6,6 +6,8 @@ public class IonSerializer { private final IonSystem ion; + private static final java.util.Map> SEPARATORS = /*SEPARATOR_MAP*/; + public IonSerializer(IonSystem ion) { this.ion = ion; } @@ -22,14 +24,17 @@ public IonValue serialize(Node node) { private IonSexp serializeNode(Node node) { IonSexp sexp = ion.newEmptySexp(); - sexp.add(ion.newSymbol(node.operationName())); + String opName = node.operationName(); + sexp.add(ion.newSymbol(opName)); sexp.add(serializeSourceRange(node.sourceRange())); + var fieldSeps = SEPARATORS.getOrDefault(opName, java.util.Map.of()); for (var component : node.getClass().getRecordComponents()) { if (component.getName().equals("sourceRange")) continue; try { java.lang.Object value = component.getAccessor().invoke(node); - sexp.add(serializeArg(value, component.getType(), component.getGenericType())); + String sep = fieldSeps.get(component.getName()); + sexp.add(serializeArg(value, sep, component.getType())); } catch (java.lang.Exception e) { throw new java.lang.RuntimeException("Failed to serialize " + component.getName(), e); } @@ -54,7 +59,7 @@ private IonValue serializeSourceRange(SourceRange sr) { return sexp; } - private IonValue serializeArg(java.lang.Object value, java.lang.Class type, java.lang.reflect.Type genericType) { + private IonValue serializeArg(java.lang.Object value, String sep, java.lang.Class type) { if (value == null) { return serializeOption(java.util.Optional.empty()); } @@ -80,7 +85,7 @@ private IonValue serializeArg(java.lang.Object value, java.lang.Class type, j return serializeOption(opt); } if (value instanceof java.util.List list) { - return serializeSeq(list, genericType); + return serializeSeq(list, sep != null ? sep : "seq"); } throw new java.lang.IllegalArgumentException("Unsupported type: " + type); } @@ -129,17 +134,17 @@ private IonValue serializeOption(java.util.Optional opt) { sexp.add(ion.newSymbol("option")); sexp.add(ion.newNull()); if (opt.isPresent()) { - sexp.add(serializeArg(opt.get(), opt.get().getClass(), opt.get().getClass())); + sexp.add(serializeArg(opt.get(), null, opt.get().getClass())); } return sexp; } - private IonValue serializeSeq(java.util.List list, java.lang.reflect.Type genericType) { + private IonValue serializeSeq(java.util.List list, String sepType) { IonSexp sexp = ion.newEmptySexp(); - sexp.add(ion.newSymbol("seq")); + sexp.add(ion.newSymbol(sepType)); sexp.add(ion.newNull()); for (java.lang.Object item : list) { - sexp.add(serializeArg(item, item.getClass(), item.getClass())); + sexp.add(serializeArg(item, null, item.getClass())); } return sexp; } diff --git a/Strata/DDM/Integration/Lean/Gen.lean b/Strata/DDM/Integration/Lean/Gen.lean index 031604d6f8..aeb89b5a98 100644 --- a/Strata/DDM/Integration/Lean/Gen.lean +++ b/Strata/DDM/Integration/Lean/Gen.lean @@ -744,6 +744,8 @@ partial def toAstApplyArg (vn : Name) (cat : SyntaxCat) (unwrap : Bool := false) toAstApplyArgSeq v cat ``SepFormat.space | q`Init.SpacePrefixSepBy => do toAstApplyArgSeq v cat ``SepFormat.spacePrefix + | q`Init.NewlineSepBy => do + toAstApplyArgSeq v cat ``SepFormat.newline | q`Init.Seq => do toAstApplyArgSeq v cat ``SepFormat.none | q`Init.Option => do @@ -909,6 +911,8 @@ partial def getOfIdentArgWithUnwrap (varName : String) (cat : SyntaxCat) (unwrap getOfIdentArgSeq varName cat e ``SepFormat.space | q`Init.SpacePrefixSepBy => do getOfIdentArgSeq varName cat e ``SepFormat.spacePrefix + | q`Init.NewlineSepBy => do + getOfIdentArgSeq varName cat e ``SepFormat.newline | q`Init.Seq => do getOfIdentArgSeq varName cat e ``SepFormat.none | q`Init.Option => do diff --git a/Strata/DDM/Integration/Lean/ToExpr.lean b/Strata/DDM/Integration/Lean/ToExpr.lean index 16b5e302e4..ac86492a63 100644 --- a/Strata/DDM/Integration/Lean/ToExpr.lean +++ b/Strata/DDM/Integration/Lean/ToExpr.lean @@ -40,6 +40,7 @@ instance : ToExpr SepFormat where | .comma => mkConst ``SepFormat.comma | .space => mkConst ``SepFormat.space | .spacePrefix => mkConst ``SepFormat.spacePrefix + | .newline => mkConst ``SepFormat.newline end SepFormat diff --git a/Strata/DDM/Parser.lean b/Strata/DDM/Parser.lean index c1a3ced85f..412ddf8a70 100644 --- a/Strata/DDM/Parser.lean +++ b/Strata/DDM/Parser.lean @@ -228,10 +228,8 @@ private partial def whitespace : ParserFn := fun c s => let curr := c.get j match curr with | '/' => - match c.tokens.matchPrefix c.inputString i with - | some _ => s - | none => - andthenFn (takeUntilFn (fun c => c = '\n')) whitespace c (s.next c j) + -- // is always a line comment, regardless of whether / is a token + andthenFn (takeUntilFn (fun c => c = '\n')) whitespace c (s.next c j) | '*' => match c.tokens.matchPrefix c.inputString i with | some _ => s @@ -897,7 +895,7 @@ partial def catParser (ctx : ParsingContext) (cat : SyntaxCat) (metadata : Metad assert! cat.args.size = 1 let isNonempty := q`StrataDDL.nonempty ∈ metadata commaSepByParserHelper isNonempty <$> catParser ctx cat.args[0]! - | q`Init.SpaceSepBy | q`Init.SpacePrefixSepBy | q`Init.Seq => + | q`Init.SpaceSepBy | q`Init.SpacePrefixSepBy | q`Init.NewlineSepBy | q`Init.Seq => assert! cat.args.size = 1 let isNonempty := q`StrataDDL.nonempty ∈ metadata (if isNonempty then many1Parser else manyParser) <$> catParser ctx cat.args[0]! diff --git a/Strata/DL/Lambda/LExprEval.lean b/Strata/DL/Lambda/LExprEval.lean index 73fb6f7203..a1001974f5 100644 --- a/Strata/DL/Lambda/LExprEval.lean +++ b/Strata/DL/Lambda/LExprEval.lean @@ -164,7 +164,7 @@ def eval (n : Nat) (σ : LState TBase) (e : (LExpr TBase.mono)) -- At least one argument in the function call is symbolic. new_e | none => - -- Not a call of a factory function. + -- Not a call of a factory function - go through evalCore evalCore n' σ e def evalCore (n' : Nat) (σ : LState TBase) (e : LExpr TBase.mono) : LExpr TBase.mono := diff --git a/Strata/DL/Lambda/LExprWF.lean b/Strata/DL/Lambda/LExprWF.lean index 0fbedf2cc8..fc45c58adc 100644 --- a/Strata/DL/Lambda/LExprWF.lean +++ b/Strata/DL/Lambda/LExprWF.lean @@ -256,11 +256,23 @@ theorem varOpen_of_varClose {T} {GenericTy} [BEq T.Metadata] [LawfulBEq T.Metada /-! ### Substitution on `LExpr`s -/ /-- -Substitute `(.fvar x _)` in `e` with `s`. Note that unlike `substK`, `varClose`, -and `varOpen`, this function is agnostic of types. +Increment all bound variable indices in `e` by `n`. Used to avoid capture when +substituting under binders. +-/ +def liftBVars (n : Nat) (e : LExpr ⟨T, GenericTy⟩) : LExpr ⟨T, GenericTy⟩ := + match e with + | .const _ _ => e | .op _ _ _ => e | .fvar _ _ _ => e + | .bvar m i => .bvar m (i + n) + | .abs m ty e' => .abs m ty (liftBVars n e') + | .quant m qk ty tr' e' => .quant m qk ty (liftBVars n tr') (liftBVars n e') + | .app m fn e' => .app m (liftBVars n fn) (liftBVars n e') + | .ite m c t e' => .ite m (liftBVars n c) (liftBVars n t) (liftBVars n e') + | .eq m e1 e2 => .eq m (liftBVars n e1) (liftBVars n e2) -Also see function `subst`, where `subst s e` substitutes the outermost _bound_ -variable in `e` with `s`. +/-- +Substitute `(.fvar x _)` in `e` with `to`. Does NOT lift de Bruijn indices in `to` +when going under binders - safe when `to` contains no bvars (e.g., substituting +fvar→fvar). Use `substFvarLifting` when `to` contains bvars. -/ def substFvar [BEq T.IDMeta] (e : LExpr ⟨T, GenericTy⟩) (fr : T.Identifier) (to : LExpr ⟨T, GenericTy⟩) : (LExpr ⟨T, GenericTy⟩) := @@ -273,6 +285,28 @@ def substFvar [BEq T.IDMeta] (e : LExpr ⟨T, GenericTy⟩) (fr : T.Identifier) | .ite m c t e' => .ite m (substFvar c fr to) (substFvar t fr to) (substFvar e' fr to) | .eq m e1 e2 => .eq m (substFvar e1 fr to) (substFvar e2 fr to) +/-- +Like `substFvar`, but properly lifts de Bruijn indices in `to` when going under +binders. Use this when `to` contains bound variables that should be preserved. +-/ +def substFvarLifting [BEq T.IDMeta] (e : LExpr ⟨T, GenericTy⟩) (fr : T.Identifier) (to : LExpr ⟨T, GenericTy⟩) + : (LExpr ⟨T, GenericTy⟩) := + go e 0 +where + go (e : LExpr ⟨T, GenericTy⟩) (depth : Nat) : LExpr ⟨T, GenericTy⟩ := + match e with + | .const _ _ => e | .bvar _ _ => e | .op _ _ _ => e + | .fvar _ name _ => if name == fr then liftBVars depth to else e + | .abs m ty e' => .abs m ty (go e' (depth + 1)) + | .quant m qk ty tr' e' => .quant m qk ty (go tr' (depth + 1)) (go e' (depth + 1)) + | .app m fn e' => .app m (go fn depth) (go e' depth) + | .ite m c t f => .ite m (go c depth) (go t depth) (go f depth) + | .eq m e1 e2 => .eq m (go e1 depth) (go e2 depth) + +def substFvarsLifting [BEq T.IDMeta] (e : LExpr ⟨T, GenericTy⟩) (sm : Map T.Identifier (LExpr ⟨T, GenericTy⟩)) + : LExpr ⟨T, GenericTy⟩ := + List.foldl (fun e (var, s) => substFvarLifting e var s) e sm + def substFvars [BEq T.IDMeta] (e : LExpr ⟨T, GenericTy⟩) (sm : Map T.Identifier (LExpr ⟨T, GenericTy⟩)) : LExpr ⟨T, GenericTy⟩ := List.foldl (fun e (var, s) => substFvar e var s) e sm diff --git a/Strata/DL/SMT/Encoder.lean b/Strata/DL/SMT/Encoder.lean index 8a8b74e024..c8d95fb80c 100644 --- a/Strata/DL/SMT/Encoder.lean +++ b/Strata/DL/SMT/Encoder.lean @@ -89,6 +89,10 @@ def encodeType (ty : TermType) : EncoderM String := do | .trigger => return "Trigger" | .bitvec n => return s!"(_ BitVec {n})" | .option oty => return s!"(Option {← encodeType oty})" + | .constr "Map" [k, v] => + let k' ← encodeType k + let v' ← encodeType v + return s!"(Array {k'} {v'})" | .constr id targs => -- let targs' ← targs.mapM (fun t => encodeType t) let targs' ← go targs diff --git a/Strata/Languages/C_Simp/DDMTransform/Parse.lean b/Strata/Languages/C_Simp/DDMTransform/Parse.lean index 0f0e66a704..0bd232448c 100644 --- a/Strata/Languages/C_Simp/DDMTransform/Parse.lean +++ b/Strata/Languages/C_Simp/DDMTransform/Parse.lean @@ -123,24 +123,24 @@ op annotation (a : Annotation) : Statement => a; -- Test -private def testPrg := -#strata -program C_Simp; - -int procedure simpleTest (x: int, y: int) - //@pre y > 0; - //@post true; -{ - var z : int; - z = x + y; - //@assert [test_assert] z > x; - if (z > 10) { - z = z - 1; - } else { - z = z + 1; - } - //@assume [test_assume] z > 0; - return 0; -} - -#end +-- private def testPrg := +-- #strata +-- program C_Simp; + +-- int procedure simpleTest (x: int, y: int) +-- //@pre y > 0; +-- //@post true; +-- { +-- var z : int; +-- z = x + y; +-- //@assert [test_assert] z > x; +-- if (z > 10) { +-- z = z - 1; +-- } else { +-- z = z + 1; +-- } +-- //@assume [test_assume] z > 0; +-- return 0; +-- } + +-- #end diff --git a/Strata/Languages/Core/Env.lean b/Strata/Languages/Core/Env.lean index 2ecb1694d7..0848603f9d 100644 --- a/Strata/Languages/Core/Env.lean +++ b/Strata/Languages/Core/Env.lean @@ -256,7 +256,7 @@ def Env.genFVar (E : Env) (xt : (Lambda.IdentT Lambda.LMonoTy Visibility)) : let (xid, E) := E.genVar xt.ident let xe := match xt.ty? with | none => .fvar () xid none - | some xty => .fvar () xid xty + | some xty => .fvar () xid (some xty) (xe, E) /-- diff --git a/Strata/Languages/Core/Procedure.lean b/Strata/Languages/Core/Procedure.lean index e1fd6bf428..8bd650879e 100644 --- a/Strata/Languages/Core/Procedure.lean +++ b/Strata/Languages/Core/Procedure.lean @@ -79,11 +79,11 @@ instance : Std.ToFormat Procedure.CheckAttr where structure Procedure.Check where expr : Expression.Expr attr : CheckAttr := .Default - md : Imperative.MetaData Expression := #[] + md : Imperative.MetaData Expression deriving Repr, DecidableEq instance : Inhabited Procedure.Check where - default := { expr := Inhabited.default } + default := { expr := Inhabited.default, md := #[] } instance : ToFormat Procedure.Check where format c := f!"{c.expr}{c.attr}" diff --git a/Strata/Languages/Core/SMTEncoder.lean b/Strata/Languages/Core/SMTEncoder.lean index 7819858bbc..df6d21c4d5 100644 --- a/Strata/Languages/Core/SMTEncoder.lean +++ b/Strata/Languages/Core/SMTEncoder.lean @@ -88,6 +88,7 @@ private def lMonoTyToSMTString (ty : LMonoTy) : String := | .tcons "real" [] => "Real" | .tcons "string" [] => "String" | .tcons "regex" [] => "RegLan" + | .tcons "Map" [k, v] => s!"(Array {lMonoTyToSMTString k} {lMonoTyToSMTString v})" | .tcons name args => if args.isEmpty then name else s!"({name} {String.intercalate " " (args.map lMonoTyToSMTString)})" @@ -346,13 +347,21 @@ partial def appToSMTTerm (E : Env) (bvs : BoundVars) (e : LExpr CoreLParams.mono let (op, retty, ctx) ← toSMTOp E fn fnty ctx let (e1t, ctx) ← toSMTTerm E bvs e1 ctx .ok (op (e1t :: acc) retty, ctx) - | .app _ (.fvar _ fn (.some (.arrow intty outty))) e1 => do + | .app _ (.fvar _ fn (.some fnty)) e1 => do + let tys := LMonoTy.destructArrow fnty + let outty := tys.getLast (by exact @LMonoTy.destructArrow_non_empty fnty) + let intys := tys.take (tys.length - 1) let (smt_outty, ctx) ← LMonoTy.toSMTType E outty ctx - let (smt_intty, ctx) ← LMonoTy.toSMTType E intty ctx - let argvars := [TermVar.mk (toString $ format intty) smt_intty] let (e1t, ctx) ← toSMTTerm E bvs e1 ctx + let allArgs := e1t :: acc + let mut argvars : List TermVar := [] + let mut ctx := ctx + for inty in intys do + let (smt_inty, ctx') ← LMonoTy.toSMTType E inty ctx + ctx := ctx' + argvars := argvars ++ [TermVar.mk (toString $ format inty) smt_inty] let uf := UF.mk (id := (toString $ format fn)) (args := argvars) (out := smt_outty) - .ok (((Term.app (.uf uf) [e1t] smt_outty)), ctx) + .ok (Term.app (.uf uf) allArgs smt_outty, ctx) | .app _ _ _ => .error f!"Cannot encode expression {e}" @@ -576,9 +585,9 @@ partial def toSMTOp (E : Env) (fn : CoreIdent) (fnty : LMonoTy) (ctx : SMT.Conte | none => .ok (ctx.addUF uf, !ctx.ufs.contains uf) | some body => -- Substitute the formals in the function body with appropriate - -- `.bvar`s. + -- `.bvar`s. Use substFvarsLifting to properly lift indices under binders. let bvars := (List.range formals.length).map (fun i => LExpr.bvar () i) - let body := LExpr.substFvars body (formals.zip bvars) + let body := LExpr.substFvarsLifting body (formals.zip bvars) let (term, ctx) ← toSMTTerm E bvs body ctx .ok (ctx.addIF uf term, !ctx.ifs.contains ({ uf := uf, body := term })) if isNew then diff --git a/Strata/Languages/Core/Verifier.lean b/Strata/Languages/Core/Verifier.lean index 372b2d3054..07687a7102 100644 --- a/Strata/Languages/Core/Verifier.lean +++ b/Strata/Languages/Core/Verifier.lean @@ -474,20 +474,30 @@ def toDiagnosticModel (vcr : Core.VCResult) : Option DiagnosticModel := do match vcr.result with | .pass => none -- Verification succeeded, no diagnostic | result => - let fileRangeElem ← vcr.obligation.metadata.findElem Imperative.MetaData.fileRange - match fileRangeElem.value with - | .fileRange fileRange => - let message := match result with - | .fail => "assertion does not hold" - | .unknown => "assertion could not be proved" - | .implementationError msg => s!"verification error: {msg}" - | _ => panic "impossible" - - some { - fileRange := fileRange - message := message - } - | _ => none + let message := match result with + | .fail => "assertion does not hold" + | .unknown => "assertion could not be proved" + | .implementationError msg => s!"verification error: {msg}" + | _ => panic "impossible" + + let .some fileRangeElem := vcr.obligation.metadata.findElem Imperative.MetaData.fileRange + | some { + fileRange := default + message := s!"Internal error: diagnostics without position! obligation label: {repr vcr.obligation.label}" + } + + let result := match fileRangeElem.value with + | .fileRange fileRange => + some { + fileRange := fileRange + message := message + } + | _ => + some { + fileRange := default + message := s!"Internal error: diagnostics without position! Metadata value for fileRange key was not a fileRange. obligation label: {repr vcr.obligation.label}" + } + result structure Diagnostic where start : Lean.Position diff --git a/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean b/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean index 2e6f4b8ef3..1c39abcb02 100644 --- a/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean +++ b/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean @@ -75,18 +75,28 @@ instance : Inhabited HighType where default := .TVoid instance : Inhabited Parameter where - default := { name := "", type := .TVoid } + default := { name := "", type := ⟨.TVoid, #[]⟩ } -def translateHighType (arg : Arg) : TransM HighType := do +/-- Create a HighTypeMd with the given metadata -/ +def mkHighTypeMd (t : HighType) (md : MetaData Core.Expression) : HighTypeMd := ⟨t, md⟩ + +/-- Create a StmtExprMd with the given metadata -/ +def mkStmtExprMd (e : StmtExpr) (md : MetaData Core.Expression) : StmtExprMd := ⟨e, md⟩ + +partial def translateHighType (arg : Arg) : TransM HighTypeMd := do + let md ← getArgMetaData arg match arg with | .op op => match op.name, op.args with - | q`Laurel.intType, _ => return .TInt - | q`Laurel.boolType, _ => return .TBool + | q`Laurel.intType, _ => return mkHighTypeMd .TInt md + | q`Laurel.boolType, _ => return mkHighTypeMd .TBool md + | q`Laurel.arrayType, #[elemArg] => + let elemType ← translateHighType elemArg + return mkHighTypeMd (.Applied (mkHighTypeMd (.UserDefined "Array") md) [elemType]) md | q`Laurel.compositeType, #[nameArg] => let name ← translateIdent nameArg - return .UserDefined name - | _, _ => TransM.error s!"translateHighType expects intType, boolType or compositeType, got {repr op.name}" + return mkHighTypeMd (.UserDefined name) md + | _, _ => TransM.error s!"translateHighType expects intType, boolType, arrayType or compositeType, got {repr op.name}" | _ => TransM.error s!"translateHighType expects operation" def translateNat (arg : Arg) : TransM Nat := do @@ -118,42 +128,56 @@ instance : Inhabited Procedure where name := "" inputs := [] outputs := [] - precondition := .LiteralBool true + preconditions := [] decreases := none - body := .Transparent (.LiteralBool true) + body := .Transparent ⟨.LiteralBool true, #[]⟩ } def getBinaryOp? (name : QualifiedIdent) : Option Operation := match name with | q`Laurel.add => some Operation.Add + | q`Laurel.sub => some Operation.Sub + | q`Laurel.mul => some Operation.Mul + | q`Laurel.div => some Operation.Div + | q`Laurel.mod => some Operation.Mod + | q`Laurel.divT => some Operation.DivT + | q`Laurel.modT => some Operation.ModT | q`Laurel.eq => some Operation.Eq | q`Laurel.neq => some Operation.Neq | q`Laurel.gt => some Operation.Gt | q`Laurel.lt => some Operation.Lt | q`Laurel.le => some Operation.Leq | q`Laurel.ge => some Operation.Geq + | q`Laurel.and => some Operation.And + | q`Laurel.or => some Operation.Or + | q`Laurel.implies => some Operation.Implies + | _ => none + +def getUnaryOp? (name : QualifiedIdent) : Option Operation := + match name with + | q`Laurel.not => some Operation.Not + | q`Laurel.neg => some Operation.Neg | _ => none mutual -partial def translateStmtExpr (arg : Arg) : TransM StmtExpr := do +partial def translateStmtExpr (arg : Arg) : TransM StmtExprMd := do + let md ← getArgMetaData arg match arg with | .op op => match op.name, op.args with | q`Laurel.assert, #[arg0] => let cond ← translateStmtExpr arg0 - let md ← getArgMetaData (.op op) - return .Assert cond md + return mkStmtExprMd (.Assert cond) md | q`Laurel.assume, #[arg0] => let cond ← translateStmtExpr arg0 - let md ← getArgMetaData (.op op) - return .Assume cond md + return mkStmtExprMd (.Assume cond) md | q`Laurel.block, #[arg0] => let stmts ← translateSeqCommand arg0 - return .Block stmts none - | q`Laurel.literalBool, #[arg0] => return .LiteralBool (← translateBool arg0) + return mkStmtExprMd (.Block stmts none) md + | q`Laurel.literalBool, #[arg0] => return mkStmtExprMd (.LiteralBool (← translateBool arg0)) md | q`Laurel.int, #[arg0] => let n ← translateNat arg0 - return .LiteralInt n + return mkStmtExprMd (.LiteralInt n) md | q`Laurel.varDecl, #[arg0, typeArg, assignArg] => let name ← translateIdent arg0 let varType ← match typeArg with @@ -167,28 +191,27 @@ partial def translateStmtExpr (arg : Arg) : TransM StmtExpr := do | _ => TransM.error s!"assignArg {repr assignArg} didn't match expected pattern for variable {name}" | .option _ none => pure none | _ => TransM.error s!"assignArg {repr assignArg} didn't match expected pattern for variable {name}" - return .LocalVariable name varType value + return mkStmtExprMd (.LocalVariable name varType value) md | q`Laurel.identifier, #[arg0] => let name ← translateIdent arg0 - return .Identifier name + return mkStmtExprMd (.Identifier name) md | q`Laurel.parenthesis, #[arg0] => translateStmtExpr arg0 | q`Laurel.assign, #[arg0, arg1] => let target ← translateStmtExpr arg0 let value ← translateStmtExpr arg1 - let md ← getArgMetaData (.op op) - return .Assign target value md + return mkStmtExprMd (.Assign target value) md | q`Laurel.call, #[arg0, argsSeq] => let callee ← translateStmtExpr arg0 - let calleeName := match callee with + let calleeName := match callee.val with | .Identifier name => name | _ => "" let argsList ← match argsSeq with | .seq _ .comma args => args.toList.mapM translateStmtExpr | _ => pure [] - return .StaticCall calleeName argsList + return mkStmtExprMd (.StaticCall calleeName argsList) md | q`Laurel.return, #[arg0] => let value ← translateStmtExpr arg0 - return .Return (some value) + return mkStmtExprMd (.Return (some value)) md | q`Laurel.ifThenElse, #[arg0, arg1, elseArg] => let cond ← translateStmtExpr arg0 let thenBranch ← translateStmtExpr arg1 @@ -197,30 +220,62 @@ partial def translateStmtExpr (arg : Arg) : TransM StmtExpr := do | q`Laurel.optionalElse, #[elseArg0] => translateStmtExpr elseArg0 >>= (pure ∘ some) | _, _ => pure none | _ => pure none - return .IfThenElse cond thenBranch elseBranch + return mkStmtExprMd (.IfThenElse cond thenBranch elseBranch) md | q`Laurel.fieldAccess, #[objArg, fieldArg] => let obj ← translateStmtExpr objArg let field ← translateIdent fieldArg - return .FieldSelect obj field + return mkStmtExprMd (.FieldSelect obj field) md + | q`Laurel.arrayIndex, #[arrArg, idxArg] => + let arr ← translateStmtExpr arrArg + let idx ← translateStmtExpr idxArg + return mkStmtExprMd (.StaticCall "Array.Get" [arr, idx]) md + | q`Laurel.while, #[condArg, invSeqArg, bodyArg] => + let cond ← translateStmtExpr condArg + let invariants ← match invSeqArg with + | .seq _ _ clauses => clauses.toList.mapM fun arg => match arg with + | .op invOp => match invOp.name, invOp.args with + | q`Laurel.invariantClause, #[exprArg] => translateStmtExpr exprArg + | _, _ => TransM.error "Expected invariantClause" + | _ => TransM.error "Expected operation" + | _ => pure [] + let body ← translateStmtExpr bodyArg + return mkStmtExprMd (.While cond invariants none body) md + | _, #[arg0] => match getUnaryOp? op.name with + | some primOp => + let inner ← translateStmtExpr arg0 + return mkStmtExprMd (.PrimitiveOp primOp [inner]) md + | none => TransM.error s!"Unknown unary operation: {op.name}" + | q`Laurel.forallExpr, #[nameArg, tyArg, bodyArg] => + let name ← translateIdent nameArg + let ty ← translateHighType tyArg + let body ← translateStmtExpr bodyArg + return mkStmtExprMd (.Forall name ty body) md + | q`Laurel.existsExpr, #[nameArg, tyArg, bodyArg] => + let name ← translateIdent nameArg + let ty ← translateHighType tyArg + let body ← translateStmtExpr bodyArg + return mkStmtExprMd (.Exists name ty body) md | _, #[arg0, arg1] => match getBinaryOp? op.name with | some primOp => let lhs ← translateStmtExpr arg0 let rhs ← translateStmtExpr arg1 - return .PrimitiveOp primOp [lhs, rhs] + return mkStmtExprMd (.PrimitiveOp primOp [lhs, rhs]) md | none => TransM.error s!"Unknown operation: {op.name}" | _, _ => TransM.error s!"Unknown operation: {op.name}" | _ => TransM.error s!"translateStmtExpr expects operation" -partial def translateSeqCommand (arg : Arg) : TransM (List StmtExpr) := do - let .seq _ .none args := arg - | TransM.error s!"translateSeqCommand expects seq" - let mut stmts : List StmtExpr := [] +partial def translateSeqCommand (arg : Arg) : TransM (List StmtExprMd) := do + let args ← match arg with + | .seq _ .none args => pure args + | .seq _ .newline args => pure args -- NewlineSepBy for block statements + | _ => TransM.error s!"translateSeqCommand expects seq or newlineSepBy" + let mut stmts : List StmtExprMd := [] for arg in args do let stmt ← translateStmtExpr arg stmts := stmts ++ [stmt] return stmts -partial def translateCommand (arg : Arg) : TransM StmtExpr := do +partial def translateCommand (arg : Arg) : TransM StmtExprMd := do translateStmtExpr arg end @@ -251,30 +306,32 @@ def parseProcedure (arg : Arg) : TransM Procedure := do | .option _ none => pure [] | _ => TransM.error s!"Expected returnParameters operation, got {repr returnParamsArg}" | _ => TransM.error s!"Expected optionalReturnType operation, got {repr returnTypeArg}" - -- Parse precondition (requires clause) - let precondition ← match requiresArg with - | .option _ (some (.op requiresOp)) => match requiresOp.name, requiresOp.args with - | q`Laurel.optionalRequires, #[exprArg] => translateStmtExpr exprArg - | _, _ => TransM.error s!"Expected optionalRequires operation, got {repr requiresOp.name}" - | .option _ none => pure (.LiteralBool true) - | _ => TransM.error s!"Expected optionalRequires operation, got {repr requiresArg}" - -- Parse postcondition (ensures clause) - let postcondition ← match ensuresArg with - | .option _ (some (.op ensuresOp)) => match ensuresOp.name, ensuresOp.args with - | q`Laurel.optionalEnsures, #[exprArg] => translateStmtExpr exprArg >>= (pure ∘ some) - | _, _ => TransM.error s!"Expected optionalEnsures operation, got {repr ensuresOp.name}" - | .option _ none => pure none - | _ => TransM.error s!"Expected optionalEnsures operation, got {repr ensuresArg}" + -- Parse preconditions (requires clauses) + let preconditions ← match requiresArg with + | .seq _ .none clauses => clauses.toList.mapM fun arg => match arg with + | .op reqOp => match reqOp.name, reqOp.args with + | q`Laurel.requiresClause, #[exprArg] => translateStmtExpr exprArg + | _, _ => TransM.error "Expected requiresClause" + | _ => TransM.error "Expected operation" + | _ => pure [] + -- Parse postconditions (ensures clauses) + let postconditions ← match ensuresArg with + | .seq _ .none clauses => clauses.toList.mapM fun arg => match arg with + | .op ensOp => match ensOp.name, ensOp.args with + | q`Laurel.ensuresClause, #[exprArg] => translateStmtExpr exprArg + | _, _ => TransM.error "Expected ensuresClause" + | _ => TransM.error "Expected operation" + | _ => pure [] let body ← translateCommand bodyArg - -- If there's a postcondition, use Opaque body; otherwise use Transparent - let procBody := match postcondition with - | some post => Body.Opaque post (some body) .nondeterministic none - | none => Body.Transparent body + -- If there are postconditions, use Opaque body; otherwise use Transparent + let procBody := match postconditions with + | [] => Body.Transparent body + | posts => Body.Opaque posts (some body) .nondeterministic none return { name := name inputs := parameters outputs := returnParameters - precondition := precondition + preconditions := preconditions decreases := none body := procBody } @@ -283,19 +340,40 @@ def parseProcedure (arg : Arg) : TransM Procedure := do | _, _ => TransM.error s!"parseProcedure expects procedure, got {repr op.name}" -def parseTopLevel (arg : Arg) : TransM (Option Procedure) := do +def parseConstrainedType (arg : Arg) : TransM ConstrainedType := do + let .op op := arg + | TransM.error s!"parseConstrainedType expects operation" + match op.name, op.args with + | q`Laurel.constrainedType, #[nameArg, valueNameArg, baseArg, constraintArg, witnessArg] => + let name ← translateIdent nameArg + let valueName ← translateIdent valueNameArg + let base ← translateHighType baseArg + let constraint ← translateStmtExpr constraintArg + let witness ← translateStmtExpr witnessArg + return { name, base, valueName, constraint, witness } + | _, _ => + TransM.error s!"parseConstrainedType expects constrainedType, got {repr op.name}" + +inductive TopLevelItem where + | proc (p : Procedure) + | typeDef (t : TypeDefinition) + +def parseTopLevel (arg : Arg) : TransM (Option TopLevelItem) := do let .op op := arg | TransM.error s!"parseTopLevel expects operation" match op.name, op.args with | q`Laurel.topLevelProcedure, #[procArg] => let proc ← parseProcedure procArg - return some proc + return some (.proc proc) | q`Laurel.topLevelComposite, #[_compositeArg] => -- TODO: handle composite types return none + | q`Laurel.topLevelConstrainedType, #[ctArg] => + let ct ← parseConstrainedType ctArg + return some (.typeDef (.Constrained ct)) | _, _ => - TransM.error s!"parseTopLevel expects topLevelProcedure or topLevelComposite, got {repr op.name}" + TransM.error s!"parseTopLevel expects topLevelProcedure, topLevelComposite, or topLevelConstrainedType, got {repr op.name}" /-- Translate concrete Laurel syntax into abstract Laurel syntax @@ -317,15 +395,17 @@ def parseProgram (prog : Strata.Program) : TransM Laurel.Program := do prog.commands let mut procedures : List Procedure := [] + let mut types : List TypeDefinition := [] for op in commands do let result ← parseTopLevel (.op op) match result with - | some proc => procedures := procedures ++ [proc] + | some (.proc proc) => procedures := procedures ++ [proc] + | some (.typeDef td) => types := types ++ [td] | none => pure () -- composite types are skipped for now return { staticProcedures := procedures staticFields := [] - types := [] + types := types } end Laurel diff --git a/Strata/Languages/Laurel/Grammar/LaurelGrammar.st b/Strata/Languages/Laurel/Grammar/LaurelGrammar.st index 54723e20bf..5b72ddd620 100644 --- a/Strata/Languages/Laurel/Grammar/LaurelGrammar.st +++ b/Strata/Languages/Laurel/Grammar/LaurelGrammar.st @@ -4,6 +4,7 @@ dialect Laurel; category LaurelType; op intType : LaurelType => "int"; op boolType : LaurelType => "bool"; +op arrayType (elemType: LaurelType): LaurelType => "Array" "<" elemType ">"; op compositeType (name: Ident): LaurelType => name; category StmtExpr; @@ -12,49 +13,78 @@ op int(n : Num) : StmtExpr => n; // Variable declarations category OptionalType; -op optionalType(varType: LaurelType): OptionalType => ":" varType; +op optionalType(varType: LaurelType): OptionalType => ": " varType; category OptionalAssignment; -op optionalAssignment(value: StmtExpr): OptionalAssignment => ":=" value:0; +op optionalAssignment(value: StmtExpr): OptionalAssignment => " := " value:0; op varDecl (name: Ident, varType: Option OptionalType, assignment: Option OptionalAssignment): StmtExpr => @[prec(0)] "var " name varType assignment ";"; -op call(callee: StmtExpr, args: CommaSepBy StmtExpr): StmtExpr => callee "(" args ")"; +op call(callee: StmtExpr, args: CommaSepBy StmtExpr): StmtExpr => @[prec(95)] callee:85 "(" args ")"; // Field access op fieldAccess (obj: StmtExpr, field: Ident): StmtExpr => @[prec(90)] obj "#" field; +// Array indexing +op arrayIndex (arr: StmtExpr, idx: StmtExpr): StmtExpr => @[prec(90)] arr "[" idx "]"; + // Identifiers/Variables - must come after fieldAccess so c.value parses as fieldAccess not identifier op identifier (name: Ident): StmtExpr => name; op parenthesis (inner: StmtExpr): StmtExpr => "(" inner ")"; // Assignment -op assign (target: StmtExpr, value: StmtExpr): StmtExpr => @[prec(10)] target ":=" value ";"; +op assign (target: StmtExpr, value: StmtExpr): StmtExpr => @[prec(10)] target " := " value ";"; // Binary operators -op add (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(60)] lhs "+" rhs; -op eq (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs "==" rhs; -op neq (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs "!=" rhs; -op gt (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs ">" rhs; -op lt (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs "<" rhs; -op le (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs "<=" rhs; -op ge (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs ">=" rhs; +op add (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(60), leftassoc] lhs " + " rhs; +op sub (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(60), leftassoc] lhs " - " rhs; +op mul (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(70), leftassoc] lhs " * " rhs; +op div (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(70), leftassoc] lhs " / " rhs; +op mod (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(70), leftassoc] lhs " % " rhs; +op divT (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(70), leftassoc] lhs " /t " rhs; +op modT (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(70), leftassoc] lhs " %t " rhs; +op eq (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs " == " rhs; +op neq (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs " != " rhs; +op gt (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs " > " rhs; +op lt (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs " < " rhs; +op le (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs " <= " rhs; +op ge (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(40)] lhs " >= " rhs; +op and (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(30), leftassoc] lhs " && " rhs; +op or (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(20), leftassoc] lhs " || " rhs; +op implies (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(15), rightassoc] lhs " ==> " rhs; + +// Unary operators +op not (inner: StmtExpr): StmtExpr => @[prec(80)] "!" inner; +op neg (inner: StmtExpr): StmtExpr => @[prec(80)] "-" inner; + +// Quantifiers +op forallExpr (name: Ident, ty: LaurelType, body: StmtExpr): StmtExpr + => "forall(" name ": " ty ") => " body:0; +op existsExpr (name: Ident, ty: LaurelType, body: StmtExpr): StmtExpr + => "exists(" name ": " ty ") => " body:0; // If-else category OptionalElse; -op optionalElse(stmts : StmtExpr) : OptionalElse => "else" stmts; +op optionalElse(stmts : StmtExpr) : OptionalElse => "else " stmts:0; op ifThenElse (cond: StmtExpr, thenBranch: StmtExpr, elseBranch: Option OptionalElse): StmtExpr => - @[prec(20)] "if (" cond ") " thenBranch:0 elseBranch:0; + @[prec(20)] "if (" cond ") " thenBranch:0 " " elseBranch:0; op assert (cond : StmtExpr) : StmtExpr => "assert " cond ";"; op assume (cond : StmtExpr) : StmtExpr => "assume " cond ";"; op return (value : StmtExpr) : StmtExpr => "return " value ";"; -op block (stmts : Seq StmtExpr) : StmtExpr => @[prec(1000)] "{" stmts "}"; +op block (stmts : NewlineSepBy StmtExpr) : StmtExpr => @[prec(1000)] "{" indent(2, "\n" stmts) "\n}"; + +// While loops +category InvariantClause; +op invariantClause (cond: StmtExpr): InvariantClause => "\n invariant " cond:0; + +op while (cond: StmtExpr, invariants: Seq InvariantClause, body: StmtExpr): StmtExpr + => "while" "(" cond ")" invariants body:0; category Parameter; -op parameter (name: Ident, paramType: LaurelType): Parameter => name ":" paramType; +op parameter (name: Ident, paramType: LaurelType): Parameter => name ": " paramType; // Composite types category Field; @@ -68,11 +98,11 @@ op composite (name: Ident, fields: Seq Field): Composite => "composite " name "{ category OptionalReturnType; op optionalReturnType(returnType: LaurelType): OptionalReturnType => ":" returnType; -category OptionalRequires; -op optionalRequires(cond: StmtExpr): OptionalRequires => "requires" cond:0; +category RequiresClause; +op requiresClause(cond: StmtExpr): RequiresClause => "\n requires " cond:0; -category OptionalEnsures; -op optionalEnsures(cond: StmtExpr): OptionalEnsures => "ensures" cond:0; +category EnsuresClause; +op ensuresClause(cond: StmtExpr): EnsuresClause => "\n ensures " cond:0; category ReturnParameters; op returnParameters(parameters: CommaSepBy Parameter): ReturnParameters => "returns" "(" parameters ")"; @@ -81,13 +111,20 @@ category Procedure; op procedure (name : Ident, parameters: CommaSepBy Parameter, returnType: Option OptionalReturnType, returnParameters: Option ReturnParameters, - requires: Option OptionalRequires, - ensures: Option OptionalEnsures, + requires: Seq RequiresClause, + ensures: Seq EnsuresClause, body : StmtExpr) : Procedure => - "procedure " name "(" parameters ")" returnType returnParameters requires ensures body:0; + "procedure " name "(" parameters ")" returnType returnParameters requires ensures "\n" body:0; + +// Constrained types +category ConstrainedType; +op constrainedType (name: Ident, valueName: Ident, base: LaurelType, + constraint: StmtExpr, witness: StmtExpr): ConstrainedType + => "constrained " name " = " valueName ": " base " where " constraint:0 " witness " witness:0; category TopLevel; -op topLevelComposite(composite: Composite): TopLevel => composite; -op topLevelProcedure(procedure: Procedure): TopLevel => procedure; +op topLevelComposite(composite: Composite): TopLevel => composite "\n"; +op topLevelProcedure(procedure: Procedure): TopLevel => procedure "\n"; +op topLevelConstrainedType(ct: ConstrainedType): TopLevel => ct "\n"; op program (items: Seq TopLevel): Command => items; \ No newline at end of file diff --git a/Strata/Languages/Laurel/HeapParameterization.lean b/Strata/Languages/Laurel/HeapParameterization.lean index 4bf9803c51..97a92c5353 100644 --- a/Strata/Languages/Laurel/HeapParameterization.lean +++ b/Strata/Languages/Laurel/HeapParameterization.lean @@ -21,8 +21,8 @@ structure AnalysisResult where readsHeapDirectly : Bool := false callees : List Identifier := [] -partial def collectExpr (expr : StmtExpr) : StateM AnalysisResult Unit := do - match expr with +partial def collectExpr (expr : StmtExprMd) : StateM AnalysisResult Unit := do + match expr.val with | .FieldSelect target _ => modify fun s => { s with readsHeapDirectly := true }; collectExpr target | .InstanceCall target _ args => modify fun s => { s with readsHeapDirectly := true }; collectExpr target; for a in args do collectExpr a @@ -30,9 +30,9 @@ partial def collectExpr (expr : StmtExpr) : StateM AnalysisResult Unit := do | .IfThenElse c t e => collectExpr c; collectExpr t; if let some x := e then collectExpr x | .Block stmts _ => for s in stmts do collectExpr s | .LocalVariable _ _ i => if let some x := i then collectExpr x - | .While c i d b => collectExpr c; collectExpr b; if let some x := i then collectExpr x; if let some x := d then collectExpr x + | .While c invs d b => collectExpr c; collectExpr b; for i in invs do collectExpr i; if let some x := d then collectExpr x | .Return v => if let some x := v then collectExpr x - | .Assign t v _ => collectExpr t; collectExpr v + | .Assign t v => collectExpr t; collectExpr v | .PureFieldUpdate t _ v => collectExpr t; collectExpr v | .PrimitiveOp _ args => for a in args do collectExpr a | .ReferenceEquals l r => collectExpr l; collectExpr r @@ -43,8 +43,8 @@ partial def collectExpr (expr : StmtExpr) : StateM AnalysisResult Unit := do | .Assigned n => collectExpr n | .Old v => collectExpr v | .Fresh v => collectExpr v - | .Assert c _ => collectExpr c - | .Assume c _ => collectExpr c + | .Assert c => collectExpr c + | .Assume c => collectExpr c | .ProveBy v p => collectExpr v; collectExpr p | .ContractOf _ f => collectExpr f | _ => pure () @@ -78,60 +78,69 @@ abbrev TransformM := StateM TransformState def addFieldConstant (name : Identifier) : TransformM Unit := modify fun s => if s.fieldConstants.any (·.name == name) then s - else { s with fieldConstants := { name := name, type := .TField } :: s.fieldConstants } + else { s with fieldConstants := { name := name, type := ⟨.TField, #[]⟩ } :: s.fieldConstants } def readsHeap (name : Identifier) : TransformM Bool := do return (← get).heapReaders.contains name -partial def heapTransformExpr (heap : Identifier) (expr : StmtExpr) : TransformM StmtExpr := do - match expr with +/-- Helper to create a StmtExprMd with the same metadata as the input -/ +def mkStmtExprMdFrom (orig : StmtExprMd) (e : StmtExpr) : StmtExprMd := ⟨e, orig.md⟩ + +/-- Helper to create a StmtExprMd with empty metadata -/ +def mkStmtExprMdEmpty (e : StmtExpr) : StmtExprMd := ⟨e, #[]⟩ + +partial def heapTransformExpr (heap : Identifier) (expr : StmtExprMd) : TransformM StmtExprMd := do + let md := expr.md + match expr.val with | .FieldSelect target fieldName => addFieldConstant fieldName let t ← heapTransformExpr heap target - return .StaticCall "heapRead" [.Identifier heap, t, .Identifier fieldName] + return ⟨.StaticCall "heapRead" [mkStmtExprMdEmpty (.Identifier heap), t, mkStmtExprMdEmpty (.Identifier fieldName)], md⟩ | .StaticCall callee args => let args' ← args.mapM (heapTransformExpr heap) - return if ← readsHeap callee then .StaticCall callee (.Identifier heap :: args') else .StaticCall callee args' + return if ← readsHeap callee + then ⟨.StaticCall callee (mkStmtExprMdEmpty (.Identifier heap) :: args'), md⟩ + else ⟨.StaticCall callee args', md⟩ | .InstanceCall target callee args => let t ← heapTransformExpr heap target let args' ← args.mapM (heapTransformExpr heap) - return .InstanceCall t callee (.Identifier heap :: args') - | .IfThenElse c t e => return .IfThenElse (← heapTransformExpr heap c) (← heapTransformExpr heap t) (← e.mapM (heapTransformExpr heap)) - | .Block stmts label => return .Block (← stmts.mapM (heapTransformExpr heap)) label - | .LocalVariable n ty i => return .LocalVariable n ty (← i.mapM (heapTransformExpr heap)) - | .While c i d b => return .While (← heapTransformExpr heap c) (← i.mapM (heapTransformExpr heap)) (← d.mapM (heapTransformExpr heap)) (← heapTransformExpr heap b) - | .Return v => return .Return (← v.mapM (heapTransformExpr heap)) - | .Assign t v md => - match t with + return ⟨.InstanceCall t callee (mkStmtExprMdEmpty (.Identifier heap) :: args'), md⟩ + | .IfThenElse c t e => return ⟨.IfThenElse (← heapTransformExpr heap c) (← heapTransformExpr heap t) (← e.mapM (heapTransformExpr heap)), md⟩ + | .Block stmts label => return ⟨.Block (← stmts.mapM (heapTransformExpr heap)) label, md⟩ + | .LocalVariable n ty i => return ⟨.LocalVariable n ty (← i.mapM (heapTransformExpr heap)), md⟩ + | .While c invs d b => return ⟨.While (← heapTransformExpr heap c) (← invs.mapM (heapTransformExpr heap)) (← d.mapM (heapTransformExpr heap)) (← heapTransformExpr heap b), md⟩ + | .Return v => return ⟨.Return (← v.mapM (heapTransformExpr heap)), md⟩ + | .Assign t v => + match t.val with | .FieldSelect target fieldName => addFieldConstant fieldName let target' ← heapTransformExpr heap target let v' ← heapTransformExpr heap v -- heap := heapStore(heap, target, field, value) - return .Assign (.Identifier heap) (.StaticCall "heapStore" [.Identifier heap, target', .Identifier fieldName, v']) md - | _ => return .Assign (← heapTransformExpr heap t) (← heapTransformExpr heap v) md - | .PureFieldUpdate t f v => return .PureFieldUpdate (← heapTransformExpr heap t) f (← heapTransformExpr heap v) - | .PrimitiveOp op args => return .PrimitiveOp op (← args.mapM (heapTransformExpr heap)) - | .ReferenceEquals l r => return .ReferenceEquals (← heapTransformExpr heap l) (← heapTransformExpr heap r) - | .AsType t ty => return .AsType (← heapTransformExpr heap t) ty - | .IsType t ty => return .IsType (← heapTransformExpr heap t) ty - | .Forall n ty b => return .Forall n ty (← heapTransformExpr heap b) - | .Exists n ty b => return .Exists n ty (← heapTransformExpr heap b) - | .Assigned n => return .Assigned (← heapTransformExpr heap n) - | .Old v => return .Old (← heapTransformExpr heap v) - | .Fresh v => return .Fresh (← heapTransformExpr heap v) - | .Assert c md => return .Assert (← heapTransformExpr heap c) md - | .Assume c md => return .Assume (← heapTransformExpr heap c) md - | .ProveBy v p => return .ProveBy (← heapTransformExpr heap v) (← heapTransformExpr heap p) - | .ContractOf ty f => return .ContractOf ty (← heapTransformExpr heap f) - | other => return other + return ⟨.Assign (mkStmtExprMdEmpty (.Identifier heap)) (⟨.StaticCall "heapStore" [mkStmtExprMdEmpty (.Identifier heap), target', mkStmtExprMdEmpty (.Identifier fieldName), v'], md⟩), md⟩ + | _ => return ⟨.Assign (← heapTransformExpr heap t) (← heapTransformExpr heap v), md⟩ + | .PureFieldUpdate t f v => return ⟨.PureFieldUpdate (← heapTransformExpr heap t) f (← heapTransformExpr heap v), md⟩ + | .PrimitiveOp op args => return ⟨.PrimitiveOp op (← args.mapM (heapTransformExpr heap)), md⟩ + | .ReferenceEquals l r => return ⟨.ReferenceEquals (← heapTransformExpr heap l) (← heapTransformExpr heap r), md⟩ + | .AsType t ty => return ⟨.AsType (← heapTransformExpr heap t) ty, md⟩ + | .IsType t ty => return ⟨.IsType (← heapTransformExpr heap t) ty, md⟩ + | .Forall n ty b => return ⟨.Forall n ty (← heapTransformExpr heap b), md⟩ + | .Exists n ty b => return ⟨.Exists n ty (← heapTransformExpr heap b), md⟩ + | .Assigned n => return ⟨.Assigned (← heapTransformExpr heap n), md⟩ + | .Old v => return ⟨.Old (← heapTransformExpr heap v), md⟩ + | .Fresh v => return ⟨.Fresh (← heapTransformExpr heap v), md⟩ + | .Assert c => return ⟨.Assert (← heapTransformExpr heap c), md⟩ + | .Assume c => return ⟨.Assume (← heapTransformExpr heap c), md⟩ + | .ProveBy v p => return ⟨.ProveBy (← heapTransformExpr heap v) (← heapTransformExpr heap p), md⟩ + | .ContractOf ty f => return ⟨.ContractOf ty (← heapTransformExpr heap f), md⟩ + | other => return ⟨other, md⟩ def heapTransformProcedure (proc : Procedure) : TransformM Procedure := do if (← get).heapReaders.contains proc.name then match proc.body with | .Transparent bodyExpr => let body' ← heapTransformExpr "heap" bodyExpr - return { proc with inputs := { name := "heap", type := .THeap } :: proc.inputs, body := .Transparent body' } + return { proc with inputs := { name := "heap", type := ⟨.THeap, #[]⟩ } :: proc.inputs, body := .Transparent body' } | _ => return proc else return proc diff --git a/Strata/Languages/Laurel/Laurel.lean b/Strata/Languages/Laurel/Laurel.lean index 74abb64520..b030d7558c 100644 --- a/Strata/Languages/Laurel/Laurel.lean +++ b/Strata/Languages/Laurel/Laurel.lean @@ -52,9 +52,9 @@ inductive Operation: Type where /- Works on Bool -/ /- Equality on composite types uses reference equality for impure types, and structural equality for pure ones -/ | Eq | Neq - | And | Or | Not + | And | Or | Not | Implies /- Works on Int/Float64 -/ - | Neg | Add | Sub | Mul | Div | Mod + | Neg | Add | Sub | Mul | Div | Mod | DivT | ModT | Lt | Leq | Gt | Geq deriving Repr @@ -62,21 +62,33 @@ inductive Operation: Type where instance : Repr (Imperative.MetaData Core.Expression) := inferInstance mutual +/-- A wrapper that adds metadata to any type -/ +structure HighTypeMd where + val : HighType + md : Imperative.MetaData Core.Expression + deriving Repr + +/-- A wrapper that adds metadata to any type -/ +structure StmtExprMd where + val : StmtExpr + md : Imperative.MetaData Core.Expression + deriving Repr + structure Procedure: Type where name : Identifier inputs : List Parameter outputs : List Parameter - precondition : StmtExpr - decreases : Option StmtExpr -- optionally prove termination + preconditions : List StmtExprMd + decreases : Option StmtExprMd -- optionally prove termination body : Body inductive Determinism where - | deterministic (reads: Option StmtExpr) + | deterministic (reads: Option StmtExprMd) | nondeterministic structure Parameter where name : Identifier - type : HighType + type : HighTypeMd inductive HighType : Type where | TVoid @@ -86,22 +98,22 @@ inductive HighType : Type where | THeap /- Internal type for heap parameterization pass. Not accessible via grammar. -/ | TField /- Internal type for field constants in heap parameterization pass. Not accessible via grammar. -/ | UserDefined (name: Identifier) - | Applied (base : HighType) (typeArguments : List HighType) + | Applied (base : HighTypeMd) (typeArguments : List HighTypeMd) /- Pure represents a composite type that does not support reference equality -/ - | Pure(base: HighType) + | Pure(base: HighTypeMd) /- Java has implicit intersection types. Example: ` ? RustanLeino : AndersHejlsberg` could be typed as `Scientist & Scandinavian`-/ - | Intersection (types : List HighType) + | Intersection (types : List HighTypeMd) deriving Repr /- No support for something like function-by-method yet -/ inductive Body where - | Transparent (body : StmtExpr) + | Transparent (body : StmtExprMd) /- Without an implementation, the postcondition is assumed -/ - | Opaque (postcondition : StmtExpr) (implementation : Option StmtExpr) (determinism: Determinism) (modifies : Option StmtExpr) + | Opaque (postconditions : List StmtExprMd) (implementation : Option StmtExprMd) (determinism: Determinism) (modifies : Option StmtExprMd) /- An abstract body is useful for types that are extending. A type containing any members with abstract bodies can not be instantiated. -/ - | Abstract (postcondition : StmtExpr) + | Abstract (postconditions : List StmtExprMd) /- A StmtExpr contains both constructs that we typically find in statements and those in expressions. @@ -116,46 +128,46 @@ for example in `Option (StmtExpr isPure)` -/ inductive StmtExpr : Type where /- Statement like -/ - | IfThenElse (cond : StmtExpr) (thenBranch : StmtExpr) (elseBranch : Option StmtExpr) - | Block (statements : List StmtExpr) (label : Option Identifier) + | IfThenElse (cond : StmtExprMd) (thenBranch : StmtExprMd) (elseBranch : Option StmtExprMd) + | Block (statements : List StmtExprMd) (label : Option Identifier) /- The initializer must be set if this StmtExpr is pure -/ - | LocalVariable (name : Identifier) (type : HighType) (initializer : Option StmtExpr) + | LocalVariable (name : Identifier) (type : HighTypeMd) (initializer : Option StmtExprMd) /- While is only allowed in an impure context - The invariant and decreases are always pure + The invariants and decreases are always pure -/ - | While (cond : StmtExpr) (invariant : Option StmtExpr) (decreases: Option StmtExpr) (body : StmtExpr) + | While (cond : StmtExprMd) (invariants : List StmtExprMd) (decreases: Option StmtExprMd) (body : StmtExprMd) | Exit (target: Identifier) - | Return (value : Option StmtExpr) + | Return (value : Option StmtExprMd) /- Expression like -/ | LiteralInt (value: Int) | LiteralBool (value: Bool) | Identifier (name : Identifier) /- Assign is only allowed in an impure context -/ - | Assign (target : StmtExpr) (value : StmtExpr) (md : Imperative.MetaData Core.Expression) + | Assign (target : StmtExprMd) (value : StmtExprMd) /- Used by itself for fields reads and in combination with Assign for field writes -/ - | FieldSelect (target : StmtExpr) (fieldName : Identifier) + | FieldSelect (target : StmtExprMd) (fieldName : Identifier) /- PureFieldUpdate is the only way to assign values to fields of pure types -/ - | PureFieldUpdate (target : StmtExpr) (fieldName : Identifier) (newValue : StmtExpr) - | StaticCall (callee : Identifier) (arguments : List StmtExpr) - | PrimitiveOp (operator: Operation) (arguments : List StmtExpr) + | PureFieldUpdate (target : StmtExprMd) (fieldName : Identifier) (newValue : StmtExprMd) + | StaticCall (callee : Identifier) (arguments : List StmtExprMd) + | PrimitiveOp (operator: Operation) (arguments : List StmtExprMd) /- Instance related -/ | This - | ReferenceEquals (lhs: StmtExpr) (rhs: StmtExpr) - | AsType (target: StmtExpr) (targetType: HighType) - | IsType (target : StmtExpr) (type: HighType) - | InstanceCall (target : StmtExpr) (callee : Identifier) (arguments : List StmtExpr) + | ReferenceEquals (lhs: StmtExprMd) (rhs: StmtExprMd) + | AsType (target: StmtExprMd) (targetType: HighTypeMd) + | IsType (target : StmtExprMd) (type: HighTypeMd) + | InstanceCall (target : StmtExprMd) (callee : Identifier) (arguments : List StmtExprMd) /- Verification specific -/ - | Forall (name: Identifier) (type: HighType) (body: StmtExpr) - | Exists (name: Identifier) (type: HighType) (body: StmtExpr) - | Assigned (name : StmtExpr) - | Old (value : StmtExpr) + | Forall (name: Identifier) (type: HighTypeMd) (body: StmtExprMd) + | Exists (name: Identifier) (type: HighTypeMd) (body: StmtExprMd) + | Assigned (name : StmtExprMd) + | Old (value : StmtExprMd) /- Fresh may only target impure composite types -/ - | Fresh(value : StmtExpr) + | Fresh(value : StmtExprMd) /- Related to proofs -/ - | Assert (condition: StmtExpr) (md : Imperative.MetaData Core.Expression) - | Assume (condition: StmtExpr) (md : Imperative.MetaData Core.Expression) + | Assert (condition: StmtExprMd) + | Assume (condition: StmtExprMd) /- ProveBy allows writing proof trees. Its semantics are the same as that of the given `value`, but the `proof` is used to help prove any assertions in `value`. @@ -168,10 +180,10 @@ ProveBy( ) ) -/ - | ProveBy (value: StmtExpr) (proof: StmtExpr) + | ProveBy (value: StmtExprMd) (proof: StmtExprMd) -- ContractOf allows extracting the contract of a function - | ContractOf (type: ContractType) (function: StmtExpr) + | ContractOf (type: ContractType) (function: StmtExprMd) /- Abstract can be used as the root expr in a contract for reads/modifies/precondition/postcondition. For example: `reads(abstract)` It can only be used for instance procedures and it makes the containing type abstract, meaning it can not be instantiated. @@ -189,7 +201,7 @@ end instance : Inhabited StmtExpr where default := .Hole -def highEq (a: HighType) (b: HighType) : Bool := match a, b with +partial def highEq (a: HighTypeMd) (b: HighTypeMd) : Bool := match a.val, b.val with | HighType.TVoid, HighType.TVoid => true | HighType.TBool, HighType.TBool => true | HighType.TInt, HighType.TInt => true @@ -198,27 +210,25 @@ def highEq (a: HighType) (b: HighType) : Bool := match a, b with | HighType.TField, HighType.TField => true | HighType.UserDefined n1, HighType.UserDefined n2 => n1 == n2 | HighType.Applied b1 args1, HighType.Applied b2 args2 => - highEq b1 b2 && args1.length == args2.length && (args1.attach.zip args2 |>.all (fun (a1, a2) => highEq a1.1 a2)) + highEq b1 b2 && args1.length == args2.length && (args1.zip args2 |>.all (fun (a1, a2) => highEq a1 a2)) + | HighType.Pure b1, HighType.Pure b2 => highEq b1 b2 | HighType.Intersection ts1, HighType.Intersection ts2 => - ts1.length == ts2.length && (ts1.attach.zip ts2 |>.all (fun (t1, t2) => highEq t1.1 t2)) + ts1.length == ts2.length && (ts1.zip ts2 |>.all (fun (t1, t2) => highEq t1 t2)) | _, _ => false - termination_by (SizeOf.sizeOf a) - decreasing_by - all_goals(simp_wf; try omega) - . cases a1; simp; rename_i hin; have := List.sizeOf_lt_of_mem hin; omega - . cases t1; simp; rename_i hin; have := List.sizeOf_lt_of_mem hin; omega -instance : BEq HighType where +instance : BEq HighTypeMd where beq := highEq def HighType.isBool : HighType → Bool | TBool => true | _ => false +def HighTypeMd.isBool (t : HighTypeMd) : Bool := t.val.isBool + structure Field where name : Identifier isMutable : Bool - type : HighType + type : HighTypeMd structure CompositeType where name : Identifier @@ -232,10 +242,10 @@ structure CompositeType where structure ConstrainedType where name : Identifier - base : HighType + base : HighTypeMd valueName : Identifier - constraint : StmtExpr - witness : StmtExpr + constraint : StmtExprMd + witness : StmtExprMd /- Note that there are no explicit 'inductive datatypes'. Typed unions are created by @@ -255,7 +265,7 @@ inductive TypeDefinition where structure Constant where name : Identifier - type : HighType + type : HighTypeMd structure Program where staticProcedures : List Procedure diff --git a/Strata/Languages/Laurel/LaurelEval.lean b/Strata/Languages/Laurel/LaurelEval.lean index fd81fc67d9..6ebd199cdc 100644 --- a/Strata/Languages/Laurel/LaurelEval.lean +++ b/Strata/Languages/Laurel/LaurelEval.lean @@ -209,8 +209,9 @@ partial def eval (expr : StmtExpr) : Eval TypedValue := else setLocal param.name arg ) - let precondition ← eval callable.precondition - assertBool precondition + for precondition in callable.preconditions do + let precondResult ← eval precondition + assertBool precondResult -- TODO, handle decreases let result: TypedValue ← match callable.body with @@ -246,9 +247,9 @@ partial def eval (expr : StmtExpr) : Eval TypedValue := let tv ← eval valExpr withResult (EvalResult.Return tv.val) | StmtExpr.Return none => fun env => (EvalResult.Success { val := Value.VVoid, ty := env.returnType }, env) - | StmtExpr.While _ none _ _ => withResult <| EvalResult.TypeError "While invariant was not derived" + | StmtExpr.While _ [] _ _ => withResult <| EvalResult.TypeError "While invariant was not derived" | StmtExpr.While _ _ none _ => withResult <| EvalResult.TypeError "While decreases was not derived" - | StmtExpr.While condExpr (some invariantExpr) (some decreasedExpr) bodyExpr => do + | StmtExpr.While condExpr (invariantExpr :: _) (some decreasedExpr) bodyExpr => do let rec loop : Eval TypedValue := do let cond ← eval condExpr if (cond.ty.isBool) then diff --git a/Strata/Languages/Laurel/LaurelFormat.lean b/Strata/Languages/Laurel/LaurelFormat.lean index 7b3628d5d4..c887796542 100644 --- a/Strata/Languages/Laurel/LaurelFormat.lean +++ b/Strata/Languages/Laurel/LaurelFormat.lean @@ -11,12 +11,12 @@ namespace Laurel open Std (Format) -mutual def formatOperation : Operation → Format | .Eq => "==" | .Neq => "!=" | .And => "&&" | .Or => "||" + | .Implies => "==>" | .Not => "!" | .Neg => "-" | .Add => "+" @@ -24,12 +24,17 @@ def formatOperation : Operation → Format | .Mul => "*" | .Div => "/" | .Mod => "%" + | .DivT => "/t" + | .ModT => "%t" | .Lt => "<" | .Leq => "<=" | .Gt => ">" | .Geq => ">=" -def formatHighType : HighType → Format +mutual +partial def formatHighType (t : HighTypeMd) : Format := formatHighTypeVal t.val + +partial def formatHighTypeVal : HighType → Format | .TVoid => "void" | .TBool => "bool" | .TInt => "int" @@ -44,8 +49,10 @@ def formatHighType : HighType → Format | .Intersection types => Format.joinSep (types.map formatHighType) " & " -def formatStmtExpr (s:StmtExpr) : Format := - match h: s with +partial def formatStmtExpr (s : StmtExprMd) : Format := formatStmtExprVal s.val + +partial def formatStmtExprVal (s:StmtExpr) : Format := + match s with | .IfThenElse cond thenBr elseBr => "if " ++ formatStmtExpr cond ++ " then " ++ formatStmtExpr thenBr ++ match elseBr with @@ -58,8 +65,10 @@ def formatStmtExpr (s:StmtExpr) : Format := match init with | none => "" | some e => " := " ++ formatStmtExpr e - | .While cond _ _ body => - "while " ++ formatStmtExpr cond ++ " " ++ formatStmtExpr body + | .While cond invs _ body => + "while " ++ formatStmtExpr cond ++ + (if invs.isEmpty then Format.nil else " invariant " ++ Format.joinSep (invs.map formatStmtExpr) "; ") ++ + " " ++ formatStmtExpr body | .Exit target => "exit " ++ Format.text target | .Return value => "return" ++ @@ -69,7 +78,7 @@ def formatStmtExpr (s:StmtExpr) : Format := | .LiteralInt n => Format.text (toString n) | .LiteralBool b => if b then "true" else "false" | .Identifier name => Format.text name - | .Assign target value _ => + | .Assign target value => formatStmtExpr target ++ " := " ++ formatStmtExpr value | .FieldSelect target field => formatStmtExpr target ++ "." ++ Format.text field @@ -99,65 +108,61 @@ def formatStmtExpr (s:StmtExpr) : Format := | .Assigned name => "assigned(" ++ formatStmtExpr name ++ ")" | .Old value => "old(" ++ formatStmtExpr value ++ ")" | .Fresh value => "fresh(" ++ formatStmtExpr value ++ ")" - | .Assert cond _ => "assert " ++ formatStmtExpr cond - | .Assume cond _ => "assume " ++ formatStmtExpr cond + | .Assert cond => "assert " ++ formatStmtExpr cond + | .Assume cond => "assume " ++ formatStmtExpr cond | .ProveBy value proof => "proveBy(" ++ formatStmtExpr value ++ ", " ++ formatStmtExpr proof ++ ")" | .ContractOf _ fn => "contractOf(" ++ formatStmtExpr fn ++ ")" | .Abstract => "abstract" | .All => "all" | .Hole => "" - decreasing_by - all_goals (simp_wf; try omega) - any_goals (rename_i x_in; have := List.sizeOf_lt_of_mem x_in; omega) - subst_vars; cases h; rename_i x_in; have := List.sizeOf_lt_of_mem x_in; omega -def formatParameter (p : Parameter) : Format := +partial def formatParameter (p : Parameter) : Format := Format.text p.name ++ ": " ++ formatHighType p.type -def formatDeterminism : Determinism → Format +partial def formatDeterminism : Determinism → Format | .deterministic none => "deterministic" | .deterministic (some reads) => "deterministic reads " ++ formatStmtExpr reads | .nondeterministic => "nondeterministic" -def formatBody : Body → Format +partial def formatBody : Body → Format | .Transparent body => formatStmtExpr body - | .Opaque post impl determ modif => + | .Opaque posts impl determ modif => "opaque " ++ formatDeterminism determ ++ (match modif with | none => "" | some m => " modifies " ++ formatStmtExpr m) ++ - " ensures " ++ formatStmtExpr post ++ + Format.join (posts.map (fun p => " ensures " ++ formatStmtExpr p)) ++ match impl with | none => "" | some e => " := " ++ formatStmtExpr e - | .Abstract post => "abstract ensures " ++ formatStmtExpr post + | .Abstract posts => "abstract" ++ Format.join (posts.map (fun p => " ensures " ++ formatStmtExpr p)) -def formatProcedure (proc : Procedure) : Format := +partial def formatProcedure (proc : Procedure) : Format := "procedure " ++ Format.text proc.name ++ "(" ++ Format.joinSep (proc.inputs.map formatParameter) ", " ++ ") returns " ++ Format.line ++ "(" ++ Format.joinSep (proc.outputs.map formatParameter) ", " ++ ")" ++ Format.line ++ formatBody proc.body -def formatField (f : Field) : Format := +partial def formatField (f : Field) : Format := (if f.isMutable then "var " else "val ") ++ Format.text f.name ++ ": " ++ formatHighType f.type -def formatCompositeType (ct : CompositeType) : Format := +partial def formatCompositeType (ct : CompositeType) : Format := "composite " ++ Format.text ct.name ++ (if ct.extending.isEmpty then Format.nil else " extends " ++ Format.joinSep (ct.extending.map Format.text) ", ") ++ " { " ++ Format.joinSep (ct.fields.map formatField) "; " ++ " }" -def formatConstrainedType (ct : ConstrainedType) : Format := +partial def formatConstrainedType (ct : ConstrainedType) : Format := "constrained " ++ Format.text ct.name ++ " = " ++ Format.text ct.valueName ++ ": " ++ formatHighType ct.base ++ " | " ++ formatStmtExpr ct.constraint -def formatTypeDefinition : TypeDefinition → Format +partial def formatTypeDefinition : TypeDefinition → Format | .Composite ty => formatCompositeType ty | .Constrained ty => formatConstrainedType ty -def formatProgram (prog : Program) : Format := +partial def formatProgram (prog : Program) : Format := Format.joinSep (prog.staticProcedures.map formatProcedure) "\n\n" end @@ -165,12 +170,18 @@ end instance : Std.ToFormat Operation where format := formatOperation -instance : Std.ToFormat HighType where +instance : Std.ToFormat HighTypeMd where format := formatHighType -instance : Std.ToFormat StmtExpr where +instance : Std.ToFormat HighType where + format := formatHighTypeVal + +instance : Std.ToFormat StmtExprMd where format := formatStmtExpr +instance : Std.ToFormat StmtExpr where + format := formatStmtExprVal + instance : Std.ToFormat Parameter where format := formatParameter diff --git a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean index 2fb17b3e1c..5807620150 100644 --- a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean +++ b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean @@ -25,214 +25,600 @@ namespace Strata.Laurel open Strata open Lambda (LMonoTy LTy LExpr) +def boolImpliesOp : Core.Expression.Expr := + .op () (Core.CoreIdent.unres "Bool.Implies") (some (LMonoTy.arrow LMonoTy.bool (LMonoTy.arrow LMonoTy.bool LMonoTy.bool))) + +def intDivTOp : Core.Expression.Expr := + .op () (Core.CoreIdent.unres "Int.DivT") (some (LMonoTy.arrow LMonoTy.int (LMonoTy.arrow LMonoTy.int LMonoTy.int))) + +def intModTOp : Core.Expression.Expr := + .op () (Core.CoreIdent.unres "Int.ModT") (some (LMonoTy.arrow LMonoTy.int (LMonoTy.arrow LMonoTy.int LMonoTy.int))) + +/-- Map from constrained type name to its definition -/ +abbrev ConstrainedTypeMap := Std.HashMap Identifier ConstrainedType + +/-- Pre-translated constraint: base type and Core expression with free variable for the value -/ +structure TranslatedConstraint where + base : HighType + valueName : Identifier + /-- Core expression for constraint, with valueName as free variable -/ + coreConstraint : Core.Expression.Expr + +/-- Map from constrained type name to pre-translated constraint -/ +abbrev TranslatedConstraintMap := Std.HashMap Identifier TranslatedConstraint + +/-- Map from function name to its type (for user-defined pure functions) -/ +abbrev FunctionTypeMap := Std.HashMap Identifier LMonoTy + +/-- Build a map of constrained types from a program -/ +def buildConstrainedTypeMap (types : List TypeDefinition) : ConstrainedTypeMap := + types.foldl (init := {}) fun m td => + match td with + | .Constrained ct => m.insert ct.name ct + | _ => m + +/-- Get the base type for a type, resolving constrained types -/ +partial def resolveBaseType (ctMap : ConstrainedTypeMap) (ty : HighType) : HighType := + match ty with + | .UserDefined name => + match ctMap.get? name with + | some ct => resolveBaseType ctMap ct.base.val + | none => ty + | .Applied ctor args => + .Applied ctor (args.map fun arg => ⟨resolveBaseType ctMap arg.val, arg.md⟩) + | _ => ty + /- Translate Laurel HighType to Core Type -/ -def translateType (ty : HighType) : LMonoTy := +partial def translateType (ty : HighType) : LMonoTy := match ty with | .TInt => LMonoTy.int | .TBool => LMonoTy.bool | .TVoid => LMonoTy.bool | .THeap => .tcons "Heap" [] - | .TField => .tcons "Field" [LMonoTy.int] -- For now, all fields hold int + | .TField => .tcons "Field" [LMonoTy.int] + | .Applied ctor [elemTy] => + match ctor.val with + | .UserDefined "Array" => .tcons "Array" [translateType elemTy.val] + | _ => panic s!"unsupported applied type {repr ty}" | .UserDefined _ => .tcons "Composite" [] | _ => panic s!"unsupported type {repr ty}" -abbrev TypeEnv := List (Identifier × HighType) +/-- Translate type, resolving constrained types to their base type recursively -/ +partial def translateTypeWithCT (ctMap : ConstrainedTypeMap) (ty : HighType) : LMonoTy := + match ty with + | .Applied ctor [elemTy] => + match ctor.val with + | .UserDefined "Array" => .tcons "Array" [translateTypeWithCT ctMap elemTy.val] + | _ => translateType (resolveBaseType ctMap ty) + | _ => translateType (resolveBaseType ctMap ty) + +/-- Translate HighTypeMd, extracting the value -/ +def translateTypeMdWithCT (ctMap : ConstrainedTypeMap) (ty : HighTypeMd) : LMonoTy := + translateTypeWithCT ctMap ty.val + +/-- Get the function type for a procedure (input types → output type) -/ +def getProcedureFunctionType (ctMap : ConstrainedTypeMap) (proc : Procedure) : LMonoTy := + let inputTypes := proc.inputs.flatMap fun p => + match p.type.val with + | .Applied ctor _ => + match ctor.val with + | .UserDefined "Array" => [translateTypeMdWithCT ctMap p.type, LMonoTy.int] + | _ => [translateTypeMdWithCT ctMap p.type] + | _ => [translateTypeMdWithCT ctMap p.type] + let outputType := match proc.outputs.head? with + | some p => translateTypeMdWithCT ctMap p.type + | none => LMonoTy.bool -- default for void functions + LMonoTy.mkArrow' outputType inputTypes + +/-- Build a map from function names to their types -/ +def buildFunctionTypeMap (ctMap : ConstrainedTypeMap) (procs : List Procedure) : FunctionTypeMap := + procs.foldl (init := {}) fun m proc => + m.insert proc.name (getProcedureFunctionType ctMap proc) -def lookupType (env : TypeEnv) (name : Identifier) : LMonoTy := +abbrev TypeEnv := List (Identifier × HighTypeMd) + +def lookupType (ctMap : ConstrainedTypeMap) (env : TypeEnv) (name : Identifier) : Except String LMonoTy := match env.find? (fun (n, _) => n == name) with - | some (_, ty) => translateType ty - | none => LMonoTy.int -- fallback + | some (_, ty) => pure (translateTypeMdWithCT ctMap ty) + | none => throw s!"Unknown identifier: {name}" + +/-- Sequence bounds: array with start (inclusive) and end (exclusive) indices -/ +structure SeqBounds where + arr : Core.Expression.Expr -- the underlying array + start : Core.Expression.Expr -- start index (inclusive) + «end» : Core.Expression.Expr -- end index (exclusive) +deriving Inhabited + +/-- Expand array argument to include length parameter -/ +def expandArrayArgs (env : TypeEnv) (args : List StmtExprMd) (translatedArgs : List Core.Expression.Expr) : List Core.Expression.Expr := + (args.zip translatedArgs).flatMap fun (arg, translated) => + match arg.val with + | .Identifier arrName => + match env.find? (fun (n, _) => n == arrName) with + | some (_, ty) => + match ty.val with + | .Applied ctor _ => + match ctor.val with + | .UserDefined "Array" => [translated, .fvar () (Core.CoreIdent.locl (arrName ++ "_len")) (some LMonoTy.int)] + | _ => [translated] + | _ => [translated] + | _ => [translated] + | _ => [translated] + +/-- Translate a binary operation to Core -/ +def translateBinOp (op : Operation) (e1 e2 : Core.Expression.Expr) : Except String Core.Expression.Expr := + let binOp (bop : Core.Expression.Expr) := LExpr.mkApp () bop [e1, e2] + match op with + | .Eq => pure (.eq () e1 e2) + | .Neq => pure (.app () boolNotOp (.eq () e1 e2)) + | .And => pure (binOp boolAndOp) | .Or => pure (binOp boolOrOp) + | .Implies => pure (binOp boolImpliesOp) + | .Add => pure (binOp intAddOp) | .Sub => pure (binOp intSubOp) | .Mul => pure (binOp intMulOp) + | .Div => pure (binOp intDivOp) | .Mod => pure (binOp intModOp) + | .DivT => pure (binOp intDivTOp) | .ModT => pure (binOp intModTOp) + | .Lt => pure (binOp intLtOp) | .Leq => pure (binOp intLeOp) | .Gt => pure (binOp intGtOp) | .Geq => pure (binOp intGeOp) + | _ => throw s!"translateBinOp: unsupported {repr op}" + +/-- Translate a unary operation to Core -/ +def translateUnaryOp (op : Operation) (e : Core.Expression.Expr) : Except String Core.Expression.Expr := + match op with + | .Not => pure (.app () boolNotOp e) + | .Neg => pure (.app () intNegOp e) + | _ => throw s!"translateUnaryOp: unsupported {repr op}" + +/-- Translate simple expressions (for constraints - no quantifiers) -/ +partial def translateSimpleExpr (ctMap : ConstrainedTypeMap) (env : TypeEnv) (expr : StmtExprMd) : Except String Core.Expression.Expr := + match expr.val with + | .LiteralBool b => pure (.const () (.boolConst b)) + | .LiteralInt i => pure (.const () (.intConst i)) + | .Identifier name => do + let ty ← lookupType ctMap env name + pure (.fvar () (Core.CoreIdent.locl name) (some ty)) + | .PrimitiveOp op [e] => do + let e' ← translateSimpleExpr ctMap env e + translateUnaryOp op e' + | .PrimitiveOp op [e1, e2] => do + let e1' ← translateSimpleExpr ctMap env e1 + let e2' ← translateSimpleExpr ctMap env e2 + translateBinOp op e1' e2' + | .Forall _ _ _ => throw "Quantifiers not supported in constrained type constraints" + | .Exists _ _ _ => throw "Quantifiers not supported in constrained type constraints" + | _ => throw "Unsupported expression in constrained type constraint" + +/-- Build map of pre-translated constraints -/ +def buildTranslatedConstraintMap (ctMap : ConstrainedTypeMap) : Except String TranslatedConstraintMap := + ctMap.foldM (init := {}) fun m name ct => do + let env : TypeEnv := [(ct.valueName, ct.base)] + let coreExpr ← translateSimpleExpr ctMap env ct.constraint + pure (m.insert name { base := ct.base.val, valueName := ct.valueName, coreConstraint := coreExpr }) + +/-- Close free variable by name, converting fvar to bvar at depth k -/ +def varCloseByName (k : Nat) (x : Core.CoreIdent) (e : Core.Expression.Expr) : Core.Expression.Expr := + match e with + | .const m c => .const m c + | .op m o ty => .op m o ty + | .bvar m i => .bvar m i + | .fvar m y yty => if x == y then .bvar m k else .fvar m y yty + | .abs m ty e' => .abs m ty (varCloseByName (k + 1) x e') + | .quant m qk ty tr e' => .quant m qk ty (varCloseByName (k + 1) x tr) (varCloseByName (k + 1) x e') + | .app m e1 e2 => .app m (varCloseByName k x e1) (varCloseByName k x e2) + | .ite m c t f => .ite m (varCloseByName k x c) (varCloseByName k x t) (varCloseByName k x f) + | .eq m e1 e2 => .eq m (varCloseByName k x e1) (varCloseByName k x e2) + +/-- Translate simple expression (identifier or literal) to Core - for sequence bounds -/ +def translateSimpleBound (expr : StmtExprMd) : Except String Core.Expression.Expr := + match expr.val with + | .Identifier name => pure (.fvar () (Core.CoreIdent.locl name) (some LMonoTy.int)) + | .LiteralInt i => pure (.const () (.intConst i)) + | _ => throw "Expected simple bound expression (identifier or literal)" + +/-- Normalize callee name by removing «» quotes if present -/ +def normalizeCallee (callee : Identifier) : Identifier := + if callee.startsWith "«" && callee.endsWith "»" then + callee.drop 1 |>.dropRight 1 + else + callee + +/-- Extract sequence bounds from Seq.From/Take/Drop chain -/ +partial def translateSeqBounds (env : TypeEnv) (expr : StmtExprMd) : Except String SeqBounds := + match expr.val with + | .StaticCall callee [arr] => + if normalizeCallee callee == "Seq.From" then + match arr.val with + | .Identifier name => + -- Validate that name is an array + match env.find? (fun (n, _) => n == name) with + | some (_, ty) => + match ty.val with + | .Applied ctor _ => + match ctor.val with + | .UserDefined "Array" => + pure { arr := .fvar () (Core.CoreIdent.locl name) none + , start := .const () (.intConst 0) + , «end» := .fvar () (Core.CoreIdent.locl (name ++ "_len")) (some LMonoTy.int) } + | _ => throw s!"Seq.From expects array, got {repr ty}" + | _ => throw s!"Seq.From expects array, got {repr ty}" + | none => throw s!"Unknown identifier in Seq.From: {name}" + | _ => throw "Seq.From on complex expressions not supported" + else + throw s!"Not a sequence expression: {callee}" + | .StaticCall callee [seq, n] => + let norm := normalizeCallee callee + if norm == "Seq.Take" then do + let inner ← translateSeqBounds env seq + let bound ← translateSimpleBound n + pure { inner with «end» := bound } + else if norm == "Seq.Drop" then do + let inner ← translateSeqBounds env seq + let bound ← translateSimpleBound n + pure { inner with start := bound } + else + throw s!"Not a sequence expression: {callee}" + | _ => throw "Not a sequence expression" + +/-- Inject constraint into quantifier body. For forall uses ==>, for exists uses &&. -/ +def injectQuantifierConstraint (ctMap : ConstrainedTypeMap) (tcMap : TranslatedConstraintMap) + (isForall : Bool) (ty : HighTypeMd) (coreIdent : Core.CoreIdent) (closedBody : Core.Expression.Expr) : Core.Expression.Expr := + match ty.val with + | .UserDefined typeName => match tcMap.get? typeName with + | some tc => + let substConstraint := tc.coreConstraint.substFvar (Core.CoreIdent.locl tc.valueName) + (.fvar () coreIdent (some (translateTypeMdWithCT ctMap ty))) + let op := if isForall then boolImpliesOp else boolAndOp + LExpr.mkApp () op [varCloseByName 0 coreIdent substConstraint, closedBody] + | none => closedBody + | _ => closedBody /-- Translate Laurel StmtExpr to Core Expression -/ -def translateExpr (env : TypeEnv) (expr : StmtExpr) : Core.Expression.Expr := - match h: expr with - | .LiteralBool b => .const () (.boolConst b) - | .LiteralInt i => .const () (.intConst i) - | .Identifier name => - let ident := Core.CoreIdent.locl name - .fvar () ident (some (lookupType env name)) - | .PrimitiveOp op [e] => - match op with - | .Not => .app () boolNotOp (translateExpr env e) - | .Neg => .app () intNegOp (translateExpr env e) - | _ => panic! s!"translateExpr: Invalid unary op: {repr op}" - | .PrimitiveOp op [e1, e2] => - let binOp (bop : Core.Expression.Expr): Core.Expression.Expr := - LExpr.mkApp () bop [translateExpr env e1, translateExpr env e2] - match op with - | .Eq => .eq () (translateExpr env e1) (translateExpr env e2) - | .Neq => .app () boolNotOp (.eq () (translateExpr env e1) (translateExpr env e2)) - | .And => binOp boolAndOp - | .Or => binOp boolOrOp - | .Add => binOp intAddOp - | .Sub => binOp intSubOp - | .Mul => binOp intMulOp - | .Div => binOp intDivOp - | .Mod => binOp intModOp - | .Lt => binOp intLtOp - | .Leq => binOp intLeOp - | .Gt => binOp intGtOp - | .Geq => binOp intGeOp - | _ => panic! s!"translateExpr: Invalid binary op: {repr op}" +partial def translateExpr (ctMap : ConstrainedTypeMap) (tcMap : TranslatedConstraintMap) (ftMap : FunctionTypeMap) (env : TypeEnv) (expr : StmtExprMd) : Except String Core.Expression.Expr := + match expr.val with + | .LiteralBool b => pure (.const () (.boolConst b)) + | .LiteralInt i => pure (.const () (.intConst i)) + | .Identifier name => do + let ty ← lookupType ctMap env name + pure (.fvar () (Core.CoreIdent.locl name) (some ty)) + | .PrimitiveOp op [e] => do + let e' ← translateExpr ctMap tcMap ftMap env e + translateUnaryOp op e' + | .PrimitiveOp op [e1, e2] => do + let e1' ← translateExpr ctMap tcMap ftMap env e1 + let e2' ← translateExpr ctMap tcMap ftMap env e2 + translateBinOp op e1' e2' | .PrimitiveOp op args => - panic! s!"translateExpr: PrimitiveOp {repr op} with {args.length} args" - | .IfThenElse cond thenBranch elseBranch => - let bcond := translateExpr env cond - let bthen := translateExpr env thenBranch - let belse := match elseBranch with - | some e => translateExpr env e - | none => .const () (.intConst 0) - .ite () bcond bthen belse - | .Assign _ value _ => translateExpr env value - | .StaticCall name args => - let ident := Core.CoreIdent.glob name - let fnOp := .op () ident none - args.foldl (fun acc arg => .app () acc (translateExpr env arg)) fnOp - | .ReferenceEquals e1 e2 => - .eq () (translateExpr env e1) (translateExpr env e2) - | .Block [single] _ => translateExpr env single - | _ => panic! Std.Format.pretty (Std.ToFormat.format expr) - decreasing_by - all_goals (simp_wf; try omega) - rename_i x_in; have := List.sizeOf_lt_of_mem x_in; omega + throw s!"translateExpr: PrimitiveOp {repr op} with {args.length} args" + | .IfThenElse cond thenBranch elseBranch => do + let bcond ← translateExpr ctMap tcMap ftMap env cond + let bthen ← translateExpr ctMap tcMap ftMap env thenBranch + let belse ← match elseBranch with + | some e => translateExpr ctMap tcMap ftMap env e + | none => pure (.const () (.intConst 0)) + pure (.ite () bcond bthen belse) + | .Assign _ value => translateExpr ctMap tcMap ftMap env value + | .StaticCall callee [arg] => + let norm := normalizeCallee callee + if norm == "Array.Length" then + match arg.val with + | .Identifier name => pure (.fvar () (Core.CoreIdent.locl (name ++ "_len")) (some LMonoTy.int)) + | _ => throw "Array.Length on complex expressions not supported" + else do + let calleeOp := LExpr.op () (Core.CoreIdent.glob norm) (ftMap.get? norm) + let translated ← translateExpr ctMap tcMap ftMap env arg + let expandedArgs := expandArrayArgs env [arg] [translated] + pure (expandedArgs.foldl (fun acc a => .app () acc a) calleeOp) + | .StaticCall callee [arg1, arg2] => + let norm := normalizeCallee callee + if norm == "Array.Get" then do + let arrExpr ← translateExpr ctMap tcMap ftMap env arg1 + let idxExpr ← translateExpr ctMap tcMap ftMap env arg2 + let selectOp := LExpr.op () (Core.CoreIdent.unres "select") none + pure (LExpr.mkApp () selectOp [arrExpr, idxExpr]) + else if norm == "Seq.Contains" then do + -- exists i :: start <= i < end && arr[i] == elem + let bounds ← translateSeqBounds env arg1 + let elemExpr ← translateExpr ctMap tcMap ftMap env arg2 + let i := LExpr.bvar () 0 + -- start <= i + let geStart := LExpr.mkApp () intLeOp [bounds.start, i] + -- i < end + let ltEnd := LExpr.mkApp () intLtOp [i, bounds.«end»] + -- arr[i] + let selectOp := LExpr.op () (Core.CoreIdent.unres "select") none + let arrAtI := LExpr.mkApp () selectOp [bounds.arr, i] + -- arr[i] == elem + let eqElem := LExpr.eq () arrAtI elemExpr + -- start <= i && i < end && arr[i] == elem + let body := LExpr.mkApp () boolAndOp [geStart, LExpr.mkApp () boolAndOp [ltEnd, eqElem]] + pure (LExpr.quant () .exist (some LMonoTy.int) (LExpr.noTrigger ()) body) + else do + -- Default: treat as function call with array expansion + let calleeOp := LExpr.op () (Core.CoreIdent.glob norm) (ftMap.get? norm) + let e1 ← translateExpr ctMap tcMap ftMap env arg1 + let e2 ← translateExpr ctMap tcMap ftMap env arg2 + let expandedArgs := expandArrayArgs env [arg1, arg2] [e1, e2] + pure (expandedArgs.foldl (fun acc a => .app () acc a) calleeOp) + | .StaticCall name args => do + let normName := normalizeCallee name + let fnTy := ftMap.get? normName + let fnOp := LExpr.op () (Core.CoreIdent.glob normName) fnTy + let translatedArgs ← args.mapM (translateExpr ctMap tcMap ftMap env) + let expandedArgs := expandArrayArgs env args translatedArgs + pure (expandedArgs.foldl (fun acc a => .app () acc a) fnOp) + | .ReferenceEquals e1 e2 => do + let e1' ← translateExpr ctMap tcMap ftMap env e1 + let e2' ← translateExpr ctMap tcMap ftMap env e2 + pure (.eq () e1' e2') + | .Block [single] _ => translateExpr ctMap tcMap ftMap env single + | .Forall _name ty body => do + let coreType := translateTypeMdWithCT ctMap ty + let env' := (_name, ty) :: env + let bodyExpr ← translateExpr ctMap tcMap ftMap env' body + let coreIdent := Core.CoreIdent.locl _name + let closedBody := varCloseByName 0 coreIdent bodyExpr + let finalBody := injectQuantifierConstraint ctMap tcMap true ty coreIdent closedBody + pure (LExpr.quant () .all (some coreType) (LExpr.noTrigger ()) finalBody) + | .Exists _name ty body => do + let coreType := translateTypeMdWithCT ctMap ty + let env' := (_name, ty) :: env + let bodyExpr ← translateExpr ctMap tcMap ftMap env' body + let coreIdent := Core.CoreIdent.locl _name + let closedBody := varCloseByName 0 coreIdent bodyExpr + let finalBody := injectQuantifierConstraint ctMap tcMap false ty coreIdent closedBody + pure (LExpr.quant () .exist (some coreType) (LExpr.noTrigger ()) finalBody) + | .Return (some e) => translateExpr ctMap tcMap ftMap env e + | _ => throw s!"translateExpr: unsupported {Std.Format.pretty (Std.ToFormat.format expr.val)}" def getNameFromMd (md : Imperative.MetaData Core.Expression): String := let fileRange := (Imperative.getFileRange md).get! s!"({fileRange.range.start})" +def genConstraintCheck (ctMap : ConstrainedTypeMap) (tcMap : TranslatedConstraintMap) (param : Parameter) : Option Core.Expression.Expr := + match param.type.val with + | .UserDefined name => + match tcMap.get? name with + | some tc => + let paramIdent := Core.CoreIdent.locl param.name + let valueIdent := Core.CoreIdent.locl tc.valueName + let baseTy := translateTypeMdWithCT ctMap param.type + some (tc.coreConstraint.substFvar valueIdent (.fvar () paramIdent (some baseTy))) + | none => none + | _ => none + +def genConstraintAssert (ctMap : ConstrainedTypeMap) (tcMap : TranslatedConstraintMap) (name : Identifier) (ty : HighTypeMd) : List Core.Statement := + match genConstraintCheck ctMap tcMap { name, type := ty } with + | some expr => [Core.Statement.assert s!"{name}_constraint" expr ty.md] + | none => [] + +def defaultExprForType (ctMap : ConstrainedTypeMap) (ty : HighTypeMd) : Except String Core.Expression.Expr := + match resolveBaseType ctMap ty.val with + | .TInt => pure (.const () (.intConst 0)) + | .TBool => pure (.const () (.boolConst false)) + | other => throw s!"No default value for type {repr other}" + +def isHeapFunction (name : Identifier) : Bool := + name == "heapRead" || name == "heapStore" + +/-- Check if a StaticCall should be translated as an expression (not a procedure call) -/ +def isExpressionCall (callee : Identifier) : Bool := + let norm := normalizeCallee callee + isHeapFunction norm || norm.startsWith "Seq." || norm.startsWith "Array." + /-- Translate Laurel StmtExpr to Core Statements -Takes the type environment and output parameter names +Takes the type environment, output parameter names, and postconditions to assert at returns -/ -def translateStmt (env : TypeEnv) (outputParams : List Parameter) (stmt : StmtExpr) : TypeEnv × List Core.Statement := - match stmt with - | @StmtExpr.Assert cond md => - let boogieExpr := translateExpr env cond - (env, [Core.Statement.assert ("assert" ++ getNameFromMd md) boogieExpr md]) - | @StmtExpr.Assume cond md => - let boogieExpr := translateExpr env cond - (env, [Core.Statement.assume ("assume" ++ getNameFromMd md) boogieExpr md]) - | .Block stmts _ => - let (env', stmtsList) := stmts.foldl (fun (e, acc) s => - let (e', ss) := translateStmt e outputParams s - (e', acc ++ ss)) (env, []) - (env', stmtsList) - | .LocalVariable name ty initializer => +partial def translateStmt (ctMap : ConstrainedTypeMap) (tcMap : TranslatedConstraintMap) (ftMap : FunctionTypeMap) (env : TypeEnv) (outputParams : List Parameter) (postconds : List (String × Core.Expression.Expr)) (stmt : StmtExprMd) : Except String (TypeEnv × List Core.Statement) := + match stmt.val with + | .Assert cond => do + let boogieExpr ← translateExpr ctMap tcMap ftMap env cond + pure (env, [Core.Statement.assert ("assert" ++ getNameFromMd stmt.md) boogieExpr stmt.md]) + | .Assume cond => do + let boogieExpr ← translateExpr ctMap tcMap ftMap env cond + pure (env, [Core.Statement.assume ("assume" ++ getNameFromMd stmt.md) boogieExpr stmt.md]) + | .Block stmts _ => do + let mut env' := env + let mut stmtsList := [] + for s in stmts do + let (e', ss) ← translateStmt ctMap tcMap ftMap env' outputParams postconds s + env' := e' + stmtsList := stmtsList ++ ss + pure (env', stmtsList) + | .LocalVariable name ty initializer => do let env' := (name, ty) :: env - let boogieMonoType := translateType ty - let boogieType := LTy.forAll [] boogieMonoType + let boogieType := LTy.forAll [] (translateTypeMdWithCT ctMap ty) let ident := Core.CoreIdent.locl name + let constraintCheck := genConstraintAssert ctMap tcMap name ty match initializer with - | some (.StaticCall callee args) => - -- Check if this is a heap function (heapRead/heapStore) or a regular procedure call - -- Heap functions should be translated as expressions, not call statements - if callee == "heapRead" || callee == "heapStore" then - -- Translate as expression (function application) - let boogieExpr := translateExpr env (.StaticCall callee args) - (env', [Core.Statement.init ident boogieType boogieExpr]) - else - -- Translate as: var name; call name := callee(args) - let boogieArgs := args.map (translateExpr env) - let defaultExpr := match ty with - | .TInt => .const () (.intConst 0) - | .TBool => .const () (.boolConst false) - | _ => .const () (.intConst 0) - let initStmt := Core.Statement.init ident boogieType defaultExpr - let callStmt := Core.Statement.call [ident] callee boogieArgs - (env', [initStmt, callStmt]) - | some initExpr => - let boogieExpr := translateExpr env initExpr - (env', [Core.Statement.init ident boogieType boogieExpr]) - | none => - let defaultExpr := match ty with - | .TInt => .const () (.intConst 0) - | .TBool => .const () (.boolConst false) - | _ => .const () (.intConst 0) - (env', [Core.Statement.init ident boogieType defaultExpr]) - | .Assign target value _ => - match target with - | .Identifier name => + | some init => + match init.val with + | .StaticCall callee args => + if isExpressionCall callee then do + let boogieExpr ← translateExpr ctMap tcMap ftMap env init + pure (env', [Core.Statement.init ident boogieType boogieExpr] ++ constraintCheck) + else do + let boogieArgs ← args.mapM (translateExpr ctMap tcMap ftMap env) + let defaultVal ← defaultExprForType ctMap ty + let initStmt := Core.Statement.init ident boogieType defaultVal + let callStmt := Core.Statement.call [ident] callee boogieArgs + pure (env', [initStmt, callStmt] ++ constraintCheck) + | _ => do + let boogieExpr ← translateExpr ctMap tcMap ftMap env init + pure (env', [Core.Statement.init ident boogieType boogieExpr] ++ constraintCheck) + | none => do + let defaultVal ← defaultExprForType ctMap ty + pure (env', [Core.Statement.init ident boogieType defaultVal] ++ constraintCheck) + | .Assign target value => + match target.val with + | .Identifier name => do let ident := Core.CoreIdent.locl name - let boogieExpr := translateExpr env value - (env, [Core.Statement.set ident boogieExpr]) - | _ => (env, []) - | .IfThenElse cond thenBranch elseBranch => - let bcond := translateExpr env cond - let (_, bthen) := translateStmt env outputParams thenBranch - let belse := match elseBranch with - | some e => (translateStmt env outputParams e).2 - | none => [] - (env, [Imperative.Stmt.ite bcond bthen belse .empty]) - | .StaticCall name args => - -- Heap functions (heapRead/heapStore) should not appear as standalone statements - -- Only translate actual procedure calls to call statements - if name == "heapRead" || name == "heapStore" then - -- This shouldn't happen in well-formed programs, but handle gracefully - (env, []) - else - let boogieArgs := args.map (translateExpr env) - (env, [Core.Statement.call [] name boogieArgs]) - | .Return valueOpt => + let constraintCheck := match env.find? (fun (n, _) => n == name) with + | some (_, ty) => genConstraintAssert ctMap tcMap name ty + | none => [] + match value.val with + | .StaticCall callee args => + if isExpressionCall callee then do + let boogieExpr ← translateExpr ctMap tcMap ftMap env value + pure (env, [Core.Statement.set ident boogieExpr] ++ constraintCheck) + else do + let boogieArgs ← args.mapM (translateExpr ctMap tcMap ftMap env) + pure (env, [Core.Statement.call [ident] callee boogieArgs] ++ constraintCheck) + | _ => do + let boogieExpr ← translateExpr ctMap tcMap ftMap env value + pure (env, [Core.Statement.set ident boogieExpr] ++ constraintCheck) + | _ => throw s!"translateStmt: unsupported assignment target {Std.Format.pretty (Std.ToFormat.format target.val)}" + | .IfThenElse cond thenBranch elseBranch => do + let bcond ← translateExpr ctMap tcMap ftMap env cond + let (_, bthen) ← translateStmt ctMap tcMap ftMap env outputParams postconds thenBranch + let belse ← match elseBranch with + | some e => do let (_, s) ← translateStmt ctMap tcMap ftMap env outputParams postconds e; pure s + | none => pure [] + pure (env, [Imperative.Stmt.ite bcond bthen belse stmt.md]) + | .While cond invariants _decOpt body => do + let condExpr ← translateExpr ctMap tcMap ftMap env cond + -- Combine multiple invariants with && for Core (which expects single invariant) + let invExpr ← match invariants with + | [] => pure none + | [single] => do let e ← translateExpr ctMap tcMap ftMap env single; pure (some e) + | first :: rest => do + let firstExpr ← translateExpr ctMap tcMap ftMap env first + let combined ← rest.foldlM (fun acc inv => do + let invExpr ← translateExpr ctMap tcMap ftMap env inv + pure (LExpr.mkApp () boolAndOp [acc, invExpr])) firstExpr + pure (some combined) + let (_, bodyStmts) ← translateStmt ctMap tcMap ftMap env outputParams postconds body + pure (env, [Imperative.Stmt.loop condExpr none invExpr bodyStmts stmt.md]) + | .StaticCall name args => do + if isHeapFunction (normalizeCallee name) then pure (env, []) + else do + let boogieArgs ← args.mapM (translateExpr ctMap tcMap ftMap env) + pure (env, [Core.Statement.call [] name boogieArgs]) + | .Return valueOpt => do + -- Generate postcondition assertions before assuming false + let postAsserts := postconds.map fun (label, expr) => + Core.Statement.assert label expr stmt.md match valueOpt, outputParams.head? with - | some value, some outParam => + | some value, some outParam => do let ident := Core.CoreIdent.locl outParam.name - let boogieExpr := translateExpr env value + let boogieExpr ← translateExpr ctMap tcMap ftMap env value let assignStmt := Core.Statement.set ident boogieExpr - let noFallThrough := Core.Statement.assume "return" (.const () (.boolConst false)) .empty - (env, [assignStmt, noFallThrough]) + let noFallThrough := Core.Statement.assume "return" (.const () (.boolConst false)) stmt.md + pure (env, [assignStmt] ++ postAsserts ++ [noFallThrough]) | none, _ => - let noFallThrough := Core.Statement.assume "return" (.const () (.boolConst false)) .empty - (env, [noFallThrough]) + let noFallThrough := Core.Statement.assume "return" (.const () (.boolConst false)) stmt.md + pure (env, postAsserts ++ [noFallThrough]) | some _, none => - panic! "Return statement with value but procedure has no output parameters" - | _ => (env, []) + throw "Return statement with value but procedure has no output parameters" + | _ => throw s!"translateStmt: unsupported {Std.Format.pretty (Std.ToFormat.format stmt.val)}" /-- Translate Laurel Parameter to Core Signature entry -/ def translateParameterToCore (param : Parameter) : (Core.CoreIdent × LMonoTy) := let ident := Core.CoreIdent.locl param.name - let ty := translateType param.type + let ty := translateType param.type.val + (ident, ty) + +/-- Translate parameter with constrained type resolution -/ +def translateParameterToCoreWithCT (ctMap : ConstrainedTypeMap) (param : Parameter) : (Core.CoreIdent × LMonoTy) := + let ident := Core.CoreIdent.locl param.name + let ty := translateTypeMdWithCT ctMap param.type (ident, ty) +/-- Expand array parameter to (arr, arr_len) pair -/ +def expandArrayParam (ctMap : ConstrainedTypeMap) (param : Parameter) : List (Core.CoreIdent × LMonoTy) := + match param.type.val with + | .Applied ctor _ => + match ctor.val with + | .UserDefined "Array" => + [ (Core.CoreIdent.locl param.name, translateTypeMdWithCT ctMap param.type) + , (Core.CoreIdent.locl (param.name ++ "_len"), LMonoTy.int) ] + | _ => [translateParameterToCoreWithCT ctMap param] + | _ => [translateParameterToCoreWithCT ctMap param] + +def HighType.isHeap : HighType → Bool + | .THeap => true + | _ => false + /-- Translate Laurel Procedure to Core Procedure -/ -def translateProcedure (constants : List Constant) (proc : Procedure) : Core.Procedure := +def translateProcedure (ctMap : ConstrainedTypeMap) (tcMap : TranslatedConstraintMap) (ftMap : FunctionTypeMap) + (constants : List Constant) (proc : Procedure) : Except String Core.Decl := do -- Check if this procedure has a heap parameter (first input named "heap") - let hasHeapParam := proc.inputs.any (fun p => p.name == "heap" && p.type == .THeap) + let hasHeapParam := proc.inputs.any (fun p => p.name == "heap" && p.type.val.isHeap) -- Rename heap input to heap_in if present let renamedInputs := proc.inputs.map (fun p => - if p.name == "heap" && p.type == .THeap then { p with name := "heap_in" } else p) - let inputPairs := renamedInputs.map translateParameterToCore - let inputs := inputPairs + if p.name == "heap" && p.type.val.isHeap then { p with name := "heap_in" } else p) + let inputs := renamedInputs.flatMap (expandArrayParam ctMap) let header : Core.Procedure.Header := { name := proc.name typeArgs := [] inputs := inputs - outputs := proc.outputs.map translateParameterToCore + outputs := proc.outputs.flatMap (expandArrayParam ctMap) } + -- Build type environment with original types (for constraint checks) + -- Include array length parameters + let arrayLenEnv : TypeEnv := proc.inputs.filterMap (fun p => + match p.type.val with + | .Applied ctor _ => + match ctor.val with + | .UserDefined "Array" => some (p.name ++ "_len", ⟨.TInt, p.type.md⟩) + | _ => none + | _ => none) let initEnv : TypeEnv := proc.inputs.map (fun p => (p.name, p.type)) ++ proc.outputs.map (fun p => (p.name, p.type)) ++ + arrayLenEnv ++ constants.map (fun c => (c.name, c.type)) - -- Translate precondition if it's not just LiteralBool true - let preconditions : ListMap Core.CoreLabel Core.Procedure.Check := - match proc.precondition with - | .LiteralBool true => [] - | precond => - let check : Core.Procedure.Check := { expr := translateExpr initEnv precond } - [("requires", check)] - -- Translate postcondition for Opaque bodies - let postconditions : ListMap Core.CoreLabel Core.Procedure.Check := - match proc.body with - | .Opaque postcond _ _ _ => - let check : Core.Procedure.Check := { expr := translateExpr initEnv postcond } - [("ensures", check)] - | _ => [] + -- Generate constraint checks for input parameters with constrained types + let inputConstraints : List (Core.CoreLabel × Core.Procedure.Check) ← + proc.inputs.filterMapM (fun p => do + match genConstraintCheck ctMap tcMap p with + | some expr => pure (some (s!"{proc.name}_input_{p.name}_constraint", { expr, md := p.type.md })) + | none => pure none) + -- Array lengths are implicitly >= 0 + let arrayLenConstraints : List (Core.CoreLabel × Core.Procedure.Check) := + proc.inputs.filterMap (fun p => + match p.type.val with + | .Applied ctor _ => + match ctor.val with + | .UserDefined "Array" => + let lenVar := LExpr.fvar () (Core.CoreIdent.locl (p.name ++ "_len")) (some LMonoTy.int) + let zero := LExpr.intConst () 0 + let geZero := LExpr.mkApp () intLeOp [zero, lenVar] + some (s!"{proc.name}_input_{p.name}_len_constraint", { expr := geZero, md := p.type.md }) + | _ => none + | _ => none) + -- Translate explicit preconditions + let mut explicitPreconditions : List (Core.CoreLabel × Core.Procedure.Check) := [] + for h : i in [:proc.preconditions.length] do + let precond := proc.preconditions[i] + let expr ← translateExpr ctMap tcMap ftMap initEnv precond + let check : Core.Procedure.Check := { expr, md := precond.md } + explicitPreconditions := explicitPreconditions ++ [(s!"{proc.name}_pre_{i}", check)] + let preconditions := inputConstraints ++ arrayLenConstraints ++ explicitPreconditions + -- Generate constraint checks for output parameters with constrained types + let outputConstraints : List (Core.CoreLabel × Core.Procedure.Check) ← + proc.outputs.filterMapM (fun p => do + match genConstraintCheck ctMap tcMap p with + | some expr => pure (some (s!"{proc.name}_output_{p.name}_constraint", { expr, md := p.type.md })) + | none => pure none) + -- Translate explicit postconditions for Opaque bodies + let mut explicitPostconditions : List (Core.CoreLabel × Core.Procedure.Check) := [] + match proc.body with + | .Opaque posts _ _ _ => + for h : i in [:posts.length] do + let postcond := posts[i] + let expr ← translateExpr ctMap tcMap ftMap initEnv postcond + let check : Core.Procedure.Check := { expr, md := postcond.md } + explicitPostconditions := explicitPostconditions ++ [(s!"{proc.name}_post_{i}", check)] + | _ => pure () + let postconditions := explicitPostconditions ++ outputConstraints + -- Extract postcondition expressions for early return checking + let postcondExprs : List (String × Core.Expression.Expr) := + postconditions.map fun (label, check) => (label, check.expr) let spec : Core.Procedure.Spec := { modifies := [] preconditions := preconditions @@ -247,20 +633,25 @@ def translateProcedure (constants : List Constant) (proc : Procedure) : Core.Pro let heapInExpr := LExpr.fvar () (Core.CoreIdent.locl "heap_in") (some heapTy) [Core.Statement.init heapIdent heapType heapInExpr] else [] - let body : List Core.Statement := + let body : List Core.Statement ← match proc.body with - | .Transparent bodyExpr => heapInit ++ (translateStmt initEnv proc.outputs bodyExpr).2 - | .Opaque _postcond (some impl) _ _ => heapInit ++ (translateStmt initEnv proc.outputs impl).2 - | _ => [] - { + | .Transparent bodyExpr => do + let (_, stmts) ← translateStmt ctMap tcMap ftMap initEnv proc.outputs postcondExprs bodyExpr + pure (heapInit ++ stmts) + | .Opaque _posts (some impl) _ _ => do + let (_, stmts) ← translateStmt ctMap tcMap ftMap initEnv proc.outputs postcondExprs impl + pure (heapInit ++ stmts) + | _ => pure [] + pure <| Core.Decl.proc ({ header := header spec := spec body := body - } + }) .empty def heapTypeDecl : Core.Decl := .type (.con { name := "Heap", numargs := 0 }) def fieldTypeDecl : Core.Decl := .type (.con { name := "Field", numargs := 1 }) def compositeTypeDecl : Core.Decl := .type (.con { name := "Composite", numargs := 0 }) +def arrayTypeSynonym : Core.Decl := .type (.syn { name := "Array", typeArgs := ["T"], type := .tcons "Map" [.int, .ftvar "T"] }) def readFunction : Core.Decl := let heapTy := LMonoTy.tcons "Heap" [] @@ -268,7 +659,7 @@ def readFunction : Core.Decl := let tVar := LMonoTy.ftvar "T" let fieldTy := LMonoTy.tcons "Field" [tVar] .func { - name := Core.CoreIdent.glob "heapRead" + name := Core.CoreIdent.unres "heapRead" typeArgs := ["T"] inputs := [(Core.CoreIdent.locl "heap", heapTy), (Core.CoreIdent.locl "obj", compTy), @@ -283,7 +674,7 @@ def updateFunction : Core.Decl := let tVar := LMonoTy.ftvar "T" let fieldTy := LMonoTy.tcons "Field" [tVar] .func { - name := Core.CoreIdent.glob "heapStore" + name := Core.CoreIdent.unres "heapStore" typeArgs := ["T"] inputs := [(Core.CoreIdent.locl "heap", heapTy), (Core.CoreIdent.locl "obj", compTy), @@ -306,8 +697,8 @@ def readUpdateSameAxiom : Core.Decl := let o := LExpr.bvar () 1 let f := LExpr.bvar () 2 let v := LExpr.bvar () 3 - let updateOp := LExpr.op () (Core.CoreIdent.glob "heapStore") none - let readOp := LExpr.op () (Core.CoreIdent.glob "heapRead") none + let updateOp := LExpr.op () (Core.CoreIdent.unres "heapStore") none + let readOp := LExpr.op () (Core.CoreIdent.unres "heapRead") none let updateExpr := LExpr.mkApp () updateOp [h, o, f, v] let readExpr := LExpr.mkApp () readOp [updateExpr, o, f] let eqBody := LExpr.eq () readExpr v @@ -331,8 +722,8 @@ def readUpdateDiffObjAxiom : Core.Decl := let o2 := LExpr.bvar () 2 let f := LExpr.bvar () 3 let v := LExpr.bvar () 4 - let updateOp := LExpr.op () (Core.CoreIdent.glob "heapStore") none - let readOp := LExpr.op () (Core.CoreIdent.glob "heapRead") none + let updateOp := LExpr.op () (Core.CoreIdent.unres "heapStore") none + let readOp := LExpr.op () (Core.CoreIdent.unres "heapRead") none let updateExpr := LExpr.mkApp () updateOp [h, o1, f, v] let lhs := LExpr.mkApp () readOp [updateExpr, o2, f] let rhs := LExpr.mkApp () readOp [h, o2, f] @@ -345,8 +736,44 @@ def readUpdateDiffObjAxiom : Core.Decl := LExpr.all () (some heapTy) implBody .ax { name := "heapRead_heapStore_diff_obj", e := body } +/-- Truncating division (Java/C semantics): truncates toward zero -/ +def intDivTFunc : Core.Decl := + let a := LExpr.fvar () (Core.CoreIdent.locl "a") (some LMonoTy.int) + let b := LExpr.fvar () (Core.CoreIdent.locl "b") (some LMonoTy.int) + let zero := LExpr.intConst () 0 + let aGeZero := LExpr.mkApp () intGeOp [a, zero] + let bGeZero := LExpr.mkApp () intGeOp [b, zero] + let sameSign := LExpr.eq () aGeZero bGeZero + let euclidDiv := LExpr.mkApp () intDivOp [a, b] + let negA := LExpr.mkApp () intNegOp [a] + let negADivB := LExpr.mkApp () intDivOp [negA, b] + let negResult := LExpr.mkApp () intNegOp [negADivB] + let body := LExpr.ite () sameSign euclidDiv negResult + .func { + name := Core.CoreIdent.unres "Int.DivT" + typeArgs := [] + inputs := [(Core.CoreIdent.locl "a", LMonoTy.int), (Core.CoreIdent.locl "b", LMonoTy.int)] + output := LMonoTy.int + body := some body + } + +/-- Truncating modulo (Java/C semantics): a %t b = a - (a /t b) * b -/ +def intModTFunc : Core.Decl := + let a := LExpr.fvar () (Core.CoreIdent.locl "a") (some LMonoTy.int) + let b := LExpr.fvar () (Core.CoreIdent.locl "b") (some LMonoTy.int) + let divT := LExpr.mkApp () intDivTOp [a, b] + let mulDivB := LExpr.mkApp () intMulOp [divT, b] + let body := LExpr.mkApp () intSubOp [a, mulDivB] + .func { + name := Core.CoreIdent.unres "Int.ModT" + typeArgs := [] + inputs := [(Core.CoreIdent.locl "a", LMonoTy.int), (Core.CoreIdent.locl "b", LMonoTy.int)] + output := LMonoTy.int + body := some body + } + def translateConstant (c : Constant) : Core.Decl := - let ty := translateType c.type + let ty := translateType c.type.val .func { name := Core.CoreIdent.glob c.name typeArgs := [] @@ -360,18 +787,20 @@ Check if a StmtExpr is a pure expression (can be used as a Core function body). Pure expressions don't contain statements like assignments, loops, or local variables. A Block with a single pure expression is also considered pure. -/ -def isPureExpr : StmtExpr → Bool - | .LiteralBool _ => true - | .LiteralInt _ => true - | .Identifier _ => true - | .PrimitiveOp _ args => args.attach.all (fun ⟨a, _⟩ => isPureExpr a) - | .IfThenElse c t none => isPureExpr c && isPureExpr t - | .IfThenElse c t (some e) => isPureExpr c && isPureExpr t && isPureExpr e - | .StaticCall _ args => args.attach.all (fun ⟨a, _⟩ => isPureExpr a) - | .ReferenceEquals e1 e2 => isPureExpr e1 && isPureExpr e2 - | .Block [single] _ => isPureExpr single +partial def isPureExpr : StmtExprMd → Bool + | ⟨.LiteralBool _, _⟩ => true + | ⟨.LiteralInt _, _⟩ => true + | ⟨.Identifier _, _⟩ => true + | ⟨.PrimitiveOp _ args, _⟩ => args.all isPureExpr + | ⟨.IfThenElse c t none, _⟩ => isPureExpr c && isPureExpr t + | ⟨.IfThenElse c t (some e), _⟩ => isPureExpr c && isPureExpr t && isPureExpr e + | ⟨.StaticCall _ args, _⟩ => args.all isPureExpr + | ⟨.ReferenceEquals e1 e2, _⟩ => isPureExpr e1 && isPureExpr e2 + | ⟨.Block [single] _, _⟩ => isPureExpr single + | ⟨.Forall _ _ body, _⟩ => isPureExpr body + | ⟨.Exists _ _ body, _⟩ => isPureExpr body + | ⟨.Return (some e), _⟩ => isPureExpr e | _ => false -termination_by e => sizeOf e /-- Check if a procedure can be translated as a Core function. @@ -384,29 +813,38 @@ def canBeBoogieFunction (proc : Procedure) : Bool := match proc.body with | .Transparent bodyExpr => isPureExpr bodyExpr && - (match proc.precondition with | .LiteralBool true => true | _ => false) && + proc.preconditions.isEmpty && proc.outputs.length == 1 | _ => false /-- Translate a Laurel Procedure to a Core Function (when applicable) -/ -def translateProcedureToFunction (proc : Procedure) : Core.Decl := - let inputs := proc.inputs.map translateParameterToCore - let outputTy := match proc.outputs.head? with - | some p => translateType p.type - | none => LMonoTy.int - let initEnv : TypeEnv := proc.inputs.map (fun p => (p.name, p.type)) - let body := match proc.body with - | .Transparent bodyExpr => some (translateExpr initEnv bodyExpr) - | _ => none - .func { +def translateProcedureToFunction (ctMap : ConstrainedTypeMap) (tcMap : TranslatedConstraintMap) (ftMap : FunctionTypeMap) (proc : Procedure) : Except String Core.Decl := do + let inputs := proc.inputs.flatMap (expandArrayParam ctMap) + let outputTy ← match proc.outputs.head? with + | some p => pure (translateTypeMdWithCT ctMap p.type) + | none => throw s!"translateProcedureToFunction: {proc.name} has no output parameter" + let arrayLenEnv : TypeEnv := proc.inputs.filterMap (fun p => + match p.type.val with + | .Applied ctor _ => + match ctor.val with + | .UserDefined "Array" => some (p.name ++ "_len", ⟨.TInt, p.type.md⟩) + | _ => none + | _ => none) + let initEnv : TypeEnv := proc.inputs.map (fun p => (p.name, p.type)) ++ arrayLenEnv + let body ← match proc.body with + | .Transparent bodyExpr => do + let expr ← translateExpr ctMap tcMap ftMap initEnv bodyExpr + pure (some expr) + | _ => pure none + pure (.func { name := Core.CoreIdent.glob proc.name typeArgs := [] inputs := inputs output := outputTy body := body - } + }) /-- Translate Laurel Program to Core Program @@ -414,14 +852,18 @@ Translate Laurel Program to Core Program def translate (program : Program) : Except (Array DiagnosticModel) Core.Program := do let sequencedProgram ← liftExpressionAssignments program let heapProgram := heapParameterization sequencedProgram + -- Build constrained type maps + let ctMap := buildConstrainedTypeMap heapProgram.types + let tcMap ← buildTranslatedConstraintMap ctMap |>.mapError fun e => #[{ fileRange := default, message := e }] -- Separate procedures that can be functions from those that must be procedures let (funcProcs, procProcs) := heapProgram.staticProcedures.partition canBeBoogieFunction - let procedures := procProcs.map (translateProcedure heapProgram.constants) - let procDecls := procedures.map (fun p => Core.Decl.proc p .empty) - let laurelFuncDecls := funcProcs.map translateProcedureToFunction + -- Build function type map from procedures that will become functions + let ftMap := buildFunctionTypeMap ctMap funcProcs + let procDecls ← procProcs.mapM (translateProcedure ctMap tcMap ftMap heapProgram.constants) |>.mapError fun e => #[{ fileRange := default, message := e }] + let laurelFuncDecls ← funcProcs.mapM (translateProcedureToFunction ctMap tcMap ftMap) |>.mapError fun e => #[{ fileRange := default, message := e }] let constDecls := heapProgram.constants.map translateConstant - let typeDecls := [heapTypeDecl, fieldTypeDecl, compositeTypeDecl] - let funcDecls := [readFunction, updateFunction] + let typeDecls := [heapTypeDecl, fieldTypeDecl, compositeTypeDecl, arrayTypeSynonym] + let funcDecls := [readFunction, updateFunction, intDivTFunc, intModTFunc] let axiomDecls := [readUpdateSameAxiom, readUpdateDiffObjAxiom] return { decls := typeDecls ++ funcDecls ++ axiomDecls ++ constDecls ++ laurelFuncDecls ++ procDecls } @@ -439,9 +881,6 @@ def verifyToVcResults (smtsolver : String) (program : Program) match boogieProgramExcept with | .error e => return .error e | .ok boogieProgram => - dbg_trace "=== Generated Core Program ===" - dbg_trace (toString (Std.Format.pretty (Std.ToFormat.format boogieProgram))) - dbg_trace "=================================" let runner tempDir := EIO.toIO (fun f => IO.Error.userError (toString f)) diff --git a/Strata/Languages/Laurel/LiftExpressionAssignments.lean b/Strata/Languages/Laurel/LiftExpressionAssignments.lean index f112aaed5a..fbeb787dd5 100644 --- a/Strata/Languages/Laurel/LiftExpressionAssignments.lean +++ b/Strata/Languages/Laurel/LiftExpressionAssignments.lean @@ -28,13 +28,13 @@ Becomes: structure SequenceState where insideCondition : Bool - prependedStmts : List StmtExpr := [] + prependedStmts : List StmtExprMd := [] diagnostics : List DiagnosticModel tempCounter : Nat := 0 abbrev SequenceM := StateM SequenceState -def SequenceM.addPrependedStmt (stmt : StmtExpr) : SequenceM Unit := +def SequenceM.addPrependedStmt (stmt : StmtExprMd) : SequenceM Unit := modify fun s => { s with prependedStmts := stmt :: s.prependedStmts } def SequenceM.addDiagnostic (d : DiagnosticModel) : SequenceM Unit := @@ -52,7 +52,7 @@ def checkOutsideCondition(md: Imperative.MetaData Core.Expression): SequenceM Un def SequenceM.setInsideCondition : SequenceM Unit := do modify fun s => { s with insideCondition := true } -def SequenceM.takePrependedStmts : SequenceM (List StmtExpr) := do +def SequenceM.takePrependedStmts : SequenceM (List StmtExprMd) := do let stmts := (← get).prependedStmts modify fun s => { s with prependedStmts := [] } return stmts.reverse @@ -62,32 +62,40 @@ def SequenceM.freshTemp : SequenceM Identifier := do modify fun s => { s with tempCounter := s.tempCounter + 1 } return s!"__t{counter}" +/-- Helper to create a StmtExprMd with empty metadata -/ +def mkStmtExprMdEmpty' (e : StmtExpr) : StmtExprMd := ⟨e, #[]⟩ + +-- Add Inhabited instance for StmtExprMd to help with partial definitions +instance : Inhabited StmtExprMd where + default := ⟨.Hole, #[]⟩ + mutual /- Process an expression, extracting any assignments to preceding statements. Returns the transformed expression with assignments replaced by variable references. -/ -def transformExpr (expr : StmtExpr) : SequenceM StmtExpr := do - match expr with - | .Assign target value md => +partial def transformExpr (expr : StmtExprMd) : SequenceM StmtExprMd := do + let md := expr.md + match expr.val with + | .Assign target value => checkOutsideCondition md -- This is an assignment in expression context -- We need to: 1) execute the assignment, 2) capture the value in a temporary -- This prevents subsequent assignments to the same variable from changing the value let seqValue ← transformExpr value - let assignStmt := StmtExpr.Assign target seqValue md + let assignStmt : StmtExprMd := ⟨.Assign target seqValue, md⟩ SequenceM.addPrependedStmt assignStmt -- Create a temporary variable to capture the assigned value -- Use TInt as the type (could be refined with type inference) let tempName ← SequenceM.freshTemp - let tempDecl := StmtExpr.LocalVariable tempName .TInt (some target) + let tempDecl : StmtExprMd := ⟨.LocalVariable tempName ⟨.TInt, #[]⟩ (some target), md⟩ SequenceM.addPrependedStmt tempDecl -- Return the temporary variable as the expression value - return .Identifier tempName + return ⟨.Identifier tempName, md⟩ | .PrimitiveOp op args => let seqArgs ← args.mapM transformExpr - return .PrimitiveOp op seqArgs + return ⟨.PrimitiveOp op seqArgs, md⟩ | .IfThenElse cond thenBranch elseBranch => let seqCond ← transformExpr cond @@ -96,22 +104,22 @@ def transformExpr (expr : StmtExpr) : SequenceM StmtExpr := do let seqElse ← match elseBranch with | some e => transformExpr e >>= (pure ∘ some) | none => pure none - return .IfThenElse seqCond seqThen seqElse + return ⟨.IfThenElse seqCond seqThen seqElse, md⟩ | .StaticCall name args => let seqArgs ← args.mapM transformExpr - return .StaticCall name seqArgs + return ⟨.StaticCall name seqArgs, md⟩ | .Block stmts metadata => -- Block in expression position: move all but last statement to prepended - let rec next (remStmts: List StmtExpr) := match remStmts with + let rec next (remStmts: List StmtExprMd) := match remStmts with | [last] => transformExpr last | head :: tail => do let seqStmt ← transformStmt head for s in seqStmt do SequenceM.addPrependedStmt s next tail - | [] => return .Block [] metadata + | [] => return ⟨.Block [] metadata, md⟩ next stmts @@ -126,36 +134,37 @@ def transformExpr (expr : StmtExpr) : SequenceM StmtExpr := do Process a statement, handling any assignments in its sub-expressions. Returns a list of statements (the original one may be split into multiple). -/ -def transformStmt (stmt : StmtExpr) : SequenceM (List StmtExpr) := do - match stmt with - | @StmtExpr.Assert cond md => +partial def transformStmt (stmt : StmtExprMd) : SequenceM (List StmtExprMd) := do + let md := stmt.md + match stmt.val with + | .Assert cond => -- Process the condition, extracting any assignments let seqCond ← transformExpr cond - SequenceM.addPrependedStmt <| StmtExpr.Assert seqCond md + SequenceM.addPrependedStmt ⟨.Assert seqCond, md⟩ SequenceM.takePrependedStmts - | @StmtExpr.Assume cond md => + | .Assume cond => let seqCond ← transformExpr cond - SequenceM.addPrependedStmt <| StmtExpr.Assume seqCond md + SequenceM.addPrependedStmt ⟨.Assume seqCond, md⟩ SequenceM.takePrependedStmts | .Block stmts metadata => let seqStmts ← stmts.mapM transformStmt - return [.Block (seqStmts.flatten) metadata] + return [⟨.Block (seqStmts.flatten) metadata, md⟩] | .LocalVariable name ty initializer => match initializer with | some initExpr => do let seqInit ← transformExpr initExpr - SequenceM.addPrependedStmt <| .LocalVariable name ty (some seqInit) + SequenceM.addPrependedStmt ⟨.LocalVariable name ty (some seqInit), md⟩ SequenceM.takePrependedStmts | none => return [stmt] - | .Assign target value md => + | .Assign target value => let seqTarget ← transformExpr target let seqValue ← transformExpr value - SequenceM.addPrependedStmt <| .Assign seqTarget seqValue md + SequenceM.addPrependedStmt ⟨.Assign seqTarget seqValue, md⟩ SequenceM.takePrependedStmts | .IfThenElse cond thenBranch elseBranch => @@ -163,20 +172,20 @@ def transformStmt (stmt : StmtExpr) : SequenceM (List StmtExpr) := do SequenceM.setInsideCondition let seqThen ← transformStmt thenBranch - let thenBlock := .Block seqThen none + let thenBlock : StmtExprMd := ⟨.Block seqThen none, md⟩ let seqElse ← match elseBranch with | some e => let se ← transformStmt e - pure (some (.Block se none)) + pure (some (⟨.Block se none, md⟩ : StmtExprMd)) | none => pure none - SequenceM.addPrependedStmt <| .IfThenElse seqCond thenBlock seqElse + SequenceM.addPrependedStmt ⟨.IfThenElse seqCond thenBlock seqElse, md⟩ SequenceM.takePrependedStmts | .StaticCall name args => let seqArgs ← args.mapM transformExpr - SequenceM.addPrependedStmt <| .StaticCall name seqArgs + SequenceM.addPrependedStmt ⟨.StaticCall name seqArgs, md⟩ SequenceM.takePrependedStmts | _ => @@ -184,11 +193,11 @@ def transformStmt (stmt : StmtExpr) : SequenceM (List StmtExpr) := do end -def transformProcedureBody (body : StmtExpr) : SequenceM StmtExpr := do +def transformProcedureBody (body : StmtExprMd) : SequenceM StmtExprMd := do let seqStmts <- transformStmt body match seqStmts with | [single] => pure single - | multiple => pure <| .Block multiple.reverse none + | multiple => pure ⟨.Block multiple.reverse none, body.md⟩ def transformProcedure (proc : Procedure) : SequenceM Procedure := do match proc.body with diff --git a/StrataMain.lean b/StrataMain.lean index 287e66da14..e5c2a0519d 100644 --- a/StrataMain.lean +++ b/StrataMain.lean @@ -251,14 +251,12 @@ def deserializeIonToLaurelFiles (bytes : ByteArray) : IO (List Strata.StrataFile | .ok files => pure files | .error msg => exitFailure msg -def laurelAnalyzeCommand : Command where - name := "laurelAnalyze" +def laurelAnalyzeBinaryCommand : Command where + name := "laurelAnalyzeBinary" args := [] - help := "Analyze a Laurel Ion program from stdin. Write diagnostics to stdout." + help := "Analyze a Laurel program from binary (Ion) stdin. Write diagnostics to stdout." callback := fun _ _ => do - -- Read bytes from stdin let stdinBytes ← (← IO.getStdin).readBinToEnd - let strataFiles ← deserializeIonToLaurelFiles stdinBytes let mut combinedProgram : Strata.Laurel.Program := { @@ -268,12 +266,10 @@ def laurelAnalyzeCommand : Command where } for strataFile in strataFiles do - let transResult := Strata.Laurel.TransM.run (Strata.Uri.file strataFile.filePath) (Strata.Laurel.parseProgram strataFile.program) match transResult with | .error transErrors => exitFailure s!"Translation errors in {strataFile.filePath}: {transErrors}" | .ok laurelProgram => - combinedProgram := { staticProcedures := combinedProgram.staticProcedures ++ laurelProgram.staticProcedures staticFields := combinedProgram.staticFields ++ laurelProgram.staticFields @@ -286,8 +282,108 @@ def laurelAnalyzeCommand : Command where for diag in diagnostics do IO.println s!"{Std.format diag.fileRange.file}:{diag.fileRange.range.start}-{diag.fileRange.range.stop}: {diag.message}" +def laurelParseCommand : Command where + name := "laurelParse" + args := [ "file" ] + help := "Parse a Laurel source file (no verification)." + callback := fun _ v => do + let path : System.FilePath := v[0] + let content ← IO.FS.readFile path + let input := Strata.Parser.stringInputContext path content + let dialects := Strata.Elab.LoadedDialects.ofDialects! #[Strata.initDialect, Strata.Laurel.Laurel] + let strataProgram ← Strata.Elab.parseStrataProgramFromDialect dialects Strata.Laurel.Laurel.name input + + let uri := Strata.Uri.file path.toString + let transResult := Strata.Laurel.TransM.run uri (Strata.Laurel.parseProgram strataProgram) + match transResult with + | .error transErrors => exitFailure s!"Translation errors: {transErrors}" + | .ok _ => IO.println "Parse successful" + +def laurelAnalyzeCommand : Command where + name := "laurelAnalyze" + args := [ "file" ] + help := "Analyze a Laurel source file. Write diagnostics to stdout." + callback := fun _ v => do + let path : System.FilePath := v[0] + let content ← IO.FS.readFile path + let input := Strata.Parser.stringInputContext path content + let dialects := Strata.Elab.LoadedDialects.ofDialects! #[Strata.initDialect, Strata.Laurel.Laurel] + let strataProgram ← Strata.Elab.parseStrataProgramFromDialect dialects Strata.Laurel.Laurel.name input + + let uri := Strata.Uri.file path.toString + let transResult := Strata.Laurel.TransM.run uri (Strata.Laurel.parseProgram strataProgram) + match transResult with + | .error transErrors => exitFailure s!"Translation errors: {transErrors}" + | .ok laurelProgram => + let results ← Strata.Laurel.verifyToVcResults "z3" laurelProgram Options.default none + match results with + | .error errors => + IO.println s!"==== ERRORS ====" + for err in errors do + IO.println s!"{err.message}" + | .ok vcResults => + IO.println s!"==== RESULTS ====" + for vc in vcResults do + IO.println s!"{vc.obligation.label}: {repr vc.result}" + +def laurelPrintCommand : Command where + name := "laurelPrint" + args := [] + help := "Read Laurel Ion from stdin and print in concrete syntax to stdout." + callback := fun _ _ => do + let stdinBytes ← (← IO.getStdin).readBinToEnd + let strataFiles ← deserializeIonToLaurelFiles stdinBytes + for strataFile in strataFiles do + IO.println s!"// File: {strataFile.filePath}" + let p := strataFile.program + let c := p.formatContext {} + let s := p.formatState + let fmt := p.commands.foldl (init := f!"") fun f cmd => + f ++ (Strata.mformat cmd c s).format + IO.println (fmt.pretty 100) + IO.println "" + +def prettyPrintCore (p : Core.Program) : String := + let decls := p.decls.map fun d => + let s := toString (Std.format d) + -- Add newlines after major sections in procedures + s.replace "preconditions:" "\n preconditions:" + |>.replace "postconditions:" "\n postconditions:" + |>.replace "body:" "\n body:\n " + |>.replace "assert [" "\n assert [" + |>.replace "init (" "\n init (" + |>.replace "while (" "\n while (" + |>.replace "if (" "\n if (" + |>.replace "call [" "\n call [" + |>.replace "else{" "\n else {" + |>.replace "}}" "}\n }" + String.intercalate "\n" decls + +def laurelToCoreCommand : Command where + name := "laurelToCore" + args := [ "file" ] + help := "Translate a Laurel source file to Core and print to stdout." + callback := fun _ v => do + let path : System.FilePath := v[0] + let content ← IO.FS.readFile path + let input := Strata.Parser.stringInputContext path content + let dialects := Strata.Elab.LoadedDialects.ofDialects! #[Strata.initDialect, Strata.Laurel.Laurel] + let strataProgram ← Strata.Elab.parseStrataProgramFromDialect dialects Strata.Laurel.Laurel.name input + + let uri := Strata.Uri.file path.toString + let transResult := Strata.Laurel.TransM.run uri (Strata.Laurel.parseProgram strataProgram) + match transResult with + | .error transErrors => exitFailure s!"Translation errors: {transErrors}" + | .ok laurelProgram => + match Strata.Laurel.translate laurelProgram with + | .error diags => exitFailure s!"Core translation errors: {diags.map (·.message)}" + | .ok coreProgram => IO.println (prettyPrintCore coreProgram) + def commandList : List Command := [ javaGenCommand, + laurelPrintCommand, + laurelParseCommand, + laurelToCoreCommand, checkCommand, toIonCommand, printCommand, @@ -295,6 +391,7 @@ def commandList : List Command := [ pyAnalyzeCommand, pyTranslateCommand, laurelAnalyzeCommand, + laurelAnalyzeBinaryCommand, ] def commandMap : Std.HashMap String Command := diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean index 3ad972ee0d..d3858c8123 100644 --- a/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean @@ -8,23 +8,19 @@ import StrataTest.Util.TestDiagnostics import StrataTest.Languages.Laurel.TestExamples open StrataTest.Util -open Strata +namespace Strata namespace Laurel def program := r" constrained nat = x: int where x >= 0 witness 0 -composite Option {} -composite Some extends Option { - value: int -} -composite None extends Option -constrained SealedOption = x: Option where x is Some || x is None witness None - -procedure foo() returns (r: nat) { +procedure double(n: nat) returns (r: nat) + ensures r == n + n +{ + return n + n; } " --- Not working yet --- #eval! testInput "ConstrainedTypes" program processLaurelFile +#guard_msgs(drop info, error) in +#eval testInputWithOffset "ConstrainedTypes" program 14 processLaurelFile diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T11_Arrays.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T11_Arrays.lean new file mode 100644 index 0000000000..385df7b4a9 --- /dev/null +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T11_Arrays.lean @@ -0,0 +1,25 @@ +import StrataTest.Util.TestDiagnostics +import StrataTest.Languages.Laurel.TestExamples +open StrataTest.Util +namespace Strata.Laurel + +def program := r" +procedure getFirst(arr: Array, len: int) returns (r: int) +requires len > 0 +ensures r == arr[0] +{ + return arr[0]; +} + +procedure sumTwo(arr: Array, len: int) returns (r: int) +requires len >= 2 +ensures r == arr[0] + arr[1] +{ + return arr[0] + arr[1]; +} +" + +#guard_msgs(drop info, error) in +#eval testInputWithOffset "T11_Arrays" program 5 processLaurelFile + +end Strata.Laurel diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T12_Sequences.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T12_Sequences.lean new file mode 100644 index 0000000000..aaa610e74a --- /dev/null +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T12_Sequences.lean @@ -0,0 +1,27 @@ +import StrataTest.Util.TestDiagnostics +import StrataTest.Languages.Laurel.TestExamples +open StrataTest.Util +namespace Strata.Laurel + +def program := r" +procedure containsTarget(arr: Array, len: int, target: int) returns (r: bool) +requires len >= 0 +ensures r == Seq.Contains(Seq.From(arr), target) +{ + return Seq.Contains(Seq.From(arr), target); +} + +procedure containsInPrefix(arr: Array, len: int, n: int, target: int) returns (r: bool) +requires len >= 0 +requires n >= 0 +requires n <= len +ensures r == Seq.Contains(Seq.Take(Seq.From(arr), n), target) +{ + return Seq.Contains(Seq.Take(Seq.From(arr), n), target); +} +" + +#guard_msgs(drop info, error) in +#eval testInputWithOffset "T12_Sequences" program 5 processLaurelFile + +end Strata.Laurel diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T1b_Operators.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T1b_Operators.lean new file mode 100644 index 0000000000..37d2a74f58 --- /dev/null +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T1b_Operators.lean @@ -0,0 +1,61 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +import StrataTest.Util.TestDiagnostics +import StrataTest.Languages.Laurel.TestExamples + +open StrataTest.Util + +namespace Strata +namespace Laurel + +def program := r" +procedure testArithmetic() { + var a: int := 10; + var b: int := 3; + var x: int := a - b; + assert x == 7; + var y: int := x * 2; + assert y == 14; + var z: int := y / 2; + assert z == 7; + var r: int := 17 % 5; + assert r == 2; +} + +procedure testLogical() { + var t: bool := true; + var f: bool := false; + var a: bool := t && f; + assert a == false; + var b: bool := t || f; + assert b == true; + var c: bool := !f; + assert c == true; + assert t ==> t; + assert f ==> t; +} + +procedure testUnary() { + var x: int := 5; + var y: int := -x; + assert y == 0 - 5; +} + +procedure testTruncatingDiv() { + // Truncating division rounds toward zero (Java/C semantics) + // For positive numbers, same as Euclidean + assert 7 /t 3 == 2; + assert 7 %t 3 == 1; + // For negative dividend, truncates toward zero (not floor) + // -7 /t 3 = -2 (not -3), -7 %t 3 = -1 (not 2) + assert (0 - 7) /t 3 == 0 - 2; + assert (0 - 7) %t 3 == 0 - 1; +} +" + +#guard_msgs(drop info, error) in +#eval testInputWithOffset "Operators" program 14 processLaurelFile diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T4_WhileBasic.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T4_WhileBasic.lean new file mode 100644 index 0000000000..d06cdaa7a0 --- /dev/null +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T4_WhileBasic.lean @@ -0,0 +1,40 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +import StrataTest.Util.TestDiagnostics +import StrataTest.Languages.Laurel.TestExamples + +open StrataTest.Util + +namespace Strata +namespace Laurel + +def program := r" +procedure countdown() { + var i: int := 3; + while(i > 0) + invariant i >= 0 + { + i := i - 1; + } + assert i == 0; +} + +procedure countUp() { + var n: int := 5; + var i: int := 0; + while(i < n) + invariant i >= 0 + invariant i <= n + { + i := i + 1; + } + assert i == n; +} +" + +#guard_msgs(drop info, error) in +#eval testInputWithOffset "WhileBasic" program 14 processLaurelFile diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T5_Quantifiers.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T5_Quantifiers.lean new file mode 100644 index 0000000000..da5cff4428 --- /dev/null +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T5_Quantifiers.lean @@ -0,0 +1,25 @@ +import StrataTest.Util.TestDiagnostics +import StrataTest.Languages.Laurel.TestExamples +open StrataTest.Util +namespace Strata.Laurel + +def program := r" +procedure test(x: int) +requires forall(i: int) => i >= 0 +ensures exists(j: int) => j == x +{} + +procedure multiContract(x: int) returns (r: int) +requires x >= 0 +requires x <= 100 +ensures r >= x +ensures r <= x + 10 +{ + return x + 5; +} +" + +#guard_msgs(drop info) in +#eval testInputWithOffset "T5_Quantifiers" program 5 processLaurelFile + +end Strata.Laurel