Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions codeflash/languages/javascript/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def split_call_args(args_str: str) -> tuple[str, str]:
s = args_str
s_len = len(s)

for i in range(s_len):
i = 0
while i < s_len:
char = s[i]

if char in "\"'`" and (i == 0 or s[i - 1] != "\\"):
Expand All @@ -129,9 +130,11 @@ def split_call_args(args_str: str) -> tuple[str, str]:
elif char == string_char:
in_string = False
string_char = None
i += 1
continue

if in_string:
i += 1
continue

if char in "([{":
Expand All @@ -141,6 +144,8 @@ def split_call_args(args_str: str) -> tuple[str, str]:
elif char == "," and depth == 0:
return s[:i].strip(), s[i + 1 :].strip()

i += 1

return s.strip(), ""


Expand Down Expand Up @@ -599,6 +604,9 @@ def __init__(
rf"(\s*)expect\s*\(\s*((?:\w+\.)*){re.escape(self.func_name)}\.call\s*\("
)

# Cache whitespace characters for faster checking
self._whitespace = frozenset(" \t\n\r")

def transform(self, code: str) -> str:
"""Transform all expect calls in the code."""
result: list[str] = []
Expand Down Expand Up @@ -752,10 +760,11 @@ def _parse_expect_dot_call(self, code: str, match: re.Match[str]) -> ExpectCallM
return None

# Find closing ) of expect(
code_len = len(code)
expect_close_pos = call_close_pos
while expect_close_pos < len(code) and code[expect_close_pos].isspace():
while expect_close_pos < code_len and code[expect_close_pos] in self._whitespace:
expect_close_pos += 1
if expect_close_pos >= len(code) or code[expect_close_pos] != ")":
if expect_close_pos >= code_len or code[expect_close_pos] != ")":
return None
expect_close_pos += 1

Expand All @@ -764,7 +773,7 @@ def _parse_expect_dot_call(self, code: str, match: re.Match[str]) -> ExpectCallM
if assertion_chain is None:
return None

has_trailing_semicolon = chain_end_pos < len(code) and code[chain_end_pos] == ";"
has_trailing_semicolon = chain_end_pos < code_len and code[chain_end_pos] == ";"
if has_trailing_semicolon:
chain_end_pos += 1

Expand All @@ -790,15 +799,16 @@ def _find_balanced_parens(self, code: str, open_paren_pos: int) -> tuple[str | N
Tuple of (content inside parens, position after closing paren) or (None, -1)

"""
if open_paren_pos >= len(code) or code[open_paren_pos] != "(":
code_len = len(code)
if open_paren_pos >= code_len or code[open_paren_pos] != "(":
return None, -1

depth = 1
pos = open_paren_pos + 1
in_string = False
string_char = None

while pos < len(code) and depth > 0:
while pos < code_len and depth > 0:
char = code[pos]

# Handle string literals
Expand Down Expand Up @@ -841,32 +851,33 @@ def _parse_assertion_chain(self, code: str, start_pos: int) -> tuple[str | None,
"""
pos = start_pos
chain_parts: list[str] = []
code_len = len(code)

# Skip any leading whitespace (for multi-line)
while pos < len(code) and code[pos] in " \t\n\r":
while pos < code_len and code[pos] in self._whitespace:
pos += 1

# Must start with a dot
if pos >= len(code) or code[pos] != ".":
if pos >= code_len or code[pos] != ".":
return None, -1

while pos < len(code):
while pos < code_len:
# Skip whitespace between chain elements
while pos < len(code) and code[pos] in " \t\n\r":
while pos < code_len and code[pos] in self._whitespace:
pos += 1

if pos >= len(code) or code[pos] != ".":
if pos >= code_len or code[pos] != ".":
break

pos += 1 # Skip the dot

# Skip whitespace after dot
while pos < len(code) and code[pos] in " \t\n\r":
while pos < code_len and code[pos] in self._whitespace:
pos += 1

# Parse the method name
method_start = pos
while pos < len(code) and (code[pos].isalnum() or code[pos] == "_"):
while pos < code_len and (code[pos].isalnum() or code[pos] == "_"):
pos += 1

if pos == method_start:
Expand All @@ -875,11 +886,11 @@ def _parse_assertion_chain(self, code: str, start_pos: int) -> tuple[str | None,
method_name = code[method_start:pos]

# Skip whitespace before potential parens
while pos < len(code) and code[pos] in " \t\n\r":
while pos < code_len and code[pos] in self._whitespace:
pos += 1

# Check for parentheses (method call)
if pos < len(code) and code[pos] == "(":
if pos < code_len and code[pos] == "(":
args_content, after_paren = self._find_balanced_parens(code, pos)
if args_content is None:
return None, -1
Expand Down
Loading