diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 00b077f63..3547623ae 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -638,7 +638,23 @@ def _analyze_imports_in_optimized_code( helpers_by_file_and_func[module_name].setdefault(func_name, []).append(helper) helpers_by_file[module_name].append(helper) - for node in ast.walk(optimized_ast): + # Collect only import nodes to avoid per-node isinstance checks across the whole AST + class _ImportCollector(ast.NodeVisitor): + def __init__(self) -> None: + self.nodes: list[ast.AST] = [] + + def visit_Import(self, node: ast.Import) -> None: + self.nodes.append(node) + # No need to recurse further for import nodes + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + self.nodes.append(node) + # No need to recurse further for import-from nodes + + collector = _ImportCollector() + collector.visit(optimized_ast) + + for node in collector.nodes: if isinstance(node, ast.ImportFrom): # Handle "from module import function" statements module_name = node.module