diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index daee371d7..9e492fa46 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -130,9 +130,18 @@ def parse_args() -> Namespace: "--reset-config", action="store_true", help="Remove codeflash configuration from project config file." ) parser.add_argument("-y", "--yes", action="store_true", help="Skip confirmation prompts (useful for CI/scripts).") + parser.add_argument( + "--subagent", + action="store_true", + help="Subagent mode: skip all interactive prompts with sensible defaults. Designed for AI agent integrations.", + ) args, unknown_args = parser.parse_known_args() sys.argv[:] = [sys.argv[0], *unknown_args] + if args.subagent: + args.yes = True + args.no_pr = True + args.worktree = True return process_and_validate_cmd_args(args) diff --git a/codeflash/cli_cmds/console.py b/codeflash/cli_cmds/console.py index fdc5a420a..8c6a9af9d 100644 --- a/codeflash/cli_cmds/console.py +++ b/codeflash/cli_cmds/console.py @@ -21,7 +21,7 @@ from codeflash.cli_cmds.console_constants import SPINNER_TYPES from codeflash.cli_cmds.logging_config import BARE_LOGGING_FORMAT -from codeflash.lsp.helpers import is_LSP_enabled +from codeflash.lsp.helpers import is_LSP_enabled, is_subagent_mode from codeflash.lsp.lsp_logger import enhanced_log from codeflash.lsp.lsp_message import LspCodeMessage, LspTextMessage @@ -34,33 +34,60 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import DependencyResolver, IndexResult from codeflash.lsp.lsp_message import LspMessage + from codeflash.models.models import TestResults DEBUG_MODE = logging.getLogger().getEffectiveLevel() == logging.DEBUG console = Console() -if is_LSP_enabled(): +if is_LSP_enabled() or is_subagent_mode(): console.quiet = True -logging.basicConfig( - level=logging.INFO, - handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)], - format=BARE_LOGGING_FORMAT, -) +if is_subagent_mode(): + import re + import sys + + _lsp_prefix_re = re.compile(r"^(?:!?lsp,?|h[2-4]|loading)\|") + _subagent_drop_patterns = ( + "Test log -", + "Test failed to load", + "Examining file ", + "Generated ", + "Add custom marker", + "Disabling all autouse", + "Reverting code and helpers", + ) + + class _AgentLogFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + record.msg = _lsp_prefix_re.sub("", str(record.msg)) + msg = record.getMessage() + return not any(msg.startswith(p) for p in _subagent_drop_patterns) + + _agent_handler = logging.StreamHandler(sys.stderr) + _agent_handler.addFilter(_AgentLogFilter()) + logging.basicConfig(level=logging.INFO, handlers=[_agent_handler], format="%(levelname)s: %(message)s") +else: + logging.basicConfig( + level=logging.INFO, + handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)], + format=BARE_LOGGING_FORMAT, + ) logger = logging.getLogger("rich") logging.getLogger("parso").setLevel(logging.WARNING) # override the logger to reformat the messages for the lsp -for level in ("info", "debug", "warning", "error"): - real_fn = getattr(logger, level) - setattr( - logger, - level, - lambda msg, *args, _real_fn=real_fn, _level=level, **kwargs: enhanced_log( - msg, _real_fn, _level, *args, **kwargs - ), - ) +if not is_subagent_mode(): + for level in ("info", "debug", "warning", "error"): + real_fn = getattr(logger, level) + setattr( + logger, + level, + lambda msg, *args, _real_fn=real_fn, _level=level, **kwargs: enhanced_log( + msg, _real_fn, _level, *args, **kwargs + ), + ) class DummyTask: @@ -87,6 +114,8 @@ def paneled_text( text: str, panel_args: dict[str, str | bool] | None = None, text_args: dict[str, str] | None = None ) -> None: """Print text in a panel.""" + if is_subagent_mode(): + return from rich.panel import Panel from rich.text import Text @@ -115,6 +144,8 @@ def code_print( language: Programming language for syntax highlighting ('python', 'javascript', 'typescript') """ + if is_subagent_mode(): + return if is_LSP_enabled(): lsp_log( LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name, message_id=lsp_message_id) @@ -152,6 +183,10 @@ def progress_bar( """ global _progress_bar_active + if is_subagent_mode(): + yield DummyTask().id + return + if is_LSP_enabled(): lsp_log(LspTextMessage(text=message, takes_time=True)) yield @@ -183,6 +218,10 @@ def progress_bar( @contextmanager def test_files_progress_bar(total: int, description: str) -> Generator[tuple[Progress, TaskID], None, None]: """Progress bar for test files.""" + if is_subagent_mode(): + yield DummyProgress(), DummyTask().id + return + if is_LSP_enabled(): lsp_log(LspTextMessage(text=description, takes_time=True)) dummy_progress = DummyProgress() @@ -216,6 +255,10 @@ def call_graph_live_display( from rich.text import Text from rich.tree import Tree + if is_subagent_mode(): + yield lambda _: None + return + if is_LSP_enabled(): lsp_log(LspTextMessage(text="Building call graph", takes_time=True)) yield lambda _: None @@ -323,6 +366,9 @@ def call_graph_summary(call_graph: DependencyResolver, file_to_funcs: dict[Path, if not total_functions: return + if is_subagent_mode(): + return + # Build the mapping expected by the dependency resolver file_items = file_to_funcs.items() mapping = {file_path: {func.qualified_name for func in funcs} for file_path, funcs in file_items} @@ -349,3 +395,87 @@ def call_graph_summary(call_graph: DependencyResolver, file_to_funcs: dict[Path, return console.print(Panel(summary, title="Call Graph Summary", border_style="cyan")) + + +def subagent_log_optimization_result( + function_name: str, + file_path: Path, + perf_improvement_line: str, + original_runtime_ns: int, + best_runtime_ns: int, + raw_explanation: str, + original_code: dict[Path, str], + new_code: dict[Path, str], + review: str, + test_results: TestResults, +) -> None: + import sys + from xml.sax.saxutils import escape + + from codeflash.code_utils.code_utils import unified_diff_strings + from codeflash.code_utils.time_utils import humanize_runtime + from codeflash.models.test_type import TestType + + diff_parts = [] + for path in original_code: + old = original_code.get(path, "") + new = new_code.get(path, "") + if old != new: + diff = unified_diff_strings(old, new, fromfile=str(path), tofile=str(path)) + if diff: + diff_parts.append(diff) + + diff_str = "\n".join(diff_parts) + + original_runtime = humanize_runtime(original_runtime_ns) + optimized_runtime = humanize_runtime(best_runtime_ns) + + report = test_results.get_test_pass_fail_report_by_type() + verification_rows = [] + for test_type in TestType: + if test_type is TestType.INIT_STATE_TEST: + continue + name = test_type.to_name() + if not name: + continue + passed = report[test_type]["passed"] + failed = report[test_type]["failed"] + if passed == 0 and failed == 0: + status = "None Found" + elif failed > 0: + status = f"{failed} Failed, {passed} Passed" + else: + status = f"{passed} Passed" + verification_rows.append(f' ') + + xml = [ + "", + f" {escape(function_name)}", + f" {escape(str(file_path))}", + f" {escape(perf_improvement_line)}", + f" {escape(original_runtime)}", + f" {escape(optimized_runtime)}", + ] + if review: + xml.append(f" {escape(review)}") + xml.append(f" {escape(raw_explanation.strip())}") + xml.append(" ") + xml.extend(verification_rows) + xml.append(" ") + if diff_str: + xml.append(f" {escape(diff_str)}") + for path in new_code: + if new_code[path] != original_code.get(path, ""): + xml.append(f' {escape(new_code[path])}') + xml.append(" ") + xml.append(" Present this optimization to the user using AskUserQuestion.") + xml.append(" Show the function name, performance improvement, and explanation as the question text.") + xml.append(" Provide two options: 'Apply' (write the optimized code to the file) and 'Reject' (do nothing).") + xml.append( + " Put the full diff in the 'Apply' option's markdown preview so the user can review the exact changes." + ) + xml.append(" If the user chooses 'Apply', write the content from optimized-code to the corresponding file.") + xml.append(" ") + xml.append("") + + sys.stdout.write("\n".join(xml) + "\n") diff --git a/codeflash/cli_cmds/logging_config.py b/codeflash/cli_cmds/logging_config.py index 09dc0f1f2..53a0b49fb 100644 --- a/codeflash/cli_cmds/logging_config.py +++ b/codeflash/cli_cmds/logging_config.py @@ -5,8 +5,18 @@ def set_level(level: int, *, echo_setting: bool = True) -> None: import logging + import sys import time + from codeflash.lsp.helpers import is_subagent_mode + + if is_subagent_mode(): + logging.basicConfig( + level=level, handlers=[logging.StreamHandler(sys.stderr)], format="%(levelname)s: %(message)s", force=True + ) + logging.getLogger().setLevel(level) + return + from rich.logging import RichHandler from codeflash.cli_cmds.console import console diff --git a/codeflash/code_utils/checkpoint.py b/codeflash/code_utils/checkpoint.py index 1160bf2e0..367e150b7 100644 --- a/codeflash/code_utils/checkpoint.py +++ b/codeflash/code_utils/checkpoint.py @@ -141,12 +141,18 @@ def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dic def ask_should_use_checkpoint_get_functions(args: argparse.Namespace) -> Optional[dict[str, dict[str, str]]]: previous_checkpoint_functions = None + if getattr(args, "subagent", False): + console.rule() + return None if args.all and codeflash_temp_dir.is_dir(): previous_checkpoint_functions = get_all_historical_functions(args.module_root, codeflash_temp_dir) - if previous_checkpoint_functions and Confirm.ask( - "Previous Checkpoint detected from an incomplete optimization run, shall I continue the optimization from that point?", - default=True, - console=console, + if previous_checkpoint_functions and ( + getattr(args, "yes", False) + or Confirm.ask( + "Previous Checkpoint detected from an incomplete optimization run, shall I continue the optimization from that point?", + default=True, + console=console, + ) ): console.rule() else: diff --git a/codeflash/lsp/helpers.py b/codeflash/lsp/helpers.py index b8840e046..14121ec68 100644 --- a/codeflash/lsp/helpers.py +++ b/codeflash/lsp/helpers.py @@ -18,6 +18,11 @@ def is_LSP_enabled() -> bool: return os.getenv("CODEFLASH_LSP", default="false").lower() == "true" +@lru_cache(maxsize=1) +def is_subagent_mode() -> bool: + return os.getenv("CODEFLASH_SUBAGENT_MODE", default="false").lower() == "true" + + def tree_to_markdown(tree: Tree, level: int = 0) -> str: """Convert a rich Tree into a Markdown bullet list.""" indent = " " * level diff --git a/codeflash/main.py b/codeflash/main.py index 690c1ae98..32ae9c66c 100644 --- a/codeflash/main.py +++ b/codeflash/main.py @@ -11,6 +11,12 @@ from pathlib import Path from typing import TYPE_CHECKING +if "--subagent" in sys.argv: + os.environ["CODEFLASH_SUBAGENT_MODE"] = "true" + import warnings + + warnings.filterwarnings("ignore") + from codeflash.cli_cmds.cli import parse_args, process_pyproject_config from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO, ask_run_end_to_end_test from codeflash.cli_cmds.console import paneled_text diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index dd8e41dd8..efccd9b57 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -24,7 +24,14 @@ from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient from codeflash.api.cfapi import add_code_context_hash, create_staging, get_cfapi_base_urls, mark_optimization_success from codeflash.benchmarking.utils import process_benchmark_data -from codeflash.cli_cmds.console import code_print, console, logger, lsp_log, progress_bar +from codeflash.cli_cmds.console import ( + code_print, + console, + logger, + lsp_log, + progress_bar, + subagent_log_optimization_result, +) from codeflash.code_utils import env_utils from codeflash.code_utils.code_utils import ( choose_weights, @@ -78,7 +85,7 @@ ) from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator from codeflash.languages.python.static_analysis.static_analysis import get_first_top_level_function_or_method_ast -from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown +from codeflash.lsp.helpers import is_LSP_enabled, is_subagent_mode, report_to_markdown_table, tree_to_markdown from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( @@ -1349,6 +1356,8 @@ def repair_optimization( def log_successful_optimization( self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str ) -> None: + if is_subagent_mode(): + return if is_LSP_enabled(): md_lines = [ "### ⚡️ Optimization Summary", @@ -1739,14 +1748,24 @@ def generate_tests( self.executor, testgen_context.markdown, helper_fqns, generated_test_paths, generated_perf_test_paths ) - future_concolic_tests = self.executor.submit( - generate_concolic_tests, self.test_cfg, self.args, self.function_to_optimize, self.function_to_optimize_ast - ) + if is_subagent_mode(): + future_concolic_tests = None + else: + future_concolic_tests = self.executor.submit( + generate_concolic_tests, + self.test_cfg, + self.args, + self.function_to_optimize, + self.function_to_optimize_ast, + ) if not self.args.no_gen_tests: # Wait for test futures to complete - concurrent.futures.wait([*future_tests, future_concolic_tests]) - else: + futures_to_wait = [*future_tests] + if future_concolic_tests is not None: + futures_to_wait.append(future_concolic_tests) + concurrent.futures.wait(futures_to_wait) + elif future_concolic_tests is not None: concurrent.futures.wait([future_concolic_tests]) # Process test generation results tests: list[GeneratedTests] = [] @@ -1775,7 +1794,10 @@ def generate_tests( logger.warning(f"Failed to generate and instrument tests for {self.function_to_optimize.function_name}") return Failure(f"/!\\ NO TESTS GENERATED for {self.function_to_optimize.function_name}") - function_to_concolic_tests, concolic_test_str = future_concolic_tests.result() + if future_concolic_tests is not None: + function_to_concolic_tests, concolic_test_str = future_concolic_tests.result() + else: + function_to_concolic_tests, concolic_test_str = {}, None count_tests = len(tests) if concolic_test_str: count_tests += 1 @@ -2198,7 +2220,20 @@ def process_review( self.optimization_review = opt_review_result.review # Display the reviewer result to the user - if opt_review_result.review: + if is_subagent_mode(): + subagent_log_optimization_result( + function_name=new_explanation.function_name, + file_path=new_explanation.file_path, + perf_improvement_line=new_explanation.perf_improvement_line, + original_runtime_ns=new_explanation.original_runtime_ns, + best_runtime_ns=new_explanation.best_runtime_ns, + raw_explanation=new_explanation.raw_explanation_message, + original_code=original_code_combined, + new_code=new_code_combined, + review=opt_review_result.review, + test_results=new_explanation.winning_behavior_test_results, + ) + elif opt_review_result.review: review_display = { "high": ("[bold green]High[/bold green]", "green", "Recommended to merge"), "medium": ("[bold yellow]Medium[/bold yellow]", "yellow", "Review recommended before merging"), diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 5527a0567..1db66d1b7 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -31,6 +31,7 @@ from codeflash.code_utils.time_utils import humanize_runtime from codeflash.either import is_successful from codeflash.languages import current_language_support, is_javascript, set_current_language +from codeflash.lsp.helpers import is_subagent_mode from codeflash.models.models import ValidCode from codeflash.telemetry.posthog_cf import ph from codeflash.verification.verification_utils import TestConfig @@ -603,7 +604,7 @@ def run(self) -> None: return function_to_tests, _ = self.discover_tests(file_to_funcs_to_optimize) - if self.args.all: + if self.args.all and not getattr(self.args, "agent", False): self.functions_checkpoint = CodeflashRunCheckpoint(self.args.module_root) # GLOBAL RANKING: Rank all functions together before optimizing @@ -657,7 +658,7 @@ def run(self) -> None: if is_successful(best_optimization): optimizations_found += 1 # create a diff patch for successful optimization - if self.current_worktree: + if self.current_worktree and not is_subagent_mode(): best_opt = best_optimization.unwrap() read_writable_code = best_opt.code_context.read_writable_code relative_file_paths = [ @@ -690,7 +691,12 @@ def run(self) -> None: self.functions_checkpoint.cleanup() if hasattr(self.args, "command") and self.args.command == "optimize": self.cleanup_replay_tests() - if optimizations_found == 0: + if is_subagent_mode(): + if optimizations_found == 0: + import sys + + sys.stdout.write("No optimizations found.\n") + elif optimizations_found == 0: logger.info("❌ No optimizations found.") elif self.args.all: logger.info("✨ All functions have been optimized! ✨") diff --git a/uv.lock b/uv.lock index 05b79c606..b5222447b 100644 --- a/uv.lock +++ b/uv.lock @@ -605,7 +605,6 @@ tests = [ [[package]] name = "codeflash-benchmark" -version = "0.3.0" source = { editable = "codeflash-benchmark" } dependencies = [ { name = "pytest", version = "8.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },