Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions codeflash/languages/python/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import ast
import hashlib
import os
from collections import defaultdict
from collections import defaultdict, deque
from itertools import chain
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -746,33 +746,49 @@ def collect_type_names_from_annotation(node: ast.expr | None) -> set[str]:

def extract_init_stub_from_class(class_name: str, module_source: str, module_tree: ast.Module) -> str | None:
class_node = None
for node in ast.walk(module_tree):
if isinstance(node, ast.ClassDef) and node.name == class_name:
class_node = node
# Use a deque-based BFS to find the first matching ClassDef (preserves ast.walk order)
q: deque[ast.AST] = deque([module_tree])
while q:
candidate = q.popleft()
if isinstance(candidate, ast.ClassDef) and candidate.name == class_name:
class_node = candidate
break
q.extend(ast.iter_child_nodes(candidate))

if class_node is None:
return None

lines = module_source.splitlines()
relevant_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = []
for item in class_node.body:
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
if item.name in ("__init__", "__post_init__") or any(
(isinstance(d, ast.Name) and d.id == "property")
or (isinstance(d, ast.Attribute) and d.attr == "property")
for d in item.decorator_list
):
is_relevant = False
if item.name in ("__init__", "__post_init__"):
is_relevant = True
else:
# Check decorators explicitly to avoid generator overhead
for d in item.decorator_list:
if (isinstance(d, ast.Name) and d.id == "property") or (
isinstance(d, ast.Attribute) and d.attr == "property"
):
is_relevant = True
break
if is_relevant:
relevant_nodes.append(item)

if not relevant_nodes:
return None

snippets: list[str] = []
for node in relevant_nodes:
start = node.lineno
if node.decorator_list:
start = min(d.lineno for d in node.decorator_list)
snippets.append("\n".join(lines[start - 1 : node.end_lineno]))
for fn_node in relevant_nodes:
start = fn_node.lineno
if fn_node.decorator_list:
# Compute minimum decorator lineno with an explicit loop (avoids generator/min overhead)
m = start
for d in fn_node.decorator_list:
m = min(m, d.lineno)
start = m
snippets.append("\n".join(lines[start - 1 : fn_node.end_lineno]))

return f"class {class_name}:\n" + "\n".join(snippets)

Expand Down