Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions codeflash/languages/javascript/treesitter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,8 +1227,6 @@ def has_return_statement(self, function_node: FunctionNode, source: str) -> bool
True if the function has a return statement.

"""
source_bytes = source.encode("utf8")

# Generator functions always implicitly return a Generator/Iterator
if function_node.is_generator:
return True
Expand All @@ -1244,20 +1242,32 @@ def has_return_statement(self, function_node: FunctionNode, source: str) -> bool

def _node_has_return(self, node: Node) -> bool:
"""Recursively check if a node contains a return statement."""
if node.type == "return_statement":
return True
# Use an explicit stack to avoid recursion overhead while preserving traversal order.
func_types = ("function_declaration", "function_expression", "arrow_function", "method_definition")
stack = [node]
while stack:
current = stack.pop()
# Direct return statement check
if current.type == "return_statement":
return True

# If this node is a function-like node, only traverse its body children
if current.type in func_types:
body_node = current.child_by_field_name("body")
if body_node:
# Push children in reverse so they are processed in original order
children = body_node.children
if children:
stack.extend(reversed(children))
# Do not traverse other parts of the function node
continue

# General case: traverse all children
children = current.children
if children:
stack.extend(reversed(children))

# Don't recurse into nested function definitions
if node.type in ("function_declaration", "function_expression", "arrow_function", "method_definition"):
# Only check the current function, not nested ones
body_node = node.child_by_field_name("body")
if body_node:
for child in body_node.children:
if self._node_has_return(child):
return True
return False

return any(self._node_has_return(child) for child in node.children)
return False

def extract_type_annotations(self, source: str, function_name: str, function_line: int) -> set[str]:
"""Extract type annotation names from a function's parameters and return type.
Expand Down