diff --git a/app/src/fileflash/agents/harness/router.py b/app/src/fileflash/agents/harness/router.py index 95318ed..3c46f39 100644 --- a/app/src/fileflash/agents/harness/router.py +++ b/app/src/fileflash/agents/harness/router.py @@ -3,9 +3,13 @@ from dataclasses import dataclass from typing import Any +from sqlalchemy import and_, select from sqlalchemy.ext.asyncio import AsyncSession from ...core.errors import ApiError +from ...core.mime import resolve_file_mime_type +from ...models import File, Folder +from ...models.enums import FileStatus, FolderStatus, FolderType from ...schemas.file import ( CreateFolderRequest, GetFolderContentsQuery, @@ -54,6 +58,9 @@ async def dispatch(self, call: ToolCall) -> dict[str, Any]: ) return result.model_dump(by_alias=True, mode="json") + if tool == "drive.countFiles": + return await self._count_files(args) + if tool == "drive.createFolder": name = _required_text(args, "name", "folderName") parent_id = _first_value(args, "parentFolderId", "targetParentId", "folderId") or "root" @@ -126,6 +133,85 @@ async def dispatch(self, call: ToolCall) -> dict[str, Any]: raise ApiError(status_code=400, code=400, message=f"Unsupported agent tool: {tool}") + async def _count_files(self, args: dict[str, Any]) -> dict[str, Any]: + folder_id = str(_first_value(args, "folderId", "parentFolderId") or "root") + recursive = _bool_arg(args.get("recursive"), default=True) + category = _normalize_category(args.get("category")) + search = str(args.get("search") or "").strip().lower() + root_folder_id = await _resolve_folder_id( + self.db, + user_id=self.user_id, + folder_id=folder_id, + ) + folder_ids = ( + await _active_descendant_folder_ids( + self.db, + user_id=self.user_id, + root_folder_id=root_folder_id, + ) + if recursive + else [root_folder_id] + ) + + statement = select( + File.file_id, + File.file_name, + File.file_size, + File.mime_type, + File.file_ext, + File.folder_id, + ).where( + and_( + File.owner_id == self.user_id, + File.folder_id.in_(folder_ids), + File.status == FileStatus.ACTIVE, + File.is_latest.is_(True), + ) + ) + if search: + statement = statement.where(File.file_name.ilike(f"%{search}%")) + statement = statement.order_by(File.file_name.asc()) + + rows = (await self.db.execute(statement)).all() + by_mime_type: dict[str, int] = {} + sample_items: list[dict[str, Any]] = [] + total_items = 0 + for row in rows: + file_id, file_name, file_size, mime_type, file_ext, row_folder_id = row + resolved_mime = resolve_file_mime_type( + mime_type=mime_type, + file_ext=file_ext, + file_name=file_name, + ) + if category is not None and _category_for_file( + mime_type=resolved_mime, + file_ext=file_ext, + file_name=file_name, + ) != category: + continue + + total_items += 1 + by_mime_type[resolved_mime] = by_mime_type.get(resolved_mime, 0) + 1 + if len(sample_items) < 5: + sample_items.append( + { + "id": str(file_id), + "name": str(file_name), + "size": int(file_size or 0), + "mimeType": resolved_mime, + "folderId": str(row_folder_id), + } + ) + + return { + "totalItems": total_items, + "category": category, + "recursive": recursive, + "folderId": str(root_folder_id), + "byMimeType": dict(sorted(by_mime_type.items())), + "sampleItems": sample_items, + } + def _first_value(args: dict[str, Any], *keys: str) -> Any: for key in keys: @@ -143,3 +229,130 @@ def _required_text(args: dict[str, Any], *keys: str) -> str: if not text: raise ApiError(status_code=400, code=400, message=f"Missing required tool input: {keys[0]}") return text + + +def _bool_arg(value: Any, *, default: bool) -> bool: + if value is None: + return default + if isinstance(value, bool): + return value + text = str(value).strip().lower() + if text in {"1", "true", "yes", "y"}: + return True + if text in {"0", "false", "no", "n"}: + return False + return default + + +async def _resolve_folder_id(db: AsyncSession, *, user_id: int, folder_id: str) -> int: + if not folder_id or folder_id == "root": + root_id = await db.scalar( + select(Folder.folder_id).where( + and_( + Folder.owner_id == user_id, + Folder.parent_folder_id.is_(None), + Folder.folder_type == FolderType.ROOT, + Folder.status == FolderStatus.ACTIVE, + ) + ) + ) + if root_id is None: + raise ApiError(status_code=404, code=404, message="Root folder not found") + return int(root_id) + try: + parsed = int(folder_id) + except ValueError as exc: + raise ApiError(status_code=400, code=400, message="Invalid folderId") from exc + exists = await db.scalar( + select(Folder.folder_id).where( + and_( + Folder.folder_id == parsed, + Folder.owner_id == user_id, + Folder.status == FolderStatus.ACTIVE, + ) + ) + ) + if exists is None: + raise ApiError(status_code=404, code=404, message="Folder not found") + return parsed + + +async def _active_descendant_folder_ids( + db: AsyncSession, + *, + user_id: int, + root_folder_id: int, +) -> list[int]: + descendants = ( + select(Folder.folder_id) + .where( + and_( + Folder.folder_id == root_folder_id, + Folder.owner_id == user_id, + Folder.status == FolderStatus.ACTIVE, + ) + ) + .cte(name="agent_count_descendants", recursive=True) + ) + descendants = descendants.union_all( + select(Folder.folder_id).where( + and_( + Folder.parent_folder_id == descendants.c.folder_id, + Folder.owner_id == user_id, + Folder.status == FolderStatus.ACTIVE, + ) + ) + ) + folder_ids = list(await db.scalars(select(descendants.c.folder_id))) + return [int(folder_id) for folder_id in folder_ids] + + +def _normalize_category(value: Any) -> str | None: + text = str(value or "").strip().lower() + aliases = { + "movies": "video", + "movie": "video", + "film": "video", + "films": "video", + "视频": "video", + "影片": "video", + "电影": "video", + "videos": "video", + "documents": "document", + "docs": "document", + "images": "image", + "pictures": "image", + "archives": "archive", + "compressed": "archive", + } + text = aliases.get(text, text) + if text in {"video", "audio", "image", "document", "archive", "other"}: + return text + return None + + +def _category_for_file(*, mime_type: str, file_ext: str | None, file_name: str | None) -> str: + mime = (mime_type or "").lower() + ext = _normalized_extension(file_ext) or _filename_extension(file_name) + if mime.startswith("video/") or ext in {"mp4", "mov", "avi", "mkv", "webm", "m4v"}: + return "video" + if mime.startswith("audio/") or ext in {"mp3", "wav", "flac", "m4a", "aac", "ogg"}: + return "audio" + if mime.startswith("image/") or ext in {"jpg", "jpeg", "png", "gif", "webp", "svg", "bmp"}: + return "image" + if mime in {"application/pdf"} or ext in {"pdf", "doc", "docx", "xls", "xlsx", "ppt", "pptx", "txt", "md"}: + return "document" + if ext in {"zip", "rar", "7z", "tar", "gz", "bz2", "xz"}: + return "archive" + return "other" + + +def _normalized_extension(value: str | None) -> str: + return str(value or "").strip().lower().lstrip(".") + + +def _filename_extension(value: str | None) -> str: + name = str(value or "").strip().lower() + if "." not in name: + return "" + return name.rsplit(".", 1)[-1] diff --git a/app/src/fileflash/agents/runtime/execute_runner.py b/app/src/fileflash/agents/runtime/execute_runner.py index 1e045d6..499b489 100644 --- a/app/src/fileflash/agents/runtime/execute_runner.py +++ b/app/src/fileflash/agents/runtime/execute_runner.py @@ -146,10 +146,12 @@ async def run(self, *, db: AsyncSession, job: BackgroundJob) -> AgentExecutionRe warnings.append(f"{skipped} action(s) were skipped.") await work_sessions.close_session(job_id=int(job.job_id), status="closed") await db.commit() + answer = _build_execution_answer(actions=actions, step_outputs=step_outputs) return AgentExecutionResult( plan_job_id=str(plan_job_id), execute_job_id=str(job.job_id), summary=f"Execution completed with {applied} applied action(s).", + answer=answer, applied_actions=applied, skipped_actions=skipped, warnings=warnings, @@ -222,3 +224,37 @@ def _resolve_references( for key, item in value.items() } return value + + +def _build_execution_answer( + *, + actions: list[AgentProposedAction], + step_outputs: dict[int, dict[str, Any]], +) -> str | None: + for action in actions: + if action.tool != "drive.countFiles": + continue + output = step_outputs.get(action.step) + if not isinstance(output, dict): + continue + return _count_files_answer(output) + + if actions and all(action.side_effect == "read" for action in actions): + return f"已完成 {len(step_outputs)} 个只读操作。" + return None + + +def _count_files_answer(output: dict[str, Any]) -> str: + total_items = int(output.get("totalItems") or 0) + category = str(output.get("category") or "").strip().lower() + if category == "video": + return f"你上传了 {total_items} 部电影(按视频文件统计)。" + if category == "audio": + return f"你上传了 {total_items} 个音频文件。" + if category == "image": + return f"你上传了 {total_items} 张图片。" + if category == "document": + return f"你上传了 {total_items} 个文档。" + if category == "archive": + return f"你上传了 {total_items} 个压缩包。" + return f"你上传了 {total_items} 个文件。" diff --git a/app/src/fileflash/agents/runtime/llm.py b/app/src/fileflash/agents/runtime/llm.py index d4b157f..4a268ee 100644 --- a/app/src/fileflash/agents/runtime/llm.py +++ b/app/src/fileflash/agents/runtime/llm.py @@ -46,8 +46,24 @@ async def create_plan( "timeout": 60.0, } request_kwargs.update(_reasoning_params(reasoning_effort)) + message = await self._request_plan(api_key=api_key, request_kwargs=request_kwargs) try: - message = await self._get_client(api_key).messages.create(**request_kwargs) + parsed, usage = _parse_plan_message(message) + except ApiError as exc: + if not _is_retryable_output_error(exc): + raise + degraded_kwargs = dict(request_kwargs) + degraded_kwargs.pop("thinking", None) + degraded_kwargs.pop("output_config", None) + message = await self._request_plan(api_key=api_key, request_kwargs=degraded_kwargs) + parsed, usage = _parse_plan_message(message) + if isinstance(usage, dict): + parsed["_usage"] = usage + return parsed + + async def _request_plan(self, *, api_key: str, request_kwargs: dict[str, Any]) -> Any: + try: + return await self._get_client(api_key).messages.create(**request_kwargs) except anthropic.APIStatusError as exc: raise ApiError( status_code=503, @@ -66,13 +82,6 @@ async def create_plan( message=f"Agent LLM request failed: {type(exc).__name__}", ) from exc - text = _extract_text(message) - parsed = _parse_json_text(text) - usage = _usage_payload(message) - if isinstance(usage, dict): - parsed["_usage"] = usage - return parsed - def _get_client(self, api_key: str) -> AsyncAnthropic: if self._client is None: base_url = (self.settings.agent_llm_base_url or "").strip() or None @@ -87,22 +96,80 @@ def _get_client(self, api_key: str) -> AsyncAnthropic: def _extract_text(message: Any) -> str: chunks = getattr(message, "content", None) + if isinstance(chunks, str): + text = chunks.strip() + if text: + return text + raise ApiError(status_code=502, code=502, message="Agent LLM returned an empty response") if not isinstance(chunks, list): raise ApiError(status_code=502, code=502, message="Agent LLM returned an invalid response") parts: list[str] = [] for chunk in chunks: - if isinstance(chunk, dict): - if chunk.get("type") == "text": - parts.append(str(chunk.get("text") or "")) - continue - if getattr(chunk, "type", None) == "text": - parts.append(str(getattr(chunk, "text", "") or "")) + parts.extend(_extract_text_parts(chunk)) text = "\n".join(part for part in parts if part).strip() if not text: raise ApiError(status_code=502, code=502, message="Agent LLM returned an empty response") return text +def _extract_text_parts(chunk: Any) -> list[str]: + if chunk is None: + return [] + if isinstance(chunk, str): + candidate = chunk.strip() + return [candidate] if candidate else [] + if isinstance(chunk, dict): + return _extract_text_parts_from_mapping(chunk) + if hasattr(chunk, "model_dump"): + dumped = chunk.model_dump() + if isinstance(dumped, dict): + return _extract_text_parts_from_mapping(dumped) + return _extract_text_parts_from_mapping( + { + "type": getattr(chunk, "type", None), + "text": getattr(chunk, "text", None), + "output_text": getattr(chunk, "output_text", None), + "content": getattr(chunk, "content", None), + } + ) + + +def _extract_text_parts_from_mapping(payload: dict[str, Any]) -> list[str]: + out: list[str] = [] + if payload.get("type") == "text": + out.extend(_flatten_text_value(payload.get("text"))) + for key in ("text", "output_text", "content"): + out.extend(_flatten_text_value(payload.get(key))) + deduped: list[str] = [] + seen: set[str] = set() + for item in out: + if not item or item in seen: + continue + deduped.append(item) + seen.add(item) + return deduped + + +def _flatten_text_value(value: Any) -> list[str]: + if value is None: + return [] + if isinstance(value, str): + candidate = value.strip() + return [candidate] if candidate else [] + if isinstance(value, list): + out: list[str] = [] + for item in value: + out.extend(_flatten_text_value(item)) + return out + if isinstance(value, dict): + if value.get("type") == "text": + return _flatten_text_value(value.get("text")) + if "text" in value: + return _flatten_text_value(value.get("text")) + return [] + return [] + + def _usage_payload(message: Any) -> dict[str, Any] | None: usage = getattr(message, "usage", None) if usage is None: @@ -144,6 +211,24 @@ def _response_details(error: anthropic.APIStatusError) -> str: return str(text or "")[:800] +def _parse_plan_message(message: Any) -> tuple[dict[str, Any], dict[str, Any] | None]: + text = _extract_text(message) + parsed = _parse_json_text(text) + usage = _usage_payload(message) + return parsed, usage + + +def _is_retryable_output_error(error: ApiError) -> bool: + if error.status_code != 502: + return False + return error.message in { + "Agent LLM returned an invalid response", + "Agent LLM returned an empty response", + "Agent LLM did not return valid JSON", + "Agent LLM JSON must be an object", + } + + def _parse_json_text(text: str) -> dict[str, Any]: candidate = text.strip() if candidate.startswith("```"): diff --git a/app/src/fileflash/agents/runtime/plan_runner.py b/app/src/fileflash/agents/runtime/plan_runner.py index 7181cf1..e3421d0 100644 --- a/app/src/fileflash/agents/runtime/plan_runner.py +++ b/app/src/fileflash/agents/runtime/plan_runner.py @@ -29,6 +29,7 @@ DEFAULT_AGENT_TOOLS = ( "drive.listFolder", + "drive.countFiles", "drive.createFolder", "drive.moveFile", "drive.moveFolder", @@ -63,17 +64,26 @@ async def run(self, *, db: AsyncSession, job: BackgroundJob) -> AgentPlanResult: ) metadata = await _collect_context_metadata(db, user_id=user_id, request=request) allowed_tools = _skill_tool_whitelist(skill) - llm_payload = await self.planner_client.create_plan( - system_prompt=_system_prompt(), - user_prompt=_user_prompt( + try: + llm_payload = await self.planner_client.create_plan( + system_prompt=_system_prompt(), + user_prompt=_user_prompt( + request=request, + skill=skill, + allowed_tools=allowed_tools, + metadata=metadata, + ), + max_tokens=request.hints.budget_tokens, + reasoning_effort=request.hints.reasoning_effort, + ) + except ApiError as exc: + if exc.status_code != 502: + raise + llm_payload = _safe_fallback_payload( request=request, - skill=skill, - allowed_tools=allowed_tools, metadata=metadata, - ), - max_tokens=request.hints.budget_tokens, - reasoning_effort=request.hints.reasoning_effort, - ) + allowed_tools=allowed_tools, + ) actions = _normalize_actions( llm_payload=llm_payload, @@ -182,7 +192,10 @@ def _skill_tool_whitelist(skill: AgentSkill | AgentSkillCatalogEntry | None) -> elif skill is not None: raw = skill.tool_whitelist_json if isinstance(raw, list) and raw: - return tuple(str(item) for item in raw if str(item).strip()) + tools = tuple(str(item) for item in raw if str(item).strip()) + if "drive.countFiles" not in tools: + return (*tools, "drive.countFiles") + return tools return DEFAULT_AGENT_TOOLS @@ -356,7 +369,8 @@ def _folder_metadata(row: Folder) -> dict[str, Any]: def _system_prompt() -> str: return ( "You are FileFlash Agent Planner. Return only JSON. " - "Plan file-organization actions using the provided tools and metadata. " + "Plan file-management actions or read-only answers using the provided tools and metadata. " + "For count/how many questions, prefer drive.countFiles over listing folders. " "Do not read or infer file contents. Deletions are high risk and must be explicit. " "Cross-step dependencies must use '$stepN.field' references only and never symbolic placeholders " "like 'newFolderId'." @@ -437,6 +451,10 @@ def _skill_payload(skill: AgentSkill | AgentSkillCatalogEntry | None) -> dict[st def _tool_schemas(allowed_tools: tuple[str, ...]) -> list[dict[str, Any]]: descriptions = { "drive.listFolder": "List direct folder contents by folderId.", + "drive.countFiles": ( + "Count files under folderId. Supports recursive=true and category values " + "video, audio, image, document, archive, other. Use category=video for movie/电影 questions." + ), "drive.createFolder": "Create a folder under parentFolderId with name.", "drive.moveFile": "Move fileId into targetFolderId.", "drive.moveFolder": "Move folderId into targetParentId.", @@ -448,6 +466,74 @@ def _tool_schemas(allowed_tools: tuple[str, ...]) -> list[dict[str, Any]]: return [{"tool": tool, "description": descriptions.get(tool, "")} for tool in allowed_tools] +def _safe_fallback_payload( + *, + request: PlanAgentRequest, + metadata: dict[str, Any], + allowed_tools: tuple[str, ...], +) -> dict[str, Any]: + fallback_actions: list[dict[str, Any]] = [] + if "drive.countFiles" in allowed_tools and _looks_like_count_question(request.input): + fallback_actions.append( + { + "step": 1, + "tool": "drive.countFiles", + "input": { + "folderId": metadata.get("rootFolderId") or request.context.root_folder_id or "root", + "recursive": True, + "category": _fallback_count_category(request.input), + }, + "sideEffect": "read", + "riskLevel": "low", + "requiresConfirmation": False, + } + ) + return { + "summary": "Planner fallback mode: generated a safe read-only count plan.", + "proposedActions": fallback_actions, + } + if "drive.listFolder" in allowed_tools: + root_folder_id = str( + metadata.get("rootFolderId") + or request.context.root_folder_id + or "root" + ) + fallback_actions.append( + { + "step": 1, + "tool": "drive.listFolder", + "input": {"folderId": root_folder_id}, + "sideEffect": "read", + "riskLevel": "low", + "requiresConfirmation": False, + } + ) + return { + "summary": "Planner fallback mode: generated a safe read-only plan.", + "proposedActions": fallback_actions, + } + + +def _looks_like_count_question(text: str) -> bool: + normalized = text.lower() + return any(token in normalized for token in ("多少", "几个", "几部", "count", "how many", "number of")) + + +def _fallback_count_category(text: str) -> str | None: + normalized = text.lower() + if any(token in normalized for token in ("电影", "影片", "视频", "movie", "film", "video")): + return "video" + if any(token in normalized for token in ("图片", "照片", "image", "photo", "picture")): + return "image" + if any(token in normalized for token in ("音频", "音乐", "audio", "music")): + return "audio" + if any(token in normalized for token in ("文档", "document", "doc")): + return "document" + if any(token in normalized for token in ("压缩", "archive", "zip")): + return "archive" + return None + + def _normalize_actions( *, llm_payload: dict[str, Any], diff --git a/app/src/fileflash/agents/worker.py b/app/src/fileflash/agents/worker.py index 871c214..6fd4e89 100644 --- a/app/src/fileflash/agents/worker.py +++ b/app/src/fileflash/agents/worker.py @@ -143,6 +143,8 @@ async def _mark_succeeded(self, *, job_id: int, result: dict[str, Any], phase: s ) if job is None: return + if job.status == "canceled" or job.cancel_requested_at is not None: + return now = datetime.now(UTC) job.status = "succeeded" job.result = jsonable_encoder(result) @@ -161,6 +163,8 @@ async def _mark_failed(self, *, job_id: int, error: Exception) -> None: ) if job is None: return + if job.status == "canceled" or job.cancel_requested_at is not None: + return now = datetime.now(UTC) job.status = "failed" job.agent_phase = "failed" diff --git a/app/src/fileflash/routers/agent.py b/app/src/fileflash/routers/agent.py index 218976a..2e56108 100644 --- a/app/src/fileflash/routers/agent.py +++ b/app/src/fileflash/routers/agent.py @@ -71,11 +71,10 @@ async def cancel_agent_job( canceled_at = datetime.now(UTC) if job.status not in {"succeeded", "failed", "canceled"}: job.cancel_requested_at = canceled_at + job.status = "canceled" + job.agent_phase = "canceled" + job.finished_at = canceled_at job.updated_at = canceled_at - if job.status in {"pending", "retrying"}: - job.status = "canceled" - job.agent_phase = "canceled" - job.finished_at = canceled_at await db.commit() await db.refresh(job) diff --git a/app/src/fileflash/schemas/agent.py b/app/src/fileflash/schemas/agent.py index 96ce47b..e253be0 100644 --- a/app/src/fileflash/schemas/agent.py +++ b/app/src/fileflash/schemas/agent.py @@ -116,6 +116,7 @@ class AgentExecutionResult(CamelModel): plan_job_id: str execute_job_id: str summary: str + answer: str | None = None applied_actions: int = Field(ge=0) skipped_actions: int = Field(ge=0) warnings: list[str] = Field(default_factory=list) diff --git a/app/src/fileflash/services/agent/execute_service.py b/app/src/fileflash/services/agent/execute_service.py index b06ace5..e3f7c41 100644 --- a/app/src/fileflash/services/agent/execute_service.py +++ b/app/src/fileflash/services/agent/execute_service.py @@ -69,10 +69,31 @@ async def enqueue_execute( data={"highRiskActions": high_risk_actions}, ) + idempotency_key = f"agent.execute:{plan_job_id}" + existing_execute = await self.db.scalar( + select(BackgroundJob).where( + and_( + BackgroundJob.task_type == "agent.execute", + BackgroundJob.idempotency_key == idempotency_key, + ) + ) + ) + if existing_execute is not None: + raise ApiError( + status_code=409, + code=409, + message="Plan has already been executed", + data={ + "jobId": str(existing_execute.job_id), + "status": str(existing_execute.status), + }, + ) + job = await self.jobs.enqueue( self.db, task_type="agent.execute", payload=payload.model_dump(by_alias=True, mode="json"), + idempotency_key=idempotency_key, requested_by=user_id, max_attempts=1, priority=100, diff --git a/app/tests/test_agent_plan_execute_runtime.py b/app/tests/test_agent_plan_execute_runtime.py index a07d5f7..46fd8ef 100644 --- a/app/tests/test_agent_plan_execute_runtime.py +++ b/app/tests/test_agent_plan_execute_runtime.py @@ -197,6 +197,48 @@ async def create(self, **kwargs): # noqa: ANN003 assert fake_messages.kwargs["output_config"] == {"effort": "xhigh"} +@pytest.mark.asyncio +async def test_anthropic_planner_client_retries_without_reasoning_on_empty_output(): + class FakeMessages: + def __init__(self) -> None: + self.calls: list[dict[str, object]] = [] + + async def create(self, **kwargs): # noqa: ANN003 + self.calls.append(dict(kwargs)) + if len(self.calls) == 1: + return SimpleNamespace( + content=[SimpleNamespace(type="thinking", thinking="...")], + usage={}, + ) + return SimpleNamespace( + content=[SimpleNamespace(type="text", text='{"summary":"fallback","proposedActions":[]}')], + usage={}, + ) + + fake_messages = FakeMessages() + client = AnthropicPlannerClient( + settings=settings( + agent_llm_api_key="test-key", + agent_llm_model="claude-test", + ), + client=SimpleNamespace(messages=fake_messages), # type: ignore[arg-type] + ) + + result = await client.create_plan( + system_prompt="system", + user_prompt="user", + max_tokens=1024, + reasoning_effort="high", + ) + + assert len(fake_messages.calls) == 2 + assert "thinking" in fake_messages.calls[0] + assert "output_config" in fake_messages.calls[0] + assert "thinking" not in fake_messages.calls[1] + assert "output_config" not in fake_messages.calls[1] + assert result["summary"] == "fallback" + + def test_anthropic_planner_client_uses_configured_base_url(monkeypatch: pytest.MonkeyPatch): captured: dict[str, object] = {} @@ -272,7 +314,7 @@ async def test_execute_rejects_high_risk_plan_without_confirmation(): @pytest.mark.asyncio async def test_execute_enqueue_serializes_approval_datetime_as_json_string(): db = DummyDb() - db.scalar.return_value = BackgroundJob( + plan_job = BackgroundJob( job_id=99, task_type="agent.plan", status="succeeded", @@ -283,6 +325,7 @@ async def test_execute_enqueue_serializes_approval_datetime_as_json_string(): created_at=datetime.now(UTC), updated_at=datetime.now(UTC), ) + db.scalar = AsyncMock(side_effect=[plan_job, None]) plans = AgentPlanRepository(db) # type: ignore[arg-type] plans.get_for_execute_binding = AsyncMock(return_value=SimpleNamespace(proposed_actions_json=[])) jobs = FakeJobs() @@ -306,6 +349,61 @@ async def test_execute_enqueue_serializes_approval_datetime_as_json_string(): approval_payload = jobs.kwargs["payload"]["approval"] assert isinstance(approval_payload["confirmedAt"], str) assert approval_payload["confirmedAt"] == "2026-05-25T10:00:00Z" + assert jobs.kwargs["idempotency_key"] == "agent.execute:99" + + +@pytest.mark.asyncio +async def test_execute_rejects_repeat_when_existing_execute_job_exists(): + db = DummyDb() + plan_job = BackgroundJob( + job_id=99, + task_type="agent.plan", + status="succeeded", + payload={}, + result={}, + requested_by=7, + scheduled_at=datetime.now(UTC), + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + existing_execute = BackgroundJob( + job_id=200, + task_type="agent.execute", + status="running", + payload={}, + result={}, + requested_by=7, + idempotency_key="agent.execute:99", + scheduled_at=datetime.now(UTC), + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + db.scalar = AsyncMock(side_effect=[plan_job, existing_execute]) + plans = AgentPlanRepository(db) # type: ignore[arg-type] + plans.get_for_execute_binding = AsyncMock(return_value=SimpleNamespace(proposed_actions_json=[])) + jobs = FakeJobs() + service = ExecuteService( + db=db, + settings=settings(), + jobs=jobs, # type: ignore[arg-type] + plans=plans, + work_sessions=AgentWorkSessionRepository(db), # type: ignore[arg-type] + ) + payload = ExecuteAgentRequest.model_validate( + { + "planJobId": "99", + "planHash": "sha256:test", + "approval": {"confirmedBy": "7", "confirmedAt": "2026-05-25T10:00:00Z"}, + } + ) + + with pytest.raises(ApiError) as exc: + await service.enqueue_execute(user_id=7, payload=payload) + + assert exc.value.status_code == 409 + assert "already been executed" in exc.value.message.lower() or "already" in exc.value.message.lower() + assert exc.value.data["jobId"] == "200" + assert jobs.kwargs == {} @pytest.mark.asyncio @@ -472,6 +570,118 @@ async def test_plan_runner_rolls_back_when_commit_fails(monkeypatch: pytest.Monk db.rollback.assert_awaited_once() +@pytest.mark.asyncio +async def test_plan_runner_uses_safe_read_only_fallback_when_planner_returns_invalid_output( + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr(plan_module, "_choose_skill", AsyncMock(return_value=None)) + monkeypatch.setattr( + plan_module, + "_collect_context_metadata", + AsyncMock(return_value={"scope": "currentFolder", "rootFolderId": "root", "files": [], "folders": []}), + ) + monkeypatch.setattr(plan_module, "_upsert_agent_plan", AsyncMock(return_value=None)) + + planner = AsyncMock( + side_effect=ApiError( + status_code=502, + code=502, + message="Agent LLM returned an empty response", + ) + ) + runner = PlanRunner( + settings=settings(), + planner_client=SimpleNamespace(create_plan=planner), # type: ignore[arg-type] + ) + request = PlanAgentRequest.model_validate( + { + "input": "organize", + "context": { + "rootFolderId": "root", + "selectedFileIds": [], + "selectedFolderIds": [], + "currentPath": "/My Files", + }, + "executionPolicy": "autopilot", + } + ) + job = BackgroundJob( + job_id=333, + task_type="agent.plan", + status="running", + payload=request.model_dump(by_alias=True), + result={}, + requested_by=7, + scheduled_at=datetime.now(UTC), + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + db = DummyDb() + + result = await runner.run(db=db, job=job) # type: ignore[arg-type] + + assert "fallback mode" in result.summary.lower() + assert len(result.proposed_actions) == 1 + assert result.proposed_actions[0].tool == "drive.listFolder" + assert result.proposed_actions[0].side_effect == "read" + assert result.requires_confirmation is False + + +@pytest.mark.asyncio +async def test_plan_runner_fallback_uses_count_files_for_movie_count_question( + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr(plan_module, "_choose_skill", AsyncMock(return_value=None)) + monkeypatch.setattr( + plan_module, + "_collect_context_metadata", + AsyncMock(return_value={"scope": "currentFolder", "rootFolderId": "root", "files": [], "folders": []}), + ) + monkeypatch.setattr(plan_module, "_upsert_agent_plan", AsyncMock(return_value=None)) + + planner = AsyncMock( + side_effect=ApiError( + status_code=502, + code=502, + message="Agent LLM returned an empty response", + ) + ) + runner = PlanRunner( + settings=settings(), + planner_client=SimpleNamespace(create_plan=planner), # type: ignore[arg-type] + ) + request = PlanAgentRequest.model_validate( + { + "input": "我上传了多少部电影?", + "context": { + "rootFolderId": "root", + "selectedFileIds": [], + "selectedFolderIds": [], + "currentPath": "/My Files", + }, + "executionPolicy": "confirm", + } + ) + job = BackgroundJob( + job_id=334, + task_type="agent.plan", + status="running", + payload=request.model_dump(by_alias=True), + result={}, + requested_by=7, + scheduled_at=datetime.now(UTC), + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + + result = await runner.run(db=DummyDb(), job=job) # type: ignore[arg-type] + + assert len(result.proposed_actions) == 1 + assert result.proposed_actions[0].tool == "drive.countFiles" + assert result.proposed_actions[0].input["category"] == "video" + assert result.proposed_actions[0].side_effect == "read" + + def test_normalize_actions_rejects_symbolic_placeholder_target_folder(): with pytest.raises(ApiError) as exc: plan_module._normalize_actions( @@ -594,6 +804,39 @@ async def test_tool_router_dispatches_move_file(): router.file_service.move_file.assert_awaited_once() +@pytest.mark.asyncio +async def test_tool_router_count_files_counts_recursive_videos(): + db = DummyDb() + db.scalar = AsyncMock(return_value=1) + db.scalars = AsyncMock(return_value=[1, 2]) + db.execute = AsyncMock( + return_value=SimpleNamespace( + all=lambda: [ + (10, "movie.mp4", 100, "application/octet-stream", "mp4", 1), + (11, "clip.mkv", 200, "video/x-matroska", "mkv", 2), + (12, "notes.txt", 10, "text/plain", "txt", 1), + ] + ) + ) + router = ToolRouter(db=db, user_id=7) # type: ignore[arg-type] + + result = await router.dispatch( + ToolCall( + tool_name="drive.countFiles", + arguments={"folderId": "root", "recursive": True, "category": "video"}, + ) + ) + + assert result["totalItems"] == 2 + assert result["category"] == "video" + assert result["recursive"] is True + assert result["byMimeType"] == {"video/mp4": 1, "video/x-matroska": 1} + assert [item["name"] for item in result["sampleItems"]] == ["movie.mp4", "clip.mkv"] + executed_statement = str(db.execute.await_args.args[0]) + assert "file.status" in executed_statement + assert "file.is_latest" in executed_statement + + @pytest.mark.asyncio async def test_execute_runner_normalizes_tool_output_before_action_log(monkeypatch: pytest.MonkeyPatch): started = datetime.now(UTC) @@ -672,3 +915,81 @@ async def test_execute_runner_normalizes_tool_output_before_action_log(monkeypat assert isinstance(outputs_json, dict) assert isinstance(outputs_json["createdAt"], str) assert outputs_json["createdAt"] == output_time.isoformat() + + +@pytest.mark.asyncio +async def test_execute_runner_returns_count_files_answer(monkeypatch: pytest.MonkeyPatch): + started = datetime.now(UTC) + job = BackgroundJob( + job_id=601, + task_type="agent.execute", + status="running", + payload={ + "planJobId": "501", + "planHash": "sha256:test", + "approval": { + "confirmedBy": "7", + "confirmedAt": started.isoformat(), + "highRiskConfirmed": False, + }, + }, + result={}, + requested_by=7, + scheduled_at=started, + created_at=started, + updated_at=started, + ) + action = { + "step": 1, + "tool": "drive.countFiles", + "input": {"folderId": "root", "recursive": True, "category": "video"}, + "sideEffect": "read", + "riskLevel": "low", + "requiresConfirmation": False, + } + db = DummyDb() + db.refresh = AsyncMock() + + mock_plan_repo = SimpleNamespace( + get_for_execute_binding=AsyncMock( + return_value=SimpleNamespace( + proposed_actions_json=[action], + ) + ) + ) + monkeypatch.setattr(execute_module, "AgentPlanRepository", lambda _db: mock_plan_repo) + + mock_work_sessions = SimpleNamespace( + create_for_job=AsyncMock(return_value=None), + close_session=AsyncMock(return_value=None), + ) + monkeypatch.setattr(execute_module, "AgentWorkSessionRepository", lambda _db: mock_work_sessions) + monkeypatch.setattr( + execute_module, + "AgentActionLogRepository", + lambda _db: SimpleNamespace( + append_step=AsyncMock(return_value=None), + finish_step=AsyncMock(return_value=None), + ), + ) + monkeypatch.setattr( + execute_module, + "ToolRouter", + lambda **kwargs: SimpleNamespace( + dispatch=AsyncMock( + return_value={ + "totalItems": 3, + "category": "video", + "recursive": True, + "folderId": "1", + "byMimeType": {"video/mp4": 3}, + "sampleItems": [], + } + ) + ), + ) + + result = await ExecuteRunner().run(db=db, job=job) # type: ignore[arg-type] + + assert result.answer == "你上传了 3 部电影(按视频文件统计)。" + assert result.applied_actions == 1 diff --git a/app/tests/test_agent_routes.py b/app/tests/test_agent_routes.py index f2a0506..6b50f90 100644 --- a/app/tests/test_agent_routes.py +++ b/app/tests/test_agent_routes.py @@ -49,6 +49,12 @@ async def refresh(self, _job: BackgroundJob) -> None: return None +class RunningJobDb(StubDb): + def __init__(self) -> None: + super().__init__() + self.job.status = "running" + + def _user() -> User: return User(user_id=7, username="u7", email="u7@example.com", password_hash="x") @@ -64,6 +70,17 @@ def _client() -> TestClient: return TestClient(app) +def _client_with_running_job() -> TestClient: + app = FastAPI() + app.include_router(router, prefix="/api/v1") + app.add_exception_handler(ApiError, api_error_handler) + app.dependency_overrides[get_current_user] = _user + app.dependency_overrides[get_agent_plan_service] = lambda: StubPlanService() + app.dependency_overrides[get_agent_execute_service] = lambda: StubExecuteService() + app.dependency_overrides[get_db] = lambda: RunningJobDb() + return TestClient(app) + + def test_plan_route_returns_response_shell(): response = _client().post( "/api/v1/agent/plan", @@ -123,3 +140,13 @@ def test_cancel_route_returns_response_shell(): assert body["data"]["jobId"] == "12" assert body["data"]["status"] == "canceled" assert body["data"]["canceledAt"] + + +def test_cancel_route_marks_running_job_as_canceled(): + response = _client_with_running_job().post("/api/v1/agent/cancel/12") + + assert response.status_code == 200 + body = response.json() + assert body["success"] is True + assert body["data"]["jobId"] == "12" + assert body["data"]["status"] == "canceled" diff --git a/app/tests/test_agent_worker.py b/app/tests/test_agent_worker.py new file mode 100644 index 0000000..65b15b0 --- /dev/null +++ b/app/tests/test_agent_worker.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from fileflash.agents.worker import AgentWorkerConsumer +from fileflash.models import BackgroundJob + + +class _AsyncContextManager: + def __init__(self, value): + self._value = value + + async def __aenter__(self): + return self._value + + async def __aexit__(self, exc_type, exc, tb): + return None + + +class DummySession: + def __init__(self, job: BackgroundJob): + self._job = job + + def begin(self): + return _AsyncContextManager(SimpleNamespace()) + + async def scalar(self, _query): # noqa: ANN001 + return self._job + + +def _job(*, status: str, cancel_requested_at: datetime | None) -> BackgroundJob: + now = datetime.now(UTC) + return BackgroundJob( + job_id=65, + task_type="agent.plan", + status=status, + payload={}, + result={}, + requested_by=7, + cancel_requested_at=cancel_requested_at, + scheduled_at=now, + created_at=now, + updated_at=now, + ) + + +@pytest.mark.asyncio +async def test_mark_succeeded_does_not_override_canceled_job(monkeypatch: pytest.MonkeyPatch): + canceled_at = datetime.now(UTC) + job = _job(status="canceled", cancel_requested_at=canceled_at) + session = DummySession(job) + consumer = AgentWorkerConsumer( + queue=SimpleNamespace(), + session_factory=lambda: _AsyncContextManager(session), # type: ignore[arg-type] + ) + monkeypatch.setattr("fileflash.agents.worker.apply_local_lock_timeout", AsyncMock(return_value=None)) + + await consumer._mark_succeeded(job_id=65, result={"summary": "ok"}, phase="completed") + + assert job.status == "canceled" + assert job.cancel_requested_at == canceled_at + assert job.result == {} + + +@pytest.mark.asyncio +async def test_mark_failed_does_not_override_job_with_cancel_request(monkeypatch: pytest.MonkeyPatch): + canceled_at = datetime.now(UTC) + job = _job(status="running", cancel_requested_at=canceled_at) + session = DummySession(job) + consumer = AgentWorkerConsumer( + queue=SimpleNamespace(), + session_factory=lambda: _AsyncContextManager(session), # type: ignore[arg-type] + ) + monkeypatch.setattr("fileflash.agents.worker.apply_local_lock_timeout", AsyncMock(return_value=None)) + + await consumer._mark_failed(job_id=65, error=RuntimeError("boom")) + + assert job.status == "running" + assert job.cancel_requested_at == canceled_at + assert job.error_message is None diff --git a/web/src/components/layout/StorageStatus.spec.ts b/web/src/components/layout/StorageStatus.spec.ts index 7a8db2d..d19666f 100644 --- a/web/src/components/layout/StorageStatus.spec.ts +++ b/web/src/components/layout/StorageStatus.spec.ts @@ -66,4 +66,19 @@ describe('layout/StorageStatus', () => { const style = wrapper.find('.progress-bar-fill').attributes('style'); expect(style).toContain('width: 100%;'); }); + + it('converts ratio-like storagePercentage to percent when value is in (0, 1]', () => { + const wrapper = mount(StorageStatus, { + props: { + stats: buildStats({ + storageUsed: 426, + storageLimit: 1000, + storageAvailable: 574, + storagePercentage: 0.426, + }), + }, + }); + const style = wrapper.find('.progress-bar-fill').attributes('style'); + expect(style).toContain('width: 42.6%;'); + }); }); diff --git a/web/src/components/layout/StorageStatus.vue b/web/src/components/layout/StorageStatus.vue index b1cbc51..0eaa011 100644 --- a/web/src/components/layout/StorageStatus.vue +++ b/web/src/components/layout/StorageStatus.vue @@ -17,10 +17,23 @@ const isLoading = ref(false); const storageData = computed(() => props.stats || localStats.value); const progressPercentage = computed(() => { - if (!storageData.value) return 0; - const raw = Number(storageData.value.storagePercentage); - if (!Number.isFinite(raw)) return 0; - return Math.min(100, Math.max(0, raw)); + const stats = storageData.value; + if (!stats) return 0; + + const raw = Number(stats.storagePercentage); + let normalizedPercentage: number; + + if (Number.isFinite(raw)) { + normalizedPercentage = raw > 0 && raw <= 1 && stats.storageUsed > 0 + ? raw * 100 + : raw; + } else if (stats.storageLimit > 0) { + normalizedPercentage = (stats.storageUsed / stats.storageLimit) * 100; + } else { + normalizedPercentage = 0; + } + + return Math.min(100, Math.max(0, normalizedPercentage)); }); const progressWidthPercentage = computed(() => { const stats = storageData.value; @@ -105,17 +118,19 @@ onUnmounted(() => { padding: var(--spacing-sm) 0; } .progress-bar-wrapper { + margin-bottom: var(--spacing-sm); +} +.progress-bar { width: 100%; height: 12px; background-color: var(--color-bg-tertiary); border-radius: 6px; overflow: hidden; - margin-bottom: var(--spacing-sm); } .progress-bar-fill { height: 100%; background-color: var(--color-primary); - border-radius: 6px; + border-radius: inherit; transition: width 0.5s ease-in-out; } .stats-text { diff --git a/web/src/components/molecules/Avatar.vue b/web/src/components/molecules/Avatar.vue index 9b947b4..643f675 100644 --- a/web/src/components/molecules/Avatar.vue +++ b/web/src/components/molecules/Avatar.vue @@ -34,7 +34,7 @@ const initials = computed(() => { overflow: hidden; flex-shrink: 0; } -.ff-avatar--md { width: 28px; height: 28px; font-size: 11px; } -.ff-avatar--sm { width: 20px; height: 20px; font-size: 9px; } +.ff-avatar--md { width: 28px; height: 28px; font-size: var(--text-small); } +.ff-avatar--sm { width: 20px; height: 20px; font-size: var(--text-label); } .ff-avatar img { width: 100%; height: 100%; object-fit: cover; } diff --git a/web/src/components/molecules/Button.vue b/web/src/components/molecules/Button.vue index 5ed969f..a9a20ae 100644 --- a/web/src/components/molecules/Button.vue +++ b/web/src/components/molecules/Button.vue @@ -50,7 +50,7 @@ defineEmits<{ click: [event: MouseEvent] }>(); .ff-btn:disabled { opacity: 0.5; cursor: not-allowed; } .ff-btn--md { height: 32px; } -.ff-btn--sm { height: 24px; padding: 0 10px; font-size: 9px; } +.ff-btn--sm { height: 24px; padding: 0 10px; font-size: var(--text-label); } .ff-btn--primary { background: var(--ac); color: var(--ac-fg); } .ff-btn--primary:hover:not(:disabled) { filter: brightness(1.1); box-shadow: var(--mo-hover-bloom); } diff --git a/web/src/components/molecules/Select.spec.ts b/web/src/components/molecules/Select.spec.ts index 04af91b..48aa7c0 100644 --- a/web/src/components/molecules/Select.spec.ts +++ b/web/src/components/molecules/Select.spec.ts @@ -1,4 +1,4 @@ -import { describe, it, expect, afterEach } from 'vitest'; +import { describe, it, expect, afterEach, vi } from 'vitest'; import { mount } from '../../test/mount'; import Select from './Select.vue'; @@ -50,4 +50,48 @@ describe('molecules/Select', () => { await w.find('.ff-select').trigger('keydown', { key: 'Escape' }); expect(w.find('.ff-select__menu').exists()).toBe(false); }); + + it('opens upward when there is not enough room below the trigger', async () => { + const originalInner = window.innerHeight; + Object.defineProperty(window, 'innerHeight', { value: 600, configurable: true }); + const w = mount(Select, { + props: { modelValue: 'a', options: OPTS }, + attachTo: document.body, + }); + const root = w.find('.ff-select').element as HTMLElement; + vi.spyOn(root, 'getBoundingClientRect').mockReturnValue({ + x: 0, y: 580, top: 580, bottom: 596, + left: 0, right: 120, width: 120, height: 16, + toJSON: () => undefined, + } as DOMRect); + + await w.find('.ff-select__trigger').trigger('click'); + const menu = w.find('.ff-select__menu'); + expect(menu.exists()).toBe(true); + expect(menu.classes()).toContain('ff-select__menu--up'); + + Object.defineProperty(window, 'innerHeight', { value: originalInner, configurable: true }); + }); + + it('opens downward when there is room below the trigger', async () => { + const originalInner = window.innerHeight; + Object.defineProperty(window, 'innerHeight', { value: 800, configurable: true }); + const w = mount(Select, { + props: { modelValue: 'a', options: OPTS }, + attachTo: document.body, + }); + const root = w.find('.ff-select').element as HTMLElement; + vi.spyOn(root, 'getBoundingClientRect').mockReturnValue({ + x: 0, y: 100, top: 100, bottom: 132, + left: 0, right: 120, width: 120, height: 32, + toJSON: () => undefined, + } as DOMRect); + + await w.find('.ff-select__trigger').trigger('click'); + const menu = w.find('.ff-select__menu'); + expect(menu.exists()).toBe(true); + expect(menu.classes()).toContain('ff-select__menu--down'); + + Object.defineProperty(window, 'innerHeight', { value: originalInner, configurable: true }); + }); }); diff --git a/web/src/components/molecules/Select.vue b/web/src/components/molecules/Select.vue index a0a65ad..324e53b 100644 --- a/web/src/components/molecules/Select.vue +++ b/web/src/components/molecules/Select.vue @@ -1,5 +1,5 @@