diff --git a/codeflash/languages/java/context.py b/codeflash/languages/java/context.py index a5597351c..701dddd96 100644 --- a/codeflash/languages/java/context.py +++ b/codeflash/languages/java/context.py @@ -413,18 +413,20 @@ def _extract_type_body_context( enum_constant_parts: list[str] = [] for child in body_node.children: + child_type = child.type + # Skip braces, semicolons, and commas - if child.type in ("{", "}", ";", ","): + if child_type in ("{", "}", ";", ","): continue # Handle enum constants (only for enums) # Extract just the constant name/text, not the whole line - if child.type == "enum_constant" and type_kind == "enum": + if child_type == "enum_constant" and type_kind == "enum": constant_text = source_bytes[child.start_byte : child.end_byte].decode("utf8") enum_constant_parts.append(constant_text) # Handle field declarations - elif child.type == "field_declaration": + elif child_type == "field_declaration": start_line = child.start_point[0] end_line = child.end_point[0] @@ -436,18 +438,16 @@ def _extract_type_body_context( if comment_text.strip().startswith("/**"): javadoc_start = prev_sibling.start_point[0] - field_lines = lines[javadoc_start : end_line + 1] - field_parts.append("".join(field_lines)) + field_parts.extend(lines[javadoc_start : end_line + 1]) # Handle constant declarations (for interfaces) - elif child.type == "constant_declaration" and type_kind == "interface": + elif child_type == "constant_declaration" and type_kind == "interface": start_line = child.start_point[0] end_line = child.end_point[0] - constant_lines = lines[start_line : end_line + 1] - field_parts.append("".join(constant_lines)) + field_parts.extend(lines[start_line : end_line + 1]) # Handle constructor declarations - elif child.type == "constructor_declaration": + elif child_type == "constructor_declaration": start_line = child.start_point[0] end_line = child.end_point[0] @@ -459,8 +459,8 @@ def _extract_type_body_context( if comment_text.strip().startswith("/**"): javadoc_start = prev_sibling.start_point[0] - constructor_lines = lines[javadoc_start : end_line + 1] - constructor_parts.append("".join(constructor_lines)) + constructor_parts.extend(lines[javadoc_start : end_line + 1]) + fields_code = "".join(field_parts) constructors_code = "".join(constructor_parts)