diff --git a/codeflash/languages/java/config.py b/codeflash/languages/java/config.py index 408dcecaf..53041280e 100644 --- a/codeflash/languages/java/config.py +++ b/codeflash/languages/java/config.py @@ -152,16 +152,20 @@ def _detect_test_framework(project_root: Path, build_tool: BuildTool) -> tuple[s except Exception: pass - # Determine primary framework (prefer JUnit 5) + # Determine primary framework (prefer JUnit 5 if explicitly found) if has_junit5: + logger.debug("Selected JUnit 5 as test framework") return "junit5", has_junit5, has_junit4, has_testng if has_junit4: + logger.debug("Selected JUnit 4 as test framework") return "junit4", has_junit5, has_junit4, has_testng if has_testng: + logger.debug("Selected TestNG as test framework") return "testng", has_junit5, has_junit4, has_testng - # Default to JUnit 5 if nothing detected - return "junit5", has_junit5, has_junit4, has_testng + # Default to JUnit 4 if nothing detected (more common in legacy projects) + logger.debug("No test framework detected, defaulting to JUnit 4") + return "junit4", has_junit5, has_junit4, has_testng def _detect_test_deps_from_pom(project_root: Path) -> tuple[bool, bool, bool]: @@ -179,6 +183,36 @@ def _detect_test_deps_from_pom(project_root: Path) -> tuple[bool, bool, bool]: has_junit4 = False has_testng = False + def check_dependencies(deps_element, ns): + """Check dependencies element for test frameworks.""" + nonlocal has_junit5, has_junit4, has_testng + + if deps_element is None: + return + + for dep_path in ["dependency", "m:dependency"]: + deps_list = deps_element.findall(dep_path, ns) if "m:" in dep_path else deps_element.findall(dep_path) + for dep in deps_list: + artifact_id = None + group_id = None + + for child in dep: + tag = child.tag.replace("{http://maven.apache.org/POM/4.0.0}", "") + if tag == "artifactId": + artifact_id = child.text + elif tag == "groupId": + group_id = child.text + + if group_id == "org.junit.jupiter" or (artifact_id and "junit-jupiter" in artifact_id): + has_junit5 = True + logger.debug(f"Found JUnit 5 dependency: {group_id}:{artifact_id}") + elif group_id == "junit" and artifact_id == "junit": + has_junit4 = True + logger.debug(f"Found JUnit 4 dependency: {group_id}:{artifact_id}") + elif group_id == "org.testng": + has_testng = True + logger.debug(f"Found TestNG dependency: {group_id}:{artifact_id}") + try: tree = ET.parse(pom_path) root = tree.getroot() @@ -186,35 +220,44 @@ def _detect_test_deps_from_pom(project_root: Path) -> tuple[bool, bool, bool]: # Handle namespace ns = {"m": "http://maven.apache.org/POM/4.0.0"} - # Search for dependencies + logger.debug(f"Checking pom.xml at {pom_path}") + + # Search for direct dependencies for deps_path in ["dependencies", "m:dependencies"]: deps = root.find(deps_path, ns) if "m:" in deps_path else root.find(deps_path) - if deps is None: - continue - - for dep_path in ["dependency", "m:dependency"]: - deps_list = deps.findall(dep_path, ns) if "m:" in dep_path else deps.findall(dep_path) - for dep in deps_list: - artifact_id = None - group_id = None - - for child in dep: - tag = child.tag.replace("{http://maven.apache.org/POM/4.0.0}", "") - if tag == "artifactId": - artifact_id = child.text - elif tag == "groupId": - group_id = child.text - - if group_id == "org.junit.jupiter" or (artifact_id and "junit-jupiter" in artifact_id): - has_junit5 = True - elif group_id == "junit" and artifact_id == "junit": - has_junit4 = True - elif group_id == "org.testng": - has_testng = True + if deps is not None: + logger.debug(f"Found dependencies section in {pom_path}") + check_dependencies(deps, ns) + + # Also check dependencyManagement section (for multi-module projects) + for dep_mgmt_path in ["dependencyManagement", "m:dependencyManagement"]: + dep_mgmt = root.find(dep_mgmt_path, ns) if "m:" in dep_mgmt_path else root.find(dep_mgmt_path) + if dep_mgmt is not None: + logger.debug(f"Found dependencyManagement section in {pom_path}") + for deps_path in ["dependencies", "m:dependencies"]: + deps = dep_mgmt.find(deps_path, ns) if "m:" in deps_path else dep_mgmt.find(deps_path) + if deps is not None: + check_dependencies(deps, ns) except ET.ParseError: - pass - + logger.debug("Failed to parse pom.xml at %s", pom_path) + + # For multi-module projects, also check submodule pom.xml files + if not (has_junit5 or has_junit4 or has_testng): + logger.debug("No test deps in root pom, checking submodules") + # Check common submodule locations + for submodule_name in ["test", "tests", "src/test", "testing"]: + submodule_pom = project_root / submodule_name / "pom.xml" + if submodule_pom.exists(): + logger.debug(f"Checking submodule pom at {submodule_pom}") + sub_junit5, sub_junit4, sub_testng = _detect_test_deps_from_pom(project_root / submodule_name) + has_junit5 = has_junit5 or sub_junit5 + has_junit4 = has_junit4 or sub_junit4 + has_testng = has_testng or sub_testng + if has_junit5 or has_junit4 or has_testng: + break + + logger.debug(f"Test framework detection result: junit5={has_junit5}, junit4={has_junit4}, testng={has_testng}") return has_junit5, has_junit4, has_testng diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py index 18fdb1409..b09811468 100644 --- a/codeflash/languages/java/instrumentation.py +++ b/codeflash/languages/java/instrumentation.py @@ -14,6 +14,7 @@ from __future__ import annotations +import bisect import logging import re from typing import TYPE_CHECKING @@ -26,6 +27,30 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.java.parser import JavaAnalyzer +_STATEMENT_BOUNDARIES = frozenset( + { + "method_declaration", + "block", + "if_statement", + "for_statement", + "while_statement", + "try_statement", + "expression_statement", + } +) + +_COMPLEX_EXPRESSIONS = frozenset( + { + "cast_expression", + "ternary_expression", + "array_access", + "binary_expression", + "unary_expression", + "parenthesized_expression", + "instanceof_expression", + } +) + logger = logging.getLogger(__name__) @@ -39,6 +64,24 @@ def _get_function_name(func: Any) -> str: raise AttributeError(msg) +_METHOD_SIG_PATTERN = re.compile( + r"\b(?:public|private|protected)?\s*(?:static)?\s*(?:final)?\s*" + r"(?:void|String|int|long|boolean|double|float|char|byte|short|\w+(?:\[\])?)\s+(\w+)\s*\(" +) +_FALLBACK_METHOD_PATTERN = re.compile(r"\b(\w+)\s*\(") + + +def _extract_test_method_name(method_lines: list[str]) -> str: + method_sig = " ".join(method_lines).strip() + match = _METHOD_SIG_PATTERN.search(method_sig) + if match: + return match.group(1) + fallback_match = _FALLBACK_METHOD_PATTERN.search(method_sig) + if fallback_match: + return fallback_match.group(1) + return "unknown" + + # Pattern to detect primitive array types in assertions _PRIMITIVE_ARRAY_PATTERN = re.compile(r"new\s+(int|long|double|float|short|byte|char|boolean)\s*\[\s*\]") @@ -73,6 +116,33 @@ def _is_inside_lambda(node) -> bool: return False +def _is_inside_complex_expression(node) -> bool: + """Check if a tree-sitter node is inside a complex expression that shouldn't be instrumented directly. + + This includes: + - Cast expressions: (Long)list.get(2) + - Ternary expressions: condition ? func() : other + - Array access: arr[func()] + - Binary operations: func() + 1 + + Returns True if the node should not be directly instrumented. + """ + current = node.parent + while current is not None: + current_type = current.type + + # Stop at statement boundaries + if current_type in _STATEMENT_BOUNDARIES: + return False + + # These are complex expressions that shouldn't have instrumentation inserted in the middle + if current_type in _COMPLEX_EXPRESSIONS: + return True + + current = current.parent + return False + + _TS_BODY_PREFIX = "class _D { void _m() {\n" _TS_BODY_SUFFIX = "\n}}" _TS_BODY_PREFIX_BYTES = _TS_BODY_PREFIX.encode("utf8") @@ -113,10 +183,11 @@ def wrap_target_calls_with_treesitter( line_byte_starts.append(offset) offset += len(line.encode("utf8")) + 1 # +1 for \n from join - # Group non-lambda calls by their line index + # Group non-lambda and non-complex-expression calls by their line index calls_by_line: dict[int, list] = {} for call in calls: - if call["in_lambda"]: + if call["in_lambda"] or call.get("in_complex", False): + logger.debug("Skipping behavior instrumentation for call in lambda or complex expression") continue line_idx = _byte_to_line_index(call["start_byte"], line_byte_starts) calls_by_line.setdefault(line_idx, []).append(call) @@ -220,6 +291,7 @@ def _collect_calls(node, wrapper_bytes, body_bytes, prefix_len, func_name, analy "full_call": analyzer.get_node_text(node, wrapper_bytes), "parent_type": parent_type, "in_lambda": _is_inside_lambda(node), + "in_complex": _is_inside_complex_expression(node), "es_start_byte": es_start, "es_end_byte": es_end, } @@ -230,10 +302,8 @@ def _collect_calls(node, wrapper_bytes, body_bytes, prefix_len, func_name, analy def _byte_to_line_index(byte_offset: int, line_byte_starts: list[int]) -> int: """Map a byte offset in body_text to a body_lines index.""" - for i in range(len(line_byte_starts) - 1, -1, -1): - if byte_offset >= line_byte_starts[i]: - return i - return 0 + idx = bisect.bisect_right(line_byte_starts, byte_offset) - 1 + return max(0, idx) def _infer_array_cast_type(line: str) -> str | None: @@ -495,6 +565,9 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str) result.append(ml) i += 1 + # Extract the test method name from the method signature + test_method_name = _extract_test_method_name(method_lines) + # We're now inside the method body iteration_counter += 1 iter_id = iteration_counter @@ -540,6 +613,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str) f'{indent}String _cf_outputFile{iter_id} = System.getenv("CODEFLASH_OUTPUT_FILE");', f'{indent}String _cf_testIteration{iter_id} = System.getenv("CODEFLASH_TEST_ITERATION");', f'{indent}if (_cf_testIteration{iter_id} == null) _cf_testIteration{iter_id} = "0";', + f'{indent}String _cf_test{iter_id} = "{test_method_name}";', f'{indent}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_iter{iter_id} + "######$!");', f"{indent}byte[] _cf_serializedResult{iter_id} = null;", f"{indent}long _cf_end{iter_id} = -1;", @@ -577,7 +651,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str) f"{indent} try (PreparedStatement _cf_pstmt{iter_id} = _cf_conn{iter_id}.prepareStatement(_cf_sql{iter_id})) {{", f"{indent} _cf_pstmt{iter_id}.setString(1, _cf_mod{iter_id});", f"{indent} _cf_pstmt{iter_id}.setString(2, _cf_cls{iter_id});", - f'{indent} _cf_pstmt{iter_id}.setString(3, "{class_name}Test");', + f"{indent} _cf_pstmt{iter_id}.setString(3, _cf_test{iter_id});", f"{indent} _cf_pstmt{iter_id}.setString(4, _cf_fn{iter_id});", f"{indent} _cf_pstmt{iter_id}.setInt(5, _cf_loop{iter_id});", f'{indent} _cf_pstmt{iter_id}.setString(6, _cf_iter{iter_id} + "_" + _cf_testIteration{iter_id});', @@ -664,8 +738,12 @@ def collect_test_methods(node, out) -> None: def collect_target_calls(node, wrapper_bytes: bytes, func: str, out) -> None: if node.type == "method_invocation": name_node = node.child_by_field_name("name") - if name_node and analyzer.get_node_text(name_node, wrapper_bytes) == func and not _is_inside_lambda(node): - out.append(node) + if name_node and analyzer.get_node_text(name_node, wrapper_bytes) == func: + # Skip if inside lambda or complex expression + if not _is_inside_lambda(node) and not _is_inside_complex_expression(node): + out.append(node) + else: + logger.debug(f"Skipping instrumentation of {func} inside lambda or complex expression") for child in node.children: collect_target_calls(child, wrapper_bytes, func, out) diff --git a/codeflash/languages/java/test_runner.py b/codeflash/languages/java/test_runner.py index 5ca2f2f8f..ca8b1b2c7 100644 --- a/codeflash/languages/java/test_runner.py +++ b/codeflash/languages/java/test_runner.py @@ -562,6 +562,26 @@ def _get_test_classpath( if main_classes.exists(): cp_parts.append(str(main_classes)) + # For multi-module projects, also include target/classes from all modules + # This is needed because the test module may depend on other modules + if test_module: + # Find all target/classes directories in sibling modules + for module_dir in project_root.iterdir(): + if module_dir.is_dir() and module_dir.name != test_module: + module_classes = module_dir / "target" / "classes" + if module_classes.exists(): + logger.debug(f"Adding multi-module classpath: {module_classes}") + cp_parts.append(str(module_classes)) + + # Add JUnit Platform Console Standalone JAR if not already on classpath. + # This is required for direct JVM execution with ConsoleLauncher, + # which is NOT included in the standard junit-jupiter dependency tree. + if "console-standalone" not in classpath and "ConsoleLauncher" not in classpath: + console_jar = _find_junit_console_standalone() + if console_jar: + logger.debug("Adding JUnit Console Standalone to classpath: %s", console_jar) + cp_parts.append(str(console_jar)) + return os.pathsep.join(cp_parts) except subprocess.TimeoutExpired: @@ -576,6 +596,54 @@ def _get_test_classpath( cp_file.unlink() +def _find_junit_console_standalone() -> Path | None: + """Find the JUnit Platform Console Standalone JAR in the local Maven repository. + + This JAR contains ConsoleLauncher which is required for direct JVM test execution + with JUnit 5. It is NOT included in the standard junit-jupiter dependency tree. + + Returns: + Path to the console standalone JAR, or None if not found. + + """ + m2_base = Path.home() / ".m2" / "repository" / "org" / "junit" / "platform" / "junit-platform-console-standalone" + if not m2_base.exists(): + # Try to download it via Maven + mvn = find_maven_executable() + if mvn: + logger.debug("Console standalone not found in cache, downloading via Maven") + try: + subprocess.run( + [ + mvn, + "dependency:get", + "-Dartifact=org.junit.platform:junit-platform-console-standalone:1.10.0", + "-q", + "-B", + ], + check=False, + capture_output=True, + text=True, + timeout=30, + ) + except (subprocess.TimeoutExpired, Exception): + pass + if not m2_base.exists(): + return None + + # Find the latest version available + try: + versions = sorted([d for d in m2_base.iterdir() if d.is_dir()], key=lambda d: d.name, reverse=True) + for version_dir in versions: + jar = version_dir / f"junit-platform-console-standalone-{version_dir.name}.jar" + if jar.exists(): + return jar + except Exception: + pass + + return None + + def _run_tests_direct( classpath: str, test_classes: list[str], @@ -605,49 +673,97 @@ def _run_tests_direct( java = _find_java_executable() or "java" - # Build command using JUnit Platform Console Launcher - # The launcher is included in junit-platform-console-standalone or junit-jupiter - cmd = [ - str(java), - # Java 16+ module system: Kryo needs reflective access to internal JDK classes - "--add-opens", - "java.base/java.util=ALL-UNNAMED", - "--add-opens", - "java.base/java.lang=ALL-UNNAMED", - "--add-opens", - "java.base/java.lang.reflect=ALL-UNNAMED", - "--add-opens", - "java.base/java.io=ALL-UNNAMED", - "--add-opens", - "java.base/java.math=ALL-UNNAMED", - "--add-opens", - "java.base/java.net=ALL-UNNAMED", - "--add-opens", - "java.base/java.util.zip=ALL-UNNAMED", - "-cp", - classpath, - "org.junit.platform.console.ConsoleLauncher", - "--disable-banner", - "--disable-ansi-colors", - # Use 'none' details to avoid duplicate output - # Timing markers are captured in XML via stdout capture config - "--details=none", - # Enable stdout/stderr capture in XML reports - # This ensures timing markers are included in the XML system-out element - "--config=junit.platform.output.capture.stdout=true", - "--config=junit.platform.output.capture.stderr=true", - ] - - # Add reports directory if specified (for XML output) - if reports_dir: - reports_dir.mkdir(parents=True, exist_ok=True) - cmd.extend(["--reports-dir", str(reports_dir)]) - - # Add test classes to select - for test_class in test_classes: - cmd.extend(["--select-class", test_class]) - - logger.debug("Running tests directly: java -cp ... ConsoleLauncher --select-class %s", test_classes) + # Detect JUnit version from the classpath string. + # We check for junit-jupiter (the JUnit 5 test API) as the indicator of JUnit 5 tests. + # Note: console-standalone and junit-platform are NOT reliable indicators because + # we inject console-standalone ourselves in _get_test_classpath(), so it's always present. + # ConsoleLauncher can run both JUnit 5 and JUnit 4 tests (via vintage engine), + # so we prefer it when available and only fall back to JUnitCore for pure JUnit 4 + # projects without ConsoleLauncher on the classpath. + has_junit5_tests = "junit-jupiter" in classpath + has_console_launcher = "console-standalone" in classpath or "ConsoleLauncher" in classpath + # Use ConsoleLauncher if available (works for both JUnit 4 via vintage and JUnit 5). + # Only use JUnitCore when ConsoleLauncher is not on the classpath at all. + is_junit4 = not has_console_launcher + if is_junit4: + logger.debug("JUnit 4 project, no ConsoleLauncher available, using JUnitCore") + elif has_junit5_tests: + logger.debug("JUnit 5 project, using ConsoleLauncher") + else: + logger.debug("JUnit 4 project, using ConsoleLauncher (via vintage engine)") + + if is_junit4: + # Use JUnit 4's JUnitCore runner + cmd = [ + str(java), + # Java 16+ module system: Kryo needs reflective access to internal JDK classes + "--add-opens", + "java.base/java.util=ALL-UNNAMED", + "--add-opens", + "java.base/java.lang=ALL-UNNAMED", + "--add-opens", + "java.base/java.lang.reflect=ALL-UNNAMED", + "--add-opens", + "java.base/java.io=ALL-UNNAMED", + "--add-opens", + "java.base/java.math=ALL-UNNAMED", + "--add-opens", + "java.base/java.net=ALL-UNNAMED", + "--add-opens", + "java.base/java.util.zip=ALL-UNNAMED", + "-cp", + classpath, + "org.junit.runner.JUnitCore", + ] + # Add test classes + cmd.extend(test_classes) + else: + # Build command using JUnit Platform Console Launcher (JUnit 5) + # The launcher is included in junit-platform-console-standalone or junit-jupiter + cmd = [ + str(java), + # Java 16+ module system: Kryo needs reflective access to internal JDK classes + "--add-opens", + "java.base/java.util=ALL-UNNAMED", + "--add-opens", + "java.base/java.lang=ALL-UNNAMED", + "--add-opens", + "java.base/java.lang.reflect=ALL-UNNAMED", + "--add-opens", + "java.base/java.io=ALL-UNNAMED", + "--add-opens", + "java.base/java.math=ALL-UNNAMED", + "--add-opens", + "java.base/java.net=ALL-UNNAMED", + "--add-opens", + "java.base/java.util.zip=ALL-UNNAMED", + "-cp", + classpath, + "org.junit.platform.console.ConsoleLauncher", + "--disable-banner", + "--disable-ansi-colors", + # Use 'none' details to avoid duplicate output + # Timing markers are captured in XML via stdout capture config + "--details=none", + # Enable stdout/stderr capture in XML reports + # This ensures timing markers are included in the XML system-out element + "--config=junit.platform.output.capture.stdout=true", + "--config=junit.platform.output.capture.stderr=true", + ] + + # Add reports directory if specified (for XML output) + if reports_dir: + reports_dir.mkdir(parents=True, exist_ok=True) + cmd.extend(["--reports-dir", str(reports_dir)]) + + # Add test classes to select + for test_class in test_classes: + cmd.extend(["--select-class", test_class]) + + if is_junit4: + logger.debug("Running tests directly: java -cp ... JUnitCore %s", test_classes) + else: + logger.debug("Running tests directly: java -cp ... ConsoleLauncher --select-class %s", test_classes) try: return subprocess.run( @@ -982,6 +1098,10 @@ def run_benchmarking_tests( logger.debug("Loop %d completed in %.2fs (returncode=%d)", loop_idx, loop_time, result.returncode) + # Log stderr if direct JVM execution failed (for debugging) + if result.returncode != 0 and result.stderr: + logger.debug("Direct JVM stderr: %s", result.stderr[:500]) + # Check if direct JVM execution failed on the first loop. # Fall back to Maven-based execution for: # - JUnit 4 projects (ConsoleLauncher not on classpath or no tests discovered) diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index e00c3a827..a662cd2e6 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -143,6 +143,10 @@ def parse_concurrency_metrics(test_results: TestResults, function_name: str) -> ) +# Cache for resolved test file paths to avoid repeated rglob calls +_test_file_path_cache: dict[tuple[str, Path], Path | None] = {} + + def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> Path | None: """Resolve test file path from pytest's test class path or Java class path. @@ -164,6 +168,13 @@ def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> P >>> # Should find: /path/to/tests/unittest/test_file.py """ + # Check cache first + cache_key = (test_class_path, base_dir) + if cache_key in _test_file_path_cache: + cached_result = _test_file_path_cache[cache_key] + logger.debug(f"[RESOLVE] Cache hit for {test_class_path}: {cached_result}") + return cached_result + # Handle Java class paths (convert dots to path and add .java extension) # Java class paths look like "com.example.TestClass" and should map to # src/test/java/com/example/TestClass.java @@ -178,6 +189,7 @@ def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> P logger.debug(f"[RESOLVE] Attempt 1: checking {potential_path}") if potential_path.exists(): logger.debug(f"[RESOLVE] Attempt 1 SUCCESS: found {potential_path}") + _test_file_path_cache[cache_key] = potential_path return potential_path # 2. Under src/test/java relative to project root @@ -189,6 +201,7 @@ def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> P logger.debug(f"[RESOLVE] Attempt 2: checking {potential_path} (project_root={project_root})") if potential_path.exists(): logger.debug(f"[RESOLVE] Attempt 2 SUCCESS: found {potential_path}") + _test_file_path_cache[cache_key] = potential_path return potential_path # 3. Search for the file in base_dir and its subdirectories @@ -196,9 +209,11 @@ def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> P logger.debug(f"[RESOLVE] Attempt 3: rglob for {file_name} in {base_dir}") for java_file in base_dir.rglob(file_name): logger.debug(f"[RESOLVE] Attempt 3 SUCCESS: rglob found {java_file}") + _test_file_path_cache[cache_key] = java_file return java_file logger.warning(f"[RESOLVE] FAILED to resolve {test_class_path} in base_dir {base_dir}") + _test_file_path_cache[cache_key] = None # Cache negative results too return None # Handle file paths (contain slashes and extensions like .js/.ts) @@ -207,6 +222,7 @@ def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> P # Try the path as-is if it's absolute potential_path = Path(test_class_path) if potential_path.is_absolute() and potential_path.exists(): + _test_file_path_cache[cache_key] = potential_path return potential_path # Try to resolve relative to base_dir's parent (project root) @@ -216,6 +232,7 @@ def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> P try: potential_path = potential_path.resolve() if potential_path.exists(): + _test_file_path_cache[cache_key] = potential_path return potential_path except (OSError, RuntimeError): pass @@ -225,10 +242,12 @@ def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> P try: potential_path = potential_path.resolve() if potential_path.exists(): + _test_file_path_cache[cache_key] = potential_path return potential_path except (OSError, RuntimeError): pass + _test_file_path_cache[cache_key] = None # Cache negative results return None # First try the full path (Python module path) @@ -259,6 +278,8 @@ def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> P if test_file_path: break + # Cache the result (could be None) + _test_file_path_cache[cache_key] = test_file_path return test_file_path @@ -1228,6 +1249,21 @@ def parse_test_results( results = merge_test_results(test_results_xml, test_results_data, test_config.test_framework) + # Bug #10 Fix: For Java performance tests, preserve subprocess stdout containing timing markers + # This is needed for calculate_function_throughput_from_test_results to work correctly + if is_java() and testing_type == TestingMode.PERFORMANCE and run_result is not None: + try: + # Extract stdout from subprocess result containing timing markers + if isinstance(run_result.stdout, bytes): + results.perf_stdout = run_result.stdout.decode("utf-8", errors="replace") + elif isinstance(run_result.stdout, str): + results.perf_stdout = run_result.stdout + logger.debug( + f"Bug #10 Fix: Set perf_stdout for Java performance tests ({len(results.perf_stdout or '')} chars)" + ) + except Exception as e: + logger.debug(f"Bug #10 Fix: Failed to set perf_stdout: {e}") + all_args = False coverage = None if coverage_database_file and source_file and code_context and function_name: diff --git a/codeflash/verification/verification_utils.py b/codeflash/verification/verification_utils.py index c6650ef99..0a613c1fe 100644 --- a/codeflash/verification/verification_utils.py +++ b/codeflash/verification/verification_utils.py @@ -179,6 +179,7 @@ class TestConfig: use_cache: bool = True _language: Optional[str] = None # Language identifier for multi-language support js_project_root: Optional[Path] = None # JavaScript project root (directory containing package.json) + _test_framework: Optional[str] = None # Cached test framework detection result def __post_init__(self) -> None: self.tests_root = self.tests_root.resolve() @@ -191,14 +192,19 @@ def test_framework(self) -> str: For JavaScript/TypeScript: uses the configured framework (vitest, jest, or mocha). For Python: uses pytest as default. + Result is cached after first detection to avoid repeated pom.xml parsing. """ + if self._test_framework is not None: + return self._test_framework if is_javascript(): from codeflash.languages.test_framework import get_js_test_framework_or_default - return get_js_test_framework_or_default() - if is_java(): - return self._detect_java_test_framework() - return "pytest" + self._test_framework = get_js_test_framework_or_default() + elif is_java(): + self._test_framework = self._detect_java_test_framework() + else: + self._test_framework = "pytest" + return self._test_framework def _detect_java_test_framework(self) -> str: """Detect the Java test framework from the project configuration. @@ -232,7 +238,7 @@ def _detect_java_test_framework(self) -> str: return config.test_framework except Exception: pass - return "junit5" # Default fallback + return "junit4" # Default fallback (JUnit 4 is more common in legacy projects) def set_language(self, language: str) -> None: """Set the language for this test config. diff --git a/tests/test_languages/test_java/test_instrumentation.py b/tests/test_languages/test_java/test_instrumentation.py index c07340ec4..5a2c5ba91 100644 --- a/tests/test_languages/test_java/test_instrumentation.py +++ b/tests/test_languages/test_java/test_instrumentation.py @@ -145,6 +145,7 @@ def test_instrument_behavior_mode_simple(self, tmp_path: Path): String _cf_outputFile1 = System.getenv("CODEFLASH_OUTPUT_FILE"); String _cf_testIteration1 = System.getenv("CODEFLASH_TEST_ITERATION"); if (_cf_testIteration1 == null) _cf_testIteration1 = "0"; + String _cf_test1 = "testAdd"; System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_iter1 + "######$!"); byte[] _cf_serializedResult1 = null; long _cf_end1 = -1; @@ -175,7 +176,7 @@ def test_instrument_behavior_mode_simple(self, tmp_path: Path): try (PreparedStatement _cf_pstmt1 = _cf_conn1.prepareStatement(_cf_sql1)) { _cf_pstmt1.setString(1, _cf_mod1); _cf_pstmt1.setString(2, _cf_cls1); - _cf_pstmt1.setString(3, "CalculatorTestTest"); + _cf_pstmt1.setString(3, _cf_test1); _cf_pstmt1.setString(4, _cf_fn1); _cf_pstmt1.setInt(5, _cf_loop1); _cf_pstmt1.setString(6, _cf_iter1 + "_" + _cf_testIteration1); @@ -256,6 +257,7 @@ def test_instrument_behavior_mode_assert_throws_expression_lambda(self, tmp_path String _cf_outputFile1 = System.getenv("CODEFLASH_OUTPUT_FILE"); String _cf_testIteration1 = System.getenv("CODEFLASH_TEST_ITERATION"); if (_cf_testIteration1 == null) _cf_testIteration1 = "0"; + String _cf_test1 = "testNegativeInput_ThrowsIllegalArgumentException"; System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_iter1 + "######$!"); byte[] _cf_serializedResult1 = null; long _cf_end1 = -1; @@ -281,7 +283,7 @@ def test_instrument_behavior_mode_assert_throws_expression_lambda(self, tmp_path try (PreparedStatement _cf_pstmt1 = _cf_conn1.prepareStatement(_cf_sql1)) { _cf_pstmt1.setString(1, _cf_mod1); _cf_pstmt1.setString(2, _cf_cls1); - _cf_pstmt1.setString(3, "FibonacciTestTest"); + _cf_pstmt1.setString(3, _cf_test1); _cf_pstmt1.setString(4, _cf_fn1); _cf_pstmt1.setInt(5, _cf_loop1); _cf_pstmt1.setString(6, _cf_iter1 + "_" + _cf_testIteration1); @@ -309,6 +311,7 @@ def test_instrument_behavior_mode_assert_throws_expression_lambda(self, tmp_path String _cf_outputFile2 = System.getenv("CODEFLASH_OUTPUT_FILE"); String _cf_testIteration2 = System.getenv("CODEFLASH_TEST_ITERATION"); if (_cf_testIteration2 == null) _cf_testIteration2 = "0"; + String _cf_test2 = "testZeroInput_ReturnsZero"; System.out.println("!$######" + _cf_mod2 + ":" + _cf_cls2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":" + _cf_iter2 + "######$!"); byte[] _cf_serializedResult2 = null; long _cf_end2 = -1; @@ -338,7 +341,7 @@ def test_instrument_behavior_mode_assert_throws_expression_lambda(self, tmp_path try (PreparedStatement _cf_pstmt2 = _cf_conn2.prepareStatement(_cf_sql2)) { _cf_pstmt2.setString(1, _cf_mod2); _cf_pstmt2.setString(2, _cf_cls2); - _cf_pstmt2.setString(3, "FibonacciTestTest"); + _cf_pstmt2.setString(3, _cf_test2); _cf_pstmt2.setString(4, _cf_fn2); _cf_pstmt2.setInt(5, _cf_loop2); _cf_pstmt2.setString(6, _cf_iter2 + "_" + _cf_testIteration2); @@ -420,6 +423,7 @@ def test_instrument_behavior_mode_assert_throws_block_lambda(self, tmp_path: Pat String _cf_outputFile1 = System.getenv("CODEFLASH_OUTPUT_FILE"); String _cf_testIteration1 = System.getenv("CODEFLASH_TEST_ITERATION"); if (_cf_testIteration1 == null) _cf_testIteration1 = "0"; + String _cf_test1 = "testNegativeInput_ThrowsIllegalArgumentException"; System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_iter1 + "######$!"); byte[] _cf_serializedResult1 = null; long _cf_end1 = -1; @@ -447,7 +451,7 @@ def test_instrument_behavior_mode_assert_throws_block_lambda(self, tmp_path: Pat try (PreparedStatement _cf_pstmt1 = _cf_conn1.prepareStatement(_cf_sql1)) { _cf_pstmt1.setString(1, _cf_mod1); _cf_pstmt1.setString(2, _cf_cls1); - _cf_pstmt1.setString(3, "FibonacciTestTest"); + _cf_pstmt1.setString(3, _cf_test1); _cf_pstmt1.setString(4, _cf_fn1); _cf_pstmt1.setInt(5, _cf_loop1); _cf_pstmt1.setString(6, _cf_iter1 + "_" + _cf_testIteration1); @@ -475,6 +479,7 @@ def test_instrument_behavior_mode_assert_throws_block_lambda(self, tmp_path: Pat String _cf_outputFile2 = System.getenv("CODEFLASH_OUTPUT_FILE"); String _cf_testIteration2 = System.getenv("CODEFLASH_TEST_ITERATION"); if (_cf_testIteration2 == null) _cf_testIteration2 = "0"; + String _cf_test2 = "testZeroInput_ReturnsZero"; System.out.println("!$######" + _cf_mod2 + ":" + _cf_cls2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":" + _cf_iter2 + "######$!"); byte[] _cf_serializedResult2 = null; long _cf_end2 = -1; @@ -504,7 +509,7 @@ def test_instrument_behavior_mode_assert_throws_block_lambda(self, tmp_path: Pat try (PreparedStatement _cf_pstmt2 = _cf_conn2.prepareStatement(_cf_sql2)) { _cf_pstmt2.setString(1, _cf_mod2); _cf_pstmt2.setString(2, _cf_cls2); - _cf_pstmt2.setString(3, "FibonacciTestTest"); + _cf_pstmt2.setString(3, _cf_test2); _cf_pstmt2.setString(4, _cf_fn2); _cf_pstmt2.setInt(5, _cf_loop2); _cf_pstmt2.setString(6, _cf_iter2 + "_" + _cf_testIteration2); @@ -816,6 +821,7 @@ class TestKryoSerializerUsage: String _cf_outputFile1 = System.getenv("CODEFLASH_OUTPUT_FILE"); String _cf_testIteration1 = System.getenv("CODEFLASH_TEST_ITERATION"); if (_cf_testIteration1 == null) _cf_testIteration1 = "0"; + String _cf_test1 = "testFoo"; System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_iter1 + "######$!"); byte[] _cf_serializedResult1 = null; long _cf_end1 = -1; @@ -844,7 +850,7 @@ class TestKryoSerializerUsage: try (PreparedStatement _cf_pstmt1 = _cf_conn1.prepareStatement(_cf_sql1)) { _cf_pstmt1.setString(1, _cf_mod1); _cf_pstmt1.setString(2, _cf_cls1); - _cf_pstmt1.setString(3, "MyTestTest"); + _cf_pstmt1.setString(3, _cf_test1); _cf_pstmt1.setString(4, _cf_fn1); _cf_pstmt1.setInt(5, _cf_loop1); _cf_pstmt1.setString(6, _cf_iter1 + "_" + _cf_testIteration1); @@ -1317,6 +1323,7 @@ def test_instrument_generated_test_behavior_mode(self): String _cf_outputFile1 = System.getenv("CODEFLASH_OUTPUT_FILE"); String _cf_testIteration1 = System.getenv("CODEFLASH_TEST_ITERATION"); if (_cf_testIteration1 == null) _cf_testIteration1 = "0"; + String _cf_test1 = "testAdd"; System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_iter1 + "######$!"); byte[] _cf_serializedResult1 = null; long _cf_end1 = -1; @@ -1346,7 +1353,7 @@ def test_instrument_generated_test_behavior_mode(self): try (PreparedStatement _cf_pstmt1 = _cf_conn1.prepareStatement(_cf_sql1)) { _cf_pstmt1.setString(1, _cf_mod1); _cf_pstmt1.setString(2, _cf_cls1); - _cf_pstmt1.setString(3, "CalculatorTestTest"); + _cf_pstmt1.setString(3, _cf_test1); _cf_pstmt1.setString(4, _cf_fn1); _cf_pstmt1.setInt(5, _cf_loop1); _cf_pstmt1.setString(6, _cf_iter1 + "_" + _cf_testIteration1); @@ -2522,6 +2529,7 @@ def test_behavior_mode_writes_to_sqlite(self, java_project): String _cf_outputFile1 = System.getenv("CODEFLASH_OUTPUT_FILE"); String _cf_testIteration1 = System.getenv("CODEFLASH_TEST_ITERATION"); if (_cf_testIteration1 == null) _cf_testIteration1 = "0"; + String _cf_test1 = "testIncrement"; System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_iter1 + "######$!"); byte[] _cf_serializedResult1 = null; long _cf_end1 = -1; @@ -2552,7 +2560,7 @@ def test_behavior_mode_writes_to_sqlite(self, java_project): try (PreparedStatement _cf_pstmt1 = _cf_conn1.prepareStatement(_cf_sql1)) { _cf_pstmt1.setString(1, _cf_mod1); _cf_pstmt1.setString(2, _cf_cls1); - _cf_pstmt1.setString(3, "CounterTestTest"); + _cf_pstmt1.setString(3, _cf_test1); _cf_pstmt1.setString(4, _cf_fn1); _cf_pstmt1.setInt(5, _cf_loop1); _cf_pstmt1.setString(6, _cf_iter1 + "_" + _cf_testIteration1);