diff --git a/.flocks/plugins/skills/browser-use/SKILL.md b/.flocks/plugins/skills/browser-use/SKILL.md index 8edb3333..1f41fefe 100644 --- a/.flocks/plugins/skills/browser-use/SKILL.md +++ b/.flocks/plugins/skills/browser-use/SKILL.md @@ -68,11 +68,7 @@ browser: not connected — 请确保 Chrome / Chromium / Edge 已打开,然后 说明当前环境不适合 `CDP 直连`。此时要: 1. 明确告诉用户是哪一项不满足,提示需要做什么操作才能达到要求 -2. 切换到 `agent-browser` 模式 -3. 立即阅读: - - `references/agent-browser.md` - -不要继续尝试 CDP。 +2. 提示用户切换到 `agent-browser` 模式 ## 执行规则 diff --git a/.github/workflows/windows-packaging-publish.yml b/.github/workflows/windows-packaging-publish.yml index f491883a..b2817d5d 100644 --- a/.github/workflows/windows-packaging-publish.yml +++ b/.github/workflows/windows-packaging-publish.yml @@ -47,6 +47,7 @@ jobs: -RepoRoot "${{ github.workspace }}" ` -AppVersion $appVersion $uvExe = Join-Path $out "tools/uv/uv.exe" + $pythonExe = Join-Path $out "tools/python/python.exe" $nodeExe = Join-Path $out "tools/node/node.exe" $chromeExe = Get-ChildItem -Path (Join-Path $out "tools/chrome") -Recurse -Filter "chrome.exe" -File -ErrorAction SilentlyContinue | Where-Object { $_.FullName -match "chrome-win" } | @@ -54,6 +55,9 @@ jobs: if (-not (Test-Path $uvExe)) { throw "uv executable not found in staging: $uvExe" } + if (-not (Test-Path $pythonExe)) { + throw "python executable not found in staging: $pythonExe" + } if (-not (Test-Path $nodeExe)) { throw "node executable not found in staging: $nodeExe" } @@ -61,6 +65,7 @@ jobs: throw "chrome executable not found in staging under tools/chrome" } $uvVersion = (& $uvExe --version).Trim() + $pythonVersion = (& $pythonExe --version).Trim() $nodeVersion = (& $nodeExe --version).Trim() $chromeVersion = (Get-Item -LiteralPath $chromeExe.FullName).VersionInfo.ProductVersion if ([string]::IsNullOrWhiteSpace($chromeVersion)) { @@ -71,6 +76,8 @@ jobs: } Write-Host "[runtime] pinned uv version: $($manifest.uv.version)" Write-Host "[runtime] bundled uv version: $uvVersion" + Write-Host "[runtime] pinned python version: $($manifest.python.version)" + Write-Host "[runtime] bundled python version: $pythonVersion" Write-Host "[runtime] pinned node version: $($manifest.nodejs.version)" Write-Host "[runtime] bundled node version: $nodeVersion" Write-Host "[runtime] pinned chrome version: $($manifest.chrome_for_testing.version)" @@ -78,6 +85,9 @@ jobs: if ($uvVersion -notmatch ("^uv\s+" + [regex]::Escape($manifest.uv.version) + "(\s|$)")) { throw "Bundled uv version does not match pinned version in manifest" } + if ($pythonVersion -ne ("Python " + $manifest.python.version)) { + throw "Bundled python version does not match pinned version in manifest" + } if ($nodeVersion -ne ("v" + $manifest.nodejs.version)) { throw "Bundled node version does not match pinned version in manifest" } diff --git a/.github/workflows/windows-packaging.yml b/.github/workflows/windows-packaging.yml index 1e595d9f..ea37b51c 100644 --- a/.github/workflows/windows-packaging.yml +++ b/.github/workflows/windows-packaging.yml @@ -42,6 +42,7 @@ jobs: $manifest = Get-Content -Path $manifestPath -Raw -Encoding utf8 | ConvertFrom-Json & "${{ github.workspace }}/packaging/windows/build-installer.ps1" -OutputDir $out -RepoRoot "${{ github.workspace }}" $uvExe = Join-Path $out "tools/uv/uv.exe" + $pythonExe = Join-Path $out "tools/python/python.exe" $nodeExe = Join-Path $out "tools/node/node.exe" $chromeExe = Get-ChildItem -Path (Join-Path $out "tools/chrome") -Recurse -Filter "chrome.exe" -File -ErrorAction SilentlyContinue | Where-Object { $_.FullName -match "chrome-win" } | @@ -49,6 +50,9 @@ jobs: if (-not (Test-Path $uvExe)) { throw "uv executable not found in staging: $uvExe" } + if (-not (Test-Path $pythonExe)) { + throw "python executable not found in staging: $pythonExe" + } if (-not (Test-Path $nodeExe)) { throw "node executable not found in staging: $nodeExe" } @@ -56,6 +60,7 @@ jobs: throw "chrome executable not found in staging under tools/chrome" } $uvVersion = (& $uvExe --version).Trim() + $pythonVersion = (& $pythonExe --version).Trim() $nodeVersion = (& $nodeExe --version).Trim() $chromeVersion = (Get-Item -LiteralPath $chromeExe.FullName).VersionInfo.ProductVersion if ([string]::IsNullOrWhiteSpace($chromeVersion)) { @@ -66,6 +71,8 @@ jobs: } Write-Host "[runtime] pinned uv version: $($manifest.uv.version)" Write-Host "[runtime] bundled uv version: $uvVersion" + Write-Host "[runtime] pinned python version: $($manifest.python.version)" + Write-Host "[runtime] bundled python version: $pythonVersion" Write-Host "[runtime] pinned node version: $($manifest.nodejs.version)" Write-Host "[runtime] bundled node version: $nodeVersion" Write-Host "[runtime] pinned chrome version: $($manifest.chrome_for_testing.version)" @@ -73,6 +80,9 @@ jobs: if ($uvVersion -notmatch ("^uv\s+" + [regex]::Escape($manifest.uv.version) + "(\s|$)")) { throw "Bundled uv version does not match pinned version in manifest" } + if ($pythonVersion -ne ("Python " + $manifest.python.version)) { + throw "Bundled python version does not match pinned version in manifest" + } if ($nodeVersion -ne ("v" + $manifest.nodejs.version)) { throw "Bundled node version does not match pinned version in manifest" } diff --git a/.gitignore b/.gitignore index 45739c7f..4d720f9c 100644 --- a/.gitignore +++ b/.gitignore @@ -100,7 +100,6 @@ tmp/ # Documentation docs/_build/ -docs/*.md site/ # Node.js (TUI) @@ -108,7 +107,6 @@ node_modules/ tui/node_modules/ bun.lockb .bun/ -docs/* !docs/CHANGELOG.md # TUI build diff --git a/README.md b/README.md index 94e25de0..bfa4c058 100644 --- a/README.md +++ b/README.md @@ -306,6 +306,10 @@ Scan the QR code with **WeChat** to join our official discussion group. ![WeCom official community QR code](assets/community-wecom-qr.png) -## 6. License +## 6. Contributing + +See [`docs/CONTRIBUTING.md`](docs/CONTRIBUTING.md) for development setup, coding standards, testing expectations, and Pull Request guidelines. + +## 7. License Apache License 2.0 diff --git a/README_zh.md b/README_zh.md index a0d88a8d..2e28c145 100644 --- a/README_zh.md +++ b/README_zh.md @@ -271,6 +271,10 @@ flocks start --server-host 127.0.0.1 --webui-host 0.0.0.0 ![企业微信官方交流群二维码](assets/community-wecom-qr.png) -## 6. 开源协议 +## 6. 参与贡献 + +开发环境、代码规范、测试要求和 Pull Request 流程请参考 [`docs/CONTRIBUTING.md`](docs/CONTRIBUTING.md)。 + +## 7. 开源协议 Apache License 2.0 diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md new file mode 100644 index 00000000..67b9a581 --- /dev/null +++ b/docs/CONTRIBUTING.md @@ -0,0 +1,197 @@ +# Contributing Guide + +Thank you for contributing to `flocks`. We welcome bug fixes, documentation improvements, tests, UX polish, new features, and other well-scoped changes that make the project better. + +This guide explains how to contribute in a way that is easy to review, maintain, and merge. + +## Ways to Contribute + +You can contribute by: + +- reporting bugs with clear reproduction steps +- proposing features or design improvements +- improving documentation, examples, and developer experience +- fixing issues and adding regression coverage +- improving the WebUI, CLI, workflows, plugins, tools, or platform integrations + +If your change is large or affects architecture, public behavior, or user workflows, please open an Issue first so the direction can be discussed before implementation starts. + +## Before You Start + +Before writing code, please: + +1. Search existing Issues and Pull Requests to avoid duplicate work. +2. Confirm the scope for larger features, refactors, or behavior changes. +3. Keep each contribution focused on one topic whenever possible. + +## Development Environment + +The main development stack for `flocks` currently includes: + +- Python `3.12` +- `uv` for Python environment and dependency management +- Node.js `22+` +- `npm` for frontend dependencies + +Recommended setup: + +```bash +uv sync --group dev +cd webui && npm ci +``` + +If you work on browser-related features, you may also need the browser runtime dependencies described in the project README. + +## Common Commands + +Use `uv run` for Python-related commands whenever possible. + +### Backend / Python + +```bash +uv run ruff check . +uv run pytest +``` + +If your change is scoped to a smaller area, run the most relevant tests first: + +```bash +uv run pytest tests/session +uv run pytest tests/cli/test_service_manager.py +``` + +### Frontend / WebUI + +```bash +cd webui +npm run lint +npm run build +``` + +If your change touches both Python and frontend code, please run checks for both parts. + +## Coding Standards + +Please make sure your changes follow the repository conventions: + +- Follow the Google Python Style Guide for Python code. +- Use `ruff` for linting and formatting-related checks. +- New features and bug fixes must include or update tests. +- Keep all test code under `tests/`. +- Except for the repository root `README.md`, feature guides, usage docs, and summary markdown files should go under `docs/`. +- Run Python commands with `uv run`, or from the project's active virtual environment. +- Any `.ps1` file in scripts must use **UTF-8 with BOM** encoding and **CRLF** line endings. + +Please also follow these general principles: + +- Keep changes focused and avoid unrelated refactors. +- Add type hints, error handling, and regression coverage where they meaningfully improve maintainability. +- Introduce new dependencies only when necessary, and explain why they are needed. +- Add brief comments for non-obvious logic, but avoid low-value commentary. + +## Branching and Commits + +Create your working branch from the latest `dev` branch. Do not develop directly on `main`, and do not open contribution PRs against `main` unless a maintainer explicitly asks for it. + +Suggested branch naming examples: + +- `feat/add-session-export` +- `fix/webui-login-redirect` +- `docs/contributing-guide` +- `refactor/mcp-client-cache` +- `test/add-workflow-route-cases` + +Write commit messages in clear English. A Conventional Commits style is recommended: + +```text +feat(cli): add service restart timeout option +fix(auth): preserve session after password reset +docs: add contributing guide +test(session): cover runner retry path +``` + +A good commit should: + +- focus on one main change +- describe intent clearly in the title +- include extra context in the body when behavior, compatibility, or motivation needs explanation + +## Testing Expectations + +Please validate your change according to its scope: + +- Documentation changes: verify links, commands, filenames, and paths. +- Python changes: run the relevant tests; for shared infrastructure changes, run broader coverage. +- Frontend changes: run at least `npm run lint` and `npm run build`. +- Cross-cutting changes: include enough automated or manual verification to show that the change works as intended. + +If you are fixing a bug, prefer adding a regression test that reproduces the issue before or alongside the fix. + +## Pull Request Guidelines + +All contribution PRs for `flocks` should target the `dev` branch. + +When opening a PR, make it easy for reviewers to understand: + +1. What problem the change solves. +2. What the scope of the change is. +3. Why the chosen approach is appropriate. +4. How you validated the change. +5. Whether there are compatibility, migration, or configuration impacts. + +If the PR changes UI or interaction behavior, include screenshots, recordings, or a clear before/after explanation. + +Recommended PR description template: + +```markdown +## Summary +- ... + +## Why +- ... + +## Test Plan +- [x] uv run pytest ... +- [x] npm run lint +- [ ] Manual verification +``` + +Please keep PRs as small and focused as practical. Multiple reviewable PRs are usually easier to merge than one large mixed change. + +## Issue Reporting + +This repository already provides GitHub Issue templates. Please choose the most appropriate template and include enough detail to make triage efficient: + +- Bug reports: reproduction steps, expected behavior, actual behavior, logs, and version information +- Feature requests: motivation, proposed solution, alternatives considered, and expected impact +- Plugin / tool requests: target use case, inputs, outputs, and relevant constraints + +High-quality Issues significantly improve response time and implementation quality. + +## Security Issues + +If you discover a security vulnerability or any issue that could expose users or deployments to risk, please do not disclose sensitive details in a public Issue. Contact the maintainers through an approved private channel first, then coordinate on disclosure after a fix is available. + +## Communication + +Please keep communication respectful, specific, and constructive: + +- discuss the problem, not the person +- provide evidence and context, not just conclusions +- stay open to review feedback, and split changes if needed + +We strongly prefer incremental, testable, reviewable contributions over large rewrites. + +## Pre-PR Checklist + +Before opening a PR, please confirm: + +- [ ] the change is focused and does not include unrelated edits +- [ ] code, naming, and documentation style match the repository +- [ ] new features or bug fixes include appropriate tests +- [ ] relevant local checks have passed +- [ ] the PR clearly explains background, approach, and validation +- [ ] the PR targets `dev` +- [ ] any new markdown documentation has been added under `docs/` when applicable + +Thank you for helping improve `flocks`. diff --git a/flocks/auth/service.py b/flocks/auth/service.py index 4ae03a6d..7a9eb618 100644 --- a/flocks/auth/service.py +++ b/flocks/auth/service.py @@ -82,7 +82,7 @@ async def init(cls) -> None: db_path = Storage.get_db_path() if cls._initialized and cls._initialized_db_path == str(db_path) and db_path.exists(): return - async with aiosqlite.connect(db_path) as db: + async with Storage.connect(db_path) as db: await db.executescript( """ CREATE TABLE IF NOT EXISTS users ( @@ -162,7 +162,7 @@ async def has_users(cls) -> bool: return True await cls.init() db_path = Storage.get_db_path() - async with aiosqlite.connect(db_path) as db: + async with Storage.connect(db_path) as db: async with db.execute("SELECT COUNT(1) FROM users") as cursor: row = await cursor.fetchone() result = bool(row and row[0] > 0) @@ -211,7 +211,7 @@ async def _create_user_internal( now = _iso_now() password_hash = cls._hash_password(password) db_path = Storage.get_db_path() - async with aiosqlite.connect(db_path) as db: + async with Storage.connect(db_path) as db: await db.execute( """ INSERT INTO users ( @@ -239,7 +239,7 @@ async def _create_user_internal( async def get_user_by_id(cls, user_id: str) -> Optional[LocalUser]: await cls.init() db_path = Storage.get_db_path() - async with aiosqlite.connect(db_path) as db: + async with Storage.connect(db_path) as db: async with db.execute( """ SELECT id, username, role, status, must_reset_password, @@ -266,7 +266,7 @@ async def get_user_by_id(cls, user_id: str) -> Optional[LocalUser]: async def get_user_by_username(cls, username: str) -> Optional[Tuple[LocalUser, str, Optional[str]]]: await cls.init() db_path = Storage.get_db_path() - async with aiosqlite.connect(db_path) as db: + async with Storage.connect(db_path) as db: async with db.execute( """ SELECT id, username, role, status, must_reset_password, created_at, updated_at, last_login_at, @@ -295,7 +295,7 @@ async def list_users(cls) -> List[LocalUser]: await cls.init() db_path = Storage.get_db_path() users: List[LocalUser] = [] - async with aiosqlite.connect(db_path) as db: + async with Storage.connect(db_path) as db: async with db.execute( """ SELECT id, username, role, status, must_reset_password, created_at, updated_at, last_login_at @@ -326,7 +326,7 @@ async def _create_session(cls, user_id: str) -> str: now = _iso_now() expires_at = (_utc_now() + timedelta(days=cls._session_ttl_days)).isoformat() db_path = Storage.get_db_path() - async with aiosqlite.connect(db_path) as db: + async with Storage.connect(db_path) as db: await db.execute( """ INSERT INTO user_sessions(session_id, user_id, expires_at, created_at, updated_at) @@ -341,7 +341,7 @@ async def _create_session(cls, user_id: str) -> str: async def get_user_by_session_id(cls, session_id: str) -> Optional[LocalUser]: await cls.init() db_path = Storage.get_db_path() - async with aiosqlite.connect(db_path) as db: + async with Storage.connect(db_path) as db: async with db.execute( """ SELECT u.id, u.username, u.role, u.status, u.must_reset_password, u.created_at, u.updated_at, u.last_login_at, @@ -377,7 +377,7 @@ async def get_user_by_session_id(cls, session_id: str) -> Optional[LocalUser]: async def revoke_session(cls, session_id: str) -> None: await cls.init() db_path = Storage.get_db_path() - async with aiosqlite.connect(db_path) as db: + async with Storage.connect(db_path) as db: await db.execute("DELETE FROM user_sessions WHERE session_id = ?", (session_id,)) await db.commit() @@ -407,7 +407,7 @@ async def login( session_id = await cls._create_session(user.id) now = _iso_now() db_path = Storage.get_db_path() - async with aiosqlite.connect(db_path) as db: + async with Storage.connect(db_path) as db: await db.execute("UPDATE users SET last_login_at = ?, updated_at = ? WHERE id = ?", (now, now, user.id)) await db.commit() @@ -453,7 +453,7 @@ async def set_password( now = _iso_now() pwd_hash = cls._hash_password(new_password) db_path = Storage.get_db_path() - async with aiosqlite.connect(db_path) as db: + async with Storage.connect(db_path) as db: cursor = await db.execute( """ UPDATE users diff --git a/flocks/channel/base.py b/flocks/channel/base.py index 7bac9f2e..5d58315d 100644 --- a/flocks/channel/base.py +++ b/flocks/channel/base.py @@ -53,6 +53,7 @@ class InboundMessage: chat_type: ChatType = ChatType.DIRECT text: str = "" media_url: Optional[str] = None + media_mime: Optional[str] = None reply_to_id: Optional[str] = None thread_id: Optional[str] = None mentioned: bool = False diff --git a/flocks/channel/builtin/weixin/__init__.py b/flocks/channel/builtin/weixin/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/flocks/channel/builtin/weixin/cdn.py b/flocks/channel/builtin/weixin/cdn.py new file mode 100644 index 00000000..60980e28 --- /dev/null +++ b/flocks/channel/builtin/weixin/cdn.py @@ -0,0 +1,162 @@ +""" +WeChat CDN (novac2c.cdn.weixin.qq.com) URL builders, SSRF protection, +and raw download / upload helpers for AES-encrypted media payloads. + +The CDN protocol: +- Inbound media is fetched from ``/c2c/download?encrypted_query_param=...`` + and decrypted client-side with the AES key embedded in the iLink frame. +- Outbound media is encrypted client-side and uploaded with POST to either + ``/c2c/upload?encrypted_query_param=&filekey=`` + or directly to ``upload_full_url`` returned by ``getuploadurl``. +""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Optional +from urllib.parse import quote, urlparse + +if TYPE_CHECKING: + import aiohttp + + +# Hosts the channel is allowed to fetch media from. SSRF guard. +_WEIXIN_CDN_ALLOWLIST: frozenset[str] = frozenset( + { + "novac2c.cdn.weixin.qq.com", + "ilinkai.weixin.qq.com", + "wx.qlogo.cn", + "thirdwx.qlogo.cn", + "res.wx.qq.com", + "mmbiz.qpic.cn", + "mmbiz.qlogo.cn", + } +) + + +def cdn_download_url(cdn_base_url: str, encrypted_query_param: str) -> str: + return ( + f"{cdn_base_url.rstrip('/')}/download" + f"?encrypted_query_param={quote(encrypted_query_param, safe='')}" + ) + + +def cdn_upload_url(cdn_base_url: str, upload_param: str, filekey: str) -> str: + return ( + f"{cdn_base_url.rstrip('/')}/upload" + f"?encrypted_query_param={quote(upload_param, safe='')}" + f"&filekey={quote(filekey, safe='')}" + ) + + +def assert_weixin_cdn_url(url: str) -> None: + """Raise ``ValueError`` if *url* is not on a known WeChat CDN host. + + Used as an SSRF guard before fetching ``full_url`` (which the iLink + server controls) — without this, a malicious frame could redirect + downloads to arbitrary internal hosts. + """ + try: + parsed = urlparse(url) + scheme = parsed.scheme.lower() + host = parsed.hostname or "" + except Exception as exc: # noqa: BLE001 + raise ValueError(f"Unparseable media URL: {url!r}") from exc + + if scheme not in ("http", "https"): + raise ValueError( + f"Media URL has disallowed scheme {scheme!r}; only http/https permitted." + ) + if host not in _WEIXIN_CDN_ALLOWLIST: + raise ValueError( + f"Media URL host {host!r} is not in the WeChat CDN allowlist. " + "Refusing to fetch to prevent SSRF." + ) + + +async def download_bytes( + session: "aiohttp.ClientSession", + *, + url: str, + timeout_seconds: float = 60.0, +) -> bytes: + """GET *url* and return the response body bytes. + + Uses ``asyncio.wait_for`` rather than ``aiohttp.ClientTimeout`` so the + coroutine can be safely scheduled via ``run_coroutine_threadsafe`` from + callers running outside the aiohttp event loop. + """ + async def _do() -> bytes: + async with session.get(url) as resp: + resp.raise_for_status() + return await resp.read() + return await asyncio.wait_for(_do(), timeout=timeout_seconds) + + +async def upload_ciphertext( + session: "aiohttp.ClientSession", + *, + ciphertext: bytes, + upload_url: str, + timeout_seconds: float = 120.0, +) -> str: + """POST encrypted bytes to the WeChat CDN, return ``x-encrypted-param`` echo. + + Both the constructed CDN URL (from ``upload_param``) and the direct + ``upload_full_url`` use POST with the raw ciphertext as the body. + """ + async def _do() -> str: + async with session.post( + upload_url, + data=ciphertext, + headers={"Content-Type": "application/octet-stream"}, + ) as resp: + if resp.status == 200: + encrypted_param = resp.headers.get("x-encrypted-param") + if encrypted_param: + await resp.read() + return encrypted_param + raw = await resp.text() + raise RuntimeError(f"CDN upload missing x-encrypted-param header: {raw[:200]}") + raw = await resp.text() + raise RuntimeError(f"CDN upload HTTP {resp.status}: {raw[:200]}") + return await asyncio.wait_for(_do(), timeout=timeout_seconds) + + +def media_reference(item: dict, key: str) -> dict: + """Pull the ``.media`` sub-dict out of an item like ``image_item``/``file_item``.""" + return (item.get(key) or {}).get("media") or {} + + +async def download_and_decrypt_media( + session: "aiohttp.ClientSession", + *, + cdn_base_url: str, + encrypted_query_param: Optional[str], + aes_key_b64: Optional[str], + full_url: Optional[str], + timeout_seconds: float, +) -> bytes: + """Fetch + AES-decrypt a single media payload. + + Caller supplies whichever of ``encrypted_query_param`` / ``full_url`` is + present in the iLink frame. ``aes_key_b64`` is decoded by ``crypto.parse_aes_key``. + """ + # Local import to avoid a circular dependency between cdn and crypto. + from .crypto import aes128_ecb_decrypt, parse_aes_key + + if encrypted_query_param: + raw = await download_bytes( + session, + url=cdn_download_url(cdn_base_url, encrypted_query_param), + timeout_seconds=timeout_seconds, + ) + elif full_url: + assert_weixin_cdn_url(full_url) + raw = await download_bytes(session, url=full_url, timeout_seconds=timeout_seconds) + else: + raise RuntimeError("media item had neither encrypt_query_param nor full_url") + + if aes_key_b64: + raw = aes128_ecb_decrypt(raw, parse_aes_key(aes_key_b64)) + return raw diff --git a/flocks/channel/builtin/weixin/channel.py b/flocks/channel/builtin/weixin/channel.py new file mode 100644 index 00000000..cb1e30d5 --- /dev/null +++ b/flocks/channel/builtin/weixin/channel.py @@ -0,0 +1,696 @@ +""" +Weixin (微信) ChannelPlugin implementation. + +Connects Flocks to WeChat personal accounts via Tencent's iLink Bot API. +Only accounts registered as iLink bots (via QR scan) are supported. + +Design notes: +- Long-poll ``getupdates`` drives inbound delivery. +- Every outbound reply should echo the latest ``context_token`` for the peer. +- Media files move through an AES-128-ECB encrypted CDN protocol — see + ``media.py`` and ``cdn.py``. +- Token / credentials are obtained via QR login on the iLink Bot developer + portal, then configured under the ``weixin`` channel. +""" + +from __future__ import annotations + +import asyncio +import hashlib +import os +import uuid +from typing import Awaitable, Callable, Optional +from urllib.parse import urlparse + +from flocks.channel.base import ( + ChannelCapabilities, + ChannelMeta, + ChannelPlugin, + ChatType, + DeliveryResult, + InboundMessage, + OutboundContext, +) +from flocks.utils.log import Log + +from . import client as ilink +from .config import ( + AIOHTTP_AVAILABLE, + BACKOFF_DELAY_SECONDS, + CRYPTO_AVAILABLE, + ILINK_BASE_URL, + LONG_POLL_TIMEOUT_MS, + MAX_CONSECUTIVE_FAILURES, + MAX_MESSAGE_LENGTH, + RATE_LIMIT_ERRCODE, + RETRY_DELAY_SECONDS, + SESSION_EXPIRED_ERRCODE, + WEIXIN_CDN_BASE_URL, +) +from .format import format_for_weixin, split_chunks +from .inbound import extract_text, guess_chat_type, safe_id +from .media import ( + MediaCache, + download_inbound_item, + fetch_remote_to_temp, + is_downloadable_media_item, + send_outbound_file, +) +from .store import ( + ContextTokenStore, + MessageDedup, + load_sync_buf, + save_sync_buf, +) + +log = Log.create(service="channel.weixin") + +# Local alias to keep type hints readable when aiohttp is missing at import time +if AIOHTTP_AVAILABLE: + import aiohttp # type: ignore[import-untyped] + + +class WeixinChannel(ChannelPlugin): + """WeChat (微信) personal account channel via Tencent iLink Bot API. + + Prerequisites: + - A WeChat account registered as an iLink bot (QR scan on the iLink portal). + - ``aiohttp`` and ``cryptography`` Python packages installed. + + Required config keys: + - ``token`` — iLink bot token (``WEIXIN_TOKEN`` env var as fallback) + - ``accountId`` — iLink bot account ID (``WEIXIN_ACCOUNT_ID`` env var as fallback) + + Optional config keys: + - ``baseUrl`` — iLink API base URL (defaults to ilinkai.weixin.qq.com) + - ``cdnBaseUrl`` — iLink CDN base URL (defaults to novac2c.cdn.weixin.qq.com) + - ``dmPolicy`` — ``"open"`` (default) | ``"disabled"`` | ``"allowlist"`` + - ``allowFrom`` — list of allowed sender user IDs for DM allowlist mode + - ``groupPolicy`` — ``"all"`` (default) | ``"disabled"`` | ``"allowlist"`` + Controls whether group chat messages are processed. + Note: iLink Bot accounts may not receive group events in + ordinary WeChat groups depending on account type. + - ``groupAllowFrom`` — list of allowed group / room IDs for group allowlist mode + - ``sendChunkDelay`` — seconds between multi-chunk messages (default 1.5) + - ``dataDir`` — override path for storing sync_buf / context-token / media cache + (default: ~/.flocks/workspace/channels/weixin) + """ + + def __init__(self) -> None: + super().__init__() + self._token: str = "" + self._account_id: str = "" + self._base_url: str = ILINK_BASE_URL + self._cdn_base_url: str = WEIXIN_CDN_BASE_URL + self._dm_policy: str = "open" + self._allow_from: list[str] = [] + self._group_policy: str = "all" + self._group_allow_from: list[str] = [] + self._send_chunk_delay: float = 1.5 + self._send_chunk_retries: int = 4 + self._data_dir: Optional[str] = None + + self._token_store: ContextTokenStore = ContextTokenStore() + self._dedup: MessageDedup = MessageDedup() + self._media_cache: Optional[MediaCache] = None + + self._poll_session: "Optional[aiohttp.ClientSession]" = None + self._send_session: "Optional[aiohttp.ClientSession]" = None + + # ------------------------------------------------------------------ + # ChannelPlugin interface + # ------------------------------------------------------------------ + + def meta(self) -> ChannelMeta: + return ChannelMeta( + id="weixin", + label="微信", + aliases=["wechat", "wx"], + order=30, + ) + + def capabilities(self) -> ChannelCapabilities: + return ChannelCapabilities( + chat_types=[ChatType.DIRECT, ChatType.GROUP], + media=True, + threads=False, + reactions=False, + edit=False, + rich_text=True, + ) + + def validate_config(self, config: dict) -> Optional[str]: + token = config.get("token") or os.getenv("WEIXIN_TOKEN", "") + account_id = config.get("accountId") or os.getenv("WEIXIN_ACCOUNT_ID", "") + if not str(token).strip(): + return "Missing required config: token (or WEIXIN_TOKEN env var)" + if not str(account_id).strip(): + return "Missing required config: accountId (or WEIXIN_ACCOUNT_ID env var)" + return None + + def config_schema(self) -> Optional[dict]: + return { + "type": "object", + "properties": { + "token": {"type": "string", "description": "iLink bot token (从 QR 登录获取)"}, + "accountId": {"type": "string", "description": "iLink bot account ID (从 QR 登录获取)"}, + "baseUrl": {"type": "string", "description": "iLink API 地址", "default": ILINK_BASE_URL}, + "cdnBaseUrl": {"type": "string", "description": "iLink CDN 地址", "default": WEIXIN_CDN_BASE_URL}, + "dmPolicy": { + "type": "string", + "enum": ["open", "disabled", "allowlist"], + "description": "私信策略", + "default": "open", + }, + "allowFrom": {"type": "string", "description": "allowlist 模式下允许的发送者 user_id,逗号分隔"}, + "groupPolicy": { + "type": "string", + "enum": ["all", "disabled", "allowlist"], + "description": "群聊策略", + "default": "all", + }, + "groupAllowFrom": {"type": "string", "description": "群聊 allowlist 模式下允许的群 / 房间 ID,逗号分隔"}, + "sendChunkDelay": {"type": "number", "description": "多段消息发送间隔(秒)", "default": 1.5}, + "dataDir": {"type": "string", "description": "状态文件 / 媒体缓存存储目录(默认 ~/.flocks/workspace/channels/weixin)"}, + }, + "required": ["token", "accountId"], + } + + def target_hint(self) -> str: + return "" + + @property + def text_chunk_limit(self) -> int: + return MAX_MESSAGE_LENGTH + + def format_message(self, text: str, format_hint: str = "markdown") -> str: + return format_for_weixin(text) + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + async def start( + self, + config: dict, + on_message: Callable[[InboundMessage], Awaitable[None]], + abort_event: Optional[asyncio.Event] = None, + ) -> None: + if not (AIOHTTP_AVAILABLE and CRYPTO_AVAILABLE): + raise RuntimeError( + "Weixin channel requires ``aiohttp`` and ``cryptography``. " + "Run: pip install aiohttp cryptography" + ) + + self._token = str(config.get("token") or os.getenv("WEIXIN_TOKEN", "")).strip() + self._account_id = str(config.get("accountId") or os.getenv("WEIXIN_ACCOUNT_ID", "")).strip() + self._base_url = str(config.get("baseUrl") or ILINK_BASE_URL).rstrip("/") + self._cdn_base_url = str(config.get("cdnBaseUrl") or WEIXIN_CDN_BASE_URL).rstrip("/") + self._dm_policy = str(config.get("dmPolicy") or "open").lower() + raw_allow = config.get("allowFrom") or "" + self._allow_from = [s.strip() for s in str(raw_allow).split(",") if s.strip()] + self._group_policy = str(config.get("groupPolicy") or "all").lower() + raw_group_allow = config.get("groupAllowFrom") or "" + self._group_allow_from = [s.strip() for s in str(raw_group_allow).split(",") if s.strip()] + self._send_chunk_delay = float(config.get("sendChunkDelay") or 1.5) + self._send_chunk_retries = int(config.get("sendChunkRetries") or 4) + self._data_dir = config.get("dataDir") + + self._token_store = ContextTokenStore(self._data_dir) + self._token_store.restore(self._account_id) + self._dedup = MessageDedup() + self._media_cache = MediaCache(self._data_dir) + + no_timeout = aiohttp.ClientTimeout( + total=None, connect=None, sock_connect=None, sock_read=None, + ) + self._poll_session = aiohttp.ClientSession( + trust_env=True, connector=ilink.make_ssl_connector(), + ) + self._send_session = aiohttp.ClientSession( + trust_env=True, connector=ilink.make_ssl_connector(), timeout=no_timeout, + ) + + self.mark_connected() + log.info("weixin.connected", { + "account_id": safe_id(self._account_id), + "base_url": self._base_url, + }) + if self._group_policy != "disabled": + log.warning("weixin.group_policy.note", { + "group_policy": self._group_policy, + "note": ( + "QR-login connects an iLink bot identity (e.g. ...@im.bot), not a " + "normal personal WeChat account. Ordinary WeChat group messages are " + "typically NOT delivered by iLink for this account type. " + "groupPolicy only takes effect if iLink actually delivers group events." + ), + }) + + try: + await self._poll_loop(on_message, abort_event) + finally: + await self._close_sessions() + self.mark_disconnected() + + async def stop(self) -> None: + await self._close_sessions() + + async def _close_sessions(self) -> None: + for attr in ("_poll_session", "_send_session"): + session = getattr(self, attr, None) + if session and not session.closed: + try: + await session.close() + except Exception: + pass + setattr(self, attr, None) + + # ------------------------------------------------------------------ + # Inbound long-poll loop + # ------------------------------------------------------------------ + + async def _poll_loop( + self, + on_message: Callable[[InboundMessage], Awaitable[None]], + abort_event: Optional[asyncio.Event], + ) -> None: + assert self._poll_session is not None + sync_buf = load_sync_buf(self._account_id, self._data_dir) + timeout_ms = LONG_POLL_TIMEOUT_MS + consecutive_failures = 0 + + while abort_event is None or not abort_event.is_set(): + try: + response = await ilink.get_updates( + self._poll_session, + base_url=self._base_url, + token=self._token, + sync_buf=sync_buf, + timeout_ms=timeout_ms, + ) + + suggested = response.get("longpolling_timeout_ms") + if isinstance(suggested, int) and suggested > 0: + timeout_ms = suggested + + ret = response.get("ret", 0) + errcode = response.get("errcode", 0) + + if ret not in (0, None) or errcode not in (0, None): + if ( + ret == SESSION_EXPIRED_ERRCODE + or errcode == SESSION_EXPIRED_ERRCODE + or ilink.is_stale_session(ret, errcode, response.get("errmsg")) + ): + log.error("weixin.session_expired", { + "account_id": safe_id(self._account_id), + }) + await asyncio.sleep(600) + consecutive_failures = 0 + continue + + consecutive_failures += 1 + log.warning("weixin.getupdates_error", { + "ret": ret, "errcode": errcode, + "errmsg": response.get("errmsg", ""), + "attempt": consecutive_failures, + }) + await asyncio.sleep( + BACKOFF_DELAY_SECONDS + if consecutive_failures >= MAX_CONSECUTIVE_FAILURES + else RETRY_DELAY_SECONDS + ) + if consecutive_failures >= MAX_CONSECUTIVE_FAILURES: + consecutive_failures = 0 + continue + + consecutive_failures = 0 + new_sync_buf = str(response.get("get_updates_buf") or "") + if new_sync_buf: + sync_buf = new_sync_buf + save_sync_buf(self._account_id, sync_buf, self._data_dir) + + for message in response.get("msgs") or []: + asyncio.create_task(self._process_message_safe(message, on_message)) + + except asyncio.CancelledError: + break + except Exception as exc: + consecutive_failures += 1 + log.error("weixin.poll_error", { + "error": str(exc), "attempt": consecutive_failures, + }) + await asyncio.sleep( + BACKOFF_DELAY_SECONDS + if consecutive_failures >= MAX_CONSECUTIVE_FAILURES + else RETRY_DELAY_SECONDS + ) + if consecutive_failures >= MAX_CONSECUTIVE_FAILURES: + consecutive_failures = 0 + + async def _process_message_safe( + self, + message: dict, + on_message: Callable[[InboundMessage], Awaitable[None]], + ) -> None: + try: + await self._process_message(message, on_message) + except Exception as exc: + log.error("weixin.process_error", { + "from": safe_id(message.get("from_user_id")), + "error": str(exc), + }) + + async def _process_message( + self, + message: dict, + on_message: Callable[[InboundMessage], Awaitable[None]], + ) -> None: + sender_id = str(message.get("from_user_id") or "").strip() + if not sender_id or sender_id == self._account_id: + return + + message_id = str(message.get("message_id") or "").strip() + if message_id and self._dedup.is_duplicate(message_id): + return + + item_list = message.get("item_list") or [] + text = extract_text(item_list) + + if text: + content_key = f"content:{sender_id}:{hashlib.md5(text.encode()).hexdigest()}" + if self._dedup.is_duplicate(content_key): + log.debug("weixin.dedup_content", {"sender": safe_id(sender_id)}) + return + + chat_type_str, effective_chat_id = guess_chat_type(message, self._account_id) + + if chat_type_str == "group": + if not self._is_group_allowed(effective_chat_id): + return + elif not self._is_dm_allowed(sender_id): + return + + # Download the first inbound media item (image / video / voice / file). + # InboundMessage.media_url is single-valued, so any extras are dropped. + media_url, media_mime = await self._collect_inbound_media(item_list, sender_id) + + if not text and not media_url: + return + + context_token = str(message.get("context_token") or "").strip() + if context_token: + self._token_store.set(self._account_id, sender_id, context_token) + + chat_type = ChatType.GROUP if chat_type_str == "group" else ChatType.DIRECT + inbound = InboundMessage( + channel_id="weixin", + account_id=self._account_id, + message_id=message_id or str(uuid.uuid4()), + sender_id=sender_id, + sender_name=sender_id, + chat_id=effective_chat_id, + chat_type=chat_type, + text=text, + media_url=media_url, + media_mime=media_mime, + mentioned=False, + raw=message, + ) + log.info("weixin.inbound", { + "from": safe_id(sender_id), + "chat_type": chat_type_str, + "text_preview": text[:50], + "media_mime": media_mime, + }) + await on_message(inbound) + + async def _collect_inbound_media( + self, item_list: list, sender_id: str, + ) -> tuple[Optional[str], Optional[str]]: + """Download the first downloadable media item and return ``(uri, mime)``. + + ``InboundMessage.media_url`` is single-valued, so we deliberately do NOT + download items beyond the first — only count them so a warning is logged. + """ + if not self._poll_session or not self._media_cache: + return None, None + media_items = [item for item in item_list if is_downloadable_media_item(item)] + if not media_items: + return None, None + + sender_log = safe_id(sender_id) + if len(media_items) > 1: + log.warning("weixin.media.extra_dropped", { + "from": sender_log, + "dropped": len(media_items) - 1, + }) + + result = await download_inbound_item( + self._poll_session, + item=media_items[0], + cdn_base_url=self._cdn_base_url, + cache=self._media_cache, + sender_log_id=sender_log, + ) + return (result[0], result[1]) if result else (None, None) + + def _is_dm_allowed(self, sender_id: str) -> bool: + if self._dm_policy == "disabled": + return False + if self._dm_policy == "allowlist": + return sender_id in self._allow_from + return True + + def _is_group_allowed(self, chat_id: str) -> bool: + if self._group_policy == "disabled": + return False + if self._group_policy == "allowlist": + return chat_id in self._group_allow_from + return True + + # ------------------------------------------------------------------ + # Outbound: text + # ------------------------------------------------------------------ + + async def send_text(self, ctx: OutboundContext) -> DeliveryResult: + if not self._send_session or not self._token: + return DeliveryResult( + channel_id="weixin", message_id="", + success=False, error="Not connected", + ) + + formatted = format_for_weixin(ctx.text) + chunks = split_chunks(formatted, MAX_MESSAGE_LENGTH) + if not chunks: + return DeliveryResult(channel_id="weixin", message_id="") + + context_token = self._token_store.get(self._account_id, ctx.to) + last_message_id = "" + try: + for idx, chunk in enumerate(chunks): + client_id = f"flocks-weixin-{uuid.uuid4().hex}" + await self._send_chunk_with_retry( + to=ctx.to, chunk=chunk, + context_token=context_token, client_id=client_id, + ) + last_message_id = client_id + if idx < len(chunks) - 1 and self._send_chunk_delay > 0: + await asyncio.sleep(self._send_chunk_delay) + except Exception as exc: + log.error("weixin.send_text.error", { + "to": safe_id(ctx.to), "error": str(exc), + }) + return DeliveryResult( + channel_id="weixin", message_id="", + success=False, error=str(exc), + ) + return DeliveryResult(channel_id="weixin", message_id=last_message_id) + + async def _send_chunk_with_retry( + self, + *, + to: str, + chunk: str, + context_token: Optional[str], + client_id: str, + ) -> None: + """Send a single text chunk with per-chunk retry and backoff. + + - On session-expired (errcode -14): retry once *without* ``context_token`` + and drop it from the local store. + - On rate-limit (errcode -2): back off 3× and retry. + """ + last_error: Optional[Exception] = None + retried_without_token = False + retry_delay = 1.0 + + for attempt in range(self._send_chunk_retries + 1): + try: + resp = await ilink.send_text_message( + self._send_session, + base_url=self._base_url, + token=self._token, + to=to, text=chunk, + context_token=context_token, client_id=client_id, + ) + + if isinstance(resp, dict): + ret = resp.get("ret") + errcode = resp.get("errcode") + # Always log the iLink response so we can confirm whether + # the message was actually accepted (vs silently dropped). + log.info("weixin.send.response", { + "to": safe_id(to), + "client_id": client_id[:24], + "ret": ret, + "errcode": errcode, + "errmsg": resp.get("errmsg"), + "msg_id": resp.get("msg_id") or resp.get("message_id"), + "has_context_token": bool(context_token), + "chunk_len": len(chunk), + }) + if (ret is not None and ret != 0) or (errcode is not None and errcode != 0): + is_session_expired = ( + ret == SESSION_EXPIRED_ERRCODE + or errcode == SESSION_EXPIRED_ERRCODE + or ilink.is_stale_session(ret, errcode, resp.get("errmsg")) + ) + if is_session_expired and not retried_without_token and context_token: + retried_without_token = True + context_token = None + self._token_store.clear(self._account_id, to) + log.warning("weixin.send.session_expired_retry", { + "to": safe_id(to), + }) + continue + + is_rate_limited = ( + ret == RATE_LIMIT_ERRCODE or errcode == RATE_LIMIT_ERRCODE + ) + if is_rate_limited: + errmsg = resp.get("errmsg") or resp.get("msg") or "rate limited" + last_error = RuntimeError( + f"iLink sendmessage rate limited: " + f"ret={ret} errcode={errcode} errmsg={errmsg}" + ) + if attempt >= self._send_chunk_retries: + break + wait = retry_delay * 3 + log.warning("weixin.send.rate_limited", { + "to": safe_id(to), "wait": wait, + }) + await asyncio.sleep(wait) + continue + + errmsg = resp.get("errmsg") or resp.get("msg") or "unknown error" + raise RuntimeError( + f"iLink sendmessage error: ret={ret} errcode={errcode} errmsg={errmsg}" + ) + return + + except Exception as exc: + last_error = exc + if attempt >= self._send_chunk_retries: + break + wait = retry_delay * (attempt + 1) + log.warning("weixin.send.retry", { + "to": safe_id(to), + "attempt": attempt + 1, + "wait": wait, + "error": str(exc), + }) + await asyncio.sleep(wait) + + if last_error is not None: + raise last_error + + # ------------------------------------------------------------------ + # Outbound: media + # ------------------------------------------------------------------ + + async def send_media(self, ctx: OutboundContext) -> DeliveryResult: + """Send a media file (image / video / voice / document). + + ``ctx.media_url`` may be: + - a local path (``/abs/path/to/file.png``) + - a ``file://`` URI + - a remote ``http(s)://`` URL on the WeChat CDN allowlist + """ + if not self._send_session or not self._token: + return DeliveryResult( + channel_id="weixin", message_id="", + success=False, error="Not connected", + ) + if not ctx.media_url: + # No media to send — fall back to plain text via send_text. + return await self.send_text(ctx) + + local_path, cleanup = await self._resolve_media_to_path(ctx.media_url) + if not local_path: + return DeliveryResult( + channel_id="weixin", message_id="", + success=False, error=f"Could not resolve media URL: {ctx.media_url}", + ) + + context_token = self._token_store.get(self._account_id, ctx.to) + try: + # Caption first (if any) so the file appears under it in chat order. + if ctx.text and ctx.text.strip(): + caption_result = await self.send_text(ctx) + if not caption_result.success: + return caption_result + + client_id = await send_outbound_file( + self._send_session, + base_url=self._base_url, + cdn_base_url=self._cdn_base_url, + token=self._token, + chat_id=ctx.to, + path=local_path, + context_token=context_token, + ) + return DeliveryResult(channel_id="weixin", message_id=client_id) + + except Exception as exc: + log.error("weixin.send_media.error", { + "to": safe_id(ctx.to), "error": str(exc), + }) + return DeliveryResult( + channel_id="weixin", message_id="", + success=False, error=str(exc), + ) + finally: + if cleanup and local_path: + try: + os.unlink(local_path) + except OSError: + pass + + async def _resolve_media_to_path(self, media_url: str) -> tuple[Optional[str], bool]: + """Resolve *media_url* to an on-disk path. Returns ``(path, should_cleanup)``.""" + parsed = urlparse(media_url) + scheme = parsed.scheme.lower() + + if scheme in ("", "file"): + path = parsed.path if scheme == "file" else media_url + if not os.path.isabs(path): + path = os.path.abspath(path) + return (path, False) if os.path.exists(path) else (None, False) + + if scheme in ("http", "https"): + try: + # Validate host before downloading to prevent SSRF. + from .cdn import assert_weixin_cdn_url + assert_weixin_cdn_url(media_url) + path = await fetch_remote_to_temp(self._send_session, url=media_url) + return path, True + except Exception as exc: + log.warning("weixin.media.fetch_failed", { + "url": media_url, "error": str(exc), + }) + return None, False + + log.warning("weixin.media.unsupported_scheme", {"scheme": scheme}) + return None, False diff --git a/flocks/channel/builtin/weixin/client.py b/flocks/channel/builtin/weixin/client.py new file mode 100644 index 00000000..1069fabd --- /dev/null +++ b/flocks/channel/builtin/weixin/client.py @@ -0,0 +1,229 @@ +""" +Low-level iLink Bot HTTP API helpers. + +Each function maps 1:1 to an iLink endpoint and returns the parsed JSON dict. +Higher-level retry/backoff is handled by the channel itself. +""" + +from __future__ import annotations + +import asyncio +import base64 +import json +import secrets +import ssl +import struct +from typing import TYPE_CHECKING, Optional + +from .config import ( + API_TIMEOUT_MS, + CHANNEL_VERSION, + EP_GET_UPDATES, + EP_GET_UPLOAD_URL, + EP_SEND_MESSAGE, + ILINK_APP_CLIENT_VERSION, + ILINK_APP_ID, + ITEM_TEXT, + MSG_STATE_FINISH, + MSG_TYPE_BOT, + RATE_LIMIT_ERRCODE, +) + +if TYPE_CHECKING: + import aiohttp + + +def make_ssl_connector() -> "Optional[aiohttp.TCPConnector]": + """Return a TCPConnector with certifi CA bundle for iLink TLS verification. + + Tencent's ``ilinkai.weixin.qq.com`` is not always verifiable against + Homebrew OpenSSL on macOS; certifi's Mozilla bundle is the reliable choice. + Returns ``None`` if certifi or aiohttp is unavailable; caller falls back + to aiohttp defaults. + """ + try: + import aiohttp # local import keeps module importable without aiohttp + import certifi + except ImportError: + return None + ssl_ctx = ssl.create_default_context(cafile=certifi.where()) + return aiohttp.TCPConnector(ssl=ssl_ctx) + + +def random_wechat_uin() -> str: + value = struct.unpack(">I", secrets.token_bytes(4))[0] + return base64.b64encode(str(value).encode("utf-8")).decode("ascii") + + +def base_info() -> dict: + return {"channel_version": CHANNEL_VERSION} + + +def make_headers(token: Optional[str], body: str) -> dict: + headers = { + "Content-Type": "application/json", + "AuthorizationType": "ilink_bot_token", + "Content-Length": str(len(body.encode("utf-8"))), + "X-WECHAT-UIN": random_wechat_uin(), + "iLink-App-Id": ILINK_APP_ID, + "iLink-App-ClientVersion": str(ILINK_APP_CLIENT_VERSION), + } + if token: + headers["Authorization"] = f"Bearer {token}" + return headers + + +def is_stale_session( + ret: Optional[int], errcode: Optional[int], errmsg: Optional[str] +) -> bool: + """Detect the iLink "stale session" disguise of errcode -2. + + iLink occasionally returns ret/errcode = -2 with errmsg "unknown error" + for an expired session, rather than the documented errcode -14. + """ + if ret != RATE_LIMIT_ERRCODE and errcode != RATE_LIMIT_ERRCODE: + return False + return (errmsg or "").lower() == "unknown error" + + +def _json_dumps(payload: dict) -> str: + return json.dumps(payload, ensure_ascii=False, separators=(",", ":")) + + +async def api_post( + session: "aiohttp.ClientSession", + *, + base_url: str, + endpoint: str, + payload: dict, + token: Optional[str], + timeout_ms: int, +) -> dict: + """POST *payload* + ``base_info`` to ``{base_url}/{endpoint}``.""" + import aiohttp + + body = _json_dumps({**payload, "base_info": base_info()}) + url = f"{base_url.rstrip('/')}/{endpoint}" + timeout = aiohttp.ClientTimeout(total=timeout_ms / 1000) + async with session.post(url, data=body, headers=make_headers(token, body), timeout=timeout) as resp: + raw = await resp.text() + if not resp.ok: + raise RuntimeError(f"iLink POST {endpoint} HTTP {resp.status}: {raw[:200]}") + return json.loads(raw) + + +async def get_updates( + session: "aiohttp.ClientSession", + *, + base_url: str, + token: str, + sync_buf: str, + timeout_ms: int, +) -> dict: + try: + return await api_post( + session, + base_url=base_url, + endpoint=EP_GET_UPDATES, + payload={"get_updates_buf": sync_buf}, + token=token, + timeout_ms=timeout_ms, + ) + except asyncio.TimeoutError: + return {"ret": 0, "msgs": [], "get_updates_buf": sync_buf} + + +async def send_text_message( + session: "aiohttp.ClientSession", + *, + base_url: str, + token: str, + to: str, + text: str, + context_token: Optional[str], + client_id: str, +) -> dict: + if not text or not text.strip(): + raise ValueError("send_text_message: text must not be empty") + message: dict = { + "from_user_id": "", + "to_user_id": to, + "client_id": client_id, + "message_type": MSG_TYPE_BOT, + "message_state": MSG_STATE_FINISH, + "item_list": [{"type": ITEM_TEXT, "text_item": {"text": text}}], + } + if context_token: + message["context_token"] = context_token + return await api_post( + session, + base_url=base_url, + endpoint=EP_SEND_MESSAGE, + payload={"msg": message}, + token=token, + timeout_ms=API_TIMEOUT_MS, + ) + + +async def send_media_message( + session: "aiohttp.ClientSession", + *, + base_url: str, + token: str, + to: str, + item: dict, + context_token: Optional[str], + client_id: str, +) -> dict: + """Send a single pre-built media item (image/video/voice/file).""" + message: dict = { + "from_user_id": "", + "to_user_id": to, + "client_id": client_id, + "message_type": MSG_TYPE_BOT, + "message_state": MSG_STATE_FINISH, + "item_list": [item], + } + if context_token: + message["context_token"] = context_token + return await api_post( + session, + base_url=base_url, + endpoint=EP_SEND_MESSAGE, + payload={"msg": message}, + token=token, + timeout_ms=API_TIMEOUT_MS, + ) + + +async def get_upload_url( + session: "aiohttp.ClientSession", + *, + base_url: str, + token: str, + to_user_id: str, + media_type: int, + filekey: str, + rawsize: int, + rawfilemd5: str, + filesize: int, + aeskey_hex: str, +) -> dict: + """Request a CDN upload slot for an outbound media file.""" + return await api_post( + session, + base_url=base_url, + endpoint=EP_GET_UPLOAD_URL, + payload={ + "filekey": filekey, + "media_type": media_type, + "to_user_id": to_user_id, + "rawsize": rawsize, + "rawfilemd5": rawfilemd5, + "filesize": filesize, + "no_need_thumb": True, + "aeskey": aeskey_hex, + }, + token=token, + timeout_ms=API_TIMEOUT_MS, + ) diff --git a/flocks/channel/builtin/weixin/config.py b/flocks/channel/builtin/weixin/config.py new file mode 100644 index 00000000..50ed3150 --- /dev/null +++ b/flocks/channel/builtin/weixin/config.py @@ -0,0 +1,99 @@ +""" +Constants, regex patterns, and dependency guards for the Weixin channel. + +All public constants for the iLink Bot API live here so that other modules +in this package import a single source of truth. +""" + +from __future__ import annotations + +import re + +# --------------------------------------------------------------------------- +# iLink Bot API constants +# --------------------------------------------------------------------------- +ILINK_BASE_URL = "https://ilinkai.weixin.qq.com" +WEIXIN_CDN_BASE_URL = "https://novac2c.cdn.weixin.qq.com/c2c" +ILINK_APP_ID = "bot" +CHANNEL_VERSION = "2.2.0" +ILINK_APP_CLIENT_VERSION = (2 << 16) | (2 << 8) | 0 + +EP_GET_UPDATES = "ilink/bot/getupdates" +EP_SEND_MESSAGE = "ilink/bot/sendmessage" +EP_GET_CONFIG = "ilink/bot/getconfig" +EP_GET_UPLOAD_URL = "ilink/bot/getuploadurl" +EP_GET_BOT_QR = "ilink/bot/get_bot_qrcode" +EP_GET_QR_STATUS = "ilink/bot/get_qrcode_status" +QR_TIMEOUT_MS = 35_000 + +# --------------------------------------------------------------------------- +# Timeouts (milliseconds for API calls, seconds for media transfers) +# --------------------------------------------------------------------------- +LONG_POLL_TIMEOUT_MS = 35_000 +API_TIMEOUT_MS = 15_000 + +MEDIA_DOWNLOAD_IMAGE_TIMEOUT_S = 30.0 +MEDIA_DOWNLOAD_VIDEO_TIMEOUT_S = 120.0 +MEDIA_DOWNLOAD_FILE_TIMEOUT_S = 60.0 +MEDIA_DOWNLOAD_VOICE_TIMEOUT_S = 60.0 +MEDIA_UPLOAD_TIMEOUT_S = 120.0 +MEDIA_REMOTE_FETCH_TIMEOUT_S = 30.0 + +# --------------------------------------------------------------------------- +# Retry / backoff tuning +# --------------------------------------------------------------------------- +MAX_CONSECUTIVE_FAILURES = 3 +RETRY_DELAY_SECONDS = 2.0 +BACKOFF_DELAY_SECONDS = 30.0 +SESSION_EXPIRED_ERRCODE = -14 +RATE_LIMIT_ERRCODE = -2 +MESSAGE_DEDUP_TTL_SECONDS = 300 +MAX_MESSAGE_LENGTH = 2000 + +# --------------------------------------------------------------------------- +# iLink message / item type constants +# --------------------------------------------------------------------------- +ITEM_TEXT = 1 +ITEM_IMAGE = 2 +ITEM_VOICE = 3 +ITEM_FILE = 4 +ITEM_VIDEO = 5 +MSG_TYPE_BOT = 2 +MSG_STATE_FINISH = 2 + +MEDIA_IMAGE = 1 +MEDIA_VIDEO = 2 +MEDIA_FILE = 3 +MEDIA_VOICE = 4 + +# --------------------------------------------------------------------------- +# Markdown / format regex helpers (shared with format.py) +# --------------------------------------------------------------------------- +HEADER_RE = re.compile(r"^(#{1,6})\s+(.+?)\s*$") +TABLE_RULE_RE = re.compile(r"^\s*\|?(?:\s*:?-{3,}:?\s*\|)+\s*:?-{3,}:?\s*\|?\s*$") +FENCE_RE = re.compile(r"^```([^\n`]*)\s*$") + +# --------------------------------------------------------------------------- +# Dependency guards (importable feature flags) +# --------------------------------------------------------------------------- +try: + import aiohttp # type: ignore[import-untyped] # noqa: F401 + AIOHTTP_AVAILABLE = True +except ImportError: + AIOHTTP_AVAILABLE = False + +try: + from cryptography.hazmat.backends import default_backend # noqa: F401 + from cryptography.hazmat.primitives.ciphers import ( # noqa: F401 + Cipher, + algorithms, + modes, + ) + CRYPTO_AVAILABLE = True +except ImportError: + CRYPTO_AVAILABLE = False + + +def check_requirements() -> bool: + """Return True when both runtime dependencies are installed.""" + return AIOHTTP_AVAILABLE and CRYPTO_AVAILABLE diff --git a/flocks/channel/builtin/weixin/crypto.py b/flocks/channel/builtin/weixin/crypto.py new file mode 100644 index 00000000..c40729f0 --- /dev/null +++ b/flocks/channel/builtin/weixin/crypto.py @@ -0,0 +1,59 @@ +""" +AES-128-ECB encryption helpers used by the WeChat iLink CDN protocol. + +iLink encrypts/decrypts media payloads with a per-file 16-byte AES key +in ECB mode with PKCS7 padding. Key wire format is base64 of either the +raw 16 bytes or the 32-character hex string of the same key. +""" + +from __future__ import annotations + +import base64 + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + + +def pkcs7_pad(data: bytes, block_size: int = 16) -> bytes: + pad_len = block_size - (len(data) % block_size) + return data + bytes([pad_len] * pad_len) + + +def aes128_ecb_encrypt(plaintext: bytes, key: bytes) -> bytes: + cipher = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend()) + encryptor = cipher.encryptor() + return encryptor.update(pkcs7_pad(plaintext)) + encryptor.finalize() + + +def aes128_ecb_decrypt(ciphertext: bytes, key: bytes) -> bytes: + cipher = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend()) + decryptor = cipher.decryptor() + padded = decryptor.update(ciphertext) + decryptor.finalize() + if not padded: + return padded + pad_len = padded[-1] + if 1 <= pad_len <= 16 and padded.endswith(bytes([pad_len]) * pad_len): + return padded[:-pad_len] + return padded + + +def aes_padded_size(size: int) -> int: + """PKCS7-padded output size for *size* plaintext bytes (block=16).""" + return ((size + 1 + 15) // 16) * 16 + + +def parse_aes_key(aes_key_b64: str) -> bytes: + """Parse an iLink-style AES key. + + Accepts either: + - base64 of raw 16 bytes (decoded length 16), or + - base64 of the 32-char ASCII hex of the same key (decoded length 32). + """ + decoded = base64.b64decode(aes_key_b64) + if len(decoded) == 16: + return decoded + if len(decoded) == 32: + text = decoded.decode("ascii", errors="ignore") + if text and all(ch in "0123456789abcdefABCDEF" for ch in text): + return bytes.fromhex(text) + raise ValueError(f"unexpected aes_key format ({len(decoded)} decoded bytes)") diff --git a/flocks/channel/builtin/weixin/format.py b/flocks/channel/builtin/weixin/format.py new file mode 100644 index 00000000..f36fb873 --- /dev/null +++ b/flocks/channel/builtin/weixin/format.py @@ -0,0 +1,170 @@ +""" +Markdown normalization and message-chunk splitting for WeChat delivery. + +WeChat clients render most Markdown but truncate very long lines awkwardly. +We: +- Collapse runs of blank lines to at most one. +- Hard-wrap non-code, non-table lines longer than ``LINE_WRAP_WIDTH``. +- Pack content into messages under ``MAX_MESSAGE_LENGTH`` while keeping + fenced code blocks (``` ``` ```) intact. +""" + +from __future__ import annotations + +import textwrap +from typing import Optional + +from .config import FENCE_RE, TABLE_RULE_RE + +LINE_WRAP_WIDTH = 120 + + +def normalize_markdown(content: str) -> str: + """Collapse multi-blank-line runs (outside code blocks) to a single blank.""" + lines = content.splitlines() + result: list[str] = [] + in_code_block = False + blank_run = 0 + + for raw_line in lines: + line = raw_line.rstrip() + if FENCE_RE.match(line.strip()): + in_code_block = not in_code_block + result.append(line) + blank_run = 0 + continue + if in_code_block: + result.append(line) + continue + if not line.strip(): + blank_run += 1 + if blank_run <= 1: + result.append("") + continue + blank_run = 0 + result.append(line) + + return "\n".join(result).strip() + + +def wrap_long_lines(content: str, width: int = LINE_WRAP_WIDTH) -> str: + """Soft-wrap copy-unfriendly long lines while preserving code/tables.""" + wrapped: list[str] = [] + in_code_block = False + + for raw_line in content.splitlines(): + line = raw_line.rstrip() + stripped = line.strip() + + if FENCE_RE.match(stripped): + in_code_block = not in_code_block + wrapped.append(line) + continue + + if ( + in_code_block + or len(line) <= width + or not stripped + or stripped.startswith("|") + or TABLE_RULE_RE.match(stripped) + ): + wrapped.append(line) + continue + + wrapped_lines = textwrap.wrap( + line, width=width, + break_long_words=False, break_on_hyphens=False, + replace_whitespace=False, drop_whitespace=True, + ) + wrapped.extend(wrapped_lines or [line]) + + return "\n".join(wrapped).strip() + + +def format_for_weixin(content: Optional[str]) -> str: + """Top-level formatter: normalize whitespace + soft-wrap long lines.""" + if not content: + return "" + return wrap_long_lines(normalize_markdown(content)) + + +def split_markdown_blocks(content: str) -> list[str]: + """Split content into markdown-aware blocks, keeping fenced code intact.""" + if not content: + return [] + blocks: list[str] = [] + current: list[str] = [] + in_code_block = False + + for raw_line in content.splitlines(): + line = raw_line.rstrip() + if FENCE_RE.match(line.strip()): + if not in_code_block and current: + blocks.append("\n".join(current).strip()) + current = [] + current.append(line) + in_code_block = not in_code_block + if not in_code_block: + blocks.append("\n".join(current).strip()) + current = [] + continue + if in_code_block: + current.append(line) + continue + if not line.strip(): + if current: + blocks.append("\n".join(current).strip()) + current = [] + continue + current.append(line) + + if current: + blocks.append("\n".join(current).strip()) + return [b for b in blocks if b] + + +def split_chunks(content: str, max_length: int) -> list[str]: + """Pack markdown blocks into chunks under *max_length*, preserving code fences. + + Long single blocks (e.g. a large code block) are force-split by line then + by character as a last resort. + """ + if not content: + return [] + if len(content) <= max_length: + return [content] + + chunks: list[str] = [] + current = "" + for block in split_markdown_blocks(content): + candidate = block if not current else f"{current}\n\n{block}" + if len(candidate) <= max_length: + current = candidate + continue + if current: + chunks.append(current) + current = "" + if len(block) <= max_length: + current = block + continue + # Block itself oversized — fall back to line-then-char split. + line_buf = "" + for line in block.splitlines(): + if len(line) > max_length: + if line_buf: + chunks.append(line_buf) + line_buf = "" + for i in range(0, len(line), max_length): + chunks.append(line[i:i + max_length]) + continue + if len(line_buf) + len(line) + 1 > max_length: + if line_buf: + chunks.append(line_buf) + line_buf = line + else: + line_buf = f"{line_buf}\n{line}" if line_buf else line + if line_buf: + current = line_buf + if current: + chunks.append(current) + return [c for c in chunks if c] or [content[:max_length]] diff --git a/flocks/channel/builtin/weixin/inbound.py b/flocks/channel/builtin/weixin/inbound.py new file mode 100644 index 00000000..1ac2af99 --- /dev/null +++ b/flocks/channel/builtin/weixin/inbound.py @@ -0,0 +1,65 @@ +""" +Inbound message parsing helpers for the iLink frame format. + +These are pure functions over the raw frame dicts emitted by ``getupdates``, +suitable for unit-testing without an aiohttp session. +""" + +from __future__ import annotations + +from .config import ITEM_FILE, ITEM_IMAGE, ITEM_TEXT, ITEM_VIDEO, ITEM_VOICE + + +def extract_text(item_list: list) -> str: + """Pull a flat text string out of an iLink ``item_list``. + + Handles plain text, replies / quotes (``ref_msg``), and voice-to-text + transcription fallback. + """ + for item in item_list: + if item.get("type") == ITEM_TEXT: + text = str((item.get("text_item") or {}).get("text") or "") + ref = item.get("ref_msg") or {} + ref_item = ref.get("message_item") or {} + ref_type = ref_item.get("type") + if ref_type in (ITEM_IMAGE, ITEM_VIDEO, ITEM_FILE, ITEM_VOICE): + title = ref.get("title") or "" + prefix = f"[引用媒体: {title}]\n" if title else "[引用媒体]\n" + return f"{prefix}{text}".strip() + if ref_item: + parts: list[str] = [] + if ref.get("title"): + parts.append(str(ref["title"])) + ref_text = extract_text([ref_item]) + if ref_text: + parts.append(ref_text) + if parts: + return f"[引用: {' | '.join(parts)}]\n{text}".strip() + return text + for item in item_list: + if item.get("type") == ITEM_VOICE: + voice_text = str((item.get("voice_item") or {}).get("text") or "") + if voice_text: + return voice_text + return "" + + +def guess_chat_type(message: dict, account_id: str) -> tuple[str, str]: + """Return ``(chat_type, effective_chat_id)`` where chat_type ∈ ``"dm"`` | ``"group"``.""" + room_id = str(message.get("room_id") or message.get("chat_room_id") or "").strip() + to_user_id = str(message.get("to_user_id") or "").strip() + is_group = bool(room_id) or ( + to_user_id and account_id and to_user_id != account_id + and message.get("msg_type") == 1 + ) + if is_group: + return "group", room_id or to_user_id or str(message.get("from_user_id") or "") + return "dm", str(message.get("from_user_id") or "") + + +def safe_id(value: object, keep: int = 8) -> str: + """Truncate IDs for log output while keeping enough to be useful.""" + raw = str(value or "").strip() + if not raw: + return "?" + return raw[:keep] if len(raw) > keep else raw diff --git a/flocks/channel/builtin/weixin/media.py b/flocks/channel/builtin/weixin/media.py new file mode 100644 index 00000000..95b5a48a --- /dev/null +++ b/flocks/channel/builtin/weixin/media.py @@ -0,0 +1,449 @@ +""" +High-level media orchestration for the Weixin channel. + +- ``MediaCache`` writes decrypted inbound bytes to a content-addressed disk + cache and returns local ``file://`` URIs that can travel through the rest + of the Flocks pipeline. +- ``download_inbound_item`` dispatches on iLink item type to fetch + decrypt + + cache an image / video / file / voice payload, returning ``(local_uri, + mime_type)``. +- ``send_outbound_file`` encrypts a local file, requests a CDN upload slot, + uploads the ciphertext, and posts the media item via ``send_media_message``. +- ``fetch_remote_to_temp`` resolves remote URLs to local temp files (used + when ``OutboundContext.media_url`` is an http(s) URL rather than a path). +""" + +from __future__ import annotations + +import base64 +import hashlib +import mimetypes +import secrets +import tempfile +import uuid +from pathlib import Path +from typing import TYPE_CHECKING, Callable, Optional + +from flocks.utils.log import Log + +from .cdn import ( + cdn_upload_url, + download_and_decrypt_media, + download_bytes, + media_reference, + upload_ciphertext, +) +from .client import get_upload_url, send_media_message +from .config import ( + ITEM_FILE, + ITEM_IMAGE, + ITEM_VIDEO, + ITEM_VOICE, + MEDIA_DOWNLOAD_FILE_TIMEOUT_S, + MEDIA_DOWNLOAD_IMAGE_TIMEOUT_S, + MEDIA_DOWNLOAD_VIDEO_TIMEOUT_S, + MEDIA_DOWNLOAD_VOICE_TIMEOUT_S, + MEDIA_FILE, + MEDIA_IMAGE, + MEDIA_REMOTE_FETCH_TIMEOUT_S, + MEDIA_UPLOAD_TIMEOUT_S, + MEDIA_VIDEO, + MEDIA_VOICE, +) +from .crypto import aes128_ecb_encrypt, aes_padded_size +from .inbound import safe_id +from .store import ensure_state_dir + +if TYPE_CHECKING: + import aiohttp + +log = Log.create(service="channel.weixin.media") + + +# --------------------------------------------------------------------------- +# Local content-addressed cache for inbound media +# --------------------------------------------------------------------------- + +class MediaCache: + """Write decrypted inbound bytes to ``/media/`` and yield URIs. + + Content-addressed by sha256 of the plaintext to deduplicate re-deliveries + of the same image / file across restarts. + """ + + def __init__(self, data_dir: Optional[str] = None) -> None: + self._root = ensure_state_dir(data_dir) / "media" + self._root.mkdir(parents=True, exist_ok=True) + + def write(self, data: bytes, suffix: str, original_name: Optional[str] = None) -> str: + """Cache *data* under sha256(data) + *suffix* and return a ``file://`` URI.""" + digest = hashlib.sha256(data).hexdigest() + if original_name: + stem = Path(original_name).stem.replace("/", "_") or "media" + name = f"{stem}-{digest[:16]}{suffix}" + else: + name = f"{digest}{suffix}" + path = self._root / name + if not path.exists(): + try: + path.write_bytes(data) + except Exception as exc: + log.warning("weixin.media.cache_write_error", {"error": str(exc)}) + return "" + return path.resolve().as_uri() + + +# --------------------------------------------------------------------------- +# Inbound dispatch +# --------------------------------------------------------------------------- + +def is_downloadable_media_item(item: dict) -> bool: + """Return True iff *item* is a media item that ``download_inbound_item`` + would actually fetch (i.e. would produce bytes, not text-only fallback). + """ + item_type = item.get("type") + if item_type in (ITEM_IMAGE, ITEM_VIDEO, ITEM_FILE): + return True + if item_type == ITEM_VOICE: + # Voice items already transcribed to text are not downloaded as media. + voice_item = item.get("voice_item") or {} + return not voice_item.get("text") + return False + + +async def download_inbound_item( + session: "aiohttp.ClientSession", + *, + item: dict, + cdn_base_url: str, + cache: MediaCache, + sender_log_id: str = "?", +) -> Optional[tuple[str, str]]: + """Download + decrypt + cache a single ``item_list`` entry. + + Returns ``(local_file_uri, mime_type)`` on success, or ``None`` for non-media + items / failures (logged at WARN). + """ + item_type = item.get("type") + try: + if item_type == ITEM_IMAGE: + return await _download_image(session, item, cdn_base_url, cache) + if item_type == ITEM_VIDEO: + return await _download_video(session, item, cdn_base_url, cache) + if item_type == ITEM_FILE: + return await _download_file(session, item, cdn_base_url, cache) + if item_type == ITEM_VOICE: + return await _download_voice(session, item, cdn_base_url, cache) + except Exception as exc: + log.warning("weixin.media.download_failed", { + "type": item_type, "from": sender_log_id, "error": str(exc), + }) + return None + + +async def _download_image( + session: "aiohttp.ClientSession", + item: dict, + cdn_base_url: str, + cache: MediaCache, +) -> Optional[tuple[str, str]]: + image_item = item.get("image_item") or {} + media = image_item.get("media") or {} + aes_key = _normalize_image_aes_key(image_item, media) + data = await download_and_decrypt_media( + session, + cdn_base_url=cdn_base_url, + encrypted_query_param=media.get("encrypt_query_param"), + aes_key_b64=aes_key, + full_url=media.get("full_url"), + timeout_seconds=MEDIA_DOWNLOAD_IMAGE_TIMEOUT_S, + ) + uri = cache.write(data, ".jpg") + return (uri, "image/jpeg") if uri else None + + +async def _download_video( + session: "aiohttp.ClientSession", + item: dict, + cdn_base_url: str, + cache: MediaCache, +) -> Optional[tuple[str, str]]: + media = media_reference(item, "video_item") + data = await download_and_decrypt_media( + session, + cdn_base_url=cdn_base_url, + encrypted_query_param=media.get("encrypt_query_param"), + aes_key_b64=media.get("aes_key"), + full_url=media.get("full_url"), + timeout_seconds=MEDIA_DOWNLOAD_VIDEO_TIMEOUT_S, + ) + uri = cache.write(data, ".mp4") + return (uri, "video/mp4") if uri else None + + +async def _download_file( + session: "aiohttp.ClientSession", + item: dict, + cdn_base_url: str, + cache: MediaCache, +) -> Optional[tuple[str, str]]: + file_item = item.get("file_item") or {} + media = file_item.get("media") or {} + filename = str(file_item.get("file_name") or "document.bin") + mime = mime_from_filename(filename) + data = await download_and_decrypt_media( + session, + cdn_base_url=cdn_base_url, + encrypted_query_param=media.get("encrypt_query_param"), + aes_key_b64=media.get("aes_key"), + full_url=media.get("full_url"), + timeout_seconds=MEDIA_DOWNLOAD_FILE_TIMEOUT_S, + ) + suffix = Path(filename).suffix or ".bin" + uri = cache.write(data, suffix, original_name=filename) + return (uri, mime) if uri else None + + +async def _download_voice( + session: "aiohttp.ClientSession", + item: dict, + cdn_base_url: str, + cache: MediaCache, +) -> Optional[tuple[str, str]]: + voice_item = item.get("voice_item") or {} + if voice_item.get("text"): + # Voice already transcribed by iLink; treat as text, no media to cache. + return None + media = voice_item.get("media") or {} + data = await download_and_decrypt_media( + session, + cdn_base_url=cdn_base_url, + encrypted_query_param=media.get("encrypt_query_param"), + aes_key_b64=media.get("aes_key"), + full_url=media.get("full_url"), + timeout_seconds=MEDIA_DOWNLOAD_VOICE_TIMEOUT_S, + ) + uri = cache.write(data, ".silk") + return (uri, "audio/silk") if uri else None + + +def _normalize_image_aes_key(image_item: dict, media: dict) -> Optional[str]: + """iLink image frames may stash the AES key under ``image_item.aeskey`` (hex) + instead of ``media.aes_key`` (b64). Reconcile both into a base64 string. + """ + if media.get("aes_key"): + return media["aes_key"] + aeskey_hex = image_item.get("aeskey") + if isinstance(aeskey_hex, str) and aeskey_hex: + try: + return base64.b64encode(bytes.fromhex(aeskey_hex)).decode("ascii") + except Exception: + return None + return None + + +# --------------------------------------------------------------------------- +# Outbound dispatch +# --------------------------------------------------------------------------- + +OutboundItemBuilder = Callable[..., dict] + + +def select_outbound_media( + path: str, force_file_attachment: bool = False +) -> tuple[int, OutboundItemBuilder]: + """Pick the right ``media_type`` + ``item`` constructor for *path*'s mime.""" + mime = mimetypes.guess_type(path)[0] or "application/octet-stream" + + if mime.startswith("image/"): + return MEDIA_IMAGE, _build_image_item + if mime.startswith("video/"): + return MEDIA_VIDEO, _build_video_item + if path.endswith(".silk") and not force_file_attachment: + return MEDIA_VOICE, _build_voice_item + if mime.startswith("audio/"): + # Non-silk audio: send as file attachment (silk is required for native voice bubble). + return MEDIA_FILE, _build_file_item + return MEDIA_FILE, _build_file_item + + +def _build_image_item(**kw) -> dict: + return { + "type": ITEM_IMAGE, + "image_item": { + "media": { + "encrypt_query_param": kw["encrypt_query_param"], + "aes_key": kw["aes_key_for_api"], + "encrypt_type": 1, + }, + "mid_size": kw["ciphertext_size"], + }, + } + + +def _build_video_item(**kw) -> dict: + return { + "type": ITEM_VIDEO, + "video_item": { + "media": { + "encrypt_query_param": kw["encrypt_query_param"], + "aes_key": kw["aes_key_for_api"], + "encrypt_type": 1, + }, + "video_size": kw["ciphertext_size"], + "play_length": kw.get("play_length", 0), + "video_md5": kw.get("rawfilemd5", ""), + }, + } + + +def _build_voice_item(**kw) -> dict: + return { + "type": ITEM_VOICE, + "voice_item": { + "media": { + "encrypt_query_param": kw["encrypt_query_param"], + "aes_key": kw["aes_key_for_api"], + "encrypt_type": 1, + }, + "encode_type": kw.get("encode_type", 6), + "bits_per_sample": kw.get("bits_per_sample", 16), + "sample_rate": kw.get("sample_rate", 24000), + "playtime": kw.get("playtime", 0), + }, + } + + +def _build_file_item(**kw) -> dict: + return { + "type": ITEM_FILE, + "file_item": { + "media": { + "encrypt_query_param": kw["encrypt_query_param"], + "aes_key": kw["aes_key_for_api"], + "encrypt_type": 1, + }, + "file_name": kw["filename"], + "len": str(kw["plaintext_size"]), + }, + } + + +async def send_outbound_file( + session: "aiohttp.ClientSession", + *, + base_url: str, + cdn_base_url: str, + token: str, + chat_id: str, + path: str, + context_token: Optional[str], + context_token_setter: Optional[Callable[[str, Optional[str]], None]] = None, + force_file_attachment: bool = False, +) -> str: + """Encrypt + upload + send a single local file. Returns the client_id used.""" + plaintext = Path(path).read_bytes() + media_type, item_builder = select_outbound_media( + path, force_file_attachment=force_file_attachment, + ) + filekey = secrets.token_hex(16) + aes_key = secrets.token_bytes(16) + rawsize = len(plaintext) + rawfilemd5 = hashlib.md5(plaintext).hexdigest() + + upload_resp = await get_upload_url( + session, + base_url=base_url, + token=token, + to_user_id=chat_id, + media_type=media_type, + filekey=filekey, + rawsize=rawsize, + rawfilemd5=rawfilemd5, + filesize=aes_padded_size(rawsize), + aeskey_hex=aes_key.hex(), + ) + + upload_param = str(upload_resp.get("upload_param") or "") + upload_full_url = str(upload_resp.get("upload_full_url") or "") + if upload_full_url: + upload_url = upload_full_url + elif upload_param: + upload_url = cdn_upload_url(cdn_base_url, upload_param, filekey) + else: + raise RuntimeError( + "getUploadUrl returned neither upload_param nor upload_full_url: " + f"{upload_resp}" + ) + + ciphertext = aes128_ecb_encrypt(plaintext, aes_key) + encrypted_query_param = await upload_ciphertext( + session, + ciphertext=ciphertext, + upload_url=upload_url, + timeout_seconds=MEDIA_UPLOAD_TIMEOUT_S, + ) + + # iLink expects aes_key as base64(hex_string), not base64(raw_bytes). + aes_key_for_api = base64.b64encode(aes_key.hex().encode("ascii")).decode("ascii") + + media_item_kwargs = { + "encrypt_query_param": encrypted_query_param, + "aes_key_for_api": aes_key_for_api, + "ciphertext_size": len(ciphertext), + "plaintext_size": rawsize, + "filename": Path(path).name, + "rawfilemd5": rawfilemd5, + } + media_item = item_builder(**media_item_kwargs) + + client_id = f"flocks-weixin-{uuid.uuid4().hex}" + await send_media_message( + session, + base_url=base_url, + token=token, + to=chat_id, + item=media_item, + context_token=context_token, + client_id=client_id, + ) + log.info("weixin.media.sent", { + "to": safe_id(chat_id), + "media_type": media_type, + "size": rawsize, + }) + return client_id + + +async def fetch_remote_to_temp( + session: "aiohttp.ClientSession", + *, + url: str, + timeout_seconds: float = MEDIA_REMOTE_FETCH_TIMEOUT_S, +) -> str: + """Download an http(s) URL into a temp file, return the local path. + + Caller is responsible for unlinking the temp file when done. + Only use after validating the URL belongs to the WeChat CDN. + """ + data = await download_bytes(session, url=url, timeout_seconds=timeout_seconds) + suffix = Path(url.split("?", 1)[0]).suffix or ".bin" + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as handle: + handle.write(data) + return handle.name + + +def mime_from_filename(filename: str) -> str: + return mimetypes.guess_type(filename)[0] or "application/octet-stream" + + +# Re-exported for the channel to schedule async tasks without a circular import. +__all__ = [ + "MediaCache", + "download_inbound_item", + "is_downloadable_media_item", + "send_outbound_file", + "fetch_remote_to_temp", + "mime_from_filename", + "select_outbound_media", +] diff --git a/flocks/channel/builtin/weixin/qr_login.py b/flocks/channel/builtin/weixin/qr_login.py new file mode 100644 index 00000000..e364fca8 --- /dev/null +++ b/flocks/channel/builtin/weixin/qr_login.py @@ -0,0 +1,172 @@ +""" +iLink Bot QR-code login flow for the Weixin channel. + +Two API endpoints are exposed by the Flocks server so the web UI can: + +1. ``POST /api/channel/weixin/qr-login/start`` + → Call ``ilink/bot/get_bot_qrcode`` (no token required — this *is* the + pre-auth step). Returns ``{qrcode_value, qrcode_url}`` so the frontend + can render a QR code with e.g. ``qrcode.react``. + +2. ``GET /api/channel/weixin/qr-login/status?qrcode=`` + → Poll ``ilink/bot/get_qrcode_status``. Returns + ``{status, account_id, token}`` where ``status`` is one of: + "waiting" — waiting for scan + "scaned" — phone scanned, waiting for phone confirmation tap + "confirmed"— login complete; ``account_id`` and ``token`` populated + "expired" — QR code expired; frontend should call /start again + +These helpers are pure async functions that accept an explicit ``base_url`` +so callers can override the iLink endpoint without touching global state. +""" + +from __future__ import annotations + +import json +import ssl +from typing import Optional + +from .config import ( + EP_GET_QR_STATUS, + EP_GET_BOT_QR, + ILINK_BASE_URL, + ILINK_APP_ID, + ILINK_APP_CLIENT_VERSION, + CHANNEL_VERSION, + QR_TIMEOUT_MS, +) + + +# --------------------------------------------------------------------------- +# HTTP helpers (no aiohttp session shared with the channel — login creates +# a throwaway session so it doesn't interfere with the poll loop) +# --------------------------------------------------------------------------- + +def _make_ssl_ctx(): + try: + import certifi + return ssl.create_default_context(cafile=certifi.where()) + except ImportError: + return True # aiohttp default + + +def _login_headers() -> dict: + return { + "Content-Type": "application/json", + "AuthorizationType": "ilink_bot_token", + "iLink-App-Id": ILINK_APP_ID, + "iLink-App-ClientVersion": str(ILINK_APP_CLIENT_VERSION), + } + + +async def _api_get(base_url: str, endpoint: str) -> dict: + """Simple GET against the iLink API with a short timeout.""" + import aiohttp + + url = f"{base_url.rstrip('/')}/{endpoint}" + timeout = aiohttp.ClientTimeout(total=QR_TIMEOUT_MS / 1000) + connector = aiohttp.TCPConnector(ssl=_make_ssl_ctx()) + async with aiohttp.ClientSession( + trust_env=True, connector=connector + ) as session: + async with session.get( + url, headers=_login_headers(), timeout=timeout + ) as resp: + raw = await resp.text() + if not resp.ok: + raise RuntimeError( + f"iLink GET {endpoint} HTTP {resp.status}: {raw[:200]}" + ) + return json.loads(raw) + + +# --------------------------------------------------------------------------- +# Public helpers called by the route handlers +# --------------------------------------------------------------------------- + +async def start_qr_login( + base_url: str = ILINK_BASE_URL, + bot_type: str = "3", +) -> dict: + """Request a fresh QR code from iLink. + + Returns ``{"qrcode_value": str, "qrcode_url": str}`` where + - ``qrcode_value`` is the raw hex token used to poll status + - ``qrcode_url`` is the WeChat mini-app URL to encode in the rendered QR + """ + resp = await _api_get( + base_url, + f"{EP_GET_BOT_QR}?bot_type={bot_type}", + ) + qrcode_value: str = str(resp.get("qrcode") or "") + qrcode_url: str = str(resp.get("qrcode_img_content") or "") + if not qrcode_value: + raise RuntimeError( + f"iLink get_bot_qrcode returned no qrcode field: {resp}" + ) + # WeChat must scan the full mini-app URL, not the raw hex token. + scan_data = qrcode_url if qrcode_url else qrcode_value + return { + "qrcode_value": qrcode_value, + "qrcode_url": scan_data, + } + + +async def poll_qr_status( + qrcode_value: str, + base_url: str = ILINK_BASE_URL, +) -> dict: + """Poll the QR code status once. + + Returns one of:: + + {"status": "waiting"} + {"status": "scaned"} + {"status": "expired"} + {"status": "redirect", "redirect_base_url": "https://..."} + {"status": "confirmed", "account_id": "...", "token": "...", + "base_url": "https://..."} + + ``redirect`` is returned when iLink routes the account to a regional node. + The frontend must pass the new ``redirect_base_url`` as ``base_url`` for all + subsequent calls so that the final ``confirmed`` response comes from the + correct node. It must also persist ``base_url`` from ``confirmed`` into the + channel config — otherwise the long-poll loop will connect to the wrong node. + + The caller (route handler) is responsible for looping / error handling. + """ + resp = await _api_get( + base_url, + f"{EP_GET_QR_STATUS}?qrcode={qrcode_value}", + ) + status: str = str(resp.get("status") or "waiting").lower() + + if status == "confirmed": + account_id = str(resp.get("ilink_bot_id") or "") + token = str(resp.get("bot_token") or "") + # iLink returns the canonical base_url for this account on confirmed. + # This may differ from ILINK_BASE_URL for accounts on regional nodes. + confirmed_base_url = str(resp.get("baseurl") or "").rstrip("/") or base_url + if not account_id or not token: + raise RuntimeError( + f"QR confirmed but missing credentials: {resp}" + ) + return { + "status": "confirmed", + "account_id": account_id, + "token": token, + "base_url": confirmed_base_url, + } + + if status == "scaned_but_redirect": + redirect_host = str(resp.get("redirect_host") or "").strip() + redirect_base_url = ( + f"https://{redirect_host}" if redirect_host else base_url + ) + return {"status": "redirect", "redirect_base_url": redirect_base_url} + + if status == "scaned": + return {"status": "scaned"} + if status == "expired": + return {"status": "expired"} + return {"status": "waiting"} diff --git a/flocks/channel/builtin/weixin/store.py b/flocks/channel/builtin/weixin/store.py new file mode 100644 index 00000000..cf08592f --- /dev/null +++ b/flocks/channel/builtin/weixin/store.py @@ -0,0 +1,144 @@ +""" +Disk-backed state stores for the Weixin channel: + +- ``ContextTokenStore`` — per-account, per-peer ``context_token`` cache + required to maintain conversation continuity with the iLink server. +- ``MessageDedup`` — in-memory dedup with TTL-based pruning. +- ``sync_buf`` helpers — long-poll cursor persistence. + +State files default to ``~/.flocks/workspace/channels/weixin/`` but the channel +can override the root via the ``dataDir`` config key (useful for multi-profile +setups). When ``dataDir`` is set it is used as-is (no ``weixin/`` sub-dir is +appended). +""" + +from __future__ import annotations + +import json +import time +from pathlib import Path +from typing import Optional + +from flocks.utils.log import Log + +from .config import MESSAGE_DEDUP_TTL_SECONDS + +log = Log.create(service="channel.weixin.store") + + +# --------------------------------------------------------------------------- +# Filesystem helpers +# --------------------------------------------------------------------------- + +def state_dir(data_dir: Optional[str] = None) -> Path: + if data_dir: + return Path(data_dir) + return Path.home() / ".flocks" / "workspace" / "channels" / "weixin" + + +def ensure_state_dir(data_dir: Optional[str] = None) -> Path: + path = state_dir(data_dir) + path.mkdir(parents=True, exist_ok=True) + return path + + +# --------------------------------------------------------------------------- +# Sync-buf cursor (long-poll position) +# --------------------------------------------------------------------------- + +def load_sync_buf(account_id: str, data_dir: Optional[str] = None) -> str: + path = state_dir(data_dir) / f"{account_id}.sync.json" + if not path.exists(): + return "" + try: + return json.loads(path.read_text(encoding="utf-8")).get("get_updates_buf", "") + except Exception: + return "" + + +def save_sync_buf(account_id: str, sync_buf: str, data_dir: Optional[str] = None) -> None: + try: + path = ensure_state_dir(data_dir) / f"{account_id}.sync.json" + path.write_text(json.dumps({"get_updates_buf": sync_buf}), encoding="utf-8") + except Exception as exc: + log.warning("weixin.sync_buf.save_error", {"error": str(exc)}) + + +# --------------------------------------------------------------------------- +# Per-peer context token cache +# --------------------------------------------------------------------------- + +class ContextTokenStore: + """Disk-backed ``context_token`` cache keyed by ``(account_id, user_id)``.""" + + def __init__(self, data_dir: Optional[str] = None) -> None: + self._root = state_dir(data_dir) + self._cache: dict[str, str] = {} + + def _path(self, account_id: str) -> Path: + return self._root / f"{account_id}.context-tokens.json" + + @staticmethod + def _key(account_id: str, user_id: str) -> str: + return f"{account_id}:{user_id}" + + def restore(self, account_id: str) -> None: + path = self._path(account_id) + if not path.exists(): + return + try: + data = json.loads(path.read_text(encoding="utf-8")) + except Exception: + return + for user_id, token in data.items(): + if isinstance(token, str) and token: + self._cache[self._key(account_id, user_id)] = token + + def get(self, account_id: str, user_id: str) -> Optional[str]: + return self._cache.get(self._key(account_id, user_id)) + + def set(self, account_id: str, user_id: str, token: str) -> None: + self._cache[self._key(account_id, user_id)] = token + self._persist(account_id) + + def clear(self, account_id: str, user_id: str) -> None: + """Drop a stale token (called on session-expired errors).""" + if self._cache.pop(self._key(account_id, user_id), None) is not None: + self._persist(account_id) + + def _persist(self, account_id: str) -> None: + prefix = f"{account_id}:" + payload = { + key[len(prefix):]: value + for key, value in self._cache.items() + if key.startswith(prefix) + } + try: + self._root.mkdir(parents=True, exist_ok=True) + self._path(account_id).write_text(json.dumps(payload), encoding="utf-8") + except Exception as exc: + log.warning("weixin.context_token.persist_error", {"error": str(exc)}) + + +# --------------------------------------------------------------------------- +# In-memory dedup with TTL pruning +# --------------------------------------------------------------------------- + +class MessageDedup: + """Track recent message ids / content hashes to drop redelivered messages.""" + + def __init__(self, ttl_seconds: float = MESSAGE_DEDUP_TTL_SECONDS) -> None: + self._ttl = ttl_seconds + self._seen: dict[str, float] = {} + + def is_duplicate(self, key: str) -> bool: + now = time.time() + cutoff = now - self._ttl + seen_at = self._seen.get(key) + if seen_at is not None and seen_at > cutoff: + return True + # Prune stale entries lazily every ~100 inserts to bound memory growth. + if len(self._seen) >= 100 and len(self._seen) % 100 == 0: + self._seen = {k: v for k, v in self._seen.items() if v > cutoff} + self._seen[key] = now + return False diff --git a/flocks/channel/inbound/dispatcher.py b/flocks/channel/inbound/dispatcher.py index 466ffc50..1512c8cd 100644 --- a/flocks/channel/inbound/dispatcher.py +++ b/flocks/channel/inbound/dispatcher.py @@ -959,6 +959,11 @@ async def _append_user_message( model: Optional[dict] = None, agent: Optional[str] = None, ) -> None: + import mimetypes + import os + from pathlib import Path + from urllib.parse import unquote, urlparse + from flocks.session.message import FilePart, Message, MessageRole create_kwargs: dict = dict( @@ -980,29 +985,70 @@ async def _append_user_message( message = await Message.create(**create_kwargs) - if msg.channel_id != "feishu" or not msg.media_url or channel_config is None: + if not msg.media_url: return try: - from flocks.channel.builtin.feishu.inbound_media import download_inbound_media + parsed = urlparse(msg.media_url) + scheme = parsed.scheme.lower() + + if msg.channel_id == "feishu" and channel_config is not None: + # Feishu: media is still on the remote server, download first. + from flocks.channel.builtin.feishu.inbound_media import download_inbound_media + + raw_cfg = channel_config.model_dump(by_alias=True, exclude_none=True) + media = await download_inbound_media(msg, raw_cfg) + if not media: + return + + await Message.store_part( + session_id, + message.id, + FilePart( + sessionID=session_id, + messageID=message.id, + mime=media.mime, + filename=media.filename, + url=media.url, + source=media.source, + ), + ) - raw_cfg = channel_config.model_dump(by_alias=True, exclude_none=True) - media = await download_inbound_media(msg, raw_cfg) - if not media: - return + elif scheme in ("", "file"): + # Local file already downloaded by the channel plugin (e.g. weixin). + # file:// URIs may have URL-encoded paths (e.g. Chinese filenames). + local_path = unquote(parsed.path) if scheme == "file" else msg.media_url + if not os.path.isfile(local_path): + log.warning("dispatcher.inbound_media_missing", { + "channel_id": msg.channel_id, + "path": local_path, + }) + return + filename = Path(local_path).name + mime = ( + msg.media_mime + or mimetypes.guess_type(local_path)[0] + or "application/octet-stream" + ) + file_uri = Path(local_path).resolve().as_uri() + await Message.store_part( + session_id, + message.id, + FilePart( + sessionID=session_id, + messageID=message.id, + mime=mime, + filename=filename, + url=file_uri, + source=None, + ), + ) + log.info("dispatcher.inbound_media_attached", { + "channel_id": msg.channel_id, + "filename": filename, + "mime": mime, + }) - await Message.store_part( - session_id, - message.id, - FilePart( - sessionID=session_id, - messageID=message.id, - mime=media.mime, - filename=media.filename, - url=media.url, - source=media.source, - ), - ) except Exception as e: log.warning("dispatcher.inbound_media_download_failed", { "channel_id": msg.channel_id, diff --git a/flocks/channel/inbound/session_binding.py b/flocks/channel/inbound/session_binding.py index 3327b575..175ea413 100644 --- a/flocks/channel/inbound/session_binding.py +++ b/flocks/channel/inbound/session_binding.py @@ -100,11 +100,12 @@ async def _get_db() -> aiosqlite.Connection: db_path = Storage.get_db_path() db_path.parent.mkdir(parents=True, exist_ok=True) - _db_conn = await aiosqlite.connect(str(db_path)) + _db_conn = await aiosqlite.connect( + str(db_path), + timeout=Storage._sqlite_timeout_s, + ) _db_conn.row_factory = aiosqlite.Row - - await _db_conn.execute("PRAGMA journal_mode=WAL") - await _db_conn.execute("PRAGMA busy_timeout=5000") + await Storage.configure_connection(_db_conn) await _db_conn.executescript(_DDL) await _migrate_legacy_binding_agent_ids(_db_conn) _db_ready = True diff --git a/flocks/channel/registry.py b/flocks/channel/registry.py index aff20c7d..03819050 100644 --- a/flocks/channel/registry.py +++ b/flocks/channel/registry.py @@ -80,10 +80,12 @@ def _register_builtin_channels(self) -> None: from flocks.channel.builtin.feishu.channel import FeishuChannel from flocks.channel.builtin.telegram.channel import TelegramChannel from flocks.channel.builtin.wecom.channel import WeComChannel + from flocks.channel.builtin.weixin.channel import WeixinChannel self.register(FeishuChannel()) self.register(WeComChannel()) self.register(TelegramChannel()) self.register(DingTalkChannel()) + self.register(WeixinChannel()) def _register_plugin_extension_point(self) -> None: from flocks.plugin import PluginLoader, ExtensionPoint diff --git a/flocks/mcp/client.py b/flocks/mcp/client.py index 8b14d9a2..b8e709ed 100644 --- a/flocks/mcp/client.py +++ b/flocks/mcp/client.py @@ -1,21 +1,26 @@ """ MCP Client Wrapper -Client implementation based on official MCP SDK, supporting Streamable HTTP and SSE transports +Client implementation based on official MCP SDK, supporting Streamable HTTP and SSE transports. """ import asyncio +import contextlib import os import tempfile +from contextlib import asynccontextmanager +from dataclasses import dataclass, field from pathlib import Path -from typing import Optional, Dict, Any, List, Literal +from typing import Any, Dict, List, Literal, Optional + +import httpx from mcp import ClientSession -from mcp.client.streamable_http import streamablehttp_client from mcp.client.sse import sse_client -from mcp.client.stdio import stdio_client, StdioServerParameters +from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.client.streamable_http import streamable_http_client -from flocks.mcp.types import McpToolDef, McpResource -from flocks.mcp.utils import build_mcp_headers, build_mcp_url, resolve_env_var +from flocks.mcp.types import McpResource, McpToolDef +from flocks.mcp.utils import build_mcp_headers, build_mcp_url from flocks.utils.log import Log log = Log.create(service="mcp.client") @@ -43,6 +48,24 @@ def _extract_root_cause(exc: BaseException) -> str: return str(exc) +def _normalize_timeout(timeout: object) -> float: + """Normalize optional timeout inputs to a positive float.""" + try: + value = float(timeout) # type: ignore[arg-type] + except (TypeError, ValueError): + return 30.0 + return value if value > 0 else 30.0 + + +@dataclass(slots=True) +class _ClientCommand: + """A serialized request executed by the MCP owner task.""" + + action: Literal["list_tools", "call_tool", "list_resources", "read_resource", "disconnect"] + payload: Dict[str, Any] = field(default_factory=dict) + response: asyncio.Future[Any] | None = None + + class McpClient: """ MCP Client - Wraps official SDK @@ -88,13 +111,15 @@ def __init__( self.env = env self.auth_config = auth_config self.transport = transport - self.timeout = timeout + self.timeout = _normalize_timeout(timeout) self.session: Optional[ClientSession] = None self._streams = None - self._streams_context = None self._connected = False self._transport_type: Optional[str] = None + self._command_queue: asyncio.Queue[_ClientCommand] | None = None + self._owner_task: asyncio.Task[None] | None = None + self._owner_error: BaseException | None = None async def connect(self) -> None: """ @@ -107,18 +132,105 @@ async def connect(self) -> None: if self._connected: log.warn("mcp.client.already_connected", {"server": self.name}) return - - if self.server_type in ("remote", "sse"): - # Both "remote" and "sse" use auto-detection: - # try Streamable HTTP first, fall back to SSE. - # This handles servers that only support one transport. - await self._connect_remote() - elif self.server_type in ("local", "stdio"): - await self._connect_local() - else: - raise ValueError(f"Unknown server type: {self.server_type}") - - async def _connect_remote(self) -> None: + + if self._owner_task and not self._owner_task.done(): + log.warn("mcp.client.connect_in_progress", {"server": self.name}) + return + + loop = asyncio.get_running_loop() + startup_future: asyncio.Future[None] = loop.create_future() + self._owner_error = None + self._command_queue = asyncio.Queue() + + owner_task = asyncio.create_task( + self._run_connection_owner(startup_future), + name=f"mcp-owner:{self.name}", + ) + owner_task.add_done_callback(self._handle_owner_task_done) + self._owner_task = owner_task + + try: + await asyncio.wait_for(startup_future, timeout=self.timeout + 1.0) + except Exception: + await self._cancel_owner_task() + self._reset_runtime_state() + raise + + async def _run_connection_owner(self, startup_future: asyncio.Future[None]) -> None: + """Own the entire MCP session lifecycle inside one asyncio task.""" + try: + if self.server_type in ("remote", "sse"): + await self._connect_remote(startup_future) + elif self.server_type in ("local", "stdio"): + await self._connect_local(startup_future) + else: + raise ValueError(f"Unknown server type: {self.server_type}") + except Exception as exc: + if not startup_future.done(): + startup_future.set_exception(exc) + else: + await self._fail_pending_commands( + RuntimeError(f"Connection lost: {self.name}: {_extract_root_cause(exc)}") + ) + raise + finally: + if not startup_future.done(): + startup_future.set_exception( + RuntimeError(f"Connection closed before initialization: {self.name}") + ) + self._connected = False + self.session = None + self._streams = None + self._transport_type = None + await self._fail_pending_commands(RuntimeError(f"Client not connected: {self.name}")) + + def _handle_owner_task_done(self, task: asyncio.Task[None]) -> None: + """Retrieve background task exceptions so asyncio does not emit warnings.""" + try: + owner_error = task.exception() + except asyncio.CancelledError: + owner_error = None + + self._owner_error = owner_error + if self._owner_task is task: + self._owner_task = None + + if owner_error is not None: + log.error("mcp.client.owner_task_error", { + "server": self.name, + "error": _extract_root_cause(owner_error), + }) + + async def _cancel_owner_task(self) -> None: + """Cancel and await the owner task if it is still running.""" + owner_task = self._owner_task + if owner_task is None: + return + if not owner_task.done(): + owner_task.cancel() + try: + await owner_task + except asyncio.CancelledError: + return + except Exception as exc: + # Connection startup may already have failed; preserve the error so + # callers can finish cleanup and still surface the root cause. + if self._owner_error is None: + self._owner_error = exc + + def _reset_runtime_state(self, clear_owner_error: bool = False) -> None: + """Reset local state after disconnects or failed startups.""" + self.session = None + self._streams = None + self._connected = False + self._transport_type = None + self._command_queue = None + if self._owner_task is not None and self._owner_task.done(): + self._owner_task = None + if clear_owner_error: + self._owner_error = None + + async def _connect_remote(self, startup_future: asyncio.Future[None]) -> None: """Connect to a remote server using the configured transport strategy.""" full_url = build_mcp_url(self.url, self.auth_config) request_headers = build_mcp_headers(self.headers, self.auth_config) @@ -129,7 +241,7 @@ async def _connect_remote(self) -> None: "type": "remote", "strategy": "streamable_http_only", }) - await self._connect_streamable_http_only(full_url, request_headers) + await self._connect_streamable_http_only(full_url, request_headers, startup_future) return if self.transport == "sse": @@ -138,7 +250,7 @@ async def _connect_remote(self) -> None: "type": "remote", "strategy": "sse_only", }) - await self._connect_sse_only(full_url, request_headers) + await self._connect_sse_only(full_url, request_headers, startup_future) return log.info("mcp.client.connecting", { @@ -146,163 +258,261 @@ async def _connect_remote(self) -> None: "type": "remote", "strategy": "streamable_http_then_sse", }) - await self._connect_auto(full_url, request_headers) + await self._connect_auto(full_url, request_headers, startup_future) async def _connect_streamable_http_only( - self, full_url: str, headers: Optional[Dict[str, str]] + self, + full_url: str, + headers: Optional[Dict[str, str]], + startup_future: asyncio.Future[None], ) -> None: """Connect using only Streamable HTTP.""" try: - await self._do_connect_streamable_http(full_url, headers) - self._transport_type = "streamable_http" - except asyncio.TimeoutError: - await self._cleanup_connection() + await self._run_remote_transport( + transport_name="streamable_http", + full_url=full_url, + headers=headers, + startup_future=startup_future, + transport_factory=self._create_streamable_http_streams, + ) + except asyncio.TimeoutError as exc: log.error("mcp.client.timeout", { "server": self.name, "transport": "streamable_http", }) - raise RuntimeError(f"Connection timeout: {self.name}") - except Exception as e: - root_cause = _extract_root_cause(e) - await self._cleanup_connection() - raise RuntimeError(f"Connection failed: {self.name}: {root_cause}") + raise RuntimeError(f"Connection timeout: {self.name}") from exc + except Exception as exc: + raise RuntimeError(f"Connection failed: {self.name}: {_extract_root_cause(exc)}") from exc async def _connect_sse_only( - self, full_url: str, headers: Optional[Dict[str, str]] + self, + full_url: str, + headers: Optional[Dict[str, str]], + startup_future: asyncio.Future[None], ) -> None: """Connect using only SSE.""" try: - await self._do_connect_sse(full_url, headers) - self._transport_type = "sse" - except asyncio.TimeoutError: - await self._cleanup_connection() + await self._run_remote_transport( + transport_name="sse", + full_url=full_url, + headers=headers, + startup_future=startup_future, + transport_factory=self._create_sse_streams, + ) + except asyncio.TimeoutError as exc: log.error("mcp.client.timeout", { "server": self.name, "transport": "sse", }) - raise RuntimeError(f"Connection timeout: {self.name}") - except Exception as e: - root_cause = _extract_root_cause(e) - await self._cleanup_connection() - raise RuntimeError(f"Connection failed: {self.name}: {root_cause}") + raise RuntimeError(f"Connection timeout: {self.name}") from exc + except Exception as exc: + raise RuntimeError(f"Connection failed: {self.name}: {_extract_root_cause(exc)}") from exc async def _connect_auto( - self, full_url: str, headers: Optional[Dict[str, str]] + self, + full_url: str, + headers: Optional[Dict[str, str]], + startup_future: asyncio.Future[None], ) -> None: """Connect using auto-detection: HTTP first, then SSE.""" try: - await self._do_connect_streamable_http(full_url, headers) - self._transport_type = "streamable_http" + await self._run_remote_transport( + transport_name="streamable_http", + full_url=full_url, + headers=headers, + startup_future=startup_future, + transport_factory=self._create_streamable_http_streams, + ) return - except asyncio.TimeoutError: - await self._cleanup_connection() + except asyncio.TimeoutError as exc: log.error("mcp.client.timeout", { "server": self.name, "transport": "streamable_http", }) - raise RuntimeError(f"Connection timeout: {self.name}") - except Exception as e: + raise RuntimeError(f"Connection timeout: {self.name}") from exc + except Exception as exc: + if startup_future.done(): + raise log.info("mcp.client.streamable_http_failed", { "server": self.name, - "error": str(e), + "error": _extract_root_cause(exc), "fallback": "sse", }) - await self._cleanup_connection() try: - await self._do_connect_sse(full_url, headers) - self._transport_type = "sse" - return - except Exception as e: - root_cause = _extract_root_cause(e) + await self._run_remote_transport( + transport_name="sse", + full_url=full_url, + headers=headers, + startup_future=startup_future, + transport_factory=self._create_sse_streams, + ) + except Exception as exc: + root_cause = _extract_root_cause(exc) log.error("mcp.client.all_transports_failed", { "server": self.name, "error": root_cause, }) - await self._cleanup_connection() - raise RuntimeError(f"Connection failed: {self.name}: {root_cause}") - - async def _do_connect_streamable_http( - self, full_url: str, headers: Optional[Dict[str, str]] = None + raise RuntimeError(f"Connection failed: {self.name}: {root_cause}") from exc + + async def _run_remote_transport( + self, + transport_name: Literal["streamable_http", "sse"], + full_url: str, + headers: Optional[Dict[str, str]], + startup_future: asyncio.Future[None], + transport_factory, ) -> None: - """Perform Streamable HTTP connection. - - Raises: - asyncio.TimeoutError: If connection or initialization times out - Exception: Any other connection error - """ - self._streams_context = streamablehttp_client( - full_url, - headers=headers, - timeout=self.timeout, - ) - self._streams = await self._streams_context.__aenter__() - read_stream, write_stream, _ = self._streams - - self.session = ClientSession(read_stream, write_stream) - await self.session.__aenter__() - init_result = await asyncio.wait_for( - self.session.initialize(), - timeout=self.timeout - ) - - self._connected = True - log.info("mcp.client.connected", { - "server": self.name, - "transport": "streamable_http", - "protocol_version": getattr(init_result, 'protocolVersion', 'unknown'), - "server_info": getattr(init_result, 'serverInfo', {}) - }) - - async def _do_connect_sse( - self, full_url: str, headers: Optional[Dict[str, str]] = None + """Run one remote transport from startup until disconnect.""" + async with transport_factory(full_url, headers) as streams: + self._streams = streams + await self._run_connected_session(transport_name, streams, startup_future) + + @asynccontextmanager + async def _create_streamable_http_streams( + self, + full_url: str, + headers: Optional[Dict[str, str]], + ): + """Create a modern Streamable HTTP transport context.""" + timeout = httpx.Timeout(self.timeout, read=60 * 5) + async with httpx.AsyncClient(headers=headers, timeout=timeout) as http_client: + async with streamable_http_client(full_url, http_client=http_client) as streams: + yield streams + + @asynccontextmanager + async def _create_sse_streams( + self, + full_url: str, + headers: Optional[Dict[str, str]], + ): + """Create an SSE transport context.""" + async with sse_client(full_url, headers=headers, timeout=self.timeout) as streams: + yield streams + + async def _run_connected_session( + self, + transport_name: Literal["streamable_http", "sse", "stdio"], + streams, + startup_future: asyncio.Future[None], ) -> None: - """Perform SSE connection. - - Raises: - asyncio.TimeoutError: If connection or initialization times out - Exception: Any other connection error - """ - self._streams_context = sse_client( - full_url, - headers=headers, - timeout=self.timeout, - ) - self._streams = await self._streams_context.__aenter__() - read_stream, write_stream = self._streams - - self.session = ClientSession(read_stream, write_stream) - await self.session.__aenter__() - init_result = await asyncio.wait_for( - self.session.initialize(), - timeout=self.timeout - ) - - self._connected = True - log.info("mcp.client.connected", { - "server": self.name, - "transport": "sse", - "protocol_version": getattr(init_result, 'protocolVersion', 'unknown'), - "server_info": getattr(init_result, 'serverInfo', {}) - }) - - async def _cleanup_connection(self) -> None: - """Clean up connection resources after a failed attempt""" - if self.session: + """Initialize the MCP session and then serve queued commands.""" + if len(streams) == 3: + read_stream, write_stream, _ = streams + else: + read_stream, write_stream = streams + + async with ClientSession(read_stream, write_stream) as session: + self.session = session + init_result = await asyncio.wait_for(session.initialize(), timeout=self.timeout) + self._connected = True + self._transport_type = transport_name + if not startup_future.done(): + startup_future.set_result(None) + log.info("mcp.client.connected", { + "server": self.name, + "transport": transport_name, + "protocol_version": getattr(init_result, "protocolVersion", "unknown"), + "server_info": getattr(init_result, "serverInfo", {}), + }) + await self._serve_commands(session) + + async def _serve_commands(self, session: ClientSession) -> None: + """Process serialized commands until disconnect is requested.""" + if self._command_queue is None: + raise RuntimeError(f"Command queue not initialized: {self.name}") + + while True: + command = await self._command_queue.get() + if command.action == "disconnect": + if command.response is not None and not command.response.done(): + command.response.set_result(None) + return + try: - await self.session.__aexit__(None, None, None) - except Exception: - pass - if self._streams_context: + result = await self._execute_command(session, command) + except Exception as exc: + if command.response is not None and not command.response.done(): + command.response.set_exception(exc) + else: + if command.response is not None and not command.response.done(): + command.response.set_result(result) + + async def _execute_command(self, session: ClientSession, command: _ClientCommand) -> Any: + """Execute one queued MCP command inside the owner task.""" + if command.action == "list_tools": + result = await asyncio.wait_for(session.list_tools(), timeout=self.timeout) + tools = [McpToolDef.from_sdk(tool) for tool in result.tools] + log.debug("mcp.client.tools_listed", { + "server": self.name, + "count": len(tools), + }) + return tools + + if command.action == "call_tool": + tool_name = command.payload["name"] try: - await self._streams_context.__aexit__(None, None, None) - except Exception: - pass - self.session = None - self._streams = None - self._streams_context = None - self._connected = False - self._transport_type = None + result = await asyncio.wait_for( + session.call_tool(name=tool_name, arguments=command.payload["arguments"]), + timeout=self.timeout, + ) + except asyncio.TimeoutError as exc: + log.error("mcp.client.call_timeout", { + "server": self.name, + "tool": tool_name, + }) + from concurrent.futures import TimeoutError as _FuturesTimeoutError + + raise _FuturesTimeoutError(f"MCP工具调用超时 ({self.timeout}s): {tool_name}") from exc + + log.debug("mcp.client.tool_called", { + "server": self.name, + "tool": tool_name, + }) + return result + + if command.action == "list_resources": + result = await asyncio.wait_for(session.list_resources(), timeout=self.timeout) + resources = [ + McpResource( + name=resource.name, + uri=resource.uri, + description=getattr(resource, "description", None), + mime_type=getattr(resource, "mimeType", None), + server=self.name, + ) + for resource in result.resources + ] + log.debug("mcp.client.resources_listed", { + "server": self.name, + "count": len(resources), + }) + return resources + + if command.action == "read_resource": + uri = command.payload["uri"] + result = await asyncio.wait_for(session.read_resource(uri=uri), timeout=self.timeout) + log.debug("mcp.client.resource_read", { + "server": self.name, + "uri": uri, + }) + return result + + raise ValueError(f"Unknown MCP command: {command.action}") + + async def _fail_pending_commands(self, error: Exception) -> None: + """Fail any queued commands when the owner task exits.""" + if self._command_queue is None: + return + + while True: + try: + command = self._command_queue.get_nowait() + except asyncio.QueueEmpty: + return + + if command.response is not None and not command.response.done(): + command.response.set_exception(error) @staticmethod def _read_stderr(stderr_file) -> str: @@ -325,7 +535,13 @@ def _flocks_mcp_prefix() -> str: prefix.mkdir(parents=True, exist_ok=True) return str(prefix) - async def _connect_local(self) -> None: + @asynccontextmanager + async def _create_stdio_streams(self, server_params: StdioServerParameters, stderr_file): + """Create stdio transport streams.""" + async with stdio_client(server_params, errlog=stderr_file) as streams: + yield streams + + async def _connect_local(self, startup_future: asyncio.Future[None]) -> None: """Connect to local server via Stdio transport.""" if not self.command: raise ValueError(f"No command specified for local server: {self.name}") @@ -374,90 +590,63 @@ async def _connect_local(self) -> None: "args": args, }) - stderr_file = tempfile.TemporaryFile(mode="w+") - try: - self._streams_context = stdio_client(server_params, errlog=stderr_file) - self._streams = await self._streams_context.__aenter__() - read_stream, write_stream = self._streams - - self.session = ClientSession(read_stream, write_stream) - await self.session.__aenter__() - init_result = await asyncio.wait_for( - self.session.initialize(), - timeout=self.timeout, - ) - - self._connected = True - self._transport_type = "stdio" - log.info("mcp.client.connected", { - "server": self.name, - "transport": "stdio", - "protocol_version": getattr(init_result, 'protocolVersion', 'unknown'), - "server_info": getattr(init_result, 'serverInfo', {}), - }) - except asyncio.TimeoutError: - stderr_output = self._read_stderr(stderr_file) - await self._cleanup_connection() - log.error("mcp.client.timeout", { - "server": self.name, - "transport": "stdio", - "stderr": stderr_output, - }) - detail = f"Connection timeout (stdio): {self.name}" - if stderr_output: - detail += f"\nServer stderr:\n{stderr_output}" - raise RuntimeError(detail) - except Exception as e: - stderr_output = self._read_stderr(stderr_file) - root_cause = _extract_root_cause(e) - await self._cleanup_connection() - log.error("mcp.client.stdio_failed", { - "server": self.name, - "error": root_cause, - "stderr": stderr_output, - }) - detail = f"Stdio connection failed: {self.name}: {root_cause}" - if stderr_output: - detail += f"\nServer stderr:\n{stderr_output}" - raise RuntimeError(detail) - finally: - stderr_file.close() + # Keep stderr capture lifetime explicit: the file must outlive the stdio + # transport context, but should close immediately once the attempt ends. + with tempfile.TemporaryFile(mode="w+") as stderr_file: + try: + async with self._create_stdio_streams(server_params, stderr_file) as streams: + self._streams = streams + await self._run_connected_session("stdio", streams, startup_future) + except asyncio.TimeoutError as exc: + stderr_output = self._read_stderr(stderr_file) + log.error("mcp.client.timeout", { + "server": self.name, + "transport": "stdio", + "stderr": stderr_output, + }) + detail = f"Connection timeout (stdio): {self.name}" + if stderr_output: + detail += f"\nServer stderr:\n{stderr_output}" + raise RuntimeError(detail) from exc + except Exception as exc: + stderr_output = self._read_stderr(stderr_file) + root_cause = _extract_root_cause(exc) + log.error("mcp.client.stdio_failed", { + "server": self.name, + "error": root_cause, + "stderr": stderr_output, + }) + detail = f"Stdio connection failed: {self.name}: {root_cause}" + if stderr_output: + detail += f"\nServer stderr:\n{stderr_output}" + raise RuntimeError(detail) async def disconnect(self) -> None: """Disconnect from server""" - if not self._connected: + owner_task = self._owner_task + if owner_task is None and not self._connected: return - + try: - # Close session first - if self.session: - try: - await self.session.__aexit__(None, None, None) - except Exception as e: - log.warn("mcp.client.session_close_error", { - "server": self.name, - "error": str(e) - }) - - # Then close streams - if self._streams_context: - try: - await self._streams_context.__aexit__(None, None, None) - except Exception as e: - log.warn("mcp.client.streams_close_error", { - "server": self.name, - "error": str(e) - }) - except Exception as e: + if owner_task is not None and not owner_task.done() and self._command_queue is not None: + response = asyncio.get_running_loop().create_future() + await self._command_queue.put(_ClientCommand(action="disconnect", response=response)) + await response + elif owner_task is not None and not owner_task.done(): + owner_task.cancel() + + if owner_task is not None: + with contextlib.suppress(asyncio.CancelledError): + await owner_task + except Exception as exc: log.error("mcp.client.disconnect_error", { "server": self.name, - "error": str(e) + "error": str(exc), }) finally: - self._connected = False - self.session = None - self._streams = None - self._streams_context = None + self._reset_runtime_state(clear_owner_error=True) + if self._owner_task is owner_task: + self._owner_task = None log.info("mcp.client.disconnected", {"server": self.name}) async def list_tools(self) -> List[McpToolDef]: @@ -470,24 +659,13 @@ async def list_tools(self) -> List[McpToolDef]: Raises: RuntimeError: If not connected """ - if not self._connected or not self.session: - raise RuntimeError(f"Client not connected: {self.name}") - try: - result = await asyncio.wait_for( - self.session.list_tools(), - timeout=self.timeout - ) - tools = [McpToolDef.from_sdk(tool) for tool in result.tools] - log.debug("mcp.client.tools_listed", { - "server": self.name, - "count": len(tools) - }) - return tools - except Exception as e: + result = await self._submit_command("list_tools") + return result + except Exception as exc: log.error("mcp.client.list_tools_error", { "server": self.name, - "error": str(e) + "error": str(exc), }) raise @@ -505,31 +683,13 @@ async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any: Raises: RuntimeError: If not connected """ - if not self._connected or not self.session: - raise RuntimeError(f"Client not connected: {self.name}") - try: - result = await asyncio.wait_for( - self.session.call_tool(name=name, arguments=arguments), - timeout=self.timeout - ) - log.debug("mcp.client.tool_called", { - "server": self.name, - "tool": name - }) - return result - except asyncio.TimeoutError: - log.error("mcp.client.call_timeout", { - "server": self.name, - "tool": name - }) - from concurrent.futures import TimeoutError as _FuturesTimeoutError - raise _FuturesTimeoutError(f"MCP工具调用超时 ({self.timeout}s): {name}") - except Exception as e: + return await self._submit_command("call_tool", name=name, arguments=arguments) + except Exception as exc: log.error("mcp.client.call_error", { "server": self.name, "tool": name, - "error": str(e) + "error": str(exc), }) raise @@ -543,32 +703,13 @@ async def list_resources(self) -> List[McpResource]: Raises: RuntimeError: If not connected """ - if not self._connected or not self.session: - raise RuntimeError(f"Client not connected: {self.name}") - try: - result = await asyncio.wait_for( - self.session.list_resources(), - timeout=self.timeout - ) - resources = [] - for r in result.resources: - resources.append(McpResource( - name=r.name, - uri=r.uri, - description=getattr(r, 'description', None), - mime_type=getattr(r, 'mimeType', None), - server=self.name - )) - log.debug("mcp.client.resources_listed", { - "server": self.name, - "count": len(resources) - }) - return resources - except Exception as e: + result = await self._submit_command("list_resources") + return result + except Exception as exc: log.error("mcp.client.list_resources_error", { "server": self.name, - "error": str(e) + "error": str(exc), }) raise @@ -585,26 +726,42 @@ async def read_resource(self, uri: str) -> Any: Raises: RuntimeError: If not connected """ - if not self._connected or not self.session: - raise RuntimeError(f"Client not connected: {self.name}") - try: - result = await asyncio.wait_for( - self.session.read_resource(uri=uri), - timeout=self.timeout - ) - log.debug("mcp.client.resource_read", { - "server": self.name, - "uri": uri - }) - return result - except Exception as e: + return await self._submit_command("read_resource", uri=uri) + except Exception as exc: log.error("mcp.client.read_resource_error", { "server": self.name, "uri": uri, - "error": str(e) + "error": str(exc), }) raise + + async def _submit_command(self, action: str, **payload: Any) -> Any: + """Send a serialized command to the owner task.""" + if not self._connected or self._command_queue is None: + if self._owner_error is not None: + raise RuntimeError( + f"Client not connected: {self.name}: {_extract_root_cause(self._owner_error)}" + ) from self._owner_error + raise RuntimeError(f"Client not connected: {self.name}") + + owner_task = self._owner_task + if owner_task is None: + if self._owner_error is not None: + raise RuntimeError( + f"Client not connected: {self.name}: {_extract_root_cause(self._owner_error)}" + ) from self._owner_error + raise RuntimeError(f"Client not connected: {self.name}") + + response = asyncio.get_running_loop().create_future() + command = _ClientCommand(action=action, payload=payload, response=response) + await self._command_queue.put(command) + + if owner_task.done() and not response.done(): + owner_error = self._owner_error or RuntimeError(f"Client not connected: {self.name}") + response.set_exception(owner_error) + + return await response @property def is_connected(self) -> bool: diff --git a/flocks/memory/sync/indexer.py b/flocks/memory/sync/indexer.py index 6985c39d..614ba817 100644 --- a/flocks/memory/sync/indexer.py +++ b/flocks/memory/sync/indexer.py @@ -222,7 +222,7 @@ async def _get_indexed_files(self) -> Dict[str, Dict[str, Any]]: indexed = {} try: - async with aiosqlite.connect(Storage.get_db_path()) as db: + async with Storage.connect(Storage.get_db_path()) as db: cursor = await db.execute(""" SELECT path, hash, mtime, size FROM memory_files @@ -449,7 +449,7 @@ async def _delete_file_chunks(self, path: str) -> None: import aiosqlite try: - async with aiosqlite.connect(Storage.get_db_path()) as db: + async with Storage.connect(Storage.get_db_path()) as db: await db.execute( "DELETE FROM memory_chunks WHERE project_id = ? AND path = ?", (self.project_id, path), @@ -465,7 +465,7 @@ async def _update_file_entry(self, file_entry: MemoryFileEntry) -> None: now = datetime.now().timestamp() try: - async with aiosqlite.connect(Storage.get_db_path()) as db: + async with Storage.connect(Storage.get_db_path()) as db: await db.execute(""" INSERT OR REPLACE INTO memory_files (path, project_id, source, hash, mtime, size, indexed_at) @@ -496,7 +496,7 @@ async def _clean_deleted_files(self, current_files: List[str]) -> int: import aiosqlite try: - async with aiosqlite.connect(Storage.get_db_path()) as db: + async with Storage.connect(Storage.get_db_path()) as db: cursor = await db.execute(""" SELECT path FROM memory_files WHERE project_id = ? """, (self.project_id,)) diff --git a/flocks/provider/usage_service.py b/flocks/provider/usage_service.py index a04e1b60..6679db64 100644 --- a/flocks/provider/usage_service.py +++ b/flocks/provider/usage_service.py @@ -270,7 +270,7 @@ async def _get_existing_usage_record( async def usage_record_exists(*, session_id: str, message_id: str) -> bool: """Check whether a usage row already exists for a message.""" await Storage._ensure_init() - async with aiosqlite.connect(Storage._db_path) as db: + async with Storage.connect(Storage._db_path) as db: async with db.execute( "SELECT 1 FROM usage_records WHERE session_id = ? AND message_id = ? LIMIT 1", (session_id, message_id), @@ -282,7 +282,7 @@ async def usage_record_exists(*, session_id: str, message_id: str) -> bool: async def _get_recorded_message_ids(session_id: str) -> set[str]: """Return all assistant message ids already present in usage_records.""" await Storage._ensure_init() - async with aiosqlite.connect(Storage._db_path) as db: + async with Storage.connect(Storage._db_path) as db: async with db.execute( "SELECT message_id FROM usage_records WHERE session_id = ? AND message_id IS NOT NULL", (session_id,), @@ -334,7 +334,7 @@ async def record_usage(req: RecordUsageRequest) -> UsageRecord: total_cost = 0.0 currency = "USD" - async with aiosqlite.connect(Storage._db_path) as db: + async with Storage.connect(Storage._db_path) as db: db.row_factory = aiosqlite.Row existing = await _get_existing_usage_record( db, @@ -419,7 +419,7 @@ async def get_usage_records( model_id=model_id, session_ids=session_ids, ) - async with aiosqlite.connect(Storage._db_path) as db: + async with Storage.connect(Storage._db_path) as db: db.row_factory = aiosqlite.Row async with db.execute( f"""SELECT id, provider_id, model_id, credential_id, session_id, message_id, @@ -453,7 +453,7 @@ async def get_usage_stats( session_ids=session_ids, ) - async with aiosqlite.connect(Storage._db_path) as db: + async with Storage.connect(Storage._db_path) as db: db.row_factory = aiosqlite.Row async with db.execute( diff --git a/flocks/server/app.py b/flocks/server/app.py index a3edd361..7387c9b2 100644 --- a/flocks/server/app.py +++ b/flocks/server/app.py @@ -5,10 +5,12 @@ """ import asyncio +import inspect import os import time +from types import SimpleNamespace from pathlib import Path -from typing import Optional +from typing import Any, Callable, Optional from contextlib import asynccontextmanager from fastapi import FastAPI, Request, Response, status from fastapi.middleware.cors import CORSMiddleware @@ -43,6 +45,60 @@ # Lifespan context manager for startup/shutdown +async def _maybe_await(result: Any) -> Any: + """Await values that are awaitable and return plain values unchanged.""" + if inspect.isawaitable(result): + return await result + return result + + +async def _run_startup_phase( + log, + phase: str, + fn: Callable[[], Any], +) -> Any: + """Execute one startup phase and emit structured timing logs.""" + started_at = time.perf_counter() + try: + result = await _maybe_await(fn()) + except Exception as exc: + duration_ms = int((time.perf_counter() - started_at) * 1000) + log.warning("server.startup.phase", { + "phase": phase, + "status": "failed", + "duration_ms": duration_ms, + "error": str(exc), + }) + raise + + duration_ms = int((time.perf_counter() - started_at) * 1000) + log.info("server.startup.phase", { + "phase": phase, + "status": "completed", + "duration_ms": duration_ms, + }) + return result + + +def _schedule_startup_phase( + app: FastAPI, + log, + phase: str, + fn: Callable[[], Any], +) -> None: + """Run a non-critical startup phase in the background after app is ready.""" + + async def _runner() -> None: + try: + await _run_startup_phase(log, phase, fn) + except Exception: + # _run_startup_phase already logged the failure. + return + + task = asyncio.create_task(_runner(), name=f"startup:{phase}") + app.state.startup_background_tasks.append(task) + + @asynccontextmanager async def lifespan(app: FastAPI): """Handle application lifecycle""" @@ -51,13 +107,21 @@ async def lifespan(app: FastAPI): await Log.init(print=False, dev=False, level=LogLevel.INFO) log = Log.create(service="server") + if not hasattr(app, "state") or app.state is None: + app.state = SimpleNamespace() + app.state.startup_background_tasks = [] + startup_started_at = time.perf_counter() # Startup log.info("server.startup", {"version": "0.2.0"}) try: from flocks.updater.updater import cleanup_replaced_files - await asyncio.to_thread(cleanup_replaced_files) + await _run_startup_phase( + log, + "updater.cleanup_leftovers", + lambda: asyncio.to_thread(cleanup_replaced_files), + ) log.info("updater.leftovers.cleaned") except Exception as e: log.warning("updater.leftovers.cleanup_failed", {"error": str(e)}) @@ -65,13 +129,21 @@ async def lifespan(app: FastAPI): try: from flocks.updater.updater import _get_repo_root, _refresh_global_cli_entry - await asyncio.to_thread(_refresh_global_cli_entry, _get_repo_root()) + await _run_startup_phase( + log, + "cli.refresh_global_entry", + lambda: asyncio.to_thread(_refresh_global_cli_entry, _get_repo_root()), + ) log.info("cli.global_entry.refreshed") except Exception as e: log.warning("cli.global_entry.refresh_failed", {"error": str(e)}) try: - init_observability() + await _run_startup_phase( + log, + "observability.init", + init_observability, + ) log.info("observability.initialized") except Exception as e: log.warning("observability.init_failed", {"error": str(e)}) @@ -79,7 +151,11 @@ async def lifespan(app: FastAPI): # Ensure config files exist (copy from examples if needed) try: from flocks.config.config_writer import ensure_config_files - ensure_config_files() + await _run_startup_phase( + log, + "config.ensure_files", + ensure_config_files, + ) log.info("config.files.checked") except Exception as e: log.warning("config.files.check_failed", {"error": str(e)}) @@ -89,7 +165,11 @@ async def lifespan(app: FastAPI): # ``_v`` once the plugin declares a version. try: from flocks.config.api_versioning import migrate_api_services - actions = migrate_api_services() + actions = await _run_startup_phase( + log, + "config.migrate_api_services", + migrate_api_services, + ) copied = [k for k, v in actions.items() if v == "copied"] if copied: log.info("config.api_services.migrated", {"copied": copied}) @@ -97,31 +177,42 @@ async def lifespan(app: FastAPI): log.warning("config.api_services.migrate_failed", {"error": str(e)}) # Initialize storage - await Storage.init() + await _run_startup_phase(log, "storage.init", Storage.init) log.info("storage.initialized") # Initialize local auth/account tables - await AuthService.init() + await _run_startup_phase(log, "auth.init", AuthService.init) log.info("auth.initialized") # Best-effort migration: old sessions default to admin ownership. # The migration itself is idempotent (guarded by a persisted marker), # but we still skip loading users when the marker is already set # to avoid unnecessary DB + session scans on every startup. - try: + async def _migrate_legacy_sessions_to_admin() -> None: marker = await Storage.get("auth:migration:legacy_session_owner_to_admin", dict) - if not (marker and marker.get("done")): - if await AuthService.has_users(): - users = await AuthService.list_users() - admin = next((u for u in users if u.role == "admin"), None) - if admin: - await AuthService.migrate_legacy_sessions_to_admin(admin.id) - except Exception as e: - log.warning("auth.legacy_sessions.migration_failed", {"error": str(e)}) + if marker and marker.get("done"): + return + if not await AuthService.has_users(): + return + users = await AuthService.list_users() + admin = next((u for u in users if u.role == "admin"), None) + if admin: + await AuthService.migrate_legacy_sessions_to_admin(admin.id) + + _schedule_startup_phase( + app, + log, + "auth.migrate_legacy_session_owner", + _migrate_legacy_sessions_to_admin, + ) # Setup question handler for real user interaction from flocks.tool.question_handler import setup_api_question_handler - setup_api_question_handler() + await _run_startup_phase( + log, + "question_handler.setup", + setup_api_question_handler, + ) log.info("question_handler.initialized") # Register built-in hooks if memory is enabled @@ -129,7 +220,11 @@ async def lifespan(app: FastAPI): config = await Config.get() if config.memory.enabled: from flocks.hooks.builtin import register_builtin_hooks - register_builtin_hooks() + await _run_startup_phase( + log, + "hooks.register_builtin", + register_builtin_hooks, + ) log.info("hooks.registered") except Exception as e: # Hook registration failure should not stop server startup @@ -138,25 +233,47 @@ async def lifespan(app: FastAPI): # Migrate env-var credentials to .secret.json (idempotent) try: from flocks.provider.credential import migrate_env_credentials - migrated = migrate_env_credentials() - if migrated > 0: - log.info("credential.env_migration.done", {"migrated": migrated}) + + def _migrate_env_credentials_phase() -> None: + migrated = migrate_env_credentials() + if migrated > 0: + log.info("credential.env_migration.done", {"migrated": migrated}) + + _schedule_startup_phase( + app, + log, + "credential.migrate_env_credentials", + _migrate_env_credentials_phase, + ) except Exception as e: log.warning("credential.env_migration.failed", {"error": str(e)}) # Sync new catalog models into flocks.json for existing providers (idempotent) try: from flocks.provider.model_catalog import sync_catalog_models_to_config - synced = sync_catalog_models_to_config() - if synced > 0: - log.info("catalog.model_sync.done", {"models_added": synced}) + + def _sync_catalog_models_phase() -> None: + synced = sync_catalog_models_to_config() + if synced > 0: + log.info("catalog.model_sync.done", {"models_added": synced}) + + _schedule_startup_phase( + app, + log, + "catalog.sync_models_to_config", + _sync_catalog_models_phase, + ) except Exception as e: log.warning("catalog.model_sync.failed", {"error": str(e)}) # Load custom providers from flocks.json into runtime try: from flocks.server.routes.custom_provider import load_custom_providers_on_startup - await load_custom_providers_on_startup() + await _run_startup_phase( + log, + "custom_providers.load", + load_custom_providers_on_startup, + ) log.info("custom_providers.loaded") except Exception as e: log.warning("custom_providers.load.failed", {"error": str(e)}) @@ -165,23 +282,31 @@ async def lifespan(app: FastAPI): # after a service restart, without requiring manual UI reconnection. try: from flocks.mcp import MCP - await MCP.init() - log.info("mcp.initialized") + + _schedule_startup_phase(app, log, "mcp.init", MCP.init) except Exception as e: log.warning("mcp.init_failed", {"error": str(e)}) # Sync workflows from .flocks/workflow/ filesystem into Storage try: from flocks.server.routes.workflow import sync_workflows_from_filesystem - imported = await sync_workflows_from_filesystem() - log.info("workflow.sync.done", {"imported": imported}) + + async def _sync_workflows_phase() -> None: + imported = await sync_workflows_from_filesystem() + log.info("workflow.sync.done", {"imported": imported}) + + _schedule_startup_phase(app, log, "workflow.sync_filesystem", _sync_workflows_phase) except Exception as e: log.warning("workflow.sync.failed", {"error": str(e)}) # Start Task Center (scheduler + queue executor) try: from flocks.task.manager import TaskManager - await TaskManager.start() + await _run_startup_phase( + log, + "task_manager.start", + TaskManager.start, + ) log.info("task_manager.started") except Exception as e: from flocks.task.manager import TaskManager @@ -191,55 +316,93 @@ async def lifespan(app: FastAPI): # Seed built-in scheduled tasks from .flocks/plugins/tasks/*.json (idempotent) try: from flocks.task.plugin import seed_tasks_from_plugin - seeded = await seed_tasks_from_plugin() - if seeded: - log.info("task.plugin.seeded", {"count": seeded}) + + async def _seed_tasks_phase() -> None: + seeded = await seed_tasks_from_plugin() + if seeded: + log.info("task.plugin.seeded", {"count": seeded}) + + _schedule_startup_phase(app, log, "task.seed_plugin_specs", _seed_tasks_phase) except Exception as e: log.warning("task.plugin.seed_failed", {"error": str(e)}) # Start Skill file watcher (auto-invalidate cache on SKILL.md changes) try: from flocks.skill.skill import Skill - Skill.start_watcher() - log.info("skill.watcher.initialized") + + def _start_skill_watcher() -> None: + Skill.start_watcher() + log.info("skill.watcher.initialized") + + _schedule_startup_phase(app, log, "skill.watcher.start", _start_skill_watcher) except Exception as e: log.warning("skill.watcher.init_failed", {"error": str(e)}) # Start Agent file watcher (auto-invalidate cache on plugin agent changes) try: from flocks.agent.registry import Agent - Agent.start_watcher() - log.info("agent.watcher.initialized") + + def _start_agent_watcher() -> None: + Agent.start_watcher() + log.info("agent.watcher.initialized") + + _schedule_startup_phase(app, log, "agent.watcher.start", _start_agent_watcher) except Exception as e: log.warning("agent.watcher.init_failed", {"error": str(e)}) # Start Tool file watcher (auto-reload plugin tools on file changes) try: from flocks.tool.registry import ToolRegistry - ToolRegistry.start_watcher() - log.info("tool.watcher.initialized") + + def _start_tool_watcher() -> None: + ToolRegistry.start_watcher() + log.info("tool.watcher.initialized") + + _schedule_startup_phase(app, log, "tool.watcher.start", _start_tool_watcher) except Exception as e: log.warning("tool.watcher.init_failed", {"error": str(e)}) # Start Channel Gateway (connect enabled IM channels) try: from flocks.channel.gateway.manager import default_manager - await default_manager.start_all() - log.info("channel.gateway.started") + + async def _start_channel_gateway() -> None: + await default_manager.start_all() + log.info("channel.gateway.started") + + _schedule_startup_phase(app, log, "channel.gateway.start", _start_channel_gateway) except Exception as e: log.warning("channel.gateway.start_failed", {"error": str(e)}) try: from flocks.updater.updater import recover_upgrade_state - await asyncio.to_thread(recover_upgrade_state) + await _run_startup_phase( + log, + "updater.recover_upgrade_state", + lambda: asyncio.to_thread(recover_upgrade_state), + ) log.info("updater.recovery.checked") except Exception as e: log.warning("updater.recovery.failed", {"error": str(e)}) + blocking_startup_ms = int((time.perf_counter() - startup_started_at) * 1000) + log.info("server.startup.ready", { + "blocking_duration_ms": blocking_startup_ms, + "background_tasks": len(app.state.startup_background_tasks), + }) + yield - # --- Graceful shutdown: notify SSE clients FIRST --- + background_tasks = list(getattr(app.state, "startup_background_tasks", [])) + for task in background_tasks: + if not task.done(): + task.cancel() + if background_tasks: + await asyncio.gather(*background_tasks, return_exceptions=True) + + # Notify SSE clients before stopping sessions, MCP transports, and other + # long-lived runtime services so browser listeners see the shutdown event. try: from flocks.server.routes.event import EventBroadcaster broadcaster = EventBroadcaster.get() diff --git a/flocks/server/routes/_timing.py b/flocks/server/routes/_timing.py new file mode 100644 index 00000000..f0c8e883 --- /dev/null +++ b/flocks/server/routes/_timing.py @@ -0,0 +1,28 @@ +"""Helpers for route timing logs.""" + +from __future__ import annotations + +import time +from typing import Any + +from flocks.utils.log import Logger + +DEFAULT_SLOW_ROUTE_LOG_THRESHOLD_MS = 300 + + +def log_route_timing( + logger: Logger, + event: str, + *, + started_at: float, + extra: dict[str, Any] | None = None, + slow_threshold_ms: int = DEFAULT_SLOW_ROUTE_LOG_THRESHOLD_MS, +) -> int: + """Log route timings at INFO only when a request is slow enough.""" + duration_ms = int((time.perf_counter() - started_at) * 1000) + payload = {"duration_ms": duration_ms, **(extra or {})} + if duration_ms >= slow_threshold_ms: + logger.info(event, payload) + else: + logger.debug(event, payload) + return duration_ms diff --git a/flocks/server/routes/channel.py b/flocks/server/routes/channel.py index fa456a1d..4ee07d1a 100644 --- a/flocks/server/routes/channel.py +++ b/flocks/server/routes/channel.py @@ -300,6 +300,52 @@ async def _do(): return {"ok": True} +# --------------------------------------------------------------------------- +# Weixin QR login +# --------------------------------------------------------------------------- + +class WeixinQrStartRequest(BaseModel): + baseUrl: Optional[str] = None + + +@router.post("/weixin/qr-login/start") +async def weixin_qr_login_start(req: WeixinQrStartRequest): + """Request a fresh iLink Bot QR code for WeChat account login. + + No credentials needed — this is the pre-authentication step. + Returns ``{qrcode_value, qrcode_url}`` for the frontend to render. + """ + from flocks.channel.builtin.weixin.config import ILINK_BASE_URL + from flocks.channel.builtin.weixin.qr_login import start_qr_login + + base_url = (req.baseUrl or "").strip() or ILINK_BASE_URL + try: + result = await start_qr_login(base_url=base_url) + return {"ok": True, **result} + except Exception as exc: + log.error("weixin.qr_login.start_failed", {"error": str(exc)}) + raise HTTPException(status_code=502, detail=str(exc)) + + +@router.get("/weixin/qr-login/status") +async def weixin_qr_login_status(qrcode: str, baseUrl: Optional[str] = None): + """Poll the QR code scan status once. + + Returns ``{status}`` where status ∈ waiting | scaned | expired | confirmed. + On ``confirmed`` also returns ``{account_id, token}``. + """ + from flocks.channel.builtin.weixin.config import ILINK_BASE_URL + from flocks.channel.builtin.weixin.qr_login import poll_qr_status + + base_url = (baseUrl or "").strip() or ILINK_BASE_URL + try: + result = await poll_qr_status(qrcode_value=qrcode, base_url=base_url) + return {"ok": True, **result} + except Exception as exc: + log.error("weixin.qr_login.poll_failed", {"error": str(exc)}) + raise HTTPException(status_code=502, detail=str(exc)) + + # --------------------------------------------------------------------------- # Telegram pairing # --------------------------------------------------------------------------- diff --git a/flocks/server/routes/session.py b/flocks/server/routes/session.py index db40d397..e2d04c20 100644 --- a/flocks/server/routes/session.py +++ b/flocks/server/routes/session.py @@ -15,6 +15,7 @@ from pydantic import BaseModel, Field, ConfigDict from flocks.auth.context import get_current_auth_user +from flocks.server.routes._timing import log_route_timing from flocks.session.session import Session, SessionInfo as SessionModel from flocks.session.policy import SessionPolicy from flocks.utils.log import Log @@ -206,6 +207,7 @@ async def list_sessions( category: Optional[str] = Query(None, description="Filter by category: user or task"), ) -> List[SessionResponse]: """List all sessions with optional filters""" + started_at = time.perf_counter() _current_user = require_user(request) all_sessions = await Session.list_all() @@ -235,7 +237,15 @@ async def list_sessions( if limit is not None and len(filtered) >= limit: break - return [_session_to_response(s) for s in filtered] + response = [_session_to_response(s) for s in filtered] + log_route_timing(log, "session.list.complete", started_at=started_at, extra={ + "count": len(response), + "roots": roots, + "limit": limit, + "search": bool(search), + "category": category, + }) + return response @router.post( diff --git a/flocks/server/routes/task_entities.py b/flocks/server/routes/task_entities.py index 62e01abb..69170c4d 100644 --- a/flocks/server/routes/task_entities.py +++ b/flocks/server/routes/task_entities.py @@ -1,13 +1,17 @@ """Execution-centric task scheduler/execution routes.""" from enum import Enum +import time from typing import List, Optional, Type from fastapi import APIRouter, HTTPException, Query, status from pydantic import BaseModel, ConfigDict, Field +from flocks.server.routes._timing import log_route_timing +from flocks.utils.log import Log router = APIRouter() +log = Log.create(service="task-routes") class SchedulerCreateRequest(BaseModel): @@ -143,21 +147,40 @@ def _parse_task_type(task_type: str) -> str: async def get_task_system_notice(): from flocks.task.manager import TaskManager - return await TaskManager.get_task_page_notice() + started_at = time.perf_counter() + notice = await TaskManager.get_task_page_notice() + log_route_timing(log, "task.notice.complete", started_at=started_at, extra={ + "has_notice": bool(notice), + }) + return notice @router.get("/task-system/dashboard") async def task_dashboard(): from flocks.task.manager import TaskManager - return await TaskManager.dashboard() + started_at = time.perf_counter() + payload = await TaskManager.dashboard() + log_route_timing(log, "task.dashboard.complete", started_at=started_at, extra={ + "running": payload.get("running"), + "queued": payload.get("queued"), + "scheduled_active": payload.get("scheduled_active"), + }) + return payload @router.get("/task-system/queue/status") async def task_queue_status(): from flocks.task.manager import TaskManager - return await TaskManager.queue_status() + started_at = time.perf_counter() + payload = await TaskManager.queue_status() + log_route_timing(log, "task.queue_status.complete", started_at=started_at, extra={ + "queued": payload.get("queued"), + "running": payload.get("running"), + "paused": payload.get("paused"), + }) + return payload @router.post("/task-system/queue/pause") diff --git a/flocks/server/routes/tool.py b/flocks/server/routes/tool.py index 6fcc5f30..dc1a0f19 100644 --- a/flocks/server/routes/tool.py +++ b/flocks/server/routes/tool.py @@ -3,11 +3,13 @@ """ import asyncio +import time from typing import List, Optional, Dict, Any from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel, Field from flocks.server.auth import require_admin +from flocks.server.routes._timing import log_route_timing from flocks.utils.log import Log from flocks.config.config_writer import ConfigWriter from flocks.permission.next import DeniedError, PermissionNext @@ -439,6 +441,7 @@ async def list_tools( List of tool information """ # Initialize registry if needed + started_at = time.perf_counter() ToolRegistry.init() # Parse category filter @@ -458,7 +461,12 @@ async def list_tools( # Apply source filter if specified if source: result = [t for t in result if t.source == source] - + + log_route_timing(log, "tools.list.complete", started_at=started_at, extra={ + "count": len(result), + "category": category, + "source": source, + }) return result @@ -764,6 +772,7 @@ async def refresh_tools(_admin: object = Depends(require_admin)): This is the batch counterpart to the single-tool ``/{name}/reload`` endpoint. """ + started_at = time.perf_counter() ToolRegistry.init() errors: list[str] = [] @@ -783,7 +792,10 @@ async def refresh_tools(_admin: object = Depends(require_admin)): errors.append(f"plugin: {e}") tool_count = len(ToolRegistry.all_tool_ids()) - log.info("tools.refresh.done", {"tool_count": tool_count, "errors": len(errors)}) + log_route_timing(log, "tools.refresh.done", started_at=started_at, extra={ + "tool_count": tool_count, + "errors": len(errors), + }) if errors: return RefreshResponse( diff --git a/flocks/server/routes/workflow.py b/flocks/server/routes/workflow.py index a20875fe..d35d7cfe 100644 --- a/flocks/server/routes/workflow.py +++ b/flocks/server/routes/workflow.py @@ -364,6 +364,17 @@ def _list_workflows_from_fs() -> List[Dict[str, Any]]: return list(by_id.values()) +async def sync_workflows_from_filesystem() -> int: + """Best-effort startup sync for filesystem-backed workflows. + + The filesystem is the source of truth for workflow definitions. Startup only + needs to migrate any legacy Storage-only records to disk and report how many + workflows are currently discoverable from the configured workflow roots. + """ + await _migrate_storage_to_filesystem() + return len(_list_workflows_from_fs()) + + async def _migrate_storage_to_filesystem() -> None: """One-time migration: move Storage-only workflow definitions to the filesystem. diff --git a/flocks/storage/storage.py b/flocks/storage/storage.py index 94bb5c37..6a0662b1 100644 --- a/flocks/storage/storage.py +++ b/flocks/storage/storage.py @@ -4,8 +4,10 @@ Provides SQLite-based storage similar to Flocks's Storage namespace """ +from contextlib import asynccontextmanager from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar +import sqlite3 +from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Type, TypeVar import json import aiosqlite from datetime import datetime @@ -43,6 +45,9 @@ class Storage: _db_path: Optional[Path] = None _initialized = False _extension_ddls: List[str] = [] + _sqlite_timeout_s = 5.0 + _sqlite_busy_timeout_ms = 5000 + _sqlite_journal_mode = "WAL" @classmethod def _invalidate_runtime_caches(cls) -> None: @@ -71,6 +76,48 @@ def get_db_path(cls) -> Path: data_dir = Config.get_data_path() return data_dir / "flocks.db" + @classmethod + async def configure_connection( + cls, conn: aiosqlite.Connection + ) -> aiosqlite.Connection: + """Apply the runtime SQLite contract to an async connection.""" + await conn.execute(f"PRAGMA journal_mode={cls._sqlite_journal_mode}") + await conn.execute(f"PRAGMA busy_timeout={cls._sqlite_busy_timeout_ms}") + await conn.execute("PRAGMA foreign_keys = ON") + return conn + + @classmethod + def configure_sync_connection(cls, conn: sqlite3.Connection) -> sqlite3.Connection: + """Apply the runtime SQLite contract to a sync connection.""" + conn.execute(f"PRAGMA journal_mode={cls._sqlite_journal_mode}") + conn.execute(f"PRAGMA busy_timeout={cls._sqlite_busy_timeout_ms}") + conn.execute("PRAGMA foreign_keys = ON") + return conn + + @classmethod + @asynccontextmanager + async def connect( + cls, db_path: Optional[Path] = None + ) -> AsyncIterator[aiosqlite.Connection]: + """Open a configured async SQLite connection for the active storage DB.""" + target = Path(db_path) if db_path is not None else cls.get_db_path() + target.parent.mkdir(parents=True, exist_ok=True) + conn = await aiosqlite.connect(target, timeout=cls._sqlite_timeout_s) + try: + await cls.configure_connection(conn) + yield conn + finally: + await conn.close() + + @classmethod + def connect_sync(cls, db_path: Optional[Path] = None) -> sqlite3.Connection: + """Open a configured sync SQLite connection for the active storage DB.""" + target = Path(db_path) if db_path is not None else cls.get_db_path() + target.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(target, timeout=cls._sqlite_timeout_s) + conn.row_factory = sqlite3.Row + return cls.configure_sync_connection(conn) + @classmethod def register_ddl(cls, ddl: str) -> None: """Register an extension DDL script to be executed during ``init()``. @@ -124,7 +171,7 @@ async def init(cls, db_path: Optional[Path] = None) -> None: cls._invalidate_runtime_caches() # Create tables - async with aiosqlite.connect(cls._db_path) as db: + async with cls.connect(cls._db_path) as db: await db.execute(""" CREATE TABLE IF NOT EXISTS storage ( key TEXT PRIMARY KEY, @@ -150,8 +197,9 @@ async def init(cls, db_path: Optional[Path] = None) -> None: # Run extension DDLs registered before init for ddl in cls._extension_ddls: try: - async with aiosqlite.connect(cls._db_path) as db: + async with cls.connect(cls._db_path) as db: await db.executescript(ddl) + await db.commit() except Exception as e: cls._log.warn("storage.extension_ddl.failed", {"error": str(e)}) @@ -166,7 +214,7 @@ async def _create_model_management_tables(cls) -> None: (credentials, model settings, default models, custom providers) is stored in flocks.json / .secret.json. """ - async with aiosqlite.connect(cls._db_path) as db: + async with cls.connect(cls._db_path) as db: await db.executescript(""" -- Usage records (dynamic data — the only model-management table in SQLite) CREATE TABLE IF NOT EXISTS usage_records ( @@ -249,7 +297,7 @@ async def set(cls, key: str, value: Any, value_type: str = "json") -> None: from datetime import UTC now = datetime.now(UTC).isoformat() - async with aiosqlite.connect(cls._db_path) as db: + async with cls.connect(cls._db_path) as db: await db.execute(""" INSERT OR REPLACE INTO storage (key, value, type, created_at, updated_at) VALUES (?, ?, ?, @@ -274,7 +322,7 @@ async def get(cls, key: str, model: Optional[Type[T]] = None) -> Optional[T | An """ await cls._ensure_init() - async with aiosqlite.connect(cls._db_path) as db: + async with cls.connect(cls._db_path) as db: async with db.execute( "SELECT value, type FROM storage WHERE key = ?", (key,) ) as cursor: @@ -303,7 +351,7 @@ async def delete(cls, key: str) -> bool: """ await cls._ensure_init() - async with aiosqlite.connect(cls._db_path) as db: + async with cls.connect(cls._db_path) as db: cursor = await db.execute("DELETE FROM storage WHERE key = ?", (key,)) await db.commit() deleted = cursor.rowcount > 0 @@ -326,7 +374,7 @@ async def list_keys(cls, prefix: Optional[str] = None) -> List[str]: """ await cls._ensure_init() - async with aiosqlite.connect(cls._db_path) as db: + async with cls.connect(cls._db_path) as db: if prefix: query = "SELECT key FROM storage WHERE key LIKE ?" params = (f"{prefix}%",) @@ -360,7 +408,7 @@ async def list_entries( """ await cls._ensure_init() - async with aiosqlite.connect(cls._db_path) as db: + async with cls.connect(cls._db_path) as db: if prefix: query = "SELECT key, value FROM storage WHERE key LIKE ?" params = (f"{prefix}%",) @@ -393,7 +441,7 @@ async def exists(cls, key: str) -> bool: """ await cls._ensure_init() - async with aiosqlite.connect(cls._db_path) as db: + async with cls.connect(cls._db_path) as db: async with db.execute( "SELECT 1 FROM storage WHERE key = ?", (key,) ) as cursor: @@ -414,7 +462,7 @@ async def clear(cls, prefix: Optional[str] = None) -> int: """ await cls._ensure_init() - async with aiosqlite.connect(cls._db_path) as db: + async with cls.connect(cls._db_path) as db: if prefix: query = "DELETE FROM storage WHERE key LIKE ?" params = (f"{prefix}%",) diff --git a/flocks/storage/vector.py b/flocks/storage/vector.py index 35bb4118..fde2f51a 100644 --- a/flocks/storage/vector.py +++ b/flocks/storage/vector.py @@ -12,6 +12,7 @@ import math from datetime import datetime +from flocks.storage.storage import Storage from flocks.utils.log import Log log = Log.create(service="storage.vector") @@ -99,7 +100,7 @@ async def ensure_vector_tables(db_path: Path) -> Dict[str, Any]: } try: - async with aiosqlite.connect(db_path) as db: + async with Storage.connect(db_path) as db: # Create vector tables await db.executescript(VECTOR_SCHEMA_SQL) await db.commit() @@ -188,7 +189,7 @@ async def vector_search( results = [] try: - async with aiosqlite.connect(db_path) as db: + async with Storage.connect(db_path) as db: # Build query query = """ SELECT id, path, source, start_line, end_line, text, embedding @@ -298,7 +299,7 @@ async def fts_search( results = [] try: - async with aiosqlite.connect(db_path) as db: + async with Storage.connect(db_path) as db: # Build FTS query fts_query = build_fts_query(query) if not fts_query: @@ -370,7 +371,7 @@ async def insert_chunks( Number of chunks inserted """ try: - async with aiosqlite.connect(db_path) as db: + async with Storage.connect(db_path) as db: now = datetime.now().timestamp() # Insert into chunks table @@ -442,7 +443,7 @@ async def get_embedding_from_cache( Tuple of (embedding, dims) or None if not found """ try: - async with aiosqlite.connect(db_path) as db: + async with Storage.connect(db_path) as db: cursor = await db.execute(""" SELECT embedding, dims FROM memory_embedding_cache @@ -483,7 +484,7 @@ async def put_embedding_to_cache( Put embedding to cache """ try: - async with aiosqlite.connect(db_path) as db: + async with Storage.connect(db_path) as db: now = datetime.now().timestamp() await db.execute(""" INSERT OR REPLACE INTO memory_embedding_cache diff --git a/flocks/task/manager.py b/flocks/task/manager.py index f5dc7740..dc0d8e4b 100644 --- a/flocks/task/manager.py +++ b/flocks/task/manager.py @@ -813,9 +813,7 @@ def _clear_migration_state(cls) -> None: @staticmethod def _with_db_connection() -> sqlite3.Connection: - conn = sqlite3.connect(Storage.get_db_path()) - conn.row_factory = sqlite3.Row - return conn + return Storage.connect_sync() @classmethod def _legacy_tables_exist(cls) -> bool: diff --git a/flocks/task/store.py b/flocks/task/store.py index 6441c734..32d1a1d4 100644 --- a/flocks/task/store.py +++ b/flocks/task/store.py @@ -32,8 +32,11 @@ async def init(cls) -> None: if cls._initialized: return await Storage._ensure_init() - cls._conn = await aiosqlite.connect(Storage._db_path) - await cls._conn.execute("PRAGMA foreign_keys = ON") + cls._conn = await aiosqlite.connect( + Storage._db_path, + timeout=Storage._sqlite_timeout_s, + ) + await Storage.configure_connection(cls._conn) await cls._conn.executescript(_TASKS_DDL) for stmt in _INDEX_STMTS: await cls._conn.execute(stmt) diff --git a/flocks/tool/channel/channel_message.py b/flocks/tool/channel/channel_message.py index 7b591d8a..30b2ec79 100644 --- a/flocks/tool/channel/channel_message.py +++ b/flocks/tool/channel/channel_message.py @@ -38,6 +38,23 @@ def _normalize_channel_type(channel_type: str | None) -> str | None: return lower +def _get_api_token() -> str | None: + """Read the server API token from the secret manager (non-async, best-effort). + + Reuses ``API_TOKEN_SECRET_ID`` from ``flocks.server.auth`` so that the + secret id stays in lockstep with what the server-side auth middleware + expects; if those drift apart the request will silently start failing + with 401. + """ + try: + from flocks.security import get_secret_manager + from flocks.server.auth import API_TOKEN_SECRET_ID + token = get_secret_manager().get(API_TOKEN_SECRET_ID) + return token.strip() if token and token.strip() else None + except Exception: + return None + + async def _http_session_send( port: int, session_id: str, @@ -60,10 +77,16 @@ async def _http_session_send( if media_url: payload["media_url"] = media_url + headers: dict[str, str] = {} + api_token = _get_api_token() + if api_token: + headers["Authorization"] = f"Bearer {api_token}" + async with httpx.AsyncClient() as client: resp = await client.post( f"http://localhost:{port}/api/channel/session-send", json=payload, + headers=headers, timeout=10.0, ) body = resp.json() @@ -76,6 +99,14 @@ async def _http_session_send( f"ids: {body.get('message_ids', [])}" ), ) + # 401 + we had no token to present: either the secret is unset + # or this process can't read it. Either way, the in-process + # path bypasses HTTP auth and can still deliver the message, + # so we fall back instead of surfacing an error. + # (If we DID send a token and it was rejected, fall through + # and report the server's detail so misconfiguration is visible.) + if resp.status_code == 401 and not api_token: + return None return ToolResult( success=False, error=f"Send failed (HTTP {resp.status_code}): {body.get('detail', body)}", diff --git a/flocks/updater/updater.py b/flocks/updater/updater.py index 4703dab2..eab3a507 100644 --- a/flocks/updater/updater.py +++ b/flocks/updater/updater.py @@ -84,6 +84,15 @@ class UpdateMirrorProfile: pip_index_url: str | None = None +@dataclass(frozen=True) +class _FrontendNpmCandidate: + """A single npm launcher candidate for staged frontend rebuilds.""" + + npm: str + env: dict[str, str] | None + source: str + + # ------------------------------------------------------------------ # # Install root # ------------------------------------------------------------------ # @@ -349,13 +358,16 @@ def _build_uv_sync_env() -> dict[str, str] | None: return {"PATH": os.pathsep.join([current_path] + missing)} -def _build_frontend_subprocess_env(*, npm_registry: str | None = None) -> dict[str, str] | None: +def _build_frontend_subprocess_env_for_node_dir( + node_dir: Path | None, + *, + npm_registry: str | None = None, +) -> dict[str, str] | None: """Build supplemental env vars for frontend npm commands.""" env: dict[str, str] = {} if npm_registry: env["npm_config_registry"] = npm_registry - node_dir = _bundled_node_install_dir() if node_dir is not None: node_bin = str(node_dir if sys.platform == "win32" else node_dir / "bin") current_path = os.environ.get("PATH", "") @@ -366,6 +378,22 @@ def _build_frontend_subprocess_env(*, npm_registry: str | None = None) -> dict[s return env or None +def _build_frontend_subprocess_env(*, npm_registry: str | None = None) -> dict[str, str] | None: + """Build supplemental env vars for frontend npm commands.""" + return _build_frontend_subprocess_env_for_node_dir( + _bundled_node_install_dir(), + npm_registry=npm_registry, + ) + + +def _reset_staged_frontend_workspace(staged_webui_dir: Path) -> None: + """Remove transient frontend build artifacts before retrying another npm candidate.""" + for name in ("node_modules", "dist"): + target = staged_webui_dir / name + if target.exists(): + _safe_remove(target) + + def _dependency_sync_timeout_seconds() -> int: """Return the timeout budget for ``uv sync`` during self-update.""" if sys.platform == "win32": @@ -1974,71 +2002,93 @@ async def perform_update( # ------------------------------------------------------------------ # staged_webui_dir = content_root / "webui" if staged_webui_dir.is_dir() and (staged_webui_dir / "package.json").exists(): - npm = _resolve_npm_executable() - if npm: - yield UpdateProgress(stage="building", message="Installing frontend dependencies...") - npm_env = _build_frontend_subprocess_env(npm_registry=profile.npm_registry) + npm_candidates = _resolve_frontend_npm_candidates(npm_registry=profile.npm_registry) + if npm_candidates: install_subcommand = "ci" if (staged_webui_dir / "package-lock.json").exists() else "install" - install_cmd = [npm, install_subcommand] install_label = f"npm {install_subcommand}" - try: - code, _, err = await _run_async( - install_cmd, - cwd=staged_webui_dir, - timeout=_FRONTEND_DEPENDENCY_INSTALL_TIMEOUT_SECONDS, - env=npm_env, - ) - except subprocess.TimeoutExpired: - shutil.rmtree(tmp_dir, ignore_errors=True) - _fe_dep_timeout = ( - "Frontend dependency install timed out after " - f"{_FRONTEND_DEPENDENCY_INSTALL_TIMEOUT_SECONDS}s while running {install_label}." - ) - _record_update_journal(f"ERROR {_fe_dep_timeout}") - yield UpdateProgress( - stage="error", - message=_fe_dep_timeout, - success=False, - ) - return - if code != 0: - shutil.rmtree(tmp_dir, ignore_errors=True) - _fe_dep = f"Frontend dependency install failed ({install_label}): {err}" - _record_update_journal(f"ERROR {_fe_dep}") - yield UpdateProgress( - stage="error", - message=_fe_dep, - success=False, - ) - return + final_frontend_error: str | None = None - yield UpdateProgress(stage="building", message="Building frontend...") - try: - code, _, err = await _run_async( - [npm, "run", "build"], - cwd=staged_webui_dir, - timeout=_FRONTEND_BUILD_TIMEOUT_SECONDS, - env=npm_env, - ) - except subprocess.TimeoutExpired: - shutil.rmtree(tmp_dir, ignore_errors=True) - _fe_build_timeout = ( - f"Frontend build timed out after {_FRONTEND_BUILD_TIMEOUT_SECONDS}s while running npm run build." - ) - _record_update_journal(f"ERROR {_fe_build_timeout}") - yield UpdateProgress( - stage="error", - message=_fe_build_timeout, - success=False, - ) - return - if code != 0: + for index, candidate in enumerate(npm_candidates): + attempt_source = candidate.source + is_last_attempt = index == len(npm_candidates) - 1 + + yield UpdateProgress(stage="building", message="Installing frontend dependencies...") + install_cmd = [candidate.npm, install_subcommand] + try: + code, _, err = await _run_async( + install_cmd, + cwd=staged_webui_dir, + timeout=_FRONTEND_DEPENDENCY_INSTALL_TIMEOUT_SECONDS, + env=candidate.env, + ) + except subprocess.TimeoutExpired: + final_frontend_error = ( + "Frontend dependency install timed out after " + f"{_FRONTEND_DEPENDENCY_INSTALL_TIMEOUT_SECONDS}s while running {install_label}." + ) + if is_last_attempt: + break + _reset_staged_frontend_workspace(staged_webui_dir) + _record_update_journal( + "WARN " + f"{final_frontend_error} Cleaned staged frontend workspace and retrying with fallback " + f"npm/node after {attempt_source} attempt." + ) + continue + if code != 0: + final_frontend_error = f"Frontend dependency install failed ({install_label}): {err}" + if is_last_attempt: + break + _reset_staged_frontend_workspace(staged_webui_dir) + _record_update_journal( + "WARN " + f"{final_frontend_error} Cleaned staged frontend workspace and retrying with fallback " + f"npm/node after {attempt_source} attempt." + ) + continue + + yield UpdateProgress(stage="building", message="Building frontend...") + try: + code, _, err = await _run_async( + [candidate.npm, "run", "build"], + cwd=staged_webui_dir, + timeout=_FRONTEND_BUILD_TIMEOUT_SECONDS, + env=candidate.env, + ) + except subprocess.TimeoutExpired: + final_frontend_error = ( + f"Frontend build timed out after {_FRONTEND_BUILD_TIMEOUT_SECONDS}s while running npm run build." + ) + if is_last_attempt: + break + _reset_staged_frontend_workspace(staged_webui_dir) + _record_update_journal( + "WARN " + f"{final_frontend_error} Cleaned staged frontend workspace and retrying with fallback " + f"npm/node after {attempt_source} attempt." + ) + continue + if code != 0: + final_frontend_error = f"Frontend build failed: {err}" + if is_last_attempt: + break + _reset_staged_frontend_workspace(staged_webui_dir) + _record_update_journal( + "WARN " + f"{final_frontend_error} Cleaned staged frontend workspace and retrying with fallback " + f"npm/node after {attempt_source} attempt." + ) + continue + + final_frontend_error = None + break + + if final_frontend_error is not None: shutil.rmtree(tmp_dir, ignore_errors=True) - _fe_build = f"Frontend build failed: {err}" - _record_update_journal(f"ERROR {_fe_build}") + _record_update_journal(f"ERROR {final_frontend_error}") yield UpdateProgress( stage="error", - message=_fe_build, + message=final_frontend_error, success=False, ) return @@ -2566,17 +2616,87 @@ def _find_executable(name: str) -> str | None: def _resolve_npm_executable() -> str | None: """Resolve npm from bundled Node first, then standard executable probing.""" + candidates = _resolve_frontend_npm_candidates() + if candidates: + return candidates[0].npm + return None + + +def _resolve_bundled_npm_executable() -> tuple[str, Path] | None: + """Resolve npm from the bundled Node.js install, if available.""" node_dir = _bundled_node_install_dir() - if node_dir is not None: - candidates = ( - [node_dir / "npm.cmd", node_dir / "npm", node_dir / "bin" / "npm"] - if sys.platform == "win32" - else [node_dir / "bin" / "npm", node_dir / "npm"] - ) - for candidate in candidates: - if candidate.exists(): - return str(candidate) + if node_dir is None: + return None + + candidates = ( + [node_dir / "npm.cmd", node_dir / "npm", node_dir / "bin" / "npm"] + if sys.platform == "win32" + else [node_dir / "bin" / "npm", node_dir / "npm"] + ) + for candidate in candidates: + if candidate.exists(): + return str(candidate), node_dir + return None + +def _resolve_system_npm_executable() -> str | None: + """Resolve npm without relying on the bundled Node.js install.""" if sys.platform == "win32": return _find_executable("npm.cmd") or _find_executable("npm") return _find_executable("npm") or _find_executable("npm.cmd") + + +def _resolve_frontend_npm_candidates(*, npm_registry: str | None = None) -> list[_FrontendNpmCandidate]: + """Resolve npm candidates for staged frontend rebuilds.""" + candidates: list[_FrontendNpmCandidate] = [] + + bundled = _resolve_bundled_npm_executable() + if bundled is not None: + bundled_npm, node_dir = bundled + candidates.append( + _FrontendNpmCandidate( + npm=bundled_npm, + env=_build_frontend_subprocess_env_for_node_dir( + node_dir, + npm_registry=npm_registry, + ), + source="bundled", + ) + ) + + system_npm = _resolve_system_npm_executable() + if system_npm is not None and sys.platform == "win32": + if sys.platform == "win32": + normalized_system_npm = system_npm.replace("/", "\\").lower() + duplicate = any( + candidate.npm.replace("/", "\\").lower() == normalized_system_npm + for candidate in candidates + ) + if not duplicate: + candidates.append( + _FrontendNpmCandidate( + npm=system_npm, + env=_build_frontend_subprocess_env_for_node_dir( + None, + npm_registry=npm_registry, + ), + source="system", + ) + ) + + if candidates: + return candidates + + if system_npm is not None: + candidates.append( + _FrontendNpmCandidate( + npm=system_npm, + env=_build_frontend_subprocess_env_for_node_dir( + None, + npm_registry=npm_registry, + ), + source="default", + ) + ) + + return candidates diff --git a/packaging/windows/bootstrap-windows.ps1 b/packaging/windows/bootstrap-windows.ps1 index e8ace3db..268c8907 100644 --- a/packaging/windows/bootstrap-windows.ps1 +++ b/packaging/windows/bootstrap-windows.ps1 @@ -67,6 +67,23 @@ function Add-UserPathEntryIfMissing { } } +function Add-ProcessPathEntryIfMissing { + param([string]$Entry) + + if ([string]::IsNullOrWhiteSpace($Entry)) { return } + + $processPath = $env:Path + if ([string]::IsNullOrWhiteSpace($processPath)) { + $env:Path = $Entry + return + } + + $existing = $processPath -split ';' | Where-Object { ($_.TrimEnd('\','/')).ToLower() -eq $Entry.TrimEnd('\','/').ToLower() } + if (-not $existing) { + $env:Path = "$Entry;$processPath" + } +} + function Resolve-ChromeExecutablePath { param([string]$BrowserRoot) @@ -107,6 +124,20 @@ else { Write-Host "[flocks-bootstrap] warning: bundled node not found at $bundledNode" -ForegroundColor Yellow } +$bundledPythonDir = Join-Path $InstallRoot "tools\python" +$bundledPython = Join-Path $bundledPythonDir "python.exe" +if (Test-Path $bundledPython) { + Add-ProcessPathEntryIfMissing -Entry $bundledPythonDir + $env:FLOCKS_BUNDLED_PYTHON = $bundledPython + $env:UV_PYTHON = $bundledPython + $env:UV_PYTHON_DOWNLOADS = "never" + $env:UV_NO_MANAGED_PYTHON = "1" + Write-Host "[flocks-bootstrap] configured bundled Python runtime: $bundledPython" +} +else { + Write-Host "[flocks-bootstrap] warning: bundled Python runtime not found at $bundledPython" -ForegroundColor Yellow +} + # 2) Expose bundled Chrome for Testing under ~/.flocks/browser so install.ps1's # Resolve-ChromeForTestingPath finds it and skips the real download. # Prefer a directory junction (fast, no disk duplication) and fall back to copy. diff --git a/packaging/windows/build-staging.ps1 b/packaging/windows/build-staging.ps1 index 175bb635..ddee15d4 100644 --- a/packaging/windows/build-staging.ps1 +++ b/packaging/windows/build-staging.ps1 @@ -157,11 +157,35 @@ function Get-OrDownloadFileFromCandidates { throw "Failed to download $Label" } +function Expand-TarGzArchive { + param( + [Parameter(Mandatory = $true)][string]$ArchivePath, + [Parameter(Mandatory = $true)][string]$DestinationPath + ) + + $tarExe = Get-Command tar.exe -ErrorAction SilentlyContinue + if (-not $tarExe) { + throw "tar.exe is required to extract $ArchivePath" + } + + Remove-PathWithRetry -Path $DestinationPath + New-Item -ItemType Directory -Path $DestinationPath -Force | Out-Null + + & $tarExe.Source -xzf $ArchivePath -C $DestinationPath + if ($LASTEXITCODE -ne 0) { + throw "tar.exe failed to extract $ArchivePath with exit code $LASTEXITCODE" + } + $global:LASTEXITCODE = 0 +} + Write-Host "[build-staging] RepoRoot: $RepoRoot" Write-Host "[build-staging] OutputDir: $OutputDir" $manifest = Read-Manifest -Path $ManifestPath $uvVersion = $manifest.uv.version +$pythonVersion = $manifest.python.version +$pythonStandaloneRelease = $manifest.python.python_build_standalone_release +$pythonArchiveName = $manifest.python.windows_archive_name $nodeVersion = $manifest.nodejs.version $nodeSuffix = $manifest.nodejs.windows_zip_suffix $cacheRoot = Resolve-CacheRoot -RepoRoot $RepoRoot -CacheRootOverride $CacheRoot @@ -170,11 +194,13 @@ Write-Host "[build-staging] CacheRoot: $cacheRoot" Ensure-EmptyDir -Path $OutputDir $toolsUv = Join-Path $OutputDir "tools\uv" +$toolsPython = Join-Path $OutputDir "tools\python" $toolsNode = Join-Path $OutputDir "tools\node" $toolsChrome = Join-Path $OutputDir "tools\chrome" $flocksDest = Join-Path $OutputDir "flocks" New-Item -ItemType Directory -Path $toolsUv -Force | Out-Null +New-Item -ItemType Directory -Path $toolsPython -Force | Out-Null New-Item -ItemType Directory -Path $toolsNode -Force | Out-Null New-Item -ItemType Directory -Path $toolsChrome -Force | Out-Null @@ -185,6 +211,38 @@ $uvZip = Join-Path $cacheRoot "downloads\uv-$uvVersion-$uvZipName" Get-OrDownloadFile -Url $uvUrl -CachePath $uvZip -Label "uv $uvVersion" Expand-Archive -Path $uvZip -DestinationPath $toolsUv -Force +# Python runtime (python-build-standalone install-only archive) +$pythonArchiveNameEscaped = [Uri]::EscapeDataString($pythonArchiveName) +$pythonMirrorBase = $env:FLOCKS_PYTHON_STANDALONE_MIRROR_BASE_URL +$pythonUrls = @() +if (-not [string]::IsNullOrWhiteSpace($pythonMirrorBase)) { + $mirrorBase = $pythonMirrorBase.TrimEnd('/') + $pythonUrls += "$mirrorBase/$pythonStandaloneRelease/$pythonArchiveNameEscaped" + Write-Host "[build-staging] Added python-build-standalone mirror candidate from FLOCKS_PYTHON_STANDALONE_MIRROR_BASE_URL" +} +$pythonUrls += "https://github.com/astral-sh/python-build-standalone/releases/download/$pythonStandaloneRelease/$pythonArchiveNameEscaped" +$pythonArchive = Join-Path $cacheRoot "downloads\python-$pythonVersion-$pythonStandaloneRelease-$pythonArchiveName" +Get-OrDownloadFileFromCandidates -Urls $pythonUrls -CachePath $pythonArchive -Label "Python $pythonVersion" + +$pythonExtract = Join-Path $env:TEMP "python-extract-$pythonVersion-$pythonStandaloneRelease" +Expand-TarGzArchive -ArchivePath $pythonArchive -DestinationPath $pythonExtract +$pythonExe = Get-ChildItem -Path $pythonExtract -Recurse -Filter "python.exe" -File -ErrorAction SilentlyContinue | + Where-Object { $_.DirectoryName -notmatch '\\DLLs($|\\)' } | + Select-Object -First 1 +if (-not $pythonExe) { + throw "python.exe not found after extracting bundled Python runtime" +} +$pythonSource = $pythonExe.Directory.FullName +robocopy $pythonSource $toolsPython /E /NFL /NDL /NJH /NJS /nc /ns /np | Out-Null +if ($LASTEXITCODE -ge 8) { + throw "robocopy failed while copying bundled Python with exit code $LASTEXITCODE" +} +$global:LASTEXITCODE = 0 +Remove-PathWithRetry -Path $pythonExtract +if (-not (Test-Path (Join-Path $toolsPython "python.exe"))) { + throw "Bundled Python runtime missing python.exe under tools\python after extraction" +} + # Node.js official zip (portable) $nodeZipName = "node-v$nodeVersion-$nodeSuffix.zip" $nodeUrl = "https://nodejs.org/dist/v$nodeVersion/$nodeZipName" diff --git a/packaging/windows/flocks-setup.iss b/packaging/windows/flocks-setup.iss index bccc5977..aa1643c6 100644 --- a/packaging/windows/flocks-setup.iss +++ b/packaging/windows/flocks-setup.iss @@ -49,13 +49,14 @@ Root: HKCU; Subkey: "Environment"; ValueType: string; ValueName: "FLOCKS_INSTALL Root: HKCU; Subkey: "Environment"; ValueType: string; ValueName: "FLOCKS_REPO_ROOT"; ValueData: "{app}\flocks"; Flags: uninsdeletevalue Root: HKCU; Subkey: "Environment"; ValueType: string; ValueName: "FLOCKS_NODE_HOME"; ValueData: "{app}\tools\node"; Flags: uninsdeletevalue -; Shortcuts intentionally target the same wrapper path that `scripts\install.ps1` -; writes, so the Start menu / desktop icon and `flocks start` typed in a new -; terminal are strictly equivalent across all install flows. +; Installer-created launch shortcuts intentionally go through a tiny elevation +; helper first, then invoke the same `%USERPROFILE%\.local\bin\flocks.cmd` +; wrapper that `scripts\install.ps1` writes. This keeps the app entrypoint +; consistent while letting Windows prompt for UAC on shortcut launches. [Icons] -Name: "{autoprograms}\{#MyAppName}\Start Flocks"; Filename: "{%USERPROFILE}\.local\bin\flocks.cmd"; Parameters: "start"; WorkingDir: "{%USERPROFILE}" +Name: "{autoprograms}\{#MyAppName}\Start Flocks"; Filename: "powershell.exe"; Parameters: "-NoProfile -ExecutionPolicy Bypass -WindowStyle Hidden -File ""{app}\flocks\packaging\windows\start-flocks-elevated.ps1"""; WorkingDir: "{%USERPROFILE}" Name: "{autoprograms}\{#MyAppName}\Flocks repository"; Filename: "{app}\flocks"; WorkingDir: "{app}\flocks" -Name: "{userdesktop}\{#MyAppName}"; Filename: "{%USERPROFILE}\.local\bin\flocks.cmd"; Parameters: "start"; WorkingDir: "{%USERPROFILE}"; Tasks: desktopicon +Name: "{userdesktop}\{#MyAppName}"; Filename: "powershell.exe"; Parameters: "-NoProfile -ExecutionPolicy Bypass -WindowStyle Hidden -File ""{app}\flocks\packaging\windows\start-flocks-elevated.ps1"""; WorkingDir: "{%USERPROFILE}"; Tasks: desktopicon [Run] Filename: "powershell.exe"; Parameters: "-NoProfile -ExecutionPolicy Bypass -File ""{app}\flocks\packaging\windows\bootstrap-windows.ps1"" -InstallRoot ""{app}"""; StatusMsg: "Setting up Python and JavaScript dependencies..."; Flags: runascurrentuser waituntilterminated diff --git a/packaging/windows/start-flocks-elevated.ps1 b/packaging/windows/start-flocks-elevated.ps1 new file mode 100644 index 00000000..2a7d5c14 --- /dev/null +++ b/packaging/windows/start-flocks-elevated.ps1 @@ -0,0 +1,16 @@ +[CmdletBinding()] +param() + +$wrapperPath = Join-Path $HOME ".local\bin\flocks.cmd" +if (-not (Test-Path -LiteralPath $wrapperPath)) { + throw "Flocks launcher not found: $wrapperPath" +} + +$cmdPath = $env:ComSpec +if ([string]::IsNullOrWhiteSpace($cmdPath)) { + $cmdPath = "cmd.exe" +} + +# Route installer-created shortcuts through UAC, but keep the real app entrypoint +# on the shared flocks.cmd wrapper so shortcut launches match terminal launches. +Start-Process -FilePath $cmdPath -ArgumentList @("/c", "`"$wrapperPath`" start") -WorkingDirectory $HOME -WindowStyle Hidden -Verb RunAs diff --git a/packaging/windows/versions.manifest.json b/packaging/windows/versions.manifest.json index 1d420759..d26eaf5c 100644 --- a/packaging/windows/versions.manifest.json +++ b/packaging/windows/versions.manifest.json @@ -3,6 +3,11 @@ "uv": { "version": "0.9.15" }, + "python": { + "version": "3.12.12", + "python_build_standalone_release": "20251202", + "windows_archive_name": "cpython-3.12.12+20251202-x86_64-pc-windows-msvc-install_only_stripped.tar.gz" + }, "nodejs": { "version": "24.14.0", "windows_zip_suffix": "win-x64" diff --git a/pyproject.toml b/pyproject.toml index 95d63ad9..87476056 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "flocks" -version = "v2026.5.9" +version = "v2026.5.12" description = "AI-Native SecOps platform with multi-agent collaboration" authors = [ {name = "Flocks Team", email = "team@example.com"} diff --git a/tests/mcp/test_mcp_client.py b/tests/mcp/test_mcp_client.py index 5eff5df2..4bdc162a 100644 --- a/tests/mcp/test_mcp_client.py +++ b/tests/mcp/test_mcp_client.py @@ -1,82 +1,178 @@ +import asyncio +from contextlib import asynccontextmanager +from types import MethodType +from unittest.mock import AsyncMock + import pytest +import flocks.mcp.client as mcp_client_module from flocks.mcp.client import McpClient class TestMcpClientTransportSelection: @pytest.mark.asyncio - async def test_connect_uses_sse_only_when_transport_is_sse(self, monkeypatch: pytest.MonkeyPatch): + async def test_connect_routes_remote_servers_to_remote_owner( + self, + monkeypatch: pytest.MonkeyPatch, + ): calls: list[str] = [] - async def fake_http(*args, **kwargs): - calls.append("http") - - async def fake_sse(*args, **kwargs): - calls.append("sse") + async def fake_remote(startup_future): + calls.append("remote") + startup_future.set_result(None) client = McpClient( name="demo", server_type="remote", - url="https://example.com/sse", - transport="sse", + url="https://example.com/mcp", ) - monkeypatch.setattr(client, "_do_connect_streamable_http", fake_http) - monkeypatch.setattr(client, "_do_connect_sse", fake_sse) + monkeypatch.setattr(client, "_connect_remote", fake_remote) await client.connect() - assert calls == ["sse"] - assert client._transport_type == "sse" + assert calls == ["remote"] @pytest.mark.asyncio - async def test_connect_uses_http_only_when_transport_is_http(self, monkeypatch: pytest.MonkeyPatch): + async def test_connect_routes_stdio_servers_to_local_owner( + self, + monkeypatch: pytest.MonkeyPatch, + ): calls: list[str] = [] - async def fake_http(*args, **kwargs): - calls.append("http") + async def fake_local(startup_future): + calls.append("local") + startup_future.set_result(None) + + client = McpClient( + name="demo", + server_type="stdio", + command=["python", "-m", "demo"], + ) + monkeypatch.setattr(client, "_connect_local", fake_local) + + await client.connect() + + assert calls == ["local"] - async def fake_sse(*args, **kwargs): - calls.append("sse") + @pytest.mark.asyncio + async def test_timeout_none_defaults_to_safe_float( + self, + monkeypatch: pytest.MonkeyPatch, + ): + observed: list[float] = [] + + async def fake_remote(startup_future): + observed.append(client.timeout) + startup_future.set_result(None) client = McpClient( name="demo", server_type="remote", url="https://example.com/mcp", - transport="http", + timeout=None, ) - monkeypatch.setattr(client, "_do_connect_streamable_http", fake_http) - monkeypatch.setattr(client, "_do_connect_sse", fake_sse) + monkeypatch.setattr(client, "_connect_remote", fake_remote) await client.connect() - assert calls == ["http"] - assert client._transport_type == "streamable_http" + assert observed == [30.0] @pytest.mark.asyncio - async def test_connect_auto_falls_back_to_sse_after_http_failure(self, monkeypatch: pytest.MonkeyPatch): - calls: list[str] = [] + async def test_unknown_type_raises_value_error(self): + client = McpClient( + name="demo", + server_type="websocket", + url="wss://example.com", + ) - async def fake_http(*args, **kwargs): - calls.append("http") - raise RuntimeError("HTTP 405") + with pytest.raises(ValueError, match="Unknown server type: websocket"): + await client.connect() - async def fake_sse(*args, **kwargs): - calls.append("sse") + @pytest.mark.asyncio + async def test_failed_connect_cleans_up_owner_runtime_state(self): + client = McpClient( + name="demo", + server_type="websocket", + url="wss://example.com", + ) - async def fake_cleanup(): - return None + with pytest.raises(ValueError, match="Unknown server type: websocket"): + await client.connect() + assert client._connected is False + assert client._command_queue is None + assert client._owner_task is None + assert isinstance(client._owner_error, ValueError) + + @pytest.mark.asyncio + async def test_already_connected_skips_new_owner_task( + self, + monkeypatch: pytest.MonkeyPatch, + ): client = McpClient( name="demo", server_type="remote", url="https://example.com/mcp", - transport="auto", ) - monkeypatch.setattr(client, "_do_connect_streamable_http", fake_http) - monkeypatch.setattr(client, "_do_connect_sse", fake_sse) - monkeypatch.setattr(client, "_cleanup_connection", fake_cleanup) + client._connected = True + fake_owner = AsyncMock() + monkeypatch.setattr(client, "_run_connection_owner", fake_owner) await client.connect() - assert calls == ["http", "sse"] - assert client._transport_type == "sse" + fake_owner.assert_not_called() + + @pytest.mark.asyncio + async def test_connect_local_closes_stderr_file_on_failure( + self, + monkeypatch: pytest.MonkeyPatch, + ): + class _FakeTempFile: + def __init__(self) -> None: + self.closed = False + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + self.close() + return False + + def seek(self, _offset: int) -> None: + return None + + def read(self, _size: int = -1) -> str: + return "stdio stderr" + + def close(self) -> None: + self.closed = True + + fake_stderr = _FakeTempFile() + client = McpClient( + name="demo", + server_type="stdio", + command=["python", "-m", "demo"], + ) + + @asynccontextmanager + async def broken_stdio(self, _server_params, stderr_file): + assert stderr_file is fake_stderr + raise RuntimeError("spawn failed") + yield + + monkeypatch.setattr( + mcp_client_module.tempfile, + "TemporaryFile", + lambda mode="w+": fake_stderr, + ) + monkeypatch.setattr( + client, + "_create_stdio_streams", + MethodType(broken_stdio, client), + ) + + startup_future = asyncio.get_running_loop().create_future() + with pytest.raises(RuntimeError, match="Stdio connection failed"): + await client._connect_local(startup_future) + + assert fake_stderr.closed is True diff --git a/tests/mcp/test_mcp_client_sse.py b/tests/mcp/test_mcp_client_sse.py index a0da9fee..479a32a4 100644 --- a/tests/mcp/test_mcp_client_sse.py +++ b/tests/mcp/test_mcp_client_sse.py @@ -1,220 +1,243 @@ -""" -Tests for MCP Client SSE transport support - -Verifies that McpClient correctly handles: -- remote / sse server type (auto-detect: Streamable HTTP -> SSE fallback) -- Timeout does NOT fall back (avoids double wait) -- Unknown server types (raises ValueError) -- Error message extraction from ExceptionGroups -""" +"""Tests for MCP client remote transport lifecycle and fallback behavior.""" import asyncio +from contextlib import asynccontextmanager +from types import MethodType, SimpleNamespace import pytest -from unittest.mock import AsyncMock, MagicMock, patch -from flocks.mcp.client import McpClient, _extract_root_cause - - -class TestMcpClientServerTypes: - """Test McpClient server type routing""" - - def test_init_sse_type(self): - """SSE type should be accepted""" - client = McpClient( - name="test-sse", - server_type="sse", - url="https://example.com/sse", - ) - assert client.server_type == "sse" - assert client.url == "https://example.com/sse" - - def test_init_remote_type(self): - """Remote type should be accepted""" - client = McpClient( - name="test-remote", - server_type="remote", - url="https://example.com/mcp", - ) - assert client.server_type == "remote" - - @pytest.mark.asyncio - async def test_unknown_type_raises_value_error(self): - """Unknown server type should raise ValueError""" - client = McpClient( - name="test-bad", - server_type="websocket", - url="wss://example.com", - ) - with pytest.raises(ValueError, match="Unknown server type: websocket"): - await client.connect() - - @pytest.mark.asyncio - async def test_sse_type_uses_auto_detect(self): - """SSE type should use _connect_remote (auto-detect) same as remote""" - client = McpClient( - name="test", - server_type="sse", - url="https://example.com/sse", - ) - client._connect_remote = AsyncMock() - await client.connect() - client._connect_remote.assert_called_once() - @pytest.mark.asyncio - async def test_remote_type_calls_connect_remote(self): - """Remote type should call _connect_remote""" - client = McpClient( - name="test", - server_type="remote", - url="https://example.com/mcp", - ) - client._connect_remote = AsyncMock() - await client.connect() - client._connect_remote.assert_called_once() +import flocks.mcp.client as mcp_client_module +from flocks.mcp.client import McpClient, _extract_root_cause - @pytest.mark.asyncio - async def test_stdio_type_calls_connect_local(self): - """Stdio type attempts connection (raises RuntimeError on failure)""" - client = McpClient( - name="test", - server_type="stdio", - url=None, - command=["python", "-m", "some_server"], - ) - # Stdio connection will fail since 'some_server' doesn't exist - with pytest.raises((NotImplementedError, RuntimeError)): - await client.connect() - @pytest.mark.asyncio - async def test_already_connected_skips(self): - """Already connected client should skip reconnection""" - client = McpClient( - name="test", - server_type="sse", - url="https://example.com/sse", - ) - client._connected = True - client._connect_remote = AsyncMock() - await client.connect() - client._connect_remote.assert_not_called() +def _make_session_class( + *, + events: dict[str, object] | None = None, + tool_result: object | None = None, + tools: list[object] | None = None, + resources: list[object] | None = None, +): + class FakeSession: + def __init__(self, read_stream, write_stream): + self.read_stream = read_stream + self.write_stream = write_stream + + async def __aenter__(self): + if events is not None: + events["session_enter_task"] = asyncio.current_task() + return self + + async def __aexit__(self, exc_type, exc, tb): + if events is not None: + events["session_exit_task"] = asyncio.current_task() + return False + + async def initialize(self): + return SimpleNamespace(protocolVersion="2026-05-12", serverInfo={"name": "demo"}) + + async def list_tools(self): + return SimpleNamespace(tools=tools or []) + + async def call_tool(self, name, arguments): + if events is not None: + events["call_tool_task"] = asyncio.current_task() + if isinstance(tool_result, Exception): + raise tool_result + if tool_result is not None: + return tool_result + return {"name": name, "arguments": arguments} + + async def list_resources(self): + return SimpleNamespace(resources=resources or []) + + async def read_resource(self, uri): + return {"uri": uri} + + return FakeSession + + +def _make_remote_transport_factory( + label: str, + *, + streams: tuple[object, ...] | None = None, + error: Exception | None = None, + events: dict[str, object] | None = None, + captures: list[tuple[str, str, dict | None]] | None = None, +): + if streams is None: + if label == "http": + streams = ("read", "write", lambda: None) + else: + streams = ("read", "write") + + @asynccontextmanager + async def factory(self, url, headers): + if captures is not None: + captures.append((label, url, headers)) + if error is not None: + raise error + if events is not None: + events[f"{label}_enter_task"] = asyncio.current_task() + try: + yield streams + finally: + if events is not None: + events[f"{label}_exit_task"] = asyncio.current_task() + + return factory + + +def _bind_method(monkeypatch: pytest.MonkeyPatch, client: McpClient, name: str, method) -> None: + monkeypatch.setattr(client, name, MethodType(method, client)) class TestMcpClientRemoteFallback: - """Test remote type fallback from Streamable HTTP to SSE""" - @pytest.mark.asyncio - async def test_remote_falls_back_to_sse(self): - """Remote type should fall back to SSE when Streamable HTTP fails""" + async def test_remote_falls_back_to_sse(self, monkeypatch: pytest.MonkeyPatch): client = McpClient( name="test-remote", server_type="remote", url="https://mcp.example.com/mcp", timeout=10.0, ) + monkeypatch.setattr(mcp_client_module, "ClientSession", _make_session_class()) - # Mock _do_connect_streamable_http to fail - client._do_connect_streamable_http = AsyncMock( - side_effect=RuntimeError("Streamable HTTP not supported") + _bind_method( + monkeypatch, + client, + "_create_streamable_http_streams", + _make_remote_transport_factory("http", error=RuntimeError("HTTP failed")), + ) + _bind_method( + monkeypatch, + client, + "_create_sse_streams", + _make_remote_transport_factory("sse"), ) - # Mock _do_connect_sse to succeed - async def mark_connected(url, headers=None): - client._connected = True - client._do_connect_sse = AsyncMock(side_effect=mark_connected) await client.connect() - client._do_connect_streamable_http.assert_called_once() - client._do_connect_sse.assert_called_once() assert client._transport_type == "sse" + await client.disconnect() @pytest.mark.asyncio - async def test_remote_streamable_http_success_no_sse(self): - """Remote type should not try SSE if Streamable HTTP succeeds""" + async def test_remote_streamable_http_success_no_sse(self, monkeypatch: pytest.MonkeyPatch): client = McpClient( name="test-remote", server_type="remote", url="https://mcp.example.com/mcp", timeout=10.0, ) - - async def mark_connected(url, headers=None): - client._connected = True - client._do_connect_streamable_http = AsyncMock(side_effect=mark_connected) - client._do_connect_sse = AsyncMock() + captures: list[tuple[str, str, dict | None]] = [] + monkeypatch.setattr(mcp_client_module, "ClientSession", _make_session_class()) + + _bind_method( + monkeypatch, + client, + "_create_streamable_http_streams", + _make_remote_transport_factory("http", captures=captures), + ) + _bind_method( + monkeypatch, + client, + "_create_sse_streams", + _make_remote_transport_factory("sse", captures=captures), + ) await client.connect() - client._do_connect_streamable_http.assert_called_once() - client._do_connect_sse.assert_not_called() assert client._transport_type == "streamable_http" + assert [label for label, _, _ in captures] == ["http"] + await client.disconnect() @pytest.mark.asyncio - async def test_remote_both_fail_raises(self): - """Remote type should raise RuntimeError if both transports fail""" + async def test_remote_both_fail_raises(self, monkeypatch: pytest.MonkeyPatch): client = McpClient( name="test-remote", server_type="remote", url="https://mcp.example.com/mcp", timeout=10.0, ) + monkeypatch.setattr(mcp_client_module, "ClientSession", _make_session_class()) - client._do_connect_streamable_http = AsyncMock( - side_effect=RuntimeError("HTTP failed") + _bind_method( + monkeypatch, + client, + "_create_streamable_http_streams", + _make_remote_transport_factory("http", error=RuntimeError("HTTP failed")), ) - client._do_connect_sse = AsyncMock( - side_effect=RuntimeError("SSE failed") + _bind_method( + monkeypatch, + client, + "_create_sse_streams", + _make_remote_transport_factory("sse", error=RuntimeError("SSE failed")), ) with pytest.raises(RuntimeError, match="Connection failed.*SSE failed"): await client.connect() @pytest.mark.asyncio - async def test_sse_type_also_tries_streamable_http_first(self): - """SSE type uses same auto-detect strategy as remote (Streamable HTTP first)""" + async def test_sse_type_also_tries_streamable_http_first(self, monkeypatch: pytest.MonkeyPatch): client = McpClient( name="test-sse", server_type="sse", url="https://mcp.example.com/mcp", timeout=10.0, ) - - async def mark_connected(url, headers=None): - client._connected = True - client._do_connect_streamable_http = AsyncMock(side_effect=mark_connected) - client._do_connect_sse = AsyncMock() + captures: list[tuple[str, str, dict | None]] = [] + monkeypatch.setattr(mcp_client_module, "ClientSession", _make_session_class()) + + _bind_method( + monkeypatch, + client, + "_create_streamable_http_streams", + _make_remote_transport_factory("http", captures=captures), + ) + _bind_method( + monkeypatch, + client, + "_create_sse_streams", + _make_remote_transport_factory("sse", captures=captures), + ) await client.connect() - # "sse" and "remote" share the same auto-detect logic - client._do_connect_streamable_http.assert_called_once() - client._do_connect_sse.assert_not_called() assert client._transport_type == "streamable_http" + assert [label for label, _, _ in captures] == ["http"] + await client.disconnect() @pytest.mark.asyncio - async def test_timeout_does_not_fall_back(self): - """Streamable HTTP timeout should NOT fall back to SSE (avoids double wait)""" + async def test_timeout_does_not_fall_back(self, monkeypatch: pytest.MonkeyPatch): client = McpClient( name="test-timeout", server_type="remote", url="https://mcp.example.com/mcp", timeout=10.0, ) - - client._do_connect_streamable_http = AsyncMock( - side_effect=asyncio.TimeoutError() + captures: list[tuple[str, str, dict | None]] = [] + monkeypatch.setattr(mcp_client_module, "ClientSession", _make_session_class()) + + _bind_method( + monkeypatch, + client, + "_create_streamable_http_streams", + _make_remote_transport_factory("http", error=asyncio.TimeoutError()), + ) + _bind_method( + monkeypatch, + client, + "_create_sse_streams", + _make_remote_transport_factory("sse", captures=captures), ) - client._do_connect_sse = AsyncMock() with pytest.raises(RuntimeError, match="Connection timeout"): await client.connect() - # SSE should NOT have been attempted - client._do_connect_sse.assert_not_called() + assert captures == [] assert client._transport_type is None @pytest.mark.asyncio - async def test_remote_passes_resolved_headers_to_transports(self): - """Remote connection should pass config and auth headers to SDK transports""" + async def test_remote_passes_resolved_headers_to_transports(self, monkeypatch: pytest.MonkeyPatch): client = McpClient( name="test-headers", server_type="remote", @@ -228,15 +251,21 @@ async def test_remote_passes_resolved_headers_to_transports(self): }, timeout=10.0, ) - - client._do_connect_streamable_http = AsyncMock( - side_effect=RuntimeError("HTTP failed") + captures: list[tuple[str, str, dict | None]] = [] + monkeypatch.setattr(mcp_client_module, "ClientSession", _make_session_class()) + + _bind_method( + monkeypatch, + client, + "_create_streamable_http_streams", + _make_remote_transport_factory("http", captures=captures, error=RuntimeError("HTTP failed")), + ) + _bind_method( + monkeypatch, + client, + "_create_sse_streams", + _make_remote_transport_factory("sse", captures=captures), ) - - async def mark_connected(url, headers): - client._connected = True - - client._do_connect_sse = AsyncMock(side_effect=mark_connected) await client.connect() @@ -244,46 +273,104 @@ async def mark_connected(url, headers): "Api-Key": "token123", "Authorization": "Bearer abc", } - client._do_connect_streamable_http.assert_called_once_with( - "https://mcp.example.com/mcp", - expected_headers, + assert captures == [ + ("http", "https://mcp.example.com/mcp", expected_headers), + ("sse", "https://mcp.example.com/mcp", expected_headers), + ] + await client.disconnect() + + @pytest.mark.asyncio + async def test_disconnect_closes_streams_and_session_in_owner_task( + self, + monkeypatch: pytest.MonkeyPatch, + ): + client = McpClient( + name="test-owner", + server_type="remote", + url="https://mcp.example.com/mcp", + ) + events: dict[str, object] = {} + monkeypatch.setattr(mcp_client_module, "ClientSession", _make_session_class(events=events)) + + _bind_method( + monkeypatch, + client, + "_create_streamable_http_streams", + _make_remote_transport_factory("http", events=events), ) - client._do_connect_sse.assert_called_once_with( - "https://mcp.example.com/mcp", - expected_headers, + _bind_method( + monkeypatch, + client, + "_create_sse_streams", + _make_remote_transport_factory("sse", events=events), ) + await client.connect() + await client.disconnect() -class TestExtractRootCause: - """Test _extract_root_cause helper function""" + assert events["http_enter_task"] is events["http_exit_task"] + assert events["session_enter_task"] is events["session_exit_task"] + + @pytest.mark.asyncio + async def test_call_tool_runs_through_owner_task(self, monkeypatch: pytest.MonkeyPatch): + client = McpClient( + name="test-call", + server_type="remote", + url="https://mcp.example.com/mcp", + ) + events: dict[str, object] = {} + monkeypatch.setattr( + mcp_client_module, + "ClientSession", + _make_session_class(events=events), + ) + _bind_method( + monkeypatch, + client, + "_create_streamable_http_streams", + _make_remote_transport_factory("http"), + ) + _bind_method( + monkeypatch, + client, + "_create_sse_streams", + _make_remote_transport_factory("sse"), + ) + + await client.connect() + result = await client.call_tool("demo_tool", {"value": 1}) + await client.disconnect() + + assert result == {"name": "demo_tool", "arguments": {"value": 1}} + assert events["call_tool_task"] is events["session_enter_task"] + + +class TestExtractRootCause: def test_simple_exception(self): - """Simple exception returns its message""" assert _extract_root_cause(RuntimeError("simple error")) == "simple error" def test_exception_group(self): - """ExceptionGroup should unwrap to the root cause""" inner = RuntimeError("real error") group = ExceptionGroup("group", [inner]) assert _extract_root_cause(group) == "real error" def test_nested_exception_group(self): - """Nested ExceptionGroups should be fully unwrapped""" inner = ValueError("deep error") group1 = ExceptionGroup("inner group", [inner]) group2 = ExceptionGroup("outer group", [group1]) assert _extract_root_cause(group2) == "deep error" def test_http_status_error(self): - """HTTP status errors should show status code""" - # Simulate httpx.HTTPStatusError class MockResponse: status_code = 401 + class MockRequest: url = "https://example.com/mcp?apikey=secret123" + exc = Exception("HTTP error") exc.response = MockResponse() exc.request = MockRequest() result = _extract_root_cause(exc) assert "401" in result - assert "secret" not in result # URL should be masked + assert "secret" not in result diff --git a/tests/packaging/test_windows_manifest.py b/tests/packaging/test_windows_manifest.py index 78aea3d1..55bff6da 100644 --- a/tests/packaging/test_windows_manifest.py +++ b/tests/packaging/test_windows_manifest.py @@ -14,3 +14,14 @@ def test_windows_bundled_uv_supports_python_downloads_json_url() -> None: manifest = json.loads(WINDOWS_MANIFEST.read_text(encoding="utf-8")) assert _parse_version(manifest["uv"]["version"]) >= (0, 7, 3) + + +def test_windows_manifest_pins_bundled_python_runtime() -> None: + manifest = json.loads(WINDOWS_MANIFEST.read_text(encoding="utf-8")) + + python = manifest["python"] + assert _parse_version(python["version"]) >= (3, 12, 0) + assert python["python_build_standalone_release"].isdigit() + assert python["windows_archive_name"].endswith(".tar.gz") + assert "install_only" in python["windows_archive_name"] + assert "windows-msvc" in python["windows_archive_name"] diff --git a/tests/scripts/test_browser_runtime_configuration.py b/tests/scripts/test_browser_runtime_configuration.py index 24fe4756..92cd1738 100644 --- a/tests/scripts/test_browser_runtime_configuration.py +++ b/tests/scripts/test_browser_runtime_configuration.py @@ -68,10 +68,16 @@ def test_powershell_bootstrap_wires_bundled_toolchain() -> None: """packaging/windows/bootstrap-windows.ps1 is the single place that bridges the bundled layout to install.ps1.""" script = (PACKAGING_WINDOWS_DIR / "bootstrap-windows.ps1").read_text(encoding="utf-8-sig") + assert "Add-ProcessPathEntryIfMissing" in script assert "Resolve-ChromeExecutablePath" in script assert "FLOCKS_SKIP_ADMIN_CHECK" in script assert "FLOCKS_BROWSER_EXECUTABLE_OVERRIDE" in script + assert "FLOCKS_BUNDLED_PYTHON" in script + assert "UV_PYTHON" in script + assert "UV_PYTHON_DOWNLOADS" in script + assert "UV_NO_MANAGED_PYTHON" in script assert "tools\\uv" in script + assert "tools\\python" in script assert "tools\\node" in script assert "tools\\chrome" in script assert ".flocks\\browser" in script @@ -79,6 +85,16 @@ def test_powershell_bootstrap_wires_bundled_toolchain() -> None: assert 'scripts\\install_zh.ps1' in script +def test_build_staging_bundles_python_runtime() -> None: + script = (PACKAGING_WINDOWS_DIR / "build-staging.ps1").read_text(encoding="utf-8-sig") + + assert "tools\\python" in script + assert "python-build-standalone" in script + assert "python.exe" in script + assert "tar.exe" in script + assert "FLOCKS_PYTHON_STANDALONE_MIRROR_BASE_URL" in script + + def test_inno_setup_points_to_packaging_bootstrap() -> None: """flocks-setup.iss must invoke the bootstrap from its new packaging location.""" iss = (PACKAGING_WINDOWS_DIR / "flocks-setup.iss").read_text(encoding="utf-8") @@ -87,11 +103,9 @@ def test_inno_setup_points_to_packaging_bootstrap() -> None: assert "scripts\\bootstrap-windows.ps1" not in iss -def test_inno_shortcuts_point_to_user_local_bin_wrapper() -> None: - """Start-menu and desktop shortcuts must match the CLI wrapper location that - `scripts/install.ps1` writes, so `flocks start` triggered from the shortcut - and from a freshly opened terminal are strictly equivalent across all - install flows (source, one-liner, bundled installer).""" +def test_inno_shortcuts_point_to_elevated_launcher() -> None: + """Installer-created launch shortcuts should route through the elevated + launcher so clicking them triggers UAC before running the shared wrapper.""" iss = (PACKAGING_WINDOWS_DIR / "flocks-setup.iss").read_text(encoding="utf-8") icons_section_idx = iss.find("[Icons]") @@ -99,7 +113,8 @@ def test_inno_shortcuts_point_to_user_local_bin_wrapper() -> None: assert icons_section_idx != -1 and run_section_idx != -1 icons_block = iss[icons_section_idx:run_section_idx] - expected_target = "{%USERPROFILE}\\.local\\bin\\flocks.cmd" + expected_target = 'Filename: "powershell.exe"' + expected_script = 'start-flocks-elevated.ps1' start_menu_lines = [ line for line in icons_block.splitlines() @@ -108,14 +123,27 @@ def test_inno_shortcuts_point_to_user_local_bin_wrapper() -> None: assert start_menu_lines, "expected Start Flocks + desktop shortcut entries" for line in start_menu_lines: assert expected_target in line, ( - f"shortcut must target the shared wrapper path; got: {line}" + f"shortcut must target PowerShell launcher; got: {line}" ) - assert 'Parameters: "start"' in line + assert expected_script in line + assert "-WindowStyle Hidden" in line + + # Guard against accidentally re-introducing direct shortcut launches that + # bypass the UAC prompt. + assert "{%USERPROFILE}\\.local\\bin\\flocks.cmd" not in icons_block + + +def test_windows_elevated_launcher_runs_shared_wrapper_as_admin() -> None: + """The elevation helper should re-use the shared CLI wrapper and request + Administrator rights via Start-Process.""" + script = (PACKAGING_WINDOWS_DIR / "start-flocks-elevated.ps1").read_text( + encoding="utf-8-sig" + ) - # Guard against accidentally re-introducing a shortcut to {app}\bin, which - # would point to a non-existent file because install.ps1 writes the wrapper - # under %USERPROFILE%\.local\bin. - assert "{app}\\bin\\flocks.cmd" not in icons_block + assert 'Join-Path $HOME ".local\\bin\\flocks.cmd"' in script + assert "Start-Process" in script + assert "-Verb RunAs" in script + assert "`\"$wrapperPath`\" start" in script def test_inno_finish_page_reminds_user_to_reopen_terminal() -> None: diff --git a/tests/server/concurrency/conftest.py b/tests/server/concurrency/conftest.py index f29d6b6a..6ec9323c 100644 --- a/tests/server/concurrency/conftest.py +++ b/tests/server/concurrency/conftest.py @@ -16,10 +16,34 @@ async def client() -> AsyncGenerator[AsyncClient, None]: from flocks.server.app import app transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as ac: + headers = {"Authorization": "Bearer abc123", "User-Agent": "curl/8.0"} + async with AsyncClient( + transport=transport, + base_url="http://test", + headers=headers, + ) as ac: yield ac +@pytest.fixture(autouse=True) +def _concurrency_test_api_token(monkeypatch: pytest.MonkeyPatch) -> None: + """Provide a valid API token for concurrency tests.""" + from flocks.server import auth as auth_module + + class SecretManagerStub: + def __init__(self, values: dict[str, str]): + self._values = values + + def get(self, key: str): + return self._values.get(key) + + monkeypatch.setattr( + auth_module, + "get_secret_manager", + lambda: SecretManagerStub({auth_module.API_TOKEN_SECRET_ID: "abc123"}), + ) + + @pytest.fixture async def session_id(client: AsyncClient) -> str: resp = await client.post("/api/session", json={"title": "concurrency-test"}) diff --git a/tests/server/routes/test_route_timing.py b/tests/server/routes/test_route_timing.py new file mode 100644 index 00000000..86fa81a5 --- /dev/null +++ b/tests/server/routes/test_route_timing.py @@ -0,0 +1,53 @@ +import pytest + +from flocks.server.routes import _timing as timing_module + + +class _Recorder: + def __init__(self) -> None: + self.debug_calls: list[tuple[str, dict]] = [] + self.info_calls: list[tuple[str, dict]] = [] + + def debug(self, message, extra=None) -> None: + self.debug_calls.append((message, extra or {})) + + def info(self, message, extra=None) -> None: + self.info_calls.append((message, extra or {})) + + +def test_log_route_timing_uses_debug_below_threshold(monkeypatch: pytest.MonkeyPatch) -> None: + logger = _Recorder() + monkeypatch.setattr(timing_module.time, "perf_counter", lambda: 100.2) + + duration_ms = timing_module.log_route_timing( + logger, + "session.list.complete", + started_at=100.0, + extra={"count": 2}, + slow_threshold_ms=300, + ) + + assert 199 <= duration_ms <= 200 + assert logger.info_calls == [] + assert logger.debug_calls == [ + ("session.list.complete", {"duration_ms": duration_ms, "count": 2}), + ] + + +def test_log_route_timing_uses_info_at_threshold(monkeypatch: pytest.MonkeyPatch) -> None: + logger = _Recorder() + monkeypatch.setattr(timing_module.time, "perf_counter", lambda: 200.3) + + duration_ms = timing_module.log_route_timing( + logger, + "task.dashboard.complete", + started_at=200.0, + extra={"running": 1}, + slow_threshold_ms=300, + ) + + assert 299 <= duration_ms <= 300 + assert logger.debug_calls == [] + assert logger.info_calls == [ + ("task.dashboard.complete", {"duration_ms": duration_ms, "running": 1}), + ] diff --git a/tests/storage/test_sqlite_connection_config.py b/tests/storage/test_sqlite_connection_config.py new file mode 100644 index 00000000..cdd87b8d --- /dev/null +++ b/tests/storage/test_sqlite_connection_config.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import aiosqlite +import pytest + +from flocks.config.config import Config +from flocks.storage.storage import Storage +from flocks.task.manager import TaskManager + + +@pytest.fixture(autouse=True) +async def isolated_storage(tmp_path: pytest.TempPathFactory, monkeypatch: pytest.MonkeyPatch): + data_dir = tmp_path / "flocks_data" + data_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setenv("FLOCKS_DATA_DIR", str(data_dir)) + + Config._global_config = None + Config._cached_config = None + Storage._initialized = False + Storage._db_path = None + + yield + + Storage._initialized = False + Storage._db_path = None + Config._global_config = None + Config._cached_config = None + + +@pytest.mark.asyncio +async def test_storage_init_enables_wal_for_fresh_database() -> None: + await Storage.init() + + async with aiosqlite.connect(Storage.get_db_path()) as db: + async with db.execute("PRAGMA journal_mode") as cursor: + journal_mode = (await cursor.fetchone())[0] + + assert journal_mode == "wal" + + +@pytest.mark.asyncio +async def test_storage_connect_applies_runtime_sqlite_pragmas() -> None: + await Storage.init() + + async with Storage.connect() as db: + async with db.execute("PRAGMA busy_timeout") as cursor: + busy_timeout = (await cursor.fetchone())[0] + async with db.execute("PRAGMA foreign_keys") as cursor: + foreign_keys = (await cursor.fetchone())[0] + + assert busy_timeout == Storage._sqlite_busy_timeout_ms + assert foreign_keys == 1 + + +def test_storage_connect_sync_applies_runtime_sqlite_pragmas() -> None: + import asyncio + + asyncio.run(Storage.init()) + + with Storage.connect_sync() as db: + busy_timeout = db.execute("PRAGMA busy_timeout").fetchone()[0] + foreign_keys = db.execute("PRAGMA foreign_keys").fetchone()[0] + + assert busy_timeout == Storage._sqlite_busy_timeout_ms + assert foreign_keys == 1 + + +@pytest.mark.asyncio +async def test_task_manager_sync_connection_uses_storage_sqlite_contract() -> None: + await Storage.init() + + with TaskManager._with_db_connection() as db: + row = db.execute( + "SELECT name FROM sqlite_master WHERE type = 'table' AND name = 'storage'" + ).fetchone() + busy_timeout = db.execute("PRAGMA busy_timeout").fetchone()[0] + foreign_keys = db.execute("PRAGMA foreign_keys").fetchone()[0] + + assert row is not None + assert busy_timeout == Storage._sqlite_busy_timeout_ms + assert foreign_keys == 1 diff --git a/tests/storage/test_sqlite_mixed_access.py b/tests/storage/test_sqlite_mixed_access.py new file mode 100644 index 00000000..c9339908 --- /dev/null +++ b/tests/storage/test_sqlite_mixed_access.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path + +import pytest + +from flocks.config.config import Config +from flocks.provider.usage_service import RecordUsageRequest, get_usage_records, record_usage +from flocks.storage.storage import Storage +from flocks.task.models import TaskExecution, TaskScheduler, TaskStatus +from flocks.task.store import TaskStore + + +@pytest.fixture(autouse=True) +async def isolated_storage_and_task_env( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +): + data_dir = tmp_path / "flocks_data" + data_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setenv("FLOCKS_DATA_DIR", str(data_dir)) + + Config._global_config = None + Config._cached_config = None + Storage._initialized = False + Storage._db_path = None + TaskStore._initialized = False + TaskStore._conn = None + + await Storage.init() + await TaskStore.init() + + yield + + await TaskStore.close() + Config._global_config = None + Config._cached_config = None + Storage._initialized = False + Storage._db_path = None + TaskStore._initialized = False + TaskStore._conn = None + + +@pytest.mark.asyncio +async def test_mixed_storage_task_and_usage_access_share_consistent_sqlite_config() -> None: + scheduler = TaskScheduler(title="sqlite-mixed-access") + await TaskStore.create_scheduler(scheduler) + + async def write_storage(idx: int) -> None: + await Storage.set(f"mixed:key:{idx}", {"value": idx}) + + async def write_usage(idx: int) -> None: + await record_usage( + RecordUsageRequest( + provider_id="test-provider", + model_id="test-model", + session_id=f"session-{idx}", + message_id=f"message-{idx}", + input_tokens=idx + 1, + output_tokens=idx + 2, + ) + ) + + async def write_task_execution(idx: int) -> None: + execution = TaskExecution( + scheduler_id=scheduler.id, + title=f"execution-{idx}", + status=TaskStatus.QUEUED, + ) + await TaskStore.create_execution(execution) + await TaskStore.enqueue_execution_ref(execution.id) + + await asyncio.gather( + *[ + asyncio.gather( + write_storage(idx), + write_usage(idx), + write_task_execution(idx), + ) + for idx in range(5) + ] + ) + + keys = await Storage.list_keys(prefix="mixed:key:") + usage_records = await get_usage_records() + executions, total = await TaskStore.list_executions(scheduler_id=scheduler.id, limit=20) + + assert len(keys) == 5 + assert len(usage_records) == 5 + assert total == 5 + assert len(executions) == 5 diff --git a/tests/updater/test_updater.py b/tests/updater/test_updater.py index 43f7fbf5..a17eb4e5 100644 --- a/tests/updater/test_updater.py +++ b/tests/updater/test_updater.py @@ -215,6 +215,67 @@ def fake_find(name: str) -> str | None: assert updater._resolve_npm_executable() == r"C:\Program Files\nodejs\npm.cmd" +def test_resolve_frontend_npm_candidates_adds_system_fallback_on_windows( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + node_home = tmp_path / "tools" / "node" + node_home.mkdir(parents=True) + (node_home / "node.exe").write_text("", encoding="utf-8") + bundled_npm = node_home / "npm.cmd" + bundled_npm.write_text("", encoding="utf-8") + + monkeypatch.setattr(updater.sys, "platform", "win32") + monkeypatch.setenv("FLOCKS_NODE_HOME", str(node_home)) + monkeypatch.delenv("FLOCKS_INSTALL_ROOT", raising=False) + monkeypatch.setenv("PATH", r"C:\Windows\System32") + + def fake_find(name: str) -> str | None: + if name == "npm.cmd": + return r"C:\Program Files\nodejs\npm.cmd" + return None + + monkeypatch.setattr(updater, "_find_executable", fake_find) + + candidates = updater._resolve_frontend_npm_candidates( + npm_registry="https://registry.npmmirror.com/" + ) + + assert [candidate.npm for candidate in candidates] == [ + str(bundled_npm), + r"C:\Program Files\nodejs\npm.cmd", + ] + assert candidates[0].env is not None + assert candidates[0].env["PATH"].split(os.pathsep)[0] == str(node_home) + assert candidates[1].env == { + "npm_config_registry": "https://registry.npmmirror.com/" + } + + +def test_resolve_frontend_npm_candidates_keeps_single_candidate_off_windows( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + node_home = tmp_path / "tools" / "node" + node_bin = node_home / "bin" + node_bin.mkdir(parents=True) + (node_bin / "node").write_text("", encoding="utf-8") + bundled_npm = node_bin / "npm" + bundled_npm.write_text("", encoding="utf-8") + + monkeypatch.setattr(updater.sys, "platform", "linux") + monkeypatch.setenv("FLOCKS_NODE_HOME", str(node_home)) + monkeypatch.delenv("FLOCKS_INSTALL_ROOT", raising=False) + monkeypatch.setattr(updater, "_find_executable", lambda name: f"/usr/bin/{name}") + + candidates = updater._resolve_frontend_npm_candidates( + npm_registry="https://registry.npmmirror.com/" + ) + + assert [candidate.npm for candidate in candidates] == [str(bundled_npm)] + assert candidates[0].source == "bundled" + + def test_find_executable_ignores_wsl_mnt_paths( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, @@ -1609,7 +1670,6 @@ async def fake_validate_windows_restart_runtime(*_args, **_kwargs): monkeypatch.setattr(updater, "_backup_current_version", lambda *_args, **_kwargs: tmp_path / "backup.tar.gz") monkeypatch.setattr(updater, "_extract_archive", lambda *_args, **_kwargs: staged_root) monkeypatch.setattr(updater, "_run_async", fake_run_async) - monkeypatch.setattr(updater, "_resolve_npm_executable", lambda: str(bundled_npm)) monkeypatch.setattr(updater, "_find_executable", lambda name: r"C:\Users\flocks\AppData\Local\Programs\Flocks\tools\uv\uv.exe" if name == "uv" else None) monkeypatch.setattr(updater, "_build_uv_sync_env", lambda: None) monkeypatch.setattr(updater, "_validate_windows_restart_runtime", fake_validate_windows_restart_runtime) @@ -1637,6 +1697,217 @@ async def fake_validate_windows_restart_runtime(*_args, **_kwargs): assert build_env["PATH"].split(os.pathsep)[0] == str(node_home) +@pytest.mark.asyncio +async def test_perform_update_retries_windows_frontend_with_system_npm_after_bundled_build_failure( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + archive_path = tmp_path / "flocks.zip" + archive_path.write_text("archive", encoding="utf-8") + staged_root = tmp_path / "staged" + staged_webui = staged_root / "webui" + staged_webui.mkdir(parents=True) + (staged_webui / "package.json").write_text("{}", encoding="utf-8") + + node_home = tmp_path / "tools" / "node" + node_home.mkdir(parents=True) + (node_home / "node.exe").write_text("", encoding="utf-8") + bundled_npm = node_home / "npm.cmd" + bundled_npm.write_text("", encoding="utf-8") + system_npm = r"C:\Program Files\nodejs\npm.cmd" + + run_calls: list[tuple[list[str], int | None, dict[str, str] | None]] = [] + + async def fake_get_updater_config(): + return SimpleNamespace( + archive_format="zip", + sources=["github"], + repo="AgentFlocks/Flocks", + token=None, + gitee_token=None, + backup_retain_count=3, + base_url=None, + gitee_repo=None, + ) + + async def fake_download_with_fallback(**_kwargs): + return archive_path + + async def fake_run_async(cmd, cwd=None, timeout=None, env=None): + run_calls.append((list(cmd), timeout, env)) + if cmd == [str(bundled_npm), "install"]: + bundled_modules = staged_webui / "node_modules" / "@esbuild" + bundled_modules.mkdir(parents=True, exist_ok=True) + (bundled_modules / "bundled.txt").write_text("bundled", encoding="utf-8") + return 0, "", "" + if cmd == [str(bundled_npm), "run", "build"]: + stale_dist = staged_webui / "dist" + stale_dist.mkdir(exist_ok=True) + (stale_dist / "stale.txt").write_text("stale", encoding="utf-8") + return 1, "", "bundled build failed" + if cmd == [system_npm, "install"]: + assert not (staged_webui / "node_modules").exists() + assert not (staged_webui / "dist").exists() + return 0, "", "" + if cmd == [system_npm, "run", "build"]: + dist_dir = staged_webui / "dist" + dist_dir.mkdir(exist_ok=True) + (dist_dir / "index.html").write_text("", encoding="utf-8") + return 0, "", "" + if "sync" in cmd: + return 0, "", "" + raise AssertionError(f"unexpected command: {cmd}") + + async def fake_validate_windows_restart_runtime(*_args, **_kwargs): + return None + + def fake_find(name: str) -> str | None: + if name == "npm.cmd": + return system_npm + if name == "uv": + return r"C:\Users\flocks\AppData\Local\Programs\Flocks\tools\uv\uv.exe" + return None + + monkeypatch.setattr(updater, "_get_updater_config", fake_get_updater_config) + monkeypatch.setattr(updater, "_get_repo_root", lambda: tmp_path / "install-root") + monkeypatch.setattr(updater, "get_current_version", lambda: "2026.3.31") + monkeypatch.setattr(updater, "_download_with_fallback", fake_download_with_fallback) + monkeypatch.setattr(updater, "_backup_current_version", lambda *_args, **_kwargs: tmp_path / "backup.tar.gz") + monkeypatch.setattr(updater, "_extract_archive", lambda *_args, **_kwargs: staged_root) + monkeypatch.setattr(updater, "_run_async", fake_run_async) + monkeypatch.setattr(updater, "_find_executable", fake_find) + monkeypatch.setattr(updater, "_build_uv_sync_env", lambda: None) + monkeypatch.setattr(updater, "_validate_windows_restart_runtime", fake_validate_windows_restart_runtime) + monkeypatch.setattr(updater, "_replace_install_dir", lambda *_args, **_kwargs: None) + monkeypatch.setattr(updater, "_write_version_marker", lambda _v: None) + monkeypatch.setattr(updater.sys, "platform", "win32") + monkeypatch.setenv("FLOCKS_NODE_HOME", str(node_home)) + monkeypatch.delenv("FLOCKS_INSTALL_ROOT", raising=False) + monkeypatch.setenv("PATH", "/usr/bin:/bin") + + progresses = [step async for step in updater.perform_update("2026.4.1", restart=False, locale="zh-CN")] + + assert progresses[-1].stage == "done" + frontend_calls = [ + call for call in run_calls if call[0][0] in {str(bundled_npm), system_npm} + ] + assert [call[0] for call in frontend_calls] == [ + [str(bundled_npm), "install"], + [str(bundled_npm), "run", "build"], + [system_npm, "install"], + [system_npm, "run", "build"], + ] + assert [call[1] for call in frontend_calls] == [300, 300, 300, 300] + bundled_install_env = frontend_calls[0][2] + bundled_build_env = frontend_calls[1][2] + system_install_env = frontend_calls[2][2] + system_build_env = frontend_calls[3][2] + assert bundled_install_env is not None + assert bundled_build_env is not None + assert bundled_install_env["PATH"].split(os.pathsep)[0] == str(node_home) + assert bundled_build_env["PATH"].split(os.pathsep)[0] == str(node_home) + assert system_install_env == {"npm_config_registry": "https://registry.npmmirror.com/"} + assert system_build_env == {"npm_config_registry": "https://registry.npmmirror.com/"} + + +@pytest.mark.asyncio +async def test_perform_update_retries_windows_frontend_with_full_timeout_after_bundled_install_timeout( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + archive_path = tmp_path / "flocks.zip" + archive_path.write_text("archive", encoding="utf-8") + staged_root = tmp_path / "staged" + staged_webui = staged_root / "webui" + staged_webui.mkdir(parents=True) + (staged_webui / "package.json").write_text("{}", encoding="utf-8") + (staged_webui / "package-lock.json").write_text("{}", encoding="utf-8") + + node_home = tmp_path / "tools" / "node" + node_home.mkdir(parents=True) + (node_home / "node.exe").write_text("", encoding="utf-8") + bundled_npm = node_home / "npm.cmd" + bundled_npm.write_text("", encoding="utf-8") + system_npm = r"C:\Program Files\nodejs\npm.cmd" + + run_calls: list[tuple[list[str], int | None, dict[str, str] | None]] = [] + + async def fake_get_updater_config(): + return SimpleNamespace( + archive_format="zip", + sources=["github"], + repo="AgentFlocks/Flocks", + token=None, + gitee_token=None, + backup_retain_count=3, + base_url=None, + gitee_repo=None, + ) + + async def fake_download_with_fallback(**_kwargs): + return archive_path + + async def fake_run_async(cmd, cwd=None, timeout=None, env=None): + run_calls.append((list(cmd), timeout, env)) + if cmd == [str(bundled_npm), "ci"]: + bundled_modules = staged_webui / "node_modules" / "@esbuild" + bundled_modules.mkdir(parents=True, exist_ok=True) + (bundled_modules / "bundled.txt").write_text("bundled", encoding="utf-8") + raise subprocess.TimeoutExpired(cmd=cmd, timeout=timeout) + if cmd == [system_npm, "ci"]: + assert not (staged_webui / "node_modules").exists() + assert not (staged_webui / "dist").exists() + return 0, "", "" + if cmd == [system_npm, "run", "build"]: + dist_dir = staged_webui / "dist" + dist_dir.mkdir(exist_ok=True) + (dist_dir / "index.html").write_text("", encoding="utf-8") + return 0, "", "" + if "sync" in cmd: + return 0, "", "" + raise AssertionError(f"unexpected command: {cmd}") + + async def fake_validate_windows_restart_runtime(*_args, **_kwargs): + return None + + def fake_find(name: str) -> str | None: + if name == "npm.cmd": + return system_npm + if name == "uv": + return r"C:\Users\flocks\AppData\Local\Programs\Flocks\tools\uv\uv.exe" + return None + + monkeypatch.setattr(updater, "_get_updater_config", fake_get_updater_config) + monkeypatch.setattr(updater, "_get_repo_root", lambda: tmp_path / "install-root") + monkeypatch.setattr(updater, "get_current_version", lambda: "2026.3.31") + monkeypatch.setattr(updater, "_download_with_fallback", fake_download_with_fallback) + monkeypatch.setattr(updater, "_backup_current_version", lambda *_args, **_kwargs: tmp_path / "backup.tar.gz") + monkeypatch.setattr(updater, "_extract_archive", lambda *_args, **_kwargs: staged_root) + monkeypatch.setattr(updater, "_run_async", fake_run_async) + monkeypatch.setattr(updater, "_find_executable", fake_find) + monkeypatch.setattr(updater, "_build_uv_sync_env", lambda: None) + monkeypatch.setattr(updater, "_validate_windows_restart_runtime", fake_validate_windows_restart_runtime) + monkeypatch.setattr(updater, "_replace_install_dir", lambda *_args, **_kwargs: None) + monkeypatch.setattr(updater, "_write_version_marker", lambda _v: None) + monkeypatch.setattr(updater.sys, "platform", "win32") + monkeypatch.setenv("FLOCKS_NODE_HOME", str(node_home)) + monkeypatch.delenv("FLOCKS_INSTALL_ROOT", raising=False) + monkeypatch.setenv("PATH", "/usr/bin:/bin") + + progresses = [step async for step in updater.perform_update("2026.4.1", restart=False, locale="zh-CN")] + + assert progresses[-1].stage == "done" + frontend_calls = [ + call for call in run_calls if call[0][0] in {str(bundled_npm), system_npm} + ] + assert [call[0] for call in frontend_calls] == [ + [str(bundled_npm), "ci"], + [system_npm, "ci"], + [system_npm, "run", "build"], + ] + assert [call[1] for call in frontend_calls] == [300, 300, 300] + + @pytest.mark.asyncio async def test_perform_update_errors_when_uv_not_found( monkeypatch: pytest.MonkeyPatch, diff --git a/tui/tsconfig.json b/tui/tsconfig.json index 77a00812..746eed59 100644 --- a/tui/tsconfig.json +++ b/tui/tsconfig.json @@ -1,7 +1,12 @@ { "$schema": "https://json.schemastore.org/tsconfig", - "extends": "@tsconfig/bun/tsconfig.json", "compilerOptions": { + "target": "ESNext", + "module": "ESNext", + "moduleResolution": "bundler", + "allowImportingTsExtensions": true, + "resolveJsonModule": true, + "noEmit": true, "jsx": "preserve", "jsxImportSource": "@opentui/solid", "lib": ["ESNext", "DOM", "DOM.Iterable"], diff --git a/uv.lock b/uv.lock index e5ad83bf..5c8d1fd6 100644 --- a/uv.lock +++ b/uv.lock @@ -496,7 +496,7 @@ wheels = [ [[package]] name = "flocks" -version = "2026.5.9" +version = "2026.5.12" source = { editable = "." } dependencies = [ { name = "aiofiles" }, diff --git a/webui/package-lock.json b/webui/package-lock.json index 628f6e97..cb87d933 100644 --- a/webui/package-lock.json +++ b/webui/package-lock.json @@ -16,6 +16,7 @@ "i18next": "^25.8.14", "i18next-browser-languagedetector": "^8.2.1", "lucide-react": "^0.562.0", + "qrcode.react": "^4.2.0", "react": "^19.2.0", "react-dom": "^19.2.0", "react-i18next": "^16.5.6", @@ -6351,6 +6352,15 @@ "node": ">=6" } }, + "node_modules/qrcode.react": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/qrcode.react/-/qrcode.react-4.2.0.tgz", + "integrity": "sha512-QpgqWi8rD9DsS9EP3z7BT+5lY5SFhsqGjpgW5DY/i3mK4M9DTBNz3ErMi8BWYEfI3L0d8GIbGmcdFAS1uIRGjA==", + "license": "ISC", + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" + } + }, "node_modules/querystringify": { "version": "2.2.0", "resolved": "https://registry.npmjs.org/querystringify/-/querystringify-2.2.0.tgz", diff --git a/webui/package.json b/webui/package.json index 44c05b6b..ee2f6e88 100644 --- a/webui/package.json +++ b/webui/package.json @@ -22,6 +22,7 @@ "i18next": "^25.8.14", "i18next-browser-languagedetector": "^8.2.1", "lucide-react": "^0.562.0", + "qrcode.react": "^4.2.0", "react": "^19.2.0", "react-dom": "^19.2.0", "react-i18next": "^16.5.6", diff --git a/webui/public/channel-weixin.png b/webui/public/channel-weixin.png new file mode 100644 index 00000000..916104af Binary files /dev/null and b/webui/public/channel-weixin.png differ diff --git a/webui/src/components/common/SessionChat.tsx b/webui/src/components/common/SessionChat.tsx index 517beb82..2d1adc78 100644 --- a/webui/src/components/common/SessionChat.tsx +++ b/webui/src/components/common/SessionChat.tsx @@ -521,7 +521,7 @@ export default function SessionChat({ const hasUserMessage = useMemo(() => messages.some((m) => m.role === 'user'), [messages]); - const sseEnabled = live || isStreaming || !hideInput; + const sseEnabled = Boolean(sessionId) && (live || isStreaming || !hideInput); const handleSSEEvent = useCallback( (event: SSEChatEvent) => { diff --git a/webui/src/hooks/useTasks.ts b/webui/src/hooks/useTasks.ts index 5f94e5e4..d8189ce4 100644 --- a/webui/src/hooks/useTasks.ts +++ b/webui/src/hooks/useTasks.ts @@ -22,10 +22,11 @@ export function useTaskSchedulers( const [loading, setLoading] = useState(true); const [error, setError] = useState(null); const tasksRef = useRef([]); + const initializedRef = useRef(false); const fetchTasks = useCallback(async () => { try { - setLoading(true); + if (!initializedRef.current) setLoading(true); setError(null); const response = await taskAPI.listSchedulers(filters); const data = response.data; @@ -39,6 +40,7 @@ export function useTaskSchedulers( setTotal(0); } finally { setLoading(false); + initializedRef.current = true; } }, [ filters?.status, @@ -87,10 +89,11 @@ export function useTaskExecutions( const [loading, setLoading] = useState(true); const [error, setError] = useState(null); const tasksRef = useRef([]); + const initializedRef = useRef(false); const fetchTasks = useCallback(async () => { try { - setLoading(true); + if (!initializedRef.current) setLoading(true); setError(null); const response = await taskAPI.listExecutions(filters); const data = response.data; @@ -104,6 +107,7 @@ export function useTaskExecutions( setTotal(0); } finally { setLoading(false); + initializedRef.current = true; } }, [ filters?.status, @@ -173,10 +177,11 @@ export function useTaskDashboard(options?: { pollInterval?: number }) { const [counts, setCounts] = useState(null); const [loading, setLoading] = useState(true); const [error, setError] = useState(null); + const initializedRef = useRef(false); const fetchDashboard = useCallback(async () => { try { - setLoading(true); + if (!initializedRef.current) setLoading(true); setError(null); const response = await taskAPI.dashboard(); setCounts(response.data); @@ -184,6 +189,7 @@ export function useTaskDashboard(options?: { pollInterval?: number }) { setError(err.message || 'Failed to fetch dashboard'); } finally { setLoading(false); + initializedRef.current = true; } }, []); @@ -232,10 +238,11 @@ export function useQueueStatus(options?: { pollInterval?: number }) { const [queueStatus, setQueueStatus] = useState(null); const [loading, setLoading] = useState(true); const [error, setError] = useState(null); + const initializedRef = useRef(false); const fetchQueueStatus = useCallback(async () => { try { - setLoading(true); + if (!initializedRef.current) setLoading(true); setError(null); const response = await taskAPI.queueStatus(); setQueueStatus(response.data); @@ -243,6 +250,7 @@ export function useQueueStatus(options?: { pollInterval?: number }) { setError(err.message || 'Failed to fetch queue status'); } finally { setLoading(false); + initializedRef.current = true; } }, []); @@ -263,10 +271,11 @@ export function useTaskSystemNotice() { const [notice, setNotice] = useState(null); const [loading, setLoading] = useState(false); const [error, setError] = useState(null); + const initializedRef = useRef(false); const fetchNotice = useCallback(async () => { try { - setLoading(true); + if (!initializedRef.current) setLoading(true); setError(null); const response = await taskAPI.getSystemNotice(); setNotice(response.data ?? null); @@ -274,6 +283,7 @@ export function useTaskSystemNotice() { setError(err.message || 'Failed to fetch system notice'); } finally { setLoading(false); + initializedRef.current = true; } }, []); diff --git a/webui/src/hooks/useTools.test.tsx b/webui/src/hooks/useTools.test.tsx new file mode 100644 index 00000000..ac6fd9e7 --- /dev/null +++ b/webui/src/hooks/useTools.test.tsx @@ -0,0 +1,64 @@ +import { renderHook, waitFor } from '@testing-library/react'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import { useTools } from './useTools'; + +const { listMock, refreshMock } = vi.hoisted(() => ({ + listMock: vi.fn(), + refreshMock: vi.fn(), +})); + +vi.mock('@/api/tool', () => ({ + toolAPI: { + list: listMock, + refresh: refreshMock, + }, +})); + +function deferred() { + let resolve!: (value: T | PromiseLike) => void; + const promise = new Promise((res) => { + resolve = res; + }); + return { promise, resolve }; +} + +describe('useTools', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('renders the tool list before the background refresh completes', async () => { + const refreshDeferred = deferred<{ data: { status: string } }>(); + + listMock.mockResolvedValue({ + data: [ + { + name: 'tool-alpha', + description: 'alpha tool', + category: 'custom', + source: 'custom', + enabled: true, + }, + ], + }); + refreshMock.mockReturnValue(refreshDeferred.promise); + + const { result } = renderHook(() => useTools()); + + await waitFor(() => { + expect(result.current.loading).toBe(false); + }); + + expect(result.current.tools).toHaveLength(1); + expect(result.current.tools[0].name).toBe('tool-alpha'); + expect(listMock).toHaveBeenCalledTimes(1); + expect(refreshMock).toHaveBeenCalledTimes(1); + + refreshDeferred.resolve({ data: { status: 'success' } }); + + await waitFor(() => { + expect(listMock).toHaveBeenCalledTimes(2); + }); + }); +}); diff --git a/webui/src/hooks/useTools.ts b/webui/src/hooks/useTools.ts index 8cebfc84..f4119b2f 100644 --- a/webui/src/hooks/useTools.ts +++ b/webui/src/hooks/useTools.ts @@ -6,17 +6,19 @@ export function useTools() { const [loading, setLoading] = useState(true); const [error, setError] = useState(null); const lastRefreshRef = useRef(0); + const initializedRef = useRef(false); const fetchTools = useCallback(async (showLoading = false) => { try { - if (showLoading) setLoading(true); + if (showLoading && !initializedRef.current) setLoading(true); setError(null); const response = await toolAPI.list(); setTools(Array.isArray(response.data) ? response.data : []); } catch (err: any) { setError(err.message || 'Failed to fetch tools'); } finally { - if (showLoading) setLoading(false); + if (showLoading && !initializedRef.current) setLoading(false); + initializedRef.current = true; } }, []); @@ -31,18 +33,34 @@ export function useTools() { }, [fetchTools]); useEffect(() => { + let cancelled = false; + const init = async () => { - try { await toolAPI.refresh(); } catch { /* ignore */ } - lastRefreshRef.current = Date.now(); await fetchTools(true); + if (cancelled) return; + + try { + await toolAPI.refresh(); + if (cancelled) return; + lastRefreshRef.current = Date.now(); + await fetchTools(false); + } catch { + /* ignore */ + } }; - init(); + + void init(); const onVisible = () => { - if (document.visibilityState === 'visible') refreshAndFetch(); + if (document.visibilityState === 'visible') { + void refreshAndFetch(); + } }; document.addEventListener('visibilitychange', onVisible); - return () => document.removeEventListener('visibilitychange', onVisible); + return () => { + cancelled = true; + document.removeEventListener('visibilitychange', onVisible); + }; }, [fetchTools, refreshAndFetch]); return { diff --git a/webui/src/locales/en-US/channel.json b/webui/src/locales/en-US/channel.json index 5a02ae3a..fa058f8f 100644 --- a/webui/src/locales/en-US/channel.json +++ b/webui/src/locales/en-US/channel.json @@ -11,7 +11,8 @@ "feishu": "Feishu", "wecom": "WeCom", "dingtalk": "DingTalk", - "telegram": "Telegram" + "telegram": "Telegram", + "weixin": "WeChat" }, "status": { "enabled": "Enabled", @@ -232,5 +233,57 @@ "dedupTtlSeconds": "Dedup TTL (seconds)", "dedupTtlSecondsHint": "Retention time for message deduplication records (default 86400s = 24h)", "optional": "Optional" + }, + "weixin": { + "enableTitle": "Enable WeChat Channel", + "enableDesc": "When enabled, the bot connects to WeChat personal accounts via the iLink Bot long-poll API.", + "credentials": "Account Credentials", + "credentialsDesc": "Obtain Token and Account ID after scanning the iLink Bot QR code.", + "tokenHint": "iLink Bot Token (obtained after QR login)", + "accountIdHint": "iLink Bot Account ID (obtained after QR login, format: xxx@im.bot)", + "baseUrl": "iLink API Base URL", + "baseUrlHint": "Custom iLink API base URL (leave empty to use default https://ilinkai.weixin.qq.com)", + "optional": "Optional", + "behavior": "Message Behavior", + "behaviorDesc": "Configure message routing and DM policy.", + "defaultAgent": "Default Agent", + "defaultAgentHint": "Default Agent ID used when no Agent is specified", + "dmPolicy": "DM Policy", + "dmPolicyHint": "Controls which users' direct messages can trigger the Agent", + "dmPolicyOpen": "Open (everyone allowed)", + "dmPolicyAllowlist": "Allowlist (only listed users)", + "dmPolicyDisabled": "Disabled (no DMs accepted)", + "allowFrom": "Allowed User IDs", + "allowFromHint": "In allowlist mode, only these WeChat user_ids can trigger the Agent", + "allowFromPlaceholder": "Enter WeChat user_id and press Enter", + "groupPolicy": "Group Policy", + "groupPolicyHint": "Controls whether group chat messages trigger the Agent", + "groupPolicyAll": "All (process all group messages)", + "groupPolicyAllowlist": "Allowlist (allowed groups only)", + "groupPolicyDisabled": "Disabled (ignore group messages)", + "groupAllowFrom": "Allowed Group IDs", + "groupAllowFromHint": "In allowlist mode, only messages from these group / room IDs trigger the Agent", + "groupAllowFromPlaceholder": "Enter group ID and press Enter", + "advanced": "Advanced Settings", + "advancedDesc": "Message chunk delay and state storage directory.", + "sendChunkDelay": "Chunk Send Delay (s)", + "sendChunkDelayHint": "Delay in seconds between multi-chunk messages (default 1.5s)", + "dataDir": "State Storage Directory", + "dataDirHint": "Directory for sync_buf and context-token state files (leave empty to use ~/.flocks/workspace/channels/weixin)", + "qrLoginButton": "Connect via QR Code", + "qrLoading": "Fetching QR code…", + "qrAlreadyLinked": "Account already linked — scan again to replace credentials", + "qrModalTitle": "WeChat QR Login", + "qrHintScanning": "Scan the QR code above with WeChat", + "qrHintScaned": "Scanned — please tap Confirm on your phone", + "qrHintConfirmed": "Login successful — credentials auto-filled", + "qrScaned": "Scanned, confirm on phone", + "qrConfirmed": "Login successful!", + "qrExpired": "QR code expired — please refresh", + "qrRefresh": "Refresh QR Code", + "qrRetry": "Retry", + "qrError": "Failed to fetch QR code — check your network connection", + "qrDone": "Done", + "qrSuccess": "WeChat account connected — credentials auto-filled" } } diff --git a/webui/src/locales/zh-CN/channel.json b/webui/src/locales/zh-CN/channel.json index a7e25b18..fbd4c2cb 100644 --- a/webui/src/locales/zh-CN/channel.json +++ b/webui/src/locales/zh-CN/channel.json @@ -11,7 +11,8 @@ "feishu": "飞书", "wecom": "企业微信", "dingtalk": "钉钉", - "telegram": "Telegram" + "telegram": "Telegram", + "weixin": "微信" }, "status": { "enabled": "已启用", @@ -234,5 +235,57 @@ "dedupTtlSeconds": "去重 TTL (秒)", "dedupTtlSecondsHint": "消息去重记录的保留时间(默认 86400 秒 = 24 小时)", "optional": "选填" + }, + "weixin": { + "enableTitle": "启用微信通道", + "enableDesc": "启用后,机器人将通过 iLink Bot 长轮询接入微信,并开始接收和回复消息。", + "credentials": "账号凭证", + "credentialsDesc": "通过微信 iLink Bot QR 扫码登录后获取 Token 和 Account ID。", + "tokenHint": "iLink Bot Token(扫码登录后获取)", + "accountIdHint": "iLink Bot Account ID(扫码登录后获取,格式如 xxx@im.bot)", + "baseUrl": "iLink API 地址", + "baseUrlHint": "自定义 iLink API 地址(留空使用默认 https://ilinkai.weixin.qq.com)", + "optional": "选填", + "behavior": "消息行为", + "behaviorDesc": "配置消息路由和私信接收策略。", + "defaultAgent": "默认 Agent", + "defaultAgentHint": "未指定 Agent 时使用的默认 Agent ID", + "dmPolicy": "私信策略", + "dmPolicyHint": "控制哪些用户的私信可以触发 Agent", + "dmPolicyOpen": "开放(所有人均可)", + "dmPolicyAllowlist": "白名单(仅允许名单内用户)", + "dmPolicyDisabled": "关闭(不接受私信)", + "allowFrom": "允许的用户 ID", + "allowFromHint": "白名单模式下,仅允许这些微信 user_id 触发 Agent", + "allowFromPlaceholder": "输入微信 user_id 后按回车添加", + "groupPolicy": "群聊策略", + "groupPolicyHint": "控制群聊消息是否触发 Agent", + "groupPolicyAll": "全部(处理所有群消息)", + "groupPolicyAllowlist": "白名单(仅允许名单内群组)", + "groupPolicyDisabled": "关闭(不处理群消息)", + "groupAllowFrom": "允许的群 ID", + "groupAllowFromHint": "白名单模式下,仅允许这些群 / 房间 ID 触发 Agent", + "groupAllowFromPlaceholder": "输入群 ID 后按回车添加", + "advanced": "高级设置", + "advancedDesc": "消息分块发送间隔及状态存储目录。", + "sendChunkDelay": "分块发送间隔 (秒)", + "sendChunkDelayHint": "多段消息之间的发送间隔秒数(默认 1.5 秒)", + "dataDir": "状态存储目录", + "dataDirHint": "sync_buf 与 context-token 等状态文件存储目录(留空使用默认 ~/.flocks/workspace/channels/weixin)", + "qrLoginButton": "扫码登录微信", + "qrLoading": "正在获取二维码…", + "qrAlreadyLinked": "已连接账号,可重新扫码替换", + "qrModalTitle": "微信扫码登录", + "qrHintScanning": "请用微信扫描上方二维码", + "qrHintScaned": "扫码成功,请在手机上点击「确认登录」", + "qrHintConfirmed": "登录成功,凭证已自动填入", + "qrScaned": "已扫码,请确认", + "qrConfirmed": "登录成功!", + "qrExpired": "二维码已过期,请点击刷新", + "qrRefresh": "刷新二维码", + "qrRetry": "重试", + "qrError": "获取二维码失败,请检查网络连接", + "qrDone": "完成", + "qrSuccess": "微信账号连接成功,凭证已自动填入" } } diff --git a/webui/src/pages/Channel/index.tsx b/webui/src/pages/Channel/index.tsx index ef246969..09ca07a6 100644 --- a/webui/src/pages/Channel/index.tsx +++ b/webui/src/pages/Channel/index.tsx @@ -1,4 +1,5 @@ import { useState, useEffect, useCallback, useRef } from 'react'; +import { QRCodeSVG } from 'qrcode.react'; import { Radio, Save, @@ -128,7 +129,22 @@ interface TelegramChannelConfig { streamingCoalesceMs?: number; } -type ChannelConfig = FeishuChannelConfig | WeComChannelConfig | DingTalkChannelConfig | TelegramChannelConfig; +interface WeixinChannelConfig { + enabled: boolean; + token?: string; + accountId?: string; + baseUrl?: string; + cdnBaseUrl?: string; + defaultAgent?: string; + dmPolicy?: string; + allowFrom?: string[]; + groupPolicy?: string; + groupAllowFrom?: string[]; + sendChunkDelay?: number; + dataDir?: string; +} + +type ChannelConfig = FeishuChannelConfig | WeComChannelConfig | DingTalkChannelConfig | TelegramChannelConfig | WeixinChannelConfig; function defaultFeishuConfig(): FeishuChannelConfig { return { @@ -175,6 +191,15 @@ function defaultTelegramConfig(): TelegramChannelConfig { }; } +function defaultWeixinConfig(): WeixinChannelConfig { + return { + enabled: false, + dmPolicy: 'open', + groupPolicy: 'all', + sendChunkDelay: 1.5, + }; +} + // ============================================================================ // Form primitives // ============================================================================ @@ -430,6 +455,7 @@ const CHANNEL_ICON_SRC: Record = { wecom: '/channel-wecom.png', dingtalk: '/channel-dingtalk.png', telegram: '/channel-telegram.png', + weixin: '/channel-weixin.png', }; const FEISHU_GUIDE_PDF_URL = '/feishu-bot-guide.pdf'; @@ -616,6 +642,7 @@ function ConnectionStatusPanel({ status, config, channelId }: ConnectionStatusPa {channelId === 'feishu' && 'WebSocket'} {channelId === 'wecom' && 'WebSocket'} {channelId === 'dingtalk' && 'Stream'} + {channelId === 'weixin' && 'Long-Poll'} {channelId === 'telegram' && ((config as TelegramChannelConfig).mode === 'webhook' ? 'Webhook' : 'Polling')} @@ -1369,6 +1396,354 @@ function TelegramPanel({ config, onChange, onRefresh }: TelegramPanelProps) { ); } +// ============================================================================ +// Weixin Config Panel +// ============================================================================ + +interface WeixinPanelProps { + config: WeixinChannelConfig; + onChange: (c: WeixinChannelConfig) => void; + /** Persist QR-obtained credentials to flocks.json + restart the channel. + * Called automatically when the QR login flow completes. */ + onQrLoginSuccess?: (creds: { token: string; accountId: string; baseUrl?: string }) => Promise | void; +} + +type QrPhase = + | 'idle' // initial / closed + | 'loading' // fetching QR from backend + | 'scanning' // QR shown, waiting for phone scan + | 'scaned' // phone scanned, waiting for confirmation tap + | 'confirmed' // login complete — credentials filled + | 'expired' // QR expired, allow restart + | 'error'; // network / API error + +function WeixinPanel({ config, onChange, onQrLoginSuccess }: WeixinPanelProps) { + const { t } = useTranslation('channel'); + const toast = useToast(); + const set = useCallback( + (key: K, value: WeixinChannelConfig[K]) => + onChange({ ...config, [key]: value }), + [config, onChange] + ); + + // ── QR login state ────────────────────────────────────────────────────── + const [qrPhase, setQrPhase] = useState('idle'); + const [qrUrl, setQrUrl] = useState(''); // URL to encode into QR SVG + const [qrValue, setQrValue] = useState(''); // hex token used for polling + const [qrError, setQrError] = useState(''); + const pollRef = useRef | null>(null); + // Guard: multiple in-flight requests may all resolve with "confirmed". + // Only the first one should act; the rest are no-ops. + const confirmedRef = useRef(false); + // Tracks the current polling base_url; may change on scaned_but_redirect. + const currentBaseUrlRef = useRef(undefined); + + const stopPolling = () => { + if (pollRef.current) { + clearInterval(pollRef.current); + pollRef.current = null; + } + }; + + // Cleanup on unmount + useEffect(() => () => stopPolling(), []); + + const startQrLogin = async () => { + stopPolling(); + confirmedRef.current = false; + currentBaseUrlRef.current = config.baseUrl?.trim() || undefined; + setQrError(''); + setQrPhase('loading'); + try { + const baseUrl = config.baseUrl?.trim() || undefined; + const res = await client.post('/api/channel/weixin/qr-login/start', { baseUrl: baseUrl ?? null }); + const { qrcode_value, qrcode_url } = res.data; + setQrValue(qrcode_value); + setQrUrl(qrcode_url); + setQrPhase('scanning'); + + // Poll status every 2 s. + // NOTE: each tick is an async call; multiple ticks can be in-flight + // simultaneously. confirmedRef prevents duplicate side-effects. + // currentBaseUrlRef tracks regional redirects (scaned_but_redirect). + pollRef.current = setInterval(async () => { + try { + const statusRes = await client.get('/api/channel/weixin/qr-login/status', { + params: { qrcode: qrcode_value, baseUrl: currentBaseUrlRef.current ?? undefined }, + }); + const { status, account_id, token, base_url, redirect_base_url } = statusRes.data; + if (status === 'scaned') { + setQrPhase('scaned'); + } else if (status === 'redirect') { + // iLink is routing this account to a different regional node. + // Update base_url so subsequent polls hit the correct host. + if (redirect_base_url) currentBaseUrlRef.current = redirect_base_url; + setQrPhase('scaned'); + } else if (status === 'confirmed') { + if (confirmedRef.current) return; // already handled + confirmedRef.current = true; + stopPolling(); + setQrPhase('confirmed'); + // Auto-fill credentials including the canonical base_url for this + // account — it may differ from the default when iLink redirected. + const newConfig: WeixinChannelConfig = { + ...config, + accountId: account_id, + token, + ...(base_url ? { baseUrl: base_url } : {}), + }; + onChange(newConfig); + // Persist immediately — without this the gateway keeps trying to + // start with the (still empty) on-disk config and the channel never + // actually connects to WeChat. + if (onQrLoginSuccess) { + try { + await onQrLoginSuccess({ + token, + accountId: account_id, + ...(base_url ? { baseUrl: base_url } : {}), + }); + } catch (err: any) { + toast.error(t('weixin.qrError'), err?.message ?? ''); + } + } + toast.success(t('weixin.qrSuccess')); + } else if (status === 'expired') { + stopPolling(); + setQrPhase('expired'); + } + // 'waiting' → keep polling + } catch { + // transient network error — keep polling + } + }, 2000); + } catch (err: any) { + const detail = err?.response?.data?.detail ?? err?.message ?? ''; + setQrError(detail); + setQrPhase('error'); + } + }; + + const closeQrModal = () => { + stopPolling(); + setQrPhase('idle'); + setQrUrl(''); + setQrValue(''); + setQrError(''); + }; + + const showModal = qrPhase !== 'idle'; + + return ( + <> +
+ {/* QR login launcher */} +
+ + {config.token && config.accountId && ( +

+ + {t('weixin.qrAlreadyLinked')} +

+ )} +
+ + {/* QR modal overlay */} + {showModal && ( +
+
+ {/* Close button */} + + +

{t('weixin.qrModalTitle')}

+ + {/* QR code display area */} + {qrPhase === 'loading' && ( +
+ +
+ )} + {(qrPhase === 'scanning' || qrPhase === 'scaned') && qrUrl && ( +
+
+ +
+ {qrPhase === 'scaned' && ( +
+
+ +

{t('weixin.qrScaned')}

+
+
+ )} +
+ )} + {qrPhase === 'confirmed' && ( +
+ +

{t('weixin.qrConfirmed')}

+
+ )} + {qrPhase === 'expired' && ( +
+ +

{t('weixin.qrExpired')}

+ +
+ )} + {qrPhase === 'error' && ( +
+ +

{qrError || t('weixin.qrError')}

+ +
+ )} + + {/* Status hint */} +

+ {qrPhase === 'scanning' && t('weixin.qrHintScanning')} + {qrPhase === 'scaned' && t('weixin.qrHintScaned')} + {qrPhase === 'confirmed' && t('weixin.qrHintConfirmed')} + {qrPhase === 'expired' && ''} + {qrPhase === 'error' && ''} +

+ + {qrPhase === 'confirmed' && ( + + )} +
+
+ )} + +
+ + + set('token', v || undefined)} + placeholder="xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + /> + + + set('accountId', v || undefined)} + placeholder="xxxxxxxxxxxxxxxxx@im.bot" + /> + + + set('baseUrl', v || undefined)} + placeholder={t('weixin.optional')} + /> + +
+ +
+ + set('defaultAgent', v || undefined)} + placeholder={t('weixin.optional')} + /> + + + set('groupPolicy', v)} + options={[ + { value: 'all', label: t('weixin.groupPolicyAll') }, + { value: 'allowlist', label: t('weixin.groupPolicyAllowlist') }, + { value: 'disabled', label: t('weixin.groupPolicyDisabled') }, + ]} + /> + + {(config.groupPolicy ?? 'all') === 'allowlist' && ( + + set('groupAllowFrom', v.length ? v : undefined)} + placeholder={t('weixin.groupAllowFromPlaceholder')} + /> + + )} +
+ +
+ + set('sendChunkDelay', v)} + min={0} + /> + + + set('dataDir', v || undefined)} + placeholder={t('weixin.optional')} + /> + +
+ + ); +} + // ============================================================================ // Detail Panel Header // ============================================================================ @@ -1581,6 +1956,8 @@ export default function ChannelPage() { configs[ch.id] = { ...defaultDingTalkConfig(), ...saved }; } else if (ch.id === 'telegram') { configs[ch.id] = { ...defaultTelegramConfig(), ...saved }; + } else if (ch.id === 'weixin') { + configs[ch.id] = { ...defaultWeixinConfig(), ...saved }; } else { configs[ch.id] = { enabled: false, ...saved }; } @@ -1669,6 +2046,50 @@ export default function ChannelPage() { } }; + // Persist credentials obtained via WeChat QR login + auto-enable + restart. + // The user explicitly initiated the QR scan, so we treat that as consent to + // enable the channel — no extra "save & enable" click required. + // Mirrors handleToggleEnabled's single-field update pattern so that any + // other unsaved channel edits are not flushed prematurely. + const handleWeixinQrSuccess = async ( + creds: { token: string; accountId: string; baseUrl?: string } + ) => { + const channelId = 'weixin'; + const savedChannelCfg = (fullConfig.channels?.[channelId] ?? {}) as Record; + const updatedChannelCfg: Record = { + ...savedChannelCfg, + enabled: true, + token: creds.token, + accountId: creds.accountId, + }; + if (creds.baseUrl) updatedChannelCfg.baseUrl = creds.baseUrl; + + const updatedChannels = { ...(fullConfig.channels ?? {}), [channelId]: updatedChannelCfg }; + const updated = { ...fullConfig, channels: updatedChannels }; + + await client.patch('/api/config/', updated); + setFullConfig(updated); + + // Sync the in-memory editor state so the UI immediately reflects the + // newly-saved values (token + accountId fields, enabled toggle, baseUrl). + setChannelConfigs((prev) => ({ + ...prev, + [channelId]: { ...prev[channelId], ...updatedChannelCfg } as ChannelConfig, + })); + originalConfigsRef.current = { + ...originalConfigsRef.current, + [channelId]: { ...originalConfigsRef.current[channelId], ...updatedChannelCfg }, + }; + + // Restart the channel so the new credentials take effect immediately. + // Fire-and-forget — server may take time to disconnect WebSocket. + client.post(`/api/channel/${channelId}/restart`, {}, { timeout: 5000 }).catch(() => {}); + + // Sync UI state after the connection has had time to come up. + setTimeout(() => { fetchAll(); fetchStatuses(true); }, 3000); + setTimeout(() => { fetchAll(); fetchStatuses(true); }, 8000); + }; + // Manual restart — useful when connection drops and user wants to reconnect const handleRestart = async (channelId?: string) => { const id = channelId ?? selectedId; @@ -1860,6 +2281,13 @@ export default function ChannelPage() { onRefresh={fetchAll} /> )} + {selectedId === 'weixin' && ( + handleChannelConfigChange('weixin', cfg)} + onQrLoginSuccess={handleWeixinQrSuccess} + /> + )} ) : ( diff --git a/webui/src/pages/Session/index.test.tsx b/webui/src/pages/Session/index.test.tsx index e48265f4..d6e9c72b 100644 --- a/webui/src/pages/Session/index.test.tsx +++ b/webui/src/pages/Session/index.test.tsx @@ -264,6 +264,12 @@ describe('SessionPage session actions menu', () => { expect(global.confirm).toHaveBeenCalledWith('confirmDelete'); }); + it('does not auto-select the first session on initial load', () => { + renderSessionPage(); + + expect(screen.getByTestId('session-chat')).toHaveTextContent('no-session'); + }); + it('syncs selected session when query param changes after mount', async () => { const user = userEvent.setup(); @@ -294,9 +300,7 @@ describe('SessionPage session actions menu', () => { , ); - await waitFor(() => { - expect(screen.getByTestId('session-chat')).toHaveTextContent('session-1'); - }); + expect(screen.getByTestId('session-chat')).toHaveTextContent('no-session'); await user.click(screen.getByRole('button', { name: 'go-session-2' })); diff --git a/webui/src/pages/Session/index.tsx b/webui/src/pages/Session/index.tsx index 8f32fbfa..dacc354d 100644 --- a/webui/src/pages/Session/index.tsx +++ b/webui/src/pages/Session/index.tsx @@ -99,13 +99,6 @@ export default function SessionPage() { } }, [searchParams, selectedSessionId, setSearchParams]); - // Auto select first session - useEffect(() => { - if (!selectedSessionId && sessions.length > 0) { - setSelectedSessionId(sessions[0].id); - } - }, [sessions, selectedSessionId]); - // Close agent dropdown on outside click useEffect(() => { if (!showAgentOptions) return; @@ -617,7 +610,7 @@ export default function SessionPage() { { refetchDashboard(); - refetchQueue(); }; const forceRemountSections = () => { diff --git a/webui/src/types/index.ts b/webui/src/types/index.ts index 6e8d3420..dcb6fb35 100644 --- a/webui/src/types/index.ts +++ b/webui/src/types/index.ts @@ -18,6 +18,7 @@ export interface Session { /** Session category: 'user' | 'workflow' | 'task' | 'entity-config' | ... */ category?: string; ownerUserID?: string; + ownerUsername?: string; canDelete?: boolean; }