diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index b635fd529..7b024b910 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -1301,38 +1301,39 @@ def _find_and_extract_body(self, source: str, function_name: str, analyzer: Tree def find_function_node(node: Any, target_name: str) -> Any: """Recursively find a function/method with the given name.""" - # Check method definitions - if node.type == "method_definition": - name_node = node.child_by_field_name("name") - if name_node: - name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") - if name == target_name: - return node - - # Check function declarations - if node.type in ("function_declaration", "function"): - name_node = node.child_by_field_name("name") - if name_node: - name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") - if name == target_name: - return node - - # Check arrow functions assigned to variables - if node.type == "lexical_declaration": - for child in node.children: - if child.type == "variable_declarator": - name_node = child.child_by_field_name("name") - value_node = child.child_by_field_name("value") - if name_node and value_node and value_node.type == "arrow_function": - name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") - if name == target_name: - return value_node + target_bytes = target_name.encode("utf8") + stack = [node] + while stack: + n = stack.pop() + + # Check method definitions + if n.type == "method_definition": + name_node = n.child_by_field_name("name") + if name_node: + if source_bytes[name_node.start_byte : name_node.end_byte] == target_bytes: + return n + + # Check function declarations + if n.type in ("function_declaration", "function"): + name_node = n.child_by_field_name("name") + if name_node: + if source_bytes[name_node.start_byte : name_node.end_byte] == target_bytes: + return n + + # Check arrow functions assigned to variables + if n.type == "lexical_declaration": + for child in n.children: + if child.type == "variable_declarator": + name_node = child.child_by_field_name("name") + value_node = child.child_by_field_name("value") + if name_node and value_node and value_node.type == "arrow_function": + if source_bytes[name_node.start_byte : name_node.end_byte] == target_bytes: + return value_node + + # Recurse into children (stack push) + for child in n.children: + stack.append(child) - # Recurse into children - for child in node.children: - result = find_function_node(child, target_name) - if result: - return result return None