From 77827b5c870dca5e4dbf42f7fe708a668a6a37b1 Mon Sep 17 00:00:00 2001 From: He-Pin Date: Fri, 6 Mar 2026 23:37:04 +0800 Subject: [PATCH 1/4] chore: Reduce allocation of Lazy chore: fix run bench of bench.07.jsonnet --- bench/resources/refresh_golden_outputs.sh | 2 +- .../sjsonnet/bench/RegressionBenchmark.scala | 7 +- sjsonnet/src/sjsonnet/Evaluator.scala | 54 ++++++-- sjsonnet/src/sjsonnet/Materializer.scala | 2 +- sjsonnet/src/sjsonnet/Val.scala | 122 +++++++++++++++++- .../src/sjsonnet/stdlib/ArrayModule.scala | 19 ++- .../src/sjsonnet/stdlib/ObjectModule.scala | 4 +- 7 files changed, 176 insertions(+), 34 deletions(-) diff --git a/bench/resources/refresh_golden_outputs.sh b/bench/resources/refresh_golden_outputs.sh index 381cc18a..da453719 100755 --- a/bench/resources/refresh_golden_outputs.sh +++ b/bench/resources/refresh_golden_outputs.sh @@ -12,7 +12,7 @@ for suite in bench/resources/*_suite; do echo "Refreshing golden outputs for suite: $suite_name" for f in "$suite"/*.jsonnet; do echo " Processing file: $f" - java -Xss100m -Xmx2g -jar "$SJSONNET" -J "$suite" "$f" > "$f.golden" + java -Xss100m -Xmx2g -jar "$SJSONNET" --max-stack 100000 -J "$suite" "$f" > "$f.golden" done done diff --git a/bench/src/sjsonnet/bench/RegressionBenchmark.scala b/bench/src/sjsonnet/bench/RegressionBenchmark.scala index cb1b1460..5e6f7329 100644 --- a/bench/src/sjsonnet/bench/RegressionBenchmark.scala +++ b/bench/src/sjsonnet/bench/RegressionBenchmark.scala @@ -11,6 +11,9 @@ object RegressionBenchmark { private val testSuiteRoot: os.Path = sys.env.get("MILL_WORKSPACE_ROOT").map(os.Path(_)).getOrElse(os.pwd) + /** Shared CLI args passed to every benchmark invocation (e.g. bench.07 needs deep recursion). */ + private val defaultArgs: Array[String] = Array("--max-stack", "100000") + private def createDummyOut = new PrintStream(new OutputStream { def write(b: Int): Unit = () override def write(b: Array[Byte]): Unit = () @@ -36,7 +39,7 @@ class RegressionBenchmark { val baos = new ByteArrayOutputStream() val ps = new PrintStream(baos) SjsonnetMainBase.main0( - Array(path), + RegressionBenchmark.defaultArgs :+ path, new DefaultParseCache, System.in, ps, @@ -61,7 +64,7 @@ class RegressionBenchmark { def main(bh: Blackhole): Unit = { bh.consume( SjsonnetMainBase.main0( - Array(path), + RegressionBenchmark.defaultArgs :+ path, new DefaultParseCache, System.in, dummyOut, diff --git a/sjsonnet/src/sjsonnet/Evaluator.scala b/sjsonnet/src/sjsonnet/Evaluator.scala index dd823ee3..43804262 100644 --- a/sjsonnet/src/sjsonnet/Evaluator.scala +++ b/sjsonnet/src/sjsonnet/Evaluator.scala @@ -121,11 +121,37 @@ class Evaluator( Error.fail("Should not have happened.", e.pos) } + /** + * Convert an expression to an [[Eval]] for deferred evaluation. + * + * Three fast paths eliminate or reduce allocation vs the naive + * `new LazyFunc(() => visitExpr(e))`: + * + * 1. [[Val]] literals — already evaluated, return as-is (zero cost). + * 2. [[ValidId]] (variable reference) where the binding slot is non-null — reuse the existing + * [[Eval]] from scope directly (zero allocation). Covers ~18% of calls. When the slot IS + * null (self-recursive local, e.g. `local a = [a[1], 0]`), the binding hasn't been written + * yet, so we must create a deferred thunk to defer the lookup. + * 3. All other expressions — [[LazyExpr]] stores (Expr, ValScope, Evaluator) as fields instead + * of capturing them in a closure: 1 JVM object vs 2. Covers ~76% of calls (dominated by + * BinaryOp). + * + * PERF: Do not revert to `new LazyFunc(() => visitExpr(e))` — profiling across all benchmark + * suites shows this method produces ~93% of deferred evaluations. The fast paths eliminate 242K + * allocations (bench.02) and improve wall-clock time ~5% (comparison2). + */ def visitAsLazy(e: Expr)(implicit scope: ValScope): Eval = e match { - case v: Val => v - case e => + case v: Val => v + case e: ValidId => + val binding = scope.bindings(e.nameIdx) + if (binding != null) binding + else { + if (debugStats != null) debugStats.lazyCreated += 1 + new LazyExpr(e, scope, this) + } + case e => if (debugStats != null) debugStats.lazyCreated += 1 - new Lazy(() => visitExpr(e)) + new LazyExpr(e, scope, this) } def visitValidId(e: ValidId)(implicit scope: ValScope): Val = { @@ -149,9 +175,10 @@ class Evaluator( while (i < bindings.length) { val b = bindings(i) newScope.bindings(base + i) = b.args match { - case null => visitAsLazy(b.rhs)(newScope) - case argSpec => - new Lazy(() => visitMethod(b.rhs, argSpec, b.pos, b.name)(newScope)) + case null => visitAsLazy(b.rhs)(newScope) + case argSpec => // local function def — needs LazyFunc+closure (calls visitMethod, not visitExpr) + if (debugStats != null) debugStats.lazyCreated += 1 + new LazyFunc(() => visitMethod(b.rhs, argSpec, b.pos, b.name)(newScope)) } i += 1 } @@ -788,8 +815,9 @@ class Evaluator( val b = bindings(i) newScope.bindings(base + i) = b.args match { case null => visitAsLazy(b.rhs)(newScope) - case argSpec => - new Lazy(() => visitMethod(b.rhs, argSpec, b.pos)(newScope)) + case argSpec => // local function def — needs LazyFunc+closure (calls visitMethod) + if (debugStats != null) debugStats.lazyCreated += 1 + new LazyFunc(() => visitMethod(b.rhs, argSpec, b.pos)(newScope)) } i += 1 } @@ -853,6 +881,7 @@ class Evaluator( visitExpr(e) } + // Note: can't use LazyExpr here — `scope` is by-name (=> ValScope), must remain lazy. def visitBindings(bindings: Array[Bind], scope: => ValScope): Array[Eval] = { if (debugStats != null) debugStats.lazyCreated += bindings.length val arrF = new Array[Eval](bindings.length) @@ -861,9 +890,9 @@ class Evaluator( val b = bindings(i) arrF(i) = b.args match { case null => - new Lazy(() => visitExpr(b.rhs)(scope)) + new LazyFunc(() => visitExpr(b.rhs)(scope)) case argSpec => - new Lazy(() => visitMethod(b.rhs, argSpec, b.pos, b.name)(scope)) + new LazyFunc(() => visitMethod(b.rhs, argSpec, b.pos, b.name)(scope)) } i += 1 } @@ -926,8 +955,9 @@ class Evaluator( arrF(j) = b.args match { case null => visitAsLazy(b.rhs)(newScope) - case argSpec => - new Lazy(() => visitMethod(b.rhs, argSpec, b.pos)(newScope)) + case argSpec => // local function def — needs LazyFunc+closure (calls visitMethod) + if (debugStats != null) debugStats.lazyCreated += 1 + new LazyFunc(() => visitMethod(b.rhs, argSpec, b.pos)(newScope)) } i += 1 j += 1 diff --git a/sjsonnet/src/sjsonnet/Materializer.scala b/sjsonnet/src/sjsonnet/Materializer.scala index e987f194..97cc5ccc 100644 --- a/sjsonnet/src/sjsonnet/Materializer.scala +++ b/sjsonnet/src/sjsonnet/Materializer.scala @@ -336,7 +336,7 @@ abstract class Materializer { var i = 0 while (i < len) { val x = xs(i) - res(i) = new Lazy(() => reverse(pos, x)) + res(i) = new LazyFunc(() => reverse(pos, x)) i += 1 } Val.Arr(pos, res) diff --git a/sjsonnet/src/sjsonnet/Val.scala b/sjsonnet/src/sjsonnet/Val.scala index 491f593c..4f642e03 100644 --- a/sjsonnet/src/sjsonnet/Val.scala +++ b/sjsonnet/src/sjsonnet/Val.scala @@ -18,19 +18,129 @@ trait Eval { } /** - * Lazily evaluated dictionary values, array contents, or function parameters are all wrapped in - * [[Lazy]] and only truly evaluated on-demand. + * Abstract marker base for deferred (lazy) evaluation. Contains no fields — subclasses manage their + * own caching to minimize per-instance memory. + * + * Hierarchy (allocation percentages measured across 591 test and benchmark files; actual + * distribution varies by workload): + * - [[LazyFunc]] — wraps a `() => Val` closure with a separate `cached` field (~0.1%) + * - [[LazyExpr]] — closure-free `visitExpr` thunk, repurposes fields for caching (~91%) + * - [[LazyApply1]] — closure-free `func.apply1` thunk (~9%) + * - [[LazyApply2]] — closure-free `func.apply2` thunk (<1%) + * + * @see + * [[Eval]] the parent trait shared with [[Val]] (eager values). */ -final class Lazy(private var computeFunc: () => Val) extends Eval { +abstract class Lazy extends Eval + +/** + * Closure-based [[Lazy]]: wraps an arbitrary `() => Val` thunk. + * + * Used for deferred evaluations that don't fit the specialized [[LazyExpr]]/[[LazyApply1]]/ + * [[LazyApply2]] patterns, e.g. `visitMethod` (local function defs), `visitBindings` (object field + * bindings), and default parameter evaluation. These account for <1% of all deferred evaluations + * (profiled across 591 benchmark and test files). + * + * Thread-safety: `f` is `@volatile` so that concurrent readers see a consistent state. If two + * threads race to initialize, the loser sees `f == null` and falls through to read `cached`, which + * is visible due to piggybacking on the volatile write to `f`; see + * https://stackoverflow.com/a/8769692 for background. + */ +final class LazyFunc(@volatile private var f: () => Val) extends Lazy { private var cached: Val = _ def value: Val = { if (cached != null) return cached - cached = computeFunc() - computeFunc = null // allow closure to be GC'd + val func = f + if (func != null) { + cached = func() + f = null // volatile write publishes `cached` to other threads + } + // else: lost the race to compute, but `cached` is already set and visible + // in this thread due to the volatile read of `f` (piggybacking) cached } } +/** + * Closure-free [[Lazy]] that defers `evaluator.visitExpr(expr)(scope)`. + * + * Used in [[Evaluator.visitAsLazy]] instead of `new LazyFunc(() => visitExpr(e)(scope))`. By + * storing (expr, scope, evaluator) as fields rather than capturing them in a closure, this cuts + * per-thunk allocation from 2 JVM objects (LazyFunc + closure) to 1 (LazyExpr), and from 56B to 24B + * (compressed oops). + * + * Profiling across all benchmark and test suites (591 files) shows [[Evaluator.visitAsLazy]] + * produces ~91% of all deferred evaluations. + * + * After computation, the cached [[Val]] is stored in the `exprOrVal` field (which originally held + * the [[Expr]]), and `ev` is nulled as a sentinel. `scope` is also cleared to allow GC. + */ +final class LazyExpr( + private var exprOrVal: AnyRef, // Expr before compute, Val after + private var scope: ValScope, + private var ev: Evaluator) + extends Lazy { + def value: Val = { + if (ev == null) exprOrVal.asInstanceOf[Val] + else { + val r = ev.visitExpr(exprOrVal.asInstanceOf[Expr])(scope) + exprOrVal = r // cache result + scope = null.asInstanceOf[sjsonnet.ValScope] // allow GC + ev = null // sentinel: marks as computed + r + } + } +} + +/** + * Closure-free [[Lazy]] that defers `func.apply1(arg, pos)(ev, TailstrictModeDisabled)`. + * + * Used in stdlib builtins (`std.map`, `std.filterMap`, `std.makeArray`, etc.) to eliminate the + * 2-object allocation (LazyFunc + Function0 closure), cutting from 56B to 32B per instance. After + * computation, `funcOrVal` caches the result, `ev == null` serves as the computed sentinel, and + * remaining fields are cleared for GC. + */ +final class LazyApply1( + private var funcOrVal: AnyRef, // Val.Func before compute, Val after + private var arg: Eval, + private var pos: Position, + private var ev: EvalScope) + extends Lazy { + def value: Val = { + if (ev == null) funcOrVal.asInstanceOf[Val] + else { + val r = funcOrVal.asInstanceOf[Val.Func].apply1(arg, pos)(ev, TailstrictModeDisabled) + funcOrVal = r + arg = null; pos = null; ev = null + r + } + } +} + +/** + * Closure-free [[Lazy]] that defers `func.apply2(arg1, arg2, pos)(ev, TailstrictModeDisabled)`. + * + * Used in stdlib builtins (`std.mapWithIndex`, etc.). Same field-repurposing strategy as + * [[LazyApply1]], cutting from 56B to 32B per instance. + */ +final class LazyApply2( + private var funcOrVal: AnyRef, // Val.Func before compute, Val after + private var arg1: Eval, + private var arg2: Eval, + private var pos: Position, + private var ev: EvalScope) + extends Lazy { + def value: Val = { + if (ev == null) funcOrVal.asInstanceOf[Val] + else { + val r = funcOrVal.asInstanceOf[Val.Func].apply2(arg1, arg2, pos)(ev, TailstrictModeDisabled) + funcOrVal = r + arg1 = null; arg2 = null; pos = null; ev = null + r + } + } +} + /** * [[Val]]s represented Jsonnet values that are the result of evaluating a Jsonnet program. The * [[Val]] data structure is essentially a JSON tree, except evaluation of object attributes and @@ -750,7 +860,7 @@ object Val { if (argVals(j) == null) { val default = params.defaultExprs(i) if (default != null) { - argVals(j) = new Lazy(() => evalDefault(default, newScope, ev)) + argVals(j) = new LazyFunc(() => evalDefault(default, newScope, ev)) } else { if (missing == null) missing = new ArrayBuffer missing.+=(params.names(i)) diff --git a/sjsonnet/src/sjsonnet/stdlib/ArrayModule.scala b/sjsonnet/src/sjsonnet/stdlib/ArrayModule.scala index f8b7384d..5c82f8d2 100644 --- a/sjsonnet/src/sjsonnet/stdlib/ArrayModule.scala +++ b/sjsonnet/src/sjsonnet/stdlib/ArrayModule.scala @@ -164,9 +164,10 @@ object ArrayModule extends AbstractFunctionModule { arg: Array[Eval], ev: EvalScope, pos: Position): Val.Arr = { + val noOff = pos.noOffset Val.Arr( pos, - arg.map(v => new Lazy(() => _func.apply1(v, pos.noOffset)(ev, TailstrictModeDisabled))) + arg.map(v => new LazyApply1(_func, v, noOff, ev)) ) } @@ -180,11 +181,12 @@ object ArrayModule extends AbstractFunctionModule { val func = _func.value.asFunc val arr = _arr.value.asArr.asLazyArray val a = new Array[Eval](arr.length) + val noOff = pos.noOffset var i = 0 while (i < a.length) { val x = arr(i) val idx = Val.Num(pos, i) - a(i) = new Lazy(() => func.apply2(idx, x, pos.noOffset)(ev, TailstrictModeDisabled)) + a(i) = new LazyApply2(func, idx, x, noOff, ev) i += 1 } Val.Arr(pos, a) @@ -425,16 +427,15 @@ object ArrayModule extends AbstractFunctionModule { }, builtin("filterMap", "filter_func", "map_func", "arr") { (pos, ev, filter_func: Val.Func, map_func: Val.Func, arr: Val.Arr) => + val noOff = pos.noOffset Val.Arr( pos, arr.asLazyArray.flatMap { i => i.value - if (!filter_func.apply1(i, pos.noOffset)(ev, TailstrictModeDisabled).asBoolean) { + if (!filter_func.apply1(i, noOff)(ev, TailstrictModeDisabled).asBoolean) { None } else { - Some[Eval]( - new Lazy(() => map_func.apply1(i, pos.noOffset)(ev, TailstrictModeDisabled)) - ) + Some[Eval](new LazyApply1(map_func, i, noOff, ev)) } } ) @@ -468,12 +469,10 @@ object ArrayModule extends AbstractFunctionModule { pos, { val sz = size.cast[Val.Num].asPositiveInt val a = new Array[Eval](sz) + val noOff = pos.noOffset var i = 0 while (i < sz) { - val forcedI = i - a(i) = new Lazy(() => - func.apply1(Val.Num(pos, forcedI), pos.noOffset)(ev, TailstrictModeDisabled) - ) + a(i) = new LazyApply1(func, Val.Num(pos, i), noOff, ev) i += 1 } a diff --git a/sjsonnet/src/sjsonnet/stdlib/ObjectModule.scala b/sjsonnet/src/sjsonnet/stdlib/ObjectModule.scala index 2fa5b0ef..7ae6103e 100644 --- a/sjsonnet/src/sjsonnet/stdlib/ObjectModule.scala +++ b/sjsonnet/src/sjsonnet/stdlib/ObjectModule.scala @@ -107,7 +107,7 @@ object ObjectModule extends AbstractFunctionModule { def invoke(self: Val.Obj, sup: Val.Obj, fs: FileScope, ev: EvalScope): Val = func.apply2( Val.Str(pos, k), - new Lazy(() => obj.value(k, pos.noOffset)(ev)), + new LazyFunc(() => obj.value(k, pos.noOffset)(ev)), pos.noOffset )( ev, @@ -139,7 +139,7 @@ object ObjectModule extends AbstractFunctionModule { Val.Arr( pos, keys.map { k => - new Lazy(() => v1.value(k, pos.noOffset)(ev)) + new LazyFunc(() => v1.value(k, pos.noOffset)(ev)) } ) From 6e0223093fcaf5c9d4048aaafe071b13678c65ba Mon Sep 17 00:00:00 2001 From: He-Pin Date: Sun, 8 Mar 2026 01:20:05 +0800 Subject: [PATCH 2/4] chore: reduce allocation of Num --- sjsonnet/src/sjsonnet/Evaluator.scala | 309 +++++++++++------- sjsonnet/src/sjsonnet/StaticOptimizer.scala | 292 +++++++++++------ sjsonnet/src/sjsonnet/Val.scala | 4 +- .../go_test_suite/binaryNot2.jsonnet.golden | 2 +- .../go_test_suite/bitwise_or10.jsonnet.golden | 2 +- .../number_divided_by_string.jsonnet.golden | 2 +- .../number_times_string.jsonnet.golden | 2 +- .../string_divided_by_number.jsonnet.golden | 2 +- .../string_minus_number.jsonnet.golden | 2 +- .../string_times_number.jsonnet.golden | 2 +- .../go_test_suite/unary_minus4.jsonnet.golden | 2 +- .../go_test_suite/unary_object.jsonnet.golden | 2 +- 12 files changed, 394 insertions(+), 229 deletions(-) diff --git a/sjsonnet/src/sjsonnet/Evaluator.scala b/sjsonnet/src/sjsonnet/Evaluator.scala index 43804262..cc6a43d8 100644 --- a/sjsonnet/src/sjsonnet/Evaluator.scala +++ b/sjsonnet/src/sjsonnet/Evaluator.scala @@ -4,6 +4,8 @@ import sjsonnet.Expr.Member.Visibility import sjsonnet.Expr.{Error as _, *} import ujson.Value +import sjsonnet.Evaluator.SafeDoubleOps + import scala.annotation.{switch, tailrec} /** @@ -234,36 +236,110 @@ class Evaluator( } def visitUnaryOp(e: UnaryOp)(implicit scope: ValScope): Val = { - val v = visitExpr(e.value) val pos = e.pos - def fail() = - Error.fail(s"Unknown unary operation: ${Expr.UnaryOp.name(e.op)} ${v.prettyName}", pos) - e.op match { + (e.op: @switch) match { + case Expr.UnaryOp.OP_+ => Val.Num(pos, visitExprAsDouble(e.value)) + case Expr.UnaryOp.OP_- => Val.Num(pos, -visitExprAsDouble(e.value)) + case Expr.UnaryOp.OP_~ => Val.Num(pos, (~visitExprAsDouble(e.value).toSafeLong(pos)).toDouble) case Expr.UnaryOp.OP_! => - v match { + visitExpr(e.value) match { case Val.True(_) => Val.False(pos) case Val.False(_) => Val.True(pos) - case _ => fail() - } - case Expr.UnaryOp.OP_- => - v match { - case Val.Num(_, v) => Val.Num(pos, -v) - case _ => fail() + case v => + Error.fail(s"Unknown unary operation: ! ${v.prettyName}", pos) } - case Expr.UnaryOp.OP_~ => - v match { - case Val.Num(_, v) => Val.Num(pos, (~v.toLong).toDouble) - case _ => fail() + case _ => + val v = visitExpr(e.value) + Error.fail(s"Unknown unary operation: ${Expr.UnaryOp.name(e.op)} ${v.prettyName}", pos) + } + } + + /** + * Fast path: evaluate an expression expected to produce a Double, avoiding intermediate + * [[Val.Num]] allocation. When a numeric expression chain like `a + b * c - d` is evaluated, + * intermediate results stay as raw JVM `double` primitives (zero allocation) instead of being + * boxed into `Val.Num` objects (~24-32 bytes each) at every step. + * + * Only the outermost operation (in [[visitBinaryOp]]) boxes the final result into a `Val.Num`. + */ + private def visitExprAsDouble(e: Expr)(implicit scope: ValScope): Double = try { + e match { + case v: Val.Num => v.asDouble + case v: Val => Error.fail("Expected Number, got " + v.prettyName, e.pos) + case e: ValidId => + scope.bindings(e.nameIdx).value match { + case n: Val.Num => n.asDouble + case v => Error.fail("Expected Number, got " + v.prettyName, e.pos) } - case Expr.UnaryOp.OP_+ => - v match { - case Val.Num(_, v) => Val.Num(pos, v) - case _ => fail() + case e: BinaryOp => visitBinaryOpAsDouble(e) + case e: UnaryOp => visitUnaryOpAsDouble(e) + case e => + visitExpr(e) match { + case n: Val.Num => n.asDouble + case v => Error.fail("Expected Number, got " + v.prettyName, e.pos) } - case _ => fail() } + } catch { + Error.withStackFrame(e) } + private def visitBinaryOpAsDouble(e: BinaryOp)(implicit scope: ValScope): Double = { + val pos = e.pos + (e.op: @switch) match { + case Expr.BinaryOp.OP_* => + val r = visitExprAsDouble(e.lhs) * visitExprAsDouble(e.rhs) + if (r.isInfinite) Error.fail("overflow", pos); r + case Expr.BinaryOp.OP_/ => + val l = visitExprAsDouble(e.lhs) + val r = visitExprAsDouble(e.rhs) + if (r == 0) Error.fail("division by zero", pos) + val result = l / r + if (result.isInfinite) Error.fail("overflow", pos); result + case Expr.BinaryOp.OP_% => + visitExprAsDouble(e.lhs) % visitExprAsDouble(e.rhs) + case Expr.BinaryOp.OP_+ => + val r = visitExprAsDouble(e.lhs) + visitExprAsDouble(e.rhs) + if (r.isInfinite) Error.fail("overflow", pos); r + case Expr.BinaryOp.OP_- => + val r = visitExprAsDouble(e.lhs) - visitExprAsDouble(e.rhs) + if (r.isInfinite) Error.fail("overflow", pos); r + case Expr.BinaryOp.OP_<< => + val ll = visitExprAsDouble(e.lhs).toSafeLong(pos) + val rr = visitExprAsDouble(e.rhs).toSafeLong(pos) + if (rr < 0) Error.fail("shift by negative exponent", pos) + if (rr >= 1 && math.abs(ll) >= (1L << (63 - rr))) + Error.fail("numeric value outside safe integer range for bitwise operation", pos) + (ll << rr).toDouble + case Expr.BinaryOp.OP_>> => + val ll = visitExprAsDouble(e.lhs).toSafeLong(pos) + val rr = visitExprAsDouble(e.rhs).toSafeLong(pos) + if (rr < 0) Error.fail("shift by negative exponent", pos) + (ll >> rr).toDouble + case Expr.BinaryOp.OP_& => + (visitExprAsDouble(e.lhs).toSafeLong(pos) & visitExprAsDouble(e.rhs).toSafeLong( + pos + )).toDouble + case Expr.BinaryOp.OP_^ => + (visitExprAsDouble(e.lhs).toSafeLong(pos) ^ visitExprAsDouble(e.rhs).toSafeLong( + pos + )).toDouble + case Expr.BinaryOp.OP_| => + (visitExprAsDouble(e.lhs).toSafeLong(pos) | visitExprAsDouble(e.rhs).toSafeLong( + pos + )).toDouble + case _ => + visitBinaryOp(e).asDouble + } + } + + private def visitUnaryOpAsDouble(e: UnaryOp)(implicit scope: ValScope): Double = + (e.op: @switch) match { + case Expr.UnaryOp.OP_- => -visitExprAsDouble(e.value) + case Expr.UnaryOp.OP_+ => visitExprAsDouble(e.value) + case Expr.UnaryOp.OP_~ => (~visitExprAsDouble(e.value).toSafeLong(e.pos)).toDouble + case _ => visitUnaryOp(e).asDouble + } + /** * Function application entry points (visitApply/visitApply0-3 for user functions, * visitApplyBuiltin/visitApplyBuiltin0-4 for built-in functions). @@ -601,28 +677,31 @@ class Evaluator( } def visitBinaryOp(e: BinaryOp)(implicit scope: ValScope): Val.Literal = { - val l = visitExpr(e.lhs) - val r = visitExpr(e.rhs) val pos = e.pos - def fail() = Error.fail( - s"Unknown binary operation: ${l.prettyName} ${Expr.BinaryOp.name(e.op)} ${r.prettyName}", - pos - ) - e.op match { - - case Expr.BinaryOp.OP_== => - if (l.isInstanceOf[Val.Func] && r.isInstanceOf[Val.Func]) { - Error.fail("cannot test equality of functions", pos) - } - Val.bool(pos, equal(l, r)) - - case Expr.BinaryOp.OP_!= => - if (l.isInstanceOf[Val.Func] && r.isInstanceOf[Val.Func]) { - Error.fail("cannot test equality of functions", pos) + (e.op: @switch) match { + // Pure numeric fast path: avoid intermediate Val.Num allocation + case Expr.BinaryOp.OP_* => + Val.Num(pos, visitExprAsDouble(e.lhs) * visitExprAsDouble(e.rhs)) + case Expr.BinaryOp.OP_- => + Val.Num(pos, visitExprAsDouble(e.lhs) - visitExprAsDouble(e.rhs)) + case Expr.BinaryOp.OP_/ => + val l = visitExprAsDouble(e.lhs) + val r = visitExprAsDouble(e.rhs) + if (r == 0) Error.fail("division by zero", pos) + Val.Num(pos, l / r) + // Polymorphic ops: need visitExpr for type dispatch + case Expr.BinaryOp.OP_% => + val l = visitExpr(e.lhs) + val r = visitExpr(e.rhs) + (l, r) match { + case (Val.Num(_, l), Val.Num(_, r)) => Val.Num(pos, l % r) + case (Val.Str(_, l), r) => Val.Str(pos, Format.format(l, r, pos)) + case _ => failBinOp(l, e.op, r, pos) } - Val.bool(pos, !equal(l, r)) case Expr.BinaryOp.OP_+ => + val l = visitExpr(e.lhs) + val r = visitExpr(e.rhs) (l, r) match { case (Val.Num(_, l), Val.Num(_, r)) => Val.Num(pos, l + r) case (Val.Str(_, l), Val.Str(_, r)) => Val.Str(pos, l + r) @@ -630,130 +709,128 @@ class Evaluator( case (l, Val.Str(_, r)) => Val.Str(pos, Materializer.stringify(l) + r) case (l: Val.Obj, r: Val.Obj) => r.addSuper(pos, l) case (l: Val.Arr, r: Val.Arr) => l.concat(pos, r) - case _ => fail() - } - - case Expr.BinaryOp.OP_- => - (l, r) match { - case (Val.Num(_, l), Val.Num(_, r)) => Val.Num(pos, l - r) - case _ => fail() + case _ => failBinOp(l, e.op, r, pos) } - case Expr.BinaryOp.OP_* => - (l, r) match { - case (Val.Num(_, l), Val.Num(_, r)) => Val.Num(pos, l * r) - case _ => fail() - } - - case Expr.BinaryOp.OP_/ => - (l, r) match { - case (Val.Num(_, l), Val.Num(_, r)) => - if (r == 0) Error.fail("division by zero", pos) - Val.Num(pos, l / r) - case _ => fail() - } + // Shift ops: pure numeric with safe-integer range check + case Expr.BinaryOp.OP_<< => + val ll = visitExprAsDouble(e.lhs).toSafeLong(pos) + val rr = visitExprAsDouble(e.rhs).toSafeLong(pos) + if (rr < 0) Error.fail("shift by negative exponent", pos) + if (rr >= 1 && math.abs(ll) >= (1L << (63 - rr))) + Error.fail("numeric value outside safe integer range for bitwise operation", pos) + else + Val.Num(pos, (ll << rr).toDouble) - case Expr.BinaryOp.OP_% => - (l, r) match { - case (Val.Num(_, l), Val.Num(_, r)) => Val.Num(pos, l % r) - case (Val.Str(_, l), r) => Val.Str(pos, Format.format(l, r, pos)) - case _ => fail() - } + case Expr.BinaryOp.OP_>> => + val ll = visitExprAsDouble(e.lhs).toSafeLong(pos) + val rr = visitExprAsDouble(e.rhs).toSafeLong(pos) + if (rr < 0) Error.fail("shift by negative exponent", pos) + Val.Num(pos, (ll >> rr).toDouble) + // Comparison ops: polymorphic (Num/Str/Arr) case Expr.BinaryOp.OP_< => + val l = visitExpr(e.lhs) + val r = visitExpr(e.rhs) (l, r) match { case (Val.Str(_, l), Val.Str(_, r)) => Val.bool(pos, Util.compareStringsByCodepoint(l, r) < 0) case (Val.Num(_, l), Val.Num(_, r)) => Val.bool(pos, l < r) case (x: Val.Arr, y: Val.Arr) => Val.bool(pos, compare(x, y) < 0) - case _ => fail() + case _ => failBinOp(l, e.op, r, pos) } case Expr.BinaryOp.OP_> => + val l = visitExpr(e.lhs) + val r = visitExpr(e.rhs) (l, r) match { case (Val.Str(_, l), Val.Str(_, r)) => Val.bool(pos, Util.compareStringsByCodepoint(l, r) > 0) case (Val.Num(_, l), Val.Num(_, r)) => Val.bool(pos, l > r) case (x: Val.Arr, y: Val.Arr) => Val.bool(pos, compare(x, y) > 0) - case _ => fail() + case _ => failBinOp(l, e.op, r, pos) } case Expr.BinaryOp.OP_<= => + val l = visitExpr(e.lhs) + val r = visitExpr(e.rhs) (l, r) match { case (Val.Str(_, l), Val.Str(_, r)) => Val.bool(pos, Util.compareStringsByCodepoint(l, r) <= 0) case (Val.Num(_, l), Val.Num(_, r)) => Val.bool(pos, l <= r) case (x: Val.Arr, y: Val.Arr) => Val.bool(pos, compare(x, y) <= 0) - case _ => fail() + case _ => failBinOp(l, e.op, r, pos) } case Expr.BinaryOp.OP_>= => + val l = visitExpr(e.lhs) + val r = visitExpr(e.rhs) (l, r) match { case (Val.Str(_, l), Val.Str(_, r)) => Val.bool(pos, Util.compareStringsByCodepoint(l, r) >= 0) case (Val.Num(_, l), Val.Num(_, r)) => Val.bool(pos, l >= r) case (x: Val.Arr, y: Val.Arr) => Val.bool(pos, compare(x, y) >= 0) - case _ => fail() - } - - case Expr.BinaryOp.OP_<< => - (l, r) match { - case (l: Val.Num, r: Val.Num) => - val ll = l.asSafeLong - val rr = r.asSafeLong - if (rr < 0) { - Error.fail("shift by negative exponent", pos) - } - if (rr >= 1 && math.abs(ll) >= (1L << (63 - rr))) - Error.fail("numeric value outside safe integer range for bitwise operation", pos) - else - Val.Num(pos, (ll << rr).toDouble) - case _ => fail() - } - - case Expr.BinaryOp.OP_>> => - (l, r) match { - case (l: Val.Num, r: Val.Num) => - val ll = l.asSafeLong - val rr = r.asSafeLong - if (rr < 0) { - Error.fail("shift by negative exponent", pos) - } - Val.Num(pos, (ll >> rr).toDouble) - case _ => fail() + case _ => failBinOp(l, e.op, r, pos) } case Expr.BinaryOp.OP_in => + val l = visitExpr(e.lhs) + val r = visitExpr(e.rhs) (l, r) match { case (Val.Str(_, l), o: Val.Obj) => Val.bool(pos, o.containsKey(l)) - case _ => fail() + case _ => failBinOp(l, e.op, r, pos) } + // Equality ops + case Expr.BinaryOp.OP_== => + val l = visitExpr(e.lhs) + val r = visitExpr(e.rhs) + if (l.isInstanceOf[Val.Func] && r.isInstanceOf[Val.Func]) + Error.fail("cannot test equality of functions", pos) + Val.bool(pos, equal(l, r)) + + case Expr.BinaryOp.OP_!= => + val l = visitExpr(e.lhs) + val r = visitExpr(e.rhs) + if (l.isInstanceOf[Val.Func] && r.isInstanceOf[Val.Func]) + Error.fail("cannot test equality of functions", pos) + Val.bool(pos, !equal(l, r)) + + // Bitwise ops: pure numeric with safe-integer range check case Expr.BinaryOp.OP_& => - (l, r) match { - case (l: Val.Num, r: Val.Num) => - Val.Num(pos, (l.asSafeLong & r.asSafeLong).toDouble) - case _ => fail() - } + Val.Num( + pos, + (visitExprAsDouble(e.lhs).toSafeLong(pos) & + visitExprAsDouble(e.rhs).toSafeLong(pos)).toDouble + ) case Expr.BinaryOp.OP_^ => - (l, r) match { - case (l: Val.Num, r: Val.Num) => - Val.Num(pos, (l.asSafeLong ^ r.asSafeLong).toDouble) - case _ => fail() - } + Val.Num( + pos, + (visitExprAsDouble(e.lhs).toSafeLong(pos) ^ + visitExprAsDouble(e.rhs).toSafeLong(pos)).toDouble + ) case Expr.BinaryOp.OP_| => - (l, r) match { - case (l: Val.Num, r: Val.Num) => - Val.Num(pos, (l.asSafeLong | r.asSafeLong).toDouble) - case _ => fail() - } + Val.Num( + pos, + (visitExprAsDouble(e.lhs).toSafeLong(pos) | + visitExprAsDouble(e.rhs).toSafeLong(pos)).toDouble + ) - case _ => fail() + case _ => + val l = visitExpr(e.lhs) + val r = visitExpr(e.rhs) + failBinOp(l, e.op, r, pos) } } + @inline private def failBinOp(l: Val, op: Int, r: Val, pos: Position): Nothing = + Error.fail( + s"Unknown binary operation: ${l.prettyName} ${Expr.BinaryOp.name(op)} ${r.prettyName}", + pos + ) + def visitFieldName(fieldName: FieldName, pos: Position)(implicit scope: ValScope): String = { fieldName match { case FieldName.Fixed(s) => s @@ -1240,6 +1317,14 @@ class NewEvaluator( object Evaluator { + implicit class SafeDoubleOps(private val d: Double) extends AnyVal { + @inline def toSafeLong(pos: Position)(implicit ev: EvalErrorScope): Long = { + if (d < Val.DOUBLE_MIN_SAFE_INTEGER || d > Val.DOUBLE_MAX_SAFE_INTEGER) + Error.fail("numeric value outside safe integer range for bitwise operation", pos) + d.toLong + } + } + /** * Logger, used for warnings and trace. The first argument is true if the message is a trace * emitted by std.trace diff --git a/sjsonnet/src/sjsonnet/StaticOptimizer.scala b/sjsonnet/src/sjsonnet/StaticOptimizer.scala index 1ed8ad77..2240966d 100644 --- a/sjsonnet/src/sjsonnet/StaticOptimizer.scala +++ b/sjsonnet/src/sjsonnet/StaticOptimizer.scala @@ -1,8 +1,10 @@ package sjsonnet +import scala.annotation.switch import scala.collection.mutable import Expr.* +import Evaluator.SafeDoubleOps import ScopedExprTransform.* /** @@ -31,106 +33,184 @@ class StaticOptimizer( extends ScopedExprTransform { def optimize(e: Expr): Expr = transform(e) - override def transform(_e: Expr): Expr = super.transform(check(_e)) match { - case a: Apply => transformApply(a) - - case e @ Select(p, obj: Val.Obj, name) if obj.containsKey(name) => - try obj.value(name, p)(ev).asInstanceOf[Expr] - catch { case _: Exception => e } - - case Select(pos, ValidSuper(_, selfIdx), name) => - SelectSuper(pos, selfIdx, name) - - case Lookup(pos, ValidSuper(_, selfIdx), index) => - LookupSuper(pos, selfIdx, index) - - case BinaryOp(pos, lhs, BinaryOp.OP_in, ValidSuper(_, selfIdx)) => - InSuper(pos, lhs, selfIdx) - case b2 @ BinaryOp(pos, lhs: Val.Str, BinaryOp.OP_%, rhs) => - try ApplyBuiltin1(pos, new Format.PartialApplyFmt(lhs.str), rhs, tailstrict = false) - catch { case _: Exception => b2 } - - case e @ Id(pos, name) => - scope.get(name) match { - case ScopedVal(v: Val with Expr, _, _) => v - case ScopedVal(_, _, idx) => ValidId(pos, name, idx) - case null if name == f"$$std" => std - case null if name == "std" => std - case null => - variableResolver(name) match { - case Some(v) => v // additional variable resolution - case None => - StaticError.fail( - "Unknown variable: " + name, - e - )(ev) + override def transform(_e: Expr): Expr = { + // Fast path: fold pure numeric literal chains as raw doubles before the bottom-up transform. + // This avoids intermediate Val.Num + BinaryOp allocations for chains like `60 * 60 * 24`. + _e match { + case _: BinaryOp | _: UnaryOp => + val d = tryFoldAsDouble(_e) + if (!d.isNaN) return Val.Num(_e.pos, d) + case _ => + } + super.transform(check(_e)) match { + case a: Apply => transformApply(a) + + case e @ Select(p, obj: Val.Obj, name) if obj.containsKey(name) => + try obj.value(name, p)(ev).asInstanceOf[Expr] + catch { case _: Exception => e } + + case Select(pos, ValidSuper(_, selfIdx), name) => + SelectSuper(pos, selfIdx, name) + + case Lookup(pos, ValidSuper(_, selfIdx), index) => + LookupSuper(pos, selfIdx, index) + + case BinaryOp(pos, lhs, BinaryOp.OP_in, ValidSuper(_, selfIdx)) => + InSuper(pos, lhs, selfIdx) + case b2 @ BinaryOp(pos, lhs: Val.Str, BinaryOp.OP_%, rhs) => + try { + rhs match { + case r: Val => + val partial = new Format.PartialApplyFmt(lhs.str) + try partial.evalRhs(r, ev, pos).asInstanceOf[Expr] + catch { + case _: Exception => + ApplyBuiltin1(pos, partial, rhs, tailstrict = false) + } + case _ => + ApplyBuiltin1(pos, new Format.PartialApplyFmt(lhs.str), rhs, tailstrict = false) } - } + } catch { case _: Exception => b2 } + + case e @ Id(pos, name) => + scope.get(name) match { + case ScopedVal(v: Val with Expr, _, _) => v + case ScopedVal(_, _, idx) => ValidId(pos, name, idx) + case null if name == f"$$std" => std + case null if name == "std" => std + case null => + variableResolver(name) match { + case Some(v) => v // additional variable resolution + case None => + StaticError.fail( + "Unknown variable: " + name, + e + )(ev) + } + } - case e @ Self(pos) => - scope.get("self") match { - case ScopedVal(v, _, idx) if v != null => ValidId(pos, "self", idx) - case _ => StaticError.fail("Can't use self outside of an object", e)(ev) - } + case e @ Self(pos) => + scope.get("self") match { + case ScopedVal(v, _, idx) if v != null => ValidId(pos, "self", idx) + case _ => StaticError.fail("Can't use self outside of an object", e)(ev) + } - case e @ $(pos) => - scope.get("$") match { - case ScopedVal(v, _, idx) if v != null => ValidId(pos, "$", idx) - case _ => StaticError.fail("Can't use $ outside of an object", e)(ev) - } + case e @ $(pos) => + scope.get("$") match { + case ScopedVal(v, _, idx) if v != null => ValidId(pos, "$", idx) + case _ => StaticError.fail("Can't use $ outside of an object", e)(ev) + } - case e @ Super(_) if !scope.contains("super") => - StaticError.fail("Can't use super outside of an object", e)(ev) - - case a: Arr if a.value.forall(_.isInstanceOf[Val]) => - Val.Arr(a.pos, a.value.map(e => e.asInstanceOf[Val])) - - case m @ ObjBody.MemberList(pos, binds, fields, asserts) => - // If static optimization has constant-folded originally-dynamic field names - // into fixed names, it's possible that we might now have duplicate names. - // In that case, we keep the object as a MemberList and leave it to the - // Evaluator to throw an error if/when the object is evaluated (in order - // to preserve proper laziness semantics). - def allFieldsStaticAndUniquelyNamed: Boolean = { - val seen = mutable.Set.empty[String] - fields.forall { f => - f.isStatic && seen.add(f.fieldName.asInstanceOf[FieldName.Fixed].value) + case e @ Super(_) if !scope.contains("super") => + StaticError.fail("Can't use super outside of an object", e)(ev) + + case a: Arr if a.value.forall(_.isInstanceOf[Val]) => + Val.Arr(a.pos, a.value.map(e => e.asInstanceOf[Val])) + + case m @ ObjBody.MemberList(pos, binds, fields, asserts) => + // If static optimization has constant-folded originally-dynamic field names + // into fixed names, it's possible that we might now have duplicate names. + // In that case, we keep the object as a MemberList and leave it to the + // Evaluator to throw an error if/when the object is evaluated (in order + // to preserve proper laziness semantics). + def allFieldsStaticAndUniquelyNamed: Boolean = { + val seen = mutable.Set.empty[String] + fields.forall { f => + f.isStatic && seen.add(f.fieldName.asInstanceOf[FieldName.Fixed].value) + } } - } - if (binds == null && asserts == null && allFieldsStaticAndUniquelyNamed) - Val.staticObject(pos, fields, internedStaticFieldSets, internedStrings) - else m - // Aggressive optimizations: constant folding, branch elimination, short-circuit elimination. - // These reduce AST node count at parse time, benefiting long-running Jsonnet programs. - // Constant folding: BinaryOp with two constant operands (most common case first) - case e @ BinaryOp(pos, lhs: Val, op, rhs: Val) => tryFoldBinaryOp(pos, lhs, op, rhs, e) - - // Constant folding: UnaryOp with constant operand - case e @ UnaryOp(pos, op, v: Val) => tryFoldUnaryOp(pos, op, v, e) - - // Branch elimination: constant condition in if-else - case IfElse(pos, _: Val.True, thenExpr, _) => - thenExpr.pos = pos; thenExpr - case IfElse(pos, _: Val.False, _, elseExpr) => - if (elseExpr == null) Val.Null(pos) - else { elseExpr.pos = pos; elseExpr } - - // Short-circuit elimination for And/Or with constant lhs. - // - // IMPORTANT: rhs MUST be guarded as `Val.Bool` — do NOT relax this to arbitrary Expr. - // The Evaluator's visitAnd/visitOr enforces that rhs evaluates to Bool, throwing - // "binary operator && does not operate on s" otherwise. If we fold `true && rhs` - // into just `rhs` without the Bool guard, we silently remove that runtime type check, - // causing programs like `true && "hello"` to return "hello" instead of erroring. - // See: Evaluator.visitAnd / Evaluator.visitOr for the authoritative runtime semantics. - case And(pos, _: Val.True, rhs: Val.Bool) => rhs.pos = pos; rhs - case And(pos, _: Val.False, _) => Val.False(pos) - case Or(pos, _: Val.True, _) => Val.True(pos) - case Or(pos, _: Val.False, rhs: Val.Bool) => rhs.pos = pos; rhs - case e => e + if (binds == null && asserts == null && allFieldsStaticAndUniquelyNamed) + Val.staticObject(pos, fields, internedStaticFieldSets, internedStrings) + else m + // Aggressive optimizations: constant folding, branch elimination, short-circuit elimination. + // These reduce AST node count at parse time, benefiting long-running Jsonnet programs. + // Constant folding: BinaryOp with two constant operands (most common case first) + case e @ BinaryOp(pos, lhs: Val, op, rhs: Val) => tryFoldBinaryOp(pos, lhs, op, rhs, e) + + // Constant folding: UnaryOp with constant operand + case e @ UnaryOp(pos, op, v: Val) => tryFoldUnaryOp(pos, op, v, e) + + // Branch elimination: constant condition in if-else + case IfElse(pos, _: Val.True, thenExpr, _) => + thenExpr.pos = pos; thenExpr + case IfElse(pos, _: Val.False, _, elseExpr) => + if (elseExpr == null) Val.Null(pos) + else { elseExpr.pos = pos; elseExpr } + + // Short-circuit elimination for And/Or with constant lhs. + // + // IMPORTANT: rhs MUST be guarded as `Val.Bool` — do NOT relax this to arbitrary Expr. + // The Evaluator's visitAnd/visitOr enforces that rhs evaluates to Bool, throwing + // "binary operator && does not operate on s" otherwise. If we fold `true && rhs` + // into just `rhs` without the Bool guard, we silently remove that runtime type check, + // causing programs like `true && "hello"` to return "hello" instead of erroring. + // See: Evaluator.visitAnd / Evaluator.visitOr for the authoritative runtime semantics. + case And(pos, _: Val.True, rhs: Val.Bool) => rhs.pos = pos; rhs + case And(pos, _: Val.False, _) => Val.False(pos) + case Or(pos, _: Val.True, _) => Val.True(pos) + case Or(pos, _: Val.False, rhs: Val.Bool) => rhs.pos = pos; rhs + case e => e + } } + /** + * Try to fold a pure constant numeric expression chain as a raw double, bypassing the bottom-up + * tree transformer. Only handles trees of BinaryOp/UnaryOp/Val.Num with numeric-only ops. + * + * Returns `NaN` if the expression cannot be folded (non-numeric leaf, polymorphic op, error). + * This avoids intermediate `Val.Num` and `BinaryOp` allocations in chains like `60 * 60 * 24`. + */ + private def tryFoldAsDouble(e: Expr): Double = + try { + e match { + case Val.Num(_, n) => n + case BinaryOp(pos, lhs, op, rhs) => + val l = tryFoldAsDouble(lhs) + if (l.isNaN) return Double.NaN + val r = tryFoldAsDouble(rhs) + if (r.isNaN) return Double.NaN + (op: @switch) match { + case BinaryOp.OP_+ => + val res = l + r; if (res.isInfinite) return Double.NaN; res + case BinaryOp.OP_- => + val res = l - r; if (res.isInfinite) return Double.NaN; res + case BinaryOp.OP_* => + val res = l * r; if (res.isInfinite) return Double.NaN; res + case BinaryOp.OP_/ => + if (r == 0) return Double.NaN + val res = l / r; if (res.isInfinite) return Double.NaN; res + case BinaryOp.OP_% => l % r + case BinaryOp.OP_<< => + val ll = l.toSafeLong(pos)(ev); val rr = r.toSafeLong(pos)(ev) + if (rr < 0) return Double.NaN + if (rr >= 1 && math.abs(ll) >= (1L << (63 - rr))) return Double.NaN + (ll << rr).toDouble + case BinaryOp.OP_>> => + val ll = l.toSafeLong(pos)(ev); val rr = r.toSafeLong(pos)(ev) + if (rr < 0) return Double.NaN + (ll >> rr).toDouble + case BinaryOp.OP_& => + (l.toSafeLong(pos)(ev) & r.toSafeLong(pos)(ev)).toDouble + case BinaryOp.OP_^ => + (l.toSafeLong(pos)(ev) ^ r.toSafeLong(pos)(ev)).toDouble + case BinaryOp.OP_| => + (l.toSafeLong(pos)(ev) | r.toSafeLong(pos)(ev)).toDouble + case _ => Double.NaN // non-numeric op (comparison, equality, etc.) + } + case UnaryOp(pos, op, v) => + val d = tryFoldAsDouble(v) + if (d.isNaN) return Double.NaN + (op: @switch) match { + case Expr.UnaryOp.OP_- => -d + case Expr.UnaryOp.OP_+ => d + case Expr.UnaryOp.OP_~ => (~d.toSafeLong(pos)(ev)).toDouble + case _ => Double.NaN + } + case _ => Double.NaN + } + } catch { case _: Exception => Double.NaN } + private object ValidSuper { def unapply(s: Super): Option[(Position, Int)] = scope.get("self") match { @@ -293,7 +373,7 @@ class StaticOptimizer( private def tryFoldUnaryOp(pos: Position, op: Int, v: Val, fallback: Expr): Expr = try { - op match { + (op: @switch) match { case Expr.UnaryOp.OP_! => v match { case _: Val.True => Val.False(pos) @@ -307,13 +387,13 @@ class StaticOptimizer( } case Expr.UnaryOp.OP_~ => v match { - case Val.Num(_, n) => Val.Num(pos, (~n.toLong).toDouble) - case _ => fallback + case n: Val.Num => Val.Num(pos, (~n.asSafeLong).toDouble) + case _ => fallback } case Expr.UnaryOp.OP_+ => v match { - case Val.Num(_, n) => Val.Num(pos, n) - case _ => fallback + case n: Val.Num => n.pos = pos; n.asInstanceOf[Expr] + case _ => fallback } case _ => fallback } @@ -321,7 +401,7 @@ class StaticOptimizer( private def tryFoldBinaryOp(pos: Position, lhs: Val, op: Int, rhs: Val, fallback: Expr): Expr = try { - op match { + (op: @switch) match { case BinaryOp.OP_+ => (lhs, rhs) match { case (Val.Num(_, l), Val.Num(_, r)) => Val.Num(pos, l + r) @@ -368,9 +448,9 @@ class StaticOptimizer( } case BinaryOp.OP_<< => (lhs, rhs) match { - case (Val.Num(_, l), Val.Num(_, r)) => - val ll = lhs.asInstanceOf[Val.Num].asSafeLong - val rr = rhs.asInstanceOf[Val.Num].asSafeLong + case (l: Val.Num, r: Val.Num) => + val ll = l.asSafeLong + val rr = r.asSafeLong if (rr < 0) fallback // negative shift → runtime error else if (rr >= 1 && math.abs(ll) >= (1L << (63 - rr))) fallback // overflow → runtime error @@ -379,9 +459,9 @@ class StaticOptimizer( } case BinaryOp.OP_>> => (lhs, rhs) match { - case (Val.Num(_, l), Val.Num(_, r)) => - val ll = lhs.asInstanceOf[Val.Num].asSafeLong - val rr = rhs.asInstanceOf[Val.Num].asSafeLong + case (l: Val.Num, r: Val.Num) => + val ll = l.asSafeLong + val rr = r.asSafeLong if (rr < 0) fallback // negative shift → runtime error else Val.Num(pos, (ll >> rr).toDouble) case _ => fallback @@ -418,7 +498,7 @@ class StaticOptimizer( // because compare(-0.0, 0.0) == -1 while IEEE 754 treats -0.0 == 0.0. (lhs, rhs) match { case (Val.Num(_, l), Val.Num(_, r)) if !l.isNaN && !r.isNaN => - val result = op match { + val result = (op: @switch) match { case BinaryOp.OP_< => l < r case BinaryOp.OP_> => l > r case BinaryOp.OP_<= => l <= r @@ -428,7 +508,7 @@ class StaticOptimizer( Val.bool(pos, result) case (Val.Str(_, l), Val.Str(_, r)) => val cmp = Util.compareStringsByCodepoint(l, r) - val result = op match { + val result = (op: @switch) match { case BinaryOp.OP_< => cmp < 0 case BinaryOp.OP_> => cmp > 0 case BinaryOp.OP_<= => cmp <= 0 diff --git a/sjsonnet/src/sjsonnet/Val.scala b/sjsonnet/src/sjsonnet/Val.scala index 4f642e03..fa1090e4 100644 --- a/sjsonnet/src/sjsonnet/Val.scala +++ b/sjsonnet/src/sjsonnet/Val.scala @@ -182,8 +182,8 @@ object PrettyNamed { object Val { // Constants for safe double-to-int conversion // IEEE 754 doubles precisely represent integers up to 2^53, beyond which precision is lost - private val DOUBLE_MAX_SAFE_INTEGER = (1L << 53) - 1 - private val DOUBLE_MIN_SAFE_INTEGER = -((1L << 53) - 1) + private[sjsonnet] final val DOUBLE_MAX_SAFE_INTEGER = (1L << 53) - 1 + private[sjsonnet] final val DOUBLE_MIN_SAFE_INTEGER = -((1L << 53) - 1) abstract class Literal extends Val with Expr { final override private[sjsonnet] def tag = ExprTags.`Val.Literal` diff --git a/sjsonnet/test/resources/go_test_suite/binaryNot2.jsonnet.golden b/sjsonnet/test/resources/go_test_suite/binaryNot2.jsonnet.golden index 356e6fb7..11f034e7 100644 --- a/sjsonnet/test/resources/go_test_suite/binaryNot2.jsonnet.golden +++ b/sjsonnet/test/resources/go_test_suite/binaryNot2.jsonnet.golden @@ -1,3 +1,3 @@ -sjsonnet.Error: Unknown unary operation: ~ string +sjsonnet.Error: Expected Number, got string at [].(binaryNot2.jsonnet:1:1) diff --git a/sjsonnet/test/resources/go_test_suite/bitwise_or10.jsonnet.golden b/sjsonnet/test/resources/go_test_suite/bitwise_or10.jsonnet.golden index 036db1b8..50f106df 100644 --- a/sjsonnet/test/resources/go_test_suite/bitwise_or10.jsonnet.golden +++ b/sjsonnet/test/resources/go_test_suite/bitwise_or10.jsonnet.golden @@ -1,3 +1,3 @@ -sjsonnet.Error: Unknown binary operation: string | number +sjsonnet.Error: Expected Number, got string at [].(bitwise_or10.jsonnet:1:7) diff --git a/sjsonnet/test/resources/go_test_suite/number_divided_by_string.jsonnet.golden b/sjsonnet/test/resources/go_test_suite/number_divided_by_string.jsonnet.golden index 8b1f93a6..cb927039 100644 --- a/sjsonnet/test/resources/go_test_suite/number_divided_by_string.jsonnet.golden +++ b/sjsonnet/test/resources/go_test_suite/number_divided_by_string.jsonnet.golden @@ -1,3 +1,3 @@ -sjsonnet.Error: Unknown binary operation: number / string +sjsonnet.Error: Expected Number, got string at [].(number_divided_by_string.jsonnet:1:4) diff --git a/sjsonnet/test/resources/go_test_suite/number_times_string.jsonnet.golden b/sjsonnet/test/resources/go_test_suite/number_times_string.jsonnet.golden index 5175744d..68d168bd 100644 --- a/sjsonnet/test/resources/go_test_suite/number_times_string.jsonnet.golden +++ b/sjsonnet/test/resources/go_test_suite/number_times_string.jsonnet.golden @@ -1,3 +1,3 @@ -sjsonnet.Error: Unknown binary operation: number * string +sjsonnet.Error: Expected Number, got string at [].(number_times_string.jsonnet:1:4) diff --git a/sjsonnet/test/resources/go_test_suite/string_divided_by_number.jsonnet.golden b/sjsonnet/test/resources/go_test_suite/string_divided_by_number.jsonnet.golden index c1baf4e5..b6090100 100644 --- a/sjsonnet/test/resources/go_test_suite/string_divided_by_number.jsonnet.golden +++ b/sjsonnet/test/resources/go_test_suite/string_divided_by_number.jsonnet.golden @@ -1,3 +1,3 @@ -sjsonnet.Error: Unknown binary operation: string / number +sjsonnet.Error: Expected Number, got string at [].(string_divided_by_number.jsonnet:1:7) diff --git a/sjsonnet/test/resources/go_test_suite/string_minus_number.jsonnet.golden b/sjsonnet/test/resources/go_test_suite/string_minus_number.jsonnet.golden index 802a3733..d724c6db 100644 --- a/sjsonnet/test/resources/go_test_suite/string_minus_number.jsonnet.golden +++ b/sjsonnet/test/resources/go_test_suite/string_minus_number.jsonnet.golden @@ -1,3 +1,3 @@ -sjsonnet.Error: Unknown binary operation: string - number +sjsonnet.Error: Expected Number, got string at [].(string_minus_number.jsonnet:1:5) diff --git a/sjsonnet/test/resources/go_test_suite/string_times_number.jsonnet.golden b/sjsonnet/test/resources/go_test_suite/string_times_number.jsonnet.golden index 6b05e85b..6ba00d2d 100644 --- a/sjsonnet/test/resources/go_test_suite/string_times_number.jsonnet.golden +++ b/sjsonnet/test/resources/go_test_suite/string_times_number.jsonnet.golden @@ -1,3 +1,3 @@ -sjsonnet.Error: Unknown binary operation: string * number +sjsonnet.Error: Expected Number, got string at [].(string_times_number.jsonnet:1:5) diff --git a/sjsonnet/test/resources/go_test_suite/unary_minus4.jsonnet.golden b/sjsonnet/test/resources/go_test_suite/unary_minus4.jsonnet.golden index 0886fb00..eb0b0518 100644 --- a/sjsonnet/test/resources/go_test_suite/unary_minus4.jsonnet.golden +++ b/sjsonnet/test/resources/go_test_suite/unary_minus4.jsonnet.golden @@ -1,3 +1,3 @@ -sjsonnet.Error: Unknown unary operation: - string +sjsonnet.Error: Expected Number, got string at [].(unary_minus4.jsonnet:1:1) diff --git a/sjsonnet/test/resources/go_test_suite/unary_object.jsonnet.golden b/sjsonnet/test/resources/go_test_suite/unary_object.jsonnet.golden index bde9b3fe..5b965a44 100644 --- a/sjsonnet/test/resources/go_test_suite/unary_object.jsonnet.golden +++ b/sjsonnet/test/resources/go_test_suite/unary_object.jsonnet.golden @@ -1,3 +1,3 @@ -sjsonnet.Error: Unknown unary operation: + object +sjsonnet.Error: Expected Number, got object at [].(unary_object.jsonnet:1:1) From e47f47f852070276f52cacdf01eca838e8df17df Mon Sep 17 00:00:00 2001 From: He-Pin Date: Sun, 8 Mar 2026 03:11:30 +0800 Subject: [PATCH 3/4] refactor static optimizer dispatch Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../sjsonnet/bench/OptimizerBenchmark.scala | 146 ++- sjsonnet/src/sjsonnet/Evaluator.scala | 262 +++-- sjsonnet/src/sjsonnet/ExprTransform.scala | 303 ------ .../src/sjsonnet/ScopedExprTransform.scala | 194 ---- sjsonnet/src/sjsonnet/StaticOptimizer.scala | 919 ++++++++++++++---- sjsonnet/src/sjsonnet/Val.scala | 1 + .../go_test_suite/binaryNot2.jsonnet.golden | 2 +- .../go_test_suite/bitwise_or10.jsonnet.golden | 2 +- .../number_divided_by_string.jsonnet.golden | 2 +- .../number_times_string.jsonnet.golden | 2 +- .../string_divided_by_number.jsonnet.golden | 2 +- .../string_minus_number.jsonnet.golden | 2 +- .../string_times_number.jsonnet.golden | 2 +- .../go_test_suite/unary_minus4.jsonnet.golden | 2 +- .../go_test_suite/unary_object.jsonnet.golden | 2 +- .../AggressiveStaticOptimizationTests.scala | 41 +- 16 files changed, 1126 insertions(+), 758 deletions(-) delete mode 100644 sjsonnet/src/sjsonnet/ExprTransform.scala delete mode 100644 sjsonnet/src/sjsonnet/ScopedExprTransform.scala diff --git a/bench/src/sjsonnet/bench/OptimizerBenchmark.scala b/bench/src/sjsonnet/bench/OptimizerBenchmark.scala index b8000578..65668a22 100644 --- a/bench/src/sjsonnet/bench/OptimizerBenchmark.scala +++ b/bench/src/sjsonnet/bench/OptimizerBenchmark.scala @@ -69,12 +69,13 @@ class OptimizerBenchmark { }) } - class Counter extends ExprTransform { + class Counter { var total, vals, exprs, arrVals, staticArrExprs, otherArrExprs, staticObjs, missedStaticObjs, otherObjs, namedApplies, applies, arityApplies, builtin = 0 val applyArities = new mutable.LongMap[Int]() val ifElseChains = new mutable.LongMap[Int]() val selectChains = new mutable.LongMap[Int]() + def transform(e: Expr): Expr = { total += 1 if (e.isInstanceOf[Val]) vals += 1 @@ -95,7 +96,6 @@ class OptimizerBenchmark { val a = e.args.length applyArities.put(a.toLong, applyArities.getOrElse(a.toLong, 0) + 1) } else namedApplies += 1 - case _: Expr.Apply0 | _: Expr.Apply1 | _: Expr.Apply2 | _: Expr.Apply3 => arityApplies += 1 case _: Expr.ApplyBuiltin | _: Expr.ApplyBuiltin1 | _: Expr.ApplyBuiltin2 => builtin += 1 case _ => @@ -119,7 +119,149 @@ class OptimizerBenchmark { ) } rec(e) + e + } + + private def rec(e: Expr): Unit = e match { + case Expr.Select(_, x, _) => transform(x) + case Expr.Apply(_, x, y, _, _) => + transform(x) + transformArr(y) + case Expr.Apply0(_, x, _) => + transform(x) + case Expr.Apply1(_, x, y, _) => + transform(x) + transform(y) + case Expr.Apply2(_, x, y, z, _) => + transform(x) + transform(y) + transform(z) + case Expr.Apply3(_, x, y, z, a, _) => + transform(x) + transform(y) + transform(z) + transform(a) + case Expr.ApplyBuiltin(_, _, x, _) => + transformArr(x) + case Expr.ApplyBuiltin1(_, _, x, _) => + transform(x) + case Expr.ApplyBuiltin2(_, _, x, y, _) => + transform(x) + transform(y) + case Expr.ApplyBuiltin3(_, _, x, y, z, _) => + transform(x) + transform(y) + transform(z) + case Expr.ApplyBuiltin4(_, _, x, y, z, a, _) => + transform(x) + transform(y) + transform(z) + transform(a) + case Expr.UnaryOp(_, _, x) => + transform(x) + case Expr.BinaryOp(_, x, _, y) => + transform(x) + transform(y) + case Expr.And(_, x, y) => + transform(x) + transform(y) + case Expr.Or(_, x, y) => + transform(x) + transform(y) + case Expr.InSuper(_, x, _) => + transform(x) + case Expr.Lookup(_, x, y) => + transform(x) + transform(y) + case Expr.LookupSuper(_, _, x) => + transform(x) + case Expr.Function(_, params, body) => + transformParams(params) + transform(body) + case Expr.LocalExpr(_, binds, returned) => + transformBinds(binds) + transform(returned) + case Expr.IfElse(_, cond, thenExpr, elseExpr) => + transform(cond) + transform(thenExpr) + transform(elseExpr) + case Expr.ObjBody.MemberList(_, binds, fields, asserts) => + transformBinds(binds) + transformFields(fields) + transformAsserts(asserts) + case Expr.AssertExpr(_, assertion, returned) => + transform(assertion.value) + if (assertion.msg != null) transform(assertion.msg) + transform(returned) + case Expr.Comp(_, value, first, rest) => + transform(value) + transform(first) + transformArr(rest) + case Expr.Arr(_, values) => + transformArr(values) + case Expr.ObjExtend(_, base, ext) => + transform(base) + transform(ext) + case Expr.ObjBody.ObjComp(_, preLocals, key, value, _, postLocals, first, rest) => + transformBinds(preLocals) + transform(key) + transform(value) + transformBinds(postLocals) + transform(first) + transformList(rest) + case Expr.Slice(_, value, start, end, stride) => + transform(value) + transformOption(start) + transformOption(end) + transformOption(stride) + case Expr.IfSpec(_, cond) => + transform(cond) + case Expr.ForSpec(_, _, cond) => + transform(cond) + case Expr.Error(_, value) => + transform(value) + case _ => + } + + private def transformArr[T <: Expr](values: Array[T]): Unit = { + if (values != null) values.foreach(transform) + } + + private def transformOption(value: Option[Expr]): Unit = value.foreach(transform) + + private def transformList(values: List[Expr]): Unit = values.foreach(transform) + + private def transformParams(params: Expr.Params): Unit = { + if (params != null && params.defaultExprs != null) transformArr(params.defaultExprs) + } + + private def transformBinds(binds: Array[Expr.Bind]): Unit = { + if (binds != null) binds.foreach { bind => + transformParams(bind.args) + transform(bind.rhs) + } + } + + private def transformFieldName(fieldName: Expr.FieldName): Unit = fieldName match { + case Expr.FieldName.Dyn(expr) => transform(expr) + case _ => } + + private def transformFields(fields: Array[Expr.Member.Field]): Unit = { + if (fields != null) fields.foreach { field => + transformFieldName(field.fieldName) + transformParams(field.args) + transform(field.rhs) + } + } + + private def transformAsserts(asserts: Array[Expr.Member.AssertStmt]): Unit = { + if (asserts != null) asserts.foreach { assertion => + transform(assertion.value) + if (assertion.msg != null) transform(assertion.msg) + } + } + def countIfElse(e: Expr): Int = e match { case Expr.IfElse(_, _, _, else0) => countIfElse(else0) + 1 diff --git a/sjsonnet/src/sjsonnet/Evaluator.scala b/sjsonnet/src/sjsonnet/Evaluator.scala index cc6a43d8..d4ee0271 100644 --- a/sjsonnet/src/sjsonnet/Evaluator.scala +++ b/sjsonnet/src/sjsonnet/Evaluator.scala @@ -238,9 +238,9 @@ class Evaluator( def visitUnaryOp(e: UnaryOp)(implicit scope: ValScope): Val = { val pos = e.pos (e.op: @switch) match { - case Expr.UnaryOp.OP_+ => Val.Num(pos, visitExprAsDouble(e.value)) - case Expr.UnaryOp.OP_- => Val.Num(pos, -visitExprAsDouble(e.value)) - case Expr.UnaryOp.OP_~ => Val.Num(pos, (~visitExprAsDouble(e.value).toSafeLong(pos)).toDouble) + case Expr.UnaryOp.OP_+ => Val.Num(pos, visitUnaryOpAsDouble(e)) + case Expr.UnaryOp.OP_- => Val.Num(pos, visitUnaryOpAsDouble(e)) + case Expr.UnaryOp.OP_~ => Val.Num(pos, visitUnaryOpAsDouble(e)) case Expr.UnaryOp.OP_! => visitExpr(e.value) match { case Val.True(_) => Val.False(pos) @@ -264,80 +264,222 @@ class Evaluator( */ private def visitExprAsDouble(e: Expr)(implicit scope: ValScope): Double = try { e match { - case v: Val.Num => v.asDouble - case v: Val => Error.fail("Expected Number, got " + v.prettyName, e.pos) + case v: Val.Num => v.rawDouble + case v: Val => throw new Evaluator.NonNumericValue(v) case e: ValidId => scope.bindings(e.nameIdx).value match { - case n: Val.Num => n.asDouble - case v => Error.fail("Expected Number, got " + v.prettyName, e.pos) + case n: Val.Num => n.rawDouble + case v => throw new Evaluator.NonNumericValue(v) } case e: BinaryOp => visitBinaryOpAsDouble(e) case e: UnaryOp => visitUnaryOpAsDouble(e) case e => visitExpr(e) match { - case n: Val.Num => n.asDouble - case v => Error.fail("Expected Number, got " + v.prettyName, e.pos) + case n: Val.Num => n.rawDouble + case v => throw new Evaluator.NonNumericValue(v) } } } catch { Error.withStackFrame(e) } + @inline private def nonNumericOperand(expr: Expr, numeric: Double, value: Val): Val = + if (value == null) Val.Num(expr.pos, numeric) else value + private def visitBinaryOpAsDouble(e: BinaryOp)(implicit scope: ValScope): Double = { val pos = e.pos (e.op: @switch) match { case Expr.BinaryOp.OP_* => - val r = visitExprAsDouble(e.lhs) * visitExprAsDouble(e.rhs) + var lNum = 0.0 + var rNum = 0.0 + var lVal: Val = null + var rVal: Val = null + try lNum = visitExprAsDouble(e.lhs) + catch { case n: Evaluator.NonNumericValue => lVal = n.value } + try rNum = visitExprAsDouble(e.rhs) + catch { case n: Evaluator.NonNumericValue => rVal = n.value } + if ((lVal ne null) || (rVal ne null)) + failBinOp( + nonNumericOperand(e.lhs, lNum, lVal), + e.op, + nonNumericOperand(e.rhs, rNum, rVal), + pos + ) + val r = lNum * rNum if (r.isInfinite) Error.fail("overflow", pos); r case Expr.BinaryOp.OP_/ => - val l = visitExprAsDouble(e.lhs) - val r = visitExprAsDouble(e.rhs) - if (r == 0) Error.fail("division by zero", pos) - val result = l / r + var lNum = 0.0 + var rNum = 0.0 + var lVal: Val = null + var rVal: Val = null + try lNum = visitExprAsDouble(e.lhs) + catch { case n: Evaluator.NonNumericValue => lVal = n.value } + try rNum = visitExprAsDouble(e.rhs) + catch { case n: Evaluator.NonNumericValue => rVal = n.value } + if ((lVal ne null) || (rVal ne null)) + failBinOp( + nonNumericOperand(e.lhs, lNum, lVal), + e.op, + nonNumericOperand(e.rhs, rNum, rVal), + pos + ) + if (rNum == 0) Error.fail("division by zero", pos) + val result = lNum / rNum if (result.isInfinite) Error.fail("overflow", pos); result case Expr.BinaryOp.OP_% => - visitExprAsDouble(e.lhs) % visitExprAsDouble(e.rhs) - case Expr.BinaryOp.OP_+ => - val r = visitExprAsDouble(e.lhs) + visitExprAsDouble(e.rhs) - if (r.isInfinite) Error.fail("overflow", pos); r + var lNum = 0.0 + var rNum = 0.0 + var lVal: Val = null + var rVal: Val = null + try lNum = visitExprAsDouble(e.lhs) + catch { case n: Evaluator.NonNumericValue => lVal = n.value } + try rNum = visitExprAsDouble(e.rhs) + catch { case n: Evaluator.NonNumericValue => rVal = n.value } + if ((lVal ne null) || (rVal ne null)) + failBinOp( + nonNumericOperand(e.lhs, lNum, lVal), + e.op, + nonNumericOperand(e.rhs, rNum, rVal), + pos + ) + lNum % rNum case Expr.BinaryOp.OP_- => - val r = visitExprAsDouble(e.lhs) - visitExprAsDouble(e.rhs) + var lNum = 0.0 + var rNum = 0.0 + var lVal: Val = null + var rVal: Val = null + try lNum = visitExprAsDouble(e.lhs) + catch { case n: Evaluator.NonNumericValue => lVal = n.value } + try rNum = visitExprAsDouble(e.rhs) + catch { case n: Evaluator.NonNumericValue => rVal = n.value } + if ((lVal ne null) || (rVal ne null)) + failBinOp( + nonNumericOperand(e.lhs, lNum, lVal), + e.op, + nonNumericOperand(e.rhs, rNum, rVal), + pos + ) + val r = lNum - rNum if (r.isInfinite) Error.fail("overflow", pos); r case Expr.BinaryOp.OP_<< => - val ll = visitExprAsDouble(e.lhs).toSafeLong(pos) - val rr = visitExprAsDouble(e.rhs).toSafeLong(pos) + var lNum = 0.0 + var rNum = 0.0 + var lVal: Val = null + var rVal: Val = null + try lNum = visitExprAsDouble(e.lhs) + catch { case n: Evaluator.NonNumericValue => lVal = n.value } + try rNum = visitExprAsDouble(e.rhs) + catch { case n: Evaluator.NonNumericValue => rVal = n.value } + if ((lVal ne null) || (rVal ne null)) + failBinOp( + nonNumericOperand(e.lhs, lNum, lVal), + e.op, + nonNumericOperand(e.rhs, rNum, rVal), + pos + ) + val ll = lNum.toSafeLong(pos) + val rr = rNum.toSafeLong(pos) if (rr < 0) Error.fail("shift by negative exponent", pos) if (rr >= 1 && math.abs(ll) >= (1L << (63 - rr))) Error.fail("numeric value outside safe integer range for bitwise operation", pos) (ll << rr).toDouble case Expr.BinaryOp.OP_>> => - val ll = visitExprAsDouble(e.lhs).toSafeLong(pos) - val rr = visitExprAsDouble(e.rhs).toSafeLong(pos) + var lNum = 0.0 + var rNum = 0.0 + var lVal: Val = null + var rVal: Val = null + try lNum = visitExprAsDouble(e.lhs) + catch { case n: Evaluator.NonNumericValue => lVal = n.value } + try rNum = visitExprAsDouble(e.rhs) + catch { case n: Evaluator.NonNumericValue => rVal = n.value } + if ((lVal ne null) || (rVal ne null)) + failBinOp( + nonNumericOperand(e.lhs, lNum, lVal), + e.op, + nonNumericOperand(e.rhs, rNum, rVal), + pos + ) + val ll = lNum.toSafeLong(pos) + val rr = rNum.toSafeLong(pos) if (rr < 0) Error.fail("shift by negative exponent", pos) (ll >> rr).toDouble case Expr.BinaryOp.OP_& => - (visitExprAsDouble(e.lhs).toSafeLong(pos) & visitExprAsDouble(e.rhs).toSafeLong( - pos - )).toDouble + var lNum = 0.0 + var rNum = 0.0 + var lVal: Val = null + var rVal: Val = null + try lNum = visitExprAsDouble(e.lhs) + catch { case n: Evaluator.NonNumericValue => lVal = n.value } + try rNum = visitExprAsDouble(e.rhs) + catch { case n: Evaluator.NonNumericValue => rVal = n.value } + if ((lVal ne null) || (rVal ne null)) + failBinOp( + nonNumericOperand(e.lhs, lNum, lVal), + e.op, + nonNumericOperand(e.rhs, rNum, rVal), + pos + ) + (lNum.toSafeLong(pos) & rNum.toSafeLong(pos)).toDouble case Expr.BinaryOp.OP_^ => - (visitExprAsDouble(e.lhs).toSafeLong(pos) ^ visitExprAsDouble(e.rhs).toSafeLong( - pos - )).toDouble + var lNum = 0.0 + var rNum = 0.0 + var lVal: Val = null + var rVal: Val = null + try lNum = visitExprAsDouble(e.lhs) + catch { case n: Evaluator.NonNumericValue => lVal = n.value } + try rNum = visitExprAsDouble(e.rhs) + catch { case n: Evaluator.NonNumericValue => rVal = n.value } + if ((lVal ne null) || (rVal ne null)) + failBinOp( + nonNumericOperand(e.lhs, lNum, lVal), + e.op, + nonNumericOperand(e.rhs, rNum, rVal), + pos + ) + (lNum.toSafeLong(pos) ^ rNum.toSafeLong(pos)).toDouble case Expr.BinaryOp.OP_| => - (visitExprAsDouble(e.lhs).toSafeLong(pos) | visitExprAsDouble(e.rhs).toSafeLong( - pos - )).toDouble + var lNum = 0.0 + var rNum = 0.0 + var lVal: Val = null + var rVal: Val = null + try lNum = visitExprAsDouble(e.lhs) + catch { case n: Evaluator.NonNumericValue => lVal = n.value } + try rNum = visitExprAsDouble(e.rhs) + catch { case n: Evaluator.NonNumericValue => rVal = n.value } + if ((lVal ne null) || (rVal ne null)) + failBinOp( + nonNumericOperand(e.lhs, lNum, lVal), + e.op, + nonNumericOperand(e.rhs, rNum, rVal), + pos + ) + (lNum.toSafeLong(pos) | rNum.toSafeLong(pos)).toDouble case _ => - visitBinaryOp(e).asDouble + visitExpr(e) match { + case n: Val.Num => n.rawDouble + case v => throw new Evaluator.NonNumericValue(v) + } } } private def visitUnaryOpAsDouble(e: UnaryOp)(implicit scope: ValScope): Double = - (e.op: @switch) match { - case Expr.UnaryOp.OP_- => -visitExprAsDouble(e.value) - case Expr.UnaryOp.OP_+ => visitExprAsDouble(e.value) - case Expr.UnaryOp.OP_~ => (~visitExprAsDouble(e.value).toSafeLong(e.pos)).toDouble - case _ => visitUnaryOp(e).asDouble + try { + (e.op: @switch) match { + case Expr.UnaryOp.OP_- => -visitExprAsDouble(e.value) + case Expr.UnaryOp.OP_+ => visitExprAsDouble(e.value) + case Expr.UnaryOp.OP_~ => (~visitExprAsDouble(e.value).toLong).toDouble + case _ => + visitExpr(e) match { + case n: Val.Num => n.rawDouble + case v => throw new Evaluator.NonNumericValue(v) + } + } + } catch { + case n: Evaluator.NonNumericValue => + Error.fail( + s"Unknown unary operation: ${Expr.UnaryOp.name(e.op)} ${n.value.prettyName}", + e.pos + ) } /** @@ -681,14 +823,11 @@ class Evaluator( (e.op: @switch) match { // Pure numeric fast path: avoid intermediate Val.Num allocation case Expr.BinaryOp.OP_* => - Val.Num(pos, visitExprAsDouble(e.lhs) * visitExprAsDouble(e.rhs)) + Val.Num(pos, visitBinaryOpAsDouble(e)) case Expr.BinaryOp.OP_- => - Val.Num(pos, visitExprAsDouble(e.lhs) - visitExprAsDouble(e.rhs)) + Val.Num(pos, visitBinaryOpAsDouble(e)) case Expr.BinaryOp.OP_/ => - val l = visitExprAsDouble(e.lhs) - val r = visitExprAsDouble(e.rhs) - if (r == 0) Error.fail("division by zero", pos) - Val.Num(pos, l / r) + Val.Num(pos, visitBinaryOpAsDouble(e)) // Polymorphic ops: need visitExpr for type dispatch case Expr.BinaryOp.OP_% => val l = visitExpr(e.lhs) @@ -714,19 +853,10 @@ class Evaluator( // Shift ops: pure numeric with safe-integer range check case Expr.BinaryOp.OP_<< => - val ll = visitExprAsDouble(e.lhs).toSafeLong(pos) - val rr = visitExprAsDouble(e.rhs).toSafeLong(pos) - if (rr < 0) Error.fail("shift by negative exponent", pos) - if (rr >= 1 && math.abs(ll) >= (1L << (63 - rr))) - Error.fail("numeric value outside safe integer range for bitwise operation", pos) - else - Val.Num(pos, (ll << rr).toDouble) + Val.Num(pos, visitBinaryOpAsDouble(e)) case Expr.BinaryOp.OP_>> => - val ll = visitExprAsDouble(e.lhs).toSafeLong(pos) - val rr = visitExprAsDouble(e.rhs).toSafeLong(pos) - if (rr < 0) Error.fail("shift by negative exponent", pos) - Val.Num(pos, (ll >> rr).toDouble) + Val.Num(pos, visitBinaryOpAsDouble(e)) // Comparison ops: polymorphic (Num/Str/Arr) case Expr.BinaryOp.OP_< => @@ -798,25 +928,13 @@ class Evaluator( // Bitwise ops: pure numeric with safe-integer range check case Expr.BinaryOp.OP_& => - Val.Num( - pos, - (visitExprAsDouble(e.lhs).toSafeLong(pos) & - visitExprAsDouble(e.rhs).toSafeLong(pos)).toDouble - ) + Val.Num(pos, visitBinaryOpAsDouble(e)) case Expr.BinaryOp.OP_^ => - Val.Num( - pos, - (visitExprAsDouble(e.lhs).toSafeLong(pos) ^ - visitExprAsDouble(e.rhs).toSafeLong(pos)).toDouble - ) + Val.Num(pos, visitBinaryOpAsDouble(e)) case Expr.BinaryOp.OP_| => - Val.Num( - pos, - (visitExprAsDouble(e.lhs).toSafeLong(pos) | - visitExprAsDouble(e.rhs).toSafeLong(pos)).toDouble - ) + Val.Num(pos, visitBinaryOpAsDouble(e)) case _ => val l = visitExpr(e.lhs) @@ -1317,8 +1435,12 @@ class NewEvaluator( object Evaluator { + final class NonNumericValue(val value: Val) extends scala.util.control.ControlThrowable + implicit class SafeDoubleOps(private val d: Double) extends AnyVal { @inline def toSafeLong(pos: Position)(implicit ev: EvalErrorScope): Long = { + if (d.isInfinite || d.isNaN) + Error.fail("numeric value is not finite", pos) if (d < Val.DOUBLE_MIN_SAFE_INTEGER || d > Val.DOUBLE_MAX_SAFE_INTEGER) Error.fail("numeric value outside safe integer range for bitwise operation", pos) d.toLong diff --git a/sjsonnet/src/sjsonnet/ExprTransform.scala b/sjsonnet/src/sjsonnet/ExprTransform.scala deleted file mode 100644 index 38d4a051..00000000 --- a/sjsonnet/src/sjsonnet/ExprTransform.scala +++ /dev/null @@ -1,303 +0,0 @@ -package sjsonnet - -import Expr._ - -/** Simple tree transformer for the AST. */ -abstract class ExprTransform { - - def transform(expr: Expr): Expr - - def rec(expr: Expr): Expr = { - expr match { - case Select(pos, x, name) => - val x2 = transform(x) - if (x2 eq x) expr - else Select(pos, x2, name) - - case Apply(pos, x, y, namedNames, tailstrict) => - val x2 = transform(x) - val y2 = transformArr(y) - if ((x2 eq x) && (y2 eq y)) expr - else Apply(pos, x2, y2, namedNames, tailstrict) - - case Apply0(pos, x, tailstrict) => - val x2 = transform(x) - if (x2 eq x) expr - else Apply0(pos, x2, tailstrict) - - case Apply1(pos, x, y, tailstrict) => - val x2 = transform(x) - val y2 = transform(y) - if ((x2 eq x) && (y2 eq y)) expr - else Apply1(pos, x2, y2, tailstrict) - - case Apply2(pos, x, y, z, tailstrict) => - val x2 = transform(x) - val y2 = transform(y) - val z2 = transform(z) - if ((x2 eq x) && (y2 eq y) && (z2 eq z)) expr - else Apply2(pos, x2, y2, z2, tailstrict) - - case Apply3(pos, x, y, z, a, tailstrict) => - val x2 = transform(x) - val y2 = transform(y) - val z2 = transform(z) - val a2 = transform(a) - if ((x2 eq x) && (y2 eq y) && (z2 eq z) && (a2 eq a)) expr - else Apply3(pos, x2, y2, z2, a2, tailstrict) - - case ApplyBuiltin(pos, func, x, tailstrict) => - val x2 = transformArr(x) - if (x2 eq x) expr - else ApplyBuiltin(pos, func, x2, tailstrict) - - case ApplyBuiltin1(pos, func, x, tailstrict) => - val x2 = transform(x) - if (x2 eq x) expr - else ApplyBuiltin1(pos, func, x2, tailstrict) - - case ApplyBuiltin2(pos, func, x, y, tailstrict) => - val x2 = transform(x) - val y2 = transform(y) - if ((x2 eq x) && (y2 eq y)) expr - else ApplyBuiltin2(pos, func, x2, y2, tailstrict) - - case ApplyBuiltin3(pos, func, x, y, z, tailstrict) => - val x2 = transform(x) - val y2 = transform(y) - val z2 = transform(z) - if ((x2 eq x) && (y2 eq y) && (z2 eq z)) expr - else ApplyBuiltin3(pos, func, x2, y2, z2, tailstrict) - - case ApplyBuiltin4(pos, func, x, y, z, a, tailstrict) => - val x2 = transform(x) - val y2 = transform(y) - val z2 = transform(z) - val a2 = transform(a) - if ((x2 eq x) && (y2 eq y) && (z2 eq z) && (a2 eq a)) expr - else ApplyBuiltin4(pos, func, x2, y2, z2, a2, tailstrict) - - case UnaryOp(pos, op, x) => - val x2 = transform(x) - if (x2 eq x) expr - else UnaryOp(pos, op, x2) - - case BinaryOp(pos, x, op, y) => - val x2 = transform(x) - val y2 = transform(y) - if ((x2 eq x) && (y2 eq y)) expr - else BinaryOp(pos, x2, op, y2) - - case And(pos, x, y) => - val x2 = transform(x) - val y2 = transform(y) - if ((x2 eq x) && (y2 eq y)) expr - else And(pos, x2, y2) - - case Or(pos, x, y) => - val x2 = transform(x) - val y2 = transform(y) - if ((x2 eq x) && (y2 eq y)) expr - else Or(pos, x2, y2) - - case InSuper(pos, x, selfIdx) => - val x2 = transform(x) - if (x2 eq x) expr - else InSuper(pos, x2, selfIdx) - - case Lookup(pos, x, y) => - val x2 = transform(x) - val y2 = transform(y) - if ((x2 eq x) && (y2 eq y)) expr - else Lookup(pos, x2, y2) - - case LookupSuper(pos, selfIdx, x) => - val x2 = transform(x) - if (x2 eq x) expr - else LookupSuper(pos, selfIdx, x2) - - case Function(pos, x, y) => - val x2 = transformParams(x) - val y2 = transform(y) - if ((x2 eq x) && (y2 eq y)) expr - else Function(pos, x2, y2) - - case LocalExpr(pos, x, y) => - val x2 = transformBinds(x) - val y2 = transform(y) - if ((x2 eq x) && (y2 eq y)) expr - else LocalExpr(pos, x2, y2) - - case IfElse(pos, x, y, z) => - val x2 = transform(x) - val y2 = transform(y) - val z2 = transform(z) - if ((x2 eq x) && (y2 eq y) && (z2 eq z)) expr - else IfElse(pos, x2, y2, z2) - - case ObjBody.MemberList(pos, x, y, z) => - val x2 = transformBinds(x) - val y2 = transformFields(y) - val z2 = transformAsserts(z) - if ((x2 eq x) && (y2 eq y) && (z2 eq z)) expr - else ObjBody.MemberList(pos, x2, y2, z2) - - case AssertExpr(pos, x, y) => - val x2 = transformAssert(x) - val y2 = transform(y) - if ((x2 eq x) && (y2 eq y)) expr - else AssertExpr(pos, x2, y2) - - case Comp(pos, x, y, z) => - val x2 = transform(x) - val y2 = transform(y).asInstanceOf[ForSpec] - val z2 = transformArr(z) - if ((x2 eq x) && (y2 eq y) && (z2 eq z)) expr - else Comp(pos, x2, y2, z2) - - case Arr(pos, x) => - val x2 = transformArr(x) - if (x2 eq x) expr - else Arr(pos, x2) - - case ObjExtend(superPos, x, y) => - val x2 = transform(x) - val y2 = transform(y) - if ((x2 eq x) && (y2 eq y)) expr - else ObjExtend(superPos, x2, y2.asInstanceOf[ObjBody]) - - case ObjBody.ObjComp(pos, p, k, v, pl, o, f, r) => - val p2 = transformBinds(p) - val k2 = transform(k) - val v2 = transform(v) - val o2 = transformBinds(o) - val f2 = transform(f).asInstanceOf[ForSpec] - val r2 = transformList(r).asInstanceOf[List[CompSpec]] - if ((p2 eq p) && (k2 eq k) && (v2 eq v) && (o2 eq o) && (f2 eq f) && (r2 eq r)) expr - else ObjBody.ObjComp(pos, p2, k2, v2, pl, o2, f2, r2) - - case Slice(pos, v, x, y, z) => - val v2 = transform(v) - val x2 = transformOption(x) - val y2 = transformOption(y) - val z2 = transformOption(z) - if ((v2 eq v) && (x2 eq x) && (y2 eq y) && (z2 eq z)) expr - else Slice(pos, v2, x2, y2, z2) - - case IfSpec(pos, x) => - val x2 = transform(x) - if (x2 eq x) expr - else IfSpec(pos, x2) - - case ForSpec(pos, name, x) => - val x2 = transform(x) - if (x2 eq x) expr - else ForSpec(pos, name, x2) - - case Expr.Error(pos, x) => - val x2 = transform(x) - if (x2 eq x) expr - else Expr.Error(pos, x2) - - case other => other - } - } - - protected def transformArr[T <: Expr](a: Array[T]): Array[T] = - transformGenericArr(a)((transform(_)).asInstanceOf[T => T]) - - protected def transformParams(p: Params): Params = { - if (p == null) return null - val defs = p.defaultExprs - if (defs == null) p - else { - val defs2 = transformArr(defs) - if (defs2 eq defs) p - else p.copy(defaultExprs = defs2) - } - } - - protected def transformBinds(a: Array[Bind]): Array[Bind] = - transformGenericArr(a)(transformBind) - - protected def transformFields(a: Array[Member.Field]): Array[Member.Field] = - transformGenericArr(a)(transformField) - - protected def transformAsserts(a: Array[Member.AssertStmt]): Array[Member.AssertStmt] = - transformGenericArr(a)(transformAssert) - - protected def transformBind(b: Bind): Bind = { - val args = b.args - val rhs = b.rhs - val args2 = transformParams(args) - val rhs2 = transform(rhs) - if ((args2 eq args) && (rhs2 eq rhs)) b - else b.copy(args = args2, rhs = rhs2) - } - - protected def transformField(f: Member.Field): Member.Field = { - val x = f.fieldName - val y = f.args - val z = f.rhs - val x2 = transformFieldName(x) - val y2 = transformParams(y) - val z2 = transform(z) - if ((x2 eq x) && (y2 eq y) && (z2 eq z)) f - else f.copy(fieldName = x2, args = y2, rhs = z2) - } - - protected def transformFieldName(f: FieldName): FieldName = f match { - case FieldName.Dyn(x) => - val x2 = transform(x) - if (x2 eq x) f else FieldName.Dyn(x2) - case _ => f - } - - protected def transformAssert(a: Member.AssertStmt): Member.AssertStmt = { - val x = a.value - val y = a.msg - val x2 = transform(x) - val y2 = transform(y) - if ((x2 eq x) && (y2 eq y)) a - else a.copy(value = x2, msg = y2) - } - - protected def transformOption(o: Option[Expr]): Option[Expr] = o match { - case Some(e) => - val e2 = transform(e) - if (e2 eq e) o else Some(e2) - case None => o - } - - protected def transformList(l: List[Expr]): List[Expr] = { - val lb = List.newBuilder[Expr] - var diff = false - l.foreach { e => - val e2 = transform(e) - lb.+=(e2) - if (e2 ne e) diff = true - } - if (diff) lb.result() else l - } - - protected def transformGenericArr[T <: AnyRef](a: Array[T])(f: T => T): Array[T] = { - if (a == null) return null - var i = 0 - while (i < a.length) { - val x1 = a(i) - val x2 = f(x1) - if (x1 ne x2) { - val a2 = a.clone() - a2(i) = x2 - i += 1 - while (i < a2.length) { - a2(i) = f(a2(i)) - i += 1 - } - return a2 - } - i += 1 - } - a - } -} diff --git a/sjsonnet/src/sjsonnet/ScopedExprTransform.scala b/sjsonnet/src/sjsonnet/ScopedExprTransform.scala deleted file mode 100644 index c7d58bde..00000000 --- a/sjsonnet/src/sjsonnet/ScopedExprTransform.scala +++ /dev/null @@ -1,194 +0,0 @@ -package sjsonnet - -import sjsonnet.Expr.ObjBody.{MemberList, ObjComp} -import sjsonnet.Expr._ - -import scala.annotation.nowarn -import scala.collection.immutable.HashMap - -/** Tree transformer that keeps track of the bindings in the static scope. */ -class ScopedExprTransform extends ExprTransform { - import ScopedExprTransform._ - var scope: Scope = emptyScope - - // Marker for Exprs in the scope that should not be used because they need to be evaluated in a different scope - val dynamicExpr: Expr = new Expr { - var pos: Position = null; override def toString = "dynamicExpr" - } - - def transform(e: Expr): Expr = e match { - case LocalExpr(pos, bindings, returned) => - val (b2, r2) = nestedConsecutiveBindings(bindings)(transformBind)(transform(returned)) - if ((b2 eq bindings) && (r2 eq returned)) e - else LocalExpr(pos, b2, r2) - - case MemberList(pos, binds, fields, asserts) => - val fields2 = transformGenericArr(fields)(transformFieldNameOnly) - val (binds2, (fields3, asserts2)) = nestedObject(dynamicExpr, dynamicExpr) { - nestedConsecutiveBindings(binds)(transformBind) { - val fields3 = transformGenericArr(fields2)(transformFieldNoName) - val asserts2 = transformAsserts(asserts) - (fields3, asserts2) - } - } - if ((binds2 eq binds) && (fields3 eq fields) && (asserts2 eq asserts)) e - else ObjBody.MemberList(pos, binds2, fields3, asserts2) - - case Function(pos, params, body) => - nestedNames(params.names)(rec(e)) - - case ObjComp(pos, preLocals, key, value, plus, postLocals, first, rest) => - val (f2 :: r2, (k2, (pre2, post2, v2))) = compSpecs( - first :: rest, - { () => - ( - transform(key), - nestedBindings(dynamicExpr, dynamicExpr, preLocals ++ postLocals) { - (transformBinds(preLocals), transformBinds(postLocals), transform(value)) - } - ) - } - ): @unchecked - if ( - (f2 eq first) && (k2 eq key) && (v2 eq value) && (pre2 eq preLocals) && (post2 eq postLocals) && ( - r2, - rest - ).zipped.forall(_ eq _): @nowarn - ) e - else ObjComp(pos, pre2, k2, v2, plus, post2, f2.asInstanceOf[ForSpec], r2) - - case Comp(pos, value, first, rest) => - val (f2 :: r2, v2) = compSpecs(first :: rest.toList, () => transform(value)): @unchecked - if ((f2 eq first) && (v2 eq value) && (r2, rest).zipped.forall(_ eq _): @nowarn) e - else Comp(pos, v2, f2.asInstanceOf[ForSpec], r2.toArray) - - case e => rec(e) - } - - override def transformBind(b: Bind): Bind = { - val args = b.args - val rhs = b.rhs - nestedNames(if (args == null) null else args.names) { - val args2 = transformParams(args) - val rhs2 = transform(rhs) - if ((args2 eq args) && (rhs2 eq rhs)) b - else b.copy(args = args2, rhs = rhs2) - } - } - - protected def transformFieldNameOnly(f: Member.Field): Member.Field = { - val x = f.fieldName - val x2 = transformFieldName(x) - if (x2 eq x) f else f.copy(fieldName = x2) - } - - protected def transformFieldNoName(f: Member.Field): Member.Field = { - def g = { - val y = f.args - val z = f.rhs - val y2 = transformParams(y) - val z2 = transform(z) - if ((y2 eq y) && (z2 eq z)) f else f.copy(args = y2, rhs = z2) - } - if (f.args == null) g - else nestedNames(f.args.names)(g) - } - - override protected def transformField(f: Member.Field): Member.Field = ??? - - protected def compSpecs[T](a: List[CompSpec], value: () => T): (List[CompSpec], T) = a match { - case (c @ ForSpec(pos, name, cond)) :: cs => - val c2 = rec(c).asInstanceOf[ForSpec] - nestedWith(c2.name, dynamicExpr) { - val (cs2, value2) = compSpecs(cs, value) - (c2 :: cs2, value2) - } - case (c @ IfSpec(pos, cond)) :: cs => - val c2 = rec(c).asInstanceOf[CompSpec] - val (cs2, value2) = compSpecs(cs, value) - (c2 :: cs2, value2) - case Nil => - (Nil, value()) - } - - protected def nestedNew[T](sc: Scope)(f: => T): T = { - val oldScope = scope - scope = sc - try f - finally { scope = oldScope } - } - - protected def nestedWith[T](n: String, e: Expr)(f: => T): T = - nestedNew( - new Scope(scope.mappings.updated(n, ScopedVal(e, scope, scope.size)), scope.size + 1) - )(f) - - protected def nestedFileScope[T](fs: FileScope)(f: => T): T = - nestedNew(emptyScope)(f) - - protected def nestedConsecutiveBindings[T](a: Array[Bind])(f: => Bind => Bind)( - g: => T): (Array[Bind], T) = { - if (a == null || a.length == 0) (a, g) - else { - val oldScope = scope - try { - val mappings = a.zipWithIndex.map { case (b, idx) => - (b.name, ScopedVal(if (b.args == null) b.rhs else b, scope, scope.size + idx)) - } - scope = new Scope(oldScope.mappings ++ mappings, oldScope.size + a.length) - var changed = false - val a2 = a.zipWithIndex.map { case (b, idx) => - val b2 = f(b) - val sv = mappings(idx)._2.copy(v = if (b2.args == null) b2.rhs else b2) - scope = new Scope(scope.mappings.updated(b.name, sv), scope.size) - if (b2 ne b) changed = true - b2 - } - (if (changed) a2 else a, g) - } finally { scope = oldScope } - } - } - - protected def nestedBindings[T](a: Array[Bind])(f: => T): T = { - if (a == null || a.length == 0) f - else { - val newm = a.zipWithIndex.map { case (b, idx) => - // println(s"Binding ${b.name} to ${scope.size + idx}") - (b.name, ScopedVal(if (b.args == null) b.rhs else b, scope, scope.size + idx)) - } - nestedNew(new Scope(scope.mappings ++ newm, scope.size + a.length))(f) - } - } - - protected def nestedObject[T](self0: Expr, super0: Expr)(f: => T): T = { - val self = ScopedVal(self0, scope, scope.size) - val sup = ScopedVal(super0, scope, scope.size + 1) - val newm = { - val m1 = scope.mappings + (("self", self)) + (("super", sup)) - if (scope.contains("self")) m1 else m1 + (("$", self)) - } - nestedNew(new Scope(newm, scope.size + 2))(f) - } - - protected def nestedBindings[T](self0: Expr, super0: Expr, a: Array[Bind])(f: => T): T = - nestedObject(self0, super0)(nestedBindings(a)(f)) - - protected def nestedNames[T](a: Array[String])(f: => T): T = { - if (a == null || a.length == 0) f - else { - val newm = a.zipWithIndex.map { case (n, idx) => - (n, ScopedVal(dynamicExpr, scope, scope.size + idx)) - } - nestedNew(new Scope(scope.mappings ++ newm, scope.size + a.length))(f) - } - } -} - -object ScopedExprTransform { - final case class ScopedVal(v: AnyRef, sc: Scope, idx: Int) - final class Scope(val mappings: HashMap[String, ScopedVal], val size: Int) { - def get(s: String): ScopedVal = mappings.getOrElse(s, null) - def contains(s: String): Boolean = mappings.contains(s) - } - def emptyScope: Scope = new Scope(HashMap.empty, 0) -} diff --git a/sjsonnet/src/sjsonnet/StaticOptimizer.scala b/sjsonnet/src/sjsonnet/StaticOptimizer.scala index 2240966d..b0413213 100644 --- a/sjsonnet/src/sjsonnet/StaticOptimizer.scala +++ b/sjsonnet/src/sjsonnet/StaticOptimizer.scala @@ -1,11 +1,11 @@ package sjsonnet -import scala.annotation.switch +import scala.annotation.{nowarn, switch} +import scala.collection.immutable.HashMap import scala.collection.mutable import Expr.* import Evaluator.SafeDoubleOps -import ScopedExprTransform.* /** * StaticOptimizer performs necessary transformations for the evaluator (assigning ValScope indices) @@ -29,128 +29,696 @@ class StaticOptimizer( internedStaticFieldSets: mutable.HashMap[ Val.StaticObjectFieldSet, java.util.LinkedHashMap[String, java.lang.Boolean] - ]) - extends ScopedExprTransform { + ]) { + import StaticOptimizer.* + + private[this] var scope: Scope = emptyScope + + // Marker for Exprs in the scope that should not be used because they need to be evaluated in a different scope + private[this] val dynamicExpr: Expr = new Expr { + var pos: Position = null + override def toString = "dynamicExpr" + } + def optimize(e: Expr): Expr = transform(e) - override def transform(_e: Expr): Expr = { - // Fast path: fold pure numeric literal chains as raw doubles before the bottom-up transform. - // This avoids intermediate Val.Num + BinaryOp allocations for chains like `60 * 60 * 24`. - _e match { - case _: BinaryOp | _: UnaryOp => - val d = tryFoldAsDouble(_e) - if (!d.isNaN) return Val.Num(_e.pos, d) - case _ => + def transform(e: Expr): Expr = { + if (e == null) return null + val tag = e.tag + if (tag == ExprTags.BinaryOp || tag == ExprTags.UnaryOp) { + val d = tryFoldAsDouble(e) + if (!d.isNaN) return Val.Num(e.pos, d) } - super.transform(check(_e)) match { - case a: Apply => transformApply(a) - - case e @ Select(p, obj: Val.Obj, name) if obj.containsKey(name) => - try obj.value(name, p)(ev).asInstanceOf[Expr] - catch { case _: Exception => e } - - case Select(pos, ValidSuper(_, selfIdx), name) => - SelectSuper(pos, selfIdx, name) - - case Lookup(pos, ValidSuper(_, selfIdx), index) => - LookupSuper(pos, selfIdx, index) - - case BinaryOp(pos, lhs, BinaryOp.OP_in, ValidSuper(_, selfIdx)) => - InSuper(pos, lhs, selfIdx) - case b2 @ BinaryOp(pos, lhs: Val.Str, BinaryOp.OP_%, rhs) => - try { - rhs match { - case r: Val => - val partial = new Format.PartialApplyFmt(lhs.str) - try partial.evalRhs(r, ev, pos).asInstanceOf[Expr] - catch { - case _: Exception => - ApplyBuiltin1(pos, partial, rhs, tailstrict = false) - } - case _ => - ApplyBuiltin1(pos, new Format.PartialApplyFmt(lhs.str), rhs, tailstrict = false) - } - } catch { case _: Exception => b2 } - - case e @ Id(pos, name) => - scope.get(name) match { - case ScopedVal(v: Val with Expr, _, _) => v - case ScopedVal(_, _, idx) => ValidId(pos, name, idx) - case null if name == f"$$std" => std - case null if name == "std" => std - case null => - variableResolver(name) match { - case Some(v) => v // additional variable resolution - case None => - StaticError.fail( - "Unknown variable: " + name, - e - )(ev) + optimizeTransformed(transformChildren(check(e))) + } + + private def transformChildren(e: Expr): Expr = (e.tag: @switch) match { + case ExprTags.LocalExpr => + val local = e.asInstanceOf[LocalExpr] + val bindings = local.bindings + val returned = local.returned + val (bindings2, returned2) = + nestedConsecutiveBindings(bindings)(transformBind)(transform(returned)) + if ((bindings2 eq bindings) && (returned2 eq returned)) e + else LocalExpr(local.pos, bindings2, returned2) + + case ExprTags.`ObjBody.MemberList` => + val memberList = e.asInstanceOf[ObjBody.MemberList] + val binds = memberList.binds + val fields = memberList.fields + val asserts = memberList.asserts + val fields2 = transformGenericArr(fields)(transformFieldNameOnly) + val (binds2, (fields3, asserts2)) = nestedObject(dynamicExpr, dynamicExpr) { + nestedConsecutiveBindings(binds)(transformBind) { + val fields3 = transformGenericArr(fields2)(transformFieldNoName) + val asserts2 = transformAsserts(asserts) + (fields3, asserts2) + } + } + if ((binds2 eq binds) && (fields3 eq fields) && (asserts2 eq asserts)) e + else ObjBody.MemberList(memberList.pos, binds2, fields3, asserts2) + + case ExprTags.Function => + val function = e.asInstanceOf[Function] + nestedNames(function.params.names)(rec(e)) + + case ExprTags.`ObjBody.ObjComp` => + val objComp = e.asInstanceOf[ObjBody.ObjComp] + val specs = objComp.first :: objComp.rest + val (first2 :: rest2, (key2, (pre2, post2, value2))) = compSpecs( + specs, + { () => + ( + transform(objComp.key), + nestedBindings(dynamicExpr, dynamicExpr, objComp.preLocals ++ objComp.postLocals) { + ( + transformBinds(objComp.preLocals), + transformBinds(objComp.postLocals), + transform(objComp.value) + ) } + ) } + ): @unchecked + if ( + (first2 eq objComp.first) && (key2 eq objComp.key) && (value2 eq objComp.value) && + (pre2 eq objComp.preLocals) && (post2 eq objComp.postLocals) && + (rest2, objComp.rest).zipped.forall(_ eq _): @nowarn + ) e + else + ObjBody.ObjComp( + objComp.pos, + pre2, + key2, + value2, + objComp.plus, + post2, + first2.asInstanceOf[ForSpec], + rest2 + ) + + case ExprTags.Comp => + val comp = e.asInstanceOf[Comp] + val (first2 :: rest2, value2) = + compSpecs(comp.first :: comp.rest.toList, () => transform(comp.value)): @unchecked + if ( + (first2 eq comp.first) && (value2 eq comp.value) && (rest2, comp.rest).zipped.forall( + _ eq _ + ): @nowarn + ) e + else Comp(comp.pos, value2, first2.asInstanceOf[ForSpec], rest2.toArray) + + case _ => + rec(e) + } - case e @ Self(pos) => - scope.get("self") match { - case ScopedVal(v, _, idx) if v != null => ValidId(pos, "self", idx) - case _ => StaticError.fail("Can't use self outside of an object", e)(ev) + private def optimizeTransformed(e: Expr): Expr = { + if (e == null) return null + (e.tag: @switch) match { + case ExprTags.UNTAGGED => + e match { + case id @ Id(pos, name) => + scope.get(name) match { + case ScopedVal(v: Val with Expr, _, _) => v + case ScopedVal(_, _, idx) => ValidId(pos, name, idx) + case null if name == f"$$std" => std + case null if name == "std" => std + case null => + variableResolver(name) match { + case Some(v) => v + case None => + StaticError.fail( + "Unknown variable: " + name, + id + )(ev) + } + } + case _ => e } - case e @ $(pos) => - scope.get("$") match { - case ScopedVal(v, _, idx) if v != null => ValidId(pos, "$", idx) - case _ => StaticError.fail("Can't use $ outside of an object", e)(ev) + case ExprTags.ValidId => + e match { + case self @ Self(pos) => + scope.get("self") match { + case ScopedVal(v, _, idx) if v != null => ValidId(pos, "self", idx) + case _ => StaticError.fail("Can't use self outside of an object", self)(ev) + } + case _ => e } - case e @ Super(_) if !scope.contains("super") => - StaticError.fail("Can't use super outside of an object", e)(ev) - - case a: Arr if a.value.forall(_.isInstanceOf[Val]) => - Val.Arr(a.pos, a.value.map(e => e.asInstanceOf[Val])) - - case m @ ObjBody.MemberList(pos, binds, fields, asserts) => - // If static optimization has constant-folded originally-dynamic field names - // into fixed names, it's possible that we might now have duplicate names. - // In that case, we keep the object as a MemberList and leave it to the - // Evaluator to throw an error if/when the object is evaluated (in order - // to preserve proper laziness semantics). - def allFieldsStaticAndUniquelyNamed: Boolean = { - val seen = mutable.Set.empty[String] - fields.forall { f => - f.isStatic && seen.add(f.fieldName.asInstanceOf[FieldName.Fixed].value) - } + case ExprTags.BinaryOp => + e match { + case root @ $(pos) => + scope.get("$") match { + case ScopedVal(v, _, idx) if v != null => ValidId(pos, "$", idx) + case _ => StaticError.fail("Can't use $ outside of an object", root)(ev) + } + case binary: BinaryOp => + binary.rhs match { + case ValidSuper(_, selfIdx) if binary.op == BinaryOp.OP_in => + InSuper(binary.pos, binary.lhs, selfIdx) + case rhs if binary.op == BinaryOp.OP_% => + binary.lhs match { + case lhs: Val.Str => + try { + rhs match { + case r: Val => + val partial = new Format.PartialApplyFmt(lhs.str) + try partial.evalRhs(r, ev, binary.pos).asInstanceOf[Expr] + catch { + case _: Exception => + ApplyBuiltin1(binary.pos, partial, rhs, tailstrict = false) + } + case _ => + ApplyBuiltin1( + binary.pos, + new Format.PartialApplyFmt(lhs.str), + rhs, + tailstrict = false + ) + } + } catch { case _: Exception => binary } + case lhs: Val => + rhs match { + case r: Val => tryFoldBinaryOp(binary.pos, lhs, binary.op, r, binary) + case _ => binary + } + case _ => binary + } + case rhs: Val => + binary.lhs match { + case lhs: Val => tryFoldBinaryOp(binary.pos, lhs, binary.op, rhs, binary) + case _ => binary + } + case _ => binary + } + case _ => e + } + + case ExprTags.Select => + e match { + case sup @ Super(_) if !scope.contains("super") => + StaticError.fail("Can't use super outside of an object", sup)(ev) + case select: Select => + select.value match { + case obj: Val.Obj if obj.containsKey(select.name) => + try obj.value(select.name, select.pos)(ev).asInstanceOf[Expr] + catch { case _: Exception => select } + case ValidSuper(_, selfIdx) => + SelectSuper(select.pos, selfIdx, select.name) + case _ => + select + } + case _ => e + } + + case ExprTags.Apply => + transformApply(e.asInstanceOf[Apply]) + + case ExprTags.Lookup => + val lookup = e.asInstanceOf[Lookup] + lookup.value match { + case ValidSuper(_, selfIdx) => LookupSuper(lookup.pos, selfIdx, lookup.index) + case _ => lookup + } + + case ExprTags.Arr => + val arr = e.asInstanceOf[Arr] + if (arr.value.forall(_.isInstanceOf[Val])) + Val.Arr(arr.pos, arr.value.map(_.asInstanceOf[Val])) + else e + + case ExprTags.`ObjBody.MemberList` => + optimizeMemberList(e.asInstanceOf[ObjBody.MemberList]) + + case ExprTags.UnaryOp => + val unary = e.asInstanceOf[UnaryOp] + unary.value match { + case v: Val => tryFoldUnaryOp(unary.pos, unary.op, v, e) + case _ => e } - if (binds == null && asserts == null && allFieldsStaticAndUniquelyNamed) - Val.staticObject(pos, fields, internedStaticFieldSets, internedStrings) - else m - // Aggressive optimizations: constant folding, branch elimination, short-circuit elimination. - // These reduce AST node count at parse time, benefiting long-running Jsonnet programs. - // Constant folding: BinaryOp with two constant operands (most common case first) - case e @ BinaryOp(pos, lhs: Val, op, rhs: Val) => tryFoldBinaryOp(pos, lhs, op, rhs, e) - - // Constant folding: UnaryOp with constant operand - case e @ UnaryOp(pos, op, v: Val) => tryFoldUnaryOp(pos, op, v, e) - - // Branch elimination: constant condition in if-else - case IfElse(pos, _: Val.True, thenExpr, _) => - thenExpr.pos = pos; thenExpr - case IfElse(pos, _: Val.False, _, elseExpr) => - if (elseExpr == null) Val.Null(pos) - else { elseExpr.pos = pos; elseExpr } - - // Short-circuit elimination for And/Or with constant lhs. - // - // IMPORTANT: rhs MUST be guarded as `Val.Bool` — do NOT relax this to arbitrary Expr. - // The Evaluator's visitAnd/visitOr enforces that rhs evaluates to Bool, throwing - // "binary operator && does not operate on s" otherwise. If we fold `true && rhs` - // into just `rhs` without the Bool guard, we silently remove that runtime type check, - // causing programs like `true && "hello"` to return "hello" instead of erroring. - // See: Evaluator.visitAnd / Evaluator.visitOr for the authoritative runtime semantics. - case And(pos, _: Val.True, rhs: Val.Bool) => rhs.pos = pos; rhs - case And(pos, _: Val.False, _) => Val.False(pos) - case Or(pos, _: Val.True, _) => Val.True(pos) - case Or(pos, _: Val.False, rhs: Val.Bool) => rhs.pos = pos; rhs - case e => e + case ExprTags.IfElse => + val ifElse = e.asInstanceOf[IfElse] + ifElse.cond match { + case _: Val.True => + ifElse.`then`.pos = ifElse.pos + ifElse.`then` + case _: Val.False => + val elseExpr = ifElse.`else` + if (elseExpr == null) Val.Null(ifElse.pos) + else { + elseExpr.pos = ifElse.pos + elseExpr + } + case _ => e + } + + case ExprTags.And => + val and = e.asInstanceOf[And] + and.lhs match { + case _: Val.True => + and.rhs match { + case rhs: Val.Bool => + rhs.pos = and.pos + rhs + case _ => e + } + case _: Val.False => Val.False(and.pos) + case _ => e + } + + case ExprTags.Or => + val or = e.asInstanceOf[Or] + or.lhs match { + case _: Val.True => Val.True(or.pos) + case _: Val.False => + or.rhs match { + case rhs: Val.Bool => + rhs.pos = or.pos + rhs + case _ => e + } + case _ => e + } + + case _ => + e + } + } + + private def optimizeMemberList(m: ObjBody.MemberList): Expr = { + def allFieldsStaticAndUniquelyNamed: Boolean = { + val seen = mutable.Set.empty[String] + m.fields.forall { f => + f.isStatic && seen.add(f.fieldName.asInstanceOf[FieldName.Fixed].value) + } + } + + if (m.binds == null && m.asserts == null && allFieldsStaticAndUniquelyNamed) + Val.staticObject(m.pos, m.fields, internedStaticFieldSets, internedStrings) + else m + } + + private def rec(expr: Expr): Expr = expr match { + case Select(pos, x, name) => + val x2 = transform(x) + if (x2 eq x) expr + else Select(pos, x2, name) + + case Apply(pos, x, y, namedNames, tailstrict) => + val x2 = transform(x) + val y2 = transformArr(y) + if ((x2 eq x) && (y2 eq y)) expr + else Apply(pos, x2, y2, namedNames, tailstrict) + + case Apply0(pos, x, tailstrict) => + val x2 = transform(x) + if (x2 eq x) expr + else Apply0(pos, x2, tailstrict) + + case Apply1(pos, x, y, tailstrict) => + val x2 = transform(x) + val y2 = transform(y) + if ((x2 eq x) && (y2 eq y)) expr + else Apply1(pos, x2, y2, tailstrict) + + case Apply2(pos, x, y, z, tailstrict) => + val x2 = transform(x) + val y2 = transform(y) + val z2 = transform(z) + if ((x2 eq x) && (y2 eq y) && (z2 eq z)) expr + else Apply2(pos, x2, y2, z2, tailstrict) + + case Apply3(pos, x, y, z, a, tailstrict) => + val x2 = transform(x) + val y2 = transform(y) + val z2 = transform(z) + val a2 = transform(a) + if ((x2 eq x) && (y2 eq y) && (z2 eq z) && (a2 eq a)) expr + else Apply3(pos, x2, y2, z2, a2, tailstrict) + + case ApplyBuiltin(pos, func, x, tailstrict) => + val x2 = transformArr(x) + if (x2 eq x) expr + else ApplyBuiltin(pos, func, x2, tailstrict) + + case ApplyBuiltin1(pos, func, x, tailstrict) => + val x2 = transform(x) + if (x2 eq x) expr + else ApplyBuiltin1(pos, func, x2, tailstrict) + + case ApplyBuiltin2(pos, func, x, y, tailstrict) => + val x2 = transform(x) + val y2 = transform(y) + if ((x2 eq x) && (y2 eq y)) expr + else ApplyBuiltin2(pos, func, x2, y2, tailstrict) + + case ApplyBuiltin3(pos, func, x, y, z, tailstrict) => + val x2 = transform(x) + val y2 = transform(y) + val z2 = transform(z) + if ((x2 eq x) && (y2 eq y) && (z2 eq z)) expr + else ApplyBuiltin3(pos, func, x2, y2, z2, tailstrict) + + case ApplyBuiltin4(pos, func, x, y, z, a, tailstrict) => + val x2 = transform(x) + val y2 = transform(y) + val z2 = transform(z) + val a2 = transform(a) + if ((x2 eq x) && (y2 eq y) && (z2 eq z) && (a2 eq a)) expr + else ApplyBuiltin4(pos, func, x2, y2, z2, a2, tailstrict) + + case UnaryOp(pos, op, x) => + val x2 = transform(x) + if (x2 eq x) expr + else UnaryOp(pos, op, x2) + + case BinaryOp(pos, x, op, y) => + val x2 = transform(x) + val y2 = transform(y) + if ((x2 eq x) && (y2 eq y)) expr + else BinaryOp(pos, x2, op, y2) + + case And(pos, x, y) => + val x2 = transform(x) + val y2 = transform(y) + if ((x2 eq x) && (y2 eq y)) expr + else And(pos, x2, y2) + + case Or(pos, x, y) => + val x2 = transform(x) + val y2 = transform(y) + if ((x2 eq x) && (y2 eq y)) expr + else Or(pos, x2, y2) + + case InSuper(pos, x, selfIdx) => + val x2 = transform(x) + if (x2 eq x) expr + else InSuper(pos, x2, selfIdx) + + case Lookup(pos, x, y) => + val x2 = transform(x) + val y2 = transform(y) + if ((x2 eq x) && (y2 eq y)) expr + else Lookup(pos, x2, y2) + + case LookupSuper(pos, selfIdx, x) => + val x2 = transform(x) + if (x2 eq x) expr + else LookupSuper(pos, selfIdx, x2) + + case Function(pos, x, y) => + val x2 = transformParams(x) + val y2 = transform(y) + if ((x2 eq x) && (y2 eq y)) expr + else Function(pos, x2, y2) + + case LocalExpr(pos, x, y) => + val x2 = transformBinds(x) + val y2 = transform(y) + if ((x2 eq x) && (y2 eq y)) expr + else LocalExpr(pos, x2, y2) + + case IfElse(pos, x, y, z) => + val x2 = transform(x) + val y2 = transform(y) + val z2 = transform(z) + if ((x2 eq x) && (y2 eq y) && (z2 eq z)) expr + else IfElse(pos, x2, y2, z2) + + case ObjBody.MemberList(pos, x, y, z) => + val x2 = transformBinds(x) + val y2 = transformFields(y) + val z2 = transformAsserts(z) + if ((x2 eq x) && (y2 eq y) && (z2 eq z)) expr + else ObjBody.MemberList(pos, x2, y2, z2) + + case AssertExpr(pos, x, y) => + val x2 = transformAssert(x) + val y2 = transform(y) + if ((x2 eq x) && (y2 eq y)) expr + else AssertExpr(pos, x2, y2) + + case Comp(pos, x, y, z) => + val x2 = transform(x) + val y2 = transform(y).asInstanceOf[ForSpec] + val z2 = transformArr(z) + if ((x2 eq x) && (y2 eq y) && (z2 eq z)) expr + else Comp(pos, x2, y2, z2) + + case Arr(pos, x) => + val x2 = transformArr(x) + if (x2 eq x) expr + else Arr(pos, x2) + + case ObjExtend(pos, x, y) => + val x2 = transform(x) + val y2 = transform(y) + if ((x2 eq x) && (y2 eq y)) expr + else ObjExtend(pos, x2, y2.asInstanceOf[ObjBody]) + + case ObjBody.ObjComp(pos, p, k, v, pl, o, f, r) => + val p2 = transformBinds(p) + val k2 = transform(k) + val v2 = transform(v) + val o2 = transformBinds(o) + val f2 = transform(f).asInstanceOf[ForSpec] + val r2 = transformList(r).asInstanceOf[List[CompSpec]] + if ((p2 eq p) && (k2 eq k) && (v2 eq v) && (o2 eq o) && (f2 eq f) && (r2 eq r)) expr + else ObjBody.ObjComp(pos, p2, k2, v2, pl, o2, f2, r2) + + case Slice(pos, v, x, y, z) => + val v2 = transform(v) + val x2 = transformOption(x) + val y2 = transformOption(y) + val z2 = transformOption(z) + if ((v2 eq v) && (x2 eq x) && (y2 eq y) && (z2 eq z)) expr + else Slice(pos, v2, x2, y2, z2) + + case IfSpec(pos, x) => + val x2 = transform(x) + if (x2 eq x) expr + else IfSpec(pos, x2) + + case ForSpec(pos, name, x) => + val x2 = transform(x) + if (x2 eq x) expr + else ForSpec(pos, name, x2) + + case Expr.Error(pos, x) => + val x2 = transform(x) + if (x2 eq x) expr + else Expr.Error(pos, x2) + + case other => other + } + + private def transformArr[T <: Expr](a: Array[T]): Array[T] = + transformGenericArr(a)((transform(_)).asInstanceOf[T => T]) + + private def transformParams(p: Params): Params = { + if (p == null) return null + val defs = p.defaultExprs + if (defs == null) p + else { + val defs2 = transformArr(defs) + if (defs2 eq defs) p + else p.copy(defaultExprs = defs2) + } + } + + private def transformBinds(a: Array[Bind]): Array[Bind] = + transformGenericArr(a)(transformBind) + + private def transformFields(a: Array[Member.Field]): Array[Member.Field] = + transformGenericArr(a)(transformField) + + private def transformAsserts(a: Array[Member.AssertStmt]): Array[Member.AssertStmt] = + transformGenericArr(a)(transformAssert) + + private def transformBind(b: Bind): Bind = { + val args = b.args + val rhs = b.rhs + nestedNames(if (args == null) null else args.names) { + val args2 = transformParams(args) + val rhs2 = transform(rhs) + if ((args2 eq args) && (rhs2 eq rhs)) b + else b.copy(args = args2, rhs = rhs2) + } + } + + private def transformField(f: Member.Field): Member.Field = { + val x = f.fieldName + val y = f.args + val z = f.rhs + val x2 = transformFieldName(x) + val y2 = transformParams(y) + val z2 = transform(z) + if ((x2 eq x) && (y2 eq y) && (z2 eq z)) f + else f.copy(fieldName = x2, args = y2, rhs = z2) + } + + private def transformFieldNameOnly(f: Member.Field): Member.Field = { + val x = f.fieldName + val x2 = transformFieldName(x) + if (x2 eq x) f else f.copy(fieldName = x2) + } + + private def transformFieldNoName(f: Member.Field): Member.Field = { + def transformed = { + val y = f.args + val z = f.rhs + val y2 = transformParams(y) + val z2 = transform(z) + if ((y2 eq y) && (z2 eq z)) f else f.copy(args = y2, rhs = z2) + } + if (f.args == null) transformed + else nestedNames(f.args.names)(transformed) + } + + private def transformFieldName(f: FieldName): FieldName = f match { + case FieldName.Dyn(x) => + transform(x) match { + case x2: Val.Str => + FieldName.Fixed(x2.str) + case x2 if x2 eq x => f + case x2 => FieldName.Dyn(x2) + } + case _ => f + } + + private def transformAssert(a: Member.AssertStmt): Member.AssertStmt = { + val x = a.value + val y = a.msg + val x2 = transform(x) + val y2 = transform(y) + if ((x2 eq x) && (y2 eq y)) a + else a.copy(value = x2, msg = y2) + } + + private def transformOption(o: Option[Expr]): Option[Expr] = o match { + case Some(e) => + val e2 = transform(e) + if (e2 eq e) o else Some(e2) + case None => o + } + + private def transformList(l: List[Expr]): List[Expr] = { + val lb = List.newBuilder[Expr] + var diff = false + l.foreach { e => + val e2 = transform(e) + lb += e2 + if (e2 ne e) diff = true + } + if (diff) lb.result() else l + } + + private def transformGenericArr[T <: AnyRef](a: Array[T])(f: T => T): Array[T] = { + if (a == null) return null + var i = 0 + while (i < a.length) { + val x1 = a(i) + val x2 = f(x1) + if (x1 ne x2) { + val a2 = a.clone() + a2(i) = x2 + i += 1 + while (i < a2.length) { + a2(i) = f(a2(i)) + i += 1 + } + return a2 + } + i += 1 + } + a + } + + private def compSpecs[T](a: List[CompSpec], value: () => T): (List[CompSpec], T) = a match { + case (c @ ForSpec(_, _, _)) :: cs => + val c2 = rec(c).asInstanceOf[ForSpec] + nestedWith(c2.name, dynamicExpr) { + val (cs2, value2) = compSpecs(cs, value) + (c2 :: cs2, value2) + } + case (c @ IfSpec(_, _)) :: cs => + val c2 = rec(c).asInstanceOf[CompSpec] + val (cs2, value2) = compSpecs(cs, value) + (c2 :: cs2, value2) + case Nil => + (Nil, value()) + } + + private def nestedNew[T](sc: Scope)(f: => T): T = { + val oldScope = scope + scope = sc + try f + finally scope = oldScope + } + + private def nestedWith[T](n: String, e: Expr)(f: => T): T = + nestedNew( + new Scope(scope.mappings.updated(n, ScopedVal(e, scope, scope.size)), scope.size + 1) + )(f) + + private def nestedFileScope[T](fs: FileScope)(f: => T): T = + nestedNew(emptyScope)(f) + + private def nestedConsecutiveBindings[T](a: Array[Bind])(f: => Bind => Bind)( + g: => T): (Array[Bind], T) = { + if (a == null || a.length == 0) (a, g) + else { + val oldScope = scope + try { + val mappings = a.zipWithIndex.map { case (b, idx) => + (b.name, ScopedVal(if (b.args == null) b.rhs else b, scope, scope.size + idx)) + } + scope = new Scope(oldScope.mappings ++ mappings, oldScope.size + a.length) + var changed = false + val a2 = a.zipWithIndex.map { case (b, idx) => + val b2 = f(b) + val sv = mappings(idx)._2.copy(v = if (b2.args == null) b2.rhs else b2) + scope = new Scope(scope.mappings.updated(b.name, sv), scope.size) + if (b2 ne b) changed = true + b2 + } + (if (changed) a2 else a, g) + } finally scope = oldScope + } + } + + private def nestedBindings[T](a: Array[Bind])(f: => T): T = { + if (a == null || a.length == 0) f + else { + val newMappings = a.zipWithIndex.map { case (b, idx) => + (b.name, ScopedVal(if (b.args == null) b.rhs else b, scope, scope.size + idx)) + } + nestedNew(new Scope(scope.mappings ++ newMappings, scope.size + a.length))(f) + } + } + + private def nestedObject[T](self0: Expr, super0: Expr)(f: => T): T = { + val self = ScopedVal(self0, scope, scope.size) + val sup = ScopedVal(super0, scope, scope.size + 1) + val newMappings = { + val withSelf = scope.mappings + (("self", self)) + (("super", sup)) + if (scope.contains("self")) withSelf else withSelf + (("$", self)) + } + nestedNew(new Scope(newMappings, scope.size + 2))(f) + } + + private def nestedBindings[T](self0: Expr, super0: Expr, a: Array[Bind])(f: => T): T = + nestedObject(self0, super0)(nestedBindings(a)(f)) + + private def nestedNames[T](a: Array[String])(f: => T): T = { + if (a == null || a.length == 0) f + else { + val newMappings = a.zipWithIndex.map { case (n, idx) => + (n, ScopedVal(dynamicExpr, scope, scope.size + idx)) + } + nestedNew(new Scope(scope.mappings ++ newMappings, scope.size + a.length))(f) } } @@ -163,51 +731,59 @@ class StaticOptimizer( */ private def tryFoldAsDouble(e: Expr): Double = try { - e match { - case Val.Num(_, n) => n - case BinaryOp(pos, lhs, op, rhs) => - val l = tryFoldAsDouble(lhs) - if (l.isNaN) return Double.NaN - val r = tryFoldAsDouble(rhs) - if (r.isNaN) return Double.NaN - (op: @switch) match { - case BinaryOp.OP_+ => - val res = l + r; if (res.isInfinite) return Double.NaN; res - case BinaryOp.OP_- => - val res = l - r; if (res.isInfinite) return Double.NaN; res - case BinaryOp.OP_* => - val res = l * r; if (res.isInfinite) return Double.NaN; res - case BinaryOp.OP_/ => - if (r == 0) return Double.NaN - val res = l / r; if (res.isInfinite) return Double.NaN; res - case BinaryOp.OP_% => l % r - case BinaryOp.OP_<< => - val ll = l.toSafeLong(pos)(ev); val rr = r.toSafeLong(pos)(ev) - if (rr < 0) return Double.NaN - if (rr >= 1 && math.abs(ll) >= (1L << (63 - rr))) return Double.NaN - (ll << rr).toDouble - case BinaryOp.OP_>> => - val ll = l.toSafeLong(pos)(ev); val rr = r.toSafeLong(pos)(ev) - if (rr < 0) return Double.NaN - (ll >> rr).toDouble - case BinaryOp.OP_& => - (l.toSafeLong(pos)(ev) & r.toSafeLong(pos)(ev)).toDouble - case BinaryOp.OP_^ => - (l.toSafeLong(pos)(ev) ^ r.toSafeLong(pos)(ev)).toDouble - case BinaryOp.OP_| => - (l.toSafeLong(pos)(ev) | r.toSafeLong(pos)(ev)).toDouble - case _ => Double.NaN // non-numeric op (comparison, equality, etc.) - } - case UnaryOp(pos, op, v) => - val d = tryFoldAsDouble(v) - if (d.isNaN) return Double.NaN - (op: @switch) match { - case Expr.UnaryOp.OP_- => -d - case Expr.UnaryOp.OP_+ => d - case Expr.UnaryOp.OP_~ => (~d.toSafeLong(pos)(ev)).toDouble - case _ => Double.NaN - } - case _ => Double.NaN + val tag = e.tag + if (tag == ExprTags.`Val.Literal`) { + e match { + case Val.Num(_, n) => n + case _ => Double.NaN + } + } else if (tag == ExprTags.BinaryOp && e.isInstanceOf[BinaryOp]) { + val binary = e.asInstanceOf[BinaryOp] + val l = tryFoldAsDouble(binary.lhs) + if (l.isNaN) return Double.NaN + val r = tryFoldAsDouble(binary.rhs) + if (r.isNaN) return Double.NaN + (binary.op: @switch) match { + case BinaryOp.OP_+ => + val res = l + r; if (res.isInfinite) return Double.NaN; res + case BinaryOp.OP_- => + val res = l - r; if (res.isInfinite) return Double.NaN; res + case BinaryOp.OP_* => + val res = l * r; if (res.isInfinite) return Double.NaN; res + case BinaryOp.OP_/ => + if (r == 0) return Double.NaN + val res = l / r; if (res.isInfinite) return Double.NaN; res + case BinaryOp.OP_% => + l % r + case BinaryOp.OP_<< => + val ll = l.toSafeLong(binary.pos)(ev); val rr = r.toSafeLong(binary.pos)(ev) + if (rr < 0) return Double.NaN + if (rr >= 1 && math.abs(ll) >= (1L << (63 - rr))) return Double.NaN + (ll << rr).toDouble + case BinaryOp.OP_>> => + val ll = l.toSafeLong(binary.pos)(ev); val rr = r.toSafeLong(binary.pos)(ev) + if (rr < 0) return Double.NaN + (ll >> rr).toDouble + case BinaryOp.OP_& => + (l.toSafeLong(binary.pos)(ev) & r.toSafeLong(binary.pos)(ev)).toDouble + case BinaryOp.OP_^ => + (l.toSafeLong(binary.pos)(ev) ^ r.toSafeLong(binary.pos)(ev)).toDouble + case BinaryOp.OP_| => + (l.toSafeLong(binary.pos)(ev) | r.toSafeLong(binary.pos)(ev)).toDouble + case _ => Double.NaN + } + } else if (tag == ExprTags.UnaryOp) { + val unary = e.asInstanceOf[UnaryOp] + val d = tryFoldAsDouble(unary.value) + if (d.isNaN) return Double.NaN + (unary.op: @switch) match { + case Expr.UnaryOp.OP_- => -d + case Expr.UnaryOp.OP_+ => d + case Expr.UnaryOp.OP_~ => (~d.toLong).toDouble + case _ => Double.NaN + } + } else { + Double.NaN } } catch { case _: Exception => Double.NaN } @@ -239,18 +815,6 @@ class StaticOptimizer( case _ => false } - override protected def transformFieldName(f: FieldName): FieldName = f match { - case FieldName.Dyn(x) => - transform(x) match { - case x2: Val.Str => - // println(s"----- Fixing FieldName: "+x2.value) - FieldName.Fixed(x2.str) - case x2 if x2 eq x => f - case x2 => FieldName.Dyn(x2) - } - case _ => f - } - private def transformApply(a: Apply): Expr = { val rebound = rebindApply(a.pos, a.value, a.args, a.namedNames, a.tailstrict) match { case null => a @@ -387,7 +951,7 @@ class StaticOptimizer( } case Expr.UnaryOp.OP_~ => v match { - case n: Val.Num => Val.Num(pos, (~n.asSafeLong).toDouble) + case n: Val.Num => Val.Num(pos, (~n.rawDouble.toLong).toDouble) case _ => fallback } case Expr.UnaryOp.OP_+ => @@ -426,8 +990,8 @@ class StaticOptimizer( } case BinaryOp.OP_% => (lhs, rhs) match { - case (Val.Num(_, l), Val.Num(_, r)) => Val.Num(pos, l % r) - case _ => fallback + case (Val.Num(_, l), Val.Num(_, r)) if r != 0 => Val.Num(pos, l % r) + case _ => fallback } case BinaryOp.OP_< => tryFoldComparison(pos, lhs, BinaryOp.OP_<, rhs, fallback) @@ -543,3 +1107,14 @@ class StaticOptimizer( Val.bool(pos, if (negate) !result else result) } } + +object StaticOptimizer { + final case class ScopedVal(v: AnyRef, sc: Scope, idx: Int) + + final class Scope(val mappings: HashMap[String, ScopedVal], val size: Int) { + def get(s: String): ScopedVal = mappings.getOrElse(s, null) + def contains(s: String): Boolean = mappings.contains(s) + } + + val emptyScope: Scope = new Scope(HashMap.empty, 0) +} diff --git a/sjsonnet/src/sjsonnet/Val.scala b/sjsonnet/src/sjsonnet/Val.scala index fa1090e4..7485ff52 100644 --- a/sjsonnet/src/sjsonnet/Val.scala +++ b/sjsonnet/src/sjsonnet/Val.scala @@ -213,6 +213,7 @@ object Val { } def prettyName = "number" + private[sjsonnet] def rawDouble: Double = num override def asInt: Int = num.toInt def asPositiveInt: Int = { diff --git a/sjsonnet/test/resources/go_test_suite/binaryNot2.jsonnet.golden b/sjsonnet/test/resources/go_test_suite/binaryNot2.jsonnet.golden index 11f034e7..356e6fb7 100644 --- a/sjsonnet/test/resources/go_test_suite/binaryNot2.jsonnet.golden +++ b/sjsonnet/test/resources/go_test_suite/binaryNot2.jsonnet.golden @@ -1,3 +1,3 @@ -sjsonnet.Error: Expected Number, got string +sjsonnet.Error: Unknown unary operation: ~ string at [].(binaryNot2.jsonnet:1:1) diff --git a/sjsonnet/test/resources/go_test_suite/bitwise_or10.jsonnet.golden b/sjsonnet/test/resources/go_test_suite/bitwise_or10.jsonnet.golden index 50f106df..036db1b8 100644 --- a/sjsonnet/test/resources/go_test_suite/bitwise_or10.jsonnet.golden +++ b/sjsonnet/test/resources/go_test_suite/bitwise_or10.jsonnet.golden @@ -1,3 +1,3 @@ -sjsonnet.Error: Expected Number, got string +sjsonnet.Error: Unknown binary operation: string | number at [].(bitwise_or10.jsonnet:1:7) diff --git a/sjsonnet/test/resources/go_test_suite/number_divided_by_string.jsonnet.golden b/sjsonnet/test/resources/go_test_suite/number_divided_by_string.jsonnet.golden index cb927039..8b1f93a6 100644 --- a/sjsonnet/test/resources/go_test_suite/number_divided_by_string.jsonnet.golden +++ b/sjsonnet/test/resources/go_test_suite/number_divided_by_string.jsonnet.golden @@ -1,3 +1,3 @@ -sjsonnet.Error: Expected Number, got string +sjsonnet.Error: Unknown binary operation: number / string at [].(number_divided_by_string.jsonnet:1:4) diff --git a/sjsonnet/test/resources/go_test_suite/number_times_string.jsonnet.golden b/sjsonnet/test/resources/go_test_suite/number_times_string.jsonnet.golden index 68d168bd..5175744d 100644 --- a/sjsonnet/test/resources/go_test_suite/number_times_string.jsonnet.golden +++ b/sjsonnet/test/resources/go_test_suite/number_times_string.jsonnet.golden @@ -1,3 +1,3 @@ -sjsonnet.Error: Expected Number, got string +sjsonnet.Error: Unknown binary operation: number * string at [].(number_times_string.jsonnet:1:4) diff --git a/sjsonnet/test/resources/go_test_suite/string_divided_by_number.jsonnet.golden b/sjsonnet/test/resources/go_test_suite/string_divided_by_number.jsonnet.golden index b6090100..c1baf4e5 100644 --- a/sjsonnet/test/resources/go_test_suite/string_divided_by_number.jsonnet.golden +++ b/sjsonnet/test/resources/go_test_suite/string_divided_by_number.jsonnet.golden @@ -1,3 +1,3 @@ -sjsonnet.Error: Expected Number, got string +sjsonnet.Error: Unknown binary operation: string / number at [].(string_divided_by_number.jsonnet:1:7) diff --git a/sjsonnet/test/resources/go_test_suite/string_minus_number.jsonnet.golden b/sjsonnet/test/resources/go_test_suite/string_minus_number.jsonnet.golden index d724c6db..802a3733 100644 --- a/sjsonnet/test/resources/go_test_suite/string_minus_number.jsonnet.golden +++ b/sjsonnet/test/resources/go_test_suite/string_minus_number.jsonnet.golden @@ -1,3 +1,3 @@ -sjsonnet.Error: Expected Number, got string +sjsonnet.Error: Unknown binary operation: string - number at [].(string_minus_number.jsonnet:1:5) diff --git a/sjsonnet/test/resources/go_test_suite/string_times_number.jsonnet.golden b/sjsonnet/test/resources/go_test_suite/string_times_number.jsonnet.golden index 6ba00d2d..6b05e85b 100644 --- a/sjsonnet/test/resources/go_test_suite/string_times_number.jsonnet.golden +++ b/sjsonnet/test/resources/go_test_suite/string_times_number.jsonnet.golden @@ -1,3 +1,3 @@ -sjsonnet.Error: Expected Number, got string +sjsonnet.Error: Unknown binary operation: string * number at [].(string_times_number.jsonnet:1:5) diff --git a/sjsonnet/test/resources/go_test_suite/unary_minus4.jsonnet.golden b/sjsonnet/test/resources/go_test_suite/unary_minus4.jsonnet.golden index eb0b0518..0886fb00 100644 --- a/sjsonnet/test/resources/go_test_suite/unary_minus4.jsonnet.golden +++ b/sjsonnet/test/resources/go_test_suite/unary_minus4.jsonnet.golden @@ -1,3 +1,3 @@ -sjsonnet.Error: Expected Number, got string +sjsonnet.Error: Unknown unary operation: - string at [].(unary_minus4.jsonnet:1:1) diff --git a/sjsonnet/test/resources/go_test_suite/unary_object.jsonnet.golden b/sjsonnet/test/resources/go_test_suite/unary_object.jsonnet.golden index 5b965a44..bde9b3fe 100644 --- a/sjsonnet/test/resources/go_test_suite/unary_object.jsonnet.golden +++ b/sjsonnet/test/resources/go_test_suite/unary_object.jsonnet.golden @@ -1,3 +1,3 @@ -sjsonnet.Error: Expected Number, got object +sjsonnet.Error: Unknown unary operation: + object at [].(unary_object.jsonnet:1:1) diff --git a/sjsonnet/test/src/sjsonnet/AggressiveStaticOptimizationTests.scala b/sjsonnet/test/src/sjsonnet/AggressiveStaticOptimizationTests.scala index 6acc9365..c7826fd3 100644 --- a/sjsonnet/test/src/sjsonnet/AggressiveStaticOptimizationTests.scala +++ b/sjsonnet/test/src/sjsonnet/AggressiveStaticOptimizationTests.scala @@ -217,27 +217,52 @@ object AggressiveStaticOptimizationTests extends TestSuite { // Error cases: runtime errors must still be raised correctly // ------------------------------------------------------------------------- test("runtimeErrorsPreserved") { + def assertErrContainsBoth(input: String, expected: String): Unit = { + val err = evalErr(input) + assert(err.contains(expected)) + + val newEvaluatorErr = evalErr(input, useNewEvaluator = true) + assert(newEvaluatorErr.contains(expected)) + } + test("divisionByZeroNotFolded") { // Division by zero: the optimizer must NOT fold `1 / 0` into a value; // it should fall back to the runtime error path. - val err = evalErr("1 / 0") - assert(err.contains("sjsonnet.Error")) + assertErrContainsBoth("1 / 0", "sjsonnet.Error") } test("negativeShiftNotFolded") { // Negative shift amounts must not be constant-folded; runtime error expected. - val err = evalErr("1 << -1") - assert(err.contains("sjsonnet.Error")) + assertErrContainsBoth("1 << -1", "sjsonnet.Error") + } + test("moduloByZeroNotFolded") { + assertErrContainsBoth("1 % 0", "not a number") } test("andWithNonBoolRhsStillErrors") { // `true && "hello"` must still error: the optimizer only short-circuits when // rhs is a Val.Bool. If rhs is not a Bool, the BinaryOp is left intact and // the runtime type-check fires. - val err = evalErr(""" true && "hello" """) - assert(err.contains("binary operator &&")) + assertErrContainsBoth(""" true && "hello" """, "binary operator &&") } test("orWithNonBoolRhsStillErrors") { - val err = evalErr(""" false || "hello" """) - assert(err.contains("binary operator ||")) + assertErrContainsBoth(""" false || "hello" """, "binary operator ||") + } + test("constantBitwiseNaNStillErrors") { + assertErrContainsBoth("0 & (0 % 0)", "numeric value is not finite") + } + test("dynamicBitwiseNaNStillErrors") { + assertErrContainsBoth("local x = 0; 0 & (x % x)", "numeric value is not finite") + } + test("constantShiftNaNStillErrors") { + assertErrContainsBoth("1 << (0 % 0)", "numeric value is not finite") + } + test("rhsErrorsStillWinOverTypeMismatch") { + assertErrContainsBoth(""" "a" * error "boom" """, "boom") + } + test("binaryTypeErrorsStayOperatorSpecific") { + assertErrContainsBoth(""" "a" * 1 """, "Unknown binary operation: string * number") + } + test("unaryTypeErrorsStayOperatorSpecific") { + assertErrContainsBoth(""" +{} """, "Unknown unary operation: + object") } } From 52ee6da9cee2a9146fd2545cfe5822d7b39ec07d Mon Sep 17 00:00:00 2001 From: He-Pin Date: Sun, 8 Mar 2026 04:16:54 +0800 Subject: [PATCH 4/4] add tailrec profiling checkpoint Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- bench/AST_VISIT_COUNTS.md | 237 +++++++++++ bench/OPTIMIZATION_LOG.md | 36 ++ .../sjsonnet/bench/OptimizerBenchmark.scala | 12 +- .../sjsonnet/AstVisitCorpusRunner.scala | 399 ++++++++++++++++++ sjsonnet/src/sjsonnet/AstVisitProfiler.scala | 181 ++++++++ sjsonnet/src/sjsonnet/DebugStats.scala | 2 + sjsonnet/src/sjsonnet/Evaluator.scala | 84 ++-- sjsonnet/src/sjsonnet/Expr.scala | 32 +- sjsonnet/src/sjsonnet/StaticOptimizer.scala | 123 ++++-- sjsonnet/src/sjsonnet/Val.scala | 18 +- .../src/sjsonnet/AstVisitProfilerTests.scala | 79 ++++ .../test/src/sjsonnet/EvaluatorTests.scala | 2 +- .../sjsonnet/TailCallOptimizationTests.scala | 152 +++++++ 13 files changed, 1275 insertions(+), 82 deletions(-) create mode 100644 bench/AST_VISIT_COUNTS.md create mode 100644 bench/OPTIMIZATION_LOG.md create mode 100644 sjsonnet/src-jvm-native/sjsonnet/AstVisitCorpusRunner.scala create mode 100644 sjsonnet/src/sjsonnet/AstVisitProfiler.scala create mode 100644 sjsonnet/test/src/sjsonnet/AstVisitProfilerTests.scala diff --git a/bench/AST_VISIT_COUNTS.md b/bench/AST_VISIT_COUNTS.md new file mode 100644 index 00000000..865895c3 --- /dev/null +++ b/bench/AST_VISIT_COUNTS.md @@ -0,0 +1,237 @@ +# AST Visit Counts + +Generated by `sjsonnet.AstVisitCorpusRunner` on 2026-03-08. + +## Corpus + +Primary corpus: success-only files from the JVM/native test corpus under `sjsonnet/test/resources/{test_suite,go_test_suite,new_test_suite}`. +Expected-error files were intentionally excluded from counting so the report reflects steady-state successful evaluation paths; JS-only files and the existing `go_test_suite/builtinBase64_string_high_codepoint.jsonnet` skip were also excluded. + +### Corpus summary + +| Suite | Success files run | Expected-error files skipped | Other skipped | +| --- | ---: | ---: | ---: | +| `test_suite` | 73 | 110 | 0 | +| `go_test_suite` | 489 | 221 | 1 | +| `new_test_suite` | 4 | 4 | 1 | +| **Total** | **566** | **335** | **2** | + +### Explicit skips + +- `test_suite`: none +- `go_test_suite`: `builtinBase64_string_high_codepoint.jsonnet (excluded to match FileTests corpus)` +- `new_test_suite`: `regex-js.jsonnet (JS-only corpus file)` + +### Cross-check + +- Every success-only file was executed with both evaluators and required to produce identical successful output. +- Aggregate counter arrays match between evaluators: `true`. +- Old evaluator elapsed time across the corpus: `1070 ms`. +- New evaluator elapsed time across the corpus: `268 ms`. + +## Old evaluator summary + +- Total `visitExpr` calls: `190469` +- Top old-dispatch arm: `ValidId (73011)` +- Top Expr tag: `ValidId (73011)` +- Top invalid-node case: `$ (0)` + +### Old evaluator dispatch-arm counts + +| Rank | Name | Count | Share | +| ---: | --- | ---: | ---: | +| 1 | `ValidId` | 73011 | 38.33% | +| 2 | `BinaryOp` | 56856 | 29.85% | +| 3 | `Val` | 37473 | 19.67% | +| 4 | `Select` | 6689 | 3.51% | +| 5 | `Apply1` | 4107 | 2.16% | +| 6 | `Lookup` | 3422 | 1.80% | +| 7 | `IfElse` | 2095 | 1.10% | +| 8 | `ObjExtend` | 2029 | 1.07% | +| 9 | `And` | 830 | 0.44% | +| 10 | `ObjBody.MemberList` | 785 | 0.41% | +| 11 | `Function` | 734 | 0.39% | +| 12 | `ApplyBuiltin1` | 587 | 0.31% | +| 13 | `ApplyBuiltin2` | 508 | 0.27% | +| 14 | `Arr` | 468 | 0.25% | +| 15 | `LocalExpr` | 288 | 0.15% | +| 16 | `Apply2` | 131 | 0.07% | +| 17 | `ObjBody.ObjComp` | 62 | 0.03% | +| 18 | `Slice` | 57 | 0.03% | +| 19 | `Apply` | 54 | 0.03% | +| 20 | `SelectSuper` | 51 | 0.03% | +| 21 | `ApplyBuiltin3` | 46 | 0.02% | +| 22 | `ApplyBuiltin` | 34 | 0.02% | +| 23 | `Apply0` | 30 | 0.02% | +| 24 | `Comp` | 27 | 0.01% | +| 25 | `Import` | 26 | 0.01% | +| 26 | `UnaryOp` | 23 | 0.01% | +| 27 | `Apply3` | 15 | 0.01% | +| 28 | `InSuper` | 9 | 0.00% | +| 29 | `ImportStr` | 8 | 0.00% | +| 30 | `AssertExpr` | 6 | 0.00% | +| 31 | `ImportBin` | 3 | 0.00% | +| 32 | `ApplyBuiltin4` | 2 | 0.00% | +| 33 | `Or` | 2 | 0.00% | +| 34 | `LookupSuper` | 1 | 0.00% | +| 35 | `ApplyBuiltin0` | 0 | 0.00% | +| 36 | `Error` | 0 | 0.00% | +| 37 | `Invalid` | 0 | 0.00% | + +### Old evaluator normal Expr tag counts + +| Rank | Name | Count | Share | +| ---: | --- | ---: | ---: | +| 1 | `ValidId` | 73011 | 38.33% | +| 2 | `BinaryOp` | 56856 | 29.85% | +| 3 | `Val.Literal` | 37473 | 19.67% | +| 4 | `Select` | 6689 | 3.51% | +| 5 | `Apply1` | 4107 | 2.16% | +| 6 | `Lookup` | 3422 | 1.80% | +| 7 | `IfElse` | 2095 | 1.10% | +| 8 | `ObjExtend` | 2029 | 1.07% | +| 9 | `And` | 830 | 0.44% | +| 10 | `ObjBody.MemberList` | 785 | 0.41% | +| 11 | `Function` | 734 | 0.39% | +| 12 | `ApplyBuiltin1` | 587 | 0.31% | +| 13 | `ApplyBuiltin2` | 508 | 0.27% | +| 14 | `Arr` | 468 | 0.25% | +| 15 | `LocalExpr` | 288 | 0.15% | +| 16 | `Apply2` | 131 | 0.07% | +| 17 | `ObjBody.ObjComp` | 62 | 0.03% | +| 18 | `Slice` | 57 | 0.03% | +| 19 | `Apply` | 54 | 0.03% | +| 20 | `SelectSuper` | 51 | 0.03% | +| 21 | `ApplyBuiltin3` | 46 | 0.02% | +| 22 | `ApplyBuiltin` | 34 | 0.02% | +| 23 | `Apply0` | 30 | 0.02% | +| 24 | `Comp` | 27 | 0.01% | +| 25 | `Import` | 26 | 0.01% | +| 26 | `UnaryOp` | 23 | 0.01% | +| 27 | `Apply3` | 15 | 0.01% | +| 28 | `InSuper` | 9 | 0.00% | +| 29 | `ImportStr` | 8 | 0.00% | +| 30 | `AssertExpr` | 6 | 0.00% | +| 31 | `ImportBin` | 3 | 0.00% | +| 32 | `ApplyBuiltin4` | 2 | 0.00% | +| 33 | `Or` | 2 | 0.00% | +| 34 | `LookupSuper` | 1 | 0.00% | +| 35 | `ApplyBuiltin0` | 0 | 0.00% | +| 36 | `Error` | 0 | 0.00% | +| 37 | `Val.Func` | 0 | 0.00% | + +### Old evaluator invalid-node counts + +| Rank | Name | Count | Share | +| ---: | --- | ---: | ---: | +| 1 | `$` | 0 | 0.00% | +| 2 | `Id` | 0 | 0.00% | +| 3 | `Other` | 0 | 0.00% | +| 4 | `Self` | 0 | 0.00% | +| 5 | `Super` | 0 | 0.00% | + +## New evaluator summary + +- Total `visitExpr` calls: `190469` +- Top old-dispatch arm: `ValidId (73011)` +- Top Expr tag: `ValidId (73011)` +- Top invalid-node case: `$ (0)` + +### New evaluator dispatch-arm counts + +| Rank | Name | Count | Share | +| ---: | --- | ---: | ---: | +| 1 | `ValidId` | 73011 | 38.33% | +| 2 | `BinaryOp` | 56856 | 29.85% | +| 3 | `Val` | 37473 | 19.67% | +| 4 | `Select` | 6689 | 3.51% | +| 5 | `Apply1` | 4107 | 2.16% | +| 6 | `Lookup` | 3422 | 1.80% | +| 7 | `IfElse` | 2095 | 1.10% | +| 8 | `ObjExtend` | 2029 | 1.07% | +| 9 | `And` | 830 | 0.44% | +| 10 | `ObjBody.MemberList` | 785 | 0.41% | +| 11 | `Function` | 734 | 0.39% | +| 12 | `ApplyBuiltin1` | 587 | 0.31% | +| 13 | `ApplyBuiltin2` | 508 | 0.27% | +| 14 | `Arr` | 468 | 0.25% | +| 15 | `LocalExpr` | 288 | 0.15% | +| 16 | `Apply2` | 131 | 0.07% | +| 17 | `ObjBody.ObjComp` | 62 | 0.03% | +| 18 | `Slice` | 57 | 0.03% | +| 19 | `Apply` | 54 | 0.03% | +| 20 | `SelectSuper` | 51 | 0.03% | +| 21 | `ApplyBuiltin3` | 46 | 0.02% | +| 22 | `ApplyBuiltin` | 34 | 0.02% | +| 23 | `Apply0` | 30 | 0.02% | +| 24 | `Comp` | 27 | 0.01% | +| 25 | `Import` | 26 | 0.01% | +| 26 | `UnaryOp` | 23 | 0.01% | +| 27 | `Apply3` | 15 | 0.01% | +| 28 | `InSuper` | 9 | 0.00% | +| 29 | `ImportStr` | 8 | 0.00% | +| 30 | `AssertExpr` | 6 | 0.00% | +| 31 | `ImportBin` | 3 | 0.00% | +| 32 | `ApplyBuiltin4` | 2 | 0.00% | +| 33 | `Or` | 2 | 0.00% | +| 34 | `LookupSuper` | 1 | 0.00% | +| 35 | `ApplyBuiltin0` | 0 | 0.00% | +| 36 | `Error` | 0 | 0.00% | +| 37 | `Invalid` | 0 | 0.00% | + +### New evaluator normal Expr tag counts + +| Rank | Name | Count | Share | +| ---: | --- | ---: | ---: | +| 1 | `ValidId` | 73011 | 38.33% | +| 2 | `BinaryOp` | 56856 | 29.85% | +| 3 | `Val.Literal` | 37473 | 19.67% | +| 4 | `Select` | 6689 | 3.51% | +| 5 | `Apply1` | 4107 | 2.16% | +| 6 | `Lookup` | 3422 | 1.80% | +| 7 | `IfElse` | 2095 | 1.10% | +| 8 | `ObjExtend` | 2029 | 1.07% | +| 9 | `And` | 830 | 0.44% | +| 10 | `ObjBody.MemberList` | 785 | 0.41% | +| 11 | `Function` | 734 | 0.39% | +| 12 | `ApplyBuiltin1` | 587 | 0.31% | +| 13 | `ApplyBuiltin2` | 508 | 0.27% | +| 14 | `Arr` | 468 | 0.25% | +| 15 | `LocalExpr` | 288 | 0.15% | +| 16 | `Apply2` | 131 | 0.07% | +| 17 | `ObjBody.ObjComp` | 62 | 0.03% | +| 18 | `Slice` | 57 | 0.03% | +| 19 | `Apply` | 54 | 0.03% | +| 20 | `SelectSuper` | 51 | 0.03% | +| 21 | `ApplyBuiltin3` | 46 | 0.02% | +| 22 | `ApplyBuiltin` | 34 | 0.02% | +| 23 | `Apply0` | 30 | 0.02% | +| 24 | `Comp` | 27 | 0.01% | +| 25 | `Import` | 26 | 0.01% | +| 26 | `UnaryOp` | 23 | 0.01% | +| 27 | `Apply3` | 15 | 0.01% | +| 28 | `InSuper` | 9 | 0.00% | +| 29 | `ImportStr` | 8 | 0.00% | +| 30 | `AssertExpr` | 6 | 0.00% | +| 31 | `ImportBin` | 3 | 0.00% | +| 32 | `ApplyBuiltin4` | 2 | 0.00% | +| 33 | `Or` | 2 | 0.00% | +| 34 | `LookupSuper` | 1 | 0.00% | +| 35 | `ApplyBuiltin0` | 0 | 0.00% | +| 36 | `Error` | 0 | 0.00% | +| 37 | `Val.Func` | 0 | 0.00% | + +### New evaluator invalid-node counts + +| Rank | Name | Count | Share | +| ---: | --- | ---: | ---: | +| 1 | `$` | 0 | 0.00% | +| 2 | `Id` | 0 | 0.00% | +| 3 | `Other` | 0 | 0.00% | +| 4 | `Self` | 0 | 0.00% | +| 5 | `Super` | 0 | 0.00% | + +## Reproduction + +- Command: `./mill sjsonnet.jvm[3.3.7].runMain sjsonnet.AstVisitCorpusRunner /Users/hepin/IdeaProjects/sjsonnet/bench/AST_VISIT_COUNTS.md` +- Notes: this todo only adds instrumentation, corpus execution, and markdown output. It intentionally does **not** reorder evaluator dispatch or renumber `Expr.tag` values yet. diff --git a/bench/OPTIMIZATION_LOG.md b/bench/OPTIMIZATION_LOG.md new file mode 100644 index 00000000..eccc0946 --- /dev/null +++ b/bench/OPTIMIZATION_LOG.md @@ -0,0 +1,36 @@ +# Optimization Log + +## Wave 1: direct self-tailrec detection +- Scope: mark direct self-tail-calls in `StaticOptimizer` without requiring explicit `tailstrict`. +- Outcome: kept. +- Validation: + - `./mill 'sjsonnet.jvm[3.3.7]'.test.testOnly sjsonnet.TailCallOptimizationTests sjsonnet.EvaluatorTests sjsonnet.AstVisitProfilerTests` + - `./mill 'sjsonnet.jvm[3.3.7]'.test` + - `./mill bench.runJmh -i 1 -wi 1 -f 1 'sjsonnet.bench.OptimizerBenchmark.main'` +- Notes: + - semantic feature is correct and verified + - optimizer benchmark stayed effectively flat (`~0.552-0.564 ms/op` across short JMH runs) + +## Wave 2: AST visit profiler + corpus runner +- Scope: instrument both evaluators, run the success-path corpus, and persist measured counts. +- Outcome: kept. +- Artifacts: + - `bench/AST_VISIT_COUNTS.md` + - `sjsonnet/src/sjsonnet/AstVisitProfiler.scala` + - `sjsonnet/src-jvm-native/sjsonnet/AstVisitCorpusRunner.scala` +- Key corpus findings: + - hottest old-dispatch arms: `ValidId`, `BinaryOp`, `Val`, `Select` + - hottest normal tags: `ValidId`, `BinaryOp`, `Val.Literal`, `Select` + +## Wave 3: data-driven dispatch/tag reordering +- Scope: reorder old `Evaluator.visitExpr` pattern matches and renumber `Expr.tag` values from the measured corpus. +- Outcome: reverted. +- Reason: + - benchmark results regressed despite the frequency-guided order + - `Select`/`Self`/`$`/`Super` share low-tag fast paths in `StaticOptimizer`, which constrains safe tag reshuffling + - the naive reorder appears to fight Scala/JVM codegen rather than help it +- Rejected measurements: + - `MainBenchmark.main`: `2.788 ms/op -> 3.222 ms/op` + - `OptimizerBenchmark.main`: `0.555 ms/op -> 0.569 ms/op` + - corpus runner new evaluator time: `235 ms -> 261 ms` +- Resolution: restore the original dispatch/tag order and keep the profiler data for later, more targeted optimizations. diff --git a/bench/src/sjsonnet/bench/OptimizerBenchmark.scala b/bench/src/sjsonnet/bench/OptimizerBenchmark.scala index 65668a22..c2368281 100644 --- a/bench/src/sjsonnet/bench/OptimizerBenchmark.scala +++ b/bench/src/sjsonnet/bench/OptimizerBenchmark.scala @@ -123,20 +123,20 @@ class OptimizerBenchmark { } private def rec(e: Expr): Unit = e match { - case Expr.Select(_, x, _) => transform(x) - case Expr.Apply(_, x, y, _, _) => + case Expr.Select(_, x, _) => transform(x) + case Expr.Apply(_, x, y, _, _, _) => transform(x) transformArr(y) - case Expr.Apply0(_, x, _) => + case Expr.Apply0(_, x, _, _) => transform(x) - case Expr.Apply1(_, x, y, _) => + case Expr.Apply1(_, x, y, _, _) => transform(x) transform(y) - case Expr.Apply2(_, x, y, z, _) => + case Expr.Apply2(_, x, y, z, _, _) => transform(x) transform(y) transform(z) - case Expr.Apply3(_, x, y, z, a, _) => + case Expr.Apply3(_, x, y, z, a, _, _) => transform(x) transform(y) transform(z) diff --git a/sjsonnet/src-jvm-native/sjsonnet/AstVisitCorpusRunner.scala b/sjsonnet/src-jvm-native/sjsonnet/AstVisitCorpusRunner.scala new file mode 100644 index 00000000..a3883b7c --- /dev/null +++ b/sjsonnet/src-jvm-native/sjsonnet/AstVisitCorpusRunner.scala @@ -0,0 +1,399 @@ +package sjsonnet + +import java.time.LocalDate + +import sjsonnet.stdlib.NativeRegex +import ujson.Value + +object AstVisitCorpusRunner { + private val workspaceRoot: os.Path = + sys.env.get("MILL_WORKSPACE_ROOT").map(os.Path(_)).getOrElse(os.pwd) + private val testSuiteRoot: os.Path = workspaceRoot / "sjsonnet" / "test" / "resources" + private val defaultOutputPath: os.Path = workspaceRoot / "bench" / "AST_VISIT_COUNTS.md" + + private val extVars = Map( + "var1" -> "\"test\"", + "var2" -> """local f(a, b) = {[a]: b, "y": 2}; f("x", 1)""", + "stringVar" -> "\"2 + 2\"", + "codeVar" -> "3 + 3", + "errorVar" -> "error 'xxx'", + "staticErrorVar" -> ")", + "UndeclaredX" -> "x", + "selfRecursiveVar" -> """[42, std.extVar("selfRecursiveVar")[0] + 1]""", + "mutuallyRecursiveVar1" -> """[42, std.extVar("mutuallyRecursiveVar2")[0] + 1]""", + "mutuallyRecursiveVar2" -> """[42, std.extVar("mutuallyRecursiveVar1")[0] + 1]""" + ) + private val tlaVars = Map( + "var1" -> "\"test\"", + "var2" -> """{"x": 1, "y": 2}""" + ) + + private val stdModule: Val.Obj = new sjsonnet.stdlib.StdLibModule( + nativeFunctions = Map( + "jsonToString" -> new Val.Builtin1("jsonToString", "x") { + override def evalRhs(arg1: Eval, ev: EvalScope, pos: Position): Val = { + Val.Str( + pos, + Materializer + .apply0( + arg1.value, + MaterializeJsonRenderer(indent = -1, newline = "", keyValueSeparator = ":") + )(ev) + .toString + ) + } + }, + "nativeError" -> new Val.Builtin0("nativeError") { + override def evalRhs(ev: EvalScope, pos: Position): Val = + Error.fail("native function error") + }, + "nativePanic" -> new Val.Builtin0("nativePanic") { + override def evalRhs(ev: EvalScope, pos: Position): Val = + throw new RuntimeException("native function panic") + } + ) ++ NativeRegex.functions + ).module + + private val goTestDataSkippedTests: Set[String] = Set( + "builtinBase64_string_high_codepoint.jsonnet" + ) + + private val successGoldenPrefixes = Seq( + "sjsonnet.Error", + "sjsonnet.ParseError", + "sjsonnet.StaticError", + "RUNTIME ERROR" + ) + + final case class CorpusFile(suite: String, path: os.Path) + final case class SuiteSummary( + suite: String, + successFiles: Int, + expectedErrorFiles: Int, + skippedFiles: Seq[String]) + final case class CountRow(name: String, count: Long) + final case class CorpusRunSummary( + oldSnapshot: AstVisitProfileSnapshot, + newSnapshot: AstVisitProfileSnapshot, + suiteSummaries: Seq[SuiteSummary], + totalSuccessFiles: Int, + totalExpectedErrorFiles: Int, + totalSkippedFiles: Int, + oldDurationMs: Long, + newDurationMs: Long, + countsMatch: Boolean) + + def main(args: Array[String]): Unit = { + val outputPath = args.toList match { + case Nil => defaultOutputPath + case path :: Nil => resolveOutputPath(path) + case _ => + throw new IllegalArgumentException( + "Usage: AstVisitCorpusRunner [output-markdown-path]" + ) + } + + val (corpus, suiteSummaries) = discoverCorpus() + val summary = executeCorpus(corpus, suiteSummaries) + val markdown = renderMarkdown(summary, outputPath) + os.write.over(outputPath, markdown) + + println( + s"AST visit corpus run complete: ${summary.totalSuccessFiles} success files, ${summary.totalExpectedErrorFiles} expected-error files skipped, ${summary.totalSkippedFiles} additional skips." + ) + println( + s"Old evaluator total visits: ${summary.oldSnapshot.totalVisits}, new evaluator total visits: ${summary.newSnapshot.totalVisits}, counts match: ${summary.countsMatch}" + ) + println(s"Wrote $outputPath") + } + + private def resolveOutputPath(path: String): os.Path = + os.Path(path, workspaceRoot) + + private def discoverCorpus(): (Seq[CorpusFile], Seq[SuiteSummary]) = { + val suiteNames = Seq("test_suite", "go_test_suite", "new_test_suite") + val summaries = scala.collection.mutable.ArrayBuffer.empty[SuiteSummary] + val files = scala.collection.mutable.ArrayBuffer.empty[CorpusFile] + + suiteNames.foreach { suite => + val dir = testSuiteRoot / suite + val skipped = scala.collection.mutable.ArrayBuffer.empty[String] + var successFiles = 0 + var expectedErrorFiles = 0 + + os.list(dir) + .filter(p => p.ext == "jsonnet") + .sortBy(_.last) + .foreach { file => + val skip = skipReason(suite, file) + if (skip != null) skipped += s"${file.last} (${skip})" + else if (isExpectedErrorFile(file)) expectedErrorFiles += 1 + else { + successFiles += 1 + files += CorpusFile(suite, file) + } + } + + summaries += SuiteSummary(suite, successFiles, expectedErrorFiles, skipped.toSeq) + } + + (files.toSeq, summaries.toSeq) + } + + private def skipReason(suite: String, file: os.Path): String = { + if (suite == "go_test_suite" && goTestDataSkippedTests.contains(file.last)) { + "excluded to match FileTests corpus" + } else if (suite == "new_test_suite" && file.last.contains("-js")) { + "JS-only corpus file" + } else null + } + + private def isExpectedErrorFile(file: os.Path): Boolean = { + val goldenContent = os.read(os.Path(file.toString + ".golden")).stripLineEnd + goldenContent.contains("java.lang.StackOverflowError") || + successGoldenPrefixes.exists(goldenContent.startsWith) + } + + private def executeCorpus( + corpus: Seq[CorpusFile], + suiteSummaries: Seq[SuiteSummary]): CorpusRunSummary = { + val oldProfiler = new AstVisitProfiler + val newProfiler = new AstVisitProfiler + val oldParseCache = new DefaultParseCache + val newParseCache = new DefaultParseCache + val oldStats = new DebugStats + val newStats = new DebugStats + oldStats.astVisitProfiler = oldProfiler + newStats.astVisitProfiler = newProfiler + + var oldDurationNs = 0L + var newDurationNs = 0L + + corpus.foreach { file => + println(s"Running ${file.suite}/${file.path.last}") + + val oldStart = System.nanoTime() + val oldResult = eval(file, useNewEvaluator = false, oldParseCache, oldStats) + oldDurationNs += System.nanoTime() - oldStart + + val newStart = System.nanoTime() + val newResult = eval(file, useNewEvaluator = true, newParseCache, newStats) + newDurationNs += System.nanoTime() - newStart + + if (oldResult != newResult) { + throw new IllegalStateException( + s"Evaluator mismatch for ${file.suite}/${file.path.last}: old=${preview(oldResult)}, new=${preview(newResult)}" + ) + } + } + + val oldSnapshot = oldProfiler.snapshot() + val newSnapshot = newProfiler.snapshot() + CorpusRunSummary( + oldSnapshot, + newSnapshot, + suiteSummaries, + totalSuccessFiles = corpus.length, + totalExpectedErrorFiles = suiteSummaries.map(_.expectedErrorFiles).sum, + totalSkippedFiles = suiteSummaries.map(_.skippedFiles.length).sum, + oldDurationMs = oldDurationNs / 1000000L, + newDurationMs = newDurationNs / 1000000L, + countsMatch = snapshotsEqual(oldSnapshot, newSnapshot) + ) + } + + private def eval( + corpusFile: CorpusFile, + useNewEvaluator: Boolean, + parseCache: ParseCache, + debugStats: DebugStats): Either[String, Value] = { + val interp = new Interpreter( + key => extVars.get(key).map(ExternalVariable.code), + key => tlaVars.get(key).map(ExternalVariable.code), + OsPath(testSuiteRoot / corpusFile.suite), + importer = new sjsonnet.SjsonnetMainBase.SimpleImporter(Array.empty[Path].toIndexedSeq), + parseCache = parseCache, + settings = new Settings(useNewEvaluator = useNewEvaluator), + storePos = null, + logger = null, + std = stdModule, + variableResolver = _ => None, + debugStats = debugStats + ) + interp.interpret(os.read(corpusFile.path), OsPath(corpusFile.path)) + } + + private def snapshotsEqual(a: AstVisitProfileSnapshot, b: AstVisitProfileSnapshot): Boolean = { + java.util.Arrays.equals(a.oldDispatchArmCounts, b.oldDispatchArmCounts) && + java.util.Arrays.equals(a.normalTagCounts, b.normalTagCounts) && + java.util.Arrays.equals(a.invalidNodeCounts, b.invalidNodeCounts) + } + + private def preview(result: Either[String, Value]): String = result match { + case Left(err) => err.linesIterator.take(1).mkString + case Right(v) => v.render().take(120) + } + + private def renderMarkdown(summary: CorpusRunSummary, outputPath: os.Path): String = { + val oldDispatchTop = + topRows(summary.oldSnapshot.oldDispatchArmCounts, AstVisitProfiler.DispatchArmNames) + val oldTagTop = topRows(summary.oldSnapshot.normalTagCounts, AstVisitProfiler.NormalTagNames, 1) + val oldInvalidTop = + topRows(summary.oldSnapshot.invalidNodeCounts, AstVisitProfiler.InvalidNodeNames) + val newDispatchTop = + topRows(summary.newSnapshot.oldDispatchArmCounts, AstVisitProfiler.DispatchArmNames) + val newTagTop = topRows(summary.newSnapshot.normalTagCounts, AstVisitProfiler.NormalTagNames, 1) + val newInvalidTop = + topRows(summary.newSnapshot.invalidNodeCounts, AstVisitProfiler.InvalidNodeNames) + + val lines = scala.collection.mutable.ArrayBuffer.empty[String] + lines += "# AST Visit Counts" + lines += "" + lines += s"Generated by `sjsonnet.AstVisitCorpusRunner` on ${LocalDate.now()}." + lines += "" + lines += "## Corpus" + lines += "" + lines += "Primary corpus: success-only files from the JVM/native test corpus under `sjsonnet/test/resources/{test_suite,go_test_suite,new_test_suite}`." + lines += "Expected-error files were intentionally excluded from counting so the report reflects steady-state successful evaluation paths; JS-only files and the existing `go_test_suite/builtinBase64_string_high_codepoint.jsonnet` skip were also excluded." + lines += "" + lines += "### Corpus summary" + lines += "" + lines ++= Seq( + "| Suite | Success files run | Expected-error files skipped | Other skipped |", + "| --- | ---: | ---: | ---: |" + ) + summary.suiteSummaries.foreach { suiteSummary => + lines += s"| `${suiteSummary.suite}` | ${suiteSummary.successFiles} | ${suiteSummary.expectedErrorFiles} | ${suiteSummary.skippedFiles.length} |" + } + lines += s"| **Total** | **${summary.totalSuccessFiles}** | **${summary.totalExpectedErrorFiles}** | **${summary.totalSkippedFiles}** |" + lines += "" + lines += "### Explicit skips" + lines += "" + summary.suiteSummaries.foreach { suiteSummary => + lines += s"- `${suiteSummary.suite}`: ${formatSkipped(suiteSummary.skippedFiles)}" + } + lines += "" + lines += "### Cross-check" + lines += "" + lines += s"- Every success-only file was executed with both evaluators and required to produce identical successful output." + lines += s"- Aggregate counter arrays match between evaluators: `${summary.countsMatch}`." + lines += s"- Old evaluator elapsed time across the corpus: `${summary.oldDurationMs} ms`." + lines += s"- New evaluator elapsed time across the corpus: `${summary.newDurationMs} ms`." + lines += "" + lines += "## Old evaluator summary" + lines += "" + lines += s"- Total `visitExpr` calls: `${summary.oldSnapshot.totalVisits}`" + lines += s"- Top old-dispatch arm: `${formatTop(oldDispatchTop.headOption)}`" + lines += s"- Top Expr tag: `${formatTop(oldTagTop.headOption)}`" + lines += s"- Top invalid-node case: `${formatTop(oldInvalidTop.headOption)}`" + lines += "" + lines += renderSection( + "Old evaluator dispatch-arm counts", + summary.oldSnapshot.oldDispatchArmCounts, + AstVisitProfiler.DispatchArmNames, + 0, + summary.oldSnapshot.totalVisits + ) + lines += "" + lines += renderSection( + "Old evaluator normal Expr tag counts", + summary.oldSnapshot.normalTagCounts, + AstVisitProfiler.NormalTagNames, + 1, + summary.oldSnapshot.totalVisits + ) + lines += "" + lines += renderSection( + "Old evaluator invalid-node counts", + summary.oldSnapshot.invalidNodeCounts, + AstVisitProfiler.InvalidNodeNames, + 0, + summary.oldSnapshot.totalVisits + ) + lines += "" + lines += "## New evaluator summary" + lines += "" + lines += s"- Total `visitExpr` calls: `${summary.newSnapshot.totalVisits}`" + lines += s"- Top old-dispatch arm: `${formatTop(newDispatchTop.headOption)}`" + lines += s"- Top Expr tag: `${formatTop(newTagTop.headOption)}`" + lines += s"- Top invalid-node case: `${formatTop(newInvalidTop.headOption)}`" + lines += "" + lines += renderSection( + "New evaluator dispatch-arm counts", + summary.newSnapshot.oldDispatchArmCounts, + AstVisitProfiler.DispatchArmNames, + 0, + summary.newSnapshot.totalVisits + ) + lines += "" + lines += renderSection( + "New evaluator normal Expr tag counts", + summary.newSnapshot.normalTagCounts, + AstVisitProfiler.NormalTagNames, + 1, + summary.newSnapshot.totalVisits + ) + lines += "" + lines += renderSection( + "New evaluator invalid-node counts", + summary.newSnapshot.invalidNodeCounts, + AstVisitProfiler.InvalidNodeNames, + 0, + summary.newSnapshot.totalVisits + ) + lines += "" + lines += "## Reproduction" + lines += "" + lines += s"- Command: `./mill sjsonnet.jvm[3.3.7].runMain sjsonnet.AstVisitCorpusRunner ${outputPath}`" + lines += "- Notes: this todo only adds instrumentation, corpus execution, and markdown output. It intentionally does **not** reorder evaluator dispatch or renumber `Expr.tag` values yet." + + lines.mkString("\n") + "\n" + } + + private def renderSection( + title: String, + counts: Array[Long], + names: Array[String], + startIndex: Int, + totalVisits: Long): String = { + val rows = topRows(counts, names, startIndex, counts.length) + val lines = scala.collection.mutable.ArrayBuffer.empty[String] + lines += s"### ${title}" + lines += "" + lines += "| Rank | Name | Count | Share |" + lines += "| ---: | --- | ---: | ---: |" + rows.zipWithIndex.foreach { case (row, idx) => + lines += s"| ${idx + 1} | `${row.name}` | ${row.count} | ${formatShare(row.count, totalVisits)} |" + } + lines.mkString("\n") + } + + private def topRows( + counts: Array[Long], + names: Array[String], + startIndex: Int = 0, + limit: Int = 10): Seq[CountRow] = { + counts.indices + .drop(startIndex) + .filter(i => i < names.length) + .map(i => CountRow(names(i), counts(i))) + .sortBy(row => (-row.count, row.name)) + .take(limit) + .toSeq + } + + private def formatShare(count: Long, total: Long): String = { + if (total == 0) "0.00%" + else f"${count.toDouble * 100.0 / total}%.2f%%" + } + + private def formatSkipped(skipped: Seq[String]): String = { + if (skipped.isEmpty) "none" + else skipped.map(s => s"`$s`").mkString(", ") + } + + private def formatTop(row: Option[CountRow]): String = row match { + case Some(value) => s"${value.name} (${value.count})" + case None => "none" + } +} diff --git a/sjsonnet/src/sjsonnet/AstVisitProfiler.scala b/sjsonnet/src/sjsonnet/AstVisitProfiler.scala new file mode 100644 index 00000000..b8afab67 --- /dev/null +++ b/sjsonnet/src/sjsonnet/AstVisitProfiler.scala @@ -0,0 +1,181 @@ +package sjsonnet + +import scala.annotation.switch + +final class AstVisitProfiler { + val oldDispatchArmCounts = new Array[Long](AstVisitProfiler.DispatchArmNames.length) + val normalTagCounts = new Array[Long](ExprTags.Error + 1) + val invalidNodeCounts = new Array[Long](AstVisitProfiler.InvalidNodeNames.length) + + @inline final def countVisit(e: Expr): Unit = { + val tag = e.tag + if (e.isInvalidVisitTag) { + oldDispatchArmCounts(AstVisitProfiler.DispatchArm.Invalid) += 1 + invalidNodeCounts(AstVisitProfiler.invalidNodeIndex(tag)) += 1 + } else { + oldDispatchArmCounts(AstVisitProfiler.dispatchArmIndex(tag)) += 1 + normalTagCounts(tag) += 1 + } + } + + def snapshot(): AstVisitProfileSnapshot = + AstVisitProfileSnapshot( + oldDispatchArmCounts.clone(), + normalTagCounts.clone(), + invalidNodeCounts.clone() + ) +} + +final case class AstVisitProfileSnapshot( + oldDispatchArmCounts: Array[Long], + normalTagCounts: Array[Long], + invalidNodeCounts: Array[Long]) { + def totalVisits: Long = oldDispatchArmCounts.foldLeft(0L)(_ + _) +} + +object AstVisitProfiler { + object DispatchArm { + final val ValidId = 0 + final val BinaryOp = 1 + final val Select = 2 + final val Val = 3 + final val ApplyBuiltin0 = 4 + final val ApplyBuiltin1 = 5 + final val ApplyBuiltin2 = 6 + final val ApplyBuiltin3 = 7 + final val ApplyBuiltin4 = 8 + final val And = 9 + final val Or = 10 + final val UnaryOp = 11 + final val Apply1 = 12 + final val Lookup = 13 + final val Function = 14 + final val LocalExpr = 15 + final val Apply = 16 + final val IfElse = 17 + final val Apply3 = 18 + final val MemberList = 19 + final val Apply2 = 20 + final val AssertExpr = 21 + final val ApplyBuiltin = 22 + final val Comp = 23 + final val Arr = 24 + final val SelectSuper = 25 + final val LookupSuper = 26 + final val InSuper = 27 + final val ObjExtend = 28 + final val ObjComp = 29 + final val Slice = 30 + final val Import = 31 + final val Apply0 = 32 + final val ImportStr = 33 + final val ImportBin = 34 + final val Error = 35 + final val Invalid = 36 + } + + val DispatchArmNames: Array[String] = Array( + "ValidId", + "BinaryOp", + "Select", + "Val", + "ApplyBuiltin0", + "ApplyBuiltin1", + "ApplyBuiltin2", + "ApplyBuiltin3", + "ApplyBuiltin4", + "And", + "Or", + "UnaryOp", + "Apply1", + "Lookup", + "Function", + "LocalExpr", + "Apply", + "IfElse", + "Apply3", + "ObjBody.MemberList", + "Apply2", + "AssertExpr", + "ApplyBuiltin", + "Comp", + "Arr", + "SelectSuper", + "LookupSuper", + "InSuper", + "ObjExtend", + "ObjBody.ObjComp", + "Slice", + "Import", + "Apply0", + "ImportStr", + "ImportBin", + "Error", + "Invalid" + ) + + val NormalTagNames: Array[String] = Array( + "", + "ValidId", + "BinaryOp", + "Select", + "Val.Literal", + "Val.Func", + "ApplyBuiltin0", + "ApplyBuiltin1", + "ApplyBuiltin2", + "ApplyBuiltin3", + "ApplyBuiltin4", + "And", + "Or", + "UnaryOp", + "Apply1", + "Lookup", + "Function", + "LocalExpr", + "Apply", + "IfElse", + "Apply3", + "ObjBody.MemberList", + "Apply2", + "AssertExpr", + "ApplyBuiltin", + "Comp", + "Arr", + "SelectSuper", + "LookupSuper", + "InSuper", + "ObjExtend", + "ObjBody.ObjComp", + "Slice", + "Import", + "Apply0", + "ImportStr", + "ImportBin", + "Error" + ) + + @inline private[sjsonnet] final def dispatchArmIndex(tag: Int): Int = { + if (tag <= ExprTags.Select) tag - 1 + else if (tag <= ExprTags.`Val.Func`) DispatchArm.Val + else tag - 2 + } + + object InvalidNode { + final val Id = 0 + final val Self = 1 + final val Dollar = 2 + final val Super = 3 + final val Other = 4 + } + + val InvalidNodeNames: Array[String] = Array("Id", "Self", "$", "Super", "Other") + + @inline private[sjsonnet] final def invalidNodeIndex(tag: Int): Int = (tag: @switch) match { + case ExprTags.Id => InvalidNode.Id + case ExprTags.Self => InvalidNode.Self + case ExprTags.`$` => InvalidNode.Dollar + case ExprTags.Super => InvalidNode.Super + case _ => InvalidNode.Other + } +} diff --git a/sjsonnet/src/sjsonnet/DebugStats.scala b/sjsonnet/src/sjsonnet/DebugStats.scala index cadbd7e2..84bd63b2 100644 --- a/sjsonnet/src/sjsonnet/DebugStats.scala +++ b/sjsonnet/src/sjsonnet/DebugStats.scala @@ -9,6 +9,8 @@ package sjsonnet */ final class DebugStats { + var astVisitProfiler: AstVisitProfiler = _ + // -- Lazy -- var lazyCreated: Long = 0 diff --git a/sjsonnet/src/sjsonnet/Evaluator.scala b/sjsonnet/src/sjsonnet/Evaluator.scala index d4ee0271..c11fc489 100644 --- a/sjsonnet/src/sjsonnet/Evaluator.scala +++ b/sjsonnet/src/sjsonnet/Evaluator.scala @@ -32,6 +32,8 @@ class Evaluator( private[this] var stackDepth: Int = 0 private[this] val maxStack: Int = settings.maxStack + private[sjsonnet] val astVisitProfiler: AstVisitProfiler = + if (debugStats == null) null else debugStats.astVisitProfiler private[sjsonnet] var profiler: Profiler = _ @inline private[sjsonnet] final def checkStackDepth(pos: Position): Unit = { @@ -63,6 +65,8 @@ class Evaluator( collection.mutable.HashMap.empty[Path, Val] override def visitExpr(e: Expr)(implicit scope: ValScope): Val = try { + val av = astVisitProfiler + if (av != null) av.countVisit(e) val p = profiler val saved: (AnyRef, Int) = if (p != null) p.enter(e) else null try { @@ -1028,8 +1032,9 @@ class Evaluator( } } visitExprWithTailCallSupport(e.returned) - // Tail-position tailstrict calls: match TailstrictableExpr to unify the tailstrict guard, - // then dispatch by concrete type. + // Tail-position calls eligible for TCO: explicit `tailstrict` calls and optimizer-marked direct + // self-tail-calls. The latter preserve default Jsonnet laziness by carrying lazy args through + // the TailCall sentinel. // // - Apply* (user function calls): construct a TailCall sentinel that the caller's // TailCall.resolve loop will resolve iteratively, avoiding JVM stack growth for @@ -1038,40 +1043,47 @@ class Evaluator( // visitApplyBuiltin*. Those methods already wrap their result in TailCall.resolve() when // tailstrict=true, resolving any TailCall that a user-defined callback (e.g. the function // argument to std.makeArray or std.sort) may have returned. + case e: Apply if e.tailstrict || e.tailrec => + try { + val func = visitExpr(e.value).cast[Val.Func] + val mode = if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled + val args = + if (e.tailstrict) e.args.map(visitExpr(_)).asInstanceOf[Array[Eval]] + else e.args.map(visitAsLazy(_)) + new TailCall(func, args, e.namedNames, e, mode) + } catch Error.withStackFrame(e) + case e: Apply0 if e.tailstrict || e.tailrec => + try { + val func = visitExpr(e.value).cast[Val.Func] + val mode = if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled + new TailCall(func, Evaluator.emptyLazyArray, null, e, mode) + } catch Error.withStackFrame(e) + case e: Apply1 if e.tailstrict || e.tailrec => + try { + val func = visitExpr(e.value).cast[Val.Func] + val mode = if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled + val arg1: Eval = if (e.tailstrict) visitExpr(e.a1) else visitAsLazy(e.a1) + new TailCall(func, Array[Eval](arg1), null, e, mode) + } catch Error.withStackFrame(e) + case e: Apply2 if e.tailstrict || e.tailrec => + try { + val func = visitExpr(e.value).cast[Val.Func] + val mode = if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled + val arg1: Eval = if (e.tailstrict) visitExpr(e.a1) else visitAsLazy(e.a1) + val arg2: Eval = if (e.tailstrict) visitExpr(e.a2) else visitAsLazy(e.a2) + new TailCall(func, Array[Eval](arg1, arg2), null, e, mode) + } catch Error.withStackFrame(e) + case e: Apply3 if e.tailstrict || e.tailrec => + try { + val func = visitExpr(e.value).cast[Val.Func] + val mode = if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled + val arg1: Eval = if (e.tailstrict) visitExpr(e.a1) else visitAsLazy(e.a1) + val arg2: Eval = if (e.tailstrict) visitExpr(e.a2) else visitAsLazy(e.a2) + val arg3: Eval = if (e.tailstrict) visitExpr(e.a3) else visitAsLazy(e.a3) + new TailCall(func, Array[Eval](arg1, arg2, arg3), null, e, mode) + } catch Error.withStackFrame(e) case e: TailstrictableExpr if e.tailstrict => - e match { - case e: Apply => - try { - val func = visitExpr(e.value).cast[Val.Func] - new TailCall(func, e.args.map(visitExpr(_)).asInstanceOf[Array[Eval]], e.namedNames, e) - } catch Error.withStackFrame(e) - case e: Apply0 => - try { - val func = visitExpr(e.value).cast[Val.Func] - new TailCall(func, Evaluator.emptyLazyArray, null, e) - } catch Error.withStackFrame(e) - case e: Apply1 => - try { - val func = visitExpr(e.value).cast[Val.Func] - new TailCall(func, Array[Eval](visitExpr(e.a1)), null, e) - } catch Error.withStackFrame(e) - case e: Apply2 => - try { - val func = visitExpr(e.value).cast[Val.Func] - new TailCall(func, Array[Eval](visitExpr(e.a1), visitExpr(e.a2)), null, e) - } catch Error.withStackFrame(e) - case e: Apply3 => - try { - val func = visitExpr(e.value).cast[Val.Func] - new TailCall( - func, - Array[Eval](visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3)), - null, - e - ) - } catch Error.withStackFrame(e) - case _ => visitExpr(e) - } + visitExpr(e) case _ => visitExpr(e) } @@ -1372,6 +1384,8 @@ class NewEvaluator( extends Evaluator(r, e, w, s, wa, ds) { override def visitExpr(e: Expr)(implicit scope: ValScope): Val = try { + val av = astVisitProfiler + if (av != null) av.countVisit(e) (e.tag: @switch) match { case ExprTags.ValidId => visitValidId(e.asInstanceOf[ValidId]) case ExprTags.BinaryOp => visitBinaryOp(e.asInstanceOf[BinaryOp]) diff --git a/sjsonnet/src/sjsonnet/Expr.scala b/sjsonnet/src/sjsonnet/Expr.scala index 16d5da0d..af0753bb 100644 --- a/sjsonnet/src/sjsonnet/Expr.scala +++ b/sjsonnet/src/sjsonnet/Expr.scala @@ -24,6 +24,7 @@ trait Expr { */ var pos: Position private[sjsonnet] def tag: Byte = ExprTags.UNTAGGED + private[sjsonnet] def isInvalidVisitTag: Boolean = false /** The name of this expression type to be shown in error messages */ def exprErrorString: String = { @@ -72,16 +73,20 @@ object Expr { final case class Self(var pos: Position) extends Expr { final override private[sjsonnet] def tag = ExprTags.Self + final override private[sjsonnet] def isInvalidVisitTag = true } final case class Super(var pos: Position) extends Expr { final override private[sjsonnet] def tag = ExprTags.Super + final override private[sjsonnet] def isInvalidVisitTag = true } final case class $(var pos: Position) extends Expr { final override private[sjsonnet] def tag = ExprTags.`$` + final override private[sjsonnet] def isInvalidVisitTag = true } final case class Id(var pos: Position, name: String) extends Expr { final override private[sjsonnet] def tag = ExprTags.Id + final override private[sjsonnet] def isInvalidVisitTag = true override def exprErrorString: String = s"${super.exprErrorString} $name" } @@ -231,22 +236,38 @@ object Expr { value: Expr, args: Array[Expr], namedNames: Array[String], - tailstrict: Boolean) + tailstrict: Boolean, + tailrec: Boolean = false) extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.Apply override def exprErrorString: String = Expr.callTargetName(value) } - final case class Apply0(var pos: Position, value: Expr, tailstrict: Boolean) + final case class Apply0( + var pos: Position, + value: Expr, + tailstrict: Boolean, + tailrec: Boolean = false) extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.Apply0 override def exprErrorString: String = Expr.callTargetName(value) } - final case class Apply1(var pos: Position, value: Expr, a1: Expr, tailstrict: Boolean) + final case class Apply1( + var pos: Position, + value: Expr, + a1: Expr, + tailstrict: Boolean, + tailrec: Boolean = false) extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.Apply1 override def exprErrorString: String = Expr.callTargetName(value) } - final case class Apply2(var pos: Position, value: Expr, a1: Expr, a2: Expr, tailstrict: Boolean) + final case class Apply2( + var pos: Position, + value: Expr, + a1: Expr, + a2: Expr, + tailstrict: Boolean, + tailrec: Boolean = false) extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.Apply2 override def exprErrorString: String = Expr.callTargetName(value) @@ -257,7 +278,8 @@ object Expr { a1: Expr, a2: Expr, a3: Expr, - tailstrict: Boolean) + tailstrict: Boolean, + tailrec: Boolean = false) extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.Apply3 override def exprErrorString: String = Expr.callTargetName(value) diff --git a/sjsonnet/src/sjsonnet/StaticOptimizer.scala b/sjsonnet/src/sjsonnet/StaticOptimizer.scala index b0413213..66cf1b4b 100644 --- a/sjsonnet/src/sjsonnet/StaticOptimizer.scala +++ b/sjsonnet/src/sjsonnet/StaticOptimizer.scala @@ -79,8 +79,7 @@ class StaticOptimizer( else ObjBody.MemberList(memberList.pos, binds2, fields3, asserts2) case ExprTags.Function => - val function = e.asInstanceOf[Function] - nestedNames(function.params.names)(rec(e)) + transformFunction(e.asInstanceOf[Function], -1) case ExprTags.`ObjBody.ObjComp` => val objComp = e.asInstanceOf[ObjBody.ObjComp] @@ -326,37 +325,37 @@ class StaticOptimizer( if (x2 eq x) expr else Select(pos, x2, name) - case Apply(pos, x, y, namedNames, tailstrict) => + case Apply(pos, x, y, namedNames, tailstrict, tailrec) => val x2 = transform(x) val y2 = transformArr(y) if ((x2 eq x) && (y2 eq y)) expr - else Apply(pos, x2, y2, namedNames, tailstrict) + else Apply(pos, x2, y2, namedNames, tailstrict, tailrec) - case Apply0(pos, x, tailstrict) => + case Apply0(pos, x, tailstrict, tailrec) => val x2 = transform(x) if (x2 eq x) expr - else Apply0(pos, x2, tailstrict) + else Apply0(pos, x2, tailstrict, tailrec) - case Apply1(pos, x, y, tailstrict) => + case Apply1(pos, x, y, tailstrict, tailrec) => val x2 = transform(x) val y2 = transform(y) if ((x2 eq x) && (y2 eq y)) expr - else Apply1(pos, x2, y2, tailstrict) + else Apply1(pos, x2, y2, tailstrict, tailrec) - case Apply2(pos, x, y, z, tailstrict) => + case Apply2(pos, x, y, z, tailstrict, tailrec) => val x2 = transform(x) val y2 = transform(y) val z2 = transform(z) if ((x2 eq x) && (y2 eq y) && (z2 eq z)) expr - else Apply2(pos, x2, y2, z2, tailstrict) + else Apply2(pos, x2, y2, z2, tailstrict, tailrec) - case Apply3(pos, x, y, z, a, tailstrict) => + case Apply3(pos, x, y, z, a, tailstrict, tailrec) => val x2 = transform(x) val y2 = transform(y) val z2 = transform(z) val a2 = transform(a) if ((x2 eq x) && (y2 eq y) && (z2 eq z) && (a2 eq a)) expr - else Apply3(pos, x2, y2, z2, a2, tailstrict) + else Apply3(pos, x2, y2, z2, a2, tailstrict, tailrec) case ApplyBuiltin(pos, func, x, tailstrict) => val x2 = transformArr(x) @@ -538,13 +537,27 @@ class StaticOptimizer( transformGenericArr(a)(transformAssert) private def transformBind(b: Bind): Bind = { - val args = b.args - val rhs = b.rhs - nestedNames(if (args == null) null else args.names) { - val args2 = transformParams(args) - val rhs2 = transform(rhs) - if ((args2 eq args) && (rhs2 eq rhs)) b - else b.copy(args = args2, rhs = rhs2) + val selfTailrecNameIdx = scope.get(b.name).idx + b.args match { + case null => + b.rhs match { + case function: Function => + val rhs2 = transformFunction(function, selfTailrecNameIdx) + if (rhs2 eq function) b + else b.copy(rhs = rhs2) + case rhs => + val rhs2 = transform(rhs) + if (rhs2 eq rhs) b + else b.copy(rhs = rhs2) + } + case args => + val rhs = b.rhs + nestedNames(args.names) { + val args2 = transformParams(args) + val rhs2 = markDirectSelfTailCalls(transform(rhs), selfTailrecNameIdx) + if ((args2 eq args) && (rhs2 eq rhs)) b + else b.copy(args = args2, rhs = rhs2) + } } } @@ -816,7 +829,7 @@ class StaticOptimizer( } private def transformApply(a: Apply): Expr = { - val rebound = rebindApply(a.pos, a.value, a.args, a.namedNames, a.tailstrict) match { + val rebound = rebindApply(a.pos, a.value, a.args, a.namedNames, a.tailstrict, a.tailrec) match { case null => a case a => a } @@ -843,10 +856,10 @@ class StaticOptimizer( if (a.namedNames != null) a else a.args.length match { - case 0 => Apply0(a.pos, a.value, a.tailstrict) - case 1 => Apply1(a.pos, a.value, a.args(0), a.tailstrict) - case 2 => Apply2(a.pos, a.value, a.args(0), a.args(1), a.tailstrict) - case 3 => Apply3(a.pos, a.value, a.args(0), a.args(1), a.args(2), a.tailstrict) + case 0 => Apply0(a.pos, a.value, a.tailstrict, a.tailrec) + case 1 => Apply1(a.pos, a.value, a.args(0), a.tailstrict, a.tailrec) + case 2 => Apply2(a.pos, a.value, a.args(0), a.args(1), a.tailstrict, a.tailrec) + case 3 => Apply3(a.pos, a.value, a.args(0), a.args(1), a.args(2), a.tailstrict, a.tailrec) case _ => a } } @@ -856,7 +869,8 @@ class StaticOptimizer( lhs: Expr, args: Array[Expr], names: Array[String], - tailstrict: Boolean): Expr = lhs match { + tailstrict: Boolean, + tailrec: Boolean): Expr = lhs match { case f: Val.Builtin => rebind(args, names, f.params) match { case null => null @@ -891,12 +905,12 @@ class StaticOptimizer( case ScopedVal(Function(_, params, _), _, _) => rebind(args, names, params) match { case null => null - case newArgs => Apply(pos, lhs, newArgs, null, tailstrict) + case newArgs => Apply(pos, lhs, newArgs, null, tailstrict, tailrec) } case ScopedVal(Bind(_, _, params, _), _, _) => rebind(args, names, params) match { case null => null - case newArgs => Apply(pos, lhs, newArgs, null, tailstrict) + case newArgs => Apply(pos, lhs, newArgs, null, tailstrict, tailrec) } case _ => null } @@ -935,6 +949,61 @@ class StaticOptimizer( target } + private def transformFunction(function: Function, selfTailrecNameIdx: Int): Function = + nestedNames(function.params.names) { + val params = function.params + val body = function.body + val params2 = transformParams(params) + val body2 = + if (selfTailrecNameIdx == -1) transform(body) + else markDirectSelfTailCalls(transform(body), selfTailrecNameIdx) + if ((params2 eq params) && (body2 eq body)) function + else Function(function.pos, params2, body2) + } + + private def markDirectSelfTailCalls(e: Expr, selfTailrecNameIdx: Int): Expr = e match { + case IfElse(pos, cond, thenExpr, elseExpr) => + val thenExpr2 = markDirectSelfTailCalls(thenExpr, selfTailrecNameIdx) + val elseExpr2 = + if (elseExpr == null) null + else markDirectSelfTailCalls(elseExpr, selfTailrecNameIdx) + if ((thenExpr2 eq thenExpr) && (elseExpr2 eq elseExpr)) e + else IfElse(pos, cond, thenExpr2, elseExpr2) + + case LocalExpr(pos, bindings, returned) => + val returned2 = markDirectSelfTailCalls(returned, selfTailrecNameIdx) + if (returned2 eq returned) e + else LocalExpr(pos, bindings, returned2) + + case AssertExpr(pos, asserted, returned) => + val returned2 = markDirectSelfTailCalls(returned, selfTailrecNameIdx) + if (returned2 eq returned) e + else AssertExpr(pos, asserted, returned2) + + case apply @ Apply(_, value, _, _, tailstrict, _) + if !tailstrict && isDirectSelfCall(value, selfTailrecNameIdx) => + apply.copy(tailrec = true) + case apply @ Apply0(_, value, tailstrict, _) + if !tailstrict && isDirectSelfCall(value, selfTailrecNameIdx) => + apply.copy(tailrec = true) + case apply @ Apply1(_, value, _, tailstrict, _) + if !tailstrict && isDirectSelfCall(value, selfTailrecNameIdx) => + apply.copy(tailrec = true) + case apply @ Apply2(_, value, _, _, tailstrict, _) + if !tailstrict && isDirectSelfCall(value, selfTailrecNameIdx) => + apply.copy(tailrec = true) + case apply @ Apply3(_, value, _, _, _, tailstrict, _) + if !tailstrict && isDirectSelfCall(value, selfTailrecNameIdx) => + apply.copy(tailrec = true) + + case _ => e + } + + private def isDirectSelfCall(value: Expr, selfTailrecNameIdx: Int): Boolean = value match { + case ValidId(_, _, nameIdx) => nameIdx == selfTailrecNameIdx + case _ => false + } + private def tryFoldUnaryOp(pos: Position, op: Int, v: Val, fallback: Expr): Expr = try { (op: @switch) match { diff --git a/sjsonnet/src/sjsonnet/Val.scala b/sjsonnet/src/sjsonnet/Val.scala index 7485ff52..ee5fdfba 100644 --- a/sjsonnet/src/sjsonnet/Val.scala +++ b/sjsonnet/src/sjsonnet/Val.scala @@ -1166,10 +1166,11 @@ case object TailstrictModeEnabled extends TailstrictMode case object TailstrictModeDisabled extends TailstrictMode /** - * Sentinel value for tail call optimization of `tailstrict` calls. When a function body's tail - * position is a `tailstrict` call, the evaluator returns a [[TailCall]] instead of recursing into - * the callee. [[TailCall.resolve]] then re-invokes the target function iteratively, eliminating - * native stack growth. + * Sentinel value for tail call optimization. When a function body's tail position is either an + * explicit `tailstrict` call or an optimizer-marked direct self-tail-call, the evaluator returns a + * [[TailCall]] instead of recursing into the callee. [[TailCall.resolve]] then re-invokes the + * target function iteratively, eliminating native stack growth while preserving the call's original + * argument-evaluation mode. * * This is an internal protocol value and must never escape to user-visible code paths (e.g. * materialization, object field access). Every call site that may produce a TailCall must either @@ -1180,7 +1181,8 @@ final class TailCall( val func: Val.Func, val args: Array[Eval], val namedNames: Array[String], - val callSiteExpr: Expr) + val callSiteExpr: Expr, + val tailstrictMode: TailstrictMode) extends Val { def pos: Position = callSiteExpr.pos def prettyName = "tailcall" @@ -1191,8 +1193,8 @@ object TailCall { /** * Iteratively resolve a [[TailCall]] chain (trampoline loop). If `current` is not a TailCall, it - * is returned immediately. Otherwise, each TailCall's target function is re-invoked with - * `TailstrictModeEnabled` until a non-TailCall result is produced. + * is returned immediately. Otherwise, each TailCall's target function is re-invoked with the + * original call's tailstrict mode until a non-TailCall result is produced. * * Error frames preserve the original call-site expression name (e.g. "Apply2") so that TCO does * not alter user-visible stack traces. @@ -1200,7 +1202,7 @@ object TailCall { @tailrec def resolve(current: Val)(implicit ev: EvalScope): Val = current match { case tc: TailCall => - implicit val tailstrictMode: TailstrictMode = TailstrictModeEnabled + implicit val tailstrictMode: TailstrictMode = tc.tailstrictMode val next = try { tc.func.apply(tc.args, tc.namedNames, tc.callSiteExpr.pos) diff --git a/sjsonnet/test/src/sjsonnet/AstVisitProfilerTests.scala b/sjsonnet/test/src/sjsonnet/AstVisitProfilerTests.scala new file mode 100644 index 00000000..62e95309 --- /dev/null +++ b/sjsonnet/test/src/sjsonnet/AstVisitProfilerTests.scala @@ -0,0 +1,79 @@ +package sjsonnet + +import utest._ + +object AstVisitProfilerTests extends TestSuite { + private val dummyPos = new Position(new FileScope(DummyPath("ast-visit-profiler-tests")), 0) + private val emptyVars: String => Option[ExternalVariable[_]] = _ => None + + private def runExpr(expr: Expr, useNewEvaluator: Boolean): AstVisitProfileSnapshot = { + val stats = new DebugStats + val profiler = new AstVisitProfiler + stats.astVisitProfiler = profiler + val interpreter = new Interpreter( + emptyVars, + emptyVars, + DummyPath("ast-visit-profiler-tests"), + Importer.empty, + new DefaultParseCache, + new Settings(useNewEvaluator = useNewEvaluator), + storePos = null, + logger = null, + std = sjsonnet.stdlib.StdLibModule.Default.module, + variableResolver = _ => None, + debugStats = stats + ) + try interpreter.evaluator.visitExpr(expr)(ValScope.empty) + catch { + case _: Error => () + } + profiler.snapshot() + } + + private def assertSimpleBinary(snapshot: AstVisitProfileSnapshot): Unit = { + assert(snapshot.totalVisits == 3) + assert(snapshot.oldDispatchArmCounts(AstVisitProfiler.DispatchArm.BinaryOp) == 1) + assert(snapshot.oldDispatchArmCounts(AstVisitProfiler.DispatchArm.Val) == 2) + assert(snapshot.normalTagCounts(ExprTags.BinaryOp) == 1) + assert(snapshot.normalTagCounts(ExprTags.`Val.Literal`) == 2) + assert(snapshot.invalidNodeCounts(AstVisitProfiler.InvalidNode.Self) == 0) + } + + private def assertInvalidSelf(snapshot: AstVisitProfileSnapshot): Unit = { + assert(snapshot.totalVisits == 1) + assert(snapshot.oldDispatchArmCounts(AstVisitProfiler.DispatchArm.Invalid) == 1) + assert(snapshot.invalidNodeCounts(AstVisitProfiler.InvalidNode.Self) == 1) + assert(snapshot.normalTagCounts(ExprTags.ValidId) == 0) + assert(snapshot.normalTagCounts(ExprTags.BinaryOp) == 0) + } + + val tests: Tests = Tests { + test("old evaluator counts normal tags and dispatch arms") { + val expr = Expr.BinaryOp( + dummyPos, + Val.Num(dummyPos, 1), + Expr.BinaryOp.OP_+, + Val.Num(dummyPos, 2) + ) + assertSimpleBinary(runExpr(expr, useNewEvaluator = false)) + } + + test("new evaluator counts normal tags and dispatch arms") { + val expr = Expr.BinaryOp( + dummyPos, + Val.Num(dummyPos, 1), + Expr.BinaryOp.OP_+, + Val.Num(dummyPos, 2) + ) + assertSimpleBinary(runExpr(expr, useNewEvaluator = true)) + } + + test("old evaluator counts invalid node namespace separately") { + assertInvalidSelf(runExpr(Expr.Self(dummyPos), useNewEvaluator = false)) + } + + test("new evaluator counts invalid node namespace separately") { + assertInvalidSelf(runExpr(Expr.Self(dummyPos), useNewEvaluator = true)) + } + } +} diff --git a/sjsonnet/test/src/sjsonnet/EvaluatorTests.scala b/sjsonnet/test/src/sjsonnet/EvaluatorTests.scala index 88347900..217143e0 100644 --- a/sjsonnet/test/src/sjsonnet/EvaluatorTests.scala +++ b/sjsonnet/test/src/sjsonnet/EvaluatorTests.scala @@ -944,7 +944,7 @@ object EvaluatorTests extends TestSuite { test("maxStack") { test("recursiveFunction") { val err = evalErr( - "local f(x) = f(x + 1); f(0)", + "local f(x) = 1 + f(x + 1); f(0)", useNewEvaluator = useNewEvaluator, maxStack = 10 ) diff --git a/sjsonnet/test/src/sjsonnet/TailCallOptimizationTests.scala b/sjsonnet/test/src/sjsonnet/TailCallOptimizationTests.scala index addb9497..1ac227af 100644 --- a/sjsonnet/test/src/sjsonnet/TailCallOptimizationTests.scala +++ b/sjsonnet/test/src/sjsonnet/TailCallOptimizationTests.scala @@ -1,10 +1,162 @@ package sjsonnet +import scala.collection.mutable import utest._ +import Expr.* import TestUtils.{eval, evalErr} object TailCallOptimizationTests extends TestSuite { + private def optimize(s: String): Expr = { + val interpreter = new Interpreter(Map(), Map(), DummyPath(), Importer.empty, new DefaultParseCache) + val parsed = fastparse + .parse( + s, + new Parser(DummyPath("(memory)"), mutable.HashMap.empty, mutable.HashMap.empty).document(_) + ) + .get + .value + ._1 + new StaticOptimizer( + interpreter.evaluator, + _ => None, + sjsonnet.stdlib.StdLibModule.Default.module, + mutable.HashMap.empty, + mutable.HashMap.empty + ).optimize(parsed) + } + val tests: Tests = Tests { + test("optimizerMarksDirectSelfTailCall") { + val optimized = optimize( + """ + |local f(n) = + | if n <= 0 then 0 + | else f(n - 1); + | + |f(10) + |""".stripMargin + ).asInstanceOf[LocalExpr] + + val bind = optimized.bindings(0) + val recursiveCall = bind.rhs.asInstanceOf[IfElse].`else`.asInstanceOf[Apply1] + val outerCall = optimized.returned.asInstanceOf[Apply1] + + assert(recursiveCall.tailrec) + assert(!recursiveCall.tailstrict) + assert(!outerCall.tailrec) + assert(recursiveCall.value.asInstanceOf[ValidId].name == "f") + } + + test("optimizerDoesNotMarkNonTailSelfCall") { + val optimized = optimize( + """ + |local f(n) = + | if n <= 0 then 0 + | else 1 + f(n - 1); + | + |f(10) + |""".stripMargin + ).asInstanceOf[LocalExpr] + + val recursiveCall = optimized + .bindings(0) + .rhs + .asInstanceOf[IfElse] + .`else` + .asInstanceOf[BinaryOp] + .rhs + .asInstanceOf[Apply1] + + assert(!recursiveCall.tailrec) + assert(!recursiveCall.tailstrict) + } + + test("directSelfTailrecWithoutTailstrict") { + eval( + """ + |local countdown(n) = + | if n <= 0 then 0 + | else countdown(n - 1); + | + |countdown(10000) + |""".stripMargin, + maxStack = 100 + ) ==> ujson.Num(0) + } + + test("directSelfTailrecPreservesLazyArgs") { + eval( + """ + |local loop(n, ignored=0) = + | if n <= 0 then 0 + | else loop(n - 1, error "kaboom"); + | + |loop(10000) + |""".stripMargin, + maxStack = 100 + ) ==> ujson.Num(0) + } + + test("directSelfTailrecWithNamedAndDefaultArgs") { + eval( + """ + |local f(n, step=1, accum=0) = + | if n <= 0 then accum + | else f(accum=accum + n, n=n - step); + | + |f(100) + |""".stripMargin, + maxStack = 100 + ) ==> ujson.Num(5050) + } + + test("directSelfTailrecInsideNestedFunction") { + eval( + """ + |local outer(n) = + | local inner(remaining, accum=0) = + | if remaining <= 0 then accum + | else + | local next = remaining - 1; + | inner(next, accum + remaining); + | inner(n); + | + |outer(1000) + |""".stripMargin, + maxStack = 100 + ) ==> ujson.Num(500500) + } + + test("nonTailSelfRecursionStillOverflows") { + val err = evalErr( + """ + |local f(n) = + | if n <= 0 then 0 + | else 1 + f(n - 1); + | + |f(1000) + |""".stripMargin, + maxStack = 100 + ) + assert(err.contains("Max stack frames exceeded.")) + } + + test("objectMethodSelfCallIsNotImplicitTailrec") { + val err = evalErr( + """ + |local fns = { + | countdown(n):: + | if n <= 0 then 0 + | else self.countdown(n - 1), + |}; + | + |fns.countdown(1000) + |""".stripMargin, + maxStack = 100 + ) + assert(err.contains("Max stack frames exceeded.")) + } + test("tailstrictFactorialSmall") { eval( """