diff --git a/app/.env.example b/app/.env.example index 91a3052..55c8a0e 100644 --- a/app/.env.example +++ b/app/.env.example @@ -14,6 +14,11 @@ JWT_SECRET_KEY=please-set-at-least-32-bytes-secret-key # TOKEN_HASH_SECRET=please-set-at-least-32-bytes-and-different-from-jwt-secret ACCESS_TOKEN_EXPIRE_MINUTES=4320 REFRESH_TOKEN_EXPIRE_DAYS=7 +# Auth risk-control defaults (can be overridden per environment) +MAX_FAILED_LOGIN_ATTEMPTS=8 +ACCOUNT_LOCK_MINUTES=5 +REGISTER_RATE_LIMIT=12 +LOGIN_RATE_LIMIT=30 REDIS_URL=redis://localhost:6379/0 # RabbitMQ is optional in this stage, kept for future publisher swap. @@ -60,10 +65,11 @@ AGENT_COMPACT_THRESHOLD=0.75 AGENT_USER_DAILY_LIMIT=50 AGENT_USER_CONCURRENT_LIMIT=2 AGENT_STAGING_TTL_SEC=86400 -AGENT_SSE_ENABLED=false +AGENT_SSE_ENABLED=true +AGENT_LLM_PROVIDER=anthropic AGENT_LLM_MODEL=claude-sonnet-4-6 -# Configure provider compatibility only via base URL + API key token. -# Example (DeepSeek Anthropic-compatible endpoint): https://api.deepseek.com/anthropic +AGENT_LLM_PLAN_MAX_TOKENS=8192 +# For Anthropic-compatible providers, e.g. DeepSeek: https://api.deepseek.com/anthropic # AGENT_LLM_BASE_URL= # AGENT_LLM_API_KEY= AGENT_MCP_ENDPOINTS=[] diff --git a/app/src/fileflash/agents/harness/__init__.py b/app/src/fileflash/agents/harness/__init__.py index 87e106a..acdcb08 100644 --- a/app/src/fileflash/agents/harness/__init__.py +++ b/app/src/fileflash/agents/harness/__init__.py @@ -6,6 +6,7 @@ from .policy import PolicyDecision, PolicyGuard from .prompt import PromptBuildRequest, PromptBuilder from .router import ToolCall, ToolRouter +from .tool_registry import REGISTRY, ToolContext, ToolRegistry, ToolSpec __all__ = [ "AgentEvent", @@ -20,6 +21,10 @@ "PolicyGuard", "PromptBuildRequest", "PromptBuilder", + "REGISTRY", "ToolCall", + "ToolContext", + "ToolRegistry", "ToolRouter", + "ToolSpec", ] diff --git a/app/src/fileflash/agents/harness/ask.py b/app/src/fileflash/agents/harness/ask.py new file mode 100644 index 0000000..258f4ac --- /dev/null +++ b/app/src/fileflash/agents/harness/ask.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import asyncio +import contextlib +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from ...repositories import AgentInboxMessageRepository +from .event_bus import AgentEventBus, AgentEventEnvelope, AgentEventStream + + +class AskTimedOut(Exception): + def __init__(self, *, ask_id: int) -> None: + super().__init__(f"Ask {ask_id} timed out") + self.ask_id = ask_id + + +class AskProtocol: + def __init__( + self, + *, + db: AsyncSession, + event_bus: AgentEventBus, + job_id: int, + ) -> None: + self._db = db + self._bus = event_bus + self._job_id = job_id + self._repo = AgentInboxMessageRepository(db) + self._waiters: dict[int, asyncio.Future[Any]] = {} + self._sub_ctx = None + self._sub_stream: AgentEventStream | None = None + self._sub_task: asyncio.Task[None] | None = None + + async def start(self) -> None: + self._sub_ctx = self._bus.subscribe(job_id=self._job_id) + self._sub_stream = await self._sub_ctx.__aenter__() + self._sub_task = asyncio.create_task(self._listen()) + + async def aclose(self) -> None: + if self._sub_task is not None: + self._sub_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._sub_task + if self._sub_ctx is not None: + await self._sub_ctx.__aexit__(None, None, None) + for future in self._waiters.values(): + if not future.done(): + future.cancel() + + async def ask( + self, + *, + prompt: str, + schema: dict[str, Any], + timeout_sec: float, + ) -> Any: + msg = await self._repo.create_ask( + job_id=self._job_id, + payload={"prompt": prompt, "schema": schema, "timeoutSec": timeout_sec}, + ) + await self._db.commit() + + ask_id = int(msg.inbox_message_id) + loop = asyncio.get_running_loop() + future: asyncio.Future[Any] = loop.create_future() + self._waiters[ask_id] = future + + await self._bus.publish( + AgentEventEnvelope( + job_id=self._job_id, + event_type="agent.ask", + payload={ + "messageId": str(ask_id), + "prompt": prompt, + "schema": schema, + "timeoutSec": timeout_sec, + }, + emitted_at=datetime.now(UTC), + ) + ) + + try: + value = await asyncio.wait_for(future, timeout=timeout_sec) + except TimeoutError as exc: + await self._repo.mark_timed_out( + inbox_message_id=ask_id, + answered_at=datetime.now(UTC), + ) + await self._db.commit() + raise AskTimedOut(ask_id=ask_id) from exc + finally: + self._waiters.pop(ask_id, None) + + await self._repo.mark_answered( + inbox_message_id=ask_id, + answered_at=datetime.now(UTC), + ) + await self._db.commit() + return value + + async def _listen(self) -> None: + assert self._sub_stream is not None + while True: + try: + envelope = await self._sub_stream.next(timeout=None) + except asyncio.CancelledError: + raise + except Exception: + continue + if envelope.event_type != "agent.inbox.reply": + continue + reply_to = envelope.payload.get("replyTo") + if reply_to is None: + continue + try: + ask_id = int(reply_to) + except (TypeError, ValueError): + continue + future = self._waiters.get(ask_id) + if future is None or future.done(): + continue + future.set_result(envelope.payload.get("value")) + + +__all__ = ["AskProtocol", "AskTimedOut"] diff --git a/app/src/fileflash/agents/harness/event_bus.py b/app/src/fileflash/agents/harness/event_bus.py new file mode 100644 index 0000000..05c73d0 --- /dev/null +++ b/app/src/fileflash/agents/harness/event_bus.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +import asyncio +import contextlib +import json +import logging +from collections.abc import AsyncIterator +from contextlib import AbstractAsyncContextManager +from dataclasses import asdict, dataclass +from datetime import datetime +from typing import Any, Protocol + +from fastapi.encoders import jsonable_encoder +from redis.asyncio import Redis + +from ...core.settings import Settings, get_settings + +logger = logging.getLogger(__name__) + + +@dataclass(slots=True) +class AgentEventEnvelope: + job_id: int + event_type: str + payload: dict[str, Any] + emitted_at: datetime + event_id: str | None = None + + def to_json(self) -> str: + body = jsonable_encoder(asdict(self)) + return json.dumps(body, ensure_ascii=False, separators=(",", ":")) + + @classmethod + def from_json(cls, raw: str) -> AgentEventEnvelope: + data = json.loads(raw) + return cls( + job_id=int(data["job_id"]), + event_type=str(data["event_type"]), + payload=dict(data.get("payload") or {}), + emitted_at=datetime.fromisoformat(data["emitted_at"]), + event_id=data.get("event_id"), + ) + + +class AgentEventStream(Protocol): + async def next(self, *, timeout: float | None = None) -> AgentEventEnvelope: ... + async def aclose(self) -> None: ... + + +class AgentEventBus(Protocol): + async def publish(self, envelope: AgentEventEnvelope) -> None: ... + + def subscribe( + self, + *, + job_id: int, + ) -> AbstractAsyncContextManager[AgentEventStream]: ... + + +@dataclass(slots=True) +class _InMemoryStream: + queue: asyncio.Queue[AgentEventEnvelope] + + async def next(self, *, timeout: float | None = None) -> AgentEventEnvelope: + if timeout is None: + return await self.queue.get() + return await asyncio.wait_for(self.queue.get(), timeout=timeout) + + async def aclose(self) -> None: + return None + + +class InMemoryAgentEventBus: + def __init__(self, *, buffer_size: int = 64) -> None: + self._buffer = buffer_size + self._subscribers: dict[int, list[asyncio.Queue[AgentEventEnvelope]]] = {} + + async def publish(self, envelope: AgentEventEnvelope) -> None: + queues = list(self._subscribers.get(envelope.job_id, [])) + for queue in queues: + if queue.full(): + logger.warning( + "InMemoryAgentEventBus dropped event: queue full job_id=%s", + envelope.job_id, + ) + continue + await queue.put(envelope) + + @contextlib.asynccontextmanager + async def subscribe(self, *, job_id: int) -> AsyncIterator[_InMemoryStream]: + queue: asyncio.Queue[AgentEventEnvelope] = asyncio.Queue(maxsize=self._buffer) + self._subscribers.setdefault(job_id, []).append(queue) + try: + yield _InMemoryStream(queue=queue) + finally: + subscribers = self._subscribers.get(job_id) + if subscribers is not None: + subscribers.remove(queue) + if not subscribers: + del self._subscribers[job_id] + + +class RedisAgentEventBus: + def __init__( + self, + *, + redis: Redis, + channel_prefix: str, + buffer_size: int = 64, + ) -> None: + self._redis = redis + self._channel_prefix = channel_prefix + self._buffer = buffer_size + + def _channel(self, job_id: int) -> str: + return f"{self._channel_prefix}:{job_id}:events" + + async def publish(self, envelope: AgentEventEnvelope) -> None: + await self._redis.publish(self._channel(envelope.job_id), envelope.to_json()) + + @contextlib.asynccontextmanager + async def subscribe(self, *, job_id: int) -> AsyncIterator[_RedisStream]: + pubsub = self._redis.pubsub() + channel = self._channel(job_id) + await pubsub.subscribe(channel) + stream = _RedisStream(pubsub=pubsub) + try: + yield stream + finally: + await pubsub.unsubscribe(channel) + await pubsub.aclose() + + +@dataclass(slots=True) +class _RedisStream: + pubsub: Any + + async def next(self, *, timeout: float | None = None) -> AgentEventEnvelope: + if timeout is None: + async for message in self.pubsub.listen(): + envelope = _envelope_from_redis_message(message) + if envelope is not None: + return envelope + else: + message = await self.pubsub.get_message( + ignore_subscribe_messages=True, + timeout=timeout, + ) + envelope = _envelope_from_redis_message(message) + if envelope is not None: + return envelope + raise TimeoutError("No event within timeout") + + async def aclose(self) -> None: + await self.pubsub.aclose() + + +def _envelope_from_redis_message(message: Any) -> AgentEventEnvelope | None: + if message is None: + return None + message_type = message.get("type") + if message_type not in {"message", "pmessage"}: + return None + data = message.get("data") + if isinstance(data, bytes): + data = data.decode("utf-8") + return AgentEventEnvelope.from_json(str(data)) + + +def build_agent_event_bus( + *, + settings: Settings | None = None, + redis: Redis | None = None, +) -> AgentEventBus: + cfg = settings or get_settings() + if redis is None: + if not cfg.redis_url: + return InMemoryAgentEventBus(buffer_size=cfg.agent_event_bus_buffer_size) + redis = Redis.from_url(cfg.redis_url, decode_responses=True) + return RedisAgentEventBus( + redis=redis, + channel_prefix=cfg.agent_event_channel_prefix, + buffer_size=cfg.agent_event_bus_buffer_size, + ) + + +__all__ = [ + "AgentEventBus", + "AgentEventEnvelope", + "AgentEventStream", + "InMemoryAgentEventBus", + "RedisAgentEventBus", + "build_agent_event_bus", +] diff --git a/app/src/fileflash/agents/harness/events.py b/app/src/fileflash/agents/harness/events.py index 8bf7a94..3fbfbb2 100644 --- a/app/src/fileflash/agents/harness/events.py +++ b/app/src/fileflash/agents/harness/events.py @@ -1,15 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Any +from .event_bus import AgentEventBus as EventBus +from .event_bus import AgentEventEnvelope as AgentEvent - -@dataclass(slots=True) -class AgentEvent: - event_type: str - payload: dict[str, Any] - - -class EventBus: - async def publish(self, event: AgentEvent) -> None: - raise NotImplementedError("EventBus is scaffolded only in this stage") +__all__ = ["AgentEvent", "EventBus"] diff --git a/app/src/fileflash/agents/harness/inbox.py b/app/src/fileflash/agents/harness/inbox.py new file mode 100644 index 0000000..ac7f45c --- /dev/null +++ b/app/src/fileflash/agents/harness/inbox.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from ...models.enums import AgentInboxKind, AgentInboxStatus +from ...repositories import AgentInboxMessageRepository +from .event_bus import AgentEventBus, AgentEventEnvelope + +_INBOX_EVENT_TYPES: dict[AgentInboxKind, str] = { + AgentInboxKind.REPLY: "agent.inbox.reply", + AgentInboxKind.CONTROL_PAUSE: "agent.inbox.control", + AgentInboxKind.CONTROL_RESUME: "agent.inbox.control", + AgentInboxKind.CONTROL_APPROVE: "agent.inbox.control", + AgentInboxKind.CONTROL_DENY: "agent.inbox.control", + AgentInboxKind.CONTROL_SKIP: "agent.inbox.control", + AgentInboxKind.CONTROL_CANCEL: "agent.inbox.control", +} + + +class AgentInbox: + def __init__(self, *, db: AsyncSession, event_bus: AgentEventBus) -> None: + self._db = db + self._bus = event_bus + self._repo = AgentInboxMessageRepository(db) + + async def handle( + self, + *, + job_id: int, + kind: AgentInboxKind, + payload: dict[str, Any], + reply_to_id: int | None = None, + ): + if kind == AgentInboxKind.REPLY: + if reply_to_id is None: + raise ValueError("reply requires reply_to_id") + ask = await self._repo.get_ask(inbox_message_id=reply_to_id) + if ask is None: + raise ValueError(f"ask {reply_to_id} not found") + if int(ask.job_id) != job_id: + raise ValueError(f"ask {reply_to_id} belongs to a different job") + if ask.status != AgentInboxStatus.WAITING: + raise ValueError(f"ask {reply_to_id} is not waiting") + + msg = await self._repo.record_user_message( + job_id=job_id, + kind=kind, + payload=payload, + reply_to_id=reply_to_id, + ) + event_type = _INBOX_EVENT_TYPES[kind] + envelope_payload: dict[str, Any] = { + "kind": kind.value, + "messageId": str(msg.inbox_message_id), + } + if reply_to_id is not None: + envelope_payload["replyTo"] = str(reply_to_id) + if "value" in payload: + envelope_payload["value"] = payload["value"] + if "metadata" in payload: + envelope_payload["metadata"] = payload["metadata"] + await self._bus.publish( + AgentEventEnvelope( + job_id=job_id, + event_type=event_type, + payload=envelope_payload, + emitted_at=datetime.now(UTC), + ) + ) + return msg + + +__all__ = ["AgentInbox"] diff --git a/app/src/fileflash/agents/harness/policy.py b/app/src/fileflash/agents/harness/policy.py index 027af98..730322b 100644 --- a/app/src/fileflash/agents/harness/policy.py +++ b/app/src/fileflash/agents/harness/policy.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, field from ...schemas.agent import AgentProposedAction +from .tool_registry import REGISTRY @dataclass(slots=True) @@ -11,38 +12,18 @@ class PolicyDecision: reasons: list[str] = field(default_factory=list) -HIGH_RISK_TOOLS = frozenset( - { - "drive.deleteFile", - "drive.deleteFolder", - "drive.batchDelete", - "recycle.clear", - "recycle.permanentDelete", - } -) - -WRITE_TOOLS = frozenset( - { - "drive.createFolder", - "drive.moveFile", - "drive.moveFolder", - "drive.renameFile", - "drive.renameFolder", - *HIGH_RISK_TOOLS, - } -) - - def classify_tool_side_effect(tool_name: str) -> str: - return "write" if tool_name in WRITE_TOOLS else "read" + try: + return REGISTRY.get(tool_name).side_effect + except KeyError: + return "write" def classify_tool_risk(tool_name: str) -> str: - if tool_name in HIGH_RISK_TOOLS or "delete" in tool_name.lower(): + try: + return REGISTRY.get(tool_name).risk_level + except KeyError: return "high" - if classify_tool_side_effect(tool_name) == "write": - return "medium" - return "low" def normalize_action_risk(action: AgentProposedAction) -> AgentProposedAction: @@ -70,6 +51,13 @@ async def evaluate_tool_call( tool_name: str, high_risk_confirmed: bool = False, ) -> PolicyDecision: + try: + REGISTRY.get(tool_name) + except KeyError: + return PolicyDecision( + allowed=False, + reasons=[f"Unsupported agent tool: {tool_name}"], + ) if classify_tool_risk(tool_name) == "high" and not high_risk_confirmed: return PolicyDecision( allowed=False, diff --git a/app/src/fileflash/agents/harness/router.py b/app/src/fileflash/agents/harness/router.py index 3c46f39..2b3b639 100644 --- a/app/src/fileflash/agents/harness/router.py +++ b/app/src/fileflash/agents/harness/router.py @@ -3,23 +3,12 @@ from dataclasses import dataclass from typing import Any -from sqlalchemy import and_, select from sqlalchemy.ext.asyncio import AsyncSession from ...core.errors import ApiError -from ...core.mime import resolve_file_mime_type -from ...models import File, Folder -from ...models.enums import FileStatus, FolderStatus, FolderType -from ...schemas.file import ( - CreateFolderRequest, - GetFolderContentsQuery, - MoveFileRequest, - MoveFolderRequest, - RenameFileRequest, - RenameFolderRequest, -) from ...services.file import FileService from ...services.folder import FolderService +from .tool_registry import REGISTRY, ToolContext @dataclass(slots=True) @@ -36,323 +25,23 @@ def __init__(self, *, db: AsyncSession, user_id: int) -> None: self.folder_service = FolderService(db=db) async def dispatch(self, call: ToolCall) -> dict[str, Any]: - tool = call.tool_name - args = dict(call.arguments or {}) - - if tool == "drive.listFolder": - folder_id = _first_value(args, "folderId", "parentFolderId") or "root" - query = GetFolderContentsQuery( - folder_id=str(folder_id), - page=int(args.get("page") or 1), - per_page=min(200, int(args.get("perPage") or 200)), - ) - if str(folder_id) == "root": - result = await self.folder_service.get_root_contents( - user_id=self.user_id, - query=query, - ) - else: - result = await self.folder_service.get_folder_contents( - user_id=self.user_id, - query=query, - ) - return result.model_dump(by_alias=True, mode="json") - - if tool == "drive.countFiles": - return await self._count_files(args) - - if tool == "drive.createFolder": - name = _required_text(args, "name", "folderName") - parent_id = _first_value(args, "parentFolderId", "targetParentId", "folderId") or "root" - result = await self.folder_service.create_folder( - user_id=self.user_id, - payload=CreateFolderRequest(folder_name=name, parent_folder_id=str(parent_id)), - ) - data = result.model_dump(by_alias=True, mode="json") - data.setdefault("folderId", data.get("id")) - return data - - if tool == "drive.moveFile": - file_id = _required_text(args, "fileId", "id") - target_folder_id = _required_text(args, "targetFolderId", "targetParentId") - result = await self.file_service.move_file( - user_id=self.user_id, - file_id=file_id, - payload=MoveFileRequest( - target_folder_id=target_folder_id, - share_handling=str(args.get("shareHandling") or "keep"), - ), - ) - return result.model_dump(by_alias=True, mode="json") - - if tool == "drive.moveFolder": - folder_id = _required_text(args, "folderId", "id") - target_parent_id = _required_text(args, "targetParentId", "targetFolderId") - result = await self.folder_service.move_folder( - user_id=self.user_id, - folder_id=folder_id, - payload=MoveFolderRequest( - target_parent_id=target_parent_id, - share_handling=str(args.get("shareHandling") or "keep"), - ), - ) - return result.model_dump(by_alias=True, mode="json") - - if tool == "drive.renameFile": - file_id = _required_text(args, "fileId", "id") - file_name = _required_text(args, "fileName", "name") - result = await self.file_service.rename_file( - user_id=self.user_id, - file_id=file_id, - payload=RenameFileRequest(file_name=file_name), - ) - return result.model_dump(by_alias=True, mode="json") - - if tool == "drive.renameFolder": - folder_id = _required_text(args, "folderId", "id") - folder_name = _required_text(args, "folderName", "name") - result = await self.folder_service.rename_folder( - user_id=self.user_id, - folder_id=folder_id, - payload=RenameFolderRequest(folder_name=folder_name), - ) - return result.model_dump(by_alias=True, mode="json") - - if tool == "drive.deleteFile": - file_id = _required_text(args, "fileId", "id") - result = await self.file_service.delete_file(user_id=self.user_id, file_id=file_id) - return result.model_dump(by_alias=True, mode="json") - - if tool == "drive.deleteFolder": - folder_id = _required_text(args, "folderId", "id") - result = await self.folder_service.delete_folder( - user_id=self.user_id, - folder_id=folder_id, - ) - return result.model_dump(by_alias=True, mode="json") - - raise ApiError(status_code=400, code=400, message=f"Unsupported agent tool: {tool}") - - async def _count_files(self, args: dict[str, Any]) -> dict[str, Any]: - folder_id = str(_first_value(args, "folderId", "parentFolderId") or "root") - recursive = _bool_arg(args.get("recursive"), default=True) - category = _normalize_category(args.get("category")) - search = str(args.get("search") or "").strip().lower() - root_folder_id = await _resolve_folder_id( - self.db, + tool_name = str(call.tool_name or "").strip() + try: + spec = REGISTRY.get(tool_name) + except KeyError as exc: + raise ApiError( + status_code=400, + code=400, + message=f"Unsupported agent tool: {tool_name}", + ) from exc + + ctx = ToolContext( + db=self.db, user_id=self.user_id, - folder_id=folder_id, - ) - folder_ids = ( - await _active_descendant_folder_ids( - self.db, - user_id=self.user_id, - root_folder_id=root_folder_id, - ) - if recursive - else [root_folder_id] + file_service=self.file_service, + folder_service=self.folder_service, ) - - statement = select( - File.file_id, - File.file_name, - File.file_size, - File.mime_type, - File.file_ext, - File.folder_id, - ).where( - and_( - File.owner_id == self.user_id, - File.folder_id.in_(folder_ids), - File.status == FileStatus.ACTIVE, - File.is_latest.is_(True), - ) - ) - if search: - statement = statement.where(File.file_name.ilike(f"%{search}%")) - statement = statement.order_by(File.file_name.asc()) - - rows = (await self.db.execute(statement)).all() - by_mime_type: dict[str, int] = {} - sample_items: list[dict[str, Any]] = [] - total_items = 0 - for row in rows: - file_id, file_name, file_size, mime_type, file_ext, row_folder_id = row - resolved_mime = resolve_file_mime_type( - mime_type=mime_type, - file_ext=file_ext, - file_name=file_name, - ) - if category is not None and _category_for_file( - mime_type=resolved_mime, - file_ext=file_ext, - file_name=file_name, - ) != category: - continue - - total_items += 1 - by_mime_type[resolved_mime] = by_mime_type.get(resolved_mime, 0) + 1 - if len(sample_items) < 5: - sample_items.append( - { - "id": str(file_id), - "name": str(file_name), - "size": int(file_size or 0), - "mimeType": resolved_mime, - "folderId": str(row_folder_id), - } - ) - - return { - "totalItems": total_items, - "category": category, - "recursive": recursive, - "folderId": str(root_folder_id), - "byMimeType": dict(sorted(by_mime_type.items())), - "sampleItems": sample_items, - } - - -def _first_value(args: dict[str, Any], *keys: str) -> Any: - for key in keys: - value = args.get(key) - if value not in (None, ""): - return value - return None - - -def _required_text(args: dict[str, Any], *keys: str) -> str: - value = _first_value(args, *keys) - if value is None: - raise ApiError(status_code=400, code=400, message=f"Missing required tool input: {keys[0]}") - text = str(value).strip() - if not text: - raise ApiError(status_code=400, code=400, message=f"Missing required tool input: {keys[0]}") - return text - - -def _bool_arg(value: Any, *, default: bool) -> bool: - if value is None: - return default - if isinstance(value, bool): - return value - text = str(value).strip().lower() - if text in {"1", "true", "yes", "y"}: - return True - if text in {"0", "false", "no", "n"}: - return False - return default - - -async def _resolve_folder_id(db: AsyncSession, *, user_id: int, folder_id: str) -> int: - if not folder_id or folder_id == "root": - root_id = await db.scalar( - select(Folder.folder_id).where( - and_( - Folder.owner_id == user_id, - Folder.parent_folder_id.is_(None), - Folder.folder_type == FolderType.ROOT, - Folder.status == FolderStatus.ACTIVE, - ) - ) - ) - if root_id is None: - raise ApiError(status_code=404, code=404, message="Root folder not found") - return int(root_id) - try: - parsed = int(folder_id) - except ValueError as exc: - raise ApiError(status_code=400, code=400, message="Invalid folderId") from exc - exists = await db.scalar( - select(Folder.folder_id).where( - and_( - Folder.folder_id == parsed, - Folder.owner_id == user_id, - Folder.status == FolderStatus.ACTIVE, - ) - ) - ) - if exists is None: - raise ApiError(status_code=404, code=404, message="Folder not found") - return parsed - - -async def _active_descendant_folder_ids( - db: AsyncSession, - *, - user_id: int, - root_folder_id: int, -) -> list[int]: - descendants = ( - select(Folder.folder_id) - .where( - and_( - Folder.folder_id == root_folder_id, - Folder.owner_id == user_id, - Folder.status == FolderStatus.ACTIVE, - ) - ) - .cte(name="agent_count_descendants", recursive=True) - ) - descendants = descendants.union_all( - select(Folder.folder_id).where( - and_( - Folder.parent_folder_id == descendants.c.folder_id, - Folder.owner_id == user_id, - Folder.status == FolderStatus.ACTIVE, - ) - ) - ) - folder_ids = list(await db.scalars(select(descendants.c.folder_id))) - return [int(folder_id) for folder_id in folder_ids] - - -def _normalize_category(value: Any) -> str | None: - text = str(value or "").strip().lower() - aliases = { - "movies": "video", - "movie": "video", - "film": "video", - "films": "video", - "视频": "video", - "影片": "video", - "电影": "video", - "videos": "video", - "documents": "document", - "docs": "document", - "images": "image", - "pictures": "image", - "archives": "archive", - "compressed": "archive", - } - text = aliases.get(text, text) - if text in {"video", "audio", "image", "document", "archive", "other"}: - return text - return None - - -def _category_for_file(*, mime_type: str, file_ext: str | None, file_name: str | None) -> str: - mime = (mime_type or "").lower() - ext = _normalized_extension(file_ext) or _filename_extension(file_name) - if mime.startswith("video/") or ext in {"mp4", "mov", "avi", "mkv", "webm", "m4v"}: - return "video" - if mime.startswith("audio/") or ext in {"mp3", "wav", "flac", "m4a", "aac", "ogg"}: - return "audio" - if mime.startswith("image/") or ext in {"jpg", "jpeg", "png", "gif", "webp", "svg", "bmp"}: - return "image" - if mime in {"application/pdf"} or ext in {"pdf", "doc", "docx", "xls", "xlsx", "ppt", "pptx", "txt", "md"}: - return "document" - if ext in {"zip", "rar", "7z", "tar", "gz", "bz2", "xz"}: - return "archive" - return "other" - - -def _normalized_extension(value: str | None) -> str: - return str(value or "").strip().lower().lstrip(".") + return await spec.handler(ctx, dict(call.arguments or {})) -def _filename_extension(value: str | None) -> str: - name = str(value or "").strip().lower() - if "." not in name: - return "" - return name.rsplit(".", 1)[-1] +__all__ = ["ToolCall", "ToolRouter"] diff --git a/app/src/fileflash/agents/harness/tool_registry.py b/app/src/fileflash/agents/harness/tool_registry.py new file mode 100644 index 0000000..c508a86 --- /dev/null +++ b/app/src/fileflash/agents/harness/tool_registry.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +import re +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Any, Literal + +from sqlalchemy.ext.asyncio import AsyncSession + +ToolSideEffect = Literal["read", "write"] +ToolRiskLevel = Literal["low", "medium", "high"] + + +@dataclass(slots=True) +class ToolContext: + db: AsyncSession + user_id: int + file_service: Any + folder_service: Any + + +ToolHandler = Callable[[ToolContext, dict[str, Any]], Awaitable[dict[str, Any]]] +ToolAnswerFormatter = Callable[[dict[str, Any]], str | None] + + +@dataclass(frozen=True, slots=True) +class ToolSpec: + name: str + description: str + input_schema: dict[str, Any] + side_effect: ToolSideEffect + risk_level: ToolRiskLevel + requires_confirmation: bool + handler: ToolHandler + anthropic_name: str | None = None + answer_formatter: ToolAnswerFormatter | None = None + + def __post_init__(self) -> None: + if self.anthropic_name is None: + object.__setattr__(self, "anthropic_name", _to_provider_tool_name(self.name)) + + def to_anthropic_tool(self) -> dict[str, Any]: + return { + "name": self.anthropic_name, + "description": self.description, + "input_schema": self.input_schema, + "internalName": self.name, + } + + def to_planner_schema(self) -> dict[str, Any]: + return { + "tool": self.name, + "providerTool": self.anthropic_name, + "description": self.description, + "inputSchema": self.input_schema, + "sideEffect": self.side_effect, + "riskLevel": self.risk_level, + "requiresConfirmation": self.requires_confirmation, + } + + +class ToolRegistry: + def __init__(self) -> None: + self._by_name: dict[str, ToolSpec] = {} + self._by_provider_name: dict[str, ToolSpec] = {} + + def register(self, spec: ToolSpec) -> ToolSpec: + if spec.name in self._by_name: + raise ValueError(f"Agent tool already registered: {spec.name}") + provider_name = str(spec.anthropic_name or "") + if provider_name in self._by_provider_name: + raise ValueError(f"Agent provider tool already registered: {provider_name}") + self._by_name[spec.name] = spec + self._by_provider_name[provider_name] = spec + return spec + + def get(self, name: str) -> ToolSpec: + ensure_builtin_tools_registered() + return self._by_name[name] + + def get_by_provider_name(self, name: str) -> ToolSpec: + ensure_builtin_tools_registered() + return self._by_provider_name[name] + + def all(self) -> list[ToolSpec]: + ensure_builtin_tools_registered() + return list(self._by_name.values()) + + def all_names(self) -> tuple[str, ...]: + return tuple(spec.name for spec in self.all()) + + def unknown_names(self, names: list[str] | tuple[str, ...]) -> list[str]: + ensure_builtin_tools_registered() + return [name for name in names if name not in self._by_name] + + def validate_names(self, names: list[str] | tuple[str, ...]) -> tuple[str, ...]: + unknown = self.unknown_names(names) + if unknown: + raise ValueError(f"Unknown agent tools: {', '.join(sorted(unknown))}") + return tuple(names) + + def schemas_for(self, names: list[str] | tuple[str, ...]) -> list[dict[str, Any]]: + return [self.get(name).to_planner_schema() for name in names] + + def anthropic_tools_for(self, names: list[str] | tuple[str, ...]) -> list[dict[str, Any]]: + return [self.get(name).to_anthropic_tool() for name in names] + + +REGISTRY = ToolRegistry() +_BUILTINS_REGISTERED = False + + +def ensure_builtin_tools_registered() -> None: + global _BUILTINS_REGISTERED + if _BUILTINS_REGISTERED: + return + from .. import tools as _tools # noqa: F401 + _BUILTINS_REGISTERED = True + + +def _to_provider_tool_name(name: str) -> str: + value = name.replace(".", "_").replace("-", "_") + value = re.sub(r"(? None: + def __init__( + self, + *, + settings: Settings | None = None, + policy_guard: PolicyGuard | None = None, + event_bus: AgentEventBus | None = None, + answer_client: AnswerClient | None = None, + ) -> None: + self.settings = settings or get_settings() self.policy_guard = policy_guard or PolicyGuard() + self.event_bus = event_bus + self.answer_client = answer_client or AnthropicPlannerClient(settings=self.settings) async def run(self, *, db: AsyncSession, job: BackgroundJob) -> AgentExecutionResult: + ask: AskProtocol | None = None + if self.event_bus is not None: + ask = AskProtocol(db=db, event_bus=self.event_bus, job_id=int(job.job_id)) + await ask.start() + try: + return await self._run(db=db, job=job, ask=ask) + finally: + if ask is not None: + await ask.aclose() + + async def _run( + self, + *, + db: AsyncSession, + job: BackgroundJob, + ask: AskProtocol | None, + ) -> AgentExecutionResult: if job.requested_by is None: raise ApiError(status_code=400, code=400, message="Agent job is missing requestedBy") request = ExecuteAgentRequest.model_validate(dict(job.payload or {})) @@ -59,10 +97,23 @@ async def run(self, *, db: AsyncSession, job: BackgroundJob) -> AgentExecutionRe ) await db.commit() + inbox_repo = AgentInboxMessageRepository(db) if self.event_bus is not None else None + paused = False for action in actions: await db.refresh(job) if job.cancel_requested_at is not None: raise AgentJobCanceled() + if inbox_repo is not None: + paused, skip_current = await self._handle_step_boundary_controls( + db=db, + job=job, + inbox_repo=inbox_repo, + action=action, + warnings=warnings, + paused=paused, + ) + if skip_current: + continue decision = await self.policy_guard.evaluate_tool_call( tool_name=action.tool, @@ -98,6 +149,13 @@ async def run(self, *, db: AsyncSession, job: BackgroundJob) -> AgentExecutionRe error_message=f"{type(exc).__name__}: {exc}"[:2000], ) await db.commit() + await self._publish_tool( + "tool.failed", + job_id=int(job.job_id), + step=action.step, + tool=action.tool, + payload={"errorMessage": f"{type(exc).__name__}: {exc}"[:2000]}, + ) raise await action_logs.append_step( @@ -109,6 +167,14 @@ async def run(self, *, db: AsyncSession, job: BackgroundJob) -> AgentExecutionRe started_at=started, ) await db.commit() + await self._publish_tool( + "tool.started", + job_id=int(job.job_id), + step=action.step, + tool=action.tool, + payload={"input": resolved_input}, + emitted_at=started, + ) try: output = await router.dispatch( @@ -126,6 +192,13 @@ async def run(self, *, db: AsyncSession, job: BackgroundJob) -> AgentExecutionRe error_message=f"{type(exc).__name__}: {exc}"[:2000], ) await db.commit() + await self._publish_tool( + "tool.failed", + job_id=int(job.job_id), + step=action.step, + tool=action.tool, + payload={"errorMessage": f"{type(exc).__name__}: {exc}"[:2000]}, + ) raise safe_output = jsonable_encoder(output) @@ -138,6 +211,13 @@ async def run(self, *, db: AsyncSession, job: BackgroundJob) -> AgentExecutionRe duration_ms=duration_ms, ) await db.commit() + await self._publish_tool( + "tool.succeeded", + job_id=int(job.job_id), + step=action.step, + tool=action.tool, + payload={"output": safe_output, "durationMs": duration_ms}, + ) step_outputs[action.step] = safe_output applied += 1 @@ -146,7 +226,12 @@ async def run(self, *, db: AsyncSession, job: BackgroundJob) -> AgentExecutionRe warnings.append(f"{skipped} action(s) were skipped.") await work_sessions.close_session(job_id=int(job.job_id), status="closed") await db.commit() - answer = _build_execution_answer(actions=actions, step_outputs=step_outputs) + answer = await _build_execution_answer( + task_input=str(getattr(plan, "input_text", "") or ""), + actions=actions, + step_outputs=step_outputs, + answer_client=self.answer_client, + ) return AgentExecutionResult( plan_job_id=str(plan_job_id), execute_job_id=str(job.job_id), @@ -158,6 +243,96 @@ async def run(self, *, db: AsyncSession, job: BackgroundJob) -> AgentExecutionRe finished_at=datetime.now(UTC), ) + async def _handle_step_boundary_controls( + self, + *, + db: AsyncSession, + job: BackgroundJob, + inbox_repo: AgentInboxMessageRepository, + action: AgentProposedAction, + warnings: list[str], + paused: bool, + ) -> tuple[bool, bool]: + while True: + skip_current = False + pending = await inbox_repo.list_pending_controls(job_id=int(job.job_id)) + for ctrl in pending: + kind = AgentInboxKind(ctrl.kind) + if kind == AgentInboxKind.CONTROL_CANCEL: + await inbox_repo.mark_dropped(inbox_message_id=int(ctrl.inbox_message_id)) + job.cancel_requested_at = datetime.now(UTC) + await db.commit() + raise AgentJobCanceled() + if kind == AgentInboxKind.CONTROL_PAUSE: + paused = True + await inbox_repo.mark_dropped(inbox_message_id=int(ctrl.inbox_message_id)) + await self._publish_state("agent.paused", job_id=int(job.job_id)) + elif kind == AgentInboxKind.CONTROL_RESUME: + paused = False + await inbox_repo.mark_dropped(inbox_message_id=int(ctrl.inbox_message_id)) + await self._publish_state("agent.resumed", job_id=int(job.job_id)) + elif kind == AgentInboxKind.CONTROL_SKIP: + await inbox_repo.mark_dropped(inbox_message_id=int(ctrl.inbox_message_id)) + warnings.append(f"Step {action.step} skipped by user") + skip_current = True + else: + await inbox_repo.mark_dropped(inbox_message_id=int(ctrl.inbox_message_id)) + await db.commit() + if skip_current: + return paused, True + if not paused: + return paused, False + await asyncio.sleep(0.1) + + async def _publish_state(self, event_type: str, *, job_id: int) -> None: + if self.event_bus is None: + return + try: + await self.event_bus.publish( + AgentEventEnvelope( + job_id=job_id, + event_type=event_type, + payload={}, + emitted_at=datetime.now(UTC), + ) + ) + except Exception: + logger.exception( + "Failed to publish state event jobId=%s eventType=%s", + job_id, + event_type, + ) + + async def _publish_tool( + self, + event_type: str, + *, + job_id: int, + step: int, + tool: str, + payload: dict[str, Any], + emitted_at: datetime | None = None, + ) -> None: + if self.event_bus is None: + return + try: + await self.event_bus.publish( + AgentEventEnvelope( + job_id=job_id, + event_type=event_type, + payload={"step": int(step), "tool": str(tool), **payload}, + emitted_at=emitted_at or datetime.now(UTC), + ) + ) + except Exception: + logger.exception( + "Failed to publish tool event jobId=%s eventType=%s step=%s tool=%s", + job_id, + event_type, + step, + tool, + ) + def _parse_job_id(raw: str) -> int: try: @@ -226,35 +401,89 @@ def _resolve_references( return value -def _build_execution_answer( +async def _build_execution_answer( *, + task_input: str = "", actions: list[AgentProposedAction], step_outputs: dict[int, dict[str, Any]], + answer_client: AnswerClient, ) -> str | None: - for action in actions: - if action.tool != "drive.countFiles": - continue - output = step_outputs.get(action.step) - if not isinstance(output, dict): - continue - return _count_files_answer(output) - - if actions and all(action.side_effect == "read" for action in actions): - return f"已完成 {len(step_outputs)} 个只读操作。" - return None - - -def _count_files_answer(output: dict[str, Any]) -> str: - total_items = int(output.get("totalItems") or 0) - category = str(output.get("category") or "").strip().lower() - if category == "video": - return f"你上传了 {total_items} 部电影(按视频文件统计)。" - if category == "audio": - return f"你上传了 {total_items} 个音频文件。" - if category == "image": - return f"你上传了 {total_items} 张图片。" - if category == "document": - return f"你上传了 {total_items} 个文档。" - if category == "archive": - return f"你上传了 {total_items} 个压缩包。" - return f"你上传了 {total_items} 个文件。" + if not actions: + return None + user_prompt = _answer_user_prompt( + task_input=task_input, + actions=actions, + step_outputs=step_outputs, + ) + text = await answer_client.create_answer( + system_prompt=_answer_system_prompt(), + user_prompt=user_prompt, + max_tokens=640, + reasoning_effort="low", + ) + answer = _normalize_answer(text) + if answer is None: + raise ApiError(status_code=502, code=502, message="Agent answer model returned empty response") + return answer + + +def _answer_system_prompt() -> str: + return ( + "You are FileFlash execution answer generator. " + "Only describe results that are present in tool outputs. " + "Do not invent filenames, counts, or paths. " + "Keep the response concise and user-facing in the same language as the user input." + ) + + +def _answer_user_prompt( + *, + task_input: str, + actions: list[AgentProposedAction], + step_outputs: dict[int, dict[str, Any]], +) -> str: + payload_actions: list[dict[str, Any]] = [] + for action in sorted(actions, key=lambda item: item.step): + payload_actions.append( + { + "step": action.step, + "tool": action.tool, + "sideEffect": action.side_effect, + "input": action.input, + "output": _compact_output(step_outputs.get(action.step)), + } + ) + payload = { + "task": task_input, + "actions": payload_actions, + "responseGuidance": { + "includeNamesWhenAvailable": True, + "mentionTruncationWhenProvided": True, + "ifAmbiguous": "state candidate count and ask for clarification", + }, + } + return json.dumps(payload, ensure_ascii=False, sort_keys=True) + + +def _compact_output(value: dict[str, Any] | None) -> dict[str, Any] | None: + if not isinstance(value, dict): + return None + text = json.dumps(value, ensure_ascii=False, separators=(",", ":")) + if len(text) <= 12_000: + return value + compact = dict(value) + compact["truncated"] = True + compact["truncatedFields"] = sorted(compact.keys())[:16] + compact.pop("items", None) + compact.pop("sampleItems", None) + return compact + + +def _normalize_answer(text: str) -> str | None: + candidate = str(text or "").strip() + if not candidate: + return None + candidate = " ".join(candidate.split()) + if len(candidate) > 1200: + candidate = candidate[:1200].rstrip() + "…" + return candidate diff --git a/app/src/fileflash/agents/runtime/llm.py b/app/src/fileflash/agents/runtime/llm.py index 4a268ee..1596728 100644 --- a/app/src/fileflash/agents/runtime/llm.py +++ b/app/src/fileflash/agents/runtime/llm.py @@ -1,6 +1,8 @@ from __future__ import annotations import json +import logging +from collections.abc import Awaitable, Callable from typing import Any, Protocol import anthropic @@ -9,6 +11,9 @@ from ...core.errors import ApiError from ...core.settings import Settings +ToolExecutor = Callable[[str, dict[str, Any]], Awaitable[dict[str, Any]]] +logger = logging.getLogger(__name__) + class PlannerClient(Protocol): async def create_plan( @@ -18,9 +23,23 @@ async def create_plan( user_prompt: str, max_tokens: int, reasoning_effort: str = "adaptive", + tools: list[dict[str, Any]] | None = None, + tool_executor: ToolExecutor | None = None, + max_tool_roundtrips: int = 4, ) -> dict[str, Any]: ... +class AnswerClient(Protocol): + async def create_answer( + self, + *, + system_prompt: str, + user_prompt: str, + max_tokens: int, + reasoning_effort: str = "adaptive", + ) -> str: ... + + class AnthropicPlannerClient: def __init__(self, *, settings: Settings, client: AsyncAnthropic | None = None) -> None: self.settings = settings @@ -33,14 +52,94 @@ async def create_plan( user_prompt: str, max_tokens: int, reasoning_effort: str = "adaptive", + tools: list[dict[str, Any]] | None = None, + tool_executor: ToolExecutor | None = None, + max_tool_roundtrips: int = 4, ) -> dict[str, Any]: api_key = (self.settings.agent_llm_api_key or "").strip() if not api_key: raise ApiError(status_code=503, code=503, message="Agent LLM API key is not configured") + plan_token_cap = _safe_plan_token_cap(self.settings) + + request_kwargs: dict[str, Any] = { + "model": self.settings.agent_llm_model, + "max_tokens": min(max_tokens, plan_token_cap), + "system": system_prompt, + "messages": [{"role": "user", "content": user_prompt}], + "timeout": 60.0, + } + tool_name_map = _tool_name_map(tools or []) + if tools: + request_kwargs["tools"] = _anthropic_tools_payload(tools) + request_kwargs["tool_choice"] = {"type": "auto"} + request_kwargs.update(_reasoning_params(reasoning_effort)) + + try: + parsed, usage = await self._request_and_parse_plan( + api_key=api_key, + request_kwargs=request_kwargs, + tool_name_map=tool_name_map, + tool_executor=tool_executor, + max_tool_roundtrips=max_tool_roundtrips, + ) + except ApiError as first_error: + if not _is_retryable_output_error(first_error): + raise + logger.warning( + "Planner LLM retrying attempt=%s reason=%s degraded=%s jsonOnly=%s", + 2, + first_error.message, + True, + False, + ) + degraded_kwargs = _degraded_plan_request_kwargs(request_kwargs) + try: + parsed, usage = await self._request_and_parse_plan( + api_key=api_key, + request_kwargs=degraded_kwargs, + tool_name_map=tool_name_map, + tool_executor=tool_executor, + max_tool_roundtrips=max_tool_roundtrips, + ) + except ApiError as second_error: + if not _is_retryable_output_error(second_error): + raise + logger.warning( + "Planner LLM retrying attempt=%s reason=%s degraded=%s jsonOnly=%s", + 3, + second_error.message, + True, + True, + ) + strict_kwargs = _strict_json_retry_kwargs( + degraded_kwargs, + max_tokens_cap=plan_token_cap, + ) + parsed, usage = await self._request_and_parse_plan( + api_key=api_key, + request_kwargs=strict_kwargs, + tool_name_map=tool_name_map, + tool_executor=tool_executor, + max_tool_roundtrips=max_tool_roundtrips, + ) + if isinstance(usage, dict): + parsed["_usage"] = usage + return parsed + async def create_answer( + self, + *, + system_prompt: str, + user_prompt: str, + max_tokens: int, + reasoning_effort: str = "adaptive", + ) -> str: + api_key = (self.settings.agent_llm_api_key or "").strip() + if not api_key: + raise ApiError(status_code=503, code=503, message="Agent LLM API key is not configured") request_kwargs: dict[str, Any] = { "model": self.settings.agent_llm_model, - "max_tokens": min(max_tokens, 4096), + "max_tokens": min(max_tokens, 1024), "system": system_prompt, "messages": [{"role": "user", "content": user_prompt}], "timeout": 60.0, @@ -48,7 +147,7 @@ async def create_plan( request_kwargs.update(_reasoning_params(reasoning_effort)) message = await self._request_plan(api_key=api_key, request_kwargs=request_kwargs) try: - parsed, usage = _parse_plan_message(message) + return _extract_text(message) except ApiError as exc: if not _is_retryable_output_error(exc): raise @@ -56,10 +155,7 @@ async def create_plan( degraded_kwargs.pop("thinking", None) degraded_kwargs.pop("output_config", None) message = await self._request_plan(api_key=api_key, request_kwargs=degraded_kwargs) - parsed, usage = _parse_plan_message(message) - if isinstance(usage, dict): - parsed["_usage"] = usage - return parsed + return _extract_text(message) async def _request_plan(self, *, api_key: str, request_kwargs: dict[str, Any]) -> Any: try: @@ -93,6 +189,98 @@ def _get_client(self, api_key: str) -> AsyncAnthropic: ) return self._client + async def _parse_plan_response( + self, + *, + api_key: str, + request_kwargs: dict[str, Any], + message: Any, + tool_name_map: dict[str, str], + tool_executor: ToolExecutor | None, + max_tool_roundtrips: int, + ) -> tuple[dict[str, Any], dict[str, Any] | None]: + if tool_executor is None: + return _parse_plan_message(message, tool_name_map=tool_name_map) + tool_calls = _extract_tool_use_calls(message=message, tool_name_map=tool_name_map) + if not tool_calls: + return _parse_plan_message(message, tool_name_map=tool_name_map) + return await self._run_tool_loop( + api_key=api_key, + request_kwargs=request_kwargs, + initial_message=message, + tool_name_map=tool_name_map, + tool_executor=tool_executor, + max_tool_roundtrips=max_tool_roundtrips, + ) + + async def _request_and_parse_plan( + self, + *, + api_key: str, + request_kwargs: dict[str, Any], + tool_name_map: dict[str, str], + tool_executor: ToolExecutor | None, + max_tool_roundtrips: int, + ) -> tuple[dict[str, Any], dict[str, Any] | None]: + message = await self._request_plan(api_key=api_key, request_kwargs=request_kwargs) + return await self._parse_plan_response( + api_key=api_key, + request_kwargs=request_kwargs, + message=message, + tool_name_map=tool_name_map, + tool_executor=tool_executor, + max_tool_roundtrips=max_tool_roundtrips, + ) + + async def _run_tool_loop( + self, + *, + api_key: str, + request_kwargs: dict[str, Any], + initial_message: Any, + tool_name_map: dict[str, str], + tool_executor: ToolExecutor, + max_tool_roundtrips: int, + ) -> tuple[dict[str, Any], dict[str, Any] | None]: + max_rounds = max(1, min(int(max_tool_roundtrips or 0), 12)) + base_messages = request_kwargs.get("messages") + if not isinstance(base_messages, list): + raise ApiError(status_code=502, code=502, message="Agent LLM returned an invalid response") + messages: list[dict[str, Any]] = list(base_messages) + usage_total: dict[str, int] = {} + current_message = initial_message + + for _ in range(max_rounds): + usage_total = _merge_usage_totals(usage_total, _usage_payload(current_message)) + tool_calls = _extract_tool_use_calls(message=current_message, tool_name_map=tool_name_map) + if not tool_calls: + parsed, _ = _parse_plan_message(current_message, tool_name_map=tool_name_map) + return parsed, usage_total or None + + assistant_content = _content_block_mappings(current_message) + if assistant_content: + messages.append({"role": "assistant", "content": assistant_content}) + tool_results: list[dict[str, Any]] = [] + for call in tool_calls: + tool_output = await tool_executor(call["tool"], call["input"]) + tool_results.append( + { + "type": "tool_result", + "tool_use_id": call["id"], + "content": _tool_result_content(tool_output), + } + ) + messages.append({"role": "user", "content": tool_results}) + loop_kwargs = dict(request_kwargs) + loop_kwargs["messages"] = messages + current_message = await self._request_plan(api_key=api_key, request_kwargs=loop_kwargs) + + raise ApiError( + status_code=502, + code=502, + message="Agent LLM exceeded planning tool rounds", + ) + def _extract_text(message: Any) -> str: chunks = getattr(message, "content", None) @@ -112,6 +300,69 @@ def _extract_text(message: Any) -> str: return text +def _content_block_mappings(message: Any) -> list[dict[str, Any]]: + chunks = getattr(message, "content", None) + if isinstance(chunks, str): + return [{"type": "text", "text": chunks}] + if not isinstance(chunks, list): + return [] + blocks: list[dict[str, Any]] = [] + for chunk in chunks: + if isinstance(chunk, dict): + blocks.append(chunk) + continue + if hasattr(chunk, "model_dump"): + dumped = chunk.model_dump() + if isinstance(dumped, dict): + blocks.append(dumped) + continue + blocks.append( + { + "type": getattr(chunk, "type", None), + "text": getattr(chunk, "text", None), + "name": getattr(chunk, "name", None), + "input": getattr(chunk, "input", None), + "id": getattr(chunk, "id", None), + } + ) + return blocks + + +def _extract_tool_use_payload( + message: Any, + *, + tool_name_map: dict[str, str], +) -> tuple[list[dict[str, Any]], str | None]: + actions: list[dict[str, Any]] = [] + text_parts: list[str] = [] + for block in _content_block_mappings(message): + block_type = str(block.get("type") or "") + if block_type == "tool_use": + provider_name = str(block.get("name") or "").strip() + action_input = _coerce_mapping(block.get("input")) + actions.append( + { + "step": len(actions) + 1, + "tool": tool_name_map.get(provider_name, provider_name), + "input": action_input, + } + ) + continue + text_parts.extend(_extract_text_parts_from_mapping(block)) + summary = "\n".join(part for part in text_parts if part).strip() + return actions, summary or None + + +def _coerce_mapping(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return value + if hasattr(value, "model_dump"): + dumped = value.model_dump() + if isinstance(dumped, dict): + return dumped + return {} + + def _extract_text_parts(chunk: Any) -> list[str]: if chunk is None: return [] @@ -193,6 +444,34 @@ def _usage_payload(message: Any) -> dict[str, Any] | None: return payload or None +def _merge_usage_totals(base: dict[str, int], extra: dict[str, Any] | None) -> dict[str, int]: + merged = dict(base) + if not isinstance(extra, dict): + return merged + for key in ( + "input_tokens", + "output_tokens", + "cache_creation_input_tokens", + "cache_read_input_tokens", + ): + value = extra.get(key) + if value is None: + continue + try: + parsed = int(value) + except (TypeError, ValueError): + continue + merged[key] = int(merged.get(key) or 0) + parsed + return merged + + +def _tool_result_content(payload: dict[str, Any]) -> list[dict[str, str]]: + text = json.dumps(payload, ensure_ascii=False) + if len(text) > 12_000: + text = text[:12_000] + "…" + return [{"type": "text", "text": text}] + + def _reasoning_params(reasoning_effort: str) -> dict[str, Any]: effort = (reasoning_effort or "adaptive").strip().lower() if effort == "adaptive": @@ -211,13 +490,83 @@ def _response_details(error: anthropic.APIStatusError) -> str: return str(text or "")[:800] -def _parse_plan_message(message: Any) -> tuple[dict[str, Any], dict[str, Any] | None]: +def _parse_plan_message( + message: Any, + *, + tool_name_map: dict[str, str] | None = None, +) -> tuple[dict[str, Any], dict[str, Any] | None]: + tool_actions, summary = _extract_tool_use_payload( + message, + tool_name_map=tool_name_map or {}, + ) + usage = _usage_payload(message) + if tool_actions: + return { + "summary": summary or f"Prepared {len(tool_actions)} file action(s).", + "proposedActions": tool_actions, + }, usage text = _extract_text(message) parsed = _parse_json_text(text) - usage = _usage_payload(message) return parsed, usage +def _extract_tool_use_calls( + message: Any, + *, + tool_name_map: dict[str, str], +) -> list[dict[str, Any]]: + calls: list[dict[str, Any]] = [] + for block in _content_block_mappings(message): + if str(block.get("type") or "") != "tool_use": + continue + provider_name = str(block.get("name") or "").strip() + if not provider_name: + continue + tool_use_id = str(block.get("id") or "").strip() + if not tool_use_id: + tool_use_id = f"tool_use_{len(calls) + 1}" + calls.append( + { + "id": tool_use_id, + "tool": tool_name_map.get(provider_name, provider_name), + "input": _coerce_mapping(block.get("input")), + } + ) + return calls + + +def _tool_name_map(tools: list[dict[str, Any]]) -> dict[str, str]: + mapping: dict[str, str] = {} + for tool in tools: + provider_name = str(tool.get("name") or "").strip() + if not provider_name: + continue + internal_name = str( + tool.get("internalName") + or tool.get("internal_name") + or tool.get("tool") + or provider_name + ).strip() + mapping[provider_name] = internal_name or provider_name + return mapping + + +def _anthropic_tools_payload(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + payload: list[dict[str, Any]] = [] + for tool in tools: + name = str(tool.get("name") or "").strip() + if not name: + continue + payload.append( + { + "name": name, + "description": str(tool.get("description") or ""), + "input_schema": dict(tool.get("input_schema") or {"type": "object"}), + } + ) + return payload + + def _is_retryable_output_error(error: ApiError) -> bool: if error.status_code != 502: return False @@ -229,26 +578,137 @@ def _is_retryable_output_error(error: ApiError) -> bool: } +def _safe_plan_token_cap(settings: Settings) -> int: + raw_value = getattr(settings, "agent_llm_plan_max_tokens", 8192) + try: + parsed = int(raw_value) + except (TypeError, ValueError): + return 8192 + return max(1, parsed) + + +def _degraded_plan_request_kwargs(request_kwargs: dict[str, Any]) -> dict[str, Any]: + degraded_kwargs = dict(request_kwargs) + degraded_kwargs.pop("thinking", None) + degraded_kwargs.pop("output_config", None) + return degraded_kwargs + + +def _strict_json_retry_kwargs( + request_kwargs: dict[str, Any], + *, + max_tokens_cap: int, +) -> dict[str, Any]: + strict_kwargs = dict(request_kwargs) + strict_kwargs["max_tokens"] = max(1, int(max_tokens_cap)) + messages = request_kwargs.get("messages") + strict_kwargs["messages"] = _append_json_only_retry_instruction(messages) + return strict_kwargs + + +def _append_json_only_retry_instruction(messages: Any) -> list[dict[str, Any]]: + instruction = ( + "Return ONLY one valid JSON object that matches outputSchema. " + "Do not include markdown fences, prose, or extra text." + ) + if not isinstance(messages, list): + return [{"role": "user", "content": instruction}] + cloned: list[dict[str, Any]] = [] + for item in messages: + if isinstance(item, dict): + cloned.append(dict(item)) + for idx in range(len(cloned) - 1, -1, -1): + if cloned[idx].get("role") != "user": + continue + content = cloned[idx].get("content") + if isinstance(content, str): + merged = content.rstrip() + if merged: + merged = f"{merged}\n\n{instruction}" + else: + merged = instruction + cloned[idx]["content"] = merged + return cloned + cloned.append({"role": "user", "content": instruction}) + return cloned + + def _parse_json_text(text: str) -> dict[str, Any]: - candidate = text.strip() - if candidate.startswith("```"): - lines = candidate.splitlines() - if lines and lines[0].startswith("```"): - lines = lines[1:] - if lines and lines[-1].startswith("```"): - lines = lines[:-1] - candidate = "\n".join(lines).strip() + candidate = _strip_code_fences(text) try: - parsed = json.loads(candidate) - except json.JSONDecodeError as exc: + return _decode_json_object(candidate) + except ApiError: + raise + except json.JSONDecodeError as decode_error: + extracted = _extract_balanced_json_object(candidate) + if extracted is not None: + try: + return _decode_json_object(extracted) + except ApiError: + raise + except json.JSONDecodeError: + pass raise ApiError( status_code=502, code=502, message="Agent LLM did not return valid JSON", - ) from exc + ) from decode_error + + +def _decode_json_object(candidate: str) -> dict[str, Any]: + parsed = json.loads(candidate) if not isinstance(parsed, dict): raise ApiError(status_code=502, code=502, message="Agent LLM JSON must be an object") return parsed -__all__ = ["AnthropicPlannerClient", "PlannerClient"] +def _strip_code_fences(text: str) -> str: + candidate = text.strip() + if not candidate.startswith("```"): + return candidate + lines = candidate.splitlines() + if lines and lines[0].startswith("```"): + lines = lines[1:] + if lines and lines[-1].startswith("```"): + lines = lines[:-1] + return "\n".join(lines).strip() + + +def _extract_balanced_json_object(text: str) -> str | None: + start = -1 + depth = 0 + in_string = False + escaped = False + for idx, ch in enumerate(text): + if start < 0: + if ch == "{": + start = idx + depth = 1 + in_string = False + escaped = False + continue + if in_string: + if escaped: + escaped = False + elif ch == "\\": + escaped = True + elif ch == '"': + in_string = False + continue + if ch == '"': + in_string = True + continue + if ch == "{": + depth += 1 + continue + if ch == "}": + depth -= 1 + if depth == 0: + return text[start : idx + 1] + if depth < 0: + start = -1 + depth = 0 + return None + + +__all__ = ["AnswerClient", "AnthropicPlannerClient", "PlannerClient", "ToolExecutor"] diff --git a/app/src/fileflash/agents/runtime/plan_runner.py b/app/src/fileflash/agents/runtime/plan_runner.py index e3421d0..6caaa75 100644 --- a/app/src/fileflash/agents/runtime/plan_runner.py +++ b/app/src/fileflash/agents/runtime/plan_runner.py @@ -1,6 +1,7 @@ from __future__ import annotations import hashlib +import inspect import json from datetime import UTC, datetime from typing import Any @@ -19,26 +20,19 @@ from ...schemas.agent import ( AgentChosenSkill, AgentCostEstimate, + AgentPlanningEvidence, AgentPlanResult, AgentProposedAction, PlanAgentRequest, ) +from ..harness.ask import AskProtocol +from ..harness.event_bus import AgentEventBus from ..harness.policy import classify_tool_side_effect, normalize_action_risk +from ..harness.router import ToolCall, ToolRouter +from ..harness.tool_registry import REGISTRY from .llm import AnthropicPlannerClient, PlannerClient from .reference_rules import is_symbolic_id_placeholder, parse_step_reference -DEFAULT_AGENT_TOOLS = ( - "drive.listFolder", - "drive.countFiles", - "drive.createFolder", - "drive.moveFile", - "drive.moveFolder", - "drive.renameFile", - "drive.renameFolder", - "drive.deleteFile", - "drive.deleteFolder", -) - class PlanRunner: def __init__( @@ -46,11 +40,30 @@ def __init__( *, settings: Settings | None = None, planner_client: PlannerClient | None = None, + event_bus: AgentEventBus | None = None, ) -> None: self.settings = settings or get_settings() self.planner_client = planner_client or AnthropicPlannerClient(settings=self.settings) + self.event_bus = event_bus async def run(self, *, db: AsyncSession, job: BackgroundJob) -> AgentPlanResult: + ask: AskProtocol | None = None + if self.event_bus is not None: + ask = AskProtocol(db=db, event_bus=self.event_bus, job_id=int(job.job_id)) + await ask.start() + try: + return await self._run(db=db, job=job, ask=ask) + finally: + if ask is not None: + await ask.aclose() + + async def _run( + self, + *, + db: AsyncSession, + job: BackgroundJob, + ask: AskProtocol | None, + ) -> AgentPlanResult: if job.requested_by is None: raise ApiError(status_code=400, code=400, message="Agent job is missing requestedBy") @@ -64,26 +77,60 @@ async def run(self, *, db: AsyncSession, job: BackgroundJob) -> AgentPlanResult: ) metadata = await _collect_context_metadata(db, user_id=user_id, request=request) allowed_tools = _skill_tool_whitelist(skill) - try: - llm_payload = await self.planner_client.create_plan( - system_prompt=_system_prompt(), - user_prompt=_user_prompt( - request=request, - skill=skill, - allowed_tools=allowed_tools, - metadata=metadata, - ), - max_tokens=request.hints.budget_tokens, - reasoning_effort=request.hints.reasoning_effort, - ) - except ApiError as exc: - if exc.status_code != 502: - raise - llm_payload = _safe_fallback_payload( + allowed_tool_set = set(allowed_tools) + planner_router = ToolRouter(db=db, user_id=user_id) + tool_call_budget = min(self.settings.agent_job_max_tool_calls, 32) + planned_tool_calls = 0 + planning_evidence: list[AgentPlanningEvidence] = [] + + async def _planning_tool_executor(tool_name: str, args: dict[str, Any]) -> dict[str, Any]: + nonlocal planned_tool_calls + if tool_name not in allowed_tool_set: + raise ApiError( + status_code=400, + code=400, + message=f"Planner attempted disallowed tool: {tool_name}", + ) + spec = REGISTRY.get(tool_name) + if spec.side_effect != "read": + raise ApiError( + status_code=400, + code=400, + message=f"Planner exploratory tool call must be read-only: {tool_name}", + ) + planned_tool_calls += 1 + if planned_tool_calls > tool_call_budget: + raise ApiError( + status_code=400, + code=400, + message="Planner exceeded exploratory tool-call budget", + ) + output = await planner_router.dispatch(ToolCall(tool_name=tool_name, arguments=args)) + if len(planning_evidence) < 12: + planning_evidence.append( + AgentPlanningEvidence( + step=planned_tool_calls, + tool=tool_name, + input=_evidence_mapping(args), + output_preview=_evidence_preview(output), + ) + ) + return output + + llm_payload = await self.planner_client.create_plan( + system_prompt=_system_prompt(), + user_prompt=_user_prompt( request=request, - metadata=metadata, + skill=skill, allowed_tools=allowed_tools, - ) + metadata=metadata, + ), + max_tokens=request.hints.budget_tokens, + reasoning_effort=request.hints.reasoning_effort, + tools=REGISTRY.anthropic_tools_for(allowed_tools), + tool_executor=_planning_tool_executor, + max_tool_roundtrips=6, + ) actions = _normalize_actions( llm_payload=llm_payload, @@ -91,11 +138,19 @@ async def run(self, *, db: AsyncSession, job: BackgroundJob) -> AgentPlanResult: max_steps=min(request.hints.max_steps, self.settings.agent_job_max_tool_calls), ) chosen_skill = _chosen_skill(skill) - summary = str( + llm_summary = str( llm_payload.get("summary") or f"Prepared {len(actions)} file action(s)." ).strip() - if not summary: - summary = f"Prepared {len(actions)} file action(s)." + if not llm_summary: + llm_summary = f"Prepared {len(actions)} file action(s)." + summary = llm_summary + if _has_write_actions(actions): + summary = await _grounded_write_summary( + db, + user_id=user_id, + actions=actions, + fallback_summary=llm_summary, + ) requires_confirmation = ( request.execution_policy != "autopilot" @@ -115,6 +170,7 @@ async def run(self, *, db: AsyncSession, job: BackgroundJob) -> AgentPlanResult: summary=summary, requires_confirmation=requires_confirmation, cost_estimate=cost_estimate, + planning_evidence=planning_evidence or None, ) await _upsert_agent_plan( db, @@ -129,6 +185,21 @@ async def run(self, *, db: AsyncSession, job: BackgroundJob) -> AgentPlanResult: raise return result + async def _ask( + self, + *, + ask: AskProtocol | None, + prompt: str, + schema: dict[str, Any], + ) -> Any | None: + if ask is None: + return None + return await ask.ask( + prompt=prompt, + schema=schema, + timeout_sec=float(self.settings.agent_inbox_ask_timeout_sec), + ) + async def _choose_skill( db: AsyncSession, @@ -193,10 +264,16 @@ def _skill_tool_whitelist(skill: AgentSkill | AgentSkillCatalogEntry | None) -> raw = skill.tool_whitelist_json if isinstance(raw, list) and raw: tools = tuple(str(item) for item in raw if str(item).strip()) - if "drive.countFiles" not in tools: - return (*tools, "drive.countFiles") + unknown = REGISTRY.unknown_names(tools) + if unknown: + raise ApiError( + status_code=422, + code=422, + message="Unknown agent tool in selected skill", + data={"unknownTools": sorted(unknown)}, + ) return tools - return DEFAULT_AGENT_TOOLS + return REGISTRY.all_names() def _chosen_skill(skill: AgentSkill | AgentSkillCatalogEntry | None) -> AgentChosenSkill | None: @@ -368,9 +445,8 @@ def _folder_metadata(row: Folder) -> dict[str, Any]: def _system_prompt() -> str: return ( - "You are FileFlash Agent Planner. Return only JSON. " - "Plan file-management actions or read-only answers using the provided tools and metadata. " - "For count/how many questions, prefer drive.countFiles over listing folders. " + "You are FileFlash Agent Planner. Build plans from tool-grounded facts, not assumptions. " + "If you need facts, first call read-only tools; then output one final JSON object that matches outputSchema. " "Do not read or infer file contents. Deletions are high risk and must be explicit. " "Cross-step dependencies must use '$stepN.field' references only and never symbolic placeholders " "like 'newFolderId'." @@ -392,6 +468,17 @@ def _user_prompt( "skill": _skill_payload(skill), "allowedTools": list(allowed_tools), "toolSchemas": _tool_schemas(allowed_tools), + "toolUseMode": ( + "Use read-only tools first when facts are missing. " + "Write tools must appear only in final proposedActions, not exploratory tool_use steps." + ), + "plannerDefaults": { + "organizeRequest": { + "scope": "root recursive unless user constrained the scope", + "folderNaming": "reuse existing folders first; create english category folders if missing", + "writePlanPolicy": "generate executable write actions by default", + } + }, "referenceContract": { "syntax": "$stepN.field", "rules": [ @@ -449,89 +536,7 @@ def _skill_payload(skill: AgentSkill | AgentSkillCatalogEntry | None) -> dict[st def _tool_schemas(allowed_tools: tuple[str, ...]) -> list[dict[str, Any]]: - descriptions = { - "drive.listFolder": "List direct folder contents by folderId.", - "drive.countFiles": ( - "Count files under folderId. Supports recursive=true and category values " - "video, audio, image, document, archive, other. Use category=video for movie/电影 questions." - ), - "drive.createFolder": "Create a folder under parentFolderId with name.", - "drive.moveFile": "Move fileId into targetFolderId.", - "drive.moveFolder": "Move folderId into targetParentId.", - "drive.renameFile": "Rename fileId to fileName.", - "drive.renameFolder": "Rename folderId to folderName.", - "drive.deleteFile": "Soft-delete fileId into recycle bin. High risk.", - "drive.deleteFolder": "Soft-delete folderId into recycle bin. High risk.", - } - return [{"tool": tool, "description": descriptions.get(tool, "")} for tool in allowed_tools] - - -def _safe_fallback_payload( - *, - request: PlanAgentRequest, - metadata: dict[str, Any], - allowed_tools: tuple[str, ...], -) -> dict[str, Any]: - fallback_actions: list[dict[str, Any]] = [] - if "drive.countFiles" in allowed_tools and _looks_like_count_question(request.input): - fallback_actions.append( - { - "step": 1, - "tool": "drive.countFiles", - "input": { - "folderId": metadata.get("rootFolderId") or request.context.root_folder_id or "root", - "recursive": True, - "category": _fallback_count_category(request.input), - }, - "sideEffect": "read", - "riskLevel": "low", - "requiresConfirmation": False, - } - ) - return { - "summary": "Planner fallback mode: generated a safe read-only count plan.", - "proposedActions": fallback_actions, - } - if "drive.listFolder" in allowed_tools: - root_folder_id = str( - metadata.get("rootFolderId") - or request.context.root_folder_id - or "root" - ) - fallback_actions.append( - { - "step": 1, - "tool": "drive.listFolder", - "input": {"folderId": root_folder_id}, - "sideEffect": "read", - "riskLevel": "low", - "requiresConfirmation": False, - } - ) - return { - "summary": "Planner fallback mode: generated a safe read-only plan.", - "proposedActions": fallback_actions, - } - - -def _looks_like_count_question(text: str) -> bool: - normalized = text.lower() - return any(token in normalized for token in ("多少", "几个", "几部", "count", "how many", "number of")) - - -def _fallback_count_category(text: str) -> str | None: - normalized = text.lower() - if any(token in normalized for token in ("电影", "影片", "视频", "movie", "film", "video")): - return "video" - if any(token in normalized for token in ("图片", "照片", "image", "photo", "picture")): - return "image" - if any(token in normalized for token in ("音频", "音乐", "audio", "music")): - return "audio" - if any(token in normalized for token in ("文档", "document", "doc")): - return "document" - if any(token in normalized for token in ("压缩", "archive", "zip")): - return "archive" - return None + return REGISTRY.schemas_for(allowed_tools) def _normalize_actions( @@ -668,6 +673,300 @@ def _validate_action_input_value( ) +def _has_write_actions(actions: list[AgentProposedAction]) -> bool: + return any(action.side_effect == "write" for action in actions) + + +async def _grounded_write_summary( + db: AsyncSession, + *, + user_id: int, + actions: list[AgentProposedAction], + fallback_summary: str, +) -> str: + write_actions = [action for action in actions if action.side_effect == "write"] + if not write_actions: + return fallback_summary + + created_folder_names: list[str] = [] + created_folder_by_step: dict[int, str] = {} + move_file_actions: list[AgentProposedAction] = [] + move_folder_actions: list[AgentProposedAction] = [] + file_ids: set[int] = set() + folder_ids: set[int] = set() + + for action in write_actions: + if action.tool == "drive.createFolder": + folder_name = str(action.input.get("name") or action.input.get("folderName") or "").strip() + if folder_name: + created_folder_names.append(folder_name) + created_folder_by_step[action.step] = folder_name + continue + if action.tool == "drive.moveFile": + move_file_actions.append(action) + file_id = _coerce_positive_int(action.input.get("fileId")) + if file_id is not None: + file_ids.add(file_id) + target_folder_id = action.input.get("targetFolderId") + if isinstance(target_folder_id, str): + parsed_folder_id = _coerce_positive_int(target_folder_id) + if parsed_folder_id is not None: + folder_ids.add(parsed_folder_id) + continue + if action.tool == "drive.moveFolder": + move_folder_actions.append(action) + source_folder_id = _coerce_positive_int(action.input.get("folderId")) + if source_folder_id is not None: + folder_ids.add(source_folder_id) + target_parent_id = action.input.get("targetParentId", action.input.get("targetFolderId")) + if isinstance(target_parent_id, str): + parsed_parent_id = _coerce_positive_int(target_parent_id) + if parsed_parent_id is not None: + folder_ids.add(parsed_parent_id) + + file_name_map = await _safe_fetch_file_names(db, user_id=user_id, file_ids=file_ids) + folder_name_map = await _safe_fetch_folder_names(db, user_id=user_id, folder_ids=folder_ids) + + moved_file_names: list[str] = [] + destination_folder_names: list[str] = [] + for action in move_file_actions: + file_id = _coerce_positive_int(action.input.get("fileId")) + if file_id is not None: + file_name = file_name_map.get(file_id) + if file_name: + moved_file_names.append(file_name) + destination = _resolve_destination_folder_name( + action.input.get("targetFolderId"), + created_folder_by_step=created_folder_by_step, + folder_name_map=folder_name_map, + ) + if destination: + destination_folder_names.append(destination) + + moved_folder_count = len(move_folder_actions) + moved_file_count = len(move_file_actions) + clauses: list[str] = [] + + if created_folder_names: + unique_created = _unique_preserve_order(created_folder_names) + if len(unique_created) == 1: + clauses.append(f"创建“{unique_created[0]}”文件夹") + else: + clauses.append(f"创建 {len(unique_created)} 个文件夹") + + if moved_file_count > 0: + clauses.append( + f"将{_format_moved_file_subject(moved_file_names, moved_file_count)}移动到" + f"{_format_destination_folder(destination_folder_names)}" + ) + + if moved_folder_count > 0: + clauses.append(f"移动 {moved_folder_count} 个文件夹") + + if not clauses: + return fallback_summary + return ",并".join(clauses) + "。" + + +async def _safe_fetch_file_names( + db: AsyncSession, + *, + user_id: int, + file_ids: set[int], +) -> dict[int, str]: + if not file_ids: + return {} + try: + result = await db.execute( + select(File.file_id, File.file_name).where( + and_( + File.owner_id == user_id, + File.file_id.in_(sorted(file_ids)), + File.status == FileStatus.ACTIVE, + File.is_latest.is_(True), + ) + ) + ) + rows = result.all() if hasattr(result, "all") else [] + if inspect.isawaitable(rows): + rows = await rows + except Exception: + return {} + out: dict[int, str] = {} + if not isinstance(rows, list): + return out + for row in rows: + try: + file_id = row[0] + file_name = row[1] + except Exception: + continue + parsed = _coerce_positive_int(file_id) + name = str(file_name or "").strip() + if parsed is None or not name: + continue + out[parsed] = name + return out + + +async def _safe_fetch_folder_names( + db: AsyncSession, + *, + user_id: int, + folder_ids: set[int], +) -> dict[int, str]: + if not folder_ids: + return {} + try: + result = await db.execute( + select(Folder.folder_id, Folder.folder_name).where( + and_( + Folder.owner_id == user_id, + Folder.folder_id.in_(sorted(folder_ids)), + Folder.status == FolderStatus.ACTIVE, + ) + ) + ) + rows = result.all() if hasattr(result, "all") else [] + if inspect.isawaitable(rows): + rows = await rows + except Exception: + return {} + out: dict[int, str] = {} + if not isinstance(rows, list): + return out + for row in rows: + try: + folder_id = row[0] + folder_name = row[1] + except Exception: + continue + parsed = _coerce_positive_int(folder_id) + name = str(folder_name or "").strip() + if parsed is None or not name: + continue + out[parsed] = name + return out + + +def _resolve_destination_folder_name( + raw_value: Any, + *, + created_folder_by_step: dict[int, str], + folder_name_map: dict[int, str], +) -> str | None: + if not isinstance(raw_value, str): + return None + value = raw_value.strip() + if not value: + return None + reference = parse_step_reference(value) + if reference is not None: + step, path = reference + if path and path[0].lower() in {"folderid", "id"}: + return created_folder_by_step.get(step) + return None + parsed = _coerce_positive_int(value) + if parsed is None: + return None + return folder_name_map.get(parsed) + + +def _format_moved_file_subject(file_names: list[str], total_count: int) -> str: + unique_names = _unique_preserve_order( + [name.strip() for name in file_names if isinstance(name, str) and name.strip()] + ) + if not unique_names: + return f"{total_count} 个文件" + if len(unique_names) == 1 and total_count == 1: + return f"“{unique_names[0]}”" + preview = unique_names[:3] + quoted = "、".join(f"“{name}”" for name in preview) + if total_count > len(preview): + return f"{quoted}等 {total_count} 个文件" + if total_count > len(unique_names): + return f"{quoted}共 {total_count} 个文件" + return f"{quoted}共 {total_count} 个文件" + + +def _format_destination_folder(folder_names: list[str]) -> str: + unique_names = _unique_preserve_order( + [name.strip() for name in folder_names if isinstance(name, str) and name.strip()] + ) + if not unique_names: + return "目标文件夹" + if len(unique_names) == 1: + return f"“{unique_names[0]}”文件夹" + return "多个目标文件夹" + + +def _unique_preserve_order(items: list[str]) -> list[str]: + out: list[str] = [] + seen: set[str] = set() + for item in items: + if item in seen: + continue + seen.add(item) + out.append(item) + return out + + +def _coerce_positive_int(value: Any) -> int | None: + text = str(value or "").strip() + if not text.isdigit(): + return None + parsed = int(text) + if parsed <= 0: + return None + return parsed + + +def _evidence_mapping(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + preview = _evidence_value_preview(value, depth=0) + if isinstance(preview, dict): + return preview + return {} + + +def _evidence_preview(value: Any) -> dict[str, Any]: + preview = _evidence_value_preview(value, depth=0) + if isinstance(preview, dict): + return preview + return {"value": preview} + + +def _evidence_value_preview(value: Any, *, depth: int) -> Any: + if depth >= 3: + if isinstance(value, str): + return value[:120] + ("…" if len(value) > 120 else "") + if isinstance(value, (int, float, bool)) or value is None: + return value + return str(value)[:120] + + if isinstance(value, dict): + out: dict[str, Any] = {} + items = list(value.items()) + for index, (key, item) in enumerate(items): + if index >= 12: + out["_truncatedKeys"] = len(items) - 12 + break + out[str(key)] = _evidence_value_preview(item, depth=depth + 1) + return out + + if isinstance(value, list): + preview_items = [_evidence_value_preview(item, depth=depth + 1) for item in value[:6]] + if len(value) > 6: + preview_items.append(f"...({len(value) - 6} more)") + return preview_items + + if isinstance(value, str): + return value[:200] + ("…" if len(value) > 200 else "") + if isinstance(value, (int, float, bool)) or value is None: + return value + return str(value)[:200] + + def _cost_estimate( *, llm_payload: dict[str, Any], diff --git a/app/src/fileflash/agents/tools/__init__.py b/app/src/fileflash/agents/tools/__init__.py new file mode 100644 index 0000000..7ea9a9d --- /dev/null +++ b/app/src/fileflash/agents/tools/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from . import drive as drive + +__all__ = ["drive"] diff --git a/app/src/fileflash/agents/tools/drive.py b/app/src/fileflash/agents/tools/drive.py new file mode 100644 index 0000000..f74d369 --- /dev/null +++ b/app/src/fileflash/agents/tools/drive.py @@ -0,0 +1,878 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy import and_, select + +from ...core.errors import ApiError +from ...core.mime import resolve_file_mime_type +from ...models import File, Folder +from ...models.enums import FileStatus, FolderStatus, FolderType +from ...models.tables_storage import StorageObject +from ...schemas.file import ( + CreateFolderRequest, + GetFolderContentsQuery, + MoveFileRequest, + MoveFolderRequest, + RenameFileRequest, + RenameFolderRequest, +) +from ..harness.tool_registry import REGISTRY, ToolContext, ToolSpec + +_CATEGORIES = ("video", "audio", "image", "document", "archive", "other") +_COUNT_FILE_NAME_LIMIT = 12 + + +async def _list_folder(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]: + folder_id = _first_value(args, "folderId", "parentFolderId") or "root" + query = GetFolderContentsQuery( + folder_id=str(folder_id), + page=_int_arg(args.get("page"), default=1, minimum=1), + per_page=_int_arg(args.get("perPage"), default=200, minimum=1, maximum=200), + search=_optional_text(args.get("search")), + ) + if str(folder_id) == "root": + result = await ctx.folder_service.get_root_contents(user_id=ctx.user_id, query=query) + else: + result = await ctx.folder_service.get_folder_contents(user_id=ctx.user_id, query=query) + return result.model_dump(by_alias=True, mode="json") + + +async def _count_files(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]: + folder_id = str(_first_value(args, "folderId", "parentFolderId") or "root") + recursive = _bool_arg(args.get("recursive"), default=True) + category = _normalize_category(args.get("category")) + search = str(args.get("search") or "").strip() + root_folder_id = await _resolve_folder_id(ctx, folder_id=folder_id) + folder_ids = await _folder_scope_ids(ctx, root_folder_id=root_folder_id, recursive=recursive) + + statement = _active_files_query(ctx, folder_ids=folder_ids) + if search: + statement = statement.where(File.file_name.ilike(f"%{search}%")) + statement = statement.order_by(File.file_name.asc()) + + rows = list(await ctx.db.scalars(statement)) + by_mime_type: dict[str, int] = {} + sample_items: list[dict[str, Any]] = [] + item_names: list[str] = [] + names_truncated = False + total_items = 0 + for row in rows: + resolved_mime = _resolved_mime(row) + if category is not None and _category_for_file(row) != category: + continue + + total_items += 1 + by_mime_type[resolved_mime] = by_mime_type.get(resolved_mime, 0) + 1 + file_name = str(row.file_name or "").strip() + if file_name: + if len(item_names) < _COUNT_FILE_NAME_LIMIT: + item_names.append(file_name) + else: + names_truncated = True + if len(sample_items) < 5: + sample_items.append(await _file_payload(ctx, row, include_path=False)) + + return { + "totalItems": total_items, + "category": category, + "recursive": recursive, + "folderId": str(root_folder_id), + "search": search or None, + "byMimeType": dict(sorted(by_mime_type.items())), + "itemNames": item_names, + "itemNamesTruncated": names_truncated, + "sampleItems": sample_items, + } + + +async def _create_folder(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]: + name = _required_text(args, "name", "folderName") + parent_id = _first_value(args, "parentFolderId", "targetParentId", "folderId") or "root" + result = await ctx.folder_service.create_folder( + user_id=ctx.user_id, + payload=CreateFolderRequest(folder_name=name, parent_folder_id=str(parent_id)), + ) + data = result.model_dump(by_alias=True, mode="json") + data.setdefault("folderId", data.get("id")) + return data + + +async def _move_file(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]: + file_id = _required_text(args, "fileId", "id") + target_folder_id = _required_text(args, "targetFolderId", "targetParentId") + result = await ctx.file_service.move_file( + user_id=ctx.user_id, + file_id=file_id, + payload=MoveFileRequest( + target_folder_id=target_folder_id, + share_handling=str(args.get("shareHandling") or "keep"), + ), + ) + return result.model_dump(by_alias=True, mode="json") + + +async def _move_folder(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]: + folder_id = _required_text(args, "folderId", "id") + target_parent_id = _required_text(args, "targetParentId", "targetFolderId") + result = await ctx.folder_service.move_folder( + user_id=ctx.user_id, + folder_id=folder_id, + payload=MoveFolderRequest( + target_parent_id=target_parent_id, + share_handling=str(args.get("shareHandling") or "keep"), + ), + ) + return result.model_dump(by_alias=True, mode="json") + + +async def _rename_file(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]: + file_id = _required_text(args, "fileId", "id") + file_name = _required_text(args, "fileName", "name") + result = await ctx.file_service.rename_file( + user_id=ctx.user_id, + file_id=file_id, + payload=RenameFileRequest(file_name=file_name), + ) + return result.model_dump(by_alias=True, mode="json") + + +async def _rename_folder(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]: + folder_id = _required_text(args, "folderId", "id") + folder_name = _required_text(args, "folderName", "name") + result = await ctx.folder_service.rename_folder( + user_id=ctx.user_id, + folder_id=folder_id, + payload=RenameFolderRequest(folder_name=folder_name), + ) + return result.model_dump(by_alias=True, mode="json") + + +async def _delete_file(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]: + file_id = _required_text(args, "fileId", "id") + result = await ctx.file_service.delete_file(user_id=ctx.user_id, file_id=file_id) + return result.model_dump(by_alias=True, mode="json") + + +async def _delete_folder(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]: + folder_id = _required_text(args, "folderId", "id") + result = await ctx.folder_service.delete_folder(user_id=ctx.user_id, folder_id=folder_id) + return result.model_dump(by_alias=True, mode="json") + + +async def _search_files(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]: + query = _optional_text(args.get("query")) or _optional_text(args.get("search")) or "" + folder_id = str(args.get("folderId") or "root") + recursive = _bool_arg(args.get("recursive"), default=True) + category = _normalize_category(args.get("category")) + mime_prefix = _optional_text(args.get("mimePrefix")) + modified_after = _parse_datetime_arg(args.get("modifiedAfter")) + limit = _int_arg(args.get("limit"), default=50, minimum=1, maximum=200) + + root_folder_id = await _resolve_folder_id(ctx, folder_id=folder_id) + folder_ids = await _folder_scope_ids(ctx, root_folder_id=root_folder_id, recursive=recursive) + statement = _active_files_query(ctx, folder_ids=folder_ids) + if query: + statement = statement.where(File.file_name.ilike(f"%{query}%")) + if modified_after is not None: + statement = statement.where(File.updated_at >= modified_after) + statement = statement.order_by(File.file_name.asc()) + + items: list[dict[str, Any]] = [] + for row in list(await ctx.db.scalars(statement)): + mime_type = _resolved_mime(row) + if mime_prefix and not mime_type.lower().startswith(mime_prefix.lower()): + continue + if category is not None and _category_for_file(row) != category: + continue + items.append(await _file_payload(ctx, row, include_path=True)) + if len(items) >= limit: + break + + return { + "items": items, + "totalItems": len(items), + "query": query or None, + "folderId": str(root_folder_id), + "recursive": recursive, + "category": category, + "mimePrefix": mime_prefix, + "modifiedAfter": modified_after.isoformat() if modified_after else None, + } + + +async def _get_file_info(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]: + file_id = _parse_positive_int(_required_text(args, "fileId", "id"), "fileId") + row = await ctx.db.scalar( + select(File).where( + and_( + File.file_id == file_id, + File.owner_id == ctx.user_id, + File.status == FileStatus.ACTIVE, + File.is_latest.is_(True), + ) + ) + ) + if row is None: + raise ApiError(status_code=404, code=404, message="File not found") + storage = await ctx.db.get(StorageObject, int(row.storage_object_id)) + payload = await _file_payload(ctx, row, include_path=True) + payload.update( + { + "objectHash": str(storage.object_hash) if storage and storage.object_hash else None, + "hashAlgorithm": str(storage.hash_algorithm) if storage else None, + "storageObjectId": str(row.storage_object_id), + "category": _category_for_file(row), + } + ) + return payload + + +async def _list_recent(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]: + limit = _int_arg(args.get("limit"), default=20, minimum=1, maximum=50) + since = _parse_datetime_arg(args.get("since")) + statement = _active_files_query(ctx, folder_ids=None) + if since is not None: + statement = statement.where(File.updated_at >= since) + statement = statement.order_by(File.updated_at.desc(), File.file_id.desc()).limit(limit) + rows = list(await ctx.db.scalars(statement)) + return { + "items": [await _file_payload(ctx, row, include_path=True) for row in rows], + "totalItems": len(rows), + "limit": limit, + "since": since.isoformat() if since else None, + } + + +async def _stats_by_category(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]: + folder_id = str(args.get("folderId") or "root") + recursive = _bool_arg(args.get("recursive"), default=True) + root_folder_id = await _resolve_folder_id(ctx, folder_id=folder_id) + folder_ids = await _folder_scope_ids(ctx, root_folder_id=root_folder_id, recursive=recursive) + rows = list(await ctx.db.scalars(_active_files_query(ctx, folder_ids=folder_ids))) + + categories = { + category: {"count": 0, "totalSize": 0} + for category in _CATEGORIES + } + total_size = 0 + for row in rows: + category = _category_for_file(row) + size = int(row.file_size or 0) + categories[category]["count"] += 1 + categories[category]["totalSize"] += size + total_size += size + + return { + "folderId": str(root_folder_id), + "recursive": recursive, + "totalItems": len(rows), + "totalSize": total_size, + "categories": categories, + "video": categories["video"]["count"], + "audio": categories["audio"]["count"], + "image": categories["image"]["count"], + "document": categories["document"]["count"], + "archive": categories["archive"]["count"], + "other": categories["other"]["count"], + } + + +async def _find_duplicates(ctx: ToolContext, args: dict[str, Any]) -> dict[str, Any]: + folder_id = str(args.get("folderId") or "root") + recursive = _bool_arg(args.get("recursive"), default=True) + by = str(args.get("by") or "hash").strip() or "hash" + if by not in {"hash", "nameSize"}: + raise ApiError(status_code=400, code=400, message="Invalid duplicate mode") + + root_folder_id = await _resolve_folder_id(ctx, folder_id=folder_id) + folder_ids = await _folder_scope_ids(ctx, root_folder_id=root_folder_id, recursive=recursive) + rows = ( + await ctx.db.execute( + _active_files_query(ctx, folder_ids=folder_ids) + .join(StorageObject, StorageObject.object_id == File.storage_object_id) + .add_columns(StorageObject.object_hash, StorageObject.hash_algorithm) + .order_by(File.file_name.asc()) + ) + ).all() + + groups: dict[str, dict[str, Any]] = {} + for row in rows: + file_row: File = row[0] + object_hash = row[1] + hash_algorithm = row[2] + if by == "hash": + if not object_hash: + continue + key = f"{hash_algorithm}:{object_hash}:{int(file_row.file_size or 0)}" + else: + key = f"{file_row.file_name.lower()}:{int(file_row.file_size or 0)}" + group = groups.setdefault( + key, + { + "key": key, + "by": by, + "hash": str(object_hash) if object_hash else None, + "hashAlgorithm": str(hash_algorithm) if hash_algorithm else None, + "size": int(file_row.file_size or 0), + "files": [], + }, + ) + group["files"].append(await _file_payload(ctx, file_row, include_path=True)) + + duplicate_groups = [group for group in groups.values() if len(group["files"]) > 1] + return { + "folderId": str(root_folder_id), + "recursive": recursive, + "by": by, + "groups": duplicate_groups, + "totalGroups": len(duplicate_groups), + "totalFiles": sum(len(group["files"]) for group in duplicate_groups), + } + + +def _active_files_query(ctx: ToolContext, *, folder_ids: list[int] | None): + statement = select(File).where( + and_( + File.owner_id == ctx.user_id, + File.status == FileStatus.ACTIVE, + File.is_latest.is_(True), + ) + ) + if folder_ids is not None: + statement = statement.where(File.folder_id.in_(folder_ids)) + return statement + + +async def _file_payload( + ctx: ToolContext, + row: File, + *, + include_path: bool, +) -> dict[str, Any]: + payload = { + "id": str(row.file_id), + "fileId": str(row.file_id), + "name": str(row.file_name), + "size": int(row.file_size or 0), + "mimeType": _resolved_mime(row), + "folderId": str(row.folder_id), + "createdAt": row.created_at.isoformat() if row.created_at else None, + "updatedAt": row.updated_at.isoformat() if row.updated_at else None, + } + if include_path: + folder_path = await _folder_path(ctx, folder_id=int(row.folder_id)) + payload["path"] = f"{folder_path}/{row.file_name}" if folder_path else str(row.file_name) + return payload + + +async def _resolve_folder_id(ctx: ToolContext, *, folder_id: str) -> int: + if not folder_id or folder_id == "root": + root_id = await ctx.db.scalar( + select(Folder.folder_id).where( + and_( + Folder.owner_id == ctx.user_id, + Folder.parent_folder_id.is_(None), + Folder.folder_type == FolderType.ROOT, + Folder.status == FolderStatus.ACTIVE, + ) + ) + ) + if root_id is None: + raise ApiError(status_code=404, code=404, message="Root folder not found") + return int(root_id) + parsed = _parse_positive_int(folder_id, "folderId") + exists = await ctx.db.scalar( + select(Folder.folder_id).where( + and_( + Folder.folder_id == parsed, + Folder.owner_id == ctx.user_id, + Folder.status == FolderStatus.ACTIVE, + ) + ) + ) + if exists is None: + raise ApiError(status_code=404, code=404, message="Folder not found") + return parsed + + +async def _folder_scope_ids( + ctx: ToolContext, + *, + root_folder_id: int, + recursive: bool, +) -> list[int]: + if not recursive: + return [root_folder_id] + return await _active_descendant_folder_ids(ctx, root_folder_id=root_folder_id) + + +async def _active_descendant_folder_ids(ctx: ToolContext, *, root_folder_id: int) -> list[int]: + descendants = ( + select(Folder.folder_id) + .where( + and_( + Folder.folder_id == root_folder_id, + Folder.owner_id == ctx.user_id, + Folder.status == FolderStatus.ACTIVE, + ) + ) + .cte(name="agent_tool_descendants", recursive=True) + ) + descendants = descendants.union_all( + select(Folder.folder_id).where( + and_( + Folder.parent_folder_id == descendants.c.folder_id, + Folder.owner_id == ctx.user_id, + Folder.status == FolderStatus.ACTIVE, + ) + ) + ) + folder_ids = list(await ctx.db.scalars(select(descendants.c.folder_id))) + return [int(folder_id) for folder_id in folder_ids] + + +async def _folder_path(ctx: ToolContext, *, folder_id: int) -> str: + parts: list[str] = [] + current_id: int | None = folder_id + while current_id is not None: + folder = await ctx.db.scalar( + select(Folder).where( + and_( + Folder.folder_id == current_id, + Folder.owner_id == ctx.user_id, + Folder.status == FolderStatus.ACTIVE, + ) + ) + ) + if folder is None: + break + parts.append(str(folder.folder_name)) + current_id = int(folder.parent_folder_id) if folder.parent_folder_id is not None else None + parts.reverse() + return "/" + "/".join(parts) if parts else "" + + +def _first_value(args: dict[str, Any], *keys: str) -> Any: + for key in keys: + value = args.get(key) + if value not in (None, ""): + return value + return None + + +def _required_text(args: dict[str, Any], *keys: str) -> str: + value = _first_value(args, *keys) + if value is None: + raise ApiError(status_code=400, code=400, message=f"Missing required tool input: {keys[0]}") + text = str(value).strip() + if not text: + raise ApiError(status_code=400, code=400, message=f"Missing required tool input: {keys[0]}") + return text + + +def _optional_text(value: Any) -> str | None: + text = str(value or "").strip() + return text or None + + +def _bool_arg(value: Any, *, default: bool) -> bool: + if value is None: + return default + if isinstance(value, bool): + return value + text = str(value).strip().lower() + if text in {"1", "true", "yes", "y"}: + return True + if text in {"0", "false", "no", "n"}: + return False + return default + + +def _int_arg( + value: Any, + *, + default: int, + minimum: int, + maximum: int | None = None, +) -> int: + try: + parsed = int(value if value is not None else default) + except (TypeError, ValueError): + parsed = default + parsed = max(minimum, parsed) + if maximum is not None: + parsed = min(maximum, parsed) + return parsed + + +def _parse_positive_int(value: str, field_name: str) -> int: + try: + parsed = int(value) + except ValueError as exc: + raise ApiError(status_code=400, code=400, message=f"Invalid {field_name}") from exc + if parsed <= 0: + raise ApiError(status_code=400, code=400, message=f"Invalid {field_name}") + return parsed + + +def _parse_datetime_arg(value: Any) -> datetime | None: + text = str(value or "").strip() + if not text: + return None + try: + parsed = datetime.fromisoformat(text.replace("Z", "+00:00")) + except ValueError as exc: + raise ApiError(status_code=400, code=400, message="Invalid datetime") from exc + if parsed.tzinfo is None: + return parsed.replace(tzinfo=UTC) + return parsed + + +def _normalize_category(value: Any) -> str | None: + text = str(value or "").strip().lower() + aliases = { + "movies": "video", + "movie": "video", + "film": "video", + "films": "video", + "videos": "video", + "anime": "video", + "animation": "video", + "视频": "video", + "影片": "video", + "电影": "video", + "动漫": "video", + "番剧": "video", + "documents": "document", + "docs": "document", + "images": "image", + "pictures": "image", + "archives": "archive", + "compressed": "archive", + } + text = aliases.get(text, text) + if text in _CATEGORIES: + return text + return None + + +def _resolved_mime(row: File) -> str: + return resolve_file_mime_type( + mime_type=row.mime_type, + file_ext=row.file_ext, + file_name=row.file_name, + ) + + +def _category_for_file(row: File) -> str: + mime = _resolved_mime(row).lower() + ext = _normalized_extension(row.file_ext) or _filename_extension(row.file_name) + if mime.startswith("video/") or ext in {"mp4", "mov", "avi", "mkv", "webm", "m4v"}: + return "video" + if mime.startswith("audio/") or ext in {"mp3", "wav", "flac", "m4a", "aac", "ogg"}: + return "audio" + if mime.startswith("image/") or ext in {"jpg", "jpeg", "png", "gif", "webp", "svg", "bmp"}: + return "image" + if mime in {"application/pdf"} or ext in {"pdf", "doc", "docx", "xls", "xlsx", "ppt", "pptx", "txt", "md"}: + return "document" + if ext in {"zip", "rar", "7z", "tar", "gz", "bz2", "xz"}: + return "archive" + return "other" + + +def _count_files_answer(output: dict[str, Any]) -> str: + total_items = int(output.get("totalItems") or 0) + category = str(output.get("category") or "").strip().lower() + qualifier = _search_qualifier(output) + if category == "video": + return f"你上传了 {total_items} 部{qualifier}电影(按视频文件统计)。" + if category == "audio": + return f"你上传了 {total_items} 个{qualifier}音频文件。" + if category == "image": + return f"你上传了 {total_items} 张{qualifier}图片。" + if category == "document": + return f"你上传了 {total_items} 个{qualifier}文档。" + if category == "archive": + return f"你上传了 {total_items} 个{qualifier}压缩包。" + return f"你上传了 {total_items} 个{qualifier}文件。" + + +def _search_qualifier(output: dict[str, Any]) -> str: + search = str(output.get("search") or "").strip() + if not search: + return "" + return f"名称包含“{search}”的" + + +def _normalized_extension(value: str | None) -> str: + return str(value or "").strip().lower().lstrip(".") + + +def _filename_extension(value: str | None) -> str: + name = str(value or "").strip().lower() + if "." not in name: + return "" + return name.rsplit(".", 1)[-1] + + +def _schema(properties: dict[str, Any], required: list[str] | None = None) -> dict[str, Any]: + return { + "type": "object", + "properties": properties, + "required": required or [], + } + + +_FOLDER_ID = {"type": "string", "description": "Folder id, or root for the user's root folder."} +_FILE_ID = {"type": "string", "description": "File id owned by the current user."} +_CATEGORY = {"type": "string", "enum": list(_CATEGORIES)} +_SHARE_HANDLING = {"type": "string", "enum": ["keep", "revoke"], "default": "keep"} + + +REGISTRY.register( + ToolSpec( + name="drive.listFolder", + description="List direct files and folders inside a folder.", + input_schema=_schema( + { + "folderId": _FOLDER_ID, + "page": {"type": "integer", "minimum": 1, "default": 1}, + "perPage": {"type": "integer", "minimum": 1, "maximum": 200, "default": 200}, + "search": {"type": "string"}, + } + ), + side_effect="read", + risk_level="low", + requires_confirmation=False, + handler=_list_folder, + ) +) + +REGISTRY.register( + ToolSpec( + name="drive.countFiles", + description=( + "Count files under a folder. Supports recursive counts, broad file categories, " + "and filename contains search." + ), + input_schema=_schema( + { + "folderId": _FOLDER_ID, + "recursive": {"type": "boolean", "default": True}, + "category": _CATEGORY, + "search": {"type": "string"}, + } + ), + side_effect="read", + risk_level="low", + requires_confirmation=False, + handler=_count_files, + answer_formatter=_count_files_answer, + ) +) + +REGISTRY.register( + ToolSpec( + name="drive.createFolder", + description="Create a folder under parentFolderId with name.", + input_schema=_schema( + { + "parentFolderId": _FOLDER_ID, + "name": {"type": "string", "minLength": 1, "maxLength": 255}, + }, + required=["name"], + ), + side_effect="write", + risk_level="medium", + requires_confirmation=False, + handler=_create_folder, + ) +) + +REGISTRY.register( + ToolSpec( + name="drive.moveFile", + description="Move fileId into targetFolderId.", + input_schema=_schema( + { + "fileId": _FILE_ID, + "targetFolderId": _FOLDER_ID, + "shareHandling": _SHARE_HANDLING, + }, + required=["fileId", "targetFolderId"], + ), + side_effect="write", + risk_level="medium", + requires_confirmation=False, + handler=_move_file, + ) +) + +REGISTRY.register( + ToolSpec( + name="drive.moveFolder", + description="Move folderId into targetParentId.", + input_schema=_schema( + { + "folderId": _FOLDER_ID, + "targetParentId": _FOLDER_ID, + "shareHandling": _SHARE_HANDLING, + }, + required=["folderId", "targetParentId"], + ), + side_effect="write", + risk_level="medium", + requires_confirmation=False, + handler=_move_folder, + ) +) + +REGISTRY.register( + ToolSpec( + name="drive.renameFile", + description="Rename fileId to fileName.", + input_schema=_schema( + { + "fileId": _FILE_ID, + "fileName": {"type": "string", "minLength": 1, "maxLength": 255}, + }, + required=["fileId", "fileName"], + ), + side_effect="write", + risk_level="medium", + requires_confirmation=False, + handler=_rename_file, + ) +) + +REGISTRY.register( + ToolSpec( + name="drive.renameFolder", + description="Rename folderId to folderName.", + input_schema=_schema( + { + "folderId": _FOLDER_ID, + "folderName": {"type": "string", "minLength": 1, "maxLength": 255}, + }, + required=["folderId", "folderName"], + ), + side_effect="write", + risk_level="medium", + requires_confirmation=False, + handler=_rename_folder, + ) +) + +REGISTRY.register( + ToolSpec( + name="drive.deleteFile", + description="Soft-delete fileId into the recycle bin. This is high risk.", + input_schema=_schema({"fileId": _FILE_ID}, required=["fileId"]), + side_effect="write", + risk_level="high", + requires_confirmation=True, + handler=_delete_file, + ) +) + +REGISTRY.register( + ToolSpec( + name="drive.deleteFolder", + description="Soft-delete folderId into the recycle bin. This is high risk.", + input_schema=_schema({"folderId": _FOLDER_ID}, required=["folderId"]), + side_effect="write", + risk_level="high", + requires_confirmation=True, + handler=_delete_folder, + ) +) + +REGISTRY.register( + ToolSpec( + name="drive.searchFiles", + description="Search active files by filename, folder scope, category, MIME prefix, and update time.", + input_schema=_schema( + { + "query": {"type": "string"}, + "folderId": _FOLDER_ID, + "recursive": {"type": "boolean", "default": True}, + "category": _CATEGORY, + "mimePrefix": {"type": "string", "description": "MIME type prefix such as video/."}, + "modifiedAfter": {"type": "string", "format": "date-time"}, + "limit": {"type": "integer", "minimum": 1, "maximum": 200, "default": 50}, + } + ), + side_effect="read", + risk_level="low", + requires_confirmation=False, + handler=_search_files, + ) +) + +REGISTRY.register( + ToolSpec( + name="drive.getFileInfo", + description="Return detailed metadata for one active file.", + input_schema=_schema({"fileId": _FILE_ID}, required=["fileId"]), + side_effect="read", + risk_level="low", + requires_confirmation=False, + handler=_get_file_info, + ) +) + +REGISTRY.register( + ToolSpec( + name="drive.listRecent", + description="List recently updated active files.", + input_schema=_schema( + { + "limit": {"type": "integer", "minimum": 1, "maximum": 50, "default": 20}, + "since": {"type": "string", "format": "date-time"}, + } + ), + side_effect="read", + risk_level="low", + requires_confirmation=False, + handler=_list_recent, + ) +) + +REGISTRY.register( + ToolSpec( + name="drive.statsByCategory", + description="Compute file counts and sizes by broad category under a folder.", + input_schema=_schema( + { + "folderId": _FOLDER_ID, + "recursive": {"type": "boolean", "default": True}, + } + ), + side_effect="read", + risk_level="low", + requires_confirmation=False, + handler=_stats_by_category, + ) +) + +REGISTRY.register( + ToolSpec( + name="drive.findDuplicates", + description="Find duplicate active files by content hash or by name plus size.", + input_schema=_schema( + { + "folderId": _FOLDER_ID, + "recursive": {"type": "boolean", "default": True}, + "by": {"type": "string", "enum": ["hash", "nameSize"], "default": "hash"}, + } + ), + side_effect="read", + risk_level="low", + requires_confirmation=False, + handler=_find_duplicates, + ) +) + +__all__ = [] diff --git a/app/src/fileflash/agents/worker.py b/app/src/fileflash/agents/worker.py index 6fd4e89..a122e31 100644 --- a/app/src/fileflash/agents/worker.py +++ b/app/src/fileflash/agents/worker.py @@ -17,6 +17,7 @@ from ..models import BackgroundJob from ..services.job_queue import RedisStreamJobQueue from ..workers.contracts import WorkerJobMessage +from .harness.event_bus import AgentEventBus, AgentEventEnvelope, build_agent_event_bus from .runtime import AgentJobCanceled, ExecuteRunner, PlanRunner logger = logging.getLogger(__name__) @@ -28,10 +29,12 @@ def __init__( *, queue: RedisStreamJobQueue, session_factory: async_sessionmaker[AsyncSession] = SessionLocal, + event_bus: AgentEventBus | None = None, ) -> None: self._settings = get_settings() self._queue = queue self._session_factory = session_factory + self._event_bus = event_bus or build_agent_event_bus(settings=self._settings) async def run(self) -> None: logger.info( @@ -103,12 +106,15 @@ async def _run_job(self, *, job: BackgroundJob) -> tuple[dict[str, Any], str]: if fresh_job is None: raise ApiError(status_code=404, code=404, message="Job not found") if fresh_job.task_type == "agent.plan": - result = await PlanRunner(settings=self._settings).run(db=db, job=fresh_job) + result = await PlanRunner( + settings=self._settings, + event_bus=self._event_bus, + ).run(db=db, job=fresh_job) phase = "awaiting_confirm" if result.requires_confirmation else "completed" - return result.model_dump(by_alias=True), phase + return result.model_dump(by_alias=True, mode="json"), phase if fresh_job.task_type == "agent.execute": - result = await ExecuteRunner().run(db=db, job=fresh_job) - return result.model_dump(by_alias=True), "completed" + result = await ExecuteRunner(event_bus=self._event_bus).run(db=db, job=fresh_job) + return result.model_dump(by_alias=True, mode="json"), "completed" raise ApiError( status_code=400, code=400, @@ -135,6 +141,8 @@ async def _mark_running(self, message: WorkerJobMessage) -> BackgroundJob | None return job async def _mark_succeeded(self, *, job_id: int, result: dict[str, Any], phase: str) -> None: + safe_result = jsonable_encoder(result) + should_publish = False async with self._session_factory() as db: async with db.begin(): await apply_local_lock_timeout(db) @@ -147,14 +155,22 @@ async def _mark_succeeded(self, *, job_id: int, result: dict[str, Any], phase: s return now = datetime.now(UTC) job.status = "succeeded" - job.result = jsonable_encoder(result) + job.result = safe_result job.error_message = None job.agent_phase = phase job.finished_at = now job.updated_at = now + should_publish = True + if should_publish: + await self._publish_terminal( + job_id=job_id, + event_type="job.succeeded", + payload={"status": "succeeded", "agentPhase": phase, "data": {"result": safe_result}}, + ) async def _mark_failed(self, *, job_id: int, error: Exception) -> None: message = _error_message(error) + should_publish = False async with self._session_factory() as db: async with db.begin(): await apply_local_lock_timeout(db) @@ -171,8 +187,21 @@ async def _mark_failed(self, *, job_id: int, error: Exception) -> None: job.error_message = message[:2000] job.finished_at = now job.updated_at = now + should_publish = True + if should_publish: + await self._publish_terminal( + job_id=job_id, + event_type="job.failed", + payload={ + "status": "failed", + "agentPhase": "failed", + "message": message[:2000], + "data": {"errorMessage": message[:2000]}, + }, + ) async def _mark_canceled(self, *, job_id: int) -> None: + should_publish = False async with self._session_factory() as db: async with db.begin(): await apply_local_lock_timeout(db) @@ -187,6 +216,36 @@ async def _mark_canceled(self, *, job_id: int) -> None: job.cancel_requested_at = job.cancel_requested_at or now job.finished_at = now job.updated_at = now + should_publish = True + if should_publish: + await self._publish_terminal( + job_id=job_id, + event_type="job.canceled", + payload={"status": "canceled", "agentPhase": "canceled"}, + ) + + async def _publish_terminal( + self, + *, + job_id: int, + event_type: str, + payload: dict[str, Any] | None = None, + ) -> None: + try: + await self._event_bus.publish( + AgentEventEnvelope( + job_id=job_id, + event_type=event_type, + payload=payload or {}, + emitted_at=datetime.now(UTC), + ) + ) + except Exception: + logger.exception( + "Failed to publish terminal event jobId=%s eventType=%s", + job_id, + event_type, + ) def _error_message(error: Exception) -> str: diff --git a/app/src/fileflash/core/deps.py b/app/src/fileflash/core/deps.py index c4fd5be..5ac1f45 100644 --- a/app/src/fileflash/core/deps.py +++ b/app/src/fileflash/core/deps.py @@ -5,9 +5,10 @@ from jwt import InvalidTokenError from sqlalchemy.ext.asyncio import AsyncSession +from ..agents.harness.event_bus import AgentEventBus, build_agent_event_bus from ..db.deps import get_db -from ..models.tables_identity import User from ..models.enums import UserRole +from ..models.tables_identity import User from ..repositories import ( AgentActionLogRepository, AgentMcpRepository, @@ -17,17 +18,27 @@ AgentSkillRepository, AgentWorkSessionRepository, ) -from ..services.archive import ArchiveService -from ..services.agent import ExecuteService, McpService, MemoryService, PlanService, SessionService, SettingsService, SkillService -from ..services.admin.users import AdminUsersService -from ..services.admin.storage import AdminStorageService +from ..s3 import MinioObjectStorageClient from ..services.admin.files import AdminFilesService -from ..services.admin.moderation import AdminModerationService from ..services.admin.logs import AdminLogsService +from ..services.admin.moderation import AdminModerationService from ..services.admin.notifications import AdminNotificationsService +from ..services.admin.storage import AdminStorageService from ..services.admin.system import AdminSystemService +from ..services.admin.users import AdminUsersService +from ..services.agent import ( + ExecuteService, + McpService, + MemoryService, + PlanService, + SessionService, + SettingsService, + SkillService, +) +from ..services.archive import ArchiveService from ..services.auth import AuthService from ..services.background_jobs import BackgroundJobService +from ..services.download_rate_limit import DownloadRateLimitService from ..services.email_delivery import VerificationEmailDeliveryService from ..services.file import FileService from ..services.folder import FolderService @@ -37,7 +48,6 @@ from ..services.registration_email_domain_rule import RegistrationEmailDomainRuleService from ..services.share import ShareService from ..services.upload import UploadService -from ..s3 import MinioObjectStorageClient from .errors import ApiError from .security import decode_access_token from .settings import Settings, get_settings @@ -55,6 +65,7 @@ redis_url=_settings.redis_url, stream_key=_settings.agent_queue_stream, ) +_agent_event_bus_singleton: AgentEventBus | None = None def get_rate_limiter() -> RedisRateLimiter: @@ -77,6 +88,13 @@ def get_agent_job_queue_publisher() -> RedisStreamJobQueue: return _agent_job_queue_publisher +def get_agent_event_bus() -> AgentEventBus: + global _agent_event_bus_singleton + if _agent_event_bus_singleton is None: + _agent_event_bus_singleton = build_agent_event_bus(settings=_settings) + return _agent_event_bus_singleton + + def get_background_job_service( queue_publisher: RedisStreamJobQueue = Depends(get_job_queue_publisher), ) -> BackgroundJobService: @@ -93,6 +111,14 @@ def get_settings_dep() -> Settings: return _settings +def get_download_rate_limit_service( + db: AsyncSession = Depends(get_db), + settings: Settings = Depends(get_settings_dep), + rate_limiter: RedisRateLimiter = Depends(get_rate_limiter), +) -> DownloadRateLimitService: + return DownloadRateLimitService(db=db, settings=settings, rate_limiter=rate_limiter) + + def get_client_ip(request: Request) -> str: forwarded_for = request.headers.get("x-forwarded-for") if forwarded_for: diff --git a/app/src/fileflash/core/settings.py b/app/src/fileflash/core/settings.py index 80316a2..71dde62 100644 --- a/app/src/fileflash/core/settings.py +++ b/app/src/fileflash/core/settings.py @@ -56,6 +56,22 @@ class Settings(BaseSettings): cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:5173", "http://localhost:8080"]) redis_url: str | None = Field(default=None, alias="REDIS_URL") + agent_inbox_ask_timeout_sec: int = Field( + default=1800, + alias="AGENT_INBOX_ASK_TIMEOUT_SEC", + ) + agent_event_channel_prefix: str = Field( + default="agent:job", + alias="AGENT_EVENT_CHANNEL_PREFIX", + ) + agent_inbox_channel_prefix: str = Field( + default="agent:inbox", + alias="AGENT_INBOX_CHANNEL_PREFIX", + ) + agent_event_bus_buffer_size: int = Field( + default=64, + alias="AGENT_EVENT_BUS_BUFFER_SIZE", + ) rabbitmq_url: str | None = Field(default=None, alias="RABBITMQ_URL") email_verify_base_url: str = Field(default="", alias="EMAIL_VERIFY_BASE_URL") @@ -95,19 +111,25 @@ class Settings(BaseSettings): upload_temp_prefix: str = Field(default="tmp", alias="UPLOAD_TEMP_PREFIX") upload_object_prefix: str = Field(default="objects", alias="UPLOAD_OBJECT_PREFIX") - max_failed_login_attempts: int = 5 - account_lock_minutes: int = 15 + max_failed_login_attempts: int = 8 + account_lock_minutes: int = 5 email_verification_expire_minutes: int = 60 password_reset_expire_minutes: int = 30 - register_rate_limit: int = 5 + register_rate_limit: int = 12 register_rate_window_seconds: int = 600 - login_rate_limit: int = 10 + login_rate_limit: int = 30 login_rate_window_seconds: int = 300 forgot_password_rate_limit: int = 5 forgot_password_rate_window_seconds: int = 600 resend_verification_rate_limit: int = 5 resend_verification_rate_window_seconds: int = 600 + download_rate_window_seconds: int = Field(default=600, alias="DOWNLOAD_RATE_WINDOW_SECONDS") + download_rate_limit_requests: int = Field(default=120, alias="DOWNLOAD_RATE_LIMIT_REQUESTS") + download_rate_limit_bytes: int = Field( + default=2 * 1024 * 1024 * 1024, + alias="DOWNLOAD_RATE_LIMIT_BYTES", + ) worker_poll_interval_seconds: float = Field( default=2.0, @@ -145,6 +167,7 @@ class Settings(BaseSettings): agent_llm_model: str = Field(default="claude-sonnet-4-6", alias="AGENT_LLM_MODEL") agent_llm_base_url: str | None = Field(default=None, alias="AGENT_LLM_BASE_URL") agent_llm_api_key: str | None = Field(default=None, alias="AGENT_LLM_API_KEY") + agent_llm_plan_max_tokens: int = Field(default=8192, alias="AGENT_LLM_PLAN_MAX_TOKENS") agent_mcp_endpoints_raw: str = Field(default="[]", alias="AGENT_MCP_ENDPOINTS") ffmpeg_binary: str = Field(default="ffmpeg", alias="FFMPEG_BINARY") diff --git a/app/src/fileflash/models/__init__.py b/app/src/fileflash/models/__init__.py index 575fa97..21cd3a5 100644 --- a/app/src/fileflash/models/__init__.py +++ b/app/src/fileflash/models/__init__.py @@ -2,6 +2,7 @@ from .tables import ( Acl, AgentActionLog, + AgentInboxMessage, AgentMcpServer, AgentMemory, AgentPlan, @@ -40,6 +41,7 @@ __all__ = [ "Acl", "AgentActionLog", + "AgentInboxMessage", "AgentMcpServer", "AgentMemory", "AgentPlan", diff --git a/app/src/fileflash/models/enums.py b/app/src/fileflash/models/enums.py index 0ce1e32..7228c78 100644 --- a/app/src/fileflash/models/enums.py +++ b/app/src/fileflash/models/enums.py @@ -147,6 +147,29 @@ class AgentMcpVisibility(BaseStrEnum): PRIVATE = "private" +class AgentInboxRole(BaseStrEnum): + AGENT = "agent" + USER = "user" + + +class AgentInboxKind(BaseStrEnum): + ASK = "ask" + REPLY = "reply" + CONTROL_PAUSE = "control.pause" + CONTROL_RESUME = "control.resume" + CONTROL_APPROVE = "control.approve" + CONTROL_DENY = "control.deny" + CONTROL_SKIP = "control.skip" + CONTROL_CANCEL = "control.cancel" + + +class AgentInboxStatus(BaseStrEnum): + WAITING = "waiting" + ANSWERED = "answered" + TIMED_OUT = "timed_out" + DROPPED = "dropped" + + __all__ = [ "BaseStrEnum", "UploadStatus", @@ -172,5 +195,8 @@ class AgentMcpVisibility(BaseStrEnum): "AgentMemoryKind", "AgentSkillVisibility", "AgentMcpVisibility", + "AgentInboxRole", + "AgentInboxKind", + "AgentInboxStatus", ] diff --git a/app/src/fileflash/models/tables.py b/app/src/fileflash/models/tables.py index 0040f7b..5f4930c 100644 --- a/app/src/fileflash/models/tables.py +++ b/app/src/fileflash/models/tables.py @@ -11,6 +11,7 @@ ) from .tables_agent import ( AgentActionLog, + AgentInboxMessage, AgentMcpServer, AgentMemory, AgentPlan, @@ -50,6 +51,7 @@ __all__ = [ "Acl", "AgentActionLog", + "AgentInboxMessage", "AgentMcpServer", "AgentMemory", "AgentPlan", diff --git a/app/src/fileflash/models/tables_agent.py b/app/src/fileflash/models/tables_agent.py index 33b4e79..3ad9591 100644 --- a/app/src/fileflash/models/tables_agent.py +++ b/app/src/fileflash/models/tables_agent.py @@ -21,6 +21,9 @@ from .base import Base from .enums import ( AgentExecutionPolicy, + AgentInboxKind, + AgentInboxRole, + AgentInboxStatus, AgentMcpVisibility, AgentMemoryKind, AgentMemoryScope, @@ -345,8 +348,55 @@ class AgentWorkSession(Base): closed_at: Mapped[datetime | None] = mapped_column(DateTime) +class AgentInboxMessage(Base): + __tablename__ = "agent_inbox_message" + __table_args__ = ( + Index("idx_agent_inbox_message_job_created", "job_id", "created_at"), + Index( + "idx_agent_inbox_message_job_status", + "job_id", + "status", + postgresql_where=text("status IS NOT NULL"), + ), + ) + + inbox_message_id: Mapped[int] = mapped_column(BigInteger, Identity(), primary_key=True) + job_id: Mapped[int] = mapped_column( + BigInteger, + ForeignKey("background_job.job_id", ondelete="CASCADE"), + nullable=False, + ) + role: Mapped[AgentInboxRole] = mapped_column( + pg_enum(AgentInboxRole, "agent_inbox_role_enum"), + nullable=False, + ) + kind: Mapped[AgentInboxKind] = mapped_column( + pg_enum(AgentInboxKind, "agent_inbox_kind_enum"), + nullable=False, + ) + payload_json: Mapped[dict[str, Any]] = mapped_column( + JSONB, + nullable=False, + server_default=text("'{}'::jsonb"), + ) + reply_to_id: Mapped[int | None] = mapped_column( + BigInteger, + ForeignKey("agent_inbox_message.inbox_message_id", ondelete="SET NULL"), + ) + status: Mapped[AgentInboxStatus | None] = mapped_column( + pg_enum(AgentInboxStatus, "agent_inbox_status_enum"), + ) + created_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + server_default=text("CURRENT_TIMESTAMP"), + ) + answered_at: Mapped[datetime | None] = mapped_column(DateTime) + + __all__ = [ "AgentActionLog", + "AgentInboxMessage", "AgentMcpServer", "AgentMemory", "AgentPlan", diff --git a/app/src/fileflash/repositories/__init__.py b/app/src/fileflash/repositories/__init__.py index 84b1407..c063ea1 100644 --- a/app/src/fileflash/repositories/__init__.py +++ b/app/src/fileflash/repositories/__init__.py @@ -1,5 +1,6 @@ from .agent import ( AgentActionLogRepository, + AgentInboxMessageRepository, AgentMcpCatalogEntry, AgentMcpRepository, AgentMemoryActiveEntry, @@ -13,6 +14,7 @@ __all__ = [ "AgentActionLogRepository", + "AgentInboxMessageRepository", "AgentMcpCatalogEntry", "AgentMcpRepository", "AgentMemoryActiveEntry", diff --git a/app/src/fileflash/repositories/agent/__init__.py b/app/src/fileflash/repositories/agent/__init__.py index ac9837d..c551d06 100644 --- a/app/src/fileflash/repositories/agent/__init__.py +++ b/app/src/fileflash/repositories/agent/__init__.py @@ -1,5 +1,6 @@ from .action_log import AgentActionLogRepository from .contracts import AgentMcpCatalogEntry, AgentMemoryActiveEntry, AgentSkillCatalogEntry +from .inbox import AgentInboxMessageRepository from .mcp import AgentMcpRepository from .memory import AgentMemoryRepository from .plan import AgentPlanRepository @@ -9,6 +10,7 @@ __all__ = [ "AgentActionLogRepository", + "AgentInboxMessageRepository", "AgentMcpCatalogEntry", "AgentMcpRepository", "AgentMemoryActiveEntry", diff --git a/app/src/fileflash/repositories/agent/inbox.py b/app/src/fileflash/repositories/agent/inbox.py new file mode 100644 index 0000000..d6996b2 --- /dev/null +++ b/app/src/fileflash/repositories/agent/inbox.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy import and_, select +from sqlalchemy.ext.asyncio import AsyncSession + +from ...models import AgentInboxMessage +from ...models.enums import AgentInboxKind, AgentInboxRole, AgentInboxStatus + +_CONTROL_KINDS = frozenset( + { + AgentInboxKind.CONTROL_PAUSE, + AgentInboxKind.CONTROL_RESUME, + AgentInboxKind.CONTROL_APPROVE, + AgentInboxKind.CONTROL_DENY, + AgentInboxKind.CONTROL_SKIP, + AgentInboxKind.CONTROL_CANCEL, + } +) + + +class AgentInboxMessageRepository: + def __init__(self, db: AsyncSession) -> None: + self._db = db + + async def create_ask( + self, + *, + job_id: int, + payload: dict[str, Any], + ) -> AgentInboxMessage: + msg = AgentInboxMessage( + job_id=job_id, + role=AgentInboxRole.AGENT, + kind=AgentInboxKind.ASK, + payload_json=payload, + status=AgentInboxStatus.WAITING, + created_at=datetime.now(UTC), + ) + self._db.add(msg) + await self._db.flush() + return msg + + async def record_user_message( + self, + *, + job_id: int, + kind: AgentInboxKind, + payload: dict[str, Any], + reply_to_id: int | None = None, + ) -> AgentInboxMessage: + msg = AgentInboxMessage( + job_id=job_id, + role=AgentInboxRole.USER, + kind=kind, + payload_json=payload, + reply_to_id=reply_to_id, + status=None, + created_at=datetime.now(UTC), + ) + self._db.add(msg) + await self._db.flush() + return msg + + async def mark_answered( + self, + *, + inbox_message_id: int, + answered_at: datetime, + ) -> AgentInboxMessage: + msg = await self._db.get(AgentInboxMessage, inbox_message_id) + if msg is None: + raise ValueError(f"AgentInboxMessage {inbox_message_id} not found") + msg.status = AgentInboxStatus.ANSWERED + msg.answered_at = answered_at + await self._db.flush() + return msg + + async def mark_timed_out( + self, + *, + inbox_message_id: int, + answered_at: datetime, + ) -> AgentInboxMessage: + msg = await self._db.get(AgentInboxMessage, inbox_message_id) + if msg is None: + raise ValueError(f"AgentInboxMessage {inbox_message_id} not found") + msg.status = AgentInboxStatus.TIMED_OUT + msg.answered_at = answered_at + await self._db.flush() + return msg + + async def mark_dropped(self, *, inbox_message_id: int) -> None: + msg = await self._db.get(AgentInboxMessage, inbox_message_id) + if msg is None: + return + if msg.kind in _CONTROL_KINDS: + msg.status = AgentInboxStatus.DROPPED + msg.answered_at = datetime.now(UTC) + await self._db.flush() + + async def get_ask(self, *, inbox_message_id: int) -> AgentInboxMessage | None: + msg = await self._db.get(AgentInboxMessage, inbox_message_id) + if msg is None or msg.kind != AgentInboxKind.ASK: + return None + return msg + + async def get_reply_for(self, *, ask_id: int) -> AgentInboxMessage | None: + return await self._db.scalar( + select(AgentInboxMessage).where( + and_( + AgentInboxMessage.reply_to_id == ask_id, + AgentInboxMessage.kind == AgentInboxKind.REPLY, + ) + ) + ) + + async def list_pending_controls(self, *, job_id: int) -> list[AgentInboxMessage]: + rows = await self._db.scalars( + select(AgentInboxMessage) + .where( + and_( + AgentInboxMessage.job_id == job_id, + AgentInboxMessage.role == AgentInboxRole.USER, + AgentInboxMessage.kind.in_(list(_CONTROL_KINDS)), + AgentInboxMessage.status.is_(None), + ) + ) + .order_by(AgentInboxMessage.created_at.asc()) + ) + return list(rows) + + +__all__ = ["AgentInboxMessageRepository"] diff --git a/app/src/fileflash/routers/admin_users.py b/app/src/fileflash/routers/admin_users.py index 09278c0..0eb8619 100644 --- a/app/src/fileflash/routers/admin_users.py +++ b/app/src/fileflash/routers/admin_users.py @@ -11,9 +11,19 @@ router = APIRouter(prefix="/admin/users", tags=["admin"]) +def get_list_admin_users_query(query: ListAdminUsersQuery = Depends()) -> ListAdminUsersQuery: + try: + query.resolve_usage_window() + except ValueError as exc: + from ..core.errors import ApiError + + raise ApiError(status_code=400, code=400, message=str(exc)) from exc + return query + + @router.get("") async def list_admin_users( - query: ListAdminUsersQuery = Depends(), + query: ListAdminUsersQuery = Depends(get_list_admin_users_query), _: User = Depends(require_admin), service: AdminUsersService = Depends(get_admin_users_service), ): diff --git a/app/src/fileflash/routers/agent.py b/app/src/fileflash/routers/agent.py index 2e56108..4cc869f 100644 --- a/app/src/fileflash/routers/agent.py +++ b/app/src/fileflash/routers/agent.py @@ -1,18 +1,34 @@ from __future__ import annotations -from datetime import UTC, datetime -from typing import Annotated +import json +from typing import Annotated, Any, get_args from fastapi import APIRouter, Depends +from fastapi.responses import StreamingResponse from sqlalchemy import and_, select from sqlalchemy.ext.asyncio import AsyncSession -from ..core.deps import get_agent_execute_service, get_agent_plan_service, get_current_user +from ..agents.harness.event_bus import AgentEventBus, AgentEventEnvelope +from ..agents.harness.inbox import AgentInbox +from ..core.deps import ( + get_agent_event_bus, + get_agent_execute_service, + get_agent_plan_service, + get_current_user, +) from ..core.errors import ApiError, api_success from ..db.deps import get_db -from ..models import BackgroundJob +from ..models import AgentActionLog, BackgroundJob +from ..models.enums import AgentInboxKind from ..models.tables_identity import User -from ..schemas.agent import CancelAgentResponse, ExecuteAgentRequest, PlanAgentRequest +from ..schemas.agent import ( + AgentInboxMessageRequest, + AgentInboxMessageResponse, + AgentJobEvent, + AgentJobEventType, + ExecuteAgentRequest, + PlanAgentRequest, +) from ..services.agent import ExecuteService, PlanService router = APIRouter(prefix="/agent", tags=["agent"]) @@ -44,16 +60,63 @@ async def execute_agent_plan( ) -@router.post("/cancel/{job_id}") -async def cancel_agent_job( +@router.get("/jobs/{job_id}/events") +async def stream_agent_job_events( job_id: str, current_user: Annotated[User, Depends(get_current_user)], db: Annotated[AsyncSession, Depends(get_db)], + event_bus: Annotated[AgentEventBus, Depends(get_agent_event_bus)], ): - try: - parsed_job_id = int(job_id) - except ValueError as exc: - raise ApiError(status_code=400, code=400, message="Invalid jobId") from exc + parsed_job_id = _parse_job_id(job_id) + initial_events, initial_terminal = await _agent_job_events_for_job( + db=db, + job_id=parsed_job_id, + user_id=int(current_user.user_id), + ) + + async def event_stream(): + seen: set[str] = set() + for event in initial_events: + seen.add(event.id) + yield _format_sse_event(event) + if initial_terminal: + return + async with event_bus.subscribe(job_id=parsed_job_id) as stream: + while True: + try: + envelope = await stream.next(timeout=30.0) + except TimeoutError: + yield ": keep-alive\n\n" + continue + event = _envelope_to_job_event(envelope) + if event is None: + continue + if event.id in seen: + continue + seen.add(event.id) + yield _format_sse_event(event) + if event.type in {"job.succeeded", "job.failed", "job.canceled"}: + break + + return StreamingResponse( + event_stream(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + }, + ) + + +@router.post("/jobs/{job_id}/messages") +async def post_agent_job_message( + job_id: str, + payload: AgentInboxMessageRequest, + current_user: Annotated[User, Depends(get_current_user)], + db: Annotated[AsyncSession, Depends(get_db)], + event_bus: Annotated[AgentEventBus, Depends(get_agent_event_bus)], +): + parsed_job_id = _parse_job_id(job_id) job = await db.scalar( select(BackgroundJob) .where( @@ -63,27 +126,260 @@ async def cancel_agent_job( BackgroundJob.task_type.in_(["agent.plan", "agent.execute"]), ) ) - .with_for_update() ) if job is None: raise ApiError(status_code=404, code=404, message="Job not found") - canceled_at = datetime.now(UTC) - if job.status not in {"succeeded", "failed", "canceled"}: - job.cancel_requested_at = canceled_at - job.status = "canceled" - job.agent_phase = "canceled" - job.finished_at = canceled_at - job.updated_at = canceled_at + kind = AgentInboxKind(payload.kind) + reply_to_id: int | None = None + if payload.reply_to is not None: + try: + reply_to_id = int(payload.reply_to) + except ValueError as exc: + raise ApiError(status_code=400, code=400, message="Invalid replyTo") from exc + + inbox = AgentInbox(db=db, event_bus=event_bus) + try: + msg = await inbox.handle( + job_id=parsed_job_id, + kind=kind, + payload=_inbox_payload_from_request(payload), + reply_to_id=reply_to_id, + ) + except ValueError as exc: + raise ApiError(status_code=400, code=400, message=str(exc)) from exc await db.commit() - await db.refresh(job) - data = CancelAgentResponse( + data = AgentInboxMessageResponse( + inbox_message_id=str(msg.inbox_message_id), + kind=payload.kind, + accepted_at=msg.created_at, + ) + return api_success(data=data.model_dump(by_alias=True), message="Message accepted") + + +def _inbox_payload_from_request(req: AgentInboxMessageRequest) -> dict[str, Any]: + body: dict[str, Any] = {} + if req.value is not None: + body["value"] = req.value + if req.metadata: + body["metadata"] = req.metadata + return body + + +def _parse_job_id(raw: str) -> int: + try: + parsed_job_id = int(raw) + except ValueError as exc: + raise ApiError(status_code=400, code=400, message="Invalid jobId") from exc + if parsed_job_id <= 0: + raise ApiError(status_code=400, code=400, message="Invalid jobId") + return parsed_job_id + + +async def _agent_job_events_for_job( + *, + db: AsyncSession, + job_id: int, + user_id: int, +) -> tuple[list[AgentJobEvent], bool]: + job = await db.scalar( + select(BackgroundJob).where( + and_( + BackgroundJob.job_id == job_id, + BackgroundJob.requested_by == user_id, + BackgroundJob.task_type.in_(["agent.plan", "agent.execute"]), + ) + ) + ) + if job is None: + raise ApiError(status_code=404, code=404, message="Job not found") + + terminal = str(job.status) in {"succeeded", "failed", "canceled"} + events: list[AgentJobEvent] = [] + if job.task_type == "agent.plan" and job.status == "succeeded" and job.result: + events.append(_plan_ready_event(job)) + events.append(_job_status_event(job)) + elif job.task_type != "agent.execute" or not terminal: + events.append(_job_status_event(job)) + + if job.task_type == "agent.execute": + action_logs = list( + await db.scalars( + select(AgentActionLog) + .where(AgentActionLog.job_id == job_id) + .order_by(AgentActionLog.step_no.asc()) + ) + ) + for action_log in action_logs: + events.extend(_tool_events(job=job, action_log=action_log)) + if terminal: + events.append(_job_status_event(job)) + + return events, terminal + + +def _job_status_event(job: BackgroundJob) -> AgentJobEvent: + status = str(job.status) + event_type = { + "pending": "job.queued", + "running": "job.running", + "succeeded": "job.succeeded", + "failed": "job.failed", + "canceled": "job.canceled", + }.get(status, "job.running") + timestamp = job.updated_at or job.created_at + return AgentJobEvent( + id=f"{job.job_id}:job:{status}:{timestamp.isoformat()}", job_id=str(job.job_id), + task_type=str(job.task_type), + type=event_type, # type: ignore[arg-type] + status=status, + agent_phase=job.agent_phase, + message=_job_status_message(job), + data=_job_status_data(job), + timestamp=timestamp, + ) + + +def _plan_ready_event(job: BackgroundJob) -> AgentJobEvent: + timestamp = job.finished_at or job.updated_at or job.created_at + return AgentJobEvent( + id=f"{job.job_id}:plan-ready", + job_id=str(job.job_id), + task_type=str(job.task_type), + type="plan.ready", status=str(job.status), - canceled_at=job.cancel_requested_at or canceled_at, + agent_phase=job.agent_phase, + message="计划已生成。", + data={"result": dict(job.result or {})}, + timestamp=timestamp, + ) + + +def _tool_events(*, job: BackgroundJob, action_log: AgentActionLog) -> list[AgentJobEvent]: + events = [ + AgentJobEvent( + id=f"{job.job_id}:tool:{action_log.action_log_id}:started", + job_id=str(job.job_id), + task_type=str(job.task_type), + type="tool.started", + status=str(job.status), + agent_phase=job.agent_phase, + message=_tool_started_message(action_log), + data=_tool_event_data(action_log, include_output=False), + timestamp=action_log.started_at, + ) + ] + if action_log.status in {"succeeded", "failed"} and action_log.finished_at is not None: + events.append( + AgentJobEvent( + id=f"{job.job_id}:tool:{action_log.action_log_id}:{action_log.status}", + job_id=str(job.job_id), + task_type=str(job.task_type), + type="tool.succeeded" if action_log.status == "succeeded" else "tool.failed", + status=str(job.status), + agent_phase=job.agent_phase, + message=_tool_finished_message(action_log), + data=_tool_event_data(action_log, include_output=True), + timestamp=action_log.finished_at, + ) + ) + return events + + +def _job_status_message(job: BackgroundJob) -> str: + status = str(job.status) + if status == "pending": + return "任务已排队。" + if status == "running": + return "正在规划任务。" if job.task_type == "agent.plan" else "正在执行计划。" + if status == "succeeded": + result = dict(job.result or {}) + answer = result.get("answer") + if isinstance(answer, str) and answer.strip(): + return "答案已生成。" + return "任务已完成。" + if status == "failed": + return str(job.error_message or "任务失败。") + if status == "canceled": + return "任务已取消。" + return "任务状态已更新。" + + +def _job_status_data(job: BackgroundJob) -> dict[str, object]: + data: dict[str, object] = {} + if job.status in {"succeeded", "failed", "canceled"}: + data["result"] = dict(job.result or {}) + if job.error_message: + data["errorMessage"] = job.error_message + return data + + +def _tool_started_message(action_log: AgentActionLog) -> str: + if action_log.tool_name == "drive.countFiles": + inputs = dict(action_log.inputs_json or {}) + search = str(inputs.get("search") or "").strip() + category = str(inputs.get("category") or "").strip() + target = "视频文件" if category == "video" else "文件" + if search: + return f"正在读取名称包含“{search}”的{target}数量。" + return f"正在读取{target}数量。" + return f"正在调用 {action_log.tool_name}。" + + +def _tool_finished_message(action_log: AgentActionLog) -> str: + if action_log.status == "failed": + return str(action_log.error_message or f"{action_log.tool_name} 调用失败。") + if action_log.tool_name == "drive.countFiles": + outputs = dict(action_log.outputs_json or {}) + total_items = int(outputs.get("totalItems") or 0) + return f"读取完成,匹配 {total_items} 个文件。" + return f"{action_log.tool_name} 已完成。" + + +def _tool_event_data(action_log: AgentActionLog, *, include_output: bool) -> dict[str, object]: + data: dict[str, object] = { + "step": int(action_log.step_no), + "tool": str(action_log.tool_name), + "input": dict(action_log.inputs_json or {}), + } + if include_output: + data["output"] = dict(action_log.outputs_json or {}) + if action_log.duration_ms is not None: + data["durationMs"] = int(action_log.duration_ms) + if action_log.error_message: + data["errorMessage"] = action_log.error_message + return data + + +def _envelope_to_job_event(env: AgentEventEnvelope) -> AgentJobEvent | None: + if env.event_type.startswith("agent.inbox."): + return None + if env.event_type not in get_args(AgentJobEventType): + return None + payload = dict(env.payload or {}) + data = payload.get("data") + return AgentJobEvent( + id=env.event_id or f"{env.job_id}:{env.event_type}:{env.emitted_at.isoformat()}", + job_id=str(env.job_id), + task_type=str(payload.get("taskType") or "agent.execute"), + type=env.event_type, # type: ignore[arg-type] + status=str(payload.get("status") or "running"), + agent_phase=payload.get("agentPhase"), + message=str(payload.get("message") or ""), + data=dict(data) if isinstance(data, dict) else payload, + timestamp=env.emitted_at, + ) + + +def _format_sse_event(event: AgentJobEvent) -> str: + payload = event.model_dump(by_alias=True, mode="json") + return ( + f"id: {event.id}\n" + f"event: {event.type}\n" + f"data: {json.dumps(payload, ensure_ascii=False, separators=(',', ':'))}\n\n" ) - return api_success(data=data.model_dump(by_alias=True), message="Job canceled") __all__ = ["router"] diff --git a/app/src/fileflash/routers/files.py b/app/src/fileflash/routers/files.py index e54c99a..8938b4d 100644 --- a/app/src/fileflash/routers/files.py +++ b/app/src/fileflash/routers/files.py @@ -9,7 +9,13 @@ from jwt import InvalidTokenError from starlette.background import BackgroundTask -from ..core.deps import get_archive_service, get_current_user, get_file_service, get_settings_dep +from ..core.deps import ( + get_archive_service, + get_current_user, + get_download_rate_limit_service, + get_file_service, + get_settings_dep, +) from ..core.errors import ApiError, api_success from ..core.security import create_file_preview_token, decode_file_preview_token from ..core.settings import Settings @@ -26,11 +32,19 @@ ) from ..schemas.job import to_background_job_response from ..services.archive import ArchiveService +from ..services.download_rate_limit import DownloadRateLimitService from ..services.file import FileService router = APIRouter(prefix="/files", tags=["files"]) +def _content_length(headers: dict[str, str]) -> int: + try: + return max(0, int(headers.get("Content-Length") or 0)) + except ValueError: + return 0 + + @router.get("") async def list_files( folder_id: str | None = Query(None, alias="folderId"), @@ -112,11 +126,14 @@ async def batch_download_files( payload: BatchDownloadRequest, current_user: User = Depends(get_current_user), file_service: FileService = Depends(get_file_service), + download_limiter: DownloadRateLimitService = Depends(get_download_rate_limit_service), ): - archive_path, archive_name = await file_service.create_batch_download_archive( + plan = await file_service.create_batch_download_plan( user_id=current_user.user_id, payload=payload, ) + await download_limiter.enforce_user(user=current_user, bytes_count=plan.estimated_bytes) + archive_path, archive_name = await file_service.create_batch_download_archive_from_plan(plan=plan) return FileResponse( archive_path, media_type="application/zip", @@ -146,12 +163,14 @@ async def download_file( range_header: str | None = Header(default=None, alias="Range"), current_user: User = Depends(get_current_user), file_service: FileService = Depends(get_file_service), + download_limiter: DownloadRateLimitService = Depends(get_download_rate_limit_service), ): result = await file_service.get_download_stream( user_id=current_user.user_id, file_id=file_id, range_header=range_header, ) + await download_limiter.enforce_user(user=current_user, bytes_count=_content_length(result.headers)) return StreamingResponse( result.stream, media_type=result.content_type, @@ -166,12 +185,14 @@ async def preview_file( range_header: str | None = Header(default=None, alias="Range"), current_user: User = Depends(get_current_user), file_service: FileService = Depends(get_file_service), + download_limiter: DownloadRateLimitService = Depends(get_download_rate_limit_service), ): result = await file_service.get_preview_stream( user_id=current_user.user_id, file_id=file_id, range_header=range_header, ) + await download_limiter.enforce_user(user=current_user, bytes_count=_content_length(result.headers)) return StreamingResponse( result.stream, media_type=result.content_type, @@ -216,6 +237,7 @@ async def preview_file_stream( range_header: str | None = Header(default=None, alias="Range"), file_service: FileService = Depends(get_file_service), settings: Settings = Depends(get_settings_dep), + download_limiter: DownloadRateLimitService = Depends(get_download_rate_limit_service), ): try: payload = decode_file_preview_token(token, settings) @@ -232,6 +254,7 @@ async def preview_file_stream( file_id=file_id, range_header=range_header, ) + await download_limiter.enforce_user_id(user_id=user_id, bytes_count=_content_length(result.headers)) return StreamingResponse( result.stream, media_type=result.content_type, diff --git a/app/src/fileflash/routers/shares.py b/app/src/fileflash/routers/shares.py index 4341854..4ac18c0 100644 --- a/app/src/fileflash/routers/shares.py +++ b/app/src/fileflash/routers/shares.py @@ -3,7 +3,14 @@ from fastapi import APIRouter, Depends, Header from fastapi.responses import StreamingResponse -from ..core.deps import get_client_ip, get_share_service, get_user_agent, get_current_user, require_verified_user +from ..core.deps import ( + get_client_ip, + get_current_user, + get_download_rate_limit_service, + get_share_service, + get_user_agent, + require_verified_user, +) from ..core.errors import api_success from ..core.http_headers import build_content_disposition from ..models.tables_identity import User @@ -14,6 +21,7 @@ SaveShareRequest, UpdateShareSettingsRequest, ) +from ..services.download_rate_limit import DownloadRateLimitService from ..services.share import ShareService router = APIRouter(prefix="/shares", tags=["shares"]) @@ -85,6 +93,19 @@ def _sanitize_stream_headers( return sanitized +def _content_length(headers: dict[str, str] | None) -> int: + if not headers: + return 0 + for key, value in headers.items(): + if key.lower() != "content-length": + continue + try: + return max(0, int(value)) + except ValueError: + return 0 + return 0 + + @router.post("") async def create_share( payload: CreateShareRequest, @@ -195,6 +216,7 @@ async def download_shared_file( client_ip: str = Depends(get_client_ip), user_agent: str | None = Depends(get_user_agent), share_service: ShareService = Depends(get_share_service), + download_limiter: DownloadRateLimitService = Depends(get_download_rate_limit_service), ): token = _extract_bearer_token(authorization) if not token: @@ -210,6 +232,10 @@ async def download_shared_file( range_header=range_header, ip_address=client_ip, user_agent=user_agent, + rate_limit_check=lambda bytes_count: download_limiter.enforce_share_ip( + client_ip=client_ip, + bytes_count=bytes_count, + ), ) else: raw = await share_service.get_shared_file_stream( @@ -220,6 +246,8 @@ async def download_shared_file( user_agent=user_agent, ) stream, filename, content_type, status_code, headers = _extract_share_stream(tuple(raw)) + if not hasattr(share_service, "get_shared_file_download_stream_response"): + await download_limiter.enforce_share_ip(client_ip=client_ip, bytes_count=_content_length(headers)) response_headers = _sanitize_stream_headers(headers=headers, filename=filename, disposition="attachment") return StreamingResponse(stream, media_type=content_type, headers=response_headers, status_code=status_code) @@ -232,6 +260,7 @@ async def preview_shared_file( client_ip: str = Depends(get_client_ip), user_agent: str | None = Depends(get_user_agent), share_service: ShareService = Depends(get_share_service), + download_limiter: DownloadRateLimitService = Depends(get_download_rate_limit_service), ): token = _extract_bearer_token(authorization) if not token: @@ -247,6 +276,10 @@ async def preview_shared_file( range_header=range_header, ip_address=client_ip, user_agent=user_agent, + rate_limit_check=lambda bytes_count: download_limiter.enforce_share_ip( + client_ip=client_ip, + bytes_count=bytes_count, + ), ) else: raw = await share_service.get_shared_file_stream( @@ -257,6 +290,8 @@ async def preview_shared_file( user_agent=user_agent, ) stream, filename, content_type, status_code, headers = _extract_share_stream(tuple(raw)) + if not hasattr(share_service, "get_shared_file_download_stream_response"): + await download_limiter.enforce_share_ip(client_ip=client_ip, bytes_count=_content_length(headers)) response_headers = _sanitize_stream_headers(headers=headers, filename=filename, disposition="inline") return StreamingResponse(stream, media_type=content_type, headers=response_headers, status_code=status_code) diff --git a/app/src/fileflash/schemas/__init__.py b/app/src/fileflash/schemas/__init__.py index c970aae..1e577d9 100644 --- a/app/src/fileflash/schemas/__init__.py +++ b/app/src/fileflash/schemas/__init__.py @@ -1,22 +1,10 @@ -from .auth import ( - ForgotPasswordRequest, - ForgotPasswordResponse, - LoginRequest, - RegisterResponseData, - RegisterRequest, - ResetPasswordRequest, - TokenResponse, - VerifyEmailRequest, -) -from .agent_skill import ( - AgentSkillItem, - CreateAgentSkillRequest, - ImportAgentSkillItem, - ImportAgentSkillResult, - ImportAgentSkillsRequest, - ImportAgentSkillsResponse, - ListAgentSkillsQuery, - UpdateAgentSkillRequest, +from .admin.files import ( + AdminFileAuditDetail, + AdminFileAuditItem, + AdminFileAuditOwner, + AdminFileLatestScan, + ListAdminFilesQuery, + RescanResponse, ) from .agent import ( AgentApproval, @@ -25,7 +13,12 @@ AgentDataPolicy, AgentExecutionResult, AgentHints, + AgentInboxMessageKind, + AgentInboxMessageRequest, + AgentInboxMessageResponse, + AgentJobEvent, AgentPlanContext, + AgentPlanningEvidence, AgentPlanResult, AgentProposedAction, AgentReasoningEffort, @@ -35,20 +28,32 @@ PlanAgentRequest, PlanAgentResponse, ) -from .common import ApiResponse, CamelModel, PageQuery, PaginatedData, PaginationMeta -from .admin.files import ( - AdminFileAuditDetail, - AdminFileAuditItem, - AdminFileAuditOwner, - AdminFileLatestScan, - ListAdminFilesQuery, - RescanResponse, +from .agent_skill import ( + AgentSkillItem, + CreateAgentSkillRequest, + ImportAgentSkillItem, + ImportAgentSkillResult, + ImportAgentSkillsRequest, + ImportAgentSkillsResponse, + ListAgentSkillsQuery, + UpdateAgentSkillRequest, +) +from .auth import ( + ForgotPasswordRequest, + ForgotPasswordResponse, + LoginRequest, + RegisterRequest, + RegisterResponseData, + ResetPasswordRequest, + TokenResponse, + VerifyEmailRequest, ) +from .common import ApiResponse, CamelModel, PageQuery, PaginatedData, PaginationMeta from .file import ( BatchDownloadRequest, - BatchMoveItemResult, BatchFilesRequest, BatchFilesResponse, + BatchMoveItemResult, ContentItem, CopyFileRequest, CopyFileResponse, @@ -97,12 +102,6 @@ PermissionItem, UpdatePermissionRequest, ) -from .registration_email_domain_rule import ( - CreateRegistrationEmailDomainRuleRequest, - ListRegistrationEmailDomainRulesQuery, - RegistrationEmailDomainRuleItem, - UpdateRegistrationEmailDomainRuleRequest, -) from .recycle import ( ClearRecycleBinResponse, GetRecycleBinQuery, @@ -111,11 +110,17 @@ RestoreRecycleItemRequest, RestoreRecycleItemResponse, ) +from .registration_email_domain_rule import ( + CreateRegistrationEmailDomainRuleRequest, + ListRegistrationEmailDomainRulesQuery, + RegistrationEmailDomainRuleItem, + UpdateRegistrationEmailDomainRuleRequest, +) from .share import ( + AcceptSharedItemResponse, AccessShareRequest, AccessShareResponseData, AccessUrls, - AcceptSharedItemResponse, CreateShareRequest, DeleteShareResponse, GetSharedItemsQuery, @@ -178,7 +183,12 @@ "AgentDataPolicy", "AgentExecutionResult", "AgentHints", + "AgentInboxMessageKind", + "AgentInboxMessageRequest", + "AgentInboxMessageResponse", + "AgentJobEvent", "AgentPlanContext", + "AgentPlanningEvidence", "AgentPlanResult", "AgentProposedAction", "AgentReasoningEffort", diff --git a/app/src/fileflash/schemas/admin/users.py b/app/src/fileflash/schemas/admin/users.py index a62d4f8..9eab779 100644 --- a/app/src/fileflash/schemas/admin/users.py +++ b/app/src/fileflash/schemas/admin/users.py @@ -1,12 +1,28 @@ from __future__ import annotations -from datetime import datetime +from datetime import UTC, datetime, timedelta from typing import Literal +from pydantic import Field + from ..common import CamelModel, PageQuery ExternalUserStatus = Literal["active", "suspended", "pending_verification"] +DEFAULT_USAGE_WINDOW = timedelta(days=7) +MAX_USAGE_WINDOW = timedelta(days=90) + + +def _normalize_datetime(value: datetime) -> datetime: + if value.tzinfo is None: + return value.replace(tzinfo=UTC) + return value.astimezone(UTC) + + +class AdminUserUsageStats(CamelModel): + traffic_bytes: int = Field(ge=0) + agent_tokens: int = Field(ge=0) + class AdminUserItem(CamelModel): user_id: str @@ -22,6 +38,7 @@ class AdminUserItem(CamelModel): last_login_at: datetime | None = None last_active_at: datetime | None = None created_at: datetime + usage_stats: AdminUserUsageStats class ListAdminUsersQuery(PageQuery): @@ -30,6 +47,23 @@ class ListAdminUsersQuery(PageQuery): role: Literal["USER", "ADMIN"] | None = None sort: Literal["username", "createdAt", "storageUsed"] = "createdAt" order: Literal["asc", "desc"] = "desc" + usage_from: datetime | None = None + usage_to: datetime | None = None + + def resolve_usage_window(self, *, now: datetime | None = None) -> tuple[datetime, datetime]: + resolved_now = _normalize_datetime(now or datetime.now(UTC)) + if self.usage_from is None and self.usage_to is None: + return resolved_now - DEFAULT_USAGE_WINDOW, resolved_now + if self.usage_from is None or self.usage_to is None: + raise ValueError("usageFrom and usageTo must be provided together") + + usage_from = _normalize_datetime(self.usage_from) + usage_to = _normalize_datetime(self.usage_to) + if usage_from > usage_to: + raise ValueError("usageFrom must be earlier than or equal to usageTo") + if usage_to - usage_from > MAX_USAGE_WINDOW: + raise ValueError("usage window must not exceed 90 days") + return usage_from, usage_to class UpdateUserStatusRequest(CamelModel): @@ -44,6 +78,7 @@ class UpdateUserStatusResponse(CamelModel): __all__ = [ "AdminUserItem", + "AdminUserUsageStats", "ListAdminUsersQuery", "UpdateUserStatusRequest", "UpdateUserStatusResponse", diff --git a/app/src/fileflash/schemas/agent.py b/app/src/fileflash/schemas/agent.py index e253be0..97a25dd 100644 --- a/app/src/fileflash/schemas/agent.py +++ b/app/src/fileflash/schemas/agent.py @@ -20,6 +20,23 @@ "failed", "canceled", ] +AgentJobEventType = Literal[ + "job.queued", + "job.running", + "plan.ready", + "tool.started", + "tool.succeeded", + "tool.failed", + "tool.partial", + "agent.thinking", + "agent.progress", + "agent.ask", + "agent.paused", + "agent.resumed", + "job.succeeded", + "job.failed", + "job.canceled", +] class AgentDataPolicy(CamelModel): @@ -77,6 +94,13 @@ class AgentChosenSkill(CamelModel): name: str +class AgentPlanningEvidence(CamelModel): + step: int = Field(ge=1) + tool: str = Field(min_length=1, max_length=120) + input: dict[str, Any] = Field(default_factory=dict) + output_preview: dict[str, Any] = Field(default_factory=dict) + + class AgentPlanResult(CamelModel): plan_job_id: str plan_hash: str @@ -85,6 +109,7 @@ class AgentPlanResult(CamelModel): summary: str requires_confirmation: bool cost_estimate: AgentCostEstimate + planning_evidence: list[AgentPlanningEvidence] | None = None class AgentApproval(CamelModel): @@ -123,6 +148,42 @@ class AgentExecutionResult(CamelModel): finished_at: datetime +class AgentJobEvent(CamelModel): + id: str + job_id: str + task_type: str + type: AgentJobEventType + status: str + agent_phase: str | None = None + message: str + data: dict[str, Any] = Field(default_factory=dict) + timestamp: datetime + + +AgentInboxMessageKind = Literal[ + "reply", + "control.pause", + "control.resume", + "control.approve", + "control.deny", + "control.skip", + "control.cancel", +] + + +class AgentInboxMessageRequest(CamelModel): + kind: AgentInboxMessageKind + reply_to: str | None = None + value: Any = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +class AgentInboxMessageResponse(CamelModel): + inbox_message_id: str + kind: AgentInboxMessageKind + accepted_at: datetime + + __all__ = [ "AgentActionSideEffect", "AgentApproval", @@ -132,8 +193,14 @@ class AgentExecutionResult(CamelModel): "AgentExecutionPolicy", "AgentExecutionResult", "AgentHints", + "AgentInboxMessageKind", + "AgentInboxMessageRequest", + "AgentInboxMessageResponse", "AgentJobPhase", + "AgentJobEvent", + "AgentJobEventType", "AgentPlanContext", + "AgentPlanningEvidence", "AgentPlanResult", "AgentProposedAction", "AgentReasoningEffort", diff --git a/app/src/fileflash/services/__init__.py b/app/src/fileflash/services/__init__.py index e523bba..551c022 100644 --- a/app/src/fileflash/services/__init__.py +++ b/app/src/fileflash/services/__init__.py @@ -2,6 +2,7 @@ from .agent import ExecuteService, McpService, MemoryService, PlanService, SessionService, SettingsService, SkillService from .auth import AuthService from .background_jobs import BackgroundJobService +from .download_rate_limit import DownloadRateLimitService from .email_delivery import VerificationEmailDeliveryService from .file import FileService from .folder import FolderService @@ -17,6 +18,7 @@ "AuthService", "ArchiveService", "BackgroundJobService", + "DownloadRateLimitService", "VerificationEmailDeliveryService", "ExecuteService", "FileService", diff --git a/app/src/fileflash/services/admin/users.py b/app/src/fileflash/services/admin/users.py index e53a881..7c86a9a 100644 --- a/app/src/fileflash/services/admin/users.py +++ b/app/src/fileflash/services/admin/users.py @@ -6,9 +6,16 @@ from sqlalchemy.ext.asyncio import AsyncSession from ...core.errors import ApiError -from ...models.enums import UserRole, UserStatus +from ...models.enums import UploadTaskStatus, UserRole, UserStatus from ...models.tables_identity import User, UserSession -from ...schemas.admin.users import AdminUserItem, ListAdminUsersQuery, UpdateUserStatusResponse +from ...models.tables_storage import UploadTask +from ...models.tables_worker import BackgroundJob +from ...schemas.admin.users import ( + AdminUserItem, + AdminUserUsageStats, + ListAdminUsersQuery, + UpdateUserStatusResponse, +) from ...schemas.common import PaginatedData, PaginationMeta from ._status import external_to_internal, internal_to_external @@ -42,8 +49,25 @@ async def list_users(self, *, query: ListAdminUsersQuery) -> PaginatedData[Admin offset = (query.page - 1) * query.per_page rows = list(await self.db.scalars(statement.offset(offset).limit(query.per_page))) - last_seen_map = await self._collect_last_seen([int(row.user_id) for row in rows]) - items = [self._to_item(row, last_seen_map.get(int(row.user_id))) for row in rows] + user_ids = [int(row.user_id) for row in rows] + last_seen_map = await self._collect_last_seen(user_ids) + usage_from, usage_to = self._resolve_usage_window(query) + usage_map = await self._collect_usage_stats( + user_ids=user_ids, + usage_from=usage_from, + usage_to=usage_to, + ) + items = [ + self._to_item( + row, + last_seen_map.get(int(row.user_id)), + usage_map.get( + int(row.user_id), + AdminUserUsageStats(traffic_bytes=0, agent_tokens=0), + ), + ) + for row in rows + ] return PaginatedData( items=items, pagination=PaginationMeta( @@ -56,6 +80,71 @@ async def list_users(self, *, query: ListAdminUsersQuery) -> PaginatedData[Admin ), ) + @staticmethod + def _resolve_usage_window(query: ListAdminUsersQuery) -> tuple[datetime, datetime]: + try: + return query.resolve_usage_window() + except ValueError as exc: + raise ApiError(status_code=400, code=400, message=str(exc)) from exc + + async def _collect_usage_stats( + self, + *, + user_ids: list[int], + usage_from: datetime, + usage_to: datetime, + ) -> dict[int, AdminUserUsageStats]: + if not user_ids: + return {} + + traffic_rows = await self.db.execute( + select(UploadTask.user_id, func.coalesce(func.sum(UploadTask.total_size), 0)) + .where( + and_( + UploadTask.user_id.in_(user_ids), + UploadTask.status == UploadTaskStatus.COMPLETED, + UploadTask.completed_at.is_not(None), + UploadTask.completed_at >= usage_from, + UploadTask.completed_at <= usage_to, + ) + ) + .group_by(UploadTask.user_id) + ) + stats: dict[int, AdminUserUsageStats] = { + int(user_id): AdminUserUsageStats(traffic_bytes=int(total or 0), agent_tokens=0) + for user_id, total in traffic_rows.all() + } + + token_expr = BackgroundJob.result["costEstimate"]["tokens"].as_integer() + agent_rows = await self.db.execute( + select( + BackgroundJob.requested_by, + func.coalesce(func.sum(func.coalesce(token_expr, 0)), 0), + ) + .where( + and_( + BackgroundJob.requested_by.in_(user_ids), + BackgroundJob.task_type == "agent.plan", + BackgroundJob.status == "succeeded", + BackgroundJob.finished_at.is_not(None), + BackgroundJob.finished_at >= usage_from, + BackgroundJob.finished_at <= usage_to, + ) + ) + .group_by(BackgroundJob.requested_by) + ) + for user_id, total in agent_rows.all(): + if user_id is None: + continue + key = int(user_id) + current = stats.get(key, AdminUserUsageStats(traffic_bytes=0, agent_tokens=0)) + stats[key] = AdminUserUsageStats( + traffic_bytes=current.traffic_bytes, + agent_tokens=int(total or 0), + ) + + return stats + async def set_status(self, *, user_id: int, external_status: str) -> UpdateUserStatusResponse: target = await self.db.get(User, user_id) if target is None or target.deleted_at is not None: @@ -113,7 +202,11 @@ async def _collect_last_seen(self, user_ids: list[int]) -> dict[int, datetime]: return {int(user_id): seen for user_id, seen in rows.all()} @staticmethod - def _to_item(row: User, last_active_at: datetime | None) -> AdminUserItem: + def _to_item( + row: User, + last_active_at: datetime | None, + usage_stats: AdminUserUsageStats, + ) -> AdminUserItem: limit = max(int(row.storage_limit), 1) return AdminUserItem( user_id=str(row.user_id), @@ -129,6 +222,7 @@ def _to_item(row: User, last_active_at: datetime | None) -> AdminUserItem: last_login_at=row.last_login_at, last_active_at=last_active_at, created_at=row.created_at, + usage_stats=usage_stats, ) diff --git a/app/src/fileflash/services/agent/skill_service.py b/app/src/fileflash/services/agent/skill_service.py index cc06582..634c2d9 100644 --- a/app/src/fileflash/services/agent/skill_service.py +++ b/app/src/fileflash/services/agent/skill_service.py @@ -23,6 +23,7 @@ UpdateAgentSkillRequest, ) from ...schemas.common import PaginatedData, PaginationMeta +from ...agents.harness.tool_registry import REGISTRY class SkillService: @@ -66,6 +67,19 @@ def _coerce_tool_whitelist(raw: Any) -> list[str]: return [str(item) for item in raw if isinstance(item, (str, int, float))] return [] + @staticmethod + def _validate_tool_whitelist(raw: list[str]) -> list[str]: + tools = [str(item).strip() for item in raw if str(item).strip()] + unknown = REGISTRY.unknown_names(tools) + if unknown: + raise ApiError( + status_code=422, + code=422, + message="Unknown agent tool in toolWhitelist", + data={"unknownTools": sorted(unknown)}, + ) + return tools + @classmethod def _to_item(cls, entity: AgentSkill) -> AgentSkillItem: visibility_value = ( @@ -140,13 +154,14 @@ async def get_skill(self, *, user_id: int, skill_key: str) -> AgentSkillItem: async def create_custom_skill(self, *, user_id: int, payload: CreateAgentSkillRequest) -> AgentSkillItem: skill_key = await self._generate_unique_user_skill_key(user_id=user_id, name=payload.name) + tool_whitelist = self._validate_tool_whitelist(payload.tool_whitelist) entity = await self.skills.create( values={ "skill_key": skill_key, "name": payload.name, "description": payload.description, "triggers_text": payload.triggers_text, - "tool_whitelist_json": payload.tool_whitelist, + "tool_whitelist_json": tool_whitelist, "plan_template_json": payload.plan_template, "inputs_schema_json": payload.inputs_schema, "outputs_schema_json": payload.outputs_schema, @@ -178,7 +193,9 @@ async def update_custom_skill(self, *, user_id: int, skill_key: str, payload: Up if "triggers_text" in fields_set: values["triggers_text"] = payload.triggers_text if "tool_whitelist" in fields_set: - values["tool_whitelist_json"] = payload.tool_whitelist or [] + values["tool_whitelist_json"] = self._validate_tool_whitelist( + payload.tool_whitelist or [] + ) if "plan_template" in fields_set: values["plan_template_json"] = payload.plan_template or {} if "inputs_schema" in fields_set: @@ -232,12 +249,13 @@ async def import_global_skills(self, *, payload: ImportAgentSkillsRequest) -> Im results: list[ImportAgentSkillResult] = [] for item in payload.items: + tool_whitelist = self._validate_tool_whitelist(item.tool_whitelist) existing = existing_by_key.get(item.skill_key) values = { "name": item.name, "description": item.description, "triggers_text": item.triggers_text, - "tool_whitelist_json": item.tool_whitelist, + "tool_whitelist_json": tool_whitelist, "plan_template_json": item.plan_template, "inputs_schema_json": item.inputs_schema, "outputs_schema_json": item.outputs_schema, diff --git a/app/src/fileflash/services/download_rate_limit.py b/app/src/fileflash/services/download_rate_limit.py new file mode 100644 index 0000000..7f1870b --- /dev/null +++ b/app/src/fileflash/services/download_rate_limit.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from sqlalchemy.ext.asyncio import AsyncSession + +from ..core.errors import ApiError +from ..core.settings import Settings +from ..models.enums import UserRole +from ..models.tables_identity import User +from .rate_limiter import RedisRateLimiter + + +class DownloadRateLimitService: + def __init__( + self, + *, + db: AsyncSession, + settings: Settings, + rate_limiter: RedisRateLimiter, + ) -> None: + self.db = db + self.settings = settings + self.rate_limiter = rate_limiter + + async def enforce_user(self, *, user: User, bytes_count: int) -> None: + if user.role == UserRole.ADMIN: + return + await self._enforce(scope=f"user:{int(user.user_id)}", bytes_count=bytes_count) + + async def enforce_user_id(self, *, user_id: int, bytes_count: int) -> None: + user = await self.db.get(User, user_id) + if user is not None and user.role == UserRole.ADMIN: + return + await self._enforce(scope=f"user:{int(user_id)}", bytes_count=bytes_count) + + async def enforce_share_ip(self, *, client_ip: str, bytes_count: int) -> None: + await self._enforce(scope=f"share-ip:{client_ip}", bytes_count=bytes_count) + + async def _enforce(self, *, scope: str, bytes_count: int) -> None: + window_seconds = max(1, int(self.settings.download_rate_window_seconds)) + request_limit = max(1, int(self.settings.download_rate_limit_requests)) + byte_limit = max(1, int(self.settings.download_rate_limit_bytes)) + normalized_bytes = max(0, int(bytes_count)) + + request_allowed = await self.rate_limiter.allow( + key=f"download-rate:{scope}:requests", + limit=request_limit, + window_seconds=window_seconds, + ) + if not request_allowed: + raise ApiError(status_code=429, code=429, message="Download rate limit exceeded") + + bytes_allowed = await self.rate_limiter.allow_weighted( + key=f"download-rate:{scope}:bytes", + limit=byte_limit, + window_seconds=window_seconds, + weight=normalized_bytes, + ) + if not bytes_allowed: + raise ApiError(status_code=429, code=429, message="Download bandwidth limit exceeded") + + +__all__ = ["DownloadRateLimitService"] diff --git a/app/src/fileflash/services/file.py b/app/src/fileflash/services/file.py index 453abdb..7d18814 100644 --- a/app/src/fileflash/services/file.py +++ b/app/src/fileflash/services/file.py @@ -74,6 +74,12 @@ class DownloadStreamResult: headers: dict[str, str] +@dataclass(slots=True) +class BatchDownloadPlan: + files: list[tuple[File, StorageObject, str]] + estimated_bytes: int + + @dataclass(slots=True) class ResolvedStreamObject: storage_object: StorageObject @@ -358,6 +364,15 @@ async def create_batch_download_archive( user_id: int, payload: BatchDownloadRequest, ) -> tuple[str, str]: + plan = await self.create_batch_download_plan(user_id=user_id, payload=payload) + return await self.create_batch_download_archive_from_plan(plan=plan) + + async def create_batch_download_plan( + self, + *, + user_id: int, + payload: BatchDownloadRequest, + ) -> BatchDownloadPlan: if self.storage is None: raise ApiError(status_code=503, code=503, message="Object storage is unavailable") @@ -427,6 +442,31 @@ async def create_batch_download_archive( if not files_with_storage: raise ApiError(status_code=404, code=404, message="No downloadable files found") + files = [ + ( + file_row, + storage_object, + self._safe_zip_path(file_paths.get(int(file_row.file_id), file_row.file_name)), + ) + for file_row, storage_object in files_with_storage + ] + estimated_bytes = sum( + int(storage_object.object_size or file_row.file_size or 0) + for file_row, storage_object, _zip_path in files + ) + return BatchDownloadPlan(files=files, estimated_bytes=max(0, estimated_bytes)) + + async def create_batch_download_archive_from_plan( + self, + *, + plan: BatchDownloadPlan, + ) -> tuple[str, str]: + if self.storage is None: + raise ApiError(status_code=503, code=503, message="Object storage is unavailable") + + if not plan.files: + raise ApiError(status_code=404, code=404, message="No downloadable files found") + archive_name = f"fileflash-download-{datetime.now(UTC).strftime('%Y%m%d-%H%M%S')}.zip" tmp = tempfile.NamedTemporaryFile(prefix="fileflash-download-", suffix=".zip", delete=False) tmp_path = tmp.name @@ -434,8 +474,7 @@ async def create_batch_download_archive( try: with zipfile.ZipFile(tmp_path, mode="w", compression=zipfile.ZIP_DEFLATED, allowZip64=True) as archive: - for file_row, storage_object in files_with_storage: - zip_path = self._safe_zip_path(file_paths.get(int(file_row.file_id), file_row.file_name)) + for _file_row, storage_object, zip_path in plan.files: with archive.open(zip_path, mode="w") as entry: async for chunk in self.storage.iter_object( bucket_name=storage_object.bucket_name, diff --git a/app/src/fileflash/services/rate_limiter.py b/app/src/fileflash/services/rate_limiter.py index 9ced38d..00450a9 100644 --- a/app/src/fileflash/services/rate_limiter.py +++ b/app/src/fileflash/services/rate_limiter.py @@ -21,13 +21,17 @@ async def _client(self) -> Redis | None: return self._redis async def allow(self, key: str, limit: int, window_seconds: int) -> bool: + return await self.allow_weighted(key=key, limit=limit, window_seconds=window_seconds, weight=1) + + async def allow_weighted(self, key: str, limit: int, window_seconds: int, weight: int) -> bool: client = await self._client() if client is None: return True + normalized_weight = max(0, int(weight)) try: - current = await client.incr(key) - if current == 1: + current = await client.incrby(key, normalized_weight) + if current == normalized_weight: await client.expire(key, window_seconds) return current <= limit except RedisError: diff --git a/app/src/fileflash/services/share.py b/app/src/fileflash/services/share.py index 4d0ce35..a2b0755 100644 --- a/app/src/fileflash/services/share.py +++ b/app/src/fileflash/services/share.py @@ -2,7 +2,7 @@ import logging import secrets -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Awaitable, Callable from datetime import UTC, datetime from math import ceil from pathlib import Path @@ -422,6 +422,7 @@ async def get_shared_file_download_stream_response( range_header: str | None, ip_address: str, user_agent: str | None, + rate_limit_check: Callable[[int], Awaitable[None]] | None = None, ) -> tuple[AsyncIterator[bytes], str, str, int, dict[str, str]]: async def _operation() -> tuple[AsyncIterator[bytes], str, str, int, dict[str, str]]: await apply_local_lock_timeout(self.db) @@ -449,6 +450,11 @@ async def _operation() -> tuple[AsyncIterator[bytes], str, str, int, dict[str, s if object_size <= 0: raise ApiError(status_code=404, code=404, message="Shared file content not found") + byte_range = self._parse_range_header(range_header=range_header, file_size=object_size) + bytes_to_send = object_size if byte_range is None else byte_range[1] - byte_range[0] + 1 + if rate_limit_check is not None: + await rate_limit_check(bytes_to_send) + if action == "download": await self.db.execute( update(Share) @@ -479,7 +485,6 @@ async def _operation() -> tuple[AsyncIterator[bytes], str, str, int, dict[str, s ), } - byte_range = self._parse_range_header(range_header=range_header, file_size=object_size) if byte_range is None: headers["Content-Length"] = str(object_size) return ( diff --git a/app/tests/test_admin_system_service.py b/app/tests/test_admin_system_service.py index f5cbe7c..2c94b41 100644 --- a/app/tests/test_admin_system_service.py +++ b/app/tests/test_admin_system_service.py @@ -35,3 +35,19 @@ async def test_health_hash_computation_enabled_follows_settings() -> None: ) enabled_health = await enabled_service.health() assert enabled_health.hash_computation_enabled is True + + +@pytest.mark.asyncio +async def test_rate_limit_status_uses_auth_default_limits() -> None: + service = AdminSystemService( + db=DummySession(), + settings=make_settings(), + ) + + status = await service.rate_limit_status() + rules_by_scope = {rule.scope: rule for rule in status.rules} + + assert rules_by_scope["auth.login"].limit == 30 + assert rules_by_scope["auth.login"].window_seconds == 300 + assert rules_by_scope["auth.register"].limit == 12 + assert rules_by_scope["auth.register"].window_seconds == 600 diff --git a/app/tests/test_admin_users_routes.py b/app/tests/test_admin_users_routes.py index b5cccae..bff1702 100644 --- a/app/tests/test_admin_users_routes.py +++ b/app/tests/test_admin_users_routes.py @@ -9,7 +9,7 @@ from fileflash.core.deps import get_admin_users_service, require_admin from fileflash.core.errors import ApiError, api_error_handler from fileflash.routers.admin_users import router as admin_users_router -from fileflash.schemas.admin.users import AdminUserItem, UpdateUserStatusResponse +from fileflash.schemas.admin.users import AdminUserItem, AdminUserUsageStats, UpdateUserStatusResponse from fileflash.schemas.common import PaginatedData, PaginationMeta @@ -29,6 +29,7 @@ async def list_users(self, *, query): # noqa: ANN001 last_login_at=None, last_active_at=None, created_at=datetime.now(UTC), + usage_stats=AdminUserUsageStats(traffic_bytes=1024, agent_tokens=42), ) return PaginatedData( items=[item], @@ -73,6 +74,7 @@ def test_admin_can_list_users() -> None: body = resp.json() assert body["success"] is True assert body["data"]["items"][0]["username"] == "alice" + assert body["data"]["items"][0]["usageStats"] == {"trafficBytes": 1024, "agentTokens": 42} def test_non_admin_gets_403() -> None: @@ -81,6 +83,32 @@ def test_non_admin_gets_403() -> None: assert resp.status_code == 403 +def test_usage_window_requires_both_bounds() -> None: + with _client(admin=True) as c: + resp = c.get("/api/v1/admin/users?usageFrom=2026-01-01T00:00:00Z") + assert resp.status_code == 400 + + +def test_usage_window_rejects_reversed_bounds() -> None: + with _client(admin=True) as c: + resp = c.get( + "/api/v1/admin/users" + "?usageFrom=2026-02-01T00:00:00Z" + "&usageTo=2026-01-01T00:00:00Z" + ) + assert resp.status_code == 400 + + +def test_usage_window_rejects_more_than_90_days() -> None: + with _client(admin=True) as c: + resp = c.get( + "/api/v1/admin/users" + "?usageFrom=2026-01-01T00:00:00Z" + "&usageTo=2026-04-02T00:00:00Z" + ) + assert resp.status_code == 400 + + def test_admin_can_patch_status() -> None: with _client(admin=True) as c: resp = c.patch("/api/v1/admin/users/42/status", json={"status": "suspended"}) diff --git a/app/tests/test_admin_users_service.py b/app/tests/test_admin_users_service.py index 9e455a8..1370023 100644 --- a/app/tests/test_admin_users_service.py +++ b/app/tests/test_admin_users_service.py @@ -42,6 +42,14 @@ def __init__(self) -> None: self.execute = AsyncMock() +class ResultRows: + def __init__(self, rows) -> None: # noqa: ANN001 + self._rows = rows + + def all(self): # noqa: ANN201 + return self._rows + + @pytest.mark.asyncio async def test_list_users_returns_paginated_items() -> None: session = DummySession() @@ -55,6 +63,41 @@ async def test_list_users_returns_paginated_items() -> None: assert result.pagination.total_items == 1 assert result.items[0].username == "alice" assert result.items[0].status == "active" + assert result.items[0].usage_stats.traffic_bytes == 0 + assert result.items[0].usage_stats.agent_tokens == 0 + + +def test_list_users_query_default_usage_window() -> None: + now = datetime(2026, 5, 26, 12, 0, tzinfo=UTC) + usage_from, usage_to = ListAdminUsersQuery().resolve_usage_window(now=now) + + assert usage_to == now + assert (usage_to - usage_from).days == 7 + + +@pytest.mark.asyncio +async def test_collect_usage_stats_aggregates_traffic_and_tokens() -> None: + session = DummySession() + session.execute = AsyncMock( + side_effect=[ + ResultRows([(1, 2048), (2, 4096)]), + ResultRows([(1, 1500), (3, None)]), + ] + ) + service = AdminUsersService(db=session) # type: ignore[arg-type] + + stats = await service._collect_usage_stats( + user_ids=[1, 2, 3], + usage_from=datetime(2026, 5, 1, tzinfo=UTC), + usage_to=datetime(2026, 5, 26, tzinfo=UTC), + ) + + assert stats[1].traffic_bytes == 2048 + assert stats[1].agent_tokens == 1500 + assert stats[2].traffic_bytes == 4096 + assert stats[2].agent_tokens == 0 + assert stats[3].traffic_bytes == 0 + assert stats[3].agent_tokens == 0 @pytest.mark.asyncio diff --git a/app/tests/test_agent_a_end_to_end.py b/app/tests/test_agent_a_end_to_end.py new file mode 100644 index 0000000..377fcbc --- /dev/null +++ b/app/tests/test_agent_a_end_to_end.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import asyncio +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest +from test_agent_inbox_repository import InboxSession + +from fileflash.agents.harness.event_bus import InMemoryAgentEventBus +from fileflash.agents.harness.inbox import AgentInbox +from fileflash.agents.runtime import execute_runner as execute_module +from fileflash.agents.runtime.execute_runner import AgentJobCanceled, ExecuteRunner +from fileflash.models import BackgroundJob +from fileflash.models.enums import AgentInboxKind + + +class RuntimeInboxSession(InboxSession): + async def refresh(self, _job: BackgroundJob) -> None: + return None + + async def rollback(self) -> None: + return None + + +def _execute_job() -> BackgroundJob: + now = datetime.now(UTC) + return BackgroundJob( + job_id=800, + task_type="agent.execute", + status="running", + payload={ + "planJobId": "500", + "planHash": "sha256:test", + "approval": { + "confirmedBy": "7", + "confirmedAt": now.isoformat(), + "highRiskConfirmed": False, + }, + }, + result={}, + requested_by=7, + scheduled_at=now, + created_at=now, + updated_at=now, + ) + + +@pytest.mark.asyncio +async def test_user_pause_then_cancel_via_inbox(monkeypatch: pytest.MonkeyPatch): + action = { + "step": 1, + "tool": "drive.countFiles", + "input": {"folderId": "root", "recursive": True, "category": "video"}, + "sideEffect": "read", + "riskLevel": "low", + "requiresConfirmation": False, + } + monkeypatch.setattr( + execute_module, + "AgentPlanRepository", + lambda _db: SimpleNamespace( + get_for_execute_binding=AsyncMock( + return_value=SimpleNamespace(proposed_actions_json=[action]) + ) + ), + ) + monkeypatch.setattr( + execute_module, + "AgentWorkSessionRepository", + lambda _db: SimpleNamespace( + create_for_job=AsyncMock(return_value=None), + close_session=AsyncMock(return_value=None), + ), + ) + + session = RuntimeInboxSession() + bus = InMemoryAgentEventBus() + inbox = AgentInbox(db=session, event_bus=bus) # type: ignore[arg-type] + job = _execute_job() + seen_events: list[str] = [] + + await inbox.handle(job_id=int(job.job_id), kind=AgentInboxKind.CONTROL_PAUSE, payload={}) + await session.commit() + + async def cancel_when_paused() -> None: + async with bus.subscribe(job_id=int(job.job_id)) as stream: + while True: + event = await stream.next(timeout=2.0) + seen_events.append(event.event_type) + if event.event_type == "agent.paused": + await inbox.handle( + job_id=int(job.job_id), + kind=AgentInboxKind.CONTROL_CANCEL, + payload={}, + ) + await session.commit() + return + + listener = asyncio.create_task(cancel_when_paused()) + await asyncio.sleep(0) + with pytest.raises(AgentJobCanceled): + await ExecuteRunner(event_bus=bus).run(db=session, job=job) # type: ignore[arg-type] + await listener + + assert "agent.paused" in seen_events + assert job.cancel_requested_at is not None diff --git a/app/tests/test_agent_ask_protocol.py b/app/tests/test_agent_ask_protocol.py new file mode 100644 index 0000000..d935075 --- /dev/null +++ b/app/tests/test_agent_ask_protocol.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import asyncio + +import pytest +from test_agent_inbox_repository import InboxSession + +from fileflash.agents.harness.ask import AskProtocol, AskTimedOut +from fileflash.agents.harness.event_bus import InMemoryAgentEventBus +from fileflash.agents.harness.inbox import AgentInbox +from fileflash.models.enums import AgentInboxKind, AgentInboxStatus + + +@pytest.mark.asyncio +async def test_ask_returns_when_reply_arrives() -> None: + session = InboxSession() + bus = InMemoryAgentEventBus() + inbox = AgentInbox(db=session, event_bus=bus) # type: ignore[arg-type] + protocol = AskProtocol(db=session, event_bus=bus, job_id=1) # type: ignore[arg-type] + + await protocol.start() + try: + async def reply_later() -> None: + for _ in range(20): + asks = [msg for msg in session.messages if msg.kind == AgentInboxKind.ASK] + if asks: + ask = asks[-1] + await inbox.handle( + job_id=1, + kind=AgentInboxKind.REPLY, + payload={"value": "A"}, + reply_to_id=int(ask.inbox_message_id), + ) + await session.commit() + return + await asyncio.sleep(0.01) + raise AssertionError("ask message was not created") + + replier = asyncio.create_task(reply_later()) + result = await protocol.ask( + prompt="choose", + schema={"choice": ["A", "B"]}, + timeout_sec=2.0, + ) + await replier + finally: + await protocol.aclose() + + assert result == "A" + + +@pytest.mark.asyncio +async def test_ask_times_out() -> None: + session = InboxSession() + bus = InMemoryAgentEventBus() + protocol = AskProtocol(db=session, event_bus=bus, job_id=1) # type: ignore[arg-type] + + await protocol.start() + try: + with pytest.raises(AskTimedOut): + await protocol.ask(prompt="?", schema={}, timeout_sec=0.1) + finally: + await protocol.aclose() + + asks = [msg for msg in session.messages if msg.kind == AgentInboxKind.ASK] + assert asks + assert asks[-1].status == AgentInboxStatus.TIMED_OUT diff --git a/app/tests/test_agent_event_bus.py b/app/tests/test_agent_event_bus.py new file mode 100644 index 0000000..6214640 --- /dev/null +++ b/app/tests/test_agent_event_bus.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from datetime import UTC, datetime + +import pytest + +from fileflash.agents.harness.event_bus import AgentEventEnvelope, InMemoryAgentEventBus + + +@pytest.mark.asyncio +async def test_subscriber_receives_published_event() -> None: + bus = InMemoryAgentEventBus() + envelope = AgentEventEnvelope( + job_id=42, + event_type="agent.ask", + payload={"prompt": "choose"}, + emitted_at=datetime.now(UTC), + ) + + async with bus.subscribe(job_id=42) as stream: + await bus.publish(envelope) + received = await stream.next(timeout=1.0) + + assert received == envelope + + +@pytest.mark.asyncio +async def test_only_subscribers_of_same_job_receive() -> None: + bus = InMemoryAgentEventBus() + own = AgentEventEnvelope( + job_id=1, + event_type="job.running", + payload={}, + emitted_at=datetime.now(UTC), + ) + other = AgentEventEnvelope( + job_id=2, + event_type="job.running", + payload={}, + emitted_at=datetime.now(UTC), + ) + + async with bus.subscribe(job_id=1) as stream: + await bus.publish(other) + await bus.publish(own) + first = await stream.next(timeout=1.0) + + assert first == own + + +@pytest.mark.asyncio +async def test_empty_subscriber_times_out() -> None: + bus = InMemoryAgentEventBus() + async with bus.subscribe(job_id=7) as stream: + with pytest.raises(TimeoutError): + await stream.next(timeout=0.1) + + +def test_event_envelope_json_serializes_nested_datetime_payload() -> None: + now = datetime.now(UTC).replace(microsecond=0) + envelope = AgentEventEnvelope( + job_id=9, + event_type="job.succeeded", + payload={ + "data": { + "result": { + "finishedAt": now, + "steps": [{"completedAt": now}], + } + } + }, + emitted_at=now, + event_id="evt-1", + ) + + decoded = AgentEventEnvelope.from_json(envelope.to_json()) + + assert decoded.job_id == envelope.job_id + assert decoded.event_type == envelope.event_type + assert decoded.event_id == "evt-1" + assert decoded.payload["data"]["result"]["finishedAt"] == now.isoformat() + assert decoded.payload["data"]["result"]["steps"][0]["completedAt"] == now.isoformat() diff --git a/app/tests/test_agent_inbox.py b/app/tests/test_agent_inbox.py new file mode 100644 index 0000000..7520ffa --- /dev/null +++ b/app/tests/test_agent_inbox.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import pytest +from test_agent_inbox_repository import InboxSession + +from fileflash.agents.harness.event_bus import InMemoryAgentEventBus +from fileflash.agents.harness.inbox import AgentInbox +from fileflash.models import AgentInboxMessage +from fileflash.models.enums import AgentInboxKind +from fileflash.repositories import AgentInboxMessageRepository + + +@pytest.mark.asyncio +async def test_handle_reply_persists_and_publishes() -> None: + session = InboxSession() + repo = AgentInboxMessageRepository(session) # type: ignore[arg-type] + ask = await repo.create_ask(job_id=1, payload={"prompt": "?"}) + await session.commit() + + bus = InMemoryAgentEventBus() + inbox = AgentInbox(db=session, event_bus=bus) # type: ignore[arg-type] + + async with bus.subscribe(job_id=1) as stream: + msg = await inbox.handle( + job_id=1, + kind=AgentInboxKind.REPLY, + payload={"value": "yes"}, + reply_to_id=int(ask.inbox_message_id), + ) + await session.commit() + event = await stream.next(timeout=1.0) + + assert isinstance(msg, AgentInboxMessage) + assert msg.kind == AgentInboxKind.REPLY + assert event.event_type == "agent.inbox.reply" + assert event.payload["replyTo"] == str(ask.inbox_message_id) + assert event.payload["value"] == "yes" + + +@pytest.mark.asyncio +async def test_reply_with_unknown_ask_raises() -> None: + session = InboxSession() + bus = InMemoryAgentEventBus() + inbox = AgentInbox(db=session, event_bus=bus) # type: ignore[arg-type] + + with pytest.raises(ValueError): + await inbox.handle( + job_id=1, + kind=AgentInboxKind.REPLY, + payload={"value": "yes"}, + reply_to_id=999999, + ) diff --git a/app/tests/test_agent_inbox_model.py b/app/tests/test_agent_inbox_model.py new file mode 100644 index 0000000..b9fae17 --- /dev/null +++ b/app/tests/test_agent_inbox_model.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from datetime import UTC, datetime + +from fileflash.models import AgentInboxMessage +from fileflash.models.enums import AgentInboxKind, AgentInboxRole, AgentInboxStatus + + +def test_agent_inbox_message_model_fields() -> None: + msg = AgentInboxMessage( + job_id=1, + role=AgentInboxRole.AGENT, + kind=AgentInboxKind.ASK, + payload_json={"prompt": "which one?", "schema": {}}, + status=AgentInboxStatus.WAITING, + created_at=datetime.now(UTC), + ) + + assert AgentInboxMessage.__tablename__ == "agent_inbox_message" + assert msg.kind == AgentInboxKind.ASK + assert msg.status == AgentInboxStatus.WAITING + assert msg.payload_json["prompt"] == "which one?" diff --git a/app/tests/test_agent_inbox_repository.py b/app/tests/test_agent_inbox_repository.py new file mode 100644 index 0000000..7f8f62d --- /dev/null +++ b/app/tests/test_agent_inbox_repository.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from datetime import UTC, datetime + +import pytest + +from fileflash.models import AgentInboxMessage +from fileflash.models.enums import AgentInboxKind, AgentInboxRole, AgentInboxStatus +from fileflash.repositories import AgentInboxMessageRepository + + +class InboxSession: + def __init__(self) -> None: + self.messages: list[AgentInboxMessage] = [] + self._next_id = 1 + self.commits = 0 + + def add(self, msg: AgentInboxMessage) -> None: + msg.inbox_message_id = self._next_id + self._next_id += 1 + self.messages.append(msg) + + async def flush(self) -> None: + return None + + async def commit(self) -> None: + self.commits += 1 + + async def get(self, _model, inbox_message_id: int): # noqa: ANN001 + for msg in self.messages: + if msg.inbox_message_id == inbox_message_id: + return msg + return None + + async def scalar(self, _query): # noqa: ANN001 + return None + + async def scalars(self, _query): # noqa: ANN001 + controls = { + AgentInboxKind.CONTROL_PAUSE, + AgentInboxKind.CONTROL_RESUME, + AgentInboxKind.CONTROL_APPROVE, + AgentInboxKind.CONTROL_DENY, + AgentInboxKind.CONTROL_SKIP, + AgentInboxKind.CONTROL_CANCEL, + } + return [ + msg + for msg in self.messages + if msg.role == AgentInboxRole.USER + and msg.kind in controls + and msg.status is None + ] + + +@pytest.mark.asyncio +async def test_create_ask_then_record_reply() -> None: + session = InboxSession() + repo = AgentInboxMessageRepository(session) # type: ignore[arg-type] + + ask = await repo.create_ask( + job_id=1, + payload={"prompt": "choose", "schema": {"choice": ["A", "B"]}}, + ) + await session.commit() + assert ask.status == AgentInboxStatus.WAITING + assert ask.role == AgentInboxRole.AGENT + assert ask.kind == AgentInboxKind.ASK + + reply = await repo.record_user_message( + job_id=1, + kind=AgentInboxKind.REPLY, + payload={"value": "A"}, + reply_to_id=int(ask.inbox_message_id), + ) + await session.commit() + assert reply.role == AgentInboxRole.USER + assert reply.reply_to_id == ask.inbox_message_id + + answered = await repo.mark_answered( + inbox_message_id=int(ask.inbox_message_id), + answered_at=datetime.now(UTC), + ) + await session.commit() + assert answered.status == AgentInboxStatus.ANSWERED + assert answered.answered_at is not None + + +@pytest.mark.asyncio +async def test_pending_controls_excludes_consumed() -> None: + session = InboxSession() + repo = AgentInboxMessageRepository(session) # type: ignore[arg-type] + pause = await repo.record_user_message( + job_id=1, + kind=AgentInboxKind.CONTROL_PAUSE, + payload={}, + ) + await session.commit() + + pending = await repo.list_pending_controls(job_id=1) + assert [msg.inbox_message_id for msg in pending] == [pause.inbox_message_id] + + await repo.mark_dropped(inbox_message_id=int(pause.inbox_message_id)) + await session.commit() + pending_after = await repo.list_pending_controls(job_id=1) + assert pending_after == [] diff --git a/app/tests/test_agent_inbox_schema.py b/app/tests/test_agent_inbox_schema.py new file mode 100644 index 0000000..f6f7cd2 --- /dev/null +++ b/app/tests/test_agent_inbox_schema.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from fileflash.schemas.agent import AgentInboxMessageRequest + + +def test_reply_with_value_validates() -> None: + msg = AgentInboxMessageRequest.model_validate( + {"kind": "reply", "replyTo": "42", "value": "yes"} + ) + + assert msg.kind == "reply" + assert msg.reply_to == "42" + assert msg.value == "yes" + + +def test_unknown_kind_rejected() -> None: + with pytest.raises(ValidationError): + AgentInboxMessageRequest.model_validate({"kind": "control.explode"}) diff --git a/app/tests/test_agent_plan_execute_runtime.py b/app/tests/test_agent_plan_execute_runtime.py index 46fd8ef..256bdeb 100644 --- a/app/tests/test_agent_plan_execute_runtime.py +++ b/app/tests/test_agent_plan_execute_runtime.py @@ -1,20 +1,25 @@ from __future__ import annotations +import asyncio +from contextlib import asynccontextmanager from datetime import UTC, datetime from types import SimpleNamespace +from typing import Any from unittest.mock import AsyncMock import pytest +from fileflash.agents.harness.event_bus import AgentEventEnvelope from fileflash.agents.harness.policy import PolicyGuard, classify_tool_risk from fileflash.agents.harness.router import ToolCall, ToolRouter from fileflash.agents.runtime import execute_runner as execute_module -from fileflash.agents.runtime.execute_runner import ExecuteRunner from fileflash.agents.runtime import plan_runner as plan_module +from fileflash.agents.runtime.execute_runner import AgentJobCanceled, ExecuteRunner from fileflash.agents.runtime.llm import AnthropicPlannerClient from fileflash.agents.runtime.plan_runner import PlanRunner from fileflash.core.errors import ApiError from fileflash.models import BackgroundJob +from fileflash.models.enums import AgentInboxKind from fileflash.repositories import ( AgentPlanRepository, AgentSettingsRepository, @@ -67,6 +72,7 @@ def settings(**overrides): "agent_user_concurrent_limit": 2, "agent_user_daily_limit": 50, "agent_llm_base_url": None, + "agent_llm_plan_max_tokens": 8192, } base.update(overrides) return SimpleNamespace(**base) @@ -154,7 +160,7 @@ async def create(self, **kwargs): # noqa: ANN003 ) assert fake_messages.kwargs["model"] == "claude-test" - assert fake_messages.kwargs["max_tokens"] == 4096 + assert fake_messages.kwargs["max_tokens"] == 8192 assert fake_messages.kwargs["system"] == "system" assert fake_messages.kwargs["messages"] == [{"role": "user", "content": "user"}] assert fake_messages.kwargs["thinking"] == {"type": "adaptive"} @@ -239,6 +245,272 @@ async def create(self, **kwargs): # noqa: ANN003 assert result["summary"] == "fallback" +@pytest.mark.asyncio +async def test_anthropic_planner_client_retries_with_json_only_prompt_on_invalid_json(): + class FakeMessages: + def __init__(self) -> None: + self.calls: list[dict[str, object]] = [] + + async def create(self, **kwargs): # noqa: ANN003 + self.calls.append(dict(kwargs)) + if len(self.calls) == 1: + return SimpleNamespace( + content=[SimpleNamespace(type="text", text="I think we should move files.")], + usage={}, + ) + if len(self.calls) == 2: + return SimpleNamespace( + content=[SimpleNamespace(type="text", text="summary: move files")], + usage={}, + ) + return SimpleNamespace( + content=[SimpleNamespace(type="text", text='{"summary":"strict","proposedActions":[]}')], + usage={}, + ) + + fake_messages = FakeMessages() + client = AnthropicPlannerClient( + settings=settings( + agent_llm_api_key="test-key", + agent_llm_model="claude-test", + ), + client=SimpleNamespace(messages=fake_messages), # type: ignore[arg-type] + ) + + result = await client.create_plan( + system_prompt="system", + user_prompt="user", + max_tokens=1024, + reasoning_effort="high", + ) + + assert len(fake_messages.calls) == 3 + assert "thinking" in fake_messages.calls[0] + assert "output_config" in fake_messages.calls[0] + assert "thinking" not in fake_messages.calls[1] + assert "output_config" not in fake_messages.calls[1] + assert "thinking" not in fake_messages.calls[2] + assert "output_config" not in fake_messages.calls[2] + assert fake_messages.calls[1]["max_tokens"] == 1024 + assert fake_messages.calls[2]["max_tokens"] == 8192 + third_messages = fake_messages.calls[2]["messages"] + assert isinstance(third_messages, list) + assert "Return ONLY one valid JSON object" in third_messages[0]["content"] + assert result["summary"] == "strict" + + +@pytest.mark.asyncio +async def test_anthropic_planner_client_parses_json_from_wrapped_text(): + class FakeMessages: + async def create(self, **kwargs): # noqa: ANN003 + return SimpleNamespace( + content=[ + SimpleNamespace( + type="text", + text='Here is the result:\n{"summary":"ok","proposedActions":[]}\nThanks!', + ) + ], + usage={}, + ) + + client = AnthropicPlannerClient( + settings=settings( + agent_llm_api_key="test-key", + agent_llm_model="claude-test", + ), + client=SimpleNamespace(messages=FakeMessages()), # type: ignore[arg-type] + ) + + result = await client.create_plan( + system_prompt="system", + user_prompt="user", + max_tokens=1000, + ) + + assert result["summary"] == "ok" + assert result["proposedActions"] == [] + + +@pytest.mark.asyncio +async def test_anthropic_planner_client_raises_after_three_invalid_json_responses(): + class FakeMessages: + def __init__(self) -> None: + self.calls: list[dict[str, object]] = [] + + async def create(self, **kwargs): # noqa: ANN003 + self.calls.append(dict(kwargs)) + return SimpleNamespace( + content=[SimpleNamespace(type="text", text="not valid json")], + usage={}, + ) + + fake_messages = FakeMessages() + client = AnthropicPlannerClient( + settings=settings( + agent_llm_api_key="test-key", + agent_llm_model="claude-test", + ), + client=SimpleNamespace(messages=fake_messages), # type: ignore[arg-type] + ) + + with pytest.raises(ApiError) as exc: + await client.create_plan( + system_prompt="system", + user_prompt="user", + max_tokens=1024, + reasoning_effort="high", + ) + + assert exc.value.status_code == 502 + assert exc.value.message == "Agent LLM did not return valid JSON" + assert len(fake_messages.calls) == 3 + + +@pytest.mark.asyncio +async def test_anthropic_planner_client_uses_tools_and_parses_tool_use_blocks(): + class FakeMessages: + def __init__(self) -> None: + self.kwargs = {} + + async def create(self, **kwargs): # noqa: ANN003 + self.kwargs = kwargs + return SimpleNamespace( + content=[ + SimpleNamespace(type="text", text="Count matching videos."), + SimpleNamespace( + type="tool_use", + id="toolu_1", + name="drive_count_files", + input={"folderId": "root", "category": "video"}, + ), + ], + usage={"input_tokens": 5, "output_tokens": 6}, + ) + + fake_messages = FakeMessages() + client = AnthropicPlannerClient( + settings=settings( + agent_llm_api_key="test-key", + agent_llm_model="claude-test", + ), + client=SimpleNamespace(messages=fake_messages), # type: ignore[arg-type] + ) + + result = await client.create_plan( + system_prompt="system", + user_prompt="user", + max_tokens=1024, + tools=[ + { + "name": "drive_count_files", + "description": "Count files.", + "input_schema": {"type": "object"}, + "internalName": "drive.countFiles", + } + ], + ) + + assert fake_messages.kwargs["tool_choice"] == {"type": "auto"} + assert fake_messages.kwargs["tools"] == [ + { + "name": "drive_count_files", + "description": "Count files.", + "input_schema": {"type": "object"}, + } + ] + assert result["summary"] == "Count matching videos." + assert result["proposedActions"] == [ + { + "step": 1, + "tool": "drive.countFiles", + "input": {"folderId": "root", "category": "video"}, + } + ] + assert result["_usage"] == {"input_tokens": 5, "output_tokens": 6} + + +@pytest.mark.asyncio +async def test_anthropic_planner_client_executes_tool_loop_before_final_plan(): + class FakeMessages: + def __init__(self) -> None: + self.calls: list[dict[str, Any]] = [] + + async def create(self, **kwargs): # noqa: ANN003 + self.calls.append(dict(kwargs)) + if len(self.calls) == 1: + return SimpleNamespace( + content=[ + SimpleNamespace(type="text", text="Need to inspect candidates first."), + SimpleNamespace( + type="tool_use", + id="toolu_1", + name="drive_search_files", + input={"folderId": "root", "query": "银翼杀手", "category": "video"}, + ), + ], + usage={"input_tokens": 10, "output_tokens": 5}, + ) + return SimpleNamespace( + content=[ + SimpleNamespace( + type="text", + text=( + '{"summary":"move matched file","proposedActions":[' + '{"step":1,"tool":"drive.moveFile","input":{"fileId":"11","targetFolderId":"21"}}' + "]}" + ), + ) + ], + usage={"input_tokens": 8, "output_tokens": 9}, + ) + + fake_messages = FakeMessages() + tool_executor = AsyncMock(return_value={"items": [{"id": "11", "name": "银翼杀手2049.mp4"}], "totalItems": 1}) + client = AnthropicPlannerClient( + settings=settings( + agent_llm_api_key="test-key", + agent_llm_model="claude-test", + ), + client=SimpleNamespace(messages=fake_messages), # type: ignore[arg-type] + ) + + result = await client.create_plan( + system_prompt="system", + user_prompt="user", + max_tokens=1024, + tools=[ + { + "name": "drive_search_files", + "description": "Search files.", + "input_schema": {"type": "object"}, + "internalName": "drive.searchFiles", + }, + { + "name": "drive_move_file", + "description": "Move file.", + "input_schema": {"type": "object"}, + "internalName": "drive.moveFile", + }, + ], + tool_executor=tool_executor, + max_tool_roundtrips=4, + ) + + tool_executor.assert_awaited_once_with( + "drive.searchFiles", + {"folderId": "root", "query": "银翼杀手", "category": "video"}, + ) + assert len(fake_messages.calls) == 2 + second_messages = fake_messages.calls[1]["messages"] + assert isinstance(second_messages, list) + tool_result_blocks = second_messages[-1]["content"] + assert isinstance(tool_result_blocks, list) + assert tool_result_blocks[0]["type"] == "tool_result" + assert result["summary"] == "move matched file" + assert result["proposedActions"][0]["tool"] == "drive.moveFile" + assert result["_usage"] == {"input_tokens": 18, "output_tokens": 14} + + def test_anthropic_planner_client_uses_configured_base_url(monkeypatch: pytest.MonkeyPatch): captured: dict[str, object] = {} @@ -571,7 +843,7 @@ async def test_plan_runner_rolls_back_when_commit_fails(monkeypatch: pytest.Monk @pytest.mark.asyncio -async def test_plan_runner_uses_safe_read_only_fallback_when_planner_returns_invalid_output( +async def test_plan_runner_propagates_llm_output_errors_without_fallback( monkeypatch: pytest.MonkeyPatch, ): monkeypatch.setattr(plan_module, "_choose_skill", AsyncMock(return_value=None)) @@ -618,17 +890,13 @@ async def test_plan_runner_uses_safe_read_only_fallback_when_planner_returns_inv ) db = DummyDb() - result = await runner.run(db=db, job=job) # type: ignore[arg-type] - - assert "fallback mode" in result.summary.lower() - assert len(result.proposed_actions) == 1 - assert result.proposed_actions[0].tool == "drive.listFolder" - assert result.proposed_actions[0].side_effect == "read" - assert result.requires_confirmation is False + with pytest.raises(ApiError) as exc: + await runner.run(db=db, job=job) # type: ignore[arg-type] + assert exc.value.status_code == 502 @pytest.mark.asyncio -async def test_plan_runner_fallback_uses_count_files_for_movie_count_question( +async def test_plan_runner_uses_planner_returned_count_action_for_movie_question( monkeypatch: pytest.MonkeyPatch, ): monkeypatch.setattr(plan_module, "_choose_skill", AsyncMock(return_value=None)) @@ -640,11 +908,16 @@ async def test_plan_runner_fallback_uses_count_files_for_movie_count_question( monkeypatch.setattr(plan_module, "_upsert_agent_plan", AsyncMock(return_value=None)) planner = AsyncMock( - side_effect=ApiError( - status_code=502, - code=502, - message="Agent LLM returned an empty response", - ) + return_value={ + "summary": "count videos", + "proposedActions": [ + { + "step": 1, + "tool": "drive.countFiles", + "input": {"folderId": "root", "recursive": True, "category": "video"}, + } + ], + } ) runner = PlanRunner( settings=settings(), @@ -682,171 +955,1321 @@ async def test_plan_runner_fallback_uses_count_files_for_movie_count_question( assert result.proposed_actions[0].side_effect == "read" -def test_normalize_actions_rejects_symbolic_placeholder_target_folder(): - with pytest.raises(ApiError) as exc: - plan_module._normalize_actions( - llm_payload={ - "summary": "organize movies", - "proposedActions": [ - { - "step": 1, - "tool": "drive.createFolder", - "input": {"parentFolderId": "root", "name": "Movies"}, - }, - { - "step": 2, - "tool": "drive.moveFile", - "input": {"fileId": "13", "targetFolderId": "newFolderId"}, - }, - ], - }, - allowed_tools=("drive.createFolder", "drive.moveFile"), - max_steps=10, - ) - - assert exc.value.status_code == 400 - assert "step 2" in exc.value.message - assert "targetFolderId" in exc.value.message - assert "newFolderId" in exc.value.message - +@pytest.mark.asyncio +async def test_plan_runner_uses_planner_returned_count_action_for_anime_question( + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr(plan_module, "_choose_skill", AsyncMock(return_value=None)) + monkeypatch.setattr( + plan_module, + "_collect_context_metadata", + AsyncMock(return_value={"scope": "currentFolder", "rootFolderId": "root", "files": [], "folders": []}), + ) + monkeypatch.setattr(plan_module, "_upsert_agent_plan", AsyncMock(return_value=None)) -def test_normalize_actions_accepts_previous_step_reference(): - actions = plan_module._normalize_actions( - llm_payload={ - "summary": "organize movies", + planner = AsyncMock( + return_value={ + "summary": "count anime videos", "proposedActions": [ { "step": 1, - "tool": "drive.createFolder", - "input": {"parentFolderId": "root", "name": "Movies"}, - }, - { - "step": 2, - "tool": "drive.moveFile", - "input": {"fileId": "13", "targetFolderId": "$step1.folderId"}, - }, + "tool": "drive.countFiles", + "input": {"folderId": "root", "recursive": True, "category": "video"}, + } ], - }, - allowed_tools=("drive.createFolder", "drive.moveFile"), - max_steps=10, + } ) - - assert len(actions) == 2 - assert actions[1].input["targetFolderId"] == "$step1.folderId" - - -def test_normalize_actions_rejects_future_step_reference(): - with pytest.raises(ApiError) as exc: - plan_module._normalize_actions( - llm_payload={ - "summary": "organize movies", - "proposedActions": [ - { - "step": 3, - "tool": "drive.moveFile", - "input": {"fileId": "13", "targetFolderId": "$step4.folderId"}, - }, - { - "step": 4, - "tool": "drive.createFolder", - "input": {"parentFolderId": "root", "name": "Movies"}, + runner = PlanRunner( + settings=settings(), + planner_client=SimpleNamespace(create_plan=planner), # type: ignore[arg-type] + ) + request = PlanAgentRequest.model_validate( + { + "input": "我上传了多少动漫?", + "context": { + "rootFolderId": "root", + "selectedFileIds": [], + "selectedFolderIds": [], + "currentPath": "/My Files", + }, + "executionPolicy": "confirm", + } + ) + job = BackgroundJob( + job_id=336, + task_type="agent.plan", + status="running", + payload=request.model_dump(by_alias=True), + result={}, + requested_by=7, + scheduled_at=datetime.now(UTC), + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + + result = await runner.run(db=DummyDb(), job=job) # type: ignore[arg-type] + + assert len(result.proposed_actions) == 1 + action = result.proposed_actions[0] + assert action.tool == "drive.countFiles" + assert action.input["category"] == "video" + assert action.input.get("search") in (None, "") + + +@pytest.mark.asyncio +async def test_plan_runner_delegates_count_question_with_search_term_to_planner( + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr(plan_module, "_choose_skill", AsyncMock(return_value=None)) + monkeypatch.setattr( + plan_module, + "_collect_context_metadata", + AsyncMock(return_value={"scope": "currentFolder", "rootFolderId": "root", "files": [], "folders": []}), + ) + monkeypatch.setattr(plan_module, "_upsert_agent_plan", AsyncMock(return_value=None)) + + planner = AsyncMock( + return_value={ + "summary": "count matching movies", + "proposedActions": [ + { + "step": 1, + "tool": "drive.countFiles", + "input": { + "folderId": "root", + "recursive": True, + "category": "video", + "search": "银翼杀手", }, - ], + } + ], + } + ) + runner = PlanRunner( + settings=settings(), + planner_client=SimpleNamespace(create_plan=planner), # type: ignore[arg-type] + ) + request = PlanAgentRequest.model_validate( + { + "input": "我上传了几部银翼杀手?", + "context": { + "rootFolderId": "root", + "selectedFileIds": [], + "selectedFolderIds": [], + "currentPath": "/My Files", }, - allowed_tools=("drive.createFolder", "drive.moveFile"), - max_steps=10, - ) + "executionPolicy": "confirm", + } + ) + job = BackgroundJob( + job_id=335, + task_type="agent.plan", + status="running", + payload=request.model_dump(by_alias=True), + result={}, + requested_by=7, + scheduled_at=datetime.now(UTC), + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + + result = await runner.run(db=DummyDb(), job=job) # type: ignore[arg-type] + + planner.assert_awaited_once() + assert len(result.proposed_actions) == 1 + action = result.proposed_actions[0] + assert action.tool == "drive.countFiles" + assert action.input["category"] == "video" + assert action.input["search"] == "银翼杀手" + assert action.side_effect == "read" + + +@pytest.mark.asyncio +async def test_plan_runner_rejects_write_tool_in_exploratory_loop( + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr(plan_module, "_choose_skill", AsyncMock(return_value=None)) + monkeypatch.setattr( + plan_module, + "_collect_context_metadata", + AsyncMock(return_value={"scope": "currentFolder", "rootFolderId": "root", "files": [], "folders": []}), + ) + monkeypatch.setattr(plan_module, "_upsert_agent_plan", AsyncMock(return_value=None)) + + async def fake_create_plan(**kwargs): # noqa: ANN003 + tool_executor = kwargs["tool_executor"] + await tool_executor("drive.moveFile", {"fileId": "1", "targetFolderId": "2"}) + return {"summary": "should not reach", "proposedActions": []} + + runner = PlanRunner( + settings=settings(), + planner_client=SimpleNamespace(create_plan=fake_create_plan), # type: ignore[arg-type] + ) + request = PlanAgentRequest.model_validate( + { + "input": "整理文件", + "context": { + "rootFolderId": "root", + "selectedFileIds": [], + "selectedFolderIds": [], + "currentPath": "/My Files", + }, + "executionPolicy": "confirm", + } + ) + job = BackgroundJob( + job_id=340, + task_type="agent.plan", + status="running", + payload=request.model_dump(by_alias=True), + result={}, + requested_by=7, + scheduled_at=datetime.now(UTC), + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + with pytest.raises(ApiError) as exc: + await runner.run(db=DummyDb(), job=job) # type: ignore[arg-type] assert exc.value.status_code == 400 - assert "future step 4" in exc.value.message - assert "$step4.folderId" in exc.value.message + assert "read-only" in exc.value.message + +@pytest.mark.asyncio +async def test_plan_runner_uses_planner_returned_move_action_when_unique_matches( + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr(plan_module, "_choose_skill", AsyncMock(return_value=None)) + monkeypatch.setattr( + plan_module, + "_collect_context_metadata", + AsyncMock(return_value={"scope": "currentFolder", "rootFolderId": "root", "files": [], "folders": []}), + ) + monkeypatch.setattr(plan_module, "_upsert_agent_plan", AsyncMock(return_value=None)) + planner = AsyncMock( + return_value={ + "summary": "move one matching file", + "proposedActions": [ + { + "step": 1, + "tool": "drive.moveFile", + "input": {"fileId": "11", "targetFolderId": "21", "shareHandling": "keep"}, + } + ], + } + ) + runner = PlanRunner( + settings=settings(), + planner_client=SimpleNamespace(create_plan=planner), # type: ignore[arg-type] + ) + request = PlanAgentRequest.model_validate( + { + "input": "把银翼杀手电影放到银翼杀手文件夹下", + "context": { + "rootFolderId": "root", + "selectedFileIds": [], + "selectedFolderIds": [], + "currentPath": "/My Files", + }, + "executionPolicy": "confirm", + } + ) + job = BackgroundJob( + job_id=337, + task_type="agent.plan", + status="running", + payload=request.model_dump(by_alias=True), + result={}, + requested_by=7, + scheduled_at=datetime.now(UTC), + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + + result = await runner.run(db=DummyDb(), job=job) # type: ignore[arg-type] + + assert len(result.proposed_actions) == 1 + action = result.proposed_actions[0] + assert action.tool == "drive.moveFile" + assert action.input["fileId"] == "11" + assert action.input["targetFolderId"] == "21" + assert action.side_effect == "write" + + +@pytest.mark.asyncio +async def test_plan_runner_uses_planner_returned_create_then_move_when_target_missing( + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr(plan_module, "_choose_skill", AsyncMock(return_value=None)) + monkeypatch.setattr( + plan_module, + "_collect_context_metadata", + AsyncMock(return_value={"scope": "currentFolder", "rootFolderId": "root", "files": [], "folders": []}), + ) + monkeypatch.setattr(plan_module, "_upsert_agent_plan", AsyncMock(return_value=None)) + planner = AsyncMock( + return_value={ + "summary": "create missing folder then move", + "proposedActions": [ + { + "step": 1, + "tool": "drive.createFolder", + "input": {"parentFolderId": "root", "name": "银翼杀手"}, + }, + { + "step": 2, + "tool": "drive.moveFile", + "input": {"fileId": "11", "targetFolderId": "$step1.folderId"}, + }, + ], + } + ) + runner = PlanRunner( + settings=settings(), + planner_client=SimpleNamespace(create_plan=planner), # type: ignore[arg-type] + ) + request = PlanAgentRequest.model_validate( + { + "input": "把银翼杀手电影放到银翼杀手文件夹下", + "context": { + "rootFolderId": "root", + "selectedFileIds": [], + "selectedFolderIds": [], + "currentPath": "/My Files", + }, + "executionPolicy": "confirm", + } + ) + job = BackgroundJob( + job_id=338, + task_type="agent.plan", + status="running", + payload=request.model_dump(by_alias=True), + result={}, + requested_by=7, + scheduled_at=datetime.now(UTC), + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + + result = await runner.run(db=DummyDb(), job=job) # type: ignore[arg-type] + + assert len(result.proposed_actions) == 2 + assert result.proposed_actions[0].tool == "drive.createFolder" + assert result.proposed_actions[0].input["name"] == "银翼杀手" + assert result.proposed_actions[1].tool == "drive.moveFile" + assert result.proposed_actions[1].input["fileId"] == "11" + assert result.proposed_actions[1].input["targetFolderId"] == "$step1.folderId" + + +@pytest.mark.asyncio +async def test_plan_runner_rewrites_write_summary_with_grounded_facts( + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr(plan_module, "_choose_skill", AsyncMock(return_value=None)) + monkeypatch.setattr( + plan_module, + "_collect_context_metadata", + AsyncMock(return_value={"scope": "currentFolder", "rootFolderId": "root", "files": [], "folders": []}), + ) + monkeypatch.setattr(plan_module, "_upsert_agent_plan", AsyncMock(return_value=None)) + planner = AsyncMock( + return_value={ + "summary": "创建银翼杀手文件夹,然后将 V字仇杀队 文件夹中的2部银翼杀手电影移入该文件夹。", + "proposedActions": [ + { + "step": 1, + "tool": "drive.createFolder", + "input": {"parentFolderId": "root", "name": "银翼杀手"}, + }, + { + "step": 2, + "tool": "drive.moveFile", + "input": {"fileId": "19", "targetFolderId": "$step1.folderId"}, + }, + { + "step": 3, + "tool": "drive.moveFile", + "input": {"fileId": "20", "targetFolderId": "$step1.folderId"}, + }, + ], + } + ) + runner = PlanRunner( + settings=settings(), + planner_client=SimpleNamespace(create_plan=planner), # type: ignore[arg-type] + ) + request = PlanAgentRequest.model_validate( + { + "input": "把银翼杀手两部,移到银翼杀手文件夹里", + "context": { + "rootFolderId": "root", + "selectedFileIds": [], + "selectedFolderIds": [], + "currentPath": "/My Files", + }, + "executionPolicy": "confirm", + } + ) + job = BackgroundJob( + job_id=348, + task_type="agent.plan", + status="running", + payload=request.model_dump(by_alias=True), + result={}, + requested_by=7, + scheduled_at=datetime.now(UTC), + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + db = DummyDb() + + class _Rows: + def __init__(self, rows: list[tuple[Any, Any]]) -> None: + self._rows = rows + + def all(self) -> list[tuple[Any, Any]]: + return list(self._rows) + + db.execute = AsyncMock(return_value=_Rows([(19, "银翼杀手1982.mp4"), (20, "银翼杀手2049.mp4")])) + + result = await runner.run(db=db, job=job) # type: ignore[arg-type] + + assert "V字仇杀队" not in result.summary + assert "创建“银翼杀手”文件夹" in result.summary + assert "2 个文件" in result.summary + + +@pytest.mark.asyncio +async def test_plan_runner_records_planning_evidence_from_read_tools( + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr(plan_module, "_choose_skill", AsyncMock(return_value=None)) + monkeypatch.setattr( + plan_module, + "_collect_context_metadata", + AsyncMock(return_value={"scope": "currentFolder", "rootFolderId": "root", "files": [], "folders": []}), + ) + monkeypatch.setattr(plan_module, "_upsert_agent_plan", AsyncMock(return_value=None)) + + class _FakeToolRouter: + async def dispatch(self, _call): # noqa: ANN001 + return { + "items": [ + {"fileId": "19", "name": "银翼杀手1982.mp4", "path": "/My Files/videos/银翼杀手1982.mp4"}, + {"fileId": "20", "name": "银翼杀手2049.mp4", "path": "/My Files/videos/银翼杀手2049.mp4"}, + {"fileId": "21", "name": "Blade Runner Trailer.mp4", "path": "/My Files/videos/Blade Runner Trailer.mp4"}, + {"fileId": "22", "name": "x1.mp4", "path": "/My Files/videos/x1.mp4"}, + {"fileId": "23", "name": "x2.mp4", "path": "/My Files/videos/x2.mp4"}, + {"fileId": "24", "name": "x3.mp4", "path": "/My Files/videos/x3.mp4"}, + {"fileId": "25", "name": "x4.mp4", "path": "/My Files/videos/x4.mp4"}, + ], + "totalItems": 7, + "query": "银翼杀手", + "folderId": "1", + "recursive": True, + "category": "video", + } + + monkeypatch.setattr(plan_module, "ToolRouter", lambda **kwargs: _FakeToolRouter()) + + async def _planner_with_read_tool(**kwargs): # noqa: ANN003 + tool_executor = kwargs["tool_executor"] + await tool_executor( + "drive.searchFiles", + {"folderId": "root", "query": "银翼杀手", "category": "video"}, + ) + return { + "summary": "search first", + "proposedActions": [ + { + "step": 1, + "tool": "drive.searchFiles", + "input": {"folderId": "root", "query": "银翼杀手", "category": "video"}, + } + ], + } + + runner = PlanRunner( + settings=settings(), + planner_client=SimpleNamespace(create_plan=_planner_with_read_tool), # type: ignore[arg-type] + ) + request = PlanAgentRequest.model_validate( + { + "input": "找出银翼杀手视频文件", + "context": { + "rootFolderId": "root", + "selectedFileIds": [], + "selectedFolderIds": [], + "currentPath": "/My Files", + }, + "executionPolicy": "confirm", + } + ) + job = BackgroundJob( + job_id=349, + task_type="agent.plan", + status="running", + payload=request.model_dump(by_alias=True), + result={}, + requested_by=7, + scheduled_at=datetime.now(UTC), + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + + result = await runner.run(db=DummyDb(), job=job) # type: ignore[arg-type] + + assert result.planning_evidence is not None + assert len(result.planning_evidence) == 1 + evidence = result.planning_evidence[0] + assert evidence.step == 1 + assert evidence.tool == "drive.searchFiles" + assert evidence.input["query"] == "银翼杀手" + assert evidence.output_preview["totalItems"] == 7 + assert isinstance(evidence.output_preview.get("items"), list) + assert "..." in str(evidence.output_preview["items"][-1]) + + +@pytest.mark.asyncio +async def test_plan_runner_uses_planner_returned_read_only_candidates_when_ambiguous( + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr(plan_module, "_choose_skill", AsyncMock(return_value=None)) + monkeypatch.setattr( + plan_module, + "_collect_context_metadata", + AsyncMock(return_value={"scope": "currentFolder", "rootFolderId": "root", "files": [], "folders": []}), + ) + monkeypatch.setattr(plan_module, "_upsert_agent_plan", AsyncMock(return_value=None)) + planner = AsyncMock( + return_value={ + "summary": "ambiguous, return candidates", + "proposedActions": [ + { + "step": 1, + "tool": "drive.searchFiles", + "input": {"folderId": "root", "query": "银翼杀手", "category": "video"}, + } + ], + } + ) + runner = PlanRunner( + settings=settings(), + planner_client=SimpleNamespace(create_plan=planner), # type: ignore[arg-type] + ) + request = PlanAgentRequest.model_validate( + { + "input": "把银翼杀手电影放到银翼杀手文件夹下", + "context": { + "rootFolderId": "root", + "selectedFileIds": [], + "selectedFolderIds": [], + "currentPath": "/My Files", + }, + "executionPolicy": "confirm", + } + ) + job = BackgroundJob( + job_id=339, + task_type="agent.plan", + status="running", + payload=request.model_dump(by_alias=True), + result={}, + requested_by=7, + scheduled_at=datetime.now(UTC), + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + + result = await runner.run(db=DummyDb(), job=job) # type: ignore[arg-type] + + assert len(result.proposed_actions) == 1 + action = result.proposed_actions[0] + assert action.tool == "drive.searchFiles" + assert action.side_effect == "read" + assert "ambiguous" in result.summary + + +def test_normalize_actions_rejects_symbolic_placeholder_target_folder(): + with pytest.raises(ApiError) as exc: + plan_module._normalize_actions( + llm_payload={ + "summary": "organize movies", + "proposedActions": [ + { + "step": 1, + "tool": "drive.createFolder", + "input": {"parentFolderId": "root", "name": "Movies"}, + }, + { + "step": 2, + "tool": "drive.moveFile", + "input": {"fileId": "13", "targetFolderId": "newFolderId"}, + }, + ], + }, + allowed_tools=("drive.createFolder", "drive.moveFile"), + max_steps=10, + ) + + assert exc.value.status_code == 400 + assert "step 2" in exc.value.message + assert "targetFolderId" in exc.value.message + assert "newFolderId" in exc.value.message + + +def test_normalize_actions_accepts_previous_step_reference(): + actions = plan_module._normalize_actions( + llm_payload={ + "summary": "organize movies", + "proposedActions": [ + { + "step": 1, + "tool": "drive.createFolder", + "input": {"parentFolderId": "root", "name": "Movies"}, + }, + { + "step": 2, + "tool": "drive.moveFile", + "input": {"fileId": "13", "targetFolderId": "$step1.folderId"}, + }, + ], + }, + allowed_tools=("drive.createFolder", "drive.moveFile"), + max_steps=10, + ) + + assert len(actions) == 2 + assert actions[1].input["targetFolderId"] == "$step1.folderId" + + +def test_normalize_actions_rejects_future_step_reference(): + with pytest.raises(ApiError) as exc: + plan_module._normalize_actions( + llm_payload={ + "summary": "organize movies", + "proposedActions": [ + { + "step": 3, + "tool": "drive.moveFile", + "input": {"fileId": "13", "targetFolderId": "$step4.folderId"}, + }, + { + "step": 4, + "tool": "drive.createFolder", + "input": {"parentFolderId": "root", "name": "Movies"}, + }, + ], + }, + allowed_tools=("drive.createFolder", "drive.moveFile"), + max_steps=10, + ) + + assert exc.value.status_code == 400 + assert "future step 4" in exc.value.message + assert "$step4.folderId" in exc.value.message + + +def test_execute_reference_resolution_rejects_symbolic_placeholder(): + with pytest.raises(ApiError) as exc: + execute_module._resolve_references( + {"targetFolderId": "newFolderId"}, + step_outputs={}, + ) + + assert exc.value.status_code == 409 + assert "targetFolderId" in exc.value.message + assert "$stepN.field" in exc.value.message + + +@pytest.mark.asyncio +async def test_policy_guard_blocks_delete_without_confirmation(): + decision = await PolicyGuard().evaluate_tool_call( + tool_name="drive.deleteFile", + high_risk_confirmed=False, + ) + assert decision.allowed is False + assert classify_tool_risk("drive.deleteFolder") == "high" + + +@pytest.mark.asyncio +async def test_tool_router_dispatches_move_file(): + router = ToolRouter(db=DummyDb(), user_id=7) # type: ignore[arg-type] + router.file_service.move_file = AsyncMock( + return_value=SimpleNamespace( + model_dump=lambda **kwargs: {"fileId": "1", "targetFolderId": "2"} + ) + ) + + result = await router.dispatch( + ToolCall( + tool_name="drive.moveFile", + arguments={"fileId": "1", "targetFolderId": "2"}, + ) + ) + + assert result == {"fileId": "1", "targetFolderId": "2"} + router.file_service.move_file.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_tool_router_count_files_counts_recursive_videos(): + db = DummyDb() + db.scalar = AsyncMock(return_value=1) + db.scalars = AsyncMock( + side_effect=[ + [1, 2], + [ + SimpleNamespace( + file_id=10, + file_name="movie.mp4", + file_size=100, + mime_type="application/octet-stream", + file_ext="mp4", + folder_id=1, + created_at=None, + updated_at=None, + ), + SimpleNamespace( + file_id=11, + file_name="clip.mkv", + file_size=200, + mime_type="video/x-matroska", + file_ext="mkv", + folder_id=2, + created_at=None, + updated_at=None, + ), + SimpleNamespace( + file_id=12, + file_name="notes.txt", + file_size=10, + mime_type="text/plain", + file_ext="txt", + folder_id=1, + created_at=None, + updated_at=None, + ), + ], + ] + ) + router = ToolRouter(db=db, user_id=7) # type: ignore[arg-type] + + result = await router.dispatch( + ToolCall( + tool_name="drive.countFiles", + arguments={"folderId": "root", "recursive": True, "category": "video"}, + ) + ) + + assert result["totalItems"] == 2 + assert result["category"] == "video" + assert result["recursive"] is True + assert result["byMimeType"] == {"video/mp4": 1, "video/x-matroska": 1} + assert [item["name"] for item in result["sampleItems"]] == ["movie.mp4", "clip.mkv"] + executed_statement = str(db.scalars.await_args_list[-1].args[0]) + assert "file.status" in executed_statement + assert "file.is_latest" in executed_statement + + +@pytest.mark.asyncio +async def test_tool_router_count_files_filters_by_search_term(): + db = DummyDb() + db.scalar = AsyncMock(return_value=1) + db.scalars = AsyncMock( + side_effect=[ + [1], + [ + SimpleNamespace( + file_id=10, + file_name="银翼杀手.mp4", + file_size=100, + mime_type="video/mp4", + file_ext="mp4", + folder_id=1, + created_at=None, + updated_at=None, + ), + ], + ] + ) + router = ToolRouter(db=db, user_id=7) # type: ignore[arg-type] + + result = await router.dispatch( + ToolCall( + tool_name="drive.countFiles", + arguments={ + "folderId": "root", + "recursive": True, + "category": "video", + "search": "银翼杀手", + }, + ) + ) + + assert result["totalItems"] == 1 + assert result["search"] == "银翼杀手" + executed_statement = str(db.scalars.await_args_list[-1].args[0]) + assert "file_name" in executed_statement + + +@pytest.mark.asyncio +async def test_execute_runner_normalizes_tool_output_before_action_log(monkeypatch: pytest.MonkeyPatch): + started = datetime.now(UTC) + output_time = datetime.now(UTC) + job = BackgroundJob( + job_id=600, + task_type="agent.execute", + status="running", + payload={ + "planJobId": "500", + "planHash": "sha256:test", + "approval": { + "confirmedBy": "7", + "confirmedAt": started.isoformat(), + "highRiskConfirmed": False, + }, + }, + result={}, + requested_by=7, + scheduled_at=started, + created_at=started, + updated_at=started, + ) + action = { + "step": 1, + "tool": "drive.createFolder", + "input": {"parentFolderId": "root", "name": "Movies"}, + "sideEffect": "write", + "riskLevel": "low", + "requiresConfirmation": False, + } + db = DummyDb() + db.refresh = AsyncMock() + + mock_plan_repo = SimpleNamespace( + get_for_execute_binding=AsyncMock( + return_value=SimpleNamespace( + proposed_actions_json=[action], + ) + ) + ) + monkeypatch.setattr(execute_module, "AgentPlanRepository", lambda _db: mock_plan_repo) + + mock_work_sessions = SimpleNamespace( + create_for_job=AsyncMock(return_value=None), + close_session=AsyncMock(return_value=None), + ) + monkeypatch.setattr(execute_module, "AgentWorkSessionRepository", lambda _db: mock_work_sessions) + + captured_outputs: list[dict[str, object]] = [] + mock_action_logs = SimpleNamespace( + append_step=AsyncMock(return_value=None), + finish_step=AsyncMock( + side_effect=lambda **kwargs: captured_outputs.append(dict(kwargs)) or None + ), + ) + monkeypatch.setattr(execute_module, "AgentActionLogRepository", lambda _db: mock_action_logs) + + mock_router = SimpleNamespace( + dispatch=AsyncMock( + return_value={ + "id": "9", + "createdAt": output_time, + "updatedAt": output_time, + } + ) + ) + monkeypatch.setattr(execute_module, "ToolRouter", lambda **kwargs: mock_router) + + result = await ExecuteRunner( + answer_client=SimpleNamespace(create_answer=AsyncMock(return_value="ok")) # type: ignore[arg-type] + ).run(db=db, job=job) # type: ignore[arg-type] + + assert result.applied_actions == 1 + assert captured_outputs + success_call = next(item for item in captured_outputs if item.get("status") == "succeeded") + outputs_json = success_call["outputs_json"] + assert isinstance(outputs_json, dict) + assert isinstance(outputs_json["createdAt"], str) + assert outputs_json["createdAt"] == output_time.isoformat() + + +@pytest.mark.asyncio +async def test_execute_runner_propagates_answer_model_errors(monkeypatch: pytest.MonkeyPatch): + started = datetime.now(UTC) + job = BackgroundJob( + job_id=610, + task_type="agent.execute", + status="running", + payload={ + "planJobId": "510", + "planHash": "sha256:test", + "approval": { + "confirmedBy": "7", + "confirmedAt": started.isoformat(), + "highRiskConfirmed": False, + }, + }, + result={}, + requested_by=7, + scheduled_at=started, + created_at=started, + updated_at=started, + ) + action = { + "step": 1, + "tool": "drive.countFiles", + "input": {"folderId": "root", "recursive": True, "category": "video"}, + "sideEffect": "read", + "riskLevel": "low", + "requiresConfirmation": False, + } + db = DummyDb() + db.refresh = AsyncMock() + monkeypatch.setattr( + execute_module, + "AgentPlanRepository", + lambda _db: SimpleNamespace( + get_for_execute_binding=AsyncMock(return_value=SimpleNamespace(proposed_actions_json=[action])) + ), + ) + monkeypatch.setattr( + execute_module, + "AgentWorkSessionRepository", + lambda _db: SimpleNamespace( + create_for_job=AsyncMock(return_value=None), + close_session=AsyncMock(return_value=None), + ), + ) + monkeypatch.setattr( + execute_module, + "AgentActionLogRepository", + lambda _db: SimpleNamespace( + append_step=AsyncMock(return_value=None), + finish_step=AsyncMock(return_value=None), + ), + ) + monkeypatch.setattr( + execute_module, + "ToolRouter", + lambda **kwargs: SimpleNamespace( + dispatch=AsyncMock( + return_value={ + "totalItems": 3, + "category": "video", + "recursive": True, + "folderId": "1", + "byMimeType": {"video/mp4": 3}, + } + ) + ), + ) + + with pytest.raises(ApiError) as exc: + await ExecuteRunner( + answer_client=SimpleNamespace( + create_answer=AsyncMock( + side_effect=ApiError(status_code=503, code=503, message="Agent LLM API key is not configured") + ) + ) # type: ignore[arg-type] + ).run(db=db, job=job) # type: ignore[arg-type] + + assert exc.value.status_code == 503 + + +@pytest.mark.asyncio +async def test_execute_runner_returns_count_files_answer(monkeypatch: pytest.MonkeyPatch): + started = datetime.now(UTC) + job = BackgroundJob( + job_id=601, + task_type="agent.execute", + status="running", + payload={ + "planJobId": "501", + "planHash": "sha256:test", + "approval": { + "confirmedBy": "7", + "confirmedAt": started.isoformat(), + "highRiskConfirmed": False, + }, + }, + result={}, + requested_by=7, + scheduled_at=started, + created_at=started, + updated_at=started, + ) + action = { + "step": 1, + "tool": "drive.countFiles", + "input": {"folderId": "root", "recursive": True, "category": "video"}, + "sideEffect": "read", + "riskLevel": "low", + "requiresConfirmation": False, + } + db = DummyDb() + db.refresh = AsyncMock() + + mock_plan_repo = SimpleNamespace( + get_for_execute_binding=AsyncMock( + return_value=SimpleNamespace( + proposed_actions_json=[action], + ) + ) + ) + monkeypatch.setattr(execute_module, "AgentPlanRepository", lambda _db: mock_plan_repo) + + mock_work_sessions = SimpleNamespace( + create_for_job=AsyncMock(return_value=None), + close_session=AsyncMock(return_value=None), + ) + monkeypatch.setattr(execute_module, "AgentWorkSessionRepository", lambda _db: mock_work_sessions) + monkeypatch.setattr( + execute_module, + "AgentActionLogRepository", + lambda _db: SimpleNamespace( + append_step=AsyncMock(return_value=None), + finish_step=AsyncMock(return_value=None), + ), + ) + monkeypatch.setattr( + execute_module, + "ToolRouter", + lambda **kwargs: SimpleNamespace( + dispatch=AsyncMock( + return_value={ + "totalItems": 3, + "category": "video", + "recursive": True, + "folderId": "1", + "byMimeType": {"video/mp4": 3}, + "sampleItems": [], + } + ) + ), + ) + + result = await ExecuteRunner( + answer_client=SimpleNamespace( + create_answer=AsyncMock(return_value="你上传了 3 部电影(按视频文件统计)。") + ) # type: ignore[arg-type] + ).run(db=db, job=job) # type: ignore[arg-type] + + assert result.answer == "你上传了 3 部电影(按视频文件统计)。" + assert result.applied_actions == 1 + + +@pytest.mark.asyncio +async def test_execute_runner_returns_count_files_answer_with_search_term( + monkeypatch: pytest.MonkeyPatch, +): + started = datetime.now(UTC) + job = BackgroundJob( + job_id=602, + task_type="agent.execute", + status="running", + payload={ + "planJobId": "502", + "planHash": "sha256:test", + "approval": { + "confirmedBy": "7", + "confirmedAt": started.isoformat(), + "highRiskConfirmed": False, + }, + }, + result={}, + requested_by=7, + scheduled_at=started, + created_at=started, + updated_at=started, + ) + action = { + "step": 1, + "tool": "drive.countFiles", + "input": { + "folderId": "root", + "recursive": True, + "category": "video", + "search": "银翼杀手", + }, + "sideEffect": "read", + "riskLevel": "low", + "requiresConfirmation": False, + } + db = DummyDb() + db.refresh = AsyncMock() + + monkeypatch.setattr( + execute_module, + "AgentPlanRepository", + lambda _db: SimpleNamespace( + get_for_execute_binding=AsyncMock( + return_value=SimpleNamespace(proposed_actions_json=[action]) + ) + ), + ) + monkeypatch.setattr( + execute_module, + "AgentWorkSessionRepository", + lambda _db: SimpleNamespace( + create_for_job=AsyncMock(return_value=None), + close_session=AsyncMock(return_value=None), + ), + ) + monkeypatch.setattr( + execute_module, + "AgentActionLogRepository", + lambda _db: SimpleNamespace( + append_step=AsyncMock(return_value=None), + finish_step=AsyncMock(return_value=None), + ), + ) + monkeypatch.setattr( + execute_module, + "ToolRouter", + lambda **kwargs: SimpleNamespace( + dispatch=AsyncMock( + return_value={ + "totalItems": 2, + "category": "video", + "recursive": True, + "folderId": "1", + "search": "银翼杀手", + "byMimeType": {"video/mp4": 2}, + "sampleItems": [], + } + ) + ), + ) + + result = await ExecuteRunner( + answer_client=SimpleNamespace( + create_answer=AsyncMock(return_value="你上传了 2 部名称包含“银翼杀手”的电影(按视频文件统计)。") + ) # type: ignore[arg-type] + ).run(db=db, job=job) # type: ignore[arg-type] + + assert result.answer == "你上传了 2 部名称包含“银翼杀手”的电影(按视频文件统计)。" + assert "只读操作" not in (result.answer or "") + + +@pytest.mark.asyncio +async def test_execute_runner_returns_count_files_answer_with_names_when_asked( + monkeypatch: pytest.MonkeyPatch, +): + started = datetime.now(UTC) + job = BackgroundJob( + job_id=604, + task_type="agent.execute", + status="running", + payload={ + "planJobId": "504", + "planHash": "sha256:test", + "approval": { + "confirmedBy": "7", + "confirmedAt": started.isoformat(), + "highRiskConfirmed": False, + }, + }, + result={}, + requested_by=7, + scheduled_at=started, + created_at=started, + updated_at=started, + ) + action = { + "step": 1, + "tool": "drive.countFiles", + "input": {"folderId": "root", "recursive": True, "category": "archive"}, + "sideEffect": "read", + "riskLevel": "low", + "requiresConfirmation": False, + } + db = DummyDb() + db.refresh = AsyncMock() + + monkeypatch.setattr( + execute_module, + "AgentPlanRepository", + lambda _db: SimpleNamespace( + get_for_execute_binding=AsyncMock( + return_value=SimpleNamespace( + input_text="我上传了多少压缩包,叫什么名字", + proposed_actions_json=[action], + ) + ) + ), + ) + monkeypatch.setattr( + execute_module, + "AgentWorkSessionRepository", + lambda _db: SimpleNamespace( + create_for_job=AsyncMock(return_value=None), + close_session=AsyncMock(return_value=None), + ), + ) + monkeypatch.setattr( + execute_module, + "AgentActionLogRepository", + lambda _db: SimpleNamespace( + append_step=AsyncMock(return_value=None), + finish_step=AsyncMock(return_value=None), + ), + ) + monkeypatch.setattr( + execute_module, + "ToolRouter", + lambda **kwargs: SimpleNamespace( + dispatch=AsyncMock( + return_value={ + "totalItems": 2, + "category": "archive", + "recursive": True, + "folderId": "1", + "itemNames": ["photos.zip", "backup.7z"], + "itemNamesTruncated": False, + "byMimeType": {"application/zip": 1, "application/x-7z-compressed": 1}, + "sampleItems": [], + } + ) + ), + ) -def test_execute_reference_resolution_rejects_symbolic_placeholder(): - with pytest.raises(ApiError) as exc: - execute_module._resolve_references( - {"targetFolderId": "newFolderId"}, - step_outputs={}, - ) + result = await ExecuteRunner( + answer_client=SimpleNamespace( + create_answer=AsyncMock(return_value="你上传了 2 个压缩包,名字是:photos.zip、backup.7z。") + ) # type: ignore[arg-type] + ).run(db=db, job=job) # type: ignore[arg-type] - assert exc.value.status_code == 409 - assert "targetFolderId" in exc.value.message - assert "$stepN.field" in exc.value.message + assert result.answer == "你上传了 2 个压缩包,名字是:photos.zip、backup.7z。" @pytest.mark.asyncio -async def test_policy_guard_blocks_delete_without_confirmation(): - decision = await PolicyGuard().evaluate_tool_call( - tool_name="drive.deleteFile", - high_risk_confirmed=False, +async def test_execute_runner_lists_archive_names_for_read_only_archive_question( + monkeypatch: pytest.MonkeyPatch, +): + started = datetime.now(UTC) + job = BackgroundJob( + job_id=603, + task_type="agent.execute", + status="running", + payload={ + "planJobId": "503", + "planHash": "sha256:test", + "approval": { + "confirmedBy": "7", + "confirmedAt": started.isoformat(), + "highRiskConfirmed": False, + }, + }, + result={}, + requested_by=7, + scheduled_at=started, + created_at=started, + updated_at=started, ) - assert decision.allowed is False - assert classify_tool_risk("drive.deleteFolder") == "high" - + action = { + "step": 1, + "tool": "drive.listFolder", + "input": {"folderId": "root"}, + "sideEffect": "read", + "riskLevel": "low", + "requiresConfirmation": False, + } + db = DummyDb() + db.refresh = AsyncMock() -@pytest.mark.asyncio -async def test_tool_router_dispatches_move_file(): - router = ToolRouter(db=DummyDb(), user_id=7) # type: ignore[arg-type] - router.file_service.move_file = AsyncMock( - return_value=SimpleNamespace( - model_dump=lambda **kwargs: {"fileId": "1", "targetFolderId": "2"} - ) + monkeypatch.setattr( + execute_module, + "AgentPlanRepository", + lambda _db: SimpleNamespace( + get_for_execute_binding=AsyncMock( + return_value=SimpleNamespace( + input_text="我上传了哪些压缩包", + proposed_actions_json=[action], + ) + ) + ), ) - - result = await router.dispatch( - ToolCall( - tool_name="drive.moveFile", - arguments={"fileId": "1", "targetFolderId": "2"}, - ) + monkeypatch.setattr( + execute_module, + "AgentWorkSessionRepository", + lambda _db: SimpleNamespace( + create_for_job=AsyncMock(return_value=None), + close_session=AsyncMock(return_value=None), + ), ) - - assert result == {"fileId": "1", "targetFolderId": "2"} - router.file_service.move_file.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_tool_router_count_files_counts_recursive_videos(): - db = DummyDb() - db.scalar = AsyncMock(return_value=1) - db.scalars = AsyncMock(return_value=[1, 2]) - db.execute = AsyncMock( - return_value=SimpleNamespace( - all=lambda: [ - (10, "movie.mp4", 100, "application/octet-stream", "mp4", 1), - (11, "clip.mkv", 200, "video/x-matroska", "mkv", 2), - (12, "notes.txt", 10, "text/plain", "txt", 1), - ] - ) + monkeypatch.setattr( + execute_module, + "AgentActionLogRepository", + lambda _db: SimpleNamespace( + append_step=AsyncMock(return_value=None), + finish_step=AsyncMock(return_value=None), + ), ) - router = ToolRouter(db=db, user_id=7) # type: ignore[arg-type] - - result = await router.dispatch( - ToolCall( - tool_name="drive.countFiles", - arguments={"folderId": "root", "recursive": True, "category": "video"}, - ) + monkeypatch.setattr( + execute_module, + "ToolRouter", + lambda **kwargs: SimpleNamespace( + dispatch=AsyncMock( + return_value={ + "items": [ + { + "itemType": "file", + "id": "1", + "name": "photos.zip", + "size": 10, + "mimeType": "application/zip", + }, + { + "itemType": "file", + "id": "2", + "name": "movie.mp4", + "size": 20, + "mimeType": "video/mp4", + }, + { + "itemType": "file", + "id": "3", + "name": "backup.7z", + "size": 30, + "mimeType": "application/octet-stream", + }, + ], + "pagination": { + "totalItems": 3, + "totalPages": 1, + "perPage": 200, + "currentPage": 1, + "hasPrev": False, + "hasNext": False, + }, + } + ) + ), ) - assert result["totalItems"] == 2 - assert result["category"] == "video" - assert result["recursive"] is True - assert result["byMimeType"] == {"video/mp4": 1, "video/x-matroska": 1} - assert [item["name"] for item in result["sampleItems"]] == ["movie.mp4", "clip.mkv"] - executed_statement = str(db.execute.await_args.args[0]) - assert "file.status" in executed_statement - assert "file.is_latest" in executed_statement + result = await ExecuteRunner( + answer_client=SimpleNamespace( + create_answer=AsyncMock(return_value="当前文件夹中的压缩包有 2 个:photos.zip、backup.7z。") + ) # type: ignore[arg-type] + ).run(db=db, job=job) # type: ignore[arg-type] + + assert result.answer == "当前文件夹中的压缩包有 2 个:photos.zip、backup.7z。" @pytest.mark.asyncio -async def test_execute_runner_normalizes_tool_output_before_action_log(monkeypatch: pytest.MonkeyPatch): +async def test_execute_runner_returns_search_files_candidate_answer( + monkeypatch: pytest.MonkeyPatch, +): started = datetime.now(UTC) - output_time = datetime.now(UTC) job = BackgroundJob( - job_id=600, + job_id=605, task_type="agent.execute", status="running", payload={ - "planJobId": "500", + "planJobId": "505", "planHash": "sha256:test", "approval": { "confirmedBy": "7", @@ -862,70 +2285,102 @@ async def test_execute_runner_normalizes_tool_output_before_action_log(monkeypat ) action = { "step": 1, - "tool": "drive.createFolder", - "input": {"parentFolderId": "root", "name": "Movies"}, - "sideEffect": "write", + "tool": "drive.searchFiles", + "input": {"folderId": "root", "query": "银翼杀手", "category": "video"}, + "sideEffect": "read", "riskLevel": "low", "requiresConfirmation": False, } db = DummyDb() db.refresh = AsyncMock() - mock_plan_repo = SimpleNamespace( - get_for_execute_binding=AsyncMock( - return_value=SimpleNamespace( - proposed_actions_json=[action], + monkeypatch.setattr( + execute_module, + "AgentPlanRepository", + lambda _db: SimpleNamespace( + get_for_execute_binding=AsyncMock( + return_value=SimpleNamespace( + input_text="把银翼杀手电影放到银翼杀手文件夹下", + proposed_actions_json=[action], + ) ) - ) + ), ) - monkeypatch.setattr(execute_module, "AgentPlanRepository", lambda _db: mock_plan_repo) - - mock_work_sessions = SimpleNamespace( - create_for_job=AsyncMock(return_value=None), - close_session=AsyncMock(return_value=None), + monkeypatch.setattr( + execute_module, + "AgentWorkSessionRepository", + lambda _db: SimpleNamespace( + create_for_job=AsyncMock(return_value=None), + close_session=AsyncMock(return_value=None), + ), ) - monkeypatch.setattr(execute_module, "AgentWorkSessionRepository", lambda _db: mock_work_sessions) - - captured_outputs: list[dict[str, object]] = [] - mock_action_logs = SimpleNamespace( - append_step=AsyncMock(return_value=None), - finish_step=AsyncMock( - side_effect=lambda **kwargs: captured_outputs.append(dict(kwargs)) or None + monkeypatch.setattr( + execute_module, + "AgentActionLogRepository", + lambda _db: SimpleNamespace( + append_step=AsyncMock(return_value=None), + finish_step=AsyncMock(return_value=None), ), ) - monkeypatch.setattr(execute_module, "AgentActionLogRepository", lambda _db: mock_action_logs) - - mock_router = SimpleNamespace( - dispatch=AsyncMock( - return_value={ - "id": "9", - "createdAt": output_time, - "updatedAt": output_time, - } - ) + monkeypatch.setattr( + execute_module, + "ToolRouter", + lambda **kwargs: SimpleNamespace( + dispatch=AsyncMock( + return_value={ + "items": [ + {"id": "11", "name": "银翼杀手1982.mp4", "mimeType": "video/mp4"}, + {"id": "12", "name": "银翼杀手2049.mp4", "mimeType": "video/mp4"}, + ], + "totalItems": 2, + "query": "银翼杀手", + "folderId": "1", + "recursive": False, + "category": "video", + } + ) + ), ) - monkeypatch.setattr(execute_module, "ToolRouter", lambda **kwargs: mock_router) - result = await ExecuteRunner().run(db=db, job=job) # type: ignore[arg-type] + result = await ExecuteRunner( + answer_client=SimpleNamespace( + create_answer=AsyncMock( + return_value="找到 2 个名称包含“银翼杀手”的文件:银翼杀手1982.mp4、银翼杀手2049.mp4。" + ) + ) # type: ignore[arg-type] + ).run(db=db, job=job) # type: ignore[arg-type] - assert result.applied_actions == 1 - assert captured_outputs - success_call = next(item for item in captured_outputs if item.get("status") == "succeeded") - outputs_json = success_call["outputs_json"] - assert isinstance(outputs_json, dict) - assert isinstance(outputs_json["createdAt"], str) - assert outputs_json["createdAt"] == output_time.isoformat() + assert result.answer == "找到 2 个名称包含“银翼杀手”的文件:银翼杀手1982.mp4、银翼杀手2049.mp4。" -@pytest.mark.asyncio -async def test_execute_runner_returns_count_files_answer(monkeypatch: pytest.MonkeyPatch): +class _NeverStream: + async def next(self, *, timeout=None): # noqa: ANN001 + await asyncio.Future() + + async def aclose(self) -> None: + return None + + +class _CaptureBus: + def __init__(self) -> None: + self.events: list[AgentEventEnvelope] = [] + + async def publish(self, envelope: AgentEventEnvelope) -> None: + self.events.append(envelope) + + @asynccontextmanager + async def subscribe(self, *, job_id: int): # noqa: ARG002 + yield _NeverStream() + + +def _execute_job_for_controls() -> BackgroundJob: started = datetime.now(UTC) - job = BackgroundJob( - job_id=601, + return BackgroundJob( + job_id=700, task_type="agent.execute", status="running", payload={ - "planJobId": "501", + "planJobId": "500", "planHash": "sha256:test", "approval": { "confirmedBy": "7", @@ -939,6 +2394,14 @@ async def test_execute_runner_returns_count_files_answer(monkeypatch: pytest.Mon created_at=started, updated_at=started, ) + + +def _patch_execute_dependencies( + monkeypatch: pytest.MonkeyPatch, + *, + controls: list[list[SimpleNamespace]], + dropped: list[int], +) -> None: action = { "step": 1, "tool": "drive.countFiles", @@ -947,23 +2410,24 @@ async def test_execute_runner_returns_count_files_answer(monkeypatch: pytest.Mon "riskLevel": "low", "requiresConfirmation": False, } - db = DummyDb() - db.refresh = AsyncMock() - mock_plan_repo = SimpleNamespace( - get_for_execute_binding=AsyncMock( - return_value=SimpleNamespace( - proposed_actions_json=[action], + monkeypatch.setattr( + execute_module, + "AgentPlanRepository", + lambda _db: SimpleNamespace( + get_for_execute_binding=AsyncMock( + return_value=SimpleNamespace(proposed_actions_json=[action]) ) - ) + ), ) - monkeypatch.setattr(execute_module, "AgentPlanRepository", lambda _db: mock_plan_repo) - - mock_work_sessions = SimpleNamespace( - create_for_job=AsyncMock(return_value=None), - close_session=AsyncMock(return_value=None), + monkeypatch.setattr( + execute_module, + "AgentWorkSessionRepository", + lambda _db: SimpleNamespace( + create_for_job=AsyncMock(return_value=None), + close_session=AsyncMock(return_value=None), + ), ) - monkeypatch.setattr(execute_module, "AgentWorkSessionRepository", lambda _db: mock_work_sessions) monkeypatch.setattr( execute_module, "AgentActionLogRepository", @@ -978,18 +2442,103 @@ async def test_execute_runner_returns_count_files_answer(monkeypatch: pytest.Mon lambda **kwargs: SimpleNamespace( dispatch=AsyncMock( return_value={ - "totalItems": 3, + "totalItems": 1, "category": "video", "recursive": True, "folderId": "1", - "byMimeType": {"video/mp4": 3}, + "byMimeType": {"video/mp4": 1}, "sampleItems": [], } ) ), ) - result = await ExecuteRunner().run(db=db, job=job) # type: ignore[arg-type] + class FakeInboxRepository: + def __init__(self, _db) -> None: # noqa: ANN001 + return None + + async def list_pending_controls(self, *, job_id: int): # noqa: ARG002 + if controls: + return controls.pop(0) + return [] + + async def mark_dropped(self, *, inbox_message_id: int) -> None: + dropped.append(inbox_message_id) + + monkeypatch.setattr(execute_module, "AgentInboxMessageRepository", FakeInboxRepository) + + +@pytest.mark.asyncio +async def test_execute_runner_pauses_then_resumes_at_step_boundary( + monkeypatch: pytest.MonkeyPatch, +): + controls = [ + [SimpleNamespace(inbox_message_id=1, kind=AgentInboxKind.CONTROL_PAUSE)], + [SimpleNamespace(inbox_message_id=2, kind=AgentInboxKind.CONTROL_RESUME)], + ] + dropped: list[int] = [] + _patch_execute_dependencies(monkeypatch, controls=controls, dropped=dropped) + db = DummyDb() + db.refresh = AsyncMock() + bus = _CaptureBus() + + result = await ExecuteRunner( + event_bus=bus, + answer_client=SimpleNamespace(create_answer=AsyncMock(return_value="ok")), # type: ignore[arg-type] + ).run(db=db, job=_execute_job_for_controls()) # type: ignore[arg-type] - assert result.answer == "你上传了 3 部电影(按视频文件统计)。" assert result.applied_actions == 1 + assert dropped == [1, 2] + event_types = [event.event_type for event in bus.events] + assert "agent.paused" in event_types + assert "agent.resumed" in event_types + assert "tool.started" in event_types + assert "tool.succeeded" in event_types + + +@pytest.mark.asyncio +async def test_execute_runner_canceled_via_inbox_at_step_boundary( + monkeypatch: pytest.MonkeyPatch, +): + controls = [[SimpleNamespace(inbox_message_id=1, kind=AgentInboxKind.CONTROL_CANCEL)]] + dropped: list[int] = [] + _patch_execute_dependencies(monkeypatch, controls=controls, dropped=dropped) + db = DummyDb() + db.refresh = AsyncMock() + + with pytest.raises(AgentJobCanceled): + await ExecuteRunner( + event_bus=_CaptureBus(), + answer_client=SimpleNamespace(create_answer=AsyncMock(return_value="ok")), # type: ignore[arg-type] + ).run( # type: ignore[arg-type] + db=db, + job=_execute_job_for_controls(), + ) + + assert dropped == [1] + + +@pytest.mark.asyncio +async def test_execute_runner_publish_state_ignores_event_bus_failures(): + bus = SimpleNamespace(publish=AsyncMock(side_effect=RuntimeError("boom"))) + runner = ExecuteRunner(event_bus=bus) # type: ignore[arg-type] + + await runner._publish_state("agent.paused", job_id=1) + + assert bus.publish.await_count == 1 + + +@pytest.mark.asyncio +async def test_execute_runner_publish_tool_ignores_event_bus_failures(): + bus = SimpleNamespace(publish=AsyncMock(side_effect=RuntimeError("boom"))) + runner = ExecuteRunner(event_bus=bus) # type: ignore[arg-type] + + await runner._publish_tool( + "tool.started", + job_id=1, + step=1, + tool="drive.listFolder", + payload={"input": {"folderId": "root"}}, + ) + + assert bus.publish.await_count == 1 diff --git a/app/tests/test_agent_routes.py b/app/tests/test_agent_routes.py index 6b50f90..68d225e 100644 --- a/app/tests/test_agent_routes.py +++ b/app/tests/test_agent_routes.py @@ -1,14 +1,22 @@ from __future__ import annotations +from contextlib import asynccontextmanager from datetime import UTC, datetime from fastapi import FastAPI from fastapi.testclient import TestClient -from fileflash.core.deps import get_agent_execute_service, get_agent_plan_service, get_current_user +from fileflash.agents.harness.event_bus import AgentEventEnvelope, InMemoryAgentEventBus +from fileflash.core.deps import ( + get_agent_event_bus, + get_agent_execute_service, + get_agent_plan_service, + get_current_user, +) from fileflash.core.errors import ApiError, api_error_handler from fileflash.db.deps import get_db -from fileflash.models import BackgroundJob +from fileflash.models import AgentActionLog, AgentInboxMessage, BackgroundJob +from fileflash.models.enums import AgentInboxRole from fileflash.models.tables_identity import User from fileflash.routers.agent import router from fileflash.schemas.agent import ExecuteAgentResponse, PlanAgentResponse @@ -38,16 +46,32 @@ def __init__(self) -> None: created_at=now, updated_at=now, ) + self.messages: list[AgentInboxMessage] = [] + self._next_inbox_id = 1 async def scalar(self, _query): # noqa: ANN001 return self.job + async def scalars(self, _query): # noqa: ANN001 + return [] + + def add(self, msg: AgentInboxMessage) -> None: + msg.inbox_message_id = self._next_inbox_id + self._next_inbox_id += 1 + self.messages.append(msg) + + async def flush(self) -> None: + return None + async def commit(self) -> None: return None async def refresh(self, _job: BackgroundJob) -> None: return None + async def get(self, _model, _id: int): # noqa: ANN001 + return None + class RunningJobDb(StubDb): def __init__(self) -> None: @@ -55,6 +79,40 @@ def __init__(self) -> None: self.job.status = "running" +class EventsDb(StubDb): + def __init__(self) -> None: + super().__init__() + now = datetime.now(UTC) + self.job.status = "succeeded" + self.job.result = { + "planJobId": "10", + "executeJobId": "12", + "summary": "done", + "answer": "你上传了 2 部名称包含“银翼杀手”的电影(按视频文件统计)。", + "appliedActions": 1, + "skippedActions": 0, + "warnings": [], + "finishedAt": now.isoformat(), + } + self.job.finished_at = now + self.job.updated_at = now + self.action_log = AgentActionLog( + action_log_id=1, + job_id=12, + step_no=1, + tool_name="drive.countFiles", + inputs_json={"folderId": "root", "category": "video", "search": "银翼杀手"}, + outputs_json={"totalItems": 2, "category": "video", "search": "银翼杀手"}, + status="succeeded", + duration_ms=12, + started_at=now, + finished_at=now, + ) + + async def scalars(self, _query): # noqa: ANN001 + return [self.action_log] + + def _user() -> User: return User(user_id=7, username="u7", email="u7@example.com", password_hash="x") @@ -81,6 +139,17 @@ def _client_with_running_job() -> TestClient: return TestClient(app) +def _client_with_events() -> TestClient: + app = FastAPI() + app.include_router(router, prefix="/api/v1") + app.add_exception_handler(ApiError, api_error_handler) + app.dependency_overrides[get_current_user] = _user + app.dependency_overrides[get_agent_plan_service] = lambda: StubPlanService() + app.dependency_overrides[get_agent_execute_service] = lambda: StubExecuteService() + app.dependency_overrides[get_db] = lambda: EventsDb() + return TestClient(app) + + def test_plan_route_returns_response_shell(): response = _client().post( "/api/v1/agent/plan", @@ -131,22 +200,85 @@ def test_execute_route_returns_response_shell(): assert body["data"]["taskType"] == "agent.execute" -def test_cancel_route_returns_response_shell(): - response = _client().post("/api/v1/agent/cancel/12") +def test_post_message_control_pause_returns_response_shell(): + bus = InMemoryAgentEventBus() + db = StubDb() + app = FastAPI() + app.include_router(router, prefix="/api/v1") + app.add_exception_handler(ApiError, api_error_handler) + app.dependency_overrides[get_current_user] = _user + app.dependency_overrides[get_db] = lambda: db + app.dependency_overrides[get_agent_event_bus] = lambda: bus + client = TestClient(app) + + response = client.post("/api/v1/agent/jobs/12/messages", json={"kind": "control.pause"}) assert response.status_code == 200 body = response.json() assert body["success"] is True - assert body["data"]["jobId"] == "12" - assert body["data"]["status"] == "canceled" - assert body["data"]["canceledAt"] + assert body["data"]["kind"] == "control.pause" + assert body["data"]["inboxMessageId"] == "1" + assert db.messages[0].role == AgentInboxRole.USER -def test_cancel_route_marks_running_job_as_canceled(): - response = _client_with_running_job().post("/api/v1/agent/cancel/12") +def test_job_events_route_streams_tool_and_final_answer_events(): + response = _client_with_events().get("/api/v1/agent/jobs/12/events") assert response.status_code == 200 - body = response.json() - assert body["success"] is True - assert body["data"]["jobId"] == "12" - assert body["data"]["status"] == "canceled" + assert response.headers["content-type"].startswith("text/event-stream") + body = response.text + assert "event: tool.started" in body + assert "event: tool.succeeded" in body + assert "event: job.succeeded" in body + assert "正在读取名称包含" in body + assert "银翼杀手" in body + assert "answer" in body + + +def test_job_events_route_streams_event_bus_events_after_initial_replay(): + now = datetime.now(UTC) + events = [ + AgentEventEnvelope( + job_id=12, + event_type="agent.progress", + payload={"step": 1, "total": 3, "message": "halfway"}, + emitted_at=now, + ), + AgentEventEnvelope( + job_id=12, + event_type="job.succeeded", + payload={"status": "succeeded"}, + emitted_at=now, + ), + ] + + class StaticStream: + async def next(self, *, timeout=None): # noqa: ANN001 + if not events: + raise TimeoutError + return events.pop(0) + + async def aclose(self) -> None: + return None + + class StaticBus: + async def publish(self, envelope): # noqa: ANN001 + return None + + @asynccontextmanager + async def subscribe(self, *, job_id: int): # noqa: ARG002 + yield StaticStream() + + app = FastAPI() + app.include_router(router, prefix="/api/v1") + app.add_exception_handler(ApiError, api_error_handler) + app.dependency_overrides[get_current_user] = _user + app.dependency_overrides[get_db] = lambda: RunningJobDb() + app.dependency_overrides[get_agent_event_bus] = lambda: StaticBus() + client = TestClient(app) + + response = client.get("/api/v1/agent/jobs/12/events") + + assert response.status_code == 200 + assert "event: agent.progress" in response.text + assert "event: job.succeeded" in response.text diff --git a/app/tests/test_agent_skill_service.py b/app/tests/test_agent_skill_service.py index 79d8ffe..bc4d13f 100644 --- a/app/tests/test_agent_skill_service.py +++ b/app/tests/test_agent_skill_service.py @@ -76,6 +76,29 @@ async def test_update_custom_skill_requires_owner_private(): session.commit.assert_not_awaited() +@pytest.mark.asyncio +async def test_create_custom_skill_rejects_unknown_tool(): + session = DummySession() + session.scalar.return_value = None + + repo = AgentSkillRepository(session) + service = SkillService(db=session, skills=repo) + + with pytest.raises(ApiError) as exc: + await service.create_custom_skill( + user_id=7, + payload=CreateAgentSkillRequest( + name="Unsafe", + description="bad tool", + tool_whitelist=["drive.listFolder", "files.list"], + ), + ) + + assert exc.value.status_code == 422 + assert exc.value.data == {"unknownTools": ["files.list"]} + session.commit.assert_not_awaited() + + @pytest.mark.asyncio async def test_delete_custom_skill_requires_owner_private(): session = DummySession() @@ -91,6 +114,33 @@ async def test_delete_custom_skill_requires_owner_private(): session.commit.assert_not_awaited() +@pytest.mark.asyncio +async def test_import_rejects_unknown_tool(): + session = DummySession() + session.scalars.return_value = [] + + repo = AgentSkillRepository(session) + service = SkillService(db=session, skills=repo) + + payload = ImportAgentSkillsRequest( + items=[ + ImportAgentSkillItem( + skill_key="builtin:bad", + name="bad", + description="bad", + tool_whitelist=["drive.missing"], + ) + ], + ) + + with pytest.raises(ApiError) as exc: + await service.import_global_skills(payload=payload) + + assert exc.value.status_code == 422 + assert exc.value.data == {"unknownTools": ["drive.missing"]} + session.commit.assert_not_awaited() + + @pytest.mark.asyncio async def test_import_insert_only_conflict_raises_409(): session = DummySession() diff --git a/app/tests/test_agent_tools.py b/app/tests/test_agent_tools.py new file mode 100644 index 0000000..36287ec --- /dev/null +++ b/app/tests/test_agent_tools.py @@ -0,0 +1,240 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from fileflash.agents.harness.router import ToolCall, ToolRouter +from fileflash.agents.harness.tool_registry import REGISTRY, ToolRegistry, ToolSpec +from fileflash.models.enums import FileStatus + + +async def _noop_handler(_ctx, _args): # noqa: ANN001 + return {"ok": True} + + +def test_tool_registry_registers_and_maps_provider_names(): + registry = ToolRegistry() + registry.register( + ToolSpec( + name="drive.testTool", + description="test", + input_schema={"type": "object"}, + side_effect="read", + risk_level="low", + requires_confirmation=False, + handler=_noop_handler, + ) + ) + + assert registry.all_names() == ("drive.testTool",) + assert registry.get("drive.testTool").anthropic_name == "drive_test_tool" + assert registry.get_by_provider_name("drive_test_tool").name == "drive.testTool" + assert registry.anthropic_tools_for(["drive.testTool"])[0]["internalName"] == "drive.testTool" + + +def test_tool_registry_rejects_duplicate_names(): + registry = ToolRegistry() + spec = ToolSpec( + name="drive.testTool", + description="test", + input_schema={"type": "object"}, + side_effect="read", + risk_level="low", + requires_confirmation=False, + handler=_noop_handler, + ) + registry.register(spec) + + with pytest.raises(ValueError): + registry.register(spec) + + +def test_builtin_registry_contains_new_query_tools(): + names = set(REGISTRY.all_names()) + + assert { + "drive.searchFiles", + "drive.getFileInfo", + "drive.listRecent", + "drive.statsByCategory", + "drive.findDuplicates", + }.issubset(names) + assert REGISTRY.get("drive.deleteFile").risk_level == "high" + + +class DummyDb: + def __init__(self) -> None: + self.scalar = AsyncMock(return_value=1) + self.scalars = AsyncMock(return_value=[]) + self.execute = AsyncMock() + self.get = AsyncMock() + + +@pytest.mark.asyncio +async def test_tool_router_dispatches_new_search_files_tool(): + db = DummyDb() + db.scalar = AsyncMock( + side_effect=[ + 1, + SimpleNamespace(folder_name="My Files", parent_folder_id=None), + ] + ) + db.scalars = AsyncMock( + side_effect=[ + [1], + [ + SimpleNamespace( + file_id=10, + file_name="movie.mp4", + file_size=100, + mime_type="video/mp4", + file_ext="mp4", + folder_id=1, + storage_object_id=20, + status=FileStatus.ACTIVE, + is_latest=True, + created_at=None, + updated_at=None, + ) + ], + ] + ) + router = ToolRouter(db=db, user_id=7) # type: ignore[arg-type] + + result = await router.dispatch( + ToolCall( + tool_name="drive.searchFiles", + arguments={"folderId": "root", "query": "movie", "category": "video"}, + ) + ) + + assert result["totalItems"] == 1 + assert result["items"][0]["name"] == "movie.mp4" + + +@pytest.mark.asyncio +async def test_tool_router_dispatches_stats_by_category_tool(): + db = DummyDb() + db.scalars = AsyncMock( + side_effect=[ + [1], + [ + SimpleNamespace( + file_id=10, + file_name="movie.mp4", + file_size=100, + mime_type="video/mp4", + file_ext="mp4", + folder_id=1, + storage_object_id=20, + created_at=None, + updated_at=None, + ), + SimpleNamespace( + file_id=11, + file_name="notes.txt", + file_size=10, + mime_type="text/plain", + file_ext="txt", + folder_id=1, + storage_object_id=21, + created_at=None, + updated_at=None, + ), + ], + ] + ) + router = ToolRouter(db=db, user_id=7) # type: ignore[arg-type] + + result = await router.dispatch( + ToolCall(tool_name="drive.statsByCategory", arguments={"folderId": "root"}) + ) + + assert result["video"] == 1 + assert result["document"] == 1 + assert result["totalSize"] == 110 + + +@pytest.mark.asyncio +async def test_tool_router_count_files_accepts_anime_alias_and_returns_item_names(): + db = DummyDb() + db.scalar = AsyncMock(return_value=1) + db.scalars = AsyncMock( + side_effect=[ + [1], + [ + SimpleNamespace( + file_id=10, + file_name="银河动漫番剧.mp4", + file_size=100, + mime_type="video/mp4", + file_ext="mp4", + folder_id=1, + storage_object_id=20, + created_at=None, + updated_at=None, + ), + SimpleNamespace( + file_id=11, + file_name="notes.txt", + file_size=10, + mime_type="text/plain", + file_ext="txt", + folder_id=1, + storage_object_id=21, + created_at=None, + updated_at=None, + ), + ], + ] + ) + router = ToolRouter(db=db, user_id=7) # type: ignore[arg-type] + + result = await router.dispatch( + ToolCall( + tool_name="drive.countFiles", + arguments={"folderId": "root", "recursive": True, "category": "anime"}, + ) + ) + + assert result["category"] == "video" + assert result["totalItems"] == 1 + assert result["itemNames"] == ["银河动漫番剧.mp4"] + assert result["itemNamesTruncated"] is False + + +@pytest.mark.asyncio +async def test_tool_router_count_files_truncates_item_names_at_limit(): + db = DummyDb() + db.scalar = AsyncMock(return_value=1) + rows = [ + SimpleNamespace( + file_id=100 + index, + file_name=f"video-{index:02d}.mp4", + file_size=100, + mime_type="video/mp4", + file_ext="mp4", + folder_id=1, + storage_object_id=200 + index, + created_at=None, + updated_at=None, + ) + for index in range(13) + ] + db.scalars = AsyncMock(side_effect=[[1], rows]) + router = ToolRouter(db=db, user_id=7) # type: ignore[arg-type] + + result = await router.dispatch( + ToolCall( + tool_name="drive.countFiles", + arguments={"folderId": "root", "recursive": True, "category": "video"}, + ) + ) + + assert result["totalItems"] == 13 + assert len(result["itemNames"]) == 12 + assert result["itemNames"][0] == "video-00.mp4" + assert result["itemNames"][-1] == "video-11.mp4" + assert result["itemNamesTruncated"] is True diff --git a/app/tests/test_agent_worker.py b/app/tests/test_agent_worker.py index 65b15b0..1275bfe 100644 --- a/app/tests/test_agent_worker.py +++ b/app/tests/test_agent_worker.py @@ -6,6 +6,7 @@ import pytest +from fileflash.agents.harness.event_bus import AgentEventEnvelope from fileflash.agents.worker import AgentWorkerConsumer from fileflash.models import BackgroundJob @@ -32,6 +33,19 @@ async def scalar(self, _query): # noqa: ANN001 return self._job +class CaptureBus: + def __init__(self) -> None: + self.events: list[AgentEventEnvelope] = [] + + async def publish(self, envelope: AgentEventEnvelope) -> None: + self.events.append(envelope) + + +class FailingBus: + async def publish(self, envelope: AgentEventEnvelope) -> None: # noqa: ARG002 + raise RuntimeError("publish unavailable") + + def _job(*, status: str, cancel_requested_at: datetime | None) -> BackgroundJob: now = datetime.now(UTC) return BackgroundJob( @@ -82,3 +96,45 @@ async def test_mark_failed_does_not_override_job_with_cancel_request(monkeypatch assert job.status == "running" assert job.cancel_requested_at == canceled_at assert job.error_message is None + + +@pytest.mark.asyncio +async def test_mark_succeeded_publishes_terminal_event(monkeypatch: pytest.MonkeyPatch): + job = _job(status="running", cancel_requested_at=None) + session = DummySession(job) + bus = CaptureBus() + consumer = AgentWorkerConsumer( + queue=SimpleNamespace(), + session_factory=lambda: _AsyncContextManager(session), # type: ignore[arg-type] + event_bus=bus, + ) + monkeypatch.setattr("fileflash.agents.worker.apply_local_lock_timeout", AsyncMock(return_value=None)) + + finished_at = datetime.now(UTC).replace(microsecond=0) + await consumer._mark_succeeded( + job_id=65, + result={"summary": "ok", "finishedAt": finished_at}, + phase="completed", + ) + + assert [event.event_type for event in bus.events] == ["job.succeeded"] + assert bus.events[0].payload["status"] == "succeeded" + assert bus.events[0].payload["data"]["result"]["finishedAt"] == finished_at.isoformat() + assert job.result["finishedAt"] == finished_at.isoformat() + + +@pytest.mark.asyncio +async def test_mark_succeeded_ignores_publish_failures(monkeypatch: pytest.MonkeyPatch): + job = _job(status="running", cancel_requested_at=None) + session = DummySession(job) + consumer = AgentWorkerConsumer( + queue=SimpleNamespace(), + session_factory=lambda: _AsyncContextManager(session), # type: ignore[arg-type] + event_bus=FailingBus(), + ) + monkeypatch.setattr("fileflash.agents.worker.apply_local_lock_timeout", AsyncMock(return_value=None)) + + await consumer._mark_succeeded(job_id=65, result={"summary": "ok"}, phase="completed") + + assert job.status == "succeeded" + assert job.result["summary"] == "ok" diff --git a/app/tests/test_file_download_recycle_service.py b/app/tests/test_file_download_recycle_service.py index 3d1caba..e0b6031 100644 --- a/app/tests/test_file_download_recycle_service.py +++ b/app/tests/test_file_download_recycle_service.py @@ -1,15 +1,29 @@ from __future__ import annotations +import tempfile from datetime import UTC, datetime +from types import SimpleNamespace from unittest.mock import AsyncMock import pytest - -from fileflash.core.errors import ApiError -from fileflash.models.enums import FileStatus, FolderStatus, FolderType, UploadStatus +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from fileflash.core.deps import get_current_user, get_download_rate_limit_service, get_file_service +from fileflash.core.errors import ApiError, api_error_handler +from fileflash.models.enums import ( + FileStatus, + FolderStatus, + FolderType, + UploadStatus, + UserRole, + UserStatus, +) +from fileflash.models.tables_identity import User from fileflash.models.tables_storage import File, FileMediaMetadata, Folder, StorageObject -from fileflash.schemas.file import BatchFilesRequest -from fileflash.services.file import FileService +from fileflash.routers.files import router as files_router +from fileflash.schemas.file import BatchDownloadRequest, BatchFilesRequest +from fileflash.services.file import BatchDownloadPlan, DownloadStreamResult, FileService class DummyStorage: @@ -41,6 +55,14 @@ def __init__(self) -> None: self.delete = AsyncMock() +class ResultRows: + def __init__(self, rows) -> None: # noqa: ANN001 + self._rows = rows + + def all(self): # noqa: ANN201 + return self._rows + + def make_file_row(*, file_id: int = 1, file_name: str = "demo.txt", folder_id: int = 10) -> File: return File( file_id=file_id, @@ -73,6 +95,22 @@ def make_folder_row(*, folder_id: int = 10, folder_name: str = "Docs") -> Folder ) +def make_user(*, role: UserRole = UserRole.USER) -> User: + return User( + user_id=1, + username="alice", + email="alice@example.com", + password_hash="x", + role=role, + status=UserStatus.ACTIVE, + email_verified=True, + storage_limit=1024, + storage_used=0, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + + @pytest.mark.asyncio async def test_get_download_stream_supports_single_range(monkeypatch: pytest.MonkeyPatch): session = DummySession() @@ -375,6 +413,145 @@ async def test_get_preview_stream_falls_back_to_source_when_transcoded_missing(m assert result.headers["Content-Length"] == "512" +@pytest.mark.asyncio +async def test_batch_download_plan_estimates_source_file_size() -> None: + session = DummySession() + storage = DummyStorage() + service = FileService(db=session, storage=storage) + file_row = make_file_row(file_id=7, file_name="archive.bin") + file_row.storage_object_id = 99 + storage_object = StorageObject( + object_id=99, + bucket_name="fileflash", + object_key="objects/u1/archive", + object_size=512, + upload_status=UploadStatus.ACTIVE, + content_type="application/octet-stream", + ) + session.scalars = AsyncMock(return_value=[file_row]) + session.execute = AsyncMock(return_value=ResultRows([(file_row, storage_object)])) + + plan = await service.create_batch_download_plan( + user_id=1, + payload=BatchDownloadRequest(fileIds=["7"]), + ) + + assert plan.estimated_bytes == 512 + assert plan.files[0][2] == "archive.bin" + + +class StubDownloadLimiter: + def __init__(self, *, deny: bool = False) -> None: + self.deny = deny + self.calls: list[tuple[str, int]] = [] + + async def enforce_user(self, *, user: User, bytes_count: int) -> None: + self.calls.append((f"user:{user.user_id}", bytes_count)) + if self.deny and user.role != UserRole.ADMIN: + raise ApiError(status_code=429, code=429, message="Download rate limit exceeded") + + async def enforce_user_id(self, *, user_id: int, bytes_count: int) -> None: + self.calls.append((f"user:{user_id}", bytes_count)) + if self.deny: + raise ApiError(status_code=429, code=429, message="Download rate limit exceeded") + + +class StubFileRouteService: + async def get_download_stream( + self, + *, + user_id: int, # noqa: ARG002 + file_id: str, # noqa: ARG002 + range_header: str | None, + ) -> DownloadStreamResult: + async def _stream(): + yield b"0123456789" + + headers = {"Content-Length": "4" if range_header else "10", "Accept-Ranges": "bytes"} + if range_header: + headers["Content-Range"] = "bytes 0-3/10" + return DownloadStreamResult( + stream=_stream(), + filename="demo.txt", + content_type="text/plain", + status_code=206 if range_header else 200, + headers=headers, + ) + + async def get_preview_stream(self, **kwargs) -> DownloadStreamResult: # noqa: ANN003 + return await self.get_download_stream(**kwargs) + + async def create_batch_download_plan( + self, + *, + user_id: int, # noqa: ARG002 + payload: BatchDownloadRequest, # noqa: ARG002 + ) -> BatchDownloadPlan: + return SimpleNamespace(estimated_bytes=10, files=[object()]) # type: ignore[return-value] + + async def create_batch_download_archive_from_plan(self, *, plan: BatchDownloadPlan): # noqa: ANN201, ARG002 + tmp = tempfile.NamedTemporaryFile(prefix="fileflash-test-", suffix=".zip", delete=False) + tmp.write(b"zip") + tmp.close() + return tmp.name, "test.zip" + + +def _files_client(*, role: UserRole, limiter: StubDownloadLimiter) -> TestClient: + app = FastAPI() + app.add_exception_handler(ApiError, api_error_handler) + app.include_router(files_router, prefix="/api/v1") + app.dependency_overrides[get_current_user] = lambda: make_user(role=role) + app.dependency_overrides[get_file_service] = lambda: StubFileRouteService() + app.dependency_overrides[get_download_rate_limit_service] = lambda: limiter + return TestClient(app) + + +def test_download_route_returns_429_when_limiter_rejects_user() -> None: + limiter = StubDownloadLimiter(deny=True) + with _files_client(role=UserRole.USER, limiter=limiter) as client: + response = client.get("/api/v1/files/1/download") + + assert response.status_code == 429 + assert limiter.calls == [("user:1", 10)] + + +def test_download_route_preserves_range_response_when_allowed() -> None: + limiter = StubDownloadLimiter() + with _files_client(role=UserRole.USER, limiter=limiter) as client: + response = client.get("/api/v1/files/1/download", headers={"Range": "bytes=0-3"}) + + assert response.status_code == 206 + assert response.headers["content-range"] == "bytes 0-3/10" + assert limiter.calls == [("user:1", 4)] + + +def test_admin_download_route_is_not_rejected_by_user_limiter() -> None: + limiter = StubDownloadLimiter(deny=True) + with _files_client(role=UserRole.ADMIN, limiter=limiter) as client: + response = client.get("/api/v1/files/1/download") + + assert response.status_code == 200 + assert limiter.calls == [("user:1", 10)] + + +def test_batch_download_route_returns_429_before_archive_when_limited() -> None: + limiter = StubDownloadLimiter(deny=True) + with _files_client(role=UserRole.USER, limiter=limiter) as client: + response = client.post("/api/v1/files/batch-download", json={"fileIds": ["1"]}) + + assert response.status_code == 429 + assert limiter.calls == [("user:1", 10)] + + +def test_admin_batch_download_route_is_not_rejected_by_user_limiter() -> None: + limiter = StubDownloadLimiter(deny=True) + with _files_client(role=UserRole.ADMIN, limiter=limiter) as client: + response = client.post("/api/v1/files/batch-download", json={"fileIds": ["1"]}) + + assert response.status_code == 200 + assert response.content == b"zip" + + @pytest.mark.asyncio async def test_delete_file_marks_record_deleted(monkeypatch: pytest.MonkeyPatch): session = DummySession() diff --git a/app/tests/test_rate_limiter.py b/app/tests/test_rate_limiter.py new file mode 100644 index 0000000..386c160 --- /dev/null +++ b/app/tests/test_rate_limiter.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock + +import pytest +from redis.exceptions import RedisError + +from fileflash.core.settings import Settings +from fileflash.models.enums import UserRole, UserStatus +from fileflash.models.tables_identity import User +from fileflash.services.download_rate_limit import DownloadRateLimitService +from fileflash.services.rate_limiter import RedisRateLimiter + + +class FakeRedis: + def __init__(self, *, fail: bool = False) -> None: + self.fail = fail + self.values: dict[str, int] = {} + self.expired: list[tuple[str, int]] = [] + + async def incrby(self, key: str, amount: int) -> int: + if self.fail: + raise RedisError("down") + self.values[key] = self.values.get(key, 0) + amount + return self.values[key] + + async def expire(self, key: str, window_seconds: int) -> None: + self.expired.append((key, window_seconds)) + + +@pytest.mark.asyncio +async def test_allow_weighted_uses_incrby_and_sets_ttl() -> None: + limiter = RedisRateLimiter("redis://example") + fake = FakeRedis() + limiter._redis = fake # type: ignore[assignment] + + allowed = await limiter.allow_weighted("k", limit=10, window_seconds=60, weight=4) + + assert allowed is True + assert fake.values["k"] == 4 + assert fake.expired == [("k", 60)] + + +@pytest.mark.asyncio +async def test_allow_weighted_rejects_over_limit() -> None: + limiter = RedisRateLimiter("redis://example") + fake = FakeRedis() + limiter._redis = fake # type: ignore[assignment] + + assert await limiter.allow_weighted("k", limit=5, window_seconds=60, weight=4) is True + assert await limiter.allow_weighted("k", limit=5, window_seconds=60, weight=2) is False + + +@pytest.mark.asyncio +async def test_allow_weighted_degrades_open_when_redis_fails() -> None: + limiter = RedisRateLimiter("redis://example") + limiter._redis = FakeRedis(fail=True) # type: ignore[assignment] + + assert await limiter.allow_weighted("k", limit=1, window_seconds=60, weight=10) is True + + +class FakeRateLimiter: + def __init__(self) -> None: + self.allow = AsyncMock(return_value=False) + self.allow_weighted = AsyncMock(return_value=False) + + +@pytest.mark.asyncio +async def test_download_rate_limiter_skips_admin_user_id() -> None: + admin = User( + user_id=1, + username="admin", + email="admin@example.com", + password_hash="x", + role=UserRole.ADMIN, + status=UserStatus.ACTIVE, + email_verified=True, + storage_limit=1024, + storage_used=0, + ) + db = type("Db", (), {"get": AsyncMock(return_value=admin)})() + rate_limiter = FakeRateLimiter() + service = DownloadRateLimitService( + db=db, # type: ignore[arg-type] + settings=Settings(DATABASE_URL="sqlite+aiosqlite:///:memory:"), + rate_limiter=rate_limiter, # type: ignore[arg-type] + ) + + await service.enforce_user_id(user_id=1, bytes_count=100) + + rate_limiter.allow.assert_not_awaited() + rate_limiter.allow_weighted.assert_not_awaited() diff --git a/app/tests/test_settings.py b/app/tests/test_settings.py index 738f3ff..52356ea 100644 --- a/app/tests/test_settings.py +++ b/app/tests/test_settings.py @@ -32,9 +32,18 @@ def test_agent_related_settings_defaults(): assert settings.agent_queue_stream == "fileflash:agents" assert settings.agent_job_timeout_sec == 600 assert settings.agent_tool_timeout_sec == 30 + assert settings.agent_llm_plan_max_tokens == 8192 assert settings.agent_mcp_endpoints == () +def test_auth_risk_control_defaults(): + settings = make_settings() + assert settings.register_rate_limit == 12 + assert settings.login_rate_limit == 30 + assert settings.max_failed_login_attempts == 8 + assert settings.account_lock_minutes == 5 + + def test_app_env_detection(): dev = make_settings(APP_ENV="development") assert dev.is_development_env is True diff --git a/app/tests/test_share_routes.py b/app/tests/test_share_routes.py index 48cb008..e6d5b6f 100644 --- a/app/tests/test_share_routes.py +++ b/app/tests/test_share_routes.py @@ -5,7 +5,13 @@ from fastapi import FastAPI from fastapi.testclient import TestClient -from fileflash.core.deps import get_client_ip, get_share_service, get_user_agent +from fileflash.core.deps import ( + get_client_ip, + get_download_rate_limit_service, + get_share_service, + get_user_agent, +) +from fileflash.core.errors import ApiError, api_error_handler from fileflash.routers.shares import router as shares_router @@ -19,6 +25,7 @@ async def get_shared_file_download_stream_response( range_header: str | None, # noqa: ARG002 ip_address: str, # noqa: ARG002 user_agent: str | None, # noqa: ARG002 + rate_limit_check=None, # noqa: ANN001 ) -> tuple[AsyncIterator[bytes], str, str, int, dict[str, str]]: async def _stream() -> AsyncIterator[bytes]: yield b"data" @@ -32,15 +39,30 @@ async def _stream() -> AsyncIterator[bytes]: "Accept-Ranges": "bytes", "Content-Length": "4", } + if rate_limit_check is not None: + await rate_limit_check(4) return _stream(), "测试文档.pdf", "application/pdf", 200, headers -def _build_client() -> TestClient: +class StubDownloadLimiter: + def __init__(self, *, deny: bool = False) -> None: + self.deny = deny + self.calls: list[tuple[str, int]] = [] + + async def enforce_share_ip(self, *, client_ip: str, bytes_count: int) -> None: + self.calls.append((client_ip, bytes_count)) + if self.deny: + raise ApiError(status_code=429, code=429, message="Download rate limit exceeded") + + +def _build_client(limiter: StubDownloadLimiter | None = None) -> TestClient: app = FastAPI() + app.add_exception_handler(ApiError, api_error_handler) app.include_router(shares_router, prefix="/api/v1") app.dependency_overrides[get_share_service] = lambda: StubShareService() app.dependency_overrides[get_client_ip] = lambda: "127.0.0.1" app.dependency_overrides[get_user_agent] = lambda: "pytest" + app.dependency_overrides[get_download_rate_limit_service] = lambda: limiter or StubDownloadLimiter() return TestClient(app) @@ -71,3 +93,27 @@ def test_shared_preview_handles_unicode_filename_header() -> None: assert 'filename*=UTF-8\'\'' in header header.encode("latin-1") assert response.content == b"data" + + +def test_shared_download_returns_429_when_ip_limited() -> None: + limiter = StubDownloadLimiter(deny=True) + with _build_client(limiter) as client: + response = client.get( + "/api/v1/shares/ABCD/download", + headers={"Authorization": "Bearer test-share-token"}, + ) + + assert response.status_code == 429 + assert limiter.calls == [("127.0.0.1", 4)] + + +def test_shared_preview_returns_429_when_ip_limited() -> None: + limiter = StubDownloadLimiter(deny=True) + with _build_client(limiter) as client: + response = client.get( + "/api/v1/shares/ABCD/preview", + headers={"Authorization": "Bearer test-share-token"}, + ) + + assert response.status_code == 429 + assert limiter.calls == [("127.0.0.1", 4)] diff --git a/docker/flyway/migrations/V14__agent_inbox.sql b/docker/flyway/migrations/V14__agent_inbox.sql new file mode 100644 index 0000000..16ab9e0 --- /dev/null +++ b/docker/flyway/migrations/V14__agent_inbox.sql @@ -0,0 +1,49 @@ +-- ========================= +-- Domain: agent inbox +-- ========================= + +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'agent_inbox_role_enum') THEN + CREATE TYPE agent_inbox_role_enum AS ENUM ('agent', 'user'); + END IF; + IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'agent_inbox_kind_enum') THEN + CREATE TYPE agent_inbox_kind_enum AS ENUM ( + 'ask', + 'reply', + 'control.pause', + 'control.resume', + 'control.approve', + 'control.deny', + 'control.skip', + 'control.cancel' + ); + END IF; + IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'agent_inbox_status_enum') THEN + CREATE TYPE agent_inbox_status_enum AS ENUM ('waiting', 'answered', 'timed_out', 'dropped'); + END IF; +END +$$; + +CREATE TABLE IF NOT EXISTS agent_inbox_message ( + inbox_message_id BIGINT GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, + job_id BIGINT NOT NULL, + role agent_inbox_role_enum NOT NULL, + kind agent_inbox_kind_enum NOT NULL, + payload_json JSONB NOT NULL DEFAULT '{}'::jsonb, + reply_to_id BIGINT NULL, + status agent_inbox_status_enum NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + answered_at TIMESTAMP NULL, + CONSTRAINT fk_agent_inbox_message_job + FOREIGN KEY (job_id) REFERENCES background_job(job_id) ON DELETE CASCADE, + CONSTRAINT fk_agent_inbox_message_reply_to + FOREIGN KEY (reply_to_id) REFERENCES agent_inbox_message(inbox_message_id) ON DELETE SET NULL +); + +CREATE INDEX IF NOT EXISTS idx_agent_inbox_message_job_created + ON agent_inbox_message (job_id, created_at); + +CREATE INDEX IF NOT EXISTS idx_agent_inbox_message_job_status + ON agent_inbox_message (job_id, status) + WHERE status IS NOT NULL; diff --git a/docker/flyway/migrations/V15__agent_tool_registry_skills.sql b/docker/flyway/migrations/V15__agent_tool_registry_skills.sql new file mode 100644 index 0000000..f1a9079 --- /dev/null +++ b/docker/flyway/migrations/V15__agent_tool_registry_skills.sql @@ -0,0 +1,19 @@ +UPDATE agent_skill +SET tool_whitelist_json = '[ + "drive.listFolder", + "drive.countFiles", + "drive.searchFiles", + "drive.getFileInfo", + "drive.listRecent", + "drive.statsByCategory", + "drive.findDuplicates", + "drive.createFolder", + "drive.moveFile", + "drive.moveFolder", + "drive.renameFile", + "drive.renameFolder", + "drive.deleteFile", + "drive.deleteFolder" +]'::jsonb, +updated_at = CURRENT_TIMESTAMP +WHERE skill_key = 'builtin:organizeByType'; diff --git a/docs/superpowers/plans/2026-05-24-admin-console-backend.md b/docs/superpowers/plans/2026-05-24-admin-console-backend.md index 5551456..ebdb421 100644 --- a/docs/superpowers/plans/2026-05-24-admin-console-backend.md +++ b/docs/superpowers/plans/2026-05-24-admin-console-backend.md @@ -89,6 +89,7 @@ app/src/fileflash/ ## Task 0: 基础设施(admin 包 + 状态映射) **Files:** + - Create: `app/src/fileflash/schemas/admin/__init__.py` - Create: `app/src/fileflash/services/admin/__init__.py` - Create: `app/src/fileflash/services/admin/_status.py` @@ -214,6 +215,7 @@ git commit -m "feat(admin): scaffold admin packages and user status mapping" ## Task 1: Admin Users — schemas + service **Files:** + - Create: `app/src/fileflash/schemas/admin/users.py` - Create: `app/src/fileflash/services/admin/users.py` - Test: `app/tests/test_admin_users_service.py` @@ -602,6 +604,7 @@ git commit -m "feat(admin): users service with list and set_status (last-admin g ## Task 2: Admin Users — router + deps + 注册 **Files:** + - Create: `app/src/fileflash/routers/admin_users.py` - Modify: `app/src/fileflash/core/deps.py` - Modify: `app/src/fileflash/routers/__init__.py` @@ -785,6 +788,7 @@ git commit -m "feat(admin): /admin/users list + /admin/users/{id}/status routes" ## Task 3: Admin Storage(summary / users / quota / usage-trend) **Files:** + - Create: `app/src/fileflash/schemas/admin/storage.py` - Create: `app/src/fileflash/services/admin/storage.py` - Create: `app/src/fileflash/routers/admin_storage.py` @@ -1370,6 +1374,7 @@ git commit -m "feat(admin): /admin/storage summary, users, quota, usage-trend" ## Task 4: Admin Files(list + rescan) **Files:** + - Create: `app/src/fileflash/schemas/admin/files.py` - Create: `app/src/fileflash/services/admin/files.py` - Create: `app/src/fileflash/routers/admin_files.py` @@ -1836,6 +1841,7 @@ git commit -m "feat(admin): /admin/files list + /admin/files/{id}/rescan with ev ## Task 5: Admin Moderation(violations list + resolve) **Files:** + - Create: `app/src/fileflash/schemas/admin/moderation.py` - Create: `app/src/fileflash/services/admin/moderation.py` - Create: `app/src/fileflash/routers/admin_moderation.py` @@ -2201,6 +2207,7 @@ git commit -m "feat(admin): /admin/violations list + resolve via ModerationCase" ## Task 6: Admin Logs **Files:** + - Create: `app/src/fileflash/schemas/admin/logs.py` - Create: `app/src/fileflash/services/admin/logs.py` - Create: `app/src/fileflash/routers/admin_logs.py` @@ -2421,6 +2428,7 @@ git commit -m "feat(admin): /admin/logs list with filters" ## Task 7: Admin Notifications(list / broadcast / read / archive) **Files:** + - Create: `app/src/fileflash/schemas/admin/notifications.py` - Create: `app/src/fileflash/services/admin/notifications.py` - Create: `app/src/fileflash/routers/admin_notifications.py` @@ -2873,6 +2881,7 @@ git commit -m "feat(admin): /admin/notifications list, broadcast, archive" ## Task 8: Admin System(health + rate-limit) **Files:** + - Create: `app/src/fileflash/schemas/admin/system.py` - Create: `app/src/fileflash/services/admin/system.py` - Create: `app/src/fileflash/routers/admin_system.py` diff --git a/docs/superpowers/plans/2026-05-24-admin-console-frontend.md b/docs/superpowers/plans/2026-05-24-admin-console-frontend.md index 99b597c..f236394 100644 --- a/docs/superpowers/plans/2026-05-24-admin-console-frontend.md +++ b/docs/superpowers/plans/2026-05-24-admin-console-frontend.md @@ -87,6 +87,7 @@ web/src/pages/dashboard/index.ts ## Task 0: api / types / mock 对齐 Plan A 契约 **Files:** + - Modify: `web/src/api/storage.ts`, `web/src/api/log.ts`, `web/src/api/notification.ts` - Modify: `web/src/types/log.d.ts`, `web/src/types/notification.d.ts` - Modify: `web/src/mock/handlers/log.ts`, `web/src/mock/handlers/notification.ts`, `web/src/mock/handlers/storage.ts` @@ -275,6 +276,7 @@ git commit -m "feat(web): align api+mock with Plan A admin contracts" ## Task 1: ConsoleLayout + Sidebar + 路由 **Files:** + - Create: `web/src/pages/console/ConsoleLayout.vue`, `ConsoleSidebar.vue`, `index.ts` - Modify: `web/src/router/routes.ts` @@ -472,6 +474,7 @@ git commit -m "feat(web): scaffold Console layout, sidebar, and 9 subpage routes ## Task 2: 共享组件(components/console/) **Files:** + - Create: `web/src/components/console/KpiCard.vue`, `StatusBadge.vue`, `FilterBar.vue`, `AdminTable.vue`, `TrendChart.vue`, `BroadcastComposer.vue`, `QuotaEditor.vue`, `index.ts` - [ ] **Step 2.1: KpiCard.vue** @@ -840,6 +843,7 @@ git commit -m "feat(web): add Console shared components (KpiCard, AdminTable, et ## Task 3: Overview 子页 **Files:** + - Replace placeholder: `web/src/pages/console/overview/OverviewPage.vue` - [ ] **Step 3.1: 实现 OverviewPage** @@ -1825,6 +1829,7 @@ git commit -m "feat(web): Console Rules page (registration email domains)" ## Task 12: 主框架集成、删除旧 Dashboard、i18n **Files:** + - Modify: `web/src/components/organisms/shell/UserMenu.vue` - Modify: `web/src/i18n/messages.ts` - Delete: `web/src/pages/dashboard/Dashboard.vue`, `web/src/pages/dashboard/index.ts` @@ -1856,9 +1861,10 @@ Edit `web/src/i18n/messages.ts`: | 'console.nav.rules' ``` -2. 中英 messages 表中,把原 `'header.menu.dashboard': '仪表盘'` 改为 `'header.menu.console': '控制台'`(en-US 改为 `'Console'`)。追加: +1. 中英 messages 表中,把原 `'header.menu.dashboard': '仪表盘'` 改为 `'header.menu.console': '控制台'`(en-US 改为 `'Console'`)。追加: 中文表: + ``` 'console.title': '控制台', 'console.nav.overview': '概览', @@ -1873,6 +1879,7 @@ Edit `web/src/i18n/messages.ts`: ``` 英文表: + ``` 'console.title': 'Console', 'console.nav.overview': 'Overview', @@ -1924,6 +1931,7 @@ git commit -m "feat(web): wire Console into MainLayout, drop legacy Dashboard, a ```bash cd web && bun run check ``` + Expected: 0 errors。 - [ ] **Step 13.2: 全量构建** @@ -1931,6 +1939,7 @@ Expected: 0 errors。 ```bash cd web && bun run build ``` + Expected: dist/ 生成、无错误。 - [ ] **Step 13.3: dev server 手动巡检** diff --git a/docs/superpowers/plans/2026-05-26-agent-A-interaction-backend.md b/docs/superpowers/plans/2026-05-26-agent-A-interaction-backend.md new file mode 100644 index 0000000..ab86d63 --- /dev/null +++ b/docs/superpowers/plans/2026-05-26-agent-A-interaction-backend.md @@ -0,0 +1,2284 @@ +# Agent 子项目 A(交互/反馈层)— 后端实现计划 + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** 把 agent 后端从单向 SSE + DB 轮询升级为 Redis pub/sub 推送 + POST inbox 双向通道,支持 `agent.ask` / `agent.progress` / `agent.thinking` / `tool.partial` 等新事件,以及 pause/resume/skip/approve/cancel 等控制信号在 step 边界生效。 + +**Architecture:** 新增 `AgentEventBus`(Redis pub/sub 封装)、`AgentInboxMessage` 表 + repository、`AgentInbox` 服务、`AskProtocol` 协议三大单元;SSE 端点从轮询 DB 改为订阅 Redis channel;ExecuteRunner / PlanRunner 在 step 边界检查 inbox。前端最小接入留给前端 plan。 + +**Tech Stack:** Python 3.12 + FastAPI + SQLAlchemy async + Redis pub/sub (`redis.asyncio`) + Flyway SQL 迁移 + pytest + 既有 stub 风格测试(不引入 fakeredis)。 + +**Spec:** `docs/superpowers/specs/2026-05-26-agent-improvements-design.md` 子项目 A 部分 + +--- + +## File Structure + +**新建(src)** + +- `app/src/fileflash/agents/harness/event_bus.py` — `AgentEventBus`(publish + subscribe)+ 内存 stub 用于测试 +- `app/src/fileflash/agents/harness/inbox.py` — `AgentInbox` 服务(写表 + publish) +- `app/src/fileflash/agents/harness/ask.py` — `AskProtocol` 协议(创建 ask 消息、阻塞等回答) +- `app/src/fileflash/repositories/agent/inbox.py` — `AgentInboxMessageRepository` + +**新建(迁移)** + +- `docker/flyway/migrations/V14__agent_inbox.sql` — `AgentInboxMessage` 表 + enum 类型 + +**修改** + +- `app/src/fileflash/models/enums.py` — 新增 `AgentInboxRole` / `AgentInboxKind` / `AgentInboxStatus` +- `app/src/fileflash/models/tables_agent.py` — 新增 `AgentInboxMessage` ORM model +- `app/src/fileflash/models/__init__.py` — 导出 `AgentInboxMessage` +- `app/src/fileflash/repositories/__init__.py` — 导出 `AgentInboxMessageRepository` +- `app/src/fileflash/schemas/agent.py` — 新事件类型字面量 + 上行 message 类型 +- `app/src/fileflash/routers/agent.py` — 新增 `POST /agent/jobs/{id}/messages`、改 SSE 实现、删除 `POST /agent/cancel/{job_id}` +- `app/src/fileflash/agents/runtime/execute_runner.py` — step 边界检查 inbox(pause/resume/cancel/skip/approve)+ publish 工具事件 +- `app/src/fileflash/agents/runtime/plan_runner.py` — 接入 ask 协议 +- `app/src/fileflash/agents/worker.py` — 创建 EventBus 单例并下发到 runner +- `app/src/fileflash/core/settings.py` — 增 `agent_inbox_ask_timeout_sec`(默认 1800)+ Redis pub/sub channel 配置项 +- `app/src/fileflash/core/deps.py` — 注入 `AgentEventBus` 依赖 + +**测试** + +- `app/tests/test_agent_event_bus.py` — 新 +- `app/tests/test_agent_inbox.py` — 新 +- `app/tests/test_agent_ask_protocol.py` — 新 +- `app/tests/test_agent_routes.py` — 扩展 +- `app/tests/test_agent_plan_execute_runtime.py` — 扩展 + +**前端** + +不在本 plan 范围。参见 `2026-05-26-agent-A-interaction-frontend.md`。本 plan 完成后,后端通过 curl/httpx 集成测试可以独立验证。 + +--- + +## Sequencing + +``` +Task 1 (settings) ──► Task 2 (enums) ──► Task 3 (SQL 迁移) ──► Task 4 (ORM model) + │ + ┌────────────────────────┘ + ▼ + Task 5 (repository) + │ + ┌──────────────────────────────────────┼─────────────────────────┐ + ▼ ▼ ▼ +Task 6 (schemas) Task 7 (EventBus) Task 8 (Inbox service) + │ + ▼ + Task 9 (Ask protocol) + │ + ┌─────────────────────────┘ + ▼ + Task 10 (POST /messages 路由) + │ + ▼ + Task 11 (SSE 改 EventBus subscribe) + │ + ▼ + Task 12 (删 POST /cancel) + │ + ▼ + Task 13 (ExecuteRunner 接 inbox) + │ + ▼ + Task 14 (PlanRunner 接 ask) + │ + ▼ + Task 15 (worker 装配) + │ + ▼ + Task 16 (端到端集成测试) +``` + +--- + +## Task 1: 配置项与依赖 + +**Files:** + +- Modify: `app/src/fileflash/core/settings.py` + +- [ ] **Step 1: 在 `Settings` 类合适位置(紧跟 `redis_url` 之后)增加 4 个配置项** + +```python + agent_inbox_ask_timeout_sec: int = Field( + default=1800, + alias="AGENT_INBOX_ASK_TIMEOUT_SEC", + ) + agent_event_channel_prefix: str = Field( + default="agent:job", + alias="AGENT_EVENT_CHANNEL_PREFIX", + ) + agent_inbox_channel_prefix: str = Field( + default="agent:inbox", + alias="AGENT_INBOX_CHANNEL_PREFIX", + ) + agent_event_bus_buffer_size: int = Field( + default=64, + alias="AGENT_EVENT_BUS_BUFFER_SIZE", + ) +``` + +- [ ] **Step 2: 删除既有 `routers/agent.py` 顶部的 `AGENT_EVENT_POLL_INTERVAL_SEC` 常量(line 22)。如果还有其它文件引用此常量,先用 Grep 确认无引用再删。** + +Run: `grep -rn "AGENT_EVENT_POLL_INTERVAL_SEC" app/src/ app/tests/` +Expected: 仅 `routers/agent.py:22` 一处定义、`_event_stream` 内一处引用。 + +- [ ] **Step 3: Commit** + +```bash +git add app/src/fileflash/core/settings.py +git commit -m "feat(agent): add inbox + event bus settings" +``` + +--- + +## Task 2: 新增 inbox 相关枚举 + +**Files:** + +- Modify: `app/src/fileflash/models/enums.py` + +- [ ] **Step 1: 在 `AgentMcpVisibility` 之后追加三个枚举** + +```python +class AgentInboxRole(BaseStrEnum): + AGENT = "agent" + USER = "user" + + +class AgentInboxKind(BaseStrEnum): + ASK = "ask" + REPLY = "reply" + CONTROL_PAUSE = "control.pause" + CONTROL_RESUME = "control.resume" + CONTROL_APPROVE = "control.approve" + CONTROL_DENY = "control.deny" + CONTROL_SKIP = "control.skip" + CONTROL_CANCEL = "control.cancel" + + +class AgentInboxStatus(BaseStrEnum): + WAITING = "waiting" + ANSWERED = "answered" + TIMED_OUT = "timed_out" + DROPPED = "dropped" +``` + +- [ ] **Step 2: 把上面三个名字加到 `__all__` 末尾** + +```python +__all__ = [ + # ... existing entries ... + "AgentInboxRole", + "AgentInboxKind", + "AgentInboxStatus", +] +``` + +- [ ] **Step 3: Commit** + +```bash +git add app/src/fileflash/models/enums.py +git commit -m "feat(agent): add inbox role/kind/status enums" +``` + +--- + +## Task 3: Flyway 迁移 V14(新表 + pg enums) + +**Files:** + +- Create: `docker/flyway/migrations/V14__agent_inbox.sql` + +- [ ] **Step 1: 写完整 SQL 迁移** + +```sql +-- ========================= +-- Domain: agent inbox +-- ========================= + +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'agent_inbox_role_enum') THEN + CREATE TYPE agent_inbox_role_enum AS ENUM ('agent', 'user'); + END IF; + IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'agent_inbox_kind_enum') THEN + CREATE TYPE agent_inbox_kind_enum AS ENUM ( + 'ask', + 'reply', + 'control.pause', + 'control.resume', + 'control.approve', + 'control.deny', + 'control.skip', + 'control.cancel' + ); + END IF; + IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'agent_inbox_status_enum') THEN + CREATE TYPE agent_inbox_status_enum AS ENUM ('waiting', 'answered', 'timed_out', 'dropped'); + END IF; +END +$$; + +CREATE TABLE IF NOT EXISTS agent_inbox_message ( + inbox_message_id BIGINT GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, + job_id BIGINT NOT NULL, + role agent_inbox_role_enum NOT NULL, + kind agent_inbox_kind_enum NOT NULL, + payload_json JSONB NOT NULL DEFAULT '{}'::jsonb, + reply_to_id BIGINT NULL, + status agent_inbox_status_enum NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + answered_at TIMESTAMP NULL, + CONSTRAINT fk_agent_inbox_message_job + FOREIGN KEY (job_id) REFERENCES background_job(job_id) ON DELETE CASCADE, + CONSTRAINT fk_agent_inbox_message_reply_to + FOREIGN KEY (reply_to_id) REFERENCES agent_inbox_message(inbox_message_id) ON DELETE SET NULL +); + +CREATE INDEX IF NOT EXISTS idx_agent_inbox_message_job_created + ON agent_inbox_message (job_id, created_at); + +CREATE INDEX IF NOT EXISTS idx_agent_inbox_message_job_status + ON agent_inbox_message (job_id, status) + WHERE status IS NOT NULL; +``` + +- [ ] **Step 2: 在本地 PostgreSQL 应用迁移(按既有 Flyway 流程)** + +Run: `docker compose -f docker/compose.yml up flyway --build` 或项目既有 migration 命令。 +Expected: V14 标记为 success;`\dt agent_inbox_message` 在 psql 中能看到新表。 + +- [ ] **Step 3: Commit** + +```bash +git add docker/flyway/migrations/V14__agent_inbox.sql +git commit -m "feat(agent): V14 add agent_inbox_message table" +``` + +--- + +## Task 4: ORM model `AgentInboxMessage` + +**Files:** + +- Modify: `app/src/fileflash/models/tables_agent.py` +- Modify: `app/src/fileflash/models/__init__.py` + +- [ ] **Step 1: 在 `tables_agent.py` 顶部导入区追加** + +```python +from .enums import ( + AgentExecutionPolicy, + AgentInboxKind, + AgentInboxRole, + AgentInboxStatus, + AgentMcpVisibility, + AgentMemoryKind, + AgentMemoryScope, + AgentSkillVisibility, +) +``` + +- [ ] **Step 2: 在 `AgentWorkSession` 类之后追加 `AgentInboxMessage` 类** + +```python +class AgentInboxMessage(Base): + __tablename__ = "agent_inbox_message" + __table_args__ = ( + Index("idx_agent_inbox_message_job_created", "job_id", "created_at"), + Index( + "idx_agent_inbox_message_job_status", + "job_id", + "status", + postgresql_where=text("status IS NOT NULL"), + ), + ) + + inbox_message_id: Mapped[int] = mapped_column(BigInteger, Identity(), primary_key=True) + job_id: Mapped[int] = mapped_column( + BigInteger, + ForeignKey("background_job.job_id", ondelete="CASCADE"), + nullable=False, + ) + role: Mapped[AgentInboxRole] = mapped_column( + pg_enum(AgentInboxRole, "agent_inbox_role_enum"), + nullable=False, + ) + kind: Mapped[AgentInboxKind] = mapped_column( + pg_enum(AgentInboxKind, "agent_inbox_kind_enum"), + nullable=False, + ) + payload_json: Mapped[dict[str, Any]] = mapped_column( + JSONB, + nullable=False, + server_default=text("'{}'::jsonb"), + ) + reply_to_id: Mapped[int | None] = mapped_column( + BigInteger, + ForeignKey("agent_inbox_message.inbox_message_id", ondelete="SET NULL"), + ) + status: Mapped[AgentInboxStatus | None] = mapped_column( + pg_enum(AgentInboxStatus, "agent_inbox_status_enum"), + ) + created_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + server_default=text("CURRENT_TIMESTAMP"), + ) + answered_at: Mapped[datetime | None] = mapped_column(DateTime) +``` + +- [ ] **Step 3: 把 `AgentInboxMessage` 加到 `__all__` 末尾,并在 `app/src/fileflash/models/__init__.py` 导出** + +```python +# tables_agent.py __all__ +__all__ = [ + "AgentActionLog", + "AgentInboxMessage", + "AgentMcpServer", + "AgentMemory", + "AgentPlan", + "AgentSkill", + "AgentUserSetting", + "AgentWorkSession", +] +``` + +- [ ] **Step 4: 写最小 sanity 测试,确认 model 能与 DB 通信** + +新建 `app/tests/test_agent_inbox_model.py`: + +```python +from datetime import UTC, datetime + +import pytest +from sqlalchemy import select + +from fileflash.models import AgentInboxMessage, BackgroundJob +from fileflash.models.enums import AgentInboxKind, AgentInboxRole, AgentInboxStatus + + +@pytest.mark.asyncio +async def test_insert_ask_message_round_trip(db_session, sample_background_job): # noqa: ANN001 + msg = AgentInboxMessage( + job_id=sample_background_job.job_id, + role=AgentInboxRole.AGENT, + kind=AgentInboxKind.ASK, + payload_json={"prompt": "which one?", "schema": {}}, + status=AgentInboxStatus.WAITING, + created_at=datetime.now(UTC), + ) + db_session.add(msg) + await db_session.commit() + fetched = await db_session.scalar(select(AgentInboxMessage).where( + AgentInboxMessage.inbox_message_id == msg.inbox_message_id + )) + assert fetched is not None + assert fetched.kind == AgentInboxKind.ASK + assert fetched.status == AgentInboxStatus.WAITING + assert fetched.payload_json["prompt"] == "which one?" +``` + +> 注:`db_session` / `sample_background_job` 是项目既有 pytest fixture(参见 `app/tests/test_agent_repositories.py`)。如名字不一致,沿用该测试文件里的 fixture 名。 + +- [ ] **Step 5: 运行测试** + +Run: `cd app && uv run pytest tests/test_agent_inbox_model.py -v` +Expected: PASS + +- [ ] **Step 6: Commit** + +```bash +git add app/src/fileflash/models/tables_agent.py app/src/fileflash/models/__init__.py app/tests/test_agent_inbox_model.py +git commit -m "feat(agent): add AgentInboxMessage ORM model" +``` + +--- + +## Task 5: `AgentInboxMessageRepository` + +**Files:** + +- Create: `app/src/fileflash/repositories/agent/inbox.py` +- Modify: `app/src/fileflash/repositories/__init__.py` +- Create: `app/tests/test_agent_inbox_repository.py` + +- [ ] **Step 1: 写测试(先 fail)** + +```python +# app/tests/test_agent_inbox_repository.py +from datetime import UTC, datetime + +import pytest + +from fileflash.models.enums import AgentInboxKind, AgentInboxRole, AgentInboxStatus +from fileflash.repositories import AgentInboxMessageRepository + + +@pytest.mark.asyncio +async def test_create_ask_then_record_reply(db_session, sample_background_job): # noqa: ANN001 + repo = AgentInboxMessageRepository(db_session) + ask = await repo.create_ask( + job_id=int(sample_background_job.job_id), + payload={"prompt": "choose", "schema": {"choice": ["A", "B"]}}, + ) + await db_session.commit() + assert ask.status == AgentInboxStatus.WAITING + assert ask.role == AgentInboxRole.AGENT + assert ask.kind == AgentInboxKind.ASK + + reply = await repo.record_user_message( + job_id=int(sample_background_job.job_id), + kind=AgentInboxKind.REPLY, + payload={"value": "A"}, + reply_to_id=int(ask.inbox_message_id), + ) + await db_session.commit() + assert reply.role == AgentInboxRole.USER + assert reply.reply_to_id == ask.inbox_message_id + + answered = await repo.mark_answered( + inbox_message_id=int(ask.inbox_message_id), + answered_at=datetime.now(UTC), + ) + await db_session.commit() + assert answered.status == AgentInboxStatus.ANSWERED + assert answered.answered_at is not None + + +@pytest.mark.asyncio +async def test_pending_controls_excludes_consumed(db_session, sample_background_job): # noqa: ANN001 + repo = AgentInboxMessageRepository(db_session) + pause = await repo.record_user_message( + job_id=int(sample_background_job.job_id), + kind=AgentInboxKind.CONTROL_PAUSE, + payload={}, + ) + await db_session.commit() + + pending = await repo.list_pending_controls(job_id=int(sample_background_job.job_id)) + assert [m.inbox_message_id for m in pending] == [pause.inbox_message_id] + + await repo.mark_dropped(inbox_message_id=int(pause.inbox_message_id)) + await db_session.commit() + pending_after = await repo.list_pending_controls(job_id=int(sample_background_job.job_id)) + assert pending_after == [] +``` + +- [ ] **Step 2: 运行测试,确认 fail** + +Run: `cd app && uv run pytest tests/test_agent_inbox_repository.py -v` +Expected: FAIL — `AgentInboxMessageRepository` not exported. + +- [ ] **Step 3: 实现 repository** + +`app/src/fileflash/repositories/agent/inbox.py`: + +```python +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy import and_, select +from sqlalchemy.ext.asyncio import AsyncSession + +from ...models import AgentInboxMessage +from ...models.enums import AgentInboxKind, AgentInboxRole, AgentInboxStatus + +_CONTROL_KINDS = frozenset( + { + AgentInboxKind.CONTROL_PAUSE, + AgentInboxKind.CONTROL_RESUME, + AgentInboxKind.CONTROL_APPROVE, + AgentInboxKind.CONTROL_DENY, + AgentInboxKind.CONTROL_SKIP, + AgentInboxKind.CONTROL_CANCEL, + } +) + + +class AgentInboxMessageRepository: + def __init__(self, db: AsyncSession) -> None: + self._db = db + + async def create_ask( + self, + *, + job_id: int, + payload: dict[str, Any], + ) -> AgentInboxMessage: + msg = AgentInboxMessage( + job_id=job_id, + role=AgentInboxRole.AGENT, + kind=AgentInboxKind.ASK, + payload_json=payload, + status=AgentInboxStatus.WAITING, + created_at=datetime.now(UTC), + ) + self._db.add(msg) + await self._db.flush() + return msg + + async def record_user_message( + self, + *, + job_id: int, + kind: AgentInboxKind, + payload: dict[str, Any], + reply_to_id: int | None = None, + ) -> AgentInboxMessage: + msg = AgentInboxMessage( + job_id=job_id, + role=AgentInboxRole.USER, + kind=kind, + payload_json=payload, + reply_to_id=reply_to_id, + status=None, + created_at=datetime.now(UTC), + ) + self._db.add(msg) + await self._db.flush() + return msg + + async def mark_answered( + self, + *, + inbox_message_id: int, + answered_at: datetime, + ) -> AgentInboxMessage: + msg = await self._db.get(AgentInboxMessage, inbox_message_id) + if msg is None: + raise ValueError(f"AgentInboxMessage {inbox_message_id} not found") + msg.status = AgentInboxStatus.ANSWERED + msg.answered_at = answered_at + await self._db.flush() + return msg + + async def mark_dropped(self, *, inbox_message_id: int) -> None: + msg = await self._db.get(AgentInboxMessage, inbox_message_id) + if msg is None: + return + if msg.kind in _CONTROL_KINDS: + msg.status = AgentInboxStatus.DROPPED + msg.answered_at = datetime.now(UTC) + await self._db.flush() + + async def get_ask(self, *, inbox_message_id: int) -> AgentInboxMessage | None: + msg = await self._db.get(AgentInboxMessage, inbox_message_id) + if msg is None or msg.kind != AgentInboxKind.ASK: + return None + return msg + + async def get_reply_for(self, *, ask_id: int) -> AgentInboxMessage | None: + return await self._db.scalar( + select(AgentInboxMessage).where( + and_( + AgentInboxMessage.reply_to_id == ask_id, + AgentInboxMessage.kind == AgentInboxKind.REPLY, + ) + ) + ) + + async def list_pending_controls(self, *, job_id: int) -> list[AgentInboxMessage]: + rows = await self._db.scalars( + select(AgentInboxMessage) + .where( + and_( + AgentInboxMessage.job_id == job_id, + AgentInboxMessage.role == AgentInboxRole.USER, + AgentInboxMessage.kind.in_(list(_CONTROL_KINDS)), + AgentInboxMessage.status.is_(None), + ) + ) + .order_by(AgentInboxMessage.created_at.asc()) + ) + return list(rows) +``` + +> 注:control 消息以"`status IS NULL` 表示未消费"为约定;worker 处理完后调 `mark_dropped`(命名只表"已消费、不再有效",不代表用户错误)。Reply 消息保持 `status IS NULL`,由 `mark_answered` 处理对应的 ask。 + +- [ ] **Step 4: 导出** + +`app/src/fileflash/repositories/__init__.py`:在合适位置新增 + +```python +from .agent.inbox import AgentInboxMessageRepository + +__all__ = [ + # ... existing entries ... + "AgentInboxMessageRepository", +] +``` + +- [ ] **Step 5: 运行测试** + +Run: `cd app && uv run pytest tests/test_agent_inbox_repository.py -v` +Expected: PASS(2 个用例) + +- [ ] **Step 6: Commit** + +```bash +git add app/src/fileflash/repositories/agent/inbox.py app/src/fileflash/repositories/__init__.py app/tests/test_agent_inbox_repository.py +git commit -m "feat(agent): add AgentInboxMessageRepository" +``` + +--- + +## Task 6: 新事件类型与上行 message schemas + +**Files:** + +- Modify: `app/src/fileflash/schemas/agent.py` + +- [ ] **Step 1: 扩展 `AgentJobEventType` 字面量与新增上行 message 模型** + +把 `AgentJobEventType` 改为: + +```python +AgentJobEventType = Literal[ + "job.queued", + "job.running", + "plan.ready", + "tool.started", + "tool.succeeded", + "tool.failed", + "tool.partial", + "agent.thinking", + "agent.progress", + "agent.ask", + "agent.paused", + "agent.resumed", + "job.succeeded", + "job.failed", + "job.canceled", +] +``` + +在文件末尾(`__all__` 之前)新增: + +```python +AgentInboxMessageKind = Literal[ + "reply", + "control.pause", + "control.resume", + "control.approve", + "control.deny", + "control.skip", + "control.cancel", +] + + +class AgentInboxMessageRequest(CamelModel): + kind: AgentInboxMessageKind + reply_to: str | None = None # ask 的 inbox_message_id(str-encoded) + value: Any = None # reply 时为用户回答;control 时通常 None + metadata: dict[str, Any] = Field(default_factory=dict) + + +class AgentInboxMessageResponse(CamelModel): + inbox_message_id: str + kind: AgentInboxMessageKind + accepted_at: datetime +``` + +把这两个名字加入 `__all__`。 + +- [ ] **Step 2: 写最小验证测试** + +新建 `app/tests/test_agent_inbox_schema.py`: + +```python +import pytest +from pydantic import ValidationError + +from fileflash.schemas.agent import AgentInboxMessageRequest + + +def test_reply_with_value_validates() -> None: + msg = AgentInboxMessageRequest.model_validate( + {"kind": "reply", "replyTo": "42", "value": "yes"} + ) + assert msg.kind == "reply" + assert msg.reply_to == "42" + assert msg.value == "yes" + + +def test_unknown_kind_rejected() -> None: + with pytest.raises(ValidationError): + AgentInboxMessageRequest.model_validate({"kind": "control.explode"}) +``` + +- [ ] **Step 3: 运行测试** + +Run: `cd app && uv run pytest tests/test_agent_inbox_schema.py -v` +Expected: PASS + +- [ ] **Step 4: Commit** + +```bash +git add app/src/fileflash/schemas/agent.py app/tests/test_agent_inbox_schema.py +git commit -m "feat(agent): extend job event types and add inbox message schemas" +``` + +--- + +## Task 7: `AgentEventBus`(Redis pub/sub 封装) + +**Files:** + +- Create: `app/src/fileflash/agents/harness/event_bus.py` +- Modify: `app/src/fileflash/agents/harness/events.py` — 保留 `AgentEvent`,删除 `EventBus` scaffold +- Create: `app/tests/test_agent_event_bus.py` + +- [ ] **Step 1: 写测试(先 fail)** + +```python +# app/tests/test_agent_event_bus.py +from datetime import UTC, datetime + +import pytest + +from fileflash.agents.harness.event_bus import ( + AgentEventEnvelope, + InMemoryAgentEventBus, +) + + +@pytest.mark.asyncio +async def test_subscriber_receives_published_event() -> None: + bus = InMemoryAgentEventBus() + envelope = AgentEventEnvelope( + job_id=42, + event_type="agent.ask", + payload={"prompt": "choose"}, + emitted_at=datetime.now(UTC), + ) + async with bus.subscribe(job_id=42) as stream: + await bus.publish(envelope) + received = await stream.next(timeout=1.0) + assert received == envelope + + +@pytest.mark.asyncio +async def test_only_subscribers_of_same_job_receive() -> None: + bus = InMemoryAgentEventBus() + own = AgentEventEnvelope(job_id=1, event_type="job.running", payload={}, emitted_at=datetime.now(UTC)) + other = AgentEventEnvelope(job_id=2, event_type="job.running", payload={}, emitted_at=datetime.now(UTC)) + async with bus.subscribe(job_id=1) as stream: + await bus.publish(other) + await bus.publish(own) + first = await stream.next(timeout=1.0) + assert first == own + + +@pytest.mark.asyncio +async def test_close_subscriber_unblocks() -> None: + bus = InMemoryAgentEventBus() + async with bus.subscribe(job_id=7) as stream: + with pytest.raises(TimeoutError): + await stream.next(timeout=0.1) +``` + +- [ ] **Step 2: 运行测试,确认 fail** + +Run: `cd app && uv run pytest tests/test_agent_event_bus.py -v` +Expected: FAIL — module missing. + +- [ ] **Step 3: 实现 `event_bus.py`** + +```python +# app/src/fileflash/agents/harness/event_bus.py +from __future__ import annotations + +import asyncio +import contextlib +import json +import logging +from collections.abc import AsyncIterator +from dataclasses import asdict, dataclass, field +from datetime import datetime +from typing import Any, Protocol + +from redis.asyncio import Redis + +from ...core.settings import Settings, get_settings + +logger = logging.getLogger(__name__) + + +@dataclass(slots=True) +class AgentEventEnvelope: + job_id: int + event_type: str + payload: dict[str, Any] + emitted_at: datetime + event_id: str | None = None + + def to_json(self) -> str: + body = asdict(self) + body["emitted_at"] = self.emitted_at.isoformat() + return json.dumps(body, ensure_ascii=False, separators=(",", ":")) + + @classmethod + def from_json(cls, raw: str) -> "AgentEventEnvelope": + data = json.loads(raw) + return cls( + job_id=int(data["job_id"]), + event_type=str(data["event_type"]), + payload=dict(data.get("payload") or {}), + emitted_at=datetime.fromisoformat(data["emitted_at"]), + event_id=data.get("event_id"), + ) + + +class AgentEventStream(Protocol): + async def next(self, *, timeout: float | None = None) -> AgentEventEnvelope: ... + async def aclose(self) -> None: ... + + +class AgentEventBus(Protocol): + async def publish(self, envelope: AgentEventEnvelope) -> None: ... + def subscribe(self, *, job_id: int) -> "AgentEventSubscription": ... + + +@dataclass(slots=True) +class _InMemoryStream: + queue: asyncio.Queue[AgentEventEnvelope] + + async def next(self, *, timeout: float | None = None) -> AgentEventEnvelope: + if timeout is None: + return await self.queue.get() + return await asyncio.wait_for(self.queue.get(), timeout=timeout) + + async def aclose(self) -> None: + return None + + +class InMemoryAgentEventBus: + """同进程实现,用于单元测试和单进程开发。生产用 RedisAgentEventBus。""" + + def __init__(self, *, buffer_size: int = 64) -> None: + self._buffer = buffer_size + self._subscribers: dict[int, list[asyncio.Queue[AgentEventEnvelope]]] = {} + + async def publish(self, envelope: AgentEventEnvelope) -> None: + queues = list(self._subscribers.get(envelope.job_id, [])) + for q in queues: + if q.full(): + logger.warning("InMemoryAgentEventBus drop: queue full job_id=%s", envelope.job_id) + continue + await q.put(envelope) + + @contextlib.asynccontextmanager + async def subscribe(self, *, job_id: int) -> AsyncIterator[_InMemoryStream]: + q: asyncio.Queue[AgentEventEnvelope] = asyncio.Queue(maxsize=self._buffer) + self._subscribers.setdefault(job_id, []).append(q) + try: + yield _InMemoryStream(queue=q) + finally: + self._subscribers[job_id].remove(q) + if not self._subscribers[job_id]: + del self._subscribers[job_id] + + +class RedisAgentEventBus: + """生产实现:worker 进程 publish 到 channel,web 进程 subscribe。""" + + def __init__( + self, + *, + redis: Redis, + channel_prefix: str, + buffer_size: int = 64, + ) -> None: + self._redis = redis + self._channel_prefix = channel_prefix + self._buffer = buffer_size + + def _channel(self, job_id: int) -> str: + return f"{self._channel_prefix}:{job_id}:events" + + async def publish(self, envelope: AgentEventEnvelope) -> None: + await self._redis.publish(self._channel(envelope.job_id), envelope.to_json()) + + @contextlib.asynccontextmanager + async def subscribe(self, *, job_id: int) -> AsyncIterator["_RedisStream"]: + pubsub = self._redis.pubsub() + await pubsub.subscribe(self._channel(job_id)) + stream = _RedisStream(pubsub=pubsub) + try: + yield stream + finally: + await pubsub.unsubscribe(self._channel(job_id)) + await pubsub.aclose() + + +@dataclass(slots=True) +class _RedisStream: + pubsub: Any + + async def next(self, *, timeout: float | None = None) -> AgentEventEnvelope: + message = await self.pubsub.get_message( + ignore_subscribe_messages=True, + timeout=timeout if timeout is not None else 0, + ) + if message is None: + raise TimeoutError("No event within timeout") + data = message.get("data") + if isinstance(data, bytes): + data = data.decode("utf-8") + return AgentEventEnvelope.from_json(str(data)) + + async def aclose(self) -> None: + await self.pubsub.aclose() + + +def build_agent_event_bus(*, settings: Settings | None = None, redis: Redis | None = None) -> AgentEventBus: + cfg = settings or get_settings() + if redis is None: + if not cfg.redis_url: + return InMemoryAgentEventBus(buffer_size=cfg.agent_event_bus_buffer_size) + from redis.asyncio import Redis as RedisClient # local import to avoid hard dep at import time + + redis = RedisClient.from_url(cfg.redis_url, decode_responses=True) + return RedisAgentEventBus( + redis=redis, + channel_prefix=cfg.agent_event_channel_prefix, + buffer_size=cfg.agent_event_bus_buffer_size, + ) +``` + +- [ ] **Step 4: 清理 events.py scaffold** + +`app/src/fileflash/agents/harness/events.py` 改为: + +```python +# Kept as a re-export shim until callers migrate to event_bus.py. +from .event_bus import AgentEventEnvelope as AgentEvent + +__all__ = ["AgentEvent"] +``` + +> 此 shim 后续 PR 删除。本 plan 暂不删,避免外部 import 路径同时变更。 + +- [ ] **Step 5: 运行测试** + +Run: `cd app && uv run pytest tests/test_agent_event_bus.py -v` +Expected: PASS(3 个用例) + +- [ ] **Step 6: Commit** + +```bash +git add app/src/fileflash/agents/harness/event_bus.py app/src/fileflash/agents/harness/events.py app/tests/test_agent_event_bus.py +git commit -m "feat(agent): add AgentEventBus with in-memory and Redis impls" +``` + +--- + +## Task 8: `AgentInbox` 服务(写表 + publish) + +**Files:** + +- Create: `app/src/fileflash/agents/harness/inbox.py` +- Create: `app/tests/test_agent_inbox.py` + +- [ ] **Step 1: 写测试** + +```python +# app/tests/test_agent_inbox.py +from datetime import UTC, datetime + +import pytest + +from fileflash.agents.harness.event_bus import InMemoryAgentEventBus +from fileflash.agents.harness.inbox import AgentInbox +from fileflash.models.enums import AgentInboxKind +from fileflash.repositories import AgentInboxMessageRepository + + +@pytest.mark.asyncio +async def test_handle_reply_persists_and_publishes(db_session, sample_background_job): # noqa: ANN001 + repo = AgentInboxMessageRepository(db_session) + ask = await repo.create_ask( + job_id=int(sample_background_job.job_id), + payload={"prompt": "?"}, + ) + await db_session.commit() + + bus = InMemoryAgentEventBus() + inbox = AgentInbox(db=db_session, event_bus=bus) + + async with bus.subscribe(job_id=int(sample_background_job.job_id)) as stream: + msg = await inbox.handle( + job_id=int(sample_background_job.job_id), + kind=AgentInboxKind.REPLY, + payload={"value": "yes"}, + reply_to_id=int(ask.inbox_message_id), + ) + await db_session.commit() + evt = await stream.next(timeout=1.0) + + assert msg.kind == AgentInboxKind.REPLY + assert evt.event_type == "agent.inbox.reply" + assert evt.payload["replyTo"] == str(ask.inbox_message_id) + assert evt.payload["value"] == "yes" + + +@pytest.mark.asyncio +async def test_reply_with_unknown_ask_raises(db_session, sample_background_job): # noqa: ANN001 + bus = InMemoryAgentEventBus() + inbox = AgentInbox(db=db_session, event_bus=bus) + with pytest.raises(ValueError): + await inbox.handle( + job_id=int(sample_background_job.job_id), + kind=AgentInboxKind.REPLY, + payload={"value": "yes"}, + reply_to_id=999999, + ) +``` + +- [ ] **Step 2: 运行测试,确认 fail** + +Run: `cd app && uv run pytest tests/test_agent_inbox.py -v` +Expected: FAIL — module missing. + +- [ ] **Step 3: 实现 `AgentInbox`** + +```python +# app/src/fileflash/agents/harness/inbox.py +from __future__ import annotations + +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from ...models.enums import AgentInboxKind +from ...repositories import AgentInboxMessageRepository +from .event_bus import AgentEventBus, AgentEventEnvelope + + +_INBOX_EVENT_TYPES: dict[AgentInboxKind, str] = { + AgentInboxKind.REPLY: "agent.inbox.reply", + AgentInboxKind.CONTROL_PAUSE: "agent.inbox.control", + AgentInboxKind.CONTROL_RESUME: "agent.inbox.control", + AgentInboxKind.CONTROL_APPROVE: "agent.inbox.control", + AgentInboxKind.CONTROL_DENY: "agent.inbox.control", + AgentInboxKind.CONTROL_SKIP: "agent.inbox.control", + AgentInboxKind.CONTROL_CANCEL: "agent.inbox.control", +} + + +class AgentInbox: + def __init__(self, *, db: AsyncSession, event_bus: AgentEventBus) -> None: + self._db = db + self._bus = event_bus + self._repo = AgentInboxMessageRepository(db) + + async def handle( + self, + *, + job_id: int, + kind: AgentInboxKind, + payload: dict[str, Any], + reply_to_id: int | None = None, + ): + if kind == AgentInboxKind.REPLY: + if reply_to_id is None: + raise ValueError("reply requires reply_to_id") + ask = await self._repo.get_ask(inbox_message_id=reply_to_id) + if ask is None: + raise ValueError(f"ask {reply_to_id} not found") + if ask.job_id != job_id: + raise ValueError(f"ask {reply_to_id} belongs to a different job") + + msg = await self._repo.record_user_message( + job_id=job_id, + kind=kind, + payload=payload, + reply_to_id=reply_to_id, + ) + event_type = _INBOX_EVENT_TYPES[kind] + envelope_payload: dict[str, Any] = {"kind": kind.value, "messageId": str(msg.inbox_message_id)} + if reply_to_id is not None: + envelope_payload["replyTo"] = str(reply_to_id) + if "value" in payload: + envelope_payload["value"] = payload["value"] + await self._bus.publish( + AgentEventEnvelope( + job_id=job_id, + event_type=event_type, + payload=envelope_payload, + emitted_at=datetime.now(UTC), + ) + ) + return msg +``` + +- [ ] **Step 4: 运行测试** + +Run: `cd app && uv run pytest tests/test_agent_inbox.py -v` +Expected: PASS(2 个用例) + +- [ ] **Step 5: Commit** + +```bash +git add app/src/fileflash/agents/harness/inbox.py app/tests/test_agent_inbox.py +git commit -m "feat(agent): add AgentInbox service" +``` + +--- + +## Task 9: `AskProtocol`(worker 等用户回答) + +**Files:** + +- Create: `app/src/fileflash/agents/harness/ask.py` +- Create: `app/tests/test_agent_ask_protocol.py` + +- [ ] **Step 1: 写测试** + +```python +# app/tests/test_agent_ask_protocol.py +import asyncio +from datetime import UTC, datetime + +import pytest + +from fileflash.agents.harness.ask import AskProtocol, AskTimedOut +from fileflash.agents.harness.event_bus import InMemoryAgentEventBus +from fileflash.agents.harness.inbox import AgentInbox +from fileflash.models.enums import AgentInboxKind, AgentInboxStatus +from fileflash.repositories import AgentInboxMessageRepository + + +@pytest.mark.asyncio +async def test_ask_returns_when_reply_arrives(db_session, sample_background_job): # noqa: ANN001 + bus = InMemoryAgentEventBus() + inbox = AgentInbox(db=db_session, event_bus=bus) + protocol = AskProtocol( + db=db_session, + event_bus=bus, + job_id=int(sample_background_job.job_id), + ) + await protocol.start() + try: + async def reply_later(): + await asyncio.sleep(0.05) + # 找到刚创建的 ask + repo = AgentInboxMessageRepository(db_session) + from sqlalchemy import select + from fileflash.models import AgentInboxMessage + ask = await db_session.scalar( + select(AgentInboxMessage) + .where(AgentInboxMessage.kind == AgentInboxKind.ASK) + .order_by(AgentInboxMessage.inbox_message_id.desc()) + ) + await inbox.handle( + job_id=int(sample_background_job.job_id), + kind=AgentInboxKind.REPLY, + payload={"value": "A"}, + reply_to_id=int(ask.inbox_message_id), + ) + await db_session.commit() + + replier = asyncio.create_task(reply_later()) + result = await protocol.ask( + prompt="choose", + schema={"choice": ["A", "B"]}, + timeout_sec=2.0, + ) + await replier + finally: + await protocol.aclose() + + assert result == "A" + + +@pytest.mark.asyncio +async def test_ask_times_out(db_session, sample_background_job): # noqa: ANN001 + bus = InMemoryAgentEventBus() + protocol = AskProtocol( + db=db_session, + event_bus=bus, + job_id=int(sample_background_job.job_id), + ) + await protocol.start() + try: + with pytest.raises(AskTimedOut): + await protocol.ask(prompt="?", schema={}, timeout_sec=0.1) + finally: + await protocol.aclose() + + # 验证 ask 已被标 timed_out + from sqlalchemy import select + from fileflash.models import AgentInboxMessage + asks = list( + await db_session.scalars( + select(AgentInboxMessage).where(AgentInboxMessage.kind == AgentInboxKind.ASK) + ) + ) + assert asks + assert asks[-1].status == AgentInboxStatus.TIMED_OUT +``` + +- [ ] **Step 2: 运行测试,确认 fail** + +Run: `cd app && uv run pytest tests/test_agent_ask_protocol.py -v` +Expected: FAIL — module missing. + +- [ ] **Step 3: 实现 `AskProtocol`** + +```python +# app/src/fileflash/agents/harness/ask.py +from __future__ import annotations + +import asyncio +import contextlib +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from ...models.enums import AgentInboxKind +from ...repositories import AgentInboxMessageRepository +from .event_bus import AgentEventBus, AgentEventEnvelope + + +class AskTimedOut(Exception): + def __init__(self, *, ask_id: int) -> None: + super().__init__(f"Ask {ask_id} timed out") + self.ask_id = ask_id + + +class AskProtocol: + """worker 端:创建 ask 表条目、publish agent.ask 事件、阻塞等 reply 经 inbox channel 唤醒。 + + 生命周期绑定单个 job_id。`start()` 后开始订阅;`aclose()` 释放订阅。 + """ + + def __init__( + self, + *, + db: AsyncSession, + event_bus: AgentEventBus, + job_id: int, + ) -> None: + self._db = db + self._bus = event_bus + self._job_id = job_id + self._repo = AgentInboxMessageRepository(db) + self._waiters: dict[int, asyncio.Future[Any]] = {} + self._sub_ctx = None + self._sub_stream = None + self._sub_task: asyncio.Task[None] | None = None + + async def start(self) -> None: + self._sub_ctx = self._bus.subscribe(job_id=self._job_id) + self._sub_stream = await self._sub_ctx.__aenter__() + self._sub_task = asyncio.create_task(self._listen()) + + async def aclose(self) -> None: + if self._sub_task is not None: + self._sub_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._sub_task + if self._sub_ctx is not None: + await self._sub_ctx.__aexit__(None, None, None) + for fut in self._waiters.values(): + if not fut.done(): + fut.cancel() + + async def ask( + self, + *, + prompt: str, + schema: dict[str, Any], + timeout_sec: float, + ) -> Any: + msg = await self._repo.create_ask( + job_id=self._job_id, + payload={"prompt": prompt, "schema": schema, "timeoutSec": timeout_sec}, + ) + await self._db.commit() + + await self._bus.publish( + AgentEventEnvelope( + job_id=self._job_id, + event_type="agent.ask", + payload={ + "messageId": str(msg.inbox_message_id), + "prompt": prompt, + "schema": schema, + "timeoutSec": timeout_sec, + }, + emitted_at=datetime.now(UTC), + ) + ) + + loop = asyncio.get_running_loop() + fut: asyncio.Future[Any] = loop.create_future() + self._waiters[int(msg.inbox_message_id)] = fut + try: + value = await asyncio.wait_for(fut, timeout=timeout_sec) + except asyncio.TimeoutError as exc: + from ...models.enums import AgentInboxStatus + ask = await self._repo.get_ask(inbox_message_id=int(msg.inbox_message_id)) + if ask is not None: + ask.status = AgentInboxStatus.TIMED_OUT + ask.answered_at = datetime.now(UTC) + await self._db.commit() + raise AskTimedOut(ask_id=int(msg.inbox_message_id)) from exc + finally: + self._waiters.pop(int(msg.inbox_message_id), None) + + await self._repo.mark_answered( + inbox_message_id=int(msg.inbox_message_id), + answered_at=datetime.now(UTC), + ) + await self._db.commit() + return value + + async def _listen(self) -> None: + assert self._sub_stream is not None + while True: + try: + envelope = await self._sub_stream.next(timeout=None) + except asyncio.CancelledError: + raise + except Exception: # noqa: BLE001 + continue + if envelope.event_type != "agent.inbox.reply": + continue + reply_to = envelope.payload.get("replyTo") + if reply_to is None: + continue + try: + ask_id = int(reply_to) + except (TypeError, ValueError): + continue + fut = self._waiters.get(ask_id) + if fut is None or fut.done(): + continue + fut.set_result(envelope.payload.get("value")) +``` + +- [ ] **Step 4: 运行测试** + +Run: `cd app && uv run pytest tests/test_agent_ask_protocol.py -v` +Expected: PASS(2 个用例) + +- [ ] **Step 5: Commit** + +```bash +git add app/src/fileflash/agents/harness/ask.py app/tests/test_agent_ask_protocol.py +git commit -m "feat(agent): add AskProtocol for worker-to-user blocking ask" +``` + +--- + +## Task 10: `POST /agent/jobs/{job_id}/messages` 路由 + +**Files:** + +- Modify: `app/src/fileflash/routers/agent.py` +- Modify: `app/src/fileflash/core/deps.py` +- Modify: `app/tests/test_agent_routes.py` + +- [ ] **Step 1: 在 `deps.py` 增加 EventBus 依赖** + +```python +# app/src/fileflash/core/deps.py — 在文件末尾增加 +from ..agents.harness.event_bus import AgentEventBus, build_agent_event_bus + +_event_bus_singleton: AgentEventBus | None = None + + +def get_agent_event_bus() -> AgentEventBus: + global _event_bus_singleton + if _event_bus_singleton is None: + _event_bus_singleton = build_agent_event_bus() + return _event_bus_singleton +``` + +> 注:如果 deps.py 已有 module-level singleton 模式,沿用既有写法;否则用上述简单单例。 + +- [ ] **Step 2: 在 `routers/agent.py` 新增路由** + +在 `cancel_agent_job` 之前插入: + +```python +from ..agents.harness.event_bus import AgentEventBus +from ..agents.harness.inbox import AgentInbox +from ..core.deps import get_agent_event_bus +from ..models.enums import AgentInboxKind +from ..schemas.agent import AgentInboxMessageRequest, AgentInboxMessageResponse + + +@router.post("/jobs/{job_id}/messages") +async def post_agent_job_message( + job_id: str, + payload: AgentInboxMessageRequest, + current_user: Annotated[User, Depends(get_current_user)], + db: Annotated[AsyncSession, Depends(get_db)], + event_bus: Annotated[AgentEventBus, Depends(get_agent_event_bus)], +): + parsed_job_id = _parse_job_id(job_id) + job = await db.scalar( + select(BackgroundJob).where( + and_( + BackgroundJob.job_id == parsed_job_id, + BackgroundJob.requested_by == current_user.user_id, + BackgroundJob.task_type.in_(["agent.plan", "agent.execute"]), + ) + ) + ) + if job is None: + raise ApiError(status_code=404, code=404, message="Job not found") + + kind = AgentInboxKind(payload.kind) + reply_to_id: int | None = None + if payload.reply_to is not None: + try: + reply_to_id = int(payload.reply_to) + except ValueError as exc: + raise ApiError(status_code=400, code=400, message="Invalid replyTo") from exc + + inbox = AgentInbox(db=db, event_bus=event_bus) + try: + msg = await inbox.handle( + job_id=parsed_job_id, + kind=kind, + payload=_inbox_payload_from_request(payload), + reply_to_id=reply_to_id, + ) + except ValueError as exc: + raise ApiError(status_code=400, code=400, message=str(exc)) from exc + await db.commit() + + data = AgentInboxMessageResponse( + inbox_message_id=str(msg.inbox_message_id), + kind=payload.kind, + accepted_at=msg.created_at, + ) + return api_success(data=data.model_dump(by_alias=True), message="Message accepted") + + +def _inbox_payload_from_request(req: AgentInboxMessageRequest) -> dict[str, Any]: + body: dict[str, Any] = {} + if req.value is not None: + body["value"] = req.value + if req.metadata: + body["metadata"] = req.metadata + return body +``` + +> 注:`Any` 已经在 typing 中;如未 import,在文件顶部 `from typing import Annotated, Any`。 + +- [ ] **Step 3: 扩展 `test_agent_routes.py`,新增一组用例** + +```python +# app/tests/test_agent_routes.py — 在文件末尾追加 +from fileflash.agents.harness.event_bus import InMemoryAgentEventBus +from fileflash.core.deps import get_agent_event_bus +from fileflash.models import AgentInboxMessage +from fileflash.models.enums import AgentInboxKind + + +def _build_app_with_bus(bus: InMemoryAgentEventBus, db_stub) -> FastAPI: # noqa: ANN001 + app = FastAPI() + app.include_router(router) + app.add_exception_handler(ApiError, api_error_handler) + app.dependency_overrides[get_db] = lambda: db_stub + app.dependency_overrides[get_current_user] = lambda: User(user_id=7) + app.dependency_overrides[get_agent_event_bus] = lambda: bus + return app + + +def test_post_message_control_pause_accepted(db_session, sample_background_job): # noqa: ANN001 + bus = InMemoryAgentEventBus() + # 真 DB session 测;不再用 StubDb + app = FastAPI() + app.include_router(router) + app.add_exception_handler(ApiError, api_error_handler) + app.dependency_overrides[get_db] = lambda: db_session + app.dependency_overrides[get_current_user] = lambda: User( + user_id=int(sample_background_job.requested_by) + ) + app.dependency_overrides[get_agent_event_bus] = lambda: bus + + client = TestClient(app) + resp = client.post( + f"/agent/jobs/{sample_background_job.job_id}/messages", + json={"kind": "control.pause"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["data"]["kind"] == "control.pause" +``` + +> 注:`sample_background_job` fixture 应该指向 `requested_by=7` 的 BackgroundJob;如 fixture 不一致,按既有 fixture 命名调整。 + +- [ ] **Step 4: 运行测试** + +Run: `cd app && uv run pytest tests/test_agent_routes.py -v -k "post_message"` +Expected: PASS + +- [ ] **Step 5: Commit** + +```bash +git add app/src/fileflash/core/deps.py app/src/fileflash/routers/agent.py app/tests/test_agent_routes.py +git commit -m "feat(agent): add POST /agent/jobs/{id}/messages upstream channel" +``` + +--- + +## Task 11: SSE 端点改为订阅 EventBus + +**Files:** + +- Modify: `app/src/fileflash/routers/agent.py` + +- [ ] **Step 1: 替换 `stream_agent_job_events` 与 `event_stream` 内部逻辑** + +把现有 `stream_agent_job_events` 整体替换为: + +```python +@router.get("/jobs/{job_id}/events") +async def stream_agent_job_events( + job_id: str, + current_user: Annotated[User, Depends(get_current_user)], + db: Annotated[AsyncSession, Depends(get_db)], + event_bus: Annotated[AgentEventBus, Depends(get_agent_event_bus)], +): + parsed_job_id = _parse_job_id(job_id) + initial_events, initial_terminal = await _agent_job_events_for_job( + db=db, + job_id=parsed_job_id, + user_id=int(current_user.user_id), + ) + + async def event_stream(): + seen: set[str] = set() + for event in initial_events: + seen.add(event.id) + yield _format_sse_event(event) + if initial_terminal: + return + async with event_bus.subscribe(job_id=parsed_job_id) as stream: + while True: + try: + envelope = await stream.next(timeout=30.0) + except TimeoutError: + # 30s 心跳,避免代理断连 + yield ": keep-alive\n\n" + continue + event = _envelope_to_job_event(envelope) + if event.id in seen: + continue + seen.add(event.id) + yield _format_sse_event(event) + if event.type in {"job.succeeded", "job.failed", "job.canceled"}: + break + + return StreamingResponse( + event_stream(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + }, + ) +``` + +新增辅助函数: + +```python +def _envelope_to_job_event(env: "AgentEventEnvelope") -> AgentJobEvent: + return AgentJobEvent( + id=env.event_id or f"{env.job_id}:{env.event_type}:{env.emitted_at.isoformat()}", + job_id=str(env.job_id), + task_type="agent.execute", # type 在 envelope payload 里冗余,简化 + type=env.event_type, # type: ignore[arg-type] + status=str(env.payload.get("status") or ""), + agent_phase=env.payload.get("agentPhase"), + message=str(env.payload.get("message") or ""), + data=dict(env.payload.get("data") or env.payload), + timestamp=env.emitted_at, + ) +``` + +把 `AGENT_EVENT_POLL_INTERVAL_SEC` 常量删掉,把 `asyncio.sleep(...)` 调用一并删除。在文件顶部加入: + +```python +from ..agents.harness.event_bus import AgentEventBus, AgentEventEnvelope +``` + +> 注:保留 `_agent_job_events_for_job` 用作 initial replay(连接刚建立时拉一次历史),因为 pub/sub 不持久化。 + +- [ ] **Step 2: 扩展 SSE 测试** + +在 `test_agent_routes.py` 已有的 SSE 用例之外加一个: + +```python +def test_sse_streams_published_events(db_session, sample_background_job): # noqa: ANN001 + bus = InMemoryAgentEventBus() + app = FastAPI() + app.include_router(router) + app.add_exception_handler(ApiError, api_error_handler) + app.dependency_overrides[get_db] = lambda: db_session + app.dependency_overrides[get_current_user] = lambda: User( + user_id=int(sample_background_job.requested_by) + ) + app.dependency_overrides[get_agent_event_bus] = lambda: bus + + import asyncio + from datetime import UTC, datetime + + async def producer(): + await asyncio.sleep(0.1) + await bus.publish( + AgentEventEnvelope( + job_id=int(sample_background_job.job_id), + event_type="agent.progress", + payload={"step": 1, "total": 3, "message": "halfway"}, + emitted_at=datetime.now(UTC), + ) + ) + await asyncio.sleep(0.05) + await bus.publish( + AgentEventEnvelope( + job_id=int(sample_background_job.job_id), + event_type="job.succeeded", + payload={"status": "succeeded"}, + emitted_at=datetime.now(UTC), + ) + ) + + client = TestClient(app) + # TestClient 不直接支持 async producer 并发;用线程 + import threading + + def run_producer() -> None: + asyncio.run(producer()) + + t = threading.Thread(target=run_producer) + t.start() + with client.stream("GET", f"/agent/jobs/{sample_background_job.job_id}/events") as resp: + lines = [] + for chunk in resp.iter_lines(): + if chunk: + lines.append(chunk) + if any("job.succeeded" in line for line in lines): + break + t.join() + + assert any("agent.progress" in line for line in lines) + assert any("job.succeeded" in line for line in lines) +``` + +> 注:原有依赖 0.6s DB 轮询的 SSE 测试如失败,按新模型重写为"测 initial replay + 测订阅流"两段。 + +- [ ] **Step 3: 运行测试** + +Run: `cd app && uv run pytest tests/test_agent_routes.py -v -k "sse"` +Expected: PASS + +- [ ] **Step 4: Commit** + +```bash +git add app/src/fileflash/routers/agent.py app/tests/test_agent_routes.py +git commit -m "feat(agent): replace SSE polling with EventBus subscription" +``` + +--- + +## Task 12: 删除 `POST /agent/cancel/{job_id}` + +**Files:** + +- Modify: `app/src/fileflash/routers/agent.py` +- Modify: `app/tests/test_agent_routes.py` + +- [ ] **Step 1: 删除 `cancel_agent_job` 函数与对应的 `from ..schemas.agent import ... CancelAgentResponse` 引用** + +确认 import 调整后无 unused。 + +- [ ] **Step 2: 删除 `test_agent_routes.py` 中所有 `POST /agent/cancel` 测试用例** + +Run: `grep -n "cancel_agent\|/agent/cancel" app/tests/test_agent_routes.py` +逐条删除。 + +- [ ] **Step 3: 全仓搜索其他引用并清理** + +Run: `grep -rn "agent/cancel\|cancelAgentJob" app/ web/` +Expected: 仅 `web/` 下有引用(前端 plan 处理),后端无引用。如果后端有,一并删除。 + +- [ ] **Step 4: 运行全部 agent 测试** + +Run: `cd app && uv run pytest tests/test_agent_routes.py -v` +Expected: PASS + +- [ ] **Step 5: Commit** + +```bash +git add app/src/fileflash/routers/agent.py app/tests/test_agent_routes.py +git commit -m "refactor(agent): drop legacy POST /agent/cancel route" +``` + +--- + +## Task 13: `ExecuteRunner` 接入 inbox(pause/resume/skip/approve/cancel) + +**Files:** + +- Modify: `app/src/fileflash/agents/runtime/execute_runner.py` +- Modify: `app/tests/test_agent_plan_execute_runtime.py` + +- [ ] **Step 1: 在 `ExecuteRunner.__init__` 增加 EventBus 依赖与状态** + +```python +class ExecuteRunner: + def __init__( + self, + *, + policy_guard: PolicyGuard | None = None, + event_bus: AgentEventBus | None = None, + ) -> None: + self.policy_guard = policy_guard or PolicyGuard() + self.event_bus = event_bus +``` + +> 默认 `None` 时退化为静默(不 publish),用于单元测试与旧调用兼容。 + +- [ ] **Step 2: 在 `ExecuteRunner.run` 顶部新增 step 边界控制处理** + +把原 line 62-66 的循环顶部替换为: + +```python + from ...repositories import AgentInboxMessageRepository + + inbox_repo = AgentInboxMessageRepository(db) + paused = False + + for action in actions: + await db.refresh(job) + if job.cancel_requested_at is not None: + raise AgentJobCanceled() + + # ---- step 边界 inbox 处理 ---- + while True: + pending = await inbox_repo.list_pending_controls(job_id=int(job.job_id)) + for ctrl in pending: + if ctrl.kind == AgentInboxKind.CONTROL_CANCEL: + await inbox_repo.mark_dropped(inbox_message_id=int(ctrl.inbox_message_id)) + job.cancel_requested_at = datetime.now(UTC) + await db.commit() + raise AgentJobCanceled() + if ctrl.kind == AgentInboxKind.CONTROL_PAUSE: + paused = True + await inbox_repo.mark_dropped(inbox_message_id=int(ctrl.inbox_message_id)) + await self._publish_state("agent.paused", job_id=int(job.job_id)) + elif ctrl.kind == AgentInboxKind.CONTROL_RESUME: + paused = False + await inbox_repo.mark_dropped(inbox_message_id=int(ctrl.inbox_message_id)) + await self._publish_state("agent.resumed", job_id=int(job.job_id)) + elif ctrl.kind == AgentInboxKind.CONTROL_SKIP: + # 标记跳过当前 step;继续外层 for + await inbox_repo.mark_dropped(inbox_message_id=int(ctrl.inbox_message_id)) + warnings.append(f"Step {action.step} skipped by user") + applied -= 0 # 不计入 applied + await db.commit() + break # break inner pending loop + else: + # approve / deny — 单工具实时审批,由 policy_guard 读取,这里仅消费 + await inbox_repo.mark_dropped(inbox_message_id=int(ctrl.inbox_message_id)) + await db.commit() + if not paused: + break + # paused: 等 100ms 再轮询 + await asyncio.sleep(0.1) + # ---- 结束 inbox 处理 ---- +``` + +并在顶部 import 处新增: + +```python +import asyncio + +from ...models.enums import AgentInboxKind +from ..harness.event_bus import AgentEventBus, AgentEventEnvelope +``` + +新增 `_publish_state` 实例方法(与 `run` 同一类): + +```python + async def _publish_state(self, event_type: str, *, job_id: int) -> None: + if self.event_bus is None: + return + await self.event_bus.publish( + AgentEventEnvelope( + job_id=job_id, + event_type=event_type, + payload={}, + emitted_at=datetime.now(UTC), + ) + ) +``` + +- [ ] **Step 3: 把工具调用的事件 publish 也接上 EventBus** + +在 line 103-110(`append_step` running 之后)、line 130-140(`finish_step` succeeded 之后)、以及 failure 分支,分别插入: + +```python + # 工具开始 + if self.event_bus is not None: + await self.event_bus.publish( + AgentEventEnvelope( + job_id=int(job.job_id), + event_type="tool.started", + payload={ + "step": int(action.step), + "tool": str(action.tool), + "input": resolved_input, + }, + emitted_at=started, + ) + ) + + # 工具成功 + if self.event_bus is not None: + await self.event_bus.publish( + AgentEventEnvelope( + job_id=int(job.job_id), + event_type="tool.succeeded", + payload={ + "step": int(action.step), + "tool": str(action.tool), + "output": safe_output, + "durationMs": duration_ms, + }, + emitted_at=datetime.now(UTC), + ) + ) + + # 工具失败(含 resolve / dispatch 两个分支) + if self.event_bus is not None: + await self.event_bus.publish( + AgentEventEnvelope( + job_id=int(job.job_id), + event_type="tool.failed", + payload={ + "step": int(action.step), + "tool": str(action.tool), + "errorMessage": f"{type(exc).__name__}: {exc}"[:2000], + }, + emitted_at=datetime.now(UTC), + ) + ) +``` + +- [ ] **Step 4: 写 / 改测试,覆盖 pause-resume、cancel-via-inbox** + +在 `test_agent_plan_execute_runtime.py` 末尾追加: + +```python +import asyncio + +from fileflash.agents.harness.event_bus import InMemoryAgentEventBus +from fileflash.agents.harness.inbox import AgentInbox +from fileflash.agents.runtime.execute_runner import AgentJobCanceled, ExecuteRunner +from fileflash.models.enums import AgentInboxKind + + +@pytest.mark.asyncio +async def test_execute_runner_pauses_then_resumes( + db_session, executable_job_with_two_steps, # noqa: ANN001 +): + bus = InMemoryAgentEventBus() + inbox = AgentInbox(db=db_session, event_bus=bus) + runner = ExecuteRunner(event_bus=bus) + + async def control_later(): + await asyncio.sleep(0.05) + await inbox.handle( + job_id=int(executable_job_with_two_steps.job_id), + kind=AgentInboxKind.CONTROL_PAUSE, + payload={}, + ) + await db_session.commit() + await asyncio.sleep(0.2) + await inbox.handle( + job_id=int(executable_job_with_two_steps.job_id), + kind=AgentInboxKind.CONTROL_RESUME, + payload={}, + ) + await db_session.commit() + + sender = asyncio.create_task(control_later()) + result = await runner.run(db=db_session, job=executable_job_with_two_steps) + await sender + + assert result.applied_actions == 2 + + +@pytest.mark.asyncio +async def test_execute_runner_canceled_via_inbox( + db_session, executable_job_with_two_steps, # noqa: ANN001 +): + bus = InMemoryAgentEventBus() + inbox = AgentInbox(db=db_session, event_bus=bus) + runner = ExecuteRunner(event_bus=bus) + + async def cancel_later(): + await asyncio.sleep(0.05) + await inbox.handle( + job_id=int(executable_job_with_two_steps.job_id), + kind=AgentInboxKind.CONTROL_CANCEL, + payload={}, + ) + await db_session.commit() + + sender = asyncio.create_task(cancel_later()) + with pytest.raises(AgentJobCanceled): + await runner.run(db=db_session, job=executable_job_with_two_steps) + await sender +``` + +> 注:`executable_job_with_two_steps` 是新 fixture。如项目已有可执行 job 的 fixture(参见 test_agent_plan_execute_runtime.py),沿用即可,否则在 `conftest.py` 加一个简单 fixture 构造含两步 read-only 计划的 job。 + +- [ ] **Step 5: 运行测试** + +Run: `cd app && uv run pytest tests/test_agent_plan_execute_runtime.py -v -k "pause or canceled_via_inbox"` +Expected: PASS + +- [ ] **Step 6: Commit** + +```bash +git add app/src/fileflash/agents/runtime/execute_runner.py app/tests/test_agent_plan_execute_runtime.py +git commit -m "feat(agent): wire ExecuteRunner to inbox controls and event bus" +``` + +--- + +## Task 14: `PlanRunner` 接入 ask(基础占位 + 接口暴露) + +**Files:** + +- Modify: `app/src/fileflash/agents/runtime/plan_runner.py` +- Modify: `app/tests/test_agent_plan_execute_runtime.py` + +> 本 Task 不引入 LLM 触发 ask 的判断逻辑(那需要改 prompt 与 tool-use 模板,留给后续)。本 Task 仅把 `AskProtocol` 注入到 `PlanRunner` 与 `ExecuteRunner`,并提供 `await self._ask(...)` 辅助方法,供后续 prompt 模板调用。 + +- [ ] **Step 1: 在 `PlanRunner.__init__` 增加 EventBus + ask 启停** + +```python +class PlanRunner: + def __init__( + self, + *, + settings: Settings | None = None, + planner_client: PlannerClient | None = None, + event_bus: AgentEventBus | None = None, + ) -> None: + self.settings = settings or get_settings() + self.planner_client = planner_client or AnthropicPlannerClient(settings=self.settings) + self.event_bus = event_bus +``` + +在 import 处新增: + +```python +from ..harness.event_bus import AgentEventBus +from ..harness.ask import AskProtocol +``` + +在 `run` 方法的开头: + +```python + ask: AskProtocol | None = None + if self.event_bus is not None: + ask = AskProtocol(db=db, event_bus=self.event_bus, job_id=int(job.job_id)) + await ask.start() + try: + # ... 现有 run 逻辑 ... + return result + finally: + if ask is not None: + await ask.aclose() +``` + +在类中新增辅助: + +```python + async def _ask( + self, + *, + ask: AskProtocol | None, + prompt: str, + schema: dict[str, Any], + ) -> Any | None: + if ask is None: + return None + return await ask.ask( + prompt=prompt, + schema=schema, + timeout_sec=float(self.settings.agent_inbox_ask_timeout_sec), + ) +``` + +> 后续 prompt 模板里若决定需要澄清,调 `await self._ask(ask=ask, prompt=..., schema=...)`。本 plan 仅做接线;触发逻辑留到后续。 + +- [ ] **Step 2: ExecuteRunner 同样增加 ask 接线(已 publish 事件,但当前不主动调 ask)** + +只是为了对称,方便后续 prompt 模板复用。改 `ExecuteRunner.run` 顶部: + +```python + ask: AskProtocol | None = None + if self.event_bus is not None: + ask = AskProtocol(db=db, event_bus=self.event_bus, job_id=int(job.job_id)) + await ask.start() + try: + # ... 现有 run 逻辑 ... + return result + finally: + if ask is not None: + await ask.aclose() +``` + +> 当前 ExecuteRunner 不调用 `ask.ask()`;这只是接线。 + +- [ ] **Step 3: 跳过端到端 ask 触发用例** + +> 触发 LLM 调 ask 取决于后续 prompt 模板的改动,超出本 plan 范围。`AskProtocol` 自身的行为(成功回答 + 超时 + status 写回)已在 Task 9 的两个用例完整覆盖。本 Task 不再写额外测试。 + +- [ ] **Step 4: 运行测试** + +Run: `cd app && uv run pytest tests/test_agent_plan_execute_runtime.py -v` +Expected: 既有用例 PASS(接线为可选注入,不破坏旧调用方) + +- [ ] **Step 5: Commit** + +```bash +git add app/src/fileflash/agents/runtime/plan_runner.py app/src/fileflash/agents/runtime/execute_runner.py app/tests/test_agent_plan_execute_runtime.py +git commit -m "feat(agent): wire AskProtocol into PlanRunner and ExecuteRunner" +``` + +--- + +## Task 15: worker 装配 EventBus 与 runner 注入 + +**Files:** + +- Modify: `app/src/fileflash/agents/worker.py` + +- [ ] **Step 1: 在 `AgentWorkerConsumer.__init__` 中创建 EventBus 单例并下发** + +```python +class AgentWorkerConsumer: + def __init__( + self, + *, + queue: RedisStreamJobQueue, + session_factory: async_sessionmaker[AsyncSession] = SessionLocal, + event_bus: AgentEventBus | None = None, + ) -> None: + self._settings = get_settings() + self._queue = queue + self._session_factory = session_factory + self._event_bus = event_bus or build_agent_event_bus(settings=self._settings) +``` + +在 imports 处新增: + +```python +from ..agents.harness.event_bus import AgentEventBus, build_agent_event_bus +``` + +- [ ] **Step 2: 在 `_run_job` / `_process_message` 中创建 runner 时传入 event_bus** + +找到现有 `PlanRunner()` / `ExecuteRunner(...)` 实例化点(在 `_run_job` 内),替换为: + +```python + if message.task_type == "agent.plan": + runner = PlanRunner(event_bus=self._event_bus) + result = await runner.run(db=db, job=fresh_job) + ... + elif message.task_type == "agent.execute": + runner = ExecuteRunner(event_bus=self._event_bus) + result = await runner.run(db=db, job=fresh_job) + ... +``` + +> 注:以现有代码的实例化位置为准;保持依赖注入路径一致。 + +- [ ] **Step 3: 在 `_mark_canceled` / `_mark_failed` / `_mark_succeeded` 中也 publish 终态事件** + +```python + async def _publish_terminal( + self, + *, + job_id: int, + event_type: str, + payload: dict[str, Any] | None = None, + ) -> None: + await self._event_bus.publish( + AgentEventEnvelope( + job_id=job_id, + event_type=event_type, + payload=payload or {}, + emitted_at=datetime.now(UTC), + ) + ) +``` + +并在三个 mark 函数末尾分别 `await self._publish_terminal(...)`,事件类型对应 `job.canceled` / `job.failed` / `job.succeeded`。 + +- [ ] **Step 4: 加最小验证测试** + +`app/tests/test_agent_worker.py` 已存在;在末尾追加: + +```python +@pytest.mark.asyncio +async def test_worker_publishes_terminal_event(...): # 沿用既有 test_agent_worker.py 的 fixture + ... + # 注入 InMemoryAgentEventBus,跑一个 succeed 流,断言收到 job.succeeded envelope +``` + +> 注:如 test_agent_worker.py 现有结构难以注入 event_bus,跳过此 step,依赖 Task 16 的端到端验证。 + +- [ ] **Step 5: 运行测试** + +Run: `cd app && uv run pytest tests/test_agent_worker.py -v` +Expected: PASS + +- [ ] **Step 6: Commit** + +```bash +git add app/src/fileflash/agents/worker.py app/tests/test_agent_worker.py +git commit -m "feat(agent): inject EventBus into worker and publish terminal events" +``` + +--- + +## Task 16: 端到端集成测试(POST 消息 → worker 收到 → publish → SSE 收到) + +**Files:** + +- Create: `app/tests/test_agent_a_end_to_end.py` + +- [ ] **Step 1: 写端到端测试** + +```python +# app/tests/test_agent_a_end_to_end.py +from __future__ import annotations + +import asyncio +from datetime import UTC, datetime + +import pytest + +from fileflash.agents.harness.event_bus import ( + AgentEventEnvelope, + InMemoryAgentEventBus, +) +from fileflash.agents.harness.inbox import AgentInbox +from fileflash.agents.runtime.execute_runner import ( + AgentJobCanceled, + ExecuteRunner, +) +from fileflash.models.enums import AgentInboxKind + + +@pytest.mark.asyncio +async def test_user_pause_then_cancel_via_inbox( + db_session, executable_job_with_two_steps, # noqa: ANN001 +): + bus = InMemoryAgentEventBus() + inbox = AgentInbox(db=db_session, event_bus=bus) + runner = ExecuteRunner(event_bus=bus) + + seen_events: list[str] = [] + + async def consumer(): + async with bus.subscribe(job_id=int(executable_job_with_two_steps.job_id)) as stream: + for _ in range(8): + try: + env = await stream.next(timeout=2.0) + except TimeoutError: + break + seen_events.append(env.event_type) + if env.event_type == "agent.paused": + await inbox.handle( + job_id=int(executable_job_with_two_steps.job_id), + kind=AgentInboxKind.CONTROL_CANCEL, + payload={}, + ) + await db_session.commit() + + listener = asyncio.create_task(consumer()) + + async def pause_soon(): + await asyncio.sleep(0.05) + await inbox.handle( + job_id=int(executable_job_with_two_steps.job_id), + kind=AgentInboxKind.CONTROL_PAUSE, + payload={}, + ) + await db_session.commit() + + nudger = asyncio.create_task(pause_soon()) + + with pytest.raises(AgentJobCanceled): + await runner.run(db=db_session, job=executable_job_with_two_steps) + + await nudger + listener.cancel() + with pytest.raises(asyncio.CancelledError): + await listener + + assert "agent.paused" in seen_events + assert "tool.started" in seen_events or "tool.failed" in seen_events +``` + +- [ ] **Step 2: 运行** + +Run: `cd app && uv run pytest tests/test_agent_a_end_to_end.py -v` +Expected: PASS + +- [ ] **Step 3: 全部 agent 测试 smoke** + +Run: `cd app && uv run pytest tests/ -k "agent" -v` +Expected: 全部 PASS(含旧的 test_agent_routes.py / test_agent_repositories.py / test_agent_plan_execute_runtime.py) + +- [ ] **Step 4: Commit** + +```bash +git add app/tests/test_agent_a_end_to_end.py +git commit -m "test(agent): end-to-end pause + cancel via inbox" +``` + +--- + +## Acceptance Checklist(实施完成判定) + +- [ ] `app/src/fileflash/agents/harness/event_bus.py` 提供 `InMemoryAgentEventBus` 与 `RedisAgentEventBus`,并通过 `build_agent_event_bus` 工厂自动选择 +- [ ] `AgentInboxMessage` 表通过 V14 Flyway 迁移创建,ORM model + repository 已接入 +- [ ] `POST /agent/jobs/{job_id}/messages` 接受 7 种 kind(reply + 6 种 control) +- [ ] `POST /agent/cancel/{job_id}` 已删除;取消统一走 inbox `control.cancel` +- [ ] SSE 端点不再轮询 DB;初始 replay 后纯订阅 EventBus,包含 30s 心跳 +- [ ] `ExecuteRunner` 在 step 边界处理 pause/resume/skip/approve/deny/cancel +- [ ] `PlanRunner` 与 `ExecuteRunner` 都启动了 `AskProtocol`;后续 prompt 模板可调 `_ask` +- [ ] `worker.py` 注入 EventBus 并 publish `job.succeeded` / `job.failed` / `job.canceled` 终态 +- [ ] 端到端集成测试覆盖 pause → cancel via inbox 全链路 + +**注意(不在本 plan 范围):** + +- 前端接入新事件类型与上行通道 — 留给 `2026-05-26-agent-A-interaction-frontend.md` +- prompt 模板里何时调 `ask` — 留给后续;本 plan 只提供接口 +- worker 多副本下"等用户回答的 worker 被杀"恢复机制 — 仅靠 `agent_inbox_ask_timeout_sec` 兜底;进一步的 owner 恢复留给后续 plan diff --git a/docs/superpowers/plans/2026-05-26-agent-A-interaction-frontend.md b/docs/superpowers/plans/2026-05-26-agent-A-interaction-frontend.md new file mode 100644 index 0000000..49fa467 --- /dev/null +++ b/docs/superpowers/plans/2026-05-26-agent-A-interaction-frontend.md @@ -0,0 +1,1579 @@ +# Agent 子项目 A(交互/反馈层)— 前端实现计划 + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** 把前端从"单向 SSE + 整个 job 取消"升级为"双向交互 + step 级 pause/resume/skip/approve + agent 中途提问 + 实时进度/思考/工具增量输出",匹配 A-backend 已落地的事件类型与 `POST /agent/jobs/{id}/messages` 上行通道。 + +**Architecture:** 在 `types/agent.d.ts` 扩展事件字面量与上行消息类型;`api/agent.ts` 用 `sendAgentMessage` + 6 个 control helper 取代 `cancelAgentJob`;`useAgentSession.ts` 引入 `waiting_for_user`/`paused` 状态、ask 缓存、`pauseTurn/resumeTurn/replyToAsk/...` 方法,cancel 改走 `control.cancel`;新增 `AskPrompt.vue` + `ControlBar.vue` 两个原子组件,TurnEntry 内嵌;TaskInputDock 在 waiting_for_user/paused 时切换主输入框 disable。 + +**Tech Stack:** Vue 3 + TypeScript + Vitest + bun(**不用 npm**)+ 既有 i18n 体系 + 既有 atomic 设计语言(Industrial Dashboard,参见 `frontend_aesthetic.md`)。 + +**Spec:** `docs/superpowers/specs/2026-05-26-agent-improvements-design.md` 子项目 A 部分(前端章节 A.7、A.8、A.10) + +**前置条件:** A-backend plan 已落地(commit 包含 `AgentInboxMessage` 模型、`POST /agent/jobs/{id}/messages`、SSE event_bus 推送、新 14 种 `AgentJobEventType`) + +--- + +## File Structure + +**新建** + +- `web/src/components/organisms/agent/AskPrompt.vue` — 渲染单条 agent.ask 的输入气泡(选择型 / 自由文本,含 timeout 倒计时) +- `web/src/components/organisms/agent/ControlBar.vue` — 渲染单 turn 的 pause/resume/skip/cancel 按钮组 +- `web/src/composables/useAskTimeout.ts` — ask 倒计时小工具(含 i18n 友好的 mm:ss 格式化) + +**修改** + +- `web/src/types/agent.d.ts` — 扩展 `AgentJobEventType`;新增 `MsgStatus`-相关、`AgentInboxMessageKind` / `AgentInboxMessageRequest` / `AgentInboxMessageResponse`、`AgentAskPayload` / `AgentProgressPayload` / `AgentThinkingPayload` / `AgentToolPartialPayload` +- `web/src/api/agent.ts` — 删除 `cancelAgentJob`;新增 `sendAgentMessage` 与 6 个 helper:`sendAgentReply` / `pauseAgentJob` / `resumeAgentJob` / `approveAgentStep` / `denyAgentStep` / `skipAgentStep` / `cancelAgentTurn` +- `web/src/composables/useAgentSession.ts` — `MsgStatus` 加 `waiting_for_user` / `paused`;`ChatMessage` 加 `pendingAsk`、`pauseRequestedAt`、`progress`、`thinking`、`partials`;`applyAgentEvent` 覆盖新 6 种事件;`cancel(msg)` 改走 `control.cancel`;新增方法 +- `web/src/composables/useAgentSession.spec.ts` — 覆盖新事件与新方法 +- `web/src/components/organisms/agent/TurnEntry.vue` — 内嵌 `AskPrompt` / `ControlBar`;新增 progress 条与 thinking 折叠区;扩展 `activityEvents` 过滤规则 +- `web/src/components/organisms/agent/TaskInputDock.vue` — `disabled` prop 范围扩大(waiting_for_user / paused 时锁主输入) +- `web/src/i18n/messages.ts` — 13 条新 key(ask/progress/thinking/控制按钮/状态文案)+ 中英文翻译 + +**测试** + +- `web/src/composables/useAgentSession.spec.ts` — 既有文件追加 +- (可选) `web/src/components/organisms/agent/AskPrompt.spec.ts` — 新(@vue/test-utils 风格如项目已用,否则跳过) + +--- + +## Sequencing + +``` +Task 1 (types) ──► Task 2 (api helpers) ──► Task 3 (i18n keys) + │ + ┌─────────────────────┴──────────────────────┐ + ▼ ▼ + Task 4 (useAgentSession state + cancel rewire) Task 5 (useAskTimeout) + │ │ + ▼ │ + Task 6 (useAgentSession ask handlers + control methods) │ + │ │ + ▼ ▼ + ▼─────────────► Task 7 (AskPrompt.vue) ◄─────┘ + │ + ▼ + Task 8 (ControlBar.vue) + │ + ▼ + Task 9 (TurnEntry.vue 集成) + │ + ▼ + Task 10 (TaskInputDock.vue 锁定) + │ + ▼ + Task 11 (端到端 spec:ask → reply → resume → cancel) + │ + ▼ + Task 12 (手测脚本 + dev server 真跑一次) +``` + +--- + +## Task 1: 扩展 types/agent.d.ts + +**Files:** + +- Modify: `web/src/types/agent.d.ts` + +- [ ] **Step 1: 扩展 `AgentJobEventType` 字面量与新增上行消息 / 事件 payload 类型** + +把 `AgentJobEventType` 替换为: + +```ts +export type AgentJobEventType = + | 'job.queued' + | 'job.running' + | 'plan.ready' + | 'tool.started' + | 'tool.succeeded' + | 'tool.failed' + | 'tool.partial' + | 'agent.thinking' + | 'agent.progress' + | 'agent.ask' + | 'agent.paused' + | 'agent.resumed' + | 'job.succeeded' + | 'job.failed' + | 'job.canceled'; +``` + +在文件末尾追加: + +```ts +// ----------------- Inbox (upstream channel) ----------------- + +export type AgentInboxMessageKind = + | 'reply' + | 'control.pause' + | 'control.resume' + | 'control.approve' + | 'control.deny' + | 'control.skip' + | 'control.cancel'; + +export interface AgentInboxMessageRequest { + kind: AgentInboxMessageKind; + replyTo?: string; // ask 的 inboxMessageId(string-encoded) + value?: unknown; // reply 时为用户回答 + metadata?: Record; +} + +export interface AgentInboxMessageResponse { + inboxMessageId: string; + kind: AgentInboxMessageKind; + acceptedAt: string; +} + +// ----------------- New event payloads ----------------- + +export interface AgentAskPayload { + messageId: string; + prompt: string; + schema: Record; // 自由形式;例如 {"choice":["A","B"]} + timeoutSec: number; +} + +export interface AgentProgressPayload { + step: number; + total: number; + message?: string; + percent?: number; +} + +export interface AgentThinkingPayload { + text: string; +} + +export interface AgentToolPartialPayload { + step: number; + tool: string; + chunk: unknown; +} +``` + +- [ ] **Step 2: 删除 `CancelAgentResponse` 接口** + +后端已删除 `POST /agent/cancel`;前端 type 也清理。同步 `web/src/api/agent.ts` 的 import(Task 2 处理)。 + +- [ ] **Step 3: typecheck** + +Run: `cd web && bun run typecheck` +Expected: 仅出现"`CancelAgentResponse` 仍被 import in api/agent.ts"的错误——Task 2 修复。 + +- [ ] **Step 4: Commit** + +```bash +git add web/src/types/agent.d.ts +git commit -m "feat(agent): extend frontend types for inbox + new event payloads" +``` + +--- + +## Task 2: api/agent.ts 引入 sendAgentMessage + 6 个 helper + +**Files:** + +- Modify: `web/src/api/agent.ts` + +- [ ] **Step 1: 删除 `cancelAgentJob` 与对应 import** + +```ts +// 删除: +import type { ... CancelAgentResponse ... } from '../types/agent'; +export const cancelAgentJob = (jobId: string) => { ... }; +``` + +- [ ] **Step 2: 新增 `sendAgentMessage` 与 6 个 helper** + +在 `streamAgentJobEvents` 之上插入: + +```ts +import type { + AgentBackgroundJob, + AgentInboxMessageRequest, + AgentInboxMessageResponse, + AgentJobEvent, + ExecuteAgentRequest, + ExecuteAgentResponse, + PlanAgentRequest, + PlanAgentResponse, +} from '../types/agent'; + +// ----------------- inbox upstream ----------------- + +export const sendAgentMessage = ( + jobId: string, + body: AgentInboxMessageRequest, +) => { + return http.post( + `/agent/jobs/${encodeURIComponent(jobId)}/messages`, + body, + ); +}; + +export const sendAgentReply = ( + jobId: string, + replyTo: string, + value: unknown, +) => sendAgentMessage(jobId, { kind: 'reply', replyTo, value }); + +export const pauseAgentJob = (jobId: string) => + sendAgentMessage(jobId, { kind: 'control.pause' }); + +export const resumeAgentJob = (jobId: string) => + sendAgentMessage(jobId, { kind: 'control.resume' }); + +export const approveAgentStep = (jobId: string) => + sendAgentMessage(jobId, { kind: 'control.approve' }); + +export const denyAgentStep = (jobId: string) => + sendAgentMessage(jobId, { kind: 'control.deny' }); + +export const skipAgentStep = (jobId: string) => + sendAgentMessage(jobId, { kind: 'control.skip' }); + +export const cancelAgentTurn = (jobId: string) => + sendAgentMessage(jobId, { kind: 'control.cancel' }); +``` + +> 注:`cancelAgentTurn` 命名故意区别于历史的 `cancelAgentJob`,提示这是"通过 inbox 取消当前 turn"。所有引用 `cancelAgentJob` 的地方在 Task 4 改成 `cancelAgentTurn`。 + +- [ ] **Step 3: 全仓搜索旧 import** + +Run: `grep -rn "cancelAgentJob\|CancelAgentResponse" web/src/` +Expected: 仅 `web/src/composables/useAgentSession.ts` 几处需要 Task 4 处理。 + +- [ ] **Step 4: typecheck** + +Run: `cd web && bun run typecheck` +Expected: 仅剩 useAgentSession.ts 的 import 错误(Task 4 修)。 + +- [ ] **Step 5: Commit** + +```bash +git add web/src/api/agent.ts +git commit -m "feat(agent): add sendAgentMessage + 6 control helpers, drop cancelAgentJob" +``` + +--- + +## Task 3: 新增 i18n key(中英文) + +**Files:** + +- Modify: `web/src/i18n/messages.ts` + +- [ ] **Step 1: 在 `LocaleKey` union(约 line 480-580)的 agent.v2 区块插入新 key** + +按字母序插在 `agent.v2.turn.cancel` 附近: + +```ts + | 'agent.v2.turn.status.waiting_for_user' + | 'agent.v2.turn.status.paused' + | 'agent.v2.turn.controls.pause' + | 'agent.v2.turn.controls.resume' + | 'agent.v2.turn.controls.skip' + | 'agent.v2.turn.controls.approve' + | 'agent.v2.turn.controls.deny' + | 'agent.v2.turn.ask.placeholder' + | 'agent.v2.turn.ask.send' + | 'agent.v2.turn.ask.timeout' + | 'agent.v2.turn.progress.label' + | 'agent.v2.turn.thinking.label' + | 'agent.v2.turn.thinking.toggle' +``` + +- [ ] **Step 2: 在 zh-CN 翻译 map 中添加(约 line 1066-1160)** + +```ts + 'agent.v2.turn.status.waiting_for_user': '等待你回复', + 'agent.v2.turn.status.paused': '已暂停', + 'agent.v2.turn.controls.pause': '暂停', + 'agent.v2.turn.controls.resume': '继续', + 'agent.v2.turn.controls.skip': '跳过此步', + 'agent.v2.turn.controls.approve': '批准', + 'agent.v2.turn.controls.deny': '拒绝', + 'agent.v2.turn.ask.placeholder': '输入回答…', + 'agent.v2.turn.ask.send': '发送', + 'agent.v2.turn.ask.timeout': '剩余 {value}', + 'agent.v2.turn.progress.label': '进度', + 'agent.v2.turn.thinking.label': '思考过程', + 'agent.v2.turn.thinking.toggle': '展开 / 收起', +``` + +- [ ] **Step 3: 在 en 翻译 map 中添加(约 line 1641-1730)** + +```ts + 'agent.v2.turn.status.waiting_for_user': 'WAITING FOR YOU', + 'agent.v2.turn.status.paused': 'PAUSED', + 'agent.v2.turn.controls.pause': 'Pause', + 'agent.v2.turn.controls.resume': 'Resume', + 'agent.v2.turn.controls.skip': 'Skip step', + 'agent.v2.turn.controls.approve': 'Approve', + 'agent.v2.turn.controls.deny': 'Deny', + 'agent.v2.turn.ask.placeholder': 'Type your answer…', + 'agent.v2.turn.ask.send': 'Send', + 'agent.v2.turn.ask.timeout': '{value} left', + 'agent.v2.turn.progress.label': 'PROGRESS', + 'agent.v2.turn.thinking.label': 'THINKING', + 'agent.v2.turn.thinking.toggle': 'Expand / Collapse', +``` + +- [ ] **Step 4: typecheck** + +Run: `cd web && bun run typecheck` +Expected: PASS(LocaleKey union 与 map 一致) + +- [ ] **Step 5: Commit** + +```bash +git add web/src/i18n/messages.ts +git commit -m "feat(agent): add i18n keys for ask/pause/progress/controls" +``` + +--- + +## Task 4: useAgentSession.ts — 扩展状态 + 改 cancel 走 inbox + +**Files:** + +- Modify: `web/src/composables/useAgentSession.ts` + +- [ ] **Step 1: 扩展 `MsgStatus` 与 `ChatMessage`** + +把现有 `MsgStatus` 类型改为: + +```ts +export type MsgStatus = + | 'pending' + | 'running' + | 'succeeded' + | 'failed' + | 'canceled' + | 'waiting_for_user' + | 'paused'; +``` + +把 `ChatMessage` 接口扩展为: + +```ts +export interface PendingAsk { + messageId: string; + prompt: string; + schema: Record; + timeoutSec: number; + askedAt: string; +} + +export interface ToolPartial { + step: number; + tool: string; + chunks: unknown[]; +} + +export interface ChatMessage { + id: string; + role: 'user' | 'agent'; + content: string; + status: MsgStatus; + planJobId?: string; + planHash?: string; + planResult?: AgentPlanResult; + executeJobId?: string; + executeResult?: AgentExecutionResult; + events: AgentJobEvent[]; + errorMessage?: string; + timestamp: string; + // —— 新增(A 前端)—— + pendingAsk?: PendingAsk; + pauseRequestedAt?: string; + progress?: { step: number; total: number; message?: string; percent?: number }; + thinking?: string; // 累积的 thinking 文本 + partials?: Record; +} +``` + +- [ ] **Step 2: 调整 `applyAgentEvent` 覆盖新事件** + +替换现有 `applyAgentEvent` 为: + +```ts +const applyAgentEvent = (msg: ChatMessage, event: AgentJobEvent, kind: 'plan' | 'execute') => { + appendAgentEvent(msg, event); + + // 终态 / 既有事件 + if (event.type === 'job.queued') { + msg.status = 'pending'; + } else if (event.type === 'job.running' || event.type === 'tool.started') { + if (msg.status !== 'waiting_for_user' && msg.status !== 'paused') { + msg.status = 'running'; + } + } else if (event.type === 'job.failed' || event.type === 'tool.failed') { + msg.status = 'failed'; + const errorMessage = event.data?.errorMessage; + msg.errorMessage = typeof errorMessage === 'string' ? errorMessage : event.message; + } else if (event.type === 'job.canceled') { + msg.status = 'canceled'; + } else if (event.type === 'job.succeeded') { + msg.status = 'succeeded'; + msg.pendingAsk = undefined; + msg.pauseRequestedAt = undefined; + } + + // 新事件 + if (event.type === 'agent.ask') { + const payload = event.data as AgentAskPayload; + msg.pendingAsk = { + messageId: payload.messageId, + prompt: payload.prompt, + schema: payload.schema, + timeoutSec: payload.timeoutSec, + askedAt: event.timestamp, + }; + msg.status = 'waiting_for_user'; + } else if (event.type === 'agent.paused') { + msg.status = 'paused'; + msg.pauseRequestedAt = event.timestamp; + } else if (event.type === 'agent.resumed') { + msg.status = 'running'; + msg.pauseRequestedAt = undefined; + } else if (event.type === 'agent.progress') { + const payload = event.data as AgentProgressPayload; + msg.progress = { + step: payload.step, + total: payload.total, + message: payload.message, + percent: payload.percent, + }; + } else if (event.type === 'agent.thinking') { + const payload = event.data as AgentThinkingPayload; + msg.thinking = (msg.thinking || '') + (payload.text || ''); + } else if (event.type === 'tool.partial') { + const payload = event.data as AgentToolPartialPayload; + msg.partials = msg.partials || {}; + const slot = msg.partials[payload.step] || { step: payload.step, tool: payload.tool, chunks: [] }; + slot.chunks = [...slot.chunks, payload.chunk]; + msg.partials[payload.step] = slot; + } + + const result = event.data?.result; + if (event.type === 'plan.ready' && result) { + msg.planResult = result as AgentPlanResult; + msg.planHash = msg.planResult.planHash; + } + if (event.type === 'job.succeeded' && result) { + if (kind === 'plan') { + msg.planResult = result as AgentPlanResult; + msg.planHash = msg.planResult.planHash; + } else { + msg.executeResult = result as AgentExecutionResult; + } + } +}; +``` + +在 imports 顶部追加: + +```ts +import type { + AgentAskPayload, + AgentExecutionPolicy, + AgentExecutionResult, + AgentJobEvent, + AgentPlanResult, + AgentProgressPayload, + AgentReasoningEffort, + AgentThinkingPayload, + AgentToolPartialPayload, + PlanAgentRequest, +} from '../types/agent'; +``` + +- [ ] **Step 3: 把 `cancel(msg)` 改走 `control.cancel`** + +替换 `cancel` 函数 + 删除顶部 `cancelAgentJob` import: + +```ts +import { + cancelAgentTurn, + executeAgentPlan, + getAgentJob, + planAgentTask, + streamAgentJobEvents, +} from '../api/agent'; + +// ... + +async function cancel(msg: ChatMessage): Promise { + markTurnCanceled(msg); + msg.status = 'canceled'; + msg.pendingAsk = undefined; + msg.pauseRequestedAt = undefined; + stopPolling(`${msg.id}:plan`); + stopPolling(`${msg.id}:execute`); + stopStream(`${msg.id}:plan`); + stopStream(`${msg.id}:execute`); + const jobId = msg.executeJobId || msg.planJobId; + if (!jobId) return; + try { + await cancelAgentTurn(jobId); + } catch (error) { + msg.errorMessage = extractErrorMessage(error, 'Cancel failed.'); + } +} +``` + +同时把 `sendMessage` / `runExecute` 内的旧 `cancelAgentJob(res.jobId)` 替换为 `cancelAgentTurn(res.jobId)`: + +```ts +// sendMessage 内: +if (isTurnCanceled(reactiveAgent) || reactiveAgent.status === 'canceled') { + try { + await cancelAgentTurn(res.jobId); + } catch { /* ignore */ } + return; +} + +// runExecute 内: +if (!ensureTurnNotCanceled(msg)) { + try { + await cancelAgentTurn(res.jobId); + } catch { /* ignore */ } + return; +} +``` + +- [ ] **Step 4: 全仓 grep** + +Run: `grep -rn "cancelAgentJob" web/src/` +Expected: 无匹配。 + +- [ ] **Step 5: typecheck + 跑既有测试** + +Run: `cd web && bun run typecheck && bun run test useAgentSession` +Expected: typecheck PASS;既有 spec 大多 PASS。如果某些用例断言"取消时调用 cancelAgentJob",改断言为 `cancelAgentTurn`(即 `sendAgentMessage(..., {kind:'control.cancel'})`)。 + +- [ ] **Step 6: Commit** + +```bash +git add web/src/composables/useAgentSession.ts +git commit -m "feat(agent): extend session state for ask/pause/progress, rewire cancel via inbox" +``` + +--- + +## Task 5: 新建 `useAskTimeout.ts` + +**Files:** + +- Create: `web/src/composables/useAskTimeout.ts` +- Create: `web/src/composables/useAskTimeout.spec.ts` + +- [ ] **Step 1: 写测试** + +```ts +// web/src/composables/useAskTimeout.spec.ts +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { ref } from 'vue'; +import { useAskTimeout } from './useAskTimeout'; + +describe('useAskTimeout', () => { + beforeEach(() => { + vi.useFakeTimers(); + }); + afterEach(() => { + vi.useRealTimers(); + }); + + it('counts down from askedAt + timeoutSec', () => { + const askedAt = ref('2026-05-26T12:00:00.000Z'); + const timeoutSec = ref(120); + vi.setSystemTime(new Date('2026-05-26T12:00:30.000Z')); + const { remainingSec, formatted, expired } = useAskTimeout(askedAt, timeoutSec); + expect(remainingSec.value).toBe(90); + expect(formatted.value).toBe('01:30'); + expect(expired.value).toBe(false); + + vi.setSystemTime(new Date('2026-05-26T12:02:01.000Z')); + vi.advanceTimersByTime(1000); + expect(expired.value).toBe(true); + expect(remainingSec.value).toBe(0); + expect(formatted.value).toBe('00:00'); + }); + + it('returns expired immediately when askedAt is missing', () => { + const askedAt = ref(undefined); + const timeoutSec = ref(60); + const { expired, formatted } = useAskTimeout(askedAt, timeoutSec); + expect(expired.value).toBe(true); + expect(formatted.value).toBe('00:00'); + }); +}); +``` + +- [ ] **Step 2: 运行测试,确认 fail** + +Run: `cd web && bun run test useAskTimeout` +Expected: FAIL — module missing. + +- [ ] **Step 3: 实现** + +```ts +// web/src/composables/useAskTimeout.ts +import { computed, onScopeDispose, ref, watchEffect, type Ref } from 'vue'; + +export function useAskTimeout( + askedAt: Ref, + timeoutSec: Ref, +) { + const now = ref(Date.now()); + let timer: ReturnType | null = null; + + watchEffect(() => { + if (timer) clearInterval(timer); + if (!askedAt.value || timeoutSec.value <= 0) return; + timer = setInterval(() => { + now.value = Date.now(); + }, 1000); + }); + + onScopeDispose(() => { + if (timer) clearInterval(timer); + }); + + const deadline = computed(() => { + if (!askedAt.value) return null; + const base = Date.parse(askedAt.value); + if (Number.isNaN(base)) return null; + return base + timeoutSec.value * 1000; + }); + + const remainingSec = computed(() => { + if (deadline.value === null) return 0; + return Math.max(0, Math.ceil((deadline.value - now.value) / 1000)); + }); + + const expired = computed(() => deadline.value === null || remainingSec.value <= 0); + + const formatted = computed(() => { + const total = remainingSec.value; + const mm = String(Math.floor(total / 60)).padStart(2, '0'); + const ss = String(total % 60).padStart(2, '0'); + return `${mm}:${ss}`; + }); + + return { remainingSec, formatted, expired }; +} +``` + +- [ ] **Step 4: 运行测试** + +Run: `cd web && bun run test useAskTimeout` +Expected: PASS(2 个用例) + +- [ ] **Step 5: Commit** + +```bash +git add web/src/composables/useAskTimeout.ts web/src/composables/useAskTimeout.spec.ts +git commit -m "feat(agent): add useAskTimeout countdown composable" +``` + +--- + +## Task 6: useAgentSession.ts — ask reply 与控制方法 + +**Files:** + +- Modify: `web/src/composables/useAgentSession.ts` +- Modify: `web/src/composables/useAgentSession.spec.ts` + +- [ ] **Step 1: 更新 spec 顶部的 `vi.mock` 工厂,注册新 helper** + +`useAgentSession.spec.ts` 顶部已有: + +```ts +vi.mock('../api/agent', () => ({ + planAgentTask: vi.fn(), + executeAgentPlan: vi.fn(), + cancelAgentJob: vi.fn(), + getAgentJob: vi.fn(), + streamAgentJobEvents: vi.fn(), +})); +``` + +替换为: + +```ts +vi.mock('../api/agent', () => ({ + planAgentTask: vi.fn(), + executeAgentPlan: vi.fn(), + cancelAgentTurn: vi.fn(), // 替代旧 cancelAgentJob + getAgentJob: vi.fn(), + streamAgentJobEvents: vi.fn(), + sendAgentMessage: vi.fn(), + sendAgentReply: vi.fn(), + pauseAgentJob: vi.fn(), + resumeAgentJob: vi.fn(), + approveAgentStep: vi.fn(), + denyAgentStep: vi.fn(), + skipAgentStep: vi.fn(), +})); +``` + +把既有用到 `agentApi.cancelAgentJob` 的断言全部改为 `agentApi.cancelAgentTurn`(Task 4 已经改了源;spec 这里同步)。 + +Run: `grep -n "cancelAgentJob" web/src/composables/useAgentSession.spec.ts` +Expected: 找到的每一处都要改成 `cancelAgentTurn`。 + +- [ ] **Step 2: 在 spec 末尾追加 inbox-controls 用例** + +```ts +import * as agentApi from '../api/agent'; + +describe('useAgentSession — inbox controls', () => { + beforeEach(() => { + vi.mocked(agentApi.sendAgentReply).mockResolvedValue({ + inboxMessageId: '42', kind: 'reply', acceptedAt: '2026-05-26T00:00:00Z', + }); + vi.mocked(agentApi.pauseAgentJob).mockResolvedValue({ + inboxMessageId: '50', kind: 'control.pause', acceptedAt: '2026-05-26T00:00:00Z', + }); + }); + + it('replyToAsk sends reply to backend and clears pendingAsk', async () => { + const { default: useAgentSession } = await loadComposable(); + const { createSession, replyToAsk } = useAgentSession(); + const session = createSession(); + const msg: ChatMessage = { + id: 'msg-1', + role: 'agent', + content: '', + status: 'waiting_for_user', + events: [], + timestamp: new Date().toISOString(), + executeJobId: '77', + pendingAsk: { + messageId: '101', + prompt: 'choose', + schema: { choice: ['A', 'B'] }, + timeoutSec: 60, + askedAt: new Date().toISOString(), + }, + }; + session.messages.push(msg); + + await replyToAsk(msg, 'A'); + + expect(agentApi.sendAgentReply).toHaveBeenCalledWith('77', '101', 'A'); + expect(msg.pendingAsk).toBeUndefined(); + expect(msg.status).toBe('running'); + }); + + it('pauseTurn sends control.pause and records pauseRequestedAt', async () => { + const { default: useAgentSession } = await loadComposable(); + const { createSession, pauseTurn } = useAgentSession(); + const session = createSession(); + const msg: ChatMessage = { + id: 'msg-2', role: 'agent', content: '', status: 'running', + events: [], timestamp: new Date().toISOString(), executeJobId: '88', + }; + session.messages.push(msg); + + await pauseTurn(msg); + expect(agentApi.pauseAgentJob).toHaveBeenCalledWith('88'); + // 本地不立即翻 paused,等 agent.paused 事件 + expect(msg.pauseRequestedAt).toBeTruthy(); + }); +}); +``` + +> 注:上述用例用项目既有 `vi.mock + vi.mocked + loadComposable()` 风格(参见同文件其它用例),不引入 `vi.spyOn(await import(...))` 写法。`ChatMessage` 类型可能需要从 `'../composables/useAgentSession'` 导入(既有用例如何引用就跟随)。 + +- [ ] **Step 2: 实现 5 个新方法** + +在 `useAgentSession.ts` 内(`cancel` 函数附近)新增: + +```ts +import { + approveAgentStep, + cancelAgentTurn, + denyAgentStep, + executeAgentPlan, + getAgentJob, + pauseAgentJob, + planAgentTask, + resumeAgentJob, + sendAgentReply, + skipAgentStep, + streamAgentJobEvents, +} from '../api/agent'; + +// ... + +const activeJobId = (msg: ChatMessage): string | undefined => + msg.executeJobId || msg.planJobId; + +async function replyToAsk(msg: ChatMessage, value: unknown): Promise { + const jobId = activeJobId(msg); + if (!jobId || !msg.pendingAsk) return; + const replyTo = msg.pendingAsk.messageId; + msg.pendingAsk = undefined; + msg.status = 'running'; + try { + await sendAgentReply(jobId, replyTo, value); + } catch (error) { + msg.status = 'waiting_for_user'; + msg.pendingAsk = { + messageId: replyTo, + prompt: msg.pendingAsk?.prompt || '', + schema: msg.pendingAsk?.schema || {}, + timeoutSec: msg.pendingAsk?.timeoutSec || 0, + askedAt: msg.pendingAsk?.askedAt || new Date().toISOString(), + }; + msg.errorMessage = extractErrorMessage(error, 'Reply failed.'); + } +} + +async function pauseTurn(msg: ChatMessage): Promise { + const jobId = activeJobId(msg); + if (!jobId) return; + msg.pauseRequestedAt = new Date().toISOString(); + try { + await pauseAgentJob(jobId); + } catch (error) { + msg.pauseRequestedAt = undefined; + msg.errorMessage = extractErrorMessage(error, 'Pause failed.'); + } +} + +async function resumeTurn(msg: ChatMessage): Promise { + const jobId = activeJobId(msg); + if (!jobId) return; + try { + await resumeAgentJob(jobId); + } catch (error) { + msg.errorMessage = extractErrorMessage(error, 'Resume failed.'); + } +} + +async function approveStep(msg: ChatMessage): Promise { + const jobId = activeJobId(msg); + if (!jobId) return; + try { + await approveAgentStep(jobId); + } catch (error) { + msg.errorMessage = extractErrorMessage(error, 'Approve failed.'); + } +} + +async function denyStep(msg: ChatMessage): Promise { + const jobId = activeJobId(msg); + if (!jobId) return; + try { + await denyAgentStep(jobId); + } catch (error) { + msg.errorMessage = extractErrorMessage(error, 'Deny failed.'); + } +} + +async function skipStep(msg: ChatMessage): Promise { + const jobId = activeJobId(msg); + if (!jobId) return; + try { + await skipAgentStep(jobId); + } catch (error) { + msg.errorMessage = extractErrorMessage(error, 'Skip failed.'); + } +} +``` + +把 6 个方法加到 return 对象末尾: + +```ts + return { + sessions: s.sessions, + activeSessionId: s.activeSessionId, + activeSession, + activeTurns, + policy: s.policy, + reasoningEffort: s.reasoningEffort, + taskInput: s.taskInput, + isSending: s.isSending, + createSession, + switchSession, + deleteSession, + resetActiveSession, + sendMessage, + runExecute, + cancel, + // —— 新增 —— + replyToAsk, + pauseTurn, + resumeTurn, + approveStep, + denyStep, + skipStep, + }; +``` + +- [ ] **Step 3: 运行测试** + +Run: `cd web && bun run test useAgentSession` +Expected: PASS(含新追加的两个用例) + +- [ ] **Step 4: Commit** + +```bash +git add web/src/composables/useAgentSession.ts web/src/composables/useAgentSession.spec.ts +git commit -m "feat(agent): add replyToAsk + pause/resume/skip/approve/deny composables" +``` + +--- + +## Task 7: `AskPrompt.vue` + +**Files:** + +- Create: `web/src/components/organisms/agent/AskPrompt.vue` + +- [ ] **Step 1: 实现组件** + +```vue + + + +