Skip to content
Open
12 changes: 12 additions & 0 deletions code_to_optimize/sample_code.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
from functools import partial
from typing import Any

import jax.numpy as jnp
import numpy as np
import tensorflow as tf
import torch
from jax import lax
from torch import nn


class AlexNet(nn.Module):
def __init__(self, num_classes=10, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.num_classes = num_classes
self.layer = nn.Linear(5,10)

def forward(self, x):
x = self.layer(x)
return x

def tridiagonal_solve(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray:
n = len(b)

Expand Down
63 changes: 63 additions & 0 deletions code_to_optimize/tests/pytest/test_alexnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch

from code_to_optimize.sample_code import AlexNet

def test_models():
torch.manual_seed(42)
model = AlexNet(num_classes=10)
input_data = torch.randn(2,5)
assert torch.allclose(model(input_data), torch.Tensor([
[0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381,
0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071,
0.3680166304, 0.3558489084],
[-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507,
-0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662,
0.2874411345, -0.4801278412]]))

def test_models1():
torch.manual_seed(42)
model = AlexNet(num_classes=10)
input_data = torch.randn(2,5)
assert torch.allclose(model(input_data), torch.Tensor([
[0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381,
0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071,
0.3680166304, 0.3558489084],
[-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507,
-0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662,
0.2874411345, -0.4801278412]]))

def test_models2():
torch.manual_seed(42)
model = AlexNet(num_classes=10)
input_data = torch.randn(2,5)
assert torch.allclose(model(input_data), torch.Tensor([
[0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381,
0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071,
0.3680166304, 0.3558489084],
[-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507,
-0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662,
0.2874411345, -0.4801278412]]))

def test_models3():
torch.manual_seed(42)
model = AlexNet(num_classes=10)
input_data = torch.randn(2,5)
assert torch.allclose(model(input_data), torch.Tensor([
[0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381,
0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071,
0.3680166304, 0.3558489084],
[-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507,
-0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662,
0.2874411345, -0.4801278412]]))

def test_models4():
torch.manual_seed(42)
model = AlexNet(num_classes=10)
input_data = torch.randn(2,5)
assert torch.allclose(model(input_data), torch.Tensor([
[0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381,
0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071,
0.3680166304, 0.3558489084],
[-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507,
-0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662,
0.2874411345, -0.4801278412]]))
61 changes: 59 additions & 2 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,53 @@ def __init__(
self.only_function_name = function.function_name
self.module_path = module_path
self.call_positions = call_positions
# Track instance variables when optimizing forward methods (PyTorch nn.Module pattern)
self.instance_variable_names: set[str] = set()
if len(function.parents) == 1 and function.parents[0].type == "ClassDef":
self.class_name = function.top_level_parent_name

def collect_instance_variables(self, func_node: ast.FunctionDef) -> None:
"""Collect variable names that are instances of the target class.

This handles the PyTorch nn.Module pattern where:
model = AlexNet(...)
model(input_data) # calls __call__ which invokes forward()

When optimizing ClassName.forward, we need to track variables assigned
from ClassName(...) so we can instrument calls to those variables.
"""
if self.class_name is None or self.only_function_name != "forward":
return

class_name = self.class_name
instance_vars = self.instance_variable_names

# Manually traverse only assignment nodes instead of walking entire tree
nodes_to_check = list(func_node.body)
while nodes_to_check:
node = nodes_to_check.pop()

# Look for assignments like: model = ClassName(...)
if isinstance(node, ast.Assign):
value = node.value
if isinstance(value, ast.Call):
func = value.func
if isinstance(func, ast.Name) and func.id == class_name:
for target in node.targets:
if isinstance(target, ast.Name):
instance_vars.add(target.id)

# Add nested statements to check
if hasattr(node, "body"):
nodes_to_check.extend(node.body)
if hasattr(node, "orelse"):
nodes_to_check.extend(node.orelse)
if hasattr(node, "finalbody"):
nodes_to_check.extend(node.finalbody)
if hasattr(node, "handlers"):
for handler in node.handlers:
nodes_to_check.extend(handler.body)

def find_and_update_line_node(
self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None
) -> Iterable[ast.stmt] | None:
Expand Down Expand Up @@ -122,7 +166,16 @@ def iter_ast_calls(node):
codeflash_con = ast.Name(id="codeflash_con", ctx=ast.Load())

for node in iter_ast_calls(test_node):
if not node_in_call_position(node, self.call_positions):
# Check if this call is at a known position OR is an instance variable call
# for forward methods (PyTorch nn.Module pattern)
is_at_call_position = node_in_call_position(node, self.call_positions)
is_instance_call = (
isinstance(node.func, ast.Name)
and node.func.id in self.instance_variable_names
and self.only_function_name == "forward"
)

if not is_at_call_position and not is_instance_call:
continue

call_node = node
Expand All @@ -134,7 +187,8 @@ def iter_ast_calls(node):
function_name = node_func.id

# Check if this is the function we want to instrument
if function_name != fn_obj.function_name:
# Also match instance variable calls for forward methods
if function_name != fn_obj.function_name and function_name not in self.instance_variable_names:
continue

if fn_obj.is_async:
Expand Down Expand Up @@ -325,6 +379,9 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:

def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef:
if node.name.startswith("test_"):
# Collect instance variables for forward method instrumentation (PyTorch pattern)
self.collect_instance_variables(node)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit (non-blocking): instance_variable_names is accumulated across all test functions without being cleared. If a file has multiple test functions, variable names collected from test_a will persist when processing test_b. This could cause false-positive instrumentation if a variable name from one test happens to be called in another.

Consider clearing the set at the start of each test function:

Suggested change
self.collect_instance_variables(node)
self.instance_variable_names.clear()
self.collect_instance_variables(node)


did_update = False
i = len(node.body) - 1
while i >= 0:
Expand Down
153 changes: 141 additions & 12 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,27 +265,22 @@ def visit_Import(self, node: ast.Import) -> None:

def visit_Assign(self, node: ast.Assign) -> None:
"""Track variable assignments, especially class instantiations."""
if self.found_any_target_function:
return

# Check if the assignment is a class instantiation
# Always track instance assignments, even if we've found a target function
# This is needed for the PyTorch nn.Module pattern where model(x) calls forward(x)
value = node.value
if isinstance(value, ast.Call) and isinstance(value.func, ast.Name):
class_name = value.func.id
if class_name in self.imported_modules:
# Map the variable to the actual class name (handling aliases)
original_class = self.alias_mapping.get(class_name, class_name)
# Use list comprehension for direct assignment to instance_mapping, reducing loop overhead
targets = node.targets
instance_mapping = self.instance_mapping
# since ast.Name nodes are heavily used, avoid local lookup for isinstance
# and reuse locals for faster attribute access
for target in targets:
if isinstance(target, ast.Name):
instance_mapping[target.id] = original_class
self.instance_mapping[target.id] = original_class

# Continue visiting child nodes
self.generic_visit(node)
# Continue visiting child nodes if we haven't found a target function yet
if not self.found_any_target_function:
self.generic_visit(node)

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""Handle 'from module import name' statements."""
Expand Down Expand Up @@ -405,7 +400,7 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
ast.NodeVisitor.generic_visit(self, node)

def visit_Call(self, node: ast.Call) -> None:
"""Handle function calls, particularly __import__."""
"""Handle function calls, particularly __import__ and instance calls for nn.Module.forward."""
if self.found_any_target_function:
return

Expand All @@ -415,6 +410,19 @@ def visit_Call(self, node: ast.Call) -> None:
# When __import__ is used, any target function could potentially be imported
# Be conservative and assume it might import target functions

# Check if this is a call on an instance variable (PyTorch nn.Module pattern)
# When model = AlexNet(...) and we call model(input_data), this invokes forward()
if isinstance(node.func, ast.Name):
instance_name = node.func.id
if instance_name in self.instance_mapping:
class_name = self.instance_mapping[instance_name]
# Check if ClassName.forward is in our target functions
roots_possible = self._dot_methods.get("forward")
if roots_possible and class_name in roots_possible:
self.found_any_target_function = True
self.found_qualified_name = self._class_method_to_target[(class_name, "forward")]
return

self.generic_visit(node)

def visit_Name(self, node: ast.Name) -> None:
Expand Down Expand Up @@ -495,6 +503,68 @@ def _fast_generic_visit(self, node: ast.AST) -> None:
append((value._fields, value))


class InstanceMappingExtractor(ast.NodeVisitor):
"""Simple visitor to extract instance-to-class mappings from a file.

This is needed for detecting PyTorch nn.Module.forward calls where model(x) calls forward(x).
"""

def __init__(self) -> None:
self.imported_modules: set[str] = set()
self.alias_mapping: dict[str, str] = {}
self.instance_mapping: dict[str, str] = {}

def visit_Import(self, node: ast.Import) -> None:
for alias in node.names:
module_name = alias.asname if alias.asname else alias.name
self.imported_modules.add(module_name)
self.generic_visit(node)

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if not node.module:
return
for alias in node.names:
if alias.name == "*":
continue
imported_name = alias.asname if alias.asname else alias.name
self.imported_modules.add(imported_name)
if alias.asname:
self.alias_mapping[imported_name] = alias.name
self.generic_visit(node)
Comment on lines +526 to +533
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 363% (3.63x) speedup for InstanceMappingExtractor.visit_ImportFrom in codeflash/discovery/discover_unit_tests.py

⏱️ Runtime : 1.52 milliseconds 329 microseconds (best of 129 runs)

📝 Explanation and details

The optimized code achieves a 362% speedup (from 1.52ms to 329μs) by eliminating an expensive, unnecessary call to generic_visit() that was consuming 80% of the original runtime.

Key Optimization

The original code unconditionally called self.generic_visit(node) at the end of visit_ImportFrom(), which triggers a full AST traversal of all child nodes. However, ImportFrom nodes only contain alias children that were already processed in the explicit for alias in node.names loop above.

The optimized version:

  1. Checks if a subclass has overridden generic_visit() - If so, calls it to preserve custom traversal behavior
  2. Otherwise, skips the heavy traversal entirely and only manually visits alias children if a visit_alias() handler exists

This change is safe because:

  • The InstanceMappingExtractor class doesn't override generic_visit() or define visit_alias()
  • All alias processing is already done explicitly in the loop
  • The redundant traversal was performing no additional work beyond what the loop already accomplished

Performance Impact

The line profiler shows the generic_visit() call dropped from 15.6ms (80% of runtime) to essentially zero. Test results demonstrate consistent speedups across all scenarios:

  • Simple imports: 165-225% faster
  • Large-scale imports (1000+ aliases): 372-386% faster
  • The optimization is most impactful for files with many import statements

The few test cases showing minor slowdowns (2-8%) involve trivial early-return paths (e.g., module=None) where the new conditional checks add overhead that exceeds the already-minimal original cost. This is a reasonable trade-off given the dramatic improvements in all realistic use cases.

Context

Since InstanceMappingExtractor is used for AST analysis (extracting import mappings for PyTorch module detection), it likely processes many Python files with numerous imports. This optimization significantly reduces the cost of analyzing import-heavy codebases.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 55 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
import ast

import pytest  # used for our unit tests
from codeflash.discovery.discover_unit_tests import InstanceMappingExtractor

def test_single_import_without_alias_adds_module_name_to_imported_modules():
    # Create a visitor instance to test instance state changes.
    extractor = InstanceMappingExtractor()
    # Create an ImportFrom node: from mypkg import MyClass
    node = ast.ImportFrom(module="mypkg", names=[ast.alias(name="MyClass", asname=None)], level=0)
    # Call the specific visitor method directly as required.
    extractor.visit_ImportFrom(node) # 4.57μs -> 1.48μs (208% faster)

def test_single_import_with_alias_creates_alias_mapping_and_imported_modules_entry():
    # Create a fresh visitor instance.
    extractor = InstanceMappingExtractor()
    # Create an ImportFrom node with an alias: from pkg import RealName as AliasName
    node = ast.ImportFrom(module="pkg", names=[ast.alias(name="RealName", asname="AliasName")], level=0)
    # Invoke the method under test.
    extractor.visit_ImportFrom(node) # 4.64μs -> 1.75μs (165% faster)

def test_star_import_is_ignored_and_does_not_modify_mappings():
    extractor = InstanceMappingExtractor()
    # from pkg import *
    node = ast.ImportFrom(module="pkg", names=[ast.alias(name="*", asname=None)], level=0)
    extractor.visit_ImportFrom(node) # 4.20μs -> 1.29μs (225% faster)

def test_none_module_is_ignored_no_changes_made():
    extractor = InstanceMappingExtractor()
    # An ImportFrom with module set to None should be ignored.
    node = ast.ImportFrom(module=None, names=[ast.alias(name="Name", asname=None)], level=0)
    extractor.visit_ImportFrom(node) # 461ns -> 471ns (2.12% slower)

def test_empty_names_list_does_not_raise_and_makes_no_changes():
    extractor = InstanceMappingExtractor()
    # from pkg import  (no names) -> names list empty
    node = ast.ImportFrom(module="pkg", names=[], level=0)
    # Should simply return without error and without modifications.
    extractor.visit_ImportFrom(node) # 2.25μs -> 1.24μs (81.6% faster)

def test_empty_asname_string_treated_as_no_alias():
    extractor = InstanceMappingExtractor()
    # asname is empty string: from pkg import Name as ""
    # Empty string is falsy; the implementation checks `if alias.asname:`
    node = ast.ImportFrom(module="pkg", names=[ast.alias(name="Name", asname="")], level=0)
    extractor.visit_ImportFrom(node) # 4.54μs -> 1.63μs (178% faster)

def test_special_characters_in_names_and_aliases_handled_correctly():
    extractor = InstanceMappingExtractor()
    # Use names with underscores and digits and unusual but valid identifier-like strings.
    aliases = [
        ast.alias(name="Cls_1", asname=None),
        ast.alias(name="RealName2", asname="alias_2"),
        ast.alias(name="ÎnvalidUnicode", asname="alias_unicode"),  # unicode inside Python identifier context
    ]
    node = ast.ImportFrom(module="some_pkg", names=aliases, level=0)
    extractor.visit_ImportFrom(node) # 7.10μs -> 2.33μs (204% faster)

def test_duplicate_aliases_do_not_raise_and_result_in_set_behavior_for_imported_modules():
    extractor = InstanceMappingExtractor()
    # Two identical asnames and real names repeated; set semantics means only one stored.
    aliases = [
        ast.alias(name="A", asname="X"),
        ast.alias(name="A", asname="X"),
        ast.alias(name="B", asname="Y"),
    ]
    node = ast.ImportFrom(module="dup_pkg", names=aliases, level=0)
    extractor.visit_ImportFrom(node) # 7.22μs -> 2.40μs (202% faster)

def test_large_scale_many_aliases_performance_and_correctness():
    extractor = InstanceMappingExtractor()
    # Construct 1000 aliases where every even one has an asname, odds do not.
    num = 1000
    names = []
    for i in range(num):
        real = f"Real{i}"
        # Give half of them an alias, half not.
        if i % 2 == 0:
            # alias name will be Alias{i}
            names.append(ast.alias(name=real, asname=f"Alias{i}"))
        else:
            names.append(ast.alias(name=real, asname=None))
    node = ast.ImportFrom(module="big_pkg", names=names, level=0)
    # Ensure this runs quickly and does not raise.
    extractor.visit_ImportFrom(node) # 986μs -> 208μs (372% faster)
    # Validate sizes: imported_modules should contain the aliased names (num/2) plus the non-aliased original names (num/2).
    expected_imported = set()
    for i in range(num):
        if i % 2 == 0:
            expected_imported.add(f"Alias{i}")
        else:
            expected_imported.add(f"Real{i}")
    # alias_mapping should contain only the even indices mapping alias->real.
    expected_alias_map = {f"Alias{i}": f"Real{i}" for i in range(0, num, 2)}

def test_mixed_star_and_regular_imports_large_scale():
    extractor = InstanceMappingExtractor()
    # Build a large list mixing star imports, aliased imports and non-aliased imports.
    names = []
    num = 500
    for i in range(num):
        if i % 10 == 0:
            # include a star import occasionally
            names.append(ast.alias(name="*", asname=None))
        elif i % 3 == 0:
            names.append(ast.alias(name=f"Real{i}", asname=f"Alias{i}"))
        else:
            names.append(ast.alias(name=f"Real{i}", asname=None))
    node = ast.ImportFrom(module="mix_pkg", names=names, level=0)
    extractor.visit_ImportFrom(node) # 480μs -> 98.9μs (386% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import ast

# imports
import pytest
from codeflash.discovery.discover_unit_tests import InstanceMappingExtractor

def test_visit_import_from_single_import():
    """Test visiting a simple import statement with a single name."""
    code = "from module import name"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_with_alias():
    """Test visiting an import statement with an alias."""
    code = "from module import name as alias_name"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_multiple_names():
    """Test visiting an import statement with multiple names."""
    code = "from module import name1, name2, name3"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_mixed_aliases():
    """Test visiting an import statement with both aliased and non-aliased names."""
    code = "from module import name1, name2 as alias2, name3"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_star_import():
    """Test that star imports are skipped and not added to imported_modules."""
    code = "from module import *"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_no_module():
    """Test that ImportFrom with no module attribute is handled gracefully."""
    # Create an ImportFrom node with module=None (relative import)
    node = ast.ImportFrom(module=None, names=[ast.alias(name="name", asname=None)], level=1)
    extractor = InstanceMappingExtractor()
    
    # Should return None without error
    codeflash_output = extractor.visit_ImportFrom(node); result = codeflash_output # 571ns -> 602ns (5.15% slower)

def test_visit_import_from_torch_nn_module():
    """Test a realistic PyTorch import scenario."""
    code = "from torch import nn"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_multiple_imports_same_module():
    """Test visiting multiple ImportFrom statements from different modules."""
    code = """from module1 import name1
from module2 import name2 as alias2
from module3 import name3"""
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_empty_module_name():
    """Test ImportFrom with an empty string as module name."""
    # Create an ImportFrom node with empty module
    node = ast.ImportFrom(module="", names=[ast.alias(name="name", asname=None)], level=0)
    extractor = InstanceMappingExtractor()
    
    # Process the node
    extractor.visit_ImportFrom(node) # 561ns -> 611ns (8.18% slower)

def test_visit_import_from_no_names():
    """Test ImportFrom with empty names list."""
    node = ast.ImportFrom(module="module", names=[], level=0)
    extractor = InstanceMappingExtractor()
    
    # Process the node
    extractor.visit_ImportFrom(node) # 2.52μs -> 1.23μs (104% faster)

def test_visit_import_from_special_characters_in_name():
    """Test import names with underscores and numbers."""
    code = "from module import _private, __dunder__, name123"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_long_module_path():
    """Test ImportFrom with a deeply nested module path."""
    code = "from package.subpackage.module import name"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_relative_import_with_level():
    """Test relative imports with various levels."""
    code = "from . import name"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_star_with_other_names():
    """Test that when star is mixed with other names, only star is skipped."""
    # Create a node with star and another name
    node = ast.ImportFrom(
        module="module",
        names=[ast.alias(name="*", asname=None), ast.alias(name="name", asname=None)],
        level=0
    )
    extractor = InstanceMappingExtractor()
    
    # Process the node
    extractor.visit_ImportFrom(node) # 5.86μs -> 1.92μs (205% faster)

def test_visit_import_from_alias_overwrites_same_name():
    """Test that importing the same name twice with different aliases uses the latest."""
    code = """from module import name as alias1
from module import name as alias2"""
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_single_char_names():
    """Test importing single character names."""
    code = "from module import a, b, c"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_very_long_names():
    """Test importing very long names."""
    long_name = "a" * 100
    code = f"from module import {long_name}"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_alias_same_as_original():
    """Test when alias is the same as the original name."""
    code = "from module import name as name"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_preserves_state():
    """Test that the extractor preserves state across multiple visits."""
    code1 = "from module1 import name1"
    code2 = "from module2 import name2"
    
    tree1 = ast.parse(code1)
    tree2 = ast.parse(code2)
    
    extractor = InstanceMappingExtractor()
    extractor.visit(tree1)
    extractor.visit(tree2)

def test_visit_import_from_many_imports_single_statement():
    """Test importing many names from a single module."""
    # Generate a large import statement
    names = [f"name{i}" for i in range(500)]
    code = f"from module import {', '.join(names)}"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)
    for i in range(500):
        pass

def test_visit_import_from_many_imports_multiple_statements():
    """Test processing many separate ImportFrom statements."""
    code_lines = [f"from module{i} import name{i}" for i in range(500)]
    code = "\n".join(code_lines)
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)
    for i in range(500):
        pass

def test_visit_import_from_many_aliases():
    """Test importing many names with aliases."""
    names = [f"name{i} as alias{i}" for i in range(500)]
    code = f"from module import {', '.join(names)}"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)
    for i in range(500):
        pass

def test_visit_import_from_complex_mixed_imports():
    """Test complex scenario with many mixed imports."""
    code_lines = []
    for i in range(250):
        if i % 3 == 0:
            # Non-aliased import
            code_lines.append(f"from module{i} import name{i}")
        elif i % 3 == 1:
            # Aliased import
            code_lines.append(f"from module{i} import name{i} as alias{i}")
        else:
            # Multiple imports
            code_lines.append(f"from module{i} import name{i}a, name{i}b, name{i}c")
    
    code = "\n".join(code_lines)
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_deeply_nested_code():
    """Test ImportFrom statements within nested code structures."""
    code = """
def func():
    if True:
        from module1 import name1
        for i in range(10):
            from module2 import name2
            try:
                from module3 import name3
            except:
                pass
    from module4 import name4
"""
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_instance_mapping_remains_empty():
    """Test that instance_mapping remains empty after visiting ImportFrom nodes."""
    code = "from module import name1, name2 as alias2"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_set_behavior_duplicates():
    """Test that imported_modules set prevents duplicates."""
    # Manually create scenario where same name could be added twice
    node1 = ast.ImportFrom(module="module1", names=[ast.alias(name="name", asname=None)], level=0)
    node2 = ast.ImportFrom(module="module2", names=[ast.alias(name="name", asname=None)], level=0)
    
    extractor = InstanceMappingExtractor()
    extractor.visit_ImportFrom(node1) # 4.74μs -> 1.77μs (167% faster)
    initial_size = len(extractor.imported_modules)
    extractor.visit_ImportFrom(node2) # 2.77μs -> 892ns (211% faster)
    final_size = len(extractor.imported_modules)

def test_visit_import_from_generic_visit_called():
    """Test that generic_visit is properly called for child traversal."""
    # Create a nested AST structure that would require generic_visit
    code = """
from module import (
    name1,
    name2,
    name3
)
"""
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_return_value_is_none():
    """Test that visit_ImportFrom returns None (standard NodeVisitor behavior)."""
    code = "from module import name"
    tree = ast.parse(code)
    
    # Get the ImportFrom node
    for node in ast.walk(tree):
        if isinstance(node, ast.ImportFrom):
            extractor = InstanceMappingExtractor()
            codeflash_output = extractor.visit_ImportFrom(node); result = codeflash_output
            break
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr1418-2026-02-17T13.07.22

Click to see suggested changes
Suggested change
for alias in node.names:
if alias.name == "*":
continue
imported_name = alias.asname if alias.asname else alias.name
self.imported_modules.add(imported_name)
if alias.asname:
self.alias_mapping[imported_name] = alias.name
self.generic_visit(node)
# Preserve original behavior: if a subclass has overridden generic_visit,
# call it to allow that custom traversal to run. Otherwise, avoid the
# heavy generic traversal and only visit alias children if a specific
# visit_alias handler exists.
if getattr(type(self), "generic_visit", ast.NodeVisitor.generic_visit) is not ast.NodeVisitor.generic_visit:
self.generic_visit(node)
return
has_visit_alias = hasattr(self, "visit_alias")
for alias in node.names:
if alias.name == "*":
continue
imported_name = alias.asname if alias.asname else alias.name
self.imported_modules.add(imported_name)
if alias.asname:
self.alias_mapping[imported_name] = alias.name
if has_visit_alias:
# If a visitor for alias nodes exists, invoke it to match
# the behavior of the default generic_visit which would have
# dispatched to visit_alias.
self.visit(alias)

Static Badge


def visit_Assign(self, node: ast.Assign) -> None:
value = node.value
if isinstance(value, ast.Call) and isinstance(value.func, ast.Name):
class_name = value.func.id
if class_name in self.imported_modules:
original_class = self.alias_mapping.get(class_name, class_name)
for target in node.targets:
if isinstance(target, ast.Name):
self.instance_mapping[target.id] = original_class
self.generic_visit(node)


def extract_instance_mapping(test_file_path: Path) -> dict[str, str]:
"""Extract instance-to-class mappings from a test file.

Args:
test_file_path: Path to the test file.

Returns:
Dictionary mapping instance variable names to class names.

"""
try:
with test_file_path.open("r", encoding="utf-8") as f:
source_code = f.read()
tree = ast.parse(source_code, filename=str(test_file_path))
extractor = InstanceMappingExtractor()
extractor.visit(tree)
return extractor.instance_mapping
except (SyntaxError, FileNotFoundError):
return {}


def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> bool:
"""Analyze a test file to see if it imports any of the target functions."""
try:
Expand Down Expand Up @@ -879,6 +949,10 @@ def process_test_files(
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}

# Get instance-to-class mappings for PyTorch nn.Module.forward detection
# When model = AlexNet(...) and model(x) is called, it invokes forward(x)
instance_to_class_mapping = extract_instance_mapping(test_file) if functions_to_optimize else {}

except Exception as e:
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
progress.advance(task_id)
Expand Down Expand Up @@ -1017,6 +1091,61 @@ def process_test_files(
num_discovered_replay_tests += 1

num_discovered_tests += 1

# Also check for PyTorch nn.Module pattern: model(x) -> forward(x)
# When an instance variable is called, it invokes the forward method
if name.name in instance_to_class_mapping:
class_name = instance_to_class_mapping[name.name]
for func_to_opt in functions_to_optimize:
# Check if the target is ClassName.forward
if (
func_to_opt.function_name == "forward"
and func_to_opt.top_level_parent_name == class_name
):
qualified_name_with_modules = func_to_opt.qualified_name_with_modules_from_root(
project_root_path
)

for test_func in test_functions_by_name[scope]:
if test_func.parameters is not None:
if test_framework == "pytest":
scope_test_function = (
f"{test_func.function_name}[{test_func.parameters}]"
)
else: # unittest
scope_test_function = (
f"{test_func.function_name}_{test_func.parameters}"
)
else:
scope_test_function = test_func.function_name

function_to_test_map[qualified_name_with_modules].add(
FunctionCalledInTest(
tests_in_file=TestsInFile(
test_file=test_file,
test_class=test_func.test_class,
test_function=scope_test_function,
test_type=test_func.test_type,
),
position=CodePosition(line_no=name.line, col_no=name.column),
)
)
tests_cache.insert_test(
file_path=str(test_file),
file_hash=file_hash,
qualified_name_with_modules_from_root=qualified_name_with_modules,
function_name=scope,
test_class=test_func.test_class or "",
test_function=scope_test_function,
test_type=test_func.test_type,
line_number=name.line,
col_number=name.column,
)

if test_func.test_type == TestType.REPLAY_TEST:
num_discovered_replay_tests += 1

num_discovered_tests += 1
continue
definition_obj = definition[0]
definition_path = str(definition_obj.module_path)
Expand Down
Loading
Loading