From 9e2579bdd1782a92063bc4e26438178dc3bd04ff Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Wed, 1 Apr 2026 17:15:22 -0700 Subject: [PATCH 1/3] Fix Java instrumentation compilation error for generic methods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When instrumenting tests for generic methods like `> List mergeSorted(...)`, the extracted return type `List` was used as a cast in the instrumented test class where `T` is not in scope, causing "cannot find symbol: class T" compilation errors. Added `_erase_method_type_params()` to detect method-level type parameters via tree-sitter and erase them from the return type: bare type variables become `Object`, parameterized uses become wildcards (e.g. `List` → `List`). Co-Authored-By: Claude Opus 4.6 --- codeflash/languages/java/instrumentation.py | 52 +++++++- .../test_java/test_instrumentation.py | 113 ++++++++++++++++++ 2 files changed, 164 insertions(+), 1 deletion(-) diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py index c1a8ae683..720c286af 100644 --- a/codeflash/languages/java/instrumentation.py +++ b/codeflash/languages/java/instrumentation.py @@ -508,6 +508,56 @@ def _infer_array_cast_type(line: str) -> str | None: return None +def _erase_method_type_params(return_type: str, method_node: Any) -> str: + """Erase method-level type parameters from the return type. + + Generic methods like ``> List mergeSorted(...)`` + declare type parameters that only exist within the method scope. When the + return type (e.g. ``List``) is used as a cast in an instrumented test + class, those type variables are not in scope and cause compilation errors. + + This function detects method-level type parameters via the tree-sitter AST + and replaces any occurrences in the return type with the wildcard ``?``. + If the return type *itself* is a bare type variable (e.g. ``T``), the type + is erased to ``Object``. + """ + # Find method-level type_parameters node via tree-sitter AST + ts_node = getattr(method_node, "node", None) + if ts_node is None: + return return_type + + type_params_node = None + for child in ts_node.children: + if child.type == "type_parameters": + type_params_node = child + break + + if type_params_node is None: + return return_type + + # Collect declared type variable names (e.g. T, E, K, V) + type_var_names: set[str] = set() + for child in type_params_node.children: + if child.type == "type_parameter": + name_node = child.child_by_field_name("name") or (child.children[0] if child.children else None) + if name_node: + type_var_names.add(name_node.text.decode("utf8") if isinstance(name_node.text, bytes) else str(name_node.text)) + + if not type_var_names: + return return_type + + # If the entire return type is a bare type variable, erase to Object + if return_type.strip() in type_var_names: + return "Object" + + # Replace type variables used as generic arguments with '?' + # Match whole-word type variable names that appear as generic type arguments + for tv in type_var_names: + return_type = re.sub(rf'\b{re.escape(tv)}\b', '?', return_type) + + return return_type + + def _extract_return_type(function_to_optimize: Any) -> str: """Extract the return type of a Java function from its source file using tree-sitter.""" file_path = getattr(function_to_optimize, "file_path", None) @@ -522,7 +572,7 @@ def _extract_return_type(function_to_optimize: Any) -> str: methods = analyzer.find_methods(source_text) for method in methods: if method.name == func_name and method.return_type: - return method.return_type + return _erase_method_type_params(method.return_type, method) except Exception: logger.debug("Could not extract return type for %s", func_name) return "" diff --git a/tests/test_languages/test_java/test_instrumentation.py b/tests/test_languages/test_java/test_instrumentation.py index 4290766db..0cb32f1f7 100644 --- a/tests/test_languages/test_java/test_instrumentation.py +++ b/tests/test_languages/test_java/test_instrumentation.py @@ -27,6 +27,8 @@ from codeflash.languages.java.instrumentation import ( _add_behavior_instrumentation, _add_timing_instrumentation, + _erase_method_type_params, + _extract_return_type, create_benchmark_test, instrument_existing_test, instrument_for_behavior, @@ -3485,3 +3487,114 @@ def __init__(self, path): assert math.isclose(duration, 100_000_000, rel_tol=0.15), ( f"Long spin measured {duration}ns, expected ~100_000_000ns (15% tolerance)" ) + + +class TestEraseMethodTypeParams: + """Tests for _erase_method_type_params — erasing method-level type variables from return types.""" + + def test_generic_return_type_list(self): + """Generic method List should have T erased to ? in return type.""" + source = """public class CollectionUtils { + public static > List mergeSorted(List a, List b) { + return null; + } +} +""" + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + methods = analyzer.find_methods(source) + assert len(methods) == 1 + result = _erase_method_type_params(methods[0].return_type, methods[0]) + assert result == "List", f"Expected 'List' but got '{result}'" + + def test_bare_type_variable_erased_to_object(self): + """Generic method T max(...) should erase bare T to Object.""" + source = """public class Utils { + public static > T max(T a, T b) { + return a.compareTo(b) >= 0 ? a : b; + } +} +""" + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + methods = analyzer.find_methods(source) + assert len(methods) == 1 + result = _erase_method_type_params(methods[0].return_type, methods[0]) + assert result == "Object", f"Expected 'Object' but got '{result}'" + + def test_multiple_type_params(self): + """Generic method Map should erase both K and V.""" + source = """public class Utils { + public static Map combine(Map a, Map b) { + return null; + } +} +""" + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + methods = analyzer.find_methods(source) + assert len(methods) == 1 + result = _erase_method_type_params(methods[0].return_type, methods[0]) + assert result == "Map", f"Expected 'Map' but got '{result}'" + + def test_non_generic_method_unchanged(self): + """Non-generic method return type should be unchanged.""" + source = """public class Calculator { + public int add(int a, int b) { + return a + b; + } +} +""" + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + methods = analyzer.find_methods(source) + assert len(methods) == 1 + result = _erase_method_type_params("int", methods[0]) + assert result == "int", f"Expected 'int' but got '{result}'" + + def test_class_level_generics_not_erased(self): + """Class-level type params should NOT be erased (only method-level ones).""" + source = """public class Box { + public T getValue() { + return null; + } +} +""" + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + methods = analyzer.find_methods(source) + assert len(methods) == 1 + # T is a class-level param, not method-level — should not be erased + result = _erase_method_type_params("T", methods[0]) + assert result == "T", f"Expected 'T' (class-level generic unchanged) but got '{result}'" + + +class TestExtractReturnTypeGeneric: + """Test that _extract_return_type erases method-level type params.""" + + def test_extract_return_type_generic_method(self, tmp_path): + """_extract_return_type should return erased type for generic methods.""" + java_file = tmp_path / "CollectionUtils.java" + java_file.write_text("""package com.example; +import java.util.List; + +public class CollectionUtils { + public static > List mergeSorted(List a, List b) { + return null; + } +} +""") + + class FakeFunc: + file_path = java_file + function_name = "mergeSorted" + qualified_name = "CollectionUtils.mergeSorted" + parents = [] + + result = _extract_return_type(FakeFunc()) + assert result == "List", f"Expected 'List' but got '{result}'" From 5c5c17eb90d7ed1cad858b02ac0a2454e65451e8 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Wed, 1 Apr 2026 17:17:07 -0700 Subject: [PATCH 2/3] Fix ruff lint: use double quotes and wrap long line Co-Authored-By: Claude Opus 4.6 --- codeflash/languages/java/instrumentation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py index 720c286af..0e2e556a1 100644 --- a/codeflash/languages/java/instrumentation.py +++ b/codeflash/languages/java/instrumentation.py @@ -541,7 +541,9 @@ def _erase_method_type_params(return_type: str, method_node: Any) -> str: if child.type == "type_parameter": name_node = child.child_by_field_name("name") or (child.children[0] if child.children else None) if name_node: - type_var_names.add(name_node.text.decode("utf8") if isinstance(name_node.text, bytes) else str(name_node.text)) + type_var_names.add( + name_node.text.decode("utf8") if isinstance(name_node.text, bytes) else str(name_node.text) + ) if not type_var_names: return return_type @@ -553,7 +555,7 @@ def _erase_method_type_params(return_type: str, method_node: Any) -> str: # Replace type variables used as generic arguments with '?' # Match whole-word type variable names that appear as generic type arguments for tv in type_var_names: - return_type = re.sub(rf'\b{re.escape(tv)}\b', '?', return_type) + return_type = re.sub(rf"\b{re.escape(tv)}\b", "?", return_type) return return_type From 9544e4b1ccdb6fd41410294efd3b64f629e8255f Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Wed, 1 Apr 2026 17:27:07 -0700 Subject: [PATCH 3/3] full string check tests --- .../test_java/test_instrumentation.py | 373 ++++++++++++++---- 1 file changed, 297 insertions(+), 76 deletions(-) diff --git a/tests/test_languages/test_java/test_instrumentation.py b/tests/test_languages/test_java/test_instrumentation.py index 0cb32f1f7..304b0a75c 100644 --- a/tests/test_languages/test_java/test_instrumentation.py +++ b/tests/test_languages/test_java/test_instrumentation.py @@ -22,13 +22,10 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import Language from codeflash.languages.current import set_current_language -from codeflash.languages.java.maven_strategy import MavenStrategy from codeflash.languages.java.discovery import discover_functions_from_source from codeflash.languages.java.instrumentation import ( _add_behavior_instrumentation, _add_timing_instrumentation, - _erase_method_type_params, - _extract_return_type, create_benchmark_test, instrument_existing_test, instrument_for_behavior, @@ -36,6 +33,7 @@ instrument_generated_java_test, remove_instrumentation, ) +from codeflash.languages.java.maven_strategy import MavenStrategy class TestInstrumentForBehavior: @@ -2179,7 +2177,7 @@ def test_instrument_with_multibyte_in_comment(self, tmp_path: Path): # Skip all E2E tests if Maven is not available requires_maven = pytest.mark.skipif( - MavenStrategy().find_executable(Path(".")) is None, reason="Maven not found - skipping execution tests" + MavenStrategy().find_executable(Path()) is None, reason="Maven not found - skipping execution tests" ) @@ -3489,98 +3487,253 @@ def __init__(self, path): ) -class TestEraseMethodTypeParams: - """Tests for _erase_method_type_params — erasing method-level type variables from return types.""" +class TestGenericMethodTypeErasureInstrumentation: + """Tests that generic method type parameters are erased in instrumented output.""" + + def test_generic_list_return_type_erased_in_behavior_cast(self, tmp_path): + """Generic List return type should produce (List)cast in behavior mode.""" + src_file = (tmp_path / "CollectionUtils.java").resolve() + src_file.write_text( + """package com.example; +import java.util.List; - def test_generic_return_type_list(self): - """Generic method List should have T erased to ? in return type.""" - source = """public class CollectionUtils { +public class CollectionUtils { public static > List mergeSorted(List a, List b) { return null; } } -""" - from codeflash.languages.java.parser import get_java_analyzer +""", + encoding="utf-8", + ) - analyzer = get_java_analyzer() - methods = analyzer.find_methods(source) - assert len(methods) == 1 - result = _erase_method_type_params(methods[0].return_type, methods[0]) - assert result == "List", f"Expected 'List' but got '{result}'" + test_source = """package com.example; - def test_bare_type_variable_erased_to_object(self): - """Generic method T max(...) should erase bare T to Object.""" - source = """public class Utils { - public static > T max(T a, T b) { - return a.compareTo(b) >= 0 ? a : b; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; +import java.util.List; +import java.util.Arrays; + +public class CollectionUtilsTest { + @Test + public void testMergeSorted() { + assertEquals(Arrays.asList(1, 2, 3, 4), CollectionUtils.mergeSorted(Arrays.asList(1, 3), Arrays.asList(2, 4))); } } """ - from codeflash.languages.java.parser import get_java_analyzer + test_file = (tmp_path / "CollectionUtilsTest.java").resolve() + test_file.write_text(test_source, encoding="utf-8") - analyzer = get_java_analyzer() - methods = analyzer.find_methods(source) - assert len(methods) == 1 - result = _erase_method_type_params(methods[0].return_type, methods[0]) - assert result == "Object", f"Expected 'Object' but got '{result}'" + func_info = FunctionToOptimize( + function_name="mergeSorted", + file_path=src_file, + starting_line=5, + ending_line=7, + parents=[], + is_method=False, + language="java", + ) - def test_multiple_type_params(self): - """Generic method Map should erase both K and V.""" - source = """public class Utils { - public static Map combine(Map a, Map b) { - return null; + success, instrumented = instrument_existing_test( + test_string=test_source, function_to_optimize=func_info, mode="behavior", test_path=test_file + ) + assert success + + expected_instrumented = """package com.example; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; +import java.util.List; +import java.util.Arrays; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; + +@SuppressWarnings("CheckReturnValue") +public class CollectionUtilsTest__perfinstrumented { + @Test + public void testMergeSorted() { + // Codeflash behavior instrumentation + int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX")); + int _cf_iter1 = 1; + String _cf_mod1 = "CollectionUtilsTest__perfinstrumented"; + String _cf_cls1 = "CollectionUtilsTest__perfinstrumented"; + String _cf_fn1 = "mergeSorted"; + String _cf_outputFile1 = System.getenv("CODEFLASH_OUTPUT_FILE"); + String _cf_testIteration1 = System.getenv("CODEFLASH_TEST_ITERATION"); + if (_cf_testIteration1 == null) _cf_testIteration1 = "0"; + String _cf_test1 = "testMergeSorted"; + Object _cf_result1_1 = null; + long _cf_end1_1 = -1; + long _cf_start1_1 = 0; + byte[] _cf_serializedResult1_1 = null; + System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":L15_1" + "######$!"); + try { + _cf_start1_1 = System.nanoTime(); + _cf_result1_1 = CollectionUtils.mergeSorted(Arrays.asList(1, 3), Arrays.asList(2, 4)); + _cf_end1_1 = System.nanoTime(); + _cf_serializedResult1_1 = com.codeflash.Serializer.serialize((Object) _cf_result1_1); + } finally { + long _cf_end1_1_finally = System.nanoTime(); + long _cf_dur1_1 = (_cf_end1_1 != -1 ? _cf_end1_1 : _cf_end1_1_finally) - _cf_start1_1; + System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + "L15_1" + "######!"); + // Write to SQLite if output file is set + if (_cf_outputFile1 != null && !_cf_outputFile1.isEmpty()) { + try { + Class.forName("org.sqlite.JDBC"); + try (Connection _cf_conn1_1 = DriverManager.getConnection("jdbc:sqlite:" + _cf_outputFile1)) { + try (java.sql.Statement _cf_stmt1_1 = _cf_conn1_1.createStatement()) { + _cf_stmt1_1.execute("CREATE TABLE IF NOT EXISTS test_results (" + + "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " + + "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + + "runtime INTEGER, return_value BLOB, verification_type TEXT)"); + } + String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + try (PreparedStatement _cf_pstmt1_1 = _cf_conn1_1.prepareStatement(_cf_sql1_1)) { + _cf_pstmt1_1.setString(1, _cf_mod1); + _cf_pstmt1_1.setString(2, _cf_cls1); + _cf_pstmt1_1.setString(3, _cf_test1); + _cf_pstmt1_1.setString(4, _cf_fn1); + _cf_pstmt1_1.setInt(5, _cf_loop1); + _cf_pstmt1_1.setString(6, "L15_1"); + _cf_pstmt1_1.setLong(7, _cf_dur1_1); + _cf_pstmt1_1.setBytes(8, _cf_serializedResult1_1); + _cf_pstmt1_1.setString(9, "function_call"); + _cf_pstmt1_1.executeUpdate(); + } + } + } catch (Exception _cf_e1_1) { + System.err.println("CodeflashHelper: SQLite error: " + _cf_e1_1.getMessage()); + } + } + } + assertEquals(Arrays.asList(1, 2, 3, 4), (List)_cf_result1_1); } } """ - from codeflash.languages.java.parser import get_java_analyzer + assert instrumented == expected_instrumented - analyzer = get_java_analyzer() - methods = analyzer.find_methods(source) - assert len(methods) == 1 - result = _erase_method_type_params(methods[0].return_type, methods[0]) - assert result == "Map", f"Expected 'Map' but got '{result}'" + def test_bare_type_variable_erased_to_object_in_behavior_cast(self, tmp_path): + """Generic T return type should produce (Object)cast in behavior mode.""" + src_file = (tmp_path / "Utils.java").resolve() + src_file.write_text( + """package com.example; - def test_non_generic_method_unchanged(self): - """Non-generic method return type should be unchanged.""" - source = """public class Calculator { - public int add(int a, int b) { - return a + b; +public class Utils { + public static > T max(T a, T b) { + return a.compareTo(b) >= 0 ? a : b; } } -""" - from codeflash.languages.java.parser import get_java_analyzer +""", + encoding="utf-8", + ) - analyzer = get_java_analyzer() - methods = analyzer.find_methods(source) - assert len(methods) == 1 - result = _erase_method_type_params("int", methods[0]) - assert result == "int", f"Expected 'int' but got '{result}'" + test_source = """package com.example; - def test_class_level_generics_not_erased(self): - """Class-level type params should NOT be erased (only method-level ones).""" - source = """public class Box { - public T getValue() { - return null; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; + +public class UtilsTest { + @Test + public void testMax() { + assertEquals(5, Utils.max(3, 5)); } } """ - from codeflash.languages.java.parser import get_java_analyzer + test_file = (tmp_path / "UtilsTest.java").resolve() + test_file.write_text(test_source, encoding="utf-8") + + func_info = FunctionToOptimize( + function_name="max", + file_path=src_file, + starting_line=4, + ending_line=6, + parents=[], + is_method=False, + language="java", + ) + + success, instrumented = instrument_existing_test( + test_string=test_source, function_to_optimize=func_info, mode="behavior", test_path=test_file + ) + assert success - analyzer = get_java_analyzer() - methods = analyzer.find_methods(source) - assert len(methods) == 1 - # T is a class-level param, not method-level — should not be erased - result = _erase_method_type_params("T", methods[0]) - assert result == "T", f"Expected 'T' (class-level generic unchanged) but got '{result}'" + expected_instrumented = """package com.example; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; -class TestExtractReturnTypeGeneric: - """Test that _extract_return_type erases method-level type params.""" +@SuppressWarnings("CheckReturnValue") +public class UtilsTest__perfinstrumented { + @Test + public void testMax() { + // Codeflash behavior instrumentation + int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX")); + int _cf_iter1 = 1; + String _cf_mod1 = "UtilsTest__perfinstrumented"; + String _cf_cls1 = "UtilsTest__perfinstrumented"; + String _cf_fn1 = "max"; + String _cf_outputFile1 = System.getenv("CODEFLASH_OUTPUT_FILE"); + String _cf_testIteration1 = System.getenv("CODEFLASH_TEST_ITERATION"); + if (_cf_testIteration1 == null) _cf_testIteration1 = "0"; + String _cf_test1 = "testMax"; + Object _cf_result1_1 = null; + long _cf_end1_1 = -1; + long _cf_start1_1 = 0; + byte[] _cf_serializedResult1_1 = null; + System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":L13_1" + "######$!"); + try { + _cf_start1_1 = System.nanoTime(); + _cf_result1_1 = Utils.max(3, 5); + _cf_end1_1 = System.nanoTime(); + _cf_serializedResult1_1 = com.codeflash.Serializer.serialize((Object) _cf_result1_1); + } finally { + long _cf_end1_1_finally = System.nanoTime(); + long _cf_dur1_1 = (_cf_end1_1 != -1 ? _cf_end1_1 : _cf_end1_1_finally) - _cf_start1_1; + System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + "L13_1" + "######!"); + // Write to SQLite if output file is set + if (_cf_outputFile1 != null && !_cf_outputFile1.isEmpty()) { + try { + Class.forName("org.sqlite.JDBC"); + try (Connection _cf_conn1_1 = DriverManager.getConnection("jdbc:sqlite:" + _cf_outputFile1)) { + try (java.sql.Statement _cf_stmt1_1 = _cf_conn1_1.createStatement()) { + _cf_stmt1_1.execute("CREATE TABLE IF NOT EXISTS test_results (" + + "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " + + "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + + "runtime INTEGER, return_value BLOB, verification_type TEXT)"); + } + String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + try (PreparedStatement _cf_pstmt1_1 = _cf_conn1_1.prepareStatement(_cf_sql1_1)) { + _cf_pstmt1_1.setString(1, _cf_mod1); + _cf_pstmt1_1.setString(2, _cf_cls1); + _cf_pstmt1_1.setString(3, _cf_test1); + _cf_pstmt1_1.setString(4, _cf_fn1); + _cf_pstmt1_1.setInt(5, _cf_loop1); + _cf_pstmt1_1.setString(6, "L13_1"); + _cf_pstmt1_1.setLong(7, _cf_dur1_1); + _cf_pstmt1_1.setBytes(8, _cf_serializedResult1_1); + _cf_pstmt1_1.setString(9, "function_call"); + _cf_pstmt1_1.executeUpdate(); + } + } + } catch (Exception _cf_e1_1) { + System.err.println("CodeflashHelper: SQLite error: " + _cf_e1_1.getMessage()); + } + } + } + assertEquals(5, (Object)_cf_result1_1); + } +} +""" + assert instrumented == expected_instrumented - def test_extract_return_type_generic_method(self, tmp_path): - """_extract_return_type should return erased type for generic methods.""" - java_file = tmp_path / "CollectionUtils.java" - java_file.write_text("""package com.example; + def test_generic_return_type_performance_mode(self, tmp_path): + """Generic method in performance mode should compile without type variable errors.""" + src_file = (tmp_path / "CollectionUtils.java").resolve() + src_file.write_text( + """package com.example; import java.util.List; public class CollectionUtils { @@ -3588,13 +3741,81 @@ def test_extract_return_type_generic_method(self, tmp_path): return null; } } -""") +""", + encoding="utf-8", + ) - class FakeFunc: - file_path = java_file - function_name = "mergeSorted" - qualified_name = "CollectionUtils.mergeSorted" - parents = [] + test_source = """package com.example; - result = _extract_return_type(FakeFunc()) - assert result == "List", f"Expected 'List' but got '{result}'" +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; +import java.util.List; +import java.util.Arrays; + +public class CollectionUtilsTest { + @Test + public void testMergeSorted() { + List result = CollectionUtils.mergeSorted(Arrays.asList(1, 3), Arrays.asList(2, 4)); + assertEquals(Arrays.asList(1, 2, 3, 4), result); + } +} +""" + test_file = (tmp_path / "CollectionUtilsTest.java").resolve() + test_file.write_text(test_source, encoding="utf-8") + + func_info = FunctionToOptimize( + function_name="mergeSorted", + file_path=src_file, + starting_line=5, + ending_line=7, + parents=[], + is_method=False, + language="java", + ) + + success, instrumented = instrument_existing_test( + test_string=test_source, function_to_optimize=func_info, mode="performance", test_path=test_file + ) + assert success + + expected_instrumented = """package com.example; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; +import java.util.List; +import java.util.Arrays; + +@SuppressWarnings("CheckReturnValue") +public class CollectionUtilsTest__perfonlyinstrumented { + @Test + public void testMergeSorted() { + // Codeflash timing instrumentation with inner loop for JIT warmup + int _cf_outerLoop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX")); + int _cf_maxInnerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "10")); + int _cf_innerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "10")); + String _cf_mod1 = "CollectionUtilsTest__perfonlyinstrumented"; + String _cf_cls1 = "CollectionUtilsTest__perfonlyinstrumented"; + String _cf_test1 = "testMergeSorted"; + String _cf_fn1 = "mergeSorted"; + + List result = null; + for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++) { + int _cf_loopId1 = _cf_outerLoop1 * _cf_maxInnerIterations1 + _cf_i1; + System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loopId1 + ":" + "L12_1" + "######$!"); + long _cf_end1 = -1; + long _cf_start1 = 0; + try { + _cf_start1 = System.nanoTime(); + result = CollectionUtils.mergeSorted(Arrays.asList(1, 3), Arrays.asList(2, 4)); + _cf_end1 = System.nanoTime(); + } finally { + long _cf_end1_finally = System.nanoTime(); + long _cf_dur1 = (_cf_end1 != -1 ? _cf_end1 : _cf_end1_finally) - _cf_start1; + System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loopId1 + ":" + "L12_1" + ":" + _cf_dur1 + "######!"); + } + } + assertEquals(Arrays.asList(1, 2, 3, 4), result); + } +} +""" + assert instrumented == expected_instrumented