diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index daee371d7..36120ec77 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -352,32 +352,52 @@ def _handle_show_config() -> None: from codeflash.setup.detector import detect_project, has_existing_config project_root = Path.cwd() - detected = detect_project(project_root) + config_exists, _ = has_existing_config(project_root) - # Check if config exists or is auto-detected - config_exists, config_file = has_existing_config(project_root) - status = "Saved config" if config_exists else "Auto-detected (not saved)" + if config_exists: + from codeflash.code_utils.config_parser import parse_config_file - console.print() - console.print(f"[bold]Codeflash Configuration[/bold] ({status})") - if config_exists and config_file: - console.print(f"[dim]Config file: {project_root / config_file}[/dim]") - console.print() + config, config_file_path = parse_config_file() + status = "Saved config" - table = Table(show_header=True, header_style="bold cyan") - table.add_column("Setting", style="dim") - table.add_column("Value") - - table.add_row("Language", detected.language) - table.add_row("Project root", str(detected.project_root)) - table.add_row("Module root", str(detected.module_root)) - table.add_row("Tests root", str(detected.tests_root) if detected.tests_root else "(not detected)") - table.add_row("Test runner", detected.test_runner or "(not detected)") - table.add_row("Formatter", ", ".join(detected.formatter_cmds) if detected.formatter_cmds else "(not detected)") - table.add_row( - "Ignore paths", ", ".join(str(p) for p in detected.ignore_paths) if detected.ignore_paths else "(none)" - ) - table.add_row("Confidence", f"{detected.confidence:.0%}") + console.print() + console.print(f"[bold]Codeflash Configuration[/bold] ({status})") + console.print(f"[dim]Config file: {config_file_path}[/dim]") + console.print() + + table = Table(show_header=True, header_style="bold cyan") + table.add_column("Setting", style="dim") + table.add_column("Value") + + table.add_row("Project root", str(project_root)) + table.add_row("Module root", config.get("module_root", "(not set)")) + table.add_row("Tests root", config.get("tests_root", "(not set)")) + table.add_row("Test runner", config.get("test_framework", config.get("pytest_cmd", "(not set)"))) + table.add_row("Formatter", ", ".join(config["formatter_cmds"]) if config.get("formatter_cmds") else "(not set)") + ignore_paths = config.get("ignore_paths", []) + table.add_row("Ignore paths", ", ".join(str(p) for p in ignore_paths) if ignore_paths else "(none)") + else: + detected = detect_project(project_root) + status = "Auto-detected (not saved)" + + console.print() + console.print(f"[bold]Codeflash Configuration[/bold] ({status})") + console.print() + + table = Table(show_header=True, header_style="bold cyan") + table.add_column("Setting", style="dim") + table.add_column("Value") + + table.add_row("Language", detected.language) + table.add_row("Project root", str(detected.project_root)) + table.add_row("Module root", str(detected.module_root)) + table.add_row("Tests root", str(detected.tests_root) if detected.tests_root else "(not detected)") + table.add_row("Test runner", detected.test_runner or "(not detected)") + table.add_row("Formatter", ", ".join(detected.formatter_cmds) if detected.formatter_cmds else "(not detected)") + table.add_row( + "Ignore paths", ", ".join(str(p) for p in detected.ignore_paths) if detected.ignore_paths else "(none)" + ) + table.add_row("Confidence", f"{detected.confidence:.0%}") console.print(table) console.print() diff --git a/codeflash/code_utils/time_utils.py b/codeflash/code_utils/time_utils.py index ff04b5037..42cfa9703 100644 --- a/codeflash/code_utils/time_utils.py +++ b/codeflash/code_utils/time_utils.py @@ -1,5 +1,7 @@ from __future__ import annotations +from codeflash.result.critic import performance_gain + def humanize_runtime(time_in_ns: int) -> str: runtime_human: str = str(time_in_ns) @@ -89,3 +91,13 @@ def format_perf(percentage: float) -> str: if abs_perc >= 1: return f"{percentage:.2f}" return f"{percentage:.3f}" + + +def format_runtime_comment(original_time_ns: int, optimized_time_ns: int, comment_prefix: str = "#") -> str: + perf_gain = format_perf( + abs(performance_gain(original_runtime_ns=original_time_ns, optimized_runtime_ns=optimized_time_ns) * 100) + ) + status = "slower" if optimized_time_ns > original_time_ns else "faster" + return ( + f"{comment_prefix} {format_time(original_time_ns)} -> {format_time(optimized_time_ns)} ({perf_gain}% {status})" + ) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 1a032ec36..5c2afe40d 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -728,6 +728,10 @@ def discover_tests_pytest( logger.debug(f"Pytest collection exit code: {exitcode}") if pytest_rootdir is not None: cfg.tests_project_rootdir = Path(pytest_rootdir) + if discover_only_these_tests: + resolved_discover_only = {p.resolve() for p in discover_only_these_tests} + else: + resolved_discover_only = None file_to_test_map: dict[Path, list[FunctionCalledInTest]] = defaultdict(list) for test in tests: if "__replay_test" in test["test_file"]: @@ -737,13 +741,14 @@ def discover_tests_pytest( else: test_type = TestType.EXISTING_UNIT_TEST + test_file_path = Path(test["test_file"]).resolve() test_obj = TestsInFile( - test_file=Path(test["test_file"]), + test_file=test_file_path, test_class=test["test_class"], test_function=test["test_function"], test_type=test_type, ) - if discover_only_these_tests and test_obj.test_file not in discover_only_these_tests: + if resolved_discover_only and test_obj.test_file not in resolved_discover_only: continue file_to_test_map[test_obj.test_file].append(test_obj) # Within these test files, find the project functions they are referring to and return their names/locations diff --git a/codeflash/languages/javascript/edit_tests.py b/codeflash/languages/javascript/edit_tests.py index 00ba04f9c..601da3cda 100644 --- a/codeflash/languages/javascript/edit_tests.py +++ b/codeflash/languages/javascript/edit_tests.py @@ -11,27 +11,8 @@ from pathlib import Path from codeflash.cli_cmds.console import logger -from codeflash.code_utils.time_utils import format_perf, format_time +from codeflash.code_utils.time_utils import format_runtime_comment from codeflash.models.models import GeneratedTests, GeneratedTestsList -from codeflash.result.critic import performance_gain - - -def format_runtime_comment(original_time: int, optimized_time: int) -> str: - """Format a runtime comparison comment for JavaScript. - - Args: - original_time: Original runtime in nanoseconds. - optimized_time: Optimized runtime in nanoseconds. - - Returns: - Formatted comment string with // prefix. - - """ - perf_gain = format_perf( - abs(performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time) * 100) - ) - status = "slower" if optimized_time > original_time else "faster" - return f"// {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})" def add_runtime_comments(source: str, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]) -> str: @@ -120,7 +101,7 @@ def find_matching_test(test_description: str) -> str | None: # Only add comment if line has a function call and doesn't already have a comment if func_call_pattern.search(line) and "//" not in line and "expect(" in line: orig_time, opt_time = timing_by_full_name[current_matched_full_name] - comment = format_runtime_comment(orig_time, opt_time) + comment = format_runtime_comment(orig_time, opt_time, comment_prefix="//") logger.debug(f"[js-annotations] Adding comment to test '{current_test_name}': {comment}") # Add comment at end of line line = f"{line.rstrip()} {comment}" diff --git a/codeflash/languages/python/static_analysis/code_replacer.py b/codeflash/languages/python/static_analysis/code_replacer.py index 4e100a230..fd607d975 100644 --- a/codeflash/languages/python/static_analysis/code_replacer.py +++ b/codeflash/languages/python/static_analysis/code_replacer.py @@ -239,149 +239,6 @@ def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None: test_path.write_text(modified_module.code, encoding="utf-8") -class OptimFunctionCollector(cst.CSTVisitor): - METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,) - - def __init__( - self, - preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] | None = None, - function_names: set[tuple[str | None, str]] | None = None, - ) -> None: - super().__init__() - self.preexisting_objects = preexisting_objects if preexisting_objects is not None else set() - - self.function_names = function_names # set of (class_name, function_name) - self.modified_functions: dict[ - tuple[str | None, str], cst.FunctionDef - ] = {} # keys are (class_name, function_name) - self.new_functions: list[cst.FunctionDef] = [] - self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list) - self.new_classes: list[cst.ClassDef] = [] - self.current_class = None - self.modified_init_functions: dict[str, cst.FunctionDef] = {} - - def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: - if (self.current_class, node.name.value) in self.function_names: - self.modified_functions[(self.current_class, node.name.value)] = node - elif self.current_class and node.name.value == "__init__": - self.modified_init_functions[self.current_class] = node - elif ( - self.preexisting_objects - and (node.name.value, ()) not in self.preexisting_objects - and self.current_class is None - ): - self.new_functions.append(node) - return False - - def visit_ClassDef(self, node: cst.ClassDef) -> bool: - if self.current_class: - return False # If already in a class, do not recurse deeper - self.current_class = node.name.value - - parents = (FunctionParent(name=node.name.value, type="ClassDef"),) - - if (node.name.value, ()) not in self.preexisting_objects: - self.new_classes.append(node) - - for child_node in node.body.body: - if ( - self.preexisting_objects - and isinstance(child_node, cst.FunctionDef) - and (child_node.name.value, parents) not in self.preexisting_objects - ): - self.new_class_functions[node.name.value].append(child_node) - - return True - - def leave_ClassDef(self, node: cst.ClassDef) -> None: - if self.current_class: - self.current_class = None - - -class OptimFunctionReplacer(cst.CSTTransformer): - def __init__( - self, - modified_functions: Optional[dict[tuple[str | None, str], cst.FunctionDef]] = None, - new_classes: Optional[list[cst.ClassDef]] = None, - new_functions: Optional[list[cst.FunctionDef]] = None, - new_class_functions: Optional[dict[str, list[cst.FunctionDef]]] = None, - modified_init_functions: Optional[dict[str, cst.FunctionDef]] = None, - ) -> None: - super().__init__() - self.modified_functions = modified_functions if modified_functions is not None else {} - self.new_functions = new_functions if new_functions is not None else [] - self.new_classes = new_classes if new_classes is not None else [] - self.new_class_functions = new_class_functions if new_class_functions is not None else defaultdict(list) - self.modified_init_functions: dict[str, cst.FunctionDef] = ( - modified_init_functions if modified_init_functions is not None else {} - ) - self.current_class = None - - def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: - return False - - def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: - if (self.current_class, original_node.name.value) in self.modified_functions: - node = self.modified_functions[(self.current_class, original_node.name.value)] - return updated_node.with_changes(body=node.body, decorators=node.decorators) - if original_node.name.value == "__init__" and self.current_class in self.modified_init_functions: - return self.modified_init_functions[self.current_class] - - return updated_node - - def visit_ClassDef(self, node: cst.ClassDef) -> bool: - if self.current_class: - return False # If already in a class, do not recurse deeper - self.current_class = node.name.value - return True - - def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: - if self.current_class and self.current_class == original_node.name.value: - self.current_class = None - if original_node.name.value in self.new_class_functions: - return updated_node.with_changes( - body=updated_node.body.with_changes( - body=(list(updated_node.body.body) + list(self.new_class_functions[original_node.name.value])) - ) - ) - return updated_node - - def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: - node = updated_node - max_function_index = None - max_class_index = None - for index, _node in enumerate(node.body): - if isinstance(_node, cst.FunctionDef): - max_function_index = index - if isinstance(_node, cst.ClassDef): - max_class_index = index - - if self.new_classes: - existing_class_names = {_node.name.value for _node in node.body if isinstance(_node, cst.ClassDef)} - - unique_classes = [ - new_class for new_class in self.new_classes if new_class.name.value not in existing_class_names - ] - if unique_classes: - new_classes_insertion_idx = max_class_index or find_insertion_index_after_imports(node) - new_body = list( - chain(node.body[:new_classes_insertion_idx], unique_classes, node.body[new_classes_insertion_idx:]) - ) - node = node.with_changes(body=new_body) - - if max_function_index is not None: - node = node.with_changes( - body=(*node.body[: max_function_index + 1], *self.new_functions, *node.body[max_function_index + 1 :]) - ) - elif max_class_index is not None: - node = node.with_changes( - body=(*node.body[: max_class_index + 1], *self.new_functions, *node.body[max_class_index + 1 :]) - ) - else: - node = node.with_changes(body=(*self.new_functions, *node.body)) - return node - - def replace_functions_in_file( source_code: str, original_function_names: list[str], diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 77d9108ab..8024d7b88 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -11,7 +11,6 @@ from codeflash.languages.python.static_analysis.code_replacer import ( AddRequestArgument, AutouseFixtureModifier, - OptimFunctionCollector, PytestMarkAdder, is_zero_diff, replace_functions_and_add_imports, @@ -3476,142 +3475,6 @@ def hydrate_input_text_actions_with_field_names( assert new_code == expected -# OptimFunctionCollector async function tests -def test_optim_function_collector_with_async_functions(): - """Test OptimFunctionCollector correctly collects async functions.""" - import libcst as cst - - source_code = """ -def sync_function(): - return "sync" - -async def async_function(): - return "async" - -class TestClass: - def sync_method(self): - return "sync_method" - - async def async_method(self): - return "async_method" -""" - - tree = cst.parse_module(source_code) - collector = OptimFunctionCollector( - function_names={ - (None, "sync_function"), - (None, "async_function"), - ("TestClass", "sync_method"), - ("TestClass", "async_method"), - }, - preexisting_objects=None, - ) - tree.visit(collector) - - # Should collect both sync and async functions - assert len(collector.modified_functions) == 4 - assert (None, "sync_function") in collector.modified_functions - assert (None, "async_function") in collector.modified_functions - assert ("TestClass", "sync_method") in collector.modified_functions - assert ("TestClass", "async_method") in collector.modified_functions - - -def test_optim_function_collector_new_async_functions(): - """Test OptimFunctionCollector identifies new async functions not in preexisting objects.""" - import libcst as cst - - source_code = """ -def existing_function(): - return "existing" - -async def new_async_function(): - return "new_async" - -def new_sync_function(): - return "new_sync" - -class ExistingClass: - async def new_class_async_method(self): - return "new_class_async" -""" - - # Only existing_function is in preexisting objects - preexisting_objects = {("existing_function", ())} - - tree = cst.parse_module(source_code) - collector = OptimFunctionCollector( - function_names=set(), # Not looking for specific functions - preexisting_objects=preexisting_objects, - ) - tree.visit(collector) - - # Should identify new functions (both sync and async) - assert len(collector.new_functions) == 2 - function_names = [func.name.value for func in collector.new_functions] - assert "new_async_function" in function_names - assert "new_sync_function" in function_names - - # Should identify new class methods - assert "ExistingClass" in collector.new_class_functions - assert len(collector.new_class_functions["ExistingClass"]) == 1 - assert collector.new_class_functions["ExistingClass"][0].name.value == "new_class_async_method" - - -def test_optim_function_collector_mixed_scenarios(): - """Test OptimFunctionCollector with complex mix of sync/async functions and classes.""" - import libcst as cst - - source_code = """ -# Global functions -def global_sync(): - pass - -async def global_async(): - pass - -class ParentClass: - def __init__(self): - pass - - def sync_method(self): - pass - - async def async_method(self): - pass - -class ChildClass: - async def child_async_method(self): - pass - - def child_sync_method(self): - pass -""" - - # Looking for specific functions - function_names = { - (None, "global_sync"), - (None, "global_async"), - ("ParentClass", "sync_method"), - ("ParentClass", "async_method"), - ("ChildClass", "child_async_method"), - } - - tree = cst.parse_module(source_code) - collector = OptimFunctionCollector(function_names=function_names, preexisting_objects=None) - tree.visit(collector) - - # Should collect all specified functions (mix of sync and async) - assert len(collector.modified_functions) == 5 - assert (None, "global_sync") in collector.modified_functions - assert (None, "global_async") in collector.modified_functions - assert ("ParentClass", "sync_method") in collector.modified_functions - assert ("ParentClass", "async_method") in collector.modified_functions - assert ("ChildClass", "child_async_method") in collector.modified_functions - - # Should collect __init__ method - assert "ParentClass" in collector.modified_init_functions - - def test_is_zero_diff_async_sleep(): original_code = """ import time