diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py index e7c51bb38..9e71c6279 100644 --- a/codeflash/languages/java/instrumentation.py +++ b/codeflash/languages/java/instrumentation.py @@ -178,14 +178,18 @@ def wrap_target_calls_with_treesitter( # Build line byte-start offsets for mapping calls to body_lines indices line_byte_starts = [] offset = 0 - for line in body_lines: + # Precompute encoded lines to avoid repeated encoding + encoded_lines = [line.encode("utf8") for line in body_lines] + for encoded_line in encoded_lines: line_byte_starts.append(offset) - offset += len(line.encode("utf8")) + 1 # +1 for \n from join + offset += len(encoded_line) + 1 # +1 for \n from join + + # Group non-lambda and non-complex-expression calls by their line index # Group non-lambda and non-complex-expression calls by their line index calls_by_line: dict[int, list[dict[str, Any]]] = {} for call in calls: - if call["in_lambda"] or call.get("in_complex", False): + if call["skip_instrumentation"]: logger.debug("Skipping behavior instrumentation for call in lambda or complex expression") continue line_idx = _byte_to_line_index(call["start_byte"], line_byte_starts) @@ -202,7 +206,7 @@ def wrap_target_calls_with_treesitter( line_calls = sorted(calls_by_line[line_idx], key=lambda c: c["start_byte"], reverse=True) line_indent_str = " " * (len(body_line) - len(body_line.lstrip())) line_byte_start = line_byte_starts[line_idx] - line_bytes = body_line.encode("utf8") + line_bytes = encoded_lines[line_idx] new_line = body_line # Track cumulative char shift from earlier edits on this line @@ -291,14 +295,17 @@ def _collect_calls( if parent_type == "expression_statement": es_start = parent.start_byte - prefix_len es_end = parent.end_byte - prefix_len + + # Compute skip flags once during collection + skip_instrumentation = _should_skip_instrumentation(node) + out.append( { "start_byte": start, "end_byte": end, "full_call": analyzer.get_node_text(node, wrapper_bytes), "parent_type": parent_type, - "in_lambda": _is_inside_lambda(node), - "in_complex": _is_inside_complex_expression(node), + "skip_instrumentation": skip_instrumentation, "es_start_byte": es_start, "es_end_byte": es_end, } @@ -328,8 +335,7 @@ def _infer_array_cast_type(line: str) -> str | None: """ # Only apply to assertion methods that take arrays - assertion_methods = ("assertArrayEquals", "assertArrayNotEquals") - if not any(method in line for method in assertion_methods): + if "assertArrayEquals" not in line and "assertArrayNotEquals" not in line: return None # Look for primitive array type in the line (usually the first/expected argument) @@ -1191,3 +1197,42 @@ def _add_import(source: str, import_statement: str) -> str: lines.insert(insert_idx, import_statement + "\n") return "".join(lines) + + +def _should_skip_instrumentation(node: Any) -> bool: + """Check if a node should skip instrumentation (in lambda or complex expression).""" + current = node.parent + while current is not None: + node_type = current.type + + # Stop at statement boundaries + if node_type in { + "method_declaration", + "block", + "if_statement", + "for_statement", + "while_statement", + "try_statement", + "expression_statement", + }: + return False + + # Lambda check + if node_type == "lambda_expression": + return True + + # Complex expression check + if node_type in { + "cast_expression", + "ternary_expression", + "array_access", + "binary_expression", + "unary_expression", + "parenthesized_expression", + "instanceof_expression", + }: + logger.debug("Found complex expression parent: %s", node_type) + return True + + current = current.parent + return False