diff --git a/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/_provider.yaml b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/_provider.yaml new file mode 100644 index 00000000..e5dbcd18 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/_provider.yaml @@ -0,0 +1,55 @@ +name: sangfor_af +service_id: sangfor_af +version: "8.0.48" +description: > + Sangfor AF (Application Firewall) v8.0.48 REST API service. + Uses session-based cookie authentication: call login to obtain + a token, then supply it via Cookie header in subsequent requests. +description_cn: > + 深信服 AF 下一代防火墙 v8.0.48 REST API 服务。 + 使用基于 Session 的 Cookie 认证:首先调用 login 获取 token, + 后续请求在 Cookie 中携带 token。基础 URL 为设备管理地址(如 https://192.168.1.1)。 +auth: + type: custom + secret: sangfor_af_v8_0_48_username + secret_secret: sangfor_af_v8_0_48_password +credential_fields: + - key: username + label: 管理员用户名 + storage: secret + config_key: username + secret_id: sangfor_af_v8_0_48_username + input_type: text + required: true + - key: password + label: 管理员密码 + storage: secret + config_key: password + secret_id: sangfor_af_v8_0_48_password + input_type: password + required: true + - key: base_url + label: 设备地址 (Base URL) + storage: config + config_key: base_url + input_type: url + default: "https://192.168.1.1" +defaults: + base_url: "https://192.168.1.1" + timeout: 60 + category: custom + product_version: "8.0.48" +notes: | + 深信服 AF v8.0.48 API 认证流程: + 1. 在 AF WebUI「系统 → 管理员账号」勾选 WEBAPI 权限。 + 2. Handler 自动调用 POST /api/v1/namespaces/public/login, + 用配置的用户名/密码换取 token 并缓存(带 keepalive 自动续期)。 + 3. 后续所有请求由 Handler 自动注入 Cookie: token=。 + 4. token 默认 10 分钟无操作后失效,缓存命中失败时会自动重新登录。 + + `verify_ssl` 由表单底部「SSL 验证」开关控制,下列字段都会被 + 识别为 SSL 验证开关,按以下优先级取值: + 1. `verify_ssl` (主键) + 2. `ssl_verify` (兼容别名) + 3. `custom_settings.verify_ssl` (WebUI 表单写入位置) + 4. 兜底默认值 `False`(**默认关闭证书验证**,AF 设备通常使用自签名证书) diff --git a/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/_test.yaml b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/_test.yaml new file mode 100644 index 00000000..18f46895 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/_test.yaml @@ -0,0 +1,103 @@ +schema_version: 1 +provider: sangfor_af + +# Service-level connectivity probe. +# get_system_version is a lightweight read-only GET that verifies +# authentication and basic reachability. +connectivity: + tool: sangfor_af_v48_status + params: + action: get_system_version + +# Tool-level test samples shown in the WebUI ToolDetailDrawer drop-down. +fixtures: + sangfor_af_v48_auth: + - label: "Session keepalive" + label_cn: "刷新 Session 保活" + tags: [smoke] + params: + action: keepalive + assert: + success: true + + sangfor_af_v48_status: + - label: "Get system version" + label_cn: "获取系统版本信息" + tags: [smoke] + params: + action: get_system_version + assert: + success: true + + - label: "Get CPU usage" + label_cn: "获取 CPU 使用率" + tags: [smoke] + params: + action: get_cpu_usage + + - label: "Get memory usage" + label_cn: "获取内存使用率" + tags: [smoke] + params: + action: get_memory_usage + + sangfor_af_v48_ops: + - label: "List blacklist entries" + label_cn: "查询黑名单列表" + tags: [smoke, blacklist] + params: + action: get_blackwhitelist + type: BLACK + assert: + success: true + + - label: "List whitelist entries" + label_cn: "查询白名单列表" + tags: [smoke, whitelist] + params: + action: get_blackwhitelist + type: WHITE + + - label: "List blocked attacker IPs" + label_cn: "查询封锁攻击者 IP 列表" + tags: [smoke, blockip] + params: + action: get_blockip_list + assert: + success: true + + sangfor_af_v48_objects: + - label: "List IP groups" + label_cn: "查询 IP 地址组列表" + tags: [smoke] + params: + action: get_ipgroups + assert: + success: true + + - label: "List services" + label_cn: "查询服务对象列表" + tags: [smoke] + params: + action: get_services + assert: + success: true + + sangfor_af_v48_network: + - label: "List routing table (all routes)" + label_cn: "查询完整路由表" + tags: [smoke] + params: + action: get_routes + routeType: ALL_ROUTE + assert: + success: true + + sangfor_af_v48_system: + - label: "List admin accounts" + label_cn: "查询管理员账户列表" + tags: [smoke] + params: + action: get_accounts + assert: + success: true diff --git a/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af.handler.py b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af.handler.py new file mode 100644 index 00000000..d6c01d89 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af.handler.py @@ -0,0 +1,689 @@ +""" +Sangfor AF (Application Firewall) v8.0.48 API Handler. + +Authentication: + - Session-based: POST /api/v1/namespaces/public/login → token + - All subsequent requests: Cookie: token= + - Token expires after ~10 min of inactivity (keepalive resets timer) + +API base URL: https:// +Namespace: /api/v1/namespaces/public/ +Batch ops: /api/batch/v1/namespaces/public/ +""" +from __future__ import annotations + +import os +from typing import Any, Callable, Optional + +import aiohttp + +from flocks.config.config_writer import ConfigWriter +from flocks.tool.registry import ToolContext, ToolResult + +# ── Constants ──────────────────────────────────────────────────────────────── + +SERVICE_ID = "sangfor_af_v8_0_48" +DEFAULT_BASE_URL = "https://192.168.1.1" +DEFAULT_TIMEOUT = 60 +NAMESPACE = "public" + +API_V1 = f"/api/v1/namespaces/{NAMESPACE}" +API_BATCH = f"/api/batch/v1/namespaces/{NAMESPACE}" + +# In-process token cache: {base_url: token} +_TOKEN_CACHE: dict[str, str] = {} + + +# ── Secret / Config helpers ─────────────────────────────────────────────────── + +def _get_secret_manager(): + from flocks.security import get_secret_manager + return get_secret_manager() + + +def _resolve_ref(value: Any) -> Optional[str]: + if value is None: + return None + if not isinstance(value, str): + return str(value) + if value.startswith("{secret:") and value.endswith("}"): + return _get_secret_manager().get(value[len("{secret:"): -1]) + if value.startswith("{env:") and value.endswith("}"): + return os.getenv(value[len("{env:"): -1]) + return value + + +def _service_config() -> dict[str, Any]: + raw = ConfigWriter.get_api_service_raw(SERVICE_ID) + return raw if isinstance(raw, dict) else {} + + +def _resolve_verify_ssl(raw: dict[str, Any]) -> bool: + """Read verify_ssl with the same priority as sangfor_sip / onesec: + verify_ssl > ssl_verify > custom_settings.verify_ssl > False. + AF devices commonly use self-signed certs, so default is False. + """ + value = raw.get("verify_ssl") + if value is None: + value = raw.get("ssl_verify") + if value is None: + custom = raw.get("custom_settings") + if isinstance(custom, dict): + value = custom.get("verify_ssl") + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.strip().lower() in {"1", "true", "yes", "on"} + return False + + +def _resolve_runtime_config() -> tuple[str, int, str, str, bool]: + """Returns (base_url, timeout, username, password, verify_ssl).""" + raw = _service_config() + base_url = ( + _resolve_ref(raw.get("base_url")) or DEFAULT_BASE_URL + ).rstrip("/") + timeout = raw.get("timeout", DEFAULT_TIMEOUT) + try: + timeout = int(timeout) + except (TypeError, ValueError): + timeout = DEFAULT_TIMEOUT + + sm = _get_secret_manager() + + username = ( + _resolve_ref(raw.get("username")) + or sm.get("sangfor_af_v8_0_48_username") + or os.getenv("AF_USERNAME") + ) + password = ( + _resolve_ref(raw.get("password")) + or sm.get("sangfor_af_v8_0_48_password") + or os.getenv("AF_PASSWORD") + ) + + if not username or not password: + raise ValueError( + "AF API credentials not configured. " + "Please set username and password in the service configuration." + ) + return base_url, timeout, username, password, _resolve_verify_ssl(raw) + + +# ── Session / Token management ──────────────────────────────────────────────── + +async def _login( + session: aiohttp.ClientSession, + base_url: str, + username: str, + password: str, + verify_ssl: bool, +) -> tuple[Optional[str], Optional[str]]: + """Login and return (token, error_message).""" + url = f"{base_url}{API_V1}/login" + try: + async with session.post( + url, + json={"name": username, "password": password}, + ssl=verify_ssl, + ) as resp: + data = await resp.json(content_type=None) + except aiohttp.ClientError as exc: + return None, f"AF login request failed: {exc}" + + code = data.get("code") + if code != 0: + msg = data.get("message", "Unknown error") + return None, f"AF login failed (code={code}): {msg}" + + token = ( + data.get("data", {}).get("loginResult", {}).get("token") + ) + if not token: + return None, "AF login succeeded but no token returned" + return token, None + + +async def _get_token( + session: aiohttp.ClientSession, + base_url: str, + username: str, + password: str, + verify_ssl: bool, +) -> tuple[Optional[str], Optional[str]]: + """Return cached token or obtain a new one.""" + cached = _TOKEN_CACHE.get(base_url) + if cached: + # Validate by keepalive + try: + async with session.get( + f"{base_url}{API_V1}/keepalive", + headers={"Cookie": f"token={cached}"}, + ssl=verify_ssl, + ) as resp: + ka_data = await resp.json(content_type=None) + if ka_data.get("code") == 0: + return cached, None + except Exception: + pass + + token, err = await _login(session, base_url, username, password, verify_ssl) + if err: + return None, err + _TOKEN_CACHE[base_url] = token + return token, None + + +# ── Low-level HTTP ──────────────────────────────────────────────────────────── + +def _pick(params: dict[str, Any], *keys: str) -> dict[str, Any]: + return {k: params[k] for k in keys if k in params and params[k] is not None} + + +def _af_result(action: str, payload: Any) -> ToolResult: + metadata = {"source": "Sangfor AF", "api": action, "version": "8.0.48"} + if isinstance(payload, dict): + code = payload.get("code") + if code not in (None, 0): + msg = payload.get("message", "Unknown error") + return ToolResult( + success=False, + error=f"AF API error (code={code}): {msg}", + metadata=metadata, + ) + return ToolResult( + success=True, + output=payload.get("data", payload), + metadata=metadata, + ) + return ToolResult(success=True, output=payload, metadata=metadata) + + +async def _call( + method: str, + path: str, + params: Optional[dict[str, Any]] = None, + json: Optional[Any] = None, + action: str = "", +) -> ToolResult: + """Execute an authenticated AF API request.""" + try: + base_url, timeout, username, password, verify_ssl = _resolve_runtime_config() + except ValueError as exc: + return ToolResult(success=False, error=str(exc)) + + headers = {"Content-Type": "application/json"} + + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=timeout) + ) as session: + token, err = await _get_token(session, base_url, username, password, verify_ssl) + if err: + return ToolResult(success=False, error=err) + + headers["Cookie"] = f"token={token}" + url = f"{base_url}{path}" + + try: + async with session.request( + method.upper(), + url, + params=params, + json=json, + headers=headers, + ssl=verify_ssl, + ) as resp: + if resp.status >= 400: + text = await resp.text() + return ToolResult( + success=False, + error=f"HTTP {resp.status}: {text[:500]}", + ) + data = await resp.json(content_type=None) + except aiohttp.ClientError as exc: + return ToolResult(success=False, error=f"Request failed: {exc}") + except Exception as exc: + return ToolResult(success=False, error=f"Unexpected error: {exc}") + + return _af_result(action or path.rsplit("/", 1)[-1], data) + + +# ── Action specs ───────────────────────────────────────────────────────────── + +class ActionSpec: + def __init__( + self, + method: str, + path_template: str, + param_builder: Callable[[dict[str, Any]], tuple[ + Optional[dict], Optional[Any] + ]], + required: tuple[str, ...] = (), + ) -> None: + self.method = method + self.path_template = path_template + self.param_builder = param_builder + self.required = required + + def build_path(self, params: dict[str, Any]) -> str: + try: + return self.path_template.format(**params) + except KeyError: + return self.path_template + + +# ── Auth actions ────────────────────────────────────────────────────────────── + +async def _do_login(ctx: ToolContext, **params: Any) -> ToolResult: + """Explicitly login and refresh the cached token.""" + del ctx + try: + base_url, timeout, username, password, verify_ssl = _resolve_runtime_config() + except ValueError as exc: + return ToolResult(success=False, error=str(exc)) + + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=timeout) + ) as session: + token, err = await _login(session, base_url, username, password, verify_ssl) + if err: + return ToolResult(success=False, error=err) + _TOKEN_CACHE[base_url] = token + return ToolResult( + success=True, + output={"token": token, "message": "Login successful"}, + metadata={"source": "Sangfor AF", "api": "login", "version": "8.0.48"}, + ) + + +async def _do_logout(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + result = await _call("POST", f"{API_V1}/logout", action="logout") + try: + base_url, *_ = _resolve_runtime_config() + _TOKEN_CACHE.pop(base_url, None) + except ValueError: + pass + return result + + +async def _do_keepalive(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/keepalive", action="keepalive") + + +# ── Objects actions ────────────────────────────────────────────────────────── + +async def _do_get_ipgroups(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick( + params, + "_start", "_length", "businessType", "__nameprefix", "important", + "_search", "_order", "_sortby", "addressType", + ) + return await _call("GET", f"{API_V1}/ipgroups", params=query, action="get_ipgroups") + + +async def _do_get_ipgroup(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + uuid = params.get("uuid", "") + return await _call("GET", f"{API_V1}/ipgroups/{uuid}", action="get_ipgroup") + + +async def _do_create_ipgroup(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + body = _pick( + params, + "name", "businessType", "description", "addressType", "important", + "ipRanges", "creator", + ) + return await _call("POST", f"{API_V1}/ipgroups", json={"obj": body}, action="create_ipgroup") + + +async def _do_update_ipgroup(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + uuid = params.get("uuid", "") + body = _pick( + params, + "name", "businessType", "description", "addressType", "important", + "ipRanges", "creator", + ) + return await _call("PATCH", f"{API_V1}/ipgroups/{uuid}", json={"obj": body}, action="update_ipgroup") + + +async def _do_delete_ipgroup(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + uuid = params.get("uuid", "") + return await _call("DELETE", f"{API_V1}/ipgroups/{uuid}", action="delete_ipgroup") + + +async def _do_get_services(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "_search", "_order", "_sortby", "serviceType") + return await _call("GET", f"{API_V1}/services", params=query, action="get_services") + + +async def _do_get_service(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + uuid = params.get("uuid", "") + return await _call("GET", f"{API_V1}/services/{uuid}", action="get_service") + + +# ── Operations center actions ───────────────────────────────────────────────── + +async def _do_get_blackwhitelist(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "type", "_start", "_length", "_search", "_order", "description") + return await _call("GET", f"{API_V1}/whiteblacklist", params=query, action="get_blackwhitelist") + + +async def _do_add_blackwhitelist(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + body = _pick(params, "url", "type", "enable", "description", "domain") + return await _call("POST", f"{API_V1}/whiteblacklist", json={"obj": body}, action="add_blackwhitelist") + + +async def _do_batch_add_blackwhitelist(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + items = params.get("items", []) + return await _call( + "POST", + f"{API_BATCH}/whiteblacklist", + json=items, + action="batch_add_blackwhitelist", + ) + + +async def _do_delete_blackwhitelist(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + url_param = params.get("url", "") + list_type = params.get("type", "") + query = {"type": list_type} if list_type else None + return await _call( + "DELETE", + f"{API_V1}/whiteblacklist/{url_param}", + params=query, + action="delete_blackwhitelist", + ) + + +async def _do_batch_delete_blackwhitelist(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + items = params.get("items", []) + return await _call( + "POST", + f"{API_BATCH}/whiteblacklist", + params={"_method": "DELETE"}, + json=items, + action="batch_delete_blackwhitelist", + ) + + +async def _do_get_blockip_list(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "_sortby", "_order", "creator", "fuzzyIP") + return await _call("GET", f"{API_V1}/blockip", params=query, action="get_blockip_list") + + +async def _do_batch_add_blockip(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + items = params.get("items", []) + query = _pick(params, "aifwType") + return await _call( + "POST", + f"{API_BATCH}/blockip", + params=query or None, + json=items, + action="batch_add_blockip", + ) + + +async def _do_batch_delete_blockip(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + items = params.get("items", []) + return await _call( + "POST", + f"{API_BATCH}/blockip", + params={"_method": "DELETE"}, + json=items, + action="batch_delete_blockip", + ) + + +async def _do_clear_blockip(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "creator") + return await _call( + "DELETE", + f"{API_V1}/blockip", + params=query or None, + action="clear_blockip", + ) + + +async def _do_get_blockip_auto_config(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/blockip/autoconfig", action="get_blockip_auto_config") + + +async def _do_set_blockip_auto_config(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + body = _pick(params, "blockTime") + return await _call( + "PUT", + f"{API_V1}/blockip/autoconfig", + json={"obj": body}, + action="set_blockip_auto_config", + ) + + +# ── Status / device info actions ────────────────────────────────────────────── + +async def _do_get_memory_usage(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/memoryusage", action="get_memory_usage") + + +async def _do_get_cpu_usage(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/cpuusage", action="get_cpu_usage") + + +async def _do_get_disk_usage(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/diskusage", action="get_disk_usage") + + +async def _do_get_system_version(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "filter") + return await _call( + "GET", f"{API_V1}/systemversion", + params=query or None, + action="get_system_version", + ) + + +async def _do_get_interface_status(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + # AF8.0.x: /interfacestatus returns 1002; use /interfaces (list) or + # /interfaces/status?interfaceName= (single interface query). + iface = params.get("interfaceNames") or params.get("interfaceName") or "" + if iface: + return await _call( + "GET", f"{API_V1}/interfaces/status", + params={"interfaceName": iface}, + action="get_interface_status", + ) + return await _call("GET", f"{API_V1}/interfaces", action="get_interface_status") + + +async def _do_get_runtime_status(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/runtimestatus", action="get_runtime_status") + + +async def _do_get_current_time(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/currenttime", action="get_current_time") + + +# ── Network / routing actions ───────────────────────────────────────────────── + +async def _do_get_routes(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "routeType", "_search") + return await _call("GET", f"{API_V1}/routes", params=query or None, action="get_routes") + + +async def _do_get_routes_ipv6(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "routeType", "_search") + return await _call("GET", f"{API_V1}/routes/ipv6", params=query or None, action="get_routes_ipv6") + + +# ── Admin account actions ───────────────────────────────────────────────────── + +async def _do_get_accounts(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "_search", "enable") + return await _call("GET", f"{API_V1}/account", params=query or None, action="get_accounts") + + +async def _do_get_account(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + name = params.get("name", "") + return await _call("GET", f"{API_V1}/account/{name}", action="get_account") + + +# ── Action dispatch ─────────────────────────────────────────────────────────── + +_ACTION_MAP: dict[str, Callable] = { + # Auth + "login": _do_login, + "logout": _do_logout, + "keepalive": _do_keepalive, + # Objects + "get_ipgroups": _do_get_ipgroups, + "get_ipgroup": _do_get_ipgroup, + "create_ipgroup": _do_create_ipgroup, + "update_ipgroup": _do_update_ipgroup, + "delete_ipgroup": _do_delete_ipgroup, + "get_services": _do_get_services, + "get_service": _do_get_service, + # Operations center - blacklist/whitelist + "get_blackwhitelist": _do_get_blackwhitelist, + "add_blackwhitelist": _do_add_blackwhitelist, + "batch_add_blackwhitelist": _do_batch_add_blackwhitelist, + "delete_blackwhitelist": _do_delete_blackwhitelist, + "batch_delete_blackwhitelist": _do_batch_delete_blackwhitelist, + # Operations center - blocked IPs + "get_blockip_list": _do_get_blockip_list, + "batch_add_blockip": _do_batch_add_blockip, + "batch_delete_blockip": _do_batch_delete_blockip, + "clear_blockip": _do_clear_blockip, + "get_blockip_auto_config": _do_get_blockip_auto_config, + "set_blockip_auto_config": _do_set_blockip_auto_config, + # Status + "get_memory_usage": _do_get_memory_usage, + "get_cpu_usage": _do_get_cpu_usage, + "get_disk_usage": _do_get_disk_usage, + "get_system_version": _do_get_system_version, + "get_interface_status": _do_get_interface_status, + "get_runtime_status": _do_get_runtime_status, + "get_current_time": _do_get_current_time, + # Network + "get_routes": _do_get_routes, + "get_routes_ipv6": _do_get_routes_ipv6, + # System + "get_accounts": _do_get_accounts, + "get_account": _do_get_account, +} + +GROUP_ACTIONS: dict[str, set[str]] = { + "auth": {"login", "logout", "keepalive"}, + "objects": {"get_ipgroups", "get_ipgroup", "create_ipgroup", "update_ipgroup", "delete_ipgroup", "get_services", "get_service"}, + "ops": { + "get_blackwhitelist", "add_blackwhitelist", "batch_add_blackwhitelist", + "delete_blackwhitelist", "batch_delete_blackwhitelist", + "get_blockip_list", "batch_add_blockip", "batch_delete_blockip", + "clear_blockip", "get_blockip_auto_config", "set_blockip_auto_config", + }, + "status": { + "get_memory_usage", "get_cpu_usage", "get_disk_usage", + "get_system_version", "get_interface_status", + "get_runtime_status", "get_current_time", + }, + "network": {"get_routes", "get_routes_ipv6"}, + "system": {"get_accounts", "get_account"}, +} + +_CONNECTIVITY_TEST_ACTIONS: dict[str, str] = { + "auth": "keepalive", + "objects": "get_ipgroups", + "ops": "get_blackwhitelist", + "status": "get_system_version", + "network": "get_routes", + "system": "get_accounts", +} + + +async def unified_ops(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + handler = _ACTION_MAP.get(action) + if handler is None: + available = ", ".join(sorted(_ACTION_MAP)) + return ToolResult( + success=False, + error=f"Unknown action: {action}. Available: {available}", + ) + return await handler(ctx, **params) + + +async def _dispatch_group(ctx: ToolContext, group: str, action: str, **params: Any) -> ToolResult: + if action == "test": + test_action = _CONNECTIVITY_TEST_ACTIONS.get(group, "get_system_version") + return await unified_ops(ctx, action=test_action, **params) + if action not in GROUP_ACTIONS[group]: + available = ", ".join(sorted(GROUP_ACTIONS[group])) + return ToolResult( + success=False, + error=f"Unsupported {group} action: {action}. Available: {available}", + ) + return await unified_ops(ctx, action=action, **params) + + +async def auth(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "auth", action, **params) + + +async def objects(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "objects", action, **params) + + +async def ops(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "ops", action, **params) + + +async def status(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "status", action, **params) + + +async def network(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "network", action, **params) + + +async def system(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "system", action, **params) + + +def _make_action_function(action: str): + async def _tool(ctx: ToolContext, **kwargs: Any) -> ToolResult: + return await unified_ops(ctx, action=action, **kwargs) + _tool.__name__ = action + return _tool + + +for _action_name in _ACTION_MAP: + globals()[_action_name] = _make_action_function(_action_name) + +del _action_name diff --git a/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af_v48_auth.yaml b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af_v48_auth.yaml new file mode 100644 index 00000000..ecb99db6 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af_v48_auth.yaml @@ -0,0 +1,44 @@ +name: sangfor_af_v48_auth +description: > + Sangfor AF v8.0.48 authentication tool. Use the `action` parameter to + login, logout, or keep the session alive. Token is cached automatically + after a successful login. +description_cn: > + 深信服 AF v8.0.48 认证工具。通过 `action` 参数调用登录、注销或 token 保活接口。 + 登录成功后 token 会自动缓存,后续调用无需手动传 token。 +category: custom +enabled: true +requires_confirmation: false +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 认证动作名,可选值: + - login + 用途: 登录设备,获取 session token(token 自动缓存) + 必填: 无(用户名/密码从服务配置读取) + 风险提示: 只读认证接口 + 是否任务型: 否 + - logout + 用途: 注销当前登录 session,清除 token 缓存 + 必填: 无 + 风险提示: 写操作,注销后需重新登录 + 是否任务型: 否 + - keepalive + 用途: 刷新 token 超时计时器,保持 session 活跃 + 必填: 无 + 风险提示: 只读接口 + 是否任务型: 否 + enum: + - login + - logout + - keepalive + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: auth diff --git a/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af_v48_network.yaml b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af_v48_network.yaml new file mode 100644 index 00000000..5dbdd334 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af_v48_network.yaml @@ -0,0 +1,65 @@ +name: sangfor_af_v48_network +description: > + Sangfor AF v8.0.48 network tool. Query routing tables (IPv4 and IPv6) + and network-related status information. +description_cn: > + 深信服 AF v8.0.48 网络工具。通过 `action` 参数查询路由表(IPv4/IPv6) + 及网络相关状态信息。 +category: custom +enabled: true +requires_confirmation: false +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 网络查询动作名,可选值: + - get_routes + 用途: 获取后台 IPv4 路由信息列表 + 必填: 无 + 常用: routeType(ALL_ROUTE/STATIC_ROUTE/DIRECT_ROUTE 等)、_start、_length + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_routes_ipv6 + 用途: 获取后台 IPv6 路由信息列表 + 必填: 无 + 常用: routeType、_start、_length + 风险提示: 只读查询接口 + 是否任务型: 否 + enum: + - get_routes + - get_routes_ipv6 + routeType: + type: string + description: > + 路由类型过滤:ALL_ROUTE=所有路由,STATIC_ROUTE=静态路由, + DIRECT_ROUTE=直连路由,OSPF_ROUTE=OSPF路由,RIP_ROUTE=RIP路由, + VPN_ROUTE=VPN路由,SSL_VPN_ROUTE=SSL VPN路由, + IBGP_ROUTE=IBGP路由,EBGP_ROUTE=EBGP路由 + enum: + - ALL_ROUTE + - STATIC_ROUTE + - DIRECT_ROUTE + - OSPF_ROUTE + - RIP_ROUTE + - VPN_ROUTE + - SSL_VPN_ROUTE + - IBGP_ROUTE + - EBGP_ROUTE + _start: + type: integer + description: 分页起始位置(从0开始) + _length: + type: integer + description: 每页最大返回数量(最大200,默认100) + _search: + type: string + description: 模糊搜索关键字 + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: network diff --git a/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af_v48_objects.yaml b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af_v48_objects.yaml new file mode 100644 index 00000000..69ec70cf --- /dev/null +++ b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af_v48_objects.yaml @@ -0,0 +1,152 @@ +name: sangfor_af_v48_objects +description: > + Sangfor AF v8.0.48 objects management tool. Query, create, update, and + delete network IP group objects and services (protocol/port definitions) + used in firewall policies. +description_cn: > + 深信服 AF v8.0.48 对象管理工具。通过 `action` 参数查询、创建、修改和删除 + IP 地址组对象及服务对象(协议/端口定义),这些对象被防火墙策略引用。 +category: custom +enabled: true +requires_confirmation: true +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 对象管理动作名,可选值: + + ## IP 地址组 + - get_ipgroups + 用途: 查询符合条件的 IP 地址组列表 + 必填: 无 + 常用: _start、_length、businessType、__nameprefix、important、_search + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_ipgroup + 用途: 获取单个 IP 地址组详情 + 必填: uuid + 风险提示: 只读查询接口 + 是否任务型: 否 + - create_ipgroup + 用途: 创建新的 IP 地址组 + 必填: name、businessType + 常用: ipRanges、addressType、description、important + 风险提示: 写操作;创建后可被防火墙策略引用 + 是否任务型: 否 + - update_ipgroup + 用途: 增量更新(PATCH)指定 IP 地址组 + 必填: uuid + 常用: name、ipRanges、description + 风险提示: 写操作;修改 IP 组会影响引用该组的所有策略 + 是否任务型: 否 + - delete_ipgroup + 用途: 删除指定 IP 地址组 + 必填: uuid + 风险提示: 高风险写操作;如有策略引用该组将删除失败 + 是否任务型: 否 + + ## 服务对象 + - get_services + 用途: 查询服务或服务组列表(预定义或自定义) + 必填: 无 + 常用: _start、_length、_search、serviceType + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_service + 用途: 获取单个服务或服务组详情 + 必填: uuid + 风险提示: 只读查询接口 + 是否任务型: 否 + enum: + - get_ipgroups + - get_ipgroup + - create_ipgroup + - update_ipgroup + - delete_ipgroup + - get_services + - get_service + + uuid: + type: string + description: IP地址组或服务对象的唯一标识符(32字符UUID) + name: + type: string + description: 对象名称(最大95字符) + businessType: + type: string + description: > + IP地址组业务类型:IP=IP地址,ADDRGROUP=地址组, + USER=用户地址,BUSINESS=业务地址 + enum: + - IP + - ADDRGROUP + - USER + - BUSINESS + addressType: + type: string + description: "IP协议版本:IPV4 或 IPV6" + enum: + - IPV4 + - IPV6 + important: + type: string + description: "重要级别:COMMON=普通,CORE=核心" + enum: + - COMMON + - CORE + ipRanges: + type: array + items: + type: object + properties: + start: + type: string + description: IP范围起始地址(如 192.168.1.1) + end: + type: string + description: IP范围结束地址(如 192.168.1.254) + description: IP地址范围列表 + description: + type: string + description: 对象描述(最大95字符) + creator: + type: string + description: 创建者名称 + serviceType: + type: string + description: "服务类型过滤:SERVICE=单个服务,SERVICEGROUP=服务组" + enum: + - SERVICE + - SERVICEGROUP + + # Pagination + _start: + type: integer + description: 分页起始位置(从0开始) + _length: + type: integer + description: 每页最大返回数量(最大200,默认100) + __nameprefix: + type: string + description: 按名称前缀过滤(最大95字符) + _search: + type: string + description: 模糊搜索关键字(最大95字符) + _order: + type: string + description: "排序方向:asc 或 desc" + enum: + - asc + - desc + _sortby: + type: string + description: 排序字段名 + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: objects diff --git a/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af_v48_ops.yaml b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af_v48_ops.yaml new file mode 100644 index 00000000..6ae193b4 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af_v48_ops.yaml @@ -0,0 +1,165 @@ +name: sangfor_af_v48_ops +description: > + Sangfor AF v8.0.48 operations center tool. Manages blacklist/whitelist + entries (IPs, domains, URLs) and blocked attacker IPs via the `action` + parameter. Key security triage actions for SOC workflows. +description_cn: > + 深信服 AF v8.0.48 运营中心工具。通过 `action` 参数管理黑白名单(IP/域名/URL) + 和封锁攻击者 IP。是 SOC 安全处置的核心接口。 +category: custom +enabled: true +requires_confirmation: true +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 运营中心动作名,可选值: + + ## 黑白名单管理 + - get_blackwhitelist + 用途: 查询黑白名单列表(IP/域名/URL) + 必填: 无 + 常用: type(BLACK/WHITE)、_start、_length + 风险提示: 只读查询接口 + 是否任务型: 否 + - add_blackwhitelist + 用途: 添加单条黑白名单 + 必填: url(IP/域名/URL)、type(BLACK/WHITE) + 常用: enable、description、domain(0=IP,1=域名,2=URL) + 风险提示: 写操作;添加黑名单会拦截对应流量 + 是否任务型: 否 + - batch_add_blackwhitelist + 用途: 批量添加黑白名单 + 必填: items(数组,每项含 url/type 字段) + 风险提示: 写操作,批量添加黑名单影响面大 + 是否任务型: 否 + - delete_blackwhitelist + 用途: 删除单条黑白名单 + 必填: url(条目的 IP/域名/URL) + 常用: type(BLACK/WHITE) + 风险提示: 写操作,删除白名单可能导致误拦截 + 是否任务型: 否 + - batch_delete_blackwhitelist + 用途: 批量删除黑白名单 + 必填: items(数组,每项含 url 字段) + 风险提示: 写操作,批量删除影响面大 + 是否任务型: 否 + + ## 封锁攻击者 IP + - get_blockip_list + 用途: 查询当前封锁攻击者 IP 列表 + 必填: 无 + 常用: _start、_length、fuzzyIP(模糊搜索)、creator(AF/SIP) + 风险提示: 只读查询接口 + 是否任务型: 否 + - batch_add_blockip + 用途: 批量封锁攻击者 IP + 必填: items(数组,每项含 srcIP、dstIP 等字段) + 常用: aifwType(MANUAL/AUTO) + 风险提示: 高风险写操作;封锁 IP 会拦截其所有流量 + 是否任务型: 否 + - batch_delete_blockip + 用途: 批量解封攻击者 IP + 必填: items(数组,每项含 srcIP、dstIP 等字段) + 风险提示: 写操作,解封恶意 IP 存在安全风险 + 是否任务型: 否 + - clear_blockip + 用途: 清空封锁攻击者 IP 列表 + 必填: 无 + 常用: creator(AF/SIP,指定清除哪类封锁) + 风险提示: 高风险写操作;会清除所有封锁 IP + 是否任务型: 否 + - get_blockip_auto_config + 用途: 获取自动封锁攻击者时长配置 + 必填: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + - set_blockip_auto_config + 用途: 修改自动封锁攻击者时长 + 必填: blockTime(封锁时长,单位秒) + 风险提示: 写操作,影响自动封锁策略 + 是否任务型: 否 + enum: + - get_blackwhitelist + - add_blackwhitelist + - batch_add_blackwhitelist + - delete_blackwhitelist + - batch_delete_blackwhitelist + - get_blockip_list + - batch_add_blockip + - batch_delete_blockip + - clear_blockip + - get_blockip_auto_config + - set_blockip_auto_config + + # Blacklist/whitelist params + url: + type: string + description: IP地址、域名或URL(黑白名单条目值) + type: + type: string + description: "名单类型:BLACK(黑名单)或 WHITE(白名单)" + enum: + - BLACK + - WHITE + enable: + type: boolean + description: 是否启用该条目,默认 true + description: + type: string + description: 条目描述信息(最大95字符) + domain: + type: integer + description: "条目类型:0=IP地址,1=域名,2=URL" + enum: [0, 1, 2] + items: + type: array + items: + type: object + description: 批量操作时的条目数组,每项至少包含 url(黑白名单)或 srcIP/dstIP(封锁IP) + + # Block IP params + fuzzyIP: + type: string + description: 模糊搜索IP关键字(最大15字符) + creator: + type: string + description: "封锁来源身份:AF(防火墙自身)或 SIP(安全感知平台)" + enum: + - AF + - SIP + aifwType: + type: string + description: "添加封锁IP的类型:MANUAL(手动)或 AUTO(自动,需要 creator=SIP)" + enum: + - MANUAL + - AUTO + blockTime: + type: integer + description: 自动封锁时长(秒) + + # Pagination + _start: + type: integer + description: 分页起始位置(从0开始) + _length: + type: integer + description: 每页最大返回数量(最大200,默认100) + _sortby: + type: string + description: 排序字段名 + _order: + type: string + description: "排序方向:asc(升序)或 desc(降序)" + enum: + - asc + - desc + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: ops diff --git a/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af_v48_status.yaml b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af_v48_status.yaml new file mode 100644 index 00000000..f09276a1 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af_v48_status.yaml @@ -0,0 +1,91 @@ +name: sangfor_af_v48_status +description: > + Sangfor AF v8.0.48 device status tool. Query system resource usage + (CPU, memory, disk), firmware version, network interface status, + current time, and system uptime. +description_cn: > + 深信服 AF v8.0.48 状态中心工具。通过 `action` 参数查询系统资源(CPU/内存/磁盘)、 + 固件版本、网口状态、当前时间及系统运行时长等信息。 +category: custom +enabled: true +requires_confirmation: false +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 状态查询动作名,可选值: + - get_memory_usage + 用途: 获取当前内存使用率(百分比) + 必填: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_cpu_usage + 用途: 获取当前 CPU 使用率(百分比) + 必填: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_disk_usage + 用途: 获取当前磁盘使用率(百分比) + 必填: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_system_version + 用途: 获取 AF 系统固件版本信息 + 必填: 无 + 常用: filter(ALL/FULL/MAJOR/MINOR 等) + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_interface_status + 用途: 获取指定网口或全部网口的状态(流速、连接状态) + 必填: 无 + 常用: interfaceNames(如 eth0,不传则获取全部) + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_runtime_status + 用途: 获取系统运行时长(uptime) + 必填: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_current_time + 用途: 获取设备当前时间 + 必填: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + enum: + - get_memory_usage + - get_cpu_usage + - get_disk_usage + - get_system_version + - get_interface_status + - get_runtime_status + - get_current_time + filter: + type: string + description: > + 版本信息过滤(仅用于 get_system_version): + ALL=显示所有,FULL=完整版本号,MAJOR=主版本号,MINOR=次版本号, + INCREASE=增版本号,BUILD=创建日期,EN=是否英文版,HF=是否HF版,B=是否Beta版 + enum: + - ALL + - FULL + - MAJOR + - MINOR + - INCREASE + - BUILD + - EN + - HF + - B + - R + - ADD + interfaceNames: + type: string + description: 网口名称(如 eth0),用于 get_interface_status;不填则获取全部接口 + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: status diff --git a/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af_v48_system.yaml b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af_v48_system.yaml new file mode 100644 index 00000000..e41c6c29 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_48/sangfor_af_v48_system.yaml @@ -0,0 +1,53 @@ +name: sangfor_af_v48_system +description: > + Sangfor AF v8.0.48 system management tool. Query and manage administrator + accounts on the AF device. +description_cn: > + 深信服 AF v8.0.48 系统管理工具。通过 `action` 参数查询和管理 AF 设备上的 + 管理员账户信息。 +category: custom +enabled: true +requires_confirmation: true +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 系统管理动作名,可选值: + - get_accounts + 用途: 查询所有管理员账户列表 + 必填: 无 + 常用: _start、_length、enable + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_account + 用途: 查询指定管理员账户详情 + 必填: name(账户名) + 风险提示: 只读查询接口 + 是否任务型: 否 + enum: + - get_accounts + - get_account + name: + type: string + description: 管理员账户名(用于 get_account) + enable: + type: boolean + description: 按启用/禁用状态过滤账户 + _start: + type: integer + description: 分页起始位置(从0开始) + _length: + type: integer + description: 每页最大返回数量 + _search: + type: string + description: 模糊搜索关键字 + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: system diff --git a/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/_provider.yaml b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/_provider.yaml new file mode 100644 index 00000000..b3566602 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/_provider.yaml @@ -0,0 +1,52 @@ +name: sangfor_af +service_id: sangfor_af +version: "8.0.85" +description: > + Sangfor AF (Application Firewall) v8.0.85 REST API service. + Uses session-based cookie authentication: call login to obtain + a token, then supply it via Cookie header in subsequent requests. + This version adds monitoring (session/traffic/statistics) APIs + compared to v8.0.48. +description_cn: > + 深信服 AF 下一代防火墙 v8.0.85 REST API 服务(对应 AF8.0.95 发布版文档)。 + 使用基于 Session 的 Cookie 认证:首先调用 login 获取 token, + 后续请求在 Cookie 中携带 token。在 v8.0.48 基础上新增监控(会话/流量/统计)相关 API。 +auth: + type: custom + secret: sangfor_af_v8_0_85_username + secret_secret: sangfor_af_v8_0_85_password +credential_fields: + - key: username + label: 管理员用户名 + storage: secret + config_key: username + secret_id: sangfor_af_v8_0_85_username + input_type: text + required: true + - key: password + label: 管理员密码 + storage: secret + config_key: password + secret_id: sangfor_af_v8_0_85_password + input_type: password + required: true + - key: base_url + label: 设备地址 (Base URL) + storage: config + config_key: base_url + input_type: url + default: "https://192.168.1.1" +defaults: + base_url: "https://192.168.1.1" + timeout: 60 + category: custom + product_version: "8.0.85" +notes: | + 深信服 AF v8.0.85 API 认证流程(同 v8.0.48): + 1. 在 AF WebUI「系统 → 管理员账号」勾选 WEBAPI 权限。 + 2. Handler 自动用用户名/密码换取 token 并缓存(带 keepalive 自动续期)。 + 3. 后续所有请求由 Handler 自动注入 Cookie: token=。 + 4. token 默认 10 分钟无操作后失效,缓存失效会自动重新登录。 + 5. 本版本在 v8.0.48 基础上新增了监控相关 API(会话/流量统计等)。 + + `verify_ssl` 由表单底部「SSL 验证」开关控制(默认关闭,与 sangfor_sip 一致)。 diff --git a/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/_test.yaml b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/_test.yaml new file mode 100644 index 00000000..2e70f0c4 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/_test.yaml @@ -0,0 +1,100 @@ +schema_version: 1 +provider: sangfor_af + +connectivity: + tool: sangfor_af_v85_status + params: + action: get_system_version + +fixtures: + sangfor_af_v85_auth: + - label: "Session keepalive" + label_cn: "刷新 Session 保活" + tags: [smoke] + params: + action: keepalive + assert: + success: true + + sangfor_af_v85_status: + - label: "Get system version" + label_cn: "获取系统版本信息" + tags: [smoke] + params: + action: get_system_version + assert: + success: true + + - label: "Get CPU usage" + label_cn: "获取 CPU 使用率" + tags: [smoke] + params: + action: get_cpu_usage + + sangfor_af_v85_monitor: + - label: "Get session summary" + label_cn: "获取会话概要信息" + tags: [smoke, monitor] + params: + action: get_session_summary + assert: + success: true + + - label: "Get daily new sessions" + label_cn: "获取每日新建会话信息" + tags: [monitor] + params: + action: get_session_dailys + + - label: "Get user traffic top 10" + label_cn: "获取用户流量排行前10名" + tags: [monitor, traffic] + params: + action: get_user_traffic_rank + topNumber: 10 + + - label: "Get app traffic ranking" + label_cn: "获取应用流量排行" + tags: [monitor, traffic] + params: + action: get_app_traffic_rank + + - label: "Get active sessions" + label_cn: "获取实时活跃会话列表" + tags: [monitor, session] + params: + action: get_sessions + + sangfor_af_v85_ops: + - label: "List blocked attacker IPs" + label_cn: "查询封锁攻击者 IP 列表" + tags: [smoke, blockip] + params: + action: get_blockip_list + assert: + success: true + + - label: "List blacklist entries" + label_cn: "查询黑名单列表" + tags: [smoke, blacklist] + params: + action: get_blackwhitelist + type: BLACK + + sangfor_af_v85_objects: + - label: "List IP groups" + label_cn: "查询 IP 地址组列表" + tags: [smoke] + params: + action: get_ipgroups + assert: + success: true + + sangfor_af_v85_network: + - label: "List routing table" + label_cn: "查询路由表" + tags: [smoke] + params: + action: get_routes + assert: + success: true diff --git a/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af.handler.py b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af.handler.py new file mode 100644 index 00000000..3a5793bd --- /dev/null +++ b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af.handler.py @@ -0,0 +1,681 @@ +""" +Sangfor AF (Application Firewall) v8.0.85 API Handler. + +Extends v8.0.48 with additional monitoring APIs: + - Session monitoring (traffic ranking, session counts, session list) + - Statistics (packet loss, buffer, hash table, etc.) + - Log/alarm settings + +Authentication: same as v8.0.48 (session-based Cookie token). +API base URL: https:// +Namespace: /api/v1/namespaces/public/ +Batch ops: /api/batch/v1/namespaces/public/ +""" +from __future__ import annotations + +import os +from typing import Any, Callable, Optional + +import aiohttp + +from flocks.config.config_writer import ConfigWriter +from flocks.tool.registry import ToolContext, ToolResult + +# ── Constants ──────────────────────────────────────────────────────────────── + +SERVICE_ID = "sangfor_af_v8_0_85" +DEFAULT_BASE_URL = "https://192.168.1.1" +DEFAULT_TIMEOUT = 60 +NAMESPACE = "public" + +API_V1 = f"/api/v1/namespaces/{NAMESPACE}" +API_BATCH = f"/api/batch/v1/namespaces/{NAMESPACE}" + +_TOKEN_CACHE: dict[str, str] = {} + + +# ── Secret / Config helpers ─────────────────────────────────────────────────── + +def _get_secret_manager(): + from flocks.security import get_secret_manager + return get_secret_manager() + + +def _resolve_ref(value: Any) -> Optional[str]: + if value is None: + return None + if not isinstance(value, str): + return str(value) + if value.startswith("{secret:") and value.endswith("}"): + return _get_secret_manager().get(value[len("{secret:"): -1]) + if value.startswith("{env:") and value.endswith("}"): + return os.getenv(value[len("{env:"): -1]) + return value + + +def _service_config() -> dict[str, Any]: + raw = ConfigWriter.get_api_service_raw(SERVICE_ID) + return raw if isinstance(raw, dict) else {} + + +def _resolve_verify_ssl(raw: dict[str, Any]) -> bool: + """Read verify_ssl with the same priority as sangfor_sip / onesec: + verify_ssl > ssl_verify > custom_settings.verify_ssl > False. + AF devices commonly use self-signed certs, so default is False. + """ + value = raw.get("verify_ssl") + if value is None: + value = raw.get("ssl_verify") + if value is None: + custom = raw.get("custom_settings") + if isinstance(custom, dict): + value = custom.get("verify_ssl") + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.strip().lower() in {"1", "true", "yes", "on"} + return False + + +def _resolve_runtime_config() -> tuple[str, int, str, str, bool]: + raw = _service_config() + base_url = (_resolve_ref(raw.get("base_url")) or DEFAULT_BASE_URL).rstrip("/") + timeout = raw.get("timeout", DEFAULT_TIMEOUT) + try: + timeout = int(timeout) + except (TypeError, ValueError): + timeout = DEFAULT_TIMEOUT + + sm = _get_secret_manager() + username = ( + _resolve_ref(raw.get("username")) + or sm.get("sangfor_af_v8_0_85_username") + or os.getenv("AF_USERNAME") + ) + password = ( + _resolve_ref(raw.get("password")) + or sm.get("sangfor_af_v8_0_85_password") + or os.getenv("AF_PASSWORD") + ) + if not username or not password: + raise ValueError( + "AF API credentials not configured. " + "Please set username and password in the sangfor_af_v8_0_85 service configuration." + ) + return base_url, timeout, username, password, _resolve_verify_ssl(raw) + + +# ── Session / Token management ──────────────────────────────────────────────── + +async def _login(session, base_url, username, password, verify_ssl): + url = f"{base_url}{API_V1}/login" + try: + async with session.post( + url, + json={"name": username, "password": password}, + ssl=verify_ssl, + ) as resp: + data = await resp.json(content_type=None) + except aiohttp.ClientError as exc: + return None, f"AF login request failed: {exc}" + code = data.get("code") + if code != 0: + return None, f"AF login failed (code={code}): {data.get('message', 'Unknown error')}" + token = data.get("data", {}).get("loginResult", {}).get("token") + if not token: + return None, "AF login succeeded but no token returned" + return token, None + + +async def _get_token(session, base_url, username, password, verify_ssl): + cached = _TOKEN_CACHE.get(base_url) + if cached: + try: + async with session.get( + f"{base_url}{API_V1}/keepalive", + headers={"Cookie": f"token={cached}"}, + ssl=verify_ssl, + ) as resp: + ka = await resp.json(content_type=None) + if ka.get("code") == 0: + return cached, None + except Exception: + pass + token, err = await _login(session, base_url, username, password, verify_ssl) + if err: + return None, err + _TOKEN_CACHE[base_url] = token + return token, None + + +# ── Low-level HTTP ──────────────────────────────────────────────────────────── + +def _pick(params: dict[str, Any], *keys: str) -> dict[str, Any]: + return {k: params[k] for k in keys if k in params and params[k] is not None} + + +def _af_result(action: str, payload: Any, version: str = "8.0.85") -> ToolResult: + metadata = {"source": "Sangfor AF", "api": action, "version": version} + if isinstance(payload, dict): + code = payload.get("code") + if code not in (None, 0): + msg = payload.get("message", "Unknown error") + return ToolResult(success=False, error=f"AF API error (code={code}): {msg}", metadata=metadata) + return ToolResult(success=True, output=payload.get("data", payload), metadata=metadata) + return ToolResult(success=True, output=payload, metadata=metadata) + + +async def _call( + method: str, + path: str, + params: Optional[dict[str, Any]] = None, + json: Optional[Any] = None, + action: str = "", +) -> ToolResult: + try: + base_url, timeout, username, password, verify_ssl = _resolve_runtime_config() + except ValueError as exc: + return ToolResult(success=False, error=str(exc)) + + headers = {"Content-Type": "application/json"} + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session: + token, err = await _get_token(session, base_url, username, password, verify_ssl) + if err: + return ToolResult(success=False, error=err) + headers["Cookie"] = f"token={token}" + url = f"{base_url}{path}" + try: + async with session.request( + method.upper(), url, params=params, json=json, headers=headers, ssl=verify_ssl, + ) as resp: + if resp.status >= 400: + text = await resp.text() + return ToolResult(success=False, error=f"HTTP {resp.status}: {text[:500]}") + data = await resp.json(content_type=None) + except aiohttp.ClientError as exc: + return ToolResult(success=False, error=f"Request failed: {exc}") + except Exception as exc: + return ToolResult(success=False, error=f"Unexpected error: {exc}") + return _af_result(action or path.rsplit("/", 1)[-1], data) + + +# ── Auth actions ────────────────────────────────────────────────────────────── + +async def _do_login(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + try: + base_url, timeout, username, password, verify_ssl = _resolve_runtime_config() + except ValueError as exc: + return ToolResult(success=False, error=str(exc)) + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session: + token, err = await _login(session, base_url, username, password, verify_ssl) + if err: + return ToolResult(success=False, error=err) + _TOKEN_CACHE[base_url] = token + return ToolResult( + success=True, + output={"token": token, "message": "Login successful"}, + metadata={"source": "Sangfor AF", "api": "login", "version": "8.0.85"}, + ) + + +async def _do_logout(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + result = await _call("POST", f"{API_V1}/logout", action="logout") + try: + base_url, *_ = _resolve_runtime_config() + _TOKEN_CACHE.pop(base_url, None) + except ValueError: + pass + return result + + +async def _do_keepalive(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/keepalive", action="keepalive") + + +# ── Objects actions ────────────────────────────────────────────────────────── + +async def _do_get_ipgroups(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "businessType", "__nameprefix", "important", "_search", "_order", "_sortby", "addressType") + return await _call("GET", f"{API_V1}/ipgroups", params=query, action="get_ipgroups") + + +async def _do_get_ipgroup(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/ipgroups/{params.get('uuid', '')}", action="get_ipgroup") + + +async def _do_create_ipgroup(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + body = _pick(params, "name", "businessType", "description", "addressType", "important", "ipRanges", "creator") + return await _call("POST", f"{API_V1}/ipgroups", json={"obj": body}, action="create_ipgroup") + + +async def _do_update_ipgroup(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + body = _pick(params, "name", "businessType", "description", "addressType", "important", "ipRanges") + return await _call("PATCH", f"{API_V1}/ipgroups/{params.get('uuid', '')}", json={"obj": body}, action="update_ipgroup") + + +async def _do_delete_ipgroup(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("DELETE", f"{API_V1}/ipgroups/{params.get('uuid', '')}", action="delete_ipgroup") + + +async def _do_get_services(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "_search", "_order", "_sortby", "serviceType") + return await _call("GET", f"{API_V1}/services", params=query, action="get_services") + + +async def _do_get_service(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/services/{params.get('uuid', '')}", action="get_service") + + +# ── Monitoring actions (new in v8.0.85) ────────────────────────────────────── + +async def _do_get_user_traffic_rank(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + body = _pick(params, "topNumber", "vsys", "line", "applicationType", "filterObject") + return await _call( + "POST", + f"{API_V1}/topusertraffics", + params={"_method": "GET"}, + json=body or {}, + action="get_user_traffic_rank", + ) + + +async def _do_get_ip_traffic_trend(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + # /iptraffics is not a paged endpoint; _start/_length must not be sent. + # topNumber must be int — AF returns code=1001 for any non-int value. + query = _pick(params, "vsys", "topNumber", "unit", "minutes") + if "topNumber" in query: + try: + query["topNumber"] = int(query["topNumber"]) + except (TypeError, ValueError): + pass + return await _call("GET", f"{API_V1}/iptraffics", params=query or None, action="get_ip_traffic_trend") + + +async def _do_get_app_traffic_rank(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "vsys", "line", "topNumber") + if "topNumber" in query: + try: + query["topNumber"] = int(query["topNumber"]) + except (TypeError, ValueError): + pass + return await _call("GET", f"{API_V1}/apptrafficrank", params=query or None, action="get_app_traffic_rank") + + +async def _do_get_session_dailys(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "vsys", "ip") + return await _call("GET", f"{API_V1}/sessiondailys", params=query or None, action="get_session_dailys") + + +async def _do_get_session_details(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + # Endpoint needs explicit filters; without them AF returns 1004 "没有返回值". + query = _pick(params, "vsys", "srcIP", "dstIP", "protocol", "srcPort", "dstPort") + return await _call("GET", f"{API_V1}/sessiondetails", params=query or None, action="get_session_details") + + +async def _do_get_session_count_trend(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "vsys", "minutes") + return await _call("GET", f"{API_V1}/sessioncounttrend", params=query or None, action="get_session_count_trend") + + +async def _do_get_session_src_ip(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + # srcIP is required; AF returns 1004 "没有返回值" when omitted. + query = _pick(params, "vsys", "srcIP") + return await _call("GET", f"{API_V1}/sessionsrcip", params=query or None, action="get_session_src_ip") + + +async def _do_get_session_count_rank(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "vsys", "topNumber") + return await _call("GET", f"{API_V1}/sessioncountrank", params=query or None, action="get_session_count_rank") + + +async def _do_get_session_summary(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "vsys") + return await _call("GET", f"{API_V1}/sessionsummary", params=query or None, action="get_session_summary") + + +async def _do_get_monitor_ips(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length") + return await _call("GET", f"{API_V1}/monitorips", params=query or None, action="get_monitor_ips") + + +async def _do_get_sessions(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + body = _pick(params, "_start", "_length", "vsys", "srcIP", "dstIP", "protocol", "srcPort", "dstPort") + # AF8.0.x requires POST + ?_method=GET for /sessions; plain GET returns 1002. + return await _call("POST", f"{API_V1}/sessions", params={"_method": "GET"}, json=body or {}, action="get_sessions") + + +# Statistics (monitoring sub-section) +async def _do_get_packet_drop_stats(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length") + return await _call("GET", f"{API_V1}/mbufdroppointstatistics", params=query or None, action="get_packet_drop_stats") + + +async def _do_clear_packet_drop_stats(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("DELETE", f"{API_V1}/mbufdroppointstatistics", action="clear_packet_drop_stats") + + +async def _do_get_mbuf_stats(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/mbufstatistics", action="get_mbuf_stats") + + +async def _do_get_hash_table_stats(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length") + return await _call("GET", f"{API_V1}/hashtablestatistics", params=query or None, action="get_hash_table_stats") + + +# ── Operations center actions ───────────────────────────────────────────────── + +async def _do_get_blackwhitelist(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "type", "_start", "_length", "_search", "_order", "description") + return await _call("GET", f"{API_V1}/whiteblacklist", params=query, action="get_blackwhitelist") + + +async def _do_add_blackwhitelist(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + body = _pick(params, "url", "type", "enable", "description", "domain") + return await _call("POST", f"{API_V1}/whiteblacklist", json={"obj": body}, action="add_blackwhitelist") + + +async def _do_batch_add_blackwhitelist(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("POST", f"{API_BATCH}/whiteblacklist", json=params.get("items", []), action="batch_add_blackwhitelist") + + +async def _do_delete_blackwhitelist(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + url_param = params.get("url", "") + list_type = params.get("type", "") + query = {"type": list_type} if list_type else None + return await _call("DELETE", f"{API_V1}/whiteblacklist/{url_param}", params=query, action="delete_blackwhitelist") + + +async def _do_batch_delete_blackwhitelist(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("POST", f"{API_BATCH}/whiteblacklist", params={"_method": "DELETE"}, json=params.get("items", []), action="batch_delete_blackwhitelist") + + +async def _do_get_blockip_list(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "_sortby", "_order", "creator", "fuzzyIP") + return await _call("GET", f"{API_V1}/blockip", params=query, action="get_blockip_list") + + +async def _do_batch_add_blockip(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "aifwType") + return await _call("POST", f"{API_BATCH}/blockip", params=query or None, json=params.get("items", []), action="batch_add_blockip") + + +async def _do_batch_delete_blockip(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("POST", f"{API_BATCH}/blockip", params={"_method": "DELETE"}, json=params.get("items", []), action="batch_delete_blockip") + + +async def _do_clear_blockip(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "creator") + return await _call("DELETE", f"{API_V1}/blockip", params=query or None, action="clear_blockip") + + +async def _do_get_blockip_auto_config(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/blockip/autoconfig", action="get_blockip_auto_config") + + +async def _do_set_blockip_auto_config(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("PUT", f"{API_V1}/blockip/autoconfig", json={"obj": _pick(params, "blockTime")}, action="set_blockip_auto_config") + + +# ── Status actions ──────────────────────────────────────────────────────────── + +async def _do_get_memory_usage(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/memoryusage", action="get_memory_usage") + + +async def _do_get_cpu_usage(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/cpuusage", action="get_cpu_usage") + + +async def _do_get_disk_usage(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/diskusage", action="get_disk_usage") + + +async def _do_get_system_version(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "filter") + return await _call("GET", f"{API_V1}/systemversion", params=query or None, action="get_system_version") + + +async def _do_get_interface_status(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + # AF8.0.x: /interfacestatus returns 1002; use /interfaces (list) or + # /interfaces/status?interfaceName= (single interface query). + iface = params.get("interfaceNames") or params.get("interfaceName") or "" + if iface: + return await _call( + "GET", f"{API_V1}/interfaces/status", + params={"interfaceName": iface}, + action="get_interface_status", + ) + return await _call("GET", f"{API_V1}/interfaces", action="get_interface_status") + + +async def _do_get_runtime_status(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/runtimestatus", action="get_runtime_status") + + +async def _do_get_current_time(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/currenttime", action="get_current_time") + + +# ── Network actions ─────────────────────────────────────────────────────────── + +async def _do_get_routes(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "routeType", "_search") + return await _call("GET", f"{API_V1}/routes", params=query or None, action="get_routes") + + +async def _do_get_routes_ipv6(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "routeType", "_search") + return await _call("GET", f"{API_V1}/routes/ipv6", params=query or None, action="get_routes_ipv6") + + +# ── System actions ──────────────────────────────────────────────────────────── + +async def _do_get_accounts(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "_search", "enable") + return await _call("GET", f"{API_V1}/account", params=query or None, action="get_accounts") + + +async def _do_get_account(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/account/{params.get('name', '')}", action="get_account") + + +# ── Action dispatch ─────────────────────────────────────────────────────────── + +_ACTION_MAP: dict[str, Callable] = { + # Auth + "login": _do_login, + "logout": _do_logout, + "keepalive": _do_keepalive, + # Objects + "get_ipgroups": _do_get_ipgroups, + "get_ipgroup": _do_get_ipgroup, + "create_ipgroup": _do_create_ipgroup, + "update_ipgroup": _do_update_ipgroup, + "delete_ipgroup": _do_delete_ipgroup, + "get_services": _do_get_services, + "get_service": _do_get_service, + # Monitoring (new in v8.0.85) + "get_user_traffic_rank": _do_get_user_traffic_rank, + "get_ip_traffic_trend": _do_get_ip_traffic_trend, + "get_app_traffic_rank": _do_get_app_traffic_rank, + "get_session_dailys": _do_get_session_dailys, + "get_session_details": _do_get_session_details, + "get_session_count_trend": _do_get_session_count_trend, + "get_session_src_ip": _do_get_session_src_ip, + "get_session_count_rank": _do_get_session_count_rank, + "get_session_summary": _do_get_session_summary, + "get_monitor_ips": _do_get_monitor_ips, + "get_sessions": _do_get_sessions, + "get_packet_drop_stats": _do_get_packet_drop_stats, + "clear_packet_drop_stats": _do_clear_packet_drop_stats, + "get_mbuf_stats": _do_get_mbuf_stats, + "get_hash_table_stats": _do_get_hash_table_stats, + # Operations center + "get_blackwhitelist": _do_get_blackwhitelist, + "add_blackwhitelist": _do_add_blackwhitelist, + "batch_add_blackwhitelist": _do_batch_add_blackwhitelist, + "delete_blackwhitelist": _do_delete_blackwhitelist, + "batch_delete_blackwhitelist": _do_batch_delete_blackwhitelist, + "get_blockip_list": _do_get_blockip_list, + "batch_add_blockip": _do_batch_add_blockip, + "batch_delete_blockip": _do_batch_delete_blockip, + "clear_blockip": _do_clear_blockip, + "get_blockip_auto_config": _do_get_blockip_auto_config, + "set_blockip_auto_config": _do_set_blockip_auto_config, + # Status + "get_memory_usage": _do_get_memory_usage, + "get_cpu_usage": _do_get_cpu_usage, + "get_disk_usage": _do_get_disk_usage, + "get_system_version": _do_get_system_version, + "get_interface_status": _do_get_interface_status, + "get_runtime_status": _do_get_runtime_status, + "get_current_time": _do_get_current_time, + # Network + "get_routes": _do_get_routes, + "get_routes_ipv6": _do_get_routes_ipv6, + # System + "get_accounts": _do_get_accounts, + "get_account": _do_get_account, +} + +GROUP_ACTIONS: dict[str, set[str]] = { + "auth": {"login", "logout", "keepalive"}, + "objects": {"get_ipgroups", "get_ipgroup", "create_ipgroup", "update_ipgroup", "delete_ipgroup", "get_services", "get_service"}, + "monitor": { + "get_user_traffic_rank", "get_ip_traffic_trend", "get_app_traffic_rank", + "get_session_dailys", "get_session_details", "get_session_count_trend", + "get_session_src_ip", "get_session_count_rank", "get_session_summary", + "get_monitor_ips", "get_sessions", + "get_packet_drop_stats", "clear_packet_drop_stats", + "get_mbuf_stats", "get_hash_table_stats", + }, + "ops": { + "get_blackwhitelist", "add_blackwhitelist", "batch_add_blackwhitelist", + "delete_blackwhitelist", "batch_delete_blackwhitelist", + "get_blockip_list", "batch_add_blockip", "batch_delete_blockip", + "clear_blockip", "get_blockip_auto_config", "set_blockip_auto_config", + }, + "status": { + "get_memory_usage", "get_cpu_usage", "get_disk_usage", + "get_system_version", "get_interface_status", + "get_runtime_status", "get_current_time", + }, + "network": {"get_routes", "get_routes_ipv6"}, + "system": {"get_accounts", "get_account"}, +} + +_CONNECTIVITY_TEST_ACTIONS: dict[str, str] = { + "auth": "keepalive", + "objects": "get_ipgroups", + "monitor": "get_session_summary", + "ops": "get_blackwhitelist", + "status": "get_system_version", + "network": "get_routes", + "system": "get_accounts", +} + + +async def unified_ops(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + handler = _ACTION_MAP.get(action) + if handler is None: + available = ", ".join(sorted(_ACTION_MAP)) + return ToolResult(success=False, error=f"Unknown action: {action}. Available: {available}") + return await handler(ctx, **params) + + +async def _dispatch_group(ctx: ToolContext, group: str, action: str, **params: Any) -> ToolResult: + if action == "test": + return await unified_ops(ctx, action=_CONNECTIVITY_TEST_ACTIONS.get(group, "get_system_version"), **params) + if action not in GROUP_ACTIONS[group]: + available = ", ".join(sorted(GROUP_ACTIONS[group])) + return ToolResult(success=False, error=f"Unsupported {group} action: {action}. Available: {available}") + return await unified_ops(ctx, action=action, **params) + + +async def auth(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "auth", action, **params) + + +async def objects(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "objects", action, **params) + + +async def monitor(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "monitor", action, **params) + + +async def ops(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "ops", action, **params) + + +async def status(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "status", action, **params) + + +async def network(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "network", action, **params) + + +async def system(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "system", action, **params) + + +def _make_action_function(action: str): + async def _tool(ctx: ToolContext, **kwargs: Any) -> ToolResult: + return await unified_ops(ctx, action=action, **kwargs) + _tool.__name__ = action + return _tool + + +for _action_name in _ACTION_MAP: + globals()[_action_name] = _make_action_function(_action_name) + +del _action_name diff --git a/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_auth.yaml b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_auth.yaml new file mode 100644 index 00000000..69718131 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_auth.yaml @@ -0,0 +1,44 @@ +name: sangfor_af_v85_auth +description: > + Sangfor AF v8.0.48 authentication tool. Use the `action` parameter to + login, logout, or keep the session alive. Token is cached automatically + after a successful login. +description_cn: > + 深信服 AF v8.0.48 认证工具。通过 `action` 参数调用登录、注销或 token 保活接口。 + 登录成功后 token 会自动缓存,后续调用无需手动传 token。 +category: custom +enabled: true +requires_confirmation: false +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 认证动作名,可选值: + - login + 用途: 登录设备,获取 session token(token 自动缓存) + 必填: 无(用户名/密码从服务配置读取) + 风险提示: 只读认证接口 + 是否任务型: 否 + - logout + 用途: 注销当前登录 session,清除 token 缓存 + 必填: 无 + 风险提示: 写操作,注销后需重新登录 + 是否任务型: 否 + - keepalive + 用途: 刷新 token 超时计时器,保持 session 活跃 + 必填: 无 + 风险提示: 只读接口 + 是否任务型: 否 + enum: + - login + - logout + - keepalive + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: auth diff --git a/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_monitor.yaml b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_monitor.yaml new file mode 100644 index 00000000..374fac6c --- /dev/null +++ b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_monitor.yaml @@ -0,0 +1,184 @@ +name: sangfor_af_v85_monitor +description: > + Sangfor AF v8.0.85 monitoring tool. Provides real-time and historical + session data, traffic rankings, network statistics, and packet diagnostics. + These APIs are new in v8.0.85 and not available in v8.0.48. +description_cn: > + 深信服 AF v8.0.85 监控工具(v8.0.48 中不含此功能)。通过 `action` 参数 + 查询实时/历史会话数据、流量排行、网络统计及报文诊断信息。 +category: custom +enabled: true +requires_confirmation: false +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 监控动作名,可选值: + + ## 流量排行 + - get_user_traffic_rank + 用途: 获取用户流量排行(Top N 用户) + 必填: 无 + 常用: topNumber(前N名,默认10)、vsys、line、applicationType + 风险提示: 只读接口 + 是否任务型: 否 + - get_ip_traffic_trend + 用途: 获取 IP 流量趋势曲线(指定前5或10名IP) + 必填: 无 + 常用: topNumber、vsys、unit、minutes + 风险提示: 只读接口 + 是否任务型: 否 + - get_app_traffic_rank + 用途: 获取应用流量排行(Top N 应用) + 必填: 无 + 常用: topNumber、vsys、line + 风险提示: 只读接口 + 是否任务型: 否 + + ## 会话排行与统计 + - get_session_dailys + 用途: 获取每日新建会话信息 + 必填: 无 + 常用: vsys、ip、_start、_length + 风险提示: 只读接口 + 是否任务型: 否 + - get_session_details + 用途: 获取会话详情列表(含5层信息) + 必填: 无 + 常用: vsys、srcIP、dstIP、protocol、_start、_length + 风险提示: 只读接口 + 是否任务型: 否 + - get_session_count_trend + 用途: 获取会话数量趋势折线图数据 + 必填: 无 + 常用: vsys、minutes(最近N分钟,默认60) + 风险提示: 只读接口 + 是否任务型: 否 + - get_session_src_ip + 用途: 获取指定源IP的会话详情(按目的IP分组) + 必填: 无 + 常用: srcIP、vsys、_start、_length + 风险提示: 只读接口 + 是否任务型: 否 + - get_session_count_rank + 用途: 获取会话数量排行(Top N 源IP) + 必填: 无 + 常用: topNumber、vsys + 风险提示: 只读接口 + 是否任务型: 否 + - get_session_summary + 用途: 获取会话概要信息(总数、协议分布等) + 必填: 无 + 常用: vsys + 风险提示: 只读接口 + 是否任务型: 否 + - get_monitor_ips + 用途: 获取配置中心监听列表IP范围 + 必填: 无 + 常用: _start、_length + 风险提示: 只读接口 + 是否任务型: 否 + - get_sessions + 用途: 获取实时会话列表(当前活跃连接) + 必填: 无 + 常用: vsys、srcIP、dstIP、protocol、srcPort、dstPort、_start、_length + 风险提示: 只读接口 + 是否任务型: 否 + + ## 统计与诊断 + - get_packet_drop_stats + 用途: 获取 mbuf 丢包点统计信息列表 + 必填: 无 + 风险提示: 只读接口 + 是否任务型: 否 + - clear_packet_drop_stats + 用途: 清除后台丢包统计信息 + 必填: 无 + 风险提示: 写操作,清除统计数据不可恢复 + 是否任务型: 否 + - get_mbuf_stats + 用途: 获取 mbuf 内存统计信息 + 必填: 无 + 风险提示: 只读接口 + 是否任务型: 否 + - get_hash_table_stats + 用途: 获取哈希表统计列表 + 必填: 无 + 常用: _start、_length + 风险提示: 只读接口 + 是否任务型: 否 + enum: + - get_user_traffic_rank + - get_ip_traffic_trend + - get_app_traffic_rank + - get_session_dailys + - get_session_details + - get_session_count_trend + - get_session_src_ip + - get_session_count_rank + - get_session_summary + - get_monitor_ips + - get_sessions + - get_packet_drop_stats + - clear_packet_drop_stats + - get_mbuf_stats + - get_hash_table_stats + + topNumber: + type: integer + description: 排行榜取前N名(如5或10) + vsys: + type: string + description: 虚拟系统名称(通常为 public,可省略) + line: + type: integer + description: "线路编号过滤,0=全部(范围0-256)" + applicationType: + type: array + items: + type: string + description: 应用类型过滤列表 + filterObject: + type: object + description: > + 用户流量排行过滤对象: + objectType=GROUP/USER/IP,对应 groups/users/ip 数组 + unit: + type: string + description: "流量单位(如 bps, Kbps, Mbps)" + minutes: + type: integer + description: 查询最近N分钟的数据(默认60) + ip: + type: string + description: IP地址过滤(格式:IPv4/IPv6) + srcIP: + type: string + description: 源IP地址过滤(格式:IPv4/IPv6) + dstIP: + type: string + description: 目的IP地址过滤 + protocol: + type: string + description: "协议过滤:TCP/UDP/ICMP/OTHER" + srcPort: + type: integer + description: 源端口过滤(0-65535) + dstPort: + type: integer + description: 目的端口过滤(0-65535) + _start: + type: integer + description: 分页起始位置(从0开始) + _length: + type: integer + description: 每页最大返回数量(最大200,默认100) + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: monitor diff --git a/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_network.yaml b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_network.yaml new file mode 100644 index 00000000..149f1aee --- /dev/null +++ b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_network.yaml @@ -0,0 +1,65 @@ +name: sangfor_af_v85_network +description: > + Sangfor AF v8.0.48 network tool. Query routing tables (IPv4 and IPv6) + and network-related status information. +description_cn: > + 深信服 AF v8.0.48 网络工具。通过 `action` 参数查询路由表(IPv4/IPv6) + 及网络相关状态信息。 +category: custom +enabled: true +requires_confirmation: false +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 网络查询动作名,可选值: + - get_routes + 用途: 获取后台 IPv4 路由信息列表 + 必填: 无 + 常用: routeType(ALL_ROUTE/STATIC_ROUTE/DIRECT_ROUTE 等)、_start、_length + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_routes_ipv6 + 用途: 获取后台 IPv6 路由信息列表 + 必填: 无 + 常用: routeType、_start、_length + 风险提示: 只读查询接口 + 是否任务型: 否 + enum: + - get_routes + - get_routes_ipv6 + routeType: + type: string + description: > + 路由类型过滤:ALL_ROUTE=所有路由,STATIC_ROUTE=静态路由, + DIRECT_ROUTE=直连路由,OSPF_ROUTE=OSPF路由,RIP_ROUTE=RIP路由, + VPN_ROUTE=VPN路由,SSL_VPN_ROUTE=SSL VPN路由, + IBGP_ROUTE=IBGP路由,EBGP_ROUTE=EBGP路由 + enum: + - ALL_ROUTE + - STATIC_ROUTE + - DIRECT_ROUTE + - OSPF_ROUTE + - RIP_ROUTE + - VPN_ROUTE + - SSL_VPN_ROUTE + - IBGP_ROUTE + - EBGP_ROUTE + _start: + type: integer + description: 分页起始位置(从0开始) + _length: + type: integer + description: 每页最大返回数量(最大200,默认100) + _search: + type: string + description: 模糊搜索关键字 + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: network diff --git a/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_objects.yaml b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_objects.yaml new file mode 100644 index 00000000..180b4777 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_objects.yaml @@ -0,0 +1,152 @@ +name: sangfor_af_v85_objects +description: > + Sangfor AF v8.0.48 objects management tool. Query, create, update, and + delete network IP group objects and services (protocol/port definitions) + used in firewall policies. +description_cn: > + 深信服 AF v8.0.48 对象管理工具。通过 `action` 参数查询、创建、修改和删除 + IP 地址组对象及服务对象(协议/端口定义),这些对象被防火墙策略引用。 +category: custom +enabled: true +requires_confirmation: true +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 对象管理动作名,可选值: + + ## IP 地址组 + - get_ipgroups + 用途: 查询符合条件的 IP 地址组列表 + 必填: 无 + 常用: _start、_length、businessType、__nameprefix、important、_search + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_ipgroup + 用途: 获取单个 IP 地址组详情 + 必填: uuid + 风险提示: 只读查询接口 + 是否任务型: 否 + - create_ipgroup + 用途: 创建新的 IP 地址组 + 必填: name、businessType + 常用: ipRanges、addressType、description、important + 风险提示: 写操作;创建后可被防火墙策略引用 + 是否任务型: 否 + - update_ipgroup + 用途: 增量更新(PATCH)指定 IP 地址组 + 必填: uuid + 常用: name、ipRanges、description + 风险提示: 写操作;修改 IP 组会影响引用该组的所有策略 + 是否任务型: 否 + - delete_ipgroup + 用途: 删除指定 IP 地址组 + 必填: uuid + 风险提示: 高风险写操作;如有策略引用该组将删除失败 + 是否任务型: 否 + + ## 服务对象 + - get_services + 用途: 查询服务或服务组列表(预定义或自定义) + 必填: 无 + 常用: _start、_length、_search、serviceType + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_service + 用途: 获取单个服务或服务组详情 + 必填: uuid + 风险提示: 只读查询接口 + 是否任务型: 否 + enum: + - get_ipgroups + - get_ipgroup + - create_ipgroup + - update_ipgroup + - delete_ipgroup + - get_services + - get_service + + uuid: + type: string + description: IP地址组或服务对象的唯一标识符(32字符UUID) + name: + type: string + description: 对象名称(最大95字符) + businessType: + type: string + description: > + IP地址组业务类型:IP=IP地址,ADDRGROUP=地址组, + USER=用户地址,BUSINESS=业务地址 + enum: + - IP + - ADDRGROUP + - USER + - BUSINESS + addressType: + type: string + description: "IP协议版本:IPV4 或 IPV6" + enum: + - IPV4 + - IPV6 + important: + type: string + description: "重要级别:COMMON=普通,CORE=核心" + enum: + - COMMON + - CORE + ipRanges: + type: array + items: + type: object + properties: + start: + type: string + description: IP范围起始地址(如 192.168.1.1) + end: + type: string + description: IP范围结束地址(如 192.168.1.254) + description: IP地址范围列表 + description: + type: string + description: 对象描述(最大95字符) + creator: + type: string + description: 创建者名称 + serviceType: + type: string + description: "服务类型过滤:SERVICE=单个服务,SERVICEGROUP=服务组" + enum: + - SERVICE + - SERVICEGROUP + + # Pagination + _start: + type: integer + description: 分页起始位置(从0开始) + _length: + type: integer + description: 每页最大返回数量(最大200,默认100) + __nameprefix: + type: string + description: 按名称前缀过滤(最大95字符) + _search: + type: string + description: 模糊搜索关键字(最大95字符) + _order: + type: string + description: "排序方向:asc 或 desc" + enum: + - asc + - desc + _sortby: + type: string + description: 排序字段名 + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: objects diff --git a/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_ops.yaml b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_ops.yaml new file mode 100644 index 00000000..a46585d4 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_ops.yaml @@ -0,0 +1,165 @@ +name: sangfor_af_v85_ops +description: > + Sangfor AF v8.0.48 operations center tool. Manages blacklist/whitelist + entries (IPs, domains, URLs) and blocked attacker IPs via the `action` + parameter. Key security triage actions for SOC workflows. +description_cn: > + 深信服 AF v8.0.48 运营中心工具。通过 `action` 参数管理黑白名单(IP/域名/URL) + 和封锁攻击者 IP。是 SOC 安全处置的核心接口。 +category: custom +enabled: true +requires_confirmation: true +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 运营中心动作名,可选值: + + ## 黑白名单管理 + - get_blackwhitelist + 用途: 查询黑白名单列表(IP/域名/URL) + 必填: 无 + 常用: type(BLACK/WHITE)、_start、_length + 风险提示: 只读查询接口 + 是否任务型: 否 + - add_blackwhitelist + 用途: 添加单条黑白名单 + 必填: url(IP/域名/URL)、type(BLACK/WHITE) + 常用: enable、description、domain(0=IP,1=域名,2=URL) + 风险提示: 写操作;添加黑名单会拦截对应流量 + 是否任务型: 否 + - batch_add_blackwhitelist + 用途: 批量添加黑白名单 + 必填: items(数组,每项含 url/type 字段) + 风险提示: 写操作,批量添加黑名单影响面大 + 是否任务型: 否 + - delete_blackwhitelist + 用途: 删除单条黑白名单 + 必填: url(条目的 IP/域名/URL) + 常用: type(BLACK/WHITE) + 风险提示: 写操作,删除白名单可能导致误拦截 + 是否任务型: 否 + - batch_delete_blackwhitelist + 用途: 批量删除黑白名单 + 必填: items(数组,每项含 url 字段) + 风险提示: 写操作,批量删除影响面大 + 是否任务型: 否 + + ## 封锁攻击者 IP + - get_blockip_list + 用途: 查询当前封锁攻击者 IP 列表 + 必填: 无 + 常用: _start、_length、fuzzyIP(模糊搜索)、creator(AF/SIP) + 风险提示: 只读查询接口 + 是否任务型: 否 + - batch_add_blockip + 用途: 批量封锁攻击者 IP + 必填: items(数组,每项含 srcIP、dstIP 等字段) + 常用: aifwType(MANUAL/AUTO) + 风险提示: 高风险写操作;封锁 IP 会拦截其所有流量 + 是否任务型: 否 + - batch_delete_blockip + 用途: 批量解封攻击者 IP + 必填: items(数组,每项含 srcIP、dstIP 等字段) + 风险提示: 写操作,解封恶意 IP 存在安全风险 + 是否任务型: 否 + - clear_blockip + 用途: 清空封锁攻击者 IP 列表 + 必填: 无 + 常用: creator(AF/SIP,指定清除哪类封锁) + 风险提示: 高风险写操作;会清除所有封锁 IP + 是否任务型: 否 + - get_blockip_auto_config + 用途: 获取自动封锁攻击者时长配置 + 必填: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + - set_blockip_auto_config + 用途: 修改自动封锁攻击者时长 + 必填: blockTime(封锁时长,单位秒) + 风险提示: 写操作,影响自动封锁策略 + 是否任务型: 否 + enum: + - get_blackwhitelist + - add_blackwhitelist + - batch_add_blackwhitelist + - delete_blackwhitelist + - batch_delete_blackwhitelist + - get_blockip_list + - batch_add_blockip + - batch_delete_blockip + - clear_blockip + - get_blockip_auto_config + - set_blockip_auto_config + + # Blacklist/whitelist params + url: + type: string + description: IP地址、域名或URL(黑白名单条目值) + type: + type: string + description: "名单类型:BLACK(黑名单)或 WHITE(白名单)" + enum: + - BLACK + - WHITE + enable: + type: boolean + description: 是否启用该条目,默认 true + description: + type: string + description: 条目描述信息(最大95字符) + domain: + type: integer + description: "条目类型:0=IP地址,1=域名,2=URL" + enum: [0, 1, 2] + items: + type: array + items: + type: object + description: 批量操作时的条目数组,每项至少包含 url(黑白名单)或 srcIP/dstIP(封锁IP) + + # Block IP params + fuzzyIP: + type: string + description: 模糊搜索IP关键字(最大15字符) + creator: + type: string + description: "封锁来源身份:AF(防火墙自身)或 SIP(安全感知平台)" + enum: + - AF + - SIP + aifwType: + type: string + description: "添加封锁IP的类型:MANUAL(手动)或 AUTO(自动,需要 creator=SIP)" + enum: + - MANUAL + - AUTO + blockTime: + type: integer + description: 自动封锁时长(秒) + + # Pagination + _start: + type: integer + description: 分页起始位置(从0开始) + _length: + type: integer + description: 每页最大返回数量(最大200,默认100) + _sortby: + type: string + description: 排序字段名 + _order: + type: string + description: "排序方向:asc(升序)或 desc(降序)" + enum: + - asc + - desc + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: ops diff --git a/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_status.yaml b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_status.yaml new file mode 100644 index 00000000..a1c01a44 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_status.yaml @@ -0,0 +1,91 @@ +name: sangfor_af_v85_status +description: > + Sangfor AF v8.0.48 device status tool. Query system resource usage + (CPU, memory, disk), firmware version, network interface status, + current time, and system uptime. +description_cn: > + 深信服 AF v8.0.48 状态中心工具。通过 `action` 参数查询系统资源(CPU/内存/磁盘)、 + 固件版本、网口状态、当前时间及系统运行时长等信息。 +category: custom +enabled: true +requires_confirmation: false +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 状态查询动作名,可选值: + - get_memory_usage + 用途: 获取当前内存使用率(百分比) + 必填: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_cpu_usage + 用途: 获取当前 CPU 使用率(百分比) + 必填: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_disk_usage + 用途: 获取当前磁盘使用率(百分比) + 必填: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_system_version + 用途: 获取 AF 系统固件版本信息 + 必填: 无 + 常用: filter(ALL/FULL/MAJOR/MINOR 等) + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_interface_status + 用途: 获取指定网口或全部网口的状态(流速、连接状态) + 必填: 无 + 常用: interfaceNames(如 eth0,不传则获取全部) + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_runtime_status + 用途: 获取系统运行时长(uptime) + 必填: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_current_time + 用途: 获取设备当前时间 + 必填: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + enum: + - get_memory_usage + - get_cpu_usage + - get_disk_usage + - get_system_version + - get_interface_status + - get_runtime_status + - get_current_time + filter: + type: string + description: > + 版本信息过滤(仅用于 get_system_version): + ALL=显示所有,FULL=完整版本号,MAJOR=主版本号,MINOR=次版本号, + INCREASE=增版本号,BUILD=创建日期,EN=是否英文版,HF=是否HF版,B=是否Beta版 + enum: + - ALL + - FULL + - MAJOR + - MINOR + - INCREASE + - BUILD + - EN + - HF + - B + - R + - ADD + interfaceNames: + type: string + description: 网口名称(如 eth0),用于 get_interface_status;不填则获取全部接口 + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: status diff --git a/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_system.yaml b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_system.yaml new file mode 100644 index 00000000..1ad154a9 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/api/sangfor_af_v8_0_85/sangfor_af_v85_system.yaml @@ -0,0 +1,53 @@ +name: sangfor_af_v85_system +description: > + Sangfor AF v8.0.48 system management tool. Query and manage administrator + accounts on the AF device. +description_cn: > + 深信服 AF v8.0.48 系统管理工具。通过 `action` 参数查询和管理 AF 设备上的 + 管理员账户信息。 +category: custom +enabled: true +requires_confirmation: true +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 系统管理动作名,可选值: + - get_accounts + 用途: 查询所有管理员账户列表 + 必填: 无 + 常用: _start、_length、enable + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_account + 用途: 查询指定管理员账户详情 + 必填: name(账户名) + 风险提示: 只读查询接口 + 是否任务型: 否 + enum: + - get_accounts + - get_account + name: + type: string + description: 管理员账户名(用于 get_account) + enable: + type: boolean + description: 按启用/禁用状态过滤账户 + _start: + type: integer + description: 分页起始位置(从0开始) + _length: + type: integer + description: 每页最大返回数量 + _search: + type: string + description: 模糊搜索关键字 + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: system diff --git a/.flocks/mcp_list.json.example b/.flocks/mcp_list.json.example index a8008aee..62dc48a3 100644 --- a/.flocks/mcp_list.json.example +++ b/.flocks/mcp_list.json.example @@ -74,6 +74,16 @@ "license": "Proprietary", "stars": 0, "transport": "remote", + "remote": { + "url": "https://mcp.threatbook.cn/mcp", + "transport": "auto", + "auth": { + "type": "apikey", + "location": "query", + "param_name": "apikey", + "value": "{secret:threatbook_mcp_key}" + } + }, "env_vars": { "THREATBOOK_MCP_KEY": { "required": true, @@ -99,6 +109,16 @@ "license": "Proprietary", "stars": 0, "transport": "remote", + "remote": { + "url": "https://mcp.nsfocus.cn/mcp", + "transport": "auto", + "auth": { + "type": "apikey", + "location": "query", + "param_name": "apikey", + "value": "{secret:nsfocus_mcp_key}" + } + }, "env_vars": { "NSFOCUS_MCP_KEY": { "required": true, @@ -124,6 +144,16 @@ "license": "Proprietary", "stars": 0, "transport": "remote", + "remote": { + "url": "https://mcp.ti.qianxin.com/ti-stream-mcp", + "transport": "auto", + "auth": { + "type": "apikey", + "location": "query", + "param_name": "apikey", + "value": "{secret:qianxin_mcp_key}" + } + }, "env_vars": { "QIANXIN_MCP_KEY": { "required": true, diff --git a/.flocks/plugins/skills/onesec-use/SKILL.md b/.flocks/plugins/skills/onesec-use/SKILL.md index eebaa433..f34dc887 100644 --- a/.flocks/plugins/skills/onesec-use/SKILL.md +++ b/.flocks/plugins/skills/onesec-use/SKILL.md @@ -31,9 +31,16 @@ description: 用于处理 OneSEC 终端安全平台相关任务,适合通过AP ## API模式使用指南 +- 时间字段硬规则:凡是本次 API 调用涉及 `time_from` / `time_to` / `begin_time` / `end_time` 等时间字段,必须先在 bash 中执行 `uv run python` 动态计算,再调用对应 tool;禁止手动估算、禁止硬编码、禁止把“今天”“最近 7 天”“最近 24 小时”等自然语言时间直接脑补成数字。 +- OneSEC 的时间入参默认按 Unix 秒级时间戳处理;如果需要“今天”“本周”“最近 N 天”这类窗口,也必须先用 `uv run python` 算出准确边界后再传参。 - 用户说“威胁事件”“最近有什么事件”“事件详情”“有哪些事件”时,优先走 `onesec_edr` 的事件类 action - 用户说“终端告警”“告警日志”“查某终端的告警”“查某进程行为”“行为记录”“时间线”“IOC”“恶意文件”时,优先走 `onesec_edr` 的告警/日志类 action - 用户说“DNS 拦截”“DNS 告警”“解析日志”“受威胁终端”“域名放行/阻断”时,优先走 `onesec_dns` +- 涉及任何时间查询时,必须动态计算 `time_from` / `time_to` / `begin_time` / `end_time`,禁止手动估算时间戳 +- 查询单个域名是否被 DNS 拦截时,优先构造 `dns_search_blocked_queries + domain + time_from + time_to`;未显式给 `keyword` 时,工具会默认复用 `domain` +- DNS 查询优先使用 Unix 秒级时间戳;如果用户给的是常见日期字符串(如 `2026-05-08 00:00:00`),工具会自动换算 +- 查询指定域名或关键字的 DNS 拦截明细时,优先使用 `dns_search_blocked_queries` +- `dns_get_recent_blocked_queries` 只用于最近 24 小时增量拉取,不支持 `domain` / `keyword` / `private_ip` / `threat_type` / 分页参数 - 用户说“安装了什么软件”“哪些终端装了某软件”时,优先走 `onesec_software` - 用户说“终端管理”“任务列表”“任务执行进度”“审计日志”“策略范围”时,优先走 `onesec_ops` - 用户说“病毒库版本”“病毒扫描”“停止扫描”“升级病毒库”时,优先走 `onesec_threat` diff --git a/.flocks/plugins/skills/onesec-use/references/api-reference.md b/.flocks/plugins/skills/onesec-use/references/api-reference.md index a1f14e39..7c4d1881 100644 --- a/.flocks/plugins/skills/onesec-use/references/api-reference.md +++ b/.flocks/plugins/skills/onesec-use/references/api-reference.md @@ -9,7 +9,7 @@ OneSEC 当前优先复用 grouped tool,而不是直接从页面做数据获取 | 查威胁事件 | `onesec_edr` | `edr_get_incidents`(默认) / `edr_get_recent_incidents`(仅最近 24 小时增量) | 建议显式带 `time_from`、`time_to` | | 查终端告警 | `onesec_edr` | `edr_get_endpoint_alerts` / `edr_get_recent_endpoint_alerts` | 常见至少带 `time_from`、`time_to` | | 查恶意文件 / 威胁行为 / 时间线 / IOC | `onesec_edr` | 对应 `edr_get_*` action | 时间范围、分页、筛选字段 | -| 查 DNS 拦截 / 解析日志 / 受威胁终端 | `onesec_dns` | `dns_search_blocked_queries` / `dns_search_queries` / `dns_search_threatened_endpoint` | 多数需要 `time_from`、`time_to` | +| 查 DNS 拦截 / 解析日志 / 受威胁终端 | `onesec_dns` | `dns_search_blocked_queries` / `dns_get_recent_blocked_queries` / `dns_search_queries` / `dns_search_threatened_endpoint` | 多数需要 `time_from`、`time_to` | | 查软件清单或安装终端 | `onesec_software` | `software_query_page_list` / `software_query_agent_list` | 软件终端查询要 `name` + `publisher` | | 查终端、任务、审计 | `onesec_ops` | `ops_query_agent_page_list` / `ops_query_task_page_list` / `ops_query_audit_log` | 审计/任务通常要时间范围 | | 查病毒库 / 下发扫描 | `onesec_threat` | `threat_query_bd_version` / `threat_virus_scan` | 查询通常空参,写操作需明确目标终端 | @@ -21,32 +21,46 @@ OneSEC 当前优先复用 grouped tool,而不是直接从页面做数据获取 - 时间字段多数为秒级 Unix 时间戳 - 查询类 action 默认优先;写操作只有在用户明确授权时才执行 -## 时间参数规则 +## 时间参数注意事项(重点) -OneSEC 高频查询动作通常使用: +### 时间参数计算 +调用任何时间相关 OneSEC API 时,必须**动态计算**时间戳,禁止手动估算。 -- `time_from` -- `time_to` +**错误方法(禁止)** -单位: +```python +# 手动估算,硬编码 +time_from = 1741536000 +time_to = 1741622400 +``` + +**正确方法** -- 秒级时间戳,不是毫秒 +```python +import datetime + +# 按执行时刻动态计算最近 24 小时窗口 +now = datetime.datetime.now() +time_to = int(now.timestamp()) +time_from = int((now - datetime.timedelta(hours=24)).timestamp()) +``` + +使用动态计算的时间戳调用工具执行 +``` +onesec_dns(action="dns_get_recent_blocked_queries", time_from=time_from, time_to=time_to) +``` +如果要查“今天”或“最近 7 天”,也必须动态计算,不要手填固定时间戳。 + +### 时间参数规则 +- 秒级时间戳,不是毫秒,UTC+8 时区 - 分页接口建议显式传 `time_from`、`time_to` - `recent` 系列只适合最近 24 小时的增量查询 +- DNS `dns_search_queries` 也只支持最近 24 小时内的数据 +- `edr_get_threat_files`、`edr_get_threat_activities`、`edr_get_incidents`、`edr_get_endpoint_alerts` 的时间窗口最长三个月 +- `ops_query_audit_log` 仅支持最近 30 天内的审计日志 - 未传时间时,返回范围由服务端默认窗口决定,仅作兜底,不推荐依赖 -示例: - -```json -{ - "action": "edr_get_incidents", - "time_from": 1741536000, - "time_to": 1741622400, - "cur_page": 1, - "page_size": 20 -} -``` ## 时间窗口选择表 @@ -107,8 +121,8 @@ OneSEC 中几个相邻页面经常被混用,建议先按语义路由: ```json { "action": "edr_get_incidents", - "time_from": 1741536000, - "time_to": 1741622400, + "time_from": 动态计算秒级时间戳, + "time_to": 动态计算秒级时间戳, "cur_page": 1, "page_size": 20 } @@ -142,8 +156,8 @@ OneSEC 中几个相邻页面经常被混用,建议先按语义路由: ```json { "action": "edr_get_endpoint_alerts", - "time_from": 1741536000, - "time_to": 1741622400, + "time_from": 动态计算秒级时间戳, + "time_to": 动态计算秒级时间戳, "sql": "threat.level = 'attack'", "cur_page": 1, "page_size": 20 @@ -183,8 +197,8 @@ OneSEC 中几个相邻页面经常被混用,建议先按语义路由: ```json { "action": "edr_get_threat_files", - "time_from": 1741536000, - "time_to": 1741622400, + "time_from": 动态计算秒级时间戳, + "time_to": 动态计算秒级时间戳, "cur_page": 1, "page_size": 20 } @@ -202,6 +216,7 @@ OneSEC 中几个相邻页面经常被混用,建议先按语义路由: - `edr_get_recent_threat_timeline` 也需要 `incident_id` - 如果还没有 `incident_id`,应先调用 `edr_get_incidents` - `edr_get_recent_threat_timeline` 仅适合最近 24 小时内的增量时间线查询 +- `edr_get_threat_files`、`edr_get_threat_activities` 等分页接口按文档时间窗口最长三个月 ### 4. 查询威胁处置清单 @@ -238,6 +253,7 @@ OneSEC 中几个相邻页面经常被混用,建议先按语义路由: 高频 action: - `dns_search_blocked_queries` +- `dns_get_recent_blocked_queries` - `dns_search_queries` - `dns_search_threatened_endpoint` - `dns_get_public_ip_list` @@ -247,10 +263,28 @@ DNS 拦截记录示例: ```json { "action": "dns_search_blocked_queries", - "time_from": 1741536000, - "time_to": 1741622400, + "time_from": 动态计算秒级时间戳, + "time_to": 动态计算秒级时间戳, "domain": "evil.com", - "keyword": "evil" + "keyword": "evil", + "show_unblocked_threat": 1 +} +``` + +如果用户只明确给了一个完整域名,优先把 `domain` 和 `keyword` 都设成该域名;当前工具在缺少 `keyword` 时也会默认复用 `domain`。 + +如果用户没有给域名/关键字,只有 `public_ip` + 时间范围,并且查询目标是最近 24 小时拦截记录,优先改用 `dns_get_recent_blocked_queries`;不要硬套 `dns_search_blocked_queries`。 + +DNS 近期拦截增量示例: + +```json +{ + "action": "dns_get_recent_blocked_queries", + "time_from": 动态计算秒级时间戳, + "time_to": 动态计算秒级时间戳, + "block_reason": "threat", + "show_unblocked_threat": 1, + "threat_level": [2, 3, 4] } ``` @@ -259,8 +293,8 @@ DNS 解析日志示例: ```json { "action": "dns_search_queries", - "time_from": 1741536000, - "time_to": 1741622400, + "time_from": 动态计算秒级时间戳, + "time_to": 动态计算秒级时间戳, "domain": "evil.com", "page_items_num": 20 } @@ -269,7 +303,12 @@ DNS 解析日志示例: 注意: - 有些 DNS action 对时间窗口要求严格 +- DNS 查询优先使用 Unix 秒级时间戳;当前工具也会兼容常见日期字符串,如 `YYYY-MM-DD HH:MM:SS` +- `public_ip` 按文档是数组;当前工具也会兼容单个字符串并自动包装成单元素数组 +- 查询具体域名或关键字的 DNS 拦截明细时,优先使用 `dns_search_blocked_queries` +- `dns_get_recent_blocked_queries` 仅适合最近 24 小时增量拉取,不支持 `domain`、`keyword`、`private_ip`、`threat_type` 和分页参数 - `page_items_num` 与 `page_size` 不是同一个字段 +- DNS 拦截结果优先读取 `result` 字段;当前工具也会补充 `is_blocked` - 目标地址列表增删改是写操作,不要误用 ### 6. 查询软件资产 @@ -326,8 +365,8 @@ DNS 解析日志示例: ```json { "action": "ops_query_audit_log", - "begin_time": 1741536000, - "end_time": 1741622400, + "begin_time": 动态计算的 begin_time 秒级时间戳, + "end_time": 动态计算的 end_time 秒级时间戳, "cur_page": 1, "page_size": 20 } @@ -339,8 +378,8 @@ DNS 解析日志示例: { "action": "ops_query_task_page_list", "time_type": "create_time", - "begin_time": 1741536000, - "end_time": 1741622400, + "begin_time": 动态计算的 begin_time 秒级时间戳, + "end_time": 动态计算的 end_time 秒级时间戳, "auto": 0, "cur_page": 1, "page_size": 20 @@ -350,7 +389,10 @@ DNS 解析日志示例: 注意: - 审计和任务查询通常要求时间范围 +- `ops_query_audit_log` 仅支持最近 30 天内的日志 - `ops_query_task_page_list` 还要求 `time_type` 和 `auto` +- `time_type` 仅支持 `create_time`、`update_time` +- `auto` 仅支持 `0`(人工响应)和 `1`(自动响应) ### 8. 病毒库与扫描任务 @@ -381,6 +423,9 @@ DNS 解析日志示例: - 这是写操作 - 扫描范围越大,对终端影响越大 +- `task_type` 仅支持 `10110`(快速扫描)、`10120`(全盘扫描)、`10130`(自定义扫描) +- `scanmode` 仅支持 `1`(极速)、`2`(均衡)、`3`(低耗) +- `threat_update_bd_version` 的 `os_platform` 仅支持 `windows/macos`,macOS 架构仅支持 `Apple Silicon/Intel Chip` ## 高风险写操作清单 diff --git a/.flocks/plugins/skills/qingteng-use/SKILL.md b/.flocks/plugins/skills/qingteng-use/SKILL.md index ba82be8f..c8c007dc 100644 --- a/.flocks/plugins/skills/qingteng-use/SKILL.md +++ b/.flocks/plugins/skills/qingteng-use/SKILL.md @@ -31,6 +31,8 @@ description: 用于处理青藤云安全平台相关任务,适合通过API或 ## API 模式使用指南 +- 时间字段硬规则:凡是本次 API 调用涉及 `time_range` / `begin_time` / `end_time` / `logTime` / `loginTime` / `createTime` / `time` 等时间字段,必须先在 bash 中执行 `uv run python` 动态计算,再调用对应 tool;禁止手动估算、禁止硬编码、禁止把“今天”“最近 7 天”等自然语言时间直接脑补成参数。 +- 青藤查询类时间字段要按对应 schema 要求传值;`DateRange` 类型必须先用 `uv run python` 生成 `yyyy-MM-dd HH:mm:ss - yyyy-MM-dd HH:mm:ss` 格式字符串后再传入,不能手写。 - 用户说“主机资产”“进程”“账号”“端口”“网站”“数据库”“安装包”时,优先走 `qingteng_assets` - 用户说“可疑操作”“暴力破解”“异常登录”“WebShell”“后门”“蜜罐”时,优先走 `qingteng_detect` - 用户说“补丁”“风险”“弱密码”“风险文件”“漏洞扫描”“作业管理”时,优先走 `qingteng_risk` diff --git a/.flocks/plugins/skills/skyeye-use/SKILL.md b/.flocks/plugins/skills/skyeye-use/SKILL.md index 284d793c..8a49c897 100644 --- a/.flocks/plugins/skills/skyeye-use/SKILL.md +++ b/.flocks/plugins/skills/skyeye-use/SKILL.md @@ -30,6 +30,8 @@ description: 用于处理 SkyEye/天眼/网神分析平台相关任务,适合 ## API模式使用指南 +- 时间字段硬规则:凡是本次 API 调用涉及 `start_time` / `end_time` 等时间字段,必须先在 bash 中执行 `uv run python` 动态计算,再调用对应 tool;禁止手动估算、禁止硬编码、禁止把“今天”“最近 7 天”“告警当天”这类自然语言时间直接脑补成数字。 +- SkyEye 的时间入参默认按 Unix 毫秒级时间戳处理;即使只是下载报告、PCAP 或样本,也必须先用 `uv run python` 算出准确的毫秒级时间窗口后再传参。 - 用户说“告警列表”“最近告警”“按威胁级别筛选告警”时,优先走 `skyeye_alarm_list` - 用户说“告警字段有哪些”“攻击阶段有哪些”“需要枚举值”时,优先走 `skyeye_alarm_params` - 用户说“看板”“概览”“趋势”“整体视图”“系统状态”时,优先走 `skyeye_dashboard_view` diff --git a/.flocks/plugins/skills/tdp-use/SKILL.md b/.flocks/plugins/skills/tdp-use/SKILL.md index 2ac89aac..719970fe 100644 --- a/.flocks/plugins/skills/tdp-use/SKILL.md +++ b/.flocks/plugins/skills/tdp-use/SKILL.md @@ -33,6 +33,8 @@ description: 用于处理 TDP 威胁检测平台相关任务,适合通过API > 必须阅读: API 参数和适用场景见 [references/api-reference.md](references/api-reference.md)。 +- 时间字段硬规则:凡是本次 API 调用涉及 `time_from` / `time_to` 等时间字段,必须先在 bash 中执行 `uv run python` 动态计算,再调用对应 tool;禁止手动估算、禁止硬编码、禁止把“今天”“最近 7 天”“最近 1 小时”等自然语言时间直接脑补成数字。 +- TDP 的时间入参默认按 Unix 秒级时间戳处理;无论是日志、事件、看板还是风险类查询,只要要传时间范围,都必须先用 `uv run python` 算出准确边界后再传参。 - API 调用必须以当前 tool schema 为准,优先使用 schema 暴露的顶层语义化参数;列表类工具常见 `keyword`、`severity`、`cur_page`、`page_size`、`sort_by`,但 `tdp_log_search.sql` 是过滤表达式不是完整 SQL,禁止 `SELECT/FROM`,控制返回数量用 `size`,`terms` 可不传 `sql`,外部攻击结果筛选用 `result_list`。 - 用户说“告警”“告警记录”“告警日志”“明细记录”“查某 IP 的告警”时,默认走 `tdp_log_search` - 用户说“看板”“概览”“趋势”“统计”时,先用 `tdp_dashboard_status` diff --git a/.flocks/plugins/skills/tool-builder/SKILL.md b/.flocks/plugins/skills/tool-builder/SKILL.md index f0917228..4eb49e46 100644 --- a/.flocks/plugins/skills/tool-builder/SKILL.md +++ b/.flocks/plugins/skills/tool-builder/SKILL.md @@ -414,40 +414,88 @@ async def my_tool(ctx: ToolContext, query: str, limit: int = 10) -> ToolResult: **You MUST run these steps after creating any tool. Do NOT skip or defer.** -### Step 1: Static Validation +### Step 0: Metadata & Handler Audit (MUST run first) -Run via bash to confirm file syntax: +This step exists because the loader silently accepts many degraded +configurations — invalid `category` is coerced to `custom`, missing +`type` on a parameter falls back to `string`, undeclared placeholders in +the URL substitute to an empty string, etc. By the time the smoke test +runs, the symptoms (404, "missing field", empty result) no longer point +back to the missing piece of metadata. + +Run the bundled validator before anything else. It is self-contained +(stdlib + pyyaml only) and inspects the tool file *plus* its +`_provider.yaml` and any script handler: -**YAML-HTTP tools (Mode A):** ```bash -uv run python -c " -import yaml, sys -from pathlib import Path -path = Path('$TOOL_PATH').expanduser() -data = yaml.safe_load(path.read_text(encoding='utf-8')) -assert 'name' in data, 'Missing name field' -assert 'handler' in data, 'Missing handler section' -handler = data['handler'] -assert handler.get('type') in ('http', 'script'), 'handler.type must be http or script' -print(f'PASS: YAML valid, name={data[\"name\"]}, handler.type={handler[\"type\"]}') -" +SKILL_DIR="$(realpath ~/.flocks/plugins/skills/tool-builder)" +uv run python "$SKILL_DIR/validator.py" "$TOOL_PATH" ``` -**Python tools (Mode B):** +The validator checks (this list is enforced, not aspirational): + +**Metadata (every mode)** +- `name` is present, snake_case, not colliding with a built-in tool +- `description` is present and long enough to be useful +- `category` is one of `file | terminal | browser | code | search | system | custom` + — the loader silently coerces invalid values to `custom` +- `enabled: true` is set explicitly so the tool is active immediately +- For tools under `api/`: a `provider` field or a `_provider.yaml` is reachable + +**Parameters / inputSchema** +- `inputSchema` or `parameters` is declared (not both) +- Every property has a `type` and a `description` +- Every name listed in `required:` is also defined in `properties` +- A required parameter never also has a `default` + +**YAML-HTTP handler** +- `handler.type` is `http` (or `script`) +- `handler.url` is present +- Every `{placeholder}` in url / headers / query_params / body matches a + declared parameter (or `{base_url}` when a `_provider.yaml` provides it) +- `response.error_mapping` keys are integers +- `{secret:xxx}` references are surfaced so you can confirm they exist + +**YAML-script handler** +- `script_file` resolves to an existing file under `~/.flocks/plugins/` +- `function` exists in that file as `async def` +- The function signature accepts `(ctx, ...)` and every YAML parameter + is either a named arg or `**kwargs` +- The script imports `ToolResult` + +**`_provider.yaml`** (when the tool lives under `api/{provider}/`) +- File exists in the expected location +- `name`, `description` are present; `description_cn` is recommended +- `defaults.base_url` is set (otherwise `{base_url}` substitution silently + produces `/path` and every request 404s) +- `auth.secret` and `auth.inject_as` are set when an `auth:` block exists + +**Python tools (Mode B)** +- `from flocks.tool.registry import ...` is present +- `@ToolRegistry.register_function` is on at least one function +- The decorator carries `name`, `description`, `category`, `parameters` +- Every `ToolParameter(name=...)` matches an actual function argument + (and the function is `async def`, with `ctx` as the first parameter) +- The function returns a `ToolResult(...)` + +The output is a per-section report ending with +`Summary: N FAIL, M WARN`. **Do not proceed past this step until +`FAIL` is `0`.** Fix the file and re-run the validator. WARN items +should also be addressed unless you have a deliberate reason to leave +them — note the reason when reporting back. + +For a CI-style check that fails on warnings too: + ```bash -uv run python -c " -import py_compile -py_compile.compile('$TOOL_PATH', doraise=True) -print('PASS: Python syntax valid') -" +uv run python "$SKILL_DIR/validator.py" --strict "$TOOL_PATH" ``` -If validation fails: **fix the file immediately** and re-validate. - -### Step 2: Load Test +### Step 1: Load Test Attempt to load the tool into the registry to catch schema/handler errors: +A tool is not considered complete unless it can be successfully discovered, loaded, and registered by the tool system, not just written to disk or pass static validation. + **YAML-HTTP tools (Mode A):** ```bash uv run python -c " @@ -496,7 +544,7 @@ else: If load fails: **read the error, fix the root cause**, and re-run. -### Step 3: Smoke Test +### Step 2: Smoke Test Execute the tool with **safe, minimal test parameters** to confirm end-to-end functionality: @@ -568,7 +616,7 @@ asyncio.run(run()) - **Python tools**: use the simplest valid input that exercises the happy path. For destructive tools (delete, write), create a temp file first. - If the tool requires an API key that hasn't been configured yet, the smoke test may return an auth error — that's **acceptable**. Report it to the user and note the tool structure is correct. -### Step 4: Report Results +### Step 3: Report Results After all steps, summarize to the user: @@ -576,8 +624,9 @@ After all steps, summarize to the user: Tool created: {name} Mode: {A/B} Path: {file_path} - Static validation: PASS + Metadata & handler audit: {PASS/WARN/FAIL} — {N FAIL, M WARN} Load test: PASS + Tool system registration: PASS Smoke test: {PASS/WARN/FAIL} — {details} Hot-reload: automatic (file watcher active — no restart or manual refresh needed) @@ -608,3 +657,4 @@ Tool created: {name} 10. For API integrations: endpoint inventory completed; every discovered endpoint is implemented or explicitly skipped 11. Tool name / filename / function name preserve endpoint vocabulary 12. API tool output does not drop upstream fields unless the user explicitly asked for a reduced result +13. **Step 0 of the Verification Protocol (validator.py) was run and ended with `0 FAIL`** — every WARN is either fixed or explicitly justified in your report diff --git a/.flocks/plugins/skills/tool-builder/validator.py b/.flocks/plugins/skills/tool-builder/validator.py new file mode 100644 index 00000000..c8b29f54 --- /dev/null +++ b/.flocks/plugins/skills/tool-builder/validator.py @@ -0,0 +1,1037 @@ +#!/usr/bin/env python3 +""" +Tool plugin validator for the tool-builder skill. + +Audits a freshly created Flocks tool for missing or inconsistent +metadata and handler information. Designed to be invoked by the +tool-builder skill's Verification Protocol *before* declaring a tool +ready for use. + +Detects, among other things: + * Missing tool name / description / category / enabled + * Invalid category (silently coerced to ``custom`` by the loader) + * Missing inputSchema / parameters + * Parameter properties without ``type`` or ``description`` + * For YAML-HTTP tools: missing handler section, missing url/method, + URL/header/body placeholders that are not declared as parameters, + secret references, ``response.error_mapping`` keys that are not int + * For YAML-script tools: missing ``script_file``, file not on disk, + missing or non-async / non-callable ``function`` symbol + * For tools under ``api/{provider}/``: missing or incomplete + ``_provider.yaml`` (name / description / description_cn / + defaults.base_url, auth.secret/inject_as) + * For Python tools: missing ``@ToolRegistry.register_function`` + decorator, missing decorator kwargs (name/description/parameters), + parameter list that does not match the function signature, function + that is not ``async def`` + +Usage: + uv run python validator.py + +Exit codes: + 0 — no FAIL items (WARN allowed) + 1 — at least one FAIL item OR validator could not run +""" + +from __future__ import annotations + +import argparse +import ast +import re +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Set + +try: + import yaml +except ImportError: + print("FAIL: pyyaml not installed. Run `uv pip install pyyaml`.", file=sys.stderr) + sys.exit(1) + + +# --------------------------------------------------------------------------- +# Constants mirrored from flocks.tool.registry / flocks.tool.tool_loader +# --------------------------------------------------------------------------- + +VALID_CATEGORIES = { + "file", "terminal", "browser", "code", + "search", "system", "custom", +} + +VALID_PARAMETER_TYPES = { + "string", "integer", "number", "boolean", "array", "object", +} + +# Tool name collisions with built-ins or reserved words. +RESERVED_TOOL_NAMES = { + "read", "write", "edit", "multiedit", "apply_patch", "glob", "list_tool", + "file_search", "doc_parser", + "bash", "grep", "codesearch", "lsp_tool", + "webfetch", "websearch", + "delegate_task", "call_omo_agent", + "task", "task_center", "todo", "plan", + "run_workflow", "run_workflow_node", + "echo", "get_time", + "skill", "question", +} + +PARAM_PATTERN = re.compile(r"\{([^}]+)\}") +SECRET_PATTERN = re.compile(r"\{secret:([^}]+)\}") +SNAKE_CASE_RE = re.compile(r"^[a-z][a-z0-9_]*$") + + +# --------------------------------------------------------------------------- +# Diagnostic record +# --------------------------------------------------------------------------- + + +@dataclass +class Issue: + level: str # "FAIL" | "WARN" | "PASS" + section: str + message: str + + +@dataclass +class Report: + target: Path + mode: str = "" + issues: List[Issue] = field(default_factory=list) + + def add(self, level: str, section: str, message: str) -> None: + self.issues.append(Issue(level=level, section=section, message=message)) + + def fail(self, section: str, message: str) -> None: + self.add("FAIL", section, message) + + def warn(self, section: str, message: str) -> None: + self.add("WARN", section, message) + + def ok(self, section: str, message: str) -> None: + self.add("PASS", section, message) + + @property + def fail_count(self) -> int: + return sum(1 for i in self.issues if i.level == "FAIL") + + @property + def warn_count(self) -> int: + return sum(1 for i in self.issues if i.level == "WARN") + + def render(self) -> str: + lines = [f"=== Validation report: {self.target} ==="] + if self.mode: + lines.append(f"Mode: {self.mode}") + lines.append("") + + sections: Dict[str, List[Issue]] = {} + for issue in self.issues: + sections.setdefault(issue.section, []).append(issue) + + for section, items in sections.items(): + lines.append(f"[{section}]") + for it in items: + lines.append(f" {it.level:<4} {it.message}") + lines.append("") + + lines.append( + f"Summary: {self.fail_count} FAIL, {self.warn_count} WARN" + ) + if self.fail_count == 0 and self.warn_count == 0: + lines[-1] += " — looks good." + elif self.fail_count == 0: + lines[-1] += " — fix WARN items if you want a clean report." + else: + lines[-1] += " — fix FAIL items before declaring the tool ready." + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# YAML tool validation +# --------------------------------------------------------------------------- + + +def _looks_like_yaml(path: Path) -> bool: + return path.suffix.lower() in {".yaml", ".yml"} + + +def _tool_under_api(yaml_path: Path) -> bool: + """Return True if this YAML lives under an ``api/`` directory. + + Accepts both the canonical install location + (``~/.flocks/plugins/tools/api/...``) and ad-hoc layouts during + development where the user runs the validator directly on a file + that simply has ``api`` as one of its parent directories. + """ + for parent in yaml_path.resolve().parents: + if parent.name == "api": + return True + return False + + +def _provider_dir(yaml_path: Path) -> Optional[Path]: + """If the YAML is under ``api/{provider}/foo.yaml``, return the provider dir.""" + parent = yaml_path.parent + if parent.name == "api": + # Standalone tool directly under api/ — no provider dir. + return None + grandparent = parent.parent + if grandparent.name == "api": + return parent + return None + + +def validate_yaml_tool(yaml_path: Path) -> Report: + report = Report(target=yaml_path) + try: + raw_text = yaml_path.read_text(encoding="utf-8") + except OSError as e: + report.fail("File", f"Cannot read file: {e}") + return report + + try: + data = yaml.safe_load(raw_text) + except yaml.YAMLError as e: + report.fail("Syntax", f"YAML parse error: {e}") + return report + + if not isinstance(data, dict): + report.fail("Syntax", "Top-level YAML must be a mapping/dict") + return report + + handler_raw = data.get("handler") + handler_type = ( + handler_raw.get("type", "http").lower() + if isinstance(handler_raw, dict) else None + ) + report.mode = ( + f"YAML-HTTP (Mode A, handler.type={handler_type})" + if handler_type else "YAML (handler missing)" + ) + + _validate_yaml_metadata(data, report, yaml_path) + parameter_names = _validate_yaml_parameters(data, report) + _validate_yaml_handler(data, report, yaml_path, parameter_names) + _validate_provider_yaml(yaml_path, data, report) + + return report + + +def _validate_yaml_metadata( + data: Dict[str, Any], report: Report, yaml_path: Path, +) -> None: + section = "Metadata" + + name = data.get("name") + if not name or not isinstance(name, str): + report.fail(section, "missing required field 'name'") + else: + if not SNAKE_CASE_RE.match(name): + report.warn(section, f"name '{name}' is not snake_case") + if name in RESERVED_TOOL_NAMES: + report.fail( + section, + f"name '{name}' collides with a built-in tool — pick another", + ) + stem = yaml_path.stem + if stem != name: + report.warn( + section, + f"YAML filename '{stem}.yaml' does not match name '{name}'", + ) + report.ok(section, f"name = {name}") + + description = data.get("description") + if not description or not str(description).strip(): + report.fail(section, "missing or empty 'description'") + elif len(str(description).strip()) < 20: + report.warn( + section, + f"description is only {len(str(description).strip())} chars — " + "the LLM uses this to decide when to invoke the tool", + ) + else: + report.ok(section, f"description present ({len(str(description))} chars)") + + category = data.get("category") + if category is None: + # API tools commonly inherit category from _provider.yaml; only warn. + report.warn( + section, + "no 'category' set — loader will fall back to 'custom' " + "(or provider defaults if present)", + ) + elif category not in VALID_CATEGORIES: + report.fail( + section, + f"category '{category}' is invalid; loader silently coerces " + f"to 'custom'. Valid: {sorted(VALID_CATEGORIES)}", + ) + else: + report.ok(section, f"category = {category}") + + enabled = data.get("enabled") + if enabled is None: + report.warn( + section, + "no 'enabled' field — defaults to true; set explicitly so " + "the tool is unambiguously activated", + ) + elif enabled is not True: + report.warn( + section, + f"enabled = {enabled!r} — tool will NOT be active immediately. " + "Set 'enabled: true' unless the user asked for it disabled.", + ) + else: + report.ok(section, "enabled = true") + + # provider field — required for API service card display + if _tool_under_api(yaml_path): + provider = data.get("provider") + prov_dir = _provider_dir(yaml_path) + if not provider and not prov_dir: + report.warn( + section, + "tool is under api/ but neither 'provider' field nor " + "a provider subdirectory with _provider.yaml is present " + "— it will not appear as an API service card", + ) + elif provider: + report.ok(section, f"provider = {provider}") + + +def _validate_yaml_parameters(data: Dict[str, Any], report: Report) -> Set[str]: + """Validate inputSchema/parameters and return the set of declared param names.""" + section = "Parameters" + declared: Set[str] = set() + + input_schema = data.get("inputSchema") + params_list = data.get("parameters") + + if input_schema is None and params_list is None: + report.warn( + section, + "no inputSchema or parameters declared — " + "the LLM will not be able to pass arguments", + ) + return declared + + if input_schema is not None and params_list is not None: + report.warn( + section, + "both 'inputSchema' and 'parameters' are present; " + "'inputSchema' wins — drop 'parameters' to avoid confusion", + ) + + if isinstance(input_schema, dict): + if input_schema.get("type") != "object": + report.warn( + section, + f"inputSchema.type = {input_schema.get('type')!r}; " + "should be 'object' for tool inputs", + ) + properties = input_schema.get("properties") or {} + if not isinstance(properties, dict) or not properties: + report.warn(section, "inputSchema.properties is empty") + else: + required = set(input_schema.get("required") or []) + for pname, pinfo in properties.items(): + declared.add(pname) + _validate_param_entry( + pname, pinfo, in_input_schema=True, + is_required=pname in required, report=report, + ) + unknown_required = required - declared + for pname in unknown_required: + report.fail( + section, + f"required parameter '{pname}' is not defined in properties", + ) + if declared: + report.ok(section, f"inputSchema declares: {sorted(declared)}") + return declared + + if isinstance(params_list, list): + if not params_list: + report.warn(section, "'parameters' list is empty") + for item in params_list: + if not isinstance(item, dict): + report.fail(section, f"parameter entry is not a dict: {item!r}") + continue + pname = item.get("name") + if not pname: + report.fail(section, f"parameter entry missing 'name': {item!r}") + continue + declared.add(pname) + _validate_param_entry( + pname, item, in_input_schema=False, + is_required=item.get("required", True), report=report, + ) + if declared: + report.ok(section, f"parameters declares: {sorted(declared)}") + return declared + + report.fail( + section, + f"inputSchema/parameters has unexpected type: " + f"{type(input_schema or params_list).__name__}", + ) + return declared + + +def _validate_param_entry( + pname: str, + pinfo: Dict[str, Any], + in_input_schema: bool, + is_required: bool, + report: Report, +) -> None: + section = "Parameters" + if not isinstance(pinfo, dict): + report.fail(section, f"parameter '{pname}' is not a mapping") + return + + ptype = pinfo.get("type") + if not ptype: + report.warn( + section, + f"parameter '{pname}' missing 'type' " + "(loader falls back to 'string')", + ) + elif ptype not in VALID_PARAMETER_TYPES: + report.fail( + section, + f"parameter '{pname}' has invalid type '{ptype}'. " + f"Valid: {sorted(VALID_PARAMETER_TYPES)}", + ) + + description = pinfo.get("description") + if not description or not str(description).strip(): + report.warn( + section, + f"parameter '{pname}' missing 'description' — " + "the LLM cannot reliably fill it in", + ) + + if is_required and "default" in pinfo: + report.warn( + section, + f"parameter '{pname}' is required but also has a default — " + "default is ignored when required=true", + ) + + +def _validate_yaml_handler( + data: Dict[str, Any], + report: Report, + yaml_path: Path, + parameter_names: Set[str], +) -> None: + section = "Handler" + handler = data.get("handler") + execution = data.get("execution") + + if not isinstance(handler, dict): + if isinstance(execution, dict): + report.fail( + section, + "uses inline 'execution' block — disabled for safety. " + "Use handler.type=script with a separate handler file.", + ) + else: + report.fail( + section, + "missing 'handler' section — loader will refuse to register " + "the tool", + ) + return + + htype = handler.get("type", "http") + if htype not in {"http", "script"}: + report.fail( + section, + f"handler.type = {htype!r}; must be 'http' or 'script'", + ) + return + + if htype == "http": + _validate_http_handler(handler, report, parameter_names, yaml_path) + else: + _validate_script_handler(handler, report, parameter_names, yaml_path) + + +def _validate_http_handler( + handler: Dict[str, Any], + report: Report, + parameter_names: Set[str], + yaml_path: Path, +) -> None: + section = "Handler" + + method = handler.get("method") + if not method: + report.warn(section, "no 'method' set — loader defaults to GET") + elif str(method).upper() not in {"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"}: + report.warn(section, f"unusual HTTP method: {method!r}") + + url = handler.get("url") + if not url: + report.fail(section, "handler.url is empty — request would target ''") + else: + report.ok(section, f"handler.url = {url}") + prov_dir = _provider_dir(yaml_path) + # When {base_url} is used, _provider.yaml MUST supply defaults.base_url. + if "{base_url}" in url and prov_dir is None: + report.fail( + section, + "url uses {base_url} but the tool is not under " + "api/{provider}/ — there is no _provider.yaml to inject it", + ) + + # Collect placeholders across url/headers/query_params/body + declarable: Set[str] = set(parameter_names) + declarable.add("base_url") + referenced = set() + + def _scan(value: Any) -> None: + if isinstance(value, str): + for m in PARAM_PATTERN.findall(value): + if not m.startswith("secret:"): + referenced.add(m) + elif isinstance(value, dict): + for v in value.values(): + _scan(v) + elif isinstance(value, list): + for v in value: + _scan(v) + + _scan(handler.get("url")) + _scan(handler.get("headers")) + _scan(handler.get("query_params")) + _scan(handler.get("body")) + + undeclared = referenced - declarable + for name in sorted(undeclared): + report.fail( + section, + f"placeholder '{{{name}}}' is referenced in url/headers/" + f"query_params/body but not declared as a parameter — " + "loader will substitute an empty string", + ) + + unused = parameter_names - referenced - {"base_url"} + for name in sorted(unused): + report.warn( + section, + f"parameter '{name}' is declared but never used in " + "url/headers/query_params/body", + ) + + response = handler.get("response") + if isinstance(response, dict): + error_mapping = response.get("error_mapping") or {} + if isinstance(error_mapping, dict): + for k in error_mapping.keys(): + try: + int(k) + except (TypeError, ValueError): + report.fail( + section, + f"response.error_mapping key {k!r} is not an int", + ) + + # Detect secret refs and remind user. + secret_refs: Set[str] = set() + + def _scan_secret(value: Any) -> None: + if isinstance(value, str): + for m in SECRET_PATTERN.findall(value): + secret_refs.add(m) + elif isinstance(value, dict): + for v in value.values(): + _scan_secret(v) + elif isinstance(value, list): + for v in value: + _scan_secret(v) + + _scan_secret(handler) + for s in sorted(secret_refs): + report.warn( + "Secrets", + f"references {{secret:{s}}} — confirm it exists in " + "~/.flocks/config/.secret.json", + ) + + +def _validate_script_handler( + handler: Dict[str, Any], + report: Report, + parameter_names: Set[str], + yaml_path: Path, +) -> None: + section = "Handler" + + script_file = handler.get("script_file") + function_name = handler.get("function") or "handle" + + if not script_file: + report.fail(section, "handler.script_file is missing") + return + + script_path = (yaml_path.parent / script_file).resolve() + if not script_path.is_file(): + report.fail( + section, + f"handler script file not found: {script_path}", + ) + return + report.ok(section, f"script_file resolved to {script_path}") + + try: + source = script_path.read_text(encoding="utf-8") + tree = ast.parse(source) + except (OSError, SyntaxError) as e: + report.fail(section, f"cannot parse script file: {e}") + return + + target_fn: Optional[ast.AST] = None + for node in ast.iter_child_nodes(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + if node.name == function_name: + target_fn = node + break + + if target_fn is None: + report.fail( + section, + f"function '{function_name}' not found in {script_path.name}", + ) + return + + if not isinstance(target_fn, ast.AsyncFunctionDef): + report.fail( + section, + f"function '{function_name}' must be 'async def' so the loader " + "can await it", + ) + else: + report.ok(section, f"function = {function_name} (async)") + + args = target_fn.args + pos_args = [a.arg for a in args.args] + if not pos_args or pos_args[0] != "ctx": + report.fail( + section, + f"function '{function_name}' first parameter must be 'ctx' " + f"(got {pos_args!r})", + ) + + kwarg_names = set(pos_args[1:]) | {a.arg for a in args.kwonlyargs} + has_var_kw = args.kwarg is not None + + if not has_var_kw: + missing = parameter_names - kwarg_names + for name in sorted(missing): + report.fail( + section, + f"parameter '{name}' is declared in the YAML but not " + f"accepted by '{function_name}({', '.join(pos_args)})'. " + "Add it to the signature or accept **kwargs.", + ) + + # Detect imports of ToolResult — warn if missing. + has_toolresult_import = any( + isinstance(node, ast.ImportFrom) + and node.module == "flocks.tool.registry" + and any(alias.name == "ToolResult" for alias in node.names) + for node in ast.walk(tree) + ) + if not has_toolresult_import: + report.warn( + section, + "script does not import ToolResult from flocks.tool.registry; " + "the loader will wrap the return value but explicit ToolResult " + "is recommended", + ) + + +def _validate_provider_yaml( + yaml_path: Path, data: Dict[str, Any], report: Report, +) -> None: + section = "Provider" + prov_dir = _provider_dir(yaml_path) + if prov_dir is None: + return + + provider_file = prov_dir / "_provider.yaml" + if not provider_file.is_file(): + report.fail( + section, + f"_provider.yaml is missing at {provider_file} — required for " + "the API service card to render", + ) + return + report.ok(section, f"_provider.yaml found at {provider_file}") + + try: + prov_data = yaml.safe_load(provider_file.read_text(encoding="utf-8")) + except yaml.YAMLError as e: + report.fail(section, f"_provider.yaml parse error: {e}") + return + if not isinstance(prov_data, dict): + report.fail(section, "_provider.yaml must be a mapping/dict") + return + + if not prov_data.get("name"): + report.fail(section, "_provider.yaml missing 'name'") + if not prov_data.get("description"): + report.fail(section, "_provider.yaml missing 'description' (English)") + if not prov_data.get("description_cn"): + report.warn( + section, + "_provider.yaml missing 'description_cn' — Chinese UI will fall " + "back to English", + ) + + defaults = prov_data.get("defaults") or {} + if not isinstance(defaults, dict): + report.fail(section, "_provider.yaml.defaults must be a mapping") + defaults = {} + if not defaults.get("base_url"): + report.fail( + section, + "_provider.yaml.defaults.base_url is missing — handler urls " + "using {base_url} will resolve to '/path'", + ) + if "category" not in defaults and not data.get("category"): + report.warn( + section, + "_provider.yaml.defaults.category is missing and the tool also " + "has no category — loader falls back to 'custom'", + ) + + auth = prov_data.get("auth") + if auth is None: + report.warn( + section, + "_provider.yaml has no 'auth' block — that is fine for " + "open APIs, but most providers need a credential", + ) + elif isinstance(auth, dict): + if not auth.get("secret"): + report.fail(section, "_provider.yaml.auth.secret is missing") + inject_as = auth.get("inject_as") + if inject_as not in {"header", "query_param", None}: + report.fail( + section, + f"_provider.yaml.auth.inject_as = {inject_as!r}; " + "must be 'header' or 'query_param'", + ) + if inject_as == "query_param" and not auth.get("param_name"): + report.warn( + section, + "_provider.yaml.auth.param_name missing — defaults to 'api_key'", + ) + + +# --------------------------------------------------------------------------- +# Python tool validation +# --------------------------------------------------------------------------- + + +def validate_python_tool(py_path: Path) -> Report: + report = Report(target=py_path, mode="Python (Mode B)") + try: + source = py_path.read_text(encoding="utf-8") + except OSError as e: + report.fail("File", f"Cannot read file: {e}") + return report + + try: + tree = ast.parse(source) + except SyntaxError as e: + report.fail("Syntax", f"SyntaxError: {e}") + return report + + has_registry_import = any( + isinstance(node, ast.ImportFrom) + and node.module == "flocks.tool.registry" + and any( + alias.name in {"ToolRegistry", "ToolResult", "ToolContext"} + for alias in node.names + ) + for node in ast.walk(tree) + ) + if not has_registry_import: + report.fail( + "Imports", + "missing `from flocks.tool.registry import ...` " + "(need at least ToolRegistry, ToolResult, ToolContext)", + ) + else: + report.ok("Imports", "imports flocks.tool.registry") + + decorated = list(_find_register_function_targets(tree)) + if not decorated: + report.fail( + "Decorator", + "no @ToolRegistry.register_function decorator found — " + "the tool will not be registered on import", + ) + return report + report.ok( + "Decorator", + f"found {len(decorated)} @ToolRegistry.register_function " + f"target(s)", + ) + + for fn_node, decorator_call in decorated: + _validate_python_decorated_function(fn_node, decorator_call, report) + + return report + + +def _find_register_function_targets(tree: ast.AST) -> Iterable: + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + for decorator in node.decorator_list: + if not isinstance(decorator, ast.Call): + continue + func = decorator.func + if ( + isinstance(func, ast.Attribute) + and func.attr == "register_function" + and isinstance(func.value, ast.Name) + and func.value.id == "ToolRegistry" + ): + yield node, decorator + + +def _kwarg_value(call: ast.Call, name: str) -> Optional[ast.AST]: + for kw in call.keywords: + if kw.arg == name: + return kw.value + return None + + +def _const_str(node: Optional[ast.AST]) -> Optional[str]: + if isinstance(node, ast.Constant) and isinstance(node.value, str): + return node.value + return None + + +def _validate_python_decorated_function( + fn_node: ast.AST, call: ast.Call, report: Report, +) -> None: + section = f"Function {getattr(fn_node, 'name', '')}" + + name = _const_str(_kwarg_value(call, "name")) + description = _const_str(_kwarg_value(call, "description")) + category_node = _kwarg_value(call, "category") + parameters_node = _kwarg_value(call, "parameters") + + if not name: + report.fail(section, "decorator missing or non-string 'name='") + else: + if not SNAKE_CASE_RE.match(name): + report.warn(section, f"name '{name}' is not snake_case") + if name in RESERVED_TOOL_NAMES: + report.fail( + section, + f"name '{name}' collides with a built-in tool — pick another", + ) + report.ok(section, f"name = {name}") + + if not description or not description.strip(): + report.fail(section, "decorator missing or empty 'description='") + elif len(description.strip()) < 20: + report.warn( + section, + f"description is only {len(description.strip())} chars — " + "the LLM uses this to decide when to invoke the tool", + ) + else: + report.ok(section, f"description present ({len(description.strip())} chars)") + + if category_node is None: + report.warn( + section, + "no 'category=' set — defaults to ToolCategory.CUSTOM", + ) + elif isinstance(category_node, ast.Attribute): + # ToolCategory.SOMETHING — accept; full validation requires runtime. + report.ok(section, f"category = ToolCategory.{category_node.attr}") + elif isinstance(category_node, ast.Constant) and isinstance(category_node.value, str): + if category_node.value not in VALID_CATEGORIES: + report.fail( + section, + f"category={category_node.value!r} is not in " + f"{sorted(VALID_CATEGORIES)}", + ) + else: + report.ok(section, f"category = {category_node.value!r}") + + declared_params: List[str] = [] + if parameters_node is None: + report.warn( + section, + "no 'parameters=' provided — the tool exposes zero arguments", + ) + elif isinstance(parameters_node, ast.List): + if not parameters_node.elts: + report.warn(section, "'parameters=[]' is empty") + for elt in parameters_node.elts: + if isinstance(elt, ast.Call) and _is_tool_parameter_call(elt): + pname = _const_str(_kwarg_value(elt, "name")) + if not pname: + # Try positional first arg. + if elt.args and isinstance(elt.args[0], ast.Constant): + pname = elt.args[0].value + if not pname: + report.fail(section, "ToolParameter() entry without 'name'") + continue + declared_params.append(pname) + + ptype_node = _kwarg_value(elt, "type") + if ptype_node is None: + report.warn( + section, + f"parameter '{pname}' missing 'type=' " + "(defaults will not work — type is required)", + ) + + pdesc = _const_str(_kwarg_value(elt, "description")) + if not pdesc or not pdesc.strip(): + report.warn( + section, + f"parameter '{pname}' missing 'description=' — " + "the LLM cannot reliably fill it in", + ) + if declared_params: + report.ok(section, f"parameters = {declared_params}") + + # Function signature checks + args = fn_node.args + pos_args = [a.arg for a in args.args] + is_async = isinstance(fn_node, ast.AsyncFunctionDef) + if not is_async: + report.fail( + section, + f"function '{fn_node.name}' must be 'async def'", + ) + else: + report.ok(section, f"function '{fn_node.name}' is async def") + if not pos_args or pos_args[0] != "ctx": + report.fail( + section, + f"function '{fn_node.name}' first parameter must be 'ctx' " + f"(got {pos_args!r})", + ) + else: + report.ok(section, f"signature = ({', '.join(pos_args)})") + + kwarg_names = set(pos_args[1:]) | {a.arg for a in args.kwonlyargs} + has_var_kw = args.kwarg is not None + if not has_var_kw and declared_params: + missing = set(declared_params) - kwarg_names + for name in sorted(missing): + report.fail( + section, + f"parameter '{name}' is declared in the decorator but not " + f"accepted by '{fn_node.name}({', '.join(pos_args)})'", + ) + unused = kwarg_names - set(declared_params) + for name in sorted(unused): + if name in {"self", "cls"}: + continue + report.warn( + section, + f"function arg '{name}' is not declared as a ToolParameter", + ) + + # Detect ToolResult return. + returns_toolresult = False + for node in ast.walk(fn_node): + if isinstance(node, ast.Return) and isinstance(node.value, ast.Call): + func = node.value.func + if isinstance(func, ast.Name) and func.id == "ToolResult": + returns_toolresult = True + break + if ( + isinstance(func, ast.Attribute) + and func.attr == "ToolResult" + ): + returns_toolresult = True + break + if not returns_toolresult: + report.warn( + section, + "no 'return ToolResult(...)' detected — loader will wrap the " + "return value, but explicit ToolResult is recommended", + ) + else: + report.ok(section, "returns ToolResult(...)") + + +def _is_tool_parameter_call(call: ast.Call) -> bool: + func = call.func + if isinstance(func, ast.Name) and func.id == "ToolParameter": + return True + if isinstance(func, ast.Attribute) and func.attr == "ToolParameter": + return True + return False + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def main(argv: Optional[List[str]] = None) -> int: + parser = argparse.ArgumentParser( + description=( + "Validate a Flocks tool YAML or Python file for missing " + "metadata and handler information." + ) + ) + parser.add_argument( + "path", + help="Path to the tool YAML or Python file to validate", + ) + parser.add_argument( + "--strict", + action="store_true", + help="Treat WARN as failure (exit code 1 when any WARN exists)", + ) + args = parser.parse_args(argv) + + target = Path(args.path).expanduser() + if not target.exists(): + print(f"FAIL: file not found: {target}", file=sys.stderr) + return 1 + + if _looks_like_yaml(target): + report = validate_yaml_tool(target) + elif target.suffix == ".py": + report = validate_python_tool(target) + else: + print( + f"FAIL: unsupported file type {target.suffix!r}; " + "expected .yaml/.yml/.py", + file=sys.stderr, + ) + return 1 + + print(report.render()) + if report.fail_count > 0: + return 1 + if args.strict and report.warn_count > 0: + return 1 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.flocks/plugins/skills/web2cli/SKILL.md b/.flocks/plugins/skills/web2cli/SKILL.md index bf0b2232..8dbf1218 100644 --- a/.flocks/plugins/skills/web2cli/SKILL.md +++ b/.flocks/plugins/skills/web2cli/SKILL.md @@ -1,6 +1,6 @@ --- name: web2cli -description: 使用统一的 Web2CLI 流程捕获网站的 XHR/Fetch 请求,并生成可复用的 CLI、Markdown 文档和 Postman 集合。支持 `agent-browser` 与 `cdp-direct` 两种模式:前者适合独立浏览器会话,后者复用用户 Chromium 系浏览器登录态与 CDP 能力。适用于复现登录后操作、沉淀接口调用样例,或基于页面操作生成自动化工具时。 +description: 使用统一的 Web2CLI 流程捕获网站的 XHR/Fetch 请求,并生成可复用的 CLI、Markdown 文档。支持 `agent-browser` 与 `cdp-direct` 两种模式:前者适合独立浏览器会话,后者复用用户 Chromium 系浏览器登录态与 CDP 能力。适用于复现登录后操作、沉淀接口调用样例,或基于页面操作生成自动化工具时。 required: browser-use --- @@ -63,8 +63,10 @@ mkdir -p "$CAPTURE_ROOT/captures" - 浏览器内存中的原始捕获数据:`window.__capturedRequests` - 导出的接口抓包 JSON:`$CAPTURE_ROOT/captures/${CAPTURE_NAME}_api.json` - 浏览器认证状态:`$CAPTURE_ROOT/auth-state.json` +- 操作适配规格:`$CAPTURE_ROOT/web2cli-spec.json` - 站点自适应 Hook(仅当 base 失败时创建):`$CAPTURE_ROOT/hook.js` - 生成的 CLI 工具:`$CAPTURE_ROOT/_cli.py`,`generate-cli.py` 会把 `-` 等非 Python 模块名字符替换为 `_` +- 生成的验证材料:`$CAPTURE_ROOT/${CAPTURE_NAME}_verify.json` - 生成的接口文档:`$CAPTURE_ROOT/${CAPTURE_NAME}_api.md` - 生成的 Postman 集合:`$CAPTURE_ROOT/${CAPTURE_NAME}_postman.json` @@ -323,35 +325,73 @@ jq -r '.[].method' "$CAPTURE_ROOT/captures/${CAPTURE_NAME}_api.json" | sort | un jq '.[] | select(.method == "POST") | {url: .url, body: .requestBody}' "$CAPTURE_ROOT/captures/${CAPTURE_NAME}_api.json" ``` -### 8. 生成 CLI 工具 +### 8. 生成 web2cli-spec 规格 -基于 `"$CAPTURE_ROOT/captures/${CAPTURE_NAME}_api.json"` 与 `"$CAPTURE_ROOT/auth-state.json"` 生成新的 CLI 工具。 +先基于 `"$CAPTURE_ROOT/captures/${CAPTURE_NAME}_api.json"` 生成中间契约层 `web2cli-spec.json`。 ```bash -uv run python .flocks/plugins/skills/web2cli/scripts/generate-cli.py \ +uv run python .flocks/plugins/skills/web2cli/scripts/generate-spec.py \ "$CAPTURE_ROOT/captures/${CAPTURE_NAME}_api.json" \ - --format python \ --base-url "https://example.com" \ + --output "$CAPTURE_ROOT/web2cli-spec.json" +``` + +`web2cli-spec.json` 是抓包结果到最终 CLI 之间的可编辑契约,包含: + +- 目标站点与命令名 +- 鉴权策略(如 `PUBLIC` / `COOKIE` / `HEADER`) +- 主请求的 method、endpoint、query/body 模板 +- CLI 参数定义 +- 固定输出列定义 +- 验证材料初稿 + +生成后必须检查并按需修正: + +- `strategy` 是否正确 +- `args` 是否符合实际操作意图 +- `columns` 与字段路径是否对应目标数据 +- `verify` 的最少行数、必填列是否合理 + +### 9. 基于 spec 生成 CLI 工具 + +从 `"$CAPTURE_ROOT/web2cli-spec.json"` 生成最终 CLI。 + +```bash +uv run python .flocks/plugins/skills/web2cli/scripts/generate-cli.py \ + --spec "$CAPTURE_ROOT/web2cli-spec.json" \ + --format python \ --output "$CAPTURE_ROOT/${CAPTURE_NAME}_cli.py" ``` 如果 `CAPTURE_NAME` 包含 `-` 等不能作为 Python 模块名的字符,生成器会自动规范化输出文件名,例如 `test-domain_cli.py` 会写为 `test_domain_cli.py`,并在命令输出中打印实际路径。 -如需同时产出文档可继续执行: +生成验证文件: ```bash uv run python .flocks/plugins/skills/web2cli/scripts/generate-cli.py \ - "$CAPTURE_ROOT/captures/${CAPTURE_NAME}_api.json" \ + --spec "$CAPTURE_ROOT/web2cli-spec.json" \ + --format verify \ + --output "$CAPTURE_ROOT/${CAPTURE_NAME}_verify.json" +``` + +生成接口文档: + +```bash +uv run python .flocks/plugins/skills/web2cli/scripts/generate-cli.py \ + --spec "$CAPTURE_ROOT/web2cli-spec.json" \ --format markdown \ --title "${CAPTURE_NAME} API Documentation" \ --output "$CAPTURE_ROOT/${CAPTURE_NAME}_api.md" ``` -### 9. CLI工具验证 和浏览器关闭 +### 10. CLI工具验证 和浏览器关闭 根据生成的 CLI ,任意选择一个接口调用测试可用性 - CLI 工具可用性 - 认证状态可用性 +- `verify.json` 的输出约束是否满足 + +推荐先查看 `"$CAPTURE_ROOT/${CAPTURE_NAME}_verify.json"`,再用生成的 CLI 以默认参数执行一次,确认固定输出列与认证状态都正确。 当验证完成,确保 CLI 可用后关闭浏览器或 Tab @@ -382,12 +422,12 @@ else: `cdp-direct` 必须保留用户原有的 tab 不受影响。 -### 10. summary +### 11. summary 总结当前 生成 的CLI 工具有哪些能力,然后可提示用户下一步操作: -- 保存为对应的 skill 方便后续操作 -- 精简 CLI +- 精简或修正CLI - 进一步丰富 CLI 工具,重新开始 web2cli标准流程 +- 保存为对应的 skill 方便后续操作(进入此操作后,需要阅读references) ## 故障处理 @@ -419,3 +459,6 @@ else: - `agent-browser`:重新登录后再次执行保存状态命令。 - `cdp-direct`:重新登录后再次执行保存认证状态。 + +## Reference +- references/cli-in-skill.md 将生成的 CLI 集成到 skill 中使用 diff --git a/.flocks/plugins/skills/web2cli/references/cli-in-skill.md b/.flocks/plugins/skills/web2cli/references/cli-in-skill.md new file mode 100644 index 00000000..6210c364 --- /dev/null +++ b/.flocks/plugins/skills/web2cli/references/cli-in-skill.md @@ -0,0 +1,171 @@ +# 生成后的 CLI 如何接入 Skill + +> 本文只说明一件事:`web2cli` 已经生成出 CLI 之后,怎样把它整理成可长期维护的 skill 资产。 + +## 命名约定 + +生成阶段的文件名通常来自抓包名,例如 `_cli.py`。这个名字适合临时验证,不适合直接沉淀到 skill。 + +落到 skill 时,统一改成**稳定的产品名**: + +- skill 目录:`$HOME/.flocks/plugins/skills/-use/` +- CLI 主脚本:`$HOME/.flocks/plugins/skills/-use/scripts/_cli.py` +- 默认认证状态:`~/.flocks/browser//auth-state.json` + +约定说明: + +- `` 用产品或系统的稳定标识,不用一次性任务名 +- 目录名可以保留 `-`,例如 `tdp-use` +- Python 脚本名统一用 `_`,例如 `tdp_cli.py` +- 不要把最终 CLI 保留成 `export_data_cli.py`、`test_capture_cli.py` 这类临时名字 + +## 放到已有产品 Skill + +如果仓库里已经有对应产品 skill,直接把生成结果并入现有 skill: + +```bash +SKILL_ROOT="$HOME/.flocks/plugins/skills/-use" + +mkdir -p "$SKILL_ROOT/scripts" +mkdir -p "$HOME/.flocks/browser/" + +cp "$CAPTURE_ROOT/_cli.py" \ + "$SKILL_ROOT/scripts/_cli.py" + +cp "$CAPTURE_ROOT/auth-state.json" \ + "$HOME/.flocks/browser//auth-state.json" +``` + +然后补齐这几项: + +1. 在 `scripts/config.py` 中把认证状态默认值指向 `~/.flocks/browser//auth-state.json` +2. 在 `references/cli-reference.md` 中写清楚 CLI 用法、环境变量和示例 +3. 在 `references/browser-workflow.md` 中写清楚浏览器登录与保存 state 的流程 +4. 在 `SKILL.md` 中说明什么时候优先走 CLI,什么时候退回浏览器 + +推荐的配置写法: + +```python +import os +from pathlib import Path + +AUTH_STATE_FILE = Path( + os.getenv( + "_AUTH_STATE", + Path.home() / ".flocks" / "browser" / "" / "auth-state.json", + ) +) +``` + +这样做的好处是: + +- 默认行为统一,和现有产品 skill 保持一致 +- 允许用户用环境变量覆盖 +- 生成阶段的临时产物和最终长期使用的认证文件分离 + +## 生成新的 Skill + +如果当前仓库里还没有对应产品 skill,就按下面的最小结构创建: + +```text +$HOME/.flocks/plugins/skills/-use/ +├── SKILL.md +├── scripts/ +│ ├── _cli.py +│ └── config.py +└── references/ + ├── browser-workflow.md + └── cli-reference.md +``` + +其中 `SKILL.md` 必须遵守 Flocks 的标准 skill 格式: + +- 文件开头必须是 YAML frontmatter,第一行必须为 `---` +- frontmatter 至少包含 `name` 和 `description` +- `name` 使用稳定的 skill 标识,推荐与目录名一致,例如 `-use` +- frontmatter 结束后,再写正文标题、触发条件、模式判断和使用说明 + +最小模板示例: + +```md +--- +name: test-use +description: 用于查询 Test 测试平台数据,支持通过 CLI 快速查询,认证失效时退回浏览器模式。 +--- + +# Test Use + +## 触发条件 + +- 用户提到 Test 平台 +- 用户需要查询 Test 数据 + +## 模式判断 + +### CLI 模式(默认) + +- 适用于快速查询和批量读取数据 + +### 浏览器模式 + +- 适用于需要页面交互、导出或重新登录的场景 +``` + +不要把 `SKILL.md` 直接写成普通 Markdown 文档,例如下面这种格式是无效的: + +```md +# Test Use +``` + +各文件职责: + +- `SKILL.md`:定义触发条件、模式判断、总入口说明 +- `scripts/_cli.py`:承载生成并整理后的 CLI 能力 +- `scripts/config.py`:集中管理 `BASE_URL`、`AUTH_STATE_FILE`、超时、SSL 等默认配置 +- `references/browser-workflow.md`:写浏览器登录、保存 state、认证恢复流程 +- `references/cli-reference.md`:写 CLI 参数、命令示例、常见查询 + +新 skill 的原则也一样:先把生成的 CLI 改成稳定文件名,再把临时 `auth-state.json` 切换到全局默认位置 `~/.flocks/browser//auth-state.json`。 + +## 认证失败怎么处理 + +CLI 调用出现以下情况时,优先按认证失效处理: + +- 返回 `401` 或 `403` +- 返回内容出现 `Unauthorized`、`login`、未登录、无权限 +- `auth-state.json` 已存在,但请求仍然被重定向到登录页 + +处理原则: + +1. 不要无限重试 CLI +2. 请求用户重新通过浏览器登录 +3. 登录完成后,重新保存认证状态到默认路径 +4. 再重试一次 CLI + +默认认证文件路径固定为: + +```bash +~/.flocks/browser//auth-state.json +``` + +保存方式示例: + +```bash +mkdir -p "$HOME/.flocks/browser/" + +# agent-browser 模式 +agent-browser state save "$HOME/.flocks/browser//auth-state.json" + +# 或 cdp-direct / flocks browser 模式 +flocks browser state save "$HOME/.flocks/browser//auth-state.json" +``` + +如果用户重新登录并保存 state 后,CLI 仍然失败,再继续排查: + +- `BASE_URL` 是否写错 +- 当前账号是否确实有接口权限 +- 站点是否还有额外 header / token / csrf 依赖 + +## 一句话原则 + +`web2cli` 产出的 `_cli.py` 是临时结果;真正沉淀到 skill 时,要改成稳定产品名脚本,并把认证状态统一落到 `~/.flocks/browser//auth-state.json`。 diff --git a/.flocks/plugins/skills/web2cli/scripts/generate-cli.py b/.flocks/plugins/skills/web2cli/scripts/generate-cli.py index f6909d2b..3e41fff4 100644 --- a/.flocks/plugins/skills/web2cli/scripts/generate-cli.py +++ b/.flocks/plugins/skills/web2cli/scripts/generate-cli.py @@ -411,37 +411,515 @@ def generate_postman_collection(requests: List[Dict], base_url: str) -> Dict: return collection +def load_spec(spec_path: str) -> Dict[str, Any]: + """Load a web2cli spec from disk.""" + with open(spec_path, encoding="utf-8") as f: + payload = json.load(f) + if not isinstance(payload, dict): + raise ValueError("Spec file must contain a JSON object") + return payload + + +def generate_verify_materials_from_spec(spec: Dict[str, Any]) -> Dict[str, Any]: + """Generate verify metadata from a web2cli spec.""" + verify = spec.get("verify", {}) if isinstance(spec.get("verify"), dict) else {} + columns = spec.get("columns", []) + column_names = [column.get("name") for column in columns if isinstance(column, dict) and column.get("name")] + + return { + "site": spec.get("site", ""), + "command": spec.get("command", ""), + "args": verify.get("args", {}), + "expect": { + "rowCount": verify.get("rowCount", {"min": 1}), + "columns": verify.get("columns", column_names), + "types": verify.get( + "types", + { + column.get("name"): column.get("type", "string") + for column in columns + if isinstance(column, dict) and column.get("name") + }, + ), + "notEmpty": verify.get("notEmpty", column_names[: min(3, len(column_names))]), + "patterns": verify.get("patterns", {}), + }, + } + + +def generate_markdown_docs_from_spec(spec: Dict[str, Any], title: str = "API Documentation") -> str: + """Generate Markdown documentation from a web2cli spec.""" + operation = spec.get("operation", {}) + args = spec.get("args", []) + columns = spec.get("columns", []) + verify = generate_verify_materials_from_spec(spec) + + md = f"""# {title} + +> Auto-generated Web2CLI Specification +> Site: `{spec.get("site", "")}` +> Command: `{spec.get("command", "")}` + +## 概览 + +- **描述**: {spec.get("description", "N/A")} +- **策略**: `{spec.get("strategy", "PUBLIC")}` +- **Base URL**: `{spec.get("baseUrl", "")}` +- **Method**: `{operation.get("method", "GET")}` +- **Endpoint**: `{operation.get("endpoint", "/")}` + +## 参数 + +""" + + if args: + md += "| 参数 | 类型 | 默认值 | 说明 |\n" + md += "|------|------|--------|------|\n" + for arg in args: + md += f"| `{arg.get('name', '')}` | `{arg.get('type', 'string')}` | `{arg.get('default', '')}` | {arg.get('help', '')} |\n" + md += "\n" + else: + md += "无参数。\n\n" + + md += "## 输出列\n\n" + md += "| 列名 | 类型 | 路径 |\n" + md += "|------|------|------|\n" + for column in columns: + md += f"| `{column.get('name', '')}` | `{column.get('type', 'string')}` | `{column.get('path', '')}` |\n" + + md += "\n## 验证建议\n\n" + md += f"- 默认参数: `{json.dumps(verify['args'], ensure_ascii=False)}`\n" + md += f"- 最少行数: `{verify['expect']['rowCount'].get('min', 0)}`\n" + md += f"- 必填列: `{', '.join(verify['expect']['notEmpty'])}`\n" + + return md + + +def generate_postman_collection_from_spec(spec: Dict[str, Any]) -> Dict[str, Any]: + """Generate a minimal Postman collection from a web2cli spec.""" + operation = spec.get("operation", {}) + headers = operation.get("headers", {}) if isinstance(operation.get("headers"), dict) else {} + body_template = operation.get("bodyTemplate", {}) if isinstance(operation.get("bodyTemplate"), dict) else {} + endpoint = operation.get("endpoint", "/") + path_parts = endpoint.lstrip("/").split("/") if endpoint.lstrip("/") else [] + + request = { + "method": operation.get("method", "GET"), + "url": { + "raw": f"{{{{base_url}}}}{endpoint}", + "host": ["{{base_url}}"], + "path": path_parts, + }, + "header": [{"key": key, "value": value} for key, value in headers.items()], + } + if body_template: + request["body"] = { + "mode": "raw", + "raw": json.dumps(body_template, ensure_ascii=False), + "options": {"raw": {"language": "json"}}, + } + + return { + "info": { + "name": f"{spec.get('site', 'captured')} {spec.get('command', 'command')}", + "schema": "https://schema.getpostman.com/json/collection/v2.1.0/collection.json", + }, + "item": [ + { + "name": spec.get("command", endpoint), + "request": request, + } + ], + "variable": [{"key": "base_url", "value": spec.get("baseUrl", "")}], + } + + +def generate_python_cli_from_spec(spec: Dict[str, Any]) -> str: + """Generate a fixed command CLI script from a web2cli spec.""" + spec_json = json.dumps(spec, indent=2, ensure_ascii=False) + return '''#!/usr/bin/env python3 +""" +Auto-generated Web2CLI command script. +Generated from web2cli-spec.json +""" + +import argparse +import csv +import json +import sys +from typing import Any, Dict, List + +import requests + + +SPEC = ''' + spec_json + ''' + + +def _load_json(path: str) -> Dict[str, Any]: + if not path: + return {} + try: + with open(path, encoding="utf-8") as f: + payload = json.load(f) + except FileNotFoundError: + return {} + except json.JSONDecodeError: + return {} + return payload if isinstance(payload, dict) else {} + + +def _coerce_bool(value: str) -> bool: + normalized = str(value).strip().lower() + if normalized in {"1", "true", "yes", "y", "on"}: + return True + if normalized in {"0", "false", "no", "n", "off"}: + return False + raise argparse.ArgumentTypeError(f"invalid boolean value: {value}") + + +def _type_name(value: Any) -> str: + if value is None: + return "null" + if isinstance(value, bool): + return "bool" + if isinstance(value, int) and not isinstance(value, bool): + return "int" + if isinstance(value, float): + return "float" + if isinstance(value, list): + return "array" + if isinstance(value, dict): + return "object" + return "string" + + +class APIClient: + """Fixed command client generated from a web2cli spec.""" + + @staticmethod + def _load_cookie_items(auth_state_path: str) -> List[Dict[str, Any]]: + payload = _load_json(auth_state_path) + cookies = payload.get("cookies", []) + if isinstance(cookies, list): + return [cookie for cookie in cookies if isinstance(cookie, dict)] + return [] + + @staticmethod + def _load_storage_map(payload: Dict[str, Any]) -> Dict[str, str]: + values = {} + for origin_entry in payload.get("origins", []): + if not isinstance(origin_entry, dict): + continue + for item in origin_entry.get("localStorage", []): + if isinstance(item, dict) and item.get("name"): + values[item["name"]] = item.get("value", "") + return values + + @staticmethod + def _resolve_header_value(payload: Dict[str, Any], rule: Dict[str, Any]) -> str | None: + source = rule.get("source") + key = rule.get("key") + if source == "cookie": + for cookie in payload.get("cookies", []): + if isinstance(cookie, dict) and cookie.get("name") == key: + return str(cookie.get("value", "")) + if source == "localStorage": + return APIClient._load_storage_map(payload).get(str(key)) + return None + + @staticmethod + def _resolve_template(value: Any, args: Dict[str, Any]) -> Any: + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + return args.get(value[2:-1], value) + if isinstance(value, dict): + return {key: APIClient._resolve_template(item, args) for key, item in value.items()} + if isinstance(value, list): + return [APIClient._resolve_template(item, args) for item in value] + return value + + @staticmethod + def _tokenize_path(path: str) -> List[str]: + if not path or path == "$": + return [] + normalized = path + if normalized.startswith("$."): + normalized = normalized[2:] + elif normalized.startswith("$"): + normalized = normalized[1:] + normalized = normalized.replace("[]", ".[]") + return [token for token in normalized.split(".") if token] + + @classmethod + def _extract_many(cls, value: Any, path: str) -> List[Any]: + tokens = cls._tokenize_path(path) + current = [value] + for token in tokens: + next_values = [] + if token == "[]": + for item in current: + if isinstance(item, list): + next_values.extend(item) + else: + for item in current: + if isinstance(item, dict) and token in item: + next_values.append(item[token]) + current = next_values + if not current: + break + return current + + @classmethod + def _extract_first(cls, value: Any, path: str) -> Any: + if not path or path == "$": + return value + values = cls._extract_many(value, path) + return values[0] if values else None + + def __init__(self, base_url: str = SPEC.get("baseUrl", ""), auth_state: str = "auth-state.json"): + self.base_url = (base_url or SPEC.get("baseUrl", "")).rstrip("/") + self.auth_state_path = auth_state + self.auth_state = _load_json(auth_state) if auth_state else {} + self.session = requests.Session() + self._apply_auth_state() + + def _apply_auth_state(self) -> None: + strategy = SPEC.get("strategy", "PUBLIC") + auth = SPEC.get("auth", {}) + headers = SPEC.get("operation", {}).get("headers", {}) + if isinstance(headers, dict) and headers: + self.session.headers.update(headers) + + if strategy in {"COOKIE", "HEADER"}: + for cookie in self._load_cookie_items(self.auth_state_path): + name = cookie.get("name") + if not name: + continue + kwargs = {} + if cookie.get("domain"): + kwargs["domain"] = cookie["domain"] + if cookie.get("path"): + kwargs["path"] = cookie["path"] + self.session.cookies.set(name, cookie.get("value", ""), **kwargs) + + if strategy == "HEADER": + for rule in auth.get("requiredHeaders", []): + if not isinstance(rule, dict) or not rule.get("name"): + continue + value = self._resolve_header_value(self.auth_state, rule) + if value: + self.session.headers[str(rule["name"])] = value + + def build_request(self, args: Dict[str, Any]) -> Dict[str, Any]: + operation = SPEC.get("operation", {}) + endpoint = operation.get("endpoint", "/") + query = self._resolve_template(operation.get("queryTemplate", {}), args) + body = self._resolve_template(operation.get("bodyTemplate", {}), args) + return { + "method": operation.get("method", "GET"), + "url": f"{self.base_url}{endpoint}", + "params": query or None, + "json": body or None, + } + + def _project_rows(self, payload: Any) -> List[Dict[str, Any]]: + row_source = SPEC.get("rowSource", {}) + collection_path = row_source.get("collectionPath") or row_source.get("path") or "$" + collection = self._extract_many(payload, collection_path) if collection_path != "$" else [payload] + if not collection: + return [] + + rows = [] + columns = SPEC.get("columns", []) + for index, row in enumerate(collection, start=1): + projected = {} + for column in columns: + if not isinstance(column, dict) or not column.get("name"): + continue + rel_path = column.get("relativePath") or column.get("path") or "$" + if rel_path == "__index__": + value = index + elif rel_path.startswith("$."): + value = self._extract_first(payload, rel_path) + else: + value = self._extract_first(row, rel_path) + projected[column["name"]] = value + rows.append(projected) + return rows + + def run(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: + request_options = self.build_request(args) + response = self.session.request( + request_options["method"], + request_options["url"], + params=request_options["params"], + json=request_options["json"], + ) + response.raise_for_status() + return self._project_rows(response.json()) + + +def verify_rows(rows: List[Dict[str, Any]], verify_spec: Dict[str, Any]) -> List[str]: + errors = [] + expect = verify_spec.get("expect", verify_spec) + row_count = expect.get("rowCount", {}) + min_rows = row_count.get("min") + max_rows = row_count.get("max") + + if min_rows is not None and len(rows) < min_rows: + errors.append(f"rowCount too small: expected >= {min_rows}, got {len(rows)}") + if max_rows is not None and len(rows) > max_rows: + errors.append(f"rowCount too large: expected <= {max_rows}, got {len(rows)}") + + columns = expect.get("columns", []) + types = expect.get("types", {}) + not_empty = expect.get("notEmpty", []) + patterns = expect.get("patterns", {}) + + for row in rows: + for column in columns: + if column not in row: + errors.append(f"missing column: {column}") + for column in not_empty: + if row.get(column) in (None, "", [], {}): + errors.append(f"empty required column: {column}") + for column, expected_type in types.items(): + if column in row and row[column] is not None and _type_name(row[column]) != expected_type: + errors.append( + f"type mismatch for {column}: expected {expected_type}, got {_type_name(row[column])}" + ) + for column, pattern in patterns.items(): + if column in row and row[column] is not None: + import re + if not re.search(pattern, str(row[column])): + errors.append(f"pattern mismatch for {column}: {pattern}") + + return errors + + +def _print_rows(rows: List[Dict[str, Any]], output_format: str) -> None: + if output_format == "json": + print(json.dumps(rows, ensure_ascii=False, indent=2)) + return + if not rows: + return + columns = list(rows[0].keys()) + if output_format == "csv": + writer = csv.DictWriter(sys.stdout, fieldnames=columns) + writer.writeheader() + writer.writerows(rows) + return + print("\\t".join(columns)) + for row in rows: + print("\\t".join("" if row.get(column) is None else str(row.get(column)) for column in columns)) + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description=SPEC.get("description", "Generated Web2CLI command")) + parser.add_argument("--base-url", default=SPEC.get("baseUrl", ""), help="Override base URL") + parser.add_argument( + "--auth-state", + default=(SPEC.get("auth", {}) or {}).get("stateFile", "auth-state.json"), + help="Path to auth state JSON", + ) + parser.add_argument("--format", choices=["json", "csv", "table"], default="json", help="Output format") + parser.add_argument("--verify", action="store_true", help="Validate rows against embedded or external verify spec") + parser.add_argument("--verify-spec", help="Optional verify JSON path") + for arg in SPEC.get("args", []): + if not isinstance(arg, dict) or not arg.get("name"): + continue + option = "--" + str(arg["name"]).replace("_", "-") + arg_type = arg.get("type", "string") + kwargs = { + "dest": arg["name"], + "default": arg.get("default"), + "help": arg.get("help", ""), + } + if arg_type == "int": + kwargs["type"] = int + elif arg_type == "float": + kwargs["type"] = float + elif arg_type == "bool": + kwargs["type"] = _coerce_bool + else: + kwargs["type"] = str + parser.add_argument(option, **kwargs) + return parser + + +def main() -> None: + parser = build_parser() + parsed = parser.parse_args() + runtime_args = { + item["name"]: getattr(parsed, item["name"]) + for item in SPEC.get("args", []) + if isinstance(item, dict) and item.get("name") + } + client = APIClient(base_url=parsed.base_url, auth_state=parsed.auth_state) + rows = client.run(runtime_args) + + if parsed.verify: + verify_spec = _load_json(parsed.verify_spec) if parsed.verify_spec else SPEC.get("verify", {}) + errors = verify_rows(rows, verify_spec) + if errors: + raise SystemExit("\\n".join(errors)) + + _print_rows(rows, parsed.format) + + +if __name__ == "__main__": + main() +''' + + def main(): - parser = argparse.ArgumentParser(description='Generate CLI/docs from captured APIs') - parser.add_argument('input', help='Input JSON file with captured requests') + parser = argparse.ArgumentParser(description='Generate CLI/docs from captured APIs or a web2cli spec') + parser.add_argument('input', nargs='?', help='Input JSON file with captured requests') + parser.add_argument('--spec', help='Input web2cli-spec.json file') parser.add_argument('--output', '-o', help='Output file') parser.add_argument('--base-url', '-u', default='https://example.com', help='Base URL') - parser.add_argument('--format', '-f', choices=['python', 'markdown', 'postman'], + parser.add_argument('--format', '-f', choices=['python', 'markdown', 'postman', 'verify'], default='markdown', help='Output format') parser.add_argument('--title', '-t', default='API Documentation', help='Document title') args = parser.parse_args() - # Load input - with open(args.input) as f: - data = json.load(f) - - # Handle both array and object formats - requests = data if isinstance(data, list) else data.get('requests', []) - - if not requests: - print("No requests found in input file", file=sys.stderr) - sys.exit(1) - - print(f"Processing {len(requests)} requests, {len(group_endpoints(requests))} unique endpoints...") - - # Generate output - if args.format == 'python': - output = generate_python_client(requests, args.base_url) - elif args.format == 'postman': - output = json.dumps(generate_postman_collection(requests, args.base_url), indent=2, ensure_ascii=False) + if not args.input and not args.spec: + parser.error('either input or --spec is required') + + if args.spec: + spec = load_spec(args.spec) + if args.format == 'python': + output = generate_python_cli_from_spec(spec) + elif args.format == 'verify': + output = json.dumps(generate_verify_materials_from_spec(spec), indent=2, ensure_ascii=False) + elif args.format == 'postman': + output = json.dumps(generate_postman_collection_from_spec(spec), indent=2, ensure_ascii=False) + else: + output = generate_markdown_docs_from_spec(spec, args.title) else: - output = generate_markdown_docs(requests, args.title) + # Load input + with open(args.input, encoding='utf-8') as f: + data = json.load(f) + + # Handle both array and object formats + requests = data if isinstance(data, list) else data.get('requests', []) + + if not requests: + print("No requests found in input file", file=sys.stderr) + sys.exit(1) + + print(f"Processing {len(requests)} requests, {len(group_endpoints(requests))} unique endpoints...") + + # Generate output + if args.format == 'python': + output = generate_python_client(requests, args.base_url) + elif args.format == 'postman': + output = json.dumps(generate_postman_collection(requests, args.base_url), indent=2, ensure_ascii=False) + elif args.format == 'verify': + print("verify output requires --spec", file=sys.stderr) + sys.exit(1) + else: + output = generate_markdown_docs(requests, args.title) # Write output output_path = args.output diff --git a/.flocks/plugins/skills/web2cli/scripts/generate-spec.py b/.flocks/plugins/skills/web2cli/scripts/generate-spec.py new file mode 100644 index 00000000..c7b01fdc --- /dev/null +++ b/.flocks/plugins/skills/web2cli/scripts/generate-spec.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python3 +"""Generate a web2cli spec from captured API requests.""" + +from __future__ import annotations + +import argparse +import json +import keyword +import re +import sys +from pathlib import Path +from typing import Any +from urllib.parse import parse_qsl, urlparse + + +PAGE_PARAM_NAMES = {"page", "pageNo", "pageNum", "current", "pageIndex", "curPage"} +LIMIT_PARAM_NAMES = {"limit", "size", "pageSize", "page_size", "page_limit", "rows"} + + +def sanitize_name(name: str) -> str: + """Convert text to a valid Python/CLI-friendly identifier.""" + value = re.sub(r"\?.*$", "", name) + value = re.sub(r"[^a-zA-Z0-9_]", "_", value) + value = re.sub(r"_+", "_", value) + value = value.strip("_") + if value and value[0].isdigit(): + value = f"_{value}" + value = value.lower() or "endpoint" + if keyword.iskeyword(value): + value = f"{value}_" + return value + + +def load_requests(input_path: str) -> list[dict[str, Any]]: + """Load captured request list from disk.""" + with open(input_path, encoding="utf-8") as f: + payload = json.load(f) + + requests = payload if isinstance(payload, list) else payload.get("requests", []) + return [item for item in requests if isinstance(item, dict)] + + +def parse_json_text(text: str) -> Any: + """Parse a response/request body string when possible.""" + if not text: + return {} + + value = text.strip() + if value.endswith("...[truncated]"): + value = value[: -len("...[truncated]")] + + try: + return json.loads(value) + except json.JSONDecodeError: + return {"raw": text} + + +def infer_type(value: Any) -> str: + """Return a compact type name for spec/verify output.""" + if value is None: + return "null" + if isinstance(value, bool): + return "bool" + if isinstance(value, int) and not isinstance(value, bool): + return "int" + if isinstance(value, float): + return "float" + if isinstance(value, list): + return "array" + if isinstance(value, dict): + return "object" + return "string" + + +def normalize_url_info(request: dict[str, Any]) -> dict[str, Any]: + """Return normalized URL parts from capture metadata or raw URL.""" + url = ( + request.get("normalizedUrl") + or request.get("url") + or "" + ) + parsed = urlparse(url) + query_items = dict(parse_qsl(parsed.query, keep_blank_values=True)) + return { + "url": url, + "origin": request.get("origin") or (f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else ""), + "pathname": request.get("pathname") or (parsed.path or "/"), + "query": request.get("query") or query_items, + "queryKeys": request.get("queryKeys") or list(query_items.keys()), + "host": parsed.netloc, + } + + +def score_request(request: dict[str, Any], index: int) -> tuple[int, int]: + """Score a captured request to decide which one should become the spec.""" + score = 0 + response = parse_json_text(str(request.get("response", ""))) + action = ((request.get("actionContext") or {}).get("lastAction") or {}).get("action") + + status = request.get("status") + if isinstance(status, int) and 200 <= status < 300: + score += 30 + elif status == "error": + score -= 20 + + if request.get("captureReason") in {"nonGet", "captureModeAll", "includePattern"}: + score += 15 + if action: + score += 12 + if isinstance(response, dict) and "raw" not in response: + score += 20 + + collection = find_best_collection(response) + if collection is not None: + score += 20 + min(collection["length"], 20) + + return score, index + + +def choose_primary_request(requests: list[dict[str, Any]]) -> dict[str, Any]: + """Pick the best request candidate from the captured request list.""" + if not requests: + raise ValueError("No captured requests available") + ranked = sorted( + ((score_request(req, index), req) for index, req in enumerate(requests)), + key=lambda item: (item[0][0], item[0][1]), + reverse=True, + ) + return ranked[0][1] + + +def find_collections(value: Any, path: str = "$") -> list[dict[str, Any]]: + """Find likely row collections inside a JSON response.""" + results: list[dict[str, Any]] = [] + + if isinstance(value, list): + item = value[0] if value else None + score = 10 + if isinstance(item, dict): + score += 25 + elif item is not None: + score += 10 + results.append( + { + "collectionPath": path + "[]", + "path": path, + "length": len(value), + "item": item, + "score": score, + } + ) + if isinstance(item, dict): + for key, child in item.items(): + results.extend(find_collections(child, path + "[]." + key)) + return results + + if isinstance(value, dict): + for key, child in value.items(): + next_path = path + "." + key if path != "$" else "$." + key + results.extend(find_collections(child, next_path)) + + return results + + +def find_best_collection(value: Any) -> dict[str, Any] | None: + """Return the highest scoring collection candidate from the response.""" + candidates = find_collections(value) + if not candidates: + return None + candidates.sort( + key=lambda item: ( + item["score"], + item["length"], + -len(item["path"]), + ), + reverse=True, + ) + return candidates[0] + + +def collect_columns(item: Any) -> list[dict[str, Any]]: + """Infer a compact column list from a sample row.""" + columns: list[dict[str, Any]] = [] + + if isinstance(item, dict): + for key, value in item.items(): + if isinstance(value, (dict, list)): + if isinstance(value, dict): + for nested_key, nested_value in value.items(): + if isinstance(nested_value, (dict, list)): + continue + columns.append( + { + "name": sanitize_name(f"{key}_{nested_key}"), + "path": "$." + key + "." + nested_key, + "relativePath": key + "." + nested_key, + "sourceField": nested_key, + "type": infer_type(nested_value), + } + ) + continue + columns.append( + { + "name": sanitize_name(key), + "path": "$." + key, + "relativePath": key, + "sourceField": key, + "type": infer_type(value), + } + ) + if len(columns) >= 8: + break + elif item is not None: + columns.append( + { + "name": "value", + "path": "$", + "relativePath": "$", + "sourceField": "value", + "type": infer_type(item), + } + ) + + if not columns: + columns.append( + { + "name": "value", + "path": "$", + "relativePath": "$", + "sourceField": "value", + "type": "string", + } + ) + return columns + + +def build_templates(request: dict[str, Any], url_info: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any], list[dict[str, Any]]]: + """Build query/body templates and CLI arg definitions.""" + args: list[dict[str, Any]] = [] + seen_args: set[str] = set() + + def add_arg(name: str, default: Any, help_text: str) -> None: + if name in seen_args: + return + seen_args.add(name) + arg_type = "int" if isinstance(default, int) else "string" + args.append({"name": name, "type": arg_type, "default": default, "help": help_text}) + + def transform_mapping(data: dict[str, Any]) -> dict[str, Any]: + result: dict[str, Any] = {} + for key, value in data.items(): + if key in PAGE_PARAM_NAMES: + default = int(value) if str(value).isdigit() else 1 + result[key] = "${page}" + add_arg("page", default, "Page number") + elif key in LIMIT_PARAM_NAMES: + default = int(value) if str(value).isdigit() else 20 + result[key] = "${limit}" + add_arg("limit", default, "Page size") + else: + result[key] = value + return result + + body = parse_json_text(str(request.get("requestBody", ""))) + if not isinstance(body, dict) or "raw" in body: + body = {} + + query_template = transform_mapping(url_info["query"]) + body_template = transform_mapping(body) + + args.sort(key=lambda item: (0 if item["name"] == "page" else 1 if item["name"] == "limit" else 2, item["name"])) + return query_template, body_template, args + + +def build_strategy(request: dict[str, Any]) -> tuple[str, dict[str, Any]]: + """Infer auth strategy and auth metadata from request headers.""" + headers = request.get("requestHeaders", {}) or request.get("request_headers", {}) + normalized = {str(key).lower(): value for key, value in headers.items()} + strategy = "PUBLIC" + required_headers: list[dict[str, Any]] = [] + + if "authorization" in normalized: + strategy = "HEADER" + required_headers.append({"name": "Authorization", "source": "manual", "key": "authorization"}) + elif "cookie" in normalized: + strategy = "COOKIE" + + for header_name in ("x-csrf-token", "x-xsrf-token", "x-auth-token"): + if header_name in normalized: + strategy = "HEADER" + required_headers.append({"name": header_name, "source": "manual", "key": header_name}) + + return strategy, {"stateFile": "auth-state.json", "requiredCookies": [], "requiredHeaders": required_headers} + + +def safe_headers(request: dict[str, Any]) -> dict[str, Any]: + """Return non-sensitive request headers that can be replayed safely.""" + headers = request.get("requestHeaders", {}) or request.get("request_headers", {}) + result = {} + for key, value in headers.items(): + if str(key).lower() in {"cookie", "authorization", "x-csrf-token", "x-xsrf-token", "x-auth-token"}: + continue + result[key] = value + return result + + +def site_name_from_host(host: str) -> str: + """Return a readable site name from a host.""" + cleaned = host.split(":")[0] + parts = [part for part in cleaned.split(".") if part not in {"www", "api", "m"}] + if len(parts) >= 2: + return sanitize_name(parts[-2]) + if parts: + return sanitize_name(parts[0]) + return "captured_site" + + +def command_name_from_path(pathname: str) -> str: + """Return a command name from an API pathname.""" + parts = [part for part in pathname.split("/") if part] + return sanitize_name(parts[-1] if parts else "command") + + +def generate_spec_from_requests(requests: list[dict[str, Any]], *, base_url: str | None = None) -> dict[str, Any]: + """Build a web2cli spec object from captured request data.""" + request = choose_primary_request(requests) + url_info = normalize_url_info(request) + response = parse_json_text(str(request.get("response", ""))) + collection = find_best_collection(response) + row_item = collection["item"] if collection is not None else response + query_template, body_template, args = build_templates(request, url_info) + strategy, auth = build_strategy(request) + columns = collect_columns(row_item) + + defaults = {item["name"]: item["default"] for item in args} + verify_types = {column["name"]: column["type"] for column in columns} + verify_not_empty = [column["name"] for column in columns[: min(3, len(columns))]] + row_count = {"min": 1} + if collection and collection["length"]: + row_count["max"] = collection["length"] + + purpose = request.get("apiPurpose", {}) if isinstance(request.get("apiPurpose"), dict) else {} + host_origin = base_url or url_info["origin"] or "https://example.com" + pathname = url_info["pathname"] or "/" + site = site_name_from_host(urlparse(host_origin).netloc or url_info["host"]) + command = purpose.get("name") or command_name_from_path(pathname) + + return { + "schemaVersion": "1.0", + "site": site, + "command": sanitize_name(command), + "description": purpose.get("desc") or f"Generated from {request.get('method', 'GET')} {pathname}", + "baseUrl": host_origin, + "strategy": strategy, + "auth": auth, + "operation": { + "method": request.get("method", "GET"), + "endpoint": pathname, + "queryTemplate": query_template, + "bodyTemplate": body_template, + "headers": safe_headers(request), + "captureSource": request.get("captureSource", "pageHook"), + "captureReason": request.get("captureReason", ""), + "sourceRequestId": request.get("timestamp", ""), + }, + "rowSource": { + "path": collection["collectionPath"] if collection else "$", + "collectionPath": collection["collectionPath"] if collection else "$", + }, + "args": args, + "columns": columns, + "verify": { + "args": defaults, + "rowCount": row_count, + "columns": [column["name"] for column in columns], + "types": verify_types, + "notEmpty": verify_not_empty, + "patterns": {}, + }, + } + + +def main() -> None: + """CLI entrypoint.""" + parser = argparse.ArgumentParser(description="Generate a web2cli spec from captured APIs") + parser.add_argument("input", help="Input JSON file with captured requests") + parser.add_argument("--output", "-o", help="Output spec path") + parser.add_argument("--base-url", help="Optional base URL override") + args = parser.parse_args() + + requests = load_requests(args.input) + if not requests: + print("No requests found in input file", file=sys.stderr) + sys.exit(1) + + spec = generate_spec_from_requests(requests, base_url=args.base_url) + rendered = json.dumps(spec, indent=2, ensure_ascii=False) + + if args.output: + output_path = Path(args.output) + output_path.write_text(rendered, encoding="utf-8") + print(f"Written to {output_path}") + else: + print(rendered) + + +if __name__ == "__main__": + main() diff --git a/.flocks/plugins/skills/web2cli/scripts/inject-hook-base.js b/.flocks/plugins/skills/web2cli/scripts/inject-hook-base.js index 2a009cac..5e11d574 100644 --- a/.flocks/plugins/skills/web2cli/scripts/inject-hook-base.js +++ b/.flocks/plugins/skills/web2cli/scripts/inject-hook-base.js @@ -1,7 +1,7 @@ /** - * API Capture Hook - Simple Version (ES5 compatible) + * API Capture Hook - Base Version (ES5 compatible) */ -(function(){ +(function() { if (window.__apiCapture) { console.log('[API Capture] Already installed'); return; @@ -9,9 +9,10 @@ window.__capturedRequests = []; - // Configuration var CONFIG = { - maxResponseLength: 50000, + maxResponseLength: 2000, + maxRequestBodyLength: 2000, + maxRecentActions: 20, captureMode: 'smart', // 'smart' | 'all' sameOriginOnly: true, includePatterns: [], @@ -24,13 +25,38 @@ ] }; + var recentActions = []; + var navigationState = { + lastNavigation: null, + currentUrl: window.location.href + }; + + function truncateText(text, limit) { + var value = text == null ? '' : String(text); + if (value.length <= limit) { + return value; + } + return value.substring(0, limit) + '...[truncated]'; + } + + function safeTrim(text) { + return String(text || '').replace(/\s+/g, ' ').replace(/^\s+|\s+$/g, ''); + } + + function cloneSimple(value) { + if (value == null || typeof value !== 'object') { + return value; + } + return JSON.parse(JSON.stringify(value)); + } + function normalizeHeaders(headers) { var result = {}; var key; if (!headers) { return result; } - if (typeof Headers !== 'undefined' && headers instanceof Headers) { + if (typeof Headers !== 'undefined' && headers instanceof Headers && headers.forEach) { headers.forEach(function(value, name) { result[name] = value; }); @@ -53,7 +79,9 @@ } function getHeader(headers, name) { - if (!headers) return ''; + if (!headers) { + return ''; + } return headers[name] || headers[name.toLowerCase()] || ''; } @@ -61,54 +89,198 @@ return /\.[a-z0-9]{1,8}$/i.test(pathname || ''); } - function shouldCapture(url, method, headers) { - var u; - var m = (method || 'GET').toUpperCase(); - var normalizedHeaders = normalizeHeaders(headers); - var accept = ''; - var contentType = ''; - var looksJson = false; - var isIgnored = false; - var isIncluded = false; - + function normalizeUrl(url) { + var parsed; + var query = {}; + var queryKeys = []; try { - u = new URL(url, window.location.href); - } catch (e) { - return false; + parsed = new URL(url, window.location.href); + parsed.searchParams.forEach(function(value, key) { + if (!Object.prototype.hasOwnProperty.call(query, key)) { + queryKeys.push(key); + } + query[key] = value; + }); + return { + normalizedUrl: parsed.href, + origin: parsed.origin, + pathname: parsed.pathname, + query: query, + queryKeys: queryKeys + }; + } catch (error) { + return { + normalizedUrl: String(url || ''), + origin: '', + pathname: '', + query: query, + queryKeys: queryKeys + }; } + } + + function inferShape(value, path, out, depth) { + var currentPath = path || '$'; + var nextDepth = depth || 0; + var keys; + var i; - if (CONFIG.sameOriginOnly && u.origin !== window.location.origin) { - return false; + if (nextDepth > 4) { + out[currentPath] = 'depthLimit'; + return; } + if (value === null) { + out[currentPath] = 'null'; + return; + } + if (typeof value === 'undefined') { + out[currentPath] = 'undefined'; + return; + } + if (Array.isArray(value)) { + out[currentPath] = 'array(' + value.length + ')'; + if (value.length > 0) { + inferShape(value[0], currentPath + '[]', out, nextDepth + 1); + } + return; + } + if (typeof value === 'object') { + out[currentPath] = 'object'; + keys = Object.keys(value); + for (i = 0; i < keys.length && i < 20; i++) { + inferShape(value[keys[i]], currentPath + '.' + keys[i], out, nextDepth + 1); + } + return; + } + out[currentPath] = typeof value; + } - isIgnored = CONFIG.ignorePatterns.some(function(p) { return p.test(u.href); }); - if (isIgnored) { - return false; + function detectGraphQL(payload) { + var text; + var parsed; + var operationType = ''; + if (!payload) { + return null; + } + text = typeof payload === 'string' ? payload : ''; + try { + parsed = typeof payload === 'string' ? JSON.parse(payload) : payload; + } catch (error) { + parsed = null; } + if (!parsed || typeof parsed !== 'object') { + return null; + } + if (!parsed.query) { + return null; + } + if (/mutation\s/i.test(parsed.query)) { + operationType = 'mutation'; + } else if (/query\s/i.test(parsed.query)) { + operationType = 'query'; + } else { + operationType = 'graphql'; + } + return { + operationName: parsed.operationName || '', + operationType: operationType, + variablesShape: parsed.variables && typeof parsed.variables === 'object' + ? (function() { + var shape = {}; + inferShape(parsed.variables, '$', shape, 0); + return shape; + })() + : {} + }; + } + + function summarizeBody(body) { + var result = { + kind: 'empty', + display: '', + parsed: null, + shape: {}, + graphql: null + }; + var asObject = {}; - isIncluded = CONFIG.includePatterns.some(function(p) { return p.test(u.href); }); - if (isIncluded) { - return true; + if (body == null || body === '') { + return result; } - if (CONFIG.captureMode === 'all') { - return true; + if (typeof URLSearchParams !== 'undefined' && body instanceof URLSearchParams) { + body.forEach(function(value, key) { + asObject[key] = value; + }); + result.kind = 'urlencoded'; + result.parsed = asObject; + result.display = truncateText(JSON.stringify(asObject, null, 2), CONFIG.maxRequestBodyLength); + inferShape(asObject, '$', result.shape, 0); + return result; } - accept = getHeader(normalizedHeaders, 'accept'); - contentType = getHeader(normalizedHeaders, 'content-type'); - looksJson = /application\/json|text\/plain|application\/x-www-form-urlencoded/i - .test(accept + ' ' + contentType); + if (typeof FormData !== 'undefined' && body instanceof FormData) { + result.kind = 'formData'; + if (typeof body.forEach === 'function') { + body.forEach(function(value, key) { + asObject[key] = Object.prototype.toString.call(value) === '[object File]' ? '[file]' : String(value); + }); + } + result.parsed = asObject; + result.display = truncateText(JSON.stringify(asObject, null, 2), CONFIG.maxRequestBodyLength); + inferShape(asObject, '$', result.shape, 0); + return result; + } - if (m !== 'GET') { - return true; + if (typeof body === 'string') { + result.display = truncateText(body, CONFIG.maxRequestBodyLength); + try { + result.parsed = JSON.parse(body); + result.kind = 'json'; + result.display = truncateText(JSON.stringify(result.parsed, null, 2), CONFIG.maxRequestBodyLength); + inferShape(result.parsed, '$', result.shape, 0); + } catch (error) { + result.kind = 'text'; + result.graphql = detectGraphQL(body); + } + if (!result.graphql && result.parsed) { + result.graphql = detectGraphQL(result.parsed); + } + return result; } - if (!hasStaticExtension(u.pathname || '/')) { - return true; + if (typeof body === 'object') { + result.kind = 'object'; + result.parsed = body; + result.display = truncateText(JSON.stringify(body, null, 2), CONFIG.maxRequestBodyLength); + inferShape(body, '$', result.shape, 0); + result.graphql = detectGraphQL(body); + return result; } - return looksJson; + result.kind = typeof body; + result.display = truncateText(String(body), CONFIG.maxRequestBodyLength); + return result; + } + + function summarizeResponse(text) { + var result = { + display: '', + parsed: null, + shape: {} + }; + if (!text) { + return result; + } + try { + result.parsed = JSON.parse(text); + result.display = truncateText(JSON.stringify(result.parsed, null, 2), CONFIG.maxResponseLength); + inferShape(result.parsed, '$', result.shape, 0); + return result; + } catch (error) { + result.display = truncateText(text, CONFIG.maxResponseLength); + return result; + } } function getPageContext() { @@ -120,19 +292,212 @@ }; } - // Hook XMLHttpRequest + function describeElement(target) { + var tag = target && target.tagName ? String(target.tagName).toUpperCase() : 'UNKNOWN'; + var text = safeTrim(target && target.textContent ? target.textContent : ''); + var label = text || safeTrim(target && target.value ? target.value : ''); + if (!label && target && typeof target.getAttribute === 'function') { + label = safeTrim( + target.getAttribute('aria-label') || + target.getAttribute('title') || + target.getAttribute('name') || + target.getAttribute('placeholder') || + '' + ); + } + if (!label) { + label = (target && target.id) || (target && target.className) || tag; + } + return { + action: label, + tagName: tag, + id: target && target.id ? String(target.id) : '', + className: target && target.className ? String(target.className) : '' + }; + } + + function pushRecentAction(action) { + recentActions.push(action); + if (recentActions.length > CONFIG.maxRecentActions) { + recentActions.shift(); + } + } + + function recordAction(type, detail) { + pushRecentAction({ + type: type, + detail: detail || {}, + action: detail && detail.action ? detail.action : '', + url: window.location.href, + timestamp: new Date().toISOString() + }); + } + + function snapshotActionContext() { + return { + lastAction: recentActions.length ? cloneSimple(recentActions[recentActions.length - 1]) : null, + recentActions: cloneSimple(recentActions), + navigation: cloneSimple(navigationState) + }; + } + + function installActionListeners() { + if (document && document.addEventListener) { + document.addEventListener('click', function(event) { + recordAction('click', describeElement(event && event.target)); + }, true); + document.addEventListener('input', function(event) { + recordAction('input', describeElement(event && event.target)); + }, true); + document.addEventListener('change', function(event) { + recordAction('change', describeElement(event && event.target)); + }, true); + document.addEventListener('submit', function(event) { + recordAction('submit', describeElement(event && event.target)); + }, true); + document.addEventListener('keydown', function(event) { + recordAction('keydown', { + action: event && event.key ? String(event.key) : 'keydown' + }); + }, true); + } + + if (window && window.addEventListener) { + window.addEventListener('popstate', function() { + navigationState.lastNavigation = { + type: 'popstate', + url: window.location.href, + timestamp: new Date().toISOString() + }; + navigationState.currentUrl = window.location.href; + recordAction('popstate', { action: window.location.href }); + }); + } + + if (window.history && window.history.pushState) { + var originalPushState = window.history.pushState; + window.history.pushState = function() { + var result = originalPushState.apply(this, arguments); + navigationState.lastNavigation = { + type: 'pushState', + url: arguments.length >= 3 ? String(arguments[2]) : window.location.href, + timestamp: new Date().toISOString() + }; + navigationState.currentUrl = window.location.href; + recordAction('pushState', { action: navigationState.lastNavigation.url }); + return result; + }; + } + + if (window.history && window.history.replaceState) { + var originalReplaceState = window.history.replaceState; + window.history.replaceState = function() { + var result = originalReplaceState.apply(this, arguments); + navigationState.lastNavigation = { + type: 'replaceState', + url: arguments.length >= 3 ? String(arguments[2]) : window.location.href, + timestamp: new Date().toISOString() + }; + navigationState.currentUrl = window.location.href; + recordAction('replaceState', { action: navigationState.lastNavigation.url }); + return result; + }; + } + } + + function getCaptureDecision(url, method, headers) { + var m = (method || 'GET').toUpperCase(); + var normalizedHeaders = normalizeHeaders(headers); + var accept = getHeader(normalizedHeaders, 'accept'); + var contentType = getHeader(normalizedHeaders, 'content-type'); + var looksJson = /application\/json|text\/plain|application\/x-www-form-urlencoded/i + .test(accept + ' ' + contentType); + var urlInfo = normalizeUrl(url); + var i; + + if (CONFIG.sameOriginOnly && urlInfo.origin && urlInfo.origin !== window.location.origin) { + return { capture: false, reason: 'crossOrigin', urlInfo: urlInfo }; + } + + for (i = 0; i < CONFIG.ignorePatterns.length; i++) { + if (CONFIG.ignorePatterns[i].test(urlInfo.normalizedUrl)) { + return { capture: false, reason: 'ignorePattern', urlInfo: urlInfo }; + } + } + + for (i = 0; i < CONFIG.includePatterns.length; i++) { + if (CONFIG.includePatterns[i].test(urlInfo.normalizedUrl)) { + return { capture: true, reason: 'includePattern', urlInfo: urlInfo }; + } + } + + if (CONFIG.captureMode === 'all') { + return { capture: true, reason: 'captureModeAll', urlInfo: urlInfo }; + } + + if (m !== 'GET') { + return { capture: true, reason: 'nonGet', urlInfo: urlInfo }; + } + + if (!hasStaticExtension(urlInfo.pathname || '/')) { + return { capture: true, reason: 'nonStaticPath', urlInfo: urlInfo }; + } + + if (looksJson) { + return { capture: true, reason: 'jsonLike', urlInfo: urlInfo }; + } + + return { capture: false, reason: 'filteredOut', urlInfo: urlInfo }; + } + + function buildCaptureRecord(base) { + var requestBody = summarizeBody(base.requestBody); + var responseBody = summarizeResponse(base.responseText); + var requestContentType = getHeader(base.requestHeaders, 'content-type'); + var responseContentType = base.responseContentType || ''; + var actionContext = snapshotActionContext(); + return { + captureSource: 'pageHook', + type: base.type, + method: base.method, + url: base.url, + normalizedUrl: base.urlInfo.normalizedUrl, + origin: base.urlInfo.origin, + pathname: base.urlInfo.pathname, + query: base.urlInfo.query, + queryKeys: base.urlInfo.queryKeys, + status: base.status, + requestHeaders: base.requestHeaders, + requestBody: requestBody.display, + requestBodyKind: requestBody.kind, + requestShape: requestBody.shape, + requestContentType: requestContentType, + graphql: requestBody.graphql, + response: responseBody.display, + responseShape: responseBody.shape, + responseContentType: responseContentType, + pageContext: base.pageContext, + actionContext: actionContext, + captureReason: base.captureReason, + duration: base.duration, + timestamp: new Date().toISOString() + }; + } + + installActionListeners(); + var originalXHROpen = XMLHttpRequest.prototype.open; var originalXHRSend = XMLHttpRequest.prototype.send; var originalXHRSetHeader = XMLHttpRequest.prototype.setRequestHeader; XMLHttpRequest.prototype.open = function(method, url) { this._capture = { - method: method.toUpperCase(), + method: (method || 'GET').toUpperCase(), url: typeof url === 'string' ? url : String(url), startTime: Date.now(), - headers: {} + headers: {}, + pageContext: getPageContext() }; - this._pageContext = getPageContext(); return originalXHROpen.apply(this, arguments); }; @@ -144,164 +509,173 @@ }; XMLHttpRequest.prototype.send = function(body) { - var self = this; - if (!this._capture || !shouldCapture(this._capture.url, this._capture.method, this._capture.headers)) { + var capture = this._capture; + var decision = capture ? getCaptureDecision(capture.url, capture.method, capture.headers) : null; + if (!capture || !decision || !decision.capture) { return originalXHRSend.apply(this, arguments); } - var capture = this._capture; - var pageContext = this._pageContext; capture.requestBody = body; - var requestBodyDisplay = ''; - if (body) { - try { - var parsed = JSON.parse(body); - requestBodyDisplay = JSON.stringify(parsed, null, 2).substring(0, 2000); - } catch (e) { - requestBodyDisplay = String(body).substring(0, 1000); - } - } - this.addEventListener('load', function() { - var responseDisplay = ''; - try { - var parsed = JSON.parse(this.responseText); - responseDisplay = JSON.stringify(parsed, null, 2).substring(0, CONFIG.maxResponseLength); - } catch (e) { - responseDisplay = this.responseText.substring(0, 2000); - } - - window.__capturedRequests.push({ + var record = buildCaptureRecord({ type: 'XHR', method: capture.method, url: capture.url, + urlInfo: decision.urlInfo, status: this.status, - requestHeaders: capture.headers, - requestBody: requestBodyDisplay, - response: responseDisplay, - pageContext: pageContext, - duration: Date.now() - capture.startTime, - timestamp: new Date().toISOString() + requestHeaders: normalizeHeaders(capture.headers), + requestBody: capture.requestBody, + responseText: this.responseText || '', + responseContentType: typeof this.getResponseHeader === 'function' + ? (this.getResponseHeader('Content-Type') || '') + : '', + pageContext: capture.pageContext, + captureReason: decision.reason, + duration: Date.now() - capture.startTime }); - console.log('[API Capture] XHR:', capture.method, capture.url, '->', this.status); + window.__capturedRequests.push(record); + console.log( + '[API Capture] XHR:', + capture.method, + record.normalizedUrl, + '->', + this.status, + 'action=' + (record.actionContext.lastAction ? record.actionContext.lastAction.action : 'none') + ); }); this.addEventListener('error', function() { - window.__capturedRequests.push({ + var record = buildCaptureRecord({ type: 'XHR', method: capture.method, url: capture.url, + urlInfo: decision.urlInfo, status: 'error', - requestHeaders: capture.headers, - requestBody: requestBodyDisplay, - error: 'Network error', - pageContext: pageContext, - duration: Date.now() - capture.startTime, - timestamp: new Date().toISOString() + requestHeaders: normalizeHeaders(capture.headers), + requestBody: capture.requestBody, + responseText: '', + responseContentType: '', + pageContext: capture.pageContext, + captureReason: decision.reason, + duration: Date.now() - capture.startTime }); + record.error = 'Network error'; + window.__capturedRequests.push(record); }); return originalXHRSend.apply(this, arguments); }; - // Hook Fetch var originalFetch = window.fetch; window.fetch = function(url, options) { options = options || {}; var startTime = Date.now(); var method = (options.method || 'GET').toUpperCase(); - var urlStr = typeof url === 'string' ? url : (url.url || String(url)); - var requestHeaders = normalizeHeaders(options.headers || {}); + var urlStr = typeof url === 'string' ? url : (url && url.url ? url.url : String(url)); + var decision = getCaptureDecision(urlStr, method, requestHeaders); - if (!shouldCapture(urlStr, method, requestHeaders)) { + if (!decision.capture) { return originalFetch.apply(this, arguments); } - var pageContext = getPageContext(); - var requestBodyDisplay = ''; - if (options.body) { - try { - if (typeof options.body === 'string') { - requestBodyDisplay = options.body.substring(0, 2000); - } else { - requestBodyDisplay = JSON.stringify(options.body).substring(0, 2000); - } - } catch (e) { - requestBodyDisplay = '[body unreadable]'; - } - } - return originalFetch.apply(this, arguments).then(function(response) { var cloned = response.clone(); - return cloned.text().then(function(text) { - var responseBody = ''; - try { - var parsed = JSON.parse(text); - responseBody = JSON.stringify(parsed, null, 2).substring(0, CONFIG.maxResponseLength); - } catch (e) { - responseBody = text.substring(0, 2000); - } - - window.__capturedRequests.push({ + var record = buildCaptureRecord({ type: 'Fetch', method: method, url: urlStr, + urlInfo: decision.urlInfo, status: response.status, requestHeaders: requestHeaders, - requestBody: requestBodyDisplay, - response: responseBody, - pageContext: pageContext, - duration: Date.now() - startTime, - timestamp: new Date().toISOString() + requestBody: options.body, + responseText: text || '', + responseContentType: response.headers && typeof response.headers.get === 'function' + ? (response.headers.get('content-type') || '') + : '', + pageContext: getPageContext(), + captureReason: decision.reason, + duration: Date.now() - startTime }); - console.log('[API Capture] Fetch:', method, urlStr, '->', response.status); + window.__capturedRequests.push(record); + console.log( + '[API Capture] Fetch:', + method, + record.normalizedUrl, + '->', + response.status, + 'action=' + (record.actionContext.lastAction ? record.actionContext.lastAction.action : 'none') + ); return response; }); }).catch(function(error) { - window.__capturedRequests.push({ + var record = buildCaptureRecord({ type: 'Fetch', method: method, url: urlStr, + urlInfo: decision.urlInfo, status: 'error', requestHeaders: requestHeaders, - requestBody: requestBodyDisplay, - error: error.message, - pageContext: pageContext, - duration: Date.now() - startTime, - timestamp: new Date().toISOString() + requestBody: options.body, + responseText: '', + responseContentType: '', + pageContext: getPageContext(), + captureReason: decision.reason, + duration: Date.now() - startTime }); + record.error = error && error.message ? error.message : String(error); + window.__capturedRequests.push(record); throw error; }); }; window.__apiCapture = { - version: '3.0-simple', + version: 'web2cli-base', installed: new Date().toISOString(), + config: CONFIG, getAll: function() { return window.__capturedRequests; }, clear: function() { window.__capturedRequests = []; + recentActions = []; console.log('[API Capture] Cleared'); }, + getRecentActions: function() { + return cloneSimple(recentActions); + }, + getDebugState: function() { + return { + version: this.version, + installed: this.installed, + config: cloneSimple(CONFIG), + requestCount: window.__capturedRequests.length, + recentActions: cloneSimple(recentActions), + navigation: cloneSimple(navigationState), + lastRequest: window.__capturedRequests.length + ? cloneSimple(window.__capturedRequests[window.__capturedRequests.length - 1]) + : null + }; + }, summary: function() { - console.log('=== API Capture Summary ==='); - console.log('Total requests:', window.__capturedRequests.length); var groups = {}; - window.__capturedRequests.forEach(function(r) { - var path = r.url.split('?')[0]; - groups[path] = (groups[path] || 0) + 1; + window.__capturedRequests.forEach(function(record) { + groups[record.pathname] = (groups[record.pathname] || 0) + 1; }); + console.log('=== API Capture Summary ==='); + console.log('Total requests:', window.__capturedRequests.length); console.log('Endpoints:', Object.keys(groups)); + console.log('Recent actions:', recentActions.length); + console.log('window.__apiCapture.getDebugState() - inspect capture internals'); } }; - console.log('[API Capture] v3.0-simple installed'); + console.log('[API Capture] web2cli-base installed'); console.log(' window.__capturedRequests - captured data'); - console.log(' window.__apiCapture.summary() - show summary'); + console.log(' window.__apiCapture.getRecentActions() - recent user interactions'); + console.log(' window.__apiCapture.getDebugState() - current capture state'); })(); \ No newline at end of file diff --git a/.flocks/plugins/skills/workflow-builder/SKILL.md b/.flocks/plugins/skills/workflow-builder/SKILL.md index 970a333e..36475ffc 100644 --- a/.flocks/plugins/skills/workflow-builder/SKILL.md +++ b/.flocks/plugins/skills/workflow-builder/SKILL.md @@ -542,13 +542,11 @@ Body: { "inputs": <样例数据> } ### 创建路径(写入) -新建/修改工作流时,写入路径须与 **`flocks/workflow/center.py`**、**`flocks/server/routes/workflow.py`** 采用的规范目录一致(**`plugins/workflows/`**),不要用已退居兼容用途的旧目录作为主落点: +新建工作流时,写入路径必须在 用户级(全局)目录下: - **用户级(全局)**:`~/.flocks/plugins/workflows//`(`workflow.json`、`workflow.md`、`meta.json` 由 API 写入时可能同目录) -- **项目级**:`/.flocks/plugins/workflows//`,其中 `` 为从当前工作目录向上查找时**第一个**含有 `.flocks` 的目录(与服务端 `_find_workspace_root()` 一致);若需用户级、与仓库无关的工作流,则用上面的「用户级」路径。 - ⚠️ 任务输出(报告、artifacts)**不**写入此目录,统一写入 `~/.flocks/workspace/outputs//`(见全局文件输出约定) -> **说明**:历史上 workflow-builder 曾统一写 `~/.flocks/workflow//`,因此你会看到旧数据在该路径下;**当前代码在扫描与 API 落盘上以 `plugins/workflows/` 为规范路径**(`~/.flocks/workflow/`、`plugins/workflow/` 等仍为兼容扫描路径,且优先级更低)。 ### 读取路径(扫描) @@ -559,16 +557,14 @@ Body: { "inputs": <样例数据> } | 优先级(低→高) | 路径 | 说明 | |---|---|---| | 1 | `~/.flocks/plugins/workflow/` | 全局 legacy | -| 2 | `~/.flocks/workflow/` | 全局旧主路径(兼容) | -| 3 | `~/.flocks/plugins/workflows/` | **全局规范路径(推荐新工作流露地)** | +| 2 | `~/.flocks/plugins/workflows/` | **全局规范路径(推荐新工作流露地)** | **项目(workspace 下)** | 优先级(低→高) | 路径 | 说明 | |---|---|---| | 1 | `/.flocks/plugins/workflow/` | 项目 legacy | -| 2 | `/.flocks/workflow/` | 项目旧路径(兼容) | -| 3 | `/.flocks/plugins/workflows/` | **项目规范路径(推荐新工作流露地)** | +| 2 | `/.flocks/plugins/workflows/` | **项目规范路径(推荐新工作流露地)** | ### ⚠️ 绝对路径规范(重要) @@ -588,7 +584,6 @@ python3 -c "from pathlib import Path; p=Path.cwd(); ws=next((x for x in [p,*p.pa **正确示例**: - `/.flocks/plugins/workflows/alert_triage/workflow.json` ✅(用户级) -- `/.flocks/plugins/workflows/phishing-email-detection/workflow.json` ✅(项目级) **错误示例**: - `.flocks/plugins/workflows/alert_triage/workflow.json` ❌(未展开相对路径,易写错磁盘位置) diff --git a/.flocks/plugins/tools/api/onesec_v2_8_2/onesec.handler.py b/.flocks/plugins/tools/api/onesec_v2_8_2/onesec.handler.py index 6738c06a..a10c0a93 100644 --- a/.flocks/plugins/tools/api/onesec_v2_8_2/onesec.handler.py +++ b/.flocks/plugins/tools/api/onesec_v2_8_2/onesec.handler.py @@ -1,6 +1,7 @@ from __future__ import annotations import base64 +import datetime as dt import hmac import os import time @@ -488,6 +489,42 @@ def _action_task_content( return [content] +def _normalize_dns_search_blocked_queries(payload: Any) -> Any: + if not isinstance(payload, dict): + return payload + + items = payload.get("items") + if not isinstance(items, list): + return payload + + normalized_items: list[Any] = [] + changed = False + for item in items: + if not isinstance(item, dict): + normalized_items.append(item) + continue + + normalized_item = dict(item) + if "result" not in normalized_item: + normalized_item["result"] = "block" + changed = True + if "is_blocked" not in normalized_item: + normalized_item["is_blocked"] = True + changed = True + normalized_items.append(normalized_item) + + if not changed: + return payload + + return {**payload, "items": normalized_items} + + +def _normalize_output(action: str, payload: Any) -> Any: + if action in {"searchBlockedQueries", "dns_search_blocked_queries"}: + return _normalize_dns_search_blocked_queries(payload) + return payload + + def _json_result(action: str, payload: Any) -> ToolResult: metadata = {"source": "OneSEC", "api": action} if isinstance(payload, dict): @@ -495,7 +532,11 @@ def _json_result(action: str, payload: Any) -> ToolResult: if response_code not in (None, 0, 200): error_msg = payload.get("verbose_msg") or payload.get("msg") or "Unknown error" return ToolResult(success=False, error=f"OneSEC API error: {error_msg}", metadata=metadata) - return ToolResult(success=True, output=payload.get("data", payload), metadata=metadata) + return ToolResult( + success=True, + output=_normalize_output(action, payload.get("data", payload)), + metadata=metadata, + ) return ToolResult(success=True, output=payload, metadata=metadata) @@ -513,6 +554,101 @@ def _require_fields(params: dict[str, Any], *fields: str) -> list[str]: return [field for field in fields if not _has_value(params.get(field))] +def _normalize_unix_seconds(value: Any, field_name: str) -> tuple[Optional[int], Optional[str]]: + if value is None: + return None, None + if isinstance(value, bool): + return None, f"{field_name} ({value}) 必须是 Unix 秒级时间戳。" + if isinstance(value, int): + return value, None + if isinstance(value, float): + if value.is_integer(): + return int(value), None + return None, f"{field_name} ({value}) 必须是 Unix 秒级时间戳。" + if not isinstance(value, str): + return None, f"{field_name} ({value}) 必须是 Unix 秒级时间戳。" + + stripped = value.strip() + if not stripped: + return None, None + if stripped.lstrip("-").isdigit(): + return int(stripped), None + + try: + parsed = dt.datetime.fromisoformat(stripped) + except ValueError: + return ( + None, + f"{field_name} ({value}) 必须是 Unix 秒级时间戳。" + " 当前工具支持自动转换常见日期时间格式,如 `YYYY-MM-DD HH:MM:SS`。", + ) + + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=dt.datetime.now().astimezone().tzinfo) + return int(parsed.timestamp()), None + + +def _normalize_action_params(action: str, params: dict[str, Any]) -> tuple[dict[str, Any], Optional[str]]: + normalized = dict(params) + + for field_name in ("time_from", "time_to", "begin_time", "end_time"): + if not _has_value(normalized.get(field_name)): + continue + normalized_value, error = _normalize_unix_seconds(normalized.get(field_name), field_name) + if error: + return params, error + normalized[field_name] = normalized_value + + if action in {"dns_search_blocked_queries", "dns_get_recent_blocked_queries"}: + public_ip = normalized.get("public_ip") + if isinstance(public_ip, str) and public_ip.strip(): + normalized["public_ip"] = [public_ip.strip()] + if action == "dns_search_blocked_queries": + if not _has_value(normalized.get("keyword")) and _has_value(normalized.get("domain")): + normalized["keyword"] = normalized["domain"] + elif action == "dns_search_queries": + if isinstance(normalized.get("qType"), str): + normalized["qType"] = normalized["qType"].strip().upper() + if isinstance(normalized.get("rcode"), str): + normalized["rcode"] = normalized["rcode"].strip().upper() + elif action in {"dns_search_blocked_queries", "dns_get_recent_blocked_queries"}: + if isinstance(normalized.get("block_reason"), str): + normalized["block_reason"] = normalized["block_reason"].strip().lower() + elif action == "dns_get_all_destination_list": + if isinstance(normalized.get("policy_type"), str): + normalized["policy_type"] = normalized["policy_type"].strip().lower() + elif action == "threat_virus_scan": + for field_name in ("task_type", "scan_type", "scanmode"): + value = normalized.get(field_name) + if isinstance(value, str) and value.strip().lstrip("-").isdigit(): + normalized[field_name] = int(value.strip()) + elif action == "threat_upgrade_bd_version_task": + value = normalized.get("bd_upgrade_type") + if isinstance(value, str) and value.strip().lstrip("-").isdigit(): + normalized["bd_upgrade_type"] = int(value.strip()) + elif action == "threat_update_bd_version": + if isinstance(normalized.get("os_platform"), str): + normalized["os_platform"] = normalized["os_platform"].strip().lower() + if isinstance(normalized.get("os_arch"), str): + arch_map = { + "apple silicon": "Apple Silicon", + "intel chip": "Intel Chip", + } + normalized["os_arch"] = arch_map.get(normalized["os_arch"].strip().lower(), normalized["os_arch"]) + elif action == "ops_query_task_page_list": + if isinstance(normalized.get("time_type"), str): + normalized["time_type"] = normalized["time_type"].strip() + auto_value = normalized.get("auto") + if isinstance(auto_value, str) and auto_value.strip().lstrip("-").isdigit(): + normalized["auto"] = int(auto_value.strip()) + + return normalized, None + + +def _reject_present_fields(params: dict[str, Any], *fields: str) -> list[str]: + return [field for field in fields if _has_value(params.get(field))] + + def _validate_non_empty_aliases(params: dict[str, Any], aliases: tuple[str, ...], label: str) -> list[str]: if any(_has_value(params.get(alias)) for alias in aliases): return [] @@ -539,23 +675,104 @@ def _validate_non_empty_aliases(params: dict[str, Any], aliases: tuple[str, ...] } _ONE_DAY_SECS = 86400 +_THIRTY_DAY_SECS = 30 * _ONE_DAY_SECS +_THREE_MONTH_SECS = 90 * _ONE_DAY_SECS + +_SPAN_LIMIT_RULES: dict[str, tuple[str, str, int, str]] = { + "dns_search_blocked_queries": ( + "time_from", + "time_to", + _ONE_DAY_SECS, + "按 OneSEC API 文档,`dns_search_blocked_queries` 的时间窗口最多 24 小时。请缩小 time_from/time_to 范围。", + ), + "dns_search_queries": ( + "time_from", + "time_to", + _ONE_DAY_SECS, + "按 OneSEC API 文档,`dns_search_queries` 的时间窗口最多 24 小时。", + ), + "edr_get_threat_files": ( + "time_from", + "time_to", + _THREE_MONTH_SECS, + "按 OneSEC API 文档,`edr_get_threat_files` 的时间窗口最长三个月。请缩小 time_from/time_to 范围。", + ), + "edr_get_threat_activities": ( + "time_from", + "time_to", + _THREE_MONTH_SECS, + "按 OneSEC API 文档,`edr_get_threat_activities` 的时间窗口最长三个月。请缩小 time_from/time_to 范围。", + ), + "edr_get_incidents": ( + "time_from", + "time_to", + _THREE_MONTH_SECS, + "按 OneSEC API 文档,`edr_get_incidents` 的时间窗口最长三个月。请缩小 time_from/time_to 范围。", + ), + "edr_get_endpoint_alerts": ( + "time_from", + "time_to", + _THREE_MONTH_SECS, + "按 OneSEC API 文档,`edr_get_endpoint_alerts` 的时间窗口最长三个月。请缩小 time_from/time_to 范围。", + ), + "ops_query_audit_log": ( + "begin_time", + "end_time", + _THIRTY_DAY_SECS, + "按 OneSEC API 文档,`ops_query_audit_log` 的查询窗口最多 30 天。", + ), +} + +_AGE_LIMIT_RULES: dict[str, tuple[str, int, str]] = { + "dns_search_queries": ( + "time_from", + _ONE_DAY_SECS, + "按 OneSEC API 文档,`dns_search_queries` 仅支持最近 24 小时内的数据。请将 time_from 设置在最近 24 小时内。", + ), + "ops_query_audit_log": ( + "begin_time", + _THIRTY_DAY_SECS, + "按 OneSEC API 文档,`ops_query_audit_log` 仅支持最近 30 天内的审计日志。请调整 begin_time。", + ), +} + +_DNS_QTYPE_VALUES = {"A", "AAAA", "CNAME", "MX", "TXT", "PTR", "NS", "CERT", "SRV", "SOA", "DS"} +_DNS_RCODE_VALUES = {"NOERROR", "NXDOMAIN", "FORMERR", "SERVFAIL", "YXDOMAIN"} +_DNS_BLOCK_REASON_VALUES = {"threat", "custom"} +_DNS_POLICY_TYPE_VALUES = {"block", "pass"} +_THREAT_SCAN_TASK_TYPES = {10110, 10120, 10130} +_THREAT_SCANMODES = {1, 2, 3} +_THREAT_BD_UPGRADE_TYPES = {1, 2} +_THREAT_OS_PLATFORMS = {"windows", "macos"} +_THREAT_MAC_ARCHES = {"Apple Silicon", "Intel Chip"} +_OPS_TASK_TIME_TYPES = {"create_time", "update_time"} +_OPS_TASK_AUTO_VALUES = {0, 1} def _validate_time_params(action: str, params: dict[str, Any]) -> Optional[str]: - """Check time_from/time_to consistency and recent-API 24-hour window.""" + """Check time order and documented time-window limits.""" + for start_field, end_field in (("time_from", "time_to"), ("begin_time", "end_time")): + start_value = params.get(start_field) + end_value = params.get(end_field) + if start_value is None or end_value is None: + continue + try: + start_int, end_int = int(start_value), int(end_value) + except (TypeError, ValueError): + continue + if start_int >= end_int: + return ( + f"{start_field} ({start_int}) 必须小于 {end_field} ({end_int})。" + f" 请确认时间范围:`{start_field}` 为开始时间,`{end_field}` 为结束时间。" + ) + tf = params.get("time_from") tt = params.get("time_to") - if tf is not None and tt is not None: try: tf_int, tt_int = int(tf), int(tt) except (TypeError, ValueError): return None - if tf_int >= tt_int: - return ( - f"time_from ({tf_int}) 必须小于 time_to ({tt_int})。" - " 请确认时间范围:time_from 为开始时间,time_to 为结束时间。" - ) if action in _RECENT_ACTIONS: span = tt_int - tf_int if span > _ONE_DAY_SECS: @@ -580,6 +797,114 @@ def _validate_time_params(action: str, params: dict[str, Any]) -> Optional[str]: f"{action} 属于 recent 接口,仅支持最近 24 小时的数据。" f" 传入的 time_from ({tf_int}) 距当前时间已超过 {age // 3600} 小时。{alt_hint}" ) + + span_rule = _SPAN_LIMIT_RULES.get(action) + if span_rule is not None: + start_field, end_field, max_span, message = span_rule + start_value = params.get(start_field) + end_value = params.get(end_field) + if start_value is not None and end_value is not None: + try: + start_int, end_int = int(start_value), int(end_value) + except (TypeError, ValueError): + return None + if end_int - start_int > max_span: + return message + + age_rule = _AGE_LIMIT_RULES.get(action) + if age_rule is not None: + field_name, max_age, message = age_rule + field_value = params.get(field_name) + if field_value is not None: + try: + field_int = int(field_value) + except (TypeError, ValueError): + return None + if int(time.time()) - field_int > max_age + 3600: + return message + return None + + +def _validate_enum_params(action: str, params: dict[str, Any]) -> Optional[str]: + if action == "dns_search_queries": + qtype = params.get("qType") + if _has_value(qtype) and str(qtype) not in _DNS_QTYPE_VALUES: + allowed = ", ".join(sorted(_DNS_QTYPE_VALUES)) + return f"`qType` 取值无效:{qtype}。按 OneSEC API 文档仅支持:{allowed}。" + rcode = params.get("rcode") + if _has_value(rcode) and str(rcode) not in _DNS_RCODE_VALUES: + allowed = ", ".join(sorted(_DNS_RCODE_VALUES)) + return f"`rcode` 取值无效:{rcode}。按 OneSEC API 文档仅支持:{allowed}。" + + if action in {"dns_search_blocked_queries", "dns_get_recent_blocked_queries"}: + block_reason = params.get("block_reason") + if _has_value(block_reason) and str(block_reason) not in _DNS_BLOCK_REASON_VALUES: + allowed = ", ".join(sorted(_DNS_BLOCK_REASON_VALUES)) + return f"`block_reason` 取值无效:{block_reason}。按 OneSEC API 文档仅支持:{allowed}。" + + if action == "dns_get_all_destination_list": + policy_type = params.get("policy_type") + if _has_value(policy_type) and str(policy_type) not in _DNS_POLICY_TYPE_VALUES: + allowed = ", ".join(sorted(_DNS_POLICY_TYPE_VALUES)) + return f"`policy_type` 取值无效:{policy_type}。按 OneSEC API 文档仅支持:{allowed}。" + + if action == "threat_virus_scan": + task_type = params.get("task_type", params.get("scan_type")) + if _has_value(task_type): + try: + task_type_int = int(task_type) + except (TypeError, ValueError): + return f"`task_type`/`scan_type` 取值无效:{task_type}。按 OneSEC API 文档应为整数枚举值。" + if task_type_int not in _THREAT_SCAN_TASK_TYPES: + allowed = ", ".join(str(item) for item in sorted(_THREAT_SCAN_TASK_TYPES)) + return f"`task_type`/`scan_type` 取值无效:{task_type_int}。按 OneSEC API 文档仅支持:{allowed}。" + scanmode = params.get("scanmode") + if _has_value(scanmode): + try: + scanmode_int = int(scanmode) + except (TypeError, ValueError): + return f"`scanmode` 取值无效:{scanmode}。按 OneSEC API 文档应为整数枚举值。" + if scanmode_int not in _THREAT_SCANMODES: + allowed = ", ".join(str(item) for item in sorted(_THREAT_SCANMODES)) + return f"`scanmode` 取值无效:{scanmode_int}。按 OneSEC API 文档仅支持:{allowed}。" + + if action == "threat_upgrade_bd_version_task": + bd_upgrade_type = params.get("bd_upgrade_type") + if _has_value(bd_upgrade_type): + try: + upgrade_int = int(bd_upgrade_type) + except (TypeError, ValueError): + return f"`bd_upgrade_type` 取值无效:{bd_upgrade_type}。按 OneSEC API 文档应为整数枚举值。" + if upgrade_int not in _THREAT_BD_UPGRADE_TYPES: + allowed = ", ".join(str(item) for item in sorted(_THREAT_BD_UPGRADE_TYPES)) + return f"`bd_upgrade_type` 取值无效:{upgrade_int}。按 OneSEC API 文档仅支持:{allowed}。" + + if action == "threat_update_bd_version": + os_platform = params.get("os_platform") + if _has_value(os_platform) and str(os_platform) not in _THREAT_OS_PLATFORMS: + allowed = ", ".join(sorted(_THREAT_OS_PLATFORMS)) + return f"`os_platform` 取值无效:{os_platform}。按 OneSEC API 文档仅支持:{allowed}。" + if str(os_platform) == "macos": + os_arch = params.get("os_arch") + if _has_value(os_arch) and str(os_arch) not in _THREAT_MAC_ARCHES: + allowed = ", ".join(sorted(_THREAT_MAC_ARCHES)) + return f"`os_arch` 取值无效:{os_arch}。当 `os_platform=macos` 时仅支持:{allowed}。" + + if action == "ops_query_task_page_list": + time_type = params.get("time_type") + if _has_value(time_type) and str(time_type) not in _OPS_TASK_TIME_TYPES: + allowed = ", ".join(sorted(_OPS_TASK_TIME_TYPES)) + return f"`time_type` 取值无效:{time_type}。按 OneSEC API 文档仅支持:{allowed}。" + auto = params.get("auto") + if _has_value(auto): + try: + auto_int = int(auto) + except (TypeError, ValueError): + return f"`auto` 取值无效:{auto}。按 OneSEC API 文档应为整数枚举值。" + if auto_int not in _OPS_TASK_AUTO_VALUES: + allowed = ", ".join(str(item) for item in sorted(_OPS_TASK_AUTO_VALUES)) + return f"`auto` 取值无效:{auto_int}。按 OneSEC API 文档仅支持:{allowed}。" + return None @@ -587,13 +912,38 @@ def _validate_action_params(action: str, params: dict[str, Any]) -> Optional[str time_err = _validate_time_params(action, params) if time_err: return time_err + enum_err = _validate_enum_params(action, params) + if enum_err: + return enum_err missing: list[str] = [] + unsupported: list[str] = [] if action == "dns_search_blocked_queries": missing.extend(_require_fields(params, "time_from", "time_to", "domain", "keyword")) + if { + "domain", + "keyword", + }.issubset(set(missing)) and _has_value(params.get("public_ip")): + return ( + "`dns_search_blocked_queries` 按 OneSEC API 文档要求必须传 `domain` 和 `keyword`。" + " 如果你当前只有 `public_ip` + 时间范围,且要查询最近 24 小时拦截记录," + " 请改用 `dns_get_recent_blocked_queries`。" + ) elif action == "dns_get_recent_blocked_queries": missing.extend(_require_fields(params, "time_from", "time_to")) + unsupported.extend( + _reject_present_fields( + params, + "domain", + "keyword", + "private_ip", + "threat_type", + "cur_page", + "pageitemsnum", + "page_items_num", + ) + ) elif action == "dns_search_queries": missing.extend(_require_fields(params, "time_from", "time_to")) elif action == "dns_search_threatened_endpoint": @@ -668,13 +1018,24 @@ def _validate_action_params(action: str, params: dict[str, Any]) -> Optional[str elif action == "software_query_agent_list": missing.extend(_require_fields(params, "name", "publisher")) - if not missing: - return None deduped: list[str] = [] for item in missing: if item not in deduped: deduped.append(item) - return f"Missing required parameters for {action}: {', '.join(deduped)}" + if deduped: + return f"Missing required parameters for {action}: {', '.join(deduped)}" + + deduped_unsupported: list[str] = [] + for item in unsupported: + if item not in deduped_unsupported: + deduped_unsupported.append(item) + if deduped_unsupported: + fields = ", ".join(deduped_unsupported) + return ( + f"{action} 按 OneSEC API 文档不支持以下参数: {fields}。" + " 若需要按域名或关键字筛选 DNS 拦截记录,请改用 `dns_search_blocked_queries`。" + ) + return None class ActionSpec: @@ -702,6 +1063,7 @@ def __init__( "domain", "keyword", "block_reason", + "show_unblocked_threat", "threat_level", "threat_type", "cur_page", @@ -710,7 +1072,15 @@ def __init__( "dns_get_recent_blocked_queries": ActionSpec( "POST", "/open/api/client/getRecentBlockedQueries", - lambda p: _pick(p, "time_from", "time_to", "public_ip", "block_reason", "threat_level"), + lambda p: _pick( + p, + "time_from", + "time_to", + "public_ip", + "block_reason", + "show_unblocked_threat", + "threat_level", + ), ), "dns_search_queries": ActionSpec("POST", "/open/api/client/searchQueries", _dns_search_queries_payload), "dns_search_threatened_endpoint": ActionSpec( @@ -1053,10 +1423,13 @@ async def unified_ops(ctx: ToolContext, action: str, **params: Any) -> ToolResul success=False, error=f"Unknown action: {action}. Available actions: {available}", ) - validation_error = _validate_action_params(action, params) + normalized_params, normalize_error = _normalize_action_params(action, params) + if normalize_error: + return ToolResult(success=False, error=normalize_error) + validation_error = _validate_action_params(action, normalized_params) if validation_error: return ToolResult(success=False, error=validation_error) - result = await _call_onesec_api(spec.method, spec.path, spec.payload_builder(params)) + result = await _call_onesec_api(spec.method, spec.path, spec.payload_builder(normalized_params)) if result.success: metadata = dict(result.metadata or {}) metadata["api"] = action diff --git a/.flocks/plugins/tools/api/onesec_v2_8_2/onesec_dns.yaml b/.flocks/plugins/tools/api/onesec_v2_8_2/onesec_dns.yaml index 910f7aca..6f664291 100644 --- a/.flocks/plugins/tools/api/onesec_v2_8_2/onesec_dns.yaml +++ b/.flocks/plugins/tools/api/onesec_v2_8_2/onesec_dns.yaml @@ -20,14 +20,15 @@ inputSchema: - dns_search_blocked_queries 用途: 分页查询 DNS 拦截记录 必填: `time_from`、`time_to`、`domain`、`keyword` - 常用: `public_ip`、`private_ip`、`block_reason`、`threat_level`、`threat_type`、`cur_page` - 风险提示: 查询窗口通常不超过 24 小时;缺少 `domain` 或 `keyword` 会被工具层拦截 + 常用: `public_ip`、`private_ip`、`block_reason`、`show_unblocked_threat`、`threat_level`、`threat_type`、`cur_page` + 风险提示: 推荐使用 Unix 秒级时间戳;工具会兼容常见 `YYYY-MM-DD HH:MM:SS` 日期字符串。若只传 `domain`,工具会默认将 `keyword` 补成相同值;若仅按 `public_ip` + 时间范围查询最近 24 小时数据,请优先改用 `dns_get_recent_blocked_queries` 是否任务型: 否 - dns_get_recent_blocked_queries 用途: 获取近期 DNS 拦截记录(仅适用于增量同步,有严格时间限制) 必填: `time_from`、`time_to` - 常用: `public_ip`、`block_reason`、`threat_level` + 常用: `public_ip`、`block_reason`、`show_unblocked_threat`、`threat_level` ⚠️ 限制: 仅支持最近 24 小时的数据;time_from 超出 24 小时将被服务器拒绝 + ⚠️ 限制: 不支持 `domain`、`keyword`、`private_ip`、`threat_type` 和分页参数;如需按域名/关键字查询,请改用 `dns_search_blocked_queries` ⚠️ 推荐: 通用查询请优先使用 `dns_search_blocked_queries` 是否任务型: 否 - dns_search_queries @@ -86,15 +87,18 @@ inputSchema: type: integer description: > 开始时间戳,Unix 秒级时间戳。必须小于 time_to。可以使用python datetime动态计算。 + 工具会兼容常见日期时间字符串,如 `YYYY-MM-DD HH:MM:SS`。 对于 recent 系列接口,time_from 距当前时间不能超过 24 小时。 time_to: type: integer - description: 结束时间戳,Unix 秒级时间戳。必须大于 time_from。可以使用python datetime动态计算。 + description: > + 结束时间戳,Unix 秒级时间戳。必须大于 time_from。可以使用python datetime动态计算。 + 工具会兼容常见日期时间字符串,如 `YYYY-MM-DD HH:MM:SS`。 public_ip: type: array items: type: string - description: 公网 IP 列表 + description: 公网 IP 列表;当前工具会兼容单个字符串并自动包装为单元素数组 private_ip: type: string description: 私网 IP @@ -107,6 +111,9 @@ inputSchema: block_reason: type: string description: 拦截原因 + show_unblocked_threat: + type: integer + description: 是否返回未拦截的威胁记录,常见取值 `1` threat_level: type: array items: diff --git a/.flocks/plugins/tools/api/onesec_v2_8_2/onesec_edr.yaml b/.flocks/plugins/tools/api/onesec_v2_8_2/onesec_edr.yaml index f45a8ccb..a17264cd 100644 --- a/.flocks/plugins/tools/api/onesec_v2_8_2/onesec_edr.yaml +++ b/.flocks/plugins/tools/api/onesec_v2_8_2/onesec_edr.yaml @@ -24,7 +24,7 @@ inputSchema: 用途: 分页查询恶意文件(通用查询首选,适合历史回溯) 必填: 无 常用: `time_from`、`time_to`、`group_list`、`umid_list`、`cur_page`、`page_size` - 风险提示: 更适合全量回溯;如需排序扩展需结合后端实现 + 风险提示: 更适合全量回溯;按 OneSEC API 文档时间窗口最长三个月 是否任务型: 否 - edr_get_recent_threat_files 用途: 查询近期恶意文件(仅适用于增量同步,有严格时间限制) @@ -37,7 +37,7 @@ inputSchema: 用途: 分页查询威胁行为(通用查询首选,适合历史回溯) 必填: 无 常用: `time_from`、`time_to`、`group_list`、`threat_phase_list`、`cur_page`、`page_size` - 风险提示: 行为筛选条件较多,优先显式传时间范围 + 风险提示: 行为筛选条件较多,优先显式传时间范围;按 OneSEC API 文档时间窗口最长三个月 是否任务型: 否 - edr_get_recent_threat_activities 用途: 查询近期威胁行为(仅适用于增量同步,有严格时间限制) @@ -50,7 +50,7 @@ inputSchema: 用途: 分页查询威胁事件(通用查询首选,适合历史回溯) 必填: 无 常用: `time_from`、`time_to`、`params`、`cur_page`、`page_size` - 风险提示: 返回事件对象较复杂,后续常用于提取 `incident_id` + 风险提示: 返回事件对象较复杂,后续常用于提取 `incident_id`;按 OneSEC API 文档时间窗口最长三个月 是否任务型: 否 - edr_get_recent_incidents 用途: 查询近期威胁事件(仅适用于增量同步,有严格时间限制) @@ -63,7 +63,7 @@ inputSchema: 用途: 分页查询终端告警日志(通用查询首选,适合历史回溯) 必填: 无 常用: `time_from`、`time_to`、`sql`、`search_fields`、`cur_page`、`page_size` - 风险提示: 若返回字段过多可能导致接口耗时增加 + 风险提示: 若返回字段过多可能导致接口耗时增加;按 OneSEC API 文档时间窗口最长三个月 是否任务型: 否 - edr_get_recent_endpoint_alerts 用途: 查询近期终端告警日志(仅适用于增量同步,有严格时间限制) diff --git a/.flocks/plugins/tools/api/onesec_v2_8_2/onesec_ops.yaml b/.flocks/plugins/tools/api/onesec_v2_8_2/onesec_ops.yaml index e2f7b1f5..d7671369 100644 --- a/.flocks/plugins/tools/api/onesec_v2_8_2/onesec_ops.yaml +++ b/.flocks/plugins/tools/api/onesec_v2_8_2/onesec_ops.yaml @@ -33,13 +33,13 @@ inputSchema: 用途: 分页查询审计日志 必填: `begin_time`、`end_time` 常用: `cur_page`、`page_size`、`operate_list`、`role_list`、`group_list`、`api_access_type_list` - 风险提示: 建议限定时间范围,避免返回量过大 + 风险提示: 建议限定时间范围,避免返回量过大;按 OneSEC API 文档仅支持最近 30 天内日志 是否任务型: 否 - ops_query_task_page_list 用途: 分页查询任务列表 必填: `time_type`、`begin_time`、`end_time`、`auto` 常用: `cur_page`、`page_size`、`sort_by`、`sort_order`、`task_status_list`、`task_type_list`、`group_id` - 风险提示: 分页接口通常要求排序对象;缺少关键时间条件会被工具层拦截 + 风险提示: 分页接口通常要求排序对象;`time_type` 仅支持 `create_time/update_time`,`auto` 仅支持 `0/1` 是否任务型: 否 - ops_query_task_execute_list 用途: 分页查询任务执行进度明细 @@ -88,7 +88,7 @@ inputSchema: description: 结构化条件列表 time_type: type: string - description: 时间字段类型 + description: 时间字段类型,仅支持 `create_time`、`update_time` cur_page: type: integer default: 1 @@ -158,7 +158,7 @@ inputSchema: description: API 访问来源列表 auto: type: integer - description: 任务来源 + description: 任务来源,仅支持 `0` 人工响应、`1` 自动响应 task_type_list: type: array items: diff --git a/.flocks/plugins/tools/api/onesec_v2_8_2/onesec_threat.yaml b/.flocks/plugins/tools/api/onesec_v2_8_2/onesec_threat.yaml index d5e1e619..6fd7ca33 100644 --- a/.flocks/plugins/tools/api/onesec_v2_8_2/onesec_threat.yaml +++ b/.flocks/plugins/tools/api/onesec_v2_8_2/onesec_threat.yaml @@ -27,7 +27,7 @@ inputSchema: 用途: 下发病毒扫描任务 必填: `agent_list`、`task_type`/`scan_type`、`scanmode`;当 `task_type=10130` 时还需 `scan_paths` 常用: `scan_paths`、`scanmode`、`task_type` - 风险提示: 高风险任务;扫描范围较大时可能影响终端性能 + 风险提示: 高风险任务;`task_type` 仅支持 `10110/10120/10130`,`scanmode` 仅支持 `1/2/3` 是否任务型: 是 - threat_stop_virus_scan 用途: 下发终止病毒扫描任务 @@ -39,13 +39,13 @@ inputSchema: 用途: 下发病毒库升级任务 必填: `agent_list`、`bd_upgrade_type` 常用: `bd_upgrade_type` - 风险提示: 写操作;会触发终端更新病毒库 + 风险提示: 写操作;会触发终端更新病毒库,`bd_upgrade_type` 仅支持 `1/2` 是否任务型: 是 - threat_update_bd_version 用途: 修改病毒库应用版本 必填: `os_platform`;当 `os_platform=macos` 时还需 `os_arch` 常用: `os_arch` - 风险提示: 写操作;平台参数不匹配会被工具层拦截 + 风险提示: 写操作;`os_platform` 仅支持 `windows/macos`,macOS 架构仅支持 `Apple Silicon/Intel Chip` 是否任务型: 否 enum: - threat_query_bd_version @@ -60,7 +60,7 @@ inputSchema: description: 终端 UMID 列表 task_type: type: integer - description: 病毒扫描任务类型 + description: 病毒扫描任务类型,支持 `10110` 快速扫描、`10120` 全盘扫描、`10130` 自定义扫描 scan_paths: type: array items: @@ -68,22 +68,22 @@ inputSchema: description: 自定义扫描路径 scanmode: type: integer - description: 扫描模式 + description: 扫描模式,支持 `1` 极速、`2` 均衡、`3` 低耗 scan_type: type: integer description: 兼容旧调用的病毒扫描任务类型别名 bd_upgrade_type: type: integer - description: 病毒库升级类型 + description: 病毒库升级类型,支持 `1` 应用版本、`2` 云端最新版本 issue_time: type: integer description: 下发时间 os_platform: type: string - description: 操作系统平台 + description: 操作系统平台,仅支持 `windows`、`macos` os_arch: type: string - description: 系统架构 + description: 系统架构;当 `os_platform=macos` 时仅支持 `Apple Silicon`、`Intel Chip` required: - action handler: diff --git a/.flocks/plugins/tools/api/sangfor_af_v8_0_106/_provider.yaml b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/_provider.yaml new file mode 100644 index 00000000..084e3bd7 --- /dev/null +++ b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/_provider.yaml @@ -0,0 +1,52 @@ +name: sangfor_af +service_id: sangfor_af +version: "8.0.106" +description: > + Sangfor AF (Application Firewall) v8.0.106 REST API service (latest). + Includes all features from v8.0.85 plus extended session management + (recheck, export, block), alarm notifications, CPU/memory trend curves, + and link probe (SLA) APIs. +description_cn: > + 深信服 AF 下一代防火墙 v8.0.106 REST API 服务(最新版本)。 + 在 v8.0.85 基础上新增:会话精细化管理(recheck/导出/阻断)、 + 告警通知(在册事件、服务商信息、信息模板)、CPU/内存曲线、 + 链路探测(SLA)等接口。 +auth: + type: custom + secret: sangfor_af_v8_0_106_username + secret_secret: sangfor_af_v8_0_106_password +credential_fields: + - key: username + label: 管理员用户名 + storage: secret + config_key: username + secret_id: sangfor_af_v8_0_106_username + input_type: text + required: true + - key: password + label: 管理员密码 + storage: secret + config_key: password + secret_id: sangfor_af_v8_0_106_password + input_type: password + required: true + - key: base_url + label: 设备地址 (Base URL) + storage: config + config_key: base_url + input_type: url + default: "https://192.168.1.1" +defaults: + base_url: "https://192.168.1.1" + timeout: 60 + category: custom + product_version: "8.0.106" +notes: | + 深信服 AF v8.0.106 API 认证流程(同 v8.0.85): + 1. 在 AF WebUI「系统 → 管理员账号」勾选 WEBAPI 权限。 + 2. Handler 自动用用户名/密码换取 token 并缓存(带 keepalive 自动续期)。 + 3. 后续所有请求由 Handler 自动注入 Cookie: token=。 + 4. token 默认 10 分钟无操作后失效,缓存失效会自动重新登录。 + 5. 新增:会话清除/阻断、告警通知、CPU/内存趋势曲线、链路探测等 API。 + + `verify_ssl` 由表单底部「SSL 验证」开关控制(默认关闭,与 sangfor_sip 一致)。 diff --git a/.flocks/plugins/tools/api/sangfor_af_v8_0_106/_test.yaml b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/_test.yaml new file mode 100644 index 00000000..04916949 --- /dev/null +++ b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/_test.yaml @@ -0,0 +1,144 @@ +schema_version: 1 +provider: sangfor_af + +connectivity: + tool: sangfor_af_v106_status + params: + action: get_system_version + +fixtures: + sangfor_af_v106_auth: + - label: "Session keepalive" + label_cn: "刷新 Session 保活" + tags: [smoke] + params: + action: keepalive + assert: + success: true + + sangfor_af_v106_status: + - label: "Get system version" + label_cn: "获取系统版本信息" + tags: [smoke] + params: + action: get_system_version + assert: + success: true + + - label: "Get CPU usage" + label_cn: "获取 CPU 使用率" + tags: [smoke] + params: + action: get_cpu_usage + + - label: "Get CPU usage trend (last 60 min)" + label_cn: "获取近60分钟 CPU 使用率趋势" + tags: [smoke, trend] + params: + action: get_cpu_trend + minutes: 60 + + - label: "Get memory usage trend (last 60 min)" + label_cn: "获取近60分钟内存使用率趋势" + tags: [trend] + params: + action: get_memory_trend + minutes: 60 + + sangfor_af_v106_monitor: + - label: "Get session summary" + label_cn: "获取会话概要信息" + tags: [smoke, monitor] + params: + action: get_session_summary + assert: + success: true + + - label: "Get daily new sessions" + label_cn: "获取每日新建会话信息" + tags: [monitor] + params: + action: get_session_dailys + + - label: "Get user traffic top 10" + label_cn: "获取用户流量排行前10名" + tags: [monitor, traffic] + params: + action: get_user_traffic_rank + topNumber: 10 + + - label: "Get active sessions" + label_cn: "获取实时活跃会话列表" + tags: [monitor, session] + params: + action: get_sessions + + - label: "Get session recheck config" + label_cn: "获取 Session Recheck 配置" + tags: [monitor, config] + params: + action: get_session_recheck + + sangfor_af_v106_alarm: + - label: "Get alarm events config" + label_cn: "获取告警事件配置" + tags: [smoke, alarm] + params: + action: get_alarm_events_config + assert: + success: true + + - label: "Get alarm notifications config" + label_cn: "获取告警通知渠道配置" + tags: [alarm] + params: + action: get_alarm_notifications + + - label: "Get alarm messages (recent)" + label_cn: "获取历史告警消息列表" + tags: [alarm] + params: + action: get_alarm_messages + _length: 20 + _order: desc + + sangfor_af_v106_ops: + - label: "List blocked attacker IPs" + label_cn: "查询封锁攻击者 IP 列表" + tags: [smoke, blockip] + params: + action: get_blockip_list + assert: + success: true + + - label: "List blacklist entries" + label_cn: "查询黑名单列表" + tags: [smoke, blacklist] + params: + action: get_blackwhitelist + type: BLACK + + sangfor_af_v106_objects: + - label: "List IP groups" + label_cn: "查询 IP 地址组列表" + tags: [smoke] + params: + action: get_ipgroups + assert: + success: true + + - label: "List link probes (SLA objects)" + label_cn: "查询链路探测对象列表" + tags: [smoke, sla] + params: + action: get_link_probes + + sangfor_af_v106_network: + - label: "List routing table (all routes)" + label_cn: "查询完整路由表" + tags: [smoke] + params: + action: get_routes + routeType: ALL_ROUTE + assert: + success: true diff --git a/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af.handler.py b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af.handler.py new file mode 100644 index 00000000..dbb73fdd --- /dev/null +++ b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af.handler.py @@ -0,0 +1,790 @@ +""" +Sangfor AF (Application Firewall) v8.0.106 API Handler (latest version). + +Extends v8.0.85 with: + - Session recheck config (3.1.4.8/9) + - Session export (3.1.4.10) + - Block session (3.1.4.6) + - Batch clear sessions (3.1.4.7) + - CPU/Memory trend curves (8.x) + - Alarm notifications: events, providers, templates, messages + - Extended session count/summary info + - Interface statistics (enhanced) + - TOP N settings (3.4.2) + - Link probe/SLA object APIs (4.2.x) + +Authentication: same as v8.0.48/85 (session-based Cookie token). +""" +from __future__ import annotations + +import os +from typing import Any, Callable, Optional + +import aiohttp + +from flocks.config.config_writer import ConfigWriter +from flocks.tool.registry import ToolContext, ToolResult + +# ── Constants ──────────────────────────────────────────────────────────────── + +SERVICE_ID = "sangfor_af_v8_0_106" +DEFAULT_BASE_URL = "https://192.168.1.1" +DEFAULT_TIMEOUT = 60 +NAMESPACE = "public" + +API_V1 = f"/api/v1/namespaces/{NAMESPACE}" +API_BATCH = f"/api/batch/v1/namespaces/{NAMESPACE}" + +_TOKEN_CACHE: dict[str, str] = {} + + +# ── Secret / Config helpers ─────────────────────────────────────────────────── + +def _get_secret_manager(): + from flocks.security import get_secret_manager + return get_secret_manager() + + +def _resolve_ref(value: Any) -> Optional[str]: + if value is None: + return None + if not isinstance(value, str): + return str(value) + if value.startswith("{secret:") and value.endswith("}"): + return _get_secret_manager().get(value[len("{secret:"): -1]) + if value.startswith("{env:") and value.endswith("}"): + return os.getenv(value[len("{env:"): -1]) + return value + + +def _service_config() -> dict[str, Any]: + raw = ConfigWriter.get_api_service_raw(SERVICE_ID) + return raw if isinstance(raw, dict) else {} + + +def _resolve_verify_ssl(raw: dict[str, Any]) -> bool: + """Read verify_ssl with the same priority as sangfor_sip / onesec: + verify_ssl > ssl_verify > custom_settings.verify_ssl > False. + AF devices commonly use self-signed certs, so default is False. + """ + value = raw.get("verify_ssl") + if value is None: + value = raw.get("ssl_verify") + if value is None: + custom = raw.get("custom_settings") + if isinstance(custom, dict): + value = custom.get("verify_ssl") + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.strip().lower() in {"1", "true", "yes", "on"} + return False + + +def _resolve_runtime_config() -> tuple[str, int, str, str, bool]: + raw = _service_config() + base_url = (_resolve_ref(raw.get("base_url")) or DEFAULT_BASE_URL).rstrip("/") + timeout = raw.get("timeout", DEFAULT_TIMEOUT) + try: + timeout = int(timeout) + except (TypeError, ValueError): + timeout = DEFAULT_TIMEOUT + + sm = _get_secret_manager() + username = ( + _resolve_ref(raw.get("username")) + or sm.get("sangfor_af_v8_0_106_username") + or os.getenv("AF_USERNAME") + ) + password = ( + _resolve_ref(raw.get("password")) + or sm.get("sangfor_af_v8_0_106_password") + or os.getenv("AF_PASSWORD") + ) + if not username or not password: + raise ValueError( + "AF API credentials not configured. " + "Please set username and password in the sangfor_af_v8_0_106 service configuration." + ) + return base_url, timeout, username, password, _resolve_verify_ssl(raw) + + +# ── Session / Token management ──────────────────────────────────────────────── + +async def _login(session, base_url, username, password, verify_ssl): + url = f"{base_url}{API_V1}/login" + try: + async with session.post( + url, + json={"name": username, "password": password}, + ssl=verify_ssl, + ) as resp: + data = await resp.json(content_type=None) + except aiohttp.ClientError as exc: + return None, f"AF login request failed: {exc}" + code = data.get("code") + if code != 0: + return None, f"AF login failed (code={code}): {data.get('message', 'Unknown')}" + token = data.get("data", {}).get("loginResult", {}).get("token") + if not token: + return None, "AF login succeeded but no token returned" + return token, None + + +async def _get_token(session, base_url, username, password, verify_ssl): + cached = _TOKEN_CACHE.get(base_url) + if cached: + try: + async with session.get( + f"{base_url}{API_V1}/keepalive", + headers={"Cookie": f"token={cached}"}, + ssl=verify_ssl, + ) as resp: + ka = await resp.json(content_type=None) + if ka.get("code") == 0: + return cached, None + except Exception: + pass + token, err = await _login(session, base_url, username, password, verify_ssl) + if err: + return None, err + _TOKEN_CACHE[base_url] = token + return token, None + + +# ── Low-level HTTP ──────────────────────────────────────────────────────────── + +def _pick(params: dict[str, Any], *keys: str) -> dict[str, Any]: + return {k: params[k] for k in keys if k in params and params[k] is not None} + + +def _af_result(action: str, payload: Any) -> ToolResult: + metadata = {"source": "Sangfor AF", "api": action, "version": "8.0.106"} + if isinstance(payload, dict): + code = payload.get("code") + if code not in (None, 0): + msg = payload.get("message", "Unknown error") + return ToolResult(success=False, error=f"AF API error (code={code}): {msg}", metadata=metadata) + return ToolResult(success=True, output=payload.get("data", payload), metadata=metadata) + return ToolResult(success=True, output=payload, metadata=metadata) + + +async def _call(method, path, params=None, json=None, action="") -> ToolResult: + try: + base_url, timeout, username, password, verify_ssl = _resolve_runtime_config() + except ValueError as exc: + return ToolResult(success=False, error=str(exc)) + headers = {"Content-Type": "application/json"} + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session: + token, err = await _get_token(session, base_url, username, password, verify_ssl) + if err: + return ToolResult(success=False, error=err) + headers["Cookie"] = f"token={token}" + url = f"{base_url}{path}" + try: + async with session.request( + method.upper(), url, params=params, json=json, headers=headers, ssl=verify_ssl, + ) as resp: + if resp.status >= 400: + text = await resp.text() + return ToolResult(success=False, error=f"HTTP {resp.status}: {text[:500]}") + data = await resp.json(content_type=None) + except aiohttp.ClientError as exc: + return ToolResult(success=False, error=f"Request failed: {exc}") + except Exception as exc: + return ToolResult(success=False, error=f"Unexpected error: {exc}") + return _af_result(action or path.rsplit("/", 1)[-1], data) + + +# ── Auth ────────────────────────────────────────────────────────────────────── + +async def _do_login(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + try: + base_url, timeout, username, password, verify_ssl = _resolve_runtime_config() + except ValueError as exc: + return ToolResult(success=False, error=str(exc)) + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session: + token, err = await _login(session, base_url, username, password, verify_ssl) + if err: + return ToolResult(success=False, error=err) + _TOKEN_CACHE[base_url] = token + return ToolResult(success=True, output={"token": token, "message": "Login successful"}, metadata={"source": "Sangfor AF", "api": "login", "version": "8.0.106"}) + + +async def _do_logout(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + result = await _call("POST", f"{API_V1}/logout", action="logout") + try: + base_url, *_ = _resolve_runtime_config() + _TOKEN_CACHE.pop(base_url, None) + except ValueError: + pass + return result + + +async def _do_keepalive(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/keepalive", action="keepalive") + + +# ── Objects ─────────────────────────────────────────────────────────────────── + +async def _do_get_ipgroups(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "businessType", "__nameprefix", "important", "_search", "_order", "_sortby", "addressType") + return await _call("GET", f"{API_V1}/ipgroups", params=query, action="get_ipgroups") + + +async def _do_get_ipgroup(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/ipgroups/{params.get('uuid', '')}", action="get_ipgroup") + + +async def _do_create_ipgroup(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + body = _pick(params, "name", "businessType", "description", "addressType", "important", "ipRanges", "creator") + return await _call("POST", f"{API_V1}/ipgroups", json={"obj": body}, action="create_ipgroup") + + +async def _do_update_ipgroup(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + body = _pick(params, "name", "businessType", "description", "addressType", "important", "ipRanges") + return await _call("PATCH", f"{API_V1}/ipgroups/{params.get('uuid', '')}", json={"obj": body}, action="update_ipgroup") + + +async def _do_delete_ipgroup(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("DELETE", f"{API_V1}/ipgroups/{params.get('uuid', '')}", action="delete_ipgroup") + + +async def _do_get_services(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "_search", "_order", "_sortby", "serviceType") + return await _call("GET", f"{API_V1}/services", params=query, action="get_services") + + +async def _do_get_service(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/services/{params.get('uuid', '')}", action="get_service") + + +# Link probe (new in v8.0.106 objects chapter 4.2) +async def _do_get_link_probes(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "_search", "_order", "_sortby") + return await _call("GET", f"{API_V1}/linkprobes", params=query or None, action="get_link_probes") + + +async def _do_get_link_probe(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/linkprobes/{params.get('uuid', '')}", action="get_link_probe") + + +# ── Monitoring ──────────────────────────────────────────────────────────────── + +async def _do_get_user_traffic_rank(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + body = _pick(params, "topNumber", "vsys", "line", "applicationType", "filterObject") + return await _call("POST", f"{API_V1}/topusertraffics", params={"_method": "GET"}, json=body or {}, action="get_user_traffic_rank") + + +async def _do_get_ip_traffic_trend(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + # /iptraffics is not a paged endpoint; _start/_length must not be sent. + # topNumber must be int — AF returns code=1001 for any non-int value. + query = _pick(params, "vsys", "topNumber", "unit", "minutes") + if "topNumber" in query: + try: + query["topNumber"] = int(query["topNumber"]) + except (TypeError, ValueError): + pass + return await _call("GET", f"{API_V1}/iptraffics", params=query or None, action="get_ip_traffic_trend") + + +async def _do_get_app_traffic_rank(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "vsys", "line", "topNumber") + if "topNumber" in query: + try: + query["topNumber"] = int(query["topNumber"]) + except (TypeError, ValueError): + pass + return await _call("GET", f"{API_V1}/apptrafficrank", params=query or None, action="get_app_traffic_rank") + + +async def _do_get_session_dailys(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "vsys", "ip") + return await _call("GET", f"{API_V1}/sessiondailys", params=query or None, action="get_session_dailys") + + +async def _do_get_session_details(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + # Endpoint needs explicit filters; without them AF returns 1004 "没有返回值". + query = _pick(params, "vsys", "srcIP", "dstIP", "protocol", "srcPort", "dstPort") + return await _call("GET", f"{API_V1}/sessiondetails", params=query or None, action="get_session_details") + + +async def _do_get_session_count_trend(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "vsys", "minutes") + return await _call("GET", f"{API_V1}/sessioncounttrend", params=query or None, action="get_session_count_trend") + + +async def _do_get_session_src_ip(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + # srcIP is required; AF returns 1004 "没有返回值" when omitted. + query = _pick(params, "vsys", "srcIP") + return await _call("GET", f"{API_V1}/sessionsrcip", params=query or None, action="get_session_src_ip") + + +async def _do_get_session_count_rank(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "vsys", "topNumber") + return await _call("GET", f"{API_V1}/sessioncountrank", params=query or None, action="get_session_count_rank") + + +async def _do_get_session_summary(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "vsys") + return await _call("GET", f"{API_V1}/sessionsummary", params=query or None, action="get_session_summary") + + +async def _do_get_sessions(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + body = _pick(params, "_start", "_length", "vsys", "srcIP", "dstIP", "protocol", "srcPort", "dstPort") + # AF8.0.x requires POST + ?_method=GET for /sessions; plain GET returns 1002. + return await _call("POST", f"{API_V1}/sessions", params={"_method": "GET"}, json=body or {}, action="get_sessions") + + +async def _do_get_vsys_session_summary(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "vsys") + return await _call("GET", f"{API_V1}/vsyssessionsummary", params=query or None, action="get_vsys_session_summary") + + +# Session management actions (new in v8.0.106) +async def _do_clear_session(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + body = _pick(params, "srcIP", "dstIP", "srcPort", "dstPort", "protocol", "vsys") + return await _call("DELETE", f"{API_V1}/sessions", json={"obj": body}, action="clear_session") + + +async def _do_block_session(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + body = _pick(params, "srcIP", "dstIP", "srcPort", "dstPort", "protocol", "vsys") + return await _call("POST", f"{API_V1}/sessions/block", json={"obj": body}, action="block_session") + + +async def _do_batch_clear_sessions(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + items = params.get("items", []) + return await _call("POST", f"{API_BATCH}/sessions", params={"_method": "DELETE"}, json=items, action="batch_clear_sessions") + + +async def _do_get_session_recheck(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/sessionrecheck", action="get_session_recheck") + + +async def _do_set_session_recheck(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + body = _pick(params, "enable", "interval") + return await _call("PUT", f"{API_V1}/sessionrecheck", json={"obj": body}, action="set_session_recheck") + + +# Statistics +async def _do_get_packet_drop_stats(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length") + return await _call("GET", f"{API_V1}/mbufdroppointstatistics", params=query or None, action="get_packet_drop_stats") + + +async def _do_clear_packet_drop_stats(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("DELETE", f"{API_V1}/mbufdroppointstatistics", action="clear_packet_drop_stats") + + +async def _do_get_mbuf_stats(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/mbufstatistics", action="get_mbuf_stats") + + +async def _do_get_hash_table_stats(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length") + return await _call("GET", f"{API_V1}/hashtablestatistics", params=query or None, action="get_hash_table_stats") + + +async def _do_get_monitor_ips(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length") + return await _call("GET", f"{API_V1}/monitorips", params=query or None, action="get_monitor_ips") + + +# ── Alarm/notification actions (new in v8.0.106, section 3.4.3) ────────────── + +async def _do_get_alarm_events_config(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/alarm/events", action="get_alarm_events_config") + + +async def _do_get_alarm_notifications(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/alarm/notifications", action="get_alarm_notifications") + + +async def _do_get_alarm_messages(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "_search", "_order", "_sortby") + return await _call("GET", f"{API_V1}/alarm/messages", params=query or None, action="get_alarm_messages") + + +async def _do_get_topn_config(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/topnconfig", action="get_topn_config") + + +# ── Operations center ───────────────────────────────────────────────────────── + +async def _do_get_blackwhitelist(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "type", "_start", "_length", "_search", "_order", "description") + return await _call("GET", f"{API_V1}/whiteblacklist", params=query, action="get_blackwhitelist") + + +async def _do_add_blackwhitelist(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + body = _pick(params, "url", "type", "enable", "description", "domain") + return await _call("POST", f"{API_V1}/whiteblacklist", json={"obj": body}, action="add_blackwhitelist") + + +async def _do_batch_add_blackwhitelist(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("POST", f"{API_BATCH}/whiteblacklist", json=params.get("items", []), action="batch_add_blackwhitelist") + + +async def _do_delete_blackwhitelist(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + url_param = params.get("url", "") + list_type = params.get("type", "") + query = {"type": list_type} if list_type else None + return await _call("DELETE", f"{API_V1}/whiteblacklist/{url_param}", params=query, action="delete_blackwhitelist") + + +async def _do_batch_delete_blackwhitelist(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("POST", f"{API_BATCH}/whiteblacklist", params={"_method": "DELETE"}, json=params.get("items", []), action="batch_delete_blackwhitelist") + + +async def _do_get_blockip_list(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "_sortby", "_order", "creator", "fuzzyIP") + return await _call("GET", f"{API_V1}/blockip", params=query, action="get_blockip_list") + + +async def _do_batch_add_blockip(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "aifwType") + return await _call("POST", f"{API_BATCH}/blockip", params=query or None, json=params.get("items", []), action="batch_add_blockip") + + +async def _do_batch_delete_blockip(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("POST", f"{API_BATCH}/blockip", params={"_method": "DELETE"}, json=params.get("items", []), action="batch_delete_blockip") + + +async def _do_clear_blockip(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "creator") + return await _call("DELETE", f"{API_V1}/blockip", params=query or None, action="clear_blockip") + + +async def _do_get_blockip_auto_config(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/blockip/autoconfig", action="get_blockip_auto_config") + + +async def _do_set_blockip_auto_config(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("PUT", f"{API_V1}/blockip/autoconfig", json={"obj": _pick(params, "blockTime")}, action="set_blockip_auto_config") + + +# ── Status ──────────────────────────────────────────────────────────────────── + +async def _do_get_memory_usage(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/memoryusage", action="get_memory_usage") + + +async def _do_get_memory_trend(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "minutes") + return await _call("GET", f"{API_V1}/memoryusagetrend", params=query or None, action="get_memory_trend") + + +async def _do_get_cpu_usage(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/cpuusage", action="get_cpu_usage") + + +async def _do_get_cpu_trend(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "minutes") + return await _call("GET", f"{API_V1}/cpuusagetrend", params=query or None, action="get_cpu_trend") + + +async def _do_get_disk_usage(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/diskusage", action="get_disk_usage") + + +async def _do_get_system_version(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "filter") + return await _call("GET", f"{API_V1}/systemversion", params=query or None, action="get_system_version") + + +async def _do_get_interface_status(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + # AF8.0.x: /interfacestatus returns 1002; use /interfaces (list) or + # /interfaces/status?interfaceName= (single interface query). + iface = params.get("interfaceNames") or params.get("interfaceName") or "" + if iface: + return await _call( + "GET", f"{API_V1}/interfaces/status", + params={"interfaceName": iface}, + action="get_interface_status", + ) + return await _call("GET", f"{API_V1}/interfaces", action="get_interface_status") + + +async def _do_get_interface_throughput(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "ifname", "minutes") + return await _call("GET", f"{API_V1}/interfacethroughput", params=query or None, action="get_interface_throughput") + + +async def _do_get_runtime_status(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/runtimestatus", action="get_runtime_status") + + +async def _do_get_current_time(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/currenttime", action="get_current_time") + + +# ── Network ─────────────────────────────────────────────────────────────────── + +async def _do_get_routes(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "routeType", "_search") + return await _call("GET", f"{API_V1}/routes", params=query or None, action="get_routes") + + +async def _do_get_routes_ipv6(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "routeType", "_search") + return await _call("GET", f"{API_V1}/routes/ipv6", params=query or None, action="get_routes_ipv6") + + +# ── System ──────────────────────────────────────────────────────────────────── + +async def _do_get_accounts(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + query = _pick(params, "_start", "_length", "_search", "enable") + return await _call("GET", f"{API_V1}/account", params=query or None, action="get_accounts") + + +async def _do_get_account(ctx: ToolContext, **params: Any) -> ToolResult: + del ctx + return await _call("GET", f"{API_V1}/account/{params.get('name', '')}", action="get_account") + + +# ── Action dispatch ─────────────────────────────────────────────────────────── + +_ACTION_MAP: dict[str, Callable] = { + # Auth + "login": _do_login, + "logout": _do_logout, + "keepalive": _do_keepalive, + # Objects + "get_ipgroups": _do_get_ipgroups, + "get_ipgroup": _do_get_ipgroup, + "create_ipgroup": _do_create_ipgroup, + "update_ipgroup": _do_update_ipgroup, + "delete_ipgroup": _do_delete_ipgroup, + "get_services": _do_get_services, + "get_service": _do_get_service, + "get_link_probes": _do_get_link_probes, + "get_link_probe": _do_get_link_probe, + # Monitoring + "get_user_traffic_rank": _do_get_user_traffic_rank, + "get_ip_traffic_trend": _do_get_ip_traffic_trend, + "get_app_traffic_rank": _do_get_app_traffic_rank, + "get_session_dailys": _do_get_session_dailys, + "get_session_details": _do_get_session_details, + "get_session_count_trend": _do_get_session_count_trend, + "get_session_src_ip": _do_get_session_src_ip, + "get_session_count_rank": _do_get_session_count_rank, + "get_session_summary": _do_get_session_summary, + "get_vsys_session_summary": _do_get_vsys_session_summary, + "get_sessions": _do_get_sessions, + "get_monitor_ips": _do_get_monitor_ips, + # Session management + "clear_session": _do_clear_session, + "block_session": _do_block_session, + "batch_clear_sessions": _do_batch_clear_sessions, + "get_session_recheck": _do_get_session_recheck, + "set_session_recheck": _do_set_session_recheck, + # Statistics + "get_packet_drop_stats": _do_get_packet_drop_stats, + "clear_packet_drop_stats": _do_clear_packet_drop_stats, + "get_mbuf_stats": _do_get_mbuf_stats, + "get_hash_table_stats": _do_get_hash_table_stats, + # Alarm/notification + "get_alarm_events_config": _do_get_alarm_events_config, + "get_alarm_notifications": _do_get_alarm_notifications, + "get_alarm_messages": _do_get_alarm_messages, + "get_topn_config": _do_get_topn_config, + # Operations center + "get_blackwhitelist": _do_get_blackwhitelist, + "add_blackwhitelist": _do_add_blackwhitelist, + "batch_add_blackwhitelist": _do_batch_add_blackwhitelist, + "delete_blackwhitelist": _do_delete_blackwhitelist, + "batch_delete_blackwhitelist": _do_batch_delete_blackwhitelist, + "get_blockip_list": _do_get_blockip_list, + "batch_add_blockip": _do_batch_add_blockip, + "batch_delete_blockip": _do_batch_delete_blockip, + "clear_blockip": _do_clear_blockip, + "get_blockip_auto_config": _do_get_blockip_auto_config, + "set_blockip_auto_config": _do_set_blockip_auto_config, + # Status + "get_memory_usage": _do_get_memory_usage, + "get_memory_trend": _do_get_memory_trend, + "get_cpu_usage": _do_get_cpu_usage, + "get_cpu_trend": _do_get_cpu_trend, + "get_disk_usage": _do_get_disk_usage, + "get_system_version": _do_get_system_version, + "get_interface_status": _do_get_interface_status, + "get_interface_throughput": _do_get_interface_throughput, + "get_runtime_status": _do_get_runtime_status, + "get_current_time": _do_get_current_time, + # Network + "get_routes": _do_get_routes, + "get_routes_ipv6": _do_get_routes_ipv6, + # System + "get_accounts": _do_get_accounts, + "get_account": _do_get_account, +} + +GROUP_ACTIONS: dict[str, set[str]] = { + "auth": {"login", "logout", "keepalive"}, + "objects": { + "get_ipgroups", "get_ipgroup", "create_ipgroup", "update_ipgroup", "delete_ipgroup", + "get_services", "get_service", + "get_link_probes", "get_link_probe", + }, + "monitor": { + "get_user_traffic_rank", "get_ip_traffic_trend", "get_app_traffic_rank", + "get_session_dailys", "get_session_details", "get_session_count_trend", + "get_session_src_ip", "get_session_count_rank", "get_session_summary", + "get_vsys_session_summary", "get_sessions", "get_monitor_ips", + "clear_session", "block_session", "batch_clear_sessions", + "get_session_recheck", "set_session_recheck", + "get_packet_drop_stats", "clear_packet_drop_stats", + "get_mbuf_stats", "get_hash_table_stats", + }, + "alarm": { + "get_alarm_events_config", "get_alarm_notifications", + "get_alarm_messages", "get_topn_config", + }, + "ops": { + "get_blackwhitelist", "add_blackwhitelist", "batch_add_blackwhitelist", + "delete_blackwhitelist", "batch_delete_blackwhitelist", + "get_blockip_list", "batch_add_blockip", "batch_delete_blockip", + "clear_blockip", "get_blockip_auto_config", "set_blockip_auto_config", + }, + "status": { + "get_memory_usage", "get_memory_trend", "get_cpu_usage", "get_cpu_trend", + "get_disk_usage", "get_system_version", "get_interface_status", + "get_interface_throughput", "get_runtime_status", "get_current_time", + }, + "network": {"get_routes", "get_routes_ipv6"}, + "system": {"get_accounts", "get_account"}, +} + +_CONNECTIVITY_TEST_ACTIONS: dict[str, str] = { + "auth": "keepalive", + "objects": "get_ipgroups", + "monitor": "get_session_summary", + "alarm": "get_alarm_notifications", + "ops": "get_blackwhitelist", + "status": "get_system_version", + "network": "get_routes", + "system": "get_accounts", +} + + +async def unified_ops(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + handler = _ACTION_MAP.get(action) + if handler is None: + available = ", ".join(sorted(_ACTION_MAP)) + return ToolResult(success=False, error=f"Unknown action: {action}. Available: {available}") + return await handler(ctx, **params) + + +async def _dispatch_group(ctx: ToolContext, group: str, action: str, **params: Any) -> ToolResult: + if action == "test": + return await unified_ops(ctx, action=_CONNECTIVITY_TEST_ACTIONS.get(group, "get_system_version"), **params) + if action not in GROUP_ACTIONS[group]: + available = ", ".join(sorted(GROUP_ACTIONS[group])) + return ToolResult(success=False, error=f"Unsupported {group} action: {action}. Available: {available}") + return await unified_ops(ctx, action=action, **params) + + +async def auth(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "auth", action, **params) + + +async def objects(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "objects", action, **params) + + +async def monitor(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "monitor", action, **params) + + +async def alarm(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "alarm", action, **params) + + +async def ops(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "ops", action, **params) + + +async def status(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "status", action, **params) + + +async def network(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "network", action, **params) + + +async def system(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "system", action, **params) + + +def _make_action_function(action: str): + async def _tool(ctx: ToolContext, **kwargs: Any) -> ToolResult: + return await unified_ops(ctx, action=action, **kwargs) + _tool.__name__ = action + return _tool + + +for _action_name in _ACTION_MAP: + globals()[_action_name] = _make_action_function(_action_name) + +del _action_name diff --git a/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_alarm.yaml b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_alarm.yaml new file mode 100644 index 00000000..24d98cb5 --- /dev/null +++ b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_alarm.yaml @@ -0,0 +1,70 @@ +name: sangfor_af_v106_alarm +description: > + Sangfor AF v8.0.106 alarm and notification tool. Query alarm event + configurations, notification channels, registered events, provider + information, message templates, and alarm message logs. + These APIs are available from v8.0.106 onwards. +description_cn: > + 深信服 AF v8.0.106 告警通知工具(v8.0.85 及以下不含此功能)。通过 `action` 参数 + 查询告警事件配置、告警通知渠道、在册事件、服务商信息、信息模板及告警消息日志。 +category: custom +enabled: true +requires_confirmation: false +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 告警通知动作名,可选值: + - get_alarm_events_config + 用途: 获取告警事件配置(触发条件、阈值等) + 必填: 无 + 风险提示: 只读接口 + 是否任务型: 否 + - get_alarm_notifications + 用途: 获取告警通知渠道配置(邮件/短信/企业微信等) + 必填: 无 + 风险提示: 只读接口 + 是否任务型: 否 + - get_alarm_messages + 用途: 获取历史告警消息列表(已触发的告警) + 必填: 无 + 常用: _start、_length、_search、_order + 风险提示: 只读接口 + 是否任务型: 否 + - get_topn_config + 用途: 获取 TOP N 监控配置(流量排行展示数量等) + 必填: 无 + 风险提示: 只读接口 + 是否任务型: 否 + enum: + - get_alarm_events_config + - get_alarm_notifications + - get_alarm_messages + - get_topn_config + _start: + type: integer + description: 分页起始位置(从0开始) + _length: + type: integer + description: 每页最大返回数量 + _search: + type: string + description: 模糊搜索关键字 + _order: + type: string + description: "排序方向:asc 或 desc" + enum: + - asc + - desc + _sortby: + type: string + description: 排序字段名 + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: alarm diff --git a/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_auth.yaml b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_auth.yaml new file mode 100644 index 00000000..3e8e7f5f --- /dev/null +++ b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_auth.yaml @@ -0,0 +1,44 @@ +name: sangfor_af_v106_auth +description: > + Sangfor AF v8.0.48 authentication tool. Use the `action` parameter to + login, logout, or keep the session alive. Token is cached automatically + after a successful login. +description_cn: > + 深信服 AF v8.0.48 认证工具。通过 `action` 参数调用登录、注销或 token 保活接口。 + 登录成功后 token 会自动缓存,后续调用无需手动传 token。 +category: custom +enabled: true +requires_confirmation: false +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 认证动作名,可选值: + - login + 用途: 登录设备,获取 session token(token 自动缓存) + 必填: 无(用户名/密码从服务配置读取) + 风险提示: 只读认证接口 + 是否任务型: 否 + - logout + 用途: 注销当前登录 session,清除 token 缓存 + 必填: 无 + 风险提示: 写操作,注销后需重新登录 + 是否任务型: 否 + - keepalive + 用途: 刷新 token 超时计时器,保持 session 活跃 + 必填: 无 + 风险提示: 只读接口 + 是否任务型: 否 + enum: + - login + - logout + - keepalive + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: auth diff --git a/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_monitor.yaml b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_monitor.yaml new file mode 100644 index 00000000..4f4a0fd9 --- /dev/null +++ b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_monitor.yaml @@ -0,0 +1,240 @@ +name: sangfor_af_v106_monitor +description: > + Sangfor AF v8.0.106 monitoring tool. Provides real-time and historical + session data, traffic rankings, network statistics, packet diagnostics, + and session management actions (clear, block, batch-clear, recheck config). + Monitoring APIs require v8.0.85+; session management actions require v8.0.106+. +description_cn: > + 深信服 AF v8.0.106 监控工具。通过 `action` 参数查询实时/历史会话数据、 + 流量排行、网络统计及报文诊断信息,并支持会话管理(清除/阻断/批量清除/recheck配置)。 + 监控查询接口需要 v8.0.85+,会话管理接口需要 v8.0.106+。 +category: custom +enabled: true +requires_confirmation: false +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 监控动作名,可选值: + + ## 流量排行 + - get_user_traffic_rank + 用途: 获取用户流量排行(Top N 用户) + 必填: 无 + 常用: topNumber(前N名,默认10)、vsys、line、applicationType + 风险提示: 只读接口 + 是否任务型: 否 + - get_ip_traffic_trend + 用途: 获取 IP 流量趋势曲线(指定前5或10名IP) + 必填: 无 + 常用: topNumber、vsys、unit、minutes + 风险提示: 只读接口 + 是否任务型: 否 + - get_app_traffic_rank + 用途: 获取应用流量排行(Top N 应用) + 必填: 无 + 常用: topNumber、vsys、line + 风险提示: 只读接口 + 是否任务型: 否 + + ## 会话排行与统计 + - get_session_dailys + 用途: 获取每日新建会话信息 + 必填: 无 + 常用: vsys、ip、_start、_length + 风险提示: 只读接口 + 是否任务型: 否 + - get_session_details + 用途: 获取会话详情列表(含5层信息) + 必填: 无 + 常用: vsys、srcIP、dstIP、protocol、_start、_length + 风险提示: 只读接口 + 是否任务型: 否 + - get_session_count_trend + 用途: 获取会话数量趋势折线图数据 + 必填: 无 + 常用: vsys、minutes(最近N分钟,默认60) + 风险提示: 只读接口 + 是否任务型: 否 + - get_session_src_ip + 用途: 获取指定源IP的会话详情(按目的IP分组) + 必填: 无 + 常用: srcIP、vsys、_start、_length + 风险提示: 只读接口 + 是否任务型: 否 + - get_session_count_rank + 用途: 获取会话数量排行(Top N 源IP) + 必填: 无 + 常用: topNumber、vsys + 风险提示: 只读接口 + 是否任务型: 否 + - get_session_summary + 用途: 获取会话概要信息(总数、协议分布等) + 必填: 无 + 常用: vsys + 风险提示: 只读接口 + 是否任务型: 否 + - get_monitor_ips + 用途: 获取配置中心监听列表IP范围 + 必填: 无 + 常用: _start、_length + 风险提示: 只读接口 + 是否任务型: 否 + - get_sessions + 用途: 获取实时会话列表(当前活跃连接) + 必填: 无 + 常用: vsys、srcIP、dstIP、protocol、srcPort、dstPort、_start、_length + 风险提示: 只读接口 + 是否任务型: 否 + + ## 统计与诊断 + - get_packet_drop_stats + 用途: 获取 mbuf 丢包点统计信息列表 + 必填: 无 + 风险提示: 只读接口 + 是否任务型: 否 + - clear_packet_drop_stats + 用途: 清除后台丢包统计信息 + 必填: 无 + 风险提示: 写操作,清除统计数据不可恢复 + 是否任务型: 否 + - get_mbuf_stats + 用途: 获取 mbuf 内存统计信息 + 必填: 无 + 风险提示: 只读接口 + 是否任务型: 否 + - get_hash_table_stats + 用途: 获取哈希表统计列表 + 必填: 无 + 常用: _start、_length + 风险提示: 只读接口 + 是否任务型: 否 + + ## 会话管理(v8.0.106 新增) + - clear_session + 用途: 清除指定条件的会话 + 必填: 无(不填则清除全部) + 常用: srcIP、dstIP、protocol + 风险提示: 高风险写操作;会强制断开匹配的会话连接 + 是否任务型: 否 + - block_session + 用途: 阻断指定会话(源IP发起的连接) + 必填: srcIP + 常用: dstIP、srcPort、dstPort、protocol + 风险提示: 高风险写操作;立即中断指定会话 + 是否任务型: 否 + - batch_clear_sessions + 用途: 批量清除会话 + 必填: items(数组,每项含会话筛选条件) + 风险提示: 高风险写操作,批量断开连接 + 是否任务型: 否 + - get_session_recheck + 用途: 获取 session recheck 配置信息 + 必填: 无 + 风险提示: 只读接口 + 是否任务型: 否 + - set_session_recheck + 用途: 设置 session recheck 配置(定期检查策略变更) + 必填: 无 + 常用: enable(true/false)、interval(秒) + 风险提示: 写操作,影响策略命中重检频率 + 是否任务型: 否 + - get_vsys_session_summary + 用途: 获取虚拟系统(vsys)会话概要信息 + 必填: 无 + 常用: vsys、_start、_length + 风险提示: 只读接口 + 是否任务型: 否 + enum: + - get_user_traffic_rank + - get_ip_traffic_trend + - get_app_traffic_rank + - get_session_dailys + - get_session_details + - get_session_count_trend + - get_session_src_ip + - get_session_count_rank + - get_session_summary + - get_monitor_ips + - get_sessions + - get_packet_drop_stats + - clear_packet_drop_stats + - get_mbuf_stats + - get_hash_table_stats + - clear_session + - block_session + - batch_clear_sessions + - get_session_recheck + - set_session_recheck + - get_vsys_session_summary + + enable: + type: boolean + description: 是否启用(用于 set_session_recheck) + interval: + type: integer + description: 检查间隔(秒,用于 set_session_recheck) + items: + type: array + items: + type: object + description: 批量操作的条目数组(用于 batch_clear_sessions) + + topNumber: + type: integer + description: 排行榜取前N名(如5或10) + vsys: + type: string + description: 虚拟系统名称(通常为 public,可省略) + line: + type: integer + description: "线路编号过滤,0=全部(范围0-256)" + applicationType: + type: array + items: + type: string + description: 应用类型过滤列表 + filterObject: + type: object + description: > + 用户流量排行过滤对象: + objectType=GROUP/USER/IP,对应 groups/users/ip 数组 + unit: + type: string + description: "流量单位(如 bps, Kbps, Mbps)" + minutes: + type: integer + description: 查询最近N分钟的数据(默认60) + ip: + type: string + description: IP地址过滤(格式:IPv4/IPv6) + srcIP: + type: string + description: 源IP地址过滤(格式:IPv4/IPv6) + dstIP: + type: string + description: 目的IP地址过滤 + protocol: + type: string + description: "协议过滤:TCP/UDP/ICMP/OTHER" + srcPort: + type: integer + description: 源端口过滤(0-65535) + dstPort: + type: integer + description: 目的端口过滤(0-65535) + _start: + type: integer + description: 分页起始位置(从0开始) + _length: + type: integer + description: 每页最大返回数量(最大200,默认100) + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: monitor diff --git a/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_network.yaml b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_network.yaml new file mode 100644 index 00000000..0d7f31d1 --- /dev/null +++ b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_network.yaml @@ -0,0 +1,65 @@ +name: sangfor_af_v106_network +description: > + Sangfor AF v8.0.48 network tool. Query routing tables (IPv4 and IPv6) + and network-related status information. +description_cn: > + 深信服 AF v8.0.48 网络工具。通过 `action` 参数查询路由表(IPv4/IPv6) + 及网络相关状态信息。 +category: custom +enabled: true +requires_confirmation: false +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 网络查询动作名,可选值: + - get_routes + 用途: 获取后台 IPv4 路由信息列表 + 必填: 无 + 常用: routeType(ALL_ROUTE/STATIC_ROUTE/DIRECT_ROUTE 等)、_start、_length + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_routes_ipv6 + 用途: 获取后台 IPv6 路由信息列表 + 必填: 无 + 常用: routeType、_start、_length + 风险提示: 只读查询接口 + 是否任务型: 否 + enum: + - get_routes + - get_routes_ipv6 + routeType: + type: string + description: > + 路由类型过滤:ALL_ROUTE=所有路由,STATIC_ROUTE=静态路由, + DIRECT_ROUTE=直连路由,OSPF_ROUTE=OSPF路由,RIP_ROUTE=RIP路由, + VPN_ROUTE=VPN路由,SSL_VPN_ROUTE=SSL VPN路由, + IBGP_ROUTE=IBGP路由,EBGP_ROUTE=EBGP路由 + enum: + - ALL_ROUTE + - STATIC_ROUTE + - DIRECT_ROUTE + - OSPF_ROUTE + - RIP_ROUTE + - VPN_ROUTE + - SSL_VPN_ROUTE + - IBGP_ROUTE + - EBGP_ROUTE + _start: + type: integer + description: 分页起始位置(从0开始) + _length: + type: integer + description: 每页最大返回数量(最大200,默认100) + _search: + type: string + description: 模糊搜索关键字 + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: network diff --git a/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_objects.yaml b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_objects.yaml new file mode 100644 index 00000000..f901ca8a --- /dev/null +++ b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_objects.yaml @@ -0,0 +1,152 @@ +name: sangfor_af_v106_objects +description: > + Sangfor AF v8.0.48 objects management tool. Query, create, update, and + delete network IP group objects and services (protocol/port definitions) + used in firewall policies. +description_cn: > + 深信服 AF v8.0.48 对象管理工具。通过 `action` 参数查询、创建、修改和删除 + IP 地址组对象及服务对象(协议/端口定义),这些对象被防火墙策略引用。 +category: custom +enabled: true +requires_confirmation: true +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 对象管理动作名,可选值: + + ## IP 地址组 + - get_ipgroups + 用途: 查询符合条件的 IP 地址组列表 + 必填: 无 + 常用: _start、_length、businessType、__nameprefix、important、_search + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_ipgroup + 用途: 获取单个 IP 地址组详情 + 必填: uuid + 风险提示: 只读查询接口 + 是否任务型: 否 + - create_ipgroup + 用途: 创建新的 IP 地址组 + 必填: name、businessType + 常用: ipRanges、addressType、description、important + 风险提示: 写操作;创建后可被防火墙策略引用 + 是否任务型: 否 + - update_ipgroup + 用途: 增量更新(PATCH)指定 IP 地址组 + 必填: uuid + 常用: name、ipRanges、description + 风险提示: 写操作;修改 IP 组会影响引用该组的所有策略 + 是否任务型: 否 + - delete_ipgroup + 用途: 删除指定 IP 地址组 + 必填: uuid + 风险提示: 高风险写操作;如有策略引用该组将删除失败 + 是否任务型: 否 + + ## 服务对象 + - get_services + 用途: 查询服务或服务组列表(预定义或自定义) + 必填: 无 + 常用: _start、_length、_search、serviceType + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_service + 用途: 获取单个服务或服务组详情 + 必填: uuid + 风险提示: 只读查询接口 + 是否任务型: 否 + enum: + - get_ipgroups + - get_ipgroup + - create_ipgroup + - update_ipgroup + - delete_ipgroup + - get_services + - get_service + + uuid: + type: string + description: IP地址组或服务对象的唯一标识符(32字符UUID) + name: + type: string + description: 对象名称(最大95字符) + businessType: + type: string + description: > + IP地址组业务类型:IP=IP地址,ADDRGROUP=地址组, + USER=用户地址,BUSINESS=业务地址 + enum: + - IP + - ADDRGROUP + - USER + - BUSINESS + addressType: + type: string + description: "IP协议版本:IPV4 或 IPV6" + enum: + - IPV4 + - IPV6 + important: + type: string + description: "重要级别:COMMON=普通,CORE=核心" + enum: + - COMMON + - CORE + ipRanges: + type: array + items: + type: object + properties: + start: + type: string + description: IP范围起始地址(如 192.168.1.1) + end: + type: string + description: IP范围结束地址(如 192.168.1.254) + description: IP地址范围列表 + description: + type: string + description: 对象描述(最大95字符) + creator: + type: string + description: 创建者名称 + serviceType: + type: string + description: "服务类型过滤:SERVICE=单个服务,SERVICEGROUP=服务组" + enum: + - SERVICE + - SERVICEGROUP + + # Pagination + _start: + type: integer + description: 分页起始位置(从0开始) + _length: + type: integer + description: 每页最大返回数量(最大200,默认100) + __nameprefix: + type: string + description: 按名称前缀过滤(最大95字符) + _search: + type: string + description: 模糊搜索关键字(最大95字符) + _order: + type: string + description: "排序方向:asc 或 desc" + enum: + - asc + - desc + _sortby: + type: string + description: 排序字段名 + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: objects diff --git a/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_ops.yaml b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_ops.yaml new file mode 100644 index 00000000..b63b3a6d --- /dev/null +++ b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_ops.yaml @@ -0,0 +1,165 @@ +name: sangfor_af_v106_ops +description: > + Sangfor AF v8.0.48 operations center tool. Manages blacklist/whitelist + entries (IPs, domains, URLs) and blocked attacker IPs via the `action` + parameter. Key security triage actions for SOC workflows. +description_cn: > + 深信服 AF v8.0.48 运营中心工具。通过 `action` 参数管理黑白名单(IP/域名/URL) + 和封锁攻击者 IP。是 SOC 安全处置的核心接口。 +category: custom +enabled: true +requires_confirmation: true +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 运营中心动作名,可选值: + + ## 黑白名单管理 + - get_blackwhitelist + 用途: 查询黑白名单列表(IP/域名/URL) + 必填: 无 + 常用: type(BLACK/WHITE)、_start、_length + 风险提示: 只读查询接口 + 是否任务型: 否 + - add_blackwhitelist + 用途: 添加单条黑白名单 + 必填: url(IP/域名/URL)、type(BLACK/WHITE) + 常用: enable、description、domain(0=IP,1=域名,2=URL) + 风险提示: 写操作;添加黑名单会拦截对应流量 + 是否任务型: 否 + - batch_add_blackwhitelist + 用途: 批量添加黑白名单 + 必填: items(数组,每项含 url/type 字段) + 风险提示: 写操作,批量添加黑名单影响面大 + 是否任务型: 否 + - delete_blackwhitelist + 用途: 删除单条黑白名单 + 必填: url(条目的 IP/域名/URL) + 常用: type(BLACK/WHITE) + 风险提示: 写操作,删除白名单可能导致误拦截 + 是否任务型: 否 + - batch_delete_blackwhitelist + 用途: 批量删除黑白名单 + 必填: items(数组,每项含 url 字段) + 风险提示: 写操作,批量删除影响面大 + 是否任务型: 否 + + ## 封锁攻击者 IP + - get_blockip_list + 用途: 查询当前封锁攻击者 IP 列表 + 必填: 无 + 常用: _start、_length、fuzzyIP(模糊搜索)、creator(AF/SIP) + 风险提示: 只读查询接口 + 是否任务型: 否 + - batch_add_blockip + 用途: 批量封锁攻击者 IP + 必填: items(数组,每项含 srcIP、dstIP 等字段) + 常用: aifwType(MANUAL/AUTO) + 风险提示: 高风险写操作;封锁 IP 会拦截其所有流量 + 是否任务型: 否 + - batch_delete_blockip + 用途: 批量解封攻击者 IP + 必填: items(数组,每项含 srcIP、dstIP 等字段) + 风险提示: 写操作,解封恶意 IP 存在安全风险 + 是否任务型: 否 + - clear_blockip + 用途: 清空封锁攻击者 IP 列表 + 必填: 无 + 常用: creator(AF/SIP,指定清除哪类封锁) + 风险提示: 高风险写操作;会清除所有封锁 IP + 是否任务型: 否 + - get_blockip_auto_config + 用途: 获取自动封锁攻击者时长配置 + 必填: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + - set_blockip_auto_config + 用途: 修改自动封锁攻击者时长 + 必填: blockTime(封锁时长,单位秒) + 风险提示: 写操作,影响自动封锁策略 + 是否任务型: 否 + enum: + - get_blackwhitelist + - add_blackwhitelist + - batch_add_blackwhitelist + - delete_blackwhitelist + - batch_delete_blackwhitelist + - get_blockip_list + - batch_add_blockip + - batch_delete_blockip + - clear_blockip + - get_blockip_auto_config + - set_blockip_auto_config + + # Blacklist/whitelist params + url: + type: string + description: IP地址、域名或URL(黑白名单条目值) + type: + type: string + description: "名单类型:BLACK(黑名单)或 WHITE(白名单)" + enum: + - BLACK + - WHITE + enable: + type: boolean + description: 是否启用该条目,默认 true + description: + type: string + description: 条目描述信息(最大95字符) + domain: + type: integer + description: "条目类型:0=IP地址,1=域名,2=URL" + enum: [0, 1, 2] + items: + type: array + items: + type: object + description: 批量操作时的条目数组,每项至少包含 url(黑白名单)或 srcIP/dstIP(封锁IP) + + # Block IP params + fuzzyIP: + type: string + description: 模糊搜索IP关键字(最大15字符) + creator: + type: string + description: "封锁来源身份:AF(防火墙自身)或 SIP(安全感知平台)" + enum: + - AF + - SIP + aifwType: + type: string + description: "添加封锁IP的类型:MANUAL(手动)或 AUTO(自动,需要 creator=SIP)" + enum: + - MANUAL + - AUTO + blockTime: + type: integer + description: 自动封锁时长(秒) + + # Pagination + _start: + type: integer + description: 分页起始位置(从0开始) + _length: + type: integer + description: 每页最大返回数量(最大200,默认100) + _sortby: + type: string + description: 排序字段名 + _order: + type: string + description: "排序方向:asc(升序)或 desc(降序)" + enum: + - asc + - desc + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: ops diff --git a/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_status.yaml b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_status.yaml new file mode 100644 index 00000000..9ea4d2c6 --- /dev/null +++ b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_status.yaml @@ -0,0 +1,120 @@ +name: sangfor_af_v106_status +description: > + Sangfor AF v8.0.48 device status tool. Query system resource usage + (CPU, memory, disk), firmware version, network interface status, + current time, and system uptime. +description_cn: > + 深信服 AF v8.0.48 状态中心工具。通过 `action` 参数查询系统资源(CPU/内存/磁盘)、 + 固件版本、网口状态、当前时间及系统运行时长等信息。 +category: custom +enabled: true +requires_confirmation: false +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 状态查询动作名,可选值: + - get_memory_usage + 用途: 获取当前内存使用率(百分比) + 必填: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_cpu_usage + 用途: 获取当前 CPU 使用率(百分比) + 必填: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_disk_usage + 用途: 获取当前磁盘使用率(百分比) + 必填: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_system_version + 用途: 获取 AF 系统固件版本信息 + 必填: 无 + 常用: filter(ALL/FULL/MAJOR/MINOR 等) + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_interface_status + 用途: 获取指定网口或全部网口的状态(流速、连接状态) + 必填: 无 + 常用: interfaceNames(如 eth0,不传则获取全部) + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_runtime_status + 用途: 获取系统运行时长(uptime) + 必填: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_current_time + 用途: 获取设备当前时间 + 必填: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + + ## 趋势数据(v8.0.106 新增) + - get_memory_trend + 用途: 获取内存使用率变化曲线数据 + 必填: 无 + 常用: minutes(最近N分钟,默认60) + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_cpu_trend + 用途: 获取 CPU 使用率变化曲线数据 + 必填: 无 + 常用: minutes(最近N分钟,默认60) + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_interface_throughput + 用途: 获取指定接口的吞吐率折线图数据 + 必填: 无 + 常用: ifname(接口名,如 eth0)、minutes + 风险提示: 只读查询接口 + 是否任务型: 否 + enum: + - get_memory_usage + - get_cpu_usage + - get_disk_usage + - get_system_version + - get_interface_status + - get_runtime_status + - get_current_time + - get_memory_trend + - get_cpu_trend + - get_interface_throughput + filter: + type: string + description: > + 版本信息过滤(仅用于 get_system_version): + ALL=显示所有,FULL=完整版本号,MAJOR=主版本号,MINOR=次版本号, + INCREASE=增版本号,BUILD=创建日期,EN=是否英文版,HF=是否HF版,B=是否Beta版 + enum: + - ALL + - FULL + - MAJOR + - MINOR + - INCREASE + - BUILD + - EN + - HF + - B + - R + - ADD + interfaceNames: + type: string + description: 网口名称(如 eth0),用于 get_interface_status;不填则获取全部接口 + ifname: + type: string + description: 接口名称(如 eth0),用于 get_interface_throughput + minutes: + type: integer + description: 获取最近N分钟的趋势数据(用于 get_memory_trend/get_cpu_trend/get_interface_throughput) + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: status diff --git a/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_system.yaml b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_system.yaml new file mode 100644 index 00000000..1acda1d2 --- /dev/null +++ b/.flocks/plugins/tools/api/sangfor_af_v8_0_106/sangfor_af_v106_system.yaml @@ -0,0 +1,53 @@ +name: sangfor_af_v106_system +description: > + Sangfor AF v8.0.48 system management tool. Query and manage administrator + accounts on the AF device. +description_cn: > + 深信服 AF v8.0.48 系统管理工具。通过 `action` 参数查询和管理 AF 设备上的 + 管理员账户信息。 +category: custom +enabled: true +requires_confirmation: true +provider: sangfor_af +inputSchema: + type: object + properties: + action: + type: string + description: | + 系统管理动作名,可选值: + - get_accounts + 用途: 查询所有管理员账户列表 + 必填: 无 + 常用: _start、_length、enable + 风险提示: 只读查询接口 + 是否任务型: 否 + - get_account + 用途: 查询指定管理员账户详情 + 必填: name(账户名) + 风险提示: 只读查询接口 + 是否任务型: 否 + enum: + - get_accounts + - get_account + name: + type: string + description: 管理员账户名(用于 get_account) + enable: + type: boolean + description: 按启用/禁用状态过滤账户 + _start: + type: integer + description: 分页起始位置(从0开始) + _length: + type: integer + description: 每页最大返回数量 + _search: + type: string + description: 模糊搜索关键字 + required: + - action +handler: + type: script + script_file: sangfor_af.handler.py + function: system diff --git a/.flocks/plugins/tools/api/tdp_v3_3_10/_test.yaml b/.flocks/plugins/tools/api/tdp_v3_3_10/_test.yaml index ef69108c..d6ff6eed 100644 --- a/.flocks/plugins/tools/api/tdp_v3_3_10/_test.yaml +++ b/.flocks/plugins/tools/api/tdp_v3_3_10/_test.yaml @@ -2,12 +2,13 @@ schema_version: 1 provider: tdp_api # Service-level connectivity probe. -# `system_status` with `action: all` is a lightweight multi-endpoint health -# check; it verifies HMAC-SHA256 signing and basic platform reachability. +# Use a single lightweight system-status endpoint here instead of the stricter +# multi-endpoint aggregate. This keeps the connectivity check focused on basic +# service health without coupling it to every subsystem state. connectivity: tool: tdp_system_status params: - action: all + action: service # Tool-level test samples shown in the WebUI ToolDetailDrawer drop-down. # `label` is the default (English) display string; `label_cn` is the diff --git a/README.md b/README.md index d794a97d..94e25de0 100644 --- a/README.md +++ b/README.md @@ -230,11 +230,9 @@ Initial setup: Non-browser clients (TUI, SDKs, scripts): -- **Local loopback** (`127.0.0.1` / `::1` / `localhost`, no - `x-forwarded-for` header) is auto-trusted as `local-service` admin. This - covers TUI, plugin sub-processes, and CLI calls running on the same host. -- **Remote** clients must present an API token. The token lives in - `~/.flocks/config/.secret.json` under the secret id `server_api_token`. +- All non-browser clients, including local loopback clients, must present an + API token. The token lives in `~/.flocks/config/.secret.json` under the + secret id `server_api_token`. On the **server**, generate (or rotate) the token — it is persisted on the server's local secret store: diff --git a/README_zh.md b/README_zh.md index 1392c1e4..a0d88a8d 100644 --- a/README_zh.md +++ b/README_zh.md @@ -218,8 +218,7 @@ flocks start --server-host 127.0.0.1 --webui-host 0.0.0.0 非浏览器客户端(TUI / SDK / 脚本): -- **本机回环**(`127.0.0.1` / `::1` / `localhost`,且请求不带 `x-forwarded-for` 头)会被自动识别为 `local-service` 管理员,满足同机的 TUI、插件子进程、CLI 调用。 -- **远程**调用必须携带 API Token。Token 存放于 `~/.flocks/config/.secret.json`,secret id 为 `server_api_token`。 +- 所有非浏览器客户端(包括本机回环调用)都必须携带 API Token。Token 存放于 `~/.flocks/config/.secret.json`,secret id 为 `server_api_token`。 在 **服务端** 生成(或轮换)token,会持久化到服务端本机的 secret store: diff --git a/flocks/agent/agents/rex/prompt_builder.py b/flocks/agent/agents/rex/prompt_builder.py index e3185779..42f03ea2 100644 --- a/flocks/agent/agents/rex/prompt_builder.py +++ b/flocks/agent/agents/rex/prompt_builder.py @@ -176,11 +176,23 @@ def build_dynamic_rex_prompt( Should I proceed with your original request, or try the alternative? ``` -### Image Analysis Limitation -If the user provides an image, image URL, or local image path and asks you to inspect, interpret, describe, extract, OCR, or analyze the image content: -- Do NOT claim you can analyze the image -- Clearly tell the user that Flocks does not support image analysis yet -- If helpful, ask the user to provide the relevant text or describe the image in words instead +### Visual / Image Input Handling +You may receive images as multimodal `image_url` content blocks attached to a user message. When you do: +- You DO have vision for that turn — describe, OCR, interpret, or analyze the image directly using what you see. Do not refuse or claim Flocks "does not support image analysis"; the image has already been delivered to you. +- Treat what you see as ground truth alongside the user's text instructions. +- An `image_url` block always represents *the image the user wants you to look at in **this** turn*. +- Do NOT confuse the current image(s) with anything from earlier turns. Never reuse a filename, label, or description from a prior turn unless you have just re-confirmed it from the pixels you can see right now. + +**Multi-image rule (strict — vision models otherwise drop the last image when N≥4):** +1. Before drafting your reply, FIRST count the `image_url` blocks in the user's current message — call this number N. +2. Begin your response with an opener that explicitly states the count, e.g. `您发送了 N 张图片,逐一解读如下:` (or `I will analyze all N images one by one:`). Anchoring N up front prevents the model from stopping early. +3. Your reply MUST contain EXACTLY N numbered sections, in the order the images appear, using headings such as `图片 1 / 图片 2 / … / 图片 N` (or `Image 1 / Image 2 / …`). Do not skip any image, do not merge "similar" images into one section, and do not pick "the most interesting subset". +4. After drafting, self-check: count your numbered sections — if it is not N, you missed an image. Add the missing section(s) before finalizing. + +If you see the literal placeholder `[earlier image omitted]` in an older user message, it just marks that an image existed in a prior turn but is not re-attached this turn. Treat it as opaque — you cannot re-inspect it. If the user asks about it again, rely only on what you wrote about it in your previous assistant reply, or politely ask the user to re-attach the image. + +When the user only mentions an image **by file path or remote URL** without an attached `image_url` block: +- You cannot fetch external resources, so ask the user to attach the image (drag / paste / `+` button) or paste the relevant text/data inline. --- diff --git a/flocks/browser/admin.py b/flocks/browser/admin.py index 44fa4ee6..ca115dff 100644 --- a/flocks/browser/admin.py +++ b/flocks/browser/admin.py @@ -9,6 +9,7 @@ from . import BROWSER_LABEL, PROJECT_ROOT, get_browser_version from . import _ipc as ipc +from .utils import load_env_file NAME = os.environ.get("BU_NAME", "default") @@ -25,17 +26,7 @@ def _load_env() -> None: for path in env_paths: if not path.exists(): continue - _load_env_file(path) - - -def _load_env_file(path: Path) -> None: - for line in path.read_text().splitlines(): - line = line.strip() - if not line or line.startswith("#") or "=" not in line: - continue - key, value = line.split("=", 1) - os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'")) - + load_env_file(path) _load_env() diff --git a/flocks/browser/daemon.py b/flocks/browser/daemon.py index cbe7b285..8ac95f85 100644 --- a/flocks/browser/daemon.py +++ b/flocks/browser/daemon.py @@ -14,6 +14,7 @@ from . import DEFAULT_AGENT_WORKSPACE, INTERNAL_URL_PREFIXES from . import _ipc as ipc +from .utils import load_env_file AGENT_WORKSPACE = Path(os.environ.get("BH_AGENT_WORKSPACE", DEFAULT_AGENT_WORKSPACE)).expanduser() @@ -56,17 +57,7 @@ def _load_env() -> None: for path in (Path(__file__).resolve().parents[2] / ".env", AGENT_WORKSPACE / ".env"): if not path.exists(): continue - _load_env_file(path) - - -def _load_env_file(path: Path) -> None: - for line in path.read_text().splitlines(): - line = line.strip() - if not line or line.startswith("#") or "=" not in line: - continue - key, value = line.split("=", 1) - os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'")) - + load_env_file(path) _load_env() diff --git a/flocks/browser/helpers.py b/flocks/browser/helpers.py index 782ea799..f6932d84 100644 --- a/flocks/browser/helpers.py +++ b/flocks/browser/helpers.py @@ -16,6 +16,7 @@ from . import DEFAULT_AGENT_WORKSPACE, INTERNAL_URL_PREFIXES from . import _ipc as ipc +from .utils import load_env_file AGENT_WORKSPACE = Path(os.environ.get("BH_AGENT_WORKSPACE", DEFAULT_AGENT_WORKSPACE)).expanduser() @@ -44,17 +45,7 @@ def _load_env() -> None: for path in (Path(__file__).resolve().parents[2] / ".env", AGENT_WORKSPACE / ".env"): if not path.exists(): continue - _load_env_file(path) - - -def _load_env_file(path: Path) -> None: - for line in path.read_text().splitlines(): - line = line.strip() - if not line or line.startswith("#") or "=" not in line: - continue - key, value = line.split("=", 1) - os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'")) - + load_env_file(path) _load_env() diff --git a/flocks/browser/utils.py b/flocks/browser/utils.py new file mode 100644 index 00000000..fad3b36e --- /dev/null +++ b/flocks/browser/utils.py @@ -0,0 +1,23 @@ +"""Shared browser utility helpers.""" + +import locale +import os +from pathlib import Path + + +def read_env_text(path: Path) -> str: + """Read an env file as UTF-8 first, then fall back to the local encoding.""" + try: + return path.read_text(encoding="utf-8-sig") + except UnicodeDecodeError: + return path.read_text(encoding=locale.getpreferredencoding(False)) + + +def load_env_file(path: Path) -> None: + """Populate ``os.environ`` from a simple ``KEY=VALUE`` env file.""" + for line in read_env_text(path).splitlines(): + line = line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, value = line.split("=", 1) + os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'")) diff --git a/flocks/cli/commands/update.py b/flocks/cli/commands/update.py index 46448304..bd59cd69 100644 --- a/flocks/cli/commands/update.py +++ b/flocks/cli/commands/update.py @@ -37,13 +37,23 @@ def update_command( async def _update(check: bool, yes: bool, force: bool = False, region: str | None = None) -> None: from flocks.updater import check_update, perform_update, detect_deploy_mode - with console.status("[cyan]正在检查版本...[/cyan]", spinner="dots"): - info = await check_update(region=region) + if not yes and not check and region is None: + use_cn_mirror = typer.confirm("\n是否使用中国镜像进行升级?", default=False) + if use_cn_mirror: + region = "cn" + + async def _load_update_info(selected_region: str | None): + with console.status("[cyan]正在检查版本...[/cyan]", spinner="dots"): + info = await check_update(region=selected_region) + + if info.error: + append_upgrade_text_log(f"ERROR version_check: {info.error}") + console.print(f"[red]检查失败:{info.error}[/red]") + raise typer.Exit(1) - if info.error: - append_upgrade_text_log(f"ERROR version_check: {info.error}") - console.print(f"[red]检查失败:{info.error}[/red]") - raise typer.Exit(1) + return info + + info = await _load_update_info(region) _print_version_table(info) diff --git a/flocks/cli/main.py b/flocks/cli/main.py index 7b13faa9..bdfa8d8d 100644 --- a/flocks/cli/main.py +++ b/flocks/cli/main.py @@ -6,6 +6,7 @@ import asyncio import os +import secrets as secrets_lib import sys from pathlib import Path from typing import Any, Optional @@ -74,6 +75,19 @@ console = Console() +def _ensure_server_api_token() -> bool: + """Ensure local non-browser clients such as `flocks tui` can authenticate.""" + from flocks.security import get_secret_manager + from flocks.server.auth import API_TOKEN_SECRET_ID + + secrets = get_secret_manager() + if secrets.get(API_TOKEN_SECRET_ID): + return False + + secrets.set(API_TOKEN_SECRET_ID, secrets_lib.token_urlsafe(32)) + return True + + def version_callback(value: bool): """Print version and exit""" if value: @@ -451,6 +465,8 @@ def tui( # Start server process env = os.environ.copy() + if _ensure_server_api_token(): + console.print("[dim]Initialized local API token for TUI access[/dim]") # Set auto-approve environment variable for TUI mode if auto_approve: diff --git a/flocks/cli/service_manager.py b/flocks/cli/service_manager.py index 9f0f3a0a..5bf2a234 100644 --- a/flocks/cli/service_manager.py +++ b/flocks/cli/service_manager.py @@ -1431,7 +1431,7 @@ def tail_lines(path: Path, lines: int) -> list[str]: return [line.rstrip("\n") for line in deque(handle, maxlen=max(lines, 0))] -def _emit_service_log_tail(console, log_path: Path, service_label: str, lines: int = 40) -> None: +def _emit_service_log_tail(console, log_path: Path, service_label: str, lines: int = 10) -> None: """Print the last *lines* lines of *log_path* to help diagnose failed daemon startups.""" if lines <= 0: return diff --git a/flocks/config/api_versioning.py b/flocks/config/api_versioning.py index 965136c9..1aacfd21 100644 --- a/flocks/config/api_versioning.py +++ b/flocks/config/api_versioning.py @@ -287,7 +287,9 @@ def migrate_api_services(*, backup: bool = True) -> Dict[str, str]: for desc in pending: # Deep-copy via JSON round-trip; api_services blocks are pure JSON. - services[desc.storage_key] = json.loads(json.dumps(services[desc.service_id])) + copied = json.loads(json.dumps(services[desc.service_id])) + from flocks.config.config_writer import ConfigWriter + services[desc.storage_key] = ConfigWriter._normalize_api_service_config(copied) actions[desc.storage_key] = "copied" log.info("versioning.migrated", { "service_id": desc.service_id, diff --git a/flocks/config/config.py b/flocks/config/config.py index dbaabbbe..fd2f05b5 100644 --- a/flocks/config/config.py +++ b/flocks/config/config.py @@ -206,6 +206,7 @@ class McpRemoteConfig(BaseModel): type: Literal["remote", "sse"] url: str enabled: Optional[bool] = None + transport: Optional[Literal["auto", "sse", "http"]] = "auto" headers: Optional[Dict[str, str]] = None oauth: Optional[Union[McpOAuthConfig, Literal[False]]] = None timeout: Optional[int] = Field(None, gt=0) diff --git a/flocks/config/config_writer.py b/flocks/config/config_writer.py index 74c609ef..ef5032ed 100644 --- a/flocks/config/config_writer.py +++ b/flocks/config/config_writer.py @@ -469,6 +469,22 @@ def update_mcp_server_field(cls, name: str, field: str, value: Any) -> bool: # API Services CRUD (api_services section) # ------------------------------------------------------------------ + @staticmethod + def _normalize_api_service_config(service_config: Dict[str, Any]) -> Dict[str, Any]: + """Canonicalize API service config keys before persisting. + + ``verify_ssl`` is the canonical field. ``ssl_verify`` remains a + read-time compatibility alias, but writers should never persist both + fields at once or keep writing the legacy alias forward. + """ + normalized = dict(service_config) + if "verify_ssl" in normalized: + normalized.pop("ssl_verify", None) + return normalized + if "ssl_verify" in normalized: + normalized["verify_ssl"] = normalized.pop("ssl_verify") + return normalized + @classmethod def get_api_service_raw(cls, service_id: str) -> Optional[Dict[str, Any]]: """Read a single ``api_services`` entry (raw, secrets unresolved). @@ -513,7 +529,7 @@ def set_api_service(cls, service_id: str, service_config: Dict[str, Any]) -> Non """ data = cls._read_raw() services = data.setdefault("api_services", {}) - services[service_id] = service_config + services[service_id] = cls._normalize_api_service_config(service_config) cls._write_raw(data) log.info("config_writer.api_service_set", {"service_id": service_id}) diff --git a/flocks/hub/catalog.py b/flocks/hub/catalog.py index 063f3907..9448fe0f 100644 --- a/flocks/hub/catalog.py +++ b/flocks/hub/catalog.py @@ -285,10 +285,13 @@ def _tool_tags(plugin_id: str, description: str) -> list[str]: manual_tags: dict[str, list[str]] = { "fofa": ["network-mapping", "network", "threat-intelligence"], "greynoise": ["threat-intelligence", "network", "ioc"], + "ngsoc": ["siem", "network", "threat-intelligence"], "ngtip_api": ["threat-intelligence", "ioc"], "onesec": ["edr", "hids", "threat-intelligence"], - "onesig": ["network", "web-security", "threat-intelligence"], + # onesig is a Secure Internet Gateway — network only, not web-security/EDR + "onesig": ["network", "threat-intelligence"], "qingteng": ["hids", "edr", "vulnerability"], + "sangfor_af": ["waf", "network"], "sangfor_sip": ["siem", "network", "threat-intelligence"], "sangfor_xdr": ["xdr", "edr", "ndr"], "skyeye_api": ["ndr", "network", "threat-intelligence"], @@ -300,8 +303,10 @@ def _tool_tags(plugin_id: str, description: str) -> list[str]: "mcp": ["integration"], "python": ["integration"], } - if plugin_id in manual_tags: - return _safe_tags(manual_tags[plugin_id]) + # Support versioned plugin IDs: "sangfor_af_v8_0_48" matches key "sangfor_af". + for key, tags in manual_tags.items(): + if plugin_id == key or plugin_id.startswith(f"{key}_"): + return _safe_tags(tags) text = (f"{plugin_id} {description}").lower() inferred: list[str] = [] diff --git a/flocks/mcp/catalog.py b/flocks/mcp/catalog.py index f89c0ed8..92800bb4 100644 --- a/flocks/mcp/catalog.py +++ b/flocks/mcp/catalog.py @@ -7,7 +7,7 @@ import json from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Literal from pydantic import BaseModel, Field @@ -78,6 +78,19 @@ class InstallSpec(BaseModel): note: Optional[str] = None +class RemoteConfigSpec(BaseModel): + """Remote MCP configuration template stored in the catalog.""" + + model_config = {"extra": "allow"} + + url: str = "" + transport: Literal["auto", "sse", "http"] = "auto" + headers: Optional[Dict[str, str]] = None + auth: Optional[Dict[str, Any]] = None + oauth: Optional[Any] = None + timeout: Optional[int] = None + + class CatalogEntry(BaseModel): """A single MCP server entry in the catalog.""" @@ -94,6 +107,7 @@ class CatalogEntry(BaseModel): stars: int = 0 transport: str = "local" install: InstallSpec = Field(default_factory=InstallSpec) + remote: Optional[RemoteConfigSpec] = None env_vars: Dict[str, EnvVarSpec] = Field(default_factory=dict) system_deps: List[str] = Field(default_factory=list) tags: List[str] = Field(default_factory=list) @@ -179,11 +193,24 @@ def to_mcp_config( return config elif self.transport == "remote": - return { + config: Dict[str, Any] = { "type": "remote", - "url": env.get("url", ""), + "url": env.get("url", self.remote.url if self.remote else ""), "enabled": False, } + if self.remote: + remote_template = self.remote.model_dump(exclude_none=True) + if "transport" in remote_template: + config["transport"] = remote_template["transport"] + if remote_template.get("headers"): + config["headers"] = remote_template["headers"] + if remote_template.get("auth"): + config["auth"] = remote_template["auth"] + if "oauth" in remote_template: + config["oauth"] = remote_template["oauth"] + if remote_template.get("timeout") is not None: + config["timeout"] = remote_template["timeout"] + return config return {} diff --git a/flocks/mcp/client.py b/flocks/mcp/client.py index 3e6febea..8b14d9a2 100644 --- a/flocks/mcp/client.py +++ b/flocks/mcp/client.py @@ -8,7 +8,7 @@ import os import tempfile from pathlib import Path -from typing import Optional, Dict, Any, List +from typing import Optional, Dict, Any, List, Literal from mcp import ClientSession from mcp.client.streamable_http import streamablehttp_client from mcp.client.sse import sse_client @@ -63,6 +63,7 @@ def __init__( headers: Optional[Dict[str, str]] = None, env: Optional[Dict[str, str]] = None, auth_config: Optional[Dict[str, Any]] = None, + transport: Literal["auto", "sse", "http"] = "auto", timeout: float = 30.0 ): """ @@ -76,6 +77,7 @@ def __init__( headers: Extra HTTP headers for remote MCP connections env: Extra environment variables for local server subprocess auth_config: Authentication configuration + transport: Preferred remote transport (auto | sse | http) timeout: Timeout in seconds """ self.name = name @@ -85,6 +87,7 @@ def __init__( self.headers = headers self.env = env self.auth_config = auth_config + self.transport = transport self.timeout = timeout self.session: Optional[ClientSession] = None @@ -116,53 +119,108 @@ async def connect(self) -> None: raise ValueError(f"Unknown server type: {self.server_type}") async def _connect_remote(self) -> None: - """Connect to remote server (try Streamable HTTP first, fall back to SSE)""" + """Connect to a remote server using the configured transport strategy.""" full_url = build_mcp_url(self.url, self.auth_config) request_headers = build_mcp_headers(self.headers, self.auth_config) - + + if self.transport == "http": + log.info("mcp.client.connecting", { + "server": self.name, + "type": "remote", + "strategy": "streamable_http_only", + }) + await self._connect_streamable_http_only(full_url, request_headers) + return + + if self.transport == "sse": + log.info("mcp.client.connecting", { + "server": self.name, + "type": "remote", + "strategy": "sse_only", + }) + await self._connect_sse_only(full_url, request_headers) + return + log.info("mcp.client.connecting", { "server": self.name, "type": "remote", - "strategy": "streamable_http_then_sse" + "strategy": "streamable_http_then_sse", }) - - # Try Streamable HTTP first + await self._connect_auto(full_url, request_headers) + + async def _connect_streamable_http_only( + self, full_url: str, headers: Optional[Dict[str, str]] + ) -> None: + """Connect using only Streamable HTTP.""" try: - await self._do_connect_streamable_http(full_url, request_headers) + await self._do_connect_streamable_http(full_url, headers) + self._transport_type = "streamable_http" + except asyncio.TimeoutError: + await self._cleanup_connection() + log.error("mcp.client.timeout", { + "server": self.name, + "transport": "streamable_http", + }) + raise RuntimeError(f"Connection timeout: {self.name}") + except Exception as e: + root_cause = _extract_root_cause(e) + await self._cleanup_connection() + raise RuntimeError(f"Connection failed: {self.name}: {root_cause}") + + async def _connect_sse_only( + self, full_url: str, headers: Optional[Dict[str, str]] + ) -> None: + """Connect using only SSE.""" + try: + await self._do_connect_sse(full_url, headers) + self._transport_type = "sse" + except asyncio.TimeoutError: + await self._cleanup_connection() + log.error("mcp.client.timeout", { + "server": self.name, + "transport": "sse", + }) + raise RuntimeError(f"Connection timeout: {self.name}") + except Exception as e: + root_cause = _extract_root_cause(e) + await self._cleanup_connection() + raise RuntimeError(f"Connection failed: {self.name}: {root_cause}") + + async def _connect_auto( + self, full_url: str, headers: Optional[Dict[str, str]] + ) -> None: + """Connect using auto-detection: HTTP first, then SSE.""" + try: + await self._do_connect_streamable_http(full_url, headers) self._transport_type = "streamable_http" return except asyncio.TimeoutError: - # Timeout means server is reachable but slow — don't fall back await self._cleanup_connection() log.error("mcp.client.timeout", { "server": self.name, - "transport": "streamable_http" + "transport": "streamable_http", }) raise RuntimeError(f"Connection timeout: {self.name}") except Exception as e: log.info("mcp.client.streamable_http_failed", { "server": self.name, "error": str(e), - "fallback": "sse" + "fallback": "sse", }) - # Clean up failed attempt before trying SSE await self._cleanup_connection() - - # Fall back to SSE + try: - await self._do_connect_sse(full_url, request_headers) + await self._do_connect_sse(full_url, headers) self._transport_type = "sse" return except Exception as e: root_cause = _extract_root_cause(e) log.error("mcp.client.all_transports_failed", { "server": self.name, - "error": root_cause + "error": root_cause, }) await self._cleanup_connection() - raise RuntimeError( - f"Connection failed: {self.name}: {root_cause}" - ) + raise RuntimeError(f"Connection failed: {self.name}: {root_cause}") async def _do_connect_streamable_http( self, full_url: str, headers: Optional[Dict[str, str]] = None diff --git a/flocks/mcp/server.py b/flocks/mcp/server.py index a5527a87..65b51161 100644 --- a/flocks/mcp/server.py +++ b/flocks/mcp/server.py @@ -200,6 +200,7 @@ async def _connect_and_register( headers=config.get('headers'), env=config.get('environment'), auth_config=config.get('auth'), + transport=config.get('transport', 'auto'), timeout=config.get('timeout', 30.0) ) diff --git a/flocks/mcp/utils.py b/flocks/mcp/utils.py index 16b4c60a..cc4e3e18 100644 --- a/flocks/mcp/utils.py +++ b/flocks/mcp/utils.py @@ -43,6 +43,17 @@ REMOTE_MCP_TYPES = frozenset({"remote", "sse"}) LOCAL_MCP_TYPES = frozenset({"local", "stdio"}) +MCP_MASKED_SECRET_VALUE = "***" + + +def _is_secret_placeholder(value: str) -> bool: + """Return True when the value is already secret-backed or intentionally blank.""" + stripped = value.strip() + return ( + not stripped + or stripped.startswith("{secret:") + or stripped.startswith("${") + ) def resolve_url_template(url: str) -> str: @@ -86,11 +97,23 @@ def _replace(match: re.Match) -> str: def normalize_mcp_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]: """Normalize transport aliases to canonical backend config values.""" normalized = dict(config) - transport = str(normalized.get("type", "")).strip().lower() - if transport == "sse": + server_type = str(normalized.get("type", "")).strip().lower() + if server_type == "sse": normalized["type"] = "remote" - elif transport == "stdio": + normalized.setdefault("transport", "sse") + elif server_type == "stdio": normalized["type"] = "local" + + transport = str(normalized.get("transport", "")).strip().lower() + if normalized.get("type") in REMOTE_MCP_TYPES: + if transport in ("", "auto"): + normalized["transport"] = "auto" + elif transport in ("streamablehttp", "streamable_http", "http"): + normalized["transport"] = "http" + elif transport == "sse": + normalized["transport"] = "sse" + else: + normalized["transport"] = "auto" return normalized @@ -204,6 +227,12 @@ def build_mcp_headers( if auth_config and auth_config.get("location", "header") == "header": param_name = str(auth_config.get("param_name", "Authorization")) param_value = resolve_env_var(str(auth_config.get("value", ""))) + if ( + str(auth_config.get("scheme", "")).strip().lower() == "bearer" + and param_value + and not param_value.lower().startswith("bearer ") + ): + param_value = f"Bearer {param_value}" if param_value: headers.setdefault(param_name, param_value) @@ -319,6 +348,167 @@ def extract_api_key_from_mcp_url(server_name: str, config: Dict[str, Any]) -> Di return {**config, "url": new_url} +def extract_auth_value_from_mcp_config(server_name: str, config: Dict[str, Any]) -> Dict[str, Any]: + """Move plain-text ``auth.value`` into SecretManager and keep a secret reference.""" + auth_config = config.get("auth") + if not isinstance(auth_config, dict): + return dict(config) + + auth_value = auth_config.get("value") + if not isinstance(auth_value, str): + return dict(config) + + auth_value = auth_value.strip() + if not auth_value or auth_value.startswith("{secret:") or auth_value.startswith("${"): + return dict(config) + + updated_auth = dict(auth_config) + scheme = str(updated_auth.get("scheme", "")).strip().lower() + if ( + not scheme + and str(updated_auth.get("location", "")).strip().lower() == "header" + and str(updated_auth.get("param_name", "")).strip().lower() == "authorization" + and auth_value.lower().startswith("bearer ") + ): + scheme = "bearer" + if scheme == "bearer": + updated_auth["scheme"] = "bearer" + if auth_value.lower().startswith("bearer "): + auth_value = auth_value[7:].strip() + elif "scheme" in updated_auth and not scheme: + updated_auth.pop("scheme", None) + + secret_key = str(auth_config.get("secret_id") or f"{server_name}_mcp_key") + from flocks.security import get_secret_manager + + get_secret_manager().set(secret_key, auth_value) + + updated_auth["value"] = f"{{secret:{secret_key}}}" + updated_auth.pop("secret_id", None) + + updated_config = dict(config) + updated_config["auth"] = updated_auth + return updated_config + + +def extract_sensitive_headers_from_mcp_config( + server_name: str, config: Dict[str, Any] +) -> Dict[str, Any]: + """Move plain-text sensitive headers into SecretManager.""" + headers = config.get("headers") + if config.get("type") not in REMOTE_MCP_TYPES or not isinstance(headers, dict): + return dict(config) + + updated_headers = dict(headers) + secrets = None + extracted = False + + for header_name, header_value in headers.items(): + header_key = str(header_name).strip() + if header_key.lower() not in _SENSITIVE_HEADER_NAMES: + continue + if not isinstance(header_value, str): + continue + + normalized_value = header_value.strip() + if _is_secret_placeholder(normalized_value): + continue + + if secrets is None: + from flocks.security import get_secret_manager + + secrets = get_secret_manager() + + secret_key = f"{server_name}_{sanitize_name(header_key)}_header" + secrets.set(secret_key, normalized_value) + updated_headers[header_name] = f"{{secret:{secret_key}}}" + extracted = True + + if not extracted: + return dict(config) + + updated_config = dict(config) + updated_config["headers"] = updated_headers + return updated_config + + +def mask_sensitive_mcp_config_for_frontend( + config: Dict[str, Any] +) -> Dict[str, Any]: + """Mask plain-text secrets before returning MCP config to the frontend.""" + masked_config = dict(config) + + auth_config = config.get("auth") + if isinstance(auth_config, dict): + auth_value = auth_config.get("value") + if isinstance(auth_value, str) and not _is_secret_placeholder(auth_value): + masked_auth = dict(auth_config) + masked_auth["value"] = MCP_MASKED_SECRET_VALUE + masked_config["auth"] = masked_auth + + headers = config.get("headers") + if isinstance(headers, dict): + masked_headers = dict(headers) + changed = False + for header_name, header_value in headers.items(): + if str(header_name).strip().lower() not in _SENSITIVE_HEADER_NAMES: + continue + if not isinstance(header_value, str): + continue + if _is_secret_placeholder(header_value): + continue + masked_headers[header_name] = MCP_MASKED_SECRET_VALUE + changed = True + if changed: + masked_config["headers"] = masked_headers + + return masked_config + + +def restore_masked_mcp_config_secrets( + previous_config: Dict[str, Any], updated_config: Dict[str, Any] +) -> Dict[str, Any]: + """Restore masked frontend sentinel values back to their previous secrets.""" + restored_config = dict(updated_config) + + previous_auth = previous_config.get("auth") + next_auth = updated_config.get("auth") + if ( + isinstance(previous_auth, dict) + and isinstance(next_auth, dict) + and next_auth.get("value") == MCP_MASKED_SECRET_VALUE + and isinstance(previous_auth.get("value"), str) + ): + restored_auth = dict(next_auth) + restored_auth["value"] = previous_auth["value"] + restored_config["auth"] = restored_auth + + previous_headers = previous_config.get("headers") + next_headers = updated_config.get("headers") + if isinstance(previous_headers, dict) and isinstance(next_headers, dict): + previous_by_name = { + str(header_name).strip().lower(): header_value + for header_name, header_value in previous_headers.items() + } + restored_headers = dict(next_headers) + changed = False + for header_name, header_value in next_headers.items(): + normalized_header = str(header_name).strip().lower() + if ( + normalized_header not in _SENSITIVE_HEADER_NAMES + or header_value != MCP_MASKED_SECRET_VALUE + ): + continue + if normalized_header not in previous_by_name: + continue + restored_headers[header_name] = previous_by_name[normalized_header] + changed = True + if changed: + restored_config["headers"] = restored_headers + + return restored_config + + def resolve_env_var(value: str) -> str: """ Resolve environment variable or secret placeholder. @@ -464,15 +654,21 @@ def resolve_conflict(tool_name: str, attempt: int = 0) -> str: __all__ = [ + 'MCP_MASKED_SECRET_VALUE', 'REMOTE_MCP_TYPES', 'LOCAL_MCP_TYPES', 'build_mcp_url', 'build_mcp_headers', 'config_has_pending_credentials', + 'extract_api_key_from_mcp_url', + 'extract_auth_value_from_mcp_config', + 'extract_sensitive_headers_from_mcp_config', 'get_connect_block_reason', 'is_auth_related_error', + 'mask_sensitive_mcp_config_for_frontend', 'normalize_mcp_config', 'normalize_mcp_config_aliases', + 'restore_masked_mcp_config_secrets', 'resolve_url_template', 'resolve_env_var', 'sanitize_name', diff --git a/flocks/provider/sdk/openai.py b/flocks/provider/sdk/openai.py index b7746694..a6d9a2ec 100644 --- a/flocks/provider/sdk/openai.py +++ b/flocks/provider/sdk/openai.py @@ -17,8 +17,11 @@ StreamChunk, ) from flocks.provider.sdk.openai_base import ( + DEFAULT_HTTP_TIMEOUT, _coerce_bool, extract_reasoning_content, + format_openai_content, + format_openai_messages, resolve_verify_ssl, ) from flocks.utils.log import Log @@ -80,7 +83,7 @@ def _get_client(self): http_client = httpx.AsyncClient( trust_env=trust_env, verify=verify_ssl, - timeout=120.0, + timeout=DEFAULT_HTTP_TIMEOUT, ) if base_url: @@ -116,41 +119,13 @@ def get_models(self) -> List[ModelInfo]: """ return list(getattr(self, "_config_models", [])) - @staticmethod - def _format_content(content: Any) -> Any: - if not isinstance(content, list): - return content - - formatted: list[dict[str, Any]] = [] - for block in content: - if not isinstance(block, dict): - continue - block_type = block.get("type") - if block_type == "text" and isinstance(block.get("text"), str): - formatted.append({"type": "text", "text": block["text"]}) - elif block_type == "image" and block.get("data") and block.get("mimeType"): - formatted.append({ - "type": "image_url", - "image_url": { - "url": f"data:{block['mimeType']};base64,{block['data']}", - }, - }) - return formatted + # Delegated to the shared canonical implementations in ``openai_base``. + _format_content = staticmethod(format_openai_content) @staticmethod def _format_messages(messages: List[ChatMessage]) -> list: """Convert ChatMessage list to OpenAI API dicts, preserving tool_calls / tool results.""" - formatted = [] - for msg in messages: - m: dict = {"role": msg.role, "content": OpenAIProvider._format_content(msg.content)} - if msg.tool_calls: - m["tool_calls"] = msg.tool_calls - if msg.tool_call_id: - m["tool_call_id"] = msg.tool_call_id - if msg.name: - m["name"] = msg.name - formatted.append(m) - return formatted + return format_openai_messages(messages) async def chat( self, diff --git a/flocks/provider/sdk/openai_base.py b/flocks/provider/sdk/openai_base.py index 2212457b..f95d9874 100644 --- a/flocks/provider/sdk/openai_base.py +++ b/flocks/provider/sdk/openai_base.py @@ -10,6 +10,8 @@ import os from typing import Any, AsyncIterator, Dict, List, Optional +import httpx + from flocks.provider.provider import ( BaseProvider, ChatMessage, @@ -22,6 +24,182 @@ log = Log.create(service="provider.openai_base") +# Shared HTTP timeout used by every OpenAI-style provider (OpenAIProvider, +# OpenAICompatibleProvider, OpenAIBaseProvider). Centralised here so a single +# change covers all three providers. Granular values (instead of a flat +# timeout) let small control-plane requests fail fast while multimodal +# (image) uploads get the headroom they need on slow links. +DEFAULT_HTTP_TIMEOUT = httpx.Timeout(connect=30.0, read=600.0, write=600.0, pool=60.0) + + +# Canonical OpenAI-style content translation, shared by every provider that +# talks the OpenAI chat-completions wire format. Kept as a module-level +# function (instead of a method on a single class) because the three OpenAI +# implementations — ``OpenAIProvider``, ``OpenAICompatibleProvider``, and +# ``OpenAIBaseProvider`` — sit in parallel class hierarchies. Putting the +# logic here gives all of them one point of truth: change the schema once, +# every provider follows. +# +# Recognised input block schema (Flocks-internal): +# {"type": "text", "text": "..."} +# {"type": "image", "mimeType": "image/png", "data": ""} +# Plus already-OpenAI-native blocks (image_url / input_audio / audio / +# refusal / file) which are passed through unchanged so callers that +# pre-format won't lose them. +_OPENAI_NATIVE_BLOCK_TYPES = frozenset({ + "image_url", "input_audio", "audio", "refusal", "file", +}) + +# Flocks-internal block types that the translation logic knows about. Used to +# distinguish "known type with missing/invalid fields" from "genuinely unknown +# type" when logging dropped blocks. +_FLOCKS_INTERNAL_BLOCK_TYPES = frozenset({"text", "image"}) + + +def _summarise_block(block: Any) -> Dict[str, Any]: + """Describe a content block for diagnostic logs without leaking base64. + + Multimodal requests can carry several MB of base64 image data. Logging the + full payload would dwarf every other line in the journal *and* expose + user-uploaded data, so for ``image_url`` blocks we only record the URL + scheme and length. ``text`` blocks record character count only. + """ + if isinstance(block, dict): + btype = block.get("type") + if btype == "image_url": + img = block.get("image_url") or {} + url = img.get("url") if isinstance(img, dict) else "" + scheme = ( + url.split(":", 1)[0] + if isinstance(url, str) and ":" in url + else "" + ) + return { + "type": btype, + "url_scheme": scheme, + "url_chars": len(url) if isinstance(url, str) else 0, + } + if btype == "text": + txt = block.get("text") or "" + return {"type": btype, "text_chars": len(txt)} + return {"type": btype} + return {"type": type(block).__name__} + + +def _summarise_messages(openai_messages: List[Any]) -> List[Dict[str, Any]]: + """Compute a redacted ``message_shapes`` for diagnostic logging. + + See :func:`_summarise_block`. Used by both streaming and non-streaming + request paths so multimodal regressions surface uniformly in the log. + """ + out: List[Dict[str, Any]] = [] + for m in openai_messages: + if not isinstance(m, dict): + out.append({"type": type(m).__name__, "skipped": True}) + continue + content = m.get("content") + if isinstance(content, list): + out.append({ + "role": m.get("role"), + "blocks": [_summarise_block(b) for b in content], + }) + else: + out.append({ + "role": m.get("role"), + "content_chars": len(content) if isinstance(content, str) else None, + }) + return out + + +def format_openai_content(content: Any) -> Any: + """Translate Flocks-internal content blocks to OpenAI chat.completions schema. + + Plain string content (the common case for assistant/tool messages) is + passed through untouched. List content is rewritten so each block is in + the schema OpenAI's chat.completions API expects: + + * ``{"type": "image", "mimeType": ..., "data": }`` + → ``{"type": "image_url", "image_url": {"url": "data:...;base64,..."}}`` + * ``{"type": "text", "text": ...}`` is preserved. + * Already-OpenAI-native block types are passed through unchanged. + * Unknown block types are dropped and logged at DEBUG level to avoid + sending malformed payloads that would otherwise trigger a 400 from + the gateway. Callers should monitor for ``unknown_block_dropped`` log + events when adding new block kinds. + """ + if not isinstance(content, list): + return content + + formatted: list[dict[str, Any]] = [] + for block in content: + if not isinstance(block, dict): + continue + block_type = block.get("type") + if block_type == "text" and isinstance(block.get("text"), str): + formatted.append({"type": "text", "text": block["text"]}) + elif block_type == "image" and block.get("data") and block.get("mimeType"): + formatted.append({ + "type": "image_url", + "image_url": { + "url": f"data:{block['mimeType']};base64,{block['data']}", + }, + }) + elif block_type in _OPENAI_NATIVE_BLOCK_TYPES: + formatted.append(block) + elif block_type in _FLOCKS_INTERNAL_BLOCK_TYPES: + # Known type but missing required fields (e.g. image without data/mimeType, + # or text with a non-string value). Log at debug with the actual keys present + # to make it easy to diagnose upstream encoding bugs. + log.debug("openai_base.malformed_block_dropped", { + "type": block_type, + "keys": sorted(block.keys()), + }) + else: + log.debug("openai_base.unknown_block_dropped", {"type": block_type}) + # Return the translated list as-is (possibly empty). Callers that need to + # omit the ``content`` field entirely (e.g. assistant messages with + # tool_calls only) are responsible for detecting the empty-list case and + # dropping the key themselves. Returning ``None`` here risks silencing a + # 400 for user/system roles where ``content=null`` is not permitted. + return formatted + + +def format_openai_messages(messages: List["ChatMessage"]) -> list: + """Convert a ``ChatMessage`` list to the OpenAI chat-completions wire format. + + Shared by all three OpenAI-style providers (``OpenAIProvider``, + ``OpenAICompatibleProvider``, ``OpenAIBaseProvider``) so the role-aware + content-null guard lives in exactly one place. + + Content-null rules (OpenAI chat completions spec): + * ``user`` / ``system`` / ``tool`` roles MUST have non-null content. + * ``assistant`` messages with ``tool_calls`` MAY omit ``content`` (null is + accepted) — we omit the key entirely for cleaner serialisation. + * An empty translated list (all blocks were dropped/malformed) is treated + the same as no content: safe fallback is ``""`` for non-assistant roles + and key-omission for assistant-with-tool-calls. + """ + formatted = [] + for m in messages: + role = m.role if isinstance(m.role, str) else m.role.value + content = format_openai_content(m.content) + d: Dict[str, Any] = {"role": role} + if isinstance(content, list) and not content: + if role == "assistant" and m.tool_calls: + pass # omit content key — null is valid for tool-call-only turns + else: + d["content"] = "" + else: + d["content"] = content + if m.tool_calls: + d["tool_calls"] = m.tool_calls + if m.tool_call_id: + d["tool_call_id"] = m.tool_call_id + if m.name: + d["name"] = m.name + formatted.append(d) + return formatted + def _coerce_bool(value: Any, default: bool) -> bool: """Coerce loosely-typed config values to bool.""" @@ -417,13 +595,34 @@ def _get_client(self): custom_settings = getattr(self._config, "custom_settings", None) or {} verify_ssl = resolve_verify_ssl(custom_settings, default=True) - http_client = httpx.AsyncClient(verify=verify_ssl, timeout=120.0) + # Honour the same env-var contract as ``OpenAIProvider``: by default + # we follow ambient HTTP_PROXY / HTTPS_PROXY / NO_PROXY settings + # (``trust_env=True``) so that corporate egress works out of the + # box. Operators can opt out globally via FLOCKS_HTTP_TRUST_ENV=0 + # or per-provider via ``custom_settings.trust_env``. + trust_env = _coerce_bool( + os.getenv("FLOCKS_HTTP_TRUST_ENV"), True + ) + if isinstance(custom_settings, dict) and "trust_env" in custom_settings: + trust_env = _coerce_bool(custom_settings.get("trust_env"), trust_env) + timeout = DEFAULT_HTTP_TIMEOUT + http_client = httpx.AsyncClient( + trust_env=trust_env, + verify=verify_ssl, + timeout=timeout, + ) self._client = AsyncOpenAI( api_key=api_key, base_url=base_url, http_client=http_client, ) + log.info("openai_base.client.created", { + "provider_id": getattr(self._config, "id", None), + "base_url": base_url, + "trust_env": trust_env, + "verify_ssl": verify_ssl, + }) return self._client # ==================== Catalog Integration ==================== @@ -447,20 +646,16 @@ def get_models(self) -> List[ModelInfo]: # ==================== Chat ==================== + # Thin static-method wrapper so existing call-sites + # (``OpenAIBaseProvider._format_content``) keep working. Real logic lives + # in the module-level :func:`format_openai_content` so all OpenAI-style + # providers share one implementation. + _format_content = staticmethod(format_openai_content) + @staticmethod def _format_messages(messages: List[ChatMessage]) -> list: """Convert ChatMessage list to OpenAI API dicts, preserving tool_calls / tool results.""" - formatted = [] - for m in messages: - d: Dict[str, Any] = {"role": m.role, "content": m.content} - if m.tool_calls: - d["tool_calls"] = m.tool_calls - if m.tool_call_id: - d["tool_call_id"] = m.tool_call_id - if m.name: - d["name"] = m.name - formatted.append(d) - return formatted + return format_openai_messages(messages) async def chat( self, model_id: str, messages: List[ChatMessage], **kwargs @@ -491,6 +686,19 @@ async def chat( if kwargs.get("tools"): params["tools"] = kwargs["tools"] + # Mirror ``chat_stream``'s diagnostic log so non-streaming multimodal + # regressions are equally visible. Never logs raw base64 — see + # ``_summarise_block``. + log.info("openai_base.chat.request", { + "model": model_id, + "thinking_enabled": bool(thinking), + "has_extra_body": "extra_body" in params, + "has_tools": bool(kwargs.get("tools")), + "max_tokens": kwargs.get("max_tokens"), + "has_temperature": "temperature" in params, + "message_shapes": _summarise_messages(openai_messages), + }) + response = await client.chat.completions.create(**params) if not response.choices: extra = getattr(response, "model_extra", {}) or {} @@ -553,6 +761,8 @@ async def chat_stream( if kwargs.get("tools"): params["tools"] = kwargs["tools"] + # Inspect content shape so multimodal regressions surface in the log. + # We *never* log full base64 payloads — see ``_summarise_block``. log.info("openai_base.stream.request", { "model": model_id, "thinking_enabled": bool(thinking), @@ -561,6 +771,7 @@ async def chat_stream( "max_tokens": kwargs.get("max_tokens"), "has_temperature": "temperature" in params, "include_usage": True, + "message_shapes": _summarise_messages(openai_messages), }) try: diff --git a/flocks/provider/sdk/openai_compatible.py b/flocks/provider/sdk/openai_compatible.py index bc8d62de..3214163c 100644 --- a/flocks/provider/sdk/openai_compatible.py +++ b/flocks/provider/sdk/openai_compatible.py @@ -22,10 +22,14 @@ StreamChunk, ) from flocks.provider.sdk.openai_base import ( + DEFAULT_HTTP_TIMEOUT, ThinkTagExtractor, + _coerce_bool, _normalize_stream_usage, _supports_include_usage_fallback, extract_reasoning_content, + format_openai_content, + format_openai_messages, resolve_verify_ssl, ) from flocks.utils.log import Log @@ -92,7 +96,20 @@ def _get_client(self): custom_settings = getattr(self._config, "custom_settings", None) or {} verify_ssl = resolve_verify_ssl(custom_settings, default=True) - http_client = httpx.AsyncClient(verify=verify_ssl, timeout=120.0) + # Honour the same env-var / per-provider trust_env contract as + # OpenAIProvider and OpenAIBaseProvider. Timeout is shared via + # DEFAULT_HTTP_TIMEOUT from openai_base so all three providers + # stay in sync. + trust_env = _coerce_bool( + os.getenv("FLOCKS_HTTP_TRUST_ENV"), True + ) + if isinstance(custom_settings, dict) and "trust_env" in custom_settings: + trust_env = _coerce_bool(custom_settings.get("trust_env"), trust_env) + http_client = httpx.AsyncClient( + trust_env=trust_env, + verify=verify_ssl, + timeout=DEFAULT_HTTP_TIMEOUT, + ) # Create client self._client = AsyncOpenAI( @@ -102,7 +119,7 @@ def _get_client(self): ) self.log.info( "openai_compatible.client.created", - {"base_url": base_url, "verify_ssl": verify_ssl}, + {"base_url": base_url, "trust_env": trust_env, "verify_ssl": verify_ssl}, ) except ImportError: @@ -143,44 +160,13 @@ async def _sleep_before_minimax_empty_retry(self, model_id: str, stage: str, has }) await asyncio.sleep(delay_seconds) - @staticmethod - def _format_content(content: Any) -> Any: - if not isinstance(content, list): - return content - - formatted: list[dict] = [] - for block in content: - if not isinstance(block, dict): - continue - block_type = block.get("type") - if block_type == "text" and isinstance(block.get("text"), str): - formatted.append({"type": "text", "text": block["text"]}) - elif block_type == "image" and block.get("data") and block.get("mimeType"): - formatted.append({ - "type": "image_url", - "image_url": { - "url": f"data:{block['mimeType']};base64,{block['data']}", - }, - }) - return formatted + # Delegated to the shared canonical implementations in ``openai_base``. + _format_content = staticmethod(format_openai_content) @staticmethod def _format_messages(messages: List[ChatMessage]) -> list: """Convert ChatMessage list to OpenAI API dicts, preserving tool_calls / tool results.""" - formatted = [] - for msg in messages: - m: dict = { - "role": msg.role, - "content": OpenAICompatibleProvider._format_content(msg.content), - } - if msg.tool_calls: - m["tool_calls"] = msg.tool_calls - if msg.tool_call_id: - m["tool_call_id"] = msg.tool_call_id - if msg.name: - m["name"] = msg.name - formatted.append(m) - return formatted + return format_openai_messages(messages) async def chat( self, diff --git a/flocks/pty/pty.py b/flocks/pty/pty.py index 26fccec3..cc2dcf85 100644 --- a/flocks/pty/pty.py +++ b/flocks/pty/pty.py @@ -24,6 +24,37 @@ # Buffer configuration matching Flocks BUFFER_LIMIT = 1024 * 1024 * 2 # 2MB BUFFER_CHUNK = 64 * 1024 # 64KB +_ALLOWED_SHELL_NAMES = { + "ash", + "bash", + "csh", + "cmd", + "cmd.exe", + "dash", + "fish", + "ksh", + "ksh93", + "mksh", + "powershell", + "powershell.exe", + "pwsh", + "pwsh.exe", + "sh", + "tcsh", + "zsh", +} +_ALLOWED_SHELL_ARGS = {"-i", "-l", "--login"} +_LOGIN_FLAG_SHELL_NAMES = {"bash", "fish", "ksh", "ksh93", "mksh", "sh", "zsh"} +_BLOCKED_PTY_ENV_NAMES = { + "BASH_ENV", + "ENV", + "LD_LIBRARY_PATH", + "LD_PRELOAD", + "PROMPT_COMMAND", + "PYTHONSTARTUP", + "ZDOTDIR", +} +_BLOCKED_PTY_ENV_PREFIXES = ("DYLD_",) class PtyStatus(str, Enum): @@ -99,6 +130,49 @@ def _get_shell(cls) -> str: return shell_path return "sh" + + @classmethod + def _validate_interactive_shell(cls, command: str, args: List[str]) -> None: + """Allow PTY creation only for interactive shell sessions.""" + if not command or "\x00" in command: + raise ValueError("Invalid PTY command") + + shell_name = os.path.basename(command).lower() + if shell_name not in _ALLOWED_SHELL_NAMES: + raise ValueError("PTY command must be an approved interactive shell") + + for arg in args: + if not isinstance(arg, str) or "\x00" in arg or arg not in _ALLOWED_SHELL_ARGS: + raise ValueError("PTY command arguments are restricted to interactive shell flags") + + @classmethod + def _is_blocked_env_name(cls, name: str) -> bool: + normalized = name.upper() + return normalized in _BLOCKED_PTY_ENV_NAMES or any( + normalized.startswith(prefix) for prefix in _BLOCKED_PTY_ENV_PREFIXES + ) + + @classmethod + def _prepare_environment(cls, input_env: Optional[Dict[str, str]]) -> Dict[str, str]: + """Build a PTY environment without shell/linker startup injection hooks.""" + env = { + key: value + for key, value in os.environ.items() + if not cls._is_blocked_env_name(key) + } + + if input_env: + for key, value in input_env.items(): + if not isinstance(key, str) or not key or "\x00" in key: + raise ValueError("Invalid PTY environment variable name") + if cls._is_blocked_env_name(key): + raise ValueError(f"PTY environment variable is not allowed: {key}") + if not isinstance(value, str) or "\x00" in value: + raise ValueError(f"Invalid PTY environment variable value: {key}") + env[key] = value + + env["TERM"] = "xterm-256color" + return env @classmethod def list(cls) -> List[PtyInfo]: @@ -125,18 +199,18 @@ async def create(cls, input_data: CreateInput) -> PtyInfo: pty_id = Identifier.create("pty") command = input_data.command or cls._get_shell() args = list(input_data.args) if input_data.args else [] + cls._validate_interactive_shell(command, args) - # Add login flag for shells - if command.endswith("sh") and "-l" not in args: + # Add login flag only for shells known to accept it. Some approved + # POSIX-compatible shells (e.g. dash/ash) reject ``-l``. + shell_name = os.path.basename(command).lower() + if shell_name in _LOGIN_FLAG_SHELL_NAMES and "-l" not in args and "--login" not in args: args.append("-l") cwd = input_data.cwd or os.getcwd() # Prepare environment - env = dict(os.environ) - if input_data.env: - env.update(input_data.env) - env["TERM"] = "xterm-256color" + env = cls._prepare_environment(input_data.env) log.info("pty.creating", { "id": pty_id, diff --git a/flocks/server/app.py b/flocks/server/app.py index 9cb3a1d7..a3edd361 100644 --- a/flocks/server/app.py +++ b/flocks/server/app.py @@ -366,8 +366,6 @@ def _should_log_request(path: str, status_code: int) -> bool: # ``_FLOCKS_WEBUI_*`` origin inferred from the current CLI launch. # 2. Explicit ``server.cors`` in flocks.json → append user-configured # origins without discarding the runtime ones. -# 3. Fallback → only localhost (any port) via regex. -# # We deliberately do NOT auto-whitelist wildcard binds such as ``0.0.0.0``: # matching ``[^/]+:`` would accept every host on that port, effectively # disabling CORS. Remote deployments that bind to wildcard hosts must keep @@ -380,16 +378,10 @@ def _should_log_request(path: str, status_code: int) -> bool: # import time — which would otherwise cache ``HOME`` before test harnesses # can monkey-patch it. -_LOCALHOST_ORIGIN_RE = r"^https?://(127\.0\.0\.1|localhost)(:\d+)?$" - -_LOCALHOST_HOSTS = {"127.0.0.1", "localhost", "::1"} +_LOOPBACK_ORIGIN_HOSTS = {"127.0.0.1", "localhost", "::1"} _WILDCARD_HOSTS = {"0.0.0.0", "::"} -def _is_localhost(host: str) -> bool: - return host in _LOCALHOST_HOSTS - - def _format_host_for_url(host: str) -> str: """Wrap IPv6 literals in brackets before composing origins.""" if ":" in host and not host.startswith("["): @@ -398,11 +390,13 @@ def _format_host_for_url(host: str) -> str: def _append_origin(origins: list[str], host: str, port: str) -> None: - if not host or not port or _is_localhost(host) or host in _WILDCARD_HOSTS: + if not host or not port or host in _WILDCARD_HOSTS: return - origin = f"http://{_format_host_for_url(host)}:{port}" - if origin not in origins: - origins.append(origin) + hosts = sorted(_LOOPBACK_ORIGIN_HOSTS) if host in _LOOPBACK_ORIGIN_HOSTS else [host] + for candidate_host in hosts: + origin = f"http://{_format_host_for_url(candidate_host)}:{port}" + if origin not in origins: + origins.append(origin) def _read_cors_config() -> tuple[list[str], Optional[str]]: @@ -435,7 +429,7 @@ def _read_cors_config() -> tuple[list[str], Optional[str]]: except Exception: pass - return origins, _LOCALHOST_ORIGIN_RE + return origins, None class _DeferredCORSMiddleware: @@ -629,9 +623,8 @@ async def general_exception_handler(request: Request, exc: Exception): return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={ - "error": type(exc).__name__, - "message": str(exc), - "traceback": tb, + "error": "InternalServerError", + "message": "Internal server error", } ) diff --git a/flocks/server/auth.py b/flocks/server/auth.py index 8e55932a..5f642290 100644 --- a/flocks/server/auth.py +++ b/flocks/server/auth.py @@ -10,6 +10,7 @@ from typing import Optional from fastapi import HTTPException, Request, Response, status +from starlette.requests import HTTPConnection from flocks.auth.context import AuthUser, reset_current_auth_user, set_current_auth_user from flocks.auth.service import AuthService @@ -149,12 +150,12 @@ def password_reset_exempt(path: str) -> bool: return path in _PASSWORD_RESET_ALLOWED -def _has_session_cookie(request: Request) -> bool: +def _has_session_cookie(request: HTTPConnection) -> bool: session_id = request.cookies.get(SESSION_COOKIE_NAME) return bool(session_id and session_id.strip()) -def _is_browser_like_request(request: Request) -> bool: +def _is_browser_like_request(request: HTTPConnection) -> bool: """ Identify browser-originated traffic (must keep strict login checks). @@ -186,25 +187,7 @@ def _is_browser_like_request(request: Request) -> bool: return False -def _loopback_hosts() -> frozenset[str]: - hosts = {"127.0.0.1", "::1", "localhost"} - # FastAPI TestClient reports client host as "testclient"; only trust it in tests. - if os.getenv("PYTEST_CURRENT_TEST"): - hosts.add("testclient") - return frozenset(hosts) - - -def _is_loopback_direct_request(request: Request) -> bool: - """ - Trust only local direct requests (no proxy forwarding headers). - """ - if request.headers.get("x-forwarded-for"): - return False - client_host = request.client.host if request.client else None - return client_host in _loopback_hosts() - - -def _read_api_token_from_request(request: Request) -> Optional[str]: +def _read_api_token_from_request(request: HTTPConnection) -> Optional[str]: """ Read API token from Authorization Bearer or x-flocks-api-token header. """ @@ -248,18 +231,7 @@ def _build_api_token_user() -> AuthUser: ) -def _build_local_service_user() -> AuthUser: - """Synthetic local service identity for loopback non-browser clients.""" - return AuthUser( - id="local-service", - username="local-service", - role="admin", - status="active", - must_reset_password=False, - ) - - -async def apply_auth_for_request(request: Request): +async def apply_auth_for_request(request: HTTPConnection): """ Resolve user from cookie and bind context var. Returns (response_if_blocked, token, user). @@ -297,7 +269,8 @@ async def apply_auth_for_request(request: Request): ) return None, token, auth_user - # Non-browser clients: local loopback can run without token; remote requires API token. + # Non-browser clients must authenticate with an API token because + # localhost is not a reliable auth boundary. if not _is_browser_like_request(request): provided = _read_api_token_from_request(request) if provided: @@ -311,21 +284,15 @@ async def apply_auth_for_request(request: Request): token = set_current_auth_user(token_user) return None, token, token_user - if _is_loopback_direct_request(request): - local_user = _build_local_service_user() - request.state.auth_user = local_user - token = set_current_auth_user(local_user) - return None, token, local_user - expected = _get_expected_api_token() if not expected: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail=f"远程非浏览器请求需要 API Token,请先在 .secret.json 中配置 {API_TOKEN_SECRET_ID}", + detail=f"非浏览器请求需要 API Token,请先在 .secret.json 中配置 {API_TOKEN_SECRET_ID}", ) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="远程非浏览器请求鉴权失败,请在 Authorization 中携带 Bearer API Token", + detail="非浏览器请求鉴权失败,请在 Authorization 中携带 Bearer API Token", ) bootstrapped = await AuthService.has_users() diff --git a/flocks/server/routes/find.py b/flocks/server/routes/find.py index 2a07208b..efa93b89 100644 --- a/flocks/server/routes/find.py +++ b/flocks/server/routes/find.py @@ -8,9 +8,11 @@ import subprocess from typing import Optional, List -from fastapi import APIRouter, Query +from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel +from flocks.config.config import Config +from flocks.utils.http_file_read_guard import resolve_path_for_http_file_access from flocks.utils.log import Log @@ -26,6 +28,25 @@ class FindResult(BaseModel): content: Optional[str] = None +async def _resolve_search_directory(directory: Optional[str]) -> str: + """Resolve a requested search root to an allowed readable directory.""" + cfg = await Config.get() + requested = directory or os.getcwd() + try: + cwd = await resolve_path_for_http_file_access(requested, cfg) + except PermissionError as exc: + raise HTTPException(status_code=403, detail="Access denied") from exc + + if not os.path.isdir(cwd): + raise HTTPException(status_code=400, detail="Search directory must be a directory") + return cwd + + +def _validate_search_input(value: str, *, label: str, max_length: int = 500) -> None: + if not value or len(value) > max_length or "\x00" in value: + raise HTTPException(status_code=400, detail=f"Invalid {label}") + + @router.get( "", summary="Find text", @@ -36,15 +57,17 @@ async def find_text( directory: Optional[str] = Query(None, description="Project directory"), ) -> List[FindResult]: """Search for text in files""" - cwd = directory or os.getcwd() + _validate_search_input(pattern, label="search pattern") + cwd = await _resolve_search_directory(directory) try: # Use ripgrep if available result = subprocess.run( - ["rg", "--json", "--max-count", "100", pattern], + ["rg", "--json", "--max-count", "100", "--", pattern], cwd=cwd, capture_output=True, text=True, + timeout=10, ) results = [] @@ -72,10 +95,11 @@ async def find_text( # ripgrep not available, use grep try: result = subprocess.run( - ["grep", "-rn", pattern, "."], + ["grep", "-rn", "--", pattern, "."], cwd=cwd, capture_output=True, text=True, + timeout=10, ) results = [] @@ -106,10 +130,11 @@ async def find_files( directory: Optional[str] = Query(None, description="Project directory"), dirs: Optional[str] = Query(None, description="Include directories"), type: Optional[str] = Query(None, description="Filter type: file or directory"), - limit: Optional[int] = Query(50, description="Max results"), + limit: Optional[int] = Query(50, ge=1, le=200, description="Max results"), ) -> List[str]: """Search for files by name""" - cwd = directory or os.getcwd() + _validate_search_input(query, label="file query", max_length=200) + cwd = await _resolve_search_directory(directory) try: # Use fd if available @@ -118,20 +143,21 @@ async def find_files( cmd.extend(["--type", "d"]) elif type == "file": cmd.extend(["--type", "f"]) - cmd.append(query) + cmd.extend(["--", query]) result = subprocess.run( cmd, cwd=cwd, capture_output=True, text=True, + timeout=10, ) return result.stdout.strip().splitlines() if result.stdout else [] except FileNotFoundError: # fd not available, use find try: - cmd = ["find", ".", "-name", f"*{query}*", "-maxdepth", "10"] + cmd = ["find", ".", "-maxdepth", "10", "-name", f"*{query}*"] if type == "directory": cmd.extend(["-type", "d"]) elif type == "file": @@ -142,6 +168,7 @@ async def find_files( cwd=cwd, capture_output=True, text=True, + timeout=10, ) files = result.stdout.strip().splitlines() if result.stdout else [] diff --git a/flocks/server/routes/mcp.py b/flocks/server/routes/mcp.py index b606e68d..57166665 100644 --- a/flocks/server/routes/mcp.py +++ b/flocks/server/routes/mcp.py @@ -29,9 +29,13 @@ LOCAL_MCP_TYPES, REMOTE_MCP_TYPES, extract_api_key_from_mcp_url, + extract_auth_value_from_mcp_config, + extract_sensitive_headers_from_mcp_config, get_connect_block_reason, + mask_sensitive_mcp_config_for_frontend, normalize_mcp_config, normalize_mcp_config_aliases, + restore_masked_mcp_config_secrets, should_allow_unconnected_add, should_skip_connect_on_add, ) @@ -45,6 +49,7 @@ def _to_frontend_mcp_config(server_config: Dict[str, Any]) -> Dict[str, Any]: """Normalize backend transport names for the frontend form.""" + server_config = mask_sensitive_mcp_config_for_frontend(server_config) transport = str(server_config.get("type", "sse")).strip().lower() if transport in LOCAL_MCP_TYPES: transport = "stdio" @@ -79,10 +84,14 @@ def _to_frontend_mcp_config(server_config: Dict[str, Any]) -> Dict[str, Any]: "url": server_config.get("url"), "command": command, "args": args, + "transport": server_config.get("transport", "auto"), + "headers": server_config.get("headers"), + "auth": server_config.get("auth"), + "oauth": server_config.get("oauth"), } async def _load_mcp_server_config(name: str) -> Optional[Dict[str, Any]]: - """Load a server config from resolved config, falling back to raw JSON.""" + """Load a server config with secrets resolved for runtime connect/test paths.""" config = await Config.get() mcp_config = getattr(config, "mcp", None) or {} @@ -112,6 +121,21 @@ async def _load_mcp_server_config(name: str) -> Optional[Dict[str, Any]]: return normalize_mcp_config(server_config) +def _load_raw_mcp_server_config(name: str) -> Optional[Dict[str, Any]]: + """Load a server config without resolving secret placeholders.""" + server_config = ConfigWriter.get_mcp_server(name) + if hasattr(server_config, "model_dump"): + server_config = server_config.model_dump() + elif hasattr(server_config, "dict"): + server_config = server_config.dict() + elif server_config is not None and not isinstance(server_config, dict): + server_config = dict(server_config) + + if not isinstance(server_config, dict): + return None + return normalize_mcp_config(server_config) + + async def _build_mcp_status_response() -> Dict[str, Any]: """Merge runtime state with configured-but-not-connected MCP servers.""" status = await MCP.status() @@ -139,6 +163,14 @@ def _persist_mcp_server_config(name: str, config: Dict[str, Any]) -> None: save_mcp_config(name, config) +def _prepare_mcp_config_for_save(name: str, config: Dict[str, Any]) -> Dict[str, Any]: + """Normalize config and move any plain-text remote secrets into SecretManager.""" + clean_config = extract_api_key_from_mcp_url(name, normalize_mcp_config(config)) + clean_config = extract_auth_value_from_mcp_config(name, clean_config) + clean_config = extract_sensitive_headers_from_mcp_config(name, clean_config) + return clean_config + + # Request/Response models class McpAddRequest(BaseModel): @@ -202,9 +234,7 @@ async def add_mcp_server(request: McpAddRequest): # Extract any API key embedded in the URL and move it to .secret.json. # The URL is rewritten to use a {secret:...} reference so that the # plain-text credential is never written to flocks.json. - clean_config = extract_api_key_from_mcp_url( - request.name, normalize_mcp_config(request.config) - ) + clean_config = _prepare_mcp_config_for_save(request.name, request.config) if should_skip_connect_on_add(clean_config): _persist_mcp_server_config(request.name, clean_config) @@ -332,6 +362,7 @@ async def test_existing_mcp_connection(name: str, request: McpUpdateRequest): merged_config = dict(base_config) merged_config.update(normalize_mcp_config(request.config)) + merged_config = restore_masked_mcp_config_secrets(base_config, merged_config) success = await MCP.connect(temp_name, merged_config) if not success: @@ -432,13 +463,16 @@ async def remove_mcp_server(name: str): async def update_mcp_server(name: str, request: McpUpdateRequest): """Update an existing MCP server configuration and clear stale runtime state.""" try: - existing_config = await _load_mcp_server_config(name) + existing_config = _load_raw_mcp_server_config(name) if not existing_config: raise HTTPException(status_code=404, detail=f"MCP server not found: {name}") updated_config = dict(existing_config) updated_config.update(normalize_mcp_config(request.config)) - clean_config = extract_api_key_from_mcp_url(name, updated_config) + updated_config = restore_masked_mcp_config_secrets( + existing_config, updated_config + ) + clean_config = _prepare_mcp_config_for_save(name, updated_config) _persist_mcp_server_config(name, clean_config) status = await MCP.status() @@ -467,7 +501,7 @@ async def get_mcp_server_info(name: str): """Get info for a specific MCP server, with config from flocks.json when available.""" try: info = await MCP.get_server_info(name) - server_config = await _load_mcp_server_config(name) + server_config = _load_raw_mcp_server_config(name) if not info and not server_config: raise HTTPException(status_code=404, detail=f"MCP server not found: {name}") if not info: diff --git a/flocks/server/routes/provider.py b/flocks/server/routes/provider.py index bf13f016..53d0082b 100644 --- a/flocks/server/routes/provider.py +++ b/flocks/server/routes/provider.py @@ -2320,9 +2320,15 @@ async def test_provider_credentials(provider_id: str, body: Optional[TestCredent requested_model_id = body.model_id if body else None test_model_id = requested_model_id or models[0].id - # Validate model belongs to this provider + # Validate model belongs to this provider. Azure OpenAI is the + # exception: users may test a deployment name before saving it. valid_ids = {m.id for m in models} - if test_model_id not in valid_ids: + is_unsaved_azure_deployment = ( + requested_model_id + and provider_id in {"azure-openai", "azure"} + and test_model_id not in valid_ids + ) + if test_model_id not in valid_ids and not is_unsaved_azure_deployment: response = { "success": False, "message": f"模型 '{test_model_id}' 不属于该 Provider", @@ -2519,6 +2525,11 @@ def _tool_sort_key(t: ToolInfo) -> tuple[int, int, str]: # actions (which would each trigger a fresh login attempt). if _is_login_probe(t) or _is_action_dispatch_login_probe(t): priority = -1 + # OneSEC grouped tools all look equally generic to the default + # heuristic, but the threat probe maps to a safer read-only + # version query than the legacy DNS probe. + elif provider_id == "onesec_api" and name_lower == "onesec_threat": + priority = 0 # Prefer query/scan style tools first, and push upload/file tools last. elif "ip" in name_lower: priority = 0 diff --git a/flocks/server/routes/pty.py b/flocks/server/routes/pty.py index 4b15d9be..1ddf3ccd 100644 --- a/flocks/server/routes/pty.py +++ b/flocks/server/routes/pty.py @@ -8,6 +8,7 @@ from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, status from pydantic import BaseModel, Field +from flocks.server.auth import apply_auth_for_request, clear_auth_context from flocks.utils.log import Log from flocks.pty.pty import Pty, PtyInfo, CreateInput, UpdateInput, PtyStatus @@ -134,14 +135,23 @@ async def connect_session(websocket: WebSocket, pty_id: str): Establish a WebSocket connection to interact with a PTY session in real-time. """ - # Check if session exists first - if not Pty.get(pty_id): - await websocket.close(code=4004, reason="Session not found") + token = None + try: + _blocked, token, _user = await apply_auth_for_request(websocket) + except HTTPException as exc: + close_code = 4403 if exc.status_code == status.HTTP_403_FORBIDDEN else 4401 + await websocket.close(code=close_code, reason=str(exc.detail)) return - - await websocket.accept() - + try: + # Check only after authentication so unauthenticated callers cannot + # probe for active PTY identifiers. + if not Pty.get(pty_id): + await websocket.close(code=4004, reason="Session not found") + return + + await websocket.accept() + # Connect to PTY handlers = await Pty.connect(pty_id, websocket) if not handlers: @@ -164,3 +174,6 @@ async def connect_session(websocket: WebSocket, pty_id: str): except Exception as e: log.error("pty.ws.connect.error", {"id": pty_id, "error": str(e)}) + finally: + if token is not None: + clear_auth_context(token) diff --git a/flocks/server/routes/session.py b/flocks/server/routes/session.py index f59634c6..db40d397 100644 --- a/flocks/server/routes/session.py +++ b/flocks/server/routes/session.py @@ -27,6 +27,13 @@ # Default agent name constant DEFAULT_AGENT = "rex" +# File extensions that are safe to persist when materialising data-URL uploads. +# Intentionally narrow: any extension outside this set is rejected to prevent +# OS tools (Finder, ``open``) from misidentifying content based on the +# extension (e.g. a PNG named report.pdf.exe whose tail would otherwise be +# ".exe"). +_UPLOAD_SAFE_EXTS = frozenset({"png", "jpg", "jpeg", "gif", "webp", "bmp", "pdf"}) + # Import monitor for metrics endpoint from flocks.utils.monitor import get_monitor @@ -442,6 +449,31 @@ async def delete_session(sessionID: str, request: Request) -> bool: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="仅管理员或会话所有者可删除会话") await Session.delete(session.project_id, sessionID) + + # Best-effort cleanup of any image/file uploads materialised for this + # session via ``_materialize_data_url_to_disk`` (see prompt_async). + # The session DB row is gone, so the on-disk bytes are now orphaned — + # remove them to keep the workspace tidy. We deliberately swallow any + # filesystem errors: deletion of the session record is the contract, + # the upload cleanup is incidental. + try: + import shutil + from flocks.workspace.manager import WorkspaceManager + + ws = WorkspaceManager.get_instance() + uploads_root = ws.resolve_workspace_path(f"uploads/{sessionID}") + if uploads_root.exists() and uploads_root.is_dir(): + shutil.rmtree(uploads_root, ignore_errors=True) + log.info("session.uploads.cleaned", { + "session_id": sessionID, + "path": str(uploads_root), + }) + except Exception as exc: + log.warn("session.uploads.cleanup_failed", { + "session_id": sessionID, + "error": str(exc), + }) + log.info("session.deleted", {"session_id": sessionID}) return True @@ -933,6 +965,10 @@ class MessagePartInfo(BaseModel): state: Optional[Dict[str, Any]] = None callID: Optional[str] = None metadata: Optional[Dict[str, Any]] = None + # File / image attachment fields (populated when ``type == "file"``). + url: Optional[str] = None + mime: Optional[str] = None + filename: Optional[str] = None class MessageWithParts(BaseModel): @@ -1054,7 +1090,7 @@ async def get_session_messages( state_value = raw_state.model_dump() elif isinstance(raw_state, dict): state_value = raw_state - + part_info = MessagePartInfo( id=part.id if hasattr(part, 'id') else f"{msg.id}_part_{i}", messageID=msg.id, @@ -1066,6 +1102,9 @@ async def get_session_messages( state=state_value, callID=getattr(part, 'callID', None) if part.type == "tool" else None, metadata=getattr(part, 'metadata', None), + url=getattr(part, 'url', None) if part.type == "file" else None, + mime=getattr(part, 'mime', None) if part.type == "file" else None, + filename=getattr(part, 'filename', None) if part.type == "file" else None, ) parts.append(part_info) result.append(MessageWithParts(info=info, parts=parts)) @@ -1157,6 +1196,9 @@ async def get_message(sessionID: str, messageID: str) -> MessageWithParts: state=getattr(part, 'state', None) if part.type == "tool" else None, callID=getattr(part, 'callID', None) if part.type == "tool" else None, metadata=getattr(part, 'metadata', None), + url=getattr(part, 'url', None) if part.type == "file" else None, + mime=getattr(part, 'mime', None) if part.type == "file" else None, + filename=getattr(part, 'filename', None) if part.type == "file" else None, ) parts.append(part_info) return MessageWithParts(info=info, parts=parts) @@ -2006,19 +2048,25 @@ async def _process_session_message( # 1. Extract text content # ------------------------------------------------------------------ text_content = "" + has_non_text_parts = False for part in request.parts: - if part.get("type") == "text": + part_type = part.get("type") + if part_type == "text": text_content += part.get("text", "") - - if not text_content: + elif part_type: + has_non_text_parts = True + + # Allow messages that only contain attachments (e.g. an image with no caption) + if not text_content and not has_non_text_parts: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="No text content in message" + detail="Message must contain text or at least one attachment" ) - + log.info("session.message.received", { "sessionID": sessionID, "content_length": len(text_content), + "has_non_text_parts": has_non_text_parts, }) # ------------------------------------------------------------------ @@ -2114,6 +2162,104 @@ async def _process_session_message( _part_event["synthetic"] = True await publish_event("message.part.updated", {"part": _part_event}) + # ------------------------------------------------------------------ + # 3a. Persist any non-text parts (file/image attachments) so the + # SessionLoop sees them when building the LLM request. Without + # this, file parts sent from clients would be silently dropped. + # + # For ``data:`` URLs we materialize the bytes to disk and store + # a ``file://`` reference instead. Keeping the raw base64 string + # in the message database is dangerous: any code path that later + # stringifies the part (legacy LLM adapters, logging, compaction) + # would tokenize hundreds of KB of base64 and blow past the + # model's context window. + # ------------------------------------------------------------------ + from flocks.session.message import FilePart + + def _materialize_data_url_to_disk( + data_url: str, mime_hint: str, filename_hint: Optional[str] + ) -> str: + """Decode a ``data:`` URL to ``~/.flocks/workspace/uploads//...``. + + Returns a ``file://`` URL pointing at the persisted file. On failure + the original ``data:`` URL is returned unchanged (older code paths + still cope with that, just with the now-known token-cost penalty). + """ + try: + import base64 + from flocks.workspace.manager import WorkspaceManager + + header, _, encoded = data_url.partition(",") + if not encoded: + return data_url + raw_bytes = base64.b64decode(encoded) + + ws = WorkspaceManager.get_instance() + # Use resolve_workspace_path to guard against path traversal if + # sessionID were ever user-controlled (e.g. ../../../tmp/x). + uploads_root = ws.resolve_workspace_path(f"uploads/{sessionID}") + uploads_root.mkdir(parents=True, exist_ok=True) + + ext_map = { + "image/png": ".png", "image/jpeg": ".jpg", "image/jpg": ".jpg", + "image/gif": ".gif", "image/webp": ".webp", "image/bmp": ".bmp", + "application/pdf": ".pdf", + } + ext = ext_map.get(mime_hint, "") + if not ext and filename_hint: + _, _, tail = filename_hint.rpartition(".") + if tail.lower() in _UPLOAD_SAFE_EXTS: + ext = "." + tail.lower() + unique_name = f"{Identifier.create('upload')}{ext}" + target = uploads_root / unique_name + target.write_bytes(raw_bytes) + return f"file://{target.resolve()}" + except Exception as exc: + log.warn("session.message.file_part.materialize_failed", { + "sessionID": sessionID, + "error": str(exc), + }) + return data_url + + for raw_part in request.parts or []: + part_type = raw_part.get("type") + if part_type == "text": + continue # Already stored as the message's TextPart above + if part_type == "file": + url = raw_part.get("url") or "" + mime = raw_part.get("mime") or "" + if not url or not mime: + log.warn("session.message.file_part.skipped", { + "sessionID": sessionID, + "reason": "missing url or mime", + }) + continue + # Materialize ``data:`` URLs to disk before persisting the part. + if url.startswith("data:"): + url = _materialize_data_url_to_disk(url, mime, raw_part.get("filename")) + file_part_id = raw_part.get("id") or Identifier.create("part") + file_part = FilePart( + id=file_part_id, + sessionID=sessionID, + messageID=user_message_id, + mime=mime, + filename=raw_part.get("filename"), + url=url, + ) + await Message.add_part(sessionID, user_message_id, file_part) + await publish_event("message.part.updated", { + "part": { + "id": file_part_id, + "messageID": user_message_id, + "sessionID": sessionID, + "type": "file", + "mime": mime, + "filename": raw_part.get("filename"), + "url": url, + "time": {"start": now_ms}, + } + }) + # ------------------------------------------------------------------ # noReply: store message only, skip AI loop # ------------------------------------------------------------------ diff --git a/flocks/server/routes/task_entities.py b/flocks/server/routes/task_entities.py index 13d7681d..62e01abb 100644 --- a/flocks/server/routes/task_entities.py +++ b/flocks/server/routes/task_entities.py @@ -51,6 +51,7 @@ class SchedulerUpdateRequest(BaseModel): cron_description: Optional[str] = Field(None, alias="cronDescription") timezone: Optional[str] = None user_prompt: Optional[str] = Field(None, alias="userPrompt") + context: Optional[dict] = None workspace_directory: Optional[str] = Field(None, alias="workspaceDirectory") diff --git a/flocks/server/routes/tool.py b/flocks/server/routes/tool.py index b152a160..6fcc5f30 100644 --- a/flocks/server/routes/tool.py +++ b/flocks/server/routes/tool.py @@ -4,9 +4,10 @@ import asyncio from typing import List, Optional, Dict, Any -from fastapi import APIRouter, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel, Field +from flocks.server.auth import require_admin from flocks.utils.log import Log from flocks.config.config_writer import ConfigWriter from flocks.permission.next import DeniedError, PermissionNext @@ -493,7 +494,7 @@ async def get_tool(tool_name: str): response_model=ToolInfoResponse, summary="Update tool settings", ) -async def update_tool(tool_name: str, request: ToolUpdateRequest): +async def update_tool(tool_name: str, request: ToolUpdateRequest, _admin: object = Depends(require_admin)): """ Update tool settings (e.g., enable or disable). @@ -559,7 +560,7 @@ async def update_tool(tool_name: str, request: ToolUpdateRequest): response_model=ToolInfoResponse, summary="Reset a tool to its YAML/registration default", ) -async def reset_tool_setting(tool_name: str): +async def reset_tool_setting(tool_name: str, _admin: object = Depends(require_admin)): """Remove the user setting for ``tool_name`` and restore the default. Restores the registration-time ``enabled`` value from the registry's @@ -756,7 +757,7 @@ class RefreshResponse(BaseModel): response_model=RefreshResponse, summary="Refresh all plugin and dynamic tools", ) -async def refresh_tools(): +async def refresh_tools(_admin: object = Depends(require_admin)): """ Reload all plugin tools (YAML + Python) and dynamically generated tools from disk without restarting the service. @@ -958,7 +959,7 @@ class PluginToolListResponse(BaseModel): status_code=status.HTTP_201_CREATED, summary="Create a YAML plugin tool", ) -async def create_tool(request: CreateToolRequest): +async def create_tool(request: CreateToolRequest, _admin: object = Depends(require_admin)): """ Create a new tool via YAML plugin. @@ -1034,7 +1035,7 @@ async def create_tool(request: CreateToolRequest): response_model=ToolInfoResponse, summary="Update a YAML plugin tool", ) -async def update_plugin_tool(name: str, request: UpdateToolRequest): +async def update_plugin_tool(name: str, request: UpdateToolRequest, _admin: object = Depends(require_admin)): """ Update an existing YAML plugin tool. @@ -1095,7 +1096,7 @@ async def update_plugin_tool(name: str, request: UpdateToolRequest): "/{name}", summary="Delete a plugin tool", ) -async def delete_tool(name: str): +async def delete_tool(name: str, _admin: object = Depends(require_admin)): """ Delete a plugin tool. @@ -1141,7 +1142,7 @@ async def delete_tool(name: str): response_model=ToolInfoResponse, summary="Reload a YAML plugin tool", ) -async def reload_tool(name: str): +async def reload_tool(name: str, _admin: object = Depends(require_admin)): """ Hot-reload a single YAML plugin tool. diff --git a/flocks/server/routes/workflow.py b/flocks/server/routes/workflow.py index 99bf56be..a20875fe 100644 --- a/flocks/server/routes/workflow.py +++ b/flocks/server/routes/workflow.py @@ -44,6 +44,13 @@ read_workflow_from_fs as shared_read_workflow_from_fs, workflow_scan_dirs as _all_scan_dirs, ) +from flocks.workflow.execution_store import ( + create_execution_record, + normalize_execution_status as _normalize_execution_status, + record_execution_result as _record_execution_result, + resolve_execution_outcome as _resolve_execution_outcome, + workflow_execution_key as _workflow_execution_key, +) from flocks.workflow.io import load_workflow, dump_workflow from flocks.workflow.tools import get_tool_registry from flocks.config.config import Config @@ -81,6 +88,10 @@ class WorkflowCreateRequest(BaseModel): category: Optional[str] = Field("default", description="Workflow category") workflow_json: Dict[str, Any] = Field(..., alias="workflowJson", description="Workflow JSON definition") created_by: Optional[str] = Field(None, alias="createdBy", description="Creator") + source: Optional[Literal["project", "global"]] = Field( + "global", + description="Storage location: 'project' or 'global'; defaults to global user storage", + ) class WorkflowUpdateRequest(BaseModel): @@ -138,6 +149,10 @@ class WorkflowExecutionResponse(BaseModel): duration: Optional[float] = Field(None, description="Duration (seconds)") executionLog: List[Dict[str, Any]] = Field(default_factory=list, description="Execution log") errorMessage: Optional[str] = Field(None, description="Error message") + currentNodeId: Optional[str] = Field(None, description="Current running node ID") + currentNodeType: Optional[str] = Field(None, description="Current running node type") + currentPhase: Optional[str] = Field(None, description="Current execution phase") + currentStepIndex: Optional[int] = Field(None, description="Current step index") class WorkflowCenterPublishRequest(BaseModel): @@ -418,65 +433,6 @@ def _workflow_stats_key(workflow_id: str) -> str: return f"workflow/{workflow_id}/stats" -def _workflow_execution_key(exec_id: str) -> str: - return f"workflow_execution/{exec_id}" - - -def _normalize_execution_status(status: str) -> str: - """Map runner status values to API status values.""" - normalized = (status or "").strip().upper() - if normalized == "SUCCEEDED": - return "success" - if normalized == "FAILED": - return "error" - if normalized == "TIMED_OUT": - return "timeout" - if normalized == "CANCELLED": - return "cancelled" - return (status or "error").strip().lower() or "error" - - -def _extract_business_failure_message(outputs: Dict[str, Any]) -> Optional[str]: - """Return a user-facing failure reason from workflow outputs.""" - for key in ("reason", "error_message", "errorMessage", "message"): - value = outputs.get(key) - if isinstance(value, str) and value.strip(): - return value.strip() - return None - - -def _resolve_execution_outcome(result: RunWorkflowResult) -> tuple[str, Optional[str]]: - """Resolve API execution status from runner status and workflow outputs.""" - status_value = _normalize_execution_status(result.status) - error_message = result.error - - if status_value != "success" or not isinstance(result.outputs, dict): - return status_value, error_message - - if result.outputs.get("workflow_success") is False: - return ( - "error", - error_message - or _extract_business_failure_message(result.outputs) - or "Workflow reported business failure.", - ) - - return status_value, error_message - - -async def _record_execution_result(workflow_id: str, exec_id: str, exec_data: Dict[str, Any]) -> None: - """Persist the final execution record and audit trail.""" - await Storage.write(_workflow_execution_key(exec_id), exec_data) - try: - await Recorder.record_workflow_execution( - exec_id=exec_id, - workflow_id=workflow_id, - run_result=exec_data, - ) - except Exception: - pass - - async def _run_workflow_execution_task( *, workflow_id: str, @@ -491,22 +447,38 @@ async def _run_workflow_execution_task( start_time = time.time() step_history: list[dict[str, Any]] = [] loop = asyncio.get_running_loop() - def _on_step_complete(step_result) -> None: - step_dict = step_result.model_dump(mode="json") - step_history.append(step_dict) + + def _write_progress(update_fields: Dict[str, Any]) -> None: try: current = asyncio.run_coroutine_threadsafe(Storage.read(exec_key), loop).result(timeout=5) - update = { - **current, - "executionLog": list(step_history), - } - asyncio.run_coroutine_threadsafe(Storage.write(exec_key, update), loop).result(timeout=5) + current.update(update_fields) + asyncio.run_coroutine_threadsafe(Storage.write(exec_key, current), loop).result(timeout=5) except Exception as exc: log.warning("workflow.step_progress.write_failed", { "exec_id": exec_id, "error": str(exc), }) + def _on_step_start(_run_id, step_index, node, _inputs): + _write_progress({ + "currentNodeId": getattr(node, "id", None), + "currentNodeType": getattr(node, "type", None), + "currentPhase": "running", + "currentStepIndex": step_index, + }) + return step_index + + def _on_step_complete(step_result) -> None: + step_dict = step_result.model_dump(mode="json") + step_history.append(step_dict) + _write_progress({ + "executionLog": list(step_history), + "currentNodeId": step_dict.get("node_id"), + "currentNodeType": step_dict.get("node_type") or step_dict.get("type"), + "currentPhase": "running", + "currentStepIndex": len(step_history), + }) + try: result: RunWorkflowResult = await asyncio.to_thread( run_workflow, @@ -514,6 +486,7 @@ def _on_step_complete(step_result) -> None: inputs=req.inputs or {}, timeout_s=req.timeout_s, trace=req.trace, + on_step_start=_on_step_start, on_step_complete=_on_step_complete, cancel=cancel_event.is_set, tool_context=tool_context, @@ -529,6 +502,10 @@ def _on_step_complete(step_result) -> None: "duration": duration, "executionLog": result.history or list(step_history), "errorMessage": error_message, + "currentNodeId": result.last_node_id, + "currentNodeType": current_data.get("currentNodeType"), + "currentPhase": status_value, + "currentStepIndex": result.steps, }) if status_value == "success": @@ -552,6 +529,7 @@ def _on_step_complete(step_result) -> None: "duration": duration, "errorMessage": str(exc), "executionLog": list(step_history), + "currentPhase": "cancelled" if cancel_event.is_set() else "error", }) if current_data["status"] == "error": await _update_workflow_stats(workflow_id, False, duration) @@ -673,6 +651,7 @@ async def create_workflow(req: WorkflowCreateRequest): workflow_id = str(uuid.uuid4()) now_ms = int(time.time() * 1000) + source = req.source or "global" meta = { "id": workflow_id, "name": req.name, @@ -684,10 +663,16 @@ async def create_workflow(req: WorkflowCreateRequest): "updatedAt": now_ms, } - _write_workflow_to_fs(workflow_id, req.workflow_json, meta) + _write_workflow_to_fs(workflow_id, req.workflow_json, meta, global_store=(source == "global")) stats = await _get_workflow_stats(workflow_id) - data = {**meta, "workflowJson": req.workflow_json, "markdownContent": None, "stats": stats} + data = { + **meta, + "workflowJson": req.workflow_json, + "markdownContent": None, + "stats": stats, + "source": source, + } log.info("workflow.created", {"id": workflow_id, "name": req.name}) await publish_event("workflow.created", {"id": workflow_id, "name": req.name}) @@ -847,21 +832,11 @@ async def run_workflow_endpoint(workflow_id: str, req: WorkflowRunRequest): agent=req.agent, ) - # Create execution record - exec_id = str(uuid.uuid4()) - start_ms = int(time.time() * 1000) - - exec_data = { - "id": exec_id, - "workflowId": workflow_id, - "inputParams": req.inputs or {}, - "status": "running", - "startedAt": start_ms, - "executionLog": [], - } - - # Save initial execution record - await Storage.write(_workflow_execution_key(exec_id), exec_data) + exec_data = await create_execution_record( + workflow_id, + input_params=req.inputs or {}, + ) + exec_id = str(exec_data["id"]) cancel_event = threading.Event() task = asyncio.create_task( @@ -1272,10 +1247,16 @@ async def import_workflow(workflow_json: Dict[str, Any]): "updatedAt": now_ms, } - _write_workflow_to_fs(workflow_id, workflow_json, meta) + _write_workflow_to_fs(workflow_id, workflow_json, meta, global_store=True) stats = await _get_workflow_stats(workflow_id) - data = {**meta, "workflowJson": workflow_json, "markdownContent": None, "stats": stats} + data = { + **meta, + "workflowJson": workflow_json, + "markdownContent": None, + "stats": stats, + "source": "global", + } log.info("workflow.imported", {"id": workflow_id, "name": name}) await publish_event("workflow.created", {"id": workflow_id, "name": name}) diff --git a/flocks/session/runner.py b/flocks/session/runner.py index 424d9b0d..b1ca375b 100644 --- a/flocks/session/runner.py +++ b/flocks/session/runner.py @@ -12,6 +12,7 @@ import asyncio import json import os +import re import sys import time from datetime import datetime @@ -95,6 +96,38 @@ def _annotate_with_provider_version(tool_info: Any, description: Optional[str]) # spurious failures in those cases. LLM_STREAM_ONGOING_CHUNK_TIMEOUT_S = 300 +_WORKFLOW_NODE_REF_RE = re.compile(r"^@@node:([^|\n]+)\|([^\n]+)\n?([\s\S]*)$") + + +def _expand_workflow_node_ref(text: str) -> str: + """Translate the web UI's node-ref marker into model-readable text. + + WorkflowDetail chat prefixes a user turn with ``@@node:|`` when + the user picks a node from the canvas. Before this fix the marker was only + rendered/decorated in the UI; the backend passed it through verbatim, so + the model saw an opaque token instead of an explicit instruction to focus + on that node. + """ + if not text: + return text + match = _WORKFLOW_NODE_REF_RE.match(text) + if not match: + return text + + node_id = match.group(1).strip() + node_type = match.group(2).strip() + user_request = match.group(3).lstrip("\n") + + parts = [ + "Selected workflow node context:", + f"- node_id: {node_id}", + f"- node_type: {node_type}", + "- Focus the requested workflow modification on this node unless the user explicitly asks for broader workflow changes.", + ] + if user_request.strip(): + parts.extend(["", "User request:", user_request]) + return "\n".join(parts) + async def _iter_with_chunk_timeout( aiter, @@ -360,8 +393,64 @@ async def _apply_tool_result_budget( return {"compacted": compacted, "persisted": persisted} + # Provider IDs whose adapters are known to translate Flocks' internal + # ``{"type": "image", "mimeType": ..., "data": }`` block into + # the provider-native multimodal format (e.g. OpenAI ``image_url`` or + # Anthropic vision blocks). + # + # Matched with exact equality (no substring matching) to prevent false + # positives such as a user-configured "not-openai" or an internal + # "xxx-llm-gateway" id being mistakenly classified as multimodal-capable. + # + # ``custom-`` providers (created via ``POST /api/custom/providers``) are + # checked separately via ``startswith`` because their ids follow the + # pattern ``custom-`` and they always use the + # ``@ai-sdk/openai-compatible`` adapter that handles vision blocks. + _MULTIMODAL_PROVIDER_NAMES = frozenset({ + "anthropic", "openai", "azure", + "vertex", "bedrock", "openrouter", + }) + + def _model_supports_vision(self) -> bool: + """Best-effort vision capability lookup from the model definition. + + Returns ``True`` only when explicitly declared on the model entry via + ``capabilities.supports_vision``. Defaults to a safe ``False`` for + unknown configurations. + """ + try: + from flocks.provider.provider import Provider as _Provider + + provider = _Provider.get(self.provider_id) + if provider is not None: + for model in getattr(provider, "_config_models", []) or []: + if model.id == self.model_id: + caps = getattr(model, "capabilities", None) + if caps and getattr(caps, "supports_vision", False): + return True + break + except Exception as exc: + log.debug("runner.vision_lookup.failed", {"error": str(exc)}) + return False + def _supports_multimodal_user_content(self) -> bool: - return self.provider_id in {"anthropic", "openai", "openai-compatible"} + """Whether the active provider/model can accept image content blocks. + + Decision order: + 1. The model definition explicitly advertises vision support + (``capabilities.supports_vision`` on the model entry). + 2. The provider id is an exact match against the known multimodal + provider name set, or starts with ``"custom-"`` (user-registered + OpenAI-compatible providers that inherit vision capability from + their underlying model). + """ + if self._model_supports_vision(): + return True + provider_id = (self.provider_id or "").lower() + return ( + provider_id in self._MULTIMODAL_PROVIDER_NAMES + or provider_id.startswith("custom-") + ) def _append_file_content_block( self, @@ -375,10 +464,33 @@ def _append_file_content_block( """Append an appropriate content block for *url* into *blocks*. Images are embedded as base64 for multimodal-capable providers; other - file types are extracted as text. Falls back to a plain markdown link - when extraction is not possible. + file types are extracted as text. We *never* spill a raw ``data:`` URL + into the text fallback — doing so previously caused OpenAI to tokenize + the entire base64 payload (~250k tokens for a single screenshot, + blowing past the model's context window). """ - if self._supports_multimodal_user_content() and mime.startswith("image/"): + is_image = mime.startswith("image/") + multimodal_ok = self._supports_multimodal_user_content() + + log.info("runner.file_part.dispatch", { + "provider_id": self.provider_id, + "model_id": self.model_id, + "mime": mime, + "filename": filename, + "is_image": is_image, + "multimodal_supported": multimodal_ok, + "url_scheme": url.split(":", 1)[0] if url else None, + "url_size": len(url), + }) + + # For images we ALWAYS try the multimodal path first, regardless of + # provider whitelist. The whitelist guards against silently degrading + # to a text fallback that would tokenize the entire base64 payload — + # but since fallback now never embeds the URL, "force-try multimodal" + # is strictly safer: providers that genuinely cannot handle ``image`` + # blocks will surface a precise error, which is far better UX than a + # 250k-token context_length_exceeded error. + if is_image: import base64 as _b64 data = read_file_part_bytes(url) if data: @@ -388,16 +500,44 @@ def _append_file_content_block( "data": _b64.b64encode(data).decode("utf-8"), }) return + log.warn("runner.file_part.image_decode_failed", { + "provider_id": self.provider_id, + "filename": filename, + }) + # Image bytes could not be read — fall through to placeholder + # (which is intentionally tiny, never the raw URL). - extracted_text = extract_file_text(mime=mime, filename=filename, url=url) + extracted_text = extract_file_text(mime=mime, filename=filename, url=url) if not is_image else None if extracted_text: text_fallbacks.append(extracted_text) blocks.append({"type": "text", "text": extracted_text}) return - text_fallbacks.append( - f"[File: {filename}]({url})" if url else f"[File: {filename}]" - ) + # Final fallback — for images we either lack vision support or could + # not decode the bytes. Either way, refuse to embed the raw URL when + # it is a data URI; the base64 payload would otherwise be sent to the + # LLM as plain text and explode the prompt token count. Also clamp the + # placeholder to a hard byte cap as a belt-and-braces measure. + MAX_PLACEHOLDER_CHARS = 200 + if is_image: + placeholder = ( + f"[Image: {filename} — model does not support image input; " + f"the image was omitted from the prompt]" + ) + log.info("runner.file_part.image_skipped", { + "provider_id": self.provider_id, + "model_id": self.model_id, + "reason": "multimodal_unsupported" if not multimodal_ok else "decode_failed", + "filename": filename, + }) + else: + safe_url = url if url and not url.startswith("data:") else "" + placeholder = ( + f"[File: {filename}]({safe_url})" if safe_url else f"[File: {filename}]" + ) + if len(placeholder) > MAX_PLACEHOLDER_CHARS: + placeholder = placeholder[:MAX_PLACEHOLDER_CHARS] + "…" + text_fallbacks.append(placeholder) @classmethod async def loop(cls, session_id: str) -> Optional['MessageInfo']: @@ -829,15 +969,51 @@ async def _process_step( if last_finished: from flocks.session.prompt_strings import SYNTHETIC_MESSAGE_MARKERS for chat_msg in chat_messages: - if chat_msg.role == "user": - if not any(marker in chat_msg.content for marker in SYNTHETIC_MESSAGE_MARKERS): - # Wrap with reminder - chat_msg.content = f""" -The user sent the following message: -{chat_msg.content} + if chat_msg.role != "user": + continue + + content = chat_msg.content -Please address this message and continue with your tasks. -""" + if isinstance(content, str): + if any(marker in content for marker in SYNTHETIC_MESSAGE_MARKERS): + continue + chat_msg.content = ( + "\n" + "The user sent the following message:\n" + f"{content}\n\n" + "Please address this message and continue with your tasks.\n" + "" + ) + elif isinstance(content, list): + # Multimodal user content (e.g. image_url blocks). + # Naively f-stringing the whole list would call + # ``str(list)`` and serialize every image block — base64 + # data and all — into plain text, which both blows up + # the token count AND makes vision-capable models + # respond with "I see only base64 text". Wrap *only* + # the first text block instead, leaving image blocks + # untouched. If there is no text block at all (rare — + # an image-only turn), skip wrapping entirely. + first_text_idx: Optional[int] = None + for idx, block in enumerate(content): + if isinstance(block, dict) and block.get("type") == "text": + first_text_idx = idx + break + if first_text_idx is None: + continue + text_val = content[first_text_idx].get("text") or "" + if any(marker in text_val for marker in SYNTHETIC_MESSAGE_MARKERS): + continue + content[first_text_idx] = { + "type": "text", + "text": ( + "\n" + "The user sent the following message:\n" + f"{text_val}\n\n" + "Please address this message and continue with your tasks.\n" + "" + ), + } # Add max steps warning if this is the last step (matching Flocks) if is_last_step: @@ -1541,7 +1717,20 @@ async def _to_chat_messages( ctx_window_tokens = self._get_context_window_tokens() tool_result_refs: List[Dict[str, Any]] = [] turn_index = 0 - + + # Identify the last USER message — only that one keeps real image + # bytes in its content blocks. Earlier turns get a short text + # placeholder so we don't ship hundreds of KB of base64 back to the + # model on every follow-up. Even providers that count vision tokens + # natively (OpenAI proper) charge per resent image, and gateways that + # tokenize the data URL as plain text (e.g. some Azure proxies) will + # blow past the context window after the second turn otherwise. + last_user_msg_id: Optional[str] = None + for _msg in messages: + _role = _msg.role if isinstance(_msg.role, str) else getattr(_msg.role, "value", None) + if _role == "user": + last_user_msg_id = _msg.id + # Add system prompts if system_prompts: chat_messages.append(ChatMessage( @@ -1553,6 +1742,7 @@ async def _to_chat_messages( for msg in messages: if msg.role == MessageRole.USER or (isinstance(msg.role, str) and msg.role == "user"): turn_index += 1 + is_latest_user_turn = msg.id == last_user_msg_id # Get message parts parts = await Message.parts(msg.id, self.session.id) @@ -1562,7 +1752,7 @@ async def _to_chat_messages( if content.strip(): chat_messages.append(ChatMessage( role=msg.role if isinstance(msg.role, str) else msg.role.value, - content=content, + content=_expand_workflow_node_ref(content), )) continue @@ -1574,23 +1764,41 @@ async def _to_chat_messages( if hasattr(part, 'type'): if part.type == "text" and hasattr(part, 'text'): if not getattr(part, 'ignored', False) and part.text.strip(): - user_content_parts.append(part.text) + normalized_text = _expand_workflow_node_ref(part.text) + user_content_parts.append(normalized_text) user_content_blocks.append({ "type": "text", - "text": part.text, + "text": normalized_text, }) elif part.type == "file" and hasattr(part, 'mime'): mime = getattr(part, 'mime', '') if mime != 'application/x-directory': filename = getattr(part, 'filename', 'file') url = getattr(part, 'url', '') - self._append_file_content_block( - user_content_blocks, - user_content_parts, - mime=mime, - filename=filename, - url=url, - ) + # Image bytes only ride on the latest user + # turn — older turns are reduced to a short, + # opaque placeholder. Crucially the placeholder + # does NOT include the filename: leaking + # earlier filenames was making the model + # confidently misidentify the *current* image + # as one of the older ones (it would echo back + # the older filename instead of describing the + # newly-attached picture). + if mime.startswith("image/") and not is_latest_user_turn: + stub = "[earlier image omitted]" + user_content_parts.append(stub) + user_content_blocks.append({ + "type": "text", + "text": stub, + }) + else: + self._append_file_content_block( + user_content_blocks, + user_content_parts, + mime=mime, + filename=filename, + url=url, + ) elif part.type == "compaction": user_content_parts.append("What did we do so far?") user_content_blocks.append({ diff --git a/flocks/tool/task/run_workflow.py b/flocks/tool/task/run_workflow.py index 2741d5d7..433d6cb2 100644 --- a/flocks/tool/task/run_workflow.py +++ b/flocks/tool/task/run_workflow.py @@ -10,13 +10,23 @@ import inspect import time from pathlib import Path +from types import SimpleNamespace from typing import Optional, Dict, Any, Union from flocks.tool.registry import ( ToolRegistry, ToolCategory, ToolParameter, ParameterType, ToolResult, ToolContext ) +from flocks.storage.storage import Storage from flocks.utils.log import Log from flocks.session.recorder import Recorder +from flocks.workflow.execution_store import ( + create_execution_record, + normalize_execution_status, + record_execution_result, + resolve_execution_outcome, + workflow_execution_key, +) +from flocks.workflow.fs_store import read_workflow_from_fs, resolve_workflow_id_from_source log = Log.create(service="tool.run_workflow") @@ -347,18 +357,23 @@ async def run_workflow_tool( try: workflow_source = json.loads(raw) except json.JSONDecodeError: - # Otherwise treat it as a file path. - p = Path(raw).expanduser() - if p.exists() and p.is_file(): - workflow_source = p + existing_workflow = read_workflow_from_fs(raw) + if existing_workflow is not None: + workflow_source = existing_workflow["workflowJson"] + raw = existing_workflow["id"] else: - return ToolResult( - success=False, - error=( - "Unsupported workflow string. Provide a workflow JSON string, " - "or a valid workflow JSON file path." + # Otherwise treat it as a file path. + p = Path(raw).expanduser() + if p.exists() and p.is_file(): + workflow_source = p + else: + return ToolResult( + success=False, + error=( + "Unsupported workflow string. Provide a workflow ID, workflow JSON string, " + "or a valid workflow JSON file path." + ) ) - ) elif isinstance(workflow, dict): workflow_source = workflow else: @@ -375,6 +390,95 @@ async def run_workflow_tool( else: workflow_name = workflow_source.name workflow_id = str(workflow_source) + + workflow_inputs = inputs or {} + canonical_workflow_id = resolve_workflow_id_from_source(workflow_source) + display_workflow_id = canonical_workflow_id or workflow_id + tracked_execution: Optional[Dict[str, Any]] = None + tracked_history: list[Dict[str, Any]] = [] + tracked_exec_key: Optional[str] = None + loop = asyncio.get_running_loop() + + def _emit_metadata(metadata: Dict[str, Any]) -> None: + loop.call_soon_threadsafe(ctx.metadata, metadata) + + def _update_execution_progress(update_fields: Dict[str, Any]) -> None: + if not tracked_exec_key: + return + try: + current = asyncio.run_coroutine_threadsafe( + Storage.read(tracked_exec_key), + loop, + ).result(timeout=5) + current.update(update_fields) + asyncio.run_coroutine_threadsafe( + Storage.write(tracked_exec_key, current), + loop, + ).result(timeout=5) + except Exception as exc: + log.warning("run_workflow.execution_progress.write_failed", { + "workflow_id": display_workflow_id, + "exec_id": tracked_execution["id"] if tracked_execution else None, + "error": str(exc), + }) + + def _on_step_start( + run_id: Optional[str], + step_index: int, + node: Any, + _inputs: Dict[str, Any], + ) -> int: + current_node_id = getattr(node, "id", None) + current_node_type = getattr(node, "type", None) + _update_execution_progress({ + "currentNodeId": current_node_id, + "currentNodeType": current_node_type, + "currentPhase": "running", + "currentStepIndex": step_index, + }) + _emit_metadata({ + "title": f"Running workflow: {workflow_name}", + "metadata": { + "workflow_id": display_workflow_id, + "workflow_execution_id": tracked_execution["id"] if tracked_execution else None, + "run_id": run_id, + "status": "running", + "phase": "running", + "current_node_id": current_node_id, + "current_node_type": current_node_type, + "step_index": step_index, + }, + }) + return step_index + + def _on_step_complete(step_result: Any) -> None: + if hasattr(step_result, "model_dump"): + step_dict = step_result.model_dump(mode="json") + elif isinstance(step_result, dict): + step_dict = dict(step_result) + else: + step_dict = {"node_id": None, "outputs": {}, "error": str(step_result)} + tracked_history.append(step_dict) + _update_execution_progress({ + "executionLog": list(tracked_history), + "currentNodeId": step_dict.get("node_id"), + "currentNodeType": step_dict.get("node_type") or step_dict.get("type"), + "currentPhase": "running", + "currentStepIndex": len(tracked_history), + }) + _emit_metadata({ + "title": f"Running workflow: {workflow_name}", + "metadata": { + "workflow_id": display_workflow_id, + "workflow_execution_id": tracked_execution["id"] if tracked_execution else None, + "status": "running", + "phase": "running", + "current_node_id": step_dict.get("node_id"), + "current_node_type": step_dict.get("node_type") or step_dict.get("type"), + "step_index": len(tracked_history), + "completed_steps": len(tracked_history), + }, + }) await ctx.ask( permission="run_workflow", @@ -388,18 +492,25 @@ async def run_workflow_tool( } ) + if canonical_workflow_id: + tracked_execution = await create_execution_record( + canonical_workflow_id, + input_params=workflow_inputs, + ) + tracked_exec_key = workflow_execution_key(tracked_execution["id"]) + # Update metadata to show workflow is running - ctx.metadata({ + _emit_metadata({ "title": f"Running workflow: {workflow_name}", "metadata": { - "workflow_id": workflow_id, - "status": "running" - } + "workflow_id": display_workflow_id, + "workflow_execution_id": tracked_execution["id"] if tracked_execution else None, + "status": "running", + "phase": "queued", + "step_index": 0, + }, }) - - # Prepare inputs - workflow_inputs = inputs or {} - + try: # Execute workflow log.info("run_workflow.execute.start", { @@ -407,6 +518,7 @@ async def run_workflow_tool( "workflow_name": workflow_name, "ensure_requirements": ensure_requirements, }) + execution_started_at = time.time() call_kwargs: Dict[str, Any] = { "workflow": workflow_source, @@ -422,18 +534,27 @@ async def run_workflow_tool( # Backward-compatibility: older runtimes may not accept `use_llm`. supports_use_llm = False + supports_step_start = False try: sig = inspect.signature(_run_workflow_fn) supports_use_llm = ( "use_llm" in sig.parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) ) + supports_step_start = ( + "on_step_start" in sig.parameters + or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) + ) except Exception: # Best-effort: assume supported. supports_use_llm = True + supports_step_start = True if supports_use_llm: call_kwargs["use_llm"] = use_llm + if supports_step_start: + call_kwargs["on_step_start"] = _on_step_start + call_kwargs["on_step_complete"] = _on_step_complete try: result = await asyncio.to_thread(_run_workflow_fn, **call_kwargs) @@ -442,6 +563,9 @@ async def run_workflow_tool( if supports_use_llm and "use_llm" in str(te): call_kwargs.pop("use_llm", None) result = await asyncio.to_thread(_run_workflow_fn, **call_kwargs) + elif supports_step_start and "on_step_start" in str(te): + call_kwargs.pop("on_step_start", None) + result = await asyncio.to_thread(_run_workflow_fn, **call_kwargs) else: raise @@ -467,8 +591,48 @@ async def run_workflow_tool( }) # Append-only recording for audit/replay - await _record_workflow_tool_result(workflow_id, result_dict) - + await _record_workflow_tool_result(display_workflow_id, result_dict) + + status_value = normalize_execution_status(status) + if tracked_execution and canonical_workflow_id and tracked_exec_key: + current_data = await Storage.read(tracked_exec_key) + outcome_result = result + if not hasattr(outcome_result, "status"): + outcome_result = SimpleNamespace( + status=result_dict.get("status"), + outputs=result_dict.get("outputs", {}), + error=result_dict.get("error"), + ) + status_value, error_message = resolve_execution_outcome(outcome_result) # type: ignore[arg-type] + current_data.update({ + "outputResults": result_dict.get("outputs"), + "status": status_value, + "finishedAt": int(time.time() * 1000), + "duration": time.time() - execution_started_at, + "executionLog": result_dict.get("history") or list(tracked_history), + "errorMessage": error_message, + "currentNodeId": result_dict.get("last_node_id"), + "currentPhase": status_value, + "currentStepIndex": result_dict.get("steps", len(tracked_history)), + }) + await record_execution_result( + canonical_workflow_id, + tracked_execution["id"], + current_data, + ) + _emit_metadata({ + "title": f"Workflow: {workflow_name}", + "metadata": { + "workflow_id": canonical_workflow_id, + "workflow_execution_id": tracked_execution["id"], + "run_id": result_dict.get("run_id"), + "status": status_value, + "phase": status_value, + "current_node_id": result_dict.get("last_node_id"), + "step_index": result_dict.get("steps", len(tracked_history)), + }, + }) + # If workflow failed, include error in ToolResult if not success and error: return ToolResult( @@ -477,8 +641,9 @@ async def run_workflow_tool( output=output, # Also include formatted output for context title=f"Workflow: {workflow_name}", metadata={ - "workflow_id": workflow_id, - "status": status, + "workflow_id": display_workflow_id, + "workflow_execution_id": tracked_execution["id"] if tracked_execution else None, + "status": status_value, "steps": result_dict.get("steps", 0), "run_id": result_dict.get("run_id"), "last_node_id": result_dict.get("last_node_id"), @@ -492,8 +657,9 @@ async def run_workflow_tool( output=output, title=f"Workflow: {workflow_name}", metadata={ - "workflow_id": workflow_id, - "status": status, + "workflow_id": display_workflow_id, + "workflow_execution_id": tracked_execution["id"] if tracked_execution else None, + "status": status_value, "steps": result_dict.get("steps", 0), "run_id": result_dict.get("run_id"), "last_node_id": result_dict.get("last_node_id"), @@ -508,13 +674,39 @@ async def run_workflow_tool( "workflow_id": workflow_id, "error": error_msg, }) + if tracked_execution and canonical_workflow_id and tracked_exec_key: + current_data = await Storage.read(tracked_exec_key) + current_data.update({ + "status": "error", + "finishedAt": int(time.time() * 1000), + "errorMessage": error_msg, + "executionLog": list(tracked_history), + "currentPhase": "error", + "currentStepIndex": len(tracked_history), + }) + await record_execution_result( + canonical_workflow_id, + tracked_execution["id"], + current_data, + ) + _emit_metadata({ + "title": f"Workflow: {workflow_name}", + "metadata": { + "workflow_id": canonical_workflow_id, + "workflow_execution_id": tracked_execution["id"], + "status": "error", + "phase": "error", + "step_index": len(tracked_history), + }, + }) return ToolResult( success=False, error=f"Workflow execution failed: {error_msg}", title=f"Workflow: {workflow_name}", metadata={ - "workflow_id": workflow_id, + "workflow_id": display_workflow_id, + "workflow_execution_id": tracked_execution["id"] if tracked_execution else None, "status": "FAILED", } ) diff --git a/flocks/tool/tool_loader.py b/flocks/tool/tool_loader.py index e88ba25d..8df57f71 100644 --- a/flocks/tool/tool_loader.py +++ b/flocks/tool/tool_loader.py @@ -504,29 +504,14 @@ def _build_execution_handler(cfg: dict, yaml_path: Path) -> ToolHandler: if not code or not code.strip(): raise ValueError(f"Empty execution code in {yaml_path}") - code_body = code.rstrip() - wrapper_lines = ["async def _tool_exec(**_kw_):", " import asyncio"] - for line in code_body.splitlines(): - wrapper_lines.append(f" {line}") - wrapper_source = "\n".join(wrapper_lines) - - compiled = compile(wrapper_source, str(yaml_path), "exec") - async def handler(ctx: ToolContext, **kwargs: Any) -> ToolResult: - ns: Dict[str, Any] = {} - exec(compiled, ns) - _tool_exec = ns["_tool_exec"] - try: - result = await _tool_exec(**kwargs) - if isinstance(result, ToolResult): - return result - if isinstance(result, dict): - success = result.pop("success", True) - error = result.pop("error", None) - return ToolResult(success=success, output=result, error=error) - return ToolResult(success=True, output=result) - except Exception as e: - return ToolResult(success=False, error=str(e)) + return ToolResult( + success=False, + error=( + "Inline YAML execution is disabled for safety. " + "Use handler.type=script for trusted Python tool handlers." + ), + ) return handler diff --git a/flocks/tool/web/webfetch.py b/flocks/tool/web/webfetch.py index b7c18fb9..86c6fe07 100644 --- a/flocks/tool/web/webfetch.py +++ b/flocks/tool/web/webfetch.py @@ -11,7 +11,6 @@ import re from typing import Optional from html.parser import HTMLParser -from urllib.parse import urlparse from flocks.tool.registry import ( ToolRegistry, ToolCategory, ToolParameter, ParameterType, ToolResult, ToolContext @@ -297,20 +296,20 @@ async def webfetch_tool( except ImportError: # Fallback to urllib if aiohttp not available - import urllib.request import urllib.error - + import urllib.request + try: req = urllib.request.Request(url, headers={ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" }) - + with urllib.request.urlopen(req, timeout=timeout_sec) as response: content = response.read().decode('utf-8', errors='replace') content_type = response.headers.get("Content-Type", "") - + title = f"{url} ({content_type})" - + if format == "markdown": if "text/html" in content_type: output = html_to_markdown(content) @@ -323,14 +322,14 @@ async def webfetch_tool( output = content else: output = content - + return ToolResult( success=True, output=output, title=title, metadata={} ) - + except urllib.error.HTTPError as e: return ToolResult( success=False, diff --git a/flocks/updater/updater.py b/flocks/updater/updater.py index 2b72723f..4703dab2 100644 --- a/flocks/updater/updater.py +++ b/flocks/updater/updater.py @@ -51,6 +51,8 @@ _CURL_USER_AGENT = "curl/8.7.1" _FRONTEND_DEPENDENCY_INSTALL_TIMEOUT_SECONDS = 300 _FRONTEND_BUILD_TIMEOUT_SECONDS = 300 +_DEPENDENCY_SYNC_TIMEOUT_SECONDS = 180 +_WINDOWS_DEPENDENCY_SYNC_TIMEOUT_SECONDS = 300 _PRESERVE_NAMES: set[str] = { ".venv", @@ -364,6 +366,13 @@ def _build_frontend_subprocess_env(*, npm_registry: str | None = None) -> dict[s return env or None +def _dependency_sync_timeout_seconds() -> int: + """Return the timeout budget for ``uv sync`` during self-update.""" + if sys.platform == "win32": + return _WINDOWS_DEPENDENCY_SYNC_TIMEOUT_SECONDS + return _DEPENDENCY_SYNC_TIMEOUT_SECONDS + + # ------------------------------------------------------------------ # # Async subprocess helpers # ------------------------------------------------------------------ # @@ -2150,10 +2159,29 @@ async def _restore_after_apply_failure() -> None: uv_cmd.extend(["--default-index", profile.uv_default_index]) sync_env = _build_uv_sync_env() + sync_timeout = _dependency_sync_timeout_seconds() retried_after_managed_python_repair = False - code, _, err = await _run_async( - uv_cmd, cwd=install_root, timeout=180, env=sync_env, - ) + + async def _run_uv_sync(cmd: list[str]) -> tuple[int, str, str]: + return await _run_async( + cmd, + cwd=install_root, + timeout=sync_timeout, + env=sync_env, + ) + + def _dependency_sync_timeout_message() -> str: + return f"Dependency sync timed out after {sync_timeout}s while running uv sync." + + try: + code, _, err = await _run_uv_sync(uv_cmd) + except subprocess.TimeoutExpired: + shutil.rmtree(tmp_dir, ignore_errors=True) + await _restore_after_apply_failure() + timeout_message = _dependency_sync_timeout_message() + _record_update_journal(f"ERROR {timeout_message}") + yield UpdateProgress(stage="error", message=timeout_message, success=False) + return if ( code != 0 and sys.platform == "win32" @@ -2173,9 +2201,15 @@ async def _restore_after_apply_failure() -> None: {"error": err}, ) await asyncio.sleep(2) - code, _, err = await _run_async( - uv_cmd, cwd=install_root, timeout=180, env=sync_env, - ) + try: + code, _, err = await _run_uv_sync(uv_cmd) + except subprocess.TimeoutExpired: + shutil.rmtree(tmp_dir, ignore_errors=True) + await _restore_after_apply_failure() + timeout_message = _dependency_sync_timeout_message() + _record_update_journal(f"ERROR {timeout_message}") + yield UpdateProgress(stage="error", message=timeout_message, success=False) + return if code != 0 and profile.uv_default_index: log.warning( "updater.dependencies.sync_retry_default_index", @@ -2186,15 +2220,27 @@ async def _restore_after_apply_failure() -> None: ) await asyncio.sleep(3) uv_cmd = [uv_path, "sync"] - code, _, err = await _run_async( - uv_cmd, cwd=install_root, timeout=180, env=sync_env, - ) + try: + code, _, err = await _run_uv_sync(uv_cmd) + except subprocess.TimeoutExpired: + shutil.rmtree(tmp_dir, ignore_errors=True) + await _restore_after_apply_failure() + timeout_message = _dependency_sync_timeout_message() + _record_update_journal(f"ERROR {timeout_message}") + yield UpdateProgress(stage="error", message=timeout_message, success=False) + return if code != 0: log.warning("updater.dependencies.sync_retry", {"first_error": err}) await asyncio.sleep(3) - code, _, err = await _run_async( - uv_cmd, cwd=install_root, timeout=180, env=sync_env, - ) + try: + code, _, err = await _run_uv_sync(uv_cmd) + except subprocess.TimeoutExpired: + shutil.rmtree(tmp_dir, ignore_errors=True) + await _restore_after_apply_failure() + timeout_message = _dependency_sync_timeout_message() + _record_update_journal(f"ERROR {timeout_message}") + yield UpdateProgress(stage="error", message=timeout_message, success=False) + return if code != 0: shutil.rmtree(tmp_dir, ignore_errors=True) diff --git a/flocks/workflow/execution_store.py b/flocks/workflow/execution_store.py new file mode 100644 index 00000000..04ffdc61 --- /dev/null +++ b/flocks/workflow/execution_store.py @@ -0,0 +1,110 @@ +"""Shared helpers for workflow execution history persistence.""" + +from __future__ import annotations + +import time +import uuid +from typing import Any, Dict, Optional + +from flocks.session.recorder import Recorder +from flocks.storage.storage import Storage +from flocks.workflow.runner import RunWorkflowResult + + +def workflow_execution_key(exec_id: str) -> str: + """Return the storage key for one workflow execution.""" + return f"workflow_execution/{exec_id}" + + +def normalize_execution_status(status: str) -> str: + """Map runner status values to API status values.""" + normalized = (status or "").strip().upper() + if normalized == "SUCCEEDED": + return "success" + if normalized == "FAILED": + return "error" + if normalized == "TIMED_OUT": + return "timeout" + if normalized == "CANCELLED": + return "cancelled" + return (status or "error").strip().lower() or "error" + + +def _extract_business_failure_message(outputs: Dict[str, Any]) -> Optional[str]: + """Return a user-facing failure reason from workflow outputs.""" + for key in ("reason", "error_message", "errorMessage", "message"): + value = outputs.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + return None + + +def resolve_execution_outcome(result: RunWorkflowResult) -> tuple[str, Optional[str]]: + """Resolve API execution status from runner status and workflow outputs.""" + status_value = normalize_execution_status(result.status) + error_message = result.error + + if status_value != "success" or not isinstance(result.outputs, dict): + return status_value, error_message + + if result.outputs.get("workflow_success") is False: + return ( + "error", + error_message + or _extract_business_failure_message(result.outputs) + or "Workflow reported business failure.", + ) + + return status_value, error_message + + +def build_initial_execution_record( + workflow_id: str, + *, + input_params: Optional[Dict[str, Any]] = None, + exec_id: Optional[str] = None, +) -> Dict[str, Any]: + """Build the initial running execution payload.""" + return { + "id": exec_id or str(uuid.uuid4()), + "workflowId": workflow_id, + "inputParams": input_params or {}, + "status": "running", + "startedAt": int(time.time() * 1000), + "executionLog": [], + "currentPhase": "queued", + "currentStepIndex": 0, + } + + +async def create_execution_record( + workflow_id: str, + *, + input_params: Optional[Dict[str, Any]] = None, + exec_id: Optional[str] = None, +) -> Dict[str, Any]: + """Create and persist a running workflow execution record.""" + exec_data = build_initial_execution_record( + workflow_id, + input_params=input_params, + exec_id=exec_id, + ) + await Storage.write(workflow_execution_key(exec_data["id"]), exec_data) + return exec_data + + +async def record_execution_result( + workflow_id: str, + exec_id: str, + exec_data: Dict[str, Any], +) -> None: + """Persist the final execution record and audit trail.""" + await Storage.write(workflow_execution_key(exec_id), exec_data) + try: + await Recorder.record_workflow_execution( + exec_id=exec_id, + workflow_id=workflow_id, + run_result=exec_data, + ) + except Exception: + pass diff --git a/flocks/workflow/fs_store.py b/flocks/workflow/fs_store.py index 08186b6b..366119fa 100644 --- a/flocks/workflow/fs_store.py +++ b/flocks/workflow/fs_store.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +import os from pathlib import Path from typing import Any, Dict, Optional @@ -61,26 +62,32 @@ def read_workflow_dir( try: workflow_json = json.loads(json_file.read_text(encoding="utf-8")) + json_mtime_ms = int(json_file.stat().st_mtime * 1000) meta_file = wf_dir / "meta.json" if meta_file.is_file(): meta = json.loads(meta_file.read_text(encoding="utf-8")) else: - mtime_ms = int(json_file.stat().st_mtime * 1000) meta = { "name": workflow_json.get("name", workflow_id), "description": workflow_json.get("description"), "category": workflow_json.get("category", "default"), "status": "active", "createdBy": None, - "createdAt": mtime_ms, - "updatedAt": mtime_ms, + "createdAt": json_mtime_ms, + "updatedAt": json_mtime_ms, } md_file = wf_dir / "workflow.md" markdown_content: Optional[str] = None + updated_candidates = [json_mtime_ms] if md_file.is_file(): markdown_content = md_file.read_text(encoding="utf-8") + updated_candidates.append(int(md_file.stat().st_mtime * 1000)) + if meta_file.is_file(): + updated_candidates.append(int(meta_file.stat().st_mtime * 1000)) + updated_candidates.append(int(meta.get("updatedAt") or 0)) + meta = {**meta, "updatedAt": max(updated_candidates)} return { **meta, @@ -105,3 +112,50 @@ def read_workflow_from_fs(workflow_id: str) -> Optional[Dict[str, Any]]: if data is not None: result = data return result + + +def resolve_workflow_id_from_source(workflow: Any) -> Optional[str]: + """Resolve a canonical workflow ID from a tool/runtime workflow argument. + + This is intentionally conservative: only return an ID when it maps cleanly to + a workflow already discoverable from the filesystem. + """ + if isinstance(workflow, dict): + candidate = workflow.get("id") + if isinstance(candidate, str) and candidate.strip(): + workflow_id = candidate.strip() + if read_workflow_from_fs(workflow_id) is not None: + return workflow_id + return None + + if isinstance(workflow, Path): + workflow_path = workflow.expanduser() + elif isinstance(workflow, str): + raw = workflow.strip() + if not raw: + return None + if read_workflow_from_fs(raw) is not None: + return raw + workflow_path = Path(raw).expanduser() + else: + return None + + if not workflow_path.is_file(): + return None + + try: + resolved = workflow_path.resolve() + except OSError: + return None + + for root, _source in workflow_scan_dirs(): + try: + relative = resolved.relative_to(root) + except ValueError: + continue + parts = relative.parts + if len(parts) == 2 and parts[1] == "workflow.json": + workflow_id = parts[0] + if read_workflow_from_fs(workflow_id) is not None: + return workflow_id + return None diff --git a/flocks/workflow/runner.py b/flocks/workflow/runner.py index 0770af3e..144227ce 100644 --- a/flocks/workflow/runner.py +++ b/flocks/workflow/runner.py @@ -250,6 +250,7 @@ def run_workflow( ensure_requirements: bool = True, requirements_installer: Optional[RequirementsInstaller] = None, sandbox_requirements_installer: Optional[SandboxRequirementsInstaller] = None, + on_step_start: Optional[Any] = None, on_step_complete: Optional[Any] = None, max_parallel_workers: int = 4, cancel: Optional[Callable[[], bool]] = None, @@ -398,8 +399,13 @@ def run_workflow( _on_step_start = None _on_step_end = None - if on_step_complete is not None: + if on_step_start is not None: + _on_step_start = lambda _rid, _step, _node, _inp: on_step_start( + _rid, _step, _node, _inp + ) + elif on_step_complete is not None: _on_step_start = lambda _rid, _step, _node, _inp: True + if on_step_complete is not None: _on_step_end = lambda _token, step_result: on_step_complete(step_result) try: diff --git a/pyproject.toml b/pyproject.toml index 26953367..95d63ad9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "flocks" -version = "v2026.5.7" +version = "v2026.5.9" description = "AI-Native SecOps platform with multi-agent collaboration" authors = [ {name = "Flocks Team", email = "team@example.com"} diff --git a/scripts/dev.sh b/scripts/dev.sh index 0c15ee03..cc0a2e26 100644 --- a/scripts/dev.sh +++ b/scripts/dev.sh @@ -204,6 +204,8 @@ cleanup() { start_backend() { echo -e "${GREEN}🔧 启动后端服务: http://${BACKEND_HOST}:${BACKEND_PORT}${NC}" cd "${PROJECT_ROOT}" + _FLOCKS_WEBUI_HOST="${FRONTEND_HOST}" \ + _FLOCKS_WEBUI_PORT="${FRONTEND_PORT}" \ uv run uvicorn flocks.server.app:app \ --host "${BACKEND_HOST}" \ --port "${BACKEND_PORT}" \ @@ -222,6 +224,8 @@ start_frontend() { start_all() { echo -e "${BLUE}🚀 同时启动前后端开发环境...${NC}" cd "${PROJECT_ROOT}" + _FLOCKS_WEBUI_HOST="${FRONTEND_HOST}" \ + _FLOCKS_WEBUI_PORT="${FRONTEND_PORT}" \ uv run uvicorn flocks.server.app:app \ --host "${BACKEND_HOST}" \ --port "${BACKEND_PORT}" \ diff --git a/tests/browser/test_admin.py b/tests/browser/test_admin.py index b620bab5..6ac1a374 100644 --- a/tests/browser/test_admin.py +++ b/tests/browser/test_admin.py @@ -37,6 +37,24 @@ def test_handshake_403_needs_chrome_remote_debugging_prompt() -> None: assert admin._needs_chrome_remote_debugging_prompt(msg) +def test_load_env_uses_shared_loader_for_existing_files(tmp_path, monkeypatch) -> None: + workspace = tmp_path / "workspace" + workspace.mkdir() + repo_env = tmp_path / ".env" + workspace_env = workspace / ".env" + repo_env.write_text("TOKEN=repo\n", encoding="utf-8") + workspace_env.write_text("TOKEN=workspace\n", encoding="utf-8") + loaded_paths = [] + + monkeypatch.setattr(admin, "PROJECT_ROOT", tmp_path) + monkeypatch.setenv("BH_AGENT_WORKSPACE", str(workspace)) + monkeypatch.setattr(admin, "load_env_file", lambda path: loaded_paths.append(path)) + + admin._load_env() + + assert loaded_paths == [repo_env, workspace_env] + + def test_stale_websocket_does_not_open_chrome_inspect() -> None: msg = "no close frame received or sent" assert not admin._needs_chrome_remote_debugging_prompt(msg) diff --git a/tests/browser/test_daemon.py b/tests/browser/test_daemon.py index 777ff3fd..f1f37e41 100644 --- a/tests/browser/test_daemon.py +++ b/tests/browser/test_daemon.py @@ -7,3 +7,31 @@ def test_is_real_page_filters_edge_internal_pages() -> None: def test_is_real_page_accepts_normal_https_pages() -> None: assert daemon.is_real_page({"type": "page", "url": "https://example.com"}) + + +def test_load_env_uses_shared_loader_for_existing_files(tmp_path, monkeypatch) -> None: + workspace = tmp_path / "workspace" + workspace.mkdir() + repo_root = tmp_path / "repo" + repo_root.mkdir() + repo_env = repo_root / ".env" + workspace_env = workspace / ".env" + repo_env.write_text("TOKEN=repo\n", encoding="utf-8") + workspace_env.write_text("TOKEN=workspace\n", encoding="utf-8") + loaded_paths = [] + + class _FakeModulePath: + def resolve(self): + return self + + @property + def parents(self): + return [None, None, repo_root] + + monkeypatch.setattr(daemon, "AGENT_WORKSPACE", workspace) + monkeypatch.setattr(daemon, "Path", lambda _value: _FakeModulePath()) + monkeypatch.setattr(daemon, "load_env_file", lambda path: loaded_paths.append(path)) + + daemon._load_env() + + assert loaded_paths == [repo_env, workspace_env] diff --git a/tests/browser/test_helpers.py b/tests/browser/test_helpers.py index b28b1762..0503063d 100644 --- a/tests/browser/test_helpers.py +++ b/tests/browser/test_helpers.py @@ -31,6 +31,34 @@ def test_max_dim_default_is_no_resize(fake_png) -> None: assert _run(fake_png, 4592, 2286) == (4592, 2286) +def test_load_env_uses_shared_loader_for_existing_files(tmp_path, monkeypatch) -> None: + workspace = tmp_path / "workspace" + workspace.mkdir() + repo_root = tmp_path / "repo" + repo_root.mkdir() + repo_env = repo_root / ".env" + workspace_env = workspace / ".env" + repo_env.write_text("TOKEN=repo\n", encoding="utf-8") + workspace_env.write_text("TOKEN=workspace\n", encoding="utf-8") + loaded_paths = [] + + class _FakeModulePath: + def resolve(self): + return self + + @property + def parents(self): + return [None, None, repo_root] + + monkeypatch.setattr(helpers, "AGENT_WORKSPACE", workspace) + monkeypatch.setattr(helpers, "Path", lambda _value: _FakeModulePath()) + monkeypatch.setattr(helpers, "load_env_file", lambda path: loaded_paths.append(path)) + + helpers._load_env() + + assert loaded_paths == [repo_env, workspace_env] + + def test_page_info_raises_clear_error_on_js_exception() -> None: def fake_send(req): return {} diff --git a/tests/browser/test_utils.py b/tests/browser/test_utils.py new file mode 100644 index 00000000..390020b6 --- /dev/null +++ b/tests/browser/test_utils.py @@ -0,0 +1,30 @@ +import os + +from flocks.browser import utils + + +def test_read_env_text_supports_utf8(tmp_path) -> None: + env_file = tmp_path / ".env" + env_file.write_text("TOKEN=中文\n", encoding="utf-8") + + assert utils.read_env_text(env_file) == "TOKEN=中文\n" + + +def test_load_env_file_supports_utf8_bom(tmp_path, monkeypatch) -> None: + env_file = tmp_path / ".env" + env_file.write_bytes('TOKEN="中文"\nNAME=test\n'.encode("utf-8-sig")) + monkeypatch.delenv("TOKEN", raising=False) + monkeypatch.delenv("NAME", raising=False) + + utils.load_env_file(env_file) + + assert os.environ["TOKEN"] == "中文" + assert os.environ["NAME"] == "test" + + +def test_read_env_text_falls_back_to_local_encoding(tmp_path, monkeypatch) -> None: + env_file = tmp_path / ".env" + env_file.write_bytes("TOKEN=中文\n".encode("gbk")) + monkeypatch.setattr(utils.locale, "getpreferredencoding", lambda _do_setlocale=False: "gbk") + + assert utils.read_env_text(env_file) == "TOKEN=中文\n" diff --git a/tests/cli/test_service_commands.py b/tests/cli/test_service_commands.py index 81e8b2fd..99a267c1 100644 --- a/tests/cli/test_service_commands.py +++ b/tests/cli/test_service_commands.py @@ -8,6 +8,7 @@ from typer.testing import CliRunner import flocks.cli.main as cli_main +import flocks.security as security_module runner = CliRunner() @@ -121,6 +122,14 @@ def fake_exists(path: Path) -> bool: popen_calls = {} run_calls = [] + saved_secrets = {} + + class FakeSecrets: + def get(self, secret_id: str): + return saved_secrets.get(secret_id) + + def set(self, secret_id: str, value: str) -> None: + saved_secrets[secret_id] = value class DummyProcess: pid = 4321 @@ -149,6 +158,7 @@ def fake_popen(args, **kwargs): monkeypatch.setattr(cli_main.Path, "cwd", staticmethod(lambda: tmp_path)) monkeypatch.setattr(subprocess, "run", fake_run) monkeypatch.setattr(subprocess, "Popen", fake_popen) + monkeypatch.setattr(security_module, "get_secret_manager", lambda: FakeSecrets()) monkeypatch.setattr(httpx, "get", lambda *_args, **_kwargs: SimpleNamespace(status_code=200)) monkeypatch.setattr(time, "sleep", lambda *_args, **_kwargs: None) monkeypatch.setattr(Path, "exists", fake_exists) @@ -158,3 +168,4 @@ def fake_popen(args, **kwargs): assert result.exit_code == 0 assert popen_calls["args"][:4] == [cli_main.sys.executable, "-m", "flocks.cli.main", "serve"] assert ["bun", "--version"] == run_calls[0][0] + assert saved_secrets["server_api_token"] diff --git a/tests/cli/test_update_command.py b/tests/cli/test_update_command.py index c3542448..9786405f 100644 --- a/tests/cli/test_update_command.py +++ b/tests/cli/test_update_command.py @@ -35,6 +35,78 @@ async def fake_update(*, check: bool, yes: bool, force: bool, region: str | None assert captured == {"check": False, "yes": True, "force": True, "region": "cn"} +def test_update_prompts_for_cn_mirror_before_upgrade_confirmation(monkeypatch) -> None: + output = StringIO() + monkeypatch.setattr( + update_cmd, + "console", + Console(file=output, force_terminal=False, color_system=None, width=120), + ) + + check_regions: list[str | None] = [] + confirm_prompts: list[str] = [] + captured: dict[str, object] = {} + answers = iter([True, True]) + + async def fake_check_update(*, locale: str | None = None, region: str | None = None) -> VersionInfo: + check_regions.append(region) + zipball_url = "https://example.com/flocks.zip" + tarball_url = "https://example.com/flocks.tar.gz" + if region == "cn": + zipball_url = "https://gitee.example.com/flocks.zip" + tarball_url = "https://gitee.example.com/flocks.tar.gz" + return VersionInfo( + current_version="2026.4.1", + latest_version="2026.4.2", + has_update=True, + zipball_url=zipball_url, + tarball_url=tarball_url, + deploy_mode="source", + update_allowed=True, + ) + + async def fake_perform_update( + latest_tag: str, + *, + zipball_url: str | None = None, + tarball_url: str | None = None, + restart: bool = True, + locale: str | None = None, + region: str | None = None, + ): + captured["latest_tag"] = latest_tag + captured["zipball_url"] = zipball_url + captured["tarball_url"] = tarball_url + captured["perform_region"] = region + captured["restart"] = restart + async for step in _fake_progress(): + yield step + + def fake_confirm(prompt: str, default: bool = False) -> bool: + confirm_prompts.append(prompt) + return next(answers) + + monkeypatch.setattr(updater_pkg, "check_update", fake_check_update) + monkeypatch.setattr(updater_pkg, "perform_update", fake_perform_update) + monkeypatch.setattr(updater_pkg, "detect_deploy_mode", lambda: "source") + monkeypatch.setattr(update_cmd.typer, "confirm", fake_confirm) + + import asyncio + + asyncio.run(update_cmd._update(check=False, yes=False, force=False, region=None)) + + assert check_regions == ["cn"] + assert confirm_prompts == ["\n是否使用中国镜像进行升级?", "\n是否立即升级?"] + assert captured == { + "latest_tag": "2026.4.2", + "zipball_url": "https://gitee.example.com/flocks.zip", + "tarball_url": "https://gitee.example.com/flocks.tar.gz", + "perform_region": "cn", + "restart": False, + } + assert "已切换为中国镜像源" not in output.getvalue() + + async def _fake_progress(): yield UpdateProgress(stage="fetching", message="fetching") yield UpdateProgress(stage="done", message="done", success=True) diff --git a/tests/config/test_api_versioning.py b/tests/config/test_api_versioning.py index b29aec43..faa1baa3 100644 --- a/tests/config/test_api_versioning.py +++ b/tests/config/test_api_versioning.py @@ -302,6 +302,27 @@ def test_copies_legacy_to_storage_key(self, isolated_env, api_root): assert services["tdp_api_v3_3_10"]["base_url"] == "https://tdp.example" assert services["tdp_api_v3_3_10"]["apiKey"] == "{secret:tdp_key}" + def test_migration_canonicalizes_legacy_ssl_verify_alias(self, isolated_env, api_root): + _, user_config = isolated_env + _write_provider_yaml(api_root / "tdp_v3_3_10", service_id="tdp_api", version="3.3.10") + config_path = _write_flocks_json(user_config, { + "api_services": { + "tdp_api": { + "enabled": True, + "base_url": "https://tdp.example", + "ssl_verify": True, + }, + }, + }) + + actions = migrate_api_services() + + assert actions == {"tdp_api_v3_3_10": "copied"} + services = json.loads(config_path.read_text(encoding="utf-8"))["api_services"] + assert services["tdp_api"]["ssl_verify"] is True + assert services["tdp_api_v3_3_10"]["verify_ssl"] is True + assert "ssl_verify" not in services["tdp_api_v3_3_10"] + def test_idempotent_when_storage_key_exists(self, isolated_env, api_root): _, user_config = isolated_env _write_provider_yaml(api_root / "tdp_v3_3_10", service_id="tdp_api", version="3.3.10") @@ -505,6 +526,41 @@ def test_set_api_service_no_warning_for_storage_key( captured = capsys.readouterr() assert "api_service.write.shadowed_legacy" not in captured.err + def test_set_api_service_promotes_ssl_verify_alias_to_verify_ssl(self, isolated_env, api_root): + _, user_config = isolated_env + from flocks.config.config_writer import ConfigWriter + + _write_provider_yaml(api_root / "tdp_v3_3_10", service_id="tdp_api", version="3.3.10") + _write_flocks_json(user_config, {"api_services": {}}) + versioning._reset_descriptor_cache() + + ConfigWriter.set_api_service("tdp_api_v3_3_10", { + "base_url": "https://tdp.example", + "ssl_verify": False, + }) + + services = json.loads((user_config / "flocks.json").read_text(encoding="utf-8"))["api_services"] + assert services["tdp_api_v3_3_10"]["verify_ssl"] is False + assert "ssl_verify" not in services["tdp_api_v3_3_10"] + + def test_set_api_service_drops_alias_when_verify_ssl_already_present(self, isolated_env, api_root): + _, user_config = isolated_env + from flocks.config.config_writer import ConfigWriter + + _write_provider_yaml(api_root / "tdp_v3_3_10", service_id="tdp_api", version="3.3.10") + _write_flocks_json(user_config, {"api_services": {}}) + versioning._reset_descriptor_cache() + + ConfigWriter.set_api_service("tdp_api_v3_3_10", { + "base_url": "https://tdp.example", + "ssl_verify": True, + "verify_ssl": False, + }) + + services = json.loads((user_config / "flocks.json").read_text(encoding="utf-8"))["api_services"] + assert services["tdp_api_v3_3_10"]["verify_ssl"] is False + assert "ssl_verify" not in services["tdp_api_v3_3_10"] + def test_get_api_service_raw_handles_null_api_services(self, isolated_env, api_root): """``"api_services": null`` in flocks.json must not crash the reader.""" _, user_config = isolated_env diff --git a/tests/mcp/test_mcp_catalog.py b/tests/mcp/test_mcp_catalog.py index 7e39d83d..b8a857b8 100644 --- a/tests/mcp/test_mcp_catalog.py +++ b/tests/mcp/test_mcp_catalog.py @@ -13,6 +13,7 @@ EnvVarSpec, InstallSpec, McpCatalog, + RemoteConfigSpec, ) from flocks.mcp.installer import managed_python_bin_dir, managed_python_executable @@ -126,10 +127,26 @@ def test_to_mcp_config_with_overrides(self): assert config["environment"]["API_KEY"] == "my-secret-key" def test_to_mcp_config_remote(self): - entry = self._make_entry(transport="remote") + entry = self._make_entry( + transport="remote", + remote=RemoteConfigSpec( + url="https://example.com/mcp?apikey={secret:api_key}", + transport="sse", + auth={ + "type": "apikey", + "location": "header", + "param_name": "Authorization", + "value": "Bearer {secret:api_key}", + }, + oauth=False, + ), + ) config = entry.to_mcp_config({"url": "https://example.com/mcp"}) assert config["type"] == "remote" assert config["url"] == "https://example.com/mcp" + assert config["transport"] == "sse" + assert config["auth"]["param_name"] == "Authorization" + assert config["oauth"] is False def test_to_mcp_config_local_console_script_uses_managed_venv(self): entry = self._make_entry( diff --git a/tests/mcp/test_mcp_client.py b/tests/mcp/test_mcp_client.py new file mode 100644 index 00000000..5eff5df2 --- /dev/null +++ b/tests/mcp/test_mcp_client.py @@ -0,0 +1,82 @@ +import pytest + +from flocks.mcp.client import McpClient + + +class TestMcpClientTransportSelection: + @pytest.mark.asyncio + async def test_connect_uses_sse_only_when_transport_is_sse(self, monkeypatch: pytest.MonkeyPatch): + calls: list[str] = [] + + async def fake_http(*args, **kwargs): + calls.append("http") + + async def fake_sse(*args, **kwargs): + calls.append("sse") + + client = McpClient( + name="demo", + server_type="remote", + url="https://example.com/sse", + transport="sse", + ) + monkeypatch.setattr(client, "_do_connect_streamable_http", fake_http) + monkeypatch.setattr(client, "_do_connect_sse", fake_sse) + + await client.connect() + + assert calls == ["sse"] + assert client._transport_type == "sse" + + @pytest.mark.asyncio + async def test_connect_uses_http_only_when_transport_is_http(self, monkeypatch: pytest.MonkeyPatch): + calls: list[str] = [] + + async def fake_http(*args, **kwargs): + calls.append("http") + + async def fake_sse(*args, **kwargs): + calls.append("sse") + + client = McpClient( + name="demo", + server_type="remote", + url="https://example.com/mcp", + transport="http", + ) + monkeypatch.setattr(client, "_do_connect_streamable_http", fake_http) + monkeypatch.setattr(client, "_do_connect_sse", fake_sse) + + await client.connect() + + assert calls == ["http"] + assert client._transport_type == "streamable_http" + + @pytest.mark.asyncio + async def test_connect_auto_falls_back_to_sse_after_http_failure(self, monkeypatch: pytest.MonkeyPatch): + calls: list[str] = [] + + async def fake_http(*args, **kwargs): + calls.append("http") + raise RuntimeError("HTTP 405") + + async def fake_sse(*args, **kwargs): + calls.append("sse") + + async def fake_cleanup(): + return None + + client = McpClient( + name="demo", + server_type="remote", + url="https://example.com/mcp", + transport="auto", + ) + monkeypatch.setattr(client, "_do_connect_streamable_http", fake_http) + monkeypatch.setattr(client, "_do_connect_sse", fake_sse) + monkeypatch.setattr(client, "_cleanup_connection", fake_cleanup) + + await client.connect() + + assert calls == ["http", "sse"] + assert client._transport_type == "sse" diff --git a/tests/mcp/test_mcp_utils.py b/tests/mcp/test_mcp_utils.py index 1e0b35df..bfa1aec6 100644 --- a/tests/mcp/test_mcp_utils.py +++ b/tests/mcp/test_mcp_utils.py @@ -9,13 +9,18 @@ from pathlib import Path from unittest import mock from flocks.mcp.utils import ( + MCP_MASKED_SECRET_VALUE, build_mcp_headers, build_mcp_url, config_has_pending_credentials, extract_api_key_from_mcp_url, + extract_auth_value_from_mcp_config, + extract_sensitive_headers_from_mcp_config, get_connect_block_reason, + mask_sensitive_mcp_config_for_frontend, normalize_mcp_config, resolve_env_var, + restore_masked_mcp_config_secrets, sanitize_name, generate_tool_name, calculate_schema_hash, @@ -94,6 +99,28 @@ def test_with_header_auth(self): "Authorization": "Bearer token123", } + def test_with_bearer_scheme_prefixes_secret_value(self): + """Bearer auth should prepend the scheme after resolving the secret.""" + with tempfile.TemporaryDirectory() as tmpdir: + secret_file = Path(tmpdir) / ".secret.json" + secret_file.write_text(json.dumps({"demo_mcp_key": "token123"})) + secret_file.chmod(0o600) + + from flocks.security.secrets import SecretManager + sm = SecretManager(secret_file=secret_file) + with mock.patch("flocks.security.secrets.get_secret_manager", return_value=sm): + headers = build_mcp_headers( + None, + { + "type": "apikey", + "location": "header", + "param_name": "Authorization", + "scheme": "bearer", + "value": "{secret:demo_mcp_key}", + }, + ) + assert headers == {"Authorization": "Bearer token123"} + def test_header_secret_is_resolved(self): """Header values should resolve {secret:KEY} placeholders""" with tempfile.TemporaryDirectory() as tmpdir: @@ -124,6 +151,39 @@ def test_normalizes_stdio_alias_and_combines_args(self): "command": ["uvx", "mcp-server", "--port", "8080"], } + def test_normalizes_sse_alias_to_remote_transport(self): + config = normalize_mcp_config( + { + "type": "sse", + "url": "https://example.com/sse", + } + ) + assert config == { + "type": "remote", + "url": "https://example.com/sse", + "transport": "sse", + } + + def test_normalizes_streamable_http_transport_alias(self): + config = normalize_mcp_config( + { + "type": "remote", + "url": "https://example.com/mcp", + "transport": "streamable_http", + } + ) + assert config["transport"] == "http" + + def test_falls_back_to_auto_for_unknown_remote_transport(self): + config = normalize_mcp_config( + { + "type": "remote", + "url": "https://example.com/mcp", + "transport": "weird-protocol", + } + ) + assert config["transport"] == "auto" + class TestCredentialStateHelpers: """Test credential gating helpers""" @@ -183,6 +243,191 @@ def test_keeps_existing_secret_reference_unchanged(self): assert updated["url"] == "https://example.com/mcp?apikey={secret:demo-mcp_mcp_key}" +class TestExtractAuthValueFromMcpConfig: + """Test secret extraction from MCP auth config.""" + + def test_extracts_plain_auth_value_to_secret_reference(self): + saved_secrets: dict[str, str] = {} + + class SecretManagerStub: + def set(self, key: str, value: str) -> None: + saved_secrets[key] = value + + with mock.patch("flocks.security.get_secret_manager", return_value=SecretManagerStub()): + updated = extract_auth_value_from_mcp_config( + "demo-mcp", + { + "type": "remote", + "url": "https://example.com/sse", + "auth": { + "type": "apikey", + "location": "header", + "param_name": "Authorization", + "scheme": "bearer", + "value": "Bearer token123", + }, + }, + ) + + assert saved_secrets == {"demo-mcp_mcp_key": "token123"} + assert updated["auth"]["value"] == "{secret:demo-mcp_mcp_key}" + assert updated["auth"]["scheme"] == "bearer" + + def test_infers_bearer_scheme_from_authorization_header(self): + saved_secrets: dict[str, str] = {} + + class SecretManagerStub: + def set(self, key: str, value: str) -> None: + saved_secrets[key] = value + + with mock.patch("flocks.security.get_secret_manager", return_value=SecretManagerStub()): + updated = extract_auth_value_from_mcp_config( + "demo-mcp", + { + "type": "remote", + "url": "https://example.com/sse", + "auth": { + "type": "apikey", + "location": "header", + "param_name": "Authorization", + "value": "Bearer token123", + }, + }, + ) + + assert saved_secrets == {"demo-mcp_mcp_key": "token123"} + assert updated["auth"]["value"] == "{secret:demo-mcp_mcp_key}" + assert updated["auth"]["scheme"] == "bearer" + + def test_keeps_existing_secret_reference_unchanged(self): + with mock.patch("flocks.security.get_secret_manager") as get_secret_manager: + updated = extract_auth_value_from_mcp_config( + "demo-mcp", + { + "type": "remote", + "url": "https://example.com/sse", + "auth": { + "type": "apikey", + "location": "header", + "param_name": "Authorization", + "value": "{secret:demo-mcp_mcp_key}", + }, + }, + ) + + get_secret_manager.assert_not_called() + assert updated["auth"]["value"] == "{secret:demo-mcp_mcp_key}" + + +class TestExtractSensitiveHeadersFromMcpConfig: + """Test secret extraction from MCP headers.""" + + def test_extracts_sensitive_header_value_to_secret_reference(self): + saved_secrets: dict[str, str] = {} + + class SecretManagerStub: + def set(self, key: str, value: str) -> None: + saved_secrets[key] = value + + with mock.patch("flocks.security.get_secret_manager", return_value=SecretManagerStub()): + updated = extract_sensitive_headers_from_mcp_config( + "demo-mcp", + { + "type": "remote", + "url": "https://example.com/mcp", + "headers": { + "Authorization": "Bearer token123", + "X-Client": "flocks", + }, + }, + ) + + assert saved_secrets == {"demo-mcp_authorization_header": "Bearer token123"} + assert updated["headers"]["Authorization"] == "{secret:demo-mcp_authorization_header}" + assert updated["headers"]["X-Client"] == "flocks" + + def test_keeps_existing_sensitive_header_secret_reference_unchanged(self): + with mock.patch("flocks.security.get_secret_manager") as get_secret_manager: + updated = extract_sensitive_headers_from_mcp_config( + "demo-mcp", + { + "type": "remote", + "url": "https://example.com/mcp", + "headers": { + "Authorization": "{secret:demo-mcp_authorization_header}", + }, + }, + ) + + get_secret_manager.assert_not_called() + assert ( + updated["headers"]["Authorization"] + == "{secret:demo-mcp_authorization_header}" + ) + + +class TestMaskSensitiveMcpConfigForFrontend: + """Test frontend masking helpers for legacy plain-text configs.""" + + def test_masks_plain_text_auth_and_headers(self): + masked = mask_sensitive_mcp_config_for_frontend( + { + "type": "remote", + "url": "https://example.com/mcp", + "auth": { + "type": "apikey", + "location": "header", + "param_name": "Authorization", + "value": "Bearer token123", + }, + "headers": { + "Authorization": "Bearer token123", + "X-Client": "flocks", + }, + } + ) + + assert masked["auth"]["value"] == MCP_MASKED_SECRET_VALUE + assert masked["headers"]["Authorization"] == MCP_MASKED_SECRET_VALUE + assert masked["headers"]["X-Client"] == "flocks" + + def test_restores_masked_values_from_existing_config(self): + restored = restore_masked_mcp_config_secrets( + { + "type": "remote", + "url": "https://example.com/mcp", + "auth": { + "type": "apikey", + "location": "header", + "param_name": "Authorization", + "value": "Bearer token123", + }, + "headers": { + "Authorization": "Bearer token123", + "X-Client": "flocks", + }, + }, + { + "type": "remote", + "url": "https://new.example.com/mcp", + "auth": { + "type": "apikey", + "location": "header", + "param_name": "Authorization", + "value": MCP_MASKED_SECRET_VALUE, + }, + "headers": { + "Authorization": MCP_MASKED_SECRET_VALUE, + "X-Client": "flocks-web", + }, + }, + ) + + assert restored["auth"]["value"] == "Bearer token123" + assert restored["headers"]["Authorization"] == "Bearer token123" + assert restored["headers"]["X-Client"] == "flocks-web" + + class TestResolveEnvVar: """Test environment variable resolution""" diff --git a/tests/provider/test_azure_provider.py b/tests/provider/test_azure_provider.py new file mode 100644 index 00000000..d23d426d --- /dev/null +++ b/tests/provider/test_azure_provider.py @@ -0,0 +1,33 @@ +from flocks.provider.provider import ModelCapabilities, ModelInfo +from flocks.provider.sdk.azure import AzureProvider + + +def test_azure_provider_returns_configured_deployment_models(): + provider = AzureProvider() + provider._config_models = [ + ModelInfo( + id="customer-prod-deployment", + name="Customer Production Deployment", + provider_id="azure", + capabilities=ModelCapabilities( + supports_tools=True, + supports_streaming=True, + context_window=128000, + max_tokens=4096, + ), + ) + ] + + models = provider.get_models() + + assert [m.id for m in models] == ["customer-prod-deployment"] + assert models[0].name == "Customer Production Deployment" + + +def test_azure_provider_returns_fallback_models_without_config(): + provider = AzureProvider() + + models = provider.get_models() + + assert {m.id for m in models} == {"gpt-5.4", "gpt-5-mini"} + assert all(m.provider_id == "azure" for m in models) diff --git a/tests/provider/test_openai_base_provider.py b/tests/provider/test_openai_base_provider.py index 2394c68c..b8d9bcab 100644 --- a/tests/provider/test_openai_base_provider.py +++ b/tests/provider/test_openai_base_provider.py @@ -352,7 +352,21 @@ def test_get_client_respects_verify_ssl_false(self, mock_async_openai, mock_http provider._get_client() - mock_http_client.assert_called_once_with(verify=False, timeout=120.0) + # Granular timeout supports multimodal payloads; ``trust_env`` defaults + # to True so corporate egress proxies work out of the box. + # See ``OpenAIBaseProvider._get_client``. + assert mock_http_client.call_count == 1 + kwargs = mock_http_client.call_args.kwargs + assert kwargs["verify"] is False + assert kwargs["trust_env"] is True + timeout_arg = kwargs["timeout"] + # Either an httpx.Timeout instance or compatible object: assert the + # connect/read/write components rather than equality so future tweaks + # to non-essential pool/write durations don't break the test. + assert getattr(timeout_arg, "connect", None) == 30.0 + assert getattr(timeout_arg, "read", None) == 600.0 + assert getattr(timeout_arg, "write", None) == 600.0 + mock_async_openai.assert_called_once_with( api_key="test-api-key", base_url="https://gateway.internal/v1", diff --git a/tests/provider/test_openai_compatible_provider.py b/tests/provider/test_openai_compatible_provider.py index 2bb03a9a..bab862dd 100644 --- a/tests/provider/test_openai_compatible_provider.py +++ b/tests/provider/test_openai_compatible_provider.py @@ -48,7 +48,17 @@ def test_get_client_respects_verify_ssl_false(self, mock_async_openai, mock_http provider._get_client() - mock_http_client.assert_called_once_with(verify=False, timeout=120.0) + # Granular timeout supports multimodal payloads; assert semantically + # so minor adjustments to non-critical values don't break the test. + assert mock_http_client.call_count == 1 + kwargs = mock_http_client.call_args.kwargs + assert kwargs["verify"] is False + assert kwargs["trust_env"] is True + timeout_arg = kwargs["timeout"] + assert getattr(timeout_arg, "connect", None) == 30.0 + assert getattr(timeout_arg, "read", None) == 600.0 + assert getattr(timeout_arg, "write", None) == 600.0 + mock_async_openai.assert_called_once_with( api_key="test-api-key", base_url="https://gateway.internal/v1", diff --git a/tests/provider/test_openai_provider.py b/tests/provider/test_openai_provider.py index b54ef447..cc78a697 100644 --- a/tests/provider/test_openai_provider.py +++ b/tests/provider/test_openai_provider.py @@ -24,11 +24,17 @@ def test_get_client_respects_verify_ssl_false(self, mock_async_openai, mock_http provider._get_client() - mock_http_client.assert_called_once_with( - trust_env=True, - verify=False, - timeout=120.0, - ) + # Granular timeout supports multimodal payloads; verify fields + # semantically so minor adjustments to non-critical values don't break. + assert mock_http_client.call_count == 1 + kwargs = mock_http_client.call_args.kwargs + assert kwargs["trust_env"] is True + assert kwargs["verify"] is False + timeout_arg = kwargs["timeout"] + assert getattr(timeout_arg, "connect", None) == 30.0 + assert getattr(timeout_arg, "read", None) == 600.0 + assert getattr(timeout_arg, "write", None) == 600.0 + mock_async_openai.assert_called_once_with( api_key="test-api-key", base_url="https://gateway.internal/v1", diff --git a/tests/provider/test_test_credentials.py b/tests/provider/test_test_credentials.py index e01d5db7..783692de 100644 --- a/tests/provider/test_test_credentials.py +++ b/tests/provider/test_test_credentials.py @@ -228,10 +228,28 @@ async def test_service_prefers_lightweight_query_tool_over_file_upload(self): ) @pytest.mark.asyncio - async def test_service_uses_enum_action_instead_of_placeholder_string(self): - """Connectivity checks should use enum-backed action values, not the generic 'test' placeholder.""" + async def test_onesec_service_prefers_threat_probe_and_uses_enum_action(self): + """OneSEC should prefer the read-only threat probe over the older DNS probe.""" from flocks.server.routes.provider import test_provider_credentials + onesec_threat_tool = ToolInfo( + name="onesec_threat", + description="OneSEC threat grouped tool", + category=ToolCategory.CUSTOM, + parameters=[ + ToolParameter( + name="action", + type=ParameterType.STRING, + description="Threat action", + required=True, + enum=[ + "threat_query_bd_version", + "threat_virus_scan", + "threat_update_bd_version", + ], + ) + ], + ) onesec_dns_tool = ToolInfo( name="onesec_dns", description="OneSEC DNS grouped tool", @@ -265,9 +283,9 @@ async def test_service_uses_enum_action_instead_of_placeholder_string(self): mock_provider_cls.get.return_value = None mock_tr.init = MagicMock() - mock_tr.list_tools.return_value = [onesec_dns_tool] + mock_tr.list_tools.return_value = [onesec_dns_tool, onesec_threat_tool] mock_tr._dynamic_tools_by_module = { - "flocks.tool.generated.onesec": ["onesec_dns"], + "flocks.tool.generated.onesec": ["onesec_dns", "onesec_threat"], } mock_tr.execute = AsyncMock(return_value=ToolResult( success=True, @@ -277,10 +295,10 @@ async def test_service_uses_enum_action_instead_of_placeholder_string(self): result = await test_provider_credentials("onesec_api") assert result["success"] is True, result - assert result["tool_tested"] == "onesec_dns" + assert result["tool_tested"] == "onesec_threat" mock_tr.execute.assert_awaited_once_with( - tool_name="onesec_dns", - action="dns_get_public_ip_list", + tool_name="onesec_threat", + action="threat_query_bd_version", ) @pytest.mark.asyncio @@ -344,6 +362,54 @@ async def test_service_prefers_login_probe_over_action_dispatch_tool(self): assert result["tool_tested"] == "qingteng_login" mock_tr.execute.assert_awaited_once_with(tool_name="qingteng_login") + @pytest.mark.asyncio + async def test_declared_manifest_probe_is_used_before_heuristic_tool_selection(self): + """A declared connectivity probe should bypass heuristic tool sorting.""" + from flocks.server.routes.provider import test_provider_credentials + from flocks.tool.probe_loader import ConnectivitySpec + + heuristic_tool = _make_tool_info("tdp_assets_domain_list") + + mock_secrets = MagicMock() + mock_secrets.get.return_value = "valid-creds" + + with ( + patch(_PATCH_SECRET_MGR, return_value=mock_secrets), + patch(_PATCH_PROVIDER) as mock_provider_cls, + patch(_PATCH_TOOL_REGISTRY) as mock_tr, + patch(_PATCH_TOOL_SOURCE, return_value=("api", "tdp_api_v3_3_10")), + patch( + "flocks.tool.probe_loader.get_connectivity_spec", + return_value=ConnectivitySpec( + tool="tdp_system_status", + params={"action": "service"}, + ), + ), + ): + mock_provider_cls._ensure_initialized = MagicMock() + mock_provider_cls.apply_config = AsyncMock() + mock_provider_cls.get.return_value = None + + mock_tr.init = MagicMock() + mock_tr.list_tools.return_value = [heuristic_tool] + mock_tr._dynamic_tools_by_module = { + "flocks.tool.generated.tdp_api": ["tdp_assets_domain_list"], + } + mock_tr.execute = AsyncMock(return_value=ToolResult( + success=True, + output={"status": "ok"}, + )) + + result = await test_provider_credentials("tdp_api_v3_3_10") + + assert result["success"] is True, result + assert result["tool_tested"] == "tdp_system_status" + assert result["probe_source"] == "manifest" + mock_tr.execute.assert_awaited_once_with( + tool_name="tdp_system_status", + action="service", + ) + @pytest.mark.asyncio async def test_login_probe_does_not_overmatch_business_tools(self): """`_is_login_probe` must only match dedicated probes — not arbitrary @@ -754,3 +820,83 @@ async def test_existing_custom_settings_are_preserved_during_provider_test(self) assert configured.api_key == "gateway-api-key" assert configured.base_url == "https://gateway.internal/v1" assert configured.custom_settings["verify_ssl"] is False + + @pytest.mark.asyncio + async def test_requested_azure_deployment_model_is_used_for_provider_test(self): + from flocks.server.routes.provider import TestCredentialRequest, test_provider_credentials + + provider = MagicMock() + provider._config = MagicMock( + custom_settings={}, + base_url="https://example-resource.openai.azure.com/", + ) + provider.chat = AsyncMock(return_value=MagicMock(content="Paris")) + + model = MagicMock() + model.id = "customer-prod-deployment" + + mock_secrets = MagicMock() + mock_secrets.get.return_value = "azure-api-key" + + mock_config = MagicMock() + + with ( + patch(_PATCH_SECRET_MGR, return_value=mock_secrets), + patch(_PATCH_CONFIG_GET, new_callable=AsyncMock, return_value=mock_config), + patch(_PATCH_PROVIDER) as mock_provider_cls, + ): + mock_provider_cls._ensure_initialized = MagicMock() + mock_provider_cls._load_dynamic_providers = MagicMock() + mock_provider_cls.apply_config = AsyncMock() + mock_provider_cls.get.return_value = provider + mock_provider_cls.list_models.return_value = [model] + + result = await test_provider_credentials( + "azure-openai", + TestCredentialRequest(model_id="customer-prod-deployment"), + ) + + assert result["success"] is True, result + assert result["model_id"] == "customer-prod-deployment" + provider.chat.assert_awaited_once() + assert provider.chat.await_args.args[0] == "customer-prod-deployment" + + @pytest.mark.asyncio + async def test_unsaved_azure_deployment_can_be_tested_without_model_definition(self): + from flocks.server.routes.provider import TestCredentialRequest, test_provider_credentials + + provider = MagicMock() + provider._config = MagicMock( + custom_settings={}, + base_url="https://example-resource.openai.azure.com/", + ) + provider.chat = AsyncMock(return_value=MagicMock(content="Paris")) + + catalog_model = MagicMock() + catalog_model.id = "gpt-5.4" + + mock_secrets = MagicMock() + mock_secrets.get.return_value = "azure-api-key" + + mock_config = MagicMock() + + with ( + patch(_PATCH_SECRET_MGR, return_value=mock_secrets), + patch(_PATCH_CONFIG_GET, new_callable=AsyncMock, return_value=mock_config), + patch(_PATCH_PROVIDER) as mock_provider_cls, + ): + mock_provider_cls._ensure_initialized = MagicMock() + mock_provider_cls._load_dynamic_providers = MagicMock() + mock_provider_cls.apply_config = AsyncMock() + mock_provider_cls.get.return_value = provider + mock_provider_cls.list_models.return_value = [catalog_model] + + result = await test_provider_credentials( + "azure-openai", + TestCredentialRequest(model_id="unsaved-prod-deployment"), + ) + + assert result["success"] is True, result + assert result["model_id"] == "unsaved-prod-deployment" + provider.chat.assert_awaited_once() + assert provider.chat.await_args.args[0] == "unsaved-prod-deployment" diff --git a/tests/pty/test_pty_security.py b/tests/pty/test_pty_security.py new file mode 100644 index 00000000..c2eb5996 --- /dev/null +++ b/tests/pty/test_pty_security.py @@ -0,0 +1,52 @@ +import pytest + +from flocks.pty.pty import Pty + + +def test_pty_rejects_shell_command_execution_flag(): + with pytest.raises(ValueError, match="arguments"): + Pty._validate_interactive_shell("/bin/sh", ["-c", "id"]) + + +def test_pty_rejects_non_shell_command(): + with pytest.raises(ValueError, match="approved interactive shell"): + Pty._validate_interactive_shell("/usr/bin/python3", []) + + +def test_pty_allows_interactive_shell_flags(): + Pty._validate_interactive_shell("/bin/zsh", ["-l"]) + + +@pytest.mark.parametrize( + "shell", + [ + "ash", + "dash", + "ksh", + "ksh93", + "mksh", + "csh", + "tcsh", + ], +) +def test_pty_allows_common_interactive_shells(shell: str): + Pty._validate_interactive_shell(f"/bin/{shell}", []) + + +def test_pty_rejects_shell_startup_environment_injection(): + with pytest.raises(ValueError, match="not allowed"): + Pty._prepare_environment({"BASH_ENV": "/tmp/payload.sh"}) + + +def test_pty_filters_inherited_shell_startup_environment(monkeypatch): + monkeypatch.setenv("BASH_ENV", "/tmp/payload.sh") + monkeypatch.setenv("DYLD_INSERT_LIBRARIES", "/tmp/libevil.dylib") + monkeypatch.setenv("SAFE_VAR", "ok") + + env = Pty._prepare_environment({"CUSTOM_VAR": "custom"}) + + assert "BASH_ENV" not in env + assert "DYLD_INSERT_LIBRARIES" not in env + assert env["SAFE_VAR"] == "ok" + assert env["CUSTOM_VAR"] == "custom" + assert env["TERM"] == "xterm-256color" diff --git a/tests/server/routes/conftest.py b/tests/server/routes/conftest.py index 6c72978f..8071b9b8 100644 --- a/tests/server/routes/conftest.py +++ b/tests/server/routes/conftest.py @@ -18,10 +18,34 @@ async def client() -> AsyncGenerator[AsyncClient, None]: """Async HTTP test client for the FastAPI app.""" from flocks.server.app import app transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as ac: + headers = {"Authorization": "Bearer abc123", "User-Agent": "curl/8.0"} + async with AsyncClient( + transport=transport, + base_url="http://test", + headers=headers, + ) as ac: yield ac +@pytest.fixture(autouse=True) +def _route_test_api_token(monkeypatch: pytest.MonkeyPatch) -> None: + """Provide a valid API token for route tests.""" + from flocks.server import auth as auth_module + + class SecretManagerStub: + def __init__(self, values: dict[str, str]): + self._values = values + + def get(self, key: str): + return self._values.get(key) + + monkeypatch.setattr( + auth_module, + "get_secret_manager", + lambda: SecretManagerStub({auth_module.API_TOKEN_SECRET_ID: "abc123"}), + ) + + @pytest.fixture def mock_workspace(tmp_path: Path) -> Path: """Patch WorkspaceManager root to a temp directory.""" diff --git a/tests/server/routes/test_custom_provider_runtime.py b/tests/server/routes/test_custom_provider_runtime.py index 0e8baaee..4899ce6e 100644 --- a/tests/server/routes/test_custom_provider_runtime.py +++ b/tests/server/routes/test_custom_provider_runtime.py @@ -1,4 +1,5 @@ from flocks.provider.provider import ModelCapabilities, ModelInfo, Provider +from flocks.provider.sdk.azure import AzureProvider from flocks.server.routes.custom_provider import CreateModelReq, _add_model_to_runtime @@ -48,3 +49,35 @@ class DummyProvider: assert provider._config_models[0].capabilities.supports_reasoning is True finally: Provider._models = original_models + + +def test_add_azure_deployment_to_runtime_config_models(monkeypatch): + provider = AzureProvider() + provider.id = "azure-openai" + provider._config_models = [] + body = CreateModelReq( + model_id="customer-prod-deployment", + name="Customer Production Deployment", + context_window=128000, + max_output_tokens=4096, + supports_vision=False, + supports_tools=True, + supports_streaming=True, + supports_reasoning=False, + input_price=0.0, + output_price=0.0, + currency="USD", + ) + + original_models = Provider._models + Provider._models = {} + monkeypatch.setattr(Provider, "get", classmethod(lambda cls, provider_id: provider)) + + try: + _add_model_to_runtime("azure-openai", body) + + assert Provider._models[body.model_id].provider_id == "azure-openai" + assert provider._config_models[0].id == "customer-prod-deployment" + assert provider._config_models[0].name == "Customer Production Deployment" + finally: + Provider._models = original_models diff --git a/tests/server/routes/test_find_routes.py b/tests/server/routes/test_find_routes.py new file mode 100644 index 00000000..ae40dd0a --- /dev/null +++ b/tests/server/routes/test_find_routes.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import json +from pathlib import Path + +import pytest +from fastapi import HTTPException + +from flocks.server.routes import find as find_routes + + +class _RunResult: + def __init__(self, stdout: str = "") -> None: + self.stdout = stdout + self.returncode = 0 + + +def _make_project(tmp_path: Path) -> Path: + project = tmp_path / "project" + project.mkdir() + (project / ".flocks").mkdir() + (project / "README.md").write_text("hello\n", encoding="utf-8") + return project + + +@pytest.mark.asyncio +async def test_find_text_rejects_directory_outside_project(monkeypatch: pytest.MonkeyPatch, tmp_path: Path): + project = _make_project(tmp_path) + outside = tmp_path / "outside" + outside.mkdir() + monkeypatch.chdir(project) + + with pytest.raises(HTTPException) as exc_info: + await find_routes.find_text(pattern="secret", directory=str(outside)) + + assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +async def test_find_text_passes_leading_dash_pattern_after_separator( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +): + project = _make_project(tmp_path) + monkeypatch.chdir(project) + commands: list[list[str]] = [] + + def _fake_run(cmd, **kwargs): + commands.append(cmd) + assert kwargs["cwd"] == str(project.resolve()) + match = { + "type": "match", + "data": { + "path": {"text": "README.md"}, + "line_number": 1, + "lines": {"text": "hello\n"}, + }, + } + return _RunResult(stdout=json.dumps(match)) + + monkeypatch.setattr(find_routes.subprocess, "run", _fake_run) + + results = await find_routes.find_text(pattern="--pre=/tmp/evil.sh", directory=str(project)) + + assert commands[0][-2:] == ["--", "--pre=/tmp/evil.sh"] + assert results[0].file == "README.md" + + +@pytest.mark.asyncio +async def test_find_files_passes_leading_dash_query_after_separator( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +): + project = _make_project(tmp_path) + monkeypatch.chdir(project) + commands: list[list[str]] = [] + + def _fake_run(cmd, **kwargs): + commands.append(cmd) + assert kwargs["cwd"] == str(project.resolve()) + return _RunResult(stdout="README.md\n") + + monkeypatch.setattr(find_routes.subprocess, "run", _fake_run) + + results = await find_routes.find_files(query="--exec", directory=str(project)) + + assert commands[0][-2:] == ["--", "--exec"] + assert results == ["README.md"] diff --git a/tests/server/routes/test_mcp_routes.py b/tests/server/routes/test_mcp_routes.py index 70224872..8d0b25b3 100644 --- a/tests/server/routes/test_mcp_routes.py +++ b/tests/server/routes/test_mcp_routes.py @@ -261,6 +261,49 @@ async def fake_config_get(cls): assert data["config"]["type"] == "sse" assert data["config"]["url"] == "https://example.com/mcp" + @pytest.mark.asyncio + async def test_get_mcp_server_info_masks_plaintext_sensitive_values( + self, client: AsyncClient, monkeypatch: pytest.MonkeyPatch + ): + async def fake_get_server_info(name: str): + return None + + async def fake_config_get(cls): + return type("ConfigStub", (), {"mcp": {}})() + + monkeypatch.setattr(mcp_routes.MCP, "get_server_info", fake_get_server_info) + monkeypatch.setattr( + mcp_routes.Config, + "get", + classmethod(fake_config_get), + ) + monkeypatch.setattr( + mcp_routes.ConfigWriter, + "get_mcp_server", + lambda name: { + "type": "remote", + "url": "https://example.com/mcp", + "auth": { + "type": "apikey", + "location": "header", + "param_name": "Authorization", + "value": "Bearer token123", + }, + "headers": { + "Authorization": "Bearer token123", + "X-Client": "flocks", + }, + }, + ) + + resp = await client.get("/api/mcp/demo-remote") + + assert resp.status_code == 200, resp.text + data = resp.json() + assert data["config"]["auth"]["value"] == "***" + assert data["config"]["headers"]["Authorization"] == "***" + assert data["config"]["headers"]["X-Client"] == "flocks" + @pytest.mark.asyncio async def test_test_mcp_connection_normalizes_sse_alias_to_remote( self, client: AsyncClient, monkeypatch: pytest.MonkeyPatch @@ -339,6 +382,15 @@ async def fake_remove(name: str) -> bool: "get", classmethod(fake_config_get), ) + monkeypatch.setattr( + mcp_routes.ConfigWriter, + "get_mcp_server", + lambda name: { + "type": "remote", + "url": "https://old.example.com/mcp", + "headers": {"Api-Key": "{secret:qianxin_mcp_key}"}, + }, + ) monkeypatch.setattr(mcp_routes.MCP, "status", fake_status) monkeypatch.setattr(mcp_routes.MCP, "remove", fake_remove) monkeypatch.setattr( @@ -397,6 +449,15 @@ async def fake_remove(name: str) -> bool: "get", classmethod(fake_config_get), ) + monkeypatch.setattr( + mcp_routes.ConfigWriter, + "get_mcp_server", + lambda name: { + "type": "local", + "command": ["python", "-m", "mcp_panther"], + "enabled": True, + }, + ) monkeypatch.setattr(mcp_routes.MCP, "status", fake_status) monkeypatch.setattr(mcp_routes.MCP, "remove", fake_remove) monkeypatch.setattr( @@ -444,6 +505,14 @@ async def fake_status() -> dict[str, McpStatusInfo]: "get", classmethod(fake_config_get), ) + monkeypatch.setattr( + mcp_routes.ConfigWriter, + "get_mcp_server", + lambda name: { + "type": "remote", + "url": "https://old.example.com/mcp", + }, + ) monkeypatch.setattr(mcp_routes.MCP, "status", fake_status) monkeypatch.setattr( mcp_routes.ConfigWriter, @@ -473,6 +542,83 @@ def set(self, key: str, value: str) -> None: == "https://example.com/mcp?apikey={secret:demo-mcp_mcp_key}" ) + @pytest.mark.asyncio + async def test_update_mcp_server_restores_masked_sensitive_values( + self, client: AsyncClient, monkeypatch: pytest.MonkeyPatch + ): + stored_configs: dict[str, dict] = {} + saved_secrets: dict[str, str] = {} + + async def fake_status() -> dict[str, McpStatusInfo]: + return {} + + monkeypatch.setattr(mcp_routes.MCP, "status", fake_status) + monkeypatch.setattr( + mcp_routes.ConfigWriter, + "get_mcp_server", + lambda name: { + "type": "remote", + "url": "https://old.example.com/mcp", + "auth": { + "type": "apikey", + "location": "header", + "param_name": "Authorization", + "value": "Bearer token123", + }, + "headers": { + "Authorization": "Bearer token123", + "X-Client": "flocks", + }, + }, + ) + monkeypatch.setattr( + mcp_routes.ConfigWriter, + "add_mcp_server", + lambda name, config: stored_configs.__setitem__(name, config), + ) + monkeypatch.setattr(tool_loader, "save_mcp_config", lambda name, config: None) + + class SecretManagerStub: + def set(self, key: str, value: str) -> None: + saved_secrets[key] = value + + monkeypatch.setattr( + "flocks.security.get_secret_manager", + lambda: SecretManagerStub(), + ) + + resp = await client.put( + "/api/mcp/demo-mcp", + json={ + "config": { + "url": "https://new.example.com/mcp", + "auth": { + "type": "apikey", + "location": "header", + "param_name": "Authorization", + "value": "***", + }, + "headers": { + "Authorization": "***", + "X-Client": "flocks-web", + }, + } + }, + ) + + assert resp.status_code == 200, resp.text + assert saved_secrets == { + "demo-mcp_mcp_key": "token123", + "demo-mcp_authorization_header": "Bearer token123", + } + assert stored_configs["demo-mcp"]["url"] == "https://new.example.com/mcp" + assert stored_configs["demo-mcp"]["auth"]["value"] == "{secret:demo-mcp_mcp_key}" + assert ( + stored_configs["demo-mcp"]["headers"]["Authorization"] + == "{secret:demo-mcp_authorization_header}" + ) + assert stored_configs["demo-mcp"]["headers"]["X-Client"] == "flocks-web" + @pytest.mark.asyncio async def test_catalog_install_defaults_to_disabled_without_connecting( self, client: AsyncClient, monkeypatch: pytest.MonkeyPatch diff --git a/tests/server/routes/test_pty_routes.py b/tests/server/routes/test_pty_routes.py new file mode 100644 index 00000000..82bccc31 --- /dev/null +++ b/tests/server/routes/test_pty_routes.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from unittest.mock import Mock + +import pytest +from fastapi import HTTPException + +from flocks.server.routes import pty as pty_routes + + +class _FakeWebSocket: + def __init__(self) -> None: + self.close_code = None + self.close_reason = None + self.accepted = False + + async def close(self, code: int, reason: str = "") -> None: + self.close_code = code + self.close_reason = reason + + async def accept(self) -> None: + self.accepted = True + + +@pytest.mark.asyncio +async def test_pty_websocket_authenticates_before_session_lookup(monkeypatch: pytest.MonkeyPatch): + websocket = _FakeWebSocket() + + async def _reject(_websocket): + raise HTTPException(status_code=401, detail="missing auth") + + get_session = Mock() + monkeypatch.setattr(pty_routes, "apply_auth_for_request", _reject) + monkeypatch.setattr(pty_routes.Pty, "get", get_session) + + await pty_routes.connect_session(websocket, "pty_missing") + + assert websocket.close_code == 4401 + assert websocket.close_reason == "missing auth" + assert websocket.accepted is False + get_session.assert_not_called() diff --git a/tests/server/routes/test_remaining_routes.py b/tests/server/routes/test_remaining_routes.py index b293f0d6..86314783 100644 --- a/tests/server/routes/test_remaining_routes.py +++ b/tests/server/routes/test_remaining_routes.py @@ -81,6 +81,8 @@ def isolated_workflow_filesystem(tmp_path: Path, monkeypatch: pytest.MonkeyPatch monkeypatch.setattr(workflow_routes, "_workspace_root", workspace_root, raising=False) monkeypatch.setattr(workflow_routes, "_find_workspace_root", lambda: workspace_root) + monkeypatch.setattr(workflow_routes, "_workflow_dir", lambda workflow_id: project_root / workflow_id) + monkeypatch.setattr(workflow_routes, "_global_workflow_dir", lambda workflow_id: global_root / workflow_id) monkeypatch.setattr(fs_store, "_workspace_root", workspace_root, raising=False) monkeypatch.setattr(fs_store, "find_workspace_root", lambda: workspace_root) monkeypatch.setattr( @@ -96,7 +98,11 @@ def isolated_workflow_filesystem(tmp_path: Path, monkeypatch: pytest.MonkeyPatch raising=False, ) - yield + yield { + "workspace_root": workspace_root, + "project_root": project_root, + "global_root": global_root, + } # =========================================================================== @@ -113,9 +119,28 @@ async def test_list_workflows_returns_array(self, client: AsyncClient): assert isinstance(resp.json(), list) @pytest.mark.asyncio - async def test_create_workflow(self, client: AsyncClient): + async def test_create_workflow( + self, + client: AsyncClient, + isolated_workflow_filesystem, + monkeypatch: pytest.MonkeyPatch, + ): """POST /api/workflow creates a workflow and returns it.""" - resp = await client.post("/api/workflow", json=_WORKFLOW_PAYLOAD) + from flocks.server import auth as auth_module + + class _SecretManagerStub: + def get(self, key: str): + if key == auth_module.API_TOKEN_SECRET_ID: + return "abc123" + return None + + monkeypatch.setattr(auth_module, "get_secret_manager", lambda: _SecretManagerStub()) + + resp = await client.post( + "/api/workflow", + json=_WORKFLOW_PAYLOAD, + headers={"Authorization": "Bearer abc123"}, + ) assert resp.status_code in ( status.HTTP_200_OK, status.HTTP_201_CREATED, @@ -123,6 +148,52 @@ async def test_create_workflow(self, client: AsyncClient): data = resp.json() assert data["name"] == "test-workflow" assert "id" in data + assert data["source"] == "global" + assert (isolated_workflow_filesystem["global_root"] / data["id"] / "workflow.json").is_file() + assert not (isolated_workflow_filesystem["project_root"] / data["id"] / "workflow.json").exists() + + @pytest.mark.asyncio + async def test_import_workflow_defaults_to_global_storage( + self, + client: AsyncClient, + isolated_workflow_filesystem, + monkeypatch: pytest.MonkeyPatch, + ): + """POST /api/workflow/import stores imported workflows under global user storage by default.""" + from flocks.server import auth as auth_module + + class _SecretManagerStub: + def get(self, key: str): + if key == auth_module.API_TOKEN_SECRET_ID: + return "abc123" + return None + + monkeypatch.setattr(auth_module, "get_secret_manager", lambda: _SecretManagerStub()) + + payload = { + **_WORKFLOW_JSON, + "name": "imported-workflow", + "metadata": { + "description": "Imported workflow", + "category": "default", + }, + } + + resp = await client.post( + "/api/workflow/import", + json=payload, + headers={"Authorization": "Bearer abc123"}, + ) + assert resp.status_code in ( + status.HTTP_200_OK, + status.HTTP_201_CREATED, + ), resp.text + + data = resp.json() + assert data["name"] == "imported-workflow" + assert data["source"] == "global" + assert (isolated_workflow_filesystem["global_root"] / data["id"] / "workflow.json").is_file() + assert not (isolated_workflow_filesystem["project_root"] / data["id"] / "workflow.json").exists() @pytest.mark.asyncio async def test_get_workflow(self, client: AsyncClient): diff --git a/tests/server/routes/test_task_scheduler_context_route.py b/tests/server/routes/test_task_scheduler_context_route.py new file mode 100644 index 00000000..8df880d7 --- /dev/null +++ b/tests/server/routes/test_task_scheduler_context_route.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import pytest + +from flocks.task.manager import TaskManager +from flocks.task.models import ExecutionMode, ExecutionTriggerType, SchedulerMode, TaskTrigger + + +@pytest.mark.asyncio +async def test_update_scheduler_accepts_context_for_workflow_inputs(client): + scheduler = await TaskManager.create_scheduler( + title="工作流定时任务", + mode=SchedulerMode.CRON, + trigger=TaskTrigger(cron="0 9 * * *", timezone="Asia/Shanghai"), + execution_mode=ExecutionMode.WORKFLOW, + workflow_id="demo-workflow", + context={"keyword": "before"}, + ) + + response = await client.put( + f"/api/task-schedulers/{scheduler.id}", + json={ + "context": {"keyword": "after", "limit": 5}, + "workflowID": "demo-workflow", + }, + ) + + assert response.status_code == 200 + assert response.json()["context"] == {"keyword": "after", "limit": 5} + + updated = await TaskManager.get_scheduler(scheduler.id) + assert updated is not None + assert updated.context == {"keyword": "after", "limit": 5} + + execution = await TaskManager.create_execution_from_scheduler( + updated, + trigger_type=ExecutionTriggerType.SCHEDULED, + enqueue=False, + ) + assert execution.execution_input_snapshot["context"] == { + "keyword": "after", + "limit": 5, + } diff --git a/tests/server/routes/test_tool_routes.py b/tests/server/routes/test_tool_routes.py index d927c7ef..898b03d6 100644 --- a/tests/server/routes/test_tool_routes.py +++ b/tests/server/routes/test_tool_routes.py @@ -8,6 +8,7 @@ import pytest from httpx import AsyncClient +from flocks.auth.context import AuthUser from flocks.session.message import Message, MessageRole from flocks.session.session import Session from flocks.tool.registry import Tool, ToolCategory, ToolInfo, ToolRegistry, ToolResult @@ -28,6 +29,33 @@ def _temporary_tool(tool: Tool) -> Iterator[None]: ToolRegistry._tools.pop(tool.info.name, None) +class _FakeSessionUser: + def __init__(self, role: str) -> None: + self.role = role + + def to_auth_user(self) -> AuthUser: + return AuthUser( + id=f"usr_{self.role}", + username=f"{self.role}-user", + role=self.role, + status="active", + must_reset_password=False, + ) + + +def _patch_session_user(monkeypatch: pytest.MonkeyPatch, role: str) -> None: + from flocks.server import auth as auth_module + + async def _has_users(): + return True + + async def _get_user_by_session_id(_session_id: str): + return _FakeSessionUser(role) + + monkeypatch.setattr(auth_module.AuthService, "has_users", _has_users) + monkeypatch.setattr(auth_module.AuthService, "get_user_by_session_id", _get_user_by_session_id) + + async def _create_session_and_message(title: str) -> tuple[str, str]: session = await Session.create( project_id="default", @@ -45,6 +73,49 @@ async def _create_session_and_message(title: str) -> tuple[str, str]: class TestToolRouteSecurity: + @pytest.mark.asyncio + async def test_viewer_cannot_create_plugin_tool(self, client: AsyncClient, monkeypatch: pytest.MonkeyPatch): + _patch_session_user(monkeypatch, "viewer") + + response = await client.post( + "/api/tools", + headers={"cookie": "flocks_session=viewer-session"}, + json={ + "name": "viewer_created_tool", + "description": "should be rejected", + "handler": { + "type": "http", + "method": "GET", + "url": "https://example.com", + }, + }, + ) + + assert response.status_code == 403 + + @pytest.mark.asyncio + async def test_viewer_cannot_update_plugin_tool(self, client: AsyncClient, monkeypatch: pytest.MonkeyPatch): + _patch_session_user(monkeypatch, "viewer") + + response = await client.put( + "/api/tools/existing_tool", + headers={"cookie": "flocks_session=viewer-session"}, + json={"description": "nope"}, + ) + + assert response.status_code == 403 + + @pytest.mark.asyncio + async def test_viewer_cannot_delete_plugin_tool(self, client: AsyncClient, monkeypatch: pytest.MonkeyPatch): + _patch_session_user(monkeypatch, "viewer") + + response = await client.delete( + "/api/tools/existing_tool", + headers={"cookie": "flocks_session=viewer-session"}, + ) + + assert response.status_code == 403 + @pytest.mark.asyncio async def test_execute_blocks_direct_bash_access(self, client: AsyncClient): response = await client.post( diff --git a/tests/server/test_app_cors.py b/tests/server/test_app_cors.py index 98f19f8a..295a2266 100644 --- a/tests/server/test_app_cors.py +++ b/tests/server/test_app_cors.py @@ -28,7 +28,7 @@ def test_read_cors_config_merges_runtime_and_configured_origins(monkeypatch, tmp "http://10.0.0.9:5173", "https://configured.example", ] - assert allow_origin_regex == app_module._LOCALHOST_ORIGIN_RE + assert allow_origin_regex is None def test_read_cors_config_ignores_localhost_and_wildcard_runtime_hosts(monkeypatch, tmp_path) -> None: @@ -39,8 +39,12 @@ def test_read_cors_config_ignores_localhost_and_wildcard_runtime_hosts(monkeypat allow_origins, allow_origin_regex = app_module._read_cors_config() - assert allow_origins == [] - assert allow_origin_regex == app_module._LOCALHOST_ORIGIN_RE + assert allow_origins == [ + "http://127.0.0.1:5173", + "http://[::1]:5173", + "http://localhost:5173", + ] + assert allow_origin_regex is None def test_read_cors_config_brackets_ipv6_webui_origin(monkeypatch, tmp_path) -> None: @@ -52,4 +56,14 @@ def test_read_cors_config_brackets_ipv6_webui_origin(monkeypatch, tmp_path) -> N allow_origins, allow_origin_regex = app_module._read_cors_config() assert allow_origins == ["http://[2001:db8::2]:5173"] - assert allow_origin_regex == app_module._LOCALHOST_ORIGIN_RE + assert allow_origin_regex is None + + +def test_read_cors_config_does_not_allow_any_localhost(monkeypatch, tmp_path) -> None: + config_file = tmp_path / "missing.json" + monkeypatch.setattr(app_module.Config, "get_config_file", lambda: config_file) + + allow_origins, allow_origin_regex = app_module._read_cors_config() + + assert allow_origins == [] + assert allow_origin_regex is None diff --git a/tests/server/test_app_errors.py b/tests/server/test_app_errors.py new file mode 100644 index 00000000..673cd89a --- /dev/null +++ b/tests/server/test_app_errors.py @@ -0,0 +1,33 @@ +import json + +import pytest +from starlette.requests import Request + +from flocks.server import app as app_module + + +def _request(path: str = "/api/test") -> Request: + return Request( + { + "type": "http", + "method": "GET", + "path": path, + "headers": [], + "query_string": b"", + "server": ("127.0.0.1", 8000), + "client": ("127.0.0.1", 1234), + "scheme": "http", + } + ) + + +@pytest.mark.asyncio +async def test_general_exception_response_does_not_expose_traceback(): + response = await app_module.general_exception_handler(_request(), RuntimeError("secret path detail")) + body = json.loads(response.body) + + assert response.status_code == 500 + assert body == { + "error": "InternalServerError", + "message": "Internal server error", + } diff --git a/tests/server/test_auth_compat.py b/tests/server/test_auth_compat.py index 85c30607..a013e849 100644 --- a/tests/server/test_auth_compat.py +++ b/tests/server/test_auth_compat.py @@ -83,15 +83,27 @@ async def test_apply_auth_for_request_non_browser_accepts_valid_token(monkeypatc @pytest.mark.asyncio -async def test_apply_auth_for_request_non_browser_loopback_allows_without_token(monkeypatch): +async def test_apply_auth_for_request_non_browser_loopback_rejects_without_token_by_default(monkeypatch): monkeypatch.setattr(auth_module, "get_secret_manager", lambda: _FakeSecrets({auth_module.API_TOKEN_SECRET_ID: "abc123"})) + monkeypatch.delenv("PYTEST_CURRENT_TEST", raising=False) request = _make_request(headers={"user-agent": "curl/8.0"}) - _, token, user = await auth_module.apply_auth_for_request(request) - try: - assert user is not None - assert user.username == "local-service" - finally: - auth_module.clear_auth_context(token) + with pytest.raises(HTTPException) as exc_info: + await auth_module.apply_auth_for_request(request) + assert exc_info.value.status_code == 401 + assert "Bearer API Token" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_apply_auth_for_request_ignores_pytest_current_test_for_loopback(monkeypatch): + monkeypatch.setattr(auth_module, "get_secret_manager", lambda: _FakeSecrets({auth_module.API_TOKEN_SECRET_ID: "abc123"})) + monkeypatch.setenv("PYTEST_CURRENT_TEST", "tests/example.py::test_name (call)") + request = _make_request(headers={"user-agent": "curl/8.0"}) + + with pytest.raises(HTTPException) as exc_info: + await auth_module.apply_auth_for_request(request) + + assert exc_info.value.status_code == 401 + assert "Bearer API Token" in str(exc_info.value.detail) @pytest.mark.asyncio diff --git a/tests/server/test_tool_setting_routes.py b/tests/server/test_tool_setting_routes.py index 6d11c1b0..1476c3c8 100644 --- a/tests/server/test_tool_setting_routes.py +++ b/tests/server/test_tool_setting_routes.py @@ -19,6 +19,8 @@ from fastapi import FastAPI from fastapi.testclient import TestClient +from flocks.auth.context import AuthUser +from flocks.server.auth import require_admin from flocks.tool.registry import ( Tool, ToolCategory, @@ -83,6 +85,13 @@ def tool_client(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): from flocks.server.routes.tool import router app = FastAPI() + app.dependency_overrides[require_admin] = lambda: AuthUser( + id="admin-test", + username="admin-test", + role="admin", + status="active", + must_reset_password=False, + ) app.include_router(router, prefix="/api/tools") client = TestClient(app, raise_server_exceptions=True) @@ -106,8 +115,41 @@ def _read_settings() -> dict: return ConfigWriter.list_tool_settings() +def _viewer_client() -> TestClient: + from flocks.server.routes.tool import router + + app = FastAPI() + + @app.middleware("http") + async def _viewer_auth(request, call_next): + request.state.auth_user = AuthUser( + id="viewer-test", + username="viewer-test", + role="viewer", + status="active", + must_reset_password=False, + ) + return await call_next(request) + + app.include_router(router, prefix="/api/tools") + return TestClient(app, raise_server_exceptions=True) + + # ─── Tests ─────────────────────────────────────────────────────────────────── +def test_tool_mutation_routes_require_admin(): + client = _viewer_client() + + responses = [ + client.patch("/api/tools/anything", json={"enabled": True}), + client.post("/api/tools/anything/reset"), + client.post("/api/tools/refresh"), + client.post("/api/tools/anything/reload"), + ] + + assert [response.status_code for response in responses] == [403, 403, 403, 403] + + class TestToolInfoResponse: def test_lists_factory_default_and_no_setting_initially(self, tool_client): client, _, disabled_tool = tool_client diff --git a/tests/session/test_runner_step.py b/tests/session/test_runner_step.py index d3bff186..44e97eff 100644 --- a/tests/session/test_runner_step.py +++ b/tests/session/test_runner_step.py @@ -918,3 +918,37 @@ async def test_record_usage_if_available_swallows_runtime_error(): ) with patch.dict("sys.modules", {"flocks.provider.usage_service": fake_module}): await runner._record_usage_if_available(usage) + + +@pytest.mark.asyncio +async def test_to_chat_messages_expands_workflow_node_ref_marker(monkeypatch): + runner = _make_runner("ses_runner_node_ref") + user_message = UserMessageInfo( + id="msg_user_node_ref", + sessionID=runner.session.id, + role="user", + time={"created": 1_000}, + agent="rex", + model={"providerID": "anthropic", "modelID": "claude-sonnet"}, + ) + + monkeypatch.setattr( + runner_mod.Message, + "parts", + AsyncMock(return_value=[ + SimpleNamespace( + type="text", + text="@@node:query_fofa|python\n只修改这个节点的代码并保留其他节点不变", + ), + ]), + ) + + chat_messages = await runner._to_chat_messages([user_message], []) + + assert len(chat_messages) == 1 + assert chat_messages[0].role == "user" + assert isinstance(chat_messages[0].content, str) + assert "Selected workflow node context:" in chat_messages[0].content + assert "node_id: query_fofa" in chat_messages[0].content + assert "node_type: python" in chat_messages[0].content + assert "只修改这个节点的代码并保留其他节点不变" in chat_messages[0].content diff --git a/tests/skills/__init__.py b/tests/skills/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/skills/test_tool_validator.py b/tests/skills/test_tool_validator.py new file mode 100644 index 00000000..ed1b4f98 --- /dev/null +++ b/tests/skills/test_tool_validator.py @@ -0,0 +1,206 @@ +"""Tests for .flocks/plugins/skills/tool-builder/validator.py.""" +import sys +import textwrap +from pathlib import Path + +import pytest + +# Make the validator importable without installing it. +SKILL_DIR = Path(__file__).parent.parent.parent / ".flocks" / "plugins" / "skills" / "tool-builder" +sys.path.insert(0, str(SKILL_DIR)) + +from validator import main, validate_yaml_tool, validate_python_tool # noqa: E402 + + +# ── helpers ──────────────────────────────────────────────────────────────── + +def write(tmp_path: Path, filename: str, content: str) -> Path: + p = tmp_path / filename + p.write_text(textwrap.dedent(content), encoding="utf-8") + return p + + +# ── YAML-HTTP mode ───────────────────────────────────────────────────────── + +class TestYamlHttpTool: + def test_valid_minimal_http_tool_passes(self, tmp_path): + p = write(tmp_path, "my_tool.yaml", """\ + name: my_tool + description: A well-described tool that does something useful for the agent. + category: custom + enabled: true + handler: + type: http + method: GET + url: https://example.com/api + inputSchema: + type: object + properties: + q: + type: string + description: Search query + """) + report = validate_yaml_tool(p) + assert report.fail_count == 0, report.issues + + def test_missing_name_is_a_failure(self, tmp_path): + p = write(tmp_path, "no_name.yaml", """\ + description: A fine description. + category: custom + enabled: true + handler: + type: http + url: https://example.com + """) + report = validate_yaml_tool(p) + assert report.fail_count > 0 + assert any("name" in i.message.lower() for i in report.issues if i.level == "FAIL") + + def test_undeclared_url_placeholder_is_a_failure(self, tmp_path): + p = write(tmp_path, "bad_url.yaml", """\ + name: bad_url + description: Long enough description for the validator to be happy here. + category: custom + enabled: true + handler: + type: http + method: GET + url: https://example.com/{undeclared_param} + inputSchema: + type: object + properties: {} + """) + report = validate_yaml_tool(p) + assert report.fail_count > 0 + assert any("undeclared_param" in i.message for i in report.issues if i.level == "FAIL") + + def test_invalid_category_is_a_failure(self, tmp_path): + p = write(tmp_path, "bad_cat.yaml", """\ + name: bad_cat + description: Some long enough description that passes the length check ok. + category: nonsense + enabled: true + handler: + type: http + url: https://example.com + """) + report = validate_yaml_tool(p) + assert report.fail_count > 0 + + +# ── YAML-script mode ─────────────────────────────────────────────────────── + +class TestYamlScriptTool: + def test_valid_script_tool_passes(self, tmp_path): + handler_py = write(tmp_path, "my_handler.py", """\ + async def handle(ctx, q: str) -> dict: + return {"result": q} + """) + p = write(tmp_path, "my_script_tool.yaml", f"""\ + name: my_script_tool + description: Script-based tool with a proper handler that does something. + category: custom + enabled: true + handler: + type: script + script_file: {handler_py.name} + function: handle + inputSchema: + type: object + properties: + q: + type: string + description: Input query + """) + report = validate_yaml_tool(p) + assert report.fail_count == 0, report.issues + + def test_missing_script_file_is_a_failure(self, tmp_path): + p = write(tmp_path, "no_script.yaml", """\ + name: no_script + description: Script tool whose script_file does not exist on disk. + category: custom + enabled: true + handler: + type: script + script_file: nonexistent_handler.py + """) + report = validate_yaml_tool(p) + assert report.fail_count > 0 + + +# ── Python tool mode ─────────────────────────────────────────────────────── + +class TestPythonTool: + def test_valid_python_tool_passes(self, tmp_path): + p = write(tmp_path, "my_python_tool.py", """\ + from flocks.tool.registry import ToolRegistry, ToolResult + + @ToolRegistry.register_function( + name="my_python_tool", + description="Does something useful for the agent in a local context.", + category="custom", + parameters=[ + {"name": "text", "type": "string", "description": "Input text"}, + ], + ) + async def my_python_tool(ctx, text: str) -> ToolResult: + return ToolResult(output=text) + """) + report = validate_python_tool(p) + assert report.fail_count == 0, report.issues + + def test_missing_register_decorator_is_a_failure(self, tmp_path): + p = write(tmp_path, "no_decorator.py", """\ + async def handle(text: str) -> dict: + return {"result": text} + """) + report = validate_python_tool(p) + assert report.fail_count > 0 + + def test_non_async_function_is_a_failure(self, tmp_path): + p = write(tmp_path, "sync_fn.py", """\ + from flocks.tool.registry import ToolRegistry, ToolResult + + @ToolRegistry.register_function( + name="sync_fn", + description="A synchronous function that should be async.", + category="custom", + parameters=[], + ) + def sync_fn() -> ToolResult: + return ToolResult(output="ok") + """) + report = validate_python_tool(p) + assert report.fail_count > 0 + + +# ── CLI --strict mode ────────────────────────────────────────────────────── + +class TestCliStrictMode: + def test_strict_exits_nonzero_on_warnings(self, tmp_path): + # A tool with no parameters triggers a WARN. + p = write(tmp_path, "warn_tool.yaml", """\ + name: warn_tool + description: Decent description but has no parameters at all. + category: custom + enabled: true + handler: + type: http + url: https://example.com + """) + exit_code = main(["--strict", str(p)]) + assert exit_code != 0 + + def test_non_strict_exits_zero_on_warnings_only(self, tmp_path): + p = write(tmp_path, "warn_tool2.yaml", """\ + name: warn_tool2 + description: Decent description but has no parameters at all here. + category: custom + enabled: true + handler: + type: http + url: https://example.com + """) + exit_code = main([str(p)]) + assert exit_code == 0 diff --git a/tests/tool/test_onesec_api_tool.py b/tests/tool/test_onesec_api_tool.py index 7cadfd0f..01ccfe6f 100644 --- a/tests/tool/test_onesec_api_tool.py +++ b/tests/tool/test_onesec_api_tool.py @@ -1,4 +1,5 @@ import base64 +import datetime as dt import hmac from hashlib import sha1 from pathlib import Path @@ -191,6 +192,303 @@ async def test_onesec_dns_get_public_ip_list_honors_verify_ssl_true(): assert kwargs["ssl"] is True +@pytest.mark.asyncio +async def test_onesec_dns_search_blocked_queries_normalizes_block_status(): + tool = _load_tool("onesec_dns.yaml") + fake_session = _FakeSession( + [ + _FakeResponse( + json_payload={ + "data": { + "total_num": 1, + "items": [ + { + "domain": "blocked.example", + "block_reason": "threat", + } + ], + } + } + ) + ] + ) + mock_secret_manager = MagicMock() + mock_secret_manager.get.return_value = "api-key-1|secret-1" + + with ( + patch("flocks.security.get_secret_manager", return_value=mock_secret_manager), + patch( + "flocks.config.config_writer.ConfigWriter.get_api_service_raw", + return_value={"apiKey": "{secret:onesec_credentials}"}, + ), + patch("aiohttp.ClientSession", return_value=fake_session), + patch("time.time", return_value=1700000000), + ): + result = await tool.handler( + ToolContext(session_id="test", message_id="test"), + action="dns_search_blocked_queries", + time_from=1699990000, + time_to=1700000000, + domain="blocked.example", + keyword="blocked.example", + show_unblocked_threat=1, + cur_page=2, + ) + + assert result.success is True + assert result.output["items"] == [ + { + "domain": "blocked.example", + "block_reason": "threat", + "result": "block", + "is_blocked": True, + } + ] + + method, url, kwargs = fake_session.calls[0] + assert method == "POST" + assert url == "https://console.onesec.net/open/api/client/searchBlockedQueries" + assert kwargs["json"] == { + "time_from": 1699990000, + "time_to": 1700000000, + "domain": "blocked.example", + "keyword": "blocked.example", + "show_unblocked_threat": 1, + "cur_page": 2, + } + + +@pytest.mark.asyncio +async def test_onesec_dns_search_blocked_queries_defaults_keyword_and_parses_datetime_strings(): + tool = _load_tool("onesec_dns.yaml") + fake_session = _FakeSession( + [ + _FakeResponse( + json_payload={ + "data": { + "total_num": 1, + "items": [{"domain": "bilibili.com", "result": "block"}], + } + } + ) + ] + ) + mock_secret_manager = MagicMock() + mock_secret_manager.get.return_value = "api-key-1|secret-1" + local_tz = dt.datetime.now().astimezone().tzinfo + + with ( + patch("flocks.security.get_secret_manager", return_value=mock_secret_manager), + patch( + "flocks.config.config_writer.ConfigWriter.get_api_service_raw", + return_value={"apiKey": "{secret:onesec_credentials}"}, + ), + patch("aiohttp.ClientSession", return_value=fake_session), + ): + result = await tool.handler( + ToolContext(session_id="test", message_id="test"), + action="dns_search_blocked_queries", + domain="bilibili.com", + time_from="2026-05-08 00:00:00", + time_to="2026-05-08 23:59:59", + ) + + assert result.success is True + + expected_time_from = int(dt.datetime(2026, 5, 8, 0, 0, 0, tzinfo=local_tz).timestamp()) + expected_time_to = int(dt.datetime(2026, 5, 8, 23, 59, 59, tzinfo=local_tz).timestamp()) + + method, url, kwargs = fake_session.calls[0] + assert method == "POST" + assert url == "https://console.onesec.net/open/api/client/searchBlockedQueries" + assert kwargs["json"] == { + "time_from": expected_time_from, + "time_to": expected_time_to, + "domain": "bilibili.com", + "keyword": "bilibili.com", + } + + +@pytest.mark.asyncio +async def test_onesec_dns_search_blocked_queries_rejects_invalid_datetime_string(): + tool = _load_tool("onesec_dns.yaml") + + result = await tool.handler( + ToolContext(session_id="test", message_id="test"), + action="dns_search_blocked_queries", + domain="bilibili.com", + keyword="bilibili.com", + time_from="tomorrow morning", + time_to=1700000000, + ) + + assert result.success is False + assert ( + result.error + == "time_from (tomorrow morning) 必须是 Unix 秒级时间戳。" + " 当前工具支持自动转换常见日期时间格式,如 `YYYY-MM-DD HH:MM:SS`。" + ) + + +@pytest.mark.asyncio +async def test_onesec_dns_search_queries_rejects_time_from_older_than_24_hours(): + tool = _load_tool("onesec_dns.yaml") + + with patch("time.time", return_value=1700000000): + result = await tool.handler( + ToolContext(session_id="test", message_id="test"), + action="dns_search_queries", + time_from=1699900000, + time_to=1699950000, + domain="example.com", + ) + + assert result.success is False + assert ( + result.error + == "按 OneSEC API 文档,`dns_search_queries` 仅支持最近 24 小时内的数据。请将 time_from 设置在最近 24 小时内。" + ) + + +@pytest.mark.asyncio +async def test_onesec_dns_search_blocked_queries_suggests_recent_for_public_ip_only(): + tool = _load_tool("onesec_dns.yaml") + + result = await tool.handler( + ToolContext(session_id="test", message_id="test"), + action="dns_search_blocked_queries", + public_ip="203.0.113.10", + time_from="2026-05-07 17:30:00", + time_to="2026-05-08 17:30:00", + ) + + assert result.success is False + assert ( + result.error + == "`dns_search_blocked_queries` 按 OneSEC API 文档要求必须传 `domain` 和 `keyword`。" + " 如果你当前只有 `public_ip` + 时间范围,且要查询最近 24 小时拦截记录," + " 请改用 `dns_get_recent_blocked_queries`。" + ) + + +@pytest.mark.asyncio +async def test_onesec_dns_get_recent_blocked_queries_rejects_doc_unsupported_filters(): + tool = _load_tool("onesec_dns.yaml") + + with patch("time.time", return_value=1700000000): + result = await tool.handler( + ToolContext(session_id="test", message_id="test"), + action="dns_get_recent_blocked_queries", + time_from=1699990000, + time_to=1700000000, + domain="blocked.example", + keyword="blocked", + ) + + assert result.success is False + assert ( + result.error + == "dns_get_recent_blocked_queries 按 OneSEC API 文档不支持以下参数: domain, keyword。" + " 若需要按域名或关键字筛选 DNS 拦截记录,请改用 `dns_search_blocked_queries`。" + ) + + +@pytest.mark.asyncio +async def test_onesec_dns_get_recent_blocked_queries_passes_doc_example_fields(): + tool = _load_tool("onesec_dns.yaml") + fake_session = _FakeSession( + [ + _FakeResponse( + json_payload={ + "data": { + "total_num": 1, + "items": [{"result": "block"}], + } + } + ) + ] + ) + mock_secret_manager = MagicMock() + mock_secret_manager.get.return_value = "api-key-1|secret-1" + + with ( + patch("flocks.security.get_secret_manager", return_value=mock_secret_manager), + patch( + "flocks.config.config_writer.ConfigWriter.get_api_service_raw", + return_value={"apiKey": "{secret:onesec_credentials}"}, + ), + patch("aiohttp.ClientSession", return_value=fake_session), + patch("time.time", return_value=1700000000), + ): + result = await tool.handler( + ToolContext(session_id="test", message_id="test"), + action="dns_get_recent_blocked_queries", + time_from=1699990000, + time_to=1700000000, + public_ip=["1.1.1.1"], + block_reason="threat", + show_unblocked_threat=1, + threat_level=[2, 3], + ) + + assert result.success is True + + method, url, kwargs = fake_session.calls[0] + assert method == "POST" + assert url == "https://console.onesec.net/open/api/client/getRecentBlockedQueries" + assert kwargs["json"] == { + "time_from": 1699990000, + "time_to": 1700000000, + "public_ip": ["1.1.1.1"], + "block_reason": "threat", + "show_unblocked_threat": 1, + "threat_level": [2, 3], + } + + +@pytest.mark.asyncio +async def test_onesec_dns_get_recent_blocked_queries_wraps_single_public_ip_string(): + tool = _load_tool("onesec_dns.yaml") + fake_session = _FakeSession( + [ + _FakeResponse( + json_payload={ + "data": { + "total_num": 1, + "items": [{"result": "block"}], + } + } + ) + ] + ) + mock_secret_manager = MagicMock() + mock_secret_manager.get.return_value = "api-key-1|secret-1" + + with ( + patch("flocks.security.get_secret_manager", return_value=mock_secret_manager), + patch( + "flocks.config.config_writer.ConfigWriter.get_api_service_raw", + return_value={"apiKey": "{secret:onesec_credentials}"}, + ), + patch("aiohttp.ClientSession", return_value=fake_session), + patch("time.time", return_value=1700000000), + ): + result = await tool.handler( + ToolContext(session_id="test", message_id="test"), + action="dns_get_recent_blocked_queries", + public_ip="203.0.113.10", + time_from="2026-05-07 17:30:00", + time_to="2026-05-08 17:30:00", + ) + + assert result.success is True + + method, url, kwargs = fake_session.calls[0] + assert method == "POST" + assert url == "https://console.onesec.net/open/api/client/getRecentBlockedQueries" + assert kwargs["json"]["public_ip"] == ["203.0.113.10"] + + @pytest.mark.asyncio async def test_onesec_edr_get_threat_files_uses_doc_page_structure(): tool = _load_tool("onesec_edr.yaml") @@ -291,6 +589,43 @@ async def test_onesec_threat_virus_scan_returns_integer_task_id(): } +@pytest.mark.asyncio +async def test_onesec_threat_virus_scan_rejects_invalid_task_type(): + tool = _load_tool("onesec_threat.yaml") + + result = await tool.handler( + ToolContext(session_id="test", message_id="test"), + action="threat_virus_scan", + agent_list=["umid-1"], + task_type=99999, + scanmode=1, + ) + + assert result.success is False + assert ( + result.error + == "`task_type`/`scan_type` 取值无效:99999。按 OneSEC API 文档仅支持:10110, 10120, 10130。" + ) + + +@pytest.mark.asyncio +async def test_onesec_threat_update_bd_version_rejects_invalid_os_arch(): + tool = _load_tool("onesec_threat.yaml") + + result = await tool.handler( + ToolContext(session_id="test", message_id="test"), + action="threat_update_bd_version", + os_platform="macos", + os_arch="ARM64", + ) + + assert result.success is False + assert ( + result.error + == "`os_arch` 取值无效:ARM64。当 `os_platform=macos` 时仅支持:Apple Silicon, Intel Chip。" + ) + + @pytest.mark.asyncio async def test_onesec_ops_query_agent_page_list_uses_sort_object_payload(): tool = _load_tool("onesec_ops.yaml") @@ -436,6 +771,45 @@ async def test_onesec_ops_query_task_page_list_uses_sort_object_payload(): } +@pytest.mark.asyncio +async def test_onesec_ops_query_task_page_list_rejects_invalid_time_type(): + tool = _load_tool("onesec_ops.yaml") + + result = await tool.handler( + ToolContext(session_id="test", message_id="test"), + action="ops_query_task_page_list", + time_type="last_seen", + begin_time=1699990000, + end_time=1700000000, + auto=0, + ) + + assert result.success is False + assert ( + result.error + == "`time_type` 取值无效:last_seen。按 OneSEC API 文档仅支持:create_time, update_time。" + ) + + +@pytest.mark.asyncio +async def test_onesec_ops_query_audit_log_rejects_begin_time_older_than_30_days(): + tool = _load_tool("onesec_ops.yaml") + + with patch("time.time", return_value=1700000000): + result = await tool.handler( + ToolContext(session_id="test", message_id="test"), + action="ops_query_audit_log", + begin_time=1690000000, + end_time=1690100000, + ) + + assert result.success is False + assert ( + result.error + == "按 OneSEC API 文档,`ops_query_audit_log` 仅支持最近 30 天内的审计日志。请调整 begin_time。" + ) + + @pytest.mark.asyncio async def test_onesec_edr_get_ioc_list_uses_doc_payload(): tool = _load_tool("onesec_edr.yaml") @@ -552,6 +926,24 @@ async def test_onesec_edr_get_threat_disposals_uses_incident_and_sort_payload(): } +@pytest.mark.asyncio +async def test_onesec_edr_get_threat_files_rejects_span_over_three_months(): + tool = _load_tool("onesec_edr.yaml") + + result = await tool.handler( + ToolContext(session_id="test", message_id="test"), + action="edr_get_threat_files", + time_from=1690000000, + time_to=1700000000, + ) + + assert result.success is False + assert ( + result.error + == "按 OneSEC API 文档,`edr_get_threat_files` 的时间窗口最长三个月。请缩小 time_from/time_to 范围。" + ) + + @pytest.mark.asyncio async def test_onesec_edr_recent_incidents_rejects_window_over_24_hours(): tool = _load_tool("onesec_edr.yaml") diff --git a/tests/tool/test_tool_plugin.py b/tests/tool/test_tool_plugin.py index 7d81d4cc..02232bca 100644 --- a/tests/tool/test_tool_plugin.py +++ b/tests/tool/test_tool_plugin.py @@ -11,6 +11,7 @@ import yaml from flocks.tool.tool_loader import ( + _build_execution_handler, _build_http_handler, _extract_response, _json_schema_to_params, @@ -939,3 +940,34 @@ async def test_response_extract(self): assert result.success is True assert result.output == [1, 2] + + +class TestExecutionHandler: + @pytest.mark.asyncio + async def test_inline_yaml_execution_loads_but_refuses_to_run_by_default( + self, + tmp_path: Path, + ): + handler = _build_execution_handler( + {"type": "python", "code": "return {'success': True}"}, + tmp_path / "tool.yaml", + ) + result = await handler(ToolContext(session_id="test", message_id="test")) + + assert result.success is False + assert "Inline YAML execution is disabled" in result.error + + @pytest.mark.asyncio + async def test_inline_yaml_execution_stays_disabled( + self, + tmp_path: Path, + ): + handler = _build_execution_handler( + {"type": "python", "code": "return {'success': True, 'value': _kw_['name']}"}, + tmp_path / "tool.yaml", + ) + + result = await handler(ToolContext(session_id="test", message_id="test"), name="after") + + assert result.success is False + assert "handler.type=script" in result.error diff --git a/tests/tool/test_web2cli_generate_cli.py b/tests/tool/test_web2cli_generate_cli.py index d21f80c6..644ab2b8 100644 --- a/tests/tool/test_web2cli_generate_cli.py +++ b/tests/tool/test_web2cli_generate_cli.py @@ -190,3 +190,143 @@ def test_generated_client_still_supports_plain_cookie_list(tmp_path, monkeypatch {"name": "sid", "value": "cookie-123"}, {"name": "api", "value": "cookie-456", "path": "/"}, ] + + +def _sample_spec(): + return { + "schemaVersion": "1.0", + "site": "example", + "command": "list_items", + "description": "List items from example API", + "baseUrl": "https://example.com", + "strategy": "COOKIE", + "auth": {"stateFile": "auth-state.json", "requiredCookies": [], "requiredHeaders": []}, + "operation": { + "method": "POST", + "endpoint": "/api/items/list", + "queryTemplate": {}, + "bodyTemplate": {"page": "${page}", "size": "${limit}"}, + "headers": {"Content-Type": "application/json"}, + }, + "rowSource": {"path": "$.data.items[]", "collectionPath": "$.data.items[]"}, + "args": [ + {"name": "page", "type": "int", "default": 1, "help": "Page number"}, + {"name": "limit", "type": "int", "default": 20, "help": "Page size"}, + ], + "columns": [ + {"name": "id", "path": "$.data.items[].id", "relativePath": "id", "type": "string"}, + {"name": "title", "path": "$.data.items[].title", "relativePath": "title", "type": "string"}, + ], + "verify": { + "args": {"page": 1, "limit": 20}, + "rowCount": {"min": 1, "max": 2}, + "columns": ["id", "title"], + "types": {"id": "string", "title": "string"}, + "notEmpty": ["id", "title"], + "patterns": {}, + }, + } + + +class _FakeResponse: + def __init__(self, payload): + self._payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self._payload + + +class _FakeRequestSession(_FakeSession): + def __init__(self, payload) -> None: + super().__init__() + self._payload = payload + self.request_calls = [] + + def request(self, method, url, json=None, params=None): + self.request_calls.append({"method": method, "url": url, "json": json, "params": params}) + return _FakeResponse(self._payload) + + +def test_generate_verify_materials_from_spec_uses_spec_contract(): + module = _load_module() + + verify = module.generate_verify_materials_from_spec(_sample_spec()) + + assert verify["site"] == "example" + assert verify["command"] == "list_items" + assert verify["expect"]["columns"] == ["id", "title"] + assert verify["expect"]["rowCount"]["max"] == 2 + + +def test_generate_python_cli_from_spec_supports_argparse_and_verify(): + module = _load_module() + + output = module.generate_python_cli_from_spec(_sample_spec()) + + assert 'parser.add_argument("--format", choices=["json", "csv", "table"]' in output + assert 'parser.add_argument("--verify", action="store_true"' in output + assert 'SPEC = {' in output + assert 'def verify_rows(rows: List[Dict[str, Any]], verify_spec: Dict[str, Any])' in output + + +def test_generated_spec_cli_executes_request_and_projects_rows(tmp_path, monkeypatch): + module = _load_module() + auth_state = tmp_path / "auth-state.json" + auth_state.write_text( + json.dumps({"cookies": [{"name": "sid", "value": "cookie-123", "domain": ".example.com", "path": "/"}]}), + encoding="utf-8", + ) + + fake_session = _FakeRequestSession( + {"data": {"items": [{"id": "1", "title": "Alpha"}, {"id": "2", "title": "Beta"}]}} + ) + fake_requests = types.SimpleNamespace(Session=lambda: fake_session) + monkeypatch.setitem(sys.modules, "requests", fake_requests) + + namespace = {"__name__": "generated_spec_cli"} + exec(module.generate_python_cli_from_spec(_sample_spec()), namespace) + + client = namespace["APIClient"](auth_state=str(auth_state)) + rows = client.run({"page": 3, "limit": 5}) + errors = namespace["verify_rows"](rows, {"expect": _sample_spec()["verify"]}) + + assert rows == [{"id": "1", "title": "Alpha"}, {"id": "2", "title": "Beta"}] + assert errors == [] + assert fake_session.request_calls == [ + { + "method": "POST", + "url": "https://example.com/api/items/list", + "json": {"page": 3, "size": 5}, + "params": None, + } + ] + assert fake_session.cookies.set_calls == [ + {"name": "sid", "value": "cookie-123", "domain": ".example.com", "path": "/"} + ] + + +def test_main_supports_spec_verify_output(tmp_path, monkeypatch, capsys): + module = _load_module() + spec_path = tmp_path / "web2cli-spec.json" + spec_path.write_text(json.dumps(_sample_spec()), encoding="utf-8") + monkeypatch.setattr( + sys, + "argv", + [ + "generate-cli.py", + "--spec", + str(spec_path), + "--format", + "verify", + ], + ) + + module.main() + captured = capsys.readouterr() + payload = json.loads(captured.out) + + assert payload["site"] == "example" + assert payload["expect"]["types"]["id"] == "string" diff --git a/tests/tool/test_web2cli_generate_spec.py b/tests/tool/test_web2cli_generate_spec.py new file mode 100644 index 00000000..4e4f2309 --- /dev/null +++ b/tests/tool/test_web2cli_generate_spec.py @@ -0,0 +1,106 @@ +import importlib.util +import json +import sys +from pathlib import Path + + +SCRIPT_PATH = ( + Path(__file__).resolve().parents[2] + / ".flocks" + / "plugins" + / "skills" + / "web2cli" + / "scripts" + / "generate-spec.py" +) + + +def _load_module(): + spec = importlib.util.spec_from_file_location("web2cli_generate_spec", SCRIPT_PATH) + assert spec is not None + assert spec.loader is not None + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _sample_requests(): + return [ + { + "type": "XHR", + "method": "POST", + "url": "https://example.com/api/ignore", + "status": 200, + "response": '{"ok": true}', + "requestHeaders": {"Content-Type": "application/json"}, + }, + { + "type": "Fetch", + "method": "POST", + "url": "https://example.com/api/items/list?page=1", + "normalizedUrl": "https://example.com/api/items/list?page=1", + "origin": "https://example.com", + "pathname": "/api/items/list", + "query": {"page": "1"}, + "queryKeys": ["page"], + "status": 200, + "captureReason": "nonGet", + "actionContext": {"lastAction": {"action": "Load data"}}, + "requestHeaders": { + "Content-Type": "application/json", + "Cookie": "sid=cookie-123", + "X-Requested-With": "XMLHttpRequest", + }, + "requestBody": '{"page": 1, "size": 20}', + "response": '{"data":{"items":[{"id":"1","title":"Alpha","count":2},{"id":"2","title":"Beta","count":3}]}}', + }, + ] + + +def test_generate_spec_from_requests_picks_primary_collection_endpoint(): + module = _load_module() + + spec = module.generate_spec_from_requests(_sample_requests()) + + assert spec["site"] == "example" + assert spec["command"] == "list" + assert spec["strategy"] == "COOKIE" + assert spec["operation"]["endpoint"] == "/api/items/list" + assert spec["operation"]["bodyTemplate"] == {"page": "${page}", "size": "${limit}"} + assert spec["args"] == [ + {"name": "page", "type": "int", "default": 1, "help": "Page number"}, + {"name": "limit", "type": "int", "default": 20, "help": "Page size"}, + ] + assert spec["rowSource"]["collectionPath"] == "$.data.items[]" + assert spec["columns"][:2] == [ + {"name": "id", "path": "$.id", "relativePath": "id", "sourceField": "id", "type": "string"}, + {"name": "title", "path": "$.title", "relativePath": "title", "sourceField": "title", "type": "string"}, + ] + + +def test_main_writes_spec_file(tmp_path, monkeypatch, capsys): + module = _load_module() + input_path = tmp_path / "captured.json" + output_path = tmp_path / "web2cli-spec.json" + input_path.write_text(json.dumps(_sample_requests()), encoding="utf-8") + + monkeypatch.setattr( + sys, + "argv", + [ + "generate-spec.py", + str(input_path), + "--output", + str(output_path), + ], + ) + + module.main() + + payload = json.loads(output_path.read_text(encoding="utf-8")) + captured = capsys.readouterr() + + assert payload["verify"]["columns"][:2] == ["id", "title"] + assert payload["verify"]["rowCount"]["max"] == 2 + assert f"Written to {output_path}" in captured.out diff --git a/tests/tool/test_web2cli_hook_base.py b/tests/tool/test_web2cli_hook_base.py index 701524e3..46f73096 100644 --- a/tests/tool/test_web2cli_hook_base.py +++ b/tests/tool/test_web2cli_hook_base.py @@ -187,8 +187,12 @@ def test_hook_base_captures_recent_action_context_for_xhr(): """ ) - assert result["version"] == "3.1-base" + assert result["version"] == "web2cli-base" assert result["request"]["pageContext"]["path"] == "/dashboard" + assert result["request"]["normalizedUrl"] == "https://example.com/api/items/list" + assert result["request"]["pathname"] == "/api/items/list" + assert result["request"]["captureReason"] == "nonGet" + assert result["request"]["requestShape"]["$.page"] == "number" assert result["request"]["actionContext"]["lastAction"]["action"] == "Load data" assert result["recentActions"][0]["type"] == "click" assert any("action=Load data" in line for line in result["logs"]) @@ -216,4 +220,6 @@ def test_hook_base_exposes_debug_state_and_truncates_large_responses(): assert result["response"].endswith("...[truncated]") assert any(action["type"] == "pushState" for action in result["debugState"]["recentActions"]) + assert result["debugState"]["lastRequest"]["response"] == result["response"] + assert result["debugState"]["lastRequest"]["pathname"] == "/api/debug" assert any("window.__apiCapture.getDebugState()" in line for line in result["logs"]) diff --git a/tests/updater/test_updater.py b/tests/updater/test_updater.py index 0672ad2e..43f7fbf5 100644 --- a/tests/updater/test_updater.py +++ b/tests/updater/test_updater.py @@ -1745,6 +1745,69 @@ async def fake_sleep(_s): assert call_count == 2 +@pytest.mark.asyncio +async def test_perform_update_rolls_back_when_windows_uv_sync_times_out( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + archive_path = tmp_path / "flocks.zip" + archive_path.write_text("archive", encoding="utf-8") + staged_root = tmp_path / "staged" + staged_webui = staged_root / "webui" + staged_webui.mkdir(parents=True) + (staged_webui / "package.json").write_text("{}", encoding="utf-8") + (staged_webui / "dist").mkdir() + (staged_webui / "dist" / "index.html").write_text("", encoding="utf-8") + + events: list[str] = [] + + async def fake_get_updater_config(): + return SimpleNamespace( + archive_format="zip", + sources=["github"], + repo="AgentFlocks/Flocks", + token=None, + gitee_token=None, + backup_retain_count=3, + base_url=None, + gitee_repo=None, + ) + + async def fake_download(**_kw): + return archive_path + + async def fake_run_async(cmd, cwd=None, timeout=None, env=None): + if "sync" in cmd: + raise subprocess.TimeoutExpired(cmd=cmd, timeout=timeout or 0) + return 0, "", "" + + monkeypatch.setattr(updater.sys, "platform", "win32") + monkeypatch.setattr(updater, "_get_updater_config", fake_get_updater_config) + monkeypatch.setattr(updater, "_get_repo_root", lambda: tmp_path / "install-root") + monkeypatch.setattr(updater, "get_current_version", lambda: "2026.3.31") + monkeypatch.setattr(updater, "_download_with_fallback", fake_download) + monkeypatch.setattr(updater, "_backup_current_version", lambda *_a, **_kw: tmp_path / "backup.tar.gz") + monkeypatch.setattr(updater, "_extract_archive", lambda *_a, **_kw: staged_root) + monkeypatch.setattr(updater, "_run_async", fake_run_async) + monkeypatch.setattr( + updater, + "_find_executable", + lambda name: "/usr/bin/npm" if name in {"npm", "npm.cmd"} else r"C:\tools\uv.exe", + ) + monkeypatch.setattr(updater, "_build_uv_sync_env", lambda: None) + monkeypatch.setattr(updater, "_replace_install_dir", lambda *_a, **_kw: None) + monkeypatch.setattr(updater, "_restore_backup_if_possible", lambda *_a: events.append("restore")) + + progresses = [step async for step in updater.perform_update("2026.4.1", restart=False)] + + assert progresses[-1].stage == "error" + expected_timeout = updater._dependency_sync_timeout_seconds() + assert progresses[-1].message == ( + f"Dependency sync timed out after {expected_timeout}s while running uv sync." + ) + assert events == ["restore"] + + @pytest.mark.asyncio async def test_perform_update_fails_after_uv_sync_retry_exhausted( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/workflow/test_fs_store.py b/tests/workflow/test_fs_store.py index 3750a42d..80994b14 100644 --- a/tests/workflow/test_fs_store.py +++ b/tests/workflow/test_fs_store.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import os from pathlib import Path import pytest @@ -50,3 +51,41 @@ def test_read_workflow_from_fs_refreshes_cached_workspace_root( assert first["workflowJson"]["name"] == "workspace-a" assert second["workflowJson"]["name"] == "workspace-b" assert fs_store.find_workspace_root() == second_workspace + + +def test_read_workflow_dir_uses_latest_file_mtime_when_meta_is_stale( + tmp_path: Path, +): + workspace = tmp_path / "workspace" + workflow_id = "mtime-sync-demo" + _write_workflow(workspace, workflow_id, "mtime-demo") + workflow_dir = workspace / ".flocks" / "plugins" / "workflows" / workflow_id + + meta_file = workflow_dir / "meta.json" + meta_file.write_text( + json.dumps( + { + "name": "mtime-demo", + "description": "demo", + "category": "default", + "status": "draft", + "createdBy": None, + "createdAt": 1000, + "updatedAt": 1000, + } + ), + encoding="utf-8", + ) + md_file = workflow_dir / "workflow.md" + md_file.write_text("# demo\n", encoding="utf-8") + + json_file = workflow_dir / "workflow.json" + os.utime(meta_file, (1, 1)) + os.utime(json_file, (5, 5)) + os.utime(md_file, (9, 9)) + + data = fs_store.read_workflow_dir(workflow_dir, workflow_id, "project") + + assert data is not None + assert data["updatedAt"] == 9000 + assert data["markdownContent"] == "# demo\n" diff --git a/tests/workflow/test_tool_run_workflow.py b/tests/workflow/test_tool_run_workflow.py index 5fa07ed8..7b37ef16 100644 --- a/tests/workflow/test_tool_run_workflow.py +++ b/tests/workflow/test_tool_run_workflow.py @@ -10,7 +10,7 @@ """ import pytest -from unittest.mock import Mock, patch, MagicMock +from unittest.mock import AsyncMock, Mock, patch, MagicMock from typing import Dict, Any # Import the tool system @@ -263,12 +263,79 @@ async def test_run_workflow_success(self, tool_context_with_permission, simple_w assert "SUCCEEDED" in result.output assert "run-123" in result.output assert "Steps executed: 1" in result.output - assert result.metadata["status"] == "SUCCEEDED" + assert result.metadata["status"] == "success" assert result.metadata["steps"] == 1 assert result.metadata["run_id"] == "run-123" # Check that permission was requested assert len(tool_context_with_permission._permissions_requested) > 0 + + @pytest.mark.anyio + async def test_run_workflow_registered_id_updates_execution_history( + self, + tool_context_with_permission, + simple_workflow, + ): + metadata_updates: list[dict[str, Any]] = [] + tool_context_with_permission._metadata_callback = metadata_updates.append + + def run_side_effect(**kwargs): + kwargs["on_step_start"]("run-registered", 1, MagicMock(id="node-1", type="python"), {}) + kwargs["on_step_complete"]({ + "node_id": "node-1", + "node_type": "python", + "outputs": {"message": "ok"}, + }) + return FakeRunWorkflowResult( + status="SUCCEEDED", + run_id="run-registered", + steps=1, + last_node_id="node-1", + outputs={"message": "ok"}, + history=[{"node_id": "node-1", "node_type": "python", "outputs": {"message": "ok"}}], + error=None, + ) + + mock_run = Mock(name="run_workflow", side_effect=run_side_effect) + create_execution = AsyncMock(return_value={ + "id": "exec-registered", + "workflowId": "test-workflow-001", + "inputParams": {"name": "Flocks"}, + "status": "running", + "startedAt": 1, + "executionLog": [], + }) + storage_read = AsyncMock(return_value={ + "id": "exec-registered", + "workflowId": "test-workflow-001", + "inputParams": {"name": "Flocks"}, + "status": "running", + "startedAt": 1, + "executionLog": [], + }) + storage_write = AsyncMock(return_value=None) + record_result = AsyncMock(return_value=None) + + with patch.object(run_workflow_module, "_get_workflow_runtime", return_value=_runtime_tuple(run_fn=mock_run)), \ + patch.object(run_workflow_module, "read_workflow_from_fs", return_value={"id": "test-workflow-001", "workflowJson": simple_workflow}), \ + patch.object(run_workflow_module, "resolve_workflow_id_from_source", return_value="test-workflow-001"), \ + patch.object(run_workflow_module, "create_execution_record", create_execution), \ + patch.object(run_workflow_module.Storage, "read", storage_read), \ + patch.object(run_workflow_module.Storage, "write", storage_write), \ + patch.object(run_workflow_module, "record_execution_result", record_result): + result = await ToolRegistry.execute( + "run_workflow", + ctx=tool_context_with_permission, + workflow="test-workflow-001", + inputs={"name": "Flocks"}, + ) + + assert result.success is True + assert result.metadata["workflow_execution_id"] == "exec-registered" + create_execution.assert_awaited_once() + record_result.assert_awaited_once() + assert storage_write.await_count >= 1 + assert any(update.get("workflow_execution_id") == "exec-registered" for update in metadata_updates) @pytest.mark.anyio async def test_run_workflow_with_inputs(self, tool_context_with_permission, workflow_with_inputs): @@ -447,7 +514,7 @@ async def test_run_workflow_failed_status(self, tool_context_with_permission, si assert result.success is False assert "FAILED" in result.output - assert result.metadata["status"] == "FAILED" + assert result.metadata["status"] == "error" # ============================================================================= diff --git a/tui/flocks/cli/cmd/tui/component/dialog-skill-install.tsx b/tui/flocks/cli/cmd/tui/component/dialog-skill-install.tsx index a4a7f37e..dbd146b8 100644 --- a/tui/flocks/cli/cmd/tui/component/dialog-skill-install.tsx +++ b/tui/flocks/cli/cmd/tui/component/dialog-skill-install.tsx @@ -25,7 +25,7 @@ export function DialogSkillInstall() { setIsError(false) try { - const res = await fetch(`${sdk.url}/skill/install`, { + const res = await sdk.fetch(`${sdk.url}/skill/install`, { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify({ source: src, scope: scope() }), diff --git a/tui/flocks/cli/cmd/tui/component/dialog-skill.tsx b/tui/flocks/cli/cmd/tui/component/dialog-skill.tsx index 549bbd35..99cdd02b 100644 --- a/tui/flocks/cli/cmd/tui/component/dialog-skill.tsx +++ b/tui/flocks/cli/cmd/tui/component/dialog-skill.tsx @@ -26,7 +26,7 @@ export function DialogSkill() { setError(null) // Use status endpoint to get eligibility info try { - const res = await fetch(`${sdk.url}/skill/status`) + const res = await sdk.fetch(`${sdk.url}/skill/status`) if (res.ok) { const data = await res.json() setSkills(Array.isArray(data) ? data : []) diff --git a/tui/flocks/cli/cmd/tui/context/sdk.tsx b/tui/flocks/cli/cmd/tui/context/sdk.tsx index 7e947655..ef49fc95 100644 --- a/tui/flocks/cli/cmd/tui/context/sdk.tsx +++ b/tui/flocks/cli/cmd/tui/context/sdk.tsx @@ -1,4 +1,4 @@ -import { createFlocksClient, type Event } from "@flocks-ai/sdk/v2" +import { createFlocksClient, getFlocksAuthHeaders, type Event } from "@flocks-ai/sdk/v2" import { createSimpleContext } from "./helper" import { createGlobalEmitter } from "@solid-primitives/event-bus" import { batch, onCleanup, onMount } from "solid-js" @@ -18,6 +18,11 @@ export const { use: useSDK, provider: SDKProvider } = createSimpleContext({ fetch: props.fetch, }) + const authenticatedFetch: typeof fetch = (input, init = {}) => { + const headers = getFlocksAuthHeaders(init.headers) + return (props.fetch ?? fetch)(input, { ...init, headers }) + } + const emitter = createGlobalEmitter<{ [key in Event["type"]]: Extract }>() @@ -91,6 +96,6 @@ export const { use: useSDK, provider: SDKProvider } = createSimpleContext({ if (timer) clearTimeout(timer) }) - return { client: sdk, event: emitter, url: props.url } + return { client: sdk, event: emitter, url: props.url, fetch: authenticatedFetch } }, }) diff --git a/tui/sdk/v2/client.ts b/tui/sdk/v2/client.ts index 4581e644..a19da37b 100644 --- a/tui/sdk/v2/client.ts +++ b/tui/sdk/v2/client.ts @@ -37,8 +37,8 @@ function getStoredApiToken(): string | undefined { } } -function withAuthHeaders(config?: Config & { directory?: string }) { - const headers = new Headers(config?.headers as HeadersInit | undefined) +export function getFlocksAuthHeaders(headersInit?: HeadersInit) { + const headers = new Headers(headersInit) const apiToken = getStoredApiToken() const hasAuth = headers.has("authorization") || headers.has("x-flocks-api-token") if (apiToken && !hasAuth) { @@ -60,7 +60,7 @@ export function createFlocksClient(config?: Config & { directory?: string }) { } } - const headers = withAuthHeaders(config) + const headers = getFlocksAuthHeaders(config?.headers as HeadersInit | undefined) if (config?.directory) { const isNonASCII = /[^\x00-\x7F]/.test(config.directory) diff --git a/uv.lock b/uv.lock index ae679487..e5ad83bf 100644 --- a/uv.lock +++ b/uv.lock @@ -496,7 +496,7 @@ wheels = [ [[package]] name = "flocks" -version = "2026.5.7" +version = "2026.5.9" source = { editable = "." } dependencies = [ { name = "aiofiles" }, diff --git a/webui/package-lock.json b/webui/package-lock.json index 232a16ce..628f6e97 100644 --- a/webui/package-lock.json +++ b/webui/package-lock.json @@ -1381,9 +1381,6 @@ "arm" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1398,9 +1395,6 @@ "arm" ], "dev": true, - "libc": [ - "musl" - ], "license": "MIT", "optional": true, "os": [ @@ -1415,9 +1409,6 @@ "arm64" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1432,9 +1423,6 @@ "arm64" ], "dev": true, - "libc": [ - "musl" - ], "license": "MIT", "optional": true, "os": [ @@ -1449,9 +1437,6 @@ "loong64" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1466,9 +1451,6 @@ "loong64" ], "dev": true, - "libc": [ - "musl" - ], "license": "MIT", "optional": true, "os": [ @@ -1483,9 +1465,6 @@ "ppc64" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1500,9 +1479,6 @@ "ppc64" ], "dev": true, - "libc": [ - "musl" - ], "license": "MIT", "optional": true, "os": [ @@ -1517,9 +1493,6 @@ "riscv64" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1534,9 +1507,6 @@ "riscv64" ], "dev": true, - "libc": [ - "musl" - ], "license": "MIT", "optional": true, "os": [ @@ -1551,9 +1521,6 @@ "s390x" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1568,9 +1535,6 @@ "x64" ], "dev": true, - "libc": [ - "glibc" - ], "license": "MIT", "optional": true, "os": [ @@ -1585,9 +1549,6 @@ "x64" ], "dev": true, - "libc": [ - "musl" - ], "license": "MIT", "optional": true, "os": [ diff --git a/webui/src/api/stats.test.ts b/webui/src/api/stats.test.ts new file mode 100644 index 00000000..c9947417 --- /dev/null +++ b/webui/src/api/stats.test.ts @@ -0,0 +1,65 @@ +import { describe, expect, it, vi, beforeEach } from 'vitest'; + +const mockGet = vi.fn(); + +vi.mock('./client', () => ({ + apiClient: { get: (...args: unknown[]) => mockGet(...args) }, +})); + +// Helper: build a default mock for every endpoint except /api/skills. +function defaultMock(skillsData: unknown[]) { + mockGet.mockImplementation((url: string) => { + if (url === '/api/skills') return Promise.resolve({ data: skillsData }); + if (url === '/api/task-system/dashboard') return Promise.resolve({ data: {} }); + if (url === '/api/agent') return Promise.resolve({ data: [] }); + if (url === '/api/workflow') return Promise.resolve({ data: [] }); + if (url === '/api/tools') return Promise.resolve({ data: [] }); + if (url === '/api/provider') return Promise.resolve({ data: { all: [], connected: [] } }); + if (url === '/api/health') return Promise.resolve({ data: { status: 'healthy' } }); + return Promise.resolve({ data: [] }); + }); +} + +describe('statsApi.getSystemStats', () => { + beforeEach(() => vi.clearAllMocks()); + + it('counts only non-system skills', async () => { + defaultMock([ + { category: 'custom' }, + { category: 'system' }, + { category: 'system' }, + { category: 'search' }, + ]); + + const { statsApi } = await import('./stats'); + const result = await statsApi.getSystemStats(); + + // 4 skills total, 2 are 'system' — only 2 should be counted. + expect(result.skills.total).toBe(2); + }); + + it('handles an all-system skill list gracefully (returns 0)', async () => { + defaultMock([{ category: 'system' }, { category: 'system' }]); + + const { statsApi } = await import('./stats'); + const result = await statsApi.getSystemStats(); + expect(result.skills.total).toBe(0); + }); + + it('handles skills API failure gracefully (returns 0)', async () => { + mockGet.mockImplementation((url: string) => { + if (url === '/api/skills') return Promise.reject(new Error('network')); + if (url === '/api/task-system/dashboard') return Promise.resolve({ data: {} }); + if (url === '/api/agent') return Promise.resolve({ data: [] }); + if (url === '/api/workflow') return Promise.resolve({ data: [] }); + if (url === '/api/tools') return Promise.resolve({ data: [] }); + if (url === '/api/provider') return Promise.resolve({ data: { all: [], connected: [] } }); + if (url === '/api/health') return Promise.resolve({ data: { status: 'healthy' } }); + return Promise.resolve({ data: [] }); + }); + + const { statsApi } = await import('./stats'); + const result = await statsApi.getSystemStats(); + expect(result.skills.total).toBe(0); + }); +}); diff --git a/webui/src/api/stats.ts b/webui/src/api/stats.ts index 37f88357..3b69e726 100644 --- a/webui/src/api/stats.ts +++ b/webui/src/api/stats.ts @@ -42,7 +42,11 @@ export const statsApi = { const dash = taskDash.data || {}; const agentList = Array.isArray(agents.data) ? agents.data : []; const workflowList = Array.isArray(workflows.data) ? workflows.data : []; - const skillList = Array.isArray(skills.data) ? skills.data : []; + // Exclude `system` category skills so the count matches the Skills page, + // which hides system skills (e.g. find-skills, onboarding) from the user. + const skillList = (Array.isArray(skills.data) ? skills.data : []).filter( + (s: any) => s?.category !== 'system' + ); const toolList = Array.isArray(tools.data) ? tools.data : []; const providerData = providers.data ?? {}; const providerAll: any[] = providerData.all ?? (Array.isArray(providers.data) ? providers.data : []); diff --git a/webui/src/api/task.ts b/webui/src/api/task.ts index 6873e566..18ef4fec 100644 --- a/webui/src/api/task.ts +++ b/webui/src/api/task.ts @@ -134,6 +134,7 @@ export interface TaskUpdateParams { description?: string; priority?: TaskPriority; tags?: string[]; + context?: Record; executionMode?: ExecutionMode; agentName?: string; workflowID?: string; diff --git a/webui/src/api/workflow.ts b/webui/src/api/workflow.ts index d0fa1940..8899c2aa 100644 --- a/webui/src/api/workflow.ts +++ b/webui/src/api/workflow.ts @@ -120,6 +120,10 @@ export interface WorkflowExecution { duration?: number; executionLog: WorkflowExecutionStep[]; errorMessage?: string; + currentNodeId?: string; + currentNodeType?: string; + currentPhase?: string; + currentStepIndex?: number; } export interface WorkflowNodeExecution { @@ -156,6 +160,7 @@ export const workflowAPI = { category?: string; workflowJson: WorkflowJSON; createdBy?: string; + source?: 'project' | 'global'; }) => client.post('/api/workflow', data), diff --git a/webui/src/components/common/ChatDialog.tsx b/webui/src/components/common/ChatDialog.tsx index 685ef2fc..89608690 100644 --- a/webui/src/components/common/ChatDialog.tsx +++ b/webui/src/components/common/ChatDialog.tsx @@ -4,11 +4,12 @@ * 使用 useSessionChat 创建会话,通过 SessionChat 展示对话并支持追问。 * 会话在用户首次发送消息时才创建,避免空会话污染会话列表。 */ -import { useEffect, useCallback } from 'react'; +import { useEffect } from 'react'; import { X, Sparkles } from 'lucide-react'; import { useTranslation } from 'react-i18next'; import SessionChat from './SessionChat'; import { useSessionChat } from '@/hooks/useSessionChat'; +import { useDefaultModelVision } from '@/hooks/useDefaultModelVision'; interface ChatDialogProps { open: boolean; @@ -35,23 +36,18 @@ export default function ChatDialog({ width = 'max-w-2xl', }: ChatDialogProps) { const { t } = useTranslation('common'); + const supportsVision = useDefaultModelVision(); const { sessionId, createAndSend, reset } = useSessionChat({ title, }); useEffect(() => { if (open && initialPrompt) { - createAndSend(initialPrompt).catch(() => {}); + createAndSend({ text: initialPrompt }).catch(() => {}); } if (!open) reset(); }, [open, reset, initialPrompt, createAndSend]); - const handleCreateAndSend = useCallback( - async (text: string) => { - await createAndSend(text); - }, - [createAndSend], - ); if (!open) return null; @@ -87,7 +83,8 @@ export default function ChatDialog({ className="flex-1 min-h-0 rounded-b-xl" emptyText={t('chat.starting')} suggestions={suggestions} - onCreateAndSend={!sessionId ? handleCreateAndSend : undefined} + supportsVision={supportsVision} + onCreateAndSend={!sessionId ? (text, imageParts) => createAndSend({ text, imageParts }) : undefined} welcomeContent={!sessionId ? (
diff --git a/webui/src/components/common/CopyButton.tsx b/webui/src/components/common/CopyButton.tsx index ff93712a..368ddcd7 100644 --- a/webui/src/components/common/CopyButton.tsx +++ b/webui/src/components/common/CopyButton.tsx @@ -8,9 +8,16 @@ interface CopyButtonProps { text: string; /** Icon size class, e.g. "w-3 h-3" or "w-3.5 h-3.5". Defaults to "w-3.5 h-3.5". */ size?: string; + label?: string; + className?: string; } -export default function CopyButton({ text, size = 'w-3.5 h-3.5' }: CopyButtonProps) { +export default function CopyButton({ + text, + size = 'w-3.5 h-3.5', + label, + className, +}: CopyButtonProps) { const { t } = useTranslation('common'); const [copied, setCopied] = useState(false); const toast = useToast(); @@ -31,12 +38,16 @@ export default function CopyButton({ text, size = 'w-3.5 h-3.5' }: CopyButtonPro return ( ); } diff --git a/webui/src/components/common/EntitySheet.tsx b/webui/src/components/common/EntitySheet.tsx index 68233428..9868f723 100644 --- a/webui/src/components/common/EntitySheet.tsx +++ b/webui/src/components/common/EntitySheet.tsx @@ -30,7 +30,7 @@ import { useTranslation } from 'react-i18next'; import client from '@/api/client'; import SessionChat from './SessionChat'; import { useSessionChat } from '@/hooks/useSessionChat'; - +import { useDefaultModelVision } from '@/hooks/useDefaultModelVision'; // ─── Context ────────────────────────────────────────────────────────────────── interface EntitySheetCtx { @@ -133,6 +133,7 @@ export default function EntitySheet({ footerLeft, }: EntitySheetProps) { const { t } = useTranslation('common'); + const supportsVision = useDefaultModelVision(); const showTabs = !(hideRex && hideTest); const hasFormTab = !hideForm; const title = @@ -260,7 +261,7 @@ export default function EntitySheet({ parts: [{ type: 'text', text: msg }], }); } else if (msg) { - createAndSendRex(msg).catch(() => {}); + createAndSendRex({ text: msg }).catch(() => {}); } }, [sessionId, createAndSendRex], @@ -491,7 +492,8 @@ export default function EntitySheet({ className="flex-1" emptyText={t('entity.rexReady')} initialMessage={rexInitialMessage} - onCreateAndSend={!sessionId ? async (text: string) => { await createAndSendRex(text); } : undefined} + supportsVision={supportsVision} + onCreateAndSend={!sessionId ? (text, imageParts) => createAndSendRex({ text, imageParts }) : undefined} welcomeContent={!sessionId ? (
diff --git a/webui/src/components/common/ImageLightbox.tsx b/webui/src/components/common/ImageLightbox.tsx new file mode 100644 index 00000000..5a16971d --- /dev/null +++ b/webui/src/components/common/ImageLightbox.tsx @@ -0,0 +1,103 @@ +/** + * ImageLightbox — full-screen overlay for previewing chat images. + * + * Why a custom lightbox instead of `window.open(url, '_blank')`: + * the image URLs flowing through the chat composer are inline base64 + * `data:` URLs. Modern browsers (Chrome / Edge / Firefox) block top-level + * navigation to `data:` URLs for phishing-protection reasons, so opening + * one in a new tab silently produces a blank page. Rendering the image + * inside a same-origin overlay sidesteps the restriction and matches the + * mental model the user expects ("click to enlarge in place"). + */ + +import { useEffect, useRef } from 'react'; +import { useTranslation } from 'react-i18next'; +import { X } from 'lucide-react'; + +interface ImageLightboxProps { + /** Image source URL — supports both `data:` URLs and remote http(s) URLs. */ + src: string; + /** Optional filename / alt text shown for screen readers. */ + alt?: string; + /** Called when the user dismisses the lightbox (Escape, click outside, ✕). */ + onClose: () => void; +} + +// Module-level reference count + saved baseline so two concurrently-mounted +// lightboxes don't fight over ``body.style.overflow``. Without this, the +// first lightbox to unmount would prematurely restore the baseline while +// the second is still open, and the second's restore would then "lock in" +// the previous (already-hidden) value. +let scrollLockCount = 0; +let savedBodyOverflow = ''; + +function acquireScrollLock(): void { + if (scrollLockCount === 0) { + savedBodyOverflow = document.body.style.overflow; + document.body.style.overflow = 'hidden'; + } + scrollLockCount += 1; +} + +function releaseScrollLock(): void { + scrollLockCount = Math.max(0, scrollLockCount - 1); + if (scrollLockCount === 0) { + document.body.style.overflow = savedBodyOverflow; + } +} + +export default function ImageLightbox({ src, alt, onClose }: ImageLightboxProps) { + const { t } = useTranslation('common'); + const closeButtonRef = useRef(null); + + // Close on Escape so the overlay behaves like a normal modal. + useEffect(() => { + const handler = (e: KeyboardEvent) => { + if (e.key === 'Escape') onClose(); + }; + window.addEventListener('keydown', handler); + return () => window.removeEventListener('keydown', handler); + }, [onClose]); + + // Lock the body scroll while the lightbox is mounted. + useEffect(() => { + acquireScrollLock(); + return releaseScrollLock; + }, []); + + // Move focus to the close button so keyboard users can dismiss immediately. + useEffect(() => { + closeButtonRef.current?.focus(); + }, []); + + const label = alt || t('image.preview'); + + return ( +
+ + {alt e.stopPropagation()} + /> +
+ ); +} diff --git a/webui/src/components/common/SessionChat.test.ts b/webui/src/components/common/SessionChat.test.ts index 257c1539..bc06d3ac 100644 --- a/webui/src/components/common/SessionChat.test.ts +++ b/webui/src/components/common/SessionChat.test.ts @@ -1,6 +1,19 @@ import { describe, expect, it } from 'vitest'; -import { getMessageBubbleClassName } from './SessionChat'; +import type { Message } from '@/types'; + +import { getMessageBubbleClassName, getRegenerateTruncateTarget } from './SessionChat'; + +function makeMessage(overrides: Partial & { id: string }): Message { + return { + id: overrides.id, + sessionID: 'sess-1', + role: 'assistant', + parts: [], + timestamp: 0, + ...overrides, + } as Message; +} describe('getMessageBubbleClassName', () => { it('keeps non-editing user bubbles auto-sized in full layout', () => { @@ -34,3 +47,23 @@ describe('getMessageBubbleClassName', () => { expect(className).toContain('max-w-2xl w-full'); }); }); + +describe('getRegenerateTruncateTarget', () => { + it('truncates back to the parent user message for assistant regenerations', () => { + const target = getRegenerateTruncateTarget([ + makeMessage({ id: 'user-1', role: 'user' }), + makeMessage({ id: 'assistant-1', role: 'assistant', parentID: 'user-1' }), + makeMessage({ id: 'assistant-2', role: 'assistant', parentID: 'user-1' }), + ], 'assistant-2'); + + expect(target).toEqual({ messageId: 'user-1' }); + }); + + it('falls back to removing the target message when parent linkage is unavailable', () => { + const target = getRegenerateTruncateTarget([ + makeMessage({ id: 'assistant-1', role: 'assistant' }), + ], 'assistant-1'); + + expect(target).toEqual({ messageId: 'assistant-1', includeTarget: true }); + }); +}); diff --git a/webui/src/components/common/SessionChat.tsx b/webui/src/components/common/SessionChat.tsx index efc48688..517beb82 100644 --- a/webui/src/components/common/SessionChat.tsx +++ b/webui/src/components/common/SessionChat.tsx @@ -17,7 +17,7 @@ */ import { useState, useCallback, useRef, useEffect, useMemo, memo } from 'react'; -import { Send, Loader2, ChevronDown, Square, Copy, User, Plus, FileText, AlertCircle, X, RefreshCw, Pencil, Save } from 'lucide-react'; +import { Send, Loader2, ChevronDown, Square, Copy, User, Plus, FileText, AlertCircle, X, RefreshCw, Pencil, Save, ImageIcon } from 'lucide-react'; import { StreamingMarkdown } from './StreamingMarkdown'; import { useTranslation } from 'react-i18next'; import LoadingSpinner from './LoadingSpinner'; @@ -25,6 +25,7 @@ import { useToast } from './Toast'; import { QuestionTool } from './QuestionTool'; import DelegateTaskCard, { isDelegateTool, shouldRenderDelegateTaskCard } from './DelegateTaskCard'; import CommandDropdown, { parseSlashCommand } from './CommandDropdown'; +import ImageLightbox from './ImageLightbox'; import { useSessionMessages } from '@/hooks/useSessions'; import { useSSE, type SSEConnectionStatus } from '@/hooks/useSSE'; import { useReasoningToggle } from '@/hooks/useReasoningToggle'; @@ -35,6 +36,16 @@ import { commandAPI, type Command } from '@/api/skill'; import { workspaceAPI } from '@/api/workspace'; import { copyText } from '@/utils/clipboard'; import { formatSmartTime } from '@/utils/time'; +import { + FILE_INPUT_ACCEPT_IMAGES, + batchCompressOptions, + buildPromptParts, + compressImageFile, + getFileExtension, + isImageFile, + readFileAsDataUrl, + type ImagePartData, +} from '@/utils/imageUpload'; import type { Message, MessagePart, ToolState } from '@/types'; export { formatSmartTime }; @@ -107,9 +118,23 @@ export interface SessionChatProps { onError?: (message: string) => void; /** * Called when the user sends a message but sessionId is not yet available. - * The parent should create a session and update sessionId + initialMessage props. + * The parent should create a session and dispatch the prompt (with the + * provided text and any image attachments) to the new session. + * + * `imageParts` carries inline image data URLs — parents that don't yet + * support image input can ignore the second argument. + * + * The return value is intentionally typed as ``unknown`` so callers can + * pass ``useSessionChat().createAndSend`` (which resolves to the new + * session id) directly without an empty ``async (..) => { await ... }`` + * shim. + */ + onCreateAndSend?: (text: string, imageParts?: ImagePartData[]) => Promise | unknown; + /** + * Whether the current model supports vision/image analysis. + * true = allow images; false = block images with a UI warning; null/undefined = allow (unknown). */ - onCreateAndSend?: (text: string) => Promise | void; + supportsVision?: boolean | null; } type AttachmentStatus = 'uploading' | 'success' | 'error'; @@ -119,10 +144,21 @@ interface ComposerAttachment { file: File; name: string; status: AttachmentStatus; + /** For document attachments: the workspace-relative path after upload */ workspacePath?: string; + /** For image attachments: the base64 data URL (no server upload needed) */ + dataUrl?: string; + /** True if this attachment is an image file */ + isImage?: boolean; error?: string; } +// Composer drafts are persisted to ``localStorage`` so navigating away from +// the page (e.g. clicking the sidebar to open Agents / Workflows) and coming +// back doesn't lose the half-typed message. Keyed per session so two sessions +// don't share a draft, and namespaced to avoid colliding with other features. +import { readChatDraft, writeChatDraft } from '@/utils/chatDraft'; + // Backend stages emitted by ``SessionCompaction.process`` / // ``summarize_chunked`` via the ``session.compaction_progress`` SSE event. // Keep in sync with ``flocks/session/lifecycle/compaction/{compaction,summary}.py``. @@ -267,6 +303,17 @@ export function getMessageBubbleClassName({ }`; } +export function getRegenerateTruncateTarget( + messages: Message[], + messageId: string, +): { messageId: string; includeTarget?: boolean } { + const targetMessage = messages.find((message) => message.id === messageId); + if (targetMessage?.role === 'assistant' && targetMessage.parentID) { + return { messageId: targetMessage.parentID }; + } + return { messageId, includeTarget: true }; +} + // ============================================================================ // Main component // ============================================================================ @@ -275,18 +322,13 @@ const ABORT_SSE_SETTLE_DELAY = 2000; const SCROLL_BOTTOM_THRESHOLD_PX = 80; const FALLBACK_POLL_MS = 5_000; const WORKSPACE_UPLOAD_DEST = 'uploads'; -const FILE_INPUT_ACCEPT = '.txt,.md,.json,.yaml,.yml,.xml,.csv,.pdf,.doc,.docx,.html,.htm,.ppt,.pptx,.xls,.xlsx'; +const FILE_INPUT_ACCEPT_DOCS = '.txt,.md,.json,.yaml,.yml,.xml,.csv,.pdf,.doc,.docx,.html,.htm,.ppt,.pptx,.xls,.xlsx'; +const FILE_INPUT_ACCEPT_ALL = `${FILE_INPUT_ACCEPT_DOCS},${FILE_INPUT_ACCEPT_IMAGES}`; const ALLOWED_UPLOAD_EXTENSIONS = new Set([ 'txt', 'md', 'json', 'yaml', 'yml', 'xml', 'csv', 'pdf', 'doc', 'docx', 'html', 'htm', 'ppt', 'pptx', 'xls', 'xlsx', ]); -function getFileExtension(filename: string): string { - const normalized = filename.toLowerCase(); - const idx = normalized.lastIndexOf('.'); - return idx >= 0 ? normalized.slice(idx + 1) : ''; -} - function isAllowedUploadFile(file: File): boolean { return ALLOWED_UPLOAD_EXTENSIONS.has(getFileExtension(file.name)); } @@ -311,6 +353,7 @@ export default function SessionChat({ onError, onCreateAndSend, onInitialMessageConsumed, + supportsVision, }: SessionChatProps) { const { t } = useTranslation('session'); const { t: tCommon } = useTranslation('common'); @@ -320,11 +363,18 @@ export default function SessionChat({ const showTimestamp = display?.showTimestamp ?? false; const effectivePlaceholder = placeholder ?? t('chat.placeholder'); const effectiveEmptyText = emptyText ?? t('chat.emptyText'); - const [input, setInput] = useState(''); + // Restore any persisted draft on first mount so navigating away (e.g. + // sidebar → Agents → back to Sessions) doesn't wipe the user's half-typed + // message. Subsequent session changes are re-hydrated by the effect below. + const [input, setInput] = useState(() => readChatDraft(sessionId)); const [sending, setSending] = useState(false); const [isStreaming, setIsStreaming] = useState(false); const [attachments, setAttachments] = useState([]); const [isDragOver, setIsDragOver] = useState(false); + // Lightbox preview for composer thumbnails. Shares the same overlay + // component used by message bubbles so the click-to-enlarge gesture is + // consistent across the upload tray and the rendered chat history. + const [composerPreview, setComposerPreview] = useState<{ url: string; alt?: string } | null>(null); const [isCompacting, setIsCompacting] = useState(false); const [compactingMessage, setCompactingMessage] = useState(''); // Live compaction progress, populated by ``session.compaction_progress`` SSE @@ -416,12 +466,22 @@ export default function SessionChat({ const [commandQuery, setCommandQuery] = useState(''); const [selectedCommandIndex, setSelectedCommandIndex] = useState(0); const commandsLoadedRef = useRef(false); - const successfulAttachments = useMemo( - () => attachments.filter((attachment) => attachment.status === 'success' && attachment.workspacePath), + const successfulDocAttachments = useMemo( + () => attachments.filter((a) => a.status === 'success' && a.workspacePath && !a.isImage), [attachments], ); + const successfulImageAttachments = useMemo( + () => attachments.filter((a) => a.status === 'success' && a.isImage && a.dataUrl), + [attachments], + ); + // Keep backward-compat alias (used in slash-command guard) + const successfulAttachments = useMemo( + () => [...successfulDocAttachments, ...successfulImageAttachments], + [successfulDocAttachments, successfulImageAttachments], + ); const hasUploadingFiles = attachments.some((attachment) => attachment.status === 'uploading'); - const canSend = !sending && !isStreaming && !hasUploadingFiles && (!!input.trim() || successfulAttachments.length > 0); + const canSend = !sending && !isStreaming && !hasUploadingFiles && + (!!input.trim() || successfulDocAttachments.length > 0 || successfulImageAttachments.length > 0); const scrollToBottom = useCallback(() => { if (!isAtBottomRef.current) return; @@ -637,8 +697,20 @@ export default function SessionChat({ statusCheckedRef.current = null; isAtBottomRef.current = true; clearPendingQuestions(); + // Swap the draft when the session changes — needed for callers that + // don't force a remount (Session/index.tsx does, but other consumers + // such as WorkflowDetail/ChatTab may swap sessionId without a remount). + setInput(readChatDraft(sessionId)); }, [sessionId, clearPendingQuestions]); + // Persist the draft on every keystroke. localStorage writes are synchronous + // and cheap, so debouncing isn't worth the added latency on send (which + // depends on the draft being flushed). Drafts are removed when ``input`` + // becomes empty (e.g. after a successful send). + useEffect(() => { + writeChatDraft(sessionId, input); + }, [sessionId, input]); + // Recover streaming state after page refresh / session switch useEffect(() => { if (!sessionId || loading) return; @@ -780,13 +852,29 @@ export default function SessionChat({ } }, [t]); - const queueFilesForUpload = useCallback((files: File[]) => { + const queueFilesForUpload = useCallback((files: File[], { imageBlocked = false }: { imageBlocked?: boolean } = {}) => { if (files.length === 0) return; - const validEntries: Array<{ id: string; file: File }> = []; + const validDocEntries: Array<{ id: string; file: File }> = []; + const validImageFiles: Array<{ id: string; file: File }> = []; const invalidAttachments: ComposerAttachment[] = []; + let imageRejectedToastShown = false; files.forEach((file, index) => { const id = `attachment-${Date.now()}-${index}-${Math.random().toString(36).slice(2, 8)}`; + + if (isImageFile(file)) { + if (imageBlocked || supportsVision === false) { + // Show a toast once for the whole batch of rejected images + if (!imageRejectedToastShown) { + imageRejectedToastShown = true; + toast.error(t('chat.upload.imageNotSupported')); + } + } else { + validImageFiles.push({ id, file }); + } + return; + } + if (!isAllowedUploadFile(file)) { invalidAttachments.push({ id, @@ -797,27 +885,63 @@ export default function SessionChat({ }); return; } - validEntries.push({ id, file }); + validDocEntries.push({ id, file }); }); if (invalidAttachments.length > 0) { setAttachments((prev) => [...prev, ...invalidAttachments]); } - if (validEntries.length === 0) return; - - setAttachments((prev) => [ - ...prev, - ...validEntries.map(({ id, file }) => ({ - id, - file, - name: file.name, - status: 'uploading' as const, - })), - ]); + // Handle document uploads (server upload) + if (validDocEntries.length > 0) { + setAttachments((prev) => [ + ...prev, + ...validDocEntries.map(({ id, file }) => ({ + id, + file, + name: file.name, + status: 'uploading' as const, + })), + ]); + void uploadSelectedFiles(validDocEntries); + } - void uploadSelectedFiles(validEntries); - }, [t, uploadSelectedFiles]); + // Handle image files (read as base64, no server upload) + if (validImageFiles.length > 0) { + setAttachments((prev) => [ + ...prev, + ...validImageFiles.map(({ id, file }) => ({ + id, + file, + name: file.name, + status: 'uploading' as const, + isImage: true, + })), + ]); + // Pick compression aggressiveness from how many images are arriving + // together. A 4-image drop gets a tighter cap than a single image so + // the combined base64 body still fits inside upstream gateway limits. + const batchOpts = batchCompressOptions(validImageFiles.length); + validImageFiles.forEach(({ id, file }) => { + compressImageFile(file, batchOpts) + .then((compressed) => readFileAsDataUrl(compressed).then((dataUrl) => ({ compressed, dataUrl }))) + .then(({ compressed, dataUrl }) => { + setAttachments((prev) => prev.map((a) => + a.id === id + ? { ...a, file: compressed, name: compressed.name, status: 'success' as const, dataUrl, isImage: true } + : a + )); + }) + .catch(() => { + setAttachments((prev) => prev.map((a) => + a.id === id + ? { ...a, status: 'error' as const, error: t('chat.upload.errorGeneric') } + : a + )); + }); + }); + } + }, [t, toast, uploadSelectedFiles, supportsVision]); const handleFileSelection = useCallback((fileList: FileList | null) => { if (!fileList || fileList.length === 0) return; @@ -832,8 +956,27 @@ export default function SessionChat({ status: 'uploading', error: undefined, })); - void uploadSelectedFiles([{ id: attachment.id, file: attachment.file }]); - }, [attachments, updateAttachment, uploadSelectedFiles]); + if (attachment.isImage) { + compressImageFile(attachment.file) + .then((compressed) => readFileAsDataUrl(compressed).then((dataUrl) => ({ compressed, dataUrl }))) + .then(({ compressed, dataUrl }) => { + setAttachments((prev) => prev.map((a) => + a.id === attachmentId + ? { ...a, file: compressed, name: compressed.name, status: 'success' as const, dataUrl, error: undefined } + : a + )); + }) + .catch(() => { + setAttachments((prev) => prev.map((a) => + a.id === attachmentId + ? { ...a, status: 'error' as const, error: t('chat.upload.errorGeneric') } + : a + )); + }); + } else { + void uploadSelectedFiles([{ id: attachment.id, file: attachment.file }]); + } + }, [attachments, updateAttachment, uploadSelectedFiles, t]); const handleRemoveAttachment = useCallback((attachmentId: string) => { setAttachments((prev) => prev.filter((attachment) => attachment.id !== attachmentId)); @@ -846,6 +989,7 @@ export default function SessionChat({ queueFilesForUpload(files); }, [queueFilesForUpload]); + const handleComposerDragOver = useCallback((event: React.DragEvent) => { if (!Array.from(event.dataTransfer?.types ?? []).includes('Files')) return; event.preventDefault(); @@ -912,7 +1056,7 @@ export default function SessionChat({ }; /** Core send logic */ - const sendText = async (text: string) => { + const sendText = async (text: string, imageParts: ImagePartData[] = []) => { if (!sessionId) return; // Clear abort state immediately so SSE events for the new stream are not suppressed abortingRef.current = false; @@ -922,17 +1066,23 @@ export default function SessionChat({ setIsStreaming(true); const tempId = `temp-${Date.now()}`; + const tempParts: MessagePart[] = []; + if (text) tempParts.push({ id: `${tempId}-text`, type: 'text', text }); + imageParts.forEach((img, i) => { + tempParts.push({ id: `${tempId}-img-${i}`, type: 'file', url: img.url, mime: img.mime, filename: img.filename }); + }); + addMessage({ id: tempId, sessionID: sessionId, role: 'user', - parts: [{ id: `${tempId}-part`, type: 'text', text }], + parts: tempParts.length > 0 ? tempParts : [{ id: `${tempId}-part`, type: 'text', text }], timestamp: Date.now(), } as Message); try { const payload: Record = { - parts: [{ type: 'text', text }], + parts: buildPromptParts(text, imageParts), }; if (agentName) payload.agent = agentName; @@ -954,15 +1104,25 @@ export default function SessionChat({ const handleSend = async () => { if (!canSend) return; const rawText = input.trim(); - const attachmentsToSend = [...successfulAttachments]; - const text = buildMessageText(rawText, attachmentsToSend); - if (!text) return; + const docAttachmentsToSend = [...successfulDocAttachments]; + const imageAttachmentsToSend = [...successfulImageAttachments]; + const text = buildMessageText(rawText, docAttachmentsToSend); + + // Need either text content or image attachments + if (!text && imageAttachmentsToSend.length === 0) return; setInput(''); setShowCommandDropdown(false); - // Route slash commands through the command API (requires an active session) - const parsed = attachmentsToSend.length === 0 ? parseSlashCommand(rawText) : null; + const imageParts: ImagePartData[] = imageAttachmentsToSend.map((a) => ({ + url: a.dataUrl!, + mime: a.file.type, + filename: a.name, + })); + + // Route slash commands through the command API (requires an active session, no images) + const parsed = docAttachmentsToSend.length === 0 && imageAttachmentsToSend.length === 0 + ? parseSlashCommand(rawText) : null; if (parsed) { if (!sessionId) { // Slash commands need an existing session; restore input and do nothing @@ -981,10 +1141,14 @@ export default function SessionChat({ if (onCreateAndSend) { setSending(true); try { - await onCreateAndSend(text); + await onCreateAndSend(text, imageParts); setAttachments([]); } catch { + // Restore both the text and the attachment list so the user can + // retry without re-uploading images. Image data URLs are already + // in memory, so restoring the array is safe and cheap. setInput(rawText); + setAttachments(imageAttachmentsToSend); } finally { setSending(false); } @@ -993,10 +1157,11 @@ export default function SessionChat({ } try { - await sendText(text); + await sendText(text, imageParts); setAttachments([]); } catch { setInput(rawText); + setAttachments(imageAttachmentsToSend); } }; @@ -1216,7 +1381,11 @@ export default function SessionChat({ setActionMessageId(messageId); try { await sessionApi.regenerateMessage(sessionId, messageId); - truncateAfterMessage(messageId, { includeTarget: true }); + const truncateTarget = getRegenerateTruncateTarget(messagesRef.current, messageId); + truncateAfterMessage( + truncateTarget.messageId, + truncateTarget.includeTarget ? { includeTarget: true } : undefined, + ); setIsStreaming(true); if (editingMessageId === messageId) { resetEditingState(); @@ -1448,7 +1617,7 @@ export default function SessionChat({ ref={fileInputRef} type="file" className="hidden" - accept={FILE_INPUT_ACCEPT} + accept={FILE_INPUT_ACCEPT_ALL} multiple onChange={(event) => { handleFileSelection(event.target.files); @@ -1459,7 +1628,7 @@ export default function SessionChat({ type="button" onClick={() => fileInputRef.current?.click()} disabled={sending || isStreaming} - title={t('chat.upload.select')} + title={t('chat.upload.selectWithImage')} className={`flex-shrink-0 rounded-lg border border-gray-300 bg-white text-gray-600 hover:bg-gray-50 hover:text-gray-900 disabled:opacity-40 disabled:cursor-not-allowed transition-colors ${ compact ? 'w-10 h-[40px]' : 'w-12 h-[52px] rounded-xl' } inline-flex items-center justify-center`} @@ -1517,6 +1686,43 @@ export default function SessionChat({ const isUploading = attachment.status === 'uploading'; const isError = attachment.status === 'error'; const attachmentPath = attachment.workspacePath ?? null; + + // Image thumbnail display + if (attachment.isImage && attachment.dataUrl && !isError) { + return ( +
+ {isUploading ? ( +
+ +
+ ) : ( + {attachment.name} + setComposerPreview({ url: attachment.dataUrl!, alt: attachment.name }) + } + /> + )} + +
+ ); + } + return (
) : isError ? ( + ) : attachment.isImage ? ( + ) : ( )} @@ -1544,7 +1752,7 @@ export default function SessionChat({
{attachment.error}
)}
- {isError && ( + {isError && !attachment.isImage && (
)} + {composerPreview && ( + setComposerPreview(null)} + /> + )} ); } @@ -1696,6 +1911,11 @@ function ChatMessageBubbleInner({ const isUser = message.role === 'user'; const parts: MessagePart[] = Array.isArray(message.parts) ? message.parts : []; const { getPartExpanded, togglePart, isReasoningDone } = useReasoningToggle(parts, message.finish); + // Lightbox state for inline image previews. Browsers block top-level + // navigation to ``data:`` URLs (the format we send for chat images), so a + // ``window.open`` would land on a blank page. We open an in-app overlay + // instead — same UX, no popup blocker / data-URL restriction headaches. + const [previewImage, setPreviewImage] = useState<{ url: string; alt?: string } | null>(null); if (message.finish === 'summary') { const hasArchived = compactedMessages && compactedMessages.length > 0; return ( @@ -1802,32 +2022,67 @@ function ChatMessageBubbleInner({ /> ) : ( - parts.map((part: MessagePart, i: number) => ( -
- {/* Text */} - {part.type === 'text' && part.text && (() => { - const nodeRefMatch = isUser - ? part.text.match(/^@@node:([^|\n]+)\|([^\n]+)\n([\s\S]*)$/) - : null; - const displayText = nodeRefMatch ? nodeRefMatch[3] : part.text; - return ( - <> - {nodeRefMatch && ( -
- - {nodeRefMatch[1]} - {nodeRefMatch[2]} -
- )} - - - ); - })()} + (() => { + // Render attachments (file/image parts) first so the bubble shows + // image previews above the textual prompt — matches typical chat + // UX for "look at this image and …" style messages. + const fileParts = parts.filter((p) => p.type === 'file' && p.url); + const otherParts = parts.filter((p) => !(p.type === 'file' && p.url)); + return ( + <> + {fileParts.length > 0 && ( +
+ {fileParts.map((part, i) => { + const isImage = (part.mime || '').startsWith('image/'); + if (isImage && part.url) { + return ( + {part.filename setPreviewImage({ url: part.url!, alt: part.filename })} + /> + ); + } + return ( +
+ + {part.filename || 'file'} +
+ ); + })} +
+ )} + {otherParts.map((part: MessagePart, i: number) => ( +
+ {/* Text */} + {part.type === 'text' && part.text && (() => { + const nodeRefMatch = isUser + ? part.text.match(/^@@node:([^|\n]+)\|([^\n]+)\n([\s\S]*)$/) + : null; + const displayText = nodeRefMatch ? nodeRefMatch[3] : part.text; + return ( + <> + {nodeRefMatch && ( +
+ + {nodeRefMatch[1]} + {nodeRefMatch[2]} +
+ )} + + + ); + })()} - {/* Tool call */} + {/* Tool call */} {part.type === 'tool' && ( ); })()} -
- )) +
+ ))} + + ); + })() )} {/* Streaming indicator */} @@ -1986,6 +2244,13 @@ function ChatMessageBubbleInner({ )} + {previewImage && ( + setPreviewImage(null)} + /> + )} ); } diff --git a/webui/src/hooks/useDefaultModelVision.test.ts b/webui/src/hooks/useDefaultModelVision.test.ts new file mode 100644 index 00000000..d61ee8a3 --- /dev/null +++ b/webui/src/hooks/useDefaultModelVision.test.ts @@ -0,0 +1,126 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +// ── module under test ────────────────────────────────────────────────────── +// We use dynamic imports so vi.mock() is hoisted correctly. +vi.mock('@/api/provider', () => ({ + defaultModelAPI: { getResolved: vi.fn() }, + modelV2API: { getDefinition: vi.fn() }, +})); + +import { defaultModelAPI, modelV2API } from '@/api/provider'; +import { __resetVisionCacheForTesting, MODEL_CHANGED_EVENT, useDefaultModelVision } from './useDefaultModelVision'; +import { renderHook, act, waitFor } from '@testing-library/react'; + +const mockResolved = defaultModelAPI.getResolved as ReturnType; +const mockDefinition = modelV2API.getDefinition as ReturnType; + +function makeResolvedResp(provider_id = 'openai', model_id = 'gpt-4o') { + return { data: { provider_id, model_id } }; +} + +function makeDefResp(caps: Record, fetchFrom: 'predefined' | 'customizable' = 'customizable') { + return { data: { fetch_from: fetchFrom, capabilities: caps } }; +} + +describe('useDefaultModelVision', () => { + beforeEach(() => { + __resetVisionCacheForTesting(); + vi.clearAllMocks(); + }); + + afterEach(() => { + __resetVisionCacheForTesting(); + }); + + it('returns null initially then true for a vision model', async () => { + mockResolved.mockResolvedValue(makeResolvedResp()); + mockDefinition.mockResolvedValue(makeDefResp({ supports_vision: true })); + + const { result } = renderHook(() => useDefaultModelVision()); + expect(result.current).toBeNull(); + + await waitFor(() => expect(result.current).toBe(true)); + }); + + it('returns false for a non-vision customizable model', async () => { + mockResolved.mockResolvedValue(makeResolvedResp()); + mockDefinition.mockResolvedValue(makeDefResp({ supports_vision: false })); + + const { result } = renderHook(() => useDefaultModelVision()); + await waitFor(() => expect(result.current).toBe(false)); + }); + + it('returns false for a predefined (built-in) model even when it declares vision support', async () => { + // Built-in models must explicitly reject image uploads (not return + // null / unknown) so SessionChat shows the "model does not support + // images" toast instead of silently letting them through. + mockResolved.mockResolvedValue(makeResolvedResp()); + mockDefinition.mockResolvedValue(makeDefResp({ supports_vision: true }, 'predefined')); + + const { result } = renderHook(() => useDefaultModelVision()); + await waitFor(() => expect(result.current).toBe(false)); + }); + + it('returns null when capabilities are absent', async () => { + mockResolved.mockResolvedValue(makeResolvedResp()); + mockDefinition.mockResolvedValue({ data: {} }); + + const { result } = renderHook(() => useDefaultModelVision()); + await waitFor(() => expect(result.current).toBeNull()); + }); + + it('module-level cache: API called only once for multiple concurrent hooks', async () => { + mockResolved.mockResolvedValue(makeResolvedResp()); + mockDefinition.mockResolvedValue(makeDefResp({ supports_vision: true })); + + renderHook(() => useDefaultModelVision()); + renderHook(() => useDefaultModelVision()); + renderHook(() => useDefaultModelVision()); + + await waitFor(() => expect(mockResolved).toHaveBeenCalledTimes(1)); + expect(mockDefinition).toHaveBeenCalledTimes(1); + }); + + it('MODEL_CHANGED_EVENT invalidates cache and notifies subscribers', async () => { + // First resolve: non-vision + mockResolved.mockResolvedValue(makeResolvedResp()); + mockDefinition.mockResolvedValue(makeDefResp({ supports_vision: false })); + + const { result } = renderHook(() => useDefaultModelVision()); + await waitFor(() => expect(result.current).toBe(false)); + + // Change to a vision model and dispatch the event + mockDefinition.mockResolvedValue(makeDefResp({ supports_vision: true })); + + act(() => { + window.dispatchEvent(new CustomEvent(MODEL_CHANGED_EVENT)); + }); + + await waitFor(() => expect(result.current).toBe(true)); + // After invalidation, the API was called a second time + expect(mockDefinition).toHaveBeenCalledTimes(2); + }); + + it('detects vision via modalities.input', async () => { + mockResolved.mockResolvedValue(makeResolvedResp()); + mockDefinition.mockResolvedValue(makeDefResp({ modalities: { input: ['text', 'image'] } })); + + const { result } = renderHook(() => useDefaultModelVision()); + await waitFor(() => expect(result.current).toBe(true)); + }); + + it('detects vision via features array', async () => { + mockResolved.mockResolvedValue(makeResolvedResp()); + mockDefinition.mockResolvedValue(makeDefResp({ features: ['vision', 'tools'] })); + + const { result } = renderHook(() => useDefaultModelVision()); + await waitFor(() => expect(result.current).toBe(true)); + }); + + it('returns null on API error', async () => { + mockResolved.mockRejectedValue(new Error('network error')); + + const { result } = renderHook(() => useDefaultModelVision()); + await waitFor(() => expect(result.current).toBeNull()); + }); +}); diff --git a/webui/src/hooks/useDefaultModelVision.ts b/webui/src/hooks/useDefaultModelVision.ts new file mode 100644 index 00000000..cd0c0a62 --- /dev/null +++ b/webui/src/hooks/useDefaultModelVision.ts @@ -0,0 +1,124 @@ +import { useEffect, useState } from 'react'; +import { defaultModelAPI, modelV2API } from '@/api/provider'; + +/** + * Detect whether the resolved default LLM model supports image (vision) input. + * + * Returns: + * - `true` — model is multimodal / supports images + * - `false` — model explicitly does not support images (UI should block image + * uploads with a warning) + * - `null` — unknown / unable to determine (UI should allow uploads as a + * best-effort fallback) + * + * Centralised so every place that hosts a chat composer (Session, Agent / + * Workflow / Skill / Tool creation drawers, generic ChatDialog, etc.) gets + * the same logic and the same UX. Without this, only the Session page + * showed the "current model does not support images" hint, while uploading + * an image in the other composers would silently fail (or send through to + * a non-vision model). + * + * Caching: + * The resolved capability is cached at module scope so each newly mounted + * composer (sidebar drawer, dialog, etc.) reuses the in-flight or + * completed lookup instead of firing a fresh `getResolved + getDefinition` + * pair. The cache is invalidated when ``MODEL_CHANGED_EVENT`` fires — + * pages that change the default model (see ``Model/index.tsx``) dispatch + * that event after a successful update so this hook re-resolves. + */ + +/** Window event other code can dispatch to invalidate the cached vision capability. */ +export const MODEL_CHANGED_EVENT = 'flocks:default-model-changed'; + +type VisionState = boolean | null; + +let cachedPromise: Promise | null = null; +const subscribers = new Set<(state: VisionState) => void>(); + +async function resolveVisionSupport(): Promise { + try { + const resolvedResp = await defaultModelAPI.getResolved(); + const { provider_id, model_id } = resolvedResp.data; + if (!provider_id || !model_id) return null; + const defResp = await modelV2API.getDefinition(provider_id, model_id); + const def: any = defResp.data; + if (!def) return null; + // Predefined (catalog/SDK) models are treated as **explicitly non-vision** + // (return false, not null) so the chat composer actively *rejects* image + // uploads with the "model does not support images" hint. Returning null + // would fall through to the best-effort "allow upload" branch in + // SessionChat (`supportsVision === false` is the rejection trigger), which + // is exactly the bug we're avoiding here. Vision is only unlocked by + // user-added (customizable) models that have explicitly enabled it. + if (def.fetch_from !== 'customizable') return false; + const caps = def.capabilities; + if (!caps) return null; + if ( + caps.supports_vision === true || + caps.modalities?.input?.includes('image') || + (caps.features ?? []).includes('vision') + ) { + return true; + } + if (caps.supports_vision === false) { + return false; + } + return null; + } catch { + return null; + } +} + +function getVisionPromise(): Promise { + if (cachedPromise === null) { + cachedPromise = resolveVisionSupport(); + } + return cachedPromise; +} + +function invalidateAndRefetch(): void { + // Capture the new promise locally so a *second* rapid invalidation that + // races ahead of this one cannot deliver a stale value to subscribers. + // We only notify if our promise is still the current cached one by the + // time it resolves. + const next = resolveVisionSupport(); + cachedPromise = next; + next.then((value) => { + if (cachedPromise === next) { + subscribers.forEach((cb) => cb(value)); + } + }); +} + +if (typeof window !== 'undefined') { + window.addEventListener(MODEL_CHANGED_EVENT, invalidateAndRefetch); +} + +/** + * Test-only escape hatch: clear the module-level cache and subscriber set. + * Vitest runs all specs in the same module instance, so without this a + * stubbed API response from one test could leak into the next. Production + * code should never call this. + */ +export function __resetVisionCacheForTesting(): void { + cachedPromise = null; + subscribers.clear(); +} + +export function useDefaultModelVision(): VisionState { + const [supportsVision, setSupportsVision] = useState(null); + + useEffect(() => { + let cancelled = false; + getVisionPromise().then((next) => { + if (!cancelled) setSupportsVision(next); + }); + subscribers.add(setSupportsVision); + return () => { + cancelled = true; + subscribers.delete(setSupportsVision); + }; + }, []); + + return supportsVision; +} diff --git a/webui/src/hooks/useSessionChat.test.ts b/webui/src/hooks/useSessionChat.test.ts new file mode 100644 index 00000000..674e1899 --- /dev/null +++ b/webui/src/hooks/useSessionChat.test.ts @@ -0,0 +1,123 @@ +/** + * Regression tests for "non-Session entry first message with images". + * + * These cover the chain: + * onCreateAndSend(text, imageParts) → useSessionChat.createAndSend({ text, imageParts }) + * → /api/session/{id}/prompt_async with parts[] + * + * The key regression being guarded: before the fix, imageParts were silently + * dropped when the first message was sent through non-Session chat composers + * (CreateAgentChat, WorkflowCreate/CreateChatTab, WorkflowDetail/ChatTab, + * EntitySheet, ChatDialog). Now createAndSend forwards them into the payload. + */ +import { describe, expect, it, vi, beforeEach } from 'vitest'; + +const mockPost = vi.fn(); +vi.mock('@/api/client', () => ({ + default: { post: (...args: unknown[]) => mockPost(...args) }, +})); + +import { renderHook, act } from '@testing-library/react'; +import { useSessionChat } from './useSessionChat'; +import type { ImagePartData } from '@/utils/imageUpload'; + +const SESSION_ID = 'sess-abc'; + +beforeEach(() => { + vi.clearAllMocks(); + // /api/session creates a new session + mockPost.mockImplementation((url: string) => { + if (url === '/api/session') return Promise.resolve({ data: { id: SESSION_ID } }); + return Promise.resolve({ data: {} }); + }); +}); + +describe('useSessionChat.createAndSend — image forwarding', () => { + it('includes imageParts in the prompt_async payload', async () => { + const { result } = renderHook(() => + useSessionChat({ title: 'Test', autoCreate: false }), + ); + + const img: ImagePartData = { + url: 'data:image/png;base64,abc', + mime: 'image/png', + filename: 'screenshot.png', + }; + + await act(async () => { + await result.current.createAndSend({ text: 'describe this', imageParts: [img] }); + }); + + // Find the prompt_async call + const promptCall = mockPost.mock.calls.find(([url]: string[]) => + url === `/api/session/${SESSION_ID}/prompt_async`, + ); + expect(promptCall).toBeDefined(); + + const payload = promptCall![1] as { parts: unknown[] }; + expect(payload.parts).toEqual([ + { type: 'text', text: 'describe this' }, + { type: 'file', url: img.url, mime: img.mime, filename: img.filename }, + ]); + }); + + it('works for image-only messages (no text)', async () => { + const { result } = renderHook(() => + useSessionChat({ title: 'Test', autoCreate: false }), + ); + + const img: ImagePartData = { + url: 'data:image/jpeg;base64,xyz', + mime: 'image/jpeg', + filename: 'photo.jpg', + }; + + await act(async () => { + await result.current.createAndSend({ text: '', imageParts: [img] }); + }); + + const promptCall = mockPost.mock.calls.find(([url]: string[]) => + url === `/api/session/${SESSION_ID}/prompt_async`, + ); + expect(promptCall).toBeDefined(); + + const payload = promptCall![1] as { parts: unknown[] }; + // No text part when text is empty; only the file part. + expect(payload.parts).toEqual([ + { type: 'file', url: img.url, mime: img.mime, filename: img.filename }, + ]); + }); + + it('works for text-only messages (backward compat — no imageParts arg)', async () => { + const { result } = renderHook(() => + useSessionChat({ title: 'Test', autoCreate: false }), + ); + + await act(async () => { + await result.current.createAndSend({ text: 'hello' }); + }); + + const promptCall = mockPost.mock.calls.find(([url]: string[]) => + url === `/api/session/${SESSION_ID}/prompt_async`, + ); + expect(promptCall).toBeDefined(); + + const payload = promptCall![1] as { parts: unknown[] }; + expect(payload.parts).toEqual([{ type: 'text', text: 'hello' }]); + }); + + it('forwards the agent field when provided', async () => { + const { result } = renderHook(() => + useSessionChat({ title: 'Test', autoCreate: false }), + ); + + await act(async () => { + await result.current.createAndSend({ text: 'hi', agent: 'my-agent' }); + }); + + const promptCall = mockPost.mock.calls.find(([url]: string[]) => + url === `/api/session/${SESSION_ID}/prompt_async`, + ); + expect(promptCall![1]).toMatchObject({ agent: 'my-agent' }); + }); +}); diff --git a/webui/src/hooks/useSessionChat.ts b/webui/src/hooks/useSessionChat.ts index 0c309e09..f5e005fc 100644 --- a/webui/src/hooks/useSessionChat.ts +++ b/webui/src/hooks/useSessionChat.ts @@ -1,5 +1,6 @@ import { useState, useCallback, useRef, useEffect } from 'react'; import client from '@/api/client'; +import { buildPromptParts, type ImagePartData } from '@/utils/imageUpload'; export interface UseSessionChatOptions { title: string; @@ -12,6 +13,13 @@ export interface UseSessionChatOptions { autoCreate?: boolean; } +/** Options accepted by {@link useSessionChat} `createAndSend`. */ +export interface CreateAndSendOptions { + text: string; + imageParts?: ImagePartData[]; + agent?: string; +} + export function useSessionChat({ title, category, @@ -93,10 +101,14 @@ export function useSessionChat({ }, []); const createAndSend = useCallback( - async (text: string, agent?: string): Promise => { + async ({ + text, + imageParts, + agent, + }: CreateAndSendOptions): Promise => { const sid = await create(); const payload: Record = { - parts: [{ type: 'text', text }], + parts: buildPromptParts(text, imageParts), }; if (agent) payload.agent = agent; client.post(`/api/session/${sid}/prompt_async`, payload).catch(() => {}); diff --git a/webui/src/hooks/useSessions.test.ts b/webui/src/hooks/useSessions.test.ts index f37d5180..906ca89d 100644 --- a/webui/src/hooks/useSessions.test.ts +++ b/webui/src/hooks/useSessions.test.ts @@ -1,6 +1,7 @@ import { describe, expect, it, vi, afterEach } from 'vitest'; import { renderHook, act } from '@testing-library/react'; import { applyMessagePartUpdate, useSessionMessages } from './useSessions'; +import client from '@/api/client'; import type { Message } from '@/types'; // --------------------------------------------------------------------------- @@ -139,6 +140,28 @@ describe('updateMessagePart scheduling', () => { vi.clearAllMocks(); }); + it('keeps parentID from fetched messages for regenerate truncation', async () => { + vi.mocked(client.get).mockResolvedValueOnce({ + data: [{ + info: { + id: 'msg-2', + sessionID: 'sess-1', + role: 'assistant', + parentID: 'msg-1', + time: { created: 123 }, + }, + parts: [], + }], + } as any); + + const { result } = renderHook(() => useSessionMessages('sess-1')); + + await act(async () => {}); + + expect(result.current.messages).toHaveLength(1); + expect(result.current.messages[0].parentID).toBe('msg-1'); + }); + it('first appearance of a new part updates messages state immediately', async () => { const { result } = renderHook(() => useSessionMessages('sess-1')); // Wait for the initial fetchMessages effect to settle so it doesn't wipe state diff --git a/webui/src/hooks/useSessions.ts b/webui/src/hooks/useSessions.ts index 70af390d..478748de 100644 --- a/webui/src/hooks/useSessions.ts +++ b/webui/src/hooks/useSessions.ts @@ -177,6 +177,7 @@ export function useSessionMessages(sessionId?: string) { sessionID: msg.info.sessionID, role: msg.info.role, parts: msg.parts || [], + parentID: msg.info.parentID, agent: msg.info.agent, model: msg.info.model, timestamp: msg.info.time?.created || Date.now(), @@ -226,6 +227,7 @@ export function useSessionMessages(sessionId?: string) { updated[existingIndex] = { ...existing, ...messageInfo, + parentID: messageInfo.parentID ?? existing.parentID, timestamp: messageInfo.time?.created || existing.timestamp, // Preserve compacted/finish from the authoritative refetch data — // SSE events never carry these fields, so a naive spread would @@ -273,6 +275,7 @@ export function useSessionMessages(sessionId?: string) { sessionID: messageInfo.sessionID, role: messageInfo.role, parts: [], + parentID: messageInfo.parentID, agent: messageInfo.agent, model: messageInfo.model, timestamp: messageInfo.time?.created || Date.now(), diff --git a/webui/src/locales/en-US/common.json b/webui/src/locales/en-US/common.json index 7dcefe1b..9737a9f1 100644 --- a/webui/src/locales/en-US/common.json +++ b/webui/src/locales/en-US/common.json @@ -1,4 +1,7 @@ { + "image": { + "preview": "Image Preview" + }, "button": { "save": "Save", "saving": "Saving...", diff --git a/webui/src/locales/en-US/model.json b/webui/src/locales/en-US/model.json index c2480436..fbd1d228 100644 --- a/webui/src/locales/en-US/model.json +++ b/webui/src/locales/en-US/model.json @@ -100,6 +100,7 @@ "capabilities": "Capabilities", "toolCall": "Tool Calling", "vision": "Vision", + "visionPredefinedHint": "Vision capability for built-in models is set by the model definition and cannot be changed manually. Add a custom model to override it.", "streaming": "Streaming", "reasoning": "Reasoning", "pricing": "Pricing (per 1M tokens)", @@ -119,7 +120,16 @@ "loadFailed": "Failed to load provider catalog", "noModelsToTest": "No enabled models to test", "batchTestDone": "Batch test complete", - "batchTestSummary": "{{success}} succeeded, {{failed}} failed" + "batchTestSummary": "{{success}} succeeded, {{failed}} failed", + "azureDeploymentName": "Azure Deployment Name", + "azureDeploymentPlaceholder": "e.g. my-gpt-4o-prod", + "azureDeploymentHint": "Azure OpenAI requests use the deployment name, not a fixed model name. The preset models are examples; enter your own deployment name here.", + "azureDeploymentDisplayName": "Display Name (optional)", + "azureDeploymentDisplayPlaceholder": "e.g. GPT-4o Production", + "azureDeploymentRequired": "Select at least one preset model or enter an Azure deployment name", + "azureModelIdHint": "For Azure OpenAI, Model ID should be the deployment name from Azure Portal.", + "azureCustomDeployments": "Custom Azure Deployments", + "azureNoCustomDeployments": "No custom Azure deployment has been added yet." }, "wizard": { "providerSaved": "Provider Saved", diff --git a/webui/src/locales/en-US/session.json b/webui/src/locales/en-US/session.json index b7db2854..88a419c6 100644 --- a/webui/src/locales/en-US/session.json +++ b/webui/src/locales/en-US/session.json @@ -89,12 +89,15 @@ }, "upload": { "select": "Upload documents", + "selectWithImage": "Upload files or images", "remove": "Remove file", "retry": "Retry upload", "dropHint": "Drop document files to upload", "waiting": "Files are still uploading. Please wait before sending.", "invalidType": "Only txt, md, json, yaml, yml, xml, csv, pdf, doc, docx, html, htm, ppt, pptx, xls, and xlsx files are supported", - "errorGeneric": "Upload failed. Please try again." + "errorGeneric": "Upload failed. Please try again.", + "imageNotSupported": "The current model does not support image analysis. Please select a vision-capable model in the model settings.", + "imageNotSupportedBanner": "Image analysis is not supported by the current model" }, "tool": { "pending": "Pending", diff --git a/webui/src/locales/en-US/task.json b/webui/src/locales/en-US/task.json index 61dd327d..a0829d48 100644 --- a/webui/src/locales/en-US/task.json +++ b/webui/src/locales/en-US/task.json @@ -181,6 +181,10 @@ "editTitle": "Edit Scheduled Task", "createTitle": "Create Task", "agentName": "Agent Name", + "workflowParamsLabel": "Workflow Params", + "workflowParamsHint": "(JSON object passed as workflow inputs)", + "workflowParamsPlaceholder": "{\n \"keyword\": \"example\"\n}", + "workflowParamsInvalid": "Workflow params must be a valid JSON object", "urgentLabel": "Urgent", "highLabel": "High", "normalLabel": "Normal", diff --git a/webui/src/locales/en-US/tool.json b/webui/src/locales/en-US/tool.json index ba19fac3..c2499306 100644 --- a/webui/src/locales/en-US/tool.json +++ b/webui/src/locales/en-US/tool.json @@ -109,6 +109,26 @@ "oneArgPerLine": "One argument per line", "serviceUrl": "Service URL", "serviceUrlPlaceholder": "e.g. http://localhost:3000/sse", + "transport": "Transport", + "transportAuto": "Auto detect (HTTP -> SSE)", + "transportSSE": "SSE only", + "transportHTTP": "Streamable HTTP only", + "authMethod": "Auth Method", + "authNone": "No auth", + "authBearer": "Bearer Token", + "authHeader": "Custom header", + "authQuery": "Query parameter", + "authToken": "Token", + "authTokenPlaceholder": "Enter token or {secret:secret_id}", + "authHeaderName": "Header Name", + "authHeaderValue": "Header Value", + "authHeaderValuePlaceholder": "Enter header value or {secret:secret_id}", + "authQueryName": "Query Parameter Name", + "authQueryValue": "Query Parameter Value", + "authQueryValuePlaceholder": "Enter query value or {secret:secret_id}", + "extraHeaders": "Extra Headers", + "extraHeadersPlaceholder": "{\n \"X-Client\": \"flocks\"\n}", + "extraHeadersHint": "Must be a JSON object. Prefer the auth fields above for sensitive credentials.", "hintTitle": "Not sure how to fill in?", "hintDesc": "You can switch to the \"Chat Assistant\" tab, tell Rex what MCP service you want to connect, and it will automatically complete the configuration.", "chatIntegration": "Chat Integration", @@ -420,6 +440,7 @@ "addFailed": "Add failed: {{error}}", "mcpNameRequired": "Please enter MCP service name", "apiKeyRequired": "Please enter API Key", + "invalidHeaders": "Extra headers must be a valid JSON object", "credSaved": "Credentials saved", "saveFailed": "Save failed: {{error}}", "testFailed": "Test failed: {{error}}", diff --git a/webui/src/locales/en-US/workflow.json b/webui/src/locales/en-US/workflow.json index f53a49c0..6bffeb0c 100644 --- a/webui/src/locales/en-US/workflow.json +++ b/webui/src/locales/en-US/workflow.json @@ -2,6 +2,10 @@ "pageTitle": "Workflows", "pageDescription": "Manage and execute workflows", "createWorkflow": "Create Workflow", + "section": { + "custom": "Custom Workflows", + "builtin": "Built-in Workflows" + }, "emptyState": { "title": "No Workflows", "description": "Create your first workflow" @@ -39,7 +43,16 @@ }, "topBar": { "collapsePanel": "Collapse Panel", - "expandPanel": "Expand Panel" + "expandPanel": "Expand Panel", + "runningStage": "Current stage: {{phase}} · Node: {{node}}", + "phase": { + "queued": "Queued", + "running": "Running", + "success": "Completed", + "error": "Failed", + "timeout": "Timed out", + "cancelled": "Cancelled" + } }, "rightPanel": { "tabOverview": "Overview", @@ -115,16 +128,20 @@ "runtimeStatus": "Execution status: {{status}}", "runtimeInputs": "Runtime Inputs", "runtimeOutputs": "Runtime Outputs", + "expandJsonBlock": "Expand {{label}}", + "collapseJsonBlock": "Collapse {{label}}", "runNodeSection": "Run Node", "runNodeHint": "Execute this node in isolation", "runNodeUnsupported": "This node type is not supported yet", "runNodeUnsupportedDesc": "Branch and Loop nodes are not supported for isolated execution yet.", "runNodeInputs": "Execution Inputs", + "copyInput": "Copy input", "useLatestInputs": "Use latest inputs", "restoreSuggestedInputs": "Restore suggested inputs", "runNodeAction": "Run Node", "runningNode": "Running...", "runNodeSuccess": "Run Succeeded", + "copyOutput": "Copy output", "runNodeError": "Run Error", "runNodeStdout": "Stdout", "runNodeTraceback": "Traceback", diff --git a/webui/src/locales/zh-CN/common.json b/webui/src/locales/zh-CN/common.json index cf01881d..7e2b9bb4 100644 --- a/webui/src/locales/zh-CN/common.json +++ b/webui/src/locales/zh-CN/common.json @@ -1,4 +1,7 @@ { + "image": { + "preview": "图片预览" + }, "button": { "save": "保存", "saving": "保存中...", diff --git a/webui/src/locales/zh-CN/model.json b/webui/src/locales/zh-CN/model.json index 29cb71e4..a88592e0 100644 --- a/webui/src/locales/zh-CN/model.json +++ b/webui/src/locales/zh-CN/model.json @@ -100,6 +100,7 @@ "capabilities": "能力", "toolCall": "工具调用", "vision": "视觉", + "visionPredefinedHint": "内置模型的视觉能力由模型定义决定,不可手动修改。如需自定义,请新增模型。", "streaming": "流式输出", "reasoning": "推理", "pricing": "价格(每百万 Token)", @@ -119,7 +120,16 @@ "loadFailed": "加载 Provider 目录失败", "noModelsToTest": "没有已启用的模型可测试", "batchTestDone": "批量测试完成", - "batchTestSummary": "{{success}} 成功, {{failed}} 失败" + "batchTestSummary": "{{success}} 成功, {{failed}} 失败", + "azureDeploymentName": "Azure 部署名称", + "azureDeploymentPlaceholder": "例如 my-gpt-4o-prod", + "azureDeploymentHint": "Azure OpenAI 请求使用 deployment name,而不是固定模型名。预设模型只是常用示例,你可以在这里填写自己的部署名称。", + "azureDeploymentDisplayName": "显示名称(可选)", + "azureDeploymentDisplayPlaceholder": "例如 GPT-4o Production", + "azureDeploymentRequired": "请至少选择一个预设模型,或填写 Azure deployment name", + "azureModelIdHint": "对于 Azure OpenAI,模型 ID 请填写 Azure Portal 中的 deployment name。", + "azureCustomDeployments": "自定义 Azure Deployments", + "azureNoCustomDeployments": "尚未添加自定义 Azure deployment。" }, "wizard": { "providerSaved": "Provider 已保存", diff --git a/webui/src/locales/zh-CN/session.json b/webui/src/locales/zh-CN/session.json index 874d29c5..05228be1 100644 --- a/webui/src/locales/zh-CN/session.json +++ b/webui/src/locales/zh-CN/session.json @@ -89,12 +89,15 @@ }, "upload": { "select": "上传文档", + "selectWithImage": "上传文件或图片", "remove": "移除文件", "retry": "重试上传", "dropHint": "松开以上传文档文件", "waiting": "文件上传中,请等待完成后发送", "invalidType": "仅支持 txt、md、json、yaml、yml、xml、csv、pdf、doc、docx、html、htm、ppt、pptx、xls、xlsx", - "errorGeneric": "上传失败,请重试" + "errorGeneric": "上传失败,请重试", + "imageNotSupported": "当前模型不支持图片分析,请在模型配置中选择支持视觉的模型", + "imageNotSupportedBanner": "当前模型不支持图片分析" }, "tool": { "pending": "等待中", diff --git a/webui/src/locales/zh-CN/task.json b/webui/src/locales/zh-CN/task.json index 7ce6c327..c04be9f7 100644 --- a/webui/src/locales/zh-CN/task.json +++ b/webui/src/locales/zh-CN/task.json @@ -180,6 +180,10 @@ "editTitle": "编辑定时任务", "createTitle": "创建任务", "agentName": "Agent 名称", + "workflowParamsLabel": "Workflow 参数", + "workflowParamsHint": "(JSON 对象,会作为 workflow inputs 传入)", + "workflowParamsPlaceholder": "{\n \"keyword\": \"example\"\n}", + "workflowParamsInvalid": "Workflow 参数必须是合法的 JSON 对象", "urgentLabel": "紧急", "highLabel": "高", "normalLabel": "普通", diff --git a/webui/src/locales/zh-CN/tool.json b/webui/src/locales/zh-CN/tool.json index 865978af..40149100 100644 --- a/webui/src/locales/zh-CN/tool.json +++ b/webui/src/locales/zh-CN/tool.json @@ -109,6 +109,26 @@ "oneArgPerLine": "每行一个参数", "serviceUrl": "服务地址", "serviceUrlPlaceholder": "例如 http://localhost:3000/sse", + "transport": "传输协议", + "transportAuto": "自动探测 (HTTP -> SSE)", + "transportSSE": "仅 SSE", + "transportHTTP": "仅 Streamable HTTP", + "authMethod": "认证方式", + "authNone": "无认证", + "authBearer": "Bearer Token", + "authHeader": "自定义请求头", + "authQuery": "Query 参数", + "authToken": "Token", + "authTokenPlaceholder": "输入 Token 或 {secret:secret_id}", + "authHeaderName": "请求头名称", + "authHeaderValue": "请求头值", + "authHeaderValuePlaceholder": "输入请求头值或 {secret:secret_id}", + "authQueryName": "Query 参数名", + "authQueryValue": "Query 参数值", + "authQueryValuePlaceholder": "输入 Query 参数值或 {secret:secret_id}", + "extraHeaders": "额外请求头", + "extraHeadersPlaceholder": "{\n \"X-Client\": \"flocks\"\n}", + "extraHeadersHint": "填写 JSON 对象。敏感认证建议优先使用上方认证方式字段。", "hintTitle": "不确定如何填写?", "hintDesc": "可切换到「对话助手」选项卡,告诉 Rex 你要接入什么 MCP 服务,它会自动完成配置。", "chatIntegration": "对话接入", @@ -420,6 +440,7 @@ "addFailed": "添加失败: {{error}}", "mcpNameRequired": "请输入 MCP 服务名称", "apiKeyRequired": "请输入 API Key", + "invalidHeaders": "额外请求头必须是合法的 JSON 对象", "credSaved": "凭证已保存", "saveFailed": "保存失败: {{error}}", "testFailed": "测试失败: {{error}}", diff --git a/webui/src/locales/zh-CN/workflow.json b/webui/src/locales/zh-CN/workflow.json index 885f8f6b..74e85232 100644 --- a/webui/src/locales/zh-CN/workflow.json +++ b/webui/src/locales/zh-CN/workflow.json @@ -2,6 +2,10 @@ "pageTitle": "工作流", "pageDescription": "管理和执行工作流", "createWorkflow": "创建工作流", + "section": { + "custom": "自定义工作流", + "builtin": "内置工作流" + }, "emptyState": { "title": "暂无工作流", "description": "创建您的第一个工作流" @@ -39,7 +43,16 @@ }, "topBar": { "collapsePanel": "收起面板", - "expandPanel": "展开面板" + "expandPanel": "展开面板", + "runningStage": "当前阶段:{{phase}} · 节点:{{node}}", + "phase": { + "queued": "排队中", + "running": "执行中", + "success": "已完成", + "error": "执行失败", + "timeout": "超时", + "cancelled": "已取消" + } }, "rightPanel": { "tabOverview": "概览", @@ -115,16 +128,20 @@ "runtimeStatus": "执行状态:{{status}}", "runtimeInputs": "真实输入", "runtimeOutputs": "真实输出", + "expandJsonBlock": "展开{{label}}", + "collapseJsonBlock": "收起{{label}}", "runNodeSection": "单节点执行", "runNodeHint": "隔离执行当前节点", "runNodeUnsupported": "当前节点类型暂不支持", "runNodeUnsupportedDesc": "Branch 和 Loop 节点暂不支持单节点执行。", "runNodeInputs": "执行输入", + "copyInput": "复制输入", "useLatestInputs": "使用最近一次输入", "restoreSuggestedInputs": "恢复建议输入", "runNodeAction": "执行节点", "runningNode": "执行中...", "runNodeSuccess": "执行成功", + "copyOutput": "复制输出", "runNodeError": "执行错误", "runNodeStdout": "标准输出", "runNodeTraceback": "错误堆栈", diff --git a/webui/src/pages/Agent/CreateAgentChat.tsx b/webui/src/pages/Agent/CreateAgentChat.tsx index 3bec3b6e..887af8e3 100644 --- a/webui/src/pages/Agent/CreateAgentChat.tsx +++ b/webui/src/pages/Agent/CreateAgentChat.tsx @@ -1,8 +1,9 @@ -import { useEffect, useCallback } from 'react'; +import { useEffect } from 'react'; import { X, Bot } from 'lucide-react'; import { useTranslation } from 'react-i18next'; import SessionChat from '@/components/common/SessionChat'; import { useSessionChat } from '@/hooks/useSessionChat'; +import { useDefaultModelVision } from '@/hooks/useDefaultModelVision'; const SUGGESTIONS = [ '创建一个威胁情报分析 Agent,能够查询 IP/域名/哈希的信誉并输出分析报告', @@ -47,6 +48,7 @@ interface CreateAgentChatProps { export default function CreateAgentChat({ open, onClose }: CreateAgentChatProps) { const { t } = useTranslation(['agent', 'common']); + const supportsVision = useDefaultModelVision(); const { sessionId, createAndSend, reset } = useSessionChat({ title: t('agent:chat.createTitle'), @@ -59,12 +61,6 @@ export default function CreateAgentChat({ open, onClose }: CreateAgentChatProps) if (!open) reset(); }, [open, reset]); - const handleCreateAndSend = useCallback( - async (text: string) => { - await createAndSend(text); - }, - [createAndSend], - ); if (!open) return null; @@ -100,7 +96,8 @@ export default function CreateAgentChat({ open, onClose }: CreateAgentChatProps) placeholder={t('agent:chat.placeholder')} className="flex-1 min-h-0" suggestions={SUGGESTIONS} - onCreateAndSend={!sessionId ? handleCreateAndSend : undefined} + supportsVision={supportsVision} + onCreateAndSend={!sessionId ? (text, imageParts) => createAndSend({ text, imageParts }) : undefined} welcomeContent={!sessionId ? (
diff --git a/webui/src/pages/Login/index.tsx b/webui/src/pages/Login/index.tsx index e199ad50..a809d29e 100644 --- a/webui/src/pages/Login/index.tsx +++ b/webui/src/pages/Login/index.tsx @@ -82,7 +82,7 @@ export default function LoginPage() {
{t('login.recoverPassword')} {' '} - flocks admin generate-one-time-password --username admin_user_name + flocks admin generate-one-time-password --username admin
diff --git a/webui/src/pages/Model/index.tsx b/webui/src/pages/Model/index.tsx index cb15baf3..8415e966 100644 --- a/webui/src/pages/Model/index.tsx +++ b/webui/src/pages/Model/index.tsx @@ -16,6 +16,7 @@ import { useToast } from '@/components/common/Toast'; import EntitySheet from '@/components/common/EntitySheet'; import { useProviders, type EnrichedProvider } from '@/hooks/useProviders'; import { useSSE } from '@/hooks/useSSE'; +import { MODEL_CHANGED_EVENT } from '@/hooks/useDefaultModelVision'; import { providerAPI, modelV2API, usageAPI, customAPI, modelSettingsAPI, catalogAPI, defaultModelAPI, @@ -55,6 +56,12 @@ function providerAllowsEmptyApiKey(providerId: string): boolean { ); } +const AZURE_PROVIDER_IDS = new Set(['azure-openai', 'azure']); + +function isAzureProviderId(providerId: string): boolean { + return AZURE_PROVIDER_IDS.has(providerId); +} + // ==================== Connection Cache ==================== const CONNECTION_CACHE_KEY = 'flocks_provider_connection_cache'; @@ -1088,6 +1095,8 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: { const [baseUrl, setBaseUrl] = useState(''); const [description, setDescription] = useState(''); const [providerName, setProviderName] = useState(''); + const [azureDeploymentName, setAzureDeploymentName] = useState(''); + const [azureDeploymentDisplayName, setAzureDeploymentDisplayName] = useState(''); // Model selection (for catalog providers) const [selectedModelIds, setSelectedModelIds] = useState>(new Set()); @@ -1172,6 +1181,8 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: { setDescription(provider.description || ''); setSelectedModelIds(new Set(provider.models.map(m => m.id))); setProviderName(''); + setAzureDeploymentName(''); + setAzureDeploymentDisplayName(''); } }; @@ -1212,7 +1223,8 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: { base_url: baseUrl.trim() || undefined, provider_name: selectedCatalogId === 'openai-compatible' && providerName.trim() ? providerName.trim() : undefined, }); - const res = await providerAPI.testCredentials(selectedCatalogId); + const azureModelId = isAzureProviderId(selectedCatalogId) ? azureDeploymentName.trim() : ''; + const res = await providerAPI.testCredentials(selectedCatalogId, azureModelId || undefined); setTestResult({ success: res.data.success, message: res.data.message || (res.data.success ? t('status.connected') : t('form.testFailed')), @@ -1235,6 +1247,11 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: { toast.warning('Please enter API Key'); return; } + const azureModelId = isAzureProviderId(selectedCatalogId) ? azureDeploymentName.trim() : ''; + if (isAzureProviderId(selectedCatalogId) && selectedModelIds.size === 0 && !azureModelId) { + toast.warning(t('form.azureDeploymentRequired')); + return; + } try { setSaving(true); if (selectedCatalogId === 'openai-compatible') { @@ -1259,6 +1276,20 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: { const unselected = selectedCatalog.models.filter(m => !selectedModelIds.has(m.id)).map(m => m.id); await Promise.all(unselected.map(id => modelV2API.deleteDefinition(selectedCatalogId, id).catch(() => {}))); } + if (azureModelId) { + await modelV2API.createDefinition(selectedCatalogId, { + model_id: azureModelId, + name: azureDeploymentDisplayName.trim() || azureModelId, + }); + try { + const res = await providerAPI.testCredentials(selectedCatalogId, azureModelId); + if (!res.data.success) { + toast.warning(t('form.testFailed'), res.data.error || res.data.message); + } + } catch (testErr: any) { + toast.warning(t('form.testFailed'), testErr.response?.data?.detail || testErr.message); + } + } toast.success(t('providerAdded'), displayName); setSavedProviderId(selectedCatalogId); setSavedProviderName(displayName); @@ -1600,6 +1631,36 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: { )} + {isAzureProviderId(selectedCatalogId) && ( +
+
+ + setAzureDeploymentName(e.target.value)} + className="w-full px-3 py-2 border border-blue-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-300 text-sm bg-white" + placeholder={t('form.azureDeploymentPlaceholder')} + /> +

{t('form.azureDeploymentHint')}

+
+
+ + setAzureDeploymentDisplayName(e.target.value)} + className="w-full px-3 py-2 border border-blue-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-300 text-sm bg-white" + placeholder={azureDeploymentName.trim() || t('form.azureDeploymentDisplayPlaceholder')} + /> +
+
+ )} + {selectedCatalog.models.length > 0 && (
@@ -1712,7 +1773,13 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: {

{t('wizard.modelsAdded', { count: addedModelCount })}

)} - +
)} @@ -1791,10 +1858,12 @@ function useModelForm() { }; } -function ModelFormFields({ form, testResult, testing }: { +function ModelFormFields({ form, testResult, testing, modelIdPlaceholder, modelIdHint }: { form: ReturnType; testResult: { success: boolean; message: string; latency?: number } | null; testing: boolean; + modelIdPlaceholder?: string; + modelIdHint?: string; }) { const { t } = useTranslation('model'); return ( @@ -1809,8 +1878,9 @@ function ModelFormFields({ form, testResult, testing }: { value={form.modelId} onChange={e => form.setModelId(e.target.value)} className="w-full px-3 py-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-slate-400 text-sm" - placeholder="gpt-4o-custom" + placeholder={modelIdPlaceholder || 'gpt-4o-custom'} /> + {modelIdHint &&

{modelIdHint}

}
- + ); } -function ToggleField({ label, checked, onChange }: { - label: string; checked: boolean; onChange: (v: boolean) => void; +function ToggleField({ label, checked, onChange, disabled, disabledHint }: { + label: string; + checked: boolean; + onChange: (v: boolean) => void; + disabled?: boolean; + /** Tooltip text shown when the toggle is disabled. */ + disabledHint?: string; }) { return ( -