Skip to content

Comments

⚡️ Speed up function wrap_target_calls_with_treesitter by 20% in PR #1596 (codeflash/optimize-pr1580-2026-02-20T10.00.27)#1598

Closed
codeflash-ai[bot] wants to merge 2 commits intofix/java-direct-jvm-and-bugsfrom
codeflash/optimize-pr1596-2026-02-20T10.23.33
Closed

⚡️ Speed up function wrap_target_calls_with_treesitter by 20% in PR #1596 (codeflash/optimize-pr1580-2026-02-20T10.00.27)#1598
codeflash-ai[bot] wants to merge 2 commits intofix/java-direct-jvm-and-bugsfrom
codeflash/optimize-pr1596-2026-02-20T10.23.33

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Feb 20, 2026

⚡️ This pull request contains optimizations for PR #1596

If you approve this dependent PR, these changes will be merged into the original PR branch codeflash/optimize-pr1580-2026-02-20T10.00.27.

This PR will be automatically closed if the original PR is merged.


📄 20% (0.20x) speedup for wrap_target_calls_with_treesitter in codeflash/languages/java/instrumentation.py

⏱️ Runtime : 32.9 milliseconds 27.4 milliseconds (best of 52 runs)

📝 Explanation and details

Optimization Explanation:

The profiling reveals that _collect_calls consumes 75% of the total execution time, with significant overhead from repeated calls to _is_inside_lambda and _is_inside_complex_expression (combining for ~22% of _collect_calls time). These functions traverse the AST upward for every matched node. I've optimized this by computing parent chain flags once during collection instead of storing them in the call dictionary, and by precomputing line.encode("utf8") operations that were being called repeatedly in loops. Additionally, I've moved regex compilation to module level (already done) and eliminated redundant any() iteration in _infer_array_cast_type by using early-exit short-circuit evaluation with a simple loop that's faster for the common case of no match.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 89 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 97.4%
🌀 Click to see Generated Regression Tests
import pytest  # used for our unit tests
from codeflash.languages.java import \
    instrumentation as instr  # module under test
from codeflash.languages.java.instrumentation import \
    wrap_target_calls_with_treesitter

def test_returns_copy_and_zero_when_func_not_present_simple():
    # Basic case: a couple of simple lines without the target function name.
    body = ["int x = 1;", "x++;", "System.out.println(x);"]
    # Call function with a func_name that does not occur in the body.
    wrapped, count = instr.wrap_target_calls_with_treesitter(body, func_name="nonexistentFunc", iter_id=0) # 3.19μs -> 3.21μs (0.624% slower)

def test_empty_body_lines_returns_empty_and_zero():
    # Edge case: empty input list should be handled gracefully.
    body = []
    wrapped, count = instr.wrap_target_calls_with_treesitter(body, func_name="anything", iter_id=1) # 2.60μs -> 2.52μs (3.17% faster)

def test_unicode_and_non_ascii_lines_preserved():
    # Ensure UTF-8 / multi-byte characters do not break the early-return path.
    body = ["// comment with emoji 😊", 'String s = "café";', "final int π = 3;"]
    wrapped, count = instr.wrap_target_calls_with_treesitter(body, func_name="missingFunc", iter_id=2) # 3.50μs -> 3.40μs (2.95% faster)
    # Make sure modifying the return does not affect the original (copied list).
    wrapped.append("new line")

def test__byte_to_line_index_basic_and_boundaries():
    # Build line byte starts the same way the function under test does.
    lines = ["a", "bb", "ccc"]
    line_byte_starts = []
    offset = 0
    for line in lines:
        line_byte_starts.append(offset)
        # Each joined line is followed by a '\n' when computing offsets in the code.
        offset += len(line.encode("utf8")) + 1

def test__byte_to_line_index_with_multibyte_characters():
    # Non-ASCII characters affect byte lengths; ensure mapping still correct.
    lines = ["λ", "汉字", "emoji 😊"]
    line_byte_starts = []
    offset = 0
    for line in lines:
        line_byte_starts.append(offset)
        offset += len(line.encode("utf8")) + 1
    # offset at start of second line:
    start_second = line_byte_starts[1]
    # offset in third line:
    start_third = line_byte_starts[2]

def test__infer_array_cast_type_detects_primitive_arrays_and_none():
    # Positive matches for assertArrayEquals with primitive array patterns
    line_int = "assertArrayEquals(new int[] , result);"

    line_long = "   assertArrayEquals( new    long[]   , foo());"

    # assertArrayNotEquals should also be recognized
    line_not_equals = "assertArrayNotEquals(new short[] , something);"

    # If no assertion method name, return None even if "new int[]" appears
    line_no_assert = "someMethod(new int[]);"

    # If assertion exists but no primitive 'new TYPE[]' pattern, return None
    line_assert_but_object_array = "assertArrayEquals(new String[] {\"a\"}, got);"

def test_wrap_target_calls_with_treesitter_large_scale_no_func_present():
    # Large-scale test: create 1000 lines without the function name to ensure performance and correctness.
    body = [f"int v{idx} = {idx};" for idx in range(1000)]
    wrapped, count = instr.wrap_target_calls_with_treesitter(body, func_name="targetFunc", iter_id=42) # 19.2μs -> 18.6μs (2.85% faster)
    # Mutating returned should not affect original.
    wrapped.append("extra")

def test__infer_array_cast_type_edge_cases():
    # Edge: whitespace variations inside brackets should still match
    line_space = "assertArrayEquals(new    byte[   ]  , x);"

    # Edge: assertion name appears as substring elsewhere - only exact method names are checked via 'in'
    line_false_pos = "someComment assertArrayEqualsInText(new int[] , x);"

    # Edge: different casing -> should not match (case-sensitive)
    line_case = "ASSERTARRAYEQUALS(new int[] , x);"
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import pytest
from codeflash.languages.java.instrumentation import \
    wrap_target_calls_with_treesitter
from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer

class TestBasicFunctionality:
    """Test basic, normal-case usage of wrap_target_calls_with_treesitter."""

    def test_no_matching_calls_returns_unchanged_body(self):
        """When func_name is not in body, return body unchanged with zero call count."""
        body_lines = ["int x = 5;", "return x;"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "otherFunc", iter_id=1) # 3.04μs -> 2.99μs (1.37% faster)

    def test_single_simple_method_call_in_expression_statement(self):
        """Single method call as standalone expression statement is instrumented."""
        body_lines = ["foo();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 58.4μs -> 56.3μs (3.84% faster)

    def test_method_call_with_arguments(self):
        """Method call with arguments is properly captured."""
        body_lines = ["int x = getValue(42);"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "getValue", iter_id=1) # 60.9μs -> 58.6μs (3.93% faster)

    def test_multiple_calls_same_line(self):
        """Multiple calls to same function on same line are both captured."""
        body_lines = ["foo(foo());"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 62.7μs -> 56.3μs (11.3% faster)

    def test_calls_across_multiple_lines(self):
        """Calls on different lines are all captured with correct numbering."""
        body_lines = ["foo();", "bar();", "foo();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 65.7μs -> 61.7μs (6.56% faster)

    def test_call_in_return_statement(self):
        """Call in return statement is instrumented correctly."""
        body_lines = ["return getValue();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "getValue", iter_id=1) # 48.4μs -> 46.8μs (3.36% faster)

    def test_call_in_variable_assignment(self):
        """Call in variable assignment captures the result."""
        body_lines = ["int x = compute(5);"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "compute", iter_id=1) # 54.0μs -> 53.2μs (1.64% faster)

    def test_chained_method_calls(self):
        """Chained method calls like obj.method().method() are instrumented."""
        body_lines = ["result = obj.getValue().toString();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "getValue", iter_id=2) # 61.8μs -> 59.1μs (4.45% faster)

    def test_precise_call_timing_flag(self):
        """When precise_call_timing=True, timing statements are added."""
        body_lines = ["foo();"]
        result, call_count = wrap_target_calls_with_treesitter(
            body_lines, "foo", iter_id=1, precise_call_timing=True
        ) # 46.4μs -> 45.1μs (2.89% faster)
        result_text = " ".join(result)

    def test_iter_id_affects_variable_names(self):
        """Different iter_id values produce different variable name prefixes."""
        body_lines = ["foo();"]
        result1, _ = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 45.5μs -> 42.4μs (7.42% faster)
        result2, _ = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=2) # 31.3μs -> 31.2μs (0.578% faster)
        result1_text = " ".join(result1)
        result2_text = " ".join(result2)

class TestEdgeCases:
    """Test edge cases and boundary conditions."""

    def test_empty_body_lines(self):
        """Empty body returns empty list with zero call count."""
        body_lines = []
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 2.64μs -> 2.64μs (0.000% faster)

    def test_single_empty_line(self):
        """Single line containing only whitespace."""
        body_lines = ["   "]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 2.57μs -> 2.46μs (4.46% faster)

    def test_func_name_as_substring_not_matched(self):
        """Function name that is substring of another name should not match."""
        body_lines = ["fooBar();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 36.3μs -> 36.4μs (0.247% slower)

    def test_func_name_in_string_literal_ignored(self):
        """Function name appearing in string literal should not be captured."""
        body_lines = ['String msg = "call foo()";']
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 36.8μs -> 36.8μs (0.133% slower)

    def test_func_name_in_comment_ignored(self):
        """Function name in comment should not be captured."""
        body_lines = ["// This calls foo()", "bar();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 35.1μs -> 35.4μs (0.904% slower)

    def test_zero_iter_id(self):
        """iter_id of 0 should work and produce valid variable names."""
        body_lines = ["foo();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=0) # 47.9μs -> 45.5μs (5.14% faster)

    def test_large_iter_id(self):
        """Large iter_id values should work correctly."""
        body_lines = ["foo();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=999999) # 45.8μs -> 44.4μs (3.25% faster)

    def test_call_with_no_arguments(self):
        """Method call with empty parentheses is captured."""
        body_lines = ["getValue();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "getValue", iter_id=1) # 44.6μs -> 43.5μs (2.70% faster)

    def test_call_with_null_argument(self):
        """Method call with null argument is captured."""
        body_lines = ["process(null);"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "process", iter_id=1) # 46.9μs -> 45.7μs (2.54% faster)

    def test_call_with_string_argument_containing_parens(self):
        """Method call with string containing parentheses."""
        body_lines = ['foo("text with (parens)");']
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 49.0μs -> 46.7μs (5.04% faster)

    def test_lines_with_leading_whitespace(self):
        """Lines with indentation are preserved and handled correctly."""
        body_lines = ["    foo();", "        bar();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 53.3μs -> 51.5μs (3.54% faster)

    def test_lines_with_tabs(self):
        """Lines with tab indentation are handled correctly."""
        body_lines = ["\tfoo();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 44.3μs -> 43.0μs (2.98% faster)

    def test_very_long_method_name(self):
        """Long method names are matched correctly."""
        long_name = "verylongmethodnamewithlotsofcharacters" * 2
        body_lines = [f"{long_name}();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, long_name, iter_id=1) # 47.8μs -> 44.5μs (7.47% faster)

    def test_special_characters_in_method_arguments(self):
        """Method call with special character arguments."""
        body_lines = ['foo("@#$%");']
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 48.0μs -> 47.4μs (1.29% faster)

    def test_method_call_in_if_condition(self):
        """Method call inside if condition is captured."""
        body_lines = ["if (check()) {", "    doSomething();", "}"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "check", iter_id=1) # 52.7μs -> 51.7μs (1.86% faster)

    def test_method_call_in_while_condition(self):
        """Method call in while loop condition."""
        body_lines = ["while (hasNext()) {", "    process();", "}"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "hasNext", iter_id=1) # 50.4μs -> 50.3μs (0.117% faster)

    def test_method_call_in_for_loop(self):
        """Method call in for loop initialization."""
        body_lines = ["for (int i = 0; i < count(); i++) {", "    doWork();", "}"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "count", iter_id=1) # 70.5μs -> 69.4μs (1.59% faster)

    def test_nested_method_calls(self):
        """Nested method calls are all captured."""
        body_lines = ["result = outer(inner(5));"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "inner", iter_id=1) # 58.8μs -> 56.3μs (4.49% faster)

    def test_method_call_with_array_argument(self):
        """Method call with array literal argument."""
        body_lines = ["process(new int[]{1, 2, 3});"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "process", iter_id=1) # 62.1μs -> 60.5μs (2.56% faster)

    def test_method_call_with_lambda_argument(self):
        """Method call with lambda expression argument."""
        body_lines = ["execute(() -> foo());"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "execute", iter_id=1) # 53.6μs -> 52.2μs (2.65% faster)

    def test_empty_function_name(self):
        """Empty function name returns unchanged body."""
        body_lines = ["foo();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "", iter_id=1) # 32.5μs -> 32.7μs (0.642% slower)

    def test_func_name_with_spaces(self):
        """Function names with spaces are not valid and should not match."""
        body_lines = ["foo();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo bar", iter_id=1) # 2.49μs -> 2.58μs (3.52% slower)

    def test_call_in_ternary_expression(self):
        """Call inside ternary expression (complex expression context)."""
        body_lines = ["result = condition ? getValue() : null;"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "getValue", iter_id=1) # 54.3μs -> 50.7μs (7.07% faster)

    def test_call_in_cast_expression(self):
        """Call inside cast expression."""
        body_lines = ["Object obj = (Object) getValue();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "getValue", iter_id=1) # 54.4μs -> 52.7μs (3.25% faster)

    def test_call_in_binary_operation(self):
        """Call as operand in binary operation."""
        body_lines = ["int sum = getValue() + 5;"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "getValue", iter_id=1) # 50.0μs -> 48.7μs (2.63% faster)

    def test_multiline_method_invocation(self):
        """Method call split across multiple lines."""
        body_lines = [
            "result = getValue(",
            "    arg1,",
            "    arg2",
            ");",
        ]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "getValue", iter_id=1) # 57.8μs -> 55.2μs (4.63% faster)

    def test_static_method_call(self):
        """Static method call like Math.abs() is captured."""
        body_lines = ["int x = Math.abs(-5);"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "abs", iter_id=1) # 62.1μs -> 60.1μs (3.28% faster)

    def test_instance_method_call(self):
        """Instance method call like obj.method() is captured."""
        body_lines = ["String s = obj.toString();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "toString", iter_id=1) # 56.3μs -> 54.1μs (4.02% faster)

    def test_super_method_call(self):
        """Call to superclass method via super."""
        body_lines = ["super.initialize();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "initialize", iter_id=1) # 48.1μs -> 46.9μs (2.58% faster)

    def test_this_method_call(self):
        """Call via this."""
        body_lines = ["this.initialize();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "initialize", iter_id=1) # 47.6μs -> 46.0μs (3.51% faster)

class TestLargeScale:
    """Test performance and scalability with large inputs."""

    def test_many_lines_no_calls(self):
        """Large body with no matching calls returns quickly."""
        body_lines = [f"int x{i} = {i};" for i in range(1000)]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 24.6μs -> 25.0μs (1.32% slower)

    def test_many_calls_same_line(self):
        """Single line with many calls to same function."""
        # Build a line with multiple calls: foo(foo(foo(...)))
        call_chain = "foo("
        for i in range(10):
            call_chain += "foo("
        for i in range(10):
            call_chain += ")"
        call_chain += ")"
        body_lines = [call_chain]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 314μs -> 229μs (36.8% faster)

    def test_many_calls_multiple_lines(self):
        """Many lines each with a single call."""
        body_lines = [f"func{i}();" if i % 2 == 0 else f"process();" for i in range(500)]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "process", iter_id=1) # 3.82ms -> 3.39ms (12.7% faster)

    def test_mixed_calls_large_file(self):
        """Large body with mixed calls to different functions."""
        body_lines = []
        for i in range(100):
            body_lines.append(f"int x{i} = getValue{i}();")
            body_lines.append(f"process{i}();")
            body_lines.append(f"int y{i} = getValue{i}();")
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "getValue1", iter_id=1) # 2.05ms -> 2.06ms (0.136% slower)

    def test_deeply_nested_expressions(self):
        """Deeply nested method calls."""
        nested = "getValue()"
        for i in range(50):
            nested = f"process({nested})"
        body_lines = [nested + ";"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "process", iter_id=1) # 7.69ms -> 4.06ms (89.3% faster)

    def test_many_different_methods_single_line(self):
        """Single line calling many different methods."""
        body_lines = [" ".join([f"m{i}();" for i in range(100)])]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "m50", iter_id=1) # 472μs -> 473μs (0.269% slower)

    def test_large_body_with_comments(self):
        """Large body with many comments and calls."""
        body_lines = []
        for i in range(500):
            body_lines.append(f"// Comment {i}")
            if i % 3 == 0:
                body_lines.append(f"foo();")
            body_lines.append(f"int x{i} = getValue{i}();")
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 6.32ms -> 6.04ms (4.66% faster)

    def test_large_indentation_levels(self):
        """Body with many levels of indentation."""
        body_lines = []
        indent = ""
        for level in range(100):
            indent += "    "
            body_lines.append(f"{indent}if (cond{level}) {{")
            body_lines.append(f"{indent}    foo();")
        body_lines.append("    " * 100 + "}")
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 2.74ms -> 2.48ms (10.4% faster)

    def test_return_statements_bulk(self):
        """Many return statements with method calls."""
        body_lines = []
        for i in range(100):
            if i % 2 == 0:
                body_lines.append(f"if (cond{i}) return getValue{i}();")
            else:
                body_lines.append(f"return process{i}();")
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "getValue50", iter_id=1) # 742μs -> 750μs (1.02% slower)

    def test_assignment_bulk(self):
        """Many assignment statements with calls."""
        body_lines = [f"int x{i} = getValue();" for i in range(200)]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "getValue", iter_id=1) # 2.97ms -> 2.54ms (17.0% faster)

    def test_try_catch_with_calls(self):
        """Calls inside try-catch blocks."""
        body_lines = [
            "try {",
            "    foo();",
            "    bar();",
            "} catch (Exception e) {",
            "    foo();",
            "    handleError();",
            "}",
        ]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 98.2μs -> 90.7μs (8.24% faster)

    def test_switch_with_calls(self):
        """Calls inside switch statement cases."""
        body_lines = [
            "switch (value) {",
            "    case 1: foo(); break;",
            "    case 2: foo(); break;",
            "    default: bar();",
            "}",
        ]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 100μs -> 92.6μs (8.41% faster)

    def test_long_lines_with_multiple_calls(self):
        """Single very long line with multiple method calls."""
        long_line = "; ".join([f"foo()" for _ in range(20)]) + ";"
        body_lines = [long_line]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 236μs -> 204μs (15.8% faster)

    def test_lines_with_very_long_arguments(self):
        """Method calls with very long argument lists."""
        args = ", ".join([f"arg{i}" for i in range(100)])
        body_lines = [f"process({args});"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "process", iter_id=1) # 229μs -> 231μs (0.801% slower)

    def test_unicode_in_comments_and_strings(self):
        """Body with unicode characters in comments and strings."""
        body_lines = [
            '// Comment with unicode: 你好 🚀',
            'String s = "Unicode string: 日本語";',
            "foo();",
            "// More unicode: Ñoño",
        ]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 67.1μs -> 65.5μs (2.49% faster)

    def test_sequential_calls_same_method(self):
        """Many sequential calls to same method."""
        body_lines = ["foo();"] * 100
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 1.03ms -> 855μs (20.4% faster)

    def test_alternating_different_methods(self):
        """Alternating calls between two methods."""
        body_lines = []
        for i in range(100):
            body_lines.append("foo();" if i % 2 == 0 else "bar();")
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 751μs -> 667μs (12.6% faster)

class TestSpecialPatterns:
    """Test special patterns and real-world scenarios."""

    def test_assertion_methods_with_primitive_arrays(self):
        """Assertion method with primitive array cast inference."""
        body_lines = ["assertEquals(new int[]{1, 2}, getValue());"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "getValue", iter_id=1) # 71.8μs -> 69.8μs (2.94% faster)
        # The variable cast should include int[] type based on the assertion context
        result_text = " ".join(result)

    def test_call_result_used_multiple_times(self):
        """When result is used multiple times, single capture stores it."""
        body_lines = ["if (getValue() > 0 && getValue() < 100) { }"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "getValue", iter_id=1) # 66.9μs -> 61.5μs (8.81% faster)

    def test_complex_receiver_chain(self):
        """Complex receiver chain like obj.getService().getMethod()."""
        body_lines = ["result = obj.getService().getMethod();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "getMethod", iter_id=1) # 63.2μs -> 61.0μs (3.56% faster)

    def test_call_with_this_receiver(self):
        """Method call with explicit this receiver."""
        body_lines = ["this.getValue();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "getValue", iter_id=1) # 48.8μs -> 47.8μs (2.12% faster)

    def test_generic_method_call(self):
        """Generic method call with type parameters."""
        body_lines = ["List<String> items = service.<String>get();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "get", iter_id=1) # 75.7μs -> 74.2μs (1.96% faster)

    def test_varargs_method_call(self):
        """Method call with variable arguments."""
        body_lines = ["process(1, 2, 3, 4, 5);"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "process", iter_id=1) # 56.8μs -> 55.7μs (1.98% faster)

    def test_spread_operator_in_method_call(self):
        """Method call with spread operator (modern Java)."""
        body_lines = ["process(array);"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "process", iter_id=1) # 47.4μs -> 45.8μs (3.39% faster)

    def test_call_in_object_initialization(self):
        """Call in object initializer block."""
        body_lines = [
            "Object obj = new Object() {",
            "    { init(); }",
            "};",
        ]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "init", iter_id=1) # 66.5μs -> 62.7μs (6.05% faster)

    def test_call_result_immediately_cast(self):
        """Call result immediately cast to another type."""
        body_lines = ["String s = (String) getValue();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "getValue", iter_id=1) # 53.4μs -> 51.9μs (2.93% faster)

    def test_call_in_assertion(self):
        """Method call inside assertion statement."""
        body_lines = ["assert getValue() > 0 : \"Value must be positive\";"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "getValue", iter_id=1) # 51.4μs -> 49.7μs (3.38% faster)

    def test_call_followed_by_semicolon_alternatives(self):
        """Different statement terminators and whitespace."""
        body_lines = ["foo() ; ", "bar(  )  ;"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 53.0μs -> 51.4μs (3.14% faster)

    def test_method_reference_not_invocation(self):
        """Method reference (:: operator) should not be treated as invocation."""
        body_lines = ["Consumer<String> c = String::valueOf;"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "valueOf", iter_id=1) # 54.0μs -> 53.3μs (1.16% faster)

    def test_multiple_statements_per_line_with_semicolons(self):
        """Multiple statements on same line separated by semicolons."""
        body_lines = ["int x = 5; foo(); int y = 10;"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 62.6μs -> 60.6μs (3.39% faster)

    def test_escaped_characters_in_string_arguments(self):
        """Method call with escaped characters in string arguments."""
        body_lines = ['foo("line1\\nline2\\t\\ttab");']
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 54.6μs -> 53.3μs (2.31% faster)

    def test_call_with_long_numeric_literal(self):
        """Method call with long numeric literal argument."""
        body_lines = ["process(999999999999999999L);"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "process", iter_id=1) # 47.5μs -> 46.0μs (3.22% faster)

    def test_call_with_float_literal(self):
        """Method call with float literal argument."""
        body_lines = ["calculate(3.14159f);"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "calculate", iter_id=1) # 46.5μs -> 44.5μs (4.48% faster)

    def test_call_with_boolean_literal(self):
        """Method call with boolean literal argument."""
        body_lines = ["setValue(true);", "setValue(false);"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "setValue", iter_id=1) # 62.4μs -> 59.0μs (5.66% faster)

    def test_call_in_enhanced_for_loop(self):
        """Method call in enhanced for loop."""
        body_lines = ["for (String item : getItems()) { }"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "getItems", iter_id=1) # 55.2μs -> 53.6μs (2.95% faster)

    def test_call_in_resource_try(self):
        """Method call in try-with-resources."""
        body_lines = ["try (Reader r = getReader()) { }"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "getReader", iter_id=1) # 57.6μs -> 55.0μs (4.70% faster)

    def test_serializer_import_present_in_output(self):
        """Generated code references Serializer class."""
        body_lines = ["foo();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 44.8μs -> 43.4μs (3.21% faster)
        result_text = " ".join(result)

    def test_system_nanotime_in_precise_timing(self):
        """Precise timing uses System.nanoTime()."""
        body_lines = ["foo();"]
        result, call_count = wrap_target_calls_with_treesitter(
            body_lines, "foo", iter_id=1, precise_call_timing=True
        ) # 44.9μs -> 43.9μs (2.33% faster)
        result_text = " ".join(result)

    def test_variable_naming_uniqueness(self):
        """Each captured call gets unique variable name."""
        body_lines = ["foo(); foo(); foo();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=1) # 68.4μs -> 63.0μs (8.58% faster)
        result_text = " ".join(result)

    def test_serialized_result_variable_naming(self):
        """Serialized result variables are properly named."""
        body_lines = ["foo();"]
        result, call_count = wrap_target_calls_with_treesitter(body_lines, "foo", iter_id=5) # 43.9μs -> 42.7μs (2.87% faster)
        result_text = " ".join(result)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr1596-2026-02-20T10.23.33 and push.

Codeflash

**Optimization Explanation:**

The profiling reveals that `_collect_calls` consumes 75% of the total execution time, with significant overhead from repeated calls to `_is_inside_lambda` and `_is_inside_complex_expression` (combining for ~22% of `_collect_calls` time). These functions traverse the AST upward for every matched node. I've optimized this by computing parent chain flags once during collection instead of storing them in the call dictionary, and by precomputing `line.encode("utf8")` operations that were being called repeatedly in loops. Additionally, I've moved regex compilation to module level (already done) and eliminated redundant `any()` iteration in `_infer_array_cast_type` by using early-exit short-circuit evaluation with a simple loop that's faster for the common case of no match.
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Feb 20, 2026
Comment on lines +1202 to +1238
def _should_skip_instrumentation(node: Any) -> bool:
"""Check if a node should skip instrumentation (in lambda or complex expression)."""
current = node.parent
while current is not None:
node_type = current.type

# Stop at statement boundaries
if node_type in {
"method_declaration",
"block",
"if_statement",
"for_statement",
"while_statement",
"try_statement",
"expression_statement",
}:
return False

# Lambda check
if node_type == "lambda_expression":
return True

# Complex expression check
if node_type in {
"cast_expression",
"ternary_expression",
"array_access",
"binary_expression",
"unary_expression",
"parenthesized_expression",
"instanceof_expression",
}:
logger.debug("Found complex expression parent: %s", node_type)
return True

current = current.parent
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Lambda detection broken for block-bodied lambdas

The combined _should_skip_instrumentation checks stop boundaries (including expression_statement, block) before checking for lambda_expression. This means for common Java patterns like:

list.forEach(item -> {
    targetMethod(item);  // expression_statement → block → lambda_expression
});

The walk hits expression_statement first and returns False, failing to detect that the call is inside a lambda.

The original _is_inside_lambda only stopped at method_declaration, allowing it to walk through blocks/statements to find enclosing lambdas. The combined function inherits the aggressive stop boundaries from _is_inside_complex_expression, which breaks lambda detection.

Fix: Check for lambda_expression before checking stop boundaries, or separate the lambda and complex-expression checks:

Suggested change
def _should_skip_instrumentation(node: Any) -> bool:
"""Check if a node should skip instrumentation (in lambda or complex expression)."""
current = node.parent
while current is not None:
node_type = current.type
# Stop at statement boundaries
if node_type in {
"method_declaration",
"block",
"if_statement",
"for_statement",
"while_statement",
"try_statement",
"expression_statement",
}:
return False
# Lambda check
if node_type == "lambda_expression":
return True
# Complex expression check
if node_type in {
"cast_expression",
"ternary_expression",
"array_access",
"binary_expression",
"unary_expression",
"parenthesized_expression",
"instanceof_expression",
}:
logger.debug("Found complex expression parent: %s", node_type)
return True
current = current.parent
return False
def _should_skip_instrumentation(node: Any) -> bool:
"""Check if a node should skip instrumentation (in lambda or complex expression)."""
current = node.parent
while current is not None:
node_type = current.type
# Lambda check must come before stop boundaries since lambdas contain blocks
if node_type == "lambda_expression":
return True
# Stop at statement boundaries (for complex expression check only)
if node_type in {
"method_declaration",
"block",
"if_statement",
"for_statement",
"while_statement",
"try_statement",
"expression_statement",
}:
# Still need to check for lambdas above this point
return _is_inside_lambda(node)
# Complex expression check
if node_type in {
"cast_expression",
"ternary_expression",
"array_access",
"binary_expression",
"unary_expression",
"parenthesized_expression",
"instanceof_expression",
}:
logger.debug("Found complex expression parent: %s", node_type)
return True
current = current.parent
return False

Comment on lines +187 to 189
# Group non-lambda and non-complex-expression calls by their line index

# Group non-lambda and non-complex-expression calls by their line index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate comment. Remove one of these lines.

Suggested change
# Group non-lambda and non-complex-expression calls by their line index
# Group non-lambda and non-complex-expression calls by their line index
# Group non-lambda and non-complex-expression calls by their line index

Base automatically changed from codeflash/optimize-pr1580-2026-02-20T10.00.27 to fix/java-direct-jvm-and-bugs February 20, 2026 10:31
@codeflash-ai codeflash-ai bot closed this Feb 20, 2026
@codeflash-ai
Copy link
Contributor Author

codeflash-ai bot commented Feb 20, 2026

This PR has been automatically closed because the original PR #1596 by codeflash-ai[bot] was closed.

@codeflash-ai codeflash-ai bot deleted the codeflash/optimize-pr1596-2026-02-20T10.23.33 branch February 20, 2026 10:31
@claude
Copy link
Contributor

claude bot commented Feb 20, 2026

PR Review Summary

Prek Checks

Passed after auto-fix.

  • Fixed 10 × W293 (blank-line-with-whitespace) errors
  • Fixed ruff formatting (1 file reformatted)
  • Removed duplicate _should_skip_instrumentation function definition (mypy no-redef error)
  • All fixes committed and pushed (cdc3ebfe)

Mypy

Passed — no issues after removing the duplicate function definition.

Code Review

🔴 Critical: Lambda detection broken in _should_skip_instrumentation (instrumentation.py:1202-1238)

The new combined _should_skip_instrumentation function checks stop boundaries (expression_statement, block, etc.) before checking for lambda_expression. For block-bodied lambdas (very common in Java):

list.forEach(item -> {
    targetMethod(item);  // AST: expression_statement → block → lambda_expression
});

The parent walk hits expression_statement first and returns False, failing to detect the enclosing lambda. The original _is_inside_lambda only stopped at method_declaration, correctly walking through blocks and statements.

Fix options: Either check for lambda_expression before stop boundaries, or fall back to _is_inside_lambda at stop boundaries. See inline comment for a suggested fix.

Minor: Duplicate comment (instrumentation.py:187-189) — "Group non-lambda and non-complex-expression calls by their line index" appears twice.

Test Coverage

File Base (Stmts) PR (Stmts) Base Cover PR Cover Δ
codeflash/languages/java/instrumentation.py 515 529 82% 81% -1%
  • Coverage decreased slightly (82% → 81%) due to 14 new statements with partial coverage from the new _should_skip_instrumentation function
  • The new function adds 37 lines (net after removing duplicate), with the added logic for combined lambda/complex-expression checking
  • Coverage remains above the 75% threshold for modified files

Last updated: 2026-02-20

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants