diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 9486fc677..e8cf3fe8a 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -18,6 +18,11 @@ from codeflash.models.models import CodePosition +_MODE_TO_DECORATOR = { + TestingMode.BEHAVIOR: "codeflash_behavior_async", + TestingMode.CONCURRENCY: "codeflash_concurrency_async", +} + @dataclass(frozen=True) class FunctionCallNodeArguments: @@ -1667,11 +1672,7 @@ async def async_wrapper(*args, **kwargs): def get_decorator_name_for_mode(mode: TestingMode) -> str: - if mode == TestingMode.BEHAVIOR: - return "codeflash_behavior_async" - if mode == TestingMode.CONCURRENCY: - return "codeflash_concurrency_async" - return "codeflash_performance_async" + return _MODE_TO_DECORATOR.get(mode, "codeflash_performance_async") def write_async_helper_file(target_dir: Path) -> Path: