diff --git a/codeflash-java-runtime/pom.xml b/codeflash-java-runtime/pom.xml index cb95732dd..a309a6c0d 100644 --- a/codeflash-java-runtime/pom.xml +++ b/codeflash-java-runtime/pom.xml @@ -48,6 +48,18 @@ 3.45.0.0 + + + org.ow2.asm + asm + 9.7.1 + + + org.ow2.asm + asm-commons + 9.7.1 + + org.junit.jupiter @@ -100,9 +112,19 @@ shade + + + org.objectweb.asm + com.codeflash.asm + + com.codeflash.Comparator + + com.codeflash.profiler.ProfilerAgent + true + diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/profiler/LineProfilingClassVisitor.java b/codeflash-java-runtime/src/main/java/com/codeflash/profiler/LineProfilingClassVisitor.java new file mode 100644 index 000000000..a2473ed97 --- /dev/null +++ b/codeflash-java-runtime/src/main/java/com/codeflash/profiler/LineProfilingClassVisitor.java @@ -0,0 +1,41 @@ +package com.codeflash.profiler; + +import org.objectweb.asm.ClassVisitor; +import org.objectweb.asm.MethodVisitor; +import org.objectweb.asm.Opcodes; + +/** + * ASM ClassVisitor that filters methods and wraps target methods with + * {@link LineProfilingMethodVisitor} for line-level profiling. + */ +public class LineProfilingClassVisitor extends ClassVisitor { + + private final String internalClassName; + private final ProfilerConfig config; + private String sourceFile; + + public LineProfilingClassVisitor(ClassVisitor classVisitor, String internalClassName, ProfilerConfig config) { + super(Opcodes.ASM9, classVisitor); + this.internalClassName = internalClassName; + this.config = config; + } + + @Override + public void visitSource(String source, String debug) { + super.visitSource(source, debug); + // Resolve the absolute source file path from the config + this.sourceFile = config.resolveSourceFile(internalClassName); + } + + @Override + public MethodVisitor visitMethod(int access, String name, String descriptor, + String signature, String[] exceptions) { + MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions); + + if (config.shouldInstrumentMethod(internalClassName, name)) { + return new LineProfilingMethodVisitor(mv, access, name, descriptor, + internalClassName, sourceFile); + } + return mv; + } +} diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/profiler/LineProfilingMethodVisitor.java b/codeflash-java-runtime/src/main/java/com/codeflash/profiler/LineProfilingMethodVisitor.java new file mode 100644 index 000000000..c7cd580d2 --- /dev/null +++ b/codeflash-java-runtime/src/main/java/com/codeflash/profiler/LineProfilingMethodVisitor.java @@ -0,0 +1,154 @@ +package com.codeflash.profiler; + +import org.objectweb.asm.Label; +import org.objectweb.asm.MethodVisitor; +import org.objectweb.asm.Opcodes; +import org.objectweb.asm.Type; +import org.objectweb.asm.commons.AdviceAdapter; + +/** + * ASM MethodVisitor that injects line-level profiling probes. + * + *

At each {@code LineNumber} table entry within the target method: + *

    + *
  1. Registers the line with {@link ProfilerRegistry} (happens once at class-load time)
  2. + *
  3. Injects bytecode: {@code LDC globalId; INVOKESTATIC ProfilerData.hit(I)V}
  4. + *
+ * + *

At method entry: injects a warmup self-call loop (if warmup is configured) followed by + * {@code ProfilerData.enterMethod(entryLineId)}. + *

At method exit (every RETURN/ATHROW): injects {@code ProfilerData.exitMethod()}. + */ +public class LineProfilingMethodVisitor extends AdviceAdapter { + + private static final String PROFILER_DATA = "com/codeflash/profiler/ProfilerData"; + + private final String internalClassName; + private final String sourceFile; + private final String methodName; + private boolean firstLineVisited = false; + + protected LineProfilingMethodVisitor( + MethodVisitor mv, int access, String name, String descriptor, + String internalClassName, String sourceFile) { + super(Opcodes.ASM9, mv, access, name, descriptor); + this.internalClassName = internalClassName; + this.sourceFile = sourceFile; + this.methodName = name; + } + + /** + * Inject a warmup self-call loop at method entry. + * + *

Generated bytecode equivalent: + *

+     * if (ProfilerData.isWarmupNeeded()) {
+     *     ProfilerData.startWarmup();
+     *     for (int i = 0; i < ProfilerData.getWarmupThreshold(); i++) {
+     *         thisMethod(originalArgs);
+     *     }
+     *     ProfilerData.finishWarmup();
+     * }
+     * 
+ * + *

Recursive warmup calls re-enter this method but {@code isWarmupNeeded()} returns + * {@code false} (guard flag set by {@code startWarmup()}), so they execute the normal + * instrumented body. After the loop, {@code finishWarmup()} zeros all counters so the + * next real execution records clean data. + */ + @Override + protected void onMethodEnter() { + Label skipWarmup = new Label(); + + // if (!ProfilerData.isWarmupNeeded()) goto skipWarmup + mv.visitMethodInsn(INVOKESTATIC, PROFILER_DATA, "isWarmupNeeded", "()Z", false); + mv.visitJumpInsn(IFEQ, skipWarmup); + + // ProfilerData.startWarmup() + mv.visitMethodInsn(INVOKESTATIC, PROFILER_DATA, "startWarmup", "()V", false); + + // int _warmupIdx = 0 + int counterLocal = newLocal(Type.INT_TYPE); + mv.visitInsn(ICONST_0); + mv.visitVarInsn(ISTORE, counterLocal); + + Label loopCheck = new Label(); + Label loopBody = new Label(); + + mv.visitJumpInsn(GOTO, loopCheck); + + // loop body: call self with original arguments + mv.visitLabel(loopBody); + + boolean isStatic = (methodAccess & Opcodes.ACC_STATIC) != 0; + if (!isStatic) { + loadThis(); + } + loadArgs(); + + int invokeOp; + if (isStatic) { + invokeOp = INVOKESTATIC; + } else if ((methodAccess & Opcodes.ACC_PRIVATE) != 0) { + invokeOp = INVOKESPECIAL; + } else { + invokeOp = INVOKEVIRTUAL; + } + mv.visitMethodInsn(invokeOp, internalClassName, methodName, methodDesc, false); + + // Discard return value + Type returnType = Type.getReturnType(methodDesc); + switch (returnType.getSort()) { + case Type.VOID: + break; + case Type.LONG: + case Type.DOUBLE: + mv.visitInsn(POP2); + break; + default: + mv.visitInsn(POP); + break; + } + + // _warmupIdx++ + mv.visitIincInsn(counterLocal, 1); + + // loop check: _warmupIdx < ProfilerData.getWarmupThreshold() + mv.visitLabel(loopCheck); + mv.visitVarInsn(ILOAD, counterLocal); + mv.visitMethodInsn(INVOKESTATIC, PROFILER_DATA, "getWarmupThreshold", "()I", false); + mv.visitJumpInsn(IF_ICMPLT, loopBody); + + // ProfilerData.finishWarmup() + mv.visitMethodInsn(INVOKESTATIC, PROFILER_DATA, "finishWarmup", "()V", false); + + mv.visitLabel(skipWarmup); + } + + @Override + public void visitLineNumber(int line, Label start) { + super.visitLineNumber(line, start); + + // Register this line and get its global ID (happens once at class-load time) + String dotClassName = internalClassName.replace('/', '.'); + int globalId = ProfilerRegistry.register(sourceFile, dotClassName, methodName, line); + + if (!firstLineVisited) { + firstLineVisited = true; + // Inject enterMethod call at the first line of the method + mv.visitLdcInsn(globalId); + mv.visitMethodInsn(INVOKESTATIC, PROFILER_DATA, "enterMethod", "(I)V", false); + } + + // Inject: ProfilerData.hit(globalId) + mv.visitLdcInsn(globalId); + mv.visitMethodInsn(INVOKESTATIC, PROFILER_DATA, "hit", "(I)V", false); + } + + @Override + protected void onMethodExit(int opcode) { + // Before every RETURN or ATHROW, flush timing for the last line + // This fixes the "last line always shows 0ms" bug + mv.visitMethodInsn(INVOKESTATIC, PROFILER_DATA, "exitMethod", "()V", false); + } +} diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/profiler/LineProfilingTransformer.java b/codeflash-java-runtime/src/main/java/com/codeflash/profiler/LineProfilingTransformer.java new file mode 100644 index 000000000..39fbe9d97 --- /dev/null +++ b/codeflash-java-runtime/src/main/java/com/codeflash/profiler/LineProfilingTransformer.java @@ -0,0 +1,46 @@ +package com.codeflash.profiler; + +import org.objectweb.asm.ClassReader; +import org.objectweb.asm.ClassWriter; + +import java.lang.instrument.ClassFileTransformer; +import java.security.ProtectionDomain; + +/** + * {@link ClassFileTransformer} that instruments target classes with line profiling. + * + *

When a class matches the profiler configuration, it is run through ASM + * to inject {@link ProfilerData#hit(int)} calls at each line number. + */ +public class LineProfilingTransformer implements ClassFileTransformer { + + private final ProfilerConfig config; + + public LineProfilingTransformer(ProfilerConfig config) { + this.config = config; + } + + @Override + public byte[] transform(ClassLoader loader, String className, + Class classBeingRedefined, ProtectionDomain protectionDomain, + byte[] classfileBuffer) { + if (className == null || !config.shouldInstrumentClass(className)) { + return null; // null = don't transform + } + + try { + return instrumentClass(className, classfileBuffer); + } catch (Exception e) { + System.err.println("[codeflash-profiler] Failed to instrument " + className + ": " + e.getMessage()); + return null; + } + } + + private byte[] instrumentClass(String internalClassName, byte[] bytecode) { + ClassReader cr = new ClassReader(bytecode); + ClassWriter cw = new ClassWriter(cr, ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS); + LineProfilingClassVisitor cv = new LineProfilingClassVisitor(cw, internalClassName, config); + cr.accept(cv, ClassReader.EXPAND_FRAMES); + return cw.toByteArray(); + } +} diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/profiler/ProfilerAgent.java b/codeflash-java-runtime/src/main/java/com/codeflash/profiler/ProfilerAgent.java new file mode 100644 index 000000000..572803f78 --- /dev/null +++ b/codeflash-java-runtime/src/main/java/com/codeflash/profiler/ProfilerAgent.java @@ -0,0 +1,53 @@ +package com.codeflash.profiler; + +import java.lang.instrument.Instrumentation; + +/** + * Java agent entry point for the CodeFlash line profiler. + * + *

Loaded via {@code -javaagent:codeflash-profiler-agent.jar=config=/path/to/config.json}. + * + *

The agent: + *

    + *
  1. Parses the config file specifying which classes/methods to profile
  2. + *
  3. Registers a {@link LineProfilingTransformer} to instrument target classes at load time
  4. + *
  5. Registers a shutdown hook to write profiling results to JSON
  6. + *
+ */ +public class ProfilerAgent { + + /** + * Called by the JVM before {@code main()} when the agent is loaded. + * + * @param agentArgs comma-separated key=value pairs (e.g., {@code config=/path/to/config.json}) + * @param inst the JVM instrumentation interface + */ + public static void premain(String agentArgs, Instrumentation inst) { + ProfilerConfig config = ProfilerConfig.parse(agentArgs); + + if (config.getTargetClasses().isEmpty()) { + System.err.println("[codeflash-profiler] No target classes configured, profiler inactive"); + return; + } + + // Pre-allocate registry with estimated capacity + ProfilerRegistry.initialize(config.getExpectedLineCount()); + + // Configure warmup phase + ProfilerData.setWarmupThreshold(config.getWarmupIterations()); + + // Register the bytecode transformer + inst.addTransformer(new LineProfilingTransformer(config), true); + + // Register shutdown hook to write results on JVM exit + String outputFile = config.getOutputFile(); + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + ProfilerReporter.writeResults(outputFile, config); + }, "codeflash-profiler-shutdown")); + + int warmup = config.getWarmupIterations(); + String warmupMsg = warmup > 0 ? ", warmup=" + warmup + " calls" : ""; + System.err.println("[codeflash-profiler] Agent loaded, profiling " + + config.getTargetClasses().size() + " class(es)" + warmupMsg); + } +} diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/profiler/ProfilerConfig.java b/codeflash-java-runtime/src/main/java/com/codeflash/profiler/ProfilerConfig.java new file mode 100644 index 000000000..6846b7945 --- /dev/null +++ b/codeflash-java-runtime/src/main/java/com/codeflash/profiler/ProfilerConfig.java @@ -0,0 +1,424 @@ +package com.codeflash.profiler; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Configuration for the profiler agent, parsed from a JSON file. + * + *

The JSON is generated by Python ({@code JavaLineProfiler.generate_agent_config()}). + * Uses a hand-rolled JSON parser to avoid external dependencies (keeps the agent JAR small). + */ +public final class ProfilerConfig { + + private String outputFile = ""; + private int warmupIterations = 10; + private final Map> targets = new HashMap<>(); + private final Map lineContents = new HashMap<>(); + private final Set targetClassNames = new HashSet<>(); + + public static class MethodTarget { + public final String name; + public final int startLine; + public final int endLine; + public final String sourceFile; + + public MethodTarget(String name, int startLine, int endLine, String sourceFile) { + this.name = name; + this.startLine = startLine; + this.endLine = endLine; + this.sourceFile = sourceFile; + } + } + + /** + * Parse agent arguments and load the config file. + * + *

Expected format: {@code config=/path/to/config.json} + */ + public static ProfilerConfig parse(String agentArgs) { + ProfilerConfig config = new ProfilerConfig(); + if (agentArgs == null || agentArgs.isEmpty()) { + return config; + } + + String configPath = null; + for (String part : agentArgs.split(",")) { + String trimmed = part.trim(); + if (trimmed.startsWith("config=")) { + configPath = trimmed.substring("config=".length()); + } + } + + if (configPath == null) { + System.err.println("[codeflash-profiler] No config= in agent args: " + agentArgs); + return config; + } + + try { + String json = new String(Files.readAllBytes(Paths.get(configPath)), StandardCharsets.UTF_8); + config.parseJson(json); + } catch (IOException e) { + System.err.println("[codeflash-profiler] Failed to read config: " + e.getMessage()); + } + + return config; + } + + public String getOutputFile() { + return outputFile; + } + + public int getWarmupIterations() { + return warmupIterations; + } + + public Set getTargetClasses() { + return Collections.unmodifiableSet(targetClassNames); + } + + public List getMethodsForClass(String internalClassName) { + return targets.getOrDefault(internalClassName, Collections.emptyList()); + } + + public Map getLineContents() { + return Collections.unmodifiableMap(lineContents); + } + + public int getExpectedLineCount() { + int count = 0; + for (List methods : targets.values()) { + for (MethodTarget m : methods) { + count += Math.max(m.endLine - m.startLine + 1, 1); + } + } + return Math.max(count, 256); + } + + /** + * Check if a class should be instrumented. Uses JVM internal names (slash-separated). + */ + public boolean shouldInstrumentClass(String internalClassName) { + return targetClassNames.contains(internalClassName); + } + + /** + * Check if a specific method in a class should be instrumented. + */ + public boolean shouldInstrumentMethod(String internalClassName, String methodName) { + List methods = targets.get(internalClassName); + if (methods == null) return false; + for (MethodTarget m : methods) { + if (m.name.equals(methodName)) { + return true; + } + } + return false; + } + + /** + * Resolve the absolute source file path for a given class and its source file attribute. + */ + public String resolveSourceFile(String internalClassName) { + List methods = targets.get(internalClassName); + if (methods != null && !methods.isEmpty()) { + return methods.get(0).sourceFile; + } + return internalClassName.replace('/', '.') + ".java"; + } + + // ---- Minimal JSON parser ---- + + private void parseJson(String json) { + json = json.trim(); + if (!json.startsWith("{") || !json.endsWith("}")) return; + + int[] pos = {1}; // mutable position cursor + skipWhitespace(json, pos); + + while (pos[0] < json.length() - 1) { + String key = readString(json, pos); + skipWhitespace(json, pos); + expect(json, pos, ':'); + skipWhitespace(json, pos); + + switch (key) { + case "outputFile": + this.outputFile = readString(json, pos); + break; + case "warmupIterations": + this.warmupIterations = readInt(json, pos); + break; + case "targets": + parseTargets(json, pos); + break; + case "lineContents": + parseLineContents(json, pos); + break; + default: + skipValue(json, pos); + break; + } + + skipWhitespace(json, pos); + if (pos[0] < json.length() && json.charAt(pos[0]) == ',') { + pos[0]++; + skipWhitespace(json, pos); + } + } + } + + private void parseTargets(String json, int[] pos) { + expect(json, pos, '['); + skipWhitespace(json, pos); + + while (pos[0] < json.length() && json.charAt(pos[0]) != ']') { + parseTargetObject(json, pos); + skipWhitespace(json, pos); + if (pos[0] < json.length() && json.charAt(pos[0]) == ',') { + pos[0]++; + skipWhitespace(json, pos); + } + } + pos[0]++; // skip ']' + } + + private void parseTargetObject(String json, int[] pos) { + expect(json, pos, '{'); + skipWhitespace(json, pos); + + String className = ""; + List methods = new ArrayList<>(); + + while (pos[0] < json.length() && json.charAt(pos[0]) != '}') { + String key = readString(json, pos); + skipWhitespace(json, pos); + expect(json, pos, ':'); + skipWhitespace(json, pos); + + switch (key) { + case "className": + className = readString(json, pos); + break; + case "methods": + methods = parseMethodsArray(json, pos); + break; + default: + skipValue(json, pos); + break; + } + + skipWhitespace(json, pos); + if (pos[0] < json.length() && json.charAt(pos[0]) == ',') { + pos[0]++; + skipWhitespace(json, pos); + } + } + pos[0]++; // skip '}' + + if (!className.isEmpty()) { + targets.put(className, methods); + targetClassNames.add(className); + } + } + + private List parseMethodsArray(String json, int[] pos) { + List methods = new ArrayList<>(); + expect(json, pos, '['); + skipWhitespace(json, pos); + + while (pos[0] < json.length() && json.charAt(pos[0]) != ']') { + methods.add(parseMethodTarget(json, pos)); + skipWhitespace(json, pos); + if (pos[0] < json.length() && json.charAt(pos[0]) == ',') { + pos[0]++; + skipWhitespace(json, pos); + } + } + pos[0]++; // skip ']' + return methods; + } + + private MethodTarget parseMethodTarget(String json, int[] pos) { + expect(json, pos, '{'); + skipWhitespace(json, pos); + + String name = ""; + int startLine = 0; + int endLine = 0; + String sourceFile = ""; + + while (pos[0] < json.length() && json.charAt(pos[0]) != '}') { + String key = readString(json, pos); + skipWhitespace(json, pos); + expect(json, pos, ':'); + skipWhitespace(json, pos); + + switch (key) { + case "name": + name = readString(json, pos); + break; + case "startLine": + startLine = readInt(json, pos); + break; + case "endLine": + endLine = readInt(json, pos); + break; + case "sourceFile": + sourceFile = readString(json, pos); + break; + default: + skipValue(json, pos); + break; + } + + skipWhitespace(json, pos); + if (pos[0] < json.length() && json.charAt(pos[0]) == ',') { + pos[0]++; + skipWhitespace(json, pos); + } + } + pos[0]++; // skip '}' + + return new MethodTarget(name, startLine, endLine, sourceFile); + } + + private void parseLineContents(String json, int[] pos) { + expect(json, pos, '{'); + skipWhitespace(json, pos); + + while (pos[0] < json.length() && json.charAt(pos[0]) != '}') { + String key = readString(json, pos); + skipWhitespace(json, pos); + expect(json, pos, ':'); + skipWhitespace(json, pos); + String value = readString(json, pos); + lineContents.put(key, value); + + skipWhitespace(json, pos); + if (pos[0] < json.length() && json.charAt(pos[0]) == ',') { + pos[0]++; + skipWhitespace(json, pos); + } + } + pos[0]++; // skip '}' + } + + private static String readString(String json, int[] pos) { + if (pos[0] >= json.length() || json.charAt(pos[0]) != '"') return ""; + pos[0]++; // skip opening quote + + StringBuilder sb = new StringBuilder(); + while (pos[0] < json.length()) { + char c = json.charAt(pos[0]); + if (c == '\\' && pos[0] + 1 < json.length()) { + pos[0]++; + char escaped = json.charAt(pos[0]); + switch (escaped) { + case '"': sb.append('"'); break; + case '\\': sb.append('\\'); break; + case '/': sb.append('/'); break; + case 'n': sb.append('\n'); break; + case 't': sb.append('\t'); break; + case 'r': sb.append('\r'); break; + default: sb.append('\\').append(escaped); break; + } + } else if (c == '"') { + pos[0]++; // skip closing quote + return sb.toString(); + } else { + sb.append(c); + } + pos[0]++; + } + return sb.toString(); + } + + private static int readInt(String json, int[] pos) { + int start = pos[0]; + boolean negative = false; + if (pos[0] < json.length() && json.charAt(pos[0]) == '-') { + negative = true; + pos[0]++; + } + while (pos[0] < json.length() && Character.isDigit(json.charAt(pos[0]))) { + pos[0]++; + } + String numStr = json.substring(start, pos[0]); + try { + return Integer.parseInt(numStr); + } catch (NumberFormatException e) { + return 0; + } + } + + private static void skipValue(String json, int[] pos) { + if (pos[0] >= json.length()) return; + char c = json.charAt(pos[0]); + if (c == '"') { + readString(json, pos); + } else if (c == '{') { + skipBraced(json, pos, '{', '}'); + } else if (c == '[') { + skipBraced(json, pos, '[', ']'); + } else if (c == 'n' && json.startsWith("null", pos[0])) { + pos[0] += 4; + } else if (c == 't' && json.startsWith("true", pos[0])) { + pos[0] += 4; + } else if (c == 'f' && json.startsWith("false", pos[0])) { + pos[0] += 5; + } else { + // number + while (pos[0] < json.length() && "0123456789.eE+-".indexOf(json.charAt(pos[0])) >= 0) { + pos[0]++; + } + } + } + + private static void skipBraced(String json, int[] pos, char open, char close) { + int depth = 0; + boolean inString = false; + while (pos[0] < json.length()) { + char c = json.charAt(pos[0]); + if (inString) { + if (c == '\\') { + pos[0]++; // skip escaped char + } else if (c == '"') { + inString = false; + } + } else { + if (c == '"') inString = true; + else if (c == open) depth++; + else if (c == close) { + depth--; + if (depth == 0) { + pos[0]++; + return; + } + } + } + pos[0]++; + } + } + + private static void skipWhitespace(String json, int[] pos) { + while (pos[0] < json.length() && Character.isWhitespace(json.charAt(pos[0]))) { + pos[0]++; + } + } + + private static void expect(String json, int[] pos, char expected) { + if (pos[0] < json.length() && json.charAt(pos[0]) == expected) { + pos[0]++; + } + } +} diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/profiler/ProfilerData.java b/codeflash-java-runtime/src/main/java/com/codeflash/profiler/ProfilerData.java new file mode 100644 index 000000000..03363a7be --- /dev/null +++ b/codeflash-java-runtime/src/main/java/com/codeflash/profiler/ProfilerData.java @@ -0,0 +1,273 @@ +package com.codeflash.profiler; + +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; + +/** + * Zero-allocation, zero-contention per-line profiling data storage. + * + *

Each thread gets its own primitive {@code long[]} arrays for hit counts and self-time. + * The hot path ({@link #hit(int)}) performs only an array-index increment and a single + * {@link System#nanoTime()} call — no object allocations, no locks, no shared-state contention. + * + *

A per-thread call stack tracks method entry/exit to: + *

    + *
  • Attribute time to the last line of a function (fixes the "last line 0ms" bug)
  • + *
  • Pause parent-line timing during callee execution (fixes cross-function timing)
  • + *
  • Handle recursion correctly (each stack frame is independent)
  • + *
+ */ +public final class ProfilerData { + + private static final int INITIAL_CAPACITY = 4096; + private static final int MAX_CALL_DEPTH = 256; + + // Thread-local arrays — each thread gets its own, no contention + private static final ThreadLocal hitCounts = + ThreadLocal.withInitial(() -> registerArray(new long[INITIAL_CAPACITY])); + private static final ThreadLocal selfTimeNs = + ThreadLocal.withInitial(() -> registerTimeArray(new long[INITIAL_CAPACITY])); + + // Per-thread "last line" tracking for time attribution + // Using int[1] and long[1] to avoid boxing + private static final ThreadLocal lastLineId = + ThreadLocal.withInitial(() -> new int[]{-1}); + private static final ThreadLocal lastLineTime = + ThreadLocal.withInitial(() -> new long[]{0L}); + + // Per-thread call stack for method entry/exit + private static final ThreadLocal callStackLineIds = + ThreadLocal.withInitial(() -> new int[MAX_CALL_DEPTH]); + private static final ThreadLocal callStackDepth = + ThreadLocal.withInitial(() -> new int[]{0}); + + // Global references to all thread-local arrays for harvesting at shutdown + private static final List allHitArrays = new CopyOnWriteArrayList<>(); + private static final List allTimeArrays = new CopyOnWriteArrayList<>(); + + // Warmup state: the method visitor injects a self-calling warmup loop, + // warmupInProgress guards against recursive re-entry into the warmup block. + private static volatile int warmupThreshold = 0; + private static volatile boolean warmupComplete = false; + private static volatile boolean warmupInProgress = false; + + private ProfilerData() {} + + private static long[] registerArray(long[] arr) { + allHitArrays.add(arr); + return arr; + } + + private static long[] registerTimeArray(long[] arr) { + allTimeArrays.add(arr); + return arr; + } + + /** + * Set the number of self-call warmup iterations before measurement begins. + * Called once from {@link ProfilerAgent#premain} before any classes are loaded. + * + * @param threshold number of warmup iterations (0 = no warmup) + */ + public static void setWarmupThreshold(int threshold) { + warmupThreshold = threshold; + warmupComplete = (threshold <= 0); + } + + /** + * Check whether warmup is still needed. Called by injected bytecode at target method entry. + * Returns {@code true} only on the very first call — subsequent calls (including recursive + * warmup calls) return {@code false}. + */ + public static boolean isWarmupNeeded() { + return !warmupComplete && !warmupInProgress && warmupThreshold > 0; + } + + /** + * Enter warmup phase. Sets a guard flag so recursive warmup calls skip the warmup block. + */ + public static void startWarmup() { + warmupInProgress = true; + } + + /** + * Return the configured warmup iteration count. + */ + public static int getWarmupThreshold() { + return warmupThreshold; + } + + /** + * End warmup: zero all profiling counters, mark warmup complete, clear the guard flag. + * The next execution of the method body is the clean measurement. + */ + public static void finishWarmup() { + resetAll(); + warmupComplete = true; + warmupInProgress = false; + System.err.println("[codeflash-profiler] Warmup complete after " + warmupThreshold + + " iterations, measurement started"); + } + + /** + * Reset all profiling counters across all threads. + * Called once when warmup phase completes to discard warmup data. + */ + private static void resetAll() { + for (long[] arr : allHitArrays) { + Arrays.fill(arr, 0L); + } + for (long[] arr : allTimeArrays) { + Arrays.fill(arr, 0L); + } + } + + /** + * Record a hit on a profiled line. This is the HOT PATH. + * + *

Called at every instrumented line number. Must not allocate after the initial + * thread-local array expansion. + * + * @param globalId the line's registered ID from {@link ProfilerRegistry} + */ + public static void hit(int globalId) { + long now = System.nanoTime(); + + long[] hits = hitCounts.get(); + if (globalId >= hits.length) { + hits = ensureCapacity(hitCounts, allHitArrays, globalId); + } + hits[globalId]++; + + // Attribute elapsed time to the PREVIOUS line (the one that was executing) + int[] lastId = lastLineId.get(); + long[] lastTime = lastLineTime.get(); + if (lastId[0] >= 0) { + long[] times = selfTimeNs.get(); + if (lastId[0] >= times.length) { + times = ensureCapacity(selfTimeNs, allTimeArrays, lastId[0]); + } + times[lastId[0]] += now - lastTime[0]; + } + + lastId[0] = globalId; + lastTime[0] = now; + } + + /** + * Called at method entry to push a call-stack frame. + * + *

Attributes any pending time to the previous line (the call site), then + * saves the caller's line state onto the stack so it can be restored in + * {@link #exitMethod()}. + * + * @param entryLineId the globalId of the first line in the entering method (unused for stack, + * but may be used for future total-time tracking) + */ + public static void enterMethod(int entryLineId) { + long now = System.nanoTime(); + + // Flush pending time to the line that made the call + int[] lastId = lastLineId.get(); + long[] lastTime = lastLineTime.get(); + if (lastId[0] >= 0) { + long[] times = selfTimeNs.get(); + if (lastId[0] >= times.length) { + times = ensureCapacity(selfTimeNs, allTimeArrays, lastId[0]); + } + times[lastId[0]] += now - lastTime[0]; + } + + // Push caller's line ID onto the stack + int[] depth = callStackDepth.get(); + int[] stack = callStackLineIds.get(); + if (depth[0] < stack.length) { + stack[depth[0]] = lastId[0]; + } + depth[0]++; + + // Reset for the new method scope + lastId[0] = -1; + lastTime[0] = now; + } + + /** + * Called at method exit (before RETURN or ATHROW) to pop the call stack. + * + *

Attributes remaining time to the last line of the exiting method (fixes the + * "last line always 0ms" bug), then restores the caller's timing state. + */ + public static void exitMethod() { + long now = System.nanoTime(); + + // Attribute remaining time to the last line of the exiting method + int[] lastId = lastLineId.get(); + long[] lastTime = lastLineTime.get(); + if (lastId[0] >= 0) { + long[] times = selfTimeNs.get(); + if (lastId[0] >= times.length) { + times = ensureCapacity(selfTimeNs, allTimeArrays, lastId[0]); + } + times[lastId[0]] += now - lastTime[0]; + } + + // Pop the call stack and restore parent's timing state + int[] depth = callStackDepth.get(); + if (depth[0] > 0) { + depth[0]--; + int[] stack = callStackLineIds.get(); + int parentLineId = stack[depth[0]]; + + lastId[0] = parentLineId; + lastTime[0] = now; // Self-time: exclude callee duration + } else { + lastId[0] = -1; + lastTime[0] = 0L; + } + } + + /** + * Sum hit counts across all threads. Called once at shutdown for reporting. + */ + public static long[] getGlobalHitCounts() { + int maxId = ProfilerRegistry.getMaxId(); + long[] global = new long[maxId]; + for (long[] threadHits : allHitArrays) { + int limit = Math.min(threadHits.length, maxId); + for (int i = 0; i < limit; i++) { + global[i] += threadHits[i]; + } + } + return global; + } + + /** + * Sum self-time across all threads. Called once at shutdown for reporting. + */ + public static long[] getGlobalSelfTimeNs() { + int maxId = ProfilerRegistry.getMaxId(); + long[] global = new long[maxId]; + for (long[] threadTimes : allTimeArrays) { + int limit = Math.min(threadTimes.length, maxId); + for (int i = 0; i < limit; i++) { + global[i] += threadTimes[i]; + } + } + return global; + } + + private static long[] ensureCapacity(ThreadLocal tl, List registry, int minIndex) { + long[] old = tl.get(); + int newSize = Math.max((minIndex + 1) * 2, INITIAL_CAPACITY); + long[] expanded = new long[newSize]; + System.arraycopy(old, 0, expanded, 0, old.length); + + // Update the registry: remove old, add new + registry.remove(old); + registry.add(expanded); + + tl.set(expanded); + return expanded; + } +} diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/profiler/ProfilerRegistry.java b/codeflash-java-runtime/src/main/java/com/codeflash/profiler/ProfilerRegistry.java new file mode 100644 index 000000000..f4e4f3b22 --- /dev/null +++ b/codeflash-java-runtime/src/main/java/com/codeflash/profiler/ProfilerRegistry.java @@ -0,0 +1,120 @@ +package com.codeflash.profiler; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * Maps (sourceFile, lineNumber) pairs to compact integer IDs at class-load time. + * + *

Registration happens once per unique line during class transformation (not on the hot path). + * The integer IDs are used as direct array indices in {@link ProfilerData} for zero-allocation + * hit recording at runtime. + */ +public final class ProfilerRegistry { + + private static final AtomicInteger nextId = new AtomicInteger(0); + private static final ConcurrentHashMap lineToId = new ConcurrentHashMap<>(); + + private static volatile String[] idToFile; + private static volatile int[] idToLine; + private static volatile String[] idToClassName; + private static volatile String[] idToMethodName; + + private static int capacity; + private static final Object growLock = new Object(); + + private ProfilerRegistry() {} + + /** + * Pre-allocate reverse-lookup arrays with the given capacity. + * Called once from {@link ProfilerAgent#premain} before any classes are loaded. + */ + public static void initialize(int expectedLines) { + capacity = Math.max(expectedLines * 2, 4096); + idToFile = new String[capacity]; + idToLine = new int[capacity]; + idToClassName = new String[capacity]; + idToMethodName = new String[capacity]; + } + + /** + * Register a source line and return its global ID. + * + *

Thread-safe. Called during class loading by the ASM visitor. If the same + * (className, lineNumber) pair has already been registered, returns the existing ID. + * + * @param sourceFile absolute path of the source file + * @param className dot-separated class name (e.g. "com.example.Calculator") + * @param methodName method name + * @param lineNumber 1-indexed line number in the source file + * @return compact integer ID usable as an array index + */ + public static int register(String sourceFile, String className, String methodName, int lineNumber) { + // Pack className hash + lineNumber into a 64-bit key for fast lookup + long key = ((long) className.hashCode() << 32) | (lineNumber & 0xFFFFFFFFL); + Integer existing = lineToId.get(key); + if (existing != null) { + return existing; + } + + int id = nextId.getAndIncrement(); + if (id >= capacity) { + grow(id + 1); + } + + Integer winner = lineToId.putIfAbsent(key, id); + if (winner != null) { + // Another thread registered first — use its ID + return winner; + } + + idToFile[id] = sourceFile; + idToLine[id] = lineNumber; + idToClassName[id] = className; + idToMethodName[id] = methodName; + return id; + } + + private static void grow(int minCapacity) { + synchronized (growLock) { + if (minCapacity <= capacity) return; + + int newCapacity = Math.max(minCapacity * 2, capacity * 2); + String[] newFiles = new String[newCapacity]; + int[] newLines = new int[newCapacity]; + String[] newClasses = new String[newCapacity]; + String[] newMethods = new String[newCapacity]; + + System.arraycopy(idToFile, 0, newFiles, 0, capacity); + System.arraycopy(idToLine, 0, newLines, 0, capacity); + System.arraycopy(idToClassName, 0, newClasses, 0, capacity); + System.arraycopy(idToMethodName, 0, newMethods, 0, capacity); + + idToFile = newFiles; + idToLine = newLines; + idToClassName = newClasses; + idToMethodName = newMethods; + capacity = newCapacity; + } + } + + public static int getMaxId() { + return nextId.get(); + } + + public static String getFile(int id) { + return idToFile[id]; + } + + public static int getLine(int id) { + return idToLine[id]; + } + + public static String getClassName(int id) { + return idToClassName[id]; + } + + public static String getMethodName(int id) { + return idToMethodName[id]; + } +} diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/profiler/ProfilerReporter.java b/codeflash-java-runtime/src/main/java/com/codeflash/profiler/ProfilerReporter.java new file mode 100644 index 000000000..71f05c34c --- /dev/null +++ b/codeflash-java-runtime/src/main/java/com/codeflash/profiler/ProfilerReporter.java @@ -0,0 +1,92 @@ +package com.codeflash.profiler; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Map; + +/** + * Writes profiling results to a JSON file in the same format as the old source-injected profiler. + * + *

Output format (consumed by {@code JavaLineProfiler.parse_results()} in Python): + *

+ * {
+ *   "/path/to/File.java:10": {
+ *     "hits": 100,
+ *     "time": 5000000,
+ *     "file": "/path/to/File.java",
+ *     "line": 10,
+ *     "content": "int x = compute();"
+ *   },
+ *   ...
+ * }
+ * 
+ */ +public final class ProfilerReporter { + + private ProfilerReporter() {} + + /** + * Write profiling results to the output file. Called once from a JVM shutdown hook. + */ + public static void writeResults(String outputFile, ProfilerConfig config) { + if (outputFile == null || outputFile.isEmpty()) return; + + long[] globalHits = ProfilerData.getGlobalHitCounts(); + long[] globalTimes = ProfilerData.getGlobalSelfTimeNs(); + int maxId = ProfilerRegistry.getMaxId(); + Map lineContents = config.getLineContents(); + + StringBuilder json = new StringBuilder(Math.max(maxId * 128, 256)); + json.append("{\n"); + + boolean first = true; + for (int id = 0; id < maxId; id++) { + long hits = (id < globalHits.length) ? globalHits[id] : 0; + long timeNs = (id < globalTimes.length) ? globalTimes[id] : 0; + if (hits == 0 && timeNs == 0) continue; + + String file = ProfilerRegistry.getFile(id); + int line = ProfilerRegistry.getLine(id); + if (file == null) continue; + + String key = file + ":" + line; + String content = lineContents.getOrDefault(key, ""); + + if (!first) json.append(",\n"); + first = false; + + json.append(" \"").append(escapeJson(key)).append("\": {\n"); + json.append(" \"hits\": ").append(hits).append(",\n"); + json.append(" \"time\": ").append(timeNs).append(",\n"); + json.append(" \"file\": \"").append(escapeJson(file)).append("\",\n"); + json.append(" \"line\": ").append(line).append(",\n"); + json.append(" \"content\": \"").append(escapeJson(content)).append("\"\n"); + json.append(" }"); + } + + json.append("\n}"); + + try { + Path path = Paths.get(outputFile); + Path parent = path.getParent(); + if (parent != null) { + Files.createDirectories(parent); + } + Files.write(path, json.toString().getBytes(StandardCharsets.UTF_8)); + } catch (IOException e) { + System.err.println("[codeflash-profiler] Failed to write results: " + e.getMessage()); + } + } + + private static String escapeJson(String s) { + if (s == null) return ""; + return s.replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t"); + } +} diff --git a/codeflash/code_utils/tabulate.py b/codeflash/code_utils/tabulate.py index 1024afc4b..0d5004f6c 100644 --- a/codeflash/code_utils/tabulate.py +++ b/codeflash/code_utils/tabulate.py @@ -705,7 +705,7 @@ def tabulate( # format rows and columns, convert numeric values to strings cols = list(izip_longest(*list_of_lists)) numparses = _expand_numparse(disable_numparse, len(cols)) - coltypes = [_column_type(col, numparse=np) for col, np in zip(cols, numparses)] + coltypes = [_column_type(col, has_invisible, numparse=np) for col, np in zip(cols, numparses)] if isinstance(floatfmt, str): # old version float_formats = len(cols) * [floatfmt] # just duplicate the string to use in each column else: # if floatfmt is list, tuple etc we have one per column diff --git a/codeflash/languages/java/line_profiler.py b/codeflash/languages/java/line_profiler.py index 0f4f5f3ed..2b3ff3fb5 100644 --- a/codeflash/languages/java/line_profiler.py +++ b/codeflash/languages/java/line_profiler.py @@ -1,8 +1,12 @@ -"""Line profiler instrumentation for Java. +"""Line profiler for Java via bytecode instrumentation agent. -This module provides functionality to instrument Java code with line-level -profiling similar to Python's line_profiler and JavaScript's profiler. -It tracks execution counts and timing for each line in instrumented functions. +This module generates configuration for the CodeFlash profiler Java agent, which +instruments bytecode at class-load time using ASM. The agent uses zero-allocation +thread-local arrays for hit counting and a per-thread call stack for accurate +self-time attribution. + +No source code modification is needed — the agent intercepts class loading via +-javaagent and injects probes at each LineNumber table entry. """ from __future__ import annotations @@ -10,49 +14,38 @@ import json import logging import re -from typing import TYPE_CHECKING +from pathlib import Path +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from pathlib import Path - from tree_sitter import Node from codeflash.languages.base import FunctionInfo logger = logging.getLogger(__name__) +AGENT_JAR_NAME = "codeflash-runtime-1.0.0.jar" +DEFAULT_WARMUP_ITERATIONS = 100 -class JavaLineProfiler: - """Instruments Java code for line-level profiling. - This class adds profiling code to Java functions to track: - - How many times each line executes - - How much time is spent on each line (in nanoseconds) - - Total execution time per function +class JavaLineProfiler: + """Configures the Java profiler agent for line-level profiling. Example: profiler = JavaLineProfiler(output_file=Path("profile.json")) - instrumented = profiler.instrument_source(source, file_path, functions) - # Run instrumented code + config_path = profiler.generate_agent_config(source, file_path, functions, config_path) + jvm_arg = profiler.build_javaagent_arg(config_path) + # Run Java with: java -cp ... ClassName results = JavaLineProfiler.parse_results(Path("profile.json")) """ - def __init__(self, output_file: Path) -> None: - """Initialize the line profiler. - - Args: - output_file: Path where profiling results will be written (JSON format). - - """ + def __init__(self, output_file: Path, warmup_iterations: int = DEFAULT_WARMUP_ITERATIONS) -> None: self.output_file = output_file + self.warmup_iterations = warmup_iterations self.profiler_class = "CodeflashLineProfiler" - self.profiler_var = "__codeflashProfiler__" - self.line_contents: dict[str, str] = {} - # Java executable statement types - # Moved to an instance-level frozenset to avoid rebuilding this set on every call. - self._executable_types = frozenset( + self.executable_types = frozenset( { "expression_statement", "return_statement", @@ -76,44 +69,106 @@ def __init__(self, output_file: Path) -> None: } ) - def instrument_source(self, source: str, file_path: Path, functions: list[FunctionInfo], analyzer=None) -> str: - """Instrument Java source code with line profiling. + # === Agent-based profiling (bytecode instrumentation) === - Adds profiling instrumentation to track line-level execution for the - specified functions. + def generate_agent_config( + self, source: str, file_path: Path, functions: list[FunctionInfo], config_output_path: Path + ) -> Path: + """Generate config JSON for the profiler agent. + + Reads the source to extract line contents and resolves the JVM internal + class name, then writes a config JSON that the agent uses to know which + classes/methods to instrument at class-load time. Args: - source: Original Java source code. - file_path: Path to the source file. - functions: List of functions to instrument. - analyzer: Optional JavaAnalyzer instance. + source: Java source code of the file. + file_path: Absolute path to the source file. + functions: Functions to profile. + config_output_path: Where to write the config JSON. Returns: - Instrumented source code with profiling. + Path to the written config file. """ - if not functions: - return source + class_name = resolve_internal_class_name(file_path, source) + lines = source.splitlines() + line_contents: dict[str, str] = {} + method_targets = [] + + for func in functions: + for line_num in range(func.starting_line, func.ending_line + 1): + if 1 <= line_num <= len(lines): + content = lines[line_num - 1].strip() + if ( + content + and not content.startswith("//") + and not content.startswith("/*") + and not content.startswith("*") + ): + key = f"{file_path.as_posix()}:{line_num}" + line_contents[key] = content + + method_targets.append( + { + "name": func.function_name, + "startLine": func.starting_line, + "endLine": func.ending_line, + "sourceFile": file_path.as_posix(), + } + ) + + config = { + "outputFile": str(self.output_file), + "warmupIterations": self.warmup_iterations, + "targets": [{"className": class_name, "methods": method_targets}], + "lineContents": line_contents, + } + + config_output_path.parent.mkdir(parents=True, exist_ok=True) + config_output_path.write_text(json.dumps(config, indent=2), encoding="utf-8") + return config_output_path + + def build_javaagent_arg(self, config_path: Path) -> str: + """Return the -javaagent JVM argument string.""" + agent_jar = find_agent_jar() + if agent_jar is None: + msg = f"{AGENT_JAR_NAME} not found in resources or dev build directory" + raise FileNotFoundError(msg) + return f"-javaagent:{agent_jar}=config={config_path}" + + # === Source-level instrumentation === + + def instrument_source( + self, source: str, file_path: Path, functions: list[FunctionInfo], analyzer: Any = None + ) -> str: + """Instrument Java source code with line profiling. - if analyzer is None: - from codeflash.languages.java.parser import get_java_analyzer + Injects a profiler class and per-line hit() calls directly into the source. - analyzer = get_java_analyzer() + Args: + source: Java source code of the file. + file_path: Absolute path to the source file. + functions: Functions to instrument. + analyzer: JavaAnalyzer instance for parsing/validation. + Returns: + Instrumented source code, or original source if instrumentation fails. + + """ # Initialize line contents map - self.line_contents = {} + self.line_contents: dict[str, str] = {} lines = source.splitlines(keepends=True) # Process functions in reverse order to preserve line numbers for func in sorted(functions, key=lambda f: f.starting_line, reverse=True): - func_lines = self._instrument_function(func, lines, file_path, analyzer) + func_lines = self.instrument_function(func, lines, file_path, analyzer) start_idx = func.starting_line - 1 end_idx = func.ending_line lines = lines[:start_idx] + func_lines + lines[end_idx:] # Add profiler class and initialization - profiler_class_code = self._generate_profiler_class() + profiler_class_code = self.generate_profiler_class() # Insert profiler class before the package's first class # Find the first class/interface/enum/record declaration @@ -136,10 +191,10 @@ def instrument_source(self, source: str, file_path: Path, functions: list[Functi return source return result - def _generate_profiler_class(self) -> str: + def generate_profiler_class(self) -> str: """Generate Java code for profiler class.""" # Store line contents as a simple map (embedded directly in code) - line_contents_code = self._generate_line_contents_map() + line_contents_code = self.generate_line_contents_map() return f""" /** @@ -269,7 +324,7 @@ class {self.profiler_class} {{ }} """ - def _instrument_function(self, func: FunctionInfo, lines: list[str], file_path: Path, analyzer) -> list[str]: + def instrument_function(self, func: FunctionInfo, lines: list[str], file_path: Path, analyzer: Any) -> list[str]: """Instrument a single function with line profiling. Args: @@ -290,7 +345,7 @@ def _instrument_function(self, func: FunctionInfo, lines: list[str], file_path: try: tree = analyzer.parse(source.encode("utf8")) - executable_lines = self._find_executable_lines(tree.root_node) + executable_lines = self.find_executable_lines(tree.root_node) except Exception as e: logger.warning("Failed to parse function %s: %s", func.function_name, e) return func_lines @@ -345,7 +400,7 @@ def _instrument_function(self, func: FunctionInfo, lines: list[str], file_path: return instrumented_lines - def _generate_line_contents_map(self) -> str: + def generate_line_contents_map(self) -> str: """Generate Java code to initialize line contents map.""" lines = [] for key, content in self.line_contents.items(): @@ -354,7 +409,7 @@ def _generate_line_contents_map(self) -> str: lines.append(f' map.put("{key}", "{escaped}");') return "\n".join(lines) - def _find_executable_lines(self, node: Node) -> set[int]: + def find_executable_lines(self, node: Node) -> set[int]: """Find lines that contain executable statements. Args: @@ -368,7 +423,7 @@ def _find_executable_lines(self, node: Node) -> set[int]: # Use an explicit stack to avoid recursion overhead on deep ASTs. stack = [node] - types = self._executable_types + types = self.executable_types add_line = executable_lines.add while stack: @@ -385,113 +440,189 @@ def _find_executable_lines(self, node: Node) -> set[int]: return executable_lines - @staticmethod - def parse_results(profile_file: Path) -> dict: - """Parse line profiling results from output file. + # === Result parsing (shared by both approaches) === - Args: - profile_file: Path to profiling results JSON file. + @staticmethod + def parse_results(profile_file: Path) -> dict[str, Any]: + """Parse line profiling results from the agent's JSON output. - Returns: - Dictionary with profiling statistics: - { - "timings": { - "file_path": { - line_num: { - "hits": int, - "time_ns": int, - "time_ms": float, - "content": str - } - } - }, - "unit": 1e-9, - "raw_data": {...} - } + Returns the same format as parse_line_profile_test_output.parse_line_profile_results() + for non-Python languages: + { + "timings": {(filename, start_lineno, func_name): [(lineno, hits, time_ns), ...]}, + "unit": 1e-9, + "str_out": "" + } """ if not profile_file.exists(): - return {"timings": {}, "unit": 1e-9, "raw_data": {}, "str_out": ""} + return {"timings": {}, "unit": 1e-9, "str_out": ""} try: with profile_file.open("r") as f: data = json.load(f) - # Group by file - timings = {} - for key, stats in data.items(): - file_path, line_num_str = key.rsplit(":", 1) - line_num = int(line_num_str) - time_ns = int(stats["time"]) # nanoseconds - time_ms = time_ns / 1e6 # convert to milliseconds - hits = stats["hits"] - content = stats.get("content", "") - - if file_path not in timings: - timings[file_path] = {} - - timings[file_path][line_num] = { - "hits": hits, - "time_ns": time_ns, - "time_ms": time_ms, - "content": content, - } - - result = { - "timings": timings, - "unit": 1e-9, # nanoseconds - "raw_data": data, - } - result["str_out"] = format_line_profile_results(result) + # Load method ranges and line contents from config file + method_ranges, config_line_contents = load_method_ranges(profile_file) + + line_contents: dict[tuple[str, int], str] = {} + + if method_ranges: + # Group lines by method using config ranges + grouped_timings: dict[tuple[str, int, str], list[tuple[int, int, int]]] = {} + for key, stats in data.items(): + fp = stats.get("file") + line_num = stats.get("line") + if fp is None or line_num is None: + fp, line_str = key.rsplit(":", 1) + line_num = int(line_str) + line_num = int(line_num) + + line_contents[(fp, line_num)] = stats.get("content", "") + entry = (line_num, int(stats.get("hits", 0)), int(stats.get("time", 0))) + + method_name, method_start = find_method_for_line(fp, line_num, method_ranges) + group_key = (fp, method_start, method_name) + grouped_timings.setdefault(group_key, []).append(entry) + + # Fill in missing lines from config (closing braces, etc.) + for config_key, content in config_line_contents.items(): + fp, line_str = config_key.rsplit(":", 1) + line_num = int(line_str) + if (fp, line_num) not in line_contents: + line_contents[(fp, line_num)] = content + method_name, method_start = find_method_for_line(fp, line_num, method_ranges) + group_key = (fp, method_start, method_name) + grouped_timings.setdefault(group_key, []).append((line_num, 0, 0)) + + for group_key in grouped_timings: + grouped_timings[group_key].sort(key=lambda t: t[0]) + else: + # No config — fall back to grouping all lines by file + lines_by_file: dict[str, list[tuple[int, int, int]]] = {} + for key, stats in data.items(): + fp = stats.get("file") + line_num = stats.get("line") + if fp is None or line_num is None: + fp, line_str = key.rsplit(":", 1) + line_num = int(line_str) + line_num = int(line_num) + + lines_by_file.setdefault(fp, []).append( + (line_num, int(stats.get("hits", 0)), int(stats.get("time", 0))) + ) + line_contents[(fp, line_num)] = stats.get("content", "") + + grouped_timings = {} + for fp, line_stats in lines_by_file.items(): + sorted_stats = sorted(line_stats, key=lambda t: t[0]) + if sorted_stats: + grouped_timings[(fp, sorted_stats[0][0], Path(fp).name)] = sorted_stats + + result: dict[str, Any] = {"timings": grouped_timings, "unit": 1e-9, "line_contents": line_contents} + result["str_out"] = format_line_profile_results(result, line_contents) return result except Exception: logger.exception("Failed to parse line profile results") - return {"timings": {}, "unit": 1e-9, "raw_data": {}, "str_out": ""} + return {"timings": {}, "unit": 1e-9, "str_out": ""} -def format_line_profile_results(results: dict, file_path: Path | None = None) -> str: - """Format line profiling results for display. - - Args: - results: Results from parse_results(). - file_path: Optional file path to filter results. +def load_method_ranges(profile_file: Path) -> tuple[list[tuple[str, str, int, int]], dict[str, str]]: + """Load method ranges and line contents from the agent config file. Returns: - Formatted string showing per-line statistics. + (method_ranges, config_line_contents) where method_ranges is a list of + (source_file, method_name, start_line, end_line) and config_line_contents + is the lineContents dict from the config (key: "file:line", value: source text). """ - if not results or not results.get("timings"): - return "No profiling data available" + config_path = profile_file.with_suffix(".config.json") + if not config_path.exists(): + return [], {} + try: + config = json.loads(config_path.read_text(encoding="utf-8")) + ranges = [] + for target in config.get("targets", []): + for method in target.get("methods", []): + ranges.append((method.get("sourceFile", ""), method["name"], method["startLine"], method["endLine"])) + return ranges, config.get("lineContents", {}) + except Exception: + return [], {} + + +def find_method_for_line( + file_path: str, line_num: int, method_ranges: list[tuple[str, str, int, int]] +) -> tuple[str, int]: + """Find which method a line belongs to based on config ranges. + + Returns (method_name, method_start_line). Falls back to (basename, line_num) + if no matching method range is found. + """ + for source_file, method_name, start_line, end_line in method_ranges: + if file_path == source_file and start_line <= line_num <= end_line: + return method_name, start_line + return Path(file_path).name, line_num - output = [] - output.append("Line Profiling Results") - output.append("=" * 80) - timings = results["timings"] +def find_agent_jar() -> Path | None: + """Locate the profiler agent JAR file (now bundled in codeflash-runtime). + + Checks local Maven repo, package resources, and development build directory. + """ + # Check local Maven repository first (fastest) + m2_jar = Path.home() / ".m2" / "repository" / "com" / "codeflash" / "codeflash-runtime" / "1.0.0" / AGENT_JAR_NAME + if m2_jar.exists(): + return m2_jar + + # Check bundled JAR in package resources + resources_jar = Path(__file__).parent / "resources" / AGENT_JAR_NAME + if resources_jar.exists(): + return resources_jar - # Filter to specific file if requested - if file_path: - file_key = str(file_path) - timings = {file_key: timings.get(file_key, {})} + # Check development build directory + dev_jar = Path(__file__).parent.parent.parent.parent / "codeflash-java-runtime" / "target" / AGENT_JAR_NAME + if dev_jar.exists(): + return dev_jar - for file, lines in sorted(timings.items()): - if not lines: - continue + return None - output.append(f"\nFile: {file}") - output.append("-" * 80) - output.append(f"{'Line':>6} | {'Hits':>10} | {'Time (ms)':>12} | {'Avg (ms)':>12} | Code") - output.append("-" * 80) - # Sort by line number - for line_num in sorted(lines.keys()): - stats = lines[line_num] - hits = stats["hits"] - time_ms = stats["time_ms"] - avg_ms = time_ms / hits if hits > 0 else 0 - content = stats.get("content", "")[:50] # Truncate long lines +def resolve_internal_class_name(file_path: Path, source: str) -> str: + """Resolve the JVM internal class name (slash-separated) from source. + + Parses the package statement and combines with the filename stem. + e.g. "package com.example;" + "Calculator.java" → "com/example/Calculator" + """ + for line in source.splitlines(): + stripped = line.strip() + if stripped.startswith("package "): + package = stripped[8:].rstrip(";").strip() + return f"{package.replace('.', '/')}/{file_path.stem}" + # No package — default package + return file_path.stem + + +def format_line_profile_results( + results: dict[str, Any], line_contents: dict[tuple[str, int], str] | None = None +) -> str: + """Format line profiling results using the same tabulate pipe format as Python. + + Args: + results: Parsed results with timings in grouped format: + {(filename, start_lineno, func_name): [(lineno, hits, time_ns), ...]} + line_contents: Mapping of (filename, lineno) to source line content. + + Returns: + Formatted string matching the Python line_profiler output format. + + """ + if not results or not results.get("timings"): + return "" + + if line_contents is None: + line_contents = results.get("line_contents", {}) - output.append(f"{line_num:6d} | {hits:10d} | {time_ms:12.3f} | {avg_ms:12.6f} | {content}") + from codeflash.verification.parse_line_profile_test_output import show_text_non_python - return "\n".join(output) + return show_text_non_python(results, line_contents) diff --git a/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar b/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar new file mode 100644 index 000000000..848663294 Binary files /dev/null and b/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar differ diff --git a/codeflash/languages/java/support.py b/codeflash/languages/java/support.py index f56a0dab5..273f39385 100644 --- a/codeflash/languages/java/support.py +++ b/codeflash/languages/java/support.py @@ -62,6 +62,8 @@ class JavaSupport(LanguageSupport): def __init__(self) -> None: """Initialize Java support.""" self._analyzer = get_java_analyzer() + self.line_profiler_agent_arg: str | None = None + self.line_profiler_warmup_iterations: int = 0 @property def language(self) -> Language: @@ -272,7 +274,7 @@ def _build_runtime_map(self, inv_id_runtimes: dict[InvocationId, list[int]]) -> def compare_test_results( self, original_results_path: Path, candidate_results_path: Path, project_root: Path | None = None - ) -> tuple[bool, list]: + ) -> tuple[bool, list[Any]]: """Compare test results between original and candidate code.""" return _compare_test_results(original_results_path, candidate_results_path, project_root=project_root) @@ -384,35 +386,41 @@ def instrument_existing_test( def instrument_source_for_line_profiler( self, func_info: FunctionToOptimize, line_profiler_output_file: Path ) -> bool: - """Instrument source code for line profiling. + """Prepare line profiling via the bytecode-instrumentation agent. + + Generates a config JSON that the Java agent uses at class-load time to + know which methods to instrument. The agent is loaded via -javaagent + when the JVM starts. The config includes warmup iterations so the agent + discards JIT warmup data before measurement. Args: - func_info: Function to instrument. - line_profiler_output_file: Path where profiling results will be written. + func_info: Function to profile. + line_profiler_output_file: Path where profiling results will be written by the agent. Returns: - True if instrumentation succeeded, False otherwise. + True if preparation succeeded, False otherwise. """ from codeflash.languages.java.line_profiler import JavaLineProfiler try: - # Read source file source = func_info.file_path.read_text(encoding="utf-8") - # Instrument with line profiler profiler = JavaLineProfiler(output_file=line_profiler_output_file) - instrumented = profiler.instrument_source(source, func_info.file_path, [func_info], self._analyzer) - # Write instrumented source back - func_info.file_path.write_text(instrumented, encoding="utf-8") + config_path = line_profiler_output_file.with_suffix(".config.json") + profiler.generate_agent_config( + source=source, file_path=func_info.file_path, functions=[func_info], config_output_path=config_path + ) + self.line_profiler_agent_arg = profiler.build_javaagent_arg(config_path) + self.line_profiler_warmup_iterations = profiler.warmup_iterations return True except Exception: - logger.exception("Failed to instrument %s for line profiling", func_info.function_name) + logger.exception("Failed to prepare line profiling for %s", func_info.function_name) return False - def parse_line_profile_results(self, line_profiler_output_file: Path) -> dict: + def parse_line_profile_results(self, line_profiler_output_file: Path) -> dict[str, Any]: """Parse line profiler output for Java. Args: @@ -473,7 +481,7 @@ def run_line_profile_tests( project_root: Path | None = None, line_profile_output_file: Path | None = None, ) -> tuple[Path, Any]: - """Run tests with line profiling enabled. + """Run tests with the profiler agent attached. Args: test_paths: TestFiles object containing test file information. @@ -496,6 +504,7 @@ def run_line_profile_tests( timeout=timeout, project_root=project_root, line_profile_output_file=line_profile_output_file, + javaagent_arg=self.line_profiler_agent_arg, ) diff --git a/codeflash/languages/java/test_runner.py b/codeflash/languages/java/test_runner.py index 1ebc2bc8f..033681c08 100644 --- a/codeflash/languages/java/test_runner.py +++ b/codeflash/languages/java/test_runner.py @@ -1311,6 +1311,7 @@ def _run_maven_tests( mode: str = "behavior", enable_coverage: bool = False, test_module: str | None = None, + javaagent_arg: str | None = None, ) -> subprocess.CompletedProcess: """Run Maven tests with Surefire. @@ -1367,7 +1368,10 @@ def _run_maven_tests( " --add-opens java.base/java.net=ALL-UNNAMED" " --add-opens java.base/java.util.zip=ALL-UNNAMED" ) - cmd.append(f"-DargLine={add_opens_flags}") + if javaagent_arg: + cmd.append(f"-DargLine={javaagent_arg} {add_opens_flags}") + else: + cmd.append(f"-DargLine={add_opens_flags}") # For performance mode, disable Surefire's file-based output redirection. # By default, Surefire captures System.out.println() to .txt report files, @@ -1793,11 +1797,12 @@ def run_line_profile_tests( timeout: int | None = None, project_root: Path | None = None, line_profile_output_file: Path | None = None, + javaagent_arg: str | None = None, ) -> tuple[Path, Any]: - """Run tests with line profiling enabled. + """Run tests with the profiler agent attached. - Runs the instrumented tests once to collect line profiling data. - The profiler will save results to line_profile_output_file on JVM exit. + The agent instruments bytecode at class-load time — no source modification needed. + Profiling results are written to line_profile_output_file on JVM exit. Args: test_paths: TestFiles object or list of test file paths. @@ -1806,6 +1811,7 @@ def run_line_profile_tests( timeout: Optional timeout in seconds. project_root: Project root directory. line_profile_output_file: Path where profiling results will be written. + javaagent_arg: Optional -javaagent:... JVM argument for the profiler agent. Returns: Tuple of (result_file_path, subprocess_result). @@ -1833,7 +1839,13 @@ def run_line_profile_tests( effective_timeout = max(timeout or min_timeout, min_timeout) logger.debug("Running line profiling tests (single run) with timeout=%ds", effective_timeout) result = _run_maven_tests( - maven_root, test_paths, run_env, timeout=effective_timeout, mode="line_profile", test_module=test_module + maven_root, + test_paths, + run_env, + timeout=effective_timeout, + mode="line_profile", + test_module=test_module, + javaagent_arg=javaagent_arg, ) # Get result XML path diff --git a/codeflash/languages/registry.py b/codeflash/languages/registry.py index 38688cab6..17c2539a7 100644 --- a/codeflash/languages/registry.py +++ b/codeflash/languages/registry.py @@ -47,16 +47,6 @@ def _ensure_languages_registered() -> None: # Import support modules to trigger registration # These imports are deferred to avoid circular imports - import contextlib - - with contextlib.suppress(ImportError): - from codeflash.languages.python import support as _ - - with contextlib.suppress(ImportError): - from codeflash.languages.javascript import support as _ - - with contextlib.suppress(ImportError): - from codeflash.languages.java import support as _ _languages_registered = True diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index d6d310f55..526edfd30 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -3237,9 +3237,6 @@ def line_profiler_step( if self.language_support is not None and hasattr(self.language_support, "instrument_source_for_line_profiler"): try: line_profiler_output_path = get_run_tmp_file(Path("line_profiler_output.json")) - # NOTE: currently this handles single file only, add support to multi file instrumentation (or should it be kept for the main file only) - original_source = Path(self.function_to_optimize.file_path).read_text() - # Instrument source code success = self.language_support.instrument_source_for_line_profiler( func_info=self.function_to_optimize, line_profiler_output_file=line_profiler_output_path ) @@ -3268,9 +3265,6 @@ def line_profiler_step( except Exception as e: logger.warning(f"Failed to run line profiling: {e}") return {"timings": {}, "unit": 0, "str_out": ""} - finally: - # restore original source - Path(self.function_to_optimize.file_path).write_text(original_source) logger.warning(f"Language support for {self.language_support.language} doesn't support line profiling") return {"timings": {}, "unit": 0, "str_out": ""} diff --git a/codeflash/verification/parse_line_profile_test_output.py b/codeflash/verification/parse_line_profile_test_output.py index 4ef799425..ebe8f7296 100644 --- a/codeflash/verification/parse_line_profile_test_output.py +++ b/codeflash/verification/parse_line_profile_test_output.py @@ -95,6 +95,9 @@ def show_text_non_python(stats: dict, line_contents: dict[tuple[str, int], str]) default_column_sizes = {"hits": 9, "time": 12, "perhit": 8, "percent": 8} table_rows = [] for lineno, nhits, time in timings: + if nhits == 0: + table_rows.append(("", "", "", "", line_contents.get((fn, lineno), ""))) + continue percent = "" if total_time == 0 else "%5.1f" % (100 * time / total_time) time_disp = f"{time:5.1f}" if len(time_disp) > default_column_sizes["time"]: diff --git a/tests/test_languages/test_java/test_instrumentation.py b/tests/test_languages/test_java/test_instrumentation.py index 588f803a3..ab4886160 100644 --- a/tests/test_languages/test_java/test_instrumentation.py +++ b/tests/test_languages/test_java/test_instrumentation.py @@ -1890,6 +1890,159 @@ class InnerTests { assert result == expected +class TestMultiByteUtf8Instrumentation: + """Tests that timing instrumentation handles multi-byte UTF-8 source correctly. + + The instrumentation uses tree-sitter byte offsets which must be converted to + character offsets for Python string slicing (instrumentation.py:782). + Multi-byte characters (CJK, accented chars) shift byte positions + relative to character positions, so incorrect conversion corrupts the output. + """ + + def test_instrument_with_cjk_in_string_literal(self, tmp_path: Path): + """Target function call after a string literal containing CJK characters.""" + test_file = tmp_path / "Utf8Test.java" + source = """import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; + +public class Utf8Test { + @Test + public void testWithCjk() { + String label = "テスト名前"; + assertEquals(42, compute(21)); + } +} +""" + test_file.write_text(source, encoding="utf-8") + + func = FunctionToOptimize( + function_name="compute", + file_path=tmp_path / "Target.java", + starting_line=1, + ending_line=5, + parents=[], + is_method=True, + language="java", + ) + + success, result = instrument_existing_test( + test_string=source, + function_to_optimize=func, + mode="performance", + test_path=test_file, + ) + + # The blank line between _cf_fn1 and the prefix body has 8 trailing spaces + # (the indent level) — this is the f"{indent}\n" separator in the instrumentation code. + expected = ( + 'import org.junit.jupiter.api.Test;\n' + 'import static org.junit.jupiter.api.Assertions.*;\n' + '\n' + 'public class Utf8Test__perfonlyinstrumented {\n' + ' @Test\n' + ' public void testWithCjk() {\n' + ' // Codeflash timing instrumentation with inner loop for JIT warmup\n' + ' int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));\n' + ' int _cf_innerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));\n' + ' String _cf_mod1 = "Utf8Test";\n' + ' String _cf_cls1 = "Utf8Test";\n' + ' String _cf_test1 = "testWithCjk";\n' + ' String _cf_fn1 = "compute";\n' + ' \n' + ' String label = "\u30c6\u30b9\u30c8\u540d\u524d";\n' + ' for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++) {\n' + ' System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + "######$!");\n' + ' long _cf_end1 = -1;\n' + ' long _cf_start1 = 0;\n' + ' try {\n' + ' _cf_start1 = System.nanoTime();\n' + ' assertEquals(42, compute(21));\n' + ' _cf_end1 = System.nanoTime();\n' + ' } finally {\n' + ' long _cf_end1_finally = System.nanoTime();\n' + ' long _cf_dur1 = (_cf_end1 != -1 ? _cf_end1 : _cf_end1_finally) - _cf_start1;\n' + ' System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + ":" + _cf_dur1 + "######!");\n' + ' }\n' + ' }\n' + ' }\n' + '}\n' + ) + assert success is True + assert result == expected + + def test_instrument_with_multibyte_in_comment(self, tmp_path: Path): + """Target function call after a comment with accented characters (multi-byte UTF-8).""" + test_file = tmp_path / "AccentTest.java" + source = """import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; + +public class AccentTest { + @Test + public void testWithAccent() { + // R\u00e9sum\u00e9 processing test with accented chars + String name = "caf\u00e9"; + assertEquals(10, calculate(5)); + } +} +""" + test_file.write_text(source, encoding="utf-8") + + func = FunctionToOptimize( + function_name="calculate", + file_path=tmp_path / "Target.java", + starting_line=1, + ending_line=5, + parents=[], + is_method=True, + language="java", + ) + + success, result = instrument_existing_test( + test_string=source, + function_to_optimize=func, + mode="performance", + test_path=test_file, + ) + + assert success is True + + expected = ( + 'import org.junit.jupiter.api.Test;\n' + 'import static org.junit.jupiter.api.Assertions.*;\n' + '\n' + 'public class AccentTest__perfonlyinstrumented {\n' + ' @Test\n' + ' public void testWithAccent() {\n' + ' // Codeflash timing instrumentation with inner loop for JIT warmup\n' + ' int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));\n' + ' int _cf_innerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));\n' + ' String _cf_mod1 = "AccentTest";\n' + ' String _cf_cls1 = "AccentTest";\n' + ' String _cf_test1 = "testWithAccent";\n' + ' String _cf_fn1 = "calculate";\n' + ' \n' + ' // R\u00e9sum\u00e9 processing test with accented chars\n' + ' String name = "caf\u00e9";\n' + ' for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++) {\n' + ' System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + "######$!");\n' + ' long _cf_end1 = -1;\n' + ' long _cf_start1 = 0;\n' + ' try {\n' + ' _cf_start1 = System.nanoTime();\n' + ' assertEquals(10, calculate(5));\n' + ' _cf_end1 = System.nanoTime();\n' + ' } finally {\n' + ' long _cf_end1_finally = System.nanoTime();\n' + ' long _cf_dur1 = (_cf_end1 != -1 ? _cf_end1 : _cf_end1_finally) - _cf_start1;\n' + ' System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + ":" + _cf_dur1 + "######!");\n' + ' }\n' + ' }\n' + ' }\n' + '}\n' + ) + assert result == expected + + # Skip all E2E tests if Maven is not available requires_maven = pytest.mark.skipif( find_maven_executable() is None, @@ -2945,3 +3098,207 @@ def __init__(self, path): assert iter_0_count == 2, f"Expected 2 markers for iteration 0, got {iter_0_count}" assert iter_1_count == 2, f"Expected 2 markers for iteration 1, got {iter_1_count}" + + def test_time_correction_instrumentation(self, java_project): + """Test timing accuracy of performance instrumentation with known durations. + + Mirrors Python's test_time_correction_instrumentation — uses a busy-wait + function (SpinWait) with known nanosecond durations and verifies that: + 1. Instrumented source matches exactly (full string equality) + 2. Pipeline produces correct number of timing results + 3. Measured runtimes match expected durations within tolerance + + Python equivalent uses accurate_sleepfunc(0.01) → 100ms and accurate_sleepfunc(0.02) → 200ms + with rel_tol=0.01. Java uses System.nanoTime() busy-wait with 50ms and 100ms durations. + """ + import math + + project_root, src_dir, test_dir = java_project + + # Create SpinWait class — Java equivalent of Python's accurate_sleepfunc + (src_dir / "SpinWait.java").write_text("""package com.example; + +public class SpinWait { + public static long spinWait(long durationNs) { + long start = System.nanoTime(); + while (System.nanoTime() - start < durationNs) { + } + return durationNs; + } +} +""", encoding="utf-8") + + # Two test methods with known durations — mirrors Python's parametrize with + # (0.01, 0.010) and (0.02, 0.020) which map to 100ms and 200ms + test_source = """package com.example; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; + +public class SpinWaitTest { + @Test + public void testSpinShort() { + assertEquals(50_000_000L, SpinWait.spinWait(50_000_000L)); + } + + @Test + public void testSpinLong() { + assertEquals(100_000_000L, SpinWait.spinWait(100_000_000L)); + } +} +""" + test_file = test_dir / "SpinWaitTest.java" + test_file.write_text(test_source, encoding="utf-8") + + func_info = FunctionToOptimize( + function_name="spinWait", + file_path=src_dir / "SpinWait.java", + starting_line=4, + ending_line=9, + parents=[], + is_method=True, + language="java", + ) + + # Instrument for performance mode + success, instrumented = instrument_existing_test( + test_string=test_source, + function_to_optimize=func_info, + mode="performance", + test_path=test_file, + ) + assert success, "Instrumentation should succeed" + + # Assert exact instrumented source (full string equality) — mirrors Python's + # assert new_test.replace('"', "'") == expected.format(...).replace('"', "'") + expected_instrumented = """package com.example; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; + +public class SpinWaitTest__perfonlyinstrumented { + @Test + public void testSpinShort() { + // Codeflash timing instrumentation with inner loop for JIT warmup + int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX")); + int _cf_innerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100")); + String _cf_mod1 = "SpinWaitTest"; + String _cf_cls1 = "SpinWaitTest"; + String _cf_test1 = "testSpinShort"; + String _cf_fn1 = "spinWait"; + + for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++) { + System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + "######$!"); + long _cf_end1 = -1; + long _cf_start1 = 0; + try { + _cf_start1 = System.nanoTime(); + assertEquals(50_000_000L, SpinWait.spinWait(50_000_000L)); + _cf_end1 = System.nanoTime(); + } finally { + long _cf_end1_finally = System.nanoTime(); + long _cf_dur1 = (_cf_end1 != -1 ? _cf_end1 : _cf_end1_finally) - _cf_start1; + System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + ":" + _cf_dur1 + "######!"); + } + } + } + + @Test + public void testSpinLong() { + // Codeflash timing instrumentation with inner loop for JIT warmup + int _cf_loop2 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX")); + int _cf_innerIterations2 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100")); + String _cf_mod2 = "SpinWaitTest"; + String _cf_cls2 = "SpinWaitTest"; + String _cf_test2 = "testSpinLong"; + String _cf_fn2 = "spinWait"; + + for (int _cf_i2 = 0; _cf_i2 < _cf_innerIterations2; _cf_i2++) { + System.out.println("!$######" + _cf_mod2 + ":" + _cf_cls2 + "." + _cf_test2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":" + _cf_i2 + "######$!"); + long _cf_end2 = -1; + long _cf_start2 = 0; + try { + _cf_start2 = System.nanoTime(); + assertEquals(100_000_000L, SpinWait.spinWait(100_000_000L)); + _cf_end2 = System.nanoTime(); + } finally { + long _cf_end2_finally = System.nanoTime(); + long _cf_dur2 = (_cf_end2 != -1 ? _cf_end2 : _cf_end2_finally) - _cf_start2; + System.out.println("!######" + _cf_mod2 + ":" + _cf_cls2 + "." + _cf_test2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":" + _cf_i2 + ":" + _cf_dur2 + "######!"); + } + } + } +} +""" + assert instrumented == expected_instrumented + + instrumented_file = test_dir / "SpinWaitTest__perfonlyinstrumented.java" + instrumented_file.write_text(instrumented, encoding="utf-8") + + # Run benchmarking with inner_iterations=2 — mirrors Python's + # pytest_min_loops=2, pytest_max_loops=2 which produces 4 results + from codeflash.languages.java.test_runner import run_benchmarking_tests + + test_env = os.environ.copy() + + class MockTestFiles: + def __init__(self, files): + self.test_files = files + + class MockTestFile: + def __init__(self, path): + self.benchmarking_file_path = path + self.instrumented_behavior_file_path = path + + test_files = MockTestFiles([MockTestFile(instrumented_file)]) + + result_xml_path, result = run_benchmarking_tests( + test_paths=test_files, + test_env=test_env, + cwd=project_root, + timeout=120, + project_root=project_root, + min_loops=1, + max_loops=1, + target_duration_seconds=1.0, + inner_iterations=2, + ) + + assert result.returncode == 0, f"Maven test failed: {result.stderr}" + + # Parse timing markers from stdout + stdout = result.stdout + end_pattern = re.compile(r"!######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+):([^:]+)######!") + end_matches = end_pattern.findall(stdout) + + # Should have 4 timing markers (2 test methods × 2 inner iterations) + # Mirrors Python's: assert len(test_results) == 4 + assert len(end_matches) == 4, ( + f"Expected 4 end markers (2 methods × 2 inner iterations), got {len(end_matches)}: {end_matches}" + ) + + # Verify all tests passed and timing accuracy — mirrors Python's: + # assert math.isclose(test_result.runtime, ((i % 2) + 1) * 100_000_000, rel_tol=0.01) + short_durations = [] + long_durations = [] + for match in end_matches: + duration_ns = int(match[5]) + assert duration_ns > 0 + + if duration_ns < 75_000_000: + short_durations.append(duration_ns) + else: + long_durations.append(duration_ns) + + assert len(short_durations) == 2, f"Expected 2 short results, got {len(short_durations)}" + assert len(long_durations) == 2, f"Expected 2 long results, got {len(long_durations)}" + + for duration in short_durations: + assert math.isclose(duration, 50_000_000, rel_tol=0.15), ( + f"Short spin measured {duration}ns, expected ~50_000_000ns (15% tolerance)" + ) + + for duration in long_durations: + assert math.isclose(duration, 100_000_000, rel_tol=0.15), ( + f"Long spin measured {duration}ns, expected ~100_000_000ns (15% tolerance)" + ) diff --git a/tests/test_languages/test_java/test_line_profiler.py b/tests/test_languages/test_java/test_line_profiler.py index fd42acad7..3a73bd9f4 100644 --- a/tests/test_languages/test_java/test_line_profiler.py +++ b/tests/test_languages/test_java/test_line_profiler.py @@ -1,21 +1,28 @@ -"""Tests for Java line profiler.""" +"""Tests for Java line profiler (agent-based).""" import json import tempfile from pathlib import Path +from unittest.mock import patch import pytest -from codeflash.languages.base import FunctionInfo, Language -from codeflash.languages.java.line_profiler import JavaLineProfiler, format_line_profile_results -from codeflash.languages.java.parser import get_java_analyzer +from codeflash.languages.java.line_profiler import ( + DEFAULT_WARMUP_ITERATIONS, + JavaLineProfiler, + find_agent_jar, + format_line_profile_results, + resolve_internal_class_name, +) -class TestJavaLineProfilerInstrumentation: - """Tests for line profiler instrumentation.""" +class TestAgentConfigGeneration: + """Tests for agent config generation.""" + + def test_simple_method(self): + """Test config generation for a simple method.""" + from codeflash.languages.base import FunctionInfo, Language - def test_instrument_simple_method(self): - """Test instrumenting a simple method.""" source = """package com.example; public class Calculator { @@ -39,29 +46,86 @@ def test_instrument_simple_method(self): language=Language.JAVA, ) - with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp: - output_file = Path(tmp.name) + with tempfile.TemporaryDirectory() as tmpdir: + output_file = Path(tmpdir) / "profile.json" + config_path = Path(tmpdir) / "config.json" + + profiler = JavaLineProfiler(output_file=output_file) + profiler.generate_agent_config(source, file_path, [func], config_path) + + assert config_path.exists() + config = json.loads(config_path.read_text()) + + assert config == { + "outputFile": str(output_file), + "warmupIterations": DEFAULT_WARMUP_ITERATIONS, + "targets": [ + { + "className": "com/example/Calculator", + "methods": [ + { + "name": "add", + "startLine": 4, + "endLine": 7, + "sourceFile": file_path.as_posix(), + } + ], + } + ], + "lineContents": { + f"{file_path.as_posix()}:4": "public static int add(int a, int b) {", + f"{file_path.as_posix()}:5": "int result = a + b;", + f"{file_path.as_posix()}:6": "return result;", + f"{file_path.as_posix()}:7": "}", + }, + } + + def test_line_contents_extraction(self): + """Test that line contents are extracted correctly.""" + from codeflash.languages.base import FunctionInfo, Language - profiler = JavaLineProfiler(output_file=output_file) - analyzer = get_java_analyzer() + source = """public class Test { + public void method() { + int x = 1; + // just a comment + return; + } +} +""" + file_path = Path("/tmp/Test.java") + func = FunctionInfo( + function_name="method", + file_path=file_path, + starting_line=2, + ending_line=6, + starting_col=0, + ending_col=0, + parents=(), + is_async=False, + is_method=True, + language=Language.JAVA, + ) - instrumented = profiler.instrument_source(source, file_path, [func], analyzer) + with tempfile.TemporaryDirectory() as tmpdir: + output_file = Path(tmpdir) / "profile.json" + config_path = Path(tmpdir) / "config.json" - # Verify profiler class is added - assert "class CodeflashLineProfiler" in instrumented - assert "public static void hit(String file, int line)" in instrumented + profiler = JavaLineProfiler(output_file=output_file) + profiler.generate_agent_config(source, file_path, [func], config_path) - # Verify enterFunction() is called - assert "CodeflashLineProfiler.enterFunction()" in instrumented + config = json.loads(config_path.read_text()) - # Verify hit() calls are added for executable lines - assert 'CodeflashLineProfiler.hit("/tmp/Calculator.java"' in instrumented + assert config["lineContents"] == { + f"{file_path.as_posix()}:2": "public void method() {", + f"{file_path.as_posix()}:3": "int x = 1;", + f"{file_path.as_posix()}:5": "return;", + f"{file_path.as_posix()}:6": "}", + } - # Cleanup - output_file.unlink(missing_ok=True) + def test_multiple_functions(self): + """Test config with multiple target functions.""" + from codeflash.languages.base import FunctionInfo, Language - def test_instrument_preserves_non_instrumented_code(self): - """Test that non-instrumented parts are preserved.""" source = """public class Test { public void method1() { int x = 1; @@ -73,7 +137,7 @@ def test_instrument_preserves_non_instrumented_code(self): } """ file_path = Path("/tmp/Test.java") - func = FunctionInfo( + func1 = FunctionInfo( function_name="method1", file_path=file_path, starting_line=2, @@ -85,92 +149,141 @@ def test_instrument_preserves_non_instrumented_code(self): is_method=True, language=Language.JAVA, ) + func2 = FunctionInfo( + function_name="method2", + file_path=file_path, + starting_line=6, + ending_line=8, + starting_col=0, + ending_col=0, + parents=(), + is_async=False, + is_method=True, + language=Language.JAVA, + ) - with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp: - output_file = Path(tmp.name) - - profiler = JavaLineProfiler(output_file=output_file) - analyzer = get_java_analyzer() - - instrumented = profiler.instrument_source(source, file_path, [func], analyzer) - - # method2 should not be instrumented - lines = instrumented.split("\n") - method2_lines = [l for l in lines if "method2" in l or "int y = 2" in l] - - # Should have method2 declaration and body, but no profiler calls in method2 - assert any("method2" in l for l in method2_lines) - assert any("int y = 2" in l for l in method2_lines) - # Profiler calls should not be in method2's body - method2_start = None - for i, l in enumerate(lines): - if "method2" in l: - method2_start = i - break - - if method2_start: - # Check the few lines after method2 declaration - method2_body = lines[method2_start : method2_start + 5] - profiler_in_method2 = any("CodeflashLineProfiler.hit" in l for l in method2_body) - # There might be profiler class code before method2, but not in its body - # Actually, since we only instrument method1, method2 should be unchanged - - # Cleanup - output_file.unlink(missing_ok=True) - - def test_find_executable_lines(self): - """Test finding executable lines in Java code.""" - source = """public static int fibonacci(int n) { - if (n <= 1) return n; - return fibonacci(n-1) + fibonacci(n-2); -} -""" - analyzer = get_java_analyzer() - tree = analyzer.parse(source.encode("utf8")) + with tempfile.TemporaryDirectory() as tmpdir: + output_file = Path(tmpdir) / "profile.json" + config_path = Path(tmpdir) / "config.json" + + profiler = JavaLineProfiler(output_file=output_file) + profiler.generate_agent_config(source, file_path, [func1, func2], config_path) + + config = json.loads(config_path.read_text()) + + assert config["targets"][0]["methods"] == [ + { + "name": "method1", + "startLine": 2, + "endLine": 4, + "sourceFile": file_path.as_posix(), + }, + { + "name": "method2", + "startLine": 6, + "endLine": 8, + "sourceFile": file_path.as_posix(), + }, + ] + + def test_empty_function_list(self): + """Test with no functions produces valid config.""" + source = "public class Test {}" + file_path = Path("/tmp/Test.java") - with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp: - output_file = Path(tmp.name) + with tempfile.TemporaryDirectory() as tmpdir: + output_file = Path(tmpdir) / "profile.json" + config_path = Path(tmpdir) / "config.json" - profiler = JavaLineProfiler(output_file=output_file) - executable_lines = profiler._find_executable_lines(tree.root_node) + profiler = JavaLineProfiler(output_file=output_file) + profiler.generate_agent_config(source, file_path, [], config_path) - # Should find the if statement and return statements - assert len(executable_lines) >= 2 + config = json.loads(config_path.read_text()) + assert config["targets"][0]["methods"] == [] - # Cleanup - output_file.unlink(missing_ok=True) +class TestResolveInternalClassName: + """Tests for JVM class name resolution.""" -class TestJavaLineProfilerExecution: - """Tests for line profiler execution (requires compilation).""" + def test_with_package(self): + source = "package com.example;\npublic class Calculator {}" + result = resolve_internal_class_name(Path("/tmp/Calculator.java"), source) + assert result == "com/example/Calculator" - @pytest.mark.skipif( - True, # Skip for now - compilation test requires full Java env - reason="Java compiler test skipped - requires javac and dependencies", - ) - def test_instrumented_code_compiles(self): - """Test that instrumented code compiles successfully.""" - source = """package com.example; + def test_without_package(self): + source = "public class Calculator {}" + result = resolve_internal_class_name(Path("/tmp/Calculator.java"), source) + assert result == "Calculator" -public class Factorial { - public static long factorial(int n) { - if (n < 0) { - throw new IllegalArgumentException("Negative input"); - } - long result = 1; - for (int i = 1; i <= n; i++) { - result *= i; - } - return result; - } -} -""" - file_path = Path("/tmp/test_profiler/Factorial.java") + def test_nested_package(self): + source = "package org.apache.commons.lang3;\npublic class StringUtils {}" + result = resolve_internal_class_name(Path("/tmp/StringUtils.java"), source) + assert result == "org/apache/commons/lang3/StringUtils" + + +class TestAgentJarLocator: + """Tests for finding the agent JAR.""" + + def test_find_agent_jar(self): + jar = find_agent_jar() + # Should find it in either resources or dev build + assert jar is not None + assert jar.exists() + assert jar.name == "codeflash-runtime-1.0.0.jar" + + def test_build_javaagent_arg(self): + with tempfile.TemporaryDirectory() as tmpdir: + output_file = Path(tmpdir) / "profile.json" + config_path = Path(tmpdir) / "config.json" + config_path.write_text("{}") + + profiler = JavaLineProfiler(output_file=output_file) + arg = profiler.build_javaagent_arg(config_path) + + agent_jar = find_agent_jar() + assert arg == f"-javaagent:{agent_jar}=config={config_path}" + + def test_build_javaagent_arg_missing_jar(self): + with tempfile.TemporaryDirectory() as tmpdir: + output_file = Path(tmpdir) / "profile.json" + config_path = Path(tmpdir) / "config.json" + config_path.write_text("{}") + + profiler = JavaLineProfiler(output_file=output_file) + + with patch("codeflash.languages.java.line_profiler.find_agent_jar", return_value=None): + with pytest.raises(FileNotFoundError): + profiler.build_javaagent_arg(config_path) + + +class TestWarmupConfig: + """Tests for warmup configuration in agent config generation.""" + + def test_default_warmup_iterations(self): + """Test that default warmup iterations matches the module constant.""" + with tempfile.TemporaryDirectory() as tmpdir: + output_file = Path(tmpdir) / "profile.json" + profiler = JavaLineProfiler(output_file=output_file) + assert profiler.warmup_iterations == DEFAULT_WARMUP_ITERATIONS + + def test_custom_warmup_iterations(self): + """Test setting custom warmup iterations.""" + with tempfile.TemporaryDirectory() as tmpdir: + output_file = Path(tmpdir) / "profile.json" + profiler = JavaLineProfiler(output_file=output_file, warmup_iterations=10) + assert profiler.warmup_iterations == 10 + + def test_warmup_disabled(self): + """Test warmup can be disabled by setting to 0.""" + from codeflash.languages.base import FunctionInfo, Language + + source = "public class Test {\n public void method() {\n return;\n }\n}" + file_path = Path("/tmp/Test.java") func = FunctionInfo( - function_name="factorial", + function_name="method", file_path=file_path, - starting_line=4, - ending_line=12, + starting_line=2, + ending_line=4, starting_col=0, ending_col=0, parents=(), @@ -179,71 +292,201 @@ def test_instrumented_code_compiles(self): language=Language.JAVA, ) - with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp: - output_file = Path(tmp.name) - - profiler = JavaLineProfiler(output_file=output_file) - analyzer = get_java_analyzer() + with tempfile.TemporaryDirectory() as tmpdir: + output_file = Path(tmpdir) / "profile.json" + config_path = Path(tmpdir) / "config.json" - instrumented = profiler.instrument_source(source, file_path, [func], analyzer) + profiler = JavaLineProfiler(output_file=output_file, warmup_iterations=0) + profiler.generate_agent_config(source, file_path, [func], config_path) - # Write instrumented source - file_path.parent.mkdir(parents=True, exist_ok=True) - file_path.write_text(instrumented, encoding="utf-8") + config = json.loads(config_path.read_text()) + assert config["warmupIterations"] == 0 - # Try to compile - import subprocess + def test_warmup_in_config_json(self): + """Test that warmupIterations appears in the generated config JSON.""" + from codeflash.languages.base import FunctionInfo, Language - result = subprocess.run( - ["javac", str(file_path)], - capture_output=True, - text=True, + source = "package com.example;\npublic class Calc {\n public int add(int a, int b) {\n return a + b;\n }\n}" + file_path = Path("/tmp/Calc.java") + func = FunctionInfo( + function_name="add", + file_path=file_path, + starting_line=3, + ending_line=5, + starting_col=0, + ending_col=0, + parents=(), + is_async=False, + is_method=True, + language=Language.JAVA, ) - # Check compilation - if result.returncode != 0: - print(f"Compilation failed:\n{result.stderr}") - # For now, we expect compilation to fail due to missing Gson dependency - # This is expected - we're just testing that the instrumentation syntax is valid - # In real usage, Gson will be available via Maven/Gradle - assert "package com.google.gson does not exist" in result.stderr or "cannot find symbol" in result.stderr - else: - assert result.returncode == 0, f"Compilation failed: {result.stderr}" + with tempfile.TemporaryDirectory() as tmpdir: + output_file = Path(tmpdir) / "profile.json" + config_path = Path(tmpdir) / "config.json" + + profiler = JavaLineProfiler(output_file=output_file, warmup_iterations=7) + profiler.generate_agent_config(source, file_path, [func], config_path) + + config = json.loads(config_path.read_text()) + assert config["warmupIterations"] == 7 + + +class TestAgentConfigBoundaryConditions: + """Tests for boundary conditions in agent config generation.""" + + def test_start_line_beyond_end_line(self): + """When starting_line > ending_line, no lines are extracted but config is still valid.""" + from codeflash.languages.base import FunctionInfo, Language + + source = "public class Test {\n public void foo() { return; }\n}\n" + file_path = Path("/tmp/Test.java") + + with tempfile.TemporaryDirectory() as tmpdir: + output_file = Path(tmpdir) / "profile.json" + config_path = Path(tmpdir) / "config.json" + + func = FunctionInfo( + function_name="foo", + file_path=file_path, + starting_line=5, + ending_line=2, + starting_col=0, + ending_col=0, + parents=(), + is_async=False, + is_method=True, + language=Language.JAVA, + ) + + profiler = JavaLineProfiler(output_file=output_file) + profiler.generate_agent_config(source, file_path, [func], config_path) + + config = json.loads(config_path.read_text()) + assert config["lineContents"] == {} + assert config["targets"][0]["methods"] == [ + {"name": "foo", "startLine": 5, "endLine": 2, "sourceFile": file_path.as_posix()} + ] + + def test_line_numbers_beyond_source_length(self): + """Line numbers beyond the source length are silently skipped.""" + from codeflash.languages.base import FunctionInfo, Language + + source = "public class Test {\n public void foo() { return; }\n}\n" + file_path = Path("/tmp/Test.java") + + with tempfile.TemporaryDirectory() as tmpdir: + output_file = Path(tmpdir) / "profile.json" + config_path = Path(tmpdir) / "config.json" + + func = FunctionInfo( + function_name="foo", + file_path=file_path, + starting_line=100, + ending_line=200, + starting_col=0, + ending_col=0, + parents=(), + is_async=False, + is_method=True, + language=Language.JAVA, + ) + + profiler = JavaLineProfiler(output_file=output_file) + profiler.generate_agent_config(source, file_path, [func], config_path) + + config = json.loads(config_path.read_text()) + assert config == { + "outputFile": str(output_file), + "warmupIterations": DEFAULT_WARMUP_ITERATIONS, + "targets": [ + { + "className": "Test", + "methods": [ + { + "name": "foo", + "startLine": 100, + "endLine": 200, + "sourceFile": file_path.as_posix(), + } + ], + } + ], + "lineContents": {}, + } - # Cleanup - output_file.unlink(missing_ok=True) - file_path.unlink(missing_ok=True) - class_file = file_path.with_suffix(".class") - class_file.unlink(missing_ok=True) + def test_negative_line_numbers(self): + """Negative line numbers produce no line contents (range is empty or out of bounds).""" + from codeflash.languages.base import FunctionInfo, Language + + source = "public class Test {\n public void foo() { return; }\n}\n" + file_path = Path("/tmp/Test.java") + + with tempfile.TemporaryDirectory() as tmpdir: + output_file = Path(tmpdir) / "profile.json" + config_path = Path(tmpdir) / "config.json" + + func = FunctionInfo( + function_name="foo", + file_path=file_path, + starting_line=-5, + ending_line=-1, + starting_col=0, + ending_col=0, + parents=(), + is_async=False, + is_method=True, + language=Language.JAVA, + ) + + profiler = JavaLineProfiler(output_file=output_file) + profiler.generate_agent_config(source, file_path, [func], config_path) + + config = json.loads(config_path.read_text()) + assert config == { + "outputFile": str(output_file), + "warmupIterations": DEFAULT_WARMUP_ITERATIONS, + "targets": [ + { + "className": "Test", + "methods": [ + { + "name": "foo", + "startLine": -5, + "endLine": -1, + "sourceFile": file_path.as_posix(), + } + ], + } + ], + "lineContents": {}, + } class TestLineProfileResultsParsing: """Tests for parsing line profile results.""" def test_parse_results_empty_file(self): - """Test parsing when file doesn't exist.""" results = JavaLineProfiler.parse_results(Path("/tmp/nonexistent.json")) - assert results["timings"] == {} - assert results["unit"] == 1e-9 + assert results == {"timings": {}, "unit": 1e-9, "str_out": ""} def test_parse_results_valid_data(self): - """Test parsing valid profiling data.""" data = { "/tmp/Test.java:10": { "hits": 100, - "time": 5000000, # 5ms in nanoseconds + "time": 5000000, "file": "/tmp/Test.java", "line": 10, - "content": "int x = compute();" + "content": "int x = compute();", }, "/tmp/Test.java:11": { "hits": 100, - "time": 95000000, # 95ms in nanoseconds + "time": 95000000, "file": "/tmp/Test.java", "line": 11, - "content": "result = slowOperation(x);" - } + "content": "result = slowOperation(x);", + }, } with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp: @@ -252,36 +495,34 @@ def test_parse_results_valid_data(self): results = JavaLineProfiler.parse_results(profile_file) - assert "/tmp/Test.java" in results["timings"] - assert 10 in results["timings"]["/tmp/Test.java"] - assert 11 in results["timings"]["/tmp/Test.java"] - - line10 = results["timings"]["/tmp/Test.java"][10] - assert line10["hits"] == 100 - assert line10["time_ns"] == 5000000 - assert line10["time_ms"] == 5.0 - - line11 = results["timings"]["/tmp/Test.java"][11] - assert line11["hits"] == 100 - assert line11["time_ns"] == 95000000 - assert line11["time_ms"] == 95.0 - - # Line 11 is the hotspot (95% of time) - total_time = line10["time_ms"] + line11["time_ms"] - assert line11["time_ms"] / total_time > 0.9 # 95% of time + assert results["unit"] == 1e-9 + assert results["timings"] == { + ("/tmp/Test.java", 10, "Test.java"): [(10, 100, 5000000), (11, 100, 95000000)] + } + assert results["line_contents"] == { + ("/tmp/Test.java", 10): "int x = compute();", + ("/tmp/Test.java", 11): "result = slowOperation(x);", + } + assert results["str_out"] == ( + "# Timer unit: 1e-09 s\n" + "## Function: Test.java\n" + "## Total time: 0.1 s\n" + "| Hits | Time | Per Hit | % Time | Line Contents |\n" + "|-------:|--------:|----------:|---------:|:---------------------------|\n" + "| 100 | 5e+06 | 50000 | 5 | int x = compute(); |\n" + "| 100 | 9.5e+07 | 950000 | 95 | result = slowOperation(x); |\n" + ) - # Cleanup profile_file.unlink() def test_format_results(self): - """Test formatting results for display.""" data = { "/tmp/Test.java:10": { "hits": 10, - "time": 1000000, # 1ms + "time": 1000000, "file": "/tmp/Test.java", "line": 10, - "content": "int x = 1;" + "content": "int x = 1;", } } @@ -292,78 +533,88 @@ def test_format_results(self): results = JavaLineProfiler.parse_results(profile_file) formatted = format_line_profile_results(results) - assert "Line Profiling Results" in formatted - assert "/tmp/Test.java" in formatted - assert "10" in formatted # Line number - assert "10" in formatted # Hits - assert "int x = 1" in formatted # Code content + expected = ( + "# Timer unit: 1e-09 s\n" + "## Function: Test.java\n" + "## Total time: 0.001 s\n" + "| Hits | Time | Per Hit | % Time | Line Contents |\n" + "|-------:|-------:|----------:|---------:|:----------------|\n" + "| 10 | 1e+06 | 100000 | 100 | int x = 1; |\n" + ) + assert formatted == expected - # Cleanup profile_file.unlink() + def test_parse_results_corrupted_json(self): + """Corrupted/truncated JSON returns empty results instead of crashing.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp: + tmp.write('{"incomplete": true, "data": [') # truncated JSON + profile_file = Path(tmp.name) -class TestLineProfilerEdgeCases: - """Tests for edge cases in line profiling.""" - - def test_empty_function_list(self): - """Test with no functions to instrument.""" - source = "public class Test {}" - file_path = Path("/tmp/Test.java") - - with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp: - output_file = Path(tmp.name) - - profiler = JavaLineProfiler(output_file=output_file) - - instrumented = profiler.instrument_source(source, file_path, [], None) - - # Should return source unchanged - assert instrumented == source + results = JavaLineProfiler.parse_results(profile_file) - # Cleanup - output_file.unlink(missing_ok=True) + assert results == {"timings": {}, "unit": 1e-9, "str_out": ""} - def test_function_with_only_comments(self): - """Test instrumenting a function with only comments.""" - source = """public class Test { - public void method() { - // Just a comment - /* Another comment */ - } -} -""" - file_path = Path("/tmp/Test.java") - func = FunctionInfo( - function_name="method", - file_path=file_path, - starting_line=2, - ending_line=5, - starting_col=0, - ending_col=0, - parents=(), - is_async=False, - is_method=True, - language=Language.JAVA, - ) + profile_file.unlink() - with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp: - output_file = Path(tmp.name) + def test_parse_results_not_a_dict(self): + """Profile file containing a JSON array instead of object returns empty results.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp: + json.dump([1, 2, 3], tmp) + profile_file = Path(tmp.name) - profiler = JavaLineProfiler(output_file=output_file) - analyzer = get_java_analyzer() + results = JavaLineProfiler.parse_results(profile_file) - instrumented = profiler.instrument_source(source, file_path, [func], analyzer) + assert results == {"timings": {}, "unit": 1e-9, "str_out": ""} - # Should add profiler class and enterFunction, but no hit() calls for comments - assert "CodeflashLineProfiler" in instrumented - assert "enterFunction()" in instrumented + profile_file.unlink() - # Should not add hit() for comment lines - lines = instrumented.split("\n") - comment_line_has_hit = any( - "// Just a comment" in l and "hit(" in l for l in lines - ) - assert not comment_line_has_hit + def test_parse_results_no_config_file_fallback(self): + """When config.json is missing, parse_results falls back to grouping by file.""" + data = { + "/tmp/Sorter.java:5": { + "hits": 10, + "time": 2000000, + "file": "/tmp/Sorter.java", + "line": 5, + "content": "int n = arr.length;", + }, + "/tmp/Sorter.java:6": { + "hits": 10, + "time": 8000000, + "file": "/tmp/Sorter.java", + "line": 6, + "content": "for (int i = 0; i < n; i++) {", + }, + } - # Cleanup - output_file.unlink(missing_ok=True) + with tempfile.TemporaryDirectory() as tmpdir: + profile_file = Path(tmpdir) / "profile.json" + profile_file.write_text(json.dumps(data), encoding="utf-8") + + # Deliberately do NOT create profile.config.json + + config_path = profile_file.with_suffix(".config.json") + assert not config_path.exists() + + results = JavaLineProfiler.parse_results(profile_file) + + assert results == { + "unit": 1e-9, + "timings": { + ("/tmp/Sorter.java", 5, "Sorter.java"): [(5, 10, 2000000), (6, 10, 8000000)] + }, + "line_contents": { + ("/tmp/Sorter.java", 5): "int n = arr.length;", + ("/tmp/Sorter.java", 6): "for (int i = 0; i < n; i++) {", + }, + "str_out": ( + "# Timer unit: 1e-09 s\n" + "## Function: Sorter.java\n" + "## Total time: 0.01 s\n" + "| Hits | Time | Per Hit | % Time | Line Contents |\n" + "|-------:|-------:|----------:|---------:|:------------------------------|\n" + "| 10 | 2e+06 | 200000 | 20 | int n = arr.length; |\n" + "| 10 | 8e+06 | 800000 | 80 | for (int i = 0; i < n; i++) { |\n" + ), + } diff --git a/tests/test_languages/test_java/test_line_profiler_integration.py b/tests/test_languages/test_java/test_line_profiler_integration.py index 14b4c8426..46662b2d5 100644 --- a/tests/test_languages/test_java/test_line_profiler_integration.py +++ b/tests/test_languages/test_java/test_line_profiler_integration.py @@ -1,21 +1,27 @@ -"""Integration tests for Java line profiler with JavaSupport.""" +"""Integration tests for Java line profiler with JavaSupport. +""" import json +import math +import shutil +import subprocess import tempfile from pathlib import Path import pytest from codeflash.languages.base import FunctionInfo, Language +from codeflash.languages.java.line_profiler import DEFAULT_WARMUP_ITERATIONS, JavaLineProfiler, find_agent_jar from codeflash.languages.java.support import get_java_support -class TestLineProfilerIntegration: - """Integration tests for line profiler with JavaSupport.""" +class TestLineProfilerInstrumentation: + """Integration tests for line profiler instrumentation through JavaSupport. + """ - def test_instrument_and_parse_results(self): - """Test full workflow: instrument, parse results.""" - # Create a temporary Java file + def test_instrument_with_package(self): + """Test instrumentation for a class with a package declaration. + """ source = """package com.example; public class Calculator { @@ -27,13 +33,9 @@ def test_instrument_and_parse_results(self): """ with tempfile.TemporaryDirectory() as tmpdir: tmppath = Path(tmpdir) - src_dir = tmppath / "src" - src_dir.mkdir() - - java_file = src_dir / "Calculator.java" + java_file = tmppath / "Calculator.java" java_file.write_text(source, encoding="utf-8") - # Create profile output file profile_output = tmppath / "profile.json" func = FunctionInfo( @@ -49,105 +51,176 @@ def test_instrument_and_parse_results(self): language=Language.JAVA, ) - # Get JavaSupport and instrument support = get_java_support() success = support.instrument_source_for_line_profiler(func, profile_output) - # Should succeed - assert success, "Instrumentation should succeed" - - # Verify file was modified - instrumented = java_file.read_text(encoding="utf-8") - assert "CodeflashLineProfiler" in instrumented - assert "enterFunction()" in instrumented - assert "hit(" in instrumented - - def test_parse_empty_results(self): - """Test parsing results when file doesn't exist.""" - support = get_java_support() - - # Parse non-existent file - results = support.parse_line_profile_results(Path("/tmp/nonexistent_profile.json")) - - # Should return empty results - assert results["timings"] == {} - assert results["unit"] == 1e-9 - - def test_parse_valid_results(self): - """Test parsing valid profiling results.""" - # Create sample profiling data - data = { - "/tmp/Test.java:5": { - "hits": 100, - "time": 5000000, # 5ms in nanoseconds - "file": "/tmp/Test.java", - "line": 5, - "content": "int x = compute();" - }, - "/tmp/Test.java:6": { - "hits": 100, - "time": 95000000, # 95ms in nanoseconds - "file": "/tmp/Test.java", - "line": 6, - "content": "result = slowOperation(x);" + assert success, "Profiler config generation should succeed" + + # Source file must NOT be modified (Java uses agent, not source rewriting) + assert java_file.read_text(encoding="utf-8") == source + + # Config JSON should have been created with correct content + config_path = profile_output.with_suffix(".config.json") + assert config_path.exists() + config = json.loads(config_path.read_text(encoding="utf-8")) + + assert config == { + "outputFile": str(profile_output), + "warmupIterations": DEFAULT_WARMUP_ITERATIONS, + "targets": [ + { + "className": "com/example/Calculator", + "methods": [ + { + "name": "add", + "startLine": 4, + "endLine": 7, + "sourceFile": java_file.as_posix(), + } + ], + } + ], + "lineContents": { + f"{java_file.as_posix()}:4": "public static int add(int a, int b) {", + f"{java_file.as_posix()}:5": "int result = a + b;", + f"{java_file.as_posix()}:6": "return result;", + f"{java_file.as_posix()}:7": "}", + }, + } + + # javaagent arg should be set on the support instance + agent_jar = find_agent_jar() + assert support.line_profiler_agent_arg == f"-javaagent:{agent_jar}=config={config_path}" + + # Warmup iterations should be stored + assert support.line_profiler_warmup_iterations == DEFAULT_WARMUP_ITERATIONS + + def test_instrument_without_package(self): + """Test instrumentation for a class without a package declaration. + + Mirrors Python's test_add_decorator_imports_nodeps — simple function with + no external dependencies. + """ + source = """public class Sorter { + public static int[] sort(int[] arr) { + int n = arr.length; + for (int i = 0; i < n; i++) { + for (int j = 0; j < n - i - 1; j++) { + if (arr[j] > arr[j + 1]) { + int temp = arr[j]; + arr[j] = arr[j + 1]; + arr[j + 1] = temp; + } } } + return arr; + } +} +""" + with tempfile.TemporaryDirectory() as tmpdir: + tmppath = Path(tmpdir) + java_file = tmppath / "Sorter.java" + java_file.write_text(source, encoding="utf-8") - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp: - json.dump(data, tmp) - profile_file = Path(tmp.name) + profile_output = tmppath / "profile.json" + + func = FunctionInfo( + function_name="sort", + file_path=java_file, + starting_line=2, + ending_line=14, + starting_col=0, + ending_col=0, + parents=(), + is_async=False, + is_method=True, + language=Language.JAVA, + ) - try: support = get_java_support() - results = support.parse_line_profile_results(profile_file) - - # Verify structure - assert "/tmp/Test.java" in results["timings"] - assert 5 in results["timings"]["/tmp/Test.java"] - assert 6 in results["timings"]["/tmp/Test.java"] - - # Verify line 5 data - line5 = results["timings"]["/tmp/Test.java"][5] - assert line5["hits"] == 100 - assert line5["time_ns"] == 5000000 - assert line5["time_ms"] == 5.0 - - # Verify line 6 is the hotspot (95% of time) - line6 = results["timings"]["/tmp/Test.java"][6] - assert line6["hits"] == 100 - assert line6["time_ns"] == 95000000 - assert line6["time_ms"] == 95.0 - - # Line 6 should be much slower - assert line6["time_ms"] > line5["time_ms"] * 10 - - finally: - profile_file.unlink() - - def test_instrument_multiple_functions(self): - """Test instrumenting multiple functions in same file.""" - source = """public class Test { - public void method1() { - int x = 1; + success = support.instrument_source_for_line_profiler(func, profile_output) + + assert success + + # Source not modified + assert java_file.read_text(encoding="utf-8") == source + + config_path = profile_output.with_suffix(".config.json") + config = json.loads(config_path.read_text(encoding="utf-8")) + + assert config == { + "outputFile": str(profile_output), + "warmupIterations": DEFAULT_WARMUP_ITERATIONS, + "targets": [ + { + "className": "Sorter", + "methods": [ + { + "name": "sort", + "startLine": 2, + "endLine": 14, + "sourceFile": java_file.as_posix(), + } + ], + } + ], + "lineContents": { + f"{java_file.as_posix()}:2": "public static int[] sort(int[] arr) {", + f"{java_file.as_posix()}:3": "int n = arr.length;", + f"{java_file.as_posix()}:4": "for (int i = 0; i < n; i++) {", + f"{java_file.as_posix()}:5": "for (int j = 0; j < n - i - 1; j++) {", + f"{java_file.as_posix()}:6": "if (arr[j] > arr[j + 1]) {", + f"{java_file.as_posix()}:7": "int temp = arr[j];", + f"{java_file.as_posix()}:8": "arr[j] = arr[j + 1];", + f"{java_file.as_posix()}:9": "arr[j + 1] = temp;", + f"{java_file.as_posix()}:10": "}", + f"{java_file.as_posix()}:11": "}", + f"{java_file.as_posix()}:12": "}", + f"{java_file.as_posix()}:13": "return arr;", + f"{java_file.as_posix()}:14": "}", + }, + } + + def test_instrument_multiple_methods(self): + """Test instrumentation with multiple target methods in the same class. + + Mirrors Python's test_add_decorator_imports_helper_outside — multiple + functions that all need to be profiled. + """ + source = """public class StringProcessor { + public static String reverse(String s) { + char[] chars = s.toCharArray(); + int left = 0; + int right = chars.length - 1; + while (left < right) { + char temp = chars[left]; + chars[left] = chars[right]; + chars[right] = temp; + left++; + right--; + } + return new String(chars); } - public void method2() { - int y = 2; + public static boolean isPalindrome(String s) { + String cleaned = s.toLowerCase().replaceAll("[^a-z0-9]", ""); + String reversed = reverse(cleaned); + return cleaned.equals(reversed); } } """ with tempfile.TemporaryDirectory() as tmpdir: tmppath = Path(tmpdir) - java_file = tmppath / "Test.java" + java_file = tmppath / "StringProcessor.java" java_file.write_text(source, encoding="utf-8") profile_output = tmppath / "profile.json" - func1 = FunctionInfo( - function_name="method1", + func_reverse = FunctionInfo( + function_name="reverse", file_path=java_file, starting_line=2, - ending_line=4, + ending_line=14, starting_col=0, ending_col=0, parents=(), @@ -155,12 +228,11 @@ def test_instrument_multiple_functions(self): is_method=True, language=Language.JAVA, ) - - func2 = FunctionInfo( - function_name="method2", + func_palindrome = FunctionInfo( + function_name="isPalindrome", file_path=java_file, - starting_line=6, - ending_line=8, + starting_line=16, + ending_line=20, starting_col=0, ending_col=0, parents=(), @@ -169,14 +241,329 @@ def test_instrument_multiple_functions(self): language=Language.JAVA, ) + support = get_java_support() # Instrument first function + success = support.instrument_source_for_line_profiler(func_reverse, profile_output) + assert success + + # Source not modified + assert java_file.read_text(encoding="utf-8") == source + + config_path = profile_output.with_suffix(".config.json") + config = json.loads(config_path.read_text(encoding="utf-8")) + + # Both methods should appear as targets when generated together + profiler = JavaLineProfiler(output_file=profile_output) + profiler.generate_agent_config( + source, java_file, [func_reverse, func_palindrome], config_path + ) + config = json.loads(config_path.read_text(encoding="utf-8")) + + assert config == { + "outputFile": str(profile_output), + "warmupIterations": DEFAULT_WARMUP_ITERATIONS, + "targets": [ + { + "className": "StringProcessor", + "methods": [ + { + "name": "reverse", + "startLine": 2, + "endLine": 14, + "sourceFile": java_file.as_posix(), + }, + { + "name": "isPalindrome", + "startLine": 16, + "endLine": 20, + "sourceFile": java_file.as_posix(), + }, + ], + } + ], + "lineContents": { + f"{java_file.as_posix()}:2": "public static String reverse(String s) {", + f"{java_file.as_posix()}:3": "char[] chars = s.toCharArray();", + f"{java_file.as_posix()}:4": "int left = 0;", + f"{java_file.as_posix()}:5": "int right = chars.length - 1;", + f"{java_file.as_posix()}:6": "while (left < right) {", + f"{java_file.as_posix()}:7": "char temp = chars[left];", + f"{java_file.as_posix()}:8": "chars[left] = chars[right];", + f"{java_file.as_posix()}:9": "chars[right] = temp;", + f"{java_file.as_posix()}:10": "left++;", + f"{java_file.as_posix()}:11": "right--;", + f"{java_file.as_posix()}:12": "}", + f"{java_file.as_posix()}:13": "return new String(chars);", + f"{java_file.as_posix()}:14": "}", + f"{java_file.as_posix()}:16": "public static boolean isPalindrome(String s) {", + f"{java_file.as_posix()}:17": 'String cleaned = s.toLowerCase().replaceAll("[^a-z0-9]", "");', + f"{java_file.as_posix()}:18": "String reversed = reverse(cleaned);", + f"{java_file.as_posix()}:19": "return cleaned.equals(reversed);", + f"{java_file.as_posix()}:20": "}", + }, + } + + def test_instrument_nested_package(self): + """Test instrumentation for a deeply nested package. + + Mirrors Python's test_add_decorator_imports_helper_in_nested_class — + verifies correct class name resolution with deep package nesting. + """ + source = """package org.apache.commons.lang3; + +public class StringUtils { + public static boolean isEmpty(String s) { + return s == null || s.length() == 0; + } +} +""" + with tempfile.TemporaryDirectory() as tmpdir: + tmppath = Path(tmpdir) + java_file = tmppath / "StringUtils.java" + java_file.write_text(source, encoding="utf-8") + + profile_output = tmppath / "profile.json" + + func = FunctionInfo( + function_name="isEmpty", + file_path=java_file, + starting_line=4, + ending_line=6, + starting_col=0, + ending_col=0, + parents=(), + is_async=False, + is_method=True, + language=Language.JAVA, + ) + support = get_java_support() - success1 = support.instrument_source_for_line_profiler(func1, profile_output) - assert success1 + success = support.instrument_source_for_line_profiler(func, profile_output) - # Re-read source and instrument second function - # Note: In real usage, you'd instrument both at once, but this tests the flow - source2 = java_file.read_text(encoding="utf-8") + assert success + + # Source not modified + assert java_file.read_text(encoding="utf-8") == source + + config_path = profile_output.with_suffix(".config.json") + config = json.loads(config_path.read_text(encoding="utf-8")) + + assert config == { + "outputFile": str(profile_output), + "warmupIterations": DEFAULT_WARMUP_ITERATIONS, + "targets": [ + { + "className": "org/apache/commons/lang3/StringUtils", + "methods": [ + { + "name": "isEmpty", + "startLine": 4, + "endLine": 6, + "sourceFile": java_file.as_posix(), + } + ], + } + ], + "lineContents": { + f"{java_file.as_posix()}:4": "public static boolean isEmpty(String s) {", + f"{java_file.as_posix()}:5": "return s == null || s.length() == 0;", + f"{java_file.as_posix()}:6": "}", + }, + } + + def test_instrument_verifies_line_contents(self): + """Test that line contents are extracted correctly, skipping comment-only lines. - # Write back original to test multiple instrumentations - # (In practice, the profiler instruments all functions at once) + Mirrors Python's test_add_decorator_imports_helper_in_dunder_class — + verifies that instrumentation handles all content in the function body. + """ + source = """public class Fibonacci { + public static long fib(int n) { + if (n <= 1) { + return n; + } + // iterative approach + long a = 0; + long b = 1; + for (int i = 2; i <= n; i++) { + long temp = b; + b = a + b; + a = temp; + } + return b; + } +} +""" + with tempfile.TemporaryDirectory() as tmpdir: + tmppath = Path(tmpdir) + java_file = tmppath / "Fibonacci.java" + java_file.write_text(source, encoding="utf-8") + + profile_output = tmppath / "profile.json" + + func = FunctionInfo( + function_name="fib", + file_path=java_file, + starting_line=2, + ending_line=15, + starting_col=0, + ending_col=0, + parents=(), + is_async=False, + is_method=True, + language=Language.JAVA, + ) + + support = get_java_support() + success = support.instrument_source_for_line_profiler(func, profile_output) + + assert success + + config_path = profile_output.with_suffix(".config.json") + config = json.loads(config_path.read_text(encoding="utf-8")) + + line_contents = config["lineContents"] + p = java_file.as_posix() + + # Comment-only line 6 ("// iterative approach") should be excluded + assert f"{p}:6" not in line_contents + + # Code lines should be present with correct content + assert line_contents[f"{p}:2"] == "public static long fib(int n) {" + assert line_contents[f"{p}:3"] == "if (n <= 1) {" + assert line_contents[f"{p}:4"] == "return n;" + assert line_contents[f"{p}:7"] == "long a = 0;" + assert line_contents[f"{p}:9"] == "for (int i = 2; i <= n; i++) {" + assert line_contents[f"{p}:14"] == "return b;" + assert line_contents[f"{p}:15"] == "}" + + +def build_spin_timer_source(spin_durations_ns: list[int]) -> str: + """Build a SpinTimer Java source that calls spinWait with each given duration.""" + calls = "\n".join(f" spinWait({d}L);" for d in spin_durations_ns) + return f"""\ +public class SpinTimer {{ + public static long spinWait(long durationNs) {{ + long start = System.nanoTime(); + while (System.nanoTime() - start < durationNs) {{ + }} + return durationNs; + }} + + public static void main(String[] args) {{ +{calls} + }} +}} +""" + + +def run_spin_timer_profiled(tmppath: Path, spin_durations_ns: list[int]) -> dict: + """Compile and run SpinTimer with the profiler agent, return parsed results.""" + source = build_spin_timer_source(spin_durations_ns) + java_file = tmppath / "SpinTimer.java" + java_file.write_text(source, encoding="utf-8") + + profile_output = tmppath / "profile.json" + config_path = profile_output.with_suffix(".config.json") + + func = FunctionInfo( + function_name="spinWait", + file_path=java_file, + starting_line=2, + ending_line=7, + starting_col=0, + ending_col=0, + parents=(), + is_async=False, + is_method=True, + language=Language.JAVA, + ) + + profiler = JavaLineProfiler(output_file=profile_output, warmup_iterations=0) + profiler.generate_agent_config(source, java_file, [func], config_path) + agent_arg = profiler.build_javaagent_arg(config_path) + + result = subprocess.run( + ["javac", str(java_file)], + capture_output=True, + text=True, + cwd=str(tmppath), + ) + assert result.returncode == 0, f"javac failed: {result.stderr}" + + result = subprocess.run( + ["java", agent_arg, "-cp", str(tmppath), "SpinTimer"], + capture_output=True, + text=True, + cwd=str(tmppath), + timeout=30, + ) + assert result.returncode == 0, f"java failed: {result.stderr}" + assert profile_output.exists(), "Profile output not written" + + return JavaLineProfiler.parse_results(profile_output) + + +@pytest.mark.skipif(not shutil.which("javac"), reason="Java compiler not available") +class TestSpinTimerProfiling: + """End-to-end spin-timer tests validating profiler timing accuracy. + + Calls spinWait multiple times with known durations, then verifies the + profiler-reported total time matches the expected sum of all spin durations. + """ + + @pytest.mark.parametrize( + "spin_durations_ns", + [ + [50_000_000, 100_000_000], + [30_000_000, 40_000_000, 80_000_000], + ], + ) + def test_total_time_matches_expected(self, spin_durations_ns): + """Profiler total time should match the sum of all spin durations.""" + expected_ns = sum(spin_durations_ns) + + with tempfile.TemporaryDirectory() as tmpdir: + results = run_spin_timer_profiled(Path(tmpdir), spin_durations_ns) + + assert results["timings"], "No timing data produced" + + line_data = next(iter(results["timings"].values())) + total_time_ns = sum(t for _, _, t in line_data) + + assert math.isclose(total_time_ns, expected_ns, rel_tol=0.25), ( + f"Measured {total_time_ns}ns, expected ~{expected_ns}ns (25% tolerance)" + ) + + def test_while_line_dominates(self): + """The while-loop line should account for the majority of self-time.""" + with tempfile.TemporaryDirectory() as tmpdir: + results = run_spin_timer_profiled(Path(tmpdir), [50_000_000, 100_000_000]) + + assert results["timings"] + + line_data = next(iter(results["timings"].values())) + line_times = {lineno: t for lineno, _, t in line_data} + total_time = sum(line_times.values()) + + while_line_time = line_times.get(4, 0) + assert while_line_time / total_time > 0.80, ( + f"While line has {while_line_time / total_time:.1%} of total time, expected >80%" + ) + + def test_hit_counts_match_call_count(self): + """Each line in spinWait should have hits equal to the number of calls.""" + spin_durations = [20_000_000, 30_000_000, 50_000_000] + + with tempfile.TemporaryDirectory() as tmpdir: + results = run_spin_timer_profiled(Path(tmpdir), spin_durations) + + assert results["timings"] + + line_data = next(iter(results["timings"].values())) + line_hits = {lineno: h for lineno, h, _ in line_data} + + # Lines 3 and 6 (start assignment and return) execute once per call + assert line_hits.get(3, 0) == len(spin_durations), ( + f"Line 3 hits: {line_hits.get(3, 0)}, expected {len(spin_durations)}" + )