From 5d009b9aa94625b50b24322fd199fae0e33fc45a Mon Sep 17 00:00:00 2001 From: Stephen Amar Date: Fri, 6 Mar 2026 04:57:30 +0000 Subject: [PATCH] Add unified --profile flag with text and flamegraph output Replaces the --flamegraph flag with --profile (text output, default) or --profile flamegraph: (folded stack format). Ports the expression-level profiler from universe/jsonnet/profiler, which measures self-time per expression via visitExpr instrumentation, and merges it with the existing flamegraph stack recorder. The text format shows top expressions by self-time, aggregated by file and by expression type, with System.nanoTime overhead estimation. The flamegraph format filters to Apply/ApplyBuiltin/comprehension frames and includes filenames for non-builtin calls. Deletes the old bench/ ProfilingEvaluator and RunProfiler which are superseded by the new CLI-integrated profiler. --- .../sjsonnet/bench/ProfilingEvaluator.scala | 197 ---------- bench/src/sjsonnet/bench/RunProfiler.scala | 88 ----- sjsonnet/src-jvm-native/sjsonnet/Config.scala | 6 +- .../sjsonnet/SjsonnetMainBase.scala | 29 +- sjsonnet/src/sjsonnet/Evaluator.scala | 87 ++--- .../src/sjsonnet/FlameGraphProfiler.scala | 52 --- sjsonnet/src/sjsonnet/Profiler.scala | 352 ++++++++++++++++++ 7 files changed, 420 insertions(+), 391 deletions(-) delete mode 100644 bench/src/sjsonnet/bench/ProfilingEvaluator.scala delete mode 100644 bench/src/sjsonnet/bench/RunProfiler.scala delete mode 100644 sjsonnet/src/sjsonnet/FlameGraphProfiler.scala create mode 100644 sjsonnet/src/sjsonnet/Profiler.scala diff --git a/bench/src/sjsonnet/bench/ProfilingEvaluator.scala b/bench/src/sjsonnet/bench/ProfilingEvaluator.scala deleted file mode 100644 index 143e5ce4..00000000 --- a/bench/src/sjsonnet/bench/ProfilingEvaluator.scala +++ /dev/null @@ -1,197 +0,0 @@ -package sjsonnet.bench - -import sjsonnet.* - -import java.util -import scala.collection.mutable -import scala.jdk.CollectionConverters.* -import scala.language.existentials - -class ProfilingEvaluator( - r: CachedResolver, - e: String => Option[Expr], - w: Path, - s: Settings, - wa: (Boolean, String) => Unit) - extends Evaluator(r, e, w, s, wa) { - - trait Box { - def name: String - var time: Long = 0 - var count: Long = 0 - } - - class ExprBox(val expr: Expr) extends Box { - var children: Seq[Expr] = Nil - var totalTime: Long = 0 - lazy val prettyOffset: String = - prettyIndex(expr.pos).map { case (l, c) => s"$l:$c" }.getOrElse("?:?") - lazy val prettyPos = s"${expr.pos.currentFile.asInstanceOf[OsPath].p}:$prettyOffset" - lazy val name: String = { - val exprName = expr.getClass.getName.split('.').last.split('$').last - val funcOrOptName: Option[String] = expr match { - case a: Expr.ApplyBuiltin => Some(a.func.functionName) - case a: Expr.ApplyBuiltin1 => Some(a.func.functionName) - case a: Expr.ApplyBuiltin2 => Some(a.func.functionName) - case a: Expr.ApplyBuiltin3 => Some(a.func.functionName) - case a: Expr.ApplyBuiltin4 => Some(a.func.functionName) - case u: Expr.UnaryOp => Some(Expr.UnaryOp.name(u.op)) - case b: Expr.BinaryOp => Some(Expr.BinaryOp.name(b.op)) - case _ => None - } - exprName + funcOrOptName.map(n => s" ($n)").getOrElse("") - } - } - - class OpBox(val op: Int, val name: String) extends Box - - class ExprTypeBox(val name: String) extends Box - - class BuiltinBox(val name: String) extends Box - - private val data = new util.IdentityHashMap[Expr, ExprBox] - private var parent: ExprBox = _ - - private def getOrCreate(e: Expr): ExprBox = { - var box = data.get(e) - if (box == null) { - box = new ExprBox(e) - data.put(e, box) - } - box - } - - override def visitExpr(e: Expr)(implicit scope: ValScope): Val = { - val pt0 = System.nanoTime() - val box = getOrCreate(e) - val prevParent = parent - parent = box - val t0 = System.nanoTime() - try super.visitExpr(e) - finally { - box.time += (System.nanoTime() - t0) - box.count += 1 - parent = prevParent - if (parent != null) { - parent.time -= (System.nanoTime() - pt0) - } - } - } - - def clear(): Unit = data.clear() - - def get(e: Expr): ExprBox = data.get(e) - - def accumulate(e: Expr): Long = { - val box = getOrCreate(e) - box.children = getChildren(e) - box.totalTime = box.time + box.children.map(accumulate).sum - box.totalTime - } - - private def getChildren(e: Expr): Seq[Expr] = e match { - case p: Product => - (0 until p.productArity).iterator - .map(p.productElement) - .flatMap { - case _: Expr => Seq(e) - case a: Array[Expr] => a.toSeq - case a: Array[Expr.CompSpec] => a.toSeq - case p: Expr.Params => getChildren(p) - case a: Expr.Member.AssertStmt => Seq(a.value, a.msg) - case a: Array[Expr.Bind] => a.iterator.flatMap(getChildren).toSeq - case a: Array[Expr.Member.Field] => a.iterator.flatMap(getChildren).toSeq - case a: Array[Expr.Member.AssertStmt] => a.iterator.flatMap(getChildren).toSeq - case s: Seq[?] => s.collect { case ee: Expr => ee } - case s: Some[?] => s.collect { case ee: Expr => ee } - case _ => Nil - } - .filter(_ != null) - .toSeq - case _ => Nil - } - - private def getChildren(p: Expr.Params): Seq[Expr] = - if (p == null || p.defaultExprs == null) Nil else p.defaultExprs.toSeq - - private def getChildren(b: Expr.Bind): Seq[Expr] = - getChildren(b.args) :+ b.rhs - - private def getChildren(m: Expr.Member): Seq[Expr] = m match { - case b: Expr.Bind => getChildren(b.args) :+ b.rhs - case a: Expr.Member.AssertStmt => Seq(a.value, a.msg) - case f: Expr.Member.Field => getChildren(f.fieldName) ++ getChildren(f.args) :+ f.rhs - case null => Nil - } - - private def getChildren(f: Expr.FieldName): Seq[Expr] = f match { - case f: Expr.FieldName.Dyn => Seq(f.expr) - case _ => Nil - } - - def all: Seq[ExprBox] = data.values().asScala.toSeq - - def binaryOperators(): Seq[OpBox] = { - val m = new mutable.HashMap[Int, OpBox] - all.foreach { b => - b.expr match { - case Expr.BinaryOp(_, _, op, _) => - val ob = m.getOrElseUpdate(op, new OpBox(op, Expr.BinaryOp.name(op))) - ob.time += b.time - ob.count += b.count - case _ => - } - } - m.valuesIterator.toSeq - } - - def unaryOperators(): Seq[OpBox] = { - val m = new mutable.HashMap[Int, OpBox] - all.foreach { b => - b.expr match { - case Expr.UnaryOp(_, op, _) => - val ob = m.getOrElseUpdate(op, new OpBox(op, Expr.UnaryOp.name(op))) - ob.time += b.time - ob.count += b.count - case _ => - } - } - m.valuesIterator.toSeq - } - - def exprTypes(): Seq[ExprTypeBox] = { - val m = new mutable.HashMap[String, ExprTypeBox] - all.foreach { b => - val cl = b.expr match { - case _: Val => classOf[Val] - case ee => ee.getClass - } - val n = cl.getName.replaceAll("^sjsonnet\\.", "").replace('$', '.') - val eb = m.getOrElseUpdate(n, new ExprTypeBox(n)) - eb.time += b.time - eb.count += b.count - } - m.valuesIterator.toSeq - } - - def builtins(): Seq[BuiltinBox] = { - val names = new util.IdentityHashMap[Val.Func, String]() - sjsonnet.stdlib.StdLibModule.Default.functions.foreachEntry((n, f) => names.put(f, n)) - val m = new mutable.HashMap[String, BuiltinBox] - def add(b: ExprBox, func: Val.Builtin): Unit = { - val n = names.getOrDefault(func, func.functionName) - val bb = m.getOrElseUpdate(n, new BuiltinBox(n)) - bb.time += b.time - bb.count += b.count - } - all.foreach { b => - b.expr match { - case a: Expr.ApplyBuiltin1 => add(b, a.func) - case a: Expr.ApplyBuiltin2 => add(b, a.func) - case a: Expr.ApplyBuiltin => add(b, a.func) - case _ => - } - } - m.valuesIterator.toSeq - } -} diff --git a/bench/src/sjsonnet/bench/RunProfiler.scala b/bench/src/sjsonnet/bench/RunProfiler.scala deleted file mode 100644 index 8414612d..00000000 --- a/bench/src/sjsonnet/bench/RunProfiler.scala +++ /dev/null @@ -1,88 +0,0 @@ -package sjsonnet.bench - -import sjsonnet.* - -import java.io.StringWriter - -object RunProfiler extends App { - val parser = mainargs.ParserForClass[Config] - val config = parser - .constructEither(MainBenchmark.mainArgs.toIndexedSeq, autoPrintHelpAndExit = None) - .toOption - .get - val file = config.file - val wd = os.pwd - val path = OsPath(os.Path(file, wd)) - val parseCache = new DefaultParseCache - val interp = new Interpreter( - Map.empty[String, String], - Map.empty[String, String], - OsPath(wd), - importer = new SjsonnetMainBase.SimpleImporter( - config.getOrderedJpaths.map(os.Path(_, wd)).map(OsPath(_)).toIndexedSeq, - None - ), - parseCache = parseCache - ) { - override def createEvaluator( - resolver: CachedResolver, - extVars: String => Option[Expr], - wd: Path, - settings: Settings): Evaluator = - new ProfilingEvaluator(resolver, extVars, wd, settings, null) - } - val profiler = interp.evaluator.asInstanceOf[ProfilingEvaluator] - - def run(): Long = { - val renderer = new Renderer(new StringWriter, indent = 3) - val start = interp.resolver.read(path, binaryData = false).get.readString() - val t0 = System.nanoTime() - interp.interpret0(start, path, renderer).toOption.get - System.nanoTime() - t0 - } - - println("\nWarming up...") - profiler.clear() - for (i <- 1 to 10) run() - - println("\nProfiling...") - profiler.clear() - val total = (for (i <- 1 to 5) yield run()).sum - - val roots = parseCache.valuesIterator.map(_.toOption.get).map(_._1).toSeq - roots.foreach(profiler.accumulate) - - println(s"\nTop 20 by time:") - profiler.all.sortBy(-_.time).take(20).foreach { b => show(b.time, b, "- ", rec = false) } - - val cutoff = 0.02 - println(s"\nTrees with >= $cutoff time:") - showAll(roots, "") - - def showAll(es: Seq[Expr], indent: String): Unit = { - val timed = es.iterator - .map(profiler.get) - .filter(_.totalTime.toDouble / total.toDouble >= cutoff) - .toSeq - .sortBy(-_.time) - timed.foreach { b => show(b.totalTime, b, indent, rec = true) } - } - - def show(time: Long, box: profiler.ExprBox, indent: String, rec: Boolean): Unit = { - println(s"$indent${time / 1000000L}ms ${box.name} ${box.prettyPos}") - if (rec) showAll(box.children, indent + " ") - } - - def show(n: String, bs: Seq[profiler.Box]): Unit = { - println(n) - bs.filter(_.count > 0).sortBy(-_.time).foreach { ob => - val avg = if (ob.count == 0) 0 else ob.time / ob.count - println(s"- ${ob.time / 1000000L}ms\t${ob.count}\t${avg}ns\t${ob.name}") - } - } - - show(s"\nBinary operators:", profiler.binaryOperators()) - show(s"\nUnary operators:", profiler.unaryOperators()) - show(s"\nBuilt-in functions:", profiler.builtins()) - show(s"\nExpr types:", profiler.exprTypes()) -} diff --git a/sjsonnet/src-jvm-native/sjsonnet/Config.scala b/sjsonnet/src-jvm-native/sjsonnet/Config.scala index 8f3948da..5cee58c9 100644 --- a/sjsonnet/src-jvm-native/sjsonnet/Config.scala +++ b/sjsonnet/src-jvm-native/sjsonnet/Config.scala @@ -168,11 +168,11 @@ final case class Config( ) maxStack: Int = 500, @arg( - name = "flamegraph", + name = "profile", doc = - "Write a flame graph profile in folded stack format to the given file. Use with https://github.com/brendangregg/FlameGraph" + "Profile evaluation and write results to a file. Format: --profile or --profile : where format is 'text' (default) or 'flamegraph'" ) - flamegraph: Option[String] = None, + profile: Option[String] = None, @arg( doc = "The jsonnet file you wish to evaluate", positional = true diff --git a/sjsonnet/src-jvm-native/sjsonnet/SjsonnetMainBase.scala b/sjsonnet/src-jvm-native/sjsonnet/SjsonnetMainBase.scala index a0568a6b..3ec5798b 100644 --- a/sjsonnet/src-jvm-native/sjsonnet/SjsonnetMainBase.scala +++ b/sjsonnet/src-jvm-native/sjsonnet/SjsonnetMainBase.scala @@ -170,7 +170,7 @@ object SjsonnetMainBase { warn, std, debugStats = debugStats, - flamegraphFile = config.flamegraph + profileOpt = config.profile ) res <- { if (hasWarnings && config.fatalWarnings.value) Left("") @@ -320,7 +320,7 @@ object SjsonnetMainBase { std: Val.Obj, evaluatorOverride: Option[Evaluator] = None, debugStats: DebugStats = null, - flamegraphFile: Option[String] = None): Either[String, String] = { + profileOpt: Option[String] = None): Either[String, String] = { val (jsonnetCode, path) = if (config.exec.value) (file, wd / Util.wrapInLessThanGreaterThan("exec")) @@ -349,8 +349,19 @@ object SjsonnetMainBase { wd ) + val (profileFormat, profileFile) = profileOpt match { + case Some(s) if s.startsWith("flamegraph:") => + (Some(ProfileOutputFormat.FlameGraph), Some(s.stripPrefix("flamegraph:"))) + case Some(s) if s.startsWith("text:") => + (Some(ProfileOutputFormat.Text), Some(s.stripPrefix("text:"))) + case Some(s) => + (Some(ProfileOutputFormat.Text), Some(s)) + case None => + (None, None) + } + var currentPos: Position = null - var profiler: FlameGraphProfiler = null + var profilerInstance: Profiler = null val interp = new Interpreter( queryExtVar = (key: String) => extBinding.get(key).map(ExternalVariable.code), queryTlaVar = (key: String) => tlaBinding.get(key).map(ExternalVariable.code), @@ -372,9 +383,9 @@ object SjsonnetMainBase { val ev = evaluatorOverride.getOrElse( super.createEvaluator(resolver, extVars, wd, settings) ) - if (flamegraphFile.isDefined) { - profiler = new FlameGraphProfiler - ev.flameGraphProfiler = profiler + profileFormat.foreach { fmt => + profilerInstance = new Profiler(fmt, wd) + ev.profiler = profilerInstance } ev } @@ -448,8 +459,10 @@ object SjsonnetMainBase { case _ => renderNormal(config, interp, jsonnetCode, path, wd, () => currentPos) } - if (profiler != null) - flamegraphFile.foreach(profiler.writeTo) + if (profilerInstance != null) + profileFile.foreach(f => + profilerInstance.writeTo(f, pos => interp.evaluator.prettyIndex(pos)) + ) result } diff --git a/sjsonnet/src/sjsonnet/Evaluator.scala b/sjsonnet/src/sjsonnet/Evaluator.scala index 37261f2d..dd823ee3 100644 --- a/sjsonnet/src/sjsonnet/Evaluator.scala +++ b/sjsonnet/src/sjsonnet/Evaluator.scala @@ -30,7 +30,7 @@ class Evaluator( private[this] var stackDepth: Int = 0 private[this] val maxStack: Int = settings.maxStack - private[sjsonnet] var flameGraphProfiler: FlameGraphProfiler = _ + private[sjsonnet] var profiler: Profiler = _ @inline private[sjsonnet] final def checkStackDepth(pos: Position): Unit = { stackDepth += 1 @@ -42,21 +42,18 @@ class Evaluator( @inline private[sjsonnet] final def checkStackDepth(pos: Position, expr: Expr): Unit = { stackDepth += 1 - if (flameGraphProfiler != null) flameGraphProfiler.push(expr.exprErrorString) if (stackDepth > maxStack) Error.fail("Max stack frames exceeded.", pos) } @inline private[sjsonnet] final def checkStackDepth(pos: Position, name: String): Unit = { stackDepth += 1 - if (flameGraphProfiler != null) flameGraphProfiler.push(name) if (stackDepth > maxStack) Error.fail("Max stack frames exceeded.", pos) } @inline private[sjsonnet] final def decrementStackDepth(): Unit = { stackDepth -= 1 - if (flameGraphProfiler != null) flameGraphProfiler.pop() } def materialize(v: Val): Value = Materializer.apply(v) @@ -64,45 +61,49 @@ class Evaluator( collection.mutable.HashMap.empty[Path, Val] override def visitExpr(e: Expr)(implicit scope: ValScope): Val = try { - e match { - case e: ValidId => visitValidId(e) - case e: BinaryOp => visitBinaryOp(e) - case e: Select => visitSelect(e) - case e: Val => e - case e: ApplyBuiltin0 => visitApplyBuiltin0(e) - case e: ApplyBuiltin1 => visitApplyBuiltin1(e) - case e: ApplyBuiltin2 => visitApplyBuiltin2(e) - case e: ApplyBuiltin3 => visitApplyBuiltin3(e) - case e: ApplyBuiltin4 => visitApplyBuiltin4(e) - case e: And => visitAnd(e) - case e: Or => visitOr(e) - case e: UnaryOp => visitUnaryOp(e) - case e: Apply1 => visitApply1(e) - case e: Lookup => visitLookup(e) - case e: Function => visitMethod(e.body, e.params, e.pos) - case e: LocalExpr => visitLocalExpr(e) - case e: Apply => visitApply(e) - case e: IfElse => visitIfElse(e) - case e: Apply3 => visitApply3(e) - case e: ObjBody.MemberList => visitMemberList(e.pos, e, null) - case e: Apply2 => visitApply2(e) - case e: AssertExpr => visitAssert(e) - case e: ApplyBuiltin => visitApplyBuiltin(e) - case e: Comp => visitComp(e) - case e: Arr => visitArr(e) - case e: SelectSuper => visitSelectSuper(e) - case e: LookupSuper => visitLookupSuper(e) - case e: InSuper => visitInSuper(e) - case e: ObjExtend => visitObjExtend(e) - case e: ObjBody.ObjComp => visitObjComp(e, null) - case e: Slice => visitSlice(e) - case e: Import => visitImport(e) - case e: Apply0 => visitApply0(e) - case e: ImportStr => visitImportStr(e) - case e: ImportBin => visitImportBin(e) - case e: Expr.Error => visitError(e) - case e => visitInvalid(e) - } + val p = profiler + val saved: (AnyRef, Int) = if (p != null) p.enter(e) else null + try { + e match { + case e: ValidId => visitValidId(e) + case e: BinaryOp => visitBinaryOp(e) + case e: Select => visitSelect(e) + case e: Val => e + case e: ApplyBuiltin0 => visitApplyBuiltin0(e) + case e: ApplyBuiltin1 => visitApplyBuiltin1(e) + case e: ApplyBuiltin2 => visitApplyBuiltin2(e) + case e: ApplyBuiltin3 => visitApplyBuiltin3(e) + case e: ApplyBuiltin4 => visitApplyBuiltin4(e) + case e: And => visitAnd(e) + case e: Or => visitOr(e) + case e: UnaryOp => visitUnaryOp(e) + case e: Apply1 => visitApply1(e) + case e: Lookup => visitLookup(e) + case e: Function => visitMethod(e.body, e.params, e.pos) + case e: LocalExpr => visitLocalExpr(e) + case e: Apply => visitApply(e) + case e: IfElse => visitIfElse(e) + case e: Apply3 => visitApply3(e) + case e: ObjBody.MemberList => visitMemberList(e.pos, e, null) + case e: Apply2 => visitApply2(e) + case e: AssertExpr => visitAssert(e) + case e: ApplyBuiltin => visitApplyBuiltin(e) + case e: Comp => visitComp(e) + case e: Arr => visitArr(e) + case e: SelectSuper => visitSelectSuper(e) + case e: LookupSuper => visitLookupSuper(e) + case e: InSuper => visitInSuper(e) + case e: ObjExtend => visitObjExtend(e) + case e: ObjBody.ObjComp => visitObjComp(e, null) + case e: Slice => visitSlice(e) + case e: Import => visitImport(e) + case e: Apply0 => visitApply0(e) + case e: ImportStr => visitImportStr(e) + case e: ImportBin => visitImportBin(e) + case e: Expr.Error => visitError(e) + case e => visitInvalid(e) + } + } finally if (p != null) p.exit(saved) } catch { Error.withStackFrame(e) } diff --git a/sjsonnet/src/sjsonnet/FlameGraphProfiler.scala b/sjsonnet/src/sjsonnet/FlameGraphProfiler.scala deleted file mode 100644 index 018e7687..00000000 --- a/sjsonnet/src/sjsonnet/FlameGraphProfiler.scala +++ /dev/null @@ -1,52 +0,0 @@ -package sjsonnet - -import java.io.{BufferedWriter, FileWriter} - -/** - * Collects stack samples during Jsonnet evaluation and writes them in Brendan Gregg's folded stack - * format, suitable for generating flame graphs with https://github.com/brendangregg/FlameGraph. - * - * Each call to [[push]] records a new frame on the current stack. Each call to [[pop]] removes the - * top frame. A sample (incrementing the count for the current stack) is taken on every [[push]], so - * deeper call trees contribute proportionally more samples. - */ -final class FlameGraphProfiler { - private val stack = new java.util.ArrayDeque[String]() - private val counts = new java.util.HashMap[String, java.lang.Long]() - - def push(name: String): Unit = { - stack.push(name) - val key = foldedStack() - val prev = counts.get(key) - counts.put(key, if (prev == null) 1L else prev + 1L) - } - - def pop(): Unit = - if (!stack.isEmpty) stack.pop() - - private def foldedStack(): String = { - val sb = new StringBuilder - val it = stack.descendingIterator() - var first = true - while (it.hasNext) { - if (!first) sb.append(';') - sb.append(it.next()) - first = false - } - sb.toString - } - - def writeTo(path: String): Unit = { - val w = new BufferedWriter(new FileWriter(path)) - try { - val it = counts.entrySet().iterator() - while (it.hasNext) { - val e = it.next() - w.write(e.getKey) - w.write(' ') - w.write(e.getValue.toString) - w.newLine() - } - } finally w.close() - } -} diff --git a/sjsonnet/src/sjsonnet/Profiler.scala b/sjsonnet/src/sjsonnet/Profiler.scala new file mode 100644 index 00000000..65b407b7 --- /dev/null +++ b/sjsonnet/src/sjsonnet/Profiler.scala @@ -0,0 +1,352 @@ +package sjsonnet + +import java.io.{BufferedWriter, FileWriter} +import java.util + +import scala.collection.mutable.ListBuffer + +final case class BoxId(name: String, fileName: String, offset: Int) + +final case class Box(id: BoxId, lineCol: () => String, count: Long, selfTimeNs: Long) + +final case class ProfilingResult(boxes: Seq[Box]) + +object ProfilingResult { + def merge(results: Seq[ProfilingResult]): ProfilingResult = { + if (results.size == 1) return results.head + val mergedBoxes = results + .flatMap(_.boxes) + .groupBy(_.id) + .values + .map { boxes => + Box( + id = boxes.head.id, + lineCol = boxes.head.lineCol, + count = boxes.map(_.count).sum, + selfTimeNs = boxes.map(_.selfTimeNs).sum + ) + } + .toSeq + ProfilingResult(mergedBoxes) + } +} + +sealed abstract class ProfileOutputFormat +object ProfileOutputFormat { + case object Text extends ProfileOutputFormat + case object FlameGraph extends ProfileOutputFormat +} + +/** + * Profiler that measures self-time for every evaluated expression, and optionally records call + * stacks for flame graph output. + * + * Hooks into `Evaluator.visitExpr` to time each expression. The profiler is single-threaded and not + * thread-safe. + */ +final class Profiler(val format: ProfileOutputFormat, wd: Path) { + + private class ExprBox(val expr: Expr) { + val id: BoxId = { + val fileName = Option(expr.pos) + .map(_.currentFile.relativeToString(wd)) + .getOrElse("") + + val name = { + val exprName = expr.getClass.getName.split('.').last.split('$').last + val detail: Option[String] = expr match { + case a: Expr.Apply => Some(Expr.callTargetName(a.value)) + case a: Expr.Apply0 => Some(Expr.callTargetName(a.value)) + case a: Expr.Apply1 => Some(Expr.callTargetName(a.value)) + case a: Expr.Apply2 => Some(Expr.callTargetName(a.value)) + case a: Expr.Apply3 => Some(Expr.callTargetName(a.value)) + case a: Expr.ApplyBuiltin => Some(a.func.functionName) + case a: Expr.ApplyBuiltin0 => Some(a.func.functionName) + case a: Expr.ApplyBuiltin1 => Some(a.func.functionName) + case a: Expr.ApplyBuiltin2 => Some(a.func.functionName) + case a: Expr.ApplyBuiltin3 => Some(a.func.functionName) + case a: Expr.ApplyBuiltin4 => Some(a.func.functionName) + case u: Expr.UnaryOp => Some(Expr.UnaryOp.name(u.op)) + case b: Expr.BinaryOp => Some(Expr.BinaryOp.name(b.op)) + case _: Expr.Comp => Some("array comprehension") + case _: Expr.ObjBody.ObjComp => Some("object comprehension") + case _ => None + } + exprName + detail.map(n => s" ($n)").getOrElse("") + } + + BoxId(name, fileName, if (expr.pos == null) 0 else expr.pos.offset) + } + + var lineCol: () => String = _ + var count: Long = 0 + var selfTimeNs: Long = 0 + + def toBox: Box = Box(id, lineCol, count, selfTimeNs) + } + + private val data = new util.IdentityHashMap[Expr, ExprBox] + private var current: ExprBox = _ + + // Flame graph state (only used when format == FlameGraph) + private val fgStack = new util.ArrayDeque[String]() + private val fgCounts = new util.HashMap[String, java.lang.Long]() + private var fgDepth: Int = 0 + + private def isFlameGraphFrame(e: Expr): Boolean = e match { + case _: Expr.Apply | _: Expr.Apply0 | _: Expr.Apply1 | _: Expr.Apply2 | _: Expr.Apply3 | + _: Expr.ApplyBuiltin | _: Expr.ApplyBuiltin0 | _: Expr.ApplyBuiltin1 | + _: Expr.ApplyBuiltin2 | _: Expr.ApplyBuiltin3 | _: Expr.ApplyBuiltin4 | _: Expr.Comp | + _: Expr.ObjBody.ObjComp => + true + case _ => false + } + + private def isBuiltin(e: Expr): Boolean = e match { + case _: Expr.ApplyBuiltin | _: Expr.ApplyBuiltin0 | _: Expr.ApplyBuiltin1 | + _: Expr.ApplyBuiltin2 | _: Expr.ApplyBuiltin3 | _: Expr.ApplyBuiltin4 => + true + case _ => false + } + + private def flameGraphFrameName(e: Expr): String = { + val name = e.exprErrorString + if (isBuiltin(e) || e.pos == null) name + else { + val file = e.pos.currentFile.relativeToString(wd) + s"$name ($file)" + } + } + + private def getOrCreate(e: Expr): ExprBox = { + var box = data.get(e) + if (box == null) { + box = new ExprBox(e) + data.put(e, box) + } + box + } + + /** + * Must be called before evaluating `e`. Returns an opaque token for [[exit]]. The low bit of the + * returned long encodes whether a flamegraph frame was pushed; the upper bits store the saved + * fgDepth. + */ + def enter(e: Expr): (AnyRef, Int) = { + val box = getOrCreate(e) + box.count += 1 + val parent = current + current = box + + val now = System.nanoTime() + box.selfTimeNs -= now + if (parent != null) parent.selfTimeNs += now + + val prevFgDepth = fgDepth + if (format == ProfileOutputFormat.FlameGraph && isFlameGraphFrame(e)) { + fgStack.push(flameGraphFrameName(e)) + fgDepth += 1 + val key = foldedStack() + val prev = fgCounts.get(key) + fgCounts.put(key, if (prev == null) 1L else prev + 1L) + } + + (parent, prevFgDepth) + } + + /** Must be called after evaluating `e`, with the token returned by [[enter]]. */ + def exit(saved: (AnyRef, Int)): Unit = { + val now = System.nanoTime() + val box = current + if (box != null) box.selfTimeNs += now + + val (parentRef, prevFgDepth) = saved + if (parentRef != null) { + val parent = parentRef.asInstanceOf[ExprBox] + current = parent + parent.selfTimeNs -= now + } else { + current = null + } + + while (fgDepth > prevFgDepth) { + fgStack.pop() + fgDepth -= 1 + } + } + + def initLineColLookup(prettyIndex: Position => Option[(Int, Int)]): Unit = { + val it = data.values().iterator() + while (it.hasNext) { + val box = it.next() + box.lineCol = () => { + Option(box.expr.pos) + .flatMap(prettyIndex) + .map { case (l, c) => s"${l + 1}:$c" } + .getOrElse("?:?") + } + } + } + + def collectResult(): ProfilingResult = { + val boxes = new scala.collection.mutable.ArrayBuffer[Box](data.size()) + val it = data.values().iterator() + while (it.hasNext) boxes += it.next().toBox + new ProfilingResult(boxes.toSeq) + } + + def writeTo(path: String, prettyIndex: Position => Option[(Int, Int)]): Unit = { + initLineColLookup(prettyIndex) + format match { + case ProfileOutputFormat.Text => + val result = collectResult() + val lines = ProfilePrinter.formatResult(result) + val w = new BufferedWriter(new FileWriter(path)) + try + lines.foreach { line => + w.write(line); w.newLine() + } + finally w.close() + + case ProfileOutputFormat.FlameGraph => + val w = new BufferedWriter(new FileWriter(path)) + try { + val it = fgCounts.entrySet().iterator() + while (it.hasNext) { + val e = it.next() + w.write(e.getKey) + w.write(' ') + w.write(e.getValue.toString) + w.newLine() + } + } finally w.close() + } + } + + private def foldedStack(): String = { + val sb = new StringBuilder + val it = fgStack.descendingIterator() + var first = true + while (it.hasNext) { + if (!first) sb.append(';') + sb.append(it.next()) + first = false + } + sb.toString + } +} + +object ProfilePrinter { + private val TopExpressionsCount = 50 + private val TopAggregatedCount = 10 + + def formatResult(result: ProfilingResult): Seq[String] = { + val lines = new ListBuffer[String] + + val overhead = estimateSystemNanoTimeOverhead + val boxes = result.boxes.map(box => + box.copy(selfTimeNs = math.max(0, box.selfTimeNs - box.count * overhead)) + ) + + val totalTimeNs = boxes.map(_.selfTimeNs).sum + lines += s"Total evaluation time: ${nsToSec(totalTimeNs)}" + + val evaluations = boxes.map(_.count).sum + lines += s"Unique expressions: ${result.boxes.size}, expression evaluations: $evaluations" + lines += s"Times adjusted based on estimated System.nanoTime overhead: $overhead ns" + + lines += "" + lines += s"Top $TopExpressionsCount by self time:" + lines += "" + + lines += formatRow("count", "self", "%", "cumul", "%", "location") + lines ++= format(boxes, totalTimeNs, agg = None, TopExpressionsCount) + + lines += "" + lines += s"Top $TopAggregatedCount aggregated by file name:" + lines += "" + + lines += formatRow("count", "self", "%", "cumul", "%", "file") + lines ++= format(boxes, totalTimeNs, agg = Some(box => box.id.fileName), TopAggregatedCount) + + lines += "" + lines += s"Top $TopAggregatedCount aggregated by expression:" + lines += "" + + lines += formatRow("count", "self", "%", "cumul", "%", "expression") + lines ++= format(boxes, totalTimeNs, agg = Some(box => box.id.name), TopAggregatedCount) + + lines.toList + } + + private def format( + boxes: Seq[Box], + totalTimeNs: Long, + agg: Option[Box => String], + top: Int): Seq[String] = { + case class Stats(name: () => String, count: Long, selfTimeNs: Long) + + val aggregated = agg match { + case Some(agg) => + boxes + .groupBy(agg) + .map { case (aggregatedBy, boxes) => + Stats( + name = () => aggregatedBy, + count = boxes.map(_.count).sum, + selfTimeNs = boxes.map(_.selfTimeNs).sum + ) + } + + case _ => + boxes.map { box => + Stats( + () => f"${box.id.fileName}:${box.lineCol()} ${box.id.name}", + box.count, + box.selfTimeNs + ) + } + } + + val totalNs = math.max(1, totalTimeNs) + var sum = 0L + + aggregated.toSeq + .sortBy(-_.selfTimeNs) + .take(top) + .map { stats => + sum = sum + stats.selfTimeNs + formatRow( + count = stats.count.toString, + selfTime = nsToSec(stats.selfTimeNs), + selfPerc = f"${stats.selfTimeNs * 100.0 / totalNs}%.1f", + selfSum = nsToSec(sum), + sumPerc = f"${sum * 100.0 / totalNs}%.1f", + name = stats.name() + ) + } + } + + private def estimateSystemNanoTimeOverhead: Long = { + val samples: Array[Long] = new Array[Long](1000 * 1000) + for (i <- samples.indices) { + val start = System.nanoTime() + samples(i) = System.nanoTime() - start + } + samples.sorted.apply(samples.length / 2) + } + + private def formatRow( + count: String, + selfTime: String, + selfPerc: String, + selfSum: String, + sumPerc: String, + name: String): String = + f"$count%12s $selfTime%8s $selfPerc%4s $selfSum%8s $sumPerc%4s $name" + + private def nsToSec(ns: Long): String = { + val sec = ns / 1000000000.0 + f"$sec%.2fs" + } +}