-
Notifications
You must be signed in to change notification settings - Fork 21
fix: instrument PyTorch nn.Module forward method calls via instance #1418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
39c8610
bb932ab
e6717b2
69295b0
db95204
5292d67
6e3d6e7
ee6f901
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| import torch | ||
|
|
||
| from code_to_optimize.sample_code import AlexNet | ||
|
|
||
| def test_models(): | ||
| torch.manual_seed(42) | ||
| model = AlexNet(num_classes=10) | ||
| input_data = torch.randn(2,5) | ||
| assert torch.allclose(model(input_data), torch.Tensor([ | ||
| [0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381, | ||
| 0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071, | ||
| 0.3680166304, 0.3558489084], | ||
| [-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507, | ||
| -0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662, | ||
| 0.2874411345, -0.4801278412]])) | ||
|
|
||
| def test_models1(): | ||
| torch.manual_seed(42) | ||
| model = AlexNet(num_classes=10) | ||
| input_data = torch.randn(2,5) | ||
| assert torch.allclose(model(input_data), torch.Tensor([ | ||
| [0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381, | ||
| 0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071, | ||
| 0.3680166304, 0.3558489084], | ||
| [-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507, | ||
| -0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662, | ||
| 0.2874411345, -0.4801278412]])) | ||
|
|
||
| def test_models2(): | ||
| torch.manual_seed(42) | ||
| model = AlexNet(num_classes=10) | ||
| input_data = torch.randn(2,5) | ||
| assert torch.allclose(model(input_data), torch.Tensor([ | ||
| [0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381, | ||
| 0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071, | ||
| 0.3680166304, 0.3558489084], | ||
| [-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507, | ||
| -0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662, | ||
| 0.2874411345, -0.4801278412]])) | ||
|
|
||
| def test_models3(): | ||
| torch.manual_seed(42) | ||
| model = AlexNet(num_classes=10) | ||
| input_data = torch.randn(2,5) | ||
| assert torch.allclose(model(input_data), torch.Tensor([ | ||
| [0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381, | ||
| 0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071, | ||
| 0.3680166304, 0.3558489084], | ||
| [-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507, | ||
| -0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662, | ||
| 0.2874411345, -0.4801278412]])) | ||
|
|
||
| def test_models4(): | ||
| torch.manual_seed(42) | ||
| model = AlexNet(num_classes=10) | ||
| input_data = torch.randn(2,5) | ||
| assert torch.allclose(model(input_data), torch.Tensor([ | ||
| [0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381, | ||
| 0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071, | ||
| 0.3680166304, 0.3558489084], | ||
| [-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507, | ||
| -0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662, | ||
| 0.2874411345, -0.4801278412]])) |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -265,27 +265,22 @@ def visit_Import(self, node: ast.Import) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def visit_Assign(self, node: ast.Assign) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Track variable assignments, especially class instantiations.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.found_any_target_function: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Check if the assignment is a class instantiation | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Always track instance assignments, even if we've found a target function | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # This is needed for the PyTorch nn.Module pattern where model(x) calls forward(x) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| value = node.value | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(value, ast.Call) and isinstance(value.func, ast.Name): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class_name = value.func.id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if class_name in self.imported_modules: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Map the variable to the actual class name (handling aliases) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| original_class = self.alias_mapping.get(class_name, class_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Use list comprehension for direct assignment to instance_mapping, reducing loop overhead | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| targets = node.targets | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| instance_mapping = self.instance_mapping | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # since ast.Name nodes are heavily used, avoid local lookup for isinstance | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # and reuse locals for faster attribute access | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for target in targets: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(target, ast.Name): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| instance_mapping[target.id] = original_class | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.instance_mapping[target.id] = original_class | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Continue visiting child nodes | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.generic_visit(node) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Continue visiting child nodes if we haven't found a target function yet | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not self.found_any_target_function: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.generic_visit(node) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def visit_ImportFrom(self, node: ast.ImportFrom) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Handle 'from module import name' statements.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -405,7 +400,7 @@ def visit_Attribute(self, node: ast.Attribute) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ast.NodeVisitor.generic_visit(self, node) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def visit_Call(self, node: ast.Call) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Handle function calls, particularly __import__.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Handle function calls, particularly __import__ and instance calls for nn.Module.forward.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.found_any_target_function: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -415,6 +410,19 @@ def visit_Call(self, node: ast.Call) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # When __import__ is used, any target function could potentially be imported | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Be conservative and assume it might import target functions | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Check if this is a call on an instance variable (PyTorch nn.Module pattern) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # When model = AlexNet(...) and we call model(input_data), this invokes forward() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(node.func, ast.Name): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| instance_name = node.func.id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if instance_name in self.instance_mapping: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class_name = self.instance_mapping[instance_name] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Check if ClassName.forward is in our target functions | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| roots_possible = self._dot_methods.get("forward") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if roots_possible and class_name in roots_possible: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.found_any_target_function = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.found_qualified_name = self._class_method_to_target[(class_name, "forward")] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.generic_visit(node) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def visit_Name(self, node: ast.Name) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -495,6 +503,68 @@ def _fast_generic_visit(self, node: ast.AST) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| append((value._fields, value)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class InstanceMappingExtractor(ast.NodeVisitor): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Simple visitor to extract instance-to-class mappings from a file. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| This is needed for detecting PyTorch nn.Module.forward calls where model(x) calls forward(x). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.imported_modules: set[str] = set() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.alias_mapping: dict[str, str] = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.instance_mapping: dict[str, str] = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def visit_Import(self, node: ast.Import) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for alias in node.names: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| module_name = alias.asname if alias.asname else alias.name | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.imported_modules.add(module_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.generic_visit(node) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def visit_ImportFrom(self, node: ast.ImportFrom) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not node.module: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for alias in node.names: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if alias.name == "*": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| imported_name = alias.asname if alias.asname else alias.name | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.imported_modules.add(imported_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if alias.asname: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.alias_mapping[imported_name] = alias.name | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.generic_visit(node) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+526
to
+533
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ⚡️Codeflash found 363% (3.63x) speedup for
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Test | Status |
|---|---|
| ⚙️ Existing Unit Tests | 🔘 None Found |
| 🌀 Generated Regression Tests | ✅ 55 Passed |
| ⏪ Replay Tests | 🔘 None Found |
| 🔎 Concolic Coverage Tests | 🔘 None Found |
| 📊 Tests Coverage | 100.0% |
🌀 Click to see Generated Regression Tests
import ast
import pytest # used for our unit tests
from codeflash.discovery.discover_unit_tests import InstanceMappingExtractor
def test_single_import_without_alias_adds_module_name_to_imported_modules():
# Create a visitor instance to test instance state changes.
extractor = InstanceMappingExtractor()
# Create an ImportFrom node: from mypkg import MyClass
node = ast.ImportFrom(module="mypkg", names=[ast.alias(name="MyClass", asname=None)], level=0)
# Call the specific visitor method directly as required.
extractor.visit_ImportFrom(node) # 4.57μs -> 1.48μs (208% faster)
def test_single_import_with_alias_creates_alias_mapping_and_imported_modules_entry():
# Create a fresh visitor instance.
extractor = InstanceMappingExtractor()
# Create an ImportFrom node with an alias: from pkg import RealName as AliasName
node = ast.ImportFrom(module="pkg", names=[ast.alias(name="RealName", asname="AliasName")], level=0)
# Invoke the method under test.
extractor.visit_ImportFrom(node) # 4.64μs -> 1.75μs (165% faster)
def test_star_import_is_ignored_and_does_not_modify_mappings():
extractor = InstanceMappingExtractor()
# from pkg import *
node = ast.ImportFrom(module="pkg", names=[ast.alias(name="*", asname=None)], level=0)
extractor.visit_ImportFrom(node) # 4.20μs -> 1.29μs (225% faster)
def test_none_module_is_ignored_no_changes_made():
extractor = InstanceMappingExtractor()
# An ImportFrom with module set to None should be ignored.
node = ast.ImportFrom(module=None, names=[ast.alias(name="Name", asname=None)], level=0)
extractor.visit_ImportFrom(node) # 461ns -> 471ns (2.12% slower)
def test_empty_names_list_does_not_raise_and_makes_no_changes():
extractor = InstanceMappingExtractor()
# from pkg import (no names) -> names list empty
node = ast.ImportFrom(module="pkg", names=[], level=0)
# Should simply return without error and without modifications.
extractor.visit_ImportFrom(node) # 2.25μs -> 1.24μs (81.6% faster)
def test_empty_asname_string_treated_as_no_alias():
extractor = InstanceMappingExtractor()
# asname is empty string: from pkg import Name as ""
# Empty string is falsy; the implementation checks `if alias.asname:`
node = ast.ImportFrom(module="pkg", names=[ast.alias(name="Name", asname="")], level=0)
extractor.visit_ImportFrom(node) # 4.54μs -> 1.63μs (178% faster)
def test_special_characters_in_names_and_aliases_handled_correctly():
extractor = InstanceMappingExtractor()
# Use names with underscores and digits and unusual but valid identifier-like strings.
aliases = [
ast.alias(name="Cls_1", asname=None),
ast.alias(name="RealName2", asname="alias_2"),
ast.alias(name="ÎnvalidUnicode", asname="alias_unicode"), # unicode inside Python identifier context
]
node = ast.ImportFrom(module="some_pkg", names=aliases, level=0)
extractor.visit_ImportFrom(node) # 7.10μs -> 2.33μs (204% faster)
def test_duplicate_aliases_do_not_raise_and_result_in_set_behavior_for_imported_modules():
extractor = InstanceMappingExtractor()
# Two identical asnames and real names repeated; set semantics means only one stored.
aliases = [
ast.alias(name="A", asname="X"),
ast.alias(name="A", asname="X"),
ast.alias(name="B", asname="Y"),
]
node = ast.ImportFrom(module="dup_pkg", names=aliases, level=0)
extractor.visit_ImportFrom(node) # 7.22μs -> 2.40μs (202% faster)
def test_large_scale_many_aliases_performance_and_correctness():
extractor = InstanceMappingExtractor()
# Construct 1000 aliases where every even one has an asname, odds do not.
num = 1000
names = []
for i in range(num):
real = f"Real{i}"
# Give half of them an alias, half not.
if i % 2 == 0:
# alias name will be Alias{i}
names.append(ast.alias(name=real, asname=f"Alias{i}"))
else:
names.append(ast.alias(name=real, asname=None))
node = ast.ImportFrom(module="big_pkg", names=names, level=0)
# Ensure this runs quickly and does not raise.
extractor.visit_ImportFrom(node) # 986μs -> 208μs (372% faster)
# Validate sizes: imported_modules should contain the aliased names (num/2) plus the non-aliased original names (num/2).
expected_imported = set()
for i in range(num):
if i % 2 == 0:
expected_imported.add(f"Alias{i}")
else:
expected_imported.add(f"Real{i}")
# alias_mapping should contain only the even indices mapping alias->real.
expected_alias_map = {f"Alias{i}": f"Real{i}" for i in range(0, num, 2)}
def test_mixed_star_and_regular_imports_large_scale():
extractor = InstanceMappingExtractor()
# Build a large list mixing star imports, aliased imports and non-aliased imports.
names = []
num = 500
for i in range(num):
if i % 10 == 0:
# include a star import occasionally
names.append(ast.alias(name="*", asname=None))
elif i % 3 == 0:
names.append(ast.alias(name=f"Real{i}", asname=f"Alias{i}"))
else:
names.append(ast.alias(name=f"Real{i}", asname=None))
node = ast.ImportFrom(module="mix_pkg", names=names, level=0)
extractor.visit_ImportFrom(node) # 480μs -> 98.9μs (386% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.import ast
# imports
import pytest
from codeflash.discovery.discover_unit_tests import InstanceMappingExtractor
def test_visit_import_from_single_import():
"""Test visiting a simple import statement with a single name."""
code = "from module import name"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_with_alias():
"""Test visiting an import statement with an alias."""
code = "from module import name as alias_name"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_multiple_names():
"""Test visiting an import statement with multiple names."""
code = "from module import name1, name2, name3"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_mixed_aliases():
"""Test visiting an import statement with both aliased and non-aliased names."""
code = "from module import name1, name2 as alias2, name3"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_star_import():
"""Test that star imports are skipped and not added to imported_modules."""
code = "from module import *"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_no_module():
"""Test that ImportFrom with no module attribute is handled gracefully."""
# Create an ImportFrom node with module=None (relative import)
node = ast.ImportFrom(module=None, names=[ast.alias(name="name", asname=None)], level=1)
extractor = InstanceMappingExtractor()
# Should return None without error
codeflash_output = extractor.visit_ImportFrom(node); result = codeflash_output # 571ns -> 602ns (5.15% slower)
def test_visit_import_from_torch_nn_module():
"""Test a realistic PyTorch import scenario."""
code = "from torch import nn"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_multiple_imports_same_module():
"""Test visiting multiple ImportFrom statements from different modules."""
code = """from module1 import name1
from module2 import name2 as alias2
from module3 import name3"""
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_empty_module_name():
"""Test ImportFrom with an empty string as module name."""
# Create an ImportFrom node with empty module
node = ast.ImportFrom(module="", names=[ast.alias(name="name", asname=None)], level=0)
extractor = InstanceMappingExtractor()
# Process the node
extractor.visit_ImportFrom(node) # 561ns -> 611ns (8.18% slower)
def test_visit_import_from_no_names():
"""Test ImportFrom with empty names list."""
node = ast.ImportFrom(module="module", names=[], level=0)
extractor = InstanceMappingExtractor()
# Process the node
extractor.visit_ImportFrom(node) # 2.52μs -> 1.23μs (104% faster)
def test_visit_import_from_special_characters_in_name():
"""Test import names with underscores and numbers."""
code = "from module import _private, __dunder__, name123"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_long_module_path():
"""Test ImportFrom with a deeply nested module path."""
code = "from package.subpackage.module import name"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_relative_import_with_level():
"""Test relative imports with various levels."""
code = "from . import name"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_star_with_other_names():
"""Test that when star is mixed with other names, only star is skipped."""
# Create a node with star and another name
node = ast.ImportFrom(
module="module",
names=[ast.alias(name="*", asname=None), ast.alias(name="name", asname=None)],
level=0
)
extractor = InstanceMappingExtractor()
# Process the node
extractor.visit_ImportFrom(node) # 5.86μs -> 1.92μs (205% faster)
def test_visit_import_from_alias_overwrites_same_name():
"""Test that importing the same name twice with different aliases uses the latest."""
code = """from module import name as alias1
from module import name as alias2"""
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_single_char_names():
"""Test importing single character names."""
code = "from module import a, b, c"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_very_long_names():
"""Test importing very long names."""
long_name = "a" * 100
code = f"from module import {long_name}"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_alias_same_as_original():
"""Test when alias is the same as the original name."""
code = "from module import name as name"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_preserves_state():
"""Test that the extractor preserves state across multiple visits."""
code1 = "from module1 import name1"
code2 = "from module2 import name2"
tree1 = ast.parse(code1)
tree2 = ast.parse(code2)
extractor = InstanceMappingExtractor()
extractor.visit(tree1)
extractor.visit(tree2)
def test_visit_import_from_many_imports_single_statement():
"""Test importing many names from a single module."""
# Generate a large import statement
names = [f"name{i}" for i in range(500)]
code = f"from module import {', '.join(names)}"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
for i in range(500):
pass
def test_visit_import_from_many_imports_multiple_statements():
"""Test processing many separate ImportFrom statements."""
code_lines = [f"from module{i} import name{i}" for i in range(500)]
code = "\n".join(code_lines)
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
for i in range(500):
pass
def test_visit_import_from_many_aliases():
"""Test importing many names with aliases."""
names = [f"name{i} as alias{i}" for i in range(500)]
code = f"from module import {', '.join(names)}"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
for i in range(500):
pass
def test_visit_import_from_complex_mixed_imports():
"""Test complex scenario with many mixed imports."""
code_lines = []
for i in range(250):
if i % 3 == 0:
# Non-aliased import
code_lines.append(f"from module{i} import name{i}")
elif i % 3 == 1:
# Aliased import
code_lines.append(f"from module{i} import name{i} as alias{i}")
else:
# Multiple imports
code_lines.append(f"from module{i} import name{i}a, name{i}b, name{i}c")
code = "\n".join(code_lines)
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_deeply_nested_code():
"""Test ImportFrom statements within nested code structures."""
code = """
def func():
if True:
from module1 import name1
for i in range(10):
from module2 import name2
try:
from module3 import name3
except:
pass
from module4 import name4
"""
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_instance_mapping_remains_empty():
"""Test that instance_mapping remains empty after visiting ImportFrom nodes."""
code = "from module import name1, name2 as alias2"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_set_behavior_duplicates():
"""Test that imported_modules set prevents duplicates."""
# Manually create scenario where same name could be added twice
node1 = ast.ImportFrom(module="module1", names=[ast.alias(name="name", asname=None)], level=0)
node2 = ast.ImportFrom(module="module2", names=[ast.alias(name="name", asname=None)], level=0)
extractor = InstanceMappingExtractor()
extractor.visit_ImportFrom(node1) # 4.74μs -> 1.77μs (167% faster)
initial_size = len(extractor.imported_modules)
extractor.visit_ImportFrom(node2) # 2.77μs -> 892ns (211% faster)
final_size = len(extractor.imported_modules)
def test_visit_import_from_generic_visit_called():
"""Test that generic_visit is properly called for child traversal."""
# Create a nested AST structure that would require generic_visit
code = """
from module import (
name1,
name2,
name3
)
"""
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_return_value_is_none():
"""Test that visit_ImportFrom returns None (standard NodeVisitor behavior)."""
code = "from module import name"
tree = ast.parse(code)
# Get the ImportFrom node
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom):
extractor = InstanceMappingExtractor()
codeflash_output = extractor.visit_ImportFrom(node); result = codeflash_output
break
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.To test or edit this optimization locally git merge codeflash/optimize-pr1418-2026-02-17T13.07.22
Click to see suggested changes
| for alias in node.names: | |
| if alias.name == "*": | |
| continue | |
| imported_name = alias.asname if alias.asname else alias.name | |
| self.imported_modules.add(imported_name) | |
| if alias.asname: | |
| self.alias_mapping[imported_name] = alias.name | |
| self.generic_visit(node) | |
| # Preserve original behavior: if a subclass has overridden generic_visit, | |
| # call it to allow that custom traversal to run. Otherwise, avoid the | |
| # heavy generic traversal and only visit alias children if a specific | |
| # visit_alias handler exists. | |
| if getattr(type(self), "generic_visit", ast.NodeVisitor.generic_visit) is not ast.NodeVisitor.generic_visit: | |
| self.generic_visit(node) | |
| return | |
| has_visit_alias = hasattr(self, "visit_alias") | |
| for alias in node.names: | |
| if alias.name == "*": | |
| continue | |
| imported_name = alias.asname if alias.asname else alias.name | |
| self.imported_modules.add(imported_name) | |
| if alias.asname: | |
| self.alias_mapping[imported_name] = alias.name | |
| if has_visit_alias: | |
| # If a visitor for alias nodes exists, invoke it to match | |
| # the behavior of the default generic_visit which would have | |
| # dispatched to visit_alias. | |
| self.visit(alias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit (non-blocking):
instance_variable_namesis accumulated across all test functions without being cleared. If a file has multiple test functions, variable names collected fromtest_awill persist when processingtest_b. This could cause false-positive instrumentation if a variable name from one test happens to be called in another.Consider clearing the set at the start of each test function: