diff --git a/components/runners/claude-code-runner/adapter.py b/components/runners/claude-code-runner/adapter.py index a96a907e..96fce69f 100644 --- a/components/runners/claude-code-runner/adapter.py +++ b/components/runners/claude-code-runner/adapter.py @@ -70,9 +70,9 @@ def __init__(self): self._skip_resume_on_restart = False self._turn_count = 0 - # AG-UI streaming state - self._current_message_id: Optional[str] = None - self._current_tool_id: Optional[str] = None + # AG-UI streaming state (per-run, not instance state) + # NOTE: _current_message_id and _current_tool_id are now local variables + # in _run_claude_agent_sdk to avoid race conditions with concurrent runs self._current_run_id: Optional[str] = None self._current_thread_id: Optional[str] = None @@ -289,6 +289,9 @@ async def _run_claude_agent_sdk( thread_id: AG-UI thread identifier run_id: AG-UI run identifier """ + # Per-run state - NOT instance variables to avoid race conditions with concurrent runs + current_message_id: Optional[str] = None + logger.info(f"_run_claude_agent_sdk called with prompt length={len(prompt)}, will create fresh client") try: # Check for authentication method @@ -331,6 +334,7 @@ async def _run_claude_agent_sdk( ToolResultBlock, ) from claude_agent_sdk.types import StreamEvent + from claude_agent_sdk import tool as sdk_tool, create_sdk_mcp_server from observability import ObservabilityManager @@ -401,6 +405,31 @@ async def _run_claude_agent_sdk( # Load MCP server configuration (webfetch is included in static .mcp.json) mcp_servers = self._load_mcp_config(cwd_path) or {} + # Create custom session control tools + # Capture self reference for the restart tool closure + adapter_ref = self + + @sdk_tool("restart_session", "Restart the Claude session to recover from issues, clear state, or get a fresh connection. Use this if you detect you're in a broken state or need to reset.", {}) + async def restart_session_tool(args: dict) -> dict: + """Tool that allows Claude to request a session restart.""" + adapter_ref._restart_requested = True + logger.info("🔄 Session restart requested by Claude via MCP tool") + return { + "content": [{ + "type": "text", + "text": "Session restart has been requested. The current run will complete and a fresh session will be established. Your conversation context will be preserved on disk." + }] + } + + # Create SDK MCP server for session tools + session_tools_server = create_sdk_mcp_server( + name="session", + version="1.0.0", + tools=[restart_session_tool] + ) + mcp_servers["session"] = session_tools_server + logger.info("Added custom session control MCP tools (restart_session)") + # Disable built-in WebFetch in favor of WebFetch.MCP from config allowed_tools = ["Read", "Write", "Bash", "Glob", "Grep", "Edit", "MultiEdit", "WebSearch"] if mcp_servers: @@ -428,20 +457,6 @@ async def _run_claude_agent_sdk( include_partial_messages=True, ) - # Enable continue_conversation for session resumption - if not self._first_run or is_continuation: - try: - options.continue_conversation = True - logger.info("Enabled continue_conversation for session resumption") - yield RawEvent( - type=EventType.RAW, - thread_id=thread_id, - run_id=run_id, - event={"type": "system_log", "message": "🔄 Continuing conversation from previous state"} - ) - except Exception as e: - logger.warning(f"Failed to set continue_conversation: {e}") - if self._skip_resume_on_restart: self._skip_resume_on_restart = False @@ -481,9 +496,24 @@ def create_sdk_client(opts, disable_continue=False): opts.continue_conversation = False return ClaudeSDKClient(options=opts) - # Always create a fresh client for each run (simple and reliable) + # Create fresh client for each run + # (Python SDK has issues with client reuse despite docs suggesting it should work) logger.info("Creating new ClaudeSDKClient for this run...") + # Enable continue_conversation to resume from disk state + if not self._first_run or is_continuation: + try: + options.continue_conversation = True + logger.info("Enabled continue_conversation (will resume from disk state)") + yield RawEvent( + type=EventType.RAW, + thread_id=thread_id, + run_id=run_id, + event={"type": "system_log", "message": "🔄 Resuming conversation from disk state"} + ) + except Exception as e: + logger.warning(f"Failed to set continue_conversation: {e}") + try: logger.info("Creating ClaudeSDKClient...") client = create_sdk_client(options) @@ -508,15 +538,6 @@ def create_sdk_client(opts, disable_continue=False): try: # Store client reference for interrupt support self._active_client = client - - if not self._first_run: - yield RawEvent( - type=EventType.RAW, - thread_id=thread_id, - run_id=run_id, - event={"type": "system_log", "message": "✅ Continuing conversation"} - ) - logger.info("SDK continuing conversation from local state") # Process the prompt step_id = str(uuid.uuid4()) @@ -533,8 +554,12 @@ def create_sdk_client(opts, disable_continue=False): logger.info("Query sent, waiting for response stream...") # Process response stream + logger.info("Starting to consume receive_response() iterator...") + message_count = 0 + async for message in client.receive_response(): - logger.info(f"[ClaudeSDKClient]: {message}") + message_count += 1 + logger.info(f"[ClaudeSDKClient Message #{message_count}]: {message}") # Handle StreamEvent for real-time streaming chunks if isinstance(message, StreamEvent): @@ -542,12 +567,12 @@ def create_sdk_client(opts, disable_continue=False): event_type = event_data.get('type') if event_type == 'message_start': - self._current_message_id = str(uuid.uuid4()) + current_message_id = str(uuid.uuid4()) yield TextMessageStartEvent( type=EventType.TEXT_MESSAGE_START, thread_id=thread_id, run_id=run_id, - message_id=self._current_message_id, + message_id=current_message_id, role="assistant", ) @@ -555,12 +580,12 @@ def create_sdk_client(opts, disable_continue=False): delta_data = event_data.get('delta', {}) if delta_data.get('type') == 'text_delta': text_chunk = delta_data.get('text', '') - if text_chunk: + if text_chunk and current_message_id: yield TextMessageContentEvent( type=EventType.TEXT_MESSAGE_CONTENT, thread_id=thread_id, run_id=run_id, - message_id=self._current_message_id, + message_id=current_message_id, delta=text_chunk, ) continue @@ -654,14 +679,14 @@ def create_sdk_client(opts, disable_continue=False): ) # End text message after processing all blocks - if getattr(message, 'content', []) and self._current_message_id: + if getattr(message, 'content', []) and current_message_id: yield TextMessageEndEvent( type=EventType.TEXT_MESSAGE_END, thread_id=thread_id, run_id=run_id, - message_id=self._current_message_id, + message_id=current_message_id, ) - self._current_message_id = None + current_message_id = None elif isinstance(message, SystemMessage): text = getattr(message, 'text', None) @@ -724,15 +749,31 @@ def create_sdk_client(opts, disable_continue=False): step_id=step_id, step_name="processing_prompt", ) + + logger.info(f"Response iterator fully consumed ({message_count} messages total)") # Mark first run complete self._first_run = False + + # Check if restart was requested by Claude + if self._restart_requested: + logger.info("🔄 Restart was requested, emitting restart event") + self._restart_requested = False # Reset flag + yield RawEvent( + type=EventType.RAW, + thread_id=thread_id, + run_id=run_id, + event={ + "type": "session_restart_requested", + "message": "Claude requested a session restart. Reconnecting..." + } + ) finally: - # Clear active client reference (interrupt no longer valid for this run) + # Clear active client reference self._active_client = None - # Always disconnect client at end of run (no persistence) + # Always disconnect client at end of run if client is not None: logger.info("Disconnecting client (end of run)") await client.disconnect() @@ -761,7 +802,6 @@ async def interrupt(self) -> None: except Exception as e: logger.error(f"Failed to interrupt Claude SDK: {e}") - def _setup_workflow_paths(self, active_workflow_url: str, repos_cfg: list) -> tuple[str, list, str]: """Setup paths for workflow mode.""" add_dirs = [] @@ -1300,25 +1340,34 @@ def _build_workspace_context_prompt(self, repos_cfg, workflow_name, artifacts_pa # Repositories if repos_cfg: + session_id = os.getenv('AGENTIC_SESSION_NAME', '').strip() + feature_branch = f"ambient/{session_id}" if session_id else None + repo_names = [repo.get('name', f'repo-{i}') for i, repo in enumerate(repos_cfg)] if len(repo_names) <= 5: - prompt += f"**Repositories**: {', '.join([f'repos/{name}/' for name in repo_names])}\n\n" + prompt += f"**Repositories**: {', '.join([f'repos/{name}/' for name in repo_names])}\n" else: - prompt += f"**Repositories** ({len(repo_names)} total): {', '.join([f'repos/{name}/' for name in repo_names[:5]])}, and {len(repo_names) - 5} more\n\n" + prompt += f"**Repositories** ({len(repo_names)} total): {', '.join([f'repos/{name}/' for name in repo_names[:5]])}, and {len(repo_names) - 5} more\n" + + if feature_branch: + prompt += f"**Working Branch**: `{feature_branch}` (all repos are on this feature branch)\n\n" + else: + prompt += "\n" # Add git push instructions for repos with autoPush enabled auto_push_repos = [repo for repo in repos_cfg if repo.get('autoPush', False)] if auto_push_repos: + push_branch = feature_branch or "ambient/" + prompt += "## Git Push Instructions\n\n" prompt += "The following repositories have auto-push enabled. When you make changes to these repositories, you MUST commit and push your changes:\n\n" for repo in auto_push_repos: repo_name = repo.get('name', 'unknown') - repo_branch = repo.get('branch', 'main') - prompt += f"- **repos/{repo_name}/** (branch: {repo_branch})\n" + prompt += f"- **repos/{repo_name}/**\n" prompt += "\nAfter making changes to any auto-push repository:\n" prompt += "1. Use `git add` to stage your changes\n" prompt += "2. Use `git commit -m \"description\"` to commit with a descriptive message\n" - prompt += "3. Use `git push origin ` to push to the remote repository\n\n" + prompt += f"3. Use `git push origin {push_branch}` to push to the remote repository\n\n" # MCP Integration Setup Instructions prompt += "## MCP Integrations\n" diff --git a/components/runners/claude-code-runner/main.py b/components/runners/claude-code-runner/main.py index f31f1d8a..079d4001 100644 --- a/components/runners/claude-code-runner/main.py +++ b/components/runners/claude-code-runner/main.py @@ -542,12 +542,13 @@ async def change_workflow(request: Request): return {"message": "Workflow updated", "gitUrl": git_url, "branch": branch, "path": path} -async def clone_repo_at_runtime(git_url: str, branch: str, name: str) -> tuple[bool, str]: +async def clone_repo_at_runtime(git_url: str, branch: str, name: str) -> tuple[bool, str, bool]: """ - Clone a repository at runtime. + Clone a repository at runtime and create a feature branch. This mirrors the logic in hydrate.sh but runs when repos are added - after the pod has started. + after the pod has started. After cloning, creates and checks out a + feature branch named 'ambient/'. Args: git_url: Git repository URL @@ -555,14 +556,17 @@ async def clone_repo_at_runtime(git_url: str, branch: str, name: str) -> tuple[b name: Name for the cloned directory (derived from URL if empty) Returns: - (success, repo_dir_path) tuple + (success, repo_dir_path, was_newly_cloned) tuple + - success: True if repo is available (either newly cloned or already existed) + - repo_dir_path: Path to the repo directory + - was_newly_cloned: True only if the repo was actually cloned this time """ import tempfile import shutil from pathlib import Path if not git_url: - return False, "" + return False, "", False # Derive repo name from URL if not provided if not name: @@ -576,10 +580,10 @@ async def clone_repo_at_runtime(git_url: str, branch: str, name: str) -> tuple[b logger.info(f"Cloning repo '{name}' from {git_url}@{branch}") - # Skip if already cloned + # Skip if already cloned - not newly cloned if repo_final.exists(): logger.info(f"Repo '{name}' already exists at {repo_final}, skipping clone") - return True, str(repo_final) + return True, str(repo_final), False # Already existed, not newly cloned # Create temp directory for clone temp_dir = Path(tempfile.mkdtemp(prefix="repo-clone-")) @@ -600,9 +604,9 @@ async def clone_repo_at_runtime(git_url: str, branch: str, name: str) -> tuple[b clone_url = git_url.replace("https://", f"https://oauth2:{gitlab_token}@") logger.info("Using GITLAB_TOKEN for authentication") - # Clone the repository + # Clone the repository (no --depth 1 to allow full branch operations) process = await asyncio.create_subprocess_exec( - "git", "clone", "--branch", branch, "--single-branch", "--depth", "1", + "git", "clone", "--branch", branch, "--single-branch", clone_url, str(temp_dir), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE @@ -617,20 +621,41 @@ async def clone_repo_at_runtime(git_url: str, branch: str, name: str) -> tuple[b if gitlab_token: error_msg = error_msg.replace(gitlab_token, "***REDACTED***") logger.error(f"Failed to clone repo: {error_msg}") - return False, "" - - logger.info("Clone successful, moving to final location...") + return False, "", False + + logger.info("Clone successful, creating feature branch...") + + # Create and checkout feature branch: ambient/ + session_id = os.getenv("AGENTIC_SESSION_NAME", "").strip() + if session_id: + feature_branch = f"ambient/{session_id}" + checkout_process = await asyncio.create_subprocess_exec( + "git", "checkout", "-b", feature_branch, + cwd=str(temp_dir), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + checkout_stdout, checkout_stderr = await checkout_process.communicate() + + if checkout_process.returncode != 0: + logger.warning(f"Failed to create feature branch '{feature_branch}': {checkout_stderr.decode()}") + # Continue anyway - repo is still usable on the original branch + else: + logger.info(f"Created and checked out feature branch: {feature_branch}") + else: + logger.warning("AGENTIC_SESSION_NAME not set, skipping feature branch creation") # Move to final location + logger.info("Moving to final location...") repo_final.parent.mkdir(parents=True, exist_ok=True) shutil.move(str(temp_dir), str(repo_final)) logger.info(f"Repo '{name}' ready at {repo_final}") - return True, str(repo_final) + return True, str(repo_final), True # Newly cloned except Exception as e: logger.error(f"Error cloning repo: {e}") - return False, "" + return False, "", False finally: # Cleanup temp directory if it still exists if temp_dir.exists(): @@ -725,38 +750,43 @@ async def add_repo(request: Request): name = url.split("/")[-1].removesuffix(".git") # Clone the repository at runtime - success, repo_path = await clone_repo_at_runtime(url, branch, name) + success, repo_path, was_newly_cloned = await clone_repo_at_runtime(url, branch, name) if not success: raise HTTPException(status_code=500, detail=f"Failed to clone repository: {url}") - # Update REPOS_JSON env var - repos_json = os.getenv("REPOS_JSON", "[]") - try: - repos = json.loads(repos_json) if repos_json else [] - except: - repos = [] - - # Add new repo - repos.append({ - "name": name, - "input": { - "url": url, - "branch": branch - } - }) - - os.environ["REPOS_JSON"] = json.dumps(repos) - - # Reset adapter state to force reinitialization on next run - _adapter_initialized = False - adapter._first_run = True - - logger.info(f"Repo '{name}' added and cloned, adapter will reinitialize on next run") - - # Trigger a notification to Claude about the new repository - asyncio.create_task(trigger_repo_added_notification(name, url)) + # Only update state and trigger notification if repo was newly cloned + # This prevents duplicate notifications when both backend and operator call this endpoint + if was_newly_cloned: + # Update REPOS_JSON env var + repos_json = os.getenv("REPOS_JSON", "[]") + try: + repos = json.loads(repos_json) if repos_json else [] + except: + repos = [] + + # Add new repo + repos.append({ + "name": name, + "input": { + "url": url, + "branch": branch + } + }) + + os.environ["REPOS_JSON"] = json.dumps(repos) + + # Reset adapter state to force reinitialization on next run + _adapter_initialized = False + adapter._first_run = True + + logger.info(f"Repo '{name}' added and cloned, adapter will reinitialize on next run") + + # Trigger a notification to Claude about the new repository + asyncio.create_task(trigger_repo_added_notification(name, url)) + else: + logger.info(f"Repo '{name}' already existed, skipping notification (idempotent call)") - return {"message": "Repository added", "name": name, "path": repo_path} + return {"message": "Repository added", "name": name, "path": repo_path, "newly_cloned": was_newly_cloned} async def trigger_repo_added_notification(repo_name: str, repo_url: str):