Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 213 additions & 0 deletions app/src/fileflash/agents/harness/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
from dataclasses import dataclass
from typing import Any

from sqlalchemy import and_, select
from sqlalchemy.ext.asyncio import AsyncSession

from ...core.errors import ApiError
from ...core.mime import resolve_file_mime_type
from ...models import File, Folder
from ...models.enums import FileStatus, FolderStatus, FolderType
from ...schemas.file import (
CreateFolderRequest,
GetFolderContentsQuery,
Expand Down Expand Up @@ -54,6 +58,9 @@ async def dispatch(self, call: ToolCall) -> dict[str, Any]:
)
return result.model_dump(by_alias=True, mode="json")

if tool == "drive.countFiles":
return await self._count_files(args)

if tool == "drive.createFolder":
name = _required_text(args, "name", "folderName")
parent_id = _first_value(args, "parentFolderId", "targetParentId", "folderId") or "root"
Expand Down Expand Up @@ -126,6 +133,85 @@ async def dispatch(self, call: ToolCall) -> dict[str, Any]:

raise ApiError(status_code=400, code=400, message=f"Unsupported agent tool: {tool}")

async def _count_files(self, args: dict[str, Any]) -> dict[str, Any]:
folder_id = str(_first_value(args, "folderId", "parentFolderId") or "root")
recursive = _bool_arg(args.get("recursive"), default=True)
category = _normalize_category(args.get("category"))
search = str(args.get("search") or "").strip().lower()
root_folder_id = await _resolve_folder_id(
self.db,
user_id=self.user_id,
folder_id=folder_id,
)
folder_ids = (
await _active_descendant_folder_ids(
self.db,
user_id=self.user_id,
root_folder_id=root_folder_id,
)
if recursive
else [root_folder_id]
)

statement = select(
File.file_id,
File.file_name,
File.file_size,
File.mime_type,
File.file_ext,
File.folder_id,
).where(
and_(
File.owner_id == self.user_id,
File.folder_id.in_(folder_ids),
File.status == FileStatus.ACTIVE,
File.is_latest.is_(True),
)
)
if search:
statement = statement.where(File.file_name.ilike(f"%{search}%"))
statement = statement.order_by(File.file_name.asc())

rows = (await self.db.execute(statement)).all()
by_mime_type: dict[str, int] = {}
sample_items: list[dict[str, Any]] = []
total_items = 0
for row in rows:
file_id, file_name, file_size, mime_type, file_ext, row_folder_id = row
resolved_mime = resolve_file_mime_type(
mime_type=mime_type,
file_ext=file_ext,
file_name=file_name,
)
if category is not None and _category_for_file(
mime_type=resolved_mime,
file_ext=file_ext,
file_name=file_name,
) != category:
continue

total_items += 1
by_mime_type[resolved_mime] = by_mime_type.get(resolved_mime, 0) + 1
if len(sample_items) < 5:
sample_items.append(
{
"id": str(file_id),
"name": str(file_name),
"size": int(file_size or 0),
"mimeType": resolved_mime,
"folderId": str(row_folder_id),
}
)

return {
"totalItems": total_items,
"category": category,
"recursive": recursive,
"folderId": str(root_folder_id),
"byMimeType": dict(sorted(by_mime_type.items())),
"sampleItems": sample_items,
}


def _first_value(args: dict[str, Any], *keys: str) -> Any:
for key in keys:
Expand All @@ -143,3 +229,130 @@ def _required_text(args: dict[str, Any], *keys: str) -> str:
if not text:
raise ApiError(status_code=400, code=400, message=f"Missing required tool input: {keys[0]}")
return text


def _bool_arg(value: Any, *, default: bool) -> bool:
if value is None:
return default
if isinstance(value, bool):
return value
text = str(value).strip().lower()
if text in {"1", "true", "yes", "y"}:
return True
if text in {"0", "false", "no", "n"}:
return False
return default


async def _resolve_folder_id(db: AsyncSession, *, user_id: int, folder_id: str) -> int:
if not folder_id or folder_id == "root":
root_id = await db.scalar(
select(Folder.folder_id).where(
and_(
Folder.owner_id == user_id,
Folder.parent_folder_id.is_(None),
Folder.folder_type == FolderType.ROOT,
Folder.status == FolderStatus.ACTIVE,
)
)
)
if root_id is None:
raise ApiError(status_code=404, code=404, message="Root folder not found")
return int(root_id)
try:
parsed = int(folder_id)
except ValueError as exc:
raise ApiError(status_code=400, code=400, message="Invalid folderId") from exc
exists = await db.scalar(
select(Folder.folder_id).where(
and_(
Folder.folder_id == parsed,
Folder.owner_id == user_id,
Folder.status == FolderStatus.ACTIVE,
)
)
)
if exists is None:
raise ApiError(status_code=404, code=404, message="Folder not found")
return parsed


async def _active_descendant_folder_ids(
db: AsyncSession,
*,
user_id: int,
root_folder_id: int,
) -> list[int]:
descendants = (
select(Folder.folder_id)
.where(
and_(
Folder.folder_id == root_folder_id,
Folder.owner_id == user_id,
Folder.status == FolderStatus.ACTIVE,
)
)
.cte(name="agent_count_descendants", recursive=True)
)
descendants = descendants.union_all(
select(Folder.folder_id).where(
and_(
Folder.parent_folder_id == descendants.c.folder_id,
Folder.owner_id == user_id,
Folder.status == FolderStatus.ACTIVE,
)
)
)
folder_ids = list(await db.scalars(select(descendants.c.folder_id)))
return [int(folder_id) for folder_id in folder_ids]


def _normalize_category(value: Any) -> str | None:
text = str(value or "").strip().lower()
aliases = {
"movies": "video",
"movie": "video",
"film": "video",
"films": "video",
"视频": "video",
"影片": "video",
"电影": "video",
"videos": "video",
"documents": "document",
"docs": "document",
"images": "image",
"pictures": "image",
"archives": "archive",
"compressed": "archive",
}
text = aliases.get(text, text)
if text in {"video", "audio", "image", "document", "archive", "other"}:
return text
return None


def _category_for_file(*, mime_type: str, file_ext: str | None, file_name: str | None) -> str:
mime = (mime_type or "").lower()
ext = _normalized_extension(file_ext) or _filename_extension(file_name)
if mime.startswith("video/") or ext in {"mp4", "mov", "avi", "mkv", "webm", "m4v"}:
return "video"
if mime.startswith("audio/") or ext in {"mp3", "wav", "flac", "m4a", "aac", "ogg"}:
return "audio"
if mime.startswith("image/") or ext in {"jpg", "jpeg", "png", "gif", "webp", "svg", "bmp"}:
return "image"
if mime in {"application/pdf"} or ext in {"pdf", "doc", "docx", "xls", "xlsx", "ppt", "pptx", "txt", "md"}:
return "document"
if ext in {"zip", "rar", "7z", "tar", "gz", "bz2", "xz"}:
return "archive"
return "other"


def _normalized_extension(value: str | None) -> str:
return str(value or "").strip().lower().lstrip(".")


def _filename_extension(value: str | None) -> str:
name = str(value or "").strip().lower()
if "." not in name:
return ""
return name.rsplit(".", 1)[-1]
36 changes: 36 additions & 0 deletions app/src/fileflash/agents/runtime/execute_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,12 @@ async def run(self, *, db: AsyncSession, job: BackgroundJob) -> AgentExecutionRe
warnings.append(f"{skipped} action(s) were skipped.")
await work_sessions.close_session(job_id=int(job.job_id), status="closed")
await db.commit()
answer = _build_execution_answer(actions=actions, step_outputs=step_outputs)
return AgentExecutionResult(
plan_job_id=str(plan_job_id),
execute_job_id=str(job.job_id),
summary=f"Execution completed with {applied} applied action(s).",
answer=answer,
applied_actions=applied,
skipped_actions=skipped,
warnings=warnings,
Expand Down Expand Up @@ -222,3 +224,37 @@ def _resolve_references(
for key, item in value.items()
}
return value


def _build_execution_answer(
*,
actions: list[AgentProposedAction],
step_outputs: dict[int, dict[str, Any]],
) -> str | None:
for action in actions:
if action.tool != "drive.countFiles":
continue
output = step_outputs.get(action.step)
if not isinstance(output, dict):
continue
return _count_files_answer(output)

if actions and all(action.side_effect == "read" for action in actions):
return f"已完成 {len(step_outputs)} 个只读操作。"
return None


def _count_files_answer(output: dict[str, Any]) -> str:
total_items = int(output.get("totalItems") or 0)
category = str(output.get("category") or "").strip().lower()
if category == "video":
return f"你上传了 {total_items} 部电影(按视频文件统计)。"
if category == "audio":
return f"你上传了 {total_items} 个音频文件。"
if category == "image":
return f"你上传了 {total_items} 张图片。"
if category == "document":
return f"你上传了 {total_items} 个文档。"
if category == "archive":
return f"你上传了 {total_items} 个压缩包。"
return f"你上传了 {total_items} 个文件。"
Loading
Loading