From a4ed93e22958deea736eb90c4d5f8aa9f46e0e99 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Tue, 3 Feb 2026 14:26:27 -0800 Subject: [PATCH 1/4] feat: add gpu flag for CUDA event-based timing Add a `gpu` parameter to instrument tests with torch.cuda.Event timing instead of time.perf_counter_ns() for measuring GPU kernel execution time. Falls back to CPU timing when CUDA is not available/initialized. Co-Authored-By: Claude Opus 4.5 --- .../code_utils/instrument_existing_tests.py | 478 +++++++++++++++--- .../test_inject_profiling_used_frameworks.py | 432 ++++++++++++++++ 2 files changed, 829 insertions(+), 81 deletions(-) diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 4366468d0..f3e929688 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -636,6 +636,7 @@ def inject_async_profiling_into_existing_test( function_to_optimize: FunctionToOptimize, tests_project_root: Path, mode: TestingMode = TestingMode.BEHAVIOR, + gpu: bool = False, ) -> tuple[bool, str | None]: """Inject profiling for async function calls by setting environment variables before each call.""" with test_path.open(encoding="utf8") as f: @@ -708,6 +709,7 @@ def inject_profiling_into_existing_test( function_to_optimize: FunctionToOptimize, tests_project_root: Path, mode: TestingMode = TestingMode.BEHAVIOR, + gpu: bool = False, ) -> tuple[bool, str | None]: if function_to_optimize.is_async: return inject_async_profiling_into_existing_test( @@ -752,7 +754,7 @@ def inject_profiling_into_existing_test( else: # If there's an alias, use it (e.g., "import torch as th") new_imports.append(ast.Import(names=[ast.alias(name=framework_name, asname=framework_alias)])) - additional_functions = [create_wrapper_function(mode, used_frameworks)] + additional_functions = [create_wrapper_function(mode, used_frameworks, gpu)] tree.body = [*new_imports, *additional_functions, *tree.body] return True, sort_imports(ast.unparse(tree), float_to_top=True) @@ -908,6 +910,60 @@ def _create_device_sync_precompute_statements(used_frameworks: dict[str, str] | return precompute_statements +def _create_gpu_event_timing_precompute_statements(used_frameworks: dict[str, str] | None) -> list[ast.stmt]: + """Create AST statements to pre-compute GPU event timing conditions. + + This generates: + _codeflash_use_gpu_timer = torch.cuda.is_available() and torch.cuda.is_initialized() + + Args: + used_frameworks: Dict mapping framework names to their import aliases + + Returns: + List of AST statements that pre-compute GPU timer availability + + """ + if not used_frameworks or "torch" not in used_frameworks: + return [] + + torch_alias = used_frameworks["torch"] + + # _codeflash_use_gpu_timer = torch.cuda.is_available() and torch.cuda.is_initialized() + return [ + ast.Assign( + targets=[ast.Name(id="_codeflash_use_gpu_timer", ctx=ast.Store())], + value=ast.BoolOp( + op=ast.And(), + values=[ + ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load() + ), + attr="is_available", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ), + ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load() + ), + attr="is_initialized", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ), + ], + ), + lineno=1, + ) + ] + + def _create_device_sync_statements( used_frameworks: dict[str, str] | None, for_return_value: bool = False ) -> list[ast.stmt]: @@ -1030,8 +1086,338 @@ def _create_device_sync_statements( return sync_statements +def _create_gpu_timing_try_body(torch_alias: str) -> list[ast.stmt]: + """Create AST statements for the GPU event timing try body. + + Generates: + _codeflash_start_event = torch.cuda.Event(enable_timing=True) + _codeflash_end_event = torch.cuda.Event(enable_timing=True) + _codeflash_start_event.record() + return_value = codeflash_wrapped(*args, **kwargs) + _codeflash_end_event.record() + torch.cuda.synchronize() + codeflash_duration = int(_codeflash_start_event.elapsed_time(_codeflash_end_event) * 1_000_000) + + Args: + torch_alias: The import alias for torch (e.g., "torch" or "th") + + Returns: + List of AST statements for GPU event timing + + """ + return [ + # _codeflash_start_event = torch.cuda.Event(enable_timing=True) + ast.Assign( + targets=[ast.Name(id="_codeflash_start_event", ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute(value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load()), + attr="Event", + ctx=ast.Load(), + ), + args=[], + keywords=[ast.keyword(arg="enable_timing", value=ast.Constant(value=True))], + ), + lineno=1, + ), + # _codeflash_end_event = torch.cuda.Event(enable_timing=True) + ast.Assign( + targets=[ast.Name(id="_codeflash_end_event", ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute(value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load()), + attr="Event", + ctx=ast.Load(), + ), + args=[], + keywords=[ast.keyword(arg="enable_timing", value=ast.Constant(value=True))], + ), + lineno=1, + ), + # _codeflash_start_event.record() + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="_codeflash_start_event", ctx=ast.Load()), attr="record", ctx=ast.Load() + ), + args=[], + keywords=[], + ) + ), + # return_value = codeflash_wrapped(*args, **kwargs) + ast.Assign( + targets=[ast.Name(id="return_value", ctx=ast.Store())], + value=ast.Call( + func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()), + args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())], + keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))], + ), + lineno=1, + ), + # _codeflash_end_event.record() + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="_codeflash_end_event", ctx=ast.Load()), attr="record", ctx=ast.Load() + ), + args=[], + keywords=[], + ) + ), + # torch.cuda.synchronize() + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute(value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load()), + attr="synchronize", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) + ), + # codeflash_duration = int(_codeflash_start_event.elapsed_time(_codeflash_end_event) * 1_000_000) + ast.Assign( + targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], + value=ast.Call( + func=ast.Name(id="int", ctx=ast.Load()), + args=[ + ast.BinOp( + left=ast.Call( + func=ast.Attribute( + value=ast.Name(id="_codeflash_start_event", ctx=ast.Load()), + attr="elapsed_time", + ctx=ast.Load(), + ), + args=[ast.Name(id="_codeflash_end_event", ctx=ast.Load())], + keywords=[], + ), + op=ast.Mult(), + right=ast.Constant(value=1_000_000), + ) + ], + keywords=[], + ), + lineno=1, + ), + ] + + +def _create_gpu_timing_except_body(torch_alias: str) -> list[ast.stmt]: + """Create AST statements for the GPU event timing exception handler. + + Generates: + torch.cuda.synchronize() + codeflash_duration = 0 + exception = e + + Args: + torch_alias: The import alias for torch (e.g., "torch" or "th") + + Returns: + List of AST statements for GPU timing exception handling + + """ + return [ + # torch.cuda.synchronize() + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute(value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load()), + attr="synchronize", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) + ), + # codeflash_duration = 0 + ast.Assign(targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], value=ast.Constant(value=0), lineno=1), + # exception = e + ast.Assign( + targets=[ast.Name(id="exception", ctx=ast.Store())], value=ast.Name(id="e", ctx=ast.Load()), lineno=1 + ), + ] + + +def _create_cpu_timing_try_body(used_frameworks: dict[str, str] | None) -> list[ast.stmt]: + """Create AST statements for the CPU timing try body. + + Generates standard time.perf_counter_ns() timing with device sync. + + Args: + used_frameworks: Dict mapping framework names to their import aliases + + Returns: + List of AST statements for CPU timing + + """ + return [ + # Pre-sync: synchronize device before starting timer + *_create_device_sync_statements(used_frameworks, for_return_value=False), + # counter = time.perf_counter_ns() + ast.Assign( + targets=[ast.Name(id="counter", ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load()), + args=[], + keywords=[], + ), + lineno=1, + ), + # return_value = codeflash_wrapped(*args, **kwargs) + ast.Assign( + targets=[ast.Name(id="return_value", ctx=ast.Store())], + value=ast.Call( + func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()), + args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())], + keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))], + ), + lineno=1, + ), + # Post-sync: synchronize device after function call + *_create_device_sync_statements(used_frameworks, for_return_value=True), + # codeflash_duration = time.perf_counter_ns() - counter + ast.Assign( + targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], + value=ast.BinOp( + left=ast.Call( + func=ast.Attribute( + value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load() + ), + args=[], + keywords=[], + ), + op=ast.Sub(), + right=ast.Name(id="counter", ctx=ast.Load()), + ), + lineno=1, + ), + ] + + +def _create_cpu_timing_except_body() -> list[ast.stmt]: + """Create AST statements for the CPU timing exception handler. + + Generates: + codeflash_duration = time.perf_counter_ns() - counter + exception = e + + Returns: + List of AST statements for CPU timing exception handling + + """ + return [ + # codeflash_duration = time.perf_counter_ns() - counter + ast.Assign( + targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], + value=ast.BinOp( + left=ast.Call( + func=ast.Attribute( + value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load() + ), + args=[], + keywords=[], + ), + op=ast.Sub(), + right=ast.Name(id="counter", ctx=ast.Load()), + ), + lineno=1, + ), + # exception = e + ast.Assign( + targets=[ast.Name(id="exception", ctx=ast.Store())], value=ast.Name(id="e", ctx=ast.Load()), lineno=1 + ), + ] + + +def _create_timing_try_block(used_frameworks: dict[str, str] | None, gpu: bool, lineno: int) -> list[ast.stmt]: + """Create the timing try block, handling both GPU and CPU timing modes. + + When gpu=True and torch is available, generates an if/else structure: + if _codeflash_use_gpu_timer: + # GPU event timing path + else: + # CPU timing fallback path + + Otherwise, generates standard CPU timing. + + Args: + used_frameworks: Dict mapping framework names to their import aliases + gpu: Whether to use GPU event timing when possible + lineno: Current line number for AST nodes + + Returns: + List containing the try statement(s) for timing + + """ + use_gpu_timing = gpu and used_frameworks and "torch" in used_frameworks + + if use_gpu_timing: + torch_alias = used_frameworks["torch"] + + # Create GPU timing try block + gpu_try = ast.Try( + body=_create_gpu_timing_try_body(torch_alias), + handlers=[ + ast.ExceptHandler( + type=ast.Name(id="Exception", ctx=ast.Load()), + name="e", + body=_create_gpu_timing_except_body(torch_alias), + lineno=lineno + 14, + ) + ], + orelse=[], + finalbody=[], + lineno=lineno + 11, + ) + + # Create CPU timing try block (fallback) + cpu_try = ast.Try( + body=_create_cpu_timing_try_body(used_frameworks), + handlers=[ + ast.ExceptHandler( + type=ast.Name(id="Exception", ctx=ast.Load()), + name="e", + body=_create_cpu_timing_except_body(), + lineno=lineno + 14, + ) + ], + orelse=[], + finalbody=[], + lineno=lineno + 11, + ) + + # Wrap in if/else based on _codeflash_use_gpu_timer + return [ + ast.If( + test=ast.Name(id="_codeflash_use_gpu_timer", ctx=ast.Load()), + body=[gpu_try], + orelse=[cpu_try], + lineno=lineno + 11, + ) + ] + # Standard CPU timing + return [ + ast.Try( + body=_create_cpu_timing_try_body(used_frameworks), + handlers=[ + ast.ExceptHandler( + type=ast.Name(id="Exception", ctx=ast.Load()), + name="e", + body=_create_cpu_timing_except_body(), + lineno=lineno + 14, + ) + ], + orelse=[], + finalbody=[], + lineno=lineno + 11, + ) + ] + + def create_wrapper_function( - mode: TestingMode = TestingMode.BEHAVIOR, used_frameworks: dict[str, str] | None = None + mode: TestingMode = TestingMode.BEHAVIOR, used_frameworks: dict[str, str] | None = None, gpu: bool = False ) -> ast.FunctionDef: lineno = 1 wrapper_body: list[ast.stmt] = [ @@ -1193,8 +1579,14 @@ def create_wrapper_function( ast.Assign( targets=[ast.Name(id="exception", ctx=ast.Store())], value=ast.Constant(value=None), lineno=lineno + 10 ), - # Pre-compute device sync conditions before profiling to avoid overhead during timing - *_create_device_sync_precompute_statements(used_frameworks), + # Pre-compute conditions before profiling to avoid overhead during timing + *( + # When gpu=True with torch, we need both the GPU timer check AND device sync conditions for the fallback + _create_gpu_event_timing_precompute_statements(used_frameworks) + + _create_device_sync_precompute_statements(used_frameworks) + if gpu and used_frameworks and "torch" in used_frameworks + else _create_device_sync_precompute_statements(used_frameworks) + ), ast.Expr( value=ast.Call( func=ast.Attribute(value=ast.Name(id="gc", ctx=ast.Load()), attr="disable", ctx=ast.Load()), @@ -1203,83 +1595,7 @@ def create_wrapper_function( ), lineno=lineno + 9, ), - ast.Try( - body=[ - # Pre-sync: synchronize device before starting timer - *_create_device_sync_statements(used_frameworks, for_return_value=False), - ast.Assign( - targets=[ast.Name(id="counter", ctx=ast.Store())], - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load() - ), - args=[], - keywords=[], - ), - lineno=lineno + 11, - ), - ast.Assign( - targets=[ast.Name(id="return_value", ctx=ast.Store())], - value=ast.Call( - func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()), - args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())], - keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))], - ), - lineno=lineno + 12, - ), - # Post-sync: synchronize device after function call to ensure all device work is complete - *_create_device_sync_statements(used_frameworks, for_return_value=True), - ast.Assign( - targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], - value=ast.BinOp( - left=ast.Call( - func=ast.Attribute( - value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load() - ), - args=[], - keywords=[], - ), - op=ast.Sub(), - right=ast.Name(id="counter", ctx=ast.Load()), - ), - lineno=lineno + 13, - ), - ], - handlers=[ - ast.ExceptHandler( - type=ast.Name(id="Exception", ctx=ast.Load()), - name="e", - body=[ - ast.Assign( - targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], - value=ast.BinOp( - left=ast.Call( - func=ast.Attribute( - value=ast.Name(id="time", ctx=ast.Load()), - attr="perf_counter_ns", - ctx=ast.Load(), - ), - args=[], - keywords=[], - ), - op=ast.Sub(), - right=ast.Name(id="counter", ctx=ast.Load()), - ), - lineno=lineno + 15, - ), - ast.Assign( - targets=[ast.Name(id="exception", ctx=ast.Store())], - value=ast.Name(id="e", ctx=ast.Load()), - lineno=lineno + 13, - ), - ], - lineno=lineno + 14, - ) - ], - orelse=[], - finalbody=[], - lineno=lineno + 11, - ), + *_create_timing_try_block(used_frameworks, gpu, lineno), ast.Expr( value=ast.Call( func=ast.Attribute(value=ast.Name(id="gc", ctx=ast.Load()), attr="enable", ctx=ast.Load()), diff --git a/tests/test_inject_profiling_used_frameworks.py b/tests/test_inject_profiling_used_frameworks.py index 826be09c8..ede5559df 100644 --- a/tests/test_inject_profiling_used_frameworks.py +++ b/tests/test_inject_profiling_used_frameworks.py @@ -1492,3 +1492,435 @@ def test_my_function(): result = normalize_instrumented_code(instrumented_code) expected = EXPECTED_ALL_FRAMEWORKS_PERFORMANCE assert result == expected + + +# ============================================================================ +# Expected instrumented code for GPU timing mode +# ============================================================================ + +EXPECTED_TORCH_GPU_BEHAVIOR = """import gc +import inspect +import os +import sqlite3 +import time + +import dill as pickle +import torch +from mymodule import my_function + + +def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs): + test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}' + if not hasattr(codeflash_wrap, 'index'): + codeflash_wrap.index = {} + if test_id in codeflash_wrap.index: + codeflash_wrap.index[test_id] += 1 + else: + codeflash_wrap.index[test_id] = 0 + codeflash_test_index = codeflash_wrap.index[test_id] + invocation_id = f'{codeflash_line_id}_{codeflash_test_index}' + test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}' + print(f'!$######{test_stdout_tag}######$!') + exception = None + _codeflash_use_gpu_timer = torch.cuda.is_available() and torch.cuda.is_initialized() + _codeflash_should_sync_cuda = torch.cuda.is_available() and torch.cuda.is_initialized() + _codeflash_should_sync_mps = not _codeflash_should_sync_cuda and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and hasattr(torch.mps, 'synchronize') + gc.disable() + if _codeflash_use_gpu_timer: + try: + _codeflash_start_event = torch.cuda.Event(enable_timing=True) + _codeflash_end_event = torch.cuda.Event(enable_timing=True) + _codeflash_start_event.record() + return_value = codeflash_wrapped(*args, **kwargs) + _codeflash_end_event.record() + torch.cuda.synchronize() + codeflash_duration = int(_codeflash_start_event.elapsed_time(_codeflash_end_event) * 1000000) + except Exception as e: + torch.cuda.synchronize() + codeflash_duration = 0 + exception = e + else: + try: + if _codeflash_should_sync_cuda: + torch.cuda.synchronize() + elif _codeflash_should_sync_mps: + torch.mps.synchronize() + counter = time.perf_counter_ns() + return_value = codeflash_wrapped(*args, **kwargs) + if _codeflash_should_sync_cuda: + torch.cuda.synchronize() + elif _codeflash_should_sync_mps: + torch.mps.synchronize() + codeflash_duration = time.perf_counter_ns() - counter + except Exception as e: + codeflash_duration = time.perf_counter_ns() - counter + exception = e + gc.enable() + print(f'!######{test_stdout_tag}######!') + pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value) + codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) + codeflash_con.commit() + if exception: + raise exception + return return_value + +def test_my_function(): + codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX']) + codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION'] + codeflash_con = sqlite3.connect(f'{CODEFLASH_DB_PATH}') + codeflash_cur = codeflash_con.cursor() + codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') + _call__bound__arguments = inspect.signature(my_function).bind(1, 2) + _call__bound__arguments.apply_defaults() + result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) + assert result == 3 + codeflash_con.close() +""" + +EXPECTED_TORCH_GPU_PERFORMANCE = """import gc +import os +import time + +import torch +from mymodule import my_function + + +def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, *args, **kwargs): + test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}' + if not hasattr(codeflash_wrap, 'index'): + codeflash_wrap.index = {} + if test_id in codeflash_wrap.index: + codeflash_wrap.index[test_id] += 1 + else: + codeflash_wrap.index[test_id] = 0 + codeflash_test_index = codeflash_wrap.index[test_id] + invocation_id = f'{codeflash_line_id}_{codeflash_test_index}' + test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}' + print(f'!$######{test_stdout_tag}######$!') + exception = None + _codeflash_use_gpu_timer = torch.cuda.is_available() and torch.cuda.is_initialized() + _codeflash_should_sync_cuda = torch.cuda.is_available() and torch.cuda.is_initialized() + _codeflash_should_sync_mps = not _codeflash_should_sync_cuda and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and hasattr(torch.mps, 'synchronize') + gc.disable() + if _codeflash_use_gpu_timer: + try: + _codeflash_start_event = torch.cuda.Event(enable_timing=True) + _codeflash_end_event = torch.cuda.Event(enable_timing=True) + _codeflash_start_event.record() + return_value = codeflash_wrapped(*args, **kwargs) + _codeflash_end_event.record() + torch.cuda.synchronize() + codeflash_duration = int(_codeflash_start_event.elapsed_time(_codeflash_end_event) * 1000000) + except Exception as e: + torch.cuda.synchronize() + codeflash_duration = 0 + exception = e + else: + try: + if _codeflash_should_sync_cuda: + torch.cuda.synchronize() + elif _codeflash_should_sync_mps: + torch.mps.synchronize() + counter = time.perf_counter_ns() + return_value = codeflash_wrapped(*args, **kwargs) + if _codeflash_should_sync_cuda: + torch.cuda.synchronize() + elif _codeflash_should_sync_mps: + torch.mps.synchronize() + codeflash_duration = time.perf_counter_ns() - counter + except Exception as e: + codeflash_duration = time.perf_counter_ns() - counter + exception = e + gc.enable() + print(f'!######{test_stdout_tag}:{codeflash_duration}######!') + if exception: + raise exception + return return_value + +def test_my_function(): + codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX']) + result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, 1, 2) + assert result == 3 +""" + +EXPECTED_TORCH_ALIASED_GPU_BEHAVIOR = """import gc +import inspect +import os +import sqlite3 +import time + +import dill as pickle +import torch as th +from mymodule import my_function + + +def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs): + test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}' + if not hasattr(codeflash_wrap, 'index'): + codeflash_wrap.index = {} + if test_id in codeflash_wrap.index: + codeflash_wrap.index[test_id] += 1 + else: + codeflash_wrap.index[test_id] = 0 + codeflash_test_index = codeflash_wrap.index[test_id] + invocation_id = f'{codeflash_line_id}_{codeflash_test_index}' + test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}' + print(f'!$######{test_stdout_tag}######$!') + exception = None + _codeflash_use_gpu_timer = th.cuda.is_available() and th.cuda.is_initialized() + _codeflash_should_sync_cuda = th.cuda.is_available() and th.cuda.is_initialized() + _codeflash_should_sync_mps = not _codeflash_should_sync_cuda and hasattr(th.backends, 'mps') and th.backends.mps.is_available() and hasattr(th.mps, 'synchronize') + gc.disable() + if _codeflash_use_gpu_timer: + try: + _codeflash_start_event = th.cuda.Event(enable_timing=True) + _codeflash_end_event = th.cuda.Event(enable_timing=True) + _codeflash_start_event.record() + return_value = codeflash_wrapped(*args, **kwargs) + _codeflash_end_event.record() + th.cuda.synchronize() + codeflash_duration = int(_codeflash_start_event.elapsed_time(_codeflash_end_event) * 1000000) + except Exception as e: + th.cuda.synchronize() + codeflash_duration = 0 + exception = e + else: + try: + if _codeflash_should_sync_cuda: + th.cuda.synchronize() + elif _codeflash_should_sync_mps: + th.mps.synchronize() + counter = time.perf_counter_ns() + return_value = codeflash_wrapped(*args, **kwargs) + if _codeflash_should_sync_cuda: + th.cuda.synchronize() + elif _codeflash_should_sync_mps: + th.mps.synchronize() + codeflash_duration = time.perf_counter_ns() - counter + except Exception as e: + codeflash_duration = time.perf_counter_ns() - counter + exception = e + gc.enable() + print(f'!######{test_stdout_tag}######!') + pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value) + codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) + codeflash_con.commit() + if exception: + raise exception + return return_value + +def test_my_function(): + codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX']) + codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION'] + codeflash_con = sqlite3.connect(f'{CODEFLASH_DB_PATH}') + codeflash_cur = codeflash_con.cursor() + codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') + _call__bound__arguments = inspect.signature(my_function).bind(1, 2) + _call__bound__arguments.apply_defaults() + result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) + assert result == 3 + codeflash_con.close() +""" + + +# ============================================================================ +# Tests for GPU timing mode +# ============================================================================ + + +class TestInjectProfilingGpuTimingMode: + """Tests for inject_profiling_into_existing_test with gpu=True.""" + + def test_torch_gpu_behavior_mode(self, tmp_path: Path) -> None: + """Test instrumentation with PyTorch and gpu=True in BEHAVIOR mode.""" + code = """import torch +from mymodule import my_function + +def test_my_function(): + result = my_function(1, 2) + assert result == 3 +""" + test_file = tmp_path / "test_example.py" + test_file.write_text(code) + + func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py")) + + success, instrumented_code = inject_profiling_into_existing_test( + test_path=test_file, + call_positions=[CodePosition(5, 13)], + function_to_optimize=func, + tests_project_root=tmp_path, + mode=TestingMode.BEHAVIOR, + gpu=True, + ) + + result = normalize_instrumented_code(instrumented_code) + expected = EXPECTED_TORCH_GPU_BEHAVIOR + assert result == expected + + def test_torch_gpu_performance_mode(self, tmp_path: Path) -> None: + """Test instrumentation with PyTorch and gpu=True in PERFORMANCE mode.""" + code = """import torch +from mymodule import my_function + +def test_my_function(): + result = my_function(1, 2) + assert result == 3 +""" + test_file = tmp_path / "test_example.py" + test_file.write_text(code) + + func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py")) + + success, instrumented_code = inject_profiling_into_existing_test( + test_path=test_file, + call_positions=[CodePosition(5, 13)], + function_to_optimize=func, + tests_project_root=tmp_path, + mode=TestingMode.PERFORMANCE, + gpu=True, + ) + + result = normalize_instrumented_code(instrumented_code) + expected = EXPECTED_TORCH_GPU_PERFORMANCE + assert result == expected + + def test_torch_aliased_gpu_behavior_mode(self, tmp_path: Path) -> None: + """Test instrumentation with PyTorch alias and gpu=True in BEHAVIOR mode.""" + code = """import torch as th +from mymodule import my_function + +def test_my_function(): + result = my_function(1, 2) + assert result == 3 +""" + test_file = tmp_path / "test_example.py" + test_file.write_text(code) + + func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py")) + + success, instrumented_code = inject_profiling_into_existing_test( + test_path=test_file, + call_positions=[CodePosition(5, 13)], + function_to_optimize=func, + tests_project_root=tmp_path, + mode=TestingMode.BEHAVIOR, + gpu=True, + ) + + result = normalize_instrumented_code(instrumented_code) + expected = EXPECTED_TORCH_ALIASED_GPU_BEHAVIOR + assert result == expected + + def test_no_torch_gpu_flag_uses_cpu_timing(self, tmp_path: Path) -> None: + """Test that gpu=True without torch uses standard CPU timing.""" + code = """from mymodule import my_function + +def test_my_function(): + result = my_function(1, 2) + assert result == 3 +""" + test_file = tmp_path / "test_example.py" + test_file.write_text(code) + + func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py")) + + success, instrumented_code = inject_profiling_into_existing_test( + test_path=test_file, + call_positions=[CodePosition(4, 13)], + function_to_optimize=func, + tests_project_root=tmp_path, + mode=TestingMode.PERFORMANCE, + gpu=True, + ) + + result = normalize_instrumented_code(instrumented_code) + # gpu=True without torch should produce the same result as gpu=False + expected = EXPECTED_NO_FRAMEWORKS_PERFORMANCE + assert result == expected + + def test_gpu_false_with_torch_uses_device_sync(self, tmp_path: Path) -> None: + """Test that gpu=False with torch uses device sync (existing behavior).""" + code = """import torch +from mymodule import my_function + +def test_my_function(): + result = my_function(1, 2) + assert result == 3 +""" + test_file = tmp_path / "test_example.py" + test_file.write_text(code) + + func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py")) + + success, instrumented_code = inject_profiling_into_existing_test( + test_path=test_file, + call_positions=[CodePosition(5, 13)], + function_to_optimize=func, + tests_project_root=tmp_path, + mode=TestingMode.PERFORMANCE, + gpu=False, + ) + + result = normalize_instrumented_code(instrumented_code) + # gpu=False with torch should produce device sync code + expected = EXPECTED_TORCH_PERFORMANCE + assert result == expected + + def test_torch_submodule_import_gpu_mode(self, tmp_path: Path) -> None: + """Test that gpu=True works with torch submodule imports like 'from torch import nn'.""" + code = """from torch import nn +from mymodule import my_function + +def test_my_function(): + result = my_function(1, 2) + assert result == 3 +""" + test_file = tmp_path / "test_example.py" + test_file.write_text(code) + + func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py")) + + success, instrumented_code = inject_profiling_into_existing_test( + test_path=test_file, + call_positions=[CodePosition(5, 13)], + function_to_optimize=func, + tests_project_root=tmp_path, + mode=TestingMode.PERFORMANCE, + gpu=True, + ) + + assert success + # Verify GPU timing code is present (torch detected from submodule import) + assert "_codeflash_use_gpu_timer = torch.cuda.is_available()" in instrumented_code + assert "torch.cuda.Event(enable_timing=True)" in instrumented_code + assert "elapsed_time" in instrumented_code + + def test_torch_dotted_import_gpu_mode(self, tmp_path: Path) -> None: + """Test that gpu=True works with torch dotted imports like 'import torch.nn'.""" + code = """import torch.nn +from mymodule import my_function + +def test_my_function(): + result = my_function(1, 2) + assert result == 3 +""" + test_file = tmp_path / "test_example.py" + test_file.write_text(code) + + func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py")) + + success, instrumented_code = inject_profiling_into_existing_test( + test_path=test_file, + call_positions=[CodePosition(5, 13)], + function_to_optimize=func, + tests_project_root=tmp_path, + mode=TestingMode.PERFORMANCE, + gpu=True, + ) + + assert success + # Verify GPU timing code is present (torch detected from dotted import) + assert "_codeflash_use_gpu_timer = torch.cuda.is_available()" in instrumented_code + assert "torch.cuda.Event(enable_timing=True)" in instrumented_code + assert "elapsed_time" in instrumented_code From a4e0fb469e95b021bc63ad7af57b86d23b143469 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Tue, 3 Feb 2026 14:55:43 -0800 Subject: [PATCH 2/4] fix: resolve ruff lint errors for pre-commit Fix unused variables, single-item membership tests, unnecessary lambdas, and ternary expressions that can use `or` operator. Co-Authored-By: Claude Opus 4.5 --- .../languages/javascript/find_references.py | 14 ++++++------- codeflash/languages/treesitter_utils.py | 8 ++++---- codeflash/optimization/function_optimizer.py | 20 +++++++++---------- codeflash/verification/parse_test_output.py | 10 +++++----- 4 files changed, 25 insertions(+), 27 deletions(-) diff --git a/codeflash/languages/javascript/find_references.py b/codeflash/languages/javascript/find_references.py index 812f7c4a7..43bde84a5 100644 --- a/codeflash/languages/javascript/find_references.py +++ b/codeflash/languages/javascript/find_references.py @@ -168,7 +168,7 @@ def find_references( if import_info: # Found an import - mark as visited and search for calls context.visited_files.add(file_path) - import_name, original_import = import_info + import_name, _original_import = import_info file_refs = self._find_references_in_file( file_path, file_code, function_name, import_name, file_analyzer, include_self=True ) @@ -213,7 +213,7 @@ def find_references( if import_info: context.visited_files.add(file_path) - import_name, original_import = import_info + import_name, _original_import = import_info file_refs = self._find_references_in_file( file_path, file_code, reexport_name, import_name, file_analyzer, include_self=True ) @@ -317,7 +317,7 @@ def _find_matching_import( export_name = exported.export_name or exported.function_name for name, alias in imp.named_imports: if name == export_name: - return (alias if alias else name, imp) + return (alias or name, imp) # Check namespace import if imp.namespace_import: @@ -360,7 +360,7 @@ def _find_references_in_file( lines = source_code.splitlines() # The name to search for (either imported name or original) - search_name = import_name if import_name else function_name + search_name = import_name or function_name # Handle namespace imports (e.g., "utils.helper") if "." in search_name: @@ -404,7 +404,7 @@ def _find_identifier_references( name_node = node.child_by_field_name("name") if name_node: new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") - elif node.type in ("variable_declarator",): + elif node.type == "variable_declarator": # Arrow function or function expression assigned to variable name_node = node.child_by_field_name("name") value_node = node.child_by_field_name("value") @@ -673,7 +673,7 @@ def _find_reexports( end_column=0, context=context_line.strip(), reference_type="reexport", - import_name=alias if alias else name, + import_name=alias or name, caller_function=None, ) references.append(ref) @@ -745,7 +745,7 @@ def _find_reexports_direct( end_column=0, context=context_line.strip(), reference_type="reexport", - import_name=alias if alias else name, + import_name=alias or name, caller_function=None, ) references.append(ref) diff --git a/codeflash/languages/treesitter_utils.py b/codeflash/languages/treesitter_utils.py index f4b7ead43..75792be6f 100644 --- a/codeflash/languages/treesitter_utils.py +++ b/codeflash/languages/treesitter_utils.py @@ -899,7 +899,7 @@ def is_function_exported( # Check named exports for name, alias in export.exported_names: if name == function_name: - return (True, alias if alias else name) + return (True, alias or name) # For class methods, check if the containing class is exported if class_name: @@ -911,7 +911,7 @@ def is_function_exported( # Check if class is in named exports for name, alias in export.exported_names: if name == class_name: - return (True, alias if alias else name) + return (True, alias or name) return (False, None) @@ -1580,9 +1580,9 @@ def get_analyzer_for_file(file_path: Path) -> TreeSitterAnalyzer: """ suffix = file_path.suffix.lower() - if suffix in (".ts",): + if suffix == ".ts": return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT) - if suffix in (".tsx",): + if suffix == ".tsx": return TreeSitterAnalyzer(TreeSitterLanguage.TSX) # Default to JavaScript for .js, .jsx, .mjs, .cjs return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 095731f9f..1a7387247 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -315,7 +315,7 @@ def _handle_empty_queue(self) -> CandidateNode | None: self.future_all_code_repair, "Repairing {0} candidates", "Added {0} candidates from repair, total candidates now: {1}", - lambda: self.future_all_code_repair.clear(), + self.future_all_code_repair.clear, ) if self.line_profiler_done and not self.refinement_done: return self._process_candidates( @@ -330,7 +330,7 @@ def _handle_empty_queue(self) -> CandidateNode | None: self.future_adaptive_optimizations, "Applying adaptive optimizations to {0} candidates", "Added {0} candidates from adaptive optimization, total candidates now: {1}", - lambda: self.future_adaptive_optimizations.clear(), + self.future_adaptive_optimizations.clear, ) return None # All done @@ -440,12 +440,10 @@ def __init__( ) -> None: self.project_root = test_cfg.project_root_path self.test_cfg = test_cfg - self.aiservice_client = aiservice_client if aiservice_client else AiServiceClient() + self.aiservice_client = aiservice_client or AiServiceClient() self.function_to_optimize = function_to_optimize self.function_to_optimize_source_code = ( - function_to_optimize_source_code - if function_to_optimize_source_code - else function_to_optimize.file_path.read_text(encoding="utf8") + function_to_optimize_source_code or function_to_optimize.file_path.read_text(encoding="utf8") ) self.language_support = current_language_support() if not function_to_optimize_ast: @@ -459,7 +457,7 @@ def __init__( ) else: self.function_to_optimize_ast = function_to_optimize_ast - self.function_to_tests = function_to_tests if function_to_tests else {} + self.function_to_tests = function_to_tests or {} self.experiment_id = os.getenv("CODEFLASH_EXPERIMENT_ID", None) self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None @@ -476,9 +474,9 @@ def __init__( tests_root=test_cfg.tests_root, ) - self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {} - self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {} - self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None + self.function_benchmark_timings = function_benchmark_timings or {} + self.total_benchmark_timings = total_benchmark_timings or {} + self.replay_tests_dir = replay_tests_dir or None n_tests = get_effort_value(EffortKeys.N_GENERATED_TESTS, self.effort) self.executor = concurrent.futures.ThreadPoolExecutor( max_workers=n_tests + 3 if self.experiment_id is None else n_tests + 4 @@ -2083,7 +2081,7 @@ def process_review( formatted_generated_test = format_generated_code(concolic_test_str, self.args.formatter_cmds) generated_tests_str += f"```{code_lang}\n{formatted_generated_test}\n```\n\n" - existing_tests, replay_tests, concolic_tests = existing_tests_source_for( + existing_tests, replay_tests, _concolic_tests = existing_tests_source_for( self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), function_to_all_tests, test_cfg=self.test_cfg, diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index 59b4f0acc..00ee82e19 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -429,8 +429,8 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes for val in data: try: test_module_path = val[0] - test_class_name = val[1] if val[1] else None - test_function_name = val[2] if val[2] else None + test_class_name = val[1] or None + test_function_name = val[2] or None function_getting_tested = val[3] # For Jest tests, test_module_path could be: @@ -1152,7 +1152,7 @@ def merge_test_results( for result_bin in bin_results: # Prefer XML runtime (from stdout markers) if bin runtime is None/0 # This is important for Jest perf tests which output timing to stdout, not SQLite - merged_runtime = result_bin.runtime if result_bin.runtime else xml_result.runtime + merged_runtime = result_bin.runtime or xml_result.runtime merged_test_results.add( FunctionTestInvocation( loop_index=xml_result.loop_index, @@ -1183,7 +1183,7 @@ def merge_test_results( continue # Prefer XML runtime (from stdout markers) if bin runtime is None/0 # This is important for Jest perf tests which output timing to stdout, not SQLite - merged_runtime = bin_result.runtime if bin_result.runtime else xml_result.runtime + merged_runtime = bin_result.runtime or xml_result.runtime merged_test_results.add( FunctionTestInvocation( loop_index=xml_result.loop_index, @@ -1215,7 +1215,7 @@ def merge_test_results( continue # Prefer XML runtime (from stdout markers) if bin runtime is None/0 # This is important for Jest perf tests which output timing to stdout, not SQLite - merged_runtime = bin_result.runtime if bin_result.runtime else xml_result.runtime + merged_runtime = bin_result.runtime or xml_result.runtime merged_test_results.add( FunctionTestInvocation( loop_index=bin_result.loop_index, From 805e612b3be2a117bc5127737e8358e74957adc6 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Tue, 3 Feb 2026 14:56:35 -0800 Subject: [PATCH 3/4] linter fixes --- codeflash/api/aiservice.py | 4 ++-- codeflash/code_utils/code_extractor.py | 4 ++-- codeflash/code_utils/code_replacer.py | 2 +- codeflash/code_utils/codeflash_wrap_decorator.py | 2 +- codeflash/code_utils/config_js.py | 2 +- codeflash/code_utils/git_utils.py | 10 +++++----- codeflash/code_utils/instrument_existing_tests.py | 6 +++--- codeflash/code_utils/line_profile_utils.py | 4 ++-- codeflash/code_utils/normalizers/python.py | 4 ++-- codeflash/context/code_context_extractor.py | 8 ++++---- codeflash/context/unused_definition_remover.py | 4 ++-- codeflash/discovery/discover_unit_tests.py | 6 +++--- codeflash/github/PrComment.py | 2 +- codeflash/languages/javascript/import_resolver.py | 2 +- codeflash/languages/javascript/support.py | 4 ++-- codeflash/languages/javascript/test_runner.py | 6 +++--- codeflash/languages/javascript/vitest_runner.py | 6 +++--- codeflash/models/models.py | 2 +- codeflash/result/explanation.py | 2 +- codeflash/verification/codeflash_capture.py | 2 +- codeflash/version.py | 2 +- 21 files changed, 42 insertions(+), 42 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 157bf24e6..5610dcd59 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -328,7 +328,7 @@ def optimize_python_code_line_profiler( console.rule() # Set python_version for backward compatibility with Python, or use language_version - python_version = language_version if language_version else platform.python_version() + python_version = language_version or platform.python_version() payload = { "source_code": source_code, @@ -868,7 +868,7 @@ def get_optimization_review( "replay_tests": replay_tests, "speedup": f"{(100 * float(explanation.speedup)):.2f}%", "loop_count": explanation.winning_benchmarking_test_results.number_of_loops(), - "benchmark_details": explanation.benchmark_details if explanation.benchmark_details else None, + "benchmark_details": explanation.benchmark_details or None, "optimized_runtime": humanize_runtime(explanation.best_runtime_ns), "original_runtime": humanize_runtime(explanation.original_runtime_ns), "codeflash_version": codeflash_version, diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 4e19f53be..beee82e46 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -1436,7 +1436,7 @@ def _collect_numerical_imports(tree: ast.Module) -> tuple[set[str], set[str]]: module_root = alias.name.split(".")[0] if module_root in NUMERICAL_MODULES: # Use the alias if present, otherwise the module name - name = alias.asname if alias.asname else alias.name.split(".")[0] + name = alias.asname or alias.name.split(".")[0] numerical_names.add(name) modules_used.add(module_root) elif isinstance(node, ast.ImportFrom) and node.module: @@ -1448,7 +1448,7 @@ def _collect_numerical_imports(tree: ast.Module) -> tuple[set[str], set[str]]: # Can't track star imports, but mark the module as numerical numerical_names.add(module_root) else: - name = alias.asname if alias.asname else alias.name + name = alias.asname or alias.name numerical_names.add(name) modules_used.add(module_root) diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index e543d184d..049602436 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -686,7 +686,7 @@ def _get_existing_names(original_declarations: list, analyzer: TreeSitterAnalyze if imp.default_import: existing_names.add(imp.default_import) for name, alias in imp.named_imports: - existing_names.add(alias if alias else name) + existing_names.add(alias or name) if imp.namespace_import: existing_names.add(imp.namespace_import) diff --git a/codeflash/code_utils/codeflash_wrap_decorator.py b/codeflash/code_utils/codeflash_wrap_decorator.py index a6b6d339f..a33ed1ebf 100644 --- a/codeflash/code_utils/codeflash_wrap_decorator.py +++ b/codeflash/code_utils/codeflash_wrap_decorator.py @@ -37,7 +37,7 @@ def extract_test_context_from_env() -> tuple[str, str | None, str]: test_function = os.environ["CODEFLASH_TEST_FUNCTION"] if test_module and test_function: - return (test_module, test_class if test_class else None, test_function) + return (test_module, test_class or None, test_function) raise RuntimeError( "Test context environment variables not set - ensure tests are run through codeflash test runner" diff --git a/codeflash/code_utils/config_js.py b/codeflash/code_utils/config_js.py index b2e827f26..9039f13e2 100644 --- a/codeflash/code_utils/config_js.py +++ b/codeflash/code_utils/config_js.py @@ -292,7 +292,7 @@ def parse_package_json_config(package_json_path: Path) -> tuple[dict[str, Any], config["formatter_cmds"] = codeflash_config["formatterCmds"] else: detected_formatter = detect_formatter(project_root, package_data) - config["formatter_cmds"] = detected_formatter if detected_formatter else [] + config["formatter_cmds"] = detected_formatter or [] # Parse optional config values from codeflash section if codeflash_config.get("benchmarksRoot"): diff --git a/codeflash/code_utils/git_utils.py b/codeflash/code_utils/git_utils.py index ee8b7dbc3..a67c3acb2 100644 --- a/codeflash/code_utils/git_utils.py +++ b/codeflash/code_utils/git_utils.py @@ -74,7 +74,7 @@ def get_current_branch(repo: Repo | None = None) -> str: :return: The name of the current branch, or "main" if HEAD is detached or the branch cannot be determined. """ - repository: Repo = repo if repo else git.Repo(search_parent_directories=True) + repository: Repo = repo or git.Repo(search_parent_directories=True) # Check if HEAD is detached (active_branch will be None) if repository.head.is_detached: @@ -106,12 +106,12 @@ def get_current_branch(repo: Repo | None = None) -> str: def get_remote_url(repo: Repo | None = None, git_remote: str | None = "origin") -> str: - repository: Repo = repo if repo else git.Repo(search_parent_directories=True) + repository: Repo = repo or git.Repo(search_parent_directories=True) return repository.remote(name=git_remote).url def get_git_remotes(repo: Repo) -> list[str]: - repository: Repo = repo if repo else git.Repo(search_parent_directories=True) + repository: Repo = repo or git.Repo(search_parent_directories=True) return [remote.name for remote in repository.remotes] @@ -128,7 +128,7 @@ def get_repo_owner_and_name(repo: Repo | None = None, git_remote: str | None = " def git_root_dir(repo: Repo | None = None) -> Path: - repository: Repo = repo if repo else git.Repo(search_parent_directories=True) + repository: Repo = repo or git.Repo(search_parent_directories=True) return Path(repository.working_dir) @@ -199,7 +199,7 @@ def get_last_commit_author_if_pr_exists(repo: Repo | None = None) -> str | None: if "PR_NUMBER" not in os.environ: return None try: - repository: Repo = repo if repo else git.Repo(search_parent_directories=True) + repository: Repo = repo or git.Repo(search_parent_directories=True) last_commit = repository.head.commit except Exception: logger.exception("Failed to get last commit author.") diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index f3e929688..15949957e 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -686,11 +686,11 @@ def detect_frameworks_from_code(code: str) -> dict[str, str]: module_name = alias.name.split(".")[0] if module_name == "torch": # Use asname if available, otherwise use the module name - frameworks["torch"] = alias.asname if alias.asname else module_name + frameworks["torch"] = alias.asname or module_name elif module_name == "tensorflow": - frameworks["tensorflow"] = alias.asname if alias.asname else module_name + frameworks["tensorflow"] = alias.asname or module_name elif module_name == "jax": - frameworks["jax"] = alias.asname if alias.asname else module_name + frameworks["jax"] = alias.asname or module_name elif isinstance(node, ast.ImportFrom) and node.module: module_name = node.module.split(".")[0] if module_name == "torch" and "torch" not in frameworks: diff --git a/codeflash/code_utils/line_profile_utils.py b/codeflash/code_utils/line_profile_utils.py index 93997b2c6..68c639ea2 100644 --- a/codeflash/code_utils/line_profile_utils.py +++ b/codeflash/code_utils/line_profile_utils.py @@ -41,7 +41,7 @@ def visit_Import(self, node: ast.Import) -> None: """Track regular imports like 'import numba' or 'import numba as nb'.""" for alias in node.names: # alias.name is the module name, alias.asname is the alias (or None) - local_name = alias.asname if alias.asname else alias.name + local_name = alias.asname or alias.name # For module imports, we store (module_name, None) to indicate it's a module import self.import_aliases[local_name] = (alias.name, None) self.generic_visit(node) @@ -53,7 +53,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: return for alias in node.names: - local_name = alias.asname if alias.asname else alias.name + local_name = alias.asname or alias.name # For from imports, we store (module_name, imported_name) self.import_aliases[local_name] = (node.module, alias.name) self.generic_visit(node) diff --git a/codeflash/code_utils/normalizers/python.py b/codeflash/code_utils/normalizers/python.py index c5c7986cb..59fdb32ea 100644 --- a/codeflash/code_utils/normalizers/python.py +++ b/codeflash/code_utils/normalizers/python.py @@ -56,14 +56,14 @@ def get_normalized_name(self, name: str) -> str: def visit_Import(self, node: ast.Import) -> ast.Import: """Track imported names.""" for alias in node.names: - name = alias.asname if alias.asname else alias.name + name = alias.asname or alias.name self.imports.add(name.split(".")[0]) return node def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom: """Track imported names from modules.""" for alias in node.names: - name = alias.asname if alias.asname else alias.name + name = alias.asname or alias.name self.imports.add(name) return node diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 18db28856..92aa43b44 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -607,7 +607,7 @@ class definitions for any classes imported from project modules. This helps if isinstance(node, ast.ImportFrom) and node.module: for alias in node.names: if alias.name != "*": - imported_name = alias.asname if alias.asname else alias.name + imported_name = alias.asname or alias.name imported_names[imported_name] = node.module if not imported_names: @@ -751,7 +751,7 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo if isinstance(node, ast.ImportFrom) and node.module: for alias in node.names: if alias.name != "*": - imported_name = alias.asname if alias.asname else alias.name + imported_name = alias.asname or alias.name imported_names[imported_name] = node.module elif isinstance(node, ast.ClassDef): for base in node.bases: @@ -869,14 +869,14 @@ def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, for node in module_tree.body: if isinstance(node, ast.Import): for alias in node.names: - name = alias.asname if alias.asname else alias.name.split(".")[0] + name = alias.asname or alias.name.split(".")[0] if name in needed_names and node.lineno not in added_imports: import_lines.append(source_lines[node.lineno - 1]) added_imports.add(node.lineno) break elif isinstance(node, ast.ImportFrom): for alias in node.names: - name = alias.asname if alias.asname else alias.name + name = alias.asname or alias.name if name in needed_names and node.lineno not in added_imports: import_lines.append(source_lines[node.lineno - 1]) added_imports.add(node.lineno) diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index f4eec94e8..37fc0e757 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -646,7 +646,7 @@ def _analyze_imports_in_optimized_code( file_entry = helpers_by_file_and_func.get(module_name) if file_entry: for alias in node.names: - imported_name = alias.asname if alias.asname else alias.name + imported_name = alias.asname or alias.name original_name = alias.name helpers = file_entry.get(original_name) if helpers: @@ -658,7 +658,7 @@ def _analyze_imports_in_optimized_code( elif isinstance(node, ast.Import): # Handle "import module" statements for alias in node.names: - imported_name = alias.asname if alias.asname else alias.name + imported_name = alias.asname or alias.name module_name = alias.name helpers = helpers_by_file.get(module_name) if helpers: diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index cd0a82605..96bafc504 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -244,7 +244,7 @@ def visit_Import(self, node: ast.Import) -> None: return for alias in node.names: - module_name = alias.asname if alias.asname else alias.name + module_name = alias.asname or alias.name self.imported_modules.add(module_name) # Check for dynamic import modules @@ -305,7 +305,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: self.wildcard_modules.add(mod) continue - imported_name = alias.asname if alias.asname else aname + imported_name = alias.asname or aname self.imported_modules.add(imported_name) if alias.asname: @@ -656,7 +656,7 @@ def discover_unit_tests( # Existing Python logic framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest} - strategy = framework_strategies.get(cfg.test_framework, None) + strategy = framework_strategies.get(cfg.test_framework) if not strategy: error_message = f"Unsupported test framework: {cfg.test_framework}" raise ValueError(error_message) diff --git a/codeflash/github/PrComment.py b/codeflash/github/PrComment.py index fe0ff095e..e8e742432 100644 --- a/codeflash/github/PrComment.py +++ b/codeflash/github/PrComment.py @@ -41,7 +41,7 @@ def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[B "speedup_pct": self.speedup_pct, "loop_count": self.winning_benchmarking_test_results.number_of_loops(), "report_table": report_table, - "benchmark_details": self.benchmark_details if self.benchmark_details else None, + "benchmark_details": self.benchmark_details or None, } if self.original_async_throughput is not None and self.best_async_throughput is not None: diff --git a/codeflash/languages/javascript/import_resolver.py b/codeflash/languages/javascript/import_resolver.py index 4e237b8d6..ec9c6c839 100644 --- a/codeflash/languages/javascript/import_resolver.py +++ b/codeflash/languages/javascript/import_resolver.py @@ -92,7 +92,7 @@ def _build_resolved_import(self, import_info: ImportInfo, resolved_path: Path) - # Collect named imports for name, alias in import_info.named_imports: - imported_names.append(alias if alias else name) + imported_names.append(alias or name) # Add default import if present if import_info.default_import: diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index eecf11064..33c726ba9 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -675,7 +675,7 @@ def _find_referenced_globals( if imp.namespace_import: imported_names.add(imp.namespace_import) for name, alias in imp.named_imports: - imported_names.add(alias if alias else name) + imported_names.add(alias or name) # Build a map of declaration name -> declaration info decl_map: dict[str, Any] = {} @@ -903,7 +903,7 @@ def _find_imported_type_definitions( # Check if any of our type names are imported from this module for name, alias in imp.named_imports: # The type could be imported with an alias - local_name = alias if alias else name + local_name = alias or name if local_name in type_names: type_import_map[local_name] = (imp, name) # (ImportInfo, original_name) diff --git a/codeflash/languages/javascript/test_runner.py b/codeflash/languages/javascript/test_runner.py index c65adfa7b..c58b3c1ab 100644 --- a/codeflash/languages/javascript/test_runner.py +++ b/codeflash/languages/javascript/test_runner.py @@ -542,7 +542,7 @@ def run_jest_behavioral_tests( project_root = _find_node_project_root(first_test_file) # Use the project root, or fall back to provided cwd - effective_cwd = project_root if project_root else cwd + effective_cwd = project_root or cwd logger.debug(f"Jest working directory: {effective_cwd}") # Ensure the codeflash npm package is installed @@ -780,7 +780,7 @@ def run_jest_benchmarking_tests( first_test_file = Path(test_files[0]) project_root = _find_node_project_root(first_test_file) - effective_cwd = project_root if project_root else cwd + effective_cwd = project_root or cwd logger.debug(f"Jest benchmarking working directory: {effective_cwd}") # Ensure the codeflash npm package is installed @@ -927,7 +927,7 @@ def run_jest_line_profile_tests( first_test_file = Path(test_files[0]) project_root = _find_node_project_root(first_test_file) - effective_cwd = project_root if project_root else cwd + effective_cwd = project_root or cwd logger.debug(f"Jest line profiling working directory: {effective_cwd}") # Ensure the codeflash npm package is installed diff --git a/codeflash/languages/javascript/vitest_runner.py b/codeflash/languages/javascript/vitest_runner.py index 47a529dae..b16d43609 100644 --- a/codeflash/languages/javascript/vitest_runner.py +++ b/codeflash/languages/javascript/vitest_runner.py @@ -202,7 +202,7 @@ def run_vitest_behavioral_tests( project_root = _find_vitest_project_root(test_files[0]) # Use the project root, or fall back to provided cwd - effective_cwd = project_root if project_root else cwd + effective_cwd = project_root or cwd logger.debug(f"Vitest working directory: {effective_cwd}") # Ensure the codeflash npm package is installed @@ -317,7 +317,7 @@ def run_vitest_benchmarking_tests( if project_root is None and test_files: project_root = _find_vitest_project_root(test_files[0]) - effective_cwd = project_root if project_root else cwd + effective_cwd = project_root or cwd logger.debug(f"Vitest benchmarking working directory: {effective_cwd}") # Ensure the codeflash npm package is installed @@ -420,7 +420,7 @@ def run_vitest_line_profile_tests( if project_root is None and test_files: project_root = _find_vitest_project_root(test_files[0]) - effective_cwd = project_root if project_root else cwd + effective_cwd = project_root or cwd logger.debug(f"Vitest line profiling working directory: {effective_cwd}") # Ensure the codeflash npm package is installed diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 5a5b0c5b5..a48c50552 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -784,7 +784,7 @@ def from_str_id(string_id: str, iteration_id: str | None = None) -> InvocationId test_class_name=test_class_name, test_function_name=test_function_name, function_getting_tested=components[2], - iteration_id=iteration_id if iteration_id else components[3], + iteration_id=iteration_id or components[3], ) diff --git a/codeflash/result/explanation.py b/codeflash/result/explanation.py index f0aff73d0..1afff9d58 100644 --- a/codeflash/result/explanation.py +++ b/codeflash/result/explanation.py @@ -148,7 +148,7 @@ def __str__(self) -> str: f"Optimized {self.function_name} in {self.file_path}\n" f"{self.perf_improvement_line}\n" + performance_description - + (benchmark_info if benchmark_info else "") + + (benchmark_info or "") + self.raw_explanation_message + " \n\n" + ( diff --git a/codeflash/verification/codeflash_capture.py b/codeflash/verification/codeflash_capture.py index 1c49f5515..fe7d13a99 100644 --- a/codeflash/verification/codeflash_capture.py +++ b/codeflash/verification/codeflash_capture.py @@ -94,7 +94,7 @@ def get_test_info_from_stack(tests_root: str) -> tuple[str, str | None, str, str test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "") if not test_class_name: env_class = os.environ.get("CODEFLASH_TEST_CLASS") - test_class_name = env_class if env_class else None + test_class_name = env_class or None return test_module_name, test_class_name, test_name, line_id diff --git a/codeflash/version.py b/codeflash/version.py index 6225467e3..3f984fa54 100644 --- a/codeflash/version.py +++ b/codeflash/version.py @@ -1,2 +1,2 @@ # These version placeholders will be replaced by uv-dynamic-versioning during build. -__version__ = "0.20.0" +__version__ = "0.20.0.post402.dev0+dce74b16" From ffdbd4da89d229a95e10e80e2184271f169dff7b Mon Sep 17 00:00:00 2001 From: "codeflash-ai[bot]" <148906541+codeflash-ai[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 01:22:35 +0000 Subject: [PATCH 4/4] Optimize ReferenceFinder._find_references_in_file MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This optimization achieves a **313% speedup** (from 5.05ms to 1.22ms) by eliminating redundant string decoding operations during AST traversal. The key improvements are: **What was optimized:** 1. **Node text caching**: Added `_node_text_cache` and `_node_bytes_cache` dictionaries to store decoded text and byte slices for each tree-sitter node, keyed by node ID 2. **Lazy decoding**: Introduced `_get_node_text()` and `_get_node_bytes()` helper methods that cache results on first access 3. **Byte-level comparisons**: Changed identifier matching from string equality (`name == search_name`) to byte equality (`node_bytes == search_bytes`), avoiding UTF-8 decoding unless necessary 4. **Pre-encoded search term**: The `search_name` is encoded once per file as `search_bytes` rather than repeatedly during comparisons **Why this is faster:** The original code repeatedly sliced and decoded the same AST node text during recursive traversal. Line profiler shows `_find_identifier_references` spent 52.1% of time in `child_by_field_name("function")` and 13.9% checking node types, with additional time decoding node text multiple times. The optimization eliminates this redundancy—each node's text is decoded at most once and cached. Byte comparisons are faster than string comparisons in Python and skip decoding entirely when names don't match. **Impact:** - The line profiler shows `_find_references_in_file` total time dropped from 21.5ms to 6.6ms (69% reduction) - The recursive `_find_identifier_references` becomes dramatically faster by avoiding repeated decode operations on the same nodes - Memory overhead is minimal—caches are cleared per file and only store node IDs and their decoded text - This optimization particularly benefits files with many function calls or deep AST nesting where the same parent/child nodes are accessed repeatedly The caching strategy is safe because tree-sitter nodes are immutable within a parse tree, and the caches are explicitly cleared between files to prevent memory leaks or cross-file contamination. --- .../languages/javascript/find_references.py | 64 ++++++++++++++----- 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/codeflash/languages/javascript/find_references.py b/codeflash/languages/javascript/find_references.py index 43bde84a5..a09a563d3 100644 --- a/codeflash/languages/javascript/find_references.py +++ b/codeflash/languages/javascript/find_references.py @@ -98,6 +98,10 @@ def __init__(self, project_root: Path, exclude_patterns: list[str] | None = None self.exclude_patterns = exclude_patterns or ["node_modules", "dist", "build", ".git", "coverage", "__pycache__"] self._file_cache: dict[Path, str] = {} + # Per-parse caches to avoid repeated slice+decode work. Cleared per file parse. + self._node_text_cache: dict[int, str] = {} + self._node_bytes_cache: dict[int, bytes] = {} + def find_references( self, function_to_optimize: FunctionToOptimize, include_definition: bool = False, max_files: int = 1000 ) -> list[Reference]: @@ -356,12 +360,19 @@ def _find_references_in_file( """ references: list[Reference] = [] source_bytes = source_code.encode("utf8") + # Clear per-parse caches to avoid cross-file contamination + self._node_text_cache.clear() + self._node_bytes_cache.clear() + tree = analyzer.parse(source_bytes) lines = source_code.splitlines() # The name to search for (either imported name or original) search_name = import_name or function_name + # Handle namespace imports (e.g., "utils.helper") + search_bytes = search_name.encode("utf8") + # Handle namespace imports (e.g., "utils.helper") if "." in search_name: namespace, member = search_name.split(".", 1) @@ -369,7 +380,7 @@ def _find_references_in_file( else: # Find direct calls and other reference types self._find_identifier_references( - tree.root_node, source_bytes, lines, file_path, search_name, function_name, references, None + tree.root_node, source_bytes, lines, file_path, search_name, search_bytes, function_name, references, None ) return references @@ -403,27 +414,29 @@ def _find_identifier_references( if node.type in ("function_declaration", "method_definition"): name_node = node.child_by_field_name("name") if name_node: - new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") + new_current_function = self._get_node_text(name_node, source_bytes) elif node.type == "variable_declarator": # Arrow function or function expression assigned to variable name_node = node.child_by_field_name("name") value_node = node.child_by_field_name("value") if name_node and value_node and value_node.type in ("arrow_function", "function_expression"): - new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") + new_current_function = self._get_node_text(name_node, source_bytes) + + # Check for call expressions # Check for call expressions if node.type == "call_expression": func_node = node.child_by_field_name("function") if func_node and func_node.type == "identifier": - name = source_bytes[func_node.start_byte : func_node.end_byte].decode("utf8") - if name == search_name: + # Compare bytes to avoid decode unless necessary + if self._get_node_bytes(func_node, source_bytes) == search_bytes: ref = self._create_reference(file_path, func_node, lines, "call", search_name, current_function) references.append(ref) # Check for identifiers used as callbacks or passed as arguments elif node.type == "identifier": - name = source_bytes[node.start_byte : node.end_byte].decode("utf8") - if name == search_name: + # Compare bytes to avoid decode unless matching + if self._get_node_bytes(node, source_bytes) == search_bytes: parent = node.parent # Determine reference type based on context ref_type = self._determine_reference_type(node, parent, source_bytes) @@ -434,7 +447,7 @@ def _find_identifier_references( # Recurse into children for child in node.children: self._find_identifier_references( - child, source_bytes, lines, file_path, search_name, original_name, references, new_current_function + child, source_bytes, lines, file_path, search_name, search_bytes, original_name, references, new_current_function ) def _find_member_calls( @@ -466,7 +479,9 @@ def _find_member_calls( if node.type in ("function_declaration", "method_definition"): name_node = node.child_by_field_name("name") if name_node: - new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") + new_current_function = self._get_node_text(name_node, source_bytes) + + # Check for call expressions with member access # Check for call expressions with member access if node.type == "call_expression": @@ -476,10 +491,11 @@ def _find_member_calls( prop_node = func_node.child_by_field_name("property") if obj_node and prop_node: - obj_name = source_bytes[obj_node.start_byte : obj_node.end_byte].decode("utf8") - prop_name = source_bytes[prop_node.start_byte : prop_node.end_byte].decode("utf8") + # Use cached bytes comparisons where possible + obj_name_bytes = self._get_node_bytes(obj_node, source_bytes) + prop_name_bytes = self._get_node_bytes(prop_node, source_bytes) - if obj_name == namespace and prop_name == member: + if obj_name_bytes == namespace.encode("utf8") and prop_name_bytes == member.encode("utf8"): ref = self._create_reference( file_path, func_node, lines, "call", f"{namespace}.{member}", current_function ) @@ -491,10 +507,10 @@ def _find_member_calls( prop_node = node.child_by_field_name("property") if obj_node and prop_node: - obj_name = source_bytes[obj_node.start_byte : obj_node.end_byte].decode("utf8") - prop_name = source_bytes[prop_node.start_byte : prop_node.end_byte].decode("utf8") + obj_name_bytes = self._get_node_bytes(obj_node, source_bytes) + prop_name_bytes = self._get_node_bytes(prop_node, source_bytes) - if obj_name == namespace and prop_name == member: + if obj_name_bytes == namespace.encode("utf8") and prop_name_bytes == member.encode("utf8"): parent = node.parent if parent and parent.type != "call_expression": ref_type = self._determine_reference_type(node, parent, source_bytes) @@ -804,6 +820,24 @@ def _read_file(self, file_path: Path) -> str | None: logger.debug("Could not read file %s: %s", file_path, e) return None + def _get_node_bytes(self, node: Node, source_bytes: bytes) -> bytes: + """Return the raw bytes for the node, caching to avoid repeated slicing.""" + nid = node.id + b = self._node_bytes_cache.get(nid) + if b is None: + b = source_bytes[node.start_byte : node.end_byte] + self._node_bytes_cache[nid] = b + return b + + def _get_node_text(self, node: Node, source_bytes: bytes) -> str: + """Return the decoded text for the node, caching to avoid repeated decoding.""" + nid = node.id + s = self._node_text_cache.get(nid) + if s is None: + s = source_bytes[node.start_byte : node.end_byte].decode("utf8") + self._node_text_cache[nid] = s + return s + def find_references( function_to_optimize: FunctionToOptimize, project_root: Path | None = None, max_files: int = 1000