diff --git a/codeflash/languages/javascript/treesitter_utils.py b/codeflash/languages/javascript/treesitter_utils.py index 23c3cdfb1..3b9910343 100644 --- a/codeflash/languages/javascript/treesitter_utils.py +++ b/codeflash/languages/javascript/treesitter_utils.py @@ -1227,8 +1227,6 @@ def has_return_statement(self, function_node: FunctionNode, source: str) -> bool True if the function has a return statement. """ - source_bytes = source.encode("utf8") - # Generator functions always implicitly return a Generator/Iterator if function_node.is_generator: return True @@ -1244,20 +1242,32 @@ def has_return_statement(self, function_node: FunctionNode, source: str) -> bool def _node_has_return(self, node: Node) -> bool: """Recursively check if a node contains a return statement.""" - if node.type == "return_statement": - return True + # Use an explicit stack to avoid recursion overhead while preserving traversal order. + func_types = ("function_declaration", "function_expression", "arrow_function", "method_definition") + stack = [node] + while stack: + current = stack.pop() + # Direct return statement check + if current.type == "return_statement": + return True + + # If this node is a function-like node, only traverse its body children + if current.type in func_types: + body_node = current.child_by_field_name("body") + if body_node: + # Push children in reverse so they are processed in original order + children = body_node.children + if children: + stack.extend(reversed(children)) + # Do not traverse other parts of the function node + continue + + # General case: traverse all children + children = current.children + if children: + stack.extend(reversed(children)) - # Don't recurse into nested function definitions - if node.type in ("function_declaration", "function_expression", "arrow_function", "method_definition"): - # Only check the current function, not nested ones - body_node = node.child_by_field_name("body") - if body_node: - for child in body_node.children: - if self._node_has_return(child): - return True - return False - - return any(self._node_has_return(child) for child in node.children) + return False def extract_type_annotations(self, source: str, function_name: str, function_line: int) -> set[str]: """Extract type annotation names from a function's parameters and return type.