Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 71 additions & 28 deletions codeflash/languages/java/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,20 @@ def _detect_test_framework(project_root: Path, build_tool: BuildTool) -> tuple[s
except Exception:
pass

# Determine primary framework (prefer JUnit 5)
# Determine primary framework (prefer JUnit 5 if explicitly found)
if has_junit5:
logger.debug("Selected JUnit 5 as test framework")
return "junit5", has_junit5, has_junit4, has_testng
if has_junit4:
logger.debug("Selected JUnit 4 as test framework")
return "junit4", has_junit5, has_junit4, has_testng
if has_testng:
logger.debug("Selected TestNG as test framework")
return "testng", has_junit5, has_junit4, has_testng

# Default to JUnit 5 if nothing detected
return "junit5", has_junit5, has_junit4, has_testng
# Default to JUnit 4 if nothing detected (more common in legacy projects)
logger.debug("No test framework detected, defaulting to JUnit 4")
return "junit4", has_junit5, has_junit4, has_testng


def _detect_test_deps_from_pom(project_root: Path) -> tuple[bool, bool, bool]:
Expand All @@ -179,42 +183,81 @@ def _detect_test_deps_from_pom(project_root: Path) -> tuple[bool, bool, bool]:
has_junit4 = False
has_testng = False

def check_dependencies(deps_element, ns):
"""Check dependencies element for test frameworks."""
nonlocal has_junit5, has_junit4, has_testng

if deps_element is None:
return

for dep_path in ["dependency", "m:dependency"]:
deps_list = deps_element.findall(dep_path, ns) if "m:" in dep_path else deps_element.findall(dep_path)
for dep in deps_list:
artifact_id = None
group_id = None

for child in dep:
tag = child.tag.replace("{http://maven.apache.org/POM/4.0.0}", "")
if tag == "artifactId":
artifact_id = child.text
elif tag == "groupId":
group_id = child.text

if group_id == "org.junit.jupiter" or (artifact_id and "junit-jupiter" in artifact_id):
has_junit5 = True
logger.debug(f"Found JUnit 5 dependency: {group_id}:{artifact_id}")
elif group_id == "junit" and artifact_id == "junit":
has_junit4 = True
logger.debug(f"Found JUnit 4 dependency: {group_id}:{artifact_id}")
elif group_id == "org.testng":
has_testng = True
logger.debug(f"Found TestNG dependency: {group_id}:{artifact_id}")

try:
tree = ET.parse(pom_path)
root = tree.getroot()

# Handle namespace
ns = {"m": "http://maven.apache.org/POM/4.0.0"}

# Search for dependencies
logger.debug(f"Checking pom.xml at {pom_path}")

# Search for direct dependencies
for deps_path in ["dependencies", "m:dependencies"]:
deps = root.find(deps_path, ns) if "m:" in deps_path else root.find(deps_path)
if deps is None:
continue

for dep_path in ["dependency", "m:dependency"]:
deps_list = deps.findall(dep_path, ns) if "m:" in dep_path else deps.findall(dep_path)
for dep in deps_list:
artifact_id = None
group_id = None

for child in dep:
tag = child.tag.replace("{http://maven.apache.org/POM/4.0.0}", "")
if tag == "artifactId":
artifact_id = child.text
elif tag == "groupId":
group_id = child.text

if group_id == "org.junit.jupiter" or (artifact_id and "junit-jupiter" in artifact_id):
has_junit5 = True
elif group_id == "junit" and artifact_id == "junit":
has_junit4 = True
elif group_id == "org.testng":
has_testng = True
if deps is not None:
logger.debug(f"Found dependencies section in {pom_path}")
check_dependencies(deps, ns)

# Also check dependencyManagement section (for multi-module projects)
for dep_mgmt_path in ["dependencyManagement", "m:dependencyManagement"]:
dep_mgmt = root.find(dep_mgmt_path, ns) if "m:" in dep_mgmt_path else root.find(dep_mgmt_path)
if dep_mgmt is not None:
logger.debug(f"Found dependencyManagement section in {pom_path}")
for deps_path in ["dependencies", "m:dependencies"]:
deps = dep_mgmt.find(deps_path, ns) if "m:" in deps_path else dep_mgmt.find(deps_path)
if deps is not None:
check_dependencies(deps, ns)

except ET.ParseError:
pass

logger.debug("Failed to parse pom.xml at %s", pom_path)

# For multi-module projects, also check submodule pom.xml files
if not (has_junit5 or has_junit4 or has_testng):
logger.debug("No test deps in root pom, checking submodules")
# Check common submodule locations
for submodule_name in ["test", "tests", "src/test", "testing"]:
submodule_pom = project_root / submodule_name / "pom.xml"
if submodule_pom.exists():
logger.debug(f"Checking submodule pom at {submodule_pom}")
sub_junit5, sub_junit4, sub_testng = _detect_test_deps_from_pom(project_root / submodule_name)
has_junit5 = has_junit5 or sub_junit5
has_junit4 = has_junit4 or sub_junit4
has_testng = has_testng or sub_testng
if has_junit5 or has_junit4 or has_testng:
break

logger.debug(f"Test framework detection result: junit5={has_junit5}, junit4={has_junit4}, testng={has_testng}")
return has_junit5, has_junit4, has_testng


Expand Down
96 changes: 87 additions & 9 deletions codeflash/languages/java/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import bisect
import logging
import re
from typing import TYPE_CHECKING
Expand All @@ -26,6 +27,30 @@
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.java.parser import JavaAnalyzer

_STATEMENT_BOUNDARIES = frozenset(
{
"method_declaration",
"block",
"if_statement",
"for_statement",
"while_statement",
"try_statement",
"expression_statement",
}
)

_COMPLEX_EXPRESSIONS = frozenset(
{
"cast_expression",
"ternary_expression",
"array_access",
"binary_expression",
"unary_expression",
"parenthesized_expression",
"instanceof_expression",
}
)

logger = logging.getLogger(__name__)


Expand All @@ -39,6 +64,24 @@ def _get_function_name(func: Any) -> str:
raise AttributeError(msg)


_METHOD_SIG_PATTERN = re.compile(
r"\b(?:public|private|protected)?\s*(?:static)?\s*(?:final)?\s*"
r"(?:void|String|int|long|boolean|double|float|char|byte|short|\w+(?:\[\])?)\s+(\w+)\s*\("
)
_FALLBACK_METHOD_PATTERN = re.compile(r"\b(\w+)\s*\(")


def _extract_test_method_name(method_lines: list[str]) -> str:
method_sig = " ".join(method_lines).strip()
match = _METHOD_SIG_PATTERN.search(method_sig)
if match:
return match.group(1)
fallback_match = _FALLBACK_METHOD_PATTERN.search(method_sig)
if fallback_match:
return fallback_match.group(1)
return "unknown"


# Pattern to detect primitive array types in assertions
_PRIMITIVE_ARRAY_PATTERN = re.compile(r"new\s+(int|long|double|float|short|byte|char|boolean)\s*\[\s*\]")

Expand Down Expand Up @@ -73,6 +116,33 @@ def _is_inside_lambda(node) -> bool:
return False


def _is_inside_complex_expression(node) -> bool:
"""Check if a tree-sitter node is inside a complex expression that shouldn't be instrumented directly.

This includes:
- Cast expressions: (Long)list.get(2)
- Ternary expressions: condition ? func() : other
- Array access: arr[func()]
- Binary operations: func() + 1

Returns True if the node should not be directly instrumented.
"""
current = node.parent
while current is not None:
current_type = current.type

# Stop at statement boundaries
if current_type in _STATEMENT_BOUNDARIES:
return False

# These are complex expressions that shouldn't have instrumentation inserted in the middle
if current_type in _COMPLEX_EXPRESSIONS:
return True

current = current.parent
return False


_TS_BODY_PREFIX = "class _D { void _m() {\n"
_TS_BODY_SUFFIX = "\n}}"
_TS_BODY_PREFIX_BYTES = _TS_BODY_PREFIX.encode("utf8")
Expand Down Expand Up @@ -113,10 +183,11 @@ def wrap_target_calls_with_treesitter(
line_byte_starts.append(offset)
offset += len(line.encode("utf8")) + 1 # +1 for \n from join

# Group non-lambda calls by their line index
# Group non-lambda and non-complex-expression calls by their line index
calls_by_line: dict[int, list] = {}
for call in calls:
if call["in_lambda"]:
if call["in_lambda"] or call.get("in_complex", False):
logger.debug("Skipping behavior instrumentation for call in lambda or complex expression")
continue
line_idx = _byte_to_line_index(call["start_byte"], line_byte_starts)
calls_by_line.setdefault(line_idx, []).append(call)
Expand Down Expand Up @@ -220,6 +291,7 @@ def _collect_calls(node, wrapper_bytes, body_bytes, prefix_len, func_name, analy
"full_call": analyzer.get_node_text(node, wrapper_bytes),
"parent_type": parent_type,
"in_lambda": _is_inside_lambda(node),
"in_complex": _is_inside_complex_expression(node),
"es_start_byte": es_start,
"es_end_byte": es_end,
}
Expand All @@ -230,10 +302,8 @@ def _collect_calls(node, wrapper_bytes, body_bytes, prefix_len, func_name, analy

def _byte_to_line_index(byte_offset: int, line_byte_starts: list[int]) -> int:
"""Map a byte offset in body_text to a body_lines index."""
for i in range(len(line_byte_starts) - 1, -1, -1):
if byte_offset >= line_byte_starts[i]:
return i
return 0
idx = bisect.bisect_right(line_byte_starts, byte_offset) - 1
return max(0, idx)


def _infer_array_cast_type(line: str) -> str | None:
Expand Down Expand Up @@ -495,6 +565,9 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
result.append(ml)
i += 1

# Extract the test method name from the method signature
test_method_name = _extract_test_method_name(method_lines)

# We're now inside the method body
iteration_counter += 1
iter_id = iteration_counter
Expand Down Expand Up @@ -540,6 +613,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
f'{indent}String _cf_outputFile{iter_id} = System.getenv("CODEFLASH_OUTPUT_FILE");',
f'{indent}String _cf_testIteration{iter_id} = System.getenv("CODEFLASH_TEST_ITERATION");',
f'{indent}if (_cf_testIteration{iter_id} == null) _cf_testIteration{iter_id} = "0";',
f'{indent}String _cf_test{iter_id} = "{test_method_name}";',
f'{indent}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_iter{iter_id} + "######$!");',
f"{indent}byte[] _cf_serializedResult{iter_id} = null;",
f"{indent}long _cf_end{iter_id} = -1;",
Expand Down Expand Up @@ -577,7 +651,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
f"{indent} try (PreparedStatement _cf_pstmt{iter_id} = _cf_conn{iter_id}.prepareStatement(_cf_sql{iter_id})) {{",
f"{indent} _cf_pstmt{iter_id}.setString(1, _cf_mod{iter_id});",
f"{indent} _cf_pstmt{iter_id}.setString(2, _cf_cls{iter_id});",
f'{indent} _cf_pstmt{iter_id}.setString(3, "{class_name}Test");',
f"{indent} _cf_pstmt{iter_id}.setString(3, _cf_test{iter_id});",
f"{indent} _cf_pstmt{iter_id}.setString(4, _cf_fn{iter_id});",
f"{indent} _cf_pstmt{iter_id}.setInt(5, _cf_loop{iter_id});",
f'{indent} _cf_pstmt{iter_id}.setString(6, _cf_iter{iter_id} + "_" + _cf_testIteration{iter_id});',
Expand Down Expand Up @@ -664,8 +738,12 @@ def collect_test_methods(node, out) -> None:
def collect_target_calls(node, wrapper_bytes: bytes, func: str, out) -> None:
if node.type == "method_invocation":
name_node = node.child_by_field_name("name")
if name_node and analyzer.get_node_text(name_node, wrapper_bytes) == func and not _is_inside_lambda(node):
out.append(node)
if name_node and analyzer.get_node_text(name_node, wrapper_bytes) == func:
# Skip if inside lambda or complex expression
if not _is_inside_lambda(node) and not _is_inside_complex_expression(node):
out.append(node)
else:
logger.debug(f"Skipping instrumentation of {func} inside lambda or complex expression")
for child in node.children:
collect_target_calls(child, wrapper_bytes, func, out)

Expand Down
Loading
Loading