Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Strata/Languages/Laurel/ConstrainedTypeElim.lean
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def constraintCallFor (ptMap : ConstrainedTypeMap) (ty : HighType)
(varName : Identifier) (src : Option FileRange := none) : Option StmtExprMd :=
match ty with
| .UserDefined name => if ptMap.contains name.text then
some ⟨.StaticCall (mkId s!"{name.text}$constraint") [⟨.Identifier varName, src⟩], src⟩
some ⟨.StaticCall (mkId s!"{name.text}$constraint") [⟨.Variable (.Local varName), src⟩], src⟩
else none
| _ => none

Expand All @@ -68,7 +68,7 @@ def mkConstraintFunc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Proce
if ptMap.contains parent.text then
let paramId := { ct.valueName with uniqueId := none }
let paramRef : StmtExprMd :=
{ val := .Identifier paramId, source := none }
{ val := .Variable (.Local paramId), source := none }
let parentCall : StmtExprMd :=
{ val := .StaticCall (mkId s!"{parent.text}$constraint") [paramRef], source := none }
{ val := .PrimitiveOp .And [ct.constraint, parentCall], source := none }
Expand Down Expand Up @@ -133,7 +133,7 @@ def elimStmt (ptMap : ConstrainedTypeMap)
pure ([⟨.LocalVariable name ty init', source⟩] ++ check)

| .Assign [target] _ => match target.val with
| .Identifier name => do
| .Local name => do
match (← get).get? name.text with
| some ty =>
let assert := (constraintCallFor ptMap ty name (src := source)).toList.map
Expand Down
8 changes: 6 additions & 2 deletions Strata/Languages/Laurel/CoreGroupingAndOrdering.lean
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ def collectStaticCallNames (expr : StmtExprMd) : List String :=
| none => []
| .Block stmts _ => stmts.flatMap (fun s => collectStaticCallNames s)
| .Assign targets v =>
targets.flatMap (fun t => collectStaticCallNames t) ++
targets.flatMap (fun t => match ht : t.val with
| .Local _ => []
| .Field target _ =>
have : sizeOf target < sizeOf t := Variable.sizeOf_field_target t ht
collectStaticCallNames target) ++
collectStaticCallNames v
| .LocalVariable _ _ initOption =>
match initOption with
Expand All @@ -85,7 +89,7 @@ def collectStaticCallNames (expr : StmtExprMd) : List String :=
| some t => collectStaticCallNames t
| none => []) ++
collectStaticCallNames body
| .FieldSelect t _ => collectStaticCallNames t
| .Variable (.Field t _) => collectStaticCallNames t
| .PureFieldUpdate t _ v => collectStaticCallNames t ++ collectStaticCallNames v
| .InstanceCall t _ args =>
collectStaticCallNames t ++ args.flatMap (fun a => collectStaticCallNames a)
Expand Down
2 changes: 1 addition & 1 deletion Strata/Languages/Laurel/EliminateHoles.lean
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ private def mkHoleCall (holeType : HighTypeMd) : ElimHoleM StmtExprMd := do
body := .Opaque [] none []
}
modify fun s => { s with generatedFunctions := s.generatedFunctions ++ [holeProc] }
return bare (.StaticCall holeName (inputs.map (fun p => bare (.Identifier p.name))))
return bare (.StaticCall holeName (inputs.map (fun p => bare (.Variable (.Local p.name)))))

/-- Replace a deterministic `.Hole` with a call to a fresh uninterpreted function.
Non-hole nodes pass through unchanged; recursion is handled by `mapStmtExprM`. -/
Expand Down
2 changes: 1 addition & 1 deletion Strata/Languages/Laurel/EliminateValueReturns.lean
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ private def eliminateValueReturnNode (outParam : Identifier) (stmt : StmtExprMd)
match stmt.val with
| .Return (some value) =>
-- Synthesized nodes use default metadata since no diagnostics should be reported on them
let target : StmtExprMd := { val := .Identifier outParam, source := none }
let target : VariableMd := { val := .Local outParam, source := none }
let assign : StmtExprMd := { val := .Assign [target] value, source := none }
let ret : StmtExprMd := { val := .Return none, source := stmt.source }
{ val := .Block [assign, ret] none, source := none }
Expand Down
11 changes: 7 additions & 4 deletions Strata/Languages/Laurel/FilterPrelude.lean
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,11 @@ private partial def collectExprNames (expr : StmtExprMd) : CollectM Unit := do
dec.forM collectExprNames
collectExprNames body
| .Assign targets value =>
collectExprNames value; targets.forM collectExprNames
| .FieldSelect target _ => collectExprNames target
collectExprNames value
targets.forM fun t => match t.val with
| .Local _ => pure ()
| .Field target _ => collectExprNames target
| .Variable (.Field target _) => collectExprNames target
| .PureFieldUpdate target _ newVal =>
collectExprNames target; collectExprNames newVal
| .PrimitiveOp _ args => args.forM collectExprNames
Expand All @@ -120,7 +123,7 @@ private partial def collectExprNames (expr : StmtExprMd) : CollectM Unit := do
| .ReferenceEquals lhs rhs => collectExprNames lhs; collectExprNames rhs
| .Hole _ ty => ty.forM collectHighTypeNames
| .Exit _ | .LiteralInt _ | .LiteralBool _ | .LiteralString _ | .LiteralDecimal _
| .Identifier _ | .This | .Abstract | .All => pure ()
| .Variable (.Local _) | .This | .Abstract | .All => pure ()

/-- Collect names from a procedure body. -/
private def collectBodyNames (body : Body) : CollectM Unit := do
Expand Down Expand Up @@ -177,7 +180,7 @@ private partial def collectInvokeOnTargets (expr : StmtExprMd)
| .StaticCall callee args =>
let rest ← args.flatMapM collectInvokeOnTargets
return callee.text :: rest
| .Identifier _ | .LiteralInt _ | .LiteralBool _ | .LiteralString _
| .Variable (.Local _) | .LiteralInt _ | .LiteralBool _ | .LiteralString _
| .LiteralDecimal _ => return []
| _ =>
throw s!"FilterPrelude.collectInvokeOnTargets: unexpected node in invokeOn expression"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ private def operationName : Operation → String
| .Gt => "gt" | .Geq => "ge" | .StrConcat => "strConcat"

-- Internal-only: public because `partial` prevents `private` in this section
mutual
partial def variableToArg (v : VariableMd) : Arg :=
match v.val with
| .Local name => laurelOp "identifier" #[ident name.text]
| .Field target field => laurelOp "fieldAccess" #[stmtExprToArg target, ident field.text]

partial def stmtExprToArg (s : StmtExprMd) : Arg :=
stmtExprValToArg s.val
where
Expand All @@ -90,7 +96,9 @@ where
| .LiteralString s => laurelOp "string" #[.strlit sr s]
| .Hole true _ => laurelOp "hole"
| .Hole false _ => laurelOp "nondetHole"
| .Identifier name => laurelOp "identifier" #[ident name.text]
| .Variable (.Local name) => laurelOp "identifier" #[ident name.text]
| .Variable (.Field target field) =>
laurelOp "fieldAccess" #[stmtExprToArg target, ident field.text]
| .Block stmts label =>
let stmtArgs := stmts.map stmtExprToArg |>.toArray
match label with
Expand All @@ -101,13 +109,16 @@ where
let initOpt := optionArg (init.map fun e => laurelOp "initializer" #[stmtExprToArg e])
laurelOp "varDecl" #[ident name.text, typeOpt, initOpt]
| .Assign targets value =>
-- Grammar only supports single-target assign; use first target or placeholder
let targetArg := match targets with
| t :: _ => stmtExprToArg t
| [] => laurelOp "identifier" #[ident "_"]
laurelOp "assign" #[targetArg, stmtExprToArg value]
| .FieldSelect target field =>
laurelOp "fieldAccess" #[stmtExprToArg target, ident field.text]
if targets.length > 1 then
let targetArgs := targets.map fun t => match t.val with
| .Local name => ident name.text
| .Field _ _ => ident "_"
laurelOp "multiAssign" #[commaSep targetArgs.toArray, stmtExprToArg value]
else
let targetArg := match targets with
| t :: _ => variableToArg t
| [] => laurelOp "identifier" #[ident "_"]
laurelOp "assign" #[targetArg, stmtExprToArg value]
| .StaticCall callee args =>
let calleeArg := laurelOp "identifier" #[ident callee.text]
let argsArr := args.map stmtExprToArg |>.toArray
Expand Down Expand Up @@ -165,9 +176,10 @@ where
| .PureFieldUpdate target field value =>
-- Not directly in grammar; emit as assignment to field
laurelOp "assign" #[
laurelOp "fieldAccess" #[stmtExprToArg target, ident field.text],
variableToArg ⟨.Field target field, none⟩,
stmtExprToArg value
]
end

private def parameterToArg (p : Parameter) : Arg :=
laurelOp "parameter" #[ident p.name.text, highTypeToArg p.type]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ instance : Inhabited Parameter where
def mkHighTypeMd (t : HighType) (source : Option FileRange) : HighTypeMd := { val := t, source := source }
def mkStmtExprMd (e : StmtExpr) (source : Option FileRange) : StmtExprMd := { val := e, source := source }

/-- Convert a parsed StmtExprMd (from the assign target position) into a VariableMd. -/
def stmtExprToVariable (e : StmtExprMd) : VariableMd :=
match e.val with
| .Variable v => ⟨v, e.source⟩
| _ => ⟨.Local { text := "_invalid_" }, e.source⟩

def translateNat (arg : Arg) : TransM Nat := do
let .num _ n := arg
| TransM.error s!"translateNat expects num literal"
Expand Down Expand Up @@ -243,12 +249,20 @@ partial def translateStmtExpr (arg : Arg) : TransM StmtExprMd := do
return mkStmtExprMd (.LocalVariable name varType value) src
| q`Laurel.identifier, #[arg0] =>
let name ← translateIdent arg0
return mkStmtExprMd (.Identifier name) src
return mkStmtExprMd (.Variable (.Local name)) src
| q`Laurel.parenthesis, #[arg0] => translateStmtExpr arg0
| q`Laurel.assign, #[arg0, arg1] =>
let target ← translateStmtExpr arg0
let varTarget := stmtExprToVariable target
let value ← translateStmtExpr arg1
return mkStmtExprMd (.Assign [varTarget] value) src
| q`Laurel.multiAssign, #[targetsSeq, arg1] =>
let targetIdents ← match targetsSeq with
| .seq _ .comma args => args.toList.mapM translateIdent
| _ => pure []
let variables := targetIdents.map fun name => (⟨.Local name, name.source⟩ : VariableMd)
let value ← translateStmtExpr arg1
return mkStmtExprMd (.Assign [target] value) src
return mkStmtExprMd (.Assign variables value) src
| q`Laurel.new, #[nameArg] =>
let name ← translateIdent nameArg
return mkStmtExprMd (.New name) src
Expand All @@ -263,7 +277,7 @@ partial def translateStmtExpr (arg : Arg) : TransM StmtExprMd := do
| q`Laurel.call, #[arg0, argsSeq] =>
let callee ← translateStmtExpr arg0
let calleeName := match callee.val with
| .Identifier name => name
| .Variable (.Local name) => name
| _ => ""
let argsList ← match argsSeq with
| .seq _ .comma args => args.toList.mapM translateStmtExpr
Expand All @@ -285,7 +299,7 @@ partial def translateStmtExpr (arg : Arg) : TransM StmtExprMd := do
let obj ← translateStmtExpr objArg
let field ← translateIdent fieldArg
let fieldSrc ← getArgFileRange fieldArg
return mkStmtExprMd (.FieldSelect obj field) fieldSrc
return mkStmtExprMd (.Variable (.Field obj field)) fieldSrc
| q`Laurel.while, #[condArg, invSeqArg, bodyArg] =>
let cond ← translateStmtExpr condArg
let invariants ← match invSeqArg with
Expand Down
2 changes: 1 addition & 1 deletion Strata/Languages/Laurel/Grammar/LaurelGrammar.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ module
-- Laurel dialect definition, loaded from LaurelGrammar.st
-- NOTE: Changes to LaurelGrammar.st are not automatically tracked by the build system.
-- Update this file (e.g. this comment) to trigger a recompile after modifying LaurelGrammar.st.
-- Last grammar change: added modifiesWildcard for `modifies *` and opaque keyword
-- Last grammar change: added multiAssign with parenthesized syntax
public import Strata.DDM.Integration.Lean
public meta import Strata.DDM.Integration.Lean

Expand Down
1 change: 1 addition & 0 deletions Strata/Languages/Laurel/Grammar/LaurelGrammar.st
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ op parenthesis (inner: StmtExpr): StmtExpr => "(" inner ")";

// Assignment
op assign (target: StmtExpr, value: StmtExpr): StmtExpr => @[prec(10)] target " := " value;
op multiAssign (targets: CommaSepBy Ident, value: StmtExpr): StmtExpr => @[prec(10)] "(" targets ") := " value;

// Binary operators
op add (lhs: StmtExpr, rhs: StmtExpr): StmtExpr => @[prec(60), leftassoc] lhs " + " rhs;
Expand Down
56 changes: 30 additions & 26 deletions Strata/Languages/Laurel/HeapParameterization.lean
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def collectExprMd (expr : StmtExprMd) : StateM AnalysisResult Unit := collectExp

def collectExpr (expr : StmtExpr) : StateM AnalysisResult Unit := do
match _: expr with
| .FieldSelect target _ =>
| .Variable (.Field target _) =>
modify fun s => { s with readsHeapDirectly := true }; collectExprMd target
| .InstanceCall target _ args => collectExprMd target; for a in args do collectExprMd a
| .StaticCall callee args => modify fun s => { s with callees := callee :: s.callees }; for a in args do collectExprMd a
Expand All @@ -69,11 +69,12 @@ def collectExpr (expr : StmtExpr) : StateM AnalysisResult Unit := do
| .Assign assignTargets v =>
-- Check if any target is a field assignment (heap write)
for ⟨assignTarget, _⟩ in assignTargets.attach do
match assignTarget.val with
| .FieldSelect _ _ =>
match ht : assignTarget.val with
| .Field target _ =>
have : sizeOf target < sizeOf assignTarget := Variable.sizeOf_field_target assignTarget ht
modify fun s => { s with writesHeapDirectly := true }
| _ => pure ()
collectExprMd assignTarget
collectExprMd target
| .Local _ => pure ()
collectExprMd v
| .PureFieldUpdate t _ v => collectExprMd t; collectExprMd v
| .PrimitiveOp _ args => for a in args do collectExprMd a
Expand Down Expand Up @@ -237,6 +238,8 @@ def freshVarName : TransformM Identifier := do

/-- Helper to wrap a StmtExpr into StmtExprMd with empty metadata -/
private def mkMd (e : StmtExpr) : StmtExprMd := { val := e, source := none }
/-- Helper to wrap a Variable into VariableMd with empty metadata -/
private def mkMd' (v : Variable) : VariableMd := { val := v, source := none }

/--
Resolve the owning composite type name for a field access by computing the target expression's type.
Expand All @@ -260,12 +263,12 @@ where
recurse (exprMd : StmtExprMd) (valueUsed : Bool := true) : TransformM StmtExprMd := do
let ⟨expr, source⟩ := exprMd
match _h : expr with
| .FieldSelect selectTarget fieldName => do
| .Variable (.Field selectTarget fieldName) => do
let some qualifiedName := resolveQualifiedFieldName model fieldName
| return ⟨ .Hole, source ⟩

let valTy := (model.get fieldName).getType
let readExpr := ⟨ .StaticCall "readField" [mkMd (.Identifier heapVar), selectTarget, mkMd (.StaticCall qualifiedName [])], source ⟩
let readExpr := ⟨ .StaticCall "readField" [mkMd (.Variable (.Local heapVar)), selectTarget, mkMd (.StaticCall qualifiedName [])], source ⟩
-- Unwrap Box: apply the appropriate destructor
recordBoxConstructor model valTy.val
return mkMd <| .StaticCall (boxDestructorName model valTy.val) [readExpr]
Expand All @@ -278,13 +281,13 @@ where
let freshVar ← freshVarName
let varDecl := mkMd (.LocalVariable freshVar (computeExprType model exprMd) none)
let callWithHeap := ⟨ .Assign
[mkMd (.Identifier heapVar), mkMd (.Identifier freshVar)]
(⟨ .StaticCall callee (mkMd (.Identifier heapVar) :: args'), source ⟩), source ⟩
return ⟨ .Block [varDecl, callWithHeap, mkMd (.Identifier freshVar)] none, source ⟩
[mkMd' (.Local heapVar), mkMd' (.Local freshVar)]
(⟨ .StaticCall callee (mkMd (.Variable (.Local heapVar)) :: args'), source ⟩), source ⟩
return ⟨ .Block [varDecl, callWithHeap, mkMd (.Variable (.Local freshVar))] none, source ⟩
else
return ⟨ .Assign [mkMd (.Identifier heapVar)] (⟨ .StaticCall callee (mkMd (.Identifier heapVar) :: args'), source ⟩), source ⟩
return ⟨ .Assign [mkMd' (.Local heapVar)] (⟨ .StaticCall callee (mkMd (.Variable (.Local heapVar)) :: args'), source ⟩), source ⟩
else if calleeReadsHeap then
return ⟨ .StaticCall callee (mkMd (.Identifier heapVar) :: args'), source ⟩
return ⟨ .StaticCall callee (mkMd (.Variable (.Local heapVar)) :: args'), source ⟩
else
return ⟨ .StaticCall callee args', source ⟩
| .InstanceCall callTarget callee args =>
Expand Down Expand Up @@ -318,7 +321,7 @@ where
return ⟨ .Return v', source ⟩
| .Assign targets v =>
match targets with
| [⟨.FieldSelect target fieldName, _⟩] =>
| [⟨.Field target fieldName, _⟩] =>
let some qualifiedName := resolveQualifiedFieldName model fieldName
| return ⟨ .Hole, source ⟩
let valTy := (model.get fieldName).getType
Expand All @@ -327,21 +330,21 @@ where
-- Wrap value in Box constructor
recordBoxConstructor model valTy.val
let boxedVal := mkMd <| .StaticCall (boxConstructorName model valTy.val) [v']
let heapAssign := ⟨ .Assign [mkMd (.Identifier heapVar)]
(mkMd (.StaticCall "updateField" [mkMd (.Identifier heapVar), target', mkMd (.StaticCall qualifiedName []), boxedVal])), source ⟩
let heapAssign := ⟨ .Assign [mkMd' (.Local heapVar)]
(mkMd (.StaticCall "updateField" [mkMd (.Variable (.Local heapVar)), target', mkMd (.StaticCall qualifiedName []), boxedVal])), source ⟩
if valueUsed then
return ⟨ .Block [heapAssign, v'] none, source ⟩
else
return heapAssign
| [fieldSelectMd] =>
let tgt' ← recurse fieldSelectMd
return ⟨ .Assign [tgt'] (← recurse v), source ⟩
| [] =>
return ⟨ .Assign [] (← recurse v), source ⟩
| tgt :: rest =>
let tgt' ← recurse tgt
let targets' ← rest.mapM (recurse ·)
return ⟨ .Assign (tgt' :: targets') (← recurse v), source ⟩
| _ =>
let targets' ← targets.attach.mapM fun ⟨vm, hmem⟩ => do
match hvm : vm.val with
| .Local _ => pure vm
| .Field target fieldName =>
have _h1 : sizeOf target < sizeOf vm := Variable.sizeOf_field_target vm hvm
have _h2 : sizeOf vm < sizeOf targets := List.sizeOf_lt_of_mem hmem
pure ⟨.Field (← recurse target) fieldName, vm.source⟩
return ⟨ .Assign targets' (← recurse v), source ⟩
| .PureFieldUpdate t f v => return ⟨ .PureFieldUpdate (← recurse t) f (← recurse v), source ⟩
| .PrimitiveOp op args =>
let args' ← args.mapM (recurse ·)
Expand Down Expand Up @@ -385,6 +388,7 @@ where
| .ContractOf ty f => return ⟨ .ContractOf ty (← recurse f), source ⟩
| _ => return exprMd
termination_by sizeOf exprMd
decreasing_by all_goals simp_wf; all_goals (try omega); all_goals (try term_by_mem); all_goals (try simp_all); all_goals omega

def heapTransformProcedure (model: SemanticModel) (proc : Procedure) : TransformM Procedure := do
let heapName : Identifier := "$heap"
Expand All @@ -408,15 +412,15 @@ def heapTransformProcedure (model: SemanticModel) (proc : Procedure) : Transform
let body' ← match proc.body with
| .Transparent bodyExpr =>
-- First assign $heap_in to $heap, then transform body using $heap
let assignHeap := mkMd (.Assign [mkMd (.Identifier heapName)] (mkMd (.Identifier heapInName)))
let assignHeap := mkMd (.Assign [mkMd' (.Local heapName)] (mkMd (.Variable (.Local heapInName))))
let bodyExpr' ← heapTransformExpr heapName model bodyExpr bodyValueIsUsed
pure (.Transparent (mkMd (.Block [assignHeap, bodyExpr'] none)))
| .Opaque postconds impl modif =>
-- Postconditions use $heap (the output state)
let postconds' ← postconds.mapM (·.mapM (heapTransformExpr heapName model))
let impl' ← match impl with
| some implExpr =>
let assignHeap := mkMd (.Assign [mkMd (.Identifier heapName)] (mkMd (.Identifier heapInName)))
let assignHeap := mkMd (.Assign [mkMd' (.Local heapName)] (mkMd (.Variable (.Local heapInName))))
let implExpr' ← heapTransformExpr heapName model implExpr bodyValueIsUsed
pure (some (mkMd (.Block [assignHeap, implExpr'] none)))
| none => pure none
Expand Down
Loading
Loading