From 506371d45d255dcf4a7efa38fbb1048e72f6e81a Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Mon, 18 May 2026 11:36:21 -0400 Subject: [PATCH 01/71] Refactor interproc facts domain --- compiler/ast/expr.go | 11 + compiler/cfg/analysis/dominators.go | 138 +-- compiler/cfg/graph.go | 54 +- compiler/cfg/graph_test.go | 28 +- compiler/cfg/local_functions.go | 36 + compiler/cfg/types.go | 7 + compiler/check/api/doc.go | 10 +- compiler/check/api/env.go | 141 +-- compiler/check/api/facts.go | 74 +- compiler/check/api/facts_test.go | 45 +- compiler/check/api/keys.go | 15 +- compiler/check/api/keys_test.go | 18 +- compiler/check/api/parents.go | 22 +- compiler/check/api/parents_test.go | 28 +- compiler/check/api/result.go | 14 +- compiler/check/api/store.go | 20 +- compiler/check/api/store_attach.go | 10 +- compiler/check/callsite/callee_symbols.go | 8 +- compiler/check/callsite/candidates.go | 126 ++- compiler/check/callsite/canonical_symbol.go | 14 +- compiler/check/checker.go | 24 +- .../check/erreffect/error_return_infer.go | 63 +- .../erreffect/error_return_infer_test.go | 35 + compiler/check/flowbuild/assign/emit.go | 14 +- compiler/check/flowbuild/assign/emit_test.go | 2 +- compiler/check/flowbuild/assign/infer.go | 5 +- .../flowbuild/assign/structured_overlay.go | 7 +- compiler/check/flowbuild/core/context.go | 3 + compiler/check/flowbuild/keyscoll/keyscoll.go | 53 +- .../check/flowbuild/keyscoll/keyscoll_test.go | 32 +- compiler/check/flowbuild/run.go | 2 +- compiler/check/infer/interproc/doc.go | 2 +- compiler/check/infer/interproc/postflow.go | 137 ++- compiler/check/infer/interproc/writer.go | 25 +- compiler/check/infer/interproc/writer_test.go | 16 +- compiler/check/infer/nested/processor.go | 67 +- compiler/check/infer/nested/processor_test.go | 2 +- .../check/infer/paramhints/param_hints.go | 12 +- .../infer/paramhints/param_hints_test.go | 4 +- compiler/check/infer/return/infer.go | 141 ++- compiler/check/infer/return/infer_test.go | 65 +- .../check/infer/return/overlay_pipeline.go | 202 ++-- compiler/check/infer/return/scc.go | 42 +- compiler/check/nested/enrich.go | 12 +- compiler/check/nested/enrich_test.go | 8 +- compiler/check/phase/flow.go | 13 +- compiler/check/phase/flow_test.go | 6 +- compiler/check/phase/narrow.go | 5 +- compiler/check/phase/scope.go | 31 +- compiler/check/phase/types.go | 106 +- compiler/check/phase/types_test.go | 30 +- compiler/check/pipeline/driver.go | 29 +- compiler/check/pipeline/runner.go | 66 +- compiler/check/pipeline/runner_stages.go | 29 +- .../check/returns/callgraph_symbol_test.go | 4 +- compiler/check/returns/doc.go | 2 +- compiler/check/returns/domain_law_test.go | 203 ++++ compiler/check/returns/equal.go | 19 +- compiler/check/returns/equal_test.go | 68 +- compiler/check/returns/function_facts.go | 213 +--- compiler/check/returns/join.go | 569 ++++++++++- compiler/check/returns/join_test.go | 115 +++ compiler/check/returns/kernel.go | 121 +-- compiler/check/returns/kernel_test.go | 236 ++--- compiler/check/returns/scc.go | 2 +- compiler/check/returns/signature.go | 2 +- compiler/check/returns/types.go | 21 +- compiler/check/returns/widen.go | 909 ++++++++++++++++-- compiler/check/returns/widen_test.go | 353 ++++++- compiler/check/session.go | 10 +- compiler/check/siblings/doc.go | 2 +- compiler/check/siblings/overlay.go | 16 +- compiler/check/siblings/overlay_test.go | 16 +- compiler/check/siblings/siblings.go | 15 +- compiler/check/siblings/siblings_test.go | 5 +- compiler/check/store/doc.go | 6 +- compiler/check/store/snapshot_inputs.go | 261 +++++ compiler/check/store/store.go | 186 ++-- compiler/check/store/store_test.go | 206 +++- compiler/check/synth/ops/logical_test.go | 7 +- compiler/check/synth/phase/core/params.go | 12 + .../synth/phase/extract/callback_env_infer.go | 4 +- .../phase/extract/callback_env_infer_test.go | 4 +- compiler/check/synth/phase/extract/deps.go | 21 +- compiler/check/synth/phase/extract/doc.go | 2 +- compiler/check/synth/phase/extract/expr.go | 11 +- .../check/synth/phase/extract/function.go | 264 ++--- .../phase/extract/manifest_enrich_test.go | 4 +- .../synth/phase/extract/named_function.go | 54 +- .../phase/extract/named_function_test.go | 2 +- .../check/synth/phase/extract/synthesizer.go | 110 ++- .../synth/phase/extract/synthesizer_test.go | 4 +- .../tests/errors/error_correlation_test.go | 10 +- .../tests/flow/fixpoint_unification_test.go | 18 +- .../inference/closure_return_infer_test.go | 24 +- ...nel_select_helper_return_narrowing_test.go | 8 +- .../contract_open_dynamic_return_test.go | 6 +- .../regression/false_positives_unit_test.go | 35 + .../imported_record_helper_param_test.go | 105 ++ ...ocal_function_narrow_return_repair_test.go | 6 +- .../regression/logical_or_soundness_test.go | 25 + .../param_hint_depth_convergence_test.go | 137 +++ .../regression/wippy_false_positives_test.go | 7 +- .../wippy_sorted_keys_param_hints_test.go | 5 +- .../contract.lua | 12 +- .../contract.lua | 12 +- .../dynamic-registry-renderer-guard/main.lua | 6 +- .../realworld/sql-repository/repository.lua | 2 +- .../deadlock-compiler-lua/manifest.json | 6 + .../deadlock-dataflow-node/manifest.json | 6 + types/constraint/condition.go | 2 +- types/flow/edge.go | 3 + types/flow/numeric.go | 54 +- types/flow/numeric/state.go | 167 +++- types/flow/propagate/propagate.go | 34 +- types/flow/solver.go | 43 +- types/flow/solver_helpers.go | 5 +- types/flow/transfer.go | 3 + types/typ/function.go | 11 + types/typ/join/join.go | 46 +- types/typ/policy.go | 28 +- types/typ/policy_test.go | 7 +- types/typ/rebuild.go | 38 + types/typ/record_test.go | 23 + types/typ/unwrap/unwrap.go | 343 ++++--- 125 files changed, 5191 insertions(+), 2214 deletions(-) create mode 100644 compiler/cfg/local_functions.go create mode 100644 compiler/check/erreffect/error_return_infer_test.go create mode 100644 compiler/check/returns/domain_law_test.go create mode 100644 compiler/check/store/snapshot_inputs.go create mode 100644 compiler/check/tests/regression/imported_record_helper_param_test.go create mode 100644 compiler/check/tests/regression/logical_or_soundness_test.go create mode 100644 testdata/fixtures/regression/deadlock-compiler-lua/manifest.json create mode 100644 testdata/fixtures/regression/deadlock-dataflow-node/manifest.json diff --git a/compiler/ast/expr.go b/compiler/ast/expr.go index 778df614..6bf7da19 100644 --- a/compiler/ast/expr.go +++ b/compiler/ast/expr.go @@ -89,6 +89,17 @@ type FuncCallExpr struct { AdjustRet bool // Whether return count should be adjusted } +// CanProduceMultipleValues reports whether expr can expand to multiple Lua values +// when it appears in the final slot of an expression list. +func CanProduceMultipleValues(expr Expr) bool { + switch expr.(type) { + case *FuncCallExpr, *Comma3Expr: + return true + default: + return false + } +} + // LogicalOpExpr represents a logical operator (and, or). type LogicalOpExpr struct { ExprBase diff --git a/compiler/cfg/analysis/dominators.go b/compiler/cfg/analysis/dominators.go index 9b03fa76..e850acbe 100644 --- a/compiler/cfg/analysis/dominators.go +++ b/compiler/cfg/analysis/dominators.go @@ -33,6 +33,13 @@ type rpoReader interface { RPOReadOnly() []basecfg.Point } +type immediateDominatorData struct { + rpo []basecfg.Point + rpoNum []int + idomByPoint []basecfg.Point + hasIDom []bool +} + func predecessorsOf(g basecfg.Graph, point basecfg.Point) []basecfg.Point { if direct, ok := g.(predecessorsReader); ok { return direct.PredecessorsReadOnly(point) @@ -57,33 +64,28 @@ func rpoOf(g basecfg.Graph) []basecfg.Point { return g.RPO() } -// ComputeDominators computes immediate dominators and the dominator tree. -// Uses the Cooper-Harvey-Kennedy algorithm with RPO iteration. -func ComputeDominators(g basecfg.Graph) (idom map[basecfg.Point]basecfg.Point, domTree map[basecfg.Point][]basecfg.Point) { +func computeImmediateDominatorData(g basecfg.Graph) immediateDominatorData { rpo := rpoOf(g) if len(rpo) == 0 { - return make(map[basecfg.Point]basecfg.Point), make(map[basecfg.Point][]basecfg.Point) + return immediateDominatorData{} } graphSize := g.Size() - if graphSize == 0 { - return make(map[basecfg.Point]basecfg.Point), make(map[basecfg.Point][]basecfg.Point) + return immediateDominatorData{} } - // Build RPO number lookup for intersection and deterministic sorting. rpoNum := make([]int, graphSize) for i, p := range rpo { if int(p) >= graphSize { continue } - rpoNum[p] = i } entry := g.Entry() if int(entry) >= graphSize { - return make(map[basecfg.Point]basecfg.Point), make(map[basecfg.Point][]basecfg.Point) + return immediateDominatorData{} } idomByPoint := make([]basecfg.Point, graphSize) @@ -91,96 +93,111 @@ func ComputeDominators(g basecfg.Graph) (idom map[basecfg.Point]basecfg.Point, d idomByPoint[entry] = entry hasIDom[entry] = true - // intersect finds the common dominator of two nodes - intersect := func(b1, b2 basecfg.Point) basecfg.Point { - finger1, finger2 := b1, b2 - - for finger1 != finger2 { - for rpoNum[finger1] > rpoNum[finger2] { - finger1 = idomByPoint[finger1] + intersect := func(pointA, pointB basecfg.Point) basecfg.Point { + fingerA, fingerB := pointA, pointB + for fingerA != fingerB { + for rpoNum[fingerA] > rpoNum[fingerB] { + fingerA = idomByPoint[fingerA] } - - for rpoNum[finger2] > rpoNum[finger1] { - finger2 = idomByPoint[finger2] + for rpoNum[fingerB] > rpoNum[fingerA] { + fingerB = idomByPoint[fingerB] } } - - return finger1 + return fingerA } - // Iterate until fixed point changed := true for changed { changed = false - for _, b := range rpo { - if b == entry { - continue - } - - if int(b) >= graphSize { + for _, block := range rpo { + if block == entry || int(block) >= graphSize { continue } - preds := predecessorsOf(g, b) + preds := predecessorsOf(g, block) if len(preds) == 0 { continue } - // Find first predecessor with defined idom - var newIdom basecfg.Point - + var newIDom basecfg.Point found := false - - for _, p := range preds { - if int(p) >= graphSize { + for _, pred := range preds { + predIdx := int(pred) + if predIdx >= graphSize { continue } - - if hasIDom[p] { - newIdom = p + if hasIDom[predIdx] { + newIDom = pred found = true - break } } - if !found { continue } - // Intersect with other defined predecessors - for _, p := range preds { - if p == newIdom { + for _, pred := range preds { + if pred == newIDom { continue } - - if int(p) >= graphSize { + predIdx := int(pred) + if predIdx >= graphSize { continue } - - if hasIDom[p] { - newIdom = intersect(p, newIdom) + if hasIDom[predIdx] { + newIDom = intersect(pred, newIDom) } } - if !hasIDom[b] || idomByPoint[b] != newIdom { - idomByPoint[b] = newIdom - hasIDom[b] = true + blockIdx := int(block) + if !hasIDom[blockIdx] || idomByPoint[blockIdx] != newIDom { + idomByPoint[blockIdx] = newIDom + hasIDom[blockIdx] = true changed = true } } } - idom = make(map[basecfg.Point]basecfg.Point, len(rpo)) - for _, point := range rpo { - if int(point) >= graphSize || !hasIDom[point] { + return immediateDominatorData{ + rpo: rpo, + rpoNum: rpoNum, + idomByPoint: idomByPoint, + hasIDom: hasIDom, + } +} + +func (d immediateDominatorData) asMap() map[basecfg.Point]basecfg.Point { + if len(d.rpo) == 0 { + return make(map[basecfg.Point]basecfg.Point) + } + idom := make(map[basecfg.Point]basecfg.Point, len(d.rpo)) + for _, point := range d.rpo { + idx := int(point) + if idx >= len(d.hasIDom) || !d.hasIDom[idx] { continue } + idom[point] = d.idomByPoint[idx] + } + return idom +} + +// ComputeImmediateDominators computes only the immediate-dominator map. +// +// Use this when callers only need dominance predicates. It avoids building the +// dominator tree, which is meaningful allocation in hot type-checking paths. +func ComputeImmediateDominators(g basecfg.Graph) map[basecfg.Point]basecfg.Point { + return computeImmediateDominatorData(g).asMap() +} - idom[point] = idomByPoint[point] +// ComputeDominators computes immediate dominators and the dominator tree. +// Uses the Cooper-Harvey-Kennedy algorithm with RPO iteration. +func ComputeDominators(g basecfg.Graph) (idom map[basecfg.Point]basecfg.Point, domTree map[basecfg.Point][]basecfg.Point) { + data := computeImmediateDominatorData(g) + if len(data.rpo) == 0 { + return make(map[basecfg.Point]basecfg.Point), make(map[basecfg.Point][]basecfg.Point) } - // Build dominator tree from idom + idom = data.asMap() domTree = make(map[basecfg.Point][]basecfg.Point, len(idom)) for n, dom := range idom { @@ -192,10 +209,10 @@ func ComputeDominators(g basecfg.Graph) (idom map[basecfg.Point]basecfg.Point, d // Sort children for deterministic order. for p := range domTree { slices.SortFunc(domTree[p], func(a, b basecfg.Point) int { - if rpoNum[a] < rpoNum[b] { + if data.rpoNum[a] < data.rpoNum[b] { return -1 } - if rpoNum[a] > rpoNum[b] { + if data.rpoNum[a] > data.rpoNum[b] { return 1 } return 0 @@ -599,6 +616,11 @@ func ComputePostDominators(graph basecfg.Graph) (map[basecfg.Point]basecfg.Point return ComputeDominators(&reversedGraph{g: graph}) } +// ComputeImmediatePostDominators computes only the immediate post-dominator map. +func ComputeImmediatePostDominators(graph basecfg.Graph) map[basecfg.Point]basecfg.Point { + return ComputeImmediateDominators(&reversedGraph{g: graph}) +} + // PostDominates returns true if a post-dominates b (a is on every path from b to exit). func PostDominates(postIdom map[basecfg.Point]basecfg.Point, pointA, pointB basecfg.Point) bool { return Dominates(postIdom, pointA, pointB) diff --git a/compiler/cfg/graph.go b/compiler/cfg/graph.go index c00a10a0..1f94efd5 100644 --- a/compiler/cfg/graph.go +++ b/compiler/cfg/graph.go @@ -24,6 +24,7 @@ type Graph struct { orderedBranchPoints []Point orderedFuncDefPoints []Point orderedTypeDefPoints []Point + localFunctionAssigns []LocalFunctionAssignment // Binding table (AST ident -> symbol, populated before CFG build) bindings *bind.BindingTable @@ -217,6 +218,7 @@ func BuildWithBindings(fn *ast.FunctionExpr, bindings *bind.BindingTable) *Graph size := b.Cfg.Size() pointIdx := buildPointIndex(b.Info, size) infoByPoint := denseNodeInfoByPoint(b.Info, size) + localFunctionAssigns := buildLocalFunctionAssignments(infoByPoint, pointIdx.assign) if len(visibleVersionByPoint) == 0 { visibleVersionByPoint = denseVisibleVersionByPoint(visibleVersion, size) @@ -234,6 +236,7 @@ func BuildWithBindings(fn *ast.FunctionExpr, bindings *bind.BindingTable) *Graph orderedBranchPoints: pointIdx.branch, orderedFuncDefPoints: pointIdx.funcDef, orderedTypeDefPoints: pointIdx.typeDef, + localFunctionAssigns: localFunctionAssigns, bindings: bindings, phiNodes: b.PhiNodes, visibleVersion: visibleVersion, @@ -299,6 +302,7 @@ func BuildBlock(stmts []ast.Stmt, globals ...string) *Graph { size := b.Cfg.Size() pointIdx := buildPointIndex(b.Info, size) infoByPoint := denseNodeInfoByPoint(b.Info, size) + localFunctionAssigns := buildLocalFunctionAssignments(infoByPoint, pointIdx.assign) if len(visibleVersionByPoint) == 0 { visibleVersionByPoint = denseVisibleVersionByPoint(visibleVersion, size) @@ -316,6 +320,7 @@ func BuildBlock(stmts []ast.Stmt, globals ...string) *Graph { orderedBranchPoints: pointIdx.branch, orderedFuncDefPoints: pointIdx.funcDef, orderedTypeDefPoints: pointIdx.typeDef, + localFunctionAssigns: localFunctionAssigns, bindings: bindings, phiNodes: b.PhiNodes, visibleVersion: visibleVersion, @@ -735,6 +740,14 @@ func (g *Graph) NestedFunctions() []NestedFunc { return g.nested } +// LocalFunctionAssignments returns local identifiers bound directly to function literals. +func (g *Graph) LocalFunctionAssignments() []LocalFunctionAssignment { + if g == nil { + return nil + } + return g.localFunctionAssigns +} + // CFG delegated methods. // Node returns the base CFG node at point p. @@ -1207,13 +1220,40 @@ func (g *Graph) EachAliasSymbol(targetSym basecfg.SymbolID, fn func(basecfg.Symb return } - seen := make(map[basecfg.SymbolID]struct{}, 4) + var smallSeen [8]basecfg.SymbolID + seenCount := 0 + var seen map[basecfg.SymbolID]struct{} + remember := func(sym basecfg.SymbolID) bool { + if seen != nil { + if _, ok := seen[sym]; ok { + return false + } + seen[sym] = struct{}{} + return true + } + for i := 0; i < seenCount; i++ { + if smallSeen[i] == sym { + return false + } + } + if seenCount < len(smallSeen) { + smallSeen[seenCount] = sym + seenCount++ + return true + } + seen = make(map[basecfg.SymbolID]struct{}, len(smallSeen)+1) + for i := 0; i < seenCount; i++ { + seen[smallSeen[i]] = struct{}{} + } + seen[sym] = struct{}{} + return true + } + current := targetSym for current != 0 { - if _, ok := seen[current]; ok { + if !remember(current) { return } - seen[current] = struct{}{} if fn(current) { return @@ -1311,6 +1351,14 @@ func (g *Graph) SymbolKind(sym basecfg.SymbolID) (basecfg.SymbolKind, bool) { return kind, ok } +// SymbolCount returns the number of symbols tracked by the graph. +func (g *Graph) SymbolCount() int { + if g == nil { + return 0 + } + return len(g.symbolKinds) +} + // HasScopeTracking returns true if scope visibility was computed during build. func (g *Graph) HasScopeTracking() bool { return g != nil && g.symbolScope != nil diff --git a/compiler/cfg/graph_test.go b/compiler/cfg/graph_test.go index 3a4235ed..b92b68e8 100644 --- a/compiler/cfg/graph_test.go +++ b/compiler/cfg/graph_test.go @@ -587,17 +587,16 @@ func TestGraph_CFGMethods(t *testing.T) { func TestGraph_NestedFunctions(t *testing.T) { t.Parallel() + nestedFn := &ast.FunctionExpr{ + ParList: &ast.ParList{Names: []string{"a"}}, + Stmts: []ast.Stmt{}, + } fn := &ast.FunctionExpr{ ParList: &ast.ParList{}, Stmts: []ast.Stmt{ &ast.LocalAssignStmt{ Names: []string{"fn"}, - Exprs: []ast.Expr{ - &ast.FunctionExpr{ - ParList: &ast.ParList{Names: []string{"a"}}, - Stmts: []ast.Stmt{}, - }, - }, + Exprs: []ast.Expr{nestedFn}, }, }, } @@ -611,6 +610,20 @@ func TestGraph_NestedFunctions(t *testing.T) { if len(nested) != 1 { t.Errorf("Expected 1 nested function, got %d", len(nested)) } + + localFns := g.LocalFunctionAssignments() + if len(localFns) != 1 { + t.Fatalf("Expected 1 local function assignment, got %d", len(localFns)) + } + if localFns[0].Name != "fn" { + t.Fatalf("LocalFunctionAssignments()[0].Name = %q, want fn", localFns[0].Name) + } + if localFns[0].Symbol == 0 { + t.Fatal("LocalFunctionAssignments()[0].Symbol should be non-zero") + } + if localFns[0].Func != nestedFn { + t.Fatal("LocalFunctionAssignments()[0].Func should be the assigned function literal") + } } // TestGraph_PopulateSymbols tests symbol population. @@ -974,6 +987,9 @@ func TestGraph_SymbolKind(t *testing.T) { if g == nil { t.Fatal("Build should return graph") } + if got, want := g.SymbolCount(), 4; got != want { + t.Fatalf("SymbolCount() = %d, want %d", got, want) + } // Check parameter symbols are basecfg.SymbolParam paramSymbols := g.ParamSymbols() diff --git a/compiler/cfg/local_functions.go b/compiler/cfg/local_functions.go new file mode 100644 index 00000000..f3c39e95 --- /dev/null +++ b/compiler/cfg/local_functions.go @@ -0,0 +1,36 @@ +package cfg + +import "github.com/wippyai/go-lua/compiler/ast" + +func buildLocalFunctionAssignments(infoByPoint []NodeInfo, assignPoints []Point) []LocalFunctionAssignment { + if len(infoByPoint) == 0 || len(assignPoints) == 0 { + return nil + } + + var out []LocalFunctionAssignment + for _, p := range assignPoints { + idx := int(p) + if idx < 0 || idx >= len(infoByPoint) { + continue + } + info, ok := infoByPoint[idx].(*AssignInfo) + if !ok || info == nil || !info.IsLocal || len(info.Targets) == 0 { + continue + } + info.EachTargetSource(func(_ int, target AssignTarget, source ast.Expr) { + if target.Kind != TargetIdent || target.Symbol == 0 { + return + } + fn, ok := source.(*ast.FunctionExpr) + if !ok || fn == nil { + return + } + out = append(out, LocalFunctionAssignment{ + Symbol: target.Symbol, + Name: target.Name, + Func: fn, + }) + }) + } + return out +} diff --git a/compiler/cfg/types.go b/compiler/cfg/types.go index 8e86289e..d19eddc0 100644 --- a/compiler/cfg/types.go +++ b/compiler/cfg/types.go @@ -133,6 +133,13 @@ type AssignInfo struct { singleTargetVersion [1]Version } +// LocalFunctionAssignment describes `local f = function(...) ... end`. +type LocalFunctionAssignment struct { + Symbol basecfg.SymbolID + Name string + Func *ast.FunctionExpr +} + func (*AssignInfo) nodeInfo() {} // Kind returns the node kind for AssignInfo. diff --git a/compiler/check/api/doc.go b/compiler/check/api/doc.go index 770e3f45..7d893a88 100644 --- a/compiler/check/api/doc.go +++ b/compiler/check/api/doc.go @@ -12,7 +12,7 @@ // - [GraphStore]: Access to built CFGs by ID // - [ParentScopes]: Parent scope lookup for nested functions // - [SnapshotStore]: Stable interprocedural fact snapshots -// - [StoreView]: Read-only view combining all above +// - [StoreReader]: Read contract combining the immutable stores above // - [IterationStore]: Adds mutation for fixpoint iteration // // These interfaces allow different phases to declare their dependencies and @@ -31,14 +31,14 @@ // function graph: // // - [FunctionFacts]: Canonical per-function return/signature facts -// - [ReturnSummaries]: Inferred return types by function symbol -// - [NarrowReturnSummaries]: Post-narrowing return types // - [ParamHints]: Parameter types inferred from call sites -// - [FuncTypes]: Canonical types for local function symbols // - [LiteralSigs]: Signatures for anonymous function literals +// - [CapturedTypes]: Flow-derived types for captured variables // - [CapturedFieldAssigns]: Field assignments to captured variables +// - [CapturedContainerMutations]: Container writes to captured variables +// - [ConstructorFields]: Instance fields collected from constructors // -// Facts are computed incrementally and stored per (graph, parent-scope) pair. +// Facts are emitted as canonical deltas and stored per (graph, parent-scope) pair. // The [GraphKey] type provides the canonical key for this lookup. // // # Function References diff --git a/compiler/check/api/env.go b/compiler/check/api/env.go index 4ff2c471..ca5ffb15 100644 --- a/compiler/check/api/env.go +++ b/compiler/check/api/env.go @@ -1,6 +1,6 @@ // env.go defines phase-typed environments for synthesis. // DeclaredEnv is used pre-flow; NarrowEnv is used post-flow. -// This split prevents pre-flow return summaries from being accessed in narrowing. +// This split keeps declared and flow-refined function facts phase-explicit. package api import ( @@ -37,7 +37,7 @@ func (p Phase) String() string { } // BaseEnv is the shared environment interface for synthesis. -// It intentionally excludes return summaries to prevent cross-phase misuse. +// It intentionally excludes FunctionFacts to prevent cross-phase misuse. type BaseEnv interface { Phase() Phase Graph() cfg.VersionedGraph @@ -53,16 +53,16 @@ type BaseEnv interface { WithGlobalOverlay(overlay map[string]typ.Type) BaseEnv } -// DeclaredEnv provides access to pre-flow return summaries. +// DeclaredEnv provides canonical function facts in declared phase. type DeclaredEnv interface { BaseEnv - ReturnSummaries() map[cfg.SymbolID][]typ.Type + FunctionFacts() FunctionFacts } -// NarrowEnv provides access to post-flow return summaries. +// NarrowEnv provides canonical function facts in narrowing phase. type NarrowEnv interface { BaseEnv - NarrowReturnSummaries() map[cfg.SymbolID][]typ.Type + FunctionFacts() FunctionFacts } type envBase struct { @@ -77,8 +77,8 @@ type envBase struct { globalTypes map[string]typ.Type } -func (b *envBase) withGlobalOverlay(overlay map[string]typ.Type) *envBase { - if b == nil || len(overlay) == 0 { +func (b envBase) withGlobalOverlay(overlay map[string]typ.Type) envBase { + if len(overlay) == 0 { return b } merged := make(map[string]typ.Type, len(b.globalTypes)+len(overlay)) @@ -88,25 +88,28 @@ func (b *envBase) withGlobalOverlay(overlay map[string]typ.Type) *envBase { for k, v := range overlay { merged[k] = v } - next := *b + next := b next.globalTypes = merged - return &next + return next } type envCommon struct { - base *envBase + base envBase } -func (c *envCommon) withGlobalOverlay(overlay map[string]typ.Type) *envCommon { +func (c *envCommon) withGlobalOverlay(overlay map[string]typ.Type) envCommon { if c == nil || len(overlay) == 0 { - return c + if c == nil { + return envCommon{} + } + return *c } - return &envCommon{base: c.base.withGlobalOverlay(overlay)} + return envCommon{base: c.base.withGlobalOverlay(overlay)} } // Phase returns the current checking phase. func (c *envCommon) Phase() Phase { - if c == nil || c.base == nil { + if c == nil { return PhaseScopeCompute } return c.base.phase @@ -114,7 +117,7 @@ func (c *envCommon) Phase() Phase { // Graph returns the versioned CFG graph. func (c *envCommon) Graph() cfg.VersionedGraph { - if c == nil || c.base == nil { + if c == nil { return nil } return c.base.graph @@ -122,7 +125,7 @@ func (c *envCommon) Graph() cfg.VersionedGraph { // Types returns the type facts provider. func (c *envCommon) Types() flow.TypeFacts { - if c == nil || c.base == nil { + if c == nil { return nil } return c.base.types @@ -130,7 +133,7 @@ func (c *envCommon) Types() flow.TypeFacts { // Consts returns the flow solution for constant value lookup. func (c *envCommon) Consts() *flow.Solution { - if c == nil || c.base == nil { + if c == nil { return nil } return c.base.solution @@ -138,7 +141,7 @@ func (c *envCommon) Consts() *flow.Solution { // Refinements returns the refinement facts provider. func (c *envCommon) Refinements() RefinementFacts { - if c == nil || c.base == nil { + if c == nil { return nil } return c.base.refinements @@ -146,7 +149,7 @@ func (c *envCommon) Refinements() RefinementFacts { // TypeNames returns the scope state for type name resolution. func (c *envCommon) TypeNames() *scope.State { - if c == nil || c.base == nil { + if c == nil { return nil } return c.base.typeNames @@ -154,7 +157,7 @@ func (c *envCommon) TypeNames() *scope.State { // Bindings returns the binding table for AST-based symbol resolution. func (c *envCommon) Bindings() *bind.BindingTable { - if c == nil || c.base == nil { + if c == nil { return nil } return c.base.bindings @@ -162,7 +165,7 @@ func (c *envCommon) Bindings() *bind.BindingTable { // ModuleAliases returns the module alias map (symbol -> module path). func (c *envCommon) ModuleAliases() map[cfg.SymbolID]string { - if c == nil || c.base == nil { + if c == nil { return nil } return c.base.moduleAliases @@ -170,7 +173,7 @@ func (c *envCommon) ModuleAliases() map[cfg.SymbolID]string { // ModuleAlias returns the module path for a symbol assigned from require(). func (c *envCommon) ModuleAlias(sym cfg.SymbolID) string { - if c == nil || c.base == nil || c.base.moduleAliases == nil { + if c == nil || c.base.moduleAliases == nil { return "" } return c.base.moduleAliases[sym] @@ -178,7 +181,7 @@ func (c *envCommon) ModuleAlias(sym cfg.SymbolID) string { // GlobalType returns the global type for a symbol if it is a confirmed global. func (c *envCommon) GlobalType(sym cfg.SymbolID) (typ.Type, bool) { - if c == nil || c.base == nil || c.base.globalTypes == nil || sym == 0 { + if c == nil || c.base.globalTypes == nil || sym == 0 { return nil, false } if c.base.bindings == nil { @@ -198,7 +201,7 @@ func (c *envCommon) GlobalType(sym cfg.SymbolID) (typ.Type, bool) { // GlobalTypes returns the global type map. func (c *envCommon) GlobalTypes() map[string]typ.Type { - if c == nil || c.base == nil { + if c == nil { return nil } return c.base.globalTypes @@ -206,14 +209,14 @@ func (c *envCommon) GlobalTypes() map[string]typ.Type { // DeclaredEnvImpl is the concrete declared-phase environment. type DeclaredEnvImpl struct { - *envCommon - returnSummaries map[cfg.SymbolID][]typ.Type + envCommon + functionFacts FunctionFacts } // NarrowEnvImpl is the concrete narrowing-phase environment. type NarrowEnvImpl struct { - *envCommon - narrowReturns map[cfg.SymbolID][]typ.Type + envCommon + functionFacts FunctionFacts } var _ BaseEnv = (*DeclaredEnvImpl)(nil) @@ -423,20 +426,20 @@ func (e *NarrowEnvImpl) WithGlobalOverlay(overlay map[string]typ.Type) BaseEnv { return &next } -// ReturnSummaries returns the return type summaries for sibling functions. -func (e *DeclaredEnvImpl) ReturnSummaries() map[cfg.SymbolID][]typ.Type { +// FunctionFacts returns canonical function facts for sibling functions. +func (e *DeclaredEnvImpl) FunctionFacts() FunctionFacts { if e == nil { return nil } - return e.returnSummaries + return e.functionFacts } -// NarrowReturnSummaries returns post-flow return summaries for narrowing. -func (e *NarrowEnvImpl) NarrowReturnSummaries() map[cfg.SymbolID][]typ.Type { +// FunctionFacts returns canonical function facts for sibling functions. +func (e *NarrowEnvImpl) FunctionFacts() FunctionFacts { if e == nil { return nil } - return e.narrowReturns + return e.functionFacts } // DeclaredEnvConfig holds inputs for building a declared-phase Env. @@ -449,25 +452,23 @@ type DeclaredEnvConfig struct { RefinementStore RefinementStore ModuleAliases map[cfg.SymbolID]string GlobalTypes map[string]typ.Type - SiblingTypes map[cfg.SymbolID]typ.Type LiteralTypes flow.DeclaredTypes - ReturnSummaries map[cfg.SymbolID][]typ.Type + FunctionFacts FunctionFacts } // NarrowEnvConfig holds inputs for building a narrowing-phase Env. type NarrowEnvConfig struct { - Graph cfg.VersionedGraph - Bindings *bind.BindingTable - DeclaredTypes flow.DeclaredTypes - AnnotatedVars map[cfg.SymbolID]bool - Solution *flow.Solution - BaseScope *scope.State - RefinementStore RefinementStore - ModuleAliases map[cfg.SymbolID]string - GlobalTypes map[string]typ.Type - SiblingTypes map[cfg.SymbolID]typ.Type - LiteralTypes flow.DeclaredTypes - NarrowReturnSummaries map[cfg.SymbolID][]typ.Type + Graph cfg.VersionedGraph + Bindings *bind.BindingTable + DeclaredTypes flow.DeclaredTypes + AnnotatedVars map[cfg.SymbolID]bool + Solution *flow.Solution + BaseScope *scope.State + RefinementStore RefinementStore + ModuleAliases map[cfg.SymbolID]string + GlobalTypes map[string]typ.Type + LiteralTypes flow.DeclaredTypes + FunctionFacts FunctionFacts } func newEnvBase( @@ -480,8 +481,8 @@ func newEnvBase( typeNames *scope.State, moduleAliases map[cfg.SymbolID]string, globalTypes map[string]typ.Type, -) *envBase { - return &envBase{ +) envBase { + return envBase{ phase: phase, graph: graph, bindings: bindings, @@ -503,14 +504,14 @@ func NewDeclaredEnv(cfg DeclaredEnvConfig) *DeclaredEnvImpl { PhaseScopeCompute, cfg.Graph, cfg.Bindings, - newUnifiedTypeFacts(cfg.Graph, cfg.DeclaredTypes, cfg.SiblingTypes, cfg.LiteralTypes, cfg.AnnotatedVars, nil), + newUnifiedTypeFacts(cfg.Graph, cfg.DeclaredTypes, cfg.FunctionFacts, cfg.LiteralTypes, cfg.AnnotatedVars, nil), nil, NewRefinementFacts(cfg.RefinementStore), cfg.BaseScope, cfg.ModuleAliases, cfg.GlobalTypes, ) - return &DeclaredEnvImpl{envCommon: &envCommon{base: base}, returnSummaries: cfg.ReturnSummaries} + return &DeclaredEnvImpl{envCommon: envCommon{base: base}, functionFacts: cfg.FunctionFacts} } // NewNarrowEnv creates a narrowing-phase Env. @@ -522,25 +523,25 @@ func NewNarrowEnv(cfg NarrowEnvConfig) *NarrowEnvImpl { PhaseNarrowing, cfg.Graph, cfg.Bindings, - newUnifiedTypeFacts(cfg.Graph, cfg.DeclaredTypes, cfg.SiblingTypes, cfg.LiteralTypes, cfg.AnnotatedVars, cfg.Solution), + newUnifiedTypeFacts(cfg.Graph, cfg.DeclaredTypes, cfg.FunctionFacts, cfg.LiteralTypes, cfg.AnnotatedVars, cfg.Solution), cfg.Solution, NewRefinementFacts(cfg.RefinementStore), cfg.BaseScope, cfg.ModuleAliases, cfg.GlobalTypes, ) - return &NarrowEnvImpl{envCommon: &envCommon{base: base}, narrowReturns: cfg.NarrowReturnSummaries} + return &NarrowEnvImpl{envCommon: envCommon{base: base}, functionFacts: cfg.FunctionFacts} } // ReturnInferenceEnvConfig holds inputs for return type inference. type ReturnInferenceEnvConfig struct { - Graph cfg.VersionedGraph - Bindings *bind.BindingTable - BaseScope *scope.State - DeclaredTypes flow.DeclaredTypes - GlobalTypes map[string]typ.Type - ModuleAliases map[cfg.SymbolID]string - ReturnSummaries map[cfg.SymbolID][]typ.Type + Graph cfg.VersionedGraph + Bindings *bind.BindingTable + BaseScope *scope.State + DeclaredTypes flow.DeclaredTypes + GlobalTypes map[string]typ.Type + ModuleAliases map[cfg.SymbolID]string + FunctionFacts FunctionFacts } // NewReturnInferenceEnv creates a declared-phase Env for return inference. @@ -552,21 +553,21 @@ func NewReturnInferenceEnv(cfg ReturnInferenceEnvConfig) *DeclaredEnvImpl { PhaseScopeCompute, cfg.Graph, cfg.Bindings, - newUnifiedTypeFacts(cfg.Graph, cfg.DeclaredTypes, nil, nil, nil, nil), + newUnifiedTypeFacts(cfg.Graph, cfg.DeclaredTypes, cfg.FunctionFacts, nil, nil, nil), nil, NewRefinementFacts(nil), cfg.BaseScope, cfg.ModuleAliases, cfg.GlobalTypes, ) - return &DeclaredEnvImpl{envCommon: &envCommon{base: base}, returnSummaries: cfg.ReturnSummaries} + return &DeclaredEnvImpl{envCommon: envCommon{base: base}, functionFacts: cfg.FunctionFacts} } // unifiedTypeFacts implements flow.TypeFacts with layered type source lookup. type unifiedTypeFacts struct { graph cfg.VersionedGraph declaredTypes flow.DeclaredTypes - siblingTypes map[cfg.SymbolID]typ.Type + functionFacts FunctionFacts literalTypes flow.DeclaredTypes annotatedVars map[cfg.SymbolID]bool solution *flow.Solution @@ -575,7 +576,7 @@ type unifiedTypeFacts struct { func newUnifiedTypeFacts( graph cfg.VersionedGraph, declared flow.DeclaredTypes, - siblings map[cfg.SymbolID]typ.Type, + functionFacts FunctionFacts, literals flow.DeclaredTypes, annotated map[cfg.SymbolID]bool, solution *flow.Solution, @@ -583,7 +584,7 @@ func newUnifiedTypeFacts( return &unifiedTypeFacts{ graph: graph, declaredTypes: declared, - siblingTypes: siblings, + functionFacts: functionFacts, literalTypes: literals, annotatedVars: annotated, solution: solution, @@ -604,8 +605,8 @@ func (f *unifiedTypeFacts) DeclaredAt(p cfg.Point, sym cfg.SymbolID) flow.TypedV } } } - if f.siblingTypes != nil { - if t, ok := f.siblingTypes[sym]; ok && t != nil { + if f.functionFacts != nil { + if t := f.functionFacts.FunctionType(sym); t != nil { return f.toTypedValue(t) } } @@ -616,7 +617,7 @@ func (f *unifiedTypeFacts) DeclaredAt(p cfg.Point, sym cfg.SymbolID) flow.TypedV } // Literal overlays are synthesized from function/table literals and can lag // behind canonical declared/sibling symbol types during fixpoint iterations. - // Keep them as fallback only when no canonical symbol type is available. + // Use them only after canonical symbol facts and declared types are absent. if f.literalTypes != nil { if f.annotatedVars == nil || !f.annotatedVars[sym] { if t, ok := f.literalTypes[sym]; ok && t != nil { diff --git a/compiler/check/api/facts.go b/compiler/check/api/facts.go index ada6fe9e..97d88094 100644 --- a/compiler/check/api/facts.go +++ b/compiler/check/api/facts.go @@ -12,38 +12,60 @@ import ( "github.com/wippyai/go-lua/types/typ" ) -// ReturnSummaries maps function symbols to their inferred return type vectors. -// Each entry is a slice of types representing the tuple of values returned -// by the function. For example, a function returning (value, error) has -// a two-element slice [valueType, errorType]. -type ReturnSummaries = map[cfg.SymbolID][]typ.Type - -// NarrowReturnSummaries holds post-flow return summaries with narrowing applied. -// These are computed after flow analysis and reflect the precise types at -// each return statement, accounting for control flow narrowing. -type NarrowReturnSummaries = map[cfg.SymbolID][]typ.Type - // ParamHints maps function symbols to parameter type hints inferred from call sites. // When a function is called with known argument types, those types are recorded // as hints and propagated to the function's parameter declarations. type ParamHints = map[cfg.SymbolID][]typ.Type -// FuncTypes maps local function symbols to their canonical function types. -// Used for sibling function lookups where the function is defined in the -// same scope as the call site. -type FuncTypes = map[cfg.SymbolID]typ.Type - // FunctionFact is the canonical function-related interproc fact for one symbol. -// Legacy channels (ReturnSummaries/NarrowReturns/FuncTypes) are compatibility -// views and should be derivable from this value. +// All return and local-function type evidence for a function converges here. type FunctionFact struct { + // Summary is the declared/pre-flow return vector. Summary []typ.Type - Narrow []typ.Type - Func typ.Type + // Narrow is the post-flow return vector. + Narrow []typ.Type + // Type is the canonical local function type evidence. + Type typ.Type } // FunctionFacts maps function symbols to their canonical function facts. -type FunctionFacts = map[cfg.SymbolID]FunctionFact +type FunctionFacts map[cfg.SymbolID]FunctionFact + +// Fact returns the canonical fact for sym. +func (facts FunctionFacts) Fact(sym cfg.SymbolID) (FunctionFact, bool) { + if len(facts) == 0 || sym == 0 { + return FunctionFact{}, false + } + ff, ok := facts[sym] + return ff, ok +} + +// Summary returns the declared/pre-flow return vector for sym. +func (facts FunctionFacts) Summary(sym cfg.SymbolID) []typ.Type { + ff, ok := facts.Fact(sym) + if !ok { + return nil + } + return ff.Summary +} + +// NarrowSummary returns the post-flow return vector for sym. +func (facts FunctionFacts) NarrowSummary(sym cfg.SymbolID) []typ.Type { + ff, ok := facts.Fact(sym) + if !ok { + return nil + } + return ff.Narrow +} + +// FunctionType returns the canonical local function type for sym. +func (facts FunctionFacts) FunctionType(sym cfg.SymbolID) typ.Type { + ff, ok := facts.Fact(sym) + if !ok { + return nil + } + return ff.Type +} // LiteralSigs maps anonymous function literal expressions to their signatures. // Used when function literals are passed as arguments or assigned to variables @@ -89,14 +111,8 @@ type ConstructorFields = map[cfg.SymbolID]map[string]typ.Type // Facts bundles all interprocedural analysis results for a single function graph. // These facts are computed during analysis and stored per (graph, parent) pair. type Facts struct { - FunctionFacts FunctionFacts - // Compatibility mirror derived from FunctionFacts. - ReturnSummaries ReturnSummaries - // Compatibility mirror derived from FunctionFacts. - NarrowReturns NarrowReturnSummaries - ParamHints ParamHints - // Compatibility mirror derived from FunctionFacts. - FuncTypes FuncTypes + FunctionFacts FunctionFacts + ParamHints ParamHints LiteralSigs LiteralSigs CapturedTypes CapturedTypes CapturedFields CapturedFieldAssigns diff --git a/compiler/check/api/facts_test.go b/compiler/check/api/facts_test.go index 2335f50f..dcefa7d1 100644 --- a/compiler/check/api/facts_test.go +++ b/compiler/check/api/facts_test.go @@ -13,18 +13,9 @@ func TestFacts_Zero(t *testing.T) { if f.FunctionFacts != nil { t.Error("zero Facts should have nil FunctionFacts") } - if f.ReturnSummaries != nil { - t.Error("zero Facts should have nil ReturnSummaries") - } - if f.NarrowReturns != nil { - t.Error("zero Facts should have nil NarrowReturns") - } if f.ParamHints != nil { t.Error("zero Facts should have nil ParamHints") } - if f.FuncTypes != nil { - t.Error("zero Facts should have nil FuncTypes") - } if f.LiteralSigs != nil { t.Error("zero Facts should have nil LiteralSigs") } @@ -39,15 +30,12 @@ func TestFacts_Zero(t *testing.T) { } } -func TestReturnSummaries_Basic(t *testing.T) { - summaries := make(ReturnSummaries) +func TestFunctionFacts_Summary(t *testing.T) { + facts := make(FunctionFacts) sym := cfg.SymbolID(1) - summaries[sym] = []typ.Type{typ.String, typ.Nil} + facts[sym] = FunctionFact{Summary: []typ.Type{typ.String, typ.Nil}} - rets, ok := summaries[sym] - if !ok { - t.Fatal("expected symbol to be in summaries") - } + rets := facts.Summary(sym) if len(rets) != 2 { t.Errorf("expected 2 return types, got %d", len(rets)) } @@ -67,16 +55,13 @@ func TestParamHints_Basic(t *testing.T) { } } -func TestFuncTypes_Basic(t *testing.T) { - funcTypes := make(FuncTypes) +func TestFunctionFacts_FunctionType(t *testing.T) { + facts := make(FunctionFacts) sym := cfg.SymbolID(1) fn := typ.Func().Param("x", typ.Number).Returns(typ.String).Build() - funcTypes[sym] = fn + facts[sym] = FunctionFact{Type: fn} - retrieved, ok := funcTypes[sym] - if !ok { - t.Fatal("expected symbol to be in funcTypes") - } + retrieved := facts.FunctionType(sym) if retrieved == nil { t.Error("expected non-nil function type") } @@ -174,30 +159,18 @@ func TestFacts_WithData(t *testing.T) { 4: { Summary: []typ.Type{typ.Boolean}, Narrow: []typ.Type{typ.Boolean}, - Func: typ.Func().Returns(typ.Boolean).Build(), + Type: typ.Func().Returns(typ.Boolean).Build(), }, }, - ReturnSummaries: ReturnSummaries{ - 1: []typ.Type{typ.String}, - }, ParamHints: ParamHints{ 2: []typ.Type{typ.Number}, }, - FuncTypes: FuncTypes{ - 3: typ.Func().Build(), - }, } - if len(f.ReturnSummaries) != 1 { - t.Error("expected 1 return summary") - } if len(f.FunctionFacts) != 1 { t.Error("expected 1 function fact") } if len(f.ParamHints) != 1 { t.Error("expected 1 param hint") } - if len(f.FuncTypes) != 1 { - t.Error("expected 1 func type") - } } diff --git a/compiler/check/api/keys.go b/compiler/check/api/keys.go index 962076b5..54e665e5 100644 --- a/compiler/check/api/keys.go +++ b/compiler/check/api/keys.go @@ -20,8 +20,9 @@ type SymbolKey struct { ParentHash uint64 } -// FuncKey uniquely identifies a function analysis request for memoization purposes. -// The key combines three components to ensure cache correctness: +// FuncKey uniquely identifies a function analysis request for memoization. +// Snapshot dependencies are tracked by the query database as the function result +// reads interprocedural facts, refinements, and constructor fields. // // - GraphID: Unique identifier for the function's control flow graph. Each CFG // receives a monotonically increasing ID during construction, ensuring distinct @@ -30,15 +31,9 @@ type SymbolKey struct { // - ParentHash: Hash of the parent scope state. Functions with identical code but // different lexical environments (e.g., different captured variables or type // definitions in scope) must be analyzed separately. -// -// - StoreRevision: Counter incremented at each fixpoint iteration boundary. -// This ensures cached results are invalidated when inter-function summaries -// (return types, effects, sibling types) change, forcing recomputation with -// updated cross-function information. type FuncKey struct { - GraphID uint64 - ParentHash uint64 - StoreRevision uint64 + GraphID uint64 + ParentHash uint64 } // KeyForGraph creates a GraphKey from a graph and parent scope. diff --git a/compiler/check/api/keys_test.go b/compiler/check/api/keys_test.go index 423e6df3..eb8f0b98 100644 --- a/compiler/check/api/keys_test.go +++ b/compiler/check/api/keys_test.go @@ -44,24 +44,24 @@ func TestSymbolKey_Equality(t *testing.T) { func TestFuncKey_Zero(t *testing.T) { k := FuncKey{} - if k.GraphID != 0 || k.ParentHash != 0 || k.StoreRevision != 0 { + if k.GraphID != 0 || k.ParentHash != 0 { t.Error("zero FuncKey should have zero fields") } } func TestFuncKey_Equality(t *testing.T) { - a := FuncKey{GraphID: 1, ParentHash: 2, StoreRevision: 3} - b := FuncKey{GraphID: 1, ParentHash: 2, StoreRevision: 3} + a := FuncKey{GraphID: 1, ParentHash: 2} + b := FuncKey{GraphID: 1, ParentHash: 2} if a != b { t.Error("equal FuncKeys should be ==") } } -func TestFuncKey_DifferentRevision(t *testing.T) { - a := FuncKey{GraphID: 1, ParentHash: 2, StoreRevision: 3} - b := FuncKey{GraphID: 1, ParentHash: 2, StoreRevision: 4} +func TestFuncKey_DifferentParent(t *testing.T) { + a := FuncKey{GraphID: 1, ParentHash: 2} + b := FuncKey{GraphID: 1, ParentHash: 3} if a == b { - t.Error("FuncKeys with different revisions should not be ==") + t.Error("FuncKeys with different parents should not be ==") } } @@ -87,8 +87,8 @@ func TestKeyForGraph_AsMapKey(t *testing.T) { func TestFuncKey_AsMapKey(t *testing.T) { m := make(map[FuncKey]int) - k1 := FuncKey{GraphID: 1, ParentHash: 2, StoreRevision: 3} - k2 := FuncKey{GraphID: 1, ParentHash: 2, StoreRevision: 3} + k1 := FuncKey{GraphID: 1, ParentHash: 2} + k2 := FuncKey{GraphID: 1, ParentHash: 2} m[k1] = 42 if m[k2] != 42 { t.Error("FuncKey should work as map key") diff --git a/compiler/check/api/parents.go b/compiler/check/api/parents.go index 99aeb3bf..3bece288 100644 --- a/compiler/check/api/parents.go +++ b/compiler/check/api/parents.go @@ -3,33 +3,33 @@ package api import "github.com/wippyai/go-lua/compiler/check/scope" // ParentScopeForGraph resolves the canonical parent scope for a graph. -// It prefers the stable parent-scope snapshot recorded in store and falls -// back to fallback when no stable parent is available. -func ParentScopeForGraph(store ParentScopes, graphID uint64, fallback *scope.State) *scope.State { +// It prefers the stable parent-scope snapshot recorded in store and uses +// defaultScope only when no stable parent is available. +func ParentScopeForGraph(store ParentScopes, graphID uint64, defaultScope *scope.State) *scope.State { if store == nil || graphID == 0 { - return fallback + return defaultScope } parentHash := store.GraphParentHashOf(graphID) if parentHash == 0 { - return fallback + return defaultScope } if parent := store.Parents()[parentHash]; parent != nil { return parent } - return fallback + return defaultScope } // ParentHashForGraph resolves the canonical parent hash for a graph. -// It prefers the stable graph-parent hash recorded in store and falls back to -// fallback.Hash() when no stable hash exists. -func ParentHashForGraph(store ParentScopes, graphID uint64, fallback *scope.State) uint64 { +// It prefers the stable graph-parent hash recorded in store and uses +// defaultScope.Hash() when no stable hash exists. +func ParentHashForGraph(store ParentScopes, graphID uint64, defaultScope *scope.State) uint64 { if store != nil && graphID != 0 { if parentHash := store.GraphParentHashOf(graphID); parentHash != 0 { return parentHash } } - if fallback != nil { - return fallback.Hash() + if defaultScope != nil { + return defaultScope.Hash() } return 0 } diff --git a/compiler/check/api/parents_test.go b/compiler/check/api/parents_test.go index 102f37b7..34b03392 100644 --- a/compiler/check/api/parents_test.go +++ b/compiler/check/api/parents_test.go @@ -25,54 +25,54 @@ func (s *parentScopeStoreStub) GraphKeyFor(graph *cfg.Graph, parent *scope.State } func TestParentScopeForGraph_PrefersStoredParent(t *testing.T) { - fallback := scope.New() + defaultScope := scope.New() stored := scope.New() store := &parentScopeStoreStub{ parents: map[uint64]*scope.State{11: stored}, hashes: map[uint64]uint64{7: 11}, } - got := ParentScopeForGraph(store, 7, fallback) + got := ParentScopeForGraph(store, 7, defaultScope) if got != stored { t.Fatalf("expected stored parent, got %p want %p", got, stored) } } -func TestParentScopeForGraph_FallsBackWhenStoredMissing(t *testing.T) { - fallback := scope.New() +func TestParentScopeForGraph_UsesDefaultWhenStoredMissing(t *testing.T) { + defaultScope := scope.New() store := &parentScopeStoreStub{ parents: map[uint64]*scope.State{}, hashes: map[uint64]uint64{7: 11}, } - got := ParentScopeForGraph(store, 7, fallback) - if got != fallback { - t.Fatalf("expected fallback parent, got %p want %p", got, fallback) + got := ParentScopeForGraph(store, 7, defaultScope) + if got != defaultScope { + t.Fatalf("expected default parent, got %p want %p", got, defaultScope) } } func TestParentHashForGraph_PrefersStoredHash(t *testing.T) { - fallback := scope.New() + defaultScope := scope.New() store := &parentScopeStoreStub{ parents: map[uint64]*scope.State{}, hashes: map[uint64]uint64{7: 11}, } - got := ParentHashForGraph(store, 7, fallback) + got := ParentHashForGraph(store, 7, defaultScope) if got != 11 { t.Fatalf("expected stored hash 11, got %d", got) } } -func TestParentHashForGraph_FallsBackToScopeHash(t *testing.T) { - fallback := scope.New() +func TestParentHashForGraph_UsesDefaultScopeHash(t *testing.T) { + defaultScope := scope.New() store := &parentScopeStoreStub{ parents: map[uint64]*scope.State{}, hashes: map[uint64]uint64{}, } - got := ParentHashForGraph(store, 7, fallback) - if got != fallback.Hash() { - t.Fatalf("expected fallback hash %d, got %d", fallback.Hash(), got) + got := ParentHashForGraph(store, 7, defaultScope) + if got != defaultScope.Hash() { + t.Fatalf("expected default hash %d, got %d", defaultScope.Hash(), got) } } diff --git a/compiler/check/api/result.go b/compiler/check/api/result.go index 72fd5905..ef88989f 100644 --- a/compiler/check/api/result.go +++ b/compiler/check/api/result.go @@ -24,8 +24,8 @@ type FuncResult struct { // binding information, and iteration metadata. Graph *cfg.Graph - // ModuleBindings is the module-level binding table used as fallback when - // graph-local bindings are insufficient for canonical symbol resolution. + // ModuleBindings is the module-level binding table used when graph-local + // bindings are insufficient for canonical symbol resolution. ModuleBindings *bind.BindingTable // BaseScope is the function's entry scope containing parameters, @@ -109,9 +109,9 @@ func (r *FuncResult) ExcludesTypeAt(p cfg.Point, path constraint.Path, declared return r.FlowSolution.ExcludesTypeAt(p, path, declared) } -// FuncResultView is the minimal view of a function analysis result +// FuncResultSnapshot is the stable slice of a function analysis result // required by nested processing and interprocedural helpers. -type FuncResultView struct { +type FuncResultSnapshot struct { Graph *cfg.Graph Scopes map[cfg.Point]*scope.State Facts flow.TypeFacts @@ -119,12 +119,12 @@ type FuncResultView struct { NarrowSynth Synth } -// ViewFromResult constructs a minimal view from a full function result. -func ViewFromResult(r *FuncResult) *FuncResultView { +// SnapshotFromResult constructs a stable snapshot from a full function result. +func SnapshotFromResult(r *FuncResult) *FuncResultSnapshot { if r == nil { return nil } - return &FuncResultView{ + return &FuncResultSnapshot{ Graph: r.Graph, Scopes: r.Scopes, Facts: r.Facts, diff --git a/compiler/check/api/store.go b/compiler/check/api/store.go index 51343247..98bdc65b 100644 --- a/compiler/check/api/store.go +++ b/compiler/check/api/store.go @@ -11,8 +11,8 @@ // NestedMetaStore - Nested function metadata // SnapshotStore - Stable interproc fact snapshots // FunctionRefs - Symbol/function bidirectional lookup -// StoreView - Read-only combination of above -// NestedStore - StoreView + constructor field storage +// StoreReader - Read-only combination of above +// NestedStore - StoreReader + constructor field storage // IterationStore - Full mutation capability for fixpoint package api @@ -74,12 +74,10 @@ type NestedMetaStore interface { // SnapshotStore exposes stable interproc fact snapshots. type SnapshotStore interface { GetParamHintsSnapshot(graph *cfg.Graph, parent *scope.State) ParamHints - GetReturnSummariesSnapshot(graph *cfg.Graph, parent *scope.State) ReturnSummaries - GetNarrowReturnSummariesSnapshot(graph *cfg.Graph, parent *scope.State) NarrowReturnSummaries + GetFunctionFactsSnapshot(graph *cfg.Graph, parent *scope.State) FunctionFacts GetCapturedTypesSnapshot(graph *cfg.Graph, parent *scope.State) CapturedTypes GetCapturedFieldAssignsSnapshot(graph *cfg.Graph, parent *scope.State) CapturedFieldAssigns GetCapturedContainerMutationsSnapshot(graph *cfg.Graph, parent *scope.State) CapturedContainerMutations - GetLocalFuncTypesSnapshot(graph *cfg.Graph, parent *scope.State) FuncTypes GetLiteralSigsSnapshot(graph *cfg.Graph, parent *scope.State) LiteralSigs } @@ -92,8 +90,8 @@ type FunctionRefs interface { SymbolForFunc(fn *ast.FunctionExpr) (cfg.SymbolID, bool) } -// StoreView is the minimal interface required by pre-flow return inference. -type StoreView interface { +// StoreReader is the read contract shared by checker phases. +type StoreReader interface { ModuleStore GraphStore ParentScopes @@ -110,12 +108,12 @@ type ConstructorFieldStore interface { // InterprocFactSink provides write access to per-iteration interproc facts. type InterprocFactSink interface { - UpdateInterprocFactsNext(key GraphKey, update func(*Facts)) + MergeInterprocFactsNext(key GraphKey, delta Facts) } // NestedStore is the store interface required by nested processing. type NestedStore interface { - StoreView + StoreReader ConstructorFieldStore InterprocFactSink } @@ -138,8 +136,6 @@ type IterationStore interface { ClearIterationChannels() FixpointSwap() bool FixpointChannelDiffs() []string - Revision() uint64 - BumpRevision() RefinementStore() RefinementStore StoreFunctionRefinement(sym cfg.SymbolID, eff *constraint.FunctionRefinement) @@ -149,6 +145,6 @@ type IterationStore interface { SetParentScope(parentHash uint64, parent *scope.State) SetGraphParentHash(graphID, parentHash uint64) - UpdateInterprocFactsNext(key GraphKey, update func(*Facts)) + MergeInterprocFactsNext(key GraphKey, delta Facts) ParentGraphKeyForSymbol(sym cfg.SymbolID) (GraphKey, bool) } diff --git a/compiler/check/api/store_attach.go b/compiler/check/api/store_attach.go index 1eed8e4b..5c04b478 100644 --- a/compiler/check/api/store_attach.go +++ b/compiler/check/api/store_attach.go @@ -2,19 +2,19 @@ package api import "github.com/wippyai/go-lua/types/db" -// StoreKey is the typed attachment key for StoreView. -var StoreKey = db.NewAttachmentKey[StoreView]("check.StoreView") +// StoreKey is the typed attachment key for StoreReader. +var StoreKey = db.NewAttachmentKey[StoreReader]("check.StoreReader") // AttachStore attaches a store to the query context for lookup. -func AttachStore(ctx *db.QueryContext, store StoreView) { +func AttachStore(ctx *db.QueryContext, store StoreReader) { if ctx == nil || store == nil { return } db.Attach(ctx, StoreKey, store) } -// StoreFrom retrieves the StoreView from a db.QueryContext. -func StoreFrom(ctx *db.QueryContext) StoreView { +// StoreFrom retrieves the StoreReader from a db.QueryContext. +func StoreFrom(ctx *db.QueryContext) StoreReader { store, _ := db.Attached(ctx, StoreKey) return store } diff --git a/compiler/check/callsite/callee_symbols.go b/compiler/check/callsite/callee_symbols.go index bd152843..029be20f 100644 --- a/compiler/check/callsite/callee_symbols.go +++ b/compiler/check/callsite/callee_symbols.go @@ -19,9 +19,7 @@ func CalleeSymbolCandidates(info *cfg.CallInfo, primary, fallback *bind.BindingT return nil } set := newSymbolSet(4) - for _, sym := range exprSymbolCandidates(info.Callee, info.CalleeSymbol, primary, fallback) { - set.Add(sym) - } + addExprSymbolCandidates(set, info.Callee, info.CalleeSymbol, primary, fallback) if methodSym, ok := methodCalleeSymbolFromCall(primary, nil, info); ok { set.Add(methodSym) } @@ -59,8 +57,8 @@ func CallableCalleeSymbolCandidates( return base } set := newSymbolSet(len(base)*2 + 2) - for _, sym := range expandAliasCandidates(base, graph) { - set.Add(sym) + for _, sym := range base { + addAliasExpansion(set, graph, sym) } // Method calls may resolve method symbol only through an alias receiver base diff --git a/compiler/check/callsite/candidates.go b/compiler/check/callsite/candidates.go index e293f6a1..40cddf85 100644 --- a/compiler/check/callsite/candidates.go +++ b/compiler/check/callsite/candidates.go @@ -66,6 +66,61 @@ func (s *symbolSet) Slice() []cfg.SymbolID { return s.order } +type symbolDeduper struct { + small [symbolSetMapThreshold]cfg.SymbolID + count int + seen map[cfg.SymbolID]struct{} +} + +func (d *symbolDeduper) Add(sym cfg.SymbolID) bool { + if sym == 0 { + return false + } + if d.seen != nil { + if _, ok := d.seen[sym]; ok { + return false + } + d.seen[sym] = struct{}{} + return true + } + for i := 0; i < d.count; i++ { + if d.small[i] == sym { + return false + } + } + if d.count < len(d.small) { + d.small[d.count] = sym + d.count++ + return true + } + d.seen = make(map[cfg.SymbolID]struct{}, len(d.small)+1) + for i := 0; i < d.count; i++ { + d.seen[d.small[i]] = struct{}{} + } + d.seen[sym] = struct{}{} + return true +} + +type preferredSymbolSelector struct { + prefer func(cfg.SymbolID) bool + selected cfg.SymbolID + seen symbolDeduper +} + +func (s *preferredSymbolSelector) Add(sym cfg.SymbolID) bool { + if !s.seen.Add(sym) { + return false + } + if s.selected == 0 { + s.selected = sym + } + if s.prefer != nil && s.prefer(sym) { + s.selected = sym + return true + } + return false +} + // SelectPreferredSymbol returns the first candidate and, if prefer is non-nil, returns // the first candidate that satisfies the predicate. func SelectPreferredSymbol(candidates []cfg.SymbolID, prefer func(cfg.SymbolID) bool) cfg.SymbolID { @@ -81,16 +136,79 @@ func SelectPreferredSymbol(candidates []cfg.SymbolID, prefer func(cfg.SymbolID) return selected } +func visitExprSymbolCandidates( + expr ast.Expr, + raw cfg.SymbolID, + primary *bind.BindingTable, + fallback *bind.BindingTable, + visit func(cfg.SymbolID) bool, +) bool { + if visit == nil { + return false + } + if visit(raw) { + return true + } + if visit(SymbolFromExpr(expr, primary)) { + return true + } + if fallback != primary { + return visit(SymbolFromExpr(expr, fallback)) + } + return false +} + +func addExprSymbolCandidates( + set *symbolSet, + expr ast.Expr, + raw cfg.SymbolID, + primary *bind.BindingTable, + fallback *bind.BindingTable, +) { + if set == nil { + return + } + visitExprSymbolCandidates(expr, raw, primary, fallback, func(sym cfg.SymbolID) bool { + set.Add(sym) + return false + }) +} + func addAliasExpansion(set *symbolSet, graph *cfg.Graph, sym cfg.SymbolID) { if set == nil || graph == nil || sym == 0 { return } - graph.EachAliasSymbol(sym, func(candidate cfg.SymbolID) bool { + visitAliasExpansion(graph, sym, func(candidate cfg.SymbolID) bool { set.Add(candidate) return false }) } +func visitAliasExpansion(graph *cfg.Graph, sym cfg.SymbolID, visit func(cfg.SymbolID) bool) bool { + if sym == 0 || visit == nil { + return false + } + if graph == nil { + return visit(sym) + } + var chain symbolDeduper + current := sym + for current != 0 { + if !chain.Add(current) { + return false + } + if visit(current) { + return true + } + next := graph.DirectAliasSymbol(current) + if next == 0 || next == current { + return false + } + current = next + } + return false +} + func expandAliasCandidates(base []cfg.SymbolID, graph *cfg.Graph) []cfg.SymbolID { if graph == nil || len(base) == 0 { return base @@ -113,10 +231,6 @@ func exprSymbolCandidates( fallback *bind.BindingTable, ) []cfg.SymbolID { set := newSymbolSet(3) - set.Add(raw) - set.Add(SymbolFromExpr(expr, primary)) - if fallback != primary { - set.Add(SymbolFromExpr(expr, fallback)) - } + addExprSymbolCandidates(set, expr, raw, primary, fallback) return set.Slice() } diff --git a/compiler/check/callsite/canonical_symbol.go b/compiler/check/callsite/canonical_symbol.go index 6b367133..2fc73a77 100644 --- a/compiler/check/callsite/canonical_symbol.go +++ b/compiler/check/callsite/canonical_symbol.go @@ -17,13 +17,9 @@ func CanonicalSymbolFromExprWithAliases( fallback *bind.BindingTable, prefer func(cfg.SymbolID) bool, ) cfg.SymbolID { - base := exprSymbolCandidates(expr, raw, primary, fallback) - if len(base) == 0 { - return 0 - } - if graph == nil { - return SelectPreferredSymbol(base, prefer) - } - candidates := expandAliasCandidates(base, graph) - return SelectPreferredSymbol(candidates, prefer) + selector := preferredSymbolSelector{prefer: prefer} + visitExprSymbolCandidates(expr, raw, primary, fallback, func(sym cfg.SymbolID) bool { + return visitAliasExpansion(graph, sym, selector.Add) + }) + return selector.selected } diff --git a/compiler/check/checker.go b/compiler/check/checker.go index 8f81acca..4b5315c0 100644 --- a/compiler/check/checker.go +++ b/compiler/check/checker.go @@ -29,9 +29,8 @@ // // The checker supports interprocedural analysis through a unified interproc snapshot: // -// - ReturnSummaries: Inferred return types for local functions +// - FunctionFacts: Canonical return/narrow/signature facts for local functions // - ParamHints: Inferred parameter types from call sites -// - FuncTypes: Canonical local function types for sibling lookups // - LiteralSigs: Synthesized signatures for function literals // - Refinements: Function refinement summaries, stored per symbol // @@ -44,9 +43,9 @@ // // # MEMOIZATION // -// Function analysis results are memoized by (GraphID, ParentHash, StoreRevision). -// The memoization cache is cleared at each iteration boundary to force recomputation -// with updated inter-function summaries. +// Function analysis results are memoized by (GraphID, ParentHash). Interprocedural +// facts, refinements, and constructor fields are tracked as query inputs, so +// cached results are revalidated precisely when the snapshots they read change. // // # CONVERGENCE // @@ -128,8 +127,8 @@ func WithComputePass(p api.ComputePass) Option { // for analyzing multiple files in sequence or parallel. // // MEMOIZATION: Function analysis is memoized through funcResultQ keyed by FuncKey. -// The cache is cleared at each fixpoint iteration boundary to ensure fresh -// computation with updated inter-function summaries. +// Inter-function inputs are tracked through the query database, so unchanged +// functions can be reused across fixpoint iterations without a coarse revision key. // // EXTENSION POINTS: Checker supports two extension mechanisms: // - Pass: Diagnostic generators that run after fixpoint convergence @@ -336,16 +335,7 @@ func (c *Checker) runPasses(sess *Session) { } func funcResultEqual(a, b *api.FuncResult) bool { - if a == b { - return true - } - if a == nil || b == nil { - return false - } - if a.Graph != nil && b.Graph != nil { - return a.Graph.ID() == b.Graph.ID() - } - return false + return a == b } // ClearCache removes all memoized function analysis results from the query cache. diff --git a/compiler/check/erreffect/error_return_infer.go b/compiler/check/erreffect/error_return_infer.go index 70a650d5..57acf709 100644 --- a/compiler/check/erreffect/error_return_infer.go +++ b/compiler/check/erreffect/error_return_infer.go @@ -13,6 +13,62 @@ import ( "github.com/wippyai/go-lua/types/typ/unwrap" ) +// ErrorReturnConvention describes a return layout where one slot carries the +// success value and another carries the error. ReturnCount is exact: a function +// with extra returned values is not inferred as this convention. +type ErrorReturnConvention struct { + ValueIndex int + ErrorIndex int + ReturnCount int +} + +// CanonicalLuaValueErrorConvention returns the canonical Lua `(value, err)` layout. +func CanonicalLuaValueErrorConvention() ErrorReturnConvention { + return ErrorReturnConvention{ + ValueIndex: 0, + ErrorIndex: 1, + ReturnCount: 2, + } +} + +func (c ErrorReturnConvention) valid() bool { + return c.ValueIndex >= 0 && + c.ErrorIndex >= 0 && + c.ValueIndex != c.ErrorIndex && + c.ValueIndex < c.ReturnCount && + c.ErrorIndex < c.ReturnCount +} + +// CanClassifyReturns reports whether returnTypes has the exact shape required +// by this convention before the expensive per-return inverse-pattern proof runs. +func (c ErrorReturnConvention) CanClassifyReturns(returnTypes []typ.Type) bool { + return c.valid() && len(returnTypes) == c.ReturnCount +} + +func (c ErrorReturnConvention) canClassifyFunction(fn *typ.Function) bool { + return fn != nil && c.CanClassifyReturns(fn.Returns) +} + +// HasStrictInversePattern proves this convention from the function body. +func (c ErrorReturnConvention) HasStrictInversePattern( + graph *cfg.Graph, + solution *flow.Solution, + synth api.BaseSynth, +) bool { + if !c.valid() { + return false + } + return HasStrictInverseReturnPattern(graph, solution, synth, c.ValueIndex, c.ErrorIndex) +} + +// Attach enriches fn with this convention's ErrorReturn effect. +func (c ErrorReturnConvention) Attach(fn *typ.Function) *typ.Function { + if !c.valid() { + return fn + } + return AttachErrorReturnSpec(fn, c.ValueIndex, c.ErrorIndex) +} + // AttachInferredErrorReturnSpec enriches function types with a canonical // ErrorReturn effect when the function body proves the `(value, err)` pattern. func AttachInferredErrorReturnSpec( @@ -21,7 +77,8 @@ func AttachInferredErrorReturnSpec( solution *flow.Solution, synth api.Synth, ) *typ.Function { - if fn == nil || graph == nil || synth == nil || len(fn.Returns) != 2 { + convention := CanonicalLuaValueErrorConvention() + if graph == nil || synth == nil || !convention.canClassifyFunction(fn) { return fn } if HasErrorReturnLabel(fn) { @@ -31,11 +88,11 @@ func AttachInferredErrorReturnSpec( if base == nil { base = synth } - if !HasStrictInverseReturnPattern(graph, solution, base, 0, 1) { + if !convention.HasStrictInversePattern(graph, solution, base) { return fn } - return AttachErrorReturnSpec(fn, 0, 1) + return convention.Attach(fn) } func HasErrorReturnLabel(fn *typ.Function) bool { diff --git a/compiler/check/erreffect/error_return_infer_test.go b/compiler/check/erreffect/error_return_infer_test.go new file mode 100644 index 00000000..31e4cef2 --- /dev/null +++ b/compiler/check/erreffect/error_return_infer_test.go @@ -0,0 +1,35 @@ +package erreffect + +import ( + "testing" + + "github.com/wippyai/go-lua/types/typ" +) + +func TestErrorReturnConventionCanClassifyReturns(t *testing.T) { + t.Parallel() + + convention := CanonicalLuaValueErrorConvention() + if !convention.CanClassifyReturns([]typ.Type{typ.String, typ.Nil}) { + t.Fatal("canonical value/error convention should classify exactly two returns") + } + if convention.CanClassifyReturns([]typ.Type{typ.String}) { + t.Fatal("canonical value/error convention should reject missing error slot") + } + if convention.CanClassifyReturns([]typ.Type{typ.String, typ.Nil, typ.Boolean}) { + t.Fatal("canonical value/error convention should reject extra return slots") + } +} + +func TestErrorReturnConventionRejectsInvalidLayout(t *testing.T) { + t.Parallel() + + convention := ErrorReturnConvention{ + ValueIndex: 0, + ErrorIndex: 0, + ReturnCount: 1, + } + if convention.CanClassifyReturns([]typ.Type{typ.Nil}) { + t.Fatal("convention with overlapping value/error slots should be invalid") + } +} diff --git a/compiler/check/flowbuild/assign/emit.go b/compiler/check/flowbuild/assign/emit.go index b7245bac..8aa731a4 100644 --- a/compiler/check/flowbuild/assign/emit.go +++ b/compiler/check/flowbuild/assign/emit.go @@ -157,8 +157,11 @@ func ExtractAssignments(fc *fbcore.FlowContext, inputs *flow.Inputs, keysCollect typeGuards := guard.CollectTypeGuards(fc.Graph, bindings) baseSynth := synthWithOverlayAndPreflow(overlayTypes, bindings, inputs, fc.CallCtx, fc.TypeOps, preflowBranchSolution, synth) - idom, _ := cfganalysis.ComputeDominators(fc.Graph.CFG()) structuredWrites := indexStructuredWrites(fc.Graph) + var idom map[cfg.Point]cfg.Point + if len(structuredWrites) > 0 { + idom = cfganalysis.ComputeImmediateDominators(fc.Graph.CFG()) + } var wrappedSynth func(ast.Expr, cfg.Point) typ.Type wrappedSynth = func(expr ast.Expr, p cfg.Point) typ.Type { if table, ok := expr.(*ast.TableExpr); ok && !tblutil.TableHasFunctionField(table) { @@ -440,7 +443,14 @@ func ExtractAssignments(fc *fbcore.FlowContext, inputs *flow.Inputs, keysCollect if !ok || fn == nil { continue } - if info := keyscoll.DetectKeysCollector(fn); info != nil && info.ReturnIndex == retIndex { + var fnGraph *cfg.Graph + if fc.Graphs != nil { + fnGraph = fc.Graphs.GetOrBuildCFG(fn) + } + if fnGraph == nil { + fnGraph = cfg.BuildWithBindings(fn, fc.ModuleBindings) + } + if info := keyscoll.DetectKeysCollector(fnGraph); info != nil && info.ReturnIndex == retIndex { tableSym = callsite.SymbolOrCreateFieldFromExpr(callsite.RuntimeArgAt(call, info.ParamIndex), bindings) break } diff --git a/compiler/check/flowbuild/assign/emit_test.go b/compiler/check/flowbuild/assign/emit_test.go index 537c895f..df0f9a96 100644 --- a/compiler/check/flowbuild/assign/emit_test.go +++ b/compiler/check/flowbuild/assign/emit_test.go @@ -542,7 +542,7 @@ func TestExtractAssignments_KeysCollector_WithFilterBranch(t *testing.T) { return typ.Unknown }, }, - }, inputs, keyscoll.BuildKeysCollectorDetector(graph, nil)) + }, inputs, keyscoll.BuildKeysCollectorDetector(graph, nil, nil)) src, ok := inputs.KeysProvenance[suiteNamesSym] if !ok || src != suitesSym { diff --git a/compiler/check/flowbuild/assign/infer.go b/compiler/check/flowbuild/assign/infer.go index 48606dfb..6e3b70f0 100644 --- a/compiler/check/flowbuild/assign/infer.go +++ b/compiler/check/flowbuild/assign/infer.go @@ -150,8 +150,11 @@ func collectInferredTypes( if graph == nil { return inferred } - idom, _ := cfganalysis.ComputeDominators(graph.CFG()) structuredWrites := indexStructuredWrites(graph) + var idom map[cfg.Point]cfg.Point + if len(structuredWrites) > 0 { + idom = cfganalysis.ComputeImmediateDominators(graph.CFG()) + } bindings := graph.Bindings() if moduleBindings == nil { diff --git a/compiler/check/flowbuild/assign/structured_overlay.go b/compiler/check/flowbuild/assign/structured_overlay.go index 7c94be02..f1595e48 100644 --- a/compiler/check/flowbuild/assign/structured_overlay.go +++ b/compiler/check/flowbuild/assign/structured_overlay.go @@ -19,11 +19,11 @@ type structuredWrite struct { // indexStructuredWrites collects static field/index writes keyed by base symbol. func indexStructuredWrites(graph *cfg.Graph) map[cfg.SymbolID][]structuredWrite { - result := make(map[cfg.SymbolID][]structuredWrite) if graph == nil { - return result + return nil } + var result map[cfg.SymbolID][]structuredWrite graph.EachAssign(func(p cfg.Point, info *cfg.AssignInfo) { if info == nil { return @@ -33,6 +33,9 @@ func indexStructuredWrites(graph *cfg.Graph) map[cfg.SymbolID][]structuredWrite if !ok { continue } + if result == nil { + result = make(map[cfg.SymbolID][]structuredWrite) + } result[sym] = append(result[sym], write) } }) diff --git a/compiler/check/flowbuild/core/context.go b/compiler/check/flowbuild/core/context.go index 1edb7c82..87a81629 100644 --- a/compiler/check/flowbuild/core/context.go +++ b/compiler/check/flowbuild/core/context.go @@ -28,6 +28,9 @@ type FlowContext struct { // Query context for memoization CallCtx *db.QueryContext + // Graphs provides canonical CFGs for function literals. + Graphs api.GraphProvider + // Type operations TypeOps core.TypeOps diff --git a/compiler/check/flowbuild/keyscoll/keyscoll.go b/compiler/check/flowbuild/keyscoll/keyscoll.go index 04aa2c12..f89d65eb 100644 --- a/compiler/check/flowbuild/keyscoll/keyscoll.go +++ b/compiler/check/flowbuild/keyscoll/keyscoll.go @@ -14,7 +14,12 @@ type KeysCollectorInfo struct { ReturnIndex int // Which return slot carries the keys table (0-based) } -// DetectKeysCollector analyzes a function body to detect if it follows the +// GraphProvider resolves canonical CFGs for function literals. +type GraphProvider interface { + GetOrBuildCFG(fn *ast.FunctionExpr) *cfg.Graph +} + +// DetectKeysCollector analyzes a function graph to detect if it follows the // "keys collector" pattern: creates a table, iterates with pairs over a param, // inserts keys into the table, and returns it. // @@ -25,18 +30,15 @@ type KeysCollectorInfo struct { // table.insert(keys, k) // end // return keys -func DetectKeysCollector(fn *ast.FunctionExpr) *KeysCollectorInfo { - if fn == nil || fn.Stmts == nil || len(fn.Stmts) == 0 { +func DetectKeysCollector(graph *cfg.Graph) *KeysCollectorInfo { + if graph == nil { return nil } - - graph := cfg.Build(fn) - if graph == nil { + fn := graph.Func() + if fn == nil || fn.Stmts == nil || len(fn.Stmts) == 0 { return nil } - // Use graph's own bindings since we build a fresh CFG. - // Passed-in bindings may have different symbol IDs. bindings := graph.Bindings() // Track: which local symbol is the "keys" table @@ -48,7 +50,7 @@ func DetectKeysCollector(fn *ast.FunctionExpr) *KeysCollectorInfo { insertedKeyIntoTable := false keysReturnIndex := -1 - paramSymbols := graph.ParamSymbols() + paramSlots := graph.ParamSlotsReadOnly() // Scan for local keys = {} pattern and generic for loop with pairs graph.EachAssign(func(p cfg.Point, info *cfg.AssignInfo) { @@ -97,9 +99,10 @@ func DetectKeysCollector(fn *ast.FunctionExpr) *KeysCollectorInfo { argSym, _ = gb.SymbolOf(argIdent) } } - // Check if argSym is a parameter - for i, ps := range paramSymbols { - if ps == argSym { + // Check if argSym is a parameter. The slot index is the runtime + // argument index, including implicit self when present. + for i, slot := range paramSlots { + if slot.Symbol == argSym { pairsParamSym = argSym pairsParamIndex = i break @@ -254,9 +257,31 @@ func isTableInsertCall(info *cfg.CallInfo) bool { return true } +func functionGraph(fn *ast.FunctionExpr, owner *cfg.Graph, graphs GraphProvider) *cfg.Graph { + if fn == nil { + return nil + } + if owner != nil && owner.Func() == fn { + return owner + } + if graphs != nil { + if graph := graphs.GetOrBuildCFG(fn); graph != nil { + return graph + } + } + if owner != nil && owner.Bindings() != nil { + return cfg.BuildWithBindings(fn, owner.Bindings()) + } + return cfg.Build(fn) +} + // BuildKeysCollectorDetector returns a callback that detects if a call is to a // keys collector function and returns the symbol of the table argument. -func BuildKeysCollectorDetector(graph *cfg.Graph, moduleBindings *bind.BindingTable) func(*cfg.CallInfo, cfg.Point, int) cfg.SymbolID { +func BuildKeysCollectorDetector( + graph *cfg.Graph, + moduleBindings *bind.BindingTable, + graphs GraphProvider, +) func(*cfg.CallInfo, cfg.Point, int) cfg.SymbolID { cache := make(map[cfg.SymbolID]*KeysCollectorInfo) bindings := graph.Bindings() @@ -284,7 +309,7 @@ func BuildKeysCollectorDetector(graph *cfg.Graph, moduleBindings *bind.BindingTa continue } - info := DetectKeysCollector(fn) + info := DetectKeysCollector(functionGraph(fn, graph, graphs)) cache[calleeSym] = info if info == nil { continue diff --git a/compiler/check/flowbuild/keyscoll/keyscoll_test.go b/compiler/check/flowbuild/keyscoll/keyscoll_test.go index b8f77a04..38ae1125 100644 --- a/compiler/check/flowbuild/keyscoll/keyscoll_test.go +++ b/compiler/check/flowbuild/keyscoll/keyscoll_test.go @@ -21,6 +21,10 @@ func TestKeysCollectorInfo_ParamIndex(t *testing.T) { } } +func detectKeysCollector(fn *ast.FunctionExpr) *keyscoll.KeysCollectorInfo { + return keyscoll.DetectKeysCollector(cfg.Build(fn)) +} + func TestDetectKeysCollector_NilFunction(t *testing.T) { result := keyscoll.DetectKeysCollector(nil) if result != nil { @@ -30,7 +34,7 @@ func TestDetectKeysCollector_NilFunction(t *testing.T) { func TestDetectKeysCollector_NilStmts(t *testing.T) { fn := &ast.FunctionExpr{Stmts: nil} - result := keyscoll.DetectKeysCollector(fn) + result := detectKeysCollector(fn) if result != nil { t.Error("expected nil for nil statements") } @@ -38,7 +42,7 @@ func TestDetectKeysCollector_NilStmts(t *testing.T) { func TestDetectKeysCollector_EmptyStmts(t *testing.T) { fn := &ast.FunctionExpr{Stmts: []ast.Stmt{}} - result := keyscoll.DetectKeysCollector(fn) + result := detectKeysCollector(fn) if result != nil { t.Error("expected nil for empty statements") } @@ -50,7 +54,7 @@ func TestDetectKeysCollector_SimpleReturn(t *testing.T) { &ast.ReturnStmt{Exprs: []ast.Expr{&ast.NilExpr{}}}, }, } - result := keyscoll.DetectKeysCollector(fn) + result := detectKeysCollector(fn) if result != nil { t.Error("expected nil for simple return function") } @@ -69,7 +73,7 @@ func TestDetectKeysCollector_NoKeysPattern(t *testing.T) { }, }, } - result := keyscoll.DetectKeysCollector(fn) + result := detectKeysCollector(fn) if result != nil { t.Error("expected nil for function without keys pattern") } @@ -80,7 +84,7 @@ func TestBuildKeysCollectorDetector_NilCallInfo(t *testing.T) { Stmts: []ast.Stmt{&ast.ReturnStmt{}}, } graph := cfg.Build(fn) - detector := keyscoll.BuildKeysCollectorDetector(graph, nil) + detector := keyscoll.BuildKeysCollectorDetector(graph, nil, nil) if detector == nil { t.Fatal("expected non-nil detector") } @@ -95,7 +99,7 @@ func TestBuildKeysCollectorDetector_MethodCall(t *testing.T) { Stmts: []ast.Stmt{&ast.ReturnStmt{}}, } graph := cfg.Build(fn) - detector := keyscoll.BuildKeysCollectorDetector(graph, nil) + detector := keyscoll.BuildKeysCollectorDetector(graph, nil, nil) callInfo := &cfg.CallInfo{ Method: "someMethod", Receiver: &ast.IdentExpr{Value: "obj"}, @@ -111,7 +115,7 @@ func TestBuildKeysCollectorDetector_NoCalleeSymbol(t *testing.T) { Stmts: []ast.Stmt{&ast.ReturnStmt{}}, } graph := cfg.Build(fn) - detector := keyscoll.BuildKeysCollectorDetector(graph, nil) + detector := keyscoll.BuildKeysCollectorDetector(graph, nil, nil) callInfo := &cfg.CallInfo{ Callee: &ast.IdentExpr{Value: "fn"}, CalleeSymbol: 0, @@ -138,7 +142,7 @@ func TestDetectKeysCollector_TableInsertAsAssignmentCallSite(t *testing.T) { ParList: &ast.ParList{Names: []string{"tbl"}}, Stmts: body, } - info := keyscoll.DetectKeysCollector(fn) + info := detectKeysCollector(fn) if info == nil { t.Fatal("expected keys collector to be detected when insert call is in assignment expression") } @@ -185,7 +189,7 @@ func TestBuildKeysCollectorDetector_NestedFieldArgument(t *testing.T) { } want := bindings.GetOrCreateFieldSymbol(stateSym, "users") - detector := keyscoll.BuildKeysCollectorDetector(graph, nil) + detector := keyscoll.BuildKeysCollectorDetector(graph, nil, nil) found := false graph.EachCallSite(func(p cfg.Point, info *cfg.CallInfo) { if info == nil || info.CalleeName != "sorted_keys" { @@ -216,7 +220,7 @@ func TestDetectKeysCollector_MultiReturnKeysIndex(t *testing.T) { ParList: &ast.ParList{Names: []string{"tbl"}}, Stmts: body, } - info := keyscoll.DetectKeysCollector(fn) + info := detectKeysCollector(fn) if info == nil { t.Fatal("expected keys collector info") } @@ -261,7 +265,7 @@ func TestBuildKeysCollectorDetector_RespectsReturnIndex(t *testing.T) { } want := bindings.GetOrCreateFieldSymbol(stateSym, "users") - detector := keyscoll.BuildKeysCollectorDetector(graph, nil) + detector := keyscoll.BuildKeysCollectorDetector(graph, nil, nil) found := false graph.EachCallSite(func(p cfg.Point, info *cfg.CallInfo) { if info == nil || info.CalleeName != "sorted_keys" { @@ -314,7 +318,7 @@ func TestBuildKeysCollectorDetector_UsesCanonicalCandidatesWhenRawSymbolMissing( } want := bindings.GetOrCreateFieldSymbol(stateSym, "users") - detector := keyscoll.BuildKeysCollectorDetector(graph, nil) + detector := keyscoll.BuildKeysCollectorDetector(graph, nil, nil) found := false graph.EachCallSite(func(p cfg.Point, info *cfg.CallInfo) { if info == nil || info.CalleeName != "sorted_keys" { @@ -368,7 +372,7 @@ func TestBuildKeysCollectorDetector_UsesModuleBindingNameFallback(t *testing.T) moduleBindings := bind.NewBindingTable() - detector := keyscoll.BuildKeysCollectorDetector(graph, moduleBindings) + detector := keyscoll.BuildKeysCollectorDetector(graph, moduleBindings, nil) found := false graph.EachCallSite(func(p cfg.Point, info *cfg.CallInfo) { if info == nil || info.CalleeName != "sorted_keys" { @@ -427,7 +431,7 @@ func TestBuildKeysCollectorDetector_UsesDirectAliasCandidate(t *testing.T) { } want := bindings.GetOrCreateFieldSymbol(stateSym, "users") - detector := keyscoll.BuildKeysCollectorDetector(graph, nil) + detector := keyscoll.BuildKeysCollectorDetector(graph, nil, nil) found := false graph.EachCallSite(func(p cfg.Point, info *cfg.CallInfo) { if info == nil || info.CalleeName != "sk" { diff --git a/compiler/check/flowbuild/run.go b/compiler/check/flowbuild/run.go index 936470bc..bcba5c12 100644 --- a/compiler/check/flowbuild/run.go +++ b/compiler/check/flowbuild/run.go @@ -104,7 +104,7 @@ func Run(fc *fbcore.FlowContext) *flow.Inputs { constprop.PropagateAllConstValues(fc, inputs) // Assignments with const resolution. - assign.ExtractAssignments(fc, inputs, keyscoll.BuildKeysCollectorDetector(fc.Graph, fc.ModuleBindings)) + assign.ExtractAssignments(fc, inputs, keyscoll.BuildKeysCollectorDetector(fc.Graph, fc.ModuleBindings, fc.Graphs)) // Table mutator assignments (table.insert-like). mutator.ExtractTableMutatorAssignments(fc, inputs) diff --git a/compiler/check/infer/interproc/doc.go b/compiler/check/infer/interproc/doc.go index 915e77fe..a0f33a01 100644 --- a/compiler/check/infer/interproc/doc.go +++ b/compiler/check/infer/interproc/doc.go @@ -20,6 +20,6 @@ // // # Integration // -// This package bridges per-function flow analysis with the global +// This package connects per-function flow analysis with the global // fixpoint iteration that resolves cross-function dependencies. package interproc diff --git a/compiler/check/infer/interproc/postflow.go b/compiler/check/infer/interproc/postflow.go index b9d2bdf1..c78f6e91 100644 --- a/compiler/check/infer/interproc/postflow.go +++ b/compiler/check/infer/interproc/postflow.go @@ -23,9 +23,9 @@ type functionTypeWithExpected interface { // Store is the minimal store interface required to record post-flow interproc facts. type Store interface { - api.StoreView + api.StoreReader - UpdateInterprocFactsNext(key api.GraphKey, update func(*api.Facts)) + MergeInterprocFactsNext(key api.GraphKey, delta api.Facts) StoreLiteralSigs(graphID uint64, sigs map[*ast.FunctionExpr]*typ.Function) ParentGraphKeyForSymbol(sym cfg.SymbolID) (api.GraphKey, bool) } @@ -65,26 +65,29 @@ func StoreFactsFromResult( if fnType == nil { return } - narrowReturns := returns.NormalizeReturnVector(fnType.Returns) + narrowSummary := returns.NormalizeReturnVector(fnType.Returns) if snapNarrow := narrowSummarySnapshotForSymbol(store, result, parent, fnSym); len(snapNarrow) > 0 { - narrowReturns = returns.MergeReturnSummary(narrowReturns, snapNarrow) - if aligned, changed := returns.AlignFunctionTypeWithSummary(fnType, narrowReturns); changed { + narrowSummary = returns.MergeReturnSummary(narrowSummary, snapNarrow) + if aligned, changed := returns.AlignFunctionTypeWithSummary(fnType, narrowSummary); changed { fnType = aligned } } summaryFromSnapshot := returnSummarySnapshotForSymbol(store, result, parent, fnSym) - writer.updateParentFactsForSymbol(fnSym, func(facts *api.Facts) { - candidateFunc := fnType - if hinted := paramhints.MergeIntoSignature(fn, facts.ParamHints[fnSym], unwrap.Function(candidateFunc)); hinted != nil { + candidateFunc := fnType + if hints := store.GetParamHintsSnapshot(result.Graph, parent); len(hints) > 0 { + if hinted := paramhints.MergeIntoSignature(fn, hints[fnSym], unwrap.Function(candidateFunc)); hinted != nil { candidateFunc = hinted } - returns.MergeFunctionFactIntoFacts(facts, fnSym, returns.FunctionFactCandidate{ + } + delta := api.Facts{FunctionFacts: api.FunctionFacts{ + fnSym: returns.JoinFunctionFact(api.FunctionFact{}, api.FunctionFact{ Summary: summaryFromSnapshot, - Narrow: narrowReturns, - Func: candidateFunc, - }) - }) + Narrow: narrowSummary, + Type: candidateFunc, + }), + }} + writer.mergeParentFactsForSymbol(fnSym, delta) } func storeCapturedFactsFromResult( @@ -108,28 +111,19 @@ func storeCapturedFactsFromResult( fields := nested.CollectCapturedFieldAssignments(result.Graph, capturedSet, result.NarrowSynth.TypeOf) if len(fields) > 0 { - writer.updateParentFactsForSymbol(fnSym, func(facts *api.Facts) { - if facts.CapturedFields == nil { - facts.CapturedFields = make(api.CapturedFieldAssigns) - } - existing := facts.CapturedFields[fnSym] - facts.CapturedFields[fnSym] = returns.MergeCapturedFieldSymbolMaps(existing, fields, typ.JoinPreferNonSoft) + writer.mergeParentFactsForSymbol(fnSym, api.Facts{ + CapturedFields: api.CapturedFieldAssigns{ + fnSym: fields, + }, }) } mutations := nested.CollectCapturedContainerMutations(result.Graph, capturedSet, result.NarrowSynth.TypeOf) if len(mutations) > 0 { - writer.updateParentFactsForSymbol(fnSym, func(facts *api.Facts) { - if facts.CapturedContainers == nil { - facts.CapturedContainers = make(api.CapturedContainerMutations) - } - existing := facts.CapturedContainers[fnSym] - facts.CapturedContainers[fnSym] = returns.MergeCapturedContainerMutationMaps(existing, mutations, func(prev *api.ContainerMutation, next api.ContainerMutation) api.ContainerMutation { - if prev != nil { - next.ValueType = typ.JoinPreferNonSoft(prev.ValueType, next.ValueType) - } - return next - }) + writer.mergeParentFactsForSymbol(fnSym, api.Facts{ + CapturedContainers: api.CapturedContainerMutations{ + fnSym: mutations, + }, }) } } @@ -197,11 +191,11 @@ func returnSummarySnapshotForSymbol(store Store, result *api.FuncResult, parent } } } - snap := store.GetReturnSummariesSnapshot(summaryGraph, summaryScope) - if len(snap) == 0 { + facts := store.GetFunctionFactsSnapshot(summaryGraph, summaryScope) + if len(facts) == 0 { return nil } - return snap[sym] + return facts.Summary(sym) } func narrowSummarySnapshotForSymbol(store Store, result *api.FuncResult, parent *scope.State, sym cfg.SymbolID) []typ.Type { @@ -219,20 +213,20 @@ func narrowSummarySnapshotForSymbol(store Store, result *api.FuncResult, parent } } - var snap map[cfg.SymbolID][]typ.Type + var facts api.FunctionFacts if phaser, ok := any(store).(interface { WithPhase(api.Phase, func()) }); ok { phaser.WithPhase(api.PhaseNarrowing, func() { - snap = store.GetNarrowReturnSummariesSnapshot(summaryGraph, summaryScope) + facts = store.GetFunctionFactsSnapshot(summaryGraph, summaryScope) }) } else { - snap = store.GetNarrowReturnSummariesSnapshot(summaryGraph, summaryScope) + facts = store.GetFunctionFactsSnapshot(summaryGraph, summaryScope) } - if len(snap) == 0 { + if len(facts) == 0 { return nil } - return snap[sym] + return facts.NarrowSummary(sym) } func expectedFunctionFromResult(result *api.FuncResult) *typ.Function { @@ -394,45 +388,44 @@ func CollectParamHintsFromResult(store Store, result *api.FuncResult, parent *sc return } - store.UpdateInterprocFactsNext(parentKey, func(facts *api.Facts) { - if facts.ParamHints == nil { - facts.ParamHints = make(api.ParamHints) + deltaHints := make(api.ParamHints) + hints := paramhints.EnsureHintCapacity(nil, len(info.Args)) + for i, arg := range info.Args { + if arg == nil { + continue } - hints := paramhints.EnsureHintCapacity(facts.ParamHints[calleeSym], len(info.Args)) - for i, arg := range info.Args { - if arg == nil { - continue - } - if expectedFn := unwrap.Function(infer.ExpectedArgType(i)); expectedFn != nil { - argSym := checkcallsite.CanonicalSymbolFromExprWithAliases( - arg, - 0, - result.Graph, - bindings, - moduleBindings, - hasFunctionRef, - ) - if argSym != 0 && hasFunctionRef(argSym) { - hintsForFn := facts.ParamHints[argSym] - for j, param := range expectedFn.Params { - hintsForFn, _ = paramhints.MergeHintAt(hintsForFn, j, param.Type, typ.JoinPreferNonSoft) - } - if len(hintsForFn) > 0 { - facts.ParamHints[argSym] = hintsForFn - } + if expectedFn := unwrap.Function(infer.ExpectedArgType(i)); expectedFn != nil { + argSym := checkcallsite.CanonicalSymbolFromExprWithAliases( + arg, + 0, + result.Graph, + bindings, + moduleBindings, + hasFunctionRef, + ) + if argSym != 0 && hasFunctionRef(argSym) { + hintsForFn := deltaHints[argSym] + for j, param := range expectedFn.Params { + hintsForFn, _ = paramhints.MergeHintAt(hintsForFn, j, param.Type, typ.JoinPreferNonSoft) + } + if len(hintsForFn) > 0 { + deltaHints[argSym] = hintsForFn } } - - argType := argTypes[i] - if argType == nil { - argType = result.NarrowSynth.TypeOf(arg, p) - } - hints, _ = paramhints.MergeCallArgHintAt(hints, i, argType, typ.JoinPreferNonSoft, true) } - if len(hints) > 0 { - facts.ParamHints[calleeSym] = hints + + argType := argTypes[i] + if argType == nil { + argType = result.NarrowSynth.TypeOf(arg, p) } - }) + hints, _ = paramhints.MergeCallArgHintAt(hints, i, argType, typ.JoinPreferNonSoft, true) + } + if len(hints) > 0 { + deltaHints[calleeSym] = hints + } + if len(deltaHints) > 0 { + store.MergeInterprocFactsNext(parentKey, api.Facts{ParamHints: deltaHints}) + } } graph.EachCallSite(func(p cfg.Point, info *cfg.CallInfo) { diff --git a/compiler/check/infer/interproc/writer.go b/compiler/check/infer/interproc/writer.go index bae53cb3..263135e0 100644 --- a/compiler/check/infer/interproc/writer.go +++ b/compiler/check/infer/interproc/writer.go @@ -9,7 +9,7 @@ import ( ) type factsWriteStore interface { - UpdateInterprocFactsNext(key api.GraphKey, update func(*api.Facts)) + MergeInterprocFactsNext(key api.GraphKey, delta api.Facts) StoreLiteralSigs(graphID uint64, sigs map[*ast.FunctionExpr]*typ.Function) GraphKeyFor(graph *cfg.Graph, parent *scope.State) (api.GraphKey, bool) ParentGraphKeyForSymbol(sym cfg.SymbolID) (api.GraphKey, bool) @@ -23,15 +23,15 @@ func newInterprocFactWriter(store factsWriteStore) interprocFactWriter { return interprocFactWriter{store: store} } -func (w interprocFactWriter) updateParentFactsForSymbol(sym cfg.SymbolID, update func(*api.Facts)) bool { - if w.store == nil || sym == 0 || update == nil { +func (w interprocFactWriter) mergeParentFactsForSymbol(sym cfg.SymbolID, delta api.Facts) bool { + if w.store == nil || sym == 0 { return false } parentKey, ok := w.store.ParentGraphKeyForSymbol(sym) if !ok { return false } - w.store.UpdateInterprocFactsNext(parentKey, update) + w.store.MergeInterprocFactsNext(parentKey, delta) return true } @@ -45,15 +45,14 @@ func (w interprocFactWriter) writeLiteralSignatures( } w.store.StoreLiteralSigs(graph.ID(), sigs) if key, ok := w.store.GraphKeyFor(graph, parent); ok { - w.store.UpdateInterprocFactsNext(key, func(facts *api.Facts) { - if facts.LiteralSigs == nil { - facts.LiteralSigs = make(api.LiteralSigs, len(sigs)) + delta := api.LiteralSigs{} + for fnExpr, sig := range sigs { + if fnExpr != nil && sig != nil { + delta[fnExpr] = sig } - for fnExpr, sig := range sigs { - if fnExpr != nil && sig != nil { - facts.LiteralSigs[fnExpr] = sig - } - } - }) + } + if len(delta) > 0 { + w.store.MergeInterprocFactsNext(key, api.Facts{LiteralSigs: delta}) + } } } diff --git a/compiler/check/infer/interproc/writer_test.go b/compiler/check/infer/interproc/writer_test.go index e5593860..811a4a28 100644 --- a/compiler/check/infer/interproc/writer_test.go +++ b/compiler/check/infer/interproc/writer_test.go @@ -26,10 +26,8 @@ func newFactsWriteStoreStub() *factsWriteStoreStub { } } -func (s *factsWriteStoreStub) UpdateInterprocFactsNext(key api.GraphKey, update func(*api.Facts)) { - facts := s.factsByGraphKeyNext[key] - update(&facts) - s.factsByGraphKeyNext[key] = facts +func (s *factsWriteStoreStub) MergeInterprocFactsNext(key api.GraphKey, delta api.Facts) { + s.factsByGraphKeyNext[key] = delta } func (s *factsWriteStoreStub) StoreLiteralSigs(graphID uint64, sigs map[*ast.FunctionExpr]*typ.Function) { @@ -45,16 +43,16 @@ func (s *factsWriteStoreStub) ParentGraphKeyForSymbol(sym cfg.SymbolID) (api.Gra return key, ok } -func TestInterprocFactWriter_UpdateParentFactsForSymbol(t *testing.T) { +func TestInterprocFactWriter_MergeParentFactsForSymbol(t *testing.T) { stub := newFactsWriteStoreStub() key := api.GraphKey{GraphID: 7, ParentHash: 11} stub.parentKeyBySymbol[3] = key writer := newInterprocFactWriter(stub) - ok := writer.updateParentFactsForSymbol(3, func(facts *api.Facts) { - facts.ParamHints = map[cfg.SymbolID][]typ.Type{ + ok := writer.mergeParentFactsForSymbol(3, api.Facts{ + ParamHints: map[cfg.SymbolID][]typ.Type{ 3: {typ.String}, - } + }, }) if !ok { t.Fatal("expected update to succeed") @@ -64,7 +62,7 @@ func TestInterprocFactWriter_UpdateParentFactsForSymbol(t *testing.T) { t.Fatalf("unexpected parent facts update: %#v", got.ParamHints) } - if writer.updateParentFactsForSymbol(99, func(*api.Facts) {}) { + if writer.mergeParentFactsForSymbol(99, api.Facts{}) { t.Fatal("expected update to fail for unknown symbol") } } diff --git a/compiler/check/infer/nested/processor.go b/compiler/check/infer/nested/processor.go index e4c32d76..947b87c2 100644 --- a/compiler/check/infer/nested/processor.go +++ b/compiler/check/infer/nested/processor.go @@ -39,7 +39,7 @@ import ( type CheckFunc func(fn *ast.FunctionExpr, parent *scope.State) // ResultFunc returns the analysis result for a function literal. -type ResultFunc func(fn *ast.FunctionExpr) *api.FuncResultView +type ResultFunc func(fn *ast.FunctionExpr) *api.FuncResultSnapshot // Config holds dependencies for nested processing. type Config struct { @@ -48,7 +48,7 @@ type Config struct { Graphs api.GraphProvider Check CheckFunc ResultForFunc ResultFunc - RootResult *api.FuncResultView + RootResult *api.FuncResultSnapshot } // Processor analyzes nested functions for a parent graph. @@ -58,7 +58,7 @@ type Processor struct { graphs api.GraphProvider check CheckFunc resultForFunc ResultFunc - rootResult *api.FuncResultView + rootResult *api.FuncResultSnapshot } // New creates a nested processor. @@ -74,7 +74,7 @@ func New(cfg Config) *Processor { } // ProcessNestedFunctions analyzes all nested function definitions within a parent graph. -func (p *Processor) ProcessNestedFunctions(graph *cfg.Graph, parentResult *api.FuncResultView) { +func (p *Processor) ProcessNestedFunctions(graph *cfg.Graph, parentResult *api.FuncResultSnapshot) { if parentResult == nil { return } @@ -155,18 +155,18 @@ func (p *Processor) processNestedGroup( graph *cfg.Graph, scopes map[cfg.Point]*scope.State, group *nestedGroup, - parentResult *api.FuncResultView, + parentResult *api.FuncResultSnapshot, parentFunc *ast.FunctionExpr, ) { - // Build sibling types for this group. - siblingTypes := p.buildSiblingTypesForGroup(graph, scopes, group.Hash, group.Funcs, parentResult) - if siblingTypes == nil { - siblingTypes = make(map[cfg.SymbolID]typ.Type) + // Build sibling function types for this group. + siblingFunctionTypes := p.buildSiblingTypesForGroup(graph, scopes, group.Hash, group.Funcs, parentResult) + if siblingFunctionTypes == nil { + siblingFunctionTypes = make(map[cfg.SymbolID]typ.Type) } // Process each function in the group. for _, info := range group.Funcs { - p.processNestedFunction(graph, scopes, info, siblingTypes, parentResult, parentFunc) + p.processNestedFunction(graph, scopes, info, siblingFunctionTypes, parentResult, parentFunc) } } @@ -175,8 +175,8 @@ func (p *Processor) processNestedFunction( graph *cfg.Graph, scopes map[cfg.Point]*scope.State, info *nested.FuncInfo, - siblingTypes map[cfg.SymbolID]typ.Type, - parentResult *api.FuncResultView, + siblingFunctionTypes map[cfg.SymbolID]typ.Type, + parentResult *api.FuncResultSnapshot, parentFunc *ast.FunctionExpr, ) { baseParentScope := scopes[info.NF.Point] @@ -244,7 +244,7 @@ func (p *Processor) processNestedFunction( if info.FuncDef == nil || !info.FuncDef.IsMethod { fn := info.NF.Func if phasecore.HasUnannotatedSelfParam(fn, graph.Bindings()) { - selfType, tblSym := p.resolveSelfTypeForImplicitSelf(info, siblingTypes, graph, parentResult, capturedTypes) + selfType, tblSym := p.resolveSelfTypeForImplicitSelf(info, siblingFunctionTypes, graph, parentResult, capturedTypes) if selfType != nil && tblSym != 0 && p.store != nil { selfType = nested.EnrichSelfTypeWithConstructorFields(selfType, tblSym, &nestedStoreAdapter{store: p.store}) } @@ -264,7 +264,7 @@ func (p *Processor) processNestedFunction( } // Get the result for constructor detection and sibling updates. - result := (*api.FuncResultView)(nil) + result := (*api.FuncResultSnapshot)(nil) if p.resultForFunc != nil { result = p.resultForFunc(info.NF.Func) } @@ -290,7 +290,7 @@ func (p *Processor) processNestedFunction( // Update sibling types with the fully-inferred function type. if info.IsLocal && info.FuncSym != 0 && result.NarrowSynth != nil { if inferredType := result.NarrowSynth.FunctionType(info.NF.Func, parentScope); inferredType != nil { - siblingTypes[info.FuncSym] = returns.MergeFunctionFactType(siblingTypes[info.FuncSym], inferredType) + siblingFunctionTypes[info.FuncSym] = returns.MergeFunctionFactType(siblingFunctionTypes[info.FuncSym], inferredType) } } } @@ -300,8 +300,8 @@ func (p *Processor) resolveSelfTypeForMethod( info *nested.FuncInfo, sym cfg.SymbolID, graph *cfg.Graph, - parentResult *api.FuncResultView, - rootResult *api.FuncResultView, + parentResult *api.FuncResultSnapshot, + rootResult *api.FuncResultSnapshot, ) typ.Type { var selfType typ.Type @@ -322,7 +322,7 @@ func (p *Processor) resolveSelfTypeForMethod( } } - // Fall back to parent result facts. + // Then consult parent result facts. if selfType == nil && parentResult != nil && parentResult.Facts != nil { tv := parentResult.Facts.EffectiveTypeAt(info.NF.Point, sym) if tv.Type != nil && tv.State == flow.StateResolved { @@ -368,17 +368,15 @@ func (p *Processor) persistCapturedTypesForNestedGraph( if len(nextCaptured) == 0 { return } - p.store.UpdateInterprocFactsNext(key, func(facts *api.Facts) { - facts.CapturedTypes = returns.WidenCapturedTypes(facts.CapturedTypes, nextCaptured) - }) + p.store.MergeInterprocFactsNext(key, api.Facts{CapturedTypes: nextCaptured}) } // resolveSelfTypeForImplicitSelf resolves the self-type for methods with implicit self parameter. func (p *Processor) resolveSelfTypeForImplicitSelf( info *nested.FuncInfo, - siblingTypes map[cfg.SymbolID]typ.Type, + siblingFunctionTypes map[cfg.SymbolID]typ.Type, graph *cfg.Graph, - parentResult *api.FuncResultView, + parentResult *api.FuncResultSnapshot, capturedTypes map[cfg.SymbolID]typ.Type, ) (typ.Type, cfg.SymbolID) { fn := info.NF.Func @@ -388,7 +386,7 @@ func (p *Processor) resolveSelfTypeForImplicitSelf( // Pattern 1: Table literal methods {m = function(self)...} if tbl, tblSym = nested.FindTableLiteralOwner(graph, fn); tbl != nil && tblSym != 0 { - selfType = siblingTypes[tblSym] + selfType = siblingFunctionTypes[tblSym] // Use table literal type when available. if selfType == nil && parentResult != nil && parentResult.NarrowSynth != nil { selfType = parentResult.NarrowSynth.TypeOf(tbl, info.NF.Point) @@ -398,7 +396,7 @@ func (p *Processor) resolveSelfTypeForImplicitSelf( path := constraint.Path{Symbol: tblSym} selfType = parentResult.FlowSolution.TypeAt(info.NF.Point, path) } - // Fall back to Facts.EffectiveTypeAt. + // Then consult Facts.EffectiveTypeAt. if selfType == nil && parentResult != nil && parentResult.Facts != nil { tv := parentResult.Facts.EffectiveTypeAt(info.NF.Point, tblSym) if tv.Type != nil && tv.State == flow.StateResolved { @@ -406,7 +404,9 @@ func (p *Processor) resolveSelfTypeForImplicitSelf( } } if rec, ok := selfType.(*typ.Record); ok { - selfType = nested.EnrichTableTypeWithFuncTypes(rec, tbl, graph, siblingTypes) + selfType = nested.EnrichTableTypeWithFunctionLookup(rec, tbl, graph, func(sym cfg.SymbolID) typ.Type { + return siblingFunctionTypes[sym] + }) } } @@ -415,7 +415,7 @@ func (p *Processor) resolveSelfTypeForImplicitSelf( baseSym, baseTbl, baseTblPoint := nested.FindFieldAssignmentBase(graph, fn, info.NF.Point) if baseSym != 0 { tblSym = baseSym - selfType = siblingTypes[baseSym] + selfType = siblingFunctionTypes[baseSym] // Use captured types from the parent scope (flow-derived). if selfType == nil && len(capturedTypes) > 0 { if t := capturedTypes[baseSym]; t != nil { @@ -431,7 +431,7 @@ func (p *Processor) resolveSelfTypeForImplicitSelf( path := constraint.Path{Symbol: baseSym} selfType = parentResult.FlowSolution.TypeAt(info.NF.Point, path) } - // Fall back to Facts.EffectiveTypeAt. + // Then consult Facts.EffectiveTypeAt. if selfType == nil && parentResult != nil && parentResult.Facts != nil { tv := parentResult.Facts.EffectiveTypeAt(info.NF.Point, baseSym) if tv.Type != nil && tv.State == flow.StateResolved { @@ -439,7 +439,9 @@ func (p *Processor) resolveSelfTypeForImplicitSelf( } } if rec, ok := selfType.(*typ.Record); ok && baseTbl != nil { - selfType = nested.EnrichTableTypeWithFuncTypes(rec, baseTbl, graph, siblingTypes) + selfType = nested.EnrichTableTypeWithFunctionLookup(rec, baseTbl, graph, func(sym cfg.SymbolID) typ.Type { + return siblingFunctionTypes[sym] + }) } } } @@ -465,7 +467,7 @@ func (p *Processor) buildSiblingTypesForGroup( scopes map[cfg.Point]*scope.State, groupHash uint64, funcs []*nested.FuncInfo, - parentResult *api.FuncResultView, + parentResult *api.FuncResultSnapshot, ) map[cfg.SymbolID]typ.Type { if p.store == nil || graph == nil || len(funcs) == 0 { return nil @@ -488,12 +490,11 @@ func (p *Processor) buildSiblingTypesForGroup( GroupHash: groupHash, } - // Use canonical local function types (signatures + param hints + return summaries). var parentScope *scope.State if len(funcs) > 0 { parentScope = funcs[0].DefScope } - buildCfg.FuncTypes = p.store.GetLocalFuncTypesSnapshot(graph, parentScope) + buildCfg.FunctionFacts = p.store.GetFunctionFactsSnapshot(graph, parentScope) buildCfg.Services = siblings.BuildServicesFuncs{ CapturedSymbolsFn: func(fn *ast.FunctionExpr) []cfg.SymbolID { @@ -526,7 +527,7 @@ func (p *Processor) buildSiblingTypesForGroup( }, EnrichRecordFn: func(rec *typ.Record, sym cfg.SymbolID) typ.Type { if tbl, _ := nested.FindTableLiteralForSymbol(graph, sym); tbl != nil { - return nested.EnrichTableTypeWithFuncTypes(rec, tbl, graph, buildCfg.FuncTypes) + return nested.EnrichTableTypeWithFunctionLookup(rec, tbl, graph, buildCfg.FunctionFacts.FunctionType) } return nil }, diff --git a/compiler/check/infer/nested/processor_test.go b/compiler/check/infer/nested/processor_test.go index c343fbba..e1545e5f 100644 --- a/compiler/check/infer/nested/processor_test.go +++ b/compiler/check/infer/nested/processor_test.go @@ -13,5 +13,5 @@ func TestProcessNestedFunctions_NilResult(t *testing.T) { func TestProcessNestedFunctions_NilScopes(t *testing.T) { p := New(Config{}) - p.ProcessNestedFunctions(nil, &api.FuncResultView{}) + p.ProcessNestedFunctions(nil, &api.FuncResultSnapshot{}) } diff --git a/compiler/check/infer/paramhints/param_hints.go b/compiler/check/infer/paramhints/param_hints.go index cdbf1879..7097f781 100644 --- a/compiler/check/infer/paramhints/param_hints.go +++ b/compiler/check/infer/paramhints/param_hints.go @@ -299,11 +299,11 @@ func isInformativeHintType(t typ.Type, guard internal.RecursionGuard) bool { return true } -// BuildParamHintSigView builds a function-expression keyed hint map for this graph. +// BuildParamHintSignatures builds a function-expression keyed hint map for this graph. // It merges per-iteration scratch hints with symbol-based hints from the store. // Scratch hints take precedence over symbol-derived hints. -func BuildParamHintSigView( - store api.StoreView, +func BuildParamHintSignatures( + store api.StoreReader, graph *cfg.Graph, parent *scope.State, stdlib *scope.State, @@ -348,11 +348,11 @@ func BuildParamHintSigView( if meta, ok := store.NestedMetaFor(graph.ID()); ok { parentGraph := store.Graphs()[meta.ParentGraphID] if parentGraph != nil { - fallback := (*scope.State)(nil) + defaultScope := (*scope.State)(nil) if _, isNestedParent := store.NestedMetaFor(parentGraph.ID()); !isNestedParent { - fallback = stdlib + defaultScope = stdlib } - parentScope := api.ParentScopeForGraph(store, parentGraph.ID(), fallback) + parentScope := api.ParentScopeForGraph(store, parentGraph.ID(), defaultScope) if parentScope != nil { parentHints := store.GetParamHintsSnapshot(parentGraph, parentScope) if len(parentHints) > 0 { diff --git a/compiler/check/infer/paramhints/param_hints_test.go b/compiler/check/infer/paramhints/param_hints_test.go index 62fd6e82..f710ab35 100644 --- a/compiler/check/infer/paramhints/param_hints_test.go +++ b/compiler/check/infer/paramhints/param_hints_test.go @@ -111,8 +111,8 @@ func TestWidenParamHintType_RecordBecomesOpen(t *testing.T) { } } -func TestBuildParamHintSigView_NilInputs(t *testing.T) { - result := BuildParamHintSigView(nil, nil, nil, nil) +func TestBuildParamHintSignatures_NilInputs(t *testing.T) { + result := BuildParamHintSignatures(nil, nil, nil, nil) if result != nil { t.Errorf("expected nil for nil inputs, got %v", result) } diff --git a/compiler/check/infer/return/infer.go b/compiler/check/infer/return/infer.go index 06340f42..a4d38553 100644 --- a/compiler/check/infer/return/infer.go +++ b/compiler/check/infer/return/infer.go @@ -1,6 +1,6 @@ // infer.go implements return type inference for local functions. // This runs as a pre-phase before the main analysis pipeline to ensure -// return summaries are available when the parent function is analyzed. +// return vectors are available when the parent function is analyzed. // // # RETURN TYPE INFERENCE // @@ -31,7 +31,7 @@ // // # SEED PROPAGATION // -// Return summaries are seeded from the previous fixpoint iteration: +// Return vectors are seeded from the previous fixpoint iteration: // - Seeds provide initial return type estimates // - Iteration refines seeds using actual function body analysis // - Convergence occurs when seeds stabilize across iterations @@ -64,19 +64,19 @@ type Config struct { GlobalTypes map[string]typ.Type Manifests io.ManifestQuerier Stdlib *scope.State - Store api.StoreView + Store api.StoreReader Graphs api.GraphProvider SourceName string MaxIterations int } -// Inferencer computes pre-flow return summaries for local functions. +// Inferencer computes pre-flow return vectors for local functions. type Inferencer struct { types core.TypeOps globalTypes map[string]typ.Type manifests io.ManifestQuerier stdlib *scope.State - store api.StoreView + store api.StoreReader graphs api.GraphProvider sourceName string maxIterations int @@ -170,13 +170,13 @@ func (i *Inferencer) collectLocalFunctions( } // newReturnInferenceEngine creates a synthesis engine configured for return type -// inference within the pre-flow return summary computation phase. +// inference within the pre-flow return-vector computation phase. // // The engine operates in PhaseScopeCompute mode with: // - Declared types from the overlay (params, siblings, captured variables) // - Global types for built-in function resolution // - Module aliases for require() resolution -// - Return summaries from previous iteration for recursive call resolution +// - Return vectors from previous iteration for recursive call resolution // // Unlike the main analysis engine, this engine does not have access to flow // solution or narrowed types, producing "declared-phase" type estimates. @@ -197,15 +197,14 @@ func (i *Inferencer) newReturnInferenceEngine( }) } -// computeReturnSummariesForGraph computes return summaries for local functions in a graph -// and stores them into the interproc facts for the current iteration. +// ComputeForGraph computes canonical function facts for local functions in a graph. func (i *Inferencer) ComputeForGraph( run RunContext, graph *cfg.Graph, parent *scope.State, -) (api.ReturnSummaries, api.FuncTypes, []diag.Diagnostic) { +) (api.FunctionFacts, []diag.Diagnostic) { if i == nil || i.store == nil || graph == nil || parent == nil { - return nil, nil, nil + return nil, nil } parentScope := api.ParentScopeForGraph(i.store, graph.ID(), parent) @@ -214,7 +213,7 @@ func (i *Inferencer) ComputeForGraph( pointScopes := scope.BuildTypeDefScopes(graph, parentScope, engine.ResolveTypeDef) localFuncs := i.collectLocalFunctions(graph, pointScopes, graph.Func()) if len(localFuncs) == 0 { - return nil, nil, nil + return nil, nil } // Apply param hints from the stable snapshot (deterministic order). @@ -230,22 +229,28 @@ func (i *Inferencer) ComputeForGraph( } } - seed := i.store.GetReturnSummariesSnapshot(graph, parentScope) - summaries, diags := i.computeReturnSummariesForGroup(run, parentScope.GroupHash(), localFuncs, seed) - funcTypes := i.buildLocalFuncTypes(localFuncs, summaries, engine, parentScope) - return summaries, funcTypes, diags + seedFacts := i.store.GetFunctionFactsSnapshot(graph, parentScope) + seed := make(map[cfg.SymbolID][]typ.Type, len(seedFacts)) + for sym, fact := range seedFacts { + if len(fact.Summary) > 0 { + seed[sym] = fact.Summary + } + } + returnVectors, diags := i.computeReturnVectorsForGroup(run, parentScope.GroupHash(), localFuncs, seed) + functionTypes := i.buildLocalFunctionTypes(localFuncs, returnVectors, engine, parentScope) + return assembleFunctionFacts(returnVectors, functionTypes), diags } -func (i *Inferencer) buildLocalFuncTypes( +func (i *Inferencer) buildLocalFunctionTypes( localFuncs map[cfg.SymbolID]*returns.LocalFuncInfo, - summaries map[cfg.SymbolID][]typ.Type, + returnVectors map[cfg.SymbolID][]typ.Type, engine *synth.Engine, parentScope *scope.State, -) api.FuncTypes { +) map[cfg.SymbolID]typ.Type { if len(localFuncs) == 0 { return nil } - out := make(api.FuncTypes, len(localFuncs)) + out := make(map[cfg.SymbolID]typ.Type, len(localFuncs)) for _, sym := range cfg.SortedSymbolIDs(localFuncs) { info := localFuncs[sym] if info == nil || info.Fn == nil { @@ -275,8 +280,8 @@ func (i *Inferencer) buildLocalFuncTypes( fnType = merged } } - if summary := summaries[sym]; len(summary) > 0 { - if withSummary := returns.WithSummaryOrUnknown(fnType, summary); withSummary != nil { + if returnVector := returnVectors[sym]; len(returnVector) > 0 { + if withSummary := returns.WithSummaryOrUnknown(fnType, returnVector); withSummary != nil { fnType = withSummary } } @@ -288,7 +293,46 @@ func (i *Inferencer) buildLocalFuncTypes( return out } -// computeReturnSummariesForGroup computes return type summaries for a scope group +func assembleFunctionFacts( + returnVectors map[cfg.SymbolID][]typ.Type, + funcs map[cfg.SymbolID]typ.Type, +) api.FunctionFacts { + total := len(returnVectors) + len(funcs) + if total == 0 { + return nil + } + symbols := make(map[cfg.SymbolID]bool, total) + for sym := range returnVectors { + if sym != 0 { + symbols[sym] = true + } + } + for sym := range funcs { + if sym != 0 { + symbols[sym] = true + } + } + if len(symbols) == 0 { + return nil + } + out := make(api.FunctionFacts, len(symbols)) + for _, sym := range cfg.SortedSymbolIDs(symbols) { + ff := returns.JoinFunctionFact(api.FunctionFact{}, api.FunctionFact{ + Summary: returnVectors[sym], + Type: funcs[sym], + }) + if len(ff.Summary) == 0 && ff.Type == nil && len(ff.Narrow) == 0 { + continue + } + out[sym] = ff + } + if len(out) == 0 { + return nil + } + return out +} + +// computeReturnVectorsForGroup computes return type vectors for a scope group // using strongly connected component (SCC) based fixpoint iteration. // // SCC ORDERING: Functions are partitioned into SCCs by their call graph. SCCs are @@ -306,7 +350,7 @@ func (i *Inferencer) buildLocalFuncTypes( // // SEEDING: Initial return type estimates come from the seed map (previous fixpoint // iteration). This accelerates convergence for iteratively-refined modules. -func (i *Inferencer) computeReturnSummariesForGroup( +func (i *Inferencer) computeReturnVectorsForGroup( run RunContext, groupHash uint64, localFuncs map[cfg.SymbolID]*returns.LocalFuncInfo, @@ -322,15 +366,15 @@ func (i *Inferencer) computeReturnSummariesForGroup( return nil, nil } - summaries := seedSummariesFromSeed(localFuncs, seed) - return summaries, i.processSCCSummaries(run, sccs, localFuncs, summaries) + returnVectors := seedReturnVectorsFromSeed(localFuncs, seed) + return returnVectors, i.processSCCReturnVectors(run, sccs, localFuncs, returnVectors) } // returnInferenceContext holds shared state for return type inference phases. type returnInferenceContext struct { run RunContext info *returns.LocalFuncInfo - summaries map[cfg.SymbolID][]typ.Type + returnVectors map[cfg.SymbolID][]typ.Type localFuncs map[cfg.SymbolID]*returns.LocalFuncInfo engine *synth.Engine resolveScope *scope.State @@ -368,7 +412,7 @@ func collectReturnTypes( returnTypes = joinReturnTypes(returnTypes, types) }) - return returns.NormalizeReturnVector(returnTypes) + return returns.NormalizeReturnVectorInPlace(returnTypes) } // synthesizeReturnExprs computes types for a single return statement's expressions. @@ -381,9 +425,9 @@ func synthesizeReturnExprs( return nil } - var types []typ.Type + types := make([]typ.Type, 0, len(retInfo.Exprs)) for i, expr := range retInfo.Exprs { - if i == len(retInfo.Exprs)-1 { + if i == len(retInfo.Exprs)-1 && ast.CanProduceMultipleValues(expr) { multi := synthEngine.MultiTypeOf(expr, p) if len(multi) == 0 { multi = []typ.Type{typ.Unknown} @@ -435,15 +479,15 @@ func (i *Inferencer) inferReturnTypesFromBody( if fnGraph == nil { return narrowed } - phaseReturnSummaries := summarizeWithoutCurrent(ctx.summaries, ctx.info) + phaseFunctionFacts := functionFactsExcludingCurrent(ctx.returnVectors, ctx.info) declCheckCtx := api.NewReturnInferenceEnv(api.ReturnInferenceEnvConfig{ - Graph: fnGraph, - Bindings: ctx.bindings, - BaseScope: ctx.resolveScope, - DeclaredTypes: finalOverlay, - GlobalTypes: i.globalTypes, - ModuleAliases: ctx.moduleAliases, - ReturnSummaries: phaseReturnSummaries, + Graph: fnGraph, + Bindings: ctx.bindings, + BaseScope: ctx.resolveScope, + DeclaredTypes: finalOverlay, + GlobalTypes: i.globalTypes, + ModuleAliases: ctx.moduleAliases, + FunctionFacts: phaseFunctionFacts, }) declSynth := i.newReturnInferenceEngine( ctx.run, @@ -455,15 +499,16 @@ func (i *Inferencer) inferReturnTypesFromBody( return returns.MergeReturnSummary(declared, narrowed) } -// inferReturnWithSummary infers return types for a single function using available summaries. -// This is the core inference logic called by computeReturnSummariesForGroup for each function. +// inferReturnForFunction infers return types for one local function from the +// current SCC return-vector state. +// This is the core inference logic called by computeReturnVectorsForGroup for each function. // // TWO-PHASE INFERENCE: // // Phase 1 (Preliminary): Collect inferred types for local variables within the function. // This uses a preliminary synthesis engine with: // - Parameter types (from annotations or param hints) -// - Sibling function types (from summaries) +// - Sibling function types (from return vectors) // - Captured variable types (from parent function result) // // Phase 2 (Final): Compute return types using enriched overlay containing: @@ -477,10 +522,10 @@ func (i *Inferencer) inferReturnTypesFromBody( // // MULTI-RETURN: Functions may return multiple values. The inference handles multi-return // by expanding the last expression (which may be a call or vararg) and joining position-wise. -func (i *Inferencer) inferReturnWithSummary( +func (i *Inferencer) inferReturnForFunction( run RunContext, info *returns.LocalFuncInfo, - summaries map[cfg.SymbolID][]typ.Type, + returnVectors map[cfg.SymbolID][]typ.Type, localFuncs map[cfg.SymbolID]*returns.LocalFuncInfo, ) []typ.Type { if info == nil || info.Fn == nil || info.Graph == nil { @@ -525,7 +570,7 @@ func (i *Inferencer) inferReturnWithSummary( ctx := &returnInferenceContext{ run: run, info: info, - summaries: summaries, + returnVectors: returnVectors, localFuncs: localFuncs, engine: engine, resolveScope: resolveScope, @@ -537,12 +582,12 @@ func (i *Inferencer) inferReturnWithSummary( // Build type overlay with parameter types. overlay := i.buildParameterOverlay(ctx) - // Add sibling function types from summaries. + // Add sibling function types from return vectors. i.enrichOverlayWithSiblings(ctx, overlay) - // Collect all return summaries and add local function types. - allSummaries := i.collectAllReturnSummaries(ctx) - i.enrichOverlayWithLocalFunctions(ctx, overlay, allSummaries) + // Collect normalized return vectors and add local function types. + allReturnVectors := i.collectAllReturnVectors(ctx) + i.enrichOverlayWithLocalFunctions(ctx, overlay, allReturnVectors) // Add captured variable types from parent. i.enrichOverlayWithCaptured(ctx, overlay) diff --git a/compiler/check/infer/return/infer_test.go b/compiler/check/infer/return/infer_test.go index e06faf85..da5f0436 100644 --- a/compiler/check/infer/return/infer_test.go +++ b/compiler/check/infer/return/infer_test.go @@ -10,21 +10,18 @@ import ( "github.com/wippyai/go-lua/types/typ" ) -func TestComputeReturnSummariesForGraph_Empty(t *testing.T) { +func TestComputeFunctionFactsForGraph_Empty(t *testing.T) { inferencer := New(Config{}) - summaries, funcTypes, diags := inferencer.ComputeForGraph(RunContext{}, nil, nil) - if summaries != nil { - t.Error("nil graph should return nil summaries") - } - if funcTypes != nil { - t.Error("nil graph should return nil function types") + functionFacts, diags := inferencer.ComputeForGraph(RunContext{}, nil, nil) + if functionFacts != nil { + t.Error("nil graph should return nil function facts") } if len(diags) != 0 { t.Error("nil graph should return no diagnostics") } } -func TestSeedSummariesFromSeed_UsesKnownFunctionSymbolsOnly(t *testing.T) { +func TestSeedReturnVectorsFromSeed_UsesKnownFunctionSymbolsOnly(t *testing.T) { localFuncs := map[cfg.SymbolID]*returns.LocalFuncInfo{ 1: nil, 2: nil, @@ -34,28 +31,28 @@ func TestSeedSummariesFromSeed_UsesKnownFunctionSymbolsOnly(t *testing.T) { 3: {typ.Number}, // not in local funcs; should be ignored } - got := seedSummariesFromSeed(localFuncs, seed) + got := seedReturnVectorsFromSeed(localFuncs, seed) if len(got) != 1 { - t.Fatalf("expected one seeded summary, got %d", len(got)) + t.Fatalf("expected one seeded return vector, got %d", len(got)) } if seeded := got[1]; len(seeded) != 1 || !typ.TypeEquals(seeded[0], typ.String) { - t.Fatalf("unexpected seeded summary for symbol 1: %v", seeded) + t.Fatalf("unexpected seeded return vector for symbol 1: %v", seeded) } if _, ok := got[3]; ok { t.Fatalf("unexpected seed for unknown symbol 3: %v", got[3]) } } -func TestSeedSummariesFromSeed_HandlesNilSeed(t *testing.T) { +func TestSeedReturnVectorsFromSeed_HandlesNilSeed(t *testing.T) { localFuncs := map[cfg.SymbolID]*returns.LocalFuncInfo{ 1: nil, } - got := seedSummariesFromSeed(localFuncs, nil) + got := seedReturnVectorsFromSeed(localFuncs, nil) if got == nil { - t.Fatal("expected non-nil summary map") + t.Fatal("expected non-nil return-vector map") } if len(got) != 0 { - t.Fatalf("expected empty summary map, got %v", got) + t.Fatalf("expected empty return-vector map, got %v", got) } } @@ -141,50 +138,50 @@ func TestReconcileSoftAnnotatedInference_RecordTemplateKeepsFields(t *testing.T) } } -func TestCollectAllReturnSummaries_NormalizesAndFilters(t *testing.T) { +func TestCollectAllReturnVectors_NormalizesAndFilters(t *testing.T) { inferencer := New(Config{}) ctx := &returnInferenceContext{ - summaries: map[cfg.SymbolID][]typ.Type{ + returnVectors: map[cfg.SymbolID][]typ.Type{ 0: {typ.String}, // invalid symbol id, ignored - 1: nil, // empty summary, ignored + 1: nil, // empty return vector, ignored 2: {nil, typ.String}, }, } - got := inferencer.collectAllReturnSummaries(ctx) + got := inferencer.collectAllReturnVectors(ctx) if len(got) != 1 { - t.Fatalf("expected one normalized summary, got %d (%v)", len(got), got) + t.Fatalf("expected one normalized return vector, got %d (%v)", len(got), got) } - summary := got[2] - if len(summary) != 2 { - t.Fatalf("expected 2-slot summary, got %v", summary) + returnVector := got[2] + if len(returnVector) != 2 { + t.Fatalf("expected 2-slot return vector, got %v", returnVector) } - if !typ.TypeEquals(summary[0], typ.Nil) { - t.Fatalf("expected first slot normalized to nil, got %v", summary[0]) + if !typ.TypeEquals(returnVector[0], typ.Nil) { + t.Fatalf("expected first slot normalized to nil, got %v", returnVector[0]) } - if !typ.TypeEquals(summary[1], typ.String) { - t.Fatalf("expected second slot string, got %v", summary[1]) + if !typ.TypeEquals(returnVector[1], typ.String) { + t.Fatalf("expected second slot string, got %v", returnVector[1]) } } -func TestResolveLocalFunctionSummary_UsesCurrentSummaryWithoutStore(t *testing.T) { +func TestResolveLocalFunctionReturns_UsesCurrentVectorWithoutStore(t *testing.T) { inferencer := New(Config{}) - got := inferencer.resolveLocalFunctionSummary(nil, map[cfg.SymbolID][]typ.Type{ + got := inferencer.resolveLocalFunctionReturns(nil, map[cfg.SymbolID][]typ.Type{ 1: {typ.String}, }, 1) if len(got) != 1 || !typ.TypeEquals(got[0], typ.String) { - t.Fatalf("expected string summary, got %v", got) + t.Fatalf("expected string return vector, got %v", got) } - unknownOnly := inferencer.resolveLocalFunctionSummary(nil, map[cfg.SymbolID][]typ.Type{ + unknownOnly := inferencer.resolveLocalFunctionReturns(nil, map[cfg.SymbolID][]typ.Type{ 1: {typ.Unknown}, }, 1) if len(unknownOnly) != 1 || !typ.TypeEquals(unknownOnly[0], typ.Unknown) { - t.Fatalf("expected unknown summary without store fallback, got %v", unknownOnly) + t.Fatalf("expected unknown return vector without store recovery, got %v", unknownOnly) } - if got := inferencer.resolveLocalFunctionSummary(nil, nil, 0); got != nil { - t.Fatalf("expected nil summary for symbol 0, got %v", got) + if got := inferencer.resolveLocalFunctionReturns(nil, nil, 0); got != nil { + t.Fatalf("expected nil return vector for symbol 0, got %v", got) } } diff --git a/compiler/check/infer/return/overlay_pipeline.go b/compiler/check/infer/return/overlay_pipeline.go index 2f0cfdbe..50a5d745 100644 --- a/compiler/check/infer/return/overlay_pipeline.go +++ b/compiler/check/infer/return/overlay_pipeline.go @@ -19,9 +19,10 @@ import ( ) func (i *Inferencer) buildParameterOverlay(ctx *returnInferenceContext) map[cfg.SymbolID]typ.Type { - overlay := make(map[cfg.SymbolID]typ.Type) fnGraph := ctx.info.Graph - for _, slot := range fnGraph.ParamSlotsReadOnly() { + paramSlots := fnGraph.ParamSlotsReadOnly() + overlay := make(map[cfg.SymbolID]typ.Type, overlaySymbolCapacity(fnGraph, len(paramSlots))) + for _, slot := range paramSlots { if slot.Symbol == 0 { continue } @@ -66,7 +67,18 @@ func (i *Inferencer) buildParameterOverlay(ctx *returnInferenceContext) map[cfg. return overlay } -// enrichOverlayWithSiblings adds sibling function types to the overlay using summaries. +func overlaySymbolCapacity(fnGraph *cfg.Graph, floor int) int { + if fnGraph == nil { + return floor + } + if count := fnGraph.SymbolCount(); count > floor { + return count + } + return floor +} + +// enrichOverlayWithSiblings adds sibling function types from the current +// return-vector state. func (i *Inferencer) enrichOverlayWithSiblings( ctx *returnInferenceContext, overlay map[cfg.SymbolID]typ.Type, @@ -82,9 +94,9 @@ func (i *Inferencer) enrichOverlayWithSiblings( } } siblingOverlay := siblings.BuildOverlay(siblings.OverlayConfig{ - Summaries: ctx.summaries, - Siblings: siblingEntries, - CurrentSym: ctx.info.Sym, + ReturnVectors: ctx.returnVectors, + Siblings: siblingEntries, + CurrentSym: ctx.info.Sym, Services: siblings.OverlayServicesFuncs{ SeedTypeFn: func(fn *ast.FunctionExpr) typ.Type { var bindings interface { @@ -105,26 +117,26 @@ func (i *Inferencer) enrichOverlayWithSiblings( } } -// collectAllReturnSummaries normalizes the current local summary map. -func (i *Inferencer) collectAllReturnSummaries(ctx *returnInferenceContext) map[cfg.SymbolID][]typ.Type { - if ctx == nil || len(ctx.summaries) == 0 { +// collectAllReturnVectors normalizes the current local return-vector map. +func (i *Inferencer) collectAllReturnVectors(ctx *returnInferenceContext) map[cfg.SymbolID][]typ.Type { + if ctx == nil || len(ctx.returnVectors) == 0 { return nil } - allSummaries := make(map[cfg.SymbolID][]typ.Type, len(ctx.summaries)) - for _, sym := range cfg.SortedSymbolIDs(ctx.summaries) { + allReturnVectors := make(map[cfg.SymbolID][]typ.Type, len(ctx.returnVectors)) + for _, sym := range cfg.SortedSymbolIDs(ctx.returnVectors) { if sym == 0 { continue } - normalized := returns.NormalizeReturnVector(ctx.summaries[sym]) + normalized := returns.NormalizeReturnVectorInPlace(ctx.returnVectors[sym]) if len(normalized) == 0 { continue } - allSummaries[sym] = normalized + allReturnVectors[sym] = normalized } - return allSummaries + return allReturnVectors } -func (i *Inferencer) summaryFromSnapshot( +func (i *Inferencer) returnVectorFromSnapshot( graph *cfg.Graph, parentScope *scope.State, sym cfg.SymbolID, @@ -132,67 +144,67 @@ func (i *Inferencer) summaryFromSnapshot( if i == nil || i.store == nil || graph == nil || parentScope == nil || sym == 0 { return nil } - snap := i.store.GetReturnSummariesSnapshot(graph, parentScope) - if len(snap) == 0 { + facts := i.store.GetFunctionFactsSnapshot(graph, parentScope) + if len(facts) == 0 { return nil } - normalized := returns.NormalizeReturnVector(snap[sym]) + normalized := returns.NormalizeReturnVector(facts.Summary(sym)) if len(normalized) == 0 { return nil } return normalized } -func (i *Inferencer) resolveLocalFunctionSummary( +func (i *Inferencer) resolveLocalFunctionReturns( ctx *returnInferenceContext, - allSummaries map[cfg.SymbolID][]typ.Type, + allReturnVectors map[cfg.SymbolID][]typ.Type, sym cfg.SymbolID, ) []typ.Type { if sym == 0 { return nil } - // Keep the current SCC-derived summary unless it is still unknown-only. - summary := returns.NormalizeReturnVector(allSummaries[sym]) - if !typ.IsUnknownOnlyOrEmpty(summary) { - return summary + // Keep the current SCC-derived return vector unless it is still unknown-only. + returnVector := returns.NormalizeReturnVectorInPlace(allReturnVectors[sym]) + if !typ.IsUnknownOnlyOrEmpty(returnVector) { + return returnVector } if ctx == nil || i == nil || i.store == nil { - return summary + return returnVector } - // First fallback: current graph snapshot under the current resolve scope. + // Snapshot recovery path: current graph under the current resolve scope. if ctx.info != nil && ctx.info.Graph != nil && ctx.resolveScope != nil { - if snapSummary := i.summaryFromSnapshot(ctx.info.Graph, ctx.resolveScope, sym); len(snapSummary) > 0 { - return snapSummary + if snapVector := i.returnVectorFromSnapshot(ctx.info.Graph, ctx.resolveScope, sym); len(snapVector) > 0 { + return snapVector } } - // Second fallback: parent graph snapshot for the function symbol, if known. + // Snapshot recovery path: parent graph for the function symbol, if known. ref := i.store.FunctionRefBySym(sym) if ref == nil || ref.ParentGraphID == 0 { - return summary + return returnVector } parentGraph := i.store.Graphs()[ref.ParentGraphID] if parentGraph == nil { - return summary + return returnVector } parentScope := api.ParentScopeForGraph(i.store, parentGraph.ID(), nil) if parentScope == nil { - return summary + return returnVector } - if snapSummary := i.summaryFromSnapshot(parentGraph, parentScope, sym); len(snapSummary) > 0 { - return snapSummary + if snapVector := i.returnVectorFromSnapshot(parentGraph, parentScope, sym); len(snapVector) > 0 { + return snapVector } - return summary + return returnVector } // enrichOverlayWithLocalFunctions adds local function types from the function body. func (i *Inferencer) enrichOverlayWithLocalFunctions( ctx *returnInferenceContext, overlay map[cfg.SymbolID]typ.Type, - allSummaries map[cfg.SymbolID][]typ.Type, + allReturnVectors map[cfg.SymbolID][]typ.Type, ) { ctx.info.Graph.EachAssign(func(p cfg.Point, assignInfo *cfg.AssignInfo) { if assignInfo == nil || !assignInfo.IsLocal || len(assignInfo.Targets) == 0 || len(assignInfo.Sources) == 0 { @@ -217,9 +229,9 @@ func (i *Inferencer) enrichOverlayWithLocalFunctions( i.store.RegisterFunctionRef(target.Symbol, fnExpr, fnGraph, ctx.info.Graph.ID(), p) } } - summary := i.resolveLocalFunctionSummary(ctx, allSummaries, target.Symbol) + returnVector := i.resolveLocalFunctionReturns(ctx, allReturnVectors, target.Symbol) sig := ctx.engine.ResolveFunctionSignature(fnExpr, ctx.resolveScope) - if fnType := returns.WithSummaryOrUnknown(sig, summary); fnType != nil { + if fnType := returns.WithSummaryOrUnknown(sig, returnVector); fnType != nil { overlay[target.Symbol] = fnType } } @@ -311,12 +323,7 @@ func (i *Inferencer) inferLocalVariableTypes( ) (map[cfg.SymbolID]typ.Type, *synth.Engine, func(ast.Expr, cfg.Point) typ.Type) { fnGraph := ctx.info.Graph annotated := make(map[cfg.SymbolID]bool, len(overlay)) - paramSet := make(map[cfg.SymbolID]bool) - for _, sym := range fnGraph.ParamSymbols() { - if sym != 0 { - paramSet[sym] = true - } - } + paramSet := paramSymbolSet(fnGraph) for sym, tp := range overlay { if paramSet[sym] { annotated[sym] = true @@ -352,13 +359,13 @@ func (i *Inferencer) inferLocalVariableTypes( fnScopes := uniformFunctionScopes(fnGraph, ctx.resolveScope) prelimCtx := api.NewReturnInferenceEnv(api.ReturnInferenceEnvConfig{ - Graph: fnGraph, - Bindings: ctx.bindings, - BaseScope: ctx.resolveScope, - DeclaredTypes: overlay, - GlobalTypes: i.globalTypes, - ModuleAliases: ctx.moduleAliases, - ReturnSummaries: ctx.summaries, + Graph: fnGraph, + Bindings: ctx.bindings, + BaseScope: ctx.resolveScope, + DeclaredTypes: overlay, + GlobalTypes: i.globalTypes, + ModuleAliases: ctx.moduleAliases, + FunctionFacts: functionFactsFromReturnVectors(ctx.returnVectors), }) prelimEngine := i.newReturnInferenceEngine(ctx.run, fnScopes, prelimCtx) @@ -513,13 +520,17 @@ func newOverlayMutationStage( } func paramSymbolSet(graph *cfg.Graph) map[cfg.SymbolID]bool { - out := make(map[cfg.SymbolID]bool) if graph == nil { - return out + return nil + } + paramSlots := graph.ParamSlotsReadOnly() + if len(paramSlots) == 0 { + return nil } - for _, sym := range graph.ParamSymbols() { - if sym != 0 { - out[sym] = true + out := make(map[cfg.SymbolID]bool, len(paramSlots)) + for _, slot := range paramSlots { + if slot.Symbol != 0 { + out[slot.Symbol] = true } } return out @@ -749,7 +760,7 @@ type phase2InferenceState struct { } // runPhase2FlowNarrowing executes extract->solve->narrow over the final overlay. -// This makes return summary collection path-sensitive instead of declared-only. +// This makes return collection path-sensitive instead of declared-only. func (i *Inferencer) runPhase2FlowNarrowing( ctx *returnInferenceContext, finalOverlay map[cfg.SymbolID]typ.Type, @@ -781,13 +792,13 @@ func (i *Inferencer) runPhase2FlowNarrowing( return ctx.engine.ResolveFunctionSignature(fn, sc) }), } - phaseReturnSummaries := summarizeWithoutCurrent(ctx.summaries, ctx.info) + phaseFunctionFacts := functionFactsExcludingCurrent(ctx.returnVectors, ctx.info) extractOut := phase.RunExtract(phase.FlowExtractInput{ - PhaseEnv: phaseEnv, - Resolve: phase.ResolveOutput{TypeResolver: ctx.engine}, - Scope: scopeOut, - ReturnSummaries: phaseReturnSummaries, + PhaseEnv: phaseEnv, + Resolve: phase.ResolveOutput{TypeResolver: ctx.engine}, + Scope: scopeOut, + FunctionFacts: phaseFunctionFacts, }) if extractOut.Inputs == nil { return phase2InferenceState{} @@ -800,11 +811,11 @@ func (i *Inferencer) runPhase2FlowNarrowing( }) narrowOut := phase.RunNarrow(phase.NarrowInput{ - PhaseEnv: phaseEnv, - Scope: scopeOut, - Extract: extractOut, - Solve: solveOut, - NarrowReturnSummaries: phaseReturnSummaries, + PhaseEnv: phaseEnv, + Scope: scopeOut, + Extract: extractOut, + Solve: solveOut, + FunctionFacts: phaseFunctionFacts, }) deadPoints := map[cfg.Point]bool{} @@ -823,15 +834,15 @@ func (i *Inferencer) runPhase2FlowNarrowing( } } - // Fallback: declared-phase synth (should be uncommon, e.g. nil solution path). + // Declared-phase recomputation path for uncommon nil-solution states. fnCheckCtx := api.NewReturnInferenceEnv(api.ReturnInferenceEnvConfig{ - Graph: fnGraph, - Bindings: ctx.bindings, - BaseScope: ctx.resolveScope, - DeclaredTypes: finalOverlay, - GlobalTypes: i.globalTypes, - ModuleAliases: ctx.moduleAliases, - ReturnSummaries: phaseReturnSummaries, + Graph: fnGraph, + Bindings: ctx.bindings, + BaseScope: ctx.resolveScope, + DeclaredTypes: finalOverlay, + GlobalTypes: i.globalTypes, + ModuleAliases: ctx.moduleAliases, + FunctionFacts: phaseFunctionFacts, }) return phase2InferenceState{ synth: i.newReturnInferenceEngine(ctx.run, fnScopes, fnCheckCtx), @@ -839,22 +850,43 @@ func (i *Inferencer) runPhase2FlowNarrowing( } } -func summarizeWithoutCurrent( - summaries map[cfg.SymbolID][]typ.Type, +func functionFactsExcludingCurrent( + returnVectors map[cfg.SymbolID][]typ.Type, info *returns.LocalFuncInfo, -) map[cfg.SymbolID][]typ.Type { - if len(summaries) == 0 || info == nil || info.Sym == 0 { - return summaries +) api.FunctionFacts { + if len(returnVectors) == 0 || info == nil || info.Sym == 0 { + return functionFactsFromReturnVectors(returnVectors) } - if _, ok := summaries[info.Sym]; !ok { - return summaries + if _, ok := returnVectors[info.Sym]; !ok { + return functionFactsFromReturnVectors(returnVectors) } - out := make(map[cfg.SymbolID][]typ.Type, len(summaries)-1) - for _, sym := range cfg.SortedSymbolIDs(summaries) { + out := make(map[cfg.SymbolID][]typ.Type, len(returnVectors)-1) + for _, sym := range cfg.SortedSymbolIDs(returnVectors) { if sym == info.Sym { continue } - out[sym] = summaries[sym] + out[sym] = returnVectors[sym] + } + return functionFactsFromReturnVectors(out) +} + +func functionFactsFromReturnVectors(returnVectors map[cfg.SymbolID][]typ.Type) api.FunctionFacts { + if len(returnVectors) == 0 { + return nil + } + out := make(api.FunctionFacts, len(returnVectors)) + for _, sym := range cfg.SortedSymbolIDs(returnVectors) { + if sym == 0 { + continue + } + returnVector := returns.NormalizeReturnVectorInPlace(returnVectors[sym]) + if len(returnVector) == 0 { + continue + } + out[sym] = api.FunctionFact{Summary: returnVector} + } + if len(out) == 0 { + return nil } return out } diff --git a/compiler/check/infer/return/scc.go b/compiler/check/infer/return/scc.go index 7f4aa8b3..0689a60e 100644 --- a/compiler/check/infer/return/scc.go +++ b/compiler/check/infer/return/scc.go @@ -15,11 +15,11 @@ func (i *Inferencer) iterateSCCFixpoint( run RunContext, scc []cfg.SymbolID, localFuncs map[cfg.SymbolID]*returns.LocalFuncInfo, - summaries map[cfg.SymbolID][]typ.Type, + returnVectors map[cfg.SymbolID][]typ.Type, ) bool { for iter := 0; iter < i.maxIterations; iter++ { - next, changed := i.runSCCIteration(run, scc, localFuncs, summaries) - applySCCIterationUpdates(summaries, scc, next) + next, changed := i.runSCCIteration(run, scc, localFuncs, returnVectors) + applySCCIterationUpdates(returnVectors, scc, next) if !changed { return true } @@ -40,37 +40,37 @@ func (i *Inferencer) planLocalFunctionSCCs(localFuncs map[cfg.SymbolID]*returns. return returns.ComputeSymbolSCCs(adj) } -func seedSummariesFromSeed( +func seedReturnVectorsFromSeed( localFuncs map[cfg.SymbolID]*returns.LocalFuncInfo, seed map[cfg.SymbolID][]typ.Type, ) map[cfg.SymbolID][]typ.Type { - summaries := make(map[cfg.SymbolID][]typ.Type, len(localFuncs)) + returnVectors := make(map[cfg.SymbolID][]typ.Type, len(localFuncs)) if seed == nil { - return summaries + return returnVectors } for _, sym := range cfg.SortedSymbolIDs(localFuncs) { if seeded := seed[sym]; len(seeded) > 0 { - summaries[sym] = seeded + returnVectors[sym] = seeded } } - return summaries + return returnVectors } -func (i *Inferencer) processSCCSummaries( +func (i *Inferencer) processSCCReturnVectors( run RunContext, sccs [][]cfg.SymbolID, localFuncs map[cfg.SymbolID]*returns.LocalFuncInfo, - summaries map[cfg.SymbolID][]typ.Type, + returnVectors map[cfg.SymbolID][]typ.Type, ) []diag.Diagnostic { var diags []diag.Diagnostic for _, scc := range sccs { if len(scc) == 0 { continue } - if i.iterateSCCFixpoint(run, scc, localFuncs, summaries) { + if i.iterateSCCFixpoint(run, scc, localFuncs, returnVectors) { continue } - if warn := i.widenSCCToUnknown(scc, localFuncs, summaries); warn != nil { + if warn := i.widenSCCToUnknown(scc, localFuncs, returnVectors); warn != nil { diags = append(diags, *warn) } } @@ -81,7 +81,7 @@ func (i *Inferencer) runSCCIteration( run RunContext, scc []cfg.SymbolID, localFuncs map[cfg.SymbolID]*returns.LocalFuncInfo, - summaries map[cfg.SymbolID][]typ.Type, + returnVectors map[cfg.SymbolID][]typ.Type, ) (map[cfg.SymbolID][]typ.Type, bool) { changed := false next := make(map[cfg.SymbolID][]typ.Type, len(scc)) @@ -90,8 +90,8 @@ func (i *Inferencer) runSCCIteration( if info == nil || info.Fn == nil { continue } - newReturn := i.inferReturnWithSummary(run, info, summaries, localFuncs) - oldReturn := summaries[sym] + newReturn := i.inferReturnForFunction(run, info, returnVectors, localFuncs) + oldReturn := returnVectors[sym] merged := returns.MergeReturnSummary(oldReturn, newReturn) next[sym] = merged if !returns.ReturnTypesEqual(merged, oldReturn) { @@ -102,13 +102,13 @@ func (i *Inferencer) runSCCIteration( } func applySCCIterationUpdates( - summaries map[cfg.SymbolID][]typ.Type, + returnVectors map[cfg.SymbolID][]typ.Type, scc []cfg.SymbolID, next map[cfg.SymbolID][]typ.Type, ) { for _, sym := range scc { if v, ok := next[sym]; ok { - summaries[sym] = v + returnVectors[sym] = v } } } @@ -118,18 +118,18 @@ func applySCCIterationUpdates( func (i *Inferencer) widenSCCToUnknown( scc []cfg.SymbolID, localFuncs map[cfg.SymbolID]*returns.LocalFuncInfo, - summaries map[cfg.SymbolID][]typ.Type, + returnVectors map[cfg.SymbolID][]typ.Type, ) *diag.Diagnostic { for _, sym := range scc { - existing := summaries[sym] + existing := returnVectors[sym] if len(existing) == 0 { - summaries[sym] = []typ.Type{typ.Unknown} + returnVectors[sym] = []typ.Type{typ.Unknown} } else { widened := make([]typ.Type, len(existing)) for i := range widened { widened[i] = typ.Unknown } - summaries[sym] = widened + returnVectors[sym] = widened } } if info := localFuncs[scc[0]]; info != nil && info.Fn != nil { diff --git a/compiler/check/nested/enrich.go b/compiler/check/nested/enrich.go index f6736612..a258fed9 100644 --- a/compiler/check/nested/enrich.go +++ b/compiler/check/nested/enrich.go @@ -19,20 +19,20 @@ import ( // (with inferred refinements and return types) may be more precise than the initially // synthesized type. These utilities replace placeholder types with literal sigs. -// EnrichTableTypeWithFuncTypes replaces method function types in a record -// with canonical function types derived from the interproc queries. +// EnrichTableTypeWithFunctionLookup replaces method function types in a record +// with function types resolved by symbol. // // For table literals with method fields, the initially synthesized record may // have function types without inferred return types. After analyzing the methods, // canonical function types are available per symbol. This function updates the // record with those more precise signatures. -func EnrichTableTypeWithFuncTypes( +func EnrichTableTypeWithFunctionLookup( rec *typ.Record, tableExpr *ast.TableExpr, graph *cfg.Graph, - funcTypes map[cfg.SymbolID]typ.Type, + lookup func(cfg.SymbolID) typ.Type, ) typ.Type { - if rec == nil || tableExpr == nil || graph == nil || len(funcTypes) == 0 { + if rec == nil || tableExpr == nil || graph == nil || lookup == nil { return rec } @@ -60,7 +60,7 @@ func EnrichTableTypeWithFuncTypes( } if bindings != nil { if sym, ok := bindings.FuncLitSymbol(fnExpr); ok { - if t := funcTypes[sym]; t != nil { + if t := lookup(sym); t != nil { fieldType = t modified = true } diff --git a/compiler/check/nested/enrich_test.go b/compiler/check/nested/enrich_test.go index fedd0499..b496afb4 100644 --- a/compiler/check/nested/enrich_test.go +++ b/compiler/check/nested/enrich_test.go @@ -11,15 +11,15 @@ import ( "github.com/wippyai/go-lua/types/typ" ) -func TestEnrichTableTypeWithFuncTypes_NilInputs(t *testing.T) { - result := EnrichTableTypeWithFuncTypes(nil, nil, nil, nil) +func TestEnrichTableTypeWithFunctionLookup_NilInputs(t *testing.T) { + result := EnrichTableTypeWithFunctionLookup(nil, nil, nil, nil) if rec, ok := result.(*typ.Record); !ok || rec != nil { t.Error("expected nil record for nil inputs") } } -func TestEnrichTableTypeWithFuncTypes_NilRecord(t *testing.T) { - result := EnrichTableTypeWithFuncTypes(nil, nil, &cfg.Graph{}, nil) +func TestEnrichTableTypeWithFunctionLookup_NilRecord(t *testing.T) { + result := EnrichTableTypeWithFunctionLookup(nil, nil, &cfg.Graph{}, nil) if rec, ok := result.(*typ.Record); !ok || rec != nil { t.Error("expected nil record for nil record input") } diff --git a/compiler/check/phase/flow.go b/compiler/check/phase/flow.go index 2df1db33..74de7e87 100644 --- a/compiler/check/phase/flow.go +++ b/compiler/check/phase/flow.go @@ -29,9 +29,8 @@ func RunExtract(input FlowExtractInput) FlowExtractOutput { extractionCtx := NewContextBuilder(input.PhaseEnv). WithScope(input.Scope). - WithSiblingTypes(input.SiblingTypes). + WithFunctionFacts(input.FunctionFacts). WithLiteralTypes(input.LiteralTypes). - WithReturnSummaries(input.ReturnSummaries). BuildDeclared() engine := synth.New(synth.Config{ @@ -50,6 +49,7 @@ func RunExtract(input FlowExtractInput) FlowExtractOutput { Scopes: input.Scope.Scopes, CheckCtx: extractionCtx, CallCtx: input.Ctx, + Graphs: api.GraphsFrom(input.Ctx), TypeOps: input.Types, Base: input.Scope.BaseScope, Globals: input.GlobalTypes, @@ -59,7 +59,6 @@ func RunExtract(input FlowExtractInput) FlowExtractOutput { TypeExprResolver: typeResolverFn, }, InitialDeclaredTypes: input.Scope.DeclaredTypes, - SiblingTypes: input.SiblingTypes, LiteralTypes: input.LiteralTypes, ModuleAliases: moduleAliases, ModuleBindings: input.ModuleBindings, @@ -89,7 +88,7 @@ func applyModuleAliasTypes(inputs *flow.Inputs, manifests io.ManifestQuerier) { func RunLiteral(input LiteralInput) LiteralOutput { initialCtx := NewContextBuilder(input.PhaseEnv). WithScope(input.Scope). - WithReturnSummaries(input.ReturnSummaries). + WithFunctionFacts(input.FunctionFacts). BuildDeclared() engine := synth.New(synth.Config{ @@ -179,12 +178,12 @@ func ExtractParams(fn *ast.FunctionExpr, paramTypes map[cfg.SymbolID]typ.Type, g // EnrichWithKeysCollector detects if a function is a "keys collector" // (returns keys of a parameter) and adds KeyOf constraint to OnReturn. // This enables cross-module key-provenance tracking. -func EnrichWithKeysCollector(eff *constraint.FunctionRefinement, fn *ast.FunctionExpr) *constraint.FunctionRefinement { - if fn == nil { +func EnrichWithKeysCollector(eff *constraint.FunctionRefinement, graph *cfg.Graph) *constraint.FunctionRefinement { + if graph == nil { return eff } - info := keyscoll.DetectKeysCollector(fn) + info := keyscoll.DetectKeysCollector(graph) if info == nil { return eff } diff --git a/compiler/check/phase/flow_test.go b/compiler/check/phase/flow_test.go index 1d1e191f..caa81735 100644 --- a/compiler/check/phase/flow_test.go +++ b/compiler/check/phase/flow_test.go @@ -91,7 +91,7 @@ func TestEnrichWithKeysCollector_NilFn(t *testing.T) { func TestEnrichWithKeysCollector_NonKeysCollector(t *testing.T) { fn := &ast.FunctionExpr{ParList: &ast.ParList{}} - result := EnrichWithKeysCollector(nil, fn) + result := EnrichWithKeysCollector(nil, cfg.Build(fn)) if result != nil { t.Errorf("expected nil for non-keys-collector fn, got %v", result) } @@ -113,7 +113,7 @@ func TestEnrichWithKeysCollector_UsesDetectedReturnIndex(t *testing.T) { Stmts: body, } - result := EnrichWithKeysCollector(nil, fn) + result := EnrichWithKeysCollector(nil, cfg.Build(fn)) if result == nil { t.Fatal("expected non-nil enriched effect") } @@ -150,7 +150,7 @@ func TestEnrichWithKeysCollector_AppendsToExistingOnReturn(t *testing.T) { existing := &constraint.FunctionRefinement{ OnReturn: constraint.FromConstraints(constraint.NotNil{Path: constraint.RetPath(0)}), } - result := EnrichWithKeysCollector(existing, fn) + result := EnrichWithKeysCollector(existing, cfg.Build(fn)) if result == nil { t.Fatal("expected non-nil enriched effect") } diff --git a/compiler/check/phase/narrow.go b/compiler/check/phase/narrow.go index a53b020b..230e0bf5 100644 --- a/compiler/check/phase/narrow.go +++ b/compiler/check/phase/narrow.go @@ -52,10 +52,9 @@ func RunNarrow(input NarrowInput) NarrowOutput { WithBindings(bindings). WithDeclaredTypes(declaredTypes). WithAnnotatedVars(annotatedVars). - WithSiblingTypes(input.SiblingTypes). + WithFunctionFacts(input.FunctionFacts). WithLiteralTypes(input.LiteralTypes). WithSolution(input.Solve.Solution). - WithNarrowReturnSummaries(input.NarrowReturnSummaries). BuildNarrow() engine := createNarrowedEngine( @@ -70,7 +69,7 @@ func RunNarrow(input NarrowInput) NarrowOutput { ) fnEffect := InferRefinement(input.Graph, input.Solve.Solution, input.Extract.Params, input.Extract.ReturnType) - fnEffect = EnrichWithKeysCollector(fnEffect, input.Fn) + fnEffect = EnrichWithKeysCollector(fnEffect, input.Graph) return NarrowOutput{ Facts: narrowingCtx.Types(), diff --git a/compiler/check/phase/scope.go b/compiler/check/phase/scope.go index 62fc5afc..fa4aadaa 100644 --- a/compiler/check/phase/scope.go +++ b/compiler/check/phase/scope.go @@ -210,8 +210,7 @@ func RunScope(input ScopeInput) ScopeOutput { typeExprResolver, fnSignatureResolver, typeResolutionEngine, - input.SiblingTypes, - input.ReturnSummaries, + input.FunctionFacts, ) declaredTypes = applyModuleAliasExports(declaredTypes, input.ModuleAliases, input.Manifests) @@ -222,7 +221,7 @@ func RunScope(input ScopeInput) ScopeOutput { AnnotatedVars: annotatedVars, ParamTypes: paramTypes, FunctionSignatureResolver: fnSignatureResolver, - SiblingTypes: input.SiblingTypes, + FunctionFacts: input.FunctionFacts, DepthLimitExceeded: depthExceeded, } } @@ -357,8 +356,7 @@ func buildDeclaredTypes( typeExprResolver TypeResolver, fnSigResolver FunctionSignatureResolver, synthAPI api.SynthAPI, - siblingTypes map[cfg.SymbolID]typ.Type, - returnSummaries map[cfg.SymbolID][]typ.Type, + functionFacts api.FunctionFacts, ) (flow.DeclaredTypes, map[cfg.SymbolID]bool) { if graph == nil { return nil, nil @@ -368,10 +366,10 @@ func buildDeclaredTypes( annotated := make(map[cfg.SymbolID]bool) bindings := graph.Bindings() alignWithSummary := func(sym cfg.SymbolID, fn *typ.Function) *typ.Function { - if fn == nil || len(returnSummaries) == 0 || sym == 0 { + if fn == nil || len(functionFacts) == 0 || sym == 0 { return fn } - if summary := returnSummaries[sym]; len(summary) > 0 { + if summary := functionFacts.Summary(sym); len(summary) > 0 { return returns.WithSummaryOrUnknown(fn, summary) } return fn @@ -452,11 +450,9 @@ func buildDeclaredTypes( } if fnExpr, ok := source.(*ast.FunctionExpr); ok && fnExpr != nil { - if siblingTypes != nil { - if siblingFn := siblingTypes[sym]; siblingFn != nil { - out[sym] = siblingFn - return - } + if siblingFn := functionFacts.FunctionType(sym); siblingFn != nil { + out[sym] = siblingFn + return } if fnSigResolver != nil { if fnSig := fnSigResolver.ResolveFunctionSignature(fnExpr, sc); fnSig != nil { @@ -482,7 +478,8 @@ func buildDeclaredTypes( if info := graph.FuncDef(p); info != nil && info.Name != "" && info.FuncExpr != nil { sym := info.Symbol if sym == 0 { - // Fallback for unresolved symbols in legacy/broken binding scenarios. + // Some CFG FuncDef nodes only carry the source name; recover the + // symbol from the graph's binding table for this point. var ok bool sym, ok = graph.SymbolAt(p, info.Name) if !ok { @@ -492,11 +489,9 @@ func buildDeclaredTypes( if _, exists := out[sym]; exists { continue } - if siblingTypes != nil { - if siblingFn := siblingTypes[sym]; siblingFn != nil { - out[sym] = siblingFn - continue - } + if siblingFn := functionFacts.FunctionType(sym); siblingFn != nil { + out[sym] = siblingFn + continue } sc := scopes[p] if fnSigResolver != nil { diff --git a/compiler/check/phase/types.go b/compiler/check/phase/types.go index 80d01d05..b09c418b 100644 --- a/compiler/check/phase/types.go +++ b/compiler/check/phase/types.go @@ -176,13 +176,9 @@ type ScopeInput struct { // ParamHintSignatures contains inferred param types from call sites. // Read-only - populated from ParamHints channel during iteration. ParamHintSignatures map[*ast.FunctionExpr][]typ.Type - // SiblingTypes contains types of functions in the same scope group. - // Explicit input - not looked up from store during phase execution. - SiblingTypes map[cfg.SymbolID]typ.Type - - // ReturnSummaries contains pre-flow return summaries for sibling functions. - // This is declared-phase only and intentionally not part of PhaseEnv. - ReturnSummaries map[cfg.SymbolID][]typ.Type + // FunctionFacts contains canonical facts for functions in this graph + // context. Explicit input - not looked up from store during phase execution. + FunctionFacts api.FunctionFacts } // ScopeOutput contains outputs from Phase B (scope computation). @@ -200,8 +196,8 @@ type ScopeOutput struct { ParamTypes map[cfg.SymbolID]typ.Type // FunctionSignatureResolver resolves function signatures from AST. FunctionSignatureResolver FunctionSignatureResolver - // SiblingTypes contains sibling function types (passed through from input). - SiblingTypes map[cfg.SymbolID]typ.Type + // FunctionFacts contains canonical function facts (passed through from input). + FunctionFacts api.FunctionFacts // DepthLimitExceeded indicates scope depth exceeded MaxScopeDepth. DepthLimitExceeded bool } @@ -210,10 +206,8 @@ type ScopeOutput struct { // Phase B (continued): synthesizes function literal types. type LiteralInput struct { PhaseEnv - Scope ScopeOutput - SiblingTypes map[cfg.SymbolID]typ.Type - // ReturnSummaries contains pre-flow return summaries for sibling functions. - ReturnSummaries map[cfg.SymbolID][]typ.Type + Scope ScopeOutput + FunctionFacts api.FunctionFacts } // LiteralOutput contains outputs from the function literal synthesis phase. @@ -227,12 +221,10 @@ type LiteralOutput struct { // Phase B: extracts flow constraints and assignments. type FlowExtractInput struct { PhaseEnv - Resolve ResolveOutput - Scope ScopeOutput - SiblingTypes map[cfg.SymbolID]typ.Type - LiteralTypes flow.DeclaredTypes - // ReturnSummaries contains pre-flow return summaries for sibling functions. - ReturnSummaries map[cfg.SymbolID][]typ.Type + Resolve ResolveOutput + Scope ScopeOutput + FunctionFacts api.FunctionFacts + LiteralTypes flow.DeclaredTypes } // FlowExtractOutput contains outputs from the flow extraction phase. @@ -261,13 +253,11 @@ type FlowSolveOutput struct { // Phase D: builds TypeFacts and infers effects. type NarrowInput struct { PhaseEnv - Scope ScopeOutput - Extract FlowExtractOutput - Solve FlowSolveOutput - SiblingTypes map[cfg.SymbolID]typ.Type - LiteralTypes flow.DeclaredTypes - // NarrowReturnSummaries contains post-flow return summaries for narrowing. - NarrowReturnSummaries map[cfg.SymbolID][]typ.Type + Scope ScopeOutput + Extract FlowExtractOutput + Solve FlowSolveOutput + FunctionFacts api.FunctionFacts + LiteralTypes flow.DeclaredTypes } // NarrowOutput contains outputs from the narrowing phase. @@ -281,16 +271,14 @@ type NarrowOutput struct { // ContextBuilder constructs Env instances from phase outputs. // Centralizes the wiring that was previously duplicated across phase run files. type ContextBuilder struct { - env PhaseEnv - bindings *bind.BindingTable - baseScope *scope.State - declaredTypes flow.DeclaredTypes - annotatedVars map[cfg.SymbolID]bool - siblingTypes map[cfg.SymbolID]typ.Type - literalTypes flow.DeclaredTypes - solution *flow.Solution - returnSummaries map[cfg.SymbolID][]typ.Type - narrowReturnSummaries map[cfg.SymbolID][]typ.Type + env PhaseEnv + bindings *bind.BindingTable + baseScope *scope.State + declaredTypes flow.DeclaredTypes + annotatedVars map[cfg.SymbolID]bool + functionFacts api.FunctionFacts + literalTypes flow.DeclaredTypes + solution *flow.Solution } // NewContextBuilder creates a builder pre-populated from the shared phase environment. @@ -311,7 +299,7 @@ func (b *ContextBuilder) WithScope(out ScopeOutput) *ContextBuilder { b.baseScope = out.BaseScope b.declaredTypes = out.DeclaredTypes b.annotatedVars = out.AnnotatedVars - b.siblingTypes = out.SiblingTypes + b.functionFacts = out.FunctionFacts return b } @@ -351,9 +339,9 @@ func (b *ContextBuilder) WithAnnotatedVars(av map[cfg.SymbolID]bool) *ContextBui return b } -// WithSiblingTypes overrides sibling function types. -func (b *ContextBuilder) WithSiblingTypes(st map[cfg.SymbolID]typ.Type) *ContextBuilder { - b.siblingTypes = st +// WithFunctionFacts overrides canonical function facts. +func (b *ContextBuilder) WithFunctionFacts(ff api.FunctionFacts) *ContextBuilder { + b.functionFacts = ff return b } @@ -363,49 +351,35 @@ func (b *ContextBuilder) WithLiteralTypes(lt flow.DeclaredTypes) *ContextBuilder return b } -// WithReturnSummaries sets declared-phase return summaries. -func (b *ContextBuilder) WithReturnSummaries(rs map[cfg.SymbolID][]typ.Type) *ContextBuilder { - b.returnSummaries = rs - return b -} - -// WithNarrowReturnSummaries sets post-flow return summaries for narrowing. -func (b *ContextBuilder) WithNarrowReturnSummaries(rs map[cfg.SymbolID][]typ.Type) *ContextBuilder { - b.narrowReturnSummaries = rs - return b -} - // BuildDeclared constructs a declared-phase Env from accumulated fields. func (b *ContextBuilder) BuildDeclared() *api.DeclaredEnvImpl { return api.NewDeclaredEnv(api.DeclaredEnvConfig{ Graph: b.env.Graph, Bindings: b.bindings, DeclaredTypes: b.declaredTypes, - SiblingTypes: b.siblingTypes, LiteralTypes: b.literalTypes, AnnotatedVars: b.annotatedVars, BaseScope: b.baseScope, RefinementStore: b.env.RefinementStore, ModuleAliases: b.env.ModuleAliases, GlobalTypes: b.env.GlobalTypes, - ReturnSummaries: b.returnSummaries, + FunctionFacts: b.functionFacts, }) } // BuildNarrow constructs a narrowing-phase Env from accumulated fields. func (b *ContextBuilder) BuildNarrow() *api.NarrowEnvImpl { return api.NewNarrowEnv(api.NarrowEnvConfig{ - Graph: b.env.Graph, - Bindings: b.bindings, - DeclaredTypes: b.declaredTypes, - SiblingTypes: b.siblingTypes, - LiteralTypes: b.literalTypes, - AnnotatedVars: b.annotatedVars, - Solution: b.solution, - BaseScope: b.baseScope, - RefinementStore: b.env.RefinementStore, - ModuleAliases: b.env.ModuleAliases, - GlobalTypes: b.env.GlobalTypes, - NarrowReturnSummaries: b.narrowReturnSummaries, + Graph: b.env.Graph, + Bindings: b.bindings, + DeclaredTypes: b.declaredTypes, + LiteralTypes: b.literalTypes, + AnnotatedVars: b.annotatedVars, + Solution: b.solution, + BaseScope: b.baseScope, + RefinementStore: b.env.RefinementStore, + ModuleAliases: b.env.ModuleAliases, + GlobalTypes: b.env.GlobalTypes, + FunctionFacts: b.functionFacts, }) } diff --git a/compiler/check/phase/types_test.go b/compiler/check/phase/types_test.go index 4adde914..4368178e 100644 --- a/compiler/check/phase/types_test.go +++ b/compiler/check/phase/types_test.go @@ -106,22 +106,22 @@ func TestScopeInput_Fields(t *testing.T) { if input.ParamHintSignatures != nil { t.Error("ParamHintSignatures should be nil by default") } - if input.SiblingTypes != nil { - t.Error("SiblingTypes should be nil by default") + if input.FunctionFacts != nil { + t.Error("FunctionFacts should be nil by default") } } func TestLiteralInput_Fields(t *testing.T) { input := LiteralInput{} - if input.SiblingTypes != nil { - t.Error("SiblingTypes should be nil by default") + if input.FunctionFacts != nil { + t.Error("FunctionFacts should be nil by default") } } func TestFlowExtractInput_Fields(t *testing.T) { input := FlowExtractInput{} - if input.SiblingTypes != nil { - t.Error("SiblingTypes should be nil by default") + if input.FunctionFacts != nil { + t.Error("FunctionFacts should be nil by default") } if input.LiteralTypes != nil { t.Error("LiteralTypes should be nil by default") @@ -137,8 +137,8 @@ func TestFlowSolveInput_Fields(t *testing.T) { func TestNarrowInput_Fields(t *testing.T) { input := NarrowInput{} - if input.SiblingTypes != nil { - t.Error("SiblingTypes should be nil by default") + if input.FunctionFacts != nil { + t.Error("FunctionFacts should be nil by default") } if input.LiteralTypes != nil { t.Error("LiteralTypes should be nil by default") @@ -184,8 +184,8 @@ func TestScopeOutput_Fields(t *testing.T) { if out.FunctionSignatureResolver != nil { t.Error("FunctionSignatureResolver should be nil by default") } - if out.SiblingTypes != nil { - t.Error("SiblingTypes should be nil by default") + if out.FunctionFacts != nil { + t.Error("FunctionFacts should be nil by default") } } @@ -272,7 +272,7 @@ func TestContextBuilder_WithScope(t *testing.T) { BaseScope: &scope.State{}, DeclaredTypes: make(flow.DeclaredTypes), AnnotatedVars: make(map[cfg.SymbolID]bool), - SiblingTypes: make(map[cfg.SymbolID]typ.Type), + FunctionFacts: make(api.FunctionFacts), } result := builder.WithScope(out) @@ -345,13 +345,13 @@ func TestContextBuilder_WithAnnotatedVars(t *testing.T) { } } -func TestContextBuilder_WithSiblingTypes(t *testing.T) { +func TestContextBuilder_WithFunctionFacts(t *testing.T) { env := PhaseEnv{} builder := NewContextBuilder(env) - result := builder.WithSiblingTypes(make(map[cfg.SymbolID]typ.Type)) + result := builder.WithFunctionFacts(make(api.FunctionFacts)) if result != builder { - t.Error("WithSiblingTypes should return the same builder for chaining") + t.Error("WithFunctionFacts should return the same builder for chaining") } } @@ -381,7 +381,7 @@ func TestContextBuilder_Chaining(t *testing.T) { WithBaseScope(&scope.State{}). WithDeclaredTypes(make(flow.DeclaredTypes)). WithAnnotatedVars(make(map[cfg.SymbolID]bool)). - WithSiblingTypes(make(map[cfg.SymbolID]typ.Type)). + WithFunctionFacts(make(api.FunctionFacts)). WithLiteralTypes(make(flow.DeclaredTypes)). WithSolution(nil). BuildDeclared() diff --git a/compiler/check/pipeline/driver.go b/compiler/check/pipeline/driver.go index ff653539..96c48b9b 100644 --- a/compiler/check/pipeline/driver.go +++ b/compiler/check/pipeline/driver.go @@ -127,9 +127,6 @@ func (d *Driver) runFixpoint(sess api.AnalysisSession, fn *ast.FunctionExpr, par } func (d *Driver) prepareIterationState(sess api.AnalysisSession) { - if d.cfg.FuncResultQ != nil { - d.cfg.FuncResultQ.Clear() - } sess.ResetDiagnostics() scopeState := sess.ScopeDepthDiagState() @@ -145,7 +142,6 @@ func (d *Driver) advanceFixpoint(store api.IterationStore) bool { if !store.FixpointSwap() { return true } - store.BumpRevision() return false } @@ -194,15 +190,15 @@ func (d *Driver) processNestedFunctions( Check: func(fn *ast.FunctionExpr, parent *scope.State) { d.checkFunctionFixpoint(sess, fn, parent) }, - ResultForFunc: func(fn *ast.FunctionExpr) *api.FuncResultView { + ResultForFunc: func(fn *ast.FunctionExpr) *api.FuncResultSnapshot { if results == nil { return nil } - return api.ViewFromResult(results[fn]) + return api.SnapshotFromResult(results[fn]) }, - RootResult: api.ViewFromResult(sess.RootResultValue()), + RootResult: api.SnapshotFromResult(sess.RootResultValue()), }) - nestedProc.ProcessNestedFunctions(graph, api.ViewFromResult(result)) + nestedProc.ProcessNestedFunctions(graph, api.SnapshotFromResult(result)) } func (d *Driver) registerParentScope(store api.IterationStore, graphID uint64, parent *scope.State) uint64 { @@ -239,7 +235,7 @@ func (d *Driver) runReturnInference( refinementLookup = es.LookupRefinementBySym } - summaries, funcTypes, diags := inferencer.ComputeForGraph(returninfer.RunContext{ + functionFacts, diags := inferencer.ComputeForGraph(returninfer.RunContext{ Ctx: sess.Context(), ParentFacts: d.parentFactsForGraph(sess, store, graph.ID()), EffectLookup: refinementLookup, @@ -247,12 +243,12 @@ func (d *Driver) runReturnInference( if len(diags) > 0 { sess.AppendDiagnostics(diags...) } - if len(summaries) == 0 { + if len(functionFacts) == 0 { return } if key, ok := store.GraphKeyFor(graph, parent); ok { - store.UpdateInterprocFactsNext(key, func(facts *api.Facts) { - returns.MergeFunctionFactsIntoFacts(facts, summaries, nil, funcTypes) + store.MergeInterprocFactsNext(key, api.Facts{ + FunctionFacts: functionFacts, }) } } @@ -296,14 +292,9 @@ func (d *Driver) loadFunctionResult( if d.cfg.FuncResultQ == nil { return nil } - revision := uint64(0) - if store != nil { - revision = store.Revision() - } return d.cfg.FuncResultQ.Get(sess.Context(), api.FuncKey{ - GraphID: graphID, - ParentHash: parentHash, - StoreRevision: revision, + GraphID: graphID, + ParentHash: parentHash, }) } diff --git a/compiler/check/pipeline/runner.go b/compiler/check/pipeline/runner.go index dd19e746..282cb76d 100644 --- a/compiler/check/pipeline/runner.go +++ b/compiler/check/pipeline/runner.go @@ -19,7 +19,6 @@ package pipeline import ( - "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/infer/captured" "github.com/wippyai/go-lua/compiler/check/infer/paramhints" @@ -77,6 +76,12 @@ func (r *Runner) Run(ctx *db.QueryContext, key api.FuncKey) *api.FuncResult { if store == nil { return nil } + if tracker, ok := store.(interface { + PushSnapshotReadContext(*db.QueryContext) func() + }); ok { + pop := tracker.PushSnapshotReadContext(ctx) + defer pop() + } withPhase := func(_ api.Phase, fn func()) { fn() } if phaser, ok := store.(interface{ WithPhase(api.Phase, func()) }); ok { withPhase = phaser.WithPhase @@ -101,17 +106,10 @@ func (r *Runner) Run(ctx *db.QueryContext, key api.FuncKey) *api.FuncResult { setter.SetGraphParentHash(graph.ID(), key.ParentHash) } - paramHintSigs := paramhints.BuildParamHintSigView(store, graph, parent, r.stdlib) + paramHintSigs := paramhints.BuildParamHintSignatures(store, graph, parent, r.stdlib) synthSig := r.resolveSynthesizedSignature(ctx, store, graph, fn, parent, paramHintSigs) - // Canonical local function types for this graph (stable snapshot). - siblingTypes := store.GetLocalFuncTypesSnapshot(graph, parent) - // Return summaries include captured field assignments (stable snapshot). - returnSummaries := store.GetReturnSummariesSnapshot(graph, parent) - var narrowReturnSummaries map[cfg.SymbolID][]typ.Type - withPhase(api.PhaseNarrowing, func() { - narrowReturnSummaries = store.GetNarrowReturnSummariesSnapshot(graph, parent) - }) + functionFacts := store.GetFunctionFactsSnapshot(graph, parent) // Build shared phase environment once. localAliases := modules.CollectAliases(graph) @@ -147,49 +145,46 @@ func (r *Runner) Run(ctx *db.QueryContext, key api.FuncKey) *api.FuncResult { SynthesizedFunctionSig: synthSig, FunctionLiteralSignatures: literalSigs, ParamHintSignatures: paramHintSigs, - SiblingTypes: siblingTypes, - ReturnSummaries: returnSummaries, + FunctionFacts: functionFacts, }) // Declared is the default phase for scope/extract and interproc reads. if capturedTypes := store.GetCapturedTypesSnapshot(graph, parent); len(capturedTypes) > 0 { scopeOut.DeclaredTypes = captured.MergeCapturedTypes(scopeOut.DeclaredTypes, capturedTypes) } - r.mergeCapturedParentFuncTypes(store, graph, fn, &scopeOut) + r.mergeCapturedParentFunctionTypes(store, graph, fn, &scopeOut) // Populate scopes in env for later phases. env.Scopes = scopeOut.Scopes // Phase B (continued): Synthesize function literal types. literalOut := phase.RunLiteral(phase.LiteralInput{ - PhaseEnv: env, - Scope: scopeOut, - SiblingTypes: scopeOut.SiblingTypes, - ReturnSummaries: returnSummaries, + PhaseEnv: env, + Scope: scopeOut, + FunctionFacts: functionFacts, }) // Ensure literal function types use canonical local function types. - if len(siblingTypes) > 0 { + if len(functionFacts) > 0 { if literalOut.LiteralTypes == nil { - literalOut.LiteralTypes = make(flow.DeclaredTypes, len(siblingTypes)) + literalOut.LiteralTypes = make(flow.DeclaredTypes, len(functionFacts)) } - for sym, fnType := range siblingTypes { - if fnType == nil { + for sym, fact := range functionFacts { + if fact.Type == nil { continue } - literalOut.LiteralTypes[sym] = fnType + literalOut.LiteralTypes[sym] = fact.Type } } // Phase B (continued): Extract flow constraints. extractOut := phase.RunExtract(phase.FlowExtractInput{ - PhaseEnv: env, - Resolve: resolveOut, - Scope: scopeOut, - SiblingTypes: scopeOut.SiblingTypes, - LiteralTypes: literalOut.LiteralTypes, - ReturnSummaries: returnSummaries, + PhaseEnv: env, + Resolve: resolveOut, + Scope: scopeOut, + FunctionFacts: functionFacts, + LiteralTypes: literalOut.LiteralTypes, }) - r.appendCapturedMutatorAssignments(store, graph, parent, env, scopeOut, literalOut, returnSummaries, &extractOut) + r.appendCapturedMutatorAssignments(store, graph, parent, env, scopeOut, literalOut, functionFacts, &extractOut) // Phase C: Solve flow system. solveOut := phase.RunSolve(phase.FlowSolveInput{ @@ -201,13 +196,12 @@ func (r *Runner) Run(ctx *db.QueryContext, key api.FuncKey) *api.FuncResult { var narrowOut phase.NarrowOutput withPhase(api.PhaseNarrowing, func() { narrowOut = phase.RunNarrow(phase.NarrowInput{ - PhaseEnv: env, - Scope: scopeOut, - Extract: extractOut, - Solve: solveOut, - SiblingTypes: scopeOut.SiblingTypes, - LiteralTypes: literalOut.LiteralTypes, - NarrowReturnSummaries: narrowReturnSummaries, + PhaseEnv: env, + Scope: scopeOut, + Extract: extractOut, + Solve: solveOut, + FunctionFacts: functionFacts, + LiteralTypes: literalOut.LiteralTypes, }) }) diff --git a/compiler/check/pipeline/runner_stages.go b/compiler/check/pipeline/runner_stages.go index 7ad22200..64c35bd3 100644 --- a/compiler/check/pipeline/runner_stages.go +++ b/compiler/check/pipeline/runner_stages.go @@ -17,7 +17,7 @@ import ( func (r *Runner) resolveSynthesizedSignature( ctx *db.QueryContext, - store api.StoreView, + store api.StoreReader, graph *cfg.Graph, fn *ast.FunctionExpr, parent *scope.State, @@ -56,13 +56,13 @@ func (r *Runner) resolveSynthesizedSignature( } func (r *Runner) appendCapturedMutatorAssignments( - store api.StoreView, + store api.StoreReader, graph *cfg.Graph, parent *scope.State, env phase.PhaseEnv, scopeOut phase.ScopeOutput, literalOut phase.LiteralOutput, - returnSummaries map[cfg.SymbolID][]typ.Type, + functionFacts api.FunctionFacts, extractOut *phase.FlowExtractOutput, ) { if store == nil || graph == nil || extractOut == nil || extractOut.Inputs == nil { @@ -81,9 +81,8 @@ func (r *Runner) appendCapturedMutatorAssignments( declaredEnv := phase.NewContextBuilder(env). WithScope(scopeOut). - WithSiblingTypes(scopeOut.SiblingTypes). WithLiteralTypes(literalOut.LiteralTypes). - WithReturnSummaries(returnSummaries). + WithFunctionFacts(functionFacts). BuildDeclared() synthEngine := synth.New(synth.Config{ @@ -121,7 +120,7 @@ func (r *Runner) runComputePasses(graph *cfg.Graph, scopes map[cfg.Point]*scope. return extras } -func (r *Runner) literalSignatureForFunction(store api.StoreView, graph *cfg.Graph, fn *ast.FunctionExpr) *typ.Function { +func (r *Runner) literalSignatureForFunction(store api.StoreReader, graph *cfg.Graph, fn *ast.FunctionExpr) *typ.Function { if store == nil || graph == nil || fn == nil { return nil } @@ -152,7 +151,7 @@ func (r *Runner) literalSignatureForFunction(store api.StoreView, graph *cfg.Gra return nil } -func (r *Runner) literalSigProvider(store api.StoreView, graph *cfg.Graph, parent *scope.State) phase.LiteralSigsProvider { +func (r *Runner) literalSigProvider(store api.StoreReader, graph *cfg.Graph, parent *scope.State) phase.LiteralSigsProvider { if store == nil || graph == nil || parent == nil { return nil } @@ -184,7 +183,7 @@ type effectStoreProvider interface { RefinementStore() api.RefinementStore } -func effectStoreFrom(store api.StoreView) api.RefinementStore { +func effectStoreFrom(store api.StoreReader) api.RefinementStore { if store == nil { return nil } @@ -198,7 +197,7 @@ type scratchLiteralStore interface { ScratchLiteralSigs(graphID uint64) map[*ast.FunctionExpr]*typ.Function } -func scratchLiteralSigs(store api.StoreView, graphID uint64) map[*ast.FunctionExpr]*typ.Function { +func scratchLiteralSigs(store api.StoreReader, graphID uint64) map[*ast.FunctionExpr]*typ.Function { if store == nil { return nil } @@ -208,7 +207,7 @@ func scratchLiteralSigs(store api.StoreView, graphID uint64) map[*ast.FunctionEx return nil } -func (r *Runner) parentScopeForGraph(store api.StoreView, graph *cfg.Graph) *scope.State { +func (r *Runner) parentScopeForGraph(store api.StoreReader, graph *cfg.Graph) *scope.State { if store == nil || graph == nil { return nil } @@ -221,8 +220,8 @@ func (r *Runner) parentScopeForGraph(store api.StoreView, graph *cfg.Graph) *sco return nil } -func (r *Runner) mergeCapturedParentFuncTypes( - store api.StoreView, +func (r *Runner) mergeCapturedParentFunctionTypes( + store api.StoreReader, graph *cfg.Graph, fn *ast.FunctionExpr, scopeOut *phase.ScopeOutput, @@ -242,12 +241,12 @@ func (r *Runner) mergeCapturedParentFuncTypes( if parentScope == nil { return } - parentFuncTypes := store.GetLocalFuncTypesSnapshot(parentGraph, parentScope) - if len(parentFuncTypes) == 0 { + parentFacts := store.GetFunctionFactsSnapshot(parentGraph, parentScope) + if len(parentFacts) == 0 { return } for _, sym := range graph.Bindings().CapturedSymbols(fn) { - ft := parentFuncTypes[sym] + ft := parentFacts.FunctionType(sym) if sym == 0 || ft == nil { continue } diff --git a/compiler/check/returns/callgraph_symbol_test.go b/compiler/check/returns/callgraph_symbol_test.go index e783f6d3..c41142a4 100644 --- a/compiler/check/returns/callgraph_symbol_test.go +++ b/compiler/check/returns/callgraph_symbol_test.go @@ -55,7 +55,7 @@ func TestCanonicalLocalSymbol_PrefersKnownLocalOverRaw(t *testing.T) { } } -func TestCanonicalLocalCalleeSymbol_UsesCalleeNameFallback(t *testing.T) { +func TestCanonicalLocalCalleeSymbol_UsesCalleeNameResolution(t *testing.T) { bindings := bind.NewBindingTable() const localSym cfg.SymbolID = 3001 bindings.SetName(localSym, "runner") @@ -69,7 +69,7 @@ func TestCanonicalLocalCalleeSymbol_UsesCalleeNameFallback(t *testing.T) { } got := canonicalLocalCalleeSymbol(localFuncs, nil, nil, bindings, callInfo) if got != localSym { - t.Fatalf("canonicalLocalCalleeSymbol via name fallback = %d, want %d", got, localSym) + t.Fatalf("canonicalLocalCalleeSymbol via name resolution = %d, want %d", got, localSym) } } diff --git a/compiler/check/returns/doc.go b/compiler/check/returns/doc.go index c9312772..c67c76aa 100644 --- a/compiler/check/returns/doc.go +++ b/compiler/check/returns/doc.go @@ -29,7 +29,7 @@ // // # Overlay System // -// [Overlay] provides a mutable view over stable return summaries: +// [Overlay] provides the mutable return-summary layer: // - Stable summaries from previous iterations // - Pending updates from current iteration // - Atomic commit when iteration converges diff --git a/compiler/check/returns/domain_law_test.go b/compiler/check/returns/domain_law_test.go new file mode 100644 index 00000000..307a92ab --- /dev/null +++ b/compiler/check/returns/domain_law_test.go @@ -0,0 +1,203 @@ +package returns + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/types/constraint" + "github.com/wippyai/go-lua/types/typ" +) + +func TestFactsDomain_ProductOperatorsAreIdempotentAcrossAllDomains(t *testing.T) { + fnSym := cfg.SymbolID(1) + capturedSym := cfg.SymbolID(2) + classSym := cfg.SymbolID(3) + lit := &ast.FunctionExpr{} + callback := typ.Func().Param("self", typ.Unknown).Returns(typ.Boolean).Build() + fn := typ.Func().Param("name", typ.String).Returns(typ.String).Build() + raw := api.Facts{ + FunctionFacts: api.FunctionFacts{ + fnSym: {Summary: []typ.Type{typ.String}, Narrow: []typ.Type{typ.String}, Type: fn}, + }, + ParamHints: api.ParamHints{ + fnSym: []typ.Type{typ.String}, + }, + LiteralSigs: api.LiteralSigs{ + lit: typ.Func().Param("name", typ.String).Returns(typ.String).Build(), + }, + CapturedTypes: api.CapturedTypes{ + capturedSym: typ.NewRecord().Field("name", typ.String).Build(), + }, + CapturedFields: api.CapturedFieldAssigns{ + fnSym: { + capturedSym: { + "callback": typ.NewOptional(callback), + }, + }, + }, + CapturedContainers: api.CapturedContainerMutations{ + fnSym: { + capturedSym: { + { + Segments: []constraint.Segment{{Kind: constraint.SegmentField, Name: "items"}}, + ValueType: typ.NewArray(typ.String), + }, + }, + }, + }, + ConstructorFields: api.ConstructorFields{ + classSym: { + "name": typ.String, + }, + }, + } + + normalized := WidenFacts(api.Facts{}, raw) + if !FactsEqual(normalized, WidenFacts(normalized, normalized)) { + t.Fatalf("Widen must be idempotent across the product domain") + } + if !FactsEqual(normalized, JoinFacts(normalized, normalized)) { + t.Fatalf("Join must be idempotent across the product domain") + } + + if got := normalized.FunctionFacts.Summary(fnSym); !ReturnTypesEqual(got, []typ.Type{typ.String}) { + t.Fatalf("summary must come from canonical FunctionFacts, got %v", got) + } + if got := normalized.FunctionFacts.NarrowSummary(fnSym); !ReturnTypesEqual(got, []typ.Type{typ.String}) { + t.Fatalf("narrow summary must come from canonical FunctionFacts, got %v", got) + } + if got := normalized.FunctionFacts.FunctionType(fnSym); got == nil { + t.Fatalf("function type must come from canonical FunctionFacts") + } +} + +func TestFactsDomain_WidenIdempotentForLiteralUnknownVsConcreteReturn(t *testing.T) { + lit := &ast.FunctionExpr{} + prev := api.Facts{ + LiteralSigs: api.LiteralSigs{ + lit: typ.Func().Param("name", typ.Unknown).Returns(typ.Unknown, typ.NewOptional(typ.String)).Build(), + }, + } + next := api.Facts{ + LiteralSigs: api.LiteralSigs{ + lit: typ.Func().Param("name", typ.Unknown).Returns( + typ.NewOptional(typ.NewRecord(). + Field("id", typ.String). + Field("priority", typ.Integer). + SetOpen(true). + Build()), + typ.NewOptional(typ.String), + ).Build(), + }, + } + + widened := WidenFacts(prev, next) + widenedAgain := WidenFacts(widened, next) + if !FactsEqual(widened, widenedAgain) { + t.Fatalf("Widen must be idempotent for literal signatures:\nfirst=%#v\nsecond=%#v", widened, widenedAgain) + } + + got := widened.LiteralSigs[lit] + if got == nil || len(got.Returns) != 2 || !typ.TypeEquals(got.Returns[0], typ.Unknown) { + t.Fatalf("expected unresolved literal return to remain the stable upper bound, got %v", got) + } +} + +func TestFactsDomain_WidenFunctionParamsIsVarianceAware(t *testing.T) { + sym := cfg.SymbolID(1) + prev := api.Facts{ + FunctionFacts: api.FunctionFacts{ + sym: {Type: typ.Func().Param("path", typ.Any).Returns(typ.NewArray(typ.Unknown)).Build()}, + }, + } + next := api.Facts{ + FunctionFacts: api.FunctionFacts{ + sym: {Type: typ.Func().Param("path", typ.String).Returns(typ.NewArray(typ.Unknown)).Build()}, + }, + } + + widened := WidenFacts(prev, next) + widenedAgain := WidenFacts(widened, next) + if !FactsEqual(widened, widenedAgain) { + t.Fatalf("Widen must be idempotent for function param facts") + } + + fn := unwrapFunctionForDomainTest(t, widened.FunctionFacts.FunctionType(sym)) + if len(fn.Params) != 1 || !typ.TypeEquals(fn.Params[0].Type, typ.Any) { + t.Fatalf("expected widening to preserve broad parameter upper bound, got %v", fn) + } +} + +func TestFactsDomain_WidenPreservesCapturedCallbackUnionMembers(t *testing.T) { + sym := cfg.SymbolID(9) + withPending := typ.NewUnion( + typ.LiteralString("pass"), + typ.LiteralString("pending"), + typ.LiteralString("fail"), + typ.LiteralString("skip"), + ) + withoutPending := typ.NewUnion( + typ.LiteralString("pass"), + typ.LiteralString("fail"), + typ.LiteralString("skip"), + ) + prevFn := typ.Func(). + Param("suite", typ.Any). + Param("test_case", typ.Any). + Returns(typ.NewRecord().Field("status", withPending).Field("suite", typ.Unknown).Build()). + Build() + nextFn := typ.Func(). + Param("suite", typ.Any). + Param("test_case", typ.Any). + Returns(typ.NewRecord().Field("status", withoutPending).Field("suite", typ.Unknown).Build()). + Build() + + widened := WidenFacts( + api.Facts{CapturedTypes: api.CapturedTypes{sym: prevFn}}, + api.Facts{CapturedTypes: api.CapturedTypes{sym: nextFn}}, + ) + widenedAgain := WidenFacts(widened, api.Facts{CapturedTypes: api.CapturedTypes{sym: nextFn}}) + if !FactsEqual(widened, widenedAgain) { + t.Fatalf("Widen must be idempotent for captured callback union members") + } + + fn := unwrapFunctionForDomainTest(t, widened.CapturedTypes[sym]) + rec, ok := fn.Returns[0].(*typ.Record) + if !ok { + t.Fatalf("expected callback record return, got %T", fn.Returns[0]) + } + status := rec.GetField("status") + if status == nil || !typ.TypeEquals(status.Type, withPending) { + t.Fatalf("expected status union to preserve pending member, got %v", status) + } +} + +func TestFactsDomain_UnsafeNestedUnionDropDetected(t *testing.T) { + withPending := typ.NewUnion( + typ.LiteralString("pass"), + typ.LiteralString("pending"), + typ.LiteralString("fail"), + typ.LiteralString("skip"), + ) + withoutPending := typ.NewUnion( + typ.LiteralString("pass"), + typ.LiteralString("fail"), + typ.LiteralString("skip"), + ) + prev := typ.NewRecord().Field("status", withPending).Build() + next := typ.NewRecord().Field("status", withoutPending).Build() + if !typeUnsafePrecisionDrop(prev, next) { + t.Fatalf("expected nested union member drop to be unsafe: prev=%v next=%v", prev, next) + } +} + +func unwrapFunctionForDomainTest(t *testing.T, got typ.Type) *typ.Function { + t.Helper() + fn, ok := got.(*typ.Function) + if !ok { + t.Fatalf("expected function type, got %T %v", got, got) + } + return fn +} diff --git a/compiler/check/returns/equal.go b/compiler/check/returns/equal.go index 2ada7665..0bc2be0f 100644 --- a/compiler/check/returns/equal.go +++ b/compiler/check/returns/equal.go @@ -6,9 +6,9 @@ import ( "github.com/wippyai/go-lua/types/typ" ) -// FactsEqual checks if two interproc fact bundles are equal. +// FactsEqual checks if two canonical interproc fact bundles are equal. func FactsEqual(a, b api.Facts) bool { - if !FunctionFactsEqual(canonicalFunctionFacts(a), canonicalFunctionFacts(b)) { + if !FunctionFactsEqual(a.FunctionFacts, b.FunctionFacts) { return false } if !symbolTypeVectorMapEqual(a.ParamHints, b.ParamHints) { @@ -49,7 +49,7 @@ func FunctionFactsEqual(a, b api.FunctionFacts) bool { if !ReturnTypesEqual(af.Narrow, bf.Narrow) { return false } - if !typ.TypeEquals(af.Func, bf.Func) { + if !typ.TypeEquals(af.Type, bf.Type) { return false } } @@ -89,8 +89,9 @@ func symbolTypeMapEqual(a map[cfg.SymbolID]typ.Type, b map[cfg.SymbolID]typ.Type return false } for _, sym := range cfg.SortedSymbolIDs(a) { - left := a[sym] + left := canonicalInterprocValueType(a[sym]) right, ok := b[sym] + right = canonicalInterprocValueType(right) if !ok || !typ.TypeEquals(left, right) { return false } @@ -116,7 +117,9 @@ func CapturedFieldAssignsEqual(a, b api.CapturedFieldAssigns) bool { return false } for _, name := range cfg.SortedFieldNames(fields) { - if !typ.TypeEquals(fields[name], otherFields[name]) { + left := canonicalInterprocValueType(fields[name]) + right := canonicalInterprocValueType(otherFields[name]) + if !typ.TypeEquals(left, right) { return false } } @@ -161,7 +164,7 @@ func containerMutationSlicesEqual(a, b []api.ContainerMutation) bool { for _, m := range b { key := api.ContainerMutationKey(m) other, ok := index[key] - if !ok || !typ.TypeEquals(other.ValueType, m.ValueType) { + if !ok || !typ.TypeEquals(canonicalInterprocValueType(other.ValueType), canonicalInterprocValueType(m.ValueType)) { return false } } @@ -180,7 +183,9 @@ func ConstructorFieldsEqual(a, b api.ConstructorFields) bool { return false } for _, name := range cfg.SortedFieldNames(fields) { - if !typ.TypeEquals(fields[name], other[name]) { + left := canonicalInterprocValueType(fields[name]) + right := canonicalInterprocValueType(other[name]) + if !typ.TypeEquals(left, right) { return false } } diff --git a/compiler/check/returns/equal_test.go b/compiler/check/returns/equal_test.go index 3ee8e8cf..06540da8 100644 --- a/compiler/check/returns/equal_test.go +++ b/compiler/check/returns/equal_test.go @@ -16,7 +16,7 @@ func TestFactsEqual_Empty(t *testing.T) { } } -func TestFactsEqual_ReturnSummaries(t *testing.T) { +func TestFactsEqual_FunctionFactSummary(t *testing.T) { a := api.Facts{ FunctionFacts: api.FunctionFacts{ 1: {Summary: []typ.Type{typ.String}}, @@ -32,7 +32,7 @@ func TestFactsEqual_ReturnSummaries(t *testing.T) { } } -func TestFactsEqual_DifferentReturnSummaries(t *testing.T) { +func TestFactsEqual_DifferentFunctionFactSummary(t *testing.T) { a := api.Facts{ FunctionFacts: api.FunctionFacts{ 1: {Summary: []typ.Type{typ.String}}, @@ -48,7 +48,7 @@ func TestFactsEqual_DifferentReturnSummaries(t *testing.T) { } } -func TestFactsEqual_IgnoresLegacyMirrorDrift(t *testing.T) { +func TestFactsEqual_UsesCanonicalFunctionFactsOnly(t *testing.T) { sym := cfg.SymbolID(77) fn := typ.Func().Returns(typ.String).Build() @@ -57,25 +57,16 @@ func TestFactsEqual_IgnoresLegacyMirrorDrift(t *testing.T) { sym: { Summary: []typ.Type{typ.String}, Narrow: []typ.Type{typ.String}, - Func: fn, + Type: fn, }, }, - ReturnSummaries: api.ReturnSummaries{ - sym: []typ.Type{typ.Number}, - }, - NarrowReturns: api.NarrowReturnSummaries{ - sym: []typ.Type{typ.Number}, - }, - FuncTypes: api.FuncTypes{ - sym: typ.Func().Returns(typ.Number).Build(), - }, } b := api.Facts{ FunctionFacts: api.FunctionFacts{ sym: { Summary: []typ.Type{typ.String}, Narrow: []typ.Type{typ.String}, - Func: fn, + Type: fn, }, }, } @@ -85,46 +76,34 @@ func TestFactsEqual_IgnoresLegacyMirrorDrift(t *testing.T) { } } -func TestFactsEqual_LegacyOnlyChannelsAreComparedCanonically(t *testing.T) { +func TestFactsEqual_DifferentCanonicalFunctionFacts(t *testing.T) { sym := cfg.SymbolID(91) a := api.Facts{ - ReturnSummaries: api.ReturnSummaries{ - sym: []typ.Type{typ.String}, - }, - NarrowReturns: api.NarrowReturnSummaries{ - sym: []typ.Type{typ.String}, - }, - FuncTypes: api.FuncTypes{ - sym: typ.Func().Returns(typ.String).Build(), + FunctionFacts: api.FunctionFacts{ + sym: {Summary: []typ.Type{typ.String}, Type: typ.Func().Returns(typ.String).Build()}, }, } b := api.Facts{ - ReturnSummaries: api.ReturnSummaries{ - sym: []typ.Type{typ.Number}, - }, - NarrowReturns: api.NarrowReturnSummaries{ - sym: []typ.Type{typ.Number}, - }, - FuncTypes: api.FuncTypes{ - sym: typ.Func().Returns(typ.Number).Build(), + FunctionFacts: api.FunctionFacts{ + sym: {Summary: []typ.Type{typ.Number}, Type: typ.Func().Returns(typ.Number).Build()}, }, } if FactsEqual(a, b) { - t.Fatal("legacy-only function channels should participate in canonical equality") + t.Fatal("different canonical function facts should not be equal") } } -func TestReturnSummariesEqual_Empty(t *testing.T) { +func TestTypeVectorMapEqual_Empty(t *testing.T) { if !symbolTypeVectorMapEqual(nil, nil) { t.Error("nil summaries should be equal") } } -func TestReturnSummariesEqual_DifferentLength(t *testing.T) { - a := api.ReturnSummaries{1: []typ.Type{typ.String}} - b := api.ReturnSummaries{} +func TestTypeVectorMapEqual_DifferentLength(t *testing.T) { + a := map[cfg.SymbolID][]typ.Type{1: {typ.String}} + b := map[cfg.SymbolID][]typ.Type{} if symbolTypeVectorMapEqual(a, b) { t.Error("summaries with different lengths should not be equal") } @@ -144,16 +123,16 @@ func TestParamHintsEqual_Same(t *testing.T) { } } -func TestFuncTypesEqual_Empty(t *testing.T) { +func TestSymbolTypeMapEqual_Empty(t *testing.T) { if !symbolTypeMapEqual(nil, nil) { t.Error("nil func types should be equal") } } -func TestFuncTypesEqual_Same(t *testing.T) { +func TestSymbolTypeMapEqual_Same(t *testing.T) { fn := typ.Func().Returns(typ.String).Build() - a := api.FuncTypes{1: fn} - b := api.FuncTypes{1: fn} + a := map[cfg.SymbolID]typ.Type{1: fn} + b := map[cfg.SymbolID]typ.Type{1: fn} if !symbolTypeMapEqual(a, b) { t.Error("same func types should be equal") } @@ -216,3 +195,12 @@ func TestCapturedContainerMutationsEqual_Basic(t *testing.T) { t.Error("same container mutations should be equal") } } + +func TestCapturedFieldAssignsEqual_CanonicalizesOptionalFunctionValues(t *testing.T) { + fn := typ.Func().Param("fn", typ.Unknown).Build() + left := api.CapturedFieldAssigns{1: {2: {"after_all": typ.NewOptional(fn)}}} + right := api.CapturedFieldAssigns{1: {2: {"after_all": fn}}} + if !CapturedFieldAssignsEqual(left, right) { + t.Fatal("expected optional function captured field to equal canonical function value") + } +} diff --git a/compiler/check/returns/function_facts.go b/compiler/check/returns/function_facts.go index a545c35f..87d10ee7 100644 --- a/compiler/check/returns/function_facts.go +++ b/compiler/check/returns/function_facts.go @@ -3,23 +3,8 @@ package returns import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" - "github.com/wippyai/go-lua/types/typ" ) -func collectFunctionFactChannelSymbols( - summaries api.ReturnSummaries, - narrows api.NarrowReturnSummaries, - funcs api.FuncTypes, - facts api.FunctionFacts, -) []cfg.SymbolID { - symbols := make(map[cfg.SymbolID]bool, len(summaries)+len(narrows)+len(funcs)+len(facts)) - markFunctionFactSymbols(symbols, summaries) - markFunctionFactSymbols(symbols, narrows) - markFunctionFactSymbols(symbols, funcs) - markFunctionFactSymbols(symbols, facts) - return cfg.SortedSymbolIDs(symbols) -} - func collectCanonicalFunctionFactSymbols(factSets ...api.FunctionFacts) []cfg.SymbolID { total := 0 for _, facts := range factSets { @@ -38,79 +23,52 @@ func markFunctionFactSymbols[T any](dst map[cfg.SymbolID]bool, src map[cfg.Symbo } } -func setOrDeleteReturnSummary(m *api.ReturnSummaries, sym cfg.SymbolID, summary []typ.Type) { - if len(summary) > 0 { - if *m == nil { - *m = make(api.ReturnSummaries) - } - (*m)[sym] = summary - return - } - if *m == nil { - return - } - delete(*m, sym) - if len(*m) == 0 { - *m = nil +func NormalizeFunctionFact(ff api.FunctionFact) api.FunctionFact { + return api.FunctionFact{ + Summary: canonicalReturnVector(ff.Summary), + Narrow: canonicalReturnVector(ff.Narrow), + Type: normalizeInterprocValueType(ff.Type), } } -func setOrDeleteNarrowSummary(m *api.NarrowReturnSummaries, sym cfg.SymbolID, narrow []typ.Type) { - if len(narrow) > 0 { - if *m == nil { - *m = make(api.NarrowReturnSummaries) - } - (*m)[sym] = narrow - return - } - if *m == nil { - return - } - delete(*m, sym) - if len(*m) == 0 { - *m = nil - } +func functionFactEmpty(ff api.FunctionFact) bool { + return len(ff.Summary) == 0 && len(ff.Narrow) == 0 && ff.Type == nil } -func setOrDeleteFuncType(m *api.FuncTypes, sym cfg.SymbolID, fn typ.Type) { - if fn != nil { - if *m == nil { - *m = make(api.FuncTypes) - } - (*m)[sym] = fn - return +func readFunctionFactFromFacts(facts *api.Facts, sym cfg.SymbolID) api.FunctionFact { + if facts == nil || sym == 0 { + return api.FunctionFact{} } - if *m == nil { - return + if facts.FunctionFacts == nil { + return api.FunctionFact{} } - delete(*m, sym) - if len(*m) == 0 { - *m = nil + ff, ok := facts.FunctionFacts[sym] + if !ok { + return api.FunctionFact{} } -} - -func functionFactFromChannels(summary, narrow []typ.Type, fn typ.Type) api.FunctionFact { - return api.FunctionFact{ - Summary: NormalizeReturnVector(summary), - Narrow: NormalizeReturnVector(narrow), - Func: fn, + canonical := NormalizeFunctionFact(ff) + if !functionFactEmpty(canonical) { + return canonical } + return api.FunctionFact{} } -func readFunctionFactFromFacts(facts *api.Facts, sym cfg.SymbolID) api.FunctionFact { - if facts == nil || sym == 0 { - return api.FunctionFact{} +func normalizeFunctionFactMap(facts api.FunctionFacts) api.FunctionFacts { + if len(facts) == 0 { + return nil } - if facts.FunctionFacts != nil { - ff, ok := facts.FunctionFacts[sym] - if ok { - canonical := functionFactFromChannels(ff.Summary, ff.Narrow, ff.Func) - if len(canonical.Summary) > 0 || len(canonical.Narrow) > 0 || canonical.Func != nil { - return canonical - } + out := make(api.FunctionFacts, len(facts)) + for _, sym := range cfg.SortedSymbolIDs(facts) { + canonical := NormalizeFunctionFact(facts[sym]) + if functionFactEmpty(canonical) { + continue } + out[sym] = canonical + } + if len(out) == 0 { + return nil } - return functionFactFromChannels(facts.ReturnSummaries[sym], facts.NarrowReturns[sym], facts.FuncTypes[sym]) + return out } func writeFunctionFactToFacts(facts *api.Facts, sym cfg.SymbolID, ff api.FunctionFact) { @@ -118,8 +76,8 @@ func writeFunctionFactToFacts(facts *api.Facts, sym cfg.SymbolID, ff api.Functio return } - ff = functionFactFromChannels(ff.Summary, ff.Narrow, ff.Func) - if len(ff.Summary) == 0 && len(ff.Narrow) == 0 && ff.Func == nil { + ff = NormalizeFunctionFact(ff) + if functionFactEmpty(ff) { if facts.FunctionFacts != nil { delete(facts.FunctionFacts, sym) if len(facts.FunctionFacts) == 0 { @@ -132,109 +90,12 @@ func writeFunctionFactToFacts(facts *api.Facts, sym cfg.SymbolID, ff api.Functio } facts.FunctionFacts[sym] = ff } - setOrDeleteReturnSummary(&facts.ReturnSummaries, sym, ff.Summary) - setOrDeleteNarrowSummary(&facts.NarrowReturns, sym, ff.Narrow) - setOrDeleteFuncType(&facts.FuncTypes, sym, ff.Func) -} - -func projectCanonicalFunctionFactChannel[T any]( - facts api.Facts, - project func(api.FunctionFact) (T, bool), -) map[cfg.SymbolID]T { - canonical := canonicalFunctionFacts(facts) - if len(canonical) == 0 { - return nil - } - out := make(map[cfg.SymbolID]T, len(canonical)) - for _, sym := range cfg.SortedSymbolIDs(canonical) { - ff := canonical[sym] - value, ok := project(ff) - if ok { - out[sym] = value - } - } - if len(out) == 0 { - return nil - } - return out -} - -// SummaryViewFromFacts returns the canonical summary channel view derived from -// FunctionFacts. -func SummaryViewFromFacts(facts api.Facts) api.ReturnSummaries { - return projectCanonicalFunctionFactChannel(facts, func(ff api.FunctionFact) ([]typ.Type, bool) { - if len(ff.Summary) == 0 { - return nil, false - } - return ff.Summary, true - }) } -// NarrowViewFromFacts returns the canonical narrow-summary channel view derived -// from FunctionFacts. -func NarrowViewFromFacts(facts api.Facts) api.NarrowReturnSummaries { - return projectCanonicalFunctionFactChannel(facts, func(ff api.FunctionFact) ([]typ.Type, bool) { - if len(ff.Narrow) == 0 { - return nil, false - } - return ff.Narrow, true - }) -} - -// FuncTypeViewFromFacts returns the canonical function-type channel view -// derived from FunctionFacts. -func FuncTypeViewFromFacts(facts api.Facts) api.FuncTypes { - return projectCanonicalFunctionFactChannel(facts, func(ff api.FunctionFact) (typ.Type, bool) { - if ff.Func == nil { - return nil, false - } - return ff.Func, true - }) -} - -// NormalizeFunctionFactChannels reconciles legacy function channels into -// canonical FunctionFacts, then rewrites mirrors from canonical values. -func NormalizeFunctionFactChannels(facts *api.Facts) { +// NormalizeFunctionFacts canonicalizes the stored FunctionFacts map. +func NormalizeFunctionFacts(facts *api.Facts) { if facts == nil { return } - symbols := collectFunctionFactChannelSymbols( - facts.ReturnSummaries, - facts.NarrowReturns, - facts.FuncTypes, - facts.FunctionFacts, - ) - if len(symbols) == 0 { - return - } - for _, sym := range symbols { - ff := readFunctionFactFromFacts(facts, sym) - writeFunctionFactToFacts(facts, sym, ff) - } -} - -func canonicalFunctionFacts(facts api.Facts) api.FunctionFacts { - symbols := collectFunctionFactChannelSymbols( - facts.ReturnSummaries, - facts.NarrowReturns, - facts.FuncTypes, - facts.FunctionFacts, - ) - if len(symbols) == 0 { - return nil - } - - out := make(api.FunctionFacts, len(symbols)) - factsCopy := facts - for _, sym := range symbols { - ff := readFunctionFactFromFacts(&factsCopy, sym) - if len(ff.Summary) == 0 && len(ff.Narrow) == 0 && ff.Func == nil { - continue - } - out[sym] = ff - } - if len(out) == 0 { - return nil - } - return out + facts.FunctionFacts = normalizeFunctionFactMap(facts.FunctionFacts) } diff --git a/compiler/check/returns/join.go b/compiler/check/returns/join.go index 37396615..bd4ff769 100644 --- a/compiler/check/returns/join.go +++ b/compiler/check/returns/join.go @@ -63,7 +63,7 @@ func ReturnTypesRefine(a, b []typ.Type) bool { } // ReturnTypesExtendRecord reports whether a extends b by adding record fields. -// This treats record field supersets as refinements for return summaries. +// This treats record field supersets as refinements for return vectors. func ReturnTypesExtendRecord(a, b []typ.Type) bool { if len(a) == 0 || len(b) == 0 { return false @@ -125,16 +125,34 @@ func SelectPreferredReturnVector(a, b []typ.Type) ([]typ.Type, bool) { if ReturnTypesRepairNever(b, a) { return b, true } + if ReturnTypesStopRecursiveStructuralGrowth(a, b) { + return a, true + } + if ReturnTypesStopRecursiveStructuralGrowth(b, a) { + return b, true + } + if ReturnTypesRefineFalsyMapKeys(a, b) { + return a, true + } + if ReturnTypesRefineFalsyMapKeys(b, a) { + return b, true + } if ReturnTypesRefine(a, b) { if ReturnTypesAllNil(a) && !ReturnTypesAllNil(b) { return b, true } + if ReturnTypesNestedNilOnlyRegression(a, b) { + return b, true + } return a, true } if ReturnTypesRefine(b, a) { if ReturnTypesAllNil(b) && !ReturnTypesAllNil(a) { return a, true } + if ReturnTypesNestedNilOnlyRegression(b, a) { + return a, true + } return b, true } if ReturnTypesFillNilSlots(a, b) { @@ -143,15 +161,428 @@ func SelectPreferredReturnVector(a, b []typ.Type) ([]typ.Type, bool) { if ReturnTypesFillNilSlots(b, a) { return b, true } - if ReturnTypesExtendRecord(a, b) || ReturnTypesElideOptional(a, b) { + if (ReturnTypesExtendRecord(a, b) || ReturnTypesElideOptional(a, b)) && !ReturnTypesNestedNilOnlyRegression(a, b) { return a, true } - if ReturnTypesExtendRecord(b, a) || ReturnTypesElideOptional(b, a) { + if (ReturnTypesExtendRecord(b, a) || ReturnTypesElideOptional(b, a)) && !ReturnTypesNestedNilOnlyRegression(b, a) { return b, true } return nil, false } +// ReturnTypesRefineFalsyMapKeys reports whether candidate is the same map-like +// shape as baseline after removing stale falsy key members from baseline. This +// handles fixed-point rounds where an early branch-insensitive dynamic index +// observes a key as `string | false`, then the solved guard proves the actual +// write key is `string`. +func ReturnTypesRefineFalsyMapKeys(candidate, baseline []typ.Type) bool { + if len(candidate) == 0 || len(baseline) == 0 || len(candidate) != len(baseline) { + return false + } + strict := false + for i := range candidate { + refines, changed := typeRefinesFalsyMapKey(candidate[i], baseline[i]) + if !refines { + return false + } + if changed { + strict = true + } + } + return strict +} + +func typeRefinesFalsyMapKey(candidate, baseline typ.Type) (bool, bool) { + candidate = unwrapStructuralShape(candidate) + baseline = unwrapStructuralShape(baseline) + if candidate == nil || baseline == nil { + return candidate == baseline, false + } + if typ.TypeEquals(candidate, baseline) { + return true, false + } + + switch b := baseline.(type) { + case *typ.Map: + c, ok := candidate.(*typ.Map) + if !ok { + return false, false + } + return mapKeyTruthyRefinement(c.Key, c.Value, b.Key, b.Value) + case *typ.Record: + if c, ok := candidate.(*typ.Map); ok { + if len(b.Fields) != 0 || b.Metatable != nil || !b.HasMapComponent() { + return false, false + } + return mapKeyTruthyRefinement(c.Key, c.Value, b.MapKey, b.MapValue) + } + c, ok := candidate.(*typ.Record) + if !ok || !c.HasMapComponent() || !b.HasMapComponent() { + return false, false + } + if c.Open && !b.Open { + return false, false + } + if len(c.Fields) != len(b.Fields) { + return false, false + } + for _, bf := range b.Fields { + cf := c.GetField(bf.Name) + if cf == nil || cf.Optional != bf.Optional || cf.Readonly != bf.Readonly || !typ.TypeEquals(cf.Type, bf.Type) { + return false, false + } + } + if (c.Metatable == nil) != (b.Metatable == nil) || (c.Metatable != nil && !typ.TypeEquals(c.Metatable, b.Metatable)) { + return false, false + } + return mapKeyTruthyRefinement(c.MapKey, c.MapValue, b.MapKey, b.MapValue) + default: + return false, false + } +} + +func mapKeyTruthyRefinement(candidateKey, candidateValue, baselineKey, baselineValue typ.Type) (bool, bool) { + if !typ.TypeEquals(candidateValue, baselineValue) { + return false, false + } + refinedKey := narrow.ToTruthy(baselineKey) + if refinedKey == nil || refinedKey.Kind().IsNever() || typ.TypeEquals(refinedKey, baselineKey) { + return false, false + } + if typ.TypeEquals(candidateKey, refinedKey) || subtype.IsSubtype(candidateKey, refinedKey) { + return true, true + } + return false, false +} + +// ReturnTypesNestedNilOnlyRegression reports whether candidate's apparent +// refinement only adds nested nil facts over a more useful baseline shape. A +// required `nil` field or `unknown -> nil` field does not help callers, but it +// can make iterative structural facts oscillate with later non-nil evidence. +func ReturnTypesNestedNilOnlyRegression(candidate, baseline []typ.Type) bool { + if len(candidate) == 0 || len(baseline) == 0 || len(candidate) != len(baseline) { + return false + } + for i := range candidate { + if typeNestedNilOnlyRegression(candidate[i], baseline[i]) { + return true + } + } + return false +} + +func typeNestedNilOnlyRegression(candidate, baseline typ.Type) bool { + candidate = unwrapStructuralShape(candidate) + baseline = unwrapStructuralShape(baseline) + if candidate == nil || baseline == nil || typ.TypeEquals(candidate, baseline) { + return false + } + if unwrap.IsNilType(candidate) { + return typ.IsAny(baseline) || typ.IsUnknown(baseline) || unwrap.IsOptionalLike(baseline) + } + + switch c := candidate.(type) { + case *typ.Record: + b, ok := baseline.(*typ.Record) + if !ok { + return false + } + for _, cf := range c.Fields { + bf := b.GetField(cf.Name) + if bf == nil { + continue + } + if unwrap.IsNilType(cf.Type) && (bf.Optional || typ.IsAny(bf.Type) || typ.IsUnknown(bf.Type) || unwrap.IsOptionalLike(bf.Type)) { + return true + } + if typeNestedNilOnlyRegression(cf.Type, bf.Type) { + return true + } + } + if c.HasMapComponent() && b.HasMapComponent() { + return typeNestedNilOnlyRegression(c.MapValue, b.MapValue) + } + case *typ.Array: + if b, ok := baseline.(*typ.Array); ok { + return typeNestedNilOnlyRegression(c.Element, b.Element) + } + case *typ.Map: + if b, ok := baseline.(*typ.Map); ok { + return typeNestedNilOnlyRegression(c.Value, b.Value) + } + case *typ.Tuple: + b, ok := baseline.(*typ.Tuple) + if !ok || len(c.Elements) != len(b.Elements) { + return false + } + for i := range c.Elements { + if typeNestedNilOnlyRegression(c.Elements[i], b.Elements[i]) { + return true + } + } + case *typ.Function: + b, ok := baseline.(*typ.Function) + if !ok || len(c.Returns) != len(b.Returns) { + return false + } + for i := range c.Returns { + if typeNestedNilOnlyRegression(c.Returns[i], b.Returns[i]) { + return true + } + } + } + return false +} + +// ReturnTypesStopRecursiveStructuralGrowth reports whether growing embeds the +// same structural container shape as stable beneath its root. Recursive table +// builders such as deep-copy helpers otherwise look like ever-more-specific +// refinements: {[string]: any} -> {[string]: {[string]: nil}} -> ... . The +// existing top-level shape is already a sound upper bound, so keep it once the +// candidate starts feeding that shape back through one of its children. +func ReturnTypesStopRecursiveStructuralGrowth(stable, growing []typ.Type) bool { + if len(stable) == 0 || len(growing) == 0 || len(stable) != len(growing) { + return false + } + + strict := false + for i := range stable { + s := stable[i] + g := growing[i] + if s == nil || g == nil { + return false + } + if typ.TypeEquals(s, g) { + continue + } + if typ.IsAbsentOrUnknown(s) || !typeCanSelfEmbed(s) { + return false + } + if !shallowStructuralShapeEquals(g, s) { + return false + } + if !typeContainsNestedStructuralShape(g, s) { + return false + } + strict = true + } + return strict +} + +func typeContainsNestedStructuralShape(haystack, needle typ.Type) bool { + return typeContainsNestedStructuralShapeDepth(haystack, needle, make(map[typ.Type]bool), false) +} + +func typeContainsNestedStructuralShapeDepth(haystack, needle typ.Type, seen map[typ.Type]bool, belowContainer bool) bool { + if haystack == nil || needle == nil { + return false + } + if seen[haystack] { + return false + } + seen[haystack] = true + + node := unwrapStructuralShape(haystack) + if node == nil { + return false + } + if belowContainer && shallowStructuralShapeEquals(node, needle) { + return true + } + + descend := func(child typ.Type, childBelowContainer bool) bool { + return typeContainsNestedStructuralShapeDepth(child, needle, seen, childBelowContainer) + } + + switch n := node.(type) { + case *typ.Optional: + return descend(n.Inner, belowContainer) + case *typ.Union: + for _, member := range n.Members { + if descend(member, belowContainer) { + return true + } + } + return false + case *typ.Intersection: + for _, member := range n.Members { + if descend(member, belowContainer) { + return true + } + } + return false + case *typ.Array: + return descend(n.Element, true) + case *typ.Map: + return descend(n.Key, true) || descend(n.Value, true) + case *typ.Tuple: + for _, elem := range n.Elements { + if descend(elem, true) { + return true + } + } + return false + case *typ.Record: + for _, field := range n.Fields { + if descend(field.Type, true) { + return true + } + } + if n.Metatable != nil && descend(n.Metatable, true) { + return true + } + if n.HasMapComponent() { + return descend(n.MapKey, true) || descend(n.MapValue, true) + } + return false + case *typ.Function: + for _, param := range n.Params { + if descend(param.Type, true) { + return true + } + } + if n.Variadic != nil && descend(n.Variadic, true) { + return true + } + for _, ret := range n.Returns { + if descend(ret, true) { + return true + } + } + return false + case *typ.Instantiated: + for _, arg := range n.TypeArgs { + if descend(arg, belowContainer) { + return true + } + } + return false + case *typ.Interface: + for _, method := range n.Methods { + if method.Type != nil && descend(method.Type, true) { + return true + } + } + return false + default: + return false + } +} + +func shallowStructuralShapeEquals(a, b typ.Type) bool { + a = unwrapStructuralShape(a) + b = unwrapStructuralShape(b) + if a == nil || b == nil { + return a == b + } + + switch av := a.(type) { + case *typ.Union: + for _, member := range av.Members { + if shallowStructuralShapeEquals(member, b) { + return true + } + } + return false + case *typ.Intersection: + for _, member := range av.Members { + if shallowStructuralShapeEquals(member, b) { + return true + } + } + return false + } + switch bv := b.(type) { + case *typ.Union: + for _, member := range bv.Members { + if shallowStructuralShapeEquals(a, member) { + return true + } + } + return false + case *typ.Intersection: + for _, member := range bv.Members { + if shallowStructuralShapeEquals(a, member) { + return true + } + } + return false + } + + switch av := a.(type) { + case *typ.Array: + _, ok := b.(*typ.Array) + return ok + case *typ.Map: + bv, ok := b.(*typ.Map) + return ok && shallowMapKeyShapeEquals(av.Key, bv.Key) + case *typ.Tuple: + bv, ok := b.(*typ.Tuple) + return ok && len(av.Elements) == len(bv.Elements) + case *typ.Record: + bv, ok := b.(*typ.Record) + return ok && shallowRecordShapeEquals(av, bv) + default: + return typ.TypeEquals(a, b) + } +} + +func unwrapStructuralShape(t typ.Type) typ.Type { + for t != nil { + switch v := t.(type) { + case *typ.Annotated: + if v.Inner == nil || v.Inner == t { + return t + } + t = v.Inner + case *typ.Alias: + if v.Target == nil || v.Target == t { + return t + } + t = v.Target + case *typ.Optional: + if v.Inner == nil || v.Inner == t { + return t + } + t = v.Inner + default: + return t + } + } + return nil +} + +func shallowMapKeyShapeEquals(a, b typ.Type) bool { + if a == nil || b == nil { + return a == b + } + if typ.TypeEquals(a, b) { + return true + } + return typ.IsAny(a) || typ.IsAny(b) || typ.IsUnknown(a) || typ.IsUnknown(b) +} + +func shallowRecordShapeEquals(a, b *typ.Record) bool { + if a == nil || b == nil { + return a == b + } + if a.HasMapComponent() != b.HasMapComponent() { + return false + } + if a.HasMapComponent() && !shallowMapKeyShapeEquals(a.MapKey, b.MapKey) { + return false + } + if len(a.Fields) != len(b.Fields) { + return false + } + for _, field := range a.Fields { + if b.GetField(field.Name) == nil { + return false + } + } + return true +} + // SelectRefiningReturnVector prefers candidate only when it is a directional // refinement of baseline. It never prefers baseline over candidate. // @@ -613,25 +1044,71 @@ func NormalizeReturnVector(rets []typ.Type) []typ.Type { return nil } out := make([]typ.Type, len(rets)) + copy(out, rets) + return NormalizeReturnVectorInPlace(out) +} + +// NormalizeReturnVectorInPlace replaces nil slots with explicit nil types in an +// owned return vector. +func NormalizeReturnVectorInPlace(rets []typ.Type) []typ.Type { + if len(rets) == 0 { + return nil + } for i, t := range rets { if t == nil { - out[i] = typ.Nil - } else { - out[i] = t + rets[i] = typ.Nil } } - return out + return rets +} + +func canonicalReturnVector(rets []typ.Type) []typ.Type { + if len(rets) == 0 { + return nil + } + for i, t := range rets { + if t != nil { + continue + } + out := make([]typ.Type, len(rets)) + copy(out, rets) + out[i] = typ.Nil + for j := i + 1; j < len(out); j++ { + if out[j] == nil { + out[j] = typ.Nil + } + } + return out + } + return rets } func normalizeAndPruneReturnVector(rets []typ.Type) []typ.Type { - out := NormalizeReturnVector(rets) - if len(out) == 0 { + if len(rets) == 0 { return nil } - for i, ret := range out { - out[i] = typ.PruneSoftUnionMembers(ret) + var out []typ.Type + for i, ret := range rets { + normalized := ret + if normalized == nil { + normalized = typ.Nil + } + pruned := typ.PruneSoftUnionMembers(normalized) + if out != nil { + out[i] = pruned + continue + } + if pruned == ret { + continue + } + out = make([]typ.Type, len(rets)) + copy(out, rets[:i]) + out[i] = pruned } - return out + if out != nil { + return out + } + return rets } // MergeReturnSummary applies the canonical return-summary merge policy shared by @@ -702,20 +1179,27 @@ func MergeFunctionFactType(existing, candidate typ.Type) typ.Type { return typ.JoinPreferNonSoft(existing, candidate) } +type functionFactVariants struct { + funcs []*typ.Function + residuals []typ.Type +} + func mergeFunctionFactVariants(existing, candidate typ.Type) (typ.Type, bool) { - existingFns := functionVariantsForFactMerge(existing) - candidateFns := functionVariantsForFactMerge(candidate) - if len(existingFns) == 0 || len(candidateFns) == 0 { + existingVariants := splitFunctionFactVariants(existing) + candidateVariants := splitFunctionFactVariants(candidate) + if len(existingVariants.funcs) == 0 || len(candidateVariants.funcs) == 0 { return nil, false } - all := make([]*typ.Function, 0, len(existingFns)+len(candidateFns)) - all = append(all, existingFns...) - all = append(all, candidateFns...) + + all := make([]*typ.Function, 0, len(existingVariants.funcs)+len(candidateVariants.funcs)) + all = append(all, existingVariants.funcs...) + all = append(all, candidateVariants.funcs...) for i := 1; i < len(all); i++ { if !sameFunctionShapeForFactMerge(all[0], all[i]) { return nil, false } } + merged := all[0] for i := 1; i < len(all); i++ { next, _ := mergeFunctionFactsByShape(merged, all[i]).(*typ.Function) @@ -724,40 +1208,39 @@ func mergeFunctionFactVariants(existing, candidate typ.Type) (typ.Type, bool) { } merged = next } - return merged, true + + residuals := make([]typ.Type, 0, len(existingVariants.residuals)+len(candidateVariants.residuals)+1) + residuals = append(residuals, existingVariants.residuals...) + residuals = append(residuals, candidateVariants.residuals...) + if len(residuals) == 0 { + return merged, true + } + residuals = append(residuals, merged) + return typ.NewUnion(residuals...), true } -func functionVariantsForFactMerge(t typ.Type) []*typ.Function { - if t == nil { - return nil +func splitFunctionFactVariants(t typ.Type) functionFactVariants { + var out functionFactVariants + collectFunctionFactVariants(t, &out) + return out +} + +func collectFunctionFactVariants(t typ.Type, out *functionFactVariants) { + if t == nil || out == nil { + return } switch v := unwrap.Alias(t).(type) { - case *typ.Optional: - // Optional function values include nil. Do not collapse them to a plain - // function fact or we lose optionality in merged facts. - return nil - case *typ.Function: - return []*typ.Function{v} case *typ.Union: - if len(v.Members) == 0 { - return nil + for _, member := range v.Members { + collectFunctionFactVariants(member, out) } - var out []*typ.Function - for _, m := range v.Members { - fn := unwrap.Function(m) - if fn == nil { - // Only collapse union variants when the union is function-only. - // Mixed unions (for example function|nil) must stay untouched. - return nil - } - out = append(out, fn) - } - return out + return } if fn := unwrap.Function(t); fn != nil { - return []*typ.Function{fn} + out.funcs = append(out.funcs, fn) + return } - return nil + out.residuals = append(out.residuals, t) } func sameFunctionShapeForFactMerge(a, b *typ.Function) bool { diff --git a/compiler/check/returns/join_test.go b/compiler/check/returns/join_test.go index e41facdc..6c110940 100644 --- a/compiler/check/returns/join_test.go +++ b/compiler/check/returns/join_test.go @@ -3,9 +3,11 @@ package returns import ( "testing" + "github.com/wippyai/go-lua/types/kind" "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" typjoin "github.com/wippyai/go-lua/types/typ/join" + "github.com/wippyai/go-lua/types/typ/unwrap" ) func TestJoinReturnVectors_Empty(t *testing.T) { @@ -333,6 +335,119 @@ func TestMergeFunctionFactType_WidensParamToCoverObservedCallsites(t *testing.T) } } +func TestMergeFunctionFactType_KeepsBaselineOverNestedNilOnlyRegression(t *testing.T) { + baselineReturn := typ.NewRecord(). + Field("full_path", typ.String). + Field("parent", typ.Unknown). + OptField("after_all", typ.Nil). + SetOpen(true). + Build() + candidateReturn := typ.NewRecord(). + Field("full_path", typ.String). + Field("parent", typ.Nil). + Field("after_all", typ.Nil). + SetOpen(true). + Build() + + baseline := typ.Func().Param("name", typ.Unknown).Returns(baselineReturn).Build() + candidate := typ.Func().Param("name", typ.Unknown).Returns(candidateReturn).Build() + + merged := MergeFunctionFactType(baseline, candidate) + fn, ok := merged.(*typ.Function) + if !ok || len(fn.Returns) != 1 { + t.Fatalf("expected merged function return, got %v", merged) + } + if !typ.TypeEquals(fn.Returns[0], baselineReturn) { + t.Fatalf("expected baseline record to survive nil-only refinement, got %v", fn.Returns[0]) + } +} + +func TestMergeReturnSummary_PrefersCurrentTruthyMapKeyRefinement(t *testing.T) { + baseline := typ.NewMap(typ.NewUnion(typ.String, typ.False), typ.Unknown) + candidate := typ.NewMap(typ.String, typ.Unknown) + + merged := MergeReturnSummary([]typ.Type{baseline}, []typ.Type{candidate}) + if len(merged) != 1 || !typ.TypeEquals(merged[0], candidate) { + t.Fatalf("expected stale falsy map key to refine to %v, got %v", candidate, merged) + } +} + +func TestMergeReturnSummary_PrefersCurrentTruthyRecordMapKeyRefinement(t *testing.T) { + entryArray := typ.NewArray(typ.Unknown) + baseline := typ.NewRecord(). + MapComponent(typ.NewUnion(typ.Nil, typ.String, typ.False), entryArray). + SetOpen(true). + Build() + candidate := typ.NewRecord(). + MapComponent(typ.String, entryArray). + Build() + + merged := MergeReturnSummary([]typ.Type{baseline}, []typ.Type{candidate}) + if len(merged) != 1 || !typ.TypeEquals(merged[0], candidate) { + t.Fatalf("expected stale falsy record map key to refine to %v, got %v", candidate, merged) + } +} + +func TestMergeReturnSummary_PrefersMapOverStaleOpenRecordMapKeyRefinement(t *testing.T) { + entryArray := typ.NewArray(typ.Unknown) + baseline := typ.NewRecord(). + MapComponent(typ.NewUnion(typ.Nil, typ.String, typ.False), entryArray). + SetOpen(true). + Build() + candidate := typ.NewMap(typ.String, entryArray) + + merged := MergeReturnSummary([]typ.Type{baseline}, []typ.Type{candidate}) + if len(merged) != 1 || !typ.TypeEquals(merged[0], candidate) { + t.Fatalf("expected map to replace stale open record map %v, got %v", candidate, merged) + } +} + +func TestMergeFunctionFactType_CollapsesMixedFunctionUnionVariants(t *testing.T) { + base := typ.Func(). + Param("name", typ.Unknown). + Returns(typ.NewRecord().Field("full_path", typ.String).SetOpen(true).Build()). + Build() + withChildren := typ.Func(). + Param("name", typ.Unknown). + Returns(typ.NewRecord(). + Field("full_path", typ.String). + Field("children", typ.NewArray(typ.Unknown)). + SetOpen(true). + Build()). + Build() + withTests := typ.Func(). + Param("name", typ.Unknown). + Returns(typ.NewRecord(). + Field("full_path", typ.String). + Field("tests", typ.NewArray(typ.Unknown)). + SetOpen(true). + Build()). + Build() + + merged := MergeFunctionFactType(typ.NewUnion(typ.Nil, base, withChildren), withTests) + if merged == nil { + t.Fatal("expected merged type") + } + if fn := unwrap.Function(merged); fn == nil { + t.Fatalf("expected merged function variant, got %v", merged) + } else if len(fn.Returns) != 1 { + t.Fatalf("expected one return, got %v", fn.Returns) + } else { + rec, ok := fn.Returns[0].(*typ.Record) + if !ok { + t.Fatalf("expected record return, got %T", fn.Returns[0]) + } + for _, field := range []string{"full_path", "children", "tests"} { + if rec.GetField(field) == nil { + t.Fatalf("expected merged field %q in %v", field, rec) + } + } + } + if merged.Kind() != kind.Optional { + t.Fatalf("expected nil residual to be preserved as optional, got %v", merged) + } +} + func TestMergeFunctionFactType_DoesNotDropNonFunctionUnionMembers(t *testing.T) { fn := typ.Func().Param("x", typ.String).Returns(typ.String).Build() existing := typ.NewUnion(fn, typ.Number) diff --git a/compiler/check/returns/kernel.go b/compiler/check/returns/kernel.go index 92363b84..9e73e24f 100644 --- a/compiler/check/returns/kernel.go +++ b/compiler/check/returns/kernel.go @@ -1,137 +1,50 @@ package returns import ( - "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" - "github.com/wippyai/go-lua/types/typ" "github.com/wippyai/go-lua/types/typ/unwrap" ) -// ReconcileFunctionFactInput captures all channels that can influence a single -// local-function fact slot during one update step. -type ReconcileFunctionFactInput struct { - ExistingSummary []typ.Type - ExistingNarrow []typ.Type - ExistingFunc typ.Type +// JoinFunctionFact precisely merges two observations for one local function +// during a single analysis iteration. +func JoinFunctionFact(existing, candidate api.FunctionFact) api.FunctionFact { + existing = NormalizeFunctionFact(existing) + candidate = NormalizeFunctionFact(candidate) + out := existing - CandidateSummary []typ.Type - CandidateNarrow []typ.Type - CandidateFunc typ.Type -} - -// ReconcileFunctionFactOutput is the canonical reconciled state for one symbol. -type ReconcileFunctionFactOutput struct { - Summary []typ.Type - Narrow []typ.Type - Func typ.Type -} - -// FunctionFactCandidate captures incoming candidate data for one symbol's -// function-related fact channels. -type FunctionFactCandidate struct { - Summary []typ.Type - Narrow []typ.Type - Func typ.Type -} - -// ReconcileFunctionFact centralizes reconciliation of return summary, narrow -// return summary, and function type for one symbol. -// -// This is the only policy entrypoint for function-fact channel convergence. -func ReconcileFunctionFact(in ReconcileFunctionFactInput) ReconcileFunctionFactOutput { - out := ReconcileFunctionFactOutput{ - Summary: NormalizeReturnVector(in.ExistingSummary), - Narrow: NormalizeReturnVector(in.ExistingNarrow), - Func: in.ExistingFunc, - } - - if len(in.CandidateSummary) > 0 { - out.Summary = MergeReturnSummary(out.Summary, in.CandidateSummary) + if len(candidate.Summary) > 0 { + out.Summary = MergeReturnSummary(out.Summary, candidate.Summary) } - if len(in.CandidateNarrow) > 0 { - out.Narrow = MergeReturnSummary(out.Narrow, in.CandidateNarrow) + if len(candidate.Narrow) > 0 { + out.Narrow = MergeReturnSummary(out.Narrow, candidate.Narrow) } - if in.CandidateFunc != nil { - out.Func = MergeFunctionFactType(out.Func, in.CandidateFunc) + if candidate.Type != nil { + out.Type = MergeFunctionFactType(out.Type, candidate.Type) } - // Keep summary and narrow channels mutually refining when post-flow narrow + // Keep summary and post-flow narrow results mutually refining when narrow // provides first-order information. MergeReturnSummary is the canonical // policy and already encodes directional refinement preference. if len(out.Narrow) > 0 { if len(out.Summary) == 0 { - out.Summary = NormalizeReturnVector(out.Narrow) + out.Summary = canonicalReturnVector(out.Narrow) } else { out.Summary = MergeReturnSummary(out.Summary, out.Narrow) } } - if fn := unwrap.Function(out.Func); fn != nil { + if fn := unwrap.Function(out.Type); fn != nil { alignedSummary := out.Summary - if len(out.Narrow) > 0 { - // Canonical tie-breaker: function facts track post-flow behavior. - // Narrow summaries are produced from solved flow and are authoritative - // for call-site typing in the current iteration. - alignedSummary = out.Narrow - } if len(alignedSummary) > 0 { if aligned, changed := AlignFunctionTypeWithSummary(fn, alignedSummary); changed { - out.Func = aligned + out.Type = aligned fn = aligned } } if len(out.Summary) == 0 && fn != nil && len(fn.Returns) > 0 { - out.Summary = NormalizeReturnVector(fn.Returns) + out.Summary = canonicalReturnVector(fn.Returns) } } return out } - -// MergeFunctionFactIntoFacts reconciles and writes function-related facts for -// one symbol into a facts bundle using canonical kernel policy. -func MergeFunctionFactIntoFacts(facts *api.Facts, sym cfg.SymbolID, candidate FunctionFactCandidate) { - if facts == nil || sym == 0 { - return - } - NormalizeFunctionFactChannels(facts) - mergeFunctionFactIntoNormalizedFacts(facts, sym, candidate) -} - -func mergeFunctionFactIntoNormalizedFacts(facts *api.Facts, sym cfg.SymbolID, candidate FunctionFactCandidate) { - existing := readFunctionFactFromFacts(facts, sym) - reconciled := ReconcileFunctionFact(ReconcileFunctionFactInput{ - ExistingSummary: existing.Summary, - ExistingNarrow: existing.Narrow, - ExistingFunc: existing.Func, - CandidateSummary: candidate.Summary, - CandidateNarrow: candidate.Narrow, - CandidateFunc: candidate.Func, - }) - writeFunctionFactToFacts(facts, sym, api.FunctionFact{ - Summary: reconciled.Summary, - Narrow: reconciled.Narrow, - Func: reconciled.Func, - }) -} - -// MergeFunctionFactsIntoFacts merges full function-fact channel maps into facts -// via the canonical single-symbol reconciliation path. -func MergeFunctionFactsIntoFacts( - facts *api.Facts, - summaries api.ReturnSummaries, - narrows api.NarrowReturnSummaries, - funcs api.FuncTypes, -) { - if facts == nil { - return - } - NormalizeFunctionFactChannels(facts) - for _, sym := range collectFunctionFactChannelSymbols(summaries, narrows, funcs, nil) { - mergeFunctionFactIntoNormalizedFacts(facts, sym, FunctionFactCandidate{ - Summary: summaries[sym], - Narrow: narrows[sym], - Func: funcs[sym], - }) - } -} diff --git a/compiler/check/returns/kernel_test.go b/compiler/check/returns/kernel_test.go index 665f285f..0364c5ca 100644 --- a/compiler/check/returns/kernel_test.go +++ b/compiler/check/returns/kernel_test.go @@ -8,138 +8,109 @@ import ( "github.com/wippyai/go-lua/types/typ" ) -func TestMergeFunctionFactIntoFacts_InitialWrite(t *testing.T) { - facts := &api.Facts{} +func TestJoinFunctionFact_InitialObservation(t *testing.T) { sym := cfg.SymbolID(11) fn := typ.Func().Returns(typ.String).Build() - MergeFunctionFactIntoFacts(facts, sym, FunctionFactCandidate{ + facts := api.Facts{FunctionFacts: api.FunctionFacts{sym: JoinFunctionFact(api.FunctionFact{}, api.FunctionFact{ Summary: []typ.Type{typ.String}, Narrow: []typ.Type{typ.String}, - Func: fn, - }) + Type: fn, + })}} - if got := facts.ReturnSummaries[sym]; !ReturnTypesEqual(got, []typ.Type{typ.String}) { + if got := facts.FunctionFacts.Summary(sym); !ReturnTypesEqual(got, []typ.Type{typ.String}) { t.Fatalf("summary mismatch: got %v", got) } - if got := facts.NarrowReturns[sym]; !ReturnTypesEqual(got, []typ.Type{typ.String}) { + if got := facts.FunctionFacts.NarrowSummary(sym); !ReturnTypesEqual(got, []typ.Type{typ.String}) { t.Fatalf("narrow mismatch: got %v", got) } - if got := facts.FuncTypes[sym]; !typ.TypeEquals(got, fn) { + if got := facts.FunctionFacts.FunctionType(sym); !typ.TypeEquals(got, fn) { t.Fatalf("func mismatch: got %v", got) } } -func TestMergeFunctionFactIntoFacts_MatchesKernelReconcile(t *testing.T) { - sym := cfg.SymbolID(17) +func TestJoinFunctionFact_MergesExistingAndCandidate(t *testing.T) { existingFn := typ.Func().Returns(typ.Number).Build() candidateFn := typ.Func().Returns(typ.String).Build() - facts := &api.Facts{ - FunctionFacts: api.FunctionFacts{ - sym: { - Summary: []typ.Type{typ.Number}, - Narrow: []typ.Type{typ.Number}, - Func: existingFn, - }, - }, - ReturnSummaries: api.ReturnSummaries{ - sym: []typ.Type{typ.Number}, - }, - NarrowReturns: api.NarrowReturnSummaries{ - sym: []typ.Type{typ.Number}, - }, - FuncTypes: api.FuncTypes{ - sym: existingFn, - }, + existing := api.FunctionFact{ + Summary: []typ.Type{typ.Number}, + Narrow: []typ.Type{typ.Number}, + Type: existingFn, } - candidate := FunctionFactCandidate{ + candidate := api.FunctionFact{ Summary: []typ.Type{typ.String}, Narrow: []typ.Type{typ.String}, - Func: candidateFn, - } - existing := readFunctionFactFromFacts(facts, sym) - expected := ReconcileFunctionFact(ReconcileFunctionFactInput{ - ExistingSummary: existing.Summary, - ExistingNarrow: existing.Narrow, - ExistingFunc: existing.Func, - CandidateSummary: candidate.Summary, - CandidateNarrow: candidate.Narrow, - CandidateFunc: candidate.Func, - }) - - MergeFunctionFactIntoFacts(facts, sym, candidate) + Type: candidateFn, + } + got := JoinFunctionFact(existing, candidate) - if got := facts.ReturnSummaries[sym]; !ReturnTypesEqual(got, expected.Summary) { - t.Fatalf("summary mismatch: got %v want %v", got, expected.Summary) + if !ReturnTypesEqual(got.Summary, []typ.Type{typ.NewUnion(typ.Number, typ.String)}) { + t.Fatalf("summary mismatch: got %v", got.Summary) } - if got := facts.NarrowReturns[sym]; !ReturnTypesEqual(got, expected.Narrow) { - t.Fatalf("narrow mismatch: got %v want %v", got, expected.Narrow) + if !ReturnTypesEqual(got.Narrow, []typ.Type{typ.NewUnion(typ.Number, typ.String)}) { + t.Fatalf("narrow mismatch: got %v", got.Narrow) } - if got := facts.FuncTypes[sym]; !typ.TypeEquals(got, expected.Func) { - t.Fatalf("func mismatch: got %v want %v", got, expected.Func) + if got.Type == nil { + t.Fatal("expected merged function type") } } -func TestMergeFunctionFactsIntoFacts_BatchMerge(t *testing.T) { +func TestJoinFacts_BatchMergeFunctionFacts(t *testing.T) { symSummary := cfg.SymbolID(21) symNarrow := cfg.SymbolID(22) symFunc := cfg.SymbolID(23) - facts := &api.Facts{} funcType := typ.Func().Returns(typ.Boolean).Build() - MergeFunctionFactsIntoFacts( - facts, - api.ReturnSummaries{ - symSummary: []typ.Type{typ.String}, - }, - api.NarrowReturnSummaries{ - symNarrow: []typ.Type{typ.Number}, + facts := JoinFacts( + api.Facts{ + FunctionFacts: api.FunctionFacts{ + symSummary: {Summary: []typ.Type{typ.String}}, + symNarrow: {Narrow: []typ.Type{typ.Number}}, + }, }, - api.FuncTypes{ - symFunc: funcType, + api.Facts{ + FunctionFacts: api.FunctionFacts{ + symFunc: {Type: funcType}, + }, }, ) - if got := facts.ReturnSummaries[symSummary]; !ReturnTypesEqual(got, []typ.Type{typ.String}) { + if got := facts.FunctionFacts.Summary(symSummary); !ReturnTypesEqual(got, []typ.Type{typ.String}) { t.Fatalf("summary mismatch: got %v", got) } - if got := facts.NarrowReturns[symNarrow]; !ReturnTypesEqual(got, []typ.Type{typ.Number}) { + if got := facts.FunctionFacts.NarrowSummary(symNarrow); !ReturnTypesEqual(got, []typ.Type{typ.Number}) { t.Fatalf("narrow mismatch: got %v", got) } - if got := facts.FuncTypes[symFunc]; !typ.TypeEquals(got, funcType) { + if got := facts.FunctionFacts.FunctionType(symFunc); !typ.TypeEquals(got, funcType) { t.Fatalf("func mismatch: got %v", got) } } -func TestReconcileFunctionFact_NarrowSummaryReplacesOpenTopPlaceholder(t *testing.T) { +func TestJoinFunctionFact_NarrowSummaryReplacesOpenTopPlaceholder(t *testing.T) { openTop := typ.NewRecord().SetOpen(true).Build() existingFunc := typ.Func().Returns(openTop).Build() candidateFunc := typ.Func().Returns(openTop).Build() narrow := []typ.Type{typ.NewArray(typ.Unknown)} - out := ReconcileFunctionFact(ReconcileFunctionFactInput{ - ExistingSummary: []typ.Type{openTop}, - ExistingNarrow: nil, - ExistingFunc: existingFunc, - CandidateSummary: []typ.Type{openTop}, - CandidateNarrow: narrow, - CandidateFunc: candidateFunc, - }) + out := JoinFunctionFact( + api.FunctionFact{Summary: []typ.Type{openTop}, Type: existingFunc}, + api.FunctionFact{Summary: []typ.Type{openTop}, Narrow: narrow, Type: candidateFunc}, + ) if !ReturnTypesEqual(normalizeAndPruneReturnVector(out.Summary), normalizeAndPruneReturnVector(narrow)) { t.Fatalf("summary mismatch: got %v want %v", out.Summary, narrow) } - fn, ok := out.Func.(*typ.Function) + fn, ok := out.Type.(*typ.Function) if !ok { - t.Fatalf("expected function fact, got %T", out.Func) + t.Fatalf("expected function fact, got %T", out.Type) } if !ReturnTypesEqual(normalizeAndPruneReturnVector(fn.Returns), normalizeAndPruneReturnVector(narrow)) { t.Fatalf("func returns mismatch: got %v want %v", fn.Returns, narrow) } } -func TestReconcileFunctionFact_NarrowSummaryRepairsNeverArtifact(t *testing.T) { +func TestJoinFunctionFact_NarrowSummaryRepairsNeverArtifact(t *testing.T) { bad := []typ.Type{ typ.NewUnion( typ.NewRecord(). @@ -166,12 +137,10 @@ func TestReconcileFunctionFact_NarrowSummaryRepairsNeverArtifact(t *testing.T) { } existingFunc := typ.Func().Returns(bad...).Build() - out := ReconcileFunctionFact(ReconcileFunctionFactInput{ - ExistingSummary: bad, - ExistingNarrow: nil, - ExistingFunc: existingFunc, - CandidateNarrow: good, - }) + out := JoinFunctionFact( + api.FunctionFact{Summary: bad, Type: existingFunc}, + api.FunctionFact{Narrow: good}, + ) if !ReturnTypesEqual(out.Summary, good) { t.Fatalf("summary mismatch: got %v want %v", out.Summary, good) @@ -179,109 +148,86 @@ func TestReconcileFunctionFact_NarrowSummaryRepairsNeverArtifact(t *testing.T) { if !ReturnTypesEqual(out.Narrow, good) { t.Fatalf("narrow mismatch: got %v want %v", out.Narrow, good) } - fn, ok := out.Func.(*typ.Function) + fn, ok := out.Type.(*typ.Function) if !ok { - t.Fatalf("expected function fact, got %T", out.Func) + t.Fatalf("expected function fact, got %T", out.Type) } if !ReturnTypesEqual(fn.Returns, good) { t.Fatalf("func returns mismatch: got %v want %v", fn.Returns, good) } } -func TestMergeFunctionFactIntoFacts_ReadsLegacyAndWritesCanonical(t *testing.T) { - sym := cfg.SymbolID(41) - facts := &api.Facts{ - ReturnSummaries: api.ReturnSummaries{ - sym: []typ.Type{typ.Number}, - }, - NarrowReturns: api.NarrowReturnSummaries{ - sym: []typ.Type{typ.Number}, - }, - FuncTypes: api.FuncTypes{ - sym: typ.Func().Returns(typ.Number).Build(), - }, - } - - MergeFunctionFactIntoFacts(facts, sym, FunctionFactCandidate{ - Summary: []typ.Type{typ.String}, - Narrow: []typ.Type{typ.String}, - Func: typ.Func().Returns(typ.String).Build(), - }) +func TestJoinFunctionFact_DoesNotAlignFunctionToNarrowFieldRegression(t *testing.T) { + withCapturedMethod := typ.NewRecord(). + Field("x", typ.Integer). + Field("get_x", typ.Func().Param("self", typ.Unknown).Returns(typ.Number).Build()). + Build() + flowOnly := typ.NewRecord(). + Field("x", typ.Integer). + Build() + existingFunc := typ.Func().Returns(flowOnly).Build() + + out := JoinFunctionFact( + api.FunctionFact{Summary: []typ.Type{withCapturedMethod}, Narrow: []typ.Type{flowOnly}, Type: existingFunc}, + api.FunctionFact{Summary: []typ.Type{withCapturedMethod}, Narrow: []typ.Type{flowOnly}, Type: existingFunc}, + ) - ff, ok := facts.FunctionFacts[sym] - if !ok { - t.Fatal("expected canonical FunctionFacts entry") - } - if !ReturnTypesEqual(ff.Summary, facts.ReturnSummaries[sym]) { - t.Fatalf("summary drift: canonical=%v legacy=%v", ff.Summary, facts.ReturnSummaries[sym]) + if !ReturnTypesEqual(out.Summary, []typ.Type{withCapturedMethod}) { + t.Fatalf("summary mismatch: got %v want %v", out.Summary, []typ.Type{withCapturedMethod}) } - if !ReturnTypesEqual(ff.Narrow, facts.NarrowReturns[sym]) { - t.Fatalf("narrow drift: canonical=%v legacy=%v", ff.Narrow, facts.NarrowReturns[sym]) + fn, ok := out.Type.(*typ.Function) + if !ok { + t.Fatalf("expected function fact, got %T", out.Type) } - if !typ.TypeEquals(ff.Func, facts.FuncTypes[sym]) { - t.Fatalf("func drift: canonical=%v legacy=%v", ff.Func, facts.FuncTypes[sym]) + if !ReturnTypesEqual(fn.Returns, []typ.Type{withCapturedMethod}) { + t.Fatalf("func returns should preserve captured method summary, got %v", fn.Returns) } } -func TestNormalizeFunctionFactChannels_PromotesLegacyIntoCanonical(t *testing.T) { +func TestNormalizeFunctionFacts_CanonicalizesStoredFunctionFacts(t *testing.T) { sym := cfg.SymbolID(77) fn := typ.Func().Returns(typ.Number).Build() facts := &api.Facts{ - ReturnSummaries: api.ReturnSummaries{ - sym: []typ.Type{typ.Number}, - }, - NarrowReturns: api.NarrowReturnSummaries{ - sym: []typ.Type{typ.Number}, - }, - FuncTypes: api.FuncTypes{ - sym: fn, + FunctionFacts: api.FunctionFacts{ + sym: {Summary: []typ.Type{nil}, Narrow: []typ.Type{typ.Number}, Type: fn}, }, } - NormalizeFunctionFactChannels(facts) + NormalizeFunctionFacts(facts) ff, ok := facts.FunctionFacts[sym] if !ok { - t.Fatal("expected canonical FunctionFacts entry from legacy channels") + t.Fatal("expected canonical FunctionFacts entry") } - if !ReturnTypesEqual(ff.Summary, facts.ReturnSummaries[sym]) { - t.Fatalf("summary drift: canonical=%v legacy=%v", ff.Summary, facts.ReturnSummaries[sym]) + if !ReturnTypesEqual(ff.Summary, []typ.Type{typ.Nil}) { + t.Fatalf("summary mismatch: got %v", ff.Summary) } - if !ReturnTypesEqual(ff.Narrow, facts.NarrowReturns[sym]) { - t.Fatalf("narrow drift: canonical=%v legacy=%v", ff.Narrow, facts.NarrowReturns[sym]) + if !ReturnTypesEqual(ff.Narrow, []typ.Type{typ.Number}) { + t.Fatalf("narrow mismatch: got %v", ff.Narrow) } - if !typ.TypeEquals(ff.Func, facts.FuncTypes[sym]) { - t.Fatalf("func drift: canonical=%v legacy=%v", ff.Func, facts.FuncTypes[sym]) + if !typ.TypeEquals(ff.Type, fn) { + t.Fatalf("func mismatch: got %v", ff.Type) } } -func TestFunctionFactViews_UseLegacyChannelsWhenCanonicalMissing(t *testing.T) { +func TestFunctionFactsAccessorsReadCanonicalFacts(t *testing.T) { sym := cfg.SymbolID(88) fn := typ.Func().Returns(typ.String).Build() facts := api.Facts{ - ReturnSummaries: api.ReturnSummaries{ - sym: []typ.Type{typ.String}, - }, - NarrowReturns: api.NarrowReturnSummaries{ - sym: []typ.Type{typ.String}, - }, - FuncTypes: api.FuncTypes{ - sym: fn, + FunctionFacts: api.FunctionFacts{ + sym: {Summary: []typ.Type{typ.String}, Narrow: []typ.Type{typ.String}, Type: fn}, }, } - summaries := SummaryViewFromFacts(facts) - if got := summaries[sym]; !ReturnTypesEqual(got, []typ.Type{typ.String}) { - t.Fatalf("summary view mismatch: got %v", got) + if got := facts.FunctionFacts.Summary(sym); !ReturnTypesEqual(got, []typ.Type{typ.String}) { + t.Fatalf("summary mismatch: got %v", got) } - narrows := NarrowViewFromFacts(facts) - if got := narrows[sym]; !ReturnTypesEqual(got, []typ.Type{typ.String}) { - t.Fatalf("narrow view mismatch: got %v", got) + if got := facts.FunctionFacts.NarrowSummary(sym); !ReturnTypesEqual(got, []typ.Type{typ.String}) { + t.Fatalf("narrow mismatch: got %v", got) } - funcs := FuncTypeViewFromFacts(facts) - if got := funcs[sym]; !typ.TypeEquals(got, fn) { - t.Fatalf("func view mismatch: got %v", got) + if got := facts.FunctionFacts.FunctionType(sym); !typ.TypeEquals(got, fn) { + t.Fatalf("func mismatch: got %v", got) } } diff --git a/compiler/check/returns/scc.go b/compiler/check/returns/scc.go index 3b2e5380..445383fd 100644 --- a/compiler/check/returns/scc.go +++ b/compiler/check/returns/scc.go @@ -18,7 +18,7 @@ import ( // and so on. This ordering ensures that when processing an SCC, all functions // it depends on have already been analyzed. // -// Type conversion is performed to bridge cfg.SymbolID and the uint64-based +// Type conversion maps cfg.SymbolID to the uint64-based // internal SCC implementation. func ComputeSymbolSCCs(adj map[cfg.SymbolID][]cfg.SymbolID) [][]cfg.SymbolID { if len(adj) == 0 { diff --git a/compiler/check/returns/signature.go b/compiler/check/returns/signature.go index 583332fe..55129ac6 100644 --- a/compiler/check/returns/signature.go +++ b/compiler/check/returns/signature.go @@ -9,7 +9,7 @@ import ( ) // BuildSeedFunctionTypeWithBindings builds a placeholder function type for an -// SCC sibling that has no return summary yet. +// SCC sibling that has no inferred return vector yet. // // Optional binder metadata enables implicit-self detection in method definitions. func BuildSeedFunctionTypeWithBindings( diff --git a/compiler/check/returns/types.go b/compiler/check/returns/types.go index 809fc066..6dc9460c 100644 --- a/compiler/check/returns/types.go +++ b/compiler/check/returns/types.go @@ -13,21 +13,20 @@ // an SCC, functions are processed together using fixpoint iteration until // return types stabilize. // -// # Return Summaries +// # Return Vectors // -// A return summary is a vector of types representing the types returned by -// a function. For `return a, b, c`, the summary would be [typeof(a), typeof(b), -// typeof(c)]. Summaries are accumulated across all return statements in a +// A return vector represents the types returned by a function. For +// `return a, b, c`, the vector is [typeof(a), typeof(b), typeof(c)]. Vectors +// are accumulated across all return statements in a // function body and joined to produce the final return type. // -// # Canonical vs Seed Summaries +// # Canonical Function Facts vs Iteration Vectors // -// Two summary stores are maintained: -// - Canonical: Fully computed return types from completed analysis -// - Seed: Provisional return types from the current iteration +// The stored authority is api.FunctionFacts. During SCC solving, the inferencer +// also keeps a provisional map of return vectors for the current iteration. // -// During analysis, seed summaries are used for functions in the current SCC -// (to avoid circular dependence), while canonical summaries are used for +// During analysis, iteration vectors are used for functions in the current SCC +// to avoid circular dependence, while canonical function facts are used for // functions outside the SCC (whose types are already known). // // # Parameter Hint Propagation @@ -66,6 +65,6 @@ type LocalFuncInfo struct { ParamHints []typ.Type } -// MaxReturnSummaryIterations limits fixpoint iterations for ReturnSummaries. +// MaxReturnSummaryIterations limits fixpoint iterations for return-vector inference. // Exceeding this indicates a bug (non-monotonic merge) or pathological recursion. const MaxReturnSummaryIterations = 10 diff --git a/compiler/check/returns/widen.go b/compiler/check/returns/widen.go index d0a8ca76..898c6360 100644 --- a/compiler/check/returns/widen.go +++ b/compiler/check/returns/widen.go @@ -3,7 +3,9 @@ package returns import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/infer/paramhints" "github.com/wippyai/go-lua/internal" + "github.com/wippyai/go-lua/types/narrow" "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" typjoin "github.com/wippyai/go-lua/types/typ/join" @@ -12,9 +14,6 @@ import ( // WidenFacts merges two interproc fact bundles. func WidenFacts(prev, next api.Facts) api.Facts { - NormalizeFunctionFactChannels(&prev) - NormalizeFunctionFactChannels(&next) - out := api.Facts{ ParamHints: WidenParamHints(prev.ParamHints, next.ParamHints), LiteralSigs: WidenLiteralSigs(prev.LiteralSigs, next.LiteralSigs), @@ -33,19 +32,7 @@ func WidenFacts(prev, next api.Facts) api.Facts { for _, sym := range symbols { prevFact := readFunctionFactFromFacts(&prev, sym) nextFact := readFunctionFactFromFacts(&next, sym) - reconciled := ReconcileFunctionFact(ReconcileFunctionFactInput{ - ExistingSummary: prevFact.Summary, - ExistingNarrow: prevFact.Narrow, - ExistingFunc: prevFact.Func, - CandidateSummary: nextFact.Summary, - CandidateNarrow: nextFact.Narrow, - CandidateFunc: nextFact.Func, - }) - writeFunctionFactToFacts(&out, sym, api.FunctionFact{ - Summary: widenReturnVectorForConvergence(reconciled.Summary), - Narrow: widenReturnVectorForConvergence(reconciled.Narrow), - Func: maybeWidenTypeForConvergence(reconciled.Func), - }) + writeFunctionFactToFacts(&out, sym, widenFunctionFactForConvergence(prevFact, nextFact)) } if len(out.FunctionFacts) == 0 { out.FunctionFacts = nil @@ -53,31 +40,59 @@ func WidenFacts(prev, next api.Facts) api.Facts { return out } -// WidenReturnSummaries merges return summaries through the canonical -// return-summary merge policy shared by all iterative channels. -func WidenReturnSummaries(prev, next api.ReturnSummaries) api.ReturnSummaries { - if prev == nil && next == nil { - return nil +// JoinFacts performs a precise same-iteration merge of interproc facts. +// Unlike WidenFacts, this may keep directional refinements that are useful +// inside one analysis round. Recursive fixpoint boundaries must use WidenFacts. +func JoinFacts(prev, next api.Facts) api.Facts { + out := api.Facts{ + ParamHints: JoinParamHints(prev.ParamHints, next.ParamHints), + LiteralSigs: JoinLiteralSigs(prev.LiteralSigs, next.LiteralSigs), + CapturedTypes: JoinCapturedTypes(prev.CapturedTypes, next.CapturedTypes), + CapturedFields: JoinCapturedFieldAssigns(prev.CapturedFields, next.CapturedFields), + CapturedContainers: JoinCapturedContainerMutations(prev.CapturedContainers, next.CapturedContainers), + ConstructorFields: JoinConstructorFields(prev.ConstructorFields, next.ConstructorFields), } - if prev == nil { - return next + + symbols := collectCanonicalFunctionFactSymbols(prev.FunctionFacts, next.FunctionFacts) + if len(symbols) > 0 { + out.FunctionFacts = make(api.FunctionFacts, len(symbols)) } - if next == nil { - return prev + for _, sym := range symbols { + prevFact := readFunctionFactFromFacts(&prev, sym) + nextFact := readFunctionFactFromFacts(&next, sym) + writeFunctionFactToFacts(&out, sym, JoinFunctionFact(prevFact, nextFact)) } - merged := make(api.ReturnSummaries, len(prev)+len(next)) - for _, sym := range cfg.SortedSymbolIDs(prev) { - merged[sym] = widenReturnVectorForConvergence(NormalizeReturnVector(prev[sym])) + return out +} + +func widenFunctionFactForConvergence(prev, next api.FunctionFact) api.FunctionFact { + out := api.FunctionFact{ + Summary: widenReturnSummaryForConvergence(prev.Summary, next.Summary), + Narrow: widenReturnSummaryForConvergence(prev.Narrow, next.Narrow), + Type: widenFunctionFactTypeForConvergence(prev.Type, next.Type), } - for _, sym := range cfg.SortedSymbolIDs(next) { - rets := next[sym] - if existing := merged[sym]; existing != nil { - merged[sym] = widenReturnVectorForConvergence(MergeReturnSummary(existing, rets)) + + // Narrow summaries can refine optional/non-nil returns, but a nil-only + // narrow observation must not erase an already-informative summary. + if len(out.Narrow) > 0 && !ReturnTypesAllNil(out.Narrow) { + if len(out.Summary) == 0 { + out.Summary = canonicalReturnVector(out.Narrow) } else { - merged[sym] = widenReturnVectorForConvergence(NormalizeReturnVector(rets)) + out.Summary = widenReturnSummaryForConvergence(out.Summary, out.Narrow) } } - return merged + + if fn := unwrap.Function(out.Type); fn != nil { + if len(out.Summary) > 0 { + if aligned, changed := AlignFunctionTypeWithSummary(fn, out.Summary); changed { + out.Type = widenFunctionFactTypeForConvergence(fn, aligned) + } + } else if len(fn.Returns) > 0 { + out.Summary = widenReturnSummaryForConvergence(nil, fn.Returns) + } + } + + return out } func shouldUseMonotoneReturnJoin(a, b []typ.Type) bool { @@ -322,6 +337,149 @@ func joinReturnTypeMonotone(a, b typ.Type) typ.Type { return typ.JoinPreferNonSoft(a, b) } +func widenReturnSummaryForConvergence(prev, next []typ.Type) []typ.Type { + prev = normalizeAndPruneReturnVector(prev) + next = normalizeAndPruneReturnVector(next) + if len(prev) == 0 { + return widenReturnVectorForConvergence(next) + } + if len(next) == 0 { + return widenReturnVectorForConvergence(prev) + } + + merged := MergeReturnSummary(prev, next) + if returnVectorUnsafePrecisionDrop(prev, merged) { + merged = prev + } + return widenReturnVectorForConvergence(normalizeAndPruneReturnVector(merged)) +} + +func returnVectorUnsafePrecisionDrop(prev, merged []typ.Type) bool { + if len(prev) == 0 || len(merged) == 0 || len(prev) != len(merged) { + return false + } + for i := range prev { + if typeUnsafePrecisionDrop(prev[i], merged[i]) { + return true + } + } + return false +} + +func typeUnsafePrecisionDrop(prev, merged typ.Type) bool { + if prev == nil || merged == nil || typ.TypeEquals(prev, merged) { + return false + } + if typeElidesOptional(merged, prev) { + return false + } + if refines, _ := typeRefinesFalsyMapKey(merged, prev); refines { + return false + } + if typ.IsAny(prev) || typ.IsUnknown(prev) { + return true + } + + switch p := unwrapStructuralShape(prev).(type) { + case *typ.Union: + if unionStrictMemberSubset(merged, p) { + return true + } + if subtype.IsSubtype(merged, p) && !subtype.IsSubtype(p, merged) { + return true + } + case *typ.Record: + m, ok := unwrapStructuralShape(merged).(*typ.Record) + if !ok { + break + } + for _, pf := range p.Fields { + mf := m.GetField(pf.Name) + if mf != nil && typeUnsafePrecisionDrop(pf.Type, mf.Type) { + return true + } + } + if p.HasMapComponent() && m.HasMapComponent() && typeUnsafePrecisionDrop(p.MapValue, m.MapValue) { + return true + } + case *typ.Array: + if m, ok := unwrapStructuralShape(merged).(*typ.Array); ok { + return typeUnsafePrecisionDrop(p.Element, m.Element) + } + case *typ.Map: + if m, ok := unwrapStructuralShape(merged).(*typ.Map); ok { + return typeUnsafePrecisionDrop(p.Key, m.Key) || typeUnsafePrecisionDrop(p.Value, m.Value) + } + case *typ.Tuple: + m, ok := unwrapStructuralShape(merged).(*typ.Tuple) + if !ok || len(p.Elements) != len(m.Elements) { + break + } + for i := range p.Elements { + if typeUnsafePrecisionDrop(p.Elements[i], m.Elements[i]) { + return true + } + } + case *typ.Function: + m, ok := unwrapStructuralShape(merged).(*typ.Function) + if !ok { + break + } + for i := 0; i < len(p.Params) && i < len(m.Params); i++ { + if typeUnsafePrecisionDrop(p.Params[i].Type, m.Params[i].Type) { + return true + } + } + for i := 0; i < len(p.Returns) && i < len(m.Returns); i++ { + if typeUnsafePrecisionDrop(p.Returns[i], m.Returns[i]) { + return true + } + } + } + + if subtype.IsSubtype(merged, prev) && !subtype.IsSubtype(prev, merged) && !TypeExtendsRecord(merged, prev) { + return true + } + return false +} + +func unionStrictMemberSubset(candidate typ.Type, baseline *typ.Union) bool { + if baseline == nil { + return false + } + candidateMembers := unionMembers(candidate) + if len(candidateMembers) == 0 { + candidateMembers = []typ.Type{candidate} + } + if len(candidateMembers) >= len(baseline.Members) { + return false + } + for _, member := range candidateMembers { + found := false + for _, baseMember := range baseline.Members { + if typ.TypeEquals(member, baseMember) { + found = true + break + } + } + if !found { + return false + } + } + return true +} + +func unionMembers(t typ.Type) []typ.Type { + switch v := unwrapStructuralShape(t).(type) { + case *typ.Union: + return v.Members + case *typ.Optional: + return append([]typ.Type{typ.Nil}, unionMembers(v.Inner)...) + default: + return nil + } +} + // WidenParamHints merges two param hint maps using monotone union. func WidenParamHints(prev, next api.ParamHints) api.ParamHints { if prev == nil && next == nil { @@ -407,6 +565,8 @@ func joinParamHintVectors(a, b []typ.Type) []typ.Type { } func joinParamHint(a, b typ.Type) typ.Type { + a = paramhints.NormalizeHintType(a) + b = paramhints.NormalizeHintType(b) if a == nil { return b } @@ -419,6 +579,24 @@ func joinParamHint(a, b typ.Type) typ.Type { if unwrap.IsNilType(b) && !unwrap.IsNilType(a) { return a } + if typeCanSelfEmbed(a) && typeContainsEquivalent(b, a) && !typ.IsAbsentOrUnknown(a) { + if typeContainsUnion(a) { + return a + } + return typ.JoinPreferNonSoft(a, b) + } + if typeCanSelfEmbed(b) && typeContainsEquivalent(a, b) && !typ.IsAbsentOrUnknown(b) { + if typeContainsUnion(b) { + return b + } + return typ.JoinPreferNonSoft(a, b) + } + if typeIsTruthyRefinement(a, b) { + return a + } + if typeIsTruthyRefinement(b, a) { + return b + } if TypeExtendsRecord(a, b) { return a } @@ -428,16 +606,314 @@ func joinParamHint(a, b typ.Type) typ.Type { return typ.JoinPreferNonSoft(a, b) } +func typeIsTruthyRefinement(candidate, baseline typ.Type) bool { + if candidate == nil || baseline == nil || typ.TypeEquals(candidate, baseline) { + return false + } + refined := narrow.ToTruthy(baseline) + if refined == nil || refined.Kind().IsNever() || typ.TypeEquals(refined, baseline) { + return false + } + return typ.TypeEquals(candidate, refined) || subtype.IsSubtype(candidate, refined) +} + +func typeCanSelfEmbed(t typ.Type) bool { + if t == nil { + return false + } + switch v := t.(type) { + case *typ.Annotated: + return typeCanSelfEmbed(v.Inner) + case *typ.Alias: + return typeCanSelfEmbed(v.Target) + case *typ.Optional: + return typeCanSelfEmbed(v.Inner) + case *typ.Union: + for _, member := range v.Members { + if typeCanSelfEmbed(member) { + return true + } + } + return false + case *typ.Intersection: + for _, member := range v.Members { + if typeCanSelfEmbed(member) { + return true + } + } + return false + case *typ.Array, *typ.Map, *typ.Tuple, *typ.Record, *typ.Function: + return true + default: + return false + } +} + +func typeContainsEquivalent(haystack, needle typ.Type) bool { + if haystack == nil || needle == nil { + return false + } + return scanType(haystack, typ.NewGuard(), func(node typ.Type) (bool, bool) { + if typ.TypeEquals(node, needle) { + return true, false + } + return false, true + }) +} + +func typeContainsUnion(t typ.Type) bool { + if t == nil { + return false + } + return scanType(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { + if _, ok := node.(*typ.Union); ok { + return true, false + } + return false, true + }) +} + +func canonicalInterprocValueType(t typ.Type) typ.Type { + if t == nil { + return nil + } + if fn := unwrap.Function(t); fn != nil { + return maybeWidenTypeForConvergence(fn) + } + return maybeWidenTypeForConvergence(t) +} + +func mergeInterprocValueType(existing, candidate typ.Type) typ.Type { + existing = canonicalInterprocValueType(existing) + candidate = canonicalInterprocValueType(candidate) + if existing == nil { + return candidate + } + if candidate == nil { + return existing + } + if unwrap.Function(existing) != nil || unwrap.Function(candidate) != nil { + return maybeWidenTypeForConvergence(widenFunctionFactTypeForConvergence(existing, candidate)) + } + return maybeWidenTypeForConvergence(widenValueTypeForConvergence(existing, candidate)) +} + +func normalizeInterprocValueType(t typ.Type) typ.Type { + if t == nil { + return nil + } + if fn := unwrap.Function(t); fn != nil { + return fn + } + return typ.PruneSoftUnionMembers(t) +} + +func joinInterprocValueType(existing, candidate typ.Type) typ.Type { + existing = normalizeInterprocValueType(existing) + candidate = normalizeInterprocValueType(candidate) + if existing == nil { + return candidate + } + if candidate == nil { + return existing + } + if unwrap.Function(existing) != nil || unwrap.Function(candidate) != nil { + return MergeFunctionFactType(existing, candidate) + } + return typ.JoinPreferNonSoft(existing, candidate) +} + +func widenValueTypeForConvergence(existing, candidate typ.Type) typ.Type { + existing = normalizeInterprocValueType(existing) + candidate = normalizeInterprocValueType(candidate) + if existing == nil { + return maybeWidenTypeForConvergence(candidate) + } + if candidate == nil { + return maybeWidenTypeForConvergence(existing) + } + existing = maybeWidenTypeForConvergence(existing) + candidate = maybeWidenTypeForConvergence(candidate) + if typ.TypeEquals(existing, candidate) { + return existing + } + if typ.IsAny(existing) || typ.IsUnknown(existing) { + return existing + } + if typ.IsAny(candidate) || typ.IsUnknown(candidate) { + return candidate + } + if typeElidesOptional(candidate, existing) { + return candidate + } + if TypeExtendsRecord(candidate, existing) && !typeContainsNestedStructuralShape(candidate, existing) { + return candidate + } + if refines, _ := typeRefinesFalsyMapKey(candidate, existing); refines { + return candidate + } + if subtype.IsSubtype(candidate, existing) && !subtype.IsSubtype(existing, candidate) { + return existing + } + if subtype.IsSubtype(existing, candidate) && !subtype.IsSubtype(candidate, existing) { + return candidate + } + return typ.JoinPreferNonSoft(existing, candidate) +} + +func widenFunctionFactTypeForConvergence(existing, candidate typ.Type) typ.Type { + existing = normalizeInterprocValueType(existing) + candidate = normalizeInterprocValueType(candidate) + if existing == nil { + return maybeWidenTypeForConvergence(candidate) + } + if candidate == nil { + return maybeWidenTypeForConvergence(existing) + } + existingFn := unwrap.Function(existing) + candidateFn := unwrap.Function(candidate) + if existingFn != nil && candidateFn != nil && sameFunctionShapeForFactMerge(existingFn, candidateFn) { + return maybeWidenTypeForConvergence(widenFunctionFactsByShape(existingFn, candidateFn)) + } + return widenValueTypeForConvergence(existing, candidate) +} + +func widenFunctionFactsByShape(existing, candidate *typ.Function) typ.Type { + if existing == nil { + return candidate + } + if candidate == nil { + return existing + } + + builder := typ.Func() + for _, tp := range existing.TypeParams { + builder = builder.TypeParam(tp.Name, tp.Constraint) + } + for i, p := range existing.Params { + paramType := widenFunctionParamFactTypeForConvergence(p.Type, candidate.Params[i].Type) + name := p.Name + if name == "" { + name = candidate.Params[i].Name + } + if p.Optional || candidate.Params[i].Optional { + builder = builder.OptParam(name, paramType) + } else { + builder = builder.Param(name, paramType) + } + } + if existing.Variadic != nil || candidate.Variadic != nil { + builder = builder.Variadic(widenFunctionParamFactTypeForConvergence(existing.Variadic, candidate.Variadic)) + } + if returns := widenReturnSummaryForConvergence(existing.Returns, candidate.Returns); len(returns) > 0 { + builder = builder.Returns(returns...) + } + + effects := existing.Effects + if effects == nil { + effects = candidate.Effects + } + if effects != nil { + builder = builder.Effects(effects) + } + spec := existing.Spec + if spec == nil { + spec = candidate.Spec + } + if spec != nil { + builder = builder.Spec(spec) + } + refinement := existing.Refinement + if refinement == nil { + refinement = candidate.Refinement + } + if refinement != nil { + builder = builder.WithRefinement(refinement) + } + return builder.Build() +} + +func widenFunctionParamFactTypeForConvergence(existing, candidate typ.Type) typ.Type { + existing = normalizeInterprocValueType(existing) + candidate = normalizeInterprocValueType(candidate) + if existing == nil { + return candidate + } + if candidate == nil { + return existing + } + if typ.TypeEquals(existing, candidate) { + return existing + } + if typ.IsAny(existing) || typ.IsUnknown(existing) { + return existing + } + if typ.IsAny(candidate) || typ.IsUnknown(candidate) { + return candidate + } + if typeElidesOptional(candidate, existing) || typeIsTruthyRefinement(candidate, existing) { + return candidate + } + if subtype.IsSubtype(candidate, existing) && !subtype.IsSubtype(existing, candidate) { + return existing + } + if subtype.IsSubtype(existing, candidate) && !subtype.IsSubtype(candidate, existing) { + return candidate + } + return typ.JoinPreferNonSoft(existing, candidate) +} + +// JoinParamHints merges parameter hints inside one analysis iteration. +func JoinParamHints(prev, next api.ParamHints) api.ParamHints { + return WidenParamHints(prev, next) +} + // WidenLiteralSigs merges two literal signature maps. func WidenLiteralSigs(prev, next api.LiteralSigs) api.LiteralSigs { if prev == nil && next == nil { return nil } if prev == nil { - return next + return normalizeLiteralSigs(next) } if next == nil { - return prev + return normalizeLiteralSigs(prev) + } + merged := make(api.LiteralSigs, len(prev)+len(next)) + for fn, sig := range prev { + merged[fn] = maybeWidenFunctionForConvergence(sig) + } + for fn, sig := range next { + if existing := merged[fn]; existing != nil { + merged[fn] = maybeWidenFunctionForConvergence(mergeLiteralSigForConvergence(existing, sig)) + } else { + merged[fn] = maybeWidenFunctionForConvergence(sig) + } + } + return merged +} + +func normalizeLiteralSigs(sigs api.LiteralSigs) api.LiteralSigs { + if sigs == nil { + return nil + } + out := make(api.LiteralSigs, len(sigs)) + for fn, sig := range sigs { + out[fn] = maybeWidenFunctionForConvergence(sig) + } + return out +} + +// JoinLiteralSigs merges literal signatures precisely inside one iteration. +func JoinLiteralSigs(prev, next api.LiteralSigs) api.LiteralSigs { + if prev == nil && next == nil { + return nil + } + if prev == nil { + return normalizeLiteralSigs(next) + } + if next == nil { + return normalizeLiteralSigs(prev) } merged := make(api.LiteralSigs, len(prev)+len(next)) for fn, sig := range prev { @@ -476,27 +952,61 @@ func mergeLiteralSig(prev, next *typ.Function) *typ.Function { return prev } +func mergeLiteralSigForConvergence(prev, next *typ.Function) *typ.Function { + merged := widenFunctionFactTypeForConvergence(prev, next) + if fn := unwrap.Function(merged); fn != nil { + return fn + } + return mergeLiteralSig(prev, next) +} + // WidenCapturedTypes merges two captured type maps using monotone join. func WidenCapturedTypes(prev, next api.CapturedTypes) api.CapturedTypes { if prev == nil && next == nil { return nil } if prev == nil { - return next + return normalizeCapturedTypes(next) } if next == nil { - return prev + return normalizeCapturedTypes(prev) + } + merged := make(api.CapturedTypes, len(prev)+len(next)) + for _, sym := range cfg.SortedSymbolIDs(prev) { + merged[sym] = canonicalInterprocValueType(prev[sym]) + } + for _, sym := range cfg.SortedSymbolIDs(next) { + t := next[sym] + if existing := merged[sym]; existing != nil { + merged[sym] = mergeInterprocValueType(existing, t) + } else { + merged[sym] = canonicalInterprocValueType(t) + } + } + return merged +} + +// JoinCapturedTypes merges captured types precisely inside one iteration. +func JoinCapturedTypes(prev, next api.CapturedTypes) api.CapturedTypes { + if prev == nil && next == nil { + return nil + } + if prev == nil { + return normalizeCapturedTypesForJoin(next) + } + if next == nil { + return normalizeCapturedTypesForJoin(prev) } merged := make(api.CapturedTypes, len(prev)+len(next)) for _, sym := range cfg.SortedSymbolIDs(prev) { - merged[sym] = prev[sym] + merged[sym] = normalizeInterprocValueType(prev[sym]) } for _, sym := range cfg.SortedSymbolIDs(next) { t := next[sym] if existing := merged[sym]; existing != nil { - merged[sym] = maybeWidenTypeForConvergence(typ.JoinPreferNonSoft(existing, t)) + merged[sym] = joinInterprocValueType(existing, t) } else { - merged[sym] = maybeWidenTypeForConvergence(t) + merged[sym] = normalizeInterprocValueType(t) } } return merged @@ -508,53 +1018,161 @@ func WidenCapturedFieldAssigns(prev, next api.CapturedFieldAssigns) api.Captured return nil } if prev == nil { - return next + return normalizeCapturedFieldAssigns(next) } if next == nil { - return prev + return normalizeCapturedFieldAssigns(prev) } merged := make(api.CapturedFieldAssigns, len(prev)+len(next)) for _, callee := range cfg.SortedSymbolIDs(prev) { - merged[callee] = prev[callee] + merged[callee] = normalizeCapturedFieldSymbolMap(prev[callee]) } for _, callee := range cfg.SortedSymbolIDs(next) { captured := next[callee] existing := merged[callee] if existing == nil { - merged[callee] = captured + merged[callee] = normalizeCapturedFieldSymbolMap(captured) continue } merged[callee] = MergeCapturedFieldSymbolMaps(existing, captured, func(prev typ.Type, next typ.Type) typ.Type { if prev != nil { - return maybeWidenTypeForConvergence(typ.JoinPreferNonSoft(prev, next)) + return mergeInterprocValueType(prev, next) } - return maybeWidenTypeForConvergence(next) + return canonicalInterprocValueType(next) }) } return merged } +func normalizeCapturedTypes(types api.CapturedTypes) api.CapturedTypes { + if types == nil { + return nil + } + out := make(api.CapturedTypes, len(types)) + for _, sym := range cfg.SortedSymbolIDs(types) { + out[sym] = canonicalInterprocValueType(types[sym]) + } + return out +} + +func normalizeCapturedTypesForJoin(types api.CapturedTypes) api.CapturedTypes { + if types == nil { + return nil + } + out := make(api.CapturedTypes, len(types)) + for _, sym := range cfg.SortedSymbolIDs(types) { + out[sym] = normalizeInterprocValueType(types[sym]) + } + return out +} + +func normalizeCapturedFieldAssigns(fields api.CapturedFieldAssigns) api.CapturedFieldAssigns { + if fields == nil { + return nil + } + out := make(api.CapturedFieldAssigns, len(fields)) + for _, callee := range cfg.SortedSymbolIDs(fields) { + out[callee] = normalizeCapturedFieldSymbolMap(fields[callee]) + } + return out +} + +func normalizeCapturedFieldSymbolMap(fieldsBySym map[cfg.SymbolID]map[string]typ.Type) map[cfg.SymbolID]map[string]typ.Type { + if fieldsBySym == nil { + return nil + } + out := make(map[cfg.SymbolID]map[string]typ.Type, len(fieldsBySym)) + for _, sym := range cfg.SortedSymbolIDs(fieldsBySym) { + fields := fieldsBySym[sym] + fieldOut := make(map[string]typ.Type, len(fields)) + for _, name := range cfg.SortedFieldNames(fields) { + fieldOut[name] = canonicalInterprocValueType(fields[name]) + } + out[sym] = fieldOut + } + return out +} + +// JoinCapturedFieldAssigns merges captured field assignments inside one iteration. +func JoinCapturedFieldAssigns(prev, next api.CapturedFieldAssigns) api.CapturedFieldAssigns { + if prev == nil && next == nil { + return nil + } + if prev == nil { + return normalizeCapturedFieldAssignsForJoin(next) + } + if next == nil { + return normalizeCapturedFieldAssignsForJoin(prev) + } + merged := make(api.CapturedFieldAssigns, len(prev)+len(next)) + for _, callee := range cfg.SortedSymbolIDs(prev) { + merged[callee] = normalizeCapturedFieldSymbolMapForJoin(prev[callee]) + } + for _, callee := range cfg.SortedSymbolIDs(next) { + captured := next[callee] + existing := merged[callee] + if existing == nil { + merged[callee] = normalizeCapturedFieldSymbolMapForJoin(captured) + continue + } + merged[callee] = MergeCapturedFieldSymbolMaps(existing, captured, func(prev typ.Type, next typ.Type) typ.Type { + if prev != nil { + return joinInterprocValueType(prev, next) + } + return normalizeInterprocValueType(next) + }) + } + return merged +} + +func normalizeCapturedFieldAssignsForJoin(fields api.CapturedFieldAssigns) api.CapturedFieldAssigns { + if fields == nil { + return nil + } + out := make(api.CapturedFieldAssigns, len(fields)) + for _, callee := range cfg.SortedSymbolIDs(fields) { + out[callee] = normalizeCapturedFieldSymbolMapForJoin(fields[callee]) + } + return out +} + +func normalizeCapturedFieldSymbolMapForJoin(fieldsBySym map[cfg.SymbolID]map[string]typ.Type) map[cfg.SymbolID]map[string]typ.Type { + if fieldsBySym == nil { + return nil + } + out := make(map[cfg.SymbolID]map[string]typ.Type, len(fieldsBySym)) + for _, sym := range cfg.SortedSymbolIDs(fieldsBySym) { + fields := fieldsBySym[sym] + fieldOut := make(map[string]typ.Type, len(fields)) + for _, name := range cfg.SortedFieldNames(fields) { + fieldOut[name] = normalizeInterprocValueType(fields[name]) + } + out[sym] = fieldOut + } + return out +} + // WidenCapturedContainerMutations merges captured container mutation maps using monotone union. func WidenCapturedContainerMutations(prev, next api.CapturedContainerMutations) api.CapturedContainerMutations { if prev == nil && next == nil { return nil } if prev == nil { - return next + return normalizeCapturedContainerMutations(next) } if next == nil { - return prev + return normalizeCapturedContainerMutations(prev) } merged := make(api.CapturedContainerMutations, len(prev)+len(next)) for _, sym := range cfg.SortedSymbolIDs(prev) { - merged[sym] = prev[sym] + merged[sym] = normalizeCapturedContainerMutationMap(prev[sym]) } for _, sym := range cfg.SortedSymbolIDs(next) { muts := next[sym] existing := merged[sym] merged[sym] = MergeCapturedContainerMutationMaps(existing, muts, func(prev *api.ContainerMutation, next api.ContainerMutation) api.ContainerMutation { if prev != nil { - next.ValueType = maybeWidenTypeForConvergence(typ.JoinPreferNonSoft(prev.ValueType, next.ValueType)) + next.ValueType = mergeInterprocValueType(prev.ValueType, next.ValueType) } else { next.ValueType = maybeWidenTypeForConvergence(next.ValueType) } @@ -564,20 +1182,118 @@ func WidenCapturedContainerMutations(prev, next api.CapturedContainerMutations) return merged } +func normalizeCapturedContainerMutations(muts api.CapturedContainerMutations) api.CapturedContainerMutations { + if muts == nil { + return nil + } + out := make(api.CapturedContainerMutations, len(muts)) + for _, sym := range cfg.SortedSymbolIDs(muts) { + out[sym] = normalizeCapturedContainerMutationMap(muts[sym]) + } + return out +} + +func normalizeCapturedContainerMutationMap(muts map[cfg.SymbolID][]api.ContainerMutation) map[cfg.SymbolID][]api.ContainerMutation { + if muts == nil { + return nil + } + out := make(map[cfg.SymbolID][]api.ContainerMutation, len(muts)) + for _, sym := range cfg.SortedSymbolIDs(muts) { + entries := muts[sym] + if len(entries) == 0 { + continue + } + normalized := make([]api.ContainerMutation, len(entries)) + for i, mut := range entries { + normalized[i] = mut + normalized[i].ValueType = canonicalInterprocValueType(mut.ValueType) + } + out[sym] = normalized + } + if len(out) == 0 { + return nil + } + return out +} + +// JoinCapturedContainerMutations merges captured container mutations inside one iteration. +func JoinCapturedContainerMutations(prev, next api.CapturedContainerMutations) api.CapturedContainerMutations { + if prev == nil && next == nil { + return nil + } + if prev == nil { + return normalizeCapturedContainerMutationsForJoin(next) + } + if next == nil { + return normalizeCapturedContainerMutationsForJoin(prev) + } + merged := make(api.CapturedContainerMutations, len(prev)+len(next)) + for _, sym := range cfg.SortedSymbolIDs(prev) { + merged[sym] = normalizeCapturedContainerMutationMapForJoin(prev[sym]) + } + for _, sym := range cfg.SortedSymbolIDs(next) { + muts := next[sym] + existing := merged[sym] + merged[sym] = MergeCapturedContainerMutationMaps(existing, muts, func(prev *api.ContainerMutation, next api.ContainerMutation) api.ContainerMutation { + if prev != nil { + next.ValueType = joinInterprocValueType(prev.ValueType, next.ValueType) + } else { + next.ValueType = normalizeInterprocValueType(next.ValueType) + } + return next + }) + } + return merged +} + +func normalizeCapturedContainerMutationsForJoin(muts api.CapturedContainerMutations) api.CapturedContainerMutations { + if muts == nil { + return nil + } + out := make(api.CapturedContainerMutations, len(muts)) + for _, sym := range cfg.SortedSymbolIDs(muts) { + out[sym] = normalizeCapturedContainerMutationMapForJoin(muts[sym]) + } + return out +} + +func normalizeCapturedContainerMutationMapForJoin(muts map[cfg.SymbolID][]api.ContainerMutation) map[cfg.SymbolID][]api.ContainerMutation { + if muts == nil { + return nil + } + out := make(map[cfg.SymbolID][]api.ContainerMutation, len(muts)) + for _, sym := range cfg.SortedSymbolIDs(muts) { + entries := muts[sym] + if len(entries) == 0 { + continue + } + normalized := make([]api.ContainerMutation, len(entries)) + for i, mut := range entries { + normalized[i] = mut + normalized[i].ValueType = normalizeInterprocValueType(mut.ValueType) + } + out[sym] = normalized + } + if len(out) == 0 { + return nil + } + return out +} + // WidenConstructorFields merges constructor field maps using monotone join. func WidenConstructorFields(prev, next api.ConstructorFields) api.ConstructorFields { if prev == nil && next == nil { return nil } if prev == nil { - return next + return normalizeConstructorFields(next) } if next == nil { - return prev + return normalizeConstructorFields(prev) } merged := make(api.ConstructorFields, len(prev)+len(next)) for _, sym := range cfg.SortedSymbolIDs(prev) { - merged[sym] = prev[sym] + merged[sym] = normalizeConstructorFieldMap(prev[sym]) } for _, sym := range cfg.SortedSymbolIDs(next) { fields := next[sym] @@ -593,7 +1309,7 @@ func WidenConstructorFields(prev, next api.ConstructorFields) api.ConstructorFie for _, name := range cfg.SortedFieldNames(fields) { t := fields[name] if prevType := out[name]; prevType != nil { - out[name] = maybeWidenTypeForConvergence(typ.JoinPreferNonSoft(prevType, t)) + out[name] = mergeInterprocValueType(prevType, t) } else { out[name] = maybeWidenTypeForConvergence(t) } @@ -603,6 +1319,89 @@ func WidenConstructorFields(prev, next api.ConstructorFields) api.ConstructorFie return merged } +func normalizeConstructorFields(fields api.ConstructorFields) api.ConstructorFields { + if fields == nil { + return nil + } + out := make(api.ConstructorFields, len(fields)) + for _, sym := range cfg.SortedSymbolIDs(fields) { + out[sym] = normalizeConstructorFieldMap(fields[sym]) + } + return out +} + +func normalizeConstructorFieldMap(fields map[string]typ.Type) map[string]typ.Type { + if fields == nil { + return nil + } + out := make(map[string]typ.Type, len(fields)) + for _, name := range cfg.SortedFieldNames(fields) { + out[name] = canonicalInterprocValueType(fields[name]) + } + return out +} + +// JoinConstructorFields merges constructor field maps inside one iteration. +func JoinConstructorFields(prev, next api.ConstructorFields) api.ConstructorFields { + if prev == nil && next == nil { + return nil + } + if prev == nil { + return normalizeConstructorFieldsForJoin(next) + } + if next == nil { + return normalizeConstructorFieldsForJoin(prev) + } + merged := make(api.ConstructorFields, len(prev)+len(next)) + for _, sym := range cfg.SortedSymbolIDs(prev) { + merged[sym] = normalizeConstructorFieldMapForJoin(prev[sym]) + } + for _, sym := range cfg.SortedSymbolIDs(next) { + fields := next[sym] + existing := merged[sym] + if existing == nil { + merged[sym] = normalizeConstructorFieldMapForJoin(fields) + continue + } + out := make(map[string]typ.Type, len(existing)+len(fields)) + for _, name := range cfg.SortedFieldNames(existing) { + out[name] = existing[name] + } + for _, name := range cfg.SortedFieldNames(fields) { + t := fields[name] + if prevType := out[name]; prevType != nil { + out[name] = joinInterprocValueType(prevType, t) + } else { + out[name] = normalizeInterprocValueType(t) + } + } + merged[sym] = out + } + return merged +} + +func normalizeConstructorFieldsForJoin(fields api.ConstructorFields) api.ConstructorFields { + if fields == nil { + return nil + } + out := make(api.ConstructorFields, len(fields)) + for _, sym := range cfg.SortedSymbolIDs(fields) { + out[sym] = normalizeConstructorFieldMapForJoin(fields[sym]) + } + return out +} + +func normalizeConstructorFieldMapForJoin(fields map[string]typ.Type) map[string]typ.Type { + if fields == nil { + return nil + } + out := make(map[string]typ.Type, len(fields)) + for _, name := range cfg.SortedFieldNames(fields) { + out[name] = normalizeInterprocValueType(fields[name]) + } + return out +} + func mergeFunctionReturnsIfSameShape(prevFn, nextFn *typ.Function) (typ.Type, bool) { if prevFn == nil || nextFn == nil { return nil, false diff --git a/compiler/check/returns/widen_test.go b/compiler/check/returns/widen_test.go index 0300c739..5b837420 100644 --- a/compiler/check/returns/widen_test.go +++ b/compiler/check/returns/widen_test.go @@ -9,72 +9,60 @@ import ( "github.com/wippyai/go-lua/types/typ" ) -func TestWidenFacts_DoesNotOverrideReturnSummariesWithNarrowReturns(t *testing.T) { +func TestWidenFacts_DoesNotOverrideSummaryWithNilNarrow(t *testing.T) { prev := api.Facts{ FunctionFacts: api.FunctionFacts{ 1: {Summary: []typ.Type{typ.Integer}}, }, - ReturnSummaries: api.ReturnSummaries{ - 1: []typ.Type{typ.Integer}, - }, } next := api.Facts{ FunctionFacts: api.FunctionFacts{ 1: {Narrow: []typ.Type{typ.Nil}}, }, - NarrowReturns: api.NarrowReturnSummaries{ - 1: []typ.Type{typ.Nil}, - }, } merged := WidenFacts(prev, next) - got := merged.ReturnSummaries[1] + got := merged.FunctionFacts.Summary(1) if len(got) != 1 || !typ.TypeEquals(got[0], typ.Integer) { - t.Fatalf("expected ReturnSummaries[1]=integer, got %v", got) + t.Fatalf("expected summary[1]=integer, got %v", got) } } -func TestWidenFacts_ElidesOptionalFromNarrowReturns(t *testing.T) { +func TestWidenFacts_ElidesOptionalFromNarrowFunctionFact(t *testing.T) { prev := api.Facts{ FunctionFacts: api.FunctionFacts{ 1: {Summary: []typ.Type{typ.NewOptional(typ.Integer)}}, }, - ReturnSummaries: api.ReturnSummaries{ - 1: []typ.Type{typ.NewOptional(typ.Integer)}, - }, } next := api.Facts{ FunctionFacts: api.FunctionFacts{ 1: {Narrow: []typ.Type{typ.Integer}}, }, - NarrowReturns: api.NarrowReturnSummaries{ - 1: []typ.Type{typ.Integer}, - }, } merged := WidenFacts(prev, next) - got := merged.ReturnSummaries[1] + got := merged.FunctionFacts.Summary(1) if len(got) != 1 || !typ.TypeEquals(got[0], typ.Integer) { - t.Fatalf("expected ReturnSummaries[1]=integer, got %v", got) + t.Fatalf("expected summary[1]=integer, got %v", got) } } -func TestWidenReturnSummaries_RefinesOptionalForFirstOrderTypes(t *testing.T) { - prev := api.ReturnSummaries{ - 1: []typ.Type{typ.NewOptional(typ.Integer)}, - } - next := api.ReturnSummaries{ - 1: []typ.Type{typ.Integer}, - } +func TestWidenFacts_RefinesOptionalForFirstOrderFunctionSummary(t *testing.T) { + prev := api.Facts{FunctionFacts: api.FunctionFacts{ + 1: {Summary: []typ.Type{typ.NewOptional(typ.Integer)}}, + }} + next := api.Facts{FunctionFacts: api.FunctionFacts{ + 1: {Summary: []typ.Type{typ.Integer}}, + }} - merged := WidenReturnSummaries(prev, next) - got := merged[1] + merged := WidenFacts(prev, next) + got := merged.FunctionFacts.Summary(1) if len(got) != 1 || !typ.TypeEquals(got[0], typ.Integer) { t.Fatalf("expected integer after first-order refinement, got %v", got) } } -func TestWidenReturnSummaries_UsesMonotoneJoinForHigherOrderReturns(t *testing.T) { +func TestWidenFacts_UsesMonotoneJoinForHigherOrderFunctionSummary(t *testing.T) { nestedUnknown := typ.NewRecord(). Field("next", typ.Func().Returns(typ.Unknown).Build()). Build() @@ -89,21 +77,21 @@ func TestWidenReturnSummaries_UsesMonotoneJoinForHigherOrderReturns(t *testing.T Field("build", typ.Func().Returns(nestedString).Build()). Build() - prev := api.ReturnSummaries{ - 1: []typ.Type{base}, - } - next := api.ReturnSummaries{ - 1: []typ.Type{refined}, - } + prev := api.Facts{FunctionFacts: api.FunctionFacts{ + 1: {Summary: []typ.Type{base}}, + }} + next := api.Facts{FunctionFacts: api.FunctionFacts{ + 1: {Summary: []typ.Type{refined}}, + }} - merged := WidenReturnSummaries(prev, next) - got := merged[1] + merged := WidenFacts(prev, next) + got := merged.FunctionFacts.Summary(1) if len(got) != 1 || !typ.TypeEquals(got[0], base) { t.Fatalf("expected stable upper bound for higher-order return, got %v", got) } } -func TestWidenReturnSummaries_InterfaceMethodsDoNotBlockOptionalElision(t *testing.T) { +func TestWidenFacts_InterfaceMethodsDoNotBlockOptionalElision(t *testing.T) { dbType := typ.NewInterface("sql.DB", []typ.Method{ { Name: "release", @@ -114,20 +102,271 @@ func TestWidenReturnSummaries_InterfaceMethodsDoNotBlockOptionalElision(t *testi }, }) - prev := api.ReturnSummaries{ - 1: []typ.Type{typ.NewOptional(dbType)}, - } - next := api.ReturnSummaries{ - 1: []typ.Type{dbType}, - } + prev := api.Facts{FunctionFacts: api.FunctionFacts{ + 1: {Summary: []typ.Type{typ.NewOptional(dbType)}}, + }} + next := api.Facts{FunctionFacts: api.FunctionFacts{ + 1: {Summary: []typ.Type{dbType}}, + }} - merged := WidenReturnSummaries(prev, next) - got := merged[1] + merged := WidenFacts(prev, next) + got := merged.FunctionFacts.Summary(1) if len(got) != 1 || !typ.TypeEquals(got[0], dbType) { t.Fatalf("expected optional elision for interface return, got %v", got) } } +func TestMergeReturnSummary_StopsRecursiveContainerReturnGrowth(t *testing.T) { + recordMap := func(value typ.Type) typ.Type { + return typ.NewRecord().MapComponent(typ.String, value).Build() + } + recordField := func(value typ.Type) typ.Type { + return typ.NewRecord().Field("value", value).SetOpen(true).Build() + } + + tests := []struct { + name string + stable typ.Type + growth typ.Type + }{ + { + name: "map", + stable: typ.NewMap(typ.String, typ.Any), + growth: typ.NewMap(typ.String, typ.NewMap(typ.String, typ.Nil)), + }, + { + name: "record map component", + stable: recordMap(typ.Any), + growth: recordMap(recordMap(typ.Nil)), + }, + { + name: "record field", + stable: recordField(typ.Any), + growth: recordField(recordField(typ.Nil)), + }, + { + name: "array", + stable: typ.NewArray(typ.Any), + growth: typ.NewArray(typ.NewArray(typ.Nil)), + }, + { + name: "tuple", + stable: typ.NewTuple(typ.Any), + growth: typ.NewTuple(typ.NewTuple(typ.Nil)), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + merged := MergeReturnSummary([]typ.Type{tt.stable}, []typ.Type{tt.growth}) + if len(merged) != 1 || !typ.TypeEquals(merged[0], tt.stable) { + t.Fatalf("expected stable recursive return shape, got %v", merged) + } + }) + } +} + +func TestMergeReturnSummary_KeepsNonRecursiveContainerRefinement(t *testing.T) { + stable := typ.NewMap(typ.String, typ.Any) + refined := typ.NewMap(typ.String, typ.String) + + merged := MergeReturnSummary([]typ.Type{stable}, []typ.Type{refined}) + if len(merged) != 1 || !typ.TypeEquals(merged[0], refined) { + t.Fatalf("expected non-recursive map refinement to survive, got %v", merged) + } +} + +func TestWidenParamHints_StopsSelfEmbeddingRecordGrowth(t *testing.T) { + prevHint := typ.NewUnion( + typ.Number, + typ.NewRecord(). + Field("limit", typ.Any). + SetOpen(true). + Build(), + ) + nextHint := typ.NewRecord(). + Field("limit", prevHint). + SetOpen(true). + Build() + + merged := WidenParamHints( + api.ParamHints{1: []typ.Type{prevHint}}, + api.ParamHints{1: []typ.Type{nextHint}}, + ) + + got := merged[1][0] + if !typ.TypeEquals(got, prevHint) { + t.Fatalf("expected stable previous hint, got %v", got) + } +} + +func TestWidenParamHints_StopsSelfEmbeddingContainerGrowth(t *testing.T) { + prevHint := typ.NewUnion( + typ.Number, + typ.NewRecord(). + Field("limit", typ.Any). + SetOpen(true). + Build(), + ) + + tests := []struct { + name string + next typ.Type + }{ + { + name: "record", + next: typ.NewRecord(). + Field("value", prevHint). + SetOpen(true). + Build(), + }, + { + name: "array", + next: typ.NewArray(prevHint), + }, + { + name: "map", + next: typ.NewMap(typ.String, prevHint), + }, + { + name: "tuple", + next: typ.NewTuple(prevHint), + }, + { + name: "function", + next: typ.Func(). + Param("value", prevHint). + Returns(prevHint). + Build(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + merged := WidenParamHints( + api.ParamHints{1: []typ.Type{prevHint}}, + api.ParamHints{1: []typ.Type{tt.next}}, + ) + + got := merged[1][0] + if !typ.TypeEquals(got, prevHint) { + t.Fatalf("expected stable previous hint, got %v", got) + } + }) + } +} + +func TestWidenParamHints_KeepsFirstRecordWrapperObservation(t *testing.T) { + nextHint := typ.NewRecord(). + Field("limit", typ.Number). + SetOpen(true). + Build() + + merged := WidenParamHints( + api.ParamHints{1: []typ.Type{typ.Number}}, + api.ParamHints{1: []typ.Type{nextHint}}, + ) + + got := merged[1][0] + if typ.TypeEquals(got, typ.Number) { + t.Fatalf("expected wrapper observation to be preserved, got %v", got) + } + if !typ.TypeEquals(got, typ.NewUnion(typ.Number, nextHint)) { + t.Fatalf("expected number | wrapper hint, got %v", got) + } +} + +func TestWidenParamHints_JoinsNestedRecordObservations(t *testing.T) { + nested := typ.NewRecord(). + Field("routes", typ.NewRecord().Field("users", typ.Boolean).SetOpen(true).Build()). + SetOpen(true). + Build() + outer := typ.NewRecord(). + Field("api", nested). + SetOpen(true). + Build() + + merged := WidenParamHints( + api.ParamHints{1: []typ.Type{outer}}, + api.ParamHints{1: []typ.Type{nested}}, + ) + + got := merged[1][0] + want := typ.NewUnion(outer, nested) + if !typ.TypeEquals(got, want) { + t.Fatalf("expected nested record observations to be joined as %v, got %v", want, got) + } +} + +func TestWidenParamHints_ReplacesStaleBroadHintWithCurrentRefinement(t *testing.T) { + stale := typ.NewUnion(typ.String, typ.False) + current := typ.String + + merged := WidenParamHints( + api.ParamHints{1: []typ.Type{stale}}, + api.ParamHints{1: []typ.Type{current}}, + ) + + got := merged[1][0] + if !typ.TypeEquals(got, current) { + t.Fatalf("expected current refined hint %v to replace stale broad hint, got %v", current, got) + } +} + +func TestWidenCapturedFieldAssigns_NormalizesOptionalFunctionValues(t *testing.T) { + fn := typ.Func().Param("fn", typ.Unknown).Build() + merged := WidenCapturedFieldAssigns(nil, api.CapturedFieldAssigns{ + 1: {2: {"after_all": typ.NewOptional(fn)}}, + }) + + got := merged[1][2]["after_all"] + if !typ.TypeEquals(got, fn) { + t.Fatalf("expected optional function value to canonicalize to function, got %v", got) + } +} + +func TestWidenCapturedFieldAssigns_MergesSameShapeFunctionValues(t *testing.T) { + prevFn := typ.Func(). + Param("name", typ.Unknown). + Returns(typ.NewRecord(). + Field("full_path", typ.String). + SetOpen(true). + Build()). + Build() + nextFn := typ.Func(). + Param("name", typ.Unknown). + Returns(typ.NewRecord(). + Field("full_path", typ.String). + Field("children", typ.NewArray(typ.Unknown)). + SetOpen(true). + Build()). + Build() + + merged := WidenCapturedFieldAssigns( + api.CapturedFieldAssigns{1: {2: {"describe": prevFn}}}, + api.CapturedFieldAssigns{1: {2: {"describe": nextFn}}}, + ) + + got := merged[1][2]["describe"] + if _, ok := got.(*typ.Union); ok { + t.Fatalf("expected function observations to merge, got union %v", got) + } + fn, ok := got.(*typ.Function) + if !ok { + t.Fatalf("expected merged function, got %T", got) + } + if len(fn.Returns) != 1 { + t.Fatalf("expected one return, got %d", len(fn.Returns)) + } + rec, ok := fn.Returns[0].(*typ.Record) + if !ok { + t.Fatalf("expected record return, got %T", fn.Returns[0]) + } + if rec.GetField("full_path") == nil || rec.GetField("children") == nil { + t.Fatalf("expected merged return fields, got %v", rec) + } +} + func TestMergeFunctionReturnsIfSameShape_GenericFunctions(t *testing.T) { prev := typ.Func(). TypeParam("T", nil). @@ -174,7 +413,7 @@ func TestMergeFunctionReturnsIfSameShape_GenericTypeParamsMustMatch(t *testing.T } } -func TestMergeFuncTypes_DoesNotRegressToNarrowerNilReturn(t *testing.T) { +func TestMergeFunctionFactType_DoesNotRegressToNarrowerNilReturn(t *testing.T) { prev := typ.Func(). Returns(typ.NewOptional(typ.Integer)). Build() @@ -213,7 +452,7 @@ func TestMergeFunctionReturnsIfSameShape_NormalizesLeakedTypeParams(t *testing.T } } -func TestMergeFuncTypes_PrefersWiderSupertypeOnSubtypeRelation(t *testing.T) { +func TestMergeFunctionFactType_PrefersWiderSupertypeOnSubtypeRelation(t *testing.T) { merged := MergeFunctionFactType(typ.Integer, typ.Number) if !typ.TypeEquals(merged, typ.Number) { t.Fatalf("expected wider supertype number, got %v", merged) @@ -225,7 +464,7 @@ func TestMergeFuncTypes_PrefersWiderSupertypeOnSubtypeRelation(t *testing.T) { } } -func TestMergeFuncTypes_IsCommutativeForIncomparableSignatures(t *testing.T) { +func TestMergeFunctionFactType_IsCommutativeForIncomparableSignatures(t *testing.T) { coarse := typ.Func(). Param("entries", typ.Any). Returns(typ.Integer). @@ -242,7 +481,7 @@ func TestMergeFuncTypes_IsCommutativeForIncomparableSignatures(t *testing.T) { } } -func TestMergeFuncTypes_AliasInputsUseCanonicalJoin(t *testing.T) { +func TestMergeFunctionFactType_AliasInputsUseCanonicalJoin(t *testing.T) { coarse := typ.NewAlias("CoarseFn", typ.Func(). Param("entries", typ.Any). Returns(typ.Integer). @@ -259,7 +498,7 @@ func TestMergeFuncTypes_AliasInputsUseCanonicalJoin(t *testing.T) { } } -func TestMergeFuncTypes_MapVsOpenRecordUsesCanonicalJoin(t *testing.T) { +func TestMergeFunctionFactType_MapVsOpenRecordUsesCanonicalJoin(t *testing.T) { coarse := typ.Func(). Param("t", typ.NewRecord().SetOpen(true).Build()). Returns(typ.String). @@ -329,6 +568,20 @@ func TestWidenLiteralSigs_PrefersMergedSameShapeSignature(t *testing.T) { } } +func TestWidenLiteralSigs_NormalizesNilBranch(t *testing.T) { + lit := &ast.FunctionExpr{} + sig := typ.Func(). + Returns(typ.NewUnion(typ.NewRecord().Build(), typ.String)). + Build() + + merged := WidenLiteralSigs(nil, api.LiteralSigs{lit: sig}) + got := merged[lit] + want := maybeWidenFunctionForConvergence(sig) + if got == nil || !typ.TypeEquals(got, want) { + t.Fatalf("expected nil-branch literal signature %v to be normalized to %v, got %v", sig, want, got) + } +} + func TestTypeContainsFunction_IgnoresInterfaceMethodSignatures(t *testing.T) { iface := typ.NewInterface("Reader", []typ.Method{ { diff --git a/compiler/check/session.go b/compiler/check/session.go index 3475a9ae..6a746337 100644 --- a/compiler/check/session.go +++ b/compiler/check/session.go @@ -11,8 +11,8 @@ // Contains binding tables, CFG graphs, and module aliases. Never modified // during fixpoint iteration. // -// - IterationStore: Iteration-local state used during fixpoint convergence -// (revision counter and constructor field collection). +// - Snapshot inputs: query-tracked interprocedural snapshots used to revalidate +// cached function analysis when facts/effects actually change. // // - IterationScratch: Single-iteration state cleared at each boundary. // Tracks which literals have been analyzed, pending parameter hints, @@ -217,7 +217,11 @@ func (s *Session) ScopeDepthDiagState() map[*ast.FunctionExpr]bool { // New creates a session for checking a file. func New(ctx *db.QueryContext, name string) *Session { - store := store.NewSessionStore() + var database *db.DB + if ctx != nil { + database = ctx.DB() + } + store := store.NewSessionStoreWithDB(database) api.AttachStore(ctx, store) sess := &Session{ Ctx: ctx, diff --git a/compiler/check/siblings/doc.go b/compiler/check/siblings/doc.go index 2f84ed92..5bcc60ec 100644 --- a/compiler/check/siblings/doc.go +++ b/compiler/check/siblings/doc.go @@ -21,7 +21,7 @@ // // # Overlay System // -// [Overlay] provides a view that combines: +// [Overlay] combines: // - Stable sibling types from previous iterations // - Pending updates from current iteration // diff --git a/compiler/check/siblings/overlay.go b/compiler/check/siblings/overlay.go index ee6f2d57..5a39f432 100644 --- a/compiler/check/siblings/overlay.go +++ b/compiler/check/siblings/overlay.go @@ -22,8 +22,8 @@ type OverlayEntry struct { // sibling functions so that calls between them can be typed during // fixpoint iteration. type OverlayConfig struct { - // Summaries maps symbols to their return type summaries. - Summaries map[cfg.SymbolID][]typ.Type + // ReturnVectors maps symbols to their current inferred return vectors. + ReturnVectors map[cfg.SymbolID][]typ.Type // Siblings are the sibling functions in this scope group. Siblings []OverlayEntry @@ -31,7 +31,7 @@ type OverlayConfig struct { // CurrentSym is the symbol of the function being analyzed (excluded from overlay). CurrentSym cfg.SymbolID - // Services provides seed type resolution for siblings without summaries. + // Services provides seed type resolution for siblings without return vectors. Services OverlayServices } @@ -55,18 +55,18 @@ func (o OverlayServicesFuncs) SeedType(fn *ast.FunctionExpr) typ.Type { // BuildOverlay constructs an overlay map for return inference. // // This overlay is used during SCC-based return type inference. It provides -// function types for sibling functions based on their current return summaries. +// function types for sibling functions based on their current return vectors. // The current function (CurrentSym) is excluded from the overlay to avoid // circular dependence during its own analysis. // -// For siblings without summaries yet, placeholder function types are created +// For siblings without return vectors yet, placeholder function types are created // using seed type services to preserve parameter arity. This enables the fixpoint // to make progress even when not all return types are known. func BuildOverlay(c OverlayConfig) map[cfg.SymbolID]typ.Type { overlay := make(map[cfg.SymbolID]typ.Type) - // Add sibling function types with current return summaries. - for sym, returnTypes := range c.Summaries { + // Add sibling function types with current return vectors. + for sym, returnTypes := range c.ReturnVectors { if sym == c.CurrentSym { continue } @@ -75,7 +75,7 @@ func BuildOverlay(c OverlayConfig) map[cfg.SymbolID]typ.Type { } } - // Seed siblings without summaries with placeholder function types. + // Seed siblings without return vectors with placeholder function types. for _, sib := range c.Siblings { if sib.Symbol == c.CurrentSym { continue diff --git a/compiler/check/siblings/overlay_test.go b/compiler/check/siblings/overlay_test.go index 23c2872e..7ddc68a4 100644 --- a/compiler/check/siblings/overlay_test.go +++ b/compiler/check/siblings/overlay_test.go @@ -16,9 +16,9 @@ func TestBuildOverlay_Empty(t *testing.T) { } } -func TestBuildOverlay_WithSummaries(t *testing.T) { +func TestBuildOverlay_WithReturnVectors(t *testing.T) { conf := OverlayConfig{ - Summaries: map[cfg.SymbolID][]typ.Type{ + ReturnVectors: map[cfg.SymbolID][]typ.Type{ 1: {typ.String}, 2: {typ.Number}, }, @@ -35,7 +35,7 @@ func TestBuildOverlay_WithSummaries(t *testing.T) { func TestBuildOverlay_ExcludesCurrent(t *testing.T) { conf := OverlayConfig{ - Summaries: map[cfg.SymbolID][]typ.Type{ + ReturnVectors: map[cfg.SymbolID][]typ.Type{ 1: {typ.String}, }, CurrentSym: 1, @@ -56,7 +56,7 @@ func TestOverlayEntry(t *testing.T) { } } -func TestBuildOverlay_SeedsSiblingsWithoutSummaries(t *testing.T) { +func TestBuildOverlay_SeedsSiblingsWithoutReturnVectors(t *testing.T) { seedType := typ.Func().Param("x", typ.Number).Build() fn := &ast.FunctionExpr{} conf := OverlayConfig{ @@ -72,14 +72,14 @@ func TestBuildOverlay_SeedsSiblingsWithoutSummaries(t *testing.T) { } result := BuildOverlay(conf) if result[1] != seedType { - t.Error("should seed sibling without summary using SeedType") + t.Error("should seed sibling without a return vector using SeedType") } } -func TestBuildOverlay_SummaryOverridesSeed(t *testing.T) { +func TestBuildOverlay_ReturnVectorOverridesSeed(t *testing.T) { fn := &ast.FunctionExpr{} conf := OverlayConfig{ - Summaries: map[cfg.SymbolID][]typ.Type{ + ReturnVectors: map[cfg.SymbolID][]typ.Type{ 1: {typ.String}, }, Siblings: []OverlayEntry{ @@ -98,7 +98,7 @@ func TestBuildOverlay_SummaryOverridesSeed(t *testing.T) { t.Fatal("should produce function type") } if len(fn2.Returns) == 0 || fn2.Returns[0] != typ.String { - t.Error("summary should take precedence") + t.Error("return vector should take precedence") } } diff --git a/compiler/check/siblings/siblings.go b/compiler/check/siblings/siblings.go index e93e7b7f..24263a0d 100644 --- a/compiler/check/siblings/siblings.go +++ b/compiler/check/siblings/siblings.go @@ -18,18 +18,18 @@ // // # Build Algorithm // -// The Build function constructs sibling types through four steps: +// The Build function constructs sibling types through three steps: // 1. Seed from previous iteration (monotonic accumulation across fixpoint iterations) // 2. Merge captured variable types from the parent scope -// 3. Add sibling function types enriched with return summaries -// 4. Overlay literal signatures for refined function types +// 3. Add sibling function types from canonical function facts // // The result is a SymbolID -> Type map that can be injected into the type environment // when analyzing any function in the group. // // # Integration with Fixpoint // -// Sibling types are recomputed on each fixpoint iteration as return summaries improve. +// Sibling types are recomputed on each fixpoint iteration as canonical function +// facts improve. // The monotonic accumulation (step 1) ensures that types only grow more precise, // guaranteeing convergence. package siblings @@ -37,6 +37,7 @@ package siblings import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/types/typ" "github.com/wippyai/go-lua/types/typ/unwrap" @@ -70,8 +71,8 @@ type BuildConfig struct { // SiblingTypesPrev are sibling types from the previous iteration (monotonic accumulation). SiblingTypesPrev map[cfg.SymbolID]typ.Type - // FuncTypes are canonical local function types for this scope group. - FuncTypes map[cfg.SymbolID]typ.Type + // FunctionFacts are canonical local function facts for this scope group. + FunctionFacts api.FunctionFacts // Services provides required lookups for sibling construction. Services BuildServices @@ -178,7 +179,7 @@ func Build(c BuildConfig) map[cfg.SymbolID]typ.Type { if !entry.IsLocal || entry.Symbol == 0 { continue } - fnType := c.FuncTypes[entry.Symbol] + fnType := c.FunctionFacts.FunctionType(entry.Symbol) if fnType == nil { continue } diff --git a/compiler/check/siblings/siblings_test.go b/compiler/check/siblings/siblings_test.go index eefc4a49..9a82b982 100644 --- a/compiler/check/siblings/siblings_test.go +++ b/compiler/check/siblings/siblings_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/types/typ" ) @@ -22,7 +23,7 @@ func TestBuild_WithFuncs(t *testing.T) { Funcs: []FuncEntry{ {Symbol: 1, IsLocal: true}, }, - FuncTypes: map[cfg.SymbolID]typ.Type{1: fnType}, + FunctionFacts: api.FunctionFacts{1: {Type: fnType}}, } result := Build(conf) if result == nil { @@ -38,7 +39,7 @@ func TestBuild_WithPrev(t *testing.T) { Funcs: []FuncEntry{ {Symbol: 1, IsLocal: true}, }, - FuncTypes: map[cfg.SymbolID]typ.Type{1: typ.Func().Build()}, + FunctionFacts: api.FunctionFacts{1: {Type: typ.Func().Build()}}, SiblingTypesPrev: map[cfg.SymbolID]typ.Type{ 2: typ.String, }, diff --git a/compiler/check/store/doc.go b/compiler/check/store/doc.go index 1e7739ed..b6d27010 100644 --- a/compiler/check/store/doc.go +++ b/compiler/check/store/doc.go @@ -10,12 +10,14 @@ // The store holds: // - Built CFGs indexed by graph ID // - Analysis results (types, flow facts, diagnostics) per function -// - Interprocedural facts (return summaries, parameter hints) +// - Interprocedural facts (canonical function facts, parameter hints) // - Module-level bindings and alias maps +// - Query-tracked interprocedural snapshot inputs for precise function-result +// cache revalidation // // # Session Integration // -// The store implements [api.StoreView] and [api.IterationStore] interfaces, +// The store implements [api.StoreReader] and [api.IterationStore] interfaces, // providing read access for queries and write access for the fixpoint driver. // // # Snapshot Isolation diff --git a/compiler/check/store/snapshot_inputs.go b/compiler/check/store/snapshot_inputs.go new file mode 100644 index 00000000..40b3d9ec --- /dev/null +++ b/compiler/check/store/snapshot_inputs.go @@ -0,0 +1,261 @@ +package store + +import ( + "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/returns" + "github.com/wippyai/go-lua/types/constraint" + "github.com/wippyai/go-lua/types/db" + "github.com/wippyai/go-lua/types/typ" +) + +// snapshotInputs are the Salsa-style source inputs for interprocedural reads. +// They mirror the store's read snapshots, so FuncResult queries depend on the +// exact graph facts, refinement symbols, and constructor symbols they read. +type snapshotInputs struct { + database *db.DB + + facts *db.Input[api.GraphKey, api.Facts] + factValues map[api.GraphKey]api.Facts + + refinements *db.Input[cfg.SymbolID, *constraint.FunctionRefinement] + refinementValues map[cfg.SymbolID]*constraint.FunctionRefinement + + constructorFields *db.Input[cfg.SymbolID, map[string]typ.Type] + constructorFieldValues map[cfg.SymbolID]map[string]typ.Type +} + +func newSnapshotInputs(database *db.DB) *snapshotInputs { + if database == nil { + return nil + } + return &snapshotInputs{ + database: database, + facts: db.NewInput[api.GraphKey, api.Facts](database), + factValues: make(map[api.GraphKey]api.Facts), + refinements: db.NewInput[cfg.SymbolID, *constraint.FunctionRefinement](database), + refinementValues: make(map[cfg.SymbolID]*constraint.FunctionRefinement), + constructorFields: db.NewInput[cfg.SymbolID, map[string]typ.Type](database), + constructorFieldValues: make(map[cfg.SymbolID]map[string]typ.Type), + } +} + +func (in *snapshotInputs) reset() { + if in == nil || in.database == nil { + return + } + for key := range in.factValues { + in.facts.Set(key, api.Facts{}) + } + clear(in.factValues) + for sym := range in.refinementValues { + in.refinements.Set(sym, nil) + } + clear(in.refinementValues) + for sym := range in.constructorFieldValues { + in.constructorFields.Set(sym, nil) + } + clear(in.constructorFieldValues) +} + +func (in *snapshotInputs) factsFor(ctx *db.QueryContext, key api.GraphKey) (api.Facts, bool) { + if in == nil || in.facts == nil { + return api.Facts{}, false + } + return in.facts.Get(ctx, key) +} + +func (in *snapshotInputs) setFacts(key api.GraphKey, facts api.Facts) { + if in == nil || in.facts == nil { + return + } + if factsEmpty(facts) { + if _, ok := in.factValues[key]; !ok { + return + } + delete(in.factValues, key) + in.facts.Set(key, api.Facts{}) + return + } + next := facts + if prev, ok := in.factValues[key]; ok && returns.FactsEqual(prev, next) { + return + } + in.factValues[key] = next + in.facts.Set(key, next) +} + +func (in *snapshotInputs) refinement(ctx *db.QueryContext, sym cfg.SymbolID) (*constraint.FunctionRefinement, bool) { + if in == nil || in.refinements == nil { + return nil, false + } + return in.refinements.Get(ctx, sym) +} + +func (in *snapshotInputs) setRefinement(sym cfg.SymbolID, refinement *constraint.FunctionRefinement) { + if in == nil || in.refinements == nil || sym == 0 { + return + } + if refinement == nil { + if _, ok := in.refinementValues[sym]; !ok { + return + } + delete(in.refinementValues, sym) + in.refinements.Set(sym, nil) + return + } + if prev, ok := in.refinementValues[sym]; ok && effectsEqual(prev, refinement) { + return + } + in.refinementValues[sym] = refinement + in.refinements.Set(sym, refinement) +} + +func (in *snapshotInputs) constructorFieldsFor( + ctx *db.QueryContext, + sym cfg.SymbolID, +) (map[string]typ.Type, bool) { + if in == nil || in.constructorFields == nil { + return nil, false + } + return in.constructorFields.Get(ctx, sym) +} + +func (in *snapshotInputs) setConstructorFields(sym cfg.SymbolID, fields map[string]typ.Type) { + if in == nil || in.constructorFields == nil || sym == 0 { + return + } + if len(fields) == 0 { + if _, ok := in.constructorFieldValues[sym]; !ok { + return + } + delete(in.constructorFieldValues, sym) + in.constructorFields.Set(sym, nil) + return + } + next := cloneFieldTypes(fields) + if prev, ok := in.constructorFieldValues[sym]; ok && constructorFieldMapsEqual(sym, prev, next) { + return + } + in.constructorFieldValues[sym] = next + in.constructorFields.Set(sym, next) +} + +func cloneFieldTypes(src map[string]typ.Type) map[string]typ.Type { + if len(src) == 0 { + return nil + } + out := make(map[string]typ.Type, len(src)) + for name, t := range src { + out[name] = t + } + return out +} + +func constructorFieldMapsEqual(sym cfg.SymbolID, a, b map[string]typ.Type) bool { + if len(a) == 0 && len(b) == 0 { + return true + } + return returns.ConstructorFieldsEqual( + api.ConstructorFields{sym: a}, + api.ConstructorFields{sym: b}, + ) +} + +func (s *SessionStore) PushSnapshotReadContext(ctx *db.QueryContext) func() { + if s == nil || ctx == nil || s.snapshotInputs == nil { + return func() {} + } + prev := s.snapshotCtx + s.snapshotCtx = ctx + return func() { + s.snapshotCtx = prev + } +} + +func (s *SessionStore) currentInterprocFacts(key api.GraphKey) api.Facts { + if s == nil { + return api.Facts{} + } + var prev api.Facts + if s.InterprocPrev != nil && s.InterprocPrev.Facts != nil { + prev = s.InterprocPrev.Facts[key] + } + if s.InterprocNext != nil && s.InterprocNext.Facts != nil { + if next, ok := s.InterprocNext.Facts[key]; ok { + if factsEmpty(prev) { + return next + } + if factsEmpty(next) { + return prev + } + return returns.JoinFacts(prev, next) + } + } + return prev +} + +func (s *SessionStore) syncFactsInput(key api.GraphKey) { + if s == nil || s.snapshotInputs == nil { + return + } + s.snapshotInputs.setFacts(key, s.currentInterprocFacts(key)) +} + +func (s *SessionStore) syncSnapshotInputs() { + if s == nil || s.snapshotInputs == nil { + return + } + + factKeys := make(map[api.GraphKey]struct{}, len(s.snapshotInputs.factValues)) + for key := range s.snapshotInputs.factValues { + factKeys[key] = struct{}{} + } + if s.InterprocPrev != nil { + for key := range s.InterprocPrev.Facts { + factKeys[key] = struct{}{} + } + } + if s.InterprocNext != nil { + for key := range s.InterprocNext.Facts { + factKeys[key] = struct{}{} + } + } + for key := range factKeys { + s.syncFactsInput(key) + } + + refinementSyms := make(map[cfg.SymbolID]struct{}, len(s.snapshotInputs.refinementValues)) + for sym := range s.snapshotInputs.refinementValues { + refinementSyms[sym] = struct{}{} + } + if s.InterprocPrev != nil { + for sym := range s.InterprocPrev.Refinements { + refinementSyms[sym] = struct{}{} + } + } + for sym := range refinementSyms { + var refinement *constraint.FunctionRefinement + if s.InterprocPrev != nil { + refinement = s.InterprocPrev.Refinements[sym] + } + s.snapshotInputs.setRefinement(sym, refinement) + } + + constructorSyms := make(map[cfg.SymbolID]struct{}, len(s.snapshotInputs.constructorFieldValues)) + for sym := range s.snapshotInputs.constructorFieldValues { + constructorSyms[sym] = struct{}{} + } + if s.InterprocPrev != nil { + for sym := range s.InterprocPrev.ConstructorFields { + constructorSyms[sym] = struct{}{} + } + } + for sym := range constructorSyms { + var fields map[string]typ.Type + if s.InterprocPrev != nil { + fields = s.InterprocPrev.ConstructorFields[sym] + } + s.snapshotInputs.setConstructorFields(sym, fields) + } +} diff --git a/compiler/check/store/store.go b/compiler/check/store/store.go index ebd1534a..aecbeba3 100644 --- a/compiler/check/store/store.go +++ b/compiler/check/store/store.go @@ -10,6 +10,7 @@ import ( "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/compiler/check/scope" "github.com/wippyai/go-lua/types/constraint" + "github.com/wippyai/go-lua/types/db" "github.com/wippyai/go-lua/types/typ" ) @@ -18,14 +19,11 @@ type SessionStore struct { // Created once at the start of checking and shared by all CFG builds. Module *ModuleStore - // Iteration contains iteration-local state for fixpoint convergence. - Iteration *IterationStore - // Scratch contains iteration-local state cleared each cycle. Scratch *IterationScratch // InterprocPrev holds the stable interproc snapshot used during analysis. - // Updated at fixpoint boundaries to provide a consistent view. + // Updated at fixpoint boundaries to provide a consistent snapshot. InterprocPrev *InterprocState // InterprocNext accumulates facts/effects produced during the current iteration. InterprocNext *InterprocState @@ -37,6 +35,9 @@ type SessionStore struct { // Stored per-session to avoid cross-session contamination. lastSwapDiffs []string + snapshotInputs *snapshotInputs + snapshotCtx *db.QueryContext + phase api.Phase } @@ -93,13 +94,6 @@ type FunctionRegistry struct { ByGraphID map[uint64]*api.FunctionRef } -// IterationStore holds iteration-local state for fixpoint iteration. -type IterationStore struct { - // Revision is bumped at fixpoint iteration boundary. - // Included in FuncKey to invalidate cached results when interproc facts/effects change. - Revision uint64 -} - // IterationScratch holds iteration-local state cleared each cycle. // Not double-buffered; reset at each iteration boundary. type IterationScratch struct { @@ -158,7 +152,7 @@ func widenInterprocFacts(prev, next map[api.GraphKey]api.Facts) map[api.GraphKey if existing, ok := out[key]; ok { out[key] = returns.WidenFacts(existing, facts) } else { - out[key] = facts + out[key] = returns.WidenFacts(api.Facts{}, facts) } } return out @@ -166,13 +160,21 @@ func widenInterprocFacts(prev, next map[api.GraphKey]api.Facts) map[api.GraphKey // NewSessionStore creates an initialized store with all sub-structs. func NewSessionStore() *SessionStore { + return NewSessionStoreWithDB(nil) +} + +// NewSessionStoreWithDB creates a store whose interproc snapshots are tracked +// as query inputs. The checker uses this form so function-result queries can be +// revalidated from the exact facts/effects they read instead of from a coarse +// iteration revision key. +func NewSessionStoreWithDB(database *db.DB) *SessionStore { return &SessionStore{ Module: NewModuleStore(), - Iteration: NewIterationStore(), Scratch: NewIterationScratch(), InterprocPrev: NewInterprocState(), InterprocNext: NewInterprocState(), GraphParentHash: make(map[uint64]uint64), + snapshotInputs: newSnapshotInputs(database), phase: api.PhaseScopeCompute, } } @@ -236,11 +238,6 @@ func NewModuleStore() *ModuleStore { } } -// NewIterationStore creates an initialized iteration store. -func NewIterationStore() *IterationStore { - return &IterationStore{} -} - // NewIterationScratch creates an initialized iteration scratch. func NewIterationScratch() *IterationScratch { return &IterationScratch{ @@ -376,6 +373,7 @@ func (s *SessionStore) FixpointSwap() bool { diffs := s.swapInterprocChannels() s.resetScratch() + s.syncSnapshotInputs() // Record which channels changed for diagnostic reporting s.lastSwapDiffs = diffs @@ -397,31 +395,16 @@ func (s *SessionStore) FixpointChannelDiffs() []string { return out } -// Revision returns the current revision counter. -func (s *SessionStore) Revision() uint64 { - if s == nil || s.Iteration == nil { - return 0 - } - return s.Iteration.Revision -} - -// BumpRevision increments the revision counter. -func (s *SessionStore) BumpRevision() { - if s == nil { - return - } - if s.Iteration == nil { - s.Iteration = NewIterationStore() - } - s.Iteration.Revision++ -} - // LookupRefinementBySym returns the refinement for a function by its SymbolID. // Reads from the stable interproc refinement snapshot for order-independent analysis. func (s *SessionStore) LookupRefinementBySym(sym cfg.SymbolID) *constraint.FunctionRefinement { - if sym == 0 { + if s == nil || sym == 0 { return nil } + if s.snapshotInputs != nil { + refinement, _ := s.snapshotInputs.refinement(s.snapshotCtx, sym) + return refinement + } if s.InterprocPrev == nil || s.InterprocPrev.Refinements == nil { return nil } @@ -471,6 +454,10 @@ func (s *SessionStore) LookupConstructorFields(classSym cfg.SymbolID) map[string if s == nil || classSym == 0 { return nil } + if s.snapshotInputs != nil { + fields, _ := s.snapshotInputs.constructorFieldsFor(s.snapshotCtx, classSym) + return fields + } if s.InterprocPrev == nil { return nil } @@ -482,25 +469,21 @@ func (s *SessionStore) ClearIterationChannels() { if s == nil { return } - if s.Iteration == nil { - s.Iteration = NewIterationStore() - } if s.Scratch == nil { s.Scratch = NewIterationScratch() } s.InterprocPrev = NewInterprocState() s.InterprocNext = NewInterprocState() s.resetScratch() - s.Iteration.Revision = 0 s.lastSwapDiffs = nil + if s.snapshotInputs != nil { + s.snapshotInputs.reset() + } } -// RefinementStore returns a view over the stable interproc refinement snapshot. +// RefinementStore returns a reader over the stable interproc refinement snapshot. func (s *SessionStore) RefinementStore() api.RefinementStore { - if s == nil || s.InterprocPrev == nil { - return &snapshotRefinementStore{refinements: nil} - } - return &snapshotRefinementStore{refinements: s.InterprocPrev.Refinements} + return &snapshotRefinementStore{store: s} } // ModuleBindings returns the module binding table. @@ -557,36 +540,38 @@ func (s *SessionStore) ParentGraphKeyForSymbol(sym cfg.SymbolID) (api.GraphKey, return api.KeyForGraph(graph, parentHash), true } -func initInterprocFacts(f *api.Facts) { - if f.FunctionFacts == nil { - f.FunctionFacts = make(api.FunctionFacts) - } - if f.ParamHints == nil { - f.ParamHints = make(map[cfg.SymbolID][]typ.Type) - } - if f.LiteralSigs == nil { - f.LiteralSigs = make(map[*ast.FunctionExpr]*typ.Function) - } - if f.CapturedTypes == nil { - f.CapturedTypes = make(api.CapturedTypes) - } - if f.CapturedFields == nil { - f.CapturedFields = make(api.CapturedFieldAssigns) - } +func factsEmpty(f api.Facts) bool { + return len(f.FunctionFacts) == 0 && + len(f.ParamHints) == 0 && + len(f.LiteralSigs) == 0 && + len(f.CapturedTypes) == 0 && + len(f.CapturedFields) == 0 && + len(f.CapturedContainers) == 0 && + len(f.ConstructorFields) == 0 } -// UpdateInterprocFactsNext updates interproc facts for the next iteration. -// This is the public entry point used by post-flow analysis to record results. -func (s *SessionStore) UpdateInterprocFactsNext(key api.GraphKey, update func(*api.Facts)) { +// MergeInterprocFactsNext merges a canonical fact delta into the next +// interprocedural snapshot for the current iteration. +func (s *SessionStore) MergeInterprocFactsNext(key api.GraphKey, delta api.Facts) { if s == nil { return } s.ensureInterprocStates() - facts := s.InterprocNext.Facts[key] - initInterprocFacts(&facts) - update(&facts) - returns.NormalizeFunctionFactChannels(&facts) + existing := s.InterprocNext.Facts[key] + facts := returns.JoinFacts(existing, delta) + if factsEmpty(facts) { + if factsEmpty(existing) { + return + } + delete(s.InterprocNext.Facts, key) + s.syncFactsInput(key) + return + } + if returns.FactsEqual(existing, facts) { + return + } s.InterprocNext.Facts[key] = facts + s.syncFactsInput(key) } // Funcs returns the function map. @@ -797,16 +782,29 @@ func (s *SessionStore) SetModuleAliases(aliases map[cfg.SymbolID]string) { s.Module.ModuleAliases = aliases } -// GetInterprocFactsSnapshot returns the stable interproc facts snapshot for a graph. +// GetInterprocFactsSnapshot returns the current joined interproc facts snapshot +// for a graph. It starts from the stable previous snapshot and overlays facts +// produced earlier in the current iteration, giving deterministic Gauss-Seidel +// propagation instead of forcing every local refinement through a full outer +// iteration. func (s *SessionStore) GetInterprocFactsSnapshot( graph *cfg.Graph, parent *scope.State, ) api.Facts { - if s == nil || s.InterprocPrev == nil || s.InterprocPrev.Facts == nil || graph == nil || parent == nil { + if s == nil || graph == nil { return api.Facts{} } - key := api.KeyForGraph(graph, parent.Hash()) - return s.InterprocPrev.Facts[key] + key, ok := s.GraphKeyFor(graph, parent) + if !ok { + return api.Facts{} + } + if s.snapshotInputs != nil { + if facts, ok := s.snapshotInputs.factsFor(s.snapshotCtx, key); ok { + return facts + } + return api.Facts{} + } + return s.currentInterprocFacts(key) } // GetParamHintsSnapshot returns param hints from the stable interproc snapshot. @@ -818,31 +816,15 @@ func (s *SessionStore) GetParamHintsSnapshot( return s.GetInterprocFactsSnapshot(graph, parent).ParamHints } -// GetReturnSummariesSnapshot returns return summaries from the stable interproc snapshot. -func (s *SessionStore) GetReturnSummariesSnapshot( - graph *cfg.Graph, - parent *scope.State, -) map[cfg.SymbolID][]typ.Type { - s.requirePhase(api.PhaseScopeCompute) - return returns.SummaryViewFromFacts(s.GetInterprocFactsSnapshot(graph, parent)) -} - -// GetNarrowReturnSummariesSnapshot returns post-flow return summaries from the stable snapshot. -func (s *SessionStore) GetNarrowReturnSummariesSnapshot( - graph *cfg.Graph, - parent *scope.State, -) map[cfg.SymbolID][]typ.Type { - s.requirePhase(api.PhaseNarrowing) - return returns.NarrowViewFromFacts(s.GetInterprocFactsSnapshot(graph, parent)) -} - -// GetLocalFuncTypesSnapshot returns canonical local function types from the stable interproc snapshot. -func (s *SessionStore) GetLocalFuncTypesSnapshot( +// GetFunctionFactsSnapshot returns canonical function facts from the stable +// interproc snapshot. +func (s *SessionStore) GetFunctionFactsSnapshot( graph *cfg.Graph, parent *scope.State, -) map[cfg.SymbolID]typ.Type { - s.requirePhase(api.PhaseScopeCompute) - return returns.FuncTypeViewFromFacts(s.GetInterprocFactsSnapshot(graph, parent)) +) api.FunctionFacts { + s.requirePhase(api.PhaseScopeCompute, api.PhaseNarrowing) + facts := s.GetInterprocFactsSnapshot(graph, parent) + return facts.FunctionFacts } // GetLiteralSigsSnapshot returns literal signatures from the stable interproc snapshot. @@ -907,18 +889,12 @@ func (s *SessionStore) GetCapturedContainerMutationsSnapshot( // snapshotRefinementStore implements api.RefinementStore using the stable snapshot. type snapshotRefinementStore struct { - refinements map[cfg.SymbolID]*constraint.FunctionRefinement + store *SessionStore } func (o *snapshotRefinementStore) LookupRefinementBySym(sym cfg.SymbolID) *constraint.FunctionRefinement { if o == nil || sym == 0 { return nil } - if o.refinements == nil { - return nil - } - if refinement := o.refinements[sym]; refinement != nil { - return refinement - } - return nil + return o.store.LookupRefinementBySym(sym) } diff --git a/compiler/check/store/store_test.go b/compiler/check/store/store_test.go index 7855c401..c67e78b8 100644 --- a/compiler/check/store/store_test.go +++ b/compiler/check/store/store_test.go @@ -6,9 +6,9 @@ import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" - "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/compiler/check/scope" "github.com/wippyai/go-lua/types/constraint" + "github.com/wippyai/go-lua/types/db" "github.com/wippyai/go-lua/types/typ" ) @@ -123,6 +123,28 @@ func TestWidenInterprocFacts_OnlyNext(t *testing.T) { } } +func TestWidenInterprocFacts_NormalizesNewFacts(t *testing.T) { + fn := typ.Func().Param("value", typ.Unknown).Build() + key := api.GraphKey{GraphID: 1, ParentHash: 2} + next := map[api.GraphKey]api.Facts{ + key: { + CapturedFields: api.CapturedFieldAssigns{ + cfg.SymbolID(10): { + cfg.SymbolID(20): { + "after_all": typ.NewOptional(fn), + }, + }, + }, + }, + } + + result := widenInterprocFacts(nil, next) + got := result[key].CapturedFields[cfg.SymbolID(10)][cfg.SymbolID(20)]["after_all"] + if !typ.TypeEquals(got, fn) { + t.Fatalf("expected new facts to be normalized through WidenFacts, got %v", got) + } +} + func TestWidenInterprocFacts_Merge(t *testing.T) { prev := map[api.GraphKey]api.Facts{ {GraphID: 1}: { @@ -144,7 +166,7 @@ func TestWidenInterprocFacts_Merge(t *testing.T) { } } -func TestReturnSummariesFromFacts_FallsBackToCanonical(t *testing.T) { +func TestFunctionFactsSummaryAccessor(t *testing.T) { facts := api.Facts{ FunctionFacts: api.FunctionFacts{ cfg.SymbolID(1): { @@ -152,13 +174,13 @@ func TestReturnSummariesFromFacts_FallsBackToCanonical(t *testing.T) { }, }, } - got := returns.SummaryViewFromFacts(facts) - if len(got) != 1 || len(got[cfg.SymbolID(1)]) != 1 || got[cfg.SymbolID(1)][0] != typ.String { - t.Fatalf("unexpected summary view: %#v", got) + got := facts.FunctionFacts.Summary(cfg.SymbolID(1)) + if len(got) != 1 || got[0] != typ.String { + t.Fatalf("unexpected summary: %#v", got) } } -func TestNarrowReturnSummariesFromFacts_FallsBackToCanonical(t *testing.T) { +func TestFunctionFactsNarrowAccessor(t *testing.T) { facts := api.Facts{ FunctionFacts: api.FunctionFacts{ cfg.SymbolID(2): { @@ -166,24 +188,142 @@ func TestNarrowReturnSummariesFromFacts_FallsBackToCanonical(t *testing.T) { }, }, } - got := returns.NarrowViewFromFacts(facts) - if len(got) != 1 || len(got[cfg.SymbolID(2)]) != 1 || got[cfg.SymbolID(2)][0] != typ.Number { - t.Fatalf("unexpected narrow view: %#v", got) + got := facts.FunctionFacts.NarrowSummary(cfg.SymbolID(2)) + if len(got) != 1 || got[0] != typ.Number { + t.Fatalf("unexpected narrow summary: %#v", got) } } -func TestLocalFuncTypesFromFacts_FallsBackToCanonical(t *testing.T) { +func TestFunctionFactsTypeAccessor(t *testing.T) { fn := typ.Func().Returns(typ.Boolean).Build() facts := api.Facts{ FunctionFacts: api.FunctionFacts{ cfg.SymbolID(3): { - Func: fn, + Type: fn, }, }, } - got := returns.FuncTypeViewFromFacts(facts) - if len(got) != 1 || !typ.TypeEquals(got[cfg.SymbolID(3)], fn) { - t.Fatalf("unexpected func type view: %#v", got) + got := facts.FunctionFacts.FunctionType(cfg.SymbolID(3)) + if !typ.TypeEquals(got, fn) { + t.Fatalf("unexpected function type: %#v", got) + } +} + +func TestGetInterprocFactsSnapshot_UsesStoredGraphParentHash(t *testing.T) { + graph := cfg.Build(&ast.FunctionExpr{}) + if graph == nil || graph.ID() == 0 { + t.Fatal("expected graph with stable ID") + } + + storedParent := scope.New().WithType("T", typ.String) + currentParent := scope.New().WithType("T", typ.Number) + if storedParent.Hash() == currentParent.Hash() { + t.Fatal("test requires different parent hashes") + } + + s := NewSessionStore() + s.SetGraphParentHash(graph.ID(), storedParent.Hash()) + s.SetParentScope(storedParent.Hash(), storedParent) + key := api.KeyForGraph(graph, storedParent.Hash()) + s.InterprocPrev.Facts[key] = api.Facts{ + FunctionFacts: api.FunctionFacts{ + cfg.SymbolID(1): {Summary: []typ.Type{typ.String}}, + }, + } + + got := s.GetInterprocFactsSnapshot(graph, currentParent) + summary := got.FunctionFacts.Summary(cfg.SymbolID(1)) + if len(summary) != 1 || !typ.TypeEquals(summary[0], typ.String) { + t.Fatalf("expected snapshot from stored parent hash, got %#v", summary) + } +} + +func TestGetInterprocFactsSnapshot_OverlaysCurrentIterationFacts(t *testing.T) { + graph := cfg.Build(&ast.FunctionExpr{}) + if graph == nil || graph.ID() == 0 { + t.Fatal("expected graph with stable ID") + } + + parent := scope.New().WithType("T", typ.String) + s := NewSessionStore() + s.SetGraphParentHash(graph.ID(), parent.Hash()) + s.SetParentScope(parent.Hash(), parent) + key := api.KeyForGraph(graph, parent.Hash()) + s.InterprocPrev.Facts[key] = api.Facts{ + FunctionFacts: api.FunctionFacts{ + cfg.SymbolID(1): {Summary: []typ.Type{typ.String}}, + }, + } + s.InterprocNext.Facts[key] = api.Facts{ + FunctionFacts: api.FunctionFacts{ + cfg.SymbolID(1): {Summary: []typ.Type{typ.Number}}, + }, + } + + got := s.GetInterprocFactsSnapshot(graph, parent) + summary := got.FunctionFacts.Summary(cfg.SymbolID(1)) + want := typ.NewUnion(typ.String, typ.Number) + if len(summary) != 1 || !typ.TypeEquals(summary[0], want) { + t.Fatalf("expected widened current snapshot %v, got %#v", want, summary) + } +} + +func TestMergeInterprocFactsNext_ReconcilesDeltasWithinIteration(t *testing.T) { + key := api.GraphKey{GraphID: 1, ParentHash: 2} + sym := cfg.SymbolID(7) + refined := typ.Func().Param("path", typ.String).Returns(typ.String).Build() + broad := typ.Func().Param("path", typ.Any).Returns(typ.String).Build() + + s := NewSessionStore() + first := api.Facts{FunctionFacts: api.FunctionFacts{sym: {Type: refined}}} + s.MergeInterprocFactsNext(key, first) + s.MergeInterprocFactsNext(key, api.Facts{ + FunctionFacts: api.FunctionFacts{ + sym: {Type: broad}, + }, + }) + + got := s.InterprocNext.Facts[key].FunctionFacts.FunctionType(sym) + if !typ.TypeEquals(got, refined) { + t.Fatalf("expected update boundary to keep canonical refined function fact, got %v", got) + } +} + +func TestSnapshotInputs_RevalidateFactQueries(t *testing.T) { + database := db.New() + ctx := db.NewQueryContext(database) + s := NewSessionStoreWithDB(database) + key := api.GraphKey{GraphID: 1, ParentHash: 2} + sym := cfg.SymbolID(7) + + calls := 0 + q := db.NewQuery("trackedFactsTest", func(ctx *db.QueryContext, key api.GraphKey) int { + calls++ + facts, _ := s.snapshotInputs.factsFor(ctx, key) + if len(facts.FunctionFacts.Summary(sym)) == 0 { + return 0 + } + return 1 + }, func(a, b int) bool { return a == b }) + + if got := q.Get(ctx, key); got != 0 { + t.Fatalf("initial query = %d, want 0", got) + } + if got := q.Get(ctx, key); got != 0 || calls != 1 { + t.Fatalf("unchanged query = %d calls=%d, want 0/1", got, calls) + } + + delta := api.Facts{FunctionFacts: api.FunctionFacts{ + sym: {Summary: []typ.Type{typ.String}}, + }} + s.MergeInterprocFactsNext(key, delta) + if got := q.Get(ctx, key); got != 1 || calls != 2 { + t.Fatalf("changed query = %d calls=%d, want 1/2", got, calls) + } + + s.MergeInterprocFactsNext(key, delta) + if got := q.Get(ctx, key); got != 1 || calls != 2 { + t.Fatalf("equal update query = %d calls=%d, want 1/2", got, calls) } } @@ -192,16 +332,10 @@ func TestSessionStore_Fields(t *testing.T) { Module: &ModuleStore{ Graphs: make(map[uint64]*cfg.Graph), }, - Iteration: &IterationStore{ - Revision: 5, - }, } if s.Module == nil { t.Error("Module should be set") } - if s.Iteration.Revision != 5 { - t.Error("Revision should be 5") - } } func TestModuleStore_Fields(t *testing.T) { @@ -226,13 +360,6 @@ func TestFunctionRegistry_Fields(t *testing.T) { } } -func TestIterationStore_Fields(t *testing.T) { - i := &IterationStore{Revision: 10} - if i.Revision != 10 { - t.Error("Revision not set") - } -} - func TestIterationScratch_Fields(t *testing.T) { s := &IterationScratch{ LiteralSigsByGraphID: make(map[uint64]map[*ast.FunctionExpr]*typ.Function), @@ -291,9 +418,6 @@ func TestClearIterationChannels_InitializesMissingState(t *testing.T) { s := &SessionStore{} s.ClearIterationChannels() - if s.Iteration == nil { - t.Fatal("expected iteration store to be initialized") - } if s.Scratch == nil { t.Fatal("expected scratch to be initialized") } @@ -305,14 +429,6 @@ func TestClearIterationChannels_InitializesMissingState(t *testing.T) { } } -func TestBumpRevision_InitializesIterationStore(t *testing.T) { - s := &SessionStore{} - s.BumpRevision() - if got := s.Revision(); got != 1 { - t.Fatalf("expected revision 1, got %d", got) - } -} - func TestFixpointChannelDiffs_ReturnsCopy(t *testing.T) { s := NewSessionStore() s.StoreFunctionRefinement(1, &constraint.FunctionRefinement{Terminates: true}) @@ -331,17 +447,3 @@ func TestFixpointChannelDiffs_ReturnsCopy(t *testing.T) { t.Fatalf("expected defensive copy, got %v", diffs2) } } - -func TestClearIterationChannels_ResetsRevision(t *testing.T) { - s := NewSessionStore() - s.BumpRevision() - s.BumpRevision() - if got := s.Revision(); got != 2 { - t.Fatalf("expected revision 2, got %d", got) - } - - s.ClearIterationChannels() - if got := s.Revision(); got != 0 { - t.Fatalf("expected revision reset to 0, got %d", got) - } -} diff --git a/compiler/check/synth/ops/logical_test.go b/compiler/check/synth/ops/logical_test.go index 2927edc8..f94d31a7 100644 --- a/compiler/check/synth/ops/logical_test.go +++ b/compiler/check/synth/ops/logical_test.go @@ -87,12 +87,13 @@ func TestLogicalOrTyped_LeftOptional(t *testing.T) { } } -func TestLogicalOrTyped_SoftOptionalPrefersRight(t *testing.T) { +func TestLogicalOrTyped_SoftOptionalPreservesLeftRuntimeAlternative(t *testing.T) { left := typ.NewOptional(typ.NewArray(typ.Any)) right := typ.NewArray(typ.Number) result := LogicalOrTyped(left, right) - if result == nil || result.String() != "number[]" { - t.Errorf("expected right to win for soft optional, got %v", result) + want := typ.NewUnion(typ.NewArray(typ.Any), right) + if !typ.TypeEquals(result, want) { + t.Errorf("expected truthy left and right alternatives, got %v, want %v", result, want) } } diff --git a/compiler/check/synth/phase/core/params.go b/compiler/check/synth/phase/core/params.go index 5d9ba389..916925ee 100644 --- a/compiler/check/synth/phase/core/params.go +++ b/compiler/check/synth/phase/core/params.go @@ -91,6 +91,7 @@ func ApplyParamList(builder *typ.FunctionBuilder, fn *ast.FunctionExpr, cfg Para if builder == nil || fn == nil || fn.ParList == nil { return } + builder.ReserveParams(paramListCapacity(fn, cfg.ImplicitSelf)) shiftExpected := false if cfg.ImplicitSelf { @@ -160,3 +161,14 @@ func ApplyParamList(builder *typ.FunctionBuilder, fn *ast.FunctionExpr, cfg Para builder.Variadic(typ.Any) } } + +func paramListCapacity(fn *ast.FunctionExpr, implicitSelf bool) int { + if fn == nil || fn.ParList == nil { + return 0 + } + n := len(fn.ParList.Names) + if implicitSelf { + n++ + } + return n +} diff --git a/compiler/check/synth/phase/extract/callback_env_infer.go b/compiler/check/synth/phase/extract/callback_env_infer.go index 310d4f33..1c716ae3 100644 --- a/compiler/check/synth/phase/extract/callback_env_infer.go +++ b/compiler/check/synth/phase/extract/callback_env_infer.go @@ -102,8 +102,8 @@ func inferCallbackEnvOverlays( return nil } - idom, _ := analysis.ComputeDominators(baseCFG) - postIdom, _ := analysis.ComputePostDominators(baseCFG) + idom := analysis.ComputeImmediateDominators(baseCFG) + postIdom := analysis.ComputeImmediatePostDominators(baseCFG) result := make(map[int]map[string]typ.Type) diff --git a/compiler/check/synth/phase/extract/callback_env_infer_test.go b/compiler/check/synth/phase/extract/callback_env_infer_test.go index ac4c87f8..ccbe2a9a 100644 --- a/compiler/check/synth/phase/extract/callback_env_infer_test.go +++ b/compiler/check/synth/phase/extract/callback_env_infer_test.go @@ -171,7 +171,7 @@ func TestInferCallbackEnvOverlays_UsesCanonicalCandidatesWhenRawCallSymbolMissin } } -func TestInferCallbackEnvOverlays_UsesModuleBindingNameFallback(t *testing.T) { +func TestInferCallbackEnvOverlays_UsesModuleBindingNameResolution(t *testing.T) { code := ` _G.ctx = 1 local x = cb() @@ -199,7 +199,7 @@ func TestInferCallbackEnvOverlays_UsesModuleBindingNameFallback(t *testing.T) { moduleBindings := bind.NewBindingTable() moduleBindings.SetName(paramSlots[0].Symbol, "cb_alias") - // Force callback identity recovery through module-binding name fallback. + // Force callback identity recovery through module-binding name resolution. graph.EachCallSite(func(_ cfg.Point, info *cfg.CallInfo) { if info != nil { info.CalleeSymbol = 0 diff --git a/compiler/check/synth/phase/extract/deps.go b/compiler/check/synth/phase/extract/deps.go index 53ab09c0..e184b667 100644 --- a/compiler/check/synth/phase/extract/deps.go +++ b/compiler/check/synth/phase/extract/deps.go @@ -9,6 +9,7 @@ import ( "github.com/wippyai/go-lua/types/db" "github.com/wippyai/go-lua/types/io" "github.com/wippyai/go-lua/types/query/core" + "github.com/wippyai/go-lua/types/typ" ) // Deps aggregates all dependencies needed by the Synthesizer. @@ -43,6 +44,8 @@ type Deps struct { // FunctionTypeInProgress guards call-point local function specialization // against recursion across temporary synthesizers. FunctionTypeInProgress map[functionTypeProgressKey]bool + FunctionTypeCache map[functionTypeCacheKey]*typ.Function + StableFunctionSnapshot map[stableFunctionSnapshotKey]typ.Type // Module-level bindings for nested function CFG building. ModuleBindings *bind.BindingTable @@ -56,6 +59,20 @@ type functionTypeProgressKey struct { CapturePoint cfg.Point } +type functionTypeCacheKey struct { + Func *ast.FunctionExpr + Scope *scope.State + Expected *typ.Function + CapturePoint cfg.Point + Phase api.Phase +} + +type stableFunctionSnapshotKey struct { + GraphID uint64 + Parent *scope.State + Sym cfg.SymbolID +} + // NewDeps creates a new Deps instance. func NewDeps(ctx *db.QueryContext, types core.TypeOps, scopes api.ScopeMap, manifests io.ManifestQuerier, checkCtx api.BaseEnv) *Deps { return &Deps{ @@ -68,6 +85,8 @@ func NewDeps(ctx *db.QueryContext, types core.TypeOps, scopes api.ScopeMap, mani PreCache: make(api.Cache), NarrowCache: make(api.Cache), FunctionTypeInProgress: make(map[functionTypeProgressKey]bool), + FunctionTypeCache: make(map[functionTypeCacheKey]*typ.Function), + StableFunctionSnapshot: make(map[stableFunctionSnapshotKey]typ.Type), } } @@ -87,7 +106,7 @@ func (d *Deps) Entry() cfg.Point { return 0 } -// ScopeAt returns scope for point p with optional default fallback. +// ScopeAt returns scope for point p, using DefaultScope when point scope is absent. func (d *Deps) ScopeAt(p cfg.Point) *scope.State { if d == nil { return nil diff --git a/compiler/check/synth/phase/extract/doc.go b/compiler/check/synth/phase/extract/doc.go index 6ba53b86..f4e8a892 100644 --- a/compiler/check/synth/phase/extract/doc.go +++ b/compiler/check/synth/phase/extract/doc.go @@ -36,6 +36,6 @@ // // # Integration // -// This package bridges the CFG representation with the expression-level +// This package connects the CFG representation with the expression-level // type synthesis performed by the synth engine. package extract diff --git a/compiler/check/synth/phase/extract/expr.go b/compiler/check/synth/phase/extract/expr.go index c6d54a3a..0b128c35 100644 --- a/compiler/check/synth/phase/extract/expr.go +++ b/compiler/check/synth/phase/extract/expr.go @@ -532,10 +532,10 @@ func (s *Synthesizer) expandValuesCore(exprs []ast.Expr, needed int, single func if len(exprs) == 0 { return nil } - result := make([]typ.Type, 0, needed) + result := make([]typ.Type, 0, exprListResultCapacity(exprs, needed)) for i, expr := range exprs { - if i == len(exprs)-1 { + if i == len(exprs)-1 && ast.CanProduceMultipleValues(expr) { result = append(result, multi(expr)...) } else { result = append(result, single(expr)) @@ -549,6 +549,13 @@ func (s *Synthesizer) expandValuesCore(exprs []ast.Expr, needed int, single func return result } +func exprListResultCapacity(exprs []ast.Expr, needed int) int { + if needed > len(exprs) { + return needed + } + return len(exprs) +} + // expandValues expands expression list to types. func (s *Synthesizer) expandValues(exprs []ast.Expr, needed int, p cfg.Point, narrower api.FlowOps) []typ.Type { return s.expandValuesCore(exprs, needed, diff --git a/compiler/check/synth/phase/extract/function.go b/compiler/check/synth/phase/extract/function.go index ed3f9293..d54788b8 100644 --- a/compiler/check/synth/phase/extract/function.go +++ b/compiler/check/synth/phase/extract/function.go @@ -12,7 +12,7 @@ // CONTEXTUAL TYPING (EXPECTED TYPES) // // When an expected function type is available (e.g., from callback parameter context), -// it provides default types for unannotated parameters and fallback return types. +// it provides default types for unannotated parameters and return types. // This enables idioms like: // // items:filter(function(x) return x > 0 end) -- x inferred from filter's param @@ -21,7 +21,7 @@ // // Return types are inferred by analyzing all return statements in the function body. // The algorithm: -// 1. Check ReturnSummaries for pre-computed results (from prior iterations) +// 1. Check FunctionFacts for pre-computed results (from prior iterations) // 2. Build CFG and create type overlay with parameter types // 3. Create a temporary synthesizer environment // 4. Visit each return statement, synthesizing expression types @@ -63,7 +63,7 @@ func (s *Synthesizer) FunctionType(fn *ast.FunctionExpr, sc *scope.State) *typ.F // // When an expected function type is provided, it guides inference for: // - Unannotated parameter types (uses expected parameter types) -// - Unannotated return types (uses expected return types as fallback) +// - Unannotated return types (uses expected return types) // - Self parameter in methods (infers from expected first param) // // Processing order: @@ -98,6 +98,17 @@ func (s *Synthesizer) getOrBuildFunctionGraph(fn *ast.FunctionExpr) *cfg.Graph { return cfg.Build(fn) } +func (s *Synthesizer) currentFunctionFacts() api.FunctionFacts { + if s == nil || s.deps.CheckCtx == nil { + return nil + } + ctx, ok := s.deps.CheckCtx.(interface{ FunctionFacts() api.FunctionFacts }) + if !ok { + return nil + } + return ctx.FunctionFacts() +} + func (s *Synthesizer) synthFunctionTypeWithCapturePoint( fn *ast.FunctionExpr, sc *scope.State, @@ -108,12 +119,18 @@ func (s *Synthesizer) synthFunctionTypeWithCapturePoint( if fn == nil { return nil } + cacheKey, cacheable := s.functionTypeCacheKey(fn, sc, expected, capturePoint, captureTypes) + if cacheable && s.deps.FunctionTypeCache != nil { + if cached, ok := s.deps.FunctionTypeCache[cacheKey]; ok { + return cached + } + } if s.deps.FunctionTypeInProgress == nil { s.deps.FunctionTypeInProgress = make(map[functionTypeProgressKey]bool) } progressKey := functionTypeProgressKey{Func: fn, CapturePoint: capturePoint} if s.deps.FunctionTypeInProgress[progressKey] { - return s.buildFunctionTypeSummaryFallback(fn, sc, expected) + return s.buildFunctionTypeFromAvailableFacts(fn, sc, expected) } s.deps.FunctionTypeInProgress[progressKey] = true defer delete(s.deps.FunctionTypeInProgress, progressKey) @@ -188,11 +205,36 @@ func (s *Synthesizer) synthFunctionTypeWithCapturePoint( fnType := builder.Build() if inferredErrorReturn { - fnType = erreffect.AttachErrorReturnSpec(fnType, 0, 1) + fnType = erreffect.CanonicalLuaValueErrorConvention().Attach(fnType) + } + if cacheable { + if s.deps.FunctionTypeCache == nil { + s.deps.FunctionTypeCache = make(map[functionTypeCacheKey]*typ.Function) + } + s.deps.FunctionTypeCache[cacheKey] = fnType } return fnType } +func (s *Synthesizer) functionTypeCacheKey( + fn *ast.FunctionExpr, + sc *scope.State, + expected *typ.Function, + capturePoint cfg.Point, + captureTypes map[cfg.SymbolID]typ.Type, +) (functionTypeCacheKey, bool) { + if s == nil || s.deps == nil || fn == nil || len(captureTypes) != 0 { + return functionTypeCacheKey{}, false + } + return functionTypeCacheKey{ + Func: fn, + Scope: sc, + Expected: expected, + CapturePoint: capturePoint, + Phase: s.phase, + }, true +} + // inferReturnTypesFromBody infers return types from the function body. // If fnGraph is non-nil, it reuses the pre-built CFG instead of building a new one. func (s *Synthesizer) inferReturnTypesFromBody( @@ -207,18 +249,7 @@ func (s *Synthesizer) inferReturnTypesFromBody( return nil, false } - var returnSummaries map[cfg.SymbolID][]typ.Type - if s.deps.CheckCtx != nil { - if s.IsNarrowing() { - if ctx, ok := s.deps.CheckCtx.(api.NarrowEnv); ok { - returnSummaries = ctx.NarrowReturnSummaries() - } - } else { - if ctx, ok := s.deps.CheckCtx.(api.DeclaredEnv); ok { - returnSummaries = ctx.ReturnSummaries() - } - } - } + functionFacts := s.currentFunctionFacts() var fnSym cfg.SymbolID if s.deps.CheckCtx != nil { @@ -227,14 +258,20 @@ func (s *Synthesizer) inferReturnTypesFromBody( } } - // If a return summary exists for this function symbol, declared phase can - // use it directly. Narrowing phase must still infer from the body so flow - // predicates can remove stale union members from pre-flow summaries. - var summaryFallback []typ.Type - if len(returnSummaries) > 0 && fnSym != 0 { - if rt := returnSummaries[fnSym]; len(rt) > 0 { + // If canonical facts already know this function's returns, declared phase + // can use them directly. Narrowing phase still analyzes the body so flow + // predicates can refine the pre-flow fact. + var canonicalReturns []typ.Type + if len(functionFacts) > 0 && fnSym != 0 { + rt := functionFacts.Summary(fnSym) + if s.IsNarrowing() { + if narrow := functionFacts.NarrowSummary(fnSym); len(narrow) > 0 { + rt = narrow + } + } + if len(rt) > 0 { if typ.HasKnownType(rt) { - summaryFallback = rt + canonicalReturns = rt if !s.IsNarrowing() && capturePoint == 0 && len(captureTypes) == 0 { return rt, false } @@ -264,26 +301,15 @@ func (s *Synthesizer) inferReturnTypesFromBody( overlay := s.buildParamOverlay(fnGraph, resolveScope, expected) - // Collect local function types from assignments using return summaries. - // Uses annotations for params and looks up return types from summaries. - // returnSummaries resolved above (pre-flow or post-flow depending on phase). + // Collect local function types from assignments using canonical function facts. + // Uses annotations for params and looks up return types from the product fact. - fnGraph.EachAssign(func(_ cfg.Point, info *cfg.AssignInfo) { - if info == nil || !info.IsLocal || len(info.Targets) == 0 { - return + for _, localFn := range fnGraph.LocalFunctionAssignments() { + fnType := s.buildLocalFunctionTypeFromFacts(localFn.Func, resolveScope, localFn.Symbol, functionFacts) + if fnType != nil { + overlay[localFn.Symbol] = fnType } - info.EachTargetSource(func(_ int, target cfg.AssignTarget, source ast.Expr) { - if target.Kind != cfg.TargetIdent || target.Symbol == 0 { - return - } - if fnExpr, ok := source.(*ast.FunctionExpr); ok { - fnType := s.buildFunctionTypeWithSummary(fnExpr, resolveScope, target.Symbol, returnSummaries) - if fnType != nil { - overlay[target.Symbol] = fnType - } - } - }) - }) + } // Include captured symbol types from the parent context. // This allows nested local functions to call sibling locals defined in the parent scope. @@ -322,7 +348,7 @@ func (s *Synthesizer) inferReturnTypesFromBody( } // Include local function types from the parent graph that are visible at this function's definition point. - // Uses return summaries for return types instead of recursive inference. + // Uses canonical function facts for return types instead of recursive inference. if s.deps.CheckCtx != nil { if pg, ok := s.deps.CheckCtx.Graph().(*cfg.Graph); ok && pg != nil { var defPoint cfg.Point @@ -335,37 +361,21 @@ func (s *Synthesizer) inferReturnTypesFromBody( if defPoint != 0 { visible := pg.AllSymbolsAt(defPoint) if len(visible) > 0 { - visibleSyms := make(map[cfg.SymbolID]struct{}, len(visible)) - for _, sym := range visible { - if sym != 0 { - visibleSyms[sym] = struct{}{} + for _, localFn := range pg.LocalFunctionAssignments() { + if localFn.Func == fn || localFn.Name == "" { + continue } - } - pg.EachAssign(func(_ cfg.Point, info *cfg.AssignInfo) { - if info == nil || !info.IsLocal || len(info.Targets) == 0 { - return + if visibleSym, ok := visible[localFn.Name]; !ok || visibleSym != localFn.Symbol { + continue } - info.EachTargetSource(func(_ int, target cfg.AssignTarget, source ast.Expr) { - if target.Kind != cfg.TargetIdent || target.Symbol == 0 { - return - } - if _, ok := visibleSyms[target.Symbol]; !ok { - return - } - if _, ok := overlay[target.Symbol]; ok { - return - } - if fnExpr, ok := source.(*ast.FunctionExpr); ok { - if fnExpr == fn { - return - } - fnType := s.buildFunctionTypeWithSummary(fnExpr, parentScope, target.Symbol, returnSummaries) - if fnType != nil { - overlay[target.Symbol] = fnType - } - } - }) - }) + if _, ok := overlay[localFn.Symbol]; ok { + continue + } + fnType := s.buildLocalFunctionTypeFromFacts(localFn.Func, parentScope, localFn.Symbol, functionFacts) + if fnType != nil { + overlay[localFn.Symbol] = fnType + } + } } } } @@ -400,6 +410,7 @@ func (s *Synthesizer) inferReturnTypesFromBody( DeclaredTypes: overlay, GlobalTypes: globalTypes, ModuleAliases: moduleAliases, + FunctionFacts: functionFacts, }) prelimDeps := &Deps{ @@ -409,8 +420,6 @@ func (s *Synthesizer) inferReturnTypesFromBody( Manifests: s.deps.Manifests, CheckCtx: prelimCtx, Graphs: s.deps.Graphs, - PreCache: make(api.Cache), - NarrowCache: make(api.Cache), FunctionTypeInProgress: s.deps.FunctionTypeInProgress, ModuleBindings: s.deps.ModuleBindings, ModuleAliases: moduleAliases, @@ -422,6 +431,17 @@ func (s *Synthesizer) inferReturnTypesFromBody( // Single-pass local inference from assignments (best-effort). var localInferred map[cfg.SymbolID]typ.Type + ensureLocalInferred := func() map[cfg.SymbolID]typ.Type { + if localInferred != nil { + return localInferred + } + capHint := overlaySymbolCapacity(fnGraph, 1) - len(overlay) + if capHint < 1 { + capHint = 1 + } + localInferred = make(map[cfg.SymbolID]typ.Type, capHint) + return localInferred + } fnGraph.EachAssign(func(p cfg.Point, info *cfg.AssignInfo) { if info == nil || !info.IsLocal || len(info.Targets) == 0 { return @@ -469,10 +489,7 @@ func (s *Synthesizer) inferReturnTypesFromBody( t = ensurePrelimSynth().SynthExpr(src, p, nil) } if t != nil { - if localInferred == nil { - localInferred = make(map[cfg.SymbolID]typ.Type) - } - localInferred[target.Symbol] = t + ensureLocalInferred()[target.Symbol] = t } return } @@ -488,10 +505,7 @@ func (s *Synthesizer) inferReturnTypesFromBody( return } if i < len(values) && values[i] != nil { - if localInferred == nil { - localInferred = make(map[cfg.SymbolID]typ.Type) - } - localInferred[target.Symbol] = values[i] + ensureLocalInferred()[target.Symbol] = values[i] } }) }) @@ -540,6 +554,7 @@ func (s *Synthesizer) inferReturnTypesFromBody( DeclaredTypes: overlay, GlobalTypes: globalTypes, ModuleAliases: moduleAliases, + FunctionFacts: functionFacts, }) tempDeps := &Deps{ @@ -549,8 +564,6 @@ func (s *Synthesizer) inferReturnTypesFromBody( Manifests: s.deps.Manifests, CheckCtx: fnCheckCtx, Graphs: s.deps.Graphs, - PreCache: make(api.Cache), - NarrowCache: make(api.Cache), FunctionTypeInProgress: s.deps.FunctionTypeInProgress, ModuleBindings: s.deps.ModuleBindings, ModuleAliases: moduleAliases, @@ -606,11 +619,15 @@ func (s *Synthesizer) inferReturnTypesFromBody( } } - if typ.IsUnknownOnlyOrEmpty(returnTypes) && len(summaryFallback) > 0 { - return summaryFallback, false + if typ.IsUnknownOnlyOrEmpty(returnTypes) && len(canonicalReturns) > 0 { + return canonicalReturns, false } - return returnTypes, erreffect.HasStrictInverseReturnPattern(fnGraph, nil, tempSynth, 0, 1) + convention := erreffect.CanonicalLuaValueErrorConvention() + if !convention.CanClassifyReturns(returnTypes) { + return returnTypes, false + } + return returnTypes, convention.HasStrictInversePattern(fnGraph, nil, tempSynth) } func enrichOverlayWithOrderedComparisonHints(fnGraph *cfg.Graph, overlay map[cfg.SymbolID]typ.Type) { @@ -688,19 +705,12 @@ func localFunctionSymbol(graph *cfg.Graph, fn *ast.FunctionExpr) cfg.SymbolID { } } var fnSym cfg.SymbolID - graph.EachAssign(func(_ cfg.Point, info *cfg.AssignInfo) { - if fnSym != 0 || info == nil || !info.IsLocal || len(info.Targets) == 0 { - return + for _, localFn := range graph.LocalFunctionAssignments() { + if localFn.Func == fn { + fnSym = localFn.Symbol + break } - info.EachTargetSource(func(_ int, target cfg.AssignTarget, source ast.Expr) { - if target.Kind != cfg.TargetIdent || target.Symbol == 0 { - return - } - if source == fn { - fnSym = target.Symbol - } - }) - }) + } if fnSym != 0 { return fnSym } @@ -716,7 +726,7 @@ func localFunctionSymbol(graph *cfg.Graph, fn *ast.FunctionExpr) cfg.SymbolID { } // inferReturnExprTypes synthesizes types from return expressions using CFG point. -// The last expression is expanded via MultiTypeOf to support multi-return calls. +// Lua expands only the final multivalue expression in a return list. func (s *Synthesizer) inferReturnExprTypes(exprs []ast.Expr, p cfg.Point) []typ.Type { if len(exprs) == 0 { return nil @@ -725,9 +735,9 @@ func (s *Synthesizer) inferReturnExprTypes(exprs []ast.Expr, p cfg.Point) []typ. if s.IsNarrowing() && s.deps.Flow != nil { narrower = s.deps.Flow } - var result []typ.Type + result := make([]typ.Type, 0, len(exprs)) for i, expr := range exprs { - if i == len(exprs)-1 { + if i == len(exprs)-1 && ast.CanProduceMultipleValues(expr) { multi := s.multiTypeOf(expr, p, narrower) if len(multi) == 0 { multi = []typ.Type{typ.Unknown} @@ -750,13 +760,13 @@ func (s *Synthesizer) inferReturnExprTypes(exprs []ast.Expr, p cfg.Point) []typ. return result } -// buildFunctionTypeWithSummary builds a function type using annotations for parameters -// and ReturnSummaries for return types. Does not recursively infer return types. -func (s *Synthesizer) buildFunctionTypeWithSummary( +// buildLocalFunctionTypeFromFacts builds a local function type from annotations +// and canonical function facts. It does not recursively infer returns. +func (s *Synthesizer) buildLocalFunctionTypeFromFacts( fn *ast.FunctionExpr, sc *scope.State, sym cfg.SymbolID, - returnSummaries map[cfg.SymbolID][]typ.Type, + functionFacts api.FunctionFacts, ) *typ.Function { if fn == nil { return nil @@ -773,16 +783,15 @@ func (s *Synthesizer) buildFunctionTypeWithSummary( return sig } - // Look up return types from summaries var returnTypes []typ.Type - if returnSummaries != nil && sym != 0 { - returnTypes = returnSummaries[sym] + if functionFacts != nil && sym != 0 { + returnTypes = functionFacts.Summary(sym) } return join.WithReturnsOrUnknown(sig, returnTypes) } -func (s *Synthesizer) buildFunctionTypeSummaryFallback( +func (s *Synthesizer) buildFunctionTypeFromAvailableFacts( fn *ast.FunctionExpr, sc *scope.State, expected *typ.Function, @@ -797,16 +806,7 @@ func (s *Synthesizer) buildFunctionTypeSummaryFallback( if expected != nil && len(sig.Returns) == 0 && len(expected.Returns) > 0 { sig = join.WithReturns(sig, expected.Returns) } - var summaries map[cfg.SymbolID][]typ.Type - if s.deps.CheckCtx != nil { - if s.IsNarrowing() { - if ctx, ok := s.deps.CheckCtx.(api.NarrowEnv); ok { - summaries = ctx.NarrowReturnSummaries() - } - } else if ctx, ok := s.deps.CheckCtx.(api.DeclaredEnv); ok { - summaries = ctx.ReturnSummaries() - } - } + functionFacts := s.currentFunctionFacts() var fnSym cfg.SymbolID if s.deps.CheckCtx != nil { if pg, ok := s.deps.CheckCtx.Graph().(*cfg.Graph); ok && pg != nil { @@ -814,14 +814,20 @@ func (s *Synthesizer) buildFunctionTypeSummaryFallback( } } if fnSym != 0 { - return join.WithReturnsOrUnknown(sig, summaries[fnSym]) + rets := functionFacts.Summary(fnSym) + if s.IsNarrowing() { + if narrow := functionFacts.NarrowSummary(fnSym); len(narrow) > 0 { + rets = narrow + } + } + return join.WithReturnsOrUnknown(sig, rets) } return join.WithReturnsOrUnknown(sig, nil) } func (s *Synthesizer) buildParamOverlay(fnGraph *cfg.Graph, sc *scope.State, expected *typ.Function) map[cfg.SymbolID]typ.Type { paramSlots := fnGraph.ParamSlotsReadOnly() - overlay := make(map[cfg.SymbolID]typ.Type, len(paramSlots)) + overlay := make(map[cfg.SymbolID]typ.Type, overlaySymbolCapacity(fnGraph, len(paramSlots))) for _, slot := range paramSlots { if slot.Symbol == 0 { continue @@ -851,6 +857,16 @@ func (s *Synthesizer) buildParamOverlay(fnGraph *cfg.Graph, sc *scope.State, exp return overlay } +func overlaySymbolCapacity(fnGraph *cfg.Graph, floor int) int { + if fnGraph == nil { + return floor + } + if count := fnGraph.SymbolCount(); count > floor { + return count + } + return floor +} + // inferCallbackOverlaySpec detects the "setup -> param call -> cleanup" pattern // and builds a contract.Spec with EnvOverlay for each callback parameter. func (s *Synthesizer) inferCallbackOverlaySpec( @@ -869,6 +885,7 @@ func (s *Synthesizer) inferCallbackOverlaySpec( synthExpr := func(expr ast.Expr, p cfg.Point) typ.Type { if tempSynth == nil { overlay := s.buildParamOverlay(fnGraph, sc, expected) + functionFacts := s.currentFunctionFacts() var globalTypes map[string]typ.Type var moduleAliases map[cfg.SymbolID]string @@ -887,6 +904,7 @@ func (s *Synthesizer) inferCallbackOverlaySpec( DeclaredTypes: overlay, GlobalTypes: globalTypes, ModuleAliases: moduleAliases, + FunctionFacts: functionFacts, }) tempDeps := &Deps{ Ctx: s.deps.Ctx, @@ -895,8 +913,6 @@ func (s *Synthesizer) inferCallbackOverlaySpec( Manifests: s.deps.Manifests, CheckCtx: fnCheckCtx, Graphs: s.deps.Graphs, - PreCache: make(api.Cache), - NarrowCache: make(api.Cache), ModuleBindings: s.deps.ModuleBindings, ModuleAliases: moduleAliases, } diff --git a/compiler/check/synth/phase/extract/manifest_enrich_test.go b/compiler/check/synth/phase/extract/manifest_enrich_test.go index 938f5efe..afaa77ce 100644 --- a/compiler/check/synth/phase/extract/manifest_enrich_test.go +++ b/compiler/check/synth/phase/extract/manifest_enrich_test.go @@ -31,7 +31,7 @@ func TestEnrichWithManifest_DirectLookup(t *testing.T) { } } -func TestEnrichWithManifest_ImportsFallback(t *testing.T) { +func TestEnrichWithManifest_ImportsLookup(t *testing.T) { manifest := io.NewManifest("m") manifest.SetExport(typ.NewRecord().Field("name", typ.String).Build()) got := enrichWithManifest(manifestQuerierStub{ @@ -39,7 +39,7 @@ func TestEnrichWithManifest_ImportsFallback(t *testing.T) { imports: map[string]*io.Manifest{"m": manifest}, }, typ.Number, "m", "name") if !typ.TypeEquals(got, typ.String) { - t.Fatalf("enrichWithManifest imports fallback = %v, want string", got) + t.Fatalf("enrichWithManifest imports lookup = %v, want string", got) } } diff --git a/compiler/check/synth/phase/extract/named_function.go b/compiler/check/synth/phase/extract/named_function.go index 2721b2f9..da80dc3a 100644 --- a/compiler/check/synth/phase/extract/named_function.go +++ b/compiler/check/synth/phase/extract/named_function.go @@ -137,7 +137,7 @@ func (s *Synthesizer) hasDominatingDirectFunctionRebind(sym compcfg.SymbolID, st return false } - idom, _ := cfganalysis.ComputeDominators(graph.CFG()) + idom := cfganalysis.ComputeImmediateDominators(graph.CFG()) rebound := false graph.EachAssign(func(assignPoint cfg.Point, info *compcfg.AssignInfo) { @@ -211,29 +211,45 @@ func (s *Synthesizer) stableGraphLocalFunctionSnapshotType(sym compcfg.SymbolID) return nil } - fallbackParent := s.deps.DefaultScope - if fallbackParent == nil { - fallbackParent = s.deps.CheckCtx.TypeNames() + defaultParent := s.deps.DefaultScope + if defaultParent == nil { + defaultParent = s.deps.CheckCtx.TypeNames() } - parent := api.ParentScopeForGraph(store, graph.ID(), fallbackParent) + parent := api.ParentScopeForGraph(store, graph.ID(), defaultParent) if parent == nil { return nil } - var fnTypes map[cfg.SymbolID]typ.Type + cacheKey := stableFunctionSnapshotKey{GraphID: graph.ID(), Parent: parent, Sym: sym} + if s.deps.StableFunctionSnapshot != nil { + if cached, ok := s.deps.StableFunctionSnapshot[cacheKey]; ok { + return cached + } + } + + var facts api.FunctionFacts load := func() { - fnTypes = store.GetLocalFuncTypesSnapshot(graph, parent) + facts = store.GetFunctionFactsSnapshot(graph, parent) } if phaser, ok := store.(interface{ WithPhase(api.Phase, func()) }); ok { phaser.WithPhase(api.PhaseScopeCompute, load) } else { load() } - if len(fnTypes) == 0 { + if len(facts) == 0 { + if s.deps.StableFunctionSnapshot == nil { + s.deps.StableFunctionSnapshot = make(map[stableFunctionSnapshotKey]typ.Type) + } + s.deps.StableFunctionSnapshot[cacheKey] = nil return nil } - return fnTypes[sym] + snapshotType := facts.FunctionType(sym) + if s.deps.StableFunctionSnapshot == nil { + s.deps.StableFunctionSnapshot = make(map[stableFunctionSnapshotKey]typ.Type) + } + s.deps.StableFunctionSnapshot[cacheKey] = snapshotType + return snapshotType } func (s *Synthesizer) stableLocalFunctionValueType( @@ -256,9 +272,23 @@ func (s *Synthesizer) stableLocalFunctionValueType( } } } - if snapshot := s.stableGraphLocalFunctionSnapshotType(sym); snapshot != nil { - if authoritative == nil || subtype.IsSubtype(snapshot, authoritative) { - authoritative = snapshot + hasContextFact := false + if s.deps != nil && s.deps.CheckCtx != nil { + if ctx, ok := s.deps.CheckCtx.(interface{ FunctionFacts() api.FunctionFacts }); ok { + facts := ctx.FunctionFacts() + if factType := facts.FunctionType(sym); factType != nil { + hasContextFact = true + if authoritative == nil || subtype.IsSubtype(factType, authoritative) { + authoritative = factType + } + } + } + } + if !hasContextFact { + if snapshot := s.stableGraphLocalFunctionSnapshotType(sym); snapshot != nil { + if authoritative == nil || subtype.IsSubtype(snapshot, authoritative) { + authoritative = snapshot + } } } if !hasCaptures && authoritative != nil { diff --git a/compiler/check/synth/phase/extract/named_function_test.go b/compiler/check/synth/phase/extract/named_function_test.go index 2a25e35d..3074835d 100644 --- a/compiler/check/synth/phase/extract/named_function_test.go +++ b/compiler/check/synth/phase/extract/named_function_test.go @@ -24,7 +24,7 @@ func newNamedFunctionSynth(localBindings, moduleBindings *bind.BindingTable) *Sy }, api.PhaseTypeResolution) } -func TestFunctionLiteralForIdent_UsesModuleFallbackSymbolWhenPrimaryHasNoLiteral(t *testing.T) { +func TestFunctionLiteralForIdent_UsesModuleSymbolWhenPrimaryHasNoLiteral(t *testing.T) { ident := &ast.IdentExpr{Value: "f"} localBindings := bind.NewBindingTable() diff --git a/compiler/check/synth/phase/extract/synthesizer.go b/compiler/check/synth/phase/extract/synthesizer.go index f745063d..d567cf8c 100644 --- a/compiler/check/synth/phase/extract/synthesizer.go +++ b/compiler/check/synth/phase/extract/synthesizer.go @@ -44,6 +44,27 @@ import ( // Used as callback to allow synthExprCore to recursively synthesize sub-expressions. type ExprSynth func(expr ast.Expr) typ.Type +type exprRecurser struct { + s *Synthesizer + p cfg.Point + sc *scope.State + narrower api.FlowOps + recurse ExprSynth +} + +func newExprRecurser(s *Synthesizer, p cfg.Point, sc *scope.State, narrower api.FlowOps) *exprRecurser { + r := &exprRecurser{s: s, p: p, sc: sc, narrower: narrower} + r.recurse = r.synth + return r +} + +func (r *exprRecurser) synth(expr ast.Expr) typ.Type { + if expr == nil { + return typ.Nil + } + return r.s.synthExprCore(expr, r.sc, r.p, r.narrower, r.recurse) +} + // Synthesizer is the core type synthesis engine for expressions. // // It implements a recursive descent over the AST, computing types for each @@ -125,8 +146,8 @@ func (s *Synthesizer) TypeOfWithExpected(expr ast.Expr, p cfg.Point, expected ty return s.TypeOf(expr, p) } sc := s.deps.ScopeAt(p) - recurse := func(ex ast.Expr) typ.Type { return s.SynthExpr(ex, p, nil) } - return s.SynthExprWithExpectedCore(expr, sc, p, recurse, expected) + recurser := newExprRecurser(s, p, sc, nil) + return s.SynthExprWithExpectedCore(expr, sc, p, recurser.recurse, expected) } // MultiTypeOf synthesizes multiple types for multi-value expressions (no narrowing). @@ -141,10 +162,13 @@ func (s *Synthesizer) SynthMulti(expr ast.Expr, p cfg.Point, narrower api.FlowOp func (s *Synthesizer) multiTypeOf(expr ast.Expr, p cfg.Point, narrower api.FlowOps) []typ.Type { sc := s.deps.ScopeAt(p) - recurse := func(ex ast.Expr) typ.Type { return s.SynthExpr(ex, p, narrower) } - return s.synthMultiCore(expr, sc, recurse, + if t, ok := s.synthNonRecursiveExpr(expr, sc, p, narrower); ok { + return []typ.Type{t} + } + recurser := newExprRecurser(s, p, sc, narrower) + return s.synthMultiCore(expr, sc, recurser.recurse, func(call *ast.FuncCallExpr) []typ.Type { - return s.SynthCallCore(call, p, sc, narrower, recurse) + return s.SynthCallCore(call, p, sc, narrower, recurser.recurse) }, ) } @@ -180,8 +204,11 @@ func (s *Synthesizer) SynthExprAt(expr ast.Expr, p cfg.Point, sc *scope.State) t if expr == nil { return typ.Nil } - recurse := func(ex ast.Expr) typ.Type { return s.SynthExprAt(ex, p, sc) } - return s.synthExprCore(expr, sc, p, nil, recurse) + if t, ok := s.synthNonRecursiveExpr(expr, sc, p, nil); ok { + return t + } + recurser := newExprRecurser(s, p, sc, nil) + return recurser.synth(expr) } // Resolver returns a type resolver. @@ -223,27 +250,54 @@ func (s *Synthesizer) SynthExpr(expr ast.Expr, p cfg.Point, narrower api.FlowOps return typ.Nil } sc := s.deps.ScopeAt(p) - recurse := func(ex ast.Expr) typ.Type { return s.SynthExpr(ex, p, narrower) } - return s.synthExprCore(expr, sc, p, narrower, recurse) + if t, ok := s.synthNonRecursiveExpr(expr, sc, p, narrower); ok { + return t + } + recurser := newExprRecurser(s, p, sc, narrower) + return recurser.synth(expr) } -// synthExprCore is the shared expression synthesizer implementation. -func (s *Synthesizer) synthExprCore(expr ast.Expr, sc *scope.State, p cfg.Point, narrower api.FlowOps, recurse ExprSynth) typ.Type { +// synthNonRecursiveExpr handles expression forms whose type does not depend on +// recursively synthesizing child expressions. +func (s *Synthesizer) synthNonRecursiveExpr(expr ast.Expr, sc *scope.State, p cfg.Point, narrower api.FlowOps) (typ.Type, bool) { switch ex := expr.(type) { case *ast.NilExpr: - return typ.Nil + return typ.Nil, true case *ast.TrueExpr: - return typ.True + return typ.True, true case *ast.FalseExpr: - return typ.False + return typ.False, true case *ast.NumberExpr: - return ops.ParseNumber(ex.Value) + return ops.ParseNumber(ex.Value), true case *ast.StringExpr: - return typ.LiteralString(ex.Value) + return typ.LiteralString(ex.Value), true case *ast.Comma3Expr: - return s.synthComma3(sc) + return s.synthComma3(sc), true case *ast.IdentExpr: - return s.synthIdentCore(ex, p, sc, narrower) + return s.synthIdentCore(ex, p, sc, narrower), true + case *ast.FunctionExpr: + return s.FunctionType(ex, sc), true + case *ast.RelationalOpExpr: + return typ.Boolean, true + case *ast.StringConcatOpExpr: + return typ.String, true + case *ast.UnaryNotOpExpr: + return typ.Boolean, true + case *ast.UnaryBNotOpExpr: + return typ.Integer, true + case *ast.CastExpr: + return s.ResolveType(ex.Type, sc), true + default: + return nil, false + } +} + +// synthExprCore is the shared expression synthesizer implementation. +func (s *Synthesizer) synthExprCore(expr ast.Expr, sc *scope.State, p cfg.Point, narrower api.FlowOps, recurse ExprSynth) typ.Type { + if t, ok := s.synthNonRecursiveExpr(expr, sc, p, narrower); ok { + return t + } + switch ex := expr.(type) { case *ast.AttrGetExpr: return s.synthAttrGetCore(ex, p, sc, narrower, recurse) case *ast.TableExpr: @@ -254,30 +308,18 @@ func (s *Synthesizer) synthExprCore(expr ast.Expr, sc *scope.State, p cfg.Point, return types[0] } return typ.Nil - case *ast.FunctionExpr: - return s.FunctionType(ex, sc) case *ast.LogicalOpExpr: if s.IsNarrowing() && narrower != nil { return s.synthLogicalOpWithNarrowing(ex, p, sc, narrower, recurse) } return s.synthLogicalOpCore(ex, recurse) - case *ast.RelationalOpExpr: - return typ.Boolean - case *ast.StringConcatOpExpr: - return typ.String case *ast.ArithmeticOpExpr: return s.synthArithmeticOpCore(ex, recurse) case *ast.UnaryMinusOpExpr: return s.synthUnaryMinusCore(ex, recurse) - case *ast.UnaryNotOpExpr: - return typ.Boolean case *ast.UnaryLenOpExpr: operand := recurse(ex.Expr) return s.deps.Types.UnaryOp(s.deps.Ctx, "#", operand) - case *ast.UnaryBNotOpExpr: - return typ.Integer - case *ast.CastExpr: - return s.ResolveType(ex.Type, sc) case *ast.NonNilAssertExpr: inner := recurse(ex.Expr) return narrow.RemoveNil(inner) @@ -349,7 +391,7 @@ func (s *Synthesizer) synthIdentCore(ex *ast.IdentExpr, p cfg.Point, sc *scope.S // For "self" identifier, check scope's self type first. // This ensures methods assigned via field assignment (obj.method = function(self)...) - // get the correct self type before falling back to parameter type lookup. + // get the correct self type before parameter lookup. if ex.Value == "self" && sc != nil { if selfType := sc.SelfType(); selfType != nil { return selfType @@ -377,7 +419,7 @@ func (s *Synthesizer) synthIdentCore(ex *ast.IdentExpr, p cfg.Point, sc *scope.S requireSubtype = unwrap.Function(declared.Type) != nil } if requireSubtype && !subtype.IsSubtype(narrowed, declared.Type) { - goto fallback + goto declaredLookup } } } @@ -385,7 +427,7 @@ func (s *Synthesizer) synthIdentCore(ex *ast.IdentExpr, p cfg.Point, sc *scope.S } } -fallback: +declaredLookup: if types := ctx.Types(); types != nil { tv := types.EffectiveTypeAt(p, sym) if tv.State == flow.StateResolved && tv.Type != nil { @@ -423,7 +465,7 @@ fallback: } } - // Module alias lookup (require("mod")) as fallback when no concrete type is resolved. + // Module alias lookup (require("mod")) when no concrete type is resolved. moduleAliasSym := sym if moduleAliasSym == 0 { moduleAliasSym = moduleSym diff --git a/compiler/check/synth/phase/extract/synthesizer_test.go b/compiler/check/synth/phase/extract/synthesizer_test.go index f669fd01..c772b2b8 100644 --- a/compiler/check/synth/phase/extract/synthesizer_test.go +++ b/compiler/check/synth/phase/extract/synthesizer_test.go @@ -220,11 +220,11 @@ func TestSynthesizer_TypeOf_Ident(t *testing.T) { } } -func TestSynthesizer_TypeOf_IdentFallsBackToGraphSymbolAt(t *testing.T) { +func TestSynthesizer_TypeOf_IdentUsesGraphSymbolAt(t *testing.T) { s, _ := newTestSynthesizerWithSymbol("x", typ.Integer) result := s.TypeOf(&ast.IdentExpr{Value: "x"}, 0) if result != typ.Integer { - t.Fatalf("got %v, want integer via SymbolAt fallback", result) + t.Fatalf("got %v, want integer via SymbolAt", result) } } diff --git a/compiler/check/tests/errors/error_correlation_test.go b/compiler/check/tests/errors/error_correlation_test.go index cfac14ca..adc39a1b 100644 --- a/compiler/check/tests/errors/error_correlation_test.go +++ b/compiler/check/tests/errors/error_correlation_test.go @@ -139,9 +139,9 @@ db:release() if getDbSym != 0 && result.Session.Store != nil { parentHash := result.Session.Store.GraphParentHashOf(root.Graph.ID()) parent := result.Session.Store.Parents()[parentHash] - if summaries := result.Session.Store.GetReturnSummariesSnapshot(root.Graph, parent); summaries != nil { - if returns, ok := summaries[getDbSym]; ok { - t.Logf("ReturnSummaries[%d][get_db]=%v", parentHash, returns) + if facts := result.Session.Store.GetFunctionFactsSnapshot(root.Graph, parent); facts != nil { + if fact, ok := facts[getDbSym]; ok { + t.Logf("FunctionFacts[%d][get_db].Summary=%v", parentHash, fact.Summary) } } } @@ -287,14 +287,14 @@ db:release() root := result.Session.RootResult.Graph parentHash := result.Session.Store.GraphParentHashOf(root.ID()) parent := result.Session.Store.Parents()[parentHash] - summaries := result.Session.Store.GetReturnSummariesSnapshot(root, parent) + functionFacts := result.Session.Store.GetFunctionFactsSnapshot(root, parent) for _, name := range []string{"connect", "get_connection"} { sym, ok := root.SymbolAt(root.Exit(), name) if !ok || sym == 0 { t.Fatalf("missing symbol for %s", name) } - rets := returns.NormalizeReturnVector(summaries[sym]) + rets := returns.NormalizeReturnVector(functionFacts.Summary(sym)) if len(rets) == 0 { t.Fatalf("missing return summary for %s", name) } diff --git a/compiler/check/tests/flow/fixpoint_unification_test.go b/compiler/check/tests/flow/fixpoint_unification_test.go index 5b0cb346..c74e33b7 100644 --- a/compiler/check/tests/flow/fixpoint_unification_test.go +++ b/compiler/check/tests/flow/fixpoint_unification_test.go @@ -140,8 +140,8 @@ local result: number = a() parentHash := sess.Store.GraphParentHashOf(sess.RootResult.Graph.ID()) parent := sess.Store.Parents()[parentHash] - summaries := sess.Store.GetReturnSummariesSnapshot(sess.RootResult.Graph, parent) - if len(summaries) == 0 { + functionFacts := sess.Store.GetFunctionFactsSnapshot(sess.RootResult.Graph, parent) + if len(functionFacts) == 0 { t.Error("expected non-empty return summaries for the call chain") } @@ -151,14 +151,14 @@ local result: number = a() t.Fatal("missing root graph") } - for sym, rt := range summaries { + for sym, fact := range functionFacts { name := graph.NameOf(sym) if name == "a" || name == "b" || name == "c" { - if len(rt) == 0 { + if len(fact.Summary) == 0 { t.Errorf("empty return summary for %q", name) continue } - if typ.TypeEquals(rt[0], typ.Unknown) { + if typ.TypeEquals(fact.Summary[0], typ.Unknown) { t.Errorf("return type for %q is unknown, expected number", name) } } @@ -446,16 +446,16 @@ local result: number = d() checkedFunctions := make(map[string]bool) parentHash := sess.Store.GraphParentHashOf(sess.RootResult.Graph.ID()) parent := sess.Store.Parents()[parentHash] - summaries := sess.Store.GetReturnSummariesSnapshot(sess.RootResult.Graph, parent) - for sym, rt := range summaries { + functionFacts := sess.Store.GetFunctionFactsSnapshot(sess.RootResult.Graph, parent) + for sym, fact := range functionFacts { name := graph.NameOf(sym) if name == "b" || name == "c" || name == "d" { checkedFunctions[name] = true - if len(rt) == 0 { + if len(fact.Summary) == 0 { t.Errorf("empty return summary for %q", name) continue } - if typ.TypeEquals(rt[0], typ.Unknown) { + if typ.TypeEquals(fact.Summary[0], typ.Unknown) { t.Errorf("return type for %q is unknown, expected number (hints didn't propagate)", name) } } diff --git a/compiler/check/tests/inference/closure_return_infer_test.go b/compiler/check/tests/inference/closure_return_infer_test.go index 3a3149f1..96ca146b 100644 --- a/compiler/check/tests/inference/closure_return_infer_test.go +++ b/compiler/check/tests/inference/closure_return_infer_test.go @@ -801,10 +801,10 @@ end var summary []typ.Type parentHash := sess.Store.GraphParentHashOf(sess.RootResult.Graph.ID()) parent := sess.Store.Parents()[parentHash] - if summaries := sess.Store.GetReturnSummariesSnapshot(sess.RootResult.Graph, parent); summaries != nil { - for sym, rt := range summaries { + if facts := sess.Store.GetFunctionFactsSnapshot(sess.RootResult.Graph, parent); facts != nil { + for sym, fact := range facts { if sess.RootResult.Graph.NameOf(sym) == "get_db" { - summary = rt + summary = fact.Summary break } } @@ -873,10 +873,10 @@ end var summary []typ.Type parentHash := sess.Store.GraphParentHashOf(sess.RootResult.Graph.ID()) parent := sess.Store.Parents()[parentHash] - if summaries := sess.Store.GetReturnSummariesSnapshot(sess.RootResult.Graph, parent); summaries != nil { - for sym, rt := range summaries { + if facts := sess.Store.GetFunctionFactsSnapshot(sess.RootResult.Graph, parent); facts != nil { + for sym, fact := range facts { if sess.RootResult.Graph.NameOf(sym) == "get_db" { - summary = rt + summary = fact.Summary break } } @@ -931,10 +931,10 @@ local y: string = b parentHash := sess.Store.GraphParentHashOf(sess.RootResult.Graph.ID()) parent := sess.Store.Parents()[parentHash] - if summaries := sess.Store.GetReturnSummariesSnapshot(sess.RootResult.Graph, parent); summaries != nil { - for sym, rt := range summaries { + if facts := sess.Store.GetFunctionFactsSnapshot(sess.RootResult.Graph, parent); facts != nil { + for sym, fact := range facts { name := sess.RootResult.Graph.NameOf(sym) - for i, slot := range rt { + for i, slot := range fact.Summary { if slot == nil { t.Errorf("nil slot at index %d in return summary for %q", i, name) } @@ -965,9 +965,9 @@ end found := 0 parentHash := sess.Store.GraphParentHashOf(sess.RootResult.Graph.ID()) parent := sess.Store.Parents()[parentHash] - if summaries := sess.Store.GetReturnSummariesSnapshot(sess.RootResult.Graph, parent); summaries != nil { - for sym, rt := range summaries { - if len(rt) == 0 { + if facts := sess.Store.GetFunctionFactsSnapshot(sess.RootResult.Graph, parent); facts != nil { + for sym, fact := range facts { + if len(fact.Summary) == 0 { name := "" if sess.RootResult.Graph != nil { name = sess.RootResult.Graph.NameOf(sym) diff --git a/compiler/check/tests/regression/channel_select_helper_return_narrowing_test.go b/compiler/check/tests/regression/channel_select_helper_return_narrowing_test.go index 56c3a1af..3eed9a43 100644 --- a/compiler/check/tests/regression/channel_select_helper_return_narrowing_test.go +++ b/compiler/check/tests/regression/channel_select_helper_return_narrowing_test.go @@ -85,18 +85,18 @@ end root := sess.RootResult.Graph parentHash := sess.Store.GraphParentHashOf(root.ID()) parent := sess.Store.Parents()[parentHash] - funcTypes := sess.Store.GetLocalFuncTypesSnapshot(root, parent) + functionFacts := sess.Store.GetFunctionFactsSnapshot(root, parent) var helperFn *typ.Function - for sym, tpe := range funcTypes { + for sym, fact := range functionFacts { if root.NameOf(sym) != "wait_for_exit" { continue } - helperFn = unwrap.Function(tpe) + helperFn = unwrap.Function(fact.Type) break } if helperFn == nil || len(helperFn.Returns) == 0 { - t.Fatalf("missing wait_for_exit function type in local func snapshot: %v", funcTypes) + t.Fatalf("missing wait_for_exit function type in FunctionFacts: %v", functionFacts) } nonNil := narrow.RemoveNil(helperFn.Returns[0]) diff --git a/compiler/check/tests/regression/contract_open_dynamic_return_test.go b/compiler/check/tests/regression/contract_open_dynamic_return_test.go index 4017e6b5..cbc79458 100644 --- a/compiler/check/tests/regression/contract_open_dynamic_return_test.go +++ b/compiler/check/tests/regression/contract_open_dynamic_return_test.go @@ -83,15 +83,15 @@ end root := result.Session.RootResult.Graph parentHash := result.Session.Store.GraphParentHashOf(root.ID()) parent := result.Session.Store.Parents()[parentHash] - funcTypes := result.Session.Store.GetLocalFuncTypesSnapshot(root, parent) + functionFacts := result.Session.Store.GetFunctionFactsSnapshot(root, parent) sym, ok := root.SymbolAt(root.Exit(), "get_tracker") if !ok || sym == 0 { t.Fatal("missing symbol get_tracker") } - fn := unwrap.Function(funcTypes[sym]) + fn := unwrap.Function(functionFacts.FunctionType(sym)) if fn == nil || len(fn.Returns) == 0 || fn.Returns[0] == nil { - t.Fatalf("expected get_tracker function return type, got %v", funcTypes[sym]) + t.Fatalf("expected get_tracker function return type, got %v", functionFacts.FunctionType(sym)) } if fn.Returns[0].Kind() == kind.Nil { t.Fatalf("expected get_tracker return not to collapse to nil, got %v", fn.Returns[0]) diff --git a/compiler/check/tests/regression/false_positives_unit_test.go b/compiler/check/tests/regression/false_positives_unit_test.go index c311045c..c3ee6dd1 100644 --- a/compiler/check/tests/regression/false_positives_unit_test.go +++ b/compiler/check/tests/regression/false_positives_unit_test.go @@ -1025,6 +1025,41 @@ func TestFP_OrEmptyStringStaysString(t *testing.T) { } } +func TestFP_ErrorReturnOptionalFieldOrEmptyStringStaysString(t *testing.T) { + responseType := typ.NewRecord(). + Field("status_code", typ.Number). + OptField("body", typ.String). + Build() + httpManifest := io.NewManifest("http_client") + httpManifest.SetExport(typ.NewRecord(). + Field("post", typ.Func(). + Param("url", typ.String). + OptParam("opts", typ.Any). + Returns(typ.NewOptional(responseType), typ.NewOptional(typ.LuaError)). + Spec(contract.NewSpec().WithEffects(effect.ErrorReturn{ValueIndex: 0, ErrorIndex: 1})). + Build()). + Build()) + + source := ` + local http = require("http_client") + local json = require("json") + + local client = {} + client._http_client = http + + local response, err = client._http_client.post("https://example.local", {}) + if err then + return nil, err + end + + return json.decode(response.body or "") + ` + result := testutil.Check(source, testutil.WithStdlib(), testutil.WithManifest("http_client", httpManifest)) + if result.HasError() { + t.Errorf("expected no errors for optional response body fallback, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + func TestFP_AndGuardNarrowsNestedPath(t *testing.T) { source := ` local rec: {foo: {bar: string}?}? = nil diff --git a/compiler/check/tests/regression/imported_record_helper_param_test.go b/compiler/check/tests/regression/imported_record_helper_param_test.go new file mode 100644 index 00000000..1489c0ae --- /dev/null +++ b/compiler/check/tests/regression/imported_record_helper_param_test.go @@ -0,0 +1,105 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +func TestRegression_ImportedRecordPassedThroughUntypedHelper(t *testing.T) { + clientModule := testutil.CheckAndExport(` + local client = {} + client.SERVICE = "bedrock" + function client.invoke(model_id, payload, options) + return {ok = true} + end + return client + `, "bedrock_client", testutil.WithStdlib()) + if clientModule.HasError() { + t.Fatalf("provider errors: %v", testutil.ErrorMessages(clientModule.Errors)) + } + + source := ` + local bedrock_client = require("bedrock_client") + + local handler = { + _client = bedrock_client, + } + + local function helper(client, model_id, payload, options) + return client.invoke(model_id, payload, options) + end + + local result = helper(handler._client, "model", {}, {}) + ` + result := testutil.Check(source, testutil.WithStdlib(), testutil.WithModule("bedrock_client", clientModule)) + if result.HasError() { + t.Fatalf("expected imported record helper call to type-check, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestRegression_ImportedRecordHelperWithConstrainedMethodUse(t *testing.T) { + clientModule := testutil.CheckAndExport(` + local client = {} + client.SERVICE = "bedrock" + function client.invoke(model_id: string, payload: any, options: {timeout: number?}?) + return {ok = true}, nil + end + return client + `, "bedrock_client", testutil.WithStdlib()) + if clientModule.HasError() { + t.Fatalf("provider errors: %v", testutil.ErrorMessages(clientModule.Errors)) + } + + source := ` + local bedrock_client = require("bedrock_client") + + local handler = { + _client = bedrock_client, + } + + local function helper(client, model_id, input, options) + local payload = { input = input } + local response, err = client.invoke(model_id, payload, { timeout = options and options.timeout }) + if err then + return nil, err + end + return response + end + + local result = helper(handler._client, "model", "text", { timeout = 1 }) + ` + result := testutil.Check(source, testutil.WithStdlib(), testutil.WithModule("bedrock_client", clientModule)) + if result.HasError() { + t.Fatalf("expected imported record helper with constrained method use to type-check, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestRegression_ImportedRecordHelperRejectsAnyPassedToStringMethod(t *testing.T) { + clientModule := testutil.CheckAndExport(` + local client = {} + function client.invoke(model_id: string, payload: any, options: any) + return {ok = true}, nil + end + return client + `, "bedrock_client", testutil.WithStdlib()) + if clientModule.HasError() { + t.Fatalf("provider errors: %v", testutil.ErrorMessages(clientModule.Errors)) + } + + source := ` + local bedrock_client = require("bedrock_client") + + local function helper(client, model_id) + return client.invoke(model_id, {}, {}) + end + + local contract_args = nil :: any + local model_id = contract_args.model + helper(bedrock_client, model_id) + ` + result := testutil.Check(source, testutil.WithStdlib(), testutil.WithModule("bedrock_client", clientModule)) + if !result.HasError() { + t.Fatalf("expected an error when any flows into imported string-only method") + } +} diff --git a/compiler/check/tests/regression/local_function_narrow_return_repair_test.go b/compiler/check/tests/regression/local_function_narrow_return_repair_test.go index 715b2bf2..53be3d98 100644 --- a/compiler/check/tests/regression/local_function_narrow_return_repair_test.go +++ b/compiler/check/tests/regression/local_function_narrow_return_repair_test.go @@ -59,13 +59,13 @@ return { f = f } parent := sess.Store.Parents()[parentHash] snap := sess.Store.GetInterprocFactsSnapshot(sess.RootResult.Graph, parent) - if got := snap.ReturnSummaries[sym]; len(got) != 1 || containsNever(got[0]) { + if got := snap.FunctionFacts.Summary(sym); len(got) != 1 || containsNever(got[0]) { t.Fatalf("summary contains never artifact: %v", got) } - if got := snap.NarrowReturns[sym]; len(got) != 1 || containsNever(got[0]) { + if got := snap.FunctionFacts.NarrowSummary(sym); len(got) != 1 || containsNever(got[0]) { t.Fatalf("narrow contains never artifact: %v", got) } - if got := snap.FuncTypes[sym]; got == nil || containsNever(got) { + if got := snap.FunctionFacts.FunctionType(sym); got == nil || containsNever(got) { t.Fatalf("function fact contains never artifact: %v", got) } diff --git a/compiler/check/tests/regression/logical_or_soundness_test.go b/compiler/check/tests/regression/logical_or_soundness_test.go new file mode 100644 index 00000000..81ced82f --- /dev/null +++ b/compiler/check/tests/regression/logical_or_soundness_test.go @@ -0,0 +1,25 @@ +package regression + +import ( + "strings" + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +func TestRegression_LogicalOrKeepsTruthyLeftAlternative(t *testing.T) { + source := ` + local function f(xs: {any}?) + local ys: {number} = xs or {1} + return ys + end + ` + result := testutil.Check(source, testutil.WithStdlib()) + if !result.HasError() { + t.Fatal("expected logical-or assignment to reject possible any[] left branch") + } + msgs := strings.Join(testutil.ErrorMessages(result.Diagnostics), " | ") + if !strings.Contains(msgs, "cannot assign") { + t.Fatalf("expected assignment error, got: %v", msgs) + } +} diff --git a/compiler/check/tests/regression/param_hint_depth_convergence_test.go b/compiler/check/tests/regression/param_hint_depth_convergence_test.go index 435303e1..024fd043 100644 --- a/compiler/check/tests/regression/param_hint_depth_convergence_test.go +++ b/compiler/check/tests/regression/param_hint_depth_convergence_test.go @@ -42,3 +42,140 @@ func TestParamHints_DeepAliasChain_NoInterprocNonConvergenceWarning(t *testing.T } } } + +func TestParamHints_RecordWrapperFeedback_NoInterprocNonConvergenceWarning(t *testing.T) { + code := ` + local repo = require("kb_repo") + + local function config() + return { + retrieval_iterations = tonumber(nil) or 2, + initial_vector_limit = tonumber(nil) or 4, + followup_vector_limit = tonumber(nil) or 2, + } + end + + local function search(question, kb_id, vec_limit, seen) + seen = seen or {} + local rows = repo.hybrid_search(question, kb_id, { limit = vec_limit }) + if rows then + for _, row in ipairs(rows) do + if row.node_id and not seen[row.node_id] then + seen[row.node_id] = true + end + end + end + return seen + end + + local function run(kb_id, question) + local cfg = config() + local seen = search(question, kb_id, cfg.initial_vector_limit) + for _ = 1, math.min(cfg.retrieval_iterations, 3) do + search(question, kb_id, cfg.followup_vector_limit, seen) + end + return seen + end + + return run("kb", "question") + ` + + result := testutil.Check(code, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } + + for _, d := range result.Diagnostics { + if d.Severity == diag.SeverityWarning && strings.Contains(d.Message, "inter-function fixpoint did not converge") { + t.Fatalf("unexpected non-convergence warning: %v", d.Message) + } + } +} + +func TestParamHints_NestedWrapperFeedback_NoInterprocNonConvergenceWarning(t *testing.T) { + code := ` + local repo = require("kb_repo") + + local function config() + return { + limit = tonumber(nil) or 4, + iterations = tonumber(nil) or 2, + } + end + + local function search(question, kb_id, limit, seen) + seen = seen or {} + local rows = repo.hybrid_search(question, kb_id, { + query = question, + options = { + limit = limit, + window = { limit }, + }, + }) + if rows then + for _, row in ipairs(rows) do + if row.node_id and not seen[row.node_id] then + seen[row.node_id] = true + end + end + end + return seen + end + + local function run(kb_id, question) + local cfg = config() + local seen = nil + for _ = 1, math.min(cfg.iterations, 3) do + seen = search(question, kb_id, cfg.limit, seen) + end + return seen + end + + return run("kb", "question") + ` + + result := testutil.Check(code, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } + + for _, d := range result.Diagnostics { + if d.Severity == diag.SeverityWarning && strings.Contains(d.Message, "inter-function fixpoint did not converge") { + t.Fatalf("unexpected non-convergence warning: %v", d.Message) + } + } +} + +func TestReturnSummary_RecursiveDeepCopy_NoInterprocNonConvergenceWarning(t *testing.T) { + code := ` + local function deep_copy_table(original) + if type(original) ~= "table" then + return original + end + + local copy = {} + for key, value in pairs(original) do + if type(value) == "table" then + copy[key] = deep_copy_table(value) + else + copy[key] = value + end + end + return copy + end + + local source = { api = { routes = { users = true } } } + return deep_copy_table(source) + ` + + result := testutil.Check(code, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } + + for _, d := range result.Diagnostics { + if d.Severity == diag.SeverityWarning && strings.Contains(d.Message, "inter-function fixpoint did not converge") { + t.Fatalf("unexpected non-convergence warning: %v", d.Message) + } + } +} diff --git a/compiler/check/tests/regression/wippy_false_positives_test.go b/compiler/check/tests/regression/wippy_false_positives_test.go index 0676c935..406bd036 100644 --- a/compiler/check/tests/regression/wippy_false_positives_test.go +++ b/compiler/check/tests/regression/wippy_false_positives_test.go @@ -275,13 +275,14 @@ func TestLocalFunctionShadowsModule_BindingDiagnostic(t *testing.T) { if sess != nil && sess.Store != nil && sess.RootResult != nil && sess.RootResult.BaseScope != nil && sess.RootResult.Graph != nil { parentHash := sess.Store.GraphParentHashOf(sess.RootResult.Graph.ID()) parent := sess.Store.Parents()[parentHash] - funcTypes := sess.Store.GetLocalFuncTypesSnapshot(sess.RootResult.Graph, parent) - t.Logf("LocalFuncTypes has %d symbols", len(funcTypes)) - for sym, ty := range funcTypes { + functionFacts := sess.Store.GetFunctionFactsSnapshot(sess.RootResult.Graph, parent) + t.Logf("FunctionFacts has %d symbols", len(functionFacts)) + for sym, fact := range functionFacts { name := "" if sess.Store.ModuleBindings() != nil { name = sess.Store.ModuleBindings().Name(sym) } + ty := fact.Type if ty != nil { t.Logf(" sym %d (%s): %s", sym, name, ty.String()) } diff --git a/compiler/check/tests/regression/wippy_sorted_keys_param_hints_test.go b/compiler/check/tests/regression/wippy_sorted_keys_param_hints_test.go index cb5a09f8..5a0fafa0 100644 --- a/compiler/check/tests/regression/wippy_sorted_keys_param_hints_test.go +++ b/compiler/check/tests/regression/wippy_sorted_keys_param_hints_test.go @@ -391,12 +391,13 @@ func TestWippyRunner_NearLiteralTestRunnerFlow(t *testing.T) { root := result.Session.RootResult.Graph parentHash := result.Session.Store.GraphParentHashOf(root.ID()) parent := result.Session.Store.Parents()[parentHash] - localTypes := result.Session.Store.GetLocalFuncTypesSnapshot(root, parent) + functionFacts := result.Session.Store.GetFunctionFactsSnapshot(root, parent) hints := result.Session.Store.GetParamHintsSnapshot(root, parent) if bindings := result.Session.Store.ModuleBindings(); bindings != nil { - for sym, fnType := range localTypes { + for sym, fact := range functionFacts { name := bindings.Name(sym) if name == "sorted_keys" || name == "run_suite" || name == "run_test" || name == "group_by_suite" { + fnType := fact.Type t.Logf("local-fn %q sym=%d type=%s", name, sym, typ.Format(fnType, typ.DefaultFormatOptions)) if hv := hints[sym]; len(hv) > 0 { t.Logf("param-hints %q: %v", name, hv) diff --git a/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/contract.lua b/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/contract.lua index 4dbbeb2e..6580b78a 100644 --- a/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/contract.lua +++ b/testdata/fixtures/modules/providers-open-retry-captured-options-realtest/contract.lua @@ -1,7 +1,17 @@ local contract = {} function contract.get(_id) - return nil, "not configured" + return { + with_context = function(self, _context) + return self + end, + with_options = function(self, _options) + return self + end, + open = function(self, _provider_id) + return {}, nil + end, + }, nil end return contract diff --git a/testdata/fixtures/modules/providers-open-retry-captured-options/contract.lua b/testdata/fixtures/modules/providers-open-retry-captured-options/contract.lua index 4dbbeb2e..6580b78a 100644 --- a/testdata/fixtures/modules/providers-open-retry-captured-options/contract.lua +++ b/testdata/fixtures/modules/providers-open-retry-captured-options/contract.lua @@ -1,7 +1,17 @@ local contract = {} function contract.get(_id) - return nil, "not configured" + return { + with_context = function(self, _context) + return self + end, + with_options = function(self, _options) + return self + end, + open = function(self, _provider_id) + return {}, nil + end, + }, nil end return contract diff --git a/testdata/fixtures/narrowing/dynamic-registry-renderer-guard/main.lua b/testdata/fixtures/narrowing/dynamic-registry-renderer-guard/main.lua index 3aeeb9e2..39fbbd98 100644 --- a/testdata/fixtures/narrowing/dynamic-registry-renderer-guard/main.lua +++ b/testdata/fixtures/narrowing/dynamic-registry-renderer-guard/main.lua @@ -9,8 +9,8 @@ local function get_page_data(page) return {}, nil end - local name: string = page.data_func -- expect-error: cannot assign string | true to string - takes_string(page.data_func) -- expect-error: argument 1: expected string, got string | true + local name: string = page.data_func + takes_string(page.data_func) return {}, nil end @@ -19,4 +19,4 @@ local page = page_registry.build_page({ data = { data_func = "load_data" }, }) -return get_page_data(page) -- expect-error: expected {data_func?: boolean | string +return get_page_data(page) diff --git a/testdata/fixtures/realworld/sql-repository/repository.lua b/testdata/fixtures/realworld/sql-repository/repository.lua index 3fedc990..8b7ba31e 100644 --- a/testdata/fixtures/realworld/sql-repository/repository.lua +++ b/testdata/fixtures/realworld/sql-repository/repository.lua @@ -46,7 +46,7 @@ function M.table_exists(database: db.Database): (boolean?, string?) return nil, "Query failed: " .. tostring(query_err) end if result and result[1] then - return result[1].exists or (result[1].count and result[1].count > 0), nil + return result[1].exists == true or (result[1].count and result[1].count > 0), nil end return false, nil end diff --git a/testdata/fixtures/regression/deadlock-compiler-lua/manifest.json b/testdata/fixtures/regression/deadlock-compiler-lua/manifest.json new file mode 100644 index 00000000..bca5acdf --- /dev/null +++ b/testdata/fixtures/regression/deadlock-compiler-lua/manifest.json @@ -0,0 +1,6 @@ +{ + "description": "Large dynamic compiler fixture must terminate; dynamic any-to-string calls are reported soundly", + "check": { + "errors": 2 + } +} diff --git a/testdata/fixtures/regression/deadlock-dataflow-node/manifest.json b/testdata/fixtures/regression/deadlock-dataflow-node/manifest.json new file mode 100644 index 00000000..4aab2baa --- /dev/null +++ b/testdata/fixtures/regression/deadlock-dataflow-node/manifest.json @@ -0,0 +1,6 @@ +{ + "description": "Large dynamic dataflow-node fixture must terminate; dynamic record-shape errors are reported soundly", + "check": { + "errors": 4 + } +} diff --git a/types/constraint/condition.go b/types/constraint/condition.go index 06ac99aa..2ffcb50a 100644 --- a/types/constraint/condition.go +++ b/types/constraint/condition.go @@ -81,7 +81,7 @@ type Condition struct { // TrueCondition returns a condition that imposes no constraints. func TrueCondition() Condition { - return Condition{Disjuncts: [][]Constraint{{}}} + return Condition{Disjuncts: [][]Constraint{nil}} } // FalseCondition returns an unsatisfiable condition. diff --git a/types/flow/edge.go b/types/flow/edge.go index 8b9ae132..ecadde31 100644 --- a/types/flow/edge.go +++ b/types/flow/edge.go @@ -27,6 +27,9 @@ func (s *Solution) buildEdgeConditions() { if !ec.Condition.HasConstraints() && !ec.Condition.IsFalse() { continue } + if s.edgeConditions == nil { + s.edgeConditions = make(map[edgeKey]constraint.Condition, len(s.inputs.EdgeConditions)) + } key := edgeKey{from: ec.From, to: ec.To} if existing, ok := s.edgeConditions[key]; ok && (existing.HasConstraints() || existing.IsFalse()) { s.edgeConditions[key] = constraint.And(existing, ec.Condition) diff --git a/types/flow/numeric.go b/types/flow/numeric.go index d97d35dc..89d02265 100644 --- a/types/flow/numeric.go +++ b/types/flow/numeric.go @@ -25,6 +25,9 @@ func (s *Solution) buildEdgeNumericConstraints() { if len(edge.Constraints) == 0 { continue } + if s.edgeNumericConstraints == nil { + s.edgeNumericConstraints = make(map[edgeKey][]constraint.NumericConstraint, len(s.inputs.EdgeNumericConstraints)) + } key := edgeKey{from: edge.From, to: edge.To} s.edgeNumericConstraints[key] = append(s.edgeNumericConstraints[key], edge.Constraints...) } @@ -122,6 +125,9 @@ func (s *Solution) checkNumericConstraints() { edgeState.ApplyConstraintWithResolver(nc, resolver) } if !edgeState.CheckSatisfiability() { + if s.unsatEdges == nil { + s.unsatEdges = make(map[edgeKey]bool) + } s.unsatEdges[key] = true } } @@ -223,7 +229,7 @@ func (s *Solution) computeNumericStateAt(c cfg.Graph, p cfg.Point, state map[cfg return nil } - var predStates []*numeric.State + var result *numeric.State for _, pred := range preds { predState := state[pred] edgeConstraints := s.edgeNumericConstraints[edgeKey{from: pred, to: p}] @@ -251,20 +257,13 @@ func (s *Solution) computeNumericStateAt(c cfg.Graph, p cfg.Point, state map[cfg if edgeState.IsTop() { continue } - predStates = append(predStates, edgeState) - } - - if len(predStates) == 0 { - return nil - } - if len(predStates) == 1 { - return predStates[0] + if result == nil { + result = edgeState + continue + } + result = numeric.Join(result, edgeState) } - result := predStates[0] - for i := 1; i < len(predStates); i++ { - result = numeric.Join(result, predStates[i]) - } return result } @@ -302,27 +301,24 @@ func (s *Solution) rekeyForPhis(state *numeric.State, pred, p cfg.Point) *numeri return state } - // Find phi nodes at point p - var relevantPhis []cfg.PhiNode + var keyRemap map[constraint.PathKey]constraint.PathKey for _, phi := range s.inputs.Graph.PhiNodes() { - if phi.Point == p { - relevantPhis = append(relevantPhis, phi) + if phi.Point != p { + continue + } + newKey := s.pkResolver.KeyAtVersion(phi.Target.Symbol, phi.Target.ID, nil) + if newKey == "" { + continue } - } - - if len(relevantPhis) == 0 { - return state - } - - // Build mapping from old keys to new keys - keyRemap := make(map[constraint.PathKey]constraint.PathKey) - for _, phi := range relevantPhis { - // Find the operand version coming from pred for _, op := range phi.Operands { - // Check if this operand's definition point is reachable from pred + if op.From != pred { + continue + } oldKey := s.pkResolver.KeyAtVersion(op.Version.Symbol, op.Version.ID, nil) - newKey := s.pkResolver.KeyAtVersion(phi.Target.Symbol, phi.Target.ID, nil) if oldKey != "" && newKey != "" && oldKey != newKey { + if keyRemap == nil { + keyRemap = make(map[constraint.PathKey]constraint.PathKey) + } keyRemap[oldKey] = newKey } } diff --git a/types/flow/numeric/state.go b/types/flow/numeric/state.go index 1faf78b9..dfd7803b 100644 --- a/types/flow/numeric/state.go +++ b/types/flow/numeric/state.go @@ -83,12 +83,7 @@ var ( // NewState creates an empty (top) numeric state. func NewState() *State { - return &State{ - bounds: make(map[constraint.PathKey]Interval), - modular: make(map[constraint.PathKey]ModResidue), - relations: make(map[relationKey]int64), - lenRefs: make(map[constraint.PathKey]lenRefBound), - } + return &State{} } // Bottom returns the unsatisfiable state (bottom of the lattice). @@ -122,26 +117,33 @@ func (s *State) Clone() *State { return Bottom() } - c := &State{ - bounds: make(map[constraint.PathKey]Interval, len(s.bounds)), - modular: make(map[constraint.PathKey]ModResidue, len(s.modular)), - relations: make(map[relationKey]int64, len(s.relations)), - lenRefs: make(map[constraint.PathKey]lenRefBound, len(s.lenRefs)), - } - for _, k := range constraint.SortedPathKeys(s.bounds) { - c.bounds[k] = s.bounds[k] + c := &State{} + if len(s.bounds) > 0 { + c.bounds = make(map[constraint.PathKey]Interval, len(s.bounds)) + for k, v := range s.bounds { + c.bounds[k] = v + } } - for _, k := range constraint.SortedPathKeys(s.modular) { - c.modular[k] = s.modular[k] + if len(s.modular) > 0 { + c.modular = make(map[constraint.PathKey]ModResidue, len(s.modular)) + for k, v := range s.modular { + c.modular[k] = v + } } - for _, k := range sortedRelationKeys(s.relations) { - c.relations[k] = s.relations[k] + if len(s.relations) > 0 { + c.relations = make(map[relationKey]int64, len(s.relations)) + for k, v := range s.relations { + c.relations[k] = v + } } - for _, k := range constraint.SortedPathKeys(s.lenRefs) { - c.lenRefs[k] = s.lenRefs[k] + if len(s.lenRefs) > 0 { + c.lenRefs = make(map[constraint.PathKey]lenRefBound, len(s.lenRefs)) + for k, v := range s.lenRefs { + c.lenRefs[k] = v + } } return c @@ -179,11 +181,10 @@ func Join(a, b *State) *State { return a.Clone() } - result := NewState() + result := &State{} // Bounds: keep only variables in both, intersect intervals. - for _, v := range constraint.SortedPathKeys(a.bounds) { - ai := a.bounds[v] + for v, ai := range a.bounds { if bi, ok := b.bounds[v]; ok { merged := intersectIntervals(ai, bi) if merged.Lower > merged.Upper { @@ -191,33 +192,42 @@ func Join(a, b *State) *State { } if merged != unboundedInterval { + if result.bounds == nil { + result.bounds = make(map[constraint.PathKey]Interval, minMapLen(len(a.bounds), len(b.bounds))) + } result.bounds[v] = merged } } } // Modular: keep only if identical in both. - for _, v := range constraint.SortedPathKeys(a.modular) { - am := a.modular[v] + for v, am := range a.modular { if bm, ok := b.modular[v]; ok { if am.Modulus == bm.Modulus && am.Residue == bm.Residue { + if result.modular == nil { + result.modular = make(map[constraint.PathKey]ModResidue, minMapLen(len(a.modular), len(b.modular))) + } result.modular[v] = am } } } // Relations: keep only if present in both, take maximum (loosest bound). - for _, k := range sortedRelationKeys(a.relations) { - av := a.relations[k] + for k, av := range a.relations { if bv, ok := b.relations[k]; ok { + if result.relations == nil { + result.relations = make(map[relationKey]int64, minMapLen(len(a.relations), len(b.relations))) + } result.relations[k] = maxInt64(av, bv) } } // LenRefs: keep only if identical in both. - for _, v := range constraint.SortedPathKeys(a.lenRefs) { - ref := a.lenRefs[v] + for v, ref := range a.lenRefs { if bref, ok := b.lenRefs[v]; ok && ref == bref { + if result.lenRefs == nil { + result.lenRefs = make(map[constraint.PathKey]lenRefBound, minMapLen(len(a.lenRefs), len(b.lenRefs))) + } result.lenRefs[v] = ref } } @@ -243,6 +253,37 @@ func intersectIntervals(a, b Interval) Interval { } } +func minMapLen(a, b int) int { + if a < b { + return a + } + return b +} + +func (s *State) ensureBounds(capacity int) { + if s.bounds == nil { + s.bounds = make(map[constraint.PathKey]Interval, capacity) + } +} + +func (s *State) ensureModular(capacity int) { + if s.modular == nil { + s.modular = make(map[constraint.PathKey]ModResidue, capacity) + } +} + +func (s *State) ensureRelations(capacity int) { + if s.relations == nil { + s.relations = make(map[relationKey]int64, capacity) + } +} + +func (s *State) ensureLenRefs(capacity int) { + if s.lenRefs == nil { + s.lenRefs = make(map[constraint.PathKey]lenRefBound, capacity) + } +} + // ApplyConstraintWithResolver refines the state with a numeric constraint. // // Uses the provided resolver to convert constraint paths to versioned PathKeys. @@ -350,6 +391,7 @@ func (s *State) applyLe(x, y constraint.PathKey) { func (s *State) applyLeWithConst(x, y constraint.PathKey, c int64) { // x - y <= c + s.ensureRelations(1) key := relationKey{X: x, Y: y} if old, ok := s.relations[key]; ok { s.relations[key] = minInt64(old, c) @@ -364,6 +406,7 @@ func (s *State) ApplyLt(x, y constraint.PathKey) { func (s *State) applyLt(x, y constraint.PathKey) { // x < y => x - y <= -1 + s.ensureRelations(1) key := relationKey{X: x, Y: y} if old, ok := s.relations[key]; ok { s.relations[key] = minInt64(old, -1) @@ -382,6 +425,7 @@ func (s *State) ApplyGe(x, y constraint.PathKey) { func (s *State) applyGe(x, y constraint.PathKey) { // x >= y => y - x <= 0 + s.ensureRelations(1) key := relationKey{X: y, Y: x} if old, ok := s.relations[key]; ok { s.relations[key] = minInt64(old, 0) @@ -396,6 +440,7 @@ func (s *State) ApplyGt(x, y constraint.PathKey) { func (s *State) applyGt(x, y constraint.PathKey) { // x > y => y - x <= -1 + s.ensureRelations(1) key := relationKey{X: y, Y: x} if old, ok := s.relations[key]; ok { s.relations[key] = minInt64(old, -1) @@ -419,6 +464,7 @@ func (s *State) ApplyEqConst(v constraint.PathKey, c int64) { } func (s *State) applyEqConst(v constraint.PathKey, c int64) { + s.ensureBounds(1) s.bounds[v] = Interval{Lower: c, Upper: c} } @@ -427,6 +473,7 @@ func (s *State) ApplyLeConst(v constraint.PathKey, c int64) { } func (s *State) applyLeConst(v constraint.PathKey, c int64) { + s.ensureBounds(1) if b, ok := s.bounds[v]; ok { b.Upper = minInt64(b.Upper, c) if b.Lower > b.Upper { @@ -444,6 +491,7 @@ func (s *State) ApplyGeConst(v constraint.PathKey, c int64) { } func (s *State) applyGeConst(v constraint.PathKey, c int64) { + s.ensureBounds(1) if b, ok := s.bounds[v]; ok { b.Lower = maxInt64(b.Lower, c) if b.Lower > b.Upper { @@ -461,6 +509,7 @@ func (s *State) ApplyModEq(v constraint.PathKey, m, r int64) { } func (s *State) applyModEq(v constraint.PathKey, m, r int64) { + s.ensureModular(1) if existing, ok := s.modular[v]; ok { if existing.Modulus != m || existing.Residue != r { s.unsat = true @@ -479,6 +528,7 @@ func (s *State) ApplyLeLenOfWithOffset(v, arr constraint.PathKey, offset int64) } func (s *State) applyLeLenOf(v, arr constraint.PathKey, offset int64) { + s.ensureLenRefs(1) s.lenRefs[v] = lenRefBound{Array: arr, Offset: offset} } @@ -566,16 +616,17 @@ func (s *State) checkDifferenceConstraints() bool { return true } - // Collect variables. - vars := make(map[constraint.PathKey]bool) + relKeys := sortedRelationKeys(s.relations) - for _, k := range sortedRelationKeys(s.relations) { - vars[k.X] = true - vars[k.Y] = true + // Collect variables. + vars := make(map[constraint.PathKey]struct{}, len(s.relations)*2) + for _, k := range relKeys { + vars[k.X] = struct{}{} + vars[k.Y] = struct{}{} } // Initialize distances from virtual source. - dist := make(map[constraint.PathKey]int64) + dist := make(map[constraint.PathKey]int64, len(vars)) for _, v := range constraint.SortedPathKeys(vars) { dist[v] = 0 } @@ -585,7 +636,7 @@ func (s *State) checkDifferenceConstraints() bool { for i := 0; i < n; i++ { changed := false - for _, k := range sortedRelationKeys(s.relations) { + for _, k := range relKeys { w := s.relations[k] // x - y <= w => dist[x] <= dist[y] + w if dist[k.Y]+w < dist[k.X] { @@ -600,7 +651,7 @@ func (s *State) checkDifferenceConstraints() bool { } // Check for negative cycle (one more iteration). - for _, k := range sortedRelationKeys(s.relations) { + for _, k := range relKeys { w := s.relations[k] if dist[k.Y]+w < dist[k.X] { return false @@ -742,10 +793,11 @@ func (s *State) Relations() map[constraint.PathKey]map[constraint.PathKey]int64 } result := make(map[constraint.PathKey]map[constraint.PathKey]int64) for _, rel := range sortedRelationKeys(s.relations) { + c := s.relations[rel] if result[rel.X] == nil { result[rel.X] = make(map[constraint.PathKey]int64) } - result[rel.X][rel.Y] = s.relations[rel] + result[rel.X][rel.Y] = c } return result } @@ -785,12 +837,22 @@ func (s *State) Rekey(remap map[constraint.PathKey]constraint.PathKey) *State { if s.unsat { return Bottom() } + if s.isTop() { + return s + } - result := &State{ - bounds: make(map[constraint.PathKey]Interval, len(s.bounds)), - modular: make(map[constraint.PathKey]ModResidue, len(s.modular)), - relations: make(map[relationKey]int64, len(s.relations)), - lenRefs: make(map[constraint.PathKey]lenRefBound, len(s.lenRefs)), + result := &State{} + if len(s.bounds) > 0 { + result.bounds = make(map[constraint.PathKey]Interval, len(s.bounds)) + } + if len(s.modular) > 0 { + result.modular = make(map[constraint.PathKey]ModResidue, len(s.modular)) + } + if len(s.relations) > 0 { + result.relations = make(map[relationKey]int64, len(s.relations)) + } + if len(s.lenRefs) > 0 { + result.lenRefs = make(map[constraint.PathKey]lenRefBound, len(s.lenRefs)) } // Remap bounds @@ -800,6 +862,12 @@ func (s *State) Rekey(remap map[constraint.PathKey]constraint.PathKey) *State { if mapped, ok := remap[k]; ok { newKey = mapped } + if existing, ok := result.bounds[newKey]; ok { + v = intersectIntervals(existing, v) + if v.Lower > v.Upper { + return Bottom() + } + } result.bounds[newKey] = v } @@ -810,6 +878,9 @@ func (s *State) Rekey(remap map[constraint.PathKey]constraint.PathKey) *State { if mapped, ok := remap[k]; ok { newKey = mapped } + if existing, ok := result.modular[newKey]; ok && existing != v { + return Bottom() + } result.modular[newKey] = v } @@ -824,7 +895,11 @@ func (s *State) Rekey(remap map[constraint.PathKey]constraint.PathKey) *State { if mapped, ok := remap[rel.Y]; ok { newY = mapped } - result.relations[relationKey{X: newX, Y: newY}] = c + newRel := relationKey{X: newX, Y: newY} + if existing, ok := result.relations[newRel]; ok { + c = minInt64(existing, c) + } + result.relations[newRel] = c } // Remap length references (both variable and array keys) @@ -839,6 +914,10 @@ func (s *State) Rekey(remap map[constraint.PathKey]constraint.PathKey) *State { newArr = mapped } ref.Array = newArr + if existing, ok := result.lenRefs[newK]; ok && existing != ref { + delete(result.lenRefs, newK) + continue + } result.lenRefs[newK] = ref } diff --git a/types/flow/propagate/propagate.go b/types/flow/propagate/propagate.go index db7ea994..4f56a775 100644 --- a/types/flow/propagate/propagate.go +++ b/types/flow/propagate/propagate.go @@ -143,10 +143,10 @@ func Propagate(inputs *Inputs) *Result { } g := inputs.Graph - pointConditions := make(map[cfg.Point]constraint.Condition) + worklist := g.RPO() + pointConditions := make(map[cfg.Point]constraint.Condition, len(worklist)) pointConditions[g.Entry()] = constraint.TrueCondition() - worklist := g.RPO() inQueue := make(map[cfg.Point]bool, len(worklist)) for _, p := range worklist { inQueue[p] = true @@ -217,7 +217,8 @@ func computeConditionAtPoint( return constraint.FalseCondition() } - var predConds []constraint.Condition + var result constraint.Condition + hasResult := false for _, pred := range preds { if inputs.DeadPoints != nil && inputs.DeadPoints[pred] { continue @@ -263,7 +264,13 @@ func computeConditionAtPoint( edgeCond, ok := inputs.EdgeConditions[EdgeKey{From: pred, To: p}] if !ok || (!edgeCond.HasConstraints() && !edgeCond.IsFalse()) { - edgeCond = constraint.TrueCondition() + if !hasResult { + result = predCond + hasResult = true + } else { + result = constraint.Or(result, predCond) + } + continue } var combinedCond constraint.Condition @@ -281,23 +288,18 @@ func computeConditionAtPoint( continue } - predConds = append(predConds, combinedCond) + if !hasResult { + result = combinedCond + hasResult = true + } else { + result = constraint.Or(result, combinedCond) + } } - if len(predConds) == 0 { + if !hasResult { return constraint.FalseCondition() } - var result constraint.Condition - if len(predConds) == 1 { - result = predConds[0] - } else { - result = predConds[0] - for i := 1; i < len(predConds); i++ { - result = constraint.Or(result, predConds[i]) - } - } - result = KillRedefinedConditions(result, p, inputs.Assignments) return result diff --git a/types/flow/solver.go b/types/flow/solver.go index 18c86862..bf589762 100644 --- a/types/flow/solver.go +++ b/types/flow/solver.go @@ -100,16 +100,10 @@ func Solve(inputs *Inputs, resolver narrow.Resolver) *Solution { } s := &Solution{ - inputs: inputs, - resolver: resolver, - pkResolver: pkRes, - values: make(map[string]typ.Type, size*2), - edgeConditions: make(map[edgeKey]constraint.Condition), - edgeNumericConstraints: make(map[edgeKey][]constraint.NumericConstraint), - unsatEdges: make(map[edgeKey]bool), - pointConditions: make(map[cfg.Point]constraint.Condition), - numericStates: make(map[cfg.Point]*numeric.State), - pathAliases: make(map[string]string, size), + inputs: inputs, + resolver: resolver, + pkResolver: pkRes, + values: make(map[string]typ.Type, estimateSolutionValueCapacity(inputs, size)), } if inputs != nil && len(inputs.DeclaredTypes) > 0 { s.declaredSyms = make([]cfg.SymbolID, 0, len(inputs.DeclaredTypes)) @@ -130,6 +124,26 @@ func Solve(inputs *Inputs, resolver narrow.Resolver) *Solution { return s } +func estimateSolutionValueCapacity(inputs *Inputs, graphSize int) int { + if inputs == nil { + return 0 + } + capacity := len(inputs.DeclaredTypes) + len(inputs.Assignments) + capacity += len(inputs.IndexerAssignments) + len(inputs.TableMutatorAssignments) + capacity += len(inputs.ContainerMutatorAssignments) + if capacity < len(inputs.ConstValues) { + capacity += len(inputs.ConstValues) + } + if capacity < 8 && graphSize > 0 { + capacity = 8 + } + maxByGraph := graphSize * 2 + if maxByGraph > 0 && capacity > maxByGraph { + return maxByGraph + } + return capacity +} + // runPropagation runs constraint propagation using the propagate package. // // Converts edge conditions to the propagate package format and runs @@ -140,9 +154,12 @@ func (s *Solution) runPropagation() { } // Convert edge conditions to propagate format - edgeConds := make(propagate.EdgeConditions, len(s.edgeConditions)) - for k, cond := range s.edgeConditions { - edgeConds[propagate.EdgeKey{From: k.from, To: k.to}] = cond + var edgeConds propagate.EdgeConditions + if len(s.edgeConditions) > 0 { + edgeConds = make(propagate.EdgeConditions, len(s.edgeConditions)) + for k, cond := range s.edgeConditions { + edgeConds[propagate.EdgeKey{From: k.from, To: k.to}] = cond + } } // Convert assignments to propagate format diff --git a/types/flow/solver_helpers.go b/types/flow/solver_helpers.go index 30da3b73..75a7bb14 100644 --- a/types/flow/solver_helpers.go +++ b/types/flow/solver_helpers.go @@ -98,9 +98,12 @@ func (s *Solution) initSymbolTypes(src symbolTypeSource) { } func (s *Solution) setValue(key string, t typ.Type) { - if s == nil || s.values == nil || key == "" { + if s == nil || key == "" { return } + if s.values == nil { + s.values = make(map[string]typ.Type, 1) + } s.values[key] = t if s.fieldOverlayCache == nil { return diff --git a/types/flow/transfer.go b/types/flow/transfer.go index 85b5804c..ba5e6dcf 100644 --- a/types/flow/transfer.go +++ b/types/flow/transfer.go @@ -114,6 +114,9 @@ func (s *Solution) processAssignmentReturnChangedKeys(p cfg.Point) []string { if assign.SourcePath.HasSymbol() { sourceKey := s.pkResolver.KeyAt(p, assign.SourcePath) if sourceKey != "" && sourceKey != targetKey { + if s.pathAliases == nil { + s.pathAliases = make(map[string]string) + } s.pathAliases[targetKeyStr] = string(sourceKey) } else { delete(s.pathAliases, targetKeyStr) diff --git a/types/typ/function.go b/types/typ/function.go index e07b0954..94897cf2 100644 --- a/types/typ/function.go +++ b/types/typ/function.go @@ -57,6 +57,17 @@ func Func() *FunctionBuilder { return &FunctionBuilder{} } +// ReserveParams avoids reallocating while appending known effective parameters. +func (b *FunctionBuilder) ReserveParams(n int) *FunctionBuilder { + if b == nil || n <= 1 || cap(b.params) >= n { + return b + } + params := make([]Param, len(b.params), n) + copy(params, b.params) + b.params = params + return b +} + // TypeParam adds a type parameter for generic functions. func (b *FunctionBuilder) TypeParam(name string, constraint Type) *FunctionBuilder { b.typeParams = append(b.typeParams, NewTypeParam(name, constraint)) diff --git a/types/typ/join/join.go b/types/typ/join/join.go index 85a1162f..903f92e0 100644 --- a/types/typ/join/join.go +++ b/types/typ/join/join.go @@ -49,7 +49,7 @@ func WithReturns(sig *typ.Function, returns []typ.Type) *typ.Function { return nil } - builder := typ.Func() + builder := typ.Func().ReserveParams(len(sig.Params)) for _, tp := range sig.TypeParams { builder = builder.TypeParam(tp.Name, tp.Constraint) } @@ -64,11 +64,18 @@ func WithReturns(sig *typ.Function, returns []typ.Type) *typ.Function { builder = builder.Variadic(sig.Variadic) } - normalized := make([]typ.Type, len(returns)) - copy(normalized, returns) - for i, t := range normalized { + normalized := returns + for i, t := range returns { if t == nil { + normalized = make([]typ.Type, len(returns)) + copy(normalized, returns) normalized[i] = typ.Unknown + for j := i + 1; j < len(normalized); j++ { + if normalized[j] == nil { + normalized[j] = typ.Unknown + } + } + break } } builder = builder.Returns(normalized...) @@ -101,6 +108,9 @@ func WithReturnsOrUnknown(sig *typ.Function, returns []typ.Type) *typ.Function { return WithReturns(sig, []typ.Type{typ.Unknown}) } if len(sig.Returns) == 0 || typ.IsUnknownOnlyOrEmpty(sig.Returns) { + if returnVectorsEqual(sig.Returns, returns) { + return sig + } return WithReturns(sig, returns) } if len(sig.Returns) == len(returns) { @@ -112,6 +122,9 @@ func WithReturnsOrUnknown(sig *typ.Function, returns []typ.Type) *typ.Function { } } if hasPlaceholder { + if returnVectorsEqual(sig.Returns, returns) { + return sig + } return WithReturns(sig, returns) } return sig @@ -121,8 +134,33 @@ func WithReturnsOrUnknown(sig *typ.Function, returns []typ.Type) *typ.Function { break } if ret != nil && ret.Kind().IsPlaceholder() { + if returnVectorsEqual(sig.Returns, returns) { + return sig + } return WithReturns(sig, returns) } } return sig } + +func returnVectorsEqual(existing, returns []typ.Type) bool { + if len(existing) != len(returns) { + return false + } + for i, right := range returns { + if right == nil { + right = typ.Unknown + } + left := existing[i] + if left == right { + continue + } + if left == nil || right == nil { + return false + } + if left.Hash() != right.Hash() || !typ.TypeEquals(left, right) { + return false + } + } + return true +} diff --git a/types/typ/policy.go b/types/typ/policy.go index 2a18a6b0..884eb811 100644 --- a/types/typ/policy.go +++ b/types/typ/policy.go @@ -306,11 +306,11 @@ func literalType(t Type) (*Literal, bool) { } // JoinBranchOutcome merges mutually-exclusive expression outcomes (for example, -// `a and b` / `a or b`) while preserving uncertainty. +// `a and b` / `a or b`) while preserving every runtime possibility. // -// Unlike JoinPreferNonSoft, this must not treat unknown as absent information: -// expression typing needs to preserve runtime uncertainty when one branch may -// still produce unknown-like values. +// Unlike inference joins, expression outcomes are value-level alternatives: +// a soft placeholder returned by one branch is still a real possible runtime +// value and must not be pruned just because the other branch is concrete. func JoinBranchOutcome(a, b Type) Type { if a == nil { return b @@ -318,28 +318,10 @@ func JoinBranchOutcome(a, b Type) Type { if b == nil { return a } - - a = PruneSoftUnionMembers(a) - b = PruneSoftUnionMembers(b) - - // Preserve runtime uncertainty for branch outcomes: - // unknown and nil means "value may be unknown or absent". - if (IsUnknown(a) && b.Kind() == kind.Nil) || (IsUnknown(b) && a.Kind() == kind.Nil) { - return NewOptional(Unknown) - } - - if IsSoft(a, SoftPlaceholderPolicy) && !IsSoft(b, SoftPlaceholderPolicy) && b.Kind() != kind.Nil { - return b - } - if IsSoft(b, SoftPlaceholderPolicy) && !IsSoft(a, SoftPlaceholderPolicy) && a.Kind() != kind.Nil { - return a - } - if TypeEquals(a, b) { return a } - - return PruneSoftUnionMembers(NewUnion(a, b)) + return NewUnion(a, b) } // IsRefinableAnnotation reports whether an explicit annotation should be diff --git a/types/typ/policy_test.go b/types/typ/policy_test.go index ab2c8c27..7c9ca1a8 100644 --- a/types/typ/policy_test.go +++ b/types/typ/policy_test.go @@ -46,12 +46,13 @@ func TestJoinBranchOutcome_PreservesUnknownWithNil(t *testing.T) { } } -func TestJoinBranchOutcome_PrefersConcreteOverSoft(t *testing.T) { +func TestJoinBranchOutcome_PreservesSoftRuntimeAlternative(t *testing.T) { left := NewOptional(NewArray(Any)) right := NewArray(Number) got := JoinBranchOutcome(left, right) - if got == nil || got.String() != "number[]" { - t.Fatalf("JoinBranchOutcome(%v, %v) = %v, want number[]", left, right, got) + want := NewUnion(left, right) + if !TypeEquals(got, want) { + t.Fatalf("JoinBranchOutcome(%v, %v) = %v, want %v", left, right, got, want) } } diff --git a/types/typ/rebuild.go b/types/typ/rebuild.go index 8c580461..fa068357 100644 --- a/types/typ/rebuild.go +++ b/types/typ/rebuild.go @@ -77,6 +77,9 @@ func buildRecordType(fields []Field, metatable, mapKey, mapValue Type, open bool if sorted[i].Type == nil { sorted[i].Type = Unknown } + if sorted[i].Optional { + sorted[i].Type = normalizeOptionalFieldType(sorted[i].Type) + } } if mapKey == nil && mapValue != nil { @@ -126,6 +129,41 @@ func buildRecordType(fields []Field, metatable, mapKey, mapValue Type, open bool } } +func normalizeOptionalFieldType(t Type) Type { + if t == nil { + return Unknown + } + switch v := t.(type) { + case *Annotated: + inner := normalizeOptionalFieldType(v.Inner) + if inner == v.Inner { + return t + } + return NewAnnotated(inner, v.Annotations) + case *Alias: + return t + case *Optional: + if v.Inner == nil || v.Inner.Kind() == kind.Never || v.Inner.Kind() == kind.Nil { + return t + } + return v.Inner + case *Union: + kept := make([]Type, 0, len(v.Members)) + for _, member := range v.Members { + if member == nil || member.Kind() == kind.Nil || member.Kind() == kind.Never { + continue + } + kept = append(kept, member) + } + if len(kept) == 0 { + return t + } + return NewUnion(kept...) + default: + return t + } +} + func fieldsSortedByName(fields []Field) bool { for i := 1; i < len(fields); i++ { if fields[i-1].Name > fields[i].Name { diff --git a/types/typ/record_test.go b/types/typ/record_test.go index 14aed80c..927c0032 100644 --- a/types/typ/record_test.go +++ b/types/typ/record_test.go @@ -67,6 +67,29 @@ func TestRecordFieldsSorted(t *testing.T) { } } +func TestRecordOptionalFieldNormalizesNestedOptionalType(t *testing.T) { + r := NewRecord(). + OptField("error", NewOptional(String)). + OptField("nil_only", Nil). + Build() + + err := r.GetField("error") + if err == nil || !err.Optional { + t.Fatal("expected optional error field") + } + if !TypeEquals(err.Type, String) { + t.Fatalf("expected nested optional field type to normalize to string, got %v", err.Type) + } + + nilOnly := r.GetField("nil_only") + if nilOnly == nil || !nilOnly.Optional { + t.Fatal("expected optional nil_only field") + } + if !TypeEquals(nilOnly.Type, Nil) { + t.Fatalf("expected nil-only optional field type to remain nil, got %v", nilOnly.Type) + } +} + func TestRecordOptionalField(t *testing.T) { r := NewRecord(). Field("x", Number). diff --git a/types/typ/unwrap/unwrap.go b/types/typ/unwrap/unwrap.go index e057b703..cf03d720 100644 --- a/types/typ/unwrap/unwrap.go +++ b/types/typ/unwrap/unwrap.go @@ -5,7 +5,6 @@ package unwrap import ( - "github.com/wippyai/go-lua/internal" "github.com/wippyai/go-lua/types/kind" "github.com/wippyai/go-lua/types/typ" "github.com/wippyai/go-lua/types/typ/subst" @@ -15,63 +14,72 @@ import ( // Unwraps: Alias, Optional (to inner type). // Does NOT unwrap: Instantiated (requires type substitution), Union, Ref. func Underlying(t typ.Type) typ.Type { - return underlyingDepth(t, typ.NewGuard()) -} - -func underlyingDepth(t typ.Type, guard internal.RecursionGuard) typ.Type { - return typ.VisitWithGuard(t, guard, nil, func(next internal.RecursionGuard) typ.Visitor[typ.Type] { - return typ.Visitor[typ.Type]{ - Alias: func(a *typ.Alias) typ.Type { - return underlyingDepth(a.UnaliasedTarget(), next) - }, - Optional: func(o *typ.Optional) typ.Type { - return underlyingDepth(o.Inner, next) - }, - Default: func(t typ.Type) typ.Type { - return t - }, + for depth := 0; depth <= typ.DefaultRecursionDepth; depth++ { + t = transparent(t) + switch tt := t.(type) { + case nil: + return nil + case *typ.Alias: + next := tt.UnaliasedTarget() + if next == nil || next == t { + return next + } + t = next + case *typ.Optional: + next := tt.Inner + if next == nil || next == t { + return next + } + t = next + default: + return t } - }) + } + return nil } // Alias unwraps only Alias wrappers, preserving Optional. func Alias(t typ.Type) typ.Type { - return unwrapAliasDepth(t, typ.NewGuard()) -} - -func unwrapAliasDepth(t typ.Type, guard internal.RecursionGuard) typ.Type { - return typ.VisitWithGuard(t, guard, nil, func(next internal.RecursionGuard) typ.Visitor[typ.Type] { - return typ.Visitor[typ.Type]{ - Alias: func(a *typ.Alias) typ.Type { - return unwrapAliasDepth(a.UnaliasedTarget(), next) - }, - Default: func(t typ.Type) typ.Type { - return t - }, + for depth := 0; depth <= typ.DefaultRecursionDepth; depth++ { + t = transparent(t) + alias, ok := t.(*typ.Alias) + if !ok { + return t } - }) + next := alias.UnaliasedTarget() + if next == nil || next == t { + return next + } + t = next + } + return nil } // Optional unwraps Optional to get the inner non-nil type. // Also unwraps Alias. Returns nil if type is nil or Nil. func Optional(t typ.Type) typ.Type { - return unwrapOptionalDepth(t, typ.NewGuard()) -} - -func unwrapOptionalDepth(t typ.Type, guard internal.RecursionGuard) typ.Type { - return typ.VisitWithGuard(t, guard, nil, func(next internal.RecursionGuard) typ.Visitor[typ.Type] { - return typ.Visitor[typ.Type]{ - Alias: func(a *typ.Alias) typ.Type { - return unwrapOptionalDepth(a.UnaliasedTarget(), next) - }, - Optional: func(o *typ.Optional) typ.Type { - return unwrapOptionalDepth(o.Inner, next) - }, - Default: func(t typ.Type) typ.Type { - return t - }, + for depth := 0; depth <= typ.DefaultRecursionDepth; depth++ { + t = transparent(t) + switch tt := t.(type) { + case nil: + return nil + case *typ.Alias: + next := tt.UnaliasedTarget() + if next == nil || next == t { + return next + } + t = next + case *typ.Optional: + next := tt.Inner + if next == nil || next == t { + return next + } + t = next + default: + return t } - }) + } + return nil } // IsOptionalLike returns true if the type is Optional or contains nil. @@ -159,106 +167,116 @@ func IsBuiltinTableTop(t typ.Type) bool { // Function extracts a Function type, unwrapping Alias and Optional. func Function(t typ.Type) *typ.Function { - return unwrapFunctionDepth(t, typ.NewGuard()) -} - -func unwrapFunctionDepth(t typ.Type, guard internal.RecursionGuard) *typ.Function { - return typ.VisitWithGuard(t, guard, nil, func(next internal.RecursionGuard) typ.Visitor[*typ.Function] { - return typ.Visitor[*typ.Function]{ - Function: func(fn *typ.Function) *typ.Function { - return fn - }, - Optional: func(o *typ.Optional) *typ.Function { - return unwrapFunctionDepth(o.Inner, next) - }, - Recursive: func(rec *typ.Recursive) *typ.Function { - if rec.Body == nil || rec.Body == rec { - return nil - } - return unwrapFunctionDepth(rec.Body, next) - }, - Alias: func(a *typ.Alias) *typ.Function { - return unwrapFunctionDepth(a.UnaliasedTarget(), next) - }, - Default: func(t typ.Type) *typ.Function { + for depth := 0; depth <= typ.DefaultRecursionDepth; depth++ { + t = transparent(t) + switch tt := t.(type) { + case nil: + return nil + case *typ.Function: + return tt + case *typ.Optional: + next := tt.Inner + if next == nil || next == t { + return nil + } + t = next + case *typ.Recursive: + next := tt.Body + if next == nil || next == t { + return nil + } + t = next + case *typ.Alias: + next := tt.UnaliasedTarget() + if next == nil || next == t { return nil - }, + } + t = next + default: + return nil } - }) + } + return nil } // Record extracts a Record type, unwrapping Alias and Optional. func Record(t typ.Type) *typ.Record { - return unwrapRecordDepth(t, typ.NewGuard()) -} - -func unwrapRecordDepth(t typ.Type, guard internal.RecursionGuard) *typ.Record { - return typ.VisitWithGuard(t, guard, nil, func(next internal.RecursionGuard) typ.Visitor[*typ.Record] { - return typ.Visitor[*typ.Record]{ - Record: func(rec *typ.Record) *typ.Record { - return rec - }, - Recursive: func(rec *typ.Recursive) *typ.Record { - if rec.Body == nil || rec.Body == rec { - return nil - } - return unwrapRecordDepth(rec.Body, next) - }, - Alias: func(a *typ.Alias) *typ.Record { - return unwrapRecordDepth(a.UnaliasedTarget(), next) - }, - Optional: func(o *typ.Optional) *typ.Record { - return unwrapRecordDepth(o.Inner, next) - }, - Instantiated: func(inst *typ.Instantiated) *typ.Record { - expanded := subst.ExpandInstantiated(inst) - if expanded == nil || expanded == t { - return nil - } - return unwrapRecordDepth(expanded, next) - }, - Default: func(t typ.Type) *typ.Record { + for depth := 0; depth <= typ.DefaultRecursionDepth; depth++ { + t = transparent(t) + switch tt := t.(type) { + case nil: + return nil + case *typ.Record: + return tt + case *typ.Recursive: + next := tt.Body + if next == nil || next == t { + return nil + } + t = next + case *typ.Alias: + next := tt.UnaliasedTarget() + if next == nil || next == t { + return nil + } + t = next + case *typ.Optional: + next := tt.Inner + if next == nil || next == t { + return nil + } + t = next + case *typ.Instantiated: + next := subst.ExpandInstantiated(tt) + if next == nil || next == t { return nil - }, + } + t = next + default: + return nil } - }) + } + return nil } // Union extracts a Union type, unwrapping Alias and Optional. func Union(t typ.Type) *typ.Union { - return unwrapUnionDepth(t, typ.NewGuard()) -} - -func unwrapUnionDepth(t typ.Type, guard internal.RecursionGuard) *typ.Union { - return typ.VisitWithGuard(t, guard, nil, func(next internal.RecursionGuard) typ.Visitor[*typ.Union] { - return typ.Visitor[*typ.Union]{ - Union: func(u *typ.Union) *typ.Union { - return u - }, - Recursive: func(rec *typ.Recursive) *typ.Union { - if rec.Body == nil || rec.Body == rec { - return nil - } - return unwrapUnionDepth(rec.Body, next) - }, - Alias: func(a *typ.Alias) *typ.Union { - return unwrapUnionDepth(a.UnaliasedTarget(), next) - }, - Optional: func(o *typ.Optional) *typ.Union { - return unwrapUnionDepth(o.Inner, next) - }, - Instantiated: func(inst *typ.Instantiated) *typ.Union { - expanded := subst.ExpandInstantiated(inst) - if expanded == nil || expanded == t { - return nil - } - return unwrapUnionDepth(expanded, next) - }, - Default: func(t typ.Type) *typ.Union { + for depth := 0; depth <= typ.DefaultRecursionDepth; depth++ { + t = transparent(t) + switch tt := t.(type) { + case nil: + return nil + case *typ.Union: + return tt + case *typ.Recursive: + next := tt.Body + if next == nil || next == t { + return nil + } + t = next + case *typ.Alias: + next := tt.UnaliasedTarget() + if next == nil || next == t { return nil - }, + } + t = next + case *typ.Optional: + next := tt.Inner + if next == nil || next == t { + return nil + } + t = next + case *typ.Instantiated: + next := subst.ExpandInstantiated(tt) + if next == nil || next == t { + return nil + } + t = next + default: + return nil } - }) + } + return nil } // IsLiteralString returns true if the type is a string literal. @@ -275,32 +293,32 @@ func IsLiteralString(t typ.Type) bool { // Follows aliases to find the underlying type of the specified kind. // Returns nil if the requested kind is not found or if aliases form a cycle. func ToKind(t typ.Type, k kind.Kind) typ.Type { - return unwrapToKindDepth(t, k, typ.NewGuard()) -} - -func unwrapToKindDepth(t typ.Type, k kind.Kind, guard internal.RecursionGuard) typ.Type { - if t == nil { - return nil - } - if t.Kind() == k { - return t - } - return typ.VisitWithGuard(t, guard, nil, func(next internal.RecursionGuard) typ.Visitor[typ.Type] { - return typ.Visitor[typ.Type]{ - Recursive: func(rec *typ.Recursive) typ.Type { - if rec.Body == nil || rec.Body == rec { - return nil - } - return unwrapToKindDepth(rec.Body, k, next) - }, - Alias: func(a *typ.Alias) typ.Type { - return unwrapToKindDepth(a.UnaliasedTarget(), k, next) - }, - Default: func(t typ.Type) typ.Type { + for depth := 0; depth <= typ.DefaultRecursionDepth; depth++ { + t = transparent(t) + if t == nil { + return nil + } + if t.Kind() == k { + return t + } + switch tt := t.(type) { + case *typ.Recursive: + next := tt.Body + if next == nil || next == t { + return nil + } + t = next + case *typ.Alias: + next := tt.UnaliasedTarget() + if next == nil || next == t { return nil - }, + } + t = next + default: + return nil } - }) + } + return nil } // IsNilType returns true if the type is exactly nil. @@ -325,3 +343,16 @@ func Instantiated(t typ.Type) typ.Type { } return expanded } + +func transparent(t typ.Type) typ.Type { + for { + annotated, ok := t.(*typ.Annotated) + if !ok { + return t + } + if annotated.Inner == nil || annotated.Inner == t { + return t + } + t = annotated.Inner + } +} From 97ee4c827eb08df10a2bfd024c3cf688d81083d7 Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 00:37:07 -0400 Subject: [PATCH 02/71] Refactor interproc fact domain --- compiler/check/api/doc.go | 2 +- compiler/check/api/facts.go | 32 +- compiler/check/api/facts_test.go | 8 +- .../check/erreffect/error_return_infer.go | 80 +- .../erreffect/error_return_infer_test.go | 11 +- compiler/check/flowbuild/assign/emit.go | 175 ++- compiler/check/flowbuild/assign/emit_test.go | 16 + .../flowbuild/assign/error_return_policy.go | 4 +- compiler/check/flowbuild/assign/infer.go | 670 ++++++++-- compiler/check/flowbuild/assign/infer_test.go | 10 + compiler/check/flowbuild/assign/precision.go | 166 +++ .../check/flowbuild/assign/precision_test.go | 45 + .../check/flowbuild/assign/preflow_synth.go | 154 ++- compiler/check/flowbuild/assign/visibility.go | 151 +++ compiler/check/flowbuild/guard/guard.go | 65 +- compiler/check/flowbuild/guard/guard_test.go | 29 + compiler/check/hooks/assign_check.go | 4 +- compiler/check/hooks/field_check.go | 52 +- compiler/check/hooks/table_check.go | 11 + compiler/check/hooks/table_check_test.go | 38 + compiler/check/infer/interproc/postflow.go | 33 +- compiler/check/infer/nested/processor.go | 18 + .../check/infer/paramhints/param_hints.go | 161 ++- .../infer/paramhints/param_hints_test.go | 311 ++++- compiler/check/infer/paramhints/project.go | 644 ++++++++++ compiler/check/infer/return/infer.go | 513 +++++++- .../check/infer/return/overlay_pipeline.go | 150 ++- compiler/check/infer/return/scc.go | 12 + compiler/check/nested/constructor.go | 11 +- compiler/check/nested/enrich.go | 87 +- compiler/check/nested/enrich_test.go | 70 + compiler/check/phase/scope.go | 118 +- compiler/check/pipeline/runner_stages.go | 60 +- compiler/check/returns/callgraph.go | 7 +- compiler/check/returns/callsite.go | 194 ++- compiler/check/returns/callsite_test.go | 129 +- .../returns/container_mutation_merge_test.go | 37 +- compiler/check/returns/domain_law_test.go | 26 + compiler/check/returns/equal_test.go | 20 + compiler/check/returns/function_facts.go | 5 +- compiler/check/returns/join.go | 217 +++- compiler/check/returns/join_test.go | 80 +- compiler/check/returns/signature.go | 1 + compiler/check/returns/types.go | 4 +- compiler/check/returns/widen.go | 347 ++++- compiler/check/returns/widen_test.go | 130 ++ compiler/check/siblings/overlay.go | 47 +- compiler/check/store/facts_clone.go | 162 +++ compiler/check/store/snapshot_inputs.go | 16 +- compiler/check/store/store_test.go | 34 + compiler/check/synth/literals.go | 75 +- compiler/check/synth/ops/logical.go | 14 +- compiler/check/synth/ops/logical_test.go | 18 +- compiler/check/synth/phase/core/params.go | 6 + compiler/check/synth/phase/extract/expr.go | 107 +- .../check/synth/phase/extract/function.go | 16 +- .../check/synth/phase/extract/pipeline.go | 83 +- compiler/check/synth/phase/extract/table.go | 36 +- .../check/synth/phase/extract/table_test.go | 53 + .../synth/phase/extract/union_expected.go | 2 + compiler/check/tests/modules/manifest_test.go | 47 + ...ssert_false_discriminant_narrowing_test.go | 675 ++++++++++ .../external_lint_regression_test.go | 586 +++++++++ .../regression/false_positives_unit_test.go | 1126 +++++++++++++++++ .../http_timeout_option_inference_test.go | 214 ++++ .../imported_record_helper_param_test.go | 44 + .../param_hint_depth_convergence_test.go | 49 + .../regression/deadlock-compiler-lua/main.lua | 30 +- .../deadlock-compiler-lua/manifest.json | 4 +- .../deadlock-dataflow-node/manifest.json | 4 +- types/flow/query.go | 12 +- types/flow/solver_test.go | 106 ++ types/flow/transfer.go | 40 +- types/io/manifest_lookup.go | 4 +- types/query/core/field.go | 3 + types/query/core/field_test.go | 11 + types/query/core/index.go | 15 +- types/query/core/index_test.go | 15 + types/query/core/operator.go | 3 + types/subtype/subtype.go | 33 +- types/subtype/subtype_test.go | 24 + types/typ/container.go | 1 + types/typ/container_test.go | 7 + types/typ/policy.go | 63 +- types/typ/policy_test.go | 59 + types/typ/rebuild.go | 3 + types/typ/soft.go | 14 +- types/typ/soft_test.go | 3 + types/typ/table_key.go | 59 + 89 files changed, 8601 insertions(+), 430 deletions(-) create mode 100644 compiler/check/flowbuild/assign/visibility.go create mode 100644 compiler/check/hooks/table_check_test.go create mode 100644 compiler/check/infer/paramhints/project.go create mode 100644 compiler/check/store/facts_clone.go create mode 100644 compiler/check/tests/regression/external_lint_regression_test.go create mode 100644 types/typ/table_key.go diff --git a/compiler/check/api/doc.go b/compiler/check/api/doc.go index 7d893a88..bb13b0ef 100644 --- a/compiler/check/api/doc.go +++ b/compiler/check/api/doc.go @@ -31,7 +31,7 @@ // function graph: // // - [FunctionFacts]: Canonical per-function return/signature facts -// - [ParamHints]: Parameter types inferred from call sites +// - [ParamHints]: Effective parameter types inferred from call sites // - [LiteralSigs]: Signatures for anonymous function literals // - [CapturedTypes]: Flow-derived types for captured variables // - [CapturedFieldAssigns]: Field assignments to captured variables diff --git a/compiler/check/api/facts.go b/compiler/check/api/facts.go index 97d88094..8b83dcd8 100644 --- a/compiler/check/api/facts.go +++ b/compiler/check/api/facts.go @@ -12,9 +12,9 @@ import ( "github.com/wippyai/go-lua/types/typ" ) -// ParamHints maps function symbols to parameter type hints inferred from call sites. -// When a function is called with known argument types, those types are recorded -// as hints and propagated to the function's parameter declarations. +// ParamHints maps function symbols to effective-parameter type hints inferred +// from call sites. For method calls, slot 0 is the receiver/self argument and +// the remaining slots are the source arguments. type ParamHints = map[cfg.SymbolID][]typ.Type // FunctionFact is the canonical function-related interproc fact for one symbol. @@ -86,16 +86,40 @@ type CapturedTypes = map[cfg.SymbolID]typ.Type // to its captured variables, supporting constructor inference patterns. type CapturedFieldAssigns = map[cfg.SymbolID]map[cfg.SymbolID]map[string]typ.Type +// ContainerMutationKind describes the operator used for a captured container +// mutation. Different operators have different abstract transfer functions in +// the parent flow. +type ContainerMutationKind uint8 + +const ( + // ContainerMutationContainerElement widens generic container element types, + // such as channel:send(value) through a ContainerElementUnion effect. + ContainerMutationContainerElement ContainerMutationKind = iota + // ContainerMutationTableElement widens Lua table array/map-array element + // types, such as table.insert(t, value). + ContainerMutationTableElement +) + // ContainerMutation records a container element mutation on a captured variable. // Segments capture the path from the base symbol (e.g., .ch, ["queue"]). type ContainerMutation struct { + Kind ContainerMutationKind Segments []constraint.Segment ValueType typ.Type } // ContainerMutationKey returns the canonical path key for a container mutation. func ContainerMutationKey(m ContainerMutation) string { - return constraint.FormatSegments(m.Segments) + return containerMutationKindKey(m.Kind) + ":" + constraint.FormatSegments(m.Segments) +} + +func containerMutationKindKey(kind ContainerMutationKind) string { + switch kind { + case ContainerMutationTableElement: + return "table" + default: + return "container" + } } // CapturedContainerMutations maps nested function symbols to container mutations diff --git a/compiler/check/api/facts_test.go b/compiler/check/api/facts_test.go index dcefa7d1..8062aebe 100644 --- a/compiler/check/api/facts_test.go +++ b/compiler/check/api/facts_test.go @@ -148,8 +148,12 @@ func TestContainerMutationKey(t *testing.T) { }, ValueType: typ.String, } - if got, want := ContainerMutationKey(m), ".queue[\"jobs\"][2]"; got != want { - t.Fatalf("ContainerMutationKey() = %q, want %q", got, want) + if got, want := ContainerMutationKey(m), "container:.queue[\"jobs\"][2]"; got != want { + t.Fatalf("container key = %q, want %q", got, want) + } + m.Kind = ContainerMutationTableElement + if got, want := ContainerMutationKey(m), "table:.queue[\"jobs\"][2]"; got != want { + t.Fatalf("table key = %q, want %q", got, want) } } diff --git a/compiler/check/erreffect/error_return_infer.go b/compiler/check/erreffect/error_return_infer.go index 57acf709..2fb37143 100644 --- a/compiler/check/erreffect/error_return_infer.go +++ b/compiler/check/erreffect/error_return_infer.go @@ -13,36 +13,37 @@ import ( "github.com/wippyai/go-lua/types/typ/unwrap" ) -// ErrorReturnConvention describes a return layout where one slot carries the -// success value and another carries the error. ReturnCount is exact: a function -// with extra returned values is not inferred as this convention. +// ErrorReturnConvention describes a return relation where one slot carries the +// success value and another carries the error. The convention is a pair +// relation, not a complete return-vector shape: extra return slots do not affect +// whether the value/error pair can be proven. type ErrorReturnConvention struct { - ValueIndex int - ErrorIndex int - ReturnCount int + ValueIndex int + ErrorIndex int } // CanonicalLuaValueErrorConvention returns the canonical Lua `(value, err)` layout. func CanonicalLuaValueErrorConvention() ErrorReturnConvention { return ErrorReturnConvention{ - ValueIndex: 0, - ErrorIndex: 1, - ReturnCount: 2, + ValueIndex: 0, + ErrorIndex: 1, } } func (c ErrorReturnConvention) valid() bool { return c.ValueIndex >= 0 && c.ErrorIndex >= 0 && - c.ValueIndex != c.ErrorIndex && - c.ValueIndex < c.ReturnCount && - c.ErrorIndex < c.ReturnCount + c.ValueIndex != c.ErrorIndex } -// CanClassifyReturns reports whether returnTypes has the exact shape required -// by this convention before the expensive per-return inverse-pattern proof runs. +func (c ErrorReturnConvention) requiredReturnSlots() int { + return requiredReturnSlots(c.ValueIndex, c.ErrorIndex) +} + +// CanClassifyReturns reports whether returnTypes contains the slots required by +// this convention before the expensive per-return inverse-pattern proof runs. func (c ErrorReturnConvention) CanClassifyReturns(returnTypes []typ.Type) bool { - return c.valid() && len(returnTypes) == c.ReturnCount + return c.valid() && len(returnTypes) >= c.requiredReturnSlots() } func (c ErrorReturnConvention) canClassifyFunction(fn *typ.Function) bool { @@ -118,6 +119,10 @@ func HasStrictInverseReturnPattern( if graph == nil || synth == nil { return false } + needed := requiredReturnSlots(valueIdx, errorIdx) + if needed == 0 { + return false + } var sawSuccess bool var sawFailure bool var incompatible bool @@ -136,7 +141,14 @@ func HasStrictInverseReturnPattern( return } - values := synth.ExpandValues(info.Exprs, 2, p) + if delegatesErrorReturn(info, p, synth, valueIdx, errorIdx) { + classified = true + sawSuccess = true + sawFailure = true + return + } + + values := synth.ExpandValues(info.Exprs, needed, p) if valueIdx >= len(values) || errorIdx >= len(values) { incompatible = true return @@ -169,6 +181,42 @@ func HasStrictInverseReturnPattern( return classified && !incompatible && sawSuccess && sawFailure } +func delegatesErrorReturn( + info *cfg.ReturnInfo, + p cfg.Point, + synth api.BaseSynth, + valueIdx int, + errorIdx int, +) bool { + if info == nil || synth == nil || len(info.Exprs) != 1 { + return false + } + call, ok := info.Exprs[0].(*ast.FuncCallExpr) + if !ok || call == nil || call.Func == nil { + return false + } + fn := unwrap.Function(synth.TypeOf(call.Func, p)) + if fn == nil { + return false + } + spec := contract.ExtractSpec(fn) + if spec == nil { + return false + } + er := spec.Effects.GetErrorReturn(valueIdx) + return er != nil && er.ErrorIndex == errorIdx +} + +func requiredReturnSlots(valueIdx int, errorIdx int) int { + if valueIdx < 0 || errorIdx < 0 || valueIdx == errorIdx { + return 0 + } + if valueIdx > errorIdx { + return valueIdx + 1 + } + return errorIdx + 1 +} + func AttachErrorReturnSpec(fn *typ.Function, valueIndex, errorIndex int) *typ.Function { if fn == nil { return fn diff --git a/compiler/check/erreffect/error_return_infer_test.go b/compiler/check/erreffect/error_return_infer_test.go index 31e4cef2..62d26e58 100644 --- a/compiler/check/erreffect/error_return_infer_test.go +++ b/compiler/check/erreffect/error_return_infer_test.go @@ -11,13 +11,13 @@ func TestErrorReturnConventionCanClassifyReturns(t *testing.T) { convention := CanonicalLuaValueErrorConvention() if !convention.CanClassifyReturns([]typ.Type{typ.String, typ.Nil}) { - t.Fatal("canonical value/error convention should classify exactly two returns") + t.Fatal("canonical value/error convention should classify two return slots") } if convention.CanClassifyReturns([]typ.Type{typ.String}) { t.Fatal("canonical value/error convention should reject missing error slot") } - if convention.CanClassifyReturns([]typ.Type{typ.String, typ.Nil, typ.Boolean}) { - t.Fatal("canonical value/error convention should reject extra return slots") + if !convention.CanClassifyReturns([]typ.Type{typ.String, typ.Nil, typ.Boolean}) { + t.Fatal("canonical value/error convention should allow unrelated extra return slots") } } @@ -25,9 +25,8 @@ func TestErrorReturnConventionRejectsInvalidLayout(t *testing.T) { t.Parallel() convention := ErrorReturnConvention{ - ValueIndex: 0, - ErrorIndex: 0, - ReturnCount: 1, + ValueIndex: 0, + ErrorIndex: 0, } if convention.CanClassifyReturns([]typ.Type{typ.Nil}) { t.Fatal("convention with overlapping value/error slots should be invalid") diff --git a/compiler/check/flowbuild/assign/emit.go b/compiler/check/flowbuild/assign/emit.go index 8aa731a4..d46dfae8 100644 --- a/compiler/check/flowbuild/assign/emit.go +++ b/compiler/check/flowbuild/assign/emit.go @@ -87,7 +87,9 @@ func ExtractAssignments(fc *fbcore.FlowContext, inputs *flow.Inputs, keysCollect // Uses expandValues with SpecTypes overlay for method call synthesis. specNarrowed := CollectSpecNarrowedTypes(fc.Graph, fc.Scopes, synth, symResolver, fc.API, fc.ModuleBindings) preflowBranchSolution := buildPreflowBranchSolution(fc, inputs) - inferredTypes := collectInferredTypes(fc.Graph, fc.Scopes, synth, fc.API, symResolver, specNarrowed, inputs.AnnotatedVars, inputs, fc.ModuleBindings, fc.CallCtx, fc.TypeOps, preflowBranchSolution, fc.Services) + inferenceSeeds := mergeSpecTypesInto(nil, inputs.DeclaredTypes) + inferenceSeeds = mergeSpecTypesInto(inferenceSeeds, specNarrowed) + inferredTypes := collectInferredTypes(fc.Graph, fc.Scopes, synth, fc.API, symResolver, inferenceSeeds, inputs.AnnotatedVars, inputs, fc.ModuleBindings, fc.CallCtx, fc.TypeOps, preflowBranchSolution, fc.Services) // Promote inferred parameter types into DeclaredTypes for unannotated params. // This enables bidirectional inference at call sites (e.g., custom assert helpers). if inputs.DeclaredTypes != nil { @@ -103,8 +105,8 @@ func ExtractAssignments(fc *fbcore.FlowContext, inputs *flow.Inputs, keysCollect continue } current := inputs.DeclaredTypes[sym] - if current == nil || current.Kind().IsPlaceholder() { - inputs.DeclaredTypes[sym] = inferred + if merged := mergeUnannotatedParamType(current, inferred); !typ.TypeEquals(current, merged) { + inputs.DeclaredTypes[sym] = merged } } } @@ -140,23 +142,78 @@ func ExtractAssignments(fc *fbcore.FlowContext, inputs *flow.Inputs, keysCollect if target.Kind != cfg.TargetIdent || target.Name == "" || target.Symbol == 0 { return } - if i < len(varTypes) && varTypes[i] != nil { + if i < len(varTypes) && informativeLoopVarType(varTypes[i]) { loopVarTypes[target.Symbol] = varTypes[i] } }) } }) - var overlayTypes api.SpecTypes - overlayTypes = mergeSpecTypesInto(overlayTypes, inputs.DeclaredTypes) - overlayTypes = mergeSpecTypesInto(overlayTypes, inferredTypes) - overlayTypes = mergeSpecTypesInto(overlayTypes, specNarrowed) - overlayTypes = mergeSpecTypesInto(overlayTypes, loopVarTypes) + paramSet := paramSymbolSet(fc.Graph) + valueDefs := collectValueDefinitionVersions(fc.Graph) + var overlayScratch api.SpecTypes + overlayTypesAt := func(p cfg.Point) api.SpecTypes { + size := len(inferredTypes) + len(specNarrowed) + len(loopVarTypes) + if inputs != nil { + size += len(inputs.DeclaredTypes) + } + if overlayScratch == nil { + overlayScratch = make(api.SpecTypes, size) + } else { + clear(overlayScratch) + } + if inputs != nil { + for sym, t := range inputs.DeclaredTypes { + overlayScratch[sym] = t + } + } + for sym, t := range loopVarTypes { + overlayScratch[sym] = t + } + for sym, t := range inferredTypes { + if overlayTypeVisibleAt(fc.Graph, valueDefs, paramSet, sym, p) { + overlayScratch[sym] = t + } + } + for sym, t := range specNarrowed { + overlayScratch[sym] = t + } + return overlayScratch + } + overlayTypeAt := func(sym cfg.SymbolID, p cfg.Point) (typ.Type, bool) { + if t, ok := specNarrowed[sym]; ok { + return t, true + } + var declared typ.Type + var hasDeclared bool + if inputs != nil && inputs.DeclaredTypes != nil { + if t, ok := inputs.DeclaredTypes[sym]; ok { + declared = t + hasDeclared = true + if inputs.AnnotatedVars != nil && inputs.AnnotatedVars[sym] { + return t, true + } + } + } + if t, ok := visibleInferredTypeAt(inferredTypes, fc.Graph, valueDefs, paramSet, sym, p); ok { + _, staleLoopVar := loopVarTypes[sym] + if staleLoopVar || inferredOverridesUnannotatedDeclared(t, declared) { + return t, true + } + } + if hasDeclared { + return declared, true + } + if t, ok := loopVarTypes[sym]; ok { + return t, true + } + return nil, false + } // Precompute truthy guards: map from CFG point to paths that are narrowed (non-nil) at that point. // Used during table literal synthesis to unwrap optional types. truthyGuards := guard.CollectTruthyGuards(fc.Graph, bindings) typeGuards := guard.CollectTypeGuards(fc.Graph, bindings) - baseSynth := synthWithOverlayAndPreflow(overlayTypes, bindings, inputs, fc.CallCtx, fc.TypeOps, preflowBranchSolution, synth) + baseSynth := synthWithOverlayAndPreflow(overlayTypeAt, bindings, inputs, fc.CallCtx, fc.TypeOps, preflowBranchSolution, synth) structuredWrites := indexStructuredWrites(fc.Graph) var idom map[cfg.Point]cfg.Point if len(structuredWrites) > 0 { @@ -230,7 +287,7 @@ func ExtractAssignments(fc *fbcore.FlowContext, inputs *flow.Inputs, keysCollect if len(info.IterExprs) > 0 && len(info.Targets) > 0 { var varTypes []typ.Type if fc.API != nil { - varTypes = fc.API.InferIterVarsWithSpecTypes(info.IterExprs, len(info.Targets), p, overlayTypes) + varTypes = fc.API.InferIterVarsWithSpecTypes(info.IterExprs, len(info.Targets), p, overlayTypesAt(p)) } // Build const resolver for iterator source extraction constResolver := predicate.BuildConstResolver(inputs, p) @@ -274,7 +331,7 @@ func ExtractAssignments(fc *fbcore.FlowContext, inputs *flow.Inputs, keysCollect } // Use pre-assignment symbol overlays for assignment targets so RHS // synthesis follows Lua evaluation order (`x = f(x, ...)`). - rhsOverlay := rhsSpecTypesAtAssignPoint(fc.Graph, info, p, overlayTypes, resolverWithSpec) + rhsOverlay := rhsSpecTypesAtAssignPoint(fc.Graph, info, p, overlayTypesAt(p), resolverWithSpec) rhsOverlay = enrichStructuredOverlayAtPoint(fc.Graph, idom, structuredWrites, p, rhsOverlay, resolverWithSpec, wrappedSynth) values = expandedAssignValues(fc.API, info, p, rhsOverlay) valuesComputed = true @@ -344,6 +401,12 @@ func ExtractAssignments(fc *fbcore.FlowContext, inputs *flow.Inputs, keysCollect assignedType = typ.Unknown } + if source != nil && info.IsLocal && (inputs == nil || inputs.AnnotatedVars == nil || !inputs.AnnotatedVars[sym]) { + if inferred := inferredTypes[sym]; sameExpressionHasMoreEvidence(inferred, assignedType) { + assignedType = inferred + } + } + // Use pre-collected spec-narrowed type if available (via SymbolID) if narrowed, ok := specNarrowed[sym]; ok { assignedType = narrowed @@ -527,15 +590,22 @@ func ExtractAssignments(fc *fbcore.FlowContext, inputs *flow.Inputs, keysCollect // Determine assigned type assignedType := typ.Unknown + if expected := assignmentTargetExpectedType(target, p, wrappedSynth); expected != nil { + if expectedType := synthAssignmentSourceWithExpected(fc.API, source, p, expected); expectedType != nil { + assignedType = expectedType + } + } // First check expanded values for multi-return assignments - ensureValues() - if value := assignValueAt(values, i); !typ.IsAbsentOrUnknown(value) { - assignedType = value - } else if source != nil { - if tbl, ok := source.(*ast.TableExpr); ok && wrappedSynth != nil && !tblutil.TableHasFunctionField(tbl) { - assignedType = wrappedSynth(source, p) - } else if wrappedSynth != nil { - assignedType = wrappedSynth(source, p) + if typ.IsAbsentOrUnknown(assignedType) { + ensureValues() + if value := assignValueAt(values, i); !typ.IsAbsentOrUnknown(value) { + assignedType = value + } else if source != nil { + if tbl, ok := source.(*ast.TableExpr); ok && wrappedSynth != nil && !tblutil.TableHasFunctionField(tbl) { + assignedType = wrappedSynth(source, p) + } else if wrappedSynth != nil { + assignedType = wrappedSynth(source, p) + } } } if assignedType == nil { @@ -588,15 +658,22 @@ func ExtractAssignments(fc *fbcore.FlowContext, inputs *flow.Inputs, keysCollect } // Determine assigned type assignedType := typ.Unknown + if expected := assignmentTargetExpectedType(target, p, wrappedSynth); expected != nil { + if expectedType := synthAssignmentSourceWithExpected(fc.API, source, p, expected); expectedType != nil { + assignedType = expectedType + } + } // First check expanded values for multi-return assignments - ensureValues() - if value := assignValueAt(values, i); !typ.IsAbsentOrUnknown(value) { - assignedType = value - } else if source != nil { - if tbl, ok := source.(*ast.TableExpr); ok && wrappedSynth != nil && !tblutil.TableHasFunctionField(tbl) { - assignedType = wrappedSynth(source, p) - } else if wrappedSynth != nil { - assignedType = wrappedSynth(source, p) + if typ.IsAbsentOrUnknown(assignedType) { + ensureValues() + if value := assignValueAt(values, i); !typ.IsAbsentOrUnknown(value) { + assignedType = value + } else if source != nil { + if tbl, ok := source.(*ast.TableExpr); ok && wrappedSynth != nil && !tblutil.TableHasFunctionField(tbl) { + assignedType = wrappedSynth(source, p) + } else if wrappedSynth != nil { + assignedType = wrappedSynth(source, p) + } } } if assignedType == nil { @@ -1126,6 +1203,48 @@ func ExtractFuncDefAssignments(fc *fbcore.FlowContext, inputs *flow.Inputs) { }) } +func assignmentTargetExpectedType( + target cfg.AssignTarget, + p cfg.Point, + synth func(ast.Expr, cfg.Point) typ.Type, +) typ.Type { + if target.Expr == nil || synth == nil { + return nil + } + expected := synth(target.Expr, p) + if expected == nil || typ.IsAny(expected) || typ.IsUnknown(expected) || typ.IsSoft(expected, typ.SoftAnnotationPolicy) { + return nil + } + if inner, nilable := typ.SplitNilableFieldType(expected); nilable { + return inner + } + return expected +} + +type expectedAssignmentSynth interface { + TypeOfWithExpected(ast.Expr, cfg.Point, typ.Type) typ.Type +} + +func synthAssignmentSourceWithExpected(synthAPI api.SynthAPI, source ast.Expr, p cfg.Point, expected typ.Type) typ.Type { + if synthAPI == nil || source == nil || expected == nil { + return nil + } + switch source.(type) { + case *ast.TableExpr, *ast.FunctionExpr, *ast.LogicalOpExpr: + default: + return nil + } + withExpected, ok := synthAPI.(expectedAssignmentSynth) + if !ok { + return nil + } + inferred := withExpected.TypeOfWithExpected(source, p, expected) + if inferred == nil || typ.IsAbsentOrUnknown(inferred) { + return nil + } + return inferred +} + func isTopLikeResolvedAssignType(t typ.Type) bool { if t == nil { return true diff --git a/compiler/check/flowbuild/assign/emit_test.go b/compiler/check/flowbuild/assign/emit_test.go index df0f9a96..329166b3 100644 --- a/compiler/check/flowbuild/assign/emit_test.go +++ b/compiler/check/flowbuild/assign/emit_test.go @@ -705,6 +705,22 @@ func TestCorrelationsFromFunctionType_ImplicitLuaErrorConvention(t *testing.T) { } } +func TestCorrelationsFromFunctionType_ImplicitLuaErrorConventionWithExtraReturns(t *testing.T) { + fnType := typ.Func(). + Returns(typ.NewOptional(typ.String), typ.NewOptional(typ.LuaError), typ.NewOptional(typ.Boolean)). + Build() + inverse, co := correlationsFromFunctionType(fnType) + if len(co) != 0 { + t.Fatalf("expected no co-correlations, got %v", co) + } + if len(inverse) != 1 { + t.Fatalf("expected one convention-based correlation, got %v", inverse) + } + if inverse[0] != (flow.ReturnCorrelation{ValueIndex: 0, ErrorIndex: 1}) { + t.Fatalf("unexpected convention correlation: %+v", inverse[0]) + } +} + func TestCorrelationsFromFunctionType_ImplicitStringErrorConvention(t *testing.T) { fnType := typ.Func(). Returns(typ.NewOptional(typ.String), typ.NewOptional(typ.String)). diff --git a/compiler/check/flowbuild/assign/error_return_policy.go b/compiler/check/flowbuild/assign/error_return_policy.go index e5235dab..f9d6e39e 100644 --- a/compiler/check/flowbuild/assign/error_return_policy.go +++ b/compiler/check/flowbuild/assign/error_return_policy.go @@ -13,7 +13,7 @@ import ( // from a function signature when no explicit effect labels are present. // // Rule: -// - Signature must have exactly two returns. +// - Signature must expose at least the conventional value and error slots. // - Error slot is selected by conventional position with type-based precedence: // - Prefer return[1] when it is Optional or Optional. // - Otherwise allow return[0] only when return[1] is not error-like and @@ -24,7 +24,7 @@ import ( // policy centralized and deterministic. func InferErrorReturnConvention(fnType typ.Type) ([]flow.ReturnCorrelation, []flow.ReturnCorrelation) { fn := unwrap.Function(fnType) - if fn == nil || len(fn.Returns) != 2 { + if fn == nil || len(fn.Returns) < 2 { return nil, nil } diff --git a/compiler/check/flowbuild/assign/infer.go b/compiler/check/flowbuild/assign/infer.go index 6e3b70f0..978085fd 100644 --- a/compiler/check/flowbuild/assign/infer.go +++ b/compiler/check/flowbuild/assign/infer.go @@ -62,8 +62,10 @@ import ( synthpkg "github.com/wippyai/go-lua/compiler/check/synth" "github.com/wippyai/go-lua/compiler/check/synth/ops" "github.com/wippyai/go-lua/internal" + "github.com/wippyai/go-lua/types/constraint" "github.com/wippyai/go-lua/types/db" "github.com/wippyai/go-lua/types/flow" + flowjoin "github.com/wippyai/go-lua/types/flow/join" "github.com/wippyai/go-lua/types/query/core" "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" @@ -155,6 +157,7 @@ func collectInferredTypes( if len(structuredWrites) > 0 { idom = cfganalysis.ComputeImmediateDominators(graph.CFG()) } + valueDefs := collectValueDefinitionVersions(graph) bindings := graph.Bindings() if moduleBindings == nil { @@ -279,15 +282,25 @@ func collectInferredTypes( continue } for _, target := range entry.info.Targets { - if target.Kind != cfg.TargetIdent || target.Symbol == 0 { + var sym cfg.SymbolID + switch target.Kind { + case cfg.TargetIdent: + sym = target.Symbol + case cfg.TargetField: + if paramSet[target.BaseSymbol] { + sym = target.BaseSymbol + } + } + if sym == 0 { continue } - assignIdxByTargetSym[target.Symbol] = append(assignIdxByTargetSym[target.Symbol], idx) + assignIdxByTargetSym[sym] = append(assignIdxByTargetSym[sym], idx) } } callArgSymbolsByIdx := make([][]cfg.SymbolID, len(calls)) - callIdxByParamArgSym := make(map[cfg.SymbolID][]int) + callReceiverSymbolByIdx := make([]cfg.SymbolID, len(calls)) + callIdxByArgSym := make(map[cfg.SymbolID][]int) callIdxByRefSym := make(map[cfg.SymbolID][]int) for idx, entry := range calls { if entry.info == nil { @@ -295,12 +308,23 @@ func collectInferredTypes( } argSymbols := normalizedCallArgSymbols(entry.info, bindings) callArgSymbolsByIdx[idx] = argSymbols + receiverSym := normalizedCallReceiverSymbol(entry.info, bindings) + callReceiverSymbolByIdx[idx] = receiverSym + if receiverSym != 0 { + callIdxByArgSym[receiverSym] = append(callIdxByArgSym[receiverSym], idx) + } for _, sym := range argSymbols { - if sym == 0 || !paramSet[sym] { + if sym == 0 { continue } - callIdxByParamArgSym[sym] = append(callIdxByParamArgSym[sym], idx) + callIdxByArgSym[sym] = append(callIdxByArgSym[sym], idx) + } + for _, arg := range entry.info.Args { + argPath := path.FromExprWithBindings(arg, nil, bindings) + if argPath.Symbol != 0 && len(argPath.Segments) > 0 && paramSet[argPath.Symbol] { + callIdxByArgSym[argPath.Symbol] = append(callIdxByArgSym[argPath.Symbol], idx) + } } for _, sym := range callRefSymbols(entry.info, bindings) { @@ -324,10 +348,19 @@ func collectInferredTypes( for _, entry := range assigns { info := entry.info for _, target := range info.Targets { - if target.Kind != cfg.TargetIdent || target.Symbol == 0 { + var targetSymID cfg.SymbolID + switch target.Kind { + case cfg.TargetIdent: + targetSymID = target.Symbol + case cfg.TargetField: + if paramSet[target.BaseSymbol] { + targetSymID = target.BaseSymbol + } + } + if targetSymID == 0 { continue } - targetSym := uint64(target.Symbol) + targetSym := uint64(targetSymID) if deps[targetSym] == nil { deps[targetSym] = nil // ensure node exists } @@ -393,6 +426,43 @@ func collectInferredTypes( deps[targetKey] = append(deps[targetKey], uint64(ref)) } } + for _, entry := range calls { + info := entry.info + if info == nil { + continue + } + var calleeRefs []cfg.SymbolID + collectExprSymbols(info.Callee, bindings, &calleeRefs) + collectExprSymbols(info.Receiver, bindings, &calleeRefs) + calleeRefs = dedupeSymbolIDs(calleeRefs) + if len(calleeRefs) == 0 { + continue + } + addArgExpectationDeps := func(sym cfg.SymbolID) { + if sym == 0 { + return + } + targetKey := uint64(sym) + if deps[targetKey] == nil { + deps[targetKey] = nil + } + for _, ref := range calleeRefs { + if ref == 0 || ref == sym { + continue + } + deps[targetKey] = append(deps[targetKey], uint64(ref)) + } + } + for _, sym := range normalizedCallArgSymbols(info, bindings) { + addArgExpectationDeps(sym) + } + for _, arg := range info.Args { + argPath := path.FromExprWithBindings(arg, nil, bindings) + if argPath.Symbol != 0 && len(argPath.Segments) > 0 && paramSet[argPath.Symbol] { + addArgExpectationDeps(argPath.Symbol) + } + } + } // Deduplicate edges for sym, edges := range deps { @@ -416,7 +486,7 @@ func collectInferredTypes( sccs := internal.ComputeSCCs(deps) // Process each SCC in topological order - for sccIdx, scc := range sccs { + for _, scc := range sccs { if len(scc) == 0 { continue } @@ -432,7 +502,7 @@ func collectInferredTypes( markEpoch++ sccAssignIdx := make([]int, 0, len(scc)) - sccParamCallIdx := make([]int, 0, len(scc)) + sccArgCallIdx := make([]int, 0, len(scc)) sccMutatorCallIdx := make([]int, 0, len(scc)) for _, sym := range sccSyms { for _, idx := range assignIdxByTargetSym[sym] { @@ -442,12 +512,12 @@ func collectInferredTypes( assignIdxMarks[idx] = markEpoch sccAssignIdx = append(sccAssignIdx, idx) } - for _, idx := range callIdxByParamArgSym[sym] { + for _, idx := range callIdxByArgSym[sym] { if paramCallIdxMarks[idx] == markEpoch { continue } paramCallIdxMarks[idx] = markEpoch - sccParamCallIdx = append(sccParamCallIdx, idx) + sccArgCallIdx = append(sccArgCallIdx, idx) } for _, idx := range callIdxByRefSym[sym] { if mutatorCallIdxMarks[idx] == markEpoch { @@ -466,7 +536,7 @@ func collectInferredTypes( overlayScratch = mergeSpecTypesSoftInto(overlayScratch, inferred, specTypes) overlay := overlayScratch - wrappedSynth := synthWithInferenceOverlay(graph, overlay, funcSigTypes, paramSet, annotated, bindings, inputs, callCtx, typeOps, preflowBranchSolution, synth) + wrappedSynth := synthWithInferenceOverlay(graph, inferred, specTypes, funcSigTypes, valueDefs, paramSet, annotated, bindings, inputs, callCtx, typeOps, preflowBranchSolution, synth) callSynthFor := func(p cfg.Point, info *cfg.CallInfo) func(ast.Expr, cfg.Point) typ.Type { if info == nil { return wrappedSynth @@ -486,15 +556,16 @@ func collectInferredTypes( return t, ok } } - callOverlay := rhsSpecTypesAtAssignPoint(graph, owner, p, overlay, func(point cfg.Point, sym cfg.SymbolID) (typ.Type, bool) { - if t, ok := overlay[sym]; ok && t != nil && !t.Kind().IsPlaceholder() { + callOverlayBase := inferenceOverlayAtPoint(graph, p, inferred, specTypes, funcSigTypes, valueDefs, paramSet) + callOverlay := rhsSpecTypesAtAssignPoint(graph, owner, p, callOverlayBase, func(point cfg.Point, sym cfg.SymbolID) (typ.Type, bool) { + if t, ok := callOverlayBase[sym]; ok && t != nil && !t.Kind().IsPlaceholder() { return t, true } return rhsResolver(point, sym) }) callOverlay = enrichStructuredOverlayAtPoint(graph, idom, structuredWrites, p, callOverlay, rhsResolver, wrappedSynth) - return synthWithInferenceOverlay(graph, callOverlay, funcSigTypes, paramSet, annotated, bindings, inputs, callCtx, typeOps, preflowBranchSolution, synth) + return synthWithOverlayAndPreflow(mapOverlayTypeAt(callOverlay), bindings, inputs, callCtx, typeOps, preflowBranchSolution, wrappedBaseForInference(bindings, paramSet, annotated, synth)) } // Infer expected argument types for a call using the call inference pipeline. @@ -671,8 +742,9 @@ func collectInferredTypes( return t, ok } } - rhsOverlay := rhsSpecTypesAtAssignPoint(graph, info, p, overlay, func(point cfg.Point, sym cfg.SymbolID) (typ.Type, bool) { - if t, ok := overlay[sym]; ok && t != nil && !t.Kind().IsPlaceholder() { + rhsOverlayBase := inferenceOverlayAtPoint(graph, p, inferred, specTypes, funcSigTypes, valueDefs, paramSet) + rhsOverlay := rhsSpecTypesAtAssignPoint(graph, info, p, rhsOverlayBase, func(point cfg.Point, sym cfg.SymbolID) (typ.Type, bool) { + if t, ok := rhsOverlayBase[sym]; ok && t != nil && !t.Kind().IsPlaceholder() { return t, true } return rhsResolver(point, sym) @@ -698,12 +770,76 @@ func collectInferredTypes( inferred[target.Symbol] = joined changed = true } + case cfg.TargetField: + if target.BaseSymbol == 0 || len(target.FieldPath) == 0 { + continue + } + if !paramSet[target.BaseSymbol] { + continue + } + if !sccSet[target.BaseSymbol] { + continue + } + if annotated != nil && annotated[target.BaseSymbol] { + continue + } + assignedType := typ.Unknown + if !valuesComputed { + rhsResolver := symResolver + if rhsResolver == nil { + rhsResolver = func(_ cfg.Point, sym cfg.SymbolID) (typ.Type, bool) { + t, ok := overlay[sym] + return t, ok + } + } + rhsOverlayBase := inferenceOverlayAtPoint(graph, p, inferred, specTypes, funcSigTypes, valueDefs, paramSet) + rhsOverlay := rhsSpecTypesAtAssignPoint(graph, info, p, rhsOverlayBase, func(point cfg.Point, sym cfg.SymbolID) (typ.Type, bool) { + if t, ok := rhsOverlayBase[sym]; ok && t != nil && !t.Kind().IsPlaceholder() { + return t, true + } + return rhsResolver(point, sym) + }) + rhsOverlay = enrichStructuredOverlayAtPoint(graph, idom, structuredWrites, p, rhsOverlay, rhsResolver, wrappedSynth) + values = expandedAssignValues(synthAPI, info, p, rhsOverlay) + valuesComputed = true + } + if value := assignValueAt(values, i); !typ.IsAbsentOrUnknown(value) { + assignedType = value + } else if wrappedSynth != nil && source != nil { + assignedType = wrappedSynth(source, p) + } + assignedType = resolve.Ref(assignedType, sc) + if typ.IsAbsentOrUnknown(assignedType) { + continue + } + segments := make([]constraint.Segment, 0, len(target.FieldPath)) + for _, field := range target.FieldPath { + if field == "" { + continue + } + segments = append(segments, constraint.Segment{Kind: constraint.SegmentField, Name: field}) + } + if len(segments) == 0 { + continue + } + old := inferred[target.BaseSymbol] + updated := mergeExpectedAtPath(old, segments, assignedType, paramSet[target.BaseSymbol]) + if updated == nil { + continue + } + if !typ.TypeEquals(old, updated) { + inferred[target.BaseSymbol] = updated + changed = true + } } } } - // Infer parameter types from call argument expectations. - for _, idx := range sccParamCallIdx { + // Infer unannotated symbol types from call argument expectations. + // Parameters keep the traditional bidirectional behavior. Locals only + // accept expected types while still top-like, so a concrete assignment + // is not hidden by a later incompatible call. + for _, idx := range sccArgCallIdx { entry := calls[idx] p := entry.p info := entry.info @@ -716,40 +852,76 @@ func collectInferredTypes( sc = scopes[graph.Entry()] } expectedArgs, expectedVariadic := inferExpectedArgs(p, info, synthForCall) + + if receiverSym := callReceiverSymbolByIdx[idx]; receiverSym != 0 && sccSet[receiverSym] { + if annotated == nil || !annotated[receiverSym] { + expected := expectedReceiverTypeForMethod(callCtx, typeOps, info) + if expected != nil && !expected.Kind().IsPlaceholder() { + old := inferred[receiverSym] + if paramSet[receiverSym] || callExpectationCanRefineLocal(old) { + joined := mergeCallExpectation(old, expected, paramSet[receiverSym]) + if !typ.TypeEquals(old, joined) { + inferred[receiverSym] = joined + changed = true + } + } + } + } + } + callArgSymbols := callArgSymbolsByIdx[idx] for i := range info.Args { + expected := expectedArgAt(i, expectedArgs, expectedVariadic) var sym cfg.SymbolID if i < len(callArgSymbols) { sym = callArgSymbols[i] } - if sym == 0 || !sccSet[sym] { - continue - } - if !paramSet[sym] { - continue - } - if annotated != nil && annotated[sym] { - continue - } - expected := expectedArgAt(i, expectedArgs, expectedVariadic) - if typ.IsAbsentOrUnknown(expected) { - // Fall back to actual argument type when no expected type is available. - if i < len(info.Args) && info.Args[i] != nil { - actual := synthForCall(info.Args[i], p) - actual = resolve.Ref(actual, sc) - if actual != nil && !actual.Kind().IsPlaceholder() { - expected = actual + if sym != 0 && sccSet[sym] { + if annotated != nil && annotated[sym] { + continue + } + if typ.IsAbsentOrUnknown(expected) { + // Fall back to actual argument type when no expected type is available. + if i < len(info.Args) && info.Args[i] != nil { + actual := synthForCall(info.Args[i], p) + actual = resolve.Ref(actual, sc) + if actual != nil && !actual.Kind().IsPlaceholder() { + expected = actual + } } } + if expected == nil || expected.Kind().IsPlaceholder() { + continue + } + old := inferred[sym] + if !paramSet[sym] && !callExpectationCanRefineLocal(old) { + continue + } + joined := mergeCallExpectation(old, expected, paramSet[sym]) + if !typ.TypeEquals(old, joined) { + inferred[sym] = joined + changed = true + } } - if expected == nil || expected.Kind().IsPlaceholder() { - continue - } - old := inferred[sym] - joined := joinInferredType(old, expected) - if !typ.TypeEquals(old, joined) { - inferred[sym] = joined - changed = true + if i < len(info.Args) && expected != nil && !expected.Kind().IsPlaceholder() { + argPath := path.FromExprWithBindings(info.Args[i], nil, bindings) + if argPath.Symbol != 0 && len(argPath.Segments) > 0 && sccSet[argPath.Symbol] { + if !paramSet[argPath.Symbol] { + continue + } + if annotated != nil && annotated[argPath.Symbol] { + continue + } + old := inferred[argPath.Symbol] + if !paramSet[argPath.Symbol] && !callExpectationCanRefineLocal(old) { + continue + } + joined := mergePathCallExpectation(old, argPath.Segments, expected, paramSet[argPath.Symbol]) + if !typ.TypeEquals(old, joined) { + inferred[argPath.Symbol] = joined + changed = true + } + } } } } @@ -782,18 +954,20 @@ func collectInferredTypes( // Handle indexed targets (t[k]) even when key is non-const. if attr, ok := targetExpr.(*ast.AttrGetExpr); ok { - baseSym := callsite.SymbolOrCreateFieldFromExpr(attr.Object, bindings) - if baseSym != 0 && sccSet[baseSym] { - keyType := wrappedSynth(attr.Key, p) - keyType = resolve.Ref(keyType, sc) - keyType = canonicalDynamicKeyType(keyType) - old := inferred[baseSym] - newType := flow.WidenMapValueArray(old, keyType, valueType) - if newType != nil && !typ.TypeEquals(old, newType) { - inferred[baseSym] = newType - changed = true + if _, static := path.StaticKeySegment(attr.Key); !static { + baseSym := callsite.SymbolOrCreateFieldFromExpr(attr.Object, bindings) + if baseSym != 0 && sccSet[baseSym] { + keyType := wrappedSynth(attr.Key, p) + keyType = resolve.Ref(keyType, sc) + keyType = canonicalDynamicKeyType(keyType) + old := inferred[baseSym] + newType := flow.WidenMapValueArray(old, keyType, valueType) + if newType != nil && !typ.TypeEquals(old, newType) { + inferred[baseSym] = newType + changed = true + } + continue } - continue } } @@ -803,8 +977,11 @@ func collectInferredTypes( if !sccSet[targetPath.Symbol] { continue } + if len(targetPath.Segments) > 0 && !paramSet[targetPath.Symbol] { + continue + } old := inferred[targetPath.Symbol] - newType := flow.WidenArrayElementType(old, valueType, typ.JoinPreferNonSoft) + newType := widenArrayElementAtPath(old, targetPath.Segments, valueType) if newType == nil || typ.TypeEquals(old, newType) { continue } @@ -819,20 +996,16 @@ func collectInferredTypes( } // Widen ALL symbols in non-converged SCC to Unknown (except annotated). - // This is sound: partial types may be under-approximations. + // This is sound: partial types may be under-approximations. Local + // preflow inference is a hint source, so the fallback is intentionally + // internal; surfacing it as a lint warning produces false positives for + // dynamic but valid Lua patterns. if !converged { for _, sym := range sccSyms { if annotated != nil && annotated[sym] { continue } inferred[sym] = typ.Unknown - if inputs != nil { - inputs.WideningEvents = append(inputs.WideningEvents, flow.WideningEvent{ - Symbol: sym, - SCCIndex: sccIdx, - SCC: sccSyms, - }) - } } } } @@ -845,6 +1018,11 @@ func collectInferredTypes( if annotated != nil && annotated[sym] { continue } + if inputs != nil && inputs.DeclaredTypes != nil { + if declared := inputs.DeclaredTypes[sym]; !typ.IsAbsentOrUnknown(declared) { + continue + } + } if t, ok := inferred[sym]; !ok || typ.IsAbsentOrUnknown(t) { inferred[sym] = typ.Any } @@ -869,6 +1047,37 @@ func normalizedCallArgSymbols(info *cfg.CallInfo, bindings *bind.BindingTable) [ return out } +func normalizedCallReceiverSymbol(info *cfg.CallInfo, bindings *bind.BindingTable) cfg.SymbolID { + if info == nil || info.Method == "" { + return 0 + } + if info.ReceiverSymbol != 0 { + return info.ReceiverSymbol + } + if bindings == nil { + return 0 + } + return callsite.SymbolFromExpr(info.Receiver, bindings) +} + +func expectedReceiverTypeForMethod(ctx *db.QueryContext, typeOps core.TypeOps, info *cfg.CallInfo) typ.Type { + if info == nil || info.Method == "" { + return nil + } + if typeOps == nil { + return nil + } + methodType, ok := typeOps.Method(ctx, typ.String, info.Method) + if !ok || methodType == nil { + return nil + } + fn, ok := methodType.(*typ.Function) + if !ok || len(fn.Params) == 0 || !typ.TypeEquals(fn.Params[0].Type, typ.String) { + return nil + } + return typ.String +} + func callRefSymbols(info *cfg.CallInfo, bindings *bind.BindingTable) []cfg.SymbolID { if info == nil || bindings == nil { return nil @@ -937,8 +1146,10 @@ func dedupeSymbolIDs(refs []cfg.SymbolID) []cfg.SymbolID { func synthWithInferenceOverlay( graph *cfg.Graph, - overlay map[cfg.SymbolID]typ.Type, + inferred map[cfg.SymbolID]typ.Type, + seedTypes map[cfg.SymbolID]typ.Type, funcSigTypes map[cfg.SymbolID]typ.Type, + valueDefs map[symbolVersionKey]struct{}, paramSet map[cfg.SymbolID]bool, annotated map[cfg.SymbolID]bool, bindings *bind.BindingTable, @@ -948,18 +1159,47 @@ func synthWithInferenceOverlay( preflow *flow.Solution, base func(ast.Expr, cfg.Point) typ.Type, ) func(ast.Expr, cfg.Point) typ.Type { - _ = graph - mergedOverlay := make(map[cfg.SymbolID]typ.Type, len(overlay)+len(funcSigTypes)) - for sym, t := range funcSigTypes { - if t != nil { - mergedOverlay[sym] = t + lookup := func(sym cfg.SymbolID, p cfg.Point) (typ.Type, bool) { + var seed typ.Type + var hasSeed bool + if t, ok := seedTypes[sym]; ok { + seed = t + hasSeed = true + if annotated != nil && annotated[sym] { + return t, true + } } + if _, ok := inferred[sym]; ok { + if t, visible := visibleInferredTypeAt(inferred, graph, valueDefs, paramSet, sym, p); visible { + if t == nil { + return nil, true + } + if inferredOverridesUnannotatedDeclared(t, seed) { + return t, true + } + } + } + if hasSeed { + return seed, true + } + if t, ok := funcSigTypes[sym]; ok { + if overlayTypeVisibleAt(graph, valueDefs, paramSet, sym, p) { + return t, true + } + } + return nil, false } - for sym, t := range overlay { - mergedOverlay[sym] = t - } - wrappedBase := func(expr ast.Expr, p cfg.Point) typ.Type { + return synthWithOverlayAndPreflow(lookup, bindings, inputs, callCtx, typeOps, preflow, wrappedBaseForInference(bindings, paramSet, annotated, base)) +} + +func wrappedBaseForInference( + bindings *bind.BindingTable, + paramSet map[cfg.SymbolID]bool, + annotated map[cfg.SymbolID]bool, + base func(ast.Expr, cfg.Point) typ.Type, +) func(ast.Expr, cfg.Point) typ.Type { + return func(expr ast.Expr, p cfg.Point) typ.Type { if ident, ok := expr.(*ast.IdentExpr); ok && bindings != nil { if sym, ok := bindings.SymbolOf(ident); ok && sym != 0 { if paramSet[sym] && (annotated == nil || !annotated[sym]) { @@ -972,8 +1212,6 @@ func synthWithInferenceOverlay( } return base(expr, p) } - - return synthWithOverlayAndPreflow(mergedOverlay, bindings, inputs, callCtx, typeOps, preflow, wrappedBase) } func assignmentOwningSourceCall(assigns []*cfg.AssignInfo, call *cfg.CallInfo) *cfg.AssignInfo { @@ -1096,7 +1334,285 @@ func joinInferredType(old, next typ.Type) typ.Type { } return subtype.WidenForInference(next) } - return typ.JoinPreferNonSoft(old, next) + return subtype.WidenForInference(flowjoin.Types(old, next)) +} + +func callExpectationCanRefineLocal(old typ.Type) bool { + return old == nil || + typ.IsUnknown(old) || + typ.IsSoft(old, typ.SoftAnnotationPolicy) +} + +func mergeCallExpectation(old, expected typ.Type, isParam bool) typ.Type { + if isParam { + if expectedParamTypeDominates(old, expected) { + return expected + } + return joinInferredType(old, expected) + } + if callExpectationCanRefineLocal(old) { + return expected + } + return joinInferredType(old, expected) +} + +func expectedParamTypeDominates(old, expected typ.Type) bool { + if typ.IsAbsentOrUnknown(old) || typ.IsAbsentOrUnknown(expected) { + return false + } + if typ.IsAny(old) || typ.IsAny(expected) || expected.Kind().IsPlaceholder() { + return false + } + if subtype.IsSubtype(old, expected) { + return true + } + oldRec := recordForPathMerge(old) + expectedRec := recordForPathMerge(expected) + if oldRec == nil || expectedRec == nil { + return false + } + return recordEvidenceCompatibleWithExpected(oldRec, expectedRec) +} + +func recordEvidenceCompatibleWithExpected(old, expected *typ.Record) bool { + if old == nil || expected == nil { + return false + } + for _, field := range old.Fields { + expectedField := expected.GetField(field.Name) + if expectedField == nil { + if expected.Open { + continue + } + return false + } + if fieldEvidenceIsUnresolved(field.Type) { + continue + } + expectedType := expectedField.Type + if expectedField.Optional { + expectedType = typ.NewOptional(expectedType) + } + fieldType := field.Type + if field.Optional { + fieldType = typ.NewOptional(fieldType) + } + if !subtype.IsSubtype(fieldType, expectedType) { + return false + } + } + if old.HasMapComponent() { + if !expected.HasMapComponent() { + return false + } + if !fieldEvidenceIsUnresolved(old.MapKey) && !subtype.IsSubtype(old.MapKey, expected.MapKey) { + return false + } + if !fieldEvidenceIsUnresolved(old.MapValue) && !subtype.IsSubtype(old.MapValue, expected.MapValue) { + return false + } + } + return true +} + +func fieldEvidenceIsUnresolved(t typ.Type) bool { + if typ.IsAbsentOrUnknown(t) { + return true + } + switch v := typ.UnwrapAnnotated(t).(type) { + case *typ.Alias: + return fieldEvidenceIsUnresolved(v.Target) + case *typ.Record: + return len(v.Fields) == 0 && !v.HasMapComponent() + default: + return false + } +} + +func mergePathCallExpectation(old typ.Type, segments []constraint.Segment, expected typ.Type, isParam bool) typ.Type { + if len(segments) == 0 { + return mergeCallExpectation(old, expected, isParam) + } + if expected == nil || expected.Kind().IsPlaceholder() || typ.IsAbsentOrUnknown(expected) { + return old + } + return mergeExpectedAtPath(old, segments, expected, isParam) +} + +func mergeExpectedAtPath(base typ.Type, segments []constraint.Segment, expected typ.Type, isParam bool) typ.Type { + if len(segments) == 0 { + return mergeCallExpectation(base, expected, isParam) + } + seg := segments[0] + field, ok := segmentFieldName(seg) + if !ok { + return base + } + + rec := recordForPathMerge(base) + child := typ.Type(nil) + wasOptional := isParam + if rec != nil { + if existing := rec.GetField(field); existing != nil { + child = existing.Type + wasOptional = wasOptional || existing.Optional + } else if rec.HasMapComponent() && rec.MapValue != nil { + child = rec.MapValue + wasOptional = true + } + } + if child == nil { + child = typ.Unknown + } + mergedChild := mergeExpectedAtPath(child, segments[1:], expected, isParam) + if mergedChild == nil { + return base + } + return setRecordField(base, field, mergedChild, wasOptional) +} + +func widenArrayElementAtPath(base typ.Type, segments []constraint.Segment, element typ.Type) typ.Type { + if len(segments) == 0 { + return flow.WidenArrayElementType(base, element, typ.JoinPreferNonSoft) + } + seg := segments[0] + field, ok := segmentFieldName(seg) + if !ok { + return base + } + + rec := recordForPathMerge(base) + child := typ.Type(nil) + optional := false + if rec != nil { + if existing := rec.GetField(field); existing != nil { + child = existing.Type + optional = existing.Optional + } else if rec.HasMapComponent() && rec.MapValue != nil { + child = rec.MapValue + optional = true + } + } + updated := widenArrayElementAtPath(child, segments[1:], element) + if updated == nil { + return base + } + return setRecordField(base, field, updated, optional) +} + +func segmentFieldName(seg constraint.Segment) (string, bool) { + switch seg.Kind { + case constraint.SegmentField, constraint.SegmentIndexString: + return seg.Name, seg.Name != "" + default: + return "", false + } +} + +func recordForPathMerge(t typ.Type) *typ.Record { + for { + switch v := typ.UnwrapAnnotated(t).(type) { + case *typ.Alias: + t = v.Target + case *typ.Optional: + t = v.Inner + case *typ.Record: + return v + default: + return nil + } + } +} + +func setRecordField(base typ.Type, field string, fieldType typ.Type, optional bool) typ.Type { + if field == "" || fieldType == nil { + return base + } + switch v := typ.UnwrapAnnotated(base).(type) { + case *typ.Alias: + updated := setRecordField(v.Target, field, fieldType, optional) + if updated == nil || typ.TypeEquals(updated, v.Target) { + return base + } + return typ.NewAlias(v.Name, updated) + case *typ.Union: + updated := make([]typ.Type, 0, len(v.Members)) + changed := false + for _, member := range v.Members { + if member == nil || typ.IsAny(member) || typ.TypeEquals(member, typ.Nil) { + updated = append(updated, member) + continue + } + next := setRecordField(member, field, fieldType, optional) + if next == nil { + next = member + } + if !typ.TypeEquals(member, next) { + changed = true + } + updated = append(updated, next) + } + if !changed { + return base + } + return typ.NewUnion(updated...) + case *typ.Optional: + updated := setRecordField(v.Inner, field, fieldType, optional) + if updated == nil || typ.TypeEquals(updated, v.Inner) { + return base + } + return typ.NewOptional(updated) + case *typ.Record: + return rebuildRecordWithField(v, field, fieldType, optional) + default: + builder := typ.NewRecord().SetOpen(true) + if optional { + builder.OptField(field, fieldType) + } else { + builder.Field(field, fieldType) + } + return builder.Build() + } +} + +func rebuildRecordWithField(rec *typ.Record, field string, fieldType typ.Type, optional bool) typ.Type { + builder := typ.NewRecord() + if rec.Open { + builder.SetOpen(true) + } + if rec.Metatable != nil { + builder.Metatable(rec.Metatable) + } + if rec.HasMapComponent() { + builder.MapComponent(rec.MapKey, rec.MapValue) + } + + added := false + for _, f := range rec.Fields { + if f.Name != field { + addRecordField(builder, f.Name, f.Type, f.Optional, f.Readonly) + continue + } + addRecordField(builder, f.Name, fieldType, optional || f.Optional, f.Readonly) + added = true + } + if !added { + addRecordField(builder, field, fieldType, optional, false) + } + return builder.Build() +} + +func addRecordField(builder *typ.RecordBuilder, name string, fieldType typ.Type, optional, readonly bool) { + switch { + case optional && readonly: + builder.OptReadonlyField(name, fieldType) + case optional: + builder.OptField(name, fieldType) + case readonly: + builder.ReadonlyField(name, fieldType) + default: + builder.Field(name, fieldType) + } } func typeContains(haystack, needle typ.Type) bool { diff --git a/compiler/check/flowbuild/assign/infer_test.go b/compiler/check/flowbuild/assign/infer_test.go index 1248dcee..925f0ee7 100644 --- a/compiler/check/flowbuild/assign/infer_test.go +++ b/compiler/check/flowbuild/assign/infer_test.go @@ -381,7 +381,9 @@ func TestSynthWithInferenceOverlay_PriorityAndParamFallback(t *testing.T) { synth := synthWithInferenceOverlay( nil, map[cfg.SymbolID]typ.Type{aSym: typ.String}, + nil, map[cfg.SymbolID]typ.Type{aSym: typ.Number}, + nil, paramSet, nil, bindings, @@ -396,9 +398,11 @@ func TestSynthWithInferenceOverlay_PriorityAndParamFallback(t *testing.T) { } synth = synthWithInferenceOverlay( + nil, nil, nil, map[cfg.SymbolID]typ.Type{aSym: typ.Number}, + nil, paramSet, nil, bindings, @@ -413,6 +417,8 @@ func TestSynthWithInferenceOverlay_PriorityAndParamFallback(t *testing.T) { } synth = synthWithInferenceOverlay( + nil, + nil, nil, nil, nil, @@ -430,6 +436,8 @@ func TestSynthWithInferenceOverlay_PriorityAndParamFallback(t *testing.T) { } synth = synthWithInferenceOverlay( + nil, + nil, nil, nil, nil, @@ -460,6 +468,8 @@ func TestSynthWithInferenceOverlay_PreservesNilOverlayEntries(t *testing.T) { nil, nil, nil, + nil, + nil, bindings, nil, nil, diff --git a/compiler/check/flowbuild/assign/precision.go b/compiler/check/flowbuild/assign/precision.go index f7ef8a60..3190d12f 100644 --- a/compiler/check/flowbuild/assign/precision.go +++ b/compiler/check/flowbuild/assign/precision.go @@ -46,6 +46,9 @@ func preferPreciseDirectSourceType( if preferNamedEquivalentDirectType(precise, assignedType) { return precise } + if sameExpressionHasMoreEvidence(precise, assignedType) { + return precise + } return assignedType } if typ.IsAny(assignedType) && !typ.IsAny(precise) { @@ -69,3 +72,166 @@ func isNamedIdentityType(t typ.Type) bool { return false } } + +func sameExpressionHasMoreEvidence(precise, assigned typ.Type) bool { + improved, ok := compareSameExpressionEvidence(precise, assigned, 0) + return ok && improved +} + +func mergeUnannotatedParamType(current, inferred typ.Type) typ.Type { + if typ.IsAbsentOrUnknown(inferred) || typ.IsAny(inferred) { + return current + } + if current == nil || current.Kind().IsPlaceholder() || typ.IsUnknown(current) { + return inferred + } + if typ.IsAny(current) || subtype.IsSubtype(current, inferred) { + return current + } + return inferred +} + +func inferredOverridesUnannotatedDeclared(inferred, declared typ.Type) bool { + if typ.IsAbsentOrUnknown(inferred) { + return false + } + if declared == nil || typ.IsAbsentOrUnknown(declared) || declared.Kind().IsPlaceholder() || typ.IsSoft(declared, typ.SoftAnnotationPolicy) { + return true + } + if typ.IsAny(declared) { + return false + } + if subtype.IsSubtype(inferred, declared) && !subtype.IsSubtype(declared, inferred) { + return true + } + if sameExpressionHasMoreEvidence(inferred, declared) { + return true + } + return false +} + +func informativeLoopVarType(t typ.Type) bool { + return t != nil && !typ.IsAbsentOrUnknown(t) && !typ.IsAny(t) && !t.Kind().IsPlaceholder() +} + +func compareSameExpressionEvidence(precise, assigned typ.Type, depth int) (bool, bool) { + if depth > typ.DefaultRecursionDepth { + return false, false + } + if typ.TypeEquals(precise, assigned) { + return false, true + } + if typ.IsAbsentOrUnknown(assigned) { + return !typ.IsAbsentOrUnknown(precise), true + } + if typ.IsAbsentOrUnknown(precise) { + return false, false + } + + switch p := precise.(type) { + case *typ.Alias: + return compareSameExpressionEvidence(p.UnaliasedTarget(), assigned, depth+1) + case *typ.Ref: + if a, ok := assigned.(*typ.Alias); ok && a.Name == p.Name && p.Module == "" { + return false, true + } + } + switch a := assigned.(type) { + case *typ.Alias: + return compareSameExpressionEvidence(precise, a.UnaliasedTarget(), depth+1) + case *typ.Ref: + if p, ok := precise.(*typ.Alias); ok && p.Name == a.Name && a.Module == "" { + return false, true + } + } + + switch p := precise.(type) { + case *typ.Record: + a, ok := assigned.(*typ.Record) + if !ok { + return false, false + } + return compareRecordEvidence(p, a, depth+1) + case *typ.Optional: + a, ok := assigned.(*typ.Optional) + if !ok { + return false, false + } + return compareSameExpressionEvidence(p.Inner, a.Inner, depth+1) + case *typ.Tuple: + a, ok := assigned.(*typ.Tuple) + if !ok || len(p.Elements) != len(a.Elements) { + return false, false + } + improved := false + for i := range p.Elements { + fieldImproved, ok := compareSameExpressionEvidence(p.Elements[i], a.Elements[i], depth+1) + if !ok { + return false, false + } + improved = improved || fieldImproved + } + return improved, true + case *typ.Array: + a, ok := assigned.(*typ.Array) + if !ok { + return false, false + } + return compareSameExpressionEvidence(p.Element, a.Element, depth+1) + case *typ.Map: + a, ok := assigned.(*typ.Map) + if !ok { + return false, false + } + keyImproved, ok := compareSameExpressionEvidence(p.Key, a.Key, depth+1) + if !ok { + return false, false + } + valueImproved, ok := compareSameExpressionEvidence(p.Value, a.Value, depth+1) + if !ok { + return false, false + } + return keyImproved || valueImproved, true + default: + return false, false + } +} + +func compareRecordEvidence(precise, assigned *typ.Record, depth int) (bool, bool) { + if precise == nil || assigned == nil { + return false, false + } + if precise.Open != assigned.Open { + return false, false + } + if (precise.HasMapComponent()) != (assigned.HasMapComponent()) { + return false, false + } + improved := false + for _, assignedField := range assigned.Fields { + preciseField := precise.GetField(assignedField.Name) + if preciseField == nil { + return false, false + } + if preciseField.Optional != assignedField.Optional || preciseField.Readonly != assignedField.Readonly { + return false, false + } + fieldImproved, ok := compareSameExpressionEvidence(preciseField.Type, assignedField.Type, depth+1) + if !ok { + return false, false + } + improved = improved || fieldImproved + } + if assigned.HasMapComponent() { + keyImproved, ok := compareSameExpressionEvidence(precise.MapKey, assigned.MapKey, depth+1) + if !ok { + return false, false + } + valueImproved, ok := compareSameExpressionEvidence(precise.MapValue, assigned.MapValue, depth+1) + if !ok { + return false, false + } + improved = improved || keyImproved || valueImproved + } + return improved, true +} diff --git a/compiler/check/flowbuild/assign/precision_test.go b/compiler/check/flowbuild/assign/precision_test.go index b4fdc423..4b9e1943 100644 --- a/compiler/check/flowbuild/assign/precision_test.go +++ b/compiler/check/flowbuild/assign/precision_test.go @@ -47,3 +47,48 @@ func TestPreferPreciseDirectSourceType_DoesNotReplaceNamedAssignedType(t *testin t.Fatalf("expected existing named assigned type to remain, got %s", typ.FormatShort(got)) } } + +func TestPreferPreciseDirectSourceType_RefinesUnknownRecordFieldFromSameExpression(t *testing.T) { + assigned := typ.NewRecord(). + Field("headers", typ.NewMap(typ.String, typ.String)). + Field("timeout", typ.Unknown). + Build() + precise := typ.NewRecord(). + Field("headers", typ.NewMap(typ.String, typ.String)). + Field("timeout", typ.Number). + Build() + + got := preferPreciseDirectSourceType( + assigned, + &ast.TableExpr{}, + 0, + nil, + func(ast.Expr, cfg.Point) typ.Type { return precise }, + true, + ) + if !typ.TypeEquals(got, precise) { + t.Fatalf("expected same-expression concrete field evidence to win, got %s", typ.FormatShort(got)) + } +} + +func TestPreferPreciseDirectSourceType_DoesNotDropAssignedRecordEvidence(t *testing.T) { + assigned := typ.NewRecord(). + Field("headers", typ.NewMap(typ.String, typ.String)). + Field("timeout", typ.Unknown). + Build() + precise := typ.NewRecord(). + Field("timeout", typ.Number). + Build() + + got := preferPreciseDirectSourceType( + assigned, + &ast.TableExpr{}, + 0, + nil, + func(ast.Expr, cfg.Point) typ.Type { return precise }, + true, + ) + if !typ.TypeEquals(got, assigned) { + t.Fatalf("expected assigned evidence to remain when direct type drops fields, got %s", typ.FormatShort(got)) + } +} diff --git a/compiler/check/flowbuild/assign/preflow_synth.go b/compiler/check/flowbuild/assign/preflow_synth.go index 5f677926..a71fd618 100644 --- a/compiler/check/flowbuild/assign/preflow_synth.go +++ b/compiler/check/flowbuild/assign/preflow_synth.go @@ -8,11 +8,13 @@ import ( fbcore "github.com/wippyai/go-lua/compiler/check/flowbuild/core" fbpath "github.com/wippyai/go-lua/compiler/check/flowbuild/path" "github.com/wippyai/go-lua/compiler/check/flowbuild/predicate" + "github.com/wippyai/go-lua/compiler/check/synth/ops" "github.com/wippyai/go-lua/types/db" "github.com/wippyai/go-lua/types/flow" "github.com/wippyai/go-lua/types/narrow" "github.com/wippyai/go-lua/types/query/core" "github.com/wippyai/go-lua/types/typ" + "github.com/wippyai/go-lua/types/typ/unwrap" ) type narrowResolverAdapter struct { @@ -36,6 +38,15 @@ func (r narrowResolverAdapter) Index(t typ.Type, key typ.Type) (typ.Type, bool) return r.ops.Index(r.ctx, t, key) } +type overlayTypeAt func(cfg.SymbolID, cfg.Point) (typ.Type, bool) + +func mapOverlayTypeAt(overlay map[cfg.SymbolID]typ.Type) overlayTypeAt { + return func(sym cfg.SymbolID, _ cfg.Point) (typ.Type, bool) { + t, ok := overlay[sym] + return t, ok + } +} + // buildPreflowBranchSolution solves only branch/numeric edge facts that are // already available before assignment extraction completes. // @@ -63,7 +74,7 @@ func buildPreflowBranchSolution(fc *fbcore.FlowContext, inputs *flow.Inputs) *fl // This keeps assignment inference on the canonical synthesis path while letting // recursive field/index expressions observe already-provable branch facts. func synthWithOverlayAndPreflow( - overlay map[cfg.SymbolID]typ.Type, + overlay overlayTypeAt, bindings *bind.BindingTable, inputs *flow.Inputs, callCtx *db.QueryContext, @@ -80,8 +91,10 @@ func synthWithOverlayAndPreflow( if ident, ok := expr.(*ast.IdentExpr); ok && bindings != nil { if sym, ok := bindings.SymbolOf(ident); ok && sym != 0 { - if t, exists := overlay[sym]; exists { - return t + if overlay != nil { + if t, exists := overlay(sym, p); exists { + return t + } } } } @@ -90,11 +103,22 @@ func synthWithOverlayAndPreflow( constResolver := predicate.BuildConstResolver(inputs, p) if path := fbpath.FromExprWithBindings(expr, constResolver, bindings); !path.IsEmpty() { if narrowed := preflow.NarrowedTypeAt(p, path); !typ.IsAbsentOrUnknown(narrowed) { + if attr, ok := expr.(*ast.AttrGetExpr); ok && typeOps != nil { + if declared := declaredAttrReadType(attr, p, synth, callCtx, typeOps); declared != nil { + refined, ok := refinePathFactWithDeclaredType(narrowed, declared, callCtx, typeOps) + if !ok { + goto skipPreflowPathFact + } + narrowed = refined + } + } return narrowed } } } + skipPreflowPathFact: + if attr, ok := expr.(*ast.AttrGetExpr); ok && typeOps != nil { objType := synth(attr.Object, p) if !typ.IsAbsentOrUnknown(objType) { @@ -117,6 +141,38 @@ func synthWithOverlayAndPreflow( } } + if call, ok := expr.(*ast.FuncCallExpr); ok && typeOps != nil { + if result := synthCallWithOverlay(call, p, synth, callCtx, typeOps); !typ.IsAbsentOrUnknown(result) { + return result + } + if base != nil { + if direct := base(expr, p); !typ.IsAbsentOrUnknown(direct) { + return direct + } + } + return typ.Unknown + } + + if logical, ok := expr.(*ast.LogicalOpExpr); ok { + left := synth(logical.Lhs, p) + right := synth(logical.Rhs, p) + var result typ.Type + switch logical.Operator { + case "and": + result = ops.LogicalAndTyped(left, right) + case "or": + result = ops.LogicalOrTyped(left, right) + default: + result = typ.Unknown + } + if (typ.IsAbsentOrUnknown(result) || typ.IsAny(result)) && base != nil { + if direct := base(expr, p); !typ.IsAbsentOrUnknown(direct) && !typ.IsAny(direct) { + return direct + } + } + return result + } + if base == nil { return nil } @@ -125,3 +181,95 @@ func synthWithOverlayAndPreflow( return synth } + +func synthCallWithOverlay( + call *ast.FuncCallExpr, + p cfg.Point, + synth func(ast.Expr, cfg.Point) typ.Type, + callCtx *db.QueryContext, + typeOps core.TypeOps, +) typ.Type { + if call == nil || synth == nil || typeOps == nil { + return nil + } + args := make([]typ.Type, len(call.Args)) + for i, arg := range call.Args { + args[i] = synth(arg, p) + } + def := ops.CallDef{ + Args: args, + Query: typeOps, + } + if call.Method != "" { + def.IsMethod = true + def.MethodName = call.Method + def.Receiver = synth(call.Receiver, p) + } else { + def.Callee = synth(call.Func, p) + } + result := ops.CallWithGenericInference(callCtx, def) + if len(result.Returns) > 0 { + return result.Returns[0] + } + return ops.ExtractFirstValue(result.Type) +} + +func declaredAttrReadType( + attr *ast.AttrGetExpr, + p cfg.Point, + synth func(ast.Expr, cfg.Point) typ.Type, + callCtx *db.QueryContext, + typeOps core.TypeOps, +) typ.Type { + if attr == nil || synth == nil || typeOps == nil { + return nil + } + objType := synth(attr.Object, p) + if typ.IsAbsentOrUnknown(objType) { + return nil + } + switch key := attr.Key.(type) { + case *ast.StringExpr: + if ft, ok := typeOps.Field(callCtx, objType, key.Value); ok { + return ft + } + if it, ok := typeOps.Index(callCtx, objType, typ.LiteralString(key.Value)); ok { + return it + } + default: + keyType := synth(attr.Key, p) + if !typ.IsAbsentOrUnknown(keyType) { + if it, ok := typeOps.Index(callCtx, objType, keyType); ok { + return it + } + } + } + return nil +} + +func refinePathFactWithDeclaredType(narrowed, declared typ.Type, callCtx *db.QueryContext, typeOps core.TypeOps) (typ.Type, bool) { + if narrowed == nil || declared == nil { + return narrowed, true + } + narrowed = unwrap.Alias(narrowed) + declared = unwrap.Alias(declared) + if narrowed == nil || declared == nil || declared.Kind().IsPlaceholder() { + return narrowed, true + } + if typeOps == nil { + return nil, false + } + if typeOps.IsSubtype(callCtx, narrowed, declared) { + return narrowed, true + } + declaredNonNil := narrow.RemoveNil(declared) + if !typ.IsNever(declaredNonNil) { + if typeOps.IsSubtype(callCtx, declaredNonNil, narrowed) { + return declaredNonNil, true + } + if unwrap.Function(declaredNonNil) != nil && unwrap.Function(narrowed) != nil { + return declaredNonNil, true + } + } + return nil, false +} diff --git a/compiler/check/flowbuild/assign/visibility.go b/compiler/check/flowbuild/assign/visibility.go new file mode 100644 index 00000000..f543a5f0 --- /dev/null +++ b/compiler/check/flowbuild/assign/visibility.go @@ -0,0 +1,151 @@ +package assign + +import ( + "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/types/typ" +) + +type symbolVersionKey struct { + sym cfg.SymbolID + id int +} + +func paramSymbolSet(graph *cfg.Graph) map[cfg.SymbolID]bool { + if graph == nil { + return nil + } + params := graph.ParamSymbols() + if len(params) == 0 { + return nil + } + out := make(map[cfg.SymbolID]bool, len(params)) + for _, sym := range params { + if sym != 0 { + out[sym] = true + } + } + return out +} + +func collectValueDefinitionVersions(graph *cfg.Graph) map[symbolVersionKey]struct{} { + if graph == nil { + return nil + } + out := make(map[symbolVersionKey]struct{}) + graph.EachAssign(func(p cfg.Point, info *cfg.AssignInfo) { + if info == nil { + return + } + info.EachTargetSource(func(_ int, target cfg.AssignTarget, source ast.Expr) { + if source == nil || target.Kind != cfg.TargetIdent || target.Symbol == 0 { + return + } + if ver := graph.VisibleVersion(p, target.Symbol); ver.Symbol != 0 && ver.ID != 0 { + out[symbolVersionKey{sym: target.Symbol, id: ver.ID}] = struct{}{} + } + }) + }) + graph.EachFuncDef(func(p cfg.Point, info *cfg.FuncDefInfo) { + if info == nil || info.Symbol == 0 || info.FuncExpr == nil { + return + } + if ver := graph.VisibleVersion(p, info.Symbol); ver.Symbol != 0 && ver.ID != 0 { + out[symbolVersionKey{sym: info.Symbol, id: ver.ID}] = struct{}{} + } + }) + if len(out) == 0 { + return nil + } + return out +} + +func overlayTypeVisibleAt( + graph *cfg.Graph, + valueDefs map[symbolVersionKey]struct{}, + paramSet map[cfg.SymbolID]bool, + sym cfg.SymbolID, + p cfg.Point, +) bool { + if sym == 0 { + return false + } + if graph == nil { + return true + } + if paramSet[sym] { + return true + } + ver := graph.VisibleVersion(p, sym) + if ver.Symbol == 0 || ver.ID == 0 { + return false + } + _, ok := valueDefs[symbolVersionKey{sym: sym, id: ver.ID}] + return ok +} + +func visibleInferredTypeAt( + inferred api.SpecTypes, + graph *cfg.Graph, + valueDefs map[symbolVersionKey]struct{}, + paramSet map[cfg.SymbolID]bool, + sym cfg.SymbolID, + p cfg.Point, +) (typ.Type, bool) { + t, ok := inferred[sym] + if !ok { + return nil, false + } + if !overlayTypeVisibleAt(graph, valueDefs, paramSet, sym, p) { + return nil, false + } + return t, true +} + +func mergeVisibleInferredTypes( + out api.SpecTypes, + inferred api.SpecTypes, + graph *cfg.Graph, + valueDefs map[symbolVersionKey]struct{}, + paramSet map[cfg.SymbolID]bool, + p cfg.Point, +) api.SpecTypes { + if len(inferred) == 0 { + return out + } + for sym, t := range inferred { + if !overlayTypeVisibleAt(graph, valueDefs, paramSet, sym, p) { + continue + } + if out == nil { + out = make(api.SpecTypes, len(inferred)) + } + out[sym] = t + } + return out +} + +func inferenceOverlayAtPoint( + graph *cfg.Graph, + p cfg.Point, + inferred api.SpecTypes, + seedTypes api.SpecTypes, + funcSigTypes map[cfg.SymbolID]typ.Type, + valueDefs map[symbolVersionKey]struct{}, + paramSet map[cfg.SymbolID]bool, +) api.SpecTypes { + var out api.SpecTypes + out = mergeSpecTypesInto(out, seedTypes) + for sym, t := range funcSigTypes { + if !overlayTypeVisibleAt(graph, valueDefs, paramSet, sym, p) { + continue + } + if out == nil { + out = make(api.SpecTypes, len(funcSigTypes)) + } + out[sym] = t + } + out = mergeVisibleInferredTypes(out, inferred, graph, valueDefs, paramSet, p) + return out +} diff --git a/compiler/check/flowbuild/guard/guard.go b/compiler/check/flowbuild/guard/guard.go index fc9b8d58..3c283cba 100644 --- a/compiler/check/flowbuild/guard/guard.go +++ b/compiler/check/flowbuild/guard/guard.go @@ -19,6 +19,44 @@ type TruthyPathKey struct { Field string } +// TypeProbe describes a builtin type(expr) equality check. +type TypeProbe struct { + Expr ast.Expr + Key narrow.TypeKey +} + +// ExtractTypeEqualityProbe extracts the runtime type predicate from a +// `type(expr) == "kind"` comparison. It is intentionally expression-only so +// synthesis, field validation, and flow guard collection share one parser. +func ExtractTypeEqualityProbe(expr ast.Expr) (TypeProbe, bool) { + rel, ok := expr.(*ast.RelationalOpExpr) + if !ok || rel == nil || rel.Operator != "==" { + return TypeProbe{}, false + } + if probe, ok := typeProbeSide(rel.Lhs, rel.Rhs); ok { + return probe, true + } + return typeProbeSide(rel.Rhs, rel.Lhs) +} + +// IsTypeCall reports whether call has builtin type(expr) shape. +func IsTypeCall(call *ast.FuncCallExpr) bool { + if call == nil || callsite.IsMethodLikeExpr(call) || len(call.Args) != 1 { + return false + } + ident, ok := call.Func.(*ast.IdentExpr) + return ok && ident != nil && ident.Value == "type" +} + +// TypeForTypeKey returns the broad runtime type represented by a builtin +// type() result key. +func TypeForTypeKey(key narrow.TypeKey) typ.Type { + if kind, ok := key.BuiltinKind(); ok { + return narrow.TypeForKind(kind) + } + return typ.Unknown +} + // CollectTruthyGuards scans the CFG for conditions that establish truthy guards // and propagates them to dominated points. Used to narrow optional types. func CollectTruthyGuards(graph *cfg.Graph, bindings *bind.BindingTable) map[cfg.Point]map[TruthyPathKey]bool { @@ -363,29 +401,32 @@ func extractTypeGuard(expr ast.Expr, bindings *bind.BindingTable) (TruthyPathKey } func typeGuardPathAndKey(typeExpr, keyExpr ast.Expr, bindings *bind.BindingTable) (TruthyPathKey, narrow.TypeKey, bool) { - call, ok := typeExpr.(*ast.FuncCallExpr) - if !ok || call == nil || callsite.IsMethodLikeExpr(call) || len(call.Args) != 1 { + probe, ok := typeProbeSide(typeExpr, keyExpr) + if !ok { return TruthyPathKey{}, narrow.TypeKey{}, false } - ident, ok := call.Func.(*ast.IdentExpr) - if !ok || ident.Value != "type" { + + key, ok := TruthyKeyFromExpr(probe.Expr, bindings) + if !ok || key.Field == "" { return TruthyPathKey{}, narrow.TypeKey{}, false } + return key, probe.Key, true +} +func typeProbeSide(typeExpr, keyExpr ast.Expr) (TypeProbe, bool) { + call, ok := typeExpr.(*ast.FuncCallExpr) + if !ok || !IsTypeCall(call) { + return TypeProbe{}, false + } typeName, ok := typeStringLiteral(keyExpr) if !ok { - return TruthyPathKey{}, narrow.TypeKey{}, false + return TypeProbe{}, false } typeKey, ok := narrow.KnownBuiltinTypeKey(typeName) if !ok { - return TruthyPathKey{}, narrow.TypeKey{}, false - } - - key, ok := TruthyKeyFromExpr(call.Args[0], bindings) - if !ok || key.Field == "" { - return TruthyPathKey{}, narrow.TypeKey{}, false + return TypeProbe{}, false } - return key, typeKey, true + return TypeProbe{Expr: call.Args[0], Key: typeKey}, true } func typeStringLiteral(expr ast.Expr) (string, bool) { diff --git a/compiler/check/flowbuild/guard/guard_test.go b/compiler/check/flowbuild/guard/guard_test.go index 57f2b145..4a8b154e 100644 --- a/compiler/check/flowbuild/guard/guard_test.go +++ b/compiler/check/flowbuild/guard/guard_test.go @@ -317,6 +317,35 @@ func TestCollectTypeGuards_TypeNotEqReturnPropagatesFallthrough(t *testing.T) { } } +func TestExtractTypeEqualityProbe(t *testing.T) { + target := &ast.AttrGetExpr{ + Object: &ast.IdentExpr{Value: "page"}, + Key: &ast.StringExpr{Value: "placement"}, + } + expr := &ast.RelationalOpExpr{ + Operator: "==", + Lhs: &ast.FuncCallExpr{ + Func: &ast.IdentExpr{Value: "type"}, + Args: []ast.Expr{target}, + }, + Rhs: &ast.StringExpr{Value: "string"}, + } + + probe, ok := guard.ExtractTypeEqualityProbe(expr) + if !ok { + t.Fatal("expected type equality probe") + } + if probe.Expr != target { + t.Fatal("expected probe expression to be preserved") + } + if probe.Key != narrow.BuiltinTypeKey("string") { + t.Fatalf("probe key = %v, want string key", probe.Key) + } + if got := guard.TypeForTypeKey(probe.Key); !typ.TypeEquals(got, typ.String) { + t.Fatalf("probe type = %v, want string", got) + } +} + func TestNarrowTableFieldsByGuard_TypeGuardNarrowsAny(t *testing.T) { valueExpr := &ast.AttrGetExpr{ Object: &ast.IdentExpr{Value: "payload"}, diff --git a/compiler/check/hooks/assign_check.go b/compiler/check/hooks/assign_check.go index 046b5c8b..a6577a98 100644 --- a/compiler/check/hooks/assign_check.go +++ b/compiler/check/hooks/assign_check.go @@ -157,7 +157,9 @@ func CheckAssignments(graph *cfg.Graph, scopes map[cfg.Point]*scope.State, narro valueType := narrowSynth.SynthWithExpected(source, p, declaredType) if sourceUsesTarget { if pre := preAssignmentExprTypeForAssign(source, p, narrowSynth, graph, declaredType); pre != nil { - valueType = pre + if !typ.IsAbsentOrUnknown(pre) { + valueType = pre + } } } if valueType == nil { diff --git a/compiler/check/hooks/field_check.go b/compiler/check/hooks/field_check.go index 68be760d..8eaf1347 100644 --- a/compiler/check/hooks/field_check.go +++ b/compiler/check/hooks/field_check.go @@ -19,6 +19,7 @@ import ( "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/flowbuild/guard" "github.com/wippyai/go-lua/compiler/check/flowbuild/path" "github.com/wippyai/go-lua/compiler/check/scope" checksynth "github.com/wippyai/go-lua/compiler/check/synth" @@ -182,6 +183,10 @@ func checkFieldExpr(expr ast.Expr, p cfg.Point, narrowView api.BaseSynth, resolv case *ast.FuncCallExpr: diags = append(diags, checkFieldExpr(e.Func, p, narrowView, resolver, seen, sourceName)...) for _, arg := range e.Args { + if guard.IsTypeCall(e) { + diags = append(diags, checkTypeProbeArg(arg, p, narrowView, resolver, seen, sourceName)...) + continue + } diags = append(diags, checkFieldExpr(arg, p, narrowView, resolver, seen, sourceName)...) } case *ast.TableExpr: @@ -189,6 +194,24 @@ func checkFieldExpr(expr ast.Expr, p cfg.Point, narrowView api.BaseSynth, resolv diags = append(diags, checkFieldExpr(f.Value, p, narrowView, resolver, seen, sourceName)...) } case *ast.LogicalOpExpr: + if e.Operator == "and" { + if probe, ok := guard.ExtractTypeEqualityProbe(e.Lhs); ok && resolver.bindings != nil { + probeType := guard.TypeForTypeKey(probe.Key) + diags = append(diags, checkTypeProbeArg(probe.Expr, p, narrowView, resolver, seen, sourceName)...) + probePath := path.FromExprWithBindings(probe.Expr, nil, resolver.bindings) + if !probePath.IsEmpty() { + localView := &localNarrowView{ + base: narrowView, + bindings: resolver.bindings, + overridePath: probePath, + overrideType: probeType, + } + localResolver := fieldResolverImpl{view: localView, synth: resolver.synth, bindings: resolver.bindings} + diags = append(diags, checkFieldExpr(e.Rhs, p, localView, localResolver, seen, sourceName)...) + return diags + } + } + } diags = append(diags, checkFieldExpr(e.Lhs, p, narrowView, resolver, seen, sourceName)...) lhsType := narrowView.TypeOf(e.Lhs, p) if e.Operator == "and" && ops.IsFalsy(lhsType) { @@ -357,7 +380,7 @@ func preAssignmentExprType(graph *cfg.Graph, expr ast.Expr, p cfg.Point, view ap func checkArithmetic(e *ast.ArithmeticOpExpr, p cfg.Point, narrowView api.BaseSynth, sourceName string) []diag.Diagnostic { check := func(expr ast.Expr) *diag.Diagnostic { t := narrowView.TypeOf(expr, p) - if t == nil || ops.IsNumeric(t) { + if t == nil || ops.IsNumeric(t) || typ.IsNever(t) { return nil } msg := "cannot perform arithmetic on " + typ.FormatShort(t) + ", expected number" @@ -534,6 +557,9 @@ func checkAttrGet(e *ast.AttrGetExpr, p cfg.Point, narrowView api.BaseSynth, res var diags []diag.Diagnostic diags = append(diags, checkFieldExpr(e.Object, p, narrowView, resolver, seen, sourceName)...) + if localViewOverridesExpr(narrowView, e) { + return diags + } objType := narrowView.TypeOf(e.Object, p) @@ -593,6 +619,30 @@ func checkAttrGet(e *ast.AttrGetExpr, p cfg.Point, narrowView api.BaseSynth, res return diags } +func checkTypeProbeArg(expr ast.Expr, p cfg.Point, narrowView api.BaseSynth, resolver fieldResolverImpl, seen map[ast.Expr]bool, sourceName string) []diag.Diagnostic { + attr, ok := expr.(*ast.AttrGetExpr) + if !ok || attr == nil { + return checkFieldExpr(expr, p, narrowView, resolver, seen, sourceName) + } + return checkFieldExpr(attr.Object, p, narrowView, resolver, seen, sourceName) +} + +func localViewOverridesExpr(view api.BaseSynth, expr ast.Expr) bool { + for { + localView, ok := view.(*localNarrowView) + if !ok || localView == nil { + return false + } + if localView.bindings != nil { + exprPath := path.FromExprWithBindings(expr, nil, localView.bindings) + if !exprPath.IsEmpty() && exprPath.Equal(localView.overridePath) { + return true + } + } + view = localView.base + } +} + func isStringKeyExpr(key ast.Expr) bool { if key == nil { return false diff --git a/compiler/check/hooks/table_check.go b/compiler/check/hooks/table_check.go index 75a13a3e..a44c8e13 100644 --- a/compiler/check/hooks/table_check.go +++ b/compiler/check/hooks/table_check.go @@ -217,6 +217,9 @@ func checkTableWithOptionalRelax(fields []ops.FieldDef, arrayElems []typ.Type, e if err.Message == "missing required field" && unwrap.IsOptionalLike(err.Expected) { continue } + if err.Message == "field type mismatch" && unresolvedTableEvidence(err.Got) { + continue + } if err.Message == "unexpected field" { continue } @@ -240,6 +243,14 @@ func checkTableWithOptionalRelax(fields []ops.FieldDef, arrayElems []typ.Type, e return false, reason } +func unresolvedTableEvidence(t typ.Type) bool { + if typ.IsAbsentOrUnknown(t) { + return true + } + rec := unwrap.Record(t) + return rec != nil && len(rec.Fields) == 0 && !rec.HasMapComponent() +} + func unionAllRecordLike(u *typ.Union) bool { if u == nil { return false diff --git a/compiler/check/hooks/table_check_test.go b/compiler/check/hooks/table_check_test.go new file mode 100644 index 00000000..56cd2cf9 --- /dev/null +++ b/compiler/check/hooks/table_check_test.go @@ -0,0 +1,38 @@ +package hooks + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/synth/ops" + "github.com/wippyai/go-lua/types/typ" +) + +func TestCheckTableWithOptionalRelax_ContextualizesUnresolvedFieldEvidence(t *testing.T) { + expected := typ.NewRecord(). + Field("timeout", typ.NewOptional(typ.Number)). + Build() + + ok, reason := checkTableWithOptionalRelax( + []ops.FieldDef{{Name: "timeout", Type: typ.Unknown}}, + nil, + expected, + ) + if !ok { + t.Fatalf("expected unresolved field evidence to accept contextual type, got %q", reason) + } +} + +func TestCheckTableWithOptionalRelax_RejectsConcreteMismatch(t *testing.T) { + expected := typ.NewRecord(). + Field("timeout", typ.NewOptional(typ.Number)). + Build() + + ok, _ := checkTableWithOptionalRelax( + []ops.FieldDef{{Name: "timeout", Type: typ.String}}, + nil, + expected, + ) + if ok { + t.Fatal("expected concrete mismatched field evidence to fail") + } +} diff --git a/compiler/check/infer/interproc/postflow.go b/compiler/check/infer/interproc/postflow.go index c78f6e91..8f545c79 100644 --- a/compiler/check/infer/interproc/postflow.go +++ b/compiler/check/infer/interproc/postflow.go @@ -299,7 +299,7 @@ func CollectParamHintsFromResult(store Store, result *api.FuncResult, parent *sc return sym != 0 && store.FunctionRefBySym(sym) != nil } collectCallHints := func(p cfg.Point, info *cfg.CallInfo) { - if info == nil || len(info.Args) == 0 { + if info == nil || checkcallsite.RuntimeArgCount(info) == 0 { return } callTargets := preAssignTargets[info] @@ -389,7 +389,30 @@ func CollectParamHintsFromResult(store Store, result *api.FuncResult, parent *sc } deltaHints := make(api.ParamHints) - hints := paramhints.EnsureHintCapacity(nil, len(info.Args)) + runtimeArgCount := checkcallsite.RuntimeArgCount(info) + hints := paramhints.EnsureHintCapacity(nil, runtimeArgCount) + for runtimeIdx := 0; runtimeIdx < runtimeArgCount; runtimeIdx++ { + arg := checkcallsite.RuntimeArgAt(info, runtimeIdx) + if arg == nil { + continue + } + var argType typ.Type + if checkcallsite.IsMethodCallInfo(info) && runtimeIdx == 0 { + argType = def.Receiver + } else { + argIdx := runtimeIdx + if checkcallsite.IsMethodCallInfo(info) { + argIdx-- + } + if argIdx >= 0 && argIdx < len(argTypes) { + argType = argTypes[argIdx] + } + } + if argType == nil { + argType = result.NarrowSynth.TypeOf(arg, p) + } + hints, _ = paramhints.MergeCallArgHintAt(hints, runtimeIdx, argType, typ.JoinPreferNonSoft, true) + } for i, arg := range info.Args { if arg == nil { continue @@ -413,12 +436,6 @@ func CollectParamHintsFromResult(store Store, result *api.FuncResult, parent *sc } } } - - argType := argTypes[i] - if argType == nil { - argType = result.NarrowSynth.TypeOf(arg, p) - } - hints, _ = paramhints.MergeCallArgHintAt(hints, i, argType, typ.JoinPreferNonSoft, true) } if len(hints) > 0 { deltaHints[calleeSym] = hints diff --git a/compiler/check/infer/nested/processor.go b/compiler/check/infer/nested/processor.go index 947b87c2..49bb2cc4 100644 --- a/compiler/check/infer/nested/processor.go +++ b/compiler/check/infer/nested/processor.go @@ -279,6 +279,24 @@ func (p *Processor) processNestedFunction( var synthFn func(ast.Expr, cfg.Point) typ.Type if result.NarrowSynth != nil { synthFn = result.NarrowSynth.TypeOf + if result.Graph != nil { + if bindings := result.Graph.Bindings(); bindings != nil { + baseSynth := synthFn + synthFn = func(expr ast.Expr, p cfg.Point) typ.Type { + if ident, ok := expr.(*ast.IdentExpr); ok { + if sym, found := bindings.SymbolOf(ident); found && sym != 0 { + if result.Facts != nil { + tv := result.Facts.EffectiveTypeAt(p, sym) + if tv.State == flow.StateResolved && !typ.IsAbsentOrUnknown(tv.Type) { + return tv.Type + } + } + } + } + return baseSynth(expr, p) + } + } + } } fields := nested.CollectConstructorFields(result.Graph, selfSym, synthFn) if len(fields) > 0 { diff --git a/compiler/check/infer/paramhints/param_hints.go b/compiler/check/infer/paramhints/param_hints.go index 7097f781..26101517 100644 --- a/compiler/check/infer/paramhints/param_hints.go +++ b/compiler/check/infer/paramhints/param_hints.go @@ -8,6 +8,7 @@ import ( "github.com/wippyai/go-lua/internal" "github.com/wippyai/go-lua/types/kind" "github.com/wippyai/go-lua/types/typ" + "github.com/wippyai/go-lua/types/typ/unwrap" ) type HintJoinFn func(prev, next typ.Type) typ.Type @@ -23,7 +24,7 @@ func MergeIntoSignature(fn *ast.FunctionExpr, hints []typ.Type, sig *typ.Functio if i >= len(hints) || hints[i] == nil { continue } - if i < len(fn.ParList.Types) && fn.ParList.Types[i] != nil { + if srcIdx, hasSource := signatureSourceParamIndex(fn, sig, i); hasSource && srcIdx < len(fn.ParList.Types) && fn.ParList.Types[srcIdx] != nil { if !typ.IsRefinableAnnotation(p.Type) { continue } @@ -40,7 +41,8 @@ func MergeIntoSignature(fn *ast.FunctionExpr, hints []typ.Type, sig *typ.Functio for i, p := range sig.Params { paramType := p.Type if i < len(hints) && hints[i] != nil { - annotated := i < len(fn.ParList.Types) && fn.ParList.Types[i] != nil + srcIdx, hasSource := signatureSourceParamIndex(fn, sig, i) + annotated := hasSource && srcIdx < len(fn.ParList.Types) && fn.ParList.Types[srcIdx] != nil if !annotated || typ.IsRefinableAnnotation(paramType) { paramType = hints[i] } @@ -69,6 +71,33 @@ func MergeIntoSignature(fn *ast.FunctionExpr, hints []typ.Type, sig *typ.Functio return builder.Build() } +func signatureSourceParamIndex(fn *ast.FunctionExpr, sig *typ.Function, paramIdx int) (int, bool) { + if fn == nil || fn.ParList == nil || sig == nil || paramIdx < 0 || paramIdx >= len(sig.Params) { + return 0, false + } + if signatureHasImplicitSelf(fn, sig) { + if paramIdx == 0 { + return 0, false + } + srcIdx := paramIdx - 1 + return srcIdx, srcIdx >= 0 && srcIdx < len(fn.ParList.Names) + } + return paramIdx, paramIdx < len(fn.ParList.Names) +} + +func signatureHasImplicitSelf(fn *ast.FunctionExpr, sig *typ.Function) bool { + if fn == nil || fn.ParList == nil || sig == nil || len(sig.Params) == 0 { + return false + } + if sig.Params[0].Name != "self" { + return false + } + if len(fn.ParList.Names) > 0 && fn.ParList.Names[0] == "self" { + return false + } + return len(sig.Params) == len(fn.ParList.Names)+1 +} + func WidenParamHintType(t typ.Type) typ.Type { if t == nil { return nil @@ -110,12 +139,7 @@ func WidenParamHintType(t typ.Type) typ.Type { case *typ.Record: builder := typ.NewRecord() changed := false - if !v.Open { - // Call-site table literals should not over-constrain unannotated params. - // Widen record hints to open records so optional field probes remain valid. - builder.SetOpen(true) - changed = true - } else { + if v.Open { builder.SetOpen(true) } for _, f := range v.Fields { @@ -149,7 +173,126 @@ func WidenParamHintType(t typ.Type) typ.Type { // NormalizeHintType applies canonical widening and soft-member pruning. func NormalizeHintType(t typ.Type) typ.Type { - return typ.PruneSoftUnionMembers(WidenParamHintType(t)) + return collapseTableTopHint(typ.PruneSoftUnionMembers(WidenParamHintType(t))) +} + +func collapseTableTopHint(t typ.Type) typ.Type { + if t == nil { + return nil + } + switch v := t.(type) { + case *typ.Alias: + target := collapseTableTopHint(v.Target) + if target != nil && !typ.TypeEquals(target, v.Target) { + return typ.NewAlias(v.Name, target) + } + return t + case *typ.Optional: + inner := collapseTableTopHint(v.Inner) + if inner != nil && !typ.TypeEquals(inner, v.Inner) { + return typ.NewOptional(inner) + } + return t + case *typ.Union: + return collapseTableTopUnion(v) + default: + return t + } +} + +func collapseTableTopUnion(u *typ.Union) typ.Type { + if u == nil { + return nil + } + tableTop := firstTableTopMember(u.Members) + members := make([]typ.Type, 0, len(u.Members)) + changed := false + + if tableTop == nil { + for _, member := range u.Members { + collapsed := collapseTableTopHint(member) + if !typ.TypeEquals(collapsed, member) { + changed = true + } + members = append(members, collapsed) + } + if !changed { + return u + } + return typ.NewUnion(members...) + } + + tableAdded := false + for _, member := range u.Members { + if member == nil { + continue + } + if typ.UnwrapAnnotated(member).Kind() == kind.Nil { + members = append(members, member) + continue + } + collapsed := collapseTableTopHint(member) + if tableTopCoversHintMember(collapsed) { + if !tableAdded { + members = append(members, tableTop) + tableAdded = true + } + if !typ.TypeEquals(member, tableTop) { + changed = true + } + continue + } + if !typ.TypeEquals(collapsed, member) { + changed = true + } + members = append(members, collapsed) + } + if !changed { + return u + } + return typ.NewUnion(members...) +} + +func firstTableTopMember(members []typ.Type) typ.Type { + for _, member := range members { + if isBuiltinTableTopHint(member) { + return member + } + } + return nil +} + +func isBuiltinTableTopHint(t typ.Type) bool { + return unwrap.IsBuiltinTableTop(typ.UnwrapAnnotated(t)) +} + +func tableTopCoversHintMember(t typ.Type) bool { + if t == nil { + return false + } + if isBuiltinTableTopHint(t) { + return true + } + switch v := typ.UnwrapAnnotated(t).(type) { + case *typ.Alias: + return tableTopCoversHintMember(v.UnaliasedTarget()) + case *typ.Recursive: + return v.Body != nil && v.Body != v && tableTopCoversHintMember(v.Body) + case *typ.Union: + if len(v.Members) == 0 { + return false + } + for _, member := range v.Members { + if member == nil || typ.UnwrapAnnotated(member).Kind() == kind.Nil || !tableTopCoversHintMember(member) { + return false + } + } + return true + case *typ.Record, *typ.Map, *typ.Array, *typ.Tuple, *typ.Interface, *typ.Intersection: + return true + default: + return false + } } // EnsureHintCapacity grows hint vector to at least size. diff --git a/compiler/check/infer/paramhints/param_hints_test.go b/compiler/check/infer/paramhints/param_hints_test.go index f710ab35..f91114ad 100644 --- a/compiler/check/infer/paramhints/param_hints_test.go +++ b/compiler/check/infer/paramhints/param_hints_test.go @@ -3,6 +3,8 @@ package paramhints import ( "testing" + "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/types/typ" ) @@ -86,7 +88,23 @@ func TestWidenParamHintType_Union(t *testing.T) { } } -func TestWidenParamHintType_RecordBecomesOpen(t *testing.T) { +func TestNormalizeHintType_TableTopAbsorbsPreciseTableMembers(t *testing.T) { + tableTop := typ.NewInterface("table", nil) + preciseA := typ.NewRecord(). + Field("name", typ.String). + Field("tools", typ.NewArray(typ.String)). + Build() + preciseB := typ.NewMap(typ.String, typ.Integer) + hint := typ.NewUnion(typ.NewOptional(tableTop), preciseA, preciseB, typ.String) + + got := NormalizeHintType(hint) + want := typ.NewUnion(typ.NewOptional(tableTop), typ.String) + if !typ.TypeEquals(got, want) { + t.Fatalf("expected table top to absorb precise table members as %v, got %v", want, got) + } +} + +func TestWidenParamHintType_RecordPreservesClosedShape(t *testing.T) { rec := typ.NewRecord(). Field("pid", typ.LiteralString("abc")). Field("topic", typ.LiteralString("test:update")). @@ -97,8 +115,8 @@ func TestWidenParamHintType_RecordBecomesOpen(t *testing.T) { if !ok { t.Fatalf("expected record result, got %T", result) } - if !widened.Open { - t.Fatalf("expected widened param hint record to be open, got closed: %v", widened) + if widened.Open { + t.Fatalf("expected param hint to preserve closed call-site shape, got open: %v", widened) } pid := widened.GetField("pid") @@ -118,6 +136,293 @@ func TestBuildParamHintSignatures_NilInputs(t *testing.T) { } } +func TestMergeIntoSignature_ImplicitSelfUsesEffectiveHintSlots(t *testing.T) { + fn := functionWithParams("name") + sig := typ.Func(). + Param("self", typ.Unknown). + Param("name", typ.Unknown). + Build() + selfType := typ.NewRecord().Field("prefix", typ.String).Build() + + got := MergeIntoSignature(fn, []typ.Type{selfType, typ.String}, sig) + if got == nil || len(got.Params) != 2 { + t.Fatalf("unexpected merged signature: %v", got) + } + if !typ.TypeEquals(got.Params[0].Type, selfType) { + t.Fatalf("self hint should use effective slot 0, got %v", got.Params[0].Type) + } + if !typ.TypeEquals(got.Params[1].Type, typ.String) { + t.Fatalf("source parameter hint should use effective slot 1, got %v", got.Params[1].Type) + } +} + +func TestMergeIntoSignature_PreservesExplicitNilabilityOnOptionalSlot(t *testing.T) { + fn := functionWithParams("context") + context := typ.NewRecord(). + MapComponent(typ.String, typ.Any). + SetOpen(true). + Build() + sig := typ.Func().OptParam("context", typ.Any).Build() + + got := MergeIntoSignature(fn, []typ.Type{typ.NewOptional(context)}, sig) + if got == nil || len(got.Params) != 1 { + t.Fatalf("unexpected merged signature: %v", got) + } + if !got.Params[0].Optional { + t.Fatalf("expected parameter slot to remain optional: %v", got) + } + want := typ.NewOptional(context) + if !typ.TypeEquals(got.Params[0].Type, want) { + t.Fatalf("expected nilability to remain in the value type, got %v", got.Params[0].Type) + } +} + +func TestProjectHintsToParamUse_KeepsDemandedRecordFields(t *testing.T) { + fn := functionWithParams("client", "model_id") + fn.Stmts = []ast.Stmt{ + &ast.ReturnStmt{Exprs: []ast.Expr{ + &ast.FuncCallExpr{ + Func: &ast.AttrGetExpr{ + Object: &ast.IdentExpr{Value: "client"}, + Key: &ast.StringExpr{Value: "invoke"}, + }, + Args: []ast.Expr{ + &ast.IdentExpr{Value: "model_id"}, + &ast.TableExpr{}, + &ast.TableExpr{}, + }, + }, + }}, + } + graph := cfg.Build(fn) + invoke := typ.Func().Param("model_id", typ.String).Returns(typ.Unknown).Build() + client := typ.NewRecord(). + Field("invoke", invoke). + Field("process_converse_stream", typ.Func().Returns(typ.String).Build()). + Field("_credentials", typ.String). + Build() + + got := ProjectHintsToParamUse(graph, fn, []typ.Type{client, typ.String}) + rec, ok := got[0].(*typ.Record) + if !ok { + t.Fatalf("projected client hint = %T, want record (%v)", got[0], got[0]) + } + if rec.GetField("invoke") == nil { + t.Fatalf("projected client hint lost demanded invoke field: %v", rec) + } + for _, unused := range []string{"process_converse_stream", "_credentials"} { + if rec.GetField(unused) != nil { + t.Fatalf("projected client hint kept unused field %q: %v", unused, rec) + } + } + if !typ.TypeEquals(got[1], typ.String) { + t.Fatalf("directly used scalar hint should stay intact, got %v", got[1]) + } +} + +func TestProjectHintsToParamUse_KeepsDemandedAbsentRecordFieldsAsNil(t *testing.T) { + fn := functionWithParams("options") + fn.Stmts = []ast.Stmt{ + &ast.IfStmt{ + Condition: &ast.AttrGetExpr{ + Object: &ast.IdentExpr{Value: "options"}, + Key: &ast.StringExpr{Value: "stream"}, + }, + }, + &ast.AssignStmt{ + Lhs: []ast.Expr{ + &ast.AttrGetExpr{ + Object: &ast.AttrGetExpr{ + Object: &ast.IdentExpr{Value: "options"}, + Key: &ast.StringExpr{Value: "headers"}, + }, + Key: &ast.StringExpr{Value: "Accept"}, + }, + }, + Rhs: []ast.Expr{&ast.StringExpr{Value: "application/json"}}, + }, + } + graph := cfg.Build(fn) + hint := typ.NewRecord(). + Field("headers", typ.NewRecord().Build()). + Build() + + got := ProjectHintsToParamUse(graph, fn, []typ.Type{hint}) + rec, ok := got[0].(*typ.Record) + if !ok { + t.Fatalf("projected options hint = %T, want record (%v)", got[0], got[0]) + } + stream := rec.GetField("stream") + if stream == nil || !typ.TypeEquals(stream.Type, typ.Nil) { + t.Fatalf("demanded absent stream field should project as nil, got %v in %v", stream, rec) + } + headers := rec.GetField("headers") + if headers == nil { + t.Fatalf("projected options hint lost demanded headers field: %v", rec) + } +} + +func TestProjectSignatureToParamUse_CompletesDemandedAbsentFields(t *testing.T) { + fn := functionWithParams("info") + fn.Stmts = []ast.Stmt{ + &ast.AssignStmt{ + Lhs: []ast.Expr{&ast.IdentExpr{Value: "info"}}, + Rhs: []ast.Expr{&ast.LogicalOpExpr{ + Operator: "or", + Lhs: &ast.IdentExpr{Value: "info"}, + Rhs: &ast.TableExpr{}, + }}, + }, + &ast.LocalAssignStmt{Exprs: []ast.Expr{ + &ast.AttrGetExpr{ + Object: &ast.IdentExpr{Value: "info"}, + Key: &ast.StringExpr{Value: "message"}, + }, + }}, + &ast.LocalAssignStmt{Exprs: []ast.Expr{ + &ast.AttrGetExpr{ + Object: &ast.IdentExpr{Value: "info"}, + Key: &ast.StringExpr{Value: "status_code"}, + }, + }}, + } + graph := cfg.Build(fn) + info := typ.NewRecord(). + OptField("message", typ.String). + Build() + sig := typ.Func(). + Param("info", info). + Returns(typ.String). + Build() + + got := ProjectSignatureToParamUse(graph, fn, sig) + rec, ok := got.Params[0].Type.(*typ.Record) + if !ok { + t.Fatalf("projected param = %T, want record (%v)", got.Params[0].Type, got.Params[0].Type) + } + if rec.GetField("message") == nil { + t.Fatalf("projected signature lost existing demanded message field: %v", rec) + } + status := rec.GetField("status_code") + if status == nil || !typ.TypeEquals(status.Type, typ.Nil) { + t.Fatalf("projected signature should include demanded absent status_code as nil, got %v in %v", status, rec) + } + if len(got.Returns) != 1 || !typ.TypeEquals(got.Returns[0], typ.String) { + t.Fatalf("projected signature lost returns: %v", got) + } +} + +func TestProjectHintsToParamUse_DedupsUnionAfterProjection(t *testing.T) { + fn := functionWithParams("client") + fn.Stmts = []ast.Stmt{ + &ast.ReturnStmt{Exprs: []ast.Expr{ + &ast.FuncCallExpr{ + Func: &ast.AttrGetExpr{ + Object: &ast.IdentExpr{Value: "client"}, + Key: &ast.StringExpr{Value: "invoke"}, + }, + }, + }}, + } + graph := cfg.Build(fn) + invoke := typ.Func().Returns(typ.Unknown).Build() + broad := typ.NewRecord(). + Field("invoke", invoke). + Field("stream", typ.Func().Returns(typ.String).Build()). + Build() + narrow := typ.NewRecord(). + Field("invoke", invoke). + Field("stream", typ.Func().Returns(typ.LiteralString("invalid")).Build()). + Build() + + got := ProjectHintsToParamUse(graph, fn, []typ.Type{typ.NewUnion(broad, narrow)}) + rec, ok := got[0].(*typ.Record) + if !ok { + t.Fatalf("projected union hint = %T, want coalesced record (%v)", got[0], got[0]) + } + if rec.GetField("invoke") == nil || rec.GetField("stream") != nil { + t.Fatalf("projected union should keep only invoke, got %v", rec) + } +} + +func TestProjectHintsToParamUse_WholeParameterUseKeepsHint(t *testing.T) { + fn := functionWithParams("client") + fn.Stmts = []ast.Stmt{ + &ast.ReturnStmt{Exprs: []ast.Expr{ + &ast.FuncCallExpr{ + Func: &ast.IdentExpr{Value: "use_client"}, + Args: []ast.Expr{&ast.IdentExpr{Value: "client"}}, + }, + }}, + } + graph := cfg.Build(fn, "use_client") + client := typ.NewRecord().Field("invoke", typ.Func().Returns(typ.Unknown).Build()).Build() + + got := ProjectHintsToParamUse(graph, fn, []typ.Type{client}) + if !typ.TypeEquals(got[0], client) { + t.Fatalf("whole-parameter use should keep full hint, got %v", got[0]) + } +} + +func TestProjectHintsToParamUse_RecursiveForwardingDoesNotKeepWholeHint(t *testing.T) { + recursiveIdent := &ast.IdentExpr{Value: "visit"} + selfIdent := &ast.IdentExpr{Value: "self"} + valueIdent := &ast.IdentExpr{Value: "value"} + fn := functionWithParams("self", "value") + fn.Stmts = []ast.Stmt{ + &ast.IfStmt{ + Condition: &ast.AttrGetExpr{ + Object: valueIdent, + Key: &ast.StringExpr{Value: "next"}, + }, + Then: []ast.Stmt{ + &ast.ReturnStmt{Exprs: []ast.Expr{ + &ast.FuncCallExpr{ + Func: recursiveIdent, + Args: []ast.Expr{ + selfIdent, + &ast.AttrGetExpr{ + Object: &ast.IdentExpr{Value: "value"}, + Key: &ast.StringExpr{Value: "next"}, + }, + }, + }, + }}, + }, + }, + &ast.ReturnStmt{Exprs: []ast.Expr{ + &ast.AttrGetExpr{ + Object: &ast.IdentExpr{Value: "self"}, + Key: &ast.StringExpr{Value: "id"}, + }, + }}, + } + graph := cfg.Build(fn, "visit") + if sym, ok := graph.Bindings().SymbolOf(recursiveIdent); ok { + graph.Bindings().SetFuncLitSymbol(fn, sym) + } + selfHint := typ.NewRecord(). + Field("id", typ.String). + Field("command", typ.Func().Returns(typ.Nil, typ.String).Build()). + Build() + + got := ProjectHintsToParamUse(graph, fn, []typ.Type{selfHint, typ.NewRecord().Field("next", typ.Any).Build()}) + rec, ok := got[0].(*typ.Record) + if !ok { + t.Fatalf("projected self hint = %T, want record (%v)", got[0], got[0]) + } + if rec.GetField("id") == nil { + t.Fatalf("projected self hint lost demanded id field: %v", rec) + } + if rec.GetField("command") != nil { + t.Fatalf("recursive forwarding should not keep unused command field: %v", rec) + } +} + +func functionWithParams(names ...string) *ast.FunctionExpr { + return &ast.FunctionExpr{ParList: &ast.ParList{Names: names}} +} + func TestIsInformativeHintType(t *testing.T) { tests := []struct { name string diff --git a/compiler/check/infer/paramhints/project.go b/compiler/check/infer/paramhints/project.go new file mode 100644 index 00000000..6bfc14ff --- /dev/null +++ b/compiler/check/infer/paramhints/project.go @@ -0,0 +1,644 @@ +package paramhints + +import ( + "sort" + + "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/bind" + "github.com/wippyai/go-lua/compiler/cfg" + flowpath "github.com/wippyai/go-lua/compiler/check/flowbuild/path" + "github.com/wippyai/go-lua/types/constraint" + "github.com/wippyai/go-lua/types/subtype" + "github.com/wippyai/go-lua/types/typ" +) + +type paramUse struct { + whole bool + fields map[string]struct{} +} + +// ProjectHintsToParamUse trims structured call-site hints to the surface the +// function body actually reads from each unannotated parameter. Hints are +// evidence for analyzing a helper, not a promise that every unused field on the +// first argument shape is part of that helper's public contract. +func ProjectHintsToParamUse(graph *cfg.Graph, fn *ast.FunctionExpr, hints []typ.Type) []typ.Type { + if graph == nil || fn == nil || len(hints) == 0 { + return hints + } + + uses := collectParamUses(graph, fn) + if len(uses) == 0 { + return hints + } + + var out []typ.Type + for idx, slot := range graph.ParamSlotsReadOnly() { + if slot.Symbol == 0 || idx < 0 || idx >= len(hints) { + continue + } + hint := hints[idx] + if hint == nil { + continue + } + projected := projectHintToUse(hint, uses[slot.Symbol]) + if typ.TypeEquals(hint, projected) { + continue + } + if out == nil { + out = make([]typ.Type, len(hints)) + copy(out, hints) + } + out[idx] = projected + } + + if out == nil { + return hints + } + return out +} + +// ProjectSignatureToParamUse completes a function signature's parameter slots +// against the fields the function body reads. Unlike ProjectHintsToParamUse it +// does not trim unused fields: a function fact is already a canonical signature +// observation, and same-body analysis only needs to ensure demanded fields are +// present even when the parameter is also used as a whole value. +func ProjectSignatureToParamUse(graph *cfg.Graph, fn *ast.FunctionExpr, sig *typ.Function) *typ.Function { + if sig == nil || len(sig.Params) == 0 { + return sig + } + uses := collectParamUses(graph, fn) + if len(uses) == 0 { + return sig + } + projected := make([]typ.Type, len(sig.Params)) + changed := false + for idx, slot := range graph.ParamSlotsReadOnly() { + if idx < 0 || idx >= len(sig.Params) || slot.Symbol == 0 { + continue + } + use := uses[slot.Symbol] + if len(use.fields) == 0 { + continue + } + completed, ok := completeTypeWithFields(sig.Params[idx].Type, use.fields) + if !ok || completed == nil { + continue + } + projected[idx] = completed + if !typ.TypeEquals(sig.Params[idx].Type, completed) { + changed = true + } + } + if !changed { + return sig + } + + builder := typ.Func().ReserveParams(len(sig.Params)) + for _, tp := range sig.TypeParams { + builder = builder.TypeParam(tp.Name, tp.Constraint) + } + for i, p := range sig.Params { + paramType := p.Type + if i < len(projected) && projected[i] != nil { + paramType = projected[i] + } + if p.Optional { + builder = builder.OptParam(p.Name, paramType) + } else { + builder = builder.Param(p.Name, paramType) + } + } + if sig.Variadic != nil { + builder = builder.Variadic(sig.Variadic) + } + if len(sig.Returns) > 0 { + builder = builder.Returns(sig.Returns...) + } + if sig.Effects != nil { + builder = builder.Effects(sig.Effects) + } + if sig.Spec != nil { + builder = builder.Spec(sig.Spec) + } + if sig.Refinement != nil { + builder = builder.WithRefinement(sig.Refinement) + } + return builder.Build() +} + +func completeTypeWithFields(t typ.Type, fields map[string]struct{}) (typ.Type, bool) { + if t == nil || len(fields) == 0 { + return t, false + } + switch v := typ.UnwrapAnnotated(t).(type) { + case *typ.Alias: + completed, ok := completeTypeWithFields(v.Target, fields) + if !ok { + return t, false + } + return completed, true + case *typ.Optional: + inner, ok := completeTypeWithFields(v.Inner, fields) + if !ok { + return t, false + } + return typ.NewOptional(inner), true + case *typ.Union: + members := make([]typ.Type, 0, len(v.Members)) + changed := false + for _, member := range v.Members { + completed, ok := completeTypeWithFields(member, fields) + if !ok { + members = append(members, member) + continue + } + if !typ.TypeEquals(member, completed) { + changed = true + } + members = append(members, completed) + } + if !changed { + return t, false + } + return typ.NewUnion(members...), true + case *typ.Record: + return completeRecordWithFields(v, fields), true + default: + return t, false + } +} + +func completeRecordWithFields(r *typ.Record, fields map[string]struct{}) typ.Type { + builder := typ.NewRecord() + if r.Open { + builder.SetOpen(true) + } + if r.Metatable != nil { + builder.Metatable(r.Metatable) + } + for _, field := range r.Fields { + switch { + case field.Optional && field.Readonly: + builder.OptReadonlyField(field.Name, field.Type) + case field.Optional: + builder.OptField(field.Name, field.Type) + case field.Readonly: + builder.ReadonlyField(field.Name, field.Type) + default: + builder.Field(field.Name, field.Type) + } + } + if r.HasMapComponent() { + builder.MapComponent(r.MapKey, r.MapValue) + } + + names := make([]string, 0, len(fields)) + for name := range fields { + names = append(names, name) + } + sort.Strings(names) + for _, name := range names { + if r.GetField(name) != nil { + continue + } + if r.HasMapComponent() && subtype.IsSubtype(typ.LiteralString(name), r.MapKey) { + mapValue := r.MapValue + if mapValue == nil { + mapValue = typ.Unknown + } + builder.OptField(name, mapValue) + continue + } + if !r.Open { + builder.Field(name, typ.Nil) + } + } + return builder.Build() +} + +func collectParamUses(graph *cfg.Graph, fn *ast.FunctionExpr) map[cfg.SymbolID]paramUse { + paramSymbols := make(map[cfg.SymbolID]struct{}) + for _, slot := range graph.ParamSlotsReadOnly() { + if slot.Symbol == 0 { + continue + } + paramSymbols[slot.Symbol] = struct{}{} + } + if len(paramSymbols) == 0 { + return nil + } + + collector := paramUseCollector{ + bindings: graph.Bindings(), + paramSymbols: paramSymbols, + currentFunctionSymbols: currentFunctionSymbols(graph, fn), + uses: make(map[cfg.SymbolID]paramUse), + } + for _, stmt := range fn.Stmts { + collector.stmt(stmt) + } + return collector.uses +} + +type paramUseCollector struct { + bindings *bind.BindingTable + paramSymbols map[cfg.SymbolID]struct{} + currentFunctionSymbols map[cfg.SymbolID]struct{} + uses map[cfg.SymbolID]paramUse +} + +func (c *paramUseCollector) stmt(stmt ast.Stmt) { + switch s := stmt.(type) { + case *ast.AssignStmt: + for _, lhs := range s.Lhs { + c.lvalue(lhs) + } + for _, rhs := range s.Rhs { + c.expr(rhs) + } + case *ast.LocalAssignStmt: + for _, expr := range s.Exprs { + c.expr(expr) + } + case *ast.FuncCallStmt: + c.expr(s.Expr) + case *ast.DoBlockStmt: + c.stmts(s.Stmts) + case *ast.WhileStmt: + c.condition(s.Condition) + c.stmts(s.Stmts) + case *ast.RepeatStmt: + c.stmts(s.Stmts) + c.condition(s.Condition) + case *ast.IfStmt: + c.condition(s.Condition) + c.stmts(s.Then) + c.stmts(s.Else) + case *ast.NumberForStmt: + c.expr(s.Init) + c.expr(s.Limit) + c.expr(s.Step) + c.stmts(s.Stmts) + case *ast.GenericForStmt: + for _, expr := range s.Exprs { + c.expr(expr) + } + c.stmts(s.Stmts) + case *ast.FuncDefStmt: + if s.Name != nil { + c.expr(s.Name.Func) + c.expr(s.Name.Receiver) + } + if s.Func != nil { + c.stmts(s.Func.Stmts) + } + case *ast.ReturnStmt: + for _, expr := range s.Exprs { + c.expr(expr) + } + } +} + +func (c *paramUseCollector) stmts(stmts []ast.Stmt) { + for _, stmt := range stmts { + c.stmt(stmt) + } +} + +func (c *paramUseCollector) condition(expr ast.Expr) { + switch e := expr.(type) { + case *ast.IdentExpr: + if c.isParamIdent(e) { + return + } + case *ast.UnaryNotOpExpr: + if ident, ok := e.Expr.(*ast.IdentExpr); ok && c.isParamIdent(ident) { + return + } + c.condition(e.Expr) + return + case *ast.RelationalOpExpr: + if isNilLiteral(e.Lhs) && c.isParamExpr(e.Rhs) { + return + } + if isNilLiteral(e.Rhs) && c.isParamExpr(e.Lhs) { + return + } + case *ast.LogicalOpExpr: + c.condition(e.Lhs) + c.condition(e.Rhs) + return + } + c.expr(expr) +} + +func (c *paramUseCollector) expr(expr ast.Expr) { + if expr == nil { + return + } + + switch e := expr.(type) { + case *ast.IdentExpr: + c.whole(e) + case *ast.AttrGetExpr: + if c.pathUse(expr) { + return + } + c.expr(e.Object) + c.expr(e.Key) + case *ast.TableExpr: + for _, field := range e.Fields { + if field == nil { + continue + } + c.expr(field.Key) + c.expr(field.Value) + } + case *ast.FuncCallExpr: + c.call(e) + case *ast.LogicalOpExpr: + c.expr(e.Lhs) + c.expr(e.Rhs) + case *ast.RelationalOpExpr: + c.expr(e.Lhs) + c.expr(e.Rhs) + case *ast.StringConcatOpExpr: + c.expr(e.Lhs) + c.expr(e.Rhs) + case *ast.ArithmeticOpExpr: + c.expr(e.Lhs) + c.expr(e.Rhs) + case *ast.UnaryMinusOpExpr: + c.expr(e.Expr) + case *ast.UnaryNotOpExpr: + c.expr(e.Expr) + case *ast.UnaryLenOpExpr: + c.expr(e.Expr) + case *ast.UnaryBNotOpExpr: + c.expr(e.Expr) + case *ast.FunctionExpr: + c.stmts(e.Stmts) + case *ast.CastExpr: + c.expr(e.Expr) + case *ast.NonNilAssertExpr: + c.expr(e.Expr) + } +} + +func (c *paramUseCollector) call(call *ast.FuncCallExpr) { + if call == nil { + return + } + recursive := c.isDirectRecursiveCall(call) + if call.Method != "" { + if recv := flowpath.FromExprWithBindings(call.Receiver, nil, c.bindings); c.isParamPath(recv) { + c.field(recv.Symbol, firstFieldOrMethod(recv, call.Method)) + } else { + c.expr(call.Receiver) + } + } else if callee := flowpath.FromExprWithBindings(call.Func, nil, c.bindings); c.isParamPath(callee) { + if len(callee.Segments) == 0 { + c.markWhole(callee.Symbol) + } else { + c.field(callee.Symbol, segmentFieldName(callee.Segments[0])) + } + } else { + c.expr(call.Func) + } + + for _, arg := range call.Args { + if recursive && c.isParamExpr(arg) { + continue + } + c.expr(arg) + } +} + +func (c *paramUseCollector) isDirectRecursiveCall(call *ast.FuncCallExpr) bool { + if call == nil || call.Method != "" || len(c.currentFunctionSymbols) == 0 { + return false + } + callee := flowpath.FromExprWithBindings(call.Func, nil, c.bindings) + if callee.Symbol == 0 || len(callee.Segments) != 0 { + return false + } + _, ok := c.currentFunctionSymbols[callee.Symbol] + return ok +} + +func (c *paramUseCollector) lvalue(expr ast.Expr) { + switch e := expr.(type) { + case *ast.AttrGetExpr: + if c.pathUse(expr) { + return + } + c.expr(e.Object) + c.expr(e.Key) + default: + c.expr(expr) + } +} + +func (c *paramUseCollector) pathUse(expr ast.Expr) bool { + p := flowpath.FromExprWithBindings(expr, nil, c.bindings) + if !c.isParamPath(p) { + return false + } + if len(p.Segments) == 0 { + c.markWhole(p.Symbol) + return true + } + c.field(p.Symbol, segmentFieldName(p.Segments[0])) + return true +} + +func (c *paramUseCollector) whole(expr ast.Expr) { + if c.bindings == nil || expr == nil { + return + } + ident, ok := expr.(*ast.IdentExpr) + if !ok { + return + } + sym, ok := c.bindings.SymbolOf(ident) + if !ok || sym == 0 { + return + } + if _, isParam := c.paramSymbols[sym]; !isParam { + return + } + c.markWhole(sym) +} + +func (c *paramUseCollector) isParamExpr(expr ast.Expr) bool { + ident, ok := expr.(*ast.IdentExpr) + return ok && c.isParamIdent(ident) +} + +func (c *paramUseCollector) isParamIdent(ident *ast.IdentExpr) bool { + if c.bindings == nil || ident == nil { + return false + } + sym, ok := c.bindings.SymbolOf(ident) + if !ok || sym == 0 { + return false + } + _, ok = c.paramSymbols[sym] + return ok +} + +func isNilLiteral(expr ast.Expr) bool { + _, ok := expr.(*ast.NilExpr) + return ok +} + +func (c *paramUseCollector) isParamPath(p constraint.Path) bool { + if p.IsEmpty() || p.Symbol == 0 { + return false + } + _, ok := c.paramSymbols[p.Symbol] + return ok +} + +func (c *paramUseCollector) markWhole(sym cfg.SymbolID) { + use := c.uses[sym] + use.whole = true + c.uses[sym] = use +} + +func (c *paramUseCollector) field(sym cfg.SymbolID, name string) { + if name == "" { + c.markWhole(sym) + return + } + use := c.uses[sym] + if use.fields == nil { + use.fields = make(map[string]struct{}, 1) + } + use.fields[name] = struct{}{} + c.uses[sym] = use +} + +func firstFieldOrMethod(p constraint.Path, method string) string { + if len(p.Segments) == 0 { + return method + } + return segmentFieldName(p.Segments[0]) +} + +func currentFunctionSymbols(graph *cfg.Graph, fn *ast.FunctionExpr) map[cfg.SymbolID]struct{} { + if graph == nil || fn == nil { + return nil + } + syms := make(map[cfg.SymbolID]struct{}, 1) + if bindings := graph.Bindings(); bindings != nil { + if sym, ok := bindings.FuncLitSymbol(fn); ok && sym != 0 { + syms[sym] = struct{}{} + } + } + for _, localFn := range graph.LocalFunctionAssignments() { + if localFn.Func == fn && localFn.Symbol != 0 { + syms[localFn.Symbol] = struct{}{} + } + } + graph.EachFuncDef(func(_ cfg.Point, info *cfg.FuncDefInfo) { + if info != nil && info.FuncExpr == fn && info.Symbol != 0 { + syms[info.Symbol] = struct{}{} + } + }) + if len(syms) == 0 { + return nil + } + return syms +} + +func segmentFieldName(seg constraint.Segment) string { + switch seg.Kind { + case constraint.SegmentField, constraint.SegmentIndexString: + return seg.Name + default: + return "" + } +} + +func projectHintToUse(hint typ.Type, use paramUse) typ.Type { + if hint == nil || use.whole { + return hint + } + if len(use.fields) == 0 { + return nil + } + projected, ok := projectTypeToFields(hint, use.fields) + if !ok { + return hint + } + return projected +} + +func projectTypeToFields(t typ.Type, fields map[string]struct{}) (typ.Type, bool) { + if t == nil { + return nil, false + } + switch v := typ.UnwrapAnnotated(t).(type) { + case *typ.Alias: + return projectTypeToFields(v.Target, fields) + case *typ.Optional: + inner, ok := projectTypeToFields(v.Inner, fields) + if !ok { + return t, false + } + return typ.NewOptional(inner), true + case *typ.Union: + members := make([]typ.Type, 0, len(v.Members)) + for _, member := range v.Members { + projected, ok := projectTypeToFields(member, fields) + if !ok { + return t, false + } + members = append(members, projected) + } + return typ.NewUnion(members...), true + case *typ.Record: + return projectRecordToFields(v, fields), true + default: + return t, false + } +} + +func projectRecordToFields(r *typ.Record, fields map[string]struct{}) typ.Type { + builder := typ.NewRecord().SetOpen(true) + if r.Metatable != nil { + builder.Metatable(r.Metatable) + } + names := make([]string, 0, len(fields)) + for name := range fields { + names = append(names, name) + } + sort.Strings(names) + for _, name := range names { + field := r.GetField(name) + if field == nil { + if r.HasMapComponent() && subtype.IsSubtype(typ.LiteralString(name), r.MapKey) { + mapValue := r.MapValue + if mapValue == nil { + mapValue = typ.Unknown + } + builder.OptField(name, mapValue) + } else if !r.Open { + builder.Field(name, typ.Nil) + } + continue + } + switch { + case field.Optional && field.Readonly: + builder.OptReadonlyField(field.Name, field.Type) + case field.Optional: + builder.OptField(field.Name, field.Type) + case field.Readonly: + builder.ReadonlyField(field.Name, field.Type) + default: + builder.Field(field.Name, field.Type) + } + } + return builder.Build() +} diff --git a/compiler/check/infer/return/infer.go b/compiler/check/infer/return/infer.go index a4d38553..02bec4c8 100644 --- a/compiler/check/infer/return/infer.go +++ b/compiler/check/infer/return/infer.go @@ -42,18 +42,22 @@ import ( "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/callsite" + flowpath "github.com/wippyai/go-lua/compiler/check/flowbuild/path" "github.com/wippyai/go-lua/compiler/check/infer/paramhints" "github.com/wippyai/go-lua/compiler/check/modules" "github.com/wippyai/go-lua/compiler/check/phase" "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/compiler/check/scope" "github.com/wippyai/go-lua/compiler/check/synth" + "github.com/wippyai/go-lua/compiler/check/synth/ops" "github.com/wippyai/go-lua/types/constraint" "github.com/wippyai/go-lua/types/db" "github.com/wippyai/go-lua/types/diag" "github.com/wippyai/go-lua/types/flow" "github.com/wippyai/go-lua/types/io" "github.com/wippyai/go-lua/types/query/core" + "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" "github.com/wippyai/go-lua/types/typ/unwrap" ) @@ -224,7 +228,7 @@ func (i *Inferencer) ComputeForGraph( continue } if hintVec, ok := hints[sym]; ok && len(hintVec) > 0 { - info.ParamHints = hintVec + info.ParamHints = paramhints.ProjectHintsToParamUse(info.Graph, info.Fn, hintVec) } } } @@ -389,6 +393,7 @@ func collectReturnTypes( fnGraph *cfg.Graph, synthEngine api.Synth, deadPoints map[cfg.Point]bool, + skipReturnExpr func(ast.Expr) bool, ) []typ.Type { if fnGraph == nil || synthEngine == nil { return nil @@ -402,7 +407,10 @@ func collectReturnTypes( } _ = deadPoints - types := synthesizeReturnExprs(synthEngine, retInfo, p) + if len(retInfo.Exprs) == 1 && skipReturnExpr != nil && skipReturnExpr(retInfo.Exprs[0]) { + return + } + types := synthesizeReturnExprs(synthEngine, retInfo, p, skipReturnExpr) if !seenReturn { seenReturn = true returnTypes = types @@ -420,11 +428,11 @@ func synthesizeReturnExprs( synthEngine api.Synth, retInfo *cfg.ReturnInfo, p cfg.Point, + skipReturnExpr func(ast.Expr) bool, ) []typ.Type { if len(retInfo.Exprs) == 0 { return nil } - types := make([]typ.Type, 0, len(retInfo.Exprs)) for i, expr := range retInfo.Exprs { if i == len(retInfo.Exprs)-1 && ast.CanProduceMultipleValues(expr) { @@ -473,7 +481,8 @@ func (i *Inferencer) inferReturnTypesFromBody( finalOverlay map[cfg.SymbolID]typ.Type, ) []typ.Type { state := i.runPhase2FlowNarrowing(ctx, finalOverlay) - narrowed := collectReturnTypes(ctx.info.Graph, state.synth, state.deadPoints) + skipUnresolvedLocalCall := i.skipUnresolvedLocalReturnCall(ctx) + narrowed := collectReturnTypes(ctx.info.Graph, state.synth, state.deadPoints, skipUnresolvedLocalCall) fnGraph := ctx.info.Graph if fnGraph == nil { @@ -494,11 +503,35 @@ func (i *Inferencer) inferReturnTypesFromBody( uniformFunctionScopes(fnGraph, ctx.resolveScope), declCheckCtx, ) - declared := collectReturnTypes(fnGraph, declSynth, nil) + declared := collectReturnTypes(fnGraph, declSynth, nil, skipUnresolvedLocalCall) return returns.MergeReturnSummary(declared, narrowed) } +func (i *Inferencer) skipUnresolvedLocalReturnCall(ctx *returnInferenceContext) func(ast.Expr) bool { + if ctx == nil || ctx.info == nil || ctx.info.Graph == nil || len(ctx.localFuncs) == 0 { + return nil + } + bindings := ctx.bindings + if bindings == nil { + bindings = ctx.info.Graph.Bindings() + } + if bindings == nil { + return nil + } + return func(expr ast.Expr) bool { + call, ok := expr.(*ast.FuncCallExpr) + if !ok || call == nil || call.Method != "" { + return false + } + sym := callsite.SymbolFromExpr(call.Func, bindings) + if sym == 0 || ctx.localFuncs[sym] == nil { + return false + } + return typ.IsUnknownOnlyOrEmpty(returns.NormalizeReturnVector(ctx.returnVectors[sym])) + } +} + // inferReturnForFunction infers return types for one local function from the // current SCC return-vector state. // This is the core inference logic called by computeReturnVectorsForGroup for each function. @@ -593,14 +626,478 @@ func (i *Inferencer) inferReturnForFunction( i.enrichOverlayWithCaptured(ctx, overlay) // Add local declared types (annotations, loop variables) as overlay hints. - i.enrichOverlayWithLocalDeclarations(ctx, overlay) + localValueSeeds := i.enrichOverlayWithLocalDeclarations(ctx, overlay) + + // Body-derived parameter contracts are needed by local assignment inference + // in the same function. For example, a helper call may prove that a parameter + // field is string?, which then makes `param.field or "default"` synthesize as + // string without relying on a value-level fallback shortcut. + i.mergeParamHintsFromBodyUses(ctx, overlay) + i.applyParamHintsToOverlay(ctx, overlay) // Phase 1: Infer local variable types. - inferred, _, synthAdapter := i.inferLocalVariableTypes(ctx, overlay) + inferred, _, synthAdapter := i.inferLocalVariableTypes(ctx, overlay, localValueSeeds) // Collect field/indexer assignments and apply mutations. - finalOverlay := i.collectAndApplyMutations(ctx, overlay, inferred, synthAdapter) + finalOverlay := i.collectAndApplyMutations(ctx, overlay, inferred, synthAdapter, localValueSeeds) + i.mergeParamHintsFromOverlay(ctx, finalOverlay) // Phase 2: Infer return types from body. return i.inferReturnTypesFromBody(ctx, finalOverlay) } + +func (i *Inferencer) applyParamHintsToOverlay(ctx *returnInferenceContext, overlay map[cfg.SymbolID]typ.Type) { + if ctx == nil || ctx.info == nil || ctx.info.Graph == nil || len(ctx.info.ParamHints) == 0 || overlay == nil { + return + } + for idx, slot := range ctx.info.Graph.ParamSlotsReadOnly() { + if slot.Symbol == 0 || idx >= len(ctx.info.ParamHints) { + continue + } + hint := ctx.info.ParamHints[idx] + if !paramhints.IsInformativeHintType(hint) { + continue + } + if slot.TypeAnnotation != nil && ctx.engine != nil { + resolved := ctx.engine.ResolveType(slot.TypeAnnotation, ctx.resolveScope) + if resolved != nil && !typ.IsRefinableAnnotation(resolved) { + continue + } + } + overlay[slot.Symbol] = hint + } +} + +func (i *Inferencer) mergeParamHintsFromOverlay(ctx *returnInferenceContext, overlay map[cfg.SymbolID]typ.Type) { + if ctx == nil || ctx.info == nil || ctx.info.Graph == nil || ctx.info.Fn == nil || len(overlay) == 0 { + return + } + for idx, slot := range ctx.info.Graph.ParamSlotsReadOnly() { + if slot.Symbol == 0 { + continue + } + _, hasSource := slot.SourceParamIndex() + if hasSource && slot.TypeAnnotation != nil && ctx.engine != nil { + resolved := ctx.engine.ResolveType(slot.TypeAnnotation, ctx.resolveScope) + if resolved != nil && !typ.IsRefinableAnnotation(resolved) { + continue + } + } + t := overlay[slot.Symbol] + if !paramhints.IsInformativeHintType(t) { + continue + } + next, merged := paramhints.MergeHintAt(ctx.info.ParamHints, idx, t, typ.JoinPreferNonSoft) + if merged { + ctx.info.ParamHints = next + } + } + i.mergeParamHintsFromBodyUses(ctx, overlay) +} + +func (i *Inferencer) mergeParamHintsFromBodyUses(ctx *returnInferenceContext, overlay map[cfg.SymbolID]typ.Type) { + if i == nil || ctx == nil || ctx.info == nil || ctx.info.Graph == nil || ctx.info.Fn == nil { + return + } + bindings := ctx.info.Graph.Bindings() + if bindings == nil || i.types == nil { + return + } + paramIndexBySym := make(map[cfg.SymbolID]int) + for idx, slot := range ctx.info.Graph.ParamSlotsReadOnly() { + if slot.Symbol == 0 { + continue + } + _, hasSource := slot.SourceParamIndex() + if hasSource && slot.TypeAnnotation != nil && ctx.engine != nil { + resolved := ctx.engine.ResolveType(slot.TypeAnnotation, ctx.resolveScope) + if resolved != nil && !typ.IsRefinableAnnotation(resolved) { + continue + } + } + paramIndexBySym[slot.Symbol] = idx + } + if len(paramIndexBySym) == 0 { + return + } + + var visitStmt func(ast.Stmt) + var visitExpr func(ast.Expr) + mergeReceiver := func(receiver ast.Expr, method string) { + if receiver == nil || method == "" { + return + } + ident, ok := receiver.(*ast.IdentExpr) + if !ok || ident == nil { + return + } + sym, ok := bindings.SymbolOf(ident) + if !ok || sym == 0 { + return + } + idx, ok := paramIndexBySym[sym] + if !ok { + return + } + hint := i.receiverHintForMethod(ctx, method) + if !paramhints.IsInformativeHintType(hint) { + return + } + next, merged := paramhints.MergeHintAt(ctx.info.ParamHints, idx, hint, typ.JoinPreferNonSoft) + if merged { + ctx.info.ParamHints = next + } + } + mergeParamFieldHint := func(sym cfg.SymbolID, field string, hint typ.Type, required bool) { + if sym == 0 || field == "" || !paramhints.IsInformativeHintType(hint) { + return + } + idx, ok := paramIndexBySym[sym] + if !ok { + return + } + builder := typ.NewRecord() + if required { + builder.Field(field, hint) + } else { + builder.OptField(field, hint) + } + rec := builder.Build() + next, merged := paramhints.MergeHintAt(ctx.info.ParamHints, idx, rec, typ.JoinPreferNonSoft) + if merged { + ctx.info.ParamHints = next + } + } + bodyContractJoin := func(prev, next typ.Type) typ.Type { + if next != nil { + return next + } + return prev + } + mergeParamHint := func(sym cfg.SymbolID, hint typ.Type) { + if sym == 0 || !paramhints.IsInformativeHintType(hint) { + return + } + idx, ok := paramIndexBySym[sym] + if !ok { + return + } + next, merged := paramhints.MergeHintAt(ctx.info.ParamHints, idx, hint, bodyContractJoin) + if merged { + ctx.info.ParamHints = next + } + } + paramSymbol := func(expr ast.Expr) (cfg.SymbolID, bool) { + ident, ok := expr.(*ast.IdentExpr) + if !ok || ident == nil { + return 0, false + } + sym, ok := bindings.SymbolOf(ident) + if !ok || sym == 0 { + return 0, false + } + if _, ok := paramIndexBySym[sym]; !ok { + return 0, false + } + return sym, true + } + paramFieldPath := func(expr ast.Expr) (cfg.SymbolID, string, bool) { + attr, ok := expr.(*ast.AttrGetExpr) + if !ok || attr == nil { + return 0, "", false + } + obj, ok := attr.Object.(*ast.IdentExpr) + if !ok || obj == nil { + return 0, "", false + } + key, ok := attr.Key.(*ast.StringExpr) + if !ok || key == nil || key.Value == "" { + return 0, "", false + } + sym, ok := bindings.SymbolOf(obj) + if !ok { + return 0, "", false + } + if _, ok := paramIndexBySym[sym]; !ok { + return 0, "", false + } + return sym, key.Value, true + } + typeAt := func(expr ast.Expr, p cfg.Point) typ.Type { + if expr == nil { + return typ.Unknown + } + if t, ok := overlayPathType(expr, overlay, bindings, i.types, ctx.run.Ctx); ok { + return t + } + if ctx.engine != nil { + if t := ctx.engine.TypeOf(expr, p); t != nil { + return t + } + } + return typ.Unknown + } + isDirectSelfRecursiveCall := func(info *cfg.CallInfo) bool { + if info == nil || ctx.info.Sym == 0 { + return false + } + for _, sym := range callsite.CallableCalleeSymbolCandidates(info, ctx.info.Graph, bindings, bindings) { + if sym == ctx.info.Sym { + return true + } + } + return false + } + var bodyParamContracts map[cfg.SymbolID]typ.Type + mergeParamContract := func(sym cfg.SymbolID, hint typ.Type) { + if sym == 0 || !paramhints.IsInformativeHintType(hint) { + return + } + if bodyParamContracts == nil { + bodyParamContracts = make(map[cfg.SymbolID]typ.Type) + } + if prev := bodyParamContracts[sym]; prev != nil { + bodyParamContracts[sym] = subtype.NormalizeIntersection(prev, hint) + return + } + bodyParamContracts[sym] = hint + } + mergeCallExpectedFieldHints := func(p cfg.Point, info *cfg.CallInfo) { + if info == nil || i.types == nil { + return + } + if isDirectSelfRecursiveCall(info) { + return + } + args := make([]typ.Type, len(info.Args)) + for idx, arg := range info.Args { + args[idx] = typeAt(arg, p) + } + def := ops.CallDef{ + Args: args, + Query: i.types, + } + if info.Method != "" { + def.IsMethod = true + def.MethodName = info.Method + def.Receiver = typeAt(info.Receiver, p) + def.ForceMethodReceiver = callsite.ForceMethodReceiver(bindings, ctx.info.Graph, info) + } else { + def.Callee = typeAt(info.Callee, p) + } + inferredCall := ops.InferCall(ctx.run.Ctx, def) + for idx, arg := range info.Args { + expected := inferredCall.ExpectedArgType(idx) + if !paramhints.IsInformativeHintType(expected) { + continue + } + if sym, ok := paramSymbol(arg); ok { + mergeParamContract(sym, expected) + continue + } + if sym, field, ok := paramFieldPath(arg); ok { + mergeParamFieldHint(sym, field, expected, true) + continue + } + } + } + ctx.info.Graph.EachCallSite(func(p cfg.Point, info *cfg.CallInfo) { + mergeCallExpectedFieldHints(p, info) + }) + for _, sym := range cfg.SortedSymbolIDs(bodyParamContracts) { + mergeParamHint(sym, bodyParamContracts[sym]) + } + defaultLiteralType := func(expr ast.Expr) typ.Type { + switch expr.(type) { + case *ast.StringExpr: + return typ.String + case *ast.NumberExpr: + return typ.Number + case *ast.TrueExpr, *ast.FalseExpr: + return typ.Boolean + default: + return nil + } + } + visitExpr = func(expr ast.Expr) { + switch e := expr.(type) { + case *ast.FuncCallExpr: + mergeReceiver(e.Receiver, e.Method) + visitExpr(e.Func) + visitExpr(e.Receiver) + for _, arg := range e.Args { + visitExpr(arg) + } + case *ast.AttrGetExpr: + visitExpr(e.Object) + visitExpr(e.Key) + case *ast.TableExpr: + for _, f := range e.Fields { + if f == nil { + continue + } + visitExpr(f.Key) + visitExpr(f.Value) + } + case *ast.LogicalOpExpr: + if e.Operator == "or" { + if sym, field, ok := paramFieldPath(e.Lhs); ok { + mergeParamFieldHint(sym, field, defaultLiteralType(e.Rhs), false) + } + } + visitExpr(e.Lhs) + visitExpr(e.Rhs) + case *ast.RelationalOpExpr: + visitExpr(e.Lhs) + visitExpr(e.Rhs) + case *ast.StringConcatOpExpr: + visitExpr(e.Lhs) + visitExpr(e.Rhs) + case *ast.ArithmeticOpExpr: + visitExpr(e.Lhs) + visitExpr(e.Rhs) + case *ast.UnaryMinusOpExpr: + visitExpr(e.Expr) + case *ast.UnaryNotOpExpr: + visitExpr(e.Expr) + case *ast.UnaryLenOpExpr: + visitExpr(e.Expr) + case *ast.UnaryBNotOpExpr: + visitExpr(e.Expr) + case *ast.CastExpr: + visitExpr(e.Expr) + case *ast.NonNilAssertExpr: + visitExpr(e.Expr) + case *ast.FunctionExpr: + return + } + } + visitStmt = func(stmt ast.Stmt) { + switch s := stmt.(type) { + case *ast.AssignStmt: + for _, expr := range s.Lhs { + visitExpr(expr) + } + for _, expr := range s.Rhs { + visitExpr(expr) + } + case *ast.LocalAssignStmt: + for _, expr := range s.Exprs { + visitExpr(expr) + } + case *ast.FuncCallStmt: + visitExpr(s.Expr) + case *ast.DoBlockStmt: + for _, child := range s.Stmts { + visitStmt(child) + } + case *ast.WhileStmt: + visitExpr(s.Condition) + for _, child := range s.Stmts { + visitStmt(child) + } + case *ast.RepeatStmt: + for _, child := range s.Stmts { + visitStmt(child) + } + visitExpr(s.Condition) + case *ast.IfStmt: + visitExpr(s.Condition) + for _, child := range s.Then { + visitStmt(child) + } + for _, child := range s.Else { + visitStmt(child) + } + case *ast.NumberForStmt: + visitExpr(s.Init) + visitExpr(s.Limit) + visitExpr(s.Step) + for _, child := range s.Stmts { + visitStmt(child) + } + case *ast.GenericForStmt: + for _, expr := range s.Exprs { + visitExpr(expr) + } + for _, child := range s.Stmts { + visitStmt(child) + } + case *ast.FuncDefStmt: + if s.Name != nil { + visitExpr(s.Name.Func) + visitExpr(s.Name.Receiver) + } + case *ast.ReturnStmt: + for _, expr := range s.Exprs { + visitExpr(expr) + } + } + } + for _, stmt := range ctx.info.Fn.Stmts { + visitStmt(stmt) + } +} + +func (i *Inferencer) receiverHintForMethod(ctx *returnInferenceContext, method string) typ.Type { + if i == nil || i.types == nil || method == "" { + return nil + } + methodType, ok := i.types.Method(ctx.run.Ctx, typ.String, method) + if !ok || methodType == nil { + return nil + } + fn, ok := methodType.(*typ.Function) + if !ok || len(fn.Params) == 0 || !typ.TypeEquals(fn.Params[0].Type, typ.String) { + return nil + } + return typ.String +} + +func overlayPathType( + expr ast.Expr, + overlay map[cfg.SymbolID]typ.Type, + bindings *bind.BindingTable, + typeOps core.TypeOps, + ctx *db.QueryContext, +) (typ.Type, bool) { + if expr == nil || len(overlay) == 0 || bindings == nil { + return nil, false + } + p := flowpath.FromExprWithBindings(expr, nil, bindings) + if p.IsEmpty() || p.Symbol == 0 { + return nil, false + } + t, ok := overlay[p.Symbol] + if !ok || t == nil { + return nil, false + } + for _, seg := range p.Segments { + if typeOps == nil { + return nil, false + } + switch seg.Kind { + case constraint.SegmentField: + ft, ok := typeOps.Field(ctx, t, seg.Name) + if !ok { + return nil, false + } + t = ft + case constraint.SegmentIndexString: + ft, ok := typeOps.Index(ctx, t, typ.LiteralString(seg.Name)) + if !ok { + return nil, false + } + t = ft + case constraint.SegmentIndexInt: + ft, ok := typeOps.Index(ctx, t, typ.LiteralInt(int64(seg.Index))) + if !ok { + return nil, false + } + t = ft + default: + return nil, false + } + } + return t, true +} diff --git a/compiler/check/infer/return/overlay_pipeline.go b/compiler/check/infer/return/overlay_pipeline.go index 50a5d745..1c7e5dc7 100644 --- a/compiler/check/infer/return/overlay_pipeline.go +++ b/compiler/check/infer/return/overlay_pipeline.go @@ -2,17 +2,20 @@ package infer import ( "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/flowbuild/assign" fbcore "github.com/wippyai/go-lua/compiler/check/flowbuild/core" "github.com/wippyai/go-lua/compiler/check/flowbuild/mutator" "github.com/wippyai/go-lua/compiler/check/flowbuild/resolve" + "github.com/wippyai/go-lua/compiler/check/infer/paramhints" "github.com/wippyai/go-lua/compiler/check/phase" "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/compiler/check/scope" "github.com/wippyai/go-lua/compiler/check/siblings" "github.com/wippyai/go-lua/compiler/check/synth" + "github.com/wippyai/go-lua/types/db" "github.com/wippyai/go-lua/types/flow" "github.com/wippyai/go-lua/types/query/core" "github.com/wippyai/go-lua/types/typ" @@ -22,23 +25,24 @@ func (i *Inferencer) buildParameterOverlay(ctx *returnInferenceContext) map[cfg. fnGraph := ctx.info.Graph paramSlots := fnGraph.ParamSlotsReadOnly() overlay := make(map[cfg.SymbolID]typ.Type, overlaySymbolCapacity(fnGraph, len(paramSlots))) - for _, slot := range paramSlots { + for paramIdx, slot := range paramSlots { if slot.Symbol == 0 { continue } // Binder/CFG-injected implicit self parameter. - srcIdx, hasSource := slot.SourceParamIndex() + _, hasSource := slot.SourceParamIndex() if !hasSource { if selfType := ctx.resolveScope.SelfType(); selfType != nil { overlay[slot.Symbol] = selfType + } else if ctx.info.ParamHints != nil && paramIdx < len(ctx.info.ParamHints) && ctx.info.ParamHints[paramIdx] != nil { + overlay[slot.Symbol] = ctx.info.ParamHints[paramIdx] } else { overlay[slot.Symbol] = typ.Unknown } continue } - i := srcIdx paramType := typ.Unknown if slot.Name == "self" { if selfType := ctx.resolveScope.SelfType(); selfType != nil { @@ -46,10 +50,13 @@ func (i *Inferencer) buildParameterOverlay(ctx *returnInferenceContext) map[cfg. } } if typ.IsAbsentOrUnknown(paramType) { - if ctx.info.ParamHints != nil && i < len(ctx.info.ParamHints) && ctx.info.ParamHints[i] != nil { - paramType = ctx.info.ParamHints[i] + if ctx.info.ParamHints != nil && paramIdx < len(ctx.info.ParamHints) && ctx.info.ParamHints[paramIdx] != nil { + paramType = ctx.info.ParamHints[paramIdx] } } + if typ.IsAbsentOrUnknown(paramType) && slot.TypeAnnotation == nil { + paramType = typ.Any + } if slot.TypeAnnotation != nil { resolved := ctx.engine.ResolveType(slot.TypeAnnotation, ctx.resolveScope) if resolved != nil { @@ -99,16 +106,34 @@ func (i *Inferencer) enrichOverlayWithSiblings( CurrentSym: ctx.info.Sym, Services: siblings.OverlayServicesFuncs{ SeedTypeFn: func(fn *ast.FunctionExpr) typ.Type { + var localInfo *returns.LocalFuncInfo + for _, sym := range cfg.SortedSymbolIDs(ctx.localFuncs) { + candidate := ctx.localFuncs[sym] + if candidate != nil && candidate.Fn == fn { + localInfo = candidate + break + } + } var bindings interface { ParamSymbols(*ast.FunctionExpr) []cfg.SymbolID Name(cfg.SymbolID) string } - if ctx.info != nil && ctx.info.Graph != nil { + if localInfo != nil && localInfo.Graph != nil { + if b := localInfo.Graph.Bindings(); b != nil { + bindings = b + } + } + if bindings == nil && ctx.info != nil && ctx.info.Graph != nil { if b := ctx.info.Graph.Bindings(); b != nil { bindings = b } } - return returns.BuildSeedFunctionTypeWithBindings(fn, ctx.engine, ctx.resolveScope, bindings) + seed := returns.BuildSeedFunctionTypeWithBindings(fn, ctx.engine, ctx.resolveScope, bindings) + fnType, _ := seed.(*typ.Function) + if localInfo != nil && len(localInfo.ParamHints) > 0 && fnType != nil { + return paramhints.MergeIntoSignature(fn, localInfo.ParamHints, fnType) + } + return seed }, }, }) @@ -231,6 +256,9 @@ func (i *Inferencer) enrichOverlayWithLocalFunctions( } returnVector := i.resolveLocalFunctionReturns(ctx, allReturnVectors, target.Symbol) sig := ctx.engine.ResolveFunctionSignature(fnExpr, ctx.resolveScope) + if localInfo := ctx.localFuncs[target.Symbol]; localInfo != nil && len(localInfo.ParamHints) > 0 && sig != nil { + sig = paramhints.MergeIntoSignature(fnExpr, localInfo.ParamHints, sig) + } if fnType := returns.WithSummaryOrUnknown(sig, returnVector); fnType != nil { overlay[target.Symbol] = fnType } @@ -303,6 +331,10 @@ func (i *Inferencer) enrichOverlayWithCaptured( if sym == 0 { continue } + if fnType := i.capturedFunctionFactType(ctx, sym); fnType != nil { + overlay[sym] = fnType + continue + } if existing, ok := overlay[sym]; ok && existing != nil && !typ.IsSoft(existing, typ.SoftAnnotationPolicy) { continue } @@ -316,17 +348,61 @@ func (i *Inferencer) enrichOverlayWithCaptured( } } +func (i *Inferencer) capturedFunctionFactType(ctx *returnInferenceContext, sym cfg.SymbolID) typ.Type { + if i == nil || i.store == nil || ctx == nil || sym == 0 { + return nil + } + ref := i.store.FunctionRefBySym(sym) + if ref == nil { + return nil + } + parentGraphID := ref.ParentGraphID + if parentGraphID == 0 { + parentGraphID = ref.GraphID + } + parentGraph := i.store.Graphs()[parentGraphID] + if parentGraph == nil { + return nil + } + parentScope := ctx.info.DefScope + if parentHash := i.store.GraphParentHashOf(parentGraphID); parentHash != 0 { + if scoped := i.store.Parents()[parentHash]; scoped != nil { + parentScope = scoped + } + } + facts := i.store.GetFunctionFactsSnapshot(parentGraph, parentScope) + return facts.FunctionType(sym) +} + // inferLocalVariableTypes runs phase 1 synthesis to infer local variable types. func (i *Inferencer) inferLocalVariableTypes( ctx *returnInferenceContext, overlay map[cfg.SymbolID]typ.Type, + localValueSeeds map[cfg.SymbolID]bool, ) (map[cfg.SymbolID]typ.Type, *synth.Engine, func(ast.Expr, cfg.Point) typ.Type) { fnGraph := ctx.info.Graph annotated := make(map[cfg.SymbolID]bool, len(overlay)) paramSet := paramSymbolSet(fnGraph) + explicitParamAnnotations := make(map[cfg.SymbolID]bool) + for _, slot := range fnGraph.ParamSlotsReadOnly() { + if slot.Symbol == 0 || slot.TypeAnnotation == nil || ctx.engine == nil { + continue + } + resolved := ctx.engine.ResolveType(slot.TypeAnnotation, ctx.resolveScope) + if resolved != nil && !typ.IsRefinableAnnotation(resolved) { + explicitParamAnnotations[slot.Symbol] = true + } + } for sym, tp := range overlay { if paramSet[sym] { - annotated[sym] = true + if explicitParamAnnotations[sym] { + annotated[sym] = true + } else if tp != nil && !typ.IsAny(tp) && !typ.IsAbsentOrUnknown(tp) && !typ.IsSoft(tp, typ.SoftAnnotationPolicy) { + annotated[sym] = true + } + continue + } + if localValueSeeds[sym] { continue } if tp != nil && !typ.IsSoft(tp, typ.SoftAnnotationPolicy) { @@ -356,13 +432,21 @@ func (i *Inferencer) inferLocalVariableTypes( } }) + inferenceOverlay := overlay + if len(localValueSeeds) > 0 { + inferenceOverlay = cloneOverlay(overlay, 0) + for sym := range localValueSeeds { + delete(inferenceOverlay, sym) + } + } + fnScopes := uniformFunctionScopes(fnGraph, ctx.resolveScope) prelimCtx := api.NewReturnInferenceEnv(api.ReturnInferenceEnvConfig{ Graph: fnGraph, Bindings: ctx.bindings, BaseScope: ctx.resolveScope, - DeclaredTypes: overlay, + DeclaredTypes: inferenceOverlay, GlobalTypes: i.globalTypes, ModuleAliases: ctx.moduleAliases, FunctionFacts: functionFactsFromReturnVectors(ctx.returnVectors), @@ -396,7 +480,7 @@ func (i *Inferencer) inferLocalVariableTypes( Derived: &fbcore.Derived{ SymResolver: symResolver, }, - }, overlay, annotated, nil) + }, inferenceOverlay, annotated, nil) return inferred, prelimEngine, synthAdapter } @@ -404,12 +488,13 @@ func (i *Inferencer) inferLocalVariableTypes( func (i *Inferencer) enrichOverlayWithLocalDeclarations( ctx *returnInferenceContext, overlay map[cfg.SymbolID]typ.Type, -) { +) map[cfg.SymbolID]bool { fnGraph := ctx.info.Graph if fnGraph == nil { - return + return nil } + localValueSeeds := make(map[cfg.SymbolID]bool) fnGraph.EachAssign(func(p cfg.Point, info *cfg.AssignInfo) { if info == nil || !info.IsLocal { return @@ -461,12 +546,17 @@ func (i *Inferencer) enrichOverlayWithLocalDeclarations( if _, ok := info.Sources[idx].(*ast.TableExpr); ok { if seeded := ctx.engine.TypeOf(info.Sources[idx], p); seeded != nil { overlay[target.Symbol] = seeded + localValueSeeds[target.Symbol] = true } } } }) }) + if len(localValueSeeds) == 0 { + return nil + } + return localValueSeeds } type localSymbolLookup interface { @@ -476,6 +566,7 @@ type localSymbolLookup interface { type overlayMutationStage struct { fnGraph *cfg.Graph paramSyms map[cfg.SymbolID]bool + localValueSeeds map[cfg.SymbolID]bool finalOverlay map[cfg.SymbolID]typ.Type inferred map[cfg.SymbolID]typ.Type synthAdapter func(ast.Expr, cfg.Point) typ.Type @@ -488,10 +579,11 @@ func (i *Inferencer) collectAndApplyMutations( overlay map[cfg.SymbolID]typ.Type, inferred map[cfg.SymbolID]typ.Type, synthAdapter func(ast.Expr, cfg.Point) typ.Type, + localValueSeeds map[cfg.SymbolID]bool, ) map[cfg.SymbolID]typ.Type { - stage := newOverlayMutationStage(ctx, overlay, inferred, synthAdapter) - mergeInferredIntoOverlay(stage.finalOverlay, stage.inferred, stage.paramSyms) - stage.enrichedSynthAdapter = buildEnrichedSynthAdapter(stage.fnGraph.Bindings(), stage.inferred, stage.finalOverlay, stage.synthAdapter) + stage := newOverlayMutationStage(ctx, overlay, inferred, synthAdapter, localValueSeeds) + mergeInferredIntoOverlay(stage.finalOverlay, stage.inferred, stage.paramSyms, stage.localValueSeeds) + stage.enrichedSynthAdapter = buildEnrichedSynthAdapter(stage.fnGraph.Bindings(), stage.inferred, stage.finalOverlay, stage.localValueSeeds, i.types, ctx.run.Ctx, stage.synthAdapter) i.applyFieldMutations(ctx, &stage) i.applyIndexerMutations(&stage) @@ -505,17 +597,19 @@ func newOverlayMutationStage( overlay map[cfg.SymbolID]typ.Type, inferred map[cfg.SymbolID]typ.Type, synthAdapter func(ast.Expr, cfg.Point) typ.Type, + localValueSeeds map[cfg.SymbolID]bool, ) overlayMutationStage { fnGraph := (*cfg.Graph)(nil) if ctx != nil && ctx.info != nil { fnGraph = ctx.info.Graph } return overlayMutationStage{ - fnGraph: fnGraph, - paramSyms: paramSymbolSet(fnGraph), - finalOverlay: cloneOverlay(overlay, len(inferred)), - inferred: inferred, - synthAdapter: synthAdapter, + fnGraph: fnGraph, + paramSyms: paramSymbolSet(fnGraph), + localValueSeeds: localValueSeeds, + finalOverlay: cloneOverlay(overlay, len(inferred)), + inferred: inferred, + synthAdapter: synthAdapter, } } @@ -548,6 +642,7 @@ func mergeInferredIntoOverlay( finalOverlay map[cfg.SymbolID]typ.Type, inferred map[cfg.SymbolID]typ.Type, paramSyms map[cfg.SymbolID]bool, + localValueSeeds map[cfg.SymbolID]bool, ) { for sym, inferredType := range inferred { baseType := finalOverlay[sym] @@ -559,6 +654,10 @@ func mergeInferredIntoOverlay( } continue } + if localValueSeeds[sym] { + finalOverlay[sym] = inferredType + continue + } if typ.IsAbsentOrUnknown(baseType) { finalOverlay[sym] = inferredType continue @@ -683,15 +782,21 @@ func reconcileSoftAnnotatedInference(baseType, inferredType typ.Type) typ.Type { } func buildEnrichedSynthAdapter( - bindings localSymbolLookup, + bindings *bind.BindingTable, inferred map[cfg.SymbolID]typ.Type, finalOverlay map[cfg.SymbolID]typ.Type, + localValueSeeds map[cfg.SymbolID]bool, + typeOps core.TypeOps, + queryCtx *db.QueryContext, baseAdapter func(ast.Expr, cfg.Point) typ.Type, ) func(ast.Expr, cfg.Point) typ.Type { return func(expr ast.Expr, p cfg.Point) typ.Type { if ident, ok := expr.(*ast.IdentExpr); ok && bindings != nil { if sym, found := bindings.SymbolOf(ident); found && sym != 0 { if t, exists := inferred[sym]; exists && !typ.IsAbsentOrUnknown(t) { + if localValueSeeds[sym] { + return t + } if baseType := finalOverlay[sym]; baseType != nil && !typ.IsSoft(baseType, typ.SoftAnnotationPolicy) { return baseType } @@ -699,6 +804,9 @@ func buildEnrichedSynthAdapter( } } } + if t, ok := overlayPathType(expr, finalOverlay, bindings, typeOps, queryCtx); ok && !typ.IsAbsentOrUnknown(t) { + return t + } return baseAdapter(expr, p) } } diff --git a/compiler/check/infer/return/scc.go b/compiler/check/infer/return/scc.go index 0689a60e..40e6d948 100644 --- a/compiler/check/infer/return/scc.go +++ b/compiler/check/infer/return/scc.go @@ -4,6 +4,7 @@ import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/infer/paramhints" "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/types/diag" "github.com/wippyai/go-lua/types/typ" @@ -31,6 +32,7 @@ func (i *Inferencer) planLocalFunctionSCCs(localFuncs map[cfg.SymbolID]*returns. // Propagate inter-procedural parameter hints across local call edges before // SCC return inference so unannotated params get stable callsite-driven seeds. returns.PropagateParamHintsFromCallGraph(localFuncs) + projectLocalFunctionParamHints(localFuncs) var moduleBindings *bind.BindingTable if i != nil && i.store != nil { @@ -40,6 +42,16 @@ func (i *Inferencer) planLocalFunctionSCCs(localFuncs map[cfg.SymbolID]*returns. return returns.ComputeSymbolSCCs(adj) } +func projectLocalFunctionParamHints(localFuncs map[cfg.SymbolID]*returns.LocalFuncInfo) { + for _, sym := range cfg.SortedSymbolIDs(localFuncs) { + info := localFuncs[sym] + if info == nil || len(info.ParamHints) == 0 { + continue + } + info.ParamHints = paramhints.ProjectHintsToParamUse(info.Graph, info.Fn, info.ParamHints) + } +} + func seedReturnVectorsFromSeed( localFuncs map[cfg.SymbolID]*returns.LocalFuncInfo, seed map[cfg.SymbolID][]typ.Type, diff --git a/compiler/check/nested/constructor.go b/compiler/check/nested/constructor.go index 0cfae6dc..8b09f48b 100644 --- a/compiler/check/nested/constructor.go +++ b/compiler/check/nested/constructor.go @@ -224,7 +224,16 @@ func CollectConstructorFields(graph *cfg.Graph, selfSym cfg.SymbolID, synth func fields := assign.CollectFieldAssignments(graph, synth, filterSyms) if selfFields, ok := fields[selfSym]; ok && len(selfFields) > 0 { - return selfFields + filtered := make(map[string]typ.Type, len(selfFields)) + for name, t := range selfFields { + if typ.IsAbsentOrUnknown(t) { + continue + } + filtered[name] = t + } + if len(filtered) > 0 { + return filtered + } } return nil } diff --git a/compiler/check/nested/enrich.go b/compiler/check/nested/enrich.go index a258fed9..8d4081e4 100644 --- a/compiler/check/nested/enrich.go +++ b/compiler/check/nested/enrich.go @@ -2,6 +2,7 @@ package nested import ( "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/callsite" @@ -122,45 +123,61 @@ func CollectCapturedContainerMutations( return } - ceu := mutator.ContainerMutatorFromCall(info, p, synth, nil, nil, graph, bindings, nil) - if ceu == nil { - return - } - - targetExpr := callsite.RuntimeArgAt(info, ceu.Container.Index) - valueExpr := callsite.RuntimeArgAt(info, ceu.Value.Index) - if targetExpr == nil || valueExpr == nil { - return - } - - targetPath := flowpath.FromExprWithBindings(targetExpr, nil, bindings) - if targetPath.IsEmpty() || targetPath.Symbol == 0 { - return - } - if !capturedSyms[targetPath.Symbol] { - return + if ceu := mutator.ContainerMutatorFromCall(info, p, synth, nil, nil, graph, bindings, nil); ceu != nil { + targetExpr := callsite.RuntimeArgAt(info, ceu.Container.Index) + valueExpr := callsite.RuntimeArgAt(info, ceu.Value.Index) + collectCapturedContainerMutation(result, bindings, capturedSyms, targetExpr, valueExpr, p, synth, api.ContainerMutationContainerElement) } - var valueType typ.Type - if synth != nil { - valueType = synth(valueExpr, p) + if tm := mutator.TableMutatorFromCall(info, p, synth, nil, graph, bindings, nil); tm != nil { + targetExpr := callsite.RuntimeArgAt(info, tm.Target.Index) + valueExpr := callsite.RuntimeArgAt(info, tm.Value.Index) + collectCapturedContainerMutation(result, bindings, capturedSyms, targetExpr, valueExpr, p, synth, api.ContainerMutationTableElement) } - if valueType == nil { - valueType = typ.Unknown - } - valueType = subtype.WidenForInference(valueType) - - segs := make([]constraint.Segment, len(targetPath.Segments)) - copy(segs, targetPath.Segments) - result[targetPath.Symbol] = append(result[targetPath.Symbol], api.ContainerMutation{ - Segments: segs, - ValueType: valueType, - }) }) return result } +func collectCapturedContainerMutation( + result map[cfg.SymbolID][]api.ContainerMutation, + bindings *bind.BindingTable, + capturedSyms map[cfg.SymbolID]bool, + targetExpr ast.Expr, + valueExpr ast.Expr, + p cfg.Point, + synth func(ast.Expr, cfg.Point) typ.Type, + kind api.ContainerMutationKind, +) { + if result == nil || targetExpr == nil || valueExpr == nil { + return + } + targetPath := flowpath.FromExprWithBindings(targetExpr, nil, bindings) + if targetPath.IsEmpty() || targetPath.Symbol == 0 { + return + } + if !capturedSyms[targetPath.Symbol] { + return + } + + var valueType typ.Type + if synth != nil { + valueType = synth(valueExpr, p) + } + if valueType == nil { + valueType = typ.Unknown + } + valueType = subtype.WidenForInference(valueType) + + segs := make([]constraint.Segment, len(targetPath.Segments)) + copy(segs, targetPath.Segments) + result[targetPath.Symbol] = append(result[targetPath.Symbol], api.ContainerMutation{ + Kind: kind, + Segments: segs, + ValueType: valueType, + }) +} + // EnrichSelfTypeWithConstructorFields merges constructor instance fields into a self-type. // // When a method is defined on a class that has a constructor, the self-type should @@ -213,10 +230,14 @@ func mergeFieldsIntoSelfType(selfType typ.Type, fields map[string]typ.Type) typ. existingFields := make(map[string]bool) for _, f := range v.Fields { + fieldType := f.Type + if constructorType := fields[f.Name]; constructorType != nil && (typ.IsAbsentOrUnknown(fieldType) || typ.IsAny(fieldType)) { + fieldType = constructorType + } if f.Optional { - builder.OptField(f.Name, f.Type) + builder.OptField(f.Name, fieldType) } else { - builder.Field(f.Name, f.Type) + builder.Field(f.Name, fieldType) } existingFields[f.Name] = true } diff --git a/compiler/check/nested/enrich_test.go b/compiler/check/nested/enrich_test.go index b496afb4..5cb2f269 100644 --- a/compiler/check/nested/enrich_test.go +++ b/compiler/check/nested/enrich_test.go @@ -5,6 +5,7 @@ import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/parse" "github.com/wippyai/go-lua/types/contract" "github.com/wippyai/go-lua/types/effect" @@ -104,6 +105,48 @@ func TestCollectCapturedContainerMutations_AssignmentCallSite(t *testing.T) { if len(muts) != 1 { t.Fatalf("expected 1 container mutation for c, got %d", len(muts)) } + if muts[0].Kind != api.ContainerMutationContainerElement { + t.Fatalf("expected generic container mutation kind, got %v", muts[0].Kind) + } + if !typ.TypeEquals(muts[0].ValueType, typ.Integer) { + t.Fatalf("expected integer mutation value, got %v", muts[0].ValueType) + } +} + +func TestCollectCapturedContainerMutations_TableInsertCallSite(t *testing.T) { + code := ` + local c = {} + local _ = table.insert(c.items, 1) + ` + stmts, err := parse.ParseString(code, "test.lua") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + fn := &ast.FunctionExpr{ + ParList: &ast.ParList{HasVargs: true}, + Stmts: stmts, + } + graph := cfg.Build(fn, "table") + if graph == nil { + t.Fatal("expected graph") + } + symC, ok := graph.SymbolAt(graph.Exit(), "c") + if !ok || symC == 0 { + t.Fatal("expected symbol for c") + } + + captured := map[cfg.SymbolID]bool{symC: true} + result := CollectCapturedContainerMutations(graph, captured, nestedTableInsertSynth()) + muts := result[symC] + if len(muts) != 1 { + t.Fatalf("expected 1 table mutation for c, got %d", len(muts)) + } + if muts[0].Kind != api.ContainerMutationTableElement { + t.Fatalf("expected table mutation kind, got %v", muts[0].Kind) + } + if len(muts[0].Segments) != 1 || muts[0].Segments[0].Name != "items" { + t.Fatalf("expected .items mutation path, got %#v", muts[0].Segments) + } if !typ.TypeEquals(muts[0].ValueType, typ.Integer) { t.Fatalf("expected integer mutation value, got %v", muts[0].ValueType) } @@ -136,3 +179,30 @@ func nestedContainerSendSynth() func(ast.Expr, cfg.Point) typ.Type { return typ.Unknown } } + +func nestedTableInsertSynth() func(ast.Expr, cfg.Point) typ.Type { + spec := contract.NewSpec().WithEffects(effect.TableMutator{ + Target: effect.ParamRef{Index: 0}, + Value: effect.ParamRef{Index: 1}, + }) + insert := typ.Func(). + Param("target", typ.Any). + Param("value", typ.Any). + Returns(typ.Nil). + Spec(spec). + Build() + + return func(expr ast.Expr, _ cfg.Point) typ.Type { + switch v := expr.(type) { + case *ast.AttrGetExpr: + obj, objOK := v.Object.(*ast.IdentExpr) + key, keyOK := v.Key.(*ast.StringExpr) + if objOK && keyOK && obj.Value == "table" && key.Value == "insert" { + return insert + } + case *ast.NumberExpr: + return typ.Integer + } + return typ.Unknown + } +} diff --git a/compiler/check/phase/scope.go b/compiler/check/phase/scope.go index fa4aadaa..e25d94cd 100644 --- a/compiler/check/phase/scope.go +++ b/compiler/check/phase/scope.go @@ -34,6 +34,7 @@ import ( basecfg "github.com/wippyai/go-lua/types/cfg" "github.com/wippyai/go-lua/types/db" "github.com/wippyai/go-lua/types/flow" + "github.com/wippyai/go-lua/types/kind" "github.com/wippyai/go-lua/types/query/core" "github.com/wippyai/go-lua/types/typ" "github.com/wippyai/go-lua/types/typ/unwrap" @@ -108,6 +109,7 @@ func RunScope(input ScopeInput) ScopeOutput { depthExceeded := false base := BuildFunctionScope(input.Fn, input.Parent, typeExprResolver, input.MaxScopeDepth, &depthExceeded) + base = normalizeBaseImplicitSelf(input.Graph, base) var synthSig *typ.Function if input.SynthesizedFunctionSig != nil { @@ -117,6 +119,7 @@ func RunScope(input ScopeInput) ScopeOutput { var hints []typ.Type if input.ParamHintSignatures != nil && input.Fn != nil { hints = input.ParamHintSignatures[input.Fn] + hints = paramhints.ProjectHintsToParamUse(input.Graph, input.Fn, hints) } paramTypes, paramAnnotated := ExtractParamTypes(input.Graph, input.Fn, typeExprResolver, synthSig, base, hints) @@ -132,7 +135,7 @@ func RunScope(input ScopeInput) ScopeOutput { } if i < len(synthSig.Params) && synthSig.Params[i].Type != nil { if name == "self" && base.SelfType() == nil { - base = base.WithSelf(synthSig.Params[i].Type) + base = base.WithSelf(widenImplicitSelfState(synthSig.Params[i].Type)) } } } @@ -189,7 +192,15 @@ func RunScope(input ScopeInput) ScopeOutput { exprSynth := func(expr ast.Expr, p cfg.Point, sc *scope.State) typ.Type { return typeResolutionEngine.SynthExprAt(expr, p, sc) } - fnSignatureResolver := buildFnSignatureResolver(input.FunctionLiteralSignatures, input.ParamHintSignatures, typeResolutionEngine) + paramHintSignatures := input.ParamHintSignatures + if input.Fn != nil && hints != nil && input.ParamHintSignatures != nil { + paramHintSignatures = make(map[*ast.FunctionExpr][]typ.Type, len(input.ParamHintSignatures)) + for fn, hintVec := range input.ParamHintSignatures { + paramHintSignatures[fn] = hintVec + } + paramHintSignatures[input.Fn] = hints + } + fnSignatureResolver := buildFnSignatureResolver(input.FunctionLiteralSignatures, paramHintSignatures, typeResolutionEngine) callMutator := buildCallMutator(input.Types, input.Ctx, exprSynth) services := ScopeServicesFuncs{ @@ -226,6 +237,19 @@ func RunScope(input ScopeInput) ScopeOutput { } } +func normalizeBaseImplicitSelf(graph *cfg.Graph, base *scope.State) *scope.State { + if graph == nil || base == nil || base.SelfType() == nil { + return base + } + for _, slot := range graph.ParamSlotsReadOnly() { + if slot.Name != "self" || slot.TypeAnnotation != nil { + continue + } + return base.WithSelf(widenImplicitSelfState(base.SelfType())) + } + return base +} + // buildFnSignatureResolver creates a function signature resolver that combines // pre-computed literal signatures, parameter hints, and annotation-based resolution. func buildFnSignatureResolver( @@ -275,28 +299,33 @@ func ExtractParamTypes( annotated = make(map[cfg.SymbolID]bool) slots := graph.ParamSlotsReadOnly() - for _, slot := range slots { + for paramIdx, slot := range slots { if slot.Symbol == 0 || slot.Name == "" { continue } // Binder/CFG-injected implicit self parameter has no source annotation. - srcIdx, hasSource := slot.SourceParamIndex() + _, hasSource := slot.SourceParamIndex() + var hint typ.Type + if paramHints != nil && paramIdx < len(paramHints) { + hint = paramHints[paramIdx] + } if !hasSource { if base != nil && base.SelfType() != nil { types[slot.Symbol] = base.SelfType() + } else if synthSig != nil && paramIdx < len(synthSig.Params) && synthSig.Params[paramIdx].Type != nil { + types[slot.Symbol] = synthSig.Params[paramIdx].Type + } else if hint != nil { + types[slot.Symbol] = hint } else { types[slot.Symbol] = typ.Unknown } + if slot.Name == "self" { + types[slot.Symbol] = widenImplicitSelfState(types[slot.Symbol]) + } continue } - i := srcIdx - var paramType typ.Type - var hint typ.Type - if paramHints != nil && i < len(paramHints) { - hint = paramHints[i] - } var isAnnotated bool var hasExplicitAnnotation bool if slot.TypeAnnotation != nil { @@ -308,8 +337,8 @@ func ExtractParamTypes( if typ.IsRefinableAnnotation(paramType) { if hint != nil { paramType = hint - } else if synthSig != nil && i < len(synthSig.Params) && synthSig.Params[i].Type != nil { - paramType = synthSig.Params[i].Type + } else if synthSig != nil && paramIdx < len(synthSig.Params) && synthSig.Params[paramIdx].Type != nil { + paramType = synthSig.Params[paramIdx].Type } } else { isAnnotated = true @@ -317,9 +346,8 @@ func ExtractParamTypes( } } else if hint != nil { paramType = hint - } else if synthSig != nil && i < len(synthSig.Params) && synthSig.Params[i].Type != nil { - paramType = synthSig.Params[i].Type - isAnnotated = true + } else if synthSig != nil && paramIdx < len(synthSig.Params) && synthSig.Params[paramIdx].Type != nil { + paramType = synthSig.Params[paramIdx].Type } else if slot.Name == "self" && base != nil && base.SelfType() != nil { paramType = base.SelfType() } else { @@ -333,6 +361,9 @@ func ExtractParamTypes( paramType = base.SelfType() } } + if slot.Name == "self" && !hasExplicitAnnotation { + paramType = widenImplicitSelfState(paramType) + } types[slot.Symbol] = paramType if isAnnotated { @@ -346,6 +377,63 @@ func ExtractParamTypes( return types, annotated } +func widenImplicitSelfState(t typ.Type) typ.Type { + rec, ok := t.(*typ.Record) + if !ok { + return t + } + builder := typ.NewRecord() + if rec.Open { + builder.SetOpen(true) + } + for _, f := range rec.Fields { + fieldType := widenImplicitSelfField(f.Type) + switch { + case f.Optional && f.Readonly: + builder.OptReadonlyField(f.Name, fieldType) + case f.Optional: + builder.OptField(f.Name, fieldType) + case f.Readonly: + builder.ReadonlyField(f.Name, fieldType) + default: + builder.Field(f.Name, fieldType) + } + } + if rec.Metatable != nil { + builder.Metatable(rec.Metatable) + } + if rec.HasMapComponent() { + builder.MapComponent(rec.MapKey, rec.MapValue) + } + return builder.Build() +} + +func widenImplicitSelfField(t typ.Type) typ.Type { + if t == nil { + return typ.Unknown + } + unaliased := unwrap.Alias(t) + if unaliased == nil { + return typ.Unknown + } + if unaliased.Kind() == kind.Nil { + return typ.Unknown + } + if lit, ok := unaliased.(*typ.Literal); ok { + switch lit.Base { + case kind.Boolean: + return typ.Boolean + case kind.String: + return typ.String + case kind.Integer: + return typ.Integer + case kind.Number: + return typ.Number + } + } + return t +} + // buildDeclaredTypes builds declared types from annotations. func buildDeclaredTypes( graph *cfg.Graph, diff --git a/compiler/check/pipeline/runner_stages.go b/compiler/check/pipeline/runner_stages.go index 64c35bd3..071bfd3f 100644 --- a/compiler/check/pipeline/runner_stages.go +++ b/compiler/check/pipeline/runner_stages.go @@ -13,6 +13,7 @@ import ( "github.com/wippyai/go-lua/types/db" "github.com/wippyai/go-lua/types/flow" "github.com/wippyai/go-lua/types/typ" + "github.com/wippyai/go-lua/types/typ/unwrap" ) func (r *Runner) resolveSynthesizedSignature( @@ -27,14 +28,14 @@ func (r *Runner) resolveSynthesizedSignature( return nil } - // Prefer literal signature from parent graph for nested functions. + factSig := paramhints.ProjectSignatureToParamUse(graph, fn, r.functionFactSignatureForFunction(store, graph, fn)) synthSig := r.literalSignatureForFunction(store, graph, fn) if paramHintSigs == nil { - return synthSig + return mergeSynthesizedSignatureFact(synthSig, factSig) } hints := paramHintSigs[fn] if len(hints) == 0 { - return synthSig + return mergeSynthesizedSignatureFact(synthSig, factSig) } if synthSig == nil { engine := synth.New(synth.Config{ @@ -50,9 +51,50 @@ func (r *Runner) resolveSynthesizedSignature( } } if synthSig == nil { + return factSig + } + return mergeSynthesizedSignatureFact(paramhints.MergeIntoSignature(fn, hints, synthSig), factSig) +} + +func mergeSynthesizedSignatureFact(seed, fact *typ.Function) *typ.Function { + if seed == nil { + return fact + } + if fact == nil { + return seed + } + if merged := unwrap.Function(returns.MergeFunctionFactType(seed, fact)); merged != nil { + return merged + } + return seed +} + +func (r *Runner) functionFactSignatureForFunction( + store api.StoreReader, + graph *cfg.Graph, + fn *ast.FunctionExpr, +) *typ.Function { + if store == nil || graph == nil || fn == nil { + return nil + } + sym, ok := store.SymbolForFunc(fn) + if !ok || sym == 0 { + return nil + } + meta, ok := store.NestedMetaFor(graph.ID()) + if !ok || meta.ParentGraphID == 0 { + return nil + } + parentGraph := store.Graphs()[meta.ParentGraphID] + if parentGraph == nil { return nil } - return paramhints.MergeIntoSignature(fn, hints, synthSig) + parentScope := r.parentScopeForGraph(store, parentGraph) + if parentScope == nil { + return nil + } + facts := store.GetFunctionFactsSnapshot(parentGraph, parentScope) + return unwrap.Function(facts.FunctionType(sym)) } func (r *Runner) appendCapturedMutatorAssignments( @@ -102,11 +144,13 @@ func (r *Runner) appendCapturedMutatorAssignments( return resolve.CalleeType(info, p, synthEngine.TypeOf, symResolver, assignmentTypes, graph, bindings, env.ModuleBindings) } - extra := returns.CollectCalledNestedContainerMutatorAssignments(graph, bindings, capturedContainers, calleeTypeResolver) - if len(extra) == 0 { - return + extra := returns.CollectNestedMutatorAssignments(graph, bindings, capturedContainers, calleeTypeResolver) + if len(extra.Table) > 0 { + extractOut.Inputs.TableMutatorAssignments = append(extractOut.Inputs.TableMutatorAssignments, extra.Table...) + } + if len(extra.Container) > 0 { + extractOut.Inputs.ContainerMutatorAssignments = append(extractOut.Inputs.ContainerMutatorAssignments, extra.Container...) } - extractOut.Inputs.ContainerMutatorAssignments = append(extractOut.Inputs.ContainerMutatorAssignments, extra...) } func (r *Runner) runComputePasses(graph *cfg.Graph, scopes map[cfg.Point]*scope.State) map[string]any { diff --git a/compiler/check/returns/callgraph.go b/compiler/check/returns/callgraph.go index d3a48e30..ac336084 100644 --- a/compiler/check/returns/callgraph.go +++ b/compiler/check/returns/callgraph.go @@ -113,12 +113,11 @@ func PropagateParamHintsFromCallGraph(localFuncs map[cfg.SymbolID]*LocalFuncInfo if info.Graph == nil { continue } - for _, slot := range info.Graph.ParamSlotsReadOnly() { - srcIdx, hasSource := slot.SourceParamIndex() - if !hasSource || slot.Symbol == 0 { + for idx, slot := range info.Graph.ParamSlotsReadOnly() { + if slot.Symbol == 0 { continue } - paramOwner[slot.Symbol] = paramRef{owner: info, index: srcIdx} + paramOwner[slot.Symbol] = paramRef{owner: info, index: idx} } } diff --git a/compiler/check/returns/callsite.go b/compiler/check/returns/callsite.go index a0f73d3e..2daa6010 100644 --- a/compiler/check/returns/callsite.go +++ b/compiler/check/returns/callsite.go @@ -3,6 +3,7 @@ package returns import ( "sort" + "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" @@ -91,20 +92,31 @@ func CollectCalledNestedFieldAssignments( return result } -// CollectCalledNestedContainerMutatorAssignments collects container mutations recorded for -// called nested functions that target symbols from the parent graph (captured variables). +// CalledNestedMutatorAssignments is the flow replay payload for captured +// mutations made by called nested functions. +type CalledNestedMutatorAssignments struct { + Table []flow.TableMutatorAssignment + Container []flow.ContainerMutatorAssignment +} + +// CollectNestedMutatorAssignments collects captured mutations recorded for +// parent-visible nested functions and replays them through the matching flow +// operator. // -// This supports cases where a nested function mutates a captured container (e.g., channel.send) -// and the nested function is invoked directly or passed as a callback to a function with a -// callback spec (e.g., coroutine.spawn). -func CollectCalledNestedContainerMutatorAssignments( +// This supports cases where a nested function mutates a captured table +// (table.insert) or generic container (channel.send) and the nested function is: +// - invoked directly, +// - passed as a callback to a function with a callback spec, or +// - stored in a field/global position that can be called outside the parent +// graph before another exported function reads the captured state. +func CollectNestedMutatorAssignments( parent *cfg.Graph, bindings *bind.BindingTable, capturedByCallee api.CapturedContainerMutations, resolveCalleeType func(*cfg.CallInfo, cfg.Point) typ.Type, -) []flow.ContainerMutatorAssignment { +) CalledNestedMutatorAssignments { if parent == nil || len(capturedByCallee) == 0 { - return nil + return CalledNestedMutatorAssignments{} } parentSymbols := parent.AllSymbolIDs() @@ -112,7 +124,21 @@ func CollectCalledNestedContainerMutatorAssignments( for calleeSym := range capturedByCallee { trackedCallees[calleeSym] = true } - assignments := make([]flow.ContainerMutatorAssignment, 0) + assignments := CalledNestedMutatorAssignments{} + emitForCallee := func(p cfg.Point, sym cfg.SymbolID) { + nestedMutations := capturedByCallee[sym] + if len(nestedMutations) == 0 { + return + } + for _, targetSym := range cfg.SortedSymbolIDs(nestedMutations) { + mutations := nestedMutations[targetSym] + if !parentSymbols[targetSym] { + continue + } + root := resolve.RootNameFromGraphAndBindings(parent, bindings, targetSym, "") + appendNestedMutatorAssignments(&assignments, p, root, targetSym, mutations) + } + } parent.EachCallSite(func(p cfg.Point, info *cfg.CallInfo) { if info == nil { @@ -127,36 +153,138 @@ func CollectCalledNestedContainerMutatorAssignments( } for _, sym := range cfg.SortedSymbolIDs(calledSyms) { - nestedMutations := capturedByCallee[sym] - if len(nestedMutations) == 0 { - continue - } - for _, targetSym := range cfg.SortedSymbolIDs(nestedMutations) { - mutations := nestedMutations[targetSym] - if !parentSymbols[targetSym] { - continue - } - root := resolve.RootNameFromGraphAndBindings(parent, bindings, targetSym, "") - for _, mutation := range mutations { - segs := make([]constraint.Segment, len(mutation.Segments)) - copy(segs, mutation.Segments) - assignments = append(assignments, flow.ContainerMutatorAssignment{ - Point: p, - Target: constraint.Path{ - Root: root, - Symbol: targetSym, - Segments: segs, - }, - ValueType: mutation.ValueType, - }) - } - } + emitForCallee(p, sym) } }) + for _, trigger := range escapedNestedMutationTriggers(parent, bindings, trackedCallees) { + emitForCallee(trigger.Point, trigger.Symbol) + } + return assignments } +func appendNestedMutatorAssignments( + assignments *CalledNestedMutatorAssignments, + p cfg.Point, + root string, + targetSym cfg.SymbolID, + mutations []api.ContainerMutation, +) { + if assignments == nil || targetSym == 0 || len(mutations) == 0 { + return + } + for _, mutation := range mutations { + segs := make([]constraint.Segment, len(mutation.Segments)) + copy(segs, mutation.Segments) + target := constraint.Path{ + Root: root, + Symbol: targetSym, + Segments: segs, + } + switch mutation.Kind { + case api.ContainerMutationTableElement: + assignments.Table = append(assignments.Table, flow.TableMutatorAssignment{ + Point: p, + Target: target, + ValueType: mutation.ValueType, + }) + default: + assignments.Container = append(assignments.Container, flow.ContainerMutatorAssignment{ + Point: p, + Target: target, + ValueType: mutation.ValueType, + }) + } + } +} + +type nestedMutationTrigger struct { + Point cfg.Point + Symbol cfg.SymbolID +} + +func escapedNestedMutationTriggers( + parent *cfg.Graph, + bindings *bind.BindingTable, + trackedCallees map[cfg.SymbolID]bool, +) []nestedMutationTrigger { + if parent == nil || len(trackedCallees) == 0 { + return nil + } + var triggers []nestedMutationTrigger + seen := make(map[nestedMutationTrigger]bool) + appendTrigger := func(p cfg.Point, sym cfg.SymbolID) { + if sym == 0 || !trackedCallees[sym] { + return + } + trigger := nestedMutationTrigger{Point: p, Symbol: sym} + if seen[trigger] { + return + } + seen[trigger] = true + triggers = append(triggers, trigger) + } + + parent.EachFuncDef(func(p cfg.Point, info *cfg.FuncDefInfo) { + if info == nil || !funcDefEscapesParent(info.TargetKind) { + return + } + appendTrigger(p, info.Symbol) + }) + + parent.EachAssign(func(p cfg.Point, info *cfg.AssignInfo) { + if info == nil { + return + } + info.EachTargetSource(func(i int, target cfg.AssignTarget, src ast.Expr) { + if !assignmentTargetEscapesFunction(target) { + return + } + appendTrigger(p, assignmentSourceFunctionSymbol(info, i, src, bindings)) + }) + }) + + return triggers +} + +func funcDefEscapesParent(kind cfg.FuncDefTargetKind) bool { + switch kind { + case cfg.FuncDefGlobal, cfg.FuncDefField, cfg.FuncDefMethod: + return true + default: + return false + } +} + +func assignmentTargetEscapesFunction(target cfg.AssignTarget) bool { + switch target.Kind { + case cfg.TargetField, cfg.TargetIndex: + return true + default: + return false + } +} + +func assignmentSourceFunctionSymbol( + info *cfg.AssignInfo, + i int, + src ast.Expr, + bindings *bind.BindingTable, +) cfg.SymbolID { + if info != nil && i >= 0 && i < len(info.SourceSymbols) { + if sym := info.SourceSymbols[i]; sym != 0 { + return sym + } + } + if fn, ok := src.(*ast.FunctionExpr); ok && bindings != nil { + if sym, found := bindings.FuncLitSymbol(fn); found { + return sym + } + } + return 0 +} + func calledSymbolsFromCall( info *cfg.CallInfo, p cfg.Point, diff --git a/compiler/check/returns/callsite_test.go b/compiler/check/returns/callsite_test.go index ef233ec1..79c60786 100644 --- a/compiler/check/returns/callsite_test.go +++ b/compiler/check/returns/callsite_test.go @@ -6,8 +6,11 @@ import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/api" checkcallsite "github.com/wippyai/go-lua/compiler/check/callsite" "github.com/wippyai/go-lua/compiler/parse" + "github.com/wippyai/go-lua/types/constraint" + "github.com/wippyai/go-lua/types/typ" ) func TestCollectCalledNestedFieldAssignments(t *testing.T) { @@ -19,15 +22,135 @@ func TestCollectCalledNestedFieldAssignments(t *testing.T) { }) } -func TestCollectCalledNestedContainerMutatorAssignments(t *testing.T) { +func TestCollectNestedMutatorAssignments(t *testing.T) { t.Run("nil graph returns empty slice", func(t *testing.T) { - result := CollectCalledNestedContainerMutatorAssignments(nil, nil, nil, nil) - if len(result) != 0 { + result := CollectNestedMutatorAssignments(nil, nil, nil, nil) + if len(result.Table) != 0 || len(result.Container) != 0 { t.Error("expected empty result") } }) } +func TestCollectNestedMutatorAssignments_SplitsOperatorKinds(t *testing.T) { + stmts, err := parse.ParseString(` + local state = {} + local function setup() + return nil + end + setup() + `, "test.lua") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + graph := cfg.Build(&ast.FunctionExpr{Stmts: stmts}) + if graph == nil { + t.Fatal("expected graph") + } + bindings := graph.Bindings() + if bindings == nil { + t.Fatal("expected bindings") + } + stateSym, ok := graph.SymbolAt(graph.Exit(), "state") + if !ok || stateSym == 0 { + t.Fatal("expected symbol for state") + } + setupSym, ok := graph.SymbolAt(graph.Exit(), "setup") + if !ok || setupSym == 0 { + t.Fatal("expected symbol for setup") + } + + captured := api.CapturedContainerMutations{ + setupSym: { + stateSym: { + { + Kind: api.ContainerMutationTableElement, + Segments: []constraint.Segment{{Kind: constraint.SegmentField, Name: "items"}}, + ValueType: typ.String, + }, + { + Kind: api.ContainerMutationContainerElement, + Segments: []constraint.Segment{{Kind: constraint.SegmentField, Name: "channel"}}, + ValueType: typ.Number, + }, + }, + }, + } + + got := CollectNestedMutatorAssignments(graph, bindings, captured, nil) + if len(got.Table) != 1 { + t.Fatalf("table assignments = %d, want 1", len(got.Table)) + } + if len(got.Container) != 1 { + t.Fatalf("container assignments = %d, want 1", len(got.Container)) + } + if got.Table[0].Target.Symbol != stateSym || len(got.Table[0].Target.Segments) != 1 || got.Table[0].Target.Segments[0].Name != "items" { + t.Fatalf("unexpected table target: %#v", got.Table[0].Target) + } + if got.Container[0].Target.Symbol != stateSym || len(got.Container[0].Target.Segments) != 1 || got.Container[0].Target.Segments[0].Name != "channel" { + t.Fatalf("unexpected container target: %#v", got.Container[0].Target) + } +} + +func TestCollectNestedMutatorAssignments_ReplaysExportedFieldFunction(t *testing.T) { + stmts, err := parse.ParseString(` + local api = {} + local state = {} + function api.add() + return nil + end + `, "test.lua") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + graph := cfg.Build(&ast.FunctionExpr{Stmts: stmts}) + if graph == nil { + t.Fatal("expected graph") + } + bindings := graph.Bindings() + if bindings == nil { + t.Fatal("expected bindings") + } + stateSym, ok := graph.SymbolAt(graph.Exit(), "state") + if !ok || stateSym == 0 { + t.Fatal("expected symbol for state") + } + + var addSym cfg.SymbolID + var addPoint cfg.Point + graph.EachFuncDef(func(p cfg.Point, info *cfg.FuncDefInfo) { + if info != nil && info.Name == "add" { + addSym = info.Symbol + addPoint = p + } + }) + if addSym == 0 { + t.Fatal("expected symbol for api.add") + } + + captured := api.CapturedContainerMutations{ + addSym: { + stateSym: { + { + Kind: api.ContainerMutationTableElement, + Segments: []constraint.Segment{{Kind: constraint.SegmentField, Name: "items"}}, + ValueType: typ.String, + }, + }, + }, + } + + got := CollectNestedMutatorAssignments(graph, bindings, captured, nil) + if len(got.Table) != 1 { + t.Fatalf("table assignments = %d, want 1", len(got.Table)) + } + if got.Table[0].Point != addPoint { + t.Fatalf("assignment point = %d, want exported definition point %d", got.Table[0].Point, addPoint) + } + if got.Table[0].Target.Symbol != stateSym || got.Table[0].Target.Segments[0].Name != "items" { + t.Fatalf("unexpected exported table target: %#v", got.Table[0].Target) + } +} + func TestRuntimeArgAt(t *testing.T) { t.Run("direct call positional mapping", func(t *testing.T) { a := &ast.NumberExpr{Value: "1"} diff --git a/compiler/check/returns/container_mutation_merge_test.go b/compiler/check/returns/container_mutation_merge_test.go index e2ff9af4..f9370695 100644 --- a/compiler/check/returns/container_mutation_merge_test.go +++ b/compiler/check/returns/container_mutation_merge_test.go @@ -37,11 +37,11 @@ func TestMergeContainerMutationSlices_DedupAndSorted(t *testing.T) { if len(got) != 2 { t.Fatalf("len(got) = %d, want 2", len(got)) } - if k := api.ContainerMutationKey(got[0]); k != ".a" { - t.Fatalf("first key = %q, want .a", k) + if k := api.ContainerMutationKey(got[0]); k != "container:.a" { + t.Fatalf("first key = %q, want container:.a", k) } - if k := api.ContainerMutationKey(got[1]); k != ".b" { - t.Fatalf("second key = %q, want .b", k) + if k := api.ContainerMutationKey(got[1]); k != "container:.b" { + t.Fatalf("second key = %q, want container:.b", k) } if !typ.TypeEquals(got[1].ValueType, typ.Number) { t.Fatalf(".b merged type = %v, want number", got[1].ValueType) @@ -79,7 +79,32 @@ func TestMergeCapturedContainerMutationMaps_MergeBySymbol(t *testing.T) { if len(got[1]) != 1 || len(got[2]) != 1 { t.Fatalf("unexpected per-symbol merge sizes: sym1=%d sym2=%d", len(got[1]), len(got[2])) } - if key := api.ContainerMutationKey(got[2][0]); key != ".y" { - t.Fatalf("sym2 key = %q, want .y", key) + if key := api.ContainerMutationKey(got[2][0]); key != "container:.y" { + t.Fatalf("sym2 key = %q, want container:.y", key) + } +} + +func TestMergeContainerMutationSlices_KeepsOperatorKindsDistinct(t *testing.T) { + existing := []api.ContainerMutation{ + { + Kind: api.ContainerMutationContainerElement, + Segments: []constraint.Segment{{Kind: constraint.SegmentField, Name: "items"}}, + ValueType: typ.Number, + }, + } + next := []api.ContainerMutation{ + { + Kind: api.ContainerMutationTableElement, + Segments: []constraint.Segment{{Kind: constraint.SegmentField, Name: "items"}}, + ValueType: typ.String, + }, + } + + got := MergeContainerMutationSlices(existing, next, nil) + if len(got) != 2 { + t.Fatalf("len(got) = %d, want 2 distinct operator facts", len(got)) + } + if got[0].Kind == got[1].Kind { + t.Fatalf("expected separate facts for same path with different operators, got %#v", got) } } diff --git a/compiler/check/returns/domain_law_test.go b/compiler/check/returns/domain_law_test.go index 307a92ab..c5ba39c6 100644 --- a/compiler/check/returns/domain_law_test.go +++ b/compiler/check/returns/domain_law_test.go @@ -130,6 +130,32 @@ func TestFactsDomain_WidenFunctionParamsIsVarianceAware(t *testing.T) { } } +func TestFactsDomain_PreservesArityAndNilabilityAsSeparateParamAxes(t *testing.T) { + sym := cfg.SymbolID(1) + context := typ.NewRecord(). + MapComponent(typ.String, typ.Any). + SetOpen(true). + Build() + raw := api.Facts{ + FunctionFacts: api.FunctionFacts{ + sym: {Type: typ.Func().OptParam("context", typ.NewOptional(context)).Build()}, + }, + } + + widened := WidenFacts(api.Facts{}, raw) + fn := unwrapFunctionForDomainTest(t, widened.FunctionFacts.FunctionType(sym)) + if len(fn.Params) != 1 || !fn.Params[0].Optional { + t.Fatalf("expected optional parameter slot, got %v", fn) + } + want := typ.NewOptional(context) + if !typ.TypeEquals(fn.Params[0].Type, want) { + t.Fatalf("expected explicit nilability to remain in the value type, got %v", fn.Params[0].Type) + } + if !FactsEqual(widened, WidenFacts(widened, raw)) { + t.Fatalf("expected optional parameter product-domain representation to be idempotent") + } +} + func TestFactsDomain_WidenPreservesCapturedCallbackUnionMembers(t *testing.T) { sym := cfg.SymbolID(9) withPending := typ.NewUnion( diff --git a/compiler/check/returns/equal_test.go b/compiler/check/returns/equal_test.go index 06540da8..41312e55 100644 --- a/compiler/check/returns/equal_test.go +++ b/compiler/check/returns/equal_test.go @@ -196,6 +196,26 @@ func TestCapturedContainerMutationsEqual_Basic(t *testing.T) { } } +func TestCapturedContainerMutationsEqual_DifferentOperatorKind(t *testing.T) { + a := api.CapturedContainerMutations{ + cfg.SymbolID(1): { + cfg.SymbolID(2): { + {Kind: api.ContainerMutationContainerElement, ValueType: typ.Number}, + }, + }, + } + b := api.CapturedContainerMutations{ + cfg.SymbolID(1): { + cfg.SymbolID(2): { + {Kind: api.ContainerMutationTableElement, ValueType: typ.Number}, + }, + }, + } + if CapturedContainerMutationsEqual(a, b) { + t.Error("same path with different mutation operators should not be equal") + } +} + func TestCapturedFieldAssignsEqual_CanonicalizesOptionalFunctionValues(t *testing.T) { fn := typ.Func().Param("fn", typ.Unknown).Build() left := api.CapturedFieldAssigns{1: {2: {"after_all": typ.NewOptional(fn)}}} diff --git a/compiler/check/returns/function_facts.go b/compiler/check/returns/function_facts.go index 87d10ee7..dfcb61a5 100644 --- a/compiler/check/returns/function_facts.go +++ b/compiler/check/returns/function_facts.go @@ -72,11 +72,14 @@ func normalizeFunctionFactMap(facts api.FunctionFacts) api.FunctionFacts { } func writeFunctionFactToFacts(facts *api.Facts, sym cfg.SymbolID, ff api.FunctionFact) { + writeNormalizedFunctionFactToFacts(facts, sym, NormalizeFunctionFact(ff)) +} + +func writeNormalizedFunctionFactToFacts(facts *api.Facts, sym cfg.SymbolID, ff api.FunctionFact) { if facts == nil || sym == 0 { return } - ff = NormalizeFunctionFact(ff) if functionFactEmpty(ff) { if facts.FunctionFacts != nil { delete(facts.FunctionFacts, sym) diff --git a/compiler/check/returns/join.go b/compiler/check/returns/join.go index bd4ff769..1ebeb3b1 100644 --- a/compiler/check/returns/join.go +++ b/compiler/check/returns/join.go @@ -125,6 +125,12 @@ func SelectPreferredReturnVector(a, b []typ.Type) ([]typ.Type, bool) { if ReturnTypesRepairNever(b, a) { return b, true } + if ReturnTypesRefineSoftContainers(a, b) { + return a, true + } + if ReturnTypesRefineSoftContainers(b, a) { + return b, true + } if ReturnTypesStopRecursiveStructuralGrowth(a, b) { return a, true } @@ -170,11 +176,105 @@ func SelectPreferredReturnVector(a, b []typ.Type) ([]typ.Type, bool) { return nil, false } -// ReturnTypesRefineFalsyMapKeys reports whether candidate is the same map-like -// shape as baseline after removing stale falsy key members from baseline. This -// handles fixed-point rounds where an early branch-insensitive dynamic index -// observes a key as `string | false`, then the solved guard proves the actual -// write key is `string`. +// ReturnTypesRefineSoftContainers reports whether candidate preserves the same +// table shape while replacing soft placeholder element/value evidence with +// concrete evidence. This is a summary-lattice rule only; it does not weaken +// mutable map subtyping. +func ReturnTypesRefineSoftContainers(candidate, baseline []typ.Type) bool { + if len(candidate) == 0 || len(baseline) == 0 || len(candidate) != len(baseline) { + return false + } + strict := false + for i := range candidate { + refines, changed := typeRefinesSoftContainer(candidate[i], baseline[i]) + if !refines { + return false + } + if changed { + strict = true + } + } + return strict +} + +func typeRefinesSoftContainer(candidate, baseline typ.Type) (bool, bool) { + candidate = unwrapStructuralShape(candidate) + baseline = unwrapStructuralShape(baseline) + if candidate == nil || baseline == nil { + return candidate == baseline, false + } + if typ.TypeEquals(candidate, baseline) { + return true, false + } + + switch b := baseline.(type) { + case *typ.Array: + c, ok := candidate.(*typ.Array) + if !ok { + return false, false + } + return typeRefinesSoftContainerSlot(c.Element, b.Element) + case *typ.Map: + c, ok := candidate.(*typ.Map) + if !ok || !equivalentParamValueType(c.Key, b.Key) { + return false, false + } + return typeRefinesSoftContainerSlot(c.Value, b.Value) + case *typ.Record: + c, ok := candidate.(*typ.Record) + if !ok || !sameRecordFrame(c, b) { + return false, false + } + if !c.HasMapComponent() && !b.HasMapComponent() { + return true, false + } + if !c.HasMapComponent() || !b.HasMapComponent() || !equivalentParamValueType(c.MapKey, b.MapKey) { + return false, false + } + return typeRefinesSoftContainerSlot(c.MapValue, b.MapValue) + default: + return false, false + } +} + +func typeRefinesSoftContainerSlot(candidate, baseline typ.Type) (bool, bool) { + if typ.TypeEquals(candidate, baseline) { + return true, false + } + if (typ.IsAny(baseline) || typ.IsUnknown(baseline)) && typeCanSelfEmbed(candidate) { + return false, false + } + preferred, ok := preferConcreteOverSoftType(baseline, candidate) + return ok && typ.TypeEquals(preferred, candidate), ok +} + +func sameRecordFrame(a, b *typ.Record) bool { + if a == nil || b == nil || a.Open != b.Open || len(a.Fields) != len(b.Fields) { + return false + } + if (a.Metatable == nil) != (b.Metatable == nil) { + return false + } + if a.Metatable != nil && !typ.TypeEquals(a.Metatable, b.Metatable) { + return false + } + for i, field := range a.Fields { + other := b.Fields[i] + if field.Name != other.Name || field.Optional != other.Optional || field.Readonly != other.Readonly { + return false + } + if !typ.TypeEquals(field.Type, other.Type) { + return false + } + } + return true +} + +// ReturnTypesRefineFalsyMapKeys reports whether candidate is the same +// table-derived shape as baseline after removing stale falsy members from +// baseline. This handles fixed-point rounds where an early branch-insensitive +// dynamic index observes a key as `string | false`, then the solved guard proves +// the actual write key is `string`. func ReturnTypesRefineFalsyMapKeys(candidate, baseline []typ.Type) bool { if len(candidate) == 0 || len(baseline) == 0 || len(candidate) != len(baseline) { return false @@ -203,6 +303,12 @@ func typeRefinesFalsyMapKey(candidate, baseline typ.Type) (bool, bool) { } switch b := baseline.(type) { + case *typ.Array: + c, ok := candidate.(*typ.Array) + if !ok { + return false, false + } + return truthyElementRefinement(c.Element, b.Element) case *typ.Map: c, ok := candidate.(*typ.Map) if !ok { @@ -255,6 +361,16 @@ func mapKeyTruthyRefinement(candidateKey, candidateValue, baselineKey, baselineV return false, false } +func truthyElementRefinement(candidate, baseline typ.Type) (bool, bool) { + if typ.TypeEquals(candidate, baseline) { + return true, false + } + if typeIsTruthyRefinement(candidate, baseline) { + return true, true + } + return false, false +} + // ReturnTypesNestedNilOnlyRegression reports whether candidate's apparent // refinement only adds nested nil facts over a more useful baseline shape. A // required `nil` field or `unknown -> nil` field does not help callers, but it @@ -1331,24 +1447,42 @@ func mergeFunctionParamFactType(existing, candidate typ.Type) typ.Type { existing = typ.PruneSoftUnionMembers(existing) candidate = typ.PruneSoftUnionMembers(candidate) + if unwrap.IsNilType(existing) && !unwrap.IsNilType(candidate) { + return candidate + } + if unwrap.IsNilType(candidate) && !unwrap.IsNilType(existing) { + return existing + } if preferred, ok := preferStructuredRecordParam(existing, candidate); ok { return preferred } + if preferred, ok := preferConcreteOverSoftType(existing, candidate); ok { + return preferred + } if typ.IsUnknown(existing) { return candidate } if typ.IsUnknown(candidate) { return existing } - if typ.IsAny(existing) && !typ.IsAny(candidate) { + if typ.IsAny(existing) && typ.IsAny(candidate) { + return typ.Any + } + if typ.IsAny(existing) { return candidate } - if typ.IsAny(candidate) && !typ.IsAny(existing) { + if typ.IsAny(candidate) { return existing } if typ.TypeEquals(existing, candidate) { return existing } + if candidateRefinesFunctionParam(candidate, existing) { + return candidate + } + if candidateRefinesFunctionParam(existing, candidate) { + return existing + } if subtype.IsSubtype(existing, candidate) && !subtype.IsSubtype(candidate, existing) { return candidate } @@ -1358,6 +1492,75 @@ func mergeFunctionParamFactType(existing, candidate typ.Type) typ.Type { return typ.JoinPreferNonSoft(existing, candidate) } +func candidateRefinesFunctionParam(candidate, baseline typ.Type) bool { + return typeElidesOptional(candidate, baseline) || + typeIsTruthyRefinement(candidate, baseline) || + typeRefinesTableKeyByTruthiness(candidate, baseline) +} + +func typeRefinesTableKeyByTruthiness(candidate, baseline typ.Type) bool { + if candidate == nil || baseline == nil || typ.TypeEquals(candidate, baseline) { + return false + } + candidateInner, _ := splitNilableParamHint(candidate) + baselineInner, _ := splitNilableParamHint(baseline) + if candidateInner == nil || baselineInner == nil { + return false + } + return nonNilTypeRefinesTableKeyByTruthiness(candidateInner, baselineInner) +} + +func nonNilTypeRefinesTableKeyByTruthiness(candidate, baseline typ.Type) bool { + candidate = unwrap.Alias(candidate) + baseline = unwrap.Alias(baseline) + switch b := baseline.(type) { + case *typ.Record: + c, ok := candidate.(*typ.Record) + if !ok { + return false + } + return recordRefinesTableKeyByTruthiness(c, b) + case *typ.Map: + c, ok := candidate.(*typ.Map) + if !ok { + return false + } + return typeIsTruthyRefinement(c.Key, b.Key) && equivalentParamValueType(c.Value, b.Value) + default: + return false + } +} + +func recordRefinesTableKeyByTruthiness(candidate, baseline *typ.Record) bool { + if candidate == nil || baseline == nil || !candidate.HasMapComponent() || !baseline.HasMapComponent() { + return false + } + if candidate.Open != baseline.Open || len(candidate.Fields) != len(baseline.Fields) { + return false + } + if (candidate.Metatable == nil) != (baseline.Metatable == nil) { + return false + } + if candidate.Metatable != nil && !typ.TypeEquals(candidate.Metatable, baseline.Metatable) { + return false + } + for i, field := range candidate.Fields { + other := baseline.Fields[i] + if field.Name != other.Name || field.Optional != other.Optional || field.Readonly != other.Readonly { + return false + } + if !equivalentParamValueType(field.Type, other.Type) { + return false + } + } + return typeIsTruthyRefinement(candidate.MapKey, baseline.MapKey) && + equivalentParamValueType(candidate.MapValue, baseline.MapValue) +} + +func equivalentParamValueType(a, b typ.Type) bool { + return typ.TypeEquals(a, b) || (subtype.IsSubtype(a, b) && subtype.IsSubtype(b, a)) +} + func preferStructuredRecordParam(existing, candidate typ.Type) (typ.Type, bool) { existingRec, okExisting := unwrap.Alias(existing).(*typ.Record) candidateRec, okCandidate := unwrap.Alias(candidate).(*typ.Record) diff --git a/compiler/check/returns/join_test.go b/compiler/check/returns/join_test.go index 6c110940..9ba2f9f7 100644 --- a/compiler/check/returns/join_test.go +++ b/compiler/check/returns/join_test.go @@ -135,6 +135,16 @@ func TestReturnTypesRefine_DifferentLength(t *testing.T) { } } +func TestMergeReturnSummary_ReplacesStaleFalsyKeyArrayElement(t *testing.T) { + stale := []typ.Type{typ.NewArray(typ.NewUnion(typ.Boolean, typ.String))} + current := []typ.Type{typ.NewArray(typ.String)} + + got := MergeReturnSummary(stale, current) + if !ReturnTypesEqual(got, current) { + t.Fatalf("expected truthy-refined key array %v, got %v", current, got) + } +} + func TestReturnTypesExtendRecord_Empty(t *testing.T) { if ReturnTypesExtendRecord(nil, nil) { t.Error("empty vectors should not extend") @@ -288,7 +298,7 @@ func TestMergeFunctionFactType_MergesSameShapeReturnsCanonically(t *testing.T) { } } -func TestMergeFunctionFactType_PrefersConcreteParamOverSoftAny(t *testing.T) { +func TestMergeFunctionFactType_PrefersConcreteParamOverTopObservation(t *testing.T) { existing := typ.Func(). Param("x", typ.Any). Returns(typ.String). @@ -372,6 +382,17 @@ func TestMergeReturnSummary_PrefersCurrentTruthyMapKeyRefinement(t *testing.T) { } } +func TestMergeReturnSummary_PrefersConcreteMapValueOverSoftPlaceholder(t *testing.T) { + entry := typ.NewRecord().Field("id", typ.String).Build() + baseline := typ.NewMap(typ.String, typ.NewArray(typ.Any)) + candidate := typ.NewMap(typ.String, typ.NewArray(entry)) + + merged := MergeReturnSummary([]typ.Type{baseline}, []typ.Type{candidate}) + if len(merged) != 1 || !typ.TypeEquals(merged[0], candidate) { + t.Fatalf("expected concrete map value evidence %v, got %v", candidate, merged) + } +} + func TestMergeReturnSummary_PrefersCurrentTruthyRecordMapKeyRefinement(t *testing.T) { entryArray := typ.NewArray(typ.Unknown) baseline := typ.NewRecord(). @@ -510,8 +531,61 @@ func TestMergeFunctionFactType_DoesNotCollapseParamToNilWhenOptionalInfoExists(t t.Fatalf("expected function, got %T", merged) } want := typ.NewOptional(typ.NewArray(typ.Any)) - if len(fn.Params) != 1 || !typ.TypeEquals(fn.Params[0].Type, want) { - t.Fatalf("expected param type %v, got %+v", want, fn.Params) + if len(fn.Params) != 1 || !fn.Params[0].Optional || !typ.TypeEquals(fn.Params[0].Type, want) { + t.Fatalf("expected optional param slot with type %v, got %+v", want, fn.Params) + } +} + +func TestMergeFunctionFactType_NilDoesNotDominateSoftOptionalParamShape(t *testing.T) { + softArray := typ.NewOptional(typ.NewUnion(typ.NewArray(typ.Any), typ.NewRecord().SetOpen(true).Build())) + preciseArray := typ.NewOptional(typ.NewArray(typ.String)) + + merged := MergeFunctionFactType( + typ.Func().OptParam("tests", typ.Nil).Returns(typ.Integer).Build(), + typ.Func().OptParam("tests", softArray).Returns(typ.Integer).Build(), + ) + fn, ok := merged.(*typ.Function) + if !ok || len(fn.Params) != 1 { + t.Fatalf("expected merged function, got %T", merged) + } + if !typ.TypeEquals(fn.Params[0].Type, softArray) { + t.Fatalf("expected nil observation not to replace soft optional table shape, got %v", fn.Params[0].Type) + } + + merged = MergeFunctionFactType( + typ.Func().OptParam("tests", softArray).Returns(typ.Integer).Build(), + typ.Func().OptParam("tests", preciseArray).Returns(typ.Integer).Build(), + ) + fn, ok = merged.(*typ.Function) + if !ok || len(fn.Params) != 1 { + t.Fatalf("expected merged function, got %T", merged) + } + if !typ.TypeEquals(fn.Params[0].Type, preciseArray) { + t.Fatalf("expected precise optional array evidence to replace soft shape, got %v", fn.Params[0].Type) + } +} + +func TestMergeFunctionFactType_ReplacesStaleFalsyMapKeyWithTruthyRefinement(t *testing.T) { + entry := typ.NewRecord().Field("id", typ.String).Build() + stale := typ.NewRecord(). + MapComponent(typ.NewUnion(typ.Boolean, typ.String), typ.NewArray(entry)). + SetOpen(true). + Build() + current := typ.NewRecord(). + MapComponent(typ.String, typ.NewArray(entry)). + SetOpen(true). + Build() + + merged := MergeFunctionFactType( + typ.Func().OptParam("t", stale).Returns(typ.NewArray(typ.NewUnion(typ.Boolean, typ.String))).Build(), + typ.Func().OptParam("t", current).Returns(typ.NewArray(typ.String)).Build(), + ) + fn, ok := merged.(*typ.Function) + if !ok || len(fn.Params) != 1 { + t.Fatalf("expected merged function, got %T", merged) + } + if !typ.TypeEquals(fn.Params[0].Type, current) { + t.Fatalf("expected truthy-refined map key param %v, got %v", current, fn.Params[0].Type) } } diff --git a/compiler/check/returns/signature.go b/compiler/check/returns/signature.go index 55129ac6..c51a005f 100644 --- a/compiler/check/returns/signature.go +++ b/compiler/check/returns/signature.go @@ -69,6 +69,7 @@ func BuildSeedFunctionTypeWithBindings( Expected: nil, ImplicitSelf: implicitSelf, ImplicitSelfType: implicitSelfType, + UntypedParamType: typ.Any, }) if len(fn.ReturnTypes) > 0 { diff --git a/compiler/check/returns/types.go b/compiler/check/returns/types.go index 6dc9460c..83293c2c 100644 --- a/compiler/check/returns/types.go +++ b/compiler/check/returns/types.go @@ -60,8 +60,8 @@ type LocalFuncInfo struct { ParentGraph *cfg.Graph ParentFn *ast.FunctionExpr DefPoint cfg.Point - // ParamHints holds inferred parameter types from call sites in the parent graph. - // Index corresponds to parameter position. + // ParamHints holds inferred effective-parameter types from call sites in the + // parent graph. For methods, index 0 is self. ParamHints []typ.Type } diff --git a/compiler/check/returns/widen.go b/compiler/check/returns/widen.go index 898c6360..82f6777d 100644 --- a/compiler/check/returns/widen.go +++ b/compiler/check/returns/widen.go @@ -8,7 +8,6 @@ import ( "github.com/wippyai/go-lua/types/narrow" "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" - typjoin "github.com/wippyai/go-lua/types/typ/join" "github.com/wippyai/go-lua/types/typ/unwrap" ) @@ -32,7 +31,7 @@ func WidenFacts(prev, next api.Facts) api.Facts { for _, sym := range symbols { prevFact := readFunctionFactFromFacts(&prev, sym) nextFact := readFunctionFactFromFacts(&next, sym) - writeFunctionFactToFacts(&out, sym, widenFunctionFactForConvergence(prevFact, nextFact)) + writeNormalizedFunctionFactToFacts(&out, sym, widenFunctionFactForConvergence(prevFact, nextFact)) } if len(out.FunctionFacts) == 0 { out.FunctionFacts = nil @@ -60,7 +59,7 @@ func JoinFacts(prev, next api.Facts) api.Facts { for _, sym := range symbols { prevFact := readFunctionFactFromFacts(&prev, sym) nextFact := readFunctionFactFromFacts(&next, sym) - writeFunctionFactToFacts(&out, sym, JoinFunctionFact(prevFact, nextFact)) + writeNormalizedFunctionFactToFacts(&out, sym, JoinFunctionFact(prevFact, nextFact)) } return out } @@ -493,13 +492,13 @@ func WidenParamHints(prev, next api.ParamHints) api.ParamHints { } merged := make(api.ParamHints, len(prev)+len(next)) for _, sym := range cfg.SortedSymbolIDs(prev) { - hints := prev[sym] + hints := normalizeParamHintVector(prev[sym]) if hasNonNilHint(hints) { merged[sym] = hints } } for _, sym := range cfg.SortedSymbolIDs(next) { - hints := next[sym] + hints := normalizeParamHintVector(next[sym]) if !hasNonNilHint(hints) { continue } @@ -518,7 +517,7 @@ func filterEmptyParamHints(hints api.ParamHints) api.ParamHints { } out := make(api.ParamHints, len(hints)) for _, sym := range cfg.SortedSymbolIDs(hints) { - v := hints[sym] + v := normalizeParamHintVector(hints[sym]) if hasNonNilHint(v) { out[sym] = v } @@ -529,6 +528,26 @@ func filterEmptyParamHints(hints api.ParamHints) api.ParamHints { return out } +func normalizeParamHintVector(hints []typ.Type) []typ.Type { + var out []typ.Type + for i, hint := range hints { + normalized := paramhints.NormalizeHintType(hint) + if out != nil { + out[i] = normalized + continue + } + if !typ.TypeEquals(hint, normalized) { + out = make([]typ.Type, len(hints)) + copy(out, hints[:i]) + out[i] = normalized + } + } + if out != nil { + return out + } + return hints +} + func hasNonNilHint(hints []typ.Type) bool { for _, h := range hints { if h != nil { @@ -579,6 +598,67 @@ func joinParamHint(a, b typ.Type) typ.Type { if unwrap.IsNilType(b) && !unwrap.IsNilType(a) { return a } + if joined, ok := joinNilableParamHint(a, b); ok { + return joined + } + return joinNonNilParamHint(a, b) +} + +func joinNilableParamHint(a, b typ.Type) (typ.Type, bool) { + ai, anil := splitNilableParamHint(a) + bi, bnil := splitNilableParamHint(b) + if !anil && !bnil { + return nil, false + } + if ai == nil && bi == nil { + return typ.Nil, true + } + if ai == nil { + return typ.NewOptional(bi), true + } + if bi == nil { + return typ.NewOptional(ai), true + } + return typ.NewOptional(joinNonNilParamHint(ai, bi)), true +} + +func splitNilableParamHint(t typ.Type) (typ.Type, bool) { + t = unwrap.Alias(t) + switch v := t.(type) { + case nil: + return nil, false + case *typ.Optional: + return v.Inner, true + case *typ.Union: + members := make([]typ.Type, 0, len(v.Members)) + nilable := false + for _, member := range v.Members { + member = unwrap.Alias(member) + if unwrap.IsNilType(member) { + nilable = true + continue + } + members = append(members, member) + } + if !nilable { + return t, false + } + return typ.NewUnion(members...), true + default: + if unwrap.IsNilType(t) { + return nil, true + } + return t, false + } +} + +func joinNonNilParamHint(a, b typ.Type) typ.Type { + if upper, ok := selectParamHintTableUpperBound(a, b); ok { + return upper + } + if preferred, ok := preferConcreteOverSoftType(a, b); ok { + return preferred + } if typeCanSelfEmbed(a) && typeContainsEquivalent(b, a) && !typ.IsAbsentOrUnknown(a) { if typeContainsUnion(a) { return a @@ -597,13 +677,226 @@ func joinParamHint(a, b typ.Type) typ.Type { if typeIsTruthyRefinement(b, a) { return b } + if joined, ok := typ.JoinCompatibleRecords(a, b); ok { + return joined + } + if joined, ok := joinParamHintMapRecord(a, b); ok { + return joined + } if TypeExtendsRecord(a, b) { return a } if TypeExtendsRecord(b, a) { return b } - return typ.JoinPreferNonSoft(a, b) + if !typ.IsAbsentOrUnknown(a) && !typ.IsAbsentOrUnknown(b) { + if subtype.IsSubtype(a, b) { + return b + } + if subtype.IsSubtype(b, a) { + return a + } + } + return paramhints.NormalizeHintType(typ.JoinPreferNonSoft(a, b)) +} + +func preferConcreteOverSoftType(a, b typ.Type) (typ.Type, bool) { + aSoft := typ.IsSoft(a, typ.SoftPlaceholderPolicy) + bSoft := typ.IsSoft(b, typ.SoftPlaceholderPolicy) + switch { + case aSoft && !bSoft && !unwrap.IsNilType(b): + return b, true + case bSoft && !aSoft && !unwrap.IsNilType(a): + return a, true + } + if preferred, ok := preferConcreteOverNilableSoftType(a, b); ok { + return preferred, true + } + return nil, false +} + +func preferConcreteOverNilableSoftType(a, b typ.Type) (typ.Type, bool) { + if preferred, ok := preferConcreteOverNilableSoftTypeDirected(a, b); ok { + return preferred, true + } + return preferConcreteOverNilableSoftTypeDirected(b, a) +} + +func preferConcreteOverNilableSoftTypeDirected(softMaybeNil, concrete typ.Type) (typ.Type, bool) { + inner, nilable := splitNilableParamHint(softMaybeNil) + if !nilable || inner == nil || !typ.IsSoft(inner, typ.SoftPlaceholderPolicy) { + return nil, false + } + if concrete == nil || unwrap.IsNilType(concrete) { + return nil, false + } + concreteInner, concreteNilable := splitNilableParamHint(concrete) + if concreteInner == nil { + return nil, false + } + if typ.IsSoft(concreteInner, typ.SoftPlaceholderPolicy) { + return nil, false + } + if concreteNilable { + return concrete, true + } + return typ.NewOptional(concrete), true +} + +func joinParamHintMapRecord(a, b typ.Type) (typ.Type, bool) { + if joined, ok := joinParamHintMapRecordDirected(a, b); ok { + return joined, true + } + return joinParamHintMapRecordDirected(b, a) +} + +func joinParamHintMapRecordDirected(mapType, recordType typ.Type) (typ.Type, bool) { + m, ok := unwrap.Alias(mapType).(*typ.Map) + if !ok || m == nil { + return nil, false + } + r, ok := unwrap.Alias(recordType).(*typ.Record) + if !ok || r == nil || !r.HasMapComponent() { + return nil, false + } + + key := joinNonNilParamHint(m.Key, r.MapKey) + value := joinNonNilParamHint(m.Value, r.MapValue) + if len(r.Fields) == 0 && r.Metatable == nil { + return typ.NewMap(key, value), true + } + builder := typ.NewRecord() + if r.Open { + builder.SetOpen(true) + } + if r.Metatable != nil { + builder.Metatable(r.Metatable) + } + builder.MapComponent(key, value) + for _, field := range r.Fields { + fieldType := field.Type + optional := true + if subtype.IsSubtype(typ.LiteralString(field.Name), key) { + fieldType = joinNonNilParamHint(field.Type, value) + } else { + optional = field.Optional + } + switch { + case optional && field.Readonly: + builder.OptReadonlyField(field.Name, fieldType) + case optional: + builder.OptField(field.Name, fieldType) + case field.Readonly: + builder.ReadonlyField(field.Name, fieldType) + default: + builder.Field(field.Name, fieldType) + } + } + return builder.Build(), true +} + +func selectParamHintTableUpperBound(a, b typ.Type) (typ.Type, bool) { + if paramHintIsOnlyTableTop(a) && typ.IsAny(b) { + return a, true + } + if paramHintIsOnlyTableTop(b) && typ.IsAny(a) { + return b, true + } + if paramHintContainsTableTop(a) && paramHintCoveredByTableTop(b) && subtype.IsSubtype(b, a) { + return a, true + } + if paramHintContainsTableTop(b) && paramHintCoveredByTableTop(a) && subtype.IsSubtype(a, b) { + return b, true + } + return nil, false +} + +func paramHintContainsTableTop(t typ.Type) bool { + if t == nil { + return false + } + if unwrap.IsBuiltinTableTop(typ.UnwrapAnnotated(t)) { + return true + } + switch v := typ.UnwrapAnnotated(t).(type) { + case *typ.Alias: + return paramHintContainsTableTop(v.UnaliasedTarget()) + case *typ.Optional: + return paramHintContainsTableTop(v.Inner) + case *typ.Union: + for _, member := range v.Members { + if paramHintContainsTableTop(member) { + return true + } + } + } + return false +} + +func paramHintIsOnlyTableTop(t typ.Type) bool { + if t == nil { + return false + } + if unwrap.IsBuiltinTableTop(typ.UnwrapAnnotated(t)) { + return true + } + switch v := typ.UnwrapAnnotated(t).(type) { + case *typ.Alias: + return paramHintIsOnlyTableTop(v.UnaliasedTarget()) + case *typ.Optional: + return paramHintIsOnlyTableTop(v.Inner) + case *typ.Union: + if len(v.Members) == 0 { + return false + } + hasTableTop := false + for _, member := range v.Members { + if unwrap.IsNilType(member) { + continue + } + if !paramHintIsOnlyTableTop(member) { + return false + } + hasTableTop = true + } + return hasTableTop + default: + return false + } +} + +func paramHintCoveredByTableTop(t typ.Type) bool { + if t == nil { + return false + } + if typ.IsAny(t) { + return true + } + if unwrap.IsNilType(t) || unwrap.IsBuiltinTableTop(typ.UnwrapAnnotated(t)) { + return true + } + switch v := typ.UnwrapAnnotated(t).(type) { + case *typ.Alias: + return paramHintCoveredByTableTop(v.UnaliasedTarget()) + case *typ.Optional: + return paramHintCoveredByTableTop(v.Inner) + case *typ.Recursive: + return v.Body != nil && v.Body != v && paramHintCoveredByTableTop(v.Body) + case *typ.Union: + if len(v.Members) == 0 { + return false + } + for _, member := range v.Members { + if !paramHintCoveredByTableTop(member) { + return false + } + } + return true + case *typ.Record, *typ.Map, *typ.Array, *typ.Tuple, *typ.Interface, *typ.Intersection: + return true + default: + return false + } } func typeIsTruthyRefinement(candidate, baseline typ.Type) bool { @@ -737,6 +1030,12 @@ func widenValueTypeForConvergence(existing, candidate typ.Type) typ.Type { if typ.TypeEquals(existing, candidate) { return existing } + if unwrap.IsNilType(existing) && !unwrap.IsNilType(candidate) { + return candidate + } + if unwrap.IsNilType(candidate) && !unwrap.IsNilType(existing) { + return existing + } if typ.IsAny(existing) || typ.IsUnknown(existing) { return existing } @@ -851,9 +1150,15 @@ func widenFunctionParamFactTypeForConvergence(existing, candidate typ.Type) typ. if typ.IsAny(candidate) || typ.IsUnknown(candidate) { return candidate } - if typeElidesOptional(candidate, existing) || typeIsTruthyRefinement(candidate, existing) { + if preferred, ok := preferConcreteOverSoftType(existing, candidate); ok { + return preferred + } + if candidateRefinesFunctionParam(candidate, existing) { return candidate } + if candidateRefinesFunctionParam(existing, candidate) { + return existing + } if subtype.IsSubtype(candidate, existing) && !subtype.IsSubtype(existing, candidate) { return existing } @@ -1442,10 +1747,11 @@ func mergeFunctionReturnsIfSameShape(prevFn, nextFn *typ.Function) (typ.Type, bo allowedTypeParams[tp.Name] = true } } - normalizeReturn := func(t typ.Type) typ.Type { + normalizeReturn := func(t typ.Type) (typ.Type, bool) { if t == nil { - return nil + return nil, false } + leaked := false return typ.Rewrite(t, func(node typ.Type) (typ.Type, bool) { tp, ok := node.(*typ.TypeParam) if !ok { @@ -1455,17 +1761,30 @@ func mergeFunctionReturnsIfSameShape(prevFn, nextFn *typ.Function) (typ.Type, bo return node, false } // Free type params in non-generic function returns are unstable placeholders. + leaked = true return typ.Unknown, true - }) + }), leaked } normalizedPrev := make([]typ.Type, len(prevFn.Returns)) normalizedNext := make([]typ.Type, len(nextFn.Returns)) + leakedPrev := make([]bool, len(prevFn.Returns)) + leakedNext := make([]bool, len(nextFn.Returns)) for i := range prevFn.Returns { - normalizedPrev[i] = normalizeReturn(prevFn.Returns[i]) - normalizedNext[i] = normalizeReturn(nextFn.Returns[i]) + normalizedPrev[i], leakedPrev[i] = normalizeReturn(prevFn.Returns[i]) + normalizedNext[i], leakedNext[i] = normalizeReturn(nextFn.Returns[i]) } - mergedReturns := typjoin.ReturnVectors(normalizedPrev, normalizedNext) + mergedReturns := make([]typ.Type, len(normalizedPrev)) + for i := range mergedReturns { + switch { + case leakedPrev[i] && !leakedNext[i]: + mergedReturns[i] = normalizedNext[i] + case leakedNext[i] && !leakedPrev[i]: + mergedReturns[i] = normalizedPrev[i] + default: + mergedReturns[i] = typ.JoinReturnSlot(normalizedPrev[i], normalizedNext[i]) + } + } if ReturnTypesEqual(prevFn.Returns, mergedReturns) { return prevFn, true } diff --git a/compiler/check/returns/widen_test.go b/compiler/check/returns/widen_test.go index 5b837420..13b974dd 100644 --- a/compiler/check/returns/widen_test.go +++ b/compiler/check/returns/widen_test.go @@ -313,6 +313,136 @@ func TestWidenParamHints_ReplacesStaleBroadHintWithCurrentRefinement(t *testing. } } +func TestWidenParamHints_ReplacesSoftContainerPlaceholderWithConcreteElementShape(t *testing.T) { + entry := typ.NewRecord().Field("id", typ.String).Build() + stale := typ.NewUnion( + typ.NewArray(typ.Any), + typ.NewRecord().SetOpen(true).Build(), + ) + current := typ.NewArray(entry) + + merged := WidenParamHints( + api.ParamHints{1: []typ.Type{stale}}, + api.ParamHints{1: []typ.Type{current}}, + ) + + got := merged[1][0] + if !typ.TypeEquals(got, current) { + t.Fatalf("expected concrete array hint %v to replace soft stale hint, got %v", current, got) + } +} + +func TestWidenParamHints_PreservesStructuredHintOverNilOnlyObservation(t *testing.T) { + context := typ.NewMap(typ.String, typ.Any) + + merged := WidenParamHints( + api.ParamHints{1: []typ.Type{typ.String, typ.Any, context}}, + api.ParamHints{1: []typ.Type{typ.String, typ.Any, typ.Nil}}, + ) + + got := merged[1][2] + if !typ.TypeEquals(got, context) { + t.Fatalf("expected nil-only observation to preserve structured hint %v, got %v", context, got) + } + + again := WidenParamHints(merged, api.ParamHints{1: []typ.Type{typ.String, typ.Any, typ.Nil}}) + if !symbolTypeVectorMapEqual(merged, again) { + t.Fatalf("expected idempotent nil-only observation widening, got %v then %v", merged, again) + } +} + +func TestWidenParamHints_PreservesMapHintOverOptionalOpenRecordObservation(t *testing.T) { + context := typ.NewMap(typ.String, typ.Any) + optionalContextRecord := typ.NewOptional(typ.NewRecord(). + MapComponent(typ.String, typ.Any). + SetOpen(true). + Build()) + + merged := WidenParamHints( + api.ParamHints{1: []typ.Type{typ.String, typ.Any, context}}, + api.ParamHints{1: []typ.Type{typ.String, typ.Any, optionalContextRecord}}, + ) + + got := merged[1][2] + if got == nil || typ.TypeEquals(got, typ.Nil) { + t.Fatalf("expected optional structured observation to preserve context hint, got %v", got) + } + if !typ.TypeEquals(got, typ.NewOptional(context)) { + t.Fatalf("expected pure map observation to stay canonical, got %v", got) + } + + again := WidenParamHints(merged, api.ParamHints{1: []typ.Type{typ.String, typ.Any, optionalContextRecord}}) + if !symbolTypeVectorMapEqual(merged, again) { + t.Fatalf("expected idempotent optional structured observation widening, got %v then %v", merged, again) + } +} + +func TestWidenParamHints_CollapsesPureOpenRecordMapToCanonicalMap(t *testing.T) { + entry := typ.NewRecord().Field("id", typ.String).Build() + canonical := typ.NewMap(typ.String, typ.NewArray(entry)) + staleRecordView := typ.NewRecord(). + MapComponent(typ.NewUnion(typ.String, typ.False), typ.NewArray(entry)). + SetOpen(true). + Build() + + merged := WidenParamHints( + api.ParamHints{1: []typ.Type{staleRecordView}}, + api.ParamHints{1: []typ.Type{canonical}}, + ) + + got := merged[1][0] + if !typ.TypeEquals(got, canonical) { + t.Fatalf("expected pure keyed table hint to canonicalize to %v, got %v", canonical, got) + } +} + +func TestWidenParamHints_TableTopUpperBoundAbsorbsRecordUnion(t *testing.T) { + tableTop := typ.NewOptional(typ.NewInterface("table", nil)) + strategySpec := typ.NewRecord(). + Field("kind", typ.LiteralString("strategy")). + Field("tools", typ.NewTuple(typ.String, typ.String, typ.String)). + Build() + contextSpec := typ.NewRecord(). + Field("kind", typ.LiteralString("context")). + Field("scope", typ.String). + Build() + nextHint := typ.NewUnion(strategySpec, contextSpec) + + merged := WidenParamHints( + api.ParamHints{1: []typ.Type{tableTop}}, + api.ParamHints{1: []typ.Type{nextHint}}, + ) + + got := merged[1][0] + if !typ.TypeEquals(got, tableTop) { + t.Fatalf("expected table top upper bound %v, got %v", tableTop, got) + } + + again := WidenParamHints(merged, api.ParamHints{1: []typ.Type{nextHint}}) + if !symbolTypeVectorMapEqual(merged, again) { + t.Fatalf("expected idempotent table-top widening, got %v then %v", merged, again) + } +} + +func TestWidenParamHints_TableTopUpperBoundAbsorbsAnyObservation(t *testing.T) { + tableTop := typ.NewOptional(typ.NewInterface("table", nil)) + + merged := WidenParamHints( + api.ParamHints{1: []typ.Type{tableTop}}, + api.ParamHints{1: []typ.Type{typ.Any}}, + ) + + got := merged[1][0] + if !typ.TypeEquals(got, tableTop) { + t.Fatalf("expected dynamic observation to preserve table top upper bound %v, got %v", tableTop, got) + } + + again := WidenParamHints(merged, api.ParamHints{1: []typ.Type{typ.Any}}) + if !symbolTypeVectorMapEqual(merged, again) { + t.Fatalf("expected idempotent table-top/any widening, got %v then %v", merged, again) + } +} + func TestWidenCapturedFieldAssigns_NormalizesOptionalFunctionValues(t *testing.T) { fn := typ.Func().Param("fn", typ.Unknown).Build() merged := WidenCapturedFieldAssigns(nil, api.CapturedFieldAssigns{ diff --git a/compiler/check/siblings/overlay.go b/compiler/check/siblings/overlay.go index 5a39f432..b35df311 100644 --- a/compiler/check/siblings/overlay.go +++ b/compiler/check/siblings/overlay.go @@ -64,6 +64,12 @@ func (o OverlayServicesFuncs) SeedType(fn *ast.FunctionExpr) typ.Type { // to make progress even when not all return types are known. func BuildOverlay(c OverlayConfig) map[cfg.SymbolID]typ.Type { overlay := make(map[cfg.SymbolID]typ.Type) + siblingFuncs := make(map[cfg.SymbolID]*ast.FunctionExpr, len(c.Siblings)) + for _, sib := range c.Siblings { + if sib.Symbol != 0 && sib.Func != nil { + siblingFuncs[sib.Symbol] = sib.Func + } + } // Add sibling function types with current return vectors. for sym, returnTypes := range c.ReturnVectors { @@ -71,7 +77,11 @@ func BuildOverlay(c OverlayConfig) map[cfg.SymbolID]typ.Type { continue } if len(returnTypes) > 0 { - overlay[sym] = buildFunctionFromReturns(returnTypes) + var seedType typ.Type + if c.Services != nil { + seedType = c.Services.SeedType(siblingFuncs[sym]) + } + overlay[sym] = buildFunctionFromSeedAndReturns(seedType, returnTypes) } } @@ -101,3 +111,38 @@ func buildFunctionFromReturns(returnTypes []typ.Type) typ.Type { } return typ.Func().Returns(returnTypes...).Build() } + +func buildFunctionFromSeedAndReturns(seed typ.Type, returnTypes []typ.Type) typ.Type { + if len(returnTypes) == 0 { + return seed + } + fn, ok := seed.(*typ.Function) + if !ok || fn == nil { + return buildFunctionFromReturns(returnTypes) + } + builder := typ.Func() + for _, tp := range fn.TypeParams { + builder.TypeParam(tp.Name, tp.Constraint) + } + for _, p := range fn.Params { + if p.Optional { + builder.OptParam(p.Name, p.Type) + } else { + builder.Param(p.Name, p.Type) + } + } + if fn.Variadic != nil { + builder.Variadic(fn.Variadic) + } + builder.Returns(returnTypes...) + if fn.Effects != nil { + builder.Effects(fn.Effects) + } + if fn.Spec != nil { + builder.Spec(fn.Spec) + } + if fn.Refinement != nil { + builder.WithRefinement(fn.Refinement) + } + return builder.Build() +} diff --git a/compiler/check/store/facts_clone.go b/compiler/check/store/facts_clone.go new file mode 100644 index 00000000..7b6eb2f1 --- /dev/null +++ b/compiler/check/store/facts_clone.go @@ -0,0 +1,162 @@ +package store + +import ( + "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/types/typ" +) + +func cloneFacts(f api.Facts) api.Facts { + if factsEmpty(f) { + return api.Facts{} + } + return api.Facts{ + FunctionFacts: cloneFunctionFacts(f.FunctionFacts), + ParamHints: cloneParamHints(f.ParamHints), + LiteralSigs: cloneLiteralSigs(f.LiteralSigs), + CapturedTypes: cloneCapturedTypes(f.CapturedTypes), + CapturedFields: cloneCapturedFieldAssigns(f.CapturedFields), + CapturedContainers: cloneCapturedContainerMutations(f.CapturedContainers), + ConstructorFields: cloneConstructorFields(f.ConstructorFields), + } +} + +func cloneFunctionFacts(src api.FunctionFacts) api.FunctionFacts { + if len(src) == 0 { + return nil + } + out := make(api.FunctionFacts, len(src)) + for sym, fact := range src { + fact.Summary = cloneTypeSlice(fact.Summary) + fact.Narrow = cloneTypeSlice(fact.Narrow) + out[sym] = fact + } + return out +} + +func cloneParamHints(src api.ParamHints) api.ParamHints { + if len(src) == 0 { + return nil + } + out := make(api.ParamHints, len(src)) + for sym, hints := range src { + out[sym] = cloneTypeSlice(hints) + } + return out +} + +func cloneLiteralSigs(src api.LiteralSigs) api.LiteralSigs { + if len(src) == 0 { + return nil + } + out := make(map[*ast.FunctionExpr]*typ.Function, len(src)) + for fn, sig := range src { + out[fn] = sig + } + return out +} + +func cloneCapturedTypes(src api.CapturedTypes) api.CapturedTypes { + if len(src) == 0 { + return nil + } + out := make(api.CapturedTypes, len(src)) + for sym, t := range src { + out[sym] = t + } + return out +} + +func cloneCapturedFieldAssigns(src api.CapturedFieldAssigns) api.CapturedFieldAssigns { + if len(src) == 0 { + return nil + } + out := make(api.CapturedFieldAssigns, len(src)) + for callee, bySym := range src { + if len(bySym) == 0 { + continue + } + bySymOut := make(map[cfg.SymbolID]map[string]typ.Type, len(bySym)) + for sym, fields := range bySym { + if len(fields) == 0 { + continue + } + fieldOut := make(map[string]typ.Type, len(fields)) + for name, t := range fields { + fieldOut[name] = t + } + bySymOut[sym] = fieldOut + } + if len(bySymOut) > 0 { + out[callee] = bySymOut + } + } + if len(out) == 0 { + return nil + } + return out +} + +func cloneCapturedContainerMutations(src api.CapturedContainerMutations) api.CapturedContainerMutations { + if len(src) == 0 { + return nil + } + out := make(api.CapturedContainerMutations, len(src)) + for callee, bySym := range src { + if len(bySym) == 0 { + continue + } + bySymOut := make(map[cfg.SymbolID][]api.ContainerMutation, len(bySym)) + for sym, muts := range bySym { + if len(muts) == 0 { + continue + } + mutsOut := make([]api.ContainerMutation, len(muts)) + copy(mutsOut, muts) + for i := range mutsOut { + if len(mutsOut[i].Segments) > 0 { + mutsOut[i].Segments = append(mutsOut[i].Segments[:0:0], mutsOut[i].Segments...) + } + } + bySymOut[sym] = mutsOut + } + if len(bySymOut) > 0 { + out[callee] = bySymOut + } + } + if len(out) == 0 { + return nil + } + return out +} + +func cloneConstructorFields(src api.ConstructorFields) api.ConstructorFields { + if len(src) == 0 { + return nil + } + out := make(api.ConstructorFields, len(src)) + for sym, fields := range src { + if len(fields) == 0 { + continue + } + fieldOut := make(map[string]typ.Type, len(fields)) + for name, t := range fields { + fieldOut[name] = t + } + out[sym] = fieldOut + } + if len(out) == 0 { + return nil + } + return out +} + +func cloneTypeSlice(src []typ.Type) []typ.Type { + if len(src) == 0 { + return nil + } + out := make([]typ.Type, len(src)) + copy(out, src) + return out +} diff --git a/compiler/check/store/snapshot_inputs.go b/compiler/check/store/snapshot_inputs.go index 40b3d9ec..6bcb6233 100644 --- a/compiler/check/store/snapshot_inputs.go +++ b/compiler/check/store/snapshot_inputs.go @@ -62,7 +62,11 @@ func (in *snapshotInputs) factsFor(ctx *db.QueryContext, key api.GraphKey) (api. if in == nil || in.facts == nil { return api.Facts{}, false } - return in.facts.Get(ctx, key) + facts, ok := in.facts.Get(ctx, key) + if !ok { + return api.Facts{}, false + } + return cloneFacts(facts), true } func (in *snapshotInputs) setFacts(key api.GraphKey, facts api.Facts) { @@ -77,7 +81,7 @@ func (in *snapshotInputs) setFacts(key api.GraphKey, facts api.Facts) { in.facts.Set(key, api.Facts{}) return } - next := facts + next := cloneFacts(facts) if prev, ok := in.factValues[key]; ok && returns.FactsEqual(prev, next) { return } @@ -184,15 +188,15 @@ func (s *SessionStore) currentInterprocFacts(key api.GraphKey) api.Facts { if s.InterprocNext != nil && s.InterprocNext.Facts != nil { if next, ok := s.InterprocNext.Facts[key]; ok { if factsEmpty(prev) { - return next + return cloneFacts(next) } if factsEmpty(next) { - return prev + return cloneFacts(prev) } - return returns.JoinFacts(prev, next) + return cloneFacts(returns.JoinFacts(prev, next)) } } - return prev + return cloneFacts(prev) } func (s *SessionStore) syncFactsInput(key api.GraphKey) { diff --git a/compiler/check/store/store_test.go b/compiler/check/store/store_test.go index c67e78b8..be860f69 100644 --- a/compiler/check/store/store_test.go +++ b/compiler/check/store/store_test.go @@ -268,6 +268,40 @@ func TestGetInterprocFactsSnapshot_OverlaysCurrentIterationFacts(t *testing.T) { } } +func TestGetInterprocFactsSnapshot_ReturnsImmutableFactContainers(t *testing.T) { + graph := cfg.Build(&ast.FunctionExpr{}) + if graph == nil || graph.ID() == 0 { + t.Fatal("expected graph with stable ID") + } + + parent := scope.New().WithType("T", typ.String) + s := NewSessionStore() + s.SetGraphParentHash(graph.ID(), parent.Hash()) + s.SetParentScope(parent.Hash(), parent) + key := api.KeyForGraph(graph, parent.Hash()) + sym := cfg.SymbolID(7) + s.InterprocPrev.Facts[key] = api.Facts{ + ParamHints: api.ParamHints{ + sym: []typ.Type{typ.String, typ.NewMap(typ.String, typ.Any)}, + }, + FunctionFacts: api.FunctionFacts{ + sym: {Summary: []typ.Type{typ.String}}, + }, + } + + snapshot := s.GetInterprocFactsSnapshot(graph, parent) + snapshot.ParamHints[sym][1] = typ.Nil + snapshot.FunctionFacts[sym] = api.FunctionFact{Summary: []typ.Type{typ.Number}} + + again := s.GetInterprocFactsSnapshot(graph, parent) + if got := again.ParamHints[sym][1]; !typ.TypeEquals(got, typ.NewMap(typ.String, typ.Any)) { + t.Fatalf("snapshot param hint mutation leaked into store: %v", got) + } + if got := again.FunctionFacts.Summary(sym); len(got) != 1 || !typ.TypeEquals(got[0], typ.String) { + t.Fatalf("snapshot function fact mutation leaked into store: %v", got) + } +} + func TestMergeInterprocFactsNext_ReconcilesDeltasWithinIteration(t *testing.T) { key := api.GraphKey{GraphID: 1, ParentHash: 2} sym := cfg.SymbolID(7) diff --git a/compiler/check/synth/literals.go b/compiler/check/synth/literals.go index a46b9f45..82d481ec 100644 --- a/compiler/check/synth/literals.go +++ b/compiler/check/synth/literals.go @@ -6,6 +6,7 @@ import ( "github.com/wippyai/go-lua/compiler/check/api" phasecore "github.com/wippyai/go-lua/compiler/check/synth/phase/core" "github.com/wippyai/go-lua/types/flow" + "github.com/wippyai/go-lua/types/kind" querycore "github.com/wippyai/go-lua/types/query/core" "github.com/wippyai/go-lua/types/typ" "github.com/wippyai/go-lua/types/typ/unwrap" @@ -148,6 +149,18 @@ func FunctionLiteralSignatures(graph *cfg.Graph, engine LiteralSynth, declaredRe out[fn] = sig } } + receiverSelfType := func(expr ast.Expr, p cfg.Point) typ.Type { + sc := scopes[p] + if sc == nil { + sc = scopes[entry] + } + if ident, ok := expr.(*ast.IdentExpr); ok && ident != nil && sc != nil { + if named, ok := sc.LookupType(ident.Value); ok && named != nil { + return named + } + } + return widenMutableReceiverState(engine.TypeOf(expr, p)) + } var collectExpr func(expr ast.Expr, p cfg.Point, expected typ.Type) var collectTable func(tbl *ast.TableExpr, p cfg.Point, expected typ.Type) @@ -194,7 +207,7 @@ func FunctionLiteralSignatures(graph *cfg.Graph, engine LiteralSynth, declaredRe } } if fieldCount > 0 { - selfType = selfBuilder.Build() + selfType = widenMutableReceiverState(selfBuilder.Build()) } } @@ -262,7 +275,7 @@ func FunctionLiteralSignatures(graph *cfg.Graph, engine LiteralSynth, declaredRe var expectedFn *typ.Function if info.TargetKind == cfg.FuncDefField || info.TargetKind == cfg.FuncDefMethod { if info.Receiver != nil { - recvType := engine.TypeOf(info.Receiver, p) + recvType := receiverSelfType(info.Receiver, p) if recvType != nil && phasecore.HasSelfParam(info.FuncExpr, graph.Bindings()) { expectedFn = typ.Func().Param("self", recvType).Build() } @@ -291,3 +304,61 @@ func FunctionLiteralSignatures(graph *cfg.Graph, engine LiteralSynth, declaredRe } return out } + +func widenMutableReceiverState(t typ.Type) typ.Type { + rec, ok := unwrap.Alias(t).(*typ.Record) + if !ok { + return t + } + + builder := typ.NewRecord() + if rec.Open { + builder.SetOpen(true) + } + for _, f := range rec.Fields { + fieldType := widenMutableReceiverField(f.Type) + switch { + case f.Optional && f.Readonly: + builder.OptReadonlyField(f.Name, fieldType) + case f.Optional: + builder.OptField(f.Name, fieldType) + case f.Readonly: + builder.ReadonlyField(f.Name, fieldType) + default: + builder.Field(f.Name, fieldType) + } + } + if rec.Metatable != nil { + builder.Metatable(rec.Metatable) + } + if rec.HasMapComponent() { + builder.MapComponent(rec.MapKey, rec.MapValue) + } + return builder.Build() +} + +func widenMutableReceiverField(t typ.Type) typ.Type { + if t == nil { + return typ.Unknown + } + unaliased := unwrap.Alias(t) + if unaliased == nil { + return typ.Unknown + } + if unaliased.Kind() == kind.Nil { + return typ.Unknown + } + if v, ok := unaliased.(*typ.Literal); ok { + switch v.Base { + case kind.Boolean: + return typ.Boolean + case kind.String: + return typ.String + case kind.Integer: + return typ.Integer + case kind.Number: + return typ.Number + } + } + return t +} diff --git a/compiler/check/synth/ops/logical.go b/compiler/check/synth/ops/logical.go index e3a2fd39..48fa7d21 100644 --- a/compiler/check/synth/ops/logical.go +++ b/compiler/check/synth/ops/logical.go @@ -33,10 +33,6 @@ func LogicalAndTyped(left, right typ.Type) typ.Type { return typ.Never } - if right != nil && right.Kind().IsNever() { - return typ.Never - } - // If left is definitely truthy (cannot be nil or false), result is right if !CanBeFalsy(left) { return right @@ -52,6 +48,9 @@ func LogicalAndTyped(left, right typ.Type) typ.Type { if falsyLeft == nil || falsyLeft.Kind().IsNever() { return right } + if right != nil && right.Kind().IsNever() { + return falsyLeft + } // Unknown/any right branch must remain dominant. Using plain union here can // collapse to falsy-only because typ.NewUnion treats unknown as non-informative. if typ.IsUnknown(right) { @@ -90,10 +89,6 @@ func LogicalOrTyped(left, right typ.Type) typ.Type { return typ.Never } - if right != nil && right.Kind().IsNever() { - return typ.Never - } - // If left is definitely truthy, result is left if !CanBeFalsy(left) { return left @@ -109,6 +104,9 @@ func LogicalOrTyped(left, right typ.Type) typ.Type { if truthyLeft == nil || truthyLeft.Kind().IsNever() { return right } + if right != nil && right.Kind().IsNever() { + return truthyLeft + } return typ.JoinBranchOutcome(truthyLeft, right) } diff --git a/compiler/check/synth/ops/logical_test.go b/compiler/check/synth/ops/logical_test.go index f94d31a7..cd993b4e 100644 --- a/compiler/check/synth/ops/logical_test.go +++ b/compiler/check/synth/ops/logical_test.go @@ -57,6 +57,13 @@ func TestLogicalAndTyped_Never(t *testing.T) { } } +func TestLogicalAndTyped_RightNeverPreservesFalsyShortCircuit(t *testing.T) { + result := LogicalAndTyped(typ.NewOptional(typ.Integer), typ.Never) + if !typ.TypeEquals(result, typ.Nil) { + t.Errorf("optional(integer) and never should preserve nil short-circuit, got %v", result) + } +} + func TestLogicalOrTyped_LeftTruthy(t *testing.T) { result := LogicalOrTyped(typ.Integer, typ.String) if result != typ.Integer { @@ -104,8 +111,15 @@ func TestLogicalOrTyped_Never(t *testing.T) { } result = LogicalOrTyped(typ.Integer, typ.Never) - if result.Kind() != kind.Never { - t.Errorf("X or never should return never, got %v", result) + if result != typ.Integer { + t.Errorf("truthy or never should preserve truthy short-circuit, got %v", result) + } +} + +func TestLogicalOrTyped_RightNeverPreservesTruthyShortCircuit(t *testing.T) { + result := LogicalOrTyped(typ.NewOptional(typ.Integer), typ.Never) + if !typ.TypeEquals(result, typ.Integer) { + t.Errorf("optional(integer) or never should preserve truthy short-circuit, got %v", result) } } diff --git a/compiler/check/synth/phase/core/params.go b/compiler/check/synth/phase/core/params.go index 916925ee..66273dcf 100644 --- a/compiler/check/synth/phase/core/params.go +++ b/compiler/check/synth/phase/core/params.go @@ -21,6 +21,9 @@ type ParamListConfig struct { // ImplicitSelfType is used for the prepended `self` parameter. // When nil, `unknown` is used. ImplicitSelfType typ.Type + // UntypedParamType is used for unannotated source parameters without an + // expected type. When nil, `unknown` is used. + UntypedParamType typ.Type } // ParamSymbolLookup exposes parameter symbol layout for a function expression. @@ -135,6 +138,9 @@ func ApplyParamList(builder *typ.FunctionBuilder, fn *ast.FunctionExpr, cfg Para isOptional = cfg.Expected.Params[expectedIdx].Optional } else { // Unannotated params are optional in Lua (missing args become nil). + if cfg.UntypedParamType != nil { + paramType = cfg.UntypedParamType + } isOptional = true hasUntyped = true } diff --git a/compiler/check/synth/phase/extract/expr.go b/compiler/check/synth/phase/extract/expr.go index 0b128c35..ea4e7dbf 100644 --- a/compiler/check/synth/phase/extract/expr.go +++ b/compiler/check/synth/phase/extract/expr.go @@ -25,6 +25,7 @@ package extract import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/flowbuild/guard" "github.com/wippyai/go-lua/compiler/check/scope" "github.com/wippyai/go-lua/compiler/check/synth/ops" "github.com/wippyai/go-lua/types/cfg" @@ -106,6 +107,15 @@ func (s *Synthesizer) synthAttrGetCore(ex *ast.AttrGetExpr, p cfg.Point, sc *sco if typ.IsUnknown(unwrap.Alias(narrowed)) && typ.IsAny(unwrap.Alias(objType)) { goto skipNarrowedAttr } + if key, ok := ex.Key.(*ast.StringExpr); ok { + if declaredField, ok := s.deps.Types.Field(s.deps.Ctx, objType, key.Value); ok && declaredField != nil { + refined, ok := s.refineNarrowedFieldFact(narrowed, declaredField) + if !ok { + goto skipNarrowedAttr + } + narrowed = refined + } + } return narrowed } } @@ -230,6 +240,35 @@ skipNarrowedAttr: return typ.Unknown } +func (s *Synthesizer) refineNarrowedFieldFact(narrowed, declared typ.Type) (typ.Type, bool) { + if narrowed == nil || declared == nil { + return narrowed, true + } + declared = unwrap.Alias(declared) + narrowed = unwrap.Alias(narrowed) + if declared == nil || narrowed == nil { + return narrowed, true + } + if declared.Kind().IsPlaceholder() { + return narrowed, true + } + if s.deps.Types != nil { + if s.deps.Types.IsSubtype(s.deps.Ctx, narrowed, declared) { + return narrowed, true + } + declaredNonNil := narrow.RemoveNil(declared) + if !typ.IsNever(declaredNonNil) { + if s.deps.Types.IsSubtype(s.deps.Ctx, declaredNonNil, narrowed) { + return declaredNonNil, true + } + if unwrap.Function(declaredNonNil) != nil && unwrap.Function(narrowed) != nil { + return declaredNonNil, true + } + } + } + return nil, false +} + func (s *Synthesizer) indexFromKeyOf(objType typ.Type, objExpr ast.Expr, key *ast.IdentExpr, p cfg.Point, sc *scope.State, narrower api.FlowOps) typ.Type { if s == nil || key == nil || narrower == nil || s.deps.Paths == nil || s.deps.CheckCtx == nil { return nil @@ -469,20 +508,24 @@ func (s *Synthesizer) synthLogicalOpCore(ex *ast.LogicalOpExpr, recurse ExprSynt func (s *Synthesizer) synthLogicalOpWithNarrowing(ex *ast.LogicalOpExpr, p cfg.Point, sc *scope.State, narrower api.FlowOps, recurse ExprSynth) typ.Type { left := recurse(ex.Lhs) - // Extract path for LHS expression - var lhsPath constraint.Path - if s.deps.Paths != nil { - lhsPath = s.deps.Paths(p, ex.Lhs, sc) - } else if ident, ok := ex.Lhs.(*ast.IdentExpr); ok { - if s.deps.CheckCtx != nil { - if bindings := s.deps.CheckCtx.Bindings(); bindings != nil { - if sym, ok := bindings.SymbolOf(ident); ok && sym != 0 { - lhsPath = constraint.Path{Root: ident.Value, Symbol: sym} + if ex.Operator == "and" { + if probe, ok := guard.ExtractTypeEqualityProbe(ex.Lhs); ok { + probePath := s.logicalNarrowPath(p, probe.Expr, sc) + if !probePath.IsEmpty() { + wrapped := &localNarrowOps{ + inner: narrower, + overridePath: probePath, + overrideType: guard.TypeForTypeKey(probe.Key), } + right := s.SynthExpr(ex.Rhs, p, wrapped) + return ops.LogicalAndTyped(left, right) } } } + // Extract path for LHS expression + lhsPath := s.logicalNarrowPath(p, ex.Lhs, sc) + if !lhsPath.IsEmpty() && ops.CanBeFalsy(left) { var narrowedType typ.Type switch ex.Operator { @@ -514,6 +557,52 @@ func (s *Synthesizer) synthLogicalOpWithNarrowing(ex *ast.LogicalOpExpr, p cfg.P return s.synthLogicalOpCore(ex, recurse) } +func (s *Synthesizer) logicalNarrowPath(p cfg.Point, expr ast.Expr, sc *scope.State) constraint.Path { + if s == nil { + return constraint.Path{} + } + if s.deps.Paths != nil { + return s.deps.Paths(p, expr, sc) + } + if ident, ok := expr.(*ast.IdentExpr); ok { + if s.deps.CheckCtx != nil { + if bindings := s.deps.CheckCtx.Bindings(); bindings != nil { + if sym, ok := bindings.SymbolOf(ident); ok && sym != 0 { + return constraint.Path{Root: ident.Value, Symbol: sym} + } + } + } + } + return constraint.Path{} +} + +func (s *Synthesizer) synthLogicalOpWithExpected(ex *ast.LogicalOpExpr, sc *scope.State, p cfg.Point, recurse ExprSynth, expected typ.Type) typ.Type { + if ex == nil { + return typ.Unknown + } + if expected == nil || ex.Operator != "or" && ex.Operator != "and" { + return s.synthLogicalOpCore(ex, recurse) + } + + branch := func(expr ast.Expr) typ.Type { + if expr == nil { + return typ.Unknown + } + return s.SynthExprWithExpectedCore(expr, sc, p, recurse, expected) + } + + left := recurse(ex.Lhs) + right := branch(ex.Rhs) + switch ex.Operator { + case "and": + return ops.LogicalAndTyped(left, right) + case "or": + return ops.LogicalOrTyped(left, right) + default: + return typ.Unknown + } +} + // synthArithmeticOpCore synthesizes type for arithmetic operators. func (s *Synthesizer) synthArithmeticOpCore(ex *ast.ArithmeticOpExpr, recurse ExprSynth) typ.Type { left := recurse(ex.Lhs) diff --git a/compiler/check/synth/phase/extract/function.go b/compiler/check/synth/phase/extract/function.go index d54788b8..81f5bc1f 100644 --- a/compiler/check/synth/phase/extract/function.go +++ b/compiler/check/synth/phase/extract/function.go @@ -619,7 +619,7 @@ func (s *Synthesizer) inferReturnTypesFromBody( } } - if typ.IsUnknownOnlyOrEmpty(returnTypes) && len(canonicalReturns) > 0 { + if len(returnTypes) == 0 && len(canonicalReturns) > 0 { return canonicalReturns, false } @@ -828,14 +828,17 @@ func (s *Synthesizer) buildFunctionTypeFromAvailableFacts( func (s *Synthesizer) buildParamOverlay(fnGraph *cfg.Graph, sc *scope.State, expected *typ.Function) map[cfg.SymbolID]typ.Type { paramSlots := fnGraph.ParamSlotsReadOnly() overlay := make(map[cfg.SymbolID]typ.Type, overlaySymbolCapacity(fnGraph, len(paramSlots))) - for _, slot := range paramSlots { + for paramIdx, slot := range paramSlots { if slot.Symbol == 0 { continue } - srcIdx, hasSource := slot.SourceParamIndex() + _, hasSource := slot.SourceParamIndex() if !hasSource { - if selfType := sc.SelfType(); selfType != nil { + if expected != nil && paramIdx < len(expected.Params) && expected.Params[paramIdx].Type != nil { + overlay[slot.Symbol] = expected.Params[paramIdx].Type + } else if sc != nil && sc.SelfType() != nil { + selfType := sc.SelfType() overlay[slot.Symbol] = selfType } else { overlay[slot.Symbol] = typ.Unknown @@ -843,12 +846,11 @@ func (s *Synthesizer) buildParamOverlay(fnGraph *cfg.Graph, sc *scope.State, exp continue } - i := srcIdx paramType := typ.Unknown if slot.TypeAnnotation != nil { paramType = s.ResolveType(slot.TypeAnnotation, sc) - } else if expected != nil && i < len(expected.Params) { - paramType = expected.Params[i].Type + } else if expected != nil && paramIdx < len(expected.Params) { + paramType = expected.Params[paramIdx].Type } else if slot.Name == "self" && sc != nil && sc.SelfType() != nil { paramType = sc.SelfType() } diff --git a/compiler/check/synth/phase/extract/pipeline.go b/compiler/check/synth/phase/extract/pipeline.go index 11a27517..25876b7a 100644 --- a/compiler/check/synth/phase/extract/pipeline.go +++ b/compiler/check/synth/phase/extract/pipeline.go @@ -5,6 +5,7 @@ import ( "github.com/wippyai/go-lua/compiler/check/synth/ops" "github.com/wippyai/go-lua/types/cfg" "github.com/wippyai/go-lua/types/db" + "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" "github.com/wippyai/go-lua/types/typ/unwrap" ) @@ -144,8 +145,88 @@ func FullArgReSynth( } return synthWithExpected(a, p, expected) case *ast.IdentExpr: - return synthWithExpected(a, p, expected) + inferred := synthWithExpected(a, p, expected) + if shouldRefineCallArgWithExpected(inferred, expected) { + return expected + } + return inferred + case *ast.AttrGetExpr: + inferred := synthWithExpected(a, p, expected) + if shouldRefineCallArgWithExpected(inferred, expected) { + return expected + } + return inferred } return nil } } + +func shouldRefineCallArgWithExpected(inferred, expected typ.Type) bool { + if inferred == nil || expected == nil { + return false + } + if typ.IsAny(inferred) || typ.IsAny(expected) || expected.Kind().IsPlaceholder() { + return false + } + if typ.IsUnknown(unwrap.Alias(inferred)) { + return true + } + if subtype.IsSubtype(inferred, expected) { + return true + } + inferredRec := unwrap.Record(inferred) + expectedRec := unwrap.Record(expected) + if inferredRec == nil || expectedRec == nil { + return false + } + return recordEvidenceMatchesExpected(inferredRec, expectedRec) +} + +func recordEvidenceMatchesExpected(inferred, expected *typ.Record) bool { + if inferred == nil || expected == nil { + return false + } + for _, field := range inferred.Fields { + expectedField := expected.GetField(field.Name) + if expectedField == nil { + if expected.Open { + continue + } + return false + } + if unresolvedRecordEvidence(field.Type) { + continue + } + inferredType := field.Type + if field.Optional { + inferredType = typ.NewOptional(inferredType) + } + expectedType := expectedField.Type + if expectedField.Optional { + expectedType = typ.NewOptional(expectedType) + } + if !subtype.IsSubtype(inferredType, expectedType) { + return false + } + } + if inferred.HasMapComponent() { + if !expected.HasMapComponent() { + return false + } + if !unresolvedRecordEvidence(inferred.MapKey) && !subtype.IsSubtype(inferred.MapKey, expected.MapKey) { + return false + } + if !unresolvedRecordEvidence(inferred.MapValue) && !subtype.IsSubtype(inferred.MapValue, expected.MapValue) { + return false + } + } + return true +} + +func unresolvedRecordEvidence(t typ.Type) bool { + if typ.IsAbsentOrUnknown(t) { + return true + } + rec := unwrap.Record(t) + return rec != nil && len(rec.Fields) == 0 && !rec.HasMapComponent() +} diff --git a/compiler/check/synth/phase/extract/table.go b/compiler/check/synth/phase/extract/table.go index a07f8c03..2f8cc2d9 100644 --- a/compiler/check/synth/phase/extract/table.go +++ b/compiler/check/synth/phase/extract/table.go @@ -5,6 +5,7 @@ import ( "github.com/wippyai/go-lua/compiler/check/scope" "github.com/wippyai/go-lua/compiler/check/synth/ops" phasecore "github.com/wippyai/go-lua/compiler/check/synth/phase/core" + "github.com/wippyai/go-lua/types/narrow" querycore "github.com/wippyai/go-lua/types/query/core" "github.com/wippyai/go-lua/types/typ" "github.com/wippyai/go-lua/types/typ/unwrap" @@ -37,6 +38,9 @@ func (s *Synthesizer) SynthTableCore(ex *ast.TableExpr, sc *scope.State, recurse // Empty tables return an open record (can have any additional fields assigned). func (s *Synthesizer) SynthTableWithExpected(ex *ast.TableExpr, sc *scope.State, recurse ExprSynth, expected typ.Type) typ.Type { if len(ex.Fields) == 0 { + if result := emptyTableExpectedResult(expected); result != nil { + return result + } return typ.NewRecord().SetOpen(true).Build() } @@ -152,19 +156,47 @@ func (s *Synthesizer) SynthTableWithExpected(ex *ast.TableExpr, sc *scope.State, } else { result = typ.NewTuple(arrayElements...) } - if expected != nil && len(ops.CheckTable(nil, arrayElements, expected).Errors) == 0 { + if useExpectedTableResult(expected) && len(ops.CheckTable(nil, arrayElements, expected).Errors) == 0 { return expected } return result } result := builder.Build() - if expected != nil && len(ops.CheckTable(fieldDefs, arrayElements, expected).Errors) == 0 { + if useExpectedTableResult(expected) && len(ops.CheckTable(fieldDefs, arrayElements, expected).Errors) == 0 { return expected } return result } +func emptyTableExpectedResult(expected typ.Type) typ.Type { + if expected == nil { + return nil + } + nonNil := narrow.RemoveNil(expected) + if nonNil == nil || typ.IsNever(nonNil) || typ.IsAbsentOrUnknown(nonNil) || typ.IsAny(nonNil) { + return nil + } + if !useExpectedTableResult(nonNil) { + return nil + } + if len(ops.CheckTable(nil, nil, nonNil).Errors) != 0 { + return nil + } + return nonNil +} + +func useExpectedTableResult(expected typ.Type) bool { + if expected == nil { + return false + } + unwrapped := unwrap.Alias(expected) + if unwrapped == nil { + return false + } + return !unwrapped.Kind().IsPlaceholder() +} + // synthFieldValueWithExpected synthesizes type for a table field value with optional expected type. func (s *Synthesizer) synthFieldValueWithExpected(value ast.Expr, sc *scope.State, recurse ExprSynth, expected typ.Type, selfType typ.Type) typ.Type { if tbl, ok := value.(*ast.TableExpr); ok { diff --git a/compiler/check/synth/phase/extract/table_test.go b/compiler/check/synth/phase/extract/table_test.go index 4de65874..1bffe107 100644 --- a/compiler/check/synth/phase/extract/table_test.go +++ b/compiler/check/synth/phase/extract/table_test.go @@ -70,6 +70,59 @@ func TestSynthTableCore_ArrayLike(t *testing.T) { } } +func TestSynthTableWithExpectedAnyPreservesTuplePrecision(t *testing.T) { + s := newTestSynthesizer() + sc := scope.New() + recurse := func(ex ast.Expr) typ.Type { return s.TypeOf(ex, 0) } + + table := &ast.TableExpr{ + Fields: []*ast.Field{ + {Value: &ast.StringExpr{Value: "first"}}, + }, + } + result := s.SynthTableWithExpected(table, sc, recurse, typ.Any) + + tuple, ok := result.(*typ.Tuple) + if !ok { + t.Fatalf("got %T, want tuple", result) + } + if len(tuple.Elements) != 1 { + t.Fatalf("got %d elements, want 1", len(tuple.Elements)) + } +} + +func TestSynthTableWithExpectedEmptyMapUsesNonNilExpected(t *testing.T) { + s := newTestSynthesizer() + sc := scope.New() + recurse := func(ex ast.Expr) typ.Type { return s.TypeOf(ex, 0) } + + expected := typ.NewOptional(typ.NewMap(typ.String, typ.Any)) + table := &ast.TableExpr{} + result := s.SynthTableWithExpected(table, sc, recurse, expected) + + if !typ.TypeEquals(result, typ.NewMap(typ.String, typ.Any)) { + t.Fatalf("got %v, want non-nil expected map", result) + } +} + +func TestSynthTableWithExpectedEmptyRecordRequiresFields(t *testing.T) { + s := newTestSynthesizer() + sc := scope.New() + recurse := func(ex ast.Expr) typ.Type { return s.TypeOf(ex, 0) } + + expected := typ.NewRecord().Field("name", typ.String).Build() + table := &ast.TableExpr{} + result := s.SynthTableWithExpected(table, sc, recurse, expected) + + rec, ok := result.(*typ.Record) + if !ok { + t.Fatalf("got %T, want synthesized open record", result) + } + if !rec.Open || len(rec.Fields) != 0 { + t.Fatalf("got %v, want empty open record for missing required fields", result) + } +} + func TestSynthTableWithExpected_Record(t *testing.T) { s := newTestSynthesizer() sc := scope.New() diff --git a/compiler/check/synth/phase/extract/union_expected.go b/compiler/check/synth/phase/extract/union_expected.go index 6e581e67..a977a0fa 100644 --- a/compiler/check/synth/phase/extract/union_expected.go +++ b/compiler/check/synth/phase/extract/union_expected.go @@ -85,6 +85,8 @@ func (s *Synthesizer) synthExprWithExpectedSingle( return typ.Nil } return types[0] + case *ast.LogicalOpExpr: + return s.synthLogicalOpWithExpected(ex, sc, p, recurse, expected) case *ast.IdentExpr: if expectedFn, ok := unwrap.Alias(expected).(*typ.Function); ok { if fnExpr := s.functionLiteralForIdent(ex); fnExpr != nil { diff --git a/compiler/check/tests/modules/manifest_test.go b/compiler/check/tests/modules/manifest_test.go index 20ecfa28..8bd48035 100644 --- a/compiler/check/tests/modules/manifest_test.go +++ b/compiler/check/tests/modules/manifest_test.go @@ -3,10 +3,13 @@ package modules import ( "testing" + "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/tests/testutil" "github.com/wippyai/go-lua/types/constraint" "github.com/wippyai/go-lua/types/io" "github.com/wippyai/go-lua/types/typ" + "github.com/wippyai/go-lua/types/typ/unwrap" ) // TestManifest_BasicExport tests basic manifest export types. @@ -182,6 +185,50 @@ func TestManifest_SoftLocalAnnotations(t *testing.T) { } t.Errorf("expected no errors with soft local annotations") } + + if result.Session == nil || result.Session.Store == nil || result.Session.RootResult == nil || result.Session.RootResult.Graph == nil { + t.Fatal("missing session data") + } + root := result.Session.RootResult.Graph + parentHash := result.Session.Store.GraphParentHashOf(root.ID()) + parent := result.Session.Store.Parents()[parentHash] + functionFacts := result.Session.Store.GetFunctionFactsSnapshot(root, parent) + paramHints := result.Session.Store.GetParamHintsSnapshot(root, parent) + + groupSym := localFunctionSymbolByName(t, root, functionFacts, "group_by_suite") + runSuiteSym := localFunctionSymbolByName(t, root, functionFacts, "run_suite") + entryArray := typ.NewArray(entryType) + suiteMap := typ.NewMap(typ.String, entryArray) + + groupFact := functionFacts[groupSym] + if len(groupFact.Summary) != 2 || !typ.TypeEquals(groupFact.Summary[0], suiteMap) || !typ.TypeEquals(groupFact.Summary[1], entryArray) { + t.Fatalf("expected group_by_suite summary (%v, %v), got %v", suiteMap, entryArray, groupFact.Summary) + } + if len(groupFact.Narrow) != 2 || !typ.TypeEquals(groupFact.Narrow[0], suiteMap) || !typ.TypeEquals(groupFact.Narrow[1], entryArray) { + t.Fatalf("expected group_by_suite narrow summary (%v, %v), got %v", suiteMap, entryArray, groupFact.Narrow) + } + groupFn := unwrap.Function(groupFact.Type) + if groupFn == nil || len(groupFn.Returns) != 2 || !typ.TypeEquals(groupFn.Returns[0], suiteMap) || !typ.TypeEquals(groupFn.Returns[1], entryArray) { + t.Fatalf("expected group_by_suite function returns (%v, %v), got %v", suiteMap, entryArray, groupFact.Type) + } + runSuiteFn := unwrap.Function(functionFacts.FunctionType(runSuiteSym)) + if runSuiteFn == nil || len(runSuiteFn.Params) < 2 || !typ.TypeEquals(runSuiteFn.Params[1].Type, entryArray) { + t.Fatalf("expected run_suite tests param to refine to %v, got %v", entryArray, functionFacts.FunctionType(runSuiteSym)) + } + if hints := paramHints[runSuiteSym]; len(hints) < 2 || !typ.TypeEquals(hints[1], entryArray) { + t.Fatalf("expected run_suite param hint %v, got %v", entryArray, hints) + } +} + +func localFunctionSymbolByName(t *testing.T, graph *cfg.Graph, facts api.FunctionFacts, name string) cfg.SymbolID { + t.Helper() + for sym := range facts { + if graph.NameOf(sym) == name { + return sym + } + } + t.Fatalf("missing function fact for %s", name) + return 0 } // TestManifest_InterfaceExport tests manifest with interface types. diff --git a/compiler/check/tests/regression/assert_false_discriminant_narrowing_test.go b/compiler/check/tests/regression/assert_false_discriminant_narrowing_test.go index b8a63ef2..f1486852 100644 --- a/compiler/check/tests/regression/assert_false_discriminant_narrowing_test.go +++ b/compiler/check/tests/regression/assert_false_discriminant_narrowing_test.go @@ -1,9 +1,12 @@ package regression import ( + "strings" "testing" "github.com/wippyai/go-lua/compiler/check/tests/testutil" + "github.com/wippyai/go-lua/types/typ" + "github.com/wippyai/go-lua/types/typ/unwrap" ) // Reproduces llm test pattern: @@ -49,3 +52,675 @@ contains(response.error_message, "Model is required") t.Fatal("expected no errors for assert-based discriminant narrowing") } } + +func TestRegression_DefaultedAnyFieldDoesNotSilentlyAdoptFallbackType(t *testing.T) { + source := ` +local info = nil :: any +local error_message = info.message or "fallback" + +local function needs_string(value: string) + return value +end + +needs_string(error_message) +` + result := testutil.Check(source, testutil.WithStdlib()) + if !result.HasError() { + t.Fatal("expected defaulted any field to remain dynamic, not become string") + } + found := false + for _, msg := range testutil.ErrorMessages(result.Diagnostics) { + if strings.Contains(msg, "expected string, got any") { + found = true + break + } + } + if !found { + t.Fatalf("expected any-to-string diagnostic, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestRegression_ImportedAssertFalseDiscriminantNarrowing(t *testing.T) { + testMod := testutil.CheckAndExport(` +local test = {} + +function test.is_false(val: any, msg: string?) + if val ~= false then + error(msg or "expected false") + end +end + +function test.contains(str: any, substr: string, msg: string?): string + if type(str) ~= "string" or not string.find(str, substr, 1, true) then + error(msg or "expected contains") + end + return str +end + +return test +`, "test_mod", testutil.WithStdlib()) + if testMod.HasError() { + t.Fatalf("unexpected test module errors: %v", testutil.ErrorMessages(testMod.Errors)) + } + + containsField := unwrap.Record(testMod.Manifest.Export).GetField("contains") + if containsField == nil { + t.Fatal("expected exported contains function") + } + containsFn := unwrap.Function(containsField.Type) + if containsFn == nil || len(containsFn.Params) == 0 || !typ.TypeEquals(containsFn.Params[0].Type, typ.Any) { + t.Fatalf("contains first param = %v, want any", containsField.Type) + } + if summary, ok := testMod.Manifest.LookupSummary("contains"); ok && summary != nil && len(summary.Params) > 0 { + if !typ.TypeEquals(summary.Params[0], typ.Any) { + t.Fatalf("contains summary first param = %v, want any", summary.Params[0]) + } + } + + producer := testutil.CheckAndExport(` +local M = {} + +function M.handler() + if true then + return { + success = false, + error = "invalid_request", + error_message = "Model is required" + } + end + return { + success = true, + result = { content = "ok" } + } +end + +return M +`, "producer", testutil.WithStdlib()) + if producer.HasError() { + t.Fatalf("unexpected producer errors: %v", testutil.ErrorMessages(producer.Errors)) + } + + source := ` +local tests = require("test_mod") +local producer = require("producer") + +local response = producer.handler() +tests.is_false(response.success) +tests.contains(response.error_message, "Model is required") +` + result := testutil.Check(source, + testutil.WithStdlib(), + testutil.WithModule("test_mod", testMod), + testutil.WithModule("producer", producer), + ) + if result.HasError() { + t.Fatalf("expected imported assert false to narrow discriminant, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestRegression_ImportedDiscriminantThroughMultivalueHelper(t *testing.T) { + testMod := testutil.CheckAndExport(` +local test = {} + +function test.is_false(val: any, msg: string?) + if val ~= false then + error(msg or "expected false") + end +end + +function test.contains(str: string, substr: string, msg: string?): string + if type(str) ~= "string" or not string.find(str, substr, 1, true) then + error(msg or "expected contains") + end + return str +end + +return test +`, "test_mod", testutil.WithStdlib()) + if testMod.HasError() { + t.Fatalf("unexpected test module errors: %v", testutil.ErrorMessages(testMod.Errors)) + } + + mapperMod := testutil.CheckAndExport(` +local mapper = {} + +local function map_error_type(_status_code, message) + if message then + local _lower = message:lower() + end + return "invalid_request" +end + +function mapper.map_error_response(info) + local error_message = info.message or "fallback" + local error_type = map_error_type(info.status_code, error_message) + return { + success = false, + error = error_type, + error_message = error_message, + metadata = {} + }, { message = error_message } +end + +function mapper.map_success_response(_response) + return { + success = true, + result = { content = "ok" }, + metadata = {} + } +end + +return mapper +`, "mapper_mod", testutil.WithStdlib()) + if mapperMod.HasError() { + t.Fatalf("unexpected mapper errors: %v", testutil.ErrorMessages(mapperMod.Errors)) + } + + generateMod := testutil.CheckAndExport(` +local mapper = require("mapper_mod") + +local generate = { + _mapper = mapper, +} + +function generate.handler(args) + if args.bad then + return generate._mapper.map_error_response({ + message = "bad request", + status_code = 400, + }) + end + if args.remote_bad then + local response = args.response + return generate._mapper.map_error_response(response) + end + return generate._mapper.map_success_response({}) +end + +return generate +`, "generate_mod", testutil.WithStdlib(), testutil.WithModule("mapper_mod", mapperMod)) + if generateMod.HasError() { + t.Fatalf("unexpected generate errors: %v", testutil.ErrorMessages(generateMod.Errors)) + } + + source := ` +local tests = require("test_mod") +local generate = require("generate_mod") + +local response = generate.handler({ bad = true }) +tests.is_false(response.success) +tests.contains(response.error_message, "bad request") +` + result := testutil.Check(source, + testutil.WithStdlib(), + testutil.WithModule("test_mod", testMod), + testutil.WithModule("generate_mod", generateMod), + ) + if result.HasError() { + t.Fatalf("expected imported multivalue helper result to narrow by success=false, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestRegression_BDDCallbackLocalImportedDiscriminant(t *testing.T) { + testMod := testutil.CheckAndExport(` +local test = { _cases = {} } + +function test.is_false(val: any, msg: string?) + if val ~= false then + error(msg or "expected false") + end +end + +function test.is_true(val: any, msg: string?) + if val ~= true then + error(msg or "expected true") + end +end + +function test.contains(str: string, substr: string, msg: string?): string + if type(str) ~= "string" or not string.find(str, substr, 1, true) then + error(msg or "expected contains") + end + return str +end + +function test.describe(_name: string, fn: fun()) + fn() +end + +function test.it(_name: string, fn: fun()) + table.insert(test._cases, fn) +end + +function test.run_cases(define_cases_fn: fun()) + return function() + _G.describe = test.describe + _G.it = test.it + define_cases_fn() + _G.describe = nil + _G.it = nil + end +end + +return test +`, "test_mod", testutil.WithStdlib()) + if testMod.HasError() { + t.Fatalf("unexpected test module errors: %v", testutil.ErrorMessages(testMod.Errors)) + } + + mapperMod := testutil.CheckAndExport(` +local mapper = {} + +local function map_error_type(_status_code, message) + if message then + local _lower = message:lower() + end + return "invalid_request" +end + +function mapper.map_error_response(info) + local error_message = info.message or "fallback" + local error_type = map_error_type(info.status_code, error_message) + return { + success = false, + error = error_type, + error_message = error_message, + metadata = {} + }, { message = error_message } +end + +function mapper.map_success_response(_response) + return { + success = true, + result = { content = "ok" }, + metadata = {} + } +end + +return mapper +`, "mapper_mod", testutil.WithStdlib()) + if mapperMod.HasError() { + t.Fatalf("unexpected mapper errors: %v", testutil.ErrorMessages(mapperMod.Errors)) + } + + generateMod := testutil.CheckAndExport(` +local mapper = require("mapper_mod") + +local generate = { + _mapper = mapper, +} + +function generate.handler(args) + if args.bad then + return generate._mapper.map_error_response({ + message = "bad request", + status_code = 400, + }) + end + if args.remote_bad then + local response = args.response + return generate._mapper.map_error_response(response) + end + return generate._mapper.map_success_response({}) +end + +return generate +`, "generate_mod", testutil.WithStdlib(), testutil.WithModule("mapper_mod", mapperMod)) + if generateMod.HasError() { + t.Fatalf("unexpected generate errors: %v", testutil.ErrorMessages(generateMod.Errors)) + } + + source := ` +local tests = require("test_mod") +local generate = require("generate_mod") + +local function define_tests() + describe("generate", function() + it("error response", function() + generate._mapper = { + map_error_response = function(info) + return { + success = false, + error = "invalid_request", + error_message = info.message, + metadata = {} + } + end, + map_success_response = function() + return { + success = true, + result = { content = "ok" }, + metadata = {} + } + end, + } + + local response = generate.handler({ bad = true }) + tests.is_false(response.success) + tests.contains(response.error_message, "bad request") + end) + + it("success response", function() + generate._mapper = { + map_error_response = function(info) + return { + success = false, + error = "invalid_request", + error_message = info.message, + metadata = {} + } + end, + map_success_response = function() + return { + success = true, + result = { content = "ok" }, + metadata = {} + } + end, + } + + local response = generate.handler({}) + tests.is_true(response.success) + end) + end) +end + +return tests.run_cases(define_tests) +` + result := testutil.Check(source, + testutil.WithStdlib(), + testutil.WithModule("test_mod", testMod), + testutil.WithModule("generate_mod", generateMod), + ) + if result.HasError() { + t.Fatalf("expected BDD callback-local imported discriminant to narrow, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestRegression_ImportedHandlerUsesVisibleMapperOverrideContract(t *testing.T) { + testMod := testutil.CheckAndExport(` +local test = { _cases = {} } + +function test.is_false(val: any, msg: string?) + if val ~= false then + error(msg or "expected false") + end +end + +function test.contains(str: string, substr: string, msg: string?): string + if type(str) ~= "string" or not string.find(str, substr, 1, true) then + error(msg or "expected contains") + end + return str +end + +function test.describe(_name: string, fn: fun()) + fn() +end + +function test.it(_name: string, fn: fun()) + table.insert(test._cases, fn) +end + +function test.run_cases(define_cases_fn: fun()) + return function() + _G.describe = test.describe + _G.it = test.it + define_cases_fn() + _G.describe = nil + _G.it = nil + end +end + +return test +`, "test_mod", testutil.WithStdlib()) + if testMod.HasError() { + t.Fatalf("unexpected test module errors: %v", testutil.ErrorMessages(testMod.Errors)) + } + + mapperMod := testutil.CheckAndExport(` +local mapper = {} + +function mapper.map_error_response(error_info) + local error_message = error_info.message or "Google API error" + return { + success = false, + error = "server_error", + error_message = error_message, + metadata = {} + } +end + +return mapper +`, "mapper_mod", testutil.WithStdlib()) + if mapperMod.HasError() { + t.Fatalf("unexpected mapper errors: %v", testutil.ErrorMessages(mapperMod.Errors)) + } + + contractMod := testutil.CheckAndExport(` +local contract = {} + +function contract.get(_id) + return nil, "not found" +end + +return contract +`, "contract_mod", testutil.WithStdlib()) + if contractMod.HasError() { + t.Fatalf("unexpected contract errors: %v", testutil.ErrorMessages(contractMod.Errors)) + } + + generateMod := testutil.CheckAndExport(` +local mapper = require("mapper_mod") +local contract = require("contract_mod") + +local generate = { + _mapper = mapper, + _contract = contract, +} + +function generate.handler(args) + if not args.model then + return generate._mapper.map_error_response({ + message = "Model is required", + status_code = 400, + }) + end + + local _, err = generate._contract.get("client") + if err then + return generate._mapper.map_error_response({ + message = "Failed to get client contract: " .. tostring(err), + status_code = 500, + }) + end + + return { success = true } +end + +return generate +`, "generate_mod", testutil.WithStdlib(), + testutil.WithModule("mapper_mod", mapperMod), + testutil.WithModule("contract_mod", contractMod)) + if generateMod.HasError() { + t.Fatalf("unexpected generate errors: %v", testutil.ErrorMessages(generateMod.Errors)) + } + + source := ` +local tests = require("test_mod") +local generate = require("generate_mod") + +local function define_tests() + describe("generate", function() + it("contract failure", function() + generate._mapper = { + map_error_response = function(error_info) + return { + success = false, + error = "server_error", + error_message = error_info.message, + metadata = {} + } + end + } + + generate._contract = { + get = function(_contract_id) + return nil, "Contract not found" + end + } + + local response = generate.handler({ + model = "gemini-1.5-pro", + messages = { + { role = "user", content = {{ type = "text", text = "Test" }} } + } + }) + + tests.is_false(response.success) + tests.contains(response.error_message, "Failed to get client contract") + end) + end) +end + +return tests.run_cases(define_tests) +` + result := testutil.Check(source, + testutil.WithStdlib(), + testutil.WithModule("test_mod", testMod), + testutil.WithModule("generate_mod", generateMod), + ) + if result.HasError() { + t.Fatalf("expected visible mapper override contract to prove error_message, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestRegression_ErrorMapperInfersDefaultedMessageField(t *testing.T) { + mapperMod := testutil.CheckAndExport(` +local output = { + ERROR_TYPE = { + SERVER_ERROR = "server_error", + }, + to_structured_error = function(_response) + return nil + end, +} + +local mapper = {} + +local function map_error_type(_status_code, message) + if message then + local lower_msg = message:lower() + if lower_msg:match("timeout") then + return "timeout" + end + end + return output.ERROR_TYPE.SERVER_ERROR +end + +function mapper.map_error_response(google_error) + if not google_error then + local response = { + success = false, + error = output.ERROR_TYPE.SERVER_ERROR, + error_message = "Unknown Google error", + metadata = {} + } + return response, output.to_structured_error(response) + end + + local error_message = google_error.message or "Google API error" + local error_type = map_error_type(google_error.status_code, error_message) + + local response = { + success = false, + error = error_type, + error_message = error_message, + metadata = google_error.metadata or {} + } + return response, output.to_structured_error(response) +end + +return mapper +`, "mapper_mod", testutil.WithStdlib()) + if mapperMod.HasError() { + t.Fatalf("unexpected mapper errors: %v", testutil.ErrorMessages(mapperMod.Errors)) + } + + field := unwrap.Record(mapperMod.Manifest.Export).GetField("map_error_response") + if field == nil { + t.Fatal("expected exported map_error_response") + } + fn := unwrap.Function(field.Type) + if fn == nil || len(fn.Returns) == 0 { + t.Fatalf("expected function return, got %v", field.Type) + } + rec := unwrap.Record(fn.Returns[0]) + if rec == nil { + t.Fatalf("expected record return, got %v", fn.Returns[0]) + } + errMsg := rec.GetField("error_message") + if errMsg == nil || !typ.TypeEquals(errMsg.Type, typ.String) { + t.Fatalf("error_message = %v, want string in %v", errMsg, fn.Returns[0]) + } +} + +func TestRegression_PartialRecordParamHintsBecomeOptionalFields(t *testing.T) { + source := ` +local mapper = {} + +function mapper.map_tokens(usage) + if not usage then + return nil + end + return { + prompt_tokens = usage.promptTokenCount or 0, + completion_tokens = usage.candidatesTokenCount or 0, + total_tokens = usage.totalTokenCount or 0, + thinking_tokens = usage.thoughtsTokenCount + } +end + +mapper.map_tokens({ promptTokenCount = 10 }) +mapper.map_tokens({ candidatesTokenCount = 20 }) +mapper.map_tokens({ totalTokenCount = 30 }) +mapper.map_tokens({ thoughtsTokenCount = 40 }) +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected partial record parameter observations to form optional fields, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestRegression_NestedPartialRecordParamHintsBecomeOptionalFields(t *testing.T) { + source := ` +local mapper = {} + +function mapper.map_tokens(usage) + if not usage then + return nil + end + return { + prompt_tokens = usage.promptTokenCount or 0, + completion_tokens = usage.candidatesTokenCount or 0, + total_tokens = usage.totalTokenCount or 0, + thinking_tokens = usage.thoughtsTokenCount + } +end + +function mapper.map_success_response(response) + return { + tokens = mapper.map_tokens(response.usageMetadata) + } +end + +mapper.map_success_response({ usageMetadata = { promptTokenCount = 10 } }) +mapper.map_success_response({ usageMetadata = { candidatesTokenCount = 20 } }) +mapper.map_success_response({ usageMetadata = { totalTokenCount = 30 } }) +mapper.map_success_response({ usageMetadata = { thoughtsTokenCount = 40 } }) +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected partial record parameter observations to form optional fields, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} diff --git a/compiler/check/tests/regression/external_lint_regression_test.go b/compiler/check/tests/regression/external_lint_regression_test.go new file mode 100644 index 00000000..10b2ac32 --- /dev/null +++ b/compiler/check/tests/regression/external_lint_regression_test.go @@ -0,0 +1,586 @@ +package regression + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +func TestExternalLint_OptionalResponseBodyDefaultIsStringAtCall(t *testing.T) { + source := ` +local json = {} +function json.decode(raw: string): any + return {} +end + +type Stream = { + read: (self: Stream, n: number?) -> (string?, string?), +} + +type Response = { + status_code: number, + body: string?, + stream: Stream?, +} + +local function get_response(): Response + local stream: Stream = { + read = function(self: Stream, n: number?) + return "chunk", nil + end, + } + return { status_code = 500, stream = stream } +end + +local response = get_response() +if response.status_code >= 300 then + if response.stream and not response.body then + local body_data = response.stream:read(4096) + response.body = body_data + end +end + +local parsed, parse_err = json.decode(response.body or "") +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected optional body fallback and guarded stream read to type-check, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_GuardedOptionsModelSurvivesProviderBranches(t *testing.T) { + source := ` +local models = { + get_by_name = function(model_id: string) + return { name = model_id, providers = { { id = "provider", provider_model = model_id, options = {} } } }, nil + end, + get_by_class = function(class_id: string) + return { { name = class_id, providers = { { id = "provider", provider_model = class_id, options = {} } } } }, nil + end, +} + +local providers = { + open = function(provider_id: string, options: table) + return { + generate = function(self, args) + return { success = true, result = args.model }, nil + end, + }, nil + end, +} + +local security = { + actor = function() + return nil + end, +} + +local llm = {} +llm._models = nil +llm._providers = nil + +local function resolve_model(model_identifier) + local models_module = llm._models or models + local class_name = model_identifier:match("^class:(.+)") + if class_name then + local class_models, err = models_module.get_by_class(class_name) + if err then + return nil, err + end + if class_models and #class_models > 0 then + return class_models[1] + end + return nil, "No models found" + end + + local model_card, err = models_module.get_by_name(model_identifier) + if model_card then + return model_card + end + + local class_models, class_err = models_module.get_by_class(model_identifier) + if not class_err and class_models and #class_models > 0 then + return class_models[1] + end + + return nil, "Model not found" +end + +local function merge_user_options(contract_args, user_options, exclude_keys) + exclude_keys = exclude_keys or {} + for k, v in pairs(user_options) do + local should_exclude = false + for _, exclude_key in ipairs(exclude_keys) do + if k == exclude_key then + should_exclude = true + break + end + end + if not should_exclude then + contract_args.options[k] = v + end + end +end + +function llm.generate(prompt_input, options) + if not options or not options.model then + return nil, "Model is required in options" + end + + local actor = security.actor() + if actor then + options.user = actor:id() + end + + if options.provider_id then + local providers_module = llm._providers or providers + local provider_instance, err = providers_module.open(options.provider_id, {}) + if not provider_instance then + return nil, err + end + + local contract_args = { + messages = prompt_input, + model = options.model, + options = {}, + } + merge_user_options(contract_args, options, {"model", "provider_id"}) + return (provider_instance as any):generate(contract_args) + end + + local model_card, err = resolve_model(options.model) + if not model_card then + return nil, err + end + + local provider_info = model_card.providers[1] + local providers_module = llm._providers or providers + local provider_instance, open_err = providers_module.open(provider_info.id, provider_info.options or {}) + if not provider_instance then + return nil, open_err + end + + local contract_args = { + messages = prompt_input, + model = provider_info.provider_model, + options = {}, + } + merge_user_options(contract_args, options, {"model"}) + return (provider_instance as any):generate(contract_args) +end +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatalf("expected guarded options.model/provider_id to satisfy helper/provider contracts, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_GuardedParamFieldFeedsKnownCall(t *testing.T) { + source := ` +local providers = { + open = function(provider_id: string, options: table) + return { id = provider_id }, nil + end, +} + +local function generate(options) + if options.provider_id then + return providers.open(options.provider_id, {}) + end + return nil, nil +end +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected guarded parameter field to infer from known call expectation, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_GuardedParamFieldFeedsFallbackModuleCall(t *testing.T) { + source := ` +local providers = { + open = function(provider_id: string, options: table) + return { id = provider_id }, nil + end, +} + +local api = {} +api._providers = nil + +local function generate(options) + if options.provider_id then + local providers_module = api._providers or providers + return providers_module.open(options.provider_id, {}) + end + return nil, nil +end +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected guarded parameter field to infer through fallback module alias, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_GuardedParamFieldSurvivesSiblingFieldMutation(t *testing.T) { + source := ` +local providers = { + open = function(provider_id: string, options: table) + return { id = provider_id }, nil + end, +} + +local security = { + actor = function() + return nil + end, +} + +local api = {} +api._providers = nil + +local function generate(options) + if not options or not options.model then + return nil, "model required" + end + local actor = security.actor() + if actor then + options.user = actor:id() + end + if options.provider_id then + local providers_module = api._providers or providers + return providers_module.open(options.provider_id, {}) + end + return nil, nil +end +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected guarded parameter field to survive sibling field mutation, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_CastFieldExpressionFeedsCallArgument(t *testing.T) { + source := ` +local funcs = {} + +function funcs.new() + return { + call = function(self, name: string, context: table) + return { id = name }, nil + end, + } +end + +local function get_page_data(page) + if not page or not page.data_func or page.data_func == "" then + return {}, nil + end + + local executor = funcs.new() + local result, err = executor:call(page.data_func :: string, {}) + return result, err +end + +local result, err = get_page_data({ data_func = true }) +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected explicit field cast to feed call argument checking, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_InsertedSuiteShapeSurvivesIpairs(t *testing.T) { + source := ` +type Suite = { + name: string, + tests: {any}, + children: {Suite}, + full_path: string, + before_all: any?, + after_all: any?, + before_each: any?, + after_each: any?, +} + +local test = {} +local _default_context = { + suites_hierarchy = {}, +} + +function test.suite(name: string): Suite + return { + name = name, + tests = {}, + children = {}, + full_path = name, + } +end + +local function run_suite_with_children(suite: Suite) + for _, child in ipairs(suite.children) do + run_suite_with_children(child) + end +end + +local suite: Suite = test.suite("top") +table.insert(_default_context.suites_hierarchy, suite) + +for _, top_suite in ipairs(_default_context.suites_hierarchy) do + run_suite_with_children(top_suite) +end +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatalf("expected inserted suite shape to survive ipairs, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_CapturedTableInsertFeedsCleanupLoop(t *testing.T) { + source := ` +type Suite = { + name: string, + tests: {any}, + children: {Suite}, + full_path: string, + before_all: any?, + after_all: any?, + before_each: any?, + after_each: any?, +} + +local test = {} +local _default_context = { + tests = {}, + suites_hierarchy = {}, + results = { + tests = {}, + }, +} + +function test.suite(name: string): Suite + return { + name = name, + tests = {}, + children = {}, + full_path = name, + } +end + +function test.describe(name: string) + local new_suite = test.suite(name) + table.insert(_default_context.suites_hierarchy, new_suite) + table.insert(_default_context.tests, new_suite) + return new_suite +end + +local function clear_suite_references(suite: Suite) + if suite.tests then + for i, test_case in ipairs(suite.tests) do + suite.tests[i].fn = nil + end + end + suite.before_all = nil + suite.after_all = nil + suite.before_each = nil + suite.after_each = nil + suite.children = {} + for _, child in ipairs(suite.children or {}) do + clear_suite_references(child) + end +end + +local function cleanup_test_resources() + for _, suite in ipairs(_default_context.suites_hierarchy) do + clear_suite_references(suite) + end + _default_context.tests = {} + _default_context.suites_hierarchy = {} + _default_context.results.tests = {} +end + +test.describe("top") +cleanup_test_resources() +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatalf("expected inserted suite shape to survive context resets, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_ExportedDescribeFeedsLaterRunLoop(t *testing.T) { + source := ` +type Suite = { + name: string, + tests: {any}, + children: {Suite}, + full_path: string, + before_all: any?, + after_all: any?, + before_each: any?, + after_each: any?, +} + +local test = {} +local _default_context = { + suites_hierarchy = {}, + current_describe = nil, +} + +function test.suite(name: string): Suite + return { + name = name, + tests = {}, + children = {}, + full_path = name, + } +end + +function test.describe(name: string, fn: any) + local old_describe = _default_context.current_describe + local new_suite = test.suite(name) + if old_describe then + new_suite.parent = old_describe + table.insert(old_describe.children, new_suite) + new_suite.full_path = old_describe.full_path .. " > " .. name + else + table.insert(_default_context.suites_hierarchy, new_suite) + end + _default_context.current_describe = new_suite + fn() + _default_context.current_describe = old_describe + return new_suite +end + +function test.run() + local function clear_suite_references(suite: Suite) + suite.before_all = nil + suite.after_all = nil + suite.before_each = nil + suite.after_each = nil + suite.children = {} + end + for _, suite in ipairs(_default_context.suites_hierarchy) do + clear_suite_references(suite) + end +end +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatalf("expected exported describe table insert to feed later run loop, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_BodyCallExpectationInfersWholeParameter(t *testing.T) { + source := ` +local http = { + get = function(url: string, options: {headers: {[string]: string}, stream?: boolean}) + return { status_code = 200, body = "{}" }, nil + end, +} + +local client = { + _http_client = http, +} + +function client.request(method, url, http_options) + http_options.headers["Accept"] = "application/json" + if http_options.stream then + url = url .. "?alt=sse" + end + return client._http_client.get(url, http_options) +end +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatalf("expected body call expectations to infer whole parameter shape, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_TypeProbeAllowsOptionalDynamicFieldFallback(t *testing.T) { + source := ` +local page = { + id = "home", + name = "Home", +} + +local placement: string = type(page.placement) == "string" and page.placement or "default" +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected type() field probe fallback to type-check, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_NestedTableInsertFeedsIpairs(t *testing.T) { + source := ` +local state = { + items = {}, +} + +local value: string = "x" +table.insert(state.items, value) + +for _, item in ipairs(state.items) do + local s: string = item +end +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatalf("expected nested table.insert to feed ipairs element type, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_DiscriminatedArrayElementFeedsBranchHelper(t *testing.T) { + source := ` +local function convert_image_to_converse(content_part) + if content_part.type == "image" and content_part.source then + return { image = content_part.source } + end + return nil +end + +local message = { + content = { + { type = "text", text = "hello" }, + { type = "image", source = { media_type = "image/png", data = "abc" } }, + }, +} + +local content_blocks = {} +for _, part in ipairs(message.content) do + if part.type == "text" and part.text and part.text ~= "" then + table.insert(content_blocks, { text = part.text }) + elseif part.type == "image" then + local img = convert_image_to_converse(part) + if img then + table.insert(content_blocks, img) + end + end +end +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatalf("expected discriminated array element to feed image helper, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} diff --git a/compiler/check/tests/regression/false_positives_unit_test.go b/compiler/check/tests/regression/false_positives_unit_test.go index c3ee6dd1..c7fffb14 100644 --- a/compiler/check/tests/regression/false_positives_unit_test.go +++ b/compiler/check/tests/regression/false_positives_unit_test.go @@ -107,6 +107,1132 @@ func TestFalsePositive_TupleDynamicIndexing(t *testing.T) { } } +func TestFalsePositive_CapturedLocalHelperReceivesGuardedParamField(t *testing.T) { + source := ` + local api = {} + + local function resolve_model(model_identifier) + local class_name = model_identifier:match("^class:(.+)") + if class_name then + return { id = class_name }, nil + end + return { id = model_identifier }, nil + end + + function api.generate(options) + if not options or not options.model then + return nil, "model required" + end + + local model_card, err = resolve_model(options.model) + if not model_card then + return nil, err + end + return model_card.id + end + ` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected guarded param field to satisfy captured helper parameter, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestFalsePositive_OptionalMapDefaultEmptyTable(t *testing.T) { + source := ` +local function find(options: {[string]: any}?) + options = options or {} + local criteria: {[string]: any} = {} + for k, v in pairs(options) do + criteria[k] = v + end +end +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected optional map default to empty table to type-check, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestFalsePositive_GuardedTableElementKeepsRecordFields(t *testing.T) { + source := ` + local blocks = {} + local data = nil :: any + local event_type = nil :: any + + while true do + if event_type == "content_block_start" then + if data.index ~= nil and data.content_block then + if data.content_block.type == "thinking" then + blocks[data.index] = { + type = "thinking", + thinking = data.content_block.thinking or "", + signature = data.content_block.signature or "", + } + end + end + elseif event_type == "content_block_delta" then + local index = data.index or 0 + local delta = data.delta or {} + if delta.type == "thinking_delta" then + local thinking_chunk = delta.thinking or "" + if blocks[index] then + blocks[index].thinking = blocks[index].thinking .. thinking_chunk + end + elseif delta.type == "signature_delta" then + local signature_chunk = delta.signature or "" + if blocks[index] then + blocks[index].signature = blocks[index].signature .. signature_chunk + end + end + end + break + end + ` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Errorf("expected guarded table element fields to stay available, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestFalsePositive_NestedAnyFieldFallbackInArithmetic(t *testing.T) { + source := ` + local stats = nil :: any + if stats.cpu_stats and stats.precpu_stats then + local cpu_delta = (stats.cpu_stats.cpu_usage and stats.cpu_stats.cpu_usage.total_usage or 0) - + (stats.precpu_stats.cpu_usage and stats.precpu_stats.cpu_usage.total_usage or 0) + local sys_delta = (stats.cpu_stats.system_cpu_usage or 0) - (stats.precpu_stats.system_cpu_usage or 0) + if sys_delta > 0 and cpu_delta > 0 then + local cpu_percent = (cpu_delta / sys_delta) * 100 + end + end + ` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Errorf("expected nested any field fallback arithmetic to type-check, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestFalsePositive_ParsedAnyBodyNestedFieldFallbackInArithmetic(t *testing.T) { + source := ` + local function parse_response(body: any) + if not body then + return nil + end + if type(body) == "table" then + return body + end + if type(body) == "string" then + return { raw = body } + end + return body + end + + local function container_stats() + local response = nil :: any + local result = { + body = parse_response(response.body), + } + return result.body, nil + end + + local stats, err = container_stats() + if err then + return + end + if stats.cpu_stats and stats.precpu_stats then + local cpu_delta = (stats.cpu_stats.cpu_usage and stats.cpu_stats.cpu_usage.total_usage or 0) - + (stats.precpu_stats.cpu_usage and stats.precpu_stats.cpu_usage.total_usage or 0) + local sys_delta = (stats.cpu_stats.system_cpu_usage or 0) - (stats.precpu_stats.system_cpu_usage or 0) + if sys_delta > 0 and cpu_delta > 0 then + local cpu_percent = (cpu_delta / sys_delta) * 100 + end + end + ` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Errorf("expected parsed any body arithmetic to type-check, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestFalsePositive_ErrorReturnSuccessWithImplicitNilErrorNarrowsSibling(t *testing.T) { + httpResponse := typ.NewRecord(). + Field("status_code", typ.Integer). + OptField("body", typ.String). + Build() + httpManifest := io.NewManifest("http_client") + httpManifest.SetExport(typ.NewRecord(). + Field("get", typ.Func(). + Param("url", typ.String). + OptParam("options", typ.Any). + Returns(typ.NewOptional(httpResponse), typ.NewOptional(typ.String)). + Spec(contract.NewSpec().WithEffects(effect.ErrorReturn{ValueIndex: 0, ErrorIndex: 1})). + Build()). + Build()) + + testModule := testutil.CheckAndExport(` + local test = {} + + function test.is_nil(val: any, msg: string?) + if val ~= nil then + error(msg or "assertion failed") + end + end + + return test + `, "test_mod", testutil.WithStdlib()) + if testModule.HasError() { + t.Fatalf("unexpected test module errors: %v", testutil.ErrorMessages(testModule.Errors)) + } + + jsonModule := testutil.CheckAndExport(` + local json = {} + + function json.decode(raw: string): any + if raw == "" then + return nil, "empty" + end + return { + candidates = { + { content = { parts = { { text = "Hello" } } } } + } + } + end + + return json + `, "json", testutil.WithStdlib()) + if jsonModule.HasError() { + t.Fatalf("unexpected json module errors: %v", testutil.ErrorMessages(jsonModule.Errors)) + } + + source := ` + local http_client = require("http_client") + local json = require("json") + local test = require("test_mod") + + local client = { + _http_client = http_client, + } + + local function parse_error_response(http_response) + return { + status_code = http_response.status_code, + message = "request failed", + } + end + + function client.request(method, url, http_options) + local response, err = client._http_client.get(url, http_options) + if not response then + return nil, { + status_code = 0, + message = tostring(err), + } + end + + if response.status_code < 200 or response.status_code >= 300 then + return nil, parse_error_response(response) + end + + local parsed, parse_err = json.decode(tostring(response.body or "")) + if parse_err then + return nil, { + status_code = response.status_code, + message = parse_err, + metadata = {}, + } + end + + parsed.metadata = {} + parsed.status_code = response.status_code + return parsed + end + + local response, err = client.request("GET", "https://example.test", { headers = {} }) + test.is_nil(err) + local text = response.candidates[1].content.parts[1].text + return text + ` + result := testutil.Check(source, + testutil.WithStdlib(), + testutil.WithManifest("http_client", httpManifest), + testutil.WithModule("json", jsonModule), + testutil.WithModule("test_mod", testModule), + ) + if result.HasError() { + t.Errorf("expected is_nil(err) to narrow implicit-success error return sibling, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestFalsePositive_MethodReceiverParamHintInfersCapturedSelfFields(t *testing.T) { + source := ` + type Output = { + kind: "rendered", + label: string?, + } + + type HandlerBuilder = { + name: string?, + prefix: string?, + prefix_with: (self: HandlerBuilder, prefix: string) -> HandlerBuilder, + build: (self: HandlerBuilder) -> () -> Output, + } + + type Builder = HandlerBuilder + + local Builder = {} + Builder.__index = Builder + + local M = {} + + function M.new(): HandlerBuilder + local self: Builder = { + name = nil, + prefix = nil, + prefix_with = Builder.prefix_with, + build = Builder.build, + } + setmetatable(self, Builder) + return self + end + + function Builder:prefix_with(prefix: string): Builder + self.prefix = prefix + return self + end + + function Builder:build(): () -> Output + local name = self.name or "plugin" + local prefix = self.prefix or name + local check_prefix: string = prefix + + return function(): Output + return { + kind = "rendered", + label = prefix, + } + end + end + + local handler = M.new() + :prefix_with("render") + :build() + + local out: Output = handler() + ` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Errorf("expected method receiver hints to type captured builder fields, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestFalsePositive_ErrorReturnDelegatedHelperNarrowsSibling(t *testing.T) { + testModule := testutil.CheckAndExport(` + local test = {} + + function test.is_nil(val: any, msg: string?) + if val ~= nil then + error(msg or "assertion failed") + end + end + + return test + `, "test_mod", testutil.WithStdlib()) + if testModule.HasError() { + t.Fatalf("unexpected test module errors: %v", testutil.ErrorMessages(testModule.Errors)) + } + + source := ` + local test = require("test_mod") + + local function finish(ok) + if not ok then + return nil, { message = "failed" }, nil + end + return { + candidates = { + { content = { parts = { { text = "Hello" } } } } + } + }, nil, { source = "finish" } + end + + local function request(ok) + if not ok then + return nil, { message = "failed early" }, nil + end + return finish(ok) + end + + local response, err = request(true) + test.is_nil(err) + local text = response.candidates[1].content.parts[1].text + return text + ` + result := testutil.Check(source, + testutil.WithStdlib(), + testutil.WithModule("test_mod", testModule), + ) + if result.HasError() { + t.Errorf("expected is_nil(err) to narrow delegated error-return helper sibling, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestFalsePositive_ImportedClientMockResponseKeepsDecodedArrayPresence(t *testing.T) { + jsonManifest := io.NewManifest("json") + jsonManifest.SetExport(typ.NewRecord(). + Field("encode", typ.Func(). + Param("value", typ.Any). + Returns(typ.String, typ.NewOptional(typ.LuaError)). + Build()). + Field("decode", typ.Func(). + Param("source", typ.String). + Returns(typ.Any, typ.NewOptional(typ.LuaError)). + Build()). + Build()) + + testModule := testutil.CheckAndExport(` + local test = {} + + function test.is_nil(val: any, msg: string?) + if val ~= nil then + error(msg or "assertion failed") + end + end + + function test.eq(_actual: any, _expected: any, _msg: string?) + end + + return test + `, "test_mod", testutil.WithStdlib()) + if testModule.HasError() { + t.Fatalf("unexpected test module errors: %v", testutil.ErrorMessages(testModule.Errors)) + } + + clientModule := testutil.CheckAndExport(` + local json = require("json") + + local client = { + _http_client = nil :: any, + } + + local function parse_error_response(http_response) + return { + status_code = http_response.status_code, + message = "request failed", + } + end + + function client.request(method, url, http_options) + local response = nil + local err = nil + if method == "GET" then + response, err = client._http_client.get(url, http_options) + else + response, err = client._http_client.post(url, http_options) + end + + if not response then + return nil, { + status_code = 0, + message = tostring(err), + } + end + + if response.status_code < 200 or response.status_code >= 300 then + return nil, parse_error_response(response) + end + + local parsed, parse_err = json.decode(tostring(response.body or "")) + if parse_err then + return nil, { + status_code = response.status_code, + message = tostring(parse_err), + metadata = {}, + } + end + + parsed.metadata = {} + parsed.status_code = response.status_code + return parsed + end + + return client + `, "client_mod", testutil.WithStdlib(), testutil.WithManifest("json", jsonManifest)) + if clientModule.HasError() { + t.Fatalf("unexpected client module errors: %v", testutil.ErrorMessages(clientModule.Errors)) + } + + source := ` + local client = require("client_mod") + local json = require("json") + local test = require("test_mod") + + client._http_client = { + get = function(_url, _options) + return { + status_code = 200, + body = json.encode({ + candidates = { + { content = { parts = { { text = "Hello" } } } } + } + }) + } + end + } + + local response, err = client.request("GET", "https://example.test", { headers = {} }) + test.is_nil(err) + test.eq(response.candidates[1].content.parts[1].text, "Hello") + ` + result := testutil.Check(source, + testutil.WithStdlib(), + testutil.WithManifest("json", jsonManifest), + testutil.WithModule("client_mod", clientModule), + testutil.WithModule("test_mod", testModule), + ) + if result.HasError() { + t.Errorf("expected imported client mock response to preserve decoded array presence after err narrowing, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestFalsePositive_ImportedGoogleLikeClientSuiteKeepsCandidatesArray(t *testing.T) { + jsonManifest := io.NewManifest("json") + jsonManifest.SetExport(typ.NewRecord(). + Field("encode", typ.Func(). + Param("value", typ.Any). + Returns(typ.String, typ.NewOptional(typ.LuaError)). + Build()). + Field("decode", typ.Func(). + Param("source", typ.String). + Returns(typ.Any, typ.NewOptional(typ.LuaError)). + Build()). + Build()) + + streamReaderType := typ.NewInterface("http_client.StreamReader", []typ.Method{ + {Name: "read", Type: typ.Func().Param("self", typ.Self).OptParam("size", typ.Number).Returns(typ.String, typ.NewOptional(typ.LuaError)).Build()}, + }) + httpResponse := typ.NewRecord(). + Field("status_code", typ.Number). + OptField("body", typ.String). + OptField("stream", streamReaderType). + Build() + httpFn := typ.Func(). + Param("url", typ.String). + OptParam("options", typ.Any). + Returns(httpResponse, typ.NewOptional(typ.LuaError)). + Build() + httpManifest := io.NewManifest("http_client") + httpManifest.SetExport(typ.NewRecord(). + Field("get", httpFn). + Field("post", httpFn). + Build()) + + outputModule := testutil.CheckAndExport(` + local output = {} + + function output.streamer(_pid: string?, _topic: string?, _buffer_size: any?) + return { + buffer_content = function(self, _text: string?) return true end, + send_tool_call = function(self, _name: string, _arguments: string, _id: string?) return true end, + send_thinking = function(self, _text: string) return true end, + send_error = function(self, _kind: string, _message: string, _code: any?) return true end, + flush = function(self) return true end, + }, nil + end + + return output + `, "output_mod", testutil.WithStdlib()) + if outputModule.HasError() { + t.Fatalf("unexpected output module errors: %v", testutil.ErrorMessages(outputModule.Errors)) + } + + testModule := testutil.CheckAndExport(` + local test = {} + + function test.is_nil(val: any, msg: string?) + if val ~= nil then + error(msg or "assertion failed") + end + end + + function test.eq(_actual: any, _expected: any, _msg: string?) + end + + function test.not_nil(val: any, msg: string?): any + if val == nil then + error(msg or "assertion failed") + end + return val + end + + function test.describe(_name: string, fn: fun()) + fn() + end + + function test.it(_name: string, fn: fun()) + fn() + end + + function test.after_each(fn: fun()) + fn() + end + + function test.run_cases(define_cases_fn: fun()) + return function() + _G.describe = test.describe + _G.it = test.it + _G.after_each = test.after_each + define_cases_fn() + _G.describe = nil + _G.it = nil + _G.after_each = nil + end + end + + return test + `, "test_mod", testutil.WithStdlib()) + if testModule.HasError() { + t.Fatalf("unexpected test module errors: %v", testutil.ErrorMessages(testModule.Errors)) + } + + clientModule := testutil.CheckAndExport(` + local json = require("json") + local http_client = require("http_client") + local output = require("output_mod") + + local client = { + _http_client = http_client, + } + + local function extract_response_metadata(response_body: any) + if not response_body then + return {} + end + return { + model_version = response_body.modelVersion, + response_id = response_body.responseId, + create_time = response_body.createTime, + } + end + + function client.process_stream(stream_response, callbacks) + callbacks = callbacks or {} + local on_done = callbacks.on_done or function(_result) end + local metadata = stream_response.metadata or {} + local result = { + content = "stream", + tool_calls = {}, + finish_reason = "stop", + usage = nil, + metadata = metadata, + } + on_done(result) + return "stream", nil, result + end + + local function handle_stream_response(response, http_options) + local streamer = output.streamer(http_options.stream_reply_to, http_options.stream_topic, http_options.stream_buffer_size or 10) + if not streamer then + return nil, { status_code = 500, message = "Failed to create streamer" } + end + + local full_content = "" + local tool_call_parts = {} + local finish_reason = nil + local usage_metadata = nil + local response_metadata = {} + local callbacks = { + on_content = function(chunk: string) + full_content = full_content .. chunk + streamer:buffer_content(chunk) + end, + on_tool_call = function(tool_part: any) + table.insert(tool_call_parts, tool_part) + end, + on_done = function(result) + streamer:flush() + finish_reason = result.finish_reason + usage_metadata = result.usage + response_metadata = result.metadata + end, + } + + local _, stream_err = client.process_stream({ stream = response.stream, metadata = {} }, callbacks) + if stream_err then + return nil, { status_code = 500, message = tostring(stream_err) } + end + + local parts = {} + if full_content ~= "" then + table.insert(parts, { text = full_content }) + end + for _, tc_part in ipairs(tool_call_parts) do + table.insert(parts, tc_part) + end + + return { + candidates = { + { + content = { parts = parts, role = "model" }, + finishReason = finish_reason, + }, + }, + usageMetadata = usage_metadata, + metadata = response_metadata, + status_code = response.status_code or 200, + } + end + + function client.request(method, url, http_options) + http_options.headers["Accept"] = "application/json" + if http_options.stream then + url = url .. "?alt=sse" + http_options.headers["Accept"] = "text/event-stream" + end + + local response = nil + local err = nil + if method == "GET" then + response, err = client._http_client.get(url, http_options) + else + http_options.headers["Content-Type"] = "application/json" + response, err = client._http_client.post(url, http_options) + end + + if not response then + return nil, { status_code = 0, message = tostring(err) } + end + + if response.status_code < 200 or response.status_code >= 300 then + return nil, { status_code = response.status_code, message = "bad" } + end + + if http_options.stream and response.stream then + return handle_stream_response(response, http_options) + end + + local parsed, parse_err = json.decode(response.body or "") + if parse_err then + return nil, { status_code = response.status_code, message = tostring(parse_err), metadata = {} } + end + + parsed.metadata = extract_response_metadata(parsed) + parsed.status_code = response.status_code + return parsed + end + + return client + `, "client_mod", + testutil.WithStdlib(), + testutil.WithManifest("json", jsonManifest), + testutil.WithManifest("http_client", httpManifest), + testutil.WithModule("output_mod", outputModule), + ) + if clientModule.HasError() { + t.Fatalf("unexpected client module errors: %v", testutil.ErrorMessages(clientModule.Errors)) + } + + source := ` + local client = require("client_mod") + local json = require("json") + local tests = require("test_mod") + + local function define_tests() + describe("client", function() + after_each(function() + client._http_client = nil + end) + + it("data response", function() + client._http_client = { + get = function(_url, _options) + return { + status_code = 200, + body = json.encode({ data = "test" }), + } + end, + } + + local response, err = client.request("GET", "https://example.test", { headers = {} }) + tests.is_nil(err) + tests.eq(response.data, "test") + end) + + it("post response", function() + client._http_client = { + post = function(_url, _options) + return { + status_code = 200, + body = json.encode({ data = "test" }), + } + end, + } + + local response, err = client.request("POST", "https://example.test", { headers = {}, body = json.encode({ test = "data" }) }) + tests.is_nil(err) + tests.eq(response.data, "test") + end) + + it("candidate response", function() + client._http_client = { + get = function(_url, _options) + return { + status_code = 200, + body = json.encode({ + candidates = { + { content = { parts = { { text = "Hello" } } } }, + }, + }), + } + end, + } + + local response, err = client.request("GET", "https://example.test", { headers = {} }) + tests.is_nil(err) + tests.eq(response.candidates[1].content.parts[1].text, "Hello") + end) + end) + end + + return tests.run_cases(define_tests) + ` + result := testutil.Check(source, + testutil.WithStdlib(), + testutil.WithManifest("json", jsonManifest), + testutil.WithModule("client_mod", clientModule), + testutil.WithModule("test_mod", testModule), + ) + if result.HasError() { + t.Errorf("expected imported Google-like client suite to preserve candidates as an array, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestFalsePositive_ImportedMutableClientFieldIsCallSiteSensitive(t *testing.T) { + testModule := testutil.CheckAndExport(` + local test = {} + + function test.is_nil(val: any, msg: string?) + if val ~= nil then + error(msg or "assertion failed") + end + end + + return test + `, "test_mod", testutil.WithStdlib()) + if testModule.HasError() { + t.Fatalf("unexpected test module errors: %v", testutil.ErrorMessages(testModule.Errors)) + } + + clientModule := testutil.CheckAndExport(` + local client = { + _http_client = nil :: any, + } + + function client.request() + local response, err = client._http_client.get() + if not response then + return nil, { message = tostring(err) } + end + return response.body + end + + return client + `, "client_mod", testutil.WithStdlib()) + if clientModule.HasError() { + t.Fatalf("unexpected client module errors: %v", testutil.ErrorMessages(clientModule.Errors)) + } + + source := ` + local client = require("client_mod") + local test = require("test_mod") + + local function candidate_case() + client._http_client = { + get = function() + return { + body = { + candidates = { + { content = { parts = { { text = "Hello" } } } } + } + } + } + end + } + + local response, err = client.request() + test.is_nil(err) + return response.candidates[1].content.parts[1].text + end + + local function data_case() + client._http_client = { + get = function() + return { + body = { + data = "other", + } + } + end + } + + local response, err = client.request() + test.is_nil(err) + return response.data + end + + return candidate_case(), data_case() + ` + result := testutil.Check(source, + testutil.WithStdlib(), + testutil.WithModule("client_mod", clientModule), + testutil.WithModule("test_mod", testModule), + ) + if result.HasError() { + t.Errorf("expected imported mutable client field calls to use the visible mock at each call site, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestFalsePositive_ImportedMutableClientFieldIsCallbackLocal(t *testing.T) { + testModule := testutil.CheckAndExport(` + local test = { _cases = {} } + + function test.is_nil(val: any, msg: string?) + if val ~= nil then + error(msg or "assertion failed") + end + end + + function test.eq(_actual: any, _expected: any, _msg: string?) + end + + function test.describe(_name: string, fn: fun()) + fn() + end + + function test.it(_name: string, fn: fun()) + table.insert(test._cases, fn) + end + + function test.run_cases(define_cases_fn: fun()) + return function() + _G.describe = test.describe + _G.it = test.it + define_cases_fn() + _G.describe = nil + _G.it = nil + end + end + + return test + `, "test_mod", testutil.WithStdlib()) + if testModule.HasError() { + t.Fatalf("unexpected test module errors: %v", testutil.ErrorMessages(testModule.Errors)) + } + + clientModule := testutil.CheckAndExport(` + local client = { + _http_client = nil :: any, + } + + function client.request(method, url, http_options) + http_options.headers["Accept"] = "application/json" + + if http_options.stream then + return { + candidates = { + { content = { parts = { { text = "stream" } }, role = "model" } } + } + }, nil + end + + local response = nil + local err = nil + if method == "GET" then + response, err = client._http_client.get(url, http_options) + else + response, err = client._http_client.post(url, http_options) + end + + if not response then + return nil, { message = tostring(err) } + end + if response.status_code < 200 or response.status_code >= 300 then + return nil, { status_code = response.status_code, message = "bad" } + end + + return response.body, nil + end + + return client + `, "client_mod", testutil.WithStdlib()) + if clientModule.HasError() { + t.Fatalf("unexpected client module errors: %v", testutil.ErrorMessages(clientModule.Errors)) + } + + source := ` + local client = require("client_mod") + local tests = require("test_mod") + + local function define_tests() + describe("client", function() + it("candidate response", function() + client._http_client = { + get = function(_url, _options) + return { + status_code = 200, + body = { + candidates = { + { content = { parts = { { text = "Hello" } }, role = "model" } } + } + } + } + end + } + + local response, err = client.request("GET", "https://example.test", { headers = {} }) + tests.is_nil(err) + tests.eq(response.candidates[1].content.parts[1].text, "Hello") + end) + + it("data response", function() + client._http_client = { + get = function(_url, _options) + return { + status_code = 200, + body = { data = "test" } + } + end + } + + local response, err = client.request("GET", "https://example.test", { headers = {} }) + tests.is_nil(err) + tests.eq(response.data, "test") + end) + end) + end + + return tests.run_cases(define_tests) + ` + result := testutil.Check(source, + testutil.WithStdlib(), + testutil.WithModule("client_mod", clientModule), + testutil.WithModule("test_mod", testModule), + ) + if result.HasError() { + t.Errorf("expected imported client mock fields to stay callback-local, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestFalsePositive_CallbackLocalDelegatedErrorReturnNarrowsSibling(t *testing.T) { + testModule := testutil.CheckAndExport(` + local test = { _cases = {} } + + function test.is_nil(val: any, msg: string?) + if val ~= nil then + error(msg or "assertion failed") + end + end + + function test.eq(_actual: any, _expected: any, _msg: string?) + end + + function test.describe(_name: string, fn: fun()) + fn() + end + + function test.it(_name: string, fn: fun()) + table.insert(test._cases, fn) + end + + function test.run_cases(define_cases_fn: fun()) + return function() + _G.describe = test.describe + _G.it = test.it + define_cases_fn() + _G.describe = nil + _G.it = nil + end + end + + return test + `, "test_mod", testutil.WithStdlib()) + if testModule.HasError() { + t.Fatalf("unexpected test module errors: %v", testutil.ErrorMessages(testModule.Errors)) + } + + source := ` + local tests = require("test_mod") + + local client = { + _http_client = nil :: any, + } + + local function handle_stream_response(response, http_options) + if response.err then + return nil, { message = "stream failed" } + end + return { + candidates = { + { content = { parts = { { text = "stream" } }, role = "model" } } + } + } + end + + function client.request(method, url, http_options) + http_options.headers["Accept"] = "application/json" + if http_options.stream then + http_options.headers["Accept"] = "text/event-stream" + end + + local response = nil + local err = nil + if method == "GET" then + response, err = client._http_client.get(url, http_options) + else + response, err = client._http_client.post(url, http_options) + end + + if not response then + return nil, { message = tostring(err) } + end + if response.status_code < 200 or response.status_code >= 300 then + return nil, { status_code = response.status_code, message = "bad" } + end + if http_options.stream and response.stream then + return handle_stream_response(response, http_options) + end + return response.body + end + + local function define_tests() + describe("client", function() + it("data response", function() + client._http_client = { + get = function(_url, _options) + return { + status_code = 200, + body = { data = "test" } + } + end + } + + local response, err = client.request("GET", "https://example.test", { headers = {} }) + tests.is_nil(err) + tests.eq(response.data, "test") + end) + end) + end + + return tests.run_cases(define_tests) + ` + result := testutil.Check(source, + testutil.WithStdlib(), + testutil.WithModule("test_mod", testModule), + ) + if result.HasError() { + t.Errorf("expected delegated error-return relation to narrow inside callback, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + func TestFalsePositive_ArithmeticOnOptionalAfterGuard(t *testing.T) { source := ` local values: {[integer]: number} = {10, 20, 30} diff --git a/compiler/check/tests/regression/http_timeout_option_inference_test.go b/compiler/check/tests/regression/http_timeout_option_inference_test.go index f9ba6636..d1dd67cc 100644 --- a/compiler/check/tests/regression/http_timeout_option_inference_test.go +++ b/compiler/check/tests/regression/http_timeout_option_inference_test.go @@ -59,6 +59,220 @@ local _ = http_client.post("https://example.local", http_options) } } +func TestRegression_TonumberDefaultLiteralIsNumber(t *testing.T) { + source := ` +local function resolve_string(_key: string, _default_env: string?): string? + return nil +end + +local timeout: number = tonumber(resolve_string("timeout", "HTTP_TIMEOUT")) or 600 +` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatal("expected tonumber default literal to infer number") + } +} + +func TestRegression_HttpTimeoutFromResolvedConfigRemainsNumber(t *testing.T) { + source := ` +type HttpOptions = { + headers?: {[string]: string}, + timeout?: number, + body?: string, + stream?: boolean +} + +local http_client = { + post = function(url: string, opts: HttpOptions): {status_code: number, body: string} + return {status_code = 200, body = ""} + end +} + +local client = { + _ctx = { + all = function(): {[string]: any} + return {} + end + }, + _env = { + get = function(_key: string): string? + return nil + end + }, +} + +local function resolve_config() + local ctx_all = client._ctx.all() or {} + + local function resolve_string(key: string, default_env: string?): string? + if ctx_all[key] then + return tostring(ctx_all[key]) + end + if default_env then + local val = client._env.get(default_env) + if val and val ~= "" then return val end + end + return nil + end + + return { + timeout = tonumber(resolve_string("timeout", "HTTP_TIMEOUT")) or 600, + headers = ctx_all.headers, + } +end + +function client.request(options) + options = options or {} + local config = resolve_config() + local headers: {[string]: string} = {} + local http_options = { + headers = headers, + timeout = tonumber(options.timeout) or config.timeout, + } + http_options.body = "{}" + http_options.stream = true + local _ = http_client.post("https://example.local", http_options) +end +` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatal("expected resolved config timeout to remain number") + } +} + +func TestRegression_HttpOptionsParamProjectionCompatibleWithManifestCall(t *testing.T) { + source := ` +type HttpOptions = { + headers?: {[string]: string}, + body?: string, + stream?: boolean, + stream_buffer_size?: number, + stream_reply_to?: string, + stream_topic?: string, +} + +local http_client = { + get = function(url: string, opts: HttpOptions): ({status_code: number, body: string?}?, string?) + return {status_code = 200, body = ""}, nil + end, + post = function(url: string, opts: HttpOptions): ({status_code: number, body: string?}?, string?) + return {status_code = 200, body = ""}, nil + end, +} + +local client = { + _http_client = http_client, +} + +function client.request(method, url, http_options) + http_options.headers["Accept"] = "application/json" + + if http_options.stream then + url = url .. "?alt=sse" + http_options.headers["Accept"] = "text/event-stream" + end + + local response = nil + local err = nil + if method == "GET" then + response, err = client._http_client.get(url, http_options) + else + http_options.headers["Content-Type"] = "application/json" + response, err = client._http_client.post(url, http_options) + end + + return response, err +end + +client.request("GET", "https://example.local", { headers = {} }) +` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatal("expected projected http options parameter to remain compatible with manifest call") + } +} + +func TestRegression_HttpOptionsMultipleCallHintsStillUseBodyContract(t *testing.T) { + source := ` +type HttpOptions = { + headers?: {[string]: string}, + body?: string, + stream?: boolean, + stream_buffer_size?: number, + stream_reply_to?: string, + stream_topic?: string, +} + +local http_client = { + get = function(url: string, opts: HttpOptions): ({status_code: number, body: string?}?, string?) + return {status_code = 200, body = ""}, nil + end, + post = function(url: string, opts: HttpOptions): ({status_code: number, body: string?}?, string?) + return {status_code = 200, body = ""}, nil + end, +} + +local client = { _http_client = http_client } + +function client.request(method, url, http_options) + http_options.headers["Accept"] = "application/json" + if http_options.stream then + url = url .. "?alt=sse" + http_options.headers["Accept"] = "text/event-stream" + end + local response = nil + local err = nil + if method == "GET" then + response, err = client._http_client.get(url, http_options) + else + http_options.headers["Content-Type"] = "application/json" + response, err = client._http_client.post(url, http_options) + end + return response, err +end + +local function call_one() + return client.request("GET", "https://example.local", { + headers = {}, + stream_buffer_size = 4096, + }) +end + +local function call_two() + return client.request("POST", "https://example.local", { + headers = {}, + body = "{}", + stream = true, + stream_reply_to = "reply", + stream_topic = "topic", + }) +end + +call_one() +call_two() +` + + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatal("expected body contract to dominate compatible multi-call http option hints") + } +} + func TestRegression_HttpTimeoutValueNarrowing(t *testing.T) { source := ` local options: {[string]: any} = {} diff --git a/compiler/check/tests/regression/imported_record_helper_param_test.go b/compiler/check/tests/regression/imported_record_helper_param_test.go index 1489c0ae..f9046cd1 100644 --- a/compiler/check/tests/regression/imported_record_helper_param_test.go +++ b/compiler/check/tests/regression/imported_record_helper_param_test.go @@ -103,3 +103,47 @@ func TestRegression_ImportedRecordHelperRejectsAnyPassedToStringMethod(t *testin t.Fatalf("expected an error when any flows into imported string-only method") } } + +func TestRegression_ImportedRecordHelperWithTableStoredModule(t *testing.T) { + clientModule := testutil.CheckAndExport(` + local client = {} + client.SERVICE = "bedrock" + function client.invoke(model_id: string, payload: any, options: {timeout: number?}?) + return {ok = true}, nil + end + function client.converse(model_id: string, payload: any, options: {timeout: number?}?) + return {ok = true}, nil + end + return client + `, "bedrock_client", testutil.WithStdlib()) + if clientModule.HasError() { + t.Fatalf("provider errors: %v", testutil.ErrorMessages(clientModule.Errors)) + } + + source := ` + local bedrock_client = require("bedrock_client") + + local handler = { + _client = bedrock_client, + } + + local function helper(client, model_id, input, options) + local payload = { input = input } + local response, err = client.invoke(model_id, payload, { timeout = options and options.timeout }) + if err then + return nil, err + end + return response + end + + local model_id = "model" + local input = "text" + local options = {} + local result, err + result, err = helper(handler._client, model_id, input, options) + ` + result := testutil.Check(source, testutil.WithStdlib(), testutil.WithModule("bedrock_client", clientModule)) + if result.HasError() { + t.Fatalf("expected table-stored imported module helper call to type-check, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} diff --git a/compiler/check/tests/regression/param_hint_depth_convergence_test.go b/compiler/check/tests/regression/param_hint_depth_convergence_test.go index 024fd043..68c96a8f 100644 --- a/compiler/check/tests/regression/param_hint_depth_convergence_test.go +++ b/compiler/check/tests/regression/param_hint_depth_convergence_test.go @@ -146,6 +146,55 @@ func TestParamHints_NestedWrapperFeedback_NoInterprocNonConvergenceWarning(t *te } } +func TestParamHints_OptionalContextTableFeedback_NoInterprocNonConvergenceWarning(t *testing.T) { + code := ` + local function merge_context(base, additions) + local out = {} + if base then + for k, v in pairs(base) do + out[k] = v + end + end + if additions then + for k, v in pairs(additions) do + out[k] = v + end + end + return out + end + + local function call_func(func_id: string, data: any, context: {[string]: any}?) + return data, nil + end + + local function run(items) + local result = {} + for index, item in ipairs(items) do + local ctx = merge_context(nil, { + current_item = item, + item_index = index, + }) + result[index] = call_func("item", item, ctx) + end + call_func("done", result) + return result + end + + return run({ "a", "b" }) + ` + + result := testutil.Check(code, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } + + for _, d := range result.Diagnostics { + if d.Severity == diag.SeverityWarning && strings.Contains(d.Message, "inter-function fixpoint did not converge") { + t.Fatalf("unexpected non-convergence warning: %v", d.Message) + } + } +} + func TestReturnSummary_RecursiveDeepCopy_NoInterprocNonConvergenceWarning(t *testing.T) { code := ` local function deep_copy_table(original) diff --git a/testdata/fixtures/regression/deadlock-compiler-lua/main.lua b/testdata/fixtures/regression/deadlock-compiler-lua/main.lua index bccab3a4..a0891f77 100644 --- a/testdata/fixtures/regression/deadlock-compiler-lua/main.lua +++ b/testdata/fixtures/regression/deadlock-compiler-lua/main.lua @@ -368,21 +368,21 @@ function FlowGraph:detect_cycles() local edges = (self.edges[node_id] :: any) if edges then - for _, edge in ipairs(edges.targets) do - if edge.target_node_id then - local has_cycle, cycle_desc = dfs(edge.target_node_id, path) - if has_cycle then - return true, cycle_desc - end - end - end - for _, edge in ipairs(edges.error_targets) do - if edge.target_node_id then - local has_cycle, cycle_desc = dfs(edge.target_node_id, path) - if has_cycle then - return true, cycle_desc - end - end + for _, edge in ipairs(edges.targets) do + if edge.target_node_id then + local has_cycle, cycle_desc = dfs(edge.target_node_id :: string, path) + if has_cycle then + return true, cycle_desc + end + end + end + for _, edge in ipairs(edges.error_targets) do + if edge.target_node_id then + local has_cycle, cycle_desc = dfs(edge.target_node_id :: string, path) + if has_cycle then + return true, cycle_desc + end + end end end diff --git a/testdata/fixtures/regression/deadlock-compiler-lua/manifest.json b/testdata/fixtures/regression/deadlock-compiler-lua/manifest.json index bca5acdf..845f4d03 100644 --- a/testdata/fixtures/regression/deadlock-compiler-lua/manifest.json +++ b/testdata/fixtures/regression/deadlock-compiler-lua/manifest.json @@ -1,6 +1,6 @@ { - "description": "Large dynamic compiler fixture must terminate; dynamic any-to-string calls are reported soundly", + "description": "Large dynamic compiler fixture must terminate; proven dynamic calls should not report stale false positives", "check": { - "errors": 2 + "errors": 0 } } diff --git a/testdata/fixtures/regression/deadlock-dataflow-node/manifest.json b/testdata/fixtures/regression/deadlock-dataflow-node/manifest.json index 4aab2baa..0dd29174 100644 --- a/testdata/fixtures/regression/deadlock-dataflow-node/manifest.json +++ b/testdata/fixtures/regression/deadlock-dataflow-node/manifest.json @@ -1,6 +1,6 @@ { - "description": "Large dynamic dataflow-node fixture must terminate; dynamic record-shape errors are reported soundly", + "description": "Large dynamic dataflow-node fixture must terminate without recursive param-hint over-specialization", "check": { - "errors": 4 + "errors": 0 } } diff --git a/types/flow/query.go b/types/flow/query.go index 85030863..9b1675c2 100644 --- a/types/flow/query.go +++ b/types/flow/query.go @@ -402,16 +402,18 @@ func (s *Solution) baseTypeAt(p cfg.Point, path constraint.Path) typ.Type { return explicit } - // Both available: prefer the narrower one - // If derived is falsy (nil/false), prefer explicit to avoid narrowing to never - if derived.Kind() == kind.Nil || isFalseLiteral(derived) { + // Direct child-path facts are the product-domain authority for that path. + // Parent-derived facts are a fallback: they describe the container shape, but + // may be stale after a direct field/index assignment or table mutator. + if !explicit.Kind().IsPlaceholder() { return explicit } - // Prefer concrete explicit child-path facts over placeholder parent-derived facts. - if derived.Kind().IsPlaceholder() && !explicit.Kind().IsPlaceholder() { + // If derived is falsy (nil/false), prefer explicit to avoid narrowing to never + if derived.Kind() == kind.Nil || isFalseLiteral(derived) { return explicit } + if explicit.Kind().IsPlaceholder() && !derived.Kind().IsPlaceholder() { return derived } diff --git a/types/flow/solver_test.go b/types/flow/solver_test.go index 47a20710..749679bb 100644 --- a/types/flow/solver_test.go +++ b/types/flow/solver_test.go @@ -281,6 +281,112 @@ func TestMergeFieldAssignments_IncludesCanonicalStringIndexKeys(t *testing.T) { } } +func TestWidenWithIndexer_CoalescesPartialRecordElementUpdates(t *testing.T) { + full := typ.NewRecord(). + Field("type", typ.LiteralString("thinking")). + Field("thinking", typ.Any). + Field("signature", typ.Any). + Build() + thinkingOnly := typ.NewRecord().Field("thinking", typ.String).Build() + signatureOnly := typ.NewRecord().Field("signature", typ.String).Build() + + got := widenWithIndexer(typ.NewMap(typ.Integer, full), typ.Integer, thinkingOnly) + got = widenWithIndexer(got, typ.Integer, signatureOnly) + + mp, ok := got.(*typ.Map) + if !ok { + t.Fatalf("widenWithIndexer returned %T, want *typ.Map", got) + } + rec, ok := mp.Value.(*typ.Record) + if !ok { + t.Fatalf("map value = %T, want coalesced *typ.Record (%v)", mp.Value, mp.Value) + } + for _, name := range []string{"type", "thinking", "signature"} { + if rec.GetField(name) == nil { + t.Fatalf("coalesced record missing field %q: %v", name, rec) + } + } +} + +func TestWidenWithIndexer_KeepsConflictingRecordElementsDiscriminated(t *testing.T) { + thinking := typ.NewRecord(). + Field("type", typ.LiteralString("thinking")). + Field("thinking", typ.String). + Build() + tool := typ.NewRecord(). + Field("type", typ.LiteralString("tool_use")). + Field("partial_json", typ.String). + Build() + + got := widenWithIndexer(typ.NewMap(typ.Integer, thinking), typ.Integer, tool) + + mp, ok := got.(*typ.Map) + if !ok { + t.Fatalf("widenWithIndexer returned %T, want *typ.Map", got) + } + if _, ok := mp.Value.(*typ.Union); !ok { + t.Fatalf("conflicting discriminant records should remain a union, got %T (%v)", mp.Value, mp.Value) + } +} + +func TestWidenWithIndexer_CoalescesRecordMapComponentValues(t *testing.T) { + full := typ.NewRecord(). + Field("type", typ.LiteralString("thinking")). + Field("thinking", typ.Any). + Field("signature", typ.Any). + Build() + partial := typ.NewRecord().Field("thinking", typ.String).Build() + base := typ.NewRecord(). + Field("count", typ.Integer). + MapComponent(typ.Integer, full). + Build() + + got := widenWithIndexer(base, typ.Integer, partial) + + rec, ok := got.(*typ.Record) + if !ok { + t.Fatalf("widenWithIndexer returned %T, want *typ.Record", got) + } + mapValue, ok := rec.MapValue.(*typ.Record) + if !ok { + t.Fatalf("record map value = %T, want coalesced *typ.Record (%v)", rec.MapValue, rec.MapValue) + } + for _, name := range []string{"type", "thinking", "signature"} { + if mapValue.GetField(name) == nil { + t.Fatalf("coalesced map value missing field %q: %v", name, mapValue) + } + } +} + +func TestWidenMapValueArray_CoalescesPartialRecordElements(t *testing.T) { + full := typ.NewRecord(). + Field("type", typ.LiteralString("thinking")). + Field("thinking", typ.Any). + Field("signature", typ.Any). + Build() + partial := typ.NewRecord().Field("signature", typ.String).Build() + + got := WidenMapValueArray(typ.NewMap(typ.String, typ.NewArray(full)), typ.String, partial) + + mp, ok := got.(*typ.Map) + if !ok { + t.Fatalf("WidenMapValueArray returned %T, want *typ.Map", got) + } + arr, ok := mp.Value.(*typ.Array) + if !ok { + t.Fatalf("map value = %T, want *typ.Array (%v)", mp.Value, mp.Value) + } + elem, ok := arr.Element.(*typ.Record) + if !ok { + t.Fatalf("array element = %T, want coalesced *typ.Record (%v)", arr.Element, arr.Element) + } + for _, name := range []string{"type", "thinking", "signature"} { + if elem.GetField(name) == nil { + t.Fatalf("coalesced array element missing field %q: %v", name, elem) + } + } +} + func TestMergeFieldAssignments_IncludesEscapedStringIndexKey(t *testing.T) { s := &Solution{ values: map[string]typ.Type{ diff --git a/types/flow/transfer.go b/types/flow/transfer.go index ba5e6dcf..9d917d01 100644 --- a/types/flow/transfer.go +++ b/types/flow/transfer.go @@ -763,10 +763,7 @@ func (s *Solution) processIndexerAssignmentReturnKey(p cfg.Point, ia IndexerAssi if currentType == nil { currentType = s.joinPredecessorRootTypes(p, ia.Symbol) } - currentType = preferDeclaredTemplateForWiden(currentType, s.lookupDeclaredType(constraint.Path{ - Root: ia.Root, - Symbol: ia.Symbol, - })) + currentType = preferDeclaredTemplateForWiden(currentType, s.declaredTemplateForPath(iaPath)) // Compute the widened type newType := widenWithIndexer(currentType, keyType, valueType) @@ -893,10 +890,7 @@ func (s *Solution) processTableMutatorAssignmentReturnKey(p cfg.Point, tm TableM } currentType := s.values[string(pathKey)] - currentType = preferDeclaredTemplateForWiden(currentType, s.lookupDeclaredType(constraint.Path{ - Root: tm.Target.Root, - Symbol: tm.Target.Symbol, - })) + currentType = preferDeclaredTemplateForWiden(currentType, s.declaredTemplateForPath(tm.Target)) var newType typ.Type if tm.KeySymbol != 0 || tm.KeyType != nil { @@ -1166,7 +1160,7 @@ func WidenMapValueArray(mapType typ.Type, keyType, elementType typ.Type) typ.Typ }, Map: func(m *typ.Map) typ.Type { newKey := mergeMapKeyDomain(m.Key, keyType) - newVal := WidenArrayElementType(m.Value, elementType, typ.JoinPreferNonSoft) + newVal := WidenArrayElementType(m.Value, elementType, joinContainerValueTypes) if newVal == nil { return mapType } @@ -1187,7 +1181,7 @@ func WidenMapValueArray(mapType typ.Type, keyType, elementType typ.Type) typ.Typ for _, m := range u.Members { if mp, ok := m.(*typ.Map); ok && !found { newKey := mergeMapKeyDomain(mp.Key, keyType) - newVal := WidenArrayElementType(mp.Value, elementType, typ.JoinPreferNonSoft) + newVal := WidenArrayElementType(mp.Value, elementType, joinContainerValueTypes) if newVal == nil { updated = append(updated, m) } else { @@ -1230,6 +1224,14 @@ func mergeMapKeyDomain(existing, incoming typ.Type) typ.Type { return typ.JoinPreferNonSoft(existing, incoming) } +func joinContainerValueTypes(existing, incoming typ.Type) typ.Type { + joined := typ.JoinPreferNonSoft(existing, incoming) + if union, ok := joined.(*typ.Union); ok { + return join.Types(union.Members...) + } + return joined +} + func preferDeclaredTemplateForWiden(current, declared typ.Type) typ.Type { if declared == nil { return current @@ -1240,6 +1242,20 @@ func preferDeclaredTemplateForWiden(current, declared typ.Type) typ.Type { return current } +func (s *Solution) declaredTemplateForPath(path constraint.Path) typ.Type { + if s == nil || path.Symbol == 0 { + return nil + } + root := s.lookupDeclaredType(constraint.Path{Root: path.Root, Symbol: path.Symbol}) + if root == nil || len(path.Segments) == 0 { + return root + } + if t, ok := s.deriveTypeFrom(root, path.Segments); ok { + return t + } + return nil +} + func isEmptyRecordNoMapType(t typ.Type) bool { switch v := t.(type) { case *typ.Alias: @@ -1301,7 +1317,7 @@ func widenWithIndexer(t typ.Type, keyType, valType typ.Type) typ.Type { // Record with fields: add or widen map component if r.HasMapComponent() { newKey := mergeMapKeyDomain(r.MapKey, keyType) - newVal := typ.JoinPreferNonSoft(r.MapValue, valType) + newVal := joinContainerValueTypes(r.MapValue, valType) if typ.TypeEquals(r.MapKey, newKey) && typ.TypeEquals(r.MapValue, newVal) { return t } @@ -1313,7 +1329,7 @@ func widenWithIndexer(t typ.Type, keyType, valType typ.Type) typ.Type { Map: func(m *typ.Map) typ.Type { // Widen existing map by unioning key/value types, preferring non-soft. newKey := mergeMapKeyDomain(m.Key, keyType) - newVal := typ.JoinPreferNonSoft(m.Value, valType) + newVal := joinContainerValueTypes(m.Value, valType) if typ.TypeEquals(m.Key, newKey) && typ.TypeEquals(m.Value, newVal) { return t } diff --git a/types/io/manifest_lookup.go b/types/io/manifest_lookup.go index 21974a42..430b150a 100644 --- a/types/io/manifest_lookup.go +++ b/types/io/manifest_lookup.go @@ -1,6 +1,8 @@ package io -import "github.com/wippyai/go-lua/types/typ" +import ( + "github.com/wippyai/go-lua/types/typ" +) // LookupManifest resolves a manifest by path. // diff --git a/types/query/core/field.go b/types/query/core/field.go index 2e114d09..501f3800 100644 --- a/types/query/core/field.go +++ b/types/query/core/field.go @@ -228,6 +228,9 @@ func fieldInUnion(u *typ.Union, name string, depth int) (typ.Type, bool) { } if len(types) == 0 { + if missingFromSome { + return typ.Nil, true + } return nil, false } diff --git a/types/query/core/field_test.go b/types/query/core/field_test.go index 60545e2c..59cd4e10 100644 --- a/types/query/core/field_test.go +++ b/types/query/core/field_test.go @@ -83,6 +83,17 @@ func TestFieldUnion(t *testing.T) { } }) + t.Run("field missing from all record members", func(t *testing.T) { + got, ok := Field(union, "missing") + if !ok { + t.Error("expected missing record-union field to resolve as nil") + return + } + if !typ.TypeEquals(got, typ.Nil) { + t.Errorf("expected nil, got %v", got) + } + }) + t.Run("empty union members", func(t *testing.T) { emptyUnion := &typ.Union{Members: []typ.Type{}} diff --git a/types/query/core/index.go b/types/query/core/index.go index 0829f87e..3add56af 100644 --- a/types/query/core/index.go +++ b/types/query/core/index.go @@ -34,25 +34,27 @@ func indexDepth(t, keyType typ.Type, depth int) (typ.Type, bool) { if stopDepth(t, depth) { return nil, false } + if keyType == nil { + keyType = typ.Unknown + } if top, ok := specialAccessType(t); ok { return top, true } res := typ.Visit(t, typ.Visitor[indexResult]{ Array: func(a *typ.Array) indexResult { - if keyType != nil && isNumeric(keyType) { + if isNumeric(keyType) { if a.Element == nil { return indexResult{t: typ.Nil, ok: true} } return indexResult{t: a.Element, ok: true} } + if keyType.Kind().IsPlaceholder() && a.Element != nil { + return indexResult{t: typ.NewOptional(a.Element), ok: true} + } return indexResult{} }, Map: func(m *typ.Map) indexResult { - if keyType == nil { - return indexResult{} - } - if keyType.Kind().IsPlaceholder() { if m.Value == nil { return indexResult{} @@ -85,6 +87,9 @@ func indexDepth(t, keyType typ.Type, depth int) (typ.Type, bool) { if isNumeric(keyType) && len(tup.Elements) > 0 { return indexResult{t: typ.NewOptional(typ.NewUnion(tup.Elements...)), ok: true} } + if keyType != nil && keyType.Kind().IsPlaceholder() && len(tup.Elements) > 0 { + return indexResult{t: typ.NewOptional(typ.NewUnion(tup.Elements...)), ok: true} + } return indexResult{} }, diff --git a/types/query/core/index_test.go b/types/query/core/index_test.go index bbe95d91..0f11ba19 100644 --- a/types/query/core/index_test.go +++ b/types/query/core/index_test.go @@ -26,6 +26,12 @@ func TestIndex(t *testing.T) { {"nil type", nil, typ.Integer, false, nil}, {"array with integer key", arr, typ.Integer, true, func(t typ.Type) bool { return t == typ.String }}, {"array with number key", arr, typ.Number, true, func(t typ.Type) bool { return t == typ.String }}, + {"array with unknown key placeholder", arr, typ.Unknown, true, func(t typ.Type) bool { + return ContainsNil(t) + }}, + {"array with nil key type", arr, nil, true, func(t typ.Type) bool { + return ContainsNil(t) + }}, {"array with string key", arr, typ.String, false, nil}, {"map with matching key", m, typ.String, true, func(t typ.Type) bool { _, ok := t.(*typ.Optional) @@ -52,11 +58,20 @@ func TestIndex(t *testing.T) { {"tuple with generic integer", tuple, typ.Integer, true, func(t typ.Type) bool { return ContainsNil(t) }}, + {"tuple with unknown key placeholder", tuple, typ.Unknown, true, func(t typ.Type) bool { + return ContainsNil(t) + }}, + {"tuple with nil key type", tuple, nil, true, func(t typ.Type) bool { + return ContainsNil(t) + }}, {"empty tuple with integer", typ.NewTuple(), typ.Integer, false, nil}, {"record with string literal key", rec, typ.LiteralString("a"), true, func(t typ.Type) bool { return t == typ.String }}, {"record with generic string key", rec, typ.String, true, func(t typ.Type) bool { return ContainsNil(t) }}, + {"record with nil key type", rec, nil, true, func(t typ.Type) bool { + return ContainsNil(t) + }}, {"empty record with string", typ.NewRecord().Build(), typ.String, true, func(t typ.Type) bool { return t == typ.Nil }}, {"builtin table marker", typ.NewInterface("table", nil), typ.String, true, func(t typ.Type) bool { return t == typ.Unknown }}, {"any type", typ.Any, typ.String, true, func(t typ.Type) bool { return t == typ.Any }}, diff --git a/types/query/core/operator.go b/types/query/core/operator.go index 0aee7672..4ffcdccb 100644 --- a/types/query/core/operator.go +++ b/types/query/core/operator.go @@ -719,6 +719,9 @@ func unaryOpUnion(op string, u *typ.Union) typ.Type { // isNumeric returns true if the type represents a numeric value. // Includes number, integer, and numeric literals. func isNumeric(t typ.Type) bool { + if t == nil { + return false + } switch t.Kind() { case kind.Number, kind.Integer: return true diff --git a/types/subtype/subtype.go b/types/subtype/subtype.go index 993e878e..8d0c302a 100644 --- a/types/subtype/subtype.go +++ b/types/subtype/subtype.go @@ -509,6 +509,9 @@ func (c *checker) checkRecord(sub, super *typ.Record, depth int) bool { continue } + if sf.Optional && subField.Type != nil && subField.Type.Kind() == kind.Nil { + continue + } if sf.Readonly { // Readonly in super: covariant check is sound (no writes through supertype) @@ -602,21 +605,6 @@ func canWidenTo(narrow, wide typ.Type) bool { } } - // Allow widening into unions when narrow fits at least one member. - if u, ok := wide.(*typ.Union); ok { - for _, m := range u.Members { - // Keep literal-tag unions invariant for mutable fields; only allow - // widening through non-literal branch members (for example number|string). - if m.Kind() == kind.Literal { - continue - } - if isSubtype(narrow, m) || canWidenTo(narrow, m) { - return true - } - } - return false - } - // Literal unions can widen to a primitive supertype when each branch widens. // Example: 0|8000 can widen to integer for mutable record fields. if u, ok := narrow.(*typ.Union); ok { @@ -632,6 +620,21 @@ func canWidenTo(narrow, wide typ.Type) bool { return true } + // Allow widening into unions when narrow fits at least one member. + if u, ok := wide.(*typ.Union); ok { + for _, m := range u.Members { + // Keep literal-tag unions invariant for mutable fields; only allow + // widening through non-literal branch members (for example number|string). + if m.Kind() == kind.Literal { + continue + } + if isSubtype(narrow, m) || canWidenTo(narrow, m) { + return true + } + } + return false + } + // Integer can widen to number if narrow.Kind() == kind.Integer && wide.Kind() == kind.Number { return true diff --git a/types/subtype/subtype_test.go b/types/subtype/subtype_test.go index e3541a88..132259a5 100644 --- a/types/subtype/subtype_test.go +++ b/types/subtype/subtype_test.go @@ -1373,6 +1373,15 @@ func TestRecordOptionalFields(t *testing.T) { } } +func TestRecordNilFieldSatisfiesOptionalField(t *testing.T) { + sub := typ.NewRecord().Field("headers", typ.Nil).Build() + super := typ.NewRecord().OptField("headers", typ.NewMap(typ.String, typ.String)).Build() + + if !IsSubtype(sub, super) { + t.Error("nil field represents Lua absence and should satisfy optional field") + } +} + func TestRecordFieldTypeSubtype(t *testing.T) { // {value: integer} <: {value: number} rec1 := typ.NewRecord().Field("value", typ.Integer).Build() @@ -2015,6 +2024,21 @@ func TestRecordMutableFieldWidening_LiteralBool(t *testing.T) { } } +func TestRecordMutableFieldWidening_OptionalLiteralUnionToPrimitiveUnion(t *testing.T) { + sub := typ.NewRecord(). + SetOpen(true). + OptField("data_func", typ.NewUnion(typ.String, typ.False)). + Build() + super := typ.NewRecord(). + SetOpen(true). + OptField("data_func", typ.NewUnion(typ.String, typ.Boolean)). + Build() + + if !IsSubtype(sub, super) { + t.Fatalf("expected %s to be subtype of %s", typ.FormatShort(sub), typ.FormatShort(super)) + } +} + func TestRecordMutableFieldWidening_NilToOptional(t *testing.T) { sub := typ.NewRecord().Field("x", typ.Nil).Build() super := typ.NewRecord().Field("x", typ.NewOptional(typ.String)).Build() diff --git a/types/typ/container.go b/types/typ/container.go index 2cf09df2..a2970986 100644 --- a/types/typ/container.go +++ b/types/typ/container.go @@ -60,6 +60,7 @@ func NewMap(key, value Type) *Map { if key == nil { key = Unknown } + key = NormalizeTableKey(key) if value == nil { value = Unknown } diff --git a/types/typ/container_test.go b/types/typ/container_test.go index 843698b6..c51538f0 100644 --- a/types/typ/container_test.go +++ b/types/typ/container_test.go @@ -92,6 +92,13 @@ func TestMapNilKeyValueDefaultsToUnknown(t *testing.T) { } } +func TestMapKeyRemovesImpossibleNil(t *testing.T) { + m := NewMap(NewOptional(String), Number) + if !TypeEquals(m.Key, String) { + t.Fatalf("map key = %v, want string", m.Key) + } +} + func TestMapEquality(t *testing.T) { m1 := NewMap(String, Number) m2 := NewMap(String, Number) diff --git a/types/typ/policy.go b/types/typ/policy.go index 884eb811..ad300080 100644 --- a/types/typ/policy.go +++ b/types/typ/policy.go @@ -59,7 +59,7 @@ func JoinReturnSlot(a, b Type) Type { if (IsAny(a) && b.Kind() == kind.Nil) || (IsAny(b) && a.Kind() == kind.Nil) { return Any } - if (IsUnknown(a) && b.Kind() == kind.Nil) || (IsUnknown(b) && a.Kind() == kind.Nil) { + if IsUnknown(a) || IsUnknown(b) { return Unknown } return coalesceCompatibleRecordMembers(JoinPreferNonSoft(a, b)) @@ -167,10 +167,18 @@ func JoinCompatibleRecords(a, b Type) (Type, bool) { fieldType = fa.Type optional = true readonly = fa.Readonly + if tail, ok := recordTailFieldType(br, name); ok { + fieldType, optional = normalizeMergedRecordField(JoinReturnSlot(fa.Type, tail)) + readonly = false + } case okb: fieldType = fb.Type optional = true readonly = fb.Readonly + if tail, ok := recordTailFieldType(ar, name); ok { + fieldType, optional = normalizeMergedRecordField(JoinReturnSlot(tail, fb.Type)) + readonly = false + } } switch { @@ -188,6 +196,50 @@ func JoinCompatibleRecords(a, b Type) (Type, bool) { return builder.Build(), true } +func normalizeMergedRecordField(t Type) (Type, bool) { + if inner, optional := SplitNilableFieldType(t); optional { + return inner, true + } + return t, false +} + +func recordTailFieldType(r *Record, name string) (Type, bool) { + if r == nil { + return nil, false + } + if r.HasMapComponent() && mapComponentMayContainStringKey(r.MapKey, name) { + return NewOptional(r.MapValue), true + } + if r.Open { + return Unknown, true + } + return nil, false +} + +func mapComponentMayContainStringKey(key Type, name string) bool { + if key == nil { + return false + } + if IsAny(key) || IsUnknown(key) { + return true + } + switch k := key.(type) { + case *Alias: + return mapComponentMayContainStringKey(k.Target, name) + case *Union: + for _, member := range k.Members { + if mapComponentMayContainStringKey(member, name) { + return true + } + } + return false + case *Literal: + return k.Base == kind.String && k.Value == name + default: + return k.Kind() == kind.String + } +} + func unaliasRecord(t Type) *Record { for { a, ok := t.(*Alias) @@ -321,6 +373,15 @@ func JoinBranchOutcome(a, b Type) Type { if TypeEquals(a, b) { return a } + if IsAny(a) || IsAny(b) { + return Any + } + if IsUnknown(a) && b.Kind() != kind.Nil { + return Unknown + } + if IsUnknown(b) && a.Kind() != kind.Nil { + return Unknown + } return NewUnion(a, b) } diff --git a/types/typ/policy_test.go b/types/typ/policy_test.go index 7c9ca1a8..01d803ed 100644 --- a/types/typ/policy_test.go +++ b/types/typ/policy_test.go @@ -11,6 +11,16 @@ func TestJoinReturnSlot_PreservesUnknownOverNil(t *testing.T) { } } +func TestJoinReturnSlot_PreservesUnknownOverConcrete(t *testing.T) { + rec := NewRecord().Field("value", String).Build() + if got := JoinReturnSlot(Unknown, rec); !TypeEquals(got, Unknown) { + t.Fatalf("JoinReturnSlot(unknown, record) = %v, want unknown", got) + } + if got := JoinReturnSlot(rec, Unknown); !TypeEquals(got, Unknown) { + t.Fatalf("JoinReturnSlot(record, unknown) = %v, want unknown", got) + } +} + func TestJoinReturnSlot_PreservesAnyOverNil(t *testing.T) { if got := JoinReturnSlot(Any, Nil); !TypeEquals(got, Any) { t.Fatalf("JoinReturnSlot(any, nil) = %v, want any", got) @@ -46,6 +56,15 @@ func TestJoinBranchOutcome_PreservesUnknownWithNil(t *testing.T) { } } +func TestJoinBranchOutcome_PreservesUnknownOverConcrete(t *testing.T) { + if got := JoinBranchOutcome(Unknown, String); !TypeEquals(got, Unknown) { + t.Fatalf("JoinBranchOutcome(unknown, string) = %v, want unknown", got) + } + if got := JoinBranchOutcome(String, Unknown); !TypeEquals(got, Unknown) { + t.Fatalf("JoinBranchOutcome(string, unknown) = %v, want unknown", got) + } +} + func TestJoinBranchOutcome_PreservesSoftRuntimeAlternative(t *testing.T) { left := NewOptional(NewArray(Any)) right := NewArray(Number) @@ -99,6 +118,46 @@ func TestJoinReturnSlot_MergesRecordFieldsAsOptional(t *testing.T) { } } +func TestJoinReturnSlot_MergesMissingOpenRecordFieldWithUnknownTail(t *testing.T) { + candidate := NewTuple(NewRecord(). + Field("content", NewRecord(). + Field("parts", NewTuple(NewRecord().Field("text", String).Build())). + Build()). + Build()) + stream := NewRecord(). + Field("candidates", candidate). + Field("status_code", Number). + Build() + decoded := NewRecord(). + Field("metadata", NewRecord().Build()). + Field("status_code", Number). + SetOpen(true). + Build() + + got := JoinReturnSlot(stream, decoded) + rec, ok := got.(*Record) + if !ok { + t.Fatalf("JoinReturnSlot(stream, decoded) = %T, want *Record", got) + } + field := rec.GetField("candidates") + if field == nil { + t.Fatalf("merged record lost candidates field: %v", rec) + } + if field.Optional { + t.Fatalf("candidates should merge with the open row tail, not absence: %#v", field) + } + if !TypeEquals(field.Type, Unknown) { + t.Fatalf("candidates = %v, want unknown from open row tail", field.Type) + } +} + +func TestRecordMapKeyRemovesImpossibleNil(t *testing.T) { + rec := NewRecord().MapComponent(NewOptional(String), Number).Build() + if !TypeEquals(rec.MapKey, String) { + t.Fatalf("record map key = %v, want string", rec.MapKey) + } +} + func TestJoinReturnSlot_PreservesDiscriminatedRecordUnion(t *testing.T) { a := NewRecord(). Field("kind", LiteralString("a")). diff --git a/types/typ/rebuild.go b/types/typ/rebuild.go index fa068357..200e0116 100644 --- a/types/typ/rebuild.go +++ b/types/typ/rebuild.go @@ -85,6 +85,9 @@ func buildRecordType(fields []Field, metatable, mapKey, mapValue Type, open bool if mapKey == nil && mapValue != nil { mapKey = Unknown } + if mapKey != nil { + mapKey = NormalizeTableKey(mapKey) + } if mapValue == nil && mapKey != nil { mapValue = Unknown } diff --git a/types/typ/soft.go b/types/typ/soft.go index 3a9bb4b0..32e5ad86 100644 --- a/types/typ/soft.go +++ b/types/typ/soft.go @@ -4,6 +4,7 @@ import ( "sync" "github.com/wippyai/go-lua/internal" + "github.com/wippyai/go-lua/types/kind" ) // SoftPolicy controls how soft-placeholder detection behaves. @@ -50,6 +51,9 @@ func isSoft(t Type, guard internal.RecursionGuard, policy SoftPolicy) bool { case *Map: return isSoft(tt.Value, next, policy) case *Record: + if tt.Open && len(tt.Fields) == 0 && !tt.HasMapComponent() { + return true + } if len(tt.Fields) == 0 && !tt.HasMapComponent() { return policy.AllowEmptyRecord } @@ -227,13 +231,19 @@ func pruneSoftUnionMembersMemo( members = rewritten } nonSoftMembers := make([]Type, 0, len(node.Members)-softCount) + hasNonNilConcreteMember := false for _, member := range members { if !isSoftWithMemo(member, SoftPlaceholderPolicy, softMemo) { nonSoftMembers = append(nonSoftMembers, member) + if member != nil && member.Kind() != kind.Nil { + hasNonNilConcreteMember = true + } } } - out = NewUnion(nonSoftMembers...) - break + if hasNonNilConcreteMember { + out = NewUnion(nonSoftMembers...) + break + } } if !changed { out = t diff --git a/types/typ/soft_test.go b/types/typ/soft_test.go index 8ccfa231..852a70e5 100644 --- a/types/typ/soft_test.go +++ b/types/typ/soft_test.go @@ -66,6 +66,7 @@ func TestPruneSoftUnionMembers(t *testing.T) { {"drop soft array", NewUnion(softArray, entryArray), entryArray}, {"drop empty record", NewUnion(emptyRecord, entryArray), entryArray}, {"all soft stays", NewUnion(Any, softArray), Any}, + {"nil does not erase optional soft table shape", NewUnion(Nil, softArray, NewRecord().SetOpen(true).Build()), NewUnion(Nil, softArray, NewRecord().SetOpen(true).Build())}, } for _, tt := range tests { @@ -137,6 +138,8 @@ func TestIsRefinableAnnotation(t *testing.T) { {"unknown", Unknown, false}, {"optional any", NewOptional(Any), false}, {"array any", NewArray(Any), true}, + {"open table top", NewRecord().SetOpen(true).Build(), true}, + {"array or open table top", NewUnion(NewArray(Any), NewRecord().SetOpen(true).Build()), true}, {"record map any", NewRecord().MapComponent(String, Any).Build(), true}, {"record", NewRecord().Field("id", String).Build(), false}, } diff --git a/types/typ/table_key.go b/types/typ/table_key.go new file mode 100644 index 00000000..67b2bc10 --- /dev/null +++ b/types/typ/table_key.go @@ -0,0 +1,59 @@ +package typ + +// NormalizeTableKey removes impossible nil alternatives from Lua table key +// domains. Nil from iterators is the termination sentinel; it is never an +// inhabited table key. +func NormalizeTableKey(t Type) Type { + if t == nil { + return Unknown + } + switch v := t.(type) { + case *Annotated: + inner := NormalizeTableKey(v.Inner) + if inner == v.Inner { + return t + } + return NewAnnotated(inner, v.Annotations) + case *Alias: + inner := NormalizeTableKey(v.Target) + if inner == nil || IsNever(inner) { + return inner + } + if inner == v.Target { + return t + } + return NewAlias(v.Name, inner) + case *Optional: + return NormalizeTableKey(v.Inner) + case *Union: + members := make([]Type, 0, len(v.Members)) + changed := false + for _, member := range v.Members { + if member == nil || member.Kind() == Nil.Kind() { + changed = true + continue + } + normalized := NormalizeTableKey(member) + if normalized == nil || IsNever(normalized) { + changed = true + continue + } + if normalized != member { + changed = true + } + members = append(members, normalized) + } + if len(members) == 0 { + return Never + } + if !changed { + return t + } + return NewUnion(members...) + default: + if t.Kind() == Nil.Kind() { + return Never + } + return t + } +} From 51d9a72ed15c2eee12f52545719bddd8519ca0ea Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 00:44:26 -0400 Subject: [PATCH 03/71] Document checker domain target --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 691 ++++++++++++++++++++++++++++++ 1 file changed, 691 insertions(+) create mode 100644 INTERPROC_FACTS_DOMAIN_JOURNAL.md diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md new file mode 100644 index 00000000..03b8da38 --- /dev/null +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -0,0 +1,691 @@ +# Interproc Facts And Checker Domain Design Journal + +## 2026-05-19 Design Consolidation Checkpoint + +This document records the design model before the next implementation pass. It is +not an implementation plan for incremental bridges. The intended correction is a +flash migration: design the final shape, migrate directly to it, delete the old +helper clusters, and do not leave compatibility wrappers or fallback layers in +the production checker. + +## Goal + +The checker should read as one abstract interpreter over a product domain. + +The current implementation is already powerful: + +- it tracks flow-sensitive path facts, +- it narrows through guards and assertions, +- it propagates table and container mutations, +- it infers local and interprocedural function facts, +- it correlates value/error return slots, +- it handles soft annotation evidence, +- it uses Salsa-style query inputs for function-result invalidation. + +The design problem is that these capabilities are encoded by many local helper +clusters. That makes the system hard to reason about even when the behavior is +mostly correct. Helpers such as `typeRefinesTableKeyByTruthiness` are not just +helpers; they are domain laws living in the wrong place. + +The target is a smaller, clearer checker where each law has exactly one owner. + +## Non-Negotiable Constraints + +- No production transition layer. +- No legacy mirror fact channels. +- No raising iteration caps to hide non-convergence. +- No external application-code edits as part of go-lua design correction. +- No weakening soundness by making `any` assignable to concrete contracts. +- No helper-specific exceptions for external lint targets. +- No pools as the first answer to performance; use structural ownership and + caching first. +- Every final abstraction must have law tests and paired positive/negative + behavioral tests. + +## Current Mental Model + +The checker is a multi-phase abstract interpreter: + +1. Scope and CFG construction establish symbols, lexical parents, control-flow + points, and function graph identity. +2. Declared-phase synthesis extracts initial types, table literal shapes, + function literal signatures, and call/effect evidence. +3. `flowbuild` lowers AST and synthesis facts into flow inputs: + declarations, assignments, table/index mutations, call effects, branch + predicates, return constraints, numeric constraints, aliases, and termination + facts. +4. `types/flow` solves a forward dataflow problem over canonical SSA path keys. + The persistent solved state is currently split across value maps, conditions, + numeric states, alias maps, field overlays, and local caches. +5. Narrowing queries are demand-side interpretation: read solved facts at a + point, apply propagated constraints, and answer refined path/type questions. +6. Return inference and local function SCC solving use the flow result plus + interprocedural snapshots to infer return vectors, parameter hints, function + facts, captured fields, and captured container mutations. +7. The interprocedural store combines same-iteration deltas with a precise join + and combines recursive fixpoint boundaries with widening. +8. Salsa-style snapshot inputs connect function-result queries to exact + interproc facts, refinements, and constructor-field snapshots. + +This is the right high-level shape. The weakness is that the product domain is +not first-class enough in code. + +## Clean Abstract Interpreter Target + +The final checker should be explainable as: + +```text +AbstractInterpreter = CFG + AbstractState + Transfer + Join + Widen + Query +``` + +Where: + +- `CFG` owns control-flow order and dominance. +- `AbstractState` owns the full product of memory, value, numeric, relation, + effect, and termination facts. +- `Transfer` is the only way statements and expressions change state. +- `Join` is the only way same-phase branch/predecessor evidence combines. +- `Widen` is the only way recursive or interprocedural cycles are forced to + converge. +- `Query` reads solved state without inventing another analysis path. + +This is the mental model the code should expose. If a rule cannot be explained +as one of these operations, it is either orchestration or a design smell. + +The current checker has the right ingredients but not the right ownership. It +has preflow inference, flow solving, narrowing queries, return SCC inference, +overlay refresh, mutation replay, and interproc widening. Those should become +clients of the same abstract-state and domain APIs. They should not remain +separate places where local helpers decide what refinement means. + +## State-Of-The-Art Bar + +The target is not just cleaner Go packages. The target is a modern static +analysis engine with explicit theory: + +- monotone abstract domains with named `Normalize`, `Leq`, `Join`, `Meet`, and + `Widen` operations; +- transfer functions over a product state instead of helper-specific rewrites; +- a first-class memory model for paths, fields, indexes, aliases, mutations, + row tails, and dominance; +- relational facts for tuple slots and path correlations instead of hardcoded + error-return branches; +- principled distinction between `unknown`, `any`, `nil`, absent fields, soft + evidence, hard evidence, table top, and open row tails; +- explicit widening at recursive boundaries and optional narrowing only after a + post-fixpoint is reached; +- deterministic canonicalization and equality, never equality-time repair; +- cache keys derived from immutable inputs and domain snapshots, not incidental + phase call order; +- paired positive/negative law tests so the implementation cannot get faster by + becoming less sound. + +Anything less will keep producing local helper patches. The migration should +make the checker look like the theory it is implementing. + +## Foundational Diagnosis + +The checker has accumulated strong behavior before it acquired the right +vocabulary. + +Current scattered concepts: + +- value-type joins, +- return-slot joins, +- function-param fact joins, +- param-hint joins, +- table-top absorption, +- soft-placeholder replacement, +- open-record row-tail merging, +- recursive structural-growth cutoffs, +- truthiness refinements, +- error-return tuple correlations, +- captured table/container mutation replay, +- path identity and alias identity, +- body-derived parameter contracts, +- call-site observations, +- signature projection to body use. + +These concepts are real. The problem is not that they exist. The problem is that +they appear as local helpers in `returns`, `paramhints`, `flow`, `synth`, and +`typ`, with overlapping responsibilities. + +That creates the "guacamole" feeling: behavior is strong, but the mental model +is not visible at the package boundary. + +## Canonical Product Domain + +The final checker should have these explicit products. + +### Value Domain + +Owns pure type operations that are independent of checker phase: + +- `NormalizeType` +- `JoinValue` +- `JoinReturnSlot` +- `Meet` +- `WidenShape` +- `Refines` +- `TruthinessRefinement` +- `Nilability` +- `SoftEvidence` +- open/closed record row-tail policy +- map/array/table-top classification + +Candidate home: + +```text +types/typ/domain +``` + +or, if it needs checker-only evidence policy: + +```text +compiler/check/domain/value +``` + +Rule: domain-level predicates such as "candidate refines baseline by removing a +falsy table key" cannot live in `compiler/check/returns/join.go`. + +### Memory And Path Domain + +Owns the question "what program location does this fact describe?" + +It must unify: + +- `constraint.Path` +- CFG symbol/version identity +- SSA path keys +- field/index segments +- aliases +- dynamic index writes +- table mutator paths +- captured mutation paths +- field overlays + +Candidate home: + +```text +compiler/check/memory +``` + +The public model should be: + +```go +type Location struct { ... } +type MemoryState struct { ... } +type Mutation struct { Kind, Target, Key, Value, Dominance } +``` + +Current scattered path helpers should collapse into this package. The solver +should not need to know whether a fact came from a table literal, field write, +alias replay, or captured mutation to apply the same path-law rules. + +### Flow State Domain + +Owns the persistent state of intraprocedural analysis. + +The final `AbstractState` should be a product: + +- memory facts, +- numeric facts, +- shape/presence facts, +- relation facts, +- termination facts, +- effect facts. + +Candidate home: + +```text +compiler/check/flowstate +``` + +or inside `types/flow` if it remains independent of checker-specific APIs. + +Current weakness: + +`types/flow.ProductDomain` is the closest modern abstraction, but it is mostly +used transiently during narrowing queries. The main solver still stores raw +maps and side caches. That split should disappear. Query-time narrowing should +read from the same abstract state product that transfer functions update. + +### Relation Domain + +Owns facts that connect multiple paths or tuple slots: + +- error-return `(value, err)` correlation, +- sibling return-slot narrowing, +- predicate links, +- assertion links, +- type-test links, +- tuple-slot relation facts, +- custom error records. + +Candidate home: + +```text +compiler/check/domain/relation +``` + +This is where error-return convention should live. It should not be encoded as +scattered checks for exactly two return slots at call sites. The canonical +shape is a relation: + +```go +type TupleRelation struct { + Slots []SlotPredicate +} +``` + +The current `(value, err)` convention is then one predefined relation, not a +special checker behavior. + +### Function Fact Domain + +Owns all interprocedural facts about functions. + +The stored authority remains: + +```go +type FunctionFact struct { + Summary []typ.Type + Narrow []typ.Type + Type typ.Type +} +``` + +But its operations should move out of `returns`: + +```text +compiler/check/domain/functionfact +``` + +It owns: + +- same-shape function fact merge, +- param-slot merge, +- return-vector merge delegation, +- effect/spec/refinement merge, +- function fact widening, +- function fact normalization, +- function fact equality. + +The param-slot policy must be a named domain object, not scattered helpers: + +```go +type ParamSlotDomain struct { + Mode MergeMode // precise join or convergence widening +} +``` + +The current `candidateRefinesFunctionParam`, `typeRefinesTableKeyByTruthiness`, +`preferConcreteOverSoftType`, and related functions become methods or private +support functions of this domain. + +### Return Summary Domain + +Owns return-vector shape and convergence: + +- arity normalization, +- nil-slot handling, +- `unknown` as unresolved runtime behavior, +- stale nil-only regression prevention, +- recursive structural-growth cutoff, +- concrete-over-soft container refinement, +- return-slot row-tail merging, +- function-return widening. + +Candidate home: + +```text +compiler/check/domain/returnsummary +``` + +The existing `returns` package can either become this package or stop owning +non-return policy. + +### Parameter Evidence Domain + +Owns all evidence about parameters: + +- call-site observations, +- body-derived contracts, +- signature facts, +- param-use projection, +- soft annotations, +- table-top absorption, +- nilability splitting, +- map/record joins, +- call graph propagation. + +Candidate home: + +```text +compiler/check/domain/paramevidence +``` + +Current split: + +- some policy lives in `compiler/check/infer/paramhints`, +- some lives in `compiler/check/returns/widen.go`, +- some lives in return SCC inference, +- some lives in interproc postflow. + +Final rule: + +Orchestration may stay in inference packages, but merge/canonicalization policy +belongs to the parameter evidence domain. + +### Interproc Fact Domain + +Owns the whole product: + +```go +type FactsDomain struct { + FunctionFacts FunctionFactDomain + ParamEvidence ParamEvidenceDomain + LiteralSigs LiteralSignatureDomain + Captures CaptureDomain + Constructors ConstructorDomain +} +``` + +Candidate home: + +```text +compiler/check/domain/interproc +``` + +It exposes only: + +```go +Normalize(facts) +Leq(a, b) +Join(a, b) +Widen(prev, next) +Equal(a, b) +``` + +The store calls this domain. Producers emit deltas. Producers do not call local +helper joins directly. + +## Helper Cluster Ownership + +| Current Cluster | Current Location | Final Owner | +|---|---|---| +| `JoinFacts`, `WidenFacts`, fact equality | `compiler/check/returns` | `domain/interproc` | +| function fact type merge | `compiler/check/returns/join.go` | `domain/functionfact` | +| function param-slot refinement | `compiler/check/returns/join.go`, `widen.go` | `domain/functionfact.ParamSlotDomain` | +| return-vector merge/repair | `compiler/check/returns/join.go` | `domain/returnsummary` | +| table-top absorption | `infer/paramhints`, `returns/widen.go` | `domain/paramevidence` plus value-domain classifier | +| soft vs concrete evidence | `typ/soft.go`, `returns/widen.go`, return overlay | `domain/value` evidence policy | +| open-record row-tail merge | `types/typ/policy.go` | `domain/value` row-shape policy | +| path/query/alias identity | `constraint`, `flowbuild/path`, `flow/pathkey` | `memory` | +| table/container mutation replay | `nested`, `returns`, `flowbuild`, `flow` | `memory` mutation domain | +| error-return convention | `erreffect`, call/return inference | `domain/relation` | +| body parameter contracts | `infer/return`, `flowbuild/assign` | `domain/paramevidence` | +| Salsa snapshot inputs | `store/snapshot_inputs.go` | keep in store, but document as cache boundary | + +## Target Data Flow + +The final flow should be: + +```text +source + -> CFG + symbol graph + -> normalized checker IR + -> abstract transfer over AbstractState + -> queryable solved state + -> function result + -> interproc fact delta + -> FactsDomain.Join or FactsDomain.Widen + -> Salsa input update + -> dependent function-result query revalidation +``` + +Every arrow has one owner. + +No phase should secretly perform another local abstract interpretation unless +that interpretation is a named domain transfer over the same `AbstractState`. + +Preflow, local SCC inference, and return overlay currently exist for good +reasons. The design target is not to delete their semantics. The design target +is to make them clients of the same domain objects instead of separate local +machines. + +## Salsa And Cache Model + +Current good shape: + +- function-result keys are stable graph/parent identities, +- interproc snapshots are `db.Input`s, +- updating facts bumps dependent queries through the database, +- core type queries are Salsa-style pure queries. + +Current weak shape: + +- the checker still has several non-Salsa local caches with implicit lifetimes, +- flow solution caches are manually invalidated, +- some expensive shape scans are repeated because domain operations are not + centralized, +- param-use projection can rescan AST bodies instead of reading a graph-indexed + use summary. + +Final cache contracts: + +1. Source inputs are `db.Input`s: + - manifests, + - parent scope, + - CFG identity, + - interproc facts, + - constructor fields, + - function refinements. +2. Pure expensive computations are `db.Query`s: + - core type lookup/index/method/operator queries, + - function result, + - parameter-use summary by graph/function, + - shape classification for large recursive types if profiling confirms it. +3. Intraprocedural flow state remains per-function and ephemeral unless it is + keyed by the exact immutable input bundle. Do not put hot per-edge transfer + into Salsa if dependency recording costs more than recomputation. +4. Domain operations must be pure and deterministic so they can be memoized + safely when profiling justifies it. +5. Cache lifetime must be explicit in package docs. No cache should depend on + call order for correctness. + +Performance target: + +- fewer repeated shape scans, +- fewer temporary maps in hot merges, +- copy-on-write vectors and maps, +- immutable fact snapshots, +- stable interning/hash-consing where already available, +- no object pools until ownership is proved and structural wins are exhausted. + +## Weak Points To Fix In The Design + +### 1. Domain Laws Are Not Named + +The checker has laws such as: + +- hard evidence dominates soft evidence, +- `unknown` in return summaries is unresolved runtime behavior, +- open record absent field means row-tail, not nil, +- nil field can satisfy optional absence in record subtyping, +- table-top can absorb precise table evidence in parameter hints, +- truthy refinement can remove falsy key alternatives. + +Today many of these appear as function names buried in unrelated packages. They +must become named laws of specific domains. + +### 2. Too Many Local Abstract Interpreters + +`flowbuild`, `types/flow`, return SCC inference, preflow synthesis, return +overlay refresh, condition extraction, and interproc widening each perform part +of the abstract interpretation. + +The final design should have one abstract state model and several orchestration +phases. The orchestration may be complex; the lattice rules cannot be local. + +### 3. Memory Is Not First-Class Enough + +Field writes, table inserts, dynamic indexes, aliases, captured mutations, and +path queries all affect the same memory model. They are currently split across +multiple packages. + +This causes bugs where: + +- parent-derived structure outranks explicit child-path facts, +- captured table inserts replay through the wrong mutator kind, +- alias identity and dominance are checked locally, +- nil overwrite and optional absence need separate fixes. + +The final memory domain must own these rules. + +### 4. Parameter Evidence Has Multiple Authorities + +Parameter evidence currently comes from: + +- call sites, +- body contracts, +- function facts, +- literal signatures, +- soft source annotations, +- param-use projection. + +The final design needs one `ParameterEvidence` lattice with evidence provenance +and merge mode. The implementation should not need separate helpers for +"param hints" and "function param facts" that rediscover the same truthiness, +softness, and table-key laws. + +### 5. Relation Facts Are Under-Modeled + +The system supports powerful correlations, especially error-return behavior, but +the relation model is still too tied to known patterns. + +The final design should model tuple/path relations directly. `(value, err)` is +then one relation instance. This keeps the system extensible without hardcoded +branch helpers or return-slot checks. + +### 6. Tests Are Too Positive-Heavy + +Many external-lint regressions are "this must type-check" tests. Those are +useful, but insufficient. They can pass through accidental broadening. + +Every major law needs: + +- a positive test proving wanted inference, +- a negative test proving sound rejection, +- a domain law test proving normalize/join/widen idempotence and monotonicity. + +## Flash Migration Shape + +The implementation should be prepared privately but merged as a direct final +shape. The production branch should not pass through partial API compatibility. + +Flash migration means: + +1. Introduce final domain packages. +2. Move domain laws into those packages. +3. Replace all call sites in one migration. +4. Delete old helper clusters in the same migration. +5. Delete obsolete tests that asserted old helper behavior. +6. Add law-oriented tests for the new domain boundaries. +7. Run the global replay and classify remaining diagnostics. + +No step should leave: + +- old helper path plus new helper path, +- adapter projections like "legacy view from canonical facts", +- duplicate merge functions for the same semantic slot, +- fallback normalization in equality, +- broad `any` acceptance to clear lints. + +## Proposed Final Package Map + +```text +compiler/check/domain/interproc +compiler/check/domain/functionfact +compiler/check/domain/returnsummary +compiler/check/domain/paramevidence +compiler/check/domain/relation +compiler/check/memory +compiler/check/flowstate +``` + +Existing packages remain as orchestration: + +```text +compiler/check/flowbuild +compiler/check/synth +compiler/check/infer/return +compiler/check/infer/interproc +compiler/check/store +compiler/check/pipeline +``` + +Low-level pure type mechanics remain under: + +```text +types/typ +types/subtype +types/query/core +types/db +``` + +The key rule: + +Orchestration packages may decide when a fact is produced. Domain packages +decide what that fact means and how it combines. + +## Verification Model For The Future Migration + +Required proof after the flash migration: + +```text +go test ./... +git diff --check +../scripts/verify-suite.sh +``` + +Required domain law tests: + +- `Normalize(Normalize(x)) == Normalize(x)` +- `Join(a, b) == Join(b, a)` where the domain is intended commutative +- `Join(Join(a, b), c) == Join(a, Join(b, c))` where applicable +- `Widen(Widen(a, b), b) == Widen(a, b)` +- `a <= Join(a, b)` +- `a <= Widen(a, b)` +- derived function type equals canonical function fact projection +- no equality-time normalization bridge + +Required behavior suites: + +- soft vs hard evidence, +- any vs unknown, +- nil vs absent, +- open vs closed records, +- table top vs precise table shapes, +- captured table/container mutations, +- alias and dominance, +- error-return tuple relations, +- local SCC parameter evidence, +- interproc non-convergence fixtures, +- external replay reductions. + +## Current Conclusion + +The checker is not fundamentally the wrong idea. It is closer to a serious +abstract interpreter than it looks from isolated helper functions. + +The foundational problem is organizational: the product domain exists in +behavior but not cleanly enough in code. The next design correction should not +add more local helpers. It should move the existing laws into explicit domain +objects, make memory/path identity first-class, and make Salsa/cache boundaries +documented and deliberate. + +If this is done as a flash migration, the codebase should become smaller because +many helper clusters collapse into a few named domains. It should also become +easier to reason about because every merge/refinement/widening decision will +have one owner and one law-test suite. From 728c8ec720663b58de862fb3f3b4f475cb8042e7 Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 00:46:32 -0400 Subject: [PATCH 04/71] Refine abstract interpreter design --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 371 ++++++++++++++++++++++++++++++ 1 file changed, 371 insertions(+) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index 03b8da38..9ed91567 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -123,6 +123,157 @@ analysis engine with explicit theory: Anything less will keep producing local helper patches. The migration should make the checker look like the theory it is implementing. +## Core Moral Model + +The checker should be taught and reasoned about with one sentence: + +```text +Evidence is produced by transfer, combined by domains, stabilized by widening, +and observed by queries. +``` + +That sentence is the guardrail. + +- Extraction does not decide lattice policy. It only converts source syntax into + typed evidence and transfer instructions. +- Transfer does not decide cross-iteration convergence. It only updates the + current abstract state. +- Domains do not inspect AST. They only combine abstract values and facts. +- Widening does not recover precision. It only guarantees convergence. +- Queries do not produce new facts. They only read the solved state and apply + already-recorded constraints. +- Interprocedural producers do not mutate old state. They emit deltas. + +If a function violates one of these rules, it is a design smell even if the +behavioral test passes. + +## Canonical Dataflow Contract + +The final dataflow should have explicit boundary objects. + +```text +Source + -> GraphBundle + -> CheckerIR + -> TransferProgram + -> AbstractState + -> QueryView + -> FunctionResult + -> InterprocDelta + -> FactsDomain + -> SnapshotInputs +``` + +### GraphBundle + +Owns: + +- AST function body, +- CFG, +- symbol table, +- parent scope identity, +- dominance/post-dominance indexes, +- local function indexes, +- parameter-use summaries. + +It is immutable after construction. Anything expensive and graph-derived should +be cached here or through a Salsa query keyed by graph identity. + +### CheckerIR + +Owns the normalized checker program: + +- declarations, +- assignments, +- branch predicates, +- calls, +- returns, +- table constructors, +- field/index writes, +- mutation effects, +- termination effects. + +It should be AST-free except for source spans and stable graph references. This +is where the checker stops being syntax-driven and becomes analysis-driven. + +### TransferProgram + +Owns executable transfer instructions over `AbstractState`. + +Examples: + +```text +Assign(Location, ValueExpr) +Assume(Condition) +Mutate(Mutation) +Call(CallSite) +Return(ReturnTuple) +Terminate(Reason) +``` + +Every statement-level fact should enter the solver through an instruction like +this. A table insert, captured mutation replay, field assignment, and dynamic +index write should not each invent their own path rules. + +### AbstractState + +Owns the whole product: + +```text +AbstractState = + MemoryState + x ValueFacts + x NumericFacts + x ShapeFacts + x RelationFacts + x EffectFacts + x TerminationFacts +``` + +This must be the persistent state of the intraprocedural solver. Query-time +`ProductDomain` construction should be replaced by reading this product, or by +creating a cheap view over it. The state product is the source of truth. + +### QueryView + +Owns read-only answers: + +- type at point, +- narrowed path type, +- field/index presence, +- tuple relation at call site, +- constant/numeric facts, +- reachability. + +It cannot write facts. It cannot perform fresh synthesis that changes the +answer independently from `AbstractState`. + +### InterprocDelta + +Owns facts emitted by a completed function analysis: + +- function fact, +- parameter evidence, +- literal signatures, +- captured field mutations, +- captured container/table mutations, +- constructor fields, +- relation summaries. + +The delta is immutable. The store combines it through `FactsDomain` only. + +## Phase Responsibility Table + +| Phase | May Create | May Combine | May Widen | May Query | Forbidden | +|---|---|---|---|---|---| +| Scope/CFG | graph identity, symbols | no type facts | no | no | type merge policy | +| Extract/IR | transfer instructions | no domain joins | no | declared-only queries | fixpoint repair | +| Flow solve | abstract state updates | domain joins at CFG joins | loop-local widening only if owned by flow domain | internal state reads | interproc fact writes | +| Narrow/query | read-only answers | no persistent joins | no | yes | producing facts | +| Return SCC | local return/param deltas | local domain joins | SCC widening through domain only | solved flow state | AST-specific merge laws | +| Interproc store | immutable deltas | `FactsDomain.Join` | `FactsDomain.Widen` | snapshot reads | producer-specific callbacks | +| Salsa | dependencies/cache | no semantic joins | no | query execution | hidden state mutation | + ## Foundational Diagnosis The checker has accumulated strong behavior before it acquired the right @@ -427,6 +578,180 @@ helper joins directly. | body parameter contracts | `infer/return`, `flowbuild/assign` | `domain/paramevidence` | | Salsa snapshot inputs | `store/snapshot_inputs.go` | keep in store, but document as cache boundary | +## Worked Consolidation Examples + +### Table-Key Truthiness Refinement + +Current smell: + +```go +candidateRefinesFunctionParam(candidate, baseline) +typeRefinesTableKeyByTruthiness(candidate, baseline) +recordRefinesTableKeyByTruthiness(candidate, baseline) +``` + +These helpers are trying to express one domain law: + +```text +A table-like parameter fact may refine its key domain by removing falsy key +members only if the table value domain and structural frame are preserved. +``` + +Final home: + +```text +domain/value.Refinement +domain/functionfact.ParamSlotDomain +domain/paramevidence +``` + +Final expression: + +```go +refinement := value.Refinement{ + Kind: value.RefineTruthyKey, + PreserveFrame: true, + PreserveValue: true, +} +paramSlot.Join(existing, candidate, refinement) +``` + +The check is no longer a local function-param helper. It is a value-domain +refinement rule reused by parameter evidence, function facts, and return +summary map-key refinement. + +### Soft Evidence Replacement + +Current smell: + +```go +preferConcreteOverSoftType(a, b) +typ.PruneSoftUnionMembers(t) +reconcileSoftAnnotatedInference(base, inferred) +``` + +These are fragments of one evidence-ordering law: + +```text +hard concrete evidence dominates soft placeholder evidence, but nil alone does +not erase soft structured evidence. +``` + +Final home: + +```text +domain/value.EvidenceOrder +``` + +Final expression: + +```go +EvidenceOrder.Select(existing, candidate) +``` + +Every caller gets the same policy: + +- soft annotation refinement, +- function parameter facts, +- parameter evidence, +- return-summary container refinement, +- flow assignment refinement. + +### Open-Record Row Tail + +Current smell: + +Open-record behavior is split between record join, subtyping, table literal +contextualization, and external-regression fixes. + +Canonical law: + +```text +A missing field on an open record is row-tail evidence, not proof of nil. +A missing field on a closed record is absence. +``` + +Final home: + +```text +domain/value.RowShape +``` + +Final API: + +```go +RowShape.FieldEvidence(record, fieldName) FieldEvidence +``` + +The rest of the checker asks for field evidence. It does not rediscover whether +the record is open, closed, map-like, or table-top. + +### Captured Table Mutation Replay + +Current smell: + +Captured table inserts, generic container mutations, parent replay, direct +flow mutators, and nested function calls have separate paths. + +Canonical law: + +```text +A mutation has one semantic operator and one memory location. Replay is valid +only when alias identity, dominance, and operator kind are preserved. +``` + +Final home: + +```text +compiler/check/memory +``` + +Final expression: + +```go +MemoryState.Apply(Mutation{ + Kind: MutationTableElement, + Target: Location, + Value: Type, + Provenance: CapturedCall, +}) +``` + +The same apply path handles direct `table.insert`, nested captured insert, and +exported callback replay. + +### Error-Return Correlation + +Current smell: + +Several phases know about the `(value, err)` convention, arity checks, and +success/failure narrowing. + +Canonical law: + +```text +Error-return behavior is a tuple relation over return slots, not a special case +of a two-result function. +``` + +Final home: + +```text +domain/relation +``` + +Final expression: + +```go +RelationDomain.Attach(ReturnTupleRelation{ + Success: { ErrSlot: Nil, ValueSlot: NonNilOrUnknown }, + Failure: { ErrSlot: NonNil, ValueSlot: NilOrUnknown }, +}) +``` + +The canonical Lua `(value, err)` convention is one predefined relation. Future +relations do not require new helper clusters. + ## Target Data Flow The final flow should be: @@ -579,6 +904,52 @@ Every major law needs: - a negative test proving sound rejection, - a domain law test proving normalize/join/widen idempotence and monotonicity. +## Edge-Case Matrix + +The migration must consider edge cases beyond the failures already seen. The +design is not complete until each row below has an owner domain and tests. + +| Area | Edge Cases To Model | +|---|---| +| `unknown` | branch join with concrete, return merge with concrete, exported summary, table field, array element, call argument, relation slot | +| `any` | explicit cast to any, imported dynamic data, any flowing to concrete param, any in record field, any as table key/value, any through relation facts | +| `nil` | nil as Lua value, nil as field deletion, nil satisfying optional absence, nil array slot, nil map value, nil return slot | +| absent field | closed record absence, open row-tail unknown, map-tail optional value, table-top field access, absence after mutation | +| soft evidence | soft table top, soft array element, soft map value, nil plus soft shape, hard evidence replacing soft evidence, soft evidence across imports | +| table top | `table`, `{...}`, `{[any]: any}`, arrays, maps, closed records, open records, unions with precise tables | +| row shape | open vs closed, readonly fields, optional fields, metatables, map component overlap, discriminant tags | +| truthiness | false/nil removal, literal false keys, `and`/`or` branch values, truthy field guards, truthy dynamic indexes | +| mutation | field write, nil overwrite, dynamic index write, table insert, container send, captured mutation, exported callback mutation | +| aliasing | local alias, field alias, imported alias, method receiver alias, self alias, cyclic alias, alias after reassignment | +| dominance | dominating writes, branch-local writes, loop-carried writes, post-dominated assertions, early returns, dead paths | +| functions | optional function values, union of function signatures, method `self`, varargs, higher-order callbacks, recursive locals | +| returns | zero returns, one return, two returns, more than two returns, tuple expansion, nil padding, recursive containers | +| relations | `(value, err)`, custom error record, multiple independent relations, swapped slots, relation through wrapper, relation through any | +| interproc | parent scope change, module boundary, literal signatures, captured fields, constructor fields, sibling overlay, stale snapshots | +| caching | stale query after fact change, query reuse after no-op fact change, cache key missing parent scope, cache key missing graph identity | +| performance | recursive structural scan, repeated AST projection, repeated map allocation, query dependency overhead, equality-time canonicalization | + +Adversarial cases must include both: + +- precision cases where the checker should infer the strongest provable type; +- soundness cases where similar-looking code must still fail. + +Examples: + +- guarded `options.model` should infer `string`; `provider_info as any` should + not become `string` without proof; +- `response.body or ""` should be `string`; `response.body` alone remains + `string?`; +- open row-tail field access is `unknown`; closed missing field is absent/nil + evidence depending on context; +- table insert before an `ipairs` loop should feed element type; branch-local + insert must not leak if the loop is not dominated by that branch; +- `test.is_nil(err)` may refine a related value slot only if a relation fact + proves the tuple contract. + +The suite should be generated around these matrices, not around the names of +the old helper functions. + ## Flash Migration Shape The implementation should be prepared privately but merged as a direct final From 8ba2ab37a0cfe9f3d236839eeb06938d2ea7cb31 Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 00:47:36 -0400 Subject: [PATCH 05/71] Extend checker inference design --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 182 +++++++++++++++++++++++++++++- 1 file changed, 181 insertions(+), 1 deletion(-) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index 9ed91567..223cc831 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -262,6 +262,126 @@ Owns facts emitted by a completed function analysis: The delta is immutable. The store combines it through `FactsDomain` only. +## Inference Model + +Inference is not a separate magical subsystem. It is the process of solving for +unknown slots in the product domain under the evidence produced by transfer. + +The final model should distinguish these inference layers: + +### Local Value Inference + +Scope: + +- local variables, +- field/index reads, +- table literals, +- expression results, +- branch-local values, +- loop-carried values. + +Authority: + +```text +AbstractState.ValueFacts + MemoryState + RelationFacts +``` + +Rules: + +- local value inference reads declared types, transfer assignments, and + constraints; +- it never writes interprocedural facts directly; +- it must preserve the distinction between `unknown` and `any`; +- table literal contextualization is a transfer/type-domain operation, not a + one-off hook; +- logical `and`/`or` inference must preserve actual Lua branch values. + +### Parameter Inference + +Scope: + +- call-site argument observations, +- body-derived obligations, +- source annotations, +- soft annotations, +- current function facts, +- function literal expectations. + +Authority: + +```text +ParameterEvidenceDomain +``` + +Rules: + +- call-site observations are evidence, not contracts; +- body obligations are contracts only when the function body proves it requires + that shape; +- explicit source annotations dominate inferred hints; +- soft annotations refine only when hard evidence proves the refinement; +- recursive parameter evidence must join/widen through the parameter domain. + +There should be no separate ad hoc policy for "param hints" versus "function +fact params". Both are parameter evidence with different provenance and merge +mode. + +### Return Inference + +Scope: + +- return statements, +- tuple/multivalue expansion, +- nil padding, +- recursive return vectors, +- summary and narrow summary slots, +- wrapper forwarding. + +Authority: + +```text +ReturnSummaryDomain +RelationDomain +``` + +Rules: + +- return arity is part of the tuple domain; +- `unknown` return evidence is unresolved runtime behavior, not bottom; +- recursive return vectors widen only at the SCC/fixpoint boundary; +- relation facts such as `(value, err)` attach to return tuples explicitly; +- wrapper forwarding propagates tuple and relation facts together. + +### Function Type Inference + +Scope: + +- local function literals, +- method receiver `self`, +- higher-order callbacks, +- literal signatures, +- exported functions, +- imported module functions. + +Authority: + +```text +FunctionFactDomain +ParameterEvidenceDomain +ReturnSummaryDomain +RelationDomain +``` + +Rules: + +- function type inference is a product of parameter evidence, return summary, + and relation/effect summaries; +- a same-body function fact may seed analysis only through non-narrowing domain + merge; +- higher-order signatures must use variance-aware merge rules; +- literal signatures are facts in the interproc product, not a second function + authority. + ## Phase Responsibility Table | Phase | May Create | May Combine | May Widen | May Query | Forbidden | @@ -432,6 +552,48 @@ type TupleRelation struct { The current `(value, err)` convention is then one predefined relation, not a special checker behavior. +### Effect Domain + +Owns facts about what a function or call can do: + +- termination, +- error-return relation attachment, +- path refinements caused by assertions/predicates, +- table/container mutation effects, +- callback effects, +- key-collector effects, +- external contract effects. + +Candidate home: + +```text +compiler/check/domain/effect +``` + +Effect inference must be a normal abstract-interpretation output: + +```text +CallSite + CalleeSummary + AbstractState -> EffectDelta +``` + +The effect delta is then applied by transfer or stored in function facts. It +must not be an after-the-fact patch that rewrites types without going through +the memory/relation/effect domains. + +Effect summaries should be explicit: + +```go +type EffectSummary struct { + Mutations []memory.Mutation + Relations []relation.TupleRelation + Refinements []relation.PathRelation + Terminates TerminationEffect +} +``` + +Current effects such as error-return correlation, captured container mutation, +and key-collector propagation become instances of this summary. + ### Function Fact Domain Owns all interprocedural facts about functions. @@ -539,6 +701,7 @@ type FactsDomain struct { LiteralSigs LiteralSignatureDomain Captures CaptureDomain Constructors ConstructorDomain + Effects EffectDomain } ``` @@ -575,6 +738,7 @@ helper joins directly. | path/query/alias identity | `constraint`, `flowbuild/path`, `flow/pathkey` | `memory` | | table/container mutation replay | `nested`, `returns`, `flowbuild`, `flow` | `memory` mutation domain | | error-return convention | `erreffect`, call/return inference | `domain/relation` | +| effect inference | `effects`, `erreffect`, `flowbuild`, `nested`, `returns` | `domain/effect` | | body parameter contracts | `infer/return`, `flowbuild/assign` | `domain/paramevidence` | | Salsa snapshot inputs | `store/snapshot_inputs.go` | keep in store, but document as cache boundary | @@ -893,7 +1057,22 @@ The final design should model tuple/path relations directly. `(value, err)` is then one relation instance. This keeps the system extensible without hardcoded branch helpers or return-slot checks. -### 6. Tests Are Too Positive-Heavy +### 6. Effect Inference Is Too Distributed + +Effects are currently inferred and replayed from several places: + +- call specs, +- error-return inference, +- captured field/container mutation collection, +- nested mutator replay, +- key collector detection, +- predicate/assertion refinements. + +Those are all effect facts. They need one summary model and one application path +through transfer. Otherwise each new effect creates its own mini analysis and +its own invalidation/caching risks. + +### 7. Tests Are Too Positive-Heavy Many external-lint regressions are "this must type-check" tests. Those are useful, but insufficient. They can pass through accidental broadening. @@ -925,6 +1104,7 @@ design is not complete until each row below has an owner domain and tests. | functions | optional function values, union of function signatures, method `self`, varargs, higher-order callbacks, recursive locals | | returns | zero returns, one return, two returns, more than two returns, tuple expansion, nil padding, recursive containers | | relations | `(value, err)`, custom error record, multiple independent relations, swapped slots, relation through wrapper, relation through any | +| effects | termination, assertion refinements, callback effects, captured mutation effects, key collection, external contract effects | | interproc | parent scope change, module boundary, literal signatures, captured fields, constructor fields, sibling overlay, stale snapshots | | caching | stale query after fact change, query reuse after no-op fact change, cache key missing parent scope, cache key missing graph identity | | performance | recursive structural scan, repeated AST projection, repeated map allocation, query dependency overhead, equality-time canonicalization | From 79dfc212953938fc85a42e8ed1761c9477715e6d Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 00:49:26 -0400 Subject: [PATCH 06/71] Clarify checker dataflow contract --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 342 ++++++++++++++++++++++++++++++ 1 file changed, 342 insertions(+) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index 223cc831..6b2c49bf 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -262,6 +262,254 @@ Owns facts emitted by a completed function analysis: The delta is immutable. The store combines it through `FactsDomain` only. +## Evidence Lifecycle + +Every fact in the checker should have a visible lifecycle: + +```text +Observed -> Located -> Qualified -> Transferred -> Joined -> Widened -> Queried -> Published +``` + +### Observed + +Evidence starts from one of a small number of sources: + +- source annotation, +- literal syntax, +- assignment, +- guard/predicate/assertion, +- call argument, +- call return, +- effect spec, +- table/container mutation, +- imported manifest, +- previous interproc snapshot. + +Observation records provenance. It does not decide final authority. + +### Located + +Every observation must attach to a location: + +- symbol, +- field path, +- index path, +- tuple slot, +- function graph, +- parent scope, +- call site, +- return site. + +Location must be canonical before the evidence enters transfer. This prevents +one helper using AST paths while another uses SSA path keys. + +### Qualified + +The evidence is tagged with its authority: + +```text +explicit annotation > hard proof > body obligation > call observation > +soft annotation > unresolved evidence +``` + +`any` is not "very strong evidence." It is dynamic top. `unknown` is not "safe +to ignore." It is unresolved evidence. These two facts must remain distinct in +every domain. + +### Transferred + +Transfer applies evidence to the current `AbstractState`. + +Examples: + +- assignment writes memory/value facts, +- guard writes relation and shape facts, +- call writes return tuple and effect facts, +- table insert writes a mutation fact, +- error-return check reads a tuple relation and narrows linked slots. + +Transfer does not call interproc merge functions. Transfer does not widen. + +### Joined + +Control-flow joins combine same-phase predecessor states through domain `Join`. +This is where branch evidence meets. + +Branch joins must preserve runtime alternatives. For Lua, `x or y` and `x and y` +return actual operand values, so the value domain cannot prune a live branch just +because the other branch is more precise. + +### Widened + +Widening is allowed only at named recursive boundaries: + +- loop fixpoint, +- local function SCC, +- interprocedural fixpoint, +- recursive type/shape growth boundary. + +Widening must be visible in code. If a helper "prefers" one side to force +stability, it is a widening rule and belongs to the domain that owns that +cycle. + +### Queried + +Queries produce read-only views: + +- type at point, +- narrowed path, +- field/index evidence, +- relation state, +- effect summary. + +Queries cannot publish facts. If a query has to synthesize new evidence to +answer correctly, that evidence belongs in transfer or in a cached derived input +computed before solving. + +### Published + +Only completed function analysis publishes interproc deltas. Publication is a +data move: + +```text +FunctionResult -> InterprocDelta -> FactsDomain.Join/Widen -> SnapshotInputs +``` + +Publication is not another inference pass. + +## Required Domain API Shape + +Every domain should expose the same conceptual operations even if Go uses +concrete types instead of generics everywhere. + +```go +type Domain[T any] interface { + Normalize(T) T + Leq(a, b T) bool + Join(a, b T) T + Meet(a, b T) T + Widen(prev, next T) T +} +``` + +Transfer is separate: + +```go +type Transfer[I any, S any] interface { + Apply(input I, state S) S +} +``` + +Query is separate: + +```go +type Query[S any, Q any, A any] interface { + Answer(state S, question Q) A +} +``` + +This separation is important: + +- `Join` and `Widen` do not inspect AST. +- `Transfer` does not know interproc storage. +- `Query` does not mutate state. +- `Normalize` is explicit and not hidden in equality. + +## Dataflow Walkthroughs + +### Guarded Field To Call Argument + +Pattern: + +```lua +if options.model then + provider.open(options.model) +end +``` + +Correct dataflow: + +1. `options.model` is observed as a field read. +2. The guard transfers a truthy relation for `Location(options, "model")`. +3. The call argument query reads that relation and answers `NonNil(modelType)`. +4. Parameter evidence records a call observation for the callee. +5. If the callee body requires `string`, body obligation and call observation + combine in `ParameterEvidenceDomain`. + +Wrong shape: + +- special-case `options.model` in call checking, +- make all truthy fields strings, +- accept `any` as string. + +### Table Insert To Later Iteration + +Pattern: + +```lua +table.insert(state.items, value) +for _, item in ipairs(state.items) do ... end +``` + +Correct dataflow: + +1. `state.items` resolves to one memory location. +2. `table.insert` transfers a `MutationTableElement` to that location. +3. Memory join preserves the element fact at the exact child path. +4. `ipairs` queries the array element evidence from memory. + +Wrong shape: + +- replay captured table insert through generic container mutation, +- let parent table literal shape override explicit child-path evidence, +- infer element type from the loop variable without memory provenance. + +### Error Return Correlation + +Pattern: + +```lua +local value, err = f() +test.is_nil(err) +value.field +``` + +Correct dataflow: + +1. `f()` returns a tuple with a relation summary. +2. Assignment binds tuple slots to locations. +3. `test.is_nil(err)` transfers a relation constraint on the error slot. +4. Relation query narrows the linked value slot. +5. Field access reads the narrowed value slot. + +Wrong shape: + +- hardcode `test.is_nil` as a value-slot refinement, +- assume every two-return function is `(value, err)`, +- drop tuple relation when a wrapper forwards returns. + +### Unknown External Payload + +Pattern: + +```lua +local payload = json.decode(raw) +needs_string(payload.name) +``` + +Correct dataflow: + +1. `json.decode` returns dynamic/unresolved data. +2. `payload.name` is unresolved or `any` depending on API contract. +3. Passing it to `string` must fail unless a guard, schema, cast, or contract + proves it. + +Wrong shape: + +- treat unknown external fields as strings because most callers expect strings, +- let table shape contextualization silently rewrite explicit `any`, +- clear global lint by broadening assignability. + ## Inference Model Inference is not a separate magical subsystem. It is the process of solving for @@ -961,6 +1209,26 @@ Current weak shape: - param-use projection can rescan AST bodies instead of reading a graph-indexed use summary. +Canonical Salsa wiring: + +```text +db.Input[ManifestKey] -> module/type environment queries +db.Input[GraphKey] -> graph-derived summaries +db.Input[InterprocGraphKey] -> function-result queries +db.Input[SymbolKey] -> constructor/refinement/effect summaries + +FuncResultQuery(GraphID, ParentHash) + reads graph bundle + reads interproc snapshot inputs + builds transfer program + solves abstract state + publishes immutable result +``` + +The query key is stable identity. The dependency edges come from the exact +inputs read during analysis. There should be no artificial revision number in +the function key and no manual cache clearing for correctness. + Final cache contracts: 1. Source inputs are `db.Input`s: @@ -1083,6 +1351,80 @@ Every major law needs: - a negative test proving sound rejection, - a domain law test proving normalize/join/widen idempotence and monotonicity. +## Anti-Pattern Catalog + +These shapes should be rejected during the flash migration. + +### Local Domain Predicate In An Orchestration Package + +Example smell: + +```go +func typeRefinesTableKeyByTruthiness(...) +``` + +If the helper defines what refinement means, it belongs to a domain package. +Orchestration packages can ask a domain whether a refinement is valid; they +cannot define the refinement locally. + +### Equality-Time Repair + +If equality normalizes, rebuilds, or reconciles facts to make two states look +equal, convergence bugs become invisible. + +Correct shape: + +```text +write boundary -> Normalize +merge boundary -> Join/Widen +equality -> structural comparison of canonical state +``` + +### Query-Time Fact Production + +If a query discovers a fact that later code relies on as if it were stored +analysis state, the system has a hidden analysis path. + +Correct shape: + +```text +query can memoize an answer, but cannot publish evidence +``` + +### Producer-Specific Merge + +If one producer has its own merge rules for a fact family, the product domain is +not canonical. + +Correct shape: + +```text +producer emits delta +store calls FactsDomain.Join or FactsDomain.Widen +``` + +### Compatibility View As Authority + +A projection may exist for display or API response, but not as stored authority. +If production code writes through a view, it recreates the legacy mirror problem. + +### Soundness Shortcut + +Any change whose main effect is "fewer external diagnostics because `any` now +passes" is rejected unless a domain proof explains why that `any` was not truly +dynamic. + +### Cache Without Input Contract + +Every cache must state: + +- exact key, +- immutable inputs, +- invalidation mechanism, +- whether it is semantic or performance-only. + +If the cache depends on phase call order, it is not SOTA. + ## Edge-Case Matrix The migration must consider edge cases beyond the failures already seen. The From 6b46fb3b570d333700db3673dc3f96c6829002bf Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 00:51:09 -0400 Subject: [PATCH 07/71] Sharpen checker design doctrine --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 297 ++++++++++++++++++++++++++++++ 1 file changed, 297 insertions(+) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index 6b2c49bf..d0cb3970 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -147,6 +147,35 @@ That sentence is the guardrail. If a function violates one of these rules, it is a design smell even if the behavioral test passes. +## One-Page Doctrine + +The final checker should fit in this operational doctrine. + +```text +1. Source syntax is lowered once into graph-indexed transfer IR. +2. Transfer IR is interpreted over one product AbstractState. +3. AbstractState owns every persistent intraprocedural fact. +4. Domain objects own every combine/refine/widen law. +5. Queries are read-only views over solved AbstractState. +6. Function inference publishes immutable InterprocDelta values. +7. FactsDomain is the only interprocedural merge/widen authority. +8. Salsa tracks immutable inputs and query dependencies. +``` + +Everything else is implementation detail. + +The doctrine gives a direct review test: + +- If code lowers syntax, it belongs in graph/IR/extract. +- If code changes state, it is transfer. +- If code combines facts, it is a domain operation. +- If code forces convergence, it is widening. +- If code answers a question, it is a query. +- If code crosses function/module boundaries, it emits or consumes a delta. +- If code caches, it must name immutable inputs and invalidation. + +No rule should need to be implemented twice under different helper names. + ## Canonical Dataflow Contract The final dataflow should have explicit boundary objects. @@ -415,6 +444,78 @@ This separation is important: - `Query` does not mutate state. - `Normalize` is explicit and not hidden in equality. +## Domain Invariant Ledger + +Each domain needs invariants that can be tested independently from the full +checker. These are the invariants that should guide the flash migration. + +### Value Domain Invariants + +- `unknown` means unresolved evidence and must not be silently dropped at + return, branch, table, or relation joins. +- `any` means dynamic top and must not satisfy concrete contracts without an + explicit proof, guard, schema, or cast. +- `nil` is a Lua value; absent field is structural absence; optional field is a + type-level allowance for absence/nil depending on context. +- soft evidence is lower authority than hard evidence, but `nil` alone does not + erase a soft structured shape. +- open row-tail field access produces row-tail evidence; closed missing field + does not. +- table top absorbs table-like precision only in domains where table-likeness is + the intended upper bound, not as a general precision eraser. + +### Memory Domain Invariants + +- every fact has exactly one canonical location; +- child-path facts outrank parent-derived fallback evidence for the same path; +- alias replay preserves identity and dominance; +- mutation replay preserves operator kind; +- nil overwrite and field deletion are represented explicitly; +- branch-local mutation does not leak unless control-flow dominance proves it. + +### Relation Domain Invariants + +- tuple/path relations are first-class facts; +- relation facts survive assignment, wrapper forwarding, and module export only + when slot/path identity is preserved; +- relation narrowing is bidirectional only when the relation declares it; +- a guard helper such as `is_nil` can apply a relation but cannot invent one. + +### Effect Domain Invariants + +- effects are summaries, not post-hoc type rewrites; +- effect application goes through transfer; +- captured effects preserve location, operator kind, and provenance; +- termination effects affect reachability before value queries; +- external contract effects are typed inputs, not hardcoded checker behavior. + +### Parameter Evidence Invariants + +- call observations are weaker than body obligations; +- body obligations are inferred only from actual body demand; +- source annotations remain authoritative; +- soft annotations can refine but not override hard proof; +- recursive parameter evidence widens at SCC/interproc boundaries only; +- function-fact params and param hints use the same evidence order. + +### Return Summary Invariants + +- tuple arity is explicit; +- nil padding is explicit; +- unknown return evidence is not bottom; +- relation summaries travel with tuple summaries; +- recursive container growth has one widening policy; +- narrow summary is derived from solved flow facts, not a second stored truth. + +### Interproc Facts Invariants + +- producers emit immutable deltas; +- store merge uses `FactsDomain.Join`; +- fixpoint boundary uses `FactsDomain.Widen`; +- equality compares canonical state only; +- derived views are not write targets; +- snapshot inputs mirror canonical read state exactly. + ## Dataflow Walkthroughs ### Guarded Field To Call Argument @@ -1425,6 +1526,120 @@ Every cache must state: If the cache depends on phase call order, it is not SOTA. +## Design Review Decision Tree + +Every future rule should be classified before code is written. + +### Is It About What A Type Means? + +Examples: + +- `unknown` vs `any`, +- open row tail, +- nilability, +- truthiness, +- soft evidence, +- table top. + +Owner: + +```text +ValueDomain +``` + +Reject if implemented in return inference, call checking, or postflow writer. + +### Is It About Where A Fact Lives? + +Examples: + +- field path, +- dynamic index, +- alias target, +- tuple slot, +- captured mutation target, +- receiver `self`. + +Owner: + +```text +MemoryDomain / Location model +``` + +Reject if every producer computes its own path identity. + +### Is It About How Facts Combine? + +Examples: + +- branch join, +- parameter evidence merge, +- return vector merge, +- function fact merge, +- recursive shape cutoff. + +Owner: + +```text +The domain that owns that fact family +``` + +Reject if implemented as a producer-specific helper. + +### Is It About When Analysis Converges? + +Examples: + +- loop widening, +- local function SCC widening, +- interproc widening, +- recursive type growth. + +Owner: + +```text +Widen operation of the relevant domain +``` + +Reject if hidden inside equality, query, or local preference helpers. + +### Is It About What A Call Does? + +Examples: + +- mutates a table, +- narrows an argument, +- returns `(value, err)`, +- terminates, +- invokes a callback, +- collects keys. + +Owner: + +```text +EffectDomain + RelationDomain + MemoryDomain transfer +``` + +Reject if modeled as a one-off postprocessing pass. + +### Is It About Reusing Work? + +Examples: + +- graph summaries, +- parameter-use summaries, +- function result, +- type operator query, +- shape classification. + +Owner: + +```text +Salsa query or explicit local cache with named inputs +``` + +Reject if invalidation depends on call order or hidden mutable state. + ## Edge-Case Matrix The migration must consider edge cases beyond the failures already seen. The @@ -1532,6 +1747,88 @@ The key rule: Orchestration packages may decide when a fact is produced. Domain packages decide what that fact means and how it combines. +## Minimum Final-Shape API Sketch + +This is not a transitional API. It is the smallest final surface that should +exist after the flash migration. + +```go +// compiler/check/analysis +type Engine struct { + Graphs GraphProvider + Domains Domains + Queries Queries +} + +func (e *Engine) AnalyzeFunction(input FunctionInput) FunctionResult +``` + +```go +// compiler/check/flowstate +type AbstractState struct { + Memory MemoryState + Values ValueFacts + Numeric NumericFacts + Shape ShapeFacts + Relations RelationFacts + Effects EffectFacts + Termination TerminationFacts +} + +func (s AbstractState) Join(other AbstractState, d Domains) AbstractState +func (s AbstractState) Widen(next AbstractState, d Domains) AbstractState +``` + +```go +// compiler/check/transfer +type Instruction interface { + Apply(state flowstate.AbstractState, d Domains) flowstate.AbstractState +} +``` + +```go +// compiler/check/domain +type Domains struct { + Value ValueDomain + Memory MemoryDomain + Relation RelationDomain + Effect EffectDomain + Parameter ParameterEvidenceDomain + Return ReturnSummaryDomain + Function FunctionFactDomain + Interproc InterprocFactsDomain +} +``` + +```go +// compiler/check/domain/interproc +type InterprocFactsDomain interface { + Normalize(api.Facts) api.Facts + Leq(a, b api.Facts) bool + Join(a, b api.Facts) api.Facts + Widen(prev, next api.Facts) api.Facts + Equal(a, b api.Facts) bool +} +``` + +```go +// compiler/check/query +type View interface { + TypeAt(point cfg.Point, loc memory.Location) typ.Type + RelationAt(point cfg.Point, rel relation.Query) relation.Answer + EffectAt(point cfg.Point, call CallSite) effect.Summary +} +``` + +The important part is not exact names. The important part is that: + +- state is one product; +- transfer mutates only that product; +- domains own all combination; +- query is read-only; +- interproc publication is delta-based; +- no package owns a shadow merge policy. + ## Verification Model For The Future Migration Required proof after the flash migration: From e63447ea6335f631f2a4fd25fe3d10b0d4f8346a Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 00:52:44 -0400 Subject: [PATCH 08/71] Define checker semantic atoms --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 262 ++++++++++++++++++++++++++++++ 1 file changed, 262 insertions(+) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index d0cb3970..b351b02d 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -176,6 +176,128 @@ The doctrine gives a direct review test: No rule should need to be implemented twice under different helper names. +## Semantic Atoms + +The final design should use a small shared vocabulary. These words should have +one meaning everywhere in the checker. + +### Value + +A `Value` is an abstract runtime Lua value. + +It can be concrete, literal, structural, function-like, `nil`, `unknown`, or +`any`. It is not a source annotation and not a location. A value domain may say +how values combine; it may not decide where a value came from. + +### Location + +A `Location` is an abstract program place where evidence can attach. + +Examples: + +- symbol at SSA version, +- field path, +- index path, +- tuple slot, +- receiver slot, +- captured variable, +- return slot, +- graph/function identity. + +Locations are canonical before transfer. AST paths and SSA paths cannot both be +authoritative. + +### Evidence + +`Evidence` is a value plus provenance and authority. + +Examples: + +- explicit annotation, +- hard runtime proof, +- body obligation, +- call observation, +- soft annotation, +- unresolved observation, +- imported dynamic value. + +Evidence is not automatically truth. Domains decide how evidence combines. + +### Fact + +A `Fact` is evidence that has been accepted into a domain state. + +Facts are persistent inside `AbstractState` or inside an immutable +`InterprocDelta`. Raw observations are not facts until transfer/domain logic +accepts them. + +### Constraint + +A `Constraint` restricts possible facts along a control-flow path. + +Examples: + +- truthy/falsy, +- type test, +- nil/non-nil, +- has-field, +- numeric bound, +- relation branch. + +Constraints do not mutate storage by themselves. Transfer applies them to +`AbstractState`; queries read the result. + +### Relation + +A `Relation` connects multiple locations. + +Examples: + +- return slot 1 being nil implies return slot 0 is non-nil, +- assertion on one symbol narrows a sibling path, +- method receiver relation to `self`, +- callback argument relation to caller state. + +Relations are not encoded as special value types. They are first-class domain +facts. + +### Effect + +An `Effect` describes what execution of a call or instruction can do. + +Examples: + +- mutate memory, +- terminate, +- refine an argument, +- produce a tuple relation, +- call a callback, +- collect keys. + +Effects are applied by transfer. They do not rewrite types directly. + +### Delta + +A `Delta` is a completed analysis contribution to another scope or iteration. + +Examples: + +- function fact delta, +- parameter evidence delta, +- captured mutation delta, +- constructor field delta, +- relation summary delta. + +Deltas are immutable. The store never lets a producer mutate canonical state in +place. + +### Snapshot + +A `Snapshot` is the immutable state observed by a query. + +Snapshots are cache inputs. If a snapshot changes, dependent queries must +revalidate through Salsa or an explicitly documented cache invalidation rule. + ## Canonical Dataflow Contract The final dataflow should have explicit boundary objects. @@ -345,6 +467,19 @@ soft annotation > unresolved evidence to ignore." It is unresolved evidence. These two facts must remain distinct in every domain. +The authority order is partial, not a simple global priority. For example: + +- explicit annotation dominates inferred shape for assignment checking; +- hard branch proof dominates soft annotation for narrowing; +- body obligation dominates call observation for parameter contracts; +- explicit `any` remains dynamic top and does not become concrete because a + later call expects concrete; +- unresolved `unknown` can be refined by proof, but cannot be silently replaced + by unrelated precision. + +This should become an explicit `EvidenceOrder`, not a set of local `if` +statements. + ### Transferred Transfer applies evidence to the current `AbstractState`. @@ -1292,6 +1427,115 @@ reasons. The design target is not to delete their semantics. The design target is to make them clients of the same domain objects instead of separate local machines. +## Dataflow State Machine + +The checker should have one visible state machine. + +```text +Unbuilt + -> GraphBuilt + -> IRBuilt + -> Solving + -> Solved + -> Inferred + -> Published + -> Snapshotted +``` + +### Unbuilt -> GraphBuilt + +Input: + +- source AST, +- parent scope, +- manifest environment. + +Output: + +- immutable graph bundle. + +No type-domain merge is allowed here. + +### GraphBuilt -> IRBuilt + +Input: + +- graph bundle, +- declared type environment, +- known effect specs. + +Output: + +- transfer program. + +This stage may observe syntax and produce instructions. It may not decide +fixpoint policy. + +### IRBuilt -> Solving + +Input: + +- transfer program, +- initial abstract state, +- domain set. + +Output: + +- evolving abstract state. + +All state changes go through transfer and domain operations. + +### Solving -> Solved + +Input: + +- worklist convergence, +- loop widening if needed. + +Output: + +- solved abstract state plus query view. + +No interproc publication happens before this state. + +### Solved -> Inferred + +Input: + +- query view, +- function body, +- relation/effect summaries. + +Output: + +- function result and interproc delta. + +Inference reads solved state. It does not create another path-sensitive solver. + +### Inferred -> Published + +Input: + +- immutable interproc delta. + +Output: + +- canonical fact product after join or widening. + +Only `FactsDomain` may combine this data. + +### Published -> Snapshotted + +Input: + +- canonical fact product. + +Output: + +- Salsa snapshot inputs and dependent query invalidation. + +No semantic repair is allowed here. Snapshotting is cache wiring only. + ## Salsa And Cache Model Current good shape: @@ -1864,6 +2108,24 @@ Required behavior suites: - interproc non-convergence fixtures, - external replay reductions. +## Review Checklist Before Coding + +Before implementing the flash migration, each proposed package should answer: + +- What domain or boundary object does this package own? +- What are the only mutable states in this package? +- Which operation is transfer, join, meet, widen, normalize, query, or publish? +- Which laws are tested at the package boundary? +- Which edge-case matrix rows does it cover? +- Which caches does it introduce, and what exact immutable inputs key them? +- Which old helper clusters will be deleted when this lands? +- Which production call sites will move directly to the final API? +- What negative tests prevent broadening `any`, erasing `unknown`, or treating + absence as nil in the wrong domain? + +If any answer is "handled by a fallback during migration", the design is not +ready. The next implementation must be flash migration, not coexistence. + ## Current Conclusion The checker is not fundamentally the wrong idea. It is closer to a serious From 57d9dd02265a11a2a4e2121fe1481a85b7456dd6 Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 00:56:16 -0400 Subject: [PATCH 09/71] Refine checker dataflow ownership design --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 514 ++++++++++++++++++++++++++++++ 1 file changed, 514 insertions(+) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index b351b02d..494302aa 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -651,6 +651,252 @@ checker. These are the invariants that should guide the flash migration. - derived views are not write targets; - snapshot inputs mirror canonical read state exactly. +## Domain Interaction Protocol + +The product-domain design is only useful if packages interact through a small +set of verbs. These verbs are the mental model for the future implementation. + +```text +Syntax/Graph -> Instruction -> Transfer -> Domain operation -> AbstractState +AbstractState -> Query -> Answer +FunctionResult -> InterprocDelta -> FactsDomain -> Snapshot +``` + +### Transfer + +Transfer applies one semantic instruction to one abstract state. + +Allowed: + +- read the instruction payload, +- ask a domain for local semantic operations, +- produce a new abstract state. + +Forbidden: + +- scan unrelated AST, +- publish interproc facts, +- mutate Salsa inputs, +- call compatibility projections, +- repair old facts. + +If a transfer needs a type meaning question, it asks `ValueDomain`. If it needs +a location question, it asks `MemoryDomain`. If it needs a correlation question, +it asks `RelationDomain`. It does not inline those laws. + +### Domain Operation + +A domain operation defines what a fact means and how it combines. + +Allowed: + +- normalize owned values, +- compare owned values, +- join owned values, +- meet or refine owned values, +- widen owned values, +- answer pure owned-domain predicates. + +Forbidden: + +- depend on source syntax, +- depend on checker phase order, +- allocate hidden facts in another domain, +- read mutable store state, +- perform query invalidation. + +Domain operations must be deterministic and law-testable without constructing a +whole checker. + +### Abstract State + +`AbstractState` is the one mutable semantic product during analysis. + +Allowed: + +- hold domain components, +- combine components through their domains, +- expose read-only query views after solving. + +Forbidden: + +- keep shadow facts that duplicate domain-owned facts, +- hide a second mini solver, +- let equality normalize, +- let queries write analysis evidence. + +### Query + +Queries answer questions against solved state. + +Allowed: + +- read state, +- memoize performance-only answers keyed by immutable input, +- project final user-facing answers. + +Forbidden: + +- create new evidence, +- change convergence, +- backfill facts into the store, +- call `Join` or `Widen`. + +If a query discovers that useful information is missing, the correct response is +to add a transfer/effect/domain fact that produces it before solving. The query +must not become a hidden analysis phase. + +### Publication + +Publication converts a solved function result into an immutable interproc delta. + +Allowed: + +- summarize returns, +- summarize parameter obligations, +- summarize captured effects, +- summarize relations, +- emit deltas. + +Forbidden: + +- merge deltas directly, +- reconcile legacy channels, +- mutate existing facts, +- apply caller-specific preferences. + +The only writer of canonical interproc state is `FactsDomain`. + +### Snapshot + +Snapshotting wires canonical facts into Salsa inputs. + +Allowed: + +- copy canonical facts into inputs, +- invalidate dependent queries through Salsa dependency tracking. + +Forbidden: + +- normalize, +- widen, +- infer, +- drop fields for compatibility, +- reconstruct projections that were not canonical facts. + +Snapshotting is cache plumbing. It is not part of type semantics. + +## Layering And Import Rules + +The final code should make illegal designs difficult to express. Package +dependencies should encode the semantic architecture. + +### Domain Packages + +Domain packages may import: + +- low-level type structures, +- subtype/query primitives, +- small immutable domain-local helper packages. + +Domain packages must not import: + +- AST packages, +- flow builders, +- checker store, +- Salsa database handles, +- diagnostics emitters, +- compatibility view builders. + +Reason: a domain is a pure algebra over facts. If it can see syntax or mutable +store state, local helper logic will grow back. + +### Memory And Location Packages + +Memory/location packages may import: + +- symbol/location identity, +- type values needed to represent field and container facts, +- relation keys where tuple/path identity must be preserved. + +They must not import: + +- call checking, +- interproc store, +- return inference, +- diagnostics formatting. + +Reason: every producer must use the same path identity rules. No producer should +construct its own equivalent of "field path", "tuple slot", or "receiver self". + +### Transfer Packages + +Transfer packages may import: + +- normalized checker IR, +- abstract state, +- domain set, +- memory/location model. + +They must not import: + +- old fact bridges, +- compatibility projections, +- checker diagnostics as control flow, +- global interproc store mutation. + +Reason: transfer is the executable abstract semantics for one instruction. It +can create deltas inside state, but publication happens later. + +### Store And Pipeline Packages + +Store/pipeline packages may import: + +- domain interfaces, +- abstract interpreter engine, +- Salsa database handles, +- diagnostics/reporting. + +They must not implement: + +- truthiness laws, +- soft/hard evidence ordering, +- return tuple relation semantics, +- path dominance rules, +- recursive type widening. + +Reason: orchestration controls when analysis runs. Domains control what analysis +means. + +### Query Packages + +Query packages may import: + +- read-only solved state, +- Salsa query APIs, +- pure domain predicates used for answering. + +They must not import: + +- mutable transfer state, +- publication writers, +- domain normalization writers. + +Reason: a query can be cached aggressively only when it is pure. + +### Test Packages + +Tests should mirror these boundaries: + +- domain law tests construct only domain values, +- transfer tests build small IR fragments and inspect abstract state, +- solver tests check convergence and widening, +- replay tests validate production programs, +- negative tests prove that convenience broadening did not happen. + +Tests that require a whole checker to prove a simple domain law are a signal +that the domain boundary is still too implicit. + ## Dataflow Walkthroughs ### Guarded Field To Call Argument @@ -866,6 +1112,102 @@ Rules: - literal signatures are facts in the interproc product, not a second function authority. +### Effect Inference + +Scope: + +- built-in and manifest call effects, +- assertion/predicate refinements, +- table and container mutations, +- callback invocation effects, +- termination and non-returning calls, +- return tuple relation attachment, +- captured mutation summaries, +- external contract effects. + +Authority: + +```text +EffectDomain +MemoryDomain +RelationDomain +TerminationDomain +``` + +Rules: + +- an effect is an abstract transfer summary, not a postflow patch; +- applying an effect must produce the same state change as inlining its + corresponding transfer instructions would produce, up to the abstraction; +- effect summaries preserve target locations, tuple slots, operator kind, + dominance, and provenance; +- effects that refine values must emit relation/value constraints through the + owning domains; +- effects that mutate memory must emit memory mutations through the memory + domain; +- effects that terminate execution must update reachability before any value + query observes the post-call state; +- callback effects are higher-order summaries and must be applied at the call + edge that invokes the callback, not at publication time; +- external effects are typed inputs to the domain, not hardcoded names in call + checking. + +Wrong effect inference shapes: + +- "after this call, rewrite argument type" in call checking; +- "after this function, patch captured fields" in interproc merge; +- "if function name is `test.is_nil`, narrow slot" outside relation/effect + transfer; +- "if table mutator is seen later, replay as generic container mutation"; +- "if global harness fails, add a special accepted shape". + +Correct effect inference shape: + +```text +call instruction + -> resolve effect summary + -> EffectDomain.Apply(summary, state) + -> MemoryDomain/RelationDomain/ValueDomain/TerminationDomain operations + -> new AbstractState +``` + +Effect inference must be compositional. A user-defined wrapper around an effect +should publish the same kind of summary that the built-in effect uses, so callers +do not need wrapper-specific logic. + +### Inference Soundness Boundary + +The checker should infer every property that is proven by: + +- source annotations, +- reachable transfer facts, +- memory/path identity, +- relation facts, +- effect summaries, +- interproc summaries, +- declared external contracts. + +The checker must not infer a property from: + +- the type expected by a later failing call, +- most callers preferring a shape, +- `any`, +- absent evidence, +- a compatibility projection, +- a cache hit whose input identity is incomplete. + +This boundary is the core soundness rule: + +```text +Expected type is a constraint to check against evidence. +It is not evidence unless a declared contract explicitly says so. +``` + +Contextual typing is still valid, but it must be represented as evidence with +provenance. For example, a table literal checked against an expected record can +receive contextual field types at the literal boundary. A dynamic payload flowing +through `any` cannot acquire those field types because a callee wanted them. + ## Phase Responsibility Table | Phase | May Create | May Combine | May Widen | May Query | Forbidden | @@ -1605,6 +1947,53 @@ Performance target: - stable interning/hash-consing where already available, - no object pools until ownership is proved and structural wins are exhausted. +### Cache Placement Decision Model + +Use Salsa when: + +- the computation is pure, +- the inputs are immutable identities, +- dependency tracking can precisely invalidate dependent queries, +- the result is reused across functions, modules, or fixpoint iterations, +- recomputation is more expensive than dependency tracking. + +Use a per-function local cache when: + +- the computation is hot inside one solve, +- the cache key is a small local identity, +- the result is invalid after the current function solve, +- Salsa dependency tracking would be more expensive than recomputation. + +Use no cache when: + +- the operation is a cheap domain primitive, +- the input is already interned, +- the allocation is caused by poor ownership rather than repeated work, +- correctness would require observing mutable phase order. + +Do not use a pool until: + +- the allocation site remains hot after domain consolidation, +- ownership of each pooled object is single-phase and obvious, +- tests prove no retained result can observe a reused object, +- profiling shows the pool wins after synchronization and clearing costs. + +The main expected Salsa gains are: + +- graph-indexed parameter-use summaries instead of AST rescans, +- function-result queries keyed by graph and parent scope, +- pure type/operator queries, +- shape classification for large recursive types if profiling proves reuse, +- canonical interproc snapshots as inputs instead of manually invalidated maps. + +The main non-Salsa gains are: + +- domain operations that avoid rebuilding maps for no-op joins, +- path/location interning, +- copy-on-write fact vectors, +- removing compatibility projections from hot publication paths, +- making equality structural instead of repair-driven. + ## Weak Points To Fix In The Design ### 1. Domain Laws Are Not Named @@ -1770,6 +2159,57 @@ Every cache must state: If the cache depends on phase call order, it is not SOTA. +## Failure Taxonomy + +Future regressions should be classified by failed domain responsibility, not by +the helper function that happened to produce the symptom. + +| Symptom | Likely Owner | First Question | +|---|---|---| +| guarded field still nilable at call site | `RelationDomain` or `MemoryDomain` | Did the guard create a path relation for the same location queried by the call? | +| error-return refinement does not affect value slot | `RelationDomain` | Was the tuple relation preserved through return assignment and wrapper forwarding? | +| external dynamic value passes concrete parameter | `ValueDomain` or `ParameterEvidenceDomain` | Did `any` get treated as proof instead of dynamic top? | +| unknown disappears from return summary | `ReturnSummaryDomain` | Did join/widen erase unresolved evidence? | +| nil field write behaves like absent field | `MemoryDomain` | Was nil overwrite represented as a value fact instead of structural deletion? | +| closed missing field behaves like open row-tail | `ValueDomain` or `MemoryDomain` | Was openness carried on the record/map component being queried? | +| table insert lost before iteration | `MemoryDomain` and `EffectDomain` | Was mutation replay attached to the canonical child location and operator kind? | +| recursive type keeps growing | owning domain `Widen` | Is growth bounded at the correct SCC/fixpoint boundary? | +| result changes after no semantic input changed | Salsa/cache layer | Is a cache keyed by mutable state or phase order? | +| result does not change after facts changed | Salsa/cache layer | Did the query read the canonical snapshot input that changed? | +| lint clears by accepting too much | `ValueDomain` or assignability boundary | Which negative test proves the new acceptance is sound? | +| repeated performance hot spot after caching | domain/query boundary | Is the computation duplicated because the owner is unclear? | + +Classification rule: + +```text +If a symptom requires reading three unrelated helpers to understand why it +happened, the domain model is still wrong. +``` + +The fix should move the law to the owner, delete the scattered helpers, and add +domain law tests plus one production-shaped replay test. + +## Traceability Matrix + +Every high-value behavior should be traceable from syntax to proof. + +| Behavior | Producer | Canonical Fact | Consumer | Proof | +|---|---|---|---|---| +| truthy field guard | condition transfer | path truthiness relation | call/type query | relation law + guarded-call fixture | +| `test.is_nil(err)` success branch | predicate effect transfer | tuple-slot relation constraint | value-slot query | relation law + error-return fixture | +| body demands parameter field | transfer over field read/use | parameter obligation | interproc fact join | parameter evidence law + SCC fixture | +| call observes argument type | call transfer | call observation | parameter evidence join | authority-order law + negative any fixture | +| table insert mutates array element | effect transfer | container element mutation | iteration query | memory law + dominance fixture | +| nil overwrite | assignment transfer | explicit nil value or deletion effect | field query | nil/absent law + record fixture | +| wrapper forwards returns | return transfer | tuple relation preservation | caller assignment | relation preservation law + wrapper fixture | +| imported dynamic payload | external contract transfer | `any` or `unknown` with provenance | assignability check | value law + negative concrete-param fixture | +| recursive local function | SCC solver | widened param/return evidence | function result query | widen law + convergence fixture | +| module export | publication | immutable interproc delta | dependent Salsa query | snapshot dependency test | + +This matrix is not a test list by itself. It is the audit trail showing that a +behavior has one producer, one canonical representation, one consumer path, and +one proof family. + ## Design Review Decision Tree Every future rule should be classified before code is written. @@ -1954,6 +2394,80 @@ No step should leave: - fallback normalization in equality, - broad `any` acceptance to clear lints. +## Flash Cutover Gate + +The flash migration should be reviewed as one semantic cutover, not as a chain +of transitional accommodations. The cutover is ready only when the following +artifacts can be listed before coding starts. + +### Deletion Map + +For each old helper cluster: + +- current file/package, +- semantic law it currently approximates, +- final domain owner, +- final API call site, +- tests that replace helper-specific tests, +- commit in which the helper disappears. + +If a helper cannot be mapped to a domain owner, the design is incomplete. If it +maps to more than one owner, the fact representation is probably mixed and must +be split before implementation. + +### Replacement Map + +For each production call site: + +- current call, +- final call, +- expected semantic output, +- changed cache dependency if any, +- changed diagnostic behavior if any. + +The migration should not introduce "temporary" calls that are expected to be +removed later. A call site either moves to the final API or stays unchanged until +the cutover is ready. + +### Proof Map + +For each domain law: + +- unit law test, +- one positive checker fixture, +- one negative checker fixture when soundness could be weakened, +- one replay/global-harness case if the law came from real code. + +No proof should depend only on external lint going quiet. The suite must show +both the precision gain and the rejection boundary. + +### Performance Map + +For each expensive operation touched: + +- current benchmark/profile location, +- final owner, +- expected cache key or no-cache reason, +- allocation behavior, +- invalidation story. + +Performance work should favor fewer repeated analyses and fewer duplicated data +structures before object pools. Pools are allowed only after ownership is clear +and tests prove no fact lifetime can leak across checks. + +### Cutover Rejection Rules + +Reject the migration if it contains: + +- compatibility authority, +- fallback repair, +- two writers for one fact, +- query-time publication, +- equality-time normalization, +- broad assignability introduced only to clear production code, +- new cache without an immutable input contract, +- new helper whose name describes a case instead of a domain law. + ## Proposed Final Package Map ```text From f9f9500e5b559d076e5f2488603d9a1124c981a0 Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 00:58:08 -0400 Subject: [PATCH 10/71] Document checker Salsa wiring plan --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 173 ++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index 494302aa..d55d7371 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -1994,6 +1994,179 @@ The main non-Salsa gains are: - removing compatibility projections from hot publication paths, - making equality structural instead of repair-driven. +### Concrete Salsa Wiring Plan + +The final design should classify every current cache and summary producer before +implementation. The goal is not "put everything in Salsa". The goal is exact +incremental boundaries and no hidden semantic cache. + +| Current Component | Final Role | Cache Kind | Owner | +|---|---|---|---| +| `api.FuncKey{GraphID, ParentHash}` | function analysis identity | Salsa query key | pipeline/analysis engine | +| `FuncResultQ` | analyze one function under one parent scope | Salsa query | analysis engine | +| `snapshotInputs.facts` | canonical interproc fact snapshot | Salsa input | store/facts domain boundary | +| `snapshotInputs.refinements` | function refinement/effect snapshot | Salsa input | effect/refinement boundary | +| `snapshotInputs.constructorFields` | constructor field snapshot | Salsa input | memory/constructor boundary | +| `types/query/core.Engine` | pure type operations | query engine cache | type-query layer | +| `types/flow.ProductDomain` | branch-local narrowing algebra | ephemeral domain state | abstract state / flow domain | +| `paramhints.collectParamUses` | body-demand summary | graph-derived Salsa query | graph summary layer | +| `ProjectHintsToParamUse` | parameter evidence projection | domain operation over cached body summary | parameter domain | +| `PreCache` / `NarrowCache` | repeated expression synthesis inside one solve | per-function local cache | transfer/query phase | +| `FunctionTypeCache` | local function specialization during one solve | per-function local cache unless key is immutable | function analysis | +| `StableFunctionSnapshot` | read canonical function fact snapshot | Salsa query/input read, not ad hoc map | function fact domain | +| flow solution narrow caches | repeated solved-state query | solved-state local cache | query view | +| path suffix/root caches | identity interning | local/global intern cache if immutable | memory/location layer | + +This table is a migration contract. If a component does not appear here or in a +successor table before coding, adding a cache for it should be rejected. + +### Query Dependency Contract + +`FuncResultQ` must read all semantic dependencies through tracked inputs or +tracked pure queries. + +Required reads: + +- graph bundle by `GraphID`, +- parent scope by `ParentHash`, +- canonical interproc facts by `GraphKey`, +- function refinements/effects by owning symbol key, +- constructor fields by owning symbol key, +- manifest/module environment through manifest inputs, +- graph-derived body summaries through graph summary queries, +- pure type operations through the type-query layer. + +Forbidden reads: + +- mutable `InterprocPrev` maps without snapshot input tracking, +- current-iteration `InterprocNext` except through the canonical overlay input + contract, +- ad hoc stable snapshot maps inside synthesis, +- source AST rescans for reusable graph summaries, +- global variables whose mutation does not bump a tracked input. + +When a function reads a fact for a graph or symbol, the query database must know +that dependency. When the fact does not change semantically, the input should not +be rewritten. This gives both correctness and performance: no stale result, no +unnecessary invalidation. + +### Snapshot Update Protocol + +The store should be the only bridge from fixpoint state to Salsa inputs. + +```text +producer emits InterprocDelta + -> FactsDomain.Join/Widen into InterprocNext + -> iteration boundary computes canonical InterprocPrev + -> compare canonical old/new with structural equality + -> set only changed snapshot inputs + -> Salsa revalidates dependent FuncResultQ entries +``` + +Required properties: + +- `setFacts` receives canonical facts only; +- equality is structural and does not normalize; +- empty facts are represented explicitly enough to clear stale inputs; +- per-symbol inputs are used only for facts whose key is truly symbol-local; +- parent-scoped facts use `GraphKey` or `SymbolKey`, not raw `SymbolID`; +- current-iteration overlay is either part of the canonical input contract or is + not visible to `FuncResultQ`. + +This avoids manual cache clearing as a correctness mechanism. Clearing may still +exist as a memory-pressure tool, but a correct result must not depend on it. + +### Graph Summary Queries + +Several expensive operations are currently repeated because syntax-derived +summaries are computed by the consumer. These should become graph summary +queries. + +Recommended summaries: + +- parameter-use summary by `GraphID` and function symbol, +- return-site summary by `GraphID`, +- local function/call graph summary by `GraphID`, +- table mutator call summary by `GraphID`, +- key-collector summary by `GraphID`, +- captured variable/path summary by `GraphID`, +- normalized transfer program by `GraphID` plus declared environment identity. + +These queries read immutable graph/source data and produce immutable summaries. +They do not read interproc facts and they do not infer types. The analysis query +then combines those summaries with parent scope and interproc snapshots. + +### Hot Local Cache Contract + +Some caches should remain local because they are only useful during one solve. + +Local cache keys must include: + +- phase (`declared`, `preflow`, `narrow`, or final query), +- expression identity or normalized instruction identity, +- CFG point, +- parent scope identity when the answer can depend on scope, +- solved-state token when the answer depends on flow facts. + +Local caches must not: + +- survive across `FuncResultQ` computations unless the key is fully immutable, +- contain mutable domain state, +- publish facts, +- suppress dependency tracking by reading snapshots behind Salsa's back. + +This keeps hot expression synthesis fast without making it a second semantic +store. + +### Type Query Layer Contract + +The core type query engine is already the right kind of abstraction for +field/index/operator/subtype queries: pure inputs, stable type identities, and +memoized expensive structural work. + +Final rules: + +- checker domains may call pure type queries; +- type queries must not read checker store state; +- type query caches are performance-only; +- type query answers must be invalidated or keyed by all external type-provider + inputs they depend on; +- domain law tests should not depend on query cache hit order. + +This means Salsa does not replace `types/query/core`. Salsa coordinates checker +analysis dependencies. The type query engine owns repeated structural type +operations. + +### Performance Proof Requirements + +A performance correction is accepted only with a before/after profile or +benchmark that names the reduced work. + +Required measurements for the flash migration: + +- large-function checker benchmark, +- representative interproc convergence fixture, +- production replay wall time, +- allocation profile for hot joins and expression synthesis, +- cache hit/miss counters for `FuncResultQ` and graph summary queries, +- number of snapshot inputs rewritten per fixpoint iteration. + +Expected improvements: + +- fewer `collectParamUses` rescans, +- fewer repeated local function snapshot syntheses, +- fewer map allocations in no-op fact joins, +- fewer invalidated function queries after no-op fact updates, +- fewer expression synthesis calls during narrow/final query phases. + +Regression rule: + +```text +If a performance win comes from accepting less precise facts, it is invalid. +If a precision win causes repeated semantic recomputation, the cache boundary is +wrong and must be fixed before the flash migration lands. +``` + ## Weak Points To Fix In The Design ### 1. Domain Laws Are Not Named From 88a4954dd648974f514d4d53d1b762a19c37b564 Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 00:59:54 -0400 Subject: [PATCH 11/71] Define checker abstract machine model --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 302 ++++++++++++++++++++++++++++++ 1 file changed, 302 insertions(+) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index d55d7371..1bac0322 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -176,6 +176,178 @@ The doctrine gives a direct review test: No rule should need to be implemented twice under different helper names. +## Abstract Machine Specification + +The final checker should be specified as a small abstract machine. This gives the +code a single target shape and gives reviews a way to reject scattered helper +logic. + +```text +Machine = + Inputs + + Program + + State + + Domains + + Worklist + + QueryView + + Publisher +``` + +### Inputs + +Inputs are immutable during one function analysis query: + +- graph identity, +- parent scope identity, +- manifest/module environment, +- declared type environment, +- canonical interproc snapshot, +- constructor snapshot, +- effect/refinement snapshot, +- graph summaries, +- pure type-query engine. + +Inputs are the only values allowed to affect the answer besides the transfer +program. If an answer depends on something not listed here, the dependency model +is incomplete. + +### Program + +The program is normalized checker IR: + +- no source AST policy decisions, +- no hidden synthesis callbacks, +- no direct store mutation, +- no cache-dependent control flow. + +Each instruction has one meaning as a transfer over `AbstractState`. + +### State + +The state is the product: + +```text +State = + Memory + x Values + x Shapes + x NumericFacts + x Relations + x Effects + x Termination + x DiagnosticsEvidence +``` + +`DiagnosticsEvidence` is not user diagnostics. It is proof metadata such as +"this constraint failed here" or "this widening lost precision here". User +diagnostics are emitted after solving by querying this evidence. This keeps +diagnostic formatting out of domain semantics. + +### Domains + +Domains define the algebra: + +```text +Normalize, Leq, Join, Meet, Refine, Widen, Equal +``` + +Every operation must be local to its owned component or explicitly part of a +product operation. For example, relation transfer can ask value and memory +domains to interpret a path predicate, but it cannot create a private value +merge law. + +### Worklist + +The worklist owns traversal, not meaning. + +Allowed: + +- schedule CFG points, +- schedule SCC members, +- detect local stabilization, +- invoke loop/SCC widening at declared boundaries. + +Forbidden: + +- prefer one fact over another, +- normalize facts, +- publish interproc state, +- recover precision after widening. + +If the worklist needs semantic information to decide convergence, that +information must be exposed through `Leq` or `Equal` on the relevant domain. + +### QueryView + +The query view is a read-only projection over solved state. + +It answers: + +- type at location/point, +- relation at location/point, +- effect summary at call/function boundary, +- return tuple summary, +- parameter obligation summary, +- diagnostic projection. + +It must not write facts, widen, repair state, or backfill caches that later act +as analysis state. + +### Publisher + +The publisher converts solved state into immutable deltas: + +```text +State -> FunctionResult -> InterprocDelta +``` + +The publisher does not merge with previous results. It does not reconstruct +legacy channels. It emits the final product-domain representation expected by +`FactsDomain`. + +### Machine Transition Rules + +The core machine transitions are: + +```text +step(instruction, state) = Transfer.Apply(instruction, state, domains) +join(predStates) = AbstractState.Join(predStates, domains) +widen(prev, next) = AbstractState.Widen(prev, next, domains) +query(state, question) = QueryView.Answer(state, question) +publish(state) = InterprocDelta +``` + +Every specialized feature should reduce to these transitions: + +- branch narrowing is transfer plus join; +- field writes are memory transfer; +- table mutators are effect transfer plus memory transfer; +- assertions are effect transfer plus relation/value refinement; +- error-return behavior is relation transfer over tuple slots; +- callback behavior is higher-order effect transfer; +- local function inference is an SCC over function-state summaries; +- interproc inference is a fixpoint over `InterprocDelta` values. + +If a feature cannot be expressed this way, either the machine is missing a +domain or the feature is implemented at the wrong layer. + +### Machine Laws + +The implementation should preserve these laws: + +- Transfer is monotone with respect to domain `Leq`. +- Join is least-upper-bound or a documented approximation. +- Meet/refine never invents evidence without provenance. +- Widen is only applied at explicit recursive boundaries. +- Normalize is idempotent and is not hidden in equality. +- Query is pure over solved state. +- Publication is deterministic. +- Cache hits do not change semantics. +- Diagnostics are projections of evidence, not sources of evidence. + +These laws should become test names. A regression that violates one of them is a +design regression, not a local bug. + ## Semantic Atoms The final design should use a small shared vocabulary. These words should have @@ -1878,6 +2050,136 @@ Output: No semantic repair is allowed here. Snapshotting is cache wiring only. +## Nested Fixed-Point Model + +The final checker has several fixed points, but they should all use the same +domain vocabulary. The existence of multiple schedules does not justify +multiple semantic models. + +### Level 0: Pure Graph Summaries + +Graph summaries are not fixpoints over types. They are immutable facts about +syntax and binding: + +- parameter uses, +- return sites, +- local function edges, +- call sites, +- mutator sites, +- captured path mentions, +- normalized transfer instructions. + +They can be cached by graph identity because they do not read interproc facts or +solved flow state. + +### Level 1: Intraprocedural CFG Fixpoint + +The local solver computes: + +```text +CFG x TransferProgram x InitialState -> SolvedState +``` + +Convergence boundary: + +- CFG joins use `AbstractState.Join`; +- loops use the relevant domain widen only when a loop-carried component grows + past the domain's finite-height fragment; +- dead/unreachable paths update termination/reachability before value queries. + +Forbidden: + +- AST rescans during solve, +- producer-specific joins, +- query-time narrowing that writes state, +- loop-specific precision hacks outside domain widening. + +### Level 2: Local Function SCC Fixpoint + +Local functions inside a graph can be mutually recursive. The final model should +treat their summaries as another domain product: + +```text +FunctionSummary = + Parameters + x Returns + x Relations + x Effects + x Captures +``` + +Convergence boundary: + +- recursive calls read the current SCC summary through the function fact domain; +- each function body emits a new summary delta; +- SCC join/widen uses the same parameter, return, relation, effect, capture, + and memory domains used elsewhere; +- when the SCC stabilizes, the solved summaries become ordinary evidence for + the enclosing function analysis. + +This replaces "return overlay", "preflow synthesis", and "local function +snapshot repair" as separate semantic concepts. Those may remain as scheduling +or performance techniques, but not as separate laws. + +### Level 3: Interprocedural Fixpoint + +The outer fixpoint computes canonical facts across function/module boundaries: + +```text +InterprocPrev + all FunctionResult deltas -> InterprocNext +InterprocPrev' = FactsDomain.Widen(InterprocPrev, InterprocNext) +``` + +Convergence boundary: + +- producers emit immutable deltas; +- `FactsDomain` is the only merge/widen authority; +- no producer reads its own just-emitted delta except through the declared + current-iteration overlay contract; +- equality checks canonical state only; +- snapshot inputs are updated only after the canonical product changes. + +Iteration caps are diagnostics, not semantics. If convergence requires raising a +cap for normal programs, the relevant `Widen` is missing or too precise. + +### Level 4: Incremental Revalidation + +Salsa does not define type semantics. It revalidates query results after inputs +change. + +Required dependency shape: + +```text +FuncResultQ + reads GraphSummaryQ + reads Manifest/Input queries + reads SnapshotInputs + reads TypeQuery caches + computes local fixed points + publishes deltas +``` + +When a snapshot input is unchanged, dependent results should revalidate without +re-solving. When a graph summary is unchanged, function queries should not rescan +the AST to rediscover it. When a type-query cache hits, it should only avoid +structural recomputation; it must not mask missing checker dependencies. + +### Fixed-Point Proof Obligations + +Each level needs a proof surface: + +- finite input identity, +- monotone transfer or documented approximation, +- explicit join/widen boundary, +- stable equality without repair, +- deterministic publication, +- cache invalidation by immutable dependency. + +Performance and soundness meet at these obligations. A cache that is missing a +dependency is unsound. A widen that erases too much precision causes false +positives. A join that keeps rebuilding equivalent maps causes unnecessary +invalidations. + ## Salsa And Cache Model Current good shape: From 3b4f8d8c5e64ad82c74161116d78e55e11653b07 Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 01:01:24 -0400 Subject: [PATCH 12/71] Add checker dataflow ownership ledger --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 217 ++++++++++++++++++++++++++++++ 1 file changed, 217 insertions(+) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index 1bac0322..20c96e49 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -348,6 +348,223 @@ The implementation should preserve these laws: These laws should become test names. A regression that violates one of them is a design regression, not a local bug. +## Ownership Ledger + +Every semantic object should have one home. This table is the fastest review +tool for the future flash migration. + +| Object | Born In | Canonical State | Transformed By | Queried By | Published As | Cache Boundary | +|---|---|---|---|---|---|---| +| symbol identity | graph build | graph bundle | never semantically transformed | location resolver | graph key/symbol key | graph input | +| parent scope | scope build | immutable scope state | never semantically transformed | analysis key lookup | parent hash | `FuncKey`/`GraphKey` | +| field/index path | IR/path lowering | `Location` / `MemoryState` | memory transfer | query view | captured path/mutation delta | location interning | +| local value fact | transfer | `AbstractState.Values` | value domain | type-at query | return/param/capture delta when exported | per-function state | +| table shape fact | literal/assignment transfer | value + memory domains | value/memory domains | field/index query | function/captured/container delta | type query + local state | +| branch truthiness | condition transfer | relation/value constraints | relation/value domains | query view | relation summary if it crosses boundary | per-function state | +| nil/absent evidence | assignment/field transfer | memory + value domains | memory/value domains | field query | return/param/capture delta | per-function state | +| parameter observation | call transfer | parameter evidence domain | parameter domain | function summary query | function fact delta | interproc facts input | +| body obligation | body transfer | parameter evidence domain | parameter domain | function summary query | function fact delta | graph summary + state | +| return tuple | return transfer | return summary domain | return domain | return query | function fact delta | interproc facts input | +| tuple/path relation | predicate/effect/return transfer | relation domain | relation domain | relation query | relation summary delta | local state / interproc facts | +| table mutation | assignment/effect transfer | memory domain | memory domain | iteration/field query | captured container delta | local state / interproc facts | +| call effect | effect resolution | effect domain | effect domain | transfer/query view | refinement/effect delta | effect snapshot input | +| termination fact | transfer/effect transfer | termination domain | termination domain | reachability query | function effect delta | per-function state | +| diagnostic evidence | failed constraint transfer/query | diagnostics evidence state | diagnostic projection only | diagnostics pass | no semantic delta | result only | +| constructor field | constructor transfer/publication | constructor field domain | memory/value domains | constructor query | constructor snapshot | constructor input | +| external dynamic value | manifest/effect transfer | value evidence with provenance | value/domain checks | assignability query | only if exported with provenance | manifest/type input | + +Design rule: + +```text +If a row needs two canonical states, the model is split incorrectly. +If a row has no cache boundary, the implementation will invent one locally. +If a row has two publishers, legacy mirror channels are coming back. +``` + +## Dataflow Moral Rules + +The checker should be easy to explain because the direction of information never +reverses. + +### Syntax To Evidence + +Syntax can create observations. It cannot create authority by itself. + +Examples: + +- a table literal observes fields; +- a call observes arguments; +- a guard observes a branch condition; +- a return observes tuple slots. + +These observations become evidence only through transfer and domain +qualification. + +### Evidence To Fact + +Evidence becomes a fact when the owning domain accepts it into state. + +Examples: + +- a field observation becomes a memory fact at a canonical location; +- a truthy guard becomes a relation/value constraint; +- a body use becomes a parameter obligation; +- a call argument becomes a parameter observation. + +No producer decides global precedence. The evidence order belongs to the domain. + +### Fact To Answer + +Answers are read-only projections. + +Examples: + +- "what is the type here?", +- "does this path exclude nil?", +- "what does this function return?", +- "does this call terminate?", +- "which diagnostic should be emitted?". + +An answer cannot become a fact unless a later transfer explicitly observes it +and routes it through the owning domain. This prevents query-time analysis. + +### Fact To Delta + +Only solved facts that cross a function or module boundary become deltas. + +Examples: + +- local temporary narrowing does not publish; +- body obligation publishes as parameter evidence; +- return tuple publishes as return summary and relation summary; +- captured mutation publishes as memory/effect summary; +- external contract application does not rewrite the contract. + +The publisher emits a delta; `FactsDomain` combines it. + +### Delta To Snapshot + +Snapshots are cache inputs, not semantic repair points. + +Examples: + +- changed canonical facts update snapshot inputs; +- unchanged canonical facts do not invalidate queries; +- empty canonical facts clear stale inputs; +- compatibility projections are not written. + +This keeps incremental revalidation honest: Salsa tracks dependencies, domains +track meaning. + +## Boundary Invariants + +Every boundary in the dataflow should have a small invariant that can be tested +or reviewed directly. + +### Graph Boundary + +Invariant: + +```text +Graph identity changes only when syntax/binding identity changes. +``` + +This boundary may cache syntax summaries. It may not depend on interproc facts, +solved flow state, or expected call types. + +### IR Boundary + +Invariant: + +```text +Checker IR contains operations, not answers. +``` + +The IR may say "apply this call effect" or "assign this value to this +location". It may not pre-decide the result type of an operation whose answer +depends on flow/interproc state. + +### Transfer Boundary + +Invariant: + +```text +Transfer is the only state-writing semantics inside a function. +``` + +All writes to memory, value, relation, effect, and termination state must be +visible as transfer operations. A helper that writes state outside transfer is a +hidden interpreter. + +### Join Boundary + +Invariant: + +```text +Branch merge uses domain Join and nothing else. +``` + +A branch-specific merge helper is allowed only if it is the domain's exported +join/meet/refine operation. If it knows about AST shape, it is in the wrong +layer. + +### Widen Boundary + +Invariant: + +```text +Widen happens only at named recursive boundaries. +``` + +Loop widening, local function SCC widening, and interproc widening may have +different schedules, but they must call the same domain-level widening laws for +the same fact family. + +### Query Boundary + +Invariant: + +```text +Query answers cannot become stored evidence. +``` + +Query caches are permitted only for answers. They must not publish facts or +change future convergence. + +### Publication Boundary + +Invariant: + +```text +Publication emits immutable deltas and never merges them. +``` + +The same solved state must always produce the same delta. If publication reads +previous facts to decide how to shape the delta, it is doing merge work in the +wrong layer. + +### Snapshot Boundary + +Invariant: + +```text +Snapshot updates are semantic no-ops except for dependency invalidation. +``` + +Setting a snapshot input can make queries rerun. It cannot normalize, widen, +infer, or delete evidence except by reflecting the already-canonical facts. + +### Diagnostic Boundary + +Invariant: + +```text +Diagnostics observe proof failure; they do not define type behavior. +``` + +A diagnostic pass may ask why a check failed. It may not make the check pass or +fail by changing evidence. + ## Semantic Atoms The final design should use a small shared vocabulary. These words should have From c81e8961e3618d655a798f08c5b0cebb12637ba6 Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 01:03:05 -0400 Subject: [PATCH 13/71] Clarify checker evidence authority model --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 251 ++++++++++++++++++++++++++++++ 1 file changed, 251 insertions(+) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index 20c96e49..1da96589 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -565,6 +565,257 @@ Diagnostics observe proof failure; they do not define type behavior. A diagnostic pass may ask why a check failed. It may not make the check pass or fail by changing evidence. +## Evidence Authority Model + +The checker should be precise because it carries proof, not because it guesses. +Authority is therefore part of evidence. It is not a global total order; it is a +domain-specific partial order over a specific question. + +Canonical evidence shape: + +```text +Evidence = + Location + + Value/Predicate/Effect + + Provenance + + Authority + + Scope + + Phase + + SourceSpan +``` + +`SourceSpan` may be absent for synthetic or imported evidence, but provenance +must not be absent. + +### Authority Classes + +The final design should name these authority classes explicitly. + +| Authority | Meaning | Can Prove Concrete Contract? | Can Be Weakened By Join? | Can Publish? | +|---|---|---|---|---| +| explicit contract | user/API annotation or manifest contract | yes | only through declared variance/summary abstraction | yes | +| hard runtime proof | guard, assertion, dominance-proven assignment | yes | yes at control-flow join | if it crosses boundary | +| relation proof | fact derived from tuple/path relation | yes for related locations | yes when relation path is lost | if relation crosses boundary | +| effect proof | applied call/effect summary | yes if effect declares it | yes at join/widen | yes as effect/summary | +| body obligation | function body requires a shape | yes for parameter contract inference | yes at recursive widen | yes | +| call observation | caller passed a shape | no by itself | yes | yes as weak evidence | +| contextual literal evidence | expected type applied at literal boundary | yes for that literal | yes | yes if literal escapes | +| soft annotation | low-authority annotation hint | no without compatible proof | yes | only as soft evidence | +| unresolved observation | `unknown` | no | yes but not erased silently | yes as unknown | +| dynamic top | `any` | no without explicit cast/contract | yes as dynamic top | yes as any | + +This table prevents the common mistake of treating all useful evidence as the +same. A call observation is useful for inference, but it is not proof that the +callee accepts that shape. An explicit `any` is useful information, but it is +not proof of a concrete field. + +### Conflict Resolution + +Conflicts should be resolved by the owning domain, not by producer preference. + +| Conflict | Owner | Correct Resolution | +|---|---|---| +| hard proof vs soft annotation | evidence/value domain | hard proof wins for the proven path | +| explicit `any` vs expected concrete param | assignability/value domain | reject unless cast/contract proves concrete | +| unknown return vs concrete return | return domain | preserve unresolved behavior unless domain law proves refinement | +| call observation vs body obligation | parameter domain | body obligation is stronger contract evidence | +| parent table shape vs child-path write | memory domain | child-path fact wins for that path | +| closed missing field vs open row tail | value/memory domain | closed absence and open unknown tail stay distinct | +| relation proof vs unrelated assignment | relation/memory domain | relation survives only if location identity is preserved | +| widening precision loss vs later query | owning domain | query observes widened state; no post-widen repair | + +Conflict policy must be testable as a domain law. If the test has to construct a +whole checker to decide the conflict, the domain boundary is still too implicit. + +### Proof-Carrying Facts + +Every persistent fact should be explainable as: + +```text +fact = domain.accept(observation, provenance, authority, location) +``` + +Queries should be able to answer both: + +- the abstract answer, such as "this value is string"; +- the proof route, such as "truthy guard on this location removed nil". + +The proof route does not need to be exposed in normal diagnostics, but it must +exist in the design. Without it, the checker cannot distinguish real precision +from accidental broadening. + +### Precision And Soundness Contract + +Precision can increase only by proof. + +Allowed precision gains: + +- guard removes nil/false from the exact guarded location; +- assertion effect narrows the declared target relation; +- body obligation records a parameter shape the body actually reads; +- table literal contextual typing applies at the literal boundary; +- relation summary narrows linked tuple slots after a predicate. + +Forbidden precision gains: + +- callee expected type rewrites caller evidence; +- repeated callers vote a parameter into a concrete contract; +- `any` becomes a concrete record because a later field is used; +- closed missing field becomes open unknown tail to avoid an error; +- cached answer is reused after an untracked dependency changed. + +Precision can decrease only at named abstraction boundaries: + +- branch join, +- loop widening, +- local function SCC widening, +- interproc widening, +- published summary abstraction. + +Precision must not decrease at: + +- equality, +- snapshot update, +- diagnostics, +- compatibility projection, +- query cache lookup. + +This is the soundness/performance contract. Faster analysis is valid only if it +computes the same evidence or a documented domain approximation at a named +boundary. + +### Absence Of Evidence + +Absence is not a proof. + +Rules: + +- no field evidence does not mean field is nil; +- no relation evidence does not mean slots are independent if a relation was + dropped by a bug; +- no return evidence does not mean zero returns unless arity is known; +- no effect evidence does not mean pure call unless the effect row is closed; +- no param evidence does not mean `any`; it means unresolved until declared or + inferred evidence exists. + +This is where many false positives and false negatives start. The final domains +should model absence explicitly instead of using nil maps as semantic answers. + +## Dataflow Proof Traces + +Every important inference should have a trace format. This is not a logging +requirement for the first implementation. It is the mental model for proving the +checker did the right thing. + +Trace skeleton: + +```text +Observation + -> Location + -> Evidence + -> Domain acceptance + -> State fact + -> Join/Widen if any + -> Query answer + -> Publication if any +``` + +### Guarded Field Trace + +```text +Observation: if options.model then +Location: Location(options).field("model") +Evidence: truthy predicate, hard runtime proof +Domain: RelationDomain + ValueDomain +State: path excludes nil/false on true branch +Query: provider.open argument reads non-nil field type +Publish: none unless the relation escapes through a summary +``` + +Wrong trace: + +```text +provider.open expects string -> options.model becomes string +``` + +The wrong trace reverses dataflow. + +### Error Return Trace + +```text +Observation: local value, err = f() +Location: return tuple slots assigned to local locations +Evidence: f publishes tuple relation +Domain: RelationDomain accepts slot correlation +State: err nil branch relates value slot to success case +Query: value.field sees success-side value evidence +Publish: wrapper republishes tuple relation only if slot identity is preserved +``` + +Wrong trace: + +```text +function has two returns -> assume value/error convention +``` + +The wrong trace invents relation evidence from arity. + +### Dynamic Payload Trace + +```text +Observation: payload = json.decode(raw) +Location: payload +Evidence: imported dynamic value +Domain: ValueDomain records any/unknown with provenance +State: payload.name remains dynamic/unresolved +Query: needs_string(payload.name) requires proof +Publish: dynamic evidence only if exported +``` + +Wrong trace: + +```text +needs_string expects string -> payload.name becomes string +``` + +The wrong trace treats expected type as evidence. + +### Captured Mutation Trace + +```text +Observation: nested function inserts into state.items +Location: canonical location for state.items +Evidence: mutation effect with captured provenance +Domain: EffectDomain applies MemoryDomain mutation +State: array element fact at state.items +Query: ipairs reads element fact if dominance/escape permits it +Publish: captured container mutation delta if it crosses function boundary +``` + +Wrong trace: + +```text +captured mutation replay builds a new parent table shape +``` + +The wrong trace loses operator kind and child-path authority. + +### Trace Review Rule + +For any new inference, a reviewer should be able to ask: + +- What was observed? +- What is the canonical location? +- What authority does the evidence have? +- Which domain accepted it? +- Where can it lose precision? +- Which query read it? +- Does it publish, and if so as which delta? +- Which cache boundary owns reuse? + +If the answer starts with "this helper checks whether...", the design likely +needs another domain operation instead of another helper. + ## Semantic Atoms The final design should use a small shared vocabulary. These words should have From 4f0587c0123b55c9e26d1639a738081b726c8246 Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 01:04:51 -0400 Subject: [PATCH 14/71] Define checker location memory calculus --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 284 ++++++++++++++++++++++++++++++ 1 file changed, 284 insertions(+) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index 1da96589..f025f126 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -847,6 +847,290 @@ Examples: Locations are canonical before transfer. AST paths and SSA paths cannot both be authoritative. +## Location And Memory Calculus + +The final checker needs one answer to the question: + +```text +Are these two pieces of evidence about the same runtime place? +``` + +If that answer is local to each helper, precision will stay fragile. Guarded +fields, captured mutations, alias replay, tuple relations, and table-key +refinements all depend on the same location calculus. + +### Location Shape + +A location should be a canonical structured value, not a string path and not an +AST node. + +```text +Location = + Root + + Version + + PathSegments + + ScopeIdentity + + ProvenanceClass +``` + +Roots: + +- local symbol root, +- parameter root, +- receiver `self` root, +- upvalue/captured root, +- return tuple root, +- temporary tuple result root, +- module export root, +- constructor instance root, +- external/imported value root. + +Segments: + +- named field, +- literal index, +- dynamic index with key evidence, +- array element, +- map value, +- tuple slot, +- metatable/member access when modeled, +- synthetic effect target. + +`Version` belongs to the root or to a versioned location identity. It should not +be smuggled into a string suffix. `ScopeIdentity` is required for parent-scoped +facts so two equal-looking symbols in different parent scopes do not collide. + +### Canonicalization Laws + +Location canonicalization should obey these laws: + +- resolving the same symbol/path at the same CFG point returns the same + canonical location; +- resolving different lexical symbols never collides, even when names match; +- aliases are explicit equivalence/forwarding facts, not path rewrites; +- field and index segments are interned/normalized before storage; +- dynamic index evidence is preserved and not collapsed to `string` unless a + proof refines it; +- tuple slots remain tuple slots until assignment or forwarding gives them a + concrete destination; +- captured locations retain lexical owner identity; +- module/export locations retain module identity; +- open row-tail access and closed missing-field access produce different + locations/evidence. + +These laws should be tested without a whole checker. A location unit test should +be able to prove whether two references alias, differ, or are unknown. + +### Memory State Shape + +Memory state should be the product of several maps with one owner: + +```text +MemoryState = + ValueAt(Location) + + PresenceAt(Location) + + Children(Location) + + AliasFacts + + MutationLog + + DominanceFacts + + EscapeFacts +``` + +`ValueAt` says what value evidence is known at a location. +`PresenceAt` distinguishes present, absent, nil value, unknown presence, and +open row-tail unknown. +`Children` records known child facts without forcing a parent table rewrite. +`AliasFacts` records location identity relations and their dominance. +`MutationLog` records effectful writes with operator kind. +`DominanceFacts` tells whether a write/guard reaches a query point. +`EscapeFacts` tells whether a local fact can publish across a boundary. + +None of these should be represented by "map missing means nil". Absence of a map +entry means no stored fact for that component. + +### Read Law + +A memory read answers by ordered evidence, not by helper preference. + +Read order for a path should be: + +1. exact dominated location fact; +2. exact relation-refined fact for the same location; +3. exact child-path mutation fact; +4. alias-forwarded fact whose alias is valid at the query point; +5. declared/constructed parent shape projected through the path; +6. open row-tail evidence; +7. unresolved evidence. + +Forbidden read behavior: + +- expected callee type becomes read evidence; +- parent table shape overwrites explicit child mutation; +- closed missing field becomes open row-tail unknown; +- dynamic index write broadens every named field without proof; +- stale query cache answers for a different location version. + +This read law is where many current helper clusters should collapse. + +### Write And Mutation Law + +A write is not just "join this type into a table". + +Write shape: + +```text +Write = + Target Location + + OperatorKind + + ValueEvidence + + Dominance + + Provenance +``` + +Operator kinds: + +- assignment, +- field write, +- nil overwrite, +- deletion/absence write if Lua semantics or API effect establishes deletion, +- dynamic index write, +- array element insert, +- map value update, +- container send/receive, +- captured mutation replay. + +The operator kind is semantic. `table.insert(x, v)`, `x[k] = v`, and +`x.field = v` may all affect a table, but they do not have the same path law. +Captured replay must preserve the original operator kind. + +### Alias And Dominance Law + +Alias facts are valid only over a control-flow region. + +Rules: + +- alias created by assignment is valid until reassignment or invalidating + mutation; +- field alias preserves the exact field path it came from; +- dynamic index alias preserves key evidence; +- branch-local alias facts do not leak unless dominance proves they reach the + query point; +- loop-carried aliases widen at the loop boundary; +- captured aliases include lexical owner and escape information. + +Relation facts must reference canonical locations, not syntactic expressions. +If assignment preserves location identity, relations can transfer. If it copies +only a value and loses tuple/path identity, relation facts must not silently +survive. + +### Tuple Slot Law + +Tuple slots are locations, not just positions in a slice. + +Rules: + +- return arity is part of tuple identity; +- nil padding is explicit; +- wrapper forwarding preserves tuple-slot relation only when forwarding is + identity-preserving; +- assignment from tuple slot to local location records a relation edge from slot + to local; +- swapped or reordered returns update relation mapping explicitly; +- vararg expansion has its own location/evidence policy and cannot be treated + as fixed tuple identity without proof. + +This prevents the `(value, err)` convention from becoming an arity heuristic. + +### Presence Law + +Presence is separate from value type. + +States: + +- present with value evidence, +- present with nil value, +- absent from closed structure, +- optional in declared structure, +- unknown via open row tail, +- unknown via dynamic table top. + +Important distinctions: + +- `field = nil` is not automatically the same as absent unless the domain rule + for that context says so; +- optional declared field is not the same as proven absence; +- open record tail gives unknown evidence, not nil evidence; +- map value may be nil even when key presence is unknown; +- table top preserves that a value is table-like without proving named fields. + +Presence should be tested as its own domain law. It is too important to hide in +record subtyping or field lookup helpers. + +### Publication Law + +Only memory facts that escape the local function become interproc deltas. + +Publishable memory evidence: + +- captured variable type, +- captured field assignment, +- captured container mutation, +- constructor field, +- return value/tuple slot, +- parameter obligation/effect, +- module export field. + +Non-publishable memory evidence: + +- branch-local narrowing, +- local alias that does not escape, +- temporary tuple slot after assignment unless relation summary requires it, +- diagnostic-only failure evidence, +- query cache answer. + +Publication should project from memory state. It should not reconstruct memory +facts by rescanning AST or replaying helper-specific summaries. + +### Performance Consequences + +The location calculus is also a performance boundary. + +Expected wins: + +- interned locations make map keys cheap and stable; +- path parsing disappears from hot query paths; +- child-path facts avoid rebuilding whole parent tables; +- alias and dominance checks become graph-indexed facts; +- relation queries compare location IDs instead of syntactic paths; +- captured mutation replay reuses the same mutation operator. + +Rejected performance shapes: + +- stringifying paths to compare them in hot loops; +- reparsing path suffixes during every narrowed query; +- rebuilding parent records for each child write; +- using object pools before ownership of locations and memory facts is proven; +- caching read answers without a solved-state/location-version key. + +### Location Law Tests + +The flash migration should add focused tests for: + +- same expression at same point resolves to same location; +- same name in different scopes resolves to different locations; +- alias validity ends at reassignment; +- branch-local alias does not leak; +- dynamic index write does not overwrite unrelated named field; +- child field write outranks parent shape at that child; +- closed missing field differs from open row-tail field; +- tuple relation survives identity forwarding; +- tuple relation dies on reorder unless remapped; +- captured mutation preserves operator kind and target location; +- nil value and absence remain distinguishable. + +These tests are foundational. If they pass, many higher-level inference tests +become much simpler because they no longer need to encode location policy. + ### Evidence `Evidence` is a value plus provenance and authority. From ed897bb2e477d9c630f19e95064d09ea56a4532a Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 01:06:28 -0400 Subject: [PATCH 15/71] Define checker relation effect calculus --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 328 ++++++++++++++++++++++++++++++ 1 file changed, 328 insertions(+) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index f025f126..727ecdb1 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -1200,6 +1200,334 @@ Examples: Effects are applied by transfer. They do not rewrite types directly. +## Relation And Effect Calculus + +Relations and effects are the bridge between local flow precision and +interprocedural power. They must be first-class domain facts, not names of known +functions. + +Core rule: + +```text +Relations describe conditional truth between locations. +Effects describe state transitions caused by execution. +``` + +An assertion, predicate, table mutator, callback, error-return convention, and +terminating function all fit this rule. + +### Relation Shape + +A relation should be represented as a structured fact: + +```text +Relation = + RelationID + + Participants + + Arms + + Directionality + + Validity + + Provenance +``` + +Participants are canonical locations: + +- tuple slots, +- locals, +- fields, +- indexes, +- receiver/self, +- callback arguments, +- captured paths. + +Arms describe conditional cases: + +- success/failure branch, +- true/false predicate branch, +- nil/non-nil branch, +- type-test branch, +- discriminant branch, +- custom effect branch. + +Directionality matters. Some relations are bidirectional; many are not. For +example, `err == nil` may imply success-side value evidence, but using a value +does not necessarily prove `err == nil` unless the relation declares that +reverse implication. + +Validity records when the relation is safe to apply: + +- CFG region, +- dominance/post-dominance requirement, +- location identity requirement, +- alias validity, +- tuple-slot identity, +- function summary boundary, +- effect precondition. + +### Relation Operations + +The relation domain should own these operations: + +```text +Attach(relation, state) +Assume(location predicate, state) +Remap(relation, location mapping) +Project(location, state) +Join(a, b) +Widen(prev, next) +Publish(relation, boundary) +``` + +`Attach` stores a relation after validating participants. +`Assume` applies a branch predicate and derives consequences. +`Remap` preserves a relation through assignment, wrapper forwarding, or tuple +reordering only when identity mapping is explicit. +`Project` answers what a relation proves about a queried location. +`Join` keeps only facts valid on all incoming paths or marks path-conditional +arms explicitly. +`Widen` bounds recursive relation growth. +`Publish` emits only relations that remain meaningful across the boundary. + +Forbidden relation operations: + +- infer relation from return arity alone; +- preserve relation after assignment without location mapping; +- treat a predicate function name as proof outside effect transfer; +- erase relation provenance during join; +- encode relation as a special `typ.Type`. + +### Tuple Relation Law + +The `(value, err)` convention is one tuple relation instance: + +```text +SuccessArm: err is nil -> value is success value +FailureArm: err is non-nil -> value is nil/unknown failure value +``` + +It is not: + +- any two-return function, +- any call followed by `test.is_nil`, +- a special return-summary vector, +- a call-checking hack. + +Custom error records, boolean-success APIs, result objects, and status-code +APIs should be expressible by defining different relation arms over locations. + +### Predicate Relation Law + +Predicate/assertion functions apply relations through effects. + +Examples: + +- `is_nil(x)` proves nil/non-nil branches for `x`; +- `is_string(x)` proves string/non-string branches for `x`; +- `assert_type(x, "string")` refines `x` or terminates; +- `has_field(x, "name")` proves presence for `x.name`; +- custom manifest predicate proves declared relation arms. + +The function name is only a lookup key for an effect summary. The effect summary +is the semantic object. + +Wrong shape: + +```text +if call name == "test.is_nil" then patch value type +``` + +Correct shape: + +```text +call -> effect summary -> relation transfer -> query +``` + +### Effect Shape + +An effect summary should be a structured transition: + +```text +Effect = + EffectID + + Preconditions + + MemoryEffects + + RelationEffects + + ValueEffects + + TerminationEffect + + CallbackEffects + + PublicationPolicy + + Provenance +``` + +Preconditions decide when the effect is valid. +Memory effects mutate locations through `MemoryDomain`. +Relation effects attach or assume relations through `RelationDomain`. +Value effects refine or produce value evidence through `ValueDomain`. +Termination effects update reachability through `TerminationDomain`. +Callback effects describe higher-order execution. +Publication policy decides whether the summary can cross a function/module +boundary. + +### Effect Application Law + +Applying an effect is transfer: + +```text +Call instruction + -> resolve callee/effect summary + -> instantiate summary with actual argument/receiver/return locations + -> validate preconditions + -> apply memory effects + -> apply relation effects + -> apply value effects + -> apply termination effects + -> schedule callback effects if invoked +``` + +Every sub-step calls the owning domain. The effect domain coordinates; it does +not own memory, value, relation, or termination laws. + +### Callback Effect Law + +Callbacks are effectful calls whose callee is a parameter or field. + +Rules: + +- callback invocation has its own call site and locations; +- callback argument evidence flows as call observations; +- callback return/effect evidence flows back only through declared callback + summary; +- captured caller memory can be mutated only through explicit captured location + effects; +- unknown callback effects are not pure unless the effect row is closed. + +This prevents higher-order code from becoming a blind spot or a source of +unsound broadening. + +### Termination Law + +Termination is an effect, not a diagnostic side channel. + +Examples: + +- `error()` terminates the current path; +- assertion failure terminates one branch; +- `return` terminates the current function path; +- infinite loop may terminate analysis reachability differently from runtime + non-return depending on proof. + +Reachability must update before value queries observe post-call state. Otherwise +the checker can report false positives from impossible paths or accept values +from dead branches. + +### Open And Closed Effect Rows + +Effects need the same open/closed discipline as structural types. + +Closed effect row: + +```text +This call has exactly these modeled effects. +``` + +Open effect row: + +```text +This call has at least these effects; unknown effects may remain. +``` + +Rules: + +- no effect summary does not mean pure call; +- closed pure summary can prove no mutation/termination/refinement; +- open summary cannot prove absence of unknown mutation; +- unknown external call must not refine values without a declared effect; +- manifest effects are typed inputs, not hardcoded behavior. + +### Relation/Effect Join And Widen + +Join: + +- keeps relations/effects valid on all joined paths; +- preserves path-conditional arms when the domain represents them explicitly; +- drops or weakens facts whose participant locations are no longer identical; +- never converts absence of relation into proof of independence. + +Widen: + +- bounds recursive relation chains; +- bounds callback/effect expansion; +- bounds recursive captured mutation growth; +- preserves sound top/unknown effects when precision is lost. + +Precision loss here must be visible as domain widening, not hidden in query or +publication. + +### Publication Law + +Publishable relations/effects: + +- function return tuple relation, +- predicate/assertion function relation summary, +- captured memory mutation effect, +- callback invocation effect, +- termination/non-returning effect, +- external manifest effect, +- constructor/receiver mutation effect. + +Non-publishable relations/effects: + +- branch-local guard that does not escape; +- local assertion proof after the checked value dies; +- relation over temporary tuple slots unless remapped to exported locations; +- query-only refinement; +- diagnostic-only proof. + +Publication should remap local locations to boundary locations. If a relation or +effect cannot be remapped, it does not publish. + +### Performance Consequences + +The relation/effect calculus should improve performance by making reuse +structural. + +Expected wins: + +- relation queries index by participant location; +- effect summaries are cached by callee identity and manifest/source version; +- effect instantiation is local and cheap because locations are canonical; +- callback expansion is bounded by summary widening; +- wrapper forwarding remaps relation IDs instead of resynthesizing return + behavior; +- predicate handling uses one transfer path. + +Rejected shapes: + +- scanning all relations for every type query; +- recomputing effect summaries inside every call check; +- using string function names in hot semantic paths; +- replaying captured mutations by rebuilding table types; +- preserving all recursive callback effects without widening; +- clearing false positives by treating unknown effects as pure. + +### Relation And Effect Law Tests + +The flash migration should add focused tests for: + +- tuple relation attaches only from declared summary, not arity; +- tuple relation survives identity wrapper forwarding; +- tuple relation remaps through swapped returns only with explicit mapping; +- predicate effect narrows only declared participants; +- assertion termination removes impossible paths before value query; +- unknown external call does not refine argument; +- closed pure effect proves no mutation; +- open effect row does not prove no mutation; +- callback call observation reaches callback parameter evidence; +- callback unknown effects do not mutate closed state without declaration; +- captured mutation effect preserves operator kind and target location; +- relation join does not invent independence; +- recursive relation/effect widening converges without erasing all useful proof. + ### Delta A `Delta` is a completed analysis contribution to another scope or iteration. From 66817004eeeb39ff92a5c798b4918117201fcc2b Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 01:08:03 -0400 Subject: [PATCH 16/71] Define checker function boundary calculus --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 341 ++++++++++++++++++++++++++++++ 1 file changed, 341 insertions(+) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index 727ecdb1..e851f373 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -1528,6 +1528,347 @@ The flash migration should add focused tests for: - relation join does not invent independence; - recursive relation/effect widening converges without erasing all useful proof. +## Function Boundary Summary Calculus + +A function boundary is where local abstract state becomes reusable evidence for +callers. This boundary must have one product-domain object. It should not be +spread across parameter hints, return summaries, narrow summaries, function +types, captured fields, captured containers, literal signatures, and effect +maps as independent authorities. + +Core rule: + +```text +FunctionSummary = abstraction(QueryView(SolvedState), BoundaryMap) +``` + +The summary is not a second analysis. It is a deterministic abstraction of the +solved state through the function boundary. + +### Boundary Map + +The boundary map explains how local locations become external locations. + +```text +BoundaryMap = + Parameters + + Receiver + + Returns + + Captures + + Exports + + Constructors + + CallbackSlots +``` + +Examples: + +- parameter location maps to parameter slot; +- receiver `self` maps to receiver slot; +- local return tuple slots map to return slots; +- captured upvalue paths map to captured locations; +- module fields map to export locations; +- constructor writes map to constructor instance fields; +- callback parameters map to callback function slots. + +Any summary fact that cannot be expressed through the boundary map is not +publishable. It remains local evidence. + +### Summary Product + +The canonical function summary should be a product: + +```text +FunctionSummary = + SignatureSurface + x ParameterEvidence + x ReturnTupleSummary + x RelationSummary + x EffectSummary + x CaptureSummary + x ConstructorSummary + x ExportSummary +``` + +`SignatureSurface` is the user-facing callable type projection. It is derived +from the product. It is not the stored authority. + +`ParameterEvidence` records annotations, body obligations, call observations, +soft evidence, contextual literal evidence, and recursive widening state. + +`ReturnTupleSummary` records explicit arity, nil padding, unknown slots, any +slots, multivalue expansion policy, and per-slot provenance. + +`RelationSummary` records tuple/path relations that survive the boundary map. + +`EffectSummary` records memory, relation, value, termination, and callback +effects that callers must apply through transfer. + +`CaptureSummary` records captured value/path/mutation evidence that escaped the +function body. + +`ConstructorSummary` records constructor field facts only when construction +semantics prove them. + +`ExportSummary` records module-visible fields and functions. + +### Parameter Summary Law + +Parameters have several evidence sources, but one domain. + +Evidence sources: + +- explicit parameter annotation, +- manifest/API contract, +- body obligation, +- call observation, +- function literal expected type, +- soft annotation, +- recursive SCC seed, +- interproc snapshot. + +Merge policy: + +- explicit contracts define the checked surface; +- body obligations can infer required structure; +- call observations are weak evidence and cannot create a hard contract alone; +- soft evidence refines only when compatible proof exists; +- recursive evidence widens only at SCC/interproc boundaries; +- optionality and nilability are separate axes; +- `any` remains dynamic top unless explicit cast/contract changes the question; +- absence of parameter evidence is unresolved, not `any`. + +Wrong shape: + +```text +ParamHints merge differently from FunctionFacts.Params +``` + +Correct shape: + +```text +ParameterEvidenceDomain.Join(existing, candidate) +``` + +### Return Summary Law + +Returns are tuples with attached relations and effects. + +Rules: + +- arity is explicit; +- nil padding is explicit; +- zero returns differ from one nil return; +- unknown return evidence is not bottom; +- any return evidence remains dynamic top; +- recursive return growth widens at the return domain boundary; +- narrow/success returns are derived views over tuple relation state; +- wrapper forwarding preserves return relations only through explicit location + remapping; +- vararg return expansion has a distinct summary policy. + +Wrong shape: + +```text +ReturnSummaries and NarrowReturns are stored as separate truths +``` + +Correct shape: + +```text +ReturnTupleSummary + RelationSummary -> projected narrow/success view +``` + +### Function Type Projection Law + +A function type is a projection, not an authority. + +Projection: + +```text +FunctionType = + params(ParameterEvidence) + + returns(ReturnTupleSummary) + + effects(EffectSummary) + + relation metadata if the surface type can carry it +``` + +Rules: + +- projection is deterministic and cacheable; +- projection does not write facts; +- projection does not reconcile legacy channels; +- projection must be invalidated by changes to the canonical summary product; +- two projections of the same summary must be equal. + +This removes the need for bridge shapes such as "function types from facts" as a +semantic layer. A projection function may exist as a read-only view, but it is +not a merge or fallback path. + +### Capture Summary Law + +Captures are memory/effect facts remapped through lexical ownership. + +Publishable capture facts: + +- captured variable value evidence; +- captured field write; +- captured nil overwrite/deletion when modeled; +- captured table/container mutation; +- captured relation over exported/captured locations; +- captured callback effect. + +Rules: + +- captured paths use canonical locations with lexical owner identity; +- mutation operator kind is preserved; +- dominance/escape controls whether the mutation publishes; +- parent-derived table shape cannot overwrite child captured mutation; +- captured facts are applied by transfer in the receiving context, not by + rebuilding parent table types. + +### Constructor And Export Summary Law + +Constructor and export facts are boundary memory facts. + +Rules: + +- constructor fields are published only from construction evidence; +- module export fields are published only from export locations; +- local helper facts do not publish just because the name is visible; +- exported functions publish their function summary product; +- imports read snapshots and apply summaries through transfer/query, not through + local special cases. + +### Call Application Law + +Calling a function applies its summary to actual locations. + +```text +CallSite + + FunctionSummary + + ActualArgumentLocations + + ReturnDestinationLocations + -> Transfer over AbstractState +``` + +Application steps: + +1. check actuals against projected parameter contracts; +2. record call observations as weak parameter evidence; +3. instantiate effect summary over actual locations; +4. instantiate relation summary over return and argument locations; +5. bind return tuple summary to destination locations; +6. update termination/reachability; +7. publish caller-side deltas only after the caller solves. + +Forbidden: + +- expected parameter type rewrites actual evidence; +- callee summary mutates interproc store during call checking; +- caller synthesizes a new callee summary from local expectations; +- return arity heuristic creates relation summary; +- call application bypasses transfer. + +### Summary Join And Widen + +Function summaries combine through their domains. + +Join: + +- combines independent observations within one iteration; +- preserves provenance and authority; +- keeps tuple arity explicit; +- joins relations/effects only when participant remapping is compatible; +- avoids rebuilding equivalent maps or slices on no-op joins. + +Widen: + +- applies at local function SCC and interproc boundaries; +- bounds recursive parameter, return, capture, relation, and effect growth; +- preserves sound unknown/any distinction; +- emits precision-loss evidence for diagnostics/profiling; +- never hides convergence by equality-time normalization. + +Leq/Equal: + +- compare canonical summary state only; +- do not rebuild projections; +- do not normalize as repair; +- are the basis for fixpoint convergence and snapshot invalidation. + +### Summary Storage Law + +The stored authority should be one canonical product. + +Allowed stored authority: + +```text +FunctionSummary product +``` + +Allowed derived views: + +- callable `typ.Function` surface; +- display signature; +- backward-compatible API response if needed outside production semantics; +- narrow/success return projection; +- parameter hint projection for UI/debugging. + +Forbidden stored authority: + +- param hints as separate merge truth; +- return summaries as separate merge truth; +- narrow returns as separate merge truth; +- function type cache as separate merge truth; +- captured mutation helper summaries with custom merge; +- legacy compatibility view written back into facts. + +The final flash migration should delete duplicate stored channels in the same +change that introduces the canonical product. + +### Performance Consequences + +The boundary summary calculus should make interproc faster because summaries +become smaller and more stable. + +Expected wins: + +- one summary hash/equality path instead of multiple channel comparisons; +- no function-type projection during convergence unless a caller asks for it; +- no return narrow projection during convergence unless a query asks for it; +- no-op joins can reuse previous summary components; +- snapshot inputs update only changed canonical summaries; +- wrapper forwarding remaps summaries instead of resynthesizing them; +- parameter-use graph summaries feed parameter evidence without AST rescans. + +Rejected shapes: + +- rebuilding all derived views on every merge; +- writing projections back into canonical facts; +- comparing function summaries by formatting types; +- widening by dropping entire summary families; +- adding iteration caps instead of domain widening; +- clearing caches manually to repair stale summary dependencies. + +### Function Boundary Law Tests + +The flash migration should add focused tests for: + +- function type projection is deterministic from the same summary; +- parameter body obligation outranks call observation; +- call observation alone does not prove concrete callee contract; +- explicit `any` parameter does not become concrete from calls; +- zero returns differ from one nil return; +- unknown return survives merge with concrete return when unresolved; +- narrow/success return is derived from relation summary; +- wrapper forwarding preserves relation through explicit remap; +- captured field write and captured container mutation use same memory law; +- constructor field publishes only from constructor evidence; +- export summary does not include non-escaping locals; +- no-op summary join preserves equality and avoids snapshot rewrite; +- recursive function summary widens and converges without erasing all relation + proof. + ### Delta A `Delta` is a completed analysis contribution to another scope or iteration. From 2cfa142ab77bd69304a45121b27bfcc2fc6aab26 Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 01:40:38 -0400 Subject: [PATCH 17/71] Collapse parameter evidence into function facts --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 85 ++++++- compiler/check/api/doc.go | 3 +- compiler/check/api/facts.go | 18 +- compiler/check/api/facts_test.go | 23 +- compiler/check/api/store.go | 1 - compiler/check/checker.go | 3 +- compiler/check/infer/interproc/doc.go | 2 +- compiler/check/infer/interproc/postflow.go | 42 ++-- compiler/check/infer/interproc/writer_test.go | 8 +- compiler/check/infer/paramevidence/doc.go | 28 +++ .../parameter_evidence.go} | 186 ++++++++------- .../parameter_evidence_test.go} | 221 +++++++++++++----- .../{paramhints => paramevidence}/project.go | 113 +++++++-- compiler/check/infer/paramhints/doc.go | 29 --- compiler/check/infer/return/infer.go | 125 +++++----- .../check/infer/return/overlay_pipeline.go | 20 +- compiler/check/infer/return/scc.go | 14 +- compiler/check/phase/scope.go | 58 ++--- compiler/check/phase/types.go | 6 +- compiler/check/phase/types_test.go | 70 +++--- compiler/check/pipeline/runner.go | 24 +- compiler/check/pipeline/runner_stages.go | 14 +- compiler/check/returns/callgraph.go | 48 ++-- compiler/check/returns/callgraph_test.go | 86 +++---- compiler/check/returns/doc.go | 2 +- compiler/check/returns/domain_law_test.go | 5 +- compiler/check/returns/equal.go | 20 +- compiler/check/returns/equal_test.go | 38 ++- compiler/check/returns/function_facts.go | 3 +- compiler/check/returns/join.go | 4 +- compiler/check/returns/kernel.go | 3 + compiler/check/returns/types.go | 16 +- compiler/check/returns/types_test.go | 22 +- compiler/check/returns/widen.go | 163 ++++++------- compiler/check/returns/widen_test.go | 134 ++++++----- compiler/check/session.go | 2 +- compiler/check/store/doc.go | 2 +- compiler/check/store/facts_clone.go | 13 +- compiler/check/store/store.go | 10 - compiler/check/store/store_test.go | 15 +- .../tests/core/return_field_merge_test.go | 10 +- .../tests/flow/fixpoint_unification_test.go | 33 +-- ...=> parameter_evidence_and_returns_test.go} | 0 compiler/check/tests/modules/manifest_test.go | 13 +- ...ssert_false_discriminant_narrowing_test.go | 4 +- .../explicit_any_param_contract_test.go | 2 +- .../regression/false_positives_unit_test.go | 4 +- .../http_timeout_option_inference_test.go | 2 +- ...=> nested_call_parameter_evidence_test.go} | 4 +- ...ameter_evidence_depth_convergence_test.go} | 8 +- ...py_sorted_keys_parameter_evidence_test.go} | 9 +- 51 files changed, 980 insertions(+), 788 deletions(-) create mode 100644 compiler/check/infer/paramevidence/doc.go rename compiler/check/infer/{paramhints/param_hints.go => paramevidence/parameter_evidence.go} (66%) rename compiler/check/infer/{paramhints/param_hints_test.go => paramevidence/parameter_evidence_test.go} (64%) rename compiler/check/infer/{paramhints => paramevidence}/project.go (84%) delete mode 100644 compiler/check/infer/paramhints/doc.go rename compiler/check/tests/inference/{param_hints_and_returns_test.go => parameter_evidence_and_returns_test.go} (100%) rename compiler/check/tests/regression/{nested_call_param_hints_test.go => nested_call_parameter_evidence_test.go} (86%) rename compiler/check/tests/regression/{param_hint_depth_convergence_test.go => parameter_evidence_depth_convergence_test.go} (94%) rename compiler/check/tests/regression/{wippy_sorted_keys_param_hints_test.go => wippy_sorted_keys_parameter_evidence_test.go} (97%) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index e851f373..0b07320a 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -8,6 +8,67 @@ flash migration: design the final shape, migrate directly to it, delete the old helper clusters, and do not leave compatibility wrappers or fallback layers in the production checker. +## 2026-05-19 Implementation Checkpoint + +First flash-migration slice landed for parameter evidence ownership. + +Changed production shape: + +- `api.FunctionFact` now owns canonical parameter evidence in `Params`. +- `api.Facts.ParamHints` was removed. +- `api.SnapshotStore.GetParamHintsSnapshot` was removed. +- same-iteration merge and interproc widening now combine parameter evidence + through `FunctionFacts`. +- post-flow call observation publication now emits `FunctionFacts` deltas with + `Params` instead of writing a side channel. +- return inference seeds local function parameter evidence from canonical + `FunctionFacts`. +- Salsa snapshot facts now track one canonical fact product for parameters, + returns, narrow returns, and function type projection. + +This is intentionally not a bridge. No production code reads a legacy +`ParamHints` fact channel and no compatibility writer reconstructs it from +`FunctionFacts`. + +Second cleanup slice in the same migration: + +- the local inference package was renamed from `infer/paramhints` to + `infer/paramevidence`; +- `LocalFuncInfo.ParamHints` became `LocalFuncInfo.ParameterEvidence`; +- phase input `ParamHintSignatures` became `ParameterEvidenceSignatures`; +- local call-graph propagation now exposes `PropagateParameterEvidence`; +- helper files and regression fixtures were renamed to parameter-evidence + terminology; +- production checker code no longer contains `ParamHint` or `paramhints` + identifiers. +- parameter-use projection now treats builtin `type(param)` checks and + `param = param or {}` self-default assignments as shape-neutral guard/default + operations instead of whole-parameter escapes. Those operations must not turn + a call-site record observation into a closed public contract. + +Verification notes: + +- `go test ./...` passes. +- `git diff --check` passes. +- `../scripts/verify-suite.sh` passes go-lua checker tests and builds the Wippy + binary, then exits non-zero in external lint targets while building Wippy + against `github.com/wippyai/go-lua v1.5.16`. +- A temp local-replace replay under `/tmp/wippy-golua-local-replace` builds + Wippy against this checkout without editing external code. It reduced the + projection-related false positives, but the full external sweep is still not + clean: tests/app 2 errors/4 warnings, session 20, actor/test 3, agent/src 12, + docker-demo 72, llm/src 10, llm/test 9, migration 1, views 1. + +Remaining cleanup after this parameter-evidence slice: + +- return/narrow/type projections still need the same treatment: read-only views + over the canonical function summary product, not separate authorities. +- Remaining local-replace external diagnostics must be classified in the next + engine slice. Some are soundness-preserving real-code issues (`any` flowing + into concrete contracts); some still expose missing checker power, especially + public functions that validate invalid input with `type(...)` guards and + should infer a wider accepted input domain without weakening the guarded body. + ## Goal The checker should read as one abstract interpreter over a product domain. @@ -60,7 +121,7 @@ The checker is a multi-phase abstract interpreter: 5. Narrowing queries are demand-side interpretation: read solved facts at a point, apply propagated constraints, and answer refined path/type questions. 6. Return inference and local function SCC solving use the flow result plus - interprocedural snapshots to infer return vectors, parameter hints, function + interprocedural snapshots to infer return vectors, parameter evidence, function facts, captured fields, and captured container mutations. 7. The interprocedural store combines same-iteration deltas with a precise join and combines recursive fixpoint boundaries with widening. @@ -1532,7 +1593,7 @@ The flash migration should add focused tests for: A function boundary is where local abstract state becomes reusable evidence for callers. This boundary must have one product-domain object. It should not be -spread across parameter hints, return summaries, narrow summaries, function +spread across parameter evidence, return summaries, narrow summaries, function types, captured fields, captured containers, literal signatures, and effect maps as independent authorities. @@ -1816,7 +1877,7 @@ Allowed derived views: Forbidden stored authority: -- param hints as separate merge truth; +- parameter evidence as separate merge truth; - return summaries as separate merge truth; - narrow returns as separate merge truth; - function type cache as separate merge truth; @@ -2224,7 +2285,7 @@ checker. These are the invariants that should guide the flash migration. - source annotations remain authoritative; - soft annotations can refine but not override hard proof; - recursive parameter evidence widens at SCC/interproc boundaries only; -- function-fact params and param hints use the same evidence order. +- function-fact params and parameter evidence use the same evidence order. ### Return Summary Invariants @@ -2645,7 +2706,7 @@ Rules: - soft annotations refine only when hard evidence proves the refinement; - recursive parameter evidence must join/widen through the parameter domain. -There should be no separate ad hoc policy for "param hints" versus "function +There should be no separate ad hoc policy for "parameter evidence" versus "function fact params". Both are parameter evidence with different provenance and merge mode. @@ -2823,7 +2884,7 @@ Current scattered concepts: - value-type joins, - return-slot joins, - function-param fact joins, -- param-hint joins, +- parameter-evidence joins, - table-top absorption, - soft-placeholder replacement, - open-record row-tail merging, @@ -2837,7 +2898,7 @@ Current scattered concepts: - signature projection to body use. These concepts are real. The problem is not that they exist. The problem is that -they appear as local helpers in `returns`, `paramhints`, `flow`, `synth`, and +they appear as local helpers in `returns`, `paramevidence`, `flow`, `synth`, and `typ`, with overlapping responsibilities. That creates the "guacamole" feeling: behavior is strong, but the mental model @@ -3099,7 +3160,7 @@ compiler/check/domain/paramevidence Current split: -- some policy lives in `compiler/check/infer/paramhints`, +- some policy lives in `compiler/check/infer/paramevidence`, - some lives in `compiler/check/returns/widen.go`, - some lives in return SCC inference, - some lives in interproc postflow. @@ -3151,7 +3212,7 @@ helper joins directly. | function fact type merge | `compiler/check/returns/join.go` | `domain/functionfact` | | function param-slot refinement | `compiler/check/returns/join.go`, `widen.go` | `domain/functionfact.ParamSlotDomain` | | return-vector merge/repair | `compiler/check/returns/join.go` | `domain/returnsummary` | -| table-top absorption | `infer/paramhints`, `returns/widen.go` | `domain/paramevidence` plus value-domain classifier | +| table-top absorption | `infer/paramevidence`, `returns/widen.go` | `domain/paramevidence` plus value-domain classifier | | soft vs concrete evidence | `typ/soft.go`, `returns/widen.go`, return overlay | `domain/value` evidence policy | | open-record row-tail merge | `types/typ/policy.go` | `domain/value` row-shape policy | | path/query/alias identity | `constraint`, `flowbuild/path`, `flow/pathkey` | `memory` | @@ -3732,7 +3793,7 @@ incremental boundaries and no hidden semantic cache. | `snapshotInputs.constructorFields` | constructor field snapshot | Salsa input | memory/constructor boundary | | `types/query/core.Engine` | pure type operations | query engine cache | type-query layer | | `types/flow.ProductDomain` | branch-local narrowing algebra | ephemeral domain state | abstract state / flow domain | -| `paramhints.collectParamUses` | body-demand summary | graph-derived Salsa query | graph summary layer | +| `paramevidence.collectParamUses` | body-demand summary | graph-derived Salsa query | graph summary layer | | `ProjectHintsToParamUse` | parameter evidence projection | domain operation over cached body summary | parameter domain | | `PreCache` / `NarrowCache` | repeated expression synthesis inside one solve | per-function local cache | transfer/query phase | | `FunctionTypeCache` | local function specialization during one solve | per-function local cache unless key is immutable | function analysis | @@ -3900,7 +3961,7 @@ The checker has laws such as: - `unknown` in return summaries is unresolved runtime behavior, - open record absent field means row-tail, not nil, - nil field can satisfy optional absence in record subtyping, -- table-top can absorb precise table evidence in parameter hints, +- table-top can absorb precise table evidence in parameter evidence, - truthy refinement can remove falsy key alternatives. Today many of these appear as function names buried in unrelated packages. They @@ -3943,7 +4004,7 @@ Parameter evidence currently comes from: The final design needs one `ParameterEvidence` lattice with evidence provenance and merge mode. The implementation should not need separate helpers for -"param hints" and "function param facts" that rediscover the same truthiness, +"parameter evidence" and "function param facts" that rediscover the same truthiness, softness, and table-key laws. ### 5. Relation Facts Are Under-Modeled diff --git a/compiler/check/api/doc.go b/compiler/check/api/doc.go index bb13b0ef..0b978c12 100644 --- a/compiler/check/api/doc.go +++ b/compiler/check/api/doc.go @@ -30,8 +30,7 @@ // The [Facts] type bundles interprocedural analysis results for a single // function graph: // -// - [FunctionFacts]: Canonical per-function return/signature facts -// - [ParamHints]: Effective parameter types inferred from call sites +// - [FunctionFacts]: Canonical per-function parameter/return/signature facts // - [LiteralSigs]: Signatures for anonymous function literals // - [CapturedTypes]: Flow-derived types for captured variables // - [CapturedFieldAssigns]: Field assignments to captured variables diff --git a/compiler/check/api/facts.go b/compiler/check/api/facts.go index 8b83dcd8..641c961c 100644 --- a/compiler/check/api/facts.go +++ b/compiler/check/api/facts.go @@ -12,14 +12,12 @@ import ( "github.com/wippyai/go-lua/types/typ" ) -// ParamHints maps function symbols to effective-parameter type hints inferred -// from call sites. For method calls, slot 0 is the receiver/self argument and -// the remaining slots are the source arguments. -type ParamHints = map[cfg.SymbolID][]typ.Type - // FunctionFact is the canonical function-related interproc fact for one symbol. // All return and local-function type evidence for a function converges here. type FunctionFact struct { + // Params is the canonical parameter evidence vector. For method calls, slot + // 0 is the receiver/self argument and the remaining slots are source args. + Params []typ.Type // Summary is the declared/pre-flow return vector. Summary []typ.Type // Narrow is the post-flow return vector. @@ -40,6 +38,15 @@ func (facts FunctionFacts) Fact(sym cfg.SymbolID) (FunctionFact, bool) { return ff, ok } +// Params returns the canonical parameter evidence vector for sym. +func (facts FunctionFacts) Params(sym cfg.SymbolID) []typ.Type { + ff, ok := facts.Fact(sym) + if !ok { + return nil + } + return ff.Params +} + // Summary returns the declared/pre-flow return vector for sym. func (facts FunctionFacts) Summary(sym cfg.SymbolID) []typ.Type { ff, ok := facts.Fact(sym) @@ -136,7 +143,6 @@ type ConstructorFields = map[cfg.SymbolID]map[string]typ.Type // These facts are computed during analysis and stored per (graph, parent) pair. type Facts struct { FunctionFacts FunctionFacts - ParamHints ParamHints LiteralSigs LiteralSigs CapturedTypes CapturedTypes CapturedFields CapturedFieldAssigns diff --git a/compiler/check/api/facts_test.go b/compiler/check/api/facts_test.go index 8062aebe..5fb90318 100644 --- a/compiler/check/api/facts_test.go +++ b/compiler/check/api/facts_test.go @@ -13,9 +13,6 @@ func TestFacts_Zero(t *testing.T) { if f.FunctionFacts != nil { t.Error("zero Facts should have nil FunctionFacts") } - if f.ParamHints != nil { - t.Error("zero Facts should have nil ParamHints") - } if f.LiteralSigs != nil { t.Error("zero Facts should have nil LiteralSigs") } @@ -41,17 +38,14 @@ func TestFunctionFacts_Summary(t *testing.T) { } } -func TestParamHints_Basic(t *testing.T) { - hints := make(ParamHints) +func TestFunctionFacts_Params(t *testing.T) { + facts := make(FunctionFacts) sym := cfg.SymbolID(1) - hints[sym] = []typ.Type{typ.Number, typ.String} + facts[sym] = FunctionFact{Params: []typ.Type{typ.Number, typ.String}} - params, ok := hints[sym] - if !ok { - t.Fatal("expected symbol to be in hints") - } + params := facts.Params(sym) if len(params) != 2 { - t.Errorf("expected 2 param hints, got %d", len(params)) + t.Errorf("expected 2 params, got %d", len(params)) } } @@ -161,20 +155,15 @@ func TestFacts_WithData(t *testing.T) { f := Facts{ FunctionFacts: FunctionFacts{ 4: { + Params: []typ.Type{typ.Number}, Summary: []typ.Type{typ.Boolean}, Narrow: []typ.Type{typ.Boolean}, Type: typ.Func().Returns(typ.Boolean).Build(), }, }, - ParamHints: ParamHints{ - 2: []typ.Type{typ.Number}, - }, } if len(f.FunctionFacts) != 1 { t.Error("expected 1 function fact") } - if len(f.ParamHints) != 1 { - t.Error("expected 1 param hint") - } } diff --git a/compiler/check/api/store.go b/compiler/check/api/store.go index 98bdc65b..6efca046 100644 --- a/compiler/check/api/store.go +++ b/compiler/check/api/store.go @@ -73,7 +73,6 @@ type NestedMetaStore interface { // SnapshotStore exposes stable interproc fact snapshots. type SnapshotStore interface { - GetParamHintsSnapshot(graph *cfg.Graph, parent *scope.State) ParamHints GetFunctionFactsSnapshot(graph *cfg.Graph, parent *scope.State) FunctionFacts GetCapturedTypesSnapshot(graph *cfg.Graph, parent *scope.State) CapturedTypes GetCapturedFieldAssignsSnapshot(graph *cfg.Graph, parent *scope.State) CapturedFieldAssigns diff --git a/compiler/check/checker.go b/compiler/check/checker.go index 4b5315c0..91232bfa 100644 --- a/compiler/check/checker.go +++ b/compiler/check/checker.go @@ -29,8 +29,7 @@ // // The checker supports interprocedural analysis through a unified interproc snapshot: // -// - FunctionFacts: Canonical return/narrow/signature facts for local functions -// - ParamHints: Inferred parameter types from call sites +// - FunctionFacts: Canonical parameter/return/narrow/signature facts // - LiteralSigs: Synthesized signatures for function literals // - Refinements: Function refinement summaries, stored per symbol // diff --git a/compiler/check/infer/interproc/doc.go b/compiler/check/infer/interproc/doc.go index a0f33a01..7f5b3217 100644 --- a/compiler/check/infer/interproc/doc.go +++ b/compiler/check/infer/interproc/doc.go @@ -7,7 +7,7 @@ // // After flow analysis completes for a function, this package: // - Extracts return type summaries -// - Computes parameter type hints from call sites +// - Computes parameter evidence from call sites // - Identifies captured variable assignments // - Propagates effect information // diff --git a/compiler/check/infer/interproc/postflow.go b/compiler/check/infer/interproc/postflow.go index 8f545c79..39ce124c 100644 --- a/compiler/check/infer/interproc/postflow.go +++ b/compiler/check/infer/interproc/postflow.go @@ -7,7 +7,7 @@ import ( "github.com/wippyai/go-lua/compiler/check/api" checkcallsite "github.com/wippyai/go-lua/compiler/check/callsite" "github.com/wippyai/go-lua/compiler/check/erreffect" - "github.com/wippyai/go-lua/compiler/check/infer/paramhints" + "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" "github.com/wippyai/go-lua/compiler/check/nested" "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/compiler/check/scope" @@ -53,8 +53,8 @@ func StoreFactsFromResult( fnSym = resolvedSym } } - // Collect parameter hints regardless of whether the function has a symbol. - CollectParamHintsFromResult(store, result, parent) + // Collect parameter evidence regardless of whether the function has a symbol. + CollectParameterEvidenceFromResult(store, result, parent) if fnSym == 0 { return @@ -75,8 +75,8 @@ func StoreFactsFromResult( summaryFromSnapshot := returnSummarySnapshotForSymbol(store, result, parent, fnSym) candidateFunc := fnType - if hints := store.GetParamHintsSnapshot(result.Graph, parent); len(hints) > 0 { - if hinted := paramhints.MergeIntoSignature(fn, hints[fnSym], unwrap.Function(candidateFunc)); hinted != nil { + if facts := store.GetFunctionFactsSnapshot(result.Graph, parent); len(facts) > 0 { + if hinted := paramevidence.MergeIntoSignature(fn, facts.Params(fnSym), unwrap.Function(candidateFunc)); hinted != nil { candidateFunc = hinted } } @@ -281,9 +281,9 @@ func expectedFunctionFromResult(result *api.FuncResult) *typ.Function { return builder.Build() } -// CollectParamHintsFromResult records parameter hints based on call sites +// CollectParameterEvidenceFromResult records parameter evidence based on call sites // within the current function's graph using narrowed expression types. -func CollectParamHintsFromResult(store Store, result *api.FuncResult, parent *scope.State) { +func CollectParameterEvidenceFromResult(store Store, result *api.FuncResult, parent *scope.State) { if store == nil || result == nil || result.Graph == nil || result.NarrowSynth == nil { return } @@ -298,7 +298,7 @@ func CollectParamHintsFromResult(store Store, result *api.FuncResult, parent *sc hasFunctionRef := func(sym cfg.SymbolID) bool { return sym != 0 && store.FunctionRefBySym(sym) != nil } - collectCallHints := func(p cfg.Point, info *cfg.CallInfo) { + collectCallEvidence := func(p cfg.Point, info *cfg.CallInfo) { if info == nil || checkcallsite.RuntimeArgCount(info) == 0 { return } @@ -388,9 +388,9 @@ func CollectParamHintsFromResult(store Store, result *api.FuncResult, parent *sc return } - deltaHints := make(api.ParamHints) + deltaFacts := make(api.FunctionFacts) runtimeArgCount := checkcallsite.RuntimeArgCount(info) - hints := paramhints.EnsureHintCapacity(nil, runtimeArgCount) + evidence := paramevidence.EnsureCapacity(nil, runtimeArgCount) for runtimeIdx := 0; runtimeIdx < runtimeArgCount; runtimeIdx++ { arg := checkcallsite.RuntimeArgAt(info, runtimeIdx) if arg == nil { @@ -411,7 +411,7 @@ func CollectParamHintsFromResult(store Store, result *api.FuncResult, parent *sc if argType == nil { argType = result.NarrowSynth.TypeOf(arg, p) } - hints, _ = paramhints.MergeCallArgHintAt(hints, runtimeIdx, argType, typ.JoinPreferNonSoft, true) + evidence, _ = paramevidence.MergeCallArgAt(evidence, runtimeIdx, argType, typ.JoinPreferNonSoft, true) } for i, arg := range info.Args { if arg == nil { @@ -427,26 +427,26 @@ func CollectParamHintsFromResult(store Store, result *api.FuncResult, parent *sc hasFunctionRef, ) if argSym != 0 && hasFunctionRef(argSym) { - hintsForFn := deltaHints[argSym] + fnEvidence := deltaFacts.Params(argSym) for j, param := range expectedFn.Params { - hintsForFn, _ = paramhints.MergeHintAt(hintsForFn, j, param.Type, typ.JoinPreferNonSoft) + fnEvidence, _ = paramevidence.MergeAt(fnEvidence, j, param.Type, typ.JoinPreferNonSoft) } - if len(hintsForFn) > 0 { - deltaHints[argSym] = hintsForFn + if len(fnEvidence) > 0 { + deltaFacts[argSym] = returns.JoinFunctionFact(deltaFacts[argSym], api.FunctionFact{Params: fnEvidence}) } } } } - if len(hints) > 0 { - deltaHints[calleeSym] = hints + if len(evidence) > 0 { + deltaFacts[calleeSym] = returns.JoinFunctionFact(deltaFacts[calleeSym], api.FunctionFact{Params: evidence}) } - if len(deltaHints) > 0 { - store.MergeInterprocFactsNext(parentKey, api.Facts{ParamHints: deltaHints}) + if len(deltaFacts) > 0 { + store.MergeInterprocFactsNext(parentKey, api.Facts{FunctionFacts: deltaFacts}) } } graph.EachCallSite(func(p cfg.Point, info *cfg.CallInfo) { - collectCallHints(p, info) + collectCallEvidence(p, info) seenNested := make(map[*ast.FuncCallExpr]struct{}) for _, arg := range info.Args { @@ -457,7 +457,7 @@ func CollectParamHintsFromResult(store Store, result *api.FuncResult, parent *sc if nestedInfo == nil { nestedInfo = synthCallInfoFromExpr(nested, bindings) } - collectCallHints(p, nestedInfo) + collectCallEvidence(p, nestedInfo) } }) } diff --git a/compiler/check/infer/interproc/writer_test.go b/compiler/check/infer/interproc/writer_test.go index 811a4a28..a9eb880d 100644 --- a/compiler/check/infer/interproc/writer_test.go +++ b/compiler/check/infer/interproc/writer_test.go @@ -50,16 +50,16 @@ func TestInterprocFactWriter_MergeParentFactsForSymbol(t *testing.T) { writer := newInterprocFactWriter(stub) ok := writer.mergeParentFactsForSymbol(3, api.Facts{ - ParamHints: map[cfg.SymbolID][]typ.Type{ - 3: {typ.String}, + FunctionFacts: api.FunctionFacts{ + 3: {Params: []typ.Type{typ.String}}, }, }) if !ok { t.Fatal("expected update to succeed") } got := stub.factsByGraphKeyNext[key] - if len(got.ParamHints[3]) != 1 || !typ.TypeEquals(got.ParamHints[3][0], typ.String) { - t.Fatalf("unexpected parent facts update: %#v", got.ParamHints) + if params := got.FunctionFacts.Params(3); len(params) != 1 || !typ.TypeEquals(params[0], typ.String) { + t.Fatalf("unexpected parent facts update: %#v", got.FunctionFacts) } if writer.mergeParentFactsForSymbol(99, api.Facts{}) { diff --git a/compiler/check/infer/paramevidence/doc.go b/compiler/check/infer/paramevidence/doc.go new file mode 100644 index 00000000..f1e149dc --- /dev/null +++ b/compiler/check/infer/paramevidence/doc.go @@ -0,0 +1,28 @@ +// Package paramevidence computes parameter evidence from call-site arguments. +// +// This package analyzes function call sites and body uses to build effective +// parameter types for functions without explicit type annotations. +// +// # Evidence Collection +// +// For each call site: +// +// foo(123, "bar") -- evidence: param1=number, param2=string +// +// The package collects argument types and associates them with parameter +// positions. Multiple call sites contribute evidence that is joined. +// +// # Evidence Merging +// +// When multiple calls provide conflicting evidence: +// +// foo(1) -- evidence: param1=number +// foo("a") -- evidence: param1=string +// +// The evidence is joined to produce: param1 = number | string +// +// # Integration +// +// Parameter evidence feeds into function signature inference, providing +// types for parameters that lack explicit annotations. +package paramevidence diff --git a/compiler/check/infer/paramhints/param_hints.go b/compiler/check/infer/paramevidence/parameter_evidence.go similarity index 66% rename from compiler/check/infer/paramhints/param_hints.go rename to compiler/check/infer/paramevidence/parameter_evidence.go index 26101517..757d8e92 100644 --- a/compiler/check/infer/paramhints/param_hints.go +++ b/compiler/check/infer/paramevidence/parameter_evidence.go @@ -1,4 +1,4 @@ -package paramhints +package paramevidence import ( "github.com/wippyai/go-lua/compiler/ast" @@ -11,17 +11,17 @@ import ( "github.com/wippyai/go-lua/types/typ/unwrap" ) -type HintJoinFn func(prev, next typ.Type) typ.Type +type JoinFn func(prev, next typ.Type) typ.Type // MergeIntoSignature replaces unannotated parameter slots (and refinable -// top-like annotations) with call-site hints. -func MergeIntoSignature(fn *ast.FunctionExpr, hints []typ.Type, sig *typ.Function) *typ.Function { +// top-like annotations) with call-site evidence. +func MergeIntoSignature(fn *ast.FunctionExpr, evidence []typ.Type, sig *typ.Function) *typ.Function { if sig == nil || fn == nil || fn.ParList == nil { return sig } modified := false for i, p := range sig.Params { - if i >= len(hints) || hints[i] == nil { + if i >= len(evidence) || evidence[i] == nil { continue } if srcIdx, hasSource := signatureSourceParamIndex(fn, sig, i); hasSource && srcIdx < len(fn.ParList.Types) && fn.ParList.Types[srcIdx] != nil { @@ -29,7 +29,7 @@ func MergeIntoSignature(fn *ast.FunctionExpr, hints []typ.Type, sig *typ.Functio continue } } - if !typ.TypeEquals(p.Type, hints[i]) { + if !typ.TypeEquals(p.Type, evidence[i]) { modified = true } } @@ -40,11 +40,11 @@ func MergeIntoSignature(fn *ast.FunctionExpr, hints []typ.Type, sig *typ.Functio builder := typ.Func() for i, p := range sig.Params { paramType := p.Type - if i < len(hints) && hints[i] != nil { + if i < len(evidence) && evidence[i] != nil { srcIdx, hasSource := signatureSourceParamIndex(fn, sig, i) annotated := hasSource && srcIdx < len(fn.ParList.Types) && fn.ParList.Types[srcIdx] != nil if !annotated || typ.IsRefinableAnnotation(paramType) { - paramType = hints[i] + paramType = evidence[i] } } if p.Optional { @@ -98,7 +98,7 @@ func signatureHasImplicitSelf(fn *ast.FunctionExpr, sig *typ.Function) bool { return len(sig.Params) == len(fn.ParList.Names)+1 } -func WidenParamHintType(t typ.Type) typ.Type { +func WidenType(t typ.Type) typ.Type { if t == nil { return nil } @@ -115,19 +115,19 @@ func WidenParamHintType(t typ.Type) typ.Type { return typ.String } case *typ.Optional: - inner := WidenParamHintType(v.Inner) + inner := WidenType(v.Inner) if inner != v.Inner && inner != nil { return typ.NewOptional(inner) } case *typ.Alias: if v.Target != nil { - return WidenParamHintType(v.Target) + return WidenType(v.Target) } case *typ.Union: changed := false members := make([]typ.Type, 0, len(v.Members)) for _, m := range v.Members { - wm := WidenParamHintType(m) + wm := WidenType(m) if wm != m { changed = true } @@ -143,7 +143,7 @@ func WidenParamHintType(t typ.Type) typ.Type { builder.SetOpen(true) } for _, f := range v.Fields { - ft := WidenParamHintType(f.Type) + ft := WidenType(f.Type) if ft != f.Type { changed = true } @@ -154,8 +154,8 @@ func WidenParamHintType(t typ.Type) typ.Type { } } if v.MapKey != nil && v.MapValue != nil { - k := WidenParamHintType(v.MapKey) - val := WidenParamHintType(v.MapValue) + k := WidenType(v.MapKey) + val := WidenType(v.MapValue) if k != v.MapKey || val != v.MapValue { changed = true } @@ -171,24 +171,24 @@ func WidenParamHintType(t typ.Type) typ.Type { return t } -// NormalizeHintType applies canonical widening and soft-member pruning. -func NormalizeHintType(t typ.Type) typ.Type { - return collapseTableTopHint(typ.PruneSoftUnionMembers(WidenParamHintType(t))) +// NormalizeType applies canonical widening and soft-member pruning. +func NormalizeType(t typ.Type) typ.Type { + return collapseTableTopEvidence(typ.PruneSoftUnionMembers(WidenType(t))) } -func collapseTableTopHint(t typ.Type) typ.Type { +func collapseTableTopEvidence(t typ.Type) typ.Type { if t == nil { return nil } switch v := t.(type) { case *typ.Alias: - target := collapseTableTopHint(v.Target) + target := collapseTableTopEvidence(v.Target) if target != nil && !typ.TypeEquals(target, v.Target) { return typ.NewAlias(v.Name, target) } return t case *typ.Optional: - inner := collapseTableTopHint(v.Inner) + inner := collapseTableTopEvidence(v.Inner) if inner != nil && !typ.TypeEquals(inner, v.Inner) { return typ.NewOptional(inner) } @@ -210,7 +210,7 @@ func collapseTableTopUnion(u *typ.Union) typ.Type { if tableTop == nil { for _, member := range u.Members { - collapsed := collapseTableTopHint(member) + collapsed := collapseTableTopEvidence(member) if !typ.TypeEquals(collapsed, member) { changed = true } @@ -231,8 +231,8 @@ func collapseTableTopUnion(u *typ.Union) typ.Type { members = append(members, member) continue } - collapsed := collapseTableTopHint(member) - if tableTopCoversHintMember(collapsed) { + collapsed := collapseTableTopEvidence(member) + if tableTopCoversEvidenceMember(collapsed) { if !tableAdded { members = append(members, tableTop) tableAdded = true @@ -255,35 +255,35 @@ func collapseTableTopUnion(u *typ.Union) typ.Type { func firstTableTopMember(members []typ.Type) typ.Type { for _, member := range members { - if isBuiltinTableTopHint(member) { + if isBuiltinTableTopEvidence(member) { return member } } return nil } -func isBuiltinTableTopHint(t typ.Type) bool { +func isBuiltinTableTopEvidence(t typ.Type) bool { return unwrap.IsBuiltinTableTop(typ.UnwrapAnnotated(t)) } -func tableTopCoversHintMember(t typ.Type) bool { +func tableTopCoversEvidenceMember(t typ.Type) bool { if t == nil { return false } - if isBuiltinTableTopHint(t) { + if isBuiltinTableTopEvidence(t) { return true } switch v := typ.UnwrapAnnotated(t).(type) { case *typ.Alias: - return tableTopCoversHintMember(v.UnaliasedTarget()) + return tableTopCoversEvidenceMember(v.UnaliasedTarget()) case *typ.Recursive: - return v.Body != nil && v.Body != v && tableTopCoversHintMember(v.Body) + return v.Body != nil && v.Body != v && tableTopCoversEvidenceMember(v.Body) case *typ.Union: if len(v.Members) == 0 { return false } for _, member := range v.Members { - if member == nil || typ.UnwrapAnnotated(member).Kind() == kind.Nil || !tableTopCoversHintMember(member) { + if member == nil || typ.UnwrapAnnotated(member).Kind() == kind.Nil || !tableTopCoversEvidenceMember(member) { return false } } @@ -295,65 +295,65 @@ func tableTopCoversHintMember(t typ.Type) bool { } } -// EnsureHintCapacity grows hint vector to at least size. -func EnsureHintCapacity(hints []typ.Type, size int) []typ.Type { - if size <= len(hints) { - return hints +// EnsureCapacity grows evidence vector to at least size. +func EnsureCapacity(evidence []typ.Type, size int) []typ.Type { + if size <= len(evidence) { + return evidence } expanded := make([]typ.Type, size) - copy(expanded, hints) + copy(expanded, evidence) return expanded } -// MergeHintAt normalizes and joins one hint into vector slot idx. -func MergeHintAt(hints []typ.Type, idx int, hint typ.Type, join HintJoinFn) ([]typ.Type, bool) { +// MergeAt normalizes and joins one observation into vector slot idx. +func MergeAt(vec []typ.Type, idx int, observed typ.Type, join JoinFn) ([]typ.Type, bool) { if idx < 0 { - return hints, false + return vec, false } - hint = NormalizeHintType(hint) - if !IsInformativeHintType(hint) { - return hints, false + observed = NormalizeType(observed) + if !IsInformative(observed) { + return vec, false } - hints = EnsureHintCapacity(hints, idx+1) + vec = EnsureCapacity(vec, idx+1) joinFn := join if joinFn == nil { joinFn = typ.JoinPreferNonSoft } - prev := hints[idx] - merged := joinFn(prev, hint) + prev := vec[idx] + merged := joinFn(prev, observed) if typ.TypeEquals(prev, merged) { - return hints, false + return vec, false } - hints[idx] = merged - return hints, true + vec[idx] = merged + return vec, true } -// MergeCallArgHintAt merges a call-argument observation into a parameter hint -// slot. Unlike MergeHintAt, unresolved/top-like argument observations are +// MergeCallArgAt merges a call-argument observation into a parameter evidence +// slot. Unlike MergeAt, unresolved/top-like argument observations are // preserved as uncertainty evidence so later literal calls cannot over-specialize // unannotated parameters. -func MergeCallArgHintAt(hints []typ.Type, idx int, argType typ.Type, join HintJoinFn, unknownOnNil bool) ([]typ.Type, bool) { +func MergeCallArgAt(evidence []typ.Type, idx int, argType typ.Type, join JoinFn, unknownOnNil bool) ([]typ.Type, bool) { if idx < 0 { - return hints, false + return evidence, false } - argType = NormalizeHintType(argType) + argType = NormalizeType(argType) if argType == nil { if !unknownOnNil { - return hints, false + return evidence, false } argType = typ.Unknown } - hints = EnsureHintCapacity(hints, idx+1) + evidence = EnsureCapacity(evidence, idx+1) joinFn := join if joinFn == nil { joinFn = typ.JoinPreferNonSoft } - prev := NormalizeHintType(hints[idx]) + prev := NormalizeType(evidence[idx]) if prev == nil { - prev = hints[idx] + prev = evidence[idx] } mergeTopAware := func(a, b typ.Type) typ.Type { @@ -376,29 +376,29 @@ func MergeCallArgHintAt(hints []typ.Type, idx int, argType typ.Type, join HintJo } topLikeArg := typ.IsAny(argType) || typ.IsUnknown(argType) - if !topLikeArg && !IsInformativeHintType(argType) { - return hints, false + if !topLikeArg && !IsInformative(argType) { + return evidence, false } merged := mergeTopAware(prev, argType) - if typ.TypeEquals(hints[idx], merged) { - return hints, false + if typ.TypeEquals(evidence[idx], merged) { + return evidence, false } - hints[idx] = merged - return hints, true + evidence[idx] = merged + return evidence, true } -// IsInformativeHintType reports whether a type carries useful call-site -// information for parameter hint propagation. +// IsInformative reports whether a type carries useful call-site +// information for parameter evidence propagation. // // It intentionally rejects top-like and empty placeholder shapes that tend to -// poison hints, while preserving structured hints such as maps/arrays with +// poison evidence, while preserving structured evidence such as maps/arrays with // partial information (for example `{[string]: any[]}`). -func IsInformativeHintType(t typ.Type) bool { - return isInformativeHintType(t, typ.NewGuard()) +func IsInformative(t typ.Type) bool { + return isInformativeEvidenceType(t, typ.NewGuard()) } -func isInformativeHintType(t typ.Type, guard internal.RecursionGuard) bool { +func isInformativeEvidenceType(t typ.Type, guard internal.RecursionGuard) bool { if t == nil { return false } @@ -418,10 +418,10 @@ func isInformativeHintType(t typ.Type, guard internal.RecursionGuard) bool { switch v := t.(type) { case *typ.Optional: - return isInformativeHintType(v.Inner, next) + return isInformativeEvidenceType(v.Inner, next) case *typ.Union: for _, m := range v.Members { - if isInformativeHintType(m, next) { + if isInformativeEvidenceType(m, next) { return true } } @@ -430,7 +430,7 @@ func isInformativeHintType(t typ.Type, guard internal.RecursionGuard) bool { if v.Target == nil { return false } - return isInformativeHintType(v.Target, next) + return isInformativeEvidenceType(v.Target, next) } if r, ok := t.(*typ.Record); ok { @@ -442,10 +442,9 @@ func isInformativeHintType(t typ.Type, guard internal.RecursionGuard) bool { return true } -// BuildParamHintSignatures builds a function-expression keyed hint map for this graph. -// It merges per-iteration scratch hints with symbol-based hints from the store. -// Scratch hints take precedence over symbol-derived hints. -func BuildParamHintSignatures( +// BuildSignatureMap builds a function-expression keyed parameter evidence +// map for this graph from canonical FunctionFacts. +func BuildSignatureMap( store api.StoreReader, graph *cfg.Graph, parent *scope.State, @@ -455,25 +454,24 @@ func BuildParamHintSignatures( return nil } - // Use stable snapshot param hints during analysis. - symHints := store.GetParamHintsSnapshot(graph, parent) + functionFacts := store.GetFunctionFactsSnapshot(graph, parent) out := make(map[*ast.FunctionExpr][]typ.Type) - if len(symHints) > 0 { - for _, sym := range cfg.SortedSymbolIDs(symHints) { - hints := symHints[sym] - if len(hints) == 0 { + if len(functionFacts) > 0 { + for _, sym := range cfg.SortedSymbolIDs(functionFacts) { + vec := functionFacts.Params(sym) + if len(vec) == 0 { continue } - hasHint := false - for _, hint := range hints { - if hint != nil { - hasHint = true + hasEvidence := false + for _, observed := range vec { + if observed != nil { + hasEvidence = true break } } - if !hasHint { + if !hasEvidence { continue } fn := store.FuncForSymbol(sym) @@ -481,13 +479,13 @@ func BuildParamHintSignatures( continue } if _, exists := out[fn]; !exists { - out[fn] = hints + out[fn] = vec } } } - // If this graph is a nested function, pull param hints from the parent graph - // and apply them to the current function signature. + // If this graph is a nested function, pull parameter evidence from the + // parent graph and apply it to the current function signature. if meta, ok := store.NestedMetaFor(graph.ID()); ok { parentGraph := store.Graphs()[meta.ParentGraphID] if parentGraph != nil { @@ -497,16 +495,16 @@ func BuildParamHintSignatures( } parentScope := api.ParentScopeForGraph(store, parentGraph.ID(), defaultScope) if parentScope != nil { - parentHints := store.GetParamHintsSnapshot(parentGraph, parentScope) - if len(parentHints) > 0 { + parentFacts := store.GetFunctionFactsSnapshot(parentGraph, parentScope) + if len(parentFacts) > 0 { fn := store.FuncForGraph(graph) if fn == nil { fn = graph.Func() } if fn != nil { if sym, ok := store.SymbolForFunc(fn); ok { - if hints := parentHints[sym]; len(hints) > 0 { - out[fn] = hints + if evidence := parentFacts.Params(sym); len(evidence) > 0 { + out[fn] = evidence } } } diff --git a/compiler/check/infer/paramhints/param_hints_test.go b/compiler/check/infer/paramevidence/parameter_evidence_test.go similarity index 64% rename from compiler/check/infer/paramhints/param_hints_test.go rename to compiler/check/infer/paramevidence/parameter_evidence_test.go index f91114ad..86f7f65d 100644 --- a/compiler/check/infer/paramhints/param_hints_test.go +++ b/compiler/check/infer/paramevidence/parameter_evidence_test.go @@ -1,4 +1,4 @@ -package paramhints +package paramevidence import ( "testing" @@ -8,64 +8,64 @@ import ( "github.com/wippyai/go-lua/types/typ" ) -func TestWidenParamHintType_Nil(t *testing.T) { - result := WidenParamHintType(nil) +func TestWidenType_Nil(t *testing.T) { + result := WidenType(nil) if result != nil { t.Errorf("expected nil, got %v", result) } } -func TestWidenParamHintType_BooleanLiteral(t *testing.T) { +func TestWidenType_BooleanLiteral(t *testing.T) { lit := typ.LiteralBool(true) - result := WidenParamHintType(lit) + result := WidenType(lit) if result != typ.Boolean { t.Errorf("expected Boolean, got %v", result) } } -func TestWidenParamHintType_IntegerLiteral(t *testing.T) { +func TestWidenType_IntegerLiteral(t *testing.T) { lit := typ.LiteralInt(42) - result := WidenParamHintType(lit) + result := WidenType(lit) if result != typ.Integer { t.Errorf("expected Integer, got %v", result) } } -func TestWidenParamHintType_NumberLiteral(t *testing.T) { +func TestWidenType_NumberLiteral(t *testing.T) { lit := typ.LiteralNumber(3.14) - result := WidenParamHintType(lit) + result := WidenType(lit) if result != typ.Number { t.Errorf("expected Number, got %v", result) } } -func TestWidenParamHintType_StringLiteral(t *testing.T) { +func TestWidenType_StringLiteral(t *testing.T) { lit := typ.LiteralString("hello") - result := WidenParamHintType(lit) + result := WidenType(lit) if result != typ.String { t.Errorf("expected String, got %v", result) } } -func TestWidenParamHintType_NonLiteral(t *testing.T) { - result := WidenParamHintType(typ.String) +func TestWidenType_NonLiteral(t *testing.T) { + result := WidenType(typ.String) if result != typ.String { t.Errorf("expected String unchanged, got %v", result) } } -func TestWidenParamHintType_Alias(t *testing.T) { +func TestWidenType_Alias(t *testing.T) { alias := typ.NewAlias("NumAlias", typ.Number) - result := WidenParamHintType(alias) + result := WidenType(alias) if result != typ.Number { t.Errorf("expected alias to widen to Number, got %v", result) } } -func TestWidenParamHintType_Optional(t *testing.T) { +func TestWidenType_Optional(t *testing.T) { lit := typ.LiteralString("hello") opt := typ.NewOptional(lit) - result := WidenParamHintType(opt) + result := WidenType(opt) if result == nil { t.Fatal("expected non-nil result") } @@ -78,45 +78,45 @@ func TestWidenParamHintType_Optional(t *testing.T) { } } -func TestWidenParamHintType_Union(t *testing.T) { +func TestWidenType_Union(t *testing.T) { lit1 := typ.LiteralString("a") lit2 := typ.LiteralNumber(1.0) union := typ.NewUnion(lit1, lit2) - result := WidenParamHintType(union) + result := WidenType(union) if result == nil { t.Fatal("expected non-nil result") } } -func TestNormalizeHintType_TableTopAbsorbsPreciseTableMembers(t *testing.T) { +func TestNormalizeType_TableTopAbsorbsPreciseTableMembers(t *testing.T) { tableTop := typ.NewInterface("table", nil) preciseA := typ.NewRecord(). Field("name", typ.String). Field("tools", typ.NewArray(typ.String)). Build() preciseB := typ.NewMap(typ.String, typ.Integer) - hint := typ.NewUnion(typ.NewOptional(tableTop), preciseA, preciseB, typ.String) + evidence := typ.NewUnion(typ.NewOptional(tableTop), preciseA, preciseB, typ.String) - got := NormalizeHintType(hint) + got := NormalizeType(evidence) want := typ.NewUnion(typ.NewOptional(tableTop), typ.String) if !typ.TypeEquals(got, want) { t.Fatalf("expected table top to absorb precise table members as %v, got %v", want, got) } } -func TestWidenParamHintType_RecordPreservesClosedShape(t *testing.T) { +func TestWidenType_RecordPreservesClosedShape(t *testing.T) { rec := typ.NewRecord(). Field("pid", typ.LiteralString("abc")). Field("topic", typ.LiteralString("test:update")). Build() - result := WidenParamHintType(rec) + result := WidenType(rec) widened, ok := result.(*typ.Record) if !ok { t.Fatalf("expected record result, got %T", result) } if widened.Open { - t.Fatalf("expected param hint to preserve closed call-site shape, got open: %v", widened) + t.Fatalf("expected parameter evidence to preserve closed call-site shape, got open: %v", widened) } pid := widened.GetField("pid") @@ -129,8 +129,8 @@ func TestWidenParamHintType_RecordPreservesClosedShape(t *testing.T) { } } -func TestBuildParamHintSignatures_NilInputs(t *testing.T) { - result := BuildParamHintSignatures(nil, nil, nil, nil) +func TestBuildSignatureMap_NilInputs(t *testing.T) { + result := BuildSignatureMap(nil, nil, nil, nil) if result != nil { t.Errorf("expected nil for nil inputs, got %v", result) } @@ -149,10 +149,10 @@ func TestMergeIntoSignature_ImplicitSelfUsesEffectiveHintSlots(t *testing.T) { t.Fatalf("unexpected merged signature: %v", got) } if !typ.TypeEquals(got.Params[0].Type, selfType) { - t.Fatalf("self hint should use effective slot 0, got %v", got.Params[0].Type) + t.Fatalf("self evidence should use effective slot 0, got %v", got.Params[0].Type) } if !typ.TypeEquals(got.Params[1].Type, typ.String) { - t.Fatalf("source parameter hint should use effective slot 1, got %v", got.Params[1].Type) + t.Fatalf("source parameter evidence should use effective slot 1, got %v", got.Params[1].Type) } } @@ -177,7 +177,7 @@ func TestMergeIntoSignature_PreservesExplicitNilabilityOnOptionalSlot(t *testing } } -func TestProjectHintsToParamUse_KeepsDemandedRecordFields(t *testing.T) { +func TestProjectToParameterUse_KeepsDemandedRecordFields(t *testing.T) { fn := functionWithParams("client", "model_id") fn.Stmts = []ast.Stmt{ &ast.ReturnStmt{Exprs: []ast.Expr{ @@ -202,25 +202,25 @@ func TestProjectHintsToParamUse_KeepsDemandedRecordFields(t *testing.T) { Field("_credentials", typ.String). Build() - got := ProjectHintsToParamUse(graph, fn, []typ.Type{client, typ.String}) + got := ProjectToParameterUse(graph, fn, []typ.Type{client, typ.String}) rec, ok := got[0].(*typ.Record) if !ok { - t.Fatalf("projected client hint = %T, want record (%v)", got[0], got[0]) + t.Fatalf("projected client evidence = %T, want record (%v)", got[0], got[0]) } if rec.GetField("invoke") == nil { - t.Fatalf("projected client hint lost demanded invoke field: %v", rec) + t.Fatalf("projected client evidence lost demanded invoke field: %v", rec) } for _, unused := range []string{"process_converse_stream", "_credentials"} { if rec.GetField(unused) != nil { - t.Fatalf("projected client hint kept unused field %q: %v", unused, rec) + t.Fatalf("projected client evidence kept unused field %q: %v", unused, rec) } } if !typ.TypeEquals(got[1], typ.String) { - t.Fatalf("directly used scalar hint should stay intact, got %v", got[1]) + t.Fatalf("directly used scalar evidence should stay intact, got %v", got[1]) } } -func TestProjectHintsToParamUse_KeepsDemandedAbsentRecordFieldsAsNil(t *testing.T) { +func TestProjectToParameterUse_KeepsDemandedAbsentRecordFieldsAsNil(t *testing.T) { fn := functionWithParams("options") fn.Stmts = []ast.Stmt{ &ast.IfStmt{ @@ -243,14 +243,14 @@ func TestProjectHintsToParamUse_KeepsDemandedAbsentRecordFieldsAsNil(t *testing. }, } graph := cfg.Build(fn) - hint := typ.NewRecord(). + evidence := typ.NewRecord(). Field("headers", typ.NewRecord().Build()). Build() - got := ProjectHintsToParamUse(graph, fn, []typ.Type{hint}) + got := ProjectToParameterUse(graph, fn, []typ.Type{evidence}) rec, ok := got[0].(*typ.Record) if !ok { - t.Fatalf("projected options hint = %T, want record (%v)", got[0], got[0]) + t.Fatalf("projected options evidence = %T, want record (%v)", got[0], got[0]) } stream := rec.GetField("stream") if stream == nil || !typ.TypeEquals(stream.Type, typ.Nil) { @@ -258,7 +258,7 @@ func TestProjectHintsToParamUse_KeepsDemandedAbsentRecordFieldsAsNil(t *testing. } headers := rec.GetField("headers") if headers == nil { - t.Fatalf("projected options hint lost demanded headers field: %v", rec) + t.Fatalf("projected options evidence lost demanded headers field: %v", rec) } } @@ -312,7 +312,100 @@ func TestProjectSignatureToParamUse_CompletesDemandedAbsentFields(t *testing.T) } } -func TestProjectHintsToParamUse_DedupsUnionAfterProjection(t *testing.T) { +func TestProjectToParameterUse_TypeGuardDoesNotKeepWholeRecord(t *testing.T) { + fn := functionWithParams("params") + fn.Stmts = []ast.Stmt{ + &ast.IfStmt{ + Condition: &ast.RelationalOpExpr{ + Operator: "~=", + Lhs: &ast.FuncCallExpr{ + Func: &ast.IdentExpr{Value: "type"}, + Args: []ast.Expr{&ast.IdentExpr{Value: "params"}}, + }, + Rhs: &ast.StringExpr{Value: "table"}, + }, + }, + &ast.IfStmt{ + Condition: &ast.UnaryNotOpExpr{ + Expr: &ast.AttrGetExpr{ + Object: &ast.IdentExpr{Value: "params"}, + Key: &ast.StringExpr{Value: "agent"}, + }, + }, + }, + &ast.ReturnStmt{Exprs: []ast.Expr{ + &ast.AttrGetExpr{ + Object: &ast.IdentExpr{Value: "params"}, + Key: &ast.StringExpr{Value: "kind"}, + }, + }}, + } + graph := cfg.Build(fn) + evidence := typ.NewRecord(). + OptField("kind", typ.String). + Build() + + got := ProjectToParameterUse(graph, fn, []typ.Type{evidence}) + rec, ok := got[0].(*typ.Record) + if !ok { + t.Fatalf("projected params evidence = %T, want record (%v)", got[0], got[0]) + } + if rec.GetField("kind") == nil { + t.Fatalf("projected params evidence lost demanded kind field: %v", rec) + } + agent := rec.GetField("agent") + if agent == nil || !typ.TypeEquals(agent.Type, typ.Nil) { + t.Fatalf("type(params) should not force whole-record evidence; agent = %v in %v", agent, rec) + } +} + +func TestProjectToParameterUse_SelfDefaultDoesNotKeepWholeRecord(t *testing.T) { + fn := functionWithParams("options") + optionsIdent := &ast.IdentExpr{Value: "options"} + fn.Stmts = []ast.Stmt{ + &ast.AssignStmt{ + Lhs: []ast.Expr{optionsIdent}, + Rhs: []ast.Expr{&ast.LogicalOpExpr{ + Operator: "or", + Lhs: &ast.IdentExpr{Value: "options"}, + Rhs: &ast.TableExpr{}, + }}, + }, + &ast.LocalAssignStmt{ + Names: []string{"method"}, + Exprs: []ast.Expr{&ast.AttrGetExpr{ + Object: &ast.IdentExpr{Value: "options"}, + Key: &ast.StringExpr{Value: "method"}, + }}, + }, + &ast.LocalAssignStmt{ + Names: []string{"timeout"}, + Exprs: []ast.Expr{&ast.AttrGetExpr{ + Object: &ast.IdentExpr{Value: "options"}, + Key: &ast.StringExpr{Value: "timeout"}, + }}, + }, + } + graph := cfg.Build(fn) + evidence := typ.NewRecord(). + OptField("method", typ.String). + Build() + + got := ProjectToParameterUse(graph, fn, []typ.Type{evidence}) + rec, ok := got[0].(*typ.Record) + if !ok { + t.Fatalf("projected options evidence = %T, want record (%v)", got[0], got[0]) + } + if rec.GetField("method") == nil { + t.Fatalf("projected options evidence lost demanded method field: %v", rec) + } + timeout := rec.GetField("timeout") + if timeout == nil || !typ.TypeEquals(timeout.Type, typ.Nil) { + t.Fatalf("self-default assignment should not force whole-record evidence; timeout = %v in %v", timeout, rec) + } +} + +func TestProjectToParameterUse_DedupsUnionAfterProjection(t *testing.T) { fn := functionWithParams("client") fn.Stmts = []ast.Stmt{ &ast.ReturnStmt{Exprs: []ast.Expr{ @@ -335,17 +428,17 @@ func TestProjectHintsToParamUse_DedupsUnionAfterProjection(t *testing.T) { Field("stream", typ.Func().Returns(typ.LiteralString("invalid")).Build()). Build() - got := ProjectHintsToParamUse(graph, fn, []typ.Type{typ.NewUnion(broad, narrow)}) + got := ProjectToParameterUse(graph, fn, []typ.Type{typ.NewUnion(broad, narrow)}) rec, ok := got[0].(*typ.Record) if !ok { - t.Fatalf("projected union hint = %T, want coalesced record (%v)", got[0], got[0]) + t.Fatalf("projected union evidence = %T, want coalesced record (%v)", got[0], got[0]) } if rec.GetField("invoke") == nil || rec.GetField("stream") != nil { t.Fatalf("projected union should keep only invoke, got %v", rec) } } -func TestProjectHintsToParamUse_WholeParameterUseKeepsHint(t *testing.T) { +func TestProjectToParameterUse_WholeParameterUseKeepsEvidence(t *testing.T) { fn := functionWithParams("client") fn.Stmts = []ast.Stmt{ &ast.ReturnStmt{Exprs: []ast.Expr{ @@ -358,13 +451,13 @@ func TestProjectHintsToParamUse_WholeParameterUseKeepsHint(t *testing.T) { graph := cfg.Build(fn, "use_client") client := typ.NewRecord().Field("invoke", typ.Func().Returns(typ.Unknown).Build()).Build() - got := ProjectHintsToParamUse(graph, fn, []typ.Type{client}) + got := ProjectToParameterUse(graph, fn, []typ.Type{client}) if !typ.TypeEquals(got[0], client) { - t.Fatalf("whole-parameter use should keep full hint, got %v", got[0]) + t.Fatalf("whole-parameter use should keep full evidence, got %v", got[0]) } } -func TestProjectHintsToParamUse_RecursiveForwardingDoesNotKeepWholeHint(t *testing.T) { +func TestProjectToParameterUse_RecursiveForwardingDoesNotKeepWholeEvidence(t *testing.T) { recursiveIdent := &ast.IdentExpr{Value: "visit"} selfIdent := &ast.IdentExpr{Value: "self"} valueIdent := &ast.IdentExpr{Value: "value"} @@ -401,18 +494,18 @@ func TestProjectHintsToParamUse_RecursiveForwardingDoesNotKeepWholeHint(t *testi if sym, ok := graph.Bindings().SymbolOf(recursiveIdent); ok { graph.Bindings().SetFuncLitSymbol(fn, sym) } - selfHint := typ.NewRecord(). + selfEvidence := typ.NewRecord(). Field("id", typ.String). Field("command", typ.Func().Returns(typ.Nil, typ.String).Build()). Build() - got := ProjectHintsToParamUse(graph, fn, []typ.Type{selfHint, typ.NewRecord().Field("next", typ.Any).Build()}) + got := ProjectToParameterUse(graph, fn, []typ.Type{selfEvidence, typ.NewRecord().Field("next", typ.Any).Build()}) rec, ok := got[0].(*typ.Record) if !ok { - t.Fatalf("projected self hint = %T, want record (%v)", got[0], got[0]) + t.Fatalf("projected self evidence = %T, want record (%v)", got[0], got[0]) } if rec.GetField("id") == nil { - t.Fatalf("projected self hint lost demanded id field: %v", rec) + t.Fatalf("projected self evidence lost demanded id field: %v", rec) } if rec.GetField("command") != nil { t.Fatalf("recursive forwarding should not keep unused command field: %v", rec) @@ -423,7 +516,7 @@ func functionWithParams(names ...string) *ast.FunctionExpr { return &ast.FunctionExpr{ParList: &ast.ParList{Names: names}} } -func TestIsInformativeHintType(t *testing.T) { +func TestIsInformative(t *testing.T) { tests := []struct { name string in typ.Type @@ -448,31 +541,31 @@ func TestIsInformativeHintType(t *testing.T) { } for _, tt := range tests { - if got := IsInformativeHintType(tt.in); got != tt.want { + if got := IsInformative(tt.in); got != tt.want { t.Errorf("%s: got %v, want %v", tt.name, got, tt.want) } } } -func TestEnsureHintCapacity(t *testing.T) { +func TestEnsureCapacity(t *testing.T) { base := []typ.Type{typ.String} - got := EnsureHintCapacity(base, 3) + got := EnsureCapacity(base, 3) if len(got) != 3 { - t.Fatalf("EnsureHintCapacity len = %d, want 3", len(got)) + t.Fatalf("EnsureCapacity len = %d, want 3", len(got)) } if got[0] != typ.String { - t.Fatalf("EnsureHintCapacity preserved value = %v, want string", got[0]) + t.Fatalf("EnsureCapacity preserved value = %v, want string", got[0]) } } -func TestMergeHintAt(t *testing.T) { +func TestMergeAt(t *testing.T) { join := func(prev, next typ.Type) typ.Type { return typ.JoinPreferNonSoft(prev, next) } t.Run("filters non-informative", func(t *testing.T) { - hints := []typ.Type{typ.String} - got, changed := MergeHintAt(hints, 1, typ.Unknown, join) + evidence := []typ.Type{typ.String} + got, changed := MergeAt(evidence, 1, typ.Unknown, join) if changed { - t.Fatal("expected no change for unknown hint") + t.Fatal("expected no change for unknown evidence") } if len(got) != 1 { t.Fatalf("expected unchanged slice len 1, got %d", len(got)) @@ -480,15 +573,15 @@ func TestMergeHintAt(t *testing.T) { }) t.Run("normalizes literal and merges", func(t *testing.T) { - got, changed := MergeHintAt(nil, 0, typ.LiteralString("x"), join) + got, changed := MergeAt(nil, 0, typ.LiteralString("x"), join) if !changed { t.Fatal("expected merge change for informative literal") } if len(got) != 1 { - t.Fatalf("expected one hint, got %d", len(got)) + t.Fatalf("expected one evidence, got %d", len(got)) } if !typ.TypeEquals(got[0], typ.String) { - t.Fatalf("expected normalized string hint, got %v", got[0]) + t.Fatalf("expected normalized string evidence, got %v", got[0]) } }) } diff --git a/compiler/check/infer/paramhints/project.go b/compiler/check/infer/paramevidence/project.go similarity index 84% rename from compiler/check/infer/paramhints/project.go rename to compiler/check/infer/paramevidence/project.go index 6bfc14ff..8a21ea3f 100644 --- a/compiler/check/infer/paramhints/project.go +++ b/compiler/check/infer/paramevidence/project.go @@ -1,4 +1,4 @@ -package paramhints +package paramevidence import ( "sort" @@ -17,48 +17,48 @@ type paramUse struct { fields map[string]struct{} } -// ProjectHintsToParamUse trims structured call-site hints to the surface the -// function body actually reads from each unannotated parameter. Hints are -// evidence for analyzing a helper, not a promise that every unused field on the +// ProjectToParameterUse trims structured call-site evidence to the surface the +// function body actually reads from each unannotated parameter. It is evidence +// for analyzing a helper, not a promise that every unused field on the // first argument shape is part of that helper's public contract. -func ProjectHintsToParamUse(graph *cfg.Graph, fn *ast.FunctionExpr, hints []typ.Type) []typ.Type { - if graph == nil || fn == nil || len(hints) == 0 { - return hints +func ProjectToParameterUse(graph *cfg.Graph, fn *ast.FunctionExpr, vec []typ.Type) []typ.Type { + if graph == nil || fn == nil || len(vec) == 0 { + return vec } uses := collectParamUses(graph, fn) if len(uses) == 0 { - return hints + return vec } var out []typ.Type for idx, slot := range graph.ParamSlotsReadOnly() { - if slot.Symbol == 0 || idx < 0 || idx >= len(hints) { + if slot.Symbol == 0 || idx < 0 || idx >= len(vec) { continue } - hint := hints[idx] - if hint == nil { + observed := vec[idx] + if observed == nil { continue } - projected := projectHintToUse(hint, uses[slot.Symbol]) - if typ.TypeEquals(hint, projected) { + projected := projectEvidenceToUse(observed, uses[slot.Symbol]) + if typ.TypeEquals(observed, projected) { continue } if out == nil { - out = make([]typ.Type, len(hints)) - copy(out, hints) + out = make([]typ.Type, len(vec)) + copy(out, vec) } out[idx] = projected } if out == nil { - return hints + return vec } return out } // ProjectSignatureToParamUse completes a function signature's parameter slots -// against the fields the function body reads. Unlike ProjectHintsToParamUse it +// against the fields the function body reads. Unlike ProjectToParameterUse it // does not trim unused fields: a function fact is already a canonical signature // observation, and same-body analysis only needs to ensure demanded fields are // present even when the parameter is also used as a whole value. @@ -250,10 +250,23 @@ type paramUseCollector struct { func (c *paramUseCollector) stmt(stmt ast.Stmt) { switch s := stmt.(type) { case *ast.AssignStmt: - for _, lhs := range s.Lhs { + var skipRHS map[int]struct{} + for i, lhs := range s.Lhs { + if i < len(s.Rhs) && c.isParamSelfDefault(lhs, s.Rhs[i]) { + if skipRHS == nil { + skipRHS = make(map[int]struct{}, 1) + } + skipRHS[i] = struct{}{} + continue + } c.lvalue(lhs) } - for _, rhs := range s.Rhs { + for i, rhs := range s.Rhs { + if skipRHS != nil { + if _, skip := skipRHS[i]; skip { + continue + } + } c.expr(rhs) } case *ast.LocalAssignStmt: @@ -410,10 +423,30 @@ func (c *paramUseCollector) call(call *ast.FuncCallExpr) { if recursive && c.isParamExpr(arg) { continue } + if c.isBuiltinTypeCall(call) && c.isParamExpr(arg) { + continue + } c.expr(arg) } } +func (c *paramUseCollector) isBuiltinTypeCall(call *ast.FuncCallExpr) bool { + if call == nil || call.Method != "" { + return false + } + ident, ok := call.Func.(*ast.IdentExpr) + if !ok || ident == nil || ident.Value != "type" { + return false + } + if c.bindings != nil { + if sym, ok := c.bindings.SymbolOf(ident); ok && sym != 0 { + kind, hasKind := c.bindings.Kind(sym) + return hasKind && kind == cfg.SymbolGlobal + } + } + return true +} + func (c *paramUseCollector) isDirectRecursiveCall(call *ast.FuncCallExpr) bool { if call == nil || call.Method != "" || len(c.currentFunctionSymbols) == 0 { return false @@ -428,6 +461,8 @@ func (c *paramUseCollector) isDirectRecursiveCall(call *ast.FuncCallExpr) bool { func (c *paramUseCollector) lvalue(expr ast.Expr) { switch e := expr.(type) { + case *ast.IdentExpr: + return case *ast.AttrGetExpr: if c.pathUse(expr) { return @@ -439,6 +474,23 @@ func (c *paramUseCollector) lvalue(expr ast.Expr) { } } +func (c *paramUseCollector) isParamSelfDefault(lhs, rhs ast.Expr) bool { + lhsIdent, ok := lhs.(*ast.IdentExpr) + if !ok || !c.isParamIdent(lhsIdent) { + return false + } + op, ok := rhs.(*ast.LogicalOpExpr) + if !ok || op.Operator != "or" { + return false + } + rhsIdent, ok := op.Lhs.(*ast.IdentExpr) + if !ok || !c.sameParamIdent(lhsIdent, rhsIdent) { + return false + } + _, ok = op.Rhs.(*ast.TableExpr) + return ok +} + func (c *paramUseCollector) pathUse(expr ast.Expr) bool { p := flowpath.FromExprWithBindings(expr, nil, c.bindings) if !c.isParamPath(p) { @@ -487,6 +539,19 @@ func (c *paramUseCollector) isParamIdent(ident *ast.IdentExpr) bool { return ok } +func (c *paramUseCollector) sameParamIdent(a, b *ast.IdentExpr) bool { + if c.bindings == nil || a == nil || b == nil { + return false + } + asym, aok := c.bindings.SymbolOf(a) + bsym, bok := c.bindings.SymbolOf(b) + if !aok || !bok || asym == 0 || bsym == 0 || asym != bsym { + return false + } + _, ok := c.paramSymbols[asym] + return ok +} + func isNilLiteral(expr ast.Expr) bool { _, ok := expr.(*ast.NilExpr) return ok @@ -561,16 +626,16 @@ func segmentFieldName(seg constraint.Segment) string { } } -func projectHintToUse(hint typ.Type, use paramUse) typ.Type { - if hint == nil || use.whole { - return hint +func projectEvidenceToUse(observed typ.Type, use paramUse) typ.Type { + if observed == nil || use.whole { + return observed } if len(use.fields) == 0 { return nil } - projected, ok := projectTypeToFields(hint, use.fields) + projected, ok := projectTypeToFields(observed, use.fields) if !ok { - return hint + return observed } return projected } diff --git a/compiler/check/infer/paramhints/doc.go b/compiler/check/infer/paramhints/doc.go deleted file mode 100644 index b8a27c21..00000000 --- a/compiler/check/infer/paramhints/doc.go +++ /dev/null @@ -1,29 +0,0 @@ -// Package paramhints infers parameter types from call site arguments. -// -// This package analyzes function call sites to infer parameter types for -// functions without explicit type annotations. When a function is called -// with known-type arguments, those types hint at the parameter types. -// -// # Hint Collection -// -// For each call site: -// -// foo(123, "bar") -- hints: param1=number, param2=string -// -// The package collects argument types and associates them with parameter -// positions. Multiple call sites contribute hints that are joined. -// -// # Hint Merging -// -// When multiple calls provide conflicting hints: -// -// foo(1) -- hints: param1=number -// foo("a") -- hints: param1=string -// -// The hints are joined to produce: param1 = number | string -// -// # Integration -// -// Parameter hints feed into function signature inference, providing -// types for parameters that lack explicit annotations. -package paramhints diff --git a/compiler/check/infer/return/infer.go b/compiler/check/infer/return/infer.go index 02bec4c8..15e6bae1 100644 --- a/compiler/check/infer/return/infer.go +++ b/compiler/check/infer/return/infer.go @@ -22,12 +22,12 @@ // - Types can only grow (become more general), never shrink // - Bounded iteration with widening to unknown on non-convergence // -// # PARAM HINTS +// # PARAMETER EVIDENCE // -// Parameter type hints are collected from call sites: -// - When a() calls b(10), b's first param gets hint "number" -// - Hints from multiple call sites are joined -// - Hints propagate through the call graph (if a() calls b(), b() calls c()) +// Parameter evidence is collected from call sites: +// - When a() calls b(10), b's first param records number evidence. +// - Multiple call sites are joined. +// - Evidence propagates through the call graph (if a() calls b(), b() calls c()). // // # SEED PROPAGATION // @@ -44,7 +44,7 @@ import ( "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/callsite" flowpath "github.com/wippyai/go-lua/compiler/check/flowbuild/path" - "github.com/wippyai/go-lua/compiler/check/infer/paramhints" + "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" "github.com/wippyai/go-lua/compiler/check/modules" "github.com/wippyai/go-lua/compiler/check/phase" "github.com/wippyai/go-lua/compiler/check/returns" @@ -220,15 +220,15 @@ func (i *Inferencer) ComputeForGraph( return nil, nil } - // Apply param hints from the stable snapshot (deterministic order). - if hints := i.store.GetParamHintsSnapshot(graph, parentScope); len(hints) > 0 { + // Apply parameter evidence from the stable canonical function facts. + if facts := i.store.GetFunctionFactsSnapshot(graph, parentScope); len(facts) > 0 { for _, sym := range cfg.SortedSymbolIDs(localFuncs) { info := localFuncs[sym] if info == nil { continue } - if hintVec, ok := hints[sym]; ok && len(hintVec) > 0 { - info.ParamHints = paramhints.ProjectHintsToParamUse(info.Graph, info.Fn, hintVec) + if hintVec := facts.Params(sym); len(hintVec) > 0 { + info.ParameterEvidence = paramevidence.ProjectToParameterUse(info.Graph, info.Fn, hintVec) } } } @@ -242,7 +242,7 @@ func (i *Inferencer) ComputeForGraph( } returnVectors, diags := i.computeReturnVectorsForGroup(run, parentScope.GroupHash(), localFuncs, seed) functionTypes := i.buildLocalFunctionTypes(localFuncs, returnVectors, engine, parentScope) - return assembleFunctionFacts(returnVectors, functionTypes), diags + return assembleFunctionFacts(localFuncs, returnVectors, functionTypes), diags } func (i *Inferencer) buildLocalFunctionTypes( @@ -279,8 +279,8 @@ func (i *Inferencer) buildLocalFunctionTypes( if fnType == nil { continue } - if len(info.ParamHints) > 0 { - if merged := paramhints.MergeIntoSignature(info.Fn, info.ParamHints, fnType); merged != nil { + if len(info.ParameterEvidence) > 0 { + if merged := paramevidence.MergeIntoSignature(info.Fn, info.ParameterEvidence, fnType); merged != nil { fnType = merged } } @@ -298,14 +298,20 @@ func (i *Inferencer) buildLocalFunctionTypes( } func assembleFunctionFacts( + localFuncs map[cfg.SymbolID]*returns.LocalFuncInfo, returnVectors map[cfg.SymbolID][]typ.Type, funcs map[cfg.SymbolID]typ.Type, ) api.FunctionFacts { - total := len(returnVectors) + len(funcs) + total := len(localFuncs) + len(returnVectors) + len(funcs) if total == 0 { return nil } symbols := make(map[cfg.SymbolID]bool, total) + for sym := range localFuncs { + if sym != 0 { + symbols[sym] = true + } + } for sym := range returnVectors { if sym != 0 { symbols[sym] = true @@ -321,11 +327,16 @@ func assembleFunctionFacts( } out := make(api.FunctionFacts, len(symbols)) for _, sym := range cfg.SortedSymbolIDs(symbols) { + var params []typ.Type + if info := localFuncs[sym]; info != nil { + params = info.ParameterEvidence + } ff := returns.JoinFunctionFact(api.FunctionFact{}, api.FunctionFact{ + Params: params, Summary: returnVectors[sym], Type: funcs[sym], }) - if len(ff.Summary) == 0 && ff.Type == nil && len(ff.Narrow) == 0 { + if len(ff.Params) == 0 && len(ff.Summary) == 0 && ff.Type == nil && len(ff.Narrow) == 0 { continue } out[sym] = ff @@ -388,7 +399,7 @@ type returnInferenceContext struct { } // buildParameterOverlay creates the initial type overlay with parameter types. -// Parameters are typed from annotations, hints, or default to unknown. +// Parameters are typed from annotations, parameter evidence, or default to unknown. func collectReturnTypes( fnGraph *cfg.Graph, synthEngine api.Synth, @@ -540,7 +551,7 @@ func (i *Inferencer) skipUnresolvedLocalReturnCall(ctx *returnInferenceContext) // // Phase 1 (Preliminary): Collect inferred types for local variables within the function. // This uses a preliminary synthesis engine with: -// - Parameter types (from annotations or param hints) +// - Parameter types (from annotations or parameter evidence) // - Sibling function types (from return vectors) // - Captured variable types (from parent function result) // @@ -632,30 +643,30 @@ func (i *Inferencer) inferReturnForFunction( // in the same function. For example, a helper call may prove that a parameter // field is string?, which then makes `param.field or "default"` synthesize as // string without relying on a value-level fallback shortcut. - i.mergeParamHintsFromBodyUses(ctx, overlay) - i.applyParamHintsToOverlay(ctx, overlay) + i.mergeParameterEvidenceFromBodyUses(ctx, overlay) + i.applyParameterEvidenceToOverlay(ctx, overlay) // Phase 1: Infer local variable types. inferred, _, synthAdapter := i.inferLocalVariableTypes(ctx, overlay, localValueSeeds) // Collect field/indexer assignments and apply mutations. finalOverlay := i.collectAndApplyMutations(ctx, overlay, inferred, synthAdapter, localValueSeeds) - i.mergeParamHintsFromOverlay(ctx, finalOverlay) + i.mergeParameterEvidenceFromOverlay(ctx, finalOverlay) // Phase 2: Infer return types from body. return i.inferReturnTypesFromBody(ctx, finalOverlay) } -func (i *Inferencer) applyParamHintsToOverlay(ctx *returnInferenceContext, overlay map[cfg.SymbolID]typ.Type) { - if ctx == nil || ctx.info == nil || ctx.info.Graph == nil || len(ctx.info.ParamHints) == 0 || overlay == nil { +func (i *Inferencer) applyParameterEvidenceToOverlay(ctx *returnInferenceContext, overlay map[cfg.SymbolID]typ.Type) { + if ctx == nil || ctx.info == nil || ctx.info.Graph == nil || len(ctx.info.ParameterEvidence) == 0 || overlay == nil { return } for idx, slot := range ctx.info.Graph.ParamSlotsReadOnly() { - if slot.Symbol == 0 || idx >= len(ctx.info.ParamHints) { + if slot.Symbol == 0 || idx >= len(ctx.info.ParameterEvidence) { continue } - hint := ctx.info.ParamHints[idx] - if !paramhints.IsInformativeHintType(hint) { + evidence := ctx.info.ParameterEvidence[idx] + if !paramevidence.IsInformative(evidence) { continue } if slot.TypeAnnotation != nil && ctx.engine != nil { @@ -664,11 +675,11 @@ func (i *Inferencer) applyParamHintsToOverlay(ctx *returnInferenceContext, overl continue } } - overlay[slot.Symbol] = hint + overlay[slot.Symbol] = evidence } } -func (i *Inferencer) mergeParamHintsFromOverlay(ctx *returnInferenceContext, overlay map[cfg.SymbolID]typ.Type) { +func (i *Inferencer) mergeParameterEvidenceFromOverlay(ctx *returnInferenceContext, overlay map[cfg.SymbolID]typ.Type) { if ctx == nil || ctx.info == nil || ctx.info.Graph == nil || ctx.info.Fn == nil || len(overlay) == 0 { return } @@ -684,18 +695,18 @@ func (i *Inferencer) mergeParamHintsFromOverlay(ctx *returnInferenceContext, ove } } t := overlay[slot.Symbol] - if !paramhints.IsInformativeHintType(t) { + if !paramevidence.IsInformative(t) { continue } - next, merged := paramhints.MergeHintAt(ctx.info.ParamHints, idx, t, typ.JoinPreferNonSoft) + next, merged := paramevidence.MergeAt(ctx.info.ParameterEvidence, idx, t, typ.JoinPreferNonSoft) if merged { - ctx.info.ParamHints = next + ctx.info.ParameterEvidence = next } } - i.mergeParamHintsFromBodyUses(ctx, overlay) + i.mergeParameterEvidenceFromBodyUses(ctx, overlay) } -func (i *Inferencer) mergeParamHintsFromBodyUses(ctx *returnInferenceContext, overlay map[cfg.SymbolID]typ.Type) { +func (i *Inferencer) mergeParameterEvidenceFromBodyUses(ctx *returnInferenceContext, overlay map[cfg.SymbolID]typ.Type) { if i == nil || ctx == nil || ctx.info == nil || ctx.info.Graph == nil || ctx.info.Fn == nil { return } @@ -739,17 +750,17 @@ func (i *Inferencer) mergeParamHintsFromBodyUses(ctx *returnInferenceContext, ov if !ok { return } - hint := i.receiverHintForMethod(ctx, method) - if !paramhints.IsInformativeHintType(hint) { + evidence := i.receiverEvidenceForMethod(ctx, method) + if !paramevidence.IsInformative(evidence) { return } - next, merged := paramhints.MergeHintAt(ctx.info.ParamHints, idx, hint, typ.JoinPreferNonSoft) + next, merged := paramevidence.MergeAt(ctx.info.ParameterEvidence, idx, evidence, typ.JoinPreferNonSoft) if merged { - ctx.info.ParamHints = next + ctx.info.ParameterEvidence = next } } - mergeParamFieldHint := func(sym cfg.SymbolID, field string, hint typ.Type, required bool) { - if sym == 0 || field == "" || !paramhints.IsInformativeHintType(hint) { + mergeParamFieldEvidence := func(sym cfg.SymbolID, field string, evidence typ.Type, required bool) { + if sym == 0 || field == "" || !paramevidence.IsInformative(evidence) { return } idx, ok := paramIndexBySym[sym] @@ -758,14 +769,14 @@ func (i *Inferencer) mergeParamHintsFromBodyUses(ctx *returnInferenceContext, ov } builder := typ.NewRecord() if required { - builder.Field(field, hint) + builder.Field(field, evidence) } else { - builder.OptField(field, hint) + builder.OptField(field, evidence) } rec := builder.Build() - next, merged := paramhints.MergeHintAt(ctx.info.ParamHints, idx, rec, typ.JoinPreferNonSoft) + next, merged := paramevidence.MergeAt(ctx.info.ParameterEvidence, idx, rec, typ.JoinPreferNonSoft) if merged { - ctx.info.ParamHints = next + ctx.info.ParameterEvidence = next } } bodyContractJoin := func(prev, next typ.Type) typ.Type { @@ -774,17 +785,17 @@ func (i *Inferencer) mergeParamHintsFromBodyUses(ctx *returnInferenceContext, ov } return prev } - mergeParamHint := func(sym cfg.SymbolID, hint typ.Type) { - if sym == 0 || !paramhints.IsInformativeHintType(hint) { + mergeParameterEvidence := func(sym cfg.SymbolID, evidence typ.Type) { + if sym == 0 || !paramevidence.IsInformative(evidence) { return } idx, ok := paramIndexBySym[sym] if !ok { return } - next, merged := paramhints.MergeHintAt(ctx.info.ParamHints, idx, hint, bodyContractJoin) + next, merged := paramevidence.MergeAt(ctx.info.ParameterEvidence, idx, evidence, bodyContractJoin) if merged { - ctx.info.ParamHints = next + ctx.info.ParameterEvidence = next } } paramSymbol := func(expr ast.Expr) (cfg.SymbolID, bool) { @@ -849,20 +860,20 @@ func (i *Inferencer) mergeParamHintsFromBodyUses(ctx *returnInferenceContext, ov return false } var bodyParamContracts map[cfg.SymbolID]typ.Type - mergeParamContract := func(sym cfg.SymbolID, hint typ.Type) { - if sym == 0 || !paramhints.IsInformativeHintType(hint) { + mergeParamContract := func(sym cfg.SymbolID, evidence typ.Type) { + if sym == 0 || !paramevidence.IsInformative(evidence) { return } if bodyParamContracts == nil { bodyParamContracts = make(map[cfg.SymbolID]typ.Type) } if prev := bodyParamContracts[sym]; prev != nil { - bodyParamContracts[sym] = subtype.NormalizeIntersection(prev, hint) + bodyParamContracts[sym] = subtype.NormalizeIntersection(prev, evidence) return } - bodyParamContracts[sym] = hint + bodyParamContracts[sym] = evidence } - mergeCallExpectedFieldHints := func(p cfg.Point, info *cfg.CallInfo) { + mergeExpectedFieldEvidence := func(p cfg.Point, info *cfg.CallInfo) { if info == nil || i.types == nil { return } @@ -888,7 +899,7 @@ func (i *Inferencer) mergeParamHintsFromBodyUses(ctx *returnInferenceContext, ov inferredCall := ops.InferCall(ctx.run.Ctx, def) for idx, arg := range info.Args { expected := inferredCall.ExpectedArgType(idx) - if !paramhints.IsInformativeHintType(expected) { + if !paramevidence.IsInformative(expected) { continue } if sym, ok := paramSymbol(arg); ok { @@ -896,16 +907,16 @@ func (i *Inferencer) mergeParamHintsFromBodyUses(ctx *returnInferenceContext, ov continue } if sym, field, ok := paramFieldPath(arg); ok { - mergeParamFieldHint(sym, field, expected, true) + mergeParamFieldEvidence(sym, field, expected, true) continue } } } ctx.info.Graph.EachCallSite(func(p cfg.Point, info *cfg.CallInfo) { - mergeCallExpectedFieldHints(p, info) + mergeExpectedFieldEvidence(p, info) }) for _, sym := range cfg.SortedSymbolIDs(bodyParamContracts) { - mergeParamHint(sym, bodyParamContracts[sym]) + mergeParameterEvidence(sym, bodyParamContracts[sym]) } defaultLiteralType := func(expr ast.Expr) typ.Type { switch expr.(type) { @@ -942,7 +953,7 @@ func (i *Inferencer) mergeParamHintsFromBodyUses(ctx *returnInferenceContext, ov case *ast.LogicalOpExpr: if e.Operator == "or" { if sym, field, ok := paramFieldPath(e.Lhs); ok { - mergeParamFieldHint(sym, field, defaultLiteralType(e.Rhs), false) + mergeParamFieldEvidence(sym, field, defaultLiteralType(e.Rhs), false) } } visitExpr(e.Lhs) @@ -1039,7 +1050,7 @@ func (i *Inferencer) mergeParamHintsFromBodyUses(ctx *returnInferenceContext, ov } } -func (i *Inferencer) receiverHintForMethod(ctx *returnInferenceContext, method string) typ.Type { +func (i *Inferencer) receiverEvidenceForMethod(ctx *returnInferenceContext, method string) typ.Type { if i == nil || i.types == nil || method == "" { return nil } diff --git a/compiler/check/infer/return/overlay_pipeline.go b/compiler/check/infer/return/overlay_pipeline.go index 1c7e5dc7..b76194c6 100644 --- a/compiler/check/infer/return/overlay_pipeline.go +++ b/compiler/check/infer/return/overlay_pipeline.go @@ -9,7 +9,7 @@ import ( fbcore "github.com/wippyai/go-lua/compiler/check/flowbuild/core" "github.com/wippyai/go-lua/compiler/check/flowbuild/mutator" "github.com/wippyai/go-lua/compiler/check/flowbuild/resolve" - "github.com/wippyai/go-lua/compiler/check/infer/paramhints" + "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" "github.com/wippyai/go-lua/compiler/check/phase" "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/compiler/check/scope" @@ -35,8 +35,8 @@ func (i *Inferencer) buildParameterOverlay(ctx *returnInferenceContext) map[cfg. if !hasSource { if selfType := ctx.resolveScope.SelfType(); selfType != nil { overlay[slot.Symbol] = selfType - } else if ctx.info.ParamHints != nil && paramIdx < len(ctx.info.ParamHints) && ctx.info.ParamHints[paramIdx] != nil { - overlay[slot.Symbol] = ctx.info.ParamHints[paramIdx] + } else if ctx.info.ParameterEvidence != nil && paramIdx < len(ctx.info.ParameterEvidence) && ctx.info.ParameterEvidence[paramIdx] != nil { + overlay[slot.Symbol] = ctx.info.ParameterEvidence[paramIdx] } else { overlay[slot.Symbol] = typ.Unknown } @@ -50,8 +50,8 @@ func (i *Inferencer) buildParameterOverlay(ctx *returnInferenceContext) map[cfg. } } if typ.IsAbsentOrUnknown(paramType) { - if ctx.info.ParamHints != nil && paramIdx < len(ctx.info.ParamHints) && ctx.info.ParamHints[paramIdx] != nil { - paramType = ctx.info.ParamHints[paramIdx] + if ctx.info.ParameterEvidence != nil && paramIdx < len(ctx.info.ParameterEvidence) && ctx.info.ParameterEvidence[paramIdx] != nil { + paramType = ctx.info.ParameterEvidence[paramIdx] } } if typ.IsAbsentOrUnknown(paramType) && slot.TypeAnnotation == nil { @@ -130,8 +130,8 @@ func (i *Inferencer) enrichOverlayWithSiblings( } seed := returns.BuildSeedFunctionTypeWithBindings(fn, ctx.engine, ctx.resolveScope, bindings) fnType, _ := seed.(*typ.Function) - if localInfo != nil && len(localInfo.ParamHints) > 0 && fnType != nil { - return paramhints.MergeIntoSignature(fn, localInfo.ParamHints, fnType) + if localInfo != nil && len(localInfo.ParameterEvidence) > 0 && fnType != nil { + return paramevidence.MergeIntoSignature(fn, localInfo.ParameterEvidence, fnType) } return seed }, @@ -256,8 +256,8 @@ func (i *Inferencer) enrichOverlayWithLocalFunctions( } returnVector := i.resolveLocalFunctionReturns(ctx, allReturnVectors, target.Symbol) sig := ctx.engine.ResolveFunctionSignature(fnExpr, ctx.resolveScope) - if localInfo := ctx.localFuncs[target.Symbol]; localInfo != nil && len(localInfo.ParamHints) > 0 && sig != nil { - sig = paramhints.MergeIntoSignature(fnExpr, localInfo.ParamHints, sig) + if localInfo := ctx.localFuncs[target.Symbol]; localInfo != nil && len(localInfo.ParameterEvidence) > 0 && sig != nil { + sig = paramevidence.MergeIntoSignature(fnExpr, localInfo.ParameterEvidence, sig) } if fnType := returns.WithSummaryOrUnknown(sig, returnVector); fnType != nil { overlay[target.Symbol] = fnType @@ -646,7 +646,7 @@ func mergeInferredIntoOverlay( ) { for sym, inferredType := range inferred { baseType := finalOverlay[sym] - // Parameter domains are seeded from annotations/hints and must not be + // Parameter domains are seeded from annotations/evidence and must not be // rewritten by local variable inference artifacts. if paramSyms[sym] { if typ.IsAbsentOrUnknown(baseType) { diff --git a/compiler/check/infer/return/scc.go b/compiler/check/infer/return/scc.go index 40e6d948..7dea8e36 100644 --- a/compiler/check/infer/return/scc.go +++ b/compiler/check/infer/return/scc.go @@ -4,7 +4,7 @@ import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" - "github.com/wippyai/go-lua/compiler/check/infer/paramhints" + "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/types/diag" "github.com/wippyai/go-lua/types/typ" @@ -29,10 +29,10 @@ func (i *Inferencer) iterateSCCFixpoint( } func (i *Inferencer) planLocalFunctionSCCs(localFuncs map[cfg.SymbolID]*returns.LocalFuncInfo) [][]cfg.SymbolID { - // Propagate inter-procedural parameter hints across local call edges before + // Propagate inter-procedural parameter evidence across local call edges before // SCC return inference so unannotated params get stable callsite-driven seeds. - returns.PropagateParamHintsFromCallGraph(localFuncs) - projectLocalFunctionParamHints(localFuncs) + returns.PropagateParameterEvidence(localFuncs) + projectLocalParameterEvidence(localFuncs) var moduleBindings *bind.BindingTable if i != nil && i.store != nil { @@ -42,13 +42,13 @@ func (i *Inferencer) planLocalFunctionSCCs(localFuncs map[cfg.SymbolID]*returns. return returns.ComputeSymbolSCCs(adj) } -func projectLocalFunctionParamHints(localFuncs map[cfg.SymbolID]*returns.LocalFuncInfo) { +func projectLocalParameterEvidence(localFuncs map[cfg.SymbolID]*returns.LocalFuncInfo) { for _, sym := range cfg.SortedSymbolIDs(localFuncs) { info := localFuncs[sym] - if info == nil || len(info.ParamHints) == 0 { + if info == nil || len(info.ParameterEvidence) == 0 { continue } - info.ParamHints = paramhints.ProjectHintsToParamUse(info.Graph, info.Fn, info.ParamHints) + info.ParameterEvidence = paramevidence.ProjectToParameterUse(info.Graph, info.Fn, info.ParameterEvidence) } } diff --git a/compiler/check/phase/scope.go b/compiler/check/phase/scope.go index e25d94cd..a56aff9b 100644 --- a/compiler/check/phase/scope.go +++ b/compiler/check/phase/scope.go @@ -27,7 +27,7 @@ import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/callsite" - "github.com/wippyai/go-lua/compiler/check/infer/paramhints" + "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/compiler/check/scope" "github.com/wippyai/go-lua/compiler/check/synth" @@ -116,12 +116,12 @@ func RunScope(input ScopeInput) ScopeOutput { synthSig = input.SynthesizedFunctionSig } - var hints []typ.Type - if input.ParamHintSignatures != nil && input.Fn != nil { - hints = input.ParamHintSignatures[input.Fn] - hints = paramhints.ProjectHintsToParamUse(input.Graph, input.Fn, hints) + var parameterEvidence []typ.Type + if input.ParameterEvidenceSignatures != nil && input.Fn != nil { + parameterEvidence = input.ParameterEvidenceSignatures[input.Fn] + parameterEvidence = paramevidence.ProjectToParameterUse(input.Graph, input.Fn, parameterEvidence) } - paramTypes, paramAnnotated := ExtractParamTypes(input.Graph, input.Fn, typeExprResolver, synthSig, base, hints) + paramTypes, paramAnnotated := ExtractParamTypes(input.Graph, input.Fn, typeExprResolver, synthSig, base, parameterEvidence) // Inject synthesized self type into base scope only if base doesn't already // have a more specific self type (set by processNestedFunctions from field assignment context). @@ -192,15 +192,15 @@ func RunScope(input ScopeInput) ScopeOutput { exprSynth := func(expr ast.Expr, p cfg.Point, sc *scope.State) typ.Type { return typeResolutionEngine.SynthExprAt(expr, p, sc) } - paramHintSignatures := input.ParamHintSignatures - if input.Fn != nil && hints != nil && input.ParamHintSignatures != nil { - paramHintSignatures = make(map[*ast.FunctionExpr][]typ.Type, len(input.ParamHintSignatures)) - for fn, hintVec := range input.ParamHintSignatures { - paramHintSignatures[fn] = hintVec + parameterEvidenceSignatures := input.ParameterEvidenceSignatures + if input.Fn != nil && parameterEvidence != nil && input.ParameterEvidenceSignatures != nil { + parameterEvidenceSignatures = make(map[*ast.FunctionExpr][]typ.Type, len(input.ParameterEvidenceSignatures)) + for fn, evidence := range input.ParameterEvidenceSignatures { + parameterEvidenceSignatures[fn] = evidence } - paramHintSignatures[input.Fn] = hints + parameterEvidenceSignatures[input.Fn] = parameterEvidence } - fnSignatureResolver := buildFnSignatureResolver(input.FunctionLiteralSignatures, paramHintSignatures, typeResolutionEngine) + fnSignatureResolver := buildFnSignatureResolver(input.FunctionLiteralSignatures, parameterEvidenceSignatures, typeResolutionEngine) callMutator := buildCallMutator(input.Types, input.Ctx, exprSynth) services := ScopeServicesFuncs{ @@ -251,10 +251,10 @@ func normalizeBaseImplicitSelf(graph *cfg.Graph, base *scope.State) *scope.State } // buildFnSignatureResolver creates a function signature resolver that combines -// pre-computed literal signatures, parameter hints, and annotation-based resolution. +// pre-computed literal signatures, parameter evidence, and annotation-based resolution. func buildFnSignatureResolver( literalSigs LiteralSigsProvider, - paramHints map[*ast.FunctionExpr][]typ.Type, + parameterEvidence map[*ast.FunctionExpr][]typ.Type, engine *synth.Engine, ) FunctionSignatureResolver { return FunctionSignatureResolverFunc(func(fn *ast.FunctionExpr, sc *scope.State) *typ.Function { @@ -270,14 +270,14 @@ func buildFnSignatureResolver( if sig == nil { return nil } - if paramHints == nil { + if parameterEvidence == nil { return sig } - hints := paramHints[fn] - if len(hints) == 0 { + evidence := parameterEvidence[fn] + if len(evidence) == 0 { return sig } - return paramhints.MergeIntoSignature(fn, hints, sig) + return paramevidence.MergeIntoSignature(fn, evidence, sig) }) } @@ -289,7 +289,7 @@ func ExtractParamTypes( typeExprResolver TypeResolver, synthSig *typ.Function, base *scope.State, - paramHints []typ.Type, + parameterEvidence []typ.Type, ) (types map[cfg.SymbolID]typ.Type, annotated map[cfg.SymbolID]bool) { if fn == nil || fn.ParList == nil || graph == nil { return nil, nil @@ -306,17 +306,17 @@ func ExtractParamTypes( // Binder/CFG-injected implicit self parameter has no source annotation. _, hasSource := slot.SourceParamIndex() - var hint typ.Type - if paramHints != nil && paramIdx < len(paramHints) { - hint = paramHints[paramIdx] + var evidence typ.Type + if parameterEvidence != nil && paramIdx < len(parameterEvidence) { + evidence = parameterEvidence[paramIdx] } if !hasSource { if base != nil && base.SelfType() != nil { types[slot.Symbol] = base.SelfType() } else if synthSig != nil && paramIdx < len(synthSig.Params) && synthSig.Params[paramIdx].Type != nil { types[slot.Symbol] = synthSig.Params[paramIdx].Type - } else if hint != nil { - types[slot.Symbol] = hint + } else if evidence != nil { + types[slot.Symbol] = evidence } else { types[slot.Symbol] = typ.Unknown } @@ -335,8 +335,8 @@ func ExtractParamTypes( paramType = typ.Unknown } if typ.IsRefinableAnnotation(paramType) { - if hint != nil { - paramType = hint + if evidence != nil { + paramType = evidence } else if synthSig != nil && paramIdx < len(synthSig.Params) && synthSig.Params[paramIdx].Type != nil { paramType = synthSig.Params[paramIdx].Type } @@ -344,8 +344,8 @@ func ExtractParamTypes( isAnnotated = true hasExplicitAnnotation = true } - } else if hint != nil { - paramType = hint + } else if evidence != nil { + paramType = evidence } else if synthSig != nil && paramIdx < len(synthSig.Params) && synthSig.Params[paramIdx].Type != nil { paramType = synthSig.Params[paramIdx].Type } else if slot.Name == "self" && base != nil && base.SelfType() != nil { diff --git a/compiler/check/phase/types.go b/compiler/check/phase/types.go index b09c418b..41022780 100644 --- a/compiler/check/phase/types.go +++ b/compiler/check/phase/types.go @@ -173,9 +173,9 @@ type ScopeInput struct { // Read-only - populated from LiteralSigs channel during iteration. // Can be a map or LiteralSigsProvider interface for lazy lookup. FunctionLiteralSignatures LiteralSigsProvider - // ParamHintSignatures contains inferred param types from call sites. - // Read-only - populated from ParamHints channel during iteration. - ParamHintSignatures map[*ast.FunctionExpr][]typ.Type + // ParameterEvidenceSignatures contains function-expression keyed parameter evidence. + // Read-only - projected from canonical FunctionFacts during iteration. + ParameterEvidenceSignatures map[*ast.FunctionExpr][]typ.Type // FunctionFacts contains canonical facts for functions in this graph // context. Explicit input - not looked up from store during phase execution. FunctionFacts api.FunctionFacts diff --git a/compiler/check/phase/types_test.go b/compiler/check/phase/types_test.go index 4368178e..fcc9c227 100644 --- a/compiler/check/phase/types_test.go +++ b/compiler/check/phase/types_test.go @@ -7,7 +7,7 @@ import ( "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" - "github.com/wippyai/go-lua/compiler/check/infer/paramhints" + "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" "github.com/wippyai/go-lua/compiler/check/scope" "github.com/wippyai/go-lua/types/flow" "github.com/wippyai/go-lua/types/typ" @@ -103,8 +103,8 @@ func TestScopeInput_Fields(t *testing.T) { if input.FunctionLiteralSignatures != nil { t.Error("FunctionLiteralSignatures should be nil by default") } - if input.ParamHintSignatures != nil { - t.Error("ParamHintSignatures should be nil by default") + if input.ParameterEvidenceSignatures != nil { + t.Error("ParameterEvidenceSignatures should be nil by default") } if input.FunctionFacts != nil { t.Error("FunctionFacts should be nil by default") @@ -419,60 +419,60 @@ func TestContextBuilder_Phases(t *testing.T) { }) } -func TestMergeParamHintsIntoSig_NilSig(t *testing.T) { +func TestMergeParameterEvidenceIntoSig_NilSig(t *testing.T) { fn := &ast.FunctionExpr{} - hints := []typ.Type{typ.Number} + evidence := []typ.Type{typ.Number} - result := paramhints.MergeIntoSignature(fn, hints, nil) + result := paramevidence.MergeIntoSignature(fn, evidence, nil) if result != nil { t.Error("expected nil when sig is nil") } } -func TestMergeParamHintsIntoSig_NilFn(t *testing.T) { +func TestMergeParameterEvidenceIntoSig_NilFn(t *testing.T) { sig := typ.Func().Param("x", typ.Any).Build() - hints := []typ.Type{typ.Number} + evidence := []typ.Type{typ.Number} - result := paramhints.MergeIntoSignature(nil, hints, sig) + result := paramevidence.MergeIntoSignature(nil, evidence, sig) if result != sig { t.Error("expected original sig when fn is nil") } } -func TestMergeParamHintsIntoSig_NilParList(t *testing.T) { +func TestMergeParameterEvidenceIntoSig_NilParList(t *testing.T) { fn := &ast.FunctionExpr{ParList: nil} sig := typ.Func().Param("x", typ.Any).Build() - hints := []typ.Type{typ.Number} + evidence := []typ.Type{typ.Number} - result := paramhints.MergeIntoSignature(fn, hints, sig) + result := paramevidence.MergeIntoSignature(fn, evidence, sig) if result != sig { t.Error("expected original sig when ParList is nil") } } -func TestMergeParamHintsIntoSig_EmptyHints(t *testing.T) { +func TestMergeParameterEvidenceIntoSig_EmptyHints(t *testing.T) { fn := &ast.FunctionExpr{ParList: &ast.ParList{}} sig := typ.Func().Param("x", typ.Any).Build() - var hints []typ.Type + var evidence []typ.Type - result := paramhints.MergeIntoSignature(fn, hints, sig) + result := paramevidence.MergeIntoSignature(fn, evidence, sig) if result != sig { - t.Error("expected original sig when hints are empty") + t.Error("expected original sig when evidence is empty") } } -func TestMergeParamHintsIntoSig_NilHintElement(t *testing.T) { +func TestMergeParameterEvidenceIntoSig_NilHintElement(t *testing.T) { fn := &ast.FunctionExpr{ParList: &ast.ParList{}} sig := typ.Func().Param("x", typ.Any).Build() - hints := []typ.Type{nil} + evidence := []typ.Type{nil} - result := paramhints.MergeIntoSignature(fn, hints, sig) + result := paramevidence.MergeIntoSignature(fn, evidence, sig) if result != sig { - t.Error("expected original sig when hint element is nil") + t.Error("expected original sig when evidence element is nil") } } -func TestMergeParamHintsIntoSig_AppliesHint(t *testing.T) { +func TestMergeParameterEvidenceIntoSig_AppliesHint(t *testing.T) { fn := &ast.FunctionExpr{ ParList: &ast.ParList{ Names: []string{"x"}, @@ -480,9 +480,9 @@ func TestMergeParamHintsIntoSig_AppliesHint(t *testing.T) { }, } sig := typ.Func().Param("x", typ.Any).Build() - hints := []typ.Type{typ.Number} + evidence := []typ.Type{typ.Number} - result := paramhints.MergeIntoSignature(fn, hints, sig) + result := paramevidence.MergeIntoSignature(fn, evidence, sig) if result == nil { t.Fatal("expected non-nil result") } @@ -494,7 +494,7 @@ func TestMergeParamHintsIntoSig_AppliesHint(t *testing.T) { } } -func TestMergeParamHintsIntoSig_PreservesAnnotatedParam(t *testing.T) { +func TestMergeParameterEvidenceIntoSig_PreservesAnnotatedParam(t *testing.T) { fn := &ast.FunctionExpr{ ParList: &ast.ParList{ Names: []string{"x"}, @@ -502,15 +502,15 @@ func TestMergeParamHintsIntoSig_PreservesAnnotatedParam(t *testing.T) { }, } sig := typ.Func().Param("x", typ.String).Build() - hints := []typ.Type{typ.Number} + evidence := []typ.Type{typ.Number} - result := paramhints.MergeIntoSignature(fn, hints, sig) + result := paramevidence.MergeIntoSignature(fn, evidence, sig) if result != sig { t.Error("expected original sig when param is annotated") } } -func TestMergeParamHintsIntoSig_PreservesVariadic(t *testing.T) { +func TestMergeParameterEvidenceIntoSig_PreservesVariadic(t *testing.T) { fn := &ast.FunctionExpr{ ParList: &ast.ParList{ Names: []string{"x"}, @@ -518,9 +518,9 @@ func TestMergeParamHintsIntoSig_PreservesVariadic(t *testing.T) { }, } sig := typ.Func().Param("x", typ.Any).Variadic(typ.String).Build() - hints := []typ.Type{typ.Number} + evidence := []typ.Type{typ.Number} - result := paramhints.MergeIntoSignature(fn, hints, sig) + result := paramevidence.MergeIntoSignature(fn, evidence, sig) if result == nil { t.Fatal("expected non-nil result") } @@ -529,7 +529,7 @@ func TestMergeParamHintsIntoSig_PreservesVariadic(t *testing.T) { } } -func TestMergeParamHintsIntoSig_PreservesReturns(t *testing.T) { +func TestMergeParameterEvidenceIntoSig_PreservesReturns(t *testing.T) { fn := &ast.FunctionExpr{ ParList: &ast.ParList{ Names: []string{"x"}, @@ -537,9 +537,9 @@ func TestMergeParamHintsIntoSig_PreservesReturns(t *testing.T) { }, } sig := typ.Func().Param("x", typ.Any).Returns(typ.Boolean).Build() - hints := []typ.Type{typ.Number} + evidence := []typ.Type{typ.Number} - result := paramhints.MergeIntoSignature(fn, hints, sig) + result := paramevidence.MergeIntoSignature(fn, evidence, sig) if result == nil { t.Fatal("expected non-nil result") } @@ -548,7 +548,7 @@ func TestMergeParamHintsIntoSig_PreservesReturns(t *testing.T) { } } -func TestMergeParamHintsIntoSig_PreservesOptionalParam(t *testing.T) { +func TestMergeParameterEvidenceIntoSig_PreservesOptionalParam(t *testing.T) { fn := &ast.FunctionExpr{ ParList: &ast.ParList{ Names: []string{"x"}, @@ -556,9 +556,9 @@ func TestMergeParamHintsIntoSig_PreservesOptionalParam(t *testing.T) { }, } sig := typ.Func().OptParam("x", typ.Any).Build() - hints := []typ.Type{typ.Number} + evidence := []typ.Type{typ.Number} - result := paramhints.MergeIntoSignature(fn, hints, sig) + result := paramevidence.MergeIntoSignature(fn, evidence, sig) if result == nil { t.Fatal("expected non-nil result") } diff --git a/compiler/check/pipeline/runner.go b/compiler/check/pipeline/runner.go index 282cb76d..6671222f 100644 --- a/compiler/check/pipeline/runner.go +++ b/compiler/check/pipeline/runner.go @@ -15,13 +15,13 @@ // - Synthesis engine for expression type computation // - Flow solver for control flow analysis // - Effect propagation for side effect tracking -// - Parameter hint inference from call sites +// - Parameter evidence inference from call sites package pipeline import ( "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/infer/captured" - "github.com/wippyai/go-lua/compiler/check/infer/paramhints" + "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" "github.com/wippyai/go-lua/compiler/check/modules" "github.com/wippyai/go-lua/compiler/check/phase" "github.com/wippyai/go-lua/compiler/check/scope" @@ -106,8 +106,8 @@ func (r *Runner) Run(ctx *db.QueryContext, key api.FuncKey) *api.FuncResult { setter.SetGraphParentHash(graph.ID(), key.ParentHash) } - paramHintSigs := paramhints.BuildParamHintSignatures(store, graph, parent, r.stdlib) - synthSig := r.resolveSynthesizedSignature(ctx, store, graph, fn, parent, paramHintSigs) + parameterEvidenceSigs := paramevidence.BuildSignatureMap(store, graph, parent, r.stdlib) + synthSig := r.resolveSynthesizedSignature(ctx, store, graph, fn, parent, parameterEvidenceSigs) functionFacts := store.GetFunctionFactsSnapshot(graph, parent) @@ -138,14 +138,14 @@ func (r *Runner) Run(ctx *db.QueryContext, key api.FuncKey) *api.FuncResult { // Phase B: Build scopes and extract declared types. scopeOut := phase.RunScope(phase.ScopeInput{ - PhaseEnv: env, - Parent: parent, - MaxScopeDepth: r.maxScopeDepth, - Resolve: resolveOut, - SynthesizedFunctionSig: synthSig, - FunctionLiteralSignatures: literalSigs, - ParamHintSignatures: paramHintSigs, - FunctionFacts: functionFacts, + PhaseEnv: env, + Parent: parent, + MaxScopeDepth: r.maxScopeDepth, + Resolve: resolveOut, + SynthesizedFunctionSig: synthSig, + FunctionLiteralSignatures: literalSigs, + ParameterEvidenceSignatures: parameterEvidenceSigs, + FunctionFacts: functionFacts, }) // Declared is the default phase for scope/extract and interproc reads. diff --git a/compiler/check/pipeline/runner_stages.go b/compiler/check/pipeline/runner_stages.go index 071bfd3f..8b12bdaa 100644 --- a/compiler/check/pipeline/runner_stages.go +++ b/compiler/check/pipeline/runner_stages.go @@ -5,7 +5,7 @@ import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/flowbuild/resolve" - "github.com/wippyai/go-lua/compiler/check/infer/paramhints" + "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" "github.com/wippyai/go-lua/compiler/check/phase" "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/compiler/check/scope" @@ -22,19 +22,19 @@ func (r *Runner) resolveSynthesizedSignature( graph *cfg.Graph, fn *ast.FunctionExpr, parent *scope.State, - paramHintSigs map[*ast.FunctionExpr][]typ.Type, + parameterEvidenceSigs map[*ast.FunctionExpr][]typ.Type, ) *typ.Function { if graph == nil || fn == nil { return nil } - factSig := paramhints.ProjectSignatureToParamUse(graph, fn, r.functionFactSignatureForFunction(store, graph, fn)) + factSig := paramevidence.ProjectSignatureToParamUse(graph, fn, r.functionFactSignatureForFunction(store, graph, fn)) synthSig := r.literalSignatureForFunction(store, graph, fn) - if paramHintSigs == nil { + if parameterEvidenceSigs == nil { return mergeSynthesizedSignatureFact(synthSig, factSig) } - hints := paramHintSigs[fn] - if len(hints) == 0 { + evidence := parameterEvidenceSigs[fn] + if len(evidence) == 0 { return mergeSynthesizedSignatureFact(synthSig, factSig) } if synthSig == nil { @@ -53,7 +53,7 @@ func (r *Runner) resolveSynthesizedSignature( if synthSig == nil { return factSig } - return mergeSynthesizedSignatureFact(paramhints.MergeIntoSignature(fn, hints, synthSig), factSig) + return mergeSynthesizedSignatureFact(paramevidence.MergeIntoSignature(fn, evidence, synthSig), factSig) } func mergeSynthesizedSignatureFact(seed, fact *typ.Function) *typ.Function { diff --git a/compiler/check/returns/callgraph.go b/compiler/check/returns/callgraph.go index ac336084..d9150885 100644 --- a/compiler/check/returns/callgraph.go +++ b/compiler/check/returns/callgraph.go @@ -7,7 +7,7 @@ import ( "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" checkcallsite "github.com/wippyai/go-lua/compiler/check/callsite" - "github.com/wippyai/go-lua/compiler/check/infer/paramhints" + "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" synthresolve "github.com/wippyai/go-lua/compiler/check/synth/phase/resolve" "github.com/wippyai/go-lua/types/typ" "github.com/wippyai/go-lua/types/typ/unwrap" @@ -81,23 +81,23 @@ func buildLocalSignatureResolver(localFuncs map[cfg.SymbolID]*LocalFuncInfo) fun } } -// PropagateParamHintsFromCallGraph propagates parameter type hints through +// PropagateParameterEvidence propagates parameter evidence through // inner function call graphs. // // This function implements inter-procedural parameter type inference. For each // local function, it scans call sites to identify argument types: // -// - Literal arguments (numbers, strings, booleans, nil) provide direct type hints -// - Identifier arguments that reference caller parameters with known hints -// propagate those hints transitively +// - Literal arguments (numbers, strings, booleans, nil) provide direct evidence. +// - Identifier arguments that reference caller parameters with known evidence +// propagate that evidence transitively. // // The algorithm iterates to fixpoint, bounded by the number of local functions. // This ensures that chains like f(x) -> g(x) -> h(x) are fully resolved even // if functions are processed in arbitrary order. // -// Hints are accumulated using typ.JoinPreferNonSoft, producing union types when a parameter -// is called with multiple different types across call sites. -func PropagateParamHintsFromCallGraph(localFuncs map[cfg.SymbolID]*LocalFuncInfo) { +// Evidence is accumulated with typ.JoinPreferNonSoft, producing union types when +// a parameter is called with multiple different types across call sites. +func PropagateParameterEvidence(localFuncs map[cfg.SymbolID]*LocalFuncInfo) { if len(localFuncs) == 0 { return } @@ -182,13 +182,13 @@ func PropagateParamHintsFromCallGraph(localFuncs map[cfg.SymbolID]*LocalFuncInfo } // For identifiers, check if the ident refers to a caller - // parameter with a known hint. + // parameter with known evidence. if argType == nil { if ident, ok := arg.(*ast.IdentExpr); ok && bindings != nil { if sym, found := bindings.SymbolOf(ident); found { if ref, isParam := paramOwner[sym]; isParam { - if ref.index < len(ref.owner.ParamHints) { - argType = ref.owner.ParamHints[ref.index] + if ref.index < len(ref.owner.ParameterEvidence) { + argType = ref.owner.ParameterEvidence[ref.index] } } } @@ -197,13 +197,13 @@ func PropagateParamHintsFromCallGraph(localFuncs map[cfg.SymbolID]*LocalFuncInfo // If a local function is passed as an argument and the callee has // a function-typed parameter annotation at this position, propagate - // those parameter types as hints to the passed local function. + // those parameter types as evidence to the passed local function. if calleeSig != nil && i < len(calleeSig.Params) { if expectedFn := unwrap.Function(calleeSig.Params[i].Type); expectedFn != nil { argSym := canonicalLocalSymbol(localFuncs, graph, moduleBindings, bindings, arg, 0) if argSym != 0 { if argLocal := localFuncs[argSym]; argLocal != nil { - if mergeFunctionParamHints(argLocal, expectedFn) { + if mergeExpectedFunctionEvidence(argLocal, expectedFn) { changed = true } } @@ -211,8 +211,8 @@ func PropagateParamHintsFromCallGraph(localFuncs map[cfg.SymbolID]*LocalFuncInfo } } - nextHints, merged := paramhints.MergeHintAt(callee.ParamHints, i, argType, typ.JoinPreferNonSoft) - callee.ParamHints = nextHints + nextEvidence, merged := paramevidence.MergeAt(callee.ParameterEvidence, i, argType, typ.JoinPreferNonSoft) + callee.ParameterEvidence = nextEvidence if merged { changed = true } @@ -220,7 +220,7 @@ func PropagateParamHintsFromCallGraph(localFuncs map[cfg.SymbolID]*LocalFuncInfo } // Parent-graph calls (e.g. chunk-level calls to local/nested functions) - // provide the first wave of hints into local function params. + // provide the first wave of evidence into local function params. for _, graphID := range parentGraphIDs { g := parentGraphs[graphID] if g == nil { @@ -249,22 +249,22 @@ func PropagateParamHintsFromCallGraph(localFuncs map[cfg.SymbolID]*LocalFuncInfo } } -func mergeFunctionParamHints(target *LocalFuncInfo, expectedFn *typ.Function) bool { +func mergeExpectedFunctionEvidence(target *LocalFuncInfo, expectedFn *typ.Function) bool { if target == nil || expectedFn == nil || len(expectedFn.Params) == 0 { return false } changed := false - if target.ParamHints == nil { - target.ParamHints = make([]typ.Type, len(expectedFn.Params)) - } else if len(expectedFn.Params) > len(target.ParamHints) { + if target.ParameterEvidence == nil { + target.ParameterEvidence = make([]typ.Type, len(expectedFn.Params)) + } else if len(expectedFn.Params) > len(target.ParameterEvidence) { expanded := make([]typ.Type, len(expectedFn.Params)) - copy(expanded, target.ParamHints) - target.ParamHints = expanded + copy(expanded, target.ParameterEvidence) + target.ParameterEvidence = expanded } for i, param := range expectedFn.Params { - nextHints, merged := paramhints.MergeHintAt(target.ParamHints, i, param.Type, typ.JoinPreferNonSoft) - target.ParamHints = nextHints + nextEvidence, merged := paramevidence.MergeAt(target.ParameterEvidence, i, param.Type, typ.JoinPreferNonSoft) + target.ParameterEvidence = nextEvidence if merged { changed = true } diff --git a/compiler/check/returns/callgraph_test.go b/compiler/check/returns/callgraph_test.go index 7eb2325c..4b9f06d4 100644 --- a/compiler/check/returns/callgraph_test.go +++ b/compiler/check/returns/callgraph_test.go @@ -10,29 +10,29 @@ import ( "github.com/wippyai/go-lua/types/typ" ) -func TestPropagateParamHintsFromCallGraph_Empty(t *testing.T) { - PropagateParamHintsFromCallGraph(nil) - PropagateParamHintsFromCallGraph(map[cfg.SymbolID]*LocalFuncInfo{}) +func TestPropagateParameterEvidence_Empty(t *testing.T) { + PropagateParameterEvidence(nil) + PropagateParameterEvidence(map[cfg.SymbolID]*LocalFuncInfo{}) } -func TestPropagateParamHintsFromCallGraph_NilGraph(t *testing.T) { +func TestPropagateParameterEvidence_NilGraph(t *testing.T) { localFuncs := map[cfg.SymbolID]*LocalFuncInfo{ 1: {Sym: 1, Graph: nil}, } - PropagateParamHintsFromCallGraph(localFuncs) + PropagateParameterEvidence(localFuncs) } -func TestPropagateParamHintsFromCallGraph_SingleFuncNoArgs(t *testing.T) { +func TestPropagateParameterEvidence_SingleFuncNoArgs(t *testing.T) { fn := &ast.FunctionExpr{ParList: &ast.ParList{Names: []string{"x"}}} graph := cfg.Build(fn) localFuncs := map[cfg.SymbolID]*LocalFuncInfo{ 1: {Sym: 1, Fn: fn, Graph: graph}, } - PropagateParamHintsFromCallGraph(localFuncs) + PropagateParameterEvidence(localFuncs) - if localFuncs[1].ParamHints != nil { - t.Error("expected nil ParamHints for function with no callers") + if localFuncs[1].ParameterEvidence != nil { + t.Error("expected nil ParameterEvidence for function with no callers") } } @@ -76,7 +76,7 @@ func TestBuildLocalCallGraph_SingleFunc(t *testing.T) { } } -func TestPropagateParamHintsFromCallGraph_LiteralArgTypes(t *testing.T) { +func TestPropagateParameterEvidence_LiteralArgTypes(t *testing.T) { // Test that literal arguments (number, string, bool, nil) are typed correctly tests := []struct { name string @@ -126,7 +126,7 @@ func TestPropagateParamHintsFromCallGraph_LiteralArgTypes(t *testing.T) { } } -func TestPropagateParamHintsFromCallGraph_UnknownArgSkipped(t *testing.T) { +func TestPropagateParameterEvidence_UnknownArgSkipped(t *testing.T) { fn := &ast.FunctionExpr{ParList: &ast.ParList{Names: []string{"x"}}} graph := cfg.Build(fn) @@ -136,9 +136,9 @@ func TestPropagateParamHintsFromCallGraph_UnknownArgSkipped(t *testing.T) { Graph: graph, } - // Unknown type args should be skipped (not create hints) - if info.ParamHints != nil { - t.Error("ParamHints should be nil initially") + // Unknown type args should be skipped (not create evidence) + if info.ParameterEvidence != nil { + t.Error("ParameterEvidence should be nil initially") } } @@ -153,44 +153,44 @@ func TestLocalFuncInfo_ZeroValue(t *testing.T) { if info.Graph != nil { t.Error("Graph should be nil") } - if info.ParamHints != nil { - t.Error("ParamHints should be nil") + if info.ParameterEvidence != nil { + t.Error("ParameterEvidence should be nil") } } -func TestLocalFuncInfo_ParamHintsExpansion(t *testing.T) { - // Test that ParamHints array expands correctly +func TestLocalFuncInfo_ParameterEvidenceExpansion(t *testing.T) { + // Test that ParameterEvidence array expands correctly info := &LocalFuncInfo{ - Sym: 1, - ParamHints: []typ.Type{typ.Number}, + Sym: 1, + ParameterEvidence: []typ.Type{typ.Number}, } // Verify initial state - if len(info.ParamHints) != 1 { - t.Fatalf("expected 1 hint, got %d", len(info.ParamHints)) + if len(info.ParameterEvidence) != 1 { + t.Fatalf("expected 1 evidence, got %d", len(info.ParameterEvidence)) } - if info.ParamHints[0] != typ.Number { - t.Errorf("expected Number, got %v", info.ParamHints[0]) + if info.ParameterEvidence[0] != typ.Number { + t.Errorf("expected Number, got %v", info.ParameterEvidence[0]) } - // Simulate expansion like PropagateParamHintsFromCallGraph does + // Simulate expansion like PropagateParameterEvidence does i := 2 - if i >= len(info.ParamHints) { + if i >= len(info.ParameterEvidence) { expanded := make([]typ.Type, i+1) - copy(expanded, info.ParamHints) - info.ParamHints = expanded + copy(expanded, info.ParameterEvidence) + info.ParameterEvidence = expanded } - if len(info.ParamHints) != 3 { - t.Fatalf("expected 3 hints after expansion, got %d", len(info.ParamHints)) + if len(info.ParameterEvidence) != 3 { + t.Fatalf("expected 3 evidence after expansion, got %d", len(info.ParameterEvidence)) } - if info.ParamHints[0] != typ.Number { - t.Error("original hint should be preserved") + if info.ParameterEvidence[0] != typ.Number { + t.Error("original evidence should be preserved") } - if info.ParamHints[1] != nil { + if info.ParameterEvidence[1] != nil { t.Error("gap should be nil") } - if info.ParamHints[2] != nil { + if info.ParameterEvidence[2] != nil { t.Error("new slot should be nil") } } @@ -264,7 +264,7 @@ func TestBuildLocalCallGraph_AddsCallbackFunctionEdges(t *testing.T) { } } -func TestPropagateParamHintsFromCallGraph_MethodRuntimeIndexing(t *testing.T) { +func TestPropagateParameterEvidence_MethodRuntimeIndexing(t *testing.T) { stmts, err := parse.ParseString(` local function callee(self, x) return x @@ -314,17 +314,17 @@ func TestPropagateParamHintsFromCallGraph_MethodRuntimeIndexing(t *testing.T) { t.Fatalf("expected symbols for callee/caller, got callee=%d caller=%d", calleeSym, callerSym) } - PropagateParamHintsFromCallGraph(localFuncs) + PropagateParameterEvidence(localFuncs) - hints := localFuncs[calleeSym].ParamHints - if len(hints) < 2 { - t.Fatalf("expected at least 2 param hints for callee(self,x), got %d", len(hints)) + evidence := localFuncs[calleeSym].ParameterEvidence + if len(evidence) < 2 { + t.Fatalf("expected at least 2 parameter evidence for callee(self,x), got %d", len(evidence)) } - if !typ.TypeEquals(hints[1], typ.Number) { - t.Fatalf("expected hint for x at index 1 to be number, got %v", hints[1]) + if !typ.TypeEquals(evidence[1], typ.Number) { + t.Fatalf("expected evidence for x at index 1 to be number, got %v", evidence[1]) } - if hints[0] != nil { - t.Fatalf("expected no informative hint for receiver at index 0, got %v", hints[0]) + if evidence[0] != nil { + t.Fatalf("expected no informative evidence for receiver at index 0, got %v", evidence[0]) } } diff --git a/compiler/check/returns/doc.go b/compiler/check/returns/doc.go index c67c76aa..21372f4f 100644 --- a/compiler/check/returns/doc.go +++ b/compiler/check/returns/doc.go @@ -41,6 +41,6 @@ // // # Signature Inference // -// [InferSignature] combines parameter hints and return types to produce +// [InferSignature] combines parameter evidence and return types to produce // complete function signatures for functions without annotations. package returns diff --git a/compiler/check/returns/domain_law_test.go b/compiler/check/returns/domain_law_test.go index c5ba39c6..17116992 100644 --- a/compiler/check/returns/domain_law_test.go +++ b/compiler/check/returns/domain_law_test.go @@ -19,10 +19,7 @@ func TestFactsDomain_ProductOperatorsAreIdempotentAcrossAllDomains(t *testing.T) fn := typ.Func().Param("name", typ.String).Returns(typ.String).Build() raw := api.Facts{ FunctionFacts: api.FunctionFacts{ - fnSym: {Summary: []typ.Type{typ.String}, Narrow: []typ.Type{typ.String}, Type: fn}, - }, - ParamHints: api.ParamHints{ - fnSym: []typ.Type{typ.String}, + fnSym: {Params: []typ.Type{typ.String}, Summary: []typ.Type{typ.String}, Narrow: []typ.Type{typ.String}, Type: fn}, }, LiteralSigs: api.LiteralSigs{ lit: typ.Func().Param("name", typ.String).Returns(typ.String).Build(), diff --git a/compiler/check/returns/equal.go b/compiler/check/returns/equal.go index 0bc2be0f..95c57bb5 100644 --- a/compiler/check/returns/equal.go +++ b/compiler/check/returns/equal.go @@ -11,9 +11,6 @@ func FactsEqual(a, b api.Facts) bool { if !FunctionFactsEqual(a.FunctionFacts, b.FunctionFacts) { return false } - if !symbolTypeVectorMapEqual(a.ParamHints, b.ParamHints) { - return false - } if !LiteralSigsEqual(a.LiteralSigs, b.LiteralSigs) { return false } @@ -43,6 +40,9 @@ func FunctionFactsEqual(a, b api.FunctionFacts) bool { if !ok { return false } + if !ReturnTypesEqual(af.Params, bf.Params) { + return false + } if !ReturnTypesEqual(af.Summary, bf.Summary) { return false } @@ -70,20 +70,6 @@ func LiteralSigsEqual(a, b api.LiteralSigs) bool { return true } -func symbolTypeVectorMapEqual(a map[cfg.SymbolID][]typ.Type, b map[cfg.SymbolID][]typ.Type) bool { - if len(a) != len(b) { - return false - } - for _, sym := range cfg.SortedSymbolIDs(a) { - left := a[sym] - right, ok := b[sym] - if !ok || !ReturnTypesEqual(left, right) { - return false - } - } - return true -} - func symbolTypeMapEqual(a map[cfg.SymbolID]typ.Type, b map[cfg.SymbolID]typ.Type) bool { if len(a) != len(b) { return false diff --git a/compiler/check/returns/equal_test.go b/compiler/check/returns/equal_test.go index 41312e55..269fb25a 100644 --- a/compiler/check/returns/equal_test.go +++ b/compiler/check/returns/equal_test.go @@ -95,37 +95,25 @@ func TestFactsEqual_DifferentCanonicalFunctionFacts(t *testing.T) { } } -func TestTypeVectorMapEqual_Empty(t *testing.T) { - if !symbolTypeVectorMapEqual(nil, nil) { - t.Error("nil summaries should be equal") - } -} - -func TestTypeVectorMapEqual_DifferentLength(t *testing.T) { - a := map[cfg.SymbolID][]typ.Type{1: {typ.String}} - b := map[cfg.SymbolID][]typ.Type{} - if symbolTypeVectorMapEqual(a, b) { - t.Error("summaries with different lengths should not be equal") - } -} - -func TestParamHintsEqual_Empty(t *testing.T) { - if !symbolTypeVectorMapEqual(nil, nil) { - t.Error("nil param hints should be equal") +func TestSymbolTypeMapEqual_Empty(t *testing.T) { + if !symbolTypeMapEqual(nil, nil) { + t.Error("nil func types should be equal") } } -func TestParamHintsEqual_Same(t *testing.T) { - a := api.ParamHints{1: []typ.Type{typ.String}} - b := api.ParamHints{1: []typ.Type{typ.String}} - if !symbolTypeVectorMapEqual(a, b) { - t.Error("same param hints should be equal") +func TestFunctionFactsEqual_Params(t *testing.T) { + a := api.FunctionFacts{1: {Params: []typ.Type{typ.String}}} + b := api.FunctionFacts{1: {Params: []typ.Type{typ.String}}} + if !FunctionFactsEqual(a, b) { + t.Error("same canonical parameter evidence should be equal") } } -func TestSymbolTypeMapEqual_Empty(t *testing.T) { - if !symbolTypeMapEqual(nil, nil) { - t.Error("nil func types should be equal") +func TestFunctionFactsEqual_DifferentParams(t *testing.T) { + a := api.FunctionFacts{1: {Params: []typ.Type{typ.String}}} + b := api.FunctionFacts{1: {Params: []typ.Type{typ.Number}}} + if FunctionFactsEqual(a, b) { + t.Error("different canonical parameter evidence should not be equal") } } diff --git a/compiler/check/returns/function_facts.go b/compiler/check/returns/function_facts.go index dfcb61a5..958cecc5 100644 --- a/compiler/check/returns/function_facts.go +++ b/compiler/check/returns/function_facts.go @@ -25,6 +25,7 @@ func markFunctionFactSymbols[T any](dst map[cfg.SymbolID]bool, src map[cfg.Symbo func NormalizeFunctionFact(ff api.FunctionFact) api.FunctionFact { return api.FunctionFact{ + Params: filterEmptyParameterEvidenceVector(ff.Params), Summary: canonicalReturnVector(ff.Summary), Narrow: canonicalReturnVector(ff.Narrow), Type: normalizeInterprocValueType(ff.Type), @@ -32,7 +33,7 @@ func NormalizeFunctionFact(ff api.FunctionFact) api.FunctionFact { } func functionFactEmpty(ff api.FunctionFact) bool { - return len(ff.Summary) == 0 && len(ff.Narrow) == 0 && ff.Type == nil + return len(ff.Params) == 0 && len(ff.Summary) == 0 && len(ff.Narrow) == 0 && ff.Type == nil } func readFunctionFactFromFacts(facts *api.Facts, sym cfg.SymbolID) api.FunctionFact { diff --git a/compiler/check/returns/join.go b/compiler/check/returns/join.go index 1ebeb3b1..32247006 100644 --- a/compiler/check/returns/join.go +++ b/compiler/check/returns/join.go @@ -1502,8 +1502,8 @@ func typeRefinesTableKeyByTruthiness(candidate, baseline typ.Type) bool { if candidate == nil || baseline == nil || typ.TypeEquals(candidate, baseline) { return false } - candidateInner, _ := splitNilableParamHint(candidate) - baselineInner, _ := splitNilableParamHint(baseline) + candidateInner, _ := splitNilableParameterEvidence(candidate) + baselineInner, _ := splitNilableParameterEvidence(baseline) if candidateInner == nil || baselineInner == nil { return false } diff --git a/compiler/check/returns/kernel.go b/compiler/check/returns/kernel.go index 9e73e24f..5895571f 100644 --- a/compiler/check/returns/kernel.go +++ b/compiler/check/returns/kernel.go @@ -12,6 +12,9 @@ func JoinFunctionFact(existing, candidate api.FunctionFact) api.FunctionFact { candidate = NormalizeFunctionFact(candidate) out := existing + if len(candidate.Params) > 0 { + out.Params = joinParameterEvidenceVectors(out.Params, candidate.Params) + } if len(candidate.Summary) > 0 { out.Summary = MergeReturnSummary(out.Summary, candidate.Summary) } diff --git a/compiler/check/returns/types.go b/compiler/check/returns/types.go index 83293c2c..59a0b117 100644 --- a/compiler/check/returns/types.go +++ b/compiler/check/returns/types.go @@ -29,11 +29,11 @@ // to avoid circular dependence, while canonical function facts are used for // functions outside the SCC (whose types are already known). // -// # Parameter Hint Propagation +// # Parameter Evidence Propagation // -// For unannotated parameters, the system propagates type hints from call sites. -// If function `f` is called as `f(42)`, the first parameter of `f` is hinted -// as `number`. Hints are joined across all call sites and propagated through +// For unannotated parameters, the system propagates evidence from call sites. +// If function `f` is called as `f(42)`, the first parameter of `f` records +// number evidence. Evidence is joined across all call sites and propagated through // the call graph until fixpoint. package returns @@ -48,7 +48,7 @@ import ( // // Each LocalFuncInfo represents a function that may participate in mutual // recursion with other local functions. The info includes the function's -// AST, CFG, definition context, and any parameter hints inferred from +// AST, CFG, definition context, and any parameter evidence inferred from // call sites. type LocalFuncInfo struct { Sym cfg.SymbolID @@ -56,13 +56,13 @@ type LocalFuncInfo struct { DefScope *scope.State Graph *cfg.Graph // ParentGraph is the graph where this local function is defined. - // Used for parent-scope callsite hint propagation. + // Used for parent-scope callsite evidence propagation. ParentGraph *cfg.Graph ParentFn *ast.FunctionExpr DefPoint cfg.Point - // ParamHints holds inferred effective-parameter types from call sites in the + // ParameterEvidence holds inferred effective-parameter types from call sites in the // parent graph. For methods, index 0 is self. - ParamHints []typ.Type + ParameterEvidence []typ.Type } // MaxReturnSummaryIterations limits fixpoint iterations for return-vector inference. diff --git a/compiler/check/returns/types_test.go b/compiler/check/returns/types_test.go index 99bee60a..78ca76d8 100644 --- a/compiler/check/returns/types_test.go +++ b/compiler/check/returns/types_test.go @@ -12,13 +12,13 @@ import ( func TestLocalFuncInfoStructure(t *testing.T) { t.Run("struct fields are accessible", func(t *testing.T) { info := LocalFuncInfo{ - Sym: cfg.SymbolID(1), - Fn: &ast.FunctionExpr{}, - DefScope: scope.New(), - Graph: &cfg.Graph{}, - ParentFn: nil, - DefPoint: cfg.Point(0), - ParamHints: []typ.Type{typ.String, typ.Number}, + Sym: cfg.SymbolID(1), + Fn: &ast.FunctionExpr{}, + DefScope: scope.New(), + Graph: &cfg.Graph{}, + ParentFn: nil, + DefPoint: cfg.Point(0), + ParameterEvidence: []typ.Type{typ.String, typ.Number}, } if info.Sym != cfg.SymbolID(1) { @@ -39,8 +39,8 @@ func TestLocalFuncInfoStructure(t *testing.T) { if info.DefPoint != cfg.Point(0) { t.Fatalf("expected DefPoint=0, got %v", info.DefPoint) } - if len(info.ParamHints) != 2 { - t.Fatalf("expected 2 ParamHints, got %d", len(info.ParamHints)) + if len(info.ParameterEvidence) != 2 { + t.Fatalf("expected 2 ParameterEvidence, got %d", len(info.ParameterEvidence)) } }) @@ -52,8 +52,8 @@ func TestLocalFuncInfoStructure(t *testing.T) { if info.Fn != nil { t.Fatal("expected nil Fn") } - if info.ParamHints != nil { - t.Fatal("expected nil ParamHints") + if info.ParameterEvidence != nil { + t.Fatal("expected nil ParameterEvidence") } }) } diff --git a/compiler/check/returns/widen.go b/compiler/check/returns/widen.go index 82f6777d..104b5e2a 100644 --- a/compiler/check/returns/widen.go +++ b/compiler/check/returns/widen.go @@ -3,7 +3,7 @@ package returns import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" - "github.com/wippyai/go-lua/compiler/check/infer/paramhints" + "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" "github.com/wippyai/go-lua/internal" "github.com/wippyai/go-lua/types/narrow" "github.com/wippyai/go-lua/types/subtype" @@ -14,7 +14,6 @@ import ( // WidenFacts merges two interproc fact bundles. func WidenFacts(prev, next api.Facts) api.Facts { out := api.Facts{ - ParamHints: WidenParamHints(prev.ParamHints, next.ParamHints), LiteralSigs: WidenLiteralSigs(prev.LiteralSigs, next.LiteralSigs), CapturedTypes: WidenCapturedTypes(prev.CapturedTypes, next.CapturedTypes), CapturedFields: WidenCapturedFieldAssigns(prev.CapturedFields, next.CapturedFields), @@ -44,7 +43,6 @@ func WidenFacts(prev, next api.Facts) api.Facts { // inside one analysis round. Recursive fixpoint boundaries must use WidenFacts. func JoinFacts(prev, next api.Facts) api.Facts { out := api.Facts{ - ParamHints: JoinParamHints(prev.ParamHints, next.ParamHints), LiteralSigs: JoinLiteralSigs(prev.LiteralSigs, next.LiteralSigs), CapturedTypes: JoinCapturedTypes(prev.CapturedTypes, next.CapturedTypes), CapturedFields: JoinCapturedFieldAssigns(prev.CapturedFields, next.CapturedFields), @@ -66,6 +64,7 @@ func JoinFacts(prev, next api.Facts) api.Facts { func widenFunctionFactForConvergence(prev, next api.FunctionFact) api.FunctionFact { out := api.FunctionFact{ + Params: joinParameterEvidenceVectors(prev.Params, next.Params), Summary: widenReturnSummaryForConvergence(prev.Summary, next.Summary), Narrow: widenReturnSummaryForConvergence(prev.Narrow, next.Narrow), Type: widenFunctionFactTypeForConvergence(prev.Type, next.Type), @@ -479,46 +478,47 @@ func unionMembers(t typ.Type) []typ.Type { } } -// WidenParamHints merges two param hint maps using monotone union. -func WidenParamHints(prev, next api.ParamHints) api.ParamHints { +// WidenParameterEvidence merges two parameter evidence maps using the same +// vector law used by canonical FunctionFacts. +func WidenParameterEvidence(prev, next map[cfg.SymbolID][]typ.Type) map[cfg.SymbolID][]typ.Type { if prev == nil && next == nil { return nil } if prev == nil { - return filterEmptyParamHints(next) + return filterEmptyParameterEvidence(next) } if next == nil { - return filterEmptyParamHints(prev) + return filterEmptyParameterEvidence(prev) } - merged := make(api.ParamHints, len(prev)+len(next)) + merged := make(map[cfg.SymbolID][]typ.Type, len(prev)+len(next)) for _, sym := range cfg.SortedSymbolIDs(prev) { - hints := normalizeParamHintVector(prev[sym]) - if hasNonNilHint(hints) { - merged[sym] = hints + evidence := normalizeParameterEvidenceVector(prev[sym]) + if hasNonNilEvidence(evidence) { + merged[sym] = evidence } } for _, sym := range cfg.SortedSymbolIDs(next) { - hints := normalizeParamHintVector(next[sym]) - if !hasNonNilHint(hints) { + evidence := normalizeParameterEvidenceVector(next[sym]) + if !hasNonNilEvidence(evidence) { continue } if existing := merged[sym]; existing != nil { - merged[sym] = joinParamHintVectors(existing, hints) + merged[sym] = joinParameterEvidenceVectors(existing, evidence) } else { - merged[sym] = hints + merged[sym] = evidence } } return merged } -func filterEmptyParamHints(hints api.ParamHints) api.ParamHints { - if hints == nil { +func filterEmptyParameterEvidence(evidence map[cfg.SymbolID][]typ.Type) map[cfg.SymbolID][]typ.Type { + if evidence == nil { return nil } - out := make(api.ParamHints, len(hints)) - for _, sym := range cfg.SortedSymbolIDs(hints) { - v := normalizeParamHintVector(hints[sym]) - if hasNonNilHint(v) { + out := make(map[cfg.SymbolID][]typ.Type, len(evidence)) + for _, sym := range cfg.SortedSymbolIDs(evidence) { + v := filterEmptyParameterEvidenceVector(evidence[sym]) + if hasNonNilEvidence(v) { out[sym] = v } } @@ -528,37 +528,45 @@ func filterEmptyParamHints(hints api.ParamHints) api.ParamHints { return out } -func normalizeParamHintVector(hints []typ.Type) []typ.Type { +func filterEmptyParameterEvidenceVector(evidence []typ.Type) []typ.Type { + v := normalizeParameterEvidenceVector(evidence) + if !hasNonNilEvidence(v) { + return nil + } + return v +} + +func normalizeParameterEvidenceVector(evidence []typ.Type) []typ.Type { var out []typ.Type - for i, hint := range hints { - normalized := paramhints.NormalizeHintType(hint) + for i, observed := range evidence { + normalized := paramevidence.NormalizeType(observed) if out != nil { out[i] = normalized continue } - if !typ.TypeEquals(hint, normalized) { - out = make([]typ.Type, len(hints)) - copy(out, hints[:i]) + if !typ.TypeEquals(observed, normalized) { + out = make([]typ.Type, len(evidence)) + copy(out, evidence[:i]) out[i] = normalized } } if out != nil { return out } - return hints + return evidence } -func hasNonNilHint(hints []typ.Type) bool { - for _, h := range hints { - if h != nil { +func hasNonNilEvidence(evidence []typ.Type) bool { + for _, observed := range evidence { + if observed != nil { return true } } return false } -// joinParamHintVectors joins two parameter hint vectors element-wise. -func joinParamHintVectors(a, b []typ.Type) []typ.Type { +// joinParameterEvidenceVectors joins two parameter evidence vectors element-wise. +func joinParameterEvidenceVectors(a, b []typ.Type) []typ.Type { if len(a) == 0 { return b } @@ -578,14 +586,14 @@ func joinParamHintVectors(a, b []typ.Type) []typ.Type { if i < len(b) { bi = b[i] } - result[i] = joinParamHint(ai, bi) + result[i] = joinParameterEvidence(ai, bi) } return result } -func joinParamHint(a, b typ.Type) typ.Type { - a = paramhints.NormalizeHintType(a) - b = paramhints.NormalizeHintType(b) +func joinParameterEvidence(a, b typ.Type) typ.Type { + a = paramevidence.NormalizeType(a) + b = paramevidence.NormalizeType(b) if a == nil { return b } @@ -598,15 +606,15 @@ func joinParamHint(a, b typ.Type) typ.Type { if unwrap.IsNilType(b) && !unwrap.IsNilType(a) { return a } - if joined, ok := joinNilableParamHint(a, b); ok { + if joined, ok := joinNilableParameterEvidence(a, b); ok { return joined } - return joinNonNilParamHint(a, b) + return joinNonNilParameterEvidence(a, b) } -func joinNilableParamHint(a, b typ.Type) (typ.Type, bool) { - ai, anil := splitNilableParamHint(a) - bi, bnil := splitNilableParamHint(b) +func joinNilableParameterEvidence(a, b typ.Type) (typ.Type, bool) { + ai, anil := splitNilableParameterEvidence(a) + bi, bnil := splitNilableParameterEvidence(b) if !anil && !bnil { return nil, false } @@ -619,10 +627,10 @@ func joinNilableParamHint(a, b typ.Type) (typ.Type, bool) { if bi == nil { return typ.NewOptional(ai), true } - return typ.NewOptional(joinNonNilParamHint(ai, bi)), true + return typ.NewOptional(joinNonNilParameterEvidence(ai, bi)), true } -func splitNilableParamHint(t typ.Type) (typ.Type, bool) { +func splitNilableParameterEvidence(t typ.Type) (typ.Type, bool) { t = unwrap.Alias(t) switch v := t.(type) { case nil: @@ -652,8 +660,8 @@ func splitNilableParamHint(t typ.Type) (typ.Type, bool) { } } -func joinNonNilParamHint(a, b typ.Type) typ.Type { - if upper, ok := selectParamHintTableUpperBound(a, b); ok { +func joinNonNilParameterEvidence(a, b typ.Type) typ.Type { + if upper, ok := selectParameterEvidenceTableUpperBound(a, b); ok { return upper } if preferred, ok := preferConcreteOverSoftType(a, b); ok { @@ -680,7 +688,7 @@ func joinNonNilParamHint(a, b typ.Type) typ.Type { if joined, ok := typ.JoinCompatibleRecords(a, b); ok { return joined } - if joined, ok := joinParamHintMapRecord(a, b); ok { + if joined, ok := joinParameterEvidenceMapRecord(a, b); ok { return joined } if TypeExtendsRecord(a, b) { @@ -697,7 +705,7 @@ func joinNonNilParamHint(a, b typ.Type) typ.Type { return a } } - return paramhints.NormalizeHintType(typ.JoinPreferNonSoft(a, b)) + return paramevidence.NormalizeType(typ.JoinPreferNonSoft(a, b)) } func preferConcreteOverSoftType(a, b typ.Type) (typ.Type, bool) { @@ -723,14 +731,14 @@ func preferConcreteOverNilableSoftType(a, b typ.Type) (typ.Type, bool) { } func preferConcreteOverNilableSoftTypeDirected(softMaybeNil, concrete typ.Type) (typ.Type, bool) { - inner, nilable := splitNilableParamHint(softMaybeNil) + inner, nilable := splitNilableParameterEvidence(softMaybeNil) if !nilable || inner == nil || !typ.IsSoft(inner, typ.SoftPlaceholderPolicy) { return nil, false } if concrete == nil || unwrap.IsNilType(concrete) { return nil, false } - concreteInner, concreteNilable := splitNilableParamHint(concrete) + concreteInner, concreteNilable := splitNilableParameterEvidence(concrete) if concreteInner == nil { return nil, false } @@ -743,14 +751,14 @@ func preferConcreteOverNilableSoftTypeDirected(softMaybeNil, concrete typ.Type) return typ.NewOptional(concrete), true } -func joinParamHintMapRecord(a, b typ.Type) (typ.Type, bool) { - if joined, ok := joinParamHintMapRecordDirected(a, b); ok { +func joinParameterEvidenceMapRecord(a, b typ.Type) (typ.Type, bool) { + if joined, ok := joinParameterEvidenceMapRecordDirected(a, b); ok { return joined, true } - return joinParamHintMapRecordDirected(b, a) + return joinParameterEvidenceMapRecordDirected(b, a) } -func joinParamHintMapRecordDirected(mapType, recordType typ.Type) (typ.Type, bool) { +func joinParameterEvidenceMapRecordDirected(mapType, recordType typ.Type) (typ.Type, bool) { m, ok := unwrap.Alias(mapType).(*typ.Map) if !ok || m == nil { return nil, false @@ -760,8 +768,8 @@ func joinParamHintMapRecordDirected(mapType, recordType typ.Type) (typ.Type, boo return nil, false } - key := joinNonNilParamHint(m.Key, r.MapKey) - value := joinNonNilParamHint(m.Value, r.MapValue) + key := joinNonNilParameterEvidence(m.Key, r.MapKey) + value := joinNonNilParameterEvidence(m.Value, r.MapValue) if len(r.Fields) == 0 && r.Metatable == nil { return typ.NewMap(key, value), true } @@ -777,7 +785,7 @@ func joinParamHintMapRecordDirected(mapType, recordType typ.Type) (typ.Type, boo fieldType := field.Type optional := true if subtype.IsSubtype(typ.LiteralString(field.Name), key) { - fieldType = joinNonNilParamHint(field.Type, value) + fieldType = joinNonNilParameterEvidence(field.Type, value) } else { optional = field.Optional } @@ -795,23 +803,23 @@ func joinParamHintMapRecordDirected(mapType, recordType typ.Type) (typ.Type, boo return builder.Build(), true } -func selectParamHintTableUpperBound(a, b typ.Type) (typ.Type, bool) { - if paramHintIsOnlyTableTop(a) && typ.IsAny(b) { +func selectParameterEvidenceTableUpperBound(a, b typ.Type) (typ.Type, bool) { + if parameterEvidenceIsOnlyTableTop(a) && typ.IsAny(b) { return a, true } - if paramHintIsOnlyTableTop(b) && typ.IsAny(a) { + if parameterEvidenceIsOnlyTableTop(b) && typ.IsAny(a) { return b, true } - if paramHintContainsTableTop(a) && paramHintCoveredByTableTop(b) && subtype.IsSubtype(b, a) { + if parameterEvidenceContainsTableTop(a) && parameterEvidenceCoveredByTableTop(b) && subtype.IsSubtype(b, a) { return a, true } - if paramHintContainsTableTop(b) && paramHintCoveredByTableTop(a) && subtype.IsSubtype(a, b) { + if parameterEvidenceContainsTableTop(b) && parameterEvidenceCoveredByTableTop(a) && subtype.IsSubtype(a, b) { return b, true } return nil, false } -func paramHintContainsTableTop(t typ.Type) bool { +func parameterEvidenceContainsTableTop(t typ.Type) bool { if t == nil { return false } @@ -820,12 +828,12 @@ func paramHintContainsTableTop(t typ.Type) bool { } switch v := typ.UnwrapAnnotated(t).(type) { case *typ.Alias: - return paramHintContainsTableTop(v.UnaliasedTarget()) + return parameterEvidenceContainsTableTop(v.UnaliasedTarget()) case *typ.Optional: - return paramHintContainsTableTop(v.Inner) + return parameterEvidenceContainsTableTop(v.Inner) case *typ.Union: for _, member := range v.Members { - if paramHintContainsTableTop(member) { + if parameterEvidenceContainsTableTop(member) { return true } } @@ -833,7 +841,7 @@ func paramHintContainsTableTop(t typ.Type) bool { return false } -func paramHintIsOnlyTableTop(t typ.Type) bool { +func parameterEvidenceIsOnlyTableTop(t typ.Type) bool { if t == nil { return false } @@ -842,9 +850,9 @@ func paramHintIsOnlyTableTop(t typ.Type) bool { } switch v := typ.UnwrapAnnotated(t).(type) { case *typ.Alias: - return paramHintIsOnlyTableTop(v.UnaliasedTarget()) + return parameterEvidenceIsOnlyTableTop(v.UnaliasedTarget()) case *typ.Optional: - return paramHintIsOnlyTableTop(v.Inner) + return parameterEvidenceIsOnlyTableTop(v.Inner) case *typ.Union: if len(v.Members) == 0 { return false @@ -854,7 +862,7 @@ func paramHintIsOnlyTableTop(t typ.Type) bool { if unwrap.IsNilType(member) { continue } - if !paramHintIsOnlyTableTop(member) { + if !parameterEvidenceIsOnlyTableTop(member) { return false } hasTableTop = true @@ -865,7 +873,7 @@ func paramHintIsOnlyTableTop(t typ.Type) bool { } } -func paramHintCoveredByTableTop(t typ.Type) bool { +func parameterEvidenceCoveredByTableTop(t typ.Type) bool { if t == nil { return false } @@ -877,17 +885,17 @@ func paramHintCoveredByTableTop(t typ.Type) bool { } switch v := typ.UnwrapAnnotated(t).(type) { case *typ.Alias: - return paramHintCoveredByTableTop(v.UnaliasedTarget()) + return parameterEvidenceCoveredByTableTop(v.UnaliasedTarget()) case *typ.Optional: - return paramHintCoveredByTableTop(v.Inner) + return parameterEvidenceCoveredByTableTop(v.Inner) case *typ.Recursive: - return v.Body != nil && v.Body != v && paramHintCoveredByTableTop(v.Body) + return v.Body != nil && v.Body != v && parameterEvidenceCoveredByTableTop(v.Body) case *typ.Union: if len(v.Members) == 0 { return false } for _, member := range v.Members { - if !paramHintCoveredByTableTop(member) { + if !parameterEvidenceCoveredByTableTop(member) { return false } } @@ -1168,11 +1176,6 @@ func widenFunctionParamFactTypeForConvergence(existing, candidate typ.Type) typ. return typ.JoinPreferNonSoft(existing, candidate) } -// JoinParamHints merges parameter hints inside one analysis iteration. -func JoinParamHints(prev, next api.ParamHints) api.ParamHints { - return WidenParamHints(prev, next) -} - // WidenLiteralSigs merges two literal signature maps. func WidenLiteralSigs(prev, next api.LiteralSigs) api.LiteralSigs { if prev == nil && next == nil { diff --git a/compiler/check/returns/widen_test.go b/compiler/check/returns/widen_test.go index 13b974dd..04a1901a 100644 --- a/compiler/check/returns/widen_test.go +++ b/compiler/check/returns/widen_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" @@ -176,7 +177,7 @@ func TestMergeReturnSummary_KeepsNonRecursiveContainerRefinement(t *testing.T) { } } -func TestWidenParamHints_StopsSelfEmbeddingRecordGrowth(t *testing.T) { +func TestWidenParameterEvidence_StopsSelfEmbeddingRecordGrowth(t *testing.T) { prevHint := typ.NewUnion( typ.Number, typ.NewRecord(). @@ -189,18 +190,18 @@ func TestWidenParamHints_StopsSelfEmbeddingRecordGrowth(t *testing.T) { SetOpen(true). Build() - merged := WidenParamHints( - api.ParamHints{1: []typ.Type{prevHint}}, - api.ParamHints{1: []typ.Type{nextHint}}, + merged := WidenParameterEvidence( + map[cfg.SymbolID][]typ.Type{1: []typ.Type{prevHint}}, + map[cfg.SymbolID][]typ.Type{1: []typ.Type{nextHint}}, ) got := merged[1][0] if !typ.TypeEquals(got, prevHint) { - t.Fatalf("expected stable previous hint, got %v", got) + t.Fatalf("expected stable previous evidence, got %v", got) } } -func TestWidenParamHints_StopsSelfEmbeddingContainerGrowth(t *testing.T) { +func TestWidenParameterEvidence_StopsSelfEmbeddingContainerGrowth(t *testing.T) { prevHint := typ.NewUnion( typ.Number, typ.NewRecord(). @@ -243,28 +244,28 @@ func TestWidenParamHints_StopsSelfEmbeddingContainerGrowth(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - merged := WidenParamHints( - api.ParamHints{1: []typ.Type{prevHint}}, - api.ParamHints{1: []typ.Type{tt.next}}, + merged := WidenParameterEvidence( + map[cfg.SymbolID][]typ.Type{1: []typ.Type{prevHint}}, + map[cfg.SymbolID][]typ.Type{1: []typ.Type{tt.next}}, ) got := merged[1][0] if !typ.TypeEquals(got, prevHint) { - t.Fatalf("expected stable previous hint, got %v", got) + t.Fatalf("expected stable previous evidence, got %v", got) } }) } } -func TestWidenParamHints_KeepsFirstRecordWrapperObservation(t *testing.T) { +func TestWidenParameterEvidence_KeepsFirstRecordWrapperObservation(t *testing.T) { nextHint := typ.NewRecord(). Field("limit", typ.Number). SetOpen(true). Build() - merged := WidenParamHints( - api.ParamHints{1: []typ.Type{typ.Number}}, - api.ParamHints{1: []typ.Type{nextHint}}, + merged := WidenParameterEvidence( + map[cfg.SymbolID][]typ.Type{1: []typ.Type{typ.Number}}, + map[cfg.SymbolID][]typ.Type{1: []typ.Type{nextHint}}, ) got := merged[1][0] @@ -272,11 +273,11 @@ func TestWidenParamHints_KeepsFirstRecordWrapperObservation(t *testing.T) { t.Fatalf("expected wrapper observation to be preserved, got %v", got) } if !typ.TypeEquals(got, typ.NewUnion(typ.Number, nextHint)) { - t.Fatalf("expected number | wrapper hint, got %v", got) + t.Fatalf("expected number | wrapper evidence, got %v", got) } } -func TestWidenParamHints_JoinsNestedRecordObservations(t *testing.T) { +func TestWidenParameterEvidence_JoinsNestedRecordObservations(t *testing.T) { nested := typ.NewRecord(). Field("routes", typ.NewRecord().Field("users", typ.Boolean).SetOpen(true).Build()). SetOpen(true). @@ -286,9 +287,9 @@ func TestWidenParamHints_JoinsNestedRecordObservations(t *testing.T) { SetOpen(true). Build() - merged := WidenParamHints( - api.ParamHints{1: []typ.Type{outer}}, - api.ParamHints{1: []typ.Type{nested}}, + merged := WidenParameterEvidence( + map[cfg.SymbolID][]typ.Type{1: []typ.Type{outer}}, + map[cfg.SymbolID][]typ.Type{1: []typ.Type{nested}}, ) got := merged[1][0] @@ -298,22 +299,22 @@ func TestWidenParamHints_JoinsNestedRecordObservations(t *testing.T) { } } -func TestWidenParamHints_ReplacesStaleBroadHintWithCurrentRefinement(t *testing.T) { +func TestWidenParameterEvidence_ReplacesStaleBroadHintWithCurrentRefinement(t *testing.T) { stale := typ.NewUnion(typ.String, typ.False) current := typ.String - merged := WidenParamHints( - api.ParamHints{1: []typ.Type{stale}}, - api.ParamHints{1: []typ.Type{current}}, + merged := WidenParameterEvidence( + map[cfg.SymbolID][]typ.Type{1: []typ.Type{stale}}, + map[cfg.SymbolID][]typ.Type{1: []typ.Type{current}}, ) got := merged[1][0] if !typ.TypeEquals(got, current) { - t.Fatalf("expected current refined hint %v to replace stale broad hint, got %v", current, got) + t.Fatalf("expected current refined evidence %v to replace stale broad evidence, got %v", current, got) } } -func TestWidenParamHints_ReplacesSoftContainerPlaceholderWithConcreteElementShape(t *testing.T) { +func TestWidenParameterEvidence_ReplacesSoftContainerPlaceholderWithConcreteElementShape(t *testing.T) { entry := typ.NewRecord().Field("id", typ.String).Build() stale := typ.NewUnion( typ.NewArray(typ.Any), @@ -321,63 +322,63 @@ func TestWidenParamHints_ReplacesSoftContainerPlaceholderWithConcreteElementShap ) current := typ.NewArray(entry) - merged := WidenParamHints( - api.ParamHints{1: []typ.Type{stale}}, - api.ParamHints{1: []typ.Type{current}}, + merged := WidenParameterEvidence( + map[cfg.SymbolID][]typ.Type{1: []typ.Type{stale}}, + map[cfg.SymbolID][]typ.Type{1: []typ.Type{current}}, ) got := merged[1][0] if !typ.TypeEquals(got, current) { - t.Fatalf("expected concrete array hint %v to replace soft stale hint, got %v", current, got) + t.Fatalf("expected concrete array evidence %v to replace soft stale evidence, got %v", current, got) } } -func TestWidenParamHints_PreservesStructuredHintOverNilOnlyObservation(t *testing.T) { +func TestWidenParameterEvidence_PreservesStructuredHintOverNilOnlyObservation(t *testing.T) { context := typ.NewMap(typ.String, typ.Any) - merged := WidenParamHints( - api.ParamHints{1: []typ.Type{typ.String, typ.Any, context}}, - api.ParamHints{1: []typ.Type{typ.String, typ.Any, typ.Nil}}, + merged := WidenParameterEvidence( + map[cfg.SymbolID][]typ.Type{1: []typ.Type{typ.String, typ.Any, context}}, + map[cfg.SymbolID][]typ.Type{1: []typ.Type{typ.String, typ.Any, typ.Nil}}, ) got := merged[1][2] if !typ.TypeEquals(got, context) { - t.Fatalf("expected nil-only observation to preserve structured hint %v, got %v", context, got) + t.Fatalf("expected nil-only observation to preserve structured evidence %v, got %v", context, got) } - again := WidenParamHints(merged, api.ParamHints{1: []typ.Type{typ.String, typ.Any, typ.Nil}}) - if !symbolTypeVectorMapEqual(merged, again) { + again := WidenParameterEvidence(merged, map[cfg.SymbolID][]typ.Type{1: []typ.Type{typ.String, typ.Any, typ.Nil}}) + if !parameterEvidenceEqual(merged, again) { t.Fatalf("expected idempotent nil-only observation widening, got %v then %v", merged, again) } } -func TestWidenParamHints_PreservesMapHintOverOptionalOpenRecordObservation(t *testing.T) { +func TestWidenParameterEvidence_PreservesMapHintOverOptionalOpenRecordObservation(t *testing.T) { context := typ.NewMap(typ.String, typ.Any) optionalContextRecord := typ.NewOptional(typ.NewRecord(). MapComponent(typ.String, typ.Any). SetOpen(true). Build()) - merged := WidenParamHints( - api.ParamHints{1: []typ.Type{typ.String, typ.Any, context}}, - api.ParamHints{1: []typ.Type{typ.String, typ.Any, optionalContextRecord}}, + merged := WidenParameterEvidence( + map[cfg.SymbolID][]typ.Type{1: []typ.Type{typ.String, typ.Any, context}}, + map[cfg.SymbolID][]typ.Type{1: []typ.Type{typ.String, typ.Any, optionalContextRecord}}, ) got := merged[1][2] if got == nil || typ.TypeEquals(got, typ.Nil) { - t.Fatalf("expected optional structured observation to preserve context hint, got %v", got) + t.Fatalf("expected optional structured observation to preserve context evidence, got %v", got) } if !typ.TypeEquals(got, typ.NewOptional(context)) { t.Fatalf("expected pure map observation to stay canonical, got %v", got) } - again := WidenParamHints(merged, api.ParamHints{1: []typ.Type{typ.String, typ.Any, optionalContextRecord}}) - if !symbolTypeVectorMapEqual(merged, again) { + again := WidenParameterEvidence(merged, map[cfg.SymbolID][]typ.Type{1: []typ.Type{typ.String, typ.Any, optionalContextRecord}}) + if !parameterEvidenceEqual(merged, again) { t.Fatalf("expected idempotent optional structured observation widening, got %v then %v", merged, again) } } -func TestWidenParamHints_CollapsesPureOpenRecordMapToCanonicalMap(t *testing.T) { +func TestWidenParameterEvidence_CollapsesPureOpenRecordMapToCanonicalMap(t *testing.T) { entry := typ.NewRecord().Field("id", typ.String).Build() canonical := typ.NewMap(typ.String, typ.NewArray(entry)) staleRecordView := typ.NewRecord(). @@ -385,18 +386,18 @@ func TestWidenParamHints_CollapsesPureOpenRecordMapToCanonicalMap(t *testing.T) SetOpen(true). Build() - merged := WidenParamHints( - api.ParamHints{1: []typ.Type{staleRecordView}}, - api.ParamHints{1: []typ.Type{canonical}}, + merged := WidenParameterEvidence( + map[cfg.SymbolID][]typ.Type{1: []typ.Type{staleRecordView}}, + map[cfg.SymbolID][]typ.Type{1: []typ.Type{canonical}}, ) got := merged[1][0] if !typ.TypeEquals(got, canonical) { - t.Fatalf("expected pure keyed table hint to canonicalize to %v, got %v", canonical, got) + t.Fatalf("expected pure keyed table evidence to canonicalize to %v, got %v", canonical, got) } } -func TestWidenParamHints_TableTopUpperBoundAbsorbsRecordUnion(t *testing.T) { +func TestWidenParameterEvidence_TableTopUpperBoundAbsorbsRecordUnion(t *testing.T) { tableTop := typ.NewOptional(typ.NewInterface("table", nil)) strategySpec := typ.NewRecord(). Field("kind", typ.LiteralString("strategy")). @@ -408,9 +409,9 @@ func TestWidenParamHints_TableTopUpperBoundAbsorbsRecordUnion(t *testing.T) { Build() nextHint := typ.NewUnion(strategySpec, contextSpec) - merged := WidenParamHints( - api.ParamHints{1: []typ.Type{tableTop}}, - api.ParamHints{1: []typ.Type{nextHint}}, + merged := WidenParameterEvidence( + map[cfg.SymbolID][]typ.Type{1: []typ.Type{tableTop}}, + map[cfg.SymbolID][]typ.Type{1: []typ.Type{nextHint}}, ) got := merged[1][0] @@ -418,18 +419,18 @@ func TestWidenParamHints_TableTopUpperBoundAbsorbsRecordUnion(t *testing.T) { t.Fatalf("expected table top upper bound %v, got %v", tableTop, got) } - again := WidenParamHints(merged, api.ParamHints{1: []typ.Type{nextHint}}) - if !symbolTypeVectorMapEqual(merged, again) { + again := WidenParameterEvidence(merged, map[cfg.SymbolID][]typ.Type{1: []typ.Type{nextHint}}) + if !parameterEvidenceEqual(merged, again) { t.Fatalf("expected idempotent table-top widening, got %v then %v", merged, again) } } -func TestWidenParamHints_TableTopUpperBoundAbsorbsAnyObservation(t *testing.T) { +func TestWidenParameterEvidence_TableTopUpperBoundAbsorbsAnyObservation(t *testing.T) { tableTop := typ.NewOptional(typ.NewInterface("table", nil)) - merged := WidenParamHints( - api.ParamHints{1: []typ.Type{tableTop}}, - api.ParamHints{1: []typ.Type{typ.Any}}, + merged := WidenParameterEvidence( + map[cfg.SymbolID][]typ.Type{1: []typ.Type{tableTop}}, + map[cfg.SymbolID][]typ.Type{1: []typ.Type{typ.Any}}, ) got := merged[1][0] @@ -437,12 +438,25 @@ func TestWidenParamHints_TableTopUpperBoundAbsorbsAnyObservation(t *testing.T) { t.Fatalf("expected dynamic observation to preserve table top upper bound %v, got %v", tableTop, got) } - again := WidenParamHints(merged, api.ParamHints{1: []typ.Type{typ.Any}}) - if !symbolTypeVectorMapEqual(merged, again) { + again := WidenParameterEvidence(merged, map[cfg.SymbolID][]typ.Type{1: []typ.Type{typ.Any}}) + if !parameterEvidenceEqual(merged, again) { t.Fatalf("expected idempotent table-top/any widening, got %v then %v", merged, again) } } +func parameterEvidenceEqual(a, b map[cfg.SymbolID][]typ.Type) bool { + if len(a) != len(b) { + return false + } + for _, sym := range cfg.SortedSymbolIDs(a) { + right, ok := b[sym] + if !ok || !ReturnTypesEqual(a[sym], right) { + return false + } + } + return true +} + func TestWidenCapturedFieldAssigns_NormalizesOptionalFunctionValues(t *testing.T) { fn := typ.Func().Param("fn", typ.Unknown).Build() merged := WidenCapturedFieldAssigns(nil, api.CapturedFieldAssigns{ diff --git a/compiler/check/session.go b/compiler/check/session.go index 6a746337..015d01ea 100644 --- a/compiler/check/session.go +++ b/compiler/check/session.go @@ -15,7 +15,7 @@ // cached function analysis when facts/effects actually change. // // - IterationScratch: Single-iteration state cleared at each boundary. -// Tracks which literals have been analyzed, pending parameter hints, +// Tracks which literals have been analyzed, pending parameter evidence, // and change detection flags. // // # SNAPSHOT PROTOCOL diff --git a/compiler/check/store/doc.go b/compiler/check/store/doc.go index b6d27010..76e392ae 100644 --- a/compiler/check/store/doc.go +++ b/compiler/check/store/doc.go @@ -10,7 +10,7 @@ // The store holds: // - Built CFGs indexed by graph ID // - Analysis results (types, flow facts, diagnostics) per function -// - Interprocedural facts (canonical function facts, parameter hints) +// - Interprocedural facts (canonical function facts, parameter evidence) // - Module-level bindings and alias maps // - Query-tracked interprocedural snapshot inputs for precise function-result // cache revalidation diff --git a/compiler/check/store/facts_clone.go b/compiler/check/store/facts_clone.go index 7b6eb2f1..0facdbd7 100644 --- a/compiler/check/store/facts_clone.go +++ b/compiler/check/store/facts_clone.go @@ -13,7 +13,6 @@ func cloneFacts(f api.Facts) api.Facts { } return api.Facts{ FunctionFacts: cloneFunctionFacts(f.FunctionFacts), - ParamHints: cloneParamHints(f.ParamHints), LiteralSigs: cloneLiteralSigs(f.LiteralSigs), CapturedTypes: cloneCapturedTypes(f.CapturedTypes), CapturedFields: cloneCapturedFieldAssigns(f.CapturedFields), @@ -28,6 +27,7 @@ func cloneFunctionFacts(src api.FunctionFacts) api.FunctionFacts { } out := make(api.FunctionFacts, len(src)) for sym, fact := range src { + fact.Params = cloneTypeSlice(fact.Params) fact.Summary = cloneTypeSlice(fact.Summary) fact.Narrow = cloneTypeSlice(fact.Narrow) out[sym] = fact @@ -35,17 +35,6 @@ func cloneFunctionFacts(src api.FunctionFacts) api.FunctionFacts { return out } -func cloneParamHints(src api.ParamHints) api.ParamHints { - if len(src) == 0 { - return nil - } - out := make(api.ParamHints, len(src)) - for sym, hints := range src { - out[sym] = cloneTypeSlice(hints) - } - return out -} - func cloneLiteralSigs(src api.LiteralSigs) api.LiteralSigs { if len(src) == 0 { return nil diff --git a/compiler/check/store/store.go b/compiler/check/store/store.go index aecbeba3..504e909b 100644 --- a/compiler/check/store/store.go +++ b/compiler/check/store/store.go @@ -542,7 +542,6 @@ func (s *SessionStore) ParentGraphKeyForSymbol(sym cfg.SymbolID) (api.GraphKey, func factsEmpty(f api.Facts) bool { return len(f.FunctionFacts) == 0 && - len(f.ParamHints) == 0 && len(f.LiteralSigs) == 0 && len(f.CapturedTypes) == 0 && len(f.CapturedFields) == 0 && @@ -807,15 +806,6 @@ func (s *SessionStore) GetInterprocFactsSnapshot( return s.currentInterprocFacts(key) } -// GetParamHintsSnapshot returns param hints from the stable interproc snapshot. -func (s *SessionStore) GetParamHintsSnapshot( - graph *cfg.Graph, - parent *scope.State, -) map[cfg.SymbolID][]typ.Type { - s.requirePhase(api.PhaseScopeCompute) - return s.GetInterprocFactsSnapshot(graph, parent).ParamHints -} - // GetFunctionFactsSnapshot returns canonical function facts from the stable // interproc snapshot. func (s *SessionStore) GetFunctionFactsSnapshot( diff --git a/compiler/check/store/store_test.go b/compiler/check/store/store_test.go index be860f69..e6970f6e 100644 --- a/compiler/check/store/store_test.go +++ b/compiler/check/store/store_test.go @@ -281,21 +281,22 @@ func TestGetInterprocFactsSnapshot_ReturnsImmutableFactContainers(t *testing.T) key := api.KeyForGraph(graph, parent.Hash()) sym := cfg.SymbolID(7) s.InterprocPrev.Facts[key] = api.Facts{ - ParamHints: api.ParamHints{ - sym: []typ.Type{typ.String, typ.NewMap(typ.String, typ.Any)}, - }, FunctionFacts: api.FunctionFacts{ - sym: {Summary: []typ.Type{typ.String}}, + sym: { + Params: []typ.Type{typ.String, typ.NewMap(typ.String, typ.Any)}, + Summary: []typ.Type{typ.String}, + }, }, } snapshot := s.GetInterprocFactsSnapshot(graph, parent) - snapshot.ParamHints[sym][1] = typ.Nil + snapshotFact := snapshot.FunctionFacts[sym] + snapshotFact.Params[1] = typ.Nil snapshot.FunctionFacts[sym] = api.FunctionFact{Summary: []typ.Type{typ.Number}} again := s.GetInterprocFactsSnapshot(graph, parent) - if got := again.ParamHints[sym][1]; !typ.TypeEquals(got, typ.NewMap(typ.String, typ.Any)) { - t.Fatalf("snapshot param hint mutation leaked into store: %v", got) + if got := again.FunctionFacts.Params(sym)[1]; !typ.TypeEquals(got, typ.NewMap(typ.String, typ.Any)) { + t.Fatalf("snapshot parameter evidence mutation leaked into store: %v", got) } if got := again.FunctionFacts.Summary(sym); len(got) != 1 || !typ.TypeEquals(got[0], typ.String) { t.Fatalf("snapshot function fact mutation leaked into store: %v", got) diff --git a/compiler/check/tests/core/return_field_merge_test.go b/compiler/check/tests/core/return_field_merge_test.go index bb4c3e87..10d209e1 100644 --- a/compiler/check/tests/core/return_field_merge_test.go +++ b/compiler/check/tests/core/return_field_merge_test.go @@ -254,10 +254,10 @@ func TestReturnFieldMerge_ModuleImport(t *testing.T) { } } -// TestParamHintsSeesEnrichedReturns verifies that param hints are computed +// TestParameterEvidenceSeesEnrichedReturns verifies that parameter evidence is computed // using enriched return types (with field merges applied), not raw returns. -// This test fails with the timing bug (param hints see {} instead of {value: number}). -func TestParamHintsSeesEnrichedReturns(t *testing.T) { +// This test fails with the timing bug (parameter evidence sees {} instead of {value: number}). +func TestParameterEvidenceSeesEnrichedReturns(t *testing.T) { code := ` local function make_obj() local obj = {} @@ -277,8 +277,8 @@ func TestParamHintsSeesEnrichedReturns(t *testing.T) { } } -// TestParamHintsSeesEnrichedReturns_Method verifies method calls work through param hints. -func TestParamHintsSeesEnrichedReturns_Method(t *testing.T) { +// TestParameterEvidenceSeesEnrichedReturns_Method verifies method calls work through parameter evidence. +func TestParameterEvidenceSeesEnrichedReturns_Method(t *testing.T) { code := ` local function make_obj() local obj = {} diff --git a/compiler/check/tests/flow/fixpoint_unification_test.go b/compiler/check/tests/flow/fixpoint_unification_test.go index c74e33b7..07026c3b 100644 --- a/compiler/check/tests/flow/fixpoint_unification_test.go +++ b/compiler/check/tests/flow/fixpoint_unification_test.go @@ -108,10 +108,10 @@ local result: string = tbl:process(42) // Literal signatures channel removed in canonical query architecture. } -// TestFixpointUnification_ParamHintPropagation verifies that parameter hints +// TestFixpointUnification_ParameterEvidencePropagation verifies that parameter evidence // from call sites propagate across iterations. In a chain A -> B -> C, where -// A calls B with a number and B calls C, param hints should stabilize. -func TestFixpointUnification_ParamHintPropagation(t *testing.T) { +// A calls B with a number and B calls C, parameter evidence should stabilize. +func TestFixpointUnification_ParameterEvidencePropagation(t *testing.T) { source := ` local function c(x) return x + 1 @@ -404,12 +404,12 @@ func TestFixpointUnification_EffectRowLabels(t *testing.T) { } } -// TestFixpointUnification_ParamHintNestedPropagation verifies that parameter -// hints propagate correctly through nested function calls within function bodies. -// This is a regression test for the early break bug where PropagateParamHintsFromCallGraph +// TestFixpointUnification_ParameterEvidenceNestedPropagation verifies that parameter +// evidence propagate correctly through nested function calls within function bodies. +// This is a regression test for the early break bug where PropagateParameterEvidence // would fail to resolve callee symbols from identifiers when CalleeSymbol was zero. -func TestFixpointUnification_ParamHintNestedPropagation(t *testing.T) { - // d calls c, c calls b, b has parameter x. Hints should flow d->c->b. +func TestFixpointUnification_ParameterEvidenceNestedPropagation(t *testing.T) { + // d calls c, c calls b, b has parameter x. Evidence should flow d->c->b. // The key is that inner calls (c calling b) need identifier resolution. source := ` local function b(x) @@ -456,18 +456,21 @@ local result: number = d() continue } if typ.TypeEquals(fact.Summary[0], typ.Unknown) { - t.Errorf("return type for %q is unknown, expected number (hints didn't propagate)", name) + t.Errorf("return type for %q is unknown, expected number (evidence didn't propagate)", name) } } } - // Verify that param hints were propagated to inner functions. - paramHintsFound := false - if hints := sess.Store.GetParamHintsSnapshot(sess.RootResult.Graph, parent); len(hints) > 0 { - paramHintsFound = true + // Verify that parameter evidence was propagated to inner functions. + parameterEvidenceFound := false + for _, fact := range functionFacts { + if len(fact.Params) > 0 { + parameterEvidenceFound = true + break + } } - if !paramHintsFound { - t.Log("no param hints found in ParamHintsPrev (propagation may have converged)") + if !parameterEvidenceFound { + t.Log("no parameter evidence found in canonical function facts (propagation may have converged)") } } diff --git a/compiler/check/tests/inference/param_hints_and_returns_test.go b/compiler/check/tests/inference/parameter_evidence_and_returns_test.go similarity index 100% rename from compiler/check/tests/inference/param_hints_and_returns_test.go rename to compiler/check/tests/inference/parameter_evidence_and_returns_test.go diff --git a/compiler/check/tests/modules/manifest_test.go b/compiler/check/tests/modules/manifest_test.go index 8bd48035..8a1d7cb8 100644 --- a/compiler/check/tests/modules/manifest_test.go +++ b/compiler/check/tests/modules/manifest_test.go @@ -88,9 +88,9 @@ func TestManifest_LocalRequireInFunction(t *testing.T) { } } -// TestManifest_SoftAnnotationParamHints ensures soft annotations like {any} -// are overridden by call-site param hints. -func TestManifest_SoftAnnotationParamHints(t *testing.T) { +// TestManifest_SoftAnnotationParameterEvidence ensures soft annotations like {any} +// are overridden by call-site parameter evidence. +func TestManifest_SoftAnnotationParameterEvidence(t *testing.T) { registryManifest := io.NewManifest("registry") entryType := typ.NewRecord().Field("id", typ.String).Build() findFn := typ.Func().Param("query", typ.Any).Returns(typ.NewArray(entryType)).Build() @@ -118,7 +118,7 @@ func TestManifest_SoftAnnotationParamHints(t *testing.T) { for _, d := range result.Errors { t.Logf("error: %s", d.Message) } - t.Errorf("expected no errors with soft annotation param hints") + t.Errorf("expected no errors with soft annotation parameter evidence") } } @@ -193,7 +193,6 @@ func TestManifest_SoftLocalAnnotations(t *testing.T) { parentHash := result.Session.Store.GraphParentHashOf(root.ID()) parent := result.Session.Store.Parents()[parentHash] functionFacts := result.Session.Store.GetFunctionFactsSnapshot(root, parent) - paramHints := result.Session.Store.GetParamHintsSnapshot(root, parent) groupSym := localFunctionSymbolByName(t, root, functionFacts, "group_by_suite") runSuiteSym := localFunctionSymbolByName(t, root, functionFacts, "run_suite") @@ -215,8 +214,8 @@ func TestManifest_SoftLocalAnnotations(t *testing.T) { if runSuiteFn == nil || len(runSuiteFn.Params) < 2 || !typ.TypeEquals(runSuiteFn.Params[1].Type, entryArray) { t.Fatalf("expected run_suite tests param to refine to %v, got %v", entryArray, functionFacts.FunctionType(runSuiteSym)) } - if hints := paramHints[runSuiteSym]; len(hints) < 2 || !typ.TypeEquals(hints[1], entryArray) { - t.Fatalf("expected run_suite param hint %v, got %v", entryArray, hints) + if evidence := functionFacts.Params(runSuiteSym); len(evidence) < 2 || !typ.TypeEquals(evidence[1], entryArray) { + t.Fatalf("expected run_suite parameter evidence %v, got %v", entryArray, evidence) } } diff --git a/compiler/check/tests/regression/assert_false_discriminant_narrowing_test.go b/compiler/check/tests/regression/assert_false_discriminant_narrowing_test.go index f1486852..41828abe 100644 --- a/compiler/check/tests/regression/assert_false_discriminant_narrowing_test.go +++ b/compiler/check/tests/regression/assert_false_discriminant_narrowing_test.go @@ -665,7 +665,7 @@ return mapper } } -func TestRegression_PartialRecordParamHintsBecomeOptionalFields(t *testing.T) { +func TestRegression_PartialRecordParameterEvidenceBecomeOptionalFields(t *testing.T) { source := ` local mapper = {} @@ -692,7 +692,7 @@ mapper.map_tokens({ thoughtsTokenCount = 40 }) } } -func TestRegression_NestedPartialRecordParamHintsBecomeOptionalFields(t *testing.T) { +func TestRegression_NestedPartialRecordParameterEvidenceBecomeOptionalFields(t *testing.T) { source := ` local mapper = {} diff --git a/compiler/check/tests/regression/explicit_any_param_contract_test.go b/compiler/check/tests/regression/explicit_any_param_contract_test.go index a4c4ad54..aec624b4 100644 --- a/compiler/check/tests/regression/explicit_any_param_contract_test.go +++ b/compiler/check/tests/regression/explicit_any_param_contract_test.go @@ -7,7 +7,7 @@ import ( ) // Explicit `any` parameter annotations are contracts and must not be rewritten -// by call-site param hints. This mirrors wippy.test:runner wait_for(ch: any). +// by call-site parameter evidence. This mirrors wippy.test:runner wait_for(ch: any). func TestExplicitAnyParamAnnotation_IsNotRewrittenByHints(t *testing.T) { source := ` local function wait_for(ch: any, timeout: any) diff --git a/compiler/check/tests/regression/false_positives_unit_test.go b/compiler/check/tests/regression/false_positives_unit_test.go index c7fffb14..44d63d4c 100644 --- a/compiler/check/tests/regression/false_positives_unit_test.go +++ b/compiler/check/tests/regression/false_positives_unit_test.go @@ -363,7 +363,7 @@ func TestFalsePositive_ErrorReturnSuccessWithImplicitNilErrorNarrowsSibling(t *t } } -func TestFalsePositive_MethodReceiverParamHintInfersCapturedSelfFields(t *testing.T) { +func TestFalsePositive_MethodReceiverParameterEvidenceInfersCapturedSelfFields(t *testing.T) { source := ` type Output = { kind: "rendered", @@ -421,7 +421,7 @@ func TestFalsePositive_MethodReceiverParamHintInfersCapturedSelfFields(t *testin ` result := testutil.Check(source, testutil.WithStdlib()) if result.HasError() { - t.Errorf("expected method receiver hints to type captured builder fields, got: %v", testutil.ErrorMessages(result.Diagnostics)) + t.Errorf("expected method receiver evidence to type captured builder fields, got: %v", testutil.ErrorMessages(result.Diagnostics)) } } diff --git a/compiler/check/tests/regression/http_timeout_option_inference_test.go b/compiler/check/tests/regression/http_timeout_option_inference_test.go index d1dd67cc..62d0297c 100644 --- a/compiler/check/tests/regression/http_timeout_option_inference_test.go +++ b/compiler/check/tests/regression/http_timeout_option_inference_test.go @@ -269,7 +269,7 @@ call_two() for _, e := range result.Errors { t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) } - t.Fatal("expected body contract to dominate compatible multi-call http option hints") + t.Fatal("expected body contract to dominate compatible multi-call http option evidence") } } diff --git a/compiler/check/tests/regression/nested_call_param_hints_test.go b/compiler/check/tests/regression/nested_call_parameter_evidence_test.go similarity index 86% rename from compiler/check/tests/regression/nested_call_param_hints_test.go rename to compiler/check/tests/regression/nested_call_parameter_evidence_test.go index 80835331..e01e68b7 100644 --- a/compiler/check/tests/regression/nested_call_param_hints_test.go +++ b/compiler/check/tests/regression/nested_call_parameter_evidence_test.go @@ -7,8 +7,8 @@ import ( ) // Regression guard: nested calls used as arguments must still contribute -// parameter hints to local helper functions. -func TestNestedCall_ParamHintsFlowIntoLocalHelper(t *testing.T) { +// parameter evidence to local helper functions. +func TestNestedCall_ParameterEvidenceFlowIntoLocalHelper(t *testing.T) { source := ` type Entry = { id: string, kind: string } diff --git a/compiler/check/tests/regression/param_hint_depth_convergence_test.go b/compiler/check/tests/regression/parameter_evidence_depth_convergence_test.go similarity index 94% rename from compiler/check/tests/regression/param_hint_depth_convergence_test.go rename to compiler/check/tests/regression/parameter_evidence_depth_convergence_test.go index 68c96a8f..54d31cde 100644 --- a/compiler/check/tests/regression/param_hint_depth_convergence_test.go +++ b/compiler/check/tests/regression/parameter_evidence_depth_convergence_test.go @@ -18,7 +18,7 @@ func numericAliasChain(depth int) string { return b.String() } -func TestParamHints_DeepAliasChain_NoInterprocNonConvergenceWarning(t *testing.T) { +func TestParameterEvidence_DeepAliasChain_NoInterprocNonConvergenceWarning(t *testing.T) { code := numericAliasChain(32) + ` local function g(x) return x + 1 @@ -43,7 +43,7 @@ func TestParamHints_DeepAliasChain_NoInterprocNonConvergenceWarning(t *testing.T } } -func TestParamHints_RecordWrapperFeedback_NoInterprocNonConvergenceWarning(t *testing.T) { +func TestParameterEvidence_RecordWrapperFeedback_NoInterprocNonConvergenceWarning(t *testing.T) { code := ` local repo = require("kb_repo") @@ -92,7 +92,7 @@ func TestParamHints_RecordWrapperFeedback_NoInterprocNonConvergenceWarning(t *te } } -func TestParamHints_NestedWrapperFeedback_NoInterprocNonConvergenceWarning(t *testing.T) { +func TestParameterEvidence_NestedWrapperFeedback_NoInterprocNonConvergenceWarning(t *testing.T) { code := ` local repo = require("kb_repo") @@ -146,7 +146,7 @@ func TestParamHints_NestedWrapperFeedback_NoInterprocNonConvergenceWarning(t *te } } -func TestParamHints_OptionalContextTableFeedback_NoInterprocNonConvergenceWarning(t *testing.T) { +func TestParameterEvidence_OptionalContextTableFeedback_NoInterprocNonConvergenceWarning(t *testing.T) { code := ` local function merge_context(base, additions) local out = {} diff --git a/compiler/check/tests/regression/wippy_sorted_keys_param_hints_test.go b/compiler/check/tests/regression/wippy_sorted_keys_parameter_evidence_test.go similarity index 97% rename from compiler/check/tests/regression/wippy_sorted_keys_param_hints_test.go rename to compiler/check/tests/regression/wippy_sorted_keys_parameter_evidence_test.go index 5a0fafa0..a285d67f 100644 --- a/compiler/check/tests/regression/wippy_sorted_keys_param_hints_test.go +++ b/compiler/check/tests/regression/wippy_sorted_keys_parameter_evidence_test.go @@ -10,7 +10,7 @@ import ( "github.com/wippyai/go-lua/types/typ" ) -// Regression: call-site param hints must keep informative soft map shapes. +// Regression: call-site parameter evidence must keep informative soft map shapes. // If {[string]: any[]} is dropped as "soft", sorted key iteration degrades to // `name: any`, which then breaks suites[name] and downstream run_test(entry.id). func TestWippyRunner_SortedKeysRetainsMapKeyHints(t *testing.T) { @@ -99,7 +99,7 @@ func TestWippyRunner_SortedKeysRetainsMapKeyHints(t *testing.T) { for _, d := range result.Errors { t.Logf("error: %s", d.Message) } - t.Fatal("expected no errors for sorted_keys/group_by_suite hint propagation") + t.Fatal("expected no errors for sorted_keys/group_by_suite evidence propagation") } } @@ -392,15 +392,14 @@ func TestWippyRunner_NearLiteralTestRunnerFlow(t *testing.T) { parentHash := result.Session.Store.GraphParentHashOf(root.ID()) parent := result.Session.Store.Parents()[parentHash] functionFacts := result.Session.Store.GetFunctionFactsSnapshot(root, parent) - hints := result.Session.Store.GetParamHintsSnapshot(root, parent) if bindings := result.Session.Store.ModuleBindings(); bindings != nil { for sym, fact := range functionFacts { name := bindings.Name(sym) if name == "sorted_keys" || name == "run_suite" || name == "run_test" || name == "group_by_suite" { fnType := fact.Type t.Logf("local-fn %q sym=%d type=%s", name, sym, typ.Format(fnType, typ.DefaultFormatOptions)) - if hv := hints[sym]; len(hv) > 0 { - t.Logf("param-hints %q: %v", name, hv) + if hv := fact.Params; len(hv) > 0 { + t.Logf("parameter-evidence %q: %v", name, hv) } } } From ca519170f3c246c0e28e5bcb1b5573eb94948238 Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 02:07:39 -0400 Subject: [PATCH 18/71] Move parameter evidence into domain packages --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 84 ++- .../{infer => domain}/paramevidence/doc.go | 16 +- compiler/check/domain/paramevidence/merge.go | 438 ++++++++++++ .../check/domain/paramevidence/merge_test.go | 300 +++++++++ .../paramevidence/parameter_evidence.go | 0 .../paramevidence/parameter_evidence_test.go | 0 .../paramevidence/project.go | 0 compiler/check/domain/value/shape.go | 391 +++++++++++ compiler/check/domain/value/shape_test.go | 30 + compiler/check/infer/interproc/postflow.go | 2 +- compiler/check/infer/return/infer.go | 2 +- .../check/infer/return/overlay_pipeline.go | 2 +- compiler/check/infer/return/scc.go | 2 +- compiler/check/phase/scope.go | 2 +- compiler/check/phase/types_test.go | 2 +- compiler/check/pipeline/runner.go | 2 +- compiler/check/pipeline/runner_stages.go | 2 +- compiler/check/returns/callgraph.go | 2 +- compiler/check/returns/function_facts.go | 3 +- compiler/check/returns/join.go | 229 +------ compiler/check/returns/join_test.go | 26 - compiler/check/returns/kernel.go | 3 +- compiler/check/returns/widen.go | 631 +----------------- compiler/check/returns/widen_test.go | 281 -------- 24 files changed, 1285 insertions(+), 1165 deletions(-) rename compiler/check/{infer => domain}/paramevidence/doc.go (51%) create mode 100644 compiler/check/domain/paramevidence/merge.go create mode 100644 compiler/check/domain/paramevidence/merge_test.go rename compiler/check/{infer => domain}/paramevidence/parameter_evidence.go (100%) rename compiler/check/{infer => domain}/paramevidence/parameter_evidence_test.go (100%) rename compiler/check/{infer => domain}/paramevidence/project.go (100%) create mode 100644 compiler/check/domain/value/shape.go create mode 100644 compiler/check/domain/value/shape_test.go diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index 0b07320a..f27c8273 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -33,7 +33,8 @@ This is intentionally not a bridge. No production code reads a legacy Second cleanup slice in the same migration: - the local inference package was renamed from `infer/paramhints` to - `infer/paramevidence`; + `infer/paramevidence`, then the domain was moved to + `domain/paramevidence` when its lattice laws were consolidated; - `LocalFuncInfo.ParamHints` became `LocalFuncInfo.ParameterEvidence`; - phase input `ParamHintSignatures` became `ParameterEvidenceSignatures`; - local call-graph propagation now exposes `PropagateParameterEvidence`; @@ -69,6 +70,60 @@ Remaining cleanup after this parameter-evidence slice: public functions that validate invalid input with `type(...)` guards and should infer a wider accepted input domain without weakening the guarded body. +## 2026-05-19 Domain Rectification Checkpoint + +The next flash-migration slice moved parameter evidence out of inference/return +orchestration and into a domain owner: + +- `compiler/check/infer/paramevidence` was moved to + `compiler/check/domain/paramevidence`; +- shared value-shape predicates that were duplicated during the first move were + factored into `compiler/check/domain/value`; +- parameter-evidence vector/map normalization, join, widening, table-top + absorption, nilability splitting, soft/concrete selection, and truthy-key + refinement now live under domain packages; +- `returns` no longer owns parameter evidence merge helpers. Function-fact + parameter slots delegate to `paramevidence.JoinVectors`, + `paramevidence.FilterEmptyVector`, and `paramevidence.RefinesFunctionParam`; +- return-summary and parameter-evidence code both call `domain/value` for + optional elision, truthy refinements, soft/concrete preference, recursive + structural scanning, and record-extension checks; +- parameter-evidence law tests moved with the domain, so the tests describe the + owner instead of the old return package. + +This is not a compatibility bridge. The old package path and old +`WidenParameterEvidence` API were deleted. Call sites moved directly to the +domain package. + +Verification for this slice so far: + +- `go test ./compiler/check/domain/value` passes. +- `go test ./compiler/check/domain/paramevidence` passes. +- `go test ./compiler/check/returns` passes. +- `go test ./compiler/check/...` passes. +- `go test ./...` passes. +- `git diff --check` passes. +- Standard `../scripts/verify-suite.sh` passes the go-lua checker tests and + Wippy binary build, then exits non-zero on external lint targets while the + Wippy checkout is still using its pinned go-lua module: session 8 errors, + agent/src 8 errors, docker-demo 21 errors and 2 warnings. +- Local-replace replay with + `WIPPY_DIR=/tmp/wippy-golua-local-replace GOFLAGS=-buildvcs=false` also + passes the go-lua checker tests and Wippy binary build, then exits non-zero + on known external diagnostics: tests/app 2 errors/4 warnings, session 20, + actor/test 3, agent/src 11, docker-demo 72, llm/src 9, llm/test 9, + migration 1, views 1. + +Design result: + +- orchestration still decides when evidence is collected from calls, body use, + post-flow observations, or signatures; +- the parameter-evidence domain now decides how evidence combines; +- the value domain owns shared structural predicates instead of duplicating them + under returns and parameter evidence; +- helper names that encode parameter-specific lattice laws are no longer local + return-package predicates. + ## Goal The checker should read as one abstract interpreter over a product domain. @@ -3112,9 +3167,11 @@ type ParamSlotDomain struct { } ``` -The current `candidateRefinesFunctionParam`, `typeRefinesTableKeyByTruthiness`, -`preferConcreteOverSoftType`, and related functions become methods or private -support functions of this domain. +The previous `candidateRefinesFunctionParam`, +`typeRefinesTableKeyByTruthiness`, `preferConcreteOverSoftType`, and related +return-package functions are being collapsed into domain-owned operations. +Parameter-specific pieces now live in `domain/paramevidence`; the remaining +function-fact merge should move to `domain/functionfact`. ### Return Summary Domain @@ -3158,12 +3215,13 @@ Candidate home: compiler/check/domain/paramevidence ``` -Current split: +Current state after the first domain slice: -- some policy lives in `compiler/check/infer/paramevidence`, -- some lives in `compiler/check/returns/widen.go`, -- some lives in return SCC inference, -- some lives in interproc postflow. +- merge/canonicalization policy lives in `compiler/check/domain/paramevidence`; +- return SCC inference and interproc postflow still collect observations, but + they call the domain to merge them; +- remaining work is to separate collection orchestration from the pure domain + surface where it improves readability without adding a bridge. Final rule: @@ -3210,10 +3268,10 @@ helper joins directly. |---|---|---| | `JoinFacts`, `WidenFacts`, fact equality | `compiler/check/returns` | `domain/interproc` | | function fact type merge | `compiler/check/returns/join.go` | `domain/functionfact` | -| function param-slot refinement | `compiler/check/returns/join.go`, `widen.go` | `domain/functionfact.ParamSlotDomain` | +| function param-slot refinement | `domain/paramevidence` plus `domain/value`, called by `returns/join.go`, `widen.go` | `domain/functionfact.ParamSlotDomain` delegating value refinements to `domain/paramevidence`/`domain/value` | | return-vector merge/repair | `compiler/check/returns/join.go` | `domain/returnsummary` | -| table-top absorption | `infer/paramevidence`, `returns/widen.go` | `domain/paramevidence` plus value-domain classifier | -| soft vs concrete evidence | `typ/soft.go`, `returns/widen.go`, return overlay | `domain/value` evidence policy | +| table-top absorption | `domain/paramevidence` | `domain/paramevidence` plus value-domain classifier | +| soft vs concrete evidence | `typ/soft.go`, `domain/value`, return overlay | `domain/value` evidence policy | | open-record row-tail merge | `types/typ/policy.go` | `domain/value` row-shape policy | | path/query/alias identity | `constraint`, `flowbuild/path`, `flow/pathkey` | `memory` | | table/container mutation replay | `nested`, `returns`, `flowbuild`, `flow` | `memory` mutation domain | @@ -3226,7 +3284,7 @@ helper joins directly. ### Table-Key Truthiness Refinement -Current smell: +Previous smell: ```go candidateRefinesFunctionParam(candidate, baseline) diff --git a/compiler/check/infer/paramevidence/doc.go b/compiler/check/domain/paramevidence/doc.go similarity index 51% rename from compiler/check/infer/paramevidence/doc.go rename to compiler/check/domain/paramevidence/doc.go index f1e149dc..51f15cfa 100644 --- a/compiler/check/infer/paramevidence/doc.go +++ b/compiler/check/domain/paramevidence/doc.go @@ -1,7 +1,9 @@ -// Package paramevidence computes parameter evidence from call-site arguments. +// Package paramevidence owns the parameter-evidence domain. // -// This package analyzes function call sites and body uses to build effective -// parameter types for functions without explicit type annotations. +// The domain canonicalizes, joins, and widens observations from call sites, +// body-derived contracts, and signature facts. Orchestration packages decide +// when an observation is produced; this package decides what that observation +// means and how it combines with prior evidence. // // # Evidence Collection // @@ -9,8 +11,8 @@ // // foo(123, "bar") -- evidence: param1=number, param2=string // -// The package collects argument types and associates them with parameter -// positions. Multiple call sites contribute evidence that is joined. +// Call-site analysis collects argument types and associates them with parameter +// positions. Multiple call sites contribute evidence that is joined here. // // # Evidence Merging // @@ -23,6 +25,6 @@ // // # Integration // -// Parameter evidence feeds into function signature inference, providing -// types for parameters that lack explicit annotations. +// Parameter evidence feeds into function signature inference, providing types +// for parameters that lack explicit annotations. package paramevidence diff --git a/compiler/check/domain/paramevidence/merge.go b/compiler/check/domain/paramevidence/merge.go new file mode 100644 index 00000000..e5900953 --- /dev/null +++ b/compiler/check/domain/paramevidence/merge.go @@ -0,0 +1,438 @@ +package paramevidence + +import ( + "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/domain/value" + "github.com/wippyai/go-lua/types/subtype" + "github.com/wippyai/go-lua/types/typ" + "github.com/wippyai/go-lua/types/typ/unwrap" +) + +// WidenMap merges two parameter evidence maps with the same vector law used by +// canonical FunctionFacts. +func WidenMap(prev, next map[cfg.SymbolID][]typ.Type) map[cfg.SymbolID][]typ.Type { + if prev == nil && next == nil { + return nil + } + if prev == nil { + return FilterEmptyMap(next) + } + if next == nil { + return FilterEmptyMap(prev) + } + merged := make(map[cfg.SymbolID][]typ.Type, len(prev)+len(next)) + for _, sym := range cfg.SortedSymbolIDs(prev) { + evidence := NormalizeVector(prev[sym]) + if hasNonNilEvidence(evidence) { + merged[sym] = evidence + } + } + for _, sym := range cfg.SortedSymbolIDs(next) { + evidence := NormalizeVector(next[sym]) + if !hasNonNilEvidence(evidence) { + continue + } + if existing := merged[sym]; existing != nil { + merged[sym] = JoinVectors(existing, evidence) + } else { + merged[sym] = evidence + } + } + return merged +} + +// FilterEmptyMap normalizes evidence and drops entries with no informative +// slots. +func FilterEmptyMap(evidence map[cfg.SymbolID][]typ.Type) map[cfg.SymbolID][]typ.Type { + if evidence == nil { + return nil + } + out := make(map[cfg.SymbolID][]typ.Type, len(evidence)) + for _, sym := range cfg.SortedSymbolIDs(evidence) { + v := FilterEmptyVector(evidence[sym]) + if hasNonNilEvidence(v) { + out[sym] = v + } + } + if len(out) == 0 { + return nil + } + return out +} + +// FilterEmptyVector normalizes one evidence vector and returns nil when all +// slots are empty. +func FilterEmptyVector(evidence []typ.Type) []typ.Type { + v := NormalizeVector(evidence) + if !hasNonNilEvidence(v) { + return nil + } + return v +} + +// NormalizeVector canonicalizes all occupied evidence slots. +func NormalizeVector(evidence []typ.Type) []typ.Type { + var out []typ.Type + for i, observed := range evidence { + normalized := NormalizeType(observed) + if out != nil { + out[i] = normalized + continue + } + if !typ.TypeEquals(observed, normalized) { + out = make([]typ.Type, len(evidence)) + copy(out, evidence[:i]) + out[i] = normalized + } + } + if out != nil { + return out + } + return evidence +} + +func hasNonNilEvidence(evidence []typ.Type) bool { + for _, observed := range evidence { + if observed != nil { + return true + } + } + return false +} + +// JoinVectors joins two parameter evidence vectors element-wise. +func JoinVectors(a, b []typ.Type) []typ.Type { + if len(a) == 0 { + return b + } + if len(b) == 0 { + return a + } + maxLen := len(a) + if len(b) > maxLen { + maxLen = len(b) + } + result := make([]typ.Type, maxLen) + for i := 0; i < maxLen; i++ { + var ai, bi typ.Type + if i < len(a) { + ai = a[i] + } + if i < len(b) { + bi = b[i] + } + result[i] = Join(ai, bi) + } + return result +} + +// Join merges two parameter evidence observations. +func Join(a, b typ.Type) typ.Type { + a = NormalizeType(a) + b = NormalizeType(b) + if a == nil { + return b + } + if b == nil { + return a + } + if unwrap.IsNilType(a) && !unwrap.IsNilType(b) { + return b + } + if unwrap.IsNilType(b) && !unwrap.IsNilType(a) { + return a + } + if joined, ok := joinNilable(a, b); ok { + return joined + } + return joinNonNil(a, b) +} + +func joinNilable(a, b typ.Type) (typ.Type, bool) { + ai, anil := value.SplitNilable(a) + bi, bnil := value.SplitNilable(b) + if !anil && !bnil { + return nil, false + } + if ai == nil && bi == nil { + return typ.Nil, true + } + if ai == nil { + return typ.NewOptional(bi), true + } + if bi == nil { + return typ.NewOptional(ai), true + } + return typ.NewOptional(joinNonNil(ai, bi)), true +} + +func joinNonNil(a, b typ.Type) typ.Type { + if upper, ok := selectTableUpperBound(a, b); ok { + return upper + } + if preferred, ok := value.PreferConcreteOverSoft(a, b); ok { + return preferred + } + if value.CanSelfEmbed(a) && value.ContainsEquivalent(b, a) && !typ.IsAbsentOrUnknown(a) { + if value.ContainsUnion(a) { + return a + } + return typ.JoinPreferNonSoft(a, b) + } + if value.CanSelfEmbed(b) && value.ContainsEquivalent(a, b) && !typ.IsAbsentOrUnknown(b) { + if value.ContainsUnion(b) { + return b + } + return typ.JoinPreferNonSoft(a, b) + } + if value.IsTruthyRefinement(a, b) { + return a + } + if value.IsTruthyRefinement(b, a) { + return b + } + if joined, ok := typ.JoinCompatibleRecords(a, b); ok { + return joined + } + if joined, ok := joinMapRecord(a, b); ok { + return joined + } + if value.ExtendsRecord(a, b) { + return a + } + if value.ExtendsRecord(b, a) { + return b + } + if !typ.IsAbsentOrUnknown(a) && !typ.IsAbsentOrUnknown(b) { + if subtype.IsSubtype(a, b) { + return b + } + if subtype.IsSubtype(b, a) { + return a + } + } + return NormalizeType(typ.JoinPreferNonSoft(a, b)) +} + +func joinMapRecord(a, b typ.Type) (typ.Type, bool) { + if joined, ok := joinMapRecordDirected(a, b); ok { + return joined, true + } + return joinMapRecordDirected(b, a) +} + +func joinMapRecordDirected(mapType, recordType typ.Type) (typ.Type, bool) { + m, ok := unwrap.Alias(mapType).(*typ.Map) + if !ok || m == nil { + return nil, false + } + r, ok := unwrap.Alias(recordType).(*typ.Record) + if !ok || r == nil || !r.HasMapComponent() { + return nil, false + } + + key := joinNonNil(m.Key, r.MapKey) + value := joinNonNil(m.Value, r.MapValue) + if len(r.Fields) == 0 && r.Metatable == nil { + return typ.NewMap(key, value), true + } + builder := typ.NewRecord() + if r.Open { + builder.SetOpen(true) + } + if r.Metatable != nil { + builder.Metatable(r.Metatable) + } + builder.MapComponent(key, value) + for _, field := range r.Fields { + fieldType := field.Type + optional := true + if subtype.IsSubtype(typ.LiteralString(field.Name), key) { + fieldType = joinNonNil(field.Type, value) + } else { + optional = field.Optional + } + switch { + case optional && field.Readonly: + builder.OptReadonlyField(field.Name, fieldType) + case optional: + builder.OptField(field.Name, fieldType) + case field.Readonly: + builder.ReadonlyField(field.Name, fieldType) + default: + builder.Field(field.Name, fieldType) + } + } + return builder.Build(), true +} + +func selectTableUpperBound(a, b typ.Type) (typ.Type, bool) { + if isOnlyTableTop(a) && typ.IsAny(b) { + return a, true + } + if isOnlyTableTop(b) && typ.IsAny(a) { + return b, true + } + if containsTableTop(a) && coveredByTableTop(b) && subtype.IsSubtype(b, a) { + return a, true + } + if containsTableTop(b) && coveredByTableTop(a) && subtype.IsSubtype(a, b) { + return b, true + } + return nil, false +} + +func containsTableTop(t typ.Type) bool { + if t == nil { + return false + } + if unwrap.IsBuiltinTableTop(typ.UnwrapAnnotated(t)) { + return true + } + switch v := typ.UnwrapAnnotated(t).(type) { + case *typ.Alias: + return containsTableTop(v.UnaliasedTarget()) + case *typ.Optional: + return containsTableTop(v.Inner) + case *typ.Union: + for _, member := range v.Members { + if containsTableTop(member) { + return true + } + } + } + return false +} + +func isOnlyTableTop(t typ.Type) bool { + if t == nil { + return false + } + if unwrap.IsBuiltinTableTop(typ.UnwrapAnnotated(t)) { + return true + } + switch v := typ.UnwrapAnnotated(t).(type) { + case *typ.Alias: + return isOnlyTableTop(v.UnaliasedTarget()) + case *typ.Optional: + return isOnlyTableTop(v.Inner) + case *typ.Union: + if len(v.Members) == 0 { + return false + } + hasTableTop := false + for _, member := range v.Members { + if unwrap.IsNilType(member) { + continue + } + if !isOnlyTableTop(member) { + return false + } + hasTableTop = true + } + return hasTableTop + default: + return false + } +} + +func coveredByTableTop(t typ.Type) bool { + if t == nil { + return false + } + if typ.IsAny(t) { + return true + } + if unwrap.IsNilType(t) || unwrap.IsBuiltinTableTop(typ.UnwrapAnnotated(t)) { + return true + } + switch v := typ.UnwrapAnnotated(t).(type) { + case *typ.Alias: + return coveredByTableTop(v.UnaliasedTarget()) + case *typ.Optional: + return coveredByTableTop(v.Inner) + case *typ.Recursive: + return v.Body != nil && v.Body != v && coveredByTableTop(v.Body) + case *typ.Union: + if len(v.Members) == 0 { + return false + } + for _, member := range v.Members { + if !coveredByTableTop(member) { + return false + } + } + return true + case *typ.Record, *typ.Map, *typ.Array, *typ.Tuple, *typ.Interface, *typ.Intersection: + return true + default: + return false + } +} + +// RefinesFunctionParam reports whether candidate is a valid directional +// refinement of baseline for parameter-slot facts. +func RefinesFunctionParam(candidate, baseline typ.Type) bool { + return value.ElidesOptional(candidate, baseline) || + value.IsTruthyRefinement(candidate, baseline) || + refinesTableKeyByTruthiness(candidate, baseline) +} + +func refinesTableKeyByTruthiness(candidate, baseline typ.Type) bool { + if candidate == nil || baseline == nil || typ.TypeEquals(candidate, baseline) { + return false + } + candidateInner, _ := value.SplitNilable(candidate) + baselineInner, _ := value.SplitNilable(baseline) + if candidateInner == nil || baselineInner == nil { + return false + } + return nonNilRefinesTableKeyByTruthiness(candidateInner, baselineInner) +} + +func nonNilRefinesTableKeyByTruthiness(candidate, baseline typ.Type) bool { + candidate = unwrap.Alias(candidate) + baseline = unwrap.Alias(baseline) + switch b := baseline.(type) { + case *typ.Record: + c, ok := candidate.(*typ.Record) + if !ok { + return false + } + return recordRefinesTableKeyByTruthiness(c, b) + case *typ.Map: + c, ok := candidate.(*typ.Map) + if !ok { + return false + } + return value.IsTruthyRefinement(c.Key, b.Key) && value.Equivalent(c.Value, b.Value) + default: + return false + } +} + +func recordRefinesTableKeyByTruthiness(candidate, baseline *typ.Record) bool { + if candidate == nil || baseline == nil || !candidate.HasMapComponent() || !baseline.HasMapComponent() { + return false + } + if candidate.Open != baseline.Open || len(candidate.Fields) != len(baseline.Fields) { + return false + } + if (candidate.Metatable == nil) != (baseline.Metatable == nil) { + return false + } + if candidate.Metatable != nil && !typ.TypeEquals(candidate.Metatable, baseline.Metatable) { + return false + } + for i, field := range candidate.Fields { + other := baseline.Fields[i] + if field.Name != other.Name || field.Optional != other.Optional || field.Readonly != other.Readonly { + return false + } + if !value.Equivalent(field.Type, other.Type) { + return false + } + } + return value.IsTruthyRefinement(candidate.MapKey, baseline.MapKey) && + value.Equivalent(candidate.MapValue, baseline.MapValue) +} diff --git a/compiler/check/domain/paramevidence/merge_test.go b/compiler/check/domain/paramevidence/merge_test.go new file mode 100644 index 00000000..e1e9d1a6 --- /dev/null +++ b/compiler/check/domain/paramevidence/merge_test.go @@ -0,0 +1,300 @@ +package paramevidence + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/types/typ" +) + +func TestWidenMap_StopsSelfEmbeddingRecordGrowth(t *testing.T) { + prevHint := typ.NewUnion( + typ.Number, + typ.NewRecord(). + Field("limit", typ.Any). + SetOpen(true). + Build(), + ) + nextHint := typ.NewRecord(). + Field("limit", prevHint). + SetOpen(true). + Build() + + merged := WidenMap( + map[cfg.SymbolID][]typ.Type{1: {prevHint}}, + map[cfg.SymbolID][]typ.Type{1: {nextHint}}, + ) + + got := merged[1][0] + if !typ.TypeEquals(got, prevHint) { + t.Fatalf("expected stable previous evidence, got %v", got) + } +} + +func TestWidenMap_StopsSelfEmbeddingContainerGrowth(t *testing.T) { + prevHint := typ.NewUnion( + typ.Number, + typ.NewRecord(). + Field("limit", typ.Any). + SetOpen(true). + Build(), + ) + + tests := []struct { + name string + next typ.Type + }{ + { + name: "record", + next: typ.NewRecord(). + Field("value", prevHint). + SetOpen(true). + Build(), + }, + { + name: "array", + next: typ.NewArray(prevHint), + }, + { + name: "map", + next: typ.NewMap(typ.String, prevHint), + }, + { + name: "tuple", + next: typ.NewTuple(prevHint), + }, + { + name: "function", + next: typ.Func(). + Param("value", prevHint). + Returns(prevHint). + Build(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + merged := WidenMap( + map[cfg.SymbolID][]typ.Type{1: {prevHint}}, + map[cfg.SymbolID][]typ.Type{1: {tt.next}}, + ) + + got := merged[1][0] + if !typ.TypeEquals(got, prevHint) { + t.Fatalf("expected stable previous evidence, got %v", got) + } + }) + } +} + +func TestWidenMap_KeepsFirstRecordWrapperObservation(t *testing.T) { + nextHint := typ.NewRecord(). + Field("limit", typ.Number). + SetOpen(true). + Build() + + merged := WidenMap( + map[cfg.SymbolID][]typ.Type{1: {typ.Number}}, + map[cfg.SymbolID][]typ.Type{1: {nextHint}}, + ) + + got := merged[1][0] + if typ.TypeEquals(got, typ.Number) { + t.Fatalf("expected wrapper observation to be preserved, got %v", got) + } + if !typ.TypeEquals(got, typ.NewUnion(typ.Number, nextHint)) { + t.Fatalf("expected number | wrapper evidence, got %v", got) + } +} + +func TestWidenMap_JoinsNestedRecordObservations(t *testing.T) { + nested := typ.NewRecord(). + Field("routes", typ.NewRecord().Field("users", typ.Boolean).SetOpen(true).Build()). + SetOpen(true). + Build() + outer := typ.NewRecord(). + Field("api", nested). + SetOpen(true). + Build() + + merged := WidenMap( + map[cfg.SymbolID][]typ.Type{1: {outer}}, + map[cfg.SymbolID][]typ.Type{1: {nested}}, + ) + + got := merged[1][0] + want := typ.NewUnion(outer, nested) + if !typ.TypeEquals(got, want) { + t.Fatalf("expected nested record observations to be joined as %v, got %v", want, got) + } +} + +func TestWidenMap_ReplacesStaleBroadHintWithCurrentRefinement(t *testing.T) { + stale := typ.NewUnion(typ.String, typ.False) + current := typ.String + + merged := WidenMap( + map[cfg.SymbolID][]typ.Type{1: {stale}}, + map[cfg.SymbolID][]typ.Type{1: {current}}, + ) + + got := merged[1][0] + if !typ.TypeEquals(got, current) { + t.Fatalf("expected current refined evidence %v to replace stale broad evidence, got %v", current, got) + } +} + +func TestWidenMap_ReplacesSoftContainerPlaceholderWithConcreteElementShape(t *testing.T) { + entry := typ.NewRecord().Field("id", typ.String).Build() + stale := typ.NewUnion( + typ.NewArray(typ.Any), + typ.NewRecord().SetOpen(true).Build(), + ) + current := typ.NewArray(entry) + + merged := WidenMap( + map[cfg.SymbolID][]typ.Type{1: {stale}}, + map[cfg.SymbolID][]typ.Type{1: {current}}, + ) + + got := merged[1][0] + if !typ.TypeEquals(got, current) { + t.Fatalf("expected concrete array evidence %v to replace soft stale evidence, got %v", current, got) + } +} + +func TestWidenMap_PreservesStructuredHintOverNilOnlyObservation(t *testing.T) { + context := typ.NewMap(typ.String, typ.Any) + + merged := WidenMap( + map[cfg.SymbolID][]typ.Type{1: {typ.String, typ.Any, context}}, + map[cfg.SymbolID][]typ.Type{1: {typ.String, typ.Any, typ.Nil}}, + ) + + got := merged[1][2] + if !typ.TypeEquals(got, context) { + t.Fatalf("expected nil-only observation to preserve structured evidence %v, got %v", context, got) + } + + again := WidenMap(merged, map[cfg.SymbolID][]typ.Type{1: {typ.String, typ.Any, typ.Nil}}) + if !evidenceMapsEqual(merged, again) { + t.Fatalf("expected idempotent nil-only observation widening, got %v then %v", merged, again) + } +} + +func TestWidenMap_PreservesMapHintOverOptionalOpenRecordObservation(t *testing.T) { + context := typ.NewMap(typ.String, typ.Any) + optionalContextRecord := typ.NewOptional(typ.NewRecord(). + MapComponent(typ.String, typ.Any). + SetOpen(true). + Build()) + + merged := WidenMap( + map[cfg.SymbolID][]typ.Type{1: {typ.String, typ.Any, context}}, + map[cfg.SymbolID][]typ.Type{1: {typ.String, typ.Any, optionalContextRecord}}, + ) + + got := merged[1][2] + if got == nil || typ.TypeEquals(got, typ.Nil) { + t.Fatalf("expected optional structured observation to preserve context evidence, got %v", got) + } + if !typ.TypeEquals(got, typ.NewOptional(context)) { + t.Fatalf("expected pure map observation to stay canonical, got %v", got) + } + + again := WidenMap(merged, map[cfg.SymbolID][]typ.Type{1: {typ.String, typ.Any, optionalContextRecord}}) + if !evidenceMapsEqual(merged, again) { + t.Fatalf("expected idempotent optional structured observation widening, got %v then %v", merged, again) + } +} + +func TestWidenMap_CollapsesPureOpenRecordMapToCanonicalMap(t *testing.T) { + entry := typ.NewRecord().Field("id", typ.String).Build() + canonical := typ.NewMap(typ.String, typ.NewArray(entry)) + staleRecordView := typ.NewRecord(). + MapComponent(typ.NewUnion(typ.String, typ.False), typ.NewArray(entry)). + SetOpen(true). + Build() + + merged := WidenMap( + map[cfg.SymbolID][]typ.Type{1: {staleRecordView}}, + map[cfg.SymbolID][]typ.Type{1: {canonical}}, + ) + + got := merged[1][0] + if !typ.TypeEquals(got, canonical) { + t.Fatalf("expected pure keyed table evidence to canonicalize to %v, got %v", canonical, got) + } +} + +func TestWidenMap_TableTopUpperBoundAbsorbsRecordUnion(t *testing.T) { + tableTop := typ.NewOptional(typ.NewInterface("table", nil)) + strategySpec := typ.NewRecord(). + Field("kind", typ.LiteralString("strategy")). + Field("tools", typ.NewTuple(typ.String, typ.String, typ.String)). + Build() + contextSpec := typ.NewRecord(). + Field("kind", typ.LiteralString("context")). + Field("scope", typ.String). + Build() + nextHint := typ.NewUnion(strategySpec, contextSpec) + + merged := WidenMap( + map[cfg.SymbolID][]typ.Type{1: {tableTop}}, + map[cfg.SymbolID][]typ.Type{1: {nextHint}}, + ) + + got := merged[1][0] + if !typ.TypeEquals(got, tableTop) { + t.Fatalf("expected table top upper bound %v, got %v", tableTop, got) + } + + again := WidenMap(merged, map[cfg.SymbolID][]typ.Type{1: {nextHint}}) + if !evidenceMapsEqual(merged, again) { + t.Fatalf("expected idempotent table-top widening, got %v then %v", merged, again) + } +} + +func TestWidenMap_TableTopUpperBoundAbsorbsAnyObservation(t *testing.T) { + tableTop := typ.NewOptional(typ.NewInterface("table", nil)) + + merged := WidenMap( + map[cfg.SymbolID][]typ.Type{1: {tableTop}}, + map[cfg.SymbolID][]typ.Type{1: {typ.Any}}, + ) + + got := merged[1][0] + if !typ.TypeEquals(got, tableTop) { + t.Fatalf("expected dynamic observation to preserve table top upper bound %v, got %v", tableTop, got) + } + + again := WidenMap(merged, map[cfg.SymbolID][]typ.Type{1: {typ.Any}}) + if !evidenceMapsEqual(merged, again) { + t.Fatalf("expected idempotent table-top/any widening, got %v then %v", merged, again) + } +} + +func evidenceMapsEqual(a, b map[cfg.SymbolID][]typ.Type) bool { + if len(a) != len(b) { + return false + } + for _, sym := range cfg.SortedSymbolIDs(a) { + right, ok := b[sym] + if !ok || !evidenceVectorsEqual(a[sym], right) { + return false + } + } + return true +} + +func evidenceVectorsEqual(a, b []typ.Type) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if !typ.TypeEquals(a[i], b[i]) { + return false + } + } + return true +} diff --git a/compiler/check/infer/paramevidence/parameter_evidence.go b/compiler/check/domain/paramevidence/parameter_evidence.go similarity index 100% rename from compiler/check/infer/paramevidence/parameter_evidence.go rename to compiler/check/domain/paramevidence/parameter_evidence.go diff --git a/compiler/check/infer/paramevidence/parameter_evidence_test.go b/compiler/check/domain/paramevidence/parameter_evidence_test.go similarity index 100% rename from compiler/check/infer/paramevidence/parameter_evidence_test.go rename to compiler/check/domain/paramevidence/parameter_evidence_test.go diff --git a/compiler/check/infer/paramevidence/project.go b/compiler/check/domain/paramevidence/project.go similarity index 100% rename from compiler/check/infer/paramevidence/project.go rename to compiler/check/domain/paramevidence/project.go diff --git a/compiler/check/domain/value/shape.go b/compiler/check/domain/value/shape.go new file mode 100644 index 00000000..8a2beae6 --- /dev/null +++ b/compiler/check/domain/value/shape.go @@ -0,0 +1,391 @@ +package value + +import ( + "github.com/wippyai/go-lua/internal" + "github.com/wippyai/go-lua/types/narrow" + "github.com/wippyai/go-lua/types/subtype" + "github.com/wippyai/go-lua/types/typ" + "github.com/wippyai/go-lua/types/typ/unwrap" +) + +// Equivalent reports structural equality or mutual subtyping. +func Equivalent(a, b typ.Type) bool { + return typ.TypeEquals(a, b) || (subtype.IsSubtype(a, b) && subtype.IsSubtype(b, a)) +} + +// ElidesOptional reports whether candidate is inside baseline after nil is +// removed from baseline. +func ElidesOptional(candidate, baseline typ.Type) bool { + if candidate == nil || baseline == nil { + return false + } + nonNil := narrow.RemoveNil(baseline) + if nonNil == nil || typ.TypeEquals(nonNil, baseline) { + return false + } + return subtype.IsSubtype(candidate, nonNil) +} + +// SplitNilable separates the non-nil component from an optional/nilable type. +func SplitNilable(t typ.Type) (typ.Type, bool) { + t = unwrap.Alias(t) + switch v := t.(type) { + case nil: + return nil, false + case *typ.Optional: + return v.Inner, true + case *typ.Union: + members := make([]typ.Type, 0, len(v.Members)) + nilable := false + for _, member := range v.Members { + member = unwrap.Alias(member) + if unwrap.IsNilType(member) { + nilable = true + continue + } + members = append(members, member) + } + if !nilable { + return t, false + } + return typ.NewUnion(members...), true + default: + if unwrap.IsNilType(t) { + return nil, true + } + return t, false + } +} + +// IsTruthyRefinement reports whether candidate equals or subtypes the truthy +// refinement of baseline. +func IsTruthyRefinement(candidate, baseline typ.Type) bool { + if candidate == nil || baseline == nil || typ.TypeEquals(candidate, baseline) { + return false + } + refined := narrow.ToTruthy(baseline) + if refined == nil || refined.Kind().IsNever() || typ.TypeEquals(refined, baseline) { + return false + } + return typ.TypeEquals(candidate, refined) || subtype.IsSubtype(candidate, refined) +} + +// PreferConcreteOverSoft selects a concrete observation over a soft placeholder +// while preserving explicit nilability. +func PreferConcreteOverSoft(a, b typ.Type) (typ.Type, bool) { + aSoft := typ.IsSoft(a, typ.SoftPlaceholderPolicy) + bSoft := typ.IsSoft(b, typ.SoftPlaceholderPolicy) + switch { + case aSoft && !bSoft && !unwrap.IsNilType(b): + return b, true + case bSoft && !aSoft && !unwrap.IsNilType(a): + return a, true + } + if preferred, ok := preferConcreteOverNilableSoft(a, b); ok { + return preferred, true + } + return nil, false +} + +func preferConcreteOverNilableSoft(a, b typ.Type) (typ.Type, bool) { + if preferred, ok := preferConcreteOverNilableSoftDirected(a, b); ok { + return preferred, true + } + return preferConcreteOverNilableSoftDirected(b, a) +} + +func preferConcreteOverNilableSoftDirected(softMaybeNil, concrete typ.Type) (typ.Type, bool) { + inner, nilable := SplitNilable(softMaybeNil) + if !nilable || inner == nil || !typ.IsSoft(inner, typ.SoftPlaceholderPolicy) { + return nil, false + } + if concrete == nil || unwrap.IsNilType(concrete) { + return nil, false + } + concreteInner, concreteNilable := SplitNilable(concrete) + if concreteInner == nil { + return nil, false + } + if typ.IsSoft(concreteInner, typ.SoftPlaceholderPolicy) { + return nil, false + } + if concreteNilable { + return concrete, true + } + return typ.NewOptional(concrete), true +} + +// CanSelfEmbed reports whether t is a structural type that can recursively +// carry another type value below itself. +func CanSelfEmbed(t typ.Type) bool { + if t == nil { + return false + } + switch v := t.(type) { + case *typ.Annotated: + return CanSelfEmbed(v.Inner) + case *typ.Alias: + return CanSelfEmbed(v.Target) + case *typ.Optional: + return CanSelfEmbed(v.Inner) + case *typ.Union: + for _, member := range v.Members { + if CanSelfEmbed(member) { + return true + } + } + return false + case *typ.Intersection: + for _, member := range v.Members { + if CanSelfEmbed(member) { + return true + } + } + return false + case *typ.Array, *typ.Map, *typ.Tuple, *typ.Record, *typ.Function: + return true + default: + return false + } +} + +// ContainsEquivalent reports whether haystack contains a node equivalent to +// needle while walking structural type children. +func ContainsEquivalent(haystack, needle typ.Type) bool { + if haystack == nil || needle == nil { + return false + } + return Scan(haystack, typ.NewGuard(), func(node typ.Type) (bool, bool) { + if typ.TypeEquals(node, needle) { + return true, false + } + return false, true + }) +} + +// ContainsUnion reports whether t contains any union node. +func ContainsUnion(t typ.Type) bool { + if t == nil { + return false + } + return Scan(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { + if _, ok := node.(*typ.Union); ok { + return true, false + } + return false, true + }) +} + +// Scan walks structural type children until visit stops traversal. +func Scan( + t typ.Type, + guard internal.RecursionGuard, + visit func(node typ.Type) (stop bool, descend bool), +) bool { + if t == nil { + return false + } + next, ok := guard.Enter(t) + if !ok { + return false + } + + node := t + for { + ann, ok := node.(*typ.Annotated) + if !ok || ann.Inner == nil || ann.Inner == node { + break + } + node = ann.Inner + } + + if stop, descend := visit(node); stop { + return true + } else if !descend { + return false + } + + switch n := node.(type) { + case *typ.Optional: + return Scan(n.Inner, next, visit) + case *typ.Union: + for _, m := range n.Members { + if Scan(m, next, visit) { + return true + } + } + return false + case *typ.Intersection: + for _, m := range n.Members { + if Scan(m, next, visit) { + return true + } + } + return false + case *typ.Array: + return Scan(n.Element, next, visit) + case *typ.Map: + return Scan(n.Key, next, visit) || Scan(n.Value, next, visit) + case *typ.Tuple: + for _, e := range n.Elements { + if Scan(e, next, visit) { + return true + } + } + return false + case *typ.Function: + for _, p := range n.Params { + if Scan(p.Type, next, visit) { + return true + } + } + for _, r := range n.Returns { + if Scan(r, next, visit) { + return true + } + } + return n.Variadic != nil && Scan(n.Variadic, next, visit) + case *typ.Record: + for _, f := range n.Fields { + if Scan(f.Type, next, visit) { + return true + } + } + if n.Metatable != nil && Scan(n.Metatable, next, visit) { + return true + } + if n.HasMapComponent() { + return Scan(n.MapKey, next, visit) || Scan(n.MapValue, next, visit) + } + return false + case *typ.Alias: + return Scan(n.Target, next, visit) + case *typ.Instantiated: + for _, a := range n.TypeArgs { + if Scan(a, next, visit) { + return true + } + } + return false + case *typ.Interface: + for _, m := range n.Methods { + if m.Type != nil && Scan(m.Type, next, visit) { + return true + } + } + return false + default: + return false + } +} + +// ExtendsRecord reports whether a extends b by adding record fields. This +// treats record field supersets as refinements when b is a record or union of +// records. +func ExtendsRecord(a, b typ.Type) bool { + if a == nil || b == nil { + return false + } + ar, ok := a.(*typ.Record) + if !ok { + return false + } + switch br := b.(type) { + case *typ.Record: + return RecordSuperset(ar, br) + case *typ.Union: + return recordSupersetUnion(ar, br) + default: + return false + } +} + +// RecordSuperset reports whether newRec preserves oldRec and may add fields. +func RecordSuperset(newRec, oldRec *typ.Record) bool { + if newRec == nil || oldRec == nil { + return false + } + if oldRec.Metatable != nil { + if newRec.Metatable == nil || !subtype.IsSubtype(newRec.Metatable, oldRec.Metatable) { + return false + } + } + if oldRec.HasMapComponent() { + if !newRec.HasMapComponent() { + return false + } + if !subtype.IsSubtype(newRec.MapKey, oldRec.MapKey) || !subtype.IsSubtype(newRec.MapValue, oldRec.MapValue) { + return false + } + } + oldFields := make(map[string]typ.Field, len(oldRec.Fields)) + for _, f := range oldRec.Fields { + oldFields[f.Name] = f + } + for _, nf := range newRec.Fields { + if of, ok := oldFields[nf.Name]; ok { + if of.Optional && !nf.Optional { + // ok: stronger requirement + } else if !of.Optional && nf.Optional { + return false + } + if of.Readonly && !nf.Readonly { + return false + } + if of.Type != nil { + if IsOpenTopRecord(nf.Type) && IsStructuredTableShape(of.Type) { + return false + } + if nf.Type == nil || !subtype.IsSubtype(nf.Type, of.Type) { + return false + } + } + delete(oldFields, nf.Name) + } + } + return len(oldFields) == 0 +} + +func recordSupersetUnion(newRec *typ.Record, oldUnion *typ.Union) bool { + if newRec == nil || oldUnion == nil { + return false + } + if len(oldUnion.Members) == 0 { + return false + } + for _, member := range oldUnion.Members { + oldRec, ok := member.(*typ.Record) + if !ok { + return false + } + if !RecordSuperset(newRec, oldRec) { + return false + } + } + return true +} + +// IsOpenTopRecord reports whether t is an open record with no concrete fields +// or map component. +func IsOpenTopRecord(t typ.Type) bool { + rec, ok := unwrap.Alias(t).(*typ.Record) + if !ok || rec == nil { + return false + } + return rec.Open && len(rec.Fields) == 0 && !rec.HasMapComponent() +} + +// IsStructuredTableShape reports whether t carries table structure beyond an +// open-top placeholder. +func IsStructuredTableShape(t typ.Type) bool { + switch v := unwrap.Alias(t).(type) { + case *typ.Array: + return true + case *typ.Map: + return true + case *typ.Record: + return v.HasMapComponent() || len(v.Fields) > 0 + default: + return false + } +} diff --git a/compiler/check/domain/value/shape_test.go b/compiler/check/domain/value/shape_test.go new file mode 100644 index 00000000..5377885e --- /dev/null +++ b/compiler/check/domain/value/shape_test.go @@ -0,0 +1,30 @@ +package value + +import ( + "testing" + + "github.com/wippyai/go-lua/types/typ" +) + +func TestExtendsRecord_NilTypes(t *testing.T) { + if ExtendsRecord(nil, typ.String) { + t.Error("nil a should not extend") + } + if ExtendsRecord(typ.String, nil) { + t.Error("nil b should not extend") + } +} + +func TestExtendsRecord_NotRecord(t *testing.T) { + if ExtendsRecord(typ.String, typ.String) { + t.Error("non-record should not extend") + } +} + +func TestExtendsRecord_MapComponentConsistency(t *testing.T) { + oldRec := typ.NewRecord().MapComponent(typ.String, typ.Number).Build() + newRec := typ.NewRecord().Field("x", typ.Number).Build() + if ExtendsRecord(newRec, oldRec) { + t.Error("record without map component should not extend record with map component") + } +} diff --git a/compiler/check/infer/interproc/postflow.go b/compiler/check/infer/interproc/postflow.go index 39ce124c..4e440627 100644 --- a/compiler/check/infer/interproc/postflow.go +++ b/compiler/check/infer/interproc/postflow.go @@ -6,8 +6,8 @@ import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" checkcallsite "github.com/wippyai/go-lua/compiler/check/callsite" + "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" "github.com/wippyai/go-lua/compiler/check/erreffect" - "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" "github.com/wippyai/go-lua/compiler/check/nested" "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/compiler/check/scope" diff --git a/compiler/check/infer/return/infer.go b/compiler/check/infer/return/infer.go index 15e6bae1..c1fadceb 100644 --- a/compiler/check/infer/return/infer.go +++ b/compiler/check/infer/return/infer.go @@ -43,8 +43,8 @@ import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/callsite" + "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" flowpath "github.com/wippyai/go-lua/compiler/check/flowbuild/path" - "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" "github.com/wippyai/go-lua/compiler/check/modules" "github.com/wippyai/go-lua/compiler/check/phase" "github.com/wippyai/go-lua/compiler/check/returns" diff --git a/compiler/check/infer/return/overlay_pipeline.go b/compiler/check/infer/return/overlay_pipeline.go index b76194c6..602ea8fe 100644 --- a/compiler/check/infer/return/overlay_pipeline.go +++ b/compiler/check/infer/return/overlay_pipeline.go @@ -5,11 +5,11 @@ import ( "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" "github.com/wippyai/go-lua/compiler/check/flowbuild/assign" fbcore "github.com/wippyai/go-lua/compiler/check/flowbuild/core" "github.com/wippyai/go-lua/compiler/check/flowbuild/mutator" "github.com/wippyai/go-lua/compiler/check/flowbuild/resolve" - "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" "github.com/wippyai/go-lua/compiler/check/phase" "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/compiler/check/scope" diff --git a/compiler/check/infer/return/scc.go b/compiler/check/infer/return/scc.go index 7dea8e36..b8852216 100644 --- a/compiler/check/infer/return/scc.go +++ b/compiler/check/infer/return/scc.go @@ -4,7 +4,7 @@ import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" - "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" + "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/types/diag" "github.com/wippyai/go-lua/types/typ" diff --git a/compiler/check/phase/scope.go b/compiler/check/phase/scope.go index a56aff9b..2a065fb6 100644 --- a/compiler/check/phase/scope.go +++ b/compiler/check/phase/scope.go @@ -27,7 +27,7 @@ import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/callsite" - "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" + "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/compiler/check/scope" "github.com/wippyai/go-lua/compiler/check/synth" diff --git a/compiler/check/phase/types_test.go b/compiler/check/phase/types_test.go index fcc9c227..f4081170 100644 --- a/compiler/check/phase/types_test.go +++ b/compiler/check/phase/types_test.go @@ -7,7 +7,7 @@ import ( "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" - "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" + "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" "github.com/wippyai/go-lua/compiler/check/scope" "github.com/wippyai/go-lua/types/flow" "github.com/wippyai/go-lua/types/typ" diff --git a/compiler/check/pipeline/runner.go b/compiler/check/pipeline/runner.go index 6671222f..37c781b4 100644 --- a/compiler/check/pipeline/runner.go +++ b/compiler/check/pipeline/runner.go @@ -20,8 +20,8 @@ package pipeline import ( "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" "github.com/wippyai/go-lua/compiler/check/infer/captured" - "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" "github.com/wippyai/go-lua/compiler/check/modules" "github.com/wippyai/go-lua/compiler/check/phase" "github.com/wippyai/go-lua/compiler/check/scope" diff --git a/compiler/check/pipeline/runner_stages.go b/compiler/check/pipeline/runner_stages.go index 8b12bdaa..ddf75739 100644 --- a/compiler/check/pipeline/runner_stages.go +++ b/compiler/check/pipeline/runner_stages.go @@ -4,8 +4,8 @@ import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" "github.com/wippyai/go-lua/compiler/check/flowbuild/resolve" - "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" "github.com/wippyai/go-lua/compiler/check/phase" "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/compiler/check/scope" diff --git a/compiler/check/returns/callgraph.go b/compiler/check/returns/callgraph.go index d9150885..f5fd2ec8 100644 --- a/compiler/check/returns/callgraph.go +++ b/compiler/check/returns/callgraph.go @@ -7,7 +7,7 @@ import ( "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" checkcallsite "github.com/wippyai/go-lua/compiler/check/callsite" - "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" + "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" synthresolve "github.com/wippyai/go-lua/compiler/check/synth/phase/resolve" "github.com/wippyai/go-lua/types/typ" "github.com/wippyai/go-lua/types/typ/unwrap" diff --git a/compiler/check/returns/function_facts.go b/compiler/check/returns/function_facts.go index 958cecc5..e301a316 100644 --- a/compiler/check/returns/function_facts.go +++ b/compiler/check/returns/function_facts.go @@ -3,6 +3,7 @@ package returns import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" ) func collectCanonicalFunctionFactSymbols(factSets ...api.FunctionFacts) []cfg.SymbolID { @@ -25,7 +26,7 @@ func markFunctionFactSymbols[T any](dst map[cfg.SymbolID]bool, src map[cfg.Symbo func NormalizeFunctionFact(ff api.FunctionFact) api.FunctionFact { return api.FunctionFact{ - Params: filterEmptyParameterEvidenceVector(ff.Params), + Params: paramevidence.FilterEmptyVector(ff.Params), Summary: canonicalReturnVector(ff.Summary), Narrow: canonicalReturnVector(ff.Narrow), Type: normalizeInterprocValueType(ff.Type), diff --git a/compiler/check/returns/join.go b/compiler/check/returns/join.go index 32247006..6b30b909 100644 --- a/compiler/check/returns/join.go +++ b/compiler/check/returns/join.go @@ -1,6 +1,8 @@ package returns import ( + "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" + "github.com/wippyai/go-lua/compiler/check/domain/value" "github.com/wippyai/go-lua/types/kind" "github.com/wippyai/go-lua/types/narrow" "github.com/wippyai/go-lua/types/subtype" @@ -72,20 +74,10 @@ func ReturnTypesExtendRecord(a, b []typ.Type) bool { return false } for i := range a { - ar, ok := a[i].(*typ.Record) - if !ok { + if _, ok := a[i].(*typ.Record); !ok { return false } - switch br := b[i].(type) { - case *typ.Record: - if !recordSuperset(ar, br) { - return false - } - case *typ.Union: - if !recordSupersetUnion(ar, br) { - return false - } - default: + if !value.ExtendsRecord(a[i], b[i]) { return false } } @@ -101,7 +93,7 @@ func ReturnTypesElideOptional(a, b []typ.Type) bool { return false } for i := range a { - if !typeElidesOptional(a[i], b[i]) { + if !value.ElidesOptional(a[i], b[i]) { return false } } @@ -216,7 +208,7 @@ func typeRefinesSoftContainer(candidate, baseline typ.Type) (bool, bool) { return typeRefinesSoftContainerSlot(c.Element, b.Element) case *typ.Map: c, ok := candidate.(*typ.Map) - if !ok || !equivalentParamValueType(c.Key, b.Key) { + if !ok || !value.Equivalent(c.Key, b.Key) { return false, false } return typeRefinesSoftContainerSlot(c.Value, b.Value) @@ -228,7 +220,7 @@ func typeRefinesSoftContainer(candidate, baseline typ.Type) (bool, bool) { if !c.HasMapComponent() && !b.HasMapComponent() { return true, false } - if !c.HasMapComponent() || !b.HasMapComponent() || !equivalentParamValueType(c.MapKey, b.MapKey) { + if !c.HasMapComponent() || !b.HasMapComponent() || !value.Equivalent(c.MapKey, b.MapKey) { return false, false } return typeRefinesSoftContainerSlot(c.MapValue, b.MapValue) @@ -241,10 +233,10 @@ func typeRefinesSoftContainerSlot(candidate, baseline typ.Type) (bool, bool) { if typ.TypeEquals(candidate, baseline) { return true, false } - if (typ.IsAny(baseline) || typ.IsUnknown(baseline)) && typeCanSelfEmbed(candidate) { + if (typ.IsAny(baseline) || typ.IsUnknown(baseline)) && value.CanSelfEmbed(candidate) { return false, false } - preferred, ok := preferConcreteOverSoftType(baseline, candidate) + preferred, ok := value.PreferConcreteOverSoft(baseline, candidate) return ok && typ.TypeEquals(preferred, candidate), ok } @@ -365,7 +357,7 @@ func truthyElementRefinement(candidate, baseline typ.Type) (bool, bool) { if typ.TypeEquals(candidate, baseline) { return true, false } - if typeIsTruthyRefinement(candidate, baseline) { + if value.IsTruthyRefinement(candidate, baseline) { return true, true } return false, false @@ -471,7 +463,7 @@ func ReturnTypesStopRecursiveStructuralGrowth(stable, growing []typ.Type) bool { if typ.TypeEquals(s, g) { continue } - if typ.IsAbsentOrUnknown(s) || !typeCanSelfEmbed(s) { + if typ.IsAbsentOrUnknown(s) || !value.CanSelfEmbed(s) { return false } if !shallowStructuralShapeEquals(g, s) { @@ -741,7 +733,7 @@ func ReturnTypesFillNilSlots(a, b []typ.Type) bool { if typ.TypeEquals(ai, bi) { continue } - if subtype.IsSubtype(ai, bi) || TypeExtendsRecord(ai, bi) || typeElidesOptional(ai, bi) { + if subtype.IsSubtype(ai, bi) || value.ExtendsRecord(ai, bi) || value.ElidesOptional(ai, bi) { continue } return false @@ -773,26 +765,6 @@ func ReturnTypesRepairNever(candidate, baseline []typ.Type) bool { return strict } -// TypeExtendsRecord reports whether type a extends type b by adding record fields. -// This treats record field supersets as refinements when b is a record or union of records. -func TypeExtendsRecord(a, b typ.Type) bool { - if a == nil || b == nil { - return false - } - ar, ok := a.(*typ.Record) - if !ok { - return false - } - switch br := b.(type) { - case *typ.Record: - return recordSuperset(ar, br) - case *typ.Union: - return recordSupersetUnion(ar, br) - default: - return false - } -} - func typeRepairsNever(candidate, baseline typ.Type) bool { if candidate == nil || baseline == nil { return false @@ -1077,83 +1049,6 @@ func typeContainsNeverMemo(t typ.Type, seen map[typ.Type]bool) bool { }) } -func typeElidesOptional(a, b typ.Type) bool { - if a == nil || b == nil { - return false - } - nonNil := narrow.RemoveNil(b) - if nonNil == nil || typ.TypeEquals(nonNil, b) { - return false - } - return subtype.IsSubtype(a, nonNil) -} - -func recordSuperset(newRec, oldRec *typ.Record) bool { - if newRec == nil || oldRec == nil { - return false - } - if oldRec.Metatable != nil { - if newRec.Metatable == nil || !subtype.IsSubtype(newRec.Metatable, oldRec.Metatable) { - return false - } - } - if oldRec.HasMapComponent() { - if !newRec.HasMapComponent() { - return false - } - if !subtype.IsSubtype(newRec.MapKey, oldRec.MapKey) || !subtype.IsSubtype(newRec.MapValue, oldRec.MapValue) { - return false - } - } - oldFields := make(map[string]typ.Field, len(oldRec.Fields)) - for _, f := range oldRec.Fields { - oldFields[f.Name] = f - } - for _, nf := range newRec.Fields { - if of, ok := oldFields[nf.Name]; ok { - if of.Optional && !nf.Optional { - // ok: stronger requirement - } else if !of.Optional && nf.Optional { - return false - } - if of.Readonly && !nf.Readonly { - return false - } - if of.Type != nil { - if isOpenTopRecordType(nf.Type) && isStructuredTableShape(of.Type) { - // Open-top table placeholders must not dominate structured - // collection/record fields when selecting preferred summaries. - return false - } - if nf.Type == nil || !subtype.IsSubtype(nf.Type, of.Type) { - return false - } - } - delete(oldFields, nf.Name) - } - } - return len(oldFields) == 0 -} - -func recordSupersetUnion(newRec *typ.Record, oldUnion *typ.Union) bool { - if newRec == nil || oldUnion == nil { - return false - } - if len(oldUnion.Members) == 0 { - return false - } - for _, member := range oldUnion.Members { - oldRec, ok := member.(*typ.Record) - if !ok { - return false - } - if !recordSuperset(newRec, oldRec) { - return false - } - } - return true -} - // NormalizeReturnVector replaces nil slots with explicit nil types. func NormalizeReturnVector(rets []typ.Type) []typ.Type { if len(rets) == 0 { @@ -1456,7 +1351,7 @@ func mergeFunctionParamFactType(existing, candidate typ.Type) typ.Type { if preferred, ok := preferStructuredRecordParam(existing, candidate); ok { return preferred } - if preferred, ok := preferConcreteOverSoftType(existing, candidate); ok { + if preferred, ok := value.PreferConcreteOverSoft(existing, candidate); ok { return preferred } if typ.IsUnknown(existing) { @@ -1477,10 +1372,10 @@ func mergeFunctionParamFactType(existing, candidate typ.Type) typ.Type { if typ.TypeEquals(existing, candidate) { return existing } - if candidateRefinesFunctionParam(candidate, existing) { + if paramevidence.RefinesFunctionParam(candidate, existing) { return candidate } - if candidateRefinesFunctionParam(existing, candidate) { + if paramevidence.RefinesFunctionParam(existing, candidate) { return existing } if subtype.IsSubtype(existing, candidate) && !subtype.IsSubtype(candidate, existing) { @@ -1492,75 +1387,6 @@ func mergeFunctionParamFactType(existing, candidate typ.Type) typ.Type { return typ.JoinPreferNonSoft(existing, candidate) } -func candidateRefinesFunctionParam(candidate, baseline typ.Type) bool { - return typeElidesOptional(candidate, baseline) || - typeIsTruthyRefinement(candidate, baseline) || - typeRefinesTableKeyByTruthiness(candidate, baseline) -} - -func typeRefinesTableKeyByTruthiness(candidate, baseline typ.Type) bool { - if candidate == nil || baseline == nil || typ.TypeEquals(candidate, baseline) { - return false - } - candidateInner, _ := splitNilableParameterEvidence(candidate) - baselineInner, _ := splitNilableParameterEvidence(baseline) - if candidateInner == nil || baselineInner == nil { - return false - } - return nonNilTypeRefinesTableKeyByTruthiness(candidateInner, baselineInner) -} - -func nonNilTypeRefinesTableKeyByTruthiness(candidate, baseline typ.Type) bool { - candidate = unwrap.Alias(candidate) - baseline = unwrap.Alias(baseline) - switch b := baseline.(type) { - case *typ.Record: - c, ok := candidate.(*typ.Record) - if !ok { - return false - } - return recordRefinesTableKeyByTruthiness(c, b) - case *typ.Map: - c, ok := candidate.(*typ.Map) - if !ok { - return false - } - return typeIsTruthyRefinement(c.Key, b.Key) && equivalentParamValueType(c.Value, b.Value) - default: - return false - } -} - -func recordRefinesTableKeyByTruthiness(candidate, baseline *typ.Record) bool { - if candidate == nil || baseline == nil || !candidate.HasMapComponent() || !baseline.HasMapComponent() { - return false - } - if candidate.Open != baseline.Open || len(candidate.Fields) != len(baseline.Fields) { - return false - } - if (candidate.Metatable == nil) != (baseline.Metatable == nil) { - return false - } - if candidate.Metatable != nil && !typ.TypeEquals(candidate.Metatable, baseline.Metatable) { - return false - } - for i, field := range candidate.Fields { - other := baseline.Fields[i] - if field.Name != other.Name || field.Optional != other.Optional || field.Readonly != other.Readonly { - return false - } - if !equivalentParamValueType(field.Type, other.Type) { - return false - } - } - return typeIsTruthyRefinement(candidate.MapKey, baseline.MapKey) && - equivalentParamValueType(candidate.MapValue, baseline.MapValue) -} - -func equivalentParamValueType(a, b typ.Type) bool { - return typ.TypeEquals(a, b) || (subtype.IsSubtype(a, b) && subtype.IsSubtype(b, a)) -} - func preferStructuredRecordParam(existing, candidate typ.Type) (typ.Type, bool) { existingRec, okExisting := unwrap.Alias(existing).(*typ.Record) candidateRec, okCandidate := unwrap.Alias(candidate).(*typ.Record) @@ -1627,10 +1453,10 @@ func replaceOpenTopWithStructured(current, summary []typ.Type) ([]typ.Type, bool out := append([]typ.Type(nil), current...) changed := false for i := range out { - if !isOpenTopRecordType(out[i]) { + if !value.IsOpenTopRecord(out[i]) { continue } - if !isStructuredTableShape(summary[i]) { + if !value.IsStructuredTableShape(summary[i]) { continue } out[i] = summary[i] @@ -1663,24 +1489,3 @@ func WithSummaryOrUnknown(fn *typ.Function, summary []typ.Type) *typ.Function { } return typjoin.WithReturns(fn, normalizeAndPruneReturnVector(summary)) } - -func isOpenTopRecordType(t typ.Type) bool { - rec, ok := unwrap.Alias(t).(*typ.Record) - if !ok || rec == nil { - return false - } - return rec.Open && len(rec.Fields) == 0 && !rec.HasMapComponent() -} - -func isStructuredTableShape(t typ.Type) bool { - switch v := unwrap.Alias(t).(type) { - case *typ.Array: - return true - case *typ.Map: - return true - case *typ.Record: - return v.HasMapComponent() || len(v.Fields) > 0 - default: - return false - } -} diff --git a/compiler/check/returns/join_test.go b/compiler/check/returns/join_test.go index 9ba2f9f7..5d61ea70 100644 --- a/compiler/check/returns/join_test.go +++ b/compiler/check/returns/join_test.go @@ -644,21 +644,6 @@ func TestWithSummaryOrUnknown_DefaultsToUnknownWhenMissing(t *testing.T) { } } -func TestTypeExtendsRecord_NilTypes(t *testing.T) { - if TypeExtendsRecord(nil, typ.String) { - t.Error("nil a should not extend") - } - if TypeExtendsRecord(typ.String, nil) { - t.Error("nil b should not extend") - } -} - -func TestTypeExtendsRecord_NotRecord(t *testing.T) { - if TypeExtendsRecord(typ.String, typ.String) { - t.Error("non-record should not extend") - } -} - func TestNormalizeReturnVector_Empty(t *testing.T) { result := NormalizeReturnVector(nil) if result != nil { @@ -853,17 +838,6 @@ func TestRecordSuperset_IncompatibleMapComponent(t *testing.T) { } } -// Regression: recordSuperset should use && not || for map component check. -// This test verifies the fix by checking that the code uses HasMapComponent semantics. -func TestTypeExtendsRecord_MapComponentConsistency(t *testing.T) { - // When old has map component, new must have compatible map component - oldRec := typ.NewRecord().MapComponent(typ.String, typ.Number).Build() - newRec := typ.NewRecord().Field("x", typ.Number).Build() - if TypeExtendsRecord(newRec, oldRec) { - t.Error("record without map component should not extend record with map component") - } -} - func TestMergeReturnSummary_PrefersStructuredCollectionOverOpenTopRecordField(t *testing.T) { weak := []typ.Type{ typ.NewRecord(). diff --git a/compiler/check/returns/kernel.go b/compiler/check/returns/kernel.go index 5895571f..48cb4f25 100644 --- a/compiler/check/returns/kernel.go +++ b/compiler/check/returns/kernel.go @@ -2,6 +2,7 @@ package returns import ( "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" "github.com/wippyai/go-lua/types/typ/unwrap" ) @@ -13,7 +14,7 @@ func JoinFunctionFact(existing, candidate api.FunctionFact) api.FunctionFact { out := existing if len(candidate.Params) > 0 { - out.Params = joinParameterEvidenceVectors(out.Params, candidate.Params) + out.Params = paramevidence.JoinVectors(out.Params, candidate.Params) } if len(candidate.Summary) > 0 { out.Summary = MergeReturnSummary(out.Summary, candidate.Summary) diff --git a/compiler/check/returns/widen.go b/compiler/check/returns/widen.go index 104b5e2a..598c8786 100644 --- a/compiler/check/returns/widen.go +++ b/compiler/check/returns/widen.go @@ -3,9 +3,8 @@ package returns import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" - "github.com/wippyai/go-lua/compiler/check/infer/paramevidence" - "github.com/wippyai/go-lua/internal" - "github.com/wippyai/go-lua/types/narrow" + "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" + "github.com/wippyai/go-lua/compiler/check/domain/value" "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" "github.com/wippyai/go-lua/types/typ/unwrap" @@ -64,7 +63,7 @@ func JoinFacts(prev, next api.Facts) api.Facts { func widenFunctionFactForConvergence(prev, next api.FunctionFact) api.FunctionFact { out := api.FunctionFact{ - Params: joinParameterEvidenceVectors(prev.Params, next.Params), + Params: paramevidence.JoinVectors(prev.Params, next.Params), Summary: widenReturnSummaryForConvergence(prev.Summary, next.Summary), Narrow: widenReturnSummaryForConvergence(prev.Narrow, next.Narrow), Type: widenFunctionFactTypeForConvergence(prev.Type, next.Type), @@ -111,7 +110,7 @@ func hasHigherOrderGrowthRisk(t typ.Type) bool { if t == nil { return false } - return scanType(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { + return value.Scan(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { switch n := node.(type) { case *typ.Function: for _, ret := range n.Returns { @@ -132,7 +131,7 @@ func typeContainsFunction(t typ.Type) bool { if t == nil { return false } - return scanType(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { + return value.Scan(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { // Interface method signatures are behavioral contracts, not first-class // returned function values. Ignore them for higher-order growth risk. if _, ok := node.(*typ.Interface); ok { @@ -164,7 +163,7 @@ func methodTypeHasSelfRecursiveReturn(t typ.Type, owner *typ.Record) bool { if t == nil || owner == nil { return false } - return scanType(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { + return value.Scan(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { // Interface method signatures are behavioral contracts, not concrete // record method bodies. Treating them as self-recursive growth risk // over-applies monotone widening and blocks valid summary refinement. @@ -180,7 +179,7 @@ func methodTypeHasSelfRecursiveReturn(t typ.Type, owner *typ.Record) bool { continue } if subtype.IsSubtype(ret, owner) || subtype.IsSubtype(owner, ret) || - TypeExtendsRecord(ret, owner) || TypeExtendsRecord(owner, ret) { + value.ExtendsRecord(ret, owner) || value.ExtendsRecord(owner, ret) { return true, false } } @@ -188,108 +187,6 @@ func methodTypeHasSelfRecursiveReturn(t typ.Type, owner *typ.Record) bool { }) } -func scanType( - t typ.Type, - guard internal.RecursionGuard, - visit func(node typ.Type) (stop bool, descend bool), -) bool { - if t == nil { - return false - } - next, ok := guard.Enter(t) - if !ok { - return false - } - - node := t - for { - ann, ok := node.(*typ.Annotated) - if !ok || ann.Inner == nil || ann.Inner == node { - break - } - node = ann.Inner - } - - if stop, descend := visit(node); stop { - return true - } else if !descend { - return false - } - - switch n := node.(type) { - case *typ.Optional: - return scanType(n.Inner, next, visit) - case *typ.Union: - for _, m := range n.Members { - if scanType(m, next, visit) { - return true - } - } - return false - case *typ.Intersection: - for _, m := range n.Members { - if scanType(m, next, visit) { - return true - } - } - return false - case *typ.Array: - return scanType(n.Element, next, visit) - case *typ.Map: - return scanType(n.Key, next, visit) || scanType(n.Value, next, visit) - case *typ.Tuple: - for _, e := range n.Elements { - if scanType(e, next, visit) { - return true - } - } - return false - case *typ.Function: - for _, p := range n.Params { - if scanType(p.Type, next, visit) { - return true - } - } - for _, r := range n.Returns { - if scanType(r, next, visit) { - return true - } - } - return n.Variadic != nil && scanType(n.Variadic, next, visit) - case *typ.Record: - for _, f := range n.Fields { - if scanType(f.Type, next, visit) { - return true - } - } - if n.Metatable != nil && scanType(n.Metatable, next, visit) { - return true - } - if n.HasMapComponent() { - return scanType(n.MapKey, next, visit) || scanType(n.MapValue, next, visit) - } - return false - case *typ.Alias: - return scanType(n.Target, next, visit) - case *typ.Instantiated: - for _, a := range n.TypeArgs { - if scanType(a, next, visit) { - return true - } - } - return false - case *typ.Interface: - for _, m := range n.Methods { - if m.Type != nil && scanType(m.Type, next, visit) { - return true - } - } - return false - default: - return false - } -} - func joinReturnVectorsMonotone(a, b []typ.Type) []typ.Type { if len(a) == 0 { return b @@ -326,10 +223,10 @@ func joinReturnTypeMonotone(a, b typ.Type) typ.Type { return a } // Keep widening monotone: if one side is already an upper bound, keep it. - if subtype.IsSubtype(a, b) || TypeExtendsRecord(a, b) || typeElidesOptional(a, b) { + if subtype.IsSubtype(a, b) || value.ExtendsRecord(a, b) || value.ElidesOptional(a, b) { return b } - if subtype.IsSubtype(b, a) || TypeExtendsRecord(b, a) || typeElidesOptional(b, a) { + if subtype.IsSubtype(b, a) || value.ExtendsRecord(b, a) || value.ElidesOptional(b, a) { return a } return typ.JoinPreferNonSoft(a, b) @@ -368,7 +265,7 @@ func typeUnsafePrecisionDrop(prev, merged typ.Type) bool { if prev == nil || merged == nil || typ.TypeEquals(prev, merged) { return false } - if typeElidesOptional(merged, prev) { + if value.ElidesOptional(merged, prev) { return false } if refines, _ := typeRefinesFalsyMapKey(merged, prev); refines { @@ -435,7 +332,7 @@ func typeUnsafePrecisionDrop(prev, merged typ.Type) bool { } } - if subtype.IsSubtype(merged, prev) && !subtype.IsSubtype(prev, merged) && !TypeExtendsRecord(merged, prev) { + if subtype.IsSubtype(merged, prev) && !subtype.IsSubtype(prev, merged) && !value.ExtendsRecord(merged, prev) { return true } return false @@ -478,502 +375,6 @@ func unionMembers(t typ.Type) []typ.Type { } } -// WidenParameterEvidence merges two parameter evidence maps using the same -// vector law used by canonical FunctionFacts. -func WidenParameterEvidence(prev, next map[cfg.SymbolID][]typ.Type) map[cfg.SymbolID][]typ.Type { - if prev == nil && next == nil { - return nil - } - if prev == nil { - return filterEmptyParameterEvidence(next) - } - if next == nil { - return filterEmptyParameterEvidence(prev) - } - merged := make(map[cfg.SymbolID][]typ.Type, len(prev)+len(next)) - for _, sym := range cfg.SortedSymbolIDs(prev) { - evidence := normalizeParameterEvidenceVector(prev[sym]) - if hasNonNilEvidence(evidence) { - merged[sym] = evidence - } - } - for _, sym := range cfg.SortedSymbolIDs(next) { - evidence := normalizeParameterEvidenceVector(next[sym]) - if !hasNonNilEvidence(evidence) { - continue - } - if existing := merged[sym]; existing != nil { - merged[sym] = joinParameterEvidenceVectors(existing, evidence) - } else { - merged[sym] = evidence - } - } - return merged -} - -func filterEmptyParameterEvidence(evidence map[cfg.SymbolID][]typ.Type) map[cfg.SymbolID][]typ.Type { - if evidence == nil { - return nil - } - out := make(map[cfg.SymbolID][]typ.Type, len(evidence)) - for _, sym := range cfg.SortedSymbolIDs(evidence) { - v := filterEmptyParameterEvidenceVector(evidence[sym]) - if hasNonNilEvidence(v) { - out[sym] = v - } - } - if len(out) == 0 { - return nil - } - return out -} - -func filterEmptyParameterEvidenceVector(evidence []typ.Type) []typ.Type { - v := normalizeParameterEvidenceVector(evidence) - if !hasNonNilEvidence(v) { - return nil - } - return v -} - -func normalizeParameterEvidenceVector(evidence []typ.Type) []typ.Type { - var out []typ.Type - for i, observed := range evidence { - normalized := paramevidence.NormalizeType(observed) - if out != nil { - out[i] = normalized - continue - } - if !typ.TypeEquals(observed, normalized) { - out = make([]typ.Type, len(evidence)) - copy(out, evidence[:i]) - out[i] = normalized - } - } - if out != nil { - return out - } - return evidence -} - -func hasNonNilEvidence(evidence []typ.Type) bool { - for _, observed := range evidence { - if observed != nil { - return true - } - } - return false -} - -// joinParameterEvidenceVectors joins two parameter evidence vectors element-wise. -func joinParameterEvidenceVectors(a, b []typ.Type) []typ.Type { - if len(a) == 0 { - return b - } - if len(b) == 0 { - return a - } - maxLen := len(a) - if len(b) > maxLen { - maxLen = len(b) - } - result := make([]typ.Type, maxLen) - for i := 0; i < maxLen; i++ { - var ai, bi typ.Type - if i < len(a) { - ai = a[i] - } - if i < len(b) { - bi = b[i] - } - result[i] = joinParameterEvidence(ai, bi) - } - return result -} - -func joinParameterEvidence(a, b typ.Type) typ.Type { - a = paramevidence.NormalizeType(a) - b = paramevidence.NormalizeType(b) - if a == nil { - return b - } - if b == nil { - return a - } - if unwrap.IsNilType(a) && !unwrap.IsNilType(b) { - return b - } - if unwrap.IsNilType(b) && !unwrap.IsNilType(a) { - return a - } - if joined, ok := joinNilableParameterEvidence(a, b); ok { - return joined - } - return joinNonNilParameterEvidence(a, b) -} - -func joinNilableParameterEvidence(a, b typ.Type) (typ.Type, bool) { - ai, anil := splitNilableParameterEvidence(a) - bi, bnil := splitNilableParameterEvidence(b) - if !anil && !bnil { - return nil, false - } - if ai == nil && bi == nil { - return typ.Nil, true - } - if ai == nil { - return typ.NewOptional(bi), true - } - if bi == nil { - return typ.NewOptional(ai), true - } - return typ.NewOptional(joinNonNilParameterEvidence(ai, bi)), true -} - -func splitNilableParameterEvidence(t typ.Type) (typ.Type, bool) { - t = unwrap.Alias(t) - switch v := t.(type) { - case nil: - return nil, false - case *typ.Optional: - return v.Inner, true - case *typ.Union: - members := make([]typ.Type, 0, len(v.Members)) - nilable := false - for _, member := range v.Members { - member = unwrap.Alias(member) - if unwrap.IsNilType(member) { - nilable = true - continue - } - members = append(members, member) - } - if !nilable { - return t, false - } - return typ.NewUnion(members...), true - default: - if unwrap.IsNilType(t) { - return nil, true - } - return t, false - } -} - -func joinNonNilParameterEvidence(a, b typ.Type) typ.Type { - if upper, ok := selectParameterEvidenceTableUpperBound(a, b); ok { - return upper - } - if preferred, ok := preferConcreteOverSoftType(a, b); ok { - return preferred - } - if typeCanSelfEmbed(a) && typeContainsEquivalent(b, a) && !typ.IsAbsentOrUnknown(a) { - if typeContainsUnion(a) { - return a - } - return typ.JoinPreferNonSoft(a, b) - } - if typeCanSelfEmbed(b) && typeContainsEquivalent(a, b) && !typ.IsAbsentOrUnknown(b) { - if typeContainsUnion(b) { - return b - } - return typ.JoinPreferNonSoft(a, b) - } - if typeIsTruthyRefinement(a, b) { - return a - } - if typeIsTruthyRefinement(b, a) { - return b - } - if joined, ok := typ.JoinCompatibleRecords(a, b); ok { - return joined - } - if joined, ok := joinParameterEvidenceMapRecord(a, b); ok { - return joined - } - if TypeExtendsRecord(a, b) { - return a - } - if TypeExtendsRecord(b, a) { - return b - } - if !typ.IsAbsentOrUnknown(a) && !typ.IsAbsentOrUnknown(b) { - if subtype.IsSubtype(a, b) { - return b - } - if subtype.IsSubtype(b, a) { - return a - } - } - return paramevidence.NormalizeType(typ.JoinPreferNonSoft(a, b)) -} - -func preferConcreteOverSoftType(a, b typ.Type) (typ.Type, bool) { - aSoft := typ.IsSoft(a, typ.SoftPlaceholderPolicy) - bSoft := typ.IsSoft(b, typ.SoftPlaceholderPolicy) - switch { - case aSoft && !bSoft && !unwrap.IsNilType(b): - return b, true - case bSoft && !aSoft && !unwrap.IsNilType(a): - return a, true - } - if preferred, ok := preferConcreteOverNilableSoftType(a, b); ok { - return preferred, true - } - return nil, false -} - -func preferConcreteOverNilableSoftType(a, b typ.Type) (typ.Type, bool) { - if preferred, ok := preferConcreteOverNilableSoftTypeDirected(a, b); ok { - return preferred, true - } - return preferConcreteOverNilableSoftTypeDirected(b, a) -} - -func preferConcreteOverNilableSoftTypeDirected(softMaybeNil, concrete typ.Type) (typ.Type, bool) { - inner, nilable := splitNilableParameterEvidence(softMaybeNil) - if !nilable || inner == nil || !typ.IsSoft(inner, typ.SoftPlaceholderPolicy) { - return nil, false - } - if concrete == nil || unwrap.IsNilType(concrete) { - return nil, false - } - concreteInner, concreteNilable := splitNilableParameterEvidence(concrete) - if concreteInner == nil { - return nil, false - } - if typ.IsSoft(concreteInner, typ.SoftPlaceholderPolicy) { - return nil, false - } - if concreteNilable { - return concrete, true - } - return typ.NewOptional(concrete), true -} - -func joinParameterEvidenceMapRecord(a, b typ.Type) (typ.Type, bool) { - if joined, ok := joinParameterEvidenceMapRecordDirected(a, b); ok { - return joined, true - } - return joinParameterEvidenceMapRecordDirected(b, a) -} - -func joinParameterEvidenceMapRecordDirected(mapType, recordType typ.Type) (typ.Type, bool) { - m, ok := unwrap.Alias(mapType).(*typ.Map) - if !ok || m == nil { - return nil, false - } - r, ok := unwrap.Alias(recordType).(*typ.Record) - if !ok || r == nil || !r.HasMapComponent() { - return nil, false - } - - key := joinNonNilParameterEvidence(m.Key, r.MapKey) - value := joinNonNilParameterEvidence(m.Value, r.MapValue) - if len(r.Fields) == 0 && r.Metatable == nil { - return typ.NewMap(key, value), true - } - builder := typ.NewRecord() - if r.Open { - builder.SetOpen(true) - } - if r.Metatable != nil { - builder.Metatable(r.Metatable) - } - builder.MapComponent(key, value) - for _, field := range r.Fields { - fieldType := field.Type - optional := true - if subtype.IsSubtype(typ.LiteralString(field.Name), key) { - fieldType = joinNonNilParameterEvidence(field.Type, value) - } else { - optional = field.Optional - } - switch { - case optional && field.Readonly: - builder.OptReadonlyField(field.Name, fieldType) - case optional: - builder.OptField(field.Name, fieldType) - case field.Readonly: - builder.ReadonlyField(field.Name, fieldType) - default: - builder.Field(field.Name, fieldType) - } - } - return builder.Build(), true -} - -func selectParameterEvidenceTableUpperBound(a, b typ.Type) (typ.Type, bool) { - if parameterEvidenceIsOnlyTableTop(a) && typ.IsAny(b) { - return a, true - } - if parameterEvidenceIsOnlyTableTop(b) && typ.IsAny(a) { - return b, true - } - if parameterEvidenceContainsTableTop(a) && parameterEvidenceCoveredByTableTop(b) && subtype.IsSubtype(b, a) { - return a, true - } - if parameterEvidenceContainsTableTop(b) && parameterEvidenceCoveredByTableTop(a) && subtype.IsSubtype(a, b) { - return b, true - } - return nil, false -} - -func parameterEvidenceContainsTableTop(t typ.Type) bool { - if t == nil { - return false - } - if unwrap.IsBuiltinTableTop(typ.UnwrapAnnotated(t)) { - return true - } - switch v := typ.UnwrapAnnotated(t).(type) { - case *typ.Alias: - return parameterEvidenceContainsTableTop(v.UnaliasedTarget()) - case *typ.Optional: - return parameterEvidenceContainsTableTop(v.Inner) - case *typ.Union: - for _, member := range v.Members { - if parameterEvidenceContainsTableTop(member) { - return true - } - } - } - return false -} - -func parameterEvidenceIsOnlyTableTop(t typ.Type) bool { - if t == nil { - return false - } - if unwrap.IsBuiltinTableTop(typ.UnwrapAnnotated(t)) { - return true - } - switch v := typ.UnwrapAnnotated(t).(type) { - case *typ.Alias: - return parameterEvidenceIsOnlyTableTop(v.UnaliasedTarget()) - case *typ.Optional: - return parameterEvidenceIsOnlyTableTop(v.Inner) - case *typ.Union: - if len(v.Members) == 0 { - return false - } - hasTableTop := false - for _, member := range v.Members { - if unwrap.IsNilType(member) { - continue - } - if !parameterEvidenceIsOnlyTableTop(member) { - return false - } - hasTableTop = true - } - return hasTableTop - default: - return false - } -} - -func parameterEvidenceCoveredByTableTop(t typ.Type) bool { - if t == nil { - return false - } - if typ.IsAny(t) { - return true - } - if unwrap.IsNilType(t) || unwrap.IsBuiltinTableTop(typ.UnwrapAnnotated(t)) { - return true - } - switch v := typ.UnwrapAnnotated(t).(type) { - case *typ.Alias: - return parameterEvidenceCoveredByTableTop(v.UnaliasedTarget()) - case *typ.Optional: - return parameterEvidenceCoveredByTableTop(v.Inner) - case *typ.Recursive: - return v.Body != nil && v.Body != v && parameterEvidenceCoveredByTableTop(v.Body) - case *typ.Union: - if len(v.Members) == 0 { - return false - } - for _, member := range v.Members { - if !parameterEvidenceCoveredByTableTop(member) { - return false - } - } - return true - case *typ.Record, *typ.Map, *typ.Array, *typ.Tuple, *typ.Interface, *typ.Intersection: - return true - default: - return false - } -} - -func typeIsTruthyRefinement(candidate, baseline typ.Type) bool { - if candidate == nil || baseline == nil || typ.TypeEquals(candidate, baseline) { - return false - } - refined := narrow.ToTruthy(baseline) - if refined == nil || refined.Kind().IsNever() || typ.TypeEquals(refined, baseline) { - return false - } - return typ.TypeEquals(candidate, refined) || subtype.IsSubtype(candidate, refined) -} - -func typeCanSelfEmbed(t typ.Type) bool { - if t == nil { - return false - } - switch v := t.(type) { - case *typ.Annotated: - return typeCanSelfEmbed(v.Inner) - case *typ.Alias: - return typeCanSelfEmbed(v.Target) - case *typ.Optional: - return typeCanSelfEmbed(v.Inner) - case *typ.Union: - for _, member := range v.Members { - if typeCanSelfEmbed(member) { - return true - } - } - return false - case *typ.Intersection: - for _, member := range v.Members { - if typeCanSelfEmbed(member) { - return true - } - } - return false - case *typ.Array, *typ.Map, *typ.Tuple, *typ.Record, *typ.Function: - return true - default: - return false - } -} - -func typeContainsEquivalent(haystack, needle typ.Type) bool { - if haystack == nil || needle == nil { - return false - } - return scanType(haystack, typ.NewGuard(), func(node typ.Type) (bool, bool) { - if typ.TypeEquals(node, needle) { - return true, false - } - return false, true - }) -} - -func typeContainsUnion(t typ.Type) bool { - if t == nil { - return false - } - return scanType(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { - if _, ok := node.(*typ.Union); ok { - return true, false - } - return false, true - }) -} - func canonicalInterprocValueType(t typ.Type) typ.Type { if t == nil { return nil @@ -1050,10 +451,10 @@ func widenValueTypeForConvergence(existing, candidate typ.Type) typ.Type { if typ.IsAny(candidate) || typ.IsUnknown(candidate) { return candidate } - if typeElidesOptional(candidate, existing) { + if value.ElidesOptional(candidate, existing) { return candidate } - if TypeExtendsRecord(candidate, existing) && !typeContainsNestedStructuralShape(candidate, existing) { + if value.ExtendsRecord(candidate, existing) && !typeContainsNestedStructuralShape(candidate, existing) { return candidate } if refines, _ := typeRefinesFalsyMapKey(candidate, existing); refines { @@ -1158,13 +559,13 @@ func widenFunctionParamFactTypeForConvergence(existing, candidate typ.Type) typ. if typ.IsAny(candidate) || typ.IsUnknown(candidate) { return candidate } - if preferred, ok := preferConcreteOverSoftType(existing, candidate); ok { + if preferred, ok := value.PreferConcreteOverSoft(existing, candidate); ok { return preferred } - if candidateRefinesFunctionParam(candidate, existing) { + if paramevidence.RefinesFunctionParam(candidate, existing) { return candidate } - if candidateRefinesFunctionParam(existing, candidate) { + if paramevidence.RefinesFunctionParam(existing, candidate) { return existing } if subtype.IsSubtype(candidate, existing) && !subtype.IsSubtype(existing, candidate) { diff --git a/compiler/check/returns/widen_test.go b/compiler/check/returns/widen_test.go index 04a1901a..ee68599d 100644 --- a/compiler/check/returns/widen_test.go +++ b/compiler/check/returns/widen_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/wippyai/go-lua/compiler/ast" - "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" @@ -177,286 +176,6 @@ func TestMergeReturnSummary_KeepsNonRecursiveContainerRefinement(t *testing.T) { } } -func TestWidenParameterEvidence_StopsSelfEmbeddingRecordGrowth(t *testing.T) { - prevHint := typ.NewUnion( - typ.Number, - typ.NewRecord(). - Field("limit", typ.Any). - SetOpen(true). - Build(), - ) - nextHint := typ.NewRecord(). - Field("limit", prevHint). - SetOpen(true). - Build() - - merged := WidenParameterEvidence( - map[cfg.SymbolID][]typ.Type{1: []typ.Type{prevHint}}, - map[cfg.SymbolID][]typ.Type{1: []typ.Type{nextHint}}, - ) - - got := merged[1][0] - if !typ.TypeEquals(got, prevHint) { - t.Fatalf("expected stable previous evidence, got %v", got) - } -} - -func TestWidenParameterEvidence_StopsSelfEmbeddingContainerGrowth(t *testing.T) { - prevHint := typ.NewUnion( - typ.Number, - typ.NewRecord(). - Field("limit", typ.Any). - SetOpen(true). - Build(), - ) - - tests := []struct { - name string - next typ.Type - }{ - { - name: "record", - next: typ.NewRecord(). - Field("value", prevHint). - SetOpen(true). - Build(), - }, - { - name: "array", - next: typ.NewArray(prevHint), - }, - { - name: "map", - next: typ.NewMap(typ.String, prevHint), - }, - { - name: "tuple", - next: typ.NewTuple(prevHint), - }, - { - name: "function", - next: typ.Func(). - Param("value", prevHint). - Returns(prevHint). - Build(), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - merged := WidenParameterEvidence( - map[cfg.SymbolID][]typ.Type{1: []typ.Type{prevHint}}, - map[cfg.SymbolID][]typ.Type{1: []typ.Type{tt.next}}, - ) - - got := merged[1][0] - if !typ.TypeEquals(got, prevHint) { - t.Fatalf("expected stable previous evidence, got %v", got) - } - }) - } -} - -func TestWidenParameterEvidence_KeepsFirstRecordWrapperObservation(t *testing.T) { - nextHint := typ.NewRecord(). - Field("limit", typ.Number). - SetOpen(true). - Build() - - merged := WidenParameterEvidence( - map[cfg.SymbolID][]typ.Type{1: []typ.Type{typ.Number}}, - map[cfg.SymbolID][]typ.Type{1: []typ.Type{nextHint}}, - ) - - got := merged[1][0] - if typ.TypeEquals(got, typ.Number) { - t.Fatalf("expected wrapper observation to be preserved, got %v", got) - } - if !typ.TypeEquals(got, typ.NewUnion(typ.Number, nextHint)) { - t.Fatalf("expected number | wrapper evidence, got %v", got) - } -} - -func TestWidenParameterEvidence_JoinsNestedRecordObservations(t *testing.T) { - nested := typ.NewRecord(). - Field("routes", typ.NewRecord().Field("users", typ.Boolean).SetOpen(true).Build()). - SetOpen(true). - Build() - outer := typ.NewRecord(). - Field("api", nested). - SetOpen(true). - Build() - - merged := WidenParameterEvidence( - map[cfg.SymbolID][]typ.Type{1: []typ.Type{outer}}, - map[cfg.SymbolID][]typ.Type{1: []typ.Type{nested}}, - ) - - got := merged[1][0] - want := typ.NewUnion(outer, nested) - if !typ.TypeEquals(got, want) { - t.Fatalf("expected nested record observations to be joined as %v, got %v", want, got) - } -} - -func TestWidenParameterEvidence_ReplacesStaleBroadHintWithCurrentRefinement(t *testing.T) { - stale := typ.NewUnion(typ.String, typ.False) - current := typ.String - - merged := WidenParameterEvidence( - map[cfg.SymbolID][]typ.Type{1: []typ.Type{stale}}, - map[cfg.SymbolID][]typ.Type{1: []typ.Type{current}}, - ) - - got := merged[1][0] - if !typ.TypeEquals(got, current) { - t.Fatalf("expected current refined evidence %v to replace stale broad evidence, got %v", current, got) - } -} - -func TestWidenParameterEvidence_ReplacesSoftContainerPlaceholderWithConcreteElementShape(t *testing.T) { - entry := typ.NewRecord().Field("id", typ.String).Build() - stale := typ.NewUnion( - typ.NewArray(typ.Any), - typ.NewRecord().SetOpen(true).Build(), - ) - current := typ.NewArray(entry) - - merged := WidenParameterEvidence( - map[cfg.SymbolID][]typ.Type{1: []typ.Type{stale}}, - map[cfg.SymbolID][]typ.Type{1: []typ.Type{current}}, - ) - - got := merged[1][0] - if !typ.TypeEquals(got, current) { - t.Fatalf("expected concrete array evidence %v to replace soft stale evidence, got %v", current, got) - } -} - -func TestWidenParameterEvidence_PreservesStructuredHintOverNilOnlyObservation(t *testing.T) { - context := typ.NewMap(typ.String, typ.Any) - - merged := WidenParameterEvidence( - map[cfg.SymbolID][]typ.Type{1: []typ.Type{typ.String, typ.Any, context}}, - map[cfg.SymbolID][]typ.Type{1: []typ.Type{typ.String, typ.Any, typ.Nil}}, - ) - - got := merged[1][2] - if !typ.TypeEquals(got, context) { - t.Fatalf("expected nil-only observation to preserve structured evidence %v, got %v", context, got) - } - - again := WidenParameterEvidence(merged, map[cfg.SymbolID][]typ.Type{1: []typ.Type{typ.String, typ.Any, typ.Nil}}) - if !parameterEvidenceEqual(merged, again) { - t.Fatalf("expected idempotent nil-only observation widening, got %v then %v", merged, again) - } -} - -func TestWidenParameterEvidence_PreservesMapHintOverOptionalOpenRecordObservation(t *testing.T) { - context := typ.NewMap(typ.String, typ.Any) - optionalContextRecord := typ.NewOptional(typ.NewRecord(). - MapComponent(typ.String, typ.Any). - SetOpen(true). - Build()) - - merged := WidenParameterEvidence( - map[cfg.SymbolID][]typ.Type{1: []typ.Type{typ.String, typ.Any, context}}, - map[cfg.SymbolID][]typ.Type{1: []typ.Type{typ.String, typ.Any, optionalContextRecord}}, - ) - - got := merged[1][2] - if got == nil || typ.TypeEquals(got, typ.Nil) { - t.Fatalf("expected optional structured observation to preserve context evidence, got %v", got) - } - if !typ.TypeEquals(got, typ.NewOptional(context)) { - t.Fatalf("expected pure map observation to stay canonical, got %v", got) - } - - again := WidenParameterEvidence(merged, map[cfg.SymbolID][]typ.Type{1: []typ.Type{typ.String, typ.Any, optionalContextRecord}}) - if !parameterEvidenceEqual(merged, again) { - t.Fatalf("expected idempotent optional structured observation widening, got %v then %v", merged, again) - } -} - -func TestWidenParameterEvidence_CollapsesPureOpenRecordMapToCanonicalMap(t *testing.T) { - entry := typ.NewRecord().Field("id", typ.String).Build() - canonical := typ.NewMap(typ.String, typ.NewArray(entry)) - staleRecordView := typ.NewRecord(). - MapComponent(typ.NewUnion(typ.String, typ.False), typ.NewArray(entry)). - SetOpen(true). - Build() - - merged := WidenParameterEvidence( - map[cfg.SymbolID][]typ.Type{1: []typ.Type{staleRecordView}}, - map[cfg.SymbolID][]typ.Type{1: []typ.Type{canonical}}, - ) - - got := merged[1][0] - if !typ.TypeEquals(got, canonical) { - t.Fatalf("expected pure keyed table evidence to canonicalize to %v, got %v", canonical, got) - } -} - -func TestWidenParameterEvidence_TableTopUpperBoundAbsorbsRecordUnion(t *testing.T) { - tableTop := typ.NewOptional(typ.NewInterface("table", nil)) - strategySpec := typ.NewRecord(). - Field("kind", typ.LiteralString("strategy")). - Field("tools", typ.NewTuple(typ.String, typ.String, typ.String)). - Build() - contextSpec := typ.NewRecord(). - Field("kind", typ.LiteralString("context")). - Field("scope", typ.String). - Build() - nextHint := typ.NewUnion(strategySpec, contextSpec) - - merged := WidenParameterEvidence( - map[cfg.SymbolID][]typ.Type{1: []typ.Type{tableTop}}, - map[cfg.SymbolID][]typ.Type{1: []typ.Type{nextHint}}, - ) - - got := merged[1][0] - if !typ.TypeEquals(got, tableTop) { - t.Fatalf("expected table top upper bound %v, got %v", tableTop, got) - } - - again := WidenParameterEvidence(merged, map[cfg.SymbolID][]typ.Type{1: []typ.Type{nextHint}}) - if !parameterEvidenceEqual(merged, again) { - t.Fatalf("expected idempotent table-top widening, got %v then %v", merged, again) - } -} - -func TestWidenParameterEvidence_TableTopUpperBoundAbsorbsAnyObservation(t *testing.T) { - tableTop := typ.NewOptional(typ.NewInterface("table", nil)) - - merged := WidenParameterEvidence( - map[cfg.SymbolID][]typ.Type{1: []typ.Type{tableTop}}, - map[cfg.SymbolID][]typ.Type{1: []typ.Type{typ.Any}}, - ) - - got := merged[1][0] - if !typ.TypeEquals(got, tableTop) { - t.Fatalf("expected dynamic observation to preserve table top upper bound %v, got %v", tableTop, got) - } - - again := WidenParameterEvidence(merged, map[cfg.SymbolID][]typ.Type{1: []typ.Type{typ.Any}}) - if !parameterEvidenceEqual(merged, again) { - t.Fatalf("expected idempotent table-top/any widening, got %v then %v", merged, again) - } -} - -func parameterEvidenceEqual(a, b map[cfg.SymbolID][]typ.Type) bool { - if len(a) != len(b) { - return false - } - for _, sym := range cfg.SortedSymbolIDs(a) { - right, ok := b[sym] - if !ok || !ReturnTypesEqual(a[sym], right) { - return false - } - } - return true -} - func TestWidenCapturedFieldAssigns_NormalizesOptionalFunctionValues(t *testing.T) { fn := typ.Func().Param("fn", typ.Unknown).Build() merged := WidenCapturedFieldAssigns(nil, api.CapturedFieldAssigns{ From 777579e6a6c82410c5f2aa15a110811cfe722a1e Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 02:15:09 -0400 Subject: [PATCH 19/71] Move value shape laws into domain --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 37 ++ compiler/check/domain/value/shape.go | 448 ++++++++++++++++++++++ compiler/check/domain/value/shape_test.go | 28 ++ compiler/check/returns/join.go | 440 +-------------------- compiler/check/returns/widen.go | 31 +- 5 files changed, 528 insertions(+), 456 deletions(-) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index f27c8273..2191215d 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -124,6 +124,43 @@ Design result: - helper names that encode parameter-specific lattice laws are no longer local return-package predicates. +## 2026-05-19 Value Shape Domain Checkpoint + +The follow-up rectification removed another cluster of domain laws from +`returns` and moved it into `compiler/check/domain/value`. + +Moved value-shape laws: + +- soft container refinement; +- stale falsy map-key refinement; +- nested nil-only regression detection; +- recursive structural-growth detection; +- structural-shape unwrapping and shallow shape equality; +- union member extraction after structural unwrapping. + +`returns` now keeps return-vector orchestration, but it asks `domain/value` for +value-shape facts. This preserves the current behavior while making the mental +model cleaner: + +```text +returns = return-vector policy and function-summary alignment +domain/value = reusable structural value relations +domain/paramevidence = parameter evidence lattice and parameter-slot refinement +``` + +This is a direct ownership move, not a bridge. The old local helpers were +deleted from `returns`. + +Verification for this slice so far: + +- `go test ./compiler/check/domain/value ./compiler/check/returns` passes. +- `go test ./compiler/check/...` passes. +- `go test ./...` passes. +- `git diff --check` passes. +- Standard `../scripts/verify-suite.sh` passes the go-lua checker tests and + Wippy binary build, then exits non-zero on the existing external lint targets: + session 8 errors, agent/src 8 errors, docker-demo 21 errors and 2 warnings. + ## Goal The checker should read as one abstract interpreter over a product domain. diff --git a/compiler/check/domain/value/shape.go b/compiler/check/domain/value/shape.go index 8a2beae6..4807feea 100644 --- a/compiler/check/domain/value/shape.go +++ b/compiler/check/domain/value/shape.go @@ -389,3 +389,451 @@ func IsStructuredTableShape(t typ.Type) bool { return false } } + +// RefinesSoftContainer reports whether candidate preserves the same table shape +// while replacing a soft placeholder element/value with concrete evidence. +func RefinesSoftContainer(candidate, baseline typ.Type) (bool, bool) { + candidate = UnwrapStructuralShape(candidate) + baseline = UnwrapStructuralShape(baseline) + if candidate == nil || baseline == nil { + return candidate == baseline, false + } + if typ.TypeEquals(candidate, baseline) { + return true, false + } + + switch b := baseline.(type) { + case *typ.Array: + c, ok := candidate.(*typ.Array) + if !ok { + return false, false + } + return refinesSoftContainerSlot(c.Element, b.Element) + case *typ.Map: + c, ok := candidate.(*typ.Map) + if !ok || !Equivalent(c.Key, b.Key) { + return false, false + } + return refinesSoftContainerSlot(c.Value, b.Value) + case *typ.Record: + c, ok := candidate.(*typ.Record) + if !ok || !sameRecordFrame(c, b) { + return false, false + } + if !c.HasMapComponent() && !b.HasMapComponent() { + return true, false + } + if !c.HasMapComponent() || !b.HasMapComponent() || !Equivalent(c.MapKey, b.MapKey) { + return false, false + } + return refinesSoftContainerSlot(c.MapValue, b.MapValue) + default: + return false, false + } +} + +func refinesSoftContainerSlot(candidate, baseline typ.Type) (bool, bool) { + if typ.TypeEquals(candidate, baseline) { + return true, false + } + if (typ.IsAny(baseline) || typ.IsUnknown(baseline)) && CanSelfEmbed(candidate) { + return false, false + } + preferred, ok := PreferConcreteOverSoft(baseline, candidate) + return ok && typ.TypeEquals(preferred, candidate), ok +} + +func sameRecordFrame(a, b *typ.Record) bool { + if a == nil || b == nil || a.Open != b.Open || len(a.Fields) != len(b.Fields) { + return false + } + if (a.Metatable == nil) != (b.Metatable == nil) { + return false + } + if a.Metatable != nil && !typ.TypeEquals(a.Metatable, b.Metatable) { + return false + } + for i, field := range a.Fields { + other := b.Fields[i] + if field.Name != other.Name || field.Optional != other.Optional || field.Readonly != other.Readonly { + return false + } + if !typ.TypeEquals(field.Type, other.Type) { + return false + } + } + return true +} + +// RefinesFalsyMapKey reports whether candidate is the same table-derived shape +// as baseline after removing stale falsy key members from baseline. +func RefinesFalsyMapKey(candidate, baseline typ.Type) (bool, bool) { + candidate = UnwrapStructuralShape(candidate) + baseline = UnwrapStructuralShape(baseline) + if candidate == nil || baseline == nil { + return candidate == baseline, false + } + if typ.TypeEquals(candidate, baseline) { + return true, false + } + + switch b := baseline.(type) { + case *typ.Array: + c, ok := candidate.(*typ.Array) + if !ok { + return false, false + } + return truthyElementRefinement(c.Element, b.Element) + case *typ.Map: + c, ok := candidate.(*typ.Map) + if !ok { + return false, false + } + return mapKeyTruthyRefinement(c.Key, c.Value, b.Key, b.Value) + case *typ.Record: + if c, ok := candidate.(*typ.Map); ok { + if len(b.Fields) != 0 || b.Metatable != nil || !b.HasMapComponent() { + return false, false + } + return mapKeyTruthyRefinement(c.Key, c.Value, b.MapKey, b.MapValue) + } + c, ok := candidate.(*typ.Record) + if !ok || !c.HasMapComponent() || !b.HasMapComponent() { + return false, false + } + if c.Open && !b.Open { + return false, false + } + if len(c.Fields) != len(b.Fields) { + return false, false + } + for _, bf := range b.Fields { + cf := c.GetField(bf.Name) + if cf == nil || cf.Optional != bf.Optional || cf.Readonly != bf.Readonly || !typ.TypeEquals(cf.Type, bf.Type) { + return false, false + } + } + if (c.Metatable == nil) != (b.Metatable == nil) || (c.Metatable != nil && !typ.TypeEquals(c.Metatable, b.Metatable)) { + return false, false + } + return mapKeyTruthyRefinement(c.MapKey, c.MapValue, b.MapKey, b.MapValue) + default: + return false, false + } +} + +func mapKeyTruthyRefinement(candidateKey, candidateValue, baselineKey, baselineValue typ.Type) (bool, bool) { + if !typ.TypeEquals(candidateValue, baselineValue) { + return false, false + } + if IsTruthyRefinement(candidateKey, baselineKey) { + return true, true + } + return false, false +} + +func truthyElementRefinement(candidate, baseline typ.Type) (bool, bool) { + if typ.TypeEquals(candidate, baseline) { + return true, false + } + if IsTruthyRefinement(candidate, baseline) { + return true, true + } + return false, false +} + +// NestedNilOnlyRegression reports whether candidate's apparent refinement only +// adds nested nil facts over a more useful baseline shape. +func NestedNilOnlyRegression(candidate, baseline typ.Type) bool { + candidate = UnwrapStructuralShape(candidate) + baseline = UnwrapStructuralShape(baseline) + if candidate == nil || baseline == nil || typ.TypeEquals(candidate, baseline) { + return false + } + if unwrap.IsNilType(candidate) { + return typ.IsAny(baseline) || typ.IsUnknown(baseline) || unwrap.IsOptionalLike(baseline) + } + + switch c := candidate.(type) { + case *typ.Record: + b, ok := baseline.(*typ.Record) + if !ok { + return false + } + for _, cf := range c.Fields { + bf := b.GetField(cf.Name) + if bf == nil { + continue + } + if unwrap.IsNilType(cf.Type) && (bf.Optional || typ.IsAny(bf.Type) || typ.IsUnknown(bf.Type) || unwrap.IsOptionalLike(bf.Type)) { + return true + } + if NestedNilOnlyRegression(cf.Type, bf.Type) { + return true + } + } + if c.HasMapComponent() && b.HasMapComponent() { + return NestedNilOnlyRegression(c.MapValue, b.MapValue) + } + case *typ.Array: + if b, ok := baseline.(*typ.Array); ok { + return NestedNilOnlyRegression(c.Element, b.Element) + } + case *typ.Map: + if b, ok := baseline.(*typ.Map); ok { + return NestedNilOnlyRegression(c.Value, b.Value) + } + case *typ.Tuple: + b, ok := baseline.(*typ.Tuple) + if !ok || len(c.Elements) != len(b.Elements) { + return false + } + for i := range c.Elements { + if NestedNilOnlyRegression(c.Elements[i], b.Elements[i]) { + return true + } + } + case *typ.Function: + b, ok := baseline.(*typ.Function) + if !ok || len(c.Returns) != len(b.Returns) { + return false + } + for i := range c.Returns { + if NestedNilOnlyRegression(c.Returns[i], b.Returns[i]) { + return true + } + } + } + return false +} + +// ContainsNestedStructuralShape reports whether haystack embeds the same +// shallow structural shape as needle below the root. +func ContainsNestedStructuralShape(haystack, needle typ.Type) bool { + return containsNestedStructuralShapeDepth(haystack, needle, make(map[typ.Type]bool), false) +} + +func containsNestedStructuralShapeDepth(haystack, needle typ.Type, seen map[typ.Type]bool, belowContainer bool) bool { + if haystack == nil || needle == nil { + return false + } + if seen[haystack] { + return false + } + seen[haystack] = true + + node := UnwrapStructuralShape(haystack) + if node == nil { + return false + } + if belowContainer && ShallowStructuralShapeEquals(node, needle) { + return true + } + + descend := func(child typ.Type, childBelowContainer bool) bool { + return containsNestedStructuralShapeDepth(child, needle, seen, childBelowContainer) + } + + switch n := node.(type) { + case *typ.Optional: + return descend(n.Inner, belowContainer) + case *typ.Union: + for _, member := range n.Members { + if descend(member, belowContainer) { + return true + } + } + return false + case *typ.Intersection: + for _, member := range n.Members { + if descend(member, belowContainer) { + return true + } + } + return false + case *typ.Array: + return descend(n.Element, true) + case *typ.Map: + return descend(n.Key, true) || descend(n.Value, true) + case *typ.Tuple: + for _, elem := range n.Elements { + if descend(elem, true) { + return true + } + } + return false + case *typ.Record: + for _, field := range n.Fields { + if descend(field.Type, true) { + return true + } + } + if n.Metatable != nil && descend(n.Metatable, true) { + return true + } + if n.HasMapComponent() { + return descend(n.MapKey, true) || descend(n.MapValue, true) + } + return false + case *typ.Function: + for _, param := range n.Params { + if descend(param.Type, true) { + return true + } + } + if n.Variadic != nil && descend(n.Variadic, true) { + return true + } + for _, ret := range n.Returns { + if descend(ret, true) { + return true + } + } + return false + case *typ.Instantiated: + for _, arg := range n.TypeArgs { + if descend(arg, belowContainer) { + return true + } + } + return false + case *typ.Interface: + for _, method := range n.Methods { + if method.Type != nil && descend(method.Type, true) { + return true + } + } + return false + default: + return false + } +} + +// ShallowStructuralShapeEquals reports whether a and b have the same root +// structural container shape. +func ShallowStructuralShapeEquals(a, b typ.Type) bool { + a = UnwrapStructuralShape(a) + b = UnwrapStructuralShape(b) + if a == nil || b == nil { + return a == b + } + + switch av := a.(type) { + case *typ.Union: + for _, member := range av.Members { + if ShallowStructuralShapeEquals(member, b) { + return true + } + } + return false + case *typ.Intersection: + for _, member := range av.Members { + if ShallowStructuralShapeEquals(member, b) { + return true + } + } + return false + } + switch bv := b.(type) { + case *typ.Union: + for _, member := range bv.Members { + if ShallowStructuralShapeEquals(a, member) { + return true + } + } + return false + case *typ.Intersection: + for _, member := range bv.Members { + if ShallowStructuralShapeEquals(a, member) { + return true + } + } + return false + } + + switch av := a.(type) { + case *typ.Array: + _, ok := b.(*typ.Array) + return ok + case *typ.Map: + bv, ok := b.(*typ.Map) + return ok && shallowMapKeyShapeEquals(av.Key, bv.Key) + case *typ.Tuple: + bv, ok := b.(*typ.Tuple) + return ok && len(av.Elements) == len(bv.Elements) + case *typ.Record: + bv, ok := b.(*typ.Record) + return ok && shallowRecordShapeEquals(av, bv) + default: + return typ.TypeEquals(a, b) + } +} + +// UnwrapStructuralShape strips transparent wrappers for structural comparison. +func UnwrapStructuralShape(t typ.Type) typ.Type { + for t != nil { + switch v := t.(type) { + case *typ.Annotated: + if v.Inner == nil || v.Inner == t { + return t + } + t = v.Inner + case *typ.Alias: + if v.Target == nil || v.Target == t { + return t + } + t = v.Target + case *typ.Optional: + if v.Inner == nil || v.Inner == t { + return t + } + t = v.Inner + default: + return t + } + } + return nil +} + +func shallowMapKeyShapeEquals(a, b typ.Type) bool { + if a == nil || b == nil { + return a == b + } + if typ.TypeEquals(a, b) { + return true + } + return typ.IsAny(a) || typ.IsAny(b) || typ.IsUnknown(a) || typ.IsUnknown(b) +} + +func shallowRecordShapeEquals(a, b *typ.Record) bool { + if a == nil || b == nil { + return a == b + } + if a.HasMapComponent() != b.HasMapComponent() { + return false + } + if a.HasMapComponent() && !shallowMapKeyShapeEquals(a.MapKey, b.MapKey) { + return false + } + if len(a.Fields) != len(b.Fields) { + return false + } + for _, field := range a.Fields { + if b.GetField(field.Name) == nil { + return false + } + } + return true +} + +// UnionMembers returns explicit union members after structural unwrapping. +func UnionMembers(t typ.Type) []typ.Type { + switch v := UnwrapStructuralShape(t).(type) { + case *typ.Union: + return v.Members + case *typ.Optional: + return append([]typ.Type{typ.Nil}, UnionMembers(v.Inner)...) + default: + return nil + } +} diff --git a/compiler/check/domain/value/shape_test.go b/compiler/check/domain/value/shape_test.go index 5377885e..a6e55ade 100644 --- a/compiler/check/domain/value/shape_test.go +++ b/compiler/check/domain/value/shape_test.go @@ -28,3 +28,31 @@ func TestExtendsRecord_MapComponentConsistency(t *testing.T) { t.Error("record without map component should not extend record with map component") } } + +func TestRefinesFalsyMapKey(t *testing.T) { + candidate := typ.NewMap(typ.String, typ.Number) + baseline := typ.NewMap(typ.NewUnion(typ.String, typ.False), typ.Number) + + ok, changed := RefinesFalsyMapKey(candidate, baseline) + if !ok || !changed { + t.Fatalf("expected truthy key refinement, got ok=%v changed=%v", ok, changed) + } +} + +func TestNestedNilOnlyRegression(t *testing.T) { + candidate := typ.NewRecord().Field("value", typ.Nil).Build() + baseline := typ.NewRecord().OptField("value", typ.String).Build() + + if !NestedNilOnlyRegression(candidate, baseline) { + t.Fatalf("expected nested nil-only regression") + } +} + +func TestContainsNestedStructuralShape(t *testing.T) { + shape := typ.NewMap(typ.String, typ.Any) + growing := typ.NewMap(typ.String, typ.NewMap(typ.String, typ.Nil)) + + if !ContainsNestedStructuralShape(growing, shape) { + t.Fatalf("expected nested structural shape") + } +} diff --git a/compiler/check/returns/join.go b/compiler/check/returns/join.go index 6b30b909..3ac0664f 100644 --- a/compiler/check/returns/join.go +++ b/compiler/check/returns/join.go @@ -4,7 +4,6 @@ import ( "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" "github.com/wippyai/go-lua/compiler/check/domain/value" "github.com/wippyai/go-lua/types/kind" - "github.com/wippyai/go-lua/types/narrow" "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" typjoin "github.com/wippyai/go-lua/types/typ/join" @@ -178,7 +177,7 @@ func ReturnTypesRefineSoftContainers(candidate, baseline []typ.Type) bool { } strict := false for i := range candidate { - refines, changed := typeRefinesSoftContainer(candidate[i], baseline[i]) + refines, changed := value.RefinesSoftContainer(candidate[i], baseline[i]) if !refines { return false } @@ -189,79 +188,6 @@ func ReturnTypesRefineSoftContainers(candidate, baseline []typ.Type) bool { return strict } -func typeRefinesSoftContainer(candidate, baseline typ.Type) (bool, bool) { - candidate = unwrapStructuralShape(candidate) - baseline = unwrapStructuralShape(baseline) - if candidate == nil || baseline == nil { - return candidate == baseline, false - } - if typ.TypeEquals(candidate, baseline) { - return true, false - } - - switch b := baseline.(type) { - case *typ.Array: - c, ok := candidate.(*typ.Array) - if !ok { - return false, false - } - return typeRefinesSoftContainerSlot(c.Element, b.Element) - case *typ.Map: - c, ok := candidate.(*typ.Map) - if !ok || !value.Equivalent(c.Key, b.Key) { - return false, false - } - return typeRefinesSoftContainerSlot(c.Value, b.Value) - case *typ.Record: - c, ok := candidate.(*typ.Record) - if !ok || !sameRecordFrame(c, b) { - return false, false - } - if !c.HasMapComponent() && !b.HasMapComponent() { - return true, false - } - if !c.HasMapComponent() || !b.HasMapComponent() || !value.Equivalent(c.MapKey, b.MapKey) { - return false, false - } - return typeRefinesSoftContainerSlot(c.MapValue, b.MapValue) - default: - return false, false - } -} - -func typeRefinesSoftContainerSlot(candidate, baseline typ.Type) (bool, bool) { - if typ.TypeEquals(candidate, baseline) { - return true, false - } - if (typ.IsAny(baseline) || typ.IsUnknown(baseline)) && value.CanSelfEmbed(candidate) { - return false, false - } - preferred, ok := value.PreferConcreteOverSoft(baseline, candidate) - return ok && typ.TypeEquals(preferred, candidate), ok -} - -func sameRecordFrame(a, b *typ.Record) bool { - if a == nil || b == nil || a.Open != b.Open || len(a.Fields) != len(b.Fields) { - return false - } - if (a.Metatable == nil) != (b.Metatable == nil) { - return false - } - if a.Metatable != nil && !typ.TypeEquals(a.Metatable, b.Metatable) { - return false - } - for i, field := range a.Fields { - other := b.Fields[i] - if field.Name != other.Name || field.Optional != other.Optional || field.Readonly != other.Readonly { - return false - } - if !typ.TypeEquals(field.Type, other.Type) { - return false - } - } - return true -} - // ReturnTypesRefineFalsyMapKeys reports whether candidate is the same // table-derived shape as baseline after removing stale falsy members from // baseline. This handles fixed-point rounds where an early branch-insensitive @@ -273,7 +199,7 @@ func ReturnTypesRefineFalsyMapKeys(candidate, baseline []typ.Type) bool { } strict := false for i := range candidate { - refines, changed := typeRefinesFalsyMapKey(candidate[i], baseline[i]) + refines, changed := value.RefinesFalsyMapKey(candidate[i], baseline[i]) if !refines { return false } @@ -284,85 +210,6 @@ func ReturnTypesRefineFalsyMapKeys(candidate, baseline []typ.Type) bool { return strict } -func typeRefinesFalsyMapKey(candidate, baseline typ.Type) (bool, bool) { - candidate = unwrapStructuralShape(candidate) - baseline = unwrapStructuralShape(baseline) - if candidate == nil || baseline == nil { - return candidate == baseline, false - } - if typ.TypeEquals(candidate, baseline) { - return true, false - } - - switch b := baseline.(type) { - case *typ.Array: - c, ok := candidate.(*typ.Array) - if !ok { - return false, false - } - return truthyElementRefinement(c.Element, b.Element) - case *typ.Map: - c, ok := candidate.(*typ.Map) - if !ok { - return false, false - } - return mapKeyTruthyRefinement(c.Key, c.Value, b.Key, b.Value) - case *typ.Record: - if c, ok := candidate.(*typ.Map); ok { - if len(b.Fields) != 0 || b.Metatable != nil || !b.HasMapComponent() { - return false, false - } - return mapKeyTruthyRefinement(c.Key, c.Value, b.MapKey, b.MapValue) - } - c, ok := candidate.(*typ.Record) - if !ok || !c.HasMapComponent() || !b.HasMapComponent() { - return false, false - } - if c.Open && !b.Open { - return false, false - } - if len(c.Fields) != len(b.Fields) { - return false, false - } - for _, bf := range b.Fields { - cf := c.GetField(bf.Name) - if cf == nil || cf.Optional != bf.Optional || cf.Readonly != bf.Readonly || !typ.TypeEquals(cf.Type, bf.Type) { - return false, false - } - } - if (c.Metatable == nil) != (b.Metatable == nil) || (c.Metatable != nil && !typ.TypeEquals(c.Metatable, b.Metatable)) { - return false, false - } - return mapKeyTruthyRefinement(c.MapKey, c.MapValue, b.MapKey, b.MapValue) - default: - return false, false - } -} - -func mapKeyTruthyRefinement(candidateKey, candidateValue, baselineKey, baselineValue typ.Type) (bool, bool) { - if !typ.TypeEquals(candidateValue, baselineValue) { - return false, false - } - refinedKey := narrow.ToTruthy(baselineKey) - if refinedKey == nil || refinedKey.Kind().IsNever() || typ.TypeEquals(refinedKey, baselineKey) { - return false, false - } - if typ.TypeEquals(candidateKey, refinedKey) || subtype.IsSubtype(candidateKey, refinedKey) { - return true, true - } - return false, false -} - -func truthyElementRefinement(candidate, baseline typ.Type) (bool, bool) { - if typ.TypeEquals(candidate, baseline) { - return true, false - } - if value.IsTruthyRefinement(candidate, baseline) { - return true, true - } - return false, false -} - // ReturnTypesNestedNilOnlyRegression reports whether candidate's apparent // refinement only adds nested nil facts over a more useful baseline shape. A // required `nil` field or `unknown -> nil` field does not help callers, but it @@ -372,76 +219,13 @@ func ReturnTypesNestedNilOnlyRegression(candidate, baseline []typ.Type) bool { return false } for i := range candidate { - if typeNestedNilOnlyRegression(candidate[i], baseline[i]) { + if value.NestedNilOnlyRegression(candidate[i], baseline[i]) { return true } } return false } -func typeNestedNilOnlyRegression(candidate, baseline typ.Type) bool { - candidate = unwrapStructuralShape(candidate) - baseline = unwrapStructuralShape(baseline) - if candidate == nil || baseline == nil || typ.TypeEquals(candidate, baseline) { - return false - } - if unwrap.IsNilType(candidate) { - return typ.IsAny(baseline) || typ.IsUnknown(baseline) || unwrap.IsOptionalLike(baseline) - } - - switch c := candidate.(type) { - case *typ.Record: - b, ok := baseline.(*typ.Record) - if !ok { - return false - } - for _, cf := range c.Fields { - bf := b.GetField(cf.Name) - if bf == nil { - continue - } - if unwrap.IsNilType(cf.Type) && (bf.Optional || typ.IsAny(bf.Type) || typ.IsUnknown(bf.Type) || unwrap.IsOptionalLike(bf.Type)) { - return true - } - if typeNestedNilOnlyRegression(cf.Type, bf.Type) { - return true - } - } - if c.HasMapComponent() && b.HasMapComponent() { - return typeNestedNilOnlyRegression(c.MapValue, b.MapValue) - } - case *typ.Array: - if b, ok := baseline.(*typ.Array); ok { - return typeNestedNilOnlyRegression(c.Element, b.Element) - } - case *typ.Map: - if b, ok := baseline.(*typ.Map); ok { - return typeNestedNilOnlyRegression(c.Value, b.Value) - } - case *typ.Tuple: - b, ok := baseline.(*typ.Tuple) - if !ok || len(c.Elements) != len(b.Elements) { - return false - } - for i := range c.Elements { - if typeNestedNilOnlyRegression(c.Elements[i], b.Elements[i]) { - return true - } - } - case *typ.Function: - b, ok := baseline.(*typ.Function) - if !ok || len(c.Returns) != len(b.Returns) { - return false - } - for i := range c.Returns { - if typeNestedNilOnlyRegression(c.Returns[i], b.Returns[i]) { - return true - } - } - } - return false -} - // ReturnTypesStopRecursiveStructuralGrowth reports whether growing embeds the // same structural container shape as stable beneath its root. Recursive table // builders such as deep-copy helpers otherwise look like ever-more-specific @@ -466,10 +250,10 @@ func ReturnTypesStopRecursiveStructuralGrowth(stable, growing []typ.Type) bool { if typ.IsAbsentOrUnknown(s) || !value.CanSelfEmbed(s) { return false } - if !shallowStructuralShapeEquals(g, s) { + if !value.ShallowStructuralShapeEquals(g, s) { return false } - if !typeContainsNestedStructuralShape(g, s) { + if !value.ContainsNestedStructuralShape(g, s) { return false } strict = true @@ -477,220 +261,6 @@ func ReturnTypesStopRecursiveStructuralGrowth(stable, growing []typ.Type) bool { return strict } -func typeContainsNestedStructuralShape(haystack, needle typ.Type) bool { - return typeContainsNestedStructuralShapeDepth(haystack, needle, make(map[typ.Type]bool), false) -} - -func typeContainsNestedStructuralShapeDepth(haystack, needle typ.Type, seen map[typ.Type]bool, belowContainer bool) bool { - if haystack == nil || needle == nil { - return false - } - if seen[haystack] { - return false - } - seen[haystack] = true - - node := unwrapStructuralShape(haystack) - if node == nil { - return false - } - if belowContainer && shallowStructuralShapeEquals(node, needle) { - return true - } - - descend := func(child typ.Type, childBelowContainer bool) bool { - return typeContainsNestedStructuralShapeDepth(child, needle, seen, childBelowContainer) - } - - switch n := node.(type) { - case *typ.Optional: - return descend(n.Inner, belowContainer) - case *typ.Union: - for _, member := range n.Members { - if descend(member, belowContainer) { - return true - } - } - return false - case *typ.Intersection: - for _, member := range n.Members { - if descend(member, belowContainer) { - return true - } - } - return false - case *typ.Array: - return descend(n.Element, true) - case *typ.Map: - return descend(n.Key, true) || descend(n.Value, true) - case *typ.Tuple: - for _, elem := range n.Elements { - if descend(elem, true) { - return true - } - } - return false - case *typ.Record: - for _, field := range n.Fields { - if descend(field.Type, true) { - return true - } - } - if n.Metatable != nil && descend(n.Metatable, true) { - return true - } - if n.HasMapComponent() { - return descend(n.MapKey, true) || descend(n.MapValue, true) - } - return false - case *typ.Function: - for _, param := range n.Params { - if descend(param.Type, true) { - return true - } - } - if n.Variadic != nil && descend(n.Variadic, true) { - return true - } - for _, ret := range n.Returns { - if descend(ret, true) { - return true - } - } - return false - case *typ.Instantiated: - for _, arg := range n.TypeArgs { - if descend(arg, belowContainer) { - return true - } - } - return false - case *typ.Interface: - for _, method := range n.Methods { - if method.Type != nil && descend(method.Type, true) { - return true - } - } - return false - default: - return false - } -} - -func shallowStructuralShapeEquals(a, b typ.Type) bool { - a = unwrapStructuralShape(a) - b = unwrapStructuralShape(b) - if a == nil || b == nil { - return a == b - } - - switch av := a.(type) { - case *typ.Union: - for _, member := range av.Members { - if shallowStructuralShapeEquals(member, b) { - return true - } - } - return false - case *typ.Intersection: - for _, member := range av.Members { - if shallowStructuralShapeEquals(member, b) { - return true - } - } - return false - } - switch bv := b.(type) { - case *typ.Union: - for _, member := range bv.Members { - if shallowStructuralShapeEquals(a, member) { - return true - } - } - return false - case *typ.Intersection: - for _, member := range bv.Members { - if shallowStructuralShapeEquals(a, member) { - return true - } - } - return false - } - - switch av := a.(type) { - case *typ.Array: - _, ok := b.(*typ.Array) - return ok - case *typ.Map: - bv, ok := b.(*typ.Map) - return ok && shallowMapKeyShapeEquals(av.Key, bv.Key) - case *typ.Tuple: - bv, ok := b.(*typ.Tuple) - return ok && len(av.Elements) == len(bv.Elements) - case *typ.Record: - bv, ok := b.(*typ.Record) - return ok && shallowRecordShapeEquals(av, bv) - default: - return typ.TypeEquals(a, b) - } -} - -func unwrapStructuralShape(t typ.Type) typ.Type { - for t != nil { - switch v := t.(type) { - case *typ.Annotated: - if v.Inner == nil || v.Inner == t { - return t - } - t = v.Inner - case *typ.Alias: - if v.Target == nil || v.Target == t { - return t - } - t = v.Target - case *typ.Optional: - if v.Inner == nil || v.Inner == t { - return t - } - t = v.Inner - default: - return t - } - } - return nil -} - -func shallowMapKeyShapeEquals(a, b typ.Type) bool { - if a == nil || b == nil { - return a == b - } - if typ.TypeEquals(a, b) { - return true - } - return typ.IsAny(a) || typ.IsAny(b) || typ.IsUnknown(a) || typ.IsUnknown(b) -} - -func shallowRecordShapeEquals(a, b *typ.Record) bool { - if a == nil || b == nil { - return a == b - } - if a.HasMapComponent() != b.HasMapComponent() { - return false - } - if a.HasMapComponent() && !shallowMapKeyShapeEquals(a.MapKey, b.MapKey) { - return false - } - if len(a.Fields) != len(b.Fields) { - return false - } - for _, field := range a.Fields { - if b.GetField(field.Name) == nil { - return false - } - } - return true -} - // SelectRefiningReturnVector prefers candidate only when it is a directional // refinement of baseline. It never prefers baseline over candidate. // diff --git a/compiler/check/returns/widen.go b/compiler/check/returns/widen.go index 598c8786..50c67661 100644 --- a/compiler/check/returns/widen.go +++ b/compiler/check/returns/widen.go @@ -268,14 +268,14 @@ func typeUnsafePrecisionDrop(prev, merged typ.Type) bool { if value.ElidesOptional(merged, prev) { return false } - if refines, _ := typeRefinesFalsyMapKey(merged, prev); refines { + if refines, _ := value.RefinesFalsyMapKey(merged, prev); refines { return false } if typ.IsAny(prev) || typ.IsUnknown(prev) { return true } - switch p := unwrapStructuralShape(prev).(type) { + switch p := value.UnwrapStructuralShape(prev).(type) { case *typ.Union: if unionStrictMemberSubset(merged, p) { return true @@ -284,7 +284,7 @@ func typeUnsafePrecisionDrop(prev, merged typ.Type) bool { return true } case *typ.Record: - m, ok := unwrapStructuralShape(merged).(*typ.Record) + m, ok := value.UnwrapStructuralShape(merged).(*typ.Record) if !ok { break } @@ -298,15 +298,15 @@ func typeUnsafePrecisionDrop(prev, merged typ.Type) bool { return true } case *typ.Array: - if m, ok := unwrapStructuralShape(merged).(*typ.Array); ok { + if m, ok := value.UnwrapStructuralShape(merged).(*typ.Array); ok { return typeUnsafePrecisionDrop(p.Element, m.Element) } case *typ.Map: - if m, ok := unwrapStructuralShape(merged).(*typ.Map); ok { + if m, ok := value.UnwrapStructuralShape(merged).(*typ.Map); ok { return typeUnsafePrecisionDrop(p.Key, m.Key) || typeUnsafePrecisionDrop(p.Value, m.Value) } case *typ.Tuple: - m, ok := unwrapStructuralShape(merged).(*typ.Tuple) + m, ok := value.UnwrapStructuralShape(merged).(*typ.Tuple) if !ok || len(p.Elements) != len(m.Elements) { break } @@ -316,7 +316,7 @@ func typeUnsafePrecisionDrop(prev, merged typ.Type) bool { } } case *typ.Function: - m, ok := unwrapStructuralShape(merged).(*typ.Function) + m, ok := value.UnwrapStructuralShape(merged).(*typ.Function) if !ok { break } @@ -342,7 +342,7 @@ func unionStrictMemberSubset(candidate typ.Type, baseline *typ.Union) bool { if baseline == nil { return false } - candidateMembers := unionMembers(candidate) + candidateMembers := value.UnionMembers(candidate) if len(candidateMembers) == 0 { candidateMembers = []typ.Type{candidate} } @@ -364,17 +364,6 @@ func unionStrictMemberSubset(candidate typ.Type, baseline *typ.Union) bool { return true } -func unionMembers(t typ.Type) []typ.Type { - switch v := unwrapStructuralShape(t).(type) { - case *typ.Union: - return v.Members - case *typ.Optional: - return append([]typ.Type{typ.Nil}, unionMembers(v.Inner)...) - default: - return nil - } -} - func canonicalInterprocValueType(t typ.Type) typ.Type { if t == nil { return nil @@ -454,10 +443,10 @@ func widenValueTypeForConvergence(existing, candidate typ.Type) typ.Type { if value.ElidesOptional(candidate, existing) { return candidate } - if value.ExtendsRecord(candidate, existing) && !typeContainsNestedStructuralShape(candidate, existing) { + if value.ExtendsRecord(candidate, existing) && !value.ContainsNestedStructuralShape(candidate, existing) { return candidate } - if refines, _ := typeRefinesFalsyMapKey(candidate, existing); refines { + if refines, _ := value.RefinesFalsyMapKey(candidate, existing); refines { return candidate } if subtype.IsSubtype(candidate, existing) && !subtype.IsSubtype(existing, candidate) { From 19ce145c829b131d5e3b7cf3de8c3e684ebc3943 Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 02:33:24 -0400 Subject: [PATCH 20/71] Move return summaries into domain --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 62 ++ compiler/check/domain/paramevidence/merge.go | 14 + .../check/domain/paramevidence/merge_test.go | 24 +- compiler/check/domain/returnsummary/doc.go | 8 + .../check/domain/returnsummary/summary.go | 925 ++++++++++++++++++ .../domain/returnsummary/summary_test.go | 47 + compiler/check/infer/interproc/postflow.go | 7 +- compiler/check/infer/return/infer.go | 9 +- .../check/infer/return/overlay_pipeline.go | 11 +- compiler/check/infer/return/scc.go | 5 +- compiler/check/phase/scope.go | 4 +- compiler/check/returns/domain_law_test.go | 5 +- compiler/check/returns/equal.go | 8 +- compiler/check/returns/function_facts.go | 5 +- compiler/check/returns/join.go | 802 +-------------- compiler/check/returns/join_test.go | 175 ++-- compiler/check/returns/kernel.go | 15 +- compiler/check/returns/kernel_test.go | 35 +- compiler/check/returns/widen.go | 161 +-- compiler/check/returns/widen_test.go | 49 +- .../tests/errors/error_correlation_test.go | 4 +- 21 files changed, 1233 insertions(+), 1142 deletions(-) create mode 100644 compiler/check/domain/returnsummary/doc.go create mode 100644 compiler/check/domain/returnsummary/summary.go create mode 100644 compiler/check/domain/returnsummary/summary_test.go diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index 2191215d..1a567854 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -161,6 +161,68 @@ Verification for this slice so far: Wippy binary build, then exits non-zero on the existing external lint targets: session 8 errors, agent/src 8 errors, docker-demo 21 errors and 2 warnings. +## 2026-05-19 Return Summary Domain Checkpoint + +The next rectification slice moved return-vector policy and function-signature +return alignment out of `compiler/check/returns` and into +`compiler/check/domain/returnsummary`. + +Moved domain laws: + +- return-vector equality and nil-slot canonicalization; +- return-vector normalization with soft-union pruning; +- directional refinement, optional elision, record extension, and nil-slot fill; +- concrete-over-soft summary preference and stale falsy map-key refinement; +- nested nil-only regression protection; +- recursive structural-growth stopping for table builders; +- nested `never` artifact repair; +- higher-order monotone summary merge for function-returning-function and + self-recursive method shapes; +- summary-to-function-return alignment and conservative unknown return + attachment for otherwise returnless callable values. + +Production callers now import `domain/returnsummary` directly. The old +`returns.ReturnTypes*`, `returns.MergeReturnSummary`, +`returns.NormalizeReturnVector*`, `returns.AlignFunctionTypeWithSummary`, +`returns.WithSummaryOrUnknown`, `canonicalReturnVector`, and +`normalizeAndPruneReturnVector` names were deleted instead of wrapped. + +Current package ownership: + +```text +domain/value = reusable structural value relations +domain/paramevidence = parameter evidence lattice, equality, and parameter-slot refinement +domain/returnsummary = return-vector lattice and function-return alignment +returns = function-fact product orchestration and interproc widening +``` + +This keeps one clear abstract-interpreter data flow: + +1. flow and return inference produce candidate return evidence; +2. `domain/returnsummary` decides how return vectors normalize, compare, merge, + and align to callable types; +3. `returns` only decides when function-fact products are joined or widened; +4. Salsa snapshots continue to observe the canonical fact product rather than a + compatibility mirror. + +This is a flash migration, not a bridge. Production code no longer calls the old +return-summary helpers through `returns`. + +Verification for this slice so far: + +- `go test ./compiler/check/domain/returnsummary ./compiler/check/returns` + passes. +- `go test ./compiler/check/...` passes. +- `go test ./...` passes. +- `git diff --check` passes. +- `go test ./compiler/check -run '^$' -bench BenchmarkCheck_LargeFunction + -benchmem -count=3` reports about 1.15 ms/op, 882 KB/op, and 9390 + allocs/op on this machine. +- Standard `../scripts/verify-suite.sh` passes go-lua checker tests and builds + the Wippy binary, then exits non-zero on the known external pinned lint + targets: session 8 errors, agent/src 9 errors, docker-demo 21 errors and + 2 warnings. + ## Goal The checker should read as one abstract interpreter over a product domain. diff --git a/compiler/check/domain/paramevidence/merge.go b/compiler/check/domain/paramevidence/merge.go index e5900953..b8bf31c9 100644 --- a/compiler/check/domain/paramevidence/merge.go +++ b/compiler/check/domain/paramevidence/merge.go @@ -91,6 +91,20 @@ func NormalizeVector(evidence []typ.Type) []typ.Type { return evidence } +// EqualVectors reports whether two normalized evidence vectors are structurally +// equal. +func EqualVectors(a, b []typ.Type) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if !typ.TypeEquals(a[i], b[i]) { + return false + } + } + return true +} + func hasNonNilEvidence(evidence []typ.Type) bool { for _, observed := range evidence { if observed != nil { diff --git a/compiler/check/domain/paramevidence/merge_test.go b/compiler/check/domain/paramevidence/merge_test.go index e1e9d1a6..cad42837 100644 --- a/compiler/check/domain/paramevidence/merge_test.go +++ b/compiler/check/domain/paramevidence/merge_test.go @@ -274,25 +274,25 @@ func TestWidenMap_TableTopUpperBoundAbsorbsAnyObservation(t *testing.T) { } } -func evidenceMapsEqual(a, b map[cfg.SymbolID][]typ.Type) bool { - if len(a) != len(b) { - return false +func TestEqualVectors(t *testing.T) { + if !EqualVectors([]typ.Type{typ.String, typ.Nil}, []typ.Type{typ.String, typ.Nil}) { + t.Fatal("expected equal evidence vectors") } - for _, sym := range cfg.SortedSymbolIDs(a) { - right, ok := b[sym] - if !ok || !evidenceVectorsEqual(a[sym], right) { - return false - } + if EqualVectors([]typ.Type{typ.String}, []typ.Type{typ.String, typ.Nil}) { + t.Fatal("expected different lengths to be unequal") + } + if EqualVectors([]typ.Type{typ.String}, []typ.Type{typ.Number}) { + t.Fatal("expected different evidence slots to be unequal") } - return true } -func evidenceVectorsEqual(a, b []typ.Type) bool { +func evidenceMapsEqual(a, b map[cfg.SymbolID][]typ.Type) bool { if len(a) != len(b) { return false } - for i := range a { - if !typ.TypeEquals(a[i], b[i]) { + for _, sym := range cfg.SortedSymbolIDs(a) { + right, ok := b[sym] + if !ok || !EqualVectors(a[sym], right) { return false } } diff --git a/compiler/check/domain/returnsummary/doc.go b/compiler/check/domain/returnsummary/doc.go new file mode 100644 index 00000000..736c3eb4 --- /dev/null +++ b/compiler/check/domain/returnsummary/doc.go @@ -0,0 +1,8 @@ +// Package returnsummary owns the return-vector abstract domain. +// +// It canonicalizes, compares, joins, and widens return summaries produced by +// local return inference and interprocedural fact propagation. Orchestration +// packages decide when candidate summaries are produced; this package decides +// how those summaries normalize, refine, merge, and align back to function +// types. +package returnsummary diff --git a/compiler/check/domain/returnsummary/summary.go b/compiler/check/domain/returnsummary/summary.go new file mode 100644 index 00000000..579a0270 --- /dev/null +++ b/compiler/check/domain/returnsummary/summary.go @@ -0,0 +1,925 @@ +package returnsummary + +import ( + "github.com/wippyai/go-lua/compiler/check/domain/value" + "github.com/wippyai/go-lua/types/kind" + "github.com/wippyai/go-lua/types/subtype" + "github.com/wippyai/go-lua/types/typ" + typjoin "github.com/wippyai/go-lua/types/typ/join" + "github.com/wippyai/go-lua/types/typ/unwrap" +) + +// Equal checks whether two return vectors are structurally equal. +func Equal(a, b []typ.Type) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if !typ.TypeEquals(a[i], b[i]) { + return false + } + } + return true +} + +// AllNil reports whether every return slot is explicit nil. +func AllNil(rets []typ.Type) bool { + if len(rets) == 0 { + return false + } + for _, t := range rets { + if t == nil || t.Kind() != kind.Nil { + return false + } + } + return true +} + +// Refines reports whether a is an element-wise subtype refinement of b. +func Refines(a, b []typ.Type) bool { + if len(a) == 0 { + return false + } + if len(b) == 0 { + return true + } + if len(a) != len(b) { + return false + } + for i := range a { + ai := a[i] + bi := b[i] + if ai == nil || bi == nil { + if ai == nil && bi == nil { + continue + } + return false + } + if !subtype.IsSubtype(ai, bi) { + return false + } + } + return true +} + +// ExtendsRecord reports whether a refines b by adding record fields. +func ExtendsRecord(a, b []typ.Type) bool { + if len(a) == 0 || len(b) == 0 || len(a) != len(b) { + return false + } + for i := range a { + if _, ok := a[i].(*typ.Record); !ok { + return false + } + if !value.ExtendsRecord(a[i], b[i]) { + return false + } + } + return true +} + +// ElidesOptional reports whether a refines b by removing nil/optional parts. +func ElidesOptional(a, b []typ.Type) bool { + if len(a) == 0 || len(b) == 0 || len(a) != len(b) { + return false + } + for i := range a { + if !value.ElidesOptional(a[i], b[i]) { + return false + } + } + return true +} + +// SelectPreferred picks a canonical winner when one return vector is strictly +// preferable to the other without requiring a join. +func SelectPreferred(a, b []typ.Type) ([]typ.Type, bool) { + if RepairsNever(a, b) { + return a, true + } + if RepairsNever(b, a) { + return b, true + } + if RefinesSoftContainers(a, b) { + return a, true + } + if RefinesSoftContainers(b, a) { + return b, true + } + if StopsRecursiveStructuralGrowth(a, b) { + return a, true + } + if StopsRecursiveStructuralGrowth(b, a) { + return b, true + } + if RefinesFalsyMapKeys(a, b) { + return a, true + } + if RefinesFalsyMapKeys(b, a) { + return b, true + } + if Refines(a, b) { + if AllNil(a) && !AllNil(b) { + return b, true + } + if NestedNilOnlyRegression(a, b) { + return b, true + } + return a, true + } + if Refines(b, a) { + if AllNil(b) && !AllNil(a) { + return a, true + } + if NestedNilOnlyRegression(b, a) { + return a, true + } + return b, true + } + if FillsNilSlots(a, b) { + return a, true + } + if FillsNilSlots(b, a) { + return b, true + } + if (ExtendsRecord(a, b) || ElidesOptional(a, b)) && !NestedNilOnlyRegression(a, b) { + return a, true + } + if (ExtendsRecord(b, a) || ElidesOptional(b, a)) && !NestedNilOnlyRegression(b, a) { + return b, true + } + return nil, false +} + +// RefinesSoftContainers reports whether candidate preserves the same table +// shape while replacing soft placeholder element/value evidence with concrete +// evidence. +func RefinesSoftContainers(candidate, baseline []typ.Type) bool { + if len(candidate) == 0 || len(baseline) == 0 || len(candidate) != len(baseline) { + return false + } + strict := false + for i := range candidate { + refines, changed := value.RefinesSoftContainer(candidate[i], baseline[i]) + if !refines { + return false + } + if changed { + strict = true + } + } + return strict +} + +// RefinesFalsyMapKeys reports whether candidate is the same table-derived shape +// as baseline after removing stale falsy members from baseline. +func RefinesFalsyMapKeys(candidate, baseline []typ.Type) bool { + if len(candidate) == 0 || len(baseline) == 0 || len(candidate) != len(baseline) { + return false + } + strict := false + for i := range candidate { + refines, changed := value.RefinesFalsyMapKey(candidate[i], baseline[i]) + if !refines { + return false + } + if changed { + strict = true + } + } + return strict +} + +// NestedNilOnlyRegression reports whether candidate's apparent refinement only +// adds nested nil facts over a more useful baseline shape. +func NestedNilOnlyRegression(candidate, baseline []typ.Type) bool { + if len(candidate) == 0 || len(baseline) == 0 || len(candidate) != len(baseline) { + return false + } + for i := range candidate { + if value.NestedNilOnlyRegression(candidate[i], baseline[i]) { + return true + } + } + return false +} + +// StopsRecursiveStructuralGrowth reports whether growing embeds the same +// structural container shape as stable beneath its root. +func StopsRecursiveStructuralGrowth(stable, growing []typ.Type) bool { + if len(stable) == 0 || len(growing) == 0 || len(stable) != len(growing) { + return false + } + + strict := false + for i := range stable { + s := stable[i] + g := growing[i] + if s == nil || g == nil { + return false + } + if typ.TypeEquals(s, g) { + continue + } + if typ.IsAbsentOrUnknown(s) || !value.CanSelfEmbed(s) { + return false + } + if !value.ShallowStructuralShapeEquals(g, s) { + return false + } + if !value.ContainsNestedStructuralShape(g, s) { + return false + } + strict = true + } + return strict +} + +// SelectRefining prefers candidate only when it directionally refines baseline. +func SelectRefining(candidate, baseline []typ.Type) ([]typ.Type, bool) { + if Refines(candidate, baseline) { + if AllNil(candidate) && !AllNil(baseline) { + return baseline, true + } + return candidate, true + } + if FillsNilSlots(candidate, baseline) { + return candidate, true + } + if ExtendsRecord(candidate, baseline) || ElidesOptional(candidate, baseline) { + return candidate, true + } + return nil, false +} + +// FillsNilSlots reports whether a improves b by replacing nil-only slots with +// concrete return evidence while staying compatible on other slots. +func FillsNilSlots(a, b []typ.Type) bool { + if len(a) == 0 || len(b) == 0 || len(a) != len(b) { + return false + } + strict := false + for i := range a { + ai := a[i] + bi := b[i] + if ai == nil || bi == nil { + return false + } + if unwrap.IsNilType(bi) && !unwrap.IsNilType(ai) { + strict = true + continue + } + if typ.TypeEquals(ai, bi) { + continue + } + if subtype.IsSubtype(ai, bi) || value.ExtendsRecord(ai, bi) || value.ElidesOptional(ai, bi) { + continue + } + return false + } + return strict +} + +// RepairsNever reports whether candidate is a runtime-possible repair of +// baseline by replacing nested never artifacts while otherwise widening +// compatibly. +func RepairsNever(candidate, baseline []typ.Type) bool { + if len(candidate) == 0 || len(baseline) == 0 || len(candidate) != len(baseline) { + return false + } + strict := false + for i := range candidate { + if candidate[i] == nil || baseline[i] == nil { + return false + } + if typ.TypeEquals(candidate[i], baseline[i]) { + continue + } + if !repairsNever(candidate[i], baseline[i]) { + return false + } + strict = true + } + return strict +} + +func repairsNever(candidate, baseline typ.Type) bool { + if candidate == nil || baseline == nil { + return false + } + if !containsNever(baseline) || containsNever(candidate) { + return false + } + ok, strict := neverRepairRelation(candidate, baseline) + return ok && strict +} + +func neverRepairRelation(candidate, baseline typ.Type) (bool, bool) { + if candidate == nil || baseline == nil { + return false, false + } + if typ.TypeEquals(candidate, baseline) { + return true, false + } + + candidate = unwrap.Alias(candidate) + baseline = unwrap.Alias(baseline) + if candidate == nil || baseline == nil { + return false, false + } + + if typ.IsNever(baseline) { + return !typ.IsNever(candidate), !typ.IsNever(candidate) + } + if !containsNever(baseline) { + return false, false + } + + switch b := baseline.(type) { + case *typ.Optional: + c, ok := candidate.(*typ.Optional) + if !ok { + return false, false + } + return neverRepairRelation(c.Inner, b.Inner) + case *typ.Union: + c, ok := candidate.(*typ.Union) + if !ok || len(c.Members) != len(b.Members) { + return false, false + } + used := make([]bool, len(c.Members)) + strict := false + for _, bm := range b.Members { + matched := false + for j, cm := range c.Members { + if used[j] || !typ.TypeEquals(cm, bm) { + continue + } + used[j] = true + matched = true + break + } + if matched { + continue + } + for j, cm := range c.Members { + if used[j] { + continue + } + ok, repaired := neverRepairRelation(cm, bm) + if !ok { + continue + } + used[j] = true + matched = true + if repaired { + strict = true + } + break + } + if !matched { + return false, false + } + } + return true, strict + case *typ.Record: + c, ok := candidate.(*typ.Record) + if !ok || c.Open != b.Open || c.HasMapComponent() != b.HasMapComponent() || len(c.Fields) != len(b.Fields) { + return false, false + } + strict := false + for _, bf := range b.Fields { + cf := c.GetField(bf.Name) + if cf == nil || cf.Optional != bf.Optional || cf.Readonly != bf.Readonly { + return false, false + } + ok, repaired := neverRepairRelation(cf.Type, bf.Type) + if !ok { + return false, false + } + if repaired { + strict = true + } + } + if b.HasMapComponent() { + ok, repaired := neverRepairRelation(c.MapKey, b.MapKey) + if !ok { + return false, false + } + if repaired { + strict = true + } + ok, repaired = neverRepairRelation(c.MapValue, b.MapValue) + if !ok { + return false, false + } + if repaired { + strict = true + } + } + if b.Metatable != nil || c.Metatable != nil { + if b.Metatable == nil || c.Metatable == nil { + return false, false + } + ok, repaired := neverRepairRelation(c.Metatable, b.Metatable) + if !ok { + return false, false + } + if repaired { + strict = true + } + } + return true, strict + case *typ.Array: + c, ok := candidate.(*typ.Array) + if !ok { + return false, false + } + return neverRepairRelation(c.Element, b.Element) + case *typ.Map: + c, ok := candidate.(*typ.Map) + if !ok { + return false, false + } + keyOK, keyStrict := neverRepairRelation(c.Key, b.Key) + if !keyOK { + return false, false + } + valOK, valStrict := neverRepairRelation(c.Value, b.Value) + if !valOK { + return false, false + } + return true, keyStrict || valStrict + case *typ.Tuple: + c, ok := candidate.(*typ.Tuple) + if !ok || len(c.Elements) != len(b.Elements) { + return false, false + } + strict := false + for i := range b.Elements { + ok, repaired := neverRepairRelation(c.Elements[i], b.Elements[i]) + if !ok { + return false, false + } + if repaired { + strict = true + } + } + return true, strict + case *typ.Function: + c, ok := candidate.(*typ.Function) + if !ok || !sameFunctionShapeForRepair(c, b) || len(c.Returns) != len(b.Returns) { + return false, false + } + for i := range b.Params { + if c.Params[i].Name != b.Params[i].Name || + c.Params[i].Optional != b.Params[i].Optional || + !typ.TypeEquals(c.Params[i].Type, b.Params[i].Type) { + return false, false + } + } + switch { + case (c.Variadic == nil) != (b.Variadic == nil): + return false, false + case c.Variadic != nil && !typ.TypeEquals(c.Variadic, b.Variadic): + return false, false + } + strict := false + for i := range b.Returns { + ok, repaired := neverRepairRelation(c.Returns[i], b.Returns[i]) + if !ok { + return false, false + } + if repaired { + strict = true + } + } + return true, strict + default: + return false, false + } +} + +func sameFunctionShapeForRepair(a, b *typ.Function) bool { + if a == nil || b == nil { + return false + } + if !typeParamsEqual(a.TypeParams, b.TypeParams) { + return false + } + return len(a.Params) == len(b.Params) +} + +func typeParamsEqual(a, b []*typ.TypeParam) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] == nil || b[i] == nil { + if a[i] != b[i] { + return false + } + continue + } + if !a[i].Equals(b[i]) { + return false + } + } + return true +} + +func containsNever(t typ.Type) bool { + seen := make(map[typ.Type]bool) + return containsNeverMemo(t, seen) +} + +func containsNeverMemo(t typ.Type, seen map[typ.Type]bool) bool { + if t == nil { + return false + } + if seen[t] { + return false + } + seen[t] = true + t = unwrap.Alias(t) + if t == nil { + return false + } + if typ.IsNever(t) { + return true + } + return typ.Visit(t, typ.Visitor[bool]{ + Optional: func(o *typ.Optional) bool { + return containsNeverMemo(o.Inner, seen) + }, + Union: func(u *typ.Union) bool { + for _, m := range u.Members { + if containsNeverMemo(m, seen) { + return true + } + } + return false + }, + Intersection: func(in *typ.Intersection) bool { + for _, m := range in.Members { + if containsNeverMemo(m, seen) { + return true + } + } + return false + }, + Tuple: func(tup *typ.Tuple) bool { + for _, e := range tup.Elements { + if containsNeverMemo(e, seen) { + return true + } + } + return false + }, + Array: func(a *typ.Array) bool { + return containsNeverMemo(a.Element, seen) + }, + Map: func(m *typ.Map) bool { + return containsNeverMemo(m.Key, seen) || containsNeverMemo(m.Value, seen) + }, + Record: func(r *typ.Record) bool { + for _, f := range r.Fields { + if containsNeverMemo(f.Type, seen) { + return true + } + } + if r.HasMapComponent() { + return containsNeverMemo(r.MapKey, seen) || containsNeverMemo(r.MapValue, seen) + } + return false + }, + Function: func(fn *typ.Function) bool { + for _, p := range fn.Params { + if containsNeverMemo(p.Type, seen) { + return true + } + } + if fn.Variadic != nil && containsNeverMemo(fn.Variadic, seen) { + return true + } + for _, ret := range fn.Returns { + if containsNeverMemo(ret, seen) { + return true + } + } + return false + }, + Default: func(typ.Type) bool { + return false + }, + }) +} + +// Normalize replaces nil slots with explicit nil types in a copied vector. +func Normalize(rets []typ.Type) []typ.Type { + if len(rets) == 0 { + return nil + } + out := make([]typ.Type, len(rets)) + copy(out, rets) + return NormalizeOwned(out) +} + +// NormalizeOwned replaces nil slots with explicit nil types in an owned vector. +func NormalizeOwned(rets []typ.Type) []typ.Type { + if len(rets) == 0 { + return nil + } + for i, t := range rets { + if t == nil { + rets[i] = typ.Nil + } + } + return rets +} + +// Canonical returns a vector with explicit nil slots, reusing the input when it +// is already canonical. +func Canonical(rets []typ.Type) []typ.Type { + if len(rets) == 0 { + return nil + } + for i, t := range rets { + if t != nil { + continue + } + out := make([]typ.Type, len(rets)) + copy(out, rets) + out[i] = typ.Nil + for j := i + 1; j < len(out); j++ { + if out[j] == nil { + out[j] = typ.Nil + } + } + return out + } + return rets +} + +// NormalizeAndPrune canonicalizes nil slots and removes soft union members. +func NormalizeAndPrune(rets []typ.Type) []typ.Type { + if len(rets) == 0 { + return nil + } + var out []typ.Type + for i, ret := range rets { + normalized := ret + if normalized == nil { + normalized = typ.Nil + } + pruned := typ.PruneSoftUnionMembers(normalized) + if out != nil { + out[i] = pruned + continue + } + if pruned == ret { + continue + } + out = make([]typ.Type, len(rets)) + copy(out, rets[:i]) + out[i] = pruned + } + if out != nil { + return out + } + return rets +} + +// Merge applies the canonical return-summary merge policy shared by iterative +// channels. +func Merge(existing, candidate []typ.Type) []typ.Type { + existing = NormalizeAndPrune(existing) + candidate = NormalizeAndPrune(candidate) + if len(existing) == 0 { + return candidate + } + if len(candidate) == 0 { + return existing + } + if replaced, ok := replaceOpenTopWithStructured(existing, candidate); ok { + existing = NormalizeAndPrune(replaced) + } + if RepairsNever(existing, candidate) { + return existing + } + if RepairsNever(candidate, existing) { + return candidate + } + if shouldUseMonotoneJoin(existing, candidate) { + return NormalizeAndPrune(joinMonotone(existing, candidate)) + } + if preferred, ok := SelectPreferred(existing, candidate); ok { + return NormalizeAndPrune(preferred) + } + return NormalizeAndPrune(typjoin.ReturnVectors(existing, candidate)) +} + +func shouldUseMonotoneJoin(a, b []typ.Type) bool { + for _, t := range a { + if HasHigherOrderGrowthRisk(t) { + return true + } + } + for _, t := range b { + if HasHigherOrderGrowthRisk(t) { + return true + } + } + return false +} + +// HasHigherOrderGrowthRisk reports whether a type can produce non-monotone +// higher-order structural growth across summary iterations. +func HasHigherOrderGrowthRisk(t typ.Type) bool { + if t == nil { + return false + } + return value.Scan(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { + switch n := node.(type) { + case *typ.Function: + for _, ret := range n.Returns { + if containsFunction(ret) { + return true, false + } + } + case *typ.Record: + if recordHasSelfRecursiveMethod(n) { + return true, false + } + } + return false, true + }) +} + +func containsFunction(t typ.Type) bool { + if t == nil { + return false + } + return value.Scan(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { + if _, ok := node.(*typ.Interface); ok { + return false, false + } + if _, ok := node.(*typ.Function); ok { + return true, false + } + return false, true + }) +} + +func recordHasSelfRecursiveMethod(r *typ.Record) bool { + if r == nil { + return false + } + for _, f := range r.Fields { + if methodTypeHasSelfRecursiveReturn(f.Type, r) { + return true + } + } + return r.HasMapComponent() && methodTypeHasSelfRecursiveReturn(r.MapValue, r) +} + +func methodTypeHasSelfRecursiveReturn(t typ.Type, owner *typ.Record) bool { + if t == nil || owner == nil { + return false + } + return value.Scan(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { + if _, ok := node.(*typ.Interface); ok { + return false, false + } + fn, ok := node.(*typ.Function) + if !ok { + return false, true + } + for _, ret := range fn.Returns { + if ret == nil { + continue + } + if subtype.IsSubtype(ret, owner) || subtype.IsSubtype(owner, ret) || + value.ExtendsRecord(ret, owner) || value.ExtendsRecord(owner, ret) { + return true, false + } + } + return false, true + }) +} + +func joinMonotone(a, b []typ.Type) []typ.Type { + if len(a) == 0 { + return b + } + if len(b) == 0 { + return a + } + maxLen := len(a) + if len(b) > maxLen { + maxLen = len(b) + } + out := make([]typ.Type, maxLen) + for i := 0; i < maxLen; i++ { + var ai, bi typ.Type + if i < len(a) { + ai = a[i] + } + if i < len(b) { + bi = b[i] + } + out[i] = joinTypeMonotone(ai, bi) + } + return out +} + +func joinTypeMonotone(a, b typ.Type) typ.Type { + if a == nil { + return b + } + if b == nil { + return a + } + if typ.TypeEquals(a, b) { + return a + } + if subtype.IsSubtype(a, b) || value.ExtendsRecord(a, b) || value.ElidesOptional(a, b) { + return b + } + if subtype.IsSubtype(b, a) || value.ExtendsRecord(b, a) || value.ElidesOptional(b, a) { + return a + } + return typ.JoinPreferNonSoft(a, b) +} + +// AlignFunction applies the canonical return-summary winner to a function type. +func AlignFunction(fn *typ.Function, summary []typ.Type) (*typ.Function, bool) { + if fn == nil { + return nil, false + } + + normalizedSummary := NormalizeAndPrune(summary) + if len(normalizedSummary) == 0 { + return fn, false + } + + current := NormalizeAndPrune(fn.Returns) + if len(current) == 0 { + aligned := typjoin.WithReturns(fn, normalizedSummary) + return aligned, aligned != nil + } + merged := Merge(current, normalizedSummary) + if Equal(current, merged) { + return fn, false + } + + aligned := typjoin.WithReturns(fn, merged) + if aligned == nil { + return fn, false + } + return aligned, true +} + +func replaceOpenTopWithStructured(current, summary []typ.Type) ([]typ.Type, bool) { + if len(current) == 0 || len(summary) == 0 || len(current) != len(summary) { + return nil, false + } + out := append([]typ.Type(nil), current...) + changed := false + for i := range out { + if !value.IsOpenTopRecord(out[i]) { + continue + } + if !value.IsStructuredTableShape(summary[i]) { + continue + } + out[i] = summary[i] + changed = true + } + if !changed { + return nil, false + } + return out, true +} + +// ApplyToFunctionType applies summary-derived returns to a function signature. +// If both summary and signature returns are empty, it attaches unknown to keep +// call-site checking conservative. +func ApplyToFunctionType(fn *typ.Function, summary []typ.Type) *typ.Function { + if fn == nil { + return nil + } + if len(summary) == 0 { + if len(fn.Returns) > 0 { + return fn + } + return typjoin.WithReturns(fn, []typ.Type{typ.Unknown}) + } + if aligned, changed := AlignFunction(fn, summary); changed { + return aligned + } + if len(fn.Returns) > 0 { + return fn + } + return typjoin.WithReturns(fn, NormalizeAndPrune(summary)) +} diff --git a/compiler/check/domain/returnsummary/summary_test.go b/compiler/check/domain/returnsummary/summary_test.go new file mode 100644 index 00000000..145beb68 --- /dev/null +++ b/compiler/check/domain/returnsummary/summary_test.go @@ -0,0 +1,47 @@ +package returnsummary + +import ( + "testing" + + "github.com/wippyai/go-lua/types/typ" +) + +func TestHigherOrderGrowthRisk_DetectsFunctionReturningFunction(t *testing.T) { + tp := typ.Func(). + Returns(typ.Func().Returns(typ.String).Build()). + Build() + if !HasHigherOrderGrowthRisk(tp) { + t.Fatalf("expected higher-order growth risk to be detected") + } +} + +func TestContainsFunction_IgnoresInterfaceMethodSignatures(t *testing.T) { + iface := typ.NewInterface("Reader", []typ.Method{ + { + Name: "next", + Type: typ.Func(). + Param("self", typ.Self). + Returns(typ.Func().Returns(typ.String).Build()). + Build(), + }, + }) + if containsFunction(iface) { + t.Fatalf("expected interface method signatures to be ignored, got true") + } +} + +func TestMethodTypeHasSelfRecursiveReturn_IgnoresInterfaceMethods(t *testing.T) { + owner := typ.NewRecord().Field("id", typ.String).Build() + methodType := typ.NewInterface("HasBuild", []typ.Method{ + { + Name: "build", + Type: typ.Func(). + Param("self", typ.Self). + Returns(owner). + Build(), + }, + }) + if methodTypeHasSelfRecursiveReturn(methodType, owner) { + t.Fatalf("expected interface method signatures to be ignored for self-recursive detection") + } +} diff --git a/compiler/check/infer/interproc/postflow.go b/compiler/check/infer/interproc/postflow.go index 4e440627..ffdf941d 100644 --- a/compiler/check/infer/interproc/postflow.go +++ b/compiler/check/infer/interproc/postflow.go @@ -7,6 +7,7 @@ import ( "github.com/wippyai/go-lua/compiler/check/api" checkcallsite "github.com/wippyai/go-lua/compiler/check/callsite" "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" + "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/compiler/check/erreffect" "github.com/wippyai/go-lua/compiler/check/nested" "github.com/wippyai/go-lua/compiler/check/returns" @@ -65,10 +66,10 @@ func StoreFactsFromResult( if fnType == nil { return } - narrowSummary := returns.NormalizeReturnVector(fnType.Returns) + narrowSummary := returnsummary.Normalize(fnType.Returns) if snapNarrow := narrowSummarySnapshotForSymbol(store, result, parent, fnSym); len(snapNarrow) > 0 { - narrowSummary = returns.MergeReturnSummary(narrowSummary, snapNarrow) - if aligned, changed := returns.AlignFunctionTypeWithSummary(fnType, narrowSummary); changed { + narrowSummary = returnsummary.Merge(narrowSummary, snapNarrow) + if aligned, changed := returnsummary.AlignFunction(fnType, narrowSummary); changed { fnType = aligned } } diff --git a/compiler/check/infer/return/infer.go b/compiler/check/infer/return/infer.go index c1fadceb..370f7d67 100644 --- a/compiler/check/infer/return/infer.go +++ b/compiler/check/infer/return/infer.go @@ -44,6 +44,7 @@ import ( "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/callsite" "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" + "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" flowpath "github.com/wippyai/go-lua/compiler/check/flowbuild/path" "github.com/wippyai/go-lua/compiler/check/modules" "github.com/wippyai/go-lua/compiler/check/phase" @@ -285,7 +286,7 @@ func (i *Inferencer) buildLocalFunctionTypes( } } if returnVector := returnVectors[sym]; len(returnVector) > 0 { - if withSummary := returns.WithSummaryOrUnknown(fnType, returnVector); withSummary != nil { + if withSummary := returnsummary.ApplyToFunctionType(fnType, returnVector); withSummary != nil { fnType = withSummary } } @@ -431,7 +432,7 @@ func collectReturnTypes( returnTypes = joinReturnTypes(returnTypes, types) }) - return returns.NormalizeReturnVectorInPlace(returnTypes) + return returnsummary.NormalizeOwned(returnTypes) } // synthesizeReturnExprs computes types for a single return statement's expressions. @@ -516,7 +517,7 @@ func (i *Inferencer) inferReturnTypesFromBody( ) declared := collectReturnTypes(fnGraph, declSynth, nil, skipUnresolvedLocalCall) - return returns.MergeReturnSummary(declared, narrowed) + return returnsummary.Merge(declared, narrowed) } func (i *Inferencer) skipUnresolvedLocalReturnCall(ctx *returnInferenceContext) func(ast.Expr) bool { @@ -539,7 +540,7 @@ func (i *Inferencer) skipUnresolvedLocalReturnCall(ctx *returnInferenceContext) if sym == 0 || ctx.localFuncs[sym] == nil { return false } - return typ.IsUnknownOnlyOrEmpty(returns.NormalizeReturnVector(ctx.returnVectors[sym])) + return typ.IsUnknownOnlyOrEmpty(returnsummary.Normalize(ctx.returnVectors[sym])) } } diff --git a/compiler/check/infer/return/overlay_pipeline.go b/compiler/check/infer/return/overlay_pipeline.go index 602ea8fe..75deedb2 100644 --- a/compiler/check/infer/return/overlay_pipeline.go +++ b/compiler/check/infer/return/overlay_pipeline.go @@ -6,6 +6,7 @@ import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" + "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/compiler/check/flowbuild/assign" fbcore "github.com/wippyai/go-lua/compiler/check/flowbuild/core" "github.com/wippyai/go-lua/compiler/check/flowbuild/mutator" @@ -152,7 +153,7 @@ func (i *Inferencer) collectAllReturnVectors(ctx *returnInferenceContext) map[cf if sym == 0 { continue } - normalized := returns.NormalizeReturnVectorInPlace(ctx.returnVectors[sym]) + normalized := returnsummary.NormalizeOwned(ctx.returnVectors[sym]) if len(normalized) == 0 { continue } @@ -173,7 +174,7 @@ func (i *Inferencer) returnVectorFromSnapshot( if len(facts) == 0 { return nil } - normalized := returns.NormalizeReturnVector(facts.Summary(sym)) + normalized := returnsummary.Normalize(facts.Summary(sym)) if len(normalized) == 0 { return nil } @@ -190,7 +191,7 @@ func (i *Inferencer) resolveLocalFunctionReturns( } // Keep the current SCC-derived return vector unless it is still unknown-only. - returnVector := returns.NormalizeReturnVectorInPlace(allReturnVectors[sym]) + returnVector := returnsummary.NormalizeOwned(allReturnVectors[sym]) if !typ.IsUnknownOnlyOrEmpty(returnVector) { return returnVector } @@ -259,7 +260,7 @@ func (i *Inferencer) enrichOverlayWithLocalFunctions( if localInfo := ctx.localFuncs[target.Symbol]; localInfo != nil && len(localInfo.ParameterEvidence) > 0 && sig != nil { sig = paramevidence.MergeIntoSignature(fnExpr, localInfo.ParameterEvidence, sig) } - if fnType := returns.WithSummaryOrUnknown(sig, returnVector); fnType != nil { + if fnType := returnsummary.ApplyToFunctionType(sig, returnVector); fnType != nil { overlay[target.Symbol] = fnType } } @@ -987,7 +988,7 @@ func functionFactsFromReturnVectors(returnVectors map[cfg.SymbolID][]typ.Type) a if sym == 0 { continue } - returnVector := returns.NormalizeReturnVectorInPlace(returnVectors[sym]) + returnVector := returnsummary.NormalizeOwned(returnVectors[sym]) if len(returnVector) == 0 { continue } diff --git a/compiler/check/infer/return/scc.go b/compiler/check/infer/return/scc.go index b8852216..9d3e1fb9 100644 --- a/compiler/check/infer/return/scc.go +++ b/compiler/check/infer/return/scc.go @@ -5,6 +5,7 @@ import ( "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" + "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/types/diag" "github.com/wippyai/go-lua/types/typ" @@ -104,9 +105,9 @@ func (i *Inferencer) runSCCIteration( } newReturn := i.inferReturnForFunction(run, info, returnVectors, localFuncs) oldReturn := returnVectors[sym] - merged := returns.MergeReturnSummary(oldReturn, newReturn) + merged := returnsummary.Merge(oldReturn, newReturn) next[sym] = merged - if !returns.ReturnTypesEqual(merged, oldReturn) { + if !returnsummary.Equal(merged, oldReturn) { changed = true } } diff --git a/compiler/check/phase/scope.go b/compiler/check/phase/scope.go index 2a065fb6..e6ea649d 100644 --- a/compiler/check/phase/scope.go +++ b/compiler/check/phase/scope.go @@ -28,7 +28,7 @@ import ( "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/callsite" "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" - "github.com/wippyai/go-lua/compiler/check/returns" + "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/compiler/check/scope" "github.com/wippyai/go-lua/compiler/check/synth" basecfg "github.com/wippyai/go-lua/types/cfg" @@ -458,7 +458,7 @@ func buildDeclaredTypes( return fn } if summary := functionFacts.Summary(sym); len(summary) > 0 { - return returns.WithSummaryOrUnknown(fn, summary) + return returnsummary.ApplyToFunctionType(fn, summary) } return fn } diff --git a/compiler/check/returns/domain_law_test.go b/compiler/check/returns/domain_law_test.go index 17116992..080835d3 100644 --- a/compiler/check/returns/domain_law_test.go +++ b/compiler/check/returns/domain_law_test.go @@ -6,6 +6,7 @@ import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/types/constraint" "github.com/wippyai/go-lua/types/typ" ) @@ -59,10 +60,10 @@ func TestFactsDomain_ProductOperatorsAreIdempotentAcrossAllDomains(t *testing.T) t.Fatalf("Join must be idempotent across the product domain") } - if got := normalized.FunctionFacts.Summary(fnSym); !ReturnTypesEqual(got, []typ.Type{typ.String}) { + if got := normalized.FunctionFacts.Summary(fnSym); !returnsummary.Equal(got, []typ.Type{typ.String}) { t.Fatalf("summary must come from canonical FunctionFacts, got %v", got) } - if got := normalized.FunctionFacts.NarrowSummary(fnSym); !ReturnTypesEqual(got, []typ.Type{typ.String}) { + if got := normalized.FunctionFacts.NarrowSummary(fnSym); !returnsummary.Equal(got, []typ.Type{typ.String}) { t.Fatalf("narrow summary must come from canonical FunctionFacts, got %v", got) } if got := normalized.FunctionFacts.FunctionType(fnSym); got == nil { diff --git a/compiler/check/returns/equal.go b/compiler/check/returns/equal.go index 95c57bb5..fda479b5 100644 --- a/compiler/check/returns/equal.go +++ b/compiler/check/returns/equal.go @@ -3,6 +3,8 @@ package returns import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" + "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/types/typ" ) @@ -40,13 +42,13 @@ func FunctionFactsEqual(a, b api.FunctionFacts) bool { if !ok { return false } - if !ReturnTypesEqual(af.Params, bf.Params) { + if !paramevidence.EqualVectors(af.Params, bf.Params) { return false } - if !ReturnTypesEqual(af.Summary, bf.Summary) { + if !returnsummary.Equal(af.Summary, bf.Summary) { return false } - if !ReturnTypesEqual(af.Narrow, bf.Narrow) { + if !returnsummary.Equal(af.Narrow, bf.Narrow) { return false } if !typ.TypeEquals(af.Type, bf.Type) { diff --git a/compiler/check/returns/function_facts.go b/compiler/check/returns/function_facts.go index e301a316..11bf1366 100644 --- a/compiler/check/returns/function_facts.go +++ b/compiler/check/returns/function_facts.go @@ -4,6 +4,7 @@ import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" + "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" ) func collectCanonicalFunctionFactSymbols(factSets ...api.FunctionFacts) []cfg.SymbolID { @@ -27,8 +28,8 @@ func markFunctionFactSymbols[T any](dst map[cfg.SymbolID]bool, src map[cfg.Symbo func NormalizeFunctionFact(ff api.FunctionFact) api.FunctionFact { return api.FunctionFact{ Params: paramevidence.FilterEmptyVector(ff.Params), - Summary: canonicalReturnVector(ff.Summary), - Narrow: canonicalReturnVector(ff.Narrow), + Summary: returnsummary.Canonical(ff.Summary), + Narrow: returnsummary.Canonical(ff.Narrow), Type: normalizeInterprocValueType(ff.Type), } } diff --git a/compiler/check/returns/join.go b/compiler/check/returns/join.go index 3ac0664f..f328f071 100644 --- a/compiler/check/returns/join.go +++ b/compiler/check/returns/join.go @@ -2,733 +2,13 @@ package returns import ( "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" + "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/compiler/check/domain/value" - "github.com/wippyai/go-lua/types/kind" "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" - typjoin "github.com/wippyai/go-lua/types/typ/join" "github.com/wippyai/go-lua/types/typ/unwrap" ) -// ReturnTypesEqual checks if two return vectors are structurally equal. -func ReturnTypesEqual(a, b []typ.Type) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if !typ.TypeEquals(a[i], b[i]) { - return false - } - } - return true -} - -// ReturnTypesAllNil reports whether all slots are explicit nil. -func ReturnTypesAllNil(rets []typ.Type) bool { - if len(rets) == 0 { - return false - } - for _, t := range rets { - if t == nil || t.Kind() != kind.Nil { - return false - } - } - return true -} - -// ReturnTypesRefine reports whether a refines b (element-wise subtype). -func ReturnTypesRefine(a, b []typ.Type) bool { - if len(a) == 0 { - return false - } - if len(b) == 0 { - return true - } - if len(a) != len(b) { - return false - } - for i := range a { - ai := a[i] - bi := b[i] - if ai == nil || bi == nil { - if ai == nil && bi == nil { - continue - } - return false - } - if !subtype.IsSubtype(ai, bi) { - return false - } - } - return true -} - -// ReturnTypesExtendRecord reports whether a extends b by adding record fields. -// This treats record field supersets as refinements for return vectors. -func ReturnTypesExtendRecord(a, b []typ.Type) bool { - if len(a) == 0 || len(b) == 0 { - return false - } - if len(a) != len(b) { - return false - } - for i := range a { - if _, ok := a[i].(*typ.Record); !ok { - return false - } - if !value.ExtendsRecord(a[i], b[i]) { - return false - } - } - return true -} - -// ReturnTypesElideOptional reports whether a refines b by removing nil/optional parts. -func ReturnTypesElideOptional(a, b []typ.Type) bool { - if len(a) == 0 || len(b) == 0 { - return false - } - if len(a) != len(b) { - return false - } - for i := range a { - if !value.ElidesOptional(a[i], b[i]) { - return false - } - } - return true -} - -// SelectPreferredReturnVector picks a canonical winner when one return vector -// is strictly preferable to the other without requiring a join. -// -// Preference order: -// 1. subtype refinement (with nil-only regression protection) -// 2. record extension -// 3. optional elision -// -// The nil-only guard prevents a refined-but-empty-looking update from -// regressing an already informative summary to just nil. -func SelectPreferredReturnVector(a, b []typ.Type) ([]typ.Type, bool) { - if ReturnTypesRepairNever(a, b) { - return a, true - } - if ReturnTypesRepairNever(b, a) { - return b, true - } - if ReturnTypesRefineSoftContainers(a, b) { - return a, true - } - if ReturnTypesRefineSoftContainers(b, a) { - return b, true - } - if ReturnTypesStopRecursiveStructuralGrowth(a, b) { - return a, true - } - if ReturnTypesStopRecursiveStructuralGrowth(b, a) { - return b, true - } - if ReturnTypesRefineFalsyMapKeys(a, b) { - return a, true - } - if ReturnTypesRefineFalsyMapKeys(b, a) { - return b, true - } - if ReturnTypesRefine(a, b) { - if ReturnTypesAllNil(a) && !ReturnTypesAllNil(b) { - return b, true - } - if ReturnTypesNestedNilOnlyRegression(a, b) { - return b, true - } - return a, true - } - if ReturnTypesRefine(b, a) { - if ReturnTypesAllNil(b) && !ReturnTypesAllNil(a) { - return a, true - } - if ReturnTypesNestedNilOnlyRegression(b, a) { - return a, true - } - return b, true - } - if ReturnTypesFillNilSlots(a, b) { - return a, true - } - if ReturnTypesFillNilSlots(b, a) { - return b, true - } - if (ReturnTypesExtendRecord(a, b) || ReturnTypesElideOptional(a, b)) && !ReturnTypesNestedNilOnlyRegression(a, b) { - return a, true - } - if (ReturnTypesExtendRecord(b, a) || ReturnTypesElideOptional(b, a)) && !ReturnTypesNestedNilOnlyRegression(b, a) { - return b, true - } - return nil, false -} - -// ReturnTypesRefineSoftContainers reports whether candidate preserves the same -// table shape while replacing soft placeholder element/value evidence with -// concrete evidence. This is a summary-lattice rule only; it does not weaken -// mutable map subtyping. -func ReturnTypesRefineSoftContainers(candidate, baseline []typ.Type) bool { - if len(candidate) == 0 || len(baseline) == 0 || len(candidate) != len(baseline) { - return false - } - strict := false - for i := range candidate { - refines, changed := value.RefinesSoftContainer(candidate[i], baseline[i]) - if !refines { - return false - } - if changed { - strict = true - } - } - return strict -} - -// ReturnTypesRefineFalsyMapKeys reports whether candidate is the same -// table-derived shape as baseline after removing stale falsy members from -// baseline. This handles fixed-point rounds where an early branch-insensitive -// dynamic index observes a key as `string | false`, then the solved guard proves -// the actual write key is `string`. -func ReturnTypesRefineFalsyMapKeys(candidate, baseline []typ.Type) bool { - if len(candidate) == 0 || len(baseline) == 0 || len(candidate) != len(baseline) { - return false - } - strict := false - for i := range candidate { - refines, changed := value.RefinesFalsyMapKey(candidate[i], baseline[i]) - if !refines { - return false - } - if changed { - strict = true - } - } - return strict -} - -// ReturnTypesNestedNilOnlyRegression reports whether candidate's apparent -// refinement only adds nested nil facts over a more useful baseline shape. A -// required `nil` field or `unknown -> nil` field does not help callers, but it -// can make iterative structural facts oscillate with later non-nil evidence. -func ReturnTypesNestedNilOnlyRegression(candidate, baseline []typ.Type) bool { - if len(candidate) == 0 || len(baseline) == 0 || len(candidate) != len(baseline) { - return false - } - for i := range candidate { - if value.NestedNilOnlyRegression(candidate[i], baseline[i]) { - return true - } - } - return false -} - -// ReturnTypesStopRecursiveStructuralGrowth reports whether growing embeds the -// same structural container shape as stable beneath its root. Recursive table -// builders such as deep-copy helpers otherwise look like ever-more-specific -// refinements: {[string]: any} -> {[string]: {[string]: nil}} -> ... . The -// existing top-level shape is already a sound upper bound, so keep it once the -// candidate starts feeding that shape back through one of its children. -func ReturnTypesStopRecursiveStructuralGrowth(stable, growing []typ.Type) bool { - if len(stable) == 0 || len(growing) == 0 || len(stable) != len(growing) { - return false - } - - strict := false - for i := range stable { - s := stable[i] - g := growing[i] - if s == nil || g == nil { - return false - } - if typ.TypeEquals(s, g) { - continue - } - if typ.IsAbsentOrUnknown(s) || !value.CanSelfEmbed(s) { - return false - } - if !value.ShallowStructuralShapeEquals(g, s) { - return false - } - if !value.ContainsNestedStructuralShape(g, s) { - return false - } - strict = true - } - return strict -} - -// SelectRefiningReturnVector prefers candidate only when it is a directional -// refinement of baseline. It never prefers baseline over candidate. -// -// This is used in iterative channels where an older baseline may be an -// under-constrained artifact; in those cases we must not lock in baseline just -// because it happens to be a subtype of the newer estimate. -func SelectRefiningReturnVector(candidate, baseline []typ.Type) ([]typ.Type, bool) { - if ReturnTypesRefine(candidate, baseline) { - if ReturnTypesAllNil(candidate) && !ReturnTypesAllNil(baseline) { - return baseline, true - } - return candidate, true - } - if ReturnTypesFillNilSlots(candidate, baseline) { - return candidate, true - } - if ReturnTypesExtendRecord(candidate, baseline) || ReturnTypesElideOptional(candidate, baseline) { - return candidate, true - } - return nil, false -} - -// ReturnTypesFillNilSlots reports whether a improves b by replacing nil-only -// slots with concrete return evidence while staying compatible on other slots. -func ReturnTypesFillNilSlots(a, b []typ.Type) bool { - if len(a) == 0 || len(b) == 0 || len(a) != len(b) { - return false - } - strict := false - for i := range a { - ai := a[i] - bi := b[i] - if ai == nil || bi == nil { - return false - } - if unwrap.IsNilType(bi) && !unwrap.IsNilType(ai) { - strict = true - continue - } - if typ.TypeEquals(ai, bi) { - continue - } - if subtype.IsSubtype(ai, bi) || value.ExtendsRecord(ai, bi) || value.ElidesOptional(ai, bi) { - continue - } - return false - } - return strict -} - -// ReturnTypesRepairNever reports whether candidate is a runtime-possible repair -// of baseline by replacing nested never artifacts while otherwise widening -// compatibly. This lets post-flow summaries correct pre-flow bottoms such as -// `{data?: never}` -> `{data?: unknown}`. -func ReturnTypesRepairNever(candidate, baseline []typ.Type) bool { - if len(candidate) == 0 || len(baseline) == 0 || len(candidate) != len(baseline) { - return false - } - strict := false - for i := range candidate { - if candidate[i] == nil || baseline[i] == nil { - return false - } - if typ.TypeEquals(candidate[i], baseline[i]) { - continue - } - if !typeRepairsNever(candidate[i], baseline[i]) { - return false - } - strict = true - } - return strict -} - -func typeRepairsNever(candidate, baseline typ.Type) bool { - if candidate == nil || baseline == nil { - return false - } - if !typeContainsNever(baseline) || typeContainsNever(candidate) { - return false - } - ok, strict := typeNeverRepairRelation(candidate, baseline) - return ok && strict -} - -func typeNeverRepairRelation(candidate, baseline typ.Type) (bool, bool) { - if candidate == nil || baseline == nil { - return false, false - } - if typ.TypeEquals(candidate, baseline) { - return true, false - } - - candidate = unwrap.Alias(candidate) - baseline = unwrap.Alias(baseline) - if candidate == nil || baseline == nil { - return false, false - } - - if typ.IsNever(baseline) { - return !typ.IsNever(candidate), !typ.IsNever(candidate) - } - if !typeContainsNever(baseline) { - return false, false - } - - switch b := baseline.(type) { - case *typ.Optional: - c, ok := candidate.(*typ.Optional) - if !ok { - return false, false - } - return typeNeverRepairRelation(c.Inner, b.Inner) - case *typ.Union: - c, ok := candidate.(*typ.Union) - if !ok || len(c.Members) != len(b.Members) { - return false, false - } - used := make([]bool, len(c.Members)) - strict := false - for _, bm := range b.Members { - matched := false - for j, cm := range c.Members { - if used[j] || !typ.TypeEquals(cm, bm) { - continue - } - used[j] = true - matched = true - break - } - if matched { - continue - } - for j, cm := range c.Members { - if used[j] { - continue - } - ok, repaired := typeNeverRepairRelation(cm, bm) - if !ok { - continue - } - used[j] = true - matched = true - if repaired { - strict = true - } - break - } - if !matched { - return false, false - } - } - return true, strict - case *typ.Record: - c, ok := candidate.(*typ.Record) - if !ok || c.Open != b.Open || c.HasMapComponent() != b.HasMapComponent() || len(c.Fields) != len(b.Fields) { - return false, false - } - strict := false - for _, bf := range b.Fields { - cf := c.GetField(bf.Name) - if cf == nil || cf.Optional != bf.Optional || cf.Readonly != bf.Readonly { - return false, false - } - ok, repaired := typeNeverRepairRelation(cf.Type, bf.Type) - if !ok { - return false, false - } - if repaired { - strict = true - } - } - if b.HasMapComponent() { - ok, repaired := typeNeverRepairRelation(c.MapKey, b.MapKey) - if !ok { - return false, false - } - if repaired { - strict = true - } - ok, repaired = typeNeverRepairRelation(c.MapValue, b.MapValue) - if !ok { - return false, false - } - if repaired { - strict = true - } - } - if b.Metatable != nil || c.Metatable != nil { - if b.Metatable == nil || c.Metatable == nil { - return false, false - } - ok, repaired := typeNeverRepairRelation(c.Metatable, b.Metatable) - if !ok { - return false, false - } - if repaired { - strict = true - } - } - return true, strict - case *typ.Array: - c, ok := candidate.(*typ.Array) - if !ok { - return false, false - } - return typeNeverRepairRelation(c.Element, b.Element) - case *typ.Map: - c, ok := candidate.(*typ.Map) - if !ok { - return false, false - } - keyOK, keyStrict := typeNeverRepairRelation(c.Key, b.Key) - if !keyOK { - return false, false - } - valOK, valStrict := typeNeverRepairRelation(c.Value, b.Value) - if !valOK { - return false, false - } - return true, keyStrict || valStrict - case *typ.Tuple: - c, ok := candidate.(*typ.Tuple) - if !ok || len(c.Elements) != len(b.Elements) { - return false, false - } - strict := false - for i := range b.Elements { - ok, repaired := typeNeverRepairRelation(c.Elements[i], b.Elements[i]) - if !ok { - return false, false - } - if repaired { - strict = true - } - } - return true, strict - case *typ.Function: - c, ok := candidate.(*typ.Function) - if !ok || !sameFunctionShapeForFactMerge(c, b) || len(c.Returns) != len(b.Returns) { - return false, false - } - for i := range b.Params { - if c.Params[i].Name != b.Params[i].Name || - c.Params[i].Optional != b.Params[i].Optional || - !typ.TypeEquals(c.Params[i].Type, b.Params[i].Type) { - return false, false - } - } - switch { - case (c.Variadic == nil) != (b.Variadic == nil): - return false, false - case c.Variadic != nil && !typ.TypeEquals(c.Variadic, b.Variadic): - return false, false - } - strict := false - for i := range b.Returns { - ok, repaired := typeNeverRepairRelation(c.Returns[i], b.Returns[i]) - if !ok { - return false, false - } - if repaired { - strict = true - } - } - return true, strict - default: - return false, false - } -} - -func typeContainsNever(t typ.Type) bool { - seen := make(map[typ.Type]bool) - return typeContainsNeverMemo(t, seen) -} - -func typeContainsNeverMemo(t typ.Type, seen map[typ.Type]bool) bool { - if t == nil { - return false - } - if seen[t] { - return false - } - seen[t] = true - t = unwrap.Alias(t) - if t == nil { - return false - } - if typ.IsNever(t) { - return true - } - return typ.Visit(t, typ.Visitor[bool]{ - Optional: func(o *typ.Optional) bool { - return typeContainsNeverMemo(o.Inner, seen) - }, - Union: func(u *typ.Union) bool { - for _, m := range u.Members { - if typeContainsNeverMemo(m, seen) { - return true - } - } - return false - }, - Intersection: func(in *typ.Intersection) bool { - for _, m := range in.Members { - if typeContainsNeverMemo(m, seen) { - return true - } - } - return false - }, - Tuple: func(tup *typ.Tuple) bool { - for _, e := range tup.Elements { - if typeContainsNeverMemo(e, seen) { - return true - } - } - return false - }, - Array: func(a *typ.Array) bool { - return typeContainsNeverMemo(a.Element, seen) - }, - Map: func(m *typ.Map) bool { - return typeContainsNeverMemo(m.Key, seen) || typeContainsNeverMemo(m.Value, seen) - }, - Record: func(r *typ.Record) bool { - for _, f := range r.Fields { - if typeContainsNeverMemo(f.Type, seen) { - return true - } - } - if r.HasMapComponent() { - return typeContainsNeverMemo(r.MapKey, seen) || typeContainsNeverMemo(r.MapValue, seen) - } - return false - }, - Function: func(fn *typ.Function) bool { - for _, p := range fn.Params { - if typeContainsNeverMemo(p.Type, seen) { - return true - } - } - if fn.Variadic != nil && typeContainsNeverMemo(fn.Variadic, seen) { - return true - } - for _, ret := range fn.Returns { - if typeContainsNeverMemo(ret, seen) { - return true - } - } - return false - }, - Default: func(typ.Type) bool { - return false - }, - }) -} - -// NormalizeReturnVector replaces nil slots with explicit nil types. -func NormalizeReturnVector(rets []typ.Type) []typ.Type { - if len(rets) == 0 { - return nil - } - out := make([]typ.Type, len(rets)) - copy(out, rets) - return NormalizeReturnVectorInPlace(out) -} - -// NormalizeReturnVectorInPlace replaces nil slots with explicit nil types in an -// owned return vector. -func NormalizeReturnVectorInPlace(rets []typ.Type) []typ.Type { - if len(rets) == 0 { - return nil - } - for i, t := range rets { - if t == nil { - rets[i] = typ.Nil - } - } - return rets -} - -func canonicalReturnVector(rets []typ.Type) []typ.Type { - if len(rets) == 0 { - return nil - } - for i, t := range rets { - if t != nil { - continue - } - out := make([]typ.Type, len(rets)) - copy(out, rets) - out[i] = typ.Nil - for j := i + 1; j < len(out); j++ { - if out[j] == nil { - out[j] = typ.Nil - } - } - return out - } - return rets -} - -func normalizeAndPruneReturnVector(rets []typ.Type) []typ.Type { - if len(rets) == 0 { - return nil - } - var out []typ.Type - for i, ret := range rets { - normalized := ret - if normalized == nil { - normalized = typ.Nil - } - pruned := typ.PruneSoftUnionMembers(normalized) - if out != nil { - out[i] = pruned - continue - } - if pruned == ret { - continue - } - out = make([]typ.Type, len(rets)) - copy(out, rets[:i]) - out[i] = pruned - } - if out != nil { - return out - } - return rets -} - -// MergeReturnSummary applies the canonical return-summary merge policy shared by -// all iterative channels (SCC return inference, interproc fact widening, and -// summary-to-signature alignment). Centralizing this logic prevents divergent -// local merge behavior across phases. -func MergeReturnSummary(existing, candidate []typ.Type) []typ.Type { - existing = normalizeAndPruneReturnVector(existing) - candidate = normalizeAndPruneReturnVector(candidate) - if len(existing) == 0 { - return candidate - } - if len(candidate) == 0 { - return existing - } - // Canonical promotion: open-top record placeholders should not dominate - // concrete structured return evidence (array/map/record with fields). - if replaced, ok := replaceOpenTopWithStructured(existing, candidate); ok { - existing = normalizeAndPruneReturnVector(replaced) - } - if ReturnTypesRepairNever(existing, candidate) { - return existing - } - if ReturnTypesRepairNever(candidate, existing) { - return candidate - } - - // Higher-order summaries are merged monotonically for fixpoint stability. - if shouldUseMonotoneReturnJoin(existing, candidate) { - return normalizeAndPruneReturnVector(joinReturnVectorsMonotone(existing, candidate)) - } - - if preferred, ok := SelectPreferredReturnVector(existing, candidate); ok { - return normalizeAndPruneReturnVector(preferred) - } - - return normalizeAndPruneReturnVector(typjoin.ReturnVectors(existing, candidate)) -} - // MergeFunctionFactType merges function-type facts through one canonical policy. // This ensures all channels agree on when to preserve shape and how to merge // returns, avoiding directional one-off behavior in individual phases. @@ -873,7 +153,7 @@ func mergeFunctionFactsByShape(existing, candidate *typ.Function) typ.Type { builder = builder.Variadic(mergeFunctionParamFactType(existing.Variadic, candidate.Variadic)) } - if mergedReturns := MergeReturnSummary(existing.Returns, candidate.Returns); len(mergedReturns) > 0 { + if mergedReturns := returnsummary.Merge(existing.Returns, candidate.Returns); len(mergedReturns) > 0 { builder = builder.Returns(mergedReturns...) } @@ -981,81 +261,3 @@ func preferStructuredRecordParam(existing, candidate typ.Type) (typ.Type, bool) } return nil, false } - -// AlignFunctionTypeWithSummary applies the canonical return-summary winner to a -// function type. It updates function returns only when the summary is the -// preferred vector under SelectPreferredReturnVector (or when function returns -// are missing). Returns the aligned function and whether it changed. -func AlignFunctionTypeWithSummary(fn *typ.Function, summary []typ.Type) (*typ.Function, bool) { - if fn == nil { - return nil, false - } - - normalizedSummary := normalizeAndPruneReturnVector(summary) - if len(normalizedSummary) == 0 { - return fn, false - } - - current := normalizeAndPruneReturnVector(fn.Returns) - if len(current) == 0 { - aligned := typjoin.WithReturns(fn, normalizedSummary) - return aligned, aligned != nil - } - // Keep one canonical merge path for summary-to-signature alignment. - // MergeReturnSummary already handles structured promotion and refinement - // policy, so AlignFunctionTypeWithSummary should not duplicate local logic. - merged := MergeReturnSummary(current, normalizedSummary) - if ReturnTypesEqual(current, merged) { - return fn, false - } - - aligned := typjoin.WithReturns(fn, merged) - if aligned == nil { - return fn, false - } - return aligned, true -} - -func replaceOpenTopWithStructured(current, summary []typ.Type) ([]typ.Type, bool) { - if len(current) == 0 || len(summary) == 0 || len(current) != len(summary) { - return nil, false - } - out := append([]typ.Type(nil), current...) - changed := false - for i := range out { - if !value.IsOpenTopRecord(out[i]) { - continue - } - if !value.IsStructuredTableShape(summary[i]) { - continue - } - out[i] = summary[i] - changed = true - } - if !changed { - return nil, false - } - return out, true -} - -// WithSummaryOrUnknown applies summary-derived returns to a function signature. -// If summary is empty and the signature has no returns, a single unknown return -// is attached to preserve call-site conservatism. -func WithSummaryOrUnknown(fn *typ.Function, summary []typ.Type) *typ.Function { - if fn == nil { - return nil - } - if len(summary) == 0 { - if len(fn.Returns) > 0 { - return fn - } - return typjoin.WithReturns(fn, []typ.Type{typ.Unknown}) - } - if aligned, changed := AlignFunctionTypeWithSummary(fn, summary); changed { - return aligned - } - if len(fn.Returns) > 0 { - return fn - } - return typjoin.WithReturns(fn, normalizeAndPruneReturnVector(summary)) -} diff --git a/compiler/check/returns/join_test.go b/compiler/check/returns/join_test.go index 5d61ea70..f7d59312 100644 --- a/compiler/check/returns/join_test.go +++ b/compiler/check/returns/join_test.go @@ -3,6 +3,7 @@ package returns import ( "testing" + "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/types/kind" "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" @@ -54,14 +55,14 @@ func TestTypJoinReturnSlot_PreservesUnknownOverNil(t *testing.T) { } } -func TestReturnTypesAllNil(t *testing.T) { - if !ReturnTypesAllNil([]typ.Type{typ.Nil}) { +func TestReturnSummaryAllNil(t *testing.T) { + if !returnsummary.AllNil([]typ.Type{typ.Nil}) { t.Fatal("expected [nil] to be nil-only") } - if ReturnTypesAllNil([]typ.Type{typ.Nil, typ.Unknown}) { + if returnsummary.AllNil([]typ.Type{typ.Nil, typ.Unknown}) { t.Fatal("expected [nil, unknown] to not be nil-only") } - if ReturnTypesAllNil(nil) { + if returnsummary.AllNil(nil) { t.Fatal("expected empty return vector to not be nil-only") } } @@ -75,108 +76,108 @@ func TestJoinReturnVectors_DifferentLengths(t *testing.T) { } } -func TestReturnTypesEqual_Empty(t *testing.T) { - if !ReturnTypesEqual(nil, nil) { +func TestReturnSummaryEqual_Empty(t *testing.T) { + if !returnsummary.Equal(nil, nil) { t.Error("nil slices should be equal") } } -func TestReturnTypesEqual_DifferentLength(t *testing.T) { +func TestReturnSummaryEqual_DifferentLength(t *testing.T) { a := []typ.Type{typ.String} b := []typ.Type{typ.String, typ.Number} - if ReturnTypesEqual(a, b) { + if returnsummary.Equal(a, b) { t.Error("different lengths should not be equal") } } -func TestReturnTypesEqual_Same(t *testing.T) { +func TestReturnSummaryEqual_Same(t *testing.T) { a := []typ.Type{typ.String, typ.Number} b := []typ.Type{typ.String, typ.Number} - if !ReturnTypesEqual(a, b) { + if !returnsummary.Equal(a, b) { t.Error("same types should be equal") } } -func TestReturnTypesEqual_Different(t *testing.T) { +func TestReturnSummaryEqual_Different(t *testing.T) { a := []typ.Type{typ.String} b := []typ.Type{typ.Number} - if ReturnTypesEqual(a, b) { + if returnsummary.Equal(a, b) { t.Error("different types should not be equal") } } -func TestReturnTypesRefine_EmptyA(t *testing.T) { +func TestReturnSummaryRefines_EmptyA(t *testing.T) { b := []typ.Type{typ.String} - if ReturnTypesRefine(nil, b) { + if returnsummary.Refines(nil, b) { t.Error("empty a should not refine b") } } -func TestReturnTypesRefine_EmptyB(t *testing.T) { +func TestReturnSummaryRefines_EmptyB(t *testing.T) { a := []typ.Type{typ.String} - if !ReturnTypesRefine(a, nil) { + if !returnsummary.Refines(a, nil) { t.Error("a should refine empty b") } } -func TestReturnTypesRefine_Same(t *testing.T) { +func TestReturnSummaryRefines_Same(t *testing.T) { a := []typ.Type{typ.String} b := []typ.Type{typ.String} - if !ReturnTypesRefine(a, b) { + if !returnsummary.Refines(a, b) { t.Error("same types should refine") } } -func TestReturnTypesRefine_DifferentLength(t *testing.T) { +func TestReturnSummaryRefines_DifferentLength(t *testing.T) { a := []typ.Type{typ.String, typ.Number} b := []typ.Type{typ.String} - if ReturnTypesRefine(a, b) { + if returnsummary.Refines(a, b) { t.Error("different lengths should not refine") } } -func TestMergeReturnSummary_ReplacesStaleFalsyKeyArrayElement(t *testing.T) { +func TestReturnSummaryMerge_ReplacesStaleFalsyKeyArrayElement(t *testing.T) { stale := []typ.Type{typ.NewArray(typ.NewUnion(typ.Boolean, typ.String))} current := []typ.Type{typ.NewArray(typ.String)} - got := MergeReturnSummary(stale, current) - if !ReturnTypesEqual(got, current) { + got := returnsummary.Merge(stale, current) + if !returnsummary.Equal(got, current) { t.Fatalf("expected truthy-refined key array %v, got %v", current, got) } } -func TestReturnTypesExtendRecord_Empty(t *testing.T) { - if ReturnTypesExtendRecord(nil, nil) { +func TestReturnSummaryExtendsRecord_Empty(t *testing.T) { + if returnsummary.ExtendsRecord(nil, nil) { t.Error("empty vectors should not extend") } } -func TestReturnTypesExtendRecord_NotRecords(t *testing.T) { +func TestReturnSummaryExtendsRecord_NotRecords(t *testing.T) { a := []typ.Type{typ.String} b := []typ.Type{typ.String} - if ReturnTypesExtendRecord(a, b) { + if returnsummary.ExtendsRecord(a, b) { t.Error("non-records should not extend") } } -func TestReturnTypesExtendRecord_RecordExtends(t *testing.T) { +func TestReturnSummaryExtendsRecord_RecordExtends(t *testing.T) { oldRec := typ.NewRecord().Field("x", typ.Number).Build() newRec := typ.NewRecord().Field("x", typ.Number).Field("y", typ.Number).Build() a := []typ.Type{newRec} b := []typ.Type{oldRec} - if !ReturnTypesExtendRecord(a, b) { + if !returnsummary.ExtendsRecord(a, b) { t.Error("record with more fields should extend") } } -func TestReturnTypesElideOptional_Empty(t *testing.T) { - if ReturnTypesElideOptional(nil, nil) { +func TestReturnSummaryElidesOptional_Empty(t *testing.T) { + if returnsummary.ElidesOptional(nil, nil) { t.Error("empty vectors should not elide") } } -func TestSelectPreferredReturnVector_Refinement(t *testing.T) { - preferred, ok := SelectPreferredReturnVector([]typ.Type{typ.String}, []typ.Type{typ.NewOptional(typ.String)}) +func TestReturnSummarySelectPreferred_Refinement(t *testing.T) { + preferred, ok := returnsummary.SelectPreferred([]typ.Type{typ.String}, []typ.Type{typ.NewOptional(typ.String)}) if !ok { t.Fatal("expected preferred vector") } @@ -185,8 +186,8 @@ func TestSelectPreferredReturnVector_Refinement(t *testing.T) { } } -func TestSelectPreferredReturnVector_AvoidsNilOnlyRegression(t *testing.T) { - preferred, ok := SelectPreferredReturnVector([]typ.Type{typ.Nil}, []typ.Type{typ.NewOptional(typ.String)}) +func TestReturnSummarySelectPreferred_AvoidsNilOnlyRegression(t *testing.T) { + preferred, ok := returnsummary.SelectPreferred([]typ.Type{typ.Nil}, []typ.Type{typ.NewOptional(typ.String)}) if !ok { t.Fatal("expected preferred vector") } @@ -195,8 +196,8 @@ func TestSelectPreferredReturnVector_AvoidsNilOnlyRegression(t *testing.T) { } } -func TestSelectPreferredReturnVector_RejectsStaleNilOnly(t *testing.T) { - preferred, ok := SelectPreferredReturnVector([]typ.Type{typ.NewOptional(typ.String)}, []typ.Type{typ.Nil}) +func TestReturnSummarySelectPreferred_RejectsStaleNilOnly(t *testing.T) { + preferred, ok := returnsummary.SelectPreferred([]typ.Type{typ.NewOptional(typ.String)}, []typ.Type{typ.Nil}) if !ok { t.Fatal("expected preferred vector") } @@ -205,11 +206,11 @@ func TestSelectPreferredReturnVector_RejectsStaleNilOnly(t *testing.T) { } } -func TestSelectPreferredReturnVector_RecordExtension(t *testing.T) { +func TestReturnSummarySelectPreferred_RecordExtension(t *testing.T) { oldRec := typ.NewRecord().Field("x", typ.Number).Build() newRec := typ.NewRecord().Field("x", typ.Number).Field("y", typ.String).Build() - preferred, ok := SelectPreferredReturnVector([]typ.Type{newRec}, []typ.Type{oldRec}) + preferred, ok := returnsummary.SelectPreferred([]typ.Type{newRec}, []typ.Type{oldRec}) if !ok { t.Fatal("expected preferred vector") } @@ -218,11 +219,11 @@ func TestSelectPreferredReturnVector_RecordExtension(t *testing.T) { } } -func TestSelectRefiningReturnVector_Refinement(t *testing.T) { +func TestReturnSummarySelectRefining_Refinement(t *testing.T) { refined := []typ.Type{typ.String} baseline := []typ.Type{typ.NewOptional(typ.String)} - got, ok := SelectRefiningReturnVector(refined, baseline) + got, ok := returnsummary.SelectRefining(refined, baseline) if !ok { t.Fatal("expected refinement to be selected") } @@ -231,35 +232,35 @@ func TestSelectRefiningReturnVector_Refinement(t *testing.T) { } } -func TestSelectRefiningReturnVector_DoesNotSelectOlderNarrowerBaseline(t *testing.T) { +func TestReturnSummarySelectRefining_DoesNotSelectOlderNarrowerBaseline(t *testing.T) { candidate := []typ.Type{typ.Any} baseline := []typ.Type{typ.False} - _, ok := SelectRefiningReturnVector(candidate, baseline) + _, ok := returnsummary.SelectRefining(candidate, baseline) if ok { t.Fatal("did not expect baseline-narrower relation to select candidate") } } -func TestReturnTypesFillNilSlots(t *testing.T) { +func TestReturnSummaryFillsNilSlots(t *testing.T) { candidate := []typ.Type{typ.NewMap(typ.Unknown, typ.NewArray(typ.Unknown)), typ.NewArray(typ.Unknown)} baseline := []typ.Type{typ.NewMap(typ.Unknown, typ.NewArray(typ.Unknown)), typ.Nil} - if !ReturnTypesFillNilSlots(candidate, baseline) { + if !returnsummary.FillsNilSlots(candidate, baseline) { t.Fatalf("expected candidate to fill nil slot: candidate=%v baseline=%v", candidate, baseline) } } -func TestMergeReturnSummary_PrefersCandidateRefinement(t *testing.T) { +func TestReturnSummaryMerge_PrefersCandidateRefinement(t *testing.T) { existing := []typ.Type{typ.NewOptional(typ.String)} candidate := []typ.Type{typ.String} - merged := MergeReturnSummary(existing, candidate) + merged := returnsummary.Merge(existing, candidate) if len(merged) != 1 || !typ.TypeEquals(merged[0], typ.String) { t.Fatalf("expected refined candidate return, got %v", merged) } } -func TestMergeReturnSummary_FillsNilSlotWithCandidateEvidence(t *testing.T) { +func TestReturnSummaryMerge_FillsNilSlotWithCandidateEvidence(t *testing.T) { existing := []typ.Type{ typ.NewMap(typ.Unknown, typ.NewArray(typ.Unknown)), typ.Nil, @@ -269,7 +270,7 @@ func TestMergeReturnSummary_FillsNilSlotWithCandidateEvidence(t *testing.T) { typ.NewArray(typ.Unknown), } - merged := MergeReturnSummary(existing, candidate) + merged := returnsummary.Merge(existing, candidate) if len(merged) != 2 { t.Fatalf("expected two return slots, got %v", merged) } @@ -372,28 +373,28 @@ func TestMergeFunctionFactType_KeepsBaselineOverNestedNilOnlyRegression(t *testi } } -func TestMergeReturnSummary_PrefersCurrentTruthyMapKeyRefinement(t *testing.T) { +func TestReturnSummaryMerge_PrefersCurrentTruthyMapKeyRefinement(t *testing.T) { baseline := typ.NewMap(typ.NewUnion(typ.String, typ.False), typ.Unknown) candidate := typ.NewMap(typ.String, typ.Unknown) - merged := MergeReturnSummary([]typ.Type{baseline}, []typ.Type{candidate}) + merged := returnsummary.Merge([]typ.Type{baseline}, []typ.Type{candidate}) if len(merged) != 1 || !typ.TypeEquals(merged[0], candidate) { t.Fatalf("expected stale falsy map key to refine to %v, got %v", candidate, merged) } } -func TestMergeReturnSummary_PrefersConcreteMapValueOverSoftPlaceholder(t *testing.T) { +func TestReturnSummaryMerge_PrefersConcreteMapValueOverSoftPlaceholder(t *testing.T) { entry := typ.NewRecord().Field("id", typ.String).Build() baseline := typ.NewMap(typ.String, typ.NewArray(typ.Any)) candidate := typ.NewMap(typ.String, typ.NewArray(entry)) - merged := MergeReturnSummary([]typ.Type{baseline}, []typ.Type{candidate}) + merged := returnsummary.Merge([]typ.Type{baseline}, []typ.Type{candidate}) if len(merged) != 1 || !typ.TypeEquals(merged[0], candidate) { t.Fatalf("expected concrete map value evidence %v, got %v", candidate, merged) } } -func TestMergeReturnSummary_PrefersCurrentTruthyRecordMapKeyRefinement(t *testing.T) { +func TestReturnSummaryMerge_PrefersCurrentTruthyRecordMapKeyRefinement(t *testing.T) { entryArray := typ.NewArray(typ.Unknown) baseline := typ.NewRecord(). MapComponent(typ.NewUnion(typ.Nil, typ.String, typ.False), entryArray). @@ -403,13 +404,13 @@ func TestMergeReturnSummary_PrefersCurrentTruthyRecordMapKeyRefinement(t *testin MapComponent(typ.String, entryArray). Build() - merged := MergeReturnSummary([]typ.Type{baseline}, []typ.Type{candidate}) + merged := returnsummary.Merge([]typ.Type{baseline}, []typ.Type{candidate}) if len(merged) != 1 || !typ.TypeEquals(merged[0], candidate) { t.Fatalf("expected stale falsy record map key to refine to %v, got %v", candidate, merged) } } -func TestMergeReturnSummary_PrefersMapOverStaleOpenRecordMapKeyRefinement(t *testing.T) { +func TestReturnSummaryMerge_PrefersMapOverStaleOpenRecordMapKeyRefinement(t *testing.T) { entryArray := typ.NewArray(typ.Unknown) baseline := typ.NewRecord(). MapComponent(typ.NewUnion(typ.Nil, typ.String, typ.False), entryArray). @@ -417,7 +418,7 @@ func TestMergeReturnSummary_PrefersMapOverStaleOpenRecordMapKeyRefinement(t *tes Build() candidate := typ.NewMap(typ.String, entryArray) - merged := MergeReturnSummary([]typ.Type{baseline}, []typ.Type{candidate}) + merged := returnsummary.Merge([]typ.Type{baseline}, []typ.Type{candidate}) if len(merged) != 1 || !typ.TypeEquals(merged[0], candidate) { t.Fatalf("expected map to replace stale open record map %v, got %v", candidate, merged) } @@ -589,7 +590,7 @@ func TestMergeFunctionFactType_ReplacesStaleFalsyMapKeyWithTruthyRefinement(t *t } } -func TestMergeReturnSummary_ElidesOptionalForInterfaceFieldRecords(t *testing.T) { +func TestReturnSummaryMerge_ElidesOptionalForInterfaceFieldRecords(t *testing.T) { txType := typ.NewInterface("sql.Tx", []typ.Method{ {Name: "rollback", Type: typ.Func().Param("self", typ.Self).Build()}, }) @@ -611,20 +612,20 @@ func TestMergeReturnSummary_ElidesOptionalForInterfaceFieldRecords(t *testing.T) Build(), } - merged := MergeReturnSummary(existing, candidate) + merged := returnsummary.Merge(existing, candidate) if len(merged) != 1 || !typ.TypeEquals(merged[0], candidate[0]) { t.Fatalf("expected candidate optional-elision to win, got %v", merged) } } -func TestWithSummaryOrUnknown_AppliesSummaryToPlaceholderReturns(t *testing.T) { +func TestReturnSummaryApplyToFunctionType_AppliesSummaryToPlaceholderReturns(t *testing.T) { fn := typ.Func(). Param("x", typ.String). Returns(typ.Unknown). Build() summary := []typ.Type{typ.Integer} - got := WithSummaryOrUnknown(fn, summary) + got := returnsummary.ApplyToFunctionType(fn, summary) if got == nil || len(got.Returns) != 1 { t.Fatalf("expected function return, got %v", got) } @@ -633,9 +634,9 @@ func TestWithSummaryOrUnknown_AppliesSummaryToPlaceholderReturns(t *testing.T) { } } -func TestWithSummaryOrUnknown_DefaultsToUnknownWhenMissing(t *testing.T) { +func TestReturnSummaryApplyToFunctionType_DefaultsToUnknownWhenMissing(t *testing.T) { fn := typ.Func().Param("x", typ.String).Build() - got := WithSummaryOrUnknown(fn, nil) + got := returnsummary.ApplyToFunctionType(fn, nil) if got == nil || len(got.Returns) != 1 { t.Fatalf("expected one default return, got %v", got) } @@ -644,16 +645,16 @@ func TestWithSummaryOrUnknown_DefaultsToUnknownWhenMissing(t *testing.T) { } } -func TestNormalizeReturnVector_Empty(t *testing.T) { - result := NormalizeReturnVector(nil) +func TestReturnSummaryNormalize_Empty(t *testing.T) { + result := returnsummary.Normalize(nil) if result != nil { t.Errorf("expected nil, got %v", result) } } -func TestNormalizeReturnVector_ReplacesNil(t *testing.T) { +func TestReturnSummaryNormalize_ReplacesNil(t *testing.T) { input := []typ.Type{typ.String, nil, typ.Number} - result := NormalizeReturnVector(input) + result := returnsummary.Normalize(input) if len(result) != 3 { t.Fatalf("expected length 3, got %d", len(result)) } @@ -675,7 +676,7 @@ func TestRecordSuperset_BothHaveMapComponent(t *testing.T) { newRec := typ.NewRecord().MapComponent(typ.String, typ.Number).Field("x", typ.Number).Build() a := []typ.Type{newRec} b := []typ.Type{oldRec} - if !ReturnTypesExtendRecord(a, b) { + if !returnsummary.ExtendsRecord(a, b) { t.Error("record with same map component and additional fields should extend") } } @@ -685,12 +686,12 @@ func TestRecordSuperset_OldHasNoMapComponent(t *testing.T) { newRec := typ.NewRecord().Field("x", typ.Number).Field("y", typ.String).Build() a := []typ.Type{newRec} b := []typ.Type{oldRec} - if !ReturnTypesExtendRecord(a, b) { + if !returnsummary.ExtendsRecord(a, b) { t.Error("record with additional fields should extend record without map component") } } -func TestAlignFunctionTypeWithSummary_AppliesStrictRefinement(t *testing.T) { +func TestReturnSummaryAlignFunction_AppliesStrictRefinement(t *testing.T) { fn := typ.Func(). Param("entries", typ.Any). Returns(typ.NewMap(typ.Unknown, typ.NewArray(typ.Unknown))). @@ -703,7 +704,7 @@ func TestAlignFunctionTypeWithSummary_AppliesStrictRefinement(t *testing.T) { Build(), } - aligned, changed := AlignFunctionTypeWithSummary(fn, summary) + aligned, changed := returnsummary.AlignFunction(fn, summary) if !changed { t.Fatal("expected alignment to apply strict refinement summary") } @@ -720,12 +721,12 @@ func TestAlignFunctionTypeWithSummary_AppliesStrictRefinement(t *testing.T) { } } -func TestAlignFunctionTypeWithSummary_ReplacesOpenTopRecordWithStructuredSummary(t *testing.T) { +func TestReturnSummaryAlignFunction_ReplacesOpenTopRecordWithStructuredSummary(t *testing.T) { openTop := typ.NewRecord().SetOpen(true).Build() fn := typ.Func().Returns(openTop).Build() summary := []typ.Type{typ.NewArray(typ.Unknown)} - aligned, changed := AlignFunctionTypeWithSummary(fn, summary) + aligned, changed := returnsummary.AlignFunction(fn, summary) if !changed { t.Fatal("expected open-top placeholder to be replaced by structured summary") } @@ -737,12 +738,12 @@ func TestAlignFunctionTypeWithSummary_ReplacesOpenTopRecordWithStructuredSummary } } -func TestAlignFunctionTypeWithSummary_DoesNotDowngradeStructuredToPlaceholder(t *testing.T) { +func TestReturnSummaryAlignFunction_DoesNotDowngradeStructuredToPlaceholder(t *testing.T) { structured := typ.NewRecord().Field("get_x", typ.Func().Build()).Build() fn := typ.Func().Returns(structured).Build() summary := []typ.Type{typ.Any} - aligned, changed := AlignFunctionTypeWithSummary(fn, summary) + aligned, changed := returnsummary.AlignFunction(fn, summary) if changed { t.Fatalf("expected no downgrade change, got %v", aligned) } @@ -754,7 +755,7 @@ func TestAlignFunctionTypeWithSummary_DoesNotDowngradeStructuredToPlaceholder(t } } -func TestMergeReturnSummary_PrefersRuntimePossibleSummaryOverNeverArtifact(t *testing.T) { +func TestReturnSummaryMerge_PrefersRuntimePossibleSummaryOverNeverArtifact(t *testing.T) { bad := []typ.Type{ typ.NewUnion( typ.NewRecord(). @@ -780,13 +781,13 @@ func TestMergeReturnSummary_PrefersRuntimePossibleSummaryOverNeverArtifact(t *te ), } - got := MergeReturnSummary(bad, good) - if !ReturnTypesEqual(got, good) { - t.Fatalf("MergeReturnSummary(%v, %v) = %v, want %v", bad, good, got, good) + got := returnsummary.Merge(bad, good) + if !returnsummary.Equal(got, good) { + t.Fatalf("returnsummary.Merge(%v, %v) = %v, want %v", bad, good, got, good) } } -func TestAlignFunctionTypeWithSummary_RepairsNestedNeverArtifact(t *testing.T) { +func TestReturnSummaryAlignFunction_RepairsNestedNeverArtifact(t *testing.T) { bad := typ.NewUnion( typ.NewRecord(). Field("success", typ.True). @@ -809,7 +810,7 @@ func TestAlignFunctionTypeWithSummary_RepairsNestedNeverArtifact(t *testing.T) { ) fn := typ.Func().Returns(bad).Build() - aligned, changed := AlignFunctionTypeWithSummary(fn, []typ.Type{good}) + aligned, changed := returnsummary.AlignFunction(fn, []typ.Type{good}) if !changed { t.Fatal("expected never-artifact repair to update function returns") } @@ -823,7 +824,7 @@ func TestRecordSuperset_NewHasMapComponentOldDoesNot(t *testing.T) { newRec := typ.NewRecord().Field("x", typ.Number).MapComponent(typ.String, typ.Any).Build() a := []typ.Type{newRec} b := []typ.Type{oldRec} - if !ReturnTypesExtendRecord(a, b) { + if !returnsummary.ExtendsRecord(a, b) { t.Error("record with additional map component should extend record without it") } } @@ -833,12 +834,12 @@ func TestRecordSuperset_IncompatibleMapComponent(t *testing.T) { newRec := typ.NewRecord().MapComponent(typ.String, typ.Number).Build() a := []typ.Type{newRec} b := []typ.Type{oldRec} - if ReturnTypesExtendRecord(a, b) { + if returnsummary.ExtendsRecord(a, b) { t.Error("record with incompatible map component should not extend") } } -func TestMergeReturnSummary_PrefersStructuredCollectionOverOpenTopRecordField(t *testing.T) { +func TestReturnSummaryMerge_PrefersStructuredCollectionOverOpenTopRecordField(t *testing.T) { weak := []typ.Type{ typ.NewRecord(). Field("messages", typ.NewRecord().SetOpen(true).Build()). @@ -854,7 +855,7 @@ func TestMergeReturnSummary_PrefersStructuredCollectionOverOpenTopRecordField(t Build(), } - merged := MergeReturnSummary(weak, strong) + merged := returnsummary.Merge(weak, strong) if len(merged) != 1 { t.Fatalf("expected one return slot, got %d", len(merged)) } @@ -872,7 +873,7 @@ func TestMergeReturnSummary_PrefersStructuredCollectionOverOpenTopRecordField(t } } -func TestMergeReturnSummary_PromotesTopLevelStructuredOverOpenTop(t *testing.T) { +func TestReturnSummaryMerge_PromotesTopLevelStructuredOverOpenTop(t *testing.T) { weak := []typ.Type{ typ.NewRecord().SetOpen(true).Build(), } @@ -880,7 +881,7 @@ func TestMergeReturnSummary_PromotesTopLevelStructuredOverOpenTop(t *testing.T) typ.NewArray(typ.Any), } - merged := MergeReturnSummary(weak, strong) + merged := returnsummary.Merge(weak, strong) if len(merged) != 1 { t.Fatalf("expected one return slot, got %d", len(merged)) } diff --git a/compiler/check/returns/kernel.go b/compiler/check/returns/kernel.go index 48cb4f25..19d9abc4 100644 --- a/compiler/check/returns/kernel.go +++ b/compiler/check/returns/kernel.go @@ -3,6 +3,7 @@ package returns import ( "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" + "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/types/typ/unwrap" ) @@ -17,36 +18,36 @@ func JoinFunctionFact(existing, candidate api.FunctionFact) api.FunctionFact { out.Params = paramevidence.JoinVectors(out.Params, candidate.Params) } if len(candidate.Summary) > 0 { - out.Summary = MergeReturnSummary(out.Summary, candidate.Summary) + out.Summary = returnsummary.Merge(out.Summary, candidate.Summary) } if len(candidate.Narrow) > 0 { - out.Narrow = MergeReturnSummary(out.Narrow, candidate.Narrow) + out.Narrow = returnsummary.Merge(out.Narrow, candidate.Narrow) } if candidate.Type != nil { out.Type = MergeFunctionFactType(out.Type, candidate.Type) } // Keep summary and post-flow narrow results mutually refining when narrow - // provides first-order information. MergeReturnSummary is the canonical + // provides first-order information. returnsummary.Merge is the canonical // policy and already encodes directional refinement preference. if len(out.Narrow) > 0 { if len(out.Summary) == 0 { - out.Summary = canonicalReturnVector(out.Narrow) + out.Summary = returnsummary.Canonical(out.Narrow) } else { - out.Summary = MergeReturnSummary(out.Summary, out.Narrow) + out.Summary = returnsummary.Merge(out.Summary, out.Narrow) } } if fn := unwrap.Function(out.Type); fn != nil { alignedSummary := out.Summary if len(alignedSummary) > 0 { - if aligned, changed := AlignFunctionTypeWithSummary(fn, alignedSummary); changed { + if aligned, changed := returnsummary.AlignFunction(fn, alignedSummary); changed { out.Type = aligned fn = aligned } } if len(out.Summary) == 0 && fn != nil && len(fn.Returns) > 0 { - out.Summary = canonicalReturnVector(fn.Returns) + out.Summary = returnsummary.Canonical(fn.Returns) } } diff --git a/compiler/check/returns/kernel_test.go b/compiler/check/returns/kernel_test.go index 0364c5ca..f6745c83 100644 --- a/compiler/check/returns/kernel_test.go +++ b/compiler/check/returns/kernel_test.go @@ -5,6 +5,7 @@ import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/types/typ" ) @@ -18,10 +19,10 @@ func TestJoinFunctionFact_InitialObservation(t *testing.T) { Type: fn, })}} - if got := facts.FunctionFacts.Summary(sym); !ReturnTypesEqual(got, []typ.Type{typ.String}) { + if got := facts.FunctionFacts.Summary(sym); !returnsummary.Equal(got, []typ.Type{typ.String}) { t.Fatalf("summary mismatch: got %v", got) } - if got := facts.FunctionFacts.NarrowSummary(sym); !ReturnTypesEqual(got, []typ.Type{typ.String}) { + if got := facts.FunctionFacts.NarrowSummary(sym); !returnsummary.Equal(got, []typ.Type{typ.String}) { t.Fatalf("narrow mismatch: got %v", got) } if got := facts.FunctionFacts.FunctionType(sym); !typ.TypeEquals(got, fn) { @@ -44,10 +45,10 @@ func TestJoinFunctionFact_MergesExistingAndCandidate(t *testing.T) { } got := JoinFunctionFact(existing, candidate) - if !ReturnTypesEqual(got.Summary, []typ.Type{typ.NewUnion(typ.Number, typ.String)}) { + if !returnsummary.Equal(got.Summary, []typ.Type{typ.NewUnion(typ.Number, typ.String)}) { t.Fatalf("summary mismatch: got %v", got.Summary) } - if !ReturnTypesEqual(got.Narrow, []typ.Type{typ.NewUnion(typ.Number, typ.String)}) { + if !returnsummary.Equal(got.Narrow, []typ.Type{typ.NewUnion(typ.Number, typ.String)}) { t.Fatalf("narrow mismatch: got %v", got.Narrow) } if got.Type == nil { @@ -75,10 +76,10 @@ func TestJoinFacts_BatchMergeFunctionFacts(t *testing.T) { }, ) - if got := facts.FunctionFacts.Summary(symSummary); !ReturnTypesEqual(got, []typ.Type{typ.String}) { + if got := facts.FunctionFacts.Summary(symSummary); !returnsummary.Equal(got, []typ.Type{typ.String}) { t.Fatalf("summary mismatch: got %v", got) } - if got := facts.FunctionFacts.NarrowSummary(symNarrow); !ReturnTypesEqual(got, []typ.Type{typ.Number}) { + if got := facts.FunctionFacts.NarrowSummary(symNarrow); !returnsummary.Equal(got, []typ.Type{typ.Number}) { t.Fatalf("narrow mismatch: got %v", got) } if got := facts.FunctionFacts.FunctionType(symFunc); !typ.TypeEquals(got, funcType) { @@ -97,7 +98,7 @@ func TestJoinFunctionFact_NarrowSummaryReplacesOpenTopPlaceholder(t *testing.T) api.FunctionFact{Summary: []typ.Type{openTop}, Narrow: narrow, Type: candidateFunc}, ) - if !ReturnTypesEqual(normalizeAndPruneReturnVector(out.Summary), normalizeAndPruneReturnVector(narrow)) { + if !returnsummary.Equal(returnsummary.NormalizeAndPrune(out.Summary), returnsummary.NormalizeAndPrune(narrow)) { t.Fatalf("summary mismatch: got %v want %v", out.Summary, narrow) } @@ -105,7 +106,7 @@ func TestJoinFunctionFact_NarrowSummaryReplacesOpenTopPlaceholder(t *testing.T) if !ok { t.Fatalf("expected function fact, got %T", out.Type) } - if !ReturnTypesEqual(normalizeAndPruneReturnVector(fn.Returns), normalizeAndPruneReturnVector(narrow)) { + if !returnsummary.Equal(returnsummary.NormalizeAndPrune(fn.Returns), returnsummary.NormalizeAndPrune(narrow)) { t.Fatalf("func returns mismatch: got %v want %v", fn.Returns, narrow) } } @@ -142,17 +143,17 @@ func TestJoinFunctionFact_NarrowSummaryRepairsNeverArtifact(t *testing.T) { api.FunctionFact{Narrow: good}, ) - if !ReturnTypesEqual(out.Summary, good) { + if !returnsummary.Equal(out.Summary, good) { t.Fatalf("summary mismatch: got %v want %v", out.Summary, good) } - if !ReturnTypesEqual(out.Narrow, good) { + if !returnsummary.Equal(out.Narrow, good) { t.Fatalf("narrow mismatch: got %v want %v", out.Narrow, good) } fn, ok := out.Type.(*typ.Function) if !ok { t.Fatalf("expected function fact, got %T", out.Type) } - if !ReturnTypesEqual(fn.Returns, good) { + if !returnsummary.Equal(fn.Returns, good) { t.Fatalf("func returns mismatch: got %v want %v", fn.Returns, good) } } @@ -172,14 +173,14 @@ func TestJoinFunctionFact_DoesNotAlignFunctionToNarrowFieldRegression(t *testing api.FunctionFact{Summary: []typ.Type{withCapturedMethod}, Narrow: []typ.Type{flowOnly}, Type: existingFunc}, ) - if !ReturnTypesEqual(out.Summary, []typ.Type{withCapturedMethod}) { + if !returnsummary.Equal(out.Summary, []typ.Type{withCapturedMethod}) { t.Fatalf("summary mismatch: got %v want %v", out.Summary, []typ.Type{withCapturedMethod}) } fn, ok := out.Type.(*typ.Function) if !ok { t.Fatalf("expected function fact, got %T", out.Type) } - if !ReturnTypesEqual(fn.Returns, []typ.Type{withCapturedMethod}) { + if !returnsummary.Equal(fn.Returns, []typ.Type{withCapturedMethod}) { t.Fatalf("func returns should preserve captured method summary, got %v", fn.Returns) } } @@ -199,10 +200,10 @@ func TestNormalizeFunctionFacts_CanonicalizesStoredFunctionFacts(t *testing.T) { if !ok { t.Fatal("expected canonical FunctionFacts entry") } - if !ReturnTypesEqual(ff.Summary, []typ.Type{typ.Nil}) { + if !returnsummary.Equal(ff.Summary, []typ.Type{typ.Nil}) { t.Fatalf("summary mismatch: got %v", ff.Summary) } - if !ReturnTypesEqual(ff.Narrow, []typ.Type{typ.Number}) { + if !returnsummary.Equal(ff.Narrow, []typ.Type{typ.Number}) { t.Fatalf("narrow mismatch: got %v", ff.Narrow) } if !typ.TypeEquals(ff.Type, fn) { @@ -219,11 +220,11 @@ func TestFunctionFactsAccessorsReadCanonicalFacts(t *testing.T) { }, } - if got := facts.FunctionFacts.Summary(sym); !ReturnTypesEqual(got, []typ.Type{typ.String}) { + if got := facts.FunctionFacts.Summary(sym); !returnsummary.Equal(got, []typ.Type{typ.String}) { t.Fatalf("summary mismatch: got %v", got) } - if got := facts.FunctionFacts.NarrowSummary(sym); !ReturnTypesEqual(got, []typ.Type{typ.String}) { + if got := facts.FunctionFacts.NarrowSummary(sym); !returnsummary.Equal(got, []typ.Type{typ.String}) { t.Fatalf("narrow mismatch: got %v", got) } diff --git a/compiler/check/returns/widen.go b/compiler/check/returns/widen.go index 50c67661..c208b4f2 100644 --- a/compiler/check/returns/widen.go +++ b/compiler/check/returns/widen.go @@ -4,6 +4,7 @@ import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" + "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/compiler/check/domain/value" "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" @@ -71,9 +72,9 @@ func widenFunctionFactForConvergence(prev, next api.FunctionFact) api.FunctionFa // Narrow summaries can refine optional/non-nil returns, but a nil-only // narrow observation must not erase an already-informative summary. - if len(out.Narrow) > 0 && !ReturnTypesAllNil(out.Narrow) { + if len(out.Narrow) > 0 && !returnsummary.AllNil(out.Narrow) { if len(out.Summary) == 0 { - out.Summary = canonicalReturnVector(out.Narrow) + out.Summary = returnsummary.Canonical(out.Narrow) } else { out.Summary = widenReturnSummaryForConvergence(out.Summary, out.Narrow) } @@ -81,7 +82,7 @@ func widenFunctionFactForConvergence(prev, next api.FunctionFact) api.FunctionFa if fn := unwrap.Function(out.Type); fn != nil { if len(out.Summary) > 0 { - if aligned, changed := AlignFunctionTypeWithSummary(fn, out.Summary); changed { + if aligned, changed := returnsummary.AlignFunction(fn, out.Summary); changed { out.Type = widenFunctionFactTypeForConvergence(fn, aligned) } } else if len(fn.Returns) > 0 { @@ -92,149 +93,9 @@ func widenFunctionFactForConvergence(prev, next api.FunctionFact) api.FunctionFa return out } -func shouldUseMonotoneReturnJoin(a, b []typ.Type) bool { - for _, t := range a { - if hasHigherOrderGrowthRisk(t) { - return true - } - } - for _, t := range b { - if hasHigherOrderGrowthRisk(t) { - return true - } - } - return false -} - -func hasHigherOrderGrowthRisk(t typ.Type) bool { - if t == nil { - return false - } - return value.Scan(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { - switch n := node.(type) { - case *typ.Function: - for _, ret := range n.Returns { - if typeContainsFunction(ret) { - return true, false - } - } - case *typ.Record: - if recordHasSelfRecursiveMethod(n) { - return true, false - } - } - return false, true - }) -} - -func typeContainsFunction(t typ.Type) bool { - if t == nil { - return false - } - return value.Scan(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { - // Interface method signatures are behavioral contracts, not first-class - // returned function values. Ignore them for higher-order growth risk. - if _, ok := node.(*typ.Interface); ok { - return false, false - } - if _, ok := node.(*typ.Function); ok { - return true, false - } - return false, true - }) -} - -func recordHasSelfRecursiveMethod(r *typ.Record) bool { - if r == nil { - return false - } - for _, f := range r.Fields { - if methodTypeHasSelfRecursiveReturn(f.Type, r) { - return true - } - } - if r.HasMapComponent() && methodTypeHasSelfRecursiveReturn(r.MapValue, r) { - return true - } - return false -} - -func methodTypeHasSelfRecursiveReturn(t typ.Type, owner *typ.Record) bool { - if t == nil || owner == nil { - return false - } - return value.Scan(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { - // Interface method signatures are behavioral contracts, not concrete - // record method bodies. Treating them as self-recursive growth risk - // over-applies monotone widening and blocks valid summary refinement. - if _, ok := node.(*typ.Interface); ok { - return false, false - } - fn, ok := node.(*typ.Function) - if !ok { - return false, true - } - for _, ret := range fn.Returns { - if ret == nil { - continue - } - if subtype.IsSubtype(ret, owner) || subtype.IsSubtype(owner, ret) || - value.ExtendsRecord(ret, owner) || value.ExtendsRecord(owner, ret) { - return true, false - } - } - return false, true - }) -} - -func joinReturnVectorsMonotone(a, b []typ.Type) []typ.Type { - if len(a) == 0 { - return b - } - if len(b) == 0 { - return a - } - maxLen := len(a) - if len(b) > maxLen { - maxLen = len(b) - } - out := make([]typ.Type, maxLen) - for i := 0; i < maxLen; i++ { - var ai, bi typ.Type - if i < len(a) { - ai = a[i] - } - if i < len(b) { - bi = b[i] - } - out[i] = joinReturnTypeMonotone(ai, bi) - } - return out -} - -func joinReturnTypeMonotone(a, b typ.Type) typ.Type { - if a == nil { - return b - } - if b == nil { - return a - } - if typ.TypeEquals(a, b) { - return a - } - // Keep widening monotone: if one side is already an upper bound, keep it. - if subtype.IsSubtype(a, b) || value.ExtendsRecord(a, b) || value.ElidesOptional(a, b) { - return b - } - if subtype.IsSubtype(b, a) || value.ExtendsRecord(b, a) || value.ElidesOptional(b, a) { - return a - } - return typ.JoinPreferNonSoft(a, b) -} - func widenReturnSummaryForConvergence(prev, next []typ.Type) []typ.Type { - prev = normalizeAndPruneReturnVector(prev) - next = normalizeAndPruneReturnVector(next) + prev = returnsummary.NormalizeAndPrune(prev) + next = returnsummary.NormalizeAndPrune(next) if len(prev) == 0 { return widenReturnVectorForConvergence(next) } @@ -242,11 +103,11 @@ func widenReturnSummaryForConvergence(prev, next []typ.Type) []typ.Type { return widenReturnVectorForConvergence(prev) } - merged := MergeReturnSummary(prev, next) + merged := returnsummary.Merge(prev, next) if returnVectorUnsafePrecisionDrop(prev, merged) { merged = prev } - return widenReturnVectorForConvergence(normalizeAndPruneReturnVector(merged)) + return widenReturnVectorForConvergence(returnsummary.NormalizeAndPrune(merged)) } func returnVectorUnsafePrecisionDrop(prev, merged []typ.Type) bool { @@ -1178,10 +1039,10 @@ func mergeFunctionReturnsIfSameShape(prevFn, nextFn *typ.Function) (typ.Type, bo mergedReturns[i] = typ.JoinReturnSlot(normalizedPrev[i], normalizedNext[i]) } } - if ReturnTypesEqual(prevFn.Returns, mergedReturns) { + if returnsummary.Equal(prevFn.Returns, mergedReturns) { return prevFn, true } - if ReturnTypesEqual(nextFn.Returns, mergedReturns) { + if returnsummary.Equal(nextFn.Returns, mergedReturns) { return nextFn, true } @@ -1260,7 +1121,7 @@ func maybeWidenTypeForConvergence(t typ.Type) typ.Type { if t == nil { return nil } - if !hasHigherOrderGrowthRisk(t) { + if !returnsummary.HasHigherOrderGrowthRisk(t) { return t } return subtype.WidenForInference(t) diff --git a/compiler/check/returns/widen_test.go b/compiler/check/returns/widen_test.go index ee68599d..006c5fbd 100644 --- a/compiler/check/returns/widen_test.go +++ b/compiler/check/returns/widen_test.go @@ -5,6 +5,7 @@ import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" ) @@ -116,7 +117,7 @@ func TestWidenFacts_InterfaceMethodsDoNotBlockOptionalElision(t *testing.T) { } } -func TestMergeReturnSummary_StopsRecursiveContainerReturnGrowth(t *testing.T) { +func TestReturnSummaryMerge_StopsRecursiveContainerReturnGrowth(t *testing.T) { recordMap := func(value typ.Type) typ.Type { return typ.NewRecord().MapComponent(typ.String, value).Build() } @@ -158,7 +159,7 @@ func TestMergeReturnSummary_StopsRecursiveContainerReturnGrowth(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - merged := MergeReturnSummary([]typ.Type{tt.stable}, []typ.Type{tt.growth}) + merged := returnsummary.Merge([]typ.Type{tt.stable}, []typ.Type{tt.growth}) if len(merged) != 1 || !typ.TypeEquals(merged[0], tt.stable) { t.Fatalf("expected stable recursive return shape, got %v", merged) } @@ -166,11 +167,11 @@ func TestMergeReturnSummary_StopsRecursiveContainerReturnGrowth(t *testing.T) { } } -func TestMergeReturnSummary_KeepsNonRecursiveContainerRefinement(t *testing.T) { +func TestReturnSummaryMerge_KeepsNonRecursiveContainerRefinement(t *testing.T) { stable := typ.NewMap(typ.String, typ.Any) refined := typ.NewMap(typ.String, typ.String) - merged := MergeReturnSummary([]typ.Type{stable}, []typ.Type{refined}) + merged := returnsummary.Merge([]typ.Type{stable}, []typ.Type{refined}) if len(merged) != 1 || !typ.TypeEquals(merged[0], refined) { t.Fatalf("expected non-recursive map refinement to survive, got %v", merged) } @@ -444,43 +445,3 @@ func TestWidenLiteralSigs_NormalizesNilBranch(t *testing.T) { t.Fatalf("expected nil-branch literal signature %v to be normalized to %v, got %v", sig, want, got) } } - -func TestTypeContainsFunction_IgnoresInterfaceMethodSignatures(t *testing.T) { - iface := typ.NewInterface("Reader", []typ.Method{ - { - Name: "next", - Type: typ.Func(). - Param("self", typ.Self). - Returns(typ.Func().Returns(typ.String).Build()). - Build(), - }, - }) - if typeContainsFunction(iface) { - t.Fatalf("expected interface method signatures to be ignored, got true") - } -} - -func TestHasHigherOrderGrowthRisk_DetectsFunctionReturningFunction(t *testing.T) { - tp := typ.Func(). - Returns(typ.Func().Returns(typ.String).Build()). - Build() - if !hasHigherOrderGrowthRisk(tp) { - t.Fatalf("expected higher-order growth risk to be detected") - } -} - -func TestMethodTypeHasSelfRecursiveReturn_IgnoresInterfaceMethods(t *testing.T) { - owner := typ.NewRecord().Field("id", typ.String).Build() - methodType := typ.NewInterface("HasBuild", []typ.Method{ - { - Name: "build", - Type: typ.Func(). - Param("self", typ.Self). - Returns(owner). - Build(), - }, - }) - if methodTypeHasSelfRecursiveReturn(methodType, owner) { - t.Fatalf("expected interface method signatures to be ignored for self-recursive detection") - } -} diff --git a/compiler/check/tests/errors/error_correlation_test.go b/compiler/check/tests/errors/error_correlation_test.go index adc39a1b..d17e742c 100644 --- a/compiler/check/tests/errors/error_correlation_test.go +++ b/compiler/check/tests/errors/error_correlation_test.go @@ -5,7 +5,7 @@ import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/cfg" - "github.com/wippyai/go-lua/compiler/check/returns" + "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/compiler/check/tests/testutil" "github.com/wippyai/go-lua/types/constraint" "github.com/wippyai/go-lua/types/contract" @@ -294,7 +294,7 @@ db:release() if !ok || sym == 0 { t.Fatalf("missing symbol for %s", name) } - rets := returns.NormalizeReturnVector(functionFacts.Summary(sym)) + rets := returnsummary.Normalize(functionFacts.Summary(sym)) if len(rets) == 0 { t.Fatalf("missing return summary for %s", name) } From cc5ebbfb053d075870e063110f3ed62c9d36dd13 Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 02:44:29 -0400 Subject: [PATCH 21/71] Move function facts into domain --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 56 +++++ compiler/check/domain/functionfact/doc.go | 7 + .../join.go => domain/functionfact/fact.go} | 147 +++++++++--- .../check/domain/functionfact/fact_test.go | 211 ++++++++++++++++++ compiler/check/infer/interproc/postflow.go | 8 +- compiler/check/infer/nested/processor.go | 3 +- compiler/check/infer/return/infer.go | 3 +- compiler/check/pipeline/runner_stages.go | 3 +- compiler/check/returns/doc.go | 21 +- compiler/check/returns/function_facts.go | 52 +---- compiler/check/returns/join_test.go | 43 ++-- compiler/check/returns/kernel.go | 55 ----- compiler/check/returns/kernel_test.go | 25 ++- compiler/check/returns/widen.go | 7 +- compiler/check/returns/widen_test.go | 29 +-- compiler/check/siblings/siblings.go | 4 +- compiler/check/siblings/siblings_test.go | 12 +- 17 files changed, 479 insertions(+), 207 deletions(-) create mode 100644 compiler/check/domain/functionfact/doc.go rename compiler/check/{returns/join.go => domain/functionfact/fact.go} (59%) create mode 100644 compiler/check/domain/functionfact/fact_test.go delete mode 100644 compiler/check/returns/kernel.go diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index 1a567854..de4c189e 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -223,6 +223,62 @@ Verification for this slice so far: targets: session 8 errors, agent/src 9 errors, docker-demo 21 errors and 2 warnings. +## 2026-05-19 Function Fact Domain Checkpoint + +The next rectification slice moved the per-function fact laws out of +`compiler/check/returns` and into `compiler/check/domain/functionfact`. + +Moved domain laws: + +- canonicalization and emptiness for one `api.FunctionFact`; +- same-iteration join for one function fact; +- merge policy for function-type fact projections; +- compatible function-variant collapse inside unions while preserving residual + non-function union members; +- same-shape function merging across params, variadic params, returns, effects, + error-return specs, and refinements; +- parameter-slot fact merge policy that delegates to `domain/paramevidence`; +- return-slot fact merge policy that delegates to `domain/returnsummary`. + +Production callers now import `domain/functionfact` directly for individual +function facts. The old `returns.JoinFunctionFact`, +`returns.MergeFunctionFactType`, `returns.NormalizeFunctionFact`, and +`returns.NormalizeFunctionFacts` names were deleted instead of wrapped. + +Current package ownership: + +```text +domain/value = reusable structural value relations +domain/paramevidence = parameter evidence lattice, equality, and parameter-slot refinement +domain/returnsummary = return-vector lattice and function-return alignment +domain/functionfact = one-function fact normalization, join, and type projection +returns = function-fact maps, captured effects, local SCC orchestration, and interproc widening +``` + +The resulting data flow is now narrower: + +1. inference and post-flow code produce one-function deltas through + `functionfact.Join`; +2. `returns.JoinFacts` and `returns.WidenFacts` decide how those deltas combine + across symbol maps and fixpoint iterations; +3. convergence-specific widening remains in `returns` because it depends on the + whole interprocedural product and iteration boundary; +4. no production code calls legacy per-function fact helpers through `returns`. + +Verification for this slice so far: + +- `go test ./compiler/check/domain/functionfact ./compiler/check/returns + ./compiler/check/...` passes. +- `go test ./...` passes. +- `git diff --check` passes. +- `go test ./compiler/check -run '^$' -bench BenchmarkCheck_LargeFunction + -benchmem -count=3` reports about 1.15-1.17 ms/op, 882 KB/op, and 9390 + allocs/op on this machine. +- Standard `../scripts/verify-suite.sh` passes go-lua checker tests and builds + the Wippy binary, then exits non-zero on the known external pinned lint + targets: session 8 errors, agent/src 10 errors, docker-demo 21 errors and + 2 warnings. + ## Goal The checker should read as one abstract interpreter over a product domain. diff --git a/compiler/check/domain/functionfact/doc.go b/compiler/check/domain/functionfact/doc.go new file mode 100644 index 00000000..6029d0fc --- /dev/null +++ b/compiler/check/domain/functionfact/doc.go @@ -0,0 +1,7 @@ +// Package functionfact owns the per-function fact abstract domain. +// +// It canonicalizes and joins one api.FunctionFact at a time: parameter +// evidence, return summaries, narrow summaries, and the projected function type. +// Product-level packages decide when facts are read from or written to maps; +// this package decides what one function fact means and how it combines. +package functionfact diff --git a/compiler/check/returns/join.go b/compiler/check/domain/functionfact/fact.go similarity index 59% rename from compiler/check/returns/join.go rename to compiler/check/domain/functionfact/fact.go index f328f071..85f90c89 100644 --- a/compiler/check/returns/join.go +++ b/compiler/check/domain/functionfact/fact.go @@ -1,6 +1,7 @@ -package returns +package functionfact import ( + "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/compiler/check/domain/value" @@ -9,10 +10,78 @@ import ( "github.com/wippyai/go-lua/types/typ/unwrap" ) -// MergeFunctionFactType merges function-type facts through one canonical policy. -// This ensures all channels agree on when to preserve shape and how to merge -// returns, avoiding directional one-off behavior in individual phases. -func MergeFunctionFactType(existing, candidate typ.Type) typ.Type { +// Normalize canonicalizes one stored function fact. +func Normalize(ff api.FunctionFact) api.FunctionFact { + return api.FunctionFact{ + Params: paramevidence.FilterEmptyVector(ff.Params), + Summary: returnsummary.Canonical(ff.Summary), + Narrow: returnsummary.Canonical(ff.Narrow), + Type: normalizeType(ff.Type), + } +} + +// Empty reports whether a canonical function fact contains no information. +func Empty(ff api.FunctionFact) bool { + return len(ff.Params) == 0 && len(ff.Summary) == 0 && len(ff.Narrow) == 0 && ff.Type == nil +} + +// Join precisely merges two observations for one local function during a single +// analysis iteration. +func Join(existing, candidate api.FunctionFact) api.FunctionFact { + existing = Normalize(existing) + candidate = Normalize(candidate) + out := existing + + if len(candidate.Params) > 0 { + out.Params = paramevidence.JoinVectors(out.Params, candidate.Params) + } + if len(candidate.Summary) > 0 { + out.Summary = returnsummary.Merge(out.Summary, candidate.Summary) + } + if len(candidate.Narrow) > 0 { + out.Narrow = returnsummary.Merge(out.Narrow, candidate.Narrow) + } + if candidate.Type != nil { + out.Type = MergeType(out.Type, candidate.Type) + } + + if len(out.Narrow) > 0 { + if len(out.Summary) == 0 { + out.Summary = returnsummary.Canonical(out.Narrow) + } else { + out.Summary = returnsummary.Merge(out.Summary, out.Narrow) + } + } + + if fn := unwrap.Function(out.Type); fn != nil { + alignedSummary := out.Summary + if len(alignedSummary) > 0 { + if aligned, changed := returnsummary.AlignFunction(fn, alignedSummary); changed { + out.Type = aligned + fn = aligned + } + } + if len(out.Summary) == 0 && fn != nil && len(fn.Returns) > 0 { + out.Summary = returnsummary.Canonical(fn.Returns) + } + } + + return out +} + +func normalizeType(t typ.Type) typ.Type { + if t == nil { + return nil + } + if fn := unwrap.Function(t); fn != nil { + return fn + } + return typ.PruneSoftUnionMembers(t) +} + +// MergeType merges function-type facts through the canonical per-function fact +// policy. +func MergeType(existing, candidate typ.Type) typ.Type { if existing == nil { return candidate } @@ -22,12 +91,12 @@ func MergeFunctionFactType(existing, candidate typ.Type) typ.Type { existingFn := unwrap.Function(existing) candidateFn := unwrap.Function(candidate) - if mergedFromVariants, ok := mergeFunctionFactVariants(existing, candidate); ok { + if mergedFromVariants, ok := mergeVariants(existing, candidate); ok { return mergedFromVariants } if existingFn != nil && candidateFn != nil { - if sameFunctionShapeForFactMerge(existingFn, candidateFn) { - return mergeFunctionFactsByShape(existingFn, candidateFn) + if SameShape(existingFn, candidateFn) { + return mergeByShape(existingFn, candidateFn) } } @@ -40,14 +109,14 @@ func MergeFunctionFactType(existing, candidate typ.Type) typ.Type { return typ.JoinPreferNonSoft(existing, candidate) } -type functionFactVariants struct { +type variants struct { funcs []*typ.Function residuals []typ.Type } -func mergeFunctionFactVariants(existing, candidate typ.Type) (typ.Type, bool) { - existingVariants := splitFunctionFactVariants(existing) - candidateVariants := splitFunctionFactVariants(candidate) +func mergeVariants(existing, candidate typ.Type) (typ.Type, bool) { + existingVariants := splitVariants(existing) + candidateVariants := splitVariants(candidate) if len(existingVariants.funcs) == 0 || len(candidateVariants.funcs) == 0 { return nil, false } @@ -56,14 +125,14 @@ func mergeFunctionFactVariants(existing, candidate typ.Type) (typ.Type, bool) { all = append(all, existingVariants.funcs...) all = append(all, candidateVariants.funcs...) for i := 1; i < len(all); i++ { - if !sameFunctionShapeForFactMerge(all[0], all[i]) { + if !SameShape(all[0], all[i]) { return nil, false } } merged := all[0] for i := 1; i < len(all); i++ { - next, _ := mergeFunctionFactsByShape(merged, all[i]).(*typ.Function) + next, _ := mergeByShape(merged, all[i]).(*typ.Function) if next == nil { return nil, false } @@ -80,20 +149,20 @@ func mergeFunctionFactVariants(existing, candidate typ.Type) (typ.Type, bool) { return typ.NewUnion(residuals...), true } -func splitFunctionFactVariants(t typ.Type) functionFactVariants { - var out functionFactVariants - collectFunctionFactVariants(t, &out) +func splitVariants(t typ.Type) variants { + var out variants + collectVariants(t, &out) return out } -func collectFunctionFactVariants(t typ.Type, out *functionFactVariants) { +func collectVariants(t typ.Type, out *variants) { if t == nil || out == nil { return } switch v := unwrap.Alias(t).(type) { case *typ.Union: for _, member := range v.Members { - collectFunctionFactVariants(member, out) + collectVariants(member, out) } return } @@ -104,7 +173,8 @@ func collectFunctionFactVariants(t typ.Type, out *functionFactVariants) { out.residuals = append(out.residuals, t) } -func sameFunctionShapeForFactMerge(a, b *typ.Function) bool { +// SameShape reports whether two function fact types can be merged slot-wise. +func SameShape(a, b *typ.Function) bool { if a == nil || b == nil { return false } @@ -114,15 +184,10 @@ func sameFunctionShapeForFactMerge(a, b *typ.Function) bool { if !typeParamsEqual(a.TypeParams, b.TypeParams) { return false } - if len(a.Params) != len(b.Params) { - return false - } - // Param type precision and optionality may differ across iterations. - // Treat those as mergeable slots and reconcile in mergeFunctionFactsByShape. - return true + return len(a.Params) == len(b.Params) } -func mergeFunctionFactsByShape(existing, candidate *typ.Function) typ.Type { +func mergeByShape(existing, candidate *typ.Function) typ.Type { if existing == nil { return candidate } @@ -136,7 +201,7 @@ func mergeFunctionFactsByShape(existing, candidate *typ.Function) typ.Type { } for i, p := range existing.Params { - paramType := mergeFunctionParamFactType(p.Type, candidate.Params[i].Type) + paramType := mergeParamType(p.Type, candidate.Params[i].Type) name := p.Name if name == "" { name = candidate.Params[i].Name @@ -150,7 +215,7 @@ func mergeFunctionFactsByShape(existing, candidate *typ.Function) typ.Type { } if existing.Variadic != nil || candidate.Variadic != nil { - builder = builder.Variadic(mergeFunctionParamFactType(existing.Variadic, candidate.Variadic)) + builder = builder.Variadic(mergeParamType(existing.Variadic, candidate.Variadic)) } if mergedReturns := returnsummary.Merge(existing.Returns, candidate.Returns); len(mergedReturns) > 0 { @@ -182,7 +247,7 @@ func mergeFunctionFactsByShape(existing, candidate *typ.Function) typ.Type { return builder.Build() } -func mergeFunctionParamFactType(existing, candidate typ.Type) typ.Type { +func mergeParamType(existing, candidate typ.Type) typ.Type { if existing == nil { return candidate } @@ -198,7 +263,7 @@ func mergeFunctionParamFactType(existing, candidate typ.Type) typ.Type { if unwrap.IsNilType(candidate) && !unwrap.IsNilType(existing) { return existing } - if preferred, ok := preferStructuredRecordParam(existing, candidate); ok { + if preferred, ok := preferStructuredRecord(existing, candidate); ok { return preferred } if preferred, ok := value.PreferConcreteOverSoft(existing, candidate); ok { @@ -237,7 +302,7 @@ func mergeFunctionParamFactType(existing, candidate typ.Type) typ.Type { return typ.JoinPreferNonSoft(existing, candidate) } -func preferStructuredRecordParam(existing, candidate typ.Type) (typ.Type, bool) { +func preferStructuredRecord(existing, candidate typ.Type) (typ.Type, bool) { existingRec, okExisting := unwrap.Alias(existing).(*typ.Record) candidateRec, okCandidate := unwrap.Alias(candidate).(*typ.Record) if !okExisting || !okCandidate { @@ -261,3 +326,21 @@ func preferStructuredRecordParam(existing, candidate typ.Type) (typ.Type, bool) } return nil, false } + +func typeParamsEqual(a, b []*typ.TypeParam) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] == nil || b[i] == nil { + if a[i] != b[i] { + return false + } + continue + } + if !a[i].Equals(b[i]) { + return false + } + } + return true +} diff --git a/compiler/check/domain/functionfact/fact_test.go b/compiler/check/domain/functionfact/fact_test.go new file mode 100644 index 00000000..6d51f3e7 --- /dev/null +++ b/compiler/check/domain/functionfact/fact_test.go @@ -0,0 +1,211 @@ +package functionfact + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" + "github.com/wippyai/go-lua/types/kind" + "github.com/wippyai/go-lua/types/subtype" + "github.com/wippyai/go-lua/types/typ" + "github.com/wippyai/go-lua/types/typ/unwrap" +) + +func TestJoin_InitialObservation(t *testing.T) { + fn := typ.Func().Returns(typ.String).Build() + + got := Join(api.FunctionFact{}, api.FunctionFact{ + Summary: []typ.Type{typ.String}, + Narrow: []typ.Type{typ.String}, + Type: fn, + }) + + if !returnsummary.Equal(got.Summary, []typ.Type{typ.String}) { + t.Fatalf("summary mismatch: got %v", got.Summary) + } + if !returnsummary.Equal(got.Narrow, []typ.Type{typ.String}) { + t.Fatalf("narrow mismatch: got %v", got.Narrow) + } + if !typ.TypeEquals(got.Type, fn) { + t.Fatalf("func mismatch: got %v", got.Type) + } +} + +func TestJoin_NarrowSummaryReplacesOpenTopPlaceholder(t *testing.T) { + openTop := typ.NewRecord().SetOpen(true).Build() + existingFunc := typ.Func().Returns(openTop).Build() + candidateFunc := typ.Func().Returns(openTop).Build() + narrow := []typ.Type{typ.NewArray(typ.Unknown)} + + out := Join( + api.FunctionFact{Summary: []typ.Type{openTop}, Type: existingFunc}, + api.FunctionFact{Summary: []typ.Type{openTop}, Narrow: narrow, Type: candidateFunc}, + ) + + if !returnsummary.Equal(returnsummary.NormalizeAndPrune(out.Summary), returnsummary.NormalizeAndPrune(narrow)) { + t.Fatalf("summary mismatch: got %v want %v", out.Summary, narrow) + } + fn, ok := out.Type.(*typ.Function) + if !ok { + t.Fatalf("expected function fact, got %T", out.Type) + } + if !returnsummary.Equal(returnsummary.NormalizeAndPrune(fn.Returns), returnsummary.NormalizeAndPrune(narrow)) { + t.Fatalf("func returns mismatch: got %v want %v", fn.Returns, narrow) + } +} + +func TestJoin_NarrowSummaryRepairsNeverArtifact(t *testing.T) { + bad := []typ.Type{ + typ.NewUnion( + typ.NewRecord(). + Field("success", typ.True). + Field("result", typ.NewRecord().OptField("data", typ.Never).Build()). + Build(), + typ.NewRecord(). + Field("success", typ.False). + Field("error", typ.LiteralString("missing")). + Build(), + ), + } + good := []typ.Type{ + typ.NewUnion( + typ.NewRecord(). + Field("success", typ.True). + Field("result", typ.NewRecord().OptField("data", typ.Unknown).Build()). + Build(), + typ.NewRecord(). + Field("success", typ.False). + Field("error", typ.LiteralString("missing")). + Build(), + ), + } + + out := Join( + api.FunctionFact{Summary: bad, Type: typ.Func().Returns(bad...).Build()}, + api.FunctionFact{Narrow: good}, + ) + + if !returnsummary.Equal(out.Summary, good) { + t.Fatalf("summary mismatch: got %v want %v", out.Summary, good) + } + if !returnsummary.Equal(out.Narrow, good) { + t.Fatalf("narrow mismatch: got %v want %v", out.Narrow, good) + } + fn, ok := out.Type.(*typ.Function) + if !ok { + t.Fatalf("expected function fact, got %T", out.Type) + } + if !returnsummary.Equal(fn.Returns, good) { + t.Fatalf("func returns mismatch: got %v want %v", fn.Returns, good) + } +} + +func TestMergeType_MergesSameShapeReturnsCanonically(t *testing.T) { + existing := typ.Func(). + Param("x", typ.String). + Returns(typ.NewOptional(typ.Integer)). + Build() + candidate := typ.Func(). + Param("x", typ.String). + Returns(typ.Integer). + Build() + + merged := MergeType(existing, candidate) + fn, ok := merged.(*typ.Function) + if !ok || len(fn.Returns) != 1 { + t.Fatalf("expected merged function, got %T", merged) + } + if !typ.TypeEquals(fn.Returns[0], typ.Integer) { + t.Fatalf("expected refined return integer, got %v", fn.Returns[0]) + } +} + +func TestMergeType_WidensParamToCoverObservedCallsites(t *testing.T) { + existing := typ.Func(). + Param("t", typ.NewArray(typ.Any)). + Returns(typ.String). + Build() + candidate := typ.Func(). + Param("t", typ.NewMap(typ.String, typ.Any)). + Returns(typ.String). + Build() + + merged := MergeType(existing, candidate) + fn, ok := merged.(*typ.Function) + if !ok { + t.Fatalf("expected merged function, got %T", merged) + } + if len(fn.Params) != 1 { + t.Fatalf("expected one param, got %+v", fn.Params) + } + if typ.TypeEquals(fn.Params[0].Type, typ.NewArray(typ.Any)) { + t.Fatalf("expected param widening beyond array-only shape, got %v", fn.Params[0].Type) + } + wantMap := typ.NewMap(typ.String, typ.Any) + if !subtype.IsSubtype(wantMap, fn.Params[0].Type) { + t.Fatalf("expected merged param to admit map callsite evidence, got %v", fn.Params[0].Type) + } +} + +func TestMergeType_CollapsesMixedFunctionUnionVariants(t *testing.T) { + base := typ.Func(). + Param("name", typ.Unknown). + Returns(typ.NewRecord().Field("full_path", typ.String).SetOpen(true).Build()). + Build() + withChildren := typ.Func(). + Param("name", typ.Unknown). + Returns(typ.NewRecord(). + Field("full_path", typ.String). + Field("children", typ.NewArray(typ.Unknown)). + SetOpen(true). + Build()). + Build() + withTests := typ.Func(). + Param("name", typ.Unknown). + Returns(typ.NewRecord(). + Field("full_path", typ.String). + Field("tests", typ.NewArray(typ.Unknown)). + SetOpen(true). + Build()). + Build() + + merged := MergeType(typ.NewUnion(typ.Nil, base, withChildren), withTests) + if merged == nil { + t.Fatal("expected merged type") + } + fn := unwrap.Function(merged) + if fn == nil || len(fn.Returns) != 1 { + t.Fatalf("expected merged function variant, got %v", merged) + } + rec, ok := fn.Returns[0].(*typ.Record) + if !ok { + t.Fatalf("expected record return, got %T", fn.Returns[0]) + } + for _, field := range []string{"full_path", "children", "tests"} { + if rec.GetField(field) == nil { + t.Fatalf("expected merged field %q in %v", field, rec) + } + } + if merged.Kind() != kind.Optional { + t.Fatalf("expected nil residual to be preserved as optional, got %v", merged) + } +} + +func TestNormalize_CanonicalizesStoredFunctionFact(t *testing.T) { + fn := typ.Func().Returns(typ.Number).Build() + got := Normalize(api.FunctionFact{ + Summary: []typ.Type{nil}, + Narrow: []typ.Type{typ.Number}, + Type: fn, + }) + + if !returnsummary.Equal(got.Summary, []typ.Type{typ.Nil}) { + t.Fatalf("summary mismatch: got %v", got.Summary) + } + if !returnsummary.Equal(got.Narrow, []typ.Type{typ.Number}) { + t.Fatalf("narrow mismatch: got %v", got.Narrow) + } + if !typ.TypeEquals(got.Type, fn) { + t.Fatalf("func mismatch: got %v", got.Type) + } +} diff --git a/compiler/check/infer/interproc/postflow.go b/compiler/check/infer/interproc/postflow.go index ffdf941d..14220351 100644 --- a/compiler/check/infer/interproc/postflow.go +++ b/compiler/check/infer/interproc/postflow.go @@ -6,11 +6,11 @@ import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" checkcallsite "github.com/wippyai/go-lua/compiler/check/callsite" + "github.com/wippyai/go-lua/compiler/check/domain/functionfact" "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/compiler/check/erreffect" "github.com/wippyai/go-lua/compiler/check/nested" - "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/compiler/check/scope" "github.com/wippyai/go-lua/compiler/check/synth/ops" "github.com/wippyai/go-lua/types/flow" @@ -82,7 +82,7 @@ func StoreFactsFromResult( } } delta := api.Facts{FunctionFacts: api.FunctionFacts{ - fnSym: returns.JoinFunctionFact(api.FunctionFact{}, api.FunctionFact{ + fnSym: functionfact.Join(api.FunctionFact{}, api.FunctionFact{ Summary: summaryFromSnapshot, Narrow: narrowSummary, Type: candidateFunc, @@ -433,13 +433,13 @@ func CollectParameterEvidenceFromResult(store Store, result *api.FuncResult, par fnEvidence, _ = paramevidence.MergeAt(fnEvidence, j, param.Type, typ.JoinPreferNonSoft) } if len(fnEvidence) > 0 { - deltaFacts[argSym] = returns.JoinFunctionFact(deltaFacts[argSym], api.FunctionFact{Params: fnEvidence}) + deltaFacts[argSym] = functionfact.Join(deltaFacts[argSym], api.FunctionFact{Params: fnEvidence}) } } } } if len(evidence) > 0 { - deltaFacts[calleeSym] = returns.JoinFunctionFact(deltaFacts[calleeSym], api.FunctionFact{Params: evidence}) + deltaFacts[calleeSym] = functionfact.Join(deltaFacts[calleeSym], api.FunctionFact{Params: evidence}) } if len(deltaFacts) > 0 { store.MergeInterprocFactsNext(parentKey, api.Facts{FunctionFacts: deltaFacts}) diff --git a/compiler/check/infer/nested/processor.go b/compiler/check/infer/nested/processor.go index 49bb2cc4..03b47f7d 100644 --- a/compiler/check/infer/nested/processor.go +++ b/compiler/check/infer/nested/processor.go @@ -23,6 +23,7 @@ import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/domain/functionfact" "github.com/wippyai/go-lua/compiler/check/flowbuild/assign" "github.com/wippyai/go-lua/compiler/check/infer/captured" "github.com/wippyai/go-lua/compiler/check/nested" @@ -308,7 +309,7 @@ func (p *Processor) processNestedFunction( // Update sibling types with the fully-inferred function type. if info.IsLocal && info.FuncSym != 0 && result.NarrowSynth != nil { if inferredType := result.NarrowSynth.FunctionType(info.NF.Func, parentScope); inferredType != nil { - siblingFunctionTypes[info.FuncSym] = returns.MergeFunctionFactType(siblingFunctionTypes[info.FuncSym], inferredType) + siblingFunctionTypes[info.FuncSym] = functionfact.MergeType(siblingFunctionTypes[info.FuncSym], inferredType) } } } diff --git a/compiler/check/infer/return/infer.go b/compiler/check/infer/return/infer.go index 370f7d67..509c0ebf 100644 --- a/compiler/check/infer/return/infer.go +++ b/compiler/check/infer/return/infer.go @@ -43,6 +43,7 @@ import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/callsite" + "github.com/wippyai/go-lua/compiler/check/domain/functionfact" "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" flowpath "github.com/wippyai/go-lua/compiler/check/flowbuild/path" @@ -332,7 +333,7 @@ func assembleFunctionFacts( if info := localFuncs[sym]; info != nil { params = info.ParameterEvidence } - ff := returns.JoinFunctionFact(api.FunctionFact{}, api.FunctionFact{ + ff := functionfact.Join(api.FunctionFact{}, api.FunctionFact{ Params: params, Summary: returnVectors[sym], Type: funcs[sym], diff --git a/compiler/check/pipeline/runner_stages.go b/compiler/check/pipeline/runner_stages.go index ddf75739..8899ead6 100644 --- a/compiler/check/pipeline/runner_stages.go +++ b/compiler/check/pipeline/runner_stages.go @@ -4,6 +4,7 @@ import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/domain/functionfact" "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" "github.com/wippyai/go-lua/compiler/check/flowbuild/resolve" "github.com/wippyai/go-lua/compiler/check/phase" @@ -63,7 +64,7 @@ func mergeSynthesizedSignatureFact(seed, fact *typ.Function) *typ.Function { if fact == nil { return seed } - if merged := unwrap.Function(returns.MergeFunctionFactType(seed, fact)); merged != nil { + if merged := unwrap.Function(functionfact.MergeType(seed, fact)); merged != nil { return merged } return seed diff --git a/compiler/check/returns/doc.go b/compiler/check/returns/doc.go index 21372f4f..0829c293 100644 --- a/compiler/check/returns/doc.go +++ b/compiler/check/returns/doc.go @@ -1,8 +1,15 @@ -// Package returns provides interprocedural return type analysis. +// Package returns orchestrates local return inference and interprocedural fact +// products. // -// This package implements the fixpoint iteration for return type inference -// across mutually recursive function groups. It computes strongly connected -// components in the call graph and processes them in dependency order. +// It does not own the lattice laws for individual fact slots. Those live in +// domain packages: +// - domain/paramevidence owns parameter evidence; +// - domain/returnsummary owns return vectors and function-return alignment; +// - domain/functionfact owns one api.FunctionFact at a time; +// - domain/value owns reusable structural value relations. +// +// This package owns when those domains are applied across maps, SCCs, overlays, +// captured mutations, and recursive interprocedural fixpoint boundaries. // // # SCC-Based Analysis // @@ -18,8 +25,8 @@ // For each function: // - Collect return expressions from all return statements // - Synthesize types for return expressions -// - Join multiple return types into a union -// - Apply widening for recursive convergence +// - Merge candidate return vectors through domain/returnsummary +// - Apply product-level widening for recursive convergence // // # Type Widening // @@ -41,6 +48,6 @@ // // # Signature Inference // -// [InferSignature] combines parameter evidence and return types to produce +// Signature inference combines parameter evidence and return types to produce // complete function signatures for functions without annotations. package returns diff --git a/compiler/check/returns/function_facts.go b/compiler/check/returns/function_facts.go index 11bf1366..aaab7799 100644 --- a/compiler/check/returns/function_facts.go +++ b/compiler/check/returns/function_facts.go @@ -3,8 +3,7 @@ package returns import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" - "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" - "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" + "github.com/wippyai/go-lua/compiler/check/domain/functionfact" ) func collectCanonicalFunctionFactSymbols(factSets ...api.FunctionFacts) []cfg.SymbolID { @@ -25,19 +24,6 @@ func markFunctionFactSymbols[T any](dst map[cfg.SymbolID]bool, src map[cfg.Symbo } } -func NormalizeFunctionFact(ff api.FunctionFact) api.FunctionFact { - return api.FunctionFact{ - Params: paramevidence.FilterEmptyVector(ff.Params), - Summary: returnsummary.Canonical(ff.Summary), - Narrow: returnsummary.Canonical(ff.Narrow), - Type: normalizeInterprocValueType(ff.Type), - } -} - -func functionFactEmpty(ff api.FunctionFact) bool { - return len(ff.Params) == 0 && len(ff.Summary) == 0 && len(ff.Narrow) == 0 && ff.Type == nil -} - func readFunctionFactFromFacts(facts *api.Facts, sym cfg.SymbolID) api.FunctionFact { if facts == nil || sym == 0 { return api.FunctionFact{} @@ -49,41 +35,19 @@ func readFunctionFactFromFacts(facts *api.Facts, sym cfg.SymbolID) api.FunctionF if !ok { return api.FunctionFact{} } - canonical := NormalizeFunctionFact(ff) - if !functionFactEmpty(canonical) { + canonical := functionfact.Normalize(ff) + if !functionfact.Empty(canonical) { return canonical } return api.FunctionFact{} } -func normalizeFunctionFactMap(facts api.FunctionFacts) api.FunctionFacts { - if len(facts) == 0 { - return nil - } - out := make(api.FunctionFacts, len(facts)) - for _, sym := range cfg.SortedSymbolIDs(facts) { - canonical := NormalizeFunctionFact(facts[sym]) - if functionFactEmpty(canonical) { - continue - } - out[sym] = canonical - } - if len(out) == 0 { - return nil - } - return out -} - -func writeFunctionFactToFacts(facts *api.Facts, sym cfg.SymbolID, ff api.FunctionFact) { - writeNormalizedFunctionFactToFacts(facts, sym, NormalizeFunctionFact(ff)) -} - func writeNormalizedFunctionFactToFacts(facts *api.Facts, sym cfg.SymbolID, ff api.FunctionFact) { if facts == nil || sym == 0 { return } - if functionFactEmpty(ff) { + if functionfact.Empty(ff) { if facts.FunctionFacts != nil { delete(facts.FunctionFacts, sym) if len(facts.FunctionFacts) == 0 { @@ -97,11 +61,3 @@ func writeNormalizedFunctionFactToFacts(facts *api.Facts, sym cfg.SymbolID, ff a facts.FunctionFacts[sym] = ff } } - -// NormalizeFunctionFacts canonicalizes the stored FunctionFacts map. -func NormalizeFunctionFacts(facts *api.Facts) { - if facts == nil { - return - } - facts.FunctionFacts = normalizeFunctionFactMap(facts.FunctionFacts) -} diff --git a/compiler/check/returns/join_test.go b/compiler/check/returns/join_test.go index f7d59312..f4dbe494 100644 --- a/compiler/check/returns/join_test.go +++ b/compiler/check/returns/join_test.go @@ -3,6 +3,7 @@ package returns import ( "testing" + "github.com/wippyai/go-lua/compiler/check/domain/functionfact" "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/types/kind" "github.com/wippyai/go-lua/types/subtype" @@ -279,7 +280,7 @@ func TestReturnSummaryMerge_FillsNilSlotWithCandidateEvidence(t *testing.T) { } } -func TestMergeFunctionFactType_MergesSameShapeReturnsCanonically(t *testing.T) { +func TestFunctionFactMergeType_MergesSameShapeReturnsCanonically(t *testing.T) { existing := typ.Func(). Param("x", typ.String). Returns(typ.NewOptional(typ.Integer)). @@ -289,7 +290,7 @@ func TestMergeFunctionFactType_MergesSameShapeReturnsCanonically(t *testing.T) { Returns(typ.Integer). Build() - merged := MergeFunctionFactType(existing, candidate) + merged := functionfact.MergeType(existing, candidate) fn, ok := merged.(*typ.Function) if !ok || len(fn.Returns) != 1 { t.Fatalf("expected merged function, got %T", merged) @@ -299,7 +300,7 @@ func TestMergeFunctionFactType_MergesSameShapeReturnsCanonically(t *testing.T) { } } -func TestMergeFunctionFactType_PrefersConcreteParamOverTopObservation(t *testing.T) { +func TestFunctionFactMergeType_PrefersConcreteParamOverTopObservation(t *testing.T) { existing := typ.Func(). Param("x", typ.Any). Returns(typ.String). @@ -309,7 +310,7 @@ func TestMergeFunctionFactType_PrefersConcreteParamOverTopObservation(t *testing Returns(typ.String). Build() - merged := MergeFunctionFactType(existing, candidate) + merged := functionfact.MergeType(existing, candidate) fn, ok := merged.(*typ.Function) if !ok { t.Fatalf("expected merged function, got %T", merged) @@ -319,7 +320,7 @@ func TestMergeFunctionFactType_PrefersConcreteParamOverTopObservation(t *testing } } -func TestMergeFunctionFactType_WidensParamToCoverObservedCallsites(t *testing.T) { +func TestFunctionFactMergeType_WidensParamToCoverObservedCallsites(t *testing.T) { existing := typ.Func(). Param("t", typ.NewArray(typ.Any)). Returns(typ.String). @@ -329,7 +330,7 @@ func TestMergeFunctionFactType_WidensParamToCoverObservedCallsites(t *testing.T) Returns(typ.String). Build() - merged := MergeFunctionFactType(existing, candidate) + merged := functionfact.MergeType(existing, candidate) fn, ok := merged.(*typ.Function) if !ok { t.Fatalf("expected merged function, got %T", merged) @@ -346,7 +347,7 @@ func TestMergeFunctionFactType_WidensParamToCoverObservedCallsites(t *testing.T) } } -func TestMergeFunctionFactType_KeepsBaselineOverNestedNilOnlyRegression(t *testing.T) { +func TestFunctionFactMergeType_KeepsBaselineOverNestedNilOnlyRegression(t *testing.T) { baselineReturn := typ.NewRecord(). Field("full_path", typ.String). Field("parent", typ.Unknown). @@ -363,7 +364,7 @@ func TestMergeFunctionFactType_KeepsBaselineOverNestedNilOnlyRegression(t *testi baseline := typ.Func().Param("name", typ.Unknown).Returns(baselineReturn).Build() candidate := typ.Func().Param("name", typ.Unknown).Returns(candidateReturn).Build() - merged := MergeFunctionFactType(baseline, candidate) + merged := functionfact.MergeType(baseline, candidate) fn, ok := merged.(*typ.Function) if !ok || len(fn.Returns) != 1 { t.Fatalf("expected merged function return, got %v", merged) @@ -424,7 +425,7 @@ func TestReturnSummaryMerge_PrefersMapOverStaleOpenRecordMapKeyRefinement(t *tes } } -func TestMergeFunctionFactType_CollapsesMixedFunctionUnionVariants(t *testing.T) { +func TestFunctionFactMergeType_CollapsesMixedFunctionUnionVariants(t *testing.T) { base := typ.Func(). Param("name", typ.Unknown). Returns(typ.NewRecord().Field("full_path", typ.String).SetOpen(true).Build()). @@ -446,7 +447,7 @@ func TestMergeFunctionFactType_CollapsesMixedFunctionUnionVariants(t *testing.T) Build()). Build() - merged := MergeFunctionFactType(typ.NewUnion(typ.Nil, base, withChildren), withTests) + merged := functionfact.MergeType(typ.NewUnion(typ.Nil, base, withChildren), withTests) if merged == nil { t.Fatal("expected merged type") } @@ -470,12 +471,12 @@ func TestMergeFunctionFactType_CollapsesMixedFunctionUnionVariants(t *testing.T) } } -func TestMergeFunctionFactType_DoesNotDropNonFunctionUnionMembers(t *testing.T) { +func TestFunctionFactMergeType_DoesNotDropNonFunctionUnionMembers(t *testing.T) { fn := typ.Func().Param("x", typ.String).Returns(typ.String).Build() existing := typ.NewUnion(fn, typ.Number) candidate := typ.Func().Param("x", typ.String).Returns(typ.String).Build() - merged := MergeFunctionFactType(existing, candidate) + merged := functionfact.MergeType(existing, candidate) u, ok := merged.(*typ.Union) if !ok { t.Fatalf("expected union to be preserved, got %T", merged) @@ -492,7 +493,7 @@ func TestMergeFunctionFactType_DoesNotDropNonFunctionUnionMembers(t *testing.T) } } -func TestMergeFunctionFactType_CollapsesCompatibleFunctionVariants(t *testing.T) { +func TestFunctionFactMergeType_CollapsesCompatibleFunctionVariants(t *testing.T) { base := typ.Func(). OptParam("entries", typ.Any). Returns(typ.NewMap(typ.Unknown, typ.NewArray(typ.Unknown))). @@ -503,7 +504,7 @@ func TestMergeFunctionFactType_CollapsesCompatibleFunctionVariants(t *testing.T) Returns(typ.NewMap(typ.String, typ.NewArray(refinedEntry))). Build() - merged := MergeFunctionFactType(base, refined) + merged := functionfact.MergeType(base, refined) fn, ok := merged.(*typ.Function) if !ok { t.Fatalf("expected function after compatible-variant collapse, got %T", merged) @@ -516,7 +517,7 @@ func TestMergeFunctionFactType_CollapsesCompatibleFunctionVariants(t *testing.T) } } -func TestMergeFunctionFactType_DoesNotCollapseParamToNilWhenOptionalInfoExists(t *testing.T) { +func TestFunctionFactMergeType_DoesNotCollapseParamToNilWhenOptionalInfoExists(t *testing.T) { existing := typ.Func(). OptParam("tests", typ.Nil). Returns(typ.Integer). @@ -526,7 +527,7 @@ func TestMergeFunctionFactType_DoesNotCollapseParamToNilWhenOptionalInfoExists(t Returns(typ.Integer). Build() - merged := MergeFunctionFactType(existing, candidate) + merged := functionfact.MergeType(existing, candidate) fn, ok := merged.(*typ.Function) if !ok { t.Fatalf("expected function, got %T", merged) @@ -537,11 +538,11 @@ func TestMergeFunctionFactType_DoesNotCollapseParamToNilWhenOptionalInfoExists(t } } -func TestMergeFunctionFactType_NilDoesNotDominateSoftOptionalParamShape(t *testing.T) { +func TestFunctionFactMergeType_NilDoesNotDominateSoftOptionalParamShape(t *testing.T) { softArray := typ.NewOptional(typ.NewUnion(typ.NewArray(typ.Any), typ.NewRecord().SetOpen(true).Build())) preciseArray := typ.NewOptional(typ.NewArray(typ.String)) - merged := MergeFunctionFactType( + merged := functionfact.MergeType( typ.Func().OptParam("tests", typ.Nil).Returns(typ.Integer).Build(), typ.Func().OptParam("tests", softArray).Returns(typ.Integer).Build(), ) @@ -553,7 +554,7 @@ func TestMergeFunctionFactType_NilDoesNotDominateSoftOptionalParamShape(t *testi t.Fatalf("expected nil observation not to replace soft optional table shape, got %v", fn.Params[0].Type) } - merged = MergeFunctionFactType( + merged = functionfact.MergeType( typ.Func().OptParam("tests", softArray).Returns(typ.Integer).Build(), typ.Func().OptParam("tests", preciseArray).Returns(typ.Integer).Build(), ) @@ -566,7 +567,7 @@ func TestMergeFunctionFactType_NilDoesNotDominateSoftOptionalParamShape(t *testi } } -func TestMergeFunctionFactType_ReplacesStaleFalsyMapKeyWithTruthyRefinement(t *testing.T) { +func TestFunctionFactMergeType_ReplacesStaleFalsyMapKeyWithTruthyRefinement(t *testing.T) { entry := typ.NewRecord().Field("id", typ.String).Build() stale := typ.NewRecord(). MapComponent(typ.NewUnion(typ.Boolean, typ.String), typ.NewArray(entry)). @@ -577,7 +578,7 @@ func TestMergeFunctionFactType_ReplacesStaleFalsyMapKeyWithTruthyRefinement(t *t SetOpen(true). Build() - merged := MergeFunctionFactType( + merged := functionfact.MergeType( typ.Func().OptParam("t", stale).Returns(typ.NewArray(typ.NewUnion(typ.Boolean, typ.String))).Build(), typ.Func().OptParam("t", current).Returns(typ.NewArray(typ.String)).Build(), ) diff --git a/compiler/check/returns/kernel.go b/compiler/check/returns/kernel.go deleted file mode 100644 index 19d9abc4..00000000 --- a/compiler/check/returns/kernel.go +++ /dev/null @@ -1,55 +0,0 @@ -package returns - -import ( - "github.com/wippyai/go-lua/compiler/check/api" - "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" - "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" - "github.com/wippyai/go-lua/types/typ/unwrap" -) - -// JoinFunctionFact precisely merges two observations for one local function -// during a single analysis iteration. -func JoinFunctionFact(existing, candidate api.FunctionFact) api.FunctionFact { - existing = NormalizeFunctionFact(existing) - candidate = NormalizeFunctionFact(candidate) - out := existing - - if len(candidate.Params) > 0 { - out.Params = paramevidence.JoinVectors(out.Params, candidate.Params) - } - if len(candidate.Summary) > 0 { - out.Summary = returnsummary.Merge(out.Summary, candidate.Summary) - } - if len(candidate.Narrow) > 0 { - out.Narrow = returnsummary.Merge(out.Narrow, candidate.Narrow) - } - if candidate.Type != nil { - out.Type = MergeFunctionFactType(out.Type, candidate.Type) - } - - // Keep summary and post-flow narrow results mutually refining when narrow - // provides first-order information. returnsummary.Merge is the canonical - // policy and already encodes directional refinement preference. - if len(out.Narrow) > 0 { - if len(out.Summary) == 0 { - out.Summary = returnsummary.Canonical(out.Narrow) - } else { - out.Summary = returnsummary.Merge(out.Summary, out.Narrow) - } - } - - if fn := unwrap.Function(out.Type); fn != nil { - alignedSummary := out.Summary - if len(alignedSummary) > 0 { - if aligned, changed := returnsummary.AlignFunction(fn, alignedSummary); changed { - out.Type = aligned - fn = aligned - } - } - if len(out.Summary) == 0 && fn != nil && len(fn.Returns) > 0 { - out.Summary = returnsummary.Canonical(fn.Returns) - } - } - - return out -} diff --git a/compiler/check/returns/kernel_test.go b/compiler/check/returns/kernel_test.go index f6745c83..a6251572 100644 --- a/compiler/check/returns/kernel_test.go +++ b/compiler/check/returns/kernel_test.go @@ -5,15 +5,16 @@ import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/domain/functionfact" "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/types/typ" ) -func TestJoinFunctionFact_InitialObservation(t *testing.T) { +func TestFunctionFactJoin_InitialObservation(t *testing.T) { sym := cfg.SymbolID(11) fn := typ.Func().Returns(typ.String).Build() - facts := api.Facts{FunctionFacts: api.FunctionFacts{sym: JoinFunctionFact(api.FunctionFact{}, api.FunctionFact{ + facts := api.Facts{FunctionFacts: api.FunctionFacts{sym: functionfact.Join(api.FunctionFact{}, api.FunctionFact{ Summary: []typ.Type{typ.String}, Narrow: []typ.Type{typ.String}, Type: fn, @@ -30,7 +31,7 @@ func TestJoinFunctionFact_InitialObservation(t *testing.T) { } } -func TestJoinFunctionFact_MergesExistingAndCandidate(t *testing.T) { +func TestFunctionFactJoin_MergesExistingAndCandidate(t *testing.T) { existingFn := typ.Func().Returns(typ.Number).Build() candidateFn := typ.Func().Returns(typ.String).Build() existing := api.FunctionFact{ @@ -43,7 +44,7 @@ func TestJoinFunctionFact_MergesExistingAndCandidate(t *testing.T) { Narrow: []typ.Type{typ.String}, Type: candidateFn, } - got := JoinFunctionFact(existing, candidate) + got := functionfact.Join(existing, candidate) if !returnsummary.Equal(got.Summary, []typ.Type{typ.NewUnion(typ.Number, typ.String)}) { t.Fatalf("summary mismatch: got %v", got.Summary) @@ -87,13 +88,13 @@ func TestJoinFacts_BatchMergeFunctionFacts(t *testing.T) { } } -func TestJoinFunctionFact_NarrowSummaryReplacesOpenTopPlaceholder(t *testing.T) { +func TestFunctionFactJoin_NarrowSummaryReplacesOpenTopPlaceholder(t *testing.T) { openTop := typ.NewRecord().SetOpen(true).Build() existingFunc := typ.Func().Returns(openTop).Build() candidateFunc := typ.Func().Returns(openTop).Build() narrow := []typ.Type{typ.NewArray(typ.Unknown)} - out := JoinFunctionFact( + out := functionfact.Join( api.FunctionFact{Summary: []typ.Type{openTop}, Type: existingFunc}, api.FunctionFact{Summary: []typ.Type{openTop}, Narrow: narrow, Type: candidateFunc}, ) @@ -111,7 +112,7 @@ func TestJoinFunctionFact_NarrowSummaryReplacesOpenTopPlaceholder(t *testing.T) } } -func TestJoinFunctionFact_NarrowSummaryRepairsNeverArtifact(t *testing.T) { +func TestFunctionFactJoin_NarrowSummaryRepairsNeverArtifact(t *testing.T) { bad := []typ.Type{ typ.NewUnion( typ.NewRecord(). @@ -138,7 +139,7 @@ func TestJoinFunctionFact_NarrowSummaryRepairsNeverArtifact(t *testing.T) { } existingFunc := typ.Func().Returns(bad...).Build() - out := JoinFunctionFact( + out := functionfact.Join( api.FunctionFact{Summary: bad, Type: existingFunc}, api.FunctionFact{Narrow: good}, ) @@ -158,7 +159,7 @@ func TestJoinFunctionFact_NarrowSummaryRepairsNeverArtifact(t *testing.T) { } } -func TestJoinFunctionFact_DoesNotAlignFunctionToNarrowFieldRegression(t *testing.T) { +func TestFunctionFactJoin_DoesNotAlignFunctionToNarrowFieldRegression(t *testing.T) { withCapturedMethod := typ.NewRecord(). Field("x", typ.Integer). Field("get_x", typ.Func().Param("self", typ.Unknown).Returns(typ.Number).Build()). @@ -168,7 +169,7 @@ func TestJoinFunctionFact_DoesNotAlignFunctionToNarrowFieldRegression(t *testing Build() existingFunc := typ.Func().Returns(flowOnly).Build() - out := JoinFunctionFact( + out := functionfact.Join( api.FunctionFact{Summary: []typ.Type{withCapturedMethod}, Narrow: []typ.Type{flowOnly}, Type: existingFunc}, api.FunctionFact{Summary: []typ.Type{withCapturedMethod}, Narrow: []typ.Type{flowOnly}, Type: existingFunc}, ) @@ -185,7 +186,7 @@ func TestJoinFunctionFact_DoesNotAlignFunctionToNarrowFieldRegression(t *testing } } -func TestNormalizeFunctionFacts_CanonicalizesStoredFunctionFacts(t *testing.T) { +func TestFunctionFactNormalize_CanonicalizesStoredFunctionFacts(t *testing.T) { sym := cfg.SymbolID(77) fn := typ.Func().Returns(typ.Number).Build() facts := &api.Facts{ @@ -194,7 +195,7 @@ func TestNormalizeFunctionFacts_CanonicalizesStoredFunctionFacts(t *testing.T) { }, } - NormalizeFunctionFacts(facts) + facts.FunctionFacts[sym] = functionfact.Normalize(facts.FunctionFacts[sym]) ff, ok := facts.FunctionFacts[sym] if !ok { diff --git a/compiler/check/returns/widen.go b/compiler/check/returns/widen.go index c208b4f2..eab99f6d 100644 --- a/compiler/check/returns/widen.go +++ b/compiler/check/returns/widen.go @@ -3,6 +3,7 @@ package returns import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/domain/functionfact" "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/compiler/check/domain/value" @@ -57,7 +58,7 @@ func JoinFacts(prev, next api.Facts) api.Facts { for _, sym := range symbols { prevFact := readFunctionFactFromFacts(&prev, sym) nextFact := readFunctionFactFromFacts(&next, sym) - writeNormalizedFunctionFactToFacts(&out, sym, JoinFunctionFact(prevFact, nextFact)) + writeNormalizedFunctionFactToFacts(&out, sym, functionfact.Join(prevFact, nextFact)) } return out } @@ -270,7 +271,7 @@ func joinInterprocValueType(existing, candidate typ.Type) typ.Type { return existing } if unwrap.Function(existing) != nil || unwrap.Function(candidate) != nil { - return MergeFunctionFactType(existing, candidate) + return functionfact.MergeType(existing, candidate) } return typ.JoinPreferNonSoft(existing, candidate) } @@ -330,7 +331,7 @@ func widenFunctionFactTypeForConvergence(existing, candidate typ.Type) typ.Type } existingFn := unwrap.Function(existing) candidateFn := unwrap.Function(candidate) - if existingFn != nil && candidateFn != nil && sameFunctionShapeForFactMerge(existingFn, candidateFn) { + if existingFn != nil && candidateFn != nil && functionfact.SameShape(existingFn, candidateFn) { return maybeWidenTypeForConvergence(widenFunctionFactsByShape(existingFn, candidateFn)) } return widenValueTypeForConvergence(existing, candidate) diff --git a/compiler/check/returns/widen_test.go b/compiler/check/returns/widen_test.go index 006c5fbd..48bbe76a 100644 --- a/compiler/check/returns/widen_test.go +++ b/compiler/check/returns/widen_test.go @@ -5,6 +5,7 @@ import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/domain/functionfact" "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" @@ -277,7 +278,7 @@ func TestMergeFunctionReturnsIfSameShape_GenericTypeParamsMustMatch(t *testing.T } } -func TestMergeFunctionFactType_DoesNotRegressToNarrowerNilReturn(t *testing.T) { +func TestFunctionFactMergeType_DoesNotRegressToNarrowerNilReturn(t *testing.T) { prev := typ.Func(). Returns(typ.NewOptional(typ.Integer)). Build() @@ -285,7 +286,7 @@ func TestMergeFunctionFactType_DoesNotRegressToNarrowerNilReturn(t *testing.T) { Returns(typ.Nil). Build() - merged := MergeFunctionFactType(prev, next) + merged := functionfact.MergeType(prev, next) fn, ok := merged.(*typ.Function) if !ok || len(fn.Returns) != 1 { t.Fatalf("expected merged function return, got %T", merged) @@ -316,19 +317,19 @@ func TestMergeFunctionReturnsIfSameShape_NormalizesLeakedTypeParams(t *testing.T } } -func TestMergeFunctionFactType_PrefersWiderSupertypeOnSubtypeRelation(t *testing.T) { - merged := MergeFunctionFactType(typ.Integer, typ.Number) +func TestFunctionFactMergeType_PrefersWiderSupertypeOnSubtypeRelation(t *testing.T) { + merged := functionfact.MergeType(typ.Integer, typ.Number) if !typ.TypeEquals(merged, typ.Number) { t.Fatalf("expected wider supertype number, got %v", merged) } - merged = MergeFunctionFactType(typ.Number, typ.Integer) + merged = functionfact.MergeType(typ.Number, typ.Integer) if !typ.TypeEquals(merged, typ.Number) { t.Fatalf("expected wider supertype number, got %v", merged) } } -func TestMergeFunctionFactType_IsCommutativeForIncomparableSignatures(t *testing.T) { +func TestFunctionFactMergeType_IsCommutativeForIncomparableSignatures(t *testing.T) { coarse := typ.Func(). Param("entries", typ.Any). Returns(typ.Integer). @@ -338,14 +339,14 @@ func TestMergeFunctionFactType_IsCommutativeForIncomparableSignatures(t *testing Returns(typ.Integer). Build() - forward := MergeFunctionFactType(coarse, refined) - reverse := MergeFunctionFactType(refined, coarse) + forward := functionfact.MergeType(coarse, refined) + reverse := functionfact.MergeType(refined, coarse) if !typ.TypeEquals(forward, reverse) { t.Fatalf("expected commutative merge result, got forward=%v reverse=%v", forward, reverse) } } -func TestMergeFunctionFactType_AliasInputsUseCanonicalJoin(t *testing.T) { +func TestFunctionFactMergeType_AliasInputsUseCanonicalJoin(t *testing.T) { coarse := typ.NewAlias("CoarseFn", typ.Func(). Param("entries", typ.Any). Returns(typ.Integer). @@ -355,14 +356,14 @@ func TestMergeFunctionFactType_AliasInputsUseCanonicalJoin(t *testing.T) { Returns(typ.Integer). Build()) - forward := MergeFunctionFactType(coarse, refined) - reverse := MergeFunctionFactType(refined, coarse) + forward := functionfact.MergeType(coarse, refined) + reverse := functionfact.MergeType(refined, coarse) if !typ.TypeEquals(forward, reverse) { t.Fatalf("expected commutative alias merge result, got forward=%v reverse=%v", forward, reverse) } } -func TestMergeFunctionFactType_MapVsOpenRecordUsesCanonicalJoin(t *testing.T) { +func TestFunctionFactMergeType_MapVsOpenRecordUsesCanonicalJoin(t *testing.T) { coarse := typ.Func(). Param("t", typ.NewRecord().SetOpen(true).Build()). Returns(typ.String). @@ -372,8 +373,8 @@ func TestMergeFunctionFactType_MapVsOpenRecordUsesCanonicalJoin(t *testing.T) { Returns(typ.String). Build() - forward := MergeFunctionFactType(coarse, refined) - reverse := MergeFunctionFactType(refined, coarse) + forward := functionfact.MergeType(coarse, refined) + reverse := functionfact.MergeType(refined, coarse) if !typ.TypeEquals(forward, reverse) { t.Fatalf("expected commutative map/open-record merge result, got forward=%v reverse=%v", forward, reverse) } diff --git a/compiler/check/siblings/siblings.go b/compiler/check/siblings/siblings.go index 24263a0d..07c7b5ae 100644 --- a/compiler/check/siblings/siblings.go +++ b/compiler/check/siblings/siblings.go @@ -38,7 +38,7 @@ import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" - "github.com/wippyai/go-lua/compiler/check/returns" + "github.com/wippyai/go-lua/compiler/check/domain/functionfact" "github.com/wippyai/go-lua/types/typ" "github.com/wippyai/go-lua/types/typ/unwrap" ) @@ -183,7 +183,7 @@ func Build(c BuildConfig) map[cfg.SymbolID]typ.Type { if fnType == nil { continue } - result[entry.Symbol] = returns.MergeFunctionFactType(result[entry.Symbol], fnType) + result[entry.Symbol] = functionfact.MergeType(result[entry.Symbol], fnType) } if len(result) == 0 { diff --git a/compiler/check/siblings/siblings_test.go b/compiler/check/siblings/siblings_test.go index 9a82b982..a29ea30f 100644 --- a/compiler/check/siblings/siblings_test.go +++ b/compiler/check/siblings/siblings_test.go @@ -5,7 +5,7 @@ import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" - "github.com/wippyai/go-lua/compiler/check/returns" + "github.com/wippyai/go-lua/compiler/check/domain/functionfact" "github.com/wippyai/go-lua/types/typ" ) @@ -51,21 +51,21 @@ func TestBuild_WithPrev(t *testing.T) { } func TestMergeSiblingType_BothNilViaBuildAPI(t *testing.T) { - result := returns.MergeFunctionFactType(nil, nil) + result := functionfact.MergeType(nil, nil) if result != nil { t.Error("both nil should return nil") } } func TestMergeSiblingType_PrevNilViaBuildAPI(t *testing.T) { - result := returns.MergeFunctionFactType(nil, typ.String) + result := functionfact.MergeType(nil, typ.String) if result != typ.String { t.Error("prev nil should return next") } } func TestMergeSiblingType_NextNilViaBuildAPI(t *testing.T) { - result := returns.MergeFunctionFactType(typ.String, nil) + result := functionfact.MergeType(typ.String, nil) if result != typ.String { t.Error("next nil should return prev") } @@ -74,7 +74,7 @@ func TestMergeSiblingType_NextNilViaBuildAPI(t *testing.T) { func TestMergeSiblingType_FunctionsViaBuildAPI(t *testing.T) { prevFn := typ.Func().Build() nextFn := typ.Func().Returns(typ.String).Build() - result := returns.MergeFunctionFactType(prevFn, nextFn) + result := functionfact.MergeType(prevFn, nextFn) if result == nil { t.Fatal("should return merged function") } @@ -90,7 +90,7 @@ func TestMergeSiblingType_FunctionsViaBuildAPI(t *testing.T) { func TestMergeSiblingType_FunctionAliasesViaBuildAPI(t *testing.T) { prevFn := typ.NewAlias("Prev", typ.Func().Build()) nextFn := typ.NewAlias("Next", typ.Func().Returns(typ.String).Build()) - result := returns.MergeFunctionFactType(prevFn, nextFn) + result := functionfact.MergeType(prevFn, nextFn) if !typ.TypeEquals(result, nextFn) { t.Fatalf("expected function alias with returns to be preferred, got %v", result) } From ebc1050e3cf295dfaf13f6d97b0cbac0cc5cf4db Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 02:56:12 -0400 Subject: [PATCH 22/71] Move fact product into domain --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 66 ++++ compiler/check/api/facts_test.go | 11 + .../factproduct/captured_field.go} | 2 +- .../factproduct/container_mutation.go} | 2 +- .../factproduct/container_mutation_test.go} | 2 +- compiler/check/domain/factproduct/doc.go | 9 + .../factproduct}/domain_law_test.go | 2 +- .../{returns => domain/factproduct}/equal.go | 2 +- .../factproduct}/equal_test.go | 2 +- .../factproduct/function_fact_product_test.go | 41 ++ .../factproduct}/function_facts.go | 2 +- .../factproduct/product.go} | 2 +- .../factproduct/product_test.go} | 84 +--- .../check/domain/functionfact/fact_test.go | 301 ++++++++++++++ .../returnsummary}/join_test.go | 366 +++--------------- compiler/check/returns/doc.go | 18 +- compiler/check/returns/kernel_test.go | 235 ----------- compiler/check/store/snapshot_inputs.go | 8 +- compiler/check/store/store.go | 14 +- 19 files changed, 507 insertions(+), 662 deletions(-) rename compiler/check/{returns/captured_field_merge.go => domain/factproduct/captured_field.go} (98%) rename compiler/check/{returns/container_mutation_merge.go => domain/factproduct/container_mutation.go} (99%) rename compiler/check/{returns/container_mutation_merge_test.go => domain/factproduct/container_mutation_test.go} (99%) create mode 100644 compiler/check/domain/factproduct/doc.go rename compiler/check/{returns => domain/factproduct}/domain_law_test.go (99%) rename compiler/check/{returns => domain/factproduct}/equal.go (99%) rename compiler/check/{returns => domain/factproduct}/equal_test.go (99%) create mode 100644 compiler/check/domain/factproduct/function_fact_product_test.go rename compiler/check/{returns => domain/factproduct}/function_facts.go (98%) rename compiler/check/{returns/widen.go => domain/factproduct/product.go} (99%) rename compiler/check/{returns/widen_test.go => domain/factproduct/product_test.go} (79%) rename compiler/check/{returns => domain/returnsummary}/join_test.go (59%) delete mode 100644 compiler/check/returns/kernel_test.go diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index de4c189e..60c2daa3 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -279,6 +279,72 @@ Verification for this slice so far: targets: session 8 errors, agent/src 10 errors, docker-demo 21 errors and 2 warnings. +## 2026-05-19 Fact Product Domain Checkpoint + +The next rectification slice moved the whole interprocedural fact product out of +`compiler/check/returns` and into `compiler/check/domain/factproduct`. + +Moved product laws: + +- `api.Facts` equality; +- same-iteration product join; +- recursive-boundary product widening; +- function-fact map canonicalization and deterministic symbol enumeration; +- literal signature join/widen; +- captured type join/widen; +- captured field assignment join/widen; +- captured container mutation join/widen; +- constructor-field join/widen; +- deterministic captured field/container equality and merge helpers. + +Production callers now import `domain/factproduct` directly. The old +`returns.WidenFacts`, `returns.JoinFacts`, `returns.FactsEqual`, +`returns.ConstructorFieldsEqual`, `returns.WidenLiteralSigs`, +`returns.JoinLiteralSigs`, captured-fact join/widen/equality helpers, and +captured merge helpers were deleted from `returns` instead of wrapped. + +Test ownership was rectified at the same time: + +- return-vector and return-summary law tests moved to `domain/returnsummary`; +- one-function fact join/type-merge tests moved to `domain/functionfact`; +- whole-product tests moved to `domain/factproduct`; +- `returns` keeps only local return orchestration tests. + +Current package ownership: + +```text +domain/value = reusable structural value relations +domain/paramevidence = parameter evidence lattice, equality, and parameter-slot refinement +domain/returnsummary = return-vector lattice and function-return alignment +domain/functionfact = one-function fact normalization, join, and type projection +domain/factproduct = whole api.Facts product join, widening, equality, and map domains +returns = local return SCC orchestration, call graph, overlays, signature seeding +store = snapshot/Salsa wiring and fixpoint publication +``` + +This separates the abstract interpreter more cleanly: + +1. local inference produces function and mutation deltas; +2. `domain/functionfact` and the other slot domains define one-slot meaning; +3. `domain/factproduct` defines how the whole interprocedural product combines; +4. the store decides when to apply join or widening and when to publish Salsa + snapshot inputs; +5. `returns` no longer owns cross-graph product laws. + +Verification for this slice so far: + +- `go test ./compiler/check/domain/factproduct ./compiler/check/store + ./compiler/check/returns ./compiler/check/...` passes. +- `go test ./...` passes. +- `git diff --check` passes. +- `go test ./compiler/check -run '^$' -bench BenchmarkCheck_LargeFunction + -benchmem -count=3` reports 1.17-1.19 ms/op, 881 KB/op, and 9390 allocs/op + on this machine. +- Standard `../scripts/verify-suite.sh` passes go-lua checker tests and builds + the Wippy binary, then exits non-zero on the known external pinned lint + targets: session 8 errors, agent/src 10 errors, docker-demo 21 errors and + 2 warnings. + ## Goal The checker should read as one abstract interpreter over a product domain. diff --git a/compiler/check/api/facts_test.go b/compiler/check/api/facts_test.go index 5fb90318..06ad4473 100644 --- a/compiler/check/api/facts_test.go +++ b/compiler/check/api/facts_test.go @@ -38,6 +38,17 @@ func TestFunctionFacts_Summary(t *testing.T) { } } +func TestFunctionFacts_NarrowSummary(t *testing.T) { + facts := make(FunctionFacts) + sym := cfg.SymbolID(1) + facts[sym] = FunctionFact{Narrow: []typ.Type{typ.String}} + + rets := facts.NarrowSummary(sym) + if len(rets) != 1 || !typ.TypeEquals(rets[0], typ.String) { + t.Fatalf("expected narrow string return, got %v", rets) + } +} + func TestFunctionFacts_Params(t *testing.T) { facts := make(FunctionFacts) sym := cfg.SymbolID(1) diff --git a/compiler/check/returns/captured_field_merge.go b/compiler/check/domain/factproduct/captured_field.go similarity index 98% rename from compiler/check/returns/captured_field_merge.go rename to compiler/check/domain/factproduct/captured_field.go index 8cdeca2e..9e11283f 100644 --- a/compiler/check/returns/captured_field_merge.go +++ b/compiler/check/domain/factproduct/captured_field.go @@ -1,4 +1,4 @@ -package returns +package factproduct import ( "github.com/wippyai/go-lua/compiler/cfg" diff --git a/compiler/check/returns/container_mutation_merge.go b/compiler/check/domain/factproduct/container_mutation.go similarity index 99% rename from compiler/check/returns/container_mutation_merge.go rename to compiler/check/domain/factproduct/container_mutation.go index 6fb3943b..cbf131d5 100644 --- a/compiler/check/returns/container_mutation_merge.go +++ b/compiler/check/domain/factproduct/container_mutation.go @@ -1,4 +1,4 @@ -package returns +package factproduct import ( "github.com/wippyai/go-lua/compiler/cfg" diff --git a/compiler/check/returns/container_mutation_merge_test.go b/compiler/check/domain/factproduct/container_mutation_test.go similarity index 99% rename from compiler/check/returns/container_mutation_merge_test.go rename to compiler/check/domain/factproduct/container_mutation_test.go index f9370695..1d9f59d9 100644 --- a/compiler/check/returns/container_mutation_merge_test.go +++ b/compiler/check/domain/factproduct/container_mutation_test.go @@ -1,4 +1,4 @@ -package returns +package factproduct import ( "testing" diff --git a/compiler/check/domain/factproduct/doc.go b/compiler/check/domain/factproduct/doc.go new file mode 100644 index 00000000..755e4651 --- /dev/null +++ b/compiler/check/domain/factproduct/doc.go @@ -0,0 +1,9 @@ +// Package factproduct owns the interprocedural facts product domain. +// +// It canonicalizes, joins, widens, and compares api.Facts bundles. Lower-level +// domains own individual slots: functionfact for one FunctionFact, +// returnsummary for return vectors, paramevidence for parameter evidence, and +// value for structural value relations. This package owns the product-level +// shape across graph facts, captured types, captured field writes, captured +// container mutations, constructor fields, and literal signatures. +package factproduct diff --git a/compiler/check/returns/domain_law_test.go b/compiler/check/domain/factproduct/domain_law_test.go similarity index 99% rename from compiler/check/returns/domain_law_test.go rename to compiler/check/domain/factproduct/domain_law_test.go index 080835d3..240f833e 100644 --- a/compiler/check/returns/domain_law_test.go +++ b/compiler/check/domain/factproduct/domain_law_test.go @@ -1,4 +1,4 @@ -package returns +package factproduct import ( "testing" diff --git a/compiler/check/returns/equal.go b/compiler/check/domain/factproduct/equal.go similarity index 99% rename from compiler/check/returns/equal.go rename to compiler/check/domain/factproduct/equal.go index fda479b5..3152af34 100644 --- a/compiler/check/returns/equal.go +++ b/compiler/check/domain/factproduct/equal.go @@ -1,4 +1,4 @@ -package returns +package factproduct import ( "github.com/wippyai/go-lua/compiler/cfg" diff --git a/compiler/check/returns/equal_test.go b/compiler/check/domain/factproduct/equal_test.go similarity index 99% rename from compiler/check/returns/equal_test.go rename to compiler/check/domain/factproduct/equal_test.go index 269fb25a..048b3e6c 100644 --- a/compiler/check/returns/equal_test.go +++ b/compiler/check/domain/factproduct/equal_test.go @@ -1,4 +1,4 @@ -package returns +package factproduct import ( "testing" diff --git a/compiler/check/domain/factproduct/function_fact_product_test.go b/compiler/check/domain/factproduct/function_fact_product_test.go new file mode 100644 index 00000000..9ab9b988 --- /dev/null +++ b/compiler/check/domain/factproduct/function_fact_product_test.go @@ -0,0 +1,41 @@ +package factproduct + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/api" + "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" + "github.com/wippyai/go-lua/types/typ" +) + +func TestJoinFacts_BatchMergeFunctionFacts(t *testing.T) { + symSummary := cfg.SymbolID(21) + symNarrow := cfg.SymbolID(22) + symFunc := cfg.SymbolID(23) + funcType := typ.Func().Returns(typ.Boolean).Build() + + facts := JoinFacts( + api.Facts{ + FunctionFacts: api.FunctionFacts{ + symSummary: {Summary: []typ.Type{typ.String}}, + symNarrow: {Narrow: []typ.Type{typ.Number}}, + }, + }, + api.Facts{ + FunctionFacts: api.FunctionFacts{ + symFunc: {Type: funcType}, + }, + }, + ) + + if got := facts.FunctionFacts.Summary(symSummary); !returnsummary.Equal(got, []typ.Type{typ.String}) { + t.Fatalf("summary mismatch: got %v", got) + } + if got := facts.FunctionFacts.NarrowSummary(symNarrow); !returnsummary.Equal(got, []typ.Type{typ.Number}) { + t.Fatalf("narrow mismatch: got %v", got) + } + if got := facts.FunctionFacts.FunctionType(symFunc); !typ.TypeEquals(got, funcType) { + t.Fatalf("func mismatch: got %v", got) + } +} diff --git a/compiler/check/returns/function_facts.go b/compiler/check/domain/factproduct/function_facts.go similarity index 98% rename from compiler/check/returns/function_facts.go rename to compiler/check/domain/factproduct/function_facts.go index aaab7799..15b459e0 100644 --- a/compiler/check/returns/function_facts.go +++ b/compiler/check/domain/factproduct/function_facts.go @@ -1,4 +1,4 @@ -package returns +package factproduct import ( "github.com/wippyai/go-lua/compiler/cfg" diff --git a/compiler/check/returns/widen.go b/compiler/check/domain/factproduct/product.go similarity index 99% rename from compiler/check/returns/widen.go rename to compiler/check/domain/factproduct/product.go index eab99f6d..0a3feea2 100644 --- a/compiler/check/returns/widen.go +++ b/compiler/check/domain/factproduct/product.go @@ -1,4 +1,4 @@ -package returns +package factproduct import ( "github.com/wippyai/go-lua/compiler/cfg" diff --git a/compiler/check/returns/widen_test.go b/compiler/check/domain/factproduct/product_test.go similarity index 79% rename from compiler/check/returns/widen_test.go rename to compiler/check/domain/factproduct/product_test.go index 48bbe76a..26903e9e 100644 --- a/compiler/check/returns/widen_test.go +++ b/compiler/check/domain/factproduct/product_test.go @@ -1,11 +1,10 @@ -package returns +package factproduct import ( "testing" "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/check/api" - "github.com/wippyai/go-lua/compiler/check/domain/functionfact" "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" @@ -278,24 +277,6 @@ func TestMergeFunctionReturnsIfSameShape_GenericTypeParamsMustMatch(t *testing.T } } -func TestFunctionFactMergeType_DoesNotRegressToNarrowerNilReturn(t *testing.T) { - prev := typ.Func(). - Returns(typ.NewOptional(typ.Integer)). - Build() - next := typ.Func(). - Returns(typ.Nil). - Build() - - merged := functionfact.MergeType(prev, next) - fn, ok := merged.(*typ.Function) - if !ok || len(fn.Returns) != 1 { - t.Fatalf("expected merged function return, got %T", merged) - } - if !typ.TypeEquals(fn.Returns[0], typ.NewOptional(typ.Integer)) { - t.Fatalf("expected integer? return after merge, got %v", fn.Returns[0]) - } -} - func TestMergeFunctionReturnsIfSameShape_NormalizesLeakedTypeParams(t *testing.T) { prev := typ.Func(). Returns(typ.NewTypeParam("T", nil)). @@ -317,69 +298,6 @@ func TestMergeFunctionReturnsIfSameShape_NormalizesLeakedTypeParams(t *testing.T } } -func TestFunctionFactMergeType_PrefersWiderSupertypeOnSubtypeRelation(t *testing.T) { - merged := functionfact.MergeType(typ.Integer, typ.Number) - if !typ.TypeEquals(merged, typ.Number) { - t.Fatalf("expected wider supertype number, got %v", merged) - } - - merged = functionfact.MergeType(typ.Number, typ.Integer) - if !typ.TypeEquals(merged, typ.Number) { - t.Fatalf("expected wider supertype number, got %v", merged) - } -} - -func TestFunctionFactMergeType_IsCommutativeForIncomparableSignatures(t *testing.T) { - coarse := typ.Func(). - Param("entries", typ.Any). - Returns(typ.Integer). - Build() - refined := typ.Func(). - Param("entries", typ.NewArray(typ.String)). - Returns(typ.Integer). - Build() - - forward := functionfact.MergeType(coarse, refined) - reverse := functionfact.MergeType(refined, coarse) - if !typ.TypeEquals(forward, reverse) { - t.Fatalf("expected commutative merge result, got forward=%v reverse=%v", forward, reverse) - } -} - -func TestFunctionFactMergeType_AliasInputsUseCanonicalJoin(t *testing.T) { - coarse := typ.NewAlias("CoarseFn", typ.Func(). - Param("entries", typ.Any). - Returns(typ.Integer). - Build()) - refined := typ.NewAlias("RefinedFn", typ.Func(). - Param("entries", typ.NewArray(typ.String)). - Returns(typ.Integer). - Build()) - - forward := functionfact.MergeType(coarse, refined) - reverse := functionfact.MergeType(refined, coarse) - if !typ.TypeEquals(forward, reverse) { - t.Fatalf("expected commutative alias merge result, got forward=%v reverse=%v", forward, reverse) - } -} - -func TestFunctionFactMergeType_MapVsOpenRecordUsesCanonicalJoin(t *testing.T) { - coarse := typ.Func(). - Param("t", typ.NewRecord().SetOpen(true).Build()). - Returns(typ.String). - Build() - refined := typ.Func(). - Param("t", typ.NewMap(typ.String, typ.NewArray(typ.String))). - Returns(typ.String). - Build() - - forward := functionfact.MergeType(coarse, refined) - reverse := functionfact.MergeType(refined, coarse) - if !typ.TypeEquals(forward, reverse) { - t.Fatalf("expected commutative map/open-record merge result, got forward=%v reverse=%v", forward, reverse) - } -} - func TestWidenLiteralSigs_DoesNotNarrowComparableSignature(t *testing.T) { lit := &ast.FunctionExpr{} diff --git a/compiler/check/domain/functionfact/fact_test.go b/compiler/check/domain/functionfact/fact_test.go index 6d51f3e7..54648774 100644 --- a/compiler/check/domain/functionfact/fact_test.go +++ b/compiler/check/domain/functionfact/fact_test.go @@ -100,6 +100,59 @@ func TestJoin_NarrowSummaryRepairsNeverArtifact(t *testing.T) { } } +func TestJoin_MergesExistingAndCandidate(t *testing.T) { + existingFn := typ.Func().Returns(typ.Number).Build() + candidateFn := typ.Func().Returns(typ.String).Build() + existing := api.FunctionFact{ + Summary: []typ.Type{typ.Number}, + Narrow: []typ.Type{typ.Number}, + Type: existingFn, + } + candidate := api.FunctionFact{ + Summary: []typ.Type{typ.String}, + Narrow: []typ.Type{typ.String}, + Type: candidateFn, + } + got := Join(existing, candidate) + + if !returnsummary.Equal(got.Summary, []typ.Type{typ.NewUnion(typ.Number, typ.String)}) { + t.Fatalf("summary mismatch: got %v", got.Summary) + } + if !returnsummary.Equal(got.Narrow, []typ.Type{typ.NewUnion(typ.Number, typ.String)}) { + t.Fatalf("narrow mismatch: got %v", got.Narrow) + } + if got.Type == nil { + t.Fatal("expected merged function type") + } +} + +func TestJoin_DoesNotAlignFunctionToNarrowFieldRegression(t *testing.T) { + withCapturedMethod := typ.NewRecord(). + Field("x", typ.Integer). + Field("get_x", typ.Func().Param("self", typ.Unknown).Returns(typ.Number).Build()). + Build() + flowOnly := typ.NewRecord(). + Field("x", typ.Integer). + Build() + existingFunc := typ.Func().Returns(flowOnly).Build() + + out := Join( + api.FunctionFact{Summary: []typ.Type{withCapturedMethod}, Narrow: []typ.Type{flowOnly}, Type: existingFunc}, + api.FunctionFact{Summary: []typ.Type{withCapturedMethod}, Narrow: []typ.Type{flowOnly}, Type: existingFunc}, + ) + + if !returnsummary.Equal(out.Summary, []typ.Type{withCapturedMethod}) { + t.Fatalf("summary mismatch: got %v want %v", out.Summary, []typ.Type{withCapturedMethod}) + } + fn, ok := out.Type.(*typ.Function) + if !ok { + t.Fatalf("expected function fact, got %T", out.Type) + } + if !returnsummary.Equal(fn.Returns, []typ.Type{withCapturedMethod}) { + t.Fatalf("func returns should preserve captured method summary, got %v", fn.Returns) + } +} + func TestMergeType_MergesSameShapeReturnsCanonically(t *testing.T) { existing := typ.Func(). Param("x", typ.String). @@ -147,6 +200,53 @@ func TestMergeType_WidensParamToCoverObservedCallsites(t *testing.T) { } } +func TestMergeType_PrefersConcreteParamOverTopObservation(t *testing.T) { + existing := typ.Func(). + Param("x", typ.Any). + Returns(typ.String). + Build() + candidate := typ.Func(). + Param("x", typ.String). + Returns(typ.String). + Build() + + merged := MergeType(existing, candidate) + fn, ok := merged.(*typ.Function) + if !ok { + t.Fatalf("expected merged function, got %T", merged) + } + if len(fn.Params) != 1 || !typ.TypeEquals(fn.Params[0].Type, typ.String) { + t.Fatalf("expected param refined to string, got %+v", fn.Params) + } +} + +func TestMergeType_KeepsBaselineOverNestedNilOnlyRegression(t *testing.T) { + baselineReturn := typ.NewRecord(). + Field("full_path", typ.String). + Field("parent", typ.Unknown). + OptField("after_all", typ.Nil). + SetOpen(true). + Build() + candidateReturn := typ.NewRecord(). + Field("full_path", typ.String). + Field("parent", typ.Nil). + Field("after_all", typ.Nil). + SetOpen(true). + Build() + + baseline := typ.Func().Param("name", typ.Unknown).Returns(baselineReturn).Build() + candidate := typ.Func().Param("name", typ.Unknown).Returns(candidateReturn).Build() + + merged := MergeType(baseline, candidate) + fn, ok := merged.(*typ.Function) + if !ok || len(fn.Returns) != 1 { + t.Fatalf("expected merged function return, got %v", merged) + } + if !typ.TypeEquals(fn.Returns[0], baselineReturn) { + t.Fatalf("expected baseline record to survive nil-only refinement, got %v", fn.Returns[0]) + } +} + func TestMergeType_CollapsesMixedFunctionUnionVariants(t *testing.T) { base := typ.Func(). Param("name", typ.Unknown). @@ -191,6 +291,207 @@ func TestMergeType_CollapsesMixedFunctionUnionVariants(t *testing.T) { } } +func TestMergeType_DoesNotDropNonFunctionUnionMembers(t *testing.T) { + fn := typ.Func().Param("x", typ.String).Returns(typ.String).Build() + existing := typ.NewUnion(fn, typ.Number) + candidate := typ.Func().Param("x", typ.String).Returns(typ.String).Build() + + merged := MergeType(existing, candidate) + u, ok := merged.(*typ.Union) + if !ok { + t.Fatalf("expected union to be preserved, got %T", merged) + } + hasNumber := false + for _, m := range u.Members { + if typ.TypeEquals(m, typ.Number) { + hasNumber = true + break + } + } + if !hasNumber { + t.Fatalf("expected merged union to retain non-function member, got %v", merged) + } +} + +func TestMergeType_CollapsesCompatibleFunctionVariants(t *testing.T) { + base := typ.Func(). + OptParam("entries", typ.Any). + Returns(typ.NewMap(typ.Unknown, typ.NewArray(typ.Unknown))). + Build() + refinedEntry := typ.NewRecord().Field("id", typ.String).Build() + refined := typ.Func(). + OptParam("entries", typ.NewArray(refinedEntry)). + Returns(typ.NewMap(typ.String, typ.NewArray(refinedEntry))). + Build() + + merged := MergeType(base, refined) + fn, ok := merged.(*typ.Function) + if !ok { + t.Fatalf("expected function after compatible-variant collapse, got %T", merged) + } + if len(fn.Params) != 1 || !typ.TypeEquals(fn.Params[0].Type, typ.NewArray(refinedEntry)) { + t.Fatalf("expected refined param type to win, got %+v", fn.Params) + } + if len(fn.Returns) != 1 || !typ.TypeEquals(fn.Returns[0], typ.NewMap(typ.String, typ.NewArray(refinedEntry))) { + t.Fatalf("expected refined return map, got %v", fn.Returns) + } +} + +func TestMergeType_DoesNotCollapseParamToNilWhenOptionalInfoExists(t *testing.T) { + existing := typ.Func(). + OptParam("tests", typ.Nil). + Returns(typ.Integer). + Build() + candidate := typ.Func(). + OptParam("tests", typ.NewOptional(typ.NewArray(typ.Any))). + Returns(typ.Integer). + Build() + + merged := MergeType(existing, candidate) + fn, ok := merged.(*typ.Function) + if !ok { + t.Fatalf("expected function, got %T", merged) + } + want := typ.NewOptional(typ.NewArray(typ.Any)) + if len(fn.Params) != 1 || !fn.Params[0].Optional || !typ.TypeEquals(fn.Params[0].Type, want) { + t.Fatalf("expected optional param slot with type %v, got %+v", want, fn.Params) + } +} + +func TestMergeType_NilDoesNotDominateSoftOptionalParamShape(t *testing.T) { + softArray := typ.NewOptional(typ.NewUnion(typ.NewArray(typ.Any), typ.NewRecord().SetOpen(true).Build())) + preciseArray := typ.NewOptional(typ.NewArray(typ.String)) + + merged := MergeType( + typ.Func().OptParam("tests", typ.Nil).Returns(typ.Integer).Build(), + typ.Func().OptParam("tests", softArray).Returns(typ.Integer).Build(), + ) + fn, ok := merged.(*typ.Function) + if !ok || len(fn.Params) != 1 { + t.Fatalf("expected merged function, got %T", merged) + } + if !typ.TypeEquals(fn.Params[0].Type, softArray) { + t.Fatalf("expected nil observation not to replace soft optional table shape, got %v", fn.Params[0].Type) + } + + merged = MergeType( + typ.Func().OptParam("tests", softArray).Returns(typ.Integer).Build(), + typ.Func().OptParam("tests", preciseArray).Returns(typ.Integer).Build(), + ) + fn, ok = merged.(*typ.Function) + if !ok || len(fn.Params) != 1 { + t.Fatalf("expected merged function, got %T", merged) + } + if !typ.TypeEquals(fn.Params[0].Type, preciseArray) { + t.Fatalf("expected precise optional array evidence to replace soft shape, got %v", fn.Params[0].Type) + } +} + +func TestMergeType_ReplacesStaleFalsyMapKeyWithTruthyRefinement(t *testing.T) { + entry := typ.NewRecord().Field("id", typ.String).Build() + stale := typ.NewRecord(). + MapComponent(typ.NewUnion(typ.Boolean, typ.String), typ.NewArray(entry)). + SetOpen(true). + Build() + current := typ.NewRecord(). + MapComponent(typ.String, typ.NewArray(entry)). + SetOpen(true). + Build() + + merged := MergeType( + typ.Func().OptParam("t", stale).Returns(typ.NewArray(typ.NewUnion(typ.Boolean, typ.String))).Build(), + typ.Func().OptParam("t", current).Returns(typ.NewArray(typ.String)).Build(), + ) + fn, ok := merged.(*typ.Function) + if !ok || len(fn.Params) != 1 { + t.Fatalf("expected merged function, got %T", merged) + } + if !typ.TypeEquals(fn.Params[0].Type, current) { + t.Fatalf("expected truthy-refined map key param %v, got %v", current, fn.Params[0].Type) + } +} + +func TestMergeType_DoesNotRegressToNarrowerNilReturn(t *testing.T) { + prev := typ.Func(). + Returns(typ.NewOptional(typ.Integer)). + Build() + next := typ.Func(). + Returns(typ.Nil). + Build() + + merged := MergeType(prev, next) + fn, ok := merged.(*typ.Function) + if !ok || len(fn.Returns) != 1 { + t.Fatalf("expected merged function return, got %T", merged) + } + if !typ.TypeEquals(fn.Returns[0], typ.NewOptional(typ.Integer)) { + t.Fatalf("expected integer? return after merge, got %v", fn.Returns[0]) + } +} + +func TestMergeType_PrefersWiderSupertypeOnSubtypeRelation(t *testing.T) { + merged := MergeType(typ.Integer, typ.Number) + if !typ.TypeEquals(merged, typ.Number) { + t.Fatalf("expected wider supertype number, got %v", merged) + } + + merged = MergeType(typ.Number, typ.Integer) + if !typ.TypeEquals(merged, typ.Number) { + t.Fatalf("expected wider supertype number, got %v", merged) + } +} + +func TestMergeType_IsCommutativeForIncomparableSignatures(t *testing.T) { + coarse := typ.Func(). + Param("entries", typ.Any). + Returns(typ.Integer). + Build() + refined := typ.Func(). + Param("entries", typ.NewArray(typ.String)). + Returns(typ.Integer). + Build() + + forward := MergeType(coarse, refined) + reverse := MergeType(refined, coarse) + if !typ.TypeEquals(forward, reverse) { + t.Fatalf("expected commutative merge result, got forward=%v reverse=%v", forward, reverse) + } +} + +func TestMergeType_AliasInputsUseCanonicalJoin(t *testing.T) { + coarse := typ.NewAlias("CoarseFn", typ.Func(). + Param("entries", typ.Any). + Returns(typ.Integer). + Build()) + refined := typ.NewAlias("RefinedFn", typ.Func(). + Param("entries", typ.NewArray(typ.String)). + Returns(typ.Integer). + Build()) + + forward := MergeType(coarse, refined) + reverse := MergeType(refined, coarse) + if !typ.TypeEquals(forward, reverse) { + t.Fatalf("expected commutative alias merge result, got forward=%v reverse=%v", forward, reverse) + } +} + +func TestMergeType_MapVsOpenRecordUsesCanonicalJoin(t *testing.T) { + coarse := typ.Func(). + Param("t", typ.NewRecord().SetOpen(true).Build()). + Returns(typ.String). + Build() + refined := typ.Func(). + Param("t", typ.NewMap(typ.String, typ.NewArray(typ.String))). + Returns(typ.String). + Build() + + forward := MergeType(coarse, refined) + reverse := MergeType(refined, coarse) + if !typ.TypeEquals(forward, reverse) { + t.Fatalf("expected commutative map/open-record merge result, got forward=%v reverse=%v", forward, reverse) + } +} + func TestNormalize_CanonicalizesStoredFunctionFact(t *testing.T) { fn := typ.Func().Returns(typ.Number).Build() got := Normalize(api.FunctionFact{ diff --git a/compiler/check/returns/join_test.go b/compiler/check/domain/returnsummary/join_test.go similarity index 59% rename from compiler/check/returns/join_test.go rename to compiler/check/domain/returnsummary/join_test.go index f4dbe494..8849a6b6 100644 --- a/compiler/check/returns/join_test.go +++ b/compiler/check/domain/returnsummary/join_test.go @@ -1,15 +1,9 @@ -package returns +package returnsummary import ( - "testing" - - "github.com/wippyai/go-lua/compiler/check/domain/functionfact" - "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" - "github.com/wippyai/go-lua/types/kind" - "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" typjoin "github.com/wippyai/go-lua/types/typ/join" - "github.com/wippyai/go-lua/types/typ/unwrap" + "testing" ) func TestJoinReturnVectors_Empty(t *testing.T) { @@ -57,13 +51,13 @@ func TestTypJoinReturnSlot_PreservesUnknownOverNil(t *testing.T) { } func TestReturnSummaryAllNil(t *testing.T) { - if !returnsummary.AllNil([]typ.Type{typ.Nil}) { + if !AllNil([]typ.Type{typ.Nil}) { t.Fatal("expected [nil] to be nil-only") } - if returnsummary.AllNil([]typ.Type{typ.Nil, typ.Unknown}) { + if AllNil([]typ.Type{typ.Nil, typ.Unknown}) { t.Fatal("expected [nil, unknown] to not be nil-only") } - if returnsummary.AllNil(nil) { + if AllNil(nil) { t.Fatal("expected empty return vector to not be nil-only") } } @@ -78,7 +72,7 @@ func TestJoinReturnVectors_DifferentLengths(t *testing.T) { } func TestReturnSummaryEqual_Empty(t *testing.T) { - if !returnsummary.Equal(nil, nil) { + if !Equal(nil, nil) { t.Error("nil slices should be equal") } } @@ -86,7 +80,7 @@ func TestReturnSummaryEqual_Empty(t *testing.T) { func TestReturnSummaryEqual_DifferentLength(t *testing.T) { a := []typ.Type{typ.String} b := []typ.Type{typ.String, typ.Number} - if returnsummary.Equal(a, b) { + if Equal(a, b) { t.Error("different lengths should not be equal") } } @@ -94,7 +88,7 @@ func TestReturnSummaryEqual_DifferentLength(t *testing.T) { func TestReturnSummaryEqual_Same(t *testing.T) { a := []typ.Type{typ.String, typ.Number} b := []typ.Type{typ.String, typ.Number} - if !returnsummary.Equal(a, b) { + if !Equal(a, b) { t.Error("same types should be equal") } } @@ -102,21 +96,21 @@ func TestReturnSummaryEqual_Same(t *testing.T) { func TestReturnSummaryEqual_Different(t *testing.T) { a := []typ.Type{typ.String} b := []typ.Type{typ.Number} - if returnsummary.Equal(a, b) { + if Equal(a, b) { t.Error("different types should not be equal") } } func TestReturnSummaryRefines_EmptyA(t *testing.T) { b := []typ.Type{typ.String} - if returnsummary.Refines(nil, b) { + if Refines(nil, b) { t.Error("empty a should not refine b") } } func TestReturnSummaryRefines_EmptyB(t *testing.T) { a := []typ.Type{typ.String} - if !returnsummary.Refines(a, nil) { + if !Refines(a, nil) { t.Error("a should refine empty b") } } @@ -124,7 +118,7 @@ func TestReturnSummaryRefines_EmptyB(t *testing.T) { func TestReturnSummaryRefines_Same(t *testing.T) { a := []typ.Type{typ.String} b := []typ.Type{typ.String} - if !returnsummary.Refines(a, b) { + if !Refines(a, b) { t.Error("same types should refine") } } @@ -132,7 +126,7 @@ func TestReturnSummaryRefines_Same(t *testing.T) { func TestReturnSummaryRefines_DifferentLength(t *testing.T) { a := []typ.Type{typ.String, typ.Number} b := []typ.Type{typ.String} - if returnsummary.Refines(a, b) { + if Refines(a, b) { t.Error("different lengths should not refine") } } @@ -141,14 +135,14 @@ func TestReturnSummaryMerge_ReplacesStaleFalsyKeyArrayElement(t *testing.T) { stale := []typ.Type{typ.NewArray(typ.NewUnion(typ.Boolean, typ.String))} current := []typ.Type{typ.NewArray(typ.String)} - got := returnsummary.Merge(stale, current) - if !returnsummary.Equal(got, current) { + got := Merge(stale, current) + if !Equal(got, current) { t.Fatalf("expected truthy-refined key array %v, got %v", current, got) } } func TestReturnSummaryExtendsRecord_Empty(t *testing.T) { - if returnsummary.ExtendsRecord(nil, nil) { + if ExtendsRecord(nil, nil) { t.Error("empty vectors should not extend") } } @@ -156,7 +150,7 @@ func TestReturnSummaryExtendsRecord_Empty(t *testing.T) { func TestReturnSummaryExtendsRecord_NotRecords(t *testing.T) { a := []typ.Type{typ.String} b := []typ.Type{typ.String} - if returnsummary.ExtendsRecord(a, b) { + if ExtendsRecord(a, b) { t.Error("non-records should not extend") } } @@ -166,19 +160,19 @@ func TestReturnSummaryExtendsRecord_RecordExtends(t *testing.T) { newRec := typ.NewRecord().Field("x", typ.Number).Field("y", typ.Number).Build() a := []typ.Type{newRec} b := []typ.Type{oldRec} - if !returnsummary.ExtendsRecord(a, b) { + if !ExtendsRecord(a, b) { t.Error("record with more fields should extend") } } func TestReturnSummaryElidesOptional_Empty(t *testing.T) { - if returnsummary.ElidesOptional(nil, nil) { + if ElidesOptional(nil, nil) { t.Error("empty vectors should not elide") } } func TestReturnSummarySelectPreferred_Refinement(t *testing.T) { - preferred, ok := returnsummary.SelectPreferred([]typ.Type{typ.String}, []typ.Type{typ.NewOptional(typ.String)}) + preferred, ok := SelectPreferred([]typ.Type{typ.String}, []typ.Type{typ.NewOptional(typ.String)}) if !ok { t.Fatal("expected preferred vector") } @@ -188,7 +182,7 @@ func TestReturnSummarySelectPreferred_Refinement(t *testing.T) { } func TestReturnSummarySelectPreferred_AvoidsNilOnlyRegression(t *testing.T) { - preferred, ok := returnsummary.SelectPreferred([]typ.Type{typ.Nil}, []typ.Type{typ.NewOptional(typ.String)}) + preferred, ok := SelectPreferred([]typ.Type{typ.Nil}, []typ.Type{typ.NewOptional(typ.String)}) if !ok { t.Fatal("expected preferred vector") } @@ -198,7 +192,7 @@ func TestReturnSummarySelectPreferred_AvoidsNilOnlyRegression(t *testing.T) { } func TestReturnSummarySelectPreferred_RejectsStaleNilOnly(t *testing.T) { - preferred, ok := returnsummary.SelectPreferred([]typ.Type{typ.NewOptional(typ.String)}, []typ.Type{typ.Nil}) + preferred, ok := SelectPreferred([]typ.Type{typ.NewOptional(typ.String)}, []typ.Type{typ.Nil}) if !ok { t.Fatal("expected preferred vector") } @@ -211,7 +205,7 @@ func TestReturnSummarySelectPreferred_RecordExtension(t *testing.T) { oldRec := typ.NewRecord().Field("x", typ.Number).Build() newRec := typ.NewRecord().Field("x", typ.Number).Field("y", typ.String).Build() - preferred, ok := returnsummary.SelectPreferred([]typ.Type{newRec}, []typ.Type{oldRec}) + preferred, ok := SelectPreferred([]typ.Type{newRec}, []typ.Type{oldRec}) if !ok { t.Fatal("expected preferred vector") } @@ -224,7 +218,7 @@ func TestReturnSummarySelectRefining_Refinement(t *testing.T) { refined := []typ.Type{typ.String} baseline := []typ.Type{typ.NewOptional(typ.String)} - got, ok := returnsummary.SelectRefining(refined, baseline) + got, ok := SelectRefining(refined, baseline) if !ok { t.Fatal("expected refinement to be selected") } @@ -237,7 +231,7 @@ func TestReturnSummarySelectRefining_DoesNotSelectOlderNarrowerBaseline(t *testi candidate := []typ.Type{typ.Any} baseline := []typ.Type{typ.False} - _, ok := returnsummary.SelectRefining(candidate, baseline) + _, ok := SelectRefining(candidate, baseline) if ok { t.Fatal("did not expect baseline-narrower relation to select candidate") } @@ -246,7 +240,7 @@ func TestReturnSummarySelectRefining_DoesNotSelectOlderNarrowerBaseline(t *testi func TestReturnSummaryFillsNilSlots(t *testing.T) { candidate := []typ.Type{typ.NewMap(typ.Unknown, typ.NewArray(typ.Unknown)), typ.NewArray(typ.Unknown)} baseline := []typ.Type{typ.NewMap(typ.Unknown, typ.NewArray(typ.Unknown)), typ.Nil} - if !returnsummary.FillsNilSlots(candidate, baseline) { + if !FillsNilSlots(candidate, baseline) { t.Fatalf("expected candidate to fill nil slot: candidate=%v baseline=%v", candidate, baseline) } } @@ -255,7 +249,7 @@ func TestReturnSummaryMerge_PrefersCandidateRefinement(t *testing.T) { existing := []typ.Type{typ.NewOptional(typ.String)} candidate := []typ.Type{typ.String} - merged := returnsummary.Merge(existing, candidate) + merged := Merge(existing, candidate) if len(merged) != 1 || !typ.TypeEquals(merged[0], typ.String) { t.Fatalf("expected refined candidate return, got %v", merged) } @@ -271,7 +265,7 @@ func TestReturnSummaryMerge_FillsNilSlotWithCandidateEvidence(t *testing.T) { typ.NewArray(typ.Unknown), } - merged := returnsummary.Merge(existing, candidate) + merged := Merge(existing, candidate) if len(merged) != 2 { t.Fatalf("expected two return slots, got %v", merged) } @@ -280,105 +274,11 @@ func TestReturnSummaryMerge_FillsNilSlotWithCandidateEvidence(t *testing.T) { } } -func TestFunctionFactMergeType_MergesSameShapeReturnsCanonically(t *testing.T) { - existing := typ.Func(). - Param("x", typ.String). - Returns(typ.NewOptional(typ.Integer)). - Build() - candidate := typ.Func(). - Param("x", typ.String). - Returns(typ.Integer). - Build() - - merged := functionfact.MergeType(existing, candidate) - fn, ok := merged.(*typ.Function) - if !ok || len(fn.Returns) != 1 { - t.Fatalf("expected merged function, got %T", merged) - } - if !typ.TypeEquals(fn.Returns[0], typ.Integer) { - t.Fatalf("expected refined return integer, got %v", fn.Returns[0]) - } -} - -func TestFunctionFactMergeType_PrefersConcreteParamOverTopObservation(t *testing.T) { - existing := typ.Func(). - Param("x", typ.Any). - Returns(typ.String). - Build() - candidate := typ.Func(). - Param("x", typ.String). - Returns(typ.String). - Build() - - merged := functionfact.MergeType(existing, candidate) - fn, ok := merged.(*typ.Function) - if !ok { - t.Fatalf("expected merged function, got %T", merged) - } - if len(fn.Params) != 1 || !typ.TypeEquals(fn.Params[0].Type, typ.String) { - t.Fatalf("expected param refined to string, got %+v", fn.Params) - } -} - -func TestFunctionFactMergeType_WidensParamToCoverObservedCallsites(t *testing.T) { - existing := typ.Func(). - Param("t", typ.NewArray(typ.Any)). - Returns(typ.String). - Build() - candidate := typ.Func(). - Param("t", typ.NewMap(typ.String, typ.Any)). - Returns(typ.String). - Build() - - merged := functionfact.MergeType(existing, candidate) - fn, ok := merged.(*typ.Function) - if !ok { - t.Fatalf("expected merged function, got %T", merged) - } - if len(fn.Params) != 1 { - t.Fatalf("expected one param, got %+v", fn.Params) - } - if typ.TypeEquals(fn.Params[0].Type, typ.NewArray(typ.Any)) { - t.Fatalf("expected param widening beyond array-only shape, got %v", fn.Params[0].Type) - } - wantMap := typ.NewMap(typ.String, typ.Any) - if !subtype.IsSubtype(wantMap, fn.Params[0].Type) { - t.Fatalf("expected merged param to admit map callsite evidence, got %v", fn.Params[0].Type) - } -} - -func TestFunctionFactMergeType_KeepsBaselineOverNestedNilOnlyRegression(t *testing.T) { - baselineReturn := typ.NewRecord(). - Field("full_path", typ.String). - Field("parent", typ.Unknown). - OptField("after_all", typ.Nil). - SetOpen(true). - Build() - candidateReturn := typ.NewRecord(). - Field("full_path", typ.String). - Field("parent", typ.Nil). - Field("after_all", typ.Nil). - SetOpen(true). - Build() - - baseline := typ.Func().Param("name", typ.Unknown).Returns(baselineReturn).Build() - candidate := typ.Func().Param("name", typ.Unknown).Returns(candidateReturn).Build() - - merged := functionfact.MergeType(baseline, candidate) - fn, ok := merged.(*typ.Function) - if !ok || len(fn.Returns) != 1 { - t.Fatalf("expected merged function return, got %v", merged) - } - if !typ.TypeEquals(fn.Returns[0], baselineReturn) { - t.Fatalf("expected baseline record to survive nil-only refinement, got %v", fn.Returns[0]) - } -} - func TestReturnSummaryMerge_PrefersCurrentTruthyMapKeyRefinement(t *testing.T) { baseline := typ.NewMap(typ.NewUnion(typ.String, typ.False), typ.Unknown) candidate := typ.NewMap(typ.String, typ.Unknown) - merged := returnsummary.Merge([]typ.Type{baseline}, []typ.Type{candidate}) + merged := Merge([]typ.Type{baseline}, []typ.Type{candidate}) if len(merged) != 1 || !typ.TypeEquals(merged[0], candidate) { t.Fatalf("expected stale falsy map key to refine to %v, got %v", candidate, merged) } @@ -389,7 +289,7 @@ func TestReturnSummaryMerge_PrefersConcreteMapValueOverSoftPlaceholder(t *testin baseline := typ.NewMap(typ.String, typ.NewArray(typ.Any)) candidate := typ.NewMap(typ.String, typ.NewArray(entry)) - merged := returnsummary.Merge([]typ.Type{baseline}, []typ.Type{candidate}) + merged := Merge([]typ.Type{baseline}, []typ.Type{candidate}) if len(merged) != 1 || !typ.TypeEquals(merged[0], candidate) { t.Fatalf("expected concrete map value evidence %v, got %v", candidate, merged) } @@ -405,7 +305,7 @@ func TestReturnSummaryMerge_PrefersCurrentTruthyRecordMapKeyRefinement(t *testin MapComponent(typ.String, entryArray). Build() - merged := returnsummary.Merge([]typ.Type{baseline}, []typ.Type{candidate}) + merged := Merge([]typ.Type{baseline}, []typ.Type{candidate}) if len(merged) != 1 || !typ.TypeEquals(merged[0], candidate) { t.Fatalf("expected stale falsy record map key to refine to %v, got %v", candidate, merged) } @@ -419,178 +319,12 @@ func TestReturnSummaryMerge_PrefersMapOverStaleOpenRecordMapKeyRefinement(t *tes Build() candidate := typ.NewMap(typ.String, entryArray) - merged := returnsummary.Merge([]typ.Type{baseline}, []typ.Type{candidate}) + merged := Merge([]typ.Type{baseline}, []typ.Type{candidate}) if len(merged) != 1 || !typ.TypeEquals(merged[0], candidate) { t.Fatalf("expected map to replace stale open record map %v, got %v", candidate, merged) } } -func TestFunctionFactMergeType_CollapsesMixedFunctionUnionVariants(t *testing.T) { - base := typ.Func(). - Param("name", typ.Unknown). - Returns(typ.NewRecord().Field("full_path", typ.String).SetOpen(true).Build()). - Build() - withChildren := typ.Func(). - Param("name", typ.Unknown). - Returns(typ.NewRecord(). - Field("full_path", typ.String). - Field("children", typ.NewArray(typ.Unknown)). - SetOpen(true). - Build()). - Build() - withTests := typ.Func(). - Param("name", typ.Unknown). - Returns(typ.NewRecord(). - Field("full_path", typ.String). - Field("tests", typ.NewArray(typ.Unknown)). - SetOpen(true). - Build()). - Build() - - merged := functionfact.MergeType(typ.NewUnion(typ.Nil, base, withChildren), withTests) - if merged == nil { - t.Fatal("expected merged type") - } - if fn := unwrap.Function(merged); fn == nil { - t.Fatalf("expected merged function variant, got %v", merged) - } else if len(fn.Returns) != 1 { - t.Fatalf("expected one return, got %v", fn.Returns) - } else { - rec, ok := fn.Returns[0].(*typ.Record) - if !ok { - t.Fatalf("expected record return, got %T", fn.Returns[0]) - } - for _, field := range []string{"full_path", "children", "tests"} { - if rec.GetField(field) == nil { - t.Fatalf("expected merged field %q in %v", field, rec) - } - } - } - if merged.Kind() != kind.Optional { - t.Fatalf("expected nil residual to be preserved as optional, got %v", merged) - } -} - -func TestFunctionFactMergeType_DoesNotDropNonFunctionUnionMembers(t *testing.T) { - fn := typ.Func().Param("x", typ.String).Returns(typ.String).Build() - existing := typ.NewUnion(fn, typ.Number) - candidate := typ.Func().Param("x", typ.String).Returns(typ.String).Build() - - merged := functionfact.MergeType(existing, candidate) - u, ok := merged.(*typ.Union) - if !ok { - t.Fatalf("expected union to be preserved, got %T", merged) - } - hasNumber := false - for _, m := range u.Members { - if typ.TypeEquals(m, typ.Number) { - hasNumber = true - break - } - } - if !hasNumber { - t.Fatalf("expected merged union to retain non-function member, got %v", merged) - } -} - -func TestFunctionFactMergeType_CollapsesCompatibleFunctionVariants(t *testing.T) { - base := typ.Func(). - OptParam("entries", typ.Any). - Returns(typ.NewMap(typ.Unknown, typ.NewArray(typ.Unknown))). - Build() - refinedEntry := typ.NewRecord().Field("id", typ.String).Build() - refined := typ.Func(). - OptParam("entries", typ.NewArray(refinedEntry)). - Returns(typ.NewMap(typ.String, typ.NewArray(refinedEntry))). - Build() - - merged := functionfact.MergeType(base, refined) - fn, ok := merged.(*typ.Function) - if !ok { - t.Fatalf("expected function after compatible-variant collapse, got %T", merged) - } - if len(fn.Params) != 1 || !typ.TypeEquals(fn.Params[0].Type, typ.NewArray(refinedEntry)) { - t.Fatalf("expected refined param type to win, got %+v", fn.Params) - } - if len(fn.Returns) != 1 || !typ.TypeEquals(fn.Returns[0], typ.NewMap(typ.String, typ.NewArray(refinedEntry))) { - t.Fatalf("expected refined return map, got %v", fn.Returns) - } -} - -func TestFunctionFactMergeType_DoesNotCollapseParamToNilWhenOptionalInfoExists(t *testing.T) { - existing := typ.Func(). - OptParam("tests", typ.Nil). - Returns(typ.Integer). - Build() - candidate := typ.Func(). - OptParam("tests", typ.NewOptional(typ.NewArray(typ.Any))). - Returns(typ.Integer). - Build() - - merged := functionfact.MergeType(existing, candidate) - fn, ok := merged.(*typ.Function) - if !ok { - t.Fatalf("expected function, got %T", merged) - } - want := typ.NewOptional(typ.NewArray(typ.Any)) - if len(fn.Params) != 1 || !fn.Params[0].Optional || !typ.TypeEquals(fn.Params[0].Type, want) { - t.Fatalf("expected optional param slot with type %v, got %+v", want, fn.Params) - } -} - -func TestFunctionFactMergeType_NilDoesNotDominateSoftOptionalParamShape(t *testing.T) { - softArray := typ.NewOptional(typ.NewUnion(typ.NewArray(typ.Any), typ.NewRecord().SetOpen(true).Build())) - preciseArray := typ.NewOptional(typ.NewArray(typ.String)) - - merged := functionfact.MergeType( - typ.Func().OptParam("tests", typ.Nil).Returns(typ.Integer).Build(), - typ.Func().OptParam("tests", softArray).Returns(typ.Integer).Build(), - ) - fn, ok := merged.(*typ.Function) - if !ok || len(fn.Params) != 1 { - t.Fatalf("expected merged function, got %T", merged) - } - if !typ.TypeEquals(fn.Params[0].Type, softArray) { - t.Fatalf("expected nil observation not to replace soft optional table shape, got %v", fn.Params[0].Type) - } - - merged = functionfact.MergeType( - typ.Func().OptParam("tests", softArray).Returns(typ.Integer).Build(), - typ.Func().OptParam("tests", preciseArray).Returns(typ.Integer).Build(), - ) - fn, ok = merged.(*typ.Function) - if !ok || len(fn.Params) != 1 { - t.Fatalf("expected merged function, got %T", merged) - } - if !typ.TypeEquals(fn.Params[0].Type, preciseArray) { - t.Fatalf("expected precise optional array evidence to replace soft shape, got %v", fn.Params[0].Type) - } -} - -func TestFunctionFactMergeType_ReplacesStaleFalsyMapKeyWithTruthyRefinement(t *testing.T) { - entry := typ.NewRecord().Field("id", typ.String).Build() - stale := typ.NewRecord(). - MapComponent(typ.NewUnion(typ.Boolean, typ.String), typ.NewArray(entry)). - SetOpen(true). - Build() - current := typ.NewRecord(). - MapComponent(typ.String, typ.NewArray(entry)). - SetOpen(true). - Build() - - merged := functionfact.MergeType( - typ.Func().OptParam("t", stale).Returns(typ.NewArray(typ.NewUnion(typ.Boolean, typ.String))).Build(), - typ.Func().OptParam("t", current).Returns(typ.NewArray(typ.String)).Build(), - ) - fn, ok := merged.(*typ.Function) - if !ok || len(fn.Params) != 1 { - t.Fatalf("expected merged function, got %T", merged) - } - if !typ.TypeEquals(fn.Params[0].Type, current) { - t.Fatalf("expected truthy-refined map key param %v, got %v", current, fn.Params[0].Type) - } -} - func TestReturnSummaryMerge_ElidesOptionalForInterfaceFieldRecords(t *testing.T) { txType := typ.NewInterface("sql.Tx", []typ.Method{ {Name: "rollback", Type: typ.Func().Param("self", typ.Self).Build()}, @@ -613,7 +347,7 @@ func TestReturnSummaryMerge_ElidesOptionalForInterfaceFieldRecords(t *testing.T) Build(), } - merged := returnsummary.Merge(existing, candidate) + merged := Merge(existing, candidate) if len(merged) != 1 || !typ.TypeEquals(merged[0], candidate[0]) { t.Fatalf("expected candidate optional-elision to win, got %v", merged) } @@ -626,7 +360,7 @@ func TestReturnSummaryApplyToFunctionType_AppliesSummaryToPlaceholderReturns(t * Build() summary := []typ.Type{typ.Integer} - got := returnsummary.ApplyToFunctionType(fn, summary) + got := ApplyToFunctionType(fn, summary) if got == nil || len(got.Returns) != 1 { t.Fatalf("expected function return, got %v", got) } @@ -637,7 +371,7 @@ func TestReturnSummaryApplyToFunctionType_AppliesSummaryToPlaceholderReturns(t * func TestReturnSummaryApplyToFunctionType_DefaultsToUnknownWhenMissing(t *testing.T) { fn := typ.Func().Param("x", typ.String).Build() - got := returnsummary.ApplyToFunctionType(fn, nil) + got := ApplyToFunctionType(fn, nil) if got == nil || len(got.Returns) != 1 { t.Fatalf("expected one default return, got %v", got) } @@ -647,7 +381,7 @@ func TestReturnSummaryApplyToFunctionType_DefaultsToUnknownWhenMissing(t *testin } func TestReturnSummaryNormalize_Empty(t *testing.T) { - result := returnsummary.Normalize(nil) + result := Normalize(nil) if result != nil { t.Errorf("expected nil, got %v", result) } @@ -655,7 +389,7 @@ func TestReturnSummaryNormalize_Empty(t *testing.T) { func TestReturnSummaryNormalize_ReplacesNil(t *testing.T) { input := []typ.Type{typ.String, nil, typ.Number} - result := returnsummary.Normalize(input) + result := Normalize(input) if len(result) != 3 { t.Fatalf("expected length 3, got %d", len(result)) } @@ -677,7 +411,7 @@ func TestRecordSuperset_BothHaveMapComponent(t *testing.T) { newRec := typ.NewRecord().MapComponent(typ.String, typ.Number).Field("x", typ.Number).Build() a := []typ.Type{newRec} b := []typ.Type{oldRec} - if !returnsummary.ExtendsRecord(a, b) { + if !ExtendsRecord(a, b) { t.Error("record with same map component and additional fields should extend") } } @@ -687,7 +421,7 @@ func TestRecordSuperset_OldHasNoMapComponent(t *testing.T) { newRec := typ.NewRecord().Field("x", typ.Number).Field("y", typ.String).Build() a := []typ.Type{newRec} b := []typ.Type{oldRec} - if !returnsummary.ExtendsRecord(a, b) { + if !ExtendsRecord(a, b) { t.Error("record with additional fields should extend record without map component") } } @@ -705,7 +439,7 @@ func TestReturnSummaryAlignFunction_AppliesStrictRefinement(t *testing.T) { Build(), } - aligned, changed := returnsummary.AlignFunction(fn, summary) + aligned, changed := AlignFunction(fn, summary) if !changed { t.Fatal("expected alignment to apply strict refinement summary") } @@ -727,7 +461,7 @@ func TestReturnSummaryAlignFunction_ReplacesOpenTopRecordWithStructuredSummary(t fn := typ.Func().Returns(openTop).Build() summary := []typ.Type{typ.NewArray(typ.Unknown)} - aligned, changed := returnsummary.AlignFunction(fn, summary) + aligned, changed := AlignFunction(fn, summary) if !changed { t.Fatal("expected open-top placeholder to be replaced by structured summary") } @@ -744,7 +478,7 @@ func TestReturnSummaryAlignFunction_DoesNotDowngradeStructuredToPlaceholder(t *t fn := typ.Func().Returns(structured).Build() summary := []typ.Type{typ.Any} - aligned, changed := returnsummary.AlignFunction(fn, summary) + aligned, changed := AlignFunction(fn, summary) if changed { t.Fatalf("expected no downgrade change, got %v", aligned) } @@ -782,9 +516,9 @@ func TestReturnSummaryMerge_PrefersRuntimePossibleSummaryOverNeverArtifact(t *te ), } - got := returnsummary.Merge(bad, good) - if !returnsummary.Equal(got, good) { - t.Fatalf("returnsummary.Merge(%v, %v) = %v, want %v", bad, good, got, good) + got := Merge(bad, good) + if !Equal(got, good) { + t.Fatalf("Merge(%v, %v) = %v, want %v", bad, good, got, good) } } @@ -811,7 +545,7 @@ func TestReturnSummaryAlignFunction_RepairsNestedNeverArtifact(t *testing.T) { ) fn := typ.Func().Returns(bad).Build() - aligned, changed := returnsummary.AlignFunction(fn, []typ.Type{good}) + aligned, changed := AlignFunction(fn, []typ.Type{good}) if !changed { t.Fatal("expected never-artifact repair to update function returns") } @@ -825,7 +559,7 @@ func TestRecordSuperset_NewHasMapComponentOldDoesNot(t *testing.T) { newRec := typ.NewRecord().Field("x", typ.Number).MapComponent(typ.String, typ.Any).Build() a := []typ.Type{newRec} b := []typ.Type{oldRec} - if !returnsummary.ExtendsRecord(a, b) { + if !ExtendsRecord(a, b) { t.Error("record with additional map component should extend record without it") } } @@ -835,7 +569,7 @@ func TestRecordSuperset_IncompatibleMapComponent(t *testing.T) { newRec := typ.NewRecord().MapComponent(typ.String, typ.Number).Build() a := []typ.Type{newRec} b := []typ.Type{oldRec} - if returnsummary.ExtendsRecord(a, b) { + if ExtendsRecord(a, b) { t.Error("record with incompatible map component should not extend") } } @@ -856,7 +590,7 @@ func TestReturnSummaryMerge_PrefersStructuredCollectionOverOpenTopRecordField(t Build(), } - merged := returnsummary.Merge(weak, strong) + merged := Merge(weak, strong) if len(merged) != 1 { t.Fatalf("expected one return slot, got %d", len(merged)) } @@ -882,7 +616,7 @@ func TestReturnSummaryMerge_PromotesTopLevelStructuredOverOpenTop(t *testing.T) typ.NewArray(typ.Any), } - merged := returnsummary.Merge(weak, strong) + merged := Merge(weak, strong) if len(merged) != 1 { t.Fatalf("expected one return slot, got %d", len(merged)) } diff --git a/compiler/check/returns/doc.go b/compiler/check/returns/doc.go index 0829c293..42eac1f0 100644 --- a/compiler/check/returns/doc.go +++ b/compiler/check/returns/doc.go @@ -1,15 +1,16 @@ -// Package returns orchestrates local return inference and interprocedural fact -// products. +// Package returns orchestrates local return inference. // // It does not own the lattice laws for individual fact slots. Those live in // domain packages: // - domain/paramevidence owns parameter evidence; // - domain/returnsummary owns return vectors and function-return alignment; // - domain/functionfact owns one api.FunctionFact at a time; +// - domain/factproduct owns whole api.Facts products; // - domain/value owns reusable structural value relations. // -// This package owns when those domains are applied across maps, SCCs, overlays, -// captured mutations, and recursive interprocedural fixpoint boundaries. +// This package owns local call graph traversal, SCC iteration, return overlays, +// signature seeding, and nested-call mutation replay. The interprocedural store +// applies product joins and widening through domain/factproduct. // // # SCC-Based Analysis // @@ -26,13 +27,12 @@ // - Collect return expressions from all return statements // - Synthesize types for return expressions // - Merge candidate return vectors through domain/returnsummary -// - Apply product-level widening for recursive convergence +// - Publish per-function facts through domain/functionfact // -// # Type Widening +// # Convergence // -// To ensure termination in recursive cases, types are widened: -// - After N iterations, recursive types are approximated -// - Widening preserves soundness while ensuring convergence +// Local SCC iteration has a bounded return-summary loop. Cross-graph +// interprocedural convergence is handled by the store through domain/factproduct. // // # Overlay System // diff --git a/compiler/check/returns/kernel_test.go b/compiler/check/returns/kernel_test.go deleted file mode 100644 index a6251572..00000000 --- a/compiler/check/returns/kernel_test.go +++ /dev/null @@ -1,235 +0,0 @@ -package returns - -import ( - "testing" - - "github.com/wippyai/go-lua/compiler/cfg" - "github.com/wippyai/go-lua/compiler/check/api" - "github.com/wippyai/go-lua/compiler/check/domain/functionfact" - "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" - "github.com/wippyai/go-lua/types/typ" -) - -func TestFunctionFactJoin_InitialObservation(t *testing.T) { - sym := cfg.SymbolID(11) - fn := typ.Func().Returns(typ.String).Build() - - facts := api.Facts{FunctionFacts: api.FunctionFacts{sym: functionfact.Join(api.FunctionFact{}, api.FunctionFact{ - Summary: []typ.Type{typ.String}, - Narrow: []typ.Type{typ.String}, - Type: fn, - })}} - - if got := facts.FunctionFacts.Summary(sym); !returnsummary.Equal(got, []typ.Type{typ.String}) { - t.Fatalf("summary mismatch: got %v", got) - } - if got := facts.FunctionFacts.NarrowSummary(sym); !returnsummary.Equal(got, []typ.Type{typ.String}) { - t.Fatalf("narrow mismatch: got %v", got) - } - if got := facts.FunctionFacts.FunctionType(sym); !typ.TypeEquals(got, fn) { - t.Fatalf("func mismatch: got %v", got) - } -} - -func TestFunctionFactJoin_MergesExistingAndCandidate(t *testing.T) { - existingFn := typ.Func().Returns(typ.Number).Build() - candidateFn := typ.Func().Returns(typ.String).Build() - existing := api.FunctionFact{ - Summary: []typ.Type{typ.Number}, - Narrow: []typ.Type{typ.Number}, - Type: existingFn, - } - candidate := api.FunctionFact{ - Summary: []typ.Type{typ.String}, - Narrow: []typ.Type{typ.String}, - Type: candidateFn, - } - got := functionfact.Join(existing, candidate) - - if !returnsummary.Equal(got.Summary, []typ.Type{typ.NewUnion(typ.Number, typ.String)}) { - t.Fatalf("summary mismatch: got %v", got.Summary) - } - if !returnsummary.Equal(got.Narrow, []typ.Type{typ.NewUnion(typ.Number, typ.String)}) { - t.Fatalf("narrow mismatch: got %v", got.Narrow) - } - if got.Type == nil { - t.Fatal("expected merged function type") - } -} - -func TestJoinFacts_BatchMergeFunctionFacts(t *testing.T) { - symSummary := cfg.SymbolID(21) - symNarrow := cfg.SymbolID(22) - symFunc := cfg.SymbolID(23) - funcType := typ.Func().Returns(typ.Boolean).Build() - - facts := JoinFacts( - api.Facts{ - FunctionFacts: api.FunctionFacts{ - symSummary: {Summary: []typ.Type{typ.String}}, - symNarrow: {Narrow: []typ.Type{typ.Number}}, - }, - }, - api.Facts{ - FunctionFacts: api.FunctionFacts{ - symFunc: {Type: funcType}, - }, - }, - ) - - if got := facts.FunctionFacts.Summary(symSummary); !returnsummary.Equal(got, []typ.Type{typ.String}) { - t.Fatalf("summary mismatch: got %v", got) - } - if got := facts.FunctionFacts.NarrowSummary(symNarrow); !returnsummary.Equal(got, []typ.Type{typ.Number}) { - t.Fatalf("narrow mismatch: got %v", got) - } - if got := facts.FunctionFacts.FunctionType(symFunc); !typ.TypeEquals(got, funcType) { - t.Fatalf("func mismatch: got %v", got) - } -} - -func TestFunctionFactJoin_NarrowSummaryReplacesOpenTopPlaceholder(t *testing.T) { - openTop := typ.NewRecord().SetOpen(true).Build() - existingFunc := typ.Func().Returns(openTop).Build() - candidateFunc := typ.Func().Returns(openTop).Build() - narrow := []typ.Type{typ.NewArray(typ.Unknown)} - - out := functionfact.Join( - api.FunctionFact{Summary: []typ.Type{openTop}, Type: existingFunc}, - api.FunctionFact{Summary: []typ.Type{openTop}, Narrow: narrow, Type: candidateFunc}, - ) - - if !returnsummary.Equal(returnsummary.NormalizeAndPrune(out.Summary), returnsummary.NormalizeAndPrune(narrow)) { - t.Fatalf("summary mismatch: got %v want %v", out.Summary, narrow) - } - - fn, ok := out.Type.(*typ.Function) - if !ok { - t.Fatalf("expected function fact, got %T", out.Type) - } - if !returnsummary.Equal(returnsummary.NormalizeAndPrune(fn.Returns), returnsummary.NormalizeAndPrune(narrow)) { - t.Fatalf("func returns mismatch: got %v want %v", fn.Returns, narrow) - } -} - -func TestFunctionFactJoin_NarrowSummaryRepairsNeverArtifact(t *testing.T) { - bad := []typ.Type{ - typ.NewUnion( - typ.NewRecord(). - Field("success", typ.True). - Field("result", typ.NewRecord().OptField("data", typ.Never).Build()). - Build(), - typ.NewRecord(). - Field("success", typ.False). - Field("error", typ.LiteralString("missing")). - Build(), - ), - } - good := []typ.Type{ - typ.NewUnion( - typ.NewRecord(). - Field("success", typ.True). - Field("result", typ.NewRecord().OptField("data", typ.Unknown).Build()). - Build(), - typ.NewRecord(). - Field("success", typ.False). - Field("error", typ.LiteralString("missing")). - Build(), - ), - } - existingFunc := typ.Func().Returns(bad...).Build() - - out := functionfact.Join( - api.FunctionFact{Summary: bad, Type: existingFunc}, - api.FunctionFact{Narrow: good}, - ) - - if !returnsummary.Equal(out.Summary, good) { - t.Fatalf("summary mismatch: got %v want %v", out.Summary, good) - } - if !returnsummary.Equal(out.Narrow, good) { - t.Fatalf("narrow mismatch: got %v want %v", out.Narrow, good) - } - fn, ok := out.Type.(*typ.Function) - if !ok { - t.Fatalf("expected function fact, got %T", out.Type) - } - if !returnsummary.Equal(fn.Returns, good) { - t.Fatalf("func returns mismatch: got %v want %v", fn.Returns, good) - } -} - -func TestFunctionFactJoin_DoesNotAlignFunctionToNarrowFieldRegression(t *testing.T) { - withCapturedMethod := typ.NewRecord(). - Field("x", typ.Integer). - Field("get_x", typ.Func().Param("self", typ.Unknown).Returns(typ.Number).Build()). - Build() - flowOnly := typ.NewRecord(). - Field("x", typ.Integer). - Build() - existingFunc := typ.Func().Returns(flowOnly).Build() - - out := functionfact.Join( - api.FunctionFact{Summary: []typ.Type{withCapturedMethod}, Narrow: []typ.Type{flowOnly}, Type: existingFunc}, - api.FunctionFact{Summary: []typ.Type{withCapturedMethod}, Narrow: []typ.Type{flowOnly}, Type: existingFunc}, - ) - - if !returnsummary.Equal(out.Summary, []typ.Type{withCapturedMethod}) { - t.Fatalf("summary mismatch: got %v want %v", out.Summary, []typ.Type{withCapturedMethod}) - } - fn, ok := out.Type.(*typ.Function) - if !ok { - t.Fatalf("expected function fact, got %T", out.Type) - } - if !returnsummary.Equal(fn.Returns, []typ.Type{withCapturedMethod}) { - t.Fatalf("func returns should preserve captured method summary, got %v", fn.Returns) - } -} - -func TestFunctionFactNormalize_CanonicalizesStoredFunctionFacts(t *testing.T) { - sym := cfg.SymbolID(77) - fn := typ.Func().Returns(typ.Number).Build() - facts := &api.Facts{ - FunctionFacts: api.FunctionFacts{ - sym: {Summary: []typ.Type{nil}, Narrow: []typ.Type{typ.Number}, Type: fn}, - }, - } - - facts.FunctionFacts[sym] = functionfact.Normalize(facts.FunctionFacts[sym]) - - ff, ok := facts.FunctionFacts[sym] - if !ok { - t.Fatal("expected canonical FunctionFacts entry") - } - if !returnsummary.Equal(ff.Summary, []typ.Type{typ.Nil}) { - t.Fatalf("summary mismatch: got %v", ff.Summary) - } - if !returnsummary.Equal(ff.Narrow, []typ.Type{typ.Number}) { - t.Fatalf("narrow mismatch: got %v", ff.Narrow) - } - if !typ.TypeEquals(ff.Type, fn) { - t.Fatalf("func mismatch: got %v", ff.Type) - } -} - -func TestFunctionFactsAccessorsReadCanonicalFacts(t *testing.T) { - sym := cfg.SymbolID(88) - fn := typ.Func().Returns(typ.String).Build() - facts := api.Facts{ - FunctionFacts: api.FunctionFacts{ - sym: {Summary: []typ.Type{typ.String}, Narrow: []typ.Type{typ.String}, Type: fn}, - }, - } - - if got := facts.FunctionFacts.Summary(sym); !returnsummary.Equal(got, []typ.Type{typ.String}) { - t.Fatalf("summary mismatch: got %v", got) - } - - if got := facts.FunctionFacts.NarrowSummary(sym); !returnsummary.Equal(got, []typ.Type{typ.String}) { - t.Fatalf("narrow mismatch: got %v", got) - } - - if got := facts.FunctionFacts.FunctionType(sym); !typ.TypeEquals(got, fn) { - t.Fatalf("func mismatch: got %v", got) - } -} diff --git a/compiler/check/store/snapshot_inputs.go b/compiler/check/store/snapshot_inputs.go index 6bcb6233..71eea969 100644 --- a/compiler/check/store/snapshot_inputs.go +++ b/compiler/check/store/snapshot_inputs.go @@ -3,7 +3,7 @@ package store import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" - "github.com/wippyai/go-lua/compiler/check/returns" + "github.com/wippyai/go-lua/compiler/check/domain/factproduct" "github.com/wippyai/go-lua/types/constraint" "github.com/wippyai/go-lua/types/db" "github.com/wippyai/go-lua/types/typ" @@ -82,7 +82,7 @@ func (in *snapshotInputs) setFacts(key api.GraphKey, facts api.Facts) { return } next := cloneFacts(facts) - if prev, ok := in.factValues[key]; ok && returns.FactsEqual(prev, next) { + if prev, ok := in.factValues[key]; ok && factproduct.FactsEqual(prev, next) { return } in.factValues[key] = next @@ -160,7 +160,7 @@ func constructorFieldMapsEqual(sym cfg.SymbolID, a, b map[string]typ.Type) bool if len(a) == 0 && len(b) == 0 { return true } - return returns.ConstructorFieldsEqual( + return factproduct.ConstructorFieldsEqual( api.ConstructorFields{sym: a}, api.ConstructorFields{sym: b}, ) @@ -193,7 +193,7 @@ func (s *SessionStore) currentInterprocFacts(key api.GraphKey) api.Facts { if factsEmpty(next) { return cloneFacts(prev) } - return cloneFacts(returns.JoinFacts(prev, next)) + return cloneFacts(factproduct.JoinFacts(prev, next)) } } return cloneFacts(prev) diff --git a/compiler/check/store/store.go b/compiler/check/store/store.go index 504e909b..63932d12 100644 --- a/compiler/check/store/store.go +++ b/compiler/check/store/store.go @@ -7,7 +7,7 @@ import ( "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" - "github.com/wippyai/go-lua/compiler/check/returns" + "github.com/wippyai/go-lua/compiler/check/domain/factproduct" "github.com/wippyai/go-lua/compiler/check/scope" "github.com/wippyai/go-lua/types/constraint" "github.com/wippyai/go-lua/types/db" @@ -131,7 +131,7 @@ func interprocFactsMapEqual(a, b map[api.GraphKey]api.Facts) bool { return false } for _, key := range api.SortedGraphKeys(a) { - if !returns.FactsEqual(a[key], b[key]) { + if !factproduct.FactsEqual(a[key], b[key]) { return false } } @@ -150,9 +150,9 @@ func widenInterprocFacts(prev, next map[api.GraphKey]api.Facts) map[api.GraphKey for _, key := range api.SortedGraphKeys(next) { facts := next[key] if existing, ok := out[key]; ok { - out[key] = returns.WidenFacts(existing, facts) + out[key] = factproduct.WidenFacts(existing, facts) } else { - out[key] = returns.WidenFacts(api.Facts{}, facts) + out[key] = factproduct.WidenFacts(api.Facts{}, facts) } } return out @@ -334,7 +334,7 @@ func (s *SessionStore) swapInterprocChannels() []string { func(_prev, next api.ConstructorFields) api.ConstructorFields { return next }, - returns.ConstructorFieldsEqual, + factproduct.ConstructorFieldsEqual, func() api.ConstructorFields { return make(api.ConstructorFields) }, @@ -557,7 +557,7 @@ func (s *SessionStore) MergeInterprocFactsNext(key api.GraphKey, delta api.Facts } s.ensureInterprocStates() existing := s.InterprocNext.Facts[key] - facts := returns.JoinFacts(existing, delta) + facts := factproduct.JoinFacts(existing, delta) if factsEmpty(facts) { if factsEmpty(existing) { return @@ -566,7 +566,7 @@ func (s *SessionStore) MergeInterprocFactsNext(key api.GraphKey, delta api.Facts s.syncFactsInput(key) return } - if returns.FactsEqual(existing, facts) { + if factproduct.FactsEqual(existing, facts) { return } s.InterprocNext.Facts[key] = facts From 78c105068ad382e6e44a386a2cee241cd76e71ef Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 03:08:59 -0400 Subject: [PATCH 23/71] Move convergence laws into domains --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 52 ++ .../domain/factproduct/domain_law_test.go | 19 - compiler/check/domain/factproduct/product.go | 535 +----------------- .../check/domain/factproduct/product_test.go | 70 +-- compiler/check/domain/functionfact/fact.go | 276 ++++++++- .../check/domain/functionfact/fact_test.go | 67 +++ .../check/domain/returnsummary/summary.go | 112 ++-- compiler/check/domain/value/convergence.go | 205 +++++++ .../check/domain/value/convergence_test.go | 26 + compiler/check/domain/value/doc.go | 8 + compiler/check/domain/value/growth.go | 81 +++ .../summary_test.go => value/growth_test.go} | 2 +- 12 files changed, 769 insertions(+), 684 deletions(-) create mode 100644 compiler/check/domain/value/convergence.go create mode 100644 compiler/check/domain/value/convergence_test.go create mode 100644 compiler/check/domain/value/doc.go create mode 100644 compiler/check/domain/value/growth.go rename compiler/check/domain/{returnsummary/summary_test.go => value/growth_test.go} (98%) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index 60c2daa3..7de2c7dc 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -345,6 +345,58 @@ Verification for this slice so far: targets: session 8 errors, agent/src 10 errors, docker-demo 21 errors and 2 warnings. +## 2026-05-19 Convergence Law Ownership Checkpoint + +The next rectification slice removed convergence and structural value laws from +`domain/factproduct`. The fact-product domain now composes slot domains instead +of carrying private copies of their logic. + +Moved laws: + +- higher-order recursive-growth detection moved to `domain/value`; +- convergence widening for one `typ.Type` moved to `domain/value`; +- unsafe precision-drop detection moved to `domain/value`; +- return-vector convergence widening moved to `domain/returnsummary`; +- one-function fact convergence widening moved to `domain/functionfact`; +- same-signature return-slot merging for function literals moved to + `domain/functionfact`; +- related tests moved to the packages that own the laws. + +The old local helper names are gone from production code: + +```text +mergeFunctionReturnsIfSameShape +widenFunctionFactTypeForConvergence +widenReturnSummaryForConvergence +maybeWidenTypeForConvergence +widenValueTypeForConvergence +typeUnsafePrecisionDrop +returnsummary.HasHigherOrderGrowthRisk +``` + +Current convergence flow: + +1. `domain/value` defines structural type relations and finite-height + convergence approximations. +2. `domain/returnsummary` widens return vectors using the value domain. +3. `domain/functionfact` widens one `api.FunctionFact` using parameter evidence, + return summaries, and value relations. +4. `domain/factproduct` widens maps and fact slots only by delegating to those + owners. + +Verification for this slice so far: + +- `go test ./...` passes. +- `git diff --check` passes. +- `go test ./compiler/check -run '^$' -bench BenchmarkCheck_LargeFunction + -benchmem -count=3` reports 1.14-1.16 ms/op, 881 KB/op, and 9390 allocs/op + on this machine. +- Standard `../scripts/verify-suite.sh` passes go-lua checker tests and builds + the Wippy binary, then exits non-zero on the known external pinned lint + targets: session 8 errors, agent/src 10 errors, docker-demo 21 errors and + 2 warnings. One first run printed agent/src 12 errors; direct replay of that + target and a full rerun both returned 10. + ## Goal The checker should read as one abstract interpreter over a product domain. diff --git a/compiler/check/domain/factproduct/domain_law_test.go b/compiler/check/domain/factproduct/domain_law_test.go index 240f833e..0ceb4312 100644 --- a/compiler/check/domain/factproduct/domain_law_test.go +++ b/compiler/check/domain/factproduct/domain_law_test.go @@ -198,25 +198,6 @@ func TestFactsDomain_WidenPreservesCapturedCallbackUnionMembers(t *testing.T) { } } -func TestFactsDomain_UnsafeNestedUnionDropDetected(t *testing.T) { - withPending := typ.NewUnion( - typ.LiteralString("pass"), - typ.LiteralString("pending"), - typ.LiteralString("fail"), - typ.LiteralString("skip"), - ) - withoutPending := typ.NewUnion( - typ.LiteralString("pass"), - typ.LiteralString("fail"), - typ.LiteralString("skip"), - ) - prev := typ.NewRecord().Field("status", withPending).Build() - next := typ.NewRecord().Field("status", withoutPending).Build() - if !typeUnsafePrecisionDrop(prev, next) { - t.Fatalf("expected nested union member drop to be unsafe: prev=%v next=%v", prev, next) - } -} - func unwrapFunctionForDomainTest(t *testing.T, got typ.Type) *typ.Function { t.Helper() fn, ok := got.(*typ.Function) diff --git a/compiler/check/domain/factproduct/product.go b/compiler/check/domain/factproduct/product.go index 0a3feea2..6629cb48 100644 --- a/compiler/check/domain/factproduct/product.go +++ b/compiler/check/domain/factproduct/product.go @@ -4,8 +4,6 @@ import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/domain/functionfact" - "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" - "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" "github.com/wippyai/go-lua/compiler/check/domain/value" "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" @@ -31,7 +29,7 @@ func WidenFacts(prev, next api.Facts) api.Facts { for _, sym := range symbols { prevFact := readFunctionFactFromFacts(&prev, sym) nextFact := readFunctionFactFromFacts(&next, sym) - writeNormalizedFunctionFactToFacts(&out, sym, widenFunctionFactForConvergence(prevFact, nextFact)) + writeNormalizedFunctionFactToFacts(&out, sym, functionfact.WidenForConvergence(prevFact, nextFact)) } if len(out.FunctionFacts) == 0 { out.FunctionFacts = nil @@ -63,177 +61,14 @@ func JoinFacts(prev, next api.Facts) api.Facts { return out } -func widenFunctionFactForConvergence(prev, next api.FunctionFact) api.FunctionFact { - out := api.FunctionFact{ - Params: paramevidence.JoinVectors(prev.Params, next.Params), - Summary: widenReturnSummaryForConvergence(prev.Summary, next.Summary), - Narrow: widenReturnSummaryForConvergence(prev.Narrow, next.Narrow), - Type: widenFunctionFactTypeForConvergence(prev.Type, next.Type), - } - - // Narrow summaries can refine optional/non-nil returns, but a nil-only - // narrow observation must not erase an already-informative summary. - if len(out.Narrow) > 0 && !returnsummary.AllNil(out.Narrow) { - if len(out.Summary) == 0 { - out.Summary = returnsummary.Canonical(out.Narrow) - } else { - out.Summary = widenReturnSummaryForConvergence(out.Summary, out.Narrow) - } - } - - if fn := unwrap.Function(out.Type); fn != nil { - if len(out.Summary) > 0 { - if aligned, changed := returnsummary.AlignFunction(fn, out.Summary); changed { - out.Type = widenFunctionFactTypeForConvergence(fn, aligned) - } - } else if len(fn.Returns) > 0 { - out.Summary = widenReturnSummaryForConvergence(nil, fn.Returns) - } - } - - return out -} - -func widenReturnSummaryForConvergence(prev, next []typ.Type) []typ.Type { - prev = returnsummary.NormalizeAndPrune(prev) - next = returnsummary.NormalizeAndPrune(next) - if len(prev) == 0 { - return widenReturnVectorForConvergence(next) - } - if len(next) == 0 { - return widenReturnVectorForConvergence(prev) - } - - merged := returnsummary.Merge(prev, next) - if returnVectorUnsafePrecisionDrop(prev, merged) { - merged = prev - } - return widenReturnVectorForConvergence(returnsummary.NormalizeAndPrune(merged)) -} - -func returnVectorUnsafePrecisionDrop(prev, merged []typ.Type) bool { - if len(prev) == 0 || len(merged) == 0 || len(prev) != len(merged) { - return false - } - for i := range prev { - if typeUnsafePrecisionDrop(prev[i], merged[i]) { - return true - } - } - return false -} - -func typeUnsafePrecisionDrop(prev, merged typ.Type) bool { - if prev == nil || merged == nil || typ.TypeEquals(prev, merged) { - return false - } - if value.ElidesOptional(merged, prev) { - return false - } - if refines, _ := value.RefinesFalsyMapKey(merged, prev); refines { - return false - } - if typ.IsAny(prev) || typ.IsUnknown(prev) { - return true - } - - switch p := value.UnwrapStructuralShape(prev).(type) { - case *typ.Union: - if unionStrictMemberSubset(merged, p) { - return true - } - if subtype.IsSubtype(merged, p) && !subtype.IsSubtype(p, merged) { - return true - } - case *typ.Record: - m, ok := value.UnwrapStructuralShape(merged).(*typ.Record) - if !ok { - break - } - for _, pf := range p.Fields { - mf := m.GetField(pf.Name) - if mf != nil && typeUnsafePrecisionDrop(pf.Type, mf.Type) { - return true - } - } - if p.HasMapComponent() && m.HasMapComponent() && typeUnsafePrecisionDrop(p.MapValue, m.MapValue) { - return true - } - case *typ.Array: - if m, ok := value.UnwrapStructuralShape(merged).(*typ.Array); ok { - return typeUnsafePrecisionDrop(p.Element, m.Element) - } - case *typ.Map: - if m, ok := value.UnwrapStructuralShape(merged).(*typ.Map); ok { - return typeUnsafePrecisionDrop(p.Key, m.Key) || typeUnsafePrecisionDrop(p.Value, m.Value) - } - case *typ.Tuple: - m, ok := value.UnwrapStructuralShape(merged).(*typ.Tuple) - if !ok || len(p.Elements) != len(m.Elements) { - break - } - for i := range p.Elements { - if typeUnsafePrecisionDrop(p.Elements[i], m.Elements[i]) { - return true - } - } - case *typ.Function: - m, ok := value.UnwrapStructuralShape(merged).(*typ.Function) - if !ok { - break - } - for i := 0; i < len(p.Params) && i < len(m.Params); i++ { - if typeUnsafePrecisionDrop(p.Params[i].Type, m.Params[i].Type) { - return true - } - } - for i := 0; i < len(p.Returns) && i < len(m.Returns); i++ { - if typeUnsafePrecisionDrop(p.Returns[i], m.Returns[i]) { - return true - } - } - } - - if subtype.IsSubtype(merged, prev) && !subtype.IsSubtype(prev, merged) && !value.ExtendsRecord(merged, prev) { - return true - } - return false -} - -func unionStrictMemberSubset(candidate typ.Type, baseline *typ.Union) bool { - if baseline == nil { - return false - } - candidateMembers := value.UnionMembers(candidate) - if len(candidateMembers) == 0 { - candidateMembers = []typ.Type{candidate} - } - if len(candidateMembers) >= len(baseline.Members) { - return false - } - for _, member := range candidateMembers { - found := false - for _, baseMember := range baseline.Members { - if typ.TypeEquals(member, baseMember) { - found = true - break - } - } - if !found { - return false - } - } - return true -} - func canonicalInterprocValueType(t typ.Type) typ.Type { if t == nil { return nil } if fn := unwrap.Function(t); fn != nil { - return maybeWidenTypeForConvergence(fn) + return value.WidenForConvergence(fn) } - return maybeWidenTypeForConvergence(t) + return value.WidenForConvergence(t) } func mergeInterprocValueType(existing, candidate typ.Type) typ.Type { @@ -246,19 +81,13 @@ func mergeInterprocValueType(existing, candidate typ.Type) typ.Type { return existing } if unwrap.Function(existing) != nil || unwrap.Function(candidate) != nil { - return maybeWidenTypeForConvergence(widenFunctionFactTypeForConvergence(existing, candidate)) + return value.WidenForConvergence(functionfact.WidenTypeForConvergence(existing, candidate)) } - return maybeWidenTypeForConvergence(widenValueTypeForConvergence(existing, candidate)) + return value.WidenForConvergence(value.MergeForConvergence(existing, candidate)) } func normalizeInterprocValueType(t typ.Type) typ.Type { - if t == nil { - return nil - } - if fn := unwrap.Function(t); fn != nil { - return fn - } - return typ.PruneSoftUnionMembers(t) + return value.NormalizeFactType(t) } func joinInterprocValueType(existing, candidate typ.Type) typ.Type { @@ -273,159 +102,7 @@ func joinInterprocValueType(existing, candidate typ.Type) typ.Type { if unwrap.Function(existing) != nil || unwrap.Function(candidate) != nil { return functionfact.MergeType(existing, candidate) } - return typ.JoinPreferNonSoft(existing, candidate) -} - -func widenValueTypeForConvergence(existing, candidate typ.Type) typ.Type { - existing = normalizeInterprocValueType(existing) - candidate = normalizeInterprocValueType(candidate) - if existing == nil { - return maybeWidenTypeForConvergence(candidate) - } - if candidate == nil { - return maybeWidenTypeForConvergence(existing) - } - existing = maybeWidenTypeForConvergence(existing) - candidate = maybeWidenTypeForConvergence(candidate) - if typ.TypeEquals(existing, candidate) { - return existing - } - if unwrap.IsNilType(existing) && !unwrap.IsNilType(candidate) { - return candidate - } - if unwrap.IsNilType(candidate) && !unwrap.IsNilType(existing) { - return existing - } - if typ.IsAny(existing) || typ.IsUnknown(existing) { - return existing - } - if typ.IsAny(candidate) || typ.IsUnknown(candidate) { - return candidate - } - if value.ElidesOptional(candidate, existing) { - return candidate - } - if value.ExtendsRecord(candidate, existing) && !value.ContainsNestedStructuralShape(candidate, existing) { - return candidate - } - if refines, _ := value.RefinesFalsyMapKey(candidate, existing); refines { - return candidate - } - if subtype.IsSubtype(candidate, existing) && !subtype.IsSubtype(existing, candidate) { - return existing - } - if subtype.IsSubtype(existing, candidate) && !subtype.IsSubtype(candidate, existing) { - return candidate - } - return typ.JoinPreferNonSoft(existing, candidate) -} - -func widenFunctionFactTypeForConvergence(existing, candidate typ.Type) typ.Type { - existing = normalizeInterprocValueType(existing) - candidate = normalizeInterprocValueType(candidate) - if existing == nil { - return maybeWidenTypeForConvergence(candidate) - } - if candidate == nil { - return maybeWidenTypeForConvergence(existing) - } - existingFn := unwrap.Function(existing) - candidateFn := unwrap.Function(candidate) - if existingFn != nil && candidateFn != nil && functionfact.SameShape(existingFn, candidateFn) { - return maybeWidenTypeForConvergence(widenFunctionFactsByShape(existingFn, candidateFn)) - } - return widenValueTypeForConvergence(existing, candidate) -} - -func widenFunctionFactsByShape(existing, candidate *typ.Function) typ.Type { - if existing == nil { - return candidate - } - if candidate == nil { - return existing - } - - builder := typ.Func() - for _, tp := range existing.TypeParams { - builder = builder.TypeParam(tp.Name, tp.Constraint) - } - for i, p := range existing.Params { - paramType := widenFunctionParamFactTypeForConvergence(p.Type, candidate.Params[i].Type) - name := p.Name - if name == "" { - name = candidate.Params[i].Name - } - if p.Optional || candidate.Params[i].Optional { - builder = builder.OptParam(name, paramType) - } else { - builder = builder.Param(name, paramType) - } - } - if existing.Variadic != nil || candidate.Variadic != nil { - builder = builder.Variadic(widenFunctionParamFactTypeForConvergence(existing.Variadic, candidate.Variadic)) - } - if returns := widenReturnSummaryForConvergence(existing.Returns, candidate.Returns); len(returns) > 0 { - builder = builder.Returns(returns...) - } - - effects := existing.Effects - if effects == nil { - effects = candidate.Effects - } - if effects != nil { - builder = builder.Effects(effects) - } - spec := existing.Spec - if spec == nil { - spec = candidate.Spec - } - if spec != nil { - builder = builder.Spec(spec) - } - refinement := existing.Refinement - if refinement == nil { - refinement = candidate.Refinement - } - if refinement != nil { - builder = builder.WithRefinement(refinement) - } - return builder.Build() -} - -func widenFunctionParamFactTypeForConvergence(existing, candidate typ.Type) typ.Type { - existing = normalizeInterprocValueType(existing) - candidate = normalizeInterprocValueType(candidate) - if existing == nil { - return candidate - } - if candidate == nil { - return existing - } - if typ.TypeEquals(existing, candidate) { - return existing - } - if typ.IsAny(existing) || typ.IsUnknown(existing) { - return existing - } - if typ.IsAny(candidate) || typ.IsUnknown(candidate) { - return candidate - } - if preferred, ok := value.PreferConcreteOverSoft(existing, candidate); ok { - return preferred - } - if paramevidence.RefinesFunctionParam(candidate, existing) { - return candidate - } - if paramevidence.RefinesFunctionParam(existing, candidate) { - return existing - } - if subtype.IsSubtype(candidate, existing) && !subtype.IsSubtype(existing, candidate) { - return existing - } - if subtype.IsSubtype(existing, candidate) && !subtype.IsSubtype(candidate, existing) { - return candidate - } - return typ.JoinPreferNonSoft(existing, candidate) + return value.JoinPrecise(existing, candidate) } // WidenLiteralSigs merges two literal signature maps. @@ -441,13 +118,13 @@ func WidenLiteralSigs(prev, next api.LiteralSigs) api.LiteralSigs { } merged := make(api.LiteralSigs, len(prev)+len(next)) for fn, sig := range prev { - merged[fn] = maybeWidenFunctionForConvergence(sig) + merged[fn] = value.WidenFunctionForConvergence(sig) } for fn, sig := range next { if existing := merged[fn]; existing != nil { - merged[fn] = maybeWidenFunctionForConvergence(mergeLiteralSigForConvergence(existing, sig)) + merged[fn] = value.WidenFunctionForConvergence(mergeLiteralSigForConvergence(existing, sig)) } else { - merged[fn] = maybeWidenFunctionForConvergence(sig) + merged[fn] = value.WidenFunctionForConvergence(sig) } } return merged @@ -459,7 +136,7 @@ func normalizeLiteralSigs(sigs api.LiteralSigs) api.LiteralSigs { } out := make(api.LiteralSigs, len(sigs)) for fn, sig := range sigs { - out[fn] = maybeWidenFunctionForConvergence(sig) + out[fn] = value.WidenFunctionForConvergence(sig) } return out } @@ -477,13 +154,13 @@ func JoinLiteralSigs(prev, next api.LiteralSigs) api.LiteralSigs { } merged := make(api.LiteralSigs, len(prev)+len(next)) for fn, sig := range prev { - merged[fn] = maybeWidenFunctionForConvergence(sig) + merged[fn] = value.WidenFunctionForConvergence(sig) } for fn, sig := range next { if existing := merged[fn]; existing != nil { - merged[fn] = maybeWidenFunctionForConvergence(mergeLiteralSig(existing, sig)) + merged[fn] = value.WidenFunctionForConvergence(mergeLiteralSig(existing, sig)) } else { - merged[fn] = maybeWidenFunctionForConvergence(sig) + merged[fn] = value.WidenFunctionForConvergence(sig) } } return merged @@ -496,7 +173,7 @@ func mergeLiteralSig(prev, next *typ.Function) *typ.Function { if next == nil { return prev } - if merged, ok := mergeFunctionReturnsIfSameShape(prev, next); ok { + if merged, ok := functionfact.MergeReturnsForSameSignature(prev, next); ok { if fn, ok := merged.(*typ.Function); ok { return fn } @@ -513,7 +190,7 @@ func mergeLiteralSig(prev, next *typ.Function) *typ.Function { } func mergeLiteralSigForConvergence(prev, next *typ.Function) *typ.Function { - merged := widenFunctionFactTypeForConvergence(prev, next) + merged := functionfact.WidenTypeForConvergence(prev, next) if fn := unwrap.Function(merged); fn != nil { return fn } @@ -734,7 +411,7 @@ func WidenCapturedContainerMutations(prev, next api.CapturedContainerMutations) if prev != nil { next.ValueType = mergeInterprocValueType(prev.ValueType, next.ValueType) } else { - next.ValueType = maybeWidenTypeForConvergence(next.ValueType) + next.ValueType = value.WidenForConvergence(next.ValueType) } return next }) @@ -871,7 +548,7 @@ func WidenConstructorFields(prev, next api.ConstructorFields) api.ConstructorFie if prevType := out[name]; prevType != nil { out[name] = mergeInterprocValueType(prevType, t) } else { - out[name] = maybeWidenTypeForConvergence(t) + out[name] = value.WidenForConvergence(t) } } merged[sym] = out @@ -961,179 +638,3 @@ func normalizeConstructorFieldMapForJoin(fields map[string]typ.Type) map[string] } return out } - -func mergeFunctionReturnsIfSameShape(prevFn, nextFn *typ.Function) (typ.Type, bool) { - if prevFn == nil || nextFn == nil { - return nil, false - } - if len(prevFn.TypeParams) != len(nextFn.TypeParams) { - return nil, false - } - if !typeParamsEqual(prevFn.TypeParams, nextFn.TypeParams) { - return nil, false - } - if len(prevFn.Params) != len(nextFn.Params) { - return nil, false - } - if (prevFn.Variadic == nil) != (nextFn.Variadic == nil) { - return nil, false - } - if prevFn.Variadic != nil && !typ.TypeEquals(prevFn.Variadic, nextFn.Variadic) { - return nil, false - } - for i := range prevFn.Params { - if prevFn.Params[i].Optional != nextFn.Params[i].Optional { - return nil, false - } - if !typ.TypeEquals(prevFn.Params[i].Type, nextFn.Params[i].Type) { - return nil, false - } - } - if len(prevFn.Returns) == 0 && len(nextFn.Returns) == 0 { - return prevFn, true - } - if len(prevFn.Returns) != len(nextFn.Returns) || len(prevFn.Returns) == 0 { - return nil, false - } - - allowedTypeParams := make(map[string]bool, len(prevFn.TypeParams)) - for _, tp := range prevFn.TypeParams { - if tp != nil && tp.Name != "" { - allowedTypeParams[tp.Name] = true - } - } - normalizeReturn := func(t typ.Type) (typ.Type, bool) { - if t == nil { - return nil, false - } - leaked := false - return typ.Rewrite(t, func(node typ.Type) (typ.Type, bool) { - tp, ok := node.(*typ.TypeParam) - if !ok { - return node, false - } - if allowedTypeParams[tp.Name] { - return node, false - } - // Free type params in non-generic function returns are unstable placeholders. - leaked = true - return typ.Unknown, true - }), leaked - } - normalizedPrev := make([]typ.Type, len(prevFn.Returns)) - normalizedNext := make([]typ.Type, len(nextFn.Returns)) - leakedPrev := make([]bool, len(prevFn.Returns)) - leakedNext := make([]bool, len(nextFn.Returns)) - for i := range prevFn.Returns { - normalizedPrev[i], leakedPrev[i] = normalizeReturn(prevFn.Returns[i]) - normalizedNext[i], leakedNext[i] = normalizeReturn(nextFn.Returns[i]) - } - - mergedReturns := make([]typ.Type, len(normalizedPrev)) - for i := range mergedReturns { - switch { - case leakedPrev[i] && !leakedNext[i]: - mergedReturns[i] = normalizedNext[i] - case leakedNext[i] && !leakedPrev[i]: - mergedReturns[i] = normalizedPrev[i] - default: - mergedReturns[i] = typ.JoinReturnSlot(normalizedPrev[i], normalizedNext[i]) - } - } - if returnsummary.Equal(prevFn.Returns, mergedReturns) { - return prevFn, true - } - if returnsummary.Equal(nextFn.Returns, mergedReturns) { - return nextFn, true - } - - effects := prevFn.Effects - if effects == nil { - effects = nextFn.Effects - } - spec := prevFn.Spec - if spec == nil { - spec = nextFn.Spec - } - refinement := prevFn.Refinement - if refinement == nil { - refinement = nextFn.Refinement - } - - builder := typ.Func(). - Effects(effects). - Spec(spec). - WithRefinement(refinement) - for _, tp := range prevFn.TypeParams { - builder = builder.TypeParam(tp.Name, tp.Constraint) - } - for _, p := range prevFn.Params { - if p.Optional { - builder = builder.OptParam(p.Name, p.Type) - } else { - builder = builder.Param(p.Name, p.Type) - } - } - if prevFn.Variadic != nil { - builder = builder.Variadic(prevFn.Variadic) - } - builder = builder.Returns(mergedReturns...) - return builder.Build(), true -} - -func typeParamsEqual(a, b []*typ.TypeParam) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if a[i] == nil || b[i] == nil { - if a[i] != b[i] { - return false - } - continue - } - if !a[i].Equals(b[i]) { - return false - } - } - return true -} - -func widenReturnVectorForConvergence(rets []typ.Type) []typ.Type { - if len(rets) == 0 { - return rets - } - out := make([]typ.Type, len(rets)) - changed := false - for i, t := range rets { - wt := maybeWidenTypeForConvergence(t) - out[i] = wt - if wt != t { - changed = true - } - } - if !changed { - return rets - } - return out -} - -func maybeWidenTypeForConvergence(t typ.Type) typ.Type { - if t == nil { - return nil - } - if !returnsummary.HasHigherOrderGrowthRisk(t) { - return t - } - return subtype.WidenForInference(t) -} - -func maybeWidenFunctionForConvergence(fn *typ.Function) *typ.Function { - if fn == nil { - return nil - } - if widened, ok := maybeWidenTypeForConvergence(fn).(*typ.Function); ok { - return widened - } - return fn -} diff --git a/compiler/check/domain/factproduct/product_test.go b/compiler/check/domain/factproduct/product_test.go index 26903e9e..ef4c1326 100644 --- a/compiler/check/domain/factproduct/product_test.go +++ b/compiler/check/domain/factproduct/product_test.go @@ -6,6 +6,7 @@ import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/domain/returnsummary" + "github.com/wippyai/go-lua/compiler/check/domain/value" "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" ) @@ -231,73 +232,6 @@ func TestWidenCapturedFieldAssigns_MergesSameShapeFunctionValues(t *testing.T) { } } -func TestMergeFunctionReturnsIfSameShape_GenericFunctions(t *testing.T) { - prev := typ.Func(). - TypeParam("T", nil). - Returns(typ.String). - Build() - next := typ.Func(). - TypeParam("T", nil). - Returns(typ.Integer). - Build() - - mergedType, ok := mergeFunctionReturnsIfSameShape(prev, next) - if !ok { - t.Fatal("expected generic same-shape functions to merge") - } - merged, ok := mergedType.(*typ.Function) - if !ok { - t.Fatalf("expected merged function type, got %T", mergedType) - } - if len(merged.TypeParams) != 1 || merged.TypeParams[0] == nil || merged.TypeParams[0].Name != "T" { - t.Fatalf("expected merged generic type parameter T, got %+v", merged.TypeParams) - } - if len(merged.Returns) != 1 { - t.Fatalf("expected one return, got %d", len(merged.Returns)) - } - want := typ.NewUnion(typ.String, typ.Integer) - if !typ.TypeEquals(merged.Returns[0], want) { - t.Fatalf("expected merged return %v, got %v", want, merged.Returns[0]) - } -} - -func TestMergeFunctionReturnsIfSameShape_GenericTypeParamsMustMatch(t *testing.T) { - prev := typ.Func(). - TypeParam("T", nil). - Returns(typ.String). - Build() - next := typ.Func(). - TypeParam("U", nil). - Returns(typ.Integer). - Build() - - _, ok := mergeFunctionReturnsIfSameShape(prev, next) - if ok { - t.Fatal("expected mismatched generic params not to merge") - } -} - -func TestMergeFunctionReturnsIfSameShape_NormalizesLeakedTypeParams(t *testing.T) { - prev := typ.Func(). - Returns(typ.NewTypeParam("T", nil)). - Build() - next := typ.Func(). - Returns(typ.Integer). - Build() - - mergedType, ok := mergeFunctionReturnsIfSameShape(prev, next) - if !ok { - t.Fatal("expected same-shape functions to merge") - } - merged, ok := mergedType.(*typ.Function) - if !ok || len(merged.Returns) != 1 { - t.Fatalf("expected merged function return, got %T", mergedType) - } - if !typ.TypeEquals(merged.Returns[0], typ.Integer) { - t.Fatalf("expected leaked type param to normalize to integer, got %v", merged.Returns[0]) - } -} - func TestWidenLiteralSigs_DoesNotNarrowComparableSignature(t *testing.T) { lit := &ast.FunctionExpr{} @@ -359,7 +293,7 @@ func TestWidenLiteralSigs_NormalizesNilBranch(t *testing.T) { merged := WidenLiteralSigs(nil, api.LiteralSigs{lit: sig}) got := merged[lit] - want := maybeWidenFunctionForConvergence(sig) + want := value.WidenFunctionForConvergence(sig) if got == nil || !typ.TypeEquals(got, want) { t.Fatalf("expected nil-branch literal signature %v to be normalized to %v, got %v", sig, want, got) } diff --git a/compiler/check/domain/functionfact/fact.go b/compiler/check/domain/functionfact/fact.go index 85f90c89..91804a8f 100644 --- a/compiler/check/domain/functionfact/fact.go +++ b/compiler/check/domain/functionfact/fact.go @@ -16,7 +16,7 @@ func Normalize(ff api.FunctionFact) api.FunctionFact { Params: paramevidence.FilterEmptyVector(ff.Params), Summary: returnsummary.Canonical(ff.Summary), Narrow: returnsummary.Canonical(ff.Narrow), - Type: normalizeType(ff.Type), + Type: value.NormalizeFactType(ff.Type), } } @@ -69,16 +69,6 @@ func Join(existing, candidate api.FunctionFact) api.FunctionFact { return out } -func normalizeType(t typ.Type) typ.Type { - if t == nil { - return nil - } - if fn := unwrap.Function(t); fn != nil { - return fn - } - return typ.PruneSoftUnionMembers(t) -} - // MergeType merges function-type facts through the canonical per-function fact // policy. func MergeType(existing, candidate typ.Type) typ.Type { @@ -109,6 +99,58 @@ func MergeType(existing, candidate typ.Type) typ.Type { return typ.JoinPreferNonSoft(existing, candidate) } +// WidenForConvergence merges one function fact at a recursive fixpoint +// boundary. +func WidenForConvergence(prev, next api.FunctionFact) api.FunctionFact { + out := api.FunctionFact{ + Params: paramevidence.JoinVectors(prev.Params, next.Params), + Summary: returnsummary.WidenForConvergence(prev.Summary, next.Summary), + Narrow: returnsummary.WidenForConvergence(prev.Narrow, next.Narrow), + Type: WidenTypeForConvergence(prev.Type, next.Type), + } + + // Narrow summaries can refine optional/non-nil returns, but a nil-only + // narrow observation must not erase an already-informative summary. + if len(out.Narrow) > 0 && !returnsummary.AllNil(out.Narrow) { + if len(out.Summary) == 0 { + out.Summary = returnsummary.Canonical(out.Narrow) + } else { + out.Summary = returnsummary.WidenForConvergence(out.Summary, out.Narrow) + } + } + + if fn := unwrap.Function(out.Type); fn != nil { + if len(out.Summary) > 0 { + if aligned, changed := returnsummary.AlignFunction(fn, out.Summary); changed { + out.Type = WidenTypeForConvergence(fn, aligned) + } + } else if len(fn.Returns) > 0 { + out.Summary = returnsummary.WidenForConvergence(nil, fn.Returns) + } + } + + return out +} + +// WidenTypeForConvergence merges function-type facts at a recursive fixpoint +// boundary. +func WidenTypeForConvergence(existing, candidate typ.Type) typ.Type { + existing = value.NormalizeFactType(existing) + candidate = value.NormalizeFactType(candidate) + if existing == nil { + return value.WidenForConvergence(candidate) + } + if candidate == nil { + return value.WidenForConvergence(existing) + } + existingFn := unwrap.Function(existing) + candidateFn := unwrap.Function(candidate) + if existingFn != nil && candidateFn != nil && SameShape(existingFn, candidateFn) { + return value.WidenForConvergence(widenByShapeForConvergence(existingFn, candidateFn)) + } + return value.MergeForConvergence(existing, candidate) +} + type variants struct { funcs []*typ.Function residuals []typ.Type @@ -247,6 +289,61 @@ func mergeByShape(existing, candidate *typ.Function) typ.Type { return builder.Build() } +func widenByShapeForConvergence(existing, candidate *typ.Function) typ.Type { + if existing == nil { + return candidate + } + if candidate == nil { + return existing + } + + builder := typ.Func() + for _, tp := range existing.TypeParams { + builder = builder.TypeParam(tp.Name, tp.Constraint) + } + for i, p := range existing.Params { + paramType := widenParamTypeForConvergence(p.Type, candidate.Params[i].Type) + name := p.Name + if name == "" { + name = candidate.Params[i].Name + } + if p.Optional || candidate.Params[i].Optional { + builder = builder.OptParam(name, paramType) + } else { + builder = builder.Param(name, paramType) + } + } + if existing.Variadic != nil || candidate.Variadic != nil { + builder = builder.Variadic(widenParamTypeForConvergence(existing.Variadic, candidate.Variadic)) + } + if returns := returnsummary.WidenForConvergence(existing.Returns, candidate.Returns); len(returns) > 0 { + builder = builder.Returns(returns...) + } + + effects := existing.Effects + if effects == nil { + effects = candidate.Effects + } + if effects != nil { + builder = builder.Effects(effects) + } + spec := existing.Spec + if spec == nil { + spec = candidate.Spec + } + if spec != nil { + builder = builder.Spec(spec) + } + refinement := existing.Refinement + if refinement == nil { + refinement = candidate.Refinement + } + if refinement != nil { + builder = builder.WithRefinement(refinement) + } + return builder.Build() +} + func mergeParamType(existing, candidate typ.Type) typ.Type { if existing == nil { return candidate @@ -302,6 +399,42 @@ func mergeParamType(existing, candidate typ.Type) typ.Type { return typ.JoinPreferNonSoft(existing, candidate) } +func widenParamTypeForConvergence(existing, candidate typ.Type) typ.Type { + existing = value.NormalizeFactType(existing) + candidate = value.NormalizeFactType(candidate) + if existing == nil { + return candidate + } + if candidate == nil { + return existing + } + if typ.TypeEquals(existing, candidate) { + return existing + } + if typ.IsAny(existing) || typ.IsUnknown(existing) { + return existing + } + if typ.IsAny(candidate) || typ.IsUnknown(candidate) { + return candidate + } + if preferred, ok := value.PreferConcreteOverSoft(existing, candidate); ok { + return preferred + } + if paramevidence.RefinesFunctionParam(candidate, existing) { + return candidate + } + if paramevidence.RefinesFunctionParam(existing, candidate) { + return existing + } + if subtype.IsSubtype(candidate, existing) && !subtype.IsSubtype(existing, candidate) { + return existing + } + if subtype.IsSubtype(existing, candidate) && !subtype.IsSubtype(candidate, existing) { + return candidate + } + return typ.JoinPreferNonSoft(existing, candidate) +} + func preferStructuredRecord(existing, candidate typ.Type) (typ.Type, bool) { existingRec, okExisting := unwrap.Alias(existing).(*typ.Record) candidateRec, okCandidate := unwrap.Alias(candidate).(*typ.Record) @@ -327,6 +460,127 @@ func preferStructuredRecord(existing, candidate typ.Type) (typ.Type, bool) { return nil, false } +// MergeReturnsForSameSignature merges return slots for function signatures that +// already have identical call shapes. +func MergeReturnsForSameSignature(prevFn, nextFn *typ.Function) (typ.Type, bool) { + if prevFn == nil || nextFn == nil { + return nil, false + } + if len(prevFn.TypeParams) != len(nextFn.TypeParams) { + return nil, false + } + if !typeParamsEqual(prevFn.TypeParams, nextFn.TypeParams) { + return nil, false + } + if len(prevFn.Params) != len(nextFn.Params) { + return nil, false + } + if (prevFn.Variadic == nil) != (nextFn.Variadic == nil) { + return nil, false + } + if prevFn.Variadic != nil && !typ.TypeEquals(prevFn.Variadic, nextFn.Variadic) { + return nil, false + } + for i := range prevFn.Params { + if prevFn.Params[i].Optional != nextFn.Params[i].Optional { + return nil, false + } + if !typ.TypeEquals(prevFn.Params[i].Type, nextFn.Params[i].Type) { + return nil, false + } + } + if len(prevFn.Returns) == 0 && len(nextFn.Returns) == 0 { + return prevFn, true + } + if len(prevFn.Returns) != len(nextFn.Returns) || len(prevFn.Returns) == 0 { + return nil, false + } + + allowedTypeParams := make(map[string]bool, len(prevFn.TypeParams)) + for _, tp := range prevFn.TypeParams { + if tp != nil && tp.Name != "" { + allowedTypeParams[tp.Name] = true + } + } + normalizeReturn := func(t typ.Type) (typ.Type, bool) { + if t == nil { + return nil, false + } + leaked := false + return typ.Rewrite(t, func(node typ.Type) (typ.Type, bool) { + tp, ok := node.(*typ.TypeParam) + if !ok { + return node, false + } + if allowedTypeParams[tp.Name] { + return node, false + } + // Free type params in non-generic function returns are unstable placeholders. + leaked = true + return typ.Unknown, true + }), leaked + } + normalizedPrev := make([]typ.Type, len(prevFn.Returns)) + normalizedNext := make([]typ.Type, len(nextFn.Returns)) + leakedPrev := make([]bool, len(prevFn.Returns)) + leakedNext := make([]bool, len(nextFn.Returns)) + for i := range prevFn.Returns { + normalizedPrev[i], leakedPrev[i] = normalizeReturn(prevFn.Returns[i]) + normalizedNext[i], leakedNext[i] = normalizeReturn(nextFn.Returns[i]) + } + + mergedReturns := make([]typ.Type, len(normalizedPrev)) + for i := range mergedReturns { + switch { + case leakedPrev[i] && !leakedNext[i]: + mergedReturns[i] = normalizedNext[i] + case leakedNext[i] && !leakedPrev[i]: + mergedReturns[i] = normalizedPrev[i] + default: + mergedReturns[i] = typ.JoinReturnSlot(normalizedPrev[i], normalizedNext[i]) + } + } + if returnsummary.Equal(prevFn.Returns, mergedReturns) { + return prevFn, true + } + if returnsummary.Equal(nextFn.Returns, mergedReturns) { + return nextFn, true + } + + effects := prevFn.Effects + if effects == nil { + effects = nextFn.Effects + } + spec := prevFn.Spec + if spec == nil { + spec = nextFn.Spec + } + refinement := prevFn.Refinement + if refinement == nil { + refinement = nextFn.Refinement + } + + builder := typ.Func(). + Effects(effects). + Spec(spec). + WithRefinement(refinement) + for _, tp := range prevFn.TypeParams { + builder = builder.TypeParam(tp.Name, tp.Constraint) + } + for _, p := range prevFn.Params { + if p.Optional { + builder = builder.OptParam(p.Name, p.Type) + } else { + builder = builder.Param(p.Name, p.Type) + } + } + if prevFn.Variadic != nil { + builder = builder.Variadic(prevFn.Variadic) + } + builder = builder.Returns(mergedReturns...) + return builder.Build(), true +} + func typeParamsEqual(a, b []*typ.TypeParam) bool { if len(a) != len(b) { return false diff --git a/compiler/check/domain/functionfact/fact_test.go b/compiler/check/domain/functionfact/fact_test.go index 54648774..a40b6387 100644 --- a/compiler/check/domain/functionfact/fact_test.go +++ b/compiler/check/domain/functionfact/fact_test.go @@ -492,6 +492,73 @@ func TestMergeType_MapVsOpenRecordUsesCanonicalJoin(t *testing.T) { } } +func TestMergeReturnsForSameSignature_GenericFunctions(t *testing.T) { + prev := typ.Func(). + TypeParam("T", nil). + Returns(typ.String). + Build() + next := typ.Func(). + TypeParam("T", nil). + Returns(typ.Integer). + Build() + + mergedType, ok := MergeReturnsForSameSignature(prev, next) + if !ok { + t.Fatal("expected generic same-shape functions to merge") + } + merged, ok := mergedType.(*typ.Function) + if !ok { + t.Fatalf("expected merged function type, got %T", mergedType) + } + if len(merged.TypeParams) != 1 || merged.TypeParams[0] == nil || merged.TypeParams[0].Name != "T" { + t.Fatalf("expected merged generic type parameter T, got %+v", merged.TypeParams) + } + if len(merged.Returns) != 1 { + t.Fatalf("expected one return, got %d", len(merged.Returns)) + } + want := typ.NewUnion(typ.String, typ.Integer) + if !typ.TypeEquals(merged.Returns[0], want) { + t.Fatalf("expected merged return %v, got %v", want, merged.Returns[0]) + } +} + +func TestMergeReturnsForSameSignature_GenericTypeParamsMustMatch(t *testing.T) { + prev := typ.Func(). + TypeParam("T", nil). + Returns(typ.String). + Build() + next := typ.Func(). + TypeParam("U", nil). + Returns(typ.Integer). + Build() + + _, ok := MergeReturnsForSameSignature(prev, next) + if ok { + t.Fatal("expected mismatched generic params not to merge") + } +} + +func TestMergeReturnsForSameSignature_NormalizesLeakedTypeParams(t *testing.T) { + prev := typ.Func(). + Returns(typ.NewTypeParam("T", nil)). + Build() + next := typ.Func(). + Returns(typ.Integer). + Build() + + mergedType, ok := MergeReturnsForSameSignature(prev, next) + if !ok { + t.Fatal("expected same-shape functions to merge") + } + merged, ok := mergedType.(*typ.Function) + if !ok || len(merged.Returns) != 1 { + t.Fatalf("expected merged function return, got %T", mergedType) + } + if !typ.TypeEquals(merged.Returns[0], typ.Integer) { + t.Fatalf("expected leaked type param to normalize to integer, got %v", merged.Returns[0]) + } +} + func TestNormalize_CanonicalizesStoredFunctionFact(t *testing.T) { fn := typ.Func().Returns(typ.Number).Build() got := Normalize(api.FunctionFact{ diff --git a/compiler/check/domain/returnsummary/summary.go b/compiler/check/domain/returnsummary/summary.go index 579a0270..dceb5699 100644 --- a/compiler/check/domain/returnsummary/summary.go +++ b/compiler/check/domain/returnsummary/summary.go @@ -719,93 +719,69 @@ func Merge(existing, candidate []typ.Type) []typ.Type { return NormalizeAndPrune(typjoin.ReturnVectors(existing, candidate)) } -func shouldUseMonotoneJoin(a, b []typ.Type) bool { - for _, t := range a { - if HasHigherOrderGrowthRisk(t) { - return true - } +// WidenForConvergence merges return vectors at a recursive fixpoint boundary. +func WidenForConvergence(prev, next []typ.Type) []typ.Type { + prev = NormalizeAndPrune(prev) + next = NormalizeAndPrune(next) + if len(prev) == 0 { + return WidenVectorForConvergence(next) } - for _, t := range b { - if HasHigherOrderGrowthRisk(t) { - return true - } + if len(next) == 0 { + return WidenVectorForConvergence(prev) } - return false -} -// HasHigherOrderGrowthRisk reports whether a type can produce non-monotone -// higher-order structural growth across summary iterations. -func HasHigherOrderGrowthRisk(t typ.Type) bool { - if t == nil { - return false + merged := Merge(prev, next) + if UnsafePrecisionDrop(prev, merged) { + merged = prev } - return value.Scan(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { - switch n := node.(type) { - case *typ.Function: - for _, ret := range n.Returns { - if containsFunction(ret) { - return true, false - } - } - case *typ.Record: - if recordHasSelfRecursiveMethod(n) { - return true, false - } - } - return false, true - }) + return WidenVectorForConvergence(NormalizeAndPrune(merged)) } -func containsFunction(t typ.Type) bool { - if t == nil { - return false +// WidenVectorForConvergence applies element-wise convergence widening. +func WidenVectorForConvergence(rets []typ.Type) []typ.Type { + if len(rets) == 0 { + return rets } - return value.Scan(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { - if _, ok := node.(*typ.Interface); ok { - return false, false - } - if _, ok := node.(*typ.Function); ok { - return true, false + out := make([]typ.Type, len(rets)) + changed := false + for i, t := range rets { + wt := value.WidenForConvergence(t) + out[i] = wt + if wt != t { + changed = true } - return false, true - }) + } + if !changed { + return rets + } + return out } -func recordHasSelfRecursiveMethod(r *typ.Record) bool { - if r == nil { +// UnsafePrecisionDrop reports whether a merged vector lost prior evidence. +func UnsafePrecisionDrop(prev, merged []typ.Type) bool { + if len(prev) == 0 || len(merged) == 0 || len(prev) != len(merged) { return false } - for _, f := range r.Fields { - if methodTypeHasSelfRecursiveReturn(f.Type, r) { + for i := range prev { + if value.UnsafePrecisionDrop(prev[i], merged[i]) { return true } } - return r.HasMapComponent() && methodTypeHasSelfRecursiveReturn(r.MapValue, r) + return false } -func methodTypeHasSelfRecursiveReturn(t typ.Type, owner *typ.Record) bool { - if t == nil || owner == nil { - return false - } - return value.Scan(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { - if _, ok := node.(*typ.Interface); ok { - return false, false - } - fn, ok := node.(*typ.Function) - if !ok { - return false, true +func shouldUseMonotoneJoin(a, b []typ.Type) bool { + for _, t := range a { + if value.HasHigherOrderGrowthRisk(t) { + return true } - for _, ret := range fn.Returns { - if ret == nil { - continue - } - if subtype.IsSubtype(ret, owner) || subtype.IsSubtype(owner, ret) || - value.ExtendsRecord(ret, owner) || value.ExtendsRecord(owner, ret) { - return true, false - } + } + for _, t := range b { + if value.HasHigherOrderGrowthRisk(t) { + return true } - return false, true - }) + } + return false } func joinMonotone(a, b []typ.Type) []typ.Type { diff --git a/compiler/check/domain/value/convergence.go b/compiler/check/domain/value/convergence.go new file mode 100644 index 00000000..6fdb4e16 --- /dev/null +++ b/compiler/check/domain/value/convergence.go @@ -0,0 +1,205 @@ +package value + +import ( + "github.com/wippyai/go-lua/types/subtype" + "github.com/wippyai/go-lua/types/typ" + "github.com/wippyai/go-lua/types/typ/unwrap" +) + +// NormalizeFactType canonicalizes one type before it is stored in an +// interprocedural fact slot. +func NormalizeFactType(t typ.Type) typ.Type { + if t == nil { + return nil + } + if fn := unwrap.Function(t); fn != nil { + return fn + } + return typ.PruneSoftUnionMembers(t) +} + +// WidenForConvergence applies the finite-height approximation needed for +// higher-order structural growth. +func WidenForConvergence(t typ.Type) typ.Type { + if t == nil { + return nil + } + if !HasHigherOrderGrowthRisk(t) { + return t + } + return subtype.WidenForInference(t) +} + +// WidenFunctionForConvergence applies convergence widening to a function type. +func WidenFunctionForConvergence(fn *typ.Function) *typ.Function { + if fn == nil { + return nil + } + if widened, ok := WidenForConvergence(fn).(*typ.Function); ok { + return widened + } + return fn +} + +// JoinPrecise merges non-function value facts inside one analysis iteration. +func JoinPrecise(existing, candidate typ.Type) typ.Type { + existing = NormalizeFactType(existing) + candidate = NormalizeFactType(candidate) + if existing == nil { + return candidate + } + if candidate == nil { + return existing + } + return typ.JoinPreferNonSoft(existing, candidate) +} + +// MergeForConvergence merges non-function value facts at a fixpoint boundary. +func MergeForConvergence(existing, candidate typ.Type) typ.Type { + existing = NormalizeFactType(existing) + candidate = NormalizeFactType(candidate) + if existing == nil { + return WidenForConvergence(candidate) + } + if candidate == nil { + return WidenForConvergence(existing) + } + existing = WidenForConvergence(existing) + candidate = WidenForConvergence(candidate) + if typ.TypeEquals(existing, candidate) { + return existing + } + if unwrap.IsNilType(existing) && !unwrap.IsNilType(candidate) { + return candidate + } + if unwrap.IsNilType(candidate) && !unwrap.IsNilType(existing) { + return existing + } + if typ.IsAny(existing) || typ.IsUnknown(existing) { + return existing + } + if typ.IsAny(candidate) || typ.IsUnknown(candidate) { + return candidate + } + if ElidesOptional(candidate, existing) { + return candidate + } + if ExtendsRecord(candidate, existing) && !ContainsNestedStructuralShape(candidate, existing) { + return candidate + } + if refines, _ := RefinesFalsyMapKey(candidate, existing); refines { + return candidate + } + if subtype.IsSubtype(candidate, existing) && !subtype.IsSubtype(existing, candidate) { + return existing + } + if subtype.IsSubtype(existing, candidate) && !subtype.IsSubtype(candidate, existing) { + return candidate + } + return typ.JoinPreferNonSoft(existing, candidate) +} + +// UnsafePrecisionDrop reports whether merged lost a previously possible branch +// from prev while appearing as a subtype refinement. +func UnsafePrecisionDrop(prev, merged typ.Type) bool { + if prev == nil || merged == nil || typ.TypeEquals(prev, merged) { + return false + } + if ElidesOptional(merged, prev) { + return false + } + if refines, _ := RefinesFalsyMapKey(merged, prev); refines { + return false + } + if typ.IsAny(prev) || typ.IsUnknown(prev) { + return true + } + + switch p := UnwrapStructuralShape(prev).(type) { + case *typ.Union: + if unionStrictMemberSubset(merged, p) { + return true + } + if subtype.IsSubtype(merged, p) && !subtype.IsSubtype(p, merged) { + return true + } + case *typ.Record: + m, ok := UnwrapStructuralShape(merged).(*typ.Record) + if !ok { + break + } + for _, pf := range p.Fields { + mf := m.GetField(pf.Name) + if mf != nil && UnsafePrecisionDrop(pf.Type, mf.Type) { + return true + } + } + if p.HasMapComponent() && m.HasMapComponent() && UnsafePrecisionDrop(p.MapValue, m.MapValue) { + return true + } + case *typ.Array: + if m, ok := UnwrapStructuralShape(merged).(*typ.Array); ok { + return UnsafePrecisionDrop(p.Element, m.Element) + } + case *typ.Map: + if m, ok := UnwrapStructuralShape(merged).(*typ.Map); ok { + return UnsafePrecisionDrop(p.Key, m.Key) || UnsafePrecisionDrop(p.Value, m.Value) + } + case *typ.Tuple: + m, ok := UnwrapStructuralShape(merged).(*typ.Tuple) + if !ok || len(p.Elements) != len(m.Elements) { + break + } + for i := range p.Elements { + if UnsafePrecisionDrop(p.Elements[i], m.Elements[i]) { + return true + } + } + case *typ.Function: + m, ok := UnwrapStructuralShape(merged).(*typ.Function) + if !ok { + break + } + for i := 0; i < len(p.Params) && i < len(m.Params); i++ { + if UnsafePrecisionDrop(p.Params[i].Type, m.Params[i].Type) { + return true + } + } + for i := 0; i < len(p.Returns) && i < len(m.Returns); i++ { + if UnsafePrecisionDrop(p.Returns[i], m.Returns[i]) { + return true + } + } + } + + if subtype.IsSubtype(merged, prev) && !subtype.IsSubtype(prev, merged) && !ExtendsRecord(merged, prev) { + return true + } + return false +} + +func unionStrictMemberSubset(candidate typ.Type, baseline *typ.Union) bool { + if baseline == nil { + return false + } + candidateMembers := UnionMembers(candidate) + if len(candidateMembers) == 0 { + candidateMembers = []typ.Type{candidate} + } + if len(candidateMembers) >= len(baseline.Members) { + return false + } + for _, member := range candidateMembers { + found := false + for _, baseMember := range baseline.Members { + if typ.TypeEquals(member, baseMember) { + found = true + break + } + } + if !found { + return false + } + } + return true +} diff --git a/compiler/check/domain/value/convergence_test.go b/compiler/check/domain/value/convergence_test.go new file mode 100644 index 00000000..f8b20f76 --- /dev/null +++ b/compiler/check/domain/value/convergence_test.go @@ -0,0 +1,26 @@ +package value + +import ( + "testing" + + "github.com/wippyai/go-lua/types/typ" +) + +func TestUnsafePrecisionDrop_DetectsNestedUnionMemberDrop(t *testing.T) { + withPending := typ.NewUnion( + typ.LiteralString("pass"), + typ.LiteralString("pending"), + typ.LiteralString("fail"), + typ.LiteralString("skip"), + ) + withoutPending := typ.NewUnion( + typ.LiteralString("pass"), + typ.LiteralString("fail"), + typ.LiteralString("skip"), + ) + prev := typ.NewRecord().Field("status", withPending).Build() + next := typ.NewRecord().Field("status", withoutPending).Build() + if !UnsafePrecisionDrop(prev, next) { + t.Fatalf("expected nested union member drop to be unsafe: prev=%v next=%v", prev, next) + } +} diff --git a/compiler/check/domain/value/doc.go b/compiler/check/domain/value/doc.go new file mode 100644 index 00000000..4d59037f --- /dev/null +++ b/compiler/check/domain/value/doc.go @@ -0,0 +1,8 @@ +// Package value owns reusable structural relations over typ.Type values. +// +// These relations are below return summaries, function facts, and whole fact +// products: optional elision, soft-placeholder preference, table-key truthiness +// refinement, recursive-growth detection, convergence widening, and unsafe +// precision-drop checks live here so higher domains can compose them without +// reimplementing local helper clusters. +package value diff --git a/compiler/check/domain/value/growth.go b/compiler/check/domain/value/growth.go new file mode 100644 index 00000000..6933336d --- /dev/null +++ b/compiler/check/domain/value/growth.go @@ -0,0 +1,81 @@ +package value + +import ( + "github.com/wippyai/go-lua/types/subtype" + "github.com/wippyai/go-lua/types/typ" +) + +// HasHigherOrderGrowthRisk reports whether a type can produce non-monotone +// higher-order structural growth across abstract-interpretation iterations. +func HasHigherOrderGrowthRisk(t typ.Type) bool { + if t == nil { + return false + } + return Scan(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { + switch n := node.(type) { + case *typ.Function: + for _, ret := range n.Returns { + if containsFunction(ret) { + return true, false + } + } + case *typ.Record: + if recordHasSelfRecursiveMethod(n) { + return true, false + } + } + return false, true + }) +} + +func containsFunction(t typ.Type) bool { + if t == nil { + return false + } + return Scan(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { + if _, ok := node.(*typ.Interface); ok { + return false, false + } + if _, ok := node.(*typ.Function); ok { + return true, false + } + return false, true + }) +} + +func recordHasSelfRecursiveMethod(r *typ.Record) bool { + if r == nil { + return false + } + for _, f := range r.Fields { + if methodTypeHasSelfRecursiveReturn(f.Type, r) { + return true + } + } + return r.HasMapComponent() && methodTypeHasSelfRecursiveReturn(r.MapValue, r) +} + +func methodTypeHasSelfRecursiveReturn(t typ.Type, owner *typ.Record) bool { + if t == nil || owner == nil { + return false + } + return Scan(t, typ.NewGuard(), func(node typ.Type) (bool, bool) { + if _, ok := node.(*typ.Interface); ok { + return false, false + } + fn, ok := node.(*typ.Function) + if !ok { + return false, true + } + for _, ret := range fn.Returns { + if ret == nil { + continue + } + if subtype.IsSubtype(ret, owner) || subtype.IsSubtype(owner, ret) || + ExtendsRecord(ret, owner) || ExtendsRecord(owner, ret) { + return true, false + } + } + return false, true + }) +} diff --git a/compiler/check/domain/returnsummary/summary_test.go b/compiler/check/domain/value/growth_test.go similarity index 98% rename from compiler/check/domain/returnsummary/summary_test.go rename to compiler/check/domain/value/growth_test.go index 145beb68..8608e62d 100644 --- a/compiler/check/domain/returnsummary/summary_test.go +++ b/compiler/check/domain/value/growth_test.go @@ -1,4 +1,4 @@ -package returnsummary +package value import ( "testing" From 1d602bc86fc4199eac9cd3ce4c3c08ab2fc5f359 Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 07:28:53 -0400 Subject: [PATCH 24/71] Fix contextual narrowing false positives --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 102 ++++++++++ compiler/check/synth/ops/call.go | 72 ++++++- .../check/synth/phase/extract/pipeline.go | 2 + .../synth/phase/extract/pipeline_test.go | 87 +++++++- .../external_lint_regression_test.go | 189 ++++++++++++++++++ types/narrow/filter.go | 80 +++++++- types/narrow/filter_test.go | 57 ++++++ 7 files changed, 577 insertions(+), 12 deletions(-) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index 7de2c7dc..eefa294f 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -397,6 +397,108 @@ Verification for this slice so far: 2 warnings. One first run printed agent/src 12 errors; direct replay of that target and a full rerun both returned 10. +## 2026-05-19 False-Positive Replay And Domain Refinement Checkpoint + +The next pass classified remaining local-replace lint failures and fixed the +ones that were checker false positives without weakening `any` soundness. + +Direct engine fixes: + +- the call pipeline now re-synthesizes every expected-sensitive argument form + that can change meaning under a concrete callee parameter expectation: + function literals, table literals, identifiers, attribute reads, explicit + casts, logical operators, call expressions, and non-nil assertions; +- intersection callees now publish contextual expected-argument vectors during + phase one, using the same merge law as union callees while still requiring + `FinishCall` to validate every intersection member; +- positive field-literal narrowing is now a domain meet for top/open table + shapes instead of only a union filter. A guard such as `part.type == "image"` + materializes the proven field on `any`, table top, maps, and open records with + row-tail evidence. Existing closed broad fields keep the previous "may match" + policy, so `field: string` does not collapse to a literal singleton merely + because one branch compared it. + +The key false-positive class was: + +```lua +for _, part in ipairs(content) do + if part.type == "text" and part.text and part.text ~= "" then + table.insert(content_blocks, { text = part.text }) + elseif part.type == "image" then + convert_image_to_converse(part) + end +end +``` + +When `content` came from `any`, the negative side of the text branch could +create an open `{text: ""}` shape. The later `part.type == "image"` check kept +that open shape because the open tail could contain `type`, but it failed to +record the hard proof that this branch's `type` field is present and equal to +`"image"`. The result was a false error when passing `part` to a helper that +requires a `type: string` field. + +Correct abstract interpretation: + +```text +Observation: part.type == "image" +Location: Location(part).field("type") +Evidence: hard runtime proof, field-literal equality +Domain: value/shape meet +State: open row-tail shape plus explicit type = "image" +Query: helper parameter assignability sees required type field +``` + +Wrong interpretation: + +```text +open row-tail may contain type -> keep the old shape unchanged +``` + +That wrong interpretation lost proof. It was not a reason to let `any` flow +into concrete contracts generally. + +Regression coverage added: + +- imported optional response-body fallback into an imported string call; +- explicit cast of an imported unknown field into an imported method call; +- intersection callee expected-argument publication; +- logical/cast/call/non-nil expected-sensitive argument re-synthesis; +- discriminated array elements from typed and untyped sources; +- open-record field-literal meet commutativity and union refinement laws. + +Local-replace Wippy replay after this fix: + +- `wippy.llm.bedrock:mapper` line 240 is clean; the reproduced checker false + positive is gone. +- `wippy.llm.bedrock:mapper` still reports line 503 (`parse_text_tool_call(text, + tool_names)` with `text` from `text_blocks`). This is not fixed in go-lua + because `text_blocks` is populated from `block.text` on an untyped external + payload. `if block.text then` proves truthiness, not stringness. Treating that + as string would be an `any`-to-concrete unsoundness unless the engine grows an + explicit successful-operator refinement model for `..` and string methods. +- session dependency diagnostics such as `expected string, got string?` remain + tied to pinned/locked external source shapes without a local fallback or cast. +- larger local-replace sweeps still contain true strictness diagnostics where + `any`, `unknown`, optional values, or intentionally invalid test inputs flow + into concrete contracts. Those must not be hidden by changing go-lua + assignability. + +Verification for this pass: + +- `go test ./types/constraint ./types/flow ./types/narrow` passes. +- `go test ./compiler/check/synth/phase/extract ./compiler/check/synth/ops + ./compiler/check/tests/regression` passes. +- `go test ./compiler/check/...` passes. +- `go test ./...` passes. +- `git diff --check` passes. +- `go test ./compiler/check -run '^$' -bench BenchmarkCheck_LargeFunction + -benchmem -count=3` reports about 1.13-1.15 ms/op, 881 KB/op, and 9390 + allocs/op on this machine. +- `../scripts/verify-suite.sh` passes the go-lua checker tests and Wippy binary + build, then exits non-zero on the external pinned lint targets: + session 8 errors, agent/src 8 errors, docker-demo 21 errors and 2 warnings. + The rest of the verify-suite lint targets report zero diagnostics. + ## Goal The checker should read as one abstract interpreter over a product domain. diff --git a/compiler/check/synth/ops/call.go b/compiler/check/synth/ops/call.go index 000772f1..c2f7451b 100644 --- a/compiler/check/synth/ops/call.go +++ b/compiler/check/synth/ops/call.go @@ -360,14 +360,7 @@ func InferCall(ctx *db.QueryContext, def CallDef) InferResult { } if callee.Kind() == kind.Intersection { - return InferResult{ - Kind: InferKindIntersection, - Callee: callee, - Receiver: receiver, - IsMethod: isMethod, - ForceMethodReceiver: def.ForceMethodReceiver, - Errors: errors, - } + return inferIntersection(ctx, callee.(*typ.Intersection), def, isMethod, receiver, errors) } fn, ok := callee.(*typ.Function) @@ -420,6 +413,69 @@ func inferFunction(ctx *db.QueryContext, fn *typ.Function, def CallDef, isMethod return result } +// inferIntersection aggregates contextual argument expectations from every +// callable intersection member. FinishCall still validates every member, so +// these expectations guide expression synthesis without weakening checking. +func inferIntersection(ctx *db.QueryContext, inter *typ.Intersection, def CallDef, isMethod bool, receiver typ.Type, errors []CallError) InferResult { + result := InferResult{ + Kind: InferKindIntersection, + Callee: inter, + Receiver: receiver, + IsMethod: isMethod, + ForceMethodReceiver: def.ForceMethodReceiver, + Errors: errors, + } + if inter == nil { + return result + } + + var ( + aggExpected []typ.Type + aggVariadic typ.Type + found bool + ) + for _, member := range inter.Members { + fn, ok := member.(*typ.Function) + if !ok { + continue + } + + instantiated := fn + typeArgs := []typ.Type(nil) + if len(fn.TypeParams) > 0 { + if len(def.TypeArgs) > 0 { + typeArgs = def.TypeArgs + } else { + var err error + typeArgs, err = InferTypeArgsWithExpectedAndMode(fn, def.Args, isMethod, receiver, def.ExpectedReturn, def.ForceMethodReceiver) + if err != nil { + continue + } + } + instantiated = InstantiateFunction(fn, typeArgs) + } + + expectedArgs, expectedVariadic := computeExpectedArgs(ctx, def.Query, instantiated, isMethod, receiver, def.ForceMethodReceiver) + if !found { + found = true + result.Function = fn + result.TypeArgs = typeArgs + result.Instantiated = instantiated + aggExpected = append([]typ.Type(nil), expectedArgs...) + aggVariadic = expectedVariadic + continue + } + aggExpected = mergeExpectedArgVectors(aggExpected, expectedArgs) + aggVariadic = typ.JoinPreferNonSoft(aggVariadic, expectedVariadic) + } + + if found { + result.ExpectedArgs = aggExpected + result.ExpectedVariadic = aggVariadic + } + return result +} + // inferUnion handles inference for union callees. // For unions, we attempt to infer each member separately and return // expected types from the first successful inference. diff --git a/compiler/check/synth/phase/extract/pipeline.go b/compiler/check/synth/phase/extract/pipeline.go index 25876b7a..4f08e5d2 100644 --- a/compiler/check/synth/phase/extract/pipeline.go +++ b/compiler/check/synth/phase/extract/pipeline.go @@ -156,6 +156,8 @@ func FullArgReSynth( return expected } return inferred + case *ast.CastExpr, *ast.LogicalOpExpr, *ast.FuncCallExpr, *ast.NonNilAssertExpr: + return synthWithExpected(a, p, expected) } return nil } diff --git a/compiler/check/synth/phase/extract/pipeline_test.go b/compiler/check/synth/phase/extract/pipeline_test.go index 8258cde2..b465529a 100644 --- a/compiler/check/synth/phase/extract/pipeline_test.go +++ b/compiler/check/synth/phase/extract/pipeline_test.go @@ -123,6 +123,49 @@ func TestCallPipeline_ExpectedArgType_OutOfRange(t *testing.T) { } } +func TestCallPipeline_ExpectedArgType_Intersection(t *testing.T) { + ctx := db.NewQueryContext(db.New()) + fnA := typ.Func().Param("x", typ.String).Returns(typ.Any).Build() + fnB := typ.Func().Param("x", typ.String).Returns(typ.Unknown).Build() + def := ops.CallDef{ + Callee: typ.NewIntersection(fnA, fnB), + Args: []typ.Type{typ.NewOptional(typ.String)}, + } + pipeline := NewCallPipeline(ctx, def, []ast.Expr{&ast.LogicalOpExpr{}}) + pipeline.Infer() + + arg0 := pipeline.ExpectedArgType(0) + if arg0 != typ.String { + t.Fatalf("got %v, want string", arg0) + } +} + +func TestCallPipeline_IntersectionReSynthesizesLogicalArg(t *testing.T) { + ctx := db.NewQueryContext(db.New()) + fnA := typ.Func().Param("x", typ.String).Returns(typ.Any).Build() + fnB := typ.Func().Param("x", typ.String).Returns(typ.Unknown).Build() + arg := &ast.LogicalOpExpr{} + def := ops.CallDef{ + Callee: typ.NewIntersection(fnA, fnB), + Args: []typ.Type{typ.NewOptional(typ.String)}, + } + pipeline := NewCallPipeline(ctx, def, []ast.Expr{arg}). + WithReSynth(func(idx int, got ast.Expr, expected typ.Type) typ.Type { + if got != arg { + t.Fatalf("got arg %p, want %p", got, arg) + } + if expected != typ.String { + t.Fatalf("got expected %v, want string", expected) + } + return typ.String + }) + + result := pipeline.Run() + if len(result.Errors) != 0 { + t.Fatalf("expected no errors after contextual re-synthesis, got %v", result.Errors) + } +} + func TestCallPipeline_ReSynthAndReInfer_NoReSynth(t *testing.T) { ctx := db.NewQueryContext(db.New()) fn := typ.Func().Build() @@ -272,7 +315,49 @@ func TestFullArgReSynth_Other(t *testing.T) { result := reSynth(0, &ast.NumberExpr{}, typ.Integer) if result != nil { - t.Fatal("expected nil for non-function/table") + t.Fatal("expected nil for expression that does not benefit from contextual re-synthesis") + } +} + +func TestFullArgReSynth_Cast(t *testing.T) { + called := false + synthWithExpected := func(arg ast.Expr, p cfg.Point, expected typ.Type) typ.Type { + called = true + if _, ok := arg.(*ast.CastExpr); !ok { + t.Fatalf("got %T, want CastExpr", arg) + } + return typ.String + } + + reSynth := FullArgReSynth(synthWithExpected, nil, 0) + result := reSynth(0, &ast.CastExpr{}, typ.String) + + if !called { + t.Fatal("expected callback to be called for cast expression") + } + if result != typ.String { + t.Fatalf("got %v, want string", result) + } +} + +func TestFullArgReSynth_Logical(t *testing.T) { + called := false + synthWithExpected := func(arg ast.Expr, p cfg.Point, expected typ.Type) typ.Type { + called = true + if _, ok := arg.(*ast.LogicalOpExpr); !ok { + t.Fatalf("got %T, want LogicalOpExpr", arg) + } + return typ.String + } + + reSynth := FullArgReSynth(synthWithExpected, nil, 0) + result := reSynth(0, &ast.LogicalOpExpr{}, typ.String) + + if !called { + t.Fatal("expected callback to be called for logical expression") + } + if result != typ.String { + t.Fatalf("got %v, want string", result) } } diff --git a/compiler/check/tests/regression/external_lint_regression_test.go b/compiler/check/tests/regression/external_lint_regression_test.go index 10b2ac32..e09dc9e1 100644 --- a/compiler/check/tests/regression/external_lint_regression_test.go +++ b/compiler/check/tests/regression/external_lint_regression_test.go @@ -48,6 +48,44 @@ local parsed, parse_err = json.decode(response.body or "") } } +func TestExternalLint_ImportedOptionalResponseBodyDefaultIsStringAtCall(t *testing.T) { + jsonModule := testutil.CheckAndExport(` +local json = {} +function json.decode(raw: string): any + return {} +end +return json +`, "json", testutil.WithStdlib()) + if jsonModule.HasError() { + t.Fatalf("json module errors: %v", testutil.ErrorMessages(jsonModule.Errors)) + } + + source := ` +local json = require("json") + +type Response = { + status_code: number, + body: string?, +} + +local function request(): (Response?, string?) + return { status_code = 200 }, nil +end + +local response, err = request() +if not response then + return nil, err +end + +local parsed, parse_err = json.decode(response.body or "") +return parsed, parse_err +` + result := testutil.Check(source, testutil.WithStdlib(), testutil.WithModule("json", jsonModule)) + if result.HasError() { + t.Fatalf("expected imported optional body fallback to feed string call argument, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + func TestExternalLint_GuardedOptionsModelSurvivesProviderBranches(t *testing.T) { source := ` local models = { @@ -292,6 +330,50 @@ local result, err = get_page_data({ data_func = true }) } } +func TestExternalLint_CastUnknownImportedFieldFeedsImportedMethodCall(t *testing.T) { + templatesModule := testutil.CheckAndExport(` +local templates = {} +function templates.get(id: string) + return { + render = function(self, name: string, context: table) + return name, nil + end, + release = function(self) + end, + }, nil +end +return templates +`, "templates", testutil.WithStdlib()) + if templatesModule.HasError() { + t.Fatalf("templates module errors: %v", testutil.ErrorMessages(templatesModule.Errors)) + } + + source := ` +local templates = require("templates") + +local function get_page() + return { + template_set = "main", + template_name = nil :: unknown, + } +end + +local page = get_page() +local tmpl, err = templates.get(page.template_set) +if err then + return nil, err +end + +local content, render_err = tmpl:render(page.template_name :: string, {}) +tmpl:release() +return content, render_err +` + result := testutil.Check(source, testutil.WithStdlib(), testutil.WithModule("templates", templatesModule)) + if result.HasError() { + t.Fatalf("expected explicit cast of imported field to feed method argument checking, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + func TestExternalLint_InsertedSuiteShapeSurvivesIpairs(t *testing.T) { source := ` type Suite = { @@ -584,3 +666,110 @@ end t.Fatalf("expected discriminated array element to feed image helper, got: %v", testutil.ErrorMessages(result.Diagnostics)) } } + +func TestExternalLint_DiscriminatedArrayElementDoesNotInheritAccumulatorShape(t *testing.T) { + source := ` +type ImagePart = { + type: string, + source: any?, + text: string?, +} + +local function convert_image_to_converse(content_part: ImagePart) + if content_part.type == "image" and content_part.source then + return { image = content_part.source } + end + return nil +end + +local message = { + content = { + { type = "text", text = "hello" }, + { type = "image", source = { media_type = "image/png", data = "abc" } }, + }, +} + +local content_blocks = {} +for _, part in ipairs(message.content) do + if part.type == "text" and part.text and part.text ~= "" then + table.insert(content_blocks, { text = part.text }) + elseif part.type == "image" then + local img = convert_image_to_converse(part) + if img then + table.insert(content_blocks, img) + end + end +end +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatalf("expected discriminated source array element not to inherit accumulator-only shapes, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_UntypedDiscriminatedArrayElementFeedsTypedBranchHelper(t *testing.T) { + source := ` +type ImagePart = { + type: string, + source: any?, + text: string?, +} + +local prompt = { + ROLE = { + ASSISTANT = "assistant", + } +} + +local function convert_image_to_converse(content_part: ImagePart) + if content_part.type == "image" and content_part.source then + return { image = content_part.source } + end + return nil +end + +local function map_messages(contract_messages) + local converse_messages = {} + for _, msg in ipairs(contract_messages) do + if msg.role == prompt.ROLE.ASSISTANT then + local content_blocks = {} + local content = msg.content + if type(content) == "string" then + if content ~= "" then + table.insert(content_blocks, { text = content }) + end + elseif type(content) == "table" then + for _, part in ipairs(content) do + if part.type == "text" and part.text and part.text ~= "" then + table.insert(content_blocks, { text = part.text }) + elseif part.type == "function_call" then + table.insert(content_blocks, { toolUse = { name = part.name or "" } }) + elseif part.type == "image" then + local img = convert_image_to_converse(part) + if img then + table.insert(content_blocks, img) + end + end + end + end + if #content_blocks > 0 then + table.insert(converse_messages, { role = "assistant", content = content_blocks }) + end + end + end + return converse_messages +end + +map_messages(nil :: any) +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatalf("expected untyped discriminated source element to feed typed image helper, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} diff --git a/types/narrow/filter.go b/types/narrow/filter.go index 2abf320f..89c4ec2f 100644 --- a/types/narrow/filter.go +++ b/types/narrow/filter.go @@ -325,6 +325,19 @@ func ByFieldLiteral(t typ.Type, field string, lit *typ.Literal, resolver Resolve if t == nil || field == "" || lit == nil || resolver == nil { return t } + return refineByFieldLiteral(t, field, lit, resolver) +} + +func refineByFieldLiteral(t typ.Type, field string, lit *typ.Literal, resolver Resolver) typ.Type { + if t == nil { + return t + } + if a, ok := t.(*typ.Alias); ok { + return refineByFieldLiteral(a.Target, field, lit, resolver) + } + if expanded := unwrap.Instantiated(t); expanded != t { + return refineByFieldLiteral(expanded, field, lit, resolver) + } if t.Kind().IsPlaceholder() || unwrap.IsBuiltinTableTop(t) { // Refining `table` by a field literal should materialize a structural // shape so downstream assignment/subtyping can use the discriminant. @@ -333,9 +346,70 @@ func ByFieldLiteral(t typ.Type, field string, lit *typ.Literal, resolver Resolve return typ.NewRecord().Field(field, lit).SetOpen(true).Build() } - return FilterByMatch(t, func(m typ.Type) bool { - return FieldMatchesLiteral(m, field, lit, resolver) - }, false) + switch v := t.(type) { + case *typ.Optional: + return refineByFieldLiteral(v.Inner, field, lit, resolver) + case *typ.Union: + kept := make([]typ.Type, 0, len(v.Members)) + for _, m := range v.Members { + refined := refineByFieldLiteral(m, field, lit, resolver) + if refined != nil && !refined.Kind().IsNever() { + kept = append(kept, refined) + } + } + if len(kept) == 0 { + return typ.Never + } + return typ.NewUnion(kept...) + case *typ.Record: + return refineRecordByFieldLiteral(v, field, lit, resolver) + case *typ.Map: + return refineMapByFieldLiteral(v, field, lit, resolver) + case *typ.Intersection: + if FieldMatchesLiteral(v, field, lit, resolver) { + return v + } + return typ.Never + default: + if FieldMatchesLiteral(t, field, lit, resolver) { + return t + } + return typ.Never + } +} + +func refineRecordByFieldLiteral(r *typ.Record, field string, lit *typ.Literal, resolver Resolver) typ.Type { + if r == nil { + return typ.Never + } + if f := r.GetField(field); f != nil { + if f.Type == nil || !typ.TypeMatchesLiteral(f.Type, lit) { + return typ.Never + } + if f.Type.Kind().IsPlaceholder() { + return typ.ExtendRecordWithField(r, field, lit) + } + return r + } + fieldType, ok := resolver.Field(r, field) + if !ok || fieldType == nil || !typ.TypeMatchesLiteral(fieldType, lit) { + return typ.Never + } + return typ.ExtendRecordWithField(r, field, lit) +} + +func refineMapByFieldLiteral(m *typ.Map, field string, lit *typ.Literal, resolver Resolver) typ.Type { + if m == nil { + return typ.Never + } + fieldType, ok := resolver.Field(m, field) + if !ok || fieldType == nil || !typ.TypeMatchesLiteral(fieldType, lit) { + return typ.Never + } + return typ.NewRecord(). + Field(field, lit). + MapComponent(m.Key, m.Value). + Build() } // ExcludeByFieldLiteral excludes union members where a field exactly equals a literal. diff --git a/types/narrow/filter_test.go b/types/narrow/filter_test.go index c4bc0317..464c349d 100644 --- a/types/narrow/filter_test.go +++ b/types/narrow/filter_test.go @@ -137,6 +137,9 @@ func (r *mockResolver) Field(t typ.Type, name string) (typ.Type, bool) { if f := rec.GetField(name); f != nil { return f.Type, true } + if rec.Open { + return typ.Unknown, true + } } key := t.String() if fields, ok := r.fields[key]; ok { @@ -366,6 +369,60 @@ func TestByFieldLiteral_PlaceholderMaterializesRecord(t *testing.T) { } } +func TestByFieldLiteral_OpenRecordMissingFieldMaterializesLiteral(t *testing.T) { + resolver := newMockResolver() + lit := typ.LiteralString("image") + base := typ.NewRecord().Field("text", typ.LiteralString("")).SetOpen(true).Build() + + result := ByFieldLiteral(base, "type", lit, resolver) + want := typ.NewRecord(). + Field("text", typ.LiteralString("")). + Field("type", lit). + SetOpen(true). + Build() + if !typ.TypeEquals(result, want) { + t.Errorf("ByFieldLiteral(open record, type, \"image\") = %v, want %v", result, want) + } +} + +func TestByFieldLiteral_OpenRecordRefinementIsOrderIndependent(t *testing.T) { + resolver := newMockResolver() + image := typ.LiteralString("image") + empty := typ.LiteralString("") + + leftFirst := ByFieldLiteral(ByFieldLiteral(typ.Any, "text", empty, resolver), "type", image, resolver) + rightFirst := ByFieldLiteral(ByFieldLiteral(typ.Any, "type", image, resolver), "text", empty, resolver) + if !typ.TypeEquals(leftFirst, rightFirst) { + t.Fatalf("field literal refinements should commute, got %v and %v", leftFirst, rightFirst) + } + + want := typ.NewRecord(). + Field("text", empty). + Field("type", image). + SetOpen(true). + Build() + if !typ.TypeEquals(leftFirst, want) { + t.Errorf("combined refinement = %v, want %v", leftFirst, want) + } +} + +func TestByFieldLiteral_UnionRefinesOpenMembers(t *testing.T) { + resolver := newMockResolver() + image := typ.LiteralString("image") + textOnly := typ.NewRecord().Field("text", typ.LiteralString("")).SetOpen(true).Build() + functionCall := typ.NewRecord().Field("type", typ.LiteralString("function_call")).SetOpen(true).Build() + + result := ByFieldLiteral(typ.NewUnion(textOnly, functionCall), "type", image, resolver) + want := typ.NewRecord(). + Field("text", typ.LiteralString("")). + Field("type", image). + SetOpen(true). + Build() + if !typ.TypeEquals(result, want) { + t.Errorf("ByFieldLiteral(union, type, \"image\") = %v, want %v", result, want) + } +} + func TestExcludeByFieldLiteral_EmptyField(t *testing.T) { resolver := newMockResolver() rec := typ.NewRecord().Field("kind", typ.LiteralString("a")).Build() From 97f222f5c916252c1bcc7ee5c57aac1fcdf26808 Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 12:33:20 -0400 Subject: [PATCH 25/71] Rectify interproc function facts domain --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 72 ++ compiler/bind/fieldpath.go | 16 + compiler/bind/table.go | 38 + compiler/bind/table_test.go | 45 + compiler/check/api/facts.go | 34 + compiler/check/api/synth.go | 3 + compiler/check/api/synth_test.go | 3 + compiler/check/domain/functionfact/fact.go | 176 +++- .../paramevidence/parameter_evidence.go | 24 +- compiler/check/flowbuild/numconst/numconst.go | 38 + .../check/flowbuild/numconst/numconst_test.go | 26 + compiler/check/hooks/call_check.go | 115 ++- compiler/check/hooks/local_param_use.go | 196 +++++ compiler/check/infer/interproc/postflow.go | 37 + compiler/check/phase/types_test.go | 6 +- compiler/check/synth/engine_test.go | 4 + compiler/check/synth/intercept/chain.go | 1 + compiler/check/synth/intercept/intercept.go | 4 + .../check/synth/intercept/setmetatable.go | 128 +++ .../synth/intercept/setmetatable_test.go | 93 ++ compiler/check/synth/ops/call.go | 13 +- compiler/check/synth/phase/extract/call.go | 239 +++++- compiler/check/synth/phase/extract/expr.go | 169 ++++ .../synth/phase/extract/named_function.go | 199 ++++- .../external_lint_regression_test.go | 805 ++++++++++++++++++ .../regression/linter_false_positive_test.go | 33 + types/constraint/atom.go | 6 + types/constraint/numeric.go | 53 +- types/constraint/numeric_test.go | 41 +- types/constraint/visit.go | 40 +- types/flow/numeric/domain.go | 4 + types/flow/numeric/state.go | 153 +++- types/flow/numeric/state_test.go | 24 + types/flow/query.go | 19 + 34 files changed, 2792 insertions(+), 65 deletions(-) create mode 100644 compiler/check/hooks/local_param_use.go create mode 100644 compiler/check/synth/intercept/setmetatable.go create mode 100644 compiler/check/synth/intercept/setmetatable_test.go diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index eefa294f..b27acedd 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -5045,3 +5045,75 @@ If this is done as a flash migration, the codebase should become smaller because many helper clusters collapse into a few named domains. It should also become easier to reason about because every merge/refinement/widening decision will have one owner and one law-test suite. + +## 2026-05-19 Engine Verification And Classification Checkpoint + +This pass removed the remaining parameter-count heuristic in call diagnostics. +The old shape was not a domain law: graph-local calls were relaxed based on +source arity. The replacement is a semantic boundary: + +- function facts remain the call contract authority; +- explicit `any` arguments are only ignored for graph-local parameter slots + whose value is never observed by the function body; +- observed parameters still enforce their declared or inferred contract; +- the unobserved-parameter mask is computed once per function symbol during the + call-check pass from binder symbol identity, so shadowing and captured uses are + handled by symbols, not names. + +Regression coverage added: + +- an internal `any` passed to an unobserved local parameter does not create a + false positive; +- the same `any` passed to an observed `string` parameter remains an error; +- imported/manifest call boundaries that require `string` still reject `any`; +- the external-lint reductions now also cover selected HTTP response body + fallback, error-guarded imported page field casts, and captured state-field + map iteration. + +Verification from this checkpoint: + +```text +go test ./... -count=1 +go test ./compiler/check -run '^$' -bench BenchmarkCheck_LargeFunction -benchmem -count=3 +../scripts/verify-suite.sh +``` + +Results: + +- `go test ./... -count=1` passes. +- `BenchmarkCheck_LargeFunction` is about 1.83-1.86 ms/op, about 1.05 MB/op, + and 10,695 allocs/op on this machine. +- `../scripts/verify-suite.sh` passes go-lua checker tests and builds Wippy, but + the external lint section still exits non-zero because the script builds Wippy + against `github.com/wippyai/go-lua v1.5.16`, not this checkout. +- A temporary Wippy binary built with a `/tmp` `go.mod` replacement pointed at + this checkout was used for classification without editing external Wippy code. + +Current external-lint classification: + +- The official pinned verify output cannot prove current go-lua regressions. + It reported `session` 8, `framework/src/agent/src` 13 during the script run + and 8 on direct replay, and `docker-demo` 21 errors / 2 warnings. +- The local-replace replay is stricter than the pinned binary. It reports many + explicit `any` to concrete-contract errors. Those are soundness-preserving + external code or manifest contract issues unless reduced to a go-lua false + positive. +- High-confidence engine candidates were reduced where possible. The current + reduced go-lua fixtures for response-body fallback, page-field casts, + captured state map iteration, length guards, setmetatable prototypes, query + builder back-references, and imported assertions pass. +- Remaining unreduced candidates are mostly context-sensitive Wippy package + interactions: imported module manifests that expose `unknown`/`any`, generated + package cache shape, and real code paths that pass unchecked dynamic values + into concrete APIs. They should not be fixed by weakening `any` or erasing + `unknown`. + +Design rule retained for the next pass: + +- Do not add compatibility channels or fallback facts. +- Do not make `any` silently assignable to concrete types. +- If an external diagnostic is a false positive, first reduce it into a + go-lua regression that fails for the same semantic reason, then fix the + owning domain or transfer rule. +- If a diagnostic is true external code, keep it classified and do not edit + external Wippy sources from this go-lua PR. diff --git a/compiler/bind/fieldpath.go b/compiler/bind/fieldpath.go index 61bc4987..14c486bb 100644 --- a/compiler/bind/fieldpath.go +++ b/compiler/bind/fieldpath.go @@ -90,3 +90,19 @@ func displayFieldPathKey(path string) string { return path } + +// DirectFieldNameFromKey returns the field name for a one-segment string-like +// field key. Numeric indexes and nested paths do not describe direct prototype +// fields and are rejected. +func DirectFieldNameFromKey(path string) (string, bool) { + segs := pathkey.ParseSuffix(path) + if len(segs) != 1 { + return "", false + } + switch segs[0].Kind { + case constraint.SegmentField, constraint.SegmentIndexString: + return segs[0].Name, segs[0].Name != "" + default: + return "", false + } +} diff --git a/compiler/bind/table.go b/compiler/bind/table.go index cae7b488..72547c5c 100644 --- a/compiler/bind/table.go +++ b/compiler/bind/table.go @@ -97,6 +97,12 @@ type fieldPathKey struct { path string } +// FieldSymbolRef identifies a direct field symbol rooted at a base symbol. +type FieldSymbolRef struct { + Name string + Symbol cfg.SymbolID +} + // NewBindingTable creates an empty binding table with all maps initialized. func NewBindingTable() *BindingTable { return NewBindingTableWithHint(0, 0) @@ -409,6 +415,38 @@ func (t *BindingTable) FieldSymbol(baseSym cfg.SymbolID, path string) (cfg.Symbo return sym, ok } +// DirectFieldSymbols returns direct field symbols rooted at baseSym. +// +// Only one-segment string-like fields are returned; nested paths and numeric +// indexes are intentionally excluded because they are not fields on the base +// prototype itself. +func (t *BindingTable) DirectFieldSymbols(baseSym cfg.SymbolID) []FieldSymbolRef { + if t == nil || baseSym == 0 || len(t.fieldSymbols) == 0 { + return nil + } + out := make([]FieldSymbolRef, 0) + for key, sym := range t.fieldSymbols { + if key.base != baseSym || sym == 0 { + continue + } + name, ok := DirectFieldNameFromKey(key.path) + if !ok { + continue + } + out = append(out, FieldSymbolRef{Name: name, Symbol: sym}) + } + if len(out) == 0 { + return nil + } + sort.Slice(out, func(i, j int) bool { + if out[i].Name == out[j].Name { + return out[i].Symbol < out[j].Symbol + } + return out[i].Name < out[j].Name + }) + return out +} + // GetOrCreateFuncLitSymbol returns or creates a symbol for an anonymous function. // // Anonymous functions (function literals) need symbols for type assignment diff --git a/compiler/bind/table_test.go b/compiler/bind/table_test.go index 196d433e..2b01046a 100644 --- a/compiler/bind/table_test.go +++ b/compiler/bind/table_test.go @@ -524,6 +524,51 @@ func TestBindingTable_FieldSymbol_NormalizesLegacyBracketStringKey(t *testing.T) } } +func TestBindingTable_DirectFieldSymbols(t *testing.T) { + table := NewBindingTable() + baseSym := cfg.NextSymbolID() + otherSym := cfg.NextSymbolID() + + beta := table.GetOrCreateFieldSymbol(baseSym, "beta") + alpha := table.GetOrCreateFieldSymbol(baseSym, "alpha") + nested := table.GetOrCreateFieldSymbol(baseSym, "alpha.deep") + indexKey, ok := FieldPathKeyFromSegments([]constraint.Segment{ + {Kind: constraint.SegmentIndexString, Name: "quoted-key"}, + }) + if !ok { + t.Fatal("expected canonical string-index key") + } + quoted := table.GetOrCreateFieldSymbol(baseSym, indexKey) + numericKey, ok := FieldPathKeyFromSegments([]constraint.Segment{ + {Kind: constraint.SegmentIndexInt, Index: 1}, + }) + if !ok { + t.Fatal("expected canonical int-index key") + } + _ = table.GetOrCreateFieldSymbol(baseSym, numericKey) + _ = table.GetOrCreateFieldSymbol(otherSym, "alpha") + + got := table.DirectFieldSymbols(baseSym) + want := []FieldSymbolRef{ + {Name: "alpha", Symbol: alpha}, + {Name: "beta", Symbol: beta}, + {Name: "quoted-key", Symbol: quoted}, + } + if len(got) != len(want) { + t.Fatalf("DirectFieldSymbols length = %d, want %d; got %#v", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("DirectFieldSymbols[%d] = %#v, want %#v", i, got[i], want[i]) + } + } + for _, ref := range got { + if ref.Symbol == nested { + t.Fatal("nested field path should not be returned as a direct field") + } + } +} + func TestBindingTable_GetOrCreateFieldSymbol_InvalidPathRejected(t *testing.T) { table := NewBindingTable() baseSym := cfg.NextSymbolID() diff --git a/compiler/check/api/facts.go b/compiler/check/api/facts.go index 641c961c..892d6294 100644 --- a/compiler/check/api/facts.go +++ b/compiler/check/api/facts.go @@ -8,6 +8,7 @@ package api import ( "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/scope" "github.com/wippyai/go-lua/types/constraint" "github.com/wippyai/go-lua/types/typ" ) @@ -74,6 +75,39 @@ func (facts FunctionFacts) FunctionType(sym cfg.SymbolID) typ.Type { return ff.Type } +// FunctionFactSnapshotForSymbol returns the stable fact snapshot for sym. +func FunctionFactSnapshotForSymbol(store StoreReader, sym cfg.SymbolID, defaultParent *scope.State) (FunctionFact, bool) { + if store == nil || sym == 0 { + return FunctionFact{}, false + } + ref := store.FunctionRefBySym(sym) + if ref == nil { + return FunctionFact{}, false + } + parentGraphID := ref.ParentGraphID + if parentGraphID == 0 { + parentGraphID = ref.GraphID + } + parentGraph := store.Graphs()[parentGraphID] + if parentGraph == nil { + return FunctionFact{}, false + } + parent := ParentScopeForGraph(store, parentGraph.ID(), defaultParent) + if parent == nil { + return FunctionFact{}, false + } + return store.GetFunctionFactsSnapshot(parentGraph, parent).Fact(sym) +} + +// FunctionTypeSnapshotForSymbol returns the stable function type fact for sym. +func FunctionTypeSnapshotForSymbol(store StoreReader, sym cfg.SymbolID, defaultParent *scope.State) typ.Type { + ff, ok := FunctionFactSnapshotForSymbol(store, sym, defaultParent) + if !ok { + return nil + } + return ff.Type +} + // LiteralSigs maps anonymous function literal expressions to their signatures. // Used when function literals are passed as arguments or assigned to variables // without explicit type annotations. diff --git a/compiler/check/api/synth.go b/compiler/check/api/synth.go index fe724007..fa11f156 100644 --- a/compiler/check/api/synth.go +++ b/compiler/check/api/synth.go @@ -179,6 +179,9 @@ type FlowOps interface { // varName <= len(array) + offset. ArrayLenBoundWithOffsetAt(p cfg.Point, varName string) (arrKey string, offset int64, ok bool) + // LengthBoundsAt returns numeric bounds for len(path) at a point. + LengthBoundsAt(p cfg.Point, path constraint.Path) (lower, upper int64, ok bool) + // IsPointDead returns whether a CFG point is unreachable. IsPointDead(p cfg.Point) bool diff --git a/compiler/check/api/synth_test.go b/compiler/check/api/synth_test.go index dffb9f91..cf5562e2 100644 --- a/compiler/check/api/synth_test.go +++ b/compiler/check/api/synth_test.go @@ -84,6 +84,9 @@ func (m *mockFlowOps) ArrayLenBoundAt(cfg.Point, string) (string, bool) { retu func (m *mockFlowOps) ArrayLenBoundWithOffsetAt(cfg.Point, string) (string, int64, bool) { return "", 0, false } +func (m *mockFlowOps) LengthBoundsAt(cfg.Point, constraint.Path) (int64, int64, bool) { + return 0, 0, false +} func (m *mockFlowOps) IsPointDead(cfg.Point) bool { return false } func (m *mockFlowOps) HasKeyOf(cfg.Point, constraint.Path, constraint.Path) bool { return false } diff --git a/compiler/check/domain/functionfact/fact.go b/compiler/check/domain/functionfact/fact.go index 91804a8f..c50957c0 100644 --- a/compiler/check/domain/functionfact/fact.go +++ b/compiler/check/domain/functionfact/fact.go @@ -7,6 +7,7 @@ import ( "github.com/wippyai/go-lua/compiler/check/domain/value" "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" + typjoin "github.com/wippyai/go-lua/types/typ/join" "github.com/wippyai/go-lua/types/typ/unwrap" ) @@ -45,6 +46,7 @@ func Join(existing, candidate api.FunctionFact) api.FunctionFact { out.Type = MergeType(out.Type, candidate.Type) } + summaryBeforeNarrow := out.Summary if len(out.Narrow) > 0 { if len(out.Summary) == 0 { out.Summary = returnsummary.Canonical(out.Narrow) @@ -54,11 +56,29 @@ func Join(existing, candidate api.FunctionFact) api.FunctionFact { } if fn := unwrap.Function(out.Type); fn != nil { - alignedSummary := out.Summary - if len(alignedSummary) > 0 { - if aligned, changed := returnsummary.AlignFunction(fn, alignedSummary); changed { - out.Type = aligned - fn = aligned + alignedReturns := out.Summary + usingNarrow := len(out.Narrow) > 0 && !returnsummary.AllNil(out.Narrow) + if usingNarrow { + repairBase := summaryBeforeNarrow + if len(repairBase) == 0 { + repairBase = out.Summary + } + alignedReturns = repairSummaryWithNarrow(repairBase, out.Narrow) + } + if len(alignedReturns) > 0 { + if usingNarrow { + if aligned := typjoin.WithReturns(fn, alignedReturns); aligned != nil { + if typ.IsAny(aligned.Variadic) { + aligned = stripVariadic(aligned) + } + out.Type = aligned + fn = aligned + } + } else { + if aligned, changed := returnsummary.AlignFunction(fn, alignedReturns); changed { + out.Type = aligned + fn = aligned + } } } if len(out.Summary) == 0 && fn != nil && len(fn.Returns) > 0 { @@ -109,6 +129,7 @@ func WidenForConvergence(prev, next api.FunctionFact) api.FunctionFact { Type: WidenTypeForConvergence(prev.Type, next.Type), } + summaryBeforeNarrow := out.Summary // Narrow summaries can refine optional/non-nil returns, but a nil-only // narrow observation must not erase an already-informative summary. if len(out.Narrow) > 0 && !returnsummary.AllNil(out.Narrow) { @@ -120,9 +141,27 @@ func WidenForConvergence(prev, next api.FunctionFact) api.FunctionFact { } if fn := unwrap.Function(out.Type); fn != nil { - if len(out.Summary) > 0 { - if aligned, changed := returnsummary.AlignFunction(fn, out.Summary); changed { - out.Type = WidenTypeForConvergence(fn, aligned) + alignedReturns := out.Summary + usingNarrow := len(out.Narrow) > 0 && !returnsummary.AllNil(out.Narrow) + if usingNarrow { + repairBase := summaryBeforeNarrow + if len(repairBase) == 0 { + repairBase = out.Summary + } + alignedReturns = repairSummaryWithNarrow(repairBase, out.Narrow) + } + if len(alignedReturns) > 0 { + if usingNarrow { + if aligned := typjoin.WithReturns(fn, alignedReturns); aligned != nil { + if nextFn := unwrap.Function(next.Type); (nextFn != nil && nextFn.Variadic == nil) || typ.IsAny(aligned.Variadic) { + aligned = stripVariadic(aligned) + } + out.Type = value.WidenForConvergence(aligned) + } + } else { + if aligned, changed := returnsummary.AlignFunction(fn, alignedReturns); changed { + out.Type = WidenTypeForConvergence(fn, aligned) + } } } else if len(fn.Returns) > 0 { out.Summary = returnsummary.WidenForConvergence(nil, fn.Returns) @@ -132,6 +171,127 @@ func WidenForConvergence(prev, next api.FunctionFact) api.FunctionFact { return out } +func repairSummaryWithNarrow(summary, narrow []typ.Type) []typ.Type { + if len(narrow) == 0 { + return summary + } + if len(summary) != len(narrow) || len(summary) == 0 { + return narrow + } + out := make([]typ.Type, len(summary)) + for i := range summary { + out[i] = repairTypeWithNarrow(summary[i], narrow[i], 0) + } + return out +} + +func stripVariadic(fn *typ.Function) *typ.Function { + if fn == nil || fn.Variadic == nil { + return fn + } + builder := typ.Func().ReserveParams(len(fn.Params)) + for _, tp := range fn.TypeParams { + builder = builder.TypeParam(tp.Name, tp.Constraint) + } + for _, p := range fn.Params { + if p.Optional { + builder = builder.OptParam(p.Name, p.Type) + } else { + builder = builder.Param(p.Name, p.Type) + } + } + if len(fn.Returns) > 0 { + builder = builder.Returns(fn.Returns...) + } + if fn.Effects != nil { + builder = builder.Effects(fn.Effects) + } + if fn.Spec != nil { + builder = builder.Spec(fn.Spec) + } + if fn.Refinement != nil { + builder = builder.WithRefinement(fn.Refinement) + } + return builder.Build() +} + +func repairTypeWithNarrow(summary, narrow typ.Type, depth int) typ.Type { + if summary == nil || narrow == nil || depth > typ.DefaultRecursionDepth { + return narrow + } + if typ.IsAny(summary) && !typ.IsAny(narrow) { + return narrow + } + summary = unwrap.Alias(summary) + narrow = unwrap.Alias(narrow) + switch s := summary.(type) { + case *typ.Union: + n, ok := narrow.(*typ.Union) + if !ok { + members := make([]typ.Type, len(s.Members)) + for i, member := range s.Members { + members[i] = repairTypeWithNarrow(member, narrow, depth+1) + } + return typ.NewUnion(members...) + } + if len(s.Members) != len(n.Members) { + return summary + } + members := make([]typ.Type, len(s.Members)) + for i, member := range s.Members { + members[i] = repairTypeWithNarrow(member, bestNarrowUnionMember(member, n.Members), depth+1) + } + return typ.NewUnion(members...) + case *typ.Record: + n, ok := narrow.(*typ.Record) + if !ok { + return narrow + } + builder := typ.NewRecord().SetOpen(s.Open) + if s.HasMapComponent() { + mapValue := s.MapValue + if n.HasMapComponent() { + mapValue = repairTypeWithNarrow(s.MapValue, n.MapValue, depth+1) + } + builder.MapComponent(s.MapKey, mapValue) + } + if s.Metatable != nil { + builder.Metatable(s.Metatable) + } + for _, field := range s.Fields { + fieldType := field.Type + if nf := n.GetField(field.Name); nf != nil { + fieldType = repairTypeWithNarrow(field.Type, nf.Type, depth+1) + } + switch { + case field.Optional && field.Readonly: + builder.OptReadonlyField(field.Name, fieldType) + case field.Optional: + builder.OptField(field.Name, fieldType) + case field.Readonly: + builder.ReadonlyField(field.Name, fieldType) + default: + builder.Field(field.Name, fieldType) + } + } + return builder.Build() + default: + return narrow + } +} + +func bestNarrowUnionMember(summary typ.Type, members []typ.Type) typ.Type { + for _, member := range members { + if subtype.IsSubtype(member, summary) || subtype.IsSubtype(summary, member) { + return member + } + } + if len(members) > 0 { + return members[0] + } + return summary +} + // WidenTypeForConvergence merges function-type facts at a recursive fixpoint // boundary. func WidenTypeForConvergence(existing, candidate typ.Type) typ.Type { diff --git a/compiler/check/domain/paramevidence/parameter_evidence.go b/compiler/check/domain/paramevidence/parameter_evidence.go index 757d8e92..45811060 100644 --- a/compiler/check/domain/paramevidence/parameter_evidence.go +++ b/compiler/check/domain/paramevidence/parameter_evidence.go @@ -40,14 +40,20 @@ func MergeIntoSignature(fn *ast.FunctionExpr, evidence []typ.Type, sig *typ.Func builder := typ.Func() for i, p := range sig.Params { paramType := p.Type + optional := p.Optional if i < len(evidence) && evidence[i] != nil { srcIdx, hasSource := signatureSourceParamIndex(fn, sig, i) annotated := hasSource && srcIdx < len(fn.ParList.Types) && fn.ParList.Types[srcIdx] != nil - if !annotated || typ.IsRefinableAnnotation(paramType) { + if !annotated { paramType = evidence[i] + if !unwrap.IsOptionalLike(evidence[i]) { + optional = false + } + } else if typ.IsRefinableAnnotation(paramType) { + paramType = mergeEvidenceIntoAnnotatedParam(paramType, evidence[i]) } } - if p.Optional { + if optional { builder = builder.OptParam(p.Name, paramType) } else { builder = builder.Param(p.Name, paramType) @@ -71,6 +77,20 @@ func MergeIntoSignature(fn *ast.FunctionExpr, evidence []typ.Type, sig *typ.Func return builder.Build() } +func mergeEvidenceIntoAnnotatedParam(annotation, evidence typ.Type) typ.Type { + if annotation == nil || evidence == nil { + return annotation + } + if unwrap.IsOptionalLike(annotation) { + inner := unwrap.Optional(evidence) + if inner == nil || unwrap.IsNilType(unwrap.Alias(evidence)) { + return annotation + } + return typ.NewOptional(inner) + } + return evidence +} + func signatureSourceParamIndex(fn *ast.FunctionExpr, sig *typ.Function, paramIdx int) (int, bool) { if fn == nil || fn.ParList == nil || sig == nil || paramIdx < 0 || paramIdx >= len(sig.Params) { return 0, false diff --git a/compiler/check/flowbuild/numconst/numconst.go b/compiler/check/flowbuild/numconst/numconst.go index 670b8ccb..a58e105e 100644 --- a/compiler/check/flowbuild/numconst/numconst.go +++ b/compiler/check/flowbuild/numconst/numconst.go @@ -28,12 +28,20 @@ func NegateConstraints(items []constraint.Constraint) []constraint.Constraint { func NumericConstraintFromComparisonWithBindings(op string, lhs, rhs ast.Expr, p cfg.Point, inputs *flow.Inputs, bindings *bind.BindingTable) constraint.NumericConstraint { leftPath := path.FromExprWithBindings(lhs, nil, bindings) rightPath := path.FromExprWithBindings(rhs, nil, bindings) + leftLenPath := lenPathFromExprWithBindings(lhs, bindings) + rightLenPath := lenPathFromExprWithBindings(rhs, bindings) leftConst, leftIsConst := IntConstFromExpr(lhs) rightConst, rightIsConst := IntConstFromExpr(rhs) switch op { case "<": + if !leftLenPath.IsEmpty() && rightIsConst { + return constraint.LenLeConst{Array: leftLenPath, C: rightConst - 1} + } + if leftIsConst && !rightLenPath.IsEmpty() { + return constraint.LenGeConst{Array: rightLenPath, C: leftConst + 1} + } if !leftPath.IsEmpty() && !rightPath.IsEmpty() { return constraint.Lt{X: leftPath, Y: rightPath} } @@ -44,6 +52,12 @@ func NumericConstraintFromComparisonWithBindings(op string, lhs, rhs ast.Expr, p return constraint.GeConst{X: rightPath, C: leftConst + 1} } case ">": + if !leftLenPath.IsEmpty() && rightIsConst { + return constraint.LenGeConst{Array: leftLenPath, C: rightConst + 1} + } + if leftIsConst && !rightLenPath.IsEmpty() { + return constraint.LenLeConst{Array: rightLenPath, C: leftConst - 1} + } if !leftPath.IsEmpty() && !rightPath.IsEmpty() { return constraint.Gt{X: leftPath, Y: rightPath} } @@ -54,6 +68,12 @@ func NumericConstraintFromComparisonWithBindings(op string, lhs, rhs ast.Expr, p return constraint.LeConst{X: rightPath, C: leftConst - 1} } case "<=": + if !leftLenPath.IsEmpty() && rightIsConst { + return constraint.LenLeConst{Array: leftLenPath, C: rightConst} + } + if leftIsConst && !rightLenPath.IsEmpty() { + return constraint.LenGeConst{Array: rightLenPath, C: leftConst} + } if !leftPath.IsEmpty() && !rightPath.IsEmpty() { return constraint.Le{X: leftPath, Y: rightPath, C: 0} } @@ -64,6 +84,12 @@ func NumericConstraintFromComparisonWithBindings(op string, lhs, rhs ast.Expr, p return constraint.GeConst{X: rightPath, C: leftConst} } case ">=": + if !leftLenPath.IsEmpty() && rightIsConst { + return constraint.LenGeConst{Array: leftLenPath, C: rightConst} + } + if leftIsConst && !rightLenPath.IsEmpty() { + return constraint.LenLeConst{Array: rightLenPath, C: leftConst} + } if !leftPath.IsEmpty() && !rightPath.IsEmpty() { return constraint.Ge{X: leftPath, Y: rightPath} } @@ -77,6 +103,14 @@ func NumericConstraintFromComparisonWithBindings(op string, lhs, rhs ast.Expr, p return nil } +func lenPathFromExprWithBindings(expr ast.Expr, bindings *bind.BindingTable) constraint.Path { + lenOp, ok := expr.(*ast.UnaryLenOpExpr) + if !ok || lenOp == nil { + return constraint.Path{} + } + return path.FromExprWithBindings(lenOp.Expr, nil, bindings) +} + // NegateNumericConstraint returns the negation of a numeric constraint. func NegateNumericConstraint(c constraint.NumericConstraint) constraint.NumericConstraint { if c == nil { @@ -95,6 +129,10 @@ func NegateNumericConstraint(c constraint.NumericConstraint) constraint.NumericC return constraint.GeConst{X: v.X, C: v.C + 1} case constraint.GeConst: return constraint.LeConst{X: v.X, C: v.C - 1} + case constraint.LenLeConst: + return constraint.LenGeConst{Array: v.Array, C: v.C + 1} + case constraint.LenGeConst: + return constraint.LenLeConst{Array: v.Array, C: v.C - 1} default: return nil } diff --git a/compiler/check/flowbuild/numconst/numconst_test.go b/compiler/check/flowbuild/numconst/numconst_test.go index 493ac9b7..24277b18 100644 --- a/compiler/check/flowbuild/numconst/numconst_test.go +++ b/compiler/check/flowbuild/numconst/numconst_test.go @@ -206,6 +206,32 @@ func TestNumericConstraintFromComparisonWithBindings_GePaths(t *testing.T) { } } +func TestNumericConstraintFromComparisonWithBindings_LenLowerBound(t *testing.T) { + lhs := &ast.UnaryLenOpExpr{Expr: &ast.IdentExpr{Value: "rows"}} + rhs := &ast.NumberExpr{Value: "0"} + result := numconst.NumericConstraintFromComparisonWithBindings(">", lhs, rhs, 0, nil, nil) + got, ok := result.(constraint.LenGeConst) + if !ok { + t.Fatalf("expected LenGeConst constraint, got %T", result) + } + if got.Array.Root != "rows" || got.C != 1 { + t.Fatalf("unexpected length lower bound: %#v", got) + } +} + +func TestNumericConstraintFromComparisonWithBindings_LenUpperBound(t *testing.T) { + lhs := &ast.NumberExpr{Value: "0"} + rhs := &ast.UnaryLenOpExpr{Expr: &ast.IdentExpr{Value: "rows"}} + result := numconst.NumericConstraintFromComparisonWithBindings(">=", lhs, rhs, 0, nil, nil) + got, ok := result.(constraint.LenLeConst) + if !ok { + t.Fatalf("expected LenLeConst constraint, got %T", result) + } + if got.Array.Root != "rows" || got.C != 0 { + t.Fatalf("unexpected length upper bound: %#v", got) + } +} + func TestNumericConstraintFromComparisonWithBindings_UnknownOp(t *testing.T) { lhs := &ast.IdentExpr{Value: "a"} rhs := &ast.IdentExpr{Value: "b"} diff --git a/compiler/check/hooks/call_check.go b/compiler/check/hooks/call_check.go index eb703add..6dcb5b37 100644 --- a/compiler/check/hooks/call_check.go +++ b/compiler/check/hooks/call_check.go @@ -30,6 +30,7 @@ import ( "github.com/wippyai/go-lua/types/diag" "github.com/wippyai/go-lua/types/effect" "github.com/wippyai/go-lua/types/query/core" + "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" "github.com/wippyai/go-lua/types/typ/unwrap" ) @@ -49,12 +50,13 @@ func CheckCalls( var diags []diag.Diagnostic query := narrowSynth.CallQuery() bindings := graph.Bindings() + unobservedLocalParams := make(map[cfg.SymbolID][]bool) graph.EachCallSite(func(p cfg.Point, info *cfg.CallInfo) { if info == nil { return } - callDiags := checkSingleCall(p, info, scopes, narrowView, narrowSynth, query, sourceName, graph, bindings) + callDiags := checkSingleCall(p, info, scopes, narrowView, narrowSynth, query, sourceName, graph, bindings, unobservedLocalParams) diags = append(diags, callDiags...) }) @@ -71,6 +73,7 @@ func checkSingleCall( sourceName string, graph *cfg.Graph, bindings *bind.BindingTable, + unobservedLocalParams map[cfg.SymbolID][]bool, ) []diag.Diagnostic { if info.Method == "" && info.Callee != nil { if t := narrowView.TypeOf(info.Callee, p); hasCallableTypeEffect(t) { @@ -105,6 +108,7 @@ func checkSingleCall( args[i] = narrowView.TypeOf(arg, p) } + ctx := narrowSynth.Context() def := ops.CallDef{ Args: args, Query: query, @@ -117,10 +121,15 @@ func checkSingleCall( def.ForceMethodReceiver = callsite.ForceMethodReceiver(bindings, graph, info) } else if info.Callee != nil { def.Callee = narrowView.TypeOf(info.Callee, p) + if factType, unobservedParams := functionFactCalleeType(api.StoreFrom(ctx), info, graph, bindings, unobservedLocalParams); factType != nil { + if typ.IsUnknownOrNil(def.Callee) || canonicalFactHasWiderParams(def.Callee, factType) { + def.Callee = factType + } else if len(unobservedParams) > 0 { + def.Callee = callTypeWithUnobservedLocalAnyArgs(def.Callee, args, unobservedParams) + } + } } - ctx := narrowSynth.Context() - pipeline := extract.NewCallPipeline(ctx, def, info.Args). WithReSynth(extract.FullArgReSynth( func(arg ast.Expr, pt cfg.Point, expected typ.Type) typ.Type { @@ -135,6 +144,106 @@ func checkSingleCall( return callErrorsToDiags(result.Errors, info, sourceName) } +func functionFactCalleeType( + store api.StoreReader, + info *cfg.CallInfo, + graph *cfg.Graph, + bindings *bind.BindingTable, + unobservedLocalParams map[cfg.SymbolID][]bool, +) (typ.Type, []bool) { + if store == nil || info == nil { + return nil, nil + } + moduleBindings := store.ModuleBindings() + for _, sym := range callsite.CallableCalleeSymbolCandidates(info, graph, bindings, moduleBindings) { + if ff, ok := api.FunctionFactSnapshotForSymbol(store, sym, nil); ok { + fn := callsite.FunctionLiteralForGraphSymbol(graph, sym) + graphLocal := fn != nil + t := ff.Type + var unobservedParams []bool + if graphLocal { + unobservedParams = unobservedLocalParamMask(store, sym, fn, unobservedLocalParams) + } + if t != nil { + return t, unobservedParams + } + } + } + return nil, nil +} + +func callTypeWithUnobservedLocalAnyArgs(callee typ.Type, args []typ.Type, unobservedParams []bool) typ.Type { + fn := unwrap.Function(callee) + if fn == nil || len(args) == 0 || len(fn.Params) == 0 || len(unobservedParams) == 0 { + return callee + } + changed := false + builder := typ.Func().ReserveParams(len(fn.Params)) + for _, tp := range fn.TypeParams { + builder = builder.TypeParam(tp.Name, tp.Constraint) + } + for i, p := range fn.Params { + paramType := p.Type + if i < len(args) && i < len(unobservedParams) && unobservedParams[i] && typ.IsAny(args[i]) && !typ.IsAny(paramType) { + paramType = typ.Any + changed = true + } + if p.Optional { + builder = builder.OptParam(p.Name, paramType) + } else { + builder = builder.Param(p.Name, paramType) + } + } + if !changed { + return callee + } + if fn.Variadic != nil { + builder = builder.Variadic(fn.Variadic) + } + if len(fn.Returns) > 0 { + builder = builder.Returns(fn.Returns...) + } + if fn.Effects != nil { + builder = builder.Effects(fn.Effects) + } + if fn.Spec != nil { + builder = builder.Spec(fn.Spec) + } + if fn.Refinement != nil { + builder = builder.WithRefinement(fn.Refinement) + } + return builder.Build() +} + +func canonicalFactHasWiderParams(current, fact typ.Type) bool { + currentFn := unwrap.Function(current) + factFn := unwrap.Function(fact) + if currentFn == nil || factFn == nil || len(currentFn.Params) != len(factFn.Params) { + return false + } + wider := false + for i, currentParam := range currentFn.Params { + factParam := factFn.Params[i] + if currentParam.Optional != factParam.Optional { + if currentParam.Optional && !factParam.Optional { + return false + } + wider = true + } + if typ.TypeEquals(currentParam.Type, factParam.Type) { + continue + } + if typ.IsAny(factParam.Type) || typ.IsAny(unwrap.Optional(factParam.Type)) { + wider = true + continue + } + if subtype.IsSubtype(currentParam.Type, factParam.Type) { + wider = true + } + } + return wider +} + func hasCallableTypeEffect(t typ.Type) bool { fn := unwrap.Function(t) if fn == nil { diff --git a/compiler/check/hooks/local_param_use.go b/compiler/check/hooks/local_param_use.go new file mode 100644 index 00000000..32554605 --- /dev/null +++ b/compiler/check/hooks/local_param_use.go @@ -0,0 +1,196 @@ +package hooks + +import ( + "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/bind" + "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/api" +) + +// unobservedLocalParamMask returns true for graph-local parameter slots whose +// value is never observed by the function body. Call diagnostics do not need to +// reject an explicit any argument for such a slot because no transfer can depend +// on the stricter annotation at runtime. +func unobservedLocalParamMask( + store api.StoreReader, + sym cfg.SymbolID, + fn *ast.FunctionExpr, + cache map[cfg.SymbolID][]bool, +) []bool { + if sym == 0 || fn == nil { + return nil + } + if cache != nil { + if mask, ok := cache[sym]; ok { + return mask + } + } + fnGraph := graphForFunctionSymbol(store, sym) + mask := computeUnobservedLocalParamMask(fn, fnGraph) + if cache != nil { + cache[sym] = mask + } + return mask +} + +func graphForFunctionSymbol(store api.StoreReader, sym cfg.SymbolID) *cfg.Graph { + if store == nil || sym == 0 { + return nil + } + ref := store.FunctionRefBySym(sym) + if ref == nil || ref.GraphID == 0 { + return nil + } + graphs := store.Graphs() + if graphs == nil { + return nil + } + return graphs[ref.GraphID] +} + +func computeUnobservedLocalParamMask(fn *ast.FunctionExpr, graph *cfg.Graph) []bool { + if fn == nil || graph == nil { + return nil + } + slots := graph.ParamSlotsReadOnly() + if len(slots) == 0 { + return nil + } + bindings := graph.Bindings() + if bindings == nil { + return nil + } + + paramIndex := make(map[cfg.SymbolID]int, len(slots)) + for i, slot := range slots { + if slot.Symbol != 0 { + paramIndex[slot.Symbol] = i + } + } + if len(paramIndex) == 0 { + return nil + } + + used := make([]bool, len(slots)) + markParamUsesInStmts(fn.Stmts, bindings, paramIndex, used) + + var mask []bool + for i, slot := range slots { + if slot.Symbol == 0 || used[i] { + continue + } + if mask == nil { + mask = make([]bool, len(slots)) + } + mask[i] = true + } + return mask +} + +func markParamUsesInStmts(stmts []ast.Stmt, bindings *bind.BindingTable, paramIndex map[cfg.SymbolID]int, used []bool) { + for _, stmt := range stmts { + markParamUsesInStmt(stmt, bindings, paramIndex, used) + } +} + +func markParamUsesInStmt(stmt ast.Stmt, bindings *bind.BindingTable, paramIndex map[cfg.SymbolID]int, used []bool) { + switch s := stmt.(type) { + case *ast.AssignStmt: + markParamUsesInExprs(s.Lhs, bindings, paramIndex, used) + markParamUsesInExprs(s.Rhs, bindings, paramIndex, used) + case *ast.LocalAssignStmt: + markParamUsesInExprs(s.Exprs, bindings, paramIndex, used) + case *ast.FuncCallStmt: + markParamUsesInExpr(s.Expr, bindings, paramIndex, used) + case *ast.DoBlockStmt: + markParamUsesInStmts(s.Stmts, bindings, paramIndex, used) + case *ast.WhileStmt: + markParamUsesInExpr(s.Condition, bindings, paramIndex, used) + markParamUsesInStmts(s.Stmts, bindings, paramIndex, used) + case *ast.RepeatStmt: + markParamUsesInStmts(s.Stmts, bindings, paramIndex, used) + markParamUsesInExpr(s.Condition, bindings, paramIndex, used) + case *ast.IfStmt: + markParamUsesInExpr(s.Condition, bindings, paramIndex, used) + markParamUsesInStmts(s.Then, bindings, paramIndex, used) + markParamUsesInStmts(s.Else, bindings, paramIndex, used) + case *ast.NumberForStmt: + markParamUsesInExpr(s.Init, bindings, paramIndex, used) + markParamUsesInExpr(s.Limit, bindings, paramIndex, used) + markParamUsesInExpr(s.Step, bindings, paramIndex, used) + markParamUsesInStmts(s.Stmts, bindings, paramIndex, used) + case *ast.GenericForStmt: + markParamUsesInExprs(s.Exprs, bindings, paramIndex, used) + markParamUsesInStmts(s.Stmts, bindings, paramIndex, used) + case *ast.FuncDefStmt: + if s.Name != nil { + markParamUsesInExpr(s.Name.Func, bindings, paramIndex, used) + markParamUsesInExpr(s.Name.Receiver, bindings, paramIndex, used) + } + if s.Func != nil { + markParamUsesInExpr(s.Func, bindings, paramIndex, used) + } + case *ast.ReturnStmt: + markParamUsesInExprs(s.Exprs, bindings, paramIndex, used) + } +} + +func markParamUsesInExprs(exprs []ast.Expr, bindings *bind.BindingTable, paramIndex map[cfg.SymbolID]int, used []bool) { + for _, expr := range exprs { + markParamUsesInExpr(expr, bindings, paramIndex, used) + } +} + +func markParamUsesInExpr(expr ast.Expr, bindings *bind.BindingTable, paramIndex map[cfg.SymbolID]int, used []bool) { + switch e := expr.(type) { + case nil: + return + case *ast.IdentExpr: + if sym, ok := bindings.SymbolOf(e); ok { + if idx, ok := paramIndex[sym]; ok && idx >= 0 && idx < len(used) { + used[idx] = true + } + } + case *ast.AttrGetExpr: + markParamUsesInExpr(e.Object, bindings, paramIndex, used) + markParamUsesInExpr(e.Key, bindings, paramIndex, used) + case *ast.TableExpr: + for _, field := range e.Fields { + if field == nil { + continue + } + markParamUsesInExpr(field.Key, bindings, paramIndex, used) + markParamUsesInExpr(field.Value, bindings, paramIndex, used) + } + case *ast.FuncCallExpr: + markParamUsesInExpr(e.Func, bindings, paramIndex, used) + markParamUsesInExpr(e.Receiver, bindings, paramIndex, used) + markParamUsesInExprs(e.Args, bindings, paramIndex, used) + case *ast.LogicalOpExpr: + markParamUsesInExpr(e.Lhs, bindings, paramIndex, used) + markParamUsesInExpr(e.Rhs, bindings, paramIndex, used) + case *ast.RelationalOpExpr: + markParamUsesInExpr(e.Lhs, bindings, paramIndex, used) + markParamUsesInExpr(e.Rhs, bindings, paramIndex, used) + case *ast.StringConcatOpExpr: + markParamUsesInExpr(e.Lhs, bindings, paramIndex, used) + markParamUsesInExpr(e.Rhs, bindings, paramIndex, used) + case *ast.ArithmeticOpExpr: + markParamUsesInExpr(e.Lhs, bindings, paramIndex, used) + markParamUsesInExpr(e.Rhs, bindings, paramIndex, used) + case *ast.UnaryMinusOpExpr: + markParamUsesInExpr(e.Expr, bindings, paramIndex, used) + case *ast.UnaryNotOpExpr: + markParamUsesInExpr(e.Expr, bindings, paramIndex, used) + case *ast.UnaryLenOpExpr: + markParamUsesInExpr(e.Expr, bindings, paramIndex, used) + case *ast.UnaryBNotOpExpr: + markParamUsesInExpr(e.Expr, bindings, paramIndex, used) + case *ast.FunctionExpr: + markParamUsesInStmts(e.Stmts, bindings, paramIndex, used) + case *ast.CastExpr: + markParamUsesInExpr(e.Expr, bindings, paramIndex, used) + case *ast.NonNilAssertExpr: + markParamUsesInExpr(e.Expr, bindings, paramIndex, used) + } +} diff --git a/compiler/check/infer/interproc/postflow.go b/compiler/check/infer/interproc/postflow.go index 14220351..23904668 100644 --- a/compiler/check/infer/interproc/postflow.go +++ b/compiler/check/infer/interproc/postflow.go @@ -15,6 +15,7 @@ import ( "github.com/wippyai/go-lua/compiler/check/synth/ops" "github.com/wippyai/go-lua/types/flow" "github.com/wippyai/go-lua/types/typ" + typjoin "github.com/wippyai/go-lua/types/typ/join" "github.com/wippyai/go-lua/types/typ/unwrap" ) @@ -76,11 +77,17 @@ func StoreFactsFromResult( summaryFromSnapshot := returnSummarySnapshotForSymbol(store, result, parent, fnSym) candidateFunc := fnType + if len(narrowSummary) > 0 && !returnsummary.AllNil(narrowSummary) { + if aligned := typjoin.WithReturns(candidateFunc, narrowSummary); aligned != nil { + candidateFunc = aligned + } + } if facts := store.GetFunctionFactsSnapshot(result.Graph, parent); len(facts) > 0 { if hinted := paramevidence.MergeIntoSignature(fn, facts.Params(fnSym), unwrap.Function(candidateFunc)); hinted != nil { candidateFunc = hinted } } + candidateFunc = stripSyntheticVariadic(fn, unwrap.Function(candidateFunc)) delta := api.Facts{FunctionFacts: api.FunctionFacts{ fnSym: functionfact.Join(api.FunctionFact{}, api.FunctionFact{ Summary: summaryFromSnapshot, @@ -129,6 +136,36 @@ func storeCapturedFactsFromResult( } } +func stripSyntheticVariadic(fn *ast.FunctionExpr, sig *typ.Function) *typ.Function { + if fn == nil || fn.ParList == nil || fn.ParList.HasVargs || sig == nil || sig.Variadic == nil { + return sig + } + builder := typ.Func().ReserveParams(len(sig.Params)) + for _, tp := range sig.TypeParams { + builder = builder.TypeParam(tp.Name, tp.Constraint) + } + for _, p := range sig.Params { + if p.Optional { + builder = builder.OptParam(p.Name, p.Type) + } else { + builder = builder.Param(p.Name, p.Type) + } + } + if len(sig.Returns) > 0 { + builder = builder.Returns(sig.Returns...) + } + if sig.Effects != nil { + builder = builder.Effects(sig.Effects) + } + if sig.Spec != nil { + builder = builder.Spec(sig.Spec) + } + if sig.Refinement != nil { + builder = builder.WithRefinement(sig.Refinement) + } + return builder.Build() +} + func bindingsForGraphOrModule(graph *cfg.Graph, store Store) *bind.BindingTable { if graph == nil { return nil diff --git a/compiler/check/phase/types_test.go b/compiler/check/phase/types_test.go index f4081170..a0d1c0f7 100644 --- a/compiler/check/phase/types_test.go +++ b/compiler/check/phase/types_test.go @@ -548,7 +548,7 @@ func TestMergeParameterEvidenceIntoSig_PreservesReturns(t *testing.T) { } } -func TestMergeParameterEvidenceIntoSig_PreservesOptionalParam(t *testing.T) { +func TestMergeParameterEvidenceIntoSig_ClearsSyntheticOptionalParam(t *testing.T) { fn := &ast.FunctionExpr{ ParList: &ast.ParList{ Names: []string{"x"}, @@ -565,7 +565,7 @@ func TestMergeParameterEvidenceIntoSig_PreservesOptionalParam(t *testing.T) { if len(result.Params) != 1 { t.Fatalf("expected 1 param, got %d", len(result.Params)) } - if !result.Params[0].Optional { - t.Error("expected param to remain optional") + if result.Params[0].Optional { + t.Error("expected non-nil evidence to clear synthetic optionality") } } diff --git a/compiler/check/synth/engine_test.go b/compiler/check/synth/engine_test.go index 02afd661..89b33199 100644 --- a/compiler/check/synth/engine_test.go +++ b/compiler/check/synth/engine_test.go @@ -87,6 +87,10 @@ func (m mockFlowOps) ArrayLenBoundWithOffsetAt(p cfg.Point, varName string) (arr return "", 0, false } +func (m mockFlowOps) LengthBoundsAt(p cfg.Point, path constraint.Path) (lower, upper int64, ok bool) { + return 0, 0, false +} + func (m mockFlowOps) IsPointDead(p cfg.Point) bool { return false } diff --git a/compiler/check/synth/intercept/chain.go b/compiler/check/synth/intercept/chain.go index e6651b5d..df84fd87 100644 --- a/compiler/check/synth/intercept/chain.go +++ b/compiler/check/synth/intercept/chain.go @@ -43,6 +43,7 @@ func (b *ChainBuilder) Build() *Chain { callIntercepts := []CallIntercept{ &SelectIntercept{VariadicResolver: b.variadicResolver}, &RequireIntercept{Manifests: b.manifests}, + &SetMetatableIntercept{}, &TypeCastIntercept{}, } diff --git a/compiler/check/synth/intercept/intercept.go b/compiler/check/synth/intercept/intercept.go index b830f467..25c6d296 100644 --- a/compiler/check/synth/intercept/intercept.go +++ b/compiler/check/synth/intercept/intercept.go @@ -63,6 +63,10 @@ type CallEnv struct { // For type names (Number, Point): returns synthetic callable function type. // Returns nil if the name is not a recognized function or type. TypeLookup func(name string) typ.Type + + // StableType resolves an expression to its graph-stable value shape when + // point-local synthesis is too early to see later field assignments. + StableType func(expr ast.Expr, current typ.Type) typ.Type } // CallIntercept handles AST-specific patterns in direct function calls. diff --git a/compiler/check/synth/intercept/setmetatable.go b/compiler/check/synth/intercept/setmetatable.go new file mode 100644 index 00000000..b931f478 --- /dev/null +++ b/compiler/check/synth/intercept/setmetatable.go @@ -0,0 +1,128 @@ +package intercept + +import ( + "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/types/kind" + "github.com/wippyai/go-lua/types/typ" + "github.com/wippyai/go-lua/types/typ/unwrap" +) + +// SetMetatableIntercept models Lua's setmetatable(table, metatable) primitive. +// +// The normal stdlib signature can express the value identity of the first +// argument, but the abstract state also needs the metatable edge on the returned +// table value so method/field queries see the prototype chain. +type SetMetatableIntercept struct{} + +func (s *SetMetatableIntercept) InterceptCall(ex *ast.FuncCallExpr, ctx CallEnv) Result { + if ex == nil || len(ex.Args) < 2 || ctx.Recurse == nil { + return Result{} + } + ident, ok := ex.Func.(*ast.IdentExpr) + if !ok || ident.Value != "setmetatable" { + return Result{} + } + + tableType := ctx.Recurse(ex.Args[0]) + metaType := ctx.Recurse(ex.Args[1]) + if ctx.StableType != nil { + metaType = ctx.StableType(ex.Args[1], metaType) + } + if tableType == nil { + return Result{Skip: true, Types: []typ.Type{typ.Unknown}} + } + + return Result{Skip: true, Types: []typ.Type{withMetatable(tableType, metaType)}} +} + +func withMetatable(tableType, metaType typ.Type) typ.Type { + tableType = unwrap.Alias(tableType) + if tableType == nil { + return typ.Unknown + } + + switch t := tableType.(type) { + case *typ.Record: + return recordWithMetatableVariants(t, metaType) + case *typ.Optional: + return typ.NewOptional(withMetatable(t.Inner, metaType)) + case *typ.Union: + members := make([]typ.Type, 0, len(t.Members)) + for _, member := range t.Members { + if member == nil || member.Kind() == kind.Nil { + members = append(members, member) + continue + } + members = append(members, withMetatable(member, metaType)) + } + return typ.NewUnion(members...) + default: + return tableType + } +} + +func recordWithMetatableVariants(rec *typ.Record, metaType typ.Type) typ.Type { + var variants []typ.Type + for _, meta := range metatableVariants(metaType) { + variants = append(variants, rebuildRecordWithMetatable(rec, meta)) + } + if len(variants) == 0 { + return rebuildRecordWithMetatable(rec, nil) + } + if len(variants) == 1 { + return variants[0] + } + return typ.NewUnion(variants...) +} + +func metatableVariants(metaType typ.Type) []typ.Type { + metaType = unwrap.Alias(metaType) + if metaType == nil { + return []typ.Type{nil} + } + switch m := metaType.(type) { + case *typ.Optional: + return []typ.Type{nil, unwrap.Alias(m.Inner)} + case *typ.Union: + var variants []typ.Type + hasNil := false + for _, member := range m.Members { + member = unwrap.Alias(member) + if member == nil || member.Kind() == kind.Nil { + if !hasNil { + variants = append(variants, nil) + hasNil = true + } + continue + } + variants = append(variants, member) + } + return variants + default: + if metaType.Kind() == kind.Nil { + return []typ.Type{nil} + } + return []typ.Type{metaType} + } +} + +func rebuildRecordWithMetatable(rec *typ.Record, meta typ.Type) typ.Type { + if rec == nil { + return typ.Unknown + } + builder := typ.NewRecord() + for _, field := range rec.Fields { + if field.Optional { + builder.OptField(field.Name, field.Type) + } else { + builder.Field(field.Name, field.Type) + } + } + if rec.HasMapComponent() { + builder.MapComponent(rec.MapKey, rec.MapValue) + } + if meta != nil { + builder.Metatable(meta) + } + return builder.SetOpen(rec.Open).Build() +} diff --git a/compiler/check/synth/intercept/setmetatable_test.go b/compiler/check/synth/intercept/setmetatable_test.go new file mode 100644 index 00000000..fef2c583 --- /dev/null +++ b/compiler/check/synth/intercept/setmetatable_test.go @@ -0,0 +1,93 @@ +package intercept + +import ( + "testing" + + "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/types/kind" + querycore "github.com/wippyai/go-lua/types/query/core" + "github.com/wippyai/go-lua/types/typ" +) + +func TestSetMetatableIntercept_AttachesMetatableToReturnedRecord(t *testing.T) { + table := typ.NewRecord().Field("nodes", typ.NewMap(typ.String, typ.Any)).Build() + method := typ.Func().Param("self", typ.Any).Returns(typ.Boolean).Build() + prototype := typ.NewRecord().Field("has_cycles", method).Build() + meta := typ.NewRecord().Field("__index", prototype).Build() + + ex := &ast.FuncCallExpr{ + Func: &ast.IdentExpr{Value: "setmetatable"}, + Args: []ast.Expr{ + &ast.IdentExpr{Value: "table"}, + &ast.IdentExpr{Value: "meta"}, + }, + } + result := (&SetMetatableIntercept{}).InterceptCall(ex, CallEnv{ + Recurse: func(expr ast.Expr) typ.Type { + if ident, ok := expr.(*ast.IdentExpr); ok && ident.Value == "meta" { + return meta + } + return table + }, + }) + + if !result.Skip || len(result.Types) != 1 { + t.Fatalf("expected intercepted single return, got %#v", result) + } + if _, ok := querycore.Method(result.Types[0], "has_cycles"); !ok { + t.Fatalf("expected returned record to expose metatable method, got %s", typ.FormatShort(result.Types[0])) + } +} + +func TestSetMetatableIntercept_OptionalMetatableKeepsNilVariantSound(t *testing.T) { + table := typ.NewRecord().Field("nodes", typ.NewMap(typ.String, typ.Any)).Build() + method := typ.Func().Param("self", typ.Any).Returns(typ.Boolean).Build() + prototype := typ.NewRecord().Field("has_cycles", method).Build() + meta := typ.NewRecord().Field("__index", prototype).Build() + + got := withMetatable(table, typ.NewOptional(meta)) + union, ok := got.(*typ.Union) + if !ok || len(union.Members) != 2 { + t.Fatalf("expected optional metatable to produce two variants, got %s", typ.FormatShort(got)) + } + + hasPlain := false + hasMeta := false + for _, member := range union.Members { + rec, ok := member.(*typ.Record) + if !ok { + t.Fatalf("expected record variants, got %T", member) + } + if rec.Metatable == nil { + hasPlain = true + continue + } + if _, ok := querycore.Method(rec, "has_cycles"); ok { + hasMeta = true + } + } + if !hasPlain || !hasMeta { + t.Fatalf("expected plain and metatabled variants, got %s", typ.FormatShort(got)) + } + if _, ok := querycore.Method(got, "has_cycles"); ok { + t.Fatal("optional metatable must not prove method exists on all variants") + } +} + +func TestSetMetatableIntercept_RemovesMetatableForNil(t *testing.T) { + method := typ.Func().Param("self", typ.Any).Returns(typ.Boolean).Build() + meta := typ.NewRecord().Field("has_cycles", method).Build() + table := typ.NewRecord().Metatable(meta).Build() + + got := withMetatable(table, typ.Nil) + rec, ok := got.(*typ.Record) + if !ok { + t.Fatalf("expected record, got %T", got) + } + if rec.Metatable != nil { + t.Fatalf("expected nil metatable to remove metatable, got %s", typ.FormatShort(rec.Metatable)) + } + if got.Kind() == kind.Never { + t.Fatal("setmetatable nil removal should not produce never") + } +} diff --git a/compiler/check/synth/ops/call.go b/compiler/check/synth/ops/call.go index c2f7451b..b0941fc5 100644 --- a/compiler/check/synth/ops/call.go +++ b/compiler/check/synth/ops/call.go @@ -583,7 +583,7 @@ func computeExpectedArgs(ctx *db.QueryContext, query core.TypeOps, fn *typ.Funct for i := 0; i < numArgs; i++ { paramIdx := i + paramOffset if paramIdx < len(fn.Params) { - expected[i] = fn.Params[paramIdx].Type + expected[i] = paramRuntimeType(fn.Params[paramIdx]) if isMethod && receiver != nil { expected[i] = subst.Self(expected[i], receiver) } @@ -873,7 +873,7 @@ func callFunction(ctx *db.QueryContext, query core.TypeOps, fn *typ.Function, ar if methodHasReceiver { var expectedReceiver typ.Type if len(fn.Params) > 0 { - expectedReceiver = fn.Params[0].Type + expectedReceiver = paramRuntimeType(fn.Params[0]) } else if hasVariadic { expectedReceiver = fn.Variadic } @@ -894,7 +894,7 @@ func callFunction(ctx *db.QueryContext, query core.TypeOps, fn *typ.Function, ar var expectedType typ.Type if paramIdx < len(fn.Params) { - expectedType = fn.Params[paramIdx].Type + expectedType = paramRuntimeType(fn.Params[paramIdx]) } else if hasVariadic { expectedType = fn.Variadic } else { @@ -929,6 +929,13 @@ func callFunction(ctx *db.QueryContext, query core.TypeOps, fn *typ.Function, ar return callResultFromReturns(returns, errors) } +func paramRuntimeType(param typ.Param) typ.Type { + if param.Type == nil || !param.Optional || unwrap.IsOptionalLike(param.Type) { + return param.Type + } + return typ.NewOptional(param.Type) +} + func normalizedCallReturns(result CallResult) []typ.Type { if len(result.Returns) > 0 { return copyTypeSlice(result.Returns) diff --git a/compiler/check/synth/phase/extract/call.go b/compiler/check/synth/phase/extract/call.go index e5a4a904..f2a448e5 100644 --- a/compiler/check/synth/phase/extract/call.go +++ b/compiler/check/synth/phase/extract/call.go @@ -1,7 +1,10 @@ package extract import ( + "sort" + "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/bind" compcfg "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/compiler/check/callsite" @@ -141,6 +144,9 @@ func (s *Synthesizer) synthCallCoreWithCaptureTypes( Scope: sc, Recurse: intercept.ExprSynth(recurse), TypeLookup: s.declaredTypeLookup(sc), + StableType: func(expr ast.Expr, current typ.Type) typ.Type { + return s.stablePrototypeType(expr, p, sc, current, recurse) + }, } chain := s.buildInterceptChain(sc) @@ -204,17 +210,40 @@ func (s *Synthesizer) specializedLocalFunctionCalleeType( return nil } info := graph.CallSiteAt(p, ex) - if info == nil { - return nil + candidates := callsite.CallableCalleeSymbolCandidates(info, graph, bindings, nil) + if len(candidates) == 0 { + if sym := callsite.SymbolFromExpr(ex.Func, bindings); sym != 0 { + candidates = append(candidates, sym) + } + if s.deps.ModuleBindings != nil && s.deps.ModuleBindings != bindings { + if sym := callsite.SymbolFromExpr(ex.Func, s.deps.ModuleBindings); sym != 0 { + candidates = append(candidates, sym) + } + } } - for _, sym := range callsite.CallableCalleeSymbolCandidates(info, graph, bindings, nil) { + for _, sym := range candidates { fn := callsite.FunctionLiteralForGraphSymbol(graph, sym) - if fn == nil { - continue + if fn != nil && !s.hasDominatingDirectFunctionRebind(sym, fn, p) { + factType := s.stableFunctionFactType(sym) + hasCallPointCaptureMutation := hasNonGlobalFunctionCaptures(bindings, fn) && s.hasDominatingCapturedMutation(fn, p) + if factType != nil && !hasCallPointCaptureMutation { + return factType + } + expectedFn, _ := unwrap.Optional(unwrap.Alias(current)).(*typ.Function) + if expectedFn == nil { + expectedFn, _ = unwrap.Optional(unwrap.Alias(factType)).(*typ.Function) + } + if fnType := s.synthFunctionTypeWithCapturePoint(fn, sc, expectedFn, p, captureTypes); fnType != nil { + return fnType + } + if factType != nil { + return factType + } } - expectedFn, _ := unwrap.Optional(unwrap.Alias(current)).(*typ.Function) - if fnType := s.synthFunctionTypeWithCapturePoint(fn, sc, expectedFn, p, captureTypes); fnType != nil { - return fnType + if typ.IsUnknownOrNil(current) { + if t := s.stableFunctionFactType(sym); t != nil { + return t + } } } return nil @@ -231,6 +260,9 @@ func (s *Synthesizer) synthMethodCallCoreWithExpected(ex *ast.FuncCallExpr, p cf Scope: sc, Recurse: intercept.ExprSynth(recurse), TypeLookup: s.declaredTypeLookup(sc), + StableType: func(expr ast.Expr, current typ.Type) typ.Type { + return s.stablePrototypeType(expr, p, sc, current, recurse) + }, } chain := s.buildInterceptChain(sc) @@ -273,6 +305,9 @@ func (s *Synthesizer) SynthCallWithReceiverType(ex *ast.FuncCallExpr, p cfg.Poin Scope: sc, Recurse: intercept.ExprSynth(recurse), TypeLookup: s.declaredTypeLookup(sc), + StableType: func(expr ast.Expr, current typ.Type) typ.Type { + return s.stablePrototypeType(expr, p, sc, current, recurse) + }, } chain := s.buildInterceptChain(sc) @@ -329,6 +364,194 @@ func (s *Synthesizer) declaredTypeLookup(sc *scope.State) func(string) typ.Type } } +func (s *Synthesizer) stablePrototypeType(expr ast.Expr, p cfg.Point, sc *scope.State, current typ.Type, recurse ExprSynth) typ.Type { + if s == nil || expr == nil || s.deps.CheckCtx == nil { + return current + } + ident, ok := expr.(*ast.IdentExpr) + if !ok || ident.Value == "" { + return current + } + graph, ok := s.deps.CheckCtx.Graph().(*compcfg.Graph) + if !ok || graph == nil { + return current + } + bindings := graph.Bindings() + if bindings == nil { + bindings = s.deps.ModuleBindings + } + if bindings == nil { + return current + } + sym, ok := bindings.SymbolOf(ident) + if !ok || sym == 0 { + return current + } + + fields := s.stablePrototypeFields(graph, sym, sc, recurse) + if len(fields) == 0 { + return current + } + + var base *typ.Record + if rec := unwrap.Record(current); rec != nil && !typ.IsUnknown(rec.Metatable) { + base = rec + } + builder := typ.NewRecord() + if base != nil { + for _, field := range base.Fields { + fields[field.Name] = typ.JoinPreferNonSoft(field.Type, fields[field.Name]) + } + if base.Metatable != nil { + builder.Metatable(base.Metatable) + } + if base.HasMapComponent() { + builder.MapComponent(base.MapKey, base.MapValue) + } + builder.SetOpen(base.Open) + } + + names := make([]string, 0, len(fields)) + for name := range fields { + names = append(names, name) + } + sort.Strings(names) + for _, name := range names { + t := fields[name] + if t == nil { + t = typ.Unknown + } + builder.Field(name, t) + } + return builder.Build() +} + +func (s *Synthesizer) stablePrototypeFields(graph *compcfg.Graph, sym compcfg.SymbolID, sc *scope.State, recurse ExprSynth) map[string]typ.Type { + if graph == nil || sym == 0 { + return nil + } + bindings := graph.Bindings() + functionFacts := s.currentFunctionFacts() + var fields map[string]typ.Type + addField := func(name string, t typ.Type) { + if name == "" { + return + } + if t == nil { + t = typ.Unknown + } + if fields == nil { + fields = make(map[string]typ.Type) + } + if existing := fields[name]; existing != nil { + fields[name] = typ.JoinPreferNonSoft(existing, t) + } else { + fields[name] = t + } + } + graph.EachAssign(func(p compcfg.Point, info *compcfg.AssignInfo) { + if info == nil { + return + } + sources := info.Sources + for i, target := range info.Targets { + fieldName := stablePrototypeFieldName(target, sym) + if fieldName == "" { + continue + } + var source ast.Expr + if i < len(sources) { + source = sources[i] + } + addField(fieldName, s.stablePrototypeFieldType(source, p, sc, bindings, functionFacts, recurse)) + } + }) + graph.EachFuncDef(func(p compcfg.Point, info *compcfg.FuncDefInfo) { + fieldName := stablePrototypeFuncDefFieldName(info, sym) + if fieldName == "" { + return + } + addField(fieldName, s.stablePrototypeFuncDefType(info, p, sc, bindings, functionFacts, recurse)) + }) + for _, field := range bindings.DirectFieldSymbols(sym) { + if field.Symbol == 0 { + continue + } + if t := functionFacts.FunctionType(field.Symbol); t != nil { + addField(field.Name, t) + continue + } + if fn, ok := bindings.FuncLitBySymbol(field.Symbol); ok { + addField(field.Name, s.stablePrototypeFieldType(fn, graph.Entry(), sc, bindings, functionFacts, recurse)) + } + } + return fields +} + +func stablePrototypeFieldName(target compcfg.AssignTarget, sym compcfg.SymbolID) string { + if target.BaseSymbol != sym { + return "" + } + switch target.Kind { + case compcfg.TargetField: + if len(target.FieldPath) == 1 { + return target.FieldPath[0] + } + case compcfg.TargetIndex: + if key, ok := target.Key.(*ast.StringExpr); ok { + return key.Value + } + } + return "" +} + +func stablePrototypeFuncDefFieldName(info *compcfg.FuncDefInfo, sym compcfg.SymbolID) string { + if info == nil || info.ReceiverSymbol != sym || info.Name == "" { + return "" + } + switch info.TargetKind { + case compcfg.FuncDefField, compcfg.FuncDefMethod: + return info.Name + default: + return "" + } +} + +func (s *Synthesizer) stablePrototypeFuncDefType(info *compcfg.FuncDefInfo, p compcfg.Point, sc *scope.State, bindings *bind.BindingTable, functionFacts api.FunctionFacts, recurse ExprSynth) typ.Type { + if info == nil { + return nil + } + if info.Symbol != 0 { + if t := functionFacts.FunctionType(info.Symbol); t != nil { + return t + } + } + return s.stablePrototypeFieldType(info.FuncExpr, p, sc, bindings, functionFacts, recurse) +} + +func (s *Synthesizer) stablePrototypeFieldType(source ast.Expr, p compcfg.Point, sc *scope.State, bindings *bind.BindingTable, functionFacts api.FunctionFacts, recurse ExprSynth) typ.Type { + if source == nil { + return nil + } + if fn, ok := source.(*ast.FunctionExpr); ok && bindings != nil { + if sym, ok := bindings.FuncLitSymbol(fn); ok && sym != 0 { + if t := functionFacts.FunctionType(sym); t != nil { + return t + } + } + if s != nil { + expected := typ.Func().Param("self", typ.Self).Build() + if t := s.SynthFunctionTypeWithExpected(fn, sc, expected); t != nil { + return t + } + } + } + if recurse != nil { + return recurse(source) + } + return nil +} + // buildInterceptChain creates the intercept chain for call synthesis. func (s *Synthesizer) buildInterceptChain(sc *scope.State) *intercept.Chain { builder := intercept.NewChainBuilder() diff --git a/compiler/check/synth/phase/extract/expr.go b/compiler/check/synth/phase/extract/expr.go index ea4e7dbf..33603340 100644 --- a/compiler/check/synth/phase/extract/expr.go +++ b/compiler/check/synth/phase/extract/expr.go @@ -78,6 +78,13 @@ func (n *localNarrowOps) ArrayLenBoundWithOffsetAt(p cfg.Point, varName string) return "", 0, false } +func (n *localNarrowOps) LengthBoundsAt(p cfg.Point, path constraint.Path) (int64, int64, bool) { + if n.inner != nil { + return n.inner.LengthBoundsAt(p, path) + } + return 0, 0, false +} + func (n *localNarrowOps) IsPointDead(p cfg.Point) bool { if n.inner != nil { return n.inner.IsPointDead(p) @@ -101,6 +108,9 @@ func (s *Synthesizer) synthAttrGetCore(ex *ast.AttrGetExpr, p cfg.Point, sc *sco if !path.IsEmpty() { narrowed := narrower.NarrowedTypeAt(p, path) if narrowed != nil { + if key, ok := ex.Key.(*ast.NumberExpr); ok && nilPresenceIsOnlyFlowUncertainty(narrowed) && s.literalLengthBoundProvesIndex(objType, ex.Object, key.Value, p, sc, narrower) { + goto skipNarrowedAttr + } if specialized := s.stableLocalFunctionValueType(ex, p, sc, narrowed, nil); specialized != nil { return specialized } @@ -161,6 +171,11 @@ skipNarrowedAttr: case *ast.NumberExpr: keyType := ops.ParseNumber(key.Value) if it, ok := s.deps.Types.Index(s.deps.Ctx, objType, keyType); ok { + if narrower != nil { + if narrowedResult := s.narrowArrayIndexByLiteralLenBound(objType, it, ex.Object, key.Value, p, sc, narrower); narrowedResult != nil { + return narrowedResult + } + } if specialized := s.stableLocalFunctionValueType(ex, p, sc, it, nil); specialized != nil { return specialized } @@ -389,6 +404,160 @@ func (s *Synthesizer) narrowArrayIndexByLenBound(indexResult typ.Type, objExpr a return opt.Inner } +func (s *Synthesizer) narrowArrayIndexByLiteralLenBound(objType, indexResult typ.Type, objExpr ast.Expr, indexLiteral string, p cfg.Point, sc *scope.State, narrower api.FlowOps) typ.Type { + if !s.literalLengthBoundProvesIndex(objType, objExpr, indexLiteral, p, sc, narrower) { + return nil + } + narrowed := narrow.RemoveNil(indexResult) + if typ.IsNever(narrowed) || typ.TypeEquals(narrowed, indexResult) { + return nil + } + return narrowed +} + +func (s *Synthesizer) literalLengthBoundProvesIndex(objType typ.Type, objExpr ast.Expr, indexLiteral string, p cfg.Point, sc *scope.State, narrower api.FlowOps) bool { + if narrower == nil || s == nil || s.deps.Paths == nil { + return false + } + index, ok := numparse.ParseIntegerLiteral(indexLiteral) + if !ok || index < 1 { + return false + } + tablePath := s.deps.Paths(p, objExpr, sc) + if tablePath.IsEmpty() { + return false + } + lower, _, ok := narrower.LengthBoundsAt(p, tablePath) + return ok && lower >= index && lengthBoundProvesSequenceIndex(objType, index) +} + +func nilPresenceIsOnlyFlowUncertainty(t typ.Type) bool { + if t == nil { + return false + } + if t.Kind() == kind.Nil { + return true + } + narrowed := narrow.RemoveNil(t) + return !typ.IsNever(narrowed) && !typ.TypeEquals(narrowed, t) +} + +func lengthBoundProvesSequenceIndex(t typ.Type, index int64) bool { + return lengthBoundProvesSequenceIndexDepth(t, index, 0) +} + +func lengthBoundProvesSequenceIndexDepth(t typ.Type, index int64, depth int) bool { + if t == nil || typ.DepthExceeded(depth) { + return false + } + t = unwrap.Alias(t) + if inst, ok := t.(*typ.Instantiated); ok { + if resolved, err := querycore.ResolveInstantiated(inst); err == nil { + return lengthBoundProvesSequenceIndexDepth(resolved, index, depth+1) + } + } + return typ.Visit(t, typ.Visitor[bool]{ + Array: func(*typ.Array) bool { + return true + }, + Tuple: func(tuple *typ.Tuple) bool { + return tuple != nil && int64(len(tuple.Elements)) >= index + }, + Optional: func(o *typ.Optional) bool { + return lengthBoundProvesSequenceIndexDepth(o.Inner, index, depth+1) + }, + Union: func(u *typ.Union) bool { + found := false + for _, m := range u.Members { + if m == nil || m.Kind() == kind.Nil { + continue + } + if lengthBoundProvesSequenceIndexDepth(m, index, depth+1) { + found = true + continue + } + if typeMaxLenLessThanIndex(m, index, depth+1) { + continue + } + return false + } + return found + }, + Intersection: func(in *typ.Intersection) bool { + for _, m := range in.Members { + if lengthBoundProvesSequenceIndexDepth(m, index, depth+1) { + return true + } + } + return false + }, + Recursive: func(r *typ.Recursive) bool { + if r.Body == nil || r.Body == r { + return false + } + return lengthBoundProvesSequenceIndexDepth(r.Body, index, depth+1) + }, + Default: func(typ.Type) bool { + return false + }, + }) +} + +func typeMaxLenLessThanIndex(t typ.Type, index int64, depth int) bool { + if t == nil || typ.DepthExceeded(depth) { + return false + } + t = unwrap.Alias(t) + if inst, ok := t.(*typ.Instantiated); ok { + if resolved, err := querycore.ResolveInstantiated(inst); err == nil { + return typeMaxLenLessThanIndex(resolved, index, depth+1) + } + } + return typ.Visit(t, typ.Visitor[bool]{ + Tuple: func(tuple *typ.Tuple) bool { + return tuple == nil || int64(len(tuple.Elements)) < index + }, + Record: func(rec *typ.Record) bool { + return rec != nil && + !rec.Open && + !rec.HasMapComponent() && + rec.Metatable == nil && + index > 0 + }, + Optional: func(o *typ.Optional) bool { + return o == nil || typeMaxLenLessThanIndex(o.Inner, index, depth+1) + }, + Union: func(u *typ.Union) bool { + for _, m := range u.Members { + if m == nil || m.Kind() == kind.Nil { + continue + } + if !typeMaxLenLessThanIndex(m, index, depth+1) { + return false + } + } + return true + }, + Intersection: func(in *typ.Intersection) bool { + for _, m := range in.Members { + if typeMaxLenLessThanIndex(m, index, depth+1) { + return true + } + } + return false + }, + Recursive: func(r *typ.Recursive) bool { + if r.Body == nil || r.Body == r { + return false + } + return typeMaxLenLessThanIndex(r.Body, index, depth+1) + }, + Default: func(typ.Type) bool { + return false + }, + }) +} + func indexVarOffsetFromExpr(expr ast.Expr) (string, int64, bool) { switch e := expr.(type) { case *ast.IdentExpr: diff --git a/compiler/check/synth/phase/extract/named_function.go b/compiler/check/synth/phase/extract/named_function.go index da80dc3a..8d7d27d9 100644 --- a/compiler/check/synth/phase/extract/named_function.go +++ b/compiler/check/synth/phase/extract/named_function.go @@ -2,6 +2,7 @@ package extract import ( "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/bind" compcfg "github.com/wippyai/go-lua/compiler/cfg" cfganalysis "github.com/wippyai/go-lua/compiler/cfg/analysis" "github.com/wippyai/go-lua/compiler/check/api" @@ -9,7 +10,6 @@ import ( "github.com/wippyai/go-lua/compiler/check/scope" "github.com/wippyai/go-lua/types/cfg" "github.com/wippyai/go-lua/types/flow" - "github.com/wippyai/go-lua/types/subtype" "github.com/wippyai/go-lua/types/typ" "github.com/wippyai/go-lua/types/typ/unwrap" ) @@ -117,11 +117,33 @@ func (s *Synthesizer) graphLocalFunctionForExpr(expr ast.Expr) (compcfg.SymbolID if captureBindings == nil { captureBindings = moduleBindings } - hasCaptures := captureBindings != nil && len(captureBindings.CapturedSymbols(fn)) > 0 + hasCaptures := hasNonGlobalFunctionCaptures(captureBindings, fn) return sym, fn, hasCaptures } +func hasNonGlobalFunctionCaptures(bindings *bind.BindingTable, fn *ast.FunctionExpr) bool { + return len(nonGlobalFunctionCaptures(bindings, fn)) > 0 +} + +func nonGlobalFunctionCaptures(bindings *bind.BindingTable, fn *ast.FunctionExpr) map[cfg.SymbolID]struct{} { + captures := make(map[cfg.SymbolID]struct{}) + if bindings == nil || fn == nil { + return captures + } + for _, sym := range bindings.CapturedSymbols(fn) { + if sym == 0 { + continue + } + kind, ok := bindings.Kind(sym) + if ok && kind == cfg.SymbolGlobal { + continue + } + captures[sym] = struct{}{} + } + return captures +} + func (s *Synthesizer) graphLocalFunctionLiteralForExpr(expr ast.Expr) *ast.FunctionExpr { _, fn, _ := s.graphLocalFunctionForExpr(expr) return fn @@ -252,6 +274,27 @@ func (s *Synthesizer) stableGraphLocalFunctionSnapshotType(sym compcfg.SymbolID) return snapshotType } +func (s *Synthesizer) stableFunctionFactType(sym compcfg.SymbolID) typ.Type { + if s == nil || sym == 0 { + return nil + } + if t := s.currentFunctionFacts().FunctionType(sym); t != nil { + return t + } + if s.deps == nil || s.deps.Ctx == nil { + return nil + } + store := api.StoreFrom(s.deps.Ctx) + if store == nil { + return nil + } + defaultParent := s.deps.DefaultScope + if defaultParent == nil && s.deps.CheckCtx != nil { + defaultParent = s.deps.CheckCtx.TypeNames() + } + return api.FunctionTypeSnapshotForSymbol(store, sym, defaultParent) +} + func (s *Synthesizer) stableLocalFunctionValueType( expr ast.Expr, p cfg.Point, @@ -263,6 +306,9 @@ func (s *Synthesizer) stableLocalFunctionValueType( if fn == nil { return nil } + if s.hasDominatingDirectFunctionRebind(sym, fn, p) { + return nil + } authoritative := current if s.deps != nil && s.deps.CheckCtx != nil { @@ -278,33 +324,160 @@ func (s *Synthesizer) stableLocalFunctionValueType( facts := ctx.FunctionFacts() if factType := facts.FunctionType(sym); factType != nil { hasContextFact = true - if authoritative == nil || subtype.IsSubtype(factType, authoritative) { - authoritative = factType - } + authoritative = factType } } } if !hasContextFact { if snapshot := s.stableGraphLocalFunctionSnapshotType(sym); snapshot != nil { - if authoritative == nil || subtype.IsSubtype(snapshot, authoritative) { - authoritative = snapshot - } + authoritative = snapshot } } if !hasCaptures && authoritative != nil { return authoritative } + hasCallPointCaptureMutation := hasCaptures && s.hasDominatingCapturedMutation(fn, p) + if !hasCallPointCaptureMutation && authoritative != nil && !functionTypeNeedsBodyRepair(authoritative) { + return authoritative + } + expectedFn, _ := unwrap.Optional(unwrap.Alias(authoritative)).(*typ.Function) specialized := s.synthFunctionTypeWithCapturePoint(fn, sc, expectedFn, p, captureTypes) - if authoritative != nil && specialized != nil { - if subtype.IsSubtype(specialized, authoritative) { - return specialized - } - return authoritative + if specialized != nil { + return specialized } if authoritative != nil { return authoritative } return specialized } + +func (s *Synthesizer) hasDominatingCapturedMutation(fn *ast.FunctionExpr, p cfg.Point) bool { + if s == nil || fn == nil || p == 0 || s.deps == nil || s.deps.CheckCtx == nil { + return false + } + graph, ok := s.deps.CheckCtx.Graph().(*compcfg.Graph) + if !ok || graph == nil { + return false + } + bindings := graph.Bindings() + if bindings == nil { + bindings = s.deps.ModuleBindings + } + captures := nonGlobalFunctionCaptures(bindings, fn) + if len(captures) == 0 { + return false + } + + var defPoint cfg.Point + graph.EachFuncDef(func(point cfg.Point, info *compcfg.FuncDefInfo) { + if defPoint != 0 || info == nil || info.FuncExpr != fn { + return + } + defPoint = point + }) + if defPoint == 0 { + graph.EachAssign(func(point cfg.Point, info *compcfg.AssignInfo) { + if defPoint != 0 || info == nil { + return + } + info.EachTargetSource(func(_ int, _ compcfg.AssignTarget, source ast.Expr) { + if defPoint == 0 && source == fn { + defPoint = point + } + }) + }) + } + if defPoint == 0 { + return false + } + + idom := cfganalysis.ComputeImmediateDominators(graph.CFG()) + mutated := false + graph.EachAssign(func(point cfg.Point, info *compcfg.AssignInfo) { + if mutated || info == nil || point == defPoint { + return + } + if !cfganalysis.StrictlyDominates(idom, defPoint, point) || !cfganalysis.StrictlyDominates(idom, point, p) { + return + } + info.EachTarget(func(_ int, target compcfg.AssignTarget) { + if mutated { + return + } + if _, ok := captures[target.Symbol]; ok && target.Symbol != 0 { + mutated = true + return + } + if _, ok := captures[target.BaseSymbol]; ok && target.BaseSymbol != 0 { + mutated = true + } + }) + }) + return mutated +} + +func functionTypeNeedsBodyRepair(t typ.Type) bool { + fn := unwrap.Function(t) + if fn == nil { + return false + } + if typeContainsAny(fn.Variadic, 0) { + return true + } + for _, ret := range fn.Returns { + if typeContainsAny(ret, 0) { + return true + } + } + return false +} + +func typeContainsAny(t typ.Type, depth int) bool { + if t == nil || depth > typ.DefaultRecursionDepth { + return false + } + t = unwrap.Alias(t) + if typ.IsAny(t) { + return true + } + switch v := t.(type) { + case *typ.Optional: + return typeContainsAny(v.Inner, depth+1) + case *typ.Union: + for _, member := range v.Members { + if typeContainsAny(member, depth+1) { + return true + } + } + case *typ.Intersection: + for _, member := range v.Members { + if typeContainsAny(member, depth+1) { + return true + } + } + case *typ.Array: + return typeContainsAny(v.Element, depth+1) + case *typ.Map: + return typeContainsAny(v.Key, depth+1) || typeContainsAny(v.Value, depth+1) + case *typ.Tuple: + for _, elem := range v.Elements { + if typeContainsAny(elem, depth+1) { + return true + } + } + case *typ.Record: + if typeContainsAny(v.MapKey, depth+1) || typeContainsAny(v.MapValue, depth+1) || typeContainsAny(v.Metatable, depth+1) { + return true + } + for _, field := range v.Fields { + if typeContainsAny(field.Type, depth+1) { + return true + } + } + case *typ.Function: + return functionTypeNeedsBodyRepair(v) + } + return false +} diff --git a/compiler/check/tests/regression/external_lint_regression_test.go b/compiler/check/tests/regression/external_lint_regression_test.go index e09dc9e1..2df29445 100644 --- a/compiler/check/tests/regression/external_lint_regression_test.go +++ b/compiler/check/tests/regression/external_lint_regression_test.go @@ -86,6 +86,61 @@ return parsed, parse_err } } +func TestExternalLint_MethodSelectedOptionalResponseBodyDefaultIsStringAtCall(t *testing.T) { + jsonModule := testutil.CheckAndExport(` +local json = {} +function json.decode(raw: string): any + return {} +end +return json +`, "json", testutil.WithStdlib()) + if jsonModule.HasError() { + t.Fatalf("json module errors: %v", testutil.ErrorMessages(jsonModule.Errors)) + } + + httpModule := testutil.CheckAndExport(` +local http = {} +type Response = {status_code: number, body: string?} +function http.get(url: string, options: {[string]: any}?): (Response?, string?) + return { status_code = 200, body = nil }, nil +end +function http.post(url: string, options: {[string]: any}?): (Response?, string?) + return { status_code = 200, body = nil }, nil +end +return http +`, "http_client", testutil.WithStdlib()) + if httpModule.HasError() { + t.Fatalf("http module errors: %v", testutil.ErrorMessages(httpModule.Errors)) + } + + source := ` +local json = require("json") +local http_client = require("http_client") + +local function request(method: string) + local response, err + if method == "GET" then + response, err = http_client.get("https://example.test", {}) + else + response, err = http_client.post("https://example.test", {}) + end + + if not response then + return nil, err + end + + local parsed, parse_err = json.decode(response.body or "") + return parsed, parse_err +end +` + result := testutil.Check(source, testutil.WithStdlib(), + testutil.WithModule("json", jsonModule), + testutil.WithModule("http_client", httpModule)) + if result.HasError() { + t.Fatalf("expected selected HTTP method body fallback to feed string call argument, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + func TestExternalLint_GuardedOptionsModelSurvivesProviderBranches(t *testing.T) { source := ` local models = { @@ -374,6 +429,68 @@ return content, render_err } } +func TestExternalLint_ErrorGuardedImportedPageFieldCastFeedsMethodCall(t *testing.T) { + templatesModule := testutil.CheckAndExport(` +local templates = {} +function templates.get(id: string) + return { + render = function(self, name: string, context: table) + return name, nil + end, + release = function(self) + end, + }, nil +end +return templates +`, "templates", testutil.WithStdlib()) + if templatesModule.HasError() { + t.Fatalf("templates module errors: %v", testutil.ErrorMessages(templatesModule.Errors)) + } + + pageRegistryModule := testutil.CheckAndExport(` +local pages = {} +function pages.get(id: string) + if id == "" then + return nil, "missing" + end + return { + template_set = "main", + template_name = nil :: unknown, + }, nil +end +return pages +`, "page_registry", testutil.WithStdlib()) + if pageRegistryModule.HasError() { + t.Fatalf("page_registry module errors: %v", testutil.ErrorMessages(pageRegistryModule.Errors)) + } + + source := ` +local templates = require("templates") +local page_registry = require("page_registry") + +local page, err = page_registry.get("home") +if err then + return nil, err +end + +local template_set: string = page.template_set +local tmpl, tmpl_get_err = templates.get(template_set) +if tmpl_get_err then + return nil, tmpl_get_err +end + +local content, render_err = tmpl:render(page.template_name :: string, {}) +tmpl:release() +return content, render_err +` + result := testutil.Check(source, testutil.WithStdlib(), + testutil.WithModule("templates", templatesModule), + testutil.WithModule("page_registry", pageRegistryModule)) + if result.HasError() { + t.Fatalf("expected error guard plus field cast to feed imported method call, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + func TestExternalLint_InsertedSuiteShapeSurvivesIpairs(t *testing.T) { source := ` type Suite = { @@ -593,6 +710,661 @@ end } } +func TestExternalLint_GuardedBodyUseDoesNotEraseOptionalParamBoundary(t *testing.T) { + source := ` +type Executor = { + with_context: (self: Executor, context: {[string]: any}) -> Executor, + call: (self: Executor, id: string, data: any) -> (any, string?), +} + +local funcs = { + new = function(): Executor + return { + with_context = function(self: Executor, context: {[string]: any}) + return self + end, + call = function(self: Executor, id: string, data: any) + return data, nil + end, + } + end, +} + +local function call_func(func_id: string, data: any, context: {[string]: any}?) + local executor = funcs.new() + if context ~= nil then + executor = executor:with_context(context) + end + return executor:call(func_id, data) +end + +local maybe_context = nil :: {[string]: any}? +call_func("map", {}, maybe_context) +call_func("filter", {}, nil) +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatalf("expected guarded body use to preserve optional parameter boundary, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_LocalMetatableInstanceKeepsLaterMethods(t *testing.T) { + source := ` +local module = {} +local Class = {} +local class_mt = { __index = Class } + +function module.new() + return setmetatable({ + nodes = {}, + }, class_mt) +end + +function Class:is_empty() + return next(self.nodes) == nil +end + +function Class:has_cycles() + return false, nil +end + +function module.build() + local graph = module.new() + if graph:is_empty() then + return graph, nil + end + local has_cycles, cycle_desc = graph:has_cycles() + if has_cycles then + return nil, cycle_desc + end + return graph, nil +end +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatalf("expected local metatable instance to keep class methods, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_ImportedMetatableQueryLengthGuardNarrowsFirstElement(t *testing.T) { + sessionSource := ` +local session = {} + +local context_query = {} +context_query.__index = context_query + +local session_reader = {} +session_reader.__index = session_reader + +function session.open() + return setmetatable({}, session_reader), nil +end + +function session_reader:contexts() + local query = setmetatable({}, context_query) + return query +end + +function context_query:type(_context_type) + return self +end + +function context_query:all() + local contexts, err = { { text = "summary", created_at = "now" } }, nil + if err then + return nil, err + end + return contexts or {}, nil +end + +return session +` + sessionModule := testutil.CheckAndExport(sessionSource, "session", testutil.WithStdlib()) + if sessionModule.HasError() { + t.Fatalf("session module should export cleanly, got: %v", testutil.ErrorMessages(sessionModule.Errors)) + } + + source := ` +local session = require("session") + +local reader, open_err = session.open() +if not reader then + return nil, open_err +end + +local existing_summaries, ctx_err = reader:contexts():type("conversation_summary"):all() +if ctx_err then + existing_summaries = {} +end + +local existing_summary = nil +if existing_summaries and #existing_summaries > 0 then + existing_summary = existing_summaries[1].text +end + +return existing_summary +` + result := testutil.Check(source, testutil.WithStdlib(), testutil.WithModule("session", sessionModule)) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatalf("expected imported query length guard to prove first element, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_LengthGuardEliminatesEmptyTableFallback(t *testing.T) { + source := ` +type Context = { + text: string, + created_at: string, +} + +local repo = {} +function repo.list_by_type(): ({Context}?, string?) + return { { text = "summary", created_at = "now" } }, nil +end + +local query = {} +function query:all() + local contexts, err = repo.list_by_type() + if err then + return nil, err + end + return contexts or {}, nil +end + +local existing_summaries, err = query:all() +if err then + existing_summaries = {} +end + +local existing_summary = nil +if existing_summaries and #existing_summaries > 0 then + existing_summary = existing_summaries[1].text +end +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected positive length guard to eliminate empty fallback before literal index, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_ImportedUntypedRepositoryFallbackEliminatesNil(t *testing.T) { + sessionSource := ` +local session = {} + +local executor = {} +function executor:query(): any + return nil +end + +local repo = {} +function repo.list_by_type(_session_id, _context_type) + local contexts, err = executor:query() + if err then + return nil, err + end + return contexts +end + +local context_query = { + _session_id = nil :: string?, + _type_filter = nil :: string?, + _error = nil :: string?, +} +context_query.__index = context_query + +local session_reader = { + session_id = nil :: string?, +} +session_reader.__index = session_reader + +function session.open(session_id) + return setmetatable({ session_id = session_id }, session_reader), nil +end + +function session_reader:contexts() + local query = setmetatable({}, context_query) + query._session_id = self.session_id + query._type_filter = nil + query._error = nil + return query +end + +function context_query:type(context_type) + if not context_type then + self._error = "Context type is required" + return self + end + self._type_filter = context_type + return self +end + +function context_query:all() + if self._error then + return nil, self._error + end + local contexts, err = repo.list_by_type(self._session_id, self._type_filter) + if err then + return nil, err + end + return contexts or {}, nil +end + +return session +` + sessionModule := testutil.CheckAndExport(sessionSource, "session", testutil.WithStdlib()) + if sessionModule.HasError() { + t.Fatalf("session module should export cleanly, got: %v", testutil.ErrorMessages(sessionModule.Errors)) + } + + source := ` +local session = require("session") + +local session_reader, session_err = session.open("s1") +if not session_reader then + return nil, session_err +end + +local existing_summaries, ctx_err = session_reader:contexts():type("conversation_summary"):all() +if ctx_err then + existing_summaries = {} +end + +if existing_summaries and #existing_summaries > 0 then + local first = existing_summaries[1] +end +` + result := testutil.Check(source, testutil.WithStdlib(), testutil.WithModule("session", sessionModule)) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatalf("expected imported nil fallback and error repair to eliminate nil index, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_QueryBuilderReaderBackReferenceSurvivesMethodChain(t *testing.T) { + source := ` +local session = {} + +local message_query = { + _session_id = nil :: string?, + _reader = nil :: any, + _after_message_id = nil :: string?, + _error = nil :: string?, +} +message_query.__index = message_query + +local session_reader = { + session_id = nil :: string?, +} +session_reader.__index = session_reader + +function session.open(session_id) + return setmetatable({ session_id = session_id }, session_reader), nil +end + +function session_reader:get_context(_key) + return "checkpoint", nil +end + +function session_reader:messages() + local query = setmetatable({}, message_query) + query._session_id = self.session_id + query._reader = self + query._after_message_id = nil + query._error = nil + return query +end + +function message_query:from_checkpoint() + if not self._reader then + self._error = "Reader reference missing" + return self + end + local checkpoint_id = self._reader:get_context("current_checkpoint") + if checkpoint_id then + self._after_message_id = checkpoint_id + end + return self +end + +function message_query:all() + if self._error then + return nil, self._error + end + return {}, nil +end + +local reader, err = session.open("s1") +if not reader then + return nil, err +end + +local messages_after_checkpoint, msg_err = reader:messages():from_checkpoint():all() +if msg_err then + return nil, msg_err +end + +local all_messages, all_err = reader:messages():all() +if all_err then + return nil, all_err +end +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatalf("expected query builder reader back-reference to survive method chains, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_ImportedQueryBuilderReaderBackReferenceSurvivesMethodChain(t *testing.T) { + sessionSource := ` +local session = {} + +local message_query = { + _session_id = nil :: string?, + _reader = nil :: any, + _after_message_id = nil :: string?, + _error = nil :: string?, +} +message_query.__index = message_query + +local session_reader = { + session_id = nil :: string?, +} +session_reader.__index = session_reader + +function session.open(session_id) + return setmetatable({ session_id = session_id }, session_reader), nil +end + +function session_reader:get_context(_key) + return "checkpoint", nil +end + +function session_reader:messages() + local query = setmetatable({}, message_query) + query._session_id = self.session_id + query._reader = self + query._after_message_id = nil + query._error = nil + return query +end + +function message_query:from_checkpoint() + if not self._reader then + self._error = "Reader reference missing" + return self + end + local checkpoint_id = self._reader:get_context("current_checkpoint") + if checkpoint_id then + self._after_message_id = checkpoint_id + end + return self +end + +function message_query:all() + if self._error then + return nil, self._error + end + return {}, nil +end + +return session +` + sessionModule := testutil.CheckAndExport(sessionSource, "session", testutil.WithStdlib()) + if sessionModule.HasError() { + t.Fatalf("session module should export cleanly, got: %v", testutil.ErrorMessages(sessionModule.Errors)) + } + + source := ` +local session = require("session") + +local reader, err = session.open("s1") +if not reader then + return nil, err +end + +local messages_after_checkpoint, msg_err = reader:messages():from_checkpoint():all() +if msg_err then + return nil, msg_err +end + +local all_messages, all_err = reader:messages():all() +if all_err then + return nil, all_err +end +` + result := testutil.Check(source, testutil.WithStdlib(), testutil.WithModule("session", sessionModule)) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatalf("expected imported query builder reader back-reference to survive method chains, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_MultipleQueryBuilderPrototypesKeepMethodReceiversSeparate(t *testing.T) { + source := ` +local session = {} + +local message_query = { + _session_id = nil :: string?, + _reader = nil :: any, + _error = nil :: string?, +} +message_query.__index = message_query + +local artifact_query = { + _session_id = nil :: string?, + _error = nil :: string?, +} +artifact_query.__index = artifact_query + +local context_query = { + _session_id = nil :: string?, + _error = nil :: string?, +} +context_query.__index = context_query + +local session_reader = { + session_id = nil :: string?, + _session_data = nil :: any, + _primary_context_cache = nil :: any, +} +session_reader.__index = session_reader + +function session.open(session_id) + return setmetatable({ + session_id = session_id, + _session_data = {}, + _primary_context_cache = nil, + }, session_reader), nil +end + +function session_reader:get_context(_key) + return "checkpoint", nil +end + +function session_reader:messages() + local query = setmetatable({}, message_query) + query._session_id = self.session_id + query._reader = self + query._error = nil + return query +end + +function session_reader:artifacts() + local query = setmetatable({}, artifact_query) + query._session_id = self.session_id + query._error = nil + return query +end + +function session_reader:contexts() + local query = setmetatable({}, context_query) + query._session_id = self.session_id + query._error = nil + return query +end + +function message_query:from_checkpoint() + if not self._reader then + self._error = "Reader reference missing" + return self + end + local checkpoint_id = self._reader:get_context("current_checkpoint") + return self +end + +function message_query:all() + if self._error then + return nil, self._error + end + return {}, nil +end + +function artifact_query:all() + if self._error then + return nil, self._error + end + return {}, nil +end + +function context_query:all() + if self._error then + return nil, self._error + end + return {}, nil +end + +local reader, err = session.open("s1") +if not reader then + return nil, err +end + +local messages_after_checkpoint, msg_err = reader:messages():from_checkpoint():all() +if msg_err then + return nil, msg_err +end + +local all_messages, all_err = reader:messages():all() +if all_err then + return nil, all_err +end +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatalf("expected multiple query builder prototypes to keep all() receiver facts separate, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_SessionReaderQueryBuilderRealShape(t *testing.T) { + source := ` +type Context = { + text: string, + created_at: string, + time: string?, +} + +local session = { + _session_contexts_repo = {}, +} + +function session._session_contexts_repo.list_by_type(session_id: string?, context_type: string?): ({Context}?, string?) + return { { text = "summary", created_at = "now" } }, nil +end + +local context_query = { + _session_id = nil :: string?, + _type_filter = nil :: string?, + _error = nil :: string?, +} +context_query.__index = context_query + +local session_reader = { + session_id = nil :: string?, +} +session_reader.__index = session_reader + +function session.open() + return setmetatable({ session_id = "s1" }, session_reader), nil +end + +function session_reader:contexts() + local query = setmetatable({}, context_query) + query._session_id = self.session_id + query._type_filter = nil + query._error = nil + return query +end + +function context_query:type(context_type) + if not context_type then + self._error = "Context type is required" + return self + end + self._type_filter = context_type + return self +end + +function context_query:all() + if self._error then + return nil, self._error + end + + local contexts, err + if self._type_filter then + contexts, err = session._session_contexts_repo.list_by_type(self._session_id, self._type_filter) + else + contexts, err = session._session_contexts_repo.list_by_type(self._session_id, self._type_filter) + end + + if err then + return nil, "Failed to fetch contexts: " .. err + end + + return contexts or {}, nil +end + +local session_reader, session_err = session.open() +if not session_reader then + return nil, session_err +end + +local existing_summaries, ctx_err = session_reader:contexts():type("conversation_summary"):all() +if ctx_err then + existing_summaries = {} +end + +local existing_summary = nil +if existing_summaries and #existing_summaries > 0 then + existing_summary = existing_summaries[1].text +end +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + for _, e := range result.Errors { + t.Logf("error: %s at %d:%d", e.Message, e.Position.Line, e.Position.Column) + } + t.Fatalf("expected real-shaped session query builder to preserve length-guarded element, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + func TestExternalLint_TypeProbeAllowsOptionalDynamicFieldFallback(t *testing.T) { source := ` local page = { @@ -773,3 +1545,36 @@ map_messages(nil :: any) t.Fatalf("expected untyped discriminated source element to feed typed image helper, got: %v", testutil.ErrorMessages(result.Diagnostics)) } } + +func TestExternalLint_CapturedStateFieldMapPairsPreservesValueShape(t *testing.T) { + source := ` +type Time = { + after: (self: Time, other: Time) -> boolean, +} + +type ActiveSession = { + pid: any, + created_at: Time, + last_activity: Time?, +} + +local state = { + active_sessions = {} :: {[string]: ActiveSession}, +} + +local function check() + local most_recent_time: Time? = nil + for sid, session_info in pairs(state.active_sessions) do + local last_activity: Time = session_info.last_activity or session_info.created_at + if not most_recent_time or last_activity:after(most_recent_time) then + most_recent_time = last_activity + end + end + return most_recent_time +end +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected captured state field map pairs to preserve ActiveSession values, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} diff --git a/compiler/check/tests/regression/linter_false_positive_test.go b/compiler/check/tests/regression/linter_false_positive_test.go index 37329489..5606bb36 100644 --- a/compiler/check/tests/regression/linter_false_positive_test.go +++ b/compiler/check/tests/regression/linter_false_positive_test.go @@ -356,6 +356,39 @@ end } } +func TestLinterFalsePositive_GraphLocalUnusedParamAllowsInternalAny(t *testing.T) { + source := ` +local function run_suite(name: string, tests: {any}) + return #tests +end + +local suite_name = nil :: any +local tests: {any} = {} +local count = run_suite(suite_name, tests) +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected unused local parameter not to reject internal any, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestLinterFalsePositive_GraphLocalObservedParamRejectsAny(t *testing.T) { + source := ` +local function run_suite(name: string, tests: {any}) + local label = name .. "" + return label, #tests +end + +local suite_name = nil :: any +local tests: {any} = {} +local label, count = run_suite(suite_name, tests) +` + result := testutil.Check(source, testutil.WithStdlib()) + if !result.HasError() { + t.Fatal("expected observed string parameter to reject explicit any") + } +} + // TestLinterFalsePositive_TestRunnerWithTypedEntries tests with explicitly typed entries // to better match real-world usage where entries come from a typed registry. func TestLinterFalsePositive_TestRunnerWithTypedEntries(t *testing.T) { diff --git a/types/constraint/atom.go b/types/constraint/atom.go index b2d1816b..4d3a6903 100644 --- a/types/constraint/atom.go +++ b/types/constraint/atom.go @@ -337,6 +337,12 @@ func NumericConstraintToAtom(c NumericConstraint) (Atom, bool) { LeLenOf: func(v LeLenOf) result { return result{atom: AtomLe(TermVar(v.X.Key()), TermLen(v.Array.Key())), ok: true} }, + LenLeConst: func(v LenLeConst) result { + return result{atom: AtomLe(TermLen(v.Array.Key()), TermConst(v.C)), ok: true} + }, + LenGeConst: func(v LenGeConst) result { + return result{atom: AtomGe(TermLen(v.Array.Key()), TermConst(v.C)), ok: true} + }, Default: func(NumericConstraint) result { return result{} }, diff --git a/types/constraint/numeric.go b/types/constraint/numeric.go index 16c9bf82..1c5caa5f 100644 --- a/types/constraint/numeric.go +++ b/types/constraint/numeric.go @@ -14,6 +14,7 @@ import "github.com/wippyai/go-lua/internal" // - Constants: [EqConst] (x==c), [LeConst] (x≤c), [GeConst] (x≥c) // - Modular: [ModEq] (x%m==r) // - Symbolic: [LeLenOf] (x≤len(arr)+c) +// - Length bounds: [LenLeConst], [LenGeConst] // // # Usage with Theory Solvers // @@ -27,17 +28,19 @@ import "github.com/wippyai/go-lua/internal" type NumKind uint8 const ( - NumInvalid NumKind = iota - NumLe // x - y <= c - NumLt // x < y - NumGe // x >= y - NumGt // x > y - NumEq // x == y - NumEqConst // x == c - NumLeConst // x <= c - NumGeConst // x >= c - NumModEq // x % m == r - NumLeLenOf // x <= len(arr) + offset + NumInvalid NumKind = iota + NumLe // x - y <= c + NumLt // x < y + NumGe // x >= y + NumGt // x > y + NumEq // x == y + NumEqConst // x == c + NumLeConst // x <= c + NumGeConst // x >= c + NumModEq // x % m == r + NumLeLenOf // x <= len(arr) + offset + NumLenLeConst // len(arr) <= c + NumLenGeConst // len(arr) >= c ) // NumericConstraint is a marker interface for numeric constraints. @@ -191,6 +194,34 @@ func (c LeLenOf) Equals(o NumericConstraint) bool { return ok && c.X.Equal(other.X) && c.Array.Equal(other.Array) && c.Offset == other.Offset } +// LenLeConst represents len(arr) <= c. +type LenLeConst struct { + Array Path + C int64 +} + +func (c LenLeConst) NumKind() NumKind { return NumLenLeConst } +func (c LenLeConst) Paths() []Path { return []Path{c.Array} } +func (c LenLeConst) Hash() uint64 { return hashNumConstraint(c.NumKind(), c.Array, Path{}, c.C) } +func (c LenLeConst) Equals(o NumericConstraint) bool { + other, ok := o.(LenLeConst) + return ok && c.Array.Equal(other.Array) && c.C == other.C +} + +// LenGeConst represents len(arr) >= c. +type LenGeConst struct { + Array Path + C int64 +} + +func (c LenGeConst) NumKind() NumKind { return NumLenGeConst } +func (c LenGeConst) Paths() []Path { return []Path{c.Array} } +func (c LenGeConst) Hash() uint64 { return hashNumConstraint(c.NumKind(), c.Array, Path{}, c.C) } +func (c LenGeConst) Equals(o NumericConstraint) bool { + other, ok := o.(LenGeConst) + return ok && c.Array.Equal(other.Array) && c.C == other.C +} + func hashNumConstraint(kind NumKind, a, b Path, extra ...int64) uint64 { h := internal.HashCombine(uint64(kind), a.Hash()) if !b.IsEmpty() { diff --git a/types/constraint/numeric_test.go b/types/constraint/numeric_test.go index bbd7bb46..881169c3 100644 --- a/types/constraint/numeric_test.go +++ b/types/constraint/numeric_test.go @@ -228,7 +228,7 @@ func TestNumKindValues(t *testing.T) { t.Error("NumInvalid should be 0") } - kinds := []NumKind{NumLe, NumLt, NumGe, NumGt, NumEq, NumEqConst, NumLeConst, NumGeConst, NumModEq, NumLeLenOf} + kinds := []NumKind{NumLe, NumLt, NumGe, NumGt, NumEq, NumEqConst, NumLeConst, NumGeConst, NumModEq, NumLeLenOf, NumLenLeConst, NumLenGeConst} seen := make(map[NumKind]bool) for _, k := range kinds { @@ -281,6 +281,43 @@ func TestLeLenOf(t *testing.T) { } } +func TestLenConstBounds(t *testing.T) { + arr := Path{Root: "arr"} + le := LenLeConst{Array: arr, C: 3} + if le.NumKind() != NumLenLeConst { + t.Errorf("expected NumLenLeConst, got %v", le.NumKind()) + } + if paths := le.Paths(); len(paths) != 1 || !paths[0].Equal(arr) { + t.Fatalf("unexpected LenLeConst paths: %#v", paths) + } + if le.Hash() == 0 { + t.Fatal("LenLeConst hash should be non-zero") + } + if !le.Equals(LenLeConst{Array: arr, C: 3}) { + t.Fatal("equal LenLeConst constraints should be equal") + } + if le.Equals(LenLeConst{Array: arr, C: 4}) { + t.Fatal("different LenLeConst constants should not be equal") + } + + ge := LenGeConst{Array: arr, C: 1} + if ge.NumKind() != NumLenGeConst { + t.Errorf("expected NumLenGeConst, got %v", ge.NumKind()) + } + if paths := ge.Paths(); len(paths) != 1 || !paths[0].Equal(arr) { + t.Fatalf("unexpected LenGeConst paths: %#v", paths) + } + if ge.Hash() == 0 { + t.Fatal("LenGeConst hash should be non-zero") + } + if !ge.Equals(LenGeConst{Array: arr, C: 1}) { + t.Fatal("equal LenGeConst constraints should be equal") + } + if ge.Equals(LenGeConst{Array: Path{Root: "other"}, C: 1}) { + t.Fatal("different LenGeConst arrays should not be equal") + } +} + func TestNumericPathsMethod(t *testing.T) { x := Path{Root: "x"} y := Path{Root: "y"} @@ -300,6 +337,8 @@ func TestNumericPathsMethod(t *testing.T) { {"GeConst", GeConst{X: x, C: 0}, 1}, {"ModEq", ModEq{X: x, M: 3, R: 1}, 1}, {"LeLenOf", LeLenOf{X: x, Array: y}, 2}, + {"LenLeConst", LenLeConst{Array: y, C: 3}, 1}, + {"LenGeConst", LenGeConst{Array: y, C: 1}, 1}, } for _, tc := range tests { diff --git a/types/constraint/visit.go b/types/constraint/visit.go index e23784c0..d81f56d2 100644 --- a/types/constraint/visit.go +++ b/types/constraint/visit.go @@ -182,17 +182,19 @@ func VisitConstraint[R any](c Constraint, v ConstraintVisitor[R]) R { // NumericConstraintVisitor dispatches on numeric constraint variants. // Nil handlers fall back to Default when provided; otherwise return zero. type NumericConstraintVisitor[R any] struct { - Le func(Le) R - Lt func(Lt) R - Ge func(Ge) R - Gt func(Gt) R - Eq func(Eq) R - EqConst func(EqConst) R - LeConst func(LeConst) R - GeConst func(GeConst) R - ModEq func(ModEq) R - LeLenOf func(LeLenOf) R - Default func(NumericConstraint) R + Le func(Le) R + Lt func(Lt) R + Ge func(Ge) R + Gt func(Gt) R + Eq func(Eq) R + EqConst func(EqConst) R + LeConst func(LeConst) R + GeConst func(GeConst) R + ModEq func(ModEq) R + LeLenOf func(LeLenOf) R + LenLeConst func(LenLeConst) R + LenGeConst func(LenGeConst) R + Default func(NumericConstraint) R } // VisitNumericConstraint applies the first matching handler in v to c. @@ -278,6 +280,22 @@ func VisitNumericConstraint[R any](c NumericConstraint, v NumericConstraintVisit if v.LeLenOf != nil { return v.LeLenOf(*cc) } + case LenLeConst: + if v.LenLeConst != nil { + return v.LenLeConst(cc) + } + case *LenLeConst: + if v.LenLeConst != nil { + return v.LenLeConst(*cc) + } + case LenGeConst: + if v.LenGeConst != nil { + return v.LenGeConst(cc) + } + case *LenGeConst: + if v.LenGeConst != nil { + return v.LenGeConst(*cc) + } } if v.Default != nil { return v.Default(c) diff --git a/types/flow/numeric/domain.go b/types/flow/numeric/domain.go index 893f920d..de63a858 100644 --- a/types/flow/numeric/domain.go +++ b/types/flow/numeric/domain.go @@ -88,6 +88,8 @@ func (d *Domain) ApplyAtom(atom constraint.Atom) bool { d.theory.AddBounds(atom.Left.Path, -maxWeight, atom.Right.Const) } else if atom.Left.IsVar() && atom.Right.IsLen() { d.state.ApplyLeLenOf(atom.Left.Path, atom.Right.Path) + } else if atom.Left.IsLen() && atom.Right.IsConst() { + d.state.ApplyLenLeConst(atom.Left.Path, atom.Right.Const) } else if atom.Left.IsVar() && atom.Right.IsVar() { d.state.ApplyLe(atom.Left.Path, atom.Right.Path) d.theory.AddDifferenceConstraint(atom.Left.Path, atom.Right.Path, 0) @@ -96,6 +98,8 @@ func (d *Domain) ApplyAtom(atom constraint.Atom) bool { if atom.Left.IsVar() && atom.Right.IsConst() { d.state.ApplyGeConst(atom.Left.Path, atom.Right.Const) d.theory.AddBounds(atom.Left.Path, atom.Right.Const, maxWeight) + } else if atom.Left.IsLen() && atom.Right.IsConst() { + d.state.ApplyLenGeConst(atom.Left.Path, atom.Right.Const) } else if atom.Left.IsVar() && atom.Right.IsVar() { d.state.ApplyGe(atom.Left.Path, atom.Right.Path) d.theory.AddDifferenceConstraint(atom.Right.Path, atom.Left.Path, 0) diff --git a/types/flow/numeric/state.go b/types/flow/numeric/state.go index dfd7803b..b3b1fb0c 100644 --- a/types/flow/numeric/state.go +++ b/types/flow/numeric/state.go @@ -5,6 +5,7 @@ // - Modular residues: congruence relations (x % m == r) // - Difference constraints: relationships between pairs (x - y <= c) // - Symbolic length bounds: array length references +// - Length bounds: lower and upper limits for len(array) // // The solver uses Bellman-Ford to detect unsatisfiable constraint sets via // negative cycle detection in the difference constraint graph. @@ -52,6 +53,9 @@ type State struct { // Entry x -> {arr, off} means x <= len(arr) + off. lenRefs map[constraint.PathKey]lenRefBound + // lenBounds maps array PathKey to its known len(array) interval. + lenBounds map[constraint.PathKey]Interval + // unsat is true if the state is unsatisfiable. unsat bool } @@ -146,6 +150,13 @@ func (s *State) Clone() *State { } } + if len(s.lenBounds) > 0 { + c.lenBounds = make(map[constraint.PathKey]Interval, len(s.lenBounds)) + for k, v := range s.lenBounds { + c.lenBounds[k] = v + } + } + return c } @@ -166,6 +177,8 @@ func (s *State) Clone() *State { // // - LenRefs: Keep only identical length references from both states. // +// - LenBounds: Keep arrays present in both, take interval intersection. +// // If both states are nil or Bottom, returns the non-Bottom one. // If the result is Top (empty maps), returns nil to save memory. func Join(a, b *State) *State { @@ -232,6 +245,23 @@ func Join(a, b *State) *State { } } + // LenBounds: keep only arrays in both, intersect intervals. + for arr, ai := range a.lenBounds { + if bi, ok := b.lenBounds[arr]; ok { + merged := intersectIntervals(ai, bi) + if merged.Lower > merged.Upper { + return Bottom() + } + + if merged != unboundedInterval { + if result.lenBounds == nil { + result.lenBounds = make(map[constraint.PathKey]Interval, minMapLen(len(a.lenBounds), len(b.lenBounds))) + } + result.lenBounds[arr] = merged + } + } + } + if result.isTop() { return nil } @@ -284,6 +314,12 @@ func (s *State) ensureLenRefs(capacity int) { } } +func (s *State) ensureLenBounds(capacity int) { + if s.lenBounds == nil { + s.lenBounds = make(map[constraint.PathKey]Interval, capacity) + } +} + // ApplyConstraintWithResolver refines the state with a numeric constraint. // // Uses the provided resolver to convert constraint paths to versioned PathKeys. @@ -382,6 +418,22 @@ func (s *State) ApplyConstraintWithResolver(c constraint.NumericConstraint, reso s.applyLeLenOf(xKey, arrKey, nc.Offset) return struct{}{} }, + LenLeConst: func(nc constraint.LenLeConst) struct{} { + arrKey := resolve(nc.Array) + if arrKey == "" { + return struct{}{} + } + s.applyLenLeConst(arrKey, nc.C) + return struct{}{} + }, + LenGeConst: func(nc constraint.LenGeConst) struct{} { + arrKey := resolve(nc.Array) + if arrKey == "" { + return struct{}{} + } + s.applyLenGeConst(arrKey, nc.C) + return struct{}{} + }, }) } @@ -532,6 +584,49 @@ func (s *State) applyLeLenOf(v, arr constraint.PathKey, offset int64) { s.lenRefs[v] = lenRefBound{Array: arr, Offset: offset} } +func (s *State) ApplyLenLeConst(arr constraint.PathKey, c int64) { + s.applyLenLeConst(arr, c) +} + +func (s *State) applyLenLeConst(arr constraint.PathKey, c int64) { + s.ensureLenBounds(1) + if b, ok := s.lenBounds[arr]; ok { + b.Upper = minInt64(b.Upper, c) + if b.Lower > b.Upper { + s.unsat = true + return + } + s.lenBounds[arr] = b + return + } + if c < 0 { + s.unsat = true + return + } + s.lenBounds[arr] = Interval{Lower: 0, Upper: c} +} + +func (s *State) ApplyLenGeConst(arr constraint.PathKey, c int64) { + s.applyLenGeConst(arr, c) +} + +func (s *State) applyLenGeConst(arr constraint.PathKey, c int64) { + if c < 0 { + c = 0 + } + s.ensureLenBounds(1) + if b, ok := s.lenBounds[arr]; ok { + b.Lower = maxInt64(b.Lower, c) + if b.Lower > b.Upper { + s.unsat = true + return + } + s.lenBounds[arr] = b + return + } + s.lenBounds[arr] = Interval{Lower: c, Upper: math.MaxInt64} +} + // BoundsFor returns the interval bounds for a PathKey. // // Returns (lower, upper, true) if the key has known bounds, or (0, 0, false) @@ -548,6 +643,18 @@ func (s *State) BoundsFor(key constraint.PathKey) (lower, upper int64, ok bool) return interval.Lower, interval.Upper, true } +// LenBoundsFor returns the interval bounds for len(key). +func (s *State) LenBoundsFor(key constraint.PathKey) (lower, upper int64, ok bool) { + if s == nil || s.lenBounds == nil { + return 0, 0, false + } + interval, found := s.lenBounds[key] + if !found { + return 0, 0, false + } + return interval.Lower, interval.Upper, true +} + // LenRefFor returns the array key if variable has a symbolic length bound. // // A length reference means "key <= #arrKey" (variable is bounded by array length). @@ -594,6 +701,18 @@ func (s *State) CheckSatisfiability() bool { } } + for _, key := range constraint.SortedPathKeys(s.lenBounds) { + b := s.lenBounds[key] + if b.Lower < 0 { + b.Lower = 0 + s.lenBounds[key] = b + } + if b.Lower > b.Upper { + s.unsat = true + return false + } + } + // Check relation consistency using Bellman-Ford on difference graph. if len(s.relations) > 0 { if !s.checkDifferenceConstraints() { @@ -664,7 +783,7 @@ func (s *State) checkDifferenceConstraints() bool { // Equals checks if two states are semantically equal. // // Two states are equal if they have the same unsat flag and identical maps -// for bounds, modular constraints, relations, and length references. +// for bounds, modular constraints, relations, length references, and length bounds. // nil and Top (empty maps) states are considered equal. func (s *State) Equals(other *State) bool { if s == nil && other == nil { @@ -703,6 +822,10 @@ func (s *State) Equals(other *State) bool { return false } + if len(s.lenBounds) != len(other.lenBounds) { + return false + } + for _, k := range constraint.SortedPathKeys(s.bounds) { v := s.bounds[k] if ov, ok := other.bounds[k]; !ok || v != ov { @@ -731,6 +854,13 @@ func (s *State) Equals(other *State) bool { } } + for _, k := range constraint.SortedPathKeys(s.lenBounds) { + v := s.lenBounds[k] + if ov, ok := other.lenBounds[k]; !ok || v != ov { + return false + } + } + return true } @@ -752,7 +882,7 @@ func (s *State) isTop() bool { return false } - return len(s.bounds) == 0 && len(s.modular) == 0 && len(s.relations) == 0 && len(s.lenRefs) == 0 + return len(s.bounds) == 0 && len(s.modular) == 0 && len(s.relations) == 0 && len(s.lenRefs) == 0 && len(s.lenBounds) == 0 } func minInt64(a, b int64) int64 { @@ -854,6 +984,9 @@ func (s *State) Rekey(remap map[constraint.PathKey]constraint.PathKey) *State { if len(s.lenRefs) > 0 { result.lenRefs = make(map[constraint.PathKey]lenRefBound, len(s.lenRefs)) } + if len(s.lenBounds) > 0 { + result.lenBounds = make(map[constraint.PathKey]Interval, len(s.lenBounds)) + } // Remap bounds for _, k := range constraint.SortedPathKeys(s.bounds) { @@ -921,6 +1054,22 @@ func (s *State) Rekey(remap map[constraint.PathKey]constraint.PathKey) *State { result.lenRefs[newK] = ref } + // Remap length bounds. + for _, k := range constraint.SortedPathKeys(s.lenBounds) { + v := s.lenBounds[k] + newKey := k + if mapped, ok := remap[k]; ok { + newKey = mapped + } + if existing, ok := result.lenBounds[newKey]; ok { + v = intersectIntervals(existing, v) + if v.Lower > v.Upper { + return Bottom() + } + } + result.lenBounds[newKey] = v + } + return result } diff --git a/types/flow/numeric/state_test.go b/types/flow/numeric/state_test.go index c54d9286..4f8734db 100644 --- a/types/flow/numeric/state_test.go +++ b/types/flow/numeric/state_test.go @@ -69,6 +69,30 @@ func TestState_ApplyBounds(t *testing.T) { } } +func TestState_ApplyLenBounds(t *testing.T) { + s := NewState() + s.ApplyLenGeConst("rows", 1) + s.ApplyLenLeConst("rows", 3) + + lower, upper, ok := s.LenBoundsFor("rows") + if !ok { + t.Fatal("expected length bounds") + } + if lower != 1 || upper != 3 { + t.Fatalf("expected len bounds [1, 3], got [%d, %d]", lower, upper) + } +} + +func TestState_ContradictoryLenBounds(t *testing.T) { + s := NewState() + s.ApplyLenGeConst("rows", 1) + s.ApplyLenLeConst("rows", 0) + + if !s.IsUnsat() { + t.Fatal("contradictory length bounds should make state unsat") + } +} + func TestState_ContradictoryBounds(t *testing.T) { s := NewState() s.ApplyGeConst("x", 10) diff --git a/types/flow/query.go b/types/flow/query.go index 9b1675c2..7a2210d3 100644 --- a/types/flow/query.go +++ b/types/flow/query.go @@ -308,6 +308,25 @@ func (s *Solution) ArrayLenBoundWithOffsetAt(p cfg.Point, varName string) (arrKe return string(pathKey), off, true } +// LengthBoundsAt returns numeric bounds for len(path) at a CFG point. +func (s *Solution) LengthBoundsAt(p cfg.Point, path constraint.Path) (lower, upper int64, ok bool) { + if s == nil || s.numericStates == nil { + return 0, 0, false + } + state := s.numericStates[p] + if state == nil { + return 0, 0, false + } + if s.pkResolver == nil || path.IsEmpty() { + return 0, 0, false + } + key := s.pkResolver.KeyAt(p, path) + if key == "" { + return 0, 0, false + } + return state.LenBoundsFor(key) +} + // NarrowedTypeAt returns the type at point p for path, narrowed by the DNF condition. // This is a pure query that composes: baseTypeAt + ConditionAt + applyCondition. func (s *Solution) NarrowedTypeAt(p cfg.Point, path constraint.Path) typ.Type { From 19dfc1700e3ca7257568889473f7e78f04dfdd5e Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 12:45:46 -0400 Subject: [PATCH 26/71] Classify external lint replay reductions --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 51 ++++ .../external_lint_regression_test.go | 260 ++++++++++++++++++ 2 files changed, 311 insertions(+) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index b27acedd..a5cc13fd 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -5117,3 +5117,54 @@ Design rule retained for the next pass: owning domain or transfer rule. - If a diagnostic is true external code, keep it classified and do not edit external Wippy sources from this go-lua PR. + +## 2026-05-19 External Replay Classification Follow-Up + +The local-replace Wippy binary still reports diagnostics in external packages, +but the new reductions did not expose a go-lua engine regression. The important +distinction is that the checker is now refusing to erase dynamic source shapes +that are not proven by the Lua program. + +New regression coverage added in `external_lint_regression_test.go`: + +- optional numeric fields defaulted with `or` become non-nil before arithmetic; +- exported model-card numeric defaults remain non-nil at a consumer; +- imported modules stored in table fields preserve those numeric defaults; +- registry-derived numeric defaults still feed arithmetic after the consumer + guards the optional return; +- guarded string field values inserted into an accumulator retain a string + element type when iterated into a helper call; +- a `type(x) == "table"` guard on an untyped value keeps dynamic field fallback + reads open. + +Classification of the remaining replay clusters: + +- `llm.lua` provider contract calls are true code issues under the current + soundness rule: `provider_info = model_card.providers[1] as any` explicitly + discards the proof that `provider_model` is `string`, then passes + `provider_info.provider_model` to contracts requiring `model: string`. +- Artifact/message metadata field errors are true code issues unless external + code adds a table guard or guaranteed decode. The repositories decode JSON + into `meta`/`metadata` on success but leave the original string when decode + fails, then downstream code accesses fields after only a truthiness guard. +- `json.decode(response.body or "")` and HTTP stream-read diagnostics are still + unreduced package-boundary candidates. The go-lua reductions for optional + response body fallback and guarded stream reads pass, so the observed replay + failures are not the simple `or` transfer rule. +- Bedrock text-block parsing is not a reproduced accumulator regression. The + guarded string accumulator reduction passes; the replay source receives + response blocks from a dynamic API shape, so the value is `any` unless the + external package or manifest proves the field type. +- Docker-demo fixture failures are mostly true fixture/source issues: examples + include `state.iteration_count` being initialized only on the first-iteration + branch before arithmetic, dynamic maps passed to stricter contracts, optional + method receivers called without guards, and generated/vendor stubs whose + contextual shapes do not declare fields they later read. + +Current rule: + +- Keep explicit `any` and `unknown` barriers sound. +- Do not suppress these diagnostics in go-lua without a failing go-lua + reduction that proves the checker lost information it already had. +- External Wippy fixes, if desired, should be explicit guards, casts at real + trust boundaries, or stronger manifests; they are outside this go-lua PR. diff --git a/compiler/check/tests/regression/external_lint_regression_test.go b/compiler/check/tests/regression/external_lint_regression_test.go index 2df29445..3d62342a 100644 --- a/compiler/check/tests/regression/external_lint_regression_test.go +++ b/compiler/check/tests/regression/external_lint_regression_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/wippyai/go-lua/compiler/check/tests/testutil" + "github.com/wippyai/go-lua/types/typ" ) func TestExternalLint_OptionalResponseBodyDefaultIsStringAtCall(t *testing.T) { @@ -1578,3 +1579,262 @@ end t.Fatalf("expected captured state field map pairs to preserve ActiveSession values, got: %v", testutil.ErrorMessages(result.Diagnostics)) } } + +func TestExternalLint_OptionalNumericFieldsDefaultBeforeArithmetic(t *testing.T) { + source := ` +local CONFIG = { + prompt_buffer_tokens = 256, + chars_per_token = 4, +} + +local function tokens_to_chars(tokens) + return math.floor(tokens * CONFIG.chars_per_token) +end + +local function budget(model_card: {max_tokens: integer?, output_tokens: integer?}) + local max_context_tokens = model_card.max_tokens or 8000 + local max_output_tokens = model_card.output_tokens or 1000 + local usable_input_tokens = max_context_tokens - max_output_tokens - CONFIG.prompt_buffer_tokens + return tokens_to_chars(usable_input_tokens) +end + +return budget({ max_tokens = nil, output_tokens = nil }) +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected numeric field defaults to remove nil before arithmetic, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_ExportedNumericDefaultsRemainNonNilAtConsumer(t *testing.T) { + modelsModule := testutil.CheckAndExport(` +local models = {} + +function models._build_model_card(entry) + return { + max_tokens = entry.data and entry.data.max_tokens or 0, + output_tokens = entry.data and entry.data.output_tokens or 0, + } +end + +function models.get_by_name(name) + if not name then + return nil, "name required" + end + return models._build_model_card({ data = {} }), nil +end + +return models +`, "models", testutil.WithStdlib()) + if modelsModule.HasError() { + t.Fatalf("models module errors: %v", testutil.ErrorMessages(modelsModule.Errors)) + } + + source := ` +local models = require("models") + +local function budget(model_name) + local model_card, err = models.get_by_name(model_name) + if not model_card then + return nil, err + end + + local max_context_tokens = model_card.max_tokens or 8000 + local max_output_tokens = model_card.output_tokens or 1000 + return max_context_tokens - max_output_tokens - 256 +end + +return budget("default") +` + result := testutil.Check(source, testutil.WithStdlib(), testutil.WithModule("models", modelsModule)) + if result.HasError() { + t.Fatalf("expected exported numeric defaults to remain non-nil at consumer arithmetic, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_TableFieldModuleNumericDefaultsRemainNonNil(t *testing.T) { + modelsModule := testutil.CheckAndExport(` +local models = {} + +function models._build_model_card(entry) + return { + max_tokens = entry.data and entry.data.max_tokens or 0, + output_tokens = entry.data and entry.data.output_tokens or 0, + } +end + +function models.get_by_name(name) + if not name then + return nil, "name required" + end + return models._build_model_card({ data = {} }), nil +end + +return models +`, "models", testutil.WithStdlib()) + if modelsModule.HasError() { + t.Fatalf("models module errors: %v", testutil.ErrorMessages(modelsModule.Errors)) + } + + source := ` +local models = require("models") + +local compress = { + _models = models, +} + +local function budget(model_name) + local model_card, err = compress._models.get_by_name(model_name) + if not model_card then + return nil, err + end + + local max_context_tokens = model_card.max_tokens or 8000 + local max_output_tokens = model_card.output_tokens or 1000 + return max_context_tokens - max_output_tokens - 256 +end + +return budget("default") +` + result := testutil.Check(source, testutil.WithStdlib(), testutil.WithModule("models", modelsModule)) + if result.HasError() { + t.Fatalf("expected table-held imported module numeric defaults to remain non-nil at consumer arithmetic, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_RegistryDerivedNumericDefaultsRemainNonNil(t *testing.T) { + entryType := typ.NewRecord(). + Field("data", typ.NewOptional(typ.NewRecord(). + Field("max_tokens", typ.NewOptional(typ.Integer)). + Field("output_tokens", typ.NewOptional(typ.Integer)). + Build())). + Build() + registryType := typ.NewRecord(). + Field("find", typ.Func(). + Param("query", typ.NewMap(typ.String, typ.Any)). + Returns(typ.NewOptional(typ.NewArray(entryType)), typ.NewOptional(typ.String)). + Build()). + Build() + + registryModule := testutil.CheckAndExport(` +local registry = {} +function registry.find(query) + return { { data = {} } }, nil +end +return registry +`, "registry", testutil.WithStdlib(), testutil.WithTypes(map[string]typ.Type{ + "registry": registryType, + })) + if registryModule.HasError() { + t.Fatalf("registry module errors: %v", testutil.ErrorMessages(registryModule.Errors)) + } + + modelsModule := testutil.CheckAndExport(` +local registry = require("registry") +local models = {} + +function models._build_model_card(entry) + return { + max_tokens = entry.data and entry.data.max_tokens or 0, + output_tokens = entry.data and entry.data.output_tokens or 0, + } +end + +function models.get_by_name(name) + if not name then + return nil, "name required" + end + local entries, err = registry.find({ name = name }) + if err then + return nil, err + end + if not entries or #entries == 0 then + return nil, "not found" + end + return models._build_model_card(entries[1]) +end + +return models +`, "models", testutil.WithStdlib(), testutil.WithModule("registry", registryModule)) + if modelsModule.HasError() { + t.Fatalf("models module errors: %v", testutil.ErrorMessages(modelsModule.Errors)) + } + + source := ` +local models = require("models") + +local compress = { + _models = models, +} + +local function budget(model_name) + local model_card, err = compress._models.get_by_name(model_name) + if not model_card then + return nil, err + end + + local max_context_tokens = model_card.max_tokens or 8000 + local max_output_tokens = model_card.output_tokens or 1000 + return max_context_tokens - max_output_tokens - 256 +end + +return budget("default") +` + result := testutil.Check(source, testutil.WithStdlib(), + testutil.WithModule("registry", registryModule), + testutil.WithModule("models", modelsModule)) + if result.HasError() { + t.Fatalf("expected registry-derived numeric defaults to remain non-nil at consumer arithmetic, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_GuardedStringFieldAccumulatorFeedsHelper(t *testing.T) { + source := ` +local function parse_text_tool_call(text: string?, tool_names) + if not text or not tool_names then + return nil + end + return { name = text } +end + +local function extract(converse_response: {output: {message: {content: {{text: string?}}}}}, tool_names) + local text_blocks = {} + for _, block in ipairs(converse_response.output.message.content) do + if block.text then + table.insert(text_blocks, block.text) + end + end + + for _, text in ipairs(text_blocks) do + local parsed = parse_text_tool_call(text, tool_names) + if parsed then + return parsed.name + end + end + return nil +end + +return extract({ output = { message = { content = { { text = "call" } } } } }, {}) +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected guarded string field accumulator to feed optional-string helper, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_TypeTableGuardKeepsDynamicFieldReadsOpen(t *testing.T) { + source := ` +local function run(stats_data) + if type(stats_data) == "table" then + return stats_data.sum or stats_data.count or 0 + end + return 0 +end + +return run({}) +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected type(table) guard on untyped value to allow dynamic field fallback reads, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} From 82b686686c017c1917cf25fe35048fc2ae4f7218 Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 13:25:10 -0400 Subject: [PATCH 27/71] Close expression call evidence gaps --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 84 +++ compiler/check/infer/interproc/postflow.go | 159 +++-- .../external_lint_regression_test.go | 610 ++++++++++++++++++ 3 files changed, 785 insertions(+), 68 deletions(-) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index a5cc13fd..05d61225 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -5168,3 +5168,87 @@ Current rule: reduction that proves the checker lost information it already had. - External Wippy fixes, if desired, should be explicit guards, casts at real trust boundaries, or stronger manifests; they are outside this go-lua PR. + +## 2026-05-19 Expression Call Evidence Closure + +One real engine regression remained in the external `compress` replay: local +helper calls nested under returned table fields were not always represented as +call sites for parameter evidence. The old collector handled statement calls, +top-level assignment/return source calls, and nested calls inside call +arguments, but missed expression positions like: + +- `return { field = helper(value) }`; +- `local t = { field = helper(value) }`; +- `if helper(value) then ... end`; +- numeric/generic loop header expressions; +- calls wrapped by casts or non-nil assertions. + +That was a domain bug, not a reason to weaken arithmetic or `any`/`unknown`. +The correction keeps `FunctionFacts.Params` as the only parameter-evidence +authority and expands the collector to visit every call expression that occurs +inside assignment sources, return expressions, branch conditions, and loop +headers. The final implementation walks each owned expression tree once from +its CFG point, so the collector does not need a compatibility call-site channel +or per-point dedupe maps. + +Regression coverage added: + +- returned-table and assigned-table helper calls feed numeric parameter + evidence; +- branch-condition and numeric-for-bound helper calls feed numeric parameter + evidence; +- the original `compress`/test-DSL mutable resolver reduction no longer + pollutes `tokens_to_chars`; +- guarded config update reductions verify unrelated numeric config fields stay + non-optional when call evidence proves the updates are safe; +- existing compress/model-card reductions remain green. +- negative soundness reductions verify the checker does not accept untyped + model-card fields as numbers, explicit `any` provider models as strings, or + untyped response text as `string?` without a real guard, cast, or manifest. + +Verification from this pass: + +```text +go test ./compiler/check/infer/interproc ./compiler/check/tests/regression -count=1 +go test ./... -count=1 +git diff --check +go test ./compiler/check -run '^$' -bench BenchmarkCheck_LargeFunction -benchmem -count=5 +``` + +Results: + +- `go test ./... -count=1` passes. +- `git diff --check` passes. +- `BenchmarkCheck_LargeFunction` is about 1.97-2.15 ms/op, about 1.054 MB/op, + and 10,699 allocs/op on this machine after the expression-call scan. +- A local-replace Wippy binary built from this checkout now reduces the full + `framework/src/llm/src` replay to 9 errors: the known 6 `llm.lua` contract + errors, 1 Bedrock dynamic text-block parser error, and 2 `compress.lua` + arithmetic errors. + +Updated classification: + +- The previous nested-call evidence bug is fixed in go-lua. +- The remaining `wippy.llm.util:compress` errors are now classified as an + external source/manifest proof gap. Replaying the real + `wippy.llm.discovery:models` export locally shows `get_by_name` exports + `max_tokens` and `output_tokens` as `unknown`, because registry `entry.data` + is not typed as numeric. `compress.lua` then uses those fields in arithmetic + after `or` defaults. go-lua must not invent numeric proof across that module + boundary; external code should either type the registry/model-card manifest or + coerce with `tonumber(...) or ` before arithmetic. +- The same soundness boundary covers the Bedrock and `llm.lua` diagnostics: + `block.text` comes from untyped provider JSON, and `provider_info` is + explicitly cast to `any` before being used to build contract args whose + `model` field must be `string`. +- The remaining global replay diagnostics are still true dynamic-boundary or + source-shape issues unless independently reduced to a failing go-lua engine + test. Current local-replace counts: `framework/src/llm/src` 9 errors, + `framework/src/agent/src` 11 errors, `session` 38 errors, and `docker-demo` + 60 errors. +- Standard `../scripts/verify-suite.sh` still exits non-zero because external + lint targets fail under the Wippy repo's pinned `github.com/wippyai/go-lua + v1.5.16` build, but the go-lua checker tests and Wippy binary build pass. + The external counts from that official path are currently `session` 8 errors, + `framework/src/agent/src` 6 errors, and `docker-demo` 21 errors plus + 2 warnings. diff --git a/compiler/check/infer/interproc/postflow.go b/compiler/check/infer/interproc/postflow.go index 23904668..44fce65f 100644 --- a/compiler/check/infer/interproc/postflow.go +++ b/compiler/check/infer/interproc/postflow.go @@ -483,19 +483,98 @@ func CollectParameterEvidenceFromResult(store Store, result *api.FuncResult, par } } - graph.EachCallSite(func(p cfg.Point, info *cfg.CallInfo) { - collectCallEvidence(p, info) + var collectExprCall func(cfg.Point, ast.Expr) + collectExprCalls := func(p cfg.Point, exprs []ast.Expr) { + if len(exprs) == 0 { + return + } + for _, expr := range exprs { + collectExprCall(p, expr) + } + } + collectCallExpr := func(p cfg.Point, call *ast.FuncCallExpr) { + if call == nil { + return + } + callInfo := graph.CallSiteAt(p, call) + if callInfo == nil { + callInfo = synthCallInfoFromExpr(call, bindings) + } + collectCallEvidence(p, callInfo) + collectExprCall(p, call.Func) + collectExprCall(p, call.Receiver) + collectExprCalls(p, call.Args) + } + collectExprCall = func(p cfg.Point, expr ast.Expr) { + if expr == nil { + return + } + switch e := expr.(type) { + case *ast.FuncCallExpr: + collectCallExpr(p, e) + case *ast.AttrGetExpr: + collectExprCall(p, e.Object) + collectExprCall(p, e.Key) + case *ast.TableExpr: + for _, field := range e.Fields { + if field == nil { + continue + } + collectExprCall(p, field.Key) + collectExprCall(p, field.Value) + } + case *ast.LogicalOpExpr: + collectExprCall(p, e.Lhs) + collectExprCall(p, e.Rhs) + case *ast.RelationalOpExpr: + collectExprCall(p, e.Lhs) + collectExprCall(p, e.Rhs) + case *ast.StringConcatOpExpr: + collectExprCall(p, e.Lhs) + collectExprCall(p, e.Rhs) + case *ast.ArithmeticOpExpr: + collectExprCall(p, e.Lhs) + collectExprCall(p, e.Rhs) + case *ast.UnaryMinusOpExpr: + collectExprCall(p, e.Expr) + case *ast.UnaryNotOpExpr: + collectExprCall(p, e.Expr) + case *ast.UnaryLenOpExpr: + collectExprCall(p, e.Expr) + case *ast.UnaryBNotOpExpr: + collectExprCall(p, e.Expr) + case *ast.CastExpr: + collectExprCall(p, e.Expr) + case *ast.NonNilAssertExpr: + collectExprCall(p, e.Expr) + } + } - seenNested := make(map[*ast.FuncCallExpr]struct{}) - for _, arg := range info.Args { - collectNestedFuncCalls(arg, seenNested) - } - for nested := range seenNested { - nestedInfo := graph.CallSiteAt(p, nested) - if nestedInfo == nil { - nestedInfo = synthCallInfoFromExpr(nested, bindings) + graph.EachStmtCall(func(p cfg.Point, info *cfg.CallInfo) { + if info == nil { + return + } + collectCallExpr(p, info.Call) + }) + graph.EachAssign(func(p cfg.Point, info *cfg.AssignInfo) { + if info != nil { + collectExprCalls(p, info.Sources) + collectExprCalls(p, info.IterExprs) + if info.NumericFor != nil { + collectExprCall(p, info.NumericFor.Init) + collectExprCall(p, info.NumericFor.Limit) + collectExprCall(p, info.NumericFor.Step) } - collectCallEvidence(p, nestedInfo) + } + }) + graph.EachReturn(func(p cfg.Point, info *cfg.ReturnInfo) { + if info != nil { + collectExprCalls(p, info.Exprs) + } + }) + graph.EachBranch(func(p cfg.Point, info *cfg.BranchInfo) { + if info != nil { + collectExprCall(p, info.Condition) } }) } @@ -504,17 +583,7 @@ func synthCallInfoFromExpr(ex *ast.FuncCallExpr, bindings *bind.BindingTable) *c if ex == nil { return nil } - info := &cfg.CallInfo{ - Call: ex, - Callee: ex.Func, - Args: ex.Args, - Method: ex.Method, - Receiver: ex.Receiver, - IsStmt: false, - } - if id, ok := ex.Func.(*ast.IdentExpr); ok { - info.CalleeName = id.Value - } + info := cfg.BuildCallInfo(ex, false) if bindings != nil { info.CalleeSymbol = checkcallsite.SymbolFromExpr(ex.Func, bindings) if ex.Receiver != nil { @@ -531,52 +600,6 @@ func synthCallInfoFromExpr(ex *ast.FuncCallExpr, bindings *bind.BindingTable) *c return info } -func collectNestedFuncCalls(expr ast.Expr, out map[*ast.FuncCallExpr]struct{}) { - if expr == nil || out == nil { - return - } - switch e := expr.(type) { - case *ast.FuncCallExpr: - out[e] = struct{}{} - collectNestedFuncCalls(e.Func, out) - collectNestedFuncCalls(e.Receiver, out) - for _, arg := range e.Args { - collectNestedFuncCalls(arg, out) - } - case *ast.AttrGetExpr: - collectNestedFuncCalls(e.Object, out) - collectNestedFuncCalls(e.Key, out) - case *ast.TableExpr: - for _, field := range e.Fields { - if field == nil { - continue - } - collectNestedFuncCalls(field.Key, out) - collectNestedFuncCalls(field.Value, out) - } - case *ast.LogicalOpExpr: - collectNestedFuncCalls(e.Lhs, out) - collectNestedFuncCalls(e.Rhs, out) - case *ast.RelationalOpExpr: - collectNestedFuncCalls(e.Lhs, out) - collectNestedFuncCalls(e.Rhs, out) - case *ast.StringConcatOpExpr: - collectNestedFuncCalls(e.Lhs, out) - collectNestedFuncCalls(e.Rhs, out) - case *ast.ArithmeticOpExpr: - collectNestedFuncCalls(e.Lhs, out) - collectNestedFuncCalls(e.Rhs, out) - case *ast.UnaryMinusOpExpr: - collectNestedFuncCalls(e.Expr, out) - case *ast.UnaryNotOpExpr: - collectNestedFuncCalls(e.Expr, out) - case *ast.UnaryLenOpExpr: - collectNestedFuncCalls(e.Expr, out) - case *ast.UnaryBNotOpExpr: - collectNestedFuncCalls(e.Expr, out) - } -} - func parentGraphKeyForCallee(store Store, result *api.FuncResult, parent *scope.State, calleeSym cfg.SymbolID) (api.GraphKey, bool) { if store == nil || result == nil || result.Graph == nil || calleeSym == 0 { return api.GraphKey{}, false diff --git a/compiler/check/tests/regression/external_lint_regression_test.go b/compiler/check/tests/regression/external_lint_regression_test.go index 3d62342a..759153b8 100644 --- a/compiler/check/tests/regression/external_lint_regression_test.go +++ b/compiler/check/tests/regression/external_lint_regression_test.go @@ -1,6 +1,7 @@ package regression import ( + "strings" "testing" "github.com/wippyai/go-lua/compiler/check/tests/testutil" @@ -1788,6 +1789,604 @@ return budget("default") } } +func TestExternalLint_CompressModelInfoNumericHelpersStayNonNil(t *testing.T) { + modelsModule := testutil.CheckAndExport(` +local models = {} + +function models.get_by_name(name) + if not name then + return nil, "name required" + end + return { + max_tokens = 8000, + output_tokens = 1000, + }, nil +end + +return models +`, "models", testutil.WithStdlib()) + if modelsModule.HasError() { + t.Fatalf("models module errors: %v", testutil.ErrorMessages(modelsModule.Errors)) + } + + source := ` +local models = require("models") + +local compress = { + _models = models, +} + +local CONFIG = { + chars_per_token = 4, + prompt_buffer_tokens = 500, + context_safety_margin = 0.1, + output_buffer_tokens = 200, +} + +local function tokens_to_chars(tokens) + return math.floor(tokens * CONFIG.chars_per_token) +end + +local function chars_to_tokens(chars) + return math.floor((tonumber(chars) or 0) / CONFIG.chars_per_token) +end + +local function get_model_info(model_name, mock_model_info) + if mock_model_info then + return mock_model_info, nil + end + + local model_card, err = compress._models.get_by_name(model_name) + if not model_card then + return nil, err + end + + local max_context_tokens = model_card.max_tokens or 8000 + local max_output_tokens = model_card.output_tokens or 1000 + local usable_input_tokens = max_context_tokens - max_output_tokens - CONFIG.prompt_buffer_tokens + local usable_input_chars = tokens_to_chars(usable_input_tokens) + local safe_input_chars = math.floor(usable_input_chars * (1 - CONFIG.context_safety_margin)) + local safe_output_chars = tokens_to_chars(max_output_tokens) + + return { + max_context_tokens = max_context_tokens, + max_output_tokens = max_output_tokens, + usable_input_chars = safe_input_chars, + usable_input_tokens = chars_to_tokens(safe_input_chars), + max_output_chars = safe_output_chars, + }, nil +end + +local function calculate_safe_max_tokens(target_chars, model_info) + local needed_tokens = chars_to_tokens(target_chars) + CONFIG.output_buffer_tokens + return math.min(needed_tokens, tonumber(model_info.max_output_tokens) or 1000) +end + +function compress.to_size(model_name, content, target_chars, options, mock_model_info) + options = options or {} + local model_info, err = get_model_info(model_name, mock_model_info) + if err then + return nil, err + end + model_info = assert(model_info) + return calculate_safe_max_tokens(target_chars, model_info) +end + +function compress.get_stats(model_name, content, target_chars, mock_model_info) + local model_info, err = get_model_info(model_name, mock_model_info) + if err then + return nil, err + end + + local content_chars = #content + return { + model_max_context_tokens = model_info.max_context_tokens, + model_max_output_tokens = model_info.max_output_tokens, + model_usable_input_chars = model_info.usable_input_chars, + model_max_output_chars = model_info.max_output_chars, + fits_in_context = content_chars <= model_info.usable_input_chars, + safe_max_tokens_for_target = calculate_safe_max_tokens(target_chars, model_info), + } +end + +return compress.to_size("model", "content", 1000, nil, nil) +` + result := testutil.Check(source, testutil.WithStdlib(), testutil.WithModule("models", modelsModule)) + if result.HasError() { + t.Fatalf("expected compress-style numeric helpers to keep defaulted tokens non-nil, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_UnionModelResolverGuardKeepsNumericDefaultsNonNil(t *testing.T) { + source := ` +local resolver +if unknown_condition then + resolver = { + get_by_name = function(model_name) + return { + max_tokens = 128000, + output_tokens = 16384, + }, nil + end, + } +else + resolver = { + get_by_name = function(model_name) + return nil, "Model not found" + end, + } +end + +local function get_model_info(model_name) + local model_card, err = resolver.get_by_name(model_name) + if not model_card then + return nil, err + end + + local max_context_tokens = model_card.max_tokens or 8000 + local max_output_tokens = model_card.output_tokens or 1000 + return max_context_tokens - max_output_tokens - 500 +end + +return get_model_info("model") +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected guarded union resolver return to keep numeric defaults non-nil, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_MutableModelResolverFieldGuardKeepsNumericDefaultsNonNil(t *testing.T) { + source := ` +local compress = { + _models = { + get_by_name = function(model_name) + return { + max_tokens = 128000, + output_tokens = 16384, + }, nil + end, + }, +} + +if unknown_condition then + compress._models = { + get_by_name = function(model_name) + return nil, "Model not found" + end, + } +end + +local function tokens_to_chars(tokens) + return math.floor(tokens * 4) +end + +local function get_model_info(model_name) + local model_card, err = compress._models.get_by_name(model_name) + if not model_card then + return nil, err + end + + local max_context_tokens = model_card.max_tokens or 8000 + local max_output_tokens = model_card.output_tokens or 1000 + local usable_input_tokens = max_context_tokens - max_output_tokens - 500 + return tokens_to_chars(usable_input_tokens), tokens_to_chars(max_output_tokens) +end + +return get_model_info("model") +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected mutable resolver field guard to keep numeric defaults non-nil, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_NestedReturnTableCallFeedsHelperParamEvidence(t *testing.T) { + source := ` +local function tokens_to_chars(tokens) + return math.floor(tokens * 4) +end + +local function model_info() + local usable_input_tokens = 6500 + local max_output_tokens = 1000 + return { + usable_input_chars = tokens_to_chars(usable_input_tokens), + max_output_chars = tokens_to_chars(max_output_tokens), + } +end + +return model_info().usable_input_chars +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected calls nested in returned table fields to feed helper parameter evidence, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_NestedAssignedTableCallFeedsHelperParamEvidence(t *testing.T) { + source := ` +local function tokens_to_chars(tokens) + return math.floor(tokens * 4) +end + +local function model_info() + local usable_input_tokens = 6500 + local info = { + usable_input_chars = tokens_to_chars(usable_input_tokens), + } + return info +end + +return model_info().usable_input_chars +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected calls nested in assigned table fields to feed helper parameter evidence, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_ConditionCallFeedsHelperParamEvidence(t *testing.T) { + source := ` +local function has_budget(tokens) + return math.floor(tokens) > 0 +end + +local function run() + local usable_input_tokens = 6500 + if has_budget(usable_input_tokens) then + return true + end + return false +end + +return run() +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected calls in branch conditions to feed helper parameter evidence, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_NumericForBoundCallFeedsHelperParamEvidence(t *testing.T) { + source := ` +local function clamp_bound(tokens) + return math.floor(tokens) +end + +local function run() + local max_tokens = 3 + local total = 0 + for i = 1, clamp_bound(max_tokens) do + total = total + i + end + return total +end + +return run() +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected calls in numeric for bounds to feed helper parameter evidence, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_GuardedConfigUpdateKeepsUnchangedNumericFields(t *testing.T) { + source := ` +local CONFIG = { + chars_per_token = 4, + prompt_buffer_tokens = 500, + default_temperature = 0.2, +} + +local function tokens_to_chars(tokens) + return math.floor(tokens * CONFIG.chars_per_token) +end + +local function usable_chars() + local max_context_tokens = 8000 + local max_output_tokens = 1000 + local usable_input_tokens = max_context_tokens - max_output_tokens - CONFIG.prompt_buffer_tokens + return tokens_to_chars(usable_input_tokens) +end + +local function configure(new_config) + for key, value in pairs(new_config) do + if CONFIG[key] ~= nil then + CONFIG[key] = value + end + end +end + +configure({ default_temperature = 0.8 }) +configure({ unknown_key = "value" }) +return usable_chars() +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected guarded config updates not to optionalize unrelated numeric fields, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_GuardedExportedConfigUpdateKeepsUnchangedNumericFields(t *testing.T) { + source := ` +local compress = {} +local CONFIG = { + chars_per_token = 4, + prompt_buffer_tokens = 500, + default_temperature = 0.2, +} + +local function tokens_to_chars(tokens) + return math.floor(tokens * CONFIG.chars_per_token) +end + +local function usable_chars() + local max_context_tokens = 8000 + local max_output_tokens = 1000 + local usable_input_tokens = max_context_tokens - max_output_tokens - CONFIG.prompt_buffer_tokens + return tokens_to_chars(usable_input_tokens) +end + +function compress.configure(new_config) + for key, value in pairs(new_config) do + if CONFIG[key] ~= nil then + CONFIG[key] = value + end + end +end + +compress.configure({ default_temperature = 0.8 }) +compress.configure({ unknown_key = "value" }) +return usable_chars() +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected guarded exported config updates not to optionalize unrelated numeric fields, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_GuardedConfigRoundTripKeepsUnchangedNumericFields(t *testing.T) { + source := ` +local compress = {} +local CONFIG = { + chars_per_token = 4, + prompt_buffer_tokens = 500, + default_temperature = 0.2, +} + +local function tokens_to_chars(tokens) + return math.floor(tokens * CONFIG.chars_per_token) +end + +local function usable_chars() + local max_context_tokens = 8000 + local max_output_tokens = 1000 + local usable_input_tokens = max_context_tokens - max_output_tokens - CONFIG.prompt_buffer_tokens + return tokens_to_chars(usable_input_tokens) +end + +function compress.configure(new_config) + for key, value in pairs(new_config) do + if CONFIG[key] ~= nil then + CONFIG[key] = value + end + end +end + +function compress.get_config() + local config_copy = {} + for key, value in pairs(CONFIG) do + config_copy[key] = value + end + return config_copy +end + +local original = compress.get_config().default_temperature +compress.configure({ default_temperature = 0.8 }) +compress.configure({ default_temperature = original }) +compress.configure({ unknown_key = "value" }) +return usable_chars() +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected guarded config round-trip not to optionalize unrelated numeric fields, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_TestDslMutableModelResolverDoesNotPolluteNumericHelper(t *testing.T) { + source := ` +local test = {} +function test.describe(_name: string, fn: fun()) fn() end +function test.it(_name: string, fn: fun()) fn() end +function test.run_cases(define_cases_fn: fun()) + return function() + _G.describe = test.describe + _G.it = test.it + define_cases_fn() + _G.describe = nil + _G.it = nil + end +end + +local compress = { + _models = { + get_by_name = function(model_name) + return { + max_tokens = 8000, + output_tokens = 1000, + }, nil + end, + }, +} + +local function tokens_to_chars(tokens) + return math.floor(tokens * 4) +end + +local function get_model_info(model_name) + local model_card, err = compress._models.get_by_name(model_name) + if not model_card then + return nil, err + end + + local max_context_tokens = model_card.max_tokens or 8000 + local max_output_tokens = model_card.output_tokens or 1000 + local usable_input_tokens = max_context_tokens - max_output_tokens - 500 + return { + usable_input_chars = tokens_to_chars(usable_input_tokens), + max_output_chars = tokens_to_chars(max_output_tokens), + }, nil +end + +function compress.to_size(model_name: string, content: string, target_chars: number) + local model_info, err = get_model_info(model_name) + if err then + return nil, err + end + if not model_info then + return nil, err + end + return model_info.usable_input_chars +end + +local function define_tests() + describe("compress", function() + it("uses a large model", function() + compress._models = { + get_by_name = function(model_name) + return { + max_tokens = 128000, + output_tokens = 16384, + }, nil + end, + } + return compress.to_size("gpt-4o-mini", "content", 100) + end) + + it("handles model not found", function() + compress._models = { + get_by_name = function(model_name) + return nil, "Model not found" + end, + } + return compress.to_size("unknown-model", "content", 100) + end) + end) +end + +return test.run_cases(define_tests) +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected test DSL mutable model mocks not to pollute numeric helper, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_ModelCardBuilderPreludeDoesNotOptionalizeNumericDefaults(t *testing.T) { + source := ` +local function build_model_card(entry) + local dimensions: number? = nil + if type(entry.data) == "table" then + local parsed_dimensions = tonumber(entry.data.dimensions) + if type(parsed_dimensions) == "number" then + dimensions = parsed_dimensions + end + end + + return { + max_tokens = entry.data and entry.data.max_tokens or 0, + output_tokens = entry.data and entry.data.output_tokens or 0, + dimensions = dimensions, + } +end + +local entry: {data: {max_tokens: integer?, output_tokens: integer?, dimensions: any?}?} = { + data = { max_tokens = 1000 }, +} +local model_card = build_model_card(entry) +local max_context_tokens = model_card.max_tokens or 8000 +local max_output_tokens = model_card.output_tokens or 1000 +return max_context_tokens - max_output_tokens - 500 +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected model-card builder prelude not to optionalize numeric defaults, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_DynamicModelCardNumericFieldsRequireProof(t *testing.T) { + source := ` +local function build_model_card(entry) + return { + max_tokens = entry.data and entry.data.max_tokens or 0, + output_tokens = entry.data and entry.data.output_tokens or 0, + } +end + +local entry = { + data = unknown_condition and { + max_tokens = "not numeric", + output_tokens = 1000, + } or { + max_tokens = 8000, + output_tokens = 1000, + }, +} + +local model_card = build_model_card(entry) +local max_context_tokens = model_card.max_tokens or 8000 +local max_output_tokens = model_card.output_tokens or 1000 +return max_context_tokens - max_output_tokens +` + result := testutil.Check(source, testutil.WithStdlib()) + requireExternalLintErrorContaining(t, result, "cannot perform arithmetic") +} + +func TestExternalLint_AnyProviderModelRequiresStringProof(t *testing.T) { + source := ` +local function merge_provider_options(contract_args: {model: string, options: table}, provider_info) + return contract_args +end + +local provider_info = { + provider_model = "gpt-4o-mini", +} as any + +local contract_args = { + model = provider_info.provider_model, + options = {}, +} + +return merge_provider_options(contract_args, provider_info) +` + result := testutil.Check(source, testutil.WithStdlib()) + requireExternalLintErrorContaining(t, result, "expected") +} + +func TestExternalLint_DynamicResponseTextRequiresStringProof(t *testing.T) { + source := ` +local function parse_text_tool_call(text: string?, tool_names) + return text +end + +local converse_response = unknown_response +local text_blocks = {} + +for _, block in ipairs(converse_response.output.message.content) do + if block.text then + table.insert(text_blocks, block.text) + end +end + +for _, text in ipairs(text_blocks) do + parse_text_tool_call(text, {}) +end +` + result := testutil.Check(source, testutil.WithStdlib()) + requireExternalLintErrorContaining(t, result, "expected string?") +} + func TestExternalLint_GuardedStringFieldAccumulatorFeedsHelper(t *testing.T) { source := ` local function parse_text_tool_call(text: string?, tool_names) @@ -1822,6 +2421,17 @@ return extract({ output = { message = { content = { { text = "call" } } } } }, { } } +func requireExternalLintErrorContaining(t *testing.T, result *testutil.Result, want string) { + t.Helper() + if !result.HasError() { + t.Fatalf("expected diagnostic containing %q, got no errors", want) + } + messages := strings.Join(testutil.ErrorMessages(result.Diagnostics), " | ") + if !strings.Contains(messages, want) { + t.Fatalf("expected diagnostic containing %q, got: %s", want, messages) + } +} + func TestExternalLint_TypeTableGuardKeepsDynamicFieldReadsOpen(t *testing.T) { source := ` local function run(stats_data) From 7eccaa780726f911d22d088a21c7e481cf2b3894 Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 13:32:28 -0400 Subject: [PATCH 28/71] Classify remaining external lint errors --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 56 +++++++- .../external_lint_regression_test.go | 125 ++++++++++++++++++ 2 files changed, 180 insertions(+), 1 deletion(-) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index 05d61225..c0c76cd6 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -5250,5 +5250,59 @@ Updated classification: lint targets fail under the Wippy repo's pinned `github.com/wippyai/go-lua v1.5.16` build, but the go-lua checker tests and Wippy binary build pass. The external counts from that official path are currently `session` 8 errors, - `framework/src/agent/src` 6 errors, and `docker-demo` 21 errors plus + `framework/src/agent/src` 8 errors, and `docker-demo` 21 errors plus 2 warnings. + +## 2026-05-19 Remaining External Error Classification Pass + +The remaining official lint failures were replayed with the exact failing +targets and then reduced against the current go-lua checker. The purpose was to +separate real checker regressions from external source/manifest obligations. + +Additional reductions added in this pass: + +- stdlib `json.decode(response.body or "")` accepts a `string?` body fallback; +- a casted truthiness-guarded field feeds a method argument expecting `string`; +- a casted table-literal field satisfies an annotated record field; +- `#xs > 0` proves both `xs[1]` and `xs[#xs]` access in the reduced sequence + cases; +- an error-return guard narrows the successful value before field access. + +These reductions pass, so the remaining package-level errors are not the +generic transfer laws above. Current classification: + +- `json.decode(response.body or "")` diagnostics in the LLM packages are still + package-boundary issues. The local reductions for stdlib JSON, imported JSON, + and selected HTTP methods pass; the full packages depend on external + `http_client` response manifests and stream surfaces outside go-lua. +- `response.stream:read(4096)` is a native/manifest arity issue, not a checker + flow regression. +- `wippy.views:renderer` casted field calls and + `wippy.views.api:list_pages` casted table fields are covered by reductions. + Remaining full-package errors depend on the external page-registry export + shape and should be fixed with stronger manifests or source guards/casts in + the views package, not by weakening go-lua. +- Metadata field errors on `meta`/`metadata` are real source-shape problems: + empty strings are truthy in Lua, so a truthiness guard alone does not prove a + decoded table. +- Dynamic payload and provider diagnostics (`any`/`unknown` passed to string, + number, contract-argument, time, or typed-option APIs) remain true dynamic + boundary errors unless the external package provides a manifest, schema + decoder, guard, or cast. +- Docker/webscout timeout and header diagnostics are source/manifest issues: + `options.timeout = options.timeout or 30` preserves an existing truthy string, + so a sound checker cannot turn that into `number`. + +Verification after adding these reductions: + +```text +go test ./compiler/check/tests/regression -count=1 +go test ./... -count=1 +git diff --check +../scripts/verify-suite.sh +``` + +The go-lua tests and diff check pass. The official verify suite still exits +non-zero only on external lint targets: `session` 8 errors, +`framework/src/agent/src` 8 errors, and `docker-demo` 21 errors plus +2 warnings. diff --git a/compiler/check/tests/regression/external_lint_regression_test.go b/compiler/check/tests/regression/external_lint_regression_test.go index 759153b8..20af0be1 100644 --- a/compiler/check/tests/regression/external_lint_regression_test.go +++ b/compiler/check/tests/regression/external_lint_regression_test.go @@ -143,6 +143,23 @@ end } } +func TestExternalLint_StdlibJsonOptionalResponseBodyDefaultIsStringAtCall(t *testing.T) { + source := ` +local json = require("json") + +type Response = { + body: string?, +} + +local response: Response = {} +return json.decode(response.body or "") +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected stdlib json optional body fallback to feed string call argument, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + func TestExternalLint_GuardedOptionsModelSurvivesProviderBranches(t *testing.T) { source := ` local models = { @@ -493,6 +510,114 @@ return content, render_err } } +func TestExternalLint_TruthinessGuardedFieldCastFeedsMethodCall(t *testing.T) { + source := ` +local funcs = {} +function funcs.new() + return { + call = function(self, id: string, context) + return context, nil + end, + } +end + +local function get_page_data(page) + if not page or not page.data_func or page.data_func == "" then + return {}, nil + end + local executor = funcs.new() + return executor:call(page.data_func :: string, {}) +end + +return get_page_data({ data_func = true }) +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected explicit cast of truthy guarded field to feed method argument checking, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_TableLiteralFieldCastSatisfiesRecordAssignment(t *testing.T) { + source := ` +type PageResponse = { + id: string, + configOverrides: {[string]: any}?, +} + +local page = { + id = "home", + config_overrides = dynamic_config, +} + +local page_info: PageResponse = { + id = type(page.id) == "string" and page.id or tostring(page.id), + configOverrides = page.config_overrides :: {[string]: any}?, +} + +return page_info +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected casted table-literal field to satisfy record assignment, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_LengthGuardNarrowsLastElementIndex(t *testing.T) { + source := ` +type YieldResult = { + content: string, +} + +local function latest_content(yield_result_data: {YieldResult}?) + if not yield_result_data or #yield_result_data == 0 then + return nil, "No yield result data found" + end + local latest_yield_result = yield_result_data[#yield_result_data] + return latest_yield_result.content +end + +return latest_content({ { content = "ok" } }) +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected positive length guard to prove last element index, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestExternalLint_ErrorReturnGuardNarrowsCurrentValue(t *testing.T) { + source := ` +type Content = { + metadata: {[string]: any}?, +} + +local content_repo = {} +function content_repo.get(content_id): (Content?, string?) + if content_id == "" then + return nil, "not found" + end + return { metadata = {} }, nil +end + +local function update_metadata(content_id, metadata) + local current, err = content_repo.get(content_id) + if err then + return nil, "Failed to get current metadata: " .. err + end + local current_metadata = current.metadata or {} + for k, v in pairs(metadata) do + current_metadata[k] = v + end + return current_metadata +end + +return update_metadata("id", { kind = "text" }) +` + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected error-return guard to narrow current value before field access, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + func TestExternalLint_InsertedSuiteShapeSurvivesIpairs(t *testing.T) { source := ` type Suite = { From ebb3639e51b89e83c39d6786273278aee79ae9bd Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 13:46:22 -0400 Subject: [PATCH 29/71] Add advanced type system stress regressions --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 47 ++ .../advanced_type_system_stress_test.go | 429 ++++++++++++++++++ .../advanced-type-system-stress/events.lua | 93 ++++ .../advanced-type-system-stress/main.lua | 50 ++ .../advanced-type-system-stress/manifest.json | 1 + .../advanced-type-system-stress/pipeline.lua | 32 ++ .../request_builder.lua | 78 ++++ .../advanced-type-system-stress/sessions.lua | 45 ++ 8 files changed, 775 insertions(+) create mode 100644 compiler/check/tests/regression/advanced_type_system_stress_test.go create mode 100644 testdata/fixtures/realworld/advanced-type-system-stress/events.lua create mode 100644 testdata/fixtures/realworld/advanced-type-system-stress/main.lua create mode 100644 testdata/fixtures/realworld/advanced-type-system-stress/manifest.json create mode 100644 testdata/fixtures/realworld/advanced-type-system-stress/pipeline.lua create mode 100644 testdata/fixtures/realworld/advanced-type-system-stress/request_builder.lua create mode 100644 testdata/fixtures/realworld/advanced-type-system-stress/sessions.lua diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index c0c76cd6..2106110a 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -5306,3 +5306,50 @@ The go-lua tests and diff check pass. The official verify suite still exits non-zero only on external lint targets: `session` 8 errors, `framework/src/agent/src` 8 errors, and `docker-demo` 21 errors plus 2 warnings. + +## 2026-05-19 Advanced Type-System Stress Regressions + +Added a focused regression suite and a real-world fixture whose purpose is to +stress the current abstract-interpreter model without weakening soundness. + +The Go regression suite covers: + +- dynamic decode into a discriminated `Event` union after explicit `type(...)` + guards, followed by variant-specific field access; +- `(value?, err?)` multi-return correlation through higher-order callbacks; +- fluent builder state preservation through explicit self-typed methods; +- manifest/module export of tagged results and callback parameter shapes; +- generic `Result` combinators that preserve payload type parameters across + `map`, `and_then`, nested callbacks, and discriminant narrowing; +- nested config builders with typed arrays and string maps; +- negative soundness cases where truthy string fallbacks must not become + numbers, and a truthiness guard over `string | record` must not prove record + field access because Lua strings are truthy. + +Added fixture: + +- `testdata/fixtures/realworld/advanced-type-system-stress` + +The fixture runs the same laws through the repository fixture harness with +separate modules for event decoding, session creation, a metatable-style +request builder, and pipeline config. The entrypoint validates cross-module +manifest exports and includes inline `expect-error` checks for the two +soundness boundaries. + +One attempted fixture assertion was intentionally tightened: assigning +`first.config.level` directly to `string` from a `{[string]: any}` config is not +sound. The fixture now proves the local value with `type(level) == "string"` +before claiming it. This is the right model boundary: the checker should infer +what is proven by control flow and manifests, not invent structure out of +dynamic `any`. + +Verification: + +```text +go test ./compiler/check/tests/regression -run 'TestAdvancedTypeSystem' -count=1 -v +go test . -run 'TestFixtures/realworld/advanced-type-system-stress/check' -count=1 -v +go test ./... -count=1 +git diff --check +``` + +All checks pass. diff --git a/compiler/check/tests/regression/advanced_type_system_stress_test.go b/compiler/check/tests/regression/advanced_type_system_stress_test.go new file mode 100644 index 00000000..9d87edcf --- /dev/null +++ b/compiler/check/tests/regression/advanced_type_system_stress_test.go @@ -0,0 +1,429 @@ +package regression + +import ( + "strings" + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +func TestAdvancedTypeSystem_DiscriminatedEventPipelineWithDynamicDecode(t *testing.T) { + source := ` +type MessageEvent = {kind: "message", id: string, text: string, tags: {string}?} +type ToolEvent = {kind: "tool", id: string, name: string, arguments: {[string]: any}} +type ErrorEvent = {kind: "error", id: string, error: {code: string, message: string}} +type Event = MessageEvent | ToolEvent | ErrorEvent + +local function require_string(value, fallback: string): string + if type(value) == "string" then + return value + end + return fallback +end + +local function string_array(value): {string}? + if type(value) ~= "table" then + return nil + end + local out: {string} = {} + for _, item in ipairs(value) do + if type(item) == "string" then + table.insert(out, item) + end + end + return out +end + +local function decode_event(raw: any): (Event?, string?) + if type(raw) ~= "table" then + return nil, "event must be a table" + end + + if raw.kind == "message" then + return { + kind = "message", + id = require_string(raw.id, ""), + text = require_string(raw.text, ""), + tags = string_array(raw.tags), + }, nil + end + + if raw.kind == "tool" then + return { + kind = "tool", + id = require_string(raw.id, ""), + name = require_string(raw.name, ""), + arguments = type(raw.arguments) == "table" and (raw.arguments :: {[string]: any}) or {}, + }, nil + end + + if raw.kind == "error" then + return { + kind = "error", + id = require_string(raw.id, ""), + error = { + code = require_string(raw.code, "unknown"), + message = require_string(raw.message, "failed"), + }, + }, nil + end + + return nil, "unknown event" +end + +local function render_event(event: Event): string + if event.kind == "message" then + return event.id .. ":" .. event.text + end + if event.kind == "tool" then + return event.id .. ":" .. event.name + end + return event.id .. ":" .. event.error.code .. ":" .. event.error.message +end + +local function render_all(raw_events: {any}): ({string}, {string}) + local rendered: {string} = {} + local errors: {string} = {} + for _, raw in ipairs(raw_events) do + local event, err = decode_event(raw) + if event then + table.insert(rendered, render_event(event)) + else + table.insert(errors, err or "unknown") + end + end + return rendered, errors +end + +local rendered, errors = render_all({ + { kind = "message", id = "m1", text = "hello" }, + { kind = "tool", id = "t1", name = "search", arguments = { query = "lua" } }, + { kind = "error", id = "e1", code = "E", message = "boom" }, +}) + +return rendered[1] or errors[1] or "" +` + assertNoAdvancedStressErrors(t, source) +} + +func TestAdvancedTypeSystem_ResultPipelineKeepsMultiReturnCorrelation(t *testing.T) { + source := ` +type Err = {kind: string, message: string} +type User = {id: string, name: string, roles: {string}} +type Session = {id: string, user: User, expires_at: number} + +local users: {[string]: User} = { + ["u1"] = { id = "u1", name = "Ada", roles = ({ "admin" } :: {string}) }, +} + +local function find_user(id: string): (User?, Err?) + local user = users[id] + if not user then + return nil, { kind = "not_found", message = id } + end + return user, nil +end + +local function create_session(user: User, now: number): (Session?, Err?) + if #user.roles == 0 then + return nil, { kind = "forbidden", message = user.id } + end + return { id = user.id .. ":" .. tostring(now), user = user, expires_at = now + 3600 }, nil +end + +local function with_user(id: string, now: number, fn: (User, number) -> (Session?, Err?)): (Session?, Err?) + local user, err = find_user(id) + if err then + return nil, err + end + return fn(user, now) +end + +local session, err = with_user("u1", 10, create_session) +if err then + return err.message +end +return session.user.name .. ":" .. tostring(session.expires_at) +` + assertNoAdvancedStressErrors(t, source) +} + +func TestAdvancedTypeSystem_FluentMetatableBuilderPreservesStateAcrossMethods(t *testing.T) { + source := ` +type Request = { + method: string, + path: string, + headers: {[string]: string}, + query: {[string]: string}, + timeout: number, +} + +local Builder = {} +Builder.__index = Builder + +function Builder.new() + return setmetatable({ + method = "GET", + path = "/", + headers = {} :: {[string]: string}, + query = {} :: {[string]: string}, + timeout = 30, + }, Builder) +end + +function Builder.with_method(self: Request, method: string): Request + self.method = method + return self +end + +function Builder.with_header(self: Request, key: string, value: string): Request + self.headers[key] = value + return self +end + +function Builder.with_query(self: Request, key: string, value: string?): Request + if value then + self.query[key] = value + end + return self +end + +function Builder.with_timeout(self: Request, timeout: number?): Request + self.timeout = timeout or self.timeout + return self +end + +function Builder.build(self: Request): Request + return { + method = self.method, + path = self.path, + headers = self.headers, + query = self.query, + timeout = self.timeout, + } +end + +local req = Builder.build( + Builder.with_timeout( + Builder.with_query( + Builder.with_header( + Builder.with_method(Builder.new() :: Request, "POST"), + "Accept", + "application/json" + ), + "q", + "lua" + ), + nil + ) +) + +return req.method .. ":" .. req.headers.Accept .. ":" .. tostring(req.timeout) +` + assertNoAdvancedStressErrors(t, source) +} + +func TestAdvancedTypeSystem_ModuleBoundaryPreservesTaggedResultAndCallbacks(t *testing.T) { + repoModule := testutil.CheckAndExport(` +local repo = {} + +type Row = {id: string, payload: string, metadata: {[string]: any}?} +type Found = {ok: true, row: Row} +type Missing = {ok: false, error: {kind: "missing", message: string}} +type Result = Found | Missing +type Mapper = (Row) -> string + +local rows = { + ["a"] = { id = "a", payload = "hello", metadata = { source = "test" } }, +} + +function repo.get(id: string): Result + local row = rows[id] + if row then + return { ok = true, row = row } + end + return { ok = false, error = { kind = "missing", message = id } } +end + +function repo.map(id: string, mapper: Mapper): (string?, string?) + local result = repo.get(id) + if result.ok then + return mapper(result.row), nil + end + return nil, result.error.message +end + +return repo +`, "repo", testutil.WithStdlib()) + if repoModule.HasError() { + t.Fatalf("repo module errors: %v", testutil.ErrorMessages(repoModule.Errors)) + } + + source := ` +local repo = require("repo") + +local value, err = repo.map("a", function(row) + local source = row.metadata and row.metadata.source or "none" + return row.id .. ":" .. row.payload .. ":" .. tostring(source) +end) + +if err then + return err +end +return value +` + result := testutil.Check(source, testutil.WithStdlib(), testutil.WithModule("repo", repoModule)) + if result.HasError() { + t.Fatalf("expected module boundary to preserve tagged result and callback parameter shape, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func TestAdvancedTypeSystem_GenericResultCombinatorsPreserveDiscriminantsAndPayloads(t *testing.T) { + source := ` +type Failure = {code: string, message: string} +type Result = {ok: true, value: T} | {ok: false, error: Failure} +type Envelope = {id: string, attrs: {[string]: string}, nested: {attempts: number}} +type View = {label: string, attempts: number} + +local function ok(value: T): Result + return { ok = true, value = value } +end + +local function fail(code: string, message: string): Result + return { ok = false, error = { code = code, message = message } } +end + +local function map(result: Result, fn: (T) -> U): Result + if result.ok then + return ok(fn(result.value)) + end + return { ok = false, error = result.error } +end + +local function and_then(result: Result, fn: (T) -> Result): Result + if result.ok then + return fn(result.value) + end + return { ok = false, error = result.error } +end + +local function decode(raw: any): Result + if type(raw) ~= "table" then + return fail("shape", "not a table") + end + if type(raw.id) ~= "string" then + return fail("shape", "missing id") + end + return ok({ + id = raw.id, + attrs = type(raw.attrs) == "table" and (raw.attrs :: {[string]: string}) or {}, + nested = { attempts = type(raw.attempts) == "number" and raw.attempts or 0 }, + }) +end + +local view = and_then(decode({ + id = "evt", + attrs = { source = "test" }, + attempts = 2, +}), function(env: Envelope): Result + return map(ok(env), function(inner: Envelope): View + return { + label = inner.id .. ":" .. inner.attrs.source, + attempts = inner.nested.attempts + 1, + } + end) +end) + +if view.ok then + local label: string = view.value.label + local attempts: number = view.value.attempts + return label .. ":" .. tostring(attempts) +end +return view.error.code .. ":" .. view.error.message +` + assertNoAdvancedStressErrors(t, source) +} + +func TestAdvancedTypeSystem_NestedConfigBuilderKeepsPreciseMapAndArrayShapes(t *testing.T) { + source := ` +type Plugin = {id: string, enabled: boolean, config: {[string]: any}} +type Pipeline = {name: string, plugins: {Plugin}, env: {[string]: string}} + +local function add_plugin(pipeline: Pipeline, plugin: Plugin): Pipeline + table.insert(pipeline.plugins, plugin) + return pipeline +end + +local function enable_defaults(pipeline: Pipeline, defaults: {[string]: string}?): Pipeline + for key, value in pairs(defaults or {}) do + pipeline.env[key] = value + end + return add_plugin(pipeline, { + id = "logger", + enabled = true, + config = { level = pipeline.env.LOG_LEVEL or "info" }, + }) +end + +local pipeline = enable_defaults({ + name = "deploy", + plugins = {}, + env = { LOG_LEVEL = "debug" }, +}, { REGION = "local" }) + +local first = pipeline.plugins[1] +if not first then + return pipeline.env.REGION +end + +return first.id .. ":" .. tostring(first.config.level) .. ":" .. pipeline.env.REGION +` + assertNoAdvancedStressErrors(t, source) +} + +func TestAdvancedTypeSystem_SoundnessRejectsTruthyStringFallbackToNumber(t *testing.T) { + source := ` +local function expect_number(value: number) + return value + 1 +end + +local options: {timeout: string?} = { timeout = "30s" } +local timeout = options.timeout or 30 +return expect_number(timeout) +` + assertAdvancedStressErrorContains(t, source, "expected number") +} + +func TestAdvancedTypeSystem_SoundnessRejectsMetadataFieldAfterTruthyString(t *testing.T) { + source := ` +local meta: string | {content_type: string} = "" +local artifact = { meta = meta } + +if artifact.meta then + local content_type: string = artifact.meta.content_type + return content_type +end +return "missing" +` + assertAdvancedStressErrorContains(t, source, "cannot assign") +} + +func assertNoAdvancedStressErrors(t *testing.T, source string) { + t.Helper() + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected advanced type-system stress case to type-check, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func assertAdvancedStressErrorContains(t *testing.T, source, want string) { + t.Helper() + result := testutil.Check(source, testutil.WithStdlib()) + if !result.HasError() { + t.Fatalf("expected diagnostic containing %q, got no errors", want) + } + messages := strings.Join(testutil.ErrorMessages(result.Diagnostics), " | ") + if !strings.Contains(messages, want) { + t.Fatalf("expected diagnostic containing %q, got: %s", want, messages) + } +} diff --git a/testdata/fixtures/realworld/advanced-type-system-stress/events.lua b/testdata/fixtures/realworld/advanced-type-system-stress/events.lua new file mode 100644 index 00000000..f9698887 --- /dev/null +++ b/testdata/fixtures/realworld/advanced-type-system-stress/events.lua @@ -0,0 +1,93 @@ +type MessageEvent = {kind: "message", id: string, text: string, tags: {string}?} +type ToolEvent = {kind: "tool", id: string, name: string, arguments: {[string]: any}} +type ErrorEvent = {kind: "error", id: string, error: {code: string, message: string}} +type Event = MessageEvent | ToolEvent | ErrorEvent + +local M = {} +M.MessageEvent = MessageEvent +M.ToolEvent = ToolEvent +M.ErrorEvent = ErrorEvent +M.Event = Event + +local function require_string(value, fallback: string): string + if type(value) == "string" then + return value + end + return fallback +end + +local function string_array(value): {string}? + if type(value) ~= "table" then + return nil + end + local out: {string} = {} + for _, item in ipairs(value) do + if type(item) == "string" then + table.insert(out, item) + end + end + return out +end + +function M.decode(raw: any): (Event?, string?) + if type(raw) ~= "table" then + return nil, "event must be a table" + end + + if raw.kind == "message" then + return { + kind = "message", + id = require_string(raw.id, ""), + text = require_string(raw.text, ""), + tags = string_array(raw.tags), + }, nil + end + + if raw.kind == "tool" then + return { + kind = "tool", + id = require_string(raw.id, ""), + name = require_string(raw.name, ""), + arguments = type(raw.arguments) == "table" and (raw.arguments :: {[string]: any}) or {}, + }, nil + end + + if raw.kind == "error" then + return { + kind = "error", + id = require_string(raw.id, ""), + error = { + code = require_string(raw.code, "unknown"), + message = require_string(raw.message, "failed"), + }, + }, nil + end + + return nil, "unknown event" +end + +function M.render(event: Event): string + if event.kind == "message" then + return event.id .. ":" .. event.text + end + if event.kind == "tool" then + return event.id .. ":" .. event.name + end + return event.id .. ":" .. event.error.code .. ":" .. event.error.message +end + +function M.collect(raw_events: {any}): ({string}, {string}) + local rendered: {string} = {} + local errors: {string} = {} + for _, raw in ipairs(raw_events) do + local event, err = M.decode(raw) + if event then + table.insert(rendered, M.render(event)) + else + table.insert(errors, err or "unknown") + end + end + return rendered, errors +end + +return M diff --git a/testdata/fixtures/realworld/advanced-type-system-stress/main.lua b/testdata/fixtures/realworld/advanced-type-system-stress/main.lua new file mode 100644 index 00000000..fbb41416 --- /dev/null +++ b/testdata/fixtures/realworld/advanced-type-system-stress/main.lua @@ -0,0 +1,50 @@ +local events = require("events") +local sessions = require("sessions") +local request_builder = require("request_builder") +local pipeline = require("pipeline") + +local rendered, event_errors = events.collect({ + {kind = "message", id = "m1", text = "hello", tags = {"ui", "chat"}}, + {kind = "tool", id = "t1", name = "search", arguments = {query = "lua"}}, + {kind = "error", id = "e1", code = "E", message = "boom"}, +}) + +local session_text, session_err = sessions.describe("u1", 10) +if session_err then + return session_err.message +end + +local request = request_builder.new() + :with_method("POST") + :with_header("Accept", "application/json") + :with_query("q", rendered[1]) + :with_timeout(nil) + :build() + +local flow = pipeline.enable_defaults(pipeline.new("deploy"), { + REGION = "local", + LOG_LEVEL = "debug", +}) + +local first = flow.plugins[1] +if first then + local plugin_id: string = first.id + local level = first.config.level + if type(level) == "string" then + local plugin_level: string = level + end +end + +local summary: string = rendered[1] .. ":" .. session_text .. ":" .. request.headers.Accept .. ":" .. flow.env.REGION + +local options: {timeout: string?} = {timeout = "30s"} +local timeout = options.timeout or 30 +local bad_timeout: number = timeout -- expect-error + +local meta: string | {content_type: string} = "" +local artifact = {meta = meta} +if artifact.meta then + local content_type: string = artifact.meta.content_type -- expect-error +end + +return summary .. ":" .. tostring(event_errors[1]) diff --git a/testdata/fixtures/realworld/advanced-type-system-stress/manifest.json b/testdata/fixtures/realworld/advanced-type-system-stress/manifest.json new file mode 100644 index 00000000..b0bf6baf --- /dev/null +++ b/testdata/fixtures/realworld/advanced-type-system-stress/manifest.json @@ -0,0 +1 @@ +{"files": ["events.lua", "sessions.lua", "request_builder.lua", "pipeline.lua", "main.lua"]} diff --git a/testdata/fixtures/realworld/advanced-type-system-stress/pipeline.lua b/testdata/fixtures/realworld/advanced-type-system-stress/pipeline.lua new file mode 100644 index 00000000..24b47205 --- /dev/null +++ b/testdata/fixtures/realworld/advanced-type-system-stress/pipeline.lua @@ -0,0 +1,32 @@ +type Plugin = {id: string, enabled: boolean, config: {[string]: any}} +type Pipeline = {name: string, plugins: {Plugin}, env: {[string]: string}} + +local M = {} +M.Plugin = Plugin +M.Pipeline = Pipeline + +function M.new(name: string): Pipeline + return { + name = name, + plugins = {}, + env = {}, + } +end + +function M.add_plugin(pipeline: Pipeline, plugin: Plugin): Pipeline + table.insert(pipeline.plugins, plugin) + return pipeline +end + +function M.enable_defaults(pipeline: Pipeline, defaults: {[string]: string}?): Pipeline + for key, value in pairs(defaults or {}) do + pipeline.env[key] = value + end + return M.add_plugin(pipeline, { + id = "logger", + enabled = true, + config = {level = pipeline.env.LOG_LEVEL or "info"}, + }) +end + +return M diff --git a/testdata/fixtures/realworld/advanced-type-system-stress/request_builder.lua b/testdata/fixtures/realworld/advanced-type-system-stress/request_builder.lua new file mode 100644 index 00000000..6641db65 --- /dev/null +++ b/testdata/fixtures/realworld/advanced-type-system-stress/request_builder.lua @@ -0,0 +1,78 @@ +type Request = { + method: string, + path: string, + headers: {[string]: string}, + query: {[string]: string}, + timeout: number, +} + +type Builder = { + method: string, + path: string, + headers: {[string]: string}, + query: {[string]: string}, + timeout: number, + with_method: (self: Builder, method: string) -> Builder, + with_header: (self: Builder, key: string, value: string) -> Builder, + with_query: (self: Builder, key: string, value: string?) -> Builder, + with_timeout: (self: Builder, timeout: number?) -> Builder, + build: (self: Builder) -> Request, +} + +local Builder = {} +Builder.__index = Builder + +function Builder:with_method(method: string): Builder + self.method = method + return self +end + +function Builder:with_header(key: string, value: string): Builder + self.headers[key] = value + return self +end + +function Builder:with_query(key: string, value: string?): Builder + if value then + self.query[key] = value + end + return self +end + +function Builder:with_timeout(timeout: number?): Builder + self.timeout = timeout or self.timeout + return self +end + +function Builder:build(): Request + return { + method = self.method, + path = self.path, + headers = self.headers, + query = self.query, + timeout = self.timeout, + } +end + +local M = {} +M.Request = Request +M.Builder = Builder + +function M.new(): Builder + local builder: Builder = { + method = "GET", + path = "/", + headers = {} :: {[string]: string}, + query = {} :: {[string]: string}, + timeout = 30, + with_method = Builder.with_method, + with_header = Builder.with_header, + with_query = Builder.with_query, + with_timeout = Builder.with_timeout, + build = Builder.build, + } + setmetatable(builder, Builder) + return builder +end + +return M diff --git a/testdata/fixtures/realworld/advanced-type-system-stress/sessions.lua b/testdata/fixtures/realworld/advanced-type-system-stress/sessions.lua new file mode 100644 index 00000000..44acd7a8 --- /dev/null +++ b/testdata/fixtures/realworld/advanced-type-system-stress/sessions.lua @@ -0,0 +1,45 @@ +type Err = {kind: string, message: string} +type User = {id: string, name: string, roles: {string}} +type Session = {id: string, user: User, expires_at: number} + +local users: {[string]: User} = { + ["u1"] = {id = "u1", name = "Ada", roles = ({"admin"} :: {string})}, +} + +local M = {} +M.Err = Err +M.User = User +M.Session = Session + +function M.find_user(id: string): (User?, Err?) + local user = users[id] + if not user then + return nil, {kind = "not_found", message = id} + end + return user, nil +end + +function M.create_session(user: User, now: number): (Session?, Err?) + if #user.roles == 0 then + return nil, {kind = "forbidden", message = user.id} + end + return {id = user.id .. ":" .. tostring(now), user = user, expires_at = now + 3600}, nil +end + +function M.with_user(id: string, now: number, fn: (User, number) -> (Session?, Err?)): (Session?, Err?) + local user, err = M.find_user(id) + if err then + return nil, err + end + return fn(user, now) +end + +function M.describe(id: string, now: number): (string?, Err?) + local session, err = M.with_user(id, now, M.create_session) + if err then + return nil, err + end + return session.user.name .. ":" .. tostring(session.expires_at), nil +end + +return M From 2c679fb7b16356b3ff4b42b55d0fe5badd64136b Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 14:01:28 -0400 Subject: [PATCH 30/71] Add adversarial gradual typing regressions --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 48 +++ .../gradual_type_system_adversarial_test.go | 327 ++++++++++++++++++ .../gradual-typing-adversarial/main.lua | 117 +++++++ .../gradual-typing-adversarial/manifest.json | 3 + 4 files changed, 495 insertions(+) create mode 100644 compiler/check/tests/regression/gradual_type_system_adversarial_test.go create mode 100644 testdata/fixtures/regression/gradual-typing-adversarial/main.lua create mode 100644 testdata/fixtures/regression/gradual-typing-adversarial/manifest.json diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index 2106110a..8d1989ce 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -5353,3 +5353,51 @@ git diff --check ``` All checks pass. + +## 2026-05-19 Adversarial Gradual-Typing Regressions + +Added a dedicated gradual-typing regression suite and fixture. The goal is to +prove the checker is permissive where the program supplies evidence, while +remaining strict at dynamic boundaries where the evidence is incomplete. + +Added Go tests: + +- `TestGradualTyping_DecodesDynamicPayloadAfterStructuralProof` +- `TestGradualTyping_DispatchesGuardedUnionThroughTypedRegistry` +- `TestGradualTyping_GenericValidatedCollectionPreservesElementType` +- `TestGradualTyping_ExplicitBoundaryCastProvidesPreciseLocalType` +- `TestGradualTyping_RejectsUncheckedAnyRecordAssignment` +- `TestGradualTyping_RejectsTruthyGuardAsStructuralProof` +- `TestGradualTyping_RejectsPartiallyCheckedCollectionAsTypedArray` +- `TestGradualTyping_RejectsDynamicCallbackAtTypedFunctionBoundary` +- `TestGradualTyping_RejectsExtraFieldsAfterNarrowBoundaryCast` + +The positive cases cover dynamic payload decoding, discriminated command +dispatch through typed registries, generic validation/traversal over `{any}`, +and explicit boundary casts that produce a precise local type. The negative +cases pin the soundness laws: `any` cannot be assigned to a precise record +without proof, truthiness is not structural evidence, checking one array element +does not prove the whole array, dynamic callbacks cannot satisfy typed callback +contracts, and a narrowed cast type does not leak extra dynamic fields. + +Added fixture: + +- `testdata/fixtures/regression/gradual-typing-adversarial` + +The fixture exercises the same model through normal fixture checking and inline +`expect-error` comments. One fixture detail is intentional: generic `ok({})` +needs a typed empty-table cast (`{} :: {string}` or +`{} :: {[string]: string}`) so the empty table does not instantiate the +validation result as an unshaped table. This keeps inference strong without +guessing structure that is not present in the literal. + +Verification: + +```text +go test ./compiler/check/tests/regression -run 'TestGradualTyping' -count=1 -v +go test . -run 'TestFixtures/regression/gradual-typing-adversarial/check' -count=1 -v +go test ./... -count=1 +git diff --check +``` + +All checks pass. diff --git a/compiler/check/tests/regression/gradual_type_system_adversarial_test.go b/compiler/check/tests/regression/gradual_type_system_adversarial_test.go new file mode 100644 index 00000000..901b270f --- /dev/null +++ b/compiler/check/tests/regression/gradual_type_system_adversarial_test.go @@ -0,0 +1,327 @@ +package regression + +import ( + "strings" + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" +) + +func TestGradualTyping_DecodesDynamicPayloadAfterStructuralProof(t *testing.T) { + source := ` +type Address = {city: string, zip: number?} +type User = {id: string, name: string, active: boolean, tags: {string}, address: Address?} +type DecodeError = {kind: string, message: string} + +local function err(kind: string, message: string): DecodeError + return { kind = kind, message = message } +end + +local function read_string(value, field: string): (string?, DecodeError?) + if type(value) == "string" then + return value, nil + end + return nil, err("shape", field) +end + +local function read_boolean(value, field: string): (boolean?, DecodeError?) + if type(value) == "boolean" then + return value, nil + end + return nil, err("shape", field) +end + +local function read_tags(value): ({string}?, DecodeError?) + if value == nil then + return {}, nil + end + if type(value) ~= "table" then + return nil, err("shape", "tags") + end + local tags: {string} = {} + for _, item in ipairs(value) do + if type(item) ~= "string" then + return nil, err("shape", "tag") + end + table.insert(tags, item) + end + return tags, nil +end + +local function read_address(value): (Address?, DecodeError?) + if value == nil then + return nil, nil + end + if type(value) ~= "table" then + return nil, err("shape", "address") + end + local city, city_err = read_string(value.city, "city") + if city_err then + return nil, city_err + end + local zip: number? = nil + if type(value.zip) == "number" then + zip = value.zip + end + return { city = city, zip = zip }, nil +end + +local function decode_user(raw: any): (User?, DecodeError?) + if type(raw) ~= "table" then + return nil, err("shape", "root") + end + + local id, id_err = read_string(raw.id, "id") + if id_err then + return nil, id_err + end + local name, name_err = read_string(raw.name, "name") + if name_err then + return nil, name_err + end + local active, active_err = read_boolean(raw.active, "active") + if active_err then + return nil, active_err + end + local tags, tags_err = read_tags(raw.tags) + if tags_err then + return nil, tags_err + end + local address, address_err = read_address(raw.address) + if address_err then + return nil, address_err + end + + return { id = id, name = name, active = active, tags = tags, address = address }, nil +end + +local user, decode_err = decode_user({ + id = "u1", + name = "Ada", + active = true, + tags = { "admin", "founder" }, + address = { city = "London", zip = 12345 }, +}) + +if decode_err then + return decode_err.message +end +if user.address then + return user.name .. ":" .. user.address.city .. ":" .. user.tags[1] +end +return user.id +` + assertNoGradualTypingErrors(t, source) +} + +func TestGradualTyping_DispatchesGuardedUnionThroughTypedRegistry(t *testing.T) { + source := ` +type EmailCommand = {kind: "email", to: string, body: string} +type SmsCommand = {kind: "sms", phone: string, text: string} +type Command = EmailCommand | SmsCommand +type DispatchResult = {ok: true, id: string} | {ok: false, error: string} +type Handler = (Command) -> DispatchResult + +local handlers: {[string]: Handler} = {} + +handlers.email = function(command: Command): DispatchResult + if command.kind == "email" then + return { ok = true, id = command.to .. ":" .. command.body } + end + return { ok = false, error = "wrong handler" } +end + +handlers.sms = function(command: Command): DispatchResult + if command.kind == "sms" then + return { ok = true, id = command.phone .. ":" .. command.text } + end + return { ok = false, error = "wrong handler" } +end + +local function decode(raw: any): Command? + if type(raw) ~= "table" then + return nil + end + if raw.kind == "email" and type(raw.to) == "string" and type(raw.body) == "string" then + return { kind = "email", to = raw.to, body = raw.body } + end + if raw.kind == "sms" and type(raw.phone) == "string" and type(raw.text) == "string" then + return { kind = "sms", phone = raw.phone, text = raw.text } + end + return nil +end + +local function dispatch(raw: any): string + local command = decode(raw) + if not command then + return "bad" + end + local handler = handlers[command.kind] + if not handler then + return "missing" + end + local result = handler(command) + if result.ok then + return result.id + end + return result.error +end + +return dispatch({ kind = "email", to = "ops@example.com", body = "ready" }) +` + assertNoGradualTypingErrors(t, source) +} + +func TestGradualTyping_GenericValidatedCollectionPreservesElementType(t *testing.T) { + source := ` +type Validation = {ok: true, value: T} | {ok: false, error: string} +type Point = {x: number, y: number} + +local function ok(value: T): Validation + return { ok = true, value = value } +end + +local function invalid(message: string): Validation + return { ok = false, error = message } +end + +local function traverse(items: {T}, fn: (T) -> Validation): Validation<{U}> + local out: {U} = {} + for _, item in ipairs(items) do + local next = fn(item) + if not next.ok then + return invalid(next.error) + end + table.insert(out, next.value) + end + return ok(out) +end + +local function parse_point(raw: any): Validation + if type(raw) ~= "table" then + return invalid("point") + end + if type(raw.x) == "number" and type(raw.y) == "number" then + return ok({ x = raw.x, y = raw.y }) + end + return invalid("coords") +end + +local parsed = traverse(({ { x = 1, y = 2 }, { x = 3, y = 4 } } :: {any}), parse_point) +if parsed.ok then + local first = parsed.value[1] + if first then + local total: number = first.x + first.y + return tostring(total) + end + return "empty" +end +return parsed.error +` + assertNoGradualTypingErrors(t, source) +} + +func TestGradualTyping_ExplicitBoundaryCastProvidesPreciseLocalType(t *testing.T) { + source := ` +type Metric = {name: string, count: number, tags: {[string]: string}} + +local raw: any = { + name = "requests", + count = 10, + tags = { route = "/v1" }, +} + +local metric = raw :: Metric +local next_count: number = metric.count + 1 +local route = metric.tags.route +if not route then + return "missing" +end +local route_name: string = route +return metric.name .. ":" .. route_name .. ":" .. tostring(next_count) +` + assertNoGradualTypingErrors(t, source) +} + +func TestGradualTyping_RejectsUncheckedAnyRecordAssignment(t *testing.T) { + source := ` +type User = {id: string, name: string} + +local raw: any = { id = "u1", name = "Ada" } +local user: User = raw +return user.id +` + assertGradualTypingErrorContains(t, source, "cannot assign") +} + +func TestGradualTyping_RejectsTruthyGuardAsStructuralProof(t *testing.T) { + source := ` +local raw: any = { profile = "not a table" } + +if raw.profile then + local city: string = raw.profile.city + return city +end +return "missing" +` + assertGradualTypingErrorContains(t, source, "cannot assign") +} + +func TestGradualTyping_RejectsPartiallyCheckedCollectionAsTypedArray(t *testing.T) { + source := ` +local raw: any = { items = { "safe", 42 } } + +if type(raw.items) == "table" and type(raw.items[1]) == "string" then + local items: {string} = raw.items + return items[1] +end +return "missing" +` + assertGradualTypingErrorContains(t, source, "cannot assign") +} + +func TestGradualTyping_RejectsDynamicCallbackAtTypedFunctionBoundary(t *testing.T) { + source := ` +type User = {id: string} + +local callback: any = function(user) + return 42 +end + +local typed: (User) -> string = callback +return typed({ id = "u1" }) +` + assertGradualTypingErrorContains(t, source, "cannot assign") +} + +func TestGradualTyping_RejectsExtraFieldsAfterNarrowBoundaryCast(t *testing.T) { + source := ` +type Metric = {name: string, count: number} + +local raw: any = { name = "requests", count = 10, extra = true } +local metric = raw :: Metric +local extra: boolean = metric.extra +return tostring(extra) +` + assertGradualTypingErrorContains(t, source, "cannot assign") +} + +func assertNoGradualTypingErrors(t *testing.T, source string) { + t.Helper() + result := testutil.Check(source, testutil.WithStdlib()) + if result.HasError() { + t.Fatalf("expected gradual-typing adversarial case to type-check, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } +} + +func assertGradualTypingErrorContains(t *testing.T, source, want string) { + t.Helper() + result := testutil.Check(source, testutil.WithStdlib()) + if !result.HasError() { + t.Fatalf("expected diagnostic containing %q, got no errors", want) + } + messages := strings.Join(testutil.ErrorMessages(result.Diagnostics), " | ") + if !strings.Contains(messages, want) { + t.Fatalf("expected diagnostic containing %q, got: %s", want, messages) + } +} diff --git a/testdata/fixtures/regression/gradual-typing-adversarial/main.lua b/testdata/fixtures/regression/gradual-typing-adversarial/main.lua new file mode 100644 index 00000000..1ff6837f --- /dev/null +++ b/testdata/fixtures/regression/gradual-typing-adversarial/main.lua @@ -0,0 +1,117 @@ +type Config = { + id: string, + retries: number, + labels: {string}, + metadata: {[string]: string}, +} + +type Validation = {ok: true, value: T} | {ok: false, error: string} + +local function ok(value: T): Validation + return {ok = true, value = value} +end + +local function invalid(message: string): Validation + return {ok = false, error = message} +end + +local function read_labels(value): Validation<{string}> + if value == nil then + return ok({} :: {string}) + end + if type(value) ~= "table" then + return invalid("labels") + end + local labels: {string} = {} + for _, item in ipairs(value) do + if type(item) ~= "string" then + return invalid("label") + end + table.insert(labels, item) + end + return ok(labels) +end + +local function read_metadata(value): Validation<{[string]: string}> + if value == nil then + return ok({} :: {[string]: string}) + end + if type(value) ~= "table" then + return invalid("metadata") + end + local metadata: {[string]: string} = {} + for key, item in pairs(value) do + if type(key) == "string" and type(item) == "string" then + metadata[key] = item + end + end + return ok(metadata) +end + +local function decode_config(raw: any): Validation + if type(raw) ~= "table" then + return invalid("root") + end + if type(raw.id) ~= "string" then + return invalid("id") + end + if type(raw.retries) ~= "number" then + return invalid("retries") + end + + local labels = read_labels(raw.labels) + if not labels.ok then + return invalid(labels.error) + end + + local metadata = read_metadata(raw.metadata) + if not metadata.ok then + return invalid(metadata.error) + end + + return ok({ + id = raw.id, + retries = raw.retries, + labels = labels.value, + metadata = metadata.value, + }) +end + +local decoded = decode_config({ + id = "worker", + retries = 3, + labels = {"critical", "api"}, + metadata = {owner = "ops"}, +}) + +if decoded.ok then + local config: Config = decoded.value + local first = config.labels[1] + if first then + local label: string = first + end + local owner = config.metadata.owner + if owner then + local owner_name: string = owner + end +end + +local raw_config: any = {id = "worker", retries = 3} +local unchecked_config: Config = raw_config -- expect-error + +if raw_config.id then + local id: string = raw_config.id -- expect-error +end + +local raw_items: any = {items = {"ok", 99}} +if type(raw_items.items) == "table" and type(raw_items.items[1]) == "string" then + local labels: {string} = raw_items.items -- expect-error +end + +local callback: any = function(config) + return 1 +end + +local typed_callback: (Config) -> string = callback -- expect-error + +return "ok" diff --git a/testdata/fixtures/regression/gradual-typing-adversarial/manifest.json b/testdata/fixtures/regression/gradual-typing-adversarial/manifest.json new file mode 100644 index 00000000..b8fb3728 --- /dev/null +++ b/testdata/fixtures/regression/gradual-typing-adversarial/manifest.json @@ -0,0 +1,3 @@ +{ + "description": "Adversarial gradual typing fixture: dynamic data becomes precise only after structural proof or explicit cast, while incomplete proofs remain rejected." +} From 92a8d7b97422d629a945483a58e7b0d0a0d2740d Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 14:12:36 -0400 Subject: [PATCH 31/71] Add loop-carried gradual refinement regressions --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 40 ++++ .../gradual_type_system_adversarial_test.go | 201 ++++++++++++++++++ .../gradual-typing-adversarial/main.lua | 96 +++++++++ 3 files changed, 337 insertions(+) diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index 8d1989ce..39038dcb 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -5401,3 +5401,43 @@ git diff --check ``` All checks pass. + +## 2026-05-19 Loop-Carried Gradual Refinement Regressions + +Extended the adversarial gradual-typing coverage with loop-shaped programs +where precision is earned over several steps and then carried through typed +accumulators or loop state. + +Added Go cases: + +- `TestGradualTyping_LoopRefinesDynamicRecordsIntoTypedArray` validates a + dynamic array by stages (`table` guard, field guards, nested tag-loop guard) + before inserting precise `Item` records into a typed array. +- `TestGradualTyping_PairsLoopRefinesDynamicMapValuesInStages` validates + dynamic map keys, dynamic record values, nested header maps, and then stores + precise `Endpoint` records in a typed string-keyed map. +- `TestGradualTyping_WhileLoopCarriesOptionalRefinementThroughState` exercises + loop-carried optional state: a discriminated event loop writes `state.name` + and arithmetic state separately, then a post-loop nil guard proves the final + name before string use. +- `TestGradualTyping_NestedLoopsRefineMatrixCellsBeforeAggregation` covers + nested `ipairs` loops where row/column/value evidence builds precise cell + records. +- `TestGradualTyping_RejectsExistentialLoopProofAsSpecificElementProof` pins a + soundness boundary: seeing some string somewhere in a loop does not prove + that `raw.items[1]` is a string. + +The fixture `testdata/fixtures/regression/gradual-typing-adversarial` now +mirrors the staged `pairs` map refinement, nested matrix refinement, and +existential-loop negative case with inline `expect-error` coverage. + +Verification: + +```text +go test ./compiler/check/tests/regression -run 'TestGradualTyping' -count=1 -v +go test . -run 'TestFixtures/regression/gradual-typing-adversarial/check' -count=1 -v +go test ./... -count=1 +git diff --check +``` + +All checks pass. diff --git a/compiler/check/tests/regression/gradual_type_system_adversarial_test.go b/compiler/check/tests/regression/gradual_type_system_adversarial_test.go index 901b270f..c29dd7d4 100644 --- a/compiler/check/tests/regression/gradual_type_system_adversarial_test.go +++ b/compiler/check/tests/regression/gradual_type_system_adversarial_test.go @@ -243,6 +243,207 @@ return metric.name .. ":" .. route_name .. ":" .. tostring(next_count) assertNoGradualTypingErrors(t, source) } +func TestGradualTyping_LoopRefinesDynamicRecordsIntoTypedArray(t *testing.T) { + source := ` +type Item = {id: string, score: number, tags: {string}} + +local function read_tags(value): {string}? + if type(value) ~= "table" then + return nil + end + local tags: {string} = {} + for _, tag in ipairs(value) do + if type(tag) ~= "string" then + return nil + end + table.insert(tags, tag) + end + return tags +end + +local function collect(raw_items: any): {Item} + local out: {Item} = {} + if type(raw_items) ~= "table" then + return out + end + for _, raw in ipairs(raw_items) do + if type(raw) == "table" then + local id = raw.id + if type(id) == "string" then + local score = raw.score + if type(score) == "number" then + local tags = read_tags(raw.tags) + if tags then + table.insert(out, { id = id, score = score, tags = tags }) + end + end + end + end + end + return out +end + +local items = collect({ + { id = "a", score = 10, tags = { "hot", "new" } }, + { id = false, score = "bad", tags = { 1 } }, + { id = "b", score = 20, tags = { "ok" } }, +}) + +local first = items[1] +if not first then + return "empty" +end +local label: string = first.id .. ":" .. first.tags[1] +local score: number = first.score + 1 +return label .. ":" .. tostring(score) +` + assertNoGradualTypingErrors(t, source) +} + +func TestGradualTyping_PairsLoopRefinesDynamicMapValuesInStages(t *testing.T) { + source := ` +type Endpoint = {url: string, weight: number, headers: {[string]: string}} + +local function collect(raw: any): {[string]: Endpoint} + local endpoints: {[string]: Endpoint} = {} + if type(raw) ~= "table" then + return endpoints + end + for key, value in pairs(raw) do + if type(key) == "string" and type(value) == "table" then + local url = value.url + if type(url) == "string" then + local weight = value.weight + if type(weight) == "number" then + local headers: {[string]: string} = {} + if type(value.headers) == "table" then + for header_name, header_value in pairs(value.headers) do + if type(header_name) == "string" and type(header_value) == "string" then + headers[header_name] = header_value + end + end + end + endpoints[key] = { url = url, weight = weight, headers = headers } + end + end + end + end + return endpoints +end + +local endpoints = collect({ + primary = { url = "https://example.test", weight = 1, headers = { Accept = "application/json" } }, + secondary = { url = false, weight = "heavy" }, +}) + +local primary = endpoints.primary +if not primary then + return "missing" +end +local accept = primary.headers.Accept +if not accept then + return primary.url +end +local url: string = primary.url +local weight: number = primary.weight + 1 +return url .. ":" .. accept .. ":" .. tostring(weight) +` + assertNoGradualTypingErrors(t, source) +} + +func TestGradualTyping_WhileLoopCarriesOptionalRefinementThroughState(t *testing.T) { + source := ` +type Event = {kind: "name", value: string} | {kind: "count", value: number} +type State = {name: string?, total: number} + +local events: {Event} = { + { kind = "count", value = 2 }, + { kind = "name", value = "worker" }, + { kind = "count", value = 3 }, +} + +local state: State = { total = 0 } +local i = 1 +while i <= #events do + local event = events[i] + if event.kind == "name" then + state.name = event.value + else + state.total = state.total + event.value + end + i = i + 1 +end + +local name = state.name +if not name then + return "missing" +end +local final_name: string = name +local final_total: number = state.total +return final_name .. ":" .. tostring(final_total) +` + assertNoGradualTypingErrors(t, source) +} + +func TestGradualTyping_NestedLoopsRefineMatrixCellsBeforeAggregation(t *testing.T) { + source := ` +type Cell = {row: number, col: number, value: string} + +local function cells(raw_rows: any): {Cell} + local out: {Cell} = {} + if type(raw_rows) ~= "table" then + return out + end + for row_index, row in ipairs(raw_rows) do + if type(row) == "table" then + for col_index, value in ipairs(row) do + if type(value) == "string" then + table.insert(out, { row = row_index, col = col_index, value = value }) + end + end + end + end + return out +end + +local out = cells({ + { "a", false }, + { "b", "c" }, +}) + +local first = out[1] +if not first then + return "empty" +end +local pos: number = first.row + first.col +local value: string = first.value +return value .. ":" .. tostring(pos) +` + assertNoGradualTypingErrors(t, source) +} + +func TestGradualTyping_RejectsExistentialLoopProofAsSpecificElementProof(t *testing.T) { + source := ` +local raw: any = { items = { 42, "safe" } } +local saw_string = false + +if type(raw.items) == "table" then + for _, item in ipairs(raw.items) do + if type(item) == "string" then + saw_string = true + end + end +end + +if saw_string then + local first: string = raw.items[1] + return first +end +return "missing" +` + assertGradualTypingErrorContains(t, source, "cannot assign") +} + func TestGradualTyping_RejectsUncheckedAnyRecordAssignment(t *testing.T) { source := ` type User = {id: string, name: string} diff --git a/testdata/fixtures/regression/gradual-typing-adversarial/main.lua b/testdata/fixtures/regression/gradual-typing-adversarial/main.lua index 1ff6837f..cace9c94 100644 --- a/testdata/fixtures/regression/gradual-typing-adversarial/main.lua +++ b/testdata/fixtures/regression/gradual-typing-adversarial/main.lua @@ -5,6 +5,18 @@ type Config = { metadata: {[string]: string}, } +type Endpoint = { + url: string, + weight: number, + headers: {[string]: string}, +} + +type Cell = { + row: number, + col: number, + value: string, +} + type Validation = {ok: true, value: T} | {ok: false, error: string} local function ok(value: T): Validation @@ -114,4 +126,88 @@ end local typed_callback: (Config) -> string = callback -- expect-error +local function collect_endpoints(raw: any): {[string]: Endpoint} + local endpoints: {[string]: Endpoint} = {} + if type(raw) ~= "table" then + return endpoints + end + for key, value in pairs(raw) do + if type(key) == "string" and type(value) == "table" then + local url = value.url + if type(url) == "string" then + local weight = value.weight + if type(weight) == "number" then + local headers: {[string]: string} = {} + if type(value.headers) == "table" then + for header_name, header_value in pairs(value.headers) do + if type(header_name) == "string" and type(header_value) == "string" then + headers[header_name] = header_value + end + end + end + endpoints[key] = {url = url, weight = weight, headers = headers} + end + end + end + end + return endpoints +end + +local endpoints = collect_endpoints({ + primary = {url = "https://example.test", weight = 1, headers = {Accept = "application/json"}}, + bad = {url = false, weight = "heavy"}, +}) + +local primary = endpoints.primary +if primary then + local accept = primary.headers.Accept + if accept then + local endpoint_url: string = primary.url + local endpoint_weight: number = primary.weight + 1 + local endpoint_accept: string = accept + end +end + +local function collect_cells(raw_rows: any): {Cell} + local out: {Cell} = {} + if type(raw_rows) ~= "table" then + return out + end + for row_index, row in ipairs(raw_rows) do + if type(row) == "table" then + for col_index, value in ipairs(row) do + if type(value) == "string" then + table.insert(out, {row = row_index, col = col_index, value = value}) + end + end + end + end + return out +end + +local cells = collect_cells({ + {"a", false}, + {"b", "c"}, +}) + +local first_cell = cells[1] +if first_cell then + local position: number = first_cell.row + first_cell.col + local value: string = first_cell.value +end + +local raw_loop: any = {items = {42, "safe"}} +local saw_string = false +if type(raw_loop.items) == "table" then + for _, item in ipairs(raw_loop.items) do + if type(item) == "string" then + saw_string = true + end + end +end + +if saw_string then + local first_string: string = raw_loop.items[1] -- expect-error +end + return "ok" From b1b5aa1ed68d96f153650628f29c5713355d0fe0 Mon Sep 17 00:00:00 2001 From: Wolfy-J Date: Tue, 19 May 2026 16:31:16 -0400 Subject: [PATCH 32/71] Rectify checker convergence domains --- INTERPROC_FACTS_DOMAIN_JOURNAL.md | 246 +++++++++ compiler/check/checker.go | 17 - compiler/check/checker_test.go | 17 - .../domain/factproduct/container_mutation.go | 29 +- .../factproduct/container_mutation_test.go | 58 ++ compiler/check/domain/factproduct/product.go | 60 ++- compiler/check/flowbuild/assign/infer.go | 146 +++-- compiler/check/flowbuild/assign/infer_test.go | 36 ++ compiler/check/hooks/exhaustiveness_check.go | 498 ++++++++++++++++++ compiler/check/hooks/hooks.go | 12 + compiler/check/hooks/hooks_test.go | 11 +- compiler/check/infer/return/infer.go | 56 +- compiler/check/infer/return/scc.go | 48 +- compiler/check/pipeline/diagnostics.go | 46 -- compiler/check/pipeline/diagnostics_test.go | 66 --- compiler/check/pipeline/driver.go | 52 +- compiler/check/pipeline/driver_test.go | 8 - compiler/check/returns/callgraph_test.go | 9 - compiler/check/returns/types.go | 4 - compiler/check/returns/types_test.go | 14 - .../tests/flow/preflow_convergence_test.go | 91 +--- .../regression/exhaustiveness_warning_test.go | 244 +++++++++ .../main.lua | 20 + .../manifest.json | 3 + types/constraint/solver.go | 20 +- types/db/query.go | 60 +-- types/flow/flow.go | 16 - types/flow/normalize.go | 13 - types/flow/numeric.go | 26 +- types/flow/numeric/state.go | 69 +++ types/flow/numeric/state_test.go | 27 + types/flow/numeric_test.go | 24 + types/flow/solver.go | 4 - types/flow/transfer.go | 23 +- types/flow/transfer_test.go | 24 + types/narrow/discriminant_domain.go | 155 ++++++ types/narrow/discriminant_domain_test.go | 62 +++ types/typ/equals.go | 5 +- types/typ/equals_test.go | 16 + 39 files changed, 1795 insertions(+), 540 deletions(-) create mode 100644 compiler/check/hooks/exhaustiveness_check.go create mode 100644 compiler/check/tests/regression/exhaustiveness_warning_test.go create mode 100644 testdata/fixtures/narrowing/channel-select-case-exhaustiveness-warning/main.lua create mode 100644 testdata/fixtures/narrowing/channel-select-case-exhaustiveness-warning/manifest.json create mode 100644 types/narrow/discriminant_domain.go create mode 100644 types/narrow/discriminant_domain_test.go diff --git a/INTERPROC_FACTS_DOMAIN_JOURNAL.md b/INTERPROC_FACTS_DOMAIN_JOURNAL.md index 39038dcb..fcd0be65 100644 --- a/INTERPROC_FACTS_DOMAIN_JOURNAL.md +++ b/INTERPROC_FACTS_DOMAIN_JOURNAL.md @@ -5354,6 +5354,119 @@ git diff --check All checks pass. +## 2026-05-19 Exhaustiveness Warnings for Closed Matches and Channel Select + +Added the checker warning the user asked for. The standard term is +**exhaustiveness checking**; the diagnostic is a warning for a +**non-exhaustive match**. + +Correction made during review: the real `channel.select` exhaustiveness target +is `result.channel`, not `result.value.kind`. `result.value` is the selected +receive payload. A payload discriminator such as `result.value.kind` only makes +sense after a channel guard has already proven which channel produced the +payload. It is a separate nested discriminated-union match, not the select-arm +match itself. + +Diagnostic boundary: + +- The diagnostic is warning-only: `diag.ErrNonExhaustive` with + `SeverityWarning`. It does not make type checking fail. +- Closed literal-tag proof lives in `types/narrow.ClosedDiscriminantDomain`. +- The checker hook recognizes match-like Lua `if/elseif` chains and delegates + closed literal-tag domain proof to the narrowing domain. +- The hook also recognizes real `channel.select` result arms by indexing + assignments of `channel.select { ch:case_receive(), ... }` and matching + `result.channel == ch` branches against the selected channel paths. +- The warning is emitted only for match-like chains with at least two explicit + arms and no final `else`. A single early-return guard stays silent because + fallthrough may intentionally handle the remaining case. +- Open or dynamic cases stay silent: `any`, `unknown`, `nil`, optional + discriminants, broad tags like `kind: string`, missing tags, non-record + members, unextractable select channels, and select calls with default cases. + +Correct `channel.select` warning sample: + +```lua +local result = channel.select { + events_ch:case_receive(), + stop_ch:case_receive(), + timeout_ch:case_receive(), +} + +if result.channel == events_ch then + return result.value.kind +elseif result.channel == stop_ch then + return result.value.reason +end +``` + +Warning: + +```text +non-exhaustive match on result.channel; missing case: timeout_ch +``` + +Correct complete select sample: + +```lua +if result.channel == events_ch then + return result.value.kind +elseif result.channel == stop_ch then + return result.value.reason +elseif result.channel == timeout_ch then + return tostring(result.value.sec) +end +``` + +No warning is emitted there because every selected channel arm is represented. + +The nested payload-discriminant case is still supported separately: + +```lua +if result.channel == events_ch then + if result.value.kind == "message" then + ... + elseif result.value.kind == "tool" then + ... + end +end +``` + +That warning is about the closed `Event` payload union after the `events_ch` +guard, not about the `channel.select` arm set. + +Added coverage: + +- `types/narrow/discriminant_domain_test.go` + - closed string tag domains, + - closed numeric tag domains, + - broad tag rejection, + - optional tag rejection. +- `compiler/check/tests/regression/exhaustiveness_warning_test.go` + - plain discriminated-union missing case, + - real `channel.select` missing channel case, + - real `channel.select` all-cases-handled no-warning case, + - real `channel.select` single early-return guard no-warning case, + - final `else` suppresses warning, + - all literal variants handled suppresses warning, + - open discriminant suppresses warning, + - numeric discriminant missing case. +- `testdata/fixtures/narrowing/channel-select-case-exhaustiveness-warning` + pins the real fixture harness line-level `expect-warning` for the selected + channel case pattern. + +Verification: + +```text +go test ./types/narrow -run TestClosedDiscriminantDomain -count=1 -v +go test ./compiler/check/tests/regression -run TestExhaustivenessWarning -count=1 -v +go test ./compiler/check/hooks -count=1 +go test . -run 'TestFixtures/narrowing/channel-select-case-exhaustiveness-warning/check' -count=1 -v +go test ./... -count=1 +``` + +All checks pass. + ## 2026-05-19 Adversarial Gradual-Typing Regressions Added a dedicated gradual-typing regression suite and fixture. The goal is to @@ -5441,3 +5554,136 @@ git diff --check ``` All checks pass. + +## 2026-05-19 Exhaustiveness Lint Wiring and Real-Code Probe + +The exhaustiveness checker is intentionally a configurable warning class for +Wippy lint, not a globally forced diagnostic. The Wippy runtime type-checker +already had `TypeCheckRules.Exhaustive` in its cache fingerprint; that bit is +now the single authority for installing `hooks.WithExhaustiveness()`. + +Design notes: + +- go-lua owns the semantic pass and exposes it as `hooks.WithExhaustiveness()`; +- Wippy lint exposes the policy switch as `wippy lint --exhaustiveness`; +- typecheck cache fingerprints already include `Rules.Exhaustive`, so cached + diagnostics cannot hide the opt-in warning state; +- default lint remains unchanged and does not emit exhaustiveness warnings unless + the flag is requested. + +Real-code proof: + +1. temporarily injected a third unhandled `channel.select` case into + `framework/src/llm/src/llm.lua`; +2. rebuilt the temporary Wippy binary against this checkout with the local + go-lua replace; +3. ran `wippy lint --cache-reset --json --exhaustiveness` in + `framework/src/llm/src`; +4. observed the expected warning: + `E0014 warning: non-exhaustive match on result.channel; missing case: c`; +5. restored `llm.lua` byte-for-byte and reran the same lint command; +6. confirmed `warning_count: 0`. + +Added Wippy-side coverage: + +- `TestTypeChecker_ExhaustiveRuleOptIn` +- `TestTypeChecker_ExhaustiveRuleOffByDefault` +- `TestParseLintFlags_Exhaustiveness` + +Verification: + +```text +env GOFLAGS=-modfile=/tmp/wippy-local-replace.mod go test ./runtime/lua/code -run 'TestTypeChecker_ExhaustiveRule|TestChannelSelectNarrowing_ProcessEvent' -count=1 -v +env GOFLAGS=-modfile=/tmp/wippy-local-replace.mod go test ./cmd/wippy/cmd -run TestParseLintFlags_Exhaustiveness -count=1 -v +env GOFLAGS=-modfile=/tmp/wippy-local-replace.mod go build -o /tmp/wippy-local-replace-bin ./cmd/wippy +env GOFLAGS=-modfile=/tmp/wippy-local-replace.mod /tmp/wippy-local-replace-bin lint --cache-reset --json --exhaustiveness +``` + +The restored real-code lint run still reports the known nine LLM errors and no +warnings. Exhaustiveness did not backfire on real code after the temporary probe +was removed. + +## 2026-05-19 Flash Convergence Rectification: No Caps + +Removed the artificial convergence caps from the checker pipeline, return SCC +inference, assignment inference, query cycle handling, constraint solving, +flow solving, and numeric solving. Non-convergence is no longer handled by +"iterate N times then warn/fallback"; it must be handled by finite-height +abstract domains, idempotent transfer functions, and explicit widening. + +Key design decisions: + +- Interprocedural facts are a product-domain fixpoint. Captured container + mutations now canonicalize same-path writes and join element/value types on + the fact boundary instead of preserving duplicate mutation events. +- Return SCCs merge with `returnsummary.WidenForConvergence`, so recursive + return summaries stabilize through domain widening instead of an unknown + fallback. +- Assignment SCCs now test the actual SCC product state for stability after a + sweep. A transient update inside the sweep is not a semantic change unless + the final vector differs. +- `any` is treated as top in local inference joins and call-expectation merges. + This prevents `T -> any -> T` oscillation while preserving soundness. +- Numeric flow has per-point widening memory: once moving numeric facts widen + to Top, that point remains Top for the solve. This prevents `Top -> fact -> + Top -> fact` reintroduction caused by representing Top as an absent state. + +Important fixes found by real replays: + +- `types/typ.TypeEquals` no longer rejects structurally equal DAG-shaped types + just because one side shares a subnode and the other side duplicates it. The + equality proof now relies on pair-based coinduction for compound cycles. +- Array/map mutator widening is idempotent. Re-inserting an already-known array + element type returns the original abstract value instead of rebuilding an + equal value and causing false "changed" reports. +- Body-local parameter evidence is treated as evidence, not as a final declared + upper bound. A stronger whole-parameter call contract can dominate compatible + body evidence, including record evidence compatible with a string-keyed map. + +Regression coverage added: + +- same-iteration captured container mutation dedupe; +- captured container mutation joins for loop/table-insert patterns; +- assignment `any` top behavior; +- assignment SCC product-stability regressions from guarded options; +- structural equality for shared DAG-shaped records; +- array/map mutator idempotence; +- numeric widening-to-Top memory; +- body evidence plus whole-parameter call expectation; +- adversarial gradual-typing and loop-carried refinement cases; +- exhaustiveness opt-in warnings. + +Verification: + +```text +go test ./... -count=1 -timeout 180s +go test ./compiler/check -run '^$' -bench BenchmarkCheck_LargeFunction -benchmem -count=3 +env GOFLAGS=-modfile=/tmp/wippy-local-replace.mod go build -o /tmp/wippy-local-replace-verify ./cmd/wippy +timeout 60s /tmp/wippy-local-replace-verify lint --cache-reset --json --ns wippy.session.api +timeout 60s /tmp/wippy-local-replace-verify lint --cache-reset --json +``` + +The go-lua suite passes. The local-replace Wippy replays that previously hung +now terminate. `wippy.session.api` no longer times out; the full session target +also terminates. + +Final benchmark sample: + +```text +BenchmarkCheck_LargeFunction-32 382 3399067 ns/op 1084024 B/op 10938 allocs/op +BenchmarkCheck_LargeFunction-32 345 3236163 ns/op 1084162 B/op 10938 allocs/op +BenchmarkCheck_LargeFunction-32 385 3162319 ns/op 1084096 B/op 10938 allocs/op +``` + +Remaining verification boundary: + +- The stock `../scripts/verify-suite.sh` still cannot build Wippy without a + local replace because the Wippy checkout references the new + `hooks.WithExhaustiveness()` while its normal module graph resolves an older + published go-lua. +- Local-replace Wippy lint is not clean. The remaining diagnostics are finite + and must be classified separately as source/manifest issues or precision gaps; + this pass fixed the convergence class, not every external diagnostic. +- `tests/app` still reports an `E9999` internal-error diagnostic for + `app.test.types:lib_inner_types`; that is an engine-facing item and should be + investigated before claiming the global harness is clean. diff --git a/compiler/check/checker.go b/compiler/check/checker.go index 91232bfa..455a7587 100644 --- a/compiler/check/checker.go +++ b/compiler/check/checker.go @@ -137,7 +137,6 @@ type Checker struct { deps Deps passes []Pass computePasses []api.ComputePass - maxIterations int maxScopeDepth int emitScopeDepthDiagnostics bool } @@ -166,7 +165,6 @@ func NewChecker(database *db.DB, deps Deps, opts ...Option) *Checker { c := &Checker{ db: database, deps: deps, - maxIterations: 10, maxScopeDepth: 0, } @@ -193,24 +191,12 @@ func (c *Checker) newPipeline() *pipeline.Driver { GlobalTypes: c.deps.GlobalTypes, Stdlib: c.deps.Stdlib, Manifests: c.db, - MaxIterations: c.maxIterations, MaxScopeDepth: c.maxScopeDepth, EmitScopeDiag: c.emitScopeDepthDiagnostics, FuncResultQ: funcResultQ, }) } -// WithMaxIterations configures the maximum number of fixpoint iterations. -// Values less than 1 are clamped to 1. -func WithMaxIterations(n int) Option { - return func(c *Checker) { - if n < 1 { - n = 1 - } - c.maxIterations = n - } -} - // WithMaxScopeDepth configures a maximum lexical scope nesting depth. // A value <= 0 disables the limit. func WithMaxScopeDepth(n int) Option { @@ -327,9 +313,6 @@ func (c *Checker) runPasses(sess *Session) { diags := p(sess, fn, result) sess.Diagnostics = append(sess.Diagnostics, diags...) } - - // Emit widening diagnostics for preflow inference precision loss - sess.Diagnostics = append(sess.Diagnostics, pipeline.WideningDiagnostics(sess.SourceName, fn, result)...) } } diff --git a/compiler/check/checker_test.go b/compiler/check/checker_test.go index 44051dd4..a05e4364 100644 --- a/compiler/check/checker_test.go +++ b/compiler/check/checker_test.go @@ -26,9 +26,6 @@ func TestNewChecker(t *testing.T) { if c.db != database { t.Error("db not set") } - if c.maxIterations != 10 { - t.Fatalf("default maxIterations = %d, want 10", c.maxIterations) - } } func TestChecker_WithPass(t *testing.T) { @@ -44,20 +41,6 @@ func TestChecker_WithPass(t *testing.T) { } } -func TestChecker_WithMaxIterations(t *testing.T) { - c := NewChecker(db.New(), Deps{Types: core.NewEngine()}, WithMaxIterations(3)) - if c.maxIterations != 3 { - t.Fatalf("maxIterations = %d, want 3", c.maxIterations) - } -} - -func TestChecker_WithMaxIterationsClamp(t *testing.T) { - c := NewChecker(db.New(), Deps{Types: core.NewEngine()}, WithMaxIterations(0)) - if c.maxIterations != 1 { - t.Fatalf("maxIterations = %d, want 1", c.maxIterations) - } -} - func TestChecker_WithMaxScopeDepth(t *testing.T) { c := NewChecker(db.New(), Deps{Types: core.NewEngine()}, WithMaxScopeDepth(4)) if c.maxScopeDepth != 4 { diff --git a/compiler/check/domain/factproduct/container_mutation.go b/compiler/check/domain/factproduct/container_mutation.go index cbf131d5..ba3de704 100644 --- a/compiler/check/domain/factproduct/container_mutation.go +++ b/compiler/check/domain/factproduct/container_mutation.go @@ -16,11 +16,8 @@ func MergeContainerMutationSlices( next []api.ContainerMutation, merge ContainerMutationMerger, ) []api.ContainerMutation { - if len(existing) == 0 { - return next - } - if len(next) == 0 { - return existing + if len(existing) == 0 && len(next) == 0 { + return nil } mergeFn := merge @@ -29,18 +26,21 @@ func MergeContainerMutationSlices( } byKey := make(map[string]api.ContainerMutation, len(existing)+len(next)) - for _, m := range existing { - byKey[api.ContainerMutationKey(m)] = m - } - for _, m := range next { + add := func(m api.ContainerMutation) { key := api.ContainerMutationKey(m) if prev, ok := byKey[key]; ok { merged := mergeFn(&prev, m) byKey[key] = merged - continue + return } byKey[key] = mergeFn(nil, m) } + for _, m := range existing { + add(m) + } + for _, m := range next { + add(m) + } out := make([]api.ContainerMutation, 0, len(byKey)) for _, key := range cfg.SortedFieldNames(byKey) { @@ -55,15 +55,12 @@ func MergeCapturedContainerMutationMaps( next map[cfg.SymbolID][]api.ContainerMutation, merge ContainerMutationMerger, ) map[cfg.SymbolID][]api.ContainerMutation { - if existing == nil { - return next - } - if next == nil { - return existing + if len(existing) == 0 && len(next) == 0 { + return nil } merged := make(map[cfg.SymbolID][]api.ContainerMutation, len(existing)+len(next)) for _, sym := range cfg.SortedSymbolIDs(existing) { - merged[sym] = existing[sym] + merged[sym] = MergeContainerMutationSlices(nil, existing[sym], merge) } for _, sym := range cfg.SortedSymbolIDs(next) { merged[sym] = MergeContainerMutationSlices(merged[sym], next[sym], merge) diff --git a/compiler/check/domain/factproduct/container_mutation_test.go b/compiler/check/domain/factproduct/container_mutation_test.go index 1d9f59d9..03554052 100644 --- a/compiler/check/domain/factproduct/container_mutation_test.go +++ b/compiler/check/domain/factproduct/container_mutation_test.go @@ -108,3 +108,61 @@ func TestMergeContainerMutationSlices_KeepsOperatorKindsDistinct(t *testing.T) { t.Fatalf("expected separate facts for same path with different operators, got %#v", got) } } + +func TestWidenCapturedContainerMutations_JoinsSameContainerElement(t *testing.T) { + prevRecord := typ.NewRecord().Field("name", typ.Any).Build() + nextRecord := typ.NewRecord().Field("error", typ.String).Build() + + prev := api.CapturedContainerMutations{ + 10: { + 20: { + {Kind: api.ContainerMutationContainerElement, ValueType: prevRecord}, + }, + }, + } + next := api.CapturedContainerMutations{ + 10: { + 20: { + {Kind: api.ContainerMutationContainerElement, ValueType: nextRecord}, + }, + }, + } + + got := WidenCapturedContainerMutations(prev, next) + muts := got[10][20] + if len(muts) != 1 { + t.Fatalf("len(muts) = %d, want 1", len(muts)) + } + if typ.TypeEquals(muts[0].ValueType, prevRecord) || typ.TypeEquals(muts[0].ValueType, nextRecord) { + t.Fatalf("expected joined container element type, got %v", muts[0].ValueType) + } + if !typ.TypeEquals(got[10][20][0].ValueType, WidenCapturedContainerMutations(got, next)[10][20][0].ValueType) { + t.Fatalf("widened captured container mutation must be idempotent, got %v then %v", got[10][20][0].ValueType, WidenCapturedContainerMutations(got, next)[10][20][0].ValueType) + } +} + +func TestWidenCapturedContainerMutations_DedupesSameIterationMutations(t *testing.T) { + firstRecord := typ.NewRecord().Field("name", typ.Any).Build() + secondRecord := typ.NewRecord().Field("error", typ.String).Build() + + next := api.CapturedContainerMutations{ + 10: { + 20: { + {Kind: api.ContainerMutationContainerElement, ValueType: firstRecord}, + {Kind: api.ContainerMutationContainerElement, ValueType: secondRecord}, + }, + }, + } + + got := WidenCapturedContainerMutations(nil, next) + muts := got[10][20] + if len(muts) != 1 { + t.Fatalf("len(muts) = %d, want 1 canonical mutation per path", len(muts)) + } + if typ.TypeEquals(muts[0].ValueType, firstRecord) || typ.TypeEquals(muts[0].ValueType, secondRecord) { + t.Fatalf("expected same-iteration container writes to join, got %v", muts[0].ValueType) + } + if !CapturedContainerMutationsEqual(got, WidenCapturedContainerMutations(got, next)) { + t.Fatalf("widened captured container mutations must be idempotent") + } +} diff --git a/compiler/check/domain/factproduct/product.go b/compiler/check/domain/factproduct/product.go index 6629cb48..f70f1adf 100644 --- a/compiler/check/domain/factproduct/product.go +++ b/compiler/check/domain/factproduct/product.go @@ -409,7 +409,7 @@ func WidenCapturedContainerMutations(prev, next api.CapturedContainerMutations) existing := merged[sym] merged[sym] = MergeCapturedContainerMutationMaps(existing, muts, func(prev *api.ContainerMutation, next api.ContainerMutation) api.ContainerMutation { if prev != nil { - next.ValueType = mergeInterprocValueType(prev.ValueType, next.ValueType) + next.ValueType = widenContainerMutationValueType(prev.ValueType, next.ValueType) } else { next.ValueType = value.WidenForConvergence(next.ValueType) } @@ -440,11 +440,14 @@ func normalizeCapturedContainerMutationMap(muts map[cfg.SymbolID][]api.Container if len(entries) == 0 { continue } - normalized := make([]api.ContainerMutation, len(entries)) - for i, mut := range entries { - normalized[i] = mut - normalized[i].ValueType = canonicalInterprocValueType(mut.ValueType) - } + normalized := MergeContainerMutationSlices(nil, entries, func(prev *api.ContainerMutation, next api.ContainerMutation) api.ContainerMutation { + if prev != nil { + next.ValueType = widenContainerMutationValueType(prev.ValueType, next.ValueType) + } else { + next.ValueType = value.WidenForConvergence(next.ValueType) + } + return next + }) out[sym] = normalized } if len(out) == 0 { @@ -473,7 +476,7 @@ func JoinCapturedContainerMutations(prev, next api.CapturedContainerMutations) a existing := merged[sym] merged[sym] = MergeCapturedContainerMutationMaps(existing, muts, func(prev *api.ContainerMutation, next api.ContainerMutation) api.ContainerMutation { if prev != nil { - next.ValueType = joinInterprocValueType(prev.ValueType, next.ValueType) + next.ValueType = joinContainerMutationValueType(prev.ValueType, next.ValueType) } else { next.ValueType = normalizeInterprocValueType(next.ValueType) } @@ -504,11 +507,14 @@ func normalizeCapturedContainerMutationMapForJoin(muts map[cfg.SymbolID][]api.Co if len(entries) == 0 { continue } - normalized := make([]api.ContainerMutation, len(entries)) - for i, mut := range entries { - normalized[i] = mut - normalized[i].ValueType = normalizeInterprocValueType(mut.ValueType) - } + normalized := MergeContainerMutationSlices(nil, entries, func(prev *api.ContainerMutation, next api.ContainerMutation) api.ContainerMutation { + if prev != nil { + next.ValueType = joinContainerMutationValueType(prev.ValueType, next.ValueType) + } else { + next.ValueType = normalizeInterprocValueType(next.ValueType) + } + return next + }) out[sym] = normalized } if len(out) == 0 { @@ -517,6 +523,36 @@ func normalizeCapturedContainerMutationMapForJoin(muts map[cfg.SymbolID][]api.Co return out } +func widenContainerMutationValueType(prev, next typ.Type) typ.Type { + prev = canonicalInterprocValueType(prev) + next = canonicalInterprocValueType(next) + if prev == nil { + return value.WidenForConvergence(next) + } + if next == nil { + return value.WidenForConvergence(prev) + } + if typ.TypeEquals(prev, next) { + return prev + } + return value.WidenForConvergence(typ.JoinReturnSlot(prev, next)) +} + +func joinContainerMutationValueType(prev, next typ.Type) typ.Type { + prev = normalizeInterprocValueType(prev) + next = normalizeInterprocValueType(next) + if prev == nil { + return next + } + if next == nil { + return prev + } + if typ.TypeEquals(prev, next) { + return prev + } + return normalizeInterprocValueType(typ.JoinReturnSlot(prev, next)) +} + // WidenConstructorFields merges constructor field maps using monotone join. func WidenConstructorFields(prev, next api.ConstructorFields) api.ConstructorFields { if prev == nil && next == nil { diff --git a/compiler/check/flowbuild/assign/infer.go b/compiler/check/flowbuild/assign/infer.go index 978085fd..1431a8fa 100644 --- a/compiler/check/flowbuild/assign/infer.go +++ b/compiler/check/flowbuild/assign/infer.go @@ -19,8 +19,8 @@ // // - Stop when no changes occur // -// 4. Widening: If an SCC doesn't converge within maxInferIterations, widen all -// symbols in that SCC to Unknown. This ensures termination. +// 4. Convergence: recursive SCCs iterate until the widened abstract domain +// stabilizes; there is no caller-visible iteration cap. // // # SCC PROCESSING // @@ -71,9 +71,6 @@ import ( "github.com/wippyai/go-lua/types/typ" ) -// maxInferIterations limits fixpoint iterations per SCC. -const maxInferIterations = 10 - func mergeSpecTypesSoftInto(out, base, override api.SpecTypes) api.SpecTypes { if out == nil { out = make(api.SpecTypes, len(base)+len(override)) @@ -131,8 +128,7 @@ func CollectInferredTypes(fc *fbcore.FlowContext, specTypes api.SpecTypes, annot // Algorithm: // 1. Build dependency graph: symbol -> symbols referenced in RHS // 2. Compute SCCs in topological order -// 3. For each SCC, run bounded fixpoint iteration with monotone joins -// 4. If not converged by max iterations, widen to Unknown +// 3. For each SCC, run fixpoint iteration with monotone joins func collectInferredTypes( graph *cfg.Graph, scopes map[cfg.Point]*scope.State, @@ -528,10 +524,11 @@ func collectInferredTypes( } } - // Fixpoint iteration for this SCC - converged := false + // Fixpoint iteration for this SCC. var overlayScratch api.SpecTypes - for iter := 0; iter < maxInferIterations; iter++ { + snapshot := make([]typ.Type, len(sccSyms)) + for { + snapshotSCCTypes(snapshot, inferred, sccSyms) changed := false overlayScratch = mergeSpecTypesSoftInto(overlayScratch, inferred, specTypes) overlay := overlayScratch @@ -989,25 +986,10 @@ func collectInferredTypes( changed = true } - if !changed { - converged = true + if !changed || sccTypesStable(snapshot, inferred, sccSyms) { break } } - - // Widen ALL symbols in non-converged SCC to Unknown (except annotated). - // This is sound: partial types may be under-approximations. Local - // preflow inference is a hint source, so the fallback is intentionally - // internal; surfacing it as a lint warning produces false positives for - // dynamic but valid Lua patterns. - if !converged { - for _, sym := range sccSyms { - if annotated != nil && annotated[sym] { - continue - } - inferred[sym] = typ.Unknown - } - } } // Default unconstrained parameters to any. @@ -1031,6 +1013,32 @@ func collectInferredTypes( return inferred } +func snapshotSCCTypes(out []typ.Type, inferred api.SpecTypes, syms []cfg.SymbolID) { + for i, sym := range syms { + out[i] = inferred[sym] + } +} + +func sccTypesStable(prev []typ.Type, inferred api.SpecTypes, syms []cfg.SymbolID) bool { + for i, sym := range syms { + before := prev[i] + after := inferred[sym] + if before == after { + continue + } + if before == nil || after == nil { + return false + } + if before.Hash() != after.Hash() { + return false + } + if !typ.TypeEquals(before, after) { + return false + } + } + return true +} + func normalizedCallArgSymbols(info *cfg.CallInfo, bindings *bind.BindingTable) []cfg.SymbolID { if info == nil || len(info.Args) == 0 { return nil @@ -1328,6 +1336,9 @@ func joinInferredType(old, next typ.Type) typ.Type { if next == nil { return old } + if typ.IsAny(old) || typ.IsAny(next) { + return typ.Any + } if typeContains(next, old) { if !typ.IsAbsentOrUnknown(old) { return old @@ -1344,6 +1355,9 @@ func callExpectationCanRefineLocal(old typ.Type) bool { } func mergeCallExpectation(old, expected typ.Type, isParam bool) typ.Type { + if typ.IsAny(old) || typ.IsAny(expected) { + return typ.Any + } if isParam { if expectedParamTypeDominates(old, expected) { return expected @@ -1393,11 +1407,7 @@ func recordEvidenceCompatibleWithExpected(old, expected *typ.Record) bool { if expectedField.Optional { expectedType = typ.NewOptional(expectedType) } - fieldType := field.Type - if field.Optional { - fieldType = typ.NewOptional(fieldType) - } - if !subtype.IsSubtype(fieldType, expectedType) { + if !evidenceTypeCompatibleWithExpected(field.Type, expectedType) { return false } } @@ -1405,10 +1415,80 @@ func recordEvidenceCompatibleWithExpected(old, expected *typ.Record) bool { if !expected.HasMapComponent() { return false } - if !fieldEvidenceIsUnresolved(old.MapKey) && !subtype.IsSubtype(old.MapKey, expected.MapKey) { + if !fieldEvidenceIsUnresolved(old.MapKey) && !evidenceTypeCompatibleWithExpected(old.MapKey, expected.MapKey) { + return false + } + if !fieldEvidenceIsUnresolved(old.MapValue) && !evidenceTypeCompatibleWithExpected(old.MapValue, expected.MapValue) { + return false + } + } + return true +} + +func evidenceTypeCompatibleWithExpected(evidence, expected typ.Type) bool { + if fieldEvidenceIsUnresolved(evidence) { + return true + } + if evidence == nil || expected == nil { + return false + } + if subtype.IsSubtype(evidence, expected) { + return true + } + switch e := typ.UnwrapAnnotated(evidence).(type) { + case *typ.Alias: + return evidenceTypeCompatibleWithExpected(e.Target, expected) + case *typ.Union: + for _, member := range e.Members { + if !evidenceTypeCompatibleWithExpected(member, expected) { + return false + } + } + return true + case *typ.Record: + if expectedMap := mapForEvidenceExpected(expected); expectedMap != nil { + return recordEvidenceCompatibleWithExpectedMap(e, expectedMap) + } + } + if opt, ok := typ.UnwrapAnnotated(expected).(*typ.Optional); ok { + return evidenceTypeCompatibleWithExpected(evidence, opt.Inner) + } + return false +} + +func mapForEvidenceExpected(t typ.Type) *typ.Map { + for { + switch v := typ.UnwrapAnnotated(t).(type) { + case *typ.Alias: + t = v.Target + case *typ.Optional: + t = v.Inner + case *typ.Map: + return v + default: + return nil + } + } +} + +func recordEvidenceCompatibleWithExpectedMap(evidence *typ.Record, expected *typ.Map) bool { + if evidence == nil || expected == nil { + return false + } + for _, field := range evidence.Fields { + keyType := typ.LiteralString(field.Name) + if !evidenceTypeCompatibleWithExpected(keyType, expected.Key) { + return false + } + if !evidenceTypeCompatibleWithExpected(field.Type, expected.Value) { + return false + } + } + if evidence.HasMapComponent() { + if !evidenceTypeCompatibleWithExpected(evidence.MapKey, expected.Key) { return false } - if !fieldEvidenceIsUnresolved(old.MapValue) && !subtype.IsSubtype(old.MapValue, expected.MapValue) { + if !evidenceTypeCompatibleWithExpected(evidence.MapValue, expected.Value) { return false } } diff --git a/compiler/check/flowbuild/assign/infer_test.go b/compiler/check/flowbuild/assign/infer_test.go index 925f0ee7..2aad0b74 100644 --- a/compiler/check/flowbuild/assign/infer_test.go +++ b/compiler/check/flowbuild/assign/infer_test.go @@ -217,6 +217,42 @@ func TestJoinInferredType_StopsRecursiveNestingGrowth(t *testing.T) { } } +func TestJoinInferredType_TreatsAnyAsTop(t *testing.T) { + suite := typ.NewRecord().Field("name", typ.String).Build() + + got := joinInferredType(suite, typ.Any) + if !typ.TypeEquals(got, typ.Any) { + t.Fatalf("joinInferredType(Suite, any) = %v, want any", got) + } + + got = mergeCallExpectation(typ.Any, suite, true) + if !typ.TypeEquals(got, typ.Any) { + t.Fatalf("mergeCallExpectation(any, Suite) = %v, want any", got) + } +} + +func TestMergeCallExpectation_ParamDominatesCompatibleBodyEvidence(t *testing.T) { + headerMap := typ.NewMap(typ.String, typ.String) + bodyHeaderEvidence := typ.NewRecord(). + SetOpen(true). + OptField("Accept", typ.String). + Build() + old := typ.NewRecord(). + SetOpen(true). + OptField("headers", typ.NewUnion(headerMap, bodyHeaderEvidence)). + OptField("stream", typ.Unknown). + Build() + expected := typ.NewRecord(). + Field("headers", headerMap). + OptField("stream", typ.Boolean). + Build() + + got := mergeCallExpectation(old, expected, true) + if !typ.TypeEquals(got, expected) { + t.Fatalf("mergeCallExpectation(body evidence, expected param) = %v, want %v", got, expected) + } +} + func TestTypeContains(t *testing.T) { base := typ.NewArray(typ.Unknown) outer := typ.NewArray(base) diff --git a/compiler/check/hooks/exhaustiveness_check.go b/compiler/check/hooks/exhaustiveness_check.go new file mode 100644 index 00000000..487b5155 --- /dev/null +++ b/compiler/check/hooks/exhaustiveness_check.go @@ -0,0 +1,498 @@ +package hooks + +import ( + "fmt" + "strings" + + "github.com/wippyai/go-lua/compiler/ast" + "github.com/wippyai/go-lua/compiler/bind" + "github.com/wippyai/go-lua/compiler/cfg" + "github.com/wippyai/go-lua/compiler/check/api" + flowpath "github.com/wippyai/go-lua/compiler/check/flowbuild/path" + "github.com/wippyai/go-lua/types/constraint" + "github.com/wippyai/go-lua/types/diag" + "github.com/wippyai/go-lua/types/narrow" + "github.com/wippyai/go-lua/types/numparse" + "github.com/wippyai/go-lua/types/typ" +) + +// CheckExhaustiveness warns when a match-like if/elseif chain misses variants +// from a provably closed discriminated union. +func CheckExhaustiveness(fn *ast.FunctionExpr, graph *cfg.Graph, synth api.BaseSynth, sourceName string) []diag.Diagnostic { + if fn == nil || graph == nil || synth == nil { + return nil + } + checker := exhaustivenessChecker{ + branchPoint: branchPointsByCondition(graph), + graph: graph, + bindings: graph.Bindings(), + selectCases: selectCasesByResult(graph), + synth: synth, + sourceName: sourceName, + } + checker.checkStmts(fn.Stmts) + return checker.diags +} + +type exhaustivenessChecker struct { + branchPoint map[ast.Expr]cfg.Point + graph *cfg.Graph + bindings *bind.BindingTable + selectCases map[string]selectCaseDomain + synth api.BaseSynth + sourceName string + diags []diag.Diagnostic +} + +type discriminantCheck struct { + object ast.Expr + objectPath constraint.Path + field string + path string + literal *typ.Literal + value ast.Expr + valuePath constraint.Path + valueName string + condition ast.Expr + point cfg.Point +} + +type selectCaseDomain struct { + cases []selectCase +} + +type selectCase struct { + path constraint.Path + name string +} + +func branchPointsByCondition(graph *cfg.Graph) map[ast.Expr]cfg.Point { + points := make(map[ast.Expr]cfg.Point) + graph.EachBranch(func(p cfg.Point, info *cfg.BranchInfo) { + if info != nil && info.Condition != nil { + points[info.Condition] = p + } + }) + return points +} + +func (c *exhaustivenessChecker) checkStmts(stmts []ast.Stmt) { + for _, stmt := range stmts { + c.checkStmt(stmt) + } +} + +func (c *exhaustivenessChecker) checkStmt(stmt ast.Stmt) { + switch s := stmt.(type) { + case *ast.IfStmt: + c.checkIf(s) + case *ast.WhileStmt: + c.checkStmts(s.Stmts) + case *ast.RepeatStmt: + c.checkStmts(s.Stmts) + case *ast.NumberForStmt: + c.checkStmts(s.Stmts) + case *ast.GenericForStmt: + c.checkStmts(s.Stmts) + case *ast.DoBlockStmt: + c.checkStmts(s.Stmts) + } +} + +func (c *exhaustivenessChecker) checkIf(stmt *ast.IfStmt) { + c.checkIfChain(stmt) + + for current := stmt; current != nil; { + c.checkStmts(current.Then) + next, ok := singleElseIf(current) + if !ok { + c.checkStmts(current.Else) + break + } + current = next + } +} + +func (c *exhaustivenessChecker) checkIfChain(stmt *ast.IfStmt) { + checks, hasElse, ok := c.collectDiscriminantChain(stmt) + if !ok || hasElse || len(checks) < 2 { + return + } + + first := checks[0] + for _, check := range checks[1:] { + if check.path != first.path || check.field != first.field { + return + } + } + + if first.literal == nil { + c.checkSelectCaseChain(stmt, checks) + return + } + + objectType := c.synth.TypeOf(first.object, first.point) + domain, ok := narrow.ClosedDiscriminantDomain(objectType, first.field) + if !ok { + return + } + + handled := make([]*typ.Literal, 0, len(checks)) + handledInDomain := false + for _, check := range checks { + handled = append(handled, check.literal) + if domain.Contains(check.literal) { + handledInDomain = true + } + } + if !handledInDomain { + return + } + + missing := domain.Missing(handled) + if len(missing) == 0 { + return + } + c.addNonExhaustiveWarning(stmt.Condition, first.path, missing) +} + +func (c *exhaustivenessChecker) checkSelectCaseChain(stmt *ast.IfStmt, checks []discriminantCheck) { + first := checks[0] + if first.field != "channel" || first.objectPath.IsEmpty() { + return + } + domain, ok := c.selectCases[pathKey(first.objectPath)] + if !ok || len(domain.cases) < 2 { + return + } + + handled := make(map[string]struct{}, len(checks)) + for _, check := range checks { + if check.valuePath.IsEmpty() { + return + } + key := pathKey(check.valuePath) + if !domain.contains(key) { + return + } + handled[key] = struct{}{} + } + + var missing []string + for _, candidate := range domain.cases { + if _, ok := handled[pathKey(candidate.path)]; !ok { + missing = append(missing, candidate.name) + } + } + if len(missing) == 0 { + return + } + c.addNonExhaustiveNamesWarning(stmt.Condition, first.path, missing) +} + +func (c *exhaustivenessChecker) collectDiscriminantChain(stmt *ast.IfStmt) ([]discriminantCheck, bool, bool) { + var checks []discriminantCheck + current := stmt + for current != nil { + check, ok := c.discriminantCheck(current.Condition) + if !ok { + return nil, false, false + } + checks = append(checks, check) + + next, ok := singleElseIf(current) + if !ok { + return checks, len(current.Else) > 0, true + } + current = next + } + return checks, false, true +} + +func (c *exhaustivenessChecker) discriminantCheck(condition ast.Expr) (discriminantCheck, bool) { + point, ok := c.branchPoint[condition] + if !ok { + return discriminantCheck{}, false + } + + check, ok := equalityDiscriminantCheck(condition) + if !ok { + return discriminantCheck{}, false + } + check.condition = condition + check.point = point + if check.object != nil && c.bindings != nil { + check.objectPath = flowpath.FromExprWithBindingsAt(check.object, nil, c.bindings, c.graph, point) + } + if check.value != nil && c.bindings != nil { + check.valuePath = flowpath.FromExprWithBindingsAt(check.value, nil, c.bindings, c.graph, point) + } + return check, true +} + +func (c *exhaustivenessChecker) addNonExhaustiveWarning(node ast.Expr, path string, missing []*typ.Literal) { + pos := diag.Position{File: c.sourceName} + span := diag.Span{} + if node != nil { + pos.Line = node.Line() + pos.Column = node.Column() + span = ast.SpanOf(node) + } + message := fmt.Sprintf("non-exhaustive match on %s; missing %s", path, formatMissingCases(missing)) + c.diags = append(c.diags, diag.Diagnostic{ + Severity: diag.SeverityWarning, + Code: diag.ErrNonExhaustive, + Position: pos, + Span: span, + Message: message, + Explanation: diag.ErrNonExhaustive.Info().Explanation, + Help: "Handle the missing case or add an else branch.", + }) +} + +func (c *exhaustivenessChecker) addNonExhaustiveNamesWarning(node ast.Expr, path string, missing []string) { + pos := diag.Position{File: c.sourceName} + span := diag.Span{} + if node != nil { + pos.Line = node.Line() + pos.Column = node.Column() + span = ast.SpanOf(node) + } + message := fmt.Sprintf("non-exhaustive match on %s; missing %s", path, formatMissingNames(missing)) + c.diags = append(c.diags, diag.Diagnostic{ + Severity: diag.SeverityWarning, + Code: diag.ErrNonExhaustive, + Position: pos, + Span: span, + Message: message, + Explanation: diag.ErrNonExhaustive.Info().Explanation, + Help: "Handle the missing case or add an else branch.", + }) +} + +func singleElseIf(stmt *ast.IfStmt) (*ast.IfStmt, bool) { + if stmt == nil || len(stmt.Else) != 1 { + return nil, false + } + next, ok := stmt.Else[0].(*ast.IfStmt) + return next, ok +} + +func equalityDiscriminantCheck(expr ast.Expr) (discriminantCheck, bool) { + rel, ok := expr.(*ast.RelationalOpExpr) + if !ok || rel.Operator != "==" { + return discriminantCheck{}, false + } + if check, ok := attrEqualsLiteral(rel.Lhs, rel.Rhs); ok { + return check, true + } + if check, ok := attrEqualsLiteral(rel.Rhs, rel.Lhs); ok { + return check, true + } + if check, ok := attrEqualsPath(rel.Lhs, rel.Rhs); ok { + return check, true + } + return attrEqualsPath(rel.Rhs, rel.Lhs) +} + +func attrEqualsLiteral(attrExpr, literalExpr ast.Expr) (discriminantCheck, bool) { + lit, ok := literalFromExpr(literalExpr) + if !ok { + return discriminantCheck{}, false + } + attr, ok := attrExpr.(*ast.AttrGetExpr) + if !ok { + return discriminantCheck{}, false + } + field := ast.KeyName(attr.Key) + if field == "" { + return discriminantCheck{}, false + } + objectPath, ok := exprPath(attr.Object) + if !ok { + return discriminantCheck{}, false + } + return discriminantCheck{ + object: attr.Object, + field: field, + path: objectPath + "." + field, + literal: lit, + }, true +} + +func attrEqualsPath(attrExpr, valueExpr ast.Expr) (discriminantCheck, bool) { + attr, ok := attrExpr.(*ast.AttrGetExpr) + if !ok { + return discriminantCheck{}, false + } + field := ast.KeyName(attr.Key) + if field == "" { + return discriminantCheck{}, false + } + objectPath, ok := exprPath(attr.Object) + if !ok { + return discriminantCheck{}, false + } + valuePath, ok := exprPath(valueExpr) + if !ok { + return discriminantCheck{}, false + } + return discriminantCheck{ + object: attr.Object, + field: field, + path: objectPath + "." + field, + value: valueExpr, + valueName: valuePath, + }, true +} + +func literalFromExpr(expr ast.Expr) (*typ.Literal, bool) { + switch e := expr.(type) { + case *ast.StringExpr: + return typ.LiteralString(e.Value), true + case *ast.TrueExpr: + return typ.True, true + case *ast.FalseExpr: + return typ.False, true + case *ast.NumberExpr: + if i, ok := numparse.ParseIntegerLiteral(e.Value); ok { + return typ.LiteralInt(i), true + } + if f, ok := numparse.ParseFloatLiteral(e.Value); ok { + return typ.LiteralNumber(f), true + } + } + return nil, false +} + +func exprPath(expr ast.Expr) (string, bool) { + switch e := expr.(type) { + case *ast.IdentExpr: + if e.Value == "" { + return "", false + } + return e.Value, true + case *ast.AttrGetExpr: + base, ok := exprPath(e.Object) + if !ok { + return "", false + } + key := ast.KeyName(e.Key) + if key == "" { + return "", false + } + return base + "." + key, true + default: + return "", false + } +} + +func formatMissingCases(missing []*typ.Literal) string { + values := make([]string, 0, len(missing)) + for _, lit := range missing { + if lit != nil { + values = append(values, lit.String()) + } + } + if len(values) == 1 { + return "case: " + values[0] + } + return "cases: " + strings.Join(values, ", ") +} + +func formatMissingNames(missing []string) string { + if len(missing) == 1 { + return "case: " + missing[0] + } + return "cases: " + strings.Join(missing, ", ") +} + +func selectCasesByResult(graph *cfg.Graph) map[string]selectCaseDomain { + if graph == nil || graph.Bindings() == nil { + return nil + } + bindings := graph.Bindings() + domains := make(map[string]selectCaseDomain) + graph.EachAssign(func(p cfg.Point, info *cfg.AssignInfo) { + target, ok := info.FirstTarget() + if !ok || target.Kind != cfg.TargetIdent || target.Symbol == 0 { + return + } + call := info.SingleSourceCall() + if !isChannelSelectCall(call) || len(call.Args) == 0 { + return + } + cases, ok := selectCaseChannels(call.Args[0], p, graph, bindings) + if !ok || len(cases) < 2 { + return + } + resultPath := constraint.Path{Root: target.Name, Symbol: target.Symbol} + if len(info.TargetVersions) > 0 && !info.TargetVersions[0].IsZero() { + resultPath.Version = info.TargetVersions[0].ID + } + domains[pathKey(resultPath)] = selectCaseDomain{cases: cases} + }) + return domains +} + +func isChannelSelectCall(call *cfg.CallInfo) bool { + if call == nil || call.Method != "" { + return false + } + attr, ok := call.Callee.(*ast.AttrGetExpr) + if !ok { + return false + } + key := ast.KeyName(attr.Key) + if key != "select" { + return false + } + root, ok := attr.Object.(*ast.IdentExpr) + return ok && root.Value == "channel" +} + +func selectCaseChannels(expr ast.Expr, p cfg.Point, graph *cfg.Graph, bindings *bind.BindingTable) ([]selectCase, bool) { + table, ok := expr.(*ast.TableExpr) + if !ok { + return nil, false + } + cases := make([]selectCase, 0, len(table.Fields)) + for _, field := range table.Fields { + if field == nil { + return nil, false + } + if key := ast.KeyName(field.Key); key == "default" { + return nil, false + } + call, ok := field.Value.(*ast.FuncCallExpr) + if !ok || call.Method != "case_receive" || call.Receiver == nil { + return nil, false + } + casePath := flowpath.FromExprWithBindingsAt(call.Receiver, nil, bindings, graph, p) + if casePath.IsEmpty() { + return nil, false + } + name, ok := exprPath(call.Receiver) + if !ok { + name = casePath.String() + } + cases = append(cases, selectCase{path: casePath, name: name}) + } + return cases, len(cases) > 0 +} + +func (d selectCaseDomain) contains(key string) bool { + for _, c := range d.cases { + if pathKey(c.path) == key { + return true + } + } + return false +} + +func pathKey(p constraint.Path) string { + if p.Symbol != 0 { + return fmt.Sprintf("#%d@%d%s", p.Symbol, p.Version, constraint.FormatSegments(p.Segments)) + } + return p.String() +} diff --git a/compiler/check/hooks/hooks.go b/compiler/check/hooks/hooks.go index c70ed869..213c5fef 100644 --- a/compiler/check/hooks/hooks.go +++ b/compiler/check/hooks/hooks.go @@ -19,6 +19,7 @@ // - WithCall: Argument type mismatches in function calls // - WithField: Invalid field access on types without the field // - WithControl: Unreachable code and control flow issues +// - WithExhaustiveness: Non-exhaustive discriminated union matches // - WithIdent: References to undefined identifiers // // # USAGE @@ -50,6 +51,7 @@ func All() []check.Option { WithCall(), WithField(), WithControl(), + WithExhaustiveness(), WithIdent(), } } @@ -107,6 +109,16 @@ func WithControl() check.Option { }) } +// WithExhaustiveness enables warnings for non-exhaustive discriminated union matches. +func WithExhaustiveness() check.Option { + return check.WithPass(func(sess *check.Session, fn *ast.FunctionExpr, result *api.FuncResult) []diag.Diagnostic { + if fn == nil || result.Graph == nil || result.NarrowSynth == nil { + return nil + } + return CheckExhaustiveness(fn, result.Graph, result.NarrowSynth.Narrow(), sess.SourceName) + }) +} + // WithIdent enables undefined identifier checking. func WithIdent() check.Option { return check.WithPass(func(sess *check.Session, _ *ast.FunctionExpr, result *api.FuncResult) []diag.Diagnostic { diff --git a/compiler/check/hooks/hooks_test.go b/compiler/check/hooks/hooks_test.go index b46c1518..2cc2a4f9 100644 --- a/compiler/check/hooks/hooks_test.go +++ b/compiler/check/hooks/hooks_test.go @@ -6,8 +6,8 @@ import ( func TestAll_ReturnsOptions(t *testing.T) { opts := All() - if len(opts) != 6 { - t.Errorf("All() returned %d options, expected 6", len(opts)) + if len(opts) != 7 { + t.Errorf("All() returned %d options, expected 7", len(opts)) } } @@ -46,6 +46,13 @@ func TestWithControl_NotNil(t *testing.T) { } } +func TestWithExhaustiveness_NotNil(t *testing.T) { + opt := WithExhaustiveness() + if opt == nil { + t.Error("WithExhaustiveness() returned nil") + } +} + func TestWithIdent_NotNil(t *testing.T) { opt := WithIdent() if opt == nil { diff --git a/compiler/check/infer/return/infer.go b/compiler/check/infer/return/infer.go index 509c0ebf..c788ce53 100644 --- a/compiler/check/infer/return/infer.go +++ b/compiler/check/infer/return/infer.go @@ -20,7 +20,8 @@ // Return type inference uses monotone union for convergence: // - New return types are joined with previous return types // - Types can only grow (become more general), never shrink -// - Bounded iteration with widening to unknown on non-convergence +// - Recursive SCCs use convergence widening, so iteration is governed by +// domain stabilization rather than an artificial budget // // # PARAMETER EVIDENCE // @@ -66,43 +67,36 @@ import ( // Config holds dependencies for return inference. type Config struct { - Types core.TypeOps - GlobalTypes map[string]typ.Type - Manifests io.ManifestQuerier - Stdlib *scope.State - Store api.StoreReader - Graphs api.GraphProvider - SourceName string - MaxIterations int + Types core.TypeOps + GlobalTypes map[string]typ.Type + Manifests io.ManifestQuerier + Stdlib *scope.State + Store api.StoreReader + Graphs api.GraphProvider + SourceName string } // Inferencer computes pre-flow return vectors for local functions. type Inferencer struct { - types core.TypeOps - globalTypes map[string]typ.Type - manifests io.ManifestQuerier - stdlib *scope.State - store api.StoreReader - graphs api.GraphProvider - sourceName string - maxIterations int + types core.TypeOps + globalTypes map[string]typ.Type + manifests io.ManifestQuerier + stdlib *scope.State + store api.StoreReader + graphs api.GraphProvider + sourceName string } // New creates a configured return inferencer. func New(cfg Config) *Inferencer { - maxIter := cfg.MaxIterations - if maxIter <= 0 { - maxIter = 10 - } return &Inferencer{ - types: cfg.Types, - globalTypes: cfg.GlobalTypes, - manifests: cfg.Manifests, - stdlib: cfg.Stdlib, - store: cfg.Store, - graphs: cfg.Graphs, - sourceName: cfg.SourceName, - maxIterations: maxIter, + types: cfg.Types, + globalTypes: cfg.GlobalTypes, + manifests: cfg.Manifests, + stdlib: cfg.Stdlib, + store: cfg.Store, + graphs: cfg.Graphs, + sourceName: cfg.SourceName, } } @@ -362,8 +356,8 @@ func assembleFunctionFacts( // - New types are joined with previous types via monotone union // - Iteration stops when no type changes // -// WIDENING: If SCC iteration exceeds MaxReturnSummaryIterations, types are widened -// to unknown to guarantee termination. A diagnostic is emitted for the non-convergence. +// WIDENING: Recursive SCCs merge through the convergence widening operator each +// round. The domain owns termination; callers do not cap iteration count. // // SEEDING: Initial return type estimates come from the seed map (previous fixpoint // iteration). This accelerates convergence for iteratively-refined modules. diff --git a/compiler/check/infer/return/scc.go b/compiler/check/infer/return/scc.go index 9d3e1fb9..3253698f 100644 --- a/compiler/check/infer/return/scc.go +++ b/compiler/check/infer/return/scc.go @@ -1,7 +1,6 @@ package infer import ( - "github.com/wippyai/go-lua/compiler/ast" "github.com/wippyai/go-lua/compiler/bind" "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/domain/paramevidence" @@ -12,21 +11,20 @@ import ( ) // iterateSCCFixpoint runs fixpoint iteration for a single SCC until convergence. -// Returns true if types stabilized within the iteration limit. +// Returns once the widened return-vector product stabilizes. func (i *Inferencer) iterateSCCFixpoint( run RunContext, scc []cfg.SymbolID, localFuncs map[cfg.SymbolID]*returns.LocalFuncInfo, returnVectors map[cfg.SymbolID][]typ.Type, ) bool { - for iter := 0; iter < i.maxIterations; iter++ { + for { next, changed := i.runSCCIteration(run, scc, localFuncs, returnVectors) applySCCIterationUpdates(returnVectors, scc, next) if !changed { return true } } - return false } func (i *Inferencer) planLocalFunctionSCCs(localFuncs map[cfg.SymbolID]*returns.LocalFuncInfo) [][]cfg.SymbolID { @@ -75,19 +73,13 @@ func (i *Inferencer) processSCCReturnVectors( localFuncs map[cfg.SymbolID]*returns.LocalFuncInfo, returnVectors map[cfg.SymbolID][]typ.Type, ) []diag.Diagnostic { - var diags []diag.Diagnostic for _, scc := range sccs { if len(scc) == 0 { continue } - if i.iterateSCCFixpoint(run, scc, localFuncs, returnVectors) { - continue - } - if warn := i.widenSCCToUnknown(scc, localFuncs, returnVectors); warn != nil { - diags = append(diags, *warn) - } + i.iterateSCCFixpoint(run, scc, localFuncs, returnVectors) } - return diags + return nil } func (i *Inferencer) runSCCIteration( @@ -105,7 +97,7 @@ func (i *Inferencer) runSCCIteration( } newReturn := i.inferReturnForFunction(run, info, returnVectors, localFuncs) oldReturn := returnVectors[sym] - merged := returnsummary.Merge(oldReturn, newReturn) + merged := returnsummary.WidenForConvergence(oldReturn, newReturn) next[sym] = merged if !returnsummary.Equal(merged, oldReturn) { changed = true @@ -125,33 +117,3 @@ func applySCCIterationUpdates( } } } - -// widenSCCToUnknown widens all SCC members to unknown when fixpoint did not converge. -// Preserves return arity while replacing type slots with unknown. -func (i *Inferencer) widenSCCToUnknown( - scc []cfg.SymbolID, - localFuncs map[cfg.SymbolID]*returns.LocalFuncInfo, - returnVectors map[cfg.SymbolID][]typ.Type, -) *diag.Diagnostic { - for _, sym := range scc { - existing := returnVectors[sym] - if len(existing) == 0 { - returnVectors[sym] = []typ.Type{typ.Unknown} - } else { - widened := make([]typ.Type, len(existing)) - for i := range widened { - widened[i] = typ.Unknown - } - returnVectors[sym] = widened - } - } - if info := localFuncs[scc[0]]; info != nil && info.Fn != nil { - return &diag.Diagnostic{ - Position: diag.Position{File: i.sourceName, Line: info.Fn.Line(), Column: info.Fn.Column()}, - Span: ast.SpanOf(info.Fn), - Severity: diag.SeverityWarning, - Message: "return type fixpoint did not converge; using unknown", - } - } - return nil -} diff --git a/compiler/check/pipeline/diagnostics.go b/compiler/check/pipeline/diagnostics.go index 45084abd..2d651c53 100644 --- a/compiler/check/pipeline/diagnostics.go +++ b/compiler/check/pipeline/diagnostics.go @@ -1,7 +1,6 @@ // This package handles post-analysis diagnostic operations: // - Sorting functions by source position for deterministic pass execution // - Sorting diagnostics for stable output ordering -// - Generating widening diagnostics when type inference doesn't converge // // Deterministic ordering is essential for reproducible builds and test stability. // All sorting uses stable tie-breakers (graph ID, message content) to ensure @@ -9,7 +8,6 @@ package pipeline import ( - "fmt" "sort" "github.com/wippyai/go-lua/compiler/ast" @@ -116,50 +114,6 @@ func SortDiagnostics(diags []diag.Diagnostic) { }) } -// WideningDiagnostics reports symbols that were widened to unknown during preflow inference. -func WideningDiagnostics(sourceName string, fn *ast.FunctionExpr, result *api.FuncResult) []diag.Diagnostic { - if result == nil || result.FlowInputs == nil || len(result.FlowInputs.WideningEvents) == 0 { - return nil - } - - seenSCC := make(map[int]bool) - var diags []diag.Diagnostic - for _, event := range result.FlowInputs.WideningEvents { - if seenSCC[event.SCCIndex] { - continue - } - seenSCC[event.SCCIndex] = true - - symName := "" - if result.Graph != nil { - symName = result.Graph.NameOf(event.Symbol) - } - if symName == "" { - symName = "" - } - - sccSize := len(event.SCC) - msg := fmt.Sprintf("type inference did not converge for '%s' (SCC size %d); widened to unknown", symName, sccSize) - - pos := diag.Position{File: sourceName} - span := diag.Span{} - if fn != nil { - pos.Line = fn.Line() - pos.Column = fn.Column() - span = ast.SpanOf(fn) - } - - diags = append(diags, diag.Diagnostic{ - Position: pos, - Span: span, - Severity: diag.SeverityWarning, - Message: msg, - }) - } - - return diags -} - // ResolveSymbolName provides a stable name for diagnostics when CFG data is available. func ResolveSymbolName(graph *cfg.Graph, sym cfg.SymbolID) string { if graph == nil { diff --git a/compiler/check/pipeline/diagnostics_test.go b/compiler/check/pipeline/diagnostics_test.go index 0ff0e6a3..e0b6c1c1 100644 --- a/compiler/check/pipeline/diagnostics_test.go +++ b/compiler/check/pipeline/diagnostics_test.go @@ -7,7 +7,6 @@ import ( "github.com/wippyai/go-lua/compiler/cfg" "github.com/wippyai/go-lua/compiler/check/api" "github.com/wippyai/go-lua/types/diag" - "github.com/wippyai/go-lua/types/flow" ) func TestSortedResultFunctions_Empty(t *testing.T) { @@ -155,71 +154,6 @@ func TestSortDiagnostics_ByMessage(t *testing.T) { } } -func TestWideningDiagnostics_NilResult(t *testing.T) { - result := WideningDiagnostics("test.lua", nil, nil) - if result != nil { - t.Error("expected nil for nil result") - } -} - -func TestWideningDiagnostics_NilFlowInputs(t *testing.T) { - result := WideningDiagnostics("test.lua", nil, &api.FuncResult{}) - if result != nil { - t.Error("expected nil for nil flow inputs") - } -} - -func TestWideningDiagnostics_NoEvents(t *testing.T) { - result := WideningDiagnostics("test.lua", nil, &api.FuncResult{ - FlowInputs: &flow.Inputs{}, - }) - if result != nil { - t.Error("expected nil for no widening events") - } -} - -func TestWideningDiagnostics_WithEvents(t *testing.T) { - fn := &ast.FunctionExpr{} - fn.SetLine(10) - fn.SetColumn(5) - - result := WideningDiagnostics("test.lua", fn, &api.FuncResult{ - FlowInputs: &flow.Inputs{ - WideningEvents: []flow.WideningEvent{ - {Symbol: 1, SCCIndex: 0, SCC: []cfg.SymbolID{1, 2}}, - }, - }, - }) - - if len(result) != 1 { - t.Fatalf("expected 1 diagnostic, got %d", len(result)) - } - if result[0].Position.File != "test.lua" { - t.Error("wrong file") - } - if result[0].Position.Line != 10 { - t.Error("wrong line") - } - if result[0].Severity != diag.SeverityWarning { - t.Error("expected warning severity") - } -} - -func TestWideningDiagnostics_DeduplicatesSCC(t *testing.T) { - result := WideningDiagnostics("test.lua", nil, &api.FuncResult{ - FlowInputs: &flow.Inputs{ - WideningEvents: []flow.WideningEvent{ - {Symbol: 1, SCCIndex: 0, SCC: []cfg.SymbolID{1, 2}}, - {Symbol: 2, SCCIndex: 0, SCC: []cfg.SymbolID{1, 2}}, - }, - }, - }) - - if len(result) != 1 { - t.Errorf("expected 1 diagnostic (deduplicated), got %d", len(result)) - } -} - func TestResolveSymbolName_NilGraph(t *testing.T) { name := ResolveSymbolName(nil, 1) if name != "" { diff --git a/compiler/check/pipeline/driver.go b/compiler/check/pipeline/driver.go index 96c48b9b..e46fb8cd 100644 --- a/compiler/check/pipeline/driver.go +++ b/compiler/check/pipeline/driver.go @@ -7,7 +7,7 @@ // 3. Execute the memoized function analysis pipeline // 4. Propagate effects and interprocedural facts // 5. Process nested functions recursively -// 6. Repeat until fixpoint (no channel changes) or max iterations +// 6. Repeat until fixpoint (no channel changes) // // The driver coordinates several inference subsystems: // - Return inference: Computes return types for local functions @@ -28,7 +28,6 @@ import ( nestedinfer "github.com/wippyai/go-lua/compiler/check/infer/nested" returninfer "github.com/wippyai/go-lua/compiler/check/infer/return" "github.com/wippyai/go-lua/compiler/check/modules" - "github.com/wippyai/go-lua/compiler/check/returns" "github.com/wippyai/go-lua/compiler/check/scope" "github.com/wippyai/go-lua/types/constraint" "github.com/wippyai/go-lua/types/db" @@ -44,7 +43,6 @@ type Config struct { GlobalTypes map[string]typ.Type Stdlib *scope.State Manifests *db.DB - MaxIterations int MaxScopeDepth int EmitScopeDiag bool FuncResultQ *db.Query[api.FuncKey, *api.FuncResult] @@ -93,36 +91,12 @@ func (d *Driver) Run(sess api.AnalysisSession, chunk []ast.Stmt) { } func (d *Driver) runFixpoint(sess api.AnalysisSession, fn *ast.FunctionExpr, parent *scope.State) { - maxIterations := d.cfg.MaxIterations - if maxIterations < 1 { - maxIterations = 1 - } - - converged := false - for iter := 0; iter < maxIterations; iter++ { + for { d.prepareIterationState(sess) d.checkFunctionFixpoint(sess, fn, parent) if d.advanceFixpoint(sess.StoreHandle()) { - converged = true - break - } - } - - if !converged { - store := sess.StoreHandle() - diffs := []string(nil) - if store != nil { - diffs = store.FixpointChannelDiffs() - } - msg := "inter-function fixpoint did not converge" - if len(diffs) > 0 { - msg += "; unstable channels: " + fmt.Sprintf("%v", diffs) + return } - sess.AppendDiagnostics(diag.Diagnostic{ - Position: diag.Position{File: sess.Source()}, - Severity: diag.SeverityWarning, - Message: msg, - }) } } @@ -139,10 +113,7 @@ func (d *Driver) advanceFixpoint(store api.IterationStore) bool { if store == nil { return true } - if !store.FixpointSwap() { - return true - } - return false + return !store.FixpointSwap() } func (d *Driver) checkFunctionFixpoint(sess api.AnalysisSession, fn *ast.FunctionExpr, parent *scope.State) { @@ -220,14 +191,13 @@ func (d *Driver) runReturnInference( } inferencer := returninfer.New(returninfer.Config{ - Types: d.cfg.Types, - GlobalTypes: d.cfg.GlobalTypes, - Manifests: d.cfg.Manifests, - Stdlib: d.cfg.Stdlib, - Store: store, - Graphs: sess, - SourceName: sess.Source(), - MaxIterations: returns.MaxReturnSummaryIterations, + Types: d.cfg.Types, + GlobalTypes: d.cfg.GlobalTypes, + Manifests: d.cfg.Manifests, + Stdlib: d.cfg.Stdlib, + Store: store, + Graphs: sess, + SourceName: sess.Source(), }) var refinementLookup constraint.RefinementLookupBySym diff --git a/compiler/check/pipeline/driver_test.go b/compiler/check/pipeline/driver_test.go index 4047f923..04d05f17 100644 --- a/compiler/check/pipeline/driver_test.go +++ b/compiler/check/pipeline/driver_test.go @@ -8,15 +8,11 @@ import ( func TestNew(t *testing.T) { d := New(Config{ - MaxIterations: 5, MaxScopeDepth: 10, }) if d == nil { t.Fatal("expected non-nil driver") } - if d.cfg.MaxIterations != 5 { - t.Error("MaxIterations not set") - } if d.cfg.MaxScopeDepth != 10 { t.Error("MaxScopeDepth not set") } @@ -29,14 +25,10 @@ func TestDriver_Run_NilSession(t *testing.T) { func TestConfig_Fields(t *testing.T) { cfg := Config{ - MaxIterations: 3, MaxScopeDepth: 8, EmitScopeDiag: true, GlobalTypes: map[string]typ.Type{"foo": typ.String}, } - if cfg.MaxIterations != 3 { - t.Error("MaxIterations not set") - } if cfg.MaxScopeDepth != 8 { t.Error("MaxScopeDepth not set") } diff --git a/compiler/check/returns/callgraph_test.go b/compiler/check/returns/callgraph_test.go index 4b9f06d4..5dc21795 100644 --- a/compiler/check/returns/callgraph_test.go +++ b/compiler/check/returns/callgraph_test.go @@ -195,15 +195,6 @@ func TestLocalFuncInfo_ParameterEvidenceExpansion(t *testing.T) { } } -func TestMaxReturnSummaryIterations_Value(t *testing.T) { - if MaxReturnSummaryIterations < 1 { - t.Error("MaxReturnSummaryIterations should be positive") - } - if MaxReturnSummaryIterations > 100 { - t.Error("MaxReturnSummaryIterations seems too high") - } -} - func TestBuildLocalCallGraph_AddsCallbackFunctionEdges(t *testing.T) { stmts, err := parse.ParseString(` local function wrapper(cb: fun(): number): number diff --git a/compiler/check/returns/types.go b/compiler/check/returns/types.go index 59a0b117..39d11f22 100644 --- a/compiler/check/returns/types.go +++ b/compiler/check/returns/types.go @@ -64,7 +64,3 @@ type LocalFuncInfo struct { // parent graph. For methods, index 0 is self. ParameterEvidence []typ.Type } - -// MaxReturnSummaryIterations limits fixpoint iterations for return-vector inference. -// Exceeding this indicates a bug (non-monotonic merge) or pathological recursion. -const MaxReturnSummaryIterations = 10 diff --git a/compiler/check/returns/types_test.go b/compiler/check/returns/types_test.go index 78ca76d8..5cfe6b81 100644 --- a/compiler/check/returns/types_test.go +++ b/compiler/check/returns/types_test.go @@ -57,17 +57,3 @@ func TestLocalFuncInfoStructure(t *testing.T) { } }) } - -func TestMaxReturnSummaryIterations(t *testing.T) { - t.Run("constant value", func(t *testing.T) { - if MaxReturnSummaryIterations != 10 { - t.Fatalf("expected MaxReturnSummaryIterations=10, got %d", MaxReturnSummaryIterations) - } - }) - - t.Run("constant is positive", func(t *testing.T) { - if MaxReturnSummaryIterations <= 0 { - t.Fatal("expected positive constant") - } - }) -} diff --git a/compiler/check/tests/flow/preflow_convergence_test.go b/compiler/check/tests/flow/preflow_convergence_test.go index 2851381e..0d0f6a86 100644 --- a/compiler/check/tests/flow/preflow_convergence_test.go +++ b/compiler/check/tests/flow/preflow_convergence_test.go @@ -214,11 +214,7 @@ local c: number = b } } -// TestPreflowConvergence_WideningSoundness tests that when an SCC doesn't converge, -// ALL members are widened to unknown, not just missing entries. -func TestPreflowConvergence_WideningSoundness(t *testing.T) { - // This test verifies that partial types don't leak through when widening occurs. - // The key property is that if widening triggers, all affected symbols get unknown. +func TestPreflowConvergence_RecursiveRecordCycleConverges(t *testing.T) { source := ` local a = {x = 1} local b = {y = a} @@ -236,10 +232,7 @@ local n: number = a.x } } -// TestPreflowConvergence_WideningReported tests that widening events are recorded. -func TestPreflowConvergence_WideningReported(t *testing.T) { - // Create a case that triggers widening by exceeding max iterations. - // Deeply recursive mutual dependencies that don't stabilize quickly. +func TestPreflowConvergence_RecursiveFunctionCycleConverges(t *testing.T) { source := ` local a, b, c, d, e @@ -251,66 +244,10 @@ e = function() return a() end ` result := testutil.Check(source, testutil.WithStdlib()) - - // Access widening events from flow inputs - if result.Session == nil || result.Session.RootResult == nil { - t.Fatal("expected session with root result") - } - - inputs := result.Session.RootResult.FlowInputs - if inputs == nil { - t.Fatal("expected flow inputs") - } - - // Even if no widening occurs in this simple case, verify the field exists - // and the API works. A true non-converging case is hard to construct - // without artificial iteration limits. - t.Logf("widening events count: %d", len(inputs.WideningEvents)) -} - -// TestPreflowConvergence_WideningDiagnosticEmitted tests that widening diagnostics -// are emitted when preflow inference doesn't converge. -func TestPreflowConvergence_WideningDiagnosticEmitted(t *testing.T) { - // This test verifies the diagnostic plumbing works. - // Note: Most real code converges within the iteration limit, - // so widening diagnostics are rare in practice. - source := ` -local a, b, c, d, e - -a = function() return b() end -b = function() return c() end -c = function() return d() end -d = function() return e() end -e = function() return a() end -` - - result := testutil.Check(source, testutil.WithStdlib()) - - if result.Session == nil || result.Session.RootResult == nil { - t.Fatal("expected session with root result") - } - - // Count widening diagnostics (if any) - wideningDiagCount := 0 - for _, d := range result.Session.Diagnostics { - if d.Severity == diag.SeverityWarning { - if len(d.Message) > 0 && (contains(d.Message, "widened to unknown") || contains(d.Message, "type inference did not converge")) { - wideningDiagCount++ - t.Logf("Widening diagnostic: %s", d.Message) - } - } - } - - // Log whether widening occurred - inputs := result.Session.RootResult.FlowInputs - if inputs != nil { - t.Logf("widening events: %d, widening diagnostics: %d", len(inputs.WideningEvents), wideningDiagCount) - - // If widening events occurred, diagnostics should be emitted - if len(inputs.WideningEvents) > 0 && wideningDiagCount == 0 { - t.Error("widening events occurred but no diagnostics were emitted") - } + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) } + assertNoConvergenceWarnings(t, result.Diagnostics) } // TestPreflowConvergence_MapEntryFallbackCounters_NoWarnings reproduces @@ -354,15 +291,27 @@ mark_passed("suite:a") } for _, d := range result.Diagnostics { - if d.Severity != diag.SeverityWarning { - continue + if isConvergenceWarning(d) { + t.Fatalf("unexpected convergence warning: %q", d.Message) } - if contains(d.Message, "type inference did not converge") || d.Message == "inter-function fixpoint did not converge" { + } +} + +func assertNoConvergenceWarnings(t *testing.T, diags []diag.Diagnostic) { + t.Helper() + for _, d := range diags { + if isConvergenceWarning(d) { t.Fatalf("unexpected convergence warning: %q", d.Message) } } } +func isConvergenceWarning(d diag.Diagnostic) bool { + return d.Severity == diag.SeverityWarning && + (contains(d.Message, "type inference did not converge") || + contains(d.Message, "inter-function fixpoint did not converge")) +} + // contains is a simple substring check helper. func contains(s, substr string) bool { return len(s) >= len(substr) && (s == substr || len(substr) == 0 || diff --git a/compiler/check/tests/regression/exhaustiveness_warning_test.go b/compiler/check/tests/regression/exhaustiveness_warning_test.go new file mode 100644 index 00000000..f48bee8c --- /dev/null +++ b/compiler/check/tests/regression/exhaustiveness_warning_test.go @@ -0,0 +1,244 @@ +package regression + +import ( + "strings" + "testing" + + "github.com/wippyai/go-lua/compiler/check/tests/testutil" + "github.com/wippyai/go-lua/types/diag" +) + +func TestExhaustivenessWarning_DiscriminatedUnionMissingVariant(t *testing.T) { + source := ` +type Message = {kind: "message", text: string} +type Tool = {kind: "tool", name: string} +type Timeout = {kind: "timeout", at: number} +type Event = Message | Tool | Timeout + +local function render(event: Event): string + if event.kind == "message" then + return event.text + elseif event.kind == "tool" then + return event.name + end + return "unknown" +end + +return render({kind = "message", text = "hi"}) +` + + result := testutil.Check(source) + if result.HasError() { + t.Fatalf("expected no errors, got: %v", testutil.ErrorMessages(result.Diagnostics)) + } + assertNonExhaustiveWarning(t, result.Diagnostics, "event.kind", `"timeout"`) +} + +func TestExhaustivenessWarning_ChannelSelectMissingCase(t *testing.T) { + source := ` +type Event = {kind: string} +type Stop = {reason: string} +type Time = {sec: number, nsec: number} + +local function handle(events_ch: Channel, stop_ch: Channel, timeout_ch: Channel