diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 4366468d0..9486fc677 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -1497,73 +1497,207 @@ def _is_target_decorator(self, decorator_node: cst.Name | cst.Attribute | cst.Ca return False -class AsyncDecoratorImportAdder(cst.CSTTransformer): - """Transformer that adds the import for async decorators.""" +ASYNC_HELPER_INLINE_CODE = """import asyncio +import gc +import os +import sqlite3 +import time +from functools import wraps +from pathlib import Path +from tempfile import TemporaryDirectory + +import dill as pickle - def __init__(self, mode: TestingMode = TestingMode.BEHAVIOR) -> None: - self.mode = mode - self.has_import = False - - def _get_decorator_name(self) -> str: - """Get the decorator name based on the testing mode.""" - if self.mode == TestingMode.BEHAVIOR: - return "codeflash_behavior_async" - if self.mode == TestingMode.CONCURRENCY: - return "codeflash_concurrency_async" - return "codeflash_performance_async" - - def visit_ImportFrom(self, node: cst.ImportFrom) -> None: - # Check if the async decorator import is already present - if ( - isinstance(node.module, cst.Attribute) - and isinstance(node.module.value, cst.Attribute) - and isinstance(node.module.value.value, cst.Name) - and node.module.value.value.value == "codeflash" - and node.module.value.attr.value == "code_utils" - and node.module.attr.value == "codeflash_wrap_decorator" - and not isinstance(node.names, cst.ImportStar) - ): - decorator_name = self._get_decorator_name() - for import_alias in node.names: - if import_alias.name.value == decorator_name: - self.has_import = True - - def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: - # If the import is already there, don't add it again - if self.has_import: - return updated_node - - # Choose import based on mode - decorator_name = self._get_decorator_name() - - # Parse the import statement into a CST node - import_node = cst.parse_statement(f"from codeflash.code_utils.codeflash_wrap_decorator import {decorator_name}") - - # Add the import to the module's body - return updated_node.with_changes(body=[import_node, *list(updated_node.body)]) + +def get_run_tmp_file(file_path): + if not hasattr(get_run_tmp_file, "tmpdir"): + get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_") + return Path(get_run_tmp_file.tmpdir.name) / file_path + + +def extract_test_context_from_env(): + test_module = os.environ["CODEFLASH_TEST_MODULE"] + test_class = os.environ.get("CODEFLASH_TEST_CLASS", None) + test_function = os.environ["CODEFLASH_TEST_FUNCTION"] + if test_module and test_function: + return (test_module, test_class if test_class else None, test_function) + raise RuntimeError( + "Test context environment variables not set - ensure tests are run through codeflash test runner" + ) + + +def codeflash_behavior_async(func): + @wraps(func) + async def async_wrapper(*args, **kwargs): + loop = asyncio.get_running_loop() + function_name = func.__name__ + line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"] + loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"]) + test_module_name, test_class_name, test_name = extract_test_context_from_env() + test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}" + if not hasattr(async_wrapper, "index"): + async_wrapper.index = {} + if test_id in async_wrapper.index: + async_wrapper.index[test_id] += 1 + else: + async_wrapper.index[test_id] = 0 + codeflash_test_index = async_wrapper.index[test_id] + invocation_id = f"{line_id}_{codeflash_test_index}" + class_prefix = (test_class_name + ".") if test_class_name else "" + test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}" + print(f"!$######{test_stdout_tag}######$!") + iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0") + db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite")) + codeflash_con = sqlite3.connect(db_path) + codeflash_cur = codeflash_con.cursor() + codeflash_cur.execute( + "CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, " + "test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + "runtime INTEGER, return_value BLOB, verification_type TEXT)" + ) + exception = None + counter = loop.time() + gc.disable() + try: + ret = func(*args, **kwargs) + counter = loop.time() + return_value = await ret + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + except Exception as e: + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + exception = e + finally: + gc.enable() + print(f"!######{test_stdout_tag}######!") + pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps((args, kwargs, return_value)) + codeflash_cur.execute( + "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + test_module_name, + test_class_name, + test_name, + function_name, + loop_index, + invocation_id, + codeflash_duration, + pickled_return_value, + "function_call", + ), + ) + codeflash_con.commit() + codeflash_con.close() + if exception: + raise exception + return return_value + return async_wrapper + + +def codeflash_performance_async(func): + @wraps(func) + async def async_wrapper(*args, **kwargs): + loop = asyncio.get_running_loop() + function_name = func.__name__ + line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"] + loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"]) + test_module_name, test_class_name, test_name = extract_test_context_from_env() + test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}" + if not hasattr(async_wrapper, "index"): + async_wrapper.index = {} + if test_id in async_wrapper.index: + async_wrapper.index[test_id] += 1 + else: + async_wrapper.index[test_id] = 0 + codeflash_test_index = async_wrapper.index[test_id] + invocation_id = f"{line_id}_{codeflash_test_index}" + class_prefix = (test_class_name + ".") if test_class_name else "" + test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}" + print(f"!$######{test_stdout_tag}######$!") + exception = None + counter = loop.time() + gc.disable() + try: + ret = func(*args, **kwargs) + counter = loop.time() + return_value = await ret + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + except Exception as e: + codeflash_duration = int((loop.time() - counter) * 1_000_000_000) + exception = e + finally: + gc.enable() + print(f"!######{test_stdout_tag}:{codeflash_duration}######!") + if exception: + raise exception + return return_value + return async_wrapper + + +def codeflash_concurrency_async(func): + @wraps(func) + async def async_wrapper(*args, **kwargs): + function_name = func.__name__ + concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10")) + test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "") + test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "") + test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "") + loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0") + gc.disable() + try: + seq_start = time.perf_counter_ns() + for _ in range(concurrency_factor): + result = await func(*args, **kwargs) + sequential_time = time.perf_counter_ns() - seq_start + finally: + gc.enable() + gc.disable() + try: + conc_start = time.perf_counter_ns() + tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)] + await asyncio.gather(*tasks) + concurrent_time = time.perf_counter_ns() - conc_start + finally: + gc.enable() + tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}" + print(f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!") + return result + return async_wrapper +""" + +ASYNC_HELPER_FILENAME = "codeflash_async_wrapper.py" + + +def get_decorator_name_for_mode(mode: TestingMode) -> str: + if mode == TestingMode.BEHAVIOR: + return "codeflash_behavior_async" + if mode == TestingMode.CONCURRENCY: + return "codeflash_concurrency_async" + return "codeflash_performance_async" + + +def write_async_helper_file(target_dir: Path) -> Path: + """Write the async decorator helper file to the target directory.""" + helper_path = target_dir / ASYNC_HELPER_FILENAME + if not helper_path.exists(): + helper_path.write_text(ASYNC_HELPER_INLINE_CODE, "utf-8") + return helper_path def add_async_decorator_to_function( - source_path: Path, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR + source_path: Path, + function: FunctionToOptimize, + mode: TestingMode = TestingMode.BEHAVIOR, + project_root: Path | None = None, ) -> bool: """Add async decorator to an async function definition and write back to file. - Args: - ---- - source_path: Path to the source file to modify in-place. - function: The FunctionToOptimize object representing the target async function. - mode: The testing mode to determine which decorator to apply. - - Returns: - ------- - Boolean indicating whether the decorator was successfully added. + Writes a helper file containing the decorator implementation to project_root (or source directory + as fallback) and adds a standard import + decorator to the source file. """ if not function.is_async: return False try: - # Read source code with source_path.open(encoding="utf8") as f: source_code = f.read() @@ -1573,10 +1707,14 @@ def add_async_decorator_to_function( decorator_transformer = AsyncDecoratorAdder(function, mode) module = module.visit(decorator_transformer) - # Add the import if decorator was added if decorator_transformer.added_decorator: - import_transformer = AsyncDecoratorImportAdder(mode) - module = module.visit(import_transformer) + # Write the helper file to project_root (on sys.path) or source dir as fallback + helper_dir = project_root if project_root is not None else source_path.parent + write_async_helper_file(helper_dir) + # Add the import via CST so sort_imports can place it correctly + decorator_name = get_decorator_name_for_mode(mode) + import_node = cst.parse_statement(f"from codeflash_async_wrapper import {decorator_name}") + module = module.with_changes(body=[import_node, *list(module.body)]) modified_code = sort_imports(code=module.code, float_to_top=True) except Exception as e: diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index ed7f7f1fe..0a515076c 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1897,6 +1897,7 @@ def setup_and_establish_baseline( if self.args.override_fixtures: restore_conftest(original_conftest_content) cleanup_paths(paths_to_cleanup) + self.cleanup_async_helper_file() return Failure(baseline_result.failure()) original_code_baseline, test_functions_to_remove = baseline_result.unwrap() @@ -1908,6 +1909,7 @@ def setup_and_establish_baseline( if self.args.override_fixtures: restore_conftest(original_conftest_content) cleanup_paths(paths_to_cleanup) + self.cleanup_async_helper_file() return Failure("The threshold for test confidence was not met.") return Success( @@ -2279,6 +2281,13 @@ def revert_code_and_helpers(self, original_helper_code: dict[Path, str]) -> None self.write_code_and_helpers( self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path ) + self.cleanup_async_helper_file() + + def cleanup_async_helper_file(self) -> None: + from codeflash.code_utils.instrument_existing_tests import ASYNC_HELPER_FILENAME + + helper_path = self.project_root / ASYNC_HELPER_FILENAME + helper_path.unlink(missing_ok=True) def establish_original_code_baseline( self, @@ -2296,7 +2305,10 @@ def establish_original_code_baseline( from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function success = add_async_decorator_to_function( - self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR + self.function_to_optimize.file_path, + self.function_to_optimize, + TestingMode.BEHAVIOR, + project_root=self.project_root, ) # Instrument codeflash capture @@ -2361,7 +2373,10 @@ def establish_original_code_baseline( from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function add_async_decorator_to_function( - self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE + self.function_to_optimize.file_path, + self.function_to_optimize, + TestingMode.PERFORMANCE, + project_root=self.project_root, ) try: @@ -2535,7 +2550,10 @@ def run_optimized_candidate( from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function add_async_decorator_to_function( - self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR + self.function_to_optimize.file_path, + self.function_to_optimize, + TestingMode.BEHAVIOR, + project_root=self.project_root, ) try: @@ -2611,7 +2629,10 @@ def run_optimized_candidate( from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function add_async_decorator_to_function( - self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE + self.function_to_optimize.file_path, + self.function_to_optimize, + TestingMode.PERFORMANCE, + project_root=self.project_root, ) try: @@ -2974,7 +2995,10 @@ def run_concurrency_benchmark( try: # Add concurrency decorator to the source function add_async_decorator_to_function( - self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.CONCURRENCY + self.function_to_optimize.file_path, + self.function_to_optimize, + TestingMode.CONCURRENCY, + project_root=self.project_root, ) # Run the concurrency benchmark tests diff --git a/tests/scripts/end_to_end_test_async.py b/tests/scripts/end_to_end_test_async.py index 0b4bf8957..0e38ae797 100644 --- a/tests/scripts/end_to_end_test_async.py +++ b/tests/scripts/end_to_end_test_async.py @@ -13,7 +13,7 @@ def run_test(expected_improvement_pct: int) -> bool: CoverageExpectation( function_name="retry_with_backoff", expected_coverage=100.0, - expected_lines=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], + expected_lines=[9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], ) ], ) diff --git a/tests/test_async_run_and_parse_tests.py b/tests/test_async_run_and_parse_tests.py index 1eb667b3f..1777a1c73 100644 --- a/tests/test_async_run_and_parse_tests.py +++ b/tests/test_async_run_and_parse_tests.py @@ -8,7 +8,9 @@ import pytest from codeflash.code_utils.instrument_existing_tests import ( + ASYNC_HELPER_FILENAME, add_async_decorator_to_function, + get_decorator_name_for_mode, inject_profiling_into_existing_test, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -55,16 +57,23 @@ async def test_async_sort(): func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True) # For async functions, instrument the source module directly with decorators - source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR) + source_success = add_async_decorator_to_function( + fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path + ) assert source_success - # Verify the file was modified + # Verify the file was modified with exact expected output instrumented_source = fto_path.read_text("utf-8") - assert ( - '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_behavior_async\n\n\n@codeflash_behavior_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n''' - in instrumented_source + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) + decorated_original = original_code.replace( + "async def async_sorter", f"@{decorator_name}\nasync def async_sorter" ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_original}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert instrumented_source.strip() == expected.strip() # Add codeflash capture instrument_codeflash_capture(func, {}, tests_root) @@ -142,6 +151,9 @@ async def test_async_sort(): test_path.unlink() if test_path_perf.exists(): test_path_perf.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -182,7 +194,9 @@ async def test_async_class_sort(): is_async=True, ) - source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR) + source_success = add_async_decorator_to_function( + fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path + ) assert source_success @@ -264,6 +278,9 @@ async def test_async_class_sort(): test_path.unlink() if test_path_perf.exists(): test_path_perf.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -294,16 +311,23 @@ async def test_async_perf(): func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True) # Instrument the source module with async performance decorators - source_success = add_async_decorator_to_function(fto_path, func, TestingMode.PERFORMANCE) + source_success = add_async_decorator_to_function( + fto_path, func, TestingMode.PERFORMANCE, project_root=project_root_path + ) assert source_success # Verify the file was modified instrumented_source = fto_path.read_text("utf-8") - assert ( - instrumented_source - == '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_performance_async\n\n\n@codeflash_performance_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n''' + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE) + decorated_original = original_code.replace( + "async def async_sorter", f"@{decorator_name}\nasync def async_sorter" ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_original}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert instrumented_source.strip() == expected.strip() instrument_codeflash_capture(func, {}, tests_root) @@ -359,6 +383,9 @@ async def test_async_perf(): # Clean up test files if test_path.exists(): test_path.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -404,68 +431,24 @@ async def async_error_function(lst): function_name="async_error_function", parents=[], file_path=Path(fto_path), is_async=True ) - source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR) + source_success = add_async_decorator_to_function( + fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path + ) assert source_success # Verify the file was modified instrumented_source = fto_path.read_text("utf-8") - expected_instrumented_source = """import asyncio -from typing import List, Union - -from codeflash.code_utils.codeflash_wrap_decorator import \\ - codeflash_behavior_async + from codeflash.code_utils.formatter import sort_imports - -async def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]: - \"\"\" - Async bubble sort implementation for testing. - \"\"\" - print("codeflash stdout: Async sorting list") - - await asyncio.sleep(0.01) - - n = len(lst) - for i in range(n): - for j in range(0, n - i - 1): - if lst[j] > lst[j + 1]: - lst[j], lst[j + 1] = lst[j + 1], lst[j] - - result = lst.copy() - print(f"result: {result}") - return result - - -class AsyncBubbleSorter: - \"\"\"Class with async sorting method for testing.\"\"\" - - async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]: - \"\"\" - Async bubble sort implementation within a class. - \"\"\" - print("codeflash stdout: AsyncBubbleSorter.sorter() called") - - # Add some async delay - await asyncio.sleep(0.005) - - n = len(lst) - for i in range(n): - for j in range(0, n - i - 1): - if lst[j] > lst[j + 1]: - lst[j], lst[j + 1] = lst[j + 1], lst[j] - - result = lst.copy() - return result - - -@codeflash_behavior_async -async def async_error_function(lst): - \"\"\"Async function that raises an error for testing.\"\"\" - await asyncio.sleep(0.001) # Small delay - raise ValueError("Test error") -""" - assert expected_instrumented_source == instrumented_source + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) + decorated_modified = modified_code.replace( + "async def async_error_function", f"@{decorator_name}\nasync def async_error_function" + ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_modified}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert instrumented_source.strip() == expected.strip() instrument_codeflash_capture(func, {}, tests_root) opt = Optimizer( @@ -526,6 +509,9 @@ async def async_error_function(lst): test_path.unlink() if test_path_perf.exists(): test_path_perf.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -563,7 +549,9 @@ async def test_async_multi(): func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True) - source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR) + source_success = add_async_decorator_to_function( + fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path + ) assert source_success instrument_codeflash_capture(func, {}, tests_root) @@ -636,6 +624,9 @@ async def test_async_multi(): test_path.unlink() if test_path_perf.exists(): test_path_perf.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -678,7 +669,9 @@ async def test_async_edge_cases(): func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True) - source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR) + source_success = add_async_decorator_to_function( + fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path + ) assert source_success instrument_codeflash_capture(func, {}, tests_root) @@ -753,6 +746,9 @@ async def test_async_edge_cases(): test_path.unlink() if test_path_perf.exists(): test_path_perf.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -987,7 +983,9 @@ async def test_mixed_sorting(): function_name="async_merge_sort", parents=[], file_path=Path(mixed_fto_path), is_async=True ) - source_success = add_async_decorator_to_function(mixed_fto_path, async_func, TestingMode.BEHAVIOR) + source_success = add_async_decorator_to_function( + mixed_fto_path, async_func, TestingMode.BEHAVIOR, project_root=project_root_path + ) assert source_success @@ -1060,3 +1058,6 @@ async def test_mixed_sorting(): test_path.unlink() if test_path_perf.exists(): test_path_perf.unlink() + helper_path = project_root_path / ASYNC_HELPER_FILENAME + if helper_path.exists(): + helper_path.unlink() diff --git a/tests/test_instrument_async_tests.py b/tests/test_instrument_async_tests.py index 29e65ad06..0e57ec209 100644 --- a/tests/test_instrument_async_tests.py +++ b/tests/test_instrument_async_tests.py @@ -6,7 +6,9 @@ import pytest from codeflash.code_utils.instrument_existing_tests import ( + ASYNC_HELPER_FILENAME, add_async_decorator_to_function, + get_decorator_name_for_mode, inject_profiling_into_existing_test, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -57,20 +59,6 @@ def test_async_decorator_application_behavior_mode(temp_dir): async_function_code = ''' import asyncio -async def async_function(x: int, y: int) -> int: - """Simple async function for testing.""" - await asyncio.sleep(0.01) - return x * y -''' - - expected_decorated_code = ''' -import asyncio - -from codeflash.code_utils.codeflash_wrap_decorator import \\ - codeflash_behavior_async - - -@codeflash_behavior_async async def async_function(x: int, y: int) -> int: """Simple async function for testing.""" await asyncio.sleep(0.01) @@ -86,7 +74,16 @@ async def async_function(x: int, y: int) -> int: assert decorator_added modified_code = test_file.read_text() - assert modified_code.strip() == expected_decorated_code.strip() + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) + code_with_decorator = async_function_code.replace( + "async def async_function", f"@{decorator_name}\nasync def async_function" + ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert modified_code.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -94,20 +91,6 @@ def test_async_decorator_application_performance_mode(temp_dir): async_function_code = ''' import asyncio -async def async_function(x: int, y: int) -> int: - """Simple async function for testing.""" - await asyncio.sleep(0.01) - return x * y -''' - - expected_decorated_code = ''' -import asyncio - -from codeflash.code_utils.codeflash_wrap_decorator import \\ - codeflash_performance_async - - -@codeflash_performance_async async def async_function(x: int, y: int) -> int: """Simple async function for testing.""" await asyncio.sleep(0.01) @@ -123,7 +106,16 @@ async def async_function(x: int, y: int) -> int: assert decorator_added modified_code = test_file.read_text() - assert modified_code.strip() == expected_decorated_code.strip() + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE) + code_with_decorator = async_function_code.replace( + "async def async_function", f"@{decorator_name}\nasync def async_function" + ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert modified_code.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -132,20 +124,6 @@ def test_async_decorator_application_concurrency_mode(temp_dir): async_function_code = ''' import asyncio -async def async_function(x: int, y: int) -> int: - """Simple async function for testing.""" - await asyncio.sleep(0.01) - return x * y -''' - - expected_decorated_code = ''' -import asyncio - -from codeflash.code_utils.codeflash_wrap_decorator import \\ - codeflash_concurrency_async - - -@codeflash_concurrency_async async def async_function(x: int, y: int) -> int: """Simple async function for testing.""" await asyncio.sleep(0.01) @@ -161,7 +139,16 @@ async def async_function(x: int, y: int) -> int: assert decorator_added modified_code = test_file.read_text() - assert modified_code.strip() == expected_decorated_code.strip() + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.CONCURRENCY) + code_with_decorator = async_function_code.replace( + "async def async_function", f"@{decorator_name}\nasync def async_function" + ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert modified_code.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") @@ -182,27 +169,6 @@ def sync_method(self, a: int, b: int) -> int: return a - b ''' - expected_decorated_code = ''' -import asyncio - -from codeflash.code_utils.codeflash_wrap_decorator import \\ - codeflash_behavior_async - - -class Calculator: - """Test class with async methods.""" - - @codeflash_behavior_async - async def async_method(self, a: int, b: int) -> int: - """Async method in class.""" - await asyncio.sleep(0.005) - return a ** b - - def sync_method(self, a: int, b: int) -> int: - """Sync method in class.""" - return a - b -''' - test_file = temp_dir / "test_async.py" test_file.write_text(async_class_code) @@ -217,11 +183,21 @@ def sync_method(self, a: int, b: int) -> int: assert decorator_added modified_code = test_file.read_text() - assert modified_code.strip() == expected_decorated_code.strip() + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) + code_with_decorator = async_class_code.replace( + " async def async_method", f" @{decorator_name}\n async def async_method" + ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert modified_code.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") def test_async_decorator_no_duplicate_application(temp_dir): + # Case 1: Old-style import already present — injector should detect and skip already_decorated_code = ''' from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async import asyncio @@ -243,6 +219,30 @@ async def async_function(x: int, y: int) -> int: # Should not add duplicate decorator assert not decorator_added + # Case 2: Inline definition already present — injector should detect and skip + already_inline_code = ''' +import asyncio + +def codeflash_behavior_async(func): + return func + +@codeflash_behavior_async +async def async_function(x: int, y: int) -> int: + """Already decorated async function.""" + await asyncio.sleep(0.01) + return x * y +''' + + test_file2 = temp_dir / "test_async2.py" + test_file2.write_text(already_inline_code) + + func2 = FunctionToOptimize(function_name="async_function", file_path=test_file2, parents=[], is_async=True) + + decorator_added2 = add_async_decorator_to_function(test_file2, func2, TestingMode.BEHAVIOR) + + # Should not add duplicate decorator + assert not decorator_added2 + @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows") def test_inject_profiling_async_function_behavior_mode(temp_dir): @@ -285,11 +285,18 @@ async def test_async_function(): assert source_success is True - # Verify the file was modified + # Verify the file was modified with exact expected output instrumented_source = source_file.read_text() - assert "@codeflash_behavior_async" in instrumented_source - assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source - assert "codeflash_behavior_async" in instrumented_source + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) + code_with_decorator = source_module_code.replace( + "async def async_function", f"@{decorator_name}\nasync def async_function" + ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert instrumented_source.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() success, instrumented_test_code = inject_profiling_into_existing_test( test_file, [CodePosition(8, 18), CodePosition(11, 19)], func, temp_dir, mode=TestingMode.BEHAVIOR @@ -340,12 +347,18 @@ async def test_async_function(): assert source_success is True - # Verify the file was modified + # Verify the file was modified with exact expected output instrumented_source = source_file.read_text() - assert "@codeflash_performance_async" in instrumented_source - # Check for the import with line continuation formatting - assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source - assert "codeflash_performance_async" in instrumented_source + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE) + code_with_decorator = source_module_code.replace( + "async def async_function", f"@{decorator_name}\nasync def async_function" + ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert instrumented_source.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() # Now test the full pipeline with source module path success, instrumented_test_code = inject_profiling_into_existing_test( @@ -406,11 +419,16 @@ async def test_mixed_functions(): # Verify the file was modified instrumented_source = source_file.read_text() - assert "@codeflash_behavior_async" in instrumented_source - assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source - assert "codeflash_behavior_async" in instrumented_source - # Sync function should remain unchanged - assert "def sync_function(x: int, y: int) -> int:" in instrumented_source + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) + code_with_decorator = source_module_code.replace( + "async def async_function", f"@{decorator_name}\nasync def async_function" + ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert instrumented_source.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() success, instrumented_test_code = inject_profiling_into_existing_test( test_file, [CodePosition(8, 18), CodePosition(11, 19)], async_func, temp_dir, mode=TestingMode.BEHAVIOR @@ -446,24 +464,19 @@ async def nested_async_method(self, x: int) -> int: decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.BEHAVIOR) - expected_output = """import asyncio - -from codeflash.code_utils.codeflash_wrap_decorator import \\ - codeflash_behavior_async - - -class OuterClass: - class InnerClass: - @codeflash_behavior_async - async def nested_async_method(self, x: int) -> int: - \"\"\"Nested async method.\"\"\" - await asyncio.sleep(0.001) - return x * 2 -""" - assert decorator_added modified_code = test_file.read_text() - assert modified_code.strip() == expected_output.strip() + from codeflash.code_utils.formatter import sort_imports + + decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR) + code_with_decorator = nested_async_code.replace( + " async def nested_async_method", + f" @{decorator_name}\n async def nested_async_method", + ) + code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}" + expected = sort_imports(code=code_with_import, float_to_top=True) + assert modified_code.strip() == expected.strip() + assert (temp_dir / ASYNC_HELPER_FILENAME).exists() @pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")