Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions build.mill
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ trait SjsonnetCrossModule extends CrossScalaModule with ScalafmtModule {
"-feature",
"-opt-inline-from:sjsonnet.*,sjsonnet.**",
"-Xsource:3",
"-Xlint:_",
) ++ (if (scalaVersion().startsWith("2.13")) Seq("-Wopt", "-Wconf:origin=scala.collection.compat.*:s")
"-Xlint:_"
) ++ (if (scalaVersion().startsWith("2.13"))
Seq("-Wopt", "-Wconf:origin=scala.collection.compat.*:s")
else Seq("-Xfatal-warnings", "-Ywarn-unused:-nowarn"))
else Seq[String]("-Wconf:origin=scala.collection.compat.*:s", "-Xlint:all")
)
Expand Down
9 changes: 8 additions & 1 deletion sjsonnet/src/sjsonnet/Settings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@ final case class Settings(
brokenAssertionLogic: Boolean = false,
maxMaterializeDepth: Int = 1000,
materializeRecursiveDepthLimit: Int = 128,
maxStack: Int = 500
maxStack: Int = 500,
/**
* Enable aggressive static optimizations in the optimization phase, including: constant folding
* for arithmetic, comparison, bitwise, and shift operators; branch elimination for if-else with
* constant conditions; short-circuit elimination for And/Or with constant lhs. These reduce AST
* node count, benefiting long-running Jsonnet programs.
*/
aggressiveStaticOptimization: Boolean = false
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What the risk with having that by default? are you worried about correctness?

)

object Settings {
Expand Down
223 changes: 222 additions & 1 deletion sjsonnet/src/sjsonnet/StaticOptimizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ import ScopedExprTransform.*
* StaticOptimizer performs necessary transformations for the evaluator (assigning ValScope indices)
* plus additional optimizations (post-order) and static checking (pre-order).
*
* When `aggressiveStaticOptimization` is enabled, the optimizer additionally performs during the
* optimization phase:
* - Constant folding for arithmetic (+, -, *, /, %), comparison (<, >, <=, >=, ==, !=), bitwise
* (&, ^, |), shift (<<, >>), and unary (!, -, ~, +) operators.
* - Branch elimination for if-else with constant conditions.
* - Short-circuit elimination for And/Or with constant lhs operands.
*
* @param variableResolver
* a function that resolves variable names to expressions, only called if the variable is not
* found in the scope.
Expand All @@ -25,6 +32,8 @@ class StaticOptimizer(
extends ScopedExprTransform {
def optimize(e: Expr): Expr = transform(e)

private val aggressiveOptimization = ev.settings.aggressiveStaticOptimization

override def transform(_e: Expr): Expr = super.transform(check(_e)) match {
case a: Apply => transformApply(a)

Expand Down Expand Up @@ -96,7 +105,42 @@ class StaticOptimizer(
Val.staticObject(pos, fields, internedStaticFieldSets, internedStrings)
else m

case e => e
// Aggressive optimizations: constant folding, branch elimination, short-circuit elimination.
// These reduce AST node count at parse time, benefiting long-running Jsonnet programs.
case e => if (aggressiveOptimization) tryAggressiveOptimize(e) else e
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only do this after a flag

}

/**
* Aggressive static optimizations that benefit long-running programs by reducing AST size.
* Includes: branch elimination, short-circuit elimination, constant folding for arithmetic,
* comparison, bitwise, and shift operators.
*/
private def tryAggressiveOptimize(e: Expr): Expr = e match {
// Constant folding: BinaryOp with two constant operands (most common case first)
case e @ BinaryOp(pos, lhs: Val, op, rhs: Val) => tryFoldBinaryOp(pos, lhs, op, rhs, e)

// Constant folding: UnaryOp with constant operand
case e @ UnaryOp(pos, op, v: Val) => tryFoldUnaryOp(pos, op, v, e)

// Branch elimination: constant condition in if-else
case IfElse(_, _: Val.True, thenExpr, _) => thenExpr
case IfElse(pos, _: Val.False, _, elseExpr) =>
if (elseExpr == null) Val.Null(pos) else elseExpr

// Short-circuit elimination for And/Or with constant lhs.
//
// IMPORTANT: rhs MUST be guarded as `Val.Bool` — do NOT relax this to arbitrary Expr.
// The Evaluator's visitAnd/visitOr enforces that rhs evaluates to Bool, throwing
// "binary operator && does not operate on <type>s" otherwise. If we fold `true && rhs`
// into just `rhs` without the Bool guard, we silently remove that runtime type check,
// causing programs like `true && "hello"` to return "hello" instead of erroring.
// See: Evaluator.visitAnd / Evaluator.visitOr for the authoritative runtime semantics.
case And(_, _: Val.True, rhs: Val.Bool) => rhs
case And(pos, _: Val.False, _) => Val.False(pos)
case Or(pos, _: Val.True, _) => Val.True(pos)
case Or(_, _: Val.False, rhs: Val.Bool) => rhs

case _ => e
}

private object ValidSuper {
Expand Down Expand Up @@ -258,4 +302,181 @@ class StaticOptimizer(
}
target
}

private def tryFoldUnaryOp(pos: Position, op: Int, v: Val, fallback: Expr): Expr =
try {
op match {
case Expr.UnaryOp.OP_! =>
v match {
case _: Val.True => Val.False(pos)
case _: Val.False => Val.True(pos)
case _ => fallback
}
case Expr.UnaryOp.OP_- =>
v match {
case Val.Num(_, n) => Val.Num(pos, -n)
case _ => fallback
}
case Expr.UnaryOp.OP_~ =>
v match {
case Val.Num(_, n) => Val.Num(pos, (~n.toLong).toDouble)
case _ => fallback
}
case Expr.UnaryOp.OP_+ =>
v match {
case Val.Num(_, n) => Val.Num(pos, n)
case _ => fallback
}
case _ => fallback
}
} catch { case _: Exception => fallback }

private def tryFoldBinaryOp(pos: Position, lhs: Val, op: Int, rhs: Val, fallback: Expr): Expr =
try {
op match {
case BinaryOp.OP_+ =>
(lhs, rhs) match {
case (Val.Num(_, l), Val.Num(_, r)) => Val.Num(pos, l + r)
case (Val.Str(_, l), Val.Str(_, r)) => Val.Str(pos, l + r)
case (l: Val.Arr, r: Val.Arr) => l.concat(pos, r)
case _ => fallback
}
case BinaryOp.OP_- =>
(lhs, rhs) match {
case (Val.Num(_, l), Val.Num(_, r)) => Val.Num(pos, l - r)
case _ => fallback
}
case BinaryOp.OP_* =>
(lhs, rhs) match {
case (Val.Num(_, l), Val.Num(_, r)) => Val.Num(pos, l * r)
case _ => fallback
}
case BinaryOp.OP_/ =>
(lhs, rhs) match {
case (Val.Num(_, l), Val.Num(_, r)) if r != 0 => Val.Num(pos, l / r)
case _ => fallback
}
case BinaryOp.OP_% =>
(lhs, rhs) match {
case (Val.Num(_, l), Val.Num(_, r)) => Val.Num(pos, l % r)
case _ => fallback
}
case BinaryOp.OP_< =>
tryFoldComparison(pos, lhs, BinaryOp.OP_<, rhs, fallback)
case BinaryOp.OP_> =>
tryFoldComparison(pos, lhs, BinaryOp.OP_>, rhs, fallback)
case BinaryOp.OP_<= =>
tryFoldComparison(pos, lhs, BinaryOp.OP_<=, rhs, fallback)
case BinaryOp.OP_>= =>
tryFoldComparison(pos, lhs, BinaryOp.OP_>=, rhs, fallback)
case BinaryOp.OP_== =>
tryFoldEquality(pos, lhs, rhs, negate = false, fallback)
case BinaryOp.OP_!= =>
tryFoldEquality(pos, lhs, rhs, negate = true, fallback)
case BinaryOp.OP_in =>
(lhs, rhs) match {
case (Val.Str(_, l), o: Val.Obj) => Val.bool(pos, o.containsKey(l))
case _ => fallback
}
case BinaryOp.OP_<< =>
(lhs, rhs) match {
case (Val.Num(_, l), Val.Num(_, r)) =>
val ll = lhs.asInstanceOf[Val.Num].asSafeLong
val rr = rhs.asInstanceOf[Val.Num].asSafeLong
if (rr < 0) fallback // negative shift → runtime error
else if (rr >= 1 && math.abs(ll) >= (1L << (63 - rr)))
fallback // overflow → runtime error
else Val.Num(pos, (ll << rr).toDouble)
case _ => fallback
}
case BinaryOp.OP_>> =>
(lhs, rhs) match {
case (Val.Num(_, l), Val.Num(_, r)) =>
val ll = lhs.asInstanceOf[Val.Num].asSafeLong
val rr = rhs.asInstanceOf[Val.Num].asSafeLong
if (rr < 0) fallback // negative shift → runtime error
else Val.Num(pos, (ll >> rr).toDouble)
case _ => fallback
}
case BinaryOp.OP_& =>
(lhs, rhs) match {
case (Val.Num(_, _), Val.Num(_, _)) =>
Val.Num(
pos,
(lhs.asInstanceOf[Val.Num].asSafeLong & rhs
.asInstanceOf[Val.Num]
.asSafeLong).toDouble
)
case _ => fallback
}
case BinaryOp.OP_^ =>
(lhs, rhs) match {
case (Val.Num(_, _), Val.Num(_, _)) =>
Val.Num(pos, (lhs.asLong ^ rhs.asLong).toDouble)
case _ => fallback
}
case BinaryOp.OP_| =>
(lhs, rhs) match {
case (Val.Num(_, _), Val.Num(_, _)) =>
Val.Num(pos, (lhs.asLong | rhs.asLong).toDouble)
case _ => fallback
}
case _ => fallback
}
} catch { case _: Exception => fallback }

private def tryFoldComparison(
pos: Position,
lhs: Val,
op: Int,
rhs: Val,
fallback: Expr): Expr = {
// Use IEEE 754 operators directly for Num, not java.lang.Double.compare,
// because compare(-0.0, 0.0) == -1 while IEEE 754 treats -0.0 == 0.0.
(lhs, rhs) match {
case (Val.Num(_, l), Val.Num(_, r)) if !l.isNaN && !r.isNaN =>
val result = op match {
case BinaryOp.OP_< => l < r
case BinaryOp.OP_> => l > r
case BinaryOp.OP_<= => l <= r
case BinaryOp.OP_>= => l >= r
case _ => return fallback
}
Val.bool(pos, result)
case (Val.Str(_, l), Val.Str(_, r)) =>
val cmp = Util.compareStringsByCodepoint(l, r)
val result = op match {
case BinaryOp.OP_< => cmp < 0
case BinaryOp.OP_> => cmp > 0
case BinaryOp.OP_<= => cmp <= 0
case BinaryOp.OP_>= => cmp >= 0
case _ => return fallback
}
Val.bool(pos, result)
case _ => fallback
}
}

private def tryFoldEquality(
pos: Position,
lhs: Val,
rhs: Val,
negate: Boolean,
fallback: Expr): Expr = {
def isSimpleLiteral(v: Val): Boolean = v match {
case _: Val.Bool | _: Val.Null | _: Val.Str | _: Val.Num => true
case _ => false
}
if (!isSimpleLiteral(lhs) || !isSimpleLiteral(rhs)) return fallback
val result = (lhs, rhs) match {
case (_: Val.True, _: Val.True) | (_: Val.False, _: Val.False) | (_: Val.Null, _: Val.Null) =>
true
case (Val.Num(_, l), Val.Num(_, r)) if !l.isNaN && !r.isNaN =>
l == r
case (Val.Str(_, l), Val.Str(_, r)) =>
l == r
case _ => false // different simple types are never equal
}
Val.bool(pos, if (negate) !result else result)
}
}
Loading