Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 43 additions & 23 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,32 +352,52 @@ def _handle_show_config() -> None:
from codeflash.setup.detector import detect_project, has_existing_config

project_root = Path.cwd()
detected = detect_project(project_root)
config_exists, _ = has_existing_config(project_root)

# Check if config exists or is auto-detected
config_exists, config_file = has_existing_config(project_root)
status = "Saved config" if config_exists else "Auto-detected (not saved)"
if config_exists:
from codeflash.code_utils.config_parser import parse_config_file

console.print()
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
if config_exists and config_file:
console.print(f"[dim]Config file: {project_root / config_file}[/dim]")
console.print()
config, config_file_path = parse_config_file()
status = "Saved config"

table = Table(show_header=True, header_style="bold cyan")
table.add_column("Setting", style="dim")
table.add_column("Value")

table.add_row("Language", detected.language)
table.add_row("Project root", str(detected.project_root))
table.add_row("Module root", str(detected.module_root))
table.add_row("Tests root", str(detected.tests_root) if detected.tests_root else "(not detected)")
table.add_row("Test runner", detected.test_runner or "(not detected)")
table.add_row("Formatter", ", ".join(detected.formatter_cmds) if detected.formatter_cmds else "(not detected)")
table.add_row(
"Ignore paths", ", ".join(str(p) for p in detected.ignore_paths) if detected.ignore_paths else "(none)"
)
table.add_row("Confidence", f"{detected.confidence:.0%}")
console.print()
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
console.print(f"[dim]Config file: {config_file_path}[/dim]")
console.print()

table = Table(show_header=True, header_style="bold cyan")
table.add_column("Setting", style="dim")
table.add_column("Value")

table.add_row("Project root", str(project_root))
table.add_row("Module root", config.get("module_root", "(not set)"))
table.add_row("Tests root", config.get("tests_root", "(not set)"))
table.add_row("Test runner", config.get("test_framework", config.get("pytest_cmd", "(not set)")))
table.add_row("Formatter", ", ".join(config["formatter_cmds"]) if config.get("formatter_cmds") else "(not set)")
ignore_paths = config.get("ignore_paths", [])
table.add_row("Ignore paths", ", ".join(str(p) for p in ignore_paths) if ignore_paths else "(none)")
else:
detected = detect_project(project_root)
status = "Auto-detected (not saved)"

console.print()
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
console.print()

table = Table(show_header=True, header_style="bold cyan")
table.add_column("Setting", style="dim")
table.add_column("Value")

table.add_row("Language", detected.language)
table.add_row("Project root", str(detected.project_root))
table.add_row("Module root", str(detected.module_root))
table.add_row("Tests root", str(detected.tests_root) if detected.tests_root else "(not detected)")
table.add_row("Test runner", detected.test_runner or "(not detected)")
table.add_row("Formatter", ", ".join(detected.formatter_cmds) if detected.formatter_cmds else "(not detected)")
table.add_row(
"Ignore paths", ", ".join(str(p) for p in detected.ignore_paths) if detected.ignore_paths else "(none)"
)
table.add_row("Confidence", f"{detected.confidence:.0%}")

console.print(table)
console.print()
Expand Down
12 changes: 12 additions & 0 deletions codeflash/code_utils/time_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from codeflash.result.critic import performance_gain


def humanize_runtime(time_in_ns: int) -> str:
runtime_human: str = str(time_in_ns)
Expand Down Expand Up @@ -89,3 +91,13 @@ def format_perf(percentage: float) -> str:
if abs_perc >= 1:
return f"{percentage:.2f}"
return f"{percentage:.3f}"


def format_runtime_comment(original_time_ns: int, optimized_time_ns: int, comment_prefix: str = "#") -> str:
perf_gain = format_perf(
abs(performance_gain(original_runtime_ns=original_time_ns, optimized_runtime_ns=optimized_time_ns) * 100)
)
status = "slower" if optimized_time_ns > original_time_ns else "faster"
return (
f"{comment_prefix} {format_time(original_time_ns)} -> {format_time(optimized_time_ns)} ({perf_gain}% {status})"
)
9 changes: 7 additions & 2 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,10 @@ def discover_tests_pytest(
logger.debug(f"Pytest collection exit code: {exitcode}")
if pytest_rootdir is not None:
cfg.tests_project_rootdir = Path(pytest_rootdir)
if discover_only_these_tests:
resolved_discover_only = {p.resolve() for p in discover_only_these_tests}
else:
resolved_discover_only = None
file_to_test_map: dict[Path, list[FunctionCalledInTest]] = defaultdict(list)
for test in tests:
if "__replay_test" in test["test_file"]:
Expand All @@ -737,13 +741,14 @@ def discover_tests_pytest(
else:
test_type = TestType.EXISTING_UNIT_TEST

test_file_path = Path(test["test_file"]).resolve()
test_obj = TestsInFile(
test_file=Path(test["test_file"]),
test_file=test_file_path,
test_class=test["test_class"],
test_function=test["test_function"],
test_type=test_type,
)
if discover_only_these_tests and test_obj.test_file not in discover_only_these_tests:
if resolved_discover_only and test_obj.test_file not in resolved_discover_only:
continue
file_to_test_map[test_obj.test_file].append(test_obj)
# Within these test files, find the project functions they are referring to and return their names/locations
Expand Down
23 changes: 2 additions & 21 deletions codeflash/languages/javascript/edit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,8 @@
from pathlib import Path

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.time_utils import format_perf, format_time
from codeflash.code_utils.time_utils import format_runtime_comment
from codeflash.models.models import GeneratedTests, GeneratedTestsList
from codeflash.result.critic import performance_gain


def format_runtime_comment(original_time: int, optimized_time: int) -> str:
"""Format a runtime comparison comment for JavaScript.

Args:
original_time: Original runtime in nanoseconds.
optimized_time: Optimized runtime in nanoseconds.

Returns:
Formatted comment string with // prefix.

"""
perf_gain = format_perf(
abs(performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time) * 100)
)
status = "slower" if optimized_time > original_time else "faster"
return f"// {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"


def add_runtime_comments(source: str, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]) -> str:
Expand Down Expand Up @@ -120,7 +101,7 @@ def find_matching_test(test_description: str) -> str | None:
# Only add comment if line has a function call and doesn't already have a comment
if func_call_pattern.search(line) and "//" not in line and "expect(" in line:
orig_time, opt_time = timing_by_full_name[current_matched_full_name]
comment = format_runtime_comment(orig_time, opt_time)
comment = format_runtime_comment(orig_time, opt_time, comment_prefix="//")
logger.debug(f"[js-annotations] Adding comment to test '{current_test_name}': {comment}")
# Add comment at end of line
line = f"{line.rstrip()} {comment}"
Expand Down
143 changes: 0 additions & 143 deletions codeflash/languages/python/static_analysis/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,149 +239,6 @@ def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None:
test_path.write_text(modified_module.code, encoding="utf-8")


class OptimFunctionCollector(cst.CSTVisitor):
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)

def __init__(
self,
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] | None = None,
function_names: set[tuple[str | None, str]] | None = None,
) -> None:
super().__init__()
self.preexisting_objects = preexisting_objects if preexisting_objects is not None else set()

self.function_names = function_names # set of (class_name, function_name)
self.modified_functions: dict[
tuple[str | None, str], cst.FunctionDef
] = {} # keys are (class_name, function_name)
self.new_functions: list[cst.FunctionDef] = []
self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list)
self.new_classes: list[cst.ClassDef] = []
self.current_class = None
self.modified_init_functions: dict[str, cst.FunctionDef] = {}

def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
if (self.current_class, node.name.value) in self.function_names:
self.modified_functions[(self.current_class, node.name.value)] = node
elif self.current_class and node.name.value == "__init__":
self.modified_init_functions[self.current_class] = node
elif (
self.preexisting_objects
and (node.name.value, ()) not in self.preexisting_objects
and self.current_class is None
):
self.new_functions.append(node)
return False

def visit_ClassDef(self, node: cst.ClassDef) -> bool:
if self.current_class:
return False # If already in a class, do not recurse deeper
self.current_class = node.name.value

parents = (FunctionParent(name=node.name.value, type="ClassDef"),)

if (node.name.value, ()) not in self.preexisting_objects:
self.new_classes.append(node)

for child_node in node.body.body:
if (
self.preexisting_objects
and isinstance(child_node, cst.FunctionDef)
and (child_node.name.value, parents) not in self.preexisting_objects
):
self.new_class_functions[node.name.value].append(child_node)

return True

def leave_ClassDef(self, node: cst.ClassDef) -> None:
if self.current_class:
self.current_class = None


class OptimFunctionReplacer(cst.CSTTransformer):
def __init__(
self,
modified_functions: Optional[dict[tuple[str | None, str], cst.FunctionDef]] = None,
new_classes: Optional[list[cst.ClassDef]] = None,
new_functions: Optional[list[cst.FunctionDef]] = None,
new_class_functions: Optional[dict[str, list[cst.FunctionDef]]] = None,
modified_init_functions: Optional[dict[str, cst.FunctionDef]] = None,
) -> None:
super().__init__()
self.modified_functions = modified_functions if modified_functions is not None else {}
self.new_functions = new_functions if new_functions is not None else []
self.new_classes = new_classes if new_classes is not None else []
self.new_class_functions = new_class_functions if new_class_functions is not None else defaultdict(list)
self.modified_init_functions: dict[str, cst.FunctionDef] = (
modified_init_functions if modified_init_functions is not None else {}
)
self.current_class = None

def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
return False

def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
if (self.current_class, original_node.name.value) in self.modified_functions:
node = self.modified_functions[(self.current_class, original_node.name.value)]
return updated_node.with_changes(body=node.body, decorators=node.decorators)
if original_node.name.value == "__init__" and self.current_class in self.modified_init_functions:
return self.modified_init_functions[self.current_class]

return updated_node

def visit_ClassDef(self, node: cst.ClassDef) -> bool:
if self.current_class:
return False # If already in a class, do not recurse deeper
self.current_class = node.name.value
return True

def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
if self.current_class and self.current_class == original_node.name.value:
self.current_class = None
if original_node.name.value in self.new_class_functions:
return updated_node.with_changes(
body=updated_node.body.with_changes(
body=(list(updated_node.body.body) + list(self.new_class_functions[original_node.name.value]))
)
)
return updated_node

def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
node = updated_node
max_function_index = None
max_class_index = None
for index, _node in enumerate(node.body):
if isinstance(_node, cst.FunctionDef):
max_function_index = index
if isinstance(_node, cst.ClassDef):
max_class_index = index

if self.new_classes:
existing_class_names = {_node.name.value for _node in node.body if isinstance(_node, cst.ClassDef)}

unique_classes = [
new_class for new_class in self.new_classes if new_class.name.value not in existing_class_names
]
if unique_classes:
new_classes_insertion_idx = max_class_index or find_insertion_index_after_imports(node)
new_body = list(
chain(node.body[:new_classes_insertion_idx], unique_classes, node.body[new_classes_insertion_idx:])
)
node = node.with_changes(body=new_body)

if max_function_index is not None:
node = node.with_changes(
body=(*node.body[: max_function_index + 1], *self.new_functions, *node.body[max_function_index + 1 :])
)
elif max_class_index is not None:
node = node.with_changes(
body=(*node.body[: max_class_index + 1], *self.new_functions, *node.body[max_class_index + 1 :])
)
else:
node = node.with_changes(body=(*self.new_functions, *node.body))
return node


def replace_functions_in_file(
source_code: str,
original_function_names: list[str],
Expand Down
Loading
Loading