diff --git a/codeflash/languages/java/remove_asserts.py b/codeflash/languages/java/remove_asserts.py index a9050c7ca..2d01f83d7 100644 --- a/codeflash/languages/java/remove_asserts.py +++ b/codeflash/languages/java/remove_asserts.py @@ -193,6 +193,9 @@ def __init__( # Precompile the assignment-detection regex to avoid recompiling on each call. self._assign_re = re.compile(r"(\w+(?:<[^>]+>)?)\s+(\w+)\s*=\s*$") + # Precompile regex to find next special character (single-quote, double-quote, brace). + self._special_re = re.compile(r"[\"'{}]") + def transform(self, source: str) -> str: """Remove assertions from source code, preserving target function calls. @@ -843,30 +846,42 @@ def _find_balanced_braces(self, code: str, open_brace_pos: int) -> tuple[str | N depth = 1 pos = open_brace_pos + 1 - in_string = False - string_char = None - in_char = False + code_len = len(code) + special_re = self._special_re + + while pos < code_len and depth > 0: + m = special_re.search(code, pos) + if m is None: + return None, -1 + + idx = m.start() + char = m.group() + prev_char = code[idx - 1] if idx > 0 else "" + + if char == "'" and prev_char != "\\": + j = code.find("'", idx + 1) + while j != -1 and j > 0 and code[j - 1] == "\\": + j = code.find("'", j + 1) + if j == -1: + return None, -1 + pos = j + 1 + continue - while pos < len(code) and depth > 0: - char = code[pos] - prev_char = code[pos - 1] if pos > 0 else "" + if char == '"' and prev_char != "\\": + j = code.find('"', idx + 1) + while j != -1 and j > 0 and code[j - 1] == "\\": + j = code.find('"', j + 1) + if j == -1: + return None, -1 + pos = j + 1 + continue - if char == "'" and not in_string and prev_char != "\\": - in_char = not in_char - elif char == '"' and not in_char and prev_char != "\\": - if not in_string: - in_string = True - string_char = char - elif char == string_char: - in_string = False - string_char = None - elif not in_string and not in_char: - if char == "{": - depth += 1 - elif char == "}": - depth -= 1 + if char == "{": + depth += 1 + elif char == "}": + depth -= 1 - pos += 1 + pos = idx + 1 if depth != 0: return None, -1