|
3 | 3 | import ast |
4 | 4 | import hashlib |
5 | 5 | import os |
6 | | -from collections import defaultdict |
| 6 | +from collections import defaultdict, deque |
7 | 7 | from itertools import chain |
8 | 8 | from typing import TYPE_CHECKING |
9 | 9 |
|
@@ -746,33 +746,49 @@ def collect_type_names_from_annotation(node: ast.expr | None) -> set[str]: |
746 | 746 |
|
747 | 747 | def extract_init_stub_from_class(class_name: str, module_source: str, module_tree: ast.Module) -> str | None: |
748 | 748 | class_node = None |
749 | | - for node in ast.walk(module_tree): |
750 | | - if isinstance(node, ast.ClassDef) and node.name == class_name: |
751 | | - class_node = node |
| 749 | + # Use a deque-based BFS to find the first matching ClassDef (preserves ast.walk order) |
| 750 | + q: deque[ast.AST] = deque([module_tree]) |
| 751 | + while q: |
| 752 | + candidate = q.popleft() |
| 753 | + if isinstance(candidate, ast.ClassDef) and candidate.name == class_name: |
| 754 | + class_node = candidate |
752 | 755 | break |
| 756 | + q.extend(ast.iter_child_nodes(candidate)) |
| 757 | + |
753 | 758 | if class_node is None: |
754 | 759 | return None |
755 | 760 |
|
756 | 761 | lines = module_source.splitlines() |
757 | 762 | relevant_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = [] |
758 | 763 | for item in class_node.body: |
759 | 764 | if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): |
760 | | - if item.name in ("__init__", "__post_init__") or any( |
761 | | - (isinstance(d, ast.Name) and d.id == "property") |
762 | | - or (isinstance(d, ast.Attribute) and d.attr == "property") |
763 | | - for d in item.decorator_list |
764 | | - ): |
| 765 | + is_relevant = False |
| 766 | + if item.name in ("__init__", "__post_init__"): |
| 767 | + is_relevant = True |
| 768 | + else: |
| 769 | + # Check decorators explicitly to avoid generator overhead |
| 770 | + for d in item.decorator_list: |
| 771 | + if (isinstance(d, ast.Name) and d.id == "property") or ( |
| 772 | + isinstance(d, ast.Attribute) and d.attr == "property" |
| 773 | + ): |
| 774 | + is_relevant = True |
| 775 | + break |
| 776 | + if is_relevant: |
765 | 777 | relevant_nodes.append(item) |
766 | 778 |
|
767 | 779 | if not relevant_nodes: |
768 | 780 | return None |
769 | 781 |
|
770 | 782 | snippets: list[str] = [] |
771 | | - for node in relevant_nodes: |
772 | | - start = node.lineno |
773 | | - if node.decorator_list: |
774 | | - start = min(d.lineno for d in node.decorator_list) |
775 | | - snippets.append("\n".join(lines[start - 1 : node.end_lineno])) |
| 783 | + for fn_node in relevant_nodes: |
| 784 | + start = fn_node.lineno |
| 785 | + if fn_node.decorator_list: |
| 786 | + # Compute minimum decorator lineno with an explicit loop (avoids generator/min overhead) |
| 787 | + m = start |
| 788 | + for d in fn_node.decorator_list: |
| 789 | + m = min(m, d.lineno) |
| 790 | + start = m |
| 791 | + snippets.append("\n".join(lines[start - 1 : fn_node.end_lineno])) |
776 | 792 |
|
777 | 793 | return f"class {class_name}:\n" + "\n".join(snippets) |
778 | 794 |
|
|
0 commit comments