From 73d6fe346251e08e8041a13dc2174521636c2758 Mon Sep 17 00:00:00 2001 From: xiami Date: Fri, 8 May 2026 11:38:19 +0800 Subject: [PATCH 01/27] feat(web2cli): add generate-spec and spec-driven CLI outputs (#230) * feat(web2cli): add generate-spec and spec-driven CLI outputs - Add generate-spec.py to build web2cli spec from captured requests - Extend generate-cli with spec loading and verify/md/postman formats - Refresh inject-hook-base.js capture hook (web2cli-base version) - Update SKILL.md workflow for verify materials - Add tests for generate-spec, CLI formats, and hook base * docs(web2cli): add CLI-to-skill integration reference - Add references/cli-in-skill.md for integrating generated CLI into skills - Update SKILL.md description, summary steps, and Reference section --- .flocks/plugins/skills/web2cli/SKILL.md | 65 +- .../skills/web2cli/references/cli-in-skill.md | 171 +++++ .../skills/web2cli/scripts/generate-cli.py | 522 ++++++++++++++- .../skills/web2cli/scripts/generate-spec.py | 408 ++++++++++++ .../web2cli/scripts/inject-hook-base.js | 628 ++++++++++++++---- tests/tool/test_web2cli_generate_cli.py | 140 ++++ tests/tool/test_web2cli_generate_spec.py | 106 +++ tests/tool/test_web2cli_hook_base.py | 8 +- 8 files changed, 1887 insertions(+), 161 deletions(-) create mode 100644 .flocks/plugins/skills/web2cli/references/cli-in-skill.md create mode 100644 .flocks/plugins/skills/web2cli/scripts/generate-spec.py create mode 100644 tests/tool/test_web2cli_generate_spec.py diff --git a/.flocks/plugins/skills/web2cli/SKILL.md b/.flocks/plugins/skills/web2cli/SKILL.md index bf0b2232..8dbf1218 100644 --- a/.flocks/plugins/skills/web2cli/SKILL.md +++ b/.flocks/plugins/skills/web2cli/SKILL.md @@ -1,6 +1,6 @@ --- name: web2cli -description: 使用统一的 Web2CLI 流程捕获网站的 XHR/Fetch 请求,并生成可复用的 CLI、Markdown 文档和 Postman 集合。支持 `agent-browser` 与 `cdp-direct` 两种模式:前者适合独立浏览器会话,后者复用用户 Chromium 系浏览器登录态与 CDP 能力。适用于复现登录后操作、沉淀接口调用样例,或基于页面操作生成自动化工具时。 +description: 使用统一的 Web2CLI 流程捕获网站的 XHR/Fetch 请求,并生成可复用的 CLI、Markdown 文档。支持 `agent-browser` 与 `cdp-direct` 两种模式:前者适合独立浏览器会话,后者复用用户 Chromium 系浏览器登录态与 CDP 能力。适用于复现登录后操作、沉淀接口调用样例,或基于页面操作生成自动化工具时。 required: browser-use --- @@ -63,8 +63,10 @@ mkdir -p "$CAPTURE_ROOT/captures" - 浏览器内存中的原始捕获数据:`window.__capturedRequests` - 导出的接口抓包 JSON:`$CAPTURE_ROOT/captures/${CAPTURE_NAME}_api.json` - 浏览器认证状态:`$CAPTURE_ROOT/auth-state.json` +- 操作适配规格:`$CAPTURE_ROOT/web2cli-spec.json` - 站点自适应 Hook(仅当 base 失败时创建):`$CAPTURE_ROOT/hook.js` - 生成的 CLI 工具:`$CAPTURE_ROOT/_cli.py`,`generate-cli.py` 会把 `-` 等非 Python 模块名字符替换为 `_` +- 生成的验证材料:`$CAPTURE_ROOT/${CAPTURE_NAME}_verify.json` - 生成的接口文档:`$CAPTURE_ROOT/${CAPTURE_NAME}_api.md` - 生成的 Postman 集合:`$CAPTURE_ROOT/${CAPTURE_NAME}_postman.json` @@ -323,35 +325,73 @@ jq -r '.[].method' "$CAPTURE_ROOT/captures/${CAPTURE_NAME}_api.json" | sort | un jq '.[] | select(.method == "POST") | {url: .url, body: .requestBody}' "$CAPTURE_ROOT/captures/${CAPTURE_NAME}_api.json" ``` -### 8. 生成 CLI 工具 +### 8. 生成 web2cli-spec 规格 -基于 `"$CAPTURE_ROOT/captures/${CAPTURE_NAME}_api.json"` 与 `"$CAPTURE_ROOT/auth-state.json"` 生成新的 CLI 工具。 +先基于 `"$CAPTURE_ROOT/captures/${CAPTURE_NAME}_api.json"` 生成中间契约层 `web2cli-spec.json`。 ```bash -uv run python .flocks/plugins/skills/web2cli/scripts/generate-cli.py \ +uv run python .flocks/plugins/skills/web2cli/scripts/generate-spec.py \ "$CAPTURE_ROOT/captures/${CAPTURE_NAME}_api.json" \ - --format python \ --base-url "https://example.com" \ + --output "$CAPTURE_ROOT/web2cli-spec.json" +``` + +`web2cli-spec.json` 是抓包结果到最终 CLI 之间的可编辑契约,包含: + +- 目标站点与命令名 +- 鉴权策略(如 `PUBLIC` / `COOKIE` / `HEADER`) +- 主请求的 method、endpoint、query/body 模板 +- CLI 参数定义 +- 固定输出列定义 +- 验证材料初稿 + +生成后必须检查并按需修正: + +- `strategy` 是否正确 +- `args` 是否符合实际操作意图 +- `columns` 与字段路径是否对应目标数据 +- `verify` 的最少行数、必填列是否合理 + +### 9. 基于 spec 生成 CLI 工具 + +从 `"$CAPTURE_ROOT/web2cli-spec.json"` 生成最终 CLI。 + +```bash +uv run python .flocks/plugins/skills/web2cli/scripts/generate-cli.py \ + --spec "$CAPTURE_ROOT/web2cli-spec.json" \ + --format python \ --output "$CAPTURE_ROOT/${CAPTURE_NAME}_cli.py" ``` 如果 `CAPTURE_NAME` 包含 `-` 等不能作为 Python 模块名的字符,生成器会自动规范化输出文件名,例如 `test-domain_cli.py` 会写为 `test_domain_cli.py`,并在命令输出中打印实际路径。 -如需同时产出文档可继续执行: +生成验证文件: ```bash uv run python .flocks/plugins/skills/web2cli/scripts/generate-cli.py \ - "$CAPTURE_ROOT/captures/${CAPTURE_NAME}_api.json" \ + --spec "$CAPTURE_ROOT/web2cli-spec.json" \ + --format verify \ + --output "$CAPTURE_ROOT/${CAPTURE_NAME}_verify.json" +``` + +生成接口文档: + +```bash +uv run python .flocks/plugins/skills/web2cli/scripts/generate-cli.py \ + --spec "$CAPTURE_ROOT/web2cli-spec.json" \ --format markdown \ --title "${CAPTURE_NAME} API Documentation" \ --output "$CAPTURE_ROOT/${CAPTURE_NAME}_api.md" ``` -### 9. CLI工具验证 和浏览器关闭 +### 10. CLI工具验证 和浏览器关闭 根据生成的 CLI ,任意选择一个接口调用测试可用性 - CLI 工具可用性 - 认证状态可用性 +- `verify.json` 的输出约束是否满足 + +推荐先查看 `"$CAPTURE_ROOT/${CAPTURE_NAME}_verify.json"`,再用生成的 CLI 以默认参数执行一次,确认固定输出列与认证状态都正确。 当验证完成,确保 CLI 可用后关闭浏览器或 Tab @@ -382,12 +422,12 @@ else: `cdp-direct` 必须保留用户原有的 tab 不受影响。 -### 10. summary +### 11. summary 总结当前 生成 的CLI 工具有哪些能力,然后可提示用户下一步操作: -- 保存为对应的 skill 方便后续操作 -- 精简 CLI +- 精简或修正CLI - 进一步丰富 CLI 工具,重新开始 web2cli标准流程 +- 保存为对应的 skill 方便后续操作(进入此操作后,需要阅读references) ## 故障处理 @@ -419,3 +459,6 @@ else: - `agent-browser`:重新登录后再次执行保存状态命令。 - `cdp-direct`:重新登录后再次执行保存认证状态。 + +## Reference +- references/cli-in-skill.md 将生成的 CLI 集成到 skill 中使用 diff --git a/.flocks/plugins/skills/web2cli/references/cli-in-skill.md b/.flocks/plugins/skills/web2cli/references/cli-in-skill.md new file mode 100644 index 00000000..6210c364 --- /dev/null +++ b/.flocks/plugins/skills/web2cli/references/cli-in-skill.md @@ -0,0 +1,171 @@ +# 生成后的 CLI 如何接入 Skill + +> 本文只说明一件事:`web2cli` 已经生成出 CLI 之后,怎样把它整理成可长期维护的 skill 资产。 + +## 命名约定 + +生成阶段的文件名通常来自抓包名,例如 `_cli.py`。这个名字适合临时验证,不适合直接沉淀到 skill。 + +落到 skill 时,统一改成**稳定的产品名**: + +- skill 目录:`$HOME/.flocks/plugins/skills/-use/` +- CLI 主脚本:`$HOME/.flocks/plugins/skills/-use/scripts/_cli.py` +- 默认认证状态:`~/.flocks/browser//auth-state.json` + +约定说明: + +- `` 用产品或系统的稳定标识,不用一次性任务名 +- 目录名可以保留 `-`,例如 `tdp-use` +- Python 脚本名统一用 `_`,例如 `tdp_cli.py` +- 不要把最终 CLI 保留成 `export_data_cli.py`、`test_capture_cli.py` 这类临时名字 + +## 放到已有产品 Skill + +如果仓库里已经有对应产品 skill,直接把生成结果并入现有 skill: + +```bash +SKILL_ROOT="$HOME/.flocks/plugins/skills/-use" + +mkdir -p "$SKILL_ROOT/scripts" +mkdir -p "$HOME/.flocks/browser/" + +cp "$CAPTURE_ROOT/_cli.py" \ + "$SKILL_ROOT/scripts/_cli.py" + +cp "$CAPTURE_ROOT/auth-state.json" \ + "$HOME/.flocks/browser//auth-state.json" +``` + +然后补齐这几项: + +1. 在 `scripts/config.py` 中把认证状态默认值指向 `~/.flocks/browser//auth-state.json` +2. 在 `references/cli-reference.md` 中写清楚 CLI 用法、环境变量和示例 +3. 在 `references/browser-workflow.md` 中写清楚浏览器登录与保存 state 的流程 +4. 在 `SKILL.md` 中说明什么时候优先走 CLI,什么时候退回浏览器 + +推荐的配置写法: + +```python +import os +from pathlib import Path + +AUTH_STATE_FILE = Path( + os.getenv( + "_AUTH_STATE", + Path.home() / ".flocks" / "browser" / "" / "auth-state.json", + ) +) +``` + +这样做的好处是: + +- 默认行为统一,和现有产品 skill 保持一致 +- 允许用户用环境变量覆盖 +- 生成阶段的临时产物和最终长期使用的认证文件分离 + +## 生成新的 Skill + +如果当前仓库里还没有对应产品 skill,就按下面的最小结构创建: + +```text +$HOME/.flocks/plugins/skills/-use/ +├── SKILL.md +├── scripts/ +│ ├── _cli.py +│ └── config.py +└── references/ + ├── browser-workflow.md + └── cli-reference.md +``` + +其中 `SKILL.md` 必须遵守 Flocks 的标准 skill 格式: + +- 文件开头必须是 YAML frontmatter,第一行必须为 `---` +- frontmatter 至少包含 `name` 和 `description` +- `name` 使用稳定的 skill 标识,推荐与目录名一致,例如 `-use` +- frontmatter 结束后,再写正文标题、触发条件、模式判断和使用说明 + +最小模板示例: + +```md +--- +name: test-use +description: 用于查询 Test 测试平台数据,支持通过 CLI 快速查询,认证失效时退回浏览器模式。 +--- + +# Test Use + +## 触发条件 + +- 用户提到 Test 平台 +- 用户需要查询 Test 数据 + +## 模式判断 + +### CLI 模式(默认) + +- 适用于快速查询和批量读取数据 + +### 浏览器模式 + +- 适用于需要页面交互、导出或重新登录的场景 +``` + +不要把 `SKILL.md` 直接写成普通 Markdown 文档,例如下面这种格式是无效的: + +```md +# Test Use +``` + +各文件职责: + +- `SKILL.md`:定义触发条件、模式判断、总入口说明 +- `scripts/_cli.py`:承载生成并整理后的 CLI 能力 +- `scripts/config.py`:集中管理 `BASE_URL`、`AUTH_STATE_FILE`、超时、SSL 等默认配置 +- `references/browser-workflow.md`:写浏览器登录、保存 state、认证恢复流程 +- `references/cli-reference.md`:写 CLI 参数、命令示例、常见查询 + +新 skill 的原则也一样:先把生成的 CLI 改成稳定文件名,再把临时 `auth-state.json` 切换到全局默认位置 `~/.flocks/browser//auth-state.json`。 + +## 认证失败怎么处理 + +CLI 调用出现以下情况时,优先按认证失效处理: + +- 返回 `401` 或 `403` +- 返回内容出现 `Unauthorized`、`login`、未登录、无权限 +- `auth-state.json` 已存在,但请求仍然被重定向到登录页 + +处理原则: + +1. 不要无限重试 CLI +2. 请求用户重新通过浏览器登录 +3. 登录完成后,重新保存认证状态到默认路径 +4. 再重试一次 CLI + +默认认证文件路径固定为: + +```bash +~/.flocks/browser//auth-state.json +``` + +保存方式示例: + +```bash +mkdir -p "$HOME/.flocks/browser/" + +# agent-browser 模式 +agent-browser state save "$HOME/.flocks/browser//auth-state.json" + +# 或 cdp-direct / flocks browser 模式 +flocks browser state save "$HOME/.flocks/browser//auth-state.json" +``` + +如果用户重新登录并保存 state 后,CLI 仍然失败,再继续排查: + +- `BASE_URL` 是否写错 +- 当前账号是否确实有接口权限 +- 站点是否还有额外 header / token / csrf 依赖 + +## 一句话原则 + +`web2cli` 产出的 `_cli.py` 是临时结果;真正沉淀到 skill 时,要改成稳定产品名脚本,并把认证状态统一落到 `~/.flocks/browser//auth-state.json`。 diff --git a/.flocks/plugins/skills/web2cli/scripts/generate-cli.py b/.flocks/plugins/skills/web2cli/scripts/generate-cli.py index f6909d2b..3e41fff4 100644 --- a/.flocks/plugins/skills/web2cli/scripts/generate-cli.py +++ b/.flocks/plugins/skills/web2cli/scripts/generate-cli.py @@ -411,37 +411,515 @@ def generate_postman_collection(requests: List[Dict], base_url: str) -> Dict: return collection +def load_spec(spec_path: str) -> Dict[str, Any]: + """Load a web2cli spec from disk.""" + with open(spec_path, encoding="utf-8") as f: + payload = json.load(f) + if not isinstance(payload, dict): + raise ValueError("Spec file must contain a JSON object") + return payload + + +def generate_verify_materials_from_spec(spec: Dict[str, Any]) -> Dict[str, Any]: + """Generate verify metadata from a web2cli spec.""" + verify = spec.get("verify", {}) if isinstance(spec.get("verify"), dict) else {} + columns = spec.get("columns", []) + column_names = [column.get("name") for column in columns if isinstance(column, dict) and column.get("name")] + + return { + "site": spec.get("site", ""), + "command": spec.get("command", ""), + "args": verify.get("args", {}), + "expect": { + "rowCount": verify.get("rowCount", {"min": 1}), + "columns": verify.get("columns", column_names), + "types": verify.get( + "types", + { + column.get("name"): column.get("type", "string") + for column in columns + if isinstance(column, dict) and column.get("name") + }, + ), + "notEmpty": verify.get("notEmpty", column_names[: min(3, len(column_names))]), + "patterns": verify.get("patterns", {}), + }, + } + + +def generate_markdown_docs_from_spec(spec: Dict[str, Any], title: str = "API Documentation") -> str: + """Generate Markdown documentation from a web2cli spec.""" + operation = spec.get("operation", {}) + args = spec.get("args", []) + columns = spec.get("columns", []) + verify = generate_verify_materials_from_spec(spec) + + md = f"""# {title} + +> Auto-generated Web2CLI Specification +> Site: `{spec.get("site", "")}` +> Command: `{spec.get("command", "")}` + +## 概览 + +- **描述**: {spec.get("description", "N/A")} +- **策略**: `{spec.get("strategy", "PUBLIC")}` +- **Base URL**: `{spec.get("baseUrl", "")}` +- **Method**: `{operation.get("method", "GET")}` +- **Endpoint**: `{operation.get("endpoint", "/")}` + +## 参数 + +""" + + if args: + md += "| 参数 | 类型 | 默认值 | 说明 |\n" + md += "|------|------|--------|------|\n" + for arg in args: + md += f"| `{arg.get('name', '')}` | `{arg.get('type', 'string')}` | `{arg.get('default', '')}` | {arg.get('help', '')} |\n" + md += "\n" + else: + md += "无参数。\n\n" + + md += "## 输出列\n\n" + md += "| 列名 | 类型 | 路径 |\n" + md += "|------|------|------|\n" + for column in columns: + md += f"| `{column.get('name', '')}` | `{column.get('type', 'string')}` | `{column.get('path', '')}` |\n" + + md += "\n## 验证建议\n\n" + md += f"- 默认参数: `{json.dumps(verify['args'], ensure_ascii=False)}`\n" + md += f"- 最少行数: `{verify['expect']['rowCount'].get('min', 0)}`\n" + md += f"- 必填列: `{', '.join(verify['expect']['notEmpty'])}`\n" + + return md + + +def generate_postman_collection_from_spec(spec: Dict[str, Any]) -> Dict[str, Any]: + """Generate a minimal Postman collection from a web2cli spec.""" + operation = spec.get("operation", {}) + headers = operation.get("headers", {}) if isinstance(operation.get("headers"), dict) else {} + body_template = operation.get("bodyTemplate", {}) if isinstance(operation.get("bodyTemplate"), dict) else {} + endpoint = operation.get("endpoint", "/") + path_parts = endpoint.lstrip("/").split("/") if endpoint.lstrip("/") else [] + + request = { + "method": operation.get("method", "GET"), + "url": { + "raw": f"{{{{base_url}}}}{endpoint}", + "host": ["{{base_url}}"], + "path": path_parts, + }, + "header": [{"key": key, "value": value} for key, value in headers.items()], + } + if body_template: + request["body"] = { + "mode": "raw", + "raw": json.dumps(body_template, ensure_ascii=False), + "options": {"raw": {"language": "json"}}, + } + + return { + "info": { + "name": f"{spec.get('site', 'captured')} {spec.get('command', 'command')}", + "schema": "https://schema.getpostman.com/json/collection/v2.1.0/collection.json", + }, + "item": [ + { + "name": spec.get("command", endpoint), + "request": request, + } + ], + "variable": [{"key": "base_url", "value": spec.get("baseUrl", "")}], + } + + +def generate_python_cli_from_spec(spec: Dict[str, Any]) -> str: + """Generate a fixed command CLI script from a web2cli spec.""" + spec_json = json.dumps(spec, indent=2, ensure_ascii=False) + return '''#!/usr/bin/env python3 +""" +Auto-generated Web2CLI command script. +Generated from web2cli-spec.json +""" + +import argparse +import csv +import json +import sys +from typing import Any, Dict, List + +import requests + + +SPEC = ''' + spec_json + ''' + + +def _load_json(path: str) -> Dict[str, Any]: + if not path: + return {} + try: + with open(path, encoding="utf-8") as f: + payload = json.load(f) + except FileNotFoundError: + return {} + except json.JSONDecodeError: + return {} + return payload if isinstance(payload, dict) else {} + + +def _coerce_bool(value: str) -> bool: + normalized = str(value).strip().lower() + if normalized in {"1", "true", "yes", "y", "on"}: + return True + if normalized in {"0", "false", "no", "n", "off"}: + return False + raise argparse.ArgumentTypeError(f"invalid boolean value: {value}") + + +def _type_name(value: Any) -> str: + if value is None: + return "null" + if isinstance(value, bool): + return "bool" + if isinstance(value, int) and not isinstance(value, bool): + return "int" + if isinstance(value, float): + return "float" + if isinstance(value, list): + return "array" + if isinstance(value, dict): + return "object" + return "string" + + +class APIClient: + """Fixed command client generated from a web2cli spec.""" + + @staticmethod + def _load_cookie_items(auth_state_path: str) -> List[Dict[str, Any]]: + payload = _load_json(auth_state_path) + cookies = payload.get("cookies", []) + if isinstance(cookies, list): + return [cookie for cookie in cookies if isinstance(cookie, dict)] + return [] + + @staticmethod + def _load_storage_map(payload: Dict[str, Any]) -> Dict[str, str]: + values = {} + for origin_entry in payload.get("origins", []): + if not isinstance(origin_entry, dict): + continue + for item in origin_entry.get("localStorage", []): + if isinstance(item, dict) and item.get("name"): + values[item["name"]] = item.get("value", "") + return values + + @staticmethod + def _resolve_header_value(payload: Dict[str, Any], rule: Dict[str, Any]) -> str | None: + source = rule.get("source") + key = rule.get("key") + if source == "cookie": + for cookie in payload.get("cookies", []): + if isinstance(cookie, dict) and cookie.get("name") == key: + return str(cookie.get("value", "")) + if source == "localStorage": + return APIClient._load_storage_map(payload).get(str(key)) + return None + + @staticmethod + def _resolve_template(value: Any, args: Dict[str, Any]) -> Any: + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + return args.get(value[2:-1], value) + if isinstance(value, dict): + return {key: APIClient._resolve_template(item, args) for key, item in value.items()} + if isinstance(value, list): + return [APIClient._resolve_template(item, args) for item in value] + return value + + @staticmethod + def _tokenize_path(path: str) -> List[str]: + if not path or path == "$": + return [] + normalized = path + if normalized.startswith("$."): + normalized = normalized[2:] + elif normalized.startswith("$"): + normalized = normalized[1:] + normalized = normalized.replace("[]", ".[]") + return [token for token in normalized.split(".") if token] + + @classmethod + def _extract_many(cls, value: Any, path: str) -> List[Any]: + tokens = cls._tokenize_path(path) + current = [value] + for token in tokens: + next_values = [] + if token == "[]": + for item in current: + if isinstance(item, list): + next_values.extend(item) + else: + for item in current: + if isinstance(item, dict) and token in item: + next_values.append(item[token]) + current = next_values + if not current: + break + return current + + @classmethod + def _extract_first(cls, value: Any, path: str) -> Any: + if not path or path == "$": + return value + values = cls._extract_many(value, path) + return values[0] if values else None + + def __init__(self, base_url: str = SPEC.get("baseUrl", ""), auth_state: str = "auth-state.json"): + self.base_url = (base_url or SPEC.get("baseUrl", "")).rstrip("/") + self.auth_state_path = auth_state + self.auth_state = _load_json(auth_state) if auth_state else {} + self.session = requests.Session() + self._apply_auth_state() + + def _apply_auth_state(self) -> None: + strategy = SPEC.get("strategy", "PUBLIC") + auth = SPEC.get("auth", {}) + headers = SPEC.get("operation", {}).get("headers", {}) + if isinstance(headers, dict) and headers: + self.session.headers.update(headers) + + if strategy in {"COOKIE", "HEADER"}: + for cookie in self._load_cookie_items(self.auth_state_path): + name = cookie.get("name") + if not name: + continue + kwargs = {} + if cookie.get("domain"): + kwargs["domain"] = cookie["domain"] + if cookie.get("path"): + kwargs["path"] = cookie["path"] + self.session.cookies.set(name, cookie.get("value", ""), **kwargs) + + if strategy == "HEADER": + for rule in auth.get("requiredHeaders", []): + if not isinstance(rule, dict) or not rule.get("name"): + continue + value = self._resolve_header_value(self.auth_state, rule) + if value: + self.session.headers[str(rule["name"])] = value + + def build_request(self, args: Dict[str, Any]) -> Dict[str, Any]: + operation = SPEC.get("operation", {}) + endpoint = operation.get("endpoint", "/") + query = self._resolve_template(operation.get("queryTemplate", {}), args) + body = self._resolve_template(operation.get("bodyTemplate", {}), args) + return { + "method": operation.get("method", "GET"), + "url": f"{self.base_url}{endpoint}", + "params": query or None, + "json": body or None, + } + + def _project_rows(self, payload: Any) -> List[Dict[str, Any]]: + row_source = SPEC.get("rowSource", {}) + collection_path = row_source.get("collectionPath") or row_source.get("path") or "$" + collection = self._extract_many(payload, collection_path) if collection_path != "$" else [payload] + if not collection: + return [] + + rows = [] + columns = SPEC.get("columns", []) + for index, row in enumerate(collection, start=1): + projected = {} + for column in columns: + if not isinstance(column, dict) or not column.get("name"): + continue + rel_path = column.get("relativePath") or column.get("path") or "$" + if rel_path == "__index__": + value = index + elif rel_path.startswith("$."): + value = self._extract_first(payload, rel_path) + else: + value = self._extract_first(row, rel_path) + projected[column["name"]] = value + rows.append(projected) + return rows + + def run(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: + request_options = self.build_request(args) + response = self.session.request( + request_options["method"], + request_options["url"], + params=request_options["params"], + json=request_options["json"], + ) + response.raise_for_status() + return self._project_rows(response.json()) + + +def verify_rows(rows: List[Dict[str, Any]], verify_spec: Dict[str, Any]) -> List[str]: + errors = [] + expect = verify_spec.get("expect", verify_spec) + row_count = expect.get("rowCount", {}) + min_rows = row_count.get("min") + max_rows = row_count.get("max") + + if min_rows is not None and len(rows) < min_rows: + errors.append(f"rowCount too small: expected >= {min_rows}, got {len(rows)}") + if max_rows is not None and len(rows) > max_rows: + errors.append(f"rowCount too large: expected <= {max_rows}, got {len(rows)}") + + columns = expect.get("columns", []) + types = expect.get("types", {}) + not_empty = expect.get("notEmpty", []) + patterns = expect.get("patterns", {}) + + for row in rows: + for column in columns: + if column not in row: + errors.append(f"missing column: {column}") + for column in not_empty: + if row.get(column) in (None, "", [], {}): + errors.append(f"empty required column: {column}") + for column, expected_type in types.items(): + if column in row and row[column] is not None and _type_name(row[column]) != expected_type: + errors.append( + f"type mismatch for {column}: expected {expected_type}, got {_type_name(row[column])}" + ) + for column, pattern in patterns.items(): + if column in row and row[column] is not None: + import re + if not re.search(pattern, str(row[column])): + errors.append(f"pattern mismatch for {column}: {pattern}") + + return errors + + +def _print_rows(rows: List[Dict[str, Any]], output_format: str) -> None: + if output_format == "json": + print(json.dumps(rows, ensure_ascii=False, indent=2)) + return + if not rows: + return + columns = list(rows[0].keys()) + if output_format == "csv": + writer = csv.DictWriter(sys.stdout, fieldnames=columns) + writer.writeheader() + writer.writerows(rows) + return + print("\\t".join(columns)) + for row in rows: + print("\\t".join("" if row.get(column) is None else str(row.get(column)) for column in columns)) + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description=SPEC.get("description", "Generated Web2CLI command")) + parser.add_argument("--base-url", default=SPEC.get("baseUrl", ""), help="Override base URL") + parser.add_argument( + "--auth-state", + default=(SPEC.get("auth", {}) or {}).get("stateFile", "auth-state.json"), + help="Path to auth state JSON", + ) + parser.add_argument("--format", choices=["json", "csv", "table"], default="json", help="Output format") + parser.add_argument("--verify", action="store_true", help="Validate rows against embedded or external verify spec") + parser.add_argument("--verify-spec", help="Optional verify JSON path") + for arg in SPEC.get("args", []): + if not isinstance(arg, dict) or not arg.get("name"): + continue + option = "--" + str(arg["name"]).replace("_", "-") + arg_type = arg.get("type", "string") + kwargs = { + "dest": arg["name"], + "default": arg.get("default"), + "help": arg.get("help", ""), + } + if arg_type == "int": + kwargs["type"] = int + elif arg_type == "float": + kwargs["type"] = float + elif arg_type == "bool": + kwargs["type"] = _coerce_bool + else: + kwargs["type"] = str + parser.add_argument(option, **kwargs) + return parser + + +def main() -> None: + parser = build_parser() + parsed = parser.parse_args() + runtime_args = { + item["name"]: getattr(parsed, item["name"]) + for item in SPEC.get("args", []) + if isinstance(item, dict) and item.get("name") + } + client = APIClient(base_url=parsed.base_url, auth_state=parsed.auth_state) + rows = client.run(runtime_args) + + if parsed.verify: + verify_spec = _load_json(parsed.verify_spec) if parsed.verify_spec else SPEC.get("verify", {}) + errors = verify_rows(rows, verify_spec) + if errors: + raise SystemExit("\\n".join(errors)) + + _print_rows(rows, parsed.format) + + +if __name__ == "__main__": + main() +''' + + def main(): - parser = argparse.ArgumentParser(description='Generate CLI/docs from captured APIs') - parser.add_argument('input', help='Input JSON file with captured requests') + parser = argparse.ArgumentParser(description='Generate CLI/docs from captured APIs or a web2cli spec') + parser.add_argument('input', nargs='?', help='Input JSON file with captured requests') + parser.add_argument('--spec', help='Input web2cli-spec.json file') parser.add_argument('--output', '-o', help='Output file') parser.add_argument('--base-url', '-u', default='https://example.com', help='Base URL') - parser.add_argument('--format', '-f', choices=['python', 'markdown', 'postman'], + parser.add_argument('--format', '-f', choices=['python', 'markdown', 'postman', 'verify'], default='markdown', help='Output format') parser.add_argument('--title', '-t', default='API Documentation', help='Document title') args = parser.parse_args() - # Load input - with open(args.input) as f: - data = json.load(f) - - # Handle both array and object formats - requests = data if isinstance(data, list) else data.get('requests', []) - - if not requests: - print("No requests found in input file", file=sys.stderr) - sys.exit(1) - - print(f"Processing {len(requests)} requests, {len(group_endpoints(requests))} unique endpoints...") - - # Generate output - if args.format == 'python': - output = generate_python_client(requests, args.base_url) - elif args.format == 'postman': - output = json.dumps(generate_postman_collection(requests, args.base_url), indent=2, ensure_ascii=False) + if not args.input and not args.spec: + parser.error('either input or --spec is required') + + if args.spec: + spec = load_spec(args.spec) + if args.format == 'python': + output = generate_python_cli_from_spec(spec) + elif args.format == 'verify': + output = json.dumps(generate_verify_materials_from_spec(spec), indent=2, ensure_ascii=False) + elif args.format == 'postman': + output = json.dumps(generate_postman_collection_from_spec(spec), indent=2, ensure_ascii=False) + else: + output = generate_markdown_docs_from_spec(spec, args.title) else: - output = generate_markdown_docs(requests, args.title) + # Load input + with open(args.input, encoding='utf-8') as f: + data = json.load(f) + + # Handle both array and object formats + requests = data if isinstance(data, list) else data.get('requests', []) + + if not requests: + print("No requests found in input file", file=sys.stderr) + sys.exit(1) + + print(f"Processing {len(requests)} requests, {len(group_endpoints(requests))} unique endpoints...") + + # Generate output + if args.format == 'python': + output = generate_python_client(requests, args.base_url) + elif args.format == 'postman': + output = json.dumps(generate_postman_collection(requests, args.base_url), indent=2, ensure_ascii=False) + elif args.format == 'verify': + print("verify output requires --spec", file=sys.stderr) + sys.exit(1) + else: + output = generate_markdown_docs(requests, args.title) # Write output output_path = args.output diff --git a/.flocks/plugins/skills/web2cli/scripts/generate-spec.py b/.flocks/plugins/skills/web2cli/scripts/generate-spec.py new file mode 100644 index 00000000..c7b01fdc --- /dev/null +++ b/.flocks/plugins/skills/web2cli/scripts/generate-spec.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python3 +"""Generate a web2cli spec from captured API requests.""" + +from __future__ import annotations + +import argparse +import json +import keyword +import re +import sys +from pathlib import Path +from typing import Any +from urllib.parse import parse_qsl, urlparse + + +PAGE_PARAM_NAMES = {"page", "pageNo", "pageNum", "current", "pageIndex", "curPage"} +LIMIT_PARAM_NAMES = {"limit", "size", "pageSize", "page_size", "page_limit", "rows"} + + +def sanitize_name(name: str) -> str: + """Convert text to a valid Python/CLI-friendly identifier.""" + value = re.sub(r"\?.*$", "", name) + value = re.sub(r"[^a-zA-Z0-9_]", "_", value) + value = re.sub(r"_+", "_", value) + value = value.strip("_") + if value and value[0].isdigit(): + value = f"_{value}" + value = value.lower() or "endpoint" + if keyword.iskeyword(value): + value = f"{value}_" + return value + + +def load_requests(input_path: str) -> list[dict[str, Any]]: + """Load captured request list from disk.""" + with open(input_path, encoding="utf-8") as f: + payload = json.load(f) + + requests = payload if isinstance(payload, list) else payload.get("requests", []) + return [item for item in requests if isinstance(item, dict)] + + +def parse_json_text(text: str) -> Any: + """Parse a response/request body string when possible.""" + if not text: + return {} + + value = text.strip() + if value.endswith("...[truncated]"): + value = value[: -len("...[truncated]")] + + try: + return json.loads(value) + except json.JSONDecodeError: + return {"raw": text} + + +def infer_type(value: Any) -> str: + """Return a compact type name for spec/verify output.""" + if value is None: + return "null" + if isinstance(value, bool): + return "bool" + if isinstance(value, int) and not isinstance(value, bool): + return "int" + if isinstance(value, float): + return "float" + if isinstance(value, list): + return "array" + if isinstance(value, dict): + return "object" + return "string" + + +def normalize_url_info(request: dict[str, Any]) -> dict[str, Any]: + """Return normalized URL parts from capture metadata or raw URL.""" + url = ( + request.get("normalizedUrl") + or request.get("url") + or "" + ) + parsed = urlparse(url) + query_items = dict(parse_qsl(parsed.query, keep_blank_values=True)) + return { + "url": url, + "origin": request.get("origin") or (f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else ""), + "pathname": request.get("pathname") or (parsed.path or "/"), + "query": request.get("query") or query_items, + "queryKeys": request.get("queryKeys") or list(query_items.keys()), + "host": parsed.netloc, + } + + +def score_request(request: dict[str, Any], index: int) -> tuple[int, int]: + """Score a captured request to decide which one should become the spec.""" + score = 0 + response = parse_json_text(str(request.get("response", ""))) + action = ((request.get("actionContext") or {}).get("lastAction") or {}).get("action") + + status = request.get("status") + if isinstance(status, int) and 200 <= status < 300: + score += 30 + elif status == "error": + score -= 20 + + if request.get("captureReason") in {"nonGet", "captureModeAll", "includePattern"}: + score += 15 + if action: + score += 12 + if isinstance(response, dict) and "raw" not in response: + score += 20 + + collection = find_best_collection(response) + if collection is not None: + score += 20 + min(collection["length"], 20) + + return score, index + + +def choose_primary_request(requests: list[dict[str, Any]]) -> dict[str, Any]: + """Pick the best request candidate from the captured request list.""" + if not requests: + raise ValueError("No captured requests available") + ranked = sorted( + ((score_request(req, index), req) for index, req in enumerate(requests)), + key=lambda item: (item[0][0], item[0][1]), + reverse=True, + ) + return ranked[0][1] + + +def find_collections(value: Any, path: str = "$") -> list[dict[str, Any]]: + """Find likely row collections inside a JSON response.""" + results: list[dict[str, Any]] = [] + + if isinstance(value, list): + item = value[0] if value else None + score = 10 + if isinstance(item, dict): + score += 25 + elif item is not None: + score += 10 + results.append( + { + "collectionPath": path + "[]", + "path": path, + "length": len(value), + "item": item, + "score": score, + } + ) + if isinstance(item, dict): + for key, child in item.items(): + results.extend(find_collections(child, path + "[]." + key)) + return results + + if isinstance(value, dict): + for key, child in value.items(): + next_path = path + "." + key if path != "$" else "$." + key + results.extend(find_collections(child, next_path)) + + return results + + +def find_best_collection(value: Any) -> dict[str, Any] | None: + """Return the highest scoring collection candidate from the response.""" + candidates = find_collections(value) + if not candidates: + return None + candidates.sort( + key=lambda item: ( + item["score"], + item["length"], + -len(item["path"]), + ), + reverse=True, + ) + return candidates[0] + + +def collect_columns(item: Any) -> list[dict[str, Any]]: + """Infer a compact column list from a sample row.""" + columns: list[dict[str, Any]] = [] + + if isinstance(item, dict): + for key, value in item.items(): + if isinstance(value, (dict, list)): + if isinstance(value, dict): + for nested_key, nested_value in value.items(): + if isinstance(nested_value, (dict, list)): + continue + columns.append( + { + "name": sanitize_name(f"{key}_{nested_key}"), + "path": "$." + key + "." + nested_key, + "relativePath": key + "." + nested_key, + "sourceField": nested_key, + "type": infer_type(nested_value), + } + ) + continue + columns.append( + { + "name": sanitize_name(key), + "path": "$." + key, + "relativePath": key, + "sourceField": key, + "type": infer_type(value), + } + ) + if len(columns) >= 8: + break + elif item is not None: + columns.append( + { + "name": "value", + "path": "$", + "relativePath": "$", + "sourceField": "value", + "type": infer_type(item), + } + ) + + if not columns: + columns.append( + { + "name": "value", + "path": "$", + "relativePath": "$", + "sourceField": "value", + "type": "string", + } + ) + return columns + + +def build_templates(request: dict[str, Any], url_info: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any], list[dict[str, Any]]]: + """Build query/body templates and CLI arg definitions.""" + args: list[dict[str, Any]] = [] + seen_args: set[str] = set() + + def add_arg(name: str, default: Any, help_text: str) -> None: + if name in seen_args: + return + seen_args.add(name) + arg_type = "int" if isinstance(default, int) else "string" + args.append({"name": name, "type": arg_type, "default": default, "help": help_text}) + + def transform_mapping(data: dict[str, Any]) -> dict[str, Any]: + result: dict[str, Any] = {} + for key, value in data.items(): + if key in PAGE_PARAM_NAMES: + default = int(value) if str(value).isdigit() else 1 + result[key] = "${page}" + add_arg("page", default, "Page number") + elif key in LIMIT_PARAM_NAMES: + default = int(value) if str(value).isdigit() else 20 + result[key] = "${limit}" + add_arg("limit", default, "Page size") + else: + result[key] = value + return result + + body = parse_json_text(str(request.get("requestBody", ""))) + if not isinstance(body, dict) or "raw" in body: + body = {} + + query_template = transform_mapping(url_info["query"]) + body_template = transform_mapping(body) + + args.sort(key=lambda item: (0 if item["name"] == "page" else 1 if item["name"] == "limit" else 2, item["name"])) + return query_template, body_template, args + + +def build_strategy(request: dict[str, Any]) -> tuple[str, dict[str, Any]]: + """Infer auth strategy and auth metadata from request headers.""" + headers = request.get("requestHeaders", {}) or request.get("request_headers", {}) + normalized = {str(key).lower(): value for key, value in headers.items()} + strategy = "PUBLIC" + required_headers: list[dict[str, Any]] = [] + + if "authorization" in normalized: + strategy = "HEADER" + required_headers.append({"name": "Authorization", "source": "manual", "key": "authorization"}) + elif "cookie" in normalized: + strategy = "COOKIE" + + for header_name in ("x-csrf-token", "x-xsrf-token", "x-auth-token"): + if header_name in normalized: + strategy = "HEADER" + required_headers.append({"name": header_name, "source": "manual", "key": header_name}) + + return strategy, {"stateFile": "auth-state.json", "requiredCookies": [], "requiredHeaders": required_headers} + + +def safe_headers(request: dict[str, Any]) -> dict[str, Any]: + """Return non-sensitive request headers that can be replayed safely.""" + headers = request.get("requestHeaders", {}) or request.get("request_headers", {}) + result = {} + for key, value in headers.items(): + if str(key).lower() in {"cookie", "authorization", "x-csrf-token", "x-xsrf-token", "x-auth-token"}: + continue + result[key] = value + return result + + +def site_name_from_host(host: str) -> str: + """Return a readable site name from a host.""" + cleaned = host.split(":")[0] + parts = [part for part in cleaned.split(".") if part not in {"www", "api", "m"}] + if len(parts) >= 2: + return sanitize_name(parts[-2]) + if parts: + return sanitize_name(parts[0]) + return "captured_site" + + +def command_name_from_path(pathname: str) -> str: + """Return a command name from an API pathname.""" + parts = [part for part in pathname.split("/") if part] + return sanitize_name(parts[-1] if parts else "command") + + +def generate_spec_from_requests(requests: list[dict[str, Any]], *, base_url: str | None = None) -> dict[str, Any]: + """Build a web2cli spec object from captured request data.""" + request = choose_primary_request(requests) + url_info = normalize_url_info(request) + response = parse_json_text(str(request.get("response", ""))) + collection = find_best_collection(response) + row_item = collection["item"] if collection is not None else response + query_template, body_template, args = build_templates(request, url_info) + strategy, auth = build_strategy(request) + columns = collect_columns(row_item) + + defaults = {item["name"]: item["default"] for item in args} + verify_types = {column["name"]: column["type"] for column in columns} + verify_not_empty = [column["name"] for column in columns[: min(3, len(columns))]] + row_count = {"min": 1} + if collection and collection["length"]: + row_count["max"] = collection["length"] + + purpose = request.get("apiPurpose", {}) if isinstance(request.get("apiPurpose"), dict) else {} + host_origin = base_url or url_info["origin"] or "https://example.com" + pathname = url_info["pathname"] or "/" + site = site_name_from_host(urlparse(host_origin).netloc or url_info["host"]) + command = purpose.get("name") or command_name_from_path(pathname) + + return { + "schemaVersion": "1.0", + "site": site, + "command": sanitize_name(command), + "description": purpose.get("desc") or f"Generated from {request.get('method', 'GET')} {pathname}", + "baseUrl": host_origin, + "strategy": strategy, + "auth": auth, + "operation": { + "method": request.get("method", "GET"), + "endpoint": pathname, + "queryTemplate": query_template, + "bodyTemplate": body_template, + "headers": safe_headers(request), + "captureSource": request.get("captureSource", "pageHook"), + "captureReason": request.get("captureReason", ""), + "sourceRequestId": request.get("timestamp", ""), + }, + "rowSource": { + "path": collection["collectionPath"] if collection else "$", + "collectionPath": collection["collectionPath"] if collection else "$", + }, + "args": args, + "columns": columns, + "verify": { + "args": defaults, + "rowCount": row_count, + "columns": [column["name"] for column in columns], + "types": verify_types, + "notEmpty": verify_not_empty, + "patterns": {}, + }, + } + + +def main() -> None: + """CLI entrypoint.""" + parser = argparse.ArgumentParser(description="Generate a web2cli spec from captured APIs") + parser.add_argument("input", help="Input JSON file with captured requests") + parser.add_argument("--output", "-o", help="Output spec path") + parser.add_argument("--base-url", help="Optional base URL override") + args = parser.parse_args() + + requests = load_requests(args.input) + if not requests: + print("No requests found in input file", file=sys.stderr) + sys.exit(1) + + spec = generate_spec_from_requests(requests, base_url=args.base_url) + rendered = json.dumps(spec, indent=2, ensure_ascii=False) + + if args.output: + output_path = Path(args.output) + output_path.write_text(rendered, encoding="utf-8") + print(f"Written to {output_path}") + else: + print(rendered) + + +if __name__ == "__main__": + main() diff --git a/.flocks/plugins/skills/web2cli/scripts/inject-hook-base.js b/.flocks/plugins/skills/web2cli/scripts/inject-hook-base.js index 2a009cac..5e11d574 100644 --- a/.flocks/plugins/skills/web2cli/scripts/inject-hook-base.js +++ b/.flocks/plugins/skills/web2cli/scripts/inject-hook-base.js @@ -1,7 +1,7 @@ /** - * API Capture Hook - Simple Version (ES5 compatible) + * API Capture Hook - Base Version (ES5 compatible) */ -(function(){ +(function() { if (window.__apiCapture) { console.log('[API Capture] Already installed'); return; @@ -9,9 +9,10 @@ window.__capturedRequests = []; - // Configuration var CONFIG = { - maxResponseLength: 50000, + maxResponseLength: 2000, + maxRequestBodyLength: 2000, + maxRecentActions: 20, captureMode: 'smart', // 'smart' | 'all' sameOriginOnly: true, includePatterns: [], @@ -24,13 +25,38 @@ ] }; + var recentActions = []; + var navigationState = { + lastNavigation: null, + currentUrl: window.location.href + }; + + function truncateText(text, limit) { + var value = text == null ? '' : String(text); + if (value.length <= limit) { + return value; + } + return value.substring(0, limit) + '...[truncated]'; + } + + function safeTrim(text) { + return String(text || '').replace(/\s+/g, ' ').replace(/^\s+|\s+$/g, ''); + } + + function cloneSimple(value) { + if (value == null || typeof value !== 'object') { + return value; + } + return JSON.parse(JSON.stringify(value)); + } + function normalizeHeaders(headers) { var result = {}; var key; if (!headers) { return result; } - if (typeof Headers !== 'undefined' && headers instanceof Headers) { + if (typeof Headers !== 'undefined' && headers instanceof Headers && headers.forEach) { headers.forEach(function(value, name) { result[name] = value; }); @@ -53,7 +79,9 @@ } function getHeader(headers, name) { - if (!headers) return ''; + if (!headers) { + return ''; + } return headers[name] || headers[name.toLowerCase()] || ''; } @@ -61,54 +89,198 @@ return /\.[a-z0-9]{1,8}$/i.test(pathname || ''); } - function shouldCapture(url, method, headers) { - var u; - var m = (method || 'GET').toUpperCase(); - var normalizedHeaders = normalizeHeaders(headers); - var accept = ''; - var contentType = ''; - var looksJson = false; - var isIgnored = false; - var isIncluded = false; - + function normalizeUrl(url) { + var parsed; + var query = {}; + var queryKeys = []; try { - u = new URL(url, window.location.href); - } catch (e) { - return false; + parsed = new URL(url, window.location.href); + parsed.searchParams.forEach(function(value, key) { + if (!Object.prototype.hasOwnProperty.call(query, key)) { + queryKeys.push(key); + } + query[key] = value; + }); + return { + normalizedUrl: parsed.href, + origin: parsed.origin, + pathname: parsed.pathname, + query: query, + queryKeys: queryKeys + }; + } catch (error) { + return { + normalizedUrl: String(url || ''), + origin: '', + pathname: '', + query: query, + queryKeys: queryKeys + }; } + } + + function inferShape(value, path, out, depth) { + var currentPath = path || '$'; + var nextDepth = depth || 0; + var keys; + var i; - if (CONFIG.sameOriginOnly && u.origin !== window.location.origin) { - return false; + if (nextDepth > 4) { + out[currentPath] = 'depthLimit'; + return; } + if (value === null) { + out[currentPath] = 'null'; + return; + } + if (typeof value === 'undefined') { + out[currentPath] = 'undefined'; + return; + } + if (Array.isArray(value)) { + out[currentPath] = 'array(' + value.length + ')'; + if (value.length > 0) { + inferShape(value[0], currentPath + '[]', out, nextDepth + 1); + } + return; + } + if (typeof value === 'object') { + out[currentPath] = 'object'; + keys = Object.keys(value); + for (i = 0; i < keys.length && i < 20; i++) { + inferShape(value[keys[i]], currentPath + '.' + keys[i], out, nextDepth + 1); + } + return; + } + out[currentPath] = typeof value; + } - isIgnored = CONFIG.ignorePatterns.some(function(p) { return p.test(u.href); }); - if (isIgnored) { - return false; + function detectGraphQL(payload) { + var text; + var parsed; + var operationType = ''; + if (!payload) { + return null; + } + text = typeof payload === 'string' ? payload : ''; + try { + parsed = typeof payload === 'string' ? JSON.parse(payload) : payload; + } catch (error) { + parsed = null; } + if (!parsed || typeof parsed !== 'object') { + return null; + } + if (!parsed.query) { + return null; + } + if (/mutation\s/i.test(parsed.query)) { + operationType = 'mutation'; + } else if (/query\s/i.test(parsed.query)) { + operationType = 'query'; + } else { + operationType = 'graphql'; + } + return { + operationName: parsed.operationName || '', + operationType: operationType, + variablesShape: parsed.variables && typeof parsed.variables === 'object' + ? (function() { + var shape = {}; + inferShape(parsed.variables, '$', shape, 0); + return shape; + })() + : {} + }; + } + + function summarizeBody(body) { + var result = { + kind: 'empty', + display: '', + parsed: null, + shape: {}, + graphql: null + }; + var asObject = {}; - isIncluded = CONFIG.includePatterns.some(function(p) { return p.test(u.href); }); - if (isIncluded) { - return true; + if (body == null || body === '') { + return result; } - if (CONFIG.captureMode === 'all') { - return true; + if (typeof URLSearchParams !== 'undefined' && body instanceof URLSearchParams) { + body.forEach(function(value, key) { + asObject[key] = value; + }); + result.kind = 'urlencoded'; + result.parsed = asObject; + result.display = truncateText(JSON.stringify(asObject, null, 2), CONFIG.maxRequestBodyLength); + inferShape(asObject, '$', result.shape, 0); + return result; } - accept = getHeader(normalizedHeaders, 'accept'); - contentType = getHeader(normalizedHeaders, 'content-type'); - looksJson = /application\/json|text\/plain|application\/x-www-form-urlencoded/i - .test(accept + ' ' + contentType); + if (typeof FormData !== 'undefined' && body instanceof FormData) { + result.kind = 'formData'; + if (typeof body.forEach === 'function') { + body.forEach(function(value, key) { + asObject[key] = Object.prototype.toString.call(value) === '[object File]' ? '[file]' : String(value); + }); + } + result.parsed = asObject; + result.display = truncateText(JSON.stringify(asObject, null, 2), CONFIG.maxRequestBodyLength); + inferShape(asObject, '$', result.shape, 0); + return result; + } - if (m !== 'GET') { - return true; + if (typeof body === 'string') { + result.display = truncateText(body, CONFIG.maxRequestBodyLength); + try { + result.parsed = JSON.parse(body); + result.kind = 'json'; + result.display = truncateText(JSON.stringify(result.parsed, null, 2), CONFIG.maxRequestBodyLength); + inferShape(result.parsed, '$', result.shape, 0); + } catch (error) { + result.kind = 'text'; + result.graphql = detectGraphQL(body); + } + if (!result.graphql && result.parsed) { + result.graphql = detectGraphQL(result.parsed); + } + return result; } - if (!hasStaticExtension(u.pathname || '/')) { - return true; + if (typeof body === 'object') { + result.kind = 'object'; + result.parsed = body; + result.display = truncateText(JSON.stringify(body, null, 2), CONFIG.maxRequestBodyLength); + inferShape(body, '$', result.shape, 0); + result.graphql = detectGraphQL(body); + return result; } - return looksJson; + result.kind = typeof body; + result.display = truncateText(String(body), CONFIG.maxRequestBodyLength); + return result; + } + + function summarizeResponse(text) { + var result = { + display: '', + parsed: null, + shape: {} + }; + if (!text) { + return result; + } + try { + result.parsed = JSON.parse(text); + result.display = truncateText(JSON.stringify(result.parsed, null, 2), CONFIG.maxResponseLength); + inferShape(result.parsed, '$', result.shape, 0); + return result; + } catch (error) { + result.display = truncateText(text, CONFIG.maxResponseLength); + return result; + } } function getPageContext() { @@ -120,19 +292,212 @@ }; } - // Hook XMLHttpRequest + function describeElement(target) { + var tag = target && target.tagName ? String(target.tagName).toUpperCase() : 'UNKNOWN'; + var text = safeTrim(target && target.textContent ? target.textContent : ''); + var label = text || safeTrim(target && target.value ? target.value : ''); + if (!label && target && typeof target.getAttribute === 'function') { + label = safeTrim( + target.getAttribute('aria-label') || + target.getAttribute('title') || + target.getAttribute('name') || + target.getAttribute('placeholder') || + '' + ); + } + if (!label) { + label = (target && target.id) || (target && target.className) || tag; + } + return { + action: label, + tagName: tag, + id: target && target.id ? String(target.id) : '', + className: target && target.className ? String(target.className) : '' + }; + } + + function pushRecentAction(action) { + recentActions.push(action); + if (recentActions.length > CONFIG.maxRecentActions) { + recentActions.shift(); + } + } + + function recordAction(type, detail) { + pushRecentAction({ + type: type, + detail: detail || {}, + action: detail && detail.action ? detail.action : '', + url: window.location.href, + timestamp: new Date().toISOString() + }); + } + + function snapshotActionContext() { + return { + lastAction: recentActions.length ? cloneSimple(recentActions[recentActions.length - 1]) : null, + recentActions: cloneSimple(recentActions), + navigation: cloneSimple(navigationState) + }; + } + + function installActionListeners() { + if (document && document.addEventListener) { + document.addEventListener('click', function(event) { + recordAction('click', describeElement(event && event.target)); + }, true); + document.addEventListener('input', function(event) { + recordAction('input', describeElement(event && event.target)); + }, true); + document.addEventListener('change', function(event) { + recordAction('change', describeElement(event && event.target)); + }, true); + document.addEventListener('submit', function(event) { + recordAction('submit', describeElement(event && event.target)); + }, true); + document.addEventListener('keydown', function(event) { + recordAction('keydown', { + action: event && event.key ? String(event.key) : 'keydown' + }); + }, true); + } + + if (window && window.addEventListener) { + window.addEventListener('popstate', function() { + navigationState.lastNavigation = { + type: 'popstate', + url: window.location.href, + timestamp: new Date().toISOString() + }; + navigationState.currentUrl = window.location.href; + recordAction('popstate', { action: window.location.href }); + }); + } + + if (window.history && window.history.pushState) { + var originalPushState = window.history.pushState; + window.history.pushState = function() { + var result = originalPushState.apply(this, arguments); + navigationState.lastNavigation = { + type: 'pushState', + url: arguments.length >= 3 ? String(arguments[2]) : window.location.href, + timestamp: new Date().toISOString() + }; + navigationState.currentUrl = window.location.href; + recordAction('pushState', { action: navigationState.lastNavigation.url }); + return result; + }; + } + + if (window.history && window.history.replaceState) { + var originalReplaceState = window.history.replaceState; + window.history.replaceState = function() { + var result = originalReplaceState.apply(this, arguments); + navigationState.lastNavigation = { + type: 'replaceState', + url: arguments.length >= 3 ? String(arguments[2]) : window.location.href, + timestamp: new Date().toISOString() + }; + navigationState.currentUrl = window.location.href; + recordAction('replaceState', { action: navigationState.lastNavigation.url }); + return result; + }; + } + } + + function getCaptureDecision(url, method, headers) { + var m = (method || 'GET').toUpperCase(); + var normalizedHeaders = normalizeHeaders(headers); + var accept = getHeader(normalizedHeaders, 'accept'); + var contentType = getHeader(normalizedHeaders, 'content-type'); + var looksJson = /application\/json|text\/plain|application\/x-www-form-urlencoded/i + .test(accept + ' ' + contentType); + var urlInfo = normalizeUrl(url); + var i; + + if (CONFIG.sameOriginOnly && urlInfo.origin && urlInfo.origin !== window.location.origin) { + return { capture: false, reason: 'crossOrigin', urlInfo: urlInfo }; + } + + for (i = 0; i < CONFIG.ignorePatterns.length; i++) { + if (CONFIG.ignorePatterns[i].test(urlInfo.normalizedUrl)) { + return { capture: false, reason: 'ignorePattern', urlInfo: urlInfo }; + } + } + + for (i = 0; i < CONFIG.includePatterns.length; i++) { + if (CONFIG.includePatterns[i].test(urlInfo.normalizedUrl)) { + return { capture: true, reason: 'includePattern', urlInfo: urlInfo }; + } + } + + if (CONFIG.captureMode === 'all') { + return { capture: true, reason: 'captureModeAll', urlInfo: urlInfo }; + } + + if (m !== 'GET') { + return { capture: true, reason: 'nonGet', urlInfo: urlInfo }; + } + + if (!hasStaticExtension(urlInfo.pathname || '/')) { + return { capture: true, reason: 'nonStaticPath', urlInfo: urlInfo }; + } + + if (looksJson) { + return { capture: true, reason: 'jsonLike', urlInfo: urlInfo }; + } + + return { capture: false, reason: 'filteredOut', urlInfo: urlInfo }; + } + + function buildCaptureRecord(base) { + var requestBody = summarizeBody(base.requestBody); + var responseBody = summarizeResponse(base.responseText); + var requestContentType = getHeader(base.requestHeaders, 'content-type'); + var responseContentType = base.responseContentType || ''; + var actionContext = snapshotActionContext(); + return { + captureSource: 'pageHook', + type: base.type, + method: base.method, + url: base.url, + normalizedUrl: base.urlInfo.normalizedUrl, + origin: base.urlInfo.origin, + pathname: base.urlInfo.pathname, + query: base.urlInfo.query, + queryKeys: base.urlInfo.queryKeys, + status: base.status, + requestHeaders: base.requestHeaders, + requestBody: requestBody.display, + requestBodyKind: requestBody.kind, + requestShape: requestBody.shape, + requestContentType: requestContentType, + graphql: requestBody.graphql, + response: responseBody.display, + responseShape: responseBody.shape, + responseContentType: responseContentType, + pageContext: base.pageContext, + actionContext: actionContext, + captureReason: base.captureReason, + duration: base.duration, + timestamp: new Date().toISOString() + }; + } + + installActionListeners(); + var originalXHROpen = XMLHttpRequest.prototype.open; var originalXHRSend = XMLHttpRequest.prototype.send; var originalXHRSetHeader = XMLHttpRequest.prototype.setRequestHeader; XMLHttpRequest.prototype.open = function(method, url) { this._capture = { - method: method.toUpperCase(), + method: (method || 'GET').toUpperCase(), url: typeof url === 'string' ? url : String(url), startTime: Date.now(), - headers: {} + headers: {}, + pageContext: getPageContext() }; - this._pageContext = getPageContext(); return originalXHROpen.apply(this, arguments); }; @@ -144,164 +509,173 @@ }; XMLHttpRequest.prototype.send = function(body) { - var self = this; - if (!this._capture || !shouldCapture(this._capture.url, this._capture.method, this._capture.headers)) { + var capture = this._capture; + var decision = capture ? getCaptureDecision(capture.url, capture.method, capture.headers) : null; + if (!capture || !decision || !decision.capture) { return originalXHRSend.apply(this, arguments); } - var capture = this._capture; - var pageContext = this._pageContext; capture.requestBody = body; - var requestBodyDisplay = ''; - if (body) { - try { - var parsed = JSON.parse(body); - requestBodyDisplay = JSON.stringify(parsed, null, 2).substring(0, 2000); - } catch (e) { - requestBodyDisplay = String(body).substring(0, 1000); - } - } - this.addEventListener('load', function() { - var responseDisplay = ''; - try { - var parsed = JSON.parse(this.responseText); - responseDisplay = JSON.stringify(parsed, null, 2).substring(0, CONFIG.maxResponseLength); - } catch (e) { - responseDisplay = this.responseText.substring(0, 2000); - } - - window.__capturedRequests.push({ + var record = buildCaptureRecord({ type: 'XHR', method: capture.method, url: capture.url, + urlInfo: decision.urlInfo, status: this.status, - requestHeaders: capture.headers, - requestBody: requestBodyDisplay, - response: responseDisplay, - pageContext: pageContext, - duration: Date.now() - capture.startTime, - timestamp: new Date().toISOString() + requestHeaders: normalizeHeaders(capture.headers), + requestBody: capture.requestBody, + responseText: this.responseText || '', + responseContentType: typeof this.getResponseHeader === 'function' + ? (this.getResponseHeader('Content-Type') || '') + : '', + pageContext: capture.pageContext, + captureReason: decision.reason, + duration: Date.now() - capture.startTime }); - console.log('[API Capture] XHR:', capture.method, capture.url, '->', this.status); + window.__capturedRequests.push(record); + console.log( + '[API Capture] XHR:', + capture.method, + record.normalizedUrl, + '->', + this.status, + 'action=' + (record.actionContext.lastAction ? record.actionContext.lastAction.action : 'none') + ); }); this.addEventListener('error', function() { - window.__capturedRequests.push({ + var record = buildCaptureRecord({ type: 'XHR', method: capture.method, url: capture.url, + urlInfo: decision.urlInfo, status: 'error', - requestHeaders: capture.headers, - requestBody: requestBodyDisplay, - error: 'Network error', - pageContext: pageContext, - duration: Date.now() - capture.startTime, - timestamp: new Date().toISOString() + requestHeaders: normalizeHeaders(capture.headers), + requestBody: capture.requestBody, + responseText: '', + responseContentType: '', + pageContext: capture.pageContext, + captureReason: decision.reason, + duration: Date.now() - capture.startTime }); + record.error = 'Network error'; + window.__capturedRequests.push(record); }); return originalXHRSend.apply(this, arguments); }; - // Hook Fetch var originalFetch = window.fetch; window.fetch = function(url, options) { options = options || {}; var startTime = Date.now(); var method = (options.method || 'GET').toUpperCase(); - var urlStr = typeof url === 'string' ? url : (url.url || String(url)); - var requestHeaders = normalizeHeaders(options.headers || {}); + var urlStr = typeof url === 'string' ? url : (url && url.url ? url.url : String(url)); + var decision = getCaptureDecision(urlStr, method, requestHeaders); - if (!shouldCapture(urlStr, method, requestHeaders)) { + if (!decision.capture) { return originalFetch.apply(this, arguments); } - var pageContext = getPageContext(); - var requestBodyDisplay = ''; - if (options.body) { - try { - if (typeof options.body === 'string') { - requestBodyDisplay = options.body.substring(0, 2000); - } else { - requestBodyDisplay = JSON.stringify(options.body).substring(0, 2000); - } - } catch (e) { - requestBodyDisplay = '[body unreadable]'; - } - } - return originalFetch.apply(this, arguments).then(function(response) { var cloned = response.clone(); - return cloned.text().then(function(text) { - var responseBody = ''; - try { - var parsed = JSON.parse(text); - responseBody = JSON.stringify(parsed, null, 2).substring(0, CONFIG.maxResponseLength); - } catch (e) { - responseBody = text.substring(0, 2000); - } - - window.__capturedRequests.push({ + var record = buildCaptureRecord({ type: 'Fetch', method: method, url: urlStr, + urlInfo: decision.urlInfo, status: response.status, requestHeaders: requestHeaders, - requestBody: requestBodyDisplay, - response: responseBody, - pageContext: pageContext, - duration: Date.now() - startTime, - timestamp: new Date().toISOString() + requestBody: options.body, + responseText: text || '', + responseContentType: response.headers && typeof response.headers.get === 'function' + ? (response.headers.get('content-type') || '') + : '', + pageContext: getPageContext(), + captureReason: decision.reason, + duration: Date.now() - startTime }); - console.log('[API Capture] Fetch:', method, urlStr, '->', response.status); + window.__capturedRequests.push(record); + console.log( + '[API Capture] Fetch:', + method, + record.normalizedUrl, + '->', + response.status, + 'action=' + (record.actionContext.lastAction ? record.actionContext.lastAction.action : 'none') + ); return response; }); }).catch(function(error) { - window.__capturedRequests.push({ + var record = buildCaptureRecord({ type: 'Fetch', method: method, url: urlStr, + urlInfo: decision.urlInfo, status: 'error', requestHeaders: requestHeaders, - requestBody: requestBodyDisplay, - error: error.message, - pageContext: pageContext, - duration: Date.now() - startTime, - timestamp: new Date().toISOString() + requestBody: options.body, + responseText: '', + responseContentType: '', + pageContext: getPageContext(), + captureReason: decision.reason, + duration: Date.now() - startTime }); + record.error = error && error.message ? error.message : String(error); + window.__capturedRequests.push(record); throw error; }); }; window.__apiCapture = { - version: '3.0-simple', + version: 'web2cli-base', installed: new Date().toISOString(), + config: CONFIG, getAll: function() { return window.__capturedRequests; }, clear: function() { window.__capturedRequests = []; + recentActions = []; console.log('[API Capture] Cleared'); }, + getRecentActions: function() { + return cloneSimple(recentActions); + }, + getDebugState: function() { + return { + version: this.version, + installed: this.installed, + config: cloneSimple(CONFIG), + requestCount: window.__capturedRequests.length, + recentActions: cloneSimple(recentActions), + navigation: cloneSimple(navigationState), + lastRequest: window.__capturedRequests.length + ? cloneSimple(window.__capturedRequests[window.__capturedRequests.length - 1]) + : null + }; + }, summary: function() { - console.log('=== API Capture Summary ==='); - console.log('Total requests:', window.__capturedRequests.length); var groups = {}; - window.__capturedRequests.forEach(function(r) { - var path = r.url.split('?')[0]; - groups[path] = (groups[path] || 0) + 1; + window.__capturedRequests.forEach(function(record) { + groups[record.pathname] = (groups[record.pathname] || 0) + 1; }); + console.log('=== API Capture Summary ==='); + console.log('Total requests:', window.__capturedRequests.length); console.log('Endpoints:', Object.keys(groups)); + console.log('Recent actions:', recentActions.length); + console.log('window.__apiCapture.getDebugState() - inspect capture internals'); } }; - console.log('[API Capture] v3.0-simple installed'); + console.log('[API Capture] web2cli-base installed'); console.log(' window.__capturedRequests - captured data'); - console.log(' window.__apiCapture.summary() - show summary'); + console.log(' window.__apiCapture.getRecentActions() - recent user interactions'); + console.log(' window.__apiCapture.getDebugState() - current capture state'); })(); \ No newline at end of file diff --git a/tests/tool/test_web2cli_generate_cli.py b/tests/tool/test_web2cli_generate_cli.py index d21f80c6..644ab2b8 100644 --- a/tests/tool/test_web2cli_generate_cli.py +++ b/tests/tool/test_web2cli_generate_cli.py @@ -190,3 +190,143 @@ def test_generated_client_still_supports_plain_cookie_list(tmp_path, monkeypatch {"name": "sid", "value": "cookie-123"}, {"name": "api", "value": "cookie-456", "path": "/"}, ] + + +def _sample_spec(): + return { + "schemaVersion": "1.0", + "site": "example", + "command": "list_items", + "description": "List items from example API", + "baseUrl": "https://example.com", + "strategy": "COOKIE", + "auth": {"stateFile": "auth-state.json", "requiredCookies": [], "requiredHeaders": []}, + "operation": { + "method": "POST", + "endpoint": "/api/items/list", + "queryTemplate": {}, + "bodyTemplate": {"page": "${page}", "size": "${limit}"}, + "headers": {"Content-Type": "application/json"}, + }, + "rowSource": {"path": "$.data.items[]", "collectionPath": "$.data.items[]"}, + "args": [ + {"name": "page", "type": "int", "default": 1, "help": "Page number"}, + {"name": "limit", "type": "int", "default": 20, "help": "Page size"}, + ], + "columns": [ + {"name": "id", "path": "$.data.items[].id", "relativePath": "id", "type": "string"}, + {"name": "title", "path": "$.data.items[].title", "relativePath": "title", "type": "string"}, + ], + "verify": { + "args": {"page": 1, "limit": 20}, + "rowCount": {"min": 1, "max": 2}, + "columns": ["id", "title"], + "types": {"id": "string", "title": "string"}, + "notEmpty": ["id", "title"], + "patterns": {}, + }, + } + + +class _FakeResponse: + def __init__(self, payload): + self._payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self._payload + + +class _FakeRequestSession(_FakeSession): + def __init__(self, payload) -> None: + super().__init__() + self._payload = payload + self.request_calls = [] + + def request(self, method, url, json=None, params=None): + self.request_calls.append({"method": method, "url": url, "json": json, "params": params}) + return _FakeResponse(self._payload) + + +def test_generate_verify_materials_from_spec_uses_spec_contract(): + module = _load_module() + + verify = module.generate_verify_materials_from_spec(_sample_spec()) + + assert verify["site"] == "example" + assert verify["command"] == "list_items" + assert verify["expect"]["columns"] == ["id", "title"] + assert verify["expect"]["rowCount"]["max"] == 2 + + +def test_generate_python_cli_from_spec_supports_argparse_and_verify(): + module = _load_module() + + output = module.generate_python_cli_from_spec(_sample_spec()) + + assert 'parser.add_argument("--format", choices=["json", "csv", "table"]' in output + assert 'parser.add_argument("--verify", action="store_true"' in output + assert 'SPEC = {' in output + assert 'def verify_rows(rows: List[Dict[str, Any]], verify_spec: Dict[str, Any])' in output + + +def test_generated_spec_cli_executes_request_and_projects_rows(tmp_path, monkeypatch): + module = _load_module() + auth_state = tmp_path / "auth-state.json" + auth_state.write_text( + json.dumps({"cookies": [{"name": "sid", "value": "cookie-123", "domain": ".example.com", "path": "/"}]}), + encoding="utf-8", + ) + + fake_session = _FakeRequestSession( + {"data": {"items": [{"id": "1", "title": "Alpha"}, {"id": "2", "title": "Beta"}]}} + ) + fake_requests = types.SimpleNamespace(Session=lambda: fake_session) + monkeypatch.setitem(sys.modules, "requests", fake_requests) + + namespace = {"__name__": "generated_spec_cli"} + exec(module.generate_python_cli_from_spec(_sample_spec()), namespace) + + client = namespace["APIClient"](auth_state=str(auth_state)) + rows = client.run({"page": 3, "limit": 5}) + errors = namespace["verify_rows"](rows, {"expect": _sample_spec()["verify"]}) + + assert rows == [{"id": "1", "title": "Alpha"}, {"id": "2", "title": "Beta"}] + assert errors == [] + assert fake_session.request_calls == [ + { + "method": "POST", + "url": "https://example.com/api/items/list", + "json": {"page": 3, "size": 5}, + "params": None, + } + ] + assert fake_session.cookies.set_calls == [ + {"name": "sid", "value": "cookie-123", "domain": ".example.com", "path": "/"} + ] + + +def test_main_supports_spec_verify_output(tmp_path, monkeypatch, capsys): + module = _load_module() + spec_path = tmp_path / "web2cli-spec.json" + spec_path.write_text(json.dumps(_sample_spec()), encoding="utf-8") + monkeypatch.setattr( + sys, + "argv", + [ + "generate-cli.py", + "--spec", + str(spec_path), + "--format", + "verify", + ], + ) + + module.main() + captured = capsys.readouterr() + payload = json.loads(captured.out) + + assert payload["site"] == "example" + assert payload["expect"]["types"]["id"] == "string" diff --git a/tests/tool/test_web2cli_generate_spec.py b/tests/tool/test_web2cli_generate_spec.py new file mode 100644 index 00000000..4e4f2309 --- /dev/null +++ b/tests/tool/test_web2cli_generate_spec.py @@ -0,0 +1,106 @@ +import importlib.util +import json +import sys +from pathlib import Path + + +SCRIPT_PATH = ( + Path(__file__).resolve().parents[2] + / ".flocks" + / "plugins" + / "skills" + / "web2cli" + / "scripts" + / "generate-spec.py" +) + + +def _load_module(): + spec = importlib.util.spec_from_file_location("web2cli_generate_spec", SCRIPT_PATH) + assert spec is not None + assert spec.loader is not None + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _sample_requests(): + return [ + { + "type": "XHR", + "method": "POST", + "url": "https://example.com/api/ignore", + "status": 200, + "response": '{"ok": true}', + "requestHeaders": {"Content-Type": "application/json"}, + }, + { + "type": "Fetch", + "method": "POST", + "url": "https://example.com/api/items/list?page=1", + "normalizedUrl": "https://example.com/api/items/list?page=1", + "origin": "https://example.com", + "pathname": "/api/items/list", + "query": {"page": "1"}, + "queryKeys": ["page"], + "status": 200, + "captureReason": "nonGet", + "actionContext": {"lastAction": {"action": "Load data"}}, + "requestHeaders": { + "Content-Type": "application/json", + "Cookie": "sid=cookie-123", + "X-Requested-With": "XMLHttpRequest", + }, + "requestBody": '{"page": 1, "size": 20}', + "response": '{"data":{"items":[{"id":"1","title":"Alpha","count":2},{"id":"2","title":"Beta","count":3}]}}', + }, + ] + + +def test_generate_spec_from_requests_picks_primary_collection_endpoint(): + module = _load_module() + + spec = module.generate_spec_from_requests(_sample_requests()) + + assert spec["site"] == "example" + assert spec["command"] == "list" + assert spec["strategy"] == "COOKIE" + assert spec["operation"]["endpoint"] == "/api/items/list" + assert spec["operation"]["bodyTemplate"] == {"page": "${page}", "size": "${limit}"} + assert spec["args"] == [ + {"name": "page", "type": "int", "default": 1, "help": "Page number"}, + {"name": "limit", "type": "int", "default": 20, "help": "Page size"}, + ] + assert spec["rowSource"]["collectionPath"] == "$.data.items[]" + assert spec["columns"][:2] == [ + {"name": "id", "path": "$.id", "relativePath": "id", "sourceField": "id", "type": "string"}, + {"name": "title", "path": "$.title", "relativePath": "title", "sourceField": "title", "type": "string"}, + ] + + +def test_main_writes_spec_file(tmp_path, monkeypatch, capsys): + module = _load_module() + input_path = tmp_path / "captured.json" + output_path = tmp_path / "web2cli-spec.json" + input_path.write_text(json.dumps(_sample_requests()), encoding="utf-8") + + monkeypatch.setattr( + sys, + "argv", + [ + "generate-spec.py", + str(input_path), + "--output", + str(output_path), + ], + ) + + module.main() + + payload = json.loads(output_path.read_text(encoding="utf-8")) + captured = capsys.readouterr() + + assert payload["verify"]["columns"][:2] == ["id", "title"] + assert payload["verify"]["rowCount"]["max"] == 2 + assert f"Written to {output_path}" in captured.out diff --git a/tests/tool/test_web2cli_hook_base.py b/tests/tool/test_web2cli_hook_base.py index 701524e3..46f73096 100644 --- a/tests/tool/test_web2cli_hook_base.py +++ b/tests/tool/test_web2cli_hook_base.py @@ -187,8 +187,12 @@ def test_hook_base_captures_recent_action_context_for_xhr(): """ ) - assert result["version"] == "3.1-base" + assert result["version"] == "web2cli-base" assert result["request"]["pageContext"]["path"] == "/dashboard" + assert result["request"]["normalizedUrl"] == "https://example.com/api/items/list" + assert result["request"]["pathname"] == "/api/items/list" + assert result["request"]["captureReason"] == "nonGet" + assert result["request"]["requestShape"]["$.page"] == "number" assert result["request"]["actionContext"]["lastAction"]["action"] == "Load data" assert result["recentActions"][0]["type"] == "click" assert any("action=Load data" in line for line in result["logs"]) @@ -216,4 +220,6 @@ def test_hook_base_exposes_debug_state_and_truncates_large_responses(): assert result["response"].endswith("...[truncated]") assert any(action["type"] == "pushState" for action in result["debugState"]["recentActions"]) + assert result["debugState"]["lastRequest"]["response"] == result["response"] + assert result["debugState"]["lastRequest"]["pathname"] == "/api/debug" assert any("window.__apiCapture.getDebugState()" in line for line in result["logs"]) From 12e27975c62974227656e59f2f62dd33f9b75f43 Mon Sep 17 00:00:00 2001 From: xiami Date: Fri, 8 May 2026 15:51:09 +0800 Subject: [PATCH 02/27] feat(update): CN mirror prompt and uv sync timeout handling (#233) - Prompt for China mirror before version check when region is unset - Refactor version check into _load_update_info after mirror choice - Use longer uv sync timeout on Windows; catch TimeoutExpired and rollback - Reduce default service log tail lines for failed daemon startups - Add tests for mirror prompt and Windows sync timeout rollback --- flocks/cli/commands/update.py | 22 +++++++--- flocks/cli/service_manager.py | 2 +- flocks/updater/updater.py | 70 +++++++++++++++++++++++++------ tests/cli/test_update_command.py | 72 ++++++++++++++++++++++++++++++++ tests/updater/test_updater.py | 63 ++++++++++++++++++++++++++++ 5 files changed, 210 insertions(+), 19 deletions(-) diff --git a/flocks/cli/commands/update.py b/flocks/cli/commands/update.py index 46448304..bd59cd69 100644 --- a/flocks/cli/commands/update.py +++ b/flocks/cli/commands/update.py @@ -37,13 +37,23 @@ def update_command( async def _update(check: bool, yes: bool, force: bool = False, region: str | None = None) -> None: from flocks.updater import check_update, perform_update, detect_deploy_mode - with console.status("[cyan]正在检查版本...[/cyan]", spinner="dots"): - info = await check_update(region=region) + if not yes and not check and region is None: + use_cn_mirror = typer.confirm("\n是否使用中国镜像进行升级?", default=False) + if use_cn_mirror: + region = "cn" + + async def _load_update_info(selected_region: str | None): + with console.status("[cyan]正在检查版本...[/cyan]", spinner="dots"): + info = await check_update(region=selected_region) + + if info.error: + append_upgrade_text_log(f"ERROR version_check: {info.error}") + console.print(f"[red]检查失败:{info.error}[/red]") + raise typer.Exit(1) - if info.error: - append_upgrade_text_log(f"ERROR version_check: {info.error}") - console.print(f"[red]检查失败:{info.error}[/red]") - raise typer.Exit(1) + return info + + info = await _load_update_info(region) _print_version_table(info) diff --git a/flocks/cli/service_manager.py b/flocks/cli/service_manager.py index 9f0f3a0a..5bf2a234 100644 --- a/flocks/cli/service_manager.py +++ b/flocks/cli/service_manager.py @@ -1431,7 +1431,7 @@ def tail_lines(path: Path, lines: int) -> list[str]: return [line.rstrip("\n") for line in deque(handle, maxlen=max(lines, 0))] -def _emit_service_log_tail(console, log_path: Path, service_label: str, lines: int = 40) -> None: +def _emit_service_log_tail(console, log_path: Path, service_label: str, lines: int = 10) -> None: """Print the last *lines* lines of *log_path* to help diagnose failed daemon startups.""" if lines <= 0: return diff --git a/flocks/updater/updater.py b/flocks/updater/updater.py index 2b72723f..4703dab2 100644 --- a/flocks/updater/updater.py +++ b/flocks/updater/updater.py @@ -51,6 +51,8 @@ _CURL_USER_AGENT = "curl/8.7.1" _FRONTEND_DEPENDENCY_INSTALL_TIMEOUT_SECONDS = 300 _FRONTEND_BUILD_TIMEOUT_SECONDS = 300 +_DEPENDENCY_SYNC_TIMEOUT_SECONDS = 180 +_WINDOWS_DEPENDENCY_SYNC_TIMEOUT_SECONDS = 300 _PRESERVE_NAMES: set[str] = { ".venv", @@ -364,6 +366,13 @@ def _build_frontend_subprocess_env(*, npm_registry: str | None = None) -> dict[s return env or None +def _dependency_sync_timeout_seconds() -> int: + """Return the timeout budget for ``uv sync`` during self-update.""" + if sys.platform == "win32": + return _WINDOWS_DEPENDENCY_SYNC_TIMEOUT_SECONDS + return _DEPENDENCY_SYNC_TIMEOUT_SECONDS + + # ------------------------------------------------------------------ # # Async subprocess helpers # ------------------------------------------------------------------ # @@ -2150,10 +2159,29 @@ async def _restore_after_apply_failure() -> None: uv_cmd.extend(["--default-index", profile.uv_default_index]) sync_env = _build_uv_sync_env() + sync_timeout = _dependency_sync_timeout_seconds() retried_after_managed_python_repair = False - code, _, err = await _run_async( - uv_cmd, cwd=install_root, timeout=180, env=sync_env, - ) + + async def _run_uv_sync(cmd: list[str]) -> tuple[int, str, str]: + return await _run_async( + cmd, + cwd=install_root, + timeout=sync_timeout, + env=sync_env, + ) + + def _dependency_sync_timeout_message() -> str: + return f"Dependency sync timed out after {sync_timeout}s while running uv sync." + + try: + code, _, err = await _run_uv_sync(uv_cmd) + except subprocess.TimeoutExpired: + shutil.rmtree(tmp_dir, ignore_errors=True) + await _restore_after_apply_failure() + timeout_message = _dependency_sync_timeout_message() + _record_update_journal(f"ERROR {timeout_message}") + yield UpdateProgress(stage="error", message=timeout_message, success=False) + return if ( code != 0 and sys.platform == "win32" @@ -2173,9 +2201,15 @@ async def _restore_after_apply_failure() -> None: {"error": err}, ) await asyncio.sleep(2) - code, _, err = await _run_async( - uv_cmd, cwd=install_root, timeout=180, env=sync_env, - ) + try: + code, _, err = await _run_uv_sync(uv_cmd) + except subprocess.TimeoutExpired: + shutil.rmtree(tmp_dir, ignore_errors=True) + await _restore_after_apply_failure() + timeout_message = _dependency_sync_timeout_message() + _record_update_journal(f"ERROR {timeout_message}") + yield UpdateProgress(stage="error", message=timeout_message, success=False) + return if code != 0 and profile.uv_default_index: log.warning( "updater.dependencies.sync_retry_default_index", @@ -2186,15 +2220,27 @@ async def _restore_after_apply_failure() -> None: ) await asyncio.sleep(3) uv_cmd = [uv_path, "sync"] - code, _, err = await _run_async( - uv_cmd, cwd=install_root, timeout=180, env=sync_env, - ) + try: + code, _, err = await _run_uv_sync(uv_cmd) + except subprocess.TimeoutExpired: + shutil.rmtree(tmp_dir, ignore_errors=True) + await _restore_after_apply_failure() + timeout_message = _dependency_sync_timeout_message() + _record_update_journal(f"ERROR {timeout_message}") + yield UpdateProgress(stage="error", message=timeout_message, success=False) + return if code != 0: log.warning("updater.dependencies.sync_retry", {"first_error": err}) await asyncio.sleep(3) - code, _, err = await _run_async( - uv_cmd, cwd=install_root, timeout=180, env=sync_env, - ) + try: + code, _, err = await _run_uv_sync(uv_cmd) + except subprocess.TimeoutExpired: + shutil.rmtree(tmp_dir, ignore_errors=True) + await _restore_after_apply_failure() + timeout_message = _dependency_sync_timeout_message() + _record_update_journal(f"ERROR {timeout_message}") + yield UpdateProgress(stage="error", message=timeout_message, success=False) + return if code != 0: shutil.rmtree(tmp_dir, ignore_errors=True) diff --git a/tests/cli/test_update_command.py b/tests/cli/test_update_command.py index c3542448..9786405f 100644 --- a/tests/cli/test_update_command.py +++ b/tests/cli/test_update_command.py @@ -35,6 +35,78 @@ async def fake_update(*, check: bool, yes: bool, force: bool, region: str | None assert captured == {"check": False, "yes": True, "force": True, "region": "cn"} +def test_update_prompts_for_cn_mirror_before_upgrade_confirmation(monkeypatch) -> None: + output = StringIO() + monkeypatch.setattr( + update_cmd, + "console", + Console(file=output, force_terminal=False, color_system=None, width=120), + ) + + check_regions: list[str | None] = [] + confirm_prompts: list[str] = [] + captured: dict[str, object] = {} + answers = iter([True, True]) + + async def fake_check_update(*, locale: str | None = None, region: str | None = None) -> VersionInfo: + check_regions.append(region) + zipball_url = "https://example.com/flocks.zip" + tarball_url = "https://example.com/flocks.tar.gz" + if region == "cn": + zipball_url = "https://gitee.example.com/flocks.zip" + tarball_url = "https://gitee.example.com/flocks.tar.gz" + return VersionInfo( + current_version="2026.4.1", + latest_version="2026.4.2", + has_update=True, + zipball_url=zipball_url, + tarball_url=tarball_url, + deploy_mode="source", + update_allowed=True, + ) + + async def fake_perform_update( + latest_tag: str, + *, + zipball_url: str | None = None, + tarball_url: str | None = None, + restart: bool = True, + locale: str | None = None, + region: str | None = None, + ): + captured["latest_tag"] = latest_tag + captured["zipball_url"] = zipball_url + captured["tarball_url"] = tarball_url + captured["perform_region"] = region + captured["restart"] = restart + async for step in _fake_progress(): + yield step + + def fake_confirm(prompt: str, default: bool = False) -> bool: + confirm_prompts.append(prompt) + return next(answers) + + monkeypatch.setattr(updater_pkg, "check_update", fake_check_update) + monkeypatch.setattr(updater_pkg, "perform_update", fake_perform_update) + monkeypatch.setattr(updater_pkg, "detect_deploy_mode", lambda: "source") + monkeypatch.setattr(update_cmd.typer, "confirm", fake_confirm) + + import asyncio + + asyncio.run(update_cmd._update(check=False, yes=False, force=False, region=None)) + + assert check_regions == ["cn"] + assert confirm_prompts == ["\n是否使用中国镜像进行升级?", "\n是否立即升级?"] + assert captured == { + "latest_tag": "2026.4.2", + "zipball_url": "https://gitee.example.com/flocks.zip", + "tarball_url": "https://gitee.example.com/flocks.tar.gz", + "perform_region": "cn", + "restart": False, + } + assert "已切换为中国镜像源" not in output.getvalue() + + async def _fake_progress(): yield UpdateProgress(stage="fetching", message="fetching") yield UpdateProgress(stage="done", message="done", success=True) diff --git a/tests/updater/test_updater.py b/tests/updater/test_updater.py index 0672ad2e..43f7fbf5 100644 --- a/tests/updater/test_updater.py +++ b/tests/updater/test_updater.py @@ -1745,6 +1745,69 @@ async def fake_sleep(_s): assert call_count == 2 +@pytest.mark.asyncio +async def test_perform_update_rolls_back_when_windows_uv_sync_times_out( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + archive_path = tmp_path / "flocks.zip" + archive_path.write_text("archive", encoding="utf-8") + staged_root = tmp_path / "staged" + staged_webui = staged_root / "webui" + staged_webui.mkdir(parents=True) + (staged_webui / "package.json").write_text("{}", encoding="utf-8") + (staged_webui / "dist").mkdir() + (staged_webui / "dist" / "index.html").write_text("", encoding="utf-8") + + events: list[str] = [] + + async def fake_get_updater_config(): + return SimpleNamespace( + archive_format="zip", + sources=["github"], + repo="AgentFlocks/Flocks", + token=None, + gitee_token=None, + backup_retain_count=3, + base_url=None, + gitee_repo=None, + ) + + async def fake_download(**_kw): + return archive_path + + async def fake_run_async(cmd, cwd=None, timeout=None, env=None): + if "sync" in cmd: + raise subprocess.TimeoutExpired(cmd=cmd, timeout=timeout or 0) + return 0, "", "" + + monkeypatch.setattr(updater.sys, "platform", "win32") + monkeypatch.setattr(updater, "_get_updater_config", fake_get_updater_config) + monkeypatch.setattr(updater, "_get_repo_root", lambda: tmp_path / "install-root") + monkeypatch.setattr(updater, "get_current_version", lambda: "2026.3.31") + monkeypatch.setattr(updater, "_download_with_fallback", fake_download) + monkeypatch.setattr(updater, "_backup_current_version", lambda *_a, **_kw: tmp_path / "backup.tar.gz") + monkeypatch.setattr(updater, "_extract_archive", lambda *_a, **_kw: staged_root) + monkeypatch.setattr(updater, "_run_async", fake_run_async) + monkeypatch.setattr( + updater, + "_find_executable", + lambda name: "/usr/bin/npm" if name in {"npm", "npm.cmd"} else r"C:\tools\uv.exe", + ) + monkeypatch.setattr(updater, "_build_uv_sync_env", lambda: None) + monkeypatch.setattr(updater, "_replace_install_dir", lambda *_a, **_kw: None) + monkeypatch.setattr(updater, "_restore_backup_if_possible", lambda *_a: events.append("restore")) + + progresses = [step async for step in updater.perform_update("2026.4.1", restart=False)] + + assert progresses[-1].stage == "error" + expected_timeout = updater._dependency_sync_timeout_seconds() + assert progresses[-1].message == ( + f"Dependency sync timed out after {expected_timeout}s while running uv sync." + ) + assert events == ["restore"] + + @pytest.mark.asyncio async def test_perform_update_fails_after_uv_sync_retry_exhausted( monkeypatch: pytest.MonkeyPatch, From 180980272dc49a7d6fcf68acd08d665789b76101 Mon Sep 17 00:00:00 2001 From: JohnYin Date: Fri, 8 May 2026 16:52:19 +0800 Subject: [PATCH 03/27] Fix/azure custom deployment name (#234) * fix(provider): support custom Azure deployments Allow Azure OpenAI users to configure deployment names directly in the provider setup flow and cover the runtime/test-credentials path so custom deployments are usable. * fix(provider): show Azure custom deployments in settings Keep saved Azure deployment names visible when editing provider settings and separate catalog model counts from custom deployments. * fix(provider): address Azure deployment review feedback Avoid persisting Azure deployment names from connection tests, keep save feedback consistent, and tighten Azure-specific UI/test coverage. --- flocks/server/routes/provider.py | 10 +- tests/provider/test_azure_provider.py | 33 ++++ tests/provider/test_test_credentials.py | 80 +++++++++ .../routes/test_custom_provider_runtime.py | 33 ++++ webui/src/locales/en-US/model.json | 11 +- webui/src/locales/zh-CN/model.json | 11 +- webui/src/pages/Model/index.tsx | 169 +++++++++++++++++- 7 files changed, 335 insertions(+), 12 deletions(-) create mode 100644 tests/provider/test_azure_provider.py diff --git a/flocks/server/routes/provider.py b/flocks/server/routes/provider.py index bf13f016..ef00afcb 100644 --- a/flocks/server/routes/provider.py +++ b/flocks/server/routes/provider.py @@ -2320,9 +2320,15 @@ async def test_provider_credentials(provider_id: str, body: Optional[TestCredent requested_model_id = body.model_id if body else None test_model_id = requested_model_id or models[0].id - # Validate model belongs to this provider + # Validate model belongs to this provider. Azure OpenAI is the + # exception: users may test a deployment name before saving it. valid_ids = {m.id for m in models} - if test_model_id not in valid_ids: + is_unsaved_azure_deployment = ( + requested_model_id + and provider_id in {"azure-openai", "azure"} + and test_model_id not in valid_ids + ) + if test_model_id not in valid_ids and not is_unsaved_azure_deployment: response = { "success": False, "message": f"模型 '{test_model_id}' 不属于该 Provider", diff --git a/tests/provider/test_azure_provider.py b/tests/provider/test_azure_provider.py new file mode 100644 index 00000000..d23d426d --- /dev/null +++ b/tests/provider/test_azure_provider.py @@ -0,0 +1,33 @@ +from flocks.provider.provider import ModelCapabilities, ModelInfo +from flocks.provider.sdk.azure import AzureProvider + + +def test_azure_provider_returns_configured_deployment_models(): + provider = AzureProvider() + provider._config_models = [ + ModelInfo( + id="customer-prod-deployment", + name="Customer Production Deployment", + provider_id="azure", + capabilities=ModelCapabilities( + supports_tools=True, + supports_streaming=True, + context_window=128000, + max_tokens=4096, + ), + ) + ] + + models = provider.get_models() + + assert [m.id for m in models] == ["customer-prod-deployment"] + assert models[0].name == "Customer Production Deployment" + + +def test_azure_provider_returns_fallback_models_without_config(): + provider = AzureProvider() + + models = provider.get_models() + + assert {m.id for m in models} == {"gpt-5.4", "gpt-5-mini"} + assert all(m.provider_id == "azure" for m in models) diff --git a/tests/provider/test_test_credentials.py b/tests/provider/test_test_credentials.py index e01d5db7..aa7e7dd2 100644 --- a/tests/provider/test_test_credentials.py +++ b/tests/provider/test_test_credentials.py @@ -754,3 +754,83 @@ async def test_existing_custom_settings_are_preserved_during_provider_test(self) assert configured.api_key == "gateway-api-key" assert configured.base_url == "https://gateway.internal/v1" assert configured.custom_settings["verify_ssl"] is False + + @pytest.mark.asyncio + async def test_requested_azure_deployment_model_is_used_for_provider_test(self): + from flocks.server.routes.provider import TestCredentialRequest, test_provider_credentials + + provider = MagicMock() + provider._config = MagicMock( + custom_settings={}, + base_url="https://example-resource.openai.azure.com/", + ) + provider.chat = AsyncMock(return_value=MagicMock(content="Paris")) + + model = MagicMock() + model.id = "customer-prod-deployment" + + mock_secrets = MagicMock() + mock_secrets.get.return_value = "azure-api-key" + + mock_config = MagicMock() + + with ( + patch(_PATCH_SECRET_MGR, return_value=mock_secrets), + patch(_PATCH_CONFIG_GET, new_callable=AsyncMock, return_value=mock_config), + patch(_PATCH_PROVIDER) as mock_provider_cls, + ): + mock_provider_cls._ensure_initialized = MagicMock() + mock_provider_cls._load_dynamic_providers = MagicMock() + mock_provider_cls.apply_config = AsyncMock() + mock_provider_cls.get.return_value = provider + mock_provider_cls.list_models.return_value = [model] + + result = await test_provider_credentials( + "azure-openai", + TestCredentialRequest(model_id="customer-prod-deployment"), + ) + + assert result["success"] is True, result + assert result["model_id"] == "customer-prod-deployment" + provider.chat.assert_awaited_once() + assert provider.chat.await_args.args[0] == "customer-prod-deployment" + + @pytest.mark.asyncio + async def test_unsaved_azure_deployment_can_be_tested_without_model_definition(self): + from flocks.server.routes.provider import TestCredentialRequest, test_provider_credentials + + provider = MagicMock() + provider._config = MagicMock( + custom_settings={}, + base_url="https://example-resource.openai.azure.com/", + ) + provider.chat = AsyncMock(return_value=MagicMock(content="Paris")) + + catalog_model = MagicMock() + catalog_model.id = "gpt-5.4" + + mock_secrets = MagicMock() + mock_secrets.get.return_value = "azure-api-key" + + mock_config = MagicMock() + + with ( + patch(_PATCH_SECRET_MGR, return_value=mock_secrets), + patch(_PATCH_CONFIG_GET, new_callable=AsyncMock, return_value=mock_config), + patch(_PATCH_PROVIDER) as mock_provider_cls, + ): + mock_provider_cls._ensure_initialized = MagicMock() + mock_provider_cls._load_dynamic_providers = MagicMock() + mock_provider_cls.apply_config = AsyncMock() + mock_provider_cls.get.return_value = provider + mock_provider_cls.list_models.return_value = [catalog_model] + + result = await test_provider_credentials( + "azure-openai", + TestCredentialRequest(model_id="unsaved-prod-deployment"), + ) + + assert result["success"] is True, result + assert result["model_id"] == "unsaved-prod-deployment" + provider.chat.assert_awaited_once() + assert provider.chat.await_args.args[0] == "unsaved-prod-deployment" diff --git a/tests/server/routes/test_custom_provider_runtime.py b/tests/server/routes/test_custom_provider_runtime.py index 0e8baaee..4899ce6e 100644 --- a/tests/server/routes/test_custom_provider_runtime.py +++ b/tests/server/routes/test_custom_provider_runtime.py @@ -1,4 +1,5 @@ from flocks.provider.provider import ModelCapabilities, ModelInfo, Provider +from flocks.provider.sdk.azure import AzureProvider from flocks.server.routes.custom_provider import CreateModelReq, _add_model_to_runtime @@ -48,3 +49,35 @@ class DummyProvider: assert provider._config_models[0].capabilities.supports_reasoning is True finally: Provider._models = original_models + + +def test_add_azure_deployment_to_runtime_config_models(monkeypatch): + provider = AzureProvider() + provider.id = "azure-openai" + provider._config_models = [] + body = CreateModelReq( + model_id="customer-prod-deployment", + name="Customer Production Deployment", + context_window=128000, + max_output_tokens=4096, + supports_vision=False, + supports_tools=True, + supports_streaming=True, + supports_reasoning=False, + input_price=0.0, + output_price=0.0, + currency="USD", + ) + + original_models = Provider._models + Provider._models = {} + monkeypatch.setattr(Provider, "get", classmethod(lambda cls, provider_id: provider)) + + try: + _add_model_to_runtime("azure-openai", body) + + assert Provider._models[body.model_id].provider_id == "azure-openai" + assert provider._config_models[0].id == "customer-prod-deployment" + assert provider._config_models[0].name == "Customer Production Deployment" + finally: + Provider._models = original_models diff --git a/webui/src/locales/en-US/model.json b/webui/src/locales/en-US/model.json index c2480436..abdf84f6 100644 --- a/webui/src/locales/en-US/model.json +++ b/webui/src/locales/en-US/model.json @@ -119,7 +119,16 @@ "loadFailed": "Failed to load provider catalog", "noModelsToTest": "No enabled models to test", "batchTestDone": "Batch test complete", - "batchTestSummary": "{{success}} succeeded, {{failed}} failed" + "batchTestSummary": "{{success}} succeeded, {{failed}} failed", + "azureDeploymentName": "Azure Deployment Name", + "azureDeploymentPlaceholder": "e.g. my-gpt-4o-prod", + "azureDeploymentHint": "Azure OpenAI requests use the deployment name, not a fixed model name. The preset models are examples; enter your own deployment name here.", + "azureDeploymentDisplayName": "Display Name (optional)", + "azureDeploymentDisplayPlaceholder": "e.g. GPT-4o Production", + "azureDeploymentRequired": "Select at least one preset model or enter an Azure deployment name", + "azureModelIdHint": "For Azure OpenAI, Model ID should be the deployment name from Azure Portal.", + "azureCustomDeployments": "Custom Azure Deployments", + "azureNoCustomDeployments": "No custom Azure deployment has been added yet." }, "wizard": { "providerSaved": "Provider Saved", diff --git a/webui/src/locales/zh-CN/model.json b/webui/src/locales/zh-CN/model.json index 29cb71e4..768f69be 100644 --- a/webui/src/locales/zh-CN/model.json +++ b/webui/src/locales/zh-CN/model.json @@ -119,7 +119,16 @@ "loadFailed": "加载 Provider 目录失败", "noModelsToTest": "没有已启用的模型可测试", "batchTestDone": "批量测试完成", - "batchTestSummary": "{{success}} 成功, {{failed}} 失败" + "batchTestSummary": "{{success}} 成功, {{failed}} 失败", + "azureDeploymentName": "Azure 部署名称", + "azureDeploymentPlaceholder": "例如 my-gpt-4o-prod", + "azureDeploymentHint": "Azure OpenAI 请求使用 deployment name,而不是固定模型名。预设模型只是常用示例,你可以在这里填写自己的部署名称。", + "azureDeploymentDisplayName": "显示名称(可选)", + "azureDeploymentDisplayPlaceholder": "例如 GPT-4o Production", + "azureDeploymentRequired": "请至少选择一个预设模型,或填写 Azure deployment name", + "azureModelIdHint": "对于 Azure OpenAI,模型 ID 请填写 Azure Portal 中的 deployment name。", + "azureCustomDeployments": "自定义 Azure Deployments", + "azureNoCustomDeployments": "尚未添加自定义 Azure deployment。" }, "wizard": { "providerSaved": "Provider 已保存", diff --git a/webui/src/pages/Model/index.tsx b/webui/src/pages/Model/index.tsx index cb15baf3..eb59fb6b 100644 --- a/webui/src/pages/Model/index.tsx +++ b/webui/src/pages/Model/index.tsx @@ -55,6 +55,12 @@ function providerAllowsEmptyApiKey(providerId: string): boolean { ); } +const AZURE_PROVIDER_IDS = new Set(['azure-openai', 'azure']); + +function isAzureProviderId(providerId: string): boolean { + return AZURE_PROVIDER_IDS.has(providerId); +} + // ==================== Connection Cache ==================== const CONNECTION_CACHE_KEY = 'flocks_provider_connection_cache'; @@ -1088,6 +1094,8 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: { const [baseUrl, setBaseUrl] = useState(''); const [description, setDescription] = useState(''); const [providerName, setProviderName] = useState(''); + const [azureDeploymentName, setAzureDeploymentName] = useState(''); + const [azureDeploymentDisplayName, setAzureDeploymentDisplayName] = useState(''); // Model selection (for catalog providers) const [selectedModelIds, setSelectedModelIds] = useState>(new Set()); @@ -1172,6 +1180,8 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: { setDescription(provider.description || ''); setSelectedModelIds(new Set(provider.models.map(m => m.id))); setProviderName(''); + setAzureDeploymentName(''); + setAzureDeploymentDisplayName(''); } }; @@ -1212,7 +1222,8 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: { base_url: baseUrl.trim() || undefined, provider_name: selectedCatalogId === 'openai-compatible' && providerName.trim() ? providerName.trim() : undefined, }); - const res = await providerAPI.testCredentials(selectedCatalogId); + const azureModelId = isAzureProviderId(selectedCatalogId) ? azureDeploymentName.trim() : ''; + const res = await providerAPI.testCredentials(selectedCatalogId, azureModelId || undefined); setTestResult({ success: res.data.success, message: res.data.message || (res.data.success ? t('status.connected') : t('form.testFailed')), @@ -1235,6 +1246,11 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: { toast.warning('Please enter API Key'); return; } + const azureModelId = isAzureProviderId(selectedCatalogId) ? azureDeploymentName.trim() : ''; + if (isAzureProviderId(selectedCatalogId) && selectedModelIds.size === 0 && !azureModelId) { + toast.warning(t('form.azureDeploymentRequired')); + return; + } try { setSaving(true); if (selectedCatalogId === 'openai-compatible') { @@ -1259,6 +1275,20 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: { const unselected = selectedCatalog.models.filter(m => !selectedModelIds.has(m.id)).map(m => m.id); await Promise.all(unselected.map(id => modelV2API.deleteDefinition(selectedCatalogId, id).catch(() => {}))); } + if (azureModelId) { + await modelV2API.createDefinition(selectedCatalogId, { + model_id: azureModelId, + name: azureDeploymentDisplayName.trim() || azureModelId, + }); + try { + const res = await providerAPI.testCredentials(selectedCatalogId, azureModelId); + if (!res.data.success) { + toast.warning(t('form.testFailed'), res.data.error || res.data.message); + } + } catch (testErr: any) { + toast.warning(t('form.testFailed'), testErr.response?.data?.detail || testErr.message); + } + } toast.success(t('providerAdded'), displayName); setSavedProviderId(selectedCatalogId); setSavedProviderName(displayName); @@ -1600,6 +1630,36 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: { )} + {isAzureProviderId(selectedCatalogId) && ( +
+
+ + setAzureDeploymentName(e.target.value)} + className="w-full px-3 py-2 border border-blue-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-300 text-sm bg-white" + placeholder={t('form.azureDeploymentPlaceholder')} + /> +

{t('form.azureDeploymentHint')}

+
+
+ + setAzureDeploymentDisplayName(e.target.value)} + className="w-full px-3 py-2 border border-blue-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-300 text-sm bg-white" + placeholder={azureDeploymentName.trim() || t('form.azureDeploymentDisplayPlaceholder')} + /> +
+
+ )} + {selectedCatalog.models.length > 0 && (
@@ -1712,7 +1772,13 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: {

{t('wizard.modelsAdded', { count: addedModelCount })}

)} - +
)} @@ -1791,10 +1857,12 @@ function useModelForm() { }; } -function ModelFormFields({ form, testResult, testing }: { +function ModelFormFields({ form, testResult, testing, modelIdPlaceholder, modelIdHint }: { form: ReturnType; testResult: { success: boolean; message: string; latency?: number } | null; testing: boolean; + modelIdPlaceholder?: string; + modelIdHint?: string; }) { const { t } = useTranslation('model'); return ( @@ -1809,8 +1877,9 @@ function ModelFormFields({ form, testResult, testing }: { value={form.modelId} onChange={e => form.setModelId(e.target.value)} className="w-full px-3 py-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-slate-400 text-sm" - placeholder="gpt-4o-custom" + placeholder={modelIdPlaceholder || 'gpt-4o-custom'} /> + {modelIdHint &&

{modelIdHint}

}
- + ); @@ -2085,7 +2160,20 @@ function ConfigureProviderDialog({ provider, existingCredentials, models, onClos // Catalog model management const [catalogModels, setCatalogModels] = useState([]); + const [catalogModelsLoaded, setCatalogModelsLoaded] = useState(false); const [selectedModelIds, setSelectedModelIds] = useState>(new Set(models.map(m => m.id))); + const [newAzureDeploymentName, setNewAzureDeploymentName] = useState(''); + const [newAzureDeploymentDisplayName, setNewAzureDeploymentDisplayName] = useState(''); + const isAzureProvider = isAzureProviderId(provider.id); + const catalogModelIds = useMemo(() => new Set(catalogModels.map(m => m.id)), [catalogModels]); + const selectedCatalogModelCount = useMemo( + () => catalogModels.filter(m => selectedModelIds.has(m.id)).length, + [catalogModels, selectedModelIds] + ); + const azureCustomModels = useMemo( + () => isAzureProvider && catalogModelsLoaded ? models.filter(m => !catalogModelIds.has(m.id)) : [], + [catalogModelIds, catalogModelsLoaded, isAzureProvider, models] + ); useEffect(() => { setApiKey(existingKey); @@ -2103,10 +2191,19 @@ function ConfigureProviderDialog({ provider, existingCredentials, models, onClos }, [provider.id, models]); useEffect(() => { + setCatalogModelsLoaded(false); catalogAPI.list().then(res => { const found = res.data.providers.find(p => p.id === provider.id); - if (found) setCatalogModels(found.models); - }).catch(() => {}); + if (found) { + setCatalogModels(found.models); + setCatalogModelsLoaded(true); + } else { + setCatalogModels([]); + } + }).catch(() => { + setCatalogModels([]); + setCatalogModelsLoaded(false); + }); }, [provider.id]); const handleToggleCatalogModel = (modelId: string) => { @@ -2148,6 +2245,13 @@ function ConfigureProviderDialog({ provider, existingCredentials, models, onClos ...toAdd.map(m => modelV2API.createDefinition(provider.id, { model_id: m.id, name: m.name }).catch(() => {})), ]); } + const azureModelId = newAzureDeploymentName.trim(); + if (isAzureProvider && azureModelId) { + await modelV2API.createDefinition(provider.id, { + model_id: azureModelId, + name: newAzureDeploymentDisplayName.trim() || azureModelId, + }); + } toast.success(t('credentialsSaved')); onConfigured(); @@ -2343,7 +2447,7 @@ ${hasExisting ? '你已有凭证配置,可以更新或测试连接。' : '请 + + ); + } + return (
) : isError ? ( + ) : attachment.isImage ? ( + ) : ( )} @@ -1559,7 +1714,7 @@ export default function SessionChat({
{attachment.error}
)}
- {isError && ( + {isError && !attachment.isImage && ( + + + +
+ + +
+ +
+ + +
+ + {formData.authType === 'bearer' && ( +
+ + update({ authValue: e.target.value })} + readOnly={isFieldReadOnly('authValue')} + placeholder={t('addMCP.authTokenPlaceholder')} + className={`${inputClassFor('authValue')} font-mono`} + /> +
+ )} + + {formData.authType === 'header' && ( +
+
+ + update({ authHeaderName: e.target.value })} + readOnly={isFieldReadOnly('authHeaderName')} + placeholder="X-API-Key" + className={`${inputClassFor('authHeaderName')} font-mono`} + /> +
+
+ + update({ authValue: e.target.value })} + readOnly={isFieldReadOnly('authValue')} + placeholder={t('addMCP.authHeaderValuePlaceholder')} + className={`${inputClassFor('authValue')} font-mono`} + /> +
+
+ )} + + {formData.authType === 'query' && ( +
+
+ + update({ authQueryName: e.target.value })} + readOnly={isFieldReadOnly('authQueryName')} + placeholder="apikey" + className={`${inputClassFor('authQueryName')} font-mono`} + /> +
+
+ + update({ authValue: e.target.value })} + readOnly={isFieldReadOnly('authValue')} + placeholder={t('addMCP.authQueryValuePlaceholder')} + className={`${inputClassFor('authValue')} font-mono`} + /> +
+
+ )} + +
+ +