diff --git a/CLAUDE.md b/CLAUDE.md index 622351db4..c4628e91a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -28,6 +28,10 @@ Discovery → Ranking → Context Extraction → Test Gen + Optimization → Bas - **Tracer**: Profiling system that records function call trees and timings (`tracing/`, `tracer.py`) - **Worktree mode**: Git worktree-based parallel optimization (`--worktree` flag) +## PR Reviews + +- GitHub PR comments and review feedback can be stale — they may reference issues already fixed by a later commit. Before acting on review feedback, verify it still applies to the current code. If the issue no longer exists, resolve the conversation in the GitHub UI. + # Agent Rules diff --git a/codeflash/languages/base.py b/codeflash/languages/base.py index 8542547a4..3e10da319 100644 --- a/codeflash/languages/base.py +++ b/codeflash/languages/base.py @@ -15,7 +15,7 @@ from pathlib import Path from codeflash.discovery.functions_to_optimize import FunctionToOptimize - from codeflash.models.models import FunctionSource + from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId from codeflash.languages.language_enum import Language from codeflash.models.function_types import FunctionParent @@ -538,6 +538,87 @@ def remove_test_functions(self, test_source: str, functions_to_remove: list[str] """ ... + def postprocess_generated_tests( + self, generated_tests: GeneratedTestsList, test_framework: str, project_root: Path, source_file_path: Path + ) -> GeneratedTestsList: + """Apply language-specific postprocessing to generated tests. + + Args: + generated_tests: Generated tests to update. + test_framework: Test framework used for the project. + project_root: Project root directory. + source_file_path: Path to the source file under optimization. + + Returns: + Updated generated tests. + + """ + ... + + def remove_test_functions_from_generated_tests( + self, generated_tests: GeneratedTestsList, functions_to_remove: list[str] + ) -> GeneratedTestsList: + """Remove specific test functions from generated tests. + + Args: + generated_tests: Generated tests to update. + functions_to_remove: List of function names to remove. + + Returns: + Updated generated tests. + + """ + ... + + def add_runtime_comments_to_generated_tests( + self, + generated_tests: GeneratedTestsList, + original_runtimes: dict[InvocationId, list[int]], + optimized_runtimes: dict[InvocationId, list[int]], + tests_project_rootdir: Path | None = None, + ) -> GeneratedTestsList: + """Add runtime comments to generated tests. + + Args: + generated_tests: Generated tests to update. + original_runtimes: Mapping of invocation IDs to original runtimes. + optimized_runtimes: Mapping of invocation IDs to optimized runtimes. + tests_project_rootdir: Root directory for tests (if applicable). + + Returns: + Updated generated tests. + + """ + ... + + def add_global_declarations(self, optimized_code: str, original_source: str, module_abspath: Path) -> str: + """Add new global declarations from optimized code to original source. + + Args: + optimized_code: The optimized code that may contain new declarations. + original_source: The original source code. + module_abspath: Path to the module file (for parser selection). + + Returns: + Original source with new declarations added. + + """ + ... + + def extract_calling_function_source(self, source_code: str, function_name: str, ref_line: int) -> str | None: + """Extract the source code of a calling function. + + Args: + source_code: Full source code of the file. + function_name: Name of the function to extract. + ref_line: Line number where the reference is. + + Returns: + Source code of the function, or None if not found. + + """ + ... + # === Test Result Comparison === def compare_test_results( diff --git a/codeflash/languages/javascript/code_replacer.py b/codeflash/languages/javascript/code_replacer.py new file mode 100644 index 000000000..83c96ec6a --- /dev/null +++ b/codeflash/languages/javascript/code_replacer.py @@ -0,0 +1,217 @@ +"""JavaScript/TypeScript code replacement helpers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from codeflash.cli_cmds.console import logger + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash.languages.base import Language + from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer + + +# Author: ali +def _add_global_declarations_for_language( + optimized_code: str, original_source: str, module_abspath: Path, language: Language +) -> str: + """Add new global declarations from optimized code to original source. + + Finds module-level declarations (const, let, var, class, type, interface, enum) + in the optimized code that don't exist in the original source and adds them. + + New declarations are inserted after any existing declarations they depend on. + For example, if optimized code has `const _has = FOO.bar.bind(FOO)`, and `FOO` + is already declared in the original source, `_has` will be inserted after `FOO`. + + Args: + optimized_code: The optimized code that may contain new declarations. + original_source: The original source code. + module_abspath: Path to the module file (for parser selection). + language: The language of the code. + + Returns: + Original source with new declarations added in dependency order. + + """ + from codeflash.languages.base import Language + + if language not in (Language.JAVASCRIPT, Language.TYPESCRIPT): + return original_source + + try: + from codeflash.languages.javascript.treesitter import get_analyzer_for_file + + analyzer = get_analyzer_for_file(module_abspath) + + original_declarations = analyzer.find_module_level_declarations(original_source) + optimized_declarations = analyzer.find_module_level_declarations(optimized_code) + + if not optimized_declarations: + return original_source + + existing_names = _get_existing_names(original_declarations, analyzer, original_source) + new_declarations = _filter_new_declarations(optimized_declarations, existing_names) + + if not new_declarations: + return original_source + + # Build a map of existing declaration names to their end lines (1-indexed) + existing_decl_end_lines = {decl.name: decl.end_line for decl in original_declarations} + + # Insert each new declaration after its dependencies + result = original_source + for decl in new_declarations: + result = _insert_declaration_after_dependencies( + result, decl, existing_decl_end_lines, analyzer, module_abspath + ) + # Update the map with the newly inserted declaration for subsequent insertions + # Re-parse to get accurate line numbers after insertion + updated_declarations = analyzer.find_module_level_declarations(result) + existing_decl_end_lines = {d.name: d.end_line for d in updated_declarations} + + return result + + except Exception as e: + logger.debug(f"Error adding global declarations: {e}") + return original_source + + +# Author: ali +def _get_existing_names(original_declarations: list, analyzer: TreeSitterAnalyzer, original_source: str) -> set[str]: + """Get all names that already exist in the original source (declarations + imports).""" + existing_names = {decl.name for decl in original_declarations} + + original_imports = analyzer.find_imports(original_source) + for imp in original_imports: + if imp.default_import: + existing_names.add(imp.default_import) + for name, alias in imp.named_imports: + existing_names.add(alias if alias else name) + if imp.namespace_import: + existing_names.add(imp.namespace_import) + + return existing_names + + +# Author: ali +def _filter_new_declarations(optimized_declarations: list, existing_names: set[str]) -> list: + """Filter declarations to only those that don't exist in the original source.""" + new_declarations = [] + seen_sources: set[str] = set() + + # Sort by line number to maintain order from optimized code + sorted_declarations = sorted(optimized_declarations, key=lambda d: d.start_line) + + for decl in sorted_declarations: + if decl.name not in existing_names and decl.source_code not in seen_sources: + new_declarations.append(decl) + seen_sources.add(decl.source_code) + + return new_declarations + + +# Author: ali +def _insert_declaration_after_dependencies( + source: str, + declaration, + existing_decl_end_lines: dict[str, int], + analyzer: TreeSitterAnalyzer, + module_abspath: Path, +) -> str: + """Insert a declaration after the last existing declaration it depends on. + + Args: + source: Current source code. + declaration: The declaration to insert. + existing_decl_end_lines: Map of existing declaration names to their end lines. + analyzer: TreeSitter analyzer. + module_abspath: Path to the module file. + + Returns: + Source code with the declaration inserted at the correct position. + + """ + # Find identifiers referenced in this declaration + referenced_names = analyzer.find_referenced_identifiers(declaration.source_code) + + # Find the latest end line among all referenced declarations + insertion_line = _find_insertion_line_for_declaration(source, referenced_names, existing_decl_end_lines, analyzer) + + lines = source.splitlines(keepends=True) + + # Ensure proper spacing + decl_code = declaration.source_code + if not decl_code.endswith("\n"): + decl_code += "\n" + + # Add blank line before if inserting after content + if insertion_line > 0 and lines[insertion_line - 1].strip(): + decl_code = "\n" + decl_code + + before = lines[:insertion_line] + after = lines[insertion_line:] + + return "".join([*before, decl_code, *after]) + + +# Author: ali +def _find_insertion_line_for_declaration( + source: str, referenced_names: set[str], existing_decl_end_lines: dict[str, int], analyzer: TreeSitterAnalyzer +) -> int: + """Find the line where a declaration should be inserted based on its dependencies. + + Args: + source: Source code. + referenced_names: Names referenced by the declaration. + existing_decl_end_lines: Map of declaration names to their end lines (1-indexed). + analyzer: TreeSitter analyzer. + + Returns: + Line index (0-based) where the declaration should be inserted. + + """ + # Find the maximum end line among referenced declarations + max_dependency_line = 0 + for name in referenced_names: + if name in existing_decl_end_lines: + max_dependency_line = max(max_dependency_line, existing_decl_end_lines[name]) + + if max_dependency_line > 0: + # Insert after the last dependency (end_line is 1-indexed, we need 0-indexed) + return max_dependency_line + + # No dependencies found - insert after imports + lines = source.splitlines(keepends=True) + return _find_line_after_imports(lines, analyzer, source) + + +# Author: ali +def _find_line_after_imports(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int: + """Find the line index after all imports. + + Args: + lines: Source lines. + analyzer: TreeSitter analyzer. + source: Full source code. + + Returns: + Line index (0-based) for insertion after imports. + + """ + try: + imports = analyzer.find_imports(source) + if imports: + return max(imp.end_line for imp in imports) + except Exception as exc: + logger.debug(f"Exception in _find_line_after_imports: {exc}") + + # Default: insert at beginning (after shebang/directive comments) + for i, line in enumerate(lines): + stripped = line.strip() + if stripped and not stripped.startswith("//") and not stripped.startswith("#!"): + return i + + return 0 diff --git a/codeflash/languages/javascript/edit_tests.py b/codeflash/languages/javascript/edit_tests.py index a4523e83b..00ba04f9c 100644 --- a/codeflash/languages/javascript/edit_tests.py +++ b/codeflash/languages/javascript/edit_tests.py @@ -6,10 +6,13 @@ from __future__ import annotations +import os import re +from pathlib import Path from codeflash.cli_cmds.console import logger from codeflash.code_utils.time_utils import format_perf, format_time +from codeflash.models.models import GeneratedTests, GeneratedTestsList from codeflash.result.critic import performance_gain @@ -130,6 +133,165 @@ def find_matching_test(test_description: str) -> str | None: return "\n".join(modified_lines) +JS_TEST_EXTENSIONS = ( + ".test.ts", + ".test.js", + ".test.tsx", + ".test.jsx", + ".spec.ts", + ".spec.js", + ".spec.tsx", + ".spec.jsx", + ".ts", + ".js", + ".tsx", + ".jsx", + ".mjs", + ".mts", +) + + +# TODO:{self} Needs cleanup for jest logic in else block +# Author: Sarthak Agarwal +def is_js_test_module_path(test_module_path: str) -> bool: + """Return True when the module path looks like a JS/TS test path.""" + return any(test_module_path.endswith(ext) for ext in JS_TEST_EXTENSIONS) + + +# Author: Sarthak Agarwal +def resolve_js_test_module_path(test_module_path: str, tests_project_rootdir: Path) -> Path: + """Resolve a JS/TS test module path to a concrete file path.""" + if "/" in test_module_path or "\\" in test_module_path: + return tests_project_rootdir / Path(test_module_path) + + matched_ext = None + for ext in JS_TEST_EXTENSIONS: + if test_module_path.endswith(ext): + matched_ext = ext + break + + if matched_ext: + base_path = test_module_path[: -len(matched_ext)] + file_path = base_path.replace(".", os.sep) + matched_ext + tests_dir_name = tests_project_rootdir.name + if file_path.startswith((tests_dir_name + os.sep, tests_dir_name + "/")): + return tests_project_rootdir.parent / Path(file_path) + return tests_project_rootdir / Path(file_path) + + return tests_project_rootdir / Path(test_module_path) + + +# Patterns for normalizing codeflash imports (legacy -> npm package) +# Author: Sarthak Agarwal +_CODEFLASH_REQUIRE_PATTERN = re.compile( + r"(const|let|var)\s+(\w+)\s*=\s*require\s*\(\s*['\"]\.?/?codeflash-jest-helper['\"]\s*\)" +) +_CODEFLASH_IMPORT_PATTERN = re.compile(r"import\s+(?:\*\s+as\s+)?(\w+)\s+from\s+['\"]\.?/?codeflash-jest-helper['\"]") + + +# Author: Sarthak Agarwal +def normalize_codeflash_imports(source: str) -> str: + """Normalize codeflash imports to use the npm package. + + Replaces legacy local file imports: + const codeflash = require('./codeflash-jest-helper') + import codeflash from './codeflash-jest-helper' + + With npm package imports: + const codeflash = require('codeflash') + + Args: + source: JavaScript/TypeScript source code. + + Returns: + Source code with normalized imports. + + """ + # Replace CommonJS require + source = _CODEFLASH_REQUIRE_PATTERN.sub(r"\1 \2 = require('codeflash')", source) + # Replace ES module import + return _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source) + + +# Author: ali +def inject_test_globals(generated_tests: GeneratedTestsList, test_framework: str = "jest") -> GeneratedTestsList: + # TODO: inside the prompt tell the llm if it should import jest functions or it's already injected in the global window + """Inject test globals into all generated tests. + + Args: + generated_tests: List of generated tests. + test_framework: The test framework being used ("jest", "vitest", or "mocha"). + + Returns: + Generated tests with test globals injected. + + """ + # we only inject test globals for esm modules + # Use vitest imports for vitest projects, jest imports for jest projects + if test_framework == "vitest": + global_import = "import { vi, describe, it, expect, beforeEach, afterEach, beforeAll, test } from 'vitest'\n" + else: + # Default to jest imports for jest and other frameworks + global_import = ( + "import { jest, describe, it, expect, beforeEach, afterEach, beforeAll, test } from '@jest/globals'\n" + ) + + for test in generated_tests.generated_tests: + test.generated_original_test_source = global_import + test.generated_original_test_source + test.instrumented_behavior_test_source = global_import + test.instrumented_behavior_test_source + test.instrumented_perf_test_source = global_import + test.instrumented_perf_test_source + return generated_tests + + +# Author: ali +def disable_ts_check(generated_tests: GeneratedTestsList) -> GeneratedTestsList: + """Disable TypeScript type checking in all generated tests. + + Args: + generated_tests: List of generated tests. + + Returns: + Generated tests with TypeScript type checking disabled. + + """ + # we only inject test globals for esm modules + ts_nocheck = "// @ts-nocheck\n" + + for test in generated_tests.generated_tests: + test.generated_original_test_source = ts_nocheck + test.generated_original_test_source + test.instrumented_behavior_test_source = ts_nocheck + test.instrumented_behavior_test_source + test.instrumented_perf_test_source = ts_nocheck + test.instrumented_perf_test_source + return generated_tests + + +# Author: Sarthak Agarwal +def normalize_generated_tests_imports(generated_tests: GeneratedTestsList) -> GeneratedTestsList: + """Normalize codeflash imports in all generated tests. + + Args: + generated_tests: List of generated tests. + + Returns: + Generated tests with normalized imports. + + """ + normalized_tests = [] + for test in generated_tests.generated_tests: + # Only normalize JS/TS files + if test.behavior_file_path.suffix in (".js", ".ts", ".jsx", ".tsx", ".mjs", ".mts"): + normalized_test = GeneratedTests( + generated_original_test_source=normalize_codeflash_imports(test.generated_original_test_source), + instrumented_behavior_test_source=normalize_codeflash_imports(test.instrumented_behavior_test_source), + instrumented_perf_test_source=normalize_codeflash_imports(test.instrumented_perf_test_source), + behavior_file_path=test.behavior_file_path, + perf_file_path=test.perf_file_path, + ) + normalized_tests.append(normalized_test) + else: + normalized_tests.append(test) + return GeneratedTestsList(generated_tests=normalized_tests) + + def remove_test_functions(source: str, functions_to_remove: list[str]) -> str: """Remove specific test functions from JavaScript test source code. diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index cde098cab..e0111c634 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -23,6 +23,7 @@ from codeflash.languages.base import ReferenceInfo from codeflash.languages.javascript.treesitter import TypeDefinition + from codeflash.models.models import GeneratedTestsList, InvocationId logger = logging.getLogger(__name__) @@ -1778,6 +1779,116 @@ def remove_test_functions(self, test_source: str, functions_to_remove: list[str] return remove_test_functions(test_source, functions_to_remove) + def postprocess_generated_tests( + self, generated_tests: GeneratedTestsList, test_framework: str, project_root: Path, source_file_path: Path + ) -> GeneratedTestsList: + """Apply language-specific postprocessing to generated tests.""" + from codeflash.languages.javascript.edit_tests import ( + disable_ts_check, + inject_test_globals, + normalize_generated_tests_imports, + ) + from codeflash.languages.javascript.module_system import detect_module_system + + module_system = detect_module_system(project_root, source_file_path) + if module_system == "esm": + generated_tests = inject_test_globals(generated_tests, test_framework) + if self.language == Language.TYPESCRIPT: + generated_tests = disable_ts_check(generated_tests) + return normalize_generated_tests_imports(generated_tests) + + def remove_test_functions_from_generated_tests( + self, generated_tests: GeneratedTestsList, functions_to_remove: list[str] + ) -> GeneratedTestsList: + """Remove specific test functions from generated tests.""" + from codeflash.models.models import GeneratedTests, GeneratedTestsList + + updated_tests: list[GeneratedTests] = [] + for test in generated_tests.generated_tests: + updated_tests.append( + GeneratedTests( + generated_original_test_source=self.remove_test_functions( + test.generated_original_test_source, functions_to_remove + ), + instrumented_behavior_test_source=test.instrumented_behavior_test_source, + instrumented_perf_test_source=test.instrumented_perf_test_source, + behavior_file_path=test.behavior_file_path, + perf_file_path=test.perf_file_path, + ) + ) + return GeneratedTestsList(generated_tests=updated_tests) + + def add_runtime_comments_to_generated_tests( + self, + generated_tests: GeneratedTestsList, + original_runtimes: dict[InvocationId, list[int]], + optimized_runtimes: dict[InvocationId, list[int]], + tests_project_rootdir: Path | None = None, + ) -> GeneratedTestsList: + """Add runtime comments to generated tests.""" + from codeflash.models.models import GeneratedTests, GeneratedTestsList + + tests_root = tests_project_rootdir or Path() + original_runtimes_dict = self._build_runtime_map(original_runtimes, tests_root) + optimized_runtimes_dict = self._build_runtime_map(optimized_runtimes, tests_root) + + modified_tests: list[GeneratedTests] = [] + for test in generated_tests.generated_tests: + modified_source = self.add_runtime_comments( + test.generated_original_test_source, original_runtimes_dict, optimized_runtimes_dict + ) + modified_tests.append( + GeneratedTests( + generated_original_test_source=modified_source, + instrumented_behavior_test_source=test.instrumented_behavior_test_source, + instrumented_perf_test_source=test.instrumented_perf_test_source, + behavior_file_path=test.behavior_file_path, + perf_file_path=test.perf_file_path, + ) + ) + return GeneratedTestsList(generated_tests=modified_tests) + + def add_global_declarations(self, optimized_code: str, original_source: str, module_abspath: Path) -> str: + from codeflash.languages.javascript.code_replacer import _add_global_declarations_for_language + + return _add_global_declarations_for_language(optimized_code, original_source, module_abspath, self.language) + + def extract_calling_function_source(self, source_code: str, function_name: str, ref_line: int) -> str | None: + from codeflash.languages.javascript.treesitter import extract_calling_function_source + + return extract_calling_function_source(source_code, function_name, ref_line) + + def _build_runtime_map( + self, inv_id_runtimes: dict[InvocationId, list[int]], tests_project_rootdir: Path + ) -> dict[str, int]: + from codeflash.languages.javascript.edit_tests import resolve_js_test_module_path + + unique_inv_ids: dict[str, int] = {} + for inv_id, runtimes in inv_id_runtimes.items(): + test_qualified_name = ( + inv_id.test_class_name + "." + inv_id.test_function_name # type: ignore[operator] + if inv_id.test_class_name + else inv_id.test_function_name + ) + if not test_qualified_name: + continue + abs_path = resolve_js_test_module_path(inv_id.test_module_path, tests_project_rootdir) + + abs_path_str = str(abs_path.resolve().with_suffix("")) + if "__unit_test_" not in abs_path_str and "__perf_test_" not in abs_path_str: + continue + + key = test_qualified_name + "#" + abs_path_str + parts = inv_id.iteration_id.split("_").__len__() # type: ignore[union-attr] + cur_invid = ( + inv_id.iteration_id.split("_")[0] if parts < 3 else "_".join(inv_id.iteration_id.split("_")[:-1]) + ) # type: ignore[union-attr] + match_key = key + "#" + cur_invid + if match_key not in unique_inv_ids: + unique_inv_ids[match_key] = 0 + unique_inv_ids[match_key] += min(runtimes) + return unique_inv_ids + # === Test Result Comparison === def compare_test_results( diff --git a/codeflash/languages/javascript/treesitter.py b/codeflash/languages/javascript/treesitter.py index 32d2431ac..c00cb228e 100644 --- a/codeflash/languages/javascript/treesitter.py +++ b/codeflash/languages/javascript/treesitter.py @@ -1788,3 +1788,39 @@ def get_analyzer_for_file(file_path: Path) -> TreeSitterAnalyzer: return TreeSitterAnalyzer(TreeSitterLanguage.TSX) # Default to JavaScript for .js, .jsx, .mjs, .cjs return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT) + + +# Author: Saurabh Misra +def extract_calling_function_source(source_code: str, function_name: str, ref_line: int) -> str | None: + """Extract the source code of a calling function in JavaScript/TypeScript. + + Args: + source_code: Full source code of the file. + function_name: Name of the function to extract. + ref_line: Line number where the reference is (helps identify the right function). + + Returns: + Source code of the function, or None if not found. + + """ + try: + from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage + + # Try TypeScript first, fall back to JavaScript + for lang in [TreeSitterLanguage.TYPESCRIPT, TreeSitterLanguage.TSX, TreeSitterLanguage.JAVASCRIPT]: + try: + analyzer = TreeSitterAnalyzer(lang) + functions = analyzer.find_functions(source_code, include_methods=True) + + for func in functions: + if func.name == function_name: + # Check if the reference line is within this function + if func.start_line <= ref_line <= func.end_line: + return func.source_text + break + except Exception: + continue + + return None + except Exception: + return None diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index 0e42022f6..0752f91b8 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -10,7 +10,6 @@ import libcst as cst from codeflash.cli_cmds.console import logger -from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages from codeflash.code_utils.config_consts import OPTIMIZATION_CONTEXT_TOKEN_LIMIT, TESTGEN_CONTEXT_TOKEN_LIMIT from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001 @@ -24,6 +23,10 @@ recurse_sections, remove_unused_definitions_by_function_names, ) +from codeflash.languages.python.static_analysis.code_extractor import ( + add_needed_imports_from_module, + find_preexisting_objects, +) from codeflash.models.models import ( CodeContextType, CodeOptimizationContext, diff --git a/codeflash/languages/python/context/unused_definition_remover.py b/codeflash/languages/python/context/unused_definition_remover.py index 3cc7c173a..e70dcad29 100644 --- a/codeflash/languages/python/context/unused_definition_remover.py +++ b/codeflash/languages/python/context/unused_definition_remover.py @@ -10,8 +10,8 @@ import libcst as cst from codeflash.cli_cmds.console import logger -from codeflash.code_utils.code_replacer import replace_function_definitions_in_module -from codeflash.languages import is_javascript +from codeflash.languages import is_python +from codeflash.languages.python.static_analysis.code_replacer import replace_function_definitions_in_module from codeflash.models.models import CodeString, CodeStringsMarkdown if TYPE_CHECKING: @@ -747,8 +747,8 @@ def detect_unused_helper_functions( """ # Skip this analysis for non-Python languages since we use Python's ast module - if is_javascript(): - logger.debug("Skipping unused helper function detection for JavaScript/TypeScript") + if not is_python(): + logger.debug("Skipping unused helper function detection for non-Python languages") return [] if isinstance(optimized_code, CodeStringsMarkdown) and len(optimized_code.code_strings) > 0: diff --git a/codeflash/languages/python/static_analysis/__init__.py b/codeflash/languages/python/static_analysis/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/languages/python/static_analysis/code_extractor.py similarity index 96% rename from codeflash/code_utils/code_extractor.py rename to codeflash/languages/python/static_analysis/code_extractor.py index c4434c3ae..704f9e3db 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/languages/python/static_analysis/code_extractor.py @@ -1659,6 +1659,13 @@ def _format_references_as_markdown(references: list, file_path: Path, project_ro refs_by_file[ref.file_path] = [] refs_by_file[ref.file_path].append(ref) + from codeflash.languages.registry import get_language_support + + try: + lang_support = get_language_support(language) + except Exception: + lang_support = None + fn_call_context = "" context_len = 0 @@ -1700,7 +1707,11 @@ def _format_references_as_markdown(references: list, file_path: Path, project_ro # Extract context around the reference if ref.caller_function: # Try to extract the full calling function - func_code = _extract_calling_function(file_content, ref.caller_function, ref.line, language) + func_code = None + if lang_support is not None: + func_code = lang_support.extract_calling_function_source( + file_content, ref.caller_function, ref.line + ) if func_code: caller_contexts.append(func_code) context_len += len(func_code) @@ -1718,77 +1729,3 @@ def _format_references_as_markdown(references: list, file_path: Path, project_ro fn_call_context += "\n```\n" return fn_call_context - - -def _extract_calling_function(source_code: str, function_name: str, ref_line: int, language: Language) -> str | None: - """Extract the source code of a calling function. - - Args: - source_code: Full source code of the file. - function_name: Name of the function to extract. - ref_line: Line number where the reference is. - language: The programming language. - - Returns: - Source code of the function, or None if not found. - - """ - if language == Language.PYTHON: - return _extract_calling_function_python(source_code, function_name, ref_line) - return _extract_calling_function_js(source_code, function_name, ref_line) - - -def _extract_calling_function_python(source_code: str, function_name: str, ref_line: int) -> str | None: - """Extract the source code of a calling function in Python.""" - try: - import ast - - tree = ast.parse(source_code) - lines = source_code.splitlines() - - for node in ast.walk(tree): - if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - if node.name == function_name: - # Check if the reference line is within this function - start_line = node.lineno - end_line = node.end_lineno or start_line - if start_line <= ref_line <= end_line: - return "\n".join(lines[start_line - 1 : end_line]) - return None - except Exception: - return None - - -def _extract_calling_function_js(source_code: str, function_name: str, ref_line: int) -> str | None: - """Extract the source code of a calling function in JavaScript/TypeScript. - - Args: - source_code: Full source code of the file. - function_name: Name of the function to extract. - ref_line: Line number where the reference is (helps identify the right function). - - Returns: - Source code of the function, or None if not found. - - """ - try: - from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage - - # Try TypeScript first, fall back to JavaScript - for lang in [TreeSitterLanguage.TYPESCRIPT, TreeSitterLanguage.TSX, TreeSitterLanguage.JAVASCRIPT]: - try: - analyzer = TreeSitterAnalyzer(lang) - functions = analyzer.find_functions(source_code, include_methods=True) - - for func in functions: - if func.name == function_name: - # Check if the reference line is within this function - if func.start_line <= ref_line <= func.end_line: - return func.source_text - break - except Exception: - continue - - return None - except Exception: - return None diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/languages/python/static_analysis/code_replacer.py similarity index 78% rename from codeflash/code_utils/code_replacer.py rename to codeflash/languages/python/static_analysis/code_replacer.py index 3ad5eba2d..4e100a230 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/languages/python/static_analysis/code_replacer.py @@ -10,23 +10,22 @@ from libcst.metadata import PositionProvider from codeflash.cli_cmds.console import logger -from codeflash.code_utils.code_extractor import ( +from codeflash.code_utils.config_parser import find_conftest_files +from codeflash.code_utils.formatter import sort_imports +from codeflash.languages import is_python +from codeflash.languages.python.static_analysis.code_extractor import ( add_global_assignments, add_needed_imports_from_module, find_insertion_index_after_imports, ) -from codeflash.code_utils.config_parser import find_conftest_files -from codeflash.code_utils.formatter import sort_imports -from codeflash.code_utils.line_profile_utils import ImportAdder -from codeflash.languages import is_python +from codeflash.languages.python.static_analysis.line_profile_utils import ImportAdder from codeflash.models.models import FunctionParent if TYPE_CHECKING: from pathlib import Path from codeflash.discovery.functions_to_optimize import FunctionToOptimize - from codeflash.languages.base import Language, LanguageSupport - from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer + from codeflash.languages.base import LanguageSupport from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, OptimizedCandidate, ValidCode ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST) @@ -401,23 +400,114 @@ def replace_functions_in_file( return source_code parsed_function_names.append((class_name, function_name)) - # Collect functions we want to modify from the optimized code - optimized_module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code)) + # Collect functions from optimized code without using MetadataWrapper + optimized_module = cst.parse_module(optimized_code) + modified_functions: dict[tuple[str | None, str], cst.FunctionDef] = {} + new_functions: list[cst.FunctionDef] = [] + new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list) + new_classes: list[cst.ClassDef] = [] + modified_init_functions: dict[str, cst.FunctionDef] = {} + + function_names_set = set(parsed_function_names) + + for node in optimized_module.body: + if isinstance(node, cst.FunctionDef): + key = (None, node.name.value) + if key in function_names_set: + modified_functions[key] = node + elif preexisting_objects and (node.name.value, ()) not in preexisting_objects: + new_functions.append(node) + + elif isinstance(node, cst.ClassDef): + class_name = node.name.value + parents = (FunctionParent(name=class_name, type="ClassDef"),) + + if (class_name, ()) not in preexisting_objects: + new_classes.append(node) + + for child in node.body.body: + if isinstance(child, cst.FunctionDef): + method_key = (class_name, child.name.value) + if method_key in function_names_set: + modified_functions[method_key] = child + elif ( + child.name.value == "__init__" + and preexisting_objects + and (class_name, ()) in preexisting_objects + ): + modified_init_functions[class_name] = child + elif preexisting_objects and (child.name.value, parents) not in preexisting_objects: + new_class_functions[class_name].append(child) + original_module = cst.parse_module(source_code) - visitor = OptimFunctionCollector(preexisting_objects, set(parsed_function_names)) - optimized_module.visit(visitor) + max_function_index = None + max_class_index = None + for index, _node in enumerate(original_module.body): + if isinstance(_node, cst.FunctionDef): + max_function_index = index + if isinstance(_node, cst.ClassDef): + max_class_index = index + + new_body: list[cst.CSTNode] = [] + existing_class_names = set() + + for node in original_module.body: + if isinstance(node, cst.FunctionDef): + key = (None, node.name.value) + if key in modified_functions: + modified_func = modified_functions[key] + new_body.append(node.with_changes(body=modified_func.body, decorators=modified_func.decorators)) + else: + new_body.append(node) + + elif isinstance(node, cst.ClassDef): + class_name = node.name.value + existing_class_names.add(class_name) + + new_members: list[cst.CSTNode] = [] + for child in node.body.body: + if isinstance(child, cst.FunctionDef): + key = (class_name, child.name.value) + if key in modified_functions: + modified_func = modified_functions[key] + new_members.append( + child.with_changes(body=modified_func.body, decorators=modified_func.decorators) + ) + elif child.name.value == "__init__" and class_name in modified_init_functions: + new_members.append(modified_init_functions[class_name]) + else: + new_members.append(child) + else: + new_members.append(child) + + if class_name in new_class_functions: + new_members.extend(new_class_functions[class_name]) + + new_body.append(node.with_changes(body=node.body.with_changes(body=new_members))) + else: + new_body.append(node) - # Replace these functions in the original code - transformer = OptimFunctionReplacer( - modified_functions=visitor.modified_functions, - new_classes=visitor.new_classes, - new_functions=visitor.new_functions, - new_class_functions=visitor.new_class_functions, - modified_init_functions=visitor.modified_init_functions, - ) - modified_tree = original_module.visit(transformer) - return modified_tree.code + if new_classes: + unique_classes = [nc for nc in new_classes if nc.name.value not in existing_class_names] + if unique_classes: + new_classes_insertion_idx = ( + max_class_index if max_class_index is not None else find_insertion_index_after_imports(original_module) + ) + new_body = list( + chain(new_body[:new_classes_insertion_idx], unique_classes, new_body[new_classes_insertion_idx:]) + ) + + if new_functions: + if max_function_index is not None: + new_body = [*new_body[: max_function_index + 1], *new_functions, *new_body[max_function_index + 1 :]] + elif max_class_index is not None: + new_body = [*new_body[: max_class_index + 1], *new_functions, *new_body[max_class_index + 1 :]] + else: + new_body = [*new_functions, *new_body] + + updated_module = original_module.with_changes(body=new_body) + return updated_module.code def replace_functions_and_add_imports( @@ -509,11 +599,8 @@ def replace_function_definitions_for_language( lang_support = get_language_support(language) # Add any new global declarations from the optimized code to the original source - original_source_code = _add_global_declarations_for_language( - optimized_code=code_to_apply, - original_source=original_source_code, - module_abspath=module_abspath, - language=language, + original_source_code = lang_support.add_global_declarations( + optimized_code=code_to_apply, original_source=original_source_code, module_abspath=module_abspath ) # If we have function_to_optimize with line info and this is the main file, use it for precise replacement @@ -612,204 +699,6 @@ def _extract_function_from_code( return None -def _add_global_declarations_for_language( - optimized_code: str, original_source: str, module_abspath: Path, language: Language -) -> str: - """Add new global declarations from optimized code to original source. - - Finds module-level declarations (const, let, var, class, type, interface, enum) - in the optimized code that don't exist in the original source and adds them. - - New declarations are inserted after any existing declarations they depend on. - For example, if optimized code has `const _has = FOO.bar.bind(FOO)`, and `FOO` - is already declared in the original source, `_has` will be inserted after `FOO`. - - Args: - optimized_code: The optimized code that may contain new declarations. - original_source: The original source code. - module_abspath: Path to the module file (for parser selection). - language: The language of the code. - - Returns: - Original source with new declarations added in dependency order. - - """ - from codeflash.languages.base import Language - - if language not in (Language.JAVASCRIPT, Language.TYPESCRIPT): - return original_source - - try: - from codeflash.languages.javascript.treesitter import get_analyzer_for_file - - analyzer = get_analyzer_for_file(module_abspath) - - original_declarations = analyzer.find_module_level_declarations(original_source) - optimized_declarations = analyzer.find_module_level_declarations(optimized_code) - - if not optimized_declarations: - return original_source - - existing_names = _get_existing_names(original_declarations, analyzer, original_source) - new_declarations = _filter_new_declarations(optimized_declarations, existing_names) - - if not new_declarations: - return original_source - - # Build a map of existing declaration names to their end lines (1-indexed) - existing_decl_end_lines = {decl.name: decl.end_line for decl in original_declarations} - - # Insert each new declaration after its dependencies - result = original_source - for decl in new_declarations: - result = _insert_declaration_after_dependencies( - result, decl, existing_decl_end_lines, analyzer, module_abspath - ) - # Update the map with the newly inserted declaration for subsequent insertions - # Re-parse to get accurate line numbers after insertion - updated_declarations = analyzer.find_module_level_declarations(result) - existing_decl_end_lines = {d.name: d.end_line for d in updated_declarations} - - return result - - except Exception as e: - logger.debug(f"Error adding global declarations: {e}") - return original_source - - -def _get_existing_names(original_declarations: list, analyzer: TreeSitterAnalyzer, original_source: str) -> set[str]: - """Get all names that already exist in the original source (declarations + imports).""" - existing_names = {decl.name for decl in original_declarations} - - original_imports = analyzer.find_imports(original_source) - for imp in original_imports: - if imp.default_import: - existing_names.add(imp.default_import) - for name, alias in imp.named_imports: - existing_names.add(alias if alias else name) - if imp.namespace_import: - existing_names.add(imp.namespace_import) - - return existing_names - - -def _filter_new_declarations(optimized_declarations: list, existing_names: set[str]) -> list: - """Filter declarations to only those that don't exist in the original source.""" - new_declarations = [] - seen_sources: set[str] = set() - - # Sort by line number to maintain order from optimized code - sorted_declarations = sorted(optimized_declarations, key=lambda d: d.start_line) - - for decl in sorted_declarations: - if decl.name not in existing_names and decl.source_code not in seen_sources: - new_declarations.append(decl) - seen_sources.add(decl.source_code) - - return new_declarations - - -def _insert_declaration_after_dependencies( - source: str, - declaration, - existing_decl_end_lines: dict[str, int], - analyzer: TreeSitterAnalyzer, - module_abspath: Path, -) -> str: - """Insert a declaration after the last existing declaration it depends on. - - Args: - source: Current source code. - declaration: The declaration to insert. - existing_decl_end_lines: Map of existing declaration names to their end lines. - analyzer: TreeSitter analyzer. - module_abspath: Path to the module file. - - Returns: - Source code with the declaration inserted at the correct position. - - """ - # Find identifiers referenced in this declaration - referenced_names = analyzer.find_referenced_identifiers(declaration.source_code) - - # Find the latest end line among all referenced declarations - insertion_line = _find_insertion_line_for_declaration(source, referenced_names, existing_decl_end_lines, analyzer) - - lines = source.splitlines(keepends=True) - - # Ensure proper spacing - decl_code = declaration.source_code - if not decl_code.endswith("\n"): - decl_code += "\n" - - # Add blank line before if inserting after content - if insertion_line > 0 and lines[insertion_line - 1].strip(): - decl_code = "\n" + decl_code - - before = lines[:insertion_line] - after = lines[insertion_line:] - - return "".join([*before, decl_code, *after]) - - -def _find_insertion_line_for_declaration( - source: str, referenced_names: set[str], existing_decl_end_lines: dict[str, int], analyzer: TreeSitterAnalyzer -) -> int: - """Find the line where a declaration should be inserted based on its dependencies. - - Args: - source: Source code. - referenced_names: Names referenced by the declaration. - existing_decl_end_lines: Map of declaration names to their end lines (1-indexed). - analyzer: TreeSitter analyzer. - - Returns: - Line index (0-based) where the declaration should be inserted. - - """ - # Find the maximum end line among referenced declarations - max_dependency_line = 0 - for name in referenced_names: - if name in existing_decl_end_lines: - max_dependency_line = max(max_dependency_line, existing_decl_end_lines[name]) - - if max_dependency_line > 0: - # Insert after the last dependency (end_line is 1-indexed, we need 0-indexed) - return max_dependency_line - - # No dependencies found - insert after imports - lines = source.splitlines(keepends=True) - return _find_line_after_imports(lines, analyzer, source) - - -def _find_line_after_imports(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int: - """Find the line index after all imports. - - Args: - lines: Source lines. - analyzer: TreeSitter analyzer. - source: Full source code. - - Returns: - Line index (0-based) for insertion after imports. - - """ - try: - imports = analyzer.find_imports(source) - if imports: - return max(imp.end_line for imp in imports) - except Exception as exc: - logger.debug(f"Exception in _find_line_after_imports: {exc}") - - # Default: insert at beginning (after shebang/directive comments) - for i, line in enumerate(lines): - stripped = line.strip() - if stripped and not stripped.startswith("//") and not stripped.startswith("#!"): - return i - - return 0 - - def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStringsMarkdown) -> str: file_to_code_context = optimized_code.file_to_path() module_optimized_code = file_to_code_context.get(str(relative_path)) diff --git a/codeflash/code_utils/concolic_utils.py b/codeflash/languages/python/static_analysis/concolic_utils.py similarity index 100% rename from codeflash/code_utils/concolic_utils.py rename to codeflash/languages/python/static_analysis/concolic_utils.py diff --git a/codeflash/code_utils/coverage_utils.py b/codeflash/languages/python/static_analysis/coverage_utils.py similarity index 100% rename from codeflash/code_utils/coverage_utils.py rename to codeflash/languages/python/static_analysis/coverage_utils.py diff --git a/codeflash/code_utils/edit_generated_tests.py b/codeflash/languages/python/static_analysis/edit_generated_tests.py similarity index 61% rename from codeflash/code_utils/edit_generated_tests.py rename to codeflash/languages/python/static_analysis/edit_generated_tests.py index 7ec303b7a..c4aed07de 100644 --- a/codeflash/code_utils/edit_generated_tests.py +++ b/codeflash/languages/python/static_analysis/edit_generated_tests.py @@ -12,7 +12,6 @@ from codeflash.cli_cmds.console import logger from codeflash.code_utils.time_utils import format_perf, format_time -from codeflash.languages.registry import get_language_support from codeflash.models.models import GeneratedTests, GeneratedTestsList from codeflash.result.critic import performance_gain @@ -155,7 +154,6 @@ def _is_python_file(file_path: Path) -> bool: return file_path.suffix == ".py" -# TODO:{self} Needs cleanup for jest logic in else block def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]], tests_project_rootdir: Path) -> dict[str, int]: unique_inv_ids: dict[str, int] = {} logger.debug(f"[unique_inv_id] Processing {len(inv_id_runtimes)} invocation IDs") @@ -166,53 +164,11 @@ def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]], tests_project_ else inv_id.test_function_name ) - # Detect if test_module_path is a file path (like in js tests) or a Python module name - # File paths contain slashes, module names use dots test_module_path = inv_id.test_module_path if "/" in test_module_path or "\\" in test_module_path: - # Already a file path - use directly abs_path = tests_project_rootdir / Path(test_module_path) else: - # Check for Jest test file extensions (e.g., tests.fibonacci.test.ts) - # These need special handling to avoid converting .test.ts -> /test/ts - jest_test_extensions = ( - ".test.ts", - ".test.js", - ".test.tsx", - ".test.jsx", - ".spec.ts", - ".spec.js", - ".spec.tsx", - ".spec.jsx", - ".ts", - ".js", - ".tsx", - ".jsx", - ".mjs", - ".mts", - ) - matched_ext = None - for ext in jest_test_extensions: - if test_module_path.endswith(ext): - matched_ext = ext - break - - if matched_ext: - # JavaScript/TypeScript: convert module-style path to file path - # "tests.fibonacci__perfonlyinstrumented.test.ts" -> "tests/fibonacci__perfonlyinstrumented.test.ts" - base_path = test_module_path[: -len(matched_ext)] - file_path = base_path.replace(".", os.sep) + matched_ext - # Check if the module path includes the tests directory name - tests_dir_name = tests_project_rootdir.name - if file_path.startswith((tests_dir_name + os.sep, tests_dir_name + "/")): - # Module path includes "tests." - use parent directory - abs_path = tests_project_rootdir.parent / Path(file_path) - else: - # Module path doesn't include tests dir - use tests root directly - abs_path = tests_project_rootdir / Path(file_path) - else: - # Python module name - convert dots to path separators and add .py - abs_path = tests_project_rootdir / Path(test_module_path.replace(".", os.sep)).with_suffix(".py") + abs_path = tests_project_rootdir / Path(test_module_path.replace(".", os.sep)).with_suffix(".py") abs_path_str = str(abs_path.resolve().with_suffix("")) # Include both unit test and perf test paths for runtime annotations @@ -268,22 +224,7 @@ def add_runtime_comments_to_generated_tests( logger.debug(f"Failed to add runtime comments to test: {e}") modified_tests.append(test) else: - try: - language_support = get_language_support(test.behavior_file_path) - modified_source = language_support.add_runtime_comments( - test.generated_original_test_source, original_runtimes_dict, optimized_runtimes_dict - ) - modified_test = GeneratedTests( - generated_original_test_source=modified_source, - instrumented_behavior_test_source=test.instrumented_behavior_test_source, - instrumented_perf_test_source=test.instrumented_perf_test_source, - behavior_file_path=test.behavior_file_path, - perf_file_path=test.perf_file_path, - ) - modified_tests.append(modified_test) - except Exception as e: - logger.debug(f"Failed to add runtime comments to test: {e}") - modified_tests.append(test) + modified_tests.append(test) return GeneratedTestsList(generated_tests=modified_tests) @@ -329,109 +270,3 @@ def _compile_function_patterns(test_functions_to_remove: list[str]) -> list[re.P ) for func in test_functions_to_remove ] - - -# Patterns for normalizing codeflash imports (legacy -> npm package) -_CODEFLASH_REQUIRE_PATTERN = re.compile( - r"(const|let|var)\s+(\w+)\s*=\s*require\s*\(\s*['\"]\.?/?codeflash-jest-helper['\"]\s*\)" -) -_CODEFLASH_IMPORT_PATTERN = re.compile(r"import\s+(?:\*\s+as\s+)?(\w+)\s+from\s+['\"]\.?/?codeflash-jest-helper['\"]") - - -def normalize_codeflash_imports(source: str) -> str: - """Normalize codeflash imports to use the npm package. - - Replaces legacy local file imports: - const codeflash = require('./codeflash-jest-helper') - import codeflash from './codeflash-jest-helper' - - With npm package imports: - const codeflash = require('codeflash') - - Args: - source: JavaScript/TypeScript source code. - - Returns: - Source code with normalized imports. - - """ - # Replace CommonJS require - source = _CODEFLASH_REQUIRE_PATTERN.sub(r"\1 \2 = require('codeflash')", source) - # Replace ES module import - return _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source) - - -def inject_test_globals(generated_tests: GeneratedTestsList, test_framework: str = "jest") -> GeneratedTestsList: - # TODO: inside the prompt tell the llm if it should import jest functions or it's already injected in the global window - """Inject test globals into all generated tests. - - Args: - generated_tests: List of generated tests. - test_framework: The test framework being used ("jest", "vitest", or "mocha"). - - Returns: - Generated tests with test globals injected. - - """ - # we only inject test globals for esm modules - # Use vitest imports for vitest projects, jest imports for jest projects - if test_framework == "vitest": - global_import = "import { vi, describe, it, expect, beforeEach, afterEach, beforeAll, test } from 'vitest'\n" - else: - # Default to jest imports for jest and other frameworks - global_import = ( - "import { jest, describe, it, expect, beforeEach, afterEach, beforeAll, test } from '@jest/globals'\n" - ) - - for test in generated_tests.generated_tests: - test.generated_original_test_source = global_import + test.generated_original_test_source - test.instrumented_behavior_test_source = global_import + test.instrumented_behavior_test_source - test.instrumented_perf_test_source = global_import + test.instrumented_perf_test_source - return generated_tests - - -def disable_ts_check(generated_tests: GeneratedTestsList) -> GeneratedTestsList: - """Disable TypeScript type checking in all generated tests. - - Args: - generated_tests: List of generated tests. - - Returns: - Generated tests with TypeScript type checking disabled. - - """ - # we only inject test globals for esm modules - ts_nocheck = "// @ts-nocheck\n" - - for test in generated_tests.generated_tests: - test.generated_original_test_source = ts_nocheck + test.generated_original_test_source - test.instrumented_behavior_test_source = ts_nocheck + test.instrumented_behavior_test_source - test.instrumented_perf_test_source = ts_nocheck + test.instrumented_perf_test_source - return generated_tests - - -def normalize_generated_tests_imports(generated_tests: GeneratedTestsList) -> GeneratedTestsList: - """Normalize codeflash imports in all generated tests. - - Args: - generated_tests: List of generated tests. - - Returns: - Generated tests with normalized imports. - - """ - normalized_tests = [] - for test in generated_tests.generated_tests: - # Only normalize JS/TS files - if test.behavior_file_path.suffix in (".js", ".ts", ".jsx", ".tsx", ".mjs", ".mts"): - normalized_test = GeneratedTests( - generated_original_test_source=normalize_codeflash_imports(test.generated_original_test_source), - instrumented_behavior_test_source=normalize_codeflash_imports(test.instrumented_behavior_test_source), - instrumented_perf_test_source=normalize_codeflash_imports(test.instrumented_perf_test_source), - behavior_file_path=test.behavior_file_path, - perf_file_path=test.perf_file_path, - ) - normalized_tests.append(normalized_test) - else: - normalized_tests.append(test) - return GeneratedTestsList(generated_tests=normalized_tests) diff --git a/codeflash/code_utils/line_profile_utils.py b/codeflash/languages/python/static_analysis/line_profile_utils.py similarity index 100% rename from codeflash/code_utils/line_profile_utils.py rename to codeflash/languages/python/static_analysis/line_profile_utils.py diff --git a/codeflash/code_utils/static_analysis.py b/codeflash/languages/python/static_analysis/static_analysis.py similarity index 100% rename from codeflash/code_utils/static_analysis.py rename to codeflash/languages/python/static_analysis/static_analysis.py diff --git a/codeflash/languages/python/support.py b/codeflash/languages/python/support.py index b026e99e5..cf55e6f61 100644 --- a/codeflash/languages/python/support.py +++ b/codeflash/languages/python/support.py @@ -22,7 +22,7 @@ from collections.abc import Sequence from codeflash.languages.base import DependencyResolver - from codeflash.models.models import FunctionSource + from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId logger = logging.getLogger(__name__) @@ -379,7 +379,7 @@ def replace_function(self, source: str, function: FunctionToOptimize, new_source Modified source code with function replaced. """ - from codeflash.code_utils.code_replacer import replace_functions_in_file + from codeflash.languages.python.static_analysis.code_replacer import replace_functions_in_file try: # Determine the function names to replace @@ -657,6 +657,59 @@ def leave_FunctionDef( except Exception: return test_source + def postprocess_generated_tests( + self, generated_tests: GeneratedTestsList, test_framework: str, project_root: Path, source_file_path: Path + ) -> GeneratedTestsList: + """Apply language-specific postprocessing to generated tests.""" + _ = test_framework, project_root, source_file_path + return generated_tests + + def remove_test_functions_from_generated_tests( + self, generated_tests: GeneratedTestsList, functions_to_remove: list[str] + ) -> GeneratedTestsList: + """Remove specific test functions from generated tests.""" + from codeflash.languages.python.static_analysis.edit_generated_tests import ( + remove_functions_from_generated_tests, + ) + + return remove_functions_from_generated_tests(generated_tests, functions_to_remove) + + def add_runtime_comments_to_generated_tests( + self, + generated_tests: GeneratedTestsList, + original_runtimes: dict[InvocationId, list[int]], + optimized_runtimes: dict[InvocationId, list[int]], + tests_project_rootdir: Path | None = None, + ) -> GeneratedTestsList: + """Add runtime comments to generated tests.""" + from codeflash.languages.python.static_analysis.edit_generated_tests import ( + add_runtime_comments_to_generated_tests, + ) + + return add_runtime_comments_to_generated_tests( + generated_tests, original_runtimes, optimized_runtimes, tests_project_rootdir + ) + + def add_global_declarations(self, optimized_code: str, original_source: str, module_abspath: Path) -> str: + _ = optimized_code, module_abspath + return original_source + + def extract_calling_function_source(self, source_code: str, function_name: str, ref_line: int) -> str | None: + """Extract the source code of a calling function in Python.""" + try: + import ast + + lines = source_code.splitlines() + tree = ast.parse(source_code) + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == function_name: + end_line = node.end_lineno or node.lineno + if node.lineno <= ref_line <= end_line: + return "\n".join(lines[node.lineno - 1 : end_line]) + except Exception: + return None + return None + # === Test Result Comparison === def compare_test_results( diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 55dfa314e..dd8e41dd8 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -26,12 +26,6 @@ from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, lsp_log, progress_bar from codeflash.code_utils import env_utils -from codeflash.code_utils.code_extractor import get_opt_review_metrics, is_numerical_code -from codeflash.code_utils.code_replacer import ( - add_custom_marker_to_all_tests, - modify_autouse_fixture, - replace_function_definitions_in_module, -) from codeflash.code_utils.code_utils import ( choose_weights, cleanup_paths, @@ -59,33 +53,31 @@ get_effort_value, ) from codeflash.code_utils.deduplicate_code import normalize_code -from codeflash.code_utils.edit_generated_tests import ( - add_runtime_comments_to_generated_tests, - disable_ts_check, - inject_test_globals, - normalize_generated_tests_imports, - remove_functions_from_generated_tests, -) from codeflash.code_utils.env_utils import get_pr_number from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports from codeflash.code_utils.git_utils import git_root_dir from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test -from codeflash.code_utils.line_profile_utils import add_decorator_imports, contains_jit_decorator from codeflash.code_utils.shell_utils import make_env_with_project_root -from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast from codeflash.code_utils.time_utils import humanize_runtime from codeflash.discovery.functions_to_optimize import was_function_previously_optimized from codeflash.either import Failure, Success, is_successful from codeflash.languages import is_python from codeflash.languages.base import Language -from codeflash.languages.current import current_language_support, is_typescript -from codeflash.languages.javascript.module_system import detect_module_system +from codeflash.languages.current import current_language_support from codeflash.languages.javascript.test_runner import clear_created_config_files, get_created_config_files from codeflash.languages.python.context import code_context_extractor from codeflash.languages.python.context.unused_definition_remover import ( detect_unused_helper_functions, revert_unused_helper_functions, ) +from codeflash.languages.python.static_analysis.code_extractor import get_opt_review_metrics, is_numerical_code +from codeflash.languages.python.static_analysis.code_replacer import ( + add_custom_marker_to_all_tests, + modify_autouse_fixture, + replace_function_definitions_in_module, +) +from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator +from codeflash.languages.python.static_analysis.static_analysis import get_first_top_level_function_or_method_ast from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId from codeflash.models.ExperimentMetadata import ExperimentMetadata @@ -596,16 +588,13 @@ def generate_and_instrument_tests( count_tests, generated_tests, function_to_concolic_tests, concolic_test_str = test_results.unwrap() - # Normalize codeflash imports in JS/TS tests to use npm package - if not is_python(): - module_system = detect_module_system(self.project_root, self.function_to_optimize.file_path) - if module_system == "esm": - generated_tests = inject_test_globals(generated_tests, self.test_cfg.test_framework) - if is_typescript(): - # disable ts check for typescript tests - generated_tests = disable_ts_check(generated_tests) - - generated_tests = normalize_generated_tests_imports(generated_tests) + # Language-specific postprocessing for generated tests + generated_tests = self.language_support.postprocess_generated_tests( + generated_tests, + test_framework=self.test_cfg.test_framework, + project_root=self.project_root, + source_file_path=self.function_to_optimize.file_path, + ) logger.debug(f"[PIPELINE] Processing {count_tests} generated tests") for i, generated_test in enumerate(generated_tests.generated_tests): @@ -2067,8 +2056,8 @@ def process_review( else "Coverage data not available" ) - generated_tests = remove_functions_from_generated_tests( - generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove + generated_tests = self.language_support.remove_test_functions_from_generated_tests( + generated_tests, test_functions_to_remove ) map_gen_test_file_to_no_of_tests = original_code_baseline.behavior_test_results.file_to_no_of_tests( test_functions_to_remove @@ -2079,7 +2068,7 @@ def process_review( best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case() ) - generated_tests = add_runtime_comments_to_generated_tests( + generated_tests = self.language_support.add_runtime_comments_to_generated_tests( generated_tests, original_runtime_by_test, optimized_runtime_by_test, self.test_cfg.tests_project_rootdir ) diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 3211ab59b..5527a0567 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -250,7 +250,9 @@ def create_function_optimizer( original_module_path: Path | None = None, call_graph: DependencyResolver | None = None, ) -> FunctionOptimizer | None: - from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast + from codeflash.languages.python.static_analysis.static_analysis import ( + get_first_top_level_function_or_method_ast, + ) from codeflash.optimization.function_optimizer import FunctionOptimizer if function_to_optimize_ast is None and original_module_ast is not None: @@ -293,8 +295,8 @@ def create_function_optimizer( def prepare_module_for_optimization( self, original_module_path: Path ) -> tuple[dict[Path, ValidCode], ast.Module | None] | None: - from codeflash.code_utils.code_replacer import normalize_code, normalize_node - from codeflash.code_utils.static_analysis import analyze_imported_modules + from codeflash.languages.python.static_analysis.code_replacer import normalize_code, normalize_node + from codeflash.languages.python.static_analysis.static_analysis import analyze_imported_modules logger.info(f"loading|Examining file {original_module_path!s}") console.rule() diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index b276725f2..cbde5399a 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -9,12 +9,12 @@ from codeflash.api import cfapi from codeflash.cli_cmds.console import console, logger from codeflash.code_utils import env_utils -from codeflash.code_utils.code_replacer import is_zero_diff from codeflash.code_utils.git_utils import check_and_push_branch, get_current_branch, get_repo_owner_and_name from codeflash.code_utils.github_utils import github_pr_url from codeflash.code_utils.tabulate import tabulate from codeflash.code_utils.time_utils import format_perf, format_time from codeflash.github.PrComment import FileDiffContent, PrComment +from codeflash.languages.python.static_analysis.code_replacer import is_zero_diff from codeflash.result.critic import performance_gain if TYPE_CHECKING: diff --git a/codeflash/verification/concolic_testing.py b/codeflash/verification/concolic_testing.py index 05cad9f7a..8fa43de7e 100644 --- a/codeflash/verification/concolic_testing.py +++ b/codeflash/verification/concolic_testing.py @@ -10,11 +10,11 @@ from codeflash.cli_cmds.console import console, logger from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE -from codeflash.code_utils.concolic_utils import clean_concolic_tests, is_valid_concolic_test from codeflash.code_utils.shell_utils import make_env_with_project_root -from codeflash.code_utils.static_analysis import has_typed_parameters from codeflash.discovery.discover_unit_tests import discover_unit_tests from codeflash.languages import is_python +from codeflash.languages.python.static_analysis.concolic_utils import clean_concolic_tests, is_valid_concolic_test +from codeflash.languages.python.static_analysis.static_analysis import has_typed_parameters from codeflash.lsp.helpers import is_LSP_enabled from codeflash.telemetry.posthog_cf import ph from codeflash.verification.verification_utils import TestConfig diff --git a/codeflash/verification/coverage_utils.py b/codeflash/verification/coverage_utils.py index 20535b9f7..08490914e 100644 --- a/codeflash/verification/coverage_utils.py +++ b/codeflash/verification/coverage_utils.py @@ -7,7 +7,7 @@ from coverage.exceptions import NoDataError from codeflash.cli_cmds.console import logger -from codeflash.code_utils.coverage_utils import ( +from codeflash.languages.python.static_analysis.coverage_utils import ( build_fully_qualified_name, extract_dependent_function, generate_candidates, diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py index c0ecdc03d..a64bdd8e1 100644 --- a/codeflash/verification/test_runner.py +++ b/codeflash/verification/test_runner.py @@ -13,9 +13,9 @@ from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE -from codeflash.code_utils.coverage_utils import prepare_coverage_files from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args from codeflash.languages import is_python +from codeflash.languages.python.static_analysis.coverage_utils import prepare_coverage_files from codeflash.languages.registry import get_language_support, get_language_support_by_framework from codeflash.models.models import TestFiles, TestType diff --git a/mypy_allowlist.txt b/mypy_allowlist.txt index e08b14e22..378d89675 100644 --- a/mypy_allowlist.txt +++ b/mypy_allowlist.txt @@ -28,8 +28,8 @@ codeflash/code_utils/__init__.py codeflash/code_utils/time_utils.py codeflash/code_utils/env_utils.py codeflash/code_utils/config_consts.py -codeflash/code_utils/static_analysis.py -codeflash/code_utils/edit_generated_tests.py +codeflash/languages/python/static_analysis/static_analysis.py +codeflash/languages/python/static_analysis/edit_generated_tests.py codeflash/cli_cmds/console_constants.py codeflash/cli_cmds/logging_config.py codeflash/cli_cmds/__init__.py diff --git a/tests/code_utils/test_concolic_utils.py b/tests/code_utils/test_concolic_utils.py index 672bad38a..3a873fde0 100644 --- a/tests/code_utils/test_concolic_utils.py +++ b/tests/code_utils/test_concolic_utils.py @@ -2,7 +2,7 @@ import pytest -from codeflash.code_utils.concolic_utils import AssertCleanup, is_valid_concolic_test +from codeflash.languages.python.static_analysis.concolic_utils import AssertCleanup, is_valid_concolic_test class TestFirstTopLevelArg: diff --git a/tests/code_utils/test_coverage_utils.py b/tests/code_utils/test_coverage_utils.py index d637bac5e..3ca28e898 100644 --- a/tests/code_utils/test_coverage_utils.py +++ b/tests/code_utils/test_coverage_utils.py @@ -2,7 +2,7 @@ from typing import Any -from codeflash.code_utils.coverage_utils import build_fully_qualified_name, extract_dependent_function +from codeflash.languages.python.static_analysis.coverage_utils import build_fully_qualified_name, extract_dependent_function from codeflash.models.function_types import FunctionParent from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown from codeflash.verification.coverage_utils import CoverageUtils diff --git a/tests/test_add_needed_imports_from_module.py b/tests/test_add_needed_imports_from_module.py index efb2a254c..03d62cdc8 100644 --- a/tests/test_add_needed_imports_from_module.py +++ b/tests/test_add_needed_imports_from_module.py @@ -3,13 +3,13 @@ import libcst as cst -from codeflash.code_utils.code_extractor import ( +from codeflash.languages.python.static_analysis.code_extractor import ( DottedImportCollector, add_needed_imports_from_module, find_preexisting_objects, resolve_star_import, ) -from codeflash.code_utils.code_replacer import replace_functions_and_add_imports +from codeflash.languages.python.static_analysis.code_replacer import replace_functions_and_add_imports from codeflash.models.models import FunctionParent @@ -22,7 +22,7 @@ def test_add_needed_imports_from_module0() -> None: import tiktoken from jedi.api.classes import Name from pydantic.dataclasses import dataclass -from codeflash.code_utils.code_extractor import get_code, get_code_no_skeleton +from codeflash.languages.python.static_analysis.code_extractor import get_code, get_code_no_skeleton from codeflash.code_utils.code_utils import path_belongs_to_site_packages from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize @@ -76,7 +76,7 @@ def test_add_needed_imports_from_module() -> None: from jedi.api.classes import Name from pydantic.dataclasses import dataclass -from codeflash.code_utils.code_extractor import get_code, get_code_no_skeleton +from codeflash.languages.python.static_analysis.code_extractor import get_code, get_code_no_skeleton from codeflash.code_utils.code_utils import path_belongs_to_site_packages from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize diff --git a/tests/test_add_runtime_comments.py b/tests/test_add_runtime_comments.py index c79e379ce..c70187aa5 100644 --- a/tests/test_add_runtime_comments.py +++ b/tests/test_add_runtime_comments.py @@ -4,7 +4,7 @@ import pytest -from codeflash.code_utils.edit_generated_tests import add_runtime_comments_to_generated_tests +from codeflash.languages.python.static_analysis.edit_generated_tests import add_runtime_comments_to_generated_tests from codeflash.models.models import ( FunctionTestInvocation, GeneratedTests, diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 4dfddb4f7..2d87fbf24 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -8,8 +8,8 @@ import pytest -from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_global_assignments -from codeflash.code_utils.code_replacer import replace_functions_and_add_imports +from codeflash.languages.python.static_analysis.code_extractor import GlobalAssignmentCollector, add_global_assignments +from codeflash.languages.python.static_analysis.code_replacer import replace_functions_and_add_imports from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.python.context.code_context_extractor import ( collect_names_from_annotation, @@ -2870,7 +2870,7 @@ def test_global_function_collector(): """Test GlobalFunctionCollector correctly collects module-level function definitions.""" import libcst as cst - from codeflash.code_utils.code_extractor import GlobalFunctionCollector + from codeflash.languages.python.static_analysis.code_extractor import GlobalFunctionCollector source_code = """ # Module-level functions diff --git a/tests/test_code_extractor_none_aliases_exact.py b/tests/test_code_extractor_none_aliases_exact.py index e212de857..464680bd2 100644 --- a/tests/test_code_extractor_none_aliases_exact.py +++ b/tests/test_code_extractor_none_aliases_exact.py @@ -1,7 +1,7 @@ import tempfile from pathlib import Path -from codeflash.code_utils.code_extractor import add_needed_imports_from_module +from codeflash.languages.python.static_analysis.code_extractor import add_needed_imports_from_module def test_add_needed_imports_with_none_aliases(): diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index fbca6d71e..77d9108ab 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -7,8 +7,8 @@ import libcst as cst -from codeflash.code_utils.code_extractor import delete___future___aliased_imports, find_preexisting_objects -from codeflash.code_utils.code_replacer import ( +from codeflash.languages.python.static_analysis.code_extractor import delete___future___aliased_imports, find_preexisting_objects +from codeflash.languages.python.static_analysis.code_replacer import ( AddRequestArgument, AutouseFixtureModifier, OptimFunctionCollector, diff --git a/tests/test_code_utils.py b/tests/test_code_utils.py index 6844a16a1..3b794e59c 100644 --- a/tests/test_code_utils.py +++ b/tests/test_code_utils.py @@ -19,8 +19,8 @@ path_belongs_to_site_packages, validate_python_code, ) -from codeflash.code_utils.concolic_utils import clean_concolic_tests -from codeflash.code_utils.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files +from codeflash.languages.python.static_analysis.concolic_utils import clean_concolic_tests +from codeflash.languages.python.static_analysis.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files from codeflash.models.models import CodeStringsMarkdown from codeflash.verification.parse_test_output import resolve_test_file_from_class_path @@ -36,7 +36,7 @@ def multiple_existing_and_non_existing_files(tmp_path: Path) -> list[Path]: @pytest.fixture def mock_get_run_tmp_file() -> Generator[MagicMock, None, None]: - with patch("codeflash.code_utils.coverage_utils.get_run_tmp_file") as mock: + with patch("codeflash.languages.python.static_analysis.coverage_utils.get_run_tmp_file") as mock: yield mock diff --git a/tests/test_get_code.py b/tests/test_get_code.py index 50ac349cb..6f50ca44e 100644 --- a/tests/test_get_code.py +++ b/tests/test_get_code.py @@ -3,7 +3,7 @@ import pytest -from codeflash.code_utils.code_extractor import get_code +from codeflash.languages.python.static_analysis.code_extractor import get_code from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import FunctionParent diff --git a/tests/test_instrument_line_profiler.py b/tests/test_instrument_line_profiler.py index a355905e7..9b1716481 100644 --- a/tests/test_instrument_line_profiler.py +++ b/tests/test_instrument_line_profiler.py @@ -2,7 +2,7 @@ from pathlib import Path from tempfile import TemporaryDirectory -from codeflash.code_utils.line_profile_utils import add_decorator_imports, contains_jit_decorator +from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import CodeOptimizationContext from codeflash.optimization.function_optimizer import FunctionOptimizer diff --git a/tests/test_instrument_tests.py b/tests/test_instrument_tests.py index a8cd75b70..1e2b6073e 100644 --- a/tests/test_instrument_tests.py +++ b/tests/test_instrument_tests.py @@ -15,7 +15,7 @@ FunctionImportedAsVisitor, inject_profiling_into_existing_test, ) -from codeflash.code_utils.line_profile_utils import add_decorator_imports +from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import ( CodeOptimizationContext, diff --git a/tests/test_is_numerical_code.py b/tests/test_is_numerical_code.py index 831d9c97e..a13a627ac 100644 --- a/tests/test_is_numerical_code.py +++ b/tests/test_is_numerical_code.py @@ -2,10 +2,10 @@ from unittest.mock import patch -from codeflash.code_utils.code_extractor import is_numerical_code +from codeflash.languages.python.static_analysis.code_extractor import is_numerical_code -@patch("codeflash.code_utils.code_extractor.has_numba", True) +@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True) class TestBasicNumpyUsage: """Test basic numpy library detection (with numba available).""" @@ -50,7 +50,7 @@ def func(x): assert is_numerical_code(code, "func") is True -@patch("codeflash.code_utils.code_extractor.has_numba", True) +@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True) class TestNumpySubmodules: """Test numpy submodule imports (with numba available).""" @@ -265,7 +265,7 @@ def func(x): assert is_numerical_code(code, "func") is True -@patch("codeflash.code_utils.code_extractor.has_numba", True) +@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True) class TestScipyUsage: """Test SciPy library detection (with numba available).""" @@ -302,7 +302,7 @@ def func(f, x0): assert is_numerical_code(code, "func") is True -@patch("codeflash.code_utils.code_extractor.has_numba", True) +@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True) class TestMathUsage: """Test math standard library detection (with numba available).""" @@ -331,7 +331,7 @@ def calculate(x): assert is_numerical_code(code, "calculate") is True -@patch("codeflash.code_utils.code_extractor.has_numba", True) +@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True) class TestClassMethods: """Test detection in class methods, staticmethods, and classmethods (with numba available).""" @@ -472,7 +472,7 @@ def func(): assert is_numerical_code(code, "func") is False -@patch("codeflash.code_utils.code_extractor.has_numba", True) +@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True) class TestEdgeCases: """Test edge cases and special scenarios (with numba available).""" @@ -535,7 +535,7 @@ async def async_process(x): assert is_numerical_code(code, "async_process") is False -@patch("codeflash.code_utils.code_extractor.has_numba", True) +@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True) class TestStarImports: """Test handling of star imports (with numba available). @@ -575,7 +575,7 @@ def func(x): assert is_numerical_code(code, "func") is False -@patch("codeflash.code_utils.code_extractor.has_numba", True) +@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True) class TestNestedUsage: """Test nested numerical library usage patterns (with numba available).""" @@ -618,7 +618,7 @@ def func(x): assert is_numerical_code(code, "func") is True -@patch("codeflash.code_utils.code_extractor.has_numba", True) +@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True) class TestMultipleLibraries: """Test code using multiple numerical libraries (with numba available).""" @@ -643,7 +643,7 @@ def analyze(data): assert is_numerical_code(code, "analyze") is True -@patch("codeflash.code_utils.code_extractor.has_numba", True) +@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True) class TestQualifiedNames: """Test various qualified name patterns (with numba available).""" @@ -689,7 +689,7 @@ def method(self): assert is_numerical_code(code, "ClassB.method") is False -@patch("codeflash.code_utils.code_extractor.has_numba", True) +@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True) class TestEmptyFunctionName: """Test behavior when function_name is empty/None. @@ -807,7 +807,7 @@ def broken( assert is_numerical_code(code, "") is False -@patch("codeflash.code_utils.code_extractor.has_numba", False) +@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", False) class TestEmptyFunctionNameWithoutNumba: """Test empty function_name behavior when numba is NOT available. @@ -886,7 +886,7 @@ def test_empty_string_math_and_scipy_returns_false_without_numba(self): assert is_numerical_code(code, "") is False -@patch("codeflash.code_utils.code_extractor.has_numba", False) +@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", False) class TestNumbaNotAvailable: """Test behavior when numba is NOT available in the environment. diff --git a/tests/test_languages/test_find_references.py b/tests/test_languages/test_find_references.py index 537e3ef0b..88701d3d0 100644 --- a/tests/test_languages/test_find_references.py +++ b/tests/test_languages/test_find_references.py @@ -22,7 +22,7 @@ find_references, ) from codeflash.languages.base import Language, FunctionInfo, ReferenceInfo -from codeflash.code_utils.code_extractor import _format_references_as_markdown +from codeflash.languages.python.static_analysis.code_extractor import _format_references_as_markdown from codeflash.models.models import FunctionParent diff --git a/tests/test_languages/test_js_code_replacer.py b/tests/test_languages/test_js_code_replacer.py index 9e251804a..5700c4bfd 100644 --- a/tests/test_languages/test_js_code_replacer.py +++ b/tests/test_languages/test_js_code_replacer.py @@ -14,7 +14,7 @@ import pytest -from codeflash.code_utils.code_replacer import replace_function_definitions_for_language +from codeflash.languages.python.static_analysis.code_replacer import replace_function_definitions_for_language from codeflash.languages.base import Language from codeflash.languages.current import set_current_language from codeflash.languages.javascript.module_system import ( diff --git a/tests/test_remove_functions_from_generated_tests.py b/tests/test_remove_functions_from_generated_tests.py index 9bb0b4c48..505f09a83 100644 --- a/tests/test_remove_functions_from_generated_tests.py +++ b/tests/test_remove_functions_from_generated_tests.py @@ -2,7 +2,7 @@ import pytest -from codeflash.code_utils.edit_generated_tests import remove_functions_from_generated_tests +from codeflash.languages.python.static_analysis.edit_generated_tests import remove_functions_from_generated_tests from codeflash.models.models import GeneratedTests, GeneratedTestsList diff --git a/tests/test_static_analysis.py b/tests/test_static_analysis.py index b997edeab..78790da69 100644 --- a/tests/test_static_analysis.py +++ b/tests/test_static_analysis.py @@ -1,7 +1,7 @@ import ast from pathlib import Path -from codeflash.code_utils.static_analysis import ( +from codeflash.languages.python.static_analysis.static_analysis import ( FunctionKind, ImportedInternalModuleAnalysis, analyze_imported_modules, @@ -23,10 +23,10 @@ def test_analyze_imported_modules() -> None: from typing import TYPE_CHECKING if TYPE_CHECKING: - from codeflash.code_utils.static_analysis import ImportedInternalModuleAnalysis + from codeflash.languages.python.static_analysis.static_analysis import ImportedInternalModuleAnalysis def a_function(): - from codeflash.code_utils.static_analysis import analyze_imported_modules + from codeflash.languages.python.static_analysis.static_analysis import analyze_imported_modules from returns.result import Failure, Success pass """ @@ -37,8 +37,8 @@ def a_function(): expected_imported_module_analysis = [ ImportedInternalModuleAnalysis( name="static_analysis", - full_name="codeflash.code_utils.static_analysis", - file_path=project_root / Path("codeflash/code_utils/static_analysis.py"), + full_name="codeflash.languages.python.static_analysis.static_analysis", + file_path=project_root / Path("codeflash/languages/python/static_analysis/static_analysis.py"), ), ImportedInternalModuleAnalysis( name="mymodule", full_name="tests.mymodule", file_path=project_root / Path("tests/mymodule.py") diff --git a/uv.lock b/uv.lock index ef5f11206..05b79c606 100644 --- a/uv.lock +++ b/uv.lock @@ -605,7 +605,7 @@ tests = [ [[package]] name = "codeflash-benchmark" -version = "0.2.0" +version = "0.3.0" source = { editable = "codeflash-benchmark" } dependencies = [ { name = "pytest", version = "8.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },