diff --git a/codeflash/languages/java/context.py b/codeflash/languages/java/context.py index 2ccfd34bf..559d4c811 100644 --- a/codeflash/languages/java/context.py +++ b/codeflash/languages/java/context.py @@ -20,6 +20,12 @@ if TYPE_CHECKING: from tree_sitter import Node +_TYPE_DECLARATIONS = { + "class_declaration": "class", + "interface_declaration": "interface", + "enum_declaration": "enum", +} + logger = logging.getLogger(__name__) @@ -254,18 +260,13 @@ def _find_type_node(node: Node, type_name: str, source_bytes: bytes) -> tuple[No Tuple of (node, type_kind) where type_kind is "class", "interface", or "enum". """ - type_declarations = { - "class_declaration": "class", - "interface_declaration": "interface", - "enum_declaration": "enum", - } - - if node.type in type_declarations: + if node.type in _TYPE_DECLARATIONS: name_node = node.child_by_field_name("name") if name_node: node_name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") if node_name == type_name: - return node, type_declarations[node.type] + return node, _TYPE_DECLARATIONS[node.type] + for child in node.children: result, kind = _find_type_node(child, type_name, source_bytes)