diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py index 01ea24661..65648e5fe 100644 --- a/codeflash/languages/java/instrumentation.py +++ b/codeflash/languages/java/instrumentation.py @@ -727,24 +727,30 @@ def has_test_annotation(method_node: Any) -> bool: return False def collect_test_methods(node: Any, out: list[tuple[Any, Any]]) -> None: - if node.type == "method_declaration" and has_test_annotation(node): - body_node = node.child_by_field_name("body") - if body_node is not None: - out.append((node, body_node)) - for child in node.children: - collect_test_methods(child, out) + stack = [node] + while stack: + current = stack.pop() + if current.type == "method_declaration" and has_test_annotation(current): + body_node = current.child_by_field_name("body") + if body_node is not None: + out.append((current, body_node)) + continue + if current.children: + stack.extend(reversed(current.children)) def collect_target_calls(node: Any, wrapper_bytes: bytes, func: str, out: list[Any]) -> None: - if node.type == "method_invocation": - name_node = node.child_by_field_name("name") - if name_node and analyzer.get_node_text(name_node, wrapper_bytes) == func: - # Skip if inside lambda or complex expression - if not _is_inside_lambda(node) and not _is_inside_complex_expression(node): - out.append(node) - else: - logger.debug("Skipping instrumentation of %s inside lambda or complex expression", func) - for child in node.children: - collect_target_calls(child, wrapper_bytes, func, out) + stack = [node] + while stack: + current = stack.pop() + if current.type == "method_invocation": + name_node = current.child_by_field_name("name") + if name_node and analyzer.get_node_text(name_node, wrapper_bytes) == func: + if not _is_inside_lambda(current) and not _is_inside_complex_expression(current): + out.append(current) + else: + logger.debug("Skipping instrumentation of %s inside lambda or complex expression", func) + if current.children: + stack.extend(reversed(current.children)) def reindent_block(text: str, target_indent: str) -> str: lines = text.splitlines()