Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 195 additions & 57 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,73 +1497,207 @@ def _is_target_decorator(self, decorator_node: cst.Name | cst.Attribute | cst.Ca
return False


class AsyncDecoratorImportAdder(cst.CSTTransformer):
"""Transformer that adds the import for async decorators."""
ASYNC_HELPER_INLINE_CODE = """import asyncio
import gc
import os
import sqlite3
import time
from functools import wraps
from pathlib import Path
from tempfile import TemporaryDirectory

import dill as pickle

def __init__(self, mode: TestingMode = TestingMode.BEHAVIOR) -> None:
self.mode = mode
self.has_import = False

def _get_decorator_name(self) -> str:
"""Get the decorator name based on the testing mode."""
if self.mode == TestingMode.BEHAVIOR:
return "codeflash_behavior_async"
if self.mode == TestingMode.CONCURRENCY:
return "codeflash_concurrency_async"
return "codeflash_performance_async"

def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
# Check if the async decorator import is already present
if (
isinstance(node.module, cst.Attribute)
and isinstance(node.module.value, cst.Attribute)
and isinstance(node.module.value.value, cst.Name)
and node.module.value.value.value == "codeflash"
and node.module.value.attr.value == "code_utils"
and node.module.attr.value == "codeflash_wrap_decorator"
and not isinstance(node.names, cst.ImportStar)
):
decorator_name = self._get_decorator_name()
for import_alias in node.names:
if import_alias.name.value == decorator_name:
self.has_import = True

def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
# If the import is already there, don't add it again
if self.has_import:
return updated_node

# Choose import based on mode
decorator_name = self._get_decorator_name()

# Parse the import statement into a CST node
import_node = cst.parse_statement(f"from codeflash.code_utils.codeflash_wrap_decorator import {decorator_name}")

# Add the import to the module's body
return updated_node.with_changes(body=[import_node, *list(updated_node.body)])

def get_run_tmp_file(file_path):
if not hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
return Path(get_run_tmp_file.tmpdir.name) / file_path


def extract_test_context_from_env():
test_module = os.environ["CODEFLASH_TEST_MODULE"]
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
if test_module and test_function:
return (test_module, test_class if test_class else None, test_function)
raise RuntimeError(
"Test context environment variables not set - ensure tests are run through codeflash test runner"
)


def codeflash_behavior_async(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
test_module_name, test_class_name, test_name = extract_test_context_from_env()
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
if not hasattr(async_wrapper, "index"):
async_wrapper.index = {}
if test_id in async_wrapper.index:
async_wrapper.index[test_id] += 1
else:
async_wrapper.index[test_id] = 0
codeflash_test_index = async_wrapper.index[test_id]
invocation_id = f"{line_id}_{codeflash_test_index}"
class_prefix = (test_class_name + ".") if test_class_name else ""
test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}"
print(f"!$######{test_stdout_tag}######$!")
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
codeflash_con = sqlite3.connect(db_path)
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute(
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, "
"test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, "
"runtime INTEGER, return_value BLOB, verification_type TEXT)"
)
exception = None
counter = loop.time()
gc.disable()
try:
ret = func(*args, **kwargs)
counter = loop.time()
return_value = await ret
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
except Exception as e:
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
exception = e
finally:
gc.enable()
print(f"!######{test_stdout_tag}######!")
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps((args, kwargs, return_value))
codeflash_cur.execute(
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
test_module_name,
test_class_name,
test_name,
function_name,
loop_index,
invocation_id,
codeflash_duration,
pickled_return_value,
"function_call",
),
)
codeflash_con.commit()
codeflash_con.close()
if exception:
raise exception
return return_value
return async_wrapper


def codeflash_performance_async(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
test_module_name, test_class_name, test_name = extract_test_context_from_env()
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
if not hasattr(async_wrapper, "index"):
async_wrapper.index = {}
if test_id in async_wrapper.index:
async_wrapper.index[test_id] += 1
else:
async_wrapper.index[test_id] = 0
codeflash_test_index = async_wrapper.index[test_id]
invocation_id = f"{line_id}_{codeflash_test_index}"
class_prefix = (test_class_name + ".") if test_class_name else ""
test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}"
print(f"!$######{test_stdout_tag}######$!")
exception = None
counter = loop.time()
gc.disable()
try:
ret = func(*args, **kwargs)
counter = loop.time()
return_value = await ret
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
except Exception as e:
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
exception = e
finally:
gc.enable()
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
if exception:
raise exception
return return_value
return async_wrapper


def codeflash_concurrency_async(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
function_name = func.__name__
concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10"))
test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "")
test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "")
test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "")
loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0")
gc.disable()
try:
seq_start = time.perf_counter_ns()
for _ in range(concurrency_factor):
result = await func(*args, **kwargs)
sequential_time = time.perf_counter_ns() - seq_start
finally:
gc.enable()
gc.disable()
try:
conc_start = time.perf_counter_ns()
tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)]
await asyncio.gather(*tasks)
concurrent_time = time.perf_counter_ns() - conc_start
finally:
gc.enable()
tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}"
print(f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!")
return result
return async_wrapper
"""

ASYNC_HELPER_FILENAME = "codeflash_async_wrapper.py"


def get_decorator_name_for_mode(mode: TestingMode) -> str:
if mode == TestingMode.BEHAVIOR:
return "codeflash_behavior_async"
if mode == TestingMode.CONCURRENCY:
return "codeflash_concurrency_async"
return "codeflash_performance_async"


def write_async_helper_file(target_dir: Path) -> Path:
"""Write the async decorator helper file to the target directory."""
helper_path = target_dir / ASYNC_HELPER_FILENAME
if not helper_path.exists():
helper_path.write_text(ASYNC_HELPER_INLINE_CODE, "utf-8")
return helper_path


def add_async_decorator_to_function(
source_path: Path, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
source_path: Path,
function: FunctionToOptimize,
mode: TestingMode = TestingMode.BEHAVIOR,
project_root: Path | None = None,
) -> bool:
"""Add async decorator to an async function definition and write back to file.

Args:
----
source_path: Path to the source file to modify in-place.
function: The FunctionToOptimize object representing the target async function.
mode: The testing mode to determine which decorator to apply.

Returns:
-------
Boolean indicating whether the decorator was successfully added.
Writes a helper file containing the decorator implementation to project_root (or source directory
as fallback) and adds a standard import + decorator to the source file.

"""
if not function.is_async:
return False

try:
# Read source code
with source_path.open(encoding="utf8") as f:
source_code = f.read()

Expand All @@ -1573,10 +1707,14 @@ def add_async_decorator_to_function(
decorator_transformer = AsyncDecoratorAdder(function, mode)
module = module.visit(decorator_transformer)

# Add the import if decorator was added
if decorator_transformer.added_decorator:
import_transformer = AsyncDecoratorImportAdder(mode)
module = module.visit(import_transformer)
# Write the helper file to project_root (on sys.path) or source dir as fallback
helper_dir = project_root if project_root is not None else source_path.parent
write_async_helper_file(helper_dir)
# Add the import via CST so sort_imports can place it correctly
decorator_name = get_decorator_name_for_mode(mode)
import_node = cst.parse_statement(f"from codeflash_async_wrapper import {decorator_name}")
module = module.with_changes(body=[import_node, *list(module.body)])

modified_code = sort_imports(code=module.code, float_to_top=True)
except Exception as e:
Expand Down
34 changes: 29 additions & 5 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1897,6 +1897,7 @@ def setup_and_establish_baseline(
if self.args.override_fixtures:
restore_conftest(original_conftest_content)
cleanup_paths(paths_to_cleanup)
self.cleanup_async_helper_file()
return Failure(baseline_result.failure())

original_code_baseline, test_functions_to_remove = baseline_result.unwrap()
Expand All @@ -1908,6 +1909,7 @@ def setup_and_establish_baseline(
if self.args.override_fixtures:
restore_conftest(original_conftest_content)
cleanup_paths(paths_to_cleanup)
self.cleanup_async_helper_file()
return Failure("The threshold for test confidence was not met.")

return Success(
Expand Down Expand Up @@ -2279,6 +2281,13 @@ def revert_code_and_helpers(self, original_helper_code: dict[Path, str]) -> None
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
self.cleanup_async_helper_file()

def cleanup_async_helper_file(self) -> None:
from codeflash.code_utils.instrument_existing_tests import ASYNC_HELPER_FILENAME

helper_path = self.project_root / ASYNC_HELPER_FILENAME
helper_path.unlink(missing_ok=True)

def establish_original_code_baseline(
self,
Expand All @@ -2296,7 +2305,10 @@ def establish_original_code_baseline(
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function

success = add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
self.function_to_optimize.file_path,
self.function_to_optimize,
TestingMode.BEHAVIOR,
project_root=self.project_root,
)

# Instrument codeflash capture
Expand Down Expand Up @@ -2361,7 +2373,10 @@ def establish_original_code_baseline(
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function

add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
self.function_to_optimize.file_path,
self.function_to_optimize,
TestingMode.PERFORMANCE,
project_root=self.project_root,
)

try:
Expand Down Expand Up @@ -2535,7 +2550,10 @@ def run_optimized_candidate(
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function

add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
self.function_to_optimize.file_path,
self.function_to_optimize,
TestingMode.BEHAVIOR,
project_root=self.project_root,
)

try:
Expand Down Expand Up @@ -2611,7 +2629,10 @@ def run_optimized_candidate(
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function

add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
self.function_to_optimize.file_path,
self.function_to_optimize,
TestingMode.PERFORMANCE,
project_root=self.project_root,
)

try:
Expand Down Expand Up @@ -2974,7 +2995,10 @@ def run_concurrency_benchmark(
try:
# Add concurrency decorator to the source function
add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.CONCURRENCY
self.function_to_optimize.file_path,
self.function_to_optimize,
TestingMode.CONCURRENCY,
project_root=self.project_root,
)

# Run the concurrency benchmark tests
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/end_to_end_test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def run_test(expected_improvement_pct: int) -> bool:
CoverageExpectation(
function_name="retry_with_backoff",
expected_coverage=100.0,
expected_lines=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
expected_lines=[9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
)
],
)
Expand Down
Loading
Loading