diff --git a/codeflash/languages/java/remove_asserts.py b/codeflash/languages/java/remove_asserts.py index a9050c7ca..68c977f0d 100644 --- a/codeflash/languages/java/remove_asserts.py +++ b/codeflash/languages/java/remove_asserts.py @@ -649,11 +649,16 @@ def _find_top_level_arg_node(self, target_node: Node, wrapper_bytes: bytes) -> N """ current = target_node passed_through_regular_call = False - while current.parent is not None: + + while True: parent = current.parent - if parent.type == "argument_list" and parent.parent is not None: + if parent is None: + return None + + parent_type = parent.type + if parent_type == "argument_list": grandparent = parent.parent - if grandparent.type == "method_invocation": + if grandparent is not None and grandparent.type == "method_invocation": gp_name = grandparent.child_by_field_name("name") if gp_name: name = self.analyzer.get_node_text(gp_name, wrapper_bytes) @@ -663,8 +668,7 @@ def _find_top_level_arg_node(self, target_node: Node, wrapper_bytes: bytes) -> N return None if not name.startswith("assert"): passed_through_regular_call = True - current = current.parent - return None + current = parent def _detect_variable_assignment(self, source: str, assertion_start: int) -> tuple[str | None, str | None]: """Check if assertion is assigned to a variable. diff --git a/codeflash/languages/registry.py b/codeflash/languages/registry.py index 38688cab6..53debd074 100644 --- a/codeflash/languages/registry.py +++ b/codeflash/languages/registry.py @@ -7,7 +7,10 @@ from __future__ import annotations +import contextlib +import importlib import logging +import sys from pathlib import Path from typing import TYPE_CHECKING @@ -47,16 +50,18 @@ def _ensure_languages_registered() -> None: # Import support modules to trigger registration # These imports are deferred to avoid circular imports - import contextlib - - with contextlib.suppress(ImportError): - from codeflash.languages.python import support as _ - - with contextlib.suppress(ImportError): - from codeflash.languages.javascript import support as _ - - with contextlib.suppress(ImportError): - from codeflash.languages.java import support as _ + module_names = ( + "codeflash.languages.python.support", + "codeflash.languages.javascript.support", + "codeflash.languages.java.support", + ) + + for name in module_names: + # Avoid the cost of importlib.import_module when the module is already loaded. + if name in sys.modules: + continue + with contextlib.suppress(ImportError): + importlib.import_module(name) _languages_registered = True