Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 53 additions & 8 deletions codeflash/languages/java/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,18 @@ def wrap_target_calls_with_treesitter(
# Build line byte-start offsets for mapping calls to body_lines indices
line_byte_starts = []
offset = 0
for line in body_lines:
# Precompute encoded lines to avoid repeated encoding
encoded_lines = [line.encode("utf8") for line in body_lines]
for encoded_line in encoded_lines:
line_byte_starts.append(offset)
offset += len(line.encode("utf8")) + 1 # +1 for \n from join
offset += len(encoded_line) + 1 # +1 for \n from join

# Group non-lambda and non-complex-expression calls by their line index

# Group non-lambda and non-complex-expression calls by their line index
Comment on lines +187 to 189
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate comment. Remove one of these lines.

Suggested change
# Group non-lambda and non-complex-expression calls by their line index
# Group non-lambda and non-complex-expression calls by their line index
# Group non-lambda and non-complex-expression calls by their line index

calls_by_line: dict[int, list[dict[str, Any]]] = {}
for call in calls:
if call["in_lambda"] or call.get("in_complex", False):
if call["skip_instrumentation"]:
logger.debug("Skipping behavior instrumentation for call in lambda or complex expression")
continue
line_idx = _byte_to_line_index(call["start_byte"], line_byte_starts)
Expand All @@ -202,7 +206,7 @@ def wrap_target_calls_with_treesitter(
line_calls = sorted(calls_by_line[line_idx], key=lambda c: c["start_byte"], reverse=True)
line_indent_str = " " * (len(body_line) - len(body_line.lstrip()))
line_byte_start = line_byte_starts[line_idx]
line_bytes = body_line.encode("utf8")
line_bytes = encoded_lines[line_idx]

new_line = body_line
# Track cumulative char shift from earlier edits on this line
Expand Down Expand Up @@ -291,14 +295,17 @@ def _collect_calls(
if parent_type == "expression_statement":
es_start = parent.start_byte - prefix_len
es_end = parent.end_byte - prefix_len

# Compute skip flags once during collection
skip_instrumentation = _should_skip_instrumentation(node)

out.append(
{
"start_byte": start,
"end_byte": end,
"full_call": analyzer.get_node_text(node, wrapper_bytes),
"parent_type": parent_type,
"in_lambda": _is_inside_lambda(node),
"in_complex": _is_inside_complex_expression(node),
"skip_instrumentation": skip_instrumentation,
"es_start_byte": es_start,
"es_end_byte": es_end,
}
Expand Down Expand Up @@ -328,8 +335,7 @@ def _infer_array_cast_type(line: str) -> str | None:

"""
# Only apply to assertion methods that take arrays
assertion_methods = ("assertArrayEquals", "assertArrayNotEquals")
if not any(method in line for method in assertion_methods):
if "assertArrayEquals" not in line and "assertArrayNotEquals" not in line:
return None

# Look for primitive array type in the line (usually the first/expected argument)
Expand Down Expand Up @@ -1191,3 +1197,42 @@ def _add_import(source: str, import_statement: str) -> str:

lines.insert(insert_idx, import_statement + "\n")
return "".join(lines)


def _should_skip_instrumentation(node: Any) -> bool:
"""Check if a node should skip instrumentation (in lambda or complex expression)."""
current = node.parent
while current is not None:
node_type = current.type

# Stop at statement boundaries
if node_type in {
"method_declaration",
"block",
"if_statement",
"for_statement",
"while_statement",
"try_statement",
"expression_statement",
}:
return False

# Lambda check
if node_type == "lambda_expression":
return True

# Complex expression check
if node_type in {
"cast_expression",
"ternary_expression",
"array_access",
"binary_expression",
"unary_expression",
"parenthesized_expression",
"instanceof_expression",
}:
logger.debug("Found complex expression parent: %s", node_type)
return True

current = current.parent
return False
Comment on lines +1202 to +1238
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Lambda detection broken for block-bodied lambdas

The combined _should_skip_instrumentation checks stop boundaries (including expression_statement, block) before checking for lambda_expression. This means for common Java patterns like:

list.forEach(item -> {
    targetMethod(item);  // expression_statement → block → lambda_expression
});

The walk hits expression_statement first and returns False, failing to detect that the call is inside a lambda.

The original _is_inside_lambda only stopped at method_declaration, allowing it to walk through blocks/statements to find enclosing lambdas. The combined function inherits the aggressive stop boundaries from _is_inside_complex_expression, which breaks lambda detection.

Fix: Check for lambda_expression before checking stop boundaries, or separate the lambda and complex-expression checks:

Suggested change
def _should_skip_instrumentation(node: Any) -> bool:
"""Check if a node should skip instrumentation (in lambda or complex expression)."""
current = node.parent
while current is not None:
node_type = current.type
# Stop at statement boundaries
if node_type in {
"method_declaration",
"block",
"if_statement",
"for_statement",
"while_statement",
"try_statement",
"expression_statement",
}:
return False
# Lambda check
if node_type == "lambda_expression":
return True
# Complex expression check
if node_type in {
"cast_expression",
"ternary_expression",
"array_access",
"binary_expression",
"unary_expression",
"parenthesized_expression",
"instanceof_expression",
}:
logger.debug("Found complex expression parent: %s", node_type)
return True
current = current.parent
return False
def _should_skip_instrumentation(node: Any) -> bool:
"""Check if a node should skip instrumentation (in lambda or complex expression)."""
current = node.parent
while current is not None:
node_type = current.type
# Lambda check must come before stop boundaries since lambdas contain blocks
if node_type == "lambda_expression":
return True
# Stop at statement boundaries (for complex expression check only)
if node_type in {
"method_declaration",
"block",
"if_statement",
"for_statement",
"while_statement",
"try_statement",
"expression_statement",
}:
# Still need to check for lambdas above this point
return _is_inside_lambda(node)
# Complex expression check
if node_type in {
"cast_expression",
"ternary_expression",
"array_access",
"binary_expression",
"unary_expression",
"parenthesized_expression",
"instanceof_expression",
}:
logger.debug("Found complex expression parent: %s", node_type)
return True
current = current.parent
return False

Loading