Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 199 additions & 5 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import ast
from dataclasses import dataclass
from functools import cache
from pathlib import Path
from typing import TYPE_CHECKING

Expand All @@ -18,6 +19,201 @@

from codeflash.models.models import CodePosition

_BEHAVIOR_ASYNC_INLINE_CODE = """import asyncio
import gc
import os
import sqlite3
from functools import wraps
from pathlib import Path
from tempfile import TemporaryDirectory

import dill as pickle


def get_run_tmp_file(file_path):
if not hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
return Path(get_run_tmp_file.tmpdir.name) / file_path


def extract_test_context_from_env():
test_module = os.environ["CODEFLASH_TEST_MODULE"]
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
if test_module and test_function:
return (test_module, test_class if test_class else None, test_function)
raise RuntimeError(
"Test context environment variables not set - ensure tests are run through codeflash test runner"
)


def codeflash_behavior_async(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
test_module_name, test_class_name, test_name = extract_test_context_from_env()
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
if not hasattr(async_wrapper, "index"):
async_wrapper.index = {}
if test_id in async_wrapper.index:
async_wrapper.index[test_id] += 1
else:
async_wrapper.index[test_id] = 0
codeflash_test_index = async_wrapper.index[test_id]
invocation_id = f"{line_id}_{codeflash_test_index}"
class_prefix = (test_class_name + ".") if test_class_name else ""
test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}"
print(f"!$######{test_stdout_tag}######$!")
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
codeflash_con = sqlite3.connect(db_path)
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute(
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, "
"test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, "
"runtime INTEGER, return_value BLOB, verification_type TEXT)"
)
exception = None
counter = loop.time()
gc.disable()
try:
ret = func(*args, **kwargs)
counter = loop.time()
return_value = await ret
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
except Exception as e:
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
exception = e
finally:
gc.enable()
print(f"!######{test_stdout_tag}######!")
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps((args, kwargs, return_value))
codeflash_cur.execute(
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
test_module_name,
test_class_name,
test_name,
function_name,
loop_index,
invocation_id,
codeflash_duration,
pickled_return_value,
"function_call",
),
)
codeflash_con.commit()
codeflash_con.close()
if exception:
raise exception
return return_value
return async_wrapper
"""

_PERFORMANCE_ASYNC_INLINE_CODE = """import asyncio
import gc
import os
from functools import wraps


def extract_test_context_from_env():
test_module = os.environ["CODEFLASH_TEST_MODULE"]
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
if test_module and test_function:
return (test_module, test_class if test_class else None, test_function)
raise RuntimeError(
"Test context environment variables not set - ensure tests are run through codeflash test runner"
)


def codeflash_performance_async(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
test_module_name, test_class_name, test_name = extract_test_context_from_env()
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
if not hasattr(async_wrapper, "index"):
async_wrapper.index = {}
if test_id in async_wrapper.index:
async_wrapper.index[test_id] += 1
else:
async_wrapper.index[test_id] = 0
codeflash_test_index = async_wrapper.index[test_id]
invocation_id = f"{line_id}_{codeflash_test_index}"
class_prefix = (test_class_name + ".") if test_class_name else ""
test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}"
print(f"!$######{test_stdout_tag}######$!")
exception = None
counter = loop.time()
gc.disable()
try:
ret = func(*args, **kwargs)
counter = loop.time()
return_value = await ret
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
except Exception as e:
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
exception = e
finally:
gc.enable()
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
if exception:
raise exception
return return_value
return async_wrapper
"""

_CONCURRENCY_ASYNC_INLINE_CODE = """import asyncio
import gc
import os
import time
from functools import wraps


def codeflash_concurrency_async(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
function_name = func.__name__
concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10"))
test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "")
test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "")
test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "")
loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0")
gc.disable()
try:
seq_start = time.perf_counter_ns()
for _ in range(concurrency_factor):
result = await func(*args, **kwargs)
sequential_time = time.perf_counter_ns() - seq_start
finally:
gc.enable()
gc.disable()
try:
conc_start = time.perf_counter_ns()
tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)]
await asyncio.gather(*tasks)
concurrent_time = time.perf_counter_ns() - conc_start
finally:
gc.enable()
tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}"
print(f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!")
return result
return async_wrapper
"""

_INLINE_CODE_MAP = {
TestingMode.BEHAVIOR: _BEHAVIOR_ASYNC_INLINE_CODE,
TestingMode.PERFORMANCE: _PERFORMANCE_ASYNC_INLINE_CODE,
TestingMode.CONCURRENCY: _CONCURRENCY_ASYNC_INLINE_CODE,
}


@dataclass(frozen=True)
class FunctionCallNodeArguments:
Expand Down Expand Up @@ -1692,12 +1888,10 @@ async def async_wrapper(*args, **kwargs):
"""


@cache
def get_async_inline_code(mode: TestingMode) -> str:
if mode == TestingMode.BEHAVIOR:
return get_behavior_async_inline_code()
if mode == TestingMode.CONCURRENCY:
return get_concurrency_async_inline_code()
return get_performance_async_inline_code()
# Return the inline code for the requested mode. Default to performance mode if not matched.
return _INLINE_CODE_MAP.get(mode, _PERFORMANCE_ASYNC_INLINE_CODE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: Dead code / string duplication

The get_async_inline_code function now uses _INLINE_CODE_MAP (module-level constants), but the three original functions (get_behavior_async_inline_code, get_performance_async_inline_code, get_concurrency_async_inline_code) still exist and return identical strings. This means each inline code string is stored twice in memory.

Consider either:

  • Removing the old functions and keeping only the module-level constants, or
  • Having the old functions reference the constants (e.g., return _BEHAVIOR_ASYNC_INLINE_CODE)

Not a bug — just unnecessary duplication.



class AsyncInlineCodeInjector(cst.CSTTransformer):
Expand Down
Loading