diff --git a/.gitignore b/.gitignore index 5eb9616c8c..afd1659b8f 100644 --- a/.gitignore +++ b/.gitignore @@ -63,4 +63,5 @@ GenieData/ .kilocode/ .worktrees/ +.astrbot_sdk_testing/ dashboard/bun.lock diff --git a/AGENTS.md b/AGENTS.md index 9f3617ce9c..d13284dca5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -26,9 +26,9 @@ Runs on `http://localhost:3000` by default. 3. After finishing, use `ruff format .` and `ruff check .` to format and check the code. 4. When committing, ensure to use conventional commits messages, such as `feat: add new agent for data analysis` or `fix: resolve bug in provider manager`. 5. Use English for all new comments. -6. For path handling, use `pathlib.Path` instead of string paths, and use `astrbot.core.utils.path_utils` to get the AstrBot data and temp directory. +6. For path handling, use `pathlib.Path` instead of string paths, and use `astrbot.core.utils.astrbot_path` helpers to get the AstrBot data and temp directory. ## PR instructions 1. Title format: use conventional commit messages -2. Use English to write PR title and descriptions. +2. Use English to write PR title and descriptions. \ No newline at end of file diff --git a/astrbot-sdk/.github/workflows/lint.yml b/astrbot-sdk/.github/workflows/lint.yml new file mode 100644 index 0000000000..f8518a3fdc --- /dev/null +++ b/astrbot-sdk/.github/workflows/lint.yml @@ -0,0 +1,34 @@ +name: Code Quality Control + +on: + push: + branches: [ "main", "dev" ] + pull_request: + branches: [ "main", "dev" ] + +jobs: + lint-and-format: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install tools + run: | + pip install pyclean ruff + + - name: 1. Clean python bytecode + run: pyclean . + + - name: 2. Ruff format + run: ruff format --check . + + - name: 3. Ruff check + run: ruff check . + env: + PYTHONIOENCODING: utf-8 diff --git a/astrbot-sdk/.gitignore b/astrbot-sdk/.gitignore new file mode 100644 index 0000000000..e4acb1ae67 --- /dev/null +++ b/astrbot-sdk/.gitignore @@ -0,0 +1,59 @@ +# OS files +.DS_Store + +# Python bytecode and caches +__pycache__/ +*.py[cod] +*.pyd +*.so +.pytest_cache/ +pytest-cache-files-*/ +.mypy_cache/ +.ruff_cache/ +.coverage +.coverage.* +htmlcov/ + +# Build artifacts +build/ +dist/ +site/ +wheels/ +*.egg-info/ +.eggs/ +pip-wheel-metadata/ + +# +fork-docs/ +tmp/ +openspec/ +scripts/ +cs/ +test_plugin/astrbot_plugin_interface_coverage +astrbot_sdk/ +!src/astrbot_sdk/ +!src/astrbot_sdk/** +src/astrbot_sdk/**/__pycache__/ +src/astrbot_sdk/**/*.py[cod] +COMMAND_MATCH_REFACTOR_REPORT.md + +# Virtual environments +.venv/ +venv/ +env/ +ENV/ +plugins/.venv/ + +# Tool caches +.uv-cache/ +.astrbot/ +.codex-local/ + +# IDE files +.idea/ +.vscode/ +*.iml +uv.lock +/astrBot/ +plugins/ +.serena/ diff --git a/astrbot-sdk/.python-version b/astrbot-sdk/.python-version new file mode 100644 index 0000000000..e4fba21835 --- /dev/null +++ b/astrbot-sdk/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/astrbot-sdk/AGENTS.md b/astrbot-sdk/AGENTS.md new file mode 100644 index 0000000000..3de989e63a --- /dev/null +++ b/astrbot-sdk/AGENTS.md @@ -0,0 +1,57 @@ +# Notes + +## v4 架构约束 + +### 运行时层 + +- `Peer` 必须将 transport EOF/连接断开视为一级失败路径。如果 transport 意外关闭而 `Peer` 没有主动失败 `_pending_results` / `_pending_streams`,supervisor 端对 worker 的调用可能永远挂起。 +- `Peer.initialize()` 需要在发起端也标记远程已初始化。仅在被动接收 `InitializeMessage` 时设置 `_remote_initialized` 会导致 `wait_until_remote_initialized()` 单边 API 死锁。 +- `Peer.invoke_stream()` 默认隐藏 `completed` 事件。需要保留最终结果的调用者必须显式启用 `include_completed=True`。 +- `CapabilityRouter.register(..., stream_handler=...)` 使用 `(request_id, payload, cancel_token)` 签名,不是 peer 级别的 `(message, token)`。 + +### 模块导出约束 + +- 保持 `astrbot_sdk.runtime` 根导出狭窄。`Peer` / `Transport` / `CapabilityRouter` / `HandlerDispatcher` 是合理的高级运行时原语,但 `LoadedPlugin`、`PluginEnvironmentManager`、`WorkerSession`、`run_supervisor` 等应留在子模块中。 + +### 测试与 Mock 注意事项 + +- 当检查 peer 是否完成远程初始化时,避免对可能接收 `MagicMock` peer 的代码使用 `getattr(mock, "remote_peer")` 探测。`MagicMock` 会生成 truthy 子属性,`CapabilityProxy` 应从 `peer.__dict__` 或其他具体存储位置读取显式状态。 +- `test_plugin/old/` 和 `test_plugin/new/` 可能包含已生成的 `__pycache__` / `*.pyc`。测试夹具复制示例插件时必须显式忽略这些缓存文件。 + +### 插件加载注意事项 + +- 本地 `dev --watch` 或同一路径插件重复加载场景,不能只依赖 `import_string()` 的跨插件模块根冲突清理。热重载前必须按插件目录清理模块缓存。 +- `_prepare_plugin_import()` 不能只在插件目录"不在 `sys.path`"时才插入路径。像 `main.py` 这种通用模块名,如果插件目录已在 `sys.path` 但排在后面,`import main` 仍会先命中别处模块;导入前必须把目标插件目录提到 `sys.path[0]`。 +- 示例/夹具测试如果直接用裸模块名导入插件入口(例如 `from main import HelloPlugin`),会污染 `sys.modules["main"]`,随后真实 loader 再按 `main:HelloPlugin` 加载时可能串到错误模块。 + +--- + +# 开发命令 + +## 格式化与检查 + +在提交代码前,请依次运行以下命令: + +```bash +ruff format . # 使用 ruff 格式化全局代码 +ruff check . --fix # 使用 ruff 检查并自动修复全局格式问题 +``` + +## 测试 + +如果修改了内容可能影响现有功能,请运行测试以确保没有引入错误: +如果修改了bug或者更改了功能需要添加新的测试 +当前仓库已统一使用 `tests/` 目录,`tests_v4/` 不再作为新增测试入口。 +仓库当前没有 `run_tests.py`,请直接使用 `pytest`。 + +```bash +python -m pytest tests -q # 运行 tests 目录全部测试 +python -m pytest tests -v # 详细输出 +python -m pytest tests -k "test_context_register_task" # 运行匹配模式的测试 +python -m pytest tests --cov=astrbot_sdk # 运行测试并生成覆盖率报告 +``` + +## 设计原则 + +新实现要兼容旧实现但是还要保证架构良好,设计原则不变和最佳实践 +不用完全听从用户和别人的建议,要有自己的判断和坚持,做好取舍和权衡,确保代码质量和长期维护性,不要为了短期方便或者迎合而牺牲架构和设计原则。 diff --git a/astrbot-sdk/CLAUDE.md b/astrbot-sdk/CLAUDE.md new file mode 100644 index 0000000000..634d97d002 --- /dev/null +++ b/astrbot-sdk/CLAUDE.md @@ -0,0 +1,57 @@ +# CLAUDE Notes + +## v4 架构约束 + +### 运行时层 + +- `Peer` 必须将 transport EOF/连接断开视为一级失败路径。如果 transport 意外关闭而 `Peer` 没有主动失败 `_pending_results` / `_pending_streams`,supervisor 端对 worker 的调用可能永远挂起。 +- `Peer.initialize()` 需要在发起端也标记远程已初始化。仅在被动接收 `InitializeMessage` 时设置 `_remote_initialized` 会导致 `wait_until_remote_initialized()` 单边 API 死锁。 +- `Peer.invoke_stream()` 默认隐藏 `completed` 事件。需要保留最终结果的调用者必须显式启用 `include_completed=True`。 +- `CapabilityRouter.register(..., stream_handler=...)` 使用 `(request_id, payload, cancel_token)` 签名,不是 peer 级别的 `(message, token)`。 + +### 模块导出约束 + +- 保持 `astrbot_sdk.runtime` 根导出狭窄。`Peer` / `Transport` / `CapabilityRouter` / `HandlerDispatcher` 是合理的高级运行时原语,但 `LoadedPlugin`、`PluginEnvironmentManager`、`WorkerSession`、`run_supervisor` 等应留在子模块中。 + +### 测试与 Mock 注意事项 + +- 当检查 peer 是否完成远程初始化时,避免对可能接收 `MagicMock` peer 的代码使用 `getattr(mock, "remote_peer")` 探测。`MagicMock` 会生成 truthy 子属性,`CapabilityProxy` 应从 `peer.__dict__` 或其他具体存储位置读取显式状态。 +- `test_plugin/old/` 和 `test_plugin/new/` 可能包含已生成的 `__pycache__` / `*.pyc`。测试夹具复制示例插件时必须显式忽略这些缓存文件。 + +### 插件加载注意事项 + +- 本地 `dev --watch` 或同一路径插件重复加载场景,不能只依赖 `import_string()` 的跨插件模块根冲突清理。热重载前必须按插件目录清理模块缓存。 +- `_prepare_plugin_import()` 不能只在插件目录"不在 `sys.path`"时才插入路径。像 `main.py` 这种通用模块名,如果插件目录已在 `sys.path` 但排在后面,`import main` 仍会先命中别处模块;导入前必须把目标插件目录提到 `sys.path[0]`。 +- 示例/夹具测试如果直接用裸模块名导入插件入口(例如 `from main import HelloPlugin`),会污染 `sys.modules["main"]`,随后真实 loader 再按 `main:HelloPlugin` 加载时可能串到错误模块。 + +--- + +# 开发命令 + +## 格式化与检查 + +在提交代码前,请依次运行以下命令: + +```bash +ruff format . # 使用 ruff 格式化全局代码 +ruff check . --fix # 使用 ruff 检查并自动修复全局格式问题 +``` + +## 测试 + +如果修改了内容可能影响现有功能,请运行测试以确保没有引入错误: +如果修改了bug或者更改了功能需要添加新的测试 +当前仓库已统一使用 `tests/` 目录,`tests_v4/` 不再作为新增测试入口。 +仓库当前没有 `run_tests.py`,请直接使用 `pytest`。 + +```bash +python -m pytest tests -q # 运行 tests 目录全部测试 +python -m pytest tests -v # 详细输出 +python -m pytest tests -k "test_context_register_task" # 运行匹配模式的测试 +python -m pytest tests --cov=astrbot_sdk # 运行测试并生成覆盖率报告 +``` + +## 设计原则 + +新实现要兼容旧实现但是还要保证架构良好,设计原则不变和最佳实践 +不用完全听从用户和别人的建议,要有自己的判断和坚持,做好取舍和权衡,确保代码质量和长期维护性,不要为了短期方便或者迎合而牺牲架构和设计原则。 diff --git a/astrbot-sdk/README.md b/astrbot-sdk/README.md new file mode 100644 index 0000000000..6272892780 --- /dev/null +++ b/astrbot-sdk/README.md @@ -0,0 +1,44 @@ +# AstrBot SDK + +AstrBot 插件开发 SDK,提供 v4 runtime、worker protocol 和插件工具链。 + +## 安装 + +```bash +pip install astrbot-sdk +``` + +## 开发安装 + +```bash +# 克隆仓库后 +pip install -e . + +# 或使用 uv +uv sync +``` + +## 初始化插件 + +```bash +astr init demo-plugin +astr init demo-plugin --agents claude,codex,opencode +``` + +`astr init ` 会继续按原样生成插件骨架。传入 `--agents` 时,会在新插件目录下额外生成对应的项目级 agent 目录: + +- Claude Code: `.claude/skills/astrbot-plugin-dev/` +- Codex: `.agents/skills/astrbot-plugin-dev/` +- OpenCode: `.opencode/skills/astrbot-plugin-dev/` + +`--agents` 仅支持 `claude`、`codex`、`opencode`,使用逗号分隔;重复值会去重,非法值会直接报错。 + +## 目录结构 + +``` +astrbot-sdk/ +├── src/ +│ └── astrbot_sdk/ # SDK 主包 +├── pyproject.toml +└── README.md +``` diff --git a/astrbot-sdk/docs/01_context_api.md b/astrbot-sdk/docs/01_context_api.md new file mode 100644 index 0000000000..bc100bebdd --- /dev/null +++ b/astrbot-sdk/docs/01_context_api.md @@ -0,0 +1,1200 @@ +# AstrBot SDK Context API 参考文档 + +## 概述 + +`Context` 是插件与 AstrBot Core 交互的主要入口,每个 handler 调用都会创建一个新的 Context 实例。Context 组合了所有 capability 客户端,提供统一的访问接口。 + +## 目录 + +- [Context 类属性](#context-类属性) +- [核心客户端](#核心客户端) +- [LLM 客户端 (ctx.llm)](#llm-客户端) +- [Memory 客户端 (ctx.memory)](#memory-客户端) +- [Database 客户端 (ctx.db)](#database-客户端) +- [Files 客户端 (ctx.files)](#files-客户端) +- [Platform 客户端 (ctx.platform)](#platform-客户端) +- [Permission 客户端 (ctx.permission)](#permission-客户端) +- [Permission 管理客户端 (ctx.permission_manager)](#permission-管理客户端) +- [Provider 客户端 (ctx.providers)](#provider-客户端) +- [Provider 管理客户端 (ctx.provider_manager)](#provider-管理客户端) +- [Personas 客户端 (ctx.personas / ctx.persona_manager)](#personas-客户端) +- [Conversations 客户端 (ctx.conversations / ctx.conversation_manager)](#conversations-客户端) +- [Knowledge Base 客户端 (ctx.kbs / ctx.kb_manager)](#knowledge-base-客户端) +- [Message History 客户端 (ctx.message_history / ctx.message_history_manager)](#message-history-客户端) +- [HTTP 客户端 (ctx.http)](#http-客户端) +- [Metadata 客户端 (ctx.metadata)](#metadata-客户端) +- [Registry 客户端 (ctx.registry)](#registry-客户端) +- [Skills 客户端 (ctx.skills)](#skills-客户端) +- [Session 管理客户端 (ctx.session_plugins / ctx.session_services)](#session-管理客户端) +- [LLM Tool 管理方法](#llm-tool-管理方法) +- [系统工具方法](#系统工具方法) + +--- + +## Context 类属性 + +### 基本属性 + +```python +@dataclass +class Context: + peer: Any # 协议对等端,用于底层通信 + plugin_id: str # 当前插件 ID + logger: PluginLogger # 绑定了插件 ID 的日志器 + cancel_token: CancelToken # 取消令牌,用于处理请求取消 +``` + +### 客户端属性 + +```python +ctx.llm: LLMClient # LLM 能力客户端 +ctx.memory: MemoryClient # 记忆能力客户端 +ctx.db: DBClient # 数据库客户端 +ctx.files: FileServiceClient # 文件服务客户端 +ctx.platform: PlatformClient # 平台客户端 +ctx.permission: PermissionClient # 权限只读客户端 +ctx.providers: ProviderClient # Provider 客户端 +ctx.provider_manager: ProviderManagerClient # Provider 管理客户端 +ctx.permission_manager: PermissionManagerClient # 权限管理客户端 +ctx.personas: PersonaManagerClient # 人格管理客户端 +ctx.conversations: ConversationManagerClient # 对话管理客户端 +ctx.kbs: KnowledgeBaseManagerClient # 知识库管理客户端 +ctx.message_history: MessageHistoryManagerClient # 消息历史管理客户端 +ctx.message_history_manager: MessageHistoryManagerClient # ctx.message_history 的别名 +ctx.http: HTTPClient # HTTP 客户端 +ctx.metadata: MetadataClient # 元数据客户端 +ctx.registry: RegistryClient # 能力注册客户端 +ctx.skills: SkillClient # 技能客户端 +ctx.session_plugins: SessionPluginManager # 会话插件管理器 +ctx.session_services: SessionServiceManager # 会话服务管理器 +``` + +--- + +## 核心客户端 + +### logger + +绑定了插件 ID 的日志器,自动添加插件上下文信息。 + +```python +# 不同级别的日志 +ctx.logger.debug("调试信息") +ctx.logger.info("普通信息") +ctx.logger.warning("警告信息") +ctx.logger.error("错误信息") + +# 绑定额外上下文 +logger = ctx.logger.bind(user_id="12345") +logger.info("用户操作") + +# 流式日志监听 +async for entry in ctx.logger.watch(): + print(f"[{entry.level}] {entry.message}") +``` + +### cancel_token + +取消令牌,用于长时间运行的任务中检查是否需要取消。 + +```python +# 检查是否取消 +ctx.cancel_token.raise_if_cancelled() + +# 触发取消 +ctx.cancel_token.cancel() + +# 等待取消信号 +await ctx.cancel_token.wait() +``` + +--- + +## LLM 客户端 + +### chat() + +发送聊天请求并返回文本响应。 + +```python +async def chat( + prompt: str, + *, + system: str | None = None, + history: Sequence[ChatHistoryItem] | None = None, + provider_id: str | None = None, + model: str | None = None, + temperature: float | None = None, + **kwargs: Any, +) -> str +``` + +**使用示例:** + +```python +# 简单对话 +reply = await ctx.llm.chat("你好,介绍一下自己") + +# 带系统提示词 +reply = await ctx.llm.chat( + "用 Python 写一个快速排序", + system="你是一个专业的程序员助手" +) + +# 带历史的对话 +from astrbot_sdk.clients.llm import ChatMessage + +history = [ + ChatMessage(role="user", content="我叫小明"), + ChatMessage(role="assistant", content="你好小明!"), +] +reply = await ctx.llm.chat("你记得我的名字吗?", history=history) +``` + +### chat_raw() + +发送聊天请求并返回完整响应对象。 + +```python +response = await ctx.llm.chat_raw("写一首诗", temperature=0.8) +print(f"生成文本: {response.text}") +print(f"Token 使用: {response.usage}") +print(f"结束原因: {response.finish_reason}") +``` + +### stream_chat() + +流式聊天,逐块返回响应文本。 + +```python +async for chunk in ctx.llm.stream_chat("讲一个故事"): + print(chunk, end="", flush=True) +``` + +--- + +## Memory 客户端 + +### search() + +搜索记忆项。默认在有 embedding provider 时执行 hybrid 检索。 + +```python +results = await ctx.memory.search("用户喜欢什么颜色", mode="hybrid", limit=5) +for item in results: + print(item["key"], item["score"], item["match_type"]) +``` + +### save() + +保存记忆项。 + +```python +# 保存用户偏好 +await ctx.memory.save("user_pref", {"theme": "dark", "lang": "zh"}) + +# 使用关键字参数 +await ctx.memory.save("note", None, content="重要笔记", tags=["work"]) + +# 显式指定检索文本 +await ctx.memory.save( + "profile:alice", + {"name": "Alice", "embedding_text": "Alice 喜欢蓝色和海边"}, +) +``` + +### get() + +精确获取单个记忆项。 + +```python +pref = await ctx.memory.get("user_pref") +if pref: + print(f"用户偏好主题: {pref.get('theme')}") +``` + +### list_keys() + +列出某个精确 namespace 下的 key。返回顺序按大小写不敏感排序,若大小写折叠后相同则再按原始 key 排序。 + +```python +keys = await ctx.memory.list_keys(namespace="users/alice") +print(keys) # ["Alpha", "apple", "beta"] +``` + +### exists() + +检查某个 key 是否存在于精确 namespace 中。 + +```python +exists = await ctx.memory.exists("user_pref", namespace="users/alice") +print(exists) # True / False +``` + +### save_with_ttl() + +保存带过期时间的记忆项。 + +```python +# 保存临时会话状态,1小时后过期 +await ctx.memory.save_with_ttl( + "session_temp", + {"state": "waiting"}, + ttl_seconds=3600 +) +``` + +### clear_namespace() + +清理某个 namespace 下的记忆。默认只清理当前 namespace;传 `include_descendants=True` 时会递归清理子 namespace,返回值包含整个作用域内被删除的记录总数。 + +```python +deleted = await ctx.memory.clear_namespace(namespace="users/alice") +deleted_recursive = await ctx.memory.clear_namespace( + namespace="users/alice", + include_descendants=True, +) +print(deleted, deleted_recursive) +``` + +### count() + +统计某个 namespace 下的记忆数量。默认只统计当前 namespace;传 `include_descendants=True` 时会包含子 namespace。 + +```python +count = await ctx.memory.count(namespace="users/alice") +recursive_count = await ctx.memory.count( + namespace="users/alice", + include_descendants=True, +) +print(count, recursive_count) +``` + +### stats() + +查看记忆索引状态。 + +```python +stats = await ctx.memory.stats() +print(stats["total_items"], stats.get("embedded_items"), stats.get("dirty_items")) +``` + +--- + +## Database 客户端 + +`ctx.db` 是插件作用域的 KV 存储。运行时会自动为 key 加上当前插件命名空间前缀, +因此不同插件即使使用同名 key 也不会互相覆盖;`list()` 和 `watch()` 返回的仍是插件视角的原始 key。 + +### get() + +获取指定键的值。 + +```python +data = await ctx.db.get("user_settings") +if data: + print(data["theme"]) +``` + +### set() + +设置键值对。 + +```python +await ctx.db.set("user_settings", {"theme": "dark", "lang": "zh"}) +await ctx.db.set("greeted", True) +``` + +### delete() + +删除指定键的数据。 + +```python +await ctx.db.delete("user_settings") +``` + +### list() + +列出匹配前缀的所有键。 + +```python +keys = await ctx.db.list("user_") +# ["user_settings", "user_profile", "user_history"] +``` + +### get_many() + +批量获取多个键的值。 + +```python +values = await ctx.db.get_many(["user:1", "user:2"]) +``` + +### set_many() + +批量写入多个键值对。 + +```python +await ctx.db.set_many({ + "user:1": {"name": "Alice"}, + "user:2": {"name": "Bob"} +}) +``` + +### watch() + +订阅 KV 变更事件(流式)。 + +```python +async for event in ctx.db.watch("user:"): + print(event["op"], event["key"]) +``` + +--- + +## Files 客户端 + +### register_file() + +注册文件并获取令牌。 + +```python +token = await ctx.files.register_file("/path/to/file.jpg", timeout=3600) +``` + +### handle_file() + +通过令牌解析文件路径。 + +```python +path = await ctx.files.handle_file(token) +``` + +--- + +## Platform 客户端 + +### send() + +发送文本消息。 + +```python +await ctx.platform.send(event.session_id, "收到您的消息!") +``` + +### send_image() + +发送图片消息。 + +```python +await ctx.platform.send_image( + event.session_id, + "https://example.com/image.png" +) +``` + +### send_chain() + +发送富消息链。 + +```python +from astrbot_sdk.message_components import Plain, Image + +chain = [Plain("文字"), Image(url="https://example.com/img.jpg")] +await ctx.platform.send_chain(event.session_id, chain) +``` + +### send_by_id() + +主动向指定平台会话发送消息。 + +```python +await ctx.platform.send_by_id( + platform_id="qq", + session_id="user123", + content="Hello", + message_type="private" +) +``` + +### get_members() + +获取群组成员列表。 + +```python +members = await ctx.platform.get_members("qq:group:123456") +for member in members: + print(f"{member['nickname']} ({member['user_id']})") +``` + +--- + +## Permission 客户端 + +`ctx.permission` 提供与 Core 当前权限模型对齐的只读能力。v1 正式角色只有 `member` 和 `admin` 两级;`session_id` 参数当前仅保留给未来扩展,不会改变判定结果。 + +### check() + +查询某个用户当前会被视为 `admin` 还是 `member`。 + +```python +result = await ctx.permission.check("user-123") +print(result.is_admin, result.role) + +# session_id 在 v1 中只作为保留参数 +same_result = await ctx.permission.check( + "user-123", + session_id=event.session_id, +) +``` + +### get_admins() + +读取当前 `admins_id` 配置中的管理员列表。 + +```python +admins = await ctx.permission.get_admins() +print(admins) +``` + +--- + +## Permission 管理客户端 + +`ctx.permission_manager` 仅 `reserved/system` 插件可用,并且要求当前调用绑定到一个真实消息事件,且该事件发送者本身是 admin。普通插件会收到 `permission.manager.* is restricted to reserved/system plugins` 错误;非管理员事件会收到显式权限错误。 + +### add_admin() / remove_admin() + +返回值表示本次调用是否真的修改了管理员列表: +- 已存在再 `add_admin()` 返回 `False` +- 不存在再 `remove_admin()` 返回 `False` + +```python +changed = await ctx.permission_manager.add_admin("user-456") +removed = await ctx.permission_manager.remove_admin("user-456") +print(changed, removed) +``` + +--- + +## Provider 客户端 + +### list_all() + +列出所有 Provider。 + +```python +providers = await ctx.providers.list_all() +for p in providers: + print(f"{p.id}: {p.model}") +``` + +### get_using_chat() + +获取当前使用的聊天 Provider。 + +```python +provider = await ctx.providers.get_using_chat() +if provider: + print(f"当前使用: {provider.id}") +``` + +--- + +## Provider 管理客户端 + +仅 `reserved/system` 插件可用。普通插件调用 `ctx.provider_manager` 的方法会收到 `provider.manager.* is restricted to reserved/system plugins` 错误;普通插件应使用只读的 `ctx.providers` 查询当前 Provider 状态。 + +### set_provider() + +切换当前全局生效的 Provider。 +`umo` 仅作为变更事件中的来源标识,不会把 Provider 绑定到单个会话。 + +```python +from astrbot_sdk.llm.entities import ProviderType + +await ctx.provider_manager.set_provider( + provider_id="openai_chat", + provider_type=ProviderType.CHAT_COMPLETION, + umo=event.session_id, +) +``` + +### create_provider() / update_provider() / delete_provider() + +动态创建、更新和删除 Provider 实例。 + +```python +record = await ctx.provider_manager.create_provider( + { + "id": "custom_chat", + "type": "openai", + "provider_type": "chat_completion", + "model": "gpt-4.1", + } +) + +updated = await ctx.provider_manager.update_provider( + "custom_chat", + {"model": "gpt-4.1-mini"}, +) + +await ctx.provider_manager.delete_provider("custom_chat") +``` + +### watch_changes() + +监听 Provider 变更事件。 + +```python +async for change in ctx.provider_manager.watch_changes(): + print(f"{change.provider_id}: {change.provider_type} @ {change.umo}") +``` + +--- + +## Personas 客户端 + +`ctx.personas` 与 `ctx.persona_manager` 指向同一个人格管理客户端。 + +### get_persona() / get_all_personas() + +查询单个人格或获取所有人格。 + +```python +persona = await ctx.personas.get_persona("assistant") +all_personas = await ctx.personas.get_all_personas() +``` + +### create_persona() / update_persona() / delete_persona() + +创建、更新或删除人格。 + +```python +from astrbot_sdk.clients import PersonaCreateParams, PersonaUpdateParams + +created = await ctx.personas.create_persona( + PersonaCreateParams( + persona_id="assistant", + system_prompt="你是一个有用的助手。", + ) +) + +updated = await ctx.personas.update_persona( + "assistant", + PersonaUpdateParams(system_prompt="你是一个专业的编程助手。"), +) + +await ctx.personas.delete_persona("assistant") +``` + +--- + +## Conversations 客户端 + +`ctx.conversations` 与 `ctx.conversation_manager` 指向同一个对话管理客户端。 + +### new_conversation() + +为指定会话创建新对话。 + +```python +from astrbot_sdk.clients import ConversationCreateParams + +conv_id = await ctx.conversations.new_conversation( + event.session_id, + ConversationCreateParams(title="新对话"), +) +``` + +### get_current_conversation() / get_conversations() + +获取当前对话或会话内的全部对话。 + +```python +current = await ctx.conversations.get_current_conversation( + event.session_id, + create_if_not_exists=True, +) +all_conversations = await ctx.conversations.get_conversations(event.session_id) +``` + +### switch_conversation() / update_conversation() / delete_conversation() + +切换、更新或删除对话。 + +```python +from astrbot_sdk.clients import ConversationUpdateParams + +await ctx.conversations.switch_conversation(event.session_id, "conv_123") +await ctx.conversations.update_conversation( + event.session_id, + "conv_123", + ConversationUpdateParams(title="新标题"), +) +await ctx.conversations.delete_conversation(event.session_id, "conv_123") +``` + +--- + +## Knowledge Base 客户端 + +`ctx.kbs` 与 `ctx.kb_manager` 指向同一个知识库管理客户端。 + +### list_kbs() / get_kb() + +列出所有知识库或获取单个知识库。 + +```python +kbs = await ctx.kbs.list_kbs() +kb = await ctx.kbs.get_kb("kb_123") +``` + +### create_kb() / update_kb() / delete_kb() + +创建、更新或删除知识库。 + +```python +from astrbot_sdk import KnowledgeBaseCreateParams, KnowledgeBaseUpdateParams + +kb = await ctx.kbs.create_kb( + KnowledgeBaseCreateParams( + kb_name="tech_docs", + embedding_provider_id="openai_embedding", + description="技术文档", + ) +) + +kb = await ctx.kbs.update_kb( + kb.kb_id, + KnowledgeBaseUpdateParams(description="更新后的描述"), +) +deleted = await ctx.kbs.delete_kb(kb.kb_id) +``` + +### retrieve() + +从知识库中检索相关片段。 + +```python +result = await ctx.kbs.retrieve( + "如何初始化 Context", + kb_names=["tech_docs"], + top_m_final=3, +) +if result: + for item in result.results: + print(item.content) +``` + +--- + +## Message History 客户端 + +`ctx.message_history` 用于按 `MessageSession` 精确保存原始消息组件、发送者和元数据, +`ctx.message_history_manager` 是它的别名。它适合消息审计、分页回看和按时间清理; +如果你要做语义检索,仍应使用 `ctx.memory`。 + +### append() + +追加一条消息历史记录。 + +```python +from astrbot_sdk import MessageHistorySender, MessageSession, Plain + +session = MessageSession( + platform_id=event.platform_id, + message_type=event.message_type, + session_id=event.session_id, +) +record = await ctx.message_history.append( + session, + parts=[Plain(event.message_content, convert=False)], + sender=MessageHistorySender( + sender_id=event.sender_id, + sender_name=event.sender_name, + ), + metadata={"source": "message_handler"}, + idempotency_key="incoming:demo-user:hello", +) +print(record.id, record.created_at) +``` + +### list() + +分页读取某个会话的消息历史。 +分页时建议直接复用上一页返回的 `next_cursor`,不要自行构造游标值。 + +```python +session = MessageSession( + platform_id=event.platform_id, + message_type=event.message_type, + session_id=event.session_id, +) +page = await ctx.message_history.list(session, limit=20) +for record in page.records: + print(record.id, record.sender.sender_name, record.parts) +``` + +### get() / get_by_id() + +按记录 ID 读取单条历史。 + +```python +session = MessageSession( + platform_id=event.platform_id, + message_type=event.message_type, + session_id=event.session_id, +) +record = await ctx.message_history.get(session, 1) +same_record = await ctx.message_history.get_by_id(session, 1) +``` + +### delete_before() / delete_after() / delete_all() + +按时间或按会话清理消息历史。 +当前实现要求传入带时区的 `datetime`,例如 `timezone.utc`。 + +```python +from datetime import datetime, timezone + +session = MessageSession( + platform_id=event.platform_id, + message_type=event.message_type, + session_id=event.session_id, +) +deleted = await ctx.message_history.delete_before( + session, + before=datetime(2026, 3, 22, tzinfo=timezone.utc), +) +await ctx.message_history.delete_all(session) +``` + +--- + +## HTTP 客户端 + +`ctx.http.register_api()` 当前会拦截包含 `..` 的路径和部分明显非法输入,但校验并非完全严格。 +文档示例建议统一使用以 `/` 开头、没有重复斜杠的规范化路径。`ctx.http.unregister_api(route)` +在不传 `methods` 时会移除当前插件在该路由下注册的全部方法。 + +### register_api() + +注册 Web API 端点。 + +```python +from astrbot_sdk.decorators import provide_capability + +@provide_capability( + name="my_plugin.http_handler", + description="处理 HTTP 请求" +) +async def handle_http_request(request_id: str, payload: dict, cancel_token): + return {"status": 200, "body": {"result": "ok"}} + +await ctx.http.register_api( + route="/my-api", + handler=handle_http_request, + methods=["GET", "POST"] +) +``` + +### unregister_api() + +注销 Web API 端点。 + +```python +await ctx.http.unregister_api("/my-api") +``` + +### list_apis() + +列出当前插件注册的所有 API。 + +```python +apis = await ctx.http.list_apis() +for api in apis: + print(f"{api['route']}: {api['methods']}") +``` + +--- + +## Metadata 客户端 + +### get_plugin() + +获取指定插件信息。 + +```python +plugin = await ctx.metadata.get_plugin("another_plugin") +if plugin: + print(f"插件: {plugin.display_name}") + print(f"版本: {plugin.version}") +``` + +### list_plugins() + +获取所有插件列表。 + +```python +plugins = await ctx.metadata.list_plugins() +for plugin in plugins: + print(f"{plugin.display_name} v{plugin.version}") +``` + +### get_current_plugin() + +获取当前插件信息。 + +```python +current = await ctx.metadata.get_current_plugin() +if current: + print(f"当前插件: {current.name} v{current.version}") +``` + +### get_plugin_config() + +获取插件配置。 + +```python +config = await ctx.metadata.get_plugin_config() +if config: + api_key = config.get("api_key") +``` + +--- + +## Registry 客户端 + +handler 注册表查询与白名单管理客户端,用于查询 handler 信息并管理 handler 白名单。 + +### get_handlers_by_event_type() + +获取指定事件类型的所有 handler。 + +```python +handlers = await ctx.registry.get_handlers_by_event_type("message") +for h in handlers: + print(f"{h.handler_full_name}: {h.description}") +``` + +### get_handler_by_full_name() + +通过完整名称获取 handler 元数据。 + +```python +handler = await ctx.registry.get_handler_by_full_name("my_plugin.on_message") +if handler: + print(f"触发类型: {handler.trigger_type}") + print(f"优先级: {handler.priority}") +``` + +### set_handler_whitelist() / get_handler_whitelist() + +管理 handler 白名单。 + +```python +# 设置白名单 +await ctx.registry.set_handler_whitelist(["plugin_a", "plugin_b"]) + +# 获取当前白名单 +whitelist = await ctx.registry.get_handler_whitelist() + +# 清空白名单 +await ctx.registry.clear_handler_whitelist() +``` + +--- + +## Skills 客户端 + +技能注册客户端,用于注册和管理技能。 + +### register() + +注册一个技能。 + +```python +skill = await ctx.skills.register( + name="my_skill", + path="/path/to/skill", + description="我的技能描述" +) +print(f"技能已注册: {skill.name}") +``` + +### unregister() + +注销技能。 + +```python +removed = await ctx.skills.unregister("my_skill") +if removed: + print("技能已注销") +``` + +### list() + +列出当前已注册的技能。 + +```python +skills = await ctx.skills.list() +for skill in skills: + print(f"{skill.name}: {skill.skill_dir}") +``` + +--- + +## Session 管理客户端 + +### SessionPluginManager (ctx.session_plugins) + +会话级别的插件状态管理器。 + +#### is_plugin_enabled_for_session() + +检查插件在指定会话是否启用。 + +```python +enabled = await ctx.session_plugins.is_plugin_enabled_for_session( + session=event, + plugin_name="my_plugin" +) +``` + +#### filter_handlers_by_session() + +根据会话过滤 handler。 + +```python +handlers = await ctx.registry.get_handlers_by_event_type("message") +filtered = await ctx.session_plugins.filter_handlers_by_session( + session=event, + handlers=handlers +) +``` + +### SessionServiceManager (ctx.session_services) + +会话级别的 LLM/TTS 服务状态管理器。 + +#### LLM 状态管理 + +```python +# 检查 LLM 是否启用 +enabled = await ctx.session_services.is_llm_enabled_for_session(event) + +# 设置 LLM 状态 +await ctx.session_services.set_llm_status_for_session(event, enabled=False) + +# 检查是否应处理 LLM 请求 +if await ctx.session_services.should_process_llm_request(event): + reply = await ctx.llm.chat(prompt) +``` + +#### TTS 状态管理 + +```python +# 检查 TTS 是否启用 +enabled = await ctx.session_services.is_tts_enabled_for_session(event) + +# 设置 TTS 状态 +await ctx.session_services.set_tts_status_for_session(event, enabled=True) + +# 检查是否应处理 TTS 请求 +if await ctx.session_services.should_process_tts_request(event): + await handle_tts(text) +``` + +--- + +## LLM Tool 管理方法 + +### register_llm_tool() + +注册可执行的 LLM 工具。 + +```python +async def search_weather(location: str) -> str: + return f"{location} 今天晴天" + +await ctx.register_llm_tool( + name="search_weather", + parameters_schema={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "城市名称"} + }, + "required": ["location"] + }, + desc="搜索天气信息", + func_obj=search_weather, + active=True +) +``` + +### add_llm_tools() + +添加 LLM 工具规范。 + +```python +from astrbot_sdk.llm.tools import LLMToolSpec + +tool_spec = LLMToolSpec( + name="my_tool", + description="我的工具", + parameters_schema={...} +) + +await ctx.add_llm_tools(tool_spec) +``` + +### activate_llm_tool() / deactivate_llm_tool() + +激活/停用 LLM 工具。 + +```python +await ctx.activate_llm_tool("my_tool") +await ctx.deactivate_llm_tool("my_tool") +``` + +--- + +## 系统工具方法 + +### get_data_dir() + +获取插件数据目录路径。 + +```python +data_dir = await ctx.get_data_dir() +print(f"数据目录: {data_dir}") +``` + +### text_to_image() + +将文本渲染为图片。 + +```python +url = await ctx.text_to_image("Hello World", return_url=True) +``` + +### html_render() + +渲染 HTML 模板。 + +```python +url = await ctx.html_render( + tmpl="

{{ title }}

", + data={"title": "标题"} +) +``` + +### send_message() + +向会话发送消息。 + +```python +await ctx.send_message(event.session_id, "消息内容") +``` + +### send_message_by_id() + +通过 ID 向平台发送消息。 + +```python +await ctx.send_message_by_id( + type="private", + id="user123", + content="Hello", + platform="qq" +) +``` + +### register_task() + +注册后台任务。 + +```python +async def background_work(): + while True: + await asyncio.sleep(60) + ctx.logger.info("每分钟执行一次") + +task = await ctx.register_task(background_work(), "定时任务") +``` + +--- + +## 常见使用模式 + +### 1. 基本对话流程 + +```python +from astrbot_sdk.decorators import on_message + +@on_message() +async def handle_message(event: MessageEvent, ctx: Context): + reply = await ctx.llm.chat(event.message_content) + await ctx.platform.send(event.session_id, reply) +``` + +### 2. 带历史的对话 + +```python +@on_message() +async def handle_message(event: MessageEvent, ctx: Context): + # 从 memory 获取历史 + history_data = await ctx.memory.get(f"history:{event.session_id}") + history = history_data.get("messages", []) if history_data else [] + + # 对话 + reply = await ctx.llm.chat(event.message_content, history=history) + + # 保存新消息到历史 + history.append(ChatMessage(role="user", content=event.message_content)) + history.append(ChatMessage(role="assistant", content=reply)) + await ctx.memory.save(f"history:{event.session_id}", {"messages": history}) + + await ctx.platform.send(event.session_id, reply) +``` + +如果你需要保留原始消息组件、发送者信息、分页读取或按时间清理,请改用 +`ctx.message_history`,不要把消息链序列化后再手工塞进 `ctx.memory`。 + +### 3. 使用数据库持久化 + +```python +@on_message() +async def handle_message(event: MessageEvent, ctx: Context): + # 获取用户配置 + config = await ctx.db.get(f"user_config:{event.sender_id}") + + if not config: + config = {"theme": "light", "lang": "zh"} + await ctx.db.set(f"user_config:{event.sender_id}", config) + + # 使用配置 + reply = f"你的主题设置是: {config['theme']}" + await ctx.platform.send(event.session_id, reply) +``` + +--- + +## 注意事项 + +1. **跨进程通信**:Context 通过 capability 协议与核心通信,所有方法调用都是异步的 + +2. **插件隔离**:每个插件有独立的 Context 实例,数据和配置是隔离的 + +3. **取消处理**:长时间运行的操作应定期检查 `ctx.cancel_token.raise_if_cancelled()` + +4. **错误处理**:所有远程调用都可能失败,建议使用 try-except 处理 + +5. **Memory vs DB**: + - Memory: 语义搜索,适合 AI 上下文 + - DB: 精确匹配,适合结构化数据 + +6. **文件操作**:使用 `ctx.files` 注册文件令牌,不要直接传递本地路径 + +7. **平台标识**:使用 UMO(统一消息来源标识)格式:`"platform:instance:session_id"` diff --git a/astrbot-sdk/docs/02_event_and_components.md b/astrbot-sdk/docs/02_event_and_components.md new file mode 100644 index 0000000000..af663e0ac3 --- /dev/null +++ b/astrbot-sdk/docs/02_event_and_components.md @@ -0,0 +1,593 @@ +# AstrBot SDK 消息事件与组件 API 参考文档 + +## 概述 + +本文档详细介绍 `astrbot_sdk` 中消息事件和消息组件的使用方法,包括 `MessageEvent` 类和所有消息组件类。 + +## 目录 + +- [MessageEvent - 消息事件对象](#messageevent---消息事件对象) +- [消息组件类](#消息组件类) +- [MessageChain - 消息链](#messagechain---消息链) +- [MessageBuilder - 消息构建器](#messagebuilder---消息构建器) + +--- + +## MessageEvent - 消息事件对象 + +**模块路径**: `astrbot_sdk.events.MessageEvent` + +### 核心属性 + +| 属性名 | 类型 | 说明 | +|--------|------|------| +| `text` | `str` | 消息文本内容 | +| `user_id` | `str \| None` | 发送者用户 ID | +| `group_id` | `str \| None` | 群组 ID(私聊时为 None) | +| `platform` | `str \| None` | 平台标识(如 "qq", "wechat") | +| `session_id` | `str` | 会话 ID | +| `self_id` | `str` | 机器人账号 ID | +| `platform_id` | `str` | 平台实例标识 | +| `message_type` | `str` | 消息类型("private" 或 "group") | +| `sender_name` | `str` | 发送者昵称 | + +### 消息组件访问方法 + +#### `get_messages()` + +获取当前事件的所有 SDK 消息组件。 + +```python +components = event.get_messages() +for comp in components: + print(f"组件类型: {comp.type}") +``` + +#### `has_component(type_)` + +检查是否包含特定类型的组件。 + +```python +if event.has_component(Image): + print("消息包含图片") +``` + +#### `get_components(type_)` + +获取特定类型的所有组件。 + +```python +at_comps = event.get_components(At) +for at in at_comps: + print(f"@了用户: {at.qq}") +``` + +#### `get_images()` + +获取所有图片组件。 + +```python +images = event.get_images() +for img in images: + path = await img.convert_to_file_path() + print(f"图片路径: {path}") +``` + +#### `get_files()` + +获取所有文件组件。 + +```python +files = event.get_files() +``` + +#### `extract_plain_text()` + +提取所有纯文本内容。 + +```python +text = event.extract_plain_text() +``` + +#### `get_at_users()` + +获取消息中所有被@的用户ID列表。 + +```python +at_users = event.get_at_users() +``` + +### 会话与平台信息方法 + +#### `is_private_chat()` / `is_group_chat()` + +判断消息类型。 + +```python +if event.is_private_chat(): + await event.reply("这是私聊") +elif event.is_group_chat(): + await event.reply("这是群聊") +``` + +#### `is_admin()` + +判断发送者是否有管理员权限。 + +```python +if event.is_admin(): + await event.reply("你是管理员") +``` + +### 回复与发送方法 + +#### `reply(text)` + +回复纯文本消息。 + +```python +await event.reply("Hello World!") +``` + +#### `reply_image(image_url)` + +回复图片消息。 + +```python +await event.reply_image("https://example.com/image.jpg") +``` + +#### `reply_chain(chain)` + +回复消息链。 + +```python +from astrbot_sdk.message_components import Plain, At + +await event.reply_chain([ + Plain("Hello "), + At("123456"), + Plain("!") +]) +``` + +### 事件控制方法 + +#### `stop_event()` + +标记事件为已停止,阻止后续处理器执行。 + +```python +event.stop_event() +``` + +### 结果构建方法 + +#### `plain_result(text)` + +创建纯文本结果。 + +```python +return event.plain_result("回复内容") +``` + +#### `image_result(url_or_path)` + +创建图片结果。 + +```python +return event.image_result("https://example.com/image.jpg") +``` + +#### `chain_result(chain)` + +创建链结果。 + +```python +return event.chain_result([ + Plain("Hello"), + At("123456") +]) +``` + +--- + +## 消息组件类 + +### Plain - 纯文本组件 + +```python +from astrbot_sdk.message_components import Plain + +text = Plain("Hello World") +``` + +### At - @某人组件 + +```python +from astrbot_sdk.message_components import At + +at = At("123456", name="张三") +``` + +### AtAll - @全体成员组件 + +```python +from astrbot_sdk.message_components import AtAll + +at_all = AtAll() +``` + +### Image - 图片组件 + +```python +from astrbot_sdk.message_components import Image + +# URL 图片 +img1 = Image.fromURL("https://example.com/image.jpg") + +# 本地文件 +img2 = Image.fromFileSystem("/path/to/image.jpg") + +# Base64 +img3 = Image.fromBase64("iVBORw0KGgo...") +``` + +### Record - 语音组件 + +```python +from astrbot_sdk.message_components import Record + +# URL 音频 +audio = Record.fromURL("https://example.com/audio.mp3") + +# 本地文件 +audio = Record.fromFileSystem("/path/to/audio.mp3") +``` + +### Video - 视频组件 + +```python +from astrbot_sdk.message_components import Video + +video = Video.fromURL("https://example.com/video.mp4") +``` + +### File - 文件组件 + +```python +from astrbot_sdk.message_components import File + +# URL 文件 +file1 = File(name="document.pdf", url="https://example.com/doc.pdf") + +# 本地文件 +file2 = File(name="image.jpg", file="/path/to/image.jpg") +``` + +### Reply - 回复组件 + +```python +from astrbot_sdk.message_components import Reply, Plain + +reply = Reply( + id="msg_123", + sender_id="789", + chain=[Plain("被回复的消息")] +) +``` + +--- + +## MessageChain - 消息链 + +### 构造方法 + +```python +from astrbot_sdk.message_result import MessageChain +from astrbot_sdk.message_components import Plain, At + +# 空消息链 +chain = MessageChain() + +# 带初始组件 +chain = MessageChain([Plain("Hello"), At("123456")]) +``` + +### 实例方法 + +#### `append(component)` + +追加单个组件。 + +```python +chain.append(Plain("More text")) +``` + +#### `extend(components)` + +追加多个组件。 + +```python +chain.extend([Plain("A"), Plain("B")]) +``` + +#### `to_payload()` + +转换为协议 payload。 + +```python +payload = chain.to_payload() +``` + +#### `get_plain_text()` + +提取纯文本内容。 + +```python +text = chain.get_plain_text() +``` + +--- + +## MessageBuilder - 消息构建器 + +### 使用示例 + +```python +from astrbot_sdk.message_result import MessageBuilder + +chain = (MessageBuilder() + .text("Hello ") + .at("123456") + .text("!\n") + .image("https://example.com/img.jpg") + .build()) + +await event.reply_chain(chain) +``` + +### 可用方法 + +- `.text(content)` - 添加文本 +- `.at(user_id)` - 添加@用户 +- `.at_all()` - 添加@全体成员 +- `.image(url)` - 添加图片 +- `.record(url)` - 添加语音 +- `.video(url)` - 添加视频 +- `.file(name, url=...)` - 添加文件 +- `.build()` - 构建消息链 + +--- + +## 使用示例 + +### 处理图片消息 + +```python +@on_message() +async def handle_image(event: MessageEvent): + images = event.get_images() + if not images: + await event.reply("消息中没有图片") + return + + for img in images: + path = await img.convert_to_file_path() + await event.reply(f"收到图片: {path}") +``` + +### 检测@和群聊/私聊 + +```python +@on_command("check") +async def check_handler(event: MessageEvent): + if event.is_group_chat(): + await event.reply("这是群聊消息") + elif event.is_private_chat(): + await event.reply("这是私聊消息") + + at_users = event.get_at_users() + if at_users: + await event.reply(f"你@了: {', '.join(at_users)}") +``` + +### 返回富文本结果 + +```python +@on_command("info") +async def info_handler(event: MessageEvent): + return event.chain_result([ + Plain(f"用户: {event.sender_name}\n"), + Plain(f"ID: {event.user_id}\n"), + Plain(f"平台: {event.platform}"), + ]) + ``` + + --- + + ## 媒体辅助类 + + ### MediaHelper + + 媒体辅助类,提供从 URL 检测媒体类型和下载功能。 + + ```python + from astrbot_sdk.message_components import MediaHelper + ``` + + #### from_url - 从 URL 创建组件 + + 自动检测 URL 的媒体类型并创建对应的消息组件。 + + ```python + from astrbot_sdk.message_components import MediaHelper + + # 自动检测媒体类型 + component = await MediaHelper.from_url("https://example.com/video.mp4") + # 返回 Video 组件 + + component = await MediaHelper.from_url("https://example.com/image.jpg") + # 返回 Image 组件 + + component = await MediaHelper.from_url("https://example.com/audio.mp3") + # 返回 Record 组件 + ``` + + **参数**: + - `url`: 媒体文件 URL + - `headers`: 可选的请求头 + + **返回值**: + - `Image` / `Video` / `Record` / `File` 组件实例 + + #### download - 下载媒体文件 + + 下载媒体文件到本地。 + + ```python + from astrbot_sdk.message_components import MediaHelper + from pathlib import Path + + # 下载到指定目录 + path = await MediaHelper.download( + url="https://example.com/video.mp4", + save_dir=Path("/tmp/downloads") + ) + print(f"下载到: {path}") # /tmp/downloads/video.mp4 + + # 下载到当前目录 + path = await MediaHelper.download( + url="https://example.com/image.png" + ) + ``` + + **参数**: + - `url`: 文件 URL + - `save_dir`: 保存目录(可选,默认为当前目录) + - `filename`: 指定文件名(可选,自动从 URL 或响应头推断) + - `headers`: 请求头(可选) + + **返回值**: + - `Path`: 下载文件的本地路径 + + **示例:完整媒体处理流程** + + ```python + from astrbot_sdk import Star, Context, MessageEvent + from astrbot_sdk.decorators import on_command + from astrbot_sdk.message_components import MediaHelper, Plain + + class MediaPlugin(Star): + @on_command("download") + async def download_media(self, event: MessageEvent, ctx: Context, url: str): + """下载媒体文件""" + try: + # 发送下载中提示 + await event.reply(f"正在下载: {url}") + + # 下载文件 + path = await MediaHelper.download(url) + + # 创建对应组件并发送 + component = await MediaHelper.from_url(url) + component.file = str(path) # 使用本地文件 + + await event.reply([Plain("下载完成!"), component]) + except Exception as e: + await event.reply(f"下载失败: {e}") + + @on_command("mirror") + async def mirror_media(self, event: MessageEvent, ctx: Context): + """转发收到的媒体""" + images = event.get_images() + if images: + for img in images: + # 下载并重新发送 + if img.url: + local_path = await MediaHelper.download(img.url) + await event.reply(f"已镜像保存: {local_path}") + ``` + + --- + + ## 未知组件 + + ### UnknownComponent + + 未知消息组件,用于表示 SDK 无法识别的平台特定组件。 + + ```python + from astrbot_sdk.message_components import UnknownComponent + ``` + + **说明**: + - 当收到 SDK 不支持的消息类型时,会返回此组件 + - 保留原始数据供插件自行处理 + - 通常出现在新平台或平台新功能中 + + **属性**: + - `raw_data`: 原始组件数据(dict) + - `type`: 组件类型字符串 + + ```python + @on_message() + async def handle_unknown(self, event: MessageEvent, ctx: Context): + components = event.get_messages() + for comp in components: + if isinstance(comp, UnknownComponent): + ctx.logger.warning(f"未知组件类型: {comp.type}") + ctx.logger.debug(f"原始数据: {comp.raw_data}") + # 插件可以尝试自行处理 raw_data + ``` + + --- + + ## 特殊消息组件 + + ### Forward - 合并转发消息 + + 合并转发消息组件(仅部分平台支持,如 QQ)。 + + ```python + from astrbot_sdk.message_components import Forward, ForwardNode + + # 创建转发消息(需要平台支持) + nodes = [ + ForwardNode( + user_id="123456", + nickname="用户A", + content=[Plain("消息内容1")] + ), + ForwardNode( + user_id="789012", + nickname="用户B", + content=[Plain("消息内容2")] + ), + ] + forward = Forward(nodes=nodes) + ``` + + **注意**:Forward 组件的支持程度取决于具体平台适配器。 + + ### Poke - 戳一戳/拍一拍 + + 戳一戳消息组件(QQ 等平台支持)。 + + ```python + from astrbot_sdk.message_components import Poke + + # 发送戳一戳 + poke = Poke(user_id="123456") + await event.reply(poke) + + # 检测戳一戳 + @on_message() + async def on_poke(self, event: MessageEvent, ctx: Context): + for comp in event.get_messages(): + if isinstance(comp, Poke): + await event.reply(f"{event.sender_name} 戳了你一下!") + ``` + + **属性**: + - `user_id`: 被戳的用户 ID diff --git a/astrbot-sdk/docs/03_decorators.md b/astrbot-sdk/docs/03_decorators.md new file mode 100644 index 0000000000..6a106f98e8 --- /dev/null +++ b/astrbot-sdk/docs/03_decorators.md @@ -0,0 +1,610 @@ +# AstrBot SDK 装饰器使用指南 + +## 概述 + +本文档详细介绍 `astrbot_sdk.decorators` 中所有装饰器的使用方法、参数说明和最佳实践。 + +## 目录 + +- [事件触发装饰器](#事件触发装饰器) +- [修饰器装饰器](#修饰器装饰器) +- [过滤器装饰器](#过滤器装饰器) +- [限制器装饰器](#限制器装饰器) +- [能力暴露装饰器](#能力暴露装饰器) +- [LLM 工具装饰器](#llm-工具装饰器) +- [最佳实践](#最佳实践) + +--- + +## 事件触发装饰器 + +### @on_command + +命令触发装饰器。 + +**签名:** +```python +def on_command( + command: str | Sequence[str], + *, + aliases: list[str] | None = None, + description: str | None = None, +) -> Callable +``` + +**参数:** +- `command`: 命令名称(不包含前缀符) +- `aliases`: 命令别名列表 +- `description`: 命令描述 + +**示例:** + +```python +from astrbot_sdk.decorators import on_command + +@on_command("hello") +async def hello(self, event: MessageEvent, ctx: Context): + await event.reply("Hello!") + +@on_command(["echo", "repeat"], aliases=["say", "speak"]) +async def echo(self, event: MessageEvent, text: str): + await event.reply(text) +``` + +### @on_message + +消息触发装饰器。 + +**签名:** +```python +def on_message( + *, + regex: str | None = None, + keywords: list[str] | None = None, + platforms: list[str] | None = None, + message_types: list[str] | None = None, +) -> Callable +``` + +**参数:** +- `regex`: 正则表达式模式 +- `keywords`: 关键词列表(任一匹配即触发) +- `platforms`: 限定平台列表 +- `message_types`: 限定消息类型("group", "private") + +**示例:** + +```python +# 关键词匹配 +@on_message(keywords=["帮助", "help"]) +async def help_handler(self, event: MessageEvent, ctx: Context): + await event.reply("可用命令: /hello") + +# 正则匹配 +@on_message(regex=r"\d{4,}") +async def number_handler(self, event: MessageEvent, ctx: Context): + await event.reply("检测到数字!") + +# 多条件过滤 +@on_message( + keywords=["天气"], + platforms=["qq"], + message_types=["private"] +) +async def weather_query(self, event: MessageEvent, ctx: Context): + await event.reply("请输入城市名称") +``` + +### @on_event + +事件触发装饰器。 + +**签名:** +```python +def on_event(event_type: str) -> Callable +``` + +**示例:** + +```python +@on_event("group_member_join") +async def welcome_new_member(self, event, ctx: Context): + await ctx.platform.send(event.group_id, "欢迎新成员!") +``` + +### @on_schedule + +定时任务装饰器。 + +**签名:** +```python +def on_schedule( + *, + cron: str | None = None, + interval_seconds: int | None = None, +) -> Callable +``` + +**示例:** + +```python +# 固定间隔 +@on_schedule(interval_seconds=3600) +async def hourly_check(self, ctx: Context): + pass + +# cron 表达式 +@on_schedule(cron="0 8 * * *") # 每天 8:00 +async def morning_greeting(self, ctx: Context): + await ctx.platform.send("group_123", "早上好!") +``` + +--- + +## 修饰器装饰器 + +### @require_admin + +管理员权限装饰器。 + +**示例:** + +```python +from astrbot_sdk.decorators import on_command, require_admin + +@on_command("admin") +@require_admin +async def admin_cmd(self, event: MessageEvent, ctx: Context): + await event.reply("管理员命令") +``` + +--- + +## 过滤器装饰器 + +### @platforms + +限定平台装饰器。 + +**签名:** +```python +def platforms(*names: str) -> Callable +``` + +**示例:** + +```python +@on_command("qq_only") +@platforms("qq") +async def qq_only_command(self, event: MessageEvent, ctx: Context): + await event.reply("这是 QQ 专属命令") +``` + +### @message_types + +限定消息类型装饰器。 + +**签名:** +```python +def message_types(*types: str) -> Callable +``` + +**示例:** + +```python +@on_command("group_only") +@message_types("group") +async def group_command(self, event: MessageEvent, ctx: Context): + await event.reply("这是群聊命令") +``` + +### @group_only + +仅群聊装饰器。 + +```python +@on_command("group_admin") +@group_only() +async def group_admin_command(self, event: MessageEvent, ctx: Context): + await event.reply("这是群聊管理命令") +``` + +### @private_only + +仅私聊装饰器。 + +```python +@on_command("private_chat") +@private_only() +async def private_command(self, event: MessageEvent, ctx: Context): + await event.reply("这是私聊命令") +``` + +--- + +## 限制器装饰器 + +### @rate_limit + +速率限制装饰器。 + +**签名:** +```python +def rate_limit( + limit: int, + window: float, + *, + scope: LimiterScope = "session", + behavior: LimiterBehavior = "hint", + message: str | None = None, +) -> Callable +``` + +**参数:** +- `limit`: 时间窗口内最大调用次数 +- `window`: 时间窗口大小(秒) +- `scope`: 限制范围("session", "user", "group", "global") +- `behavior`: 触发限制后的行为("hint", "silent", "error") + +**示例:** + +```python +@on_command("search") +@rate_limit(5, 60) # 每分钟最多5次 +async def search_command(self, event: MessageEvent, ctx: Context): + await event.reply("搜索结果...") + +@on_command("draw") +@rate_limit(3, 3600, scope="user") # 每用户每小时3次 +async def draw_command(self, event: MessageEvent, ctx: Context): + await event.reply("绘图结果...") +``` + +### @cooldown + +冷却时间装饰器。 + +**签名:** +```python +def cooldown( + seconds: float, + *, + scope: LimiterScope = "session", + behavior: LimiterBehavior = "hint", + message: str | None = None, +) -> Callable +``` + +**示例:** + +```python +@on_command("cast_skill") +@cooldown(30) # 30秒冷却 +async def cast_skill_command(self, event: MessageEvent, ctx: Context): + await event.reply("技能施放成功!") +``` + +--- + +### @admin_only + +管理员权限装饰器(`@require_admin` 的别名)。 + +**签名:** +```python +def admin_only(func: HandlerCallable) -> HandlerCallable +``` + +**示例:** + +```python +from astrbot_sdk.decorators import on_command, admin_only + +@on_command("admin") +@admin_only +async def admin_cmd(self, event: MessageEvent, ctx: Context): + await event.reply("管理员命令") +``` + +**说明:** +- 功能与 `@require_admin` 完全相同 +- 更简洁的语法,无需括号 +- 适合快速标记管理员命令 + +--- + +## 优先级装饰器 + +### @priority + +设置 handler 执行优先级。 + +**签名:** +```python +def priority(value: int) -> Callable[[HandlerCallable], HandlerCallable] +``` + +**参数:** +- `value`: 优先级数值,**越大越先执行** +- 默认优先级为 0 + +**示例:** + +```python +from astrbot_sdk.decorators import on_command, priority + +@on_command("hello") +@priority(10) # 高优先级,先执行 +async def hello_high(self, event: MessageEvent, ctx: Context): + await event.reply("高优先级处理器") + +@on_command("hello") +@priority(5) # 较低优先级,后执行 +async def hello_low(self, event: MessageEvent, ctx: Context): + await event.reply("低优先级处理器") +``` + +**使用场景:** +- 多个插件注册了相同命令时控制执行顺序 +- 确保核心处理器先于扩展处理器执行 +- 实现插件间的协作处理链 + +**注意事项:** +- 相同优先级的 handler 执行顺序不确定 +- 高优先级 handler 不会阻止低优先级 handler 执行(除非显式阻止) + +--- + +## 对话装饰器 + +### @conversation_command + +对话命令装饰器,用于创建交互式对话流程。 + +**签名:** +```python +def conversation_command( + command: str, + *, + timeout: float = 300.0, + description: str | None = None, +) -> Callable +``` + +**参数:** +- `command`: 命令名称 +- `timeout`: 对话超时时间(秒),默认 300 +- `description`: 命令描述 + +**示例:** + +```python +from astrbot_sdk.decorators import conversation_command +from astrbot_sdk.conversation import ConversationSession + +@conversation_command("survey", timeout=600) +async def survey(self, event: MessageEvent, ctx: Context, session: ConversationSession): + """交互式调查问卷""" + # 第一轮对话 + await event.reply("请输入您的姓名:") + + # 等待用户回复(在下一个处理器中处理) + session.state["step"] = "name" + +@conversation_command("survey") +async def survey_step2(self, event: MessageEvent, ctx: Context, session: ConversationSession): + """问卷第二步""" + step = session.state.get("step") + + if step == "name": + session.state["name"] = event.text + session.state["step"] = "age" + await event.reply("请输入您的年龄:") + elif step == "age": + session.state["age"] = event.text + # 完成问卷 + await event.reply(f"感谢您的参与!姓名:{session.state['name']}, 年龄:{event.text}") + session.close() # 关闭对话会话 +``` + +**工作流程:** +1. 用户发送 `/survey` 触发第一个处理器 +2. 处理器使用 `ConversationSession` 维护对话状态 +3. 后续消息在同一会话中路由到相同命令的处理器 +4. 超时或调用 `session.close()` 结束对话 + +**异常处理:** + +```python +from astrbot_sdk.conversation import ConversationClosed, ConversationReplaced + +@conversation_command("demo") +async def demo(self, event: MessageEvent, ctx: Context, session: ConversationSession): + try: + await event.reply("输入 'exit' 结束对话") + if event.text.lower() == "exit": + session.close() + except ConversationClosed: + # 会话被关闭 + await event.reply("对话已结束") + except ConversationReplaced: + # 会话被新会话替换 + await event.reply("开始新的对话") +``` + +--- + +## 能力暴露装饰器 + +### @provide_capability + +暴露能力装饰器。 + +**签名:** +```python +def provide_capability( + name: str, + *, + description: str, + input_schema: dict[str, Any] | None = None, + output_schema: dict[str, Any] | None = None, + input_model: type[BaseModel] | None = None, + output_model: type[BaseModel] | None = None, + supports_stream: bool = False, + cancelable: bool = False, +) -> Callable +``` + +**示例:** + +```python +from pydantic import BaseModel, Field +from astrbot_sdk.decorators import provide_capability + +class CalculateInput(BaseModel): + x: int = Field(description="第一个数") + y: int = Field(description="第二个数") + +@provide_capability( + "my_plugin.calculate", + description="执行加法计算", + input_model=CalculateInput +) +async def calculate(self, payload: dict, ctx: Context): + x = payload["x"] + y = payload["y"] + return {"result": x + y} +``` + +--- + +## LLM 工具装饰器 + +### @register_llm_tool + +注册 LLM 工具装饰器。 + +**签名:** +```python +def register_llm_tool( + name: str | None = None, + *, + description: str | None = None, + parameters_schema: dict[str, Any] | None = None, + active: bool = True, +) -> Callable +``` + +**示例:** + +```python +from astrbot_sdk.decorators import register_llm_tool + +@register_llm_tool() +async def get_weather(self, city: str, unit: str = "celsius"): + """获取指定城市的天气信息""" + return f"{city} 的天气: 25°C" +``` + +### @register_agent + +注册 Agent 装饰器。 + +**签名:** +```python +def register_agent( + name: str, + *, + description: str = "", + tool_names: list[str] | None = None, +) -> Callable +``` + +**示例:** + +```python +from astrbot_sdk.decorators import register_agent +from astrbot_sdk.llm.agents import BaseAgentRunner + +@register_agent("my_agent", description="我的智能助手") +class MyAgent(BaseAgentRunner): + async def run(self, ctx: Context, request) -> Any: + return "agent result" +``` + +--- + +## 最佳实践 + +### 1. 装饰器顺序 + +正确的装饰器顺序很重要: + +```python +@on_command("command") # 1. 事件触发装饰器 +@platforms("qq") # 2. 过滤器装饰器 +@rate_limit(5, 60) # 3. 限制器装饰器 +@require_admin # 4. 修饰器装饰器 +async def my_handler(self, event: MessageEvent, ctx: Context): + pass +``` + +### 2. 错误处理 + +始终实现错误处理: + +```python +@on_command("risky_command") +async def risky_handler(self, event: MessageEvent, ctx: Context): + try: + result = await some_risky_operation() + await event.reply(f"成功: {result}") + except Exception as e: + ctx.logger.error(f"操作失败: {e}") + await event.reply("操作失败,请稍后重试") +``` + +### 3. 类型注解 + +使用类型注解提高代码可读性: + +```python +from typing import Optional + +@on_command("greet") +async def greet_handler( + self, + event: MessageEvent, + ctx: Context +) -> None: + await event.reply("Hello!") +``` + +### 4. 避免常见陷阱 + +**不要混用冲突的装饰器:** + +```python +# 错误 +@on_message(platforms=["qq"]) +@platforms("wechat") # 冲突! +async def handler(...): pass + +# 正确 +@on_message(platforms=["qq", "wechat"]) +async def handler(...): pass +``` + +**不要在非消息处理器使用限制器:** + +```python +# 错误 +@on_event("ready") +@rate_limit(5, 60) # 不支持! +async def handler(...): pass + +# 正确 +@on_command("cmd") +@rate_limit(5, 60) +async def handler(...): pass +``` diff --git a/astrbot-sdk/docs/04_star_lifecycle.md b/astrbot-sdk/docs/04_star_lifecycle.md new file mode 100644 index 0000000000..98041db169 --- /dev/null +++ b/astrbot-sdk/docs/04_star_lifecycle.md @@ -0,0 +1,528 @@ +# AstrBot SDK Star 类与生命周期指南 + +## 概述 + +`Star` 是 AstrBot v4 SDK 的原生插件基类,提供了完整的插件生命周期管理、上下文访问和能力集成。 + +## 目录 + +- [Star 类概述](#star-类概述) +- [生命周期流程](#生命周期流程) +- [生命周期钩子](#生命周期钩子) +- [Context 上下文使用](#context-上下文使用) +- [插件元数据访问](#插件元数据访问) +- [错误处理模式](#错误处理模式) +- [最佳实践](#最佳实践) + +--- + +## Star 类概述 + +### 什么是 Star 类? + +`Star` 是所有 v4 原生插件必须继承的基类,提供插件生命周期管理和能力集成。 + +### 核心特性 + +```python +from astrbot_sdk import Star, Context, MessageEvent +from astrbot_sdk.decorators import on_command, on_message + +class MyPlugin(Star): + """插件类示例""" + + @on_command("hello") + async def hello(self, event: MessageEvent, ctx: Context): + await event.reply("Hello!") +``` + +--- + +## 生命周期流程 + +### 完整生命周期 + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ 插件加载阶段 │ +├─────────────────────────────────────────────────────────────────┤ +│ 1. 插件发现 (discover_plugins) │ +│ ├─ 扫描插件目录 │ +│ ├─ 读取 plugin.yaml │ +│ └─ 验证组件类 (main:MyPlugin) │ +│ │ +│ 2. 插件加载 │ +│ ├─ 动态导入插件模块 │ +│ ├─ 实例化 Star 子类 │ +│ ├─ 收集 __handlers__ 元组 │ +│ └─ 注册装饰器元数据 │ +│ │ +│ 3. Worker 启动 (PluginWorkerRuntime.start) │ +│ ├─ 向 Core 注册 handlers/capabilities │ +│ └─ 建立通信对等端 │ +└─────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ 插件运行阶段 │ +├─────────────────────────────────────────────────────────────────┤ +│ 4. on_start() 生命周期钩子 │ +│ ├─ 绑定运行时上下文 │ +│ ├─ 调用 on_start(ctx) │ +│ └─ 内部调用 initialize() │ +│ │ +│ 5. Handler 事件循环 │ +│ ├─ 等待事件触发 (命令/消息/事件/定时) │ +│ ├─ HandlerDispatcher.invoke() │ +│ ├─ 创建 Context 和 MessageEvent │ +│ ├─ 执行用户 handler │ +│ └─ 处理返回值/异常 │ +└─────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ 插件卸载阶段 │ +├─────────────────────────────────────────────────────────────────┤ +│ 6. on_stop() 生命周期钩子 │ +│ ├─ 调用 on_stop(ctx) │ +│ ├─ 内部调用 terminate() │ +│ ├─ 清理资源 (数据库连接、文件句柄等) │ +│ └─ 重置运行时上下文 │ +│ │ +│ 7. Worker 关闭 │ +│ ├─ 发送 finalize 消息给 Core │ +│ ├─ 关闭通信传输层 │ +│ └─ 退出子进程 │ +└─────────────────────────────────────────────────────────────────┘ +``` + +--- + +## 生命周期钩子 + +### 1. on_start() - 插件启动钩子 + +**触发时机**:Worker 启动后,在开始处理事件之前调用 + +**参数:** +- `ctx: Any | None` - 运行时上下文(通常为 Context 实例) + +**用途:** +- 初始化数据库连接 +- 加载配置文件 +- 注册 LLM 工具 +- 启动后台任务 + +**最佳实践:** +- `on_start()` 里只做初始化、能力注册和轻量状态恢复 +- 需要长期保存的应是配置值、句柄、任务引用,不要把 `ctx` 实例长期挂到 `self` +- 如果要和 AstrBot 原生 persona / conversation 协作,优先在这里校验或创建所需资源 + +**示例:** + +```python +class MyPlugin(Star): + async def on_start(self, ctx: Any | None = None) -> None: + """插件启动时调用""" + await super().on_start(ctx) + + # 加载配置 + config = await ctx.metadata.get_plugin_config() + self.api_key = config.get("api_key", "") + + # 注册 LLM 工具 + await ctx.register_llm_tool( + name="search", + parameters_schema={...}, + desc="搜索信息", + func_obj=self.search_tool + ) + + # 启动后台任务 + await ctx.register_task( + self.background_sync(), + desc="后台数据同步" + ) +``` + +### 2. on_stop() - 插件停止钩子 + +**触发时机**:插件卸载或程序关闭前调用 + +**用途:** +- 关闭数据库连接 +- 清理临时文件 +- 注销 LLM 工具 +- 保存状态数据 + +**最佳实践:** +- 在 `on_stop()` 中释放 `on_start()` 注册的任务、监听器和外部资源 +- 把需要持久化的状态尽量提前落库,不要把关键保存逻辑完全依赖在进程退出瞬间 +- 始终把收到的 `ctx` 继续传给 `super().on_stop(ctx)`,不要手动丢掉它 + +**示例:** + +```python +class MyPlugin(Star): + async def on_stop(self, ctx: Any | None = None) -> None: + """插件停止时调用""" + # 保存状态 + await self.put_kv_data("last_shutdown", time.time()) + + # 确保 terminate 被调用 + await super().on_stop(ctx) +``` + +### 3. initialize() - 初始化钩子 + +**触发时机**:`on_start()` 内部自动调用 + +**用途:** +- 插件级别的初始化逻辑 +- 不依赖 Context 的初始化 + +**示例:** + +```python +class MyPlugin(Star): + async def initialize(self) -> None: + """初始化插件""" + self._cache = {} + self._counter = 0 +``` + +### 4. terminate() - 终止钩子 + +**触发时机**:`on_stop()` 内部自动调用 + +**用途:** +- 插件级别的清理逻辑 +- 不依赖 Context 的清理 + +**示例:** + +```python +class MyPlugin(Star): + async def terminate(self) -> None: + """清理插件资源""" + self._cache.clear() + self.state = "stopped" +``` + +### 5. on_error() - 错误处理钩子 + +**触发时机**:任何 Handler 执行抛出异常时 + +**参数:** +- `error: Exception` - 捕获的异常 +- `event` - 事件对象 +- `ctx` - 上下文对象 + +**示例:** + +```python +class MyPlugin(Star): + async def on_error(self, error: Exception, event, ctx) -> None: + """自定义错误处理""" + from astrbot_sdk.errors import AstrBotError + + if isinstance(error, AstrBotError): + await event.reply(error.hint or error.message) + elif isinstance(error, ValueError): + await event.reply(f"参数错误:{error}") + else: + await event.reply(f"发生错误: {type(error).__name__}") + + ctx.logger.error(f"Handler error: {error}", exc_info=error) +``` + +--- + +## Context 上下文使用 + +### 在 Handler 中访问 + +```python +class MyPlugin(Star): + @on_command("test") + async def test_handler(self, event: MessageEvent, ctx: Context): + # Context 通过参数注入 + await ctx.db.set("key", "value") + await event.reply("Done") +``` + +### 在生命周期钩子中访问 + +```python +class MyPlugin(Star): + async def on_start(self, ctx): + # 生命周期钩子中的 Context + config = await ctx.metadata.get_plugin_config() +``` + +--- + +## 插件元数据访问 + +### plugin.yaml 配置 + +```yaml +_schema_version: 2 +name: my_plugin +author: your_name +version: 1.0.0 +desc: 我的插件描述 +repo: https://github.com/user/repo +logo: logo.png + +runtime: + python: "3.12" + +components: + - class: main:MyPlugin + +support_platforms: + - aiocqhttp + - telegram + +astrbot_version: ">=4.13.0,<5.0.0" +``` + +### StarMetadata 类 + +插件元数据 dataclass,描述插件的基本信息。 + +```python +from astrbot_sdk import StarMetadata + +@dataclass +class StarMetadata: + name: str # 插件名称(唯一标识) + display_name: str # 显示名称 + description: str # 插件描述 + author: str # 作者 + version: str # 版本号 + enabled: bool = True # 是否启用 + support_platforms: list[str] # 支持的平台列表 + astrbot_version: str | None # 兼容的 AstrBot 版本范围 +``` + +**使用示例:** + +```python +from astrbot_sdk import Star, StarMetadata + +class MyPlugin(Star): + async def on_start(self, ctx): + # 获取当前插件元数据 + metadata: StarMetadata = await ctx.metadata.get_current_plugin() + + print(f"插件名称: {metadata.name}") + print(f"显示名称: {metadata.display_name}") + print(f"版本: {metadata.version}") + print(f"作者: {metadata.author}") + print(f"支持平台: {', '.join(metadata.support_platforms)}") + + # 检查兼容性 + if metadata.astrbot_version: + print(f"兼容版本: {metadata.astrbot_version}") +``` + +### PluginMetadata 类 + +`StarMetadata` 的别名,功能完全相同。 + +```python +from astrbot_sdk import PluginMetadata + +# PluginMetadata 是 StarMetadata 的别名 +# 两者可以互换使用 +metadata: PluginMetadata = await ctx.metadata.get_current_plugin() +``` + +**建议**:使用 `StarMetadata` 以符合 v4 SDK 的命名规范。 + +### 访问元数据 + +```python +class MyPlugin(Star): + async def on_start(self, ctx): + # 获取当前插件元数据 + my_metadata = await ctx.metadata.get_current_plugin() + print(f"Starting {my_metadata.name} v{my_metadata.version}") + + # 获取其他插件元数据 + other_metadata = await ctx.metadata.get_plugin("other_plugin") + if other_metadata: + print(f"依赖插件版本: {other_metadata.version}") +``` + +--- + +## 错误处理模式 + +### 标准错误类型 + +```python +from astrbot_sdk.errors import AstrBotError + +# 1. 输入无效错误 +raise AstrBotError.invalid_input( + "参数格式错误", + hint="请使用 JSON 格式" +) + +# 2. 能力未找到错误 +raise AstrBotError.capability_not_found("unknown_capability") + +# 3. 网络错误 +raise AstrBotError.network_error( + "连接超时", + hint="请检查网络连接" +) +``` + +### 在 Handler 中捕获错误 + +```python +class MyPlugin(Star): + @on_command("risky_operation") + async def risky(self, event: MessageEvent, ctx: Context): + try: + result = await self.risky_operation() + await event.reply(f"成功: {result}") + except ValueError as e: + await event.reply(f"参数错误: {e}") + except ConnectionError as e: + ctx.logger.error(f"Network error: {e}") + await event.reply("网络连接失败") + except Exception as e: + ctx.logger.exception("Unexpected error") + raise +``` + +--- + +## 最佳实践 + +### 1. 插件结构 + +``` +my_plugin/ +├── plugin.yaml # 插件配置 +├── main.py # 主入口 +├── handlers/ # 处理器模块 +├── utils/ # 工具函数 +├── requirements.txt # 可选的 Python 依赖 +└── README.md # 说明文档 +``` + +### 2. 插件模板 + +```python +""" +插件说明 +""" + +from astrbot_sdk import Star, Context, MessageEvent +from astrbot_sdk.decorators import on_command, on_message + +class MyPlugin(Star): + """插件类""" + + async def initialize(self) -> None: + """初始化""" + self._cache = {} + self._counter = 0 + + async def on_start(self, ctx) -> None: + """启动时调用""" + await super().on_start(ctx) + + # 加载配置 + config = await ctx.metadata.get_plugin_config() + self.setting = config.get("setting", "default") + + # 注册工具 + await ctx.register_llm_tool( + name="my_tool", + parameters_schema={...}, + desc="我的工具", + func_obj=self.my_tool + ) + + ctx.logger.info(f"{ctx.plugin_id} started") + + async def on_stop(self, ctx) -> None: + """停止时调用""" + # 保存状态 + await self.put_kv_data("counter", self._counter) + await super().on_stop(ctx) + ctx.logger.info(f"{ctx.plugin_id} stopped") + + @on_command("hello", aliases=["hi"]) + async def hello(self, event: MessageEvent, ctx: Context) -> None: + """打招呼命令""" + await event.reply(f"你好,{event.sender_name}!") + + async def my_tool(self, param: str) -> str: + """LLM 工具实现""" + return f"处理结果: {param}" +``` + +### 3. 配置管理 + +```python +class MyPlugin(Star): + async def on_start(self, ctx): + # 获取配置 + config = await ctx.metadata.get_plugin_config() + + # 提供默认值 + self.timeout = config.get("timeout", 30) + self.max_retries = config.get("max_retries", 3) + self.debug = config.get("debug", False) + + # 验证必需配置 + if "api_key" not in config: + raise ValueError("缺少必需配置: api_key") + + self.api_key = config["api_key"] +``` + +### 4. 数据持久化 + +```python +class MyPlugin(Star): + async def on_start(self, ctx): + # 加载状态 + self.last_update = await self.get_kv_data("last_update", 0) + self.user_data = await self.get_kv_data("users", {}) + + async def save_state(self): + # 保存状态 + await self.put_kv_data("last_update", time.time()) + await self.put_kv_data("users", self.user_data) +``` + +### 5. 资源清理 + +```python +class MyPlugin(Star): + async def on_start(self, ctx): + # 创建需要清理的资源 + self._session = aiohttp.ClientSession() + self._task = asyncio.create_task(self.background_task()) + + async def on_stop(self, ctx): + # 清理资源 + if hasattr(self, '_task'): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + + if hasattr(self, '_session'): + await self._session.close() +``` diff --git a/astrbot-sdk/docs/05_clients.md b/astrbot-sdk/docs/05_clients.md new file mode 100644 index 0000000000..68d867038b --- /dev/null +++ b/astrbot-sdk/docs/05_clients.md @@ -0,0 +1,484 @@ +# AstrBot SDK 常用客户端速查 + +## 概述 + +本文档聚焦插件开发中最常用的客户端与使用模式,方便快速查阅。完整的方法签名、返回类型和全部客户端/管理器列表请查看 [API 详细参考](./api/clients.md)。 + +## 目录 + +- [LLMClient - AI 对话客户端](#1-llmclient---ai-对话客户端) +- [MemoryClient - 记忆存储客户端](#2-memoryclient---记忆存储客户端) +- [DBClient - KV 数据库客户端](#3-dbclient---kv-数据库客户端) +- [PlatformClient - 平台消息客户端](#4-platformclient---平台消息客户端) +- [FileServiceClient - 文件服务客户端](#5-fileserviceclient---文件服务客户端) +- [HTTPClient - HTTP API 客户端](#6-httpclient---http-api-客户端) +- [MetadataClient - 插件元数据客户端](#7-metadataclient---插件元数据客户端) +- [其他客户端与管理器](#8-其他客户端与管理器) + +--- + +## 1. LLMClient - AI 对话客户端 + +### 导入 + +```python +from astrbot_sdk.clients import LLMClient, ChatMessage, LLMResponse +``` + +### 方法 + +#### chat() + +简单对话。 + +```python +reply = await ctx.llm.chat("你好,介绍一下自己") +``` + +#### chat_raw() + +获取完整响应。 + +```python +response = await ctx.llm.chat_raw("写一首诗", temperature=0.8) +print(f"Token 使用: {response.usage}") +``` + +#### stream_chat() + +流式对话。 + +```python +async for chunk in ctx.llm.stream_chat("讲一个故事"): + print(chunk, end="") +``` + +--- + +## 2. MemoryClient - 记忆存储客户端 + +### 导入 + +```python +from astrbot_sdk.clients import MemoryClient +``` + +### 方法 + +#### search() + +搜索记忆。默认在有 embedding provider 时执行 hybrid 检索。 + +```python +results = await ctx.memory.search("用户喜欢什么颜色", mode="hybrid", limit=5) +for item in results: + print(item["key"], item["score"], item["match_type"]) +``` + +#### save() + +保存记忆。 + +```python +await ctx.memory.save("user_pref", {"theme": "dark", "lang": "zh"}) +await ctx.memory.save( + "profile:alice", + {"name": "Alice", "embedding_text": "Alice 喜欢蓝色和海边"}, +) +``` + +#### get() + +获取记忆。 + +```python +pref = await ctx.memory.get("user_pref") +``` + +#### save_with_ttl() + +保存带过期时间的记忆。 + +```python +await ctx.memory.save_with_ttl( + "session_temp", + {"state": "waiting"}, + ttl_seconds=3600 +) +``` + +#### delete() + +删除记忆。 + +```python +await ctx.memory.delete("old_note") +``` + +#### stats() + +查看记忆索引状态。 + +```python +stats = await ctx.memory.stats() +print(stats["total_items"], stats.get("embedded_items"), stats.get("dirty_items")) +``` + +--- + +## 3. DBClient - KV 数据库客户端 + +`ctx.db` 的 key 在运行时会自动按插件做命名空间隔离。`list()` 和 `watch()` 返回给插件的 +仍是原始 key 视图,不会暴露内部前缀。 + +### 导入 + +```python +from astrbot_sdk.clients import DBClient +``` + +### 方法 + +#### get() / set() + +基本读写。 + +```python +data = await ctx.db.get("user_settings") +await ctx.db.set("user_settings", {"theme": "dark"}) +``` + +#### delete() + +删除数据。 + +```python +await ctx.db.delete("user_settings") +``` + +#### list() + +列出键。 + +```python +keys = await ctx.db.list("user_") +``` + +#### get_many() / set_many() + +批量操作。 + +```python +values = await ctx.db.get_many(["user:1", "user:2"]) +await ctx.db.set_many({"user:1": {"name": "Alice"}, "user:2": {"name": "Bob"}}) +``` + +#### watch() + +监听变更。 + +```python +async for event in ctx.db.watch("user:"): + print(event["op"], event["key"]) +``` + +--- + +## 4. PlatformClient - 平台消息客户端 + +### 导入 + +```python +from astrbot_sdk.clients import PlatformClient +``` + +### 方法 + +#### send() + +发送文本消息。 + +```python +await ctx.platform.send("qq:group:123456", "大家好!") +``` + +#### send_image() + +发送图片。 + +```python +await ctx.platform.send_image(event.session_id, "https://example.com/image.png") +``` + +#### send_chain() + +发送消息链。 + +```python +from astrbot_sdk.message_components import Plain, Image + +chain = [Plain("文字"), Image(url="https://example.com/img.jpg")] +await ctx.platform.send_chain(event.session_id, chain) +``` + +#### send_by_id() + +通过 ID 发送。 + +```python +await ctx.platform.send_by_id( + platform_id="qq", + session_id="user123", + content="Hello", + message_type="private" +) +``` + +#### get_members() + +获取群成员。 + +```python +members = await ctx.platform.get_members("qq:group:123456") +``` + +--- + +## 5. FileServiceClient - 文件服务客户端 + +### 导入 + +```python +from astrbot_sdk.clients import FileServiceClient +``` + +### 方法 + +#### register_file() + +注册文件。 + +```python +token = await ctx.files.register_file("/path/to/file.jpg", timeout=3600) +``` + +#### handle_file() + +解析令牌。 + +```python +path = await ctx.files.handle_file(token) +``` + +--- + +## 6. HTTPClient - HTTP API 客户端 + +### 导入 + +```python +from astrbot_sdk.clients import HTTPClient +from astrbot_sdk.decorators import provide_capability +``` + +### 方法 + +当前实现会拦截包含 `..` 的路径和部分明显非法输入,但路由校验并非完全严格。 +文档示例建议统一使用以 `/` 开头、没有重复斜杠的规范化路径。`unregister_api(route)` 在不传 +`methods` 时会移除当前插件在该 route 下注册的全部方法。 + +#### register_api() + +注册 API。 + +```python +@provide_capability( + name="my_plugin.http_handler", + description="处理 HTTP 请求" +) +async def handle_http_request(request_id: str, payload: dict, cancel_token): + return {"status": 200, "body": {"result": "ok"}} + +await ctx.http.register_api( + route="/my-api", + handler=handle_http_request, + methods=["GET", "POST"] +) +``` + +#### unregister_api() + +注销 API。 + +```python +await ctx.http.unregister_api("/my-api") +``` + +#### list_apis() + +列出 API。 + +```python +apis = await ctx.http.list_apis() +``` + +--- + +## 7. MetadataClient - 插件元数据客户端 + +### 导入 + +```python +from astrbot_sdk.clients import MetadataClient +``` + +### 方法 + +#### get_plugin() + +获取插件信息。 + +```python +plugin = await ctx.metadata.get_plugin("another_plugin") +if plugin: + print(f"插件: {plugin.display_name}") +``` + +#### list_plugins() + +列出所有插件。 + +```python +plugins = await ctx.metadata.list_plugins() +``` + +#### get_current_plugin() + +获取当前插件。 + +```python +current = await ctx.metadata.get_current_plugin() +``` + +#### get_plugin_config() + +获取配置。 + +```python +config = await ctx.metadata.get_plugin_config() +api_key = config.get("api_key") +``` + +--- + +## 8. 其他客户端与管理器 + +下列客户端也属于 `Context` 的公开能力入口。这里给出用途和详细参考入口,避免常用速查页与完整 API 文档重复维护。 + +- [ProviderClient](./api/clients.md#providerclient---provider-发现客户端): 查询当前可用 Provider,以及当前会话正在使用的 chat / tts / stt Provider。 +- [ProviderManagerClient](./api/clients.md#providermanagerclient---provider-管理客户端): 动态创建、切换、更新、删除 Provider,并监听 Provider 变更。 +- [PersonaManagerClient](./api/clients.md#personamanagerclient---人格管理客户端): 管理人格模板;在 `Context` 中可通过 `ctx.personas` 或 `ctx.persona_manager` 访问。 +- [ConversationManagerClient](./api/clients.md#conversationmanagerclient---对话管理客户端): 管理会话内的多轮对话;在 `Context` 中可通过 `ctx.conversations` 或 `ctx.conversation_manager` 访问。 +- [MessageHistoryManagerClient](./api/clients.md#messagehistorymanagerclient---消息历史管理客户端): 按 `MessageSession` 精确保存消息组件、发送者和元数据;在 `Context` 中可通过 `ctx.message_history` 或 `ctx.message_history_manager` 访问。 +- [KnowledgeBaseManagerClient](./api/clients.md#knowledgebasemanagerclient---知识库管理客户端): 管理知识库、文档和检索;在 `Context` 中可通过 `ctx.kbs` 或 `ctx.kb_manager` 访问。 +- [RegistryClient](./api/clients.md#registryclient---handler-注册表客户端): 查询 handler 元数据,并管理 handler 白名单。 +- [SkillClient](./api/clients.md#skillclient---技能注册客户端): 在运行时注册、注销和列出插件技能目录。 +- [SessionPluginManager](./api/clients.md#sessionpluginmanager---会话插件管理器): 按会话检查插件启用状态并过滤 handler。 +- [SessionServiceManager](./api/clients.md#sessionservicemanager---会话服务管理器): 按会话控制 LLM/TTS 是否启用。 + +--- + +## 客户端使用示例 + +### 1. 基本对话流程 + +```python +@on_message() +async def handle_message(event: MessageEvent, ctx: Context): + reply = await ctx.llm.chat(event.message_content) + await ctx.platform.send(event.session_id, reply) +``` + +### 2. 带历史的对话 + +```python +@on_message() +async def handle_message(event: MessageEvent, ctx: Context): + history_data = await ctx.memory.get(f"history:{event.session_id}") + history = history_data.get("messages", []) if history_data else [] + + reply = await ctx.llm.chat(event.message_content, history=history) + + history.append(ChatMessage(role="user", content=event.message_content)) + history.append(ChatMessage(role="assistant", content=reply)) + await ctx.memory.save(f"history:{event.session_id}", {"messages": history}) + + await ctx.platform.send(event.session_id, reply) +``` + +如果你要保存原始消息链、发送者信息或需要分页清理,可以改用 `ctx.message_history`: + +```python +from astrbot_sdk import MessageHistorySender, MessageSession, Plain + +session = MessageSession( + platform_id=event.platform_id, + message_type=event.message_type, + session_id=event.session_id, +) +await ctx.message_history.append( + session, + parts=[Plain(event.message_content, convert=False)], + sender=MessageHistorySender( + sender_id=event.sender_id, + sender_name=event.sender_name, + ), +) +``` + +### 3. 使用数据库持久化 + +```python +@on_message() +async def handle_message(event: MessageEvent, ctx: Context): + config = await ctx.db.get(f"user_config:{event.sender_id}") + + if not config: + config = {"theme": "light", "lang": "zh"} + await ctx.db.set(f"user_config:{event.sender_id}", config) + + reply = f"你的主题设置是: {config['theme']}" + await ctx.platform.send(event.session_id, reply) +``` + +### 4. 注册 Web API + +```python +@provide_capability( + name="my_plugin.get_status", + description="获取插件状态", +) +async def get_status(request_id: str, payload: dict, cancel_token): + return {"status": "running", "version": "1.0.0"} + +@on_command("setup_api") +async def setup_api(event: MessageEvent, ctx: Context): + await ctx.http.register_api( + route="/status", + handler=get_status, + methods=["GET"] + ) + await ctx.platform.send(event.session_id, "API 已注册") +``` + +--- + +## 注意事项 + +1. 所有客户端方法都是异步的 +2. 远程调用可能失败,建议使用 try-except +3. `Memory` 适合语义检索,`DB` 适合结构化 KV,`MessageHistory` 适合精确保存原始消息记录 +4. `DBClient` 的 key 对插件隔离;`list()` 和 `watch()` 返回的 key 仍是插件本地视图 +5. `HTTPClient.register_api()` 当前会拦截 `..` 等明显非法路径,但仍建议插件自行使用规范化 route;`unregister_api(route)` 默认移除该 route 下全部方法 +6. 文件操作使用 file service 注册令牌 +7. 平台标识使用 UMO 格式:`"platform:instance:session_id"` diff --git a/astrbot-sdk/docs/06_error_handling.md b/astrbot-sdk/docs/06_error_handling.md new file mode 100644 index 0000000000..844b6ed3c1 --- /dev/null +++ b/astrbot-sdk/docs/06_error_handling.md @@ -0,0 +1,625 @@ +# AstrBot SDK 错误处理与调试指南 + +本文档详细介绍 SDK 中的错误处理机制、错误类型、调试技巧和常见问题解决方案。 + +## 目录 + +- [错误处理概述](#错误处理概述) +- [AstrBotError 错误体系](#astrboterror-错误体系) +- [错误码参考](#错误码参考) +- [错误处理模式](#错误处理模式) +- [调试技巧](#调试技巧) +- [常见问题](#常见问题) + +--- + +## 错误处理概述 + +AstrBot SDK 使用统一的错误体系 `AstrBotError`,支持跨进程传递(通过 to_payload/from_payload 序列化)。 + +### 错误处理流程 + +``` +1. 运行时抛出 AstrBotError 子类或实例 +2. 错误被捕获并序列化为 payload +3. 跨进程传输后反序列化 +4. 在 on_error 钩子中统一处理 +``` + +### 基本使用 + +```python +from astrbot_sdk.errors import AstrBotError, ErrorCodes + +# 抛出错误 +raise AstrBotError.invalid_input("参数不能为空") + +# 捕获并处理 +try: + await some_operation() +except AstrBotError as e: + if e.retryable: + # 可重试的错误 + await retry() + else: + # 不可重试的错误 + await event.reply(e.hint or e.message) +``` + +--- + +## AstrBotError 错误体系 + +### AstrBotError 类 + +```python +@dataclass(slots=True) +class AstrBotError(Exception): + code: str # 错误码 + message: str # 错误消息(面向开发者) + hint: str = "" # 用户提示(面向终端用户) + retryable: bool = False # 是否可重试 + docs_url: str = "" # 文档链接 + details: dict[str, Any] | None = None # 详细信息 +``` + +### 工厂方法 + +#### 1. invalid_input - 输入无效错误 + +**场景**:参数格式错误、缺少必需参数等 + +```python +raise AstrBotError.invalid_input( + message="参数格式错误", + hint="请使用 JSON 格式", + docs_url="https://docs.example.com/api" +) +``` + +**属性**: +- `retryable`: False +- 应该在修复输入后重试 + +#### 2. capability_not_found - 能力未找到 + +**场景**:调用的 capability 不存在或未注册 + +```python +raise AstrBotError.capability_not_found("unknown_capability") +``` + +**属性**: +- `retryable`: False +- 通常是配置或版本不匹配问题 + +#### 3. network_error - 网络错误 + +**场景**:连接超时、DNS 解析失败等 + +```python +raise AstrBotError.network_error( + message="连接超时", + hint="请检查网络连接后重试" +) +``` + +**属性**: +- `retryable`: True +- 通常可以重试 + +#### 4. internal_error - 内部错误 + +**场景**:SDK 或 Core 内部错误 + +```python +raise AstrBotError.internal_error( + message="数据库连接失败", + hint="请联系插件作者" +) +``` + +**属性**: +- `retryable`: False +- 需要开发者介入 + +#### 5. cancelled - 取消错误 + +**场景**:操作被取消 + +```python +raise AstrBotError.cancelled("用户取消了操作") +``` + +**属性**: +- `retryable`: False + +#### 6. protocol_version_mismatch - 协议版本不匹配 + +**场景**:SDK 和 Core 协议版本不兼容 + +```python +raise AstrBotError.protocol_version_mismatch("协议版本不匹配: v4 vs v5") +``` + +**属性**: +- `retryable`: False +- 需要升级 SDK 或 Core + +--- + +## 错误码参考 + +### 不可立即自动重试错误(retryable=False) + +这些错误不适合框架做“立刻重试”的自动恢复;其中 `RATE_LIMITED` 和 +`COOLDOWN_ACTIVE` 仍然可以在等待窗口结束后由用户或插件重新发起调用。 + +| 错误码 | 说明 | 处理方式 | +|--------|------|----------| +| `LLM_NOT_CONFIGURED` | LLM 未配置 | 配置 LLM Provider | +| `CAPABILITY_NOT_FOUND` | 能力未找到 | 检查 capability 名称 | +| `PERMISSION_DENIED` | 权限不足 | 检查用户权限 | +| `LLM_ERROR` | LLM 错误 | 查看详细错误信息 | +| `INVALID_INPUT` | 输入无效 | 修正输入参数 | +| `CANCELLED` | 操作被取消 | 无需处理 | +| `PROTOCOL_VERSION_MISMATCH` | 协议版本不匹配 | 升级 SDK | +| `PROTOCOL_ERROR` | 协议错误 | 检查实现 | +| `INTERNAL_ERROR` | 内部错误 | 联系开发者 | +| `RATE_LIMITED` | 速率限制 | 等待速率窗口结束后再重试 | +| `COOLDOWN_ACTIVE` | 冷却中 | 等待冷却结束后再重试 | + +### 可重试错误(retryable=True) + +| 错误码 | 说明 | 处理方式 | +|--------|------|----------| +| `CAPABILITY_TIMEOUT` | 能力调用超时 | 重试或增加超时时间 | +| `NETWORK_ERROR` | 网络错误 | 重试 | +| `LLM_TEMPORARY_ERROR` | LLM 临时错误 | 重试 | + +--- + +## 对话相关异常 + +### ConversationClosed + +对话已关闭异常。 + +**场景**:会话被显式关闭或超时时抛出 + +```python +from astrbot_sdk.conversation import ConversationClosed + +@conversation_command("demo") +async def demo_handler(self, event, ctx, session): + try: + # 处理对话... + session.close() # 关闭会话 + except ConversationClosed: + await event.reply("对话已结束") +``` + +**属性**: +- 继承自 `RuntimeError` +- 表示对话会话已结束,无法再接收消息 + +### ConversationReplaced + +对话被替换异常。 + +**场景**:用户开始新对话,当前对话被替换时抛出 + +```python +from astrbot_sdk.conversation import ConversationReplaced + +@conversation_command("survey") +async def survey_handler(self, event, ctx, session): + try: + # 处理对话... + pass + except ConversationReplaced: + # 用户开始了新对话 + await event.reply("已切换到新对话") +``` + +**属性**: +- 继承自 `RuntimeError` +- 表示当前对话被新对话替换 + +--- + +## 错误处理模式 + +### 模式 1:基本错误处理 + +```python +@on_command("risky") +async def risky_handler(self, event: MessageEvent, ctx: Context): + try: + result = await risky_operation() + await event.reply(f"成功: {result}") + except AstrBotError as e: + # SDK 错误包含用户友好的提示 + await event.reply(e.hint or e.message) + ctx.logger.error(f"操作失败: {e}") + except Exception as e: + # 未知错误 + ctx.logger.exception("未知错误") + await event.reply("操作失败,请稍后重试") +``` + +### 模式 2:分层错误处理 + +```python +async def fetch_data(ctx: Context, url: str) -> dict: + """获取数据,处理网络错误""" + try: + return await ctx.http.get(url) + except AstrBotError as e: + if e.code == ErrorCodes.NETWORK_ERROR: + # 网络错误可以重试 + ctx.logger.warning(f"网络错误,重试: {e}") + await asyncio.sleep(1) + return await ctx.http.get(url) + raise + +@on_command("data") +async def data_handler(self, event: MessageEvent, ctx: Context): + try: + data = await self.fetch_data(ctx, "https://api.example.com/data") + await event.reply(f"数据: {data}") + except AstrBotError as e: + if e.retryable: + await event.reply(f"暂时无法获取数据,请稍后重试") + else: + await event.reply(f"获取数据失败: {e.hint}") +``` + +### 模式 3:on_error 生命周期钩子 + +```python +class MyPlugin(Star): + async def on_error(self, error: Exception, event, ctx) -> None: + """统一错误处理""" + from astrbot_sdk.errors import AstrBotError + + if isinstance(error, AstrBotError): + # SDK 错误 + if error.code == ErrorCodes.RATE_LIMITED: + await event.reply("操作过于频繁,请稍后再试") + elif error.code == ErrorCodes.PERMISSION_DENIED: + await event.reply("你没有权限执行此操作") + else: + await event.reply(error.hint or "操作失败") + elif isinstance(error, ValueError): + # 参数错误 + await event.reply(f"参数错误: {error}") + else: + # 未知错误 + ctx.logger.exception("未处理的错误") + await event.reply("发生未知错误,请联系管理员") +``` + +### 模式 4:重试机制 + +```python +from astrbot_sdk.errors import AstrBotError, ErrorCodes + +async def with_retry( + operation, + max_retries: int = 3, + delay: float = 1.0 +): + """带重试的操作""" + last_error = None + + for attempt in range(max_retries): + try: + return await operation() + except AstrBotError as e: + last_error = e + if not e.retryable: + raise # 不可重试错误直接抛出 + + ctx.logger.warning(f"第 {attempt + 1} 次尝试失败: {e}") + if attempt < max_retries - 1: + await asyncio.sleep(delay * (attempt + 1)) # 指数退避 + + raise last_error + +# 使用 +@on_command("fetch") +async def fetch_handler(self, event: MessageEvent, ctx: Context): + try: + result = await with_retry( + lambda: ctx.llm.chat("生成内容"), + max_retries=3 + ) + await event.reply(result) + except AstrBotError as e: + await event.reply(f"请求失败: {e.hint}") +``` + +### 模式 5:取消处理 + +```python +@on_command("long_task") +async def long_task_handler(self, event: MessageEvent, ctx: Context): + try: + for i in range(100): + # 检查是否取消 + ctx.cancel_token.raise_if_cancelled() + + await do_work(i) + await asyncio.sleep(0.1) + + await event.reply("任务完成") + except asyncio.CancelledError: + await event.reply("任务已取消") + raise # 重新抛出以便框架处理 + except AstrBotError as e: + if e.code == ErrorCodes.CANCELLED: + await event.reply("操作已取消") + else: + raise +``` + +--- + +## 调试技巧 + +### 1. 启用详细日志 + +```python +# 在插件中记录详细日志 +@on_command("debug") +async def debug_handler(self, event: MessageEvent, ctx: Context): + ctx.logger.debug(f"收到消息: {event.text}") + ctx.logger.debug(f"用户ID: {event.user_id}") + ctx.logger.debug(f"会话ID: {event.session_id}") + ctx.logger.debug(f"平台: {event.platform}") + + # 记录组件信息 + components = event.get_messages() + for comp in components: + ctx.logger.debug(f"组件: {comp.type} - {comp}") +``` + +### 2. 使用测试框架调试 + +```python +from astrbot_sdk.testing import PluginTestHarness + +async def test_with_debug(): + harness = PluginTestHarness() + plugin = harness.load_plugin("my_plugin.main:MyPlugin") + + # 启用详细日志 + harness.enable_debug_logging() + + # 模拟事件 + result = await harness.simulate_command("/hello") + print(f"结果: {result}") + + # 查看调用历史 + for call in harness.get_call_history(): + print(f"调用: {call}") +``` + +### 3. 使用 PDB 调试 + +```python +import pdb + +@on_command("debug") +async def debug_handler(self, event: MessageEvent, ctx: Context): + # 设置断点 + pdb.set_trace() + + result = await ctx.llm.chat("测试") + await event.reply(result) +``` + +### 4. 记录完整错误信息 + +```python +import traceback + +@on_command("risky") +async def risky_handler(self, event: MessageEvent, ctx: Context): + try: + result = await risky_operation() + await event.reply(f"成功: {result}") + except Exception as e: + # 记录完整堆栈 + ctx.logger.error(f"错误: {e}") + ctx.logger.error(f"堆栈: {traceback.format_exc()}") + + # 发送简化信息给用户 + await event.reply("操作失败,请查看日志") +``` + +### 5. 使用 Context 的 cancel_token 调试 + +```python +@on_command("timeout_test") +async def timeout_test(self, event: MessageEvent, ctx: Context): + ctx.logger.info(f"取消状态: {ctx.cancel_token.cancelled}") + + try: + # 长时间运行的操作 + for i in range(10): + ctx.logger.debug(f"步骤 {i}, 取消状态: {ctx.cancel_token.cancelled}") + ctx.cancel_token.raise_if_cancelled() + await asyncio.sleep(1) + + await event.reply("完成") + except asyncio.CancelledError: + ctx.logger.info("操作被取消") + raise +``` + +--- + +## 常见问题 + +### Q1: 如何处理 "CAPABILITY_NOT_FOUND" 错误? + +**原因**:调用的 capability 不存在或未注册 + +**解决方案**: +```python +# 检查 Core 版本是否支持 +# 确认 capability 名称正确 +# 检查插件是否正确加载 + +try: + result = await ctx._proxy.call("unknown.capability", {}) +except AstrBotError as e: + if e.code == ErrorCodes.CAPABILITY_NOT_FOUND: + ctx.logger.error("当前 AstrBot 版本不支持此功能") + await event.reply("请升级 AstrBot 到最新版本") +``` + +### Q2: 如何处理速率限制? + +**解决方案**: +```python +from astrbot_sdk.errors import ErrorCodes + +@on_command("api_call") +async def api_call_handler(self, event: MessageEvent, ctx: Context): + try: + result = await call_api() + await event.reply(result) + except AstrBotError as e: + if e.code == ErrorCodes.RATE_LIMITED: + # 获取重试时间(如果有) + retry_after = e.details.get("retry_after", 60) + await event.reply(f"操作过于频繁,请 {retry_after} 秒后再试") + else: + raise +``` + +### Q3: 如何区分用户错误和系统错误? + +**解决方案**: +```python +@on_command("process") +async def process_handler(self, event: MessageEvent, ctx: Context): + try: + result = await process(event.text) + await event.reply(result) + except AstrBotError as e: + if e.code in { + ErrorCodes.INVALID_INPUT, + ErrorCodes.PERMISSION_DENIED + }: + # 用户错误,直接提示 + await event.reply(e.hint or e.message) + else: + # 系统错误,记录并提示 + ctx.logger.error(f"系统错误: {e}") + await event.reply("系统错误,请稍后重试") +``` + +### Q4: 如何在 on_error 中避免无限循环? + +**注意**:如果 `on_error` 中抛出异常,会导致递归调用 + +**解决方案**: +```python +class MyPlugin(Star): + async def on_error(self, error: Exception, event, ctx) -> None: + try: + # 错误处理逻辑 + await event.reply("发生错误") + except Exception as e: + # 避免递归,只记录不回复 + ctx.logger.exception("on_error 失败") +``` + +### Q5: 如何调试跨进程通信问题? + +**解决方案**: +```python +# 启用 SDK 调试日志 +import logging +logging.getLogger("astrbot_sdk").setLevel(logging.DEBUG) + +# 在关键位置添加日志 +@on_command("debug_comm") +async def debug_comm_handler(self, event: MessageEvent, ctx: Context): + ctx.logger.debug("开始调用 capability") + + try: + result = await ctx._proxy.call("test.capability", {"key": "value"}) + ctx.logger.debug(f"调用成功: {result}") + except Exception as e: + ctx.logger.error(f"调用失败: {e}") + raise +``` + +--- + +## 最佳实践 + +### 1. 始终处理可重试错误 + +```python +# 好的做法 +async def reliable_operation(ctx): + max_retries = 3 + for i in range(max_retries): + try: + return await ctx.llm.chat("prompt") + except AstrBotError as e: + if e.retryable and i < max_retries - 1: + await asyncio.sleep(2 ** i) # 指数退避 + else: + raise +``` + +### 2. 提供用户友好的错误提示 + +```python +# 好的做法 +try: + result = await operation() +except AstrBotError as e: + # 使用 SDK 提供的 hint + await event.reply(e.hint or "操作失败,请稍后重试") +``` + +### 3. 区分日志级别 + +```python +# 好的做法 +try: + result = await operation() +except AstrBotError as e: + if e.retryable: + ctx.logger.warning(f"临时错误: {e}") + else: + ctx.logger.error(f"严重错误: {e}") +``` + +### 4. 在 on_stop 中处理清理错误 + +```python +class MyPlugin(Star): + async def on_stop(self, ctx): + try: + await self.cleanup() + except Exception as e: + # 清理错误不应阻止停止流程 + ctx.logger.error(f"清理失败: {e}") +``` + +--- + +## 相关文档 + +- [Context API 参考](./01_context_api.md) +- [Star 类与生命周期](./04_star_lifecycle.md) +- [高级主题](./07_advanced_topics.md) diff --git a/astrbot-sdk/docs/07_advanced_topics.md b/astrbot-sdk/docs/07_advanced_topics.md new file mode 100644 index 0000000000..339f1fba39 --- /dev/null +++ b/astrbot-sdk/docs/07_advanced_topics.md @@ -0,0 +1,585 @@ +# AstrBot SDK 高级主题 + +本文档介绍 AstrBot SDK 的高级用法,包括并发处理、性能优化、安全最佳实践和架构设计。 + +## 目录 + +- [并发处理](#并发处理) +- [性能优化](#性能优化) +- [安全最佳实践](#安全最佳实践) +- [架构设计模式](#架构设计模式) +- [高级客户端用法](#高级客户端用法) + +--- + +## 并发处理 + +### asyncio 基础 + +SDK 完全基于 asyncio 构建,所有操作都是异步的。 + +```python +import asyncio +from astrbot_sdk import Star, Context, MessageEvent +from astrbot_sdk.decorators import on_command + +class MyPlugin(Star): + @on_command("concurrent") + async def concurrent_handler(self, event: MessageEvent, ctx: Context): + # 并发执行多个操作 + tasks = [ + ctx.llm.chat("任务1"), + ctx.llm.chat("任务2"), + ctx.llm.chat("任务3"), + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + + for i, result in enumerate(results): + if isinstance(result, Exception): + await event.reply(f"任务{i+1}失败: {result}") + else: + await event.reply(f"任务{i+1}结果: {result}") +``` + +### 并发限制 + +避免同时发起过多请求: + +```python +import asyncio +from asyncio import Semaphore + +class MyPlugin(Star): + def __init__(self): + # 限制并发数 + self._semaphore = Semaphore(5) + + async def limited_operation(self, ctx, prompt): + async with self._semaphore: + return await ctx.llm.chat(prompt) + + @on_command("batch") + async def batch_handler(self, event: MessageEvent, ctx: Context): + prompts = ["任务1", "任务2", "任务3", "任务4", "任务5"] + + # 使用 semaphore 限制并发 + tasks = [self.limited_operation(ctx, p) for p in prompts] + results = await asyncio.gather(*tasks, return_exceptions=True) + + await event.reply(f"完成 {len(results)} 个任务") +``` + +### 取消处理 + +正确处理操作取消: + +```python +@on_command("cancelable") +async def cancelable_handler(self, event: MessageEvent, ctx: Context): + try: + # 长时间运行的操作 + for i in range(100): + # 检查是否被取消 + ctx.cancel_token.raise_if_cancelled() + + await asyncio.sleep(0.1) + + if i % 10 == 0: + await event.reply(f"进度: {i}%") + + await event.reply("完成!") + except asyncio.CancelledError: + await event.reply("操作已取消") + raise # 重新抛出以便框架处理 +``` + +### 锁和同步 + +保护共享资源: + +```python +import asyncio + +class MyPlugin(Star): + def __init__(self): + self._lock = asyncio.Lock() + self._counter = 0 + + async def increment(self): + async with self._lock: + # 临界区 + current = self._counter + await asyncio.sleep(0.1) # 模拟操作 + self._counter = current + 1 + return self._counter + + @on_command("count") + async def count_handler(self, event: MessageEvent, ctx: Context): + count = await self.increment() + await event.reply(f"当前计数: {count}") +``` + +--- + +## 性能优化 + +### 1. 连接池 + +复用 HTTP 连接: + +```python +import aiohttp + +class MyPlugin(Star): + async def on_start(self, ctx): + # 创建连接池 + self._session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit=100, limit_per_host=20) + ) + + async def on_stop(self, ctx): + await self._session.close() + + async def fetch_data(self, url): + # 复用连接 + async with self._session.get(url) as response: + return await response.json() +``` + +### 2. 缓存策略 + +使用内存缓存减少重复计算: + +```python +from functools import lru_cache +import asyncio + +class MyPlugin(Star): + def __init__(self): + self._cache = {} + self._cache_lock = asyncio.Lock() + + async def get_cached_data(self, key, ttl=300): + async with self._cache_lock: + if key in self._cache: + data, timestamp = self._cache[key] + if asyncio.get_event_loop().time() - timestamp < ttl: + return data + + # 从数据库获取 + data = await self.fetch_from_db(key) + + async with self._cache_lock: + self._cache[key] = (data, asyncio.get_event_loop().time()) + + return data + + async def invalidate_cache(self, key): + async with self._cache_lock: + self._cache.pop(key, None) +``` + +### 3. 批处理 + +批量操作减少网络往返: + +```python +@on_command("batch_db") +async def batch_db_handler(self, event: MessageEvent, ctx: Context): + # 批量获取 + keys = [f"user:{i}" for i in range(100)] + values = await ctx.db.get_many(keys) + + # 批量设置 + updates = {f"user:{i}": {"updated": True} for i in range(100)} + await ctx.db.set_many(updates) + + await event.reply(f"更新了 {len(updates)} 条记录") +``` + +### 4. 流式处理 + +使用流式 API 处理大数据: + +```python +@on_command("stream") +async def stream_handler(self, event: MessageEvent, ctx: Context): + # 流式 LLM 响应 + message = await event.reply("正在生成...") + + full_text = "" + async for chunk in ctx.llm.stream_chat("写一个很长的故事"): + full_text += chunk + # 每 100 个字符更新一次 + if len(full_text) % 100 < 10: + await message.edit(full_text + "...") + + await message.edit(full_text) +``` + +### 5. 懒加载 + +延迟初始化资源: + +```python +class MyPlugin(Star): + def __init__(self): + self._expensive_resource = None + self._resource_lock = asyncio.Lock() + + async def get_resource(self): + if self._expensive_resource is None: + async with self._resource_lock: + if self._expensive_resource is None: + # 昂贵的初始化 + self._expensive_resource = await self.init_resource() + return self._expensive_resource +``` + +--- + +## 安全最佳实践 + +### 1. 输入验证 + +始终验证用户输入: + +```python +import re +from astrbot_sdk.errors import AstrBotError + +@on_command("search") +async def search_handler(self, event: MessageEvent, ctx: Context, query: str): + # 验证输入长度 + if len(query) > 1000: + raise AstrBotError.invalid_input("查询过长,最多 1000 字符") + + # 验证输入内容 + if not re.match(r'^[\w\s\-]+$', query): + raise AstrBotError.invalid_input("查询包含非法字符") + + # 执行搜索 + result = await self.search(query) + await event.reply(result) +``` + +### 2. 防止注入攻击 + +```python +# 危险的代码 +# await ctx.db.set(f"user:{event.user_id}", eval(user_input)) + +# 安全的代码 +import json + +@on_command("save") +async def save_handler(self, event: MessageEvent, ctx: Context, data: str): + try: + # 使用 JSON 解析而不是 eval + parsed = json.loads(data) + await ctx.db.set(f"user:{event.user_id}", parsed) + except json.JSONDecodeError: + raise AstrBotError.invalid_input("无效的 JSON 格式") +``` + +### 3. 敏感信息处理 + +```python +import os + +class MyPlugin(Star): + async def on_start(self, ctx): + config = await ctx.metadata.get_plugin_config() + + # 从配置或环境变量获取敏感信息 + self.api_key = config.get("api_key") or os.getenv("MY_PLUGIN_API_KEY") + + if not self.api_key: + raise ValueError("缺少 API Key") + + # 不要在日志中打印敏感信息 + ctx.logger.info("API Key 已配置") + # 不要: ctx.logger.info(f"API Key: {self.api_key}") +``` + +### 4. 权限检查 + +```python +from astrbot_sdk.decorators import require_admin, require_permission + +class MyPlugin(Star): + @on_command("admin_only") + @require_admin + async def admin_only(self, event: MessageEvent, ctx: Context): + await event.reply("管理员命令执行成功") + + @on_command("panel") + @require_permission("admin") + async def panel(self, event: MessageEvent, ctx: Context): + admins = await ctx.permission.get_admins() + await event.reply(f"当前管理员数量: {len(admins)}") +``` + +`@require_permission(...)` 的 v1 正式角色只支持 `member` 和 `admin`,并与 Core 当前权限模型保持一致。`@require_admin` 仍然可用,内部会归一化为 `required_role="admin"`。 + +### 5. 速率限制 + +```python +from astrbot_sdk.decorators import rate_limit + +class MyPlugin(Star): + @on_command("expensive") + @rate_limit( + limit=5, + window=3600, + scope="user", + message="每小时只能调用 5 次" + ) + async def expensive_operation(self, event: MessageEvent, ctx: Context): + # 昂贵的操作 + result = await ctx.llm.chat("复杂任务", model="gpt-4") + await event.reply(result) +``` + +--- + +## 架构设计模式 + +### 1. 分层架构 + +``` +my_plugin/ +├── __init__.py +├── main.py # 插件入口 +├── handlers/ # 处理器层 +│ ├── __init__.py +│ ├── commands.py # 命令处理器 +│ └── messages.py # 消息处理器 +├── services/ # 业务逻辑层 +│ ├── __init__.py +│ ├── user_service.py +│ └── data_service.py +├── models/ # 数据模型层 +│ ├── __init__.py +│ └── user.py +└── utils/ # 工具层 + ├── __init__.py + └── helpers.py +``` + +### 2. 依赖注入 + +```python +class UserService: + def __init__(self, ctx: Context): + self._ctx = ctx + + async def get_user(self, user_id: str): + return await self._ctx.db.get(f"user:{user_id}") + +class MyPlugin(Star): + async def on_start(self, ctx): + # 注入依赖 + self._user_service = UserService(ctx) + + @on_command("profile") + async def profile_handler(self, event: MessageEvent, ctx: Context): + user = await self._user_service.get_user(event.user_id) + await event.reply(f"用户信息: {user}") +``` + +### 3. 事件驱动架构 + +```python +class MyPlugin(Star): + def __init__(self): + self._event_handlers = {} + + def register_handler(self, event_type, handler): + if event_type not in self._event_handlers: + self._event_handlers[event_type] = [] + self._event_handlers[event_type].append(handler) + + async def emit_event(self, event_type, data): + handlers = self._event_handlers.get(event_type, []) + for handler in handlers: + try: + await handler(data) + except Exception as e: + self.logger.error(f"事件处理失败: {e}") +``` + +### 4. 状态机模式 + +```python +from enum import Enum, auto + +class ConversationState(Enum): + IDLE = auto() + WAITING_INPUT = auto() + PROCESSING = auto() + +class MyPlugin(Star): + def __init__(self): + self._states = {} + + async def get_state(self, session_id): + return self._states.get(session_id, ConversationState.IDLE) + + async def set_state(self, session_id, state): + self._states[session_id] = state + + @on_message() + async def handle_message(self, event: MessageEvent, ctx: Context): + state = await self.get_state(event.session_id) + + if state == ConversationState.IDLE: + await self.handle_idle(event, ctx) + elif state == ConversationState.WAITING_INPUT: + await self.handle_waiting(event, ctx) +``` + +--- + +## 高级客户端用法 + +### 1. ProviderManagerClient + +`ctx.provider_manager` 仅适用于 `reserved/system` 插件。普通插件应使用 `ctx.providers` 查询当前 Provider,而不是调用 Provider 管理能力。 +此外,`set_provider()` 修改的是全局生效的 Provider 选择,不是单个会话的局部设置。 + +```python +from astrbot_sdk import MessageEvent, Star, Context +from astrbot_sdk.decorators import on_command + +class MyPlugin(Star): + @on_command("switch_provider") + async def switch_provider(self, event: MessageEvent, ctx: Context): + # 列出所有 Provider + providers = await ctx.provider_manager.get_insts() + + # 切换 Provider + await ctx.provider_manager.set_provider( + provider_id="gpt-4", + provider_type="chat_completion", + ) + + # 监听 Provider 变更 + async for change in ctx.provider_manager.watch_changes(): + ctx.logger.info(f"Provider 变更: {change.provider_id}") +``` + +### 2. 平台管理 + +```python +@on_command("platform_info") +async def platform_info(self, event: MessageEvent, ctx: Context): + # 列出所有平台实例 + for platform in await ctx.list_platforms(): + ctx.logger.info(f"平台实例: {platform.id} ({platform.status})") + + # 获取平台实例 + platform = await ctx.get_platform_inst("qq:instance1") + + if platform: + await platform.refresh() + await event.reply( + f"平台: {platform.name}\n" + f"状态: {platform.status}\n" + f"错误数: {len(platform.errors)}" + ) +``` + +### 3. 高级 LLM 用法 + +```python +from astrbot_sdk.llm.entities import ProviderRequest + +@on_command("advanced_llm") +async def advanced_llm(self, event: MessageEvent, ctx: Context): + # 使用 ProviderRequest 进行精细控制 + request = ProviderRequest( + prompt="生成内容", + system_prompt="你是一个助手", + temperature=0.7, + max_tokens=2000 + ) + + # 使用工具循环 Agent + response = await ctx.tool_loop_agent( + request=request, + tool_names=["search", "calculate"] + ) + + await event.reply(response.text) +``` + +### 4. 会话管理 + +```python +from astrbot_sdk.conversation import ConversationSession + +@on_command("conversation") +async def conversation_handler(self, event: MessageEvent, ctx: Context): + # 创建会话 + session = ConversationSession( + session_id=event.session_id, + conversation_id="conv_123" + ) + + # 使用会话上下文 + async with session: + await session.send("开始对话") + response = await session.receive() + await session.send(f"收到: {response}") +``` + +--- + +## 性能监控 + +### 1. 添加性能指标 + +```python +import time + +class MyPlugin(Star): + async def monitored_operation(self, operation, *args, **kwargs): + start = time.time() + try: + result = await operation(*args, **kwargs) + return result + finally: + duration = time.time() - start + self.logger.info(f"操作耗时: {duration:.2f}s") + + @on_command("slow") + async def slow_handler(self, event: MessageEvent, ctx: Context): + result = await self.monitored_operation( + ctx.llm.chat, + "复杂查询" + ) + await event.reply(result) +``` + +### 2. 内存监控 + +```python +import sys +import gc + +class MyPlugin(Star): + def log_memory_usage(self): + # 获取内存使用 + gc.collect() + objects = gc.get_objects() + self.logger.debug(f"当前对象数: {len(objects)}") +``` + +--- + +## 相关文档 + +- [错误处理与调试](./06_error_handling.md) +- [测试指南](./08_testing_guide.md) +- [安全检查清单](./11_security_checklist.md) diff --git a/astrbot-sdk/docs/08_testing_guide.md b/astrbot-sdk/docs/08_testing_guide.md new file mode 100644 index 0000000000..0e0942b0ea --- /dev/null +++ b/astrbot-sdk/docs/08_testing_guide.md @@ -0,0 +1,610 @@ +# AstrBot SDK 测试指南 + +本文档介绍如何测试 AstrBot SDK 插件,包括单元测试、集成测试和使用测试框架。 + +## 目录 + +- [测试概述](#测试概述) +- [测试框架](#测试框架) +- [单元测试](#单元测试) +- [集成测试](#集成测试) +- [Mock 使用](#mock-使用) +- [测试最佳实践](#测试最佳实践) + +--- + +## 测试概述 + +### 为什么需要测试? + +1. **确保功能正确性**:验证插件按预期工作 +2. **防止回归**:修改代码时不破坏现有功能 +3. **文档化**:测试用例展示了如何使用代码 +4. **提高信心**:放心地重构和优化代码 + +### 测试类型 + +``` +单元测试 ──→ 集成测试 ──→ 端到端测试 +(最快) (中等) (最慢) +``` + +--- + +## 测试框架 + +### 安装测试依赖 + +```bash +pip install pytest pytest-asyncio pytest-cov +``` + +### 配置 pytest + +```python +# conftest.py +import pytest +from astrbot_sdk.testing import PluginTestHarness + +@pytest.fixture +async def harness(): + """提供测试 harness""" + h = PluginTestHarness() + yield h + await h.cleanup() + +@pytest.fixture +async def plugin(harness): + """加载插件""" + return await harness.load_plugin("my_plugin.main:MyPlugin") +``` + +--- + +## 单元测试 + +### 测试命令处理器 + +```python +import pytest +from astrbot_sdk.testing import PluginTestHarness + +@pytest.mark.asyncio +async def test_hello_command(): + """测试 hello 命令""" + harness = PluginTestHarness() + plugin = await harness.load_plugin("my_plugin.main:MyPlugin") + + # 模拟命令调用 + result = await harness.simulate_command("/hello") + + # 验证结果 + assert result.text == "Hello, World!" + + await harness.cleanup() +``` + +### 测试消息处理器 + +```python +@pytest.mark.asyncio +async def test_message_handler(): + """测试消息处理器""" + harness = PluginTestHarness() + plugin = await harness.load_plugin("my_plugin.main:MyPlugin") + + # 模拟消息 + result = await harness.simulate_message( + text="你好", + user_id="12345", + session_id="session_1" + ) + + # 验证响应 + assert "你好" in result.text + + await harness.cleanup() +``` + +### 测试装饰器 + +```python +@pytest.mark.asyncio +async def test_rate_limit(): + """测试速率限制""" + harness = PluginTestHarness() + plugin = await harness.load_plugin("my_plugin.main:MyPlugin") + + # 第一次调用应该成功 + result1 = await harness.simulate_command("/limited") + assert result1.success + + # 快速第二次调用应该被限制 + result2 = await harness.simulate_command("/limited") + assert result2.error.code == "rate_limited" + + await harness.cleanup() +``` + +--- + +## 集成测试 + +### 测试数据库操作 + +```python +@pytest.mark.asyncio +async def test_database_operations(): + """测试数据库操作""" + harness = PluginTestHarness() + plugin = await harness.load_plugin("my_plugin.main:MyPlugin") + + # 模拟事件以获取 ctx + event = harness.create_mock_event(text="test") + + # 设置数据 + await plugin.save_user_data( + event, + event.ctx, + user_id="123", + data={"name": "Alice"} + ) + + # 读取数据 + data = await plugin.get_user_data( + event, + event.ctx, + user_id="123" + ) + + assert data["name"] == "Alice" + + await harness.cleanup() +``` + +### 测试 LLM 调用 + +```python +@pytest.mark.asyncio +async def test_llm_integration(): + """测试 LLM 调用""" + harness = PluginTestHarness() + + # 配置 mock LLM 响应 + harness.mock_llm_response("模拟的 LLM 回复") + + plugin = await harness.load_plugin("my_plugin.main:MyPlugin") + + # 调用需要 LLM 的命令 + result = await harness.simulate_command("/ask 问题") + + assert "模拟的 LLM 回复" in result.text + + await harness.cleanup() +``` + +### 测试平台发送 + +```python +@pytest.mark.asyncio +async def test_platform_send(): + """测试平台消息发送""" + harness = PluginTestHarness() + plugin = await harness.load_plugin("my_plugin.main:MyPlugin") + + # 模拟命令 + await harness.simulate_command("/broadcast 大家好") + + # 验证发送记录 + sent_messages = harness.get_sent_messages() + assert len(sent_messages) >= 1 + assert "大家好" in sent_messages[0].text + + await harness.cleanup() +``` + +--- + +## Mock 使用 + +### Mock Context + +```python +from unittest.mock import AsyncMock, MagicMock +from astrbot_sdk import Context + +@pytest.fixture +def mock_ctx(): + """创建 mock Context""" + ctx = MagicMock(spec=Context) + + # Mock LLM 客户端 + ctx.llm = AsyncMock() + ctx.llm.chat.return_value = "Mocked response" + + # Mock DB 客户端 + ctx.db = AsyncMock() + ctx.db.get.return_value = {"key": "value"} + + # Mock Logger + ctx.logger = MagicMock() + + return ctx + +@pytest.mark.asyncio +async def test_with_mock_ctx(mock_ctx): + """使用 mock Context 测试""" + plugin = MyPlugin() + + result = await plugin.some_method(mock_ctx) + + # 验证调用 + mock_ctx.llm.chat.assert_called_once() + assert result == "expected" +``` + +### Mock 事件 + +```python +from astrbot_sdk import MessageEvent + +@pytest.fixture +def mock_event(): + """创建 mock 事件""" + event = MagicMock(spec=MessageEvent) + event.text = "测试消息" + event.user_id = "12345" + event.session_id = "session_1" + event.platform = "qq" + + # Mock 回复方法 + event.reply = AsyncMock() + + return event + +@pytest.mark.asyncio +async def test_with_mock_event(mock_event, mock_ctx): + """使用 mock 事件测试""" + plugin = MyPlugin() + + await plugin.handle_message(mock_event, mock_ctx) + + # 验证回复 + mock_event.reply.assert_called_once() +``` + +### Mock 时间 + +```python +import time +from unittest.mock import patch + +@pytest.mark.asyncio +async def test_with_mock_time(): + """使用 mock 时间测试""" + with patch('time.time', return_value=1234567890): + result = await plugin.time_sensitive_operation() + + assert result.timestamp == 1234567890 +``` + +### Mock 外部 API + +```python +import aiohttp +# 需要额外安装: pip install aioresponses +from aioresponses import aioresponses + +@pytest.mark.asyncio +async def test_external_api(): + """测试外部 API 调用""" + with aioresponses() as mocked: + # Mock API 响应 + mocked.get( + 'https://api.example.com/data', + payload={'result': 'success'}, + status=200 + ) + + result = await plugin.fetch_external_data() + + assert result['result'] == 'success' +``` + +--- + +## 测试最佳实践 + +### 1. 测试命名规范 + +```python +# 好的命名 +def test_calculate_sum_with_positive_numbers(): + """测试正数相加""" + pass + +def test_calculate_sum_with_negative_numbers(): + """测试负数相加""" + pass + +# 不好的命名 +def test1(): + pass + +def test_sum(): + pass +``` + +### 2. 一个测试一个概念 + +```python +# 好的做法:每个测试一个断言 +def test_user_creation(): + user = create_user("alice") + assert user.name == "alice" + +def test_user_creation_sets_default_role(): + user = create_user("alice") + assert user.role == "user" + +# 不好的做法:多个概念混在一起 +def test_user(): + user = create_user("alice") + assert user.name == "alice" + assert user.role == "user" + assert user.created_at is not None +``` + +### 3. 使用 Fixtures + +```python +# conftest.py +import pytest + +@pytest.fixture +def sample_user_data(): + """提供测试用户数据""" + return { + "user_id": "123", + "name": "Alice", + "email": "alice@example.com" + } + +@pytest.fixture +async def initialized_plugin(): + """提供已初始化的插件""" + plugin = MyPlugin() + harness = PluginTestHarness() + await plugin.on_start(harness.create_mock_ctx()) + yield plugin + await plugin.on_stop(None) + +# 测试中使用 +def test_with_fixture(sample_user_data, initialized_plugin): + result = initialized_plugin.process_user(sample_user_data) + assert result.success +``` + +### 4. 参数化测试 + +```python +import pytest + +@pytest.mark.parametrize("input,expected", [ + ("hello", "Hello"), + ("world", "World"), + ("", ""), +]) +def test_capitalize(input, expected): + assert input.capitalize() == expected + +@pytest.mark.asyncio +@pytest.mark.parametrize("command,expected_response", [ + ("/help", "可用命令..."), + ("/about", "关于信息..."), + ("/version", "版本号..."), +]) +async def test_commands(command, expected_response): + harness = PluginTestHarness() + plugin = await harness.load_plugin("my_plugin.main:MyPlugin") + + result = await harness.simulate_command(command) + assert expected_response in result.text +``` + +### 5. 测试隔离 + +```python +# 每个测试使用独立的数据 +@pytest.fixture(autouse=True) +def reset_state(): + """每个测试前重置状态""" + MyPlugin._instance_counter = 0 + yield + # 测试后清理 + MyPlugin._instance_counter = 0 + +@pytest.mark.asyncio +async def test_isolated(): + # 这个测试不会受其他测试影响 + plugin = MyPlugin() + assert plugin.id == 1 +``` + +### 6. 异步测试模式 + +```python +import asyncio +import pytest + +@pytest.mark.asyncio +async def test_async_operation(): + """测试异步操作""" + result = await async_function() + assert result == expected + +@pytest.mark.asyncio +async def test_async_timeout(): + """测试超时""" + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for( + slow_function(), + timeout=0.1 + ) + +@pytest.mark.asyncio +async def test_async_exception(): + """测试异常""" + with pytest.raises(ValueError) as exc_info: + await function_that_raises() + + assert "expected error" in str(exc_info.value) +``` + +### 7. 覆盖率检查 + +```bash +# 运行测试并生成覆盖率报告 +pytest --cov=my_plugin --cov-report=html + +# 检查覆盖率 +pytest --cov=my_plugin --cov-fail-under=80 +``` + +```ini +# .coveragerc +[run] +source = my_plugin +omit = + */tests/* + */venv/* + */__pycache__/* + +[report] +exclude_lines = + pragma: no cover + def __repr__ + raise NotImplementedError +``` + +--- + +## 测试工具函数 + +### 常用测试辅助函数 + +```python +# test_utils.py +import asyncio +from contextlib import asynccontextmanager + +async def run_with_timeout(coro, timeout=5): + """带超时运行协程""" + return await asyncio.wait_for(coro, timeout=timeout) + +@asynccontextmanager +async def temporary_database(): + """临时数据库上下文""" + db = await create_test_db() + try: + yield db + finally: + await db.cleanup() + +def create_test_event(**kwargs): + """创建测试事件""" + defaults = { + "text": "test", + "user_id": "12345", + "session_id": "test_session", + "platform": "qq", + } + defaults.update(kwargs) + return MockEvent(**defaults) +``` + +--- + +## 持续集成 + +### GitHub Actions 配置 + +```yaml +# .github/workflows/test.yml +name: Tests + +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + + - name: Install dependencies + run: | + pip install -r requirements.txt + pip install -r requirements-dev.txt + + - name: Run tests + run: | + pytest --cov=my_plugin --cov-report=xml + + - name: Upload coverage + uses: codecov/codecov-action@v3 +``` + +--- + +## 调试测试 + +### 使用 pdb + +```python +import pytest +import pdb + +def test_with_debug(): + result = some_function() + + # 设置断点 + pdb.set_trace() + + assert result.success +``` + +### 使用 pytest 的 --pdb + +```bash +# 失败时自动进入 pdb +pytest --pdb + +# 在第一个失败时停止 +pytest -x --pdb +``` + +### 详细输出 + +```bash +# 详细输出 +pytest -v + +# 最详细输出 +pytest -vv + +# 显示 print 输出 +pytest -s +``` + +--- + +## 相关文档 + +- [错误处理与调试](./06_error_handling.md) +- [高级主题](./07_advanced_topics.md) diff --git a/astrbot-sdk/docs/09_api_reference.md b/astrbot-sdk/docs/09_api_reference.md new file mode 100644 index 0000000000..1f4766148e --- /dev/null +++ b/astrbot-sdk/docs/09_api_reference.md @@ -0,0 +1,34 @@ +# AstrBot SDK 完整 API 参考 + +本文档提供 SDK 所有导出类和函数的完整参考,按模块分类。 + +## 相关文档 + +### 入门文档 +- [README](./README.md) +- [Context API 参考](./01_context_api.md) +- [消息事件与组件](./02_event_and_components.md) +- [装饰器使用指南](./03_decorators.md) + +### API 详细文档 +#### 核心类 +- [Star 类 API](./api/star.md) - 插件基类与生命周期 +- [Context 类 API](./api/context.md) - 运行时上下文与能力客户端 +- [MessageEvent 类 API](./api/message_event.md) - 消息事件对象 + +#### 装饰器与过滤器 +- [装饰器 API](./api/decorators.md) - 事件触发、限制器、过滤器装饰器 + +#### 客户端 +- [客户端 API](./api/clients.md) - LLM、Memory、DB、MessageHistory、Platform 等 17 个客户端与管理器 + +#### 消息处理 +- [消息组件 API](./api/message_components.md) - Plain、Image、At、Record、Video、File 等 +- [消息结果 API](./api/message_result.md) - MessageChain、MessageBuilder、MessageEventResult + +#### 工具与类型 +- [工具与辅助类 API](./api/utils.md) - CancelToken、MessageSession、GreedyStr、CommandGroup 等 +- [类型定义 API](./api/types.md) - 类型别名、泛型变量、Pydantic 模型 + +#### 错误处理 +- [错误处理 API](./api/errors.md) - AstrBotError、ErrorCodes diff --git a/astrbot-sdk/docs/10_migration_guide.md b/astrbot-sdk/docs/10_migration_guide.md new file mode 100644 index 0000000000..9da4360c49 --- /dev/null +++ b/astrbot-sdk/docs/10_migration_guide.md @@ -0,0 +1,494 @@ +# AstrBot SDK 迁移指南 + +本文档帮助开发者从旧版本或其他框架迁移到 AstrBot SDK v4。 + +## 目录 + +- [从 v3 迁移](#从-v3-迁移) +- [从其他框架迁移](#从其他框架迁移) +- [破坏性变更](#破坏性变更) +- [迁移检查清单](#迁移检查清单) + +--- + +## 从 v3 迁移 + +### 插件类定义 + +**v3 (旧版本)**: +```python +from astrbot.api import star + +@star.register("my_plugin") +class MyPlugin(star.Star): + def __init__(self, context): + super().__init__(context) +``` + +**v4 (新版本)**: +```python +from astrbot_sdk import Star + +class MyPlugin(Star): + async def on_start(self, ctx): + pass + + async def on_stop(self, ctx): + pass +``` + +### 装饰器变更 + +**v3**: +```python +from astrbot.api import filter + +@filter.command("hello") +async def hello(self, event): + await event.reply("Hello!") +``` + +**v4**: +```python +from astrbot_sdk.decorators import on_command + +@on_command("hello") +async def hello(self, event, ctx): + await event.reply("Hello!") +``` + +### Context 访问 + +**v3**: +```python +# 通过 self.context +config = self.context.get_config() +reply = await self.context.llm_generate("prompt") +``` + +**v4**: +```python +# 通过参数注入 +async def handler(self, event, ctx): + config = await ctx.metadata.get_plugin_config() + reply = await ctx.llm.chat("prompt") +``` + +### 数据存储 + +**v3**: +```python +# 通过 context +await self.context.put_kv_data("key", value) +data = await self.context.get_kv_data("key", default) +``` + +**v4**: +```python +# 通过 db 客户端 +await ctx.db.set("key", value) +data = await ctx.db.get("key") + +# 或使用 Mixin +from astrbot_sdk import PluginKVStoreMixin + +class MyPlugin(Star, PluginKVStoreMixin): + async def save(self): + await self.put_kv_data("key", value) +``` + +### 消息发送 + +**v3**: +```python +# 通过 event +await event.reply("消息") + +# 主动发送 +await self.context.send_message(session, chain) +``` + +**v4**: +```python +# 通过 event +await event.reply("消息") + +# 主动发送 +await ctx.platform.send(session, "消息") +await ctx.platform.send_chain(session, chain) +``` + +### 生命周期 + +**v3**: +```python +class MyPlugin(Star): + async def initialize(self): + # 初始化 + pass + + async def terminate(self): + # 清理 + pass +``` + +**v4**: +```python +class MyPlugin(Star): + async def on_start(self, ctx): + # 启动时 + await super().on_start(ctx) + + async def on_stop(self, ctx): + # 停止时 + await super().on_stop(ctx) + + # 仍然支持 + async def initialize(self): + pass + + async def terminate(self): + pass +``` + +### 配置获取 + +**v3**: +```python +config = self.context.get_config() +``` + +**v4**: +```python +config = await ctx.metadata.get_plugin_config() +``` + +### LLM 调用 + +**v3**: +```python +reply = await self.context.llm_generate("prompt") + +# 带历史 +reply = await self.context.llm_generate( + "prompt", + contexts=[{"role": "user", "content": "历史"}] +) +``` + +**v4**: +```python +from astrbot_sdk.clients.llm import ChatMessage + +reply = await ctx.llm.chat("prompt") + +# 带历史 +history = [ + ChatMessage(role="user", content="历史"), +] +reply = await ctx.llm.chat("prompt", history=history) +``` + +### 错误处理 + +**v3**: +```python +try: + result = await operation() +except Exception as e: + await event.reply(f"错误: {e}") +``` + +**v4**: +```python +from astrbot_sdk.errors import AstrBotError + +try: + result = await operation() +except AstrBotError as e: + # 使用 SDK 提供的用户友好提示 + await event.reply(e.hint or e.message) +except Exception as e: + ctx.logger.error(f"错误: {e}") + await event.reply("操作失败") +``` + +--- + +## 从其他框架迁移 + +### 从 NoneBot2 迁移 + +**NoneBot2**: +```python +from nonebot import on_command +from nonebot.adapters.onebot.v11 import Bot, Event + +matcher = on_command("hello") + +@matcher.handle() +async def hello(bot: Bot, event: Event): + await matcher.send("Hello!") +``` + +**AstrBot SDK**: +```python +from astrbot_sdk import Star, MessageEvent, Context +from astrbot_sdk.decorators import on_command + +class MyPlugin(Star): + @on_command("hello") + async def hello(self, event: MessageEvent, ctx: Context): + await event.reply("Hello!") +``` + +### 从 Koishi 迁移 + +**Koishi**: +```javascript +ctx.command('hello') + .action(() => 'Hello!') +``` + +**AstrBot SDK**: +```python +from astrbot_sdk import Star, MessageEvent, Context +from astrbot_sdk.decorators import on_command + +class MyPlugin(Star): + @on_command("hello") + async def hello(self, event: MessageEvent, ctx: Context): + await event.reply("Hello!") +``` + +### 从 python-telegram-bot 迁移 + +**python-telegram-bot**: +```python +from telegram import Update +from telegram.ext import ContextTypes + +async def hello(update: Update, context: ContextTypes.DEFAULT_TYPE): + await update.message.reply_text("Hello!") +``` + +**AstrBot SDK**: +```python +from astrbot_sdk import Star, MessageEvent, Context +from astrbot_sdk.decorators import on_command + +class MyPlugin(Star): + @on_command("hello") + @platforms("telegram") + async def hello(self, event: MessageEvent, ctx: Context): + await event.reply("Hello!") +``` + +--- + +## 破坏性变更 + +### v3 → v4 主要变更 + +1. **注册方式** + - v3: `@star.register()` + `@filter.command()` + - v4: `@on_command()` 直接在类方法上 + +2. **Context 获取** + - v3: `self.context` + - v4: `ctx` 参数注入 + +3. **数据存储** + - v3: `self.context.put_kv_data()` + - v4: `ctx.db.set()` 或 `PluginKVStoreMixin` + +4. **配置获取** + - v3: `self.context.get_config()` + - v4: `ctx.metadata.get_plugin_config()` + +5. **LLM 调用** + - v3: `self.context.llm_generate()` + - v4: `ctx.llm.chat()` + +6. **生命周期** + - v3: `initialize()` / `terminate()` + - v4: `on_start()` / `on_stop()`(仍然支持旧方法) + +7. **错误类型** + - v3: 标准 Python 异常 + - v4: `AstrBotError` 体系 + +### 已弃用的功能 + +| v3 功能 | v4 替代方案 | 状态 | +|---------|-------------|------| +| `@star.register()` | 继承 `Star` 类 | 已移除 | +| `self.context` | `ctx` 参数 | 已变更 | +| `filter.command()` | `on_command()` | 已更名 | +| `filter.regex()` | `on_message(regex=...)` | 已变更 | +| `llm_generate()` | `ctx.llm.chat()` | 已更名 | +| `send_message()` | `ctx.platform.send()` | 已更名 | + +--- + +## 迁移检查清单 + +### 代码迁移 + +- [ ] 更新导入语句 +- [ ] 移除 `@star.register()` 装饰器 +- [ ] 将 `@filter.command()` 改为 `@on_command()` +- [ ] 添加 `ctx` 参数到所有 handler +- [ ] 更新 Context 访问方式 +- [ ] 更新数据存储调用 +- [ ] 更新 LLM 调用 +- [ ] 更新配置获取 +- [ ] 更新错误处理 + +### 配置迁移 + +- [ ] 更新 `plugin.yaml` 格式 +- [ ] 检查 `support_platforms` 配置 +- [ ] 更新 `runtime` 配置 + +### 测试迁移 + +- [ ] 更新测试导入 +- [ ] 更新测试 mock +- [ ] 运行测试验证 + +### 文档更新 + +- [ ] 更新 README +- [ ] 更新使用文档 +- [ ] 更新 CHANGELOG + +--- + +## 迁移工具 + +### 自动迁移脚本(示例) + +```python +#!/usr/bin/env python3 +"""v3 到 v4 迁移辅助脚本""" + +import re +import sys +from pathlib import Path + +def migrate_file(file_path: Path): + """迁移单个文件""" + content = file_path.read_text(encoding="utf-8") + + # 替换导入 + content = re.sub( + r'from astrbot\.api import star', + 'from astrbot_sdk import Star, Context, MessageEvent', + content + ) + + # 替换装饰器 + content = re.sub( + r'@star\.register\([^)]*\)', + '', + content + ) + + content = re.sub( + r'@filter\.command\(([^)]*)\)', + r'@on_command(\1)', + content + ) + + # 替换类定义 + content = re.sub( + r'class (\w+)\(star\.Star\)', + r'class \1(Star)', + content + ) + + # 替换 context 访问 + content = re.sub( + r'self\.context\.get_config\(\)', + 'await ctx.metadata.get_plugin_config()', + content + ) + + content = re.sub( + r'self\.context\.llm_generate\(', + 'ctx.llm.chat(', + content + ) + + # 添加 ctx 参数 + content = re.sub( + r'async def (\w+)\(self, event\)', + r'async def \1(self, event, ctx)', + content + ) + + # 写回文件 + file_path.write_text(content, encoding="utf-8") + print(f"已迁移: {file_path}") + +def main(): + if len(sys.argv) < 2: + print("用法: python migrate.py ") + sys.exit(1) + + plugin_dir = Path(sys.argv[1]) + + for py_file in plugin_dir.rglob("*.py"): + migrate_file(py_file) + + print("迁移完成!请手动检查并测试。") + +if __name__ == "__main__": + main() +``` + +--- + +## 常见问题 + +### Q: v3 插件能在 v4 运行吗? + +**A**: 不能,需要进行迁移。但是 SDK 提供了兼容层,可以简化迁移过程。 + +### Q: 可以同时支持 v3 和 v4 吗? + +**A**: 不推荐。建议为 v4 创建新的插件版本。 + +### Q: 迁移后测试失败怎么办? + +**A**: +1. 检查导入是否正确 +2. 确认 `ctx` 参数已添加 +3. 验证异步函数使用 `await` +4. 查看错误日志获取详细信息 + +### Q: 如何逐步迁移? + +**A**: +1. 先迁移插件结构和装饰器 +2. 再迁移业务逻辑 +3. 最后更新测试 +4. 每个阶段都进行测试 + +--- + +## 获取帮助 + +- 查看完整文档:[docs/](./) +- 提交问题:[GitHub Issues](https://github.com/AstrBotDevs/AstrBot/issues) +- 迁移示例:[examples/migration/](./examples/migration/) + +--- + +## 相关文档 + +- [README](./README.md) +- [Context API 参考](./01_context_api.md) +- [Star 类与生命周期](./04_star_lifecycle.md) +- [错误处理与调试](./06_error_handling.md) diff --git a/astrbot-sdk/docs/11_security_checklist.md b/astrbot-sdk/docs/11_security_checklist.md new file mode 100644 index 0000000000..ac4ee5015b --- /dev/null +++ b/astrbot-sdk/docs/11_security_checklist.md @@ -0,0 +1,382 @@ +# AstrBot SDK 安全检查清单 + +本文档包含 SDK 安全开发检查清单和已知安全问题,帮助开发者编写安全的插件。 + +## 目录 + +- [安全检查清单](#安全检查清单) +- [已知安全问题](#已知安全问题) +- [安全最佳实践](#安全最佳实践) +- [安全审计指南](#安全审计指南) + +--- + +## 安全检查清单 + +### 输入验证 + +- [ ] 所有用户输入都经过验证 +- [ ] 输入长度有限制 +- [ ] 输入内容有白名单过滤 +- [ ] 特殊字符被正确转义 + +```python +# ✅ 好的做法 +import re +from astrbot_sdk.errors import AstrBotError + +def validate_input(text: str) -> str: + if len(text) > 1000: + raise AstrBotError.invalid_input("输入过长") + if not re.match(r'^[\w\s\-]+$', text): + raise AstrBotError.invalid_input("包含非法字符") + return text + +# ❌ 不好的做法 +async def unsafe_handler(event, ctx): + result = eval(event.text) # 危险! +``` + +### 敏感信息处理 + +- [ ] API Key 等敏感信息不硬编码 +- [ ] 敏感信息从配置或环境变量读取 +- [ ] 敏感信息不在日志中打印 +- [ ] 敏感信息不存储在不安全的位置 + +```python +# ✅ 好的做法 +import os + +class MyPlugin(Star): + async def on_start(self, ctx): + config = await ctx.metadata.get_plugin_config() + self.api_key = config.get("api_key") or os.getenv("MY_API_KEY") + ctx.logger.info("API Key 已配置") # 不打印实际值 + +# ❌ 不好的做法 +class UnsafePlugin(Star): + api_key = "sk-1234567890" # 硬编码! + + async def on_start(self, ctx): + ctx.logger.info(f"API Key: {self.api_key}") # 泄露! +``` + +### 权限检查 + +- [ ] 管理员命令有权限验证 +- [ ] 敏感操作有二次确认 +- [ ] 资源访问有权限控制 + +```python +# ✅ 好的做法 +from astrbot_sdk.decorators import require_admin + +class MyPlugin(Star): + @on_command("admin_only") + @require_admin + async def admin_cmd(self, event, ctx): + await event.reply("管理员命令") + +# ❌ 不好的做法 +class UnsafePlugin(Star): + @on_command("delete_all") + async def delete_all(self, event, ctx): + # 任何人都可以执行危险操作! + await ctx.db.clear_all() +``` + +### 速率限制 + +- [ ] 昂贵的操作有速率限制 +- [ ] API 调用有配额控制 +- [ ] 资源密集型操作有限制 + +```python +# ✅ 好的做法 +from astrbot_sdk.decorators import rate_limit + +class MyPlugin(Star): + @on_command("generate") + @rate_limit(limit=5, window=3600, scope="user") + async def generate(self, event, ctx): + # 昂贵的 LLM 调用 + result = await ctx.llm.chat("生成内容", model="gpt-4") + await event.reply(result) +``` + +### 资源管理 + +- [ ] 资源正确释放 +- [ ] 连接正确关闭 +- [ ] 任务正确取消 +- [ ] 避免资源泄漏 + +```python +# ✅ 好的做法 +class MyPlugin(Star): + async def on_start(self, ctx): + self._session = aiohttp.ClientSession() + self._task = asyncio.create_task(self.background_task()) + + async def on_stop(self, ctx): + if self._task: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + if self._session: + await self._session.close() +``` + +### 错误处理 + +- [ ] 错误信息不泄露敏感信息 +- [ ] 异常被正确捕获和处理 +- [ ] 错误日志不包含敏感数据 + +```python +# ✅ 好的做法 +try: + result = await operation() +except Exception as e: + ctx.logger.error(f"操作失败: {type(e).__name__}") + await event.reply("操作失败,请稍后重试") + +# ❌ 不好的做法 +try: + result = await operation() +except Exception as e: + await event.reply(f"错误: {str(e)}") # 可能泄露敏感信息 +``` + +--- + +## 已知安全问题 + +当前版本没有已知的 SDK 框架级高风险未修复项。以下历史回归已经关闭, +保留在这里帮助开发者理解为什么这些约束存在: + +- `ProviderManagerClient.register_provider_change_hook()` 现在必须和 + `unregister_provider_change_hook()` 配对使用,避免残留订阅任务。 +- `PlatformCompatFacade` 内部已经串行化状态刷新,插件侧不需要再额外为 + `refresh()` / `clear_errors()` 套一层锁来规避 SDK 自身竞态。 +- Provider 管理路径会先复制 provider payload,再做 merge,避免污染共享缓存。 + +--- + +### 🟡 Medium: 命令参数注入风险 + +**问题描述**: +插件可能直接使用用户输入作为命令参数,存在注入风险。 + +**风险等级**: Medium + +**示例**: +```python +# ❌ 危险 +@on_command("search") +async def search(self, event, ctx, query): + # 如果 query 包含特殊字符,可能引发问题 + os.system(f"grep {query} data.txt") + +# ✅ 安全 +@on_command("search") +async def search(self, event, ctx, query): + # 验证和清理输入 + safe_query = re.sub(r'[^\w\s]', '', query) + subprocess.run(["grep", safe_query, "data.txt"], capture_output=True) +``` + +--- + +### 🟢 Low: 敏感信息可能出现在日志中 + +**问题描述**: +某些错误日志可能包含敏感信息。 + +**风险等级**: Low + +**建议**: +```python +# ✅ 安全的日志记录 +ctx.logger.info(f"用户 {user_id} 执行操作") # 只记录 ID + +# ❌ 不安全的日志记录 +ctx.logger.info(f"用户数据: {user_data}") # 可能包含敏感信息 +``` + +--- + +## 安全最佳实践 + +### 1. 最小权限原则 + +```python +class MyPlugin(Star): + @on_command("public") + async def public_cmd(self, event, ctx): + # 所有人可用 + pass + + @on_command("admin") + @require_admin + async def admin_cmd(self, event, ctx): + # 仅管理员可用 + pass + + @on_command("owner") + async def owner_cmd(self, event, ctx): + # 仅插件所有者可用 + if event.user_id != self.owner_id: + raise AstrBotError.invalid_input("权限不足") +``` + +### 2. 输入验证白名单 + +```python +import re + +ALLOWED_COMMANDS = {"help", "status", "info"} + +def validate_command(cmd: str) -> str: + cmd = cmd.lower().strip() + if cmd not in ALLOWED_COMMANDS: + raise AstrBotError.invalid_input("未知命令") + return cmd +``` + +### 3. 安全的文件操作 + +```python +import os +from pathlib import Path + +BASE_DIR = Path("/safe/directory") + +def safe_read_file(filename: str) -> str: + # 防止目录遍历 + path = (BASE_DIR / filename).resolve() + if not str(path).startswith(str(BASE_DIR)): + raise AstrBotError.invalid_input("非法路径") + + return path.read_text() +``` + +### 4. 安全的正则表达式 + +```python +import re + +# ✅ 使用原始字符串和适当的限制 +pattern = re.compile(r'^[a-zA-Z0-9_]{1,50}$') + +# ❌ 避免复杂的正则,可能导致 ReDoS +# pattern = re.compile(r'(a+)+b') # 危险! +``` + +### 5. 安全配置 + +```python +class MyPlugin(Star): + async def on_start(self, ctx): + config = await ctx.metadata.get_plugin_config() + + # 验证必需配置 + required = ["api_key", "endpoint"] + for key in required: + if key not in config: + raise ValueError(f"缺少必需配置: {key}") + + # 验证配置值 + if not config["api_key"].startswith("sk-"): + raise ValueError("无效的 API Key 格式") + + self.config = config +``` + +--- + +## 安全审计指南 + +### 审计检查清单 + +1. **代码审查** + - [ ] 所有输入都经过验证 + - [ ] 没有使用 eval/exec + - [ ] 没有硬编码的敏感信息 + - [ ] 错误处理不泄露敏感信息 + +2. **依赖审查** + ```bash + # 检查依赖漏洞 + pip install safety + safety check + + # 检查依赖许可证 + pip install pip-licenses + pip-licenses + ``` + +3. **日志审查** + - [ ] 日志不包含密码、token + - [ ] 日志不包含个人隐私信息 + - [ ] 日志有适当的级别 + +4. **权限审查** + - [ ] 敏感操作有权限检查 + - [ ] 没有特权提升漏洞 + - [ ] 资源访问有控制 + +### 安全测试 + +```python +# 测试输入验证 +def test_input_validation(): + # SQL 注入测试 + malicious_input = "' OR '1'='1" + + # XSS 测试 + xss_input = "" + + # 路径遍历测试 + path_input = "../../../etc/passwd" + + # 验证这些输入都被正确拒绝 +``` + +### 安全工具 + +```bash +# 静态分析 +pip install bandit +bandit -r my_plugin/ + +# 类型检查 +pip install mypy +mypy my_plugin/ + +# 代码质量 +pip install pylint +pylint my_plugin/ +``` + +--- + +## 报告安全问题 + +如果您发现 SDK 或插件的安全问题,请通过以下方式报告: + +1. **不要** 在公开 issue 中报告安全问题 +2. 通过项目官方联系渠道私下报告,例如 `community@astrbot.app` +3. 提供详细的复现步骤 +4. 等待修复后再公开 + +--- + +## 相关文档 + +- [错误处理与调试](./06_error_handling.md) +- [高级主题](./07_advanced_topics.md) +- [测试指南](./08_testing_guide.md) diff --git a/astrbot-sdk/docs/12_plugin_capability_registration_flow.md b/astrbot-sdk/docs/12_plugin_capability_registration_flow.md new file mode 100644 index 0000000000..ff954bab12 --- /dev/null +++ b/astrbot-sdk/docs/12_plugin_capability_registration_flow.md @@ -0,0 +1,633 @@ +# 插件注册与能力注册数据流 + +> 作者:whatevertogo +> 生成日期:2026-03-24 + +--- + +## 目录 + +1. [概述](#概述) +2. [核心架构](#核心架构) +3. [插件注册流程](#插件注册流程) +4. [能力注册流程](#能力注册流程) +5. [能力调用流程](#能力调用流程) +6. [CapabilityRouter 机制](#capabilityrouter-机制) +7. [关键数据结构](#关键数据结构) +8. [时序图](#时序图) + +--- + +## 概述 + +AstrBot SDK v4 采用**进程隔离**和**能力路由**架构: + +- **进程隔离**: 每个插件运行在独立 Worker 进程,崩溃不影响其他插件 +- **能力路由**: Supervisor 统一管理所有能力的注册、发现和调用 +- **协议通信**: 通过 v4 协议进行跨进程通信(支持 Stdio/WebSocket) + +### 核心组件 + +| 组件 | 位置 | 职责 | +|------|------|------| +| `SupervisorRuntime` | 主进程 | 管理多个 Worker 进程,聚合所有 handler 和 capability | +| `WorkerSession` | 主进程 | 封装单个 Worker 进程的生命周期和通信 | +| `PluginWorkerRuntime` | Worker 进程 | 插件加载与执行 | +| `HandlerDispatcher` | Worker 进程 | Handler 请求转成真实 Python 调用 | +| `CapabilityDispatcher` | Worker 进程 | Capability 调用分发 | +| `CapabilityRouter` | 主进程 | 能力注册、发现和执行路由 | + +--- + +## 核心架构 + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ AstrBot Core │ +│ (调用能力/发送消息) │ +└────────────────────────────┬────────────────────────────────────┘ + │ invoke/call + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ SupervisorRuntime (主进程) │ +│ ┌───────────────────────────────────────────────────────────┐ │ +│ │ CapabilityRouter │ │ +│ │ _registrations: {capability_name: registration} │ │ +│ │ handler_to_worker: {handler_id: WorkerSession} │ │ +│ │ capability_to_worker: {capability_name: WorkerSession} │ │ +│ └───────────────────────────────────────────────────────────┘ │ +│ │ +│ WorkerSession A ◄────────────► WorkerSession B ◄───────────► ... │ +└──────────────┬──────────────────────────────┬───────────────────┘ + │ stdio/ws │ stdio/ws + ▼ ▼ +┌──────────────────────────┐ ┌──────────────────────────┐ +│ Worker 进程 A │ │ Worker 进程 B │ +│ ┌────────────────────┐ │ │ ┌────────────────────┐ │ +│ │ PluginWorkerRuntime│ │ │ │ PluginWorkerRuntime│ │ +│ │ │ │ │ │ │ │ +│ │ HandlerDispatcher │ │ │ │ HandlerDispatcher │ │ +│ │ CapabilityDispatcher│ │ │ CapabilityDispatcher│ │ +│ │ │ │ │ │ │ │ +│ │ loaded.handlers │ │ │ │ loaded.handlers │ │ +│ │ loaded.capabilities│ │ │ loaded.capabilities│ │ +│ └────────────────────┘ │ │ └────────────────────┘ │ +└──────────┬─────────────────┘ └──────────┬─────────────────┘ + │ │ + ▼ ▼ +┌──────────────────────────┐ ┌──────────────────────────┐ +│ Plugin A (Star) │ │ Plugin B (Star) │ +│ │ │ │ +│ @on_command │ │ @on_command │ +│ @on_message │ │ @on_message │ +│ @provide_capability │ │ @provide_capability │ +│ @llm_tool │ │ @llm_tool │ +│ @on_schedule │ │ @on_schedule │ +│ @http_api │ │ @http_api │ +└──────────────────────────┘ └──────────────────────────┘ +``` + +--- + +## 插件注册流程 + +### 阶段一:插件发现 (Supervisor 侧) + +``` +SupervisorRuntime.start() + │ + ▼ +discover_plugins(plugins_dir) + │ + ├─► 遍历 plugins 目录下的子目录 + │ │ + │ ├─► 检查 plugin.yaml 是否存在 + │ │ + │ ├─► load_plugin_spec(entry) + │ │ ├─ 读取 plugin.yaml + │ │ ├─ 解析 manifest_data + │ │ │ (name, author, version, components, runtime.python) + │ │ └─ 返回 PluginSpec + │ │ + │ ├─► validate_plugin_spec(plugin) + │ │ └─ 验证必要字段 (name, components) + │ │ + │ └─► 添加到 PluginDiscoveryResult.plugins + │ + ▼ +env_manager.plan(discovery.plugins) + │ + ├─► 按依赖兼容性分组 + │ + └─► 生成 EnvironmentGroups +``` + +**关键数据结构**: + +```python +@dataclass +class PluginSpec: + """插件规范""" + name: str # 插件名称 + plugin_dir: Path # 插件目录 + manifest_path: Path # plugin.yaml 路径 + requirements_path: Path # requirements.txt 路径 + python_version: str # Python 版本要求 + manifest_data: dict # 原始 manifest 数据 + +@dataclass +class PluginDiscoveryResult: + """发现结果""" + plugins: list[PluginSpec] # 成功发现的插件 + skipped_plugins: list[PluginSpec] # 跳过的插件 + issues: list[str] # 问题列表 +``` + +### 阶段二:插件加载 (Worker 侧) + +``` +PluginWorkerRuntime.__init__(plugin_dir) + │ + ▼ +load_plugin_spec(plugin_dir) + │ + ▼ +load_plugin(plugin) + │ + ├─► 将插件目录添加到 sys.path + │ + ├─► _plugin_component_classes(plugin) + │ │ + │ ├─ 读取 components 列表 (如 ["main:MyPlugin"]) + │ │ + │ └─ import_string(class_path) + │ └─ 动态导入组件类 + │ + ├─► 遍历每个组件类: + │ │ + │ ├─► instance = component_cls() # 无参实例化 + │ │ + │ ├─► _iter_discoverable_names(instance) + │ │ └─ 扫描所有公共方法 + │ │ + │ ├─► _resolve_handler_candidate(method) + │ │ └─ 解析 @on_command, @on_message, @on_event 等装饰器 + │ │ → 生成 LoadedHandler + │ │ + │ ├─► _resolve_capability_candidate(method) + │ │ └─ 解析 @provide_capability 装饰器 + │ │ → 生成 LoadedCapability + │ │ + │ ├─► _resolve_llm_tool_candidate(method) + │ │ └─ 解析 @llm_tool 装饰器 + │ │ → 生成 LoadedLLMTool + │ │ + │ └─► _iter_agent_candidates(method) + │ └─ 解析 @agent 装饰器 + │ → 生成 LoadedAgent + │ + ▼ +返回 LoadedPlugin + │ + ▼ +创建 HandlerDispatcher(handlers) +创建 CapabilityDispatcher(capabilities) +``` + +**关键数据结构**: + +```python +@dataclass +class LoadedPlugin: + """加载后的插件""" + plugin: PluginSpec # 插件规范 + handlers: list[LoadedHandler] # 处理器列表 + capabilities: list[LoadedCapability] # 能力列表 + llm_tools: list[LoadedLLMTool] # LLM 工具列表 + agents: list[LoadedAgent] # Agent 列表 + instances: list[Any] # 组件实例列表 + +@dataclass +class LoadedHandler: + """加载后的处理器""" + descriptor: HandlerDescriptor # 描述符 + callable: Callable # 可调用方法 + owner: Any # 所属实例 + plugin_id: str # 插件 ID + local_filters: list # 过滤器 + limiter: Optional[RateLimiter] # 限流器 + conversation: Optional[ConversationConfig] # 会话配置 + +@dataclass +class LoadedCapability: + """加载后的能力""" + descriptor: CapabilityDescriptor # 描述符 + callable: Callable # 可调用方法 + owner: Any # 所属实例 + plugin_id: str # 插件 ID +``` + +--- + +## 能力注册流程 + +### 插件中声明能力 + +```python +from astrbot_sdk import Star, Context +from astrbot_sdk.decorators import provide_capability + +class MyPlugin(Star): + @provide_capability( + name="my_plugin.calculate", + description="执行数学计算", + input_schema={ + "type": "object", + "properties": { + "x": {"type": "number"}, + "y": {"type": "number"} + }, + "required": ["x", "y"] + }, + output_schema={ + "type": "object", + "properties": { + "result": {"type": "number"} + }, + "required": ["result"] + } + ) + async def calculate(self, payload: dict, ctx: Context) -> dict: + x = payload.get("x", 0) + y = payload.get("y", 0) + return {"result": x + y} +``` + +### 握手注册流程 + +``` +Worker 侧 Supervisor 侧 + │ │ + │ PluginWorkerRuntime.start() │ + │ │ │ + │ ▼ │ + │ peer.initialize( │ + │ handlers=[handler.descriptor...], │ + │ provided_capabilities=[cap.desc...], │ + │ metadata={...} │ + │ ) │ + │ │ │ + │ ▼ │ + │ 构建 InitializeMessage │ + │ │ │ + │ ├─────────────────────────────────► │ + │ │ InitializeMessage │ + │ │ │ + │ │ WorkerSession._handle_initialize() + │ │ │ + │ │ ├─ 解析 remote_handlers + │ │ │ └─ handler_to_worker[id] = session + │ │ │ + │ │ ├─ 解析 remote_provided_capabilities + │ │ │ └─ _register_plugin_capability() + │ │ │ │ + │ │ │ ├─ 检查命名冲突 + │ │ │ │ ├─ 保留命名空间 (handler/system/internal) + │ │ │ │ │ → 跳过并警告 + │ │ │ │ └─ 普通冲突 + │ │ │ │ → 添加插件前缀 (如 plugin.echo) + │ │ │ │ + │ │ │ └─ CapabilityRouter.register() + │ │ │ ├─ _registrations[name] = registration + │ │ │ └─ capability_to_worker[name] = session + │ │ │ + │ │ └─ 构建 InitializeOutput + │ │ │ + │ ◄─────────────────────────────────┤ + │ │ ResultMessage(kind="init") │ + │ │ + InitializeOutput │ + │ │ │ + │ ▼ │ + │ 握手完成,插件就绪 │ +``` + +### 冲突处理规则 + +| 场景 | 处理方式 | +|------|---------| +| 保留命名空间冲突 (`handler.*`, `system.*`, `internal.*`) | 跳过注册,输出警告日志 | +| 普通命名冲突 | 自动添加插件名前缀,如 `demo.echo` → `my_plugin.demo.echo` | +| 无冲突 | 直接注册 | + +--- + +## 能力调用流程 + +### 从 Core 到 Plugin + +``` +AstrBot Core + │ + │ 调用能力 (如 llm.chat, platform.send, 或插件能力) + │ + ▼ +SupervisorRuntime._handle_upstream_invoke(message, cancel_token) + │ + ▼ +CapabilityRouter.execute(capability, payload, stream, cancel_token, request_id) + │ + ├─► 查找 _registrations[capability] + │ + ├─► 验证 input_schema (JSON Schema) + │ + └─► 调用注册的处理器 + │ + ▼ +_make_plugin_capability_caller(session, capability_name) + │ + ▼ +WorkerSession.invoke_capability(capability_name, payload, request_id) + │ + ▼ +peer.invoke(capability_name, payload, request_id) + │ + │ 构建 InvokeMessage + │ + ▼ +发送到 Worker 进程 +``` + +### Worker 侧执行 + +``` +Worker 进程收到 InvokeMessage + │ + ▼ +PluginWorkerRuntime._handle_invoke(message, cancel_token) + │ + ▼ +CapabilityDispatcher.invoke(message, cancel_token) + │ + ├─► 查找 _capabilities[capability] + │ + ├─► 构建 Context + │ Context( + │ peer=peer, + │ plugin_id=plugin_id, + │ request_id=request_id, + │ cancel_token=cancel_token + │ ) + │ + ├─► 绑定 logger (caller_plugin_scope) + │ + └─► _run_capability(loaded, payload, ctx, cancel_token, stream) + │ + ├─► _build_args() # 参数注入 + │ │ + │ ├─ 按类型注入: Context, CancelToken, dict + │ │ + │ └─ 按参数名注入: ctx, context, payload, ... + │ + ├─► result = loaded.callable(*args) # 执行用户方法 + │ + └─► _normalize_output(result) # 标准化输出 +``` + +### 返回结果 + +``` +Worker 侧 Supervisor 侧 + │ │ + │ 执行完成,返回结果 │ + │ │ │ + │ ▼ │ + │ 构建 ResultMessage │ + │ │ │ + │ ├─────────────────────────────────► │ + │ │ ResultMessage │ + │ │ {success: true, output: {...}} │ + │ │ │ + │ │ CapabilityRouter 处理结果 + │ │ │ + │ │ ├─ 验证 output_schema + │ │ │ + │ │ └─ 返回给调用方 + │ │ │ +``` + +--- + +## CapabilityRouter 机制 + +### 核心职责 + +1. **能力注册表**: 维护所有可用能力的描述符和处理器 +2. **Schema 验证**: 输入/输出的 JSON Schema 验证 +3. **路由转发**: 将调用转发到对应的 Worker 进程 +4. **冲突处理**: 能力名称冲突时的自动重命名 + +### 注册表结构 + +```python +@dataclass +class _CapabilityRegistration: + """能力注册项""" + descriptor: CapabilityDescriptor # 能力描述符 + call_handler: Callable # 同步调用处理器 + stream_handler: Optional[Callable] # 流式调用处理器 + finalize: Optional[Callable] # 清理函数 + exposed: bool # 是否对外暴露 + +class CapabilityRouter: + # 能力注册表 + _registrations: dict[str, _CapabilityRegistration] + + # Handler 到 Worker 的映射 + handler_to_worker: dict[str, WorkerSession] + + # Capability 到 Worker 的映射 + capability_to_worker: dict[str, WorkerSession] +``` + +### 内置能力命名空间 + +| 命名空间 | 能力示例 | 说明 | +|---------|---------|------| +| `llm.*` | `llm.chat`, `llm.stream_chat` | LLM 对话 | +| `memory.*` | `memory.search`, `memory.save` | 记忆存储 | +| `db.*` | `db.get`, `db.set`, `db.watch` | KV 存储 | +| `platform.*` | `platform.send`, `platform.send_image` | 消息发送 | +| `provider.*` | `provider.get_using`, `provider.list_all` | Provider 管理 | +| `metadata.*` | `metadata.get_plugin`, `metadata.list_plugins` | 插件元数据 | +| `http.*` | `http.register_api`, `http.list_apis` | HTTP API | +| `system.*` | `system.get_data_dir`, `system.text_to_image` | 系统功能 | +| `message_history.*` | `message_history.list`, `message_history.append` | 消息历史 | + +### Schema 验证流程 + +``` +CapabilityRouter.execute() + │ + ├─► 获取 _registrations[capability] + │ + ├─► 输入验证 + │ │ + │ └─ validate(descriptor.input_schema, payload) + │ ├─ 检查 required 字段 + │ ├─ 检查类型匹配 + │ └─ 失败返回 ErrorPayload + │ + ├─► 执行调用 + │ │ + │ └─ call_handler(payload, cancel_token, request_id) + │ + └─► 输出验证 + │ + └─ validate(descriptor.output_schema, result) + ├─ 检查 required 字段 + ├─ 检查类型匹配 + └─ 失败返回 ErrorPayload +``` + +--- + +## 关键数据结构 + +### 描述符模型 + +#### HandlerDescriptor + +```python +{ + "id": "plugin.module:handler_name", + "trigger": { + "type": "command", + "command": "hello", + "aliases": ["hi"], + "description": "打招呼命令" + }, + "kind": "handler", # handler | hook | tool | session + "contract": "message_event", # message_event | schedule + "priority": 0, + "permissions": {"require_admin": false, "level": 0}, + "filters": [], + "param_specs": [] +} +``` + +#### CapabilityDescriptor + +```python +{ + "name": "my_plugin.calculate", + "description": "执行数学计算", + "input_schema": { + "type": "object", + "properties": { + "x": {"type": "number"}, + "y": {"type": "number"} + }, + "required": ["x", "y"] + }, + "output_schema": { + "type": "object", + "properties": { + "result": {"type": "number"} + }, + "required": ["result"] + }, + "streaming": false +} +``` + +### 协议消息模型 + +| 消息类型 | 用途 | 关键字段 | +|---------|------|---------| +| `InitializeMessage` | 握手初始化 | `protocol_version`, `peer`, `handlers`, `provided_capabilities` | +| `InvokeMessage` | 调用能力 | `capability`, `input`, `stream`, `caller_plugin_id` | +| `ResultMessage` | 返回结果 | `success`, `output`, `error`, `kind` | +| `EventMessage` | 流式事件 | `phase` (started/delta/completed/failed), `data` | +| `CancelMessage` | 取消调用 | `reason` | + +--- + +## 时序图 + +### 完整生命周期时序图 + +``` +┌─────────┐ ┌────────────┐ ┌──────────────┐ ┌────────────────┐ +│ Core │ │ Supervisor │ │ WorkerSession│ │ Worker Runtime │ +└────┬────┘ └─────┬──────┘ └──────┬───────┘ └───────┬────────┘ + │ │ │ │ + │ │ start() │ │ + │ ├──────────────────►│ │ + │ │ │ 启动 Worker 进程 │ + │ │ ├────────────────────►│ + │ │ │ │ + │ │ │ load_plugin() │ + │ │ │ ├──────┐ + │ │ │ │ │ 解析装饰器 + │ │ │ │ │ 加载组件 + │ │ │ │◄─────┘ + │ │ │ │ + │ │ │ InitializeMessage │ + │ │ │◄────────────────────┤ + │ │ │ │ + │ │ _handle_initialize() │ + │ ├──────────────────►│ │ + │ │ │ │ + │ │ │ 注册 handlers │ + │ │ │ 注册 capabilities │ + │ │ │ │ + │ │ │ ResultMessage │ + │ │ ├────────────────────►│ + │ │ │ │ + │ │ │ 握手完成 │ + │ │ │ │ + │ 调用能力 │ │ │ + ├───────────────►│ │ │ + │ │ │ │ + │ │ execute() │ │ + │ ├──────────────────►│ │ + │ │ │ │ + │ │ │ InvokeMessage │ + │ │ ├────────────────────►│ + │ │ │ │ + │ │ │ │ 执行用户方法 + │ │ │ ├──────┐ + │ │ │ │ │ + │ │ │ │◄─────┘ + │ │ │ │ + │ │ │ ResultMessage │ + │ │ │◄────────────────────┤ + │ │ │ │ + │ 返回结果 │ │ │ + │◄───────────────┤ │ │ + │ │ │ │ +``` + +--- + +## 附录 + +### 相关文件 + +| 文件 | 说明 | +|------|------| +| `astrbot-sdk/src/astrbot_sdk/runtime/loader.py` | 插件发现与加载 | +| `astrbot-sdk/src/astrbot_sdk/runtime/bootstrap.py` | Supervisor/Worker 启动 | +| `astrbot-sdk/src/astrbot_sdk/runtime/capability_router.py` | 能力路由 | +| `astrbot-sdk/src/astrbot_sdk/runtime/capability_dispatcher.py` | 能力分发 | +| `astrbot-sdk/src/astrbot_sdk/runtime/handler_dispatcher.py` | Handler 分发 | +| `astrbot-sdk/src/astrbot_sdk/runtime/peer.py` | 协议对等端 | +| `astrbot-sdk/src/astrbot_sdk/protocol/messages.py` | 协议消息模型 | +| `astrbot-sdk/src/astrbot_sdk/protocol/descriptors.py` | 描述符模型 | +| `astrbot-sdk/src/astrbot_sdk/decorators.py` | 装饰器定义 | + +### 版本信息 + +- **SDK 版本**: v4.0 +- **协议版本**: P0.6 +- **Python 要求**: >=3.12 diff --git a/astrbot-sdk/docs/INDEX.md b/astrbot-sdk/docs/INDEX.md new file mode 100644 index 0000000000..ac2f99012b --- /dev/null +++ b/astrbot-sdk/docs/INDEX.md @@ -0,0 +1,131 @@ +# AstrBot SDK 文档目录 + +本文档目录包含完整的 SDK 开发文档,按难度级别分类。 + +## 📚 文档列表(按学习路径) + +### 🚀 快速开始(初级使用者) + +适合第一次接触 AstrBot SDK 的开发者: + +| 文档 | 描述 | 行数 | +|------|------|------| +| [README.md](./README.md) | 文档首页、快速开始、核心概念 | ~450 | +| [01_context_api.md](./01_context_api.md) | Context 类的核心客户端和系统工具方法 | ~1,000 | +| [02_event_and_components.md](./02_event_and_components.md) | MessageEvent 和消息组件的使用 | ~590 | +| [03_decorators.md](./03_decorators.md) | 所有装饰器的详细说明 | ~610 | +| [04_star_lifecycle.md](./04_star_lifecycle.md) | 插件基类和生命周期钩子 | ~530 | +| [05_clients.md](./05_clients.md) | 常用客户端速查与详细参考入口 | ~450 | + +### 🔧 进阶主题(中级使用者) + +适合已经掌握基础,希望深入了解 SDK 的开发者: + +| 文档 | 描述 | 行数 | +|------|------|------| +| [06_error_handling.md](./06_error_handling.md) | 完整的错误处理指南和调试技巧 | ~530 | +| [07_advanced_topics.md](./07_advanced_topics.md) | 并发处理、性能优化、安全最佳实践 | ~550 | +| [08_testing_guide.md](./08_testing_guide.md) | 如何测试插件和 Mock 使用 | ~450 | + +### 📖 参考资料(高级使用者) + +适合需要深入了解 SDK 架构和完整 API 的开发者: + +| 文档 | 描述 | 行数 | +|------|------|------| +| [09_api_reference.md](./09_api_reference.md) | 所有导出类和函数的完整参考入口 | ~30 | +| [10_migration_guide.md](./10_migration_guide.md) | 从旧版本或其他框架迁移 | ~490 | +| [11_security_checklist.md](./11_security_checklist.md) | 安全开发检查清单和已知问题 | ~380 | +| [PROJECT_ARCHITECTURE.md](./PROJECT_ARCHITECTURE.md) | SDK 架构设计文档 | ~560 | + +--- + +## 📊 文档统计 + +- **学习路径文档数**: 13 个 +- **API 子文档数**: 10 个 +- **Markdown 文档总数**: 24 个 +- **总内容行数**: ~15,400 行 +- **客户端与管理器数**: 17 个 +- **API 覆盖率**: 保持与当前公开导出同步(含 `message_history` 新增导出) + +--- + +## 🎯 文档内容覆盖 + +### 已涵盖的主题 + +✅ **基础使用** +- Context API 完整参考 +- 消息事件处理 +- 消息组件使用 +- Message History 精确消息历史管理 +- 装饰器使用 +- 生命周期管理 + +✅ **错误处理** +- AstrBotError 完整文档 +- 错误码参考 +- 错误处理模式 +- 调试技巧 + +✅ **高级主题** +- 并发处理 +- 性能优化 +- 安全最佳实践 +- 架构设计模式 + +✅ **测试** +- 单元测试 +- 集成测试 +- Mock 使用 +- 测试最佳实践 + +✅ **API 参考** +- 所有导出类的完整参考 +- 方法签名 +- 使用示例 +- DB 插件作用域与 HTTP 路由约束说明 + +✅ **迁移指南** +- v3 → v4 迁移 +- 从其他框架迁移 +- 破坏性变更列表 +- 迁移检查清单 + +✅ **安全检查清单** +- 安全开发检查清单 +- 已知安全问题(包含发现的问题) +- 安全最佳实践 +- 安全审计指南 + + +## 📝 文档使用建议 + +### 初级开发者 +1. 从 [README.md](./README.md) 开始 +2. 阅读 01-05 文档了解基础 API +3. 参考示例代码编写第一个插件 + +### 中级开发者 +1. 阅读 [06_error_handling.md](./06_error_handling.md) 建立健壮的错误处理 +2. 学习 [07_advanced_topics.md](./07_advanced_topics.md) 的并发和性能优化 +3. 按照 [08_testing_guide.md](./08_testing_guide.md) 编写测试 + +### 高级开发者 +1. 阅读 [09_api_reference.md](./09_api_reference.md) 了解所有可用功能 +2. 研究 [07_advanced_topics.md](./07_advanced_topics.md) 中的架构设计 +3. 阅读 [PROJECT_ARCHITECTURE.md](./PROJECT_ARCHITECTURE.md) 深入理解实现 + +--- + +## 🔗 相关资源 + +- **项目地址**: https://github.com/AstrBotDevs/AstrBot +- **SDK 版本**: v4.0 +- **协议版本**: P0.6 +- **Python 要求**: >= 3.12 + +--- + +**最后更新**: 2026-03-22 diff --git a/astrbot-sdk/docs/PROJECT_ARCHITECTURE.md b/astrbot-sdk/docs/PROJECT_ARCHITECTURE.md new file mode 100644 index 0000000000..749655af17 --- /dev/null +++ b/astrbot-sdk/docs/PROJECT_ARCHITECTURE.md @@ -0,0 +1,571 @@ +# AstrBot SDK 架构概述文档 + +> 作者:whatevertogo +> 生成日期:2026-03-19 + +--- + +## 目录 + +1. [项目概述](#项目概述) +2. [核心架构层次](#核心架构层次) +3. [协议层设计](#协议层设计) +4. [运行时架构](#运行时架构) +5. [客户端层设计](#客户端层设计) +6. [插件开发指南](#插件开发指南) +7. [关键设计模式](#关键设计模式) +8. [文档与资源](#文档与资源) + +--- + +## 项目概述 + +AstrBot SDK 是一个基于 Python 3.12+ 的机器人插件开发框架,采用**进程隔离**和**能力路由**架构,支持插件的动态加载、独立运行和跨进程通信。 + +### 核心特性 + +| 特性 | 描述 | +|------|------| +| **进程隔离** | 每个插件运行在独立 Worker 进程,崩溃不影响其他插件 | +| **环境分组** | 多插件可共享同一 Python 虚拟环境,节省资源 | +| **能力路由** | 显式声明的 Capability 系统,支持 JSON Schema 验证 | +| **流式支持** | 原生支持流式 LLM 调用和增量结果返回 | +| **向后兼容** | 完整的旧版 API 兼容层,支持无修改迁移 | +| **协议优先** | 基于 v4 协议的统一通信模型,支持多种传输方式 | + +### 技术栈 + +- **Python**: 3.12+ +- **异步框架**: asyncio +- **Web 框架**: aiohttp +- **数据验证**: pydantic +- **日志**: loguru +- **配置**: pyyaml +- **LLM**: openai, anthropic, google-genai +- **包管理**: uv (环境分组) + +--- + +## 核心架构层次 + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ 用户层 (Plugin Developer) │ +├─────────────────────────────────────────────────────────────────┤ +│ v4 入口: astrbot_sdk.{Star, Context, MessageEvent} │ +│ 装饰器: on_command, on_message, on_event, on_schedule │ +│ provide_capability, require_admin │ +│ 过滤器: PlatformFilter, MessageTypeFilter, CustomFilter │ +│ 命令组: CommandGroup, command_group │ +│ 会话: MessageSession, session_waiter │ +└────────────────────┬────────────────────────────────────────────┘ + │ +┌──────────────────▼─────────────────────────────────────────────┐ +│ 高层 API (High-Level API) │ +├─────────────────────────────────────────────────────────────────┤ +│ 能力客户端 (通过 CapabilityProxy 调用): │ +│ - LLMClient (llm.chat, llm.chat_raw, llm.stream_chat)│ +│ - MemoryClient (memory.search, memory.save, memory.stats, │ +│ memory.list_keys, memory.exists, │ +│ memory.clear_namespace, memory.count) │ +│ - DBClient (db.get, db.set, db.watch, db.list) │ +│ - PlatformClient (platform.send, platform.send_image, ...)│ +│ - HTTPClient (http.register_api, http.list_apis) │ +│ - MetadataClient (metadata.get_plugin, metadata.list_plugins)│ +└────────────────────┬────────────────────────────────────────────┘ + │ +┌──────────────────▼─────────────────────────────────────────────┐ +│ 执行边界 (Execution Boundary) │ +├─────────────────────────────────────────────────────────────────┤ +│ runtime 主干: │ +│ - loader.py (插件发现、加载、环境管理) │ +│ - bootstrap.py (Supervisor/Worker 启动) │ +│ - handler_dispatcher.py (Handler 执行分发、参数注入) │ +│ - capability_dispatcher.py (Capability 调用分发) │ +│ - capability_router.py (Capability 路由、Schema 验证) │ +│ - peer.py (协议对等端) │ +│ - transport.py (传输抽象) │ +└────────────────────┬────────────────────────────────────────────┘ + │ +┌──────────────────▼─────────────────────────────────────────────┐ +│ 协议与传输 (Protocol & Transport) │ +├─────────────────────────────────────────────────────────────────┤ +│ protocol/ │ +│ - messages.py (协议消息模型) │ +│ - descriptors.py (Handler/Capability 描述符) │ +│ transport 实现: │ +│ - StdioTransport (标准输入输出) │ +│ - WebSocketServerTransport (WebSocket 服务端) │ +│ - WebSocketClientTransport (WebSocket 客户端) │ +└─────────────────────────────────────────────────────────────────┘ +``` + +### 层次职责 + +| 层次 | 职责 | 主要模块 | +|------|------|---------| +| **用户层** | 插件开发者 API | `Star`, `Context`, `MessageEvent`, 装饰器, 过滤器 | +| **高层 API** | 类型化的能力客户端 | `clients/{llm, memory, db, platform, http, metadata}` | +| **执行边界** | 插件加载、路由、分发 | `runtime/loader.py`, `runtime/*_dispatcher.py` | +| **协议层** | 消息模型、描述符、JSON Schema | `protocol/` | +| **传输层** | 底层通信抽象 | `runtime/transport.py` | + +### 核心设计原则 + +1. **延迟加载**:`runtime/__init__.py` 使用 `__getattr__` 避免导入时加载重型依赖 +2. **插件身份透传**:通过 `caller_plugin_scope()` 上下文管理器将 plugin_id 注入协议层 +3. **声明式优先**:所有配置都是数据结构(描述符),便于序列化和跨进程传递 +4. **类型安全**:使用 Pydantic 模型和类型注解提供验证和 IDE 支持 + +--- + +## 协议层设计 + +### 消息模型 + +v4 协议定义了 5 种消息类型: + +| 消息类型 | 用途 | 关键字段 | +|---------|------|---------| +| `InitializeMessage` | 握手初始化 | `protocol_version`, `peer`, `handlers`, `provided_capabilities` | +| `InvokeMessage` | 调用能力 | `capability`, `input`, `stream`, `caller_plugin_id` | +| `ResultMessage` | 返回结果 | `success`, `output`, `error`, `kind` | +| `EventMessage` | 流式事件 | `phase` (started/delta/completed/failed), `data` | +| `CancelMessage` | 取消调用 | `reason` | + +### 错误模型 + +`ErrorPayload` 使用字符串 code(而非整数),包含: +- `code`: 错误码(如 "capability_not_found") +- `message`: 开发者信息 +- `hint`: 用户友好提示 +- `retryable`: 是否可重试 + +### 握手流程 + +``` +Worker (Plugin) Supervisor (Core) + | | + | InitializeMessage | + | (handlers, capabilities) | + |----------------------------->| + | | + | ResultMessage(kind="init") | + |<-----------------------------| + | | + | InvokeMessage(handler.invoke) | + |<-----------------------------| + | 执行用户 handler | + | | + | ResultMessage(output) | + |----------------------------->| +``` + +### 描述符模型 + +#### HandlerDescriptor + +```python +{ + "id": "plugin.module:handler_name", + "trigger": { + "type": "command", + "command": "hello", + "aliases": ["hi"], + "description": "打招呼命令" + }, + "kind": "handler", # handler | hook | tool | session + "contract": "message_event", # message_event | schedule + "priority": 0, + "permissions": {"require_admin": False, "level": 0}, + "filters": [], + "param_specs": [] +} +``` + +#### Trigger 类型 + +| 类型 | 关键字段 | 说明 | +|------|---------|------| +| `CommandTrigger` | command, aliases, platforms | 命令触发 | +| `MessageTrigger` | regex, keywords, platforms | 消息触发(正则/关键词) | +| `EventTrigger` | event_type | 事件触发 | +| `ScheduleTrigger` | cron, interval_seconds | 定时触发 | + +### 内置 Capabilities + +#### LLM 命名空间 + +| 能力 | 说明 | +|------|------| +| `llm.chat` | 同步对话,返回文本 | +| `llm.chat_raw` | 同步对话,返回完整响应 | +| `llm.stream_chat` | 流式对话 | + +#### Memory 命名空间 + +| 能力 | 说明 | +|------|------| +| `memory.search` | 语义搜索记忆 | +| `memory.save` | 保存记忆 | +| `memory.save_with_ttl` | 保存带过期时间的记忆 | +| `memory.get` / `get_many` | 读取记忆 | +| `memory.list_keys` / `memory.exists` | 枚举与检查记忆键 | +| `memory.delete` / `delete_many` | 删除记忆 | +| `memory.clear_namespace` / `memory.count` | 管理 namespace 中的记忆 | +| `memory.stats` | 获取统计信息 | + +#### DB 命名空间 + +| 能力 | 说明 | +|------|------| +| `db.get` / `get_many` | 读取 KV | +| `db.set` / `set_many` | 写入 KV | +| `db.delete` | 删除 KV | +| `db.list` | 列出当前插件命名空间内的键(支持前缀过滤) | +| `db.watch` | 订阅当前插件命名空间内的变更(流式) | + +#### Message History 命名空间 + +| 能力 | 说明 | +|------|------| +| `message_history.list` | 分页读取会话消息历史 | +| `message_history.get_by_id` | 按 ID 读取单条消息历史 | +| `message_history.append` | 追加消息历史记录 | +| `message_history.delete_before` | 删除某时间点之前的记录 | +| `message_history.delete_after` | 删除某时间点之后的记录 | +| `message_history.delete_all` | 删除会话内全部消息历史 | + +#### Platform 命名空间 + +| 能力 | 说明 | +|------|------| +| `platform.send` | 发送文本消息 | +| `platform.send_image` | 发送图片 | +| `platform.send_chain` | 发送消息链 | +| `platform.get_members` | 获取群成员 | + +#### HTTP 命名空间 + +| 能力 | 说明 | +|------|------| +| `http.register_api` | 注册 HTTP API 端点,并拦截 `..` 等明显非法路径 | +| `http.unregister_api` | 注销 HTTP API 端点;不传 methods 时移除该 route 的全部方法 | +| `http.list_apis` | 列出已注册的 API | + +#### Metadata 命名空间 + +| 能力 | 说明 | +|------|------| +| `metadata.get_plugin` | 获取单个插件元数据 | +| `metadata.list_plugins` | 列出所有插件元数据 | +| `metadata.get_plugin_config` | 获取当前插件配置 | + +#### System 命名空间 + +| 能力 | 说明 | +|------|------| +| `system.get_data_dir` | 获取插件数据目录 | +| `system.text_to_image` | 文本转图片 | +| `system.html_render` | 渲染 HTML 模板 | +| `system.session_waiter.*` | 会话等待器管理 | +| `system.event.*` | 表情回应、输入状态、流式消息 | + +--- + +## 运行时架构 + +### 组件关系图 + +``` + ┌──────────────┐ + │ AstrBot │ + │ Core │ + └──────┬─────┘ + │ + ┌──────▼─────┐ + │ Supervisor │ + │ Runtime │ + └──────┬─────┘ + │ + ┌──────────────────┼──────────────────┐ + │ │ │ + ┌─────▼─────┐ ┌─────▼─────┐ ┌─────▼─────┐ + │ Peer │ │ Peer │ │ Peer │ + │ (stdio) │ │ (stdio) │ │ (stdio) │ + └─────┬─────┘ └─────┬─────┘ └─────┬─────┘ + │ │ │ + ┌─────▼─────┐ ┌─────▼─────┐ ┌─────▼─────┐ + │ Worker │ │ Worker │ │ Worker │ + │ Runtime │ │ Runtime │ │ Runtime │ + └─────┬─────┘ └─────┬─────┘ └─────┬─────┘ + │ │ │ + ┌─────▼─────┐ ┌─────▼─────┐ ┌─────▼─────┐ + │ Plugin A │ │ Plugin B │ │ Plugin C │ + └───────────┘ └───────────┘ └───────────┘ +``` + +### 核心运行时组件 + +| 组件 | 职责 | +|------|------| +| **SupervisorRuntime** | 管理多个 Worker 进程,聚合所有 handler | +| **WorkerSession** | 管理单个 Worker 进程的生命周期 | +| **PluginWorkerRuntime** | Worker 进程内的插件加载与执行 | +| **HandlerDispatcher** | 将 handler.invoke 请求转成真实 Python 调用 | +| **CapabilityRouter** | 能力注册、发现和执行路由 | + +### 参数注入优先级 + +HandlerDispatcher 支持参数注入,优先级为: + +1. **按类型注解注入**(`MessageEvent`, `Context`) +2. **按参数名注入**(`event`, `ctx`, `context`) +3. **从 legacy_args 注入**(命令参数等) + +--- + +## 客户端层设计 + +### 客户端架构 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ User Plugin │ +│ ctx.llm.chat() / ctx.memory.save() / ctx.db.set() │ +└────────────┬──────────────────────────────────────────────┘ + │ +┌────────────▼──────────────────────────────────────────────┐ +│ CapabilityProxy │ +│ - call(name, payload) 普通调用 │ +│ - stream(name, payload) 流式调用 │ +└────────────┬──────────────────────────────────────────────┘ + │ +┌────────────▼──────────────────────────────────────────────┐ +│ Peer │ +│ - invoke(capability, payload) │ +│ - invoke_stream(capability, payload) │ +└────────────┬──────────────────────────────────────────────┘ + │ +┌────────────▼──────────────────────────────────────────────┐ +│ Transport │ +│ - send(json_string) │ +└─────────────────────────────────────────────────────────────┘ +``` + +### 客户端一览 + +| 客户端 | 主要方法 | 对应 Capability | +|--------|---------|-----------------| +| `LLMClient` | `chat()`, `chat_raw()`, `stream_chat()` | `llm.*` | +| `MemoryClient` | `search()`, `save()`, `save_with_ttl()`, `get()`, `list_keys()`, `exists()`, `get_many()`, `delete()`, `clear_namespace()`, `delete_many()`, `count()`, `stats()` | `memory.*` | +| `DBClient` | `get()`, `set()`, `get_many()`, `set_many()`, `delete()`, `list()`, `watch()` | `db.*` | +| `MessageHistoryManagerClient` | `list()`, `get()`, `append()`, `delete_before()`, `delete_after()`, `delete_all()` | `message_history.*` | +| `PlatformClient` | `send()`, `send_image()`, `send_chain()`, `get_members()` | `platform.*` | +| `HTTPClient` | `register_api()`, `unregister_api()`, `list_apis()` | `http.*` | +| `MetadataClient` | `get_plugin()`, `list_plugins()`, `get_current_plugin()`, `get_plugin_config()` | `metadata.*` | + +--- + +## 插件开发指南 + +### v4 原生插件示例 + +#### plugin.yaml + +```yaml +_schema_version: 2 +name: my_plugin +author: your_name +version: 1.0.0 +runtime: + python: "3.12" +components: + - class: main:MyPlugin +``` + +#### main.py + +```python +from astrbot_sdk import Star, Context, MessageEvent +from astrbot_sdk.decorators import on_command, on_message, provide_capability + +class MyPlugin(Star): + # 命令处理器 + @on_command("hello", aliases=["hi"]) + async def hello(self, event: MessageEvent, ctx: Context) -> None: + await event.reply(f"你好,{event.user_id}!") + + # 消息处理器 + @on_message(keywords=["帮助"]) + async def help(self, event: MessageEvent, ctx: Context) -> None: + await event.reply("可用命令:hello, help") + + # 提供能力 + @provide_capability( + "my_plugin.calculate", + description="执行计算", + input_schema={ + "type": "object", + "properties": {"x": {"type": "number"}}, + "required": ["x"] + }, + output_schema={ + "type": "object", + "properties": {"result": {"type": "number"}}, + "required": ["result"] + } + ) + async def calculate_capability(self, payload: dict, ctx: Context) -> dict: + x = payload.get("x", 0) + return {"result": x * 2} +``` + +### 生命周期钩子 + +| 钩子 | 说明 | +|------|------| +| `on_start()` | 插件启动时调用 | +| `on_stop()` | 插件停止时调用 | +| `on_error(exc, event, ctx)` | Handler 执行出错时调用 | + +### 常用功能速查 + +#### 1. LLM 对话 + +```python +# 简单对话 +reply = await ctx.llm.chat("你好") + +# 带历史对话 +from astrbot_sdk.clients.llm import ChatMessage +history = [ChatMessage(role="user", content="我叫小明")] +reply = await ctx.llm.chat("你记得我吗?", history=history) + +# 流式对话 +async for chunk in ctx.llm.stream_chat("讲个故事"): + print(chunk, end="") +``` + +#### 2. 数据持久化 + +```python +# DB 客户端(精确匹配) +await ctx.db.set("user:123", {"name": "Alice"}) +data = await ctx.db.get("user:123") + +# Memory 客户端(语义搜索) +await ctx.memory.save("user_pref", {"theme": "dark"}) +results = await ctx.memory.search("用户喜欢什么颜色") +``` + +#### 3. 消息发送 + +```python +# 简单文本 +await ctx.platform.send(event.session_id, "消息内容") + +# 图片 +await ctx.platform.send_image(event.session_id, "https://example.com/img.jpg") + +# 消息链 +from astrbot_sdk.message_components import Plain, Image +chain = [Plain("文字"), Image(url="https://example.com/img.jpg")] +await ctx.platform.send_chain(event.session_id, chain) +``` + +--- + +## 关键设计模式 + +### 1. 协议优先模式 + +- 所有跨进程通信都通过 v4 协议 +- 传输层只处理字符串,协议由 Peer 层处理 +- 支持多种传输方式(Stdio, WebSocket) + +### 2. 能力路由模式 + +- 显式声明 Capability 和输入/输出 Schema +- 通过 CapabilityRouter 统一路由 +- 支持同步和流式两种调用模式 +- 冲突处理:保留命名空间冲突直接跳过,非保留命名空间冲突自动添加插件名前缀 + +### 3. 环境分组模式 + +- 多插件可共享同一 Python 虚拟环境 +- 按版本和依赖兼容性自动分组 +- 节省资源,加快启动速度 + +### 4. 参数注入模式 + +- HandlerDispatcher 支持类型注解注入 +- 优先级:类型注解 > 参数名 > legacy_args +- 支持可选类型 `Optional[Type]` + +### 5. 取消传播模式 + +- CancelToken 统一取消机制 +- 跨进程取消通过 CancelMessage +- 早到取消避免竞态条件 + +### 6. 插件隔离模式 + +- 每个插件运行在独立 Worker 进程 +- 崩溃不影响其他插件 +- 支持 GroupWorkerRuntime 共享环境 + +### 7. 热重载模式 + +- `dev --watch` 支持文件变更检测 +- 按插件目录清理 `sys.modules` 缓存 +- 确保代码变更后正确重载 + +--- + +## 文档与资源 + +### 完整文档目录 + +SDK 文档按学习路径组织,位于 `src/astrbot_sdk/docs/`: + +| 级别 | 文档 | 内容 | +|------|------|------| +| **初级** | README.md | 快速开始、核心概念 | +| | 01_context_api.md | Context API 完整参考 | +| | 02_event_and_components.md | MessageEvent 和消息组件 | +| | 03_decorators.md | 装饰器详细说明 | +| | 04_star_lifecycle.md | 插件基类和生命周期 | +| | 05_clients.md | 客户端 API 文档 | +| **中级** | 06_error_handling.md | 错误处理与调试 | +| | 07_advanced_topics.md | 并发、性能优化、安全 | +| | 08_testing_guide.md | 测试指南 | +| **高级** | 09_api_reference.md | 完整 API 索引 | +| | 10_migration_guide.md | 迁移指南 | +| | 11_security_checklist.md | 安全检查清单 | +| | PROJECT_ARCHITECTURE.md | 架构设计文档 | + +### 关键文件速查 + +| 文件 | 核心类/函数 | 说明 | +|------|------------|------| +| `astrbot_sdk/__init__.py` | `Star`, `Context`, `MessageEvent` | 顶层入口 | +| `astrbot_sdk/star.py` | `Star` | v4 原生插件基类 | +| `astrbot_sdk/context.py` | `Context` | 运行时上下文 | +| `astrbot_sdk/decorators.py` | `on_command`, `on_message` | v4 装饰器 | +| `astrbot_sdk/errors.py` | `AstrBotError` | 统一错误模型 | +| `astrbot_sdk/runtime/peer.py` | `Peer` | 协议对等端 | +| `astrbot_sdk/runtime/capability_router.py` | `CapabilityRouter` | Capability 路由 | +| `astrbot_sdk/clients/llm.py` | `LLMClient` | LLM 客户端 | + +### 版本信息 + +- **SDK 版本**: v4.0 +- **协议版本**: P0.6 +- **Python 要求**: >=3.12 +- **推荐版本**: 3.12+ + +--- + +> 本文档基于 AstrBot SDK v4 架构文档整理 +> 详细内容请查阅 `src/astrbot_sdk/docs/` 目录下的完整文档 diff --git a/astrbot-sdk/docs/README.md b/astrbot-sdk/docs/README.md new file mode 100644 index 0000000000..eaad85a33a --- /dev/null +++ b/astrbot-sdk/docs/README.md @@ -0,0 +1,466 @@ +# AstrBot SDK 插件开发文档 + +欢迎来到 AstrBot SDK 插件开发文档!本文档面向 SDK 插件开发者,提供从入门到精通的完整指南。 + +## 📚 文档目录 + +### 🚀 快速开始(初级使用者) + +适合第一次接触 AstrBot SDK 的开发者: + +- **[01. Context API 参考](./01_context_api.md)** - Context 类的核心客户端和系统工具方法 +- **[02. 消息事件与组件](./02_event_and_components.md)** - MessageEvent 和消息组件的使用 +- **[03. 装饰器使用指南](./03_decorators.md)** - 所有装饰器的详细说明 +- **[04. Star 类与生命周期](./04_star_lifecycle.md)** - 插件基类和生命周期钩子 +- **[05. 常用客户端速查](./05_clients.md)** - 常用客户端的快速上手示例与详细参考入口 + +### 🔧 进阶主题(中级使用者) + +适合已经掌握基础,希望深入了解 SDK 的开发者: + +- **[06. 错误处理与调试](./06_error_handling.md)** - 完整的错误处理指南和调试技巧 +- **[07. 高级主题](./07_advanced_topics.md)** - 并发处理、性能优化、安全最佳实践 +- **[08. 测试指南](./08_testing_guide.md)** - 如何测试插件和 Mock 使用 + +### 📖 参考资料(高级使用者) + +适合需要深入了解 SDK 架构和完整 API 的开发者: + +- **[09. 完整 API 索引](./09_api_reference.md)** - 所有导出类和函数的完整参考 +- **[客户端 API 详细参考](./api/clients.md)** - 17 个客户端与管理器的完整签名、返回值和示例 +- **[10. 迁移指南](./10_migration_guide.md)** - 从旧版本或其他框架迁移 +- **[11. 安全检查清单](./11_security_checklist.md)** - 安全开发检查清单和已知问题 + +--- + +## 🎯 学习路径推荐 + +### 初级路径:快速上手 + +``` +1. 阅读本 README 的快速开始部分 +2. 跟随下面的"创建第一个插件"教程 +3. 查阅 01-05 文档了解基础 API +4. 参考文档中的示例代码 +``` + +### 中级路径:进阶开发 + +``` +1. 阅读 06 错误处理指南,建立健壮的错误处理机制 +2. 学习 07 高级主题中的并发和性能优化 +3. 按照 08 测试指南编写测试 +4. 尝试开发复杂的插件功能 +``` + +### 高级路径:精通 SDK + +``` +1. 阅读 09 完整 API 索引,了解所有可用功能 +2. 研究 07 高级主题中的架构设计 +3. 阅读 SDK 源码深入理解实现 +4. 参与 SDK 贡献和改进 +``` + +--- + +## 🚀 快速上手 + +### 创建第一个插件 + +```python +from astrbot_sdk import Star, Context, MessageEvent +from astrbot_sdk.decorators import on_command, on_message + +class MyPlugin(Star): + """我的第一个插件""" + + @on_command("hello") + async def hello(self, event: MessageEvent, ctx: Context): + """打招呼命令""" + await event.reply(f"你好,{event.sender_name}!") + + @on_message(keywords=["帮助", "help"]) + async def help(self, event: MessageEvent, ctx: Context): + """帮助信息""" + await event.reply("可用命令: /hello") +``` + +### 插件配置 (plugin.yaml) + +```yaml +_schema_version: 2 +name: my_plugin +author: your_name +version: 1.0.0 +desc: 我的插件描述 + +runtime: + python: "3.12" + +components: + - class: main:MyPlugin + +support_platforms: + - aiocqhttp + - telegram +``` + +--- + +## 📖 核心概念 + +### Context - 能力访问入口 + +`Context` 是插件与 AstrBot Core 交互的主要入口: + +```python +# LLM 对话 +reply = await ctx.llm.chat("你好") + +# 数据存储 +await ctx.db.set("key", "value") +data = await ctx.db.get("key") + +# 记忆存储 +await ctx.memory.save("pref", {"theme": "dark"}) + +# 发送消息 +await ctx.platform.send(event.session_id, "消息内容") + +# 获取配置 +config = await ctx.metadata.get_plugin_config() +``` + +### MessageEvent - 消息事件 + +`MessageEvent` 表示接收到的消息事件: + +```python +# 回复消息 +await event.reply("回复内容") + +# 获取消息组件 +images = event.get_images() + +# 判断消息类型 +if event.is_group_chat(): + await event.reply("这是群聊消息") + +# 构建返回结果 +return event.plain_result("返回内容") +``` + +### 装饰器 - 事件处理注册 + +```python +from astrbot_sdk.decorators import ( + on_command, # 命令触发 + on_message, # 消息触发 + on_event, # 事件触发 + on_schedule, # 定时任务 + require_admin, # 权限控制 + rate_limit, # 速率限制 +) + +@on_command("test") +@rate_limit(5, 60) +async def test_handler(self, event: MessageEvent, ctx: Context): + await event.reply("测试") +``` + +--- + +## 🔧 常用功能速查 + +### 1. LLM 对话 + +```python +# 简单对话 +reply = await ctx.llm.chat("你好") + +# 带历史对话 +from astrbot_sdk.clients.llm import ChatMessage + +history = [ + ChatMessage(role="user", content="我叫小明"), + ChatMessage(role="assistant", content="你好小明!"), +] +reply = await ctx.llm.chat("你记得我吗?", history=history) + +# 流式对话 +async for chunk in ctx.llm.stream_chat("讲个故事"): + print(chunk, end="") +``` + +### 2. 数据持久化 + +```python +# DB 客户端(精确匹配,键空间按插件隔离) +await ctx.db.set("user:123", {"name": "Alice"}) +data = await ctx.db.get("user:123") + +# Memory 客户端(语义搜索) +await ctx.memory.save("user_pref", {"theme": "dark"}) +results = await ctx.memory.search("用户喜欢什么颜色") +keys = await ctx.memory.list_keys() +exists = await ctx.memory.exists("user_pref") +count = await ctx.memory.count() + +# Message History(保存原始消息链和发送者) +from astrbot_sdk import MessageHistorySender, MessageSession, Plain + +session = MessageSession( + platform_id=event.platform_id, + message_type=event.message_type, + session_id=event.session_id, +) +await ctx.message_history.append( + session, + parts=[Plain(event.message_content, convert=False)], + sender=MessageHistorySender( + sender_id=event.sender_id, + sender_name=event.sender_name, + ), +) +``` + +### 3. 消息发送 + +```python +# 简单文本 +await ctx.platform.send(event.session_id, "消息内容") + +# 图片 +await ctx.platform.send_image(event.session_id, "https://example.com/img.jpg") + +# 消息链 +from astrbot_sdk.message_components import Plain, Image + +chain = [Plain("文字"), Image(url="https://example.com/img.jpg")] +await ctx.platform.send_chain(event.session_id, chain) +``` + +### 4. 文件处理 + +```python +from astrbot_sdk.message_components import Image + +# 注册文件到文件服务 +img = Image.fromFileSystem("/path/to/image.jpg") +public_url = await img.register_to_file_service() +``` + +--- + +## 🛠️ 高级功能 + +### 1. LLM 工具注册 + +```python +async def search_weather(location: str) -> str: + return f"{location} 今天晴天" + +await ctx.register_llm_tool( + name="search_weather", + parameters_schema={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "城市名称"} + }, + "required": ["location"] + }, + desc="搜索天气信息", + func_obj=search_weather +) +``` + +### 2. Web API 注册 + +```python +from astrbot_sdk.decorators import provide_capability + +@provide_capability( + name="my_plugin.api", + description="处理 HTTP 请求" +) +async def handle_api(request_id: str, payload: dict, cancel_token): + return {"status": 200, "body": {"result": "ok"}} + +await ctx.http.register_api( + route="/my-api", # 建议使用规范化路径,避免 .. 和重复斜杠 + handler=handle_api, + methods=["GET", "POST"] +) +``` + +### 3. 后台任务 + +```python +async def background_work(): + while True: + await asyncio.sleep(60) + ctx.logger.info("每分钟执行一次") + +task = await ctx.register_task(background_work(), "定时任务") +``` + +--- + +## 📋 最佳实践 + +### 1. 错误处理 + +```python +from astrbot_sdk.errors import AstrBotError + +@on_command("risky") +async def risky_handler(self, event: MessageEvent, ctx: Context): + try: + result = await risky_operation() + await event.reply(f"成功: {result}") + except AstrBotError as e: + # SDK 错误包含用户友好的提示 + await event.reply(e.hint or e.message) + except ValueError as e: + await event.reply(f"参数错误: {e}") + except Exception as e: + ctx.logger.error(f"操作失败: {e}", exc_info=e) + raise +``` + +### 2. 日志记录 + +```python +# 不同级别的日志 +ctx.logger.debug("调试信息") +ctx.logger.info("普通信息") +ctx.logger.warning("警告信息") +ctx.logger.error("错误信息") + +# 绑定上下文 +logger = ctx.logger.bind(user_id=event.user_id) +logger.info("用户操作") +``` + +### 3. 配置管理 + +```python +class MyPlugin(Star): + async def on_start(self, ctx): + config = await ctx.metadata.get_plugin_config() + + # 提供默认值 + self.timeout = config.get("timeout", 30) + + # 验证必需配置 + if "api_key" not in config: + raise ValueError("缺少必需配置: api_key") + + self.api_key = config["api_key"] +``` + +### 4. 资源清理 + +```python +class MyPlugin(Star): + async def on_start(self, ctx): + self._session = aiohttp.ClientSession() + self._task = asyncio.create_task(self.background_task()) + + async def on_stop(self, ctx): + if hasattr(self, '_task'): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + + if hasattr(self, '_session'): + await self._session.close() +``` + +--- + +## 🔍 注意事项 + +1. **异步操作**:所有客户端方法都是异步的,需要使用 `await` + +2. **插件隔离**:每个插件有独立的 Context 实例 + +3. **错误处理**:所有远程调用都可能失败,建议使用 try-except + +4. **Memory vs DB**: + - Memory: 语义搜索,适合 AI 上下文 + - DB: 精确匹配,适合结构化数据 + +5. **平台标识**:使用 UMO 格式 `"platform:instance:session_id"` + +6. **装饰器顺序**:事件触发 → 过滤器 → 限制器 → 修饰器 + +7. **安全提示**: + - 不要在插件中存储敏感信息(API Key 等应使用配置) + - 验证所有用户输入 + - 注意资源泄漏(任务、连接等需要正确清理) + - 遵循最小权限原则 + +--- + +## 🐛 调试技巧 + +### 启用调试日志 + +```python +# 在插件中获取 logger +logger = ctx.logger + +# 记录详细信息 +logger.debug(f"收到消息: {event.text}") +logger.debug(f"用户ID: {event.user_id}") +``` + +### 使用测试框架 + +```python +from astrbot_sdk.testing import PluginTestHarness + +async def test_my_plugin(): + harness = PluginTestHarness() + plugin = harness.load_plugin("my_plugin.main:MyPlugin") + + # 模拟事件 + result = await harness.simulate_command("/hello") + assert result.text == "Hello!" +``` + +--- + +## 📞 获取帮助 + +- **查看详细文档**:[docs/](./) +- **完整 API 索引**:[09_api_reference.md](./09_api_reference.md) +- **错误处理指南**:[06_error_handling.md](./06_error_handling.md) +- **安全检查清单**:[11_security_checklist.md](./11_security_checklist.md) +- **提交问题**:[GitHub Issues](https://github.com/AstrBotDevs/astrbot-sdk/issues) +- **参与讨论**:[GitHub Discussions](https://github.com/AstrBotDevs/astrbot-sdk/discussions) + +--- + +## 📚 版本信息 + +- **SDK 版本**: v4.0 +- **最后更新**: 2026-03-22 +- **Python 要求**: >= 3.12 +- **协议版本**: P0.6 + +--- + +## 📝 文档贡献 + +如果您发现文档中的错误或想改进文档,欢迎提交 PR! + +**文档规范**: +- 使用清晰的代码示例 +- 包含错误处理示例 +- 标注 API 的稳定性和版本要求 +- 提供初级和高级两种使用方式 diff --git a/astrbot-sdk/docs/api/clients.md b/astrbot-sdk/docs/api/clients.md new file mode 100644 index 0000000000..b6c5fba80d --- /dev/null +++ b/astrbot-sdk/docs/api/clients.md @@ -0,0 +1,1807 @@ +# 客户端 API 完整参考 + +## 概述 + +本文档详细介绍 `astrbot_sdk/clients/` 目录下所有客户端的 API。客户端是 Context 中暴露的各种能力接口,每个客户端负责一类特定的功能。 + +**模块路径**: `astrbot_sdk.clients` + +--- + +## 目录 + +- [LLMClient - AI 对话客户端](#llmclient---ai-对话客户端) +- [MemoryClient - 记忆存储客户端](#memoryclient---记忆存储客户端) +- [DBClient - KV 数据库客户端](#dbclient---kv-数据库客户端) +- [PlatformClient - 平台消息客户端](#platformclient---平台消息客户端) +- [FileServiceClient - 文件服务客户端](#fileserviceclient---文件服务客户端) +- [HTTPClient - HTTP API 客户端](#httpclient---http-api-客户端) +- [MetadataClient - 插件元数据客户端](#metadataclient---插件元数据客户端) +- [ProviderClient - Provider 发现客户端](#providerclient---provider-发现客户端) +- [ProviderManagerClient - Provider 管理客户端](#providermanagerclient---provider-管理客户端) +- [PersonaManagerClient - 人格管理客户端](#personamanagerclient---人格管理客户端) +- [ConversationManagerClient - 对话管理客户端](#conversationmanagerclient---对话管理客户端) +- [MessageHistoryManagerClient - 消息历史管理客户端](#messagehistorymanagerclient---消息历史管理客户端) +- [KnowledgeBaseManagerClient - 知识库管理客户端](#knowledgebasemanagerclient---知识库管理客户端) +- [RegistryClient - Handler 注册表客户端](#registryclient---handler-注册表客户端) +- [SkillClient - 技能注册客户端](#skillclient---技能注册客户端) +- [SessionPluginManager - 会话插件管理器](#sessionpluginmanager---会话插件管理器) +- [SessionServiceManager - 会话服务管理器](#sessionservicemanager---会话服务管理器) + +--- + +## LLMClient - AI 对话客户端 + +提供与大语言模型交互的能力,支持普通聊天、流式聊天和结构化响应。 + +### 导入 + +```python +from astrbot_sdk.clients import LLMClient, ChatMessage, LLMResponse +``` + +### 方法 + +#### `chat(prompt, *, system, history, contexts, provider_id, model, temperature, **kwargs)` + +发送聊天请求并返回文本响应。 + +**参数**: +- `prompt` (`str`): 用户输入的提示文本 +- `system` (`str | None`): 系统提示词 +- `history` / `contexts` (`Sequence[ChatHistoryItem] | None`): 对话历史 +- `provider_id` (`str | None`): 指定使用的 provider +- `model` (`str | None`): 指定模型名称 +- `temperature` (`float | None`): 生成温度(0-1) +- `**kwargs`: 额外透传参数(如 `image_urls`, `tools`) + +**返回**: `str` - 生成的文本内容 + +**示例**: + +```python +# 简单对话 +reply = await ctx.llm.chat("你好,介绍一下自己") + +# 带系统提示词 +reply = await ctx.llm.chat( + "翻译成英文", + system="你是一个专业翻译助手" +) + +# 带对话历史 +history = [ + ChatMessage(role="user", content="我叫小明"), + ChatMessage(role="assistant", content="你好小明!"), +] +reply = await ctx.llm.chat("你记得我吗?", history=history) + +# 使用字典格式的对话历史 +history = [ + {"role": "user", "content": "我叫小明"}, + {"role": "assistant", "content": "你好小明!"}, +] +reply = await ctx.llm.chat("你记得我吗?", history=history) +``` + +--- + +#### `chat_raw(prompt, *, system, history, contexts, provider_id, model, temperature, **kwargs)` + +发送聊天请求并返回完整响应对象。 + +**返回**: `LLMResponse` 对象,包含: +- `text`: 生成的文本内容 +- `usage`: Token 使用统计 +- `finish_reason`: 结束原因 +- `tool_calls`: 工具调用列表 +- `role`: 响应角色 + +**示例**: + +```python +response = await ctx.llm.chat_raw("写一首诗", temperature=0.8) +print(f"生成文本: {response.text}") +print(f"Token 使用: {response.usage}") +print(f"结束原因: {response.finish_reason}") + +# 处理工具调用 +if response.tool_calls: + for tool_call in response.tool_calls: + print(f"工具调用: {tool_call}") +``` + +--- + +#### `stream_chat(prompt, *, system, history, contexts, provider_id, model, temperature, **kwargs)` + +流式聊天,逐块返回响应文本。 + +**返回**: 异步生成器,逐块生成文本 + +**示例**: + +```python +# 实时显示生成内容 +async for chunk in ctx.llm.stream_chat("讲一个故事"): + print(chunk, end="", flush=True) + +# 收集完整响应 +full_text = "" +async for chunk in ctx.llm.stream_chat("写一篇文章"): + full_text += chunk + # 实时处理每个 chunk +``` + +--- + +## MemoryClient - 记忆存储客户端 + +提供 AI 记忆的存储和检索能力,支持语义搜索。与 DBClient 和 MessageHistoryManagerClient 不同, +MemoryClient 主要用于可检索的 AI 上下文,而不是精确保存原始消息记录。 + +### 导入 + +```python +from astrbot_sdk.clients import MemoryClient +``` + +### 方法 + +#### `search(query, *, mode="auto", limit=None, min_score=None, provider_id=None)` + +搜索记忆项。默认会在存在 embedding provider 时执行 hybrid 检索, +否则退化为关键词检索。 + +**参数**: +- `query` (`str`): 搜索查询文本(自然语言) +- `mode` (`Literal["auto", "keyword", "vector", "hybrid"]`): 搜索模式 +- `limit` (`int | None`): 最大返回条数 +- `min_score` (`float | None`): 最低分数阈值 +- `provider_id` (`str | None`): 指定 embedding provider + +**返回**: `list[dict]` - 匹配的记忆项列表。每项至少包含 `key`、`value`、`score`、`match_type` + +**示例**: + +```python +# 搜索用户偏好 +results = await ctx.memory.search("用户喜欢什么颜色", mode="hybrid", limit=5) +for item in results: + print(item["key"], item["score"], item["match_type"]) + +# 强制使用关键词检索 +keyword_hits = await ctx.memory.search("blue", mode="keyword", min_score=0.9) + +# 使用当前激活的 embedding provider 执行向量检索 +vector_hits = await ctx.memory.search("之前讨论过什么技术话题", mode="vector") +``` + +--- + +#### `save(key, value, **extra)` + +保存记忆项。 + +**参数**: +- `key` (`str`): 记忆项的唯一标识键 +- `value` (`dict | None`): 要存储的数据字典 +- `**extra`: 额外的键值对,会合并到 value 中 + +**示例**: + +```python +# 保存用户偏好 +await ctx.memory.save("user_pref", { + "theme": "dark", + "lang": "zh", + "favorite_color": "blue" +}) + +# 使用关键字参数 +await ctx.memory.save( + "note", + None, + content="重要笔记", + tags=["work"], + timestamp="2024-01-01" +) + +# 显式指定检索文本 +await ctx.memory.save( + "profile:alice", + { + "name": "Alice", + "city": "Shanghai", + "embedding_text": "Alice 喜欢蓝色、海边和摄影", + }, +) +``` + +--- + +#### `get(key)` + +精确获取单个记忆项。 + +**参数**: +- `key` (`str`): 记忆项的唯一键 + +**返回**: `dict | None` - 记忆项内容字典,不存在则返回 None + +**示例**: + +```python +pref = await ctx.memory.get("user_pref") +if pref: + print(f"用户偏好主题: {pref.get('theme')}") +``` + +--- + +#### `delete(key)` + +删除记忆项。 + +**参数**: +- `key` (`str`): 要删除的记忆项键名 + +**示例**: + +```python +await ctx.memory.delete("old_note") +``` + +--- + +#### `save_with_ttl(key, value, ttl_seconds)` + +保存带过期时间的记忆项。 + +**参数**: +- `key` (`str`): 记忆项的唯一标识键 +- `value` (`dict`): 要存储的数据字典 +- `ttl_seconds` (`int`): 存活时间(秒),必须大于 0 + +**异常**: +- `TypeError`: value 不是 dict 类型 +- `ValueError`: ttl_seconds 小于 1 + +**示例**: + +```python +# 保存临时会话状态,1小时后过期 +await ctx.memory.save_with_ttl( + "session_temp", + {"state": "waiting", "step": 1}, + ttl_seconds=3600 +) + +# 保存验证码,5分钟后过期 +await ctx.memory.save_with_ttl( + "verification_code", + {"code": "123456", "user_id": "user123"}, + ttl_seconds=300 +) +``` + +--- + +#### `get_many(keys)` + +批量获取多个记忆项。 + +**参数**: +- `keys` (`list[str]`): 记忆项键名列表 + +**返回**: `list[dict]` - 记忆项列表 + +**示例**: + +```python +items = await ctx.memory.get_many(["pref1", "pref2", "pref3"]) +for item in items: + if item["value"]: + print(f"{item['key']}: {item['value']}") +``` + +--- + +#### `delete_many(keys)` + +批量删除多个记忆项。 + +**参数**: +- `keys` (`list[str]`): 要删除的记忆项键名列表 + +**返回**: `int` - 实际删除的记忆项数量 + +**示例**: + +```python +deleted = await ctx.memory.delete_many(["old1", "old2", "old3"]) +print(f"删除了 {deleted} 条记忆") +``` + +--- + +#### `stats()` + +获取记忆系统统计信息。 + +**返回**: `dict` - 统计信息字典 + +**示例**: + +```python +stats = await ctx.memory.stats() +print(f"记忆库共有 {stats['total_items']} 条记录") +if 'ttl_entries' in stats: + print(f"其中 {stats['ttl_entries']} 条有过期时间") +if 'indexed_items' in stats: + print(f"已建立索引: {stats['indexed_items']}") +if 'embedded_items' in stats: + print(f"已向量化: {stats['embedded_items']}") +if 'dirty_items' in stats: + print(f"待重建索引: {stats['dirty_items']}") +``` + +--- + +## DBClient - KV 数据库客户端 + +提供键值存储能力,用于持久化插件数据。数据永久保存直到显式删除,且运行时会自动对 key 做插件级命名空间隔离。 + +### 导入 + +```python +from astrbot_sdk.clients import DBClient +``` + +### 方法 + +#### `get(key)` + +获取指定键的值。 + +**参数**: +- `key` (`str`): 数据键名 + +**返回**: `Any | None` - 存储的值,键不存在则返回 None + +**示例**: + +```python +data = await ctx.db.get("user_settings") +if data: + print(data["theme"]) +``` + +--- + +#### `set(key, value)` + +设置键值对。 + +**参数**: +- `key` (`str`): 数据键名 +- `value` (`Any`): 要存储的 JSON 值 + +**示例**: + +```python +# 存储字典 +await ctx.db.set("user_settings", {"theme": "dark", "lang": "zh"}) + +# 存储列表 +await ctx.db.set("recent_commands", ["help", "status", "info"]) + +# 存储基本类型 +await ctx.db.set("greeted", True) +await ctx.db.set("counter", 42) +await ctx.db.set("last_seen", "2024-01-01T00:00:00Z") +``` + +--- + +#### `delete(key)` + +删除指定键的数据。 + +**参数**: +- `key` (`str`): 要删除的数据键名 + +**示例**: + +```python +await ctx.db.delete("user_settings") +``` + +--- + +#### `list(prefix=None)` + +列出匹配前缀的所有键。 + +**参数**: +- `prefix` (`str | None`): 键前缀过滤,None 表示列出所有键 + +**返回**: `list[str]` - 匹配的键名列表 + +返回的 key 是当前插件视角的原始 key,不包含运行时内部命名空间前缀。 + +**示例**: + +```python +# 列出所有用户设置相关的键 +keys = await ctx.db.list("user_") +# ["user_settings", "user_profile", "user_history"] + +# 列出所有键 +all_keys = await ctx.db.list() +``` + +--- + +#### `get_many(keys)` + +批量获取多个键的值。 + +**参数**: +- `keys` (`Sequence[str]`): 要读取的键列表 + +**返回**: `dict[str, Any | None]` - 字典,value 为对应值(不存在则为 None) + +**示例**: + +```python +values = await ctx.db.get_many(["user:1", "user:2", "user:3"]) +if values["user:1"] is None: + print("user:1 不存在") + +# 遍历结果 +for key, value in values.items(): + print(f"{key}: {value}") +``` + +--- + +#### `set_many(items)` + +批量写入多个键值对。 + +**参数**: +- `items` (`Mapping[str, Any] | Sequence[tuple[str, Any]]`): 键值对集合 + +**示例**: + +```python +# 使用字典 +await ctx.db.set_many({ + "user:1": {"name": "Alice"}, + "user:2": {"name": "Bob"}, + "user:3": {"name": "Charlie"} +}) + +# 使用元组列表 +await ctx.db.set_many([ + ("counter:1", 10), + ("counter:2", 20), + ("counter:3", 30) +]) +``` + +--- + +#### `watch(prefix=None)` + +订阅 KV 变更事件(流式)。 + +**参数**: +- `prefix` (`str | None`): 键前缀过滤 + +**返回**: 异步迭代器,产生变更事件 + +**事件格式**: `{"op": "set"|"delete", "key": str, "value": Any|None}` + +事件中的 `key` 也是当前插件视角的原始 key。 + +**示例**: + +```python +# 监听所有变更 +async for event in ctx.db.watch(): + print(f"{event['op']}: {event['key']}") + +# 监听特定前缀的变更 +async for event in ctx.db.watch("user:"): + if event["op"] == "set": + print(f"用户 {event['key']} 更新: {event['value']}") + else: + print(f"用户 {event['key']} 删除") +``` + +--- + +## PlatformClient - 平台消息客户端 + +提供向聊天平台发送消息和获取信息的能力。 + +### 导入 + +```python +from astrbot_sdk.clients import PlatformClient +``` + +### 方法 + +#### `send(session, text)` + +发送文本消息。 + +**参数**: +- `session` (`str | SessionRef | MessageSession`): 统一消息来源标识 +- `text` (`str`): 要发送的文本内容 + +**返回**: `dict[str, Any]` - 发送结果 + +**示例**: + +```python +# 使用字符串 UMO +await ctx.platform.send( + "qq:group:123456", + "大家好!" +) + +# 使用 MessageSession +from astrbot_sdk.message_session import MessageSession + +session = MessageSession( + platform_id="qq", + message_type="group", + session_id="123456" +) +await ctx.platform.send(session, "你好!") + +# 使用事件中的 session_id +await ctx.platform.send(event.session_id, "收到您的消息!") +``` + +--- + +#### `send_image(session, image_url)` + +发送图片消息。 + +**参数**: +- `session`: 会话标识 +- `image_url` (`str`): 图片 URL 或本地文件路径 + +**返回**: `dict[str, Any]` - 发送结果 + +**示例**: + +```python +# 使用 URL +await ctx.platform.send_image( + event.session_id, + "https://example.com/image.png" +) + +# 使用本地路径 +await ctx.platform.send_image( + "qq:private:789", + "/path/to/local/image.jpg" +) +``` + +--- + +#### `send_chain(session, chain)` + +发送富消息链。 + +**参数**: +- `session`: 会话标识 +- `chain` (`MessageChain | list[BaseMessageComponent] | list[dict]`): 消息链 + +**返回**: `dict[str, Any]` - 发送结果 + +**示例**: + +```python +from astrbot_sdk.message_components import Plain, Image + +# 使用 MessageChain +chain = MessageChain([ + Plain("你好 "), + At("123456"), + Plain("!"), +]) +await ctx.platform.send_chain(event.session_id, chain) + +# 使用组件列表 +await ctx.platform.send_chain( + event.session_id, + [Plain("文本"), Image(url="https://example.com/img.jpg")] +) + +# 使用序列化的 payload +await ctx.platform.send_chain( + event.session_id, + [ + {"type": "text", "data": {"text": "文本"}}, + {"type": "image", "data": {"url": "https://example.com/a.png"}} + ] +) +``` + +--- + +#### `send_by_session(session, content)` + +主动向指定会话发送消息。 + +**参数**: +- `session`: 会话标识 +- `content`: 消息内容(支持多种格式) + +**示例**: + +```python +# 发送文本 +await ctx.platform.send_by_session("qq:group:123456", "公告:...") + +# 发送消息链 +chain = MessageChain([Plain("重要通知"), Image.fromURL(...)]) +await ctx.platform.send_by_session("qq:group:123456", chain) +``` + +--- + +#### `send_by_id(platform_id, session_id, content, *, message_type)` + +主动向指定平台会话发送消息。 + +**参数**: +- `platform_id` (`str`): 平台 ID +- `session_id` (`str`): 会话 ID +- `content`: 消息内容 +- `message_type` (`str`): 消息类型(`"private"` 或 `"group"`) + +**示例**: + +```python +# 发送私聊消息 +await ctx.platform.send_by_id( + platform_id="qq", + session_id="123456", + content="Hello", + message_type="private" +) + +# 发送群消息 +await ctx.platform.send_by_id( + platform_id="qq", + session_id="789", + content="群公告", + message_type="group" +) +``` + +--- + +#### `get_members(session)` + +获取群组成员列表。 + +**参数**: +- `session`: 群组会话标识 + +**返回**: `list[dict]` - 成员信息列表 + +**示例**: + +```python +members = await ctx.platform.get_members("qq:group:123456") +for member in members: + print(f"{member['nickname']} ({member['user_id']})") +``` + +--- + +## FileServiceClient - 文件服务客户端 + +提供文件令牌注册与解析能力,用于跨进程文件传递。 + +### 导入 + +```python +from astrbot_sdk.clients import FileServiceClient, FileRegistration +``` + +### 方法 + +#### `register_file(path, timeout=None)` + +注册文件到文件服务,获取访问令牌。 + +**参数**: +- `path` (`str`): 文件路径 +- `timeout` (`float | None`): 超时时间(秒) + +**返回**: `str` - 文件访问令牌 + +**示例**: + +```python +token = await ctx.files.register_file("/path/to/file.jpg", timeout=3600) +``` + +--- + +#### `handle_file(token)` + +通过令牌解析文件路径。 + +**参数**: +- `token` (`str`): 文件访问令牌 + +**返回**: `str` - 文件路径 + +**示例**: + +```python +path = await ctx.files.handle_file(token) +with open(path, 'rb') as f: + data = f.read() +``` + +--- + +## HTTPClient - HTTP API 客户端 + +提供 Web API 注册能力,允许插件暴露自定义 HTTP 端点。 + +### 导入 + +```python +from astrbot_sdk.clients import HTTPClient +``` + +### 方法 + +#### `register_api(route, handler_capability=None, *, handler=None, methods=None, description="")` + +注册 Web API 端点。 + +**参数**: +- `route` (`str`): API 路由路径。当前实现会拦截包含 `..` 的路径和部分明显非法输入;建议使用以 `/` 开头、没有重复斜杠的规范化路径 +- `handler_capability` (`str | None`): 处理此路由的 capability 名称 +- `handler` (`Any | None`): 使用 `@provide_capability` 标记的方法引用 +- `methods` (`list[str] | None`): HTTP 方法列表 +- `description` (`str`): API 描述 + +**示例**: + +```python +from astrbot_sdk.decorators import provide_capability + +# 1. 声明处理 HTTP 请求的 capability +@provide_capability( + name="my_plugin.http_handler", + description="处理 /my-api 的 HTTP 请求" +) +async def handle_http_request(request_id: str, payload: dict, cancel_token): + return {"status": 200, "body": {"result": "ok"}} + +# 2. 注册路由 +await ctx.http.register_api( + route="/my-api", + handler_capability="my_plugin.http_handler", + methods=["GET", "POST"], + description="我的 API" +) + +# 或使用 handler 参数 +await ctx.http.register_api( + route="/my-api", + handler=handle_http_request, + methods=["GET"] +) +``` + +--- + +#### `unregister_api(route, methods=None)` + +注销 Web API 端点。 + +**参数**: +- `route` (`str`): API 路由路径 +- `methods` (`list[str] | None`): HTTP 方法列表,`None` 表示移除当前插件在该 route 下注册的全部方法 + +**示例**: + +```python +await ctx.http.unregister_api("/my-api") +``` + +--- + +#### `list_apis()` + +列出当前插件注册的所有 API。 + +**返回**: `list[dict]` - API 列表 + +**示例**: + +```python +apis = await ctx.http.list_apis() +for api in apis: + print(f"{api['route']}: {api['methods']}") +``` + +--- + +## MetadataClient - 插件元数据客户端 + +提供插件元数据查询能力。 + +### 导入 + +```python +from astrbot_sdk.clients import MetadataClient, PluginMetadata +``` + +### 方法 + +#### `get_plugin(name)` + +获取指定插件的元数据。 + +**参数**: +- `name` (`str`): 插件名称 + +**返回**: `PluginMetadata | None` - 插件元数据 + +**示例**: + +```python +plugin = await ctx.metadata.get_plugin("another_plugin") +if plugin: + print(f"插件: {plugin.display_name}") + print(f"版本: {plugin.version}") +``` + +--- + +#### `list_plugins()` + +获取所有插件的元数据列表。 + +**返回**: `list[PluginMetadata]` + +**示例**: + +```python +plugins = await ctx.metadata.list_plugins() +for plugin in plugins: + print(f"{plugin.display_name} v{plugin.version} - {plugin.author}") +``` + +--- + +#### `get_current_plugin()` + +获取当前插件的元数据。 + +**返回**: `PluginMetadata | None` + +**示例**: + +```python +current = await ctx.metadata.get_current_plugin() +if current: + print(f"当前插件: {current.name} v{current.version}") +``` + +--- + +#### `get_plugin_config(name=None)` + +获取插件配置。 + +**参数**: +- `name` (`str | None`): 插件名称,None 表示当前插件 + +**返回**: `dict | None` - 插件配置字典 + +**注意**: 只能查询当前插件自己的配置 + +**示例**: + +```python +# 获取当前插件配置 +config = await ctx.metadata.get_plugin_config() +if config: + api_key = config.get("api_key") + +# 获取其他插件配置会抛 PermissionError +await ctx.metadata.get_plugin_config("other_plugin") +``` + +--- + +## ProviderClient - Provider 发现客户端 + +提供 Provider 发现和查询能力。 + +### 导入 + +```python +from astrbot_sdk.clients import ProviderClient +``` + +### 方法 + +#### `list_all()` + +列出所有聊天 Provider。 + +**返回**: `list[ProviderMeta]` + +**示例**: + +```python +providers = await ctx.providers.list_all() +for p in providers: + print(f"{p.id}: {p.model}") +``` + +--- + +#### `list_tts()` + +列出所有 TTS Provider。 + +**返回**: `list[ProviderMeta]` + +--- + +#### `list_stt()` + +列出所有 STT Provider。 + +**返回**: `list[ProviderMeta]` + +--- + +#### `list_embedding()` + +列出所有 Embedding Provider。 + +**返回**: `list[ProviderMeta]` + +--- + +#### `list_rerank()` + +列出所有 Rerank Provider。 + +**返回**: `list[ProviderMeta]` + +--- + +#### `get(provider_id)` + +获取指定 Provider 的代理。 + +**参数**: +- `provider_id` (`str`): Provider ID + +**返回**: `ProviderProxy | None` + +--- + +#### `get_using_chat(umo=None)` + +获取当前使用的聊天 Provider。 + +**参数**: +- `umo` (`str | None`): 统一消息来源标识 + +**返回**: `ProviderMeta | None` + +--- + +#### `get_using_tts(umo=None)` + +获取当前使用的 TTS Provider。 + +--- + +#### `get_using_stt(umo=None)` + +获取当前使用的 STT Provider。 + +--- + +## ProviderManagerClient - Provider 管理客户端 + +提供 Provider 的动态管理能力。 +仅 `reserved/system` 插件可用。普通插件调用这些方法会收到 `provider.manager.* is restricted to reserved/system plugins` 错误;普通插件应优先使用 `ProviderClient` 进行只读查询。 + +### 导入 + +```python +from astrbot_sdk.clients import ProviderManagerClient +``` + +### 方法 + +#### `set_provider(provider_id, provider_type, umo=None)` + +设置当前全局生效的 Provider。 +`umo` 只会出现在变更事件中,不会让 Provider 选择按会话隔离。 + +**参数**: +- `provider_id` (`str`): Provider ID +- `provider_type` (`ProviderType | str`): Provider 类型 +- `umo` (`str | None`): 统一消息来源标识 + +**示例**: + +```python +from astrbot_sdk.llm.entities import ProviderType + +await ctx.provider_manager.set_provider( + "my_provider", + ProviderType.CHAT_COMPLETION, + umo=event.session_id, +) +``` + +--- + +#### `get_provider_by_id(provider_id)` + +通过 ID 获取 Provider 记录。 + +--- + +#### `load_provider(provider_config)` + +加载 Provider。 + +--- + +#### `create_provider(provider_config)` + +创建新 Provider。 + +`provider_config` 至少应包含 `id`、`type` 和 `provider_type`。例如: + +```python +record = await ctx.provider_manager.create_provider( + { + "id": "my_provider", + "type": "openai", + "provider_type": "chat_completion", + "model": "gpt-4", + } +) +``` + +--- + +#### `update_provider(origin_provider_id, new_config)` + +更新 Provider 配置。 + +--- + +#### `delete_provider(provider_id=None, provider_source_id=None)` + +删除 Provider。 + +--- + +#### `get_insts()` + +获取所有已管理的 Provider 实例。 + +--- + +#### `watch_changes()` + +订阅 Provider 变更事件(流式)。 + +--- + +## PersonaManagerClient - 人格管理客户端 + +提供人格(Persona)的增删改查能力。 + +### 导入 + +```python +from astrbot_sdk.clients import PersonaManagerClient +``` + +### 方法 + +#### `get_persona(persona_id)` + +获取指定人格。 + +当人格不存在时会抛出 `ValueError`,而不是返回 `None`。 + +--- + +#### `get_all_personas()` + +获取所有人脸列表。 + +--- + +#### `create_persona(params)` + +创建新人格。 + +--- + +#### `update_persona(persona_id, params)` + +更新人格。 + +--- + +#### `delete_persona(persona_id)` + +删除人格。 + +--- + +## ConversationManagerClient - 对话管理客户端 + +提供对话的创建、切换、更新、删除和查询能力。 + +### 导入 + +```python +from astrbot_sdk.clients import ConversationManagerClient +``` + +### 方法 + +#### `new_conversation(session, params=None)` + +创建新对话。 + +--- + +#### `switch_conversation(session, conversation_id)` + +切换当前对话。 + +--- + +#### `delete_conversation(session, conversation_id=None)` + +删除对话。 + +--- + +#### `get_conversation(session, conversation_id, create_if_not_exists=False)` + +获取对话。 + +--- + +#### `get_current_conversation(session, create_if_not_exists=False)` + +获取当前 session 正在使用的对话记录。 + +这个方法适合“跟随 AstrBot 原生当前会话状态”的插件,例如: +- 给当前会话切换 persona +- 判断当前主聊天是否已经在某个 persona 下 +- 在 `waiting_llm_request` / `llm_request` hook 中对当前对话做增强 + +--- + +#### `get_conversations(session=None, platform_id=None)` + +获取对话列表。 + +--- + +#### `update_conversation(session, conversation_id=None, params=None)` + +更新对话。 + +--- + +## MessageHistoryManagerClient - 消息历史管理客户端 + +按 `MessageSession` 精确保存原始消息组件、发送者和元数据。适合审计、回溯、分页读取和按时间清理。 +如果要做语义召回或向量检索,请继续使用 `MemoryClient`。 + +### 导入 + +```python +from astrbot_sdk.clients import ( + MessageHistoryManagerClient, + MessageHistoryPage, + MessageHistoryRecord, + MessageHistorySender, +) +from astrbot_sdk.message.session import MessageSession +from astrbot_sdk.message.components import Plain +``` + +### 方法 + +#### `list(session, *, cursor=None, limit=50)` + +分页列出某个会话的消息历史。 + +**参数**: +- `session` (`MessageSession`): 目标会话,必须是 `MessageSession` +- `cursor` (`str | None`): 分页游标,建议直接使用上一页返回的 `next_cursor` +- `limit` (`int`): 返回条数,默认 `50` + +**返回**: `MessageHistoryPage` - 包含 `records`、`next_cursor`、`total` + +**示例**: + +```python +session = MessageSession( + platform_id=event.platform_id, + message_type=event.message_type, + session_id=event.session_id, +) +page = await ctx.message_history.list(session, limit=20) +for record in page.records: + print(record.id, record.sender.sender_name, record.parts) +``` + +--- + +#### `get(session, record_id)` / `get_by_id(session, record_id)` + +按记录 ID 读取单条消息历史。 + +**参数**: +- `session` (`MessageSession`): 目标会话 +- `record_id` (`int`): 记录 ID + +**返回**: `MessageHistoryRecord | None` + +**示例**: + +```python +session = MessageSession( + platform_id=event.platform_id, + message_type=event.message_type, + session_id=event.session_id, +) +record = await ctx.message_history.get(session, 1) +same_record = await ctx.message_history.get_by_id(session, 1) +``` + +--- + +#### `append(session, *, parts, sender, metadata=None, idempotency_key=None)` + +追加一条消息历史记录。 + +**参数**: +- `session` (`MessageSession`): 目标会话 +- `parts` (`list[BaseMessageComponent]`): 原始消息组件列表 +- `sender` (`MessageHistorySender`): 发送者信息,也可传可验证为该模型的 `dict` +- `metadata` (`dict[str, Any] | None`): 附加元数据 +- `idempotency_key` (`str | None`): 幂等键;相同 key 会返回现有记录而不是重复写入 + +**返回**: `MessageHistoryRecord` + +**示例**: + +```python +session = MessageSession( + platform_id=event.platform_id, + message_type=event.message_type, + session_id=event.session_id, +) +record = await ctx.message_history.append( + session, + parts=[Plain(event.message_content, convert=False)], + sender=MessageHistorySender( + sender_id=event.sender_id, + sender_name=event.sender_name, + ), + metadata={"source": "handler"}, + idempotency_key="incoming:demo-user:hello", +) +print(record.created_at, record.idempotency_key) +``` + +--- + +#### `delete_before(session, *, before)` / `delete_after(session, *, after)` + +按时间边界删除某个会话内的消息历史。 + +**参数**: +- `session` (`MessageSession`): 目标会话 +- `before` / `after` (`datetime`): 时间边界,建议使用带时区的 `datetime` + +**返回**: `int` - 删除的记录数 + +**示例**: + +```python +from datetime import datetime, timezone + +deleted = await ctx.message_history.delete_before( + session, + before=datetime(2026, 3, 22, tzinfo=timezone.utc), +) +``` + +--- + +#### `delete_all(session)` + +删除某个会话的全部消息历史。 + +**参数**: +- `session` (`MessageSession`): 目标会话 + +**返回**: `int` - 删除的记录数 + +**示例**: + +```python +deleted = await ctx.message_history.delete_all(session) +print(f"deleted={deleted}") +``` + +--- + +## KnowledgeBaseManagerClient - 知识库管理客户端 + +提供知识库的创建、查询和删除能力。 + +### 导入 + +```python +from astrbot_sdk.clients import KnowledgeBaseManagerClient +``` + +### 方法 + +#### `get_kb(kb_id)` + +获取知识库。 + +参数 `kb_id` 是知识库的唯一 ID,不是 `kb_name`。 + +--- + +#### `create_kb(params)` + +创建新知识库。 +返回的 `KnowledgeBaseRecord` 中包含运行时生成的 `kb_id`,后续更新、删除和文档操作都应使用这个 `kb_id`。 + +--- + +#### `delete_kb(kb_id)` + +删除知识库。 + +--- + +## RegistryClient - Handler 注册表客户端 + +handler 注册表查询与白名单管理客户端,用于查询 handler 信息并管理 handler 白名单。 + +### 导入 + +```python +from astrbot_sdk.clients import RegistryClient, HandlerMetadata +``` + +### 方法 + +#### `get_handlers_by_event_type(event_type)` + +获取指定事件类型的所有 handler。 + +**参数**: +- `event_type` (`str`): 事件类型 + +**返回**: `list[HandlerMetadata]` + +**示例**: + +```python +handlers = await ctx.registry.get_handlers_by_event_type("message") +for h in handlers: + print(f"{h.handler_full_name}: {h.description}") +``` + +--- + +#### `get_handler_by_full_name(full_name)` + +通过完整名称获取 handler 元数据。 + +**参数**: +- `full_name` (`str`): handler 完整名称(格式:`plugin_name.handler_name`) + +**返回**: `HandlerMetadata | None` + +**示例**: + +```python +handler = await ctx.registry.get_handler_by_full_name("my_plugin.on_message") +if handler: + print(f"触发类型: {handler.trigger_type}") + print(f"优先级: {handler.priority}") + print(f"需要管理员: {handler.require_admin}") +``` + +--- + +#### `set_handler_whitelist(plugin_names)` + +设置 handler 白名单。 + +**参数**: +- `plugin_names` (`list[str] | set[str] | None`): 插件名称列表,None 表示清除白名单 + +**返回**: `list[str] | None` - 实际设置的白名单 + +**示例**: + +```python +# 设置白名单 +await ctx.registry.set_handler_whitelist(["plugin_a", "plugin_b"]) + +# 清空白名单 +await ctx.registry.set_handler_whitelist(None) +``` + +--- + +#### `get_handler_whitelist()` + +获取当前 handler 白名单。 + +**返回**: `list[str] | None` + +**示例**: + +```python +whitelist = await ctx.registry.get_handler_whitelist() +if whitelist: + print(f"当前白名单: {whitelist}") +``` + +--- + +#### `clear_handler_whitelist()` + +清除 handler 白名单。 + +**示例**: + +```python +await ctx.registry.clear_handler_whitelist() +``` + +--- + +## SkillClient - 技能注册客户端 + +技能注册客户端,用于注册和管理技能。 + +### 导入 + +```python +from astrbot_sdk.clients import SkillClient, SkillRegistration +``` + +### 方法 + +#### `register(*, name, path, description="")` + +注册一个技能。 + +**参数**: +- `name` (`str`): 技能名称 +- `path` (`str`): 技能路径 +- `description` (`str`): 技能描述 + +**返回**: `SkillRegistration` + +**示例**: + +```python +skill = await ctx.skills.register( + name="my_skill", + path="/path/to/skill", + description="我的技能描述" +) +print(f"技能已注册: {skill.name}") +``` + +--- + +#### `unregister(name)` + +注销技能。 + +**参数**: +- `name` (`str`): 技能名称 + +**返回**: `bool` - 是否成功注销 + +**示例**: + +```python +removed = await ctx.skills.unregister("my_skill") +if removed: + print("技能已注销") +``` + +--- + +#### `list()` + +列出当前已注册的技能。 + +**返回**: `list[SkillRegistration]` + +**示例**: + +```python +skills = await ctx.skills.list() +for skill in skills: + print(f"{skill.name}: {skill.skill_dir}") +``` + +--- + +## SessionPluginManager - 会话插件管理器 + +会话级别的插件状态管理器,用于检查和过滤会话相关的插件状态。 + +### 导入 + +```python +from astrbot_sdk.clients import SessionPluginManager +``` + +### 方法 + +#### `is_plugin_enabled_for_session(session, plugin_name)` + +检查插件在指定会话是否启用。 + +**参数**: +- `session` (`str | MessageSession | MessageEvent`): 会话标识 +- `plugin_name` (`str`): 插件名称 + +**返回**: `bool` + +**示例**: + +```python +enabled = await ctx.session_plugins.is_plugin_enabled_for_session( + session=event, + plugin_name="my_plugin" +) +if not enabled: + await event.reply("该插件在此会话已禁用") +``` + +--- + +#### `filter_handlers_by_session(session, handlers)` + +根据会话过滤 handler 列表。 + +**参数**: +- `session` (`str | MessageSession | MessageEvent`): 会话标识 +- `handlers` (`list[HandlerMetadata]`): handler 列表 + +**返回**: `list[HandlerMetadata]` - 过滤后的 handler 列表 + +**示例**: + +```python +handlers = await ctx.registry.get_handlers_by_event_type("message") +filtered = await ctx.session_plugins.filter_handlers_by_session( + session=event, + handlers=handlers +) +print(f"可用 handler 数量: {len(filtered)}") +``` + +--- + +## SessionServiceManager - 会话服务管理器 + +会话级别的 LLM/TTS 服务状态管理器。 + +### 导入 + +```python +from astrbot_sdk.clients import SessionServiceManager +``` + +### 方法 + +#### `is_llm_enabled_for_session(session)` + +检查 LLM 服务在指定会话是否启用。 + +**参数**: +- `session` (`str | MessageSession | MessageEvent`): 会话标识 + +**返回**: `bool` + +**示例**: + +```python +if await ctx.session_services.is_llm_enabled_for_session(event): + reply = await ctx.llm.chat(prompt) +``` + +--- + +#### `set_llm_status_for_session(session, enabled)` + +设置会话的 LLM 服务状态。 + +**参数**: +- `session` (`str | MessageSession | MessageEvent`): 会话标识 +- `enabled` (`bool`): 是否启用 + +**示例**: + +```python +await ctx.session_services.set_llm_status_for_session(event, enabled=False) +await event.reply("LLM 已在此会话禁用") +``` + +--- + +#### `should_process_llm_request(event_or_session)` + +检查是否应处理 LLM 请求(等同于 `is_llm_enabled_for_session`)。 + +**参数**: +- `event_or_session` (`str | MessageSession | MessageEvent`): 会话标识 + +**返回**: `bool` + +**示例**: + +```python +if await ctx.session_services.should_process_llm_request(event): + reply = await ctx.llm.chat(prompt) +``` + +--- + +#### `is_tts_enabled_for_session(session)` + +检查 TTS 服务在指定会话是否启用。 + +**参数**: +- `session` (`str | MessageSession | MessageEvent`): 会话标识 + +**返回**: `bool` + +--- + +#### `set_tts_status_for_session(session, enabled)` + +设置会话的 TTS 服务状态。 + +**参数**: +- `session` (`str | MessageSession | MessageEvent`): 会话标识 +- `enabled` (`bool`): 是否启用 + +**示例**: + +```python +await ctx.session_services.set_tts_status_for_session(event, enabled=True) +await event.reply("TTS 已在此会话启用") +``` + +--- + +#### `should_process_tts_request(event_or_session)` + +检查是否应处理 TTS 请求(等同于 `is_tts_enabled_for_session`)。 + +**参数**: +- `event_or_session` (`str | MessageSession | MessageEvent`): 会话标识 + +**返回**: `bool` + +--- + +## 使用示例 + +### 基本对话流程 + +```python +@on_message() +async def handle_message(event: MessageEvent, ctx: Context): + reply = await ctx.llm.chat(event.message_content) + await ctx.platform.send(event.session_id, reply) +``` + +### 带历史的对话 + +```python +@on_message() +async def handle_message(event: MessageEvent, ctx: Context): + # 从 memory 获取历史 + history_data = await ctx.memory.get(f"history:{event.session_id}") + history = history_data.get("messages", []) if history_data else [] + + # 对话 + reply = await ctx.llm.chat(event.message_content, history=history) + + # 保存新消息到历史 + history.append(ChatMessage(role="user", content=event.message_content)) + history.append(ChatMessage(role="assistant", content=reply)) + await ctx.memory.save(f"history:{event.session_id}", {"messages": history}) + + await ctx.platform.send(event.session_id, reply) +``` + +如果你需要保留原始消息组件、发送者和按时间清理能力,应优先使用 `ctx.message_history`。 + +### 使用数据库持久化 + +```python +@on_message() +async def handle_message(event: MessageEvent, ctx: Context): + # 获取用户配置 + config = await ctx.db.get(f"user_config:{event.sender_id}") + + if not config: + config = {"theme": "light", "lang": "zh"} + await ctx.db.set(f"user_config:{event.sender_id}", config) + + # 使用配置 + reply = f"你的主题设置是: {config['theme']}" + await ctx.platform.send(event.session_id, reply) +``` + +--- + +## 注意事项 + +1. 所有客户端方法都是异步的,需要使用 `await` +2. 远程调用可能失败,建议使用 try-except 处理 +3. Memory 适合语义搜索,DB 适合结构化 KV,MessageHistory 适合精确保存原始消息记录 +4. DB key 在运行时按插件隔离;`list()` 和 `watch()` 返回插件本地 key 视图 +5. `HTTPClient.register_api()` 当前会拦截 `..` 等明显非法路径,但仍建议插件自行使用规范化 route;`unregister_api(route)` 默认移除该 route 下全部方法 +6. 文件操作使用 file service 注册令牌 +7. 平台标识使用 UMO 格式:`"platform:instance:session_id"` + +--- + +**版本**: v4.0 +**模块**: `astrbot_sdk.clients` +**最后更新**: 2026-03-22 diff --git a/astrbot-sdk/docs/api/context.md b/astrbot-sdk/docs/api/context.md new file mode 100644 index 0000000000..8f9242e4d0 --- /dev/null +++ b/astrbot-sdk/docs/api/context.md @@ -0,0 +1,1660 @@ +# Context 类 - 插件运行时上下文完整参考 + +## 概述 + +`Context` 是插件运行时的核心上下文对象,每个 handler 调用都会创建一个新的 Context 实例。Context 组合了所有 capability 客户端,提供统一的访问接口。 + +**模块路径**: `astrbot_sdk.context.Context` + +--- + +## 类定义 + +```python +@dataclass(slots=True) +class Context: + # 基本属性 + peer: Any # 协议对等端 + plugin_id: str # 插件 ID + logger: PluginLogger # 日志器 + cancel_token: CancelToken # 取消令牌 + + # 能力客户端 + llm: LLMClient # LLM 客户端 + memory: MemoryClient # 记忆客户端 + db: DBClient # 数据库客户端 + files: FileServiceClient # 文件服务客户端 + platform: PlatformClient # 平台客户端 + providers: ProviderClient # Provider 客户端 + provider_manager: ProviderManagerClient # Provider 管理客户端 + personas: PersonaManagerClient # 人格管理客户端 + conversations: ConversationManagerClient # 对话管理客户端 + kbs: KnowledgeBaseManagerClient # 知识库管理客户端 + message_history: MessageHistoryManagerClient # 消息历史管理客户端 + http: HTTPClient # HTTP 客户端 + metadata: MetadataClient # 元数据客户端 + registry: RegistryClient # handler 注册表客户端 + skills: SkillClient # 技能注册客户端 + session_plugins: SessionPluginManager # 会话插件管理器 + session_services: SessionServiceManager # 会话服务管理器 + + # 系统工具 + _llm_tool_manager: LLMToolManager + _source_event_payload: dict[str, Any] + + # 别名 + persona_manager = personas + conversation_manager = conversations + kb_manager = kbs + message_history_manager = message_history +``` + +--- + +## 导入方式 + +```python +# 从主模块导入(推荐) +from astrbot_sdk import Context + +# 从子模块导入 +from astrbot_sdk.context import Context + +# 常用配套导入 +from astrbot_sdk import MessageEvent # 消息事件 +from astrbot_sdk.decorators import on_command, on_message # 装饰器 +from astrbot_sdk.clients.llm import ChatMessage # 聊天消息(用于历史记录) +``` + +--- + +## 基本属性 + +### `peer` + +协议对等端,用于底层通信。 + +```python +# 类型: Any +# 说明: 内部使用,用于与 Core 通信 +``` + +### `plugin_id` + +当前插件的唯一标识符。 + +```python +# 类型: str +# 说明: 插件的名称,对应 plugin.yaml 中的 name 字段 + +ctx.logger.info(f"当前插件: {ctx.plugin_id}") +``` + +### `logger` + +绑定了插件 ID 的日志器。 + +```python +# 类型: PluginLogger +# 说明: 自动添加 plugin_id 上下文 + +# 不同级别的日志 +ctx.logger.debug("调试信息") +ctx.logger.info("普通信息") +ctx.logger.warning("警告信息") +ctx.logger.error("错误信息") + +# 绑定额外上下文 +logger = ctx.logger.bind(user_id="12345") +logger.info("用户操作") + +# 流式日志监听 +async for entry in ctx.logger.watch(): + print(f"[{entry.level}] {entry.message}") +``` + +### `cancel_token` + +取消令牌,用于长时间运行的任务中检查是否需要取消。 + +```python +# 类型: CancelToken + +# 检查是否取消 +ctx.cancel_token.raise_if_cancelled() + +# 触发取消 +ctx.cancel_token.cancel() + +# 等待取消信号 +await ctx.cancel_token.wait() + +# 检查状态 +if ctx.cancel_token.cancelled: + print("操作已取消") +``` + +**使用场景**: + +```python +async def long_operation(ctx: Context): + for item in large_list: + # 检查是否取消 + ctx.cancel_token.raise_if_cancelled() + + await process(item) +``` + +--- + +## 能力客户端 + +### 1. LLM 客户端 (ctx.llm) + +提供 AI 对话能力。 + +```python +# 类型: LLMClient +``` + +#### 方法 + +##### `chat()` + +简单对话。 + +```python +reply = await ctx.llm.chat("你好,介绍一下自己") + +# 带系统提示 +reply = await ctx.llm.chat( + "翻译成英文", + system="你是一个专业翻译助手" +) + +# 带对话历史 +from astrbot_sdk.clients.llm import ChatMessage + +history = [ + ChatMessage(role="user", content="我叫小明"), + ChatMessage(role="assistant", content="你好小明!"), +] +reply = await ctx.llm.chat("你记得我吗?", history=history) +``` + +##### `chat_raw()` + +获取完整响应对象。 + +```python +response = await ctx.llm.chat_raw("写一首诗", temperature=0.8) +print(f"生成文本: {response.text}") +print(f"Token 使用: {response.usage}") +print(f"结束原因: {response.finish_reason}") +``` + +##### `stream_chat()` + +流式对话。 + +```python +async for chunk in ctx.llm.stream_chat("讲一个故事"): + print(chunk, end="", flush=True) +``` + +--- + +### 2. Memory 客户端 (ctx.memory) + +提供语义搜索的记忆存储能力。 + +```python +# 类型: MemoryClient +``` + +#### 方法 + +##### `search()` + +搜索记忆。默认在有 embedding provider 时执行 hybrid 检索。 + +```python +results = await ctx.memory.search( + "用户喜欢什么颜色", + mode="hybrid", + limit=5, +) +for item in results: + print(item["key"], item["score"], item["match_type"]) +``` + +##### `save()` + +保存记忆。 + +```python +# 保存用户偏好 +await ctx.memory.save("user_pref", {"theme": "dark", "lang": "zh"}) + +# 使用关键字参数 +await ctx.memory.save("note", None, content="重要笔记", tags=["work"]) + +# 显式指定检索文本 +await ctx.memory.save( + "profile:alice", + { + "name": "Alice", + "embedding_text": "Alice 喜欢蓝色和海边", + }, +) +``` + +##### `get()` + +获取记忆。 + +```python +pref = await ctx.memory.get("user_pref") +if pref: + print(f"用户偏好主题: {pref.get('theme')}") +``` + +##### `save_with_ttl()` + +保存带过期时间的记忆。 + +```python +# 保存临时会话状态,1小时后过期 +await ctx.memory.save_with_ttl( + "session_temp", + {"state": "waiting"}, + ttl_seconds=3600 +) +``` + +##### `delete()` + +删除记忆。 + +```python +await ctx.memory.delete("old_note") +``` + +##### `stats()` + +查看记忆索引状态。 + +```python +stats = await ctx.memory.stats() +print(stats["total_items"], stats.get("embedded_items"), stats.get("dirty_items")) +``` + +--- + +### 3. DB 客户端 (ctx.db) + +提供键值存储能力,数据永久保存。运行时会自动把 key 限定在当前插件命名空间中; +`list()` 与 `watch()` 返回的仍是插件视角的原始 key。 + +```python +# 类型: DBClient +``` + +#### 方法 + +##### `get() / set()` + +基本读写。 + +```python +# 读取 +data = await ctx.db.get("user_settings") +if data: + print(data["theme"]) + +# 写入 +await ctx.db.set("user_settings", {"theme": "dark", "lang": "zh"}) +await ctx.db.set("greeted", True) +``` + +##### `delete()` + +删除数据。 + +```python +await ctx.db.delete("user_settings") +``` + +##### `list()` + +列出键。 + +```python +keys = await ctx.db.list("user_") +# ["user_settings", "user_profile", "user_history"] +``` + +##### `get_many() / set_many()` + +批量操作。 + +```python +# 批量读取 +values = await ctx.db.get_many(["user:1", "user:2"]) + +# 批量写入 +await ctx.db.set_many({ + "user:1": {"name": "Alice"}, + "user:2": {"name": "Bob"} +}) +``` + +##### `watch()` + +监听变更事件。 + +```python +async for event in ctx.db.watch("user:"): + print(f"{event['op']}: {event['key']}") +``` + +--- + +### 4. Files 客户端 (ctx.files) + +提供文件令牌注册与解析能力。 + +```python +# 类型: FileServiceClient +``` + +#### 方法 + +##### `register_file()` + +注册文件并获取令牌。 + +```python +token = await ctx.files.register_file("/path/to/file.jpg", timeout=3600) +``` + +##### `handle_file()` + +通过令牌解析文件路径。 + +```python +path = await ctx.files.handle_file(token) +``` + +--- + +### 5. Platform 客户端 (ctx.platform) + +提供向聊天平台发送消息和获取信息的能力。 + +```python +# 类型: PlatformClient +``` + +#### 方法 + +##### `send()` + +发送文本消息。 + +```python +await ctx.platform.send("qq:group:123456", "大家好!") + +# 使用 MessageSession +from astrbot_sdk.message_session import MessageSession + +session = MessageSession( + platform_id="qq", + message_type="group", + session_id="123456" +) +await ctx.platform.send(session, "你好!") +``` + +##### `send_image()` + +发送图片。 + +```python +await ctx.platform.send_image( + event.session_id, + "https://example.com/image.png" +) +``` + +##### `send_chain()` + +发送消息链。 + +```python +from astrbot_sdk.message_components import Plain, Image + +chain = [Plain("文字"), Image(url="https://example.com/img.jpg")] +await ctx.platform.send_chain(event.session_id, chain) +``` + +##### `send_by_id()` + +通过 ID 发送。 + +```python +await ctx.platform.send_by_id( + platform_id="qq", + session_id="user123", + content="Hello", + message_type="private" +) +``` + +##### `get_members()` + +获取群组成员。 + +```python +members = await ctx.platform.get_members("qq:group:123456") +for member in members: + print(f"{member['nickname']} ({member['user_id']})") +``` + +--- + +### 6. Providers 客户端 (ctx.providers) + +提供 Provider 发现和查询能力。 + +```python +# 类型: ProviderClient +``` + +#### 方法 + +##### `list_all()` + +列出所有 Provider。 + +```python +providers = await ctx.providers.list_all() +for p in providers: + print(f"{p.id}: {p.model}") +``` + +##### `get_using_chat()` + +获取当前使用的聊天 Provider。 + +```python +provider = await ctx.providers.get_using_chat() +if provider: + print(f"当前使用: {provider.id}") +``` + +##### `list_tts() / list_stt() / list_embedding() / list_rerank()` + +列出特定类型的 Provider。 + +```python +tts_providers = await ctx.providers.list_tts() +stt_providers = await ctx.providers.list_stt() +``` + +--- + +### 7. Provider Manager 客户端 (ctx.provider_manager) + +提供 Provider 的动态管理能力。 +仅 `reserved/system` 插件可用。普通插件调用会收到 `provider.manager.* is restricted to reserved/system plugins` 错误;普通插件通常应使用 `ctx.providers` 做只读查询。 + +```python +# 类型: ProviderManagerClient +``` + +#### 方法 + +##### `set_provider()` + +设置当前全局生效的 Provider。 +`umo` 只会作为变更事件里的来源标识,不会把 Provider 选择限定到单个会话。 + +```python +from astrbot_sdk.llm.entities import ProviderType + +await ctx.provider_manager.set_provider( + "my_provider", + ProviderType.CHAT_COMPLETION, + umo=event.session_id +) +``` + +##### `get_provider_by_id()` + +获取 Provider 记录。 + +```python +record = await ctx.provider_manager.get_provider_by_id("my_provider") +``` + +##### `create_provider() / update_provider() / delete_provider()` + +管理 Provider。 + +```python +# 创建 +record = await ctx.provider_manager.create_provider({ + "id": "my_provider", + "type": "openai", + "provider_type": "chat_completion", + "model": "gpt-4" +}) + +# 更新 +record = await ctx.provider_manager.update_provider( + "my_provider", + {"model": "gpt-4-turbo"} +) + +# 删除 +await ctx.provider_manager.delete_provider(provider_id="my_provider") +``` + +##### `watch_changes()` + +监听 Provider 变更事件。 + +```python +async for event in ctx.provider_manager.watch_changes(): + print(f"Provider {event.provider_id} 变更") +``` + +--- + +### 8. Personas 客户端 (ctx.personas / ctx.persona_manager) + +提供人格管理能力。 + +```python +# 类型: PersonaManagerClient +``` + +#### 方法 + +##### `get_persona() / get_all_personas()` + +获取人格。 + +```python +# 获取单个人格 +persona = await ctx.personas.get_persona("assistant") + +# 获取所有人格 +personas = await ctx.personas.get_all_personas() +``` + +##### `create_persona() / update_persona() / delete_persona()` + +管理人格。 + +```python +from astrbot_sdk.clients import PersonaCreateParams + +# 创建 +persona = await ctx.personas.create_persona(PersonaCreateParams( + persona_id="assistant", + system_prompt="你是一个有用的助手。", + begin_dialogs=["你好,有什么可以帮助你的?"] +)) + +# 更新 +updated = await ctx.personas.update_persona( + "assistant", + PersonaUpdateParams(system_prompt="你是一个专业的编程助手。") +) + +# 删除 +await ctx.personas.delete_persona("old_persona") +``` + +--- + +### 9. Conversations 客户端 (ctx.conversations / ctx.conversation_manager) + +提供对话管理能力。 + +```python +# 类型: ConversationManagerClient +``` + +#### 方法 + +##### `new_conversation()` + +创建新对话。 + +```python +from astrbot_sdk.clients import ConversationCreateParams + +conv_id = await ctx.conversations.new_conversation( + event.session_id, + ConversationCreateParams( + title="新对话", + persona_id="assistant" + ) +) +``` + +##### `switch_conversation()` + +切换当前对话。 + +```python +await ctx.conversations.switch_conversation( + event.session_id, + "conv_123" +) +``` + +##### `delete_conversation()` + +删除对话。 + +```python +# 删除指定对话 +await ctx.conversations.delete_conversation( + event.session_id, + "conv_123" +) + +# 删除当前对话 +await ctx.conversations.delete_conversation(event.session_id) +``` + +##### `get_conversation() / get_current_conversation() / get_conversations()` + +获取对话。 + +```python +# 获取单个对话 +conv = await ctx.conversations.get_conversation( + event.session_id, + "conv_123", + create_if_not_exists=True +) + +# 获取当前选中的对话 +current = await ctx.conversations.get_current_conversation( + event.session_id, + create_if_not_exists=True, +) + +# 获取对话列表 +convs = await ctx.conversations.get_conversations(event.session_id) +``` + +##### `update_conversation()` + +更新对话。 + +```python +from astrbot_sdk.clients import ConversationUpdateParams + +await ctx.conversations.update_conversation( + event.session_id, + "conv_123", + ConversationUpdateParams(title="新标题") +) +``` + +--- + +### 10. Knowledge Bases 客户端 (ctx.kbs / ctx.kb_manager) + +提供知识库管理能力。 + +```python +# 类型: KnowledgeBaseManagerClient +``` + +#### 方法 + +##### `get_kb()` + +获取知识库。 + +```python +kb = await ctx.kbs.get_kb("kb_123") +if kb: + print(f"知识库: {kb.kb_name}") + print(f"文档数: {kb.doc_count}") +``` + +##### `create_kb()` + +创建知识库。 + +```python +from astrbot_sdk.clients import KnowledgeBaseCreateParams + +kb = await ctx.kbs.create_kb(KnowledgeBaseCreateParams( + kb_name="技术文档", + embedding_provider_id="openai_embedding", + description="存储技术文档", + emoji="📚" +)) +``` + +##### `delete_kb()` + +删除知识库。 + +```python +deleted = await ctx.kbs.delete_kb("kb_123") +if deleted: + print("知识库已删除") +``` + +--- + +### 10.5 Message History 客户端 (ctx.message_history / ctx.message_history_manager) + +提供精确消息历史存储能力,按 `MessageSession` 保存原始消息组件、发送者和元数据。 + +```python +# 类型: MessageHistoryManagerClient +``` + +#### 方法 + +##### `append()` + +```python +from astrbot_sdk import MessageHistorySender, MessageSession, Plain + +session = MessageSession( + platform_id=event.platform_id, + message_type=event.message_type, + session_id=event.session_id, +) +record = await ctx.message_history.append( + session, + parts=[Plain(event.message_content, convert=False)], + sender=MessageHistorySender( + sender_id=event.sender_id, + sender_name=event.sender_name, + ), + metadata={"source": "message_handler"}, +) +``` + +##### `list()` + +```python +session = MessageSession( + platform_id=event.platform_id, + message_type=event.message_type, + session_id=event.session_id, +) +page = await ctx.message_history.list(session, limit=20) +for record in page.records: + print(record.id, record.sender.sender_name) +``` + +分页时建议直接复用上一页返回的 `next_cursor`,不要自行构造游标值。 + +##### `get() / get_by_id()` + +```python +session = MessageSession( + platform_id=event.platform_id, + message_type=event.message_type, + session_id=event.session_id, +) +record = await ctx.message_history.get(session, 1) +same_record = await ctx.message_history.get_by_id(session, 1) +``` + +##### `delete_before() / delete_after() / delete_all()` + +```python +from datetime import datetime, timezone + +session = MessageSession( + platform_id=event.platform_id, + message_type=event.message_type, + session_id=event.session_id, +) +await ctx.message_history.delete_before( + session, + before=datetime(2026, 3, 22, tzinfo=timezone.utc), +) +await ctx.message_history.delete_all(session) +``` + +当前实现要求传入带时区的 `datetime`,例如 `timezone.utc`。 + +--- + +### 11. HTTP 客户端 (ctx.http) + +提供 Web API 注册能力。 + +```python +# 类型: HTTPClient +``` + +当前实现会拦截包含 `..` 的路径和部分明显非法输入,但路由校验并非完全严格。 +文档示例建议统一使用以 `/` 开头、没有重复斜杠的规范化路径。`unregister_api(route)` 在不传 +`methods` 时会移除当前插件在该 route 下注册的全部方法。 + +#### 方法 + +##### `register_api()` + +注册 API 端点。 + +```python +from astrbot_sdk.decorators import provide_capability + +@provide_capability( + name="my_plugin.http_handler", + description="处理 HTTP 请求" +) +async def handle_http_request(request_id: str, payload: dict, cancel_token): + return {"status": 200, "body": {"result": "ok"}} + +await ctx.http.register_api( + route="/my-api", + handler=handle_http_request, + methods=["GET", "POST"], + description="我的 API" +) +``` + +##### `unregister_api()` + +注销 API。 + +```python +await ctx.http.unregister_api("/my-api") +``` + +##### `list_apis()` + +列出当前插件注册的所有 API。 + +```python +apis = await ctx.http.list_apis() +for api in apis: + print(f"{api['route']}: {api['methods']}") +``` + +--- + +### 12. Metadata 客户端 (ctx.metadata) + +提供插件元数据查询能力。 + +```python +# 类型: MetadataClient +``` + +#### 方法 + +##### `get_plugin()` + +获取指定插件信息。 + +```python +plugin = await ctx.metadata.get_plugin("another_plugin") +if plugin: + print(f"插件: {plugin.display_name}") + print(f"版本: {plugin.version}") + print(f"作者: {plugin.author}") +``` + +##### `list_plugins()` + +列出所有插件。 + +```python +plugins = await ctx.metadata.list_plugins() +for plugin in plugins: + print(f"{plugin.display_name} v{plugin.version} - {plugin.author}") +``` + +##### `get_current_plugin()` + +获取当前插件信息。 + +```python +current = await ctx.metadata.get_current_plugin() +if current: + print(f"当前插件: {current.name} v{current.version}") +``` + +##### `get_plugin_config()` + +获取插件配置。 + +```python +config = await ctx.metadata.get_plugin_config() +if config: + api_key = config.get("api_key") +``` + +**注意**: 只能查询当前插件自己的配置 + +--- + +### 13. Registry 客户端 (ctx.registry) + +提供 handler 注册表查询与白名单管理能力。 + +```python +# 类型: RegistryClient +``` + +#### 方法 + +##### `get_handlers_by_event_type()` + +获取指定事件类型下的全部 handler 元数据。 + +```python +handlers = await ctx.registry.get_handlers_by_event_type("message") +for handler in handlers: + print(handler.handler_full_name, handler.priority) +``` + +##### `get_handler_by_full_name()` + +按完整名称查询单个 handler。 + +```python +handler = await ctx.registry.get_handler_by_full_name("my_plugin.on_message") +if handler: + print(handler.trigger_type, handler.require_admin) +``` + +##### `set_handler_whitelist() / get_handler_whitelist() / clear_handler_whitelist()` + +管理 handler 白名单。 + +```python +await ctx.registry.set_handler_whitelist(["plugin_a", "plugin_b"]) +whitelist = await ctx.registry.get_handler_whitelist() +await ctx.registry.clear_handler_whitelist() +``` + +--- + +### 14. Skills 客户端 (ctx.skills) + +提供运行时技能注册与查询能力。 + +```python +# 类型: SkillClient +``` + +#### 方法 + +##### `register()` + +注册一个技能目录。 + +```python +skill = await ctx.skills.register( + name="my_skill", + path="/path/to/skill", + description="我的技能描述", +) +print(skill.skill_dir) +``` + +##### `unregister()` + +注销技能。 + +```python +removed = await ctx.skills.unregister("my_skill") +print(removed) +``` + +##### `list()` + +列出当前已注册的技能。 + +```python +skills = await ctx.skills.list() +for skill in skills: + print(skill.name, skill.path) +``` + +--- + +### 15. Session Plugins 客户端 (ctx.session_plugins) + +提供会话级别的插件状态管理能力。 + +```python +# 类型: SessionPluginManager +``` + +#### 方法 + +##### `is_plugin_enabled_for_session()` + +检查插件是否对指定会话启用。 + +```python +enabled = await ctx.session_plugins.is_plugin_enabled_for_session( + event, # 可以是 event、session 字符串或 MessageSession + "my_plugin", +) +``` + +##### `filter_handlers_by_session()` + +过滤会话启用的处理器。 + +```python +from astrbot_sdk.clients import HandlerMetadata + +enabled_handlers = await ctx.session_plugins.filter_handlers_by_session( + event, + all_handlers, +) +``` + +--- + +### 16. Session Services 客户端 (ctx.session_services) + +提供会话级别的 LLM/TTS 服务状态管理能力。 + +```python +# 类型: SessionServiceManager +``` + +#### 方法 + +##### `is_llm_enabled_for_session()` + +检查 LLM 是否对指定会话启用。 + +```python +enabled = await ctx.session_services.is_llm_enabled_for_session(event) +if not enabled: + await event.reply("LLM 服务已禁用") +``` + +##### `set_llm_status_for_session()` + +设置 LLM 服务状态。 + +```python +# 启用 LLM +await ctx.session_services.set_llm_status_for_session(event, True) + +# 禁用 LLM +await ctx.session_services.set_llm_status_for_session(event, False) +``` + +##### `should_process_llm_request()` + +判断是否应该处理 LLM 请求。 + +```python +if await ctx.session_services.should_process_llm_request(event): + response = await ctx.llm.chat("...") +``` + +##### `is_tts_enabled_for_session()` + +检查 TTS 是否对指定会话启用。 + +```python +enabled = await ctx.session_services.is_tts_enabled_for_session(event) +if enabled: + await event.reply("TTS 服务可用") +``` + +##### `set_tts_status_for_session()` + +设置 TTS 服务状态。 + +```python +# 启用 TTS +await ctx.session_services.set_tts_status_for_session(event, True) + +# 禁用 TTS +await ctx.session_services.set_tts_status_for_session(event, False) +``` + +##### `should_process_tts_request()` + +判断是否应该处理 TTS 请求。 + +```python +if await ctx.session_services.should_process_tts_request(event): + await handle_tts(text) +``` + +--- + +## 系统工具方法 + +### `get_data_dir()` + +获取插件数据目录路径。 + +```python +data_dir = await ctx.get_data_dir() +print(f"数据目录: {data_dir}") +``` + +**返回**: `Path` - 数据目录的 Path 对象 + +--- + +### `text_to_image()` + +将文本渲染为图片。 + +```python +url = await ctx.text_to_image("Hello World", return_url=True) +``` + +**参数**: +- `text`: 要渲染的文本 +- `return_url`: 是否返回 URL(False 则返回本地路径) + +**返回**: `str` - 图片 URL 或路径 + +--- + +### `html_render()` + +渲染 HTML 模板。 + +```python +url = await ctx.html_render( + tmpl="

{{ title }}

", + data={"title": "标题"} +) +``` + +**参数**: +- `tmpl`: HTML 模板内容 +- `data`: 模板数据 +- `return_url`: 是否返回 URL +- `options`: 渲染选项 + +**返回**: `str` - 渲染结果 URL 或路径 + +--- + +### `send_message()` + +向会话发送消息。 + +```python +await ctx.send_message(event.session_id, "消息内容") +``` + +**参数**: +- `session`: 会话标识 +- `content`: 消息内容(支持多种格式) + +--- + +### `send_message_by_id()` + +通过 ID 向平台发送消息。 + +```python +await ctx.send_message_by_id( + type="private", + id="user123", + content="Hello", + platform="qq" +) +``` + +--- + +### `register_task()` + +注册后台任务。 + +```python +async def background_work(): + while True: + await asyncio.sleep(60) + ctx.logger.info("每分钟执行一次") + +task = await ctx.register_task(background_work(), "定时任务") +``` + +**参数**: +- `task`: 可等待对象 +- `desc`: 任务描述 + +**返回**: `asyncio.Task` - 任务对象 + +**注意**: 任务失败会自动记录日志,不会影响插件主流程 + +--- + +## LLM Tool 管理方法 + +### `get_llm_tool_manager()` + +获取 LLM Tool 管理器。 + +```python +manager = ctx.get_llm_tool_manager() +``` + +--- + +### `add_llm_tools()` + +添加 LLM 工具规范。 + +```python +from astrbot_sdk.llm.tools import LLMToolSpec + +tool_spec = LLMToolSpec( + name="my_tool", + description="我的工具", + parameters_schema={...} +) + +await ctx.add_llm_tools(tool_spec) +``` + +--- + +### `activate_llm_tool() / deactivate_llm_tool()` + +激活/停用 LLM 工具。 + +```python +await ctx.activate_llm_tool("my_tool") +await ctx.deactivate_llm_tool("my_tool") +``` + +--- + +### `register_llm_tool()` + +注册可执行的 LLM 工具。 + +```python +async def search_weather(location: str) -> str: + return f"{location} 今天晴天" + +await ctx.register_llm_tool( + name="search_weather", + parameters_schema={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "城市名称"} + }, + "required": ["location"] + }, + desc="搜索天气信息", + func_obj=search_weather, + active=True +) +``` + +--- + +### `unregister_llm_tool()` + +取消注册 LLM 工具。 + +```python +await ctx.unregister_llm_tool("my_tool") +``` + +--- + +## 高级方法 + +### `tool_loop_agent()` + +执行 Agent 工具循环。 + +**签名**: +```python +async def tool_loop_agent( + self, + request: ProviderRequest | None = None, + **kwargs: Any +) -> LLMResponse +``` + +**参数**: +- `request`: ProviderRequest 对象,包含请求配置 +- `**kwargs`: 额外的请求参数,会自动合并到 request + +**返回**: `LLMResponse` - 包含工具调用结果的完整响应 + +**示例**: + +```python +from astrbot_sdk.llm.entities import ProviderRequest + +response = await ctx.tool_loop_agent( + request=ProviderRequest( + prompt="搜索天气", + system_prompt="你是一个助手" + ) +) +print(response.text) +``` + +--- + +### `register_commands()` + +注册命令(仅在 `astrbot_loaded` 或 `platform_loaded` 事件中可用)。 + +**签名**: +```python +async def register_commands( + self, + command_name: str, + handler_full_name: str, + *, + desc: str = "", + priority: int = 0, + use_regex: bool = False, + ignore_prefix: bool = False, +) -> None +``` + +**参数**: +- `command_name`: 命令名称 +- `handler_full_name`: 处理函数的完整名称(如 `module.handler_name`) +- `desc`: 命令描述 +- `priority`: 优先级 +- `use_regex`: 是否使用正则匹配 +- `ignore_prefix`: 是否忽略前缀(SDK 中不支持) + +**异常**: +- `AstrBotError`: 如果在非加载事件中调用或设置 `ignore_prefix=True` + +**示例**: + +```python +@on_event("astrbot_loaded") +async def on_load(self, event, ctx: Context): + await ctx.register_commands( + command_name="my_cmd", + handler_full_name="my_module.handle_cmd", + desc="我的命令", + priority=10 + ) +``` + +--- + +### `list_platforms()` + +列出所有可见的平台兼容层实例。 + +**签名**: +```python +async def list_platforms(self) -> list[PlatformCompatFacade] +``` + +**返回**: `list[PlatformCompatFacade]` - 平台兼容层实例列表 + +**示例**: + +```python +for platform in await ctx.list_platforms(): + print(platform.id, platform.status) +``` + +--- + +### `get_platform()` + +获取指定类型的平台兼容层实例。 + +**签名**: +```python +async def get_platform(self, platform_type: str) -> PlatformCompatFacade | None +``` + +**参数**: +- `platform_type`: 平台类型(如 "qq", "telegram") + +**返回**: `PlatformCompatFacade | None` - 平台兼容层实例 + +**示例**: + +```python +platform = await ctx.get_platform("qq") +if platform: + await platform.send_by_session("session_id", "消息") +``` + +--- + +### `get_platform_inst()` + +获取指定 ID 的平台兼容层实例。 + +**签名**: +```python +async def get_platform_inst(self, platform_id: str) -> PlatformCompatFacade | None +``` + +**参数**: +- `platform_id`: 平台实例 ID + +**返回**: `PlatformCompatFacade | None` - 平台兼容层实例 + +--- + +## PlatformCompatFacade + +平台兼容层类,提供安全的平台元信息和主动发送能力。 + +### 属性 + +| 属性 | 类型 | 说明 | +|------|------|------| +| `id` | `str` | 平台实例 ID | +| `name` | `str` | 平台名称 | +| `type` | `str` | 平台类型 | +| `status` | `PlatformStatus` | 平台状态 | +| `errors` | `list[PlatformError]` | 错误列表 | +| `last_error` | `PlatformError \| None` | 最近错误 | +| `unified_webhook` | `bool` | 是否统一 webhook | + +### 方法 + +#### `send()` + +发送消息。 + +```python +await platform.send("session_id", "消息内容") +``` + +#### `send_by_session()` + +通过会话发送消息。 + +```python +await platform.send_by_session("platform:private:123", "消息") +``` + +#### `send_by_id()` + +通过 ID 发送消息。 + +```python +await platform.send_by_id("user123", "消息", message_type="private") +``` + +#### `refresh()` + +刷新平台状态。 + +```python +await platform.refresh() +``` + +#### `clear_errors()` + +清除平台错误。 + +```python +await platform.clear_errors() +``` + +#### `get_stats()` + +获取平台统计信息。 + +```python +stats = await platform.get_stats() +``` + +--- + +## 使用示例 + +### 1. 基本对话流程 + +```python +from astrbot_sdk.decorators import on_message + +@on_message() +async def handle_message(event: MessageEvent, ctx: Context): + reply = await ctx.llm.chat(event.message_content) + await ctx.platform.send(event.session_id, reply) +``` + +--- + +### 2. 带历史的对话 + +```python +@on_message() +async def handle_message(event: MessageEvent, ctx: Context): + # 从 memory 获取历史 + history_data = await ctx.memory.get(f"history:{event.session_id}") + history = history_data.get("messages", []) if history_data else [] + + # 对话 + reply = await ctx.llm.chat(event.message_content, history=history) + + # 保存新消息到历史 + history.append(ChatMessage(role="user", content=event.message_content)) + history.append(ChatMessage(role="assistant", content=reply)) + await ctx.memory.save(f"history:{event.session_id}", {"messages": history}) + + await ctx.platform.send(event.session_id, reply) +``` + +如果你需要保留原始消息组件、发送者和分页删除能力,应优先使用 `ctx.message_history`。 + +--- + +### 3. 使用数据库持久化 + +```python +@on_message() +async def handle_message(event: MessageEvent, ctx: Context): + # 获取用户配置 + config = await ctx.db.get(f"user_config:{event.sender_id}") + + if not config: + config = {"theme": "light", "lang": "zh"} + await ctx.db.set(f"user_config:{event.sender_id}", config) + + # 使用配置 + reply = f"你的主题设置是: {config['theme']}" + await ctx.platform.send(event.session_id, reply) +``` + +--- + +### 4. 注册 Web API + +```python +from astrbot_sdk.decorators import provide_capability + +@provide_capability( + name="my_plugin.get_status", + description="获取插件状态", +) +async def get_status(request_id: str, payload: dict, cancel_token): + return {"status": "running", "version": "1.0.0"} + +@on_command("setup_api") +async def setup_api(event: MessageEvent, ctx: Context): + await ctx.http.register_api( + route="/status", + handler=get_status, + methods=["GET"] + ) + await ctx.platform.send(event.session_id, "API 已注册") +``` + +--- + +## 注意事项 + +1. **跨进程通信**: Context 通过 capability 协议与核心通信,所有方法调用都是异步的 + +2. **插件隔离**: 每个插件有独立的 Context 实例,数据和配置是隔离的 + +3. **取消处理**: 长时间运行的操作应定期检查 `ctx.cancel_token.raise_if_cancelled()` + +4. **错误处理**: 所有远程调用都可能失败,建议使用 try-except 处理 + +5. **Memory vs DB vs MessageHistory**: + - Memory: 语义搜索,适合 AI 上下文 + - DB: 精确匹配,适合结构化数据 + - MessageHistory: 精确保存消息组件、发送者和元数据 + +6. **DB 作用域**: `ctx.db` 的 key 会自动限制在当前插件命名空间中 + +7. **HTTP 路由**: `ctx.http.register_api()` 当前会拦截 `..` 等明显非法路径,但仍建议插件自行使用规范化 route + +8. **文件操作**: 使用 `ctx.files` 注册文件令牌,不要直接传递本地路径 + +9. **平台标识**: 使用 UMO(统一消息来源标识)格式:`"platform:instance:session_id"` + +10. **配置访问**: `get_plugin_config()` 只支持查询当前插件自己的配置 + +--- + +## 相关模块 + +- **LLM 客户端**: `astrbot_sdk.clients.llm.LLMClient` +- **Memory 客户端**: `astrbot_sdk.clients.memory.MemoryClient` +- **DB 客户端**: `astrbot_sdk.clients.db.DBClient` +- **Message History 客户端**: `astrbot_sdk.clients.managers.MessageHistoryManagerClient` +- **Platform 客户端**: `astrbot_sdk.clients.platform.PlatformClient` +- **日志器**: `astrbot_sdk._internal.plugin_logger.PluginLogger` +- **取消令牌**: `astrbot_sdk.context.CancelToken` + +--- + +**版本**: v4.0 +**模块**: `astrbot_sdk.context.Context` +**最后更新**: 2026-03-22 diff --git a/astrbot-sdk/docs/api/decorators.md b/astrbot-sdk/docs/api/decorators.md new file mode 100644 index 0000000000..1149329b49 --- /dev/null +++ b/astrbot-sdk/docs/api/decorators.md @@ -0,0 +1,1218 @@ +# 装饰器 - 事件处理注册完整参考 + +## 概述 + +装饰器是 AstrBot SDK 中用于注册事件处理器的核心机制。通过装饰器标记方法,SDK 会自动收集这些方法并在适当时机调用它们。 + +**模块路径**: `astrbot_sdk.decorators` + +--- + +## 目录 + +- [事件触发装饰器](#事件触发装饰器) +- [修饰器装饰器](#修饰器装饰器) +- [过滤器装饰器](#过滤器装饰器) +- [限制器装饰器](#限制器装饰器) +- [能力暴露装饰器](#能力暴露装饰器) +- [LLM 工具装饰器](#llm-工具装饰器) +- [使用示例](#使用示例) + +--- + +## 导入方式 + +```python +# 从主模块导入(推荐) +from astrbot_sdk.decorators import ( + # 事件触发 + on_command, + on_message, + on_event, + on_schedule, + # 修饰器 + require_admin, + # 过滤器 + platforms, + message_types, + group_only, + private_only, + # 限制器 + rate_limit, + cooldown, + # 能力暴露 + provide_capability, + # LLM 工具 + register_llm_tool, + register_agent, +) + +# 或者按需导入 +from astrbot_sdk.decorators import on_command, on_message +``` + +--- + +## 事件触发装饰器 + +### @on_command + +命令触发装饰器,当用户输入指定命令时触发。 + +#### 签名 + +```python +def on_command( + command: str | Sequence[str], + *, + aliases: list[str] | None = None, + description: str | None = None, +) -> Callable[[HandlerCallable], HandlerCallable] +``` + +#### 参数 + +| 参数 | 类型 | 必需 | 说明 | +|------|------|------|------| +| `command` | `str \| Sequence[str]` | 是 | 命令名称(不包含前缀符),可传入单个命令或命令列表 | +| `aliases` | `list[str] \| None` | 否 | 命令别名列表 | +| `description` | `str \| None` | 否 | 命令描述,用于帮助信息生成 | + +#### 示例 + +```python +# 简单命令 +@on_command("hello") +async def hello(self, event: MessageEvent, ctx: Context): + await event.reply("Hello, World!") + +# 带别名 +@on_command("echo", aliases=["repeat", "say"]) +async def echo(self, event: MessageEvent, text: str): + await event.reply(f"你说: {text}") + +# 带描述 +@on_command("help", description="显示帮助信息") +async def help(self, event: MessageEvent, ctx: Context): + await event.reply("可用命令: /hello") + +# 批量命令 +@on_command(["start", "begin"]) +async def start(self, event: MessageEvent, ctx: Context): + await event.reply("开始执行...") +``` + +#### 注意事项 + +1. 命令名称不应包含前缀符(如 `/`),框架会自动处理 +2. 传入命令列表时,第一个命令作为主命令名,其余作为别名 +3. `aliases` 参数中的别名会与命令列表合并,重复项会自动去重 +4. 命令名不能为空字符串 + +--- + +### @on_message + +消息触发装饰器,当消息匹配指定条件时触发。 + +#### 签名 + +```python +def on_message( + *, + regex: str | None = None, + keywords: list[str] | None = None, + platforms: list[str] | None = None, + message_types: list[str] | None = None, +) -> Callable[[HandlerCallable], HandlerCallable] +``` + +#### 参数 + +| 参数 | 类型 | 必需* | 说明 | +|------|------|--------|------| +| `regex` | `str \| None` | 否* | 正则表达式模式 | +| `keywords` | `list[str] \| None` | 否* | 关键词列表(任一匹配即触发) | +| `platforms` | `list[str] \| None` | 否 | 限定平台列表 | +| `message_types` | `list[str] \| None` | 否 | 限定消息类型(`"group"`, `"private"`) | + +*注: `regex` 和 `keywords` 至少需要提供一个 + +#### 示例 + +```python +# 关键词匹配 +@on_message(keywords=["帮助", "help"]) +async def help(self, event: MessageEvent, ctx: Context): + await event.reply("可用命令: /hello") + +# 正则匹配 +@on_message(regex=r"\d{4,}") +async def number(self, event: MessageEvent, ctx: Context): + await event.reply("检测到数字!") + +# 多条件过滤 +@on_message( + keywords=["天气"], + platforms=["qq"], + message_types=["private"] +) +async def weather(self, event: MessageEvent, ctx: Context): + await event.reply("请输入城市名称查询天气") + +# 组合使用 +@on_message(regex=r"^打卡") +async def check_in(self, event: MessageEvent, ctx: Context): + await event.reply(f"{event.sender_name} 打卡成功!") +``` + +#### 注意事项 + +1. 正则表达式使用 Python `re` 模块语法 +2. 关键词匹配是包含匹配,不是精确匹配 +3. 不能与 `@platforms()` 装饰器混用(会有 `ValueError`) +4. 不能与 `@group_only()` / `@private_only()` / `@message_types()` 混用 + +--- + +### @on_event + +事件触发装饰器,用于处理非消息类型的系统事件。 + +#### 签名 + +```python +def on_event(event_type: str) -> Callable[[HandlerCallable], HandlerCallable] +``` + +#### 参数 + +| 参数 | 类型 | 必需 | 说明 | +|------|------|------|------| +| `event_type` | `str` | 是 | 事件类型标识 | + +#### 示例 + +```python +# 群成员加入事件 +@on_event("group_member_join") +async def welcome(self, event, ctx: Context): + await ctx.platform.send(event.group_id, f"欢迎 {event.user_id}!") + +# 群成员离开事件 +@on_event("group_member_decrease") +async def goodbye(self, event, ctx: Context): + await ctx.platform.send(event.group_id, f"再见 {event.user_id}") + +# 好友请求事件 +@on_event("friend_request") +async def handle_request(self, event, ctx: Context): + await ctx.platform.send(event.user_id, "已自动通过好友请求") +``` + +#### LLM Pipeline Hooks + +`@on_event` 也用于挂接 AstrBot 原生消息处理链路中的系统事件。 + +常见事件及可注入对象: + +| 事件名 | 常见可注入参数 | 是否可修改主链路 | +|------|------|------| +| `waiting_llm_request` | `MessageEvent`, `Context` | 间接可修改,例如切换当前对话 persona | +| `agent_begin` | `MessageEvent`, `Context` | 否,适合在 Agent 真正开始执行时做准备工作 | +| `llm_request` | `MessageEvent`, `Context`, `ProviderRequest` | 是,可直接修改 `ProviderRequest` | +| `agent_done` | `MessageEvent`, `Context`, `LLMResponse` | 否,适合观察和提取 Agent 最终回复 | +| `decorating_result` | `MessageEvent`, `Context`, `MessageEventResult` | 是,可直接修改结果消息链 | +| `after_message_sent` | `MessageEvent`, `Context` | 否,适合落库、记忆、统计 | +| `calling_func_tool` | `MessageEvent`, `Context` | 否,可读取 `event.raw["tool_name"]` / `event.raw["tool_args"]` | +| `llm_tool_start` | `MessageEvent`, `Context` | 否,可读取 `event.raw["tool_name"]` / `event.raw["tool_args"]` | +| `llm_tool_end` | `MessageEvent`, `Context` | 否,可读取 `event.raw["tool_name"]` / `event.raw["tool_result"]` | +| `plugin_error` | `MessageEvent`, `Context` | 否,可读取 `event.raw["plugin_name"]` / `event.raw["error"]` | + +最小示例: + +```python +from astrbot_sdk import Context, MessageEvent +from astrbot_sdk.decorators import on_event +from astrbot_sdk.llm.entities import ProviderRequest + +@on_event("llm_request") +async def add_memory(self, event: MessageEvent, ctx: Context, request: ProviderRequest): + del event, ctx + request.system_prompt = (request.system_prompt or "") + "\n\nmemory: user likes tea" +``` + +完整示例: + +```python +from astrbot_sdk import Context, MessageEvent, Star +from astrbot_sdk.clients.llm import LLMResponse +from astrbot_sdk.clients.managers import ConversationUpdateParams +from astrbot_sdk.decorators import on_event +from astrbot_sdk.llm.entities import ProviderRequest +from astrbot_sdk.message_result import MessageEventResult +from astrbot_sdk.message_components import Plain + +class PersonaSample(Star): + @on_event("waiting_llm_request") + async def ensure_persona(self, event: MessageEvent, ctx: Context) -> None: + conversation = await ctx.conversations.get_current_conversation( + event.session_id, + create_if_not_exists=True, + ) + if conversation is None or conversation.persona_id == "girlfriend": + return + await ctx.conversations.update_conversation( + event.session_id, + conversation.conversation_id, + ConversationUpdateParams(persona_id="girlfriend"), + ) + + @on_event("llm_request") + async def inject_context( + self, + event: MessageEvent, + ctx: Context, + request: ProviderRequest, + ) -> None: + memories = await ctx.memory.search(event.text, limit=3) + facts = [] + for item in memories: + value = item.get("value") + if isinstance(value, dict) and value.get("content"): + facts.append(f"- {value['content']}") + if facts: + request.system_prompt = (request.system_prompt or "") + "\n\n" + "\n".join(facts) + + @on_event("agent_done") + async def capture_reply( + self, + event: MessageEvent, + ctx: Context, + response: LLMResponse, + ) -> None: + del ctx + if response.text: + event.set_extra("last_reply", response.text) + + @on_event("decorating_result") + async def decorate( + self, + event: MessageEvent, + ctx: Context, + result: MessageEventResult, + ) -> None: + del event, ctx + result.chain.append(Plain("\n[persona active]", convert=False)) + + @on_event("after_message_sent") + async def persist(self, event: MessageEvent, ctx: Context) -> None: + reply = str(event.get_extra("last_reply", "") or "").strip() + if not reply: + reply = str(event.get_sent_message_outline() or "").strip() + if reply: + await ctx.db.set("sample:last_reply", reply) +``` + +#### 注意事项 + +1. 可用于处理平台事件,也可用于处理 AstrBot 原生消息链路中的系统事件(如 `llm_request`) +2. 不能与 `@rate_limit` 或 `@cooldown` 一起使用 +3. 不同平台的事件类型可能不同,需要查阅平台文档 +4. `llm_request` 和 `decorating_result` 注入的是可变对象,修改会回写到 AstrBot 主链路 +5. `agent_done` 主要用于观测和提取结果,不应用来替代主回复流程 +6. 请求范围内的 JSON-safe `event.set_extra()` 数据会在同一次请求的 SDK hooks 之间保留;非 JSON-safe 值只在当前 handler 内可见 +7. `after_message_sent` 会保留 `event.text` 作为原始用户输入;读取机器人实际发送内容时,优先使用 `event.get_sent_message_outline()` 和 `event.get_sent_messages()` + +#### 系统事件附加字段 + +- `calling_func_tool` / `llm_tool_start`: `event.raw["tool_name"]`, `event.raw["tool_args"]` +- `llm_tool_end`: `event.raw["tool_name"]`, `event.raw["tool_args"]`, `event.raw["tool_result"]` +- `plugin_error`: `event.raw["plugin_name"]`, `event.raw["handler_name"]`, `event.raw["error"]`, `event.raw["traceback"]` +- `after_message_sent`: `event.get_sent_message_outline()`, `event.get_sent_messages()` + +--- + +### @on_schedule + +定时任务装饰器,按指定时间间隔或 cron 表达式触发。 + +#### 签名 + +```python +def on_schedule( + *, + cron: str | None = None, + interval_seconds: int | None = None, +) -> Callable[[HandlerCallable], HandlerCallable] +``` + +#### 参数 + +| 参数 | 类型 | 必需* | 说明 | +|------|------|--------|------| +| `cron` | `str \| None` | 否* | cron 表达式(如 `"0 8 * * *"` 表示每天 8:00) | +| `interval_seconds` | `int \| None` | 否* | 执行间隔(秒) | + +*注: `cron` 和 `interval_seconds` 必须且只能提供一个 + +#### 示例 + +```python +# 固定间隔(每小时执行) +@on_schedule(interval_seconds=3600) +async def hourly_check(self, ctx: Context): + ctx.logger.info("每小时执行一次") + +# cron 表达式(每天 8:00) +@on_schedule(cron="0 8 * * *") +async def morning_greeting(self, ctx: Context): + await ctx.platform.send("group_123", "早上好!") + +# 每2小时 +@on_schedule(cron="0 */2 * * *") +async def bi_hourly_task(self, ctx: Context): + pass + +# 工作日 9:00-17:00 每小时 +@on_schedule(cron="0 9-17 * * 1-5") +async def work_hours_check(self, ctx: Context): + pass +``` + +#### cron 表达式格式 + +``` +分钟 小时 日 月 星期 +* * * * * + +示例: +0 8 * * * # 每天 8:00 +0 */2 * * * # 每2小时 +0 9-17 * * 1-5 # 工作日 9:00-17:00 每小时 +*/10 * * * * # 每10分钟 +``` + +#### 注意事项 + +1. cron 表达式格式: `分钟 小时 日 月 星期` +2. 不能与 `@rate_limit` 或 `@cooldown` 一起使用 +3. 定时任务的 handler 不接收 `MessageEvent` 参数 +4. `interval_seconds` 最小值为 60(1分钟) + +--- + +## 修饰器装饰器 + +### @require_admin + +管理员权限装饰器,限制只有管理员才能调用。 + +#### 签名 + +```python +def require_admin(func: HandlerCallable) -> HandlerCallable +``` + +#### 示例 + +```python +from astrbot_sdk.decorators import on_command, require_admin + +@on_command("shutdown") +@require_admin +async def shutdown(self, event: MessageEvent, ctx: Context): + await event.reply("正在关闭系统...") +``` + +#### 注意事项 + +1. 必须放在事件触发装饰器(如 `@on_command`)之后 +2. 非管理员用户触发时,handler 不会被调用 +3. 别名: `@admin_only()` 功能完全相同 + +--- + +## 过滤器装饰器 + +### @platforms + +限定平台装饰器,只在指定平台上触发。 + +#### 签名 + +```python +def platforms(*names: str) -> Callable[[HandlerCallable], HandlerCallable] +``` + +#### 参数 + +| 参数 | 类型 | 必需 | 说明 | +|------|------|------|------| +| `*names` | `str` | 是 | 平台名称(可变参数) | + +#### 示例 + +```python +@on_command("qq_only") +@platforms("qq") +async def qq_only(self, event: MessageEvent, ctx: Context): + await event.reply("这是 QQ 专属命令") + +@on_command("multi") +@platforms("qq", "telegram", "discord") +async def multi(self, event: MessageEvent, ctx: Context): + await event.reply("支持多平台") +``` + +--- + +### @message_types + +限定消息类型装饰器。 + +#### 签名 + +```python +def message_types(*types: str) -> Callable[[HandlerCallable], HandlerCallable] +``` + +#### 示例 + +```python +@on_command("group_only") +@message_types("group") +async def group_only(self, event: MessageEvent, ctx: Context): + await event.reply("这是群聊命令") +``` + +--- + +### @group_only + +仅群聊装饰器。 + +#### 签名 + +```python +def group_only() -> Callable[[HandlerCallable], HandlerCallable] +``` + +#### 示例 + +```python +@on_command("group_admin") +@group_only() +async def group_admin(self, event: MessageEvent, ctx: Context): + await event.reply("这是群聊管理命令") +``` + +#### 注意事项 + +功能等同于 `@message_types("group")` + +--- + +### @private_only + +仅私聊装饰器。 + +#### 签名 + +```python +def private_only() -> Callable[[HandlerCallable], HandlerCallable] +``` + +#### 示例 + +```python +@on_command("private_chat") +@private_only() +async def private_only(self, event: MessageEvent, ctx: Context): + await event.reply("这是私聊命令") +``` + +--- + +## 限制器装饰器 + +### @rate_limit + +速率限制装饰器,限制时间窗口内的调用次数。 + +#### 签名 + +```python +def rate_limit( + limit: int, + window: float, + *, + scope: LimiterScope = "session", + behavior: LimiterBehavior = "hint", + message: str | None = None, +) -> Callable[[HandlerCallable], HandlerCallable] +``` + +#### 参数 + +| 参数 | 类型 | 必需 | 默认值 | 说明 | +|------|------|------|--------|------| +| `limit` | `int` | 是 | - | 时间窗口内最大调用次数 | +| `window` | `float` | 是 | - | 时间窗口大小(秒) | +| `scope` | `LimiterScope` | 否 | `"session"` | 限制范围 | +| `behavior` | `LimiterBehavior` | 否 | `"hint"` | 触发限制后的行为 | +| `message` | `str \| None` | 否 | `None` | 自定义提示消息 | + +**scope 可选值**: +- `"session"` - 会话级别 +- `"user"` - 用户级别 +- `"group"` - 群组级别 +- `"global"` - 全局级别 + +**behavior 可选值**: +- `"hint"` - 返回提示消息 +- `"silent"` - 静默忽略 +- `"error"` - 抛出异常 + +#### 示例 + +```python +# 每分钟最多5次 +@on_command("search") +@rate_limit(5, 60) +async def search(self, event: MessageEvent, ctx: Context): + await event.reply("搜索结果...") + +# 每用户每小时3次 +@on_command("draw") +@rate_limit(3, 3600, scope="user") +async def draw(self, event: MessageEvent, ctx: Context): + await event.reply("绘图结果...") + +# 全局限制,自定义消息 +@on_command("global") +@rate_limit( + 10, 60, + scope="global", + message="操作过于频繁,请稍后再试" +) +async def global_action(self, event: MessageEvent, ctx: Context): + await event.reply("执行全局操作") +``` + +--- + +### @cooldown + +冷却时间装饰器,限制连续调用的间隔。 + +#### 签名 + +```python +def cooldown( + seconds: float, + *, + scope: LimiterScope = "session", + behavior: LimiterBehavior = "hint", + message: str | None = None, +) -> Callable[[HandlerCallable], HandlerCallable] +``` + +#### 参数 + +| 参数 | 类型 | 必需 | 默认值 | 说明 | +|------|------|------|--------|------| +| `seconds` | `float` | 是 | - | 冷却时间(秒) | +| `scope` | `LimiterScope` | 否 | `"session"` | 限制范围 | +| `behavior` | `LimiterBehavior` | 否 | `"hint"` | 触发限制后的行为 | +| `message` | `str \| None` | 否 | `None` | 自定义提示消息 | + +#### 示例 + +```python +# 30秒冷却 +@on_command("cast_skill") +@cooldown(30) +async def cast_skill(self, event: MessageEvent, ctx: Context): + await event.reply("技能施放成功!") + +# 每用户24小时冷却 +@on_command("daily_reward") +@cooldown(86400, scope="user") +async def daily_reward(self, event: MessageEvent, ctx: Context): + await event.reply("领取每日奖励!") + +# 群组5分钟冷却 +@on_command("group_activity") +@cooldown(300, scope="group") +async def group_activity(self, event: MessageEvent, ctx: Context): + await event.reply("群活动已开始") +``` + +#### 注意事项 + +1. 只适用于 `@on_command` 和 `@on_message` +2. 不能与 `@rate_limit` 叠加使用 +3. `cooldown` 本质上是 `limit=1` 的 `rate_limit` + +--- + +## 能力暴露装饰器 + +### @provide_capability + +暴露插件能力给其他插件调用的装饰器。 + +#### 签名 + +```python +def provide_capability( + name: str, + *, + description: str, + input_schema: dict[str, Any] | None = None, + output_schema: dict[str, Any] | None = None, + input_model: type[BaseModel] | None = None, + output_model: type[BaseModel] | None = None, + supports_stream: bool = False, + cancelable: bool = False, +) -> Callable[[HandlerCallable], HandlerCallable] +``` + +#### 参数 + +| 参数 | 类型 | 必需 | 说明 | +|------|------|------|------| +| `name` | `str` | 是 | 能力名称(不能使用保留命名空间) | +| `description` | `str` | 是 | 能力描述 | +| `input_schema` | `dict \| None` | 否* | 输入 JSON Schema | +| `output_schema` | `dict \| None` | 否* | 输出 JSON Schema | +| `input_model` | `type[BaseModel] \| None` | 否* | 输入 pydantic 模型 | +| `output_model` | `type[BaseModel] \| None` | 否* | 输出 pydantic 模型 | +| `supports_stream` | `bool` | 否 | 是否支持流式输出 | +| `cancelable` | `bool` | 否 | 是否可取消 | + +*注: `input_schema` 与 `input_model` 二选一,`output_schema` 与 `output_model` 二选一 + +#### 示例 + +```python +from pydantic import BaseModel, Field + +class CalculateInput(BaseModel): + x: int = Field(description="第一个数") + y: int = Field(description="第二个数") + +class CalculateOutput(BaseModel): + result: int = Field(description="计算结果") + +@provide_capability( + "my_plugin.calculate", + description="执行加法计算", + input_model=CalculateInput, + output_model=CalculateOutput +) +async def calculate(self, payload: dict, ctx: Context): + x = payload["x"] + y = payload["y"] + return {"result": x + y} +``` + +#### 注意事项 + +1. 保留命名空间(`handler.`, `system.`, `internal.`)不能用于插件能力 +2. `input_schema` 和 `input_model` 不能同时提供 +3. `output_schema` 和 `output_model` 不能同时提供 +4. 能力名称格式建议: `插件名.功能名` + +--- + +## LLM 工具装饰器 + +### @register_llm_tool + +注册 LLM 工具装饰器,使插件函数可被 LLM 调用。 + +#### 签名 + +```python +def register_llm_tool( + name: str | None = None, + *, + description: str | None = None, + parameters_schema: dict[str, Any] | None = None, + active: bool = True, +) -> Callable[[HandlerCallable], HandlerCallable] +``` + +#### 参数 + +| 参数 | 类型 | 必需 | 默认值 | 说明 | +|------|------|------|--------|------| +| `name` | `str \| None` | 否 | 函数名 | 工具名称 | +| `description` | `str \| None` | 否 | 函数文档字符串首行 | 工具描述 | +| `parameters_schema` | `dict \| None` | 否 | 自动从函数签名推断 | 参数 JSON Schema | +| `active` | `bool` | 否 | `True` | 是否激活 | + +#### 示例 + +```python +# 自动推断参数 +@register_llm_tool() +async def get_weather(self, city: str, unit: str = "celsius"): + """获取指定城市的天气信息""" + return f"{city} 的天气: 25°C" + +# 自定义 schema +@register_llm_tool( + name="search_database", + description="搜索数据库中的记录", + parameters_schema={ + "type": "object", + "properties": { + "query": {"type": "string", "description": "搜索关键词"}, + "limit": {"type": "integer", "description": "返回结果数量", "default": 10} + }, + "required": ["query"] + }, + active=True +) +async def search_database(self, query: str, limit: int = 10): + # 实现数据库搜索逻辑 + return {"results": [...]} +``` + +#### 注意事项 + +1. 如果不提供 `name`,将使用函数名作为工具名 +2. 如果不提供 `description`,将使用函数文档字符串的第一行 +3. 如果不提供 `parameters_schema`,会自动从函数签名推断 +4. 参数推断时会跳过 `self`, `event`, `ctx`, `context` 等特殊参数 + +--- + +### @register_agent + +注册 Agent 装饰器,将类注册为 LLM Agent。 + +#### 签名 + +```python +def register_agent( + name: str, + *, + description: str = "", + tool_names: list[str] | None = None, +) -> Callable[[type[BaseAgentRunner]], type[BaseAgentRunner]] +``` + +#### 参数 + +| 参数 | 类型 | 必需 | 默认值 | 说明 | +|------|------|------|--------|------| +| `name` | `str` | 是 | - | Agent 名称 | +| `description` | `str` | 否 | `""` | Agent 描述 | +| `tool_names` | `list[str] \| None` | 否 | `None` | 可用工具名称列表 | + +#### 示例 + +```python +from astrbot_sdk.llm.agents import BaseAgentRunner +from astrbot_sdk.llm.entities import ProviderRequest + +class WeatherAgent(BaseAgentRunner): + async def run(self, ctx: Context, request: ProviderRequest) -> Any: + # 实现 agent 运行逻辑 + return "天气信息" + +class MyPlugin(Star): + @register_agent("my_agent", description="我的智能助手") + class MyAgentRunner(BaseAgentRunner): + async def run(self, ctx: Context, request: ProviderRequest) -> Any: + return "多工具处理结果" +``` + +#### 注意事项 + +1. 必须应用于 `BaseAgentRunner` 的子类 +2. `tool_names` 指定该 agent 可以使用的 LLM 工具 +3. Agent 的实际执行由 core tool loop 管理 + +--- + +## 其他装饰器 + +### @admin_only + +`@require_admin` 的别名,功能完全相同。 + +**签名**: +```python +def admin_only(func: HandlerCallable) -> HandlerCallable +``` + +--- + +### @priority + +设置 handler 执行优先级。 + +**签名**: +```python +def priority(value: int) -> Callable[[HandlerCallable], HandlerCallable] +``` + +**参数**: +- `value`: 优先级数值,越大越先执行 + +**示例**: + +```python +@on_command("high") +@priority(100) +async def high_priority(self, event: MessageEvent): + await event.reply("我优先执行") + +@on_command("low") +@priority(1) +async def low_priority(self, event: MessageEvent): + await event.reply("我后执行") +``` + +--- + +### @conversation_command + +会话命令装饰器,支持会话超时和模式控制。 + +**签名**: +```python +def conversation_command( + command: str | Sequence[str], + *, + aliases: list[str] | None = None, + description: str | None = None, + timeout: int = 60, + mode: ConversationMode = "replace", + busy_message: str | None = None, + grace_period: float = 1.0, +) -> Callable[[HandlerCallable], HandlerCallable] +``` + +**参数**: +- `command`: 命令名称 +- `aliases`: 命令别名列表 +- `description`: 命令描述 +- `timeout`: 会话超时时间(秒) +- `mode`: 会话模式(`"replace"` 或 `"reject"`) +- `busy_message`: 会话忙时的提示消息 +- `grace_period`: 宽限期(秒) + +**示例**: + +```python +@conversation_command( + "survey", + description="问卷调查", + timeout=300, + mode="replace", + busy_message="当前有进行中的问卷" +) +async def survey(self, event: MessageEvent, ctx: Context): + await event.reply("请输入您的姓名:") +``` + +--- + +## 元数据辅助函数 + +### `get_handler_meta(func)` + +获取方法的 handler 元数据。 + +**签名**: +```python +def get_handler_meta(func: HandlerCallable) -> HandlerMeta | None +``` + +**参数**: +- `func`: 要检查的方法 + +**返回**: `HandlerMeta | None` - 元数据对象,如果没有则返回 None + +**示例**: + +```python +from astrbot_sdk.decorators import get_handler_meta + +@on_command("test") +async def test_handler(self, event: MessageEvent): + pass + +meta = get_handler_meta(test_handler) +if meta: + print(f"命令: {meta.trigger.command}") +``` + +--- + +### `get_capability_meta(func)` + +获取方法的 capability 元数据。 + +**签名**: +```python +def get_capability_meta(func: HandlerCallable) -> CapabilityMeta | None +``` + +**参数**: +- `func`: 要检查的方法 + +**返回**: `CapabilityMeta | None` - 元数据对象 + +--- + +### `get_llm_tool_meta(func)` + +获取方法的 LLM 工具元数据。 + +**签名**: +```python +def get_llm_tool_meta(func: HandlerCallable) -> LLMToolMeta | None +``` + +**参数**: +- `func`: 要检查的方法 + +**返回**: `LLMToolMeta | None` - 元数据对象 + +--- + +### `get_agent_meta(obj)` + +获取 Agent 类的元数据。 + +**签名**: +```python +def get_agent_meta(obj: Any) -> AgentMeta | None +``` + +**参数**: +- `obj`: 要检查的类或对象 + +**返回**: `AgentMeta | None` - 元数据对象 + +--- + +### `append_filter_meta(func, *, specs, local_bindings)` + +追加过滤器元数据到方法。 + +**签名**: +```python +def append_filter_meta( + func: HandlerCallable, + *, + specs: list[FilterSpec] | None = None, + local_bindings: list[Any] | None = None +) -> HandlerCallable +``` + +--- + +### `set_command_route_meta(func, route)` + +设置命令路由元数据。 + +**签名**: +```python +def set_command_route_meta( + func: HandlerCallable, + route: CommandRouteSpec +) -> HandlerCallable +``` + +--- + +## 使用示例 + +### 示例 1: 基础命令 + +```python +from astrbot_sdk import Star, Context, MessageEvent +from astrbot_sdk.decorators import on_command + +class MyPlugin(Star): + @on_command("hello") + async def hello(self, event: MessageEvent, ctx: Context): + await event.reply(f"你好,{event.sender_name}!") + + @on_command("echo", aliases=["repeat", "say"]) + async def echo(self, event: MessageEvent, text: str): + await event.reply(f"你说: {text}") +``` + +--- + +### 示例 2: 消息匹配 + +```python +from astrbot_sdk.decorators import on_message + +class MyPlugin(Star): + @on_message(keywords=["帮助", "help"]) + async def help(self, event: MessageEvent, ctx: Context): + await event.reply("可用命令: /hello, /echo") + + @on_message(regex=r"\d{4,}") + async def number(self, event: MessageEvent, ctx: Context): + await event.reply("检测到数字!") +``` + +--- + +### 示例 3: 装饰器组合 + +```python +from astrbot_sdk.decorators import ( + on_command, require_admin, group_only, rate_limit +) + +class MyPlugin(Star): + @on_command("admin") + @require_admin + @group_only() + @rate_limit(5, 60) + async def admin_cmd(self, event: MessageEvent, ctx: Context): + await event.reply("管理员群聊命令(每分钟最多5次)") +``` + +--- + +### 示例 4: 定时任务 + +```python +from astrbot_sdk.decorators import on_schedule + +class MyPlugin(Star): + @on_schedule(interval_seconds=3600) + async def hourly_task(self, ctx: Context): + # 每小时执行 + pass + + @on_schedule(cron="0 8 * * *") + async def morning_task(self, ctx: Context): + # 每天8点执行 + await ctx.platform.send("group_123", "早上好!") +``` + +--- + +### 示例 5: LLM 工具注册 + +```python +from astrbot_sdk import Star +from astrbot_sdk.decorators import register_llm_tool + +class MyPlugin(Star): + @register_llm_tool() + async def get_time(self) -> str: + """获取当前时间""" + import time + return f"当前时间: {time.strftime('%Y-%m-%d %H:%M:%S')}" + + @register_llm_tool( + name="calculate", + description="执行计算", + parameters_schema={ + "type": "object", + "properties": { + "expression": {"type": "string", "description": "数学表达式"} + }, + "required": ["expression"] + } + ) + async def calculate(self, expression: str) -> str: + try: + result = eval(expression) + return f"结果: {result}" + except Exception as e: + return f"计算错误: {e}" +``` + +--- + +## 注意事项 + +### 1. 装饰器顺序 + +正确的装饰器顺序很重要: + +```python +@on_command("command") # 1. 事件触发装饰器 +@platforms("qq") # 2. 过滤器装饰器 +@rate_limit(5, 60) # 3. 限制器装饰器 +@require_admin # 4. 修饰器装饰器 +async def my_handler(self, event: MessageEvent, ctx: Context): + pass +``` + +### 2. 避免常见陷阱 + +**不要混用冲突的装饰器**: + +```python +# 错误示例 +@on_message(platforms=["qq"]) +@platforms("wechat") # 冲突! +async def handler(...): pass + +# 正确示例 +@on_message(platforms=["qq", "wechat"]) +async def handler(...): pass +``` + +**不要在非消息处理器使用限制器**: + +```python +# 错误示例 +@on_event("ready") +@rate_limit(5, 60) # 不支持! +async def handler(...): pass + +# 正确示例 +@on_command("cmd") +@rate_limit(5, 60) +async def handler(...): pass +``` + +### 3. 类型注解建议 + +使用类型注解提高代码可读性: + +```python +from typing import Optional + +@on_command("greet") +async def greet_handler( + self, + event: MessageEvent, + ctx: Context +) -> None: + await event.reply("Hello!") +``` + +--- + +## 相关模块 + +- **装饰器实现**: `astrbot_sdk.decorators` +- **协议描述符**: `astrbot_sdk.protocol.descriptors` +- **事件定义**: `astrbot_sdk.events` +- **LLM 实体**: `astrbot_sdk.llm.entities` + +--- + +**版本**: v4.0 +**模块**: `astrbot_sdk.decorators` +**最后更新**: 2026-03-17 diff --git a/astrbot-sdk/docs/api/errors.md b/astrbot-sdk/docs/api/errors.md new file mode 100644 index 0000000000..b8ecff9a6f --- /dev/null +++ b/astrbot-sdk/docs/api/errors.md @@ -0,0 +1,651 @@ +# 错误处理 API 完整参考 + +## 概述 + +AstrBot SDK 提供了统一的错误处理机制,支持跨进程传递错误信息。所有可预期的错误都应使用 `AstrBotError` 类或其工厂方法创建。 + +**模块路径**: `astrbot_sdk.errors` + +--- + +## 目录 + +- [错误处理流程](#错误处理流程) +- [导入方式](#导入方式) +- [ErrorCodes - 错误码常量](#errorcodes---错误码常量) +- [AstrBotError - 错误类](#astrboterror---错误类) +- [使用示例](#使用示例) +- [最佳实践](#最佳实践) + +--- + +## 导入方式 + +```python +# 从主模块导入 +from astrbot_sdk import AstrBotError + +# 从 errors 模块导入 +from astrbot_sdk.errors import AstrBotError, ErrorCodes +``` + +--- + +## 错误处理流程 + +```python +# 1. 抛出错误 +raise AstrBotError.invalid_input("参数不能为空") + +# 2. 错误被捕获并序列化为 payload +# 3. 跨进程传输后反序列化 +# 4. 在 on_error 钩子中统一处理 +``` + +```python +class MyPlugin(Star): + async def on_error(self, error: AstrBotError) -> None: + if error.retryable: + # 可重试的错误 + ctx.logger.warning(f"可重试错误: {error.message}") + else: + # 不可重试的错误 + ctx.logger.error(f"错误: {error.hint or error.message}") +``` + +--- + +## ErrorCodes - 错误码常量 + +稳定的错误码常量,用于标识不同类型的错误。 + +### 定义 + +```python +class ErrorCodes: + """AstrBot v4 的稳定错误码常量。""" +``` + +### 错误码列表 + +#### 不可重试错误(retryable=False) + +| 错误码 | 说明 | 默认提示 | +|--------|------|----------| +| `UNKNOWN_ERROR` | 未知错误 | - | +| `LLM_NOT_CONFIGURED` | LLM 未配置 | - | +| `CAPABILITY_NOT_FOUND` | 能力未找到 | 请确认 AstrBot Core 是否已注册该 capability | +| `PERMISSION_DENIED` | 权限被拒绝 | - | +| `LLM_ERROR` | LLM 错误 | - | +| `INVALID_INPUT` | 输入无效 | 请检查调用参数 | +| `CANCELLED` | 调用被取消 | - | +| `PROTOCOL_VERSION_MISMATCH` | 协议版本不匹配 | 请升级 astrbot_sdk 至最新版本 | +| `PROTOCOL_ERROR` | 协议错误 | 请检查通信双方的协议实现 | +| `INTERNAL_ERROR` | 内部错误 | 请联系插件作者 | +| `RATE_LIMITED` | 速率限制 | 操作过于频繁,请稍后再试 | +| `COOLDOWN_ACTIVE` | 冷却中 | - | + +#### 可重试错误(retryable=True) + +| 错误码 | 说明 | 默认提示 | +|--------|------|----------| +| `CAPABILITY_TIMEOUT` | 能力调用超时 | - | +| `NETWORK_ERROR` | 网络错误 | 网络请求失败,请稍后重试 | +| `LLM_TEMPORARY_ERROR` | LLM 临时错误 | - | + +--- + +## AstrBotError - 错误类 + +AstrBot SDK 的标准错误类型,支持跨进程传递。 + +### 类定义 + +```python +@dataclass(slots=True) +class AstrBotError(Exception): + code: str + message: str + hint: str = "" + retryable: bool = False + docs_url: str = "" + details: dict[str, Any] | None = None +``` + +### 属性说明 + +| 属性 | 类型 | 说明 | +|------|------|------| +| `code` | `str` | 错误码,来自 ErrorCodes 常量 | +| `message` | `str` | 错误消息,面向开发者 | +| `hint` | `str` | 用户提示,面向终端用户 | +| `retryable` | `bool` | 是否可重试 | +| `docs_url` | `str` | 文档链接 | +| `details` | `dict[str, Any] \| None` | 详细信息 | + +--- + +## 工厂方法 + +### `cancelled(message)` + +创建取消错误。 + +```python +@classmethod +def cancelled(cls, message: str = "调用被取消") -> AstrBotError +``` + +**参数**: +- `message` (`str`): 错误消息 + +**返回**: `AstrBotError` 实例 + +**示例**: + +```python +raise AstrBotError.cancelled("用户取消操作") +``` + +--- + +### `capability_not_found(name)` + +创建能力未找到错误。 + +```python +@classmethod +def capability_not_found(cls, name: str) -> AstrBotError +``` + +**参数**: +- `name` (`str`): 未找到的能力名称 + +**返回**: `AstrBotError` 实例 + +**示例**: + +```python +raise AstrBotError.capability_not_found("my_plugin.custom_capability") +``` + +--- + +### `invalid_input(message, *, hint, docs_url, details)` + +创建输入无效错误。 + +```python +@classmethod +def invalid_input( + cls, + message: str, + *, + hint: str = "请检查调用参数", + docs_url: str = "", + details: dict[str, Any] | None = None, +) -> AstrBotError +``` + +**参数**: +- `message` (`str`): 详细错误消息 +- `hint` (`str`): 用户提示,默认 "请检查调用参数" +- `docs_url` (`str`): 文档链接 +- `details` (`dict[str, Any] | None`): 详细信息 + +**返回**: `AstrBotError` 实例 + +**示例**: + +```python +raise AstrBotError.invalid_input( + "参数格式错误", + hint="请使用 JSON 格式", + details={"expected": "json", "received": "text"} +) +``` + +--- + +### `protocol_version_mismatch(message)` + +创建协议版本不匹配错误。 + +```python +@classmethod +def protocol_version_mismatch(cls, message: str) -> AstrBotError +``` + +**参数**: +- `message` (`str`): 详细错误消息 + +**返回**: `AstrBotError` 实例 + +**示例**: + +```python +raise AstrBotError.protocol_version_mismatch("SDK 版本 4.0 与 Core 版本 3.9 不兼容") +``` + +--- + +### `protocol_error(message)` + +创建协议错误。 + +```python +@classmethod +def protocol_error(cls, message: str) -> AstrBotError +``` + +**参数**: +- `message` (`str`): 详细错误消息 + +**返回**: `AstrBotError` 实例 + +**示例**: + +```python +raise AstrBotError.protocol_error("无效的 payload 格式") +``` + +--- + +### `internal_error(message, *, hint, docs_url, details)` + +创建内部错误。 + +```python +@classmethod +def internal_error( + cls, + message: str, + *, + hint: str = "请联系插件作者", + docs_url: str = "", + details: dict[str, Any] | None = None, +) -> AstrBotError +``` + +**参数**: +- `message` (`str`): 详细错误消息 +- `hint` (`str`): 用户提示,默认 "请联系插件作者" +- `docs_url` (`str`): 文档链接 +- `details` (`dict[str, Any] | None`): 详细信息 + +**返回**: `AstrBotError` 实例 + +**示例**: + +```python +raise AstrBotError.internal_error( + "处理逻辑异常", + hint="请检查日志并联系插件作者", + details={"traceback": "..."} +) +``` + +--- + +### `network_error(message, *, hint, docs_url, details)` + +创建网络错误。 + +```python +@classmethod +def network_error( + cls, + message: str, + *, + hint: str = "网络请求失败,请稍后重试", + docs_url: str = "", + details: dict[str, Any] | None = None, +) -> AstrBotError +``` + +**参数**: +- `message` (`str`): 详细错误消息 +- `hint` (`str`): 用户提示,默认 "网络请求失败,请稍后重试" +- `docs_url` (`str`): 文档链接 +- `details` (`dict[str, Any] | None`): 详细信息 + +**返回**: `AstrBotError` 实例 + +**特性**: `retryable=True` + +**示例**: + +```python +raise AstrBotError.network_error( + "连接超时", + hint="网络不稳定,请稍后重试", + details={"url": "...", "timeout": 30} +) +``` + +--- + +### `rate_limited(*, hint, details)` + +创建速率限制错误。 + +```python +@classmethod +def rate_limited( + cls, + *, + hint: str = "操作过于频繁,请稍后再试。", + details: dict[str, Any] | None = None, +) -> AstrBotError +``` + +**参数**: +- `hint` (`str`): 用户提示,默认 "操作过于频繁,请稍后再试。" +- `details` (`dict[str, Any] | None`): 详细信息 + +**返回**: `AstrBotError` 实例 + +**特性**: `retryable=False` + +**示例**: + +```python +raise AstrBotError.rate_limited( + hint="每分钟最多调用 5 次", + details={"limit": 5, "window": 60, "remaining": 0} +) +``` + +--- + +### `cooldown_active(*, hint, details)` + +创建冷却中错误。 + +```python +@classmethod +def cooldown_active( + cls, + *, + hint: str, + details: dict[str, Any] | None = None, +) -> AstrBotError +``` + +**参数**: +- `hint` (`str`): 用户提示 +- `details` (`dict[str, Any] | None`): 详细信息 + +**返回**: `AstrBotError` 实例 + +**特性**: `retryable=False` + +**示例**: + +```python +raise AstrBotError.cooldown_active( + hint="技能冷却中,还需等待 25 秒", + details={"cooldown": 30, "remaining": 25} +) +``` + +--- + +## 实例方法 + +### `to_payload()` + +序列化为可传输的字典格式,用于跨进程传递错误信息。 + +```python +def to_payload(self) -> dict[str, object] +``` + +**返回**: `dict[str, object]` - 包含错误信息的字典 + +**返回格式**: + +```python +{ + "code": "invalid_input", + "message": "参数格式错误", + "hint": "请使用 JSON 格式", + "retryable": False, + "docs_url": "", + "details": {"expected": "json", "received": "text"} +} +``` + +--- + +### `from_payload(payload)` + +从字典反序列化错误实例。 + +```python +@classmethod +def from_payload(cls, payload: dict[str, object]) -> AstrBotError +``` + +**参数**: +- `payload` (`dict[str, object]`): 包含错误信息的字典 + +**返回**: `AstrBotError` 实例 + +**示例**: + +```python +payload = error.to_payload() +restored_error = AstrBotError.from_payload(payload) +``` + +--- + +### `__str__()` + +返回错误消息。 + +```python +def __str__(self) -> str +``` + +**返回**: `str` - `message` 属性的值 + +--- + +## 使用示例 + +### 基本错误处理 + +```python +from astrbot_sdk import AstrBotError +from astrbot_sdk.errors import ErrorCodes + +@on_command("divide") +async def divide(self, event: MessageEvent, a: int, b: int): + if b == 0: + raise AstrBotError.invalid_input( + "除数不能为零", + hint="请输入非零的除数" + ) + return event.plain_result(f"{a} / {b} = {a / b}") +``` + +### 带详细信息的错误 + +```python +@on_command("search") +async def search(self, event: MessageEvent, keyword: str): + if not keyword or len(keyword.strip()) == 0: + raise AstrBotError.invalid_input( + "搜索关键词不能为空", + hint="请输入要搜索的关键词", + details={ + "field": "keyword", + "constraint": "non_empty", + "provided": keyword + } + ) + # 执行搜索... +``` + +### 捕获和处理错误 + +```python +@on_command("risky") +async def risky_operation(self, event: MessageEvent): + try: + result = await some_network_request() + return event.plain_result(f"成功: {result}") + except AstrBotError as e: + ctx.logger.error(f"操作失败: {e.message}") + if e.retryable: + await event.reply(f"操作失败(可重试): {e.hint or e.message}") + else: + await event.reply(f"操作失败: {e.hint or e.message}") +``` + +### 在插件中处理错误 + +```python +class MyPlugin(Star): + async def on_error(self, error: AstrBotError) -> None: + """统一处理插件中的所有错误""" + if error.code == ErrorCodes.CAPABILITY_NOT_FOUND: + self.logger.error(f"能力未找到: {error.message}") + elif error.code == ErrorCodes.NETWORK_ERROR: + self.logger.warning(f"网络错误: {error.message}") + elif error.retryable: + self.logger.warning(f"可重试错误: {error.code} - {error.message}") + else: + self.logger.error(f"错误: {error.code} - {error.message}") +``` + +### 检查特定错误码 + +```python +try: + await some_capability_call() +except AstrBotError as e: + if e.code == ErrorCodes.RATE_LIMITED: + remaining = e.details.get("remaining", 0) + await event.reply(f"请求过多,请稍后再试。剩余次数: {remaining}") + elif e.code == ErrorCodes.CAPABILITY_TIMEOUT: + await event.reply("请求超时,请稍后重试") + else: + await event.reply(f"错误: {e.hint or e.message}") +``` + +### 自定义错误(使用通用构造方法) + +```python +# 使用通用构造方法创建自定义错误 +error = AstrBotError( + code="custom_error_code", + message="自定义错误消息", + hint="这是给用户的提示", + retryable=False, + details={"custom_field": "custom_value"} +) +raise error +``` + +--- + +## 最佳实践 + +### 1. 使用工厂方法而非直接构造 + +```python +# 推荐 +raise AstrBotError.invalid_input("参数错误") + +# 不推荐(除非需要自定义错误码) +raise AstrBotError( + code=ErrorCodes.INVALID_INPUT, + message="参数错误", + hint="请检查调用参数" +) +``` + +### 2. 提供用户友好的提示 + +```python +# 推荐 +raise AstrBotError.invalid_input( + "参数 'count' 必须为正整数", + hint="请输入大于 0 的数字" +) + +# 不推荐 +raise AstrBotError.invalid_input("参数错误") +``` + +### 3. 使用 details 提供调试信息 + +```python +raise AstrBotError.invalid_input( + "参数验证失败", + hint="请检查输入格式", + details={ + "field": "email", + "pattern": "^[\\w\\.-]+@[\\w\\.-]+\\.\\w+$", + "provided": "invalid-email" + } +) +``` + +### 4. 区分可重试和不可重试错误 + +```python +# 网络错误 - 可重试 +raise AstrBotError.network_error("连接失败") + +# 参数错误 - 不可重试 +raise AstrBotError.invalid_input("参数类型错误") +``` + +### 5. 在 on_error 中集中处理 + +```python +class MyPlugin(Star): + async def on_error(self, error: AstrBotError) -> None: + # 记录所有错误 + self.logger.error(f"错误: [{error.code}] {error.message}") + + # 可重试错误记录为警告级别 + if error.retryable: + self.logger.warning(f"可重试错误,考虑实现重试逻辑") + + # 特定错误码的特殊处理 + if error.code == ErrorCodes.CAPABILITY_NOT_FOUND: + self.logger.critical("请检查 AstrBot Core 配置") +``` + +### 6. 向用户展示适当的错误信息 + +```python +try: + result = await operation() +except AstrBotError as e: + # 优先使用 hint(面向用户) + user_message = e.hint or e.message + await event.reply(user_message) + + # 记录完整的错误信息(面向开发者) + ctx.logger.error(f"操作失败: {e.code} - {e.message}", extra=e.details) +``` + +--- + +## 相关模块 + +- **事件处理**: `astrbot_sdk.events.MessageEvent` +- **上下文**: `astrbot_sdk.context.Context` +- **插件基类**: `astrbot_sdk.star.Star` + +--- + +**版本**: v4.0 +**模块**: `astrbot_sdk.errors` +**最后更新**: 2026-03-17 diff --git a/astrbot-sdk/docs/api/message_components.md b/astrbot-sdk/docs/api/message_components.md new file mode 100644 index 0000000000..3068e6989b --- /dev/null +++ b/astrbot-sdk/docs/api/message_components.md @@ -0,0 +1,948 @@ +# 消息组件 API 完整参考 + +## 概述 + +消息组件是用于构建聊天消息的各种元素。每个组件代表消息中的一种特定内容类型,可以单独使用或组合成消息链。 + +**模块路径**: `astrbot_sdk.message_components` + +--- + +## 目录 + +- [BaseMessageComponent - 基类](#basemessagecomponent---基类) +- [Plain - 纯文本组件](#plain---纯文本组件) +- [At / AtAll - @组件](#at--atall---组件) +- [Image - 图片组件](#image---图片组件) +- [Record - 语音组件](#record---语音组件) +- [Video - 视频组件](#video---视频组件) +- [File - 文件组件](#file---文件组件) +- [Reply - 回复组件](#reply---回复组件) +- [Poke - 戳一戳组件](#poke---戳一戳组件) +- [Forward - 转发组件](#forward---转发组件) +- [MessageChain - 消息链](#messagechain---消息链) +- [辅助函数](#辅助函数) + +--- + +## 导入方式 + +```python +# 从主模块导入(推荐) +from astrbot_sdk import ( + Plain, At, AtAll, Image, Record, Video, File, Reply, Poke, Forward, + MessageChain, MessageBuilder +) + +# 从子模块导入 +from astrbot_sdk.message_components import ( + Plain, At, AtAll, Image, Record, Video, File, Reply, Poke, Forward +) +from astrbot_sdk.message_result import MessageChain, MessageBuilder + +# 辅助函数 +from astrbot_sdk.message_components import ( + payload_to_component, + component_to_payload_sync, + component_to_payload, +) +``` + +--- + +## BaseMessageComponent - 基类 + +所有消息组件的基类。 + +### 类定义 + +```python +class BaseMessageComponent: + type: str = "unknown" + + def toDict(self) -> dict[str, Any]: + """同步转换为字典 payload""" + + async def to_dict(self) -> dict[str, Any]: + """异步转换为字典 payload""" +``` + +--- + +## Plain - 纯文本组件 + +最简单的消息组件,只包含文本内容。 + +### 类定义 + +```python +class Plain(BaseMessageComponent): + type = "plain" # 序列化时为 "text" + + def __init__(self, text: str, convert: bool = True, **_: Any) -> None: + self.text = text + self.convert = convert +``` + +### 构造方法 + +```python +from astrbot_sdk import Plain + +# 基本用法 +text = Plain("Hello World") + +# 不自动 strip(保留首尾空格) +text = Plain(" Hello ", convert=False) +``` + +### 序列化格式 + +```python +# toDict() 会自动 strip 文本 +{ + "type": "text", + "data": {"text": "Hello World"} +} + +# to_dict() 保留原始文本 +{ + "type": "text", + "data": {"text": " Hello "} +} +``` + +### 使用示例 + +```python +@on_command("echo") +async def echo(self, event: MessageEvent, text: str): + await event.reply_chain([Plain(f"你说: {text}")]) +``` + +--- + +## At / AtAll - @组件 + +用于在消息中提及用户。 + +### At - @某人 + +#### 类定义 + +```python +class At(BaseMessageComponent): + type = "at" + + def __init__(self, qq: int | str, name: str | None = "", **_: Any) -> None: + self.qq = qq + self.name = name or "" +``` + +#### 构造方法 + +```python +from astrbot_sdk import At + +# @ 单个用户 +at = At(123456) +at = At("123456", name="张三") +``` + +#### 序列化格式 + +```python +{ + "type": "at", + "data": {"qq": "123456"} +} +``` + +--- + +### AtAll - @全体成员 + +#### 类定义 + +```python +class AtAll(At): + def __init__(self, **_: Any) -> None: + super().__init__(qq="all") +``` + +#### 构造方法 + +```python +from astrbot_sdk import AtAll + +at_all = AtAll() +``` + +#### 序列化格式 + +```python +{ + "type": "at", + "data": {"qq": "all"} +} +``` + +--- + +### 使用示例 + +```python +from astrbot_sdk import At, AtAll, Plain + +@on_command("at_test") +async def at_test(self, event: MessageEvent): + await event.reply_chain([ + Plain("你好 "), + At(event.user_id or "123456"), + Plain("!"), + AtAll(), + Plain("所有人请注意!") + ]) +``` + +--- + +## Image - 图片组件 + +用于在消息中发送图片。 + +### 类定义 + +```python +class Image(BaseMessageComponent): + type = "image" + + def __init__(self, file: str | None, **kwargs: Any) -> None: + self.file = file or "" + self._type = kwargs.get("_type", "") + self.subType = kwargs.get("subType", 0) + self.url = kwargs.get("url", "") + self.cache = kwargs.get("cache", True) + self.id = kwargs.get("id", 40000) + self.c = kwargs.get("c", 2) + self.path = kwargs.get("path", "") + self.file_unique = kwargs.get("file_unique", "") +``` + +### 静态构造方法 + +#### `fromURL(url, **kwargs)` + +从 URL 创建图片。 + +```python +from astrbot_sdk import Image + +img = Image.fromURL("https://example.com/image.jpg") +``` + +#### `fromFileSystem(path, **kwargs)` + +从本地文件系统创建图片。 + +```python +img = Image.fromFileSystem("/path/to/image.jpg") +``` + +#### `fromBase64(base64_data, **kwargs)` + +从 Base64 数据创建图片。 + +```python +img = Image.fromBase64("iVBORw0KGgo...") +``` + +#### `fromBytes(data, **kwargs)` + +从字节数据创建图片。 + +```python +img = Image.fromBytes(b"...") +``` + +### 实例方法 + +#### `convert_to_file_path()` + +将图片转换为本地文件路径(下载或解码)。 + +```python +path = await img.convert_to_file_path() +``` + +#### `register_to_file_service()` + +将图片注册到文件服务,返回可访问 URL。 + +```python +public_url = await img.register_to_file_service() +``` + +### 支持的格式 + +```python +# URL: "https://example.com/image.jpg" +# 本地文件: "file:///absolute/path/to/image.jpg" +# Base64: "base64://iVBORw0KGgo..." +``` + +### 使用示例 + +```python +from astrbot_sdk import Image + +@on_command("cat") +async def cat(self, event: MessageEvent): + await event.reply_image("https://example.com/cat.jpg") + +@on_command("local_img") +async def local_img(self, event: MessageEvent): + await event.reply_image("file:///path/to/image.jpg") +``` + +--- + +## Record - 语音组件 + +用于在消息中发送语音/音频。 + +### 类定义 + +```python +class Record(BaseMessageComponent): + type = "record" + + def __init__(self, file: str | None, **kwargs: Any) -> None: + self.file = file or "" + self.magic = kwargs.get("magic", False) + self.url = kwargs.get("url", "") + self.cache = kwargs.get("cache", True) + self.proxy = kwargs.get("proxy", True) + self.timeout = kwargs.get("timeout", 0) + self.text = kwargs.get("text") + self.path = kwargs.get("path") +``` + +### 静态构造方法 + +#### `fromFileSystem(path, **kwargs)` + +```python +from astrbot_sdk import Record + +audio = Record.fromFileSystem("/path/to/audio.mp3") +``` + +#### `fromURL(url, **kwargs)` + +```python +audio = Record.fromURL("https://example.com/audio.mp3") +``` + +### 实例方法 + +#### `convert_to_file_path()` + +```python +path = await audio.convert_to_file_path() +``` + +#### `register_to_file_service()` + +```python +public_url = await audio.register_to_file_service() +``` + +--- + +## Video - 视频组件 + +用于在消息中发送视频。 + +### 类定义 + +```python +class Video(BaseMessageComponent): + type = "video" + + def __init__(self, file: str, **kwargs: Any) -> None: + self.file = file + self.cover = kwargs.get("cover", "") + self.c = kwargs.get("c", 2) + self.path = kwargs.get("path", "") +``` + +### 静态构造方法 + +#### `fromFileSystem(path, **kwargs)` + +```python +from astrbot_sdk import Video + +video = Video.fromFileSystem("/path/to/video.mp4") +``` + +#### `fromURL(url, **kwargs)` + +```python +video = Video.fromURL("https://example.com/video.mp4") +``` + +--- + +## File - 文件组件 + +用于在消息中发送文件附件。 + +### 类定义 + +```python +class File(BaseMessageComponent): + type = "file" + + def __init__(self, name: str, file: str = "", url: str = "") -> None: + self.name = name + self.file_ = file + self.url = url +``` + +### 属性 + +- `name` (`str`): 文件名 +- `file_` (`str`): 本地文件路径(内部使用) +- `url` (`str`): 文件 URL + +### file 属性 (getter/setter) + +```python +@property +def file(self) -> str: + return self.file_ + +@file.setter +def file(self, value: str) -> None: + if value.startswith(("http://", "https://")): + self.url = value + else: + self.file_ = value +``` + +### 构造方法 + +```python +from astrbot_sdk import File + +# URL 文件 +file1 = File(name="document.pdf", url="https://example.com/doc.pdf") + +# 本地文件 +file2 = File(name="image.jpg", file="/path/to/image.jpg") +``` + +### 实例方法 + +#### `get_file(allow_return_url=False)` + +获取文件路径或 URL。 + +```python +path = await file.get_file() + +# 优先返回 URL +path = await file.get_file(allow_return_url=True) +``` + +#### `register_to_file_service()` + +```python +public_url = await file.register_to_file_service() +``` + +### 序列化格式 + +```python +# toDict() +{ + "type": "file", + "data": { + "name": "文件名.pdf", + "file": "本地路径或URL" + } +} + +# to_dict() +{ + "type": "file", + "data": { + "name": "文件名.pdf", + "file": "优先返回URL,否则本地路径" + } +} +``` + +--- + +## Reply - 回复组件 + +用于回复某条消息。 + +### 类定义 + +```python +class Reply(BaseMessageComponent): + type = "reply" + + def __init__(self, **kwargs: Any) -> None: + self.id = kwargs.get("id", "") + self.chain = _coerce_reply_chain(kwargs.get("chain", [])) + self.sender_id = kwargs.get("sender_id", 0) + self.sender_nickname = kwargs.get("sender_nickname", "") + self.time = kwargs.get("time", 0) + self.message_str = kwargs.get("message_str", "") + self.text = kwargs.get("text", "") + self.qq = kwargs.get("qq", 0) + self.seq = kwargs.get("seq", 0) +``` + +### 构造方法 + +```python +from astrbot_sdk import Reply, Plain + +reply = Reply( + id="msg_123", + sender_id="789", + sender_nickname="张三", + chain=[Plain("被回复的消息")] +) +``` + +### 实例方法 + +#### `toDict()` / `to_dict()` + +序列化为字典。 + +--- + +## Poke - 戳一戳组件 + +用于发送戳一戳操作。 + +### 类定义 + +```python +class Poke(BaseMessageComponent): + type = "poke" + + def __init__(self, poke_type: str | int | None = None, **kwargs: Any) -> None: + self._type = str(poke_type) + self.id = kwargs.get("id") + self.qq = kwargs.get("qq", 0) +``` + +### 构造方法 + +```python +from astrbot_sdk import Poke + +poke = Poke(poke_type="126", qq="123456") +``` + +--- + +## Forward - 转发组件 + +用于转发消息。 + +### 类定义 + +```python +class Forward(BaseMessageComponent): + type = "forward" + + def __init__(self, id: str, **_: Any) -> None: + self.id = id +``` + +### 构造方法 + +```python +from astrbot_sdk import Forward + +forward = Forward(id="forward_msg_123") +``` + +--- + +## UnknownComponent - 未知组件 + +用于表示无法识别的组件类型。 + +### 类定义 + +```python +class UnknownComponent(BaseMessageComponent): + type = "unknown" + + def __init__( + self, + *, + raw_type: str = "unknown", + raw_data: dict[str, Any] | None = None, + ) -> None: + self.raw_type = raw_type + self.raw_data = raw_data or {} +``` + +### 构造方法 + +```python +from astrbot_sdk import UnknownComponent + +unknown = UnknownComponent( + raw_type="custom_type", + raw_data={"field": "value"} +) +``` + +### 说明 + +当 `payload_to_component()` 遇到无法识别的组件类型时,会返回 `UnknownComponent` 实例,保留原始数据以便调试。 + +--- + +## MessageChain - 消息链 + +用于组合多个消息组件。 + +### 类定义 + +```python +@dataclass(slots=True) +class MessageChain: + components: list[BaseMessageComponent] = field(default_factory=list) +``` + +### 构造方法 + +```python +from astrbot_sdk.message_result import MessageChain +from astrbot_sdk.message_components import Plain, At + +# 空消息链 +chain = MessageChain() + +# 带初始组件 +chain = MessageChain([Plain("Hello"), At("123456")]) +``` + +### 实例方法 + +#### `append(component)` + +追加单个组件,返回 self 支持链式调用。 + +```python +chain.append(Plain("More text")) +``` + +#### `extend(components)` + +追加多个组件。 + +```python +chain.extend([Plain("A"), Plain("B")]) +``` + +#### `to_payload()` + +转换为协议 payload。 + +```python +payload = chain.to_payload() +``` + +#### `get_plain_text(with_other_comps_mark=False)` + +提取纯文本内容。 + +```python +text = chain.get_plain_text() +``` + +--- + +## MessageBuilder - 消息构建器 + +流式构建消息链的工具类。 + +### 使用示例 + +```python +from astrbot_sdk.message_result import MessageBuilder + +chain = (MessageBuilder() + .text("Hello ") + .at("123456") + .text("!\n") + .image("https://example.com/img.jpg") + .build()) + +await event.reply_chain(chain) +``` + +### 可用方法 + +- `.text(content)` - 添加文本 +- `.at(user_id)` - 添加@用户 +- `.at_all()` - 添加@全体成员 +- `.image(url)` - 添加图片 +- `.record(url)` - 添加语音 +- `.video(url)` - 添加视频 +- `.file(name, url=...)` - 添加文件 +- `.build()` - 构建消息链 + +--- + +## 辅助函数 + +### `payload_to_component(payload)` + +将协议 payload 转换为消息组件。 + +```python +from astrbot_sdk.message_components import payload_to_component + +component = payload_to_component(payload) +``` + +### `component_to_payload_sync(component)` + +将组件同步转换为 payload。 + +```python +from astrbot_sdk.message_components import component_to_payload_sync + +payload = component_to_payload_sync(component) +``` + +### `component_to_payload(component)` + +将组件异步转换为 payload。 + +```python +from astrbot_sdk.message_components import component_to_payload + +payload = await component_to_payload(component) +``` + +--- + +### `is_message_component(value)` + +检查值是否为消息组件。 + +```python +from astrbot_sdk.message_components import is_message_component + +if is_message_component(value): + print("是消息组件") +``` + +--- + +### `payloads_to_components(payloads)` + +批量将 payload 列表转换为组件列表。 + +```python +from astrbot_sdk.message_components import payloads_to_components + +components = payloads_to_components(payload_list) +``` + +--- + +### `build_media_component_from_url(url, *, kind)` + +从 URL 构建媒体组件。 + +```python +from astrbot_sdk.message_components import build_media_component_from_url + +# 自动识别类型 +component = build_media_component_from_url("https://example.com/image.jpg") + +# 指定类型 +component = build_media_component_from_url("https://example.com/file", kind="image") +``` + +--- + +## MediaHelper - 媒体辅助类 + +提供媒体处理的静态方法。 + +### `from_url(url, *, kind)` + +从 URL 创建媒体组件。 + +**签名**: +```python +@staticmethod +async def from_url( + url: str, + *, + kind: str = "auto" +) -> BaseMessageComponent +``` + +**参数**: +- `url`: 媒体 URL +- `kind`: 媒体类型(`"auto"`, `"image"`, `"record"`, `"video"`, `"file"`) + +**返回**: 对应的媒体组件 + +**示例**: + +```python +from astrbot_sdk.message_components import MediaHelper + +# 自动识别 +img = await MediaHelper.from_url("https://example.com/photo.jpg") + +# 指定类型 +video = await MediaHelper.from_url("https://example.com/video.mp4", kind="video") +``` + +--- + +### `download(url, save_dir)` + +下载媒体文件到指定目录。 + +**签名**: +```python +@staticmethod +async def download(url: str, save_dir: Path) -> Path +``` + +**参数**: +- `url`: 媒体 URL(仅支持 http/https) +- `save_dir`: 保存目录路径 + +**返回**: `Path` - 下载后的文件路径 + +**异常**: +- `AstrBotError`: 下载失败时抛出 + +**示例**: + +```python +from pathlib import Path +from astrbot_sdk.message_components import MediaHelper + +try: + path = await MediaHelper.download( + "https://example.com/image.jpg", + Path("./downloads") + ) + print(f"下载到: {path}") +except AstrBotError as e: + print(f"下载失败: {e.message}") +``` + +--- + +## 使用示例 + +### 处理图片消息 + +```python +@on_message() +async def save_image(self, event: MessageEvent): + images = event.get_images() + if not images: + await event.reply("消息中没有图片") + return + + for img in images: + try: + path = await img.convert_to_file_path() + # 保存图片... + await event.reply(f"已保存: {path}") + except Exception as e: + await event.reply(f"保存失败: {e}") +``` + +### 检测@和群聊/私聊 + +```python +@on_command("check") +async def check(self, event: MessageEvent): + # 检查是否群聊 + if event.is_group_chat(): + await event.reply("这是群聊消息") + elif event.is_private_chat(): + await event.reply("这是私聊消息") + + # 检查@的用户 + at_users = event.get_at_users() + if at_users: + await event.reply(f"你@了: {', '.join(at_users)}") +``` + +### 返回富文本结果 + +```python +@on_command("info") +async def info(self, event: MessageEvent): + return event.chain_result([ + Plain(f"用户: {event.sender_name}\n"), + Plain(f"ID: {event.user_id}\n"), + Plain(f"平台: {event.platform}"), + ]) +``` + +--- + +## 注意事项 + +1. **序列化差异**: + - `Plain.toDict()` 会 strip 文本 + - `Plain.to_dict()` 保留原始文本 + - `File.toDict()` 和 `to_dict()` 对 file 字段处理不同 + +2. **路径格式**: + - 本地文件: `file:///absolute/path` (Windows 下特殊处理) + - URL: `http://` 或 `https://` + - Base64: `base64://` + +3. **文件下载**: + - `convert_to_file_path()` 会下载网络文件到临时目录 + - `register_to_file_service()` 需要运行时上下文 + +4. **兼容性**: + - `At` 和 `AtAll` 序列化后的 type 都是 "at" + - `Reply` 的 chain 字段在序列化时递归处理 + +--- + +## 相关模块 + +- **消息组件**: `astrbot_sdk.message_components` +- **消息链**: `astrbot_sdk.message_result.MessageChain` +- **消息构建器**: `astrbot_sdk.message_result.MessageBuilder` +- **协议描述符**: `astrbot_sdk.protocol.descriptors` + +--- + +**版本**: v4.0 +**模块**: `astrbot_sdk.message_components` +**最后更新**: 2026-03-17 diff --git a/astrbot-sdk/docs/api/message_event.md b/astrbot-sdk/docs/api/message_event.md new file mode 100644 index 0000000000..a71b564f4a --- /dev/null +++ b/astrbot-sdk/docs/api/message_event.md @@ -0,0 +1,1171 @@ +# MessageEvent 类 - 消息事件对象完整参考 + +## 概述 + +`MessageEvent` 表示接收到的聊天消息事件,包含消息的所有信息(发送者、内容、组件等)和响应方法。当用户发送消息时,AstrBot 会创建一个 `MessageEvent` 实例并传递给插件的事件处理器。 + +**模块路径**: `astrbot_sdk.events.MessageEvent` + +--- + +## 类定义 + +```python +class MessageEvent: + # 基本属性 + text: str # 消息文本内容 + user_id: str | None # 发送者用户 ID + group_id: str | None # 群组 ID(私聊时为 None) + platform: str | None # 平台标识(如 "qq", "wechat") + session_id: str # 会话 ID + self_id: str # 机器人账号 ID + platform_id: str # 平台实例标识 + message_type: str # 消息类型("private" 或 "group") + sender_name: str # 发送者昵称 + raw: dict[str, Any] # 原始消息数据(协议层 payload) + context: Context | None # 运行时上下文 +``` + +--- + +## 导入方式 + +```python +# 从主模块导入(推荐) +from astrbot_sdk import MessageEvent + +# 从子模块导入 +from astrbot_sdk.events import MessageEvent + +# 常用配套导入 +from astrbot_sdk import Context # 上下文对象 +from astrbot_sdk.decorators import on_command, on_message # 装饰器 +``` + +--- + +## 基本属性 + +### 消息内容属性 + +#### `text` + +消息的纯文本内容。 + +```python +# 类型: str +# 说明: 提取消息中的纯文本部分 + +@on_message() +async def handler(self, event: MessageEvent): + print(f"收到消息: {event.text}") +``` + +**注意**: 此属性只包含文本部分,不包含图片、@等其他组件的内容。 + +--- + +### 发送者属性 + +#### `user_id` + +发送者的用户 ID。 + +```python +# 类型: str | None +# 说明: 发送者的唯一标识符 + +@on_command("whoami") +async def whoami(self, event: MessageEvent): + await event.reply(f"你的 ID 是: {event.user_id}") +``` + +#### `sender_name` + +发送者的昵称。 + +```python +# 类型: str +# 说明: 发送者的显示名称 + +@on_command("greet") +async def greet(self, event: MessageEvent): + await event.reply(f"你好,{event.sender_name}!") +``` + +--- + +### 会话属性 + +#### `session_id` + +当前会话的唯一标识符。 + +```python +# 类型: str +# 说明: 群聊时为 group_id,私聊时为 user_id + +@on_command("session") +async def session(self, event: MessageEvent): + await event.reply(f"当前会话: {event.session_id}") +``` + +#### `group_id` + +群组 ID(仅在群聊消息中有值)。 + +```python +# 类型: str | None +# 说明: 私聊时为 None + +@on_command("check_group") +async def check_group(self, event: MessageEvent): + if event.group_id: + await event.reply(f"群组 ID: {event.group_id}") + else: + await event.reply("这是私聊消息") +``` + +#### `message_type` + +消息类型。 + +```python +# 类型: str +# 说明: "private"(私聊)或 "group"(群聊) + +@on_command("type") +async def msg_type(self, event: MessageEvent): + await event.reply(f"消息类型: {event.message_type}") +``` + +--- + +### 平台属性 + +#### `platform` + +平台标识。 + +```python +# 类型: str | None +# 说明: 如 "qq", "wechat", "telegram" 等 + +@on_command("platform") +async def platform(self, event: MessageEvent): + await event.reply(f"来自平台: {event.platform}") +``` + +#### `platform_id` + +平台实例标识。 + +```python +# 类型: str +# 说明: 同一平台可能有多个实例(如多个 QQ 账号) + +@on_command("platform_id") +async def platform_id(self, event: MessageEvent): + await event.reply(f"平台实例: {event.platform_id}") +``` + +#### `self_id` + +机器人自己的 ID。 + +```python +# 类型: str +# 说明: 当前机器人账号在平台上的 ID + +@on_command("bot_id") +async def bot_id(self, event: MessageEvent): + await event.reply(f"机器人 ID: {event.self_id}") +``` + +--- + +### 原始数据属性 + +#### `raw` + +原始消息数据(协议层 payload)。 + +```python +# 类型: dict[str, Any] +# 说明: 包含完整的原始消息数据 + +@on_command("raw") +async def raw(self, event: MessageEvent): + # 访问原始数据 + raw_data = event.raw + print(f"原始数据: {raw_data}") +``` + +**注意**: 此属性包含完整的协议层数据,格式可能因平台而异。 + +--- + +## 消息组件访问方法 + +### `get_messages()` + +获取当前事件的所有 SDK 消息组件。 + +```python +def get_messages(self) -> list[BaseMessageComponent]: + """Return SDK message components for the current event.""" +``` + +**返回**: 消息组件列表 + +**示例**: + +```python +@on_command("analyze") +async def analyze(self, event: MessageEvent): + components = event.get_messages() + for comp in components: + print(f"组件类型: {comp.type}") +``` + +--- + +### `has_component(type_)` + +检查是否包含特定类型的组件。 + +```python +def has_component(self, type_: type[BaseMessageComponent]) -> bool +``` + +**参数**: +- `type_`: 组件类型(如 `Image`, `At`, `File`) + +**返回**: `bool` - 是否包含该类型组件 + +**示例**: + +```python +@on_command("has_img") +async def has_img(self, event: MessageEvent): + if event.has_component(Image): + await event.reply("消息包含图片") + else: + await event.reply("消息不包含图片") +``` + +--- + +### `get_components(type_)` + +获取特定类型的所有组件。 + +```python +def get_components(self, type_: type[BaseMessageComponent]) -> list[BaseMessageComponent] +``` + +**参数**: +- `type_`: 组件类型 + +**返回**: 匹配的组件列表 + +**示例**: + +```python +@on_command("list_at") +async def list_at(self, event: MessageEvent): + at_comps = event.get_components(At) + for at in at_comps: + await event.reply(f"@了用户: {at.qq}") +``` + +--- + +### `get_images()` + +获取所有图片组件的便捷方法。 + +```python +def get_images(self) -> list[Image] +``` + +**返回**: 图片组件列表 + +**示例**: + +```python +@on_message(keywords=["保存图片"]) +async def save_images(self, event: MessageEvent): + images = event.get_images() + if not images: + await event.reply("消息中没有图片") + return + + saved_paths = [] + for img in images: + try: + local_path = await img.convert_to_file_path() + saved_paths.append(local_path) + except Exception as e: + await event.reply(f"保存失败: {e}") + return + + await event.reply(f"已保存 {len(saved_paths)} 张图片") +``` + +--- + +### `get_files()` + +获取所有文件组件的便捷方法。 + +```python +def get_files(self) -> list[File] +``` + +**返回**: 文件组件列表 + +**示例**: + +```python +@on_message(keywords=["文件"]) +async def handle_files(self, event: MessageEvent): + files = event.get_files() + for file in files: + await event.reply(f"收到文件: {file.name}") +``` + +--- + +### `extract_plain_text()` + +提取所有 Plain 组件的文本内容。 + +```python +def extract_plain_text(self) -> str +``` + +**返回**: 纯文本内容(拼接所有 Plain 组件) + +**注意**: 这会移除所有非文本组件(图片、@等),仅拼接纯文本。 + +**示例**: + +```python +@on_command("gettext") +async def get_text(self, event: MessageEvent): + text = event.extract_plain_text() + await event.reply(f"纯文本内容: {text}") +``` + +--- + +### `get_at_users()` + +获取消息中所有被@的用户ID列表(不包括 @全体成员)。 + +```python +def get_at_users(self) -> list[str] +``` + +**返回**: 被@的用户 ID 列表 + +**示例**: + +```python +@on_command("who_at") +async def who_at(self, event: MessageEvent): + at_users = event.get_at_users() + if at_users: + await event.reply(f"你@了这些用户: {', '.join(at_users)}") + else: + await event.reply("你没有@任何人") +``` + +--- + +## 会话与平台信息方法 + +### `is_private_chat()` / `is_group_chat()` + +判断消息类型。 + +```python +def is_private_chat(self) -> bool +def is_group_chat(self) -> bool +``` + +**返回**: `bool` - 是否为对应类型 + +**示例**: + +```python +@on_command("check") +async def check(self, event: MessageEvent): + if event.is_group_chat(): + await event.reply("这是群聊消息") + # 获取群组信息 + group_info = await event.get_group() + if group_info: + await event.reply(f"群名: {group_info.get('name')}") + elif event.is_private_chat(): + await event.reply("这是私聊消息") +``` + +--- + +### `is_admin()` + +判断发送者是否有管理员权限。 + +```python +def is_admin(self) -> bool +``` + +**返回**: `bool` - 是否为管理员 + +**示例**: + +```python +@on_command("admin_check") +async def admin_check(self, event: MessageEvent): + if event.is_admin(): + await event.reply("你是管理员") + else: + await event.reply("你不是管理员") +``` + +--- + +### `get_group()` + +获取当前群组元数据(仅群聊有效)。 + +```python +async def get_group(self) -> dict[str, Any] | None +``` + +**返回**: 群组信息字典,失败返回 None + +**示例**: + +```python +@on_command("group_info") +async def group_info(self, event: MessageEvent): + if not event.is_group_chat(): + await event.reply("这不是群聊消息") + return + + group_info = await event.get_group() + if group_info: + await event.reply(f"群名: {group_info.get('name')}") +``` + +--- + +## 回复与发送方法 + +### `reply(text)` + +回复纯文本消息。 + +```python +async def reply(self, text: str) -> None +``` + +**参数**: +- `text`: 要回复的文本内容 + +**异常**: +- `RuntimeError`: 如果未绑定 reply handler + +**示例**: + +```python +@on_command("hello") +async def hello(self, event: MessageEvent): + await event.reply("Hello, World!") +``` + +--- + +### `reply_image(image_url)` + +回复图片消息。 + +```python +async def reply_image(self, image_url: str) -> None +``` + +**参数**: +- `image_url`: 图片 URL + +**支持格式**: +- URL: `https://example.com/image.jpg` +- 本地文件: `file:///absolute/path/to/image.jpg` +- Base64: `base64://iVBORw0KGgo...` + +**示例**: + +```python +@on_command("cat") +async def cat(self, event: MessageEvent): + await event.reply_image("https://example.com/cat.jpg") + +@on_command("local_img") +async def local_img(self, event: MessageEvent): + await event.reply_image("file:///path/to/local/image.jpg") +``` + +--- + +### `reply_chain(chain)` + +回复消息链(多类型消息组合)。 + +```python +async def reply_chain( + self, + chain: MessageChain | list[BaseMessageComponent] | list[dict[str, Any]] +) -> None +``` + +**参数**: +- `chain`: 消息链组件列表 + +**示例**: + +```python +from astrbot_sdk.message_components import Plain, At, Image + +@on_command("rich") +async def rich(self, event: MessageEvent): + # 方式1: 使用 MessageChain + chain = MessageChain([ + Plain("Hello "), + At("123456"), + Plain("!"), + Image.fromURL("https://example.com/img.jpg") + ]) + await event.reply_chain(chain) + + # 方式2: 直接传递组件列表 + await event.reply_chain([ + Plain("文本"), + Image.fromURL("url") + ]) +``` + +--- + +### `react(emoji)` + +发送表情反应(如果平台支持)。 + +```python +async def react(self, emoji: str) -> bool +``` + +**参数**: +- `emoji`: emoji 表情 + +**返回**: `bool` - 是否平台支持并成功发送 + +**示例**: + +```python +@on_command("react") +async def react_cmd(self, event: MessageEvent): + supported = await event.react("👍") + if not supported: + await event.reply("该平台不支持表情反应") +``` + +--- + +### `send_typing()` + +发送正在输入状态(如果平台支持)。 + +```python +async def send_typing(self) -> bool +``` + +**返回**: `bool` - 是否平台支持并成功发送 + +--- + +### `send_streaming(generator, use_fallback=False)` + +发送流式消息。 + +```python +async def send_streaming( + self, + generator, + use_fallback: bool = False +) -> bool +``` + +**参数**: +- `generator`: 异步生成器 +- `use_fallback`: 是否使用降级模式 + +**示例**: + +```python +@on_command("stream") +async def stream_cmd(self, event: MessageEvent): + async def text_gen(): + parts = ["正在", "处理", "你的", "请求", "..."] + for part in parts: + yield part + await asyncio.sleep(0.5) + + success = await event.send_streaming(text_gen()) + if not success: + await event.reply("不支持流式消息") +``` + +--- + +## 事件控制方法 + +### `stop_event()` + +标记事件为已停止,阻止后续处理器执行。 + +```python +def stop_event(self) -> None +``` + +**示例**: + +```python +@on_command("admin") +@require_admin +async def admin_cmd(self, event: MessageEvent): + await event.reply("管理员操作已执行") + event.stop_event() # 阻止后续处理器 + +@on_command("public") +async def public_cmd(self, event: MessageEvent): + # 如果事件被停止,不会执行 + await event.reply("这是公共命令") +``` + +--- + +### `continue_event()` + +清除停止标记。 + +```python +def continue_event(self) -> None +``` + +--- + +### `is_stopped()` + +检查事件是否已停止。 + +```python +def is_stopped(self) -> bool +``` + +--- + +## Extra 数据管理 + +### `set_extra(key, value)` + +存储 SDK 本地的临时事件数据。 + +```python +def set_extra(self, key: str, value: Any) -> None +``` + +**参数**: +- `key`: 键名 +- `value`: 值 + +**示例**: + +```python +# 存储数据 +event.set_extra("custom_flag", True) +event.set_extra("temp_data", {"count": 5}) +``` + +> 请求范围内的 SDK hooks 会保留 JSON-safe 的本地 extras。不可 JSON 序列化的值只在当前 handler 内可见,不会自动带到后续 hook。 + +--- + +### `get_extra(key, default)` + +读取 SDK 本地临时事件数据。 + +```python +def get_extra(self, key: str | None = None, default: Any = None) -> Any +``` + +**参数**: +- `key`: 键名,None 时返回全部 extras +- `default`: 默认值 + +**示例**: + +```python +# 读取单个值 +flag = event.get_extra("custom_flag", False) + +# 读取全部 +all_extras = event.get_extra() +``` + +--- + +### `clear_extra()` + +清除所有 extra 数据。 + +```python +def clear_extra(self) -> None +``` + +--- + +## 结果构建方法 + +### `plain_result(text)` + +创建纯文本结果对象。 + +```python +def plain_result(self, text: str) -> PlainTextResult +``` + +**示例**: + +```python +@on_command("test") +async def test(self, event: MessageEvent): + return event.plain_result("返回内容") +``` + +--- + +### `image_result(url_or_path)` + +创建包含单个图片的链结果。 + +```python +def image_result(self, url_or_path: str) -> MessageEventResult +``` + +**参数**: +- `url_or_path`: URL 或本地路径 + +**支持格式**: +- URL: `https://example.com/image.jpg` +- 本地路径: `/path/to/image.jpg` +- Base64: `base64://iVBORw0KGgo...` + +**示例**: + +```python +@on_command("avatar") +async def avatar(self, event: MessageEvent): + return event.image_result("https://example.com/avatar.jpg") +``` + +--- + +### `chain_result(chain)` + +从 SDK 组件创建链结果。 + +```python +def chain_result( + self, + chain: MessageChain | list[BaseMessageComponent] +) -> MessageEventResult +``` + +**示例**: + +```python +@on_command("info") +async def info(self, event: MessageEvent): + return event.chain_result([ + Plain(f"用户: {event.sender_name}\n"), + Plain(f"ID: {event.user_id}") + ]) +``` + +--- + +### `make_result()` + +创建空的 SDK 结果包装器。 + +```python +def make_result(self) -> MessageEventResult +``` + +--- + +## 序列化与反序列化 + +### `from_payload()` + +从协议载荷创建事件实例(类方法)。 + +**签名**: +```python +@classmethod +def from_payload( + cls, + payload: dict[str, Any], + *, + context: Context | None = None, + reply_handler: ReplyHandler | None = None +) -> MessageEvent +``` + +**参数**: +- `payload`: 协议层传递的消息数据字典 +- `context`: 运行时上下文 +- `reply_handler`: 自定义回复处理器 + +**返回**: `MessageEvent` 实例 + +--- + +### `to_payload()` + +转换为协议载荷格式。 + +**签名**: +```python +def to_payload(self) -> dict[str, Any] +``` + +**返回**: 可序列化的字典 + +--- + +## 会话引用属性 + +### `session_ref` + +获取会话引用对象。 + +**类型**: `SessionRef | None` + +**说明**: 包含会话的完整信息,用于跨平台通信。 + +--- + +### `target` + +`session_ref` 的别名。 + +**类型**: `SessionRef | None` + +--- + +### `unified_msg_origin` + +统一消息来源标识符。 + +**类型**: `str` + +**说明**: 等同于 `session_id`。 + +--- + +## LLM 相关方法 + +### `request_llm()` + +请求触发默认 LLM 链处理当前消息。 + +**签名**: +```python +async def request_llm(self) -> bool +``` + +**返回**: `bool` - 是否应该调用 LLM + +**示例**: + +```python +@on_command("ask") +async def ask(self, event: MessageEvent): + should_call = await event.request_llm() + if should_call: + await event.reply("已触发 LLM 处理") +``` + +--- + +### `should_call_llm()` + +读取当前默认 LLM 决策状态。 + +**签名**: +```python +async def should_call_llm(self) -> bool +``` + +**返回**: `bool` - 是否应该调用 LLM + +**示例**: + +```python +@on_message() +async def handle(self, event: MessageEvent): + if await event.should_call_llm(): + response = await ctx.llm.chat(event.text) + await event.reply(response) +``` + +--- + +## 结果管理方法 + +### `set_result()` + +存储请求范围的 SDK 结果到主机桥。 + +**签名**: +```python +async def set_result(self, result: MessageEventResult) -> MessageEventResult +``` + +**参数**: +- `result`: 消息事件结果对象 + +**返回**: 传入的 `result` 对象 + +**示例**: + +```python +result = event.chain_result([Plain("处理结果")]) +await event.set_result(result) +``` + +--- + +### `get_result()` + +从主机桥读取当前请求范围的 SDK 结果。 + +**签名**: +```python +async def get_result(self) -> MessageEventResult | None +``` + +**返回**: `MessageEventResult | None` - 结果对象,不存在则返回 None + +--- + +### `clear_result()` + +清除当前请求范围的 SDK 结果。 + +**签名**: +```python +async def clear_result(self) -> None +``` + +--- + +## 其他方法 + +### `get_message_outline()` + +获取规范化的消息摘要。 + +**签名**: +```python +def get_message_outline(self) -> str +``` + +**返回**: 消息摘要文本 + +--- + +### `get_sent_message_outline()` + +获取 `after_message_sent` 事件里的实际发送摘要文本。 + +**签名**: +```python +def get_sent_message_outline(self) -> str +``` + +**返回**: 机器人实际发送的摘要文本;非发送后事件通常返回空字符串 + +--- + +### `get_sent_messages()` + +获取 `after_message_sent` 事件里的实际发送消息组件。 + +**签名**: +```python +def get_sent_messages(self) -> list[BaseMessageComponent] +``` + +**返回**: 机器人实际发送的消息组件列表;非发送后事件通常返回空列表 + +--- + +### `bind_reply_handler()` + +绑定自定义回复处理器。 + +**签名**: +```python +def bind_reply_handler(self, reply_handler: ReplyHandler) -> None +``` + +**参数**: +- `reply_handler`: 回复处理函数,接收文本参数 + +**示例**: + +```python +def custom_reply(text: str): + print(f"回复: {text}") + +event.bind_reply_handler(custom_reply) +await event.reply("测试") # 会调用 custom_reply +``` + +--- + +## 完整使用示例 + +### 示例 1: 基础消息处理 + +```python +from astrbot_sdk.decorators import on_command, on_message + +@on_command("hello") +async def hello(self, event: MessageEvent, ctx: Context): + await event.reply(f"你好,{event.sender_name}!") + +@on_message(keywords=["帮助"]) +async def help(self, event: MessageEvent, ctx: Context): + await event.reply("可用命令: /hello") +``` + +--- + +### 示例 2: 处理图片消息 + +```python +@on_message(regex="^保存图片$") +async def save_image(self, event: MessageEvent): + images = event.get_images() + if not images: + await event.reply("消息中没有图片") + return + + for img in images: + try: + local_path = await img.convert_to_file_path() + # 保存图片... + await event.reply(f"已保存: {local_path}") + except Exception as e: + await event.reply(f"保存失败: {e}") +``` + +--- + +### 示例 3: 检测@和群聊/私聊 + +```python +@on_command("check") +async def check(self, event: MessageEvent): + # 检查是否群聊 + if event.is_group_chat(): + await event.reply("这是群聊消息") + elif event.is_private_chat(): + await event.reply("这是私聊消息") + + # 检查@的用户 + at_users = event.get_at_users() + if at_users: + await event.reply(f"你@了: {', '.join(at_users)}") + + # 检查是否包含图片 + if event.has_component(Image): + await event.reply("消息包含图片") +``` + +--- + +### 示例 4: 返回富文本结果 + +```python +@on_command("info") +async def info(self, event: MessageEvent): + return event.chain_result([ + Plain(f"用户: {event.sender_name}\n"), + Plain(f"ID: {event.user_id}\n"), + Plain(f"平台: {event.platform}"), + ]) +``` + +--- + +### 示例 5: 事件控制 + +```python +@on_command("admin") +@require_admin +async def admin(self, event: MessageEvent): + await event.reply("管理员操作已执行") + event.stop_event() # 阻止后续处理器 + +@on_command("public") +async def public(self, event: MessageEvent): + # 如果事件被停止,不会执行 + await event.reply("这是公共命令") +``` + +--- + +## 注意事项 + +1. **必须绑定上下文**: 某些方法(如 `reply_image`, `reply_chain`, `get_group`)需要运行时上下文,未绑定时会抛出 `RuntimeError` + +2. **私有/群聊判断**: + - `is_private_chat()` 和 `is_group_chat()` 优先使用 `message_type` 字段 + - 其次通过 `group_id` 是否为 None 判断 + +3. **Extra 数据**: `_extras` 是 SDK 本地的,不会传递到核心,适合存储插件级别的临时状态 + +4. **事件停止**: `stop_event()` 只在 SDK 层面标记,不同处理器可能有不同的行为 + +5. **消息组件解析**: `get_messages()` 返回 SDK 组件列表,`extract_plain_text()` 只提取 Plain 组件 + +--- + +## 相关模块 + +- **消息组件**: `astrbot_sdk.message_components` - 所有消息组件类 +- **消息链**: `astrbot_sdk.message_result.MessageChain` - 消息链类 +- **消息构建器**: `astrbot_sdk.message_result.MessageBuilder` - 流式消息构建器 +- **会话引用**: `astrbot_sdk.protocol.descriptors.SessionRef` - 会话引用对象 + +--- + +**版本**: v4.0 +**模块**: `astrbot_sdk.events.MessageEvent` +**最后更新**: 2026-03-17 diff --git a/astrbot-sdk/docs/api/message_result.md b/astrbot-sdk/docs/api/message_result.md new file mode 100644 index 0000000000..fa3c1cb0bd --- /dev/null +++ b/astrbot-sdk/docs/api/message_result.md @@ -0,0 +1,728 @@ +# 消息结果 API 完整参考 + +## 概述 + +消息结果是用于构建和返回消息结果的类,包括消息链容器、流式构建器和事件结果包装器。 + +**模块路径**: `astrbot_sdk.message_result` + +--- + +## 目录 + +- [EventResultType - 事件结果类型枚举](#eventresulttype---事件结果类型枚举) +- [MessageChain - 消息链](#messagechain---消息链) +- [MessageBuilder - 消息构建器](#messagebuilder---消息构建器) +- [MessageEventResult - 消息事件结果](#messageeventresult---消息事件结果) + +--- + +## 导入方式 + +```python +# 从主模块导入 +from astrbot_sdk import MessageChain, MessageBuilder, MessageEventResult + +# 从子模块导入 +from astrbot_sdk.message_result import ( + MessageChain, + MessageBuilder, + MessageEventResult, + EventResultType, +) + +# 消息组件(用于构建消息链) +from astrbot_sdk.message_components import Plain, At, Image, File +``` + +--- + +## EventResultType - 事件结果类型枚举 + +事件结果的类型枚举,定义消息结果的类型。 + +### 定义 + +```python +class EventResultType(str, Enum): + EMPTY = "empty" # 空结果 + CHAIN = "chain" # 消息链结果 + PLAIN = "plain" # 纯文本结果 +``` + +### 值说明 + +| 值 | 说明 | +|------|------| +| `EventResultType.EMPTY` | 空结果,不返回任何内容 | +| `EventResultType.CHAIN` | 消息链结果,返回一个或多个消息组件 | +| `EventResultType.PLAIN` | 纯文本结果,返回文本内容 | + +--- + +## MessageChain - 消息链 + +消息链是消息组件的容器,用于组合多个组件形成复杂的消息。 + +### 类定义 + +```python +@dataclass(slots=True) +class MessageChain: + components: list[BaseMessageComponent] = field(default_factory=list) +``` + +### 构造方法 + +#### 空消息链 + +```python +from astrbot_sdk.message_result import MessageChain + +chain = MessageChain() +``` + +#### 带初始组件 + +```python +from astrbot_sdk.message_result import MessageChain +from astrbot_sdk.message_components import Plain, At + +chain = MessageChain([ + Plain("Hello"), + At("123456") +]) +``` + +### 实例方法 + +#### `append(component)` + +追加单个组件,返回 self 支持链式调用。 + +```python +def append(self, component: BaseMessageComponent) -> MessageChain: + """追加单个组件,返回 self""" + self.components.append(component) + return self +``` + +**参数**: +- `component` (`BaseMessageComponent`): 要追加的组件 + +**返回**: `MessageChain` - self + +**示例**: + +```python +chain = MessageChain() +chain.append(Plain("Hello ")) + .append(At("123456")) + .append(Plain("!")) +``` + +--- + +#### `extend(components)` + +追加多个组件,返回 self。 + +```python +def extend(self, components: list[BaseMessageComponent]) -> MessageChain: + """追加多个组件,返回 self""" + self.components.extend(components) + return self +``` + +**参数**: +- `components` (`list[BaseMessageComponent]`): 组件列表 + +**示例**: + +```python +chain = MessageChain() +chain.extend([ + Plain("A"), + Plain("B"), + Plain("C") +]) +``` + +--- + +#### `to_payload()` + +同步转换为协议 payload。 + +```python +def to_payload(self) -> list[dict[str, Any]]: + """转换为协议 payload""" + return [component_to_payload_sync(c) for c in self.components] +``` + +**返回**: `list[dict]` - 可序列化的字典列表 + +--- + +#### `to_payload_async()` + +异步转换为协议 payload。 + +```python +async def to_payload_async(self) -> list[dict[str, Any]]: + """异步转换为协议 payload""" + return [await component_to_payload(c) for c in self.components] +``` + +**注意**: 某些组件(如 Reply)的异步序列化可能包含额外逻辑 + +--- + +#### `get_plain_text(with_other_comps_mark=False)` + +提取纯文本内容。 + +```python +def get_plain_text(self, with_other_comps_mark: bool = False) -> str: + """提取纯文本内容""" + texts: list[str] = [] + for component in self.components: + if isinstance(component, Plain): + texts.append(component.text) + elif with_other_comps_mark: + texts.append(f"[{component.__class__.__name__}]") + return " ".join(texts) +``` + +**参数**: +- `with_other_comps_mark`: 是否为非文本组件显示类型标记 + +**返回**: `str` - 纯文本内容 + +**示例**: + +```python +chain = MessageChain([ + Plain("Hello "), + At("123456"), + Plain("!") +]) + +chain.get_plain_text() # "Hello !" +chain.get_plain_text(True) # "Hello [At] !" +``` + +--- + +#### `plain_text(with_other_comps_mark=False)` + +`get_plain_text()` 的别名。 + +```python +def plain_text(self, with_other_comps_mark: bool = False) -> str: + return self.get_plain_text(with_other_comps_mark=with_other_comps_mark) +``` + +--- + +### 迭代与长度 + +```python +# 迭代 +for component in chain: + print(f"组件: {component.__class__.__name__}") + +# 长度 +len(chain) # 组件数量 +``` + +--- + +### 使用示例 + +```python +from astrbot_sdk.message_result import MessageChain +from astrbot_sdk.message_components import Plain, At, Image + +# 创建并使用 +chain = MessageChain([ + Plain("Hello "), + At("123456"), + Plain("!"), + Image.fromURL("https://example.com/img.jpg") +]) + +# 转换为 payload +payload = chain.to_payload() + +# 提取文本 +text = chain.get_plain_text() + +# 链式追加 +chain.append(Plain("More text")) +``` + +--- + +## MessageBuilder - 消息构建器 + +流式构建消息链的工具类,提供流畅的 API。 + +### 类定义 + +```python +@dataclass(slots=True) +class MessageBuilder: + components: list[BaseMessageComponent] = field(default_factory=list) +``` + +### 链式方法 + +所有方法都返回 `self`,支持链式调用。 + +#### `text(content)` + +添加文本组件。 + +```python +def text(self, content: str) -> MessageBuilder: + """添加文本组件""" + self.components.append(Plain(content, convert=False)) + return self +``` + +**示例**: + +```python +builder = MessageBuilder() +builder.text("Hello ") +``` + +--- + +#### `at(user_id)` + +添加@组件。 + +```python +def at(self, user_id: str) -> MessageBuilder: + """添加@用户""" + self.components.append(At(user_id)) + return self +``` + +--- + +#### `at_all()` + +添加@全体成员。 + +```python +def at_all(self) -> MessageBuilder: + """添加@全体成员""" + self.components.append(AtAll()) + return self +``` + +--- + +#### `image(url)` + +添加图片。 + +```python +def image(self, url: str) -> MessageBuilder: + """添加图片""" + self.components.append(Image.fromURL(url)) + return self +``` + +--- + +#### `record(url)` + +添加语音。 + +```python +def record(self, url: str) -> MessageBuilder: + """添加语音""" + self.components.append(Record.fromURL(url)) + return self +``` + +--- + +#### `video(url)` + +添加视频。 + +```python +def video(self, url: str) -> MessageBuilder: + """添加视频""" + self.components.append(Video.fromURL(url)) + return self +``` + +--- + +#### `file(name, *, file="", url="")` + +添加文件。 + +```python +def file(self, name: str, *, file: str = "", url: str = "") -> MessageBuilder: + """添加文件""" + self.components.append(File(name=name, file=file, url=url)) + return self +``` + +--- + +#### `reply(**kwargs)` + +添加回复组件。 + +```python +def reply(self, **kwargs: Any) -> MessageBuilder: + """添加回复组件""" + self.components.append(Reply(**kwargs)) + return self +``` + +--- + +#### `append(component)` + +添加任意组件。 + +```python +def append(self, component: BaseMessageComponent) -> MessageBuilder: + """添加任意组件""" + self.components.append(component) + return self +``` + +--- + +#### `extend(components)` + +添加多个组件。 + +```python +def extend(self, components: list[BaseMessageComponent]) -> MessageBuilder: + """添加多个组件""" + self.components.extend(components) + return self +``` + +--- + +#### `build()` + +构建 MessageChain。 + +```python +def build(self) -> MessageChain: + """构建消息链""" + return MessageChain(list(self.components)) +``` + +**返回**: `MessageChain` - 包含所有组件的消息链对象 + +--- + +### 完整使用示例 + +```python +from astrbot_sdk.message_result import MessageBuilder +from astrbot_sdk.message_components import Plain, At, Image + +# 链式构建 +chain = (MessageBuilder() + .text("Hello ") + .at("123456") + .text("!\n") + .image("https://example.com/img.jpg") + .build()) + +# 使用 MessageChain +chain = MessageChain([ + Plain("Hello "), + At("123456"), + Plain("!\n"), + Image.fromURL("https://example.com/img.jpg") +]) + +# 两种方式结果相同 +``` + +--- + +## MessageEventResult - 消息事件结果 + +消息事件结果的包装类,用于 handler 返回值。 + +### 类定义 + +```python +@dataclass(slots=True) +class MessageEventResult: + type: EventResultType = EventResultType.EMPTY + chain: MessageChain = field(default_factory=MessageChain) +``` + +### 构造方法 + +#### 空结果 + +```python +from astrbot_sdk.message_result import MessageEventResult, EventResultType + +result = MessageEventResult() +# 或 +result = MessageEventResult(type=EventResultType.EMPTY) +``` + +--- + +#### 纯文本结果 + +```python +result = MessageEventResult( + type=EventResultType.PLAIN, + chain=MessageChain([Plain("返回内容")]) +) +``` + +--- + +#### 消息链结果 + +```python +from astrbot_sdk.message_result import MessageEventResult, EventResultType, MessageChain +from astrbot_sdk.message_components import Plain, Image + +result = MessageEventResult( + type=EventResultType.CHAIN, + chain=MessageChain([ + Plain("文本"), + Image(url="https://example.com/a.png") + ]) +) +``` + +--- + +### 实例方法 + +#### `to_payload()` + +转换为协议 payload。 + +```python +def to_payload(self) -> dict[str, Any]: + """转换为协议 payload""" + return { + "type": self.type.value, + "chain": self.chain.to_payload(), + } +``` + +**返回格式**: + +```python +# EMPTY +{"type": "empty", "chain": []} + +# CHAIN +{ + "type": "chain", + "chain": [ + {"type": "text", "data": {"text": "内容"}}, + {"type": "image", "data": {"url": "..."}} + ] +} + +# PLAIN +{ + "type": "plain", + "chain": [{"type": "text", "data": {"text": "内容"}}] +} +``` + +--- + +#### `from_payload(payload)` + +从协议 payload 创建实例。 + +```python +@classmethod +def from_payload(cls, payload: dict[str, Any]) -> MessageEventResult: + result_type_raw = str(payload.get("type", EventResultType.EMPTY.value)) + try: + result_type = EventResultType(result_type_raw) + except ValueError: + result_type = EventResultType.EMPTY + chain_payload = payload.get("chain") + components = ( + payloads_to_components(chain_payload) + if isinstance(chain_payload, list) + else [] + ) + return cls(type=result_type, chain=MessageChain(components)) +``` + +--- + +### 使用示例 + +```python +@on_command("return_text") +async def return_text(self, event: MessageEvent): + # 返回纯文本结果 + return event.plain_result("返回内容") + +@on_command("return_image") +async def return_image(self, event: MessageEvent): + # 返回图片结果 + return event.image_result("https://example.com/image.jpg") + +@on_command("return_chain") +async def return_chain(self, event: MessageEvent): + # 返回消息链结果 + return event.chain_result([ + Plain(f"用户: {event.sender_name}"), + Plain(f"ID: {event.user_id}"), + Plain(f"平台: {event.platform}"), + ]) +``` + +--- + +## 使用场景示例 + +### 场景1: 使用 MessageBuilder 构建复杂消息 + +```python +@on_command("rich") +async def rich_message(self, event: MessageEvent): + chain = (MessageBuilder() + .text("你好 ") + .at(event.user_id or "123456") + .text("!\n\n") + .image("https://example.com/welcome.jpg") + .text("这是欢迎图片") + .build()) + + await event.reply_chain(chain) +``` + +--- + +### 场景2: 使用 MessageChain 组合组件 + +```python +@on_command("multi") +async def multi_component(self, event: MessageEvent, count: int): + components = [Plain(f"发送 {count} 条消息:\n")] + + for i in range(count): + components.append(Plain(f"{i+1}. ")) + if i < count - 1: + components.append(Plain("\n")) + + await event.reply_chain(components) +``` + +--- + +### 场景3: 返回结构化结果 + +```python +@on_command("user_info") +async def user_info(self, event: MessageEvent): + return event.chain_result([ + Plain(f"用户: {event.sender_name}\n"), + Plain(f"ID: {event.user_id}\n"), + Plain(f"平台: {event.platform}\n"), + Plain(f"消息类型: {event.message_type}\n"), + ]) +``` + +--- + +## 辅助函数 + +### `coerce_message_chain(value)` + +将多种输入格式统一转换为 MessageChain。 + +**签名**: +```python +def coerce_message_chain(value: Any) -> MessageChain | None +``` + +**参数**: +- `value`: 要转换的值,支持以下类型: + - `MessageEventResult`: 提取其中的 chain + - `MessageChain`: 直接返回 + - `BaseMessageComponent`: 包装为单元素链 + - `list[BaseMessageComponent]`: 包装为链 + +**返回**: `MessageChain | None` - 转换后的消息链,无法转换则返回 None + +**示例**: + +```python +from astrbot_sdk.message_result import coerce_message_chain, MessageChain +from astrbot_sdk.message_components import Plain, Image + +# 从 MessageEventResult 提取 +chain = coerce_message_chain(result) + +# 从 MessageChain 返回 +chain = coerce_message_chain(existing_chain) + +# 从单个组件创建 +chain = coerce_message_chain(Plain("文本")) + +# 从组件列表创建 +chain = coerce_message_chain([Plain("文本"), Image.fromURL("url")]) +``` + +--- + +## 注意事项 + +1. **MessageChain 可变性**: + - `append()` 和 `extend()` 修改原链并返回 self + - 支持链式调用 + - 注意:链式操作会修改原链 + +2. **异步序列化**: + - 大多数情况用 `to_payload()` 即可 + - 包含 `Reply` 组件时建议用 `to_payload_async()` + +3. **纯文本提取**: + - `get_plain_text()` 默认忽略非文本组件 + - 设置 `with_other_comps_mark=True` 显示类型标记 + +4. **结果类型**: + - `EMPTY`: 不返回任何内容 + - `CHAIN`: 返回一个或多个消息组件 + - `PLAIN`: 返回文本内容 + +--- + +## 相关模块 + +- **消息组件**: `astrbot_sdk.message_components` +- **事件结果**: `astrbot_sdk.events.MessageEventResult` +- **事件类型**: `astrbot_sdk.events.EventResultType` + +--- + +**版本**: v4.0 +**模块**: `astrbot_sdk.message_result` +**最后更新**: 2026-03-17 diff --git a/astrbot-sdk/docs/api/star.md b/astrbot-sdk/docs/api/star.md new file mode 100644 index 0000000000..30a7899fb2 --- /dev/null +++ b/astrbot-sdk/docs/api/star.md @@ -0,0 +1,740 @@ +# Star 类 - 插件基类完整参考 + +## 概述 + +`Star` 是 AstrBot SDK 的插件基类,所有 v4 原生插件都必须继承此类。它提供了完整的插件生命周期管理、上下文访问和能力集成。 + +**模块路径**: `astrbot_sdk.star.Star` + +--- + +## 类定义 + +```python +class Star(PluginKVStoreMixin): + """v4 原生插件基类""" + + __handlers__: tuple[str, ...] # 自动收集的处理器列表 + + # 生命周期钩子 + async def on_start(self, ctx: Any | None = None) -> None + async def on_stop(self, ctx: Any | None = None) -> None + async def initialize(self) -> None + async def terminate(self) -> None + async def on_error(self, error: Exception, event, ctx) -> None + + # 便捷属性 + @property + def context(self) -> Context | None + + # 便捷方法 + async def text_to_image(self, text: str, *, return_url: bool = True) -> str + async def html_render(self, tmpl: str, data: dict, *, return_url: bool = True) -> str + + # KV 存储方法(继承自 PluginKVStoreMixin) + async def put_kv_data(self, key: str, value: Any) -> None + async def get_kv_data(self, key: str, default: _VT) -> _VT + async def delete_kv_data(self, key: str) -> None +``` + +--- + +## 导入方式 + +```python +# 从主模块导入(推荐) +from astrbot_sdk import Star + +# 从子模块导入 +from astrbot_sdk.star import Star + +# 常用配套导入 +from astrbot_sdk import Context, MessageEvent # 上下文和事件 +from astrbot_sdk.decorators import on_command, on_message # 装饰器 +from astrbot_sdk.errors import AstrBotError # 错误处理 +``` + +--- + +## 核心属性 + +### `__handlers__` + +自动收集的事件处理器元组。 + +```python +class MyPlugin(Star): + @on_command("cmd1") + async def cmd1_handler(self, event, ctx): + pass + +# MyPlugin.__handlers__ == ("cmd1_handler",) +``` + +**说明**: 在子类创建时,`__init_subclass__()` 会自动扫描所有装饰了 `@on_command`、`@on_message` 等装饰器的方法,并将处理器名称收集到此元组中。 + +### `context` + +获取当前运行时上下文的属性。 + +```python +class MyPlugin(Star): + async def some_method(self): + ctx = self.context + if ctx: + await ctx.db.set("key", "value") +``` + +**返回**: `Context | None` - 仅在生命周期钩子和 Handler 执行期间可用 + +**注意**: 不要存储此引用,它在插件停止后会被清除 + +--- + +## 生命周期钩子 + +### 1. `on_start(ctx)` - 插件启动钩子 + +**签名**: +```python +async def on_start(self, ctx: Any | None = None) -> None +``` + +**参数**: +- `ctx`: 运行时上下文(通常为 `Context` 实例) + +**触发时机**: Worker 启动后,在开始处理事件之前调用 + +**用途**: +- 初始化数据库连接 +- 加载配置文件 +- 注册 LLM 工具 +- 启动后台任务 +- 验证外部依赖 + +**示例**: + +```python +class MyPlugin(Star): + async def on_start(self, ctx) -> None: + # 确保 initialize 被调用 + await super().on_start(ctx) + + # 获取插件数据目录 + data_dir = await ctx.get_data_dir() + + # 加载配置 + config = await ctx.metadata.get_plugin_config() + self.api_key = config.get("api_key", "") + + # 注册 LLM 工具 + await ctx.register_llm_tool( + name="search", + parameters_schema={...}, + desc="搜索信息", + func_obj=self.search_tool + ) + + # 启动后台任务 + await ctx.register_task( + self.background_sync(), + desc="后台数据同步" + ) + + ctx.logger.info(f"{ctx.plugin_id} 启动成功") +``` + +**注意事项**: +1. 始终调用 `await super().on_start(ctx)` 确保 `initialize()` 被调用 +2. 在此方法中抛出的异常会导致插件加载失败 +3. 此方法中 `ctx` 参数保证不为 `None` + +--- + +### 2. `on_stop(ctx)` - 插件停止钩子 + +**签名**: +```python +async def on_stop(self, ctx: Any | None = None) -> None +``` + +**参数**: +- `ctx`: 运行时上下文 + +**触发时机**: 插件卸载或程序关闭前调用 + +**用途**: +- 关闭数据库连接 +- 清理临时文件 +- 注销 LLM 工具 +- 保存状态数据 + +**示例**: + +```python +class MyPlugin(Star): + async def on_stop(self, ctx) -> None: + # 保存状态 + await self.put_kv_data("last_shutdown", time.time()) + + # 注销工具 + if hasattr(self, '_tool_name'): + await ctx.unregister_llm_tool(self._tool_name) + + # 确保 terminate 被调用 + await super().on_stop(ctx) + + ctx.logger.info(f"{ctx.plugin_id} 已停止") +``` + +**注意事项**: +1. 始终调用 `await super().on_stop(ctx)` 确保 `terminate()` 被调用 +2. 此方法中的异常会被捕获并记录,不会阻止插件关闭 +3. 此时可能没有活跃的事件处理,避免发送消息 + +--- + +### 3. `initialize()` - 初始化钩子 + +**签名**: +```python +async def initialize(self) -> None +``` + +**触发时机**: `on_start()` 内部自动调用 + +**用途**: +- 插件级别的初始化逻辑 +- 不依赖 Context 的初始化 + +**示例**: + +```python +class MyPlugin(Star): + async def initialize(self) -> None: + """初始化插件""" + self._cache = {} + self._counter = 0 + self.state = "ready" +``` + +**与 `on_start` 的区别**: +- `initialize()` 无 `Context` 参数,用于不依赖外部资源的初始化 +- `on_start(ctx)` 有 `Context` 参数,用于需要访问 Core 的初始化 + +**调用顺序**: +``` +插件实例化 + ↓ +initialize() ← 先调用(无 Context) + ↓ +on_start(ctx) ← 后调用(有 Context) +``` + +--- + +### 4. `terminate()` - 终止钩子 + +**签名**: +```python +async def terminate(self) -> None +``` + +**触发时机**: `on_stop()` 内部自动调用 + +**用途**: +- 插件级别的清理逻辑 +- 不依赖 Context 的清理 + +**示例**: + +```python +class MyPlugin(Star): + async def terminate(self) -> None: + """清理插件资源""" + self._cache.clear() + self.state = "stopped" +``` + +**与 `on_stop` 的区别**: +- `terminate()` 无 `Context` 参数,用于清理插件内部资源 +- `on_stop(ctx)` 有 `Context` 参数,用于清理需要与 Core 交互的资源 + +**调用顺序**: +``` +on_stop(ctx) ← 先调用(有 Context) + ↓ +terminate() ← 后调用(无 Context) + ↓ +插件卸载 +``` + +--- + +### 5. `on_error(error, event, ctx)` - 错误处理钩子 + +**签名**: +```python + async def on_error(self, error: Exception, event, ctx) -> None + + # 类方法 + @classmethod + def __astrbot_is_new_star__(cls) -> bool +``` + +**参数**: +- `error`: 捕获的异常 +- `event`: 事件对象(可能是 `MessageEvent` 或其他类型) +- `ctx`: 上下文对象 + +**触发时机**: 任何 Handler 执行抛出异常时 + +**默认行为**: +- `AstrBotError`:根据错误类型发送友好提示 +- 其他异常:发送通用错误消息 +- 记录错误日志 + +**示例**: + +```python +from astrbot_sdk.errors import AstrBotError + +class MyPlugin(Star): + async def on_error(self, error: Exception, event, ctx) -> None: + """自定义错误处理""" + + # SDK 标准错误 + if isinstance(error, AstrBotError): + lines = [] + if error.retryable: + lines.append("请求失败,请稍后重试") + elif error.hint: + lines.append(error.hint) + else: + lines.append(error.message) + + if error.docs_url: + lines.append(f"文档:{error.docs_url}") + + await event.reply("\n".join(lines)) + + # 业务逻辑错误 + elif isinstance(error, ValueError): + await event.reply(f"参数错误:{error}") + + # 网络错误 + elif isinstance(error, ConnectionError): + await event.reply("网络连接失败,请检查网络设置") + + # 未知错误 + else: + await event.reply(f"出错了:{type(error).__name__}") + + # 记录详细错误 + ctx.logger.error(f"Handler failed: {error}", exc_info=error) +``` + +**覆盖建议**: +1. 始终记录错误日志 +2. 向用户提供友好的错误提示 +3. 调用 `await super().on_error(...)` 作为后备 + +--- + +## 类方法 + +### `__astrbot_is_new_star__()` + +标识类为 v4 原生插件。 + +**签名**: +```python +@classmethod +def __astrbot_is_new_star__(cls) -> bool +``` + +**返回**: `bool` - 始终返回 `True` + +**说明**: 此方法用于运行时识别插件类型,v4 原生插件返回 `True`,旧版插件无此方法。 + +--- + +## 便捷方法 + +### `text_to_image()` + +将文本渲染为图片。 + +**签名**: +```python +async def text_to_image( + self, + text: str, + *, + return_url: bool = True +) -> str +``` + +**参数**: +- `text`: 要渲染的文本 +- `return_url`: 是否返回 URL(False 则返回本地路径) + +**返回**: 图片 URL 或路径 + +**示例**: + +```python +class MyPlugin(Star): + @on_command("text_img") + async def text_to_image_cmd(self, event: MessageEvent): + url = await self.text_to_image("Hello World") + await event.reply_image(url) +``` + +**等价于**: +```python +url = await ctx.text_to_image("Hello World") +``` + +--- + +### `html_render()` + +渲染 HTML 模板。 + +**签名**: +```python +async def html_render( + self, + tmpl: str, + data: dict, + *, + return_url: bool = True, + options: dict[str, Any] | None = None +) -> str +``` + +**参数**: +- `tmpl`: HTML 模板内容 +- `data`: 模板数据 +- `return_url`: 是否返回 URL +- `options`: 渲染选项 + +**返回**: 渲染结果 URL 或路径 + +**示例**: + +```python +class MyPlugin(Star): + @on_command("card") + async def card_cmd(self, event: MessageEvent): + url = await self.html_render( + tmpl="

{{ title }}

{{ content }}

", + data={"title": "标题", "content": "内容"} + ) + await event.reply_image(url) +``` + +**等价于**: +```python +url = await ctx.html_render(tmpl, data) +``` + +--- + +## KV 存储方法 + +这些方法继承自 `PluginKVStoreMixin`,提供简单的键值存储能力。 + +### `put_kv_data()` + +存储数据。 + +**签名**: +```python +async def put_kv_data(self, key: str, value: Any) -> None +``` + +**示例**: + +```python +await self.put_kv_data("last_run", time.time()) +``` + +### `get_kv_data()` + +获取数据。 + +**签名**: +```python +async def get_kv_data(self, key: str, default: _VT) -> _VT +``` + +**示例**: + +```python +last_run = await self.get_kv_data("last_run", 0) +``` + +### `delete_kv_data()` + +删除数据。 + +**签名**: +```python +async def delete_kv_data(self, key: str) -> None +``` + +**示例**: + +```python +await self.delete_kv_data("temp_data") +``` + +--- + +## 完整插件示例 + +```python +""" +完整的插件示例 +""" + +from astrbot_sdk import Star, Context, MessageEvent +from astrbot_sdk.decorators import on_command, on_message, provide_capability +from astrbot_sdk.errors import AstrBotError +import asyncio +import time + +class CompletePlugin(Star): + """完整功能插件""" + + async def initialize(self) -> None: + """初始化""" + self._stats = { + "start_time": time.time(), + "command_count": 0 + } + + async def on_start(self, ctx) -> None: + """启动""" + await super().on_start(ctx) + + # 加载配置 + config = await ctx.metadata.get_plugin_config() + self.greeting = config.get("greeting", "你好") + + # 注册 LLM 工具 + await ctx.register_llm_tool( + name="get_time", + parameters_schema={ + "type": "object", + "properties": {}, + "required": [] + }, + desc="获取当前时间", + func_obj=self.get_time_tool + ) + + # 启动后台任务 + await ctx.register_task( + self.background_sync(), + desc="后台数据同步" + ) + + ctx.logger.info("Plugin started") + + async def on_stop(self, ctx) -> None: + """停止""" + # 保存统计 + await self.put_kv_data("stats", self._stats) + await super().on_stop(ctx) + ctx.logger.info("Plugin stopped") + + @on_command("hello", aliases=["hi", "greet"]) + async def hello(self, event: MessageEvent, ctx: Context) -> None: + """打招呼命令""" + self._stats["command_count"] += 1 + await event.reply(f"{self.greeting},{event.sender_name}!") + + @on_command("stats") + async def stats(self, event: MessageEvent, ctx: Context) -> None: + """统计信息""" + uptime = time.time() - self._stats["start_time"] + await event.reply(f""" + 运行时间: {uptime:.0f}秒 + 命令次数: {self._stats['command_count']} + """) + + @on_message(keywords=["帮助"]) + async def help(self, event: MessageEvent, ctx: Context) -> None: + """帮助信息""" + await event.reply(""" + 可用命令: + /hello - 打招呼 + /stats - 统计信息 + /time - 当前时间 + """) + + @on_command("time") + async def time_cmd(self, event: MessageEvent, ctx: Context) -> None: + """获取时间""" + result = await self.get_time_tool() + await event.reply(result) + + async def get_time_tool(self) -> str: + """LLM 工具实现""" + return f"当前时间: {time.strftime('%Y-%m-%d %H:%M:%S')}" + + async def background_sync(self): + """后台任务""" + while True: + await asyncio.sleep(3600) + # 执行同步逻辑 + pass + + async def on_error(self, error: Exception, event, ctx) -> None: + """错误处理""" + if isinstance(error, AstrBotError): + await event.reply(error.hint or error.message) + else: + await event.reply(f"发生错误: {type(error).__name__}") + ctx.logger.error(f"Error: {error}", exc_info=error) +``` + +--- + +## plugin.yaml 配置 + +```yaml +_schema_version: 2 +name: my_plugin +author: Your Name +version: 1.0.0 +desc: 我的插件描述 +repo: https://github.com/user/repo +logo: assets/logo.png + +runtime: + python: "3.12" + +components: + - class: main:MyPlugin + +support_platforms: + - aiocqhttp + - telegram + - discord + +astrbot_version: ">=4.13.0,<5.0.0" + +config: + timeout: 30 + max_retries: 3 + api_key: "" +``` + +--- + +## 最佳实践 + +### 1. 资源初始化与清理 + +```python +class MyPlugin(Star): + async def on_start(self, ctx): + # 创建资源 + self._session = aiohttp.ClientSession() + self._task = asyncio.create_task(self.background_task()) + + async def on_stop(self, ctx): + # 清理资源 + if hasattr(self, '_task'): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + + if hasattr(self, '_session'): + await self._session.close() +``` + +### 2. 配置管理 + +```python +class MyPlugin(Star): + async def on_start(self, ctx): + config = await ctx.metadata.get_plugin_config() + + # 提供默认值 + self.timeout = config.get("timeout", 30) + + # 验证必需配置 + if "api_key" not in config: + raise ValueError("缺少必需配置: api_key") + + self.api_key = config["api_key"] +``` + +### 3. 状态持久化 + +```python +class MyPlugin(Star): + async def on_start(self, ctx): + # 加载状态 + self.last_update = await self.get_kv_data("last_update", 0) + self.user_data = await self.get_kv_data("users", {}) + + async def on_stop(self, ctx): + # 保存状态 + await self.put_kv_data("last_update", time.time()) + await self.put_kv_data("users", self.user_data) +``` + +### 4. 错误处理 + +```python +class MyPlugin(Star): + async def on_error(self, error, event, ctx): + # 根据错误类型发送不同的提示 + if isinstance(error, ValueError): + await event.reply("参数错误") + elif isinstance(error, ConnectionError): + await event.reply("网络连接失败") + else: + # 使用默认处理 + await super().on_error(error, event, ctx) + + # 记录日志 + ctx.logger.error(f"Handler error: {error}", exc_info=error) +``` + +--- + +## 注意事项 + +1. **异步方法**: 所有生命周期钩子都是异步方法,必须使用 `async def` 声明 + +2. **super() 调用**: 在 `on_start` 和 `on_stop` 中始终调用 `await super().xxx(ctx)` 确保 `initialize`/`terminate` 被调用 + +3. **context 属性**: 仅在生命周期钩子和 Handler 执行期间可用,不要存储此引用 + +4. **异常处理**: `on_start` 中的异常会导致插件加载失败,`on_stop` 中的异常会被捕获并记录 + +5. **资源清理**: 确保在 `on_stop` 或 `terminate` 中清理所有资源(连接、任务、文件等) + +--- + +## 相关模块 + +- **装饰器**: `astrbot_sdk.decorators` - 事件处理装饰器 +- **上下文**: `astrbot_sdk.context.Context` - 运行时上下文 +- **事件**: `astrbot_sdk.events.MessageEvent` - 消息事件 +- **错误**: `astrbot_sdk.errors.AstrBotError` - SDK 错误类 + +--- + +**版本**: v4.0 +**模块**: `astrbot_sdk.star.Star` +**最后更新**: 2026-03-17 diff --git a/astrbot-sdk/docs/api/types.md b/astrbot-sdk/docs/api/types.md new file mode 100644 index 0000000000..9526541701 --- /dev/null +++ b/astrbot-sdk/docs/api/types.md @@ -0,0 +1,497 @@ +# 类型定义 API 完整参考 + +## 概述 + +本文档介绍 AstrBot SDK 中常用的类型定义,包括类型别名、泛型变量和类型注解。 + +**模块路径**: 分布在各个 SDK 模块中 + +--- + +## 目录 + +- [类型别名](#类型别名) +- [泛型变量](#泛型变量) +- [特殊类型](#特殊类型) +- [使用示例](#使用示例) + +--- + +## 导入方式 + +```python +# 类型别名 +from astrbot_sdk.context import PlatformCompatContent +from astrbot_sdk.clients.llm import ChatMessage, ChatHistoryItem, LLMResponse + +# 泛型变量(通常不需要直接导入) +from astrbot_sdk.session_waiter import _P, _ResultT, _OwnerT +from astrbot_sdk.plugin_kv import _VT + +# 通用类型 +from typing import Callable, Awaitable, Any, Sequence, Mapping + +HandlerType = Callable[..., Awaitable[Any]] +FilterType = Callable[..., Awaitable[bool]] +``` + +--- + +## 类型别名 + +### PlatformCompatContent + +平台兼容的内容类型,用于表示可以发送到平台的各种消息格式。 + +**定义位置**: `astrbot_sdk.context` + +**定义**: + +```python +from collections.abc import Sequence +from typing import Any + +PlatformCompatContent = ( + str | MessageChain | Sequence[BaseMessageComponent] | Sequence[dict[str, Any]] +) +``` + +**说明**: + +此类型别名表示可以用于平台发送方法的内容类型,支持以下四种格式: + +| 格式 | 说明 | 示例 | +|------|------|------| +| `str` | 纯文本字符串 | `"Hello World"` | +| `MessageChain` | 消息链对象 | `MessageChain([Plain("Hi")])` | +| `Sequence[BaseMessageComponent]` | 消息组件列表 | `[Plain("Hi"), At("123")]` | +| `Sequence[dict[str, Any]]` | 序列化后的字典列表 | `[{"type": "text", "data": {"text": "Hi"}}]` | + +**使用位置**: + +- `Context.send_message()` +- `Context.send_message_by_id()` +- `PlatformClient.send_by_session()` +- `StarTools.send_message()` + +**示例**: + +```python +from astrbot_sdk import Plain, Image, MessageChain + +# 纯文本 +await ctx.platform.send_by_session("session_id", "Hello") + +# 消息链 +chain = MessageChain([Plain("Hello"), Image.fromURL("...")]) +await ctx.platform.send_by_session("session_id", chain) + +# 组件列表 +await ctx.platform.send_by_session("session_id", [ + Plain("Hello"), + At("123456") +]) + +# 字典列表 +await ctx.platform.send_by_session("session_id", [ + {"type": "text", "data": {"text": "Hello"}} +]) +``` + +--- + +### ChatHistoryItem + +聊天历史项类型,用于构建对话历史。 + +**定义位置**: `astrbot_sdk.clients.llm` + +**定义**: + +```python +from collections.abc import Mapping +from typing import Any +from pydantic import BaseModel + +class ChatMessage(BaseModel): + role: str + content: str + +ChatHistoryItem = ChatMessage | Mapping[str, Any] +``` + +**说明**: + +此类型别名表示对话历史中的一项,可以是 `ChatMessage` 对象或任何字典类型的映射。 + +**支持格式**: + +| 格式 | 说明 | 示例 | +|------|------|------| +| `ChatMessage` | Pydantic 模型对象 | `ChatMessage(role="user", content="Hi")` | +| `Mapping[str, Any]` | 字典类型 | `{"role": "user", "content": "Hi"}` | + +**使用位置**: + +- `LLMClient.chat()` - `history` 参数 +- `LLMClient.chat_raw()` - `history` 参数 +- `LLMClient.stream_chat()` - `history` 参数 + +**示例**: + +```python +from astrbot_sdk.clients.llm import ChatMessage + +# 使用 ChatMessage 对象 +history = [ + ChatMessage(role="user", content="你好"), + ChatMessage(role="assistant", content="你好!"), +] + +# 使用字典 +history = [ + {"role": "user", "content": "你好"}, + {"role": "assistant", "content": "你好!"}, +] + +# 混合使用 +history = [ + ChatMessage(role="user", content="你好"), + {"role": "assistant", "content": "你好!"}, + {"role": "user", "content": "今天天气怎么样?"}, +] +``` + +--- + +## 泛型变量 + +SDK 内部使用的泛型类型变量,用于类型注解。 + +### `_P` - 参数规范 + +**定义位置**: `astrbot_sdk.session_waiter` + +**定义**: + +```python +from typing import ParamSpec + +_P = ParamSpec("_P") +``` + +**说明**: + +用于捕获可调用对象的参数签名,主要在装饰器中使用。 + +--- + +### `_ResultT` - 结果类型 + +**定义位置**: `astrbot_sdk.session_waiter` + +**定义**: + +```python +from typing import TypeVar + +_ResultT = TypeVar("_ResultT") +``` + +**说明**: + +表示异步函数的返回结果类型。 + +--- + +### `_OwnerT` - 所有者类型 + +**定义位置**: `astrbot_sdk.session_waiter` + +**定义**: + +```python +_OwnerT = TypeVar("_OwnerT") +``` + +**说明**: + +表示类的所有者类型(通常是 `Star` 子类)。 + +--- + +### `_VT` - 值类型 + +**定义位置**: `astrbot_sdk.plugin_kv` + +**定义**: + +```python +_VT = TypeVar("_VT") +``` + +**说明**: + +用于 KV 存储中默认值的类型。 + +**使用位置**: + +- `PluginKVStoreMixin.get_kv_data()` - `default` 参数的类型注解 + +**示例**: + +```python +# default 参数的类型会根据传入的值自动推断 +value = await self.get_kv_data("key", default="default") # _VT 推断为 str +count = await self.get_kv_data("count", default=0) # _VT 推断为 int +``` + +--- + +## 特殊类型 + +### HandlerType + +事件处理器函数类型。 + +**定义**: + +```python +from typing import Callable, Awaitable, Any + +HandlerType = Callable[..., Awaitable[Any]] +``` + +**说明**: + +表示事件处理器的函数签名,接受任意参数并返回异步结果。 + +**特征**: +- 可变参数 (`...`) +- 异步返回 (`Awaitable[Any]`) + +**示例**: + +```python +async def my_handler(event: MessageEvent, ctx: Context) -> None: + pass + +# 符合 HandlerType 类型 +``` + +--- + +### FilterType + +过滤器函数类型。 + +**定义**: + +```python +FilterType = Callable[..., Awaitable[bool]] +``` + +**说明**: + +表示过滤器函数的类型,返回布尔值。 + +**特征**: +- 可变参数 (`...`) +- 异步返回布尔值 (`Awaitable[bool]`) + +**示例**: + +```python +async def my_filter(event: MessageEvent, ctx: Context) -> bool: + return event.platform == "qq" + +# 符合 FilterType 类型 +``` + +--- + +## Pydantic 模型类型 + +### ChatMessage + +聊天消息模型,用于构建对话历史。 + +**定义位置**: `astrbot_sdk.clients.llm` + +**定义**: + +```python +from pydantic import BaseModel + +class ChatMessage(BaseModel): + """聊天消息模型。""" + role: str + content: str +``` + +**属性**: + +| 属性 | 类型 | 说明 | +|------|------|------| +| `role` | `str` | 消息角色,如 `"user"`, `"assistant"`, `"system"` | +| `content` | `str` | 消息内容 | + +**示例**: + +```python +from astrbot_sdk.clients.llm import ChatMessage + +# 系统提示 +system_msg = ChatMessage( + role="system", + content="你是一个友好的助手" +) + +# 用户消息 +user_msg = ChatMessage( + role="user", + content="你好" +) + +# 助手回复 +assistant_msg = ChatMessage( + role="assistant", + content="你好!有什么可以帮助你的?" +) +``` + +--- + +### LLMResponse + +LLM 响应模型,包含完整的响应信息。 + +**定义位置**: `astrbot_sdk.clients.llm` + +**定义**: + +```python +from pydantic import BaseModel, Field + +class LLMResponse(BaseModel): + """LLM 响应模型。""" + text: str + usage: dict[str, Any] | None = None + finish_reason: str | None = None + tool_calls: list[dict[str, Any]] = Field(default_factory=list) + role: str | None = None + reasoning_content: str | None = None + reasoning_signature: str | None = None +``` + +**属性**: + +| 属性 | 类型 | 说明 | +|------|------|------| +| `text` | `str` | 生成的文本内容 | +| `usage` | `dict[str, Any] \| None` | Token 使用统计 | +| `finish_reason` | `str \| None` | 结束原因(`"stop"`, `"length"`, `"tool_calls"`) | +| `tool_calls` | `list[dict[str, Any]]` | 工具调用列表 | +| `role` | `str \| None` | 响应角色 | +| `reasoning_content` | `str \| None` | 推理内容(用于推理模型) | +| `reasoning_signature` | `str \| None` | 推理签名 | + +**示例**: + +```python +from astrbot_sdk.clients.llm import LLMResponse + +response = await ctx.llm.chat_raw("写一首诗") + +print(f"生成内容: {response.text}") +print(f"Token 使用: {response.usage}") +print(f"结束原因: {response.finish_reason}") + +if response.usage: + print(f"提示词 Token: {response.usage.get('prompt_tokens')}") + print(f"完成 Token: {response.usage.get('completion_tokens')}") +``` + +--- + +## 使用示例 + +### 类型注解在函数签名中的使用 + +```python +from typing import Sequence, Mapping, Any +from astrbot_sdk.clients.llm import ChatMessage, ChatHistoryItem +from astrbot_sdk import MessageChain, BaseMessageComponent, PlatformCompatContent + +# 使用 ChatHistoryItem +async def chat_with_history( + prompt: str, + history: Sequence[ChatHistoryItem] | None = None +) -> str: + """与 LLM 聊天的函数。""" + pass + +# 使用 PlatformCompatContent +async def send_content( + session: str, + content: PlatformCompatContent +) -> dict[str, Any]: + """发送内容的函数。""" + pass +``` + +### 类型检查和类型守卫 + +```python +from collections.abc import Mapping, Sequence +from astrbot_sdk.clients.llm import ChatMessage, ChatHistoryItem + +def normalize_history_item(item: ChatHistoryItem) -> dict[str, Any]: + """将聊天历史项规范化为字典。""" + if isinstance(item, ChatMessage): + return item.model_dump() + if isinstance(item, Mapping): + return dict(item) + raise TypeError("无效的聊天历史项类型") + +# 使用 +history: Sequence[ChatHistoryItem] = [ + ChatMessage(role="user", content="Hi"), + {"role": "assistant", "content": "Hello"}, +] + +normalized = [normalize_history_item(item) for item in history] +``` + +### 泛型函数 + +```python +from typing import TypeVar, Generic + +T = TypeVar("T") + +class Container(Generic[T]): + def __init__(self, value: T) -> None: + self.value = value + + def get(self) -> T: + return self.value + +# 使用 +int_container: Container[int] = Container(42) +str_container: Container[str] = Container("hello") +``` + +--- + +## 相关模块 + +- **LLM 客户端**: `astrbot_sdk.clients.LLMClient` +- **消息组件**: `astrbot_sdk.message_components` +- **消息链**: `astrbot_sdk.message_result.MessageChain` +- **上下文**: `astrbot_sdk.context.Context` + +--- + +**版本**: v4.0 +**最后更新**: 2026-03-17 diff --git a/astrbot-sdk/docs/api/utils.md b/astrbot-sdk/docs/api/utils.md new file mode 100644 index 0000000000..62a5bbc089 --- /dev/null +++ b/astrbot-sdk/docs/api/utils.md @@ -0,0 +1,1087 @@ +# 工具与辅助类 API 完整参考 + +## 概述 + +本文档介绍 AstrBot SDK 中常用的工具类和辅助类型,包括取消令牌、会话管理、命令组织、参数解析等功能。 + +**模块路径**: +- `astrbot_sdk.context.CancelToken` +- `astrbot_sdk.message_session.MessageSession` +- `astrbot_sdk.types.GreedyStr` +- `astrbot_sdk.commands` +- `astrbot_sdk.schedule.ScheduleContext` +- `astrbot_sdk.session_waiter` +- `astrbot_sdk.star_tools.StarTools` +- `astrbot_sdk.plugin_kv.PluginKVStoreMixin` + +--- + +## 目录 + +- [CancelToken - 取消令牌](#canceltoken---取消令牌) +- [MessageSession - 消息会话](#messagesession---消息会话) +- [GreedyStr - 贪婪字符串](#greedystr---贪婪字符串) +- [CommandGroup - 命令组](#commandgroup---命令组) +- [ScheduleContext - 调度上下文](#schedulecontext---调度上下文) +- [SessionController - 会话控制器](#sessioncontroller---会话控制器) +- [session_waiter - 会话等待装饰器](#session_waiter---会话等待装饰器) +- [StarTools - Star 工具类](#startools---star-工具类) +- [PluginKVStoreMixin - KV 存储混入](#pluginkvstoremixin---kv-存储混入) + +--- + +## 导入方式 + +```python +# 从主模块导入 +from astrbot_sdk import ( + CancelToken, + MessageSession, + GreedyStr, + ScheduleContext, + SessionController, + session_waiter, + StarTools, + PluginKVStoreMixin, +) + +# 从子模块导入 +from astrbot_sdk.context import CancelToken +from astrbot_sdk.message_session import MessageSession +from astrbot_sdk.types import GreedyStr +from astrbot_sdk.commands import CommandGroup, command_group, print_cmd_tree +from astrbot_sdk.schedule import ScheduleContext +from astrbot_sdk.session_waiter import SessionController, session_waiter +from astrbot_sdk.star_tools import StarTools +from astrbot_sdk.plugin_kv import PluginKVStoreMixin +``` + +--- + +## CancelToken - 取消令牌 + +请求取消令牌,用于协调长时间运行操作的取消。 + +### 类定义 + +```python +@dataclass(slots=True) +class CancelToken: + _cancelled: asyncio.Event +``` + +### 构造方法 + +```python +from astrbot_sdk import CancelToken + +token = CancelToken() +``` + +### 实例方法 + +#### `cancel()` + +触发取消信号。 + +```python +def cancel(self) -> None: + """触发取消信号。""" +``` + +**示例**: + +```python +token.cancel() +``` + +--- + +#### `cancelled` 属性 + +检查是否已被取消。 + +```python +@property +def cancelled(self) -> bool: + """检查是否已被取消。""" +``` + +**示例**: + +```python +if token.cancelled: + print("操作已取消") +``` + +--- + +#### `wait()` + +等待取消信号。 + +```python +async def wait(self) -> None: + """等待取消信号。""" +``` + +**示例**: + +```python +await token.wait() +``` + +--- + +#### `raise_if_cancelled()` + +如果已取消则抛出 `CancelledError`。 + +```python +def raise_if_cancelled(self) -> None: + """如果已取消则抛出 CancelledError。""" +``` + +**异常**: +- `asyncio.CancelledError`: 如果令牌已被取消 + +**示例**: + +```python +async def long_operation(ctx: Context): + for item in large_list: + ctx.cancel_token.raise_if_cancelled() + await process(item) +``` + +--- + +## MessageSession - 消息会话 + +统一表示消息会话标识符,格式为 `platform_id:message_type:session_id`。 + +### 类定义 + +```python +@dataclass(slots=True) +class MessageSession: + platform_id: str + message_type: str + session_id: str +``` + +### 属性 + +| 属性 | 类型 | 说明 | +|------|------|------| +| `platform_id` | `str` | 平台实例 ID | +| `message_type` | `str` | 消息类型(`group` 或 `private`) | +| `session_id` | `str` | 会话 ID | + +### 类方法 + +#### `from_str(session)` + +从字符串解析会话。 + +```python +@classmethod +def from_str(cls, session: str) -> MessageSession: + platform_id, message_type, session_id = str(session).split(":", 2) + return cls( + platform_id=platform_id, + message_type=message_type, + session_id=session_id, + ) +``` + +**参数**: +- `session` (`str`): 会话字符串,格式为 `platform_id:message_type:session_id` + +**返回**: `MessageSession` 实例 + +**示例**: + +```python +from astrbot_sdk import MessageSession + +# 从字符串创建 +session = MessageSession.from_str("qq:group:123456") + +# 直接创建 +session = MessageSession( + platform_id="qq", + message_type="group", + session_id="123456" +) + +# 转换为字符串 +str(session) # "qq:group:123456" +``` + +--- + +## GreedyStr - 贪婪字符串 + +用于标记"贪婪字符串"参数,在命令解析时将剩余所有文本作为一个整体参数。 + +### 类定义 + +```python +class GreedyStr(str): + """Consume the remaining command text as one argument.""" +``` + +### 使用场景 + +当命令参数包含空格时,普通解析会将空格后的内容作为下一个参数,而 `GreedyStr` 会捕获剩余所有文本。 + +**示例**: + +```python +from astrbot_sdk import GreedyStr +from astrbot_sdk.decorators import on_command + +@on_command("echo") +async def echo(self, event: MessageEvent, text: GreedyStr): + # 用户输入: /echo hello world this is a test + # text = "hello world this is a test" + await event.reply(text) + +@on_command("say") +async def say(self, event: MessageEvent, name: str, message: GreedyStr): + # 用户输入: /say Alice Hello World + # name = "Alice" + # message = "Hello World" + await event.reply(f"{name} 说: {message}") +``` + +--- + +## CommandGroup - 命令组 + +用于组织具有层级关系的命令,支持命令别名和自动展开。 + +### 类定义 + +```python +class CommandGroup: + def __init__( + self, + name: str, + *, + aliases: list[str] | None = None, + description: str | None = None, + parent: CommandGroup | None = None, + ) -> None: +``` + +### 构造方法 + +```python +from astrbot_sdk import CommandGroup, command_group + +# 使用函数创建 +admin = command_group("admin", description="管理命令") + +# 使用类创建 +config = CommandGroup("config", description="配置命令") +``` + +**参数**: +- `name` (`str`): 组名称 +- `aliases` (`list[str] | None`): 别名列表 +- `description` (`str | None`): 描述信息 +- `parent` (`CommandGroup | None`): 父组 + +### 实例方法 + +#### `group(name, *, aliases, description)` + +创建子命令组。 + +```python +def group( + self, + name: str, + *, + aliases: list[str] | None = None, + description: str | None = None, +) -> CommandGroup: +``` + +**示例**: + +```python +admin = command_group("admin") +user = admin.group("user", description="用户管理") +config = admin.group("config", description="配置管理") +``` + +--- + +#### `command(name, *, aliases, description)` + +创建命令装饰器。 + +```python +def command( + self, + name: str, + *, + aliases: list[str] | None = None, + description: str | None = None, +): +``` + +**返回**: 装饰器函数 + +**示例**: + +```python +admin = command_group("admin") + +@admin.command("add", description="添加用户") +async def admin_add_user(self, event: MessageEvent, user_id: str): + await event.reply(f"添加用户: {user_id}") + +@admin.command("remove", aliases=["del"], description="删除用户") +async def admin_remove_user(self, event: MessageEvent, user_id: str): + await event.reply(f"删除用户: {user_id}") +``` + +--- + +#### `path` 属性 + +获取命令组的完整路径。 + +```python +@property +def path(self) -> list[str]: + if self.parent is None: + return [self.name] + return [*self.parent.path, self.name] +``` + +**示例**: + +```python +admin = command_group("admin") +user = admin.group("user") + +user.path # ["admin", "user"] +``` + +--- + +#### `print_cmd_tree()` + +打印命令树结构。 + +```python +def print_cmd_tree(self) -> str: + lines: list[str] = [] + self._append_tree_lines(lines, indent=0) + return "\n".join(lines) +``` + +**返回**: `str` - 命令树字符串 + +**示例**: + +```python +admin = command_group("admin") + +@admin.command("add") +async def admin_add(...): pass + +@admin.command("remove") +async def admin_remove(...): pass + +print(admin.print_cmd_tree()) +# 输出: +# admin +# - add +# - remove +``` + +--- + +### 函数 + +#### `command_group(name, *, aliases, description)` + +创建命令组实例。 + +```python +def command_group( + name: str, + *, + aliases: list[str] | None = None, + description: str | None = None, +) -> CommandGroup: + return CommandGroup( + name, + aliases=aliases, + description=description, + ) +``` + +--- + +#### `print_cmd_tree(group)` + +获取命令树字符串。 + +```python +def print_cmd_tree(group: CommandGroup) -> str: + return group.print_cmd_tree() +``` + +**示例**: + +```python +from astrbot_sdk import command_group, print_cmd_tree + +admin = command_group("admin", description="管理命令") + +@admin.command("user") +async def admin_user(...): pass + +@admin.command("setting") +async def admin_setting(...): pass + +# 获取命令树 +tree = print_cmd_tree(admin) +await event.reply(f"```\n{tree}\n```") +``` + +--- + +### 使用示例 + +#### 基本命令组 + +```python +from astrbot_sdk import Star, command_group +from astrbot_sdk.decorators import on_command +from astrbot_sdk.events import MessageEvent + +class MyPlugin(Star): + # 创建命令组 + admin = command_group("admin", description="管理命令") + + @admin.command("add", description="添加用户") + async def admin_add(self, event: MessageEvent, user_id: str): + await event.reply(f"添加用户: {user_id}") + + @admin.command("remove", aliases=["del"], description="删除用户") + async def admin_remove(self, event: MessageEvent, user_id: str): + await event.reply(f"删除用户: {user_id}") +``` + +#### 嵌套命令组 + +```python +# 创建嵌套结构 +admin = command_group("admin") +user = admin.group("user", description="用户管理") +config = admin.group("config", description="配置管理") + +@user.command("add") +async def admin_user_add(self, event: MessageEvent, user_id: str): + await event.reply(f"添加用户: {user_id}") + +@user.command("remove") +async def admin_user_remove(self, event: MessageEvent, user_id: str): + await event.reply(f"删除用户: {user_id}") + +@config.command("get") +async def admin_config_get(self, event: MessageEvent, key: str): + await event.reply(f"获取配置: {key}") + +@config.command("set") +async def admin_config_set(self, event: MessageEvent, key: str, value: str): + await event.reply(f"设置配置: {key} = {value}") +``` + +#### 使用类组织命令 + +```python +from astrbot_sdk import Star, CommandGroup + +class AdminCommands: + group = CommandGroup("admin", description="管理命令") + + @group.command("add", description="添加用户") + async def add_user(self, event, user_id: str): + await event.reply(f"添加用户: {user_id}") + + @group.command("remove", description="删除用户") + async def remove_user(self, event, user_id: str): + await event.reply(f"删除用户: {user_id}") +``` + +--- + +## ScheduleContext - 调度上下文 + +定时任务的上下文信息,包含调度任务的详细信息。 + +### 类定义 + +```python +@dataclass(slots=True) +class ScheduleContext: + schedule_id: str + plugin_id: str + handler_id: str + trigger_kind: str + cron: str | None = None + interval_seconds: int | None = None + scheduled_at: str | None = None +``` + +### 属性 + +| 属性 | 类型 | 说明 | +|------|------|------| +| `schedule_id` | `str` | 调度任务唯一标识 | +| `plugin_id` | `str` | 所属插件 ID | +| `handler_id` | `str` | 对应 handler 的标识 | +| `trigger_kind` | `str` | 触发类型(`cron` / `interval` / `once`) | +| `cron` | `str \| None` | cron 表达式(仅 cron 类型) | +| `interval_seconds` | `int \| None` | 间隔秒数(仅 interval 类型) | +| `scheduled_at` | `str \| None` | 计划执行时间(仅 once 类型) | + +### 使用示例 + +```python +from astrbot_sdk.decorators import on_schedule +from astrbot_sdk import ScheduleContext + +class MyPlugin(Star): + @on_schedule(cron="0 8 * * *") # 每天 8:00 + async def morning_greeting(self, ctx: ScheduleContext): + # ctx.schedule_id: 任务 ID + # ctx.trigger_kind: "cron" + # ctx.cron: "0 8 * * *" + await self.send_message("群号", "早上好!") + + @on_schedule(interval_seconds=3600) # 每小时 + async def hourly_check(self, ctx: ScheduleContext): + # ctx.trigger_kind: "interval" + # ctx.interval_seconds: 3600 + pass +``` + +--- + +## SessionController - 会话控制器 + +控制会话生命周期,支持超时管理、会话保持、历史记录。 + +### 类定义 + +```python +@dataclass(slots=True) +class SessionController: + future: asyncio.Future[Any] = field(default_factory=asyncio.Future) + current_event: asyncio.Event | None = None + ts: float | None = None + timeout: float | None = None + history_chains: list[list[dict[str, Any]]] = field(default_factory=list) +``` + +### 属性 + +| 属性 | 类型 | 说明 | +|------|------|------| +| `future` | `asyncio.Future` | 会话结果 Future | +| `current_event` | `asyncio.Event \| None` | 当前事件 | +| `ts` | `float \| None` | 时间戳 | +| `timeout` | `float \| None` | 超时时间(秒) | +| `history_chains` | `list[list[dict]]` | 历史消息链 | + +### 实例方法 + +#### `stop(error)` + +停止会话。 + +```python +def stop(self, error: Exception | None = None) -> None: + if self.future.done(): + return + if error is not None: + self.future.set_exception(error) + else: + self.future.set_result(None) +``` + +**参数**: +- `error` (`Exception | None`): 可选的错误对象 + +--- + +#### `keep(timeout, reset_timeout)` + +延长会话超时时间。 + +```python +def keep(self, timeout: float = 0, reset_timeout: bool = False) -> None: + new_ts = time.time() + if reset_timeout: + if timeout <= 0: + self.stop() + return + else: + assert self.timeout is not None + assert self.ts is not None + left_timeout = self.timeout - (new_ts - self.ts) + timeout = left_timeout + timeout + if timeout <= 0: + self.stop() + return + + if self.current_event and not self.current_event.is_set(): + self.current_event.set() + + current_event = asyncio.Event() + self.current_event = current_event + self.ts = new_ts + self.timeout = timeout + asyncio.create_task(self._holding(current_event, timeout)) +``` + +**参数**: +- `timeout` (`float`): 延长的超时时间(秒) +- `reset_timeout` (`bool`): 是否重置超时时间 + +--- + +#### `get_history_chains()` + +获取历史消息链。 + +```python +def get_history_chains(self) -> list[list[dict[str, Any]]]: + return list(self.history_chains) +``` + +**返回**: `list[list[dict]]` - 历史消息链的副本 + +--- + +## session_waiter - 会话等待装饰器 + +将普通 handler 转换为会话式 handler,用于构建多轮对话流程。 + +### 函数签名 + +```python +def session_waiter( + timeout: int = 30, + *, + record_history_chains: bool = False, +) -> _SessionWaiterDecorator: +``` + +### 参数 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `timeout` | `int` | `30` | 会话超时时间(秒) | +| `record_history_chains` | `bool` | `False` | 是否记录历史消息链 | + +### 使用示例 + +#### 推荐启动方式 + +`@session_waiter` 定义的是“后续消息到达时如何处理”,推荐从当前 handler +里通过 `Context.register_task()` 把 waiter 挂到后台任务中。这样首条消息的 +dispatch 会立刻结束,不会因为等待下一条消息而卡住。 + +#### 基本使用 + +```python +from astrbot_sdk import Context, session_waiter, SessionController +from astrbot_sdk.events import MessageEvent +from astrbot_sdk import Star, on_command + +class MyPlugin(Star): + @session_waiter(timeout=300) + async def collect_username( + self, + controller: SessionController, + event: MessageEvent, + ) -> None: + await event.reply(f"已记录用户名: {event.text}") + controller.stop() + + @on_command("bind") + async def bind(self, event: MessageEvent, ctx: Context) -> None: + await event.reply("请输入用户名:") + await ctx.register_task( + self.collect_username(event), + "waiter:collect_username", + ) +``` + +#### 多轮对话 + +```python +class SurveyPlugin(Star): + @session_waiter(timeout=600, record_history_chains=True) + async def survey(self, controller: SessionController, event: MessageEvent) -> None: + history = controller.get_history_chains() + + if len(history) == 1: + await event.reply(f"收到姓名: {event.text}") + await event.reply("请输入您的年龄:") + controller.keep(timeout=300) + return + + await event.reply(f"收到年龄: {event.text}") + controller.stop() + + @on_command("survey") + async def start_survey(self, event: MessageEvent, ctx: Context) -> None: + await event.reply("请输入您的姓名:") + await ctx.register_task(self.survey(event), "waiter:survey") +``` + +#### 直接 await 的语义 + +```python +@on_command("debug-blocking") +async def debug_blocking(self, event: MessageEvent, ctx: Context) -> None: + await event.reply("下一条消息会在当前 dispatch 中继续处理") + await self.collect_username(event) # 会保持当前 dispatch 挂起 +``` + +上面这种直接 `await` 仍然保留现有语义,但它会一直阻塞到下一条消息到达 +或超时。常规插件逻辑推荐使用 `await ctx.register_task(waiter(...), "...")`。 + +--- + +## StarTools - Star 工具类 + +提供类方法访问运行时上下文能力,只在生命周期、handler 和已注册的 LLM 工具执行期间可用。 + +### 类定义 + +```python +class StarTools: + """Star 工具类,提供类方法访问运行时上下文能力。""" +``` + +### 类方法 + +#### `activate_llm_tool(name)` + +激活 LLM 工具。 + +```python +@classmethod +async def activate_llm_tool(cls, name: str) -> bool: + return await cls._require_context().activate_llm_tool(name) +``` + +**参数**: +- `name` (`str`): 工具名称 + +**返回**: `bool` - 是否成功激活 + +--- + +#### `deactivate_llm_tool(name)` + +停用 LLM 工具。 + +```python +@classmethod +async def deactivate_llm_tool(cls, name: str) -> bool: + return await cls._require_context().deactivate_llm_tool(name) +``` + +**参数**: +- `name` (`str`): 工具名称 + +**返回**: `bool` - 是否成功停用 + +--- + +#### `send_message(session, content)` + +发送消息。 + +```python +@classmethod +async def send_message( + cls, + session: str | MessageSession, + content: ( + str + | MessageChain + | Sequence[BaseMessageComponent] + | Sequence[dict[str, Any]] + ), +) -> dict[str, Any]: + return await cls._require_context().send_message(session, content) +``` + +**参数**: +- `session` (`str | MessageSession`): 目标会话 +- `content`: 消息内容 + +**返回**: `dict[str, Any]` - 发送结果 + +--- + +#### `send_message_by_id(type, id, content, *, platform)` + +通过 ID 发送消息。 + +```python +@classmethod +async def send_message_by_id( + cls, + type: str, + id: str, + content: ( + str + | MessageChain + | Sequence[BaseMessageComponent] + | Sequence[dict[str, Any]] + ), + *, + platform: str, +) -> dict[str, Any]: + return await cls._require_context().send_message_by_id( + type, + id, + content, + platform=platform, + ) +``` + +**参数**: +- `type` (`str`): 消息类型(`group` 或 `private`) +- `id` (`str`): 目标 ID +- `content`: 消息内容 +- `platform` (`str`): 平台标识 + +**返回**: `dict[str, Any]` - 发送结果 + +--- + +#### `register_llm_tool(name, parameters_schema, desc, func_obj, *, active)` + +注册 LLM 工具。 + +```python +@classmethod +async def register_llm_tool( + cls, + name: str, + parameters_schema: dict[str, Any], + desc: str, + func_obj: Callable[..., Awaitable[Any]] | Callable[..., Any], + *, + active: bool = True, +) -> list[str]: + return await cls._require_context().register_llm_tool( + name, + parameters_schema, + desc, + func_obj, + active=active, + ) +``` + +**参数**: +- `name` (`str`): 工具名称 +- `parameters_schema` (`dict[str, Any]`): 参数模式 +- `desc` (`str`): 工具描述 +- `func_obj`: 工具函数 +- `active` (`bool`): 是否激活 + +**返回**: `list[str]` - 注册的工具名称列表 + +--- + +#### `unregister_llm_tool(name)` + +注销 LLM 工具。 + +```python +@classmethod +async def unregister_llm_tool(cls, name: str) -> bool: + return await cls._require_context().unregister_llm_tool(name) +``` + +**参数**: +- `name` (`str`): 工具名称 + +**返回**: `bool` - 是否成功注销 + +--- + +### 使用示例 + +```python +from astrbot_sdk import StarTools +from astrbot_sdk.events import MessageEvent + +class MyPlugin(Star): + async def on_start(self, ctx): + # 注册 LLM 工具 + await StarTools.register_llm_tool( + name="my_tool", + parameters_schema={ + "type": "object", + "properties": { + "text": {"type": "string"} + } + }, + desc="我的工具", + func_obj=self.my_tool_func + ) + + async def my_tool_func(self, text: str) -> str: + return f"处理结果: {text}" + + @on_command("test") + async def test(self, event: MessageEvent): + # 发送消息 + await StarTools.send_message( + event.session, + "Hello!" + ) + + # 激活工具 + await StarTools.activate_llm_tool("my_tool") +``` + +--- + +## PluginKVStoreMixin - KV 存储混入 + +插件作用域的 KV 存储助手,基于运行时 db 客户端。 + +### 类定义 + +```python +class PluginKVStoreMixin: + """Plugin-scoped KV helpers backed by the runtime db client.""" +``` + +### 属性 + +#### `plugin_id` + +获取插件 ID。 + +```python +@property +def plugin_id(self) -> str: + ctx = self._runtime_context() + return ctx.plugin_id +``` + +### 实例方法 + +#### `put_kv_data(key, value)` + +存储键值数据。 + +```python +async def put_kv_data(self, key: str, value: Any) -> None: + ctx = self._runtime_context() + await ctx.db.set(str(key), value) +``` + +**参数**: +- `key` (`str`): 键名 +- `value` (`Any`): 值 + +--- + +#### `get_kv_data(key, default)` + +获取键值数据。 + +```python +async def get_kv_data(self, key: str, default: _VT) -> _VT: + ctx = self._runtime_context() + value = await ctx.db.get(str(key)) + return default if value is None else value +``` + +**参数**: +- `key` (`str`): 键名 +- `default`: 默认值 + +**返回**: 存储的值或默认值 + +--- + +#### `delete_kv_data(key)` + +删除键值数据。 + +```python +async def delete_kv_data(self, key: str) -> None: + ctx = self._runtime_context() + await ctx.db.delete(str(key)) +``` + +**参数**: +- `key` (`str`): 键名 + +--- + +### 使用示例 + +```python +from astrbot_sdk import Star, PluginKVStoreMixin + +class MyPlugin(Star, PluginKVStoreMixin): + async def on_start(self, ctx): + # 存储数据 + await self.put_kv_data("initialized", True) + await self.put_kv_data("config", {"key": "value"}) + + @on_command("config") + async def config_command(self, event: MessageEvent, key: str, value: str): + # 保存配置 + await self.put_kv_data(f"config_{key}", value) + await event.reply(f"配置已保存: {key} = {value}") + + @on_command("get_config") + async def get_config(self, event: MessageEvent, key: str): + # 读取配置 + value = await self.get_kv_data(f"config_{key}", default="未设置") + await event.reply(f"{key} = {value}") + + @on_command("delete_config") + async def delete_config(self, event: MessageEvent, key: str): + # 删除配置 + await self.delete_kv_data(f"config_{key}") + await event.reply(f"配置已删除: {key}") +``` + +--- + +## 相关模块 + +- **核心类**: `astrbot_sdk.star.Star`, `astrbot_sdk.context.Context` +- **事件处理**: `astrbot_sdk.events.MessageEvent` +- **装饰器**: `astrbot_sdk.decorators` + +--- + +**版本**: v4.0 +**最后更新**: 2026-03-17 diff --git a/astrbot-sdk/pyproject.toml b/astrbot-sdk/pyproject.toml new file mode 100644 index 0000000000..2d8f7ca6c1 --- /dev/null +++ b/astrbot-sdk/pyproject.toml @@ -0,0 +1,56 @@ +[build-system] +requires = ["setuptools>=80", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "astrbot-sdk" +version = "0.1.0" +description = "AstrBot SDK with v4 runtime, worker protocol, and plugin tooling" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "aiohttp>=3.13.2", + "anthropic>=0.72.1", + "certifi>=2025.10.5", + "click>=8.3.0", + "docstring-parser>=0.17.0", + "google-genai>=1.50.0", + "loguru>=0.7.3", + "msgpack>=1.1.1", + "openai>=2.7.2", + "pydantic>=2.12.3", + "pyyaml>=6.0.3", + "uv>=0.9.17", +] + +[project.scripts] +astr = "astrbot_sdk.cli:cli" + +[tool.pytest.ini_options] +markers = [ + "unit: unit tests", +] + +# ============================================================ +# Package Discovery (src layout) +# ============================================================ +[tool.setuptools.packages.find] +where = ["src"] + +[tool.setuptools.package-data] +astrbot_sdk = [ + "templates/skills/*/SKILL.md", + "templates/skills/*/agents/*.yaml", + "templates/skills/*/references/*.md", +] + +# ============================================================ +# Optional Dependencies +# ============================================================ +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.24.0", + "pytest-cov>=5.0.0", + "ruff>=0.4.0", +] diff --git a/astrbot-sdk/src/astrbot_sdk/AGENTS.md b/astrbot-sdk/src/astrbot_sdk/AGENTS.md new file mode 100644 index 0000000000..40b2a8f93e --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/AGENTS.md @@ -0,0 +1,43 @@ +# Notes + +## v4 架构约束 + +### 运行时层 + +- `Peer` 必须将 transport EOF/连接断开视为一级失败路径。如果 transport 意外关闭而 `Peer` 没有主动失败 `_pending_results` / `_pending_streams`,supervisor 端对 worker 的调用可能永远挂起。 +- `Peer.initialize()` 需要在发起端也标记远程已初始化。仅在被动接收 `InitializeMessage` 时设置 `_remote_initialized` 会导致 `wait_until_remote_initialized()` 单边 API 死锁。 +- `Peer.invoke_stream()` 默认隐藏 `completed` 事件。需要保留最终结果的调用者必须显式启用 `include_completed=True`。 +- `CapabilityRouter.register(..., stream_handler=...)` 使用 `(request_id, payload, cancel_token)` 签名,不是 peer 级别的 `(message, token)`。 + +### 模块导出约束 + +- 保持 `astrbot_sdk.runtime` 根导出狭窄。`Peer` / `Transport` / `CapabilityRouter` / `HandlerDispatcher` 是合理的高级运行时原语,但 `LoadedPlugin`、`PluginEnvironmentManager`、`WorkerSession`、`run_supervisor` 等应留在子模块中。 + +### 测试与 Mock 注意事项 + +- 当检查 peer 是否完成远程初始化时,避免对可能接收 `MagicMock` peer 的代码使用 `getattr(mock, "remote_peer")` 探测。`MagicMock` 会生成 truthy 子属性,`CapabilityProxy` 应从 `peer.__dict__` 或其他具体存储位置读取显式状态。 +- `test_plugin/old/` 和 `test_plugin/new/` 可能包含已生成的 `__pycache__` / `*.pyc`。测试夹具复制示例插件时必须显式忽略这些缓存文件。 + +### 插件加载注意事项 + +- 本地 `dev --watch` 或同一路径插件重复加载场景,不能只依赖 `import_string()` 的跨插件模块根冲突清理。热重载前必须按插件目录清理模块缓存。 +- `_prepare_plugin_import()` 不能只在插件目录"不在 `sys.path`"时才插入路径。像 `main.py` 这种通用模块名,如果插件目录已在 `sys.path` 但排在后面,`import main` 仍会先命中别处模块;导入前必须把目标插件目录提到 `sys.path[0]`。 +- 示例/夹具测试如果直接用裸模块名导入插件入口(例如 `from main import HelloPlugin`),会污染 `sys.modules["main"]`,随后真实 loader 再按 `main:HelloPlugin` 加载时可能串到错误模块。 + +--- + +# 开发命令 + +## 格式化与检查 + +在提交代码前,请依次运行以下命令: + +```bash +ruff format . # 使用 ruff 格式化全局代码 +ruff check . --fix # 使用 ruff 检查并自动修复全局格式问题 +``` + +## 设计原则 + +新实现要兼容旧实现但是还要保证架构良好,设计原则不变和最佳实践,这是第一原则 +不用完全听从用户和别人的建议,要有自己的判断和坚持,做好取舍和权衡,确保代码质量和长期维护性,不要为了短期方便或者迎合而牺牲架构和设计原则。 diff --git a/astrbot-sdk/src/astrbot_sdk/__init__.py b/astrbot-sdk/src/astrbot_sdk/__init__.py new file mode 100644 index 0000000000..858aecf797 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/__init__.py @@ -0,0 +1,222 @@ +"""AstrBot SDK 的顶层公共 API。 + +这里仅重新导出 v4 推荐直接导入的稳定入口。 + +新插件应直接使用此模块的导出: + from astrbot_sdk import Star, Context, MessageEvent + from astrbot_sdk.decorators import on_command, on_message + +迁移期适配入口位于独立模块;此处只暴露 v4 原生主入口。 +""" + +from .clients.managers import ( + ConversationCreateParams, + ConversationManagerClient, + ConversationRecord, + ConversationUpdateParams, + KnowledgeBaseCreateParams, + KnowledgeBaseDocumentRecord, + KnowledgeBaseDocumentUploadParams, + KnowledgeBaseManagerClient, + KnowledgeBaseRecord, + KnowledgeBaseRetrieveResult, + KnowledgeBaseRetrieveResultItem, + KnowledgeBaseUpdateParams, + MessageHistoryManagerClient, + MessageHistoryPage, + MessageHistoryRecord, + MessageHistorySender, + PersonaCreateParams, + PersonaManagerClient, + PersonaRecord, + PersonaUpdateParams, +) +from .clients.mcp import MCPManagerClient, MCPServerRecord, MCPServerScope, MCPSession +from .clients.metadata import PluginMetadata, StarMetadata +from .clients.permission import ( + PermissionCheckResult, + PermissionClient, + PermissionManagerClient, +) +from .clients.platform import PlatformError, PlatformStats, PlatformStatus +from .clients.provider import ( + ManagedProviderRecord, + ProviderChangeEvent, + ProviderManagerClient, +) +from .clients.session import SessionPluginManager, SessionServiceManager +from .commands import CommandGroup, command_group, print_cmd_tree +from .context import Context +from .conversation import ( + ConversationClosed, + ConversationReplaced, + ConversationSession, + ConversationState, +) +from .decorators import ( + acknowledge_global_mcp_risk, + admin_only, + background_task, + conversation_command, + cooldown, + group_only, + http_api, + mcp_server, + message_types, + on_command, + on_event, + on_message, + on_provider_change, + on_schedule, + platforms, + priority, + private_only, + provide_capability, + rate_limit, + register_skill, + require_admin, + require_permission, + validate_config, +) +from .errors import AstrBotError +from .events import MessageEvent +from .filters import ( + CustomFilter, + MessageTypeFilter, + PlatformFilter, + all_of, + any_of, + custom_filter, +) +from .message.components import ( + At, + AtAll, + BaseMessageComponent, + File, + Forward, + Image, + MediaHelper, + Plain, + Poke, + Record, + Reply, + UnknownComponent, + Video, +) +from .message.result import ( + EventResultType, + MessageBuilder, + MessageChain, + MessageEventResult, +) +from .message.session import MessageSession +from .plugin_kv import PluginKVStoreMixin +from .schedule import ScheduleContext +from .session_waiter import SessionController, session_waiter +from .star import Star +from .star_tools import StarTools +from .types import GreedyStr + +__all__ = [ + "AstrBotError", + "At", + "AtAll", + "BaseMessageComponent", + "CommandGroup", + "ConversationClosed", + "ConversationCreateParams", + "ConversationManagerClient", + "ConversationReplaced", + "ConversationRecord", + "ConversationSession", + "ConversationState", + "ConversationUpdateParams", + "Context", + "CustomFilter", + "EventResultType", + "File", + "Forward", + "GreedyStr", + "Image", + "KnowledgeBaseCreateParams", + "KnowledgeBaseDocumentRecord", + "KnowledgeBaseDocumentUploadParams", + "KnowledgeBaseManagerClient", + "KnowledgeBaseRecord", + "KnowledgeBaseRetrieveResult", + "KnowledgeBaseRetrieveResultItem", + "KnowledgeBaseUpdateParams", + "ManagedProviderRecord", + "MCPManagerClient", + "MCPSession", + "MCPServerRecord", + "MCPServerScope", + "MediaHelper", + "MessageHistoryManagerClient", + "MessageHistoryPage", + "MessageHistoryRecord", + "MessageHistorySender", + "MessageEvent", + "MessageEventResult", + "MessageChain", + "MessageBuilder", + "MessageSession", + "MessageTypeFilter", + "Plain", + "PluginKVStoreMixin", + "PluginMetadata", + "PermissionCheckResult", + "PermissionClient", + "PermissionManagerClient", + "PlatformFilter", + "PlatformError", + "PlatformStats", + "PlatformStatus", + "Poke", + "PersonaCreateParams", + "PersonaManagerClient", + "PersonaRecord", + "PersonaUpdateParams", + "ProviderChangeEvent", + "ProviderManagerClient", + "Record", + "Reply", + "ScheduleContext", + "SessionPluginManager", + "SessionServiceManager", + "SessionController", + "Star", + "StarMetadata", + "StarTools", + "UnknownComponent", + "Video", + "acknowledge_global_mcp_risk", + "admin_only", + "all_of", + "any_of", + "background_task", + "cooldown", + "conversation_command", + "command_group", + "custom_filter", + "group_only", + "http_api", + "mcp_server", + "message_types", + "on_command", + "on_event", + "on_message", + "on_provider_change", + "on_schedule", + "platforms", + "print_cmd_tree", + "priority", + "provide_capability", + "private_only", + "rate_limit", + "require_admin", + "require_permission", + "register_skill", + "session_waiter", + "validate_config", +] diff --git a/astrbot-sdk/src/astrbot_sdk/__main__.py b/astrbot-sdk/src/astrbot_sdk/__main__.py new file mode 100644 index 0000000000..624fd22f4c --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/__main__.py @@ -0,0 +1,11 @@ +"""`python -m astrbot_sdk` 的 CLI 入口。""" + +from .cli import cli + + +def main() -> None: + cli() + + +if __name__ == "__main__": + main() diff --git a/astrbot-sdk/src/astrbot_sdk/_command_model.py b/astrbot-sdk/src/astrbot_sdk/_command_model.py new file mode 100644 index 0000000000..fd8f1ad851 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_command_model.py @@ -0,0 +1,17 @@ +from ._internal.command_model import ( + COMMAND_MODEL_DOCS_URL, + CommandModelParseResult, + ResolvedCommandModelParam, + format_command_model_help, + parse_command_model_remainder, + resolve_command_model_param, +) + +__all__ = [ + "COMMAND_MODEL_DOCS_URL", + "CommandModelParseResult", + "ResolvedCommandModelParam", + "format_command_model_help", + "parse_command_model_remainder", + "resolve_command_model_param", +] diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/__init__.py b/astrbot-sdk/src/astrbot_sdk/_internal/__init__.py new file mode 100644 index 0000000000..6ccc0d22e9 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/__init__.py @@ -0,0 +1,7 @@ +"""Internal implementation modules for astrbot_sdk. + +This package groups private helpers that are not part of the public SDK API. +Imports outside the SDK should avoid depending on these modules directly. +""" + +__all__: list[str] = [] diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/command_model.py b/astrbot-sdk/src/astrbot_sdk/_internal/command_model.py new file mode 100644 index 0000000000..664947f7af --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/command_model.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +import inspect +from dataclasses import dataclass +from typing import Any + +from pydantic import BaseModel + +from ..errors import AstrBotError +from ..runtime._command_matching import split_command_remainder +from .injected_params import is_framework_injected_parameter +from .typing_utils import unwrap_optional + +# TODO:文档内容喵 +COMMAND_MODEL_DOCS_URL = "https://docs.astrbot.org/sdk/parameter-injection" + + +@dataclass(slots=True) +class ResolvedCommandModelParam: + name: str + model_cls: type[BaseModel] + + +@dataclass(slots=True) +class CommandModelParseResult: + model: BaseModel | None = None + help_text: str | None = None + + +def resolve_command_model_param(handler: Any) -> ResolvedCommandModelParam | None: + try: + signature = inspect.signature(handler) + except (TypeError, ValueError): + return None + try: + type_hints = inspect.get_annotations(handler, eval_str=True) + except Exception: + type_hints = {} + + candidates: list[ResolvedCommandModelParam] = [] + other_names: list[str] = [] + for parameter in signature.parameters.values(): + if parameter.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + continue + annotation = type_hints.get(parameter.name) + if _is_injected_parameter(parameter.name, annotation): + continue + normalized, _is_optional = unwrap_optional(annotation) + if isinstance(normalized, type) and issubclass(normalized, BaseModel): + candidates.append( + ResolvedCommandModelParam( + name=parameter.name, + model_cls=normalized, + ) + ) + continue + other_names.append(parameter.name) + + if not candidates: + return None + if len(candidates) > 1 or other_names: + names = [item.name for item in candidates] + raise ValueError( + "Command BaseModel injection requires exactly one non-injected BaseModel " + f"parameter, got models={names!r} others={other_names!r}" + ) + _validate_supported_model(candidates[0].model_cls) + return candidates[0] + + +def parse_command_model_remainder( + *, + remainder: str, + model_param: ResolvedCommandModelParam, + command_name: str, +) -> CommandModelParseResult: + tokens = split_command_remainder(remainder) + if any(token in {"-h", "--help"} for token in tokens): + return CommandModelParseResult( + help_text=format_command_model_help(command_name, model_param.model_cls) + ) + + fields = model_param.model_cls.model_fields + explicit_values: dict[str, Any] = {} + positional_values: dict[str, Any] = {} + positional_field_names = [ + name + for name, field in fields.items() + if _supported_scalar_type(field.annotation)[0] is not bool + ] + positional_index = 0 + index = 0 + while index < len(tokens): + token = tokens[index] + if not token.startswith("--"): + assigned = False + while positional_index < len(positional_field_names): + field_name = positional_field_names[positional_index] + positional_index += 1 + if field_name in explicit_values or field_name in positional_values: + continue + positional_values[field_name] = token + assigned = True + break + if not assigned: + raise _command_parse_error("Too many positional arguments") + index += 1 + continue + + raw_name = token[2:] + if not raw_name: + raise _command_parse_error("Invalid option '--'") + explicit_value: str | None = None + if "=" in raw_name: + raw_name, explicit_value = raw_name.split("=", 1) + negated = raw_name.startswith("no-") + # 与 argparse/click 惯例一致:--foo-bar 自动映射为字段名 foo_bar + cli_name = raw_name[3:] if negated else raw_name + field_name = cli_name.replace("-", "_") + field = fields.get(field_name) + if field is None: + raise _command_parse_error(f"Unknown option: --{raw_name}") + option_name = _format_option_name(field_name) + negated_option_name = f"--no-{option_name[2:]}" + if field_name in explicit_values: + raise _command_parse_error(f"Duplicate option: {option_name}") + field_type, _is_optional = _supported_scalar_type(field.annotation) + if field_type is bool: + if explicit_value is not None: + raise _command_parse_error( + f"Boolean option '{option_name}' only supports {option_name} or {negated_option_name}" + ) + explicit_values[field_name] = not negated + index += 1 + continue + if negated: + raise _command_parse_error( + f"Non-boolean option '{option_name}' does not support {negated_option_name}" + ) + if explicit_value is None: + index += 1 + if index >= len(tokens): + raise _command_parse_error(f"Missing value for option: {option_name}") + explicit_value = tokens[index] + explicit_values[field_name] = explicit_value + index += 1 + + values = {**positional_values, **explicit_values} + + try: + model = model_param.model_cls.model_validate(values) + except Exception as exc: + raise AstrBotError.invalid_input( + "命令参数解析失败", + hint=str(exc), + docs_url=COMMAND_MODEL_DOCS_URL, + details={ + "command": command_name, + "parameter": model_param.name, + "values": values, + }, + ) from exc + return CommandModelParseResult(model=model) + + +def format_command_model_help(command_name: str, model_cls: type[BaseModel]) -> str: + _validate_supported_model(model_cls) + lines = [f"用法: /{command_name} [options]"] + if model_cls.model_fields: + lines.append("参数:") + for name, field in model_cls.model_fields.items(): + field_type, is_optional = _supported_scalar_type(field.annotation) + type_name = getattr(field_type, "__name__", str(field_type)) + required = field.is_required() + default_text = "" + if not required: + default_text = f",默认 {field.default!r}" + elif is_optional: + default_text = ",默认 None" + description = str(field.description or "").strip() + detail = f"{name}: {type_name}" + if description: + detail += f" - {description}" + detail += ",必填" if required else ",可选" + detail += default_text + if field_type is bool: + detail += f",使用 --{name} / --no-{name}" + lines.append(detail) + return "\n".join(lines) + + +def _validate_supported_model(model_cls: type[BaseModel]) -> None: + for name, field in model_cls.model_fields.items(): + try: + _supported_scalar_type(field.annotation) + except TypeError as exc: + raise ValueError( + f"Unsupported command model field '{name}': {exc}" + ) from exc + + +def _supported_scalar_type(annotation: Any) -> tuple[type[Any], bool]: + normalized, is_optional = unwrap_optional(annotation) + if normalized in {str, int, float, bool}: + return normalized, is_optional + raise TypeError("only str/int/float/bool and Optional variants are supported") + + +def _format_option_name(field_name: str) -> str: + # Surface the canonical CLI spelling so parse errors match the user's option syntax. + return f"--{field_name.replace('_', '-')}" + + +def _command_parse_error(message: str) -> AstrBotError: + return AstrBotError.invalid_input( + message, + docs_url=COMMAND_MODEL_DOCS_URL, + ) + + +def _is_injected_parameter(name: str, annotation: Any) -> bool: + return is_framework_injected_parameter(name, annotation) + + +__all__ = [ + "COMMAND_MODEL_DOCS_URL", + "CommandModelParseResult", + "ResolvedCommandModelParam", + "format_command_model_help", + "parse_command_model_remainder", + "resolve_command_model_param", +] diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/decorator_lifecycle.py b/astrbot-sdk/src/astrbot_sdk/_internal/decorator_lifecycle.py new file mode 100644 index 0000000000..e013b61e26 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/decorator_lifecycle.py @@ -0,0 +1,475 @@ +from __future__ import annotations + +import asyncio +import inspect +from contextlib import suppress +from dataclasses import dataclass, field +from typing import Any + +from loguru import logger +from pydantic import ValidationError + +from ..context import Context as RuntimeContext +from ..decorators import ( + BackgroundTaskMeta, + HttpApiMeta, + MCPServerMeta, + ValidateConfigMeta, + get_background_task_meta, + get_http_api_meta, + get_mcp_server_meta, + get_provider_change_meta, + get_skill_meta, + get_validate_config_meta, +) +from ..star import Star +from .star_runtime import bind_star_runtime + +_RUNTIME_STATE_ATTR = "__astrbot_decorator_runtime_state__" +_VALIDATED_CONFIGS_ATTR = "__astrbot_validated_configs__" + + +@dataclass(slots=True) +class DecoratorRuntimeState: + http_apis: list[tuple[str, list[str]]] = field(default_factory=list) + provider_hooks: list[asyncio.Task[None]] = field(default_factory=list) + background_tasks: list[asyncio.Task[Any]] = field(default_factory=list) + registered_skills: list[str] = field(default_factory=list) + local_mcp_servers: list[str] = field(default_factory=list) + global_mcp_servers: list[str] = field(default_factory=list) + + +def _runtime_state(instance: Any) -> DecoratorRuntimeState: + state = getattr(instance, _RUNTIME_STATE_ATTR, None) + if isinstance(state, DecoratorRuntimeState): + return state + state = DecoratorRuntimeState() + setattr(instance, _RUNTIME_STATE_ATTR, state) + return state + + +def _iter_bound_methods(instance: Any): + seen_names: set[str] = set() + for name in dir(instance.__class__): + if name.startswith("__") or name in seen_names: + continue + seen_names.add(name) + try: + raw_attr = inspect.getattr_static(instance, name) + except AttributeError: + continue + if isinstance(raw_attr, property): + continue + bound = getattr(instance, name, None) + if not callable(bound): + continue + raw = getattr(bound, "__func__", bound) + yield name, bound, raw + + +def _validated_config_store(instance: Any) -> dict[str, Any]: + values = getattr(instance, _VALIDATED_CONFIGS_ATTR, None) + if isinstance(values, dict): + return values + values = {} + setattr(instance, _VALIDATED_CONFIGS_ATTR, values) + return values + + +def _positional_arg_count(func: Any) -> int: + try: + signature = inspect.signature(func) + except (TypeError, ValueError): + return 0 + return sum( + 1 + for parameter in signature.parameters.values() + if parameter.kind + in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + ) + + +def _call_with_optional_context(bound: Any, context: RuntimeContext) -> Any: + return bound(context) if _positional_arg_count(bound) >= 1 else bound() + + +async def _await_if_needed(value: Any) -> Any: + if inspect.isawaitable(value): + return await value + return value + + +def _normalize_provider_type(value: Any) -> str: + enum_value = getattr(value, "value", None) + if isinstance(enum_value, str): + return enum_value.strip().lower() + return str(value).strip().lower() + + +def _is_valid_schema_expected_type(value: Any) -> bool: + if isinstance(value, type): + return True + return ( + isinstance(value, tuple) + and len(value) > 0 + and all(isinstance(item, type) for item in value) + ) + + +async def _run_model_validation( + *, + instance: Any, + method_name: str, + meta: ValidateConfigMeta, + config: dict[str, Any], +) -> None: + if meta.model is not None: + try: + validated = meta.model.model_validate(config) + except ValidationError as exc: + raise ValueError( + f"{instance.__class__.__name__}.{method_name} validate_config failed: {exc}" + ) from exc + _validated_config_store(instance)[method_name] = validated + return + + assert meta.schema is not None + validated = _validate_schema_config(meta.schema, config) + _validated_config_store(instance)[method_name] = validated + + +def _validate_schema_config( + schema: dict[str, Any], + config: dict[str, Any], +) -> dict[str, Any]: + validated: dict[str, Any] = {} + errors: list[str] = [] + + for field_name, field_schema in schema.items(): + if not isinstance(field_schema, dict): + errors.append(f"{field_name}: schema entry must be an object") + continue + present = field_name in config + value = config.get(field_name, field_schema.get("default")) + required = bool(field_schema.get("required", False)) + if value is None: + if required and "default" not in field_schema: + errors.append(f"{field_name}: is required") + validated[field_name] = value + continue + expected_type = field_schema.get("type") + if expected_type is not None and not _is_valid_schema_expected_type( + expected_type + ): + errors.append( + f"{field_name}: invalid schema 'type' entry {expected_type!r}; " + "expected a type or tuple of types" + ) + continue + if expected_type is not None and not isinstance(value, expected_type): + errors.append( + f"{field_name}: expected {getattr(expected_type, '__name__', expected_type)}, " + f"got {type(value).__name__}" + ) + continue + if isinstance(value, (int, float)) and not isinstance(value, bool): + minimum = field_schema.get("min") + maximum = field_schema.get("max") + range_value = field_schema.get("range") + if minimum is not None and value < minimum: + errors.append(f"{field_name}: must be >= {minimum}") + if maximum is not None and value > maximum: + errors.append(f"{field_name}: must be <= {maximum}") + if ( + isinstance(range_value, tuple) + and len(range_value) == 2 + and not (range_value[0] <= value <= range_value[1]) + ): + errors.append( + f"{field_name}: must be within [{range_value[0]}, {range_value[1]}]" + ) + if required and not present and "default" not in field_schema: + errors.append(f"{field_name}: is required") + validated[field_name] = value + + if errors: + raise ValueError("validate_config schema failed: " + "; ".join(errors)) + return validated + + +async def _run_validate_config(instance: Any, context: RuntimeContext) -> None: + config_payload = await context.metadata.get_plugin_config() + config = dict(config_payload or {}) + for method_name, _bound, raw in _iter_bound_methods(instance): + meta = get_validate_config_meta(raw) + if meta is None: + continue + await _run_model_validation( + instance=instance, + method_name=method_name, + meta=meta, + config=config, + ) + + +async def _register_http_apis(instance: Any, context: RuntimeContext) -> None: + state = _runtime_state(instance) + for _method_name, bound, raw in _iter_bound_methods(instance): + meta = get_http_api_meta(raw) + if meta is None: + continue + await _register_http_api(bound=bound, meta=meta, context=context) + state.http_apis.append((meta.route, list(meta.methods))) + + +async def _register_http_api( + *, + bound: Any, + meta: HttpApiMeta, + context: RuntimeContext, +) -> None: + if meta.capability_name: + await context.http.register_api( + route=meta.route, + handler_capability=meta.capability_name, + methods=list(meta.methods), + description=meta.description, + ) + return + await context.http.register_api( + route=meta.route, + handler=bound, + methods=list(meta.methods), + description=meta.description, + ) + + +async def _register_provider_change_hooks( + instance: Any, + context: RuntimeContext, +) -> None: + state = _runtime_state(instance) + for _method_name, bound, raw in _iter_bound_methods(instance): + meta = get_provider_change_meta(raw) + if meta is None: + continue + + async def callback( + provider_id: str, + provider_type: Any, + umo: str | None, + *, + _bound=bound, + _meta=meta, + ) -> None: + if _meta.provider_types: + current_type = _normalize_provider_type(provider_type) + if current_type not in _meta.provider_types: + return + owner = instance if isinstance(instance, Star) else None + with bind_star_runtime(owner, context): + result = _bound(provider_id, provider_type, umo) + await _await_if_needed(result) + + task = await context.provider_manager.register_provider_change_hook(callback) + # TODO: provider.manager.watch_changes is currently restricted to + # reserved/system plugins. If this decorator should be public-facing, + # the capability boundary needs to be widened or a dedicated event feed + # should be introduced. + state.provider_hooks.append(task) + + +async def _start_background_tasks(instance: Any, context: RuntimeContext) -> None: + state = _runtime_state(instance) + for method_name, bound, raw in _iter_bound_methods(instance): + meta = get_background_task_meta(raw) + if meta is None or not meta.auto_start: + continue + task = await context.register_task( + _background_runner( + instance=instance, + bound=bound, + context=context, + meta=meta, + method_name=method_name, + ), + meta.description + or f"background_task:{instance.__class__.__name__}.{method_name}", + ) + state.background_tasks.append(task) + + +async def _background_runner( + *, + instance: Any, + bound: Any, + context: RuntimeContext, + meta: BackgroundTaskMeta, + method_name: str, +) -> None: + while True: + try: + owner = instance if isinstance(instance, Star) else None + with bind_star_runtime(owner, context): + result = _call_with_optional_context(bound, context) + await _await_if_needed(result) + return + except asyncio.CancelledError: + raise + except Exception: + if meta.on_error != "restart": + raise + context.logger.exception( + "SDK decorator background_task restarting after failure: plugin_id={} task={}", + context.plugin_id, + f"{instance.__class__.__name__}.{method_name}", + ) + + +def _iter_class_and_method_meta( + instance: Any, + getter, +) -> list[Any]: + values = list(getter(instance.__class__)) + for _method_name, _bound, raw in _iter_bound_methods(instance): + values.extend(getter(raw)) + return values + + +async def _register_skills(instance: Any, context: RuntimeContext) -> None: + state = _runtime_state(instance) + for meta in _iter_class_and_method_meta(instance, get_skill_meta): + await context.register_skill( + name=meta.name, + path=meta.path, + description=meta.description, + ) + state.registered_skills.append(meta.name) + + +async def _register_mcp_servers(instance: Any, context: RuntimeContext) -> None: + state = _runtime_state(instance) + for meta in _iter_class_and_method_meta(instance, get_mcp_server_meta): + await _register_mcp_server(meta=meta, context=context) + if meta.scope == "global": + state.global_mcp_servers.append(meta.name) + else: + state.local_mcp_servers.append(meta.name) + + +async def _register_mcp_server( + *, + meta: MCPServerMeta, + context: RuntimeContext, +) -> None: + if meta.scope == "global": + if meta.config is None: + raise ValueError( + f"mcp_server(name={meta.name!r}, scope='global') requires config" + ) + await context.mcp.register_global_server( + meta.name, + dict(meta.config), + timeout=meta.timeout, + ) + return + + if meta.config not in (None, {}): + raise ValueError( + f"mcp_server(name={meta.name!r}, scope='local') does not support config registration" + ) + # TODO: local MCP only supports enable/disable of predeclared servers today. + # If the decorator is expected to register brand-new local servers, the MCP + # client/runtime needs a first-class local register/unregister API. + await context.mcp.enable_server(meta.name) + if meta.wait_until_ready: + await context.mcp.wait_until_ready(meta.name, timeout=meta.timeout) + + +async def _teardown_decorator_resources(instance: Any, context: RuntimeContext) -> None: + state = _runtime_state(instance) + + for task in reversed(state.provider_hooks): + with suppress(asyncio.CancelledError): + await context.provider_manager.unregister_provider_change_hook(task) + state.provider_hooks.clear() + + for task in reversed(state.background_tasks): + if not task.done(): + task.cancel() + for task in reversed(state.background_tasks): + with suppress(asyncio.CancelledError, Exception): + await task + state.background_tasks.clear() + + for route, methods in reversed(state.http_apis): + try: + await context.http.unregister_api(route, methods) + except Exception: + logger.exception( + "decorator http_api cleanup failed: plugin_id={} route={}", + context.plugin_id, + route, + ) + state.http_apis.clear() + + for name in reversed(state.registered_skills): + with suppress(Exception): + await context.unregister_skill(name) + state.registered_skills.clear() + + for name in reversed(state.local_mcp_servers): + with suppress(Exception): + await context.mcp.disable_server(name) + state.local_mcp_servers.clear() + + for name in reversed(state.global_mcp_servers): + with suppress(Exception): + await context.mcp.unregister_global_server(name) + state.global_mcp_servers.clear() + + +async def _invoke_hook( + *, + instance: Any, + hook: Any | None, + context: RuntimeContext, +) -> None: + if hook is None: + return + owner = instance if isinstance(instance, Star) else None + with bind_star_runtime(owner, context): + result = _call_with_optional_context(hook, context) + await _await_if_needed(result) + + +async def run_lifecycle_with_decorators( + *, + instance: Any, + hook: Any | None, + method_name: str, + context: RuntimeContext, +) -> None: + # Keep the lifecycle wrapper centralized so decorator-managed resources still + # work when plugins override on_start/on_stop without calling super(). + if method_name == "on_start": + await _run_validate_config(instance, context) + await _invoke_hook(instance=instance, hook=hook, context=context) + await _register_http_apis(instance, context) + await _register_provider_change_hooks(instance, context) + await _register_skills(instance, context) + await _register_mcp_servers(instance, context) + await _start_background_tasks(instance, context) + return + + try: + await _invoke_hook(instance=instance, hook=hook, context=context) + finally: + if method_name == "on_stop": + await _teardown_decorator_resources(instance, context) + + +__all__ = ["run_lifecycle_with_decorators"] diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/injected_params.py b/astrbot-sdk/src/astrbot_sdk/_internal/injected_params.py new file mode 100644 index 0000000000..ced6229f93 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/injected_params.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import functools +import inspect +from typing import Any + +try: + from typing import get_type_hints +except ImportError: # pragma: no cover + get_type_hints = None + +from .typing_utils import unwrap_optional + +_INJECTED_PARAMETER_NAMES = { + "event", + "ctx", + "context", + "sched", + "schedule", + "conversation", + "conv", +} + + +def is_framework_injected_parameter(name: str, annotation: Any) -> bool: + if name in _INJECTED_PARAMETER_NAMES: + return True + normalized, _is_optional = unwrap_optional(annotation) + if normalized is None: + return False + try: + injected_types = _framework_injected_types() + except Exception: + return False + if normalized in injected_types: + return True + if isinstance(normalized, type): + return issubclass(normalized, injected_types) + return False + + +def legacy_arg_parameter_names(handler: Any) -> list[str]: + try: + signature = inspect.signature(handler) + except (TypeError, ValueError): + return [] + try: + if get_type_hints is None: + type_hints = {} + else: + type_hints = get_type_hints(handler) + except Exception: + type_hints = {} + + names: list[str] = [] + for parameter in signature.parameters.values(): + if parameter.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + continue + if is_framework_injected_parameter( + parameter.name, type_hints.get(parameter.name) + ): + continue + names.append(parameter.name) + return names + + +@functools.lru_cache(maxsize=1) +def _framework_injected_types() -> tuple[type[Any], ...]: + from ..clients.llm import LLMResponse + from ..context import Context + from ..conversation import ConversationSession + from ..events import MessageEvent + from ..llm.entities import ProviderRequest + from ..message.result import MessageEventResult + from ..schedule import ScheduleContext + + return ( + Context, + MessageEvent, + ScheduleContext, + ConversationSession, + ProviderRequest, + LLMResponse, + MessageEventResult, + ) + + +__all__ = ["is_framework_injected_parameter", "legacy_arg_parameter_names"] diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/invocation_context.py b/astrbot-sdk/src/astrbot_sdk/_internal/invocation_context.py new file mode 100644 index 0000000000..2fe2ec1d5e --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/invocation_context.py @@ -0,0 +1,86 @@ +"""插件调用者身份上下文管理。 + +本模块使用 contextvars 实现跨异步任务传播插件身份, +用于在 capability 调用时自动识别调用者插件。 + +典型场景: + - http.register_api: 记录哪个插件注册了 API + - metadata.get_plugin_config: 只允许查询当前插件自己的配置 + - 能力路由层权限校验 + +使用方式: + with caller_plugin_scope("my_plugin"): + # 在此作用域内,current_caller_plugin_id() 返回 "my_plugin" + await ctx.http.register_api(...) + +注意: + contextvars 会自动传播到子任务(asyncio.create_task), + 无需手动传递。 +""" + +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar, Token + +# 存储当前调用者插件 ID 的上下文变量 +_CALLER_PLUGIN_ID: ContextVar[str | None] = ContextVar( + "astrbot_sdk_caller_plugin_id", + default=None, +) + + +def current_caller_plugin_id() -> str | None: + """获取当前上下文中的调用者插件 ID。 + + Returns: + 当前插件 ID,如果不在插件调用上下文中则返回 None + """ + return _CALLER_PLUGIN_ID.get() + + +def bind_caller_plugin_id(plugin_id: str | None) -> Token[str | None]: + """绑定调用者插件 ID 到当前上下文。 + + Args: + plugin_id: 插件 ID,空字符串会被视为 None + + Returns: + 用于后续 reset 的 Token + + Note: + 通常使用 caller_plugin_scope 上下文管理器而非直接调用此函数 + """ + normalized = plugin_id.strip() if isinstance(plugin_id, str) else "" + return _CALLER_PLUGIN_ID.set(normalized or None) + + +def reset_caller_plugin_id(token: Token[str | None]) -> None: + """重置调用者插件 ID 到之前的状态。 + + Args: + token: bind_caller_plugin_id 返回的 Token + """ + _CALLER_PLUGIN_ID.reset(token) + + +@contextmanager +def caller_plugin_scope(plugin_id: str | None) -> Iterator[None]: + """创建一个绑定插件身份的上下文作用域。 + + Args: + plugin_id: 要绑定的插件 ID + + Yields: + None + + 示例: + with caller_plugin_scope("my_plugin"): + await some_capability_call() + """ + token = bind_caller_plugin_id(plugin_id) + try: + yield + finally: + reset_caller_plugin_id(token) diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/memory_utils.py b/astrbot-sdk/src/astrbot_sdk/_internal/memory_utils.py new file mode 100644 index 0000000000..d13720b500 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/memory_utils.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import json +import math +import re +from datetime import datetime, timedelta, timezone +from typing import Any + + +def is_ttl_memory_entry(value: Any) -> bool: + """Return whether a stored memory payload uses the TTL wrapper shape.""" + + return isinstance(value, dict) and "value" in value and "ttl_seconds" in value + + +def memory_value_for_search(stored: Any) -> dict[str, Any] | None: + """Unwrap the search payload from a stored memory record when possible.""" + + if not isinstance(stored, dict): + return None + if is_ttl_memory_entry(stored): + value = stored.get("value") + return value if isinstance(value, dict) else None + return stored + + +def extract_memory_text(stored: Any) -> str: + """Pick the canonical text that keyword/vector search should index.""" + + value = memory_value_for_search(stored) + if not isinstance(value, dict): + return "" + for field_name in ("embedding_text", "content", "summary", "title", "text"): + item = value.get(field_name) + if isinstance(item, str) and item.strip(): + return item.strip() + return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str) + + +def memory_expiration_from_ttl(ttl_seconds: Any) -> datetime | None: + """Translate a TTL in seconds into an absolute UTC expiration timestamp.""" + + try: + ttl = int(ttl_seconds) + except (TypeError, ValueError): + return None + if ttl < 1: + return None + return datetime.now(timezone.utc) + timedelta(seconds=ttl) + + +def memory_expiration_from_stored_payload(stored: Any) -> datetime | None: + """Recover an absolute expiration timestamp from a stored TTL payload.""" + + if not is_ttl_memory_entry(stored) or not isinstance(stored, dict): + return None + raw_expires_at = stored.get("expires_at") + if isinstance(raw_expires_at, (int, float)): + return datetime.fromtimestamp(float(raw_expires_at), tz=timezone.utc) + if not isinstance(raw_expires_at, str): + return None + + normalized = raw_expires_at.strip() + if not normalized: + return None + if normalized.endswith("Z"): + normalized = f"{normalized[:-1]}+00:00" + try: + expires_at = datetime.fromisoformat(normalized) + except ValueError: + return None + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + return expires_at.astimezone(timezone.utc) + + +def normalize_memory_namespace(value: Any) -> str: + """Normalize a namespace path into a stable slash-delimited string.""" + + if value is None: + return "" + if isinstance(value, (list, tuple)): + return join_memory_namespace(*value) + text = str(value).strip().replace("\\", "/") + if not text: + return "" + parts = [segment.strip() for segment in text.split("/") if segment.strip()] + return "/".join(parts) + + +def join_memory_namespace(*parts: Any) -> str: + """Join namespace segments while preserving the root namespace as empty.""" + + normalized_parts: list[str] = [] + for part in parts: + normalized = normalize_memory_namespace(part) + if not normalized: + continue + normalized_parts.extend( + segment for segment in normalized.split("/") if segment.strip() + ) + return "/".join(normalized_parts) + + +def memory_namespace_matches( + candidate: str, + namespace: str | None, + *, + include_descendants: bool, +) -> bool: + """Check whether a stored namespace belongs to the requested scope.""" + + if namespace is None: + return True + normalized_candidate = normalize_memory_namespace(candidate) + normalized_namespace = normalize_memory_namespace(namespace) + if not normalized_namespace: + return include_descendants or normalized_candidate == "" + if normalized_candidate == normalized_namespace: + return True + return include_descendants and normalized_candidate.startswith( + f"{normalized_namespace}/" + ) + + +def display_memory_namespace(value: Any) -> str | None: + """Return a user-facing namespace value.""" + + normalized = normalize_memory_namespace(value) + return normalized or None + + +def _memory_query_terms(value: str) -> list[str]: + normalized = re.sub(r"\s+", " ", str(value).strip().casefold()) + if not normalized: + return [] + terms = [item for item in re.findall(r"\w+", normalized, flags=re.UNICODE) if item] + if terms: + return terms + compact = normalized.replace(" ", "") + return [compact] if compact else [] + + +def memory_keyword_score(query: str, key: str, text: str) -> float: + """Score a keyword hit the same way across runtime and core bridge.""" + + normalized_query = str(query).casefold() + if not normalized_query: + return 1.0 + normalized_key = str(key).casefold() + normalized_text = str(text).casefold() + best = 0.0 + if normalized_query in normalized_key: + best = 1.0 + if normalized_query in normalized_text: + best = max(best, 0.92) + + terms = _memory_query_terms(normalized_query) + if not terms: + return best + + key_hits = sum(1 for term in terms if term in normalized_key) + text_hits = sum(1 for term in terms if term in normalized_text) + if key_hits: + best = max(best, 0.5 + 0.5 * (key_hits / len(terms))) + if text_hits: + best = max(best, 0.35 + 0.55 * (text_hits / len(terms))) + return min(best, 1.0) + + +def cosine_similarity(left: list[float], right: list[float]) -> float: + """Compute cosine similarity defensively for embedding vectors.""" + + if not left or not right or len(left) != len(right): + return 0.0 + left_norm = math.sqrt(sum(value * value for value in left)) + right_norm = math.sqrt(sum(value * value for value in right)) + if left_norm <= 0 or right_norm <= 0: + return 0.0 + return sum(a * b for a, b in zip(left, right, strict=False)) / ( + left_norm * right_norm + ) + + +def normalize_embedding(vector: list[float]) -> list[float]: + """Normalize an embedding for cosine/inner-product search.""" + + if not vector: + return [] + norm = math.sqrt(sum(value * value for value in vector)) + if norm <= 0: + return [0.0 for _ in vector] + return [float(value) / norm for value in vector] + + +def memory_index_entry(entry: Any, *, text: str) -> dict[str, Any]: + """Normalize cached sidecar data into a stable memory index record.""" + + if isinstance(entry, dict): + return { + "text": str(entry.get("text", text)), + "embedding": ( + [float(item) for item in entry.get("embedding", [])] + if isinstance(entry.get("embedding"), list) + else None + ), + "provider_id": ( + str(entry.get("provider_id")).strip() + if entry.get("provider_id") is not None + else None + ), + } + return {"text": text, "embedding": None, "provider_id": None} diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/plugin_ids.py b/astrbot-sdk/src/astrbot_sdk/_internal/plugin_ids.py new file mode 100644 index 0000000000..00564e7868 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/plugin_ids.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import re +from pathlib import Path + +PLUGIN_ID_PATTERN = re.compile(r"^[A-Za-z0-9_](?:[A-Za-z0-9._-]{0,126}[A-Za-z0-9_])?$") +_WINDOWS_RESERVED_PLUGIN_IDS = { + "CON", + "PRN", + "AUX", + "NUL", + "COM1", + "COM2", + "COM3", + "COM4", + "COM5", + "COM6", + "COM7", + "COM8", + "COM9", + "LPT1", + "LPT2", + "LPT3", + "LPT4", + "LPT5", + "LPT6", + "LPT7", + "LPT8", + "LPT9", +} + + +def validate_plugin_id(plugin_id: str) -> str: + normalized = str(plugin_id).strip() + if not normalized: + raise ValueError("plugin_id must not be empty") + if not PLUGIN_ID_PATTERN.fullmatch(normalized): + raise ValueError( + "plugin_id must use only letters, digits, dots, underscores, or hyphens" + ) + upper_normalized = normalized.upper() + base_name = upper_normalized.split(".", 1)[0] + if ( + upper_normalized in _WINDOWS_RESERVED_PLUGIN_IDS + or base_name in _WINDOWS_RESERVED_PLUGIN_IDS + ): + raise ValueError("plugin_id must not use a reserved Windows device name") + return normalized + + +def resolve_plugin_data_dir(root: Path, plugin_id: str) -> Path: + normalized = validate_plugin_id(plugin_id) + resolved_root = root.resolve() + candidate = (resolved_root / normalized).resolve() + try: + candidate.relative_to(resolved_root) + except ValueError as exc: + raise ValueError("plugin_id escapes the plugin data root") from exc + return candidate diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/plugin_logger.py b/astrbot-sdk/src/astrbot_sdk/_internal/plugin_logger.py new file mode 100644 index 0000000000..b89fb8dc18 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/plugin_logger.py @@ -0,0 +1,313 @@ +from __future__ import annotations + +import asyncio +import inspect +import os +import time +from collections.abc import AsyncIterator +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +try: + from astrbot.core.config.default import VERSION as _ASTRBOT_VERSION +except Exception: # noqa: BLE001 + _ASTRBOT_VERSION = "" + +__all__ = ["PluginLogEntry", "PluginLogger"] + + +@dataclass(slots=True) +class PluginLogEntry: + level: str + time: float + message: str + plugin_id: str + context: dict[str, Any] = field(default_factory=dict) + + +class _PluginLogBroker: + def __init__(self, plugin_id: str) -> None: + self.plugin_id = plugin_id + self._subscribers: set[asyncio.Queue[PluginLogEntry]] = set() + + def publish(self, entry: PluginLogEntry) -> None: + for queue in list(self._subscribers): + try: + queue.put_nowait(entry) + except asyncio.QueueFull: + continue + + async def watch(self) -> AsyncIterator[PluginLogEntry]: + queue: asyncio.Queue[PluginLogEntry] = asyncio.Queue() + self._subscribers.add(queue) + try: + while True: + yield await queue.get() + finally: + self._subscribers.discard(queue) + + +_BROKERS: dict[str, _PluginLogBroker] = {} + +_SHORT_LEVEL_NAMES = { + "DEBUG": "DBUG", + "INFO": "INFO", + "WARNING": "WARN", + "ERROR": "ERRO", + "CRITICAL": "CRIT", +} + +_ANSI_RESET = "\u001b[0m" +_ANSI_GREEN = "\u001b[32m" +_ANSI_LEVEL_COLORS = { + "DEBUG": "\u001b[1;34m", + "INFO": "\u001b[1;36m", + "WARNING": "\u001b[1;33m", + "ERROR": "\u001b[31m", + "CRITICAL": "\u001b[1;31m", +} + + +def _get_short_level_name(level_name: str) -> str: + return _SHORT_LEVEL_NAMES.get(level_name.upper(), level_name[:4].upper()) + + +def _build_source_file(pathname: str | None) -> str: + if not pathname: + return "unknown" + dirname = os.path.dirname(pathname) + return ( + os.path.basename(dirname) + "." + os.path.basename(pathname).replace(".py", "") + ) + + +def _plugin_tag_from_path(pathname: str | None) -> str: + if not pathname: + return "[Plug]" + norm_path = os.path.normpath(pathname) + if any( + marker in norm_path + for marker in ( + os.path.normpath("data/plugins"), + os.path.normpath("data/sdk_plugins"), + os.path.normpath("astrbot/builtin_stars"), + ) + ): + return "[Plug]" + return "[Core]" + + +def _level_color(level: str) -> str: + return _ANSI_LEVEL_COLORS.get(level.upper(), _ANSI_RESET) + + +def _get_broker(plugin_id: str) -> _PluginLogBroker: + broker = _BROKERS.get(plugin_id) + if broker is None: + broker = _PluginLogBroker(plugin_id) + _BROKERS[plugin_id] = broker + return broker + + +class PluginLogger: + def __init__( + self, + *, + plugin_id: str, + logger: Any, + bound_context: dict[str, Any] | None = None, + ) -> None: + self._plugin_id = plugin_id + self._logger = logger + self._broker = _get_broker(plugin_id) + self._bound_context = dict(bound_context or {}) + + @property + def plugin_id(self) -> str: + return self._plugin_id + + def bind(self, **kwargs: Any) -> PluginLogger: + bind = getattr(self._logger, "bind", None) + next_logger = self._logger + if callable(bind): + try: + next_logger = bind(**kwargs) + except Exception: + next_logger = self._logger + return PluginLogger( + plugin_id=self._plugin_id, + logger=next_logger, + bound_context={**self._bound_context, **kwargs}, + ) + + def opt(self, *args: Any, **kwargs: Any) -> PluginLogger: + opt = getattr(self._logger, "opt", None) + next_logger = self._logger + if callable(opt): + try: + next_logger = opt(*args, **kwargs) + except Exception: + next_logger = self._logger + return PluginLogger( + plugin_id=self._plugin_id, + logger=next_logger, + bound_context=self._bound_context, + ) + + async def watch(self) -> AsyncIterator[PluginLogEntry]: + async for entry in self._broker.watch(): + yield entry + + def log(self, level: str, message: Any, *args: Any, **kwargs: Any) -> None: + normalized_level = str(level).upper() + self._emit_console(normalized_level, message, *args, **kwargs) + self._publish(normalized_level, message, *args, **kwargs) + + def debug(self, message: Any, *args: Any, **kwargs: Any) -> None: + self._emit_console("DEBUG", message, *args, **kwargs) + self._publish("DEBUG", message, *args, **kwargs) + + def info(self, message: Any, *args: Any, **kwargs: Any) -> None: + self._emit_console("INFO", message, *args, **kwargs) + self._publish("INFO", message, *args, **kwargs) + + def warning(self, message: Any, *args: Any, **kwargs: Any) -> None: + self._emit_console("WARNING", message, *args, **kwargs) + self._publish("WARNING", message, *args, **kwargs) + + def error(self, message: Any, *args: Any, **kwargs: Any) -> None: + self._emit_console("ERROR", message, *args, **kwargs) + self._publish("ERROR", message, *args, **kwargs) + + def exception(self, message: Any, *args: Any, **kwargs: Any) -> None: + self._emit_console("ERROR", message, *args, exception=True, **kwargs) + self._publish("ERROR", message, *args, **kwargs) + + def _emit_console( + self, + level: str, + message: Any, + *args: Any, + exception: bool = False, + **kwargs: Any, + ) -> None: + if self._emit_console_with_opt( + level, + message, + *args, + exception=exception, + **kwargs, + ): + return + self._emit_console_fallback( + level, + message, + *args, + exception=exception, + **kwargs, + ) + + def _emit_console_with_opt( + self, + level: str, + message: Any, + *args: Any, + exception: bool = False, + **kwargs: Any, + ) -> bool: + opt = getattr(self._logger, "opt", None) + if not callable(opt): + return False + formatted_message = self._format_message(message, *args, **kwargs) + pathname, source_line = self._caller_info() + plugin_tag = _plugin_tag_from_path(pathname) + source_file = _build_source_file(pathname) + version_tag = ( + f" [v{_ASTRBOT_VERSION}]" + if _ASTRBOT_VERSION and level in {"WARNING", "ERROR", "CRITICAL"} + else "" + ) + timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3] + level_text = _get_short_level_name(level) + level_color = _level_color(level) + line = ( + f"{_ANSI_GREEN}[{timestamp}]{_ANSI_RESET} {plugin_tag} " + f"{level_color}[{level_text}]{_ANSI_RESET}{version_tag} " + f"[{source_file}:{source_line}]: {level_color}{formatted_message}{_ANSI_RESET}" + ) + try: + emitter = opt(raw=True, exception=True) if exception else opt(raw=True) + log = getattr(emitter, "log", None) + if not callable(log): + return False + log(level, line + "\n") + return True + except Exception: + return False + + def _emit_console_fallback( + self, + level: str, + message: Any, + *args: Any, + exception: bool = False, + **kwargs: Any, + ) -> None: + method_names = [] + if exception: + method_names.append("exception") + method_names.append(str(level).lower()) + if exception: + method_names.append("error") + for method_name in method_names: + method = getattr(self._logger, method_name, None) + if not callable(method): + continue + try: + method(message, *args, **kwargs) + except Exception: + continue + return + log = getattr(self._logger, "log", None) + if callable(log): + try: + log(level, self._format_message(message, *args, **kwargs)) + except Exception: + return + + def _caller_info(self) -> tuple[str | None, int]: + frame = inspect.currentframe() + if frame is None: + return None, 0 + frame = frame.f_back + while frame is not None and frame.f_globals.get("__name__") == __name__: + frame = frame.f_back + if frame is None: + return None, 0 + return str(frame.f_code.co_filename), int(frame.f_lineno) + + def _publish(self, level: str, message: Any, *args: Any, **kwargs: Any) -> None: + entry = PluginLogEntry( + level=level, + time=time.time(), + message=self._format_message(message, *args, **kwargs), + plugin_id=self._plugin_id, + context=dict(self._bound_context), + ) + self._broker.publish(entry) + + @staticmethod + def _format_message(message: Any, *args: Any, **kwargs: Any) -> str: + if not isinstance(message, str): + return str(message) + text = message + if not args and not kwargs: + return text + try: + return text.format(*args, **kwargs) + except Exception: + return text + + def __getattr__(self, name: str) -> Any: + return getattr(self._logger, name) diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/star_runtime.py b/astrbot-sdk/src/astrbot_sdk/_internal/star_runtime.py new file mode 100644 index 0000000000..37211735e6 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/star_runtime.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..context import Context + from ..star import Star + + +_CURRENT_STAR_CONTEXT: ContextVar[Context | None] = ContextVar( + "astrbot_sdk_current_star_context", + default=None, +) +_CURRENT_STAR_INSTANCE: ContextVar[Star | None] = ContextVar( + "astrbot_sdk_current_star_instance", + default=None, +) + + +def current_star_context() -> Context | None: + return _CURRENT_STAR_CONTEXT.get() + + +def current_runtime_context() -> Context | None: + return _CURRENT_STAR_CONTEXT.get() + + +def current_star_instance() -> Star | None: + return _CURRENT_STAR_INSTANCE.get() + + +@contextmanager +def bind_star_runtime(star: Star | None, ctx: Context | None) -> Iterator[None]: + context_token = _CURRENT_STAR_CONTEXT.set(ctx) + star_token = _CURRENT_STAR_INSTANCE.set(star) + instance_token = star._bind_runtime_context(ctx) if star is not None else None + try: + yield + finally: + if star is not None and instance_token is not None: + star._reset_runtime_context(instance_token) + _CURRENT_STAR_INSTANCE.reset(star_token) + _CURRENT_STAR_CONTEXT.reset(context_token) diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/testing_support.py b/astrbot-sdk/src/astrbot_sdk/_internal/testing_support.py new file mode 100644 index 0000000000..a7e5e877c2 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/testing_support.py @@ -0,0 +1,606 @@ +"""Shared support primitives for local SDK testing.""" + +from __future__ import annotations + +import asyncio +import typing +from collections.abc import Mapping +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, TextIO + +from ..context import CancelToken +from ..context import Context as RuntimeContext +from ..events import MessageEvent +from ..protocol.messages import EventMessage, PeerInfo +from ..runtime._streaming import StreamExecution +from ..runtime.capability_router import CapabilityRouter + + +def _clone_payload_mapping(value: Any) -> dict[str, Any] | None: + if not isinstance(value, dict): + return None + return {str(key): item for key, item in value.items()} + + +@dataclass(slots=True) +class RecordedSend: + kind: str + message_id: str + session_id: str + text: str | None = None + image_url: str | None = None + chain: list[dict[str, Any]] | None = None + target: dict[str, Any] | None = None + raw: dict[str, Any] = field(default_factory=dict) + + @property + def session(self) -> str: + return self.session_id + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> RecordedSend: + if "text" in payload: + kind = "text" + elif "image_url" in payload: + kind = "image" + elif "chain" in payload: + kind = "chain" + else: + kind = "unknown" + return cls( + kind=kind, + message_id=str(payload.get("message_id", "")), + session_id=str(payload.get("session", "")), + text=payload.get("text") if isinstance(payload.get("text"), str) else None, + image_url=( + payload.get("image_url") + if isinstance(payload.get("image_url"), str) + else None + ), + chain=( + [dict(item) for item in payload.get("chain", [])] + if isinstance(payload.get("chain"), list) + else None + ), + target=_clone_payload_mapping(payload.get("target")), + raw=dict(payload), + ) + + +class StdoutPlatformSink: + def __init__(self, stream: TextIO | None = None) -> None: + self._stream = stream + self.records: list[RecordedSend] = [] + + def record(self, item: RecordedSend) -> None: + self.records.append(item) + if self._stream is None: + return + self._stream.write(self._format(item) + "\n") + self._stream.flush() + + def clear(self) -> None: + self.records.clear() + + def _format(self, item: RecordedSend) -> str: + if item.kind == "text": + return f"[text][{item.session_id}] {item.text or ''}" + if item.kind == "image": + return f"[image][{item.session_id}] {item.image_url or ''}" + if item.kind == "chain": + count = len(item.chain or []) + return f"[chain][{item.session_id}] {count} components" + return f"[send][{item.session_id}] {item.raw}" + + +class InMemoryDB: + def __init__(self, store: dict[str, Any]) -> None: + self._store = store + + def get(self, key: str, default: Any = None) -> Any: + return self._store.get(key, default) + + def set(self, key: str, value: Any) -> None: + self._store[key] = value + + def delete(self, key: str) -> None: + self._store.pop(key, None) + + def list(self, prefix: str | None = None) -> list[str]: + keys = sorted(self._store.keys()) + if prefix is None: + return keys + return [key for key in keys if key.startswith(prefix)] + + def get_many(self, keys: list[str]) -> list[dict[str, Any]]: + return [{"key": key, "value": self._store.get(key)} for key in keys] + + def set_many(self, items: list[dict[str, Any]]) -> None: + for item in items: + self.set(str(item.get("key", "")), item.get("value")) + + +class InMemoryMemory: + def __init__( + self, + store: dict[str, dict[str, Any]], + *, + expires_at: dict[str, datetime | None] | None = None, + ) -> None: + self._store = store + self._expires_at = expires_at if expires_at is not None else {} + + @staticmethod + def _is_ttl_entry(value: Any) -> bool: + """判断测试 memory 值是否使用 TTL 包装结构。 + + Args: + value: 待检查的存储值。 + + Returns: + bool: 如果包含 ``value`` 和 ``ttl_seconds`` 字段则返回 ``True``。 + """ + return isinstance(value, dict) and "value" in value and "ttl_seconds" in value + + @classmethod + def _search_text(cls, value: Any) -> str: + """提取测试用 memory.search 的匹配文本。 + + Args: + value: 当前存储的 memory 值。 + + Returns: + str: 用于本地测试搜索的文本内容。 + """ + if cls._is_ttl_entry(value): + value = value.get("value") + if not isinstance(value, dict): + return "" + for field_name in ("embedding_text", "content", "summary", "title", "text"): + item = value.get(field_name) + if isinstance(item, str) and item.strip(): + return item.strip() + return str(value) + + def _is_expired(self, key: str) -> bool: + """判断测试 memory 键是否已经过期。 + + Args: + key: memory 条目的键。 + + Returns: + bool: 如果当前时间已超过过期时间则返回 ``True``。 + """ + expires_at = self._expires_at.get(key) + return expires_at is not None and expires_at <= datetime.now(timezone.utc) + + def _purge_if_expired(self, key: str) -> bool: + """在测试 helper 中清理已过期的 memory 条目。 + + Args: + key: memory 条目的键。 + + Returns: + bool: 如果条目已过期并被清理则返回 ``True``。 + """ + if not self._is_expired(key): + return False + self._store.pop(key, None) + self._expires_at.pop(key, None) + return True + + def get(self, key: str, default: Any = None) -> Any: + if self._purge_if_expired(key): + return default + return self._store.get(key, default) + + def save(self, key: str, value: dict[str, Any]) -> None: + self._store[key] = dict(value) + + def delete(self, key: str) -> None: + self._store.pop(key, None) + self._expires_at.pop(key, None) + + def search(self, query: str) -> list[dict[str, Any]]: + results: list[dict[str, Any]] = [] + for key, value in list(self._store.items()): + if self._purge_if_expired(key): + continue + if query in key or query in self._search_text(value): + results.append({"key": key, "value": value}) + return results + + +class MockLLMClient: + def __init__(self, client: Any, router: MockCapabilityRouter) -> None: + self._client = client + self._router = router + + def mock_response(self, text: str) -> None: + self._router.enqueue_llm_response(text) + + def mock_stream_response(self, text: str) -> None: + self._router.enqueue_llm_stream_response(text) + + def clear_mock_responses(self) -> None: + self._router.clear_llm_responses() + + def __getattr__(self, name: str) -> Any: + return getattr(self._client, name) + + +class MockPlatformClient: + def __init__(self, client: Any, sink: StdoutPlatformSink) -> None: + self._client = client + self._sink = sink + + @property + def records(self) -> list[RecordedSend]: + return list(self._sink.records) + + def assert_sent( + self, + expected_text: str | None = None, + *, + kind: str = "text", + count: int | None = None, + ) -> None: + matched = [item for item in self._sink.records if item.kind == kind] + if expected_text is not None: + matched = [item for item in matched if item.text == expected_text] + if count is not None: + if len(matched) != count: + raise AssertionError( + f"expected {count} sent records, got {len(matched)}: {matched}" + ) + return + if not matched: + raise AssertionError( + f"expected sent record kind={kind!r} text={expected_text!r}, got {self._sink.records}" + ) + + def __getattr__(self, name: str) -> Any: + return getattr(self._client, name) + + +class MockCapabilityRouter(CapabilityRouter): + def __init__(self, *, platform_sink: StdoutPlatformSink | None = None) -> None: + self.platform_sink = platform_sink or StdoutPlatformSink() + self._llm_responses: list[str] = [] + self._llm_stream_responses: list[str] = [] + super().__init__() + self.db = InMemoryDB(self.db_store) + self.memory = InMemoryMemory( + self.memory_store, + expires_at=self._memory_expires_at, + ) + + def list_dynamic_command_routes(self, plugin_id: str) -> list[dict[str, Any]]: + return super().list_dynamic_command_routes(plugin_id) + + def remove_dynamic_command_routes_for_plugin(self, plugin_id: str) -> None: + super().remove_dynamic_command_routes_for_plugin(plugin_id) + + def emit_provider_change( + self, + provider_id: str, + provider_type: str, + umo: str | None = None, + ) -> None: + super().emit_provider_change(provider_id, provider_type, umo) + + def record_platform_error( + self, + platform_id: str, + message: str, + *, + traceback: str | None = None, + ) -> None: + super().record_platform_error(platform_id, message, traceback=traceback) + + def set_platform_stats(self, platform_id: str, stats: dict[str, Any]) -> None: + super().set_platform_stats(platform_id, stats) + + def enqueue_llm_response(self, text: str) -> None: + self._llm_responses.append(text) + + def enqueue_llm_stream_response(self, text: str) -> None: + self._llm_stream_responses.append(text) + + def clear_llm_responses(self) -> None: + self._llm_responses.clear() + self._llm_stream_responses.clear() + + async def execute( + self, + capability: str, + payload: dict[str, Any], + *, + stream: bool, + cancel_token, + request_id: str, + ) -> dict[str, Any] | StreamExecution: + if capability == "llm.chat": + return {"text": self._take_llm_response(str(payload.get("prompt", "")))} + if capability == "llm.chat_raw": + text = self._take_llm_response(str(payload.get("prompt", ""))) + return { + "text": text, + "usage": { + "input_tokens": len(str(payload.get("prompt", ""))), + "output_tokens": len(text), + }, + "finish_reason": "stop", + "tool_calls": [], + "role": "assistant", + "reasoning_content": None, + "reasoning_signature": None, + } + if capability == "llm.stream_chat": + text = self._take_llm_stream_response(str(payload.get("prompt", ""))) + + async def iterator() -> typing.AsyncIterator[dict[str, Any]]: + for char in text: + cancel_token.raise_if_cancelled() + await asyncio.sleep(0) + yield {"text": char} + + return StreamExecution( + iterator=iterator(), + finalize=lambda chunks: { + "text": "".join(item.get("text", "") for item in chunks) + }, + ) + before = len(self.sent_messages) + result = await super().execute( + capability, + payload, + stream=stream, + cancel_token=cancel_token, + request_id=request_id, + ) + self._flush_platform_records(before) + return result + + def _flush_platform_records(self, start_index: int) -> None: + for payload in self.sent_messages[start_index:]: + self.platform_sink.record(RecordedSend.from_payload(payload)) + + def _take_llm_response(self, prompt: str) -> str: + if self._llm_responses: + return self._llm_responses.pop(0) + return f"Echo: {prompt}" + + def _take_llm_stream_response(self, prompt: str) -> str: + if self._llm_stream_responses: + return self._llm_stream_responses.pop(0) + if self._llm_responses: + return self._llm_responses.pop(0) + return f"Echo: {prompt}" + + +class MockPeer: + def __init__(self, router: MockCapabilityRouter) -> None: + self._router = router + self._counter = 0 + self.remote_peer = PeerInfo( + name="astrbot-local-core", + role="core", + version="local", + ) + self.remote_capabilities = list(router.descriptors()) + self.remote_capability_map = { + item.name: item for item in self.remote_capabilities + } + self.remote_handlers: list[Any] = [] + self.remote_provided_capabilities: list[Any] = [] + self.remote_metadata = {"mode": "local"} + + async def invoke( + self, + capability: str, + payload: dict[str, Any], + *, + stream: bool = False, + request_id: str | None = None, + ) -> dict[str, Any]: + if stream: + raise ValueError("stream=True 请使用 invoke_stream()") + return typing.cast( + dict[str, Any], + await self._router.execute( + capability, + payload, + stream=False, + cancel_token=CancelToken(), + request_id=request_id or self._next_id(), + ), + ) + + async def invoke_stream( + self, + capability: str, + payload: dict[str, Any], + *, + request_id: str | None = None, + include_completed: bool = False, + ): + request_id = request_id or self._next_id() + execution = typing.cast( + StreamExecution, + await self._router.execute( + capability, + payload, + stream=True, + cancel_token=CancelToken(), + request_id=request_id, + ), + ) + + async def iterator(): + yield EventMessage.model_validate({"id": request_id, "phase": "started"}) + chunks: list[dict[str, Any]] = [] + async for chunk in execution.iterator: + if execution.collect_chunks: + chunks.append(chunk) + yield EventMessage.model_validate( + {"id": request_id, "phase": "delta", "data": chunk} + ) + output = execution.finalize(chunks) + if include_completed: + yield EventMessage.model_validate( + {"id": request_id, "phase": "completed", "output": output} + ) + + return iterator() + + def _next_id(self) -> str: + self._counter += 1 + return f"local_{self._counter:04d}" + + +def _normalize_plugin_metadata( + plugin_id: str, + plugin_metadata: Mapping[str, Any] | None, +) -> dict[str, Any]: + if plugin_metadata is None: + plugin_metadata = {} + declared_name = plugin_metadata.get("name") + if declared_name is not None and str(declared_name) != plugin_id: + raise ValueError( + "MockContext.plugin_metadata['name'] 必须与 plugin_id 一致," + f"当前收到 {declared_name!r} != {plugin_id!r}" + ) + description = plugin_metadata.get("description") + if description is None: + description = plugin_metadata.get("desc", "") + return { + "name": plugin_id, + "display_name": str(plugin_metadata.get("display_name") or plugin_id), + "description": str(description or ""), + "author": str(plugin_metadata.get("author") or ""), + "version": str(plugin_metadata.get("version") or "0.0.0"), + "enabled": bool(plugin_metadata.get("enabled", True)), + "reserved": bool(plugin_metadata.get("reserved", False)), + "acknowledge_global_mcp_risk": bool( + plugin_metadata.get("acknowledge_global_mcp_risk", False) + ), + "local_mcp_servers": ( + { + str(server_name): dict(server_payload) + for server_name, server_payload in plugin_metadata.get( + "local_mcp_servers", + {}, + ).items() + if str(server_name).strip() and isinstance(server_payload, dict) + } + if isinstance(plugin_metadata.get("local_mcp_servers"), dict) + else {} + ), + "support_platforms": [ + str(item) + for item in plugin_metadata.get("support_platforms", []) + if isinstance(item, str) + ] + if isinstance(plugin_metadata.get("support_platforms"), list) + else [], + "astrbot_version": ( + str(plugin_metadata.get("astrbot_version")) + if plugin_metadata.get("astrbot_version") is not None + else None + ), + } + + +class MockContext(RuntimeContext): + def __init__( + self, + *, + plugin_id: str = "test-plugin", + logger: Any | None = None, + cancel_token: CancelToken | None = None, + platform_sink: StdoutPlatformSink | None = None, + plugin_metadata: Mapping[str, Any] | None = None, + ) -> None: + self.platform_sink = platform_sink or StdoutPlatformSink() + self.router = MockCapabilityRouter(platform_sink=self.platform_sink) + self.mock_peer = MockPeer(self.router) + super().__init__( + peer=self.mock_peer, + plugin_id=plugin_id, + cancel_token=cancel_token, + logger=logger, + ) + self.router.upsert_plugin( + metadata=_normalize_plugin_metadata(plugin_id, plugin_metadata), + config={}, + ) + self.llm = MockLLMClient(self.llm, self.router) + self.platform = MockPlatformClient(self.platform, self.platform_sink) + + @property + def sent_messages(self) -> list[RecordedSend]: + return list(self.platform_sink.records) + + @property + def event_actions(self) -> list[dict[str, Any]]: + return list(self.router.event_actions) + + +class MockMessageEvent(MessageEvent): + def __init__( + self, + *, + text: str = "", + user_id: str | None = "test-user", + group_id: str | None = None, + platform: str | None = "test", + session_id: str | None = "test-session", + raw: dict[str, Any] | None = None, + context: MockContext | None = None, + ) -> None: + self.replies: list[str] = [] + super().__init__( + text=text, + user_id=user_id, + group_id=group_id, + platform=platform, + session_id=session_id, + raw=raw, + context=context, + ) + if context is not None: + self.bind_runtime_reply(context) + elif self._reply_handler is None: + self.bind_reply_handler(self._capture_reply) + + @property + def is_private(self) -> bool: + return self.group_id is None + + def bind_runtime_reply(self, context: MockContext) -> None: + self._context = context + + async def reply(text: str) -> None: + self.replies.append(text) + await context.platform.send(self.session_ref or self.session_id, text) + + self.bind_reply_handler(reply) + + async def _capture_reply(self, text: str) -> None: + self.replies.append(text) + + +__all__ = [ + "InMemoryDB", + "InMemoryMemory", + "MockCapabilityRouter", + "MockContext", + "MockLLMClient", + "MockMessageEvent", + "MockPeer", + "MockPlatformClient", + "RecordedSend", + "StdoutPlatformSink", +] diff --git a/astrbot-sdk/src/astrbot_sdk/_internal/typing_utils.py b/astrbot-sdk/src/astrbot_sdk/_internal/typing_utils.py new file mode 100644 index 0000000000..7cac7421ba --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_internal/typing_utils.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import typing +from types import UnionType +from typing import Any + + +def unwrap_optional(annotation: Any) -> tuple[Any, bool]: + origin = typing.get_origin(annotation) + if origin in {typing.Union, UnionType}: + args = [item for item in typing.get_args(annotation) if item is not type(None)] + if len(args) == 1: + return args[0], True + return annotation, False + + +__all__ = ["unwrap_optional"] diff --git a/astrbot-sdk/src/astrbot_sdk/_memory_backend.py b/astrbot-sdk/src/astrbot_sdk/_memory_backend.py new file mode 100644 index 0000000000..50f94cbced --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_memory_backend.py @@ -0,0 +1,1515 @@ +from __future__ import annotations + +import asyncio +import json +import re +import sqlite3 +import threading +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, cast + +from ._internal.memory_utils import ( + cosine_similarity, + display_memory_namespace, + extract_memory_text, + join_memory_namespace, + memory_keyword_score, + memory_namespace_matches, + memory_value_for_search, + normalize_embedding, + normalize_memory_namespace, +) + + +def _utcnow() -> datetime: + # Centralize time access so expiry tests can advance time without mutating SQLite internals. + return datetime.now(timezone.utc) + + +def _sql_placeholders(count: int) -> str: + if count <= 0: + raise ValueError("count must be positive") + return ", ".join("?" for _ in range(count)) + + +def _normalize_scope_namespace(namespace: str | None) -> str | None: + if namespace is None: + return None + return normalize_memory_namespace(namespace) + + +def _escape_like_value(value: str) -> str: + return str(value).replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + + +EmbedMany = Callable[[list[str]], Awaitable[list[list[float]]] | list[list[float]]] +EmbedOne = Callable[[str], Awaitable[list[float]] | list[float]] + + +@dataclass(slots=True) +class MemorySearchResult: + key: str + namespace: str + value: dict[str, Any] | None + score: float + match_type: str + + def to_payload(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "key": self.key, + "value": self.value, + "score": self.score, + "match_type": self.match_type, + } + namespace = display_memory_namespace(self.namespace) + if namespace is not None: + payload["namespace"] = namespace + return payload + + +@dataclass(slots=True) +class _StoredRecord: + namespace: str + key: str + stored: dict[str, Any] + search_text: str + updated_at: str + + +@dataclass(slots=True) +class _VectorCandidate: + namespace: str + key: str + stored: dict[str, Any] + search_text: str + score: float + + +class PluginMemoryBackend: + """Persistent plugin-scoped memory backend with namespace-aware search.""" + + def __init__(self, data_dir: Path) -> None: + self._base_dir = Path(data_dir) / "memory" + self._db_path = self._base_dir / "memory.sqlite3" + self._vector_dir = self._base_dir / "vectors" + self._lock = threading.RLock() + self._initialized = False + self._fts_enabled = False + self._vector_indexes: dict[str, Any | None] = {} + self._vector_fallbacks: dict[str, list[tuple[int, list[float]]]] = {} + + async def save( + self, + key: str, + value: dict[str, Any], + *, + namespace: str | None = None, + ) -> None: + await asyncio.to_thread( + self._save_sync, + str(key), + dict(value), + normalize_memory_namespace(namespace), + None, + ) + + async def save_with_ttl( + self, + key: str, + value: dict[str, Any], + ttl_seconds: int, + *, + namespace: str | None = None, + ) -> None: + expires_at = _utcnow().timestamp() + max(int(ttl_seconds), 0) + await asyncio.to_thread( + self._save_sync, + str(key), + dict(value), + normalize_memory_namespace(namespace), + { + "ttl_seconds": int(ttl_seconds), + "expires_at": datetime.fromtimestamp( + expires_at, + tz=timezone.utc, + ).isoformat(), + }, + ) + + async def get( + self, + key: str, + *, + namespace: str | None = None, + ) -> dict[str, Any] | None: + return await asyncio.to_thread( + self._get_sync, + str(key), + normalize_memory_namespace(namespace), + ) + + async def list_keys( + self, + *, + namespace: str | None = None, + ) -> list[str]: + return await asyncio.to_thread( + self._list_keys_sync, + normalize_memory_namespace(namespace), + ) + + async def exists( + self, + key: str, + *, + namespace: str | None = None, + ) -> bool: + return await asyncio.to_thread( + self._exists_sync, + str(key), + normalize_memory_namespace(namespace), + ) + + async def get_many( + self, + keys: list[str], + *, + namespace: str | None = None, + ) -> list[dict[str, Any]]: + normalized_namespace = normalize_memory_namespace(namespace) + return await asyncio.to_thread( + self._get_many_sync, + [str(item) for item in keys], + normalized_namespace, + ) + + async def delete( + self, + key: str, + *, + namespace: str | None = None, + ) -> bool: + return await asyncio.to_thread( + self._delete_sync, + str(key), + normalize_memory_namespace(namespace), + ) + + async def clear_namespace( + self, + *, + namespace: str | None = None, + include_descendants: bool = False, + ) -> int: + normalized_namespace = _normalize_scope_namespace(namespace) + return await asyncio.to_thread( + self._clear_namespace_sync, + normalized_namespace, + bool(include_descendants), + ) + + async def delete_many( + self, + keys: list[str], + *, + namespace: str | None = None, + ) -> int: + normalized_namespace = normalize_memory_namespace(namespace) + return await asyncio.to_thread( + self._delete_many_sync, + [str(item) for item in keys], + normalized_namespace, + ) + + async def count( + self, + *, + namespace: str | None = None, + include_descendants: bool = False, + ) -> int: + normalized_namespace = _normalize_scope_namespace(namespace) + return await asyncio.to_thread( + self._count_sync, + normalized_namespace, + bool(include_descendants), + ) + + async def stats( + self, + *, + namespace: str | None = None, + include_descendants: bool = True, + ) -> dict[str, Any]: + normalized_namespace = _normalize_scope_namespace(namespace) + return await asyncio.to_thread( + self._stats_sync, + normalized_namespace, + bool(include_descendants), + ) + + async def search( + self, + query: str, + *, + namespace: str | None = None, + include_descendants: bool = True, + mode: str, + limit: int | None, + min_score: float | None, + provider_id: str | None = None, + embed_one: EmbedOne | None = None, + embed_many: EmbedMany | None = None, + ) -> list[dict[str, Any]]: + normalized_namespace = _normalize_scope_namespace(namespace) + normalized_mode = str(mode).strip().lower() or "keyword" + query_text = str(query) + + await asyncio.to_thread(self._purge_expired_sync) + + keyword_candidates = await asyncio.to_thread( + self._keyword_candidates_sync, + query_text, + normalized_namespace, + bool(include_descendants), + limit, + ) + + vector_candidates: list[_VectorCandidate] = [] + if normalized_mode in {"vector", "hybrid"} and provider_id: + await self._ensure_embeddings( + provider_id=provider_id, + namespace=normalized_namespace, + include_descendants=bool(include_descendants), + embed_one=embed_one, + embed_many=embed_many, + ) + if embed_one is not None: + raw_query_embedding = await _maybe_await(embed_one(query_text)) + query_embedding = normalize_embedding( + [float(item) for item in raw_query_embedding] + ) + vector_candidates = await asyncio.to_thread( + self._vector_candidates_sync, + provider_id, + query_embedding, + normalized_namespace, + bool(include_descendants), + limit, + ) + + merged: dict[tuple[str, str], dict[str, Any]] = {} + for record in keyword_candidates: + identity = (record.namespace, record.key) + merged[identity] = { + "namespace": record.namespace, + "key": record.key, + "stored": record.stored, + "keyword_score": memory_keyword_score( + query_text, + record.key, + record.search_text, + ), + "vector_score": 0.0, + } + for record in vector_candidates: + identity = (record.namespace, record.key) + current = merged.setdefault( + identity, + { + "namespace": record.namespace, + "key": record.key, + "stored": record.stored, + "keyword_score": memory_keyword_score( + query_text, + record.key, + record.search_text, + ), + "vector_score": 0.0, + }, + ) + current["vector_score"] = max( + float(current["vector_score"]), + float(record.score), + ) + + results: list[MemorySearchResult] = [] + for item in merged.values(): + keyword_score = max(0.0, float(item["keyword_score"])) + vector_score = max(0.0, float(item["vector_score"])) + score = self._combined_score( + mode=normalized_mode, + keyword_score=keyword_score, + vector_score=vector_score, + ) + if score <= 0: + continue + if min_score is not None and score < float(min_score): + continue + + if normalized_mode == "keyword" or ( + keyword_score > 0 and vector_score <= 0 + ): + match_type = "keyword" + elif normalized_mode == "vector" or keyword_score <= 0: + match_type = "vector" + else: + match_type = "hybrid" + + results.append( + MemorySearchResult( + key=str(item["key"]), + namespace=str(item["namespace"]), + value=memory_value_for_search(item["stored"]), + score=score, + match_type=match_type, + ) + ) + + results.sort(key=lambda item: (-item.score, item.namespace, item.key)) + if limit is not None and limit >= 0: + results = results[:limit] + return [item.to_payload() for item in results] + + async def _ensure_embeddings( + self, + *, + provider_id: str, + namespace: str | None, + include_descendants: bool, + embed_one: EmbedOne | None, + embed_many: EmbedMany | None, + ) -> None: + missing = await asyncio.to_thread( + self._missing_embeddings_sync, + provider_id, + namespace, + include_descendants, + ) + if missing: + texts = [record.search_text for record in missing] + embeddings: list[list[float]] + if embed_many is not None: + raw_embeddings = await _maybe_await(embed_many(texts)) + embeddings = [ + normalize_embedding([float(value) for value in item]) + for item in raw_embeddings + ] + elif embed_one is not None: + embeddings = [] + for text in texts: + raw_vector = await _maybe_await(embed_one(text)) + embeddings.append( + normalize_embedding([float(value) for value in raw_vector]) + ) + else: + embeddings = [] + await asyncio.to_thread( + self._upsert_embeddings_sync, + provider_id, + missing, + embeddings, + ) + await asyncio.to_thread(self._ensure_vector_index_sync, provider_id) + + def _save_sync( + self, + key: str, + value: dict[str, Any], + namespace: str, + ttl_metadata: dict[str, Any] | None, + ) -> None: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + stored = dict(value) + expires_at: str | None = None + if ttl_metadata is not None: + expires_at = str(ttl_metadata.get("expires_at", "")).strip() or None + stored = { + "value": dict(value), + "ttl_seconds": int(ttl_metadata.get("ttl_seconds", 0)), + } + if expires_at is not None: + stored["expires_at"] = expires_at + search_text = extract_memory_text(stored) + stored_json = json.dumps( + stored, + ensure_ascii=False, + sort_keys=True, + default=str, + ) + updated_at = _utcnow().isoformat() + conn.execute( + """ + INSERT INTO memory_records(namespace, key, stored_json, search_text, expires_at, updated_at) + VALUES(?, ?, ?, ?, ?, ?) + ON CONFLICT(namespace, key) DO UPDATE SET + stored_json = excluded.stored_json, + search_text = excluded.search_text, + expires_at = excluded.expires_at, + updated_at = excluded.updated_at + """, + (namespace, key, stored_json, search_text, expires_at, updated_at), + ) + self._sync_fts_row_locked( + conn, + namespace=namespace, + key=key, + search_text=search_text, + ) + provider_rows = conn.execute( + """ + SELECT DISTINCT provider_id + FROM memory_embeddings + WHERE namespace = ? AND key = ? + """, + (namespace, key), + ).fetchall() + conn.execute( + "DELETE FROM memory_embeddings WHERE namespace = ? AND key = ?", + (namespace, key), + ) + for row in provider_rows: + provider_id = str(row[0]).strip() + if provider_id: + self._mark_vector_dirty_locked(conn, provider_id) + conn.commit() + finally: + conn.close() + + def _get_sync(self, key: str, namespace: str) -> dict[str, Any] | None: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + row = conn.execute( + """ + SELECT stored_json + FROM memory_records + WHERE namespace = ? AND key = ? + """, + (namespace, key), + ).fetchone() + if row is None: + return None + stored = self._load_stored_json(row[0]) + return memory_value_for_search(stored) + finally: + conn.close() + + def _list_keys_sync(self, namespace: str) -> list[str]: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + rows = conn.execute( + """ + SELECT key + FROM memory_records + WHERE namespace = ? + ORDER BY key COLLATE NOCASE ASC, key ASC + """, + (namespace,), + ).fetchall() + return [str(row[0]) for row in rows] + finally: + conn.close() + + def _exists_sync(self, key: str, namespace: str) -> bool: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + row = conn.execute( + """ + SELECT 1 + FROM memory_records + WHERE namespace = ? AND key = ? + LIMIT 1 + """, + (namespace, key), + ).fetchone() + return row is not None + finally: + conn.close() + + def _get_many_sync(self, keys: list[str], namespace: str) -> list[dict[str, Any]]: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + if not keys: + return [] + lookup_keys = list(dict.fromkeys(keys)) + placeholders = _sql_placeholders(len(lookup_keys)) + rows = conn.execute( + f""" + SELECT key, stored_json + FROM memory_records + WHERE namespace = ? AND key IN ({placeholders}) + """, + (namespace, *lookup_keys), + ).fetchall() + stored_by_key = { + str(row[0]): self._load_stored_json(row[1]) for row in rows + } + return [ + { + "key": key, + "value": memory_value_for_search(stored_by_key.get(key)), + } + for key in keys + ] + finally: + conn.close() + + def _delete_sync(self, key: str, namespace: str) -> bool: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + deleted = self._delete_record_locked(conn, namespace=namespace, key=key) + conn.commit() + return deleted + finally: + conn.close() + + def _clear_namespace_sync( + self, + namespace: str | None, + include_descendants: bool, + ) -> int: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + deleted = self._delete_scope_locked( + conn, + namespace=namespace, + include_descendants=include_descendants, + ) + conn.commit() + return deleted + finally: + conn.close() + + def _delete_many_sync(self, keys: list[str], namespace: str) -> int: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + unique_keys = list(dict.fromkeys(keys)) + if not unique_keys: + conn.commit() + return 0 + placeholders = _sql_placeholders(len(unique_keys)) + provider_rows = conn.execute( + f""" + SELECT DISTINCT provider_id + FROM memory_embeddings + WHERE namespace = ? AND key IN ({placeholders}) + """, + (namespace, *unique_keys), + ).fetchall() + conn.execute( + f"DELETE FROM memory_embeddings WHERE namespace = ? AND key IN ({placeholders})", + (namespace, *unique_keys), + ) + deleted = conn.execute( + f"DELETE FROM memory_records WHERE namespace = ? AND key IN ({placeholders})", + (namespace, *unique_keys), + ).rowcount + if self._fts_enabled: + conn.execute( + f"DELETE FROM memory_records_fts WHERE namespace = ? AND key IN ({placeholders})", + (namespace, *unique_keys), + ) + for row in provider_rows: + provider_id = str(row[0]).strip() + if provider_id: + self._mark_vector_dirty_locked(conn, provider_id) + conn.commit() + return deleted + finally: + conn.close() + + def _count_sync( + self, + namespace: str | None, + include_descendants: bool, + ) -> int: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + where_sql, params = self._namespace_where( + namespace, + include_descendants=include_descendants, + ) + return int( + conn.execute( + f"SELECT COUNT(*) FROM memory_records WHERE {where_sql}", + params, + ).fetchone()[0] + ) + finally: + conn.close() + + def _stats_sync( + self, + namespace: str | None, + include_descendants: bool, + ) -> dict[str, Any]: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + where_sql, params = self._namespace_where( + namespace, + include_descendants=include_descendants, + ) + total_items = int( + conn.execute( + f"SELECT COUNT(*) FROM memory_records WHERE {where_sql}", + params, + ).fetchone()[0] + ) + ttl_entries = int( + conn.execute( + f""" + SELECT COUNT(*) + FROM memory_records + WHERE {where_sql} AND expires_at IS NOT NULL + """, + params, + ).fetchone()[0] + ) + total_bytes = int( + conn.execute( + f""" + SELECT COALESCE(SUM(LENGTH(key) + LENGTH(stored_json)), 0) + FROM memory_records + WHERE {where_sql} + """, + params, + ).fetchone()[0] + ) + namespace_count = int( + conn.execute( + f""" + SELECT COUNT(DISTINCT namespace) + FROM memory_records + WHERE {where_sql} + """, + params, + ).fetchone()[0] + ) + embedding_where_sql, embedding_params = self._namespace_where( + namespace, + include_descendants=include_descendants, + alias="e", + ) + embedded_items = int( + conn.execute( + f""" + SELECT COUNT(*) + FROM ( + SELECT DISTINCT e.namespace, e.key + FROM memory_embeddings e + WHERE {embedding_where_sql} + ) + """, + embedding_params, + ).fetchone()[0] + ) + indexed_items = total_items + dirty_items = max(indexed_items - embedded_items, 0) + provider_rows = conn.execute( + """ + SELECT provider_id, dirty + FROM memory_vector_state + ORDER BY provider_id + """ + ).fetchall() + return { + "total_items": total_items, + "total_bytes": total_bytes, + "ttl_entries": ttl_entries, + "namespace": ( + None + if namespace is None + else normalize_memory_namespace(namespace) + ), + "namespace_count": namespace_count, + "indexed_items": indexed_items, + "embedded_items": embedded_items, + "dirty_items": dirty_items, + "fts_enabled": self._fts_enabled, + "vector_backend": self._vector_backend_label(), + "vector_indexes": [ + { + "provider_id": str(provider_id), + "dirty": bool(dirty), + } + for provider_id, dirty in provider_rows + ], + } + finally: + conn.close() + + def _keyword_candidates_sync( + self, + query: str, + namespace: str | None, + include_descendants: bool, + limit: int | None, + ) -> list[_StoredRecord]: + with self._lock: + conn = self._connect() + try: + fetch_limit = max((int(limit) if limit is not None else 10) * 8, 50) + where_sql, params = self._namespace_where( + namespace, + include_descendants=include_descendants, + ) + seen: set[tuple[str, str]] = set() + records: list[_StoredRecord] = [] + fts_query = self._fts_query(query) + if self._fts_enabled and fts_query is not None: + fts_where_sql, fts_params = self._namespace_where( + namespace, + include_descendants=include_descendants, + alias="r", + ) + rows = conn.execute( + f""" + SELECT r.namespace, r.key, r.stored_json, r.search_text, r.updated_at + FROM memory_records_fts f + JOIN memory_records r + ON r.namespace = f.namespace AND r.key = f.key + WHERE {fts_where_sql} AND memory_records_fts MATCH ? + ORDER BY bm25(memory_records_fts), r.updated_at DESC + LIMIT ? + """, + (*fts_params, fts_query, fetch_limit), + ).fetchall() + for row in rows: + record = self._stored_record_from_row(row) + identity = (record.namespace, record.key) + if identity not in seen: + seen.add(identity) + records.append(record) + + like_query = f"%{str(query).strip()}%" + if not records or len(records) < fetch_limit: + rows = conn.execute( + f""" + SELECT namespace, key, stored_json, search_text, updated_at + FROM memory_records + WHERE {where_sql} + AND (? = '%%' OR key LIKE ? COLLATE NOCASE OR search_text LIKE ? COLLATE NOCASE) + ORDER BY updated_at DESC + LIMIT ? + """, + (*params, like_query, like_query, like_query, fetch_limit), + ).fetchall() + for row in rows: + record = self._stored_record_from_row(row) + identity = (record.namespace, record.key) + if identity not in seen: + seen.add(identity) + records.append(record) + return records + finally: + conn.close() + + def _missing_embeddings_sync( + self, + provider_id: str, + namespace: str | None, + include_descendants: bool, + ) -> list[_StoredRecord]: + with self._lock: + conn = self._connect() + try: + where_sql, params = self._namespace_where( + namespace, + include_descendants=include_descendants, + alias="r", + ) + rows = conn.execute( + f""" + SELECT r.namespace, r.key, r.stored_json, r.search_text, r.updated_at + FROM memory_records r + LEFT JOIN memory_embeddings e + ON e.namespace = r.namespace + AND e.key = r.key + AND e.provider_id = ? + WHERE {where_sql} AND e.id IS NULL + ORDER BY r.updated_at DESC + """, + (provider_id, *params), + ).fetchall() + return [self._stored_record_from_row(row) for row in rows] + finally: + conn.close() + + def _upsert_embeddings_sync( + self, + provider_id: str, + records: list[_StoredRecord], + embeddings: list[list[float]], + ) -> None: + if not records: + return + with self._lock: + conn = self._connect() + try: + for index, record in enumerate(records): + vector = embeddings[index] if index < len(embeddings) else [] + conn.execute( + """ + INSERT INTO memory_embeddings(namespace, key, provider_id, embedding_json, updated_at) + VALUES(?, ?, ?, ?, ?) + ON CONFLICT(namespace, key, provider_id) DO UPDATE SET + embedding_json = excluded.embedding_json, + updated_at = excluded.updated_at + """, + ( + record.namespace, + record.key, + provider_id, + json.dumps( + vector, ensure_ascii=False, separators=(",", ":") + ), + _utcnow().isoformat(), + ), + ) + self._mark_vector_dirty_locked(conn, provider_id) + conn.commit() + finally: + conn.close() + + def _vector_candidates_sync( + self, + provider_id: str, + query_embedding: list[float], + namespace: str | None, + include_descendants: bool, + limit: int | None, + ) -> list[_VectorCandidate]: + if not query_embedding: + return [] + with self._lock: + conn = self._connect() + try: + index = self._vector_indexes.get(provider_id) + fetch_limit = max((int(limit) if limit is not None else 10) * 10, 50) + if index is not None and self._faiss_available(): + return self._faiss_vector_candidates_locked( + conn=conn, + provider_id=provider_id, + query_embedding=query_embedding, + namespace=namespace, + include_descendants=include_descendants, + fetch_limit=fetch_limit, + ) + return self._fallback_vector_candidates_locked( + conn=conn, + provider_id=provider_id, + query_embedding=query_embedding, + namespace=namespace, + include_descendants=include_descendants, + fetch_limit=fetch_limit, + ) + finally: + conn.close() + + def _ensure_vector_index_sync(self, provider_id: str) -> None: + with self._lock: + conn = self._connect() + try: + self._init_storage_locked(conn) + row = conn.execute( + """ + SELECT dirty + FROM memory_vector_state + WHERE provider_id = ? + """, + (provider_id,), + ).fetchone() + dirty = True if row is None else bool(row[0]) + if not dirty and provider_id in self._vector_indexes: + return + + index_path = ( + self._vector_dir / f"{self._safe_filename(provider_id)}.faiss" + ) + if not dirty and index_path.exists() and self._faiss_available(): + try: + faiss = self._import_faiss() + self._vector_indexes[provider_id] = faiss.read_index( + str(index_path) + ) + self._vector_fallbacks.pop(provider_id, None) + return + except Exception: + pass + + rows = conn.execute( + """ + SELECT id, embedding_json + FROM memory_embeddings + WHERE provider_id = ? + ORDER BY id + """, + (provider_id,), + ).fetchall() + ids: list[int] = [] + vectors: list[list[float]] = [] + for raw_id, raw_vector in rows: + vector = self._load_embedding_json(raw_vector) + if not vector: + continue + ids.append(int(raw_id)) + vectors.append(vector) + + if self._faiss_available() and vectors: + faiss = self._import_faiss() + np = self._import_numpy() + dimension = len(vectors[0]) + base_index = faiss.IndexFlatIP(dimension) + index = faiss.IndexIDMap2(base_index) + index.add_with_ids( + np.array(vectors, dtype="float32"), + np.array(ids, dtype="int64"), + ) + self._vector_indexes[provider_id] = index + self._vector_fallbacks.pop(provider_id, None) + self._vector_dir.mkdir(parents=True, exist_ok=True) + faiss.write_index(index, str(index_path)) + else: + self._vector_indexes[provider_id] = None + self._vector_fallbacks[provider_id] = list( + zip(ids, vectors, strict=False) + ) + conn.execute( + """ + INSERT INTO memory_vector_state(provider_id, dirty, updated_at) + VALUES(?, 0, ?) + ON CONFLICT(provider_id) DO UPDATE SET + dirty = 0, + updated_at = excluded.updated_at + """, + (provider_id, _utcnow().isoformat()), + ) + conn.commit() + finally: + conn.close() + + def _faiss_vector_candidates_locked( + self, + *, + conn: sqlite3.Connection, + provider_id: str, + query_embedding: list[float], + namespace: str | None, + include_descendants: bool, + fetch_limit: int, + ) -> list[_VectorCandidate]: + index = self._vector_indexes.get(provider_id) + if index is None: + return [] + np = self._import_numpy() + total_count = int(getattr(index, "ntotal", 0) or 0) + if total_count <= 0: + return [] + + collected: list[_VectorCandidate] = [] + seen: set[tuple[str, str]] = set() + current_limit = min(fetch_limit, total_count) + while current_limit > 0: + scores, ids = index.search( + np.array([query_embedding], dtype="float32"), + current_limit, + ) + raw_ids = [int(item) for item in ids[0] if int(item) >= 0] + score_map = { + int(item_id): max(0.0, float(score)) + for item_id, score in zip(raw_ids, scores[0], strict=False) + } + if not score_map: + break + placeholders = ",".join("?" for _ in score_map) + rows = conn.execute( + f""" + SELECT e.id, r.namespace, r.key, r.stored_json, r.search_text + FROM memory_embeddings e + JOIN memory_records r + ON r.namespace = e.namespace AND r.key = e.key + WHERE e.provider_id = ? + AND e.id IN ({placeholders}) + """, + (provider_id, *score_map.keys()), + ).fetchall() + row_map = {int(row[0]): row for row in rows} + for item_id in raw_ids: + row = row_map.get(item_id) + if row is None: + continue + record_namespace = normalize_memory_namespace(row[1]) + if not memory_namespace_matches( + record_namespace, + namespace, + include_descendants=include_descendants, + ): + continue + identity = (record_namespace, str(row[2])) + if identity in seen: + continue + seen.add(identity) + collected.append( + _VectorCandidate( + namespace=record_namespace, + key=str(row[2]), + stored=self._load_stored_json(row[3]), + search_text=str(row[4]), + score=max(0.0, score_map.get(item_id, 0.0)), + ) + ) + if len(collected) >= fetch_limit or current_limit >= total_count: + break + next_limit = min(total_count, current_limit * 2) + if next_limit == current_limit: + break + current_limit = next_limit + return collected + + def _fallback_vector_candidates_locked( + self, + *, + conn: sqlite3.Connection, + provider_id: str, + query_embedding: list[float], + namespace: str | None, + include_descendants: bool, + fetch_limit: int, + ) -> list[_VectorCandidate]: + rows = conn.execute( + """ + SELECT e.namespace, e.key, e.embedding_json, r.stored_json, r.search_text + FROM memory_embeddings e + JOIN memory_records r + ON r.namespace = e.namespace AND r.key = e.key + WHERE e.provider_id = ? + """, + (provider_id,), + ).fetchall() + candidates: list[_VectorCandidate] = [] + for raw_namespace, raw_key, raw_embedding, raw_stored, raw_search_text in rows: + record_namespace = normalize_memory_namespace(raw_namespace) + if not memory_namespace_matches( + record_namespace, + namespace, + include_descendants=include_descendants, + ): + continue + embedding = self._load_embedding_json(raw_embedding) + score = max(0.0, cosine_similarity(query_embedding, embedding)) + if score <= 0: + continue + candidates.append( + _VectorCandidate( + namespace=record_namespace, + key=str(raw_key), + stored=self._load_stored_json(raw_stored), + search_text=str(raw_search_text), + score=score, + ) + ) + candidates.sort(key=lambda item: (-item.score, item.namespace, item.key)) + return candidates[:fetch_limit] + + def _purge_expired_sync(self) -> None: + with self._lock: + conn = self._connect() + try: + self._purge_expired_locked(conn) + conn.commit() + finally: + conn.close() + + def _purge_expired_locked(self, conn: sqlite3.Connection) -> None: + self._init_storage_locked(conn) + now_iso = _utcnow().isoformat() + rows = conn.execute( + """ + SELECT namespace, key + FROM memory_records + WHERE expires_at IS NOT NULL AND expires_at <= ? + """, + (now_iso,), + ).fetchall() + for namespace, key in rows: + self._delete_record_locked( + conn, + namespace=normalize_memory_namespace(namespace), + key=str(key), + ) + + def _delete_record_locked( + self, + conn: sqlite3.Connection, + *, + namespace: str, + key: str, + ) -> bool: + provider_rows = conn.execute( + """ + SELECT DISTINCT provider_id + FROM memory_embeddings + WHERE namespace = ? AND key = ? + """, + (namespace, key), + ).fetchall() + conn.execute( + "DELETE FROM memory_embeddings WHERE namespace = ? AND key = ?", + (namespace, key), + ) + deleted = ( + conn.execute( + "DELETE FROM memory_records WHERE namespace = ? AND key = ?", + (namespace, key), + ).rowcount + > 0 + ) + if self._fts_enabled: + conn.execute( + "DELETE FROM memory_records_fts WHERE namespace = ? AND key = ?", + (namespace, key), + ) + for row in provider_rows: + provider_id = str(row[0]).strip() + if provider_id: + self._mark_vector_dirty_locked(conn, provider_id) + return deleted + + def _delete_scope_locked( + self, + conn: sqlite3.Connection, + *, + namespace: str | None, + include_descendants: bool, + ) -> int: + where_sql, params = self._namespace_where( + namespace, + include_descendants=include_descendants, + ) + affected_rows = conn.execute( + f""" + SELECT namespace, key + FROM memory_records + WHERE {where_sql} + """, + params, + ).fetchall() + if not affected_rows: + return 0 + + pair_placeholders = ", ".join("(?, ?)" for _ in affected_rows) + pair_params = tuple( + value + for raw_namespace, raw_key in affected_rows + for value in (normalize_memory_namespace(raw_namespace), str(raw_key)) + ) + + provider_rows = conn.execute( + f""" + SELECT DISTINCT provider_id + FROM memory_embeddings + WHERE (namespace, key) IN ({pair_placeholders}) + """, + pair_params, + ).fetchall() + conn.execute( + f""" + DELETE FROM memory_embeddings + WHERE (namespace, key) IN ({pair_placeholders}) + """, + pair_params, + ) + if self._fts_enabled: + conn.execute( + f""" + DELETE FROM memory_records_fts + WHERE (namespace, key) IN ({pair_placeholders}) + """, + pair_params, + ) + deleted = conn.execute( + f""" + DELETE FROM memory_records + WHERE (namespace, key) IN ({pair_placeholders}) + """, + pair_params, + ).rowcount + for row in provider_rows: + provider_id = str(row[0]).strip() + if provider_id: + self._mark_vector_dirty_locked(conn, provider_id) + return deleted + + def _connect(self) -> sqlite3.Connection: + self._base_dir.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(self._db_path) + conn.row_factory = sqlite3.Row + self._init_storage_locked(conn) + return conn + + def _init_storage_locked(self, conn: sqlite3.Connection) -> None: + if self._initialized: + return + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + conn.execute( + """ + CREATE TABLE IF NOT EXISTS memory_records ( + namespace TEXT NOT NULL, + key TEXT NOT NULL, + stored_json TEXT NOT NULL, + search_text TEXT NOT NULL, + expires_at TEXT, + updated_at TEXT NOT NULL, + PRIMARY KEY(namespace, key) + ) + """ + ) + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_memory_records_namespace + ON memory_records(namespace) + """ + ) + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_memory_records_expires_at + ON memory_records(expires_at) + """ + ) + try: + conn.execute( + """ + CREATE VIRTUAL TABLE IF NOT EXISTS memory_records_fts + USING fts5(namespace UNINDEXED, key, search_text, tokenize='unicode61') + """ + ) + self._fts_enabled = True + except sqlite3.OperationalError: + self._fts_enabled = False + conn.execute( + """ + CREATE TABLE IF NOT EXISTS memory_embeddings ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + namespace TEXT NOT NULL, + key TEXT NOT NULL, + provider_id TEXT NOT NULL, + embedding_json TEXT NOT NULL, + updated_at TEXT NOT NULL, + UNIQUE(namespace, key, provider_id) + ) + """ + ) + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_memory_embeddings_provider + ON memory_embeddings(provider_id, namespace) + """ + ) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS memory_vector_state ( + provider_id TEXT PRIMARY KEY, + dirty INTEGER NOT NULL DEFAULT 1, + updated_at TEXT NOT NULL + ) + """ + ) + conn.commit() + self._initialized = True + + def _sync_fts_row_locked( + self, + conn: sqlite3.Connection, + *, + namespace: str, + key: str, + search_text: str, + ) -> None: + if not self._fts_enabled: + return + conn.execute( + "DELETE FROM memory_records_fts WHERE namespace = ? AND key = ?", + (namespace, key), + ) + conn.execute( + """ + INSERT INTO memory_records_fts(namespace, key, search_text) + VALUES(?, ?, ?) + """, + (namespace, key, search_text), + ) + + def _mark_vector_dirty_locked( + self, + conn: sqlite3.Connection, + provider_id: str, + ) -> None: + conn.execute( + """ + INSERT INTO memory_vector_state(provider_id, dirty, updated_at) + VALUES(?, 1, ?) + ON CONFLICT(provider_id) DO UPDATE SET + dirty = 1, + updated_at = excluded.updated_at + """, + (provider_id, _utcnow().isoformat()), + ) + self._vector_indexes.pop(provider_id, None) + self._vector_fallbacks.pop(provider_id, None) + + @staticmethod + def _combined_score( + *, + mode: str, + keyword_score: float, + vector_score: float, + ) -> float: + if mode == "keyword": + return keyword_score + if mode == "vector": + return vector_score + if keyword_score > 0 and vector_score > 0: + return min(1.0, 0.65 * vector_score + 0.35 * keyword_score + 0.05) + if vector_score > 0: + return min(1.0, vector_score) + return min(1.0, keyword_score) + + @staticmethod + def _load_stored_json(raw_value: Any) -> dict[str, Any]: + if isinstance(raw_value, dict): + return dict(raw_value) + if isinstance(raw_value, str): + decoded = json.loads(raw_value) + return dict(decoded) if isinstance(decoded, dict) else {} + return {} + + @staticmethod + def _load_embedding_json(raw_value: Any) -> list[float]: + if isinstance(raw_value, list): + return [float(item) for item in raw_value] + if isinstance(raw_value, str): + decoded = json.loads(raw_value) + if isinstance(decoded, list): + return [float(item) for item in decoded] + return [] + + @staticmethod + def _stored_record_from_row(row: Any) -> _StoredRecord: + return _StoredRecord( + namespace=normalize_memory_namespace(row[0]), + key=str(row[1]), + stored=PluginMemoryBackend._load_stored_json(row[2]), + search_text=str(row[3]), + updated_at=str(row[4]), + ) + + @staticmethod + def _namespace_where( + namespace: str | None, + *, + include_descendants: bool, + alias: str | None = None, + ) -> tuple[str, tuple[Any, ...]]: + column = f"{alias}.namespace" if alias else "namespace" + if namespace is None: + return "1 = 1", () + normalized_namespace = normalize_memory_namespace(namespace) + if not normalized_namespace: + if include_descendants: + return "1 = 1", () + return f"{column} = ''", () + if include_descendants: + escaped_namespace = _escape_like_value(normalized_namespace) + return ( + f"({column} = ? OR {column} LIKE ? ESCAPE '\\')", + (normalized_namespace, f"{escaped_namespace}/%"), + ) + return f"{column} = ?", (normalized_namespace,) + + @staticmethod + def _fts_query(query: str) -> str | None: + stripped = str(query).strip() + if not stripped: + return None + terms = [ + item for item in re.findall(r"\w+", stripped, flags=re.UNICODE) if item + ] + if not terms: + return None + escaped_terms = [term.replace('"', '""') for term in terms[:8]] + return " OR ".join(f'"{term}"' for term in escaped_terms) + + @staticmethod + def _safe_filename(value: str) -> str: + return re.sub(r"[^A-Za-z0-9_.-]+", "_", str(value)).strip("._") or "default" + + @staticmethod + def _import_faiss() -> Any: + # FAISS often ships without stable type stubs, so keep the lazy import + # boundary explicitly dynamic to avoid false-positive Pylance errors. + import faiss + + return cast(Any, faiss) + + @staticmethod + def _import_numpy(): + import numpy + + return numpy + + @classmethod + def _faiss_available(cls) -> bool: + try: + faiss = cls._import_faiss() + cls._import_numpy() + except Exception: + return False + required_attrs = ( + "IndexFlatIP", + "IndexIDMap2", + "read_index", + "write_index", + ) + return all(hasattr(faiss, attr) for attr in required_attrs) + + def _vector_backend_label(self) -> str: + return "faiss" if self._faiss_available() else "exact" + + +async def _maybe_await(value: Any) -> Any: + if asyncio.iscoroutine(value) or isinstance(value, asyncio.Future): + return await value + return value + + +def extend_memory_namespace( + base_namespace: str | None, + extra_namespace: str | None, +) -> str: + """Join a base namespace with a relative namespace override.""" + + return join_memory_namespace(base_namespace, extra_namespace) diff --git a/astrbot-sdk/src/astrbot_sdk/_message_types.py b/astrbot-sdk/src/astrbot_sdk/_message_types.py new file mode 100644 index 0000000000..1d2df56040 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_message_types.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import Any + +_GROUP_MESSAGE_TYPES = {"group", "groupmessage", "group_message"} +_PRIVATE_MESSAGE_TYPES = { + "private", + "privatemessage", + "private_message", + "friend", + "friendmessage", + "friend_message", +} +_OTHER_MESSAGE_TYPES = {"other", "othermessage", "other_message"} + + +def normalize_message_type( + value: Any, + *, + group_id: str | None = None, + user_id: str | None = None, + empty_default: str = "", +) -> str: + """Collapse SDK-visible message types to canonical values.""" + + normalized = str(getattr(value, "value", value) or "").strip().lower() + if normalized in _GROUP_MESSAGE_TYPES: + return "group" + if normalized in _PRIVATE_MESSAGE_TYPES: + return "private" + if normalized in _OTHER_MESSAGE_TYPES: + return "other" + if group_id: + return "group" + if user_id: + return "private" + if not normalized: + return empty_default + return "other" diff --git a/astrbot-sdk/src/astrbot_sdk/_plugin_logger.py b/astrbot-sdk/src/astrbot_sdk/_plugin_logger.py new file mode 100644 index 0000000000..5d2a3d9b17 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_plugin_logger.py @@ -0,0 +1,3 @@ +from ._internal.plugin_logger import PluginLogEntry, PluginLogger + +__all__ = ["PluginLogEntry", "PluginLogger"] diff --git a/astrbot-sdk/src/astrbot_sdk/_star_runtime.py b/astrbot-sdk/src/astrbot_sdk/_star_runtime.py new file mode 100644 index 0000000000..d6d9fe215d --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_star_runtime.py @@ -0,0 +1,13 @@ +from ._internal.star_runtime import ( + bind_star_runtime, + current_runtime_context, + current_star_context, + current_star_instance, +) + +__all__ = [ + "bind_star_runtime", + "current_runtime_context", + "current_star_context", + "current_star_instance", +] diff --git a/astrbot-sdk/src/astrbot_sdk/_testing_support.py b/astrbot-sdk/src/astrbot_sdk/_testing_support.py new file mode 100644 index 0000000000..1e945e8e06 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/_testing_support.py @@ -0,0 +1,25 @@ +from ._internal.testing_support import ( + InMemoryDB, + InMemoryMemory, + MockCapabilityRouter, + MockContext, + MockLLMClient, + MockMessageEvent, + MockPeer, + MockPlatformClient, + RecordedSend, + StdoutPlatformSink, +) + +__all__ = [ + "InMemoryDB", + "InMemoryMemory", + "MockCapabilityRouter", + "MockContext", + "MockLLMClient", + "MockMessageEvent", + "MockPeer", + "MockPlatformClient", + "RecordedSend", + "StdoutPlatformSink", +] diff --git a/astrbot-sdk/src/astrbot_sdk/cli.py b/astrbot-sdk/src/astrbot_sdk/cli.py new file mode 100644 index 0000000000..3ae1cc86c1 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/cli.py @@ -0,0 +1,1267 @@ +"""AstrBot SDK 的命令行入口。 + +本模块提供 astrbot-sdk 命令行工具的所有子命令,包括: +- init: 创建新插件骨架,生成 plugin.yaml、main.py、README.md 等模板文件 +- validate: 校验插件清单、导入路径和 handler 发现是否正常 +- build: 将插件打包为 .zip 发布包 +- dev: 本地开发模式,支持 --local/--watch/--interactive 等调试选项 +- run: 启动插件主管进程(supervisor),通过 stdio 与 AstrBot 核心通信 +- worker: 内部命令,由 supervisor 调用以启动单个插件工作进程 + +错误处理: +所有 CLI 异常都会被分类并返回标准化的退出码和错误提示, +便于 CI/CD 集成和用户快速定位问题。 +""" + +from __future__ import annotations + +import asyncio +import importlib.resources as resources +import os +import re +import sys +import typing +import zipfile +from collections.abc import Coroutine +from dataclasses import dataclass, field +from importlib.resources.abc import Traversable +from pathlib import Path +from textwrap import dedent +from typing import Any + +import click +from loguru import logger + +from .errors import AstrBotError +from .runtime.bootstrap import run_plugin_worker, run_supervisor, run_websocket_server +from .runtime.loader import load_plugin, load_plugin_spec, validate_plugin_spec + +EXIT_OK = 0 +EXIT_UNEXPECTED = 1 +EXIT_USAGE = 2 +EXIT_PLUGIN_LOAD = 3 +EXIT_RUNTIME = 4 +EXIT_PLUGIN_EXECUTION = 5 +BUILD_EXCLUDED_DIRS = { + ".agents", + ".claude", + ".git", + ".idea", + ".mypy_cache", + ".opencode", + ".pytest_cache", + ".ruff_cache", + ".venv", + "__pycache__", + "dist", +} +BUILD_EXCLUDED_FILES = { + ".astrbot-worker-state.json", +} +WATCH_POLL_INTERVAL_SECONDS = 0.5 +SUPPORTED_INIT_AGENTS = ("claude", "codex", "opencode") +_TEMPLATE_VARIABLE_PATTERN = re.compile(r"{{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*}}") +INIT_AGENT_SKILL_ROOTS = { + "claude": Path(".claude") / "skills", + "codex": Path(".agents") / "skills", + "opencode": Path(".opencode") / "skills", +} +INIT_AGENT_DISPLAY_NAMES = { + "claude": "Claude Code", + "codex": "Codex", + "opencode": "OpenCode", +} +INIT_SKILL_TEMPLATE_NAME = "astrbot-plugin-dev" + + +class _CliPluginValidationError(RuntimeError): + """CLI 侧的插件结构或打包校验失败。""" + + +class _CliPluginLoadError(RuntimeError): + """CLI 侧的本地开发插件加载失败。""" + + +class _CliPluginExecutionError(RuntimeError): + """CLI 侧的本地开发插件执行失败。""" + + +@dataclass(slots=True) +class _PluginTreeWatcher: + plugin_dir: Path + snapshot: dict[str, tuple[int, int]] = field(init=False, default_factory=dict) + + def __post_init__(self) -> None: + self.snapshot = _snapshot_watch_files(self.plugin_dir) + + def poll_changes(self) -> list[str]: + current = _snapshot_watch_files(self.plugin_dir) + changed = sorted( + path + for path in set(self.snapshot) | set(current) + if self.snapshot.get(path) != current.get(path) + ) + self.snapshot = current + return changed + + +def setup_logger(verbose: bool = False) -> None: + """初始化 CLI 使用的日志配置。""" + logger.remove() + logger.add( + sys.stderr, + format="{time:HH:mm:ss} | {level: <8} | {message}", + level="DEBUG" if verbose else "INFO", + colorize=True, + ) + + +def _resolve_protocol_stdout( + protocol_stdout: str | None, +) -> tuple[typing.TextIO, typing.TextIO | None]: + configured = str(protocol_stdout).strip() if protocol_stdout is not None else "" + if not configured: + stdout = sys.stdout + if callable(getattr(stdout, "isatty", None)) and stdout.isatty(): + opened_stdout = open(os.devnull, "w", encoding="utf-8") + return opened_stdout, opened_stdout + return stdout, None + if configured.lower() == "console": + return sys.stdout, None + output_path = os.devnull if configured.lower() == "silent" else configured + opened_stdout = open(output_path, "w", encoding="utf-8") + return opened_stdout, opened_stdout + + +def _run_async_entrypoint( + entrypoint: Coroutine[Any, Any, object], + *, + log_message: str, + log_level: str = "info", + context: dict[str, Any] | None = None, +) -> None: + log_method = getattr(logger, log_level) + log_method(log_message) + try: + asyncio.run(entrypoint) + except (click.Abort, KeyboardInterrupt): + click.echo("\n创建插件已优雅地中断。", err=True) + raise SystemExit(130) + except Exception as exc: + exit_code, error_code, hint = _classify_cli_exception(exc) + docs_url = exc.docs_url if isinstance(exc, AstrBotError) else "" + details = exc.details if isinstance(exc, AstrBotError) else None + _render_cli_error( + error_code=error_code, + message=str(exc), + hint=hint, + docs_url=docs_url, + details=details, + context=context, + ) + if exit_code == EXIT_UNEXPECTED: + logger.exception("CLI 异常退出") + raise SystemExit(exit_code) from exc + + +def _run_sync_entrypoint( + entrypoint: typing.Callable[[], object], + *, + log_message: str, + log_level: str = "info", + context: dict[str, Any] | None = None, +) -> None: + log_method = getattr(logger, log_level) + log_method(log_message) + try: + entrypoint() + except (click.Abort, KeyboardInterrupt): + click.echo("\n创建插件已优雅地中断。", err=True) + raise SystemExit(130) + except Exception as exc: + exit_code, error_code, hint = _classify_cli_exception(exc) + docs_url = exc.docs_url if isinstance(exc, AstrBotError) else "" + details = exc.details if isinstance(exc, AstrBotError) else None + _render_cli_error( + error_code=error_code, + message=str(exc), + hint=hint, + docs_url=docs_url, + details=details, + context=context, + ) + if exit_code == EXIT_UNEXPECTED: + logger.exception("CLI 异常退出") + raise SystemExit(exit_code) from exc + + +def _classify_cli_exception(exc: Exception) -> tuple[int, str, str]: + if isinstance(exc, AstrBotError): + return ( + EXIT_RUNTIME, + exc.code, + exc.hint or "请检查本地 mock core 与插件调用参数", + ) + if isinstance( + exc, + ( + _CliPluginValidationError, + _CliPluginLoadError, + FileNotFoundError, + ImportError, + ModuleNotFoundError, + ), + ): + return ( + EXIT_PLUGIN_LOAD, + "plugin_load_error", + "请检查插件目录、plugin.yaml、requirements.txt(如有)和导入路径", + ) + if isinstance(exc, LookupError): + return ( + EXIT_RUNTIME, + "dispatch_error", + "请检查 handler 或 capability 是否已正确注册", + ) + if isinstance(exc, _CliPluginExecutionError): + return ( + EXIT_PLUGIN_EXECUTION, + "plugin_execution_error", + "请检查插件生命周期、handler 或 capability 的实现", + ) + return ( + EXIT_UNEXPECTED, + "unexpected_error", + "请查看详细日志,必要时使用 --verbose 重试", + ) + + +def _render_cli_error( + *, + error_code: str, + message: str, + hint: str = "", + docs_url: str = "", + details: dict[str, Any] | None = None, + context: dict[str, Any] | None = None, +) -> None: + click.echo(f"Error[{error_code}]: {message}", err=True) + if hint: + click.echo(f"Suggestion: {hint}", err=True) + if docs_url: + click.echo(f"Docs: {docs_url}", err=True) + if details: + click.echo(f"Details: {details}", err=True) + if not context: + return + for key, value in context.items(): + click.echo(f"{key}: {value}", err=True) + + +def _render_nonfatal_dev_error( + exc: Exception, + *, + context: dict[str, Any] | None = None, +) -> None: + exit_code, error_code, hint = _classify_cli_exception(exc) + _render_cli_error( + error_code=error_code, + message=str(exc), + hint=hint, + context=context, + ) + if exit_code == EXIT_UNEXPECTED: + logger.exception("watch 模式收到未分类异常") + + +def _iter_watch_files(plugin_dir: Path) -> typing.Iterator[Path]: + root = plugin_dir.resolve() + for path in sorted(root.rglob("*")): + if path.is_dir(): + continue + relative = path.relative_to(root) + if any(part in BUILD_EXCLUDED_DIRS for part in relative.parts[:-1]): + continue + if relative.name in BUILD_EXCLUDED_FILES: + continue + if path.suffix in {".pyc", ".pyo"}: + continue + yield path + + +def _snapshot_watch_files(plugin_dir: Path) -> dict[str, tuple[int, int]]: + root = plugin_dir.resolve() + snapshot: dict[str, tuple[int, int]] = {} + for path in _iter_watch_files(root): + try: + stat = path.stat() + except FileNotFoundError: + continue + snapshot[path.relative_to(root).as_posix()] = ( + stat.st_mtime_ns, + stat.st_size, + ) + return snapshot + + +def _format_watch_changes(changes: list[str], *, limit: int = 5) -> str: + if not changes: + return "未知文件" + preview = changes[:limit] + text = ", ".join(preview) + if len(changes) > limit: + text += f" 等 {len(changes)} 个文件" + return text + + +class _ReloadableLocalDevRunner: + def __init__( + self, + *, + plugin_dir: Path, + state: dict[str, Any], + plugin_load_error: type[Exception], + plugin_execution_error: type[Exception], + plugin_harness, + stdout_platform_sink, + ) -> None: + self.plugin_dir = plugin_dir + self.state = state + self._plugin_load_error = plugin_load_error + self._plugin_execution_error = plugin_execution_error + self._plugin_harness = plugin_harness + self._stdout_platform_sink = stdout_platform_sink + self._harness = None + self._lock = asyncio.Lock() + + async def close(self) -> None: + async with self._lock: + await self._stop_harness() + + async def reload(self) -> bool: + async with self._lock: + await self._stop_harness() + harness = self._plugin_harness.from_plugin_dir( + self.plugin_dir, + session_id=str(self.state["session_id"]), + user_id=str(self.state["user_id"]), + platform=str(self.state["platform"]), + group_id=typing.cast(str | None, self.state["group_id"]), + event_type=str(self.state["event_type"]), + platform_sink=self._stdout_platform_sink(stream=sys.stdout), + ) + try: + await harness.start() + except self._plugin_load_error as exc: + _render_nonfatal_dev_error( + _CliPluginLoadError(str(exc)), + context={"plugin_dir": self.plugin_dir}, + ) + return False + except self._plugin_execution_error as exc: + _render_nonfatal_dev_error( + _CliPluginExecutionError(str(exc)), + context={"plugin_dir": self.plugin_dir}, + ) + return False + self._harness = harness + return True + + async def dispatch_text(self, text: str) -> bool: + async with self._lock: + if self._harness is None: + click.echo("当前插件未成功加载,等待下一次文件变更后重试。") + return False + try: + await self._harness.dispatch_text( + text, + session_id=str(self.state["session_id"]), + user_id=str(self.state["user_id"]), + platform=str(self.state["platform"]), + group_id=typing.cast(str | None, self.state["group_id"]), + event_type=str(self.state["event_type"]), + ) + except (self._plugin_load_error, self._plugin_execution_error) as exc: + _render_nonfatal_dev_error( + _CliPluginExecutionError(str(exc)), + context={"plugin_dir": self.plugin_dir}, + ) + return False + except Exception as exc: + _render_nonfatal_dev_error( + exc, + context={"plugin_dir": self.plugin_dir}, + ) + return False + return True + + async def _stop_harness(self) -> None: + if self._harness is None: + return + try: + await self._harness.stop() + finally: + self._harness = None + + +async def _run_local_dev_watch( + *, + runner: _ReloadableLocalDevRunner, + event_text: str | None, + interactive: bool, + watch_poll_interval: float, + max_watch_reloads: int | None = None, +) -> None: + watcher = _PluginTreeWatcher(runner.plugin_dir) + reload_count = 0 + + async def reload_and_maybe_rerun(*, announce: str | None) -> None: + if announce: + click.echo(announce) + if not await runner.reload(): + return + if event_text is not None: + await runner.dispatch_text(event_text) + + async def watch_loop(stop_event: asyncio.Event) -> None: + nonlocal reload_count + while not stop_event.is_set(): + await asyncio.sleep(watch_poll_interval) + changes = watcher.poll_changes() + if not changes: + continue + await reload_and_maybe_rerun( + announce=( + f"检测到文件变更,重新加载插件:{_format_watch_changes(changes)}" + ) + ) + reload_count += 1 + if max_watch_reloads is not None and reload_count >= max_watch_reloads: + stop_event.set() + return + + stop_event = asyncio.Event() + watch_task: asyncio.Task[None] | None = None + try: + await reload_and_maybe_rerun( + announce=( + "watch 模式已启动,监听插件目录变更。" + if event_text is not None + else "watch 模式已启动,监听插件目录变更并按需热重载。" + ) + ) + if max_watch_reloads == 0: + return + watch_task = asyncio.create_task(watch_loop(stop_event)) + if interactive: + click.echo( + "本地交互模式已启动。可用命令:/session /user /platform /group /private /event /exit" + ) + while not stop_event.is_set(): + line = await asyncio.to_thread(sys.stdin.readline) + if not line: + break + text = line.strip() + if not text: + continue + if _handle_dev_meta_command(text, runner.state): + if text in {"/exit", "/quit"}: + break + continue + await runner.dispatch_text(text) + stop_event.set() + return + await stop_event.wait() + finally: + stop_event.set() + if watch_task is not None: + watch_task.cancel() + try: + await watch_task + except asyncio.CancelledError: + pass + await runner.close() + + +async def _run_local_dev( + *, + plugin_dir: Path, + event_text: str | None, + interactive: bool, + watch: bool, + session_id: str, + user_id: str, + platform: str, + group_id: str | None, + event_type: str, + watch_poll_interval: float = WATCH_POLL_INTERVAL_SECONDS, + max_watch_reloads: int | None = None, +) -> None: + from .testing import ( + PluginHarness, + StdoutPlatformSink, + _PluginExecutionError, + _PluginLoadError, + ) + + state = { + "session_id": session_id, + "user_id": user_id, + "platform": platform, + "group_id": group_id, + "event_type": event_type, + } + if watch: + runner = _ReloadableLocalDevRunner( + plugin_dir=plugin_dir, + state=state, + plugin_load_error=_PluginLoadError, + plugin_execution_error=_PluginExecutionError, + plugin_harness=PluginHarness, + stdout_platform_sink=StdoutPlatformSink, + ) + await _run_local_dev_watch( + runner=runner, + event_text=event_text, + interactive=interactive, + watch_poll_interval=watch_poll_interval, + max_watch_reloads=max_watch_reloads, + ) + return + + sink = StdoutPlatformSink(stream=sys.stdout) + harness = PluginHarness.from_plugin_dir( + plugin_dir, + session_id=session_id, + user_id=user_id, + platform=platform, + group_id=group_id, + event_type=event_type, + platform_sink=sink, + ) + try: + async with harness: + if interactive: + click.echo( + "本地交互模式已启动。可用命令:/session /user /platform /group /private /event /exit" + ) + while True: + line = await asyncio.to_thread(sys.stdin.readline) + if not line: + break + text = line.strip() + if not text: + continue + if _handle_dev_meta_command(text, state): + if text in {"/exit", "/quit"}: + break + continue + await harness.dispatch_text( + text, + session_id=str(state["session_id"]), + user_id=str(state["user_id"]), + platform=str(state["platform"]), + group_id=typing.cast(str | None, state["group_id"]), + event_type=str(state["event_type"]), + ) + return + assert event_text is not None + await harness.dispatch_text( + event_text, + session_id=session_id, + user_id=user_id, + platform=platform, + group_id=group_id, + event_type=event_type, + ) + except _PluginLoadError as exc: + raise _CliPluginLoadError(str(exc)) from exc + except _PluginExecutionError as exc: + raise _CliPluginExecutionError(str(exc)) from exc + + +def _handle_dev_meta_command(command: str, state: dict[str, Any]) -> bool: + if command in {"/exit", "/quit"}: + return True + if command.startswith("/session "): + state["session_id"] = command.split(" ", 1)[1].strip() + click.echo(f"切换 session_id -> {state['session_id']}") + return True + if command.startswith("/user "): + state["user_id"] = command.split(" ", 1)[1].strip() + click.echo(f"切换 user_id -> {state['user_id']}") + return True + if command.startswith("/platform "): + state["platform"] = command.split(" ", 1)[1].strip() + click.echo(f"切换 platform -> {state['platform']}") + return True + if command.startswith("/group "): + state["group_id"] = command.split(" ", 1)[1].strip() + click.echo(f"切换 group_id -> {state['group_id']}") + return True + if command == "/private": + state["group_id"] = None + click.echo("已切换为私聊上下文") + return True + if command.startswith("/event "): + state["event_type"] = command.split(" ", 1)[1].strip() + click.echo(f"切换 event_type -> {state['event_type']}") + return True + return False + + +def _slugify_plugin_name(value: str) -> str: + slug = re.sub(r"[^a-zA-Z0-9]+", "_", value).strip("_").lower() + return slug or "my_plugin" + + +def _normalize_plugin_name(value: str) -> str: + normalized = _slugify_plugin_name(value) + if normalized.startswith("astrbot_plugin_"): + return normalized + normalized = normalized.removeprefix("astrbot_plugin") + normalized = normalized.strip("_") + suffix = normalized or "my_plugin" + return f"astrbot_plugin_{suffix}" + + +def _class_name_for_plugin(value: str) -> str: + parts = [part for part in re.split(r"[^a-zA-Z0-9]+", value) if part] + if not parts: + return "MyPlugin" + return "".join(part[:1].upper() + part[1:] for part in parts) + + +def _sanitize_build_part(value: str) -> str: + sanitized = re.sub(r"[^a-zA-Z0-9._-]+", "_", value).strip("._-") + return sanitized or "artifact" + + +def _parse_init_agents( + _ctx: click.Context, + _param: click.Parameter, + value: str | None, +) -> tuple[str, ...]: + if value is None: + return () + + normalized_agents: list[str] = [] + seen: set[str] = set() + invalid_agents: list[str] = [] + for raw_agent in value.split(","): + candidate = raw_agent.strip().lower() + if not candidate: + invalid_agents.append("") + continue + if candidate not in SUPPORTED_INIT_AGENTS: + invalid_agents.append(raw_agent.strip()) + continue + if candidate in seen: + continue + seen.add(candidate) + normalized_agents.append(candidate) + + if invalid_agents: + supported = ", ".join(SUPPORTED_INIT_AGENTS) + invalid = ", ".join(invalid_agents) + raise click.BadParameter(f"仅支持以下 agent: {supported};非法值: {invalid}") + return tuple(normalized_agents) + + +def _render_init_plugin_yaml( + *, + plugin_name: str, + display_name: str, + desc: str, + author: str, + version: str, +) -> str: + python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + class_name = _class_name_for_plugin(plugin_name) + return dedent( + f"""\ + name: {plugin_name} + display_name: {display_name} + desc: {desc} + author: {author} + version: {version} + runtime: + python: "{python_version}" + components: + - class: main:{class_name} + """ + ) + + +def _render_init_main_py(*, plugin_name: str) -> str: + class_name = _class_name_for_plugin(plugin_name) + return dedent( + f"""\ + from astrbot_sdk import Context, MessageEvent, Star, on_command + + + class {class_name}(Star): + @on_command("hello") + async def hello(self, event: MessageEvent, ctx: Context) -> None: + await event.reply("Hello, World!") + """ + ) + + +def _render_init_readme(*, plugin_name: str) -> str: + return dedent( + f"""\ + # {plugin_name} + + 一个最小可运行的 AstrBot SDK v4 插件。 + + ## 目录结构 + + ``` + . + ├── plugin.yaml + ├── requirements.txt + ├── main.py + └── tests + └── test_plugin.py + ``` + + ## 本地开发 + + ```bash + astrbot-sdk validate + astrbot-sdk dev --local --event-text hello + astrbot-sdk dev --local --watch --event-text hello + ``` + + ## 运行测试 + + ```bash + python -m pytest tests/test_plugin.py -v + ``` + """ + ) + + +def _render_init_test_py(*, plugin_name: str) -> str: + class_name = _class_name_for_plugin(plugin_name) + return dedent( + f"""\ + from pathlib import Path + + import pytest + + from astrbot_sdk.testing import MockContext, MockMessageEvent, PluginHarness + from main import {class_name} + + + @pytest.mark.asyncio + async def test_hello_handler(): + plugin = {class_name}() + ctx = MockContext( + plugin_id="{plugin_name}", + plugin_metadata={{"display_name": "{class_name}"}}, + ) + event = MockMessageEvent(text="/hello", context=ctx) + + await plugin.hello(event, ctx) + + assert event.replies == ["Hello, World!"] + ctx.platform.assert_sent("Hello, World!") + + + @pytest.mark.asyncio + async def test_hello_dispatch(): + plugin_dir = Path(__file__).resolve().parents[1] + + async with PluginHarness.from_plugin_dir(plugin_dir) as harness: + records = await harness.dispatch_text("hello") + + assert any(record.text == "Hello, World!" for record in records) + """ + ) + + +def _plugin_root_hint_for_agent(agent: str) -> str: + skill_dir = INIT_AGENT_SKILL_ROOTS[agent] / INIT_SKILL_TEMPLATE_NAME + return "/".join(".." for _ in skill_dir.parts) or "." + + +def _build_agent_template_context( + *, + plugin_name: str, + display_name: str, + agent: str, +) -> dict[str, str]: + return { + "plugin_name": plugin_name, + "display_name": display_name, + "class_name": _class_name_for_plugin(plugin_name), + "skill_name": f"{plugin_name}_project", + "plugin_root": _plugin_root_hint_for_agent(agent), + "agent_name": agent, + "agent_display_name": INIT_AGENT_DISPLAY_NAMES[agent], + "skill_dir_name": INIT_SKILL_TEMPLATE_NAME, + } + + +def _render_template_text(template_text: str, context: dict[str, str]) -> str: + def replace(match: re.Match[str]) -> str: + key = match.group(1) + if key not in context: + raise _CliPluginValidationError(f"agent 模板变量未定义:{key}") + return context[key] + + return _TEMPLATE_VARIABLE_PATTERN.sub(replace, template_text) + + +def _copy_rendered_template_tree( + source_dir: Traversable, + target_dir: Path, + *, + context: dict[str, str], +) -> None: + target_dir.mkdir(parents=True, exist_ok=True) + for entry in sorted(source_dir.iterdir(), key=lambda item: item.name): + destination = target_dir / entry.name + if entry.is_dir(): + _copy_rendered_template_tree(entry, destination, context=context) + continue + destination.write_text( + _render_template_text(entry.read_text(encoding="utf-8"), context), + encoding="utf-8", + ) + + +def _render_init_agent_templates( + *, + target_dir: Path, + plugin_name: str, + display_name: str, + agents: tuple[str, ...], +) -> None: + if not agents: + return + + template_root = resources.files("astrbot_sdk").joinpath( + "templates", + "skills", + INIT_SKILL_TEMPLATE_NAME, + ) + if not template_root.is_dir(): + raise _CliPluginValidationError( + f"未找到项目级 skill 模板:{INIT_SKILL_TEMPLATE_NAME}" + ) + + for agent in agents: + context = _build_agent_template_context( + plugin_name=plugin_name, + display_name=display_name, + agent=agent, + ) + _copy_rendered_template_tree( + template_root, + target_dir / INIT_AGENT_SKILL_ROOTS[agent] / INIT_SKILL_TEMPLATE_NAME, + context=context, + ) + + +def _ensure_plugin_dir_exists(plugin_dir: Path) -> Path: + resolved = plugin_dir.resolve() + if not resolved.exists() or not resolved.is_dir(): + raise _CliPluginValidationError(f"插件目录不存在:{plugin_dir}") + return resolved + + +def _resolve_dev_plugin_dir(plugin_dir: Path | None) -> Path: + if plugin_dir is not None: + return plugin_dir + current_dir = Path.cwd() + if (current_dir / "plugin.yaml").exists(): + return Path(".") + raise click.BadParameter( + "未提供 --plugin-dir,且当前目录未找到 plugin.yaml", + param_hint="--plugin-dir", + ) + + +def _load_validated_plugin(plugin_dir: Path) -> tuple[Any, Any]: + resolved_dir = _ensure_plugin_dir_exists(plugin_dir) + plugin = load_plugin_spec(resolved_dir) + try: + validate_plugin_spec(plugin) + except ValueError as exc: + raise _CliPluginValidationError(str(exc)) from exc + + loaded = load_plugin(plugin) + if not loaded.instances: + raise _CliPluginValidationError( + "未找到可加载的组件,请检查 plugin.yaml 中的 components" + ) + return plugin, loaded + + +def _build_kind(plugin: Any) -> str: + return ( + "legacy-main" + if bool(plugin.manifest_data.get("__legacy_main__")) + else "plugin-yaml" + ) + + +def _path_is_within(path: Path, root: Path) -> bool: + try: + path.resolve().relative_to(root.resolve()) + except ValueError: + return False + return True + + +def _iter_build_files(plugin_dir: Path, output_dir: Path) -> list[Path]: + files: list[Path] = [] + for path in sorted(plugin_dir.rglob("*")): + if path.is_dir(): + continue + if _path_is_within(path, output_dir): + continue + relative = path.relative_to(plugin_dir) + if any(part in BUILD_EXCLUDED_DIRS for part in relative.parts[:-1]): + continue + if relative.name in BUILD_EXCLUDED_FILES: + continue + if path.suffix in {".pyc", ".pyo"}: + continue + files.append(path) + return files + + +def _prompt_nonempty_text(prompt: str) -> str: + while True: + value = click.prompt(prompt, type=str, default="", show_default=False).strip() + if value: + return value + click.echo("该字段不能为空,请重新输入。") + + +def _collect_init_metadata(name: str | None) -> tuple[str, str, str, str]: + if name is not None: + return name, "", "", "1.0.0" + + plugin_name = _prompt_nonempty_text("插件名字") + author = click.prompt("作者", type=str, default="", show_default=False).strip() + desc = click.prompt("描述", type=str, default="", show_default=False).strip() + version = click.prompt("版本", type=str, default="1.0.0", show_default=True).strip() + return plugin_name, author, desc, version or "1.0.0" + + +def _init_plugin(name: str | None, agents: tuple[str, ...] = ()) -> None: + raw_name, author, desc, version = _collect_init_metadata(name) + plugin_name = _normalize_plugin_name(raw_name) + target_dir = Path(plugin_name) + if target_dir.exists(): + raise _CliPluginValidationError(f"目标目录已存在:{target_dir}") + + display_name = raw_name.strip() or plugin_name + target_dir.mkdir(parents=True, exist_ok=False) + (target_dir / "tests").mkdir() + (target_dir / "plugin.yaml").write_text( + _render_init_plugin_yaml( + plugin_name=plugin_name, + display_name=display_name, + desc=desc, + author=author, + version=version, + ), + encoding="utf-8", + ) + (target_dir / "requirements.txt").write_text("", encoding="utf-8") + (target_dir / "main.py").write_text( + _render_init_main_py(plugin_name=plugin_name), + encoding="utf-8", + ) + (target_dir / "README.md").write_text( + _render_init_readme(plugin_name=plugin_name), + encoding="utf-8", + ) + (target_dir / "tests" / "test_plugin.py").write_text( + _render_init_test_py(plugin_name=plugin_name), + encoding="utf-8", + ) + _render_init_agent_templates( + target_dir=target_dir, + plugin_name=plugin_name, + display_name=display_name, + agents=agents, + ) + click.echo(f"已创建插件:{target_dir}") + if agents: + generated_paths = ", ".join( + str(INIT_AGENT_SKILL_ROOTS[agent] / INIT_SKILL_TEMPLATE_NAME) + for agent in agents + ) + click.echo(f"已生成项目级 skill:{generated_paths}") + click.echo("后续命令:") + click.echo(f" astrbot-sdk validate --plugin-dir {target_dir}") + click.echo( + f" astrbot-sdk dev --local --plugin-dir {target_dir} --event-text hello" + ) + + +def _validate_plugin(plugin_dir: Path) -> None: + plugin, loaded = _load_validated_plugin(plugin_dir) + click.echo(f"校验通过:{plugin.name}") + click.echo(f"kind: {_build_kind(plugin)}") + click.echo(f"plugin_dir: {plugin.plugin_dir}") + click.echo(f"handlers: {len(loaded.handlers)}") + click.echo(f"capabilities: {len(loaded.capabilities)}") + click.echo(f"instances: {len(loaded.instances)}") + + +def _build_plugin(plugin_dir: Path, output_dir: Path | None) -> None: + plugin, _ = _load_validated_plugin(plugin_dir) + build_dir = (output_dir or (plugin.plugin_dir / "dist")).resolve() + build_dir.mkdir(parents=True, exist_ok=True) + + version = _sanitize_build_part(str(plugin.manifest_data.get("version") or "0.0.0")) + archive_name = f"{_sanitize_build_part(plugin.name)}-{version}.zip" + archive_path = build_dir / archive_name + + with zipfile.ZipFile( + archive_path, + mode="w", + compression=zipfile.ZIP_DEFLATED, + ) as archive: + for path in _iter_build_files(plugin.plugin_dir, build_dir): + archive.write(path, arcname=path.relative_to(plugin.plugin_dir)) + + click.echo(f"构建完成:{archive_path}") + click.echo(f"artifact: {archive_path}") + + +@click.group() +@click.option("-v", "--verbose", is_flag=True, help="Enable verbose output") +@click.pass_context +def cli(ctx, verbose: bool) -> None: + """AstrBot SDK CLI。""" + ctx.ensure_object(dict) + ctx.obj["verbose"] = verbose + setup_logger(verbose) + + +@cli.command() +@click.option( + "--plugins-dir", + default="plugins", + type=click.Path(file_okay=False, dir_okay=True, path_type=Path), + help="Directory containing plugin folders", +) +@click.option( + "--protocol-stdout", + default=None, + type=str, + help="Redirect runtime protocol stdout to console, silent, or a file path", +) +def run(plugins_dir: Path, protocol_stdout: str | None) -> None: + """Start the plugin supervisor over stdio.""" + transport_stdout, opened_stdout = _resolve_protocol_stdout(protocol_stdout) + try: + _run_async_entrypoint( + run_supervisor(plugins_dir=plugins_dir, stdout=transport_stdout), + log_message=f"启动插件主管进程,插件目录:{plugins_dir}", + context={"plugins_dir": plugins_dir}, + ) + finally: + if opened_stdout is not None: + opened_stdout.close() + + +@cli.command() +@click.argument("name", type=str, required=False) +@click.option( + "--agents", + callback=_parse_init_agents, + metavar="claude,codex,opencode", + help="Generate per-agent project templates, comma-separated: claude,codex,opencode", +) +def init(name: str | None, agents: tuple[str, ...]) -> None: + """Create a new plugin skeleton in the target directory.""" + _run_sync_entrypoint( + lambda: _init_plugin(name, agents), + log_message=f"创建插件:{name or ''}", + context={"target": name or ""}, + ) + + +@cli.command() +@click.option( + "--plugin-dir", + default=".", + show_default=True, + type=click.Path(file_okay=False, dir_okay=True, path_type=Path), + help="Plugin directory to validate", +) +def validate(plugin_dir: Path) -> None: + """Validate plugin manifest, imports and handler discovery.""" + _run_sync_entrypoint( + lambda: _validate_plugin(plugin_dir), + log_message=f"校验插件目录:{plugin_dir}", + context={"plugin_dir": plugin_dir}, + ) + + +@cli.command() +@click.option( + "--plugin-dir", + default=".", + show_default=True, + type=click.Path(file_okay=False, dir_okay=True, path_type=Path), + help="Plugin directory to package", +) +@click.option( + "--output-dir", + default=None, + type=click.Path(file_okay=False, dir_okay=True, path_type=Path), + help="Directory for the build artifact, defaults to /dist", +) +def build(plugin_dir: Path, output_dir: Path | None) -> None: + """Validate and package a plugin into a zip artifact.""" + _run_sync_entrypoint( + lambda: _build_plugin(plugin_dir, output_dir), + log_message=f"构建插件包:{plugin_dir}", + context={"plugin_dir": plugin_dir, "output_dir": output_dir}, + ) + + +@cli.command() +@click.option( + "--plugin-dir", + required=False, + default=None, + type=click.Path(file_okay=False, dir_okay=True, path_type=Path), + help="Plugin directory to run locally, defaults to current directory when plugin.yaml exists", +) +@click.option("--local", "local_mode", is_flag=True, help="Run against local mock core") +@click.option( + "--standalone", + "standalone_mode", + is_flag=True, + help="Deprecated alias of --local", +) +@click.option("--event-text", type=str, help="Single message text to dispatch") +@click.option("--interactive", is_flag=True, help="Read follow-up messages from stdin") +@click.option( + "--watch", + is_flag=True, + help="Reload the local harness when plugin files change", +) +@click.option("--session-id", default="local-session", show_default=True) +@click.option("--user-id", default="local-user", show_default=True) +@click.option("--platform", "platform_name", default="test", show_default=True) +@click.option("--group-id", default=None) +@click.option("--event-type", default="message", show_default=True) +def dev( + plugin_dir: Path | None, + local_mode: bool, + standalone_mode: bool, + event_text: str | None, + interactive: bool, + watch: bool, + session_id: str, + user_id: str, + platform_name: str, + group_id: str | None, + event_type: str, +) -> None: + """Run a plugin against the local mock core for development.""" + if not (local_mode or standalone_mode): + raise click.BadParameter("当前 dev 只支持 --local/--standalone 模式") + if interactive and event_text: + raise click.BadParameter("--interactive 与 --event-text 不能同时使用") + if not interactive and not event_text: + raise click.BadParameter("请提供 --event-text,或改用 --interactive") + resolved_plugin_dir = _resolve_dev_plugin_dir(plugin_dir) + _run_async_entrypoint( + _run_local_dev( + plugin_dir=resolved_plugin_dir, + event_text=event_text, + interactive=interactive, + watch=watch, + session_id=session_id, + user_id=user_id, + platform=platform_name, + group_id=group_id, + event_type=event_type, + ), + log_message=f"启动本地开发模式:{resolved_plugin_dir}", + context={ + "plugin_dir": resolved_plugin_dir, + "session_id": session_id, + "platform": platform_name, + "event_type": event_type, + }, + ) + + +@cli.command(hidden=True) +@click.option( + "--plugin-dir", + required=False, + type=click.Path(file_okay=False, dir_okay=True, path_type=Path), +) +@click.option( + "--group-metadata", + required=False, + type=click.Path(file_okay=True, dir_okay=False, path_type=Path), +) +@click.option( + "--protocol-stdout", + default=None, + type=str, + help="Redirect runtime protocol stdout to console, silent, or a file path", +) +def worker( + plugin_dir: Path | None, + group_metadata: Path | None, + protocol_stdout: str | None, +) -> None: + """Internal command used by the supervisor to start a worker.""" + if plugin_dir is None and group_metadata is None: + raise click.UsageError("Either --plugin-dir or --group-metadata is required") + if plugin_dir is not None and group_metadata is not None: + raise click.UsageError( + "--plugin-dir and --group-metadata are mutually exclusive" + ) + + target = str(group_metadata or plugin_dir) + transport_stdout, opened_stdout = _resolve_protocol_stdout(protocol_stdout) + if group_metadata is not None: + entrypoint = run_plugin_worker( + group_metadata=group_metadata, + stdout=transport_stdout, + ) + else: + entrypoint = run_plugin_worker( + plugin_dir=plugin_dir, + stdout=transport_stdout, + ) + try: + _run_async_entrypoint( + entrypoint, + log_message=f"启动插件工作进程:{target}", + log_level="debug", + context={"plugin_dir": plugin_dir}, + ) + finally: + if opened_stdout is not None: + opened_stdout.close() + + +@cli.command(hidden=True) +@click.option("--port", default=8765, type=int, help="WebSocket server port") +def websocket(port: int) -> None: + """WebSocket runtime entrypoint kept for standalone bridge scenarios.""" + _run_async_entrypoint( + run_websocket_server(port=port), + log_message=f"启动 WebSocket 服务器,端口:{port}", + context={"port": port}, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/clients/__init__.py b/astrbot-sdk/src/astrbot_sdk/clients/__init__.py new file mode 100644 index 0000000000..8b1d5d4b0a --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/__init__.py @@ -0,0 +1,109 @@ +"""Native v4 capability clients. + +These clients provide the narrow, typed surface exposed by `Context` for +calling remote capabilities. They handle capability names, payload shaping, +and result decoding, without exposing protocol or transport details. + +Migration shims and higher-level orchestration stay outside these native +capability clients so `Context` keeps a narrow, stable surface. + +当前公开客户端: + - LLMClient: 文本/结构化/流式 LLM 调用 + - MemoryClient: 记忆搜索、保存、读取、删除 + - DBClient: 键值存储 get/set/delete/list + - FileServiceClient: 文件令牌注册与解析 + - PlatformClient: 平台消息发送与成员查询 + - ProviderClient: Provider 元信息与专用 provider proxy + - PersonaManagerClient: 人格管理 + - ConversationManagerClient: 对话管理 + - KnowledgeBaseManagerClient: 知识库管理 + - HTTPClient: Web API 注册 + - MetadataClient: 插件元数据查询 + - SkillClient: 运行时注册插件 skill +""" + +from .db import DBClient +from .files import FileRegistration, FileServiceClient +from .http import HTTPClient +from .llm import ChatMessage, LLMClient, LLMResponse +from .managers import ( + ConversationCreateParams, + ConversationManagerClient, + ConversationRecord, + ConversationUpdateParams, + KnowledgeBaseCreateParams, + KnowledgeBaseManagerClient, + KnowledgeBaseRecord, + MessageHistoryManagerClient, + MessageHistoryPage, + MessageHistoryRecord, + MessageHistorySender, + PersonaCreateParams, + PersonaManagerClient, + PersonaRecord, + PersonaUpdateParams, +) +from .mcp import MCPManagerClient, MCPServerRecord, MCPServerScope, MCPSession +from .memory import MemoryClient +from .metadata import MetadataClient, PluginMetadata, StarMetadata +from .permission import PermissionCheckResult, PermissionClient, PermissionManagerClient +from .platform import PlatformClient, PlatformError, PlatformStats, PlatformStatus +from .provider import ( + ManagedProviderRecord, + ProviderChangeEvent, + ProviderClient, + ProviderManagerClient, +) +from .registry import HandlerMetadata, RegistryClient +from .session import SessionPluginManager, SessionServiceManager +from .skills import SkillClient, SkillRegistration + +__all__ = [ + "ChatMessage", + "ConversationCreateParams", + "ConversationManagerClient", + "ConversationRecord", + "ConversationUpdateParams", + "DBClient", + "FileRegistration", + "FileServiceClient", + "HTTPClient", + "KnowledgeBaseCreateParams", + "KnowledgeBaseManagerClient", + "KnowledgeBaseRecord", + "MessageHistoryManagerClient", + "MessageHistoryPage", + "MessageHistoryRecord", + "MessageHistorySender", + "LLMClient", + "LLMResponse", + "MCPManagerClient", + "MCPSession", + "MCPServerRecord", + "MCPServerScope", + "MemoryClient", + "ManagedProviderRecord", + "MetadataClient", + "PermissionCheckResult", + "PermissionClient", + "PermissionManagerClient", + "PlatformClient", + "PlatformError", + "PlatformStats", + "PlatformStatus", + "PersonaCreateParams", + "PersonaManagerClient", + "PersonaRecord", + "PersonaUpdateParams", + "ProviderChangeEvent", + "ProviderClient", + "ProviderManagerClient", + "PluginMetadata", + "StarMetadata", + "HandlerMetadata", + "RegistryClient", + "SessionPluginManager", + "SessionServiceManager", + "SkillClient", + "SkillRegistration", +] diff --git a/astrbot-sdk/src/astrbot_sdk/clients/_proxy.py b/astrbot-sdk/src/astrbot_sdk/clients/_proxy.py new file mode 100644 index 0000000000..da2bca6dad --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/_proxy.py @@ -0,0 +1,188 @@ +"""能力代理模块。 + +提供 CapabilityProxy 类,作为客户端与 Peer 之间的中间层,负责: +- 检查远程能力是否可用 +- 验证流式调用支持 +- 统一封装 invoke 和 invoke_stream 调用 + +设计说明: + CapabilityProxy 是新版架构的核心组件。每个专用客户端 (LLMClient, DBClient 等) + 都通过 CapabilityProxy 与远程通信,并在发起调用时绑定当前插件身份, + 让运行时把调用者信息放进协议层而不是业务 payload。 + +使用示例: + proxy = CapabilityProxy(peer) + + # 普通调用 + result = await proxy.call("llm.chat", {"prompt": "hello"}) + + # 流式调用 + async for delta in proxy.stream("llm.stream_chat", {"prompt": "hello"}): + print(delta["text"]) +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Mapping +from typing import Any, Protocol + +from .._internal.invocation_context import caller_plugin_scope +from ..errors import AstrBotError + + +class _CapabilityDescriptorLike(Protocol): + supports_stream: bool | None + + +class _CapabilityPeerLike(Protocol): + remote_capability_map: Mapping[str, _CapabilityDescriptorLike] + remote_peer: Any | None + + async def invoke( + self, + capability: str, + payload: dict[str, Any], + *, + stream: bool = False, + request_id: str | None = None, + ) -> dict[str, Any]: ... + + async def invoke_stream( + self, + capability: str, + payload: dict[str, Any], + *, + request_id: str | None = None, + ) -> AsyncIterator[Any]: ... + + +class CapabilityProxy: + """能力代理类,封装 Peer 的能力调用接口。 + + 负责在调用前验证能力可用性和流式支持,提供统一的 call/stream 接口。 + + Attributes: + _peer: 底层 Peer 实例,负责实际的 RPC 通信 + """ + + def __init__( + self, + peer: _CapabilityPeerLike, + caller_plugin_id: str | None = None, + request_scope_id: str | None = None, + ) -> None: + """初始化能力代理。 + + Args: + peer: Peer 实例,提供 remote_capability_map 和 invoke/invoke_stream 方法 + """ + self._peer = peer + self._caller_plugin_id = caller_plugin_id + self._request_scope_id = request_scope_id + + def _get_descriptor(self, name: str): + """获取能力描述符。 + + Args: + name: 能力名称,如 "llm.chat" + + Returns: + 能力描述符,若不存在则返回 None + """ + capability_map = getattr(self._peer, "remote_capability_map", {}) + if not isinstance(capability_map, Mapping): + return None + return capability_map.get(name) + + def _remote_initialized(self) -> bool: + peer_attrs = getattr(self._peer, "__dict__", None) + if not isinstance(peer_attrs, dict): + return False + + # Avoid getattr() here: MagicMock synthesizes truthy child attributes and + # makes an uninitialized peer look ready. + remote_peer = peer_attrs.get("remote_peer") + capability_map = peer_attrs.get("remote_capability_map") + return bool(remote_peer) or ( + isinstance(capability_map, Mapping) and bool(capability_map) + ) + + def _ensure_available(self, name: str, *, stream: bool) -> None: + """确保能力可用且支持指定的调用模式。 + + Args: + name: 能力名称 + stream: 是否需要流式支持 + + Raises: + AstrBotError: 能力不存在或流式不支持 + """ + descriptor = self._get_descriptor(name) + if descriptor is None: + if self._remote_initialized(): + raise AstrBotError.capability_not_found(name) + return + if stream and not descriptor.supports_stream: + raise AstrBotError.invalid_input(f"{name} 不支持 stream=true") + + def _prepare_payload(self, name: str, payload: dict[str, Any]) -> dict[str, Any]: + if ( + not isinstance(self._request_scope_id, str) + or not self._request_scope_id + or not name.startswith("system.event.") + ): + return payload + scoped_payload = dict(payload) + scoped_payload.setdefault("_request_scope_id", self._request_scope_id) + return scoped_payload + + async def call(self, name: str, payload: dict[str, Any]) -> dict[str, Any]: + """执行普通能力调用(非流式)。 + + Args: + name: 能力名称,如 "llm.chat", "db.get" + payload: 调用参数字典 + + Returns: + 调用结果字典 + + Raises: + AstrBotError: 能力不存在或调用失败 + + 示例: + result = await proxy.call("llm.chat", {"prompt": "hello"}) + print(result["text"]) + """ + self._ensure_available(name, stream=False) + prepared_payload = self._prepare_payload(name, payload) + with caller_plugin_scope(self._caller_plugin_id): + return await self._peer.invoke(name, prepared_payload, stream=False) + + async def stream( + self, + name: str, + payload: dict[str, Any], + ) -> AsyncIterator[dict[str, Any]]: + """执行流式能力调用。 + + Args: + name: 能力名称,如 "llm.stream_chat" + payload: 调用参数字典 + + Yields: + 每个增量数据块(phase="delta" 时的 data 字段) + + Raises: + AstrBotError: 能力不存在或不支持流式 + + 示例: + async for delta in proxy.stream("llm.stream_chat", {"prompt": "hello"}): + print(delta["text"], end="") + """ + self._ensure_available(name, stream=True) + prepared_payload = self._prepare_payload(name, payload) + with caller_plugin_scope(self._caller_plugin_id): + event_stream = await self._peer.invoke_stream(name, prepared_payload) + async for event in event_stream: + if event.phase == "delta": + yield event.data diff --git a/astrbot-sdk/src/astrbot_sdk/clients/db.py b/astrbot-sdk/src/astrbot_sdk/clients/db.py new file mode 100644 index 0000000000..bf2783490d --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/db.py @@ -0,0 +1,161 @@ +"""数据库客户端模块。 + +提供键值存储能力,用于持久化插件数据。 + +功能说明: + - 数据永久存储,除非用户显式删除 + - 值类型支持任意 JSON 数据 + - 支持前缀查询键列表 + - 支持批量读写 + - 支持订阅变更事件 +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Mapping, Sequence +from typing import Any + +from ._proxy import CapabilityProxy + + +class DBClient: + """键值数据库客户端。 + + 提供插件数据的持久化存储能力,数据永久保存直到显式删除。 + + Attributes: + _proxy: CapabilityProxy 实例,用于远程能力调用 + """ + + def __init__(self, proxy: CapabilityProxy) -> None: + """初始化数据库客户端。 + + Args: + proxy: CapabilityProxy 实例 + """ + self._proxy = proxy + + async def get(self, key: str) -> Any | None: + """获取指定键的值。 + + Args: + key: 数据键名 + + Returns: + 存储的值,若键不存在则返回 None + + 示例: + data = await ctx.db.get("user_settings") + if data: + print(data["theme"]) + """ + output = await self._proxy.call("db.get", {"key": key}) + return output.get("value") + + async def set(self, key: str, value: Any) -> None: + """设置键值对。 + + Args: + key: 数据键名 + value: 要存储的 JSON 值 + + 示例: + await ctx.db.set("user_settings", {"theme": "dark", "lang": "zh"}) + await ctx.db.set("greeted", True) + """ + await self._proxy.call("db.set", {"key": key, "value": value}) + + async def delete(self, key: str) -> None: + """删除指定键的数据。 + + Args: + key: 要删除的数据键名 + + 示例: + await ctx.db.delete("user_settings") + """ + await self._proxy.call("db.delete", {"key": key}) + + async def list(self, prefix: str | None = None) -> list[str]: + """列出匹配前缀的所有键。 + + Args: + prefix: 键前缀过滤,None 表示列出所有键 + + Returns: + 匹配的键名列表 + + 示例: + # 列出所有用户设置相关的键 + keys = await ctx.db.list("user_") + # ["user_settings", "user_profile", "user_history"] + """ + output = await self._proxy.call("db.list", {"prefix": prefix}) + keys = output.get("keys") + if not isinstance(keys, (list, tuple)): + return [] + return [str(item) for item in keys] + + async def get_many(self, keys: Sequence[str]) -> dict[str, Any | None]: + """批量获取多个键的值。 + + Args: + keys: 要读取的键列表 + + Returns: + 一个 dict,key 为键名,value 为对应值(不存在则为 None) + + 示例: + values = await ctx.db.get_many(["user:1", "user:2"]) + if values["user:1"] is None: + print("user:1 missing") + """ + output = await self._proxy.call("db.get_many", {"keys": list(keys)}) + items = output.get("items") + if not isinstance(items, (list, tuple)): + return {} + result: dict[str, Any | None] = {} + for item in items: + if not isinstance(item, dict): + continue + key = item.get("key") + if not isinstance(key, str): + continue + result[key] = item.get("value") + return result + + async def set_many( + self, items: Mapping[str, Any] | Sequence[tuple[str, Any]] + ) -> None: + """批量写入多个键值对。 + + Args: + items: 键值对集合(dict 或二元组序列) + + 示例: + await ctx.db.set_many({"user:1": {"name": "a"}, "user:2": {"name": "b"}}) + """ + if isinstance(items, Mapping): + pairs = list(items.items()) + else: + pairs = list(items) + + payload_items: list[dict[str, Any]] = [ + {"key": str(key), "value": value} for key, value in pairs + ] + await self._proxy.call("db.set_many", {"items": payload_items}) + + def watch(self, prefix: str | None = None) -> AsyncIterator[dict[str, Any]]: + """订阅 KV 变更事件(流式)。 + + Args: + prefix: 键前缀过滤;None 表示订阅所有键 + + Yields: + 变更事件 dict:{"op": "set"|"delete", "key": str, "value": Any|None} + + 示例: + async for event in ctx.db.watch("user:"): + print(event["op"], event["key"]) + """ + return self._proxy.stream("db.watch", {"prefix": prefix}) diff --git a/astrbot-sdk/src/astrbot_sdk/clients/files.py b/astrbot-sdk/src/astrbot_sdk/clients/files.py new file mode 100644 index 0000000000..3a1dd6f6f3 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/files.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from ._proxy import CapabilityProxy + + +@dataclass(slots=True) +class FileRegistration: + token: str + url: str + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> FileRegistration: + return cls( + token=str(payload.get("token", "")), + url=str(payload.get("url", "")), + ) + + +class FileServiceClient: + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def register_file( + self, + path: str, + timeout: float | None = None, + ) -> str: + output = await self._proxy.call( + "system.file.register", + {"path": str(path), "timeout": timeout}, + ) + return FileRegistration.from_payload(output).token + + async def handle_file(self, token: str) -> str: + output = await self._proxy.call( + "system.file.handle", + {"token": str(token)}, + ) + return str(output.get("path", "")) + + async def _register_file_url( + self, + path: str, + timeout: float | None = None, + ) -> str: + output = await self._proxy.call( + "system.file.register", + {"path": str(path), "timeout": timeout}, + ) + return FileRegistration.from_payload(output).url diff --git a/astrbot-sdk/src/astrbot_sdk/clients/http.py b/astrbot-sdk/src/astrbot_sdk/clients/http.py new file mode 100644 index 0000000000..efec135e8c --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/http.py @@ -0,0 +1,165 @@ +"""HTTP 客户端模块。 + +提供 HTTP API 注册能力。 + +功能说明: + - 注册自定义 Web API 端点 + - 支持异步请求处理 + - 与宿主 Web 服务器集成 + +设计说明: + 由于跨进程架构,handler 函数无法直接序列化传递。 + 插件需要先声明处理 HTTP 请求的 capability,然后注册路由到 capability 的映射。 + 当前插件身份由运行时在协议层透传,客户端 payload 不暴露 `plugin_id`。 + + 调用流程: + HTTP 请求 → 宿主 Web 服务器 → 查找 route 映射 → invoke capability → Worker 执行 handler → 返回响应 + +示例: + # 插件声明处理 HTTP 请求的 capability + @provide_capability( + name="my_plugin.http_handler", + description="处理 /my-api 的 HTTP 请求", + input_schema={...}, + output_schema={...} + ) + async def handle_http_request(request_id: str, payload: dict, cancel_token): + return {"status": 200, "body": {"result": "ok"}} + + # 注册路由 → capability 映射 + await ctx.http.register_api( + route="/my-api", + methods=["GET", "POST"], + handler_capability="my_plugin.http_handler", + description="我的 API" + ) +""" + +from __future__ import annotations + +from typing import Any + +from ..decorators import get_capability_meta +from ..errors import AstrBotError +from ._proxy import CapabilityProxy + + +def _resolve_handler_capability( + handler_capability: str | None, + handler: Any | None, +) -> str: + if handler_capability and handler is not None: + raise AstrBotError.invalid_input( + "register_api 不能同时提供 handler_capability 和 handler", + hint="请二选一:传 capability 名称字符串,或传 @provide_capability 标记的方法", + ) + if handler_capability: + return handler_capability + if handler is None: + raise AstrBotError.invalid_input( + "register_api 需要提供 handler_capability 或 handler", + hint="示例:handler_capability='demo.http_handler' 或 handler=self.http_handler_capability", + ) + target = getattr(handler, "__func__", handler) + meta = get_capability_meta(target) + if meta is None: + raise AstrBotError.invalid_input( + "register_api(handler=...) 需要传入使用 @provide_capability 声明的方法", + hint="请先用 @provide_capability(name='demo.http_handler', ...) 标记该方法", + ) + return meta.descriptor.name + + +class HTTPClient: + """HTTP 能力客户端。 + + 提供 Web API 注册能力,允许插件暴露自定义 HTTP 端点。 + + Attributes: + _proxy: CapabilityProxy 实例,用于远程能力调用 + """ + + def __init__(self, proxy: CapabilityProxy) -> None: + """初始化 HTTP 客户端。 + + Args: + proxy: CapabilityProxy 实例 + """ + self._proxy = proxy + + async def register_api( + self, + route: str, + handler_capability: str | None = None, + *, + handler: Any | None = None, + methods: list[str] | None = None, + description: str = "", + ) -> None: + """注册 Web API 端点。 + + Args: + route: API 路由路径(如 "/my-api") + handler_capability: 处理此路由的 capability 名称 + handler: 使用 @provide_capability 标记的方法引用 + methods: HTTP 方法列表,默认 ["GET"] + description: API 描述 + + 示例: + await ctx.http.register_api( + route="/my-api", + handler_capability="my_plugin.http_handler", + methods=["GET", "POST"], + description="我的 API" + ) + """ + if methods is None: + methods = ["GET"] + resolved_handler = _resolve_handler_capability(handler_capability, handler) + + await self._proxy.call( + "http.register_api", + { + "route": route, + "methods": methods, + "handler_capability": resolved_handler, + "description": description, + }, + ) + + async def unregister_api( + self, route: str, methods: list[str] | None = None + ) -> None: + """注销 Web API 端点。 + + Args: + route: API 路由路径 + methods: HTTP 方法列表,None 表示所有方法 + + 示例: + await ctx.http.unregister_api("/my-api") + """ + if methods is None: + methods = [] + + await self._proxy.call( + "http.unregister_api", + {"route": route, "methods": methods}, + ) + + async def list_apis(self) -> list[dict[str, Any]]: + """列出当前插件注册的所有 API。 + + Returns: + API 列表,每项包含 route, methods, description + + 示例: + apis = await ctx.http.list_apis() + for api in apis: + print(f"{api['route']}: {api['methods']}") + """ + output = await self._proxy.call( + "http.list_apis", + {}, + ) + return output.get("apis", []) diff --git a/astrbot-sdk/src/astrbot_sdk/clients/llm.py b/astrbot-sdk/src/astrbot_sdk/clients/llm.py new file mode 100644 index 0000000000..14d7393fd0 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/llm.py @@ -0,0 +1,293 @@ +"""大语言模型客户端模块。 + +提供 v4 原生的 LLM 能力调用接口。 + +设计边界: + - `chat()` 是便捷文本接口,返回最终文本 + - `chat_raw()` 返回完整结构化响应 + - `stream_chat()` 返回文本增量 + - Agent 循环、动态工具注册等更高层 orchestration 不放在客户端内, + 由上层运行时或独立迁移入口承接 +""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator, Mapping, Sequence +from typing import Any + +from pydantic import BaseModel, Field + +from ._proxy import CapabilityProxy + + +class ChatMessage(BaseModel): + """聊天消息模型。 + + 用于构建对话历史,传递给 LLM。 + + Attributes: + role: 消息角色,如 "user", "assistant", "system" + content: 消息内容 + + 示例: + history = [ + ChatMessage(role="user", content="你好"), + ChatMessage(role="assistant", content="你好!有什么可以帮助你的?"), + ChatMessage(role="user", content="今天天气怎么样?"), + ] + """ + + role: str + content: str + + +ChatHistoryItem = ChatMessage | Mapping[str, Any] + + +def _serialize_history( + history: Sequence[ChatHistoryItem] | None, +) -> list[dict[str, Any]]: + if history is None: + return [] + + serialized: list[dict[str, Any]] = [] + for item in history: + if isinstance(item, ChatMessage): + serialized.append(item.model_dump()) + continue + if isinstance(item, Mapping): + serialized.append(dict(item)) + continue + raise TypeError("history 项必须是 ChatMessage 或 mapping") + return serialized + + +def _normalize_chat_context_payload( + *, + history: Sequence[ChatHistoryItem] | None = None, + contexts: Sequence[ChatHistoryItem] | None = None, +) -> dict[str, list[dict[str, Any]]]: + if contexts is not None: + return {"contexts": _serialize_history(contexts)} + if history is not None: + return {"contexts": _serialize_history(history)} + return {} + + +def _build_chat_payload( + prompt: str, + *, + system: str | None = None, + history: Sequence[ChatHistoryItem] | None = None, + contexts: Sequence[ChatHistoryItem] | None = None, + provider_id: str | None = None, + tool_calls_result: list[dict[str, Any]] | None = None, + model: str | None = None, + temperature: float | None = None, + extra: dict[str, Any] | None = None, +) -> dict[str, Any]: + payload: dict[str, Any] = {"prompt": prompt} + if system is not None: + payload["system"] = system + payload.update(_normalize_chat_context_payload(history=history, contexts=contexts)) + if provider_id is not None: + payload["provider_id"] = provider_id + if tool_calls_result is not None: + payload["tool_calls_result"] = [dict(item) for item in tool_calls_result] + if model is not None: + payload["model"] = model + if temperature is not None: + payload["temperature"] = temperature + if extra: + payload.update(extra) + return payload + + +class LLMResponse(BaseModel): + """LLM 响应模型。 + + 包含完整的 LLM 响应信息,用于 chat_raw() 方法返回。 + + Attributes: + text: 生成的文本内容 + usage: Token 使用统计,如 {"prompt_tokens": 10, "completion_tokens": 20} + finish_reason: 结束原因,如 "stop", "length", "tool_calls" + tool_calls: 工具调用列表(如果 LLM 决定调用工具) + """ + + text: str + usage: dict[str, Any] | None = None + finish_reason: str | None = None + tool_calls: list[dict[str, Any]] = Field(default_factory=list) + role: str | None = None + reasoning_content: str | None = None + reasoning_signature: str | None = None + + +class LLMClient: + """大语言模型客户端。 + + 提供与 LLM 交互的能力,支持普通聊天和流式聊天。 + + Attributes: + _proxy: CapabilityProxy 实例,用于远程能力调用 + """ + + def __init__(self, proxy: CapabilityProxy) -> None: + """初始化 LLM 客户端。 + + Args: + proxy: CapabilityProxy 实例 + """ + self._proxy = proxy + + async def chat( + self, + prompt: str, + *, + system: str | None = None, + history: Sequence[ChatHistoryItem] | None = None, + contexts: Sequence[ChatHistoryItem] | None = None, + provider_id: str | None = None, + tool_calls_result: list[dict[str, Any]] | None = None, + model: str | None = None, + temperature: float | None = None, + **kwargs: Any, + ) -> str: + """发送聊天请求并返回文本响应。 + + 这是简化的聊天接口,仅返回生成的文本内容。 + 如需完整响应信息(包括 usage、tool_calls),请使用 chat_raw()。 + + Args: + prompt: 用户输入的提示文本 + system: 系统提示词,用于指导 LLM 行为 + history: 对话历史,用于保持上下文连续性 + model: 指定使用的模型名称(可选,由核心自动选择) + temperature: 生成温度,控制随机性(0-1) + **kwargs: 额外透传参数,如 `image_urls`、`tools` + + Returns: + LLM 生成的文本内容 + + 示例: + # 简单对话 + reply = await ctx.llm.chat("你好,介绍一下自己") + + # 带历史的对话 + history = [ + ChatMessage(role="user", content="我叫小明"), + ChatMessage(role="assistant", content="你好小明!"), + ] + reply = await ctx.llm.chat("你记得我的名字吗?", history=history) + """ + output = await self._proxy.call( + "llm.chat", + _build_chat_payload( + prompt, + system=system, + history=history, + contexts=contexts, + provider_id=provider_id, + tool_calls_result=tool_calls_result, + model=model, + temperature=temperature, + extra=kwargs, + ), + ) + return str(output.get("text", "")) + + async def chat_raw( + self, + prompt: str, + *, + system: str | None = None, + history: Sequence[ChatHistoryItem] | None = None, + contexts: Sequence[ChatHistoryItem] | None = None, + provider_id: str | None = None, + tool_calls_result: list[dict[str, Any]] | None = None, + model: str | None = None, + temperature: float | None = None, + **kwargs: Any, + ) -> LLMResponse: + """发送聊天请求并返回完整响应。 + + 与 chat() 不同,此方法返回完整的 LLMResponse 对象, + 包含 usage、finish_reason、tool_calls 等信息。 + + Args: + prompt: 用户输入的提示文本 + **kwargs: 额外参数,如 system, history, model, temperature 等 + + Returns: + LLMResponse 对象,包含完整响应信息 + + 示例: + response = await ctx.llm.chat_raw("写一首诗", temperature=0.8) + print(f"生成文本: {response.text}") + print(f"Token 使用: {response.usage}") + """ + payload = _build_chat_payload( + prompt, + system=system, + history=history, + contexts=contexts, + provider_id=provider_id, + tool_calls_result=tool_calls_result, + model=model, + temperature=temperature, + extra=kwargs, + ) + output = await self._proxy.call( + "llm.chat_raw", + payload, + ) + return LLMResponse.model_validate(output) + + async def stream_chat( + self, + prompt: str, + *, + system: str | None = None, + history: Sequence[ChatHistoryItem] | None = None, + contexts: Sequence[ChatHistoryItem] | None = None, + provider_id: str | None = None, + tool_calls_result: list[dict[str, Any]] | None = None, + model: str | None = None, + temperature: float | None = None, + **kwargs: Any, + ) -> AsyncGenerator[str, None]: + """流式聊天,逐块返回响应文本。 + + 适用于需要实时显示生成内容的场景,如聊天界面。 + + Args: + prompt: 用户输入的提示文本 + system: 系统提示词 + history: 对话历史 + model: 指定模型 + temperature: 采样温度 + **kwargs: 额外透传参数,如 `image_urls`、`tools` + + Yields: + 每个生成的文本块 + + 示例: + async for chunk in ctx.llm.stream_chat("讲一个故事"): + print(chunk, end="", flush=True) + """ + async for data in self._proxy.stream( + "llm.stream_chat", + _build_chat_payload( + prompt, + system=system, + history=history, + contexts=contexts, + provider_id=provider_id, + tool_calls_result=tool_calls_result, + model=model, + temperature=temperature, + extra=kwargs, + ), + ): + yield str(data.get("text", "")) diff --git a/astrbot-sdk/src/astrbot_sdk/clients/managers.py b/astrbot-sdk/src/astrbot_sdk/clients/managers.py new file mode 100644 index 0000000000..e95115ea70 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/managers.py @@ -0,0 +1,886 @@ +"""Typed SDK manager clients for persona, conversation, and knowledge base.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from ..errors import AstrBotError, ErrorCodes +from ..message.components import ( + BaseMessageComponent, + component_to_payload_sync, + payload_to_component, +) +from ..message.session import MessageSession +from ._proxy import CapabilityProxy + + +class _ManagerModel(BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + def to_payload(self) -> dict[str, Any]: + return self.model_dump(exclude_none=True) + + def to_update_payload(self) -> dict[str, Any]: + return self.model_dump(exclude_unset=True) + + +def _normalize_session(session: str | MessageSession) -> str: + if isinstance(session, MessageSession): + return str(session) + return str(session) + + +def _require_message_history_session( + session: MessageSession, +) -> dict[str, str]: + if not isinstance(session, MessageSession): + raise TypeError( + "message_history requires astrbot_sdk.message.session.MessageSession" + ) + return { + "platform_id": str(session.platform_id), + "message_type": str(session.message_type), + "session_id": str(session.session_id), + } + + +def _normalize_message_history_parts( + parts: list[BaseMessageComponent], +) -> list[dict[str, Any]]: + normalized: list[dict[str, Any]] = [] + for part in parts: + if not isinstance(part, BaseMessageComponent): + raise TypeError( + "message_history.append requires BaseMessageComponent items in parts" + ) + normalized.append(component_to_payload_sync(part)) + return normalized + + +def _normalize_message_history_boundary(value: datetime) -> str: + if not isinstance(value, datetime): + raise TypeError("message_history boundary requires datetime") + normalized = value + if normalized.tzinfo is None: + normalized = normalized.replace(tzinfo=timezone.utc) + else: + normalized = normalized.astimezone(timezone.utc) + return normalized.isoformat() + + +class PersonaRecord(_ManagerModel): + persona_id: str + system_prompt: str + begin_dialogs: list[str] = Field(default_factory=list) + tools: list[str] | None = None + skills: list[str] | None = None + custom_error_message: str | None = None + folder_id: str | None = None + sort_order: int = 0 + created_at: str | None = None + updated_at: str | None = None + + @classmethod + def from_payload(cls, payload: dict[str, Any] | None) -> PersonaRecord | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class PersonaCreateParams(_ManagerModel): + persona_id: str + system_prompt: str + begin_dialogs: list[str] = Field(default_factory=list) + tools: list[str] | None = None + skills: list[str] | None = None + custom_error_message: str | None = None + folder_id: str | None = None + sort_order: int = 0 + + +class PersonaUpdateParams(_ManagerModel): + system_prompt: str | None = None + begin_dialogs: list[str] | None = None + tools: list[str] | None = None + skills: list[str] | None = None + custom_error_message: str | None = None + + +class ConversationRecord(_ManagerModel): + conversation_id: str + session: str + platform_id: str + history: list[dict[str, Any]] = Field(default_factory=list) + title: str | None = None + persona_id: str | None = None + created_at: str | None = None + updated_at: str | None = None + token_usage: int | None = None + + @classmethod + def from_payload(cls, payload: dict[str, Any] | None) -> ConversationRecord | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class ConversationCreateParams(_ManagerModel): + platform_id: str | None = None + history: list[dict[str, Any]] | None = None + title: str | None = None + persona_id: str | None = None + + +class ConversationUpdateParams(_ManagerModel): + history: list[dict[str, Any]] | None = None + title: str | None = None + persona_id: str | None = None + token_usage: int | None = None + + +class MessageHistorySender(_ManagerModel): + sender_id: str | None = None + sender_name: str | None = None + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> MessageHistorySender | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class MessageHistoryRecord(_ManagerModel): + id: int + session: MessageSession + sender: MessageHistorySender = Field(default_factory=MessageHistorySender) + parts: list[BaseMessageComponent] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) + created_at: datetime | None = None + updated_at: datetime | None = None + idempotency_key: str | None = None + + @model_validator(mode="before") + @classmethod + def _normalize_payload(cls, value: Any) -> Any: + if not isinstance(value, dict): + return value + normalized = dict(value) + + session_payload = normalized.get("session") + if isinstance(session_payload, dict): + normalized["session"] = MessageSession( + platform_id=str(session_payload.get("platform_id", "")), + message_type=str(session_payload.get("message_type", "")), + session_id=str(session_payload.get("session_id", "")), + ) + + sender_payload = normalized.get("sender") + if isinstance(sender_payload, dict): + normalized["sender"] = MessageHistorySender.model_validate(sender_payload) + elif sender_payload is None: + normalized["sender"] = MessageHistorySender() + + parts_payload = normalized.get("parts") + if isinstance(parts_payload, list): + normalized["parts"] = [ + payload_to_component(item) + for item in parts_payload + if isinstance(item, dict) + ] + + metadata_payload = normalized.get("metadata") + if not isinstance(metadata_payload, dict): + normalized["metadata"] = {} + + return normalized + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> MessageHistoryRecord | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class MessageHistoryPage(_ManagerModel): + records: list[MessageHistoryRecord] = Field(default_factory=list) + next_cursor: str | None = None + total: int | None = None + + @model_validator(mode="before") + @classmethod + def _normalize_payload(cls, value: Any) -> Any: + if not isinstance(value, dict): + return value + normalized = dict(value) + records_payload = normalized.get("records") + if isinstance(records_payload, list): + normalized["records"] = [ + record + for record in ( + MessageHistoryRecord.from_payload(item) + if isinstance(item, dict) + else None + for item in records_payload + ) + if record is not None + ] + return normalized + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> MessageHistoryPage | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class KnowledgeBaseRecord(_ManagerModel): + kb_id: str + kb_name: str + description: str | None = None + emoji: str | None = None + embedding_provider_id: str + rerank_provider_id: str | None = None + chunk_size: int | None = None + chunk_overlap: int | None = None + top_k_dense: int | None = None + top_k_sparse: int | None = None + top_m_final: int | None = None + doc_count: int = 0 + chunk_count: int = 0 + created_at: str | None = None + updated_at: str | None = None + + @classmethod + def from_payload(cls, payload: dict[str, Any] | None) -> KnowledgeBaseRecord | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class KnowledgeBaseCreateParams(_ManagerModel): + kb_name: str + embedding_provider_id: str + description: str | None = None + emoji: str | None = None + rerank_provider_id: str | None = None + chunk_size: int | None = None + chunk_overlap: int | None = None + top_k_dense: int | None = None + top_k_sparse: int | None = None + top_m_final: int | None = None + + +class KnowledgeBaseUpdateParams(_ManagerModel): + kb_name: str | None = None + embedding_provider_id: str | None = None + description: str | None = None + emoji: str | None = None + rerank_provider_id: str | None = None + chunk_size: int | None = None + chunk_overlap: int | None = None + top_k_dense: int | None = None + top_k_sparse: int | None = None + top_m_final: int | None = None + + +class KnowledgeBaseDocumentRecord(_ManagerModel): + doc_id: str + kb_id: str + doc_name: str + file_type: str + file_size: int + file_path: str = "" + chunk_count: int = 0 + media_count: int = 0 + created_at: str | None = None + updated_at: str | None = None + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> KnowledgeBaseDocumentRecord | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class KnowledgeBaseRetrieveResultItem(_ManagerModel): + chunk_id: str + doc_id: str + kb_id: str + kb_name: str + doc_name: str + chunk_index: int + content: str + score: float + char_count: int + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> KnowledgeBaseRetrieveResultItem | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class KnowledgeBaseRetrieveResult(_ManagerModel): + context_text: str + results: list[KnowledgeBaseRetrieveResultItem] = Field(default_factory=list) + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> KnowledgeBaseRetrieveResult | None: + if not isinstance(payload, dict): + return None + items = payload.get("results") + normalized_items = ( + [ + item.model_dump() + for item in ( + KnowledgeBaseRetrieveResultItem.from_payload(candidate) + if isinstance(candidate, dict) + else None + for candidate in items + ) + if item is not None + ] + if isinstance(items, list) + else [] + ) + return cls.model_validate( + { + "context_text": str(payload.get("context_text", "")), + "results": normalized_items, + } + ) + + +class KnowledgeBaseDocumentUploadParams(_ManagerModel): + file_token: str | None = None + url: str | None = None + text: str | None = None + file_name: str | None = None + file_type: str | None = None + chunk_size: int | None = None + chunk_overlap: int | None = None + batch_size: int | None = None + tasks_limit: int | None = None + max_retries: int | None = None + enable_cleaning: bool | None = None + cleaning_provider_id: str | None = None + + @model_validator(mode="after") + def _validate_source(self) -> KnowledgeBaseDocumentUploadParams: + if any( + isinstance(value, str) and value.strip() + for value in (self.file_token, self.url, self.text) + ): + return self + raise ValueError( + "knowledge base document upload requires file_token, url, or text" + ) + + +class PersonaManagerClient: + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def get_persona(self, persona_id: str) -> PersonaRecord: + try: + output = await self._proxy.call( + "persona.get", + {"persona_id": str(persona_id)}, + ) + except AstrBotError as exc: + if exc.code == ErrorCodes.INVALID_INPUT: + raise ValueError(f"persona not found: {persona_id}") from exc + raise + persona = PersonaRecord.from_payload(output.get("persona")) + if persona is None: + raise ValueError(f"persona not found: {persona_id}") + return persona + + async def get_all_personas(self) -> list[PersonaRecord]: + output = await self._proxy.call("persona.list", {}) + items = output.get("personas") + if not isinstance(items, list): + return [] + return [ + persona + for persona in ( + PersonaRecord.from_payload(item) if isinstance(item, dict) else None + for item in items + ) + if persona is not None + ] + + async def create_persona(self, params: PersonaCreateParams) -> PersonaRecord: + output = await self._proxy.call( + "persona.create", + {"persona": params.to_payload()}, + ) + persona = PersonaRecord.from_payload(output.get("persona")) + if persona is None: + raise ValueError("persona.create returned no persona") + return persona + + async def update_persona( + self, + persona_id: str, + params: PersonaUpdateParams, + ) -> PersonaRecord | None: + output = await self._proxy.call( + "persona.update", + {"persona_id": str(persona_id), "persona": params.to_update_payload()}, + ) + return PersonaRecord.from_payload(output.get("persona")) + + async def delete_persona(self, persona_id: str) -> None: + await self._proxy.call("persona.delete", {"persona_id": str(persona_id)}) + + +class ConversationManagerClient: + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def new_conversation( + self, + session: str | MessageSession, + params: ConversationCreateParams | None = None, + ) -> str: + output = await self._proxy.call( + "conversation.new", + { + "session": _normalize_session(session), + "conversation": (params.to_payload() if params is not None else {}), + }, + ) + return str(output.get("conversation_id", "")) + + async def switch_conversation( + self, + session: str | MessageSession, + conversation_id: str, + ) -> None: + await self._proxy.call( + "conversation.switch", + { + "session": _normalize_session(session), + "conversation_id": str(conversation_id), + }, + ) + + async def delete_conversation( + self, + session: str | MessageSession, + conversation_id: str | None = None, + ) -> None: + """Delete one conversation for the session. + + When ``conversation_id`` is ``None``, this deletes the current selected + conversation for the session only. It does not delete all conversations + under the session. + """ + + await self._proxy.call( + "conversation.delete", + { + "session": _normalize_session(session), + "conversation_id": conversation_id, + }, + ) + + async def get_conversation( + self, + session: str | MessageSession, + conversation_id: str, + *, + create_if_not_exists: bool = False, + ) -> ConversationRecord | None: + output = await self._proxy.call( + "conversation.get", + { + "session": _normalize_session(session), + "conversation_id": str(conversation_id), + "create_if_not_exists": bool(create_if_not_exists), + }, + ) + return ConversationRecord.from_payload(output.get("conversation")) + + async def get_current_conversation( + self, + session: str | MessageSession, + *, + create_if_not_exists: bool = False, + ) -> ConversationRecord | None: + output = await self._proxy.call( + "conversation.get_current", + { + "session": _normalize_session(session), + "create_if_not_exists": bool(create_if_not_exists), + }, + ) + return ConversationRecord.from_payload(output.get("conversation")) + + async def get_conversations( + self, + session: str | MessageSession | None = None, + *, + platform_id: str | None = None, + ) -> list[ConversationRecord]: + output = await self._proxy.call( + "conversation.list", + { + "session": ( + _normalize_session(session) if session is not None else None + ), + "platform_id": platform_id, + }, + ) + items = output.get("conversations") + if not isinstance(items, list): + return [] + return [ + conversation + for conversation in ( + ConversationRecord.from_payload(item) + if isinstance(item, dict) + else None + for item in items + ) + if conversation is not None + ] + + async def update_conversation( + self, + session: str | MessageSession, + conversation_id: str | None = None, + params: ConversationUpdateParams | None = None, + ) -> None: + await self._proxy.call( + "conversation.update", + { + "session": _normalize_session(session), + "conversation_id": conversation_id, + "conversation": ( + params.to_update_payload() if params is not None else {} + ), + }, + ) + + async def unset_persona( + self, + session: str | MessageSession, + conversation_id: str | None = None, + ) -> None: + await self._proxy.call( + "conversation.unset_persona", + { + "session": _normalize_session(session), + "conversation_id": conversation_id, + }, + ) + + +class MessageHistoryManagerClient: + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def list( + self, + session: MessageSession, + *, + cursor: str | None = None, + limit: int = 50, + ) -> MessageHistoryPage: + output = await self._proxy.call( + "message_history.list", + { + "session": _require_message_history_session(session), + "cursor": str(cursor) if cursor is not None else None, + "limit": int(limit), + }, + ) + page = MessageHistoryPage.from_payload(output.get("page")) + if page is None: + raise ValueError("message_history.list returned no page") + return page + + async def get( + self, + session: MessageSession, + record_id: int, + ) -> MessageHistoryRecord | None: + output = await self._proxy.call( + "message_history.get_by_id", + { + "session": _require_message_history_session(session), + "record_id": int(record_id), + }, + ) + return MessageHistoryRecord.from_payload(output.get("record")) + + async def get_by_id( + self, + session: MessageSession, + record_id: int, + ) -> MessageHistoryRecord | None: + return await self.get(session, record_id) + + async def append( + self, + session: MessageSession, + *, + parts: list[BaseMessageComponent], + sender: MessageHistorySender, + metadata: dict[str, Any] | None = None, + idempotency_key: str | None = None, + ) -> MessageHistoryRecord: + if isinstance(sender, MessageHistorySender): + sender_payload = sender.to_payload() + elif isinstance(sender, dict): + sender_payload = MessageHistorySender.model_validate(sender).to_payload() + else: + raise TypeError( + "message_history.append requires MessageHistorySender for sender" + ) + output = await self._proxy.call( + "message_history.append", + { + "session": _require_message_history_session(session), + "sender": sender_payload, + "parts": _normalize_message_history_parts(parts), + "metadata": dict(metadata or {}), + "idempotency_key": ( + str(idempotency_key) if idempotency_key is not None else None + ), + }, + ) + record = MessageHistoryRecord.from_payload(output.get("record")) + if record is None: + raise ValueError("message_history.append returned no record") + return record + + async def delete_before( + self, + session: MessageSession, + *, + before: datetime, + ) -> int: + output = await self._proxy.call( + "message_history.delete_before", + { + "session": _require_message_history_session(session), + "before": _normalize_message_history_boundary(before), + }, + ) + return int(output.get("deleted_count", 0) or 0) + + async def delete_after( + self, + session: MessageSession, + *, + after: datetime, + ) -> int: + output = await self._proxy.call( + "message_history.delete_after", + { + "session": _require_message_history_session(session), + "after": _normalize_message_history_boundary(after), + }, + ) + return int(output.get("deleted_count", 0) or 0) + + async def delete_all(self, session: MessageSession) -> int: + output = await self._proxy.call( + "message_history.delete_all", + {"session": _require_message_history_session(session)}, + ) + return int(output.get("deleted_count", 0) or 0) + + +class KnowledgeBaseManagerClient: + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def list_kbs(self) -> list[KnowledgeBaseRecord]: + output = await self._proxy.call("kb.list", {}) + items = output.get("kbs") + if not isinstance(items, list): + return [] + return [ + kb + for kb in ( + KnowledgeBaseRecord.from_payload(item) + if isinstance(item, dict) + else None + for item in items + ) + if kb is not None + ] + + async def get_kb(self, kb_id: str) -> KnowledgeBaseRecord | None: + output = await self._proxy.call("kb.get", {"kb_id": str(kb_id)}) + return KnowledgeBaseRecord.from_payload(output.get("kb")) + + async def create_kb( + self, + params: KnowledgeBaseCreateParams, + ) -> KnowledgeBaseRecord: + output = await self._proxy.call("kb.create", {"kb": params.to_payload()}) + kb = KnowledgeBaseRecord.from_payload(output.get("kb")) + if kb is None: + raise ValueError("kb.create returned no knowledge base") + return kb + + async def update_kb( + self, + kb_id: str, + params: KnowledgeBaseUpdateParams, + ) -> KnowledgeBaseRecord | None: + output = await self._proxy.call( + "kb.update", + {"kb_id": str(kb_id), "kb": params.to_update_payload()}, + ) + return KnowledgeBaseRecord.from_payload(output.get("kb")) + + async def delete_kb(self, kb_id: str) -> bool: + output = await self._proxy.call("kb.delete", {"kb_id": str(kb_id)}) + return bool(output.get("deleted", False)) + + async def retrieve( + self, + query: str, + *, + kb_ids: list[str] | None = None, + kb_names: list[str] | None = None, + top_k_fusion: int | None = None, + top_m_final: int | None = None, + ) -> KnowledgeBaseRetrieveResult | None: + request_payload: dict[str, Any] = { + "query": str(query), + "kb_ids": [str(item) for item in (kb_ids or [])], + "kb_names": [str(item) for item in (kb_names or [])], + } + if top_k_fusion is not None: + request_payload["top_k_fusion"] = int(top_k_fusion) + if top_m_final is not None: + request_payload["top_m_final"] = int(top_m_final) + output = await self._proxy.call( + "kb.retrieve", + request_payload, + ) + return KnowledgeBaseRetrieveResult.from_payload(output.get("result")) + + async def upload_document( + self, + kb_id: str, + params: KnowledgeBaseDocumentUploadParams, + ) -> KnowledgeBaseDocumentRecord: + output = await self._proxy.call( + "kb.document.upload", + {"kb_id": str(kb_id), "document": params.to_payload()}, + ) + document = KnowledgeBaseDocumentRecord.from_payload(output.get("document")) + if document is None: + raise ValueError("kb.document.upload returned no document") + return document + + async def list_documents( + self, + kb_id: str, + *, + offset: int = 0, + limit: int = 100, + ) -> list[KnowledgeBaseDocumentRecord]: + output = await self._proxy.call( + "kb.document.list", + {"kb_id": str(kb_id), "offset": int(offset), "limit": int(limit)}, + ) + items = output.get("documents") + if not isinstance(items, list): + return [] + return [ + document + for document in ( + KnowledgeBaseDocumentRecord.from_payload(item) + if isinstance(item, dict) + else None + for item in items + ) + if document is not None + ] + + async def get_document( + self, + kb_id: str, + doc_id: str, + ) -> KnowledgeBaseDocumentRecord | None: + output = await self._proxy.call( + "kb.document.get", + {"kb_id": str(kb_id), "doc_id": str(doc_id)}, + ) + return KnowledgeBaseDocumentRecord.from_payload(output.get("document")) + + async def delete_document( + self, + kb_id: str, + doc_id: str, + ) -> bool: + output = await self._proxy.call( + "kb.document.delete", + {"kb_id": str(kb_id), "doc_id": str(doc_id)}, + ) + return bool(output.get("deleted", False)) + + async def refresh_document( + self, + kb_id: str, + doc_id: str, + ) -> KnowledgeBaseDocumentRecord | None: + output = await self._proxy.call( + "kb.document.refresh", + {"kb_id": str(kb_id), "doc_id": str(doc_id)}, + ) + return KnowledgeBaseDocumentRecord.from_payload(output.get("document")) + + +__all__ = [ + "ConversationCreateParams", + "ConversationManagerClient", + "ConversationRecord", + "ConversationUpdateParams", + "KnowledgeBaseCreateParams", + "KnowledgeBaseDocumentRecord", + "KnowledgeBaseDocumentUploadParams", + "KnowledgeBaseManagerClient", + "KnowledgeBaseRecord", + "KnowledgeBaseRetrieveResult", + "KnowledgeBaseRetrieveResultItem", + "KnowledgeBaseUpdateParams", + "MessageHistoryManagerClient", + "MessageHistoryPage", + "MessageHistoryRecord", + "MessageHistorySender", + "PersonaCreateParams", + "PersonaManagerClient", + "PersonaRecord", + "PersonaUpdateParams", +] diff --git a/astrbot-sdk/src/astrbot_sdk/clients/mcp.py b/astrbot-sdk/src/astrbot_sdk/clients/mcp.py new file mode 100644 index 0000000000..9e486d5231 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/mcp.py @@ -0,0 +1,302 @@ +from __future__ import annotations + +from contextlib import AbstractAsyncContextManager +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from ..errors import AstrBotError +from ._proxy import CapabilityProxy + + +class MCPServerScope(str, Enum): + local = "local" + global_ = "global" + + +@dataclass(slots=True) +class MCPServerRecord: + name: str + scope: MCPServerScope + active: bool + running: bool + config: dict[str, Any] = field(default_factory=dict) + tools: list[str] = field(default_factory=list) + errlogs: list[str] = field(default_factory=list) + last_error: str | None = None + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> MCPServerRecord | None: + if not isinstance(payload, dict): + return None + scope_value = str(payload.get("scope") or MCPServerScope.local.value).strip() + try: + scope = MCPServerScope(scope_value) + except ValueError: + scope = MCPServerScope.local + return cls( + name=str(payload.get("name", "")), + scope=scope, + active=bool(payload.get("active", False)), + running=bool(payload.get("running", False)), + config=( + dict(payload.get("config")) + if isinstance(payload.get("config"), dict) + else {} + ), + tools=[ + str(item) + for item in payload.get("tools", []) + if isinstance(item, str) and item + ] + if isinstance(payload.get("tools"), list) + else [], + errlogs=[ + str(item) + for item in payload.get("errlogs", []) + if isinstance(item, str) + ] + if isinstance(payload.get("errlogs"), list) + else [], + last_error=( + str(payload.get("last_error")) + if payload.get("last_error") is not None + else None + ), + ) + + +class MCPSession(AbstractAsyncContextManager["MCPSession"]): + def __init__( + self, + proxy: CapabilityProxy, + *, + name: str, + config: dict[str, Any], + timeout: float, + ) -> None: + self._proxy = proxy + self._name = str(name) + self._config = dict(config) + self._timeout = float(timeout) + self._session_id: str | None = None + self._tools: list[str] = [] + + async def __aenter__(self) -> MCPSession: + output = await self._proxy.call( + "mcp.session.open", + { + "name": self._name, + "config": dict(self._config), + "timeout": self._timeout, + }, + ) + session_id = str(output.get("session_id", "")).strip() + if not session_id: + raise ValueError("mcp.session.open returned no session_id") + self._session_id = session_id + tools = output.get("tools") + self._tools = ( + [str(item) for item in tools if isinstance(item, str)] + if isinstance(tools, list) + else [] + ) + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + session_id = self._session_id + self._session_id = None + self._tools = [] + if not session_id: + return + try: + await self._proxy.call("mcp.session.close", {"session_id": session_id}) + except AstrBotError: + raise + except Exception: + # Session cleanup should not mask the original error raised inside the + # managed block. + if exc_type is None: + raise + + async def call_tool( + self, + tool_name: str, + args: dict[str, Any] | None = None, + ) -> dict[str, Any]: + session_id = self._require_session_id() + output = await self._proxy.call( + "mcp.session.call_tool", + { + "session_id": session_id, + "tool_name": str(tool_name), + "args": dict(args or {}), + }, + ) + result = output.get("result") + if not isinstance(result, dict): + raise ValueError("mcp.session.call_tool returned no result object") + return dict(result) + + async def list_tools(self) -> list[str]: + session_id = self._require_session_id() + output = await self._proxy.call( + "mcp.session.list_tools", + {"session_id": session_id}, + ) + tools = output.get("tools") + self._tools = ( + [str(item) for item in tools if isinstance(item, str)] + if isinstance(tools, list) + else [] + ) + return list(self._tools) + + def _require_session_id(self) -> str: + if self._session_id is None: + raise RuntimeError("MCP session is not active; use 'async with'") + return self._session_id + + +class MCPManagerClient: + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def get_server(self, name: str) -> MCPServerRecord | None: + output = await self._proxy.call("mcp.local.get", {"name": str(name)}) + return MCPServerRecord.from_payload(output.get("server")) + + async def list_servers(self) -> list[MCPServerRecord]: + output = await self._proxy.call("mcp.local.list", {}) + items = output.get("servers") + if not isinstance(items, list): + return [] + return [ + record + for record in ( + MCPServerRecord.from_payload(item) if isinstance(item, dict) else None + for item in items + ) + if record is not None + ] + + async def enable_server(self, name: str) -> MCPServerRecord: + output = await self._proxy.call("mcp.local.enable", {"name": str(name)}) + record = MCPServerRecord.from_payload(output.get("server")) + if record is None: + raise ValueError("mcp.local.enable returned no server") + return record + + async def disable_server(self, name: str) -> MCPServerRecord: + output = await self._proxy.call("mcp.local.disable", {"name": str(name)}) + record = MCPServerRecord.from_payload(output.get("server")) + if record is None: + raise ValueError("mcp.local.disable returned no server") + return record + + async def wait_until_ready( + self, + name: str, + *, + timeout: float = 30.0, + ) -> MCPServerRecord: + output = await self._proxy.call( + "mcp.local.wait_until_ready", + {"name": str(name), "timeout": float(timeout)}, + ) + record = MCPServerRecord.from_payload(output.get("server")) + if record is None: + raise ValueError("mcp.local.wait_until_ready returned no server") + return record + + def session( + self, + name: str, + config: dict[str, Any], + *, + timeout: float = 30.0, + ) -> MCPSession: + return MCPSession( + self._proxy, + name=str(name), + config=dict(config), + timeout=float(timeout), + ) + + async def register_global_server( + self, + name: str, + config: dict[str, Any], + *, + timeout: float = 30.0, + ) -> MCPServerRecord: + output = await self._proxy.call( + "mcp.global.register", + { + "name": str(name), + "config": dict(config), + "timeout": float(timeout), + }, + ) + record = MCPServerRecord.from_payload(output.get("server")) + if record is None: + raise ValueError("mcp.global.register returned no server") + return record + + async def get_global_server(self, name: str) -> MCPServerRecord | None: + output = await self._proxy.call("mcp.global.get", {"name": str(name)}) + return MCPServerRecord.from_payload(output.get("server")) + + async def list_global_servers(self) -> list[MCPServerRecord]: + output = await self._proxy.call("mcp.global.list", {}) + items = output.get("servers") + if not isinstance(items, list): + return [] + return [ + record + for record in ( + MCPServerRecord.from_payload(item) if isinstance(item, dict) else None + for item in items + ) + if record is not None + ] + + async def enable_global_server( + self, + name: str, + *, + timeout: float = 30.0, + ) -> MCPServerRecord: + output = await self._proxy.call( + "mcp.global.enable", + {"name": str(name), "timeout": float(timeout)}, + ) + record = MCPServerRecord.from_payload(output.get("server")) + if record is None: + raise ValueError("mcp.global.enable returned no server") + return record + + async def disable_global_server(self, name: str) -> MCPServerRecord: + output = await self._proxy.call("mcp.global.disable", {"name": str(name)}) + record = MCPServerRecord.from_payload(output.get("server")) + if record is None: + raise ValueError("mcp.global.disable returned no server") + return record + + async def unregister_global_server(self, name: str) -> MCPServerRecord: + output = await self._proxy.call("mcp.global.unregister", {"name": str(name)}) + record = MCPServerRecord.from_payload(output.get("server")) + if record is None: + raise ValueError("mcp.global.unregister returned no server") + return record + + +__all__ = [ + "MCPManagerClient", + "MCPSession", + "MCPServerRecord", + "MCPServerScope", +] diff --git a/astrbot-sdk/src/astrbot_sdk/clients/memory.py b/astrbot-sdk/src/astrbot_sdk/clients/memory.py new file mode 100644 index 0000000000..1ba91f1447 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/memory.py @@ -0,0 +1,432 @@ +"""记忆客户端模块。 + +提供 AI 记忆存储能力,用于存储和检索对话记忆、用户偏好等上下文数据。 + +设计说明: + MemoryClient 与 DBClient 的区别: + - DBClient: 简单的键值存储,精确匹配 + - MemoryClient: 支持基于当前 bridge 行为的记忆检索,适合 AI 上下文管理 + + 记忆系统可用于: + - 存储用户偏好和设置 + - 记录对话摘要 + - 缓存 AI 推理结果 +""" + +from __future__ import annotations + +from typing import Any, Literal + +from .._internal.memory_utils import join_memory_namespace +from ._proxy import CapabilityProxy + + +def _normalize_search_item(item: Any) -> dict[str, Any] | None: + if not isinstance(item, dict): + return None + normalized = dict(item) + value = normalized.get("value") + if isinstance(value, dict): + for key, payload_value in value.items(): + normalized.setdefault(str(key), payload_value) + return normalized + + +class MemoryClient: + """记忆客户端。 + + 提供 AI 记忆的存储和检索能力。 + + Attributes: + _proxy: CapabilityProxy 实例,用于远程能力调用 + """ + + def __init__( + self, + proxy: CapabilityProxy, + *, + namespace: str | None = None, + ) -> None: + """初始化记忆客户端。 + + Args: + proxy: CapabilityProxy 实例 + """ + self._proxy = proxy + self._namespace = join_memory_namespace(namespace) + + def namespace(self, *parts: Any) -> MemoryClient: + """Create a derived client that operates inside a child namespace.""" + + return MemoryClient( + self._proxy, + namespace=join_memory_namespace(self._namespace, *parts), + ) + + def _resolve_exact_namespace(self, namespace: str | None) -> str: + if namespace is None: + return self._namespace + return join_memory_namespace(self._namespace, namespace) + + def _resolve_scope_namespace(self, namespace: str | None) -> tuple[bool, str]: + if namespace is None: + if self._namespace: + return True, self._namespace + return False, "" + return True, join_memory_namespace(self._namespace, namespace) + + async def search( + self, + query: str, + *, + mode: Literal["auto", "keyword", "vector", "hybrid"] = "auto", + limit: int | None = None, + min_score: float | None = None, + provider_id: str | None = None, + namespace: str | None = None, + include_descendants: bool = True, + ) -> list[dict[str, Any]]: + """搜索记忆项。 + + 默认会在有 embedding provider 时执行 hybrid 检索, + 否则退化为关键词检索。返回结果包含 `score` 与 `match_type` 字段。 + + Args: + query: 搜索查询文本 + mode: 搜索模式,支持 auto/keyword/vector/hybrid + limit: 最大返回条数 + min_score: 最低分数阈值 + provider_id: 指定 embedding provider,默认使用当前激活的 provider + + Returns: + 匹配的记忆项列表,按相关度排序 + + 示例: + results = await ctx.memory.search( + "用户喜欢什么颜色", + mode="hybrid", + limit=5, + ) + for item in results: + print(item["key"], item["score"], item["match_type"]) + """ + payload: dict[str, Any] = {"query": query, "mode": mode} + if limit is not None: + payload["limit"] = limit + if min_score is not None: + payload["min_score"] = min_score + if provider_id is not None: + payload["provider_id"] = provider_id + has_namespace, resolved_namespace = self._resolve_scope_namespace(namespace) + if has_namespace: + payload["namespace"] = resolved_namespace + payload["include_descendants"] = bool(include_descendants) + output = await self._proxy.call("memory.search", payload) + items = output.get("items") + if not isinstance(items, (list, tuple)): + return [] + normalized_items: list[dict[str, Any]] = [] + for item in items: + normalized = _normalize_search_item(item) + if normalized is not None: + normalized_items.append(normalized) + return normalized_items + + async def save( + self, + key: str, + value: dict[str, Any] | None = None, + namespace: str | None = None, + **extra: Any, + ) -> None: + """保存记忆项。 + + 将数据存储到记忆系统,可通过 search() 检索或 get() 精确获取。 + + Args: + key: 记忆项的唯一标识键 + value: 要存储的数据字典 + **extra: 额外的键值对,会合并到 value 中 + Raises: + TypeError: 如果 value 不是 dict 类型 + 示例: + 保存用户偏好 + await ctx.memory.save("user_pref", {"theme": "dark", "lang": "zh"}) + + 使用关键字参数 + await ctx.memory.save("note", None, content="重要笔记", tags=["work"]) + + 使用 embedding_text 显式指定检索文本 + await ctx.memory.save( + "profile", + {"name": "alice", "embedding_text": "Alice 喜欢蓝色和海边"}, + ) + """ + if value is not None and not isinstance(value, dict): + raise TypeError("memory.save 的 value 必须是 dict") + payload = dict(value or {}) + if extra: + payload.update(extra) + request: dict[str, Any] = {"key": key, "value": payload} + request["namespace"] = self._resolve_exact_namespace(namespace) + await self._proxy.call("memory.save", request) + + async def get( + self, + key: str, + *, + namespace: str | None = None, + ) -> dict[str, Any] | None: + """精确获取单个记忆项。 + + 通过唯一键精确获取记忆内容,不经过搜索匹配。 + + Args: + key: 记忆项的唯一键 + + Returns: + 记忆项内容字典,若不存在则返回 None + + 示例: + pref = await ctx.memory.get("user_pref") + if pref: + print(f"用户偏好主题: {pref.get('theme')}") + """ + payload: dict[str, Any] = {"key": key} + payload["namespace"] = self._resolve_exact_namespace(namespace) + output = await self._proxy.call("memory.get", payload) + value = output.get("value") + return value if isinstance(value, dict) else None + + async def list_keys( + self, + *, + namespace: str | None = None, + ) -> list[str]: + """List keys in the exact namespace using case-insensitive ordering.""" + + payload: dict[str, Any] = { + "namespace": self._resolve_exact_namespace(namespace) + } + output = await self._proxy.call("memory.list_keys", payload) + keys = output.get("keys") + if not isinstance(keys, (list, tuple)): + return [] + return [str(item) for item in keys] + + async def exists( + self, + key: str, + *, + namespace: str | None = None, + ) -> bool: + """Check whether a key exists in the exact namespace.""" + + payload: dict[str, Any] = {"key": key} + payload["namespace"] = self._resolve_exact_namespace(namespace) + output = await self._proxy.call("memory.exists", payload) + return bool(output.get("exists", False)) + + async def delete( + self, + key: str, + *, + namespace: str | None = None, + ) -> None: + """删除记忆项。 + + Args: + key: 要删除的记忆项键名 + + 示例: + await ctx.memory.delete("old_note") + """ + payload: dict[str, Any] = {"key": key} + payload["namespace"] = self._resolve_exact_namespace(namespace) + await self._proxy.call("memory.delete", payload) + + async def clear_namespace( + self, + *, + namespace: str | None = None, + include_descendants: bool = False, + ) -> int: + """Delete memories in a namespace and optionally its descendants.""" + + payload: dict[str, Any] = { + "namespace": self._resolve_exact_namespace(namespace), + "include_descendants": bool(include_descendants), + } + output = await self._proxy.call("memory.clear_namespace", payload) + return int(output.get("deleted_count", 0)) + + async def save_with_ttl( + self, + key: str, + value: dict[str, Any], + ttl_seconds: int, + *, + namespace: str | None = None, + ) -> None: + """保存带过期时间的记忆项。 + + 与 save() 不同,此方法允许设置记忆项的存活时间(TTL), + 过期后记忆项将自动删除。 + + Args: + key: 记忆项的唯一标识键 + value: 要存储的数据字典 + ttl_seconds: 存活时间(秒),必须大于 0 + + Raises: + TypeError: 如果 value 不是 dict 类型 + ValueError: 如果 ttl_seconds 小于 1 + + 示例: + # 保存临时会话状态,1小时后过期 + await ctx.memory.save_with_ttl( + "session_temp", + {"state": "waiting"}, + ttl_seconds=3600, + ) + """ + if not isinstance(value, dict): + raise TypeError("memory.save_with_ttl 的 value 必须是 dict") + if ttl_seconds < 1: + raise ValueError("ttl_seconds 必须大于 0") + payload: dict[str, Any] = { + "key": key, + "value": value, + "ttl_seconds": ttl_seconds, + } + payload["namespace"] = self._resolve_exact_namespace(namespace) + await self._proxy.call("memory.save_with_ttl", payload) + + async def get_many( + self, + keys: list[str], + *, + namespace: str | None = None, + ) -> list[dict[str, Any]]: + """批量获取多个记忆项。 + + 一次性获取多个键对应的记忆内容,比多次调用 get() 更高效。 + + Args: + keys: 记忆项键名列表 + + Returns: + 记忆项列表,每项包含 key 和 value 字段, + 不存在的键返回 value 为 None + + 示例: + items = await ctx.memory.get_many(["pref1", "pref2", "pref3"]) + for item in items: + if item["value"]: + print(f"{item['key']}: {item['value']}") + """ + payload: dict[str, Any] = {"keys": keys} + payload["namespace"] = self._resolve_exact_namespace(namespace) + output = await self._proxy.call("memory.get_many", payload) + items = output.get("items") + if not isinstance(items, (list, tuple)): + return [] + return [dict(item) for item in items if isinstance(item, dict)] + + async def delete_many( + self, + keys: list[str], + *, + namespace: str | None = None, + ) -> int: + """批量删除多个记忆项。 + + 一次性删除多个键对应的记忆项,返回实际删除的数量。 + + Args: + keys: 要删除的记忆项键名列表 + + Returns: + 实际删除的记忆项数量 + + 示例: + deleted = await ctx.memory.delete_many(["old1", "old2", "old3"]) + print(f"删除了 {deleted} 条记忆") + """ + payload: dict[str, Any] = {"keys": keys} + payload["namespace"] = self._resolve_exact_namespace(namespace) + output = await self._proxy.call("memory.delete_many", payload) + return int(output.get("deleted_count", 0)) + + async def count( + self, + *, + namespace: str | None = None, + include_descendants: bool = False, + ) -> int: + """Count memories in a namespace and optionally its descendants.""" + + payload: dict[str, Any] = { + "namespace": self._resolve_exact_namespace(namespace), + "include_descendants": bool(include_descendants), + } + output = await self._proxy.call("memory.count", payload) + return int(output.get("count", 0)) + + async def stats( + self, + *, + namespace: str | None = None, + include_descendants: bool = True, + ) -> dict[str, Any]: + """获取记忆系统统计信息。 + + 返回记忆系统的当前状态,包括条目数、索引状态和脏索引数量。 + + Returns: + 统计信息字典,包含: + - total_items: 总记忆条目数 + - total_bytes: 总占用字节数(可选) + - ttl_entries: 带过期时间的条目数(可选) + - indexed_items: 已建立检索索引的条目数(可选) + - embedded_items: 已生成向量的条目数(可选) + - dirty_items: 等待重建索引的条目数(可选) + + 示例: + stats = await ctx.memory.stats() + print(f"记忆库共有 {stats['total_items']} 条记录") + if "embedded_items" in stats: + print(f"其中 {stats['embedded_items']} 条已经向量化") + """ + payload: dict[str, Any] = { + "include_descendants": bool(include_descendants), + } + has_namespace, resolved_namespace = self._resolve_scope_namespace(namespace) + if has_namespace: + payload["namespace"] = resolved_namespace + output = await self._proxy.call("memory.stats", payload) + stats = { + "total_items": output.get("total_items", 0), + "total_bytes": output.get("total_bytes"), + } + if "namespace" in output: + stats["namespace"] = output.get("namespace") + if "namespace_count" in output: + stats["namespace_count"] = output.get("namespace_count") + if "fts_enabled" in output: + stats["fts_enabled"] = output.get("fts_enabled") + if "vector_backend" in output: + stats["vector_backend"] = output.get("vector_backend") + if "vector_indexes" in output: + stats["vector_indexes"] = output.get("vector_indexes") + if "plugin_id" in output: + stats["plugin_id"] = output.get("plugin_id") + if "ttl_entries" in output: + stats["ttl_entries"] = output.get("ttl_entries") + if "indexed_items" in output: + stats["indexed_items"] = output.get("indexed_items") + if "embedded_items" in output: + stats["embedded_items"] = output.get("embedded_items") + if "dirty_items" in output: + stats["dirty_items"] = output.get("dirty_items") + return stats diff --git a/astrbot-sdk/src/astrbot_sdk/clients/metadata.py b/astrbot-sdk/src/astrbot_sdk/clients/metadata.py new file mode 100644 index 0000000000..92185b64e2 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/metadata.py @@ -0,0 +1,111 @@ +"""元数据客户端模块。 + +提供插件元数据查询能力。 + +功能说明: + - 查询已加载插件信息 + - 获取插件列表 + - 访问当前插件配置 + +安全边界: + 插件身份由运行时透传到协议层;客户端只暴露业务参数,不接受外部指定调用者。 +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from ._proxy import CapabilityProxy + + +@dataclass +class StarMetadata: + """插件元数据。""" + + name: str + display_name: str + description: str + author: str + version: str + enabled: bool = True + support_platforms: list[str] = field(default_factory=list) + astrbot_version: str | None = None + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> StarMetadata: + raw_support_platforms = data.get("support_platforms") + support_platforms = ( + [str(item) for item in raw_support_platforms if isinstance(item, str)] + if isinstance(raw_support_platforms, list) + else [] + ) + return cls( + name=str(data.get("name", "")), + display_name=str(data.get("display_name", data.get("name", ""))), + description=str(data.get("desc", data.get("description", ""))), + author=str(data.get("author", "")), + version=str(data.get("version", "0.0.0")), + enabled=bool(data.get("enabled", True)), + support_platforms=support_platforms, + astrbot_version=( + str(data.get("astrbot_version")) + if data.get("astrbot_version") is not None + else None + ), + ) + + +PluginMetadata = StarMetadata + + +class MetadataClient: + """元数据能力客户端。""" + + def __init__(self, proxy: CapabilityProxy, plugin_id: str) -> None: + self._proxy = proxy + self._plugin_id = plugin_id + + async def get_plugin(self, name: str) -> StarMetadata | None: + output = await self._proxy.call( + "metadata.get_plugin", + {"name": name}, + ) + data = output.get("plugin") + if data is None: + return None + return StarMetadata.from_dict(data) + + async def list_plugins(self) -> list[StarMetadata]: + output = await self._proxy.call("metadata.list_plugins", {}) + items = output.get("plugins", []) + return [ + StarMetadata.from_dict(item) for item in items if isinstance(item, dict) + ] + + async def get_current_plugin(self) -> StarMetadata | None: + return await self.get_plugin(self._plugin_id) + + async def get_plugin_config(self, name: str | None = None) -> dict[str, Any] | None: + target = name or self._plugin_id + if target != self._plugin_id: + raise PermissionError( + "get_plugin_config 只允许访问当前插件自己的配置," + f"请求的插件 '{target}' 不是当前插件 '{self._plugin_id}'" + ) + output = await self._proxy.call( + "metadata.get_plugin_config", + {"name": target}, + ) + config = output.get("config") + return dict(config) if isinstance(config, dict) else None + + async def save_plugin_config(self, config: dict[str, Any]) -> dict[str, Any]: + if not isinstance(config, dict): + raise TypeError("save_plugin_config requires a dict payload") + output = await self._proxy.call( + "metadata.save_plugin_config", + {"config": dict(config)}, + ) + saved = output.get("config") + return dict(saved) if isinstance(saved, dict) else {} diff --git a/astrbot-sdk/src/astrbot_sdk/clients/permission.py b/astrbot-sdk/src/astrbot_sdk/clients/permission.py new file mode 100644 index 0000000000..a5170d35e0 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/permission.py @@ -0,0 +1,94 @@ +"""Permission capability clients.""" + +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict + +from ._proxy import CapabilityProxy + + +class PermissionCheckResult(BaseModel): + model_config = ConfigDict(extra="forbid") + + is_admin: bool + role: Literal["member", "admin"] + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> PermissionCheckResult | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class PermissionClient: + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def check( + self, + user_id: str, + session_id: str | None = None, + ) -> PermissionCheckResult: + payload: dict[str, Any] = {"user_id": str(user_id)} + if session_id is not None: + payload["session_id"] = str(session_id) + output = await self._proxy.call("permission.check", payload) + result = PermissionCheckResult.from_payload(output) + if result is None: + return PermissionCheckResult(is_admin=False, role="member") + return result + + async def get_admins(self) -> list[str]: + output = await self._proxy.call("permission.get_admins", {}) + admins = output.get("admins") + if not isinstance(admins, list): + return [] + return [str(item) for item in admins] + + +class PermissionManagerClient: + def __init__( + self, + proxy: CapabilityProxy, + *, + source_event_payload: dict[str, Any] | None = None, + ) -> None: + self._proxy = proxy + self._source_event_payload = ( + dict(source_event_payload) if isinstance(source_event_payload, dict) else {} + ) + + def _caller_is_admin(self) -> bool: + return bool(self._source_event_payload.get("is_admin", False)) + + async def add_admin(self, user_id: str) -> bool: + output = await self._proxy.call( + "permission.manager.add_admin", + { + "user_id": str(user_id), + "_caller_is_admin": self._caller_is_admin(), + }, + ) + return bool(output.get("changed", False)) + + async def remove_admin(self, user_id: str) -> bool: + output = await self._proxy.call( + "permission.manager.remove_admin", + { + "user_id": str(user_id), + "_caller_is_admin": self._caller_is_admin(), + }, + ) + return bool(output.get("changed", False)) + + +__all__ = [ + "PermissionCheckResult", + "PermissionClient", + "PermissionManagerClient", +] diff --git a/astrbot-sdk/src/astrbot_sdk/clients/platform.py b/astrbot-sdk/src/astrbot_sdk/clients/platform.py new file mode 100644 index 0000000000..10142a7350 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/platform.py @@ -0,0 +1,300 @@ +"""平台客户端模块。 + +提供 v4 原生的平台能力调用。 + +设计边界: + - `PlatformClient` 只负责直接的平台 capability + - 迁移期消息桥接由独立迁移入口承接,不放进原生客户端 + - 富消息链通过 `platform.send_chain` 发送,链构建能力位于专门的消息模块 +""" + +from __future__ import annotations + +from collections.abc import Sequence +from enum import Enum +from typing import Any, cast + +from pydantic import BaseModel, ConfigDict, Field + +from ..message.components import BaseMessageComponent, Plain +from ..message.result import MessageChain +from ..message.session import MessageSession +from ..protocol.descriptors import SessionRef +from ._proxy import CapabilityProxy + + +class _PlatformModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class PlatformStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + ERROR = "error" + STOPPED = "stopped" + + @classmethod + def from_value(cls, value: Any) -> PlatformStatus: + if isinstance(value, cls): + return value + try: + return cls(str(value).strip().lower()) + except ValueError: + return cls.PENDING + + +class PlatformError(_PlatformModel): + message: str + timestamp: str + traceback: str | None = None + + @classmethod + def from_payload(cls, payload: dict[str, Any] | None) -> PlatformError | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class PlatformStats(_PlatformModel): + id: str + type: str + display_name: str + status: PlatformStatus + started_at: str | None = None + error_count: int + last_error: PlatformError | None = None + unified_webhook: bool + meta: dict[str, Any] = Field(default_factory=dict) + + @classmethod + def from_payload(cls, payload: dict[str, Any] | None) -> PlatformStats | None: + if not isinstance(payload, dict): + return None + normalized = dict(payload) + normalized["status"] = PlatformStatus.from_value(payload.get("status")) + normalized["last_error"] = PlatformError.from_payload( + payload.get("last_error") if isinstance(payload, dict) else None + ) + meta = payload.get("meta") + normalized["meta"] = dict(meta) if isinstance(meta, dict) else {} + return cls.model_validate(normalized) + + +class PlatformClient: + """平台消息客户端。 + + 提供向聊天平台发送消息和获取信息的能力。 + + Attributes: + _proxy: CapabilityProxy 实例,用于远程能力调用 + """ + + def __init__(self, proxy: CapabilityProxy) -> None: + """初始化平台客户端。 + + Args: + proxy: CapabilityProxy 实例 + """ + self._proxy = proxy + + def _build_target_payload( + self, + session: str | SessionRef | MessageSession, + ) -> tuple[str, dict[str, Any]]: + if isinstance(session, SessionRef): + return session.session, {"target": session.to_payload()} + if isinstance(session, MessageSession): + return str(session), {} + return str(session), {} + + async def _coerce_chain_payload( + self, + content: ( + str + | MessageChain + | Sequence[BaseMessageComponent] + | Sequence[dict[str, Any]] + ), + ) -> list[dict[str, Any]]: + if isinstance(content, str): + return await MessageChain( + [Plain(content, convert=False)] + ).to_payload_async() + if isinstance(content, MessageChain): + return await content.to_payload_async() + if ( + isinstance(content, Sequence) + and not isinstance(content, (str, bytes)) + and all(isinstance(item, BaseMessageComponent) for item in content) + ): + components = cast(Sequence[BaseMessageComponent], content) + return await MessageChain(list(components)).to_payload_async() + if ( + isinstance(content, Sequence) + and not isinstance(content, (str, bytes)) + and all(isinstance(item, dict) for item in content) + ): + payload_items = cast(Sequence[dict[str, Any]], content) + return [dict(item) for item in payload_items] + raise TypeError( + "content must be str, MessageChain, sequence of message components, " + "or sequence of platform.send_chain payload dicts" + ) + + async def send( + self, + session: str | SessionRef | MessageSession, + text: str, + ) -> dict[str, Any]: + """发送文本消息。 + + 向指定的会话(用户或群组)发送文本消息。 + + Args: + session: 统一消息来源标识 (UMO),格式如 "platform:instance:user_id" + text: 要发送的文本内容 + + Returns: + 发送结果,可能包含消息 ID 等信息 + + 示例: + # 发送消息到当前会话 + await ctx.platform.send(event.session_id, "收到您的消息!") + """ + session_id, extra = self._build_target_payload(session) + return await self._proxy.call( + "platform.send", + {"session": session_id, "text": text, **extra}, + ) + + async def send_image( + self, + session: str | SessionRef | MessageSession, + image_url: str, + ) -> dict[str, Any]: + """发送图片消息。 + + 向指定的会话发送图片,支持 URL 或本地路径。 + + Args: + session: 统一消息来源标识 (UMO) + image_url: 图片 URL 或本地文件路径 + + Returns: + 发送结果 + + 示例: + await ctx.platform.send_image( + event.session_id, + "https://example.com/image.png" + ) + """ + session_id, extra = self._build_target_payload(session) + return await self._proxy.call( + "platform.send_image", + {"session": session_id, "image_url": image_url, **extra}, + ) + + async def send_chain( + self, + session: str | SessionRef | MessageSession, + chain: MessageChain | Sequence[BaseMessageComponent] | Sequence[dict[str, Any]], + ) -> dict[str, Any]: + """发送富消息链。 + + Args: + session: 统一消息来源标识 (UMO) + chain: 序列化后的消息组件数组 + + Returns: + 发送结果 + """ + session_id, extra = self._build_target_payload(session) + chain_payload = await self._coerce_chain_payload(chain) + return await self._proxy.call( + "platform.send_chain", + {"session": session_id, "chain": chain_payload, **extra}, + ) + + async def send_by_session( + self, + session: str | MessageSession, + content: ( + str + | MessageChain + | Sequence[BaseMessageComponent] + | Sequence[dict[str, Any]] + ), + ) -> dict[str, Any]: + """主动向指定会话发送消息链。 + + `Sequence[dict]` 的结构与 `platform.send_chain` 完全一致: + 每一项都应是 `{"type": "...", "data": {...}}`。 + """ + chain_payload = await self._coerce_chain_payload(content) + session_id = str(session) + return await self._proxy.call( + "platform.send_by_session", + {"session": session_id, "chain": chain_payload}, + ) + + async def send_by_id( + self, + platform_id: str, + session_id: str, + content: ( + str + | MessageChain + | Sequence[BaseMessageComponent] + | Sequence[dict[str, Any]] + ), + *, + message_type: str = "private", + ) -> dict[str, Any]: + """主动向指定平台会话发送消息。""" + session = MessageSession( + platform_id=str(platform_id), + message_type=str(message_type), + session_id=str(session_id), + ) + return await self.send_by_session(session, content) + + async def get_members( + self, + session: str | SessionRef | MessageSession, + ) -> list[dict[str, Any]]: + """获取群组成员列表。 + + 获取指定群组的成员信息列表。注意仅对群组会话有效。 + + Args: + session: 群组会话的统一消息来源标识 (UMO) + + Returns: + 成员信息列表,每个成员是一个字典,可能包含: + - user_id: 用户 ID + - nickname: 昵称 + - role: 角色 (owner, admin, member) + + 示例: + members = await ctx.platform.get_members(event.session_id) + for member in members: + print(f"{member['nickname']} ({member['user_id']})") + """ + session_id, extra = self._build_target_payload(session) + output = await self._proxy.call( + "platform.get_members", + {"session": session_id, **extra}, + ) + members = output.get("members") + if not isinstance(members, (list, tuple)): + return [] + return list(members) + + +__all__ = [ + "PlatformClient", + "PlatformError", + "PlatformStats", + "PlatformStatus", +] diff --git a/astrbot-sdk/src/astrbot_sdk/clients/provider.py b/astrbot-sdk/src/astrbot_sdk/clients/provider.py new file mode 100644 index 0000000000..20bf274c29 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/provider.py @@ -0,0 +1,349 @@ +"""Provider discovery and provider-management clients.""" + +from __future__ import annotations + +import asyncio +import contextlib +import inspect +from collections.abc import AsyncIterator, Awaitable, Callable +from typing import Any + +from pydantic import BaseModel, ConfigDict + +from ..llm.entities import ProviderMeta, ProviderType +from ..llm.providers import ( + ProviderProxy, + STTProvider, + TTSProvider, + provider_proxy_from_meta, +) +from ._proxy import CapabilityProxy + + +class _ProviderModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + def to_payload(self) -> dict[str, Any]: + return self.model_dump(exclude_none=True) + + +class ManagedProviderRecord(_ProviderModel): + id: str + model: str | None = None + type: str + provider_type: ProviderType + loaded: bool + enabled: bool + provider_source_id: str | None = None + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> ManagedProviderRecord | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class ProviderChangeEvent(_ProviderModel): + provider_id: str + provider_type: ProviderType + umo: str | None = None + + @classmethod + def from_payload( + cls, + payload: dict[str, Any] | None, + ) -> ProviderChangeEvent | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class ProviderClient: + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + @staticmethod + def _provider_meta_list(items: Any) -> list[ProviderMeta]: + if not isinstance(items, list): + return [] + providers: list[ProviderMeta] = [] + for item in items: + if not isinstance(item, dict): + continue + provider = ProviderMeta.from_payload(item) + if provider is not None: + providers.append(provider) + return providers + + async def list_all(self) -> list[ProviderMeta]: + output = await self._proxy.call("provider.list_all", {}) + return self._provider_meta_list(output.get("providers")) + + async def list_tts(self) -> list[ProviderMeta]: + output = await self._proxy.call("provider.list_all_tts", {}) + return self._provider_meta_list(output.get("providers")) + + async def list_stt(self) -> list[ProviderMeta]: + output = await self._proxy.call("provider.list_all_stt", {}) + return self._provider_meta_list(output.get("providers")) + + async def list_embedding(self) -> list[ProviderMeta]: + output = await self._proxy.call("provider.list_all_embedding", {}) + return self._provider_meta_list(output.get("providers")) + + async def list_rerank(self) -> list[ProviderMeta]: + output = await self._proxy.call("provider.list_all_rerank", {}) + return self._provider_meta_list(output.get("providers")) + + async def _get_tts_support_stream(self, provider_id: str) -> bool: + output = await self._proxy.call( + "provider.tts.support_stream", + {"provider_id": str(provider_id)}, + ) + return bool(output.get("supported", False)) + + async def _build_proxy(self, meta: ProviderMeta | None) -> ProviderProxy | None: + if meta is None: + return None + tts_supports_stream = None + if meta.provider_type == ProviderType.TEXT_TO_SPEECH: + tts_supports_stream = await self._get_tts_support_stream(meta.id) + return provider_proxy_from_meta( + self._proxy, + meta, + tts_supports_stream=tts_supports_stream, + ) + + async def get(self, provider_id: str) -> ProviderProxy | None: + output = await self._proxy.call( + "provider.get_by_id", + {"provider_id": str(provider_id)}, + ) + return await self._build_proxy( + ProviderMeta.from_payload(output.get("provider")) + ) + + async def get_using_chat(self, umo: str | None = None) -> ProviderMeta | None: + output = await self._proxy.call("provider.get_using", {"umo": umo}) + return ProviderMeta.from_payload(output.get("provider")) + + async def get_using_tts(self, umo: str | None = None) -> TTSProvider | None: + output = await self._proxy.call("provider.get_using_tts", {"umo": umo}) + provider = await self._build_proxy( + ProviderMeta.from_payload(output.get("provider")) + ) + return provider if isinstance(provider, TTSProvider) else None + + async def get_using_stt(self, umo: str | None = None) -> STTProvider | None: + output = await self._proxy.call("provider.get_using_stt", {"umo": umo}) + provider = await self._build_proxy( + ProviderMeta.from_payload(output.get("provider")) + ) + return provider if isinstance(provider, STTProvider) else None + + +class ProviderManagerClient: + def __init__( + self, + proxy: CapabilityProxy, + *, + plugin_id: str | None = None, + logger: Any | None = None, + ) -> None: + self._proxy = proxy + self._plugin_id = plugin_id + self._logger = logger + self._change_hook_tasks: set[asyncio.Task[None]] = set() + + @staticmethod + def _provider_type_value(provider_type: ProviderType | str) -> str: + if isinstance(provider_type, ProviderType): + return provider_type.value + return str(provider_type).strip() + + @staticmethod + def _record_from_output(output: dict[str, Any]) -> ManagedProviderRecord | None: + return ManagedProviderRecord.from_payload(output.get("provider")) + + async def set_provider( + self, + provider_id: str, + provider_type: ProviderType | str, + umo: str | None = None, + ) -> None: + await self._proxy.call( + "provider.manager.set", + { + "provider_id": str(provider_id), + "provider_type": self._provider_type_value(provider_type), + "umo": umo, + }, + ) + + async def get_provider_by_id( + self, + provider_id: str, + ) -> ManagedProviderRecord | None: + output = await self._proxy.call( + "provider.manager.get_by_id", + {"provider_id": str(provider_id)}, + ) + return self._record_from_output(output) + + async def get_merged_provider_config( + self, + provider_id: str, + ) -> dict[str, Any] | None: + output = await self._proxy.call( + "provider.manager.get_merged_provider_config", + {"provider_id": str(provider_id).strip()}, + ) + config = output.get("config") + return dict(config) if isinstance(config, dict) else None + + async def load_provider( + self, + provider_config: dict[str, Any], + ) -> ManagedProviderRecord | None: + output = await self._proxy.call( + "provider.manager.load", + {"provider_config": dict(provider_config)}, + ) + return self._record_from_output(output) + + async def terminate_provider(self, provider_id: str) -> None: + await self._proxy.call( + "provider.manager.terminate", + {"provider_id": str(provider_id)}, + ) + + async def create_provider( + self, + provider_config: dict[str, Any], + ) -> ManagedProviderRecord | None: + output = await self._proxy.call( + "provider.manager.create", + {"provider_config": dict(provider_config)}, + ) + return self._record_from_output(output) + + async def update_provider( + self, + origin_provider_id: str, + new_config: dict[str, Any], + ) -> ManagedProviderRecord | None: + output = await self._proxy.call( + "provider.manager.update", + { + "origin_provider_id": str(origin_provider_id), + "new_config": dict(new_config), + }, + ) + return self._record_from_output(output) + + async def delete_provider( + self, + provider_id: str | None = None, + provider_source_id: str | None = None, + ) -> None: + await self._proxy.call( + "provider.manager.delete", + { + "provider_id": provider_id, + "provider_source_id": provider_source_id, + }, + ) + + async def get_insts(self) -> list[ManagedProviderRecord]: + output = await self._proxy.call("provider.manager.get_insts", {}) + items = output.get("providers") + if not isinstance(items, list): + return [] + return [ + record + for record in ( + ManagedProviderRecord.from_payload(item) + if isinstance(item, dict) + else None + for item in items + ) + if record is not None + ] + + async def watch_changes(self) -> AsyncIterator[ProviderChangeEvent]: + async for chunk in self._proxy.stream("provider.manager.watch_changes", {}): + event = ProviderChangeEvent.from_payload(chunk) + if event is not None: + yield event + + async def register_provider_change_hook( + self, + callback: Callable[ + [str, ProviderType, str | None], + Awaitable[None] | None, + ], + ) -> asyncio.Task[None]: + async def runner() -> None: + async for event in self.watch_changes(): + result = callback( + event.provider_id, + event.provider_type, + event.umo, + ) + if inspect.isawaitable(result): + await result + + task = asyncio.create_task(runner()) + self._change_hook_tasks.add(task) + task.add_done_callback(self._log_change_hook_result) + return task + + async def unregister_provider_change_hook( + self, + task: asyncio.Task[None], + ) -> None: + if task not in self._change_hook_tasks: + return + self._change_hook_tasks.discard(task) + if not task.done(): + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + + def _log_change_hook_result(self, task: asyncio.Task[None]) -> None: + self._change_hook_tasks.discard(task) + if task.cancelled(): + debug_logger = getattr(self._logger, "debug", None) + if callable(debug_logger): + debug_logger( + "Provider change hook cancelled: plugin_id={}", + self._plugin_id, + ) + return + try: + task.result() + except asyncio.CancelledError: + debug_logger = getattr(self._logger, "debug", None) + if callable(debug_logger): + debug_logger( + "Provider change hook cancelled: plugin_id={}", + self._plugin_id, + ) + except Exception: + exception_logger = getattr(self._logger, "exception", None) + if callable(exception_logger): + exception_logger( + "Provider change hook failed: plugin_id={}", + self._plugin_id, + ) + + +__all__ = [ + "ManagedProviderRecord", + "ProviderChangeEvent", + "ProviderClient", + "ProviderManagerClient", +] diff --git a/astrbot-sdk/src/astrbot_sdk/clients/registry.py b/astrbot-sdk/src/astrbot_sdk/clients/registry.py new file mode 100644 index 0000000000..5a468b0983 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/registry.py @@ -0,0 +1,126 @@ +"""只读 handler 注册表客户端。""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from ._proxy import CapabilityProxy + + +def _coerce_int(value: Any, default: int = 0) -> int: + try: + return int(value) + except (TypeError, ValueError): + return default + + +@dataclass(slots=True) +class HandlerMetadata: + plugin_name: str + handler_full_name: str + trigger_type: str + description: str | None = None + event_types: list[str] = field(default_factory=list) + enabled: bool = True + group_path: list[str] = field(default_factory=list) + priority: int = 0 + kind: str = "handler" + require_admin: bool = False + required_role: str | None = None + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> HandlerMetadata: + return cls( + plugin_name=str(data.get("plugin_name", "")), + handler_full_name=str(data.get("handler_full_name", "")), + trigger_type=str(data.get("trigger_type", "")), + description=( + None + if data.get("description") is None + else str(data.get("description", "")).strip() or None + ), + event_types=[ + str(item) + for item in data.get("event_types", []) + if isinstance(item, str) + ], + enabled=bool(data.get("enabled", True)), + group_path=[ + str(item) + for item in data.get("group_path", []) + if isinstance(item, str) + ], + priority=_coerce_int(data.get("priority", 0), 0), + kind=str(data.get("kind", "handler") or "handler"), + require_admin=bool(data.get("require_admin", False)), + required_role=( + None + if data.get("required_role") is None + else str(data.get("required_role", "")).strip() or None + ), + ) + + +class RegistryClient: + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def get_handlers_by_event_type( + self, + event_type: str, + ) -> list[HandlerMetadata]: + output = await self._proxy.call( + "registry.get_handlers_by_event_type", + {"event_type": event_type}, + ) + return [ + HandlerMetadata.from_dict(item) + for item in output.get("handlers", []) + if isinstance(item, dict) + ] + + async def get_handler_by_full_name( + self, + full_name: str, + ) -> HandlerMetadata | None: + output = await self._proxy.call( + "registry.get_handler_by_full_name", + {"full_name": full_name}, + ) + handler = output.get("handler") + if not isinstance(handler, dict): + return None + return HandlerMetadata.from_dict(handler) + + async def set_handler_whitelist( + self, + plugin_names: list[str] | set[str] | None, + ) -> list[str] | None: + names = None + if plugin_names is not None: + names = sorted({str(item) for item in plugin_names if str(item).strip()}) + output = await self._proxy.call( + "system.event.handler_whitelist.set", + {"plugin_names": names}, + ) + result = output.get("plugin_names") + if not isinstance(result, list): + return None + return [str(item) for item in result] + + async def get_handler_whitelist(self) -> list[str] | None: + output = await self._proxy.call("system.event.handler_whitelist.get", {}) + result = output.get("plugin_names") + if not isinstance(result, list): + return None + return [str(item) for item in result] + + async def clear_handler_whitelist(self) -> None: + await self._proxy.call( + "system.event.handler_whitelist.set", + {"plugin_names": None}, + ) + + +__all__ = ["HandlerMetadata", "RegistryClient"] diff --git a/astrbot-sdk/src/astrbot_sdk/clients/session.py b/astrbot-sdk/src/astrbot_sdk/clients/session.py new file mode 100644 index 0000000000..c2901708cd --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/session.py @@ -0,0 +1,135 @@ +"""Session-scoped SDK managers.""" + +from __future__ import annotations + +from typing import Any + +from ..events import MessageEvent +from ..message.session import MessageSession +from ._proxy import CapabilityProxy +from .registry import HandlerMetadata + + +def _normalize_session(session: str | MessageSession | MessageEvent) -> str: + if isinstance(session, MessageEvent): + return str(session.unified_msg_origin) + if isinstance(session, MessageSession): + return str(session) + return str(session) + + +def _handler_to_payload(handler: HandlerMetadata) -> dict[str, Any]: + return { + "plugin_name": handler.plugin_name, + "handler_full_name": handler.handler_full_name, + "trigger_type": handler.trigger_type, + "description": handler.description, + "event_types": list(handler.event_types), + "enabled": handler.enabled, + "group_path": list(handler.group_path), + "priority": handler.priority, + "kind": handler.kind, + "require_admin": handler.require_admin, + } + + +class SessionPluginManager: + """Session-scoped plugin status manager.""" + + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def is_plugin_enabled_for_session( + self, + session: str | MessageSession | MessageEvent, + plugin_name: str, + ) -> bool: + output = await self._proxy.call( + "session.plugin.is_enabled", + { + "session": _normalize_session(session), + "plugin_name": str(plugin_name), + }, + ) + return bool(output.get("enabled", False)) + + async def filter_handlers_by_session( + self, + session: str | MessageSession | MessageEvent, + handlers: list[HandlerMetadata], + ) -> list[HandlerMetadata]: + output = await self._proxy.call( + "session.plugin.filter_handlers", + { + "session": _normalize_session(session), + "handlers": [_handler_to_payload(handler) for handler in handlers], + }, + ) + items = output.get("handlers") + if not isinstance(items, list): + return [] + return [ + HandlerMetadata.from_dict(item) for item in items if isinstance(item, dict) + ] + + +class SessionServiceManager: + """Session-scoped LLM/TTS service status manager.""" + + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def is_llm_enabled_for_session( + self, + session: str | MessageSession | MessageEvent, + ) -> bool: + output = await self._proxy.call( + "session.service.is_llm_enabled", + {"session": _normalize_session(session)}, + ) + return bool(output.get("enabled", False)) + + async def set_llm_status_for_session( + self, + session: str | MessageSession | MessageEvent, + enabled: bool, + ) -> None: + await self._proxy.call( + "session.service.set_llm_status", + {"session": _normalize_session(session), "enabled": bool(enabled)}, + ) + + async def should_process_llm_request( + self, + event_or_session: str | MessageSession | MessageEvent, + ) -> bool: + return await self.is_llm_enabled_for_session(event_or_session) + + async def is_tts_enabled_for_session( + self, + session: str | MessageSession | MessageEvent, + ) -> bool: + output = await self._proxy.call( + "session.service.is_tts_enabled", + {"session": _normalize_session(session)}, + ) + return bool(output.get("enabled", False)) + + async def set_tts_status_for_session( + self, + session: str | MessageSession | MessageEvent, + enabled: bool, + ) -> None: + await self._proxy.call( + "session.service.set_tts_status", + {"session": _normalize_session(session), "enabled": bool(enabled)}, + ) + + async def should_process_tts_request( + self, + event_or_session: str | MessageSession | MessageEvent, + ) -> bool: + return await self.is_tts_enabled_for_session(event_or_session) + + +__all__ = ["SessionPluginManager", "SessionServiceManager"] diff --git a/astrbot-sdk/src/astrbot_sdk/clients/skills.py b/astrbot-sdk/src/astrbot_sdk/clients/skills.py new file mode 100644 index 0000000000..1199e32e99 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/clients/skills.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from ._proxy import CapabilityProxy + + +@dataclass(slots=True) +class SkillRegistration: + name: str + description: str + path: str + skill_dir: str + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> SkillRegistration: + return cls( + name=str(data.get("name", "")), + description=str(data.get("description", "") or ""), + path=str(data.get("path", "")), + skill_dir=str(data.get("skill_dir", "")), + ) + + +class SkillClient: + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def register( + self, + *, + name: str, + path: str, + description: str = "", + ) -> SkillRegistration: + output = await self._proxy.call( + "skill.register", + { + "name": name, + "path": path, + "description": description, + }, + ) + return SkillRegistration.from_dict(output) + + async def unregister(self, name: str) -> bool: + output = await self._proxy.call("skill.unregister", {"name": name}) + return bool(output.get("removed", False)) + + async def list(self) -> list[SkillRegistration]: + output = await self._proxy.call("skill.list", {}) + return [ + SkillRegistration.from_dict(item) + for item in output.get("skills", []) + if isinstance(item, dict) + ] + + +__all__ = ["SkillClient", "SkillRegistration"] diff --git a/astrbot-sdk/src/astrbot_sdk/commands.py b/astrbot-sdk/src/astrbot_sdk/commands.py new file mode 100644 index 0000000000..0e90ab8302 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/commands.py @@ -0,0 +1,159 @@ +"""SDK-native command group helpers. + +本模块提供命令分组工具,用于组织具有层级关系的命令。 + +CommandGroup 允许以嵌套方式定义命令树,例如: + admin + ├── user + │ ├── add + │ └── remove + └── config + ├── get + └── set + +特性: +- 支持命令别名,自动展开父级路径的所有别名组合 +- 自动生成命令树的可视化输出 (print_cmd_tree) +- 与 @on_command 装饰器无缝集成 +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from itertools import product + +from .decorators import on_command, set_command_route_meta +from .protocol.descriptors import CommandRouteSpec + + +@dataclass(slots=True) +class _CommandNode: + name: str + aliases: list[str] = field(default_factory=list) + description: str | None = None + subgroups: list[CommandGroup] = field(default_factory=list) + commands: list[tuple[str, str | None]] = field(default_factory=list) + + +class CommandGroup: + def __init__( + self, + name: str, + *, + aliases: list[str] | None = None, + description: str | None = None, + parent: CommandGroup | None = None, + ) -> None: + self.name = name + self.aliases = list(aliases or []) + self.description = description + self.parent = parent + self._tree = _CommandNode( + name=name, aliases=self.aliases, description=description + ) + + def group( + self, + name: str, + *, + aliases: list[str] | None = None, + description: str | None = None, + ) -> CommandGroup: + child = CommandGroup( + name, + aliases=aliases, + description=description, + parent=self, + ) + self._tree.subgroups.append(child) + return child + + def command( + self, + name: str, + *, + aliases: list[str] | None = None, + description: str | None = None, + ): + full_command = " ".join([*self.path, name]) + full_aliases = self._expand_aliases(name=name, aliases=aliases or []) + display_command = full_command + route = CommandRouteSpec( + group_path=self.path, + display_command=display_command, + group_help=self.description, + ) + + def decorator(func): + decorated = on_command( + full_command, + aliases=full_aliases, + description=description, + )(func) + self._tree.commands.append((name, description)) + set_command_route_meta(decorated, route) + return decorated + + return decorator + + @property + def path(self) -> list[str]: + if self.parent is None: + return [self.name] + return [*self.parent.path, self.name] + + def print_cmd_tree(self) -> str: + lines: list[str] = [] + self._append_tree_lines(lines, indent=0) + return "\n".join(lines) + + def _append_tree_lines(self, lines: list[str], *, indent: int) -> None: + prefix = " " * indent + label = self.name + if self.aliases: + label += f" ({', '.join(self.aliases)})" + lines.append(f"{prefix}{label}") + for command_name, description in self._tree.commands: + command_label = f"{prefix} - {command_name}" + if description: + command_label += f": {description}" + lines.append(command_label) + for subgroup in self._tree.subgroups: + subgroup._append_tree_lines(lines, indent=indent + 1) + + def _expand_aliases(self, *, name: str, aliases: list[str]) -> list[str]: + group_segments: list[list[str]] = [] + cursor: CommandGroup | None = self + ancestry: list[CommandGroup] = [] + while cursor is not None: + ancestry.append(cursor) + cursor = cursor.parent + for group in reversed(ancestry): + group_segments.append([group.name, *group.aliases]) + leaf_segments = [name, *aliases] + expanded: set[str] = set() + for parts in product(*group_segments, leaf_segments): + route = " ".join(parts) + if route != " ".join([*self.path, name]): + expanded.add(route) + return sorted(expanded) + + +def command_group( + name: str, + *, + aliases: list[str] | None = None, + description: str | None = None, +) -> CommandGroup: + return CommandGroup( + name, + aliases=aliases, + description=description, + ) + + +def print_cmd_tree(group: CommandGroup) -> str: + return group.print_cmd_tree() + + +__all__ = ["CommandGroup", "command_group", "print_cmd_tree"] diff --git a/astrbot-sdk/src/astrbot_sdk/context.py b/astrbot-sdk/src/astrbot_sdk/context.py new file mode 100644 index 0000000000..b2f5d4cd95 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/context.py @@ -0,0 +1,750 @@ +"""v4 原生运行时上下文。 + +`Context` 是插件与 AstrBot Core 交互的主要入口, +负责组合所有 capability 客户端并提供统一的访问接口。 + +每个 handler 调用都会创建一个新的 Context 实例, +绑定到当前的 Peer、插件 ID 和取消令牌。 + +Attributes: + llm: LLM 能力客户端,用于 AI 对话 + memory: 记忆能力客户端,用于语义存储 + db: 数据库客户端,用于 KV 持久化 + files: 文件服务客户端,用于文件令牌注册与解析 + platform: 平台客户端,用于发送消息 + permission: 权限客户端,用于查询用户角色 + providers: Provider 客户端,用于查询和调用专用 Provider + provider_manager: Provider 管理客户端,用于 reserved/system 级操作 + permission_manager: 权限管理客户端,用于 reserved/system 级管理员维护 + personas: 人格管理客户端 + conversations: 对话管理客户端 + kbs: 知识库管理客户端 + message_history: 消息历史管理客户端 + http: HTTP 客户端,用于注册 API 端点 + metadata: 元数据客户端,用于查询插件信息 + mcp: MCP 管理客户端,用于本地/全局 MCP 服务管理 + skills: Skill 客户端,用于向 AstrBot 注册插件技能 + plugin_id: 当前插件的唯一标识 + logger: 绑定了插件 ID 的日志器 + cancel_token: 取消令牌,用于处理请求取消 +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable, Sequence +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from loguru import logger as base_logger + +from ._internal.plugin_logger import PluginLogger +from ._internal.star_runtime import current_star_instance +from ._message_types import normalize_message_type +from .clients import ( + DBClient, + HTTPClient, + LLMClient, + MCPManagerClient, + MemoryClient, + MetadataClient, + PermissionClient, + PermissionManagerClient, + PlatformClient, + PlatformError, + PlatformStats, + PlatformStatus, + RegistryClient, + SkillClient, +) +from .clients._proxy import CapabilityProxy +from .clients.files import FileServiceClient +from .clients.llm import LLMResponse +from .clients.managers import ( + ConversationManagerClient, + KnowledgeBaseManagerClient, + MessageHistoryManagerClient, + PersonaManagerClient, +) +from .clients.provider import ProviderClient, ProviderManagerClient +from .clients.session import SessionPluginManager, SessionServiceManager +from .clients.skills import SkillRegistration +from .errors import AstrBotError +from .llm.entities import LLMToolSpec, ProviderMeta, ProviderRequest +from .llm.tools import LLMToolManager +from .message.components import BaseMessageComponent +from .message.result import MessageChain +from .message.session import MessageSession +from .session_waiter import ( + _mark_session_waiter_background_task, + _unmark_session_waiter_background_task, +) + +PlatformCompatContent = ( + str | MessageChain | Sequence[BaseMessageComponent] | Sequence[dict[str, Any]] +) + + +@dataclass(slots=True) +class PlatformCompatFacade: + """兼容层平台入口,仅暴露安全元信息和主动发送能力。""" + + _ctx: Context + id: str + name: str + type: str + status: PlatformStatus = PlatformStatus.PENDING + errors: list[PlatformError] = field(default_factory=list) + last_error: PlatformError | None = None + unified_webhook: bool = False + _state_lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False) + + async def send_by_session( + self, + session: str | MessageSession, + content: PlatformCompatContent, + ) -> dict[str, Any]: + return await self._ctx.platform.send_by_session(session, content) + + async def send_by_id( + self, + session_id: str, + content: PlatformCompatContent, + *, + message_type: str = "private", + ) -> dict[str, Any]: + return await self._ctx.platform.send_by_id( + self.id, + session_id, + content, + message_type=message_type, + ) + + async def send( + self, + session: str | MessageSession, + content: PlatformCompatContent, + *, + message_type: str = "private", + ) -> dict[str, Any]: + if isinstance(session, MessageSession): + return await self.send_by_session(session, content) + session_text = str(session).strip() + if ":" in session_text: + return await self.send_by_session(session_text, content) + return await self.send_by_id( + session_text, + content, + message_type=message_type, + ) + + async def refresh(self) -> None: + async with self._state_lock: + await self._refresh_locked() + + async def clear_errors(self) -> None: + async with self._state_lock: + await self._ctx._proxy.call( + "platform.manager.clear_errors", + {"platform_id": self.id}, + ) + await self._refresh_locked() + + async def get_stats(self) -> PlatformStats | None: + output = await self._ctx._proxy.call( + "platform.manager.get_stats", + {"platform_id": self.id}, + ) + return PlatformStats.from_payload(output.get("stats")) + + def _apply_snapshot(self, payload: Any) -> None: + if not isinstance(payload, dict): + return + self.name = str(payload.get("name", self.name)) + self.type = str(payload.get("type", self.type)) + self.status = PlatformStatus.from_value(payload.get("status")) + errors_payload = payload.get("errors") + if isinstance(errors_payload, list): + self.errors = [ + error + for error in ( + PlatformError.from_payload(item) if isinstance(item, dict) else None + for item in errors_payload + ) + if error is not None + ] + self.last_error = PlatformError.from_payload(payload.get("last_error")) + self.unified_webhook = bool(payload.get("unified_webhook", False)) + + async def _refresh_locked(self) -> None: + output = await self._ctx._proxy.call( + "platform.manager.get_by_id", + {"platform_id": self.id}, + ) + self._apply_snapshot(output.get("platform")) + + +@dataclass(slots=True) +class CancelToken: + """请求取消令牌。 + + 用于协调长时间运行操作的取消。当用户取消请求或 + 上游超时时,令牌会被触发,允许 handler 及时清理资源。 + + Example: + async def long_operation(ctx: Context): + for item in large_list: + ctx.cancel_token.raise_if_cancelled() + await process(item) + """ + + _cancelled: asyncio.Event + + def __init__(self) -> None: + self._cancelled = asyncio.Event() + + def cancel(self) -> None: + """触发取消信号。""" + self._cancelled.set() + + @property + def cancelled(self) -> bool: + """检查是否已被取消。""" + return self._cancelled.is_set() + + async def wait(self) -> None: + """等待取消信号。""" + await self._cancelled.wait() + + def raise_if_cancelled(self) -> None: + """如果已取消则抛出 CancelledError。 + + Raises: + asyncio.CancelledError: 如果令牌已被取消 + """ + if self.cancelled: + raise asyncio.CancelledError + + +class Context: + """插件运行时上下文。 + + 组合所有 capability 客户端,提供统一的访问接口。 + 每个 handler 调用都会创建新的 Context 实例。 + + Attributes: + peer: 协议对等端,用于底层通信 + llm: LLM 客户端 + memory: 记忆客户端 + db: 数据库客户端 + files: 文件服务客户端 + platform: 平台客户端 + permission: 权限客户端 + providers: Provider 客户端 + provider_manager: Provider 管理客户端 + permission_manager: 权限管理客户端 + personas: 人格管理客户端 + conversations: 对话管理客户端 + kbs: 知识库管理客户端 + message_history: 消息历史管理客户端 + http: HTTP 客户端 + metadata: 元数据客户端 + registry: 能力注册客户端 + skills: 技能客户端 + session_plugins: 会话插件管理器 + session_services: 会话服务管理器 + mcp: MCP 管理客户端 + plugin_id: 当前插件 ID + logger: 日志器 + cancel_token: 取消令牌 + """ + + def __init__( + self, + *, + peer, + plugin_id: str, + request_id: str | None = None, + cancel_token: CancelToken | None = None, + logger: Any | None = None, + source_event_payload: dict[str, Any] | None = None, + ) -> None: + """初始化上下文。 + + Args: + peer: 协议对等端实例 + plugin_id: 当前插件 ID + cancel_token: 取消令牌,None 时创建新令牌 + logger: 日志器,None 时使用默认 logger 并绑定 plugin_id + """ + proxy = CapabilityProxy( + peer, + caller_plugin_id=plugin_id, + request_scope_id=request_id, + ) + if isinstance(logger, PluginLogger): + bound_logger = logger + else: + bound_logger = logger or base_logger.bind(plugin_id=plugin_id) + self._proxy = proxy + self.peer = peer + self.llm = LLMClient(proxy) + self.memory = MemoryClient(proxy) + self.db = DBClient(proxy) + self.files = FileServiceClient(proxy) + self.platform = PlatformClient(proxy) + self.permission = PermissionClient(proxy) + self.providers = ProviderClient(proxy) + self.provider_manager = ProviderManagerClient( + proxy, + plugin_id=plugin_id, + logger=bound_logger, + ) + self.permission_manager = PermissionManagerClient( + proxy, + source_event_payload=source_event_payload, + ) + self.personas = PersonaManagerClient(proxy) + self.conversations = ConversationManagerClient(proxy) + self.kbs = KnowledgeBaseManagerClient(proxy) + self.message_history = MessageHistoryManagerClient(proxy) + self.http = HTTPClient(proxy) + self.metadata = MetadataClient(proxy, plugin_id) + self.mcp = MCPManagerClient(proxy) + self.registry = RegistryClient(proxy) + self.skills = SkillClient(proxy) + self.session_plugins = SessionPluginManager(proxy) + self.session_services = SessionServiceManager(proxy) + self.persona_manager = self.personas + self.conversation_manager = self.conversations + self.kb_manager = self.kbs + self.message_history_manager = self.message_history + self.mcp_manager = self.mcp + self._llm_tool_manager = LLMToolManager(proxy) + self.plugin_id = plugin_id + self.logger: PluginLogger = ( + bound_logger + if isinstance(bound_logger, PluginLogger) + else PluginLogger(plugin_id=plugin_id, logger=bound_logger) + ) + self.cancel_token = cancel_token or CancelToken() + self.request_id = request_id + self._source_event_payload = ( + dict(source_event_payload) if isinstance(source_event_payload, dict) else {} + ) + + async def get_data_dir(self) -> Path: + """Return the plugin-scoped data directory path.""" + output = await self._proxy.call("system.get_data_dir", {}) + return Path(str(output.get("path", ""))) + + async def _register_file_url( + self, + path: str, + timeout: float | None = None, + ) -> str: + return await self.files._register_file_url(path, timeout=timeout) + + async def text_to_image( + self, + text: str, + *, + return_url: bool = True, + ) -> str: + """Render plain text into an image using the host renderer.""" + output = await self._proxy.call( + "system.text_to_image", + {"text": text, "return_url": return_url}, + ) + return str(output.get("result", "")) + + async def html_render( + self, + tmpl: str, + data: dict[str, Any], + *, + return_url: bool = True, + options: dict[str, Any] | None = None, + ) -> str: + """Render an HTML template using the host renderer.""" + output = await self._proxy.call( + "system.html_render", + { + "tmpl": tmpl, + "data": dict(data), + "return_url": return_url, + "options": options, + }, + ) + return str(output.get("result", "")) + + async def get_using_provider(self, umo: str | None = None) -> ProviderMeta | None: + return await self.providers.get_using_chat(umo) + + async def get_current_chat_provider_id(self, umo: str | None = None) -> str | None: + output = await self._proxy.call( + "provider.get_current_chat_provider_id", + {"umo": umo}, + ) + value = output.get("provider_id") + return str(value) if value else None + + async def get_all_providers(self) -> list[ProviderMeta]: + return await self.providers.list_all() + + async def get_all_tts_providers(self) -> list[ProviderMeta]: + return await self.providers.list_tts() + + async def get_all_stt_providers(self) -> list[ProviderMeta]: + return await self.providers.list_stt() + + async def get_all_embedding_providers(self) -> list[ProviderMeta]: + return await self.providers.list_embedding() + + async def get_all_rerank_providers(self) -> list[ProviderMeta]: + return await self.providers.list_rerank() + + async def get_using_tts_provider( + self, umo: str | None = None + ) -> ProviderMeta | None: + provider = await self.providers.get_using_tts(umo) + return provider.meta() if provider is not None else None + + async def get_using_stt_provider( + self, umo: str | None = None + ) -> ProviderMeta | None: + provider = await self.providers.get_using_stt(umo) + return provider.meta() if provider is not None else None + + async def send_message( + self, + session: str | MessageSession, + content: PlatformCompatContent, + ) -> dict[str, Any]: + return await self.platform.send_by_session(session, content) + + async def send_message_by_id( + self, + type: str, + id: str, + content: PlatformCompatContent, + *, + platform: str, + ) -> dict[str, Any]: + platform_payload = await self._resolve_platform_target(platform) + return await self.platform.send_by_id( + str(platform_payload.get("id", "")), + str(id), + content, + message_type=self._normalize_compat_message_type(type), + ) + + @staticmethod + def _normalize_compat_message_type(value: str) -> str: + normalized = normalize_message_type(value) + if not normalized: + raise AstrBotError.invalid_input("send_message_by_id requires type") + return normalized + + async def _resolve_platform_target(self, platform: str) -> dict[str, Any]: + target = str(platform).strip() + if not target: + raise AstrBotError.invalid_input( + "send_message_by_id requires explicit platform" + ) + instances = await self._list_platform_instances() + id_matches = [ + item for item in instances if str(item.get("id", "")).strip() == target + ] + if len(id_matches) == 1: + return id_matches[0] + normalized_target = target.lower() + alias_matches = [ + item + for item in instances + if str(item.get("type", "")).strip().lower() == normalized_target + or str(item.get("name", "")).strip().lower() == normalized_target + ] + if len(alias_matches) == 1: + return alias_matches[0] + if len(alias_matches) > 1: + raise AstrBotError.invalid_input( + f"send_message_by_id platform '{target}' is ambiguous" + ) + raise AstrBotError.invalid_input( + f"send_message_by_id cannot resolve platform '{target}'" + ) + + def get_llm_tool_manager(self) -> LLMToolManager: + return self._llm_tool_manager + + async def activate_llm_tool(self, name: str) -> bool: + return await self._llm_tool_manager.activate(name) + + async def deactivate_llm_tool(self, name: str) -> bool: + return await self._llm_tool_manager.deactivate(name) + + async def add_llm_tools(self, *tools: LLMToolSpec) -> list[str]: + return await self._llm_tool_manager.add(*tools) + + async def register_llm_tool( + self, + name: str, + parameters_schema: dict[str, Any], + desc: str, + func_obj: Callable[..., Any] | Callable[..., Awaitable[Any]], + *, + active: bool = True, + ) -> list[str]: + if not callable(func_obj): + raise TypeError("register_llm_tool requires a callable func_obj") + tool_name = str(name).strip() + if not tool_name: + raise AstrBotError.invalid_input("register_llm_tool requires name") + if not isinstance(parameters_schema, dict): + raise TypeError("register_llm_tool requires parameters_schema dict") + + handler_ref = f"__dynamic_llm_tool__:{tool_name}" + tool_spec = LLMToolSpec.create( + name=tool_name, + description=str(desc), + parameters_schema=dict(parameters_schema), + handler_ref=handler_ref, + active=bool(active), + ) + owner = getattr(func_obj, "__self__", None) or current_star_instance() + dispatcher = getattr(self.peer, "_sdk_capability_dispatcher", None) + if dispatcher is not None and hasattr(dispatcher, "add_dynamic_llm_tool"): + dispatcher.add_dynamic_llm_tool( + plugin_id=self.plugin_id, + spec=tool_spec, + callable_obj=func_obj, + owner=owner, + ) + try: + return await self._llm_tool_manager.add(tool_spec) + except Exception: + if dispatcher is not None and hasattr(dispatcher, "remove_llm_tool"): + dispatcher.remove_llm_tool(self.plugin_id, tool_name) + raise + + async def unregister_llm_tool(self, name: str) -> bool: + removed = await self._llm_tool_manager.remove(str(name)) + dispatcher = getattr(self.peer, "_sdk_capability_dispatcher", None) + if dispatcher is not None and hasattr(dispatcher, "remove_llm_tool"): + dispatcher.remove_llm_tool(self.plugin_id, str(name)) + return removed + + async def register_skill( + self, + *, + name: str, + path: str | Path, + description: str = "", + ) -> SkillRegistration: + return await self.skills.register( + name=name, + path=str(path), + description=description, + ) + + async def unregister_skill(self, name: str) -> bool: + return await self.skills.unregister(name) + + async def tool_loop_agent( + self, + request: ProviderRequest | None = None, + **kwargs: Any, + ) -> LLMResponse: + provider_request = request or ProviderRequest() + if kwargs: + merged = provider_request.model_dump() + merged.update(kwargs) + provider_request = ProviderRequest.model_validate(merged) + payload = provider_request.to_payload() + target_payload = self._source_event_payload.get("target") + if isinstance(target_payload, dict): + # Preserve the original message target so core can recover the + # dispatch token for message-bound tool loop execution. + payload["target"] = dict(target_payload) + output = await self._proxy.call("agent.tool_loop.run", payload) + return LLMResponse.model_validate(output) + + def _source_event_type(self) -> str: + event_type = self._source_event_payload.get("event_type") + if isinstance(event_type, str) and event_type.strip(): + return event_type.strip() + fallback_type = self._source_event_payload.get("type") + if isinstance(fallback_type, str) and fallback_type.strip(): + return fallback_type.strip() + raw_payload = self._source_event_payload.get("raw") + if isinstance(raw_payload, dict): + raw_event_type = raw_payload.get("event_type") + if isinstance(raw_event_type, str) and raw_event_type.strip(): + return raw_event_type.strip() + return "" + + async def register_commands( + self, + command_name: str, + handler_full_name: str, + *, + desc: str = "", + priority: int = 0, + use_regex: bool = False, + ignore_prefix: bool = False, + ) -> None: + source_event_type = self._source_event_type() + if source_event_type not in {"astrbot_loaded", "platform_loaded"}: + raise AstrBotError.invalid_input( + "register_commands is only available in astrbot_loaded/platform_loaded events" + ) + if ignore_prefix: + raise AstrBotError.invalid_input( + "register_commands(ignore_prefix=True) is unsupported in SDK runtime" + ) + if isinstance(priority, bool) or not isinstance(priority, int): + raise AstrBotError.invalid_input( + "register_commands priority must be an integer" + ) + await self._proxy.call( + "registry.command.register", + { + "command_name": str(command_name), + "handler_full_name": str(handler_full_name), + "source_event_type": source_event_type, + "desc": str(desc), + "priority": priority, + "use_regex": bool(use_regex), + "ignore_prefix": False, + }, + ) + + async def register_task( + self, + task: Awaitable[Any], + desc: str, + ) -> asyncio.Task[Any]: + """Register a background task owned by the current SDK context. + + This is the recommended way to launch follow-up work that should outlive + the current handler dispatch, including `session_waiter(...)` flows. + Directly awaiting a waiter inside the current handler keeps the original + dispatch open until the next message arrives. + + Example: + await event.reply("请输入用户名:") + await ctx.register_task( + self.collect_username(event), + "waiter:collect_username", + ) + """ + task_desc = str(desc) + + async def _wrap_future(future: asyncio.Future[Any]) -> Any: + return await future + + if isinstance(task, asyncio.Task): + background_task = task + elif asyncio.isfuture(task): + background_task = asyncio.create_task(_wrap_future(task)) + elif asyncio.iscoroutine(task): + background_task = asyncio.create_task(task) + else: + raise TypeError("register_task requires an awaitable task object") + + _mark_session_waiter_background_task(background_task) + + def _on_done(done_task: asyncio.Task[Any]) -> None: + _unmark_session_waiter_background_task(done_task) + if done_task.cancelled(): + debug_logger = getattr(self.logger, "debug", None) + if callable(debug_logger): + debug_logger( + "SDK background task cancelled: plugin_id={} desc={}", + self.plugin_id, + task_desc, + ) + return + try: + done_task.result() + except Exception: + exception_logger = getattr(self.logger, "exception", None) + if callable(exception_logger): + exception_logger( + "SDK background task failed: plugin_id={} desc={}", + self.plugin_id, + task_desc, + ) + + background_task.add_done_callback(_on_done) + return background_task + + async def _list_platform_instances(self) -> list[dict[str, Any]]: + output = await self._proxy.call("platform.list_instances", {}) + items = output.get("platforms") + if not isinstance(items, list): + return [] + normalized: list[dict[str, Any]] = [] + for item in items: + if not isinstance(item, dict): + continue + platform_id = str(item.get("id", "")).strip() + platform_type = str(item.get("type", "")).strip() + if not platform_id or not platform_type: + continue + normalized.append( + { + "id": platform_id, + "name": str(item.get("name", platform_id)), + "type": platform_type, + "status": PlatformStatus.from_value(item.get("status")), + } + ) + return normalized + + def _build_platform_facade( + self, + platform_payload: dict[str, Any], + ) -> PlatformCompatFacade: + return PlatformCompatFacade( + _ctx=self, + id=str(platform_payload.get("id", "")), + name=str(platform_payload.get("name", "")), + type=str(platform_payload.get("type", "")), + status=PlatformStatus.from_value(platform_payload.get("status")), + ) + + async def list_platforms(self) -> list[PlatformCompatFacade]: + """获取所有平台实例的兼容层列表。 + + Returns: + 所有可见平台实例的兼容层对象列表 + + Example: + for platform in await ctx.list_platforms(): + print(platform.id, platform.status) + """ + return [ + self._build_platform_facade(item) + for item in await self._list_platform_instances() + ] + + async def get_platform(self, platform_type: str) -> PlatformCompatFacade | None: + target_type = str(platform_type).strip().lower() + if not target_type: + return None + for item in await self._list_platform_instances(): + if str(item.get("type", "")).strip().lower() == target_type: + return self._build_platform_facade(item) + return None + + async def get_platform_inst(self, platform_id: str) -> PlatformCompatFacade | None: + target_id = str(platform_id).strip() + if not target_id: + return None + for item in await self._list_platform_instances(): + if str(item.get("id", "")).strip() == target_id: + return self._build_platform_facade(item) + return None diff --git a/astrbot-sdk/src/astrbot_sdk/conversation.py b/astrbot-sdk/src/astrbot_sdk/conversation.py new file mode 100644 index 0000000000..a39c3fece3 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/conversation.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from enum import Enum +from typing import Any + +from .context import Context +from .events import MessageEvent +from .message.components import BaseMessageComponent +from .message.result import MessageChain +from .session_waiter import SessionWaiterManager + +DEFAULT_BUSY_MESSAGE = "当前会话已有进行中的交互,请先完成后再试。" + + +class ConversationState(str, Enum): + ACTIVE = "active" + REJECTED_BUSY = "rejected_busy" + REPLACED = "replaced" + TIMEOUT = "timeout" + COMPLETED = "completed" + CANCELLED = "cancelled" + + +class ConversationReplaced(RuntimeError): + pass + + +class ConversationClosed(RuntimeError): + pass + + +@dataclass(slots=True) +class ConversationSession: + ctx: Context + event: MessageEvent + waiter_manager: SessionWaiterManager + timeout: int + state: ConversationState = ConversationState.ACTIVE + _owner_task: asyncio.Task[Any] | None = None + + def __post_init__(self) -> None: + if self.state != ConversationState.ACTIVE: + self.state = ConversationState.ACTIVE + + def bind_owner_task(self, task: asyncio.Task[Any]) -> None: + self._owner_task = task + + @property + def session_key(self) -> str: + return self.event.unified_msg_origin + + @property + def active(self) -> bool: + return self.state == ConversationState.ACTIVE + + async def ask(self, prompt: str, timeout: int | None = None) -> MessageEvent: + self._ensure_usable("ask") + if prompt: + await self.reply(prompt) + try: + return await self.waiter_manager.wait_for_event( + event=self.event, + timeout=timeout or self.timeout, + record_history_chains=False, + ) + except asyncio.TimeoutError: + self.close(ConversationState.TIMEOUT) + raise + except asyncio.CancelledError as exc: + if self.state == ConversationState.REPLACED: + raise ConversationReplaced( + "conversation replaced by a newer session" + ) from exc + self.close(ConversationState.CANCELLED) + raise + + async def reply(self, text: str) -> None: + self._ensure_usable("reply") + await self.event.reply(text) + + async def reply_chain( + self, + chain: MessageChain | list[BaseMessageComponent] | list[dict[str, Any]], + ) -> None: + self._ensure_usable("reply_chain") + await self.event.reply_chain(chain) + + async def send_message( + self, + content: str | MessageChain | list[BaseMessageComponent] | list[dict[str, Any]], + ) -> dict[str, Any]: + self._ensure_usable("send_message") + return await self.ctx.platform.send_by_session(self.event.session_id, content) + + def end(self) -> None: + self.close(ConversationState.COMPLETED) + + def mark_replaced(self) -> None: + self.close(ConversationState.REPLACED) + + def close(self, state: ConversationState) -> None: + if self.state != ConversationState.ACTIVE and state == self.state: + return + if ( + self.state != ConversationState.ACTIVE + and state != ConversationState.REPLACED + ): + return + self.state = state + + def _ensure_usable(self, action: str) -> None: + if ( + self._owner_task is not None + and asyncio.current_task() is not self._owner_task + ): + raise ConversationClosed( + f"ConversationSession cannot be used outside its owner task during {action}" + ) + if not self.active: + raise ConversationClosed( + f"ConversationSession is already closed ({self.state.value}) during {action}" + ) + + +__all__ = [ + "ConversationClosed", + "ConversationReplaced", + "ConversationSession", + "ConversationState", + "DEFAULT_BUSY_MESSAGE", +] diff --git a/astrbot-sdk/src/astrbot_sdk/decorators.py b/astrbot-sdk/src/astrbot_sdk/decorators.py new file mode 100644 index 0000000000..708bfd975f --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/decorators.py @@ -0,0 +1,1246 @@ +"""v4 原生装饰器。 + +提供声明式的方法来注册 handler 和 capability。 +装饰器会在方法上附加元数据,由 Star.__init_subclass__ 自动收集。 + +触发器装饰器: + - @on_command: 命令触发器 + - @on_message: 消息触发器(关键词/正则) + - @on_event: 事件触发器 + - @on_schedule: 定时任务触发器 + - @conversation_command: 带会话生命周期的命令触发器 + +权限与过滤装饰器: + - @require_admin / @admin_only: 管理员权限标记 + - @require_permission: 通用角色权限标记 + - @platforms: 限定平台 + - @group_only / @private_only: 群聊/私聊限定 + - @message_types: 消息类型过滤 + +限流装饰器: + - @rate_limit: 滑动窗口限流 + - @cooldown: 冷却时间 + +优先级装饰器: + - @priority: 设置执行优先级 + +能力导出装饰器: + - @provide_capability: 声明对外暴露的能力 + - @register_llm_tool: 注册 LLM 工具 + - @register_agent: 注册 Agent + +Example: + class MyPlugin(Star): + @on_command("hello", aliases=["hi"]) + async def hello(self, event: MessageEvent, ctx: Context): + await event.reply("Hello!") + + @on_message(keywords=["help"]) + async def help(self, event: MessageEvent, ctx: Context): + await event.reply("Help info...") + + @provide_capability("my_plugin.calculate", description="计算") + async def calculate(self, payload: dict, ctx: Context): + return {"result": payload["x"] * 2} +""" + +from __future__ import annotations + +import inspect +import typing +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, Literal, cast + +from pydantic import BaseModel + +from ._internal.typing_utils import unwrap_optional +from .llm.agents import AgentSpec, BaseAgentRunner +from .llm.entities import LLMToolSpec +from .protocol.descriptors import ( + RESERVED_CAPABILITY_PREFIXES, + CapabilityDescriptor, + CommandRouteSpec, + CommandTrigger, + EventTrigger, + FilterSpec, + MessageTrigger, + MessageTypeFilterSpec, + Permissions, + PlatformFilterSpec, + ScheduleTrigger, +) + +HandlerCallable = Callable[..., Any] +HANDLER_META_ATTR = "__astrbot_handler_meta__" +CAPABILITY_META_ATTR = "__astrbot_capability_meta__" +LLM_TOOL_META_ATTR = "__astrbot_llm_tool_meta__" +AGENT_META_ATTR = "__astrbot_agent_meta__" +HTTP_API_META_ATTR = "__astrbot_http_api_meta__" +VALIDATE_CONFIG_META_ATTR = "__astrbot_validate_config_meta__" +PROVIDER_CHANGE_META_ATTR = "__astrbot_provider_change_meta__" +BACKGROUND_TASK_META_ATTR = "__astrbot_background_task_meta__" +MCP_SERVER_META_ATTR = "__astrbot_mcp_server_meta__" +SKILL_META_ATTR = "__astrbot_skill_meta__" + +LimiterScope = Literal["session", "user", "group", "global"] +LimiterBehavior = Literal["hint", "silent", "error"] +ConversationMode = Literal["replace", "reject"] + + +@dataclass(slots=True) +class LimiterMeta: + kind: Literal["rate_limit", "cooldown"] + limit: int + window: float + scope: LimiterScope = "session" + behavior: LimiterBehavior = "hint" + message: str | None = None + + +@dataclass(slots=True) +class ConversationMeta: + timeout: int = 60 + mode: ConversationMode = "replace" + busy_message: str | None = None + grace_period: float = 1.0 + + +@dataclass(slots=True) +class HandlerMeta: + """Handler 元数据。 + + 存储在方法上的 __astrbot_handler_meta__ 属性中。 + + Attributes: + trigger: 触发器(命令/消息/事件/定时) + kind: handler 类型标识 + contract: 契约类型(可选) + priority: 执行优先级(数值越大越先执行) + permissions: 权限要求 + """ + + trigger: CommandTrigger | MessageTrigger | EventTrigger | ScheduleTrigger | None = ( + None + ) + kind: str = "handler" + contract: str | None = None + description: str | None = None + priority: int = 0 + permissions: Permissions = field(default_factory=Permissions) + filters: list[FilterSpec] = field(default_factory=list) + local_filters: list[Any] = field(default_factory=list) + command_route: CommandRouteSpec | None = None + limiter: LimiterMeta | None = None + conversation: ConversationMeta | None = None + decorator_sources: dict[str, str] = field(default_factory=dict) + + +@dataclass(slots=True) +class CapabilityMeta: + """Capability 元数据。 + + 存储在方法上的 __astrbot_capability_meta__ 属性中。 + + Attributes: + descriptor: 能力描述符 + """ + + descriptor: CapabilityDescriptor + + +@dataclass(slots=True) +class LLMToolMeta: + spec: LLMToolSpec + + +@dataclass(slots=True) +class AgentMeta: + spec: AgentSpec + + +@dataclass(slots=True) +class HttpApiMeta: + route: str + methods: list[str] = field(default_factory=lambda: ["GET"]) + description: str = "" + capability_name: str | None = None + + +@dataclass(slots=True) +class ValidateConfigMeta: + model: type[BaseModel] | None = None + schema: dict[str, Any] | None = None + + +def _is_valid_validate_config_expected_type(value: Any) -> bool: + if isinstance(value, type): + return True + return ( + isinstance(value, tuple) + and len(value) > 0 + and all(isinstance(item, type) for item in value) + ) + + +def _validate_validate_config_schema(schema: dict[str, Any]) -> None: + for field_name, field_schema in schema.items(): + if not isinstance(field_schema, dict): + raise TypeError( + f"validate_config schema field {field_name!r} must be a dict" + ) + expected_type = field_schema.get("type") + if expected_type is not None and not _is_valid_validate_config_expected_type( + expected_type + ): + raise TypeError( + "validate_config schema field " + f"{field_name!r} has invalid 'type' entry {expected_type!r}; " + "expected a type or tuple of types" + ) + + +@dataclass(slots=True) +class ProviderChangeMeta: + provider_types: list[str] = field(default_factory=list) + + +@dataclass(slots=True) +class BackgroundTaskMeta: + description: str = "" + auto_start: bool = True + on_error: Literal["log", "restart"] = "log" + + +@dataclass(slots=True) +class MCPServerMeta: + name: str + scope: Literal["local", "global"] = "global" + config: dict[str, Any] | None = None + timeout: float = 30.0 + wait_until_ready: bool = True + + +@dataclass(slots=True) +class SkillMeta: + name: str + path: str + description: str = "" + + +def _get_or_create_meta(func: HandlerCallable) -> HandlerMeta: + """获取或创建 handler 元数据。""" + meta = getattr(func, HANDLER_META_ATTR, None) + if meta is None: + meta = HandlerMeta() + setattr(func, HANDLER_META_ATTR, meta) + return meta + + +def get_handler_meta(func: HandlerCallable) -> HandlerMeta | None: + """获取方法的 handler 元数据。 + + Args: + func: 要检查的方法 + + Returns: + HandlerMeta 实例,如果没有则返回 None + """ + return getattr(func, HANDLER_META_ATTR, None) + + +def get_capability_meta(func: HandlerCallable) -> CapabilityMeta | None: + """获取方法的 capability 元数据。 + + Args: + func: 要检查的方法 + + Returns: + CapabilityMeta 实例,如果没有则返回 None + """ + return getattr(func, CAPABILITY_META_ATTR, None) + + +def get_llm_tool_meta(func: HandlerCallable) -> LLMToolMeta | None: + return getattr(func, LLM_TOOL_META_ATTR, None) + + +def get_agent_meta(obj: Any) -> AgentMeta | None: + return getattr(obj, AGENT_META_ATTR, None) + + +def get_http_api_meta(func: HandlerCallable) -> HttpApiMeta | None: + return getattr(func, HTTP_API_META_ATTR, None) + + +def get_validate_config_meta(func: HandlerCallable) -> ValidateConfigMeta | None: + return getattr(func, VALIDATE_CONFIG_META_ATTR, None) + + +def get_provider_change_meta(func: HandlerCallable) -> ProviderChangeMeta | None: + return getattr(func, PROVIDER_CHANGE_META_ATTR, None) + + +def get_background_task_meta(func: HandlerCallable) -> BackgroundTaskMeta | None: + return getattr(func, BACKGROUND_TASK_META_ATTR, None) + + +def get_mcp_server_meta(obj: Any) -> list[MCPServerMeta]: + values = getattr(obj, MCP_SERVER_META_ATTR, None) + if not isinstance(values, list): + return [] + return [item for item in values if isinstance(item, MCPServerMeta)] + + +def get_skill_meta(obj: Any) -> list[SkillMeta]: + values = getattr(obj, SKILL_META_ATTR, None) + if not isinstance(values, list): + return [] + return [item for item in values if isinstance(item, SkillMeta)] + + +def _append_list_meta(obj: Any, attr_name: str, value: Any) -> None: + values = getattr(obj, attr_name, None) + if not isinstance(values, list): + values = [] + setattr(obj, attr_name, values) + values.append(value) + + +def _replace_filter(meta: HandlerMeta, spec: FilterSpec) -> None: + kind = getattr(spec, "kind", None) + meta.filters = [ + item for item in meta.filters if getattr(item, "kind", None) != kind + ] + meta.filters.append(spec) + + +def _has_filter_kind(meta: HandlerMeta, kind: str) -> bool: + return any(getattr(item, "kind", None) == kind for item in meta.filters) + + +def _set_platform_filter( + meta: HandlerMeta, + values: list[str], + *, + source: str, +) -> None: + normalized = [ + value for value in dict.fromkeys(str(item).strip() for item in values) if value + ] + if not normalized: + return + existing = meta.decorator_sources.get("platforms") + if existing is not None and existing != source: + raise ValueError("platforms(...) 不能与 on_message(platforms=...) 混用") + if existing is None and _has_filter_kind(meta, "platform"): + raise ValueError("platforms(...) 不能与已有平台过滤器混用") + meta.decorator_sources["platforms"] = source + _replace_filter(meta, PlatformFilterSpec(platforms=normalized)) + + +def _set_message_type_filter( + meta: HandlerMeta, + values: list[str], + *, + source: str, +) -> None: + normalized = [ + value + for value in dict.fromkeys(str(item).strip().lower() for item in values) + if value + ] + if not normalized: + return + existing = meta.decorator_sources.get("message_types") + if existing is not None and existing != source: + raise ValueError( + "group_only()/private_only()/message_types(...) 不能与已有消息类型约束混用" + ) + if existing is None and _has_filter_kind(meta, "message_type"): + raise ValueError( + "group_only()/private_only()/message_types(...) 不能与已有消息类型过滤器混用" + ) + meta.decorator_sources["message_types"] = source + _replace_filter(meta, MessageTypeFilterSpec(message_types=normalized)) + + +def _validate_message_trigger_compatibility(meta: HandlerMeta) -> None: + if meta.limiter is None or meta.trigger is None: + return + trigger_type = getattr(meta.trigger, "type", None) + if trigger_type not in {"command", "message"}: + raise ValueError( + "rate_limit(...) 和 cooldown(...) 只适用于 on_command/on_message" + ) + + +def _set_required_role( + meta: HandlerMeta, + role: Literal["member", "admin"], +) -> None: + current = meta.permissions.required_role + if current is not None and current != role: + raise ValueError( + f"require_permission({role!r}) 与已有权限要求 {current!r} 冲突" + ) + meta.permissions.required_role = role + meta.permissions.require_admin = role == "admin" + + +def _normalize_description(description: str | None) -> str | None: + if description is None: + return None + text = str(description).strip() + return text or None + + +def _validate_limiter_args( + *, + kind: str, + limit: int, + window: float, + scope: LimiterScope, + behavior: LimiterBehavior, +) -> None: + if isinstance(limit, bool) or int(limit) <= 0: + raise ValueError(f"{kind} requires a positive limit") + if float(window) <= 0: + raise ValueError(f"{kind} requires a positive window") + if scope not in {"session", "user", "group", "global"}: + raise ValueError(f"unsupported limiter scope: {scope}") + if behavior not in {"hint", "silent", "error"}: + raise ValueError(f"unsupported limiter behavior: {behavior}") + + +def _set_limiter( + func: HandlerCallable, + limiter: LimiterMeta, +) -> HandlerCallable: + meta = _get_or_create_meta(func) + if meta.limiter is not None: + raise ValueError("rate_limit(...) 和 cooldown(...) 不能叠加在同一个 handler 上") + meta.limiter = limiter + _validate_message_trigger_compatibility(meta) + return func + + +def _model_to_schema( + model: type[BaseModel] | None, + *, + label: str, +) -> dict[str, Any] | None: + """将 pydantic 模型转换为 JSON Schema。 + + Args: + model: pydantic BaseModel 子类 + label: 错误消息中的字段名 + + Returns: + JSON Schema 字典,如果 model 为 None 则返回 None + + Raises: + TypeError: 如果 model 不是 BaseModel 子类 + """ + if model is None: + return None + if not isinstance(model, type) or not issubclass(model, BaseModel): + raise TypeError(f"{label} 必须是 pydantic BaseModel 子类") + return cast(dict[str, Any], model.model_json_schema()) + + +def on_command( + command: str | typing.Sequence[str], + *, + aliases: list[str] | None = None, + description: str | None = None, +) -> Callable[[HandlerCallable], HandlerCallable]: + """注册命令处理方法。 + + 当用户发送指定命令时触发。命令格式为 `/{command}` 或直接 `{command}`, + 取决于平台配置。 + + Args: + command: 命令名称(不包含前缀符) + aliases: 命令别名列表 + description: 命令描述,用于帮助信息 + + Returns: + 装饰器函数 + + Example: + @on_command("echo", aliases=["repeat"], description="重复消息") + async def echo(self, event: MessageEvent, ctx: Context): + await event.reply(event.text) + """ + + commands = ( + [str(command).strip()] + if isinstance(command, str) + else [str(item).strip() for item in command] + ) + commands = [item for item in commands if item] + if not commands: + raise ValueError("on_command requires at least one non-empty command name") + canonical = commands[0] + merged_aliases: list[str] = [ + item + for item in dict.fromkeys([*commands[1:], *(aliases or [])]) + if isinstance(item, str) and item and item != canonical + ] + + def decorator(func: HandlerCallable) -> HandlerCallable: + meta = _get_or_create_meta(func) + normalized_description = _normalize_description(description) + meta.trigger = CommandTrigger( + command=canonical, + aliases=merged_aliases, + description=normalized_description, + ) + meta.description = normalized_description + _validate_message_trigger_compatibility(meta) + return func + + return decorator + + +def on_message( + *, + regex: str | None = None, + keywords: list[str] | None = None, + platforms: list[str] | None = None, + message_types: list[str] | None = None, + description: str | None = None, +) -> Callable[[HandlerCallable], HandlerCallable]: + """注册消息处理方法。 + + 当消息匹配指定条件时触发。支持正则表达式或关键词匹配。 + + Args: + regex: 正则表达式模式 + keywords: 关键词列表(任一匹配即可) + platforms: 限定平台列表(如 ["qq", "wechat"]) + + Returns: + 装饰器函数 + + Note: + regex 和 keywords 至少提供一个 + + Example: + @on_message(keywords=["help", "帮助"]) + async def help(self, event: MessageEvent, ctx: Context): + await event.reply("帮助信息") + + @on_message(regex=r"\\d+") # 匹配数字 + async def number_handler(self, event: MessageEvent, ctx: Context): + await event.reply("收到了数字") + """ + + def decorator(func: HandlerCallable) -> HandlerCallable: + meta = _get_or_create_meta(func) + meta.trigger = MessageTrigger( + regex=regex, + keywords=keywords or [], + platforms=platforms or [], + message_types=message_types or [], + ) + meta.description = _normalize_description(description) + if platforms: + _set_platform_filter(meta, list(platforms), source="trigger.platforms") + if message_types: + _set_message_type_filter( + meta, + list(message_types), + source="trigger.message_types", + ) + _validate_message_trigger_compatibility(meta) + return func + + return decorator + + +def append_filter_meta( + func: HandlerCallable, + *, + specs: list[FilterSpec] | None = None, + local_bindings: list[Any] | None = None, +) -> HandlerCallable: + """追加过滤器元数据。""" + meta = _get_or_create_meta(func) + if specs: + meta.filters.extend(specs) + if local_bindings: + meta.local_filters.extend(local_bindings) + return func + + +def set_command_route_meta( + func: HandlerCallable, + route: CommandRouteSpec, +) -> HandlerCallable: + """设置命令路由元数据。""" + meta = _get_or_create_meta(func) + meta.command_route = route + return func + + +def on_event( + event_type: str, + *, + description: str | None = None, +) -> Callable[[HandlerCallable], HandlerCallable]: + """注册事件处理方法。 + + 当特定类型的事件发生时触发。用于处理非消息类型的事件, + 如群成员变动、好友请求等。 + + Args: + event_type: 事件类型标识 + + Returns: + 装饰器函数 + + Example: + @on_event("group_member_join") + async def on_join(self, event, ctx): + await ctx.platform.send(event.group_id, "欢迎新人!") + """ + + def decorator(func: HandlerCallable) -> HandlerCallable: + meta = _get_or_create_meta(func) + meta.trigger = EventTrigger(event_type=event_type) + meta.description = _normalize_description(description) + _validate_message_trigger_compatibility(meta) + return func + + return decorator + + +def on_schedule( + *, + cron: str | None = None, + interval_seconds: int | None = None, + description: str | None = None, +) -> Callable[[HandlerCallable], HandlerCallable]: + """注册定时任务方法。 + + 按指定的时间计划定期执行。 + + Args: + cron: cron 表达式(如 "0 8 * * *" 表示每天 8:00) + interval_seconds: 执行间隔(秒) + + Returns: + 装饰器函数 + + Note: + cron 和 interval_seconds 至少提供一个 + + Example: + @on_schedule(cron="0 8 * * *") # 每天 8:00 + async def morning_greeting(self, ctx): + await ctx.platform.send("group_123", "早上好!") + + @on_schedule(interval_seconds=3600) # 每小时 + async def hourly_check(self, ctx): + pass + """ + + def decorator(func: HandlerCallable) -> HandlerCallable: + meta = _get_or_create_meta(func) + meta.trigger = ScheduleTrigger(cron=cron, interval_seconds=interval_seconds) + meta.description = _normalize_description(description) + _validate_message_trigger_compatibility(meta) + return func + + return decorator + + +def http_api( + route: str, + *, + methods: list[str] | None = None, + description: str = "", + capability_name: str | None = None, +) -> Callable[[HandlerCallable], HandlerCallable]: + normalized_route = str(route).strip() + if not normalized_route: + raise ValueError("http_api(...) requires a non-empty route") + normalized_methods = methods or ["GET"] + normalized_methods = [ + str(item).strip().upper() for item in normalized_methods if str(item).strip() + ] + if not normalized_methods: + raise ValueError("http_api(...) requires at least one HTTP method") + + def decorator(func: HandlerCallable) -> HandlerCallable: + setattr( + func, + HTTP_API_META_ATTR, + HttpApiMeta( + route=normalized_route, + methods=normalized_methods, + description=str(description), + capability_name=( + str(capability_name).strip() + if capability_name is not None + else None + ), + ), + ) + return func + + return decorator + + +def validate_config( + *, + model: type[BaseModel] | None = None, + schema: dict[str, Any] | None = None, +) -> Callable[[HandlerCallable], HandlerCallable]: + if model is None and schema is None: + raise ValueError("validate_config(...) requires model or schema") + if model is not None and schema is not None: + raise ValueError("validate_config(...) cannot accept model and schema together") + if model is not None and ( + not isinstance(model, type) or not issubclass(model, BaseModel) + ): + raise TypeError("validate_config model must be a pydantic BaseModel subclass") + if schema is not None and not isinstance(schema, dict): + raise TypeError("validate_config schema must be a dict") + if isinstance(schema, dict): + _validate_validate_config_schema(schema) + + def decorator(func: HandlerCallable) -> HandlerCallable: + setattr( + func, + VALIDATE_CONFIG_META_ATTR, + ValidateConfigMeta( + model=model, + schema=dict(schema) if isinstance(schema, dict) else None, + ), + ) + return func + + return decorator + + +def on_provider_change( + *, + provider_types: list[str] | tuple[str, ...] | None = None, +) -> Callable[[HandlerCallable], HandlerCallable]: + normalized = [ + str(item).strip().lower() + for item in (provider_types or []) + if str(item).strip() + ] + + def decorator(func: HandlerCallable) -> HandlerCallable: + setattr( + func, + PROVIDER_CHANGE_META_ATTR, + ProviderChangeMeta(provider_types=normalized), + ) + return func + + return decorator + + +def background_task( + *, + description: str = "", + auto_start: bool = True, + on_error: Literal["log", "restart"] = "log", +) -> Callable[[HandlerCallable], HandlerCallable]: + if on_error not in {"log", "restart"}: + raise ValueError("background_task on_error must be 'log' or 'restart'") + + def decorator(func: HandlerCallable) -> HandlerCallable: + setattr( + func, + BACKGROUND_TASK_META_ATTR, + BackgroundTaskMeta( + description=str(description), + auto_start=bool(auto_start), + on_error=on_error, + ), + ) + return func + + return decorator + + +def mcp_server( + *, + name: str, + scope: Literal["local", "global"] = "global", + config: dict[str, Any] | None = None, + timeout: float = 30.0, + wait_until_ready: bool = True, +): + normalized_name = str(name).strip() + if not normalized_name: + raise ValueError("mcp_server(...) requires a non-empty name") + if scope not in {"local", "global"}: + raise ValueError("mcp_server scope must be 'local' or 'global'") + if config is not None and not isinstance(config, dict): + raise TypeError("mcp_server config must be a dict") + if float(timeout) <= 0: + raise ValueError("mcp_server timeout must be positive") + + meta = MCPServerMeta( + name=normalized_name, + scope=scope, + config=dict(config) if isinstance(config, dict) else None, + timeout=float(timeout), + wait_until_ready=bool(wait_until_ready), + ) + + def decorator(target): + _append_list_meta(target, MCP_SERVER_META_ATTR, meta) + return target + + return decorator + + +def register_skill( + *, + name: str, + path: str, + description: str = "", +): + normalized_name = str(name).strip() + normalized_path = str(path).strip() + if not normalized_name: + raise ValueError("register_skill(...) requires a non-empty name") + if not normalized_path: + raise ValueError("register_skill(...) requires a non-empty path") + + meta = SkillMeta( + name=normalized_name, + path=normalized_path, + description=str(description), + ) + + def decorator(target): + _append_list_meta(target, SKILL_META_ATTR, meta) + return target + + return decorator + + +def require_admin(func: HandlerCallable) -> HandlerCallable: + """标记 handler 需要管理员权限。 + + 当用户不是管理员时,handler 将不会被调用。 + + Args: + func: 要标记的方法 + + Returns: + 标记后的方法 + + Example: + @on_command("admin") + @require_admin + async def admin_only(self, event: MessageEvent, ctx: Context): + await event.reply("管理员命令执行成功") + """ + meta = _get_or_create_meta(func) + _set_required_role(meta, "admin") + return func + + +def admin_only(func: HandlerCallable) -> HandlerCallable: + return require_admin(func) + + +def require_permission( + role: Literal["member", "admin"], +) -> Callable[[HandlerCallable], HandlerCallable]: + normalized_role = str(role).strip().lower() + if normalized_role not in {"member", "admin"}: + raise ValueError("require_permission(...) 只支持 'member' 或 'admin'") + + def decorator(func: HandlerCallable) -> HandlerCallable: + meta = _get_or_create_meta(func) + _set_required_role( + meta, + cast(Literal["member", "admin"], normalized_role), + ) + return func + + return decorator + + +def platforms(*names: str) -> Callable[[HandlerCallable], HandlerCallable]: + def decorator(func: HandlerCallable) -> HandlerCallable: + meta = _get_or_create_meta(func) + _set_platform_filter(meta, list(names), source="decorator.platforms") + return func + + return decorator + + +def message_types(*types: str) -> Callable[[HandlerCallable], HandlerCallable]: + def decorator(func: HandlerCallable) -> HandlerCallable: + meta = _get_or_create_meta(func) + _set_message_type_filter( + meta, + list(types), + source="decorator.message_types", + ) + return func + + return decorator + + +def group_only() -> Callable[[HandlerCallable], HandlerCallable]: + def decorator(func: HandlerCallable) -> HandlerCallable: + meta = _get_or_create_meta(func) + _set_message_type_filter(meta, ["group"], source="decorator.group_only") + return func + + return decorator + + +def private_only() -> Callable[[HandlerCallable], HandlerCallable]: + def decorator(func: HandlerCallable) -> HandlerCallable: + meta = _get_or_create_meta(func) + _set_message_type_filter(meta, ["private"], source="decorator.private_only") + return func + + return decorator + + +def priority(value: int) -> Callable[[HandlerCallable], HandlerCallable]: + if isinstance(value, bool) or not isinstance(value, int): + raise ValueError("priority(...) requires an integer") + + def decorator(func: HandlerCallable) -> HandlerCallable: + meta = _get_or_create_meta(func) + meta.priority = value + return func + + return decorator + + +def rate_limit( + limit: int, + window: float, + *, + scope: LimiterScope = "session", + behavior: LimiterBehavior = "hint", + message: str | None = None, +) -> Callable[[HandlerCallable], HandlerCallable]: + _validate_limiter_args( + kind="rate_limit", + limit=limit, + window=window, + scope=scope, + behavior=behavior, + ) + + def decorator(func: HandlerCallable) -> HandlerCallable: + return _set_limiter( + func, + LimiterMeta( + kind="rate_limit", + limit=int(limit), + window=float(window), + scope=scope, + behavior=behavior, + message=message, + ), + ) + + return decorator + + +def cooldown( + seconds: float, + *, + scope: LimiterScope = "session", + behavior: LimiterBehavior = "hint", + message: str | None = None, +) -> Callable[[HandlerCallable], HandlerCallable]: + _validate_limiter_args( + kind="cooldown", + limit=1, + window=seconds, + scope=scope, + behavior=behavior, + ) + + def decorator(func: HandlerCallable) -> HandlerCallable: + return _set_limiter( + func, + LimiterMeta( + kind="cooldown", + limit=1, + window=float(seconds), + scope=scope, + behavior=behavior, + message=message, + ), + ) + + return decorator + + +def conversation_command( + command: str | typing.Sequence[str], + *, + aliases: list[str] | None = None, + description: str | None = None, + timeout: int = 60, + mode: ConversationMode = "replace", + busy_message: str | None = None, + grace_period: float = 1.0, +) -> Callable[[HandlerCallable], HandlerCallable]: + """注册带会话生命周期的命令处理方法。 + + 在 ``on_command`` 基础上附加会话元数据,支持超时、并发策略和宽限期控制。 + + Args: + command: 命令名称或序列(首项为正式名,其余视为别名) + aliases: 额外别名列表 + description: 命令描述 + timeout: 会话超时时间(秒),必须为正整数 + mode: 会话冲突时的行为: + - ``"replace"``: 替换当前会话 + - ``"reject"``: 拒绝新请求 + busy_message: 拒绝新请求时的提示消息 + grace_period: 宽限期(秒),用于会话生命周期处理 + + Returns: + 装饰器函数 + + Raises: + ValueError: mode 不合法、timeout 非正整数或 grace_period 非正数 + + Example: + @conversation_command("chat", timeout=120, mode="reject", busy_message="请稍后再试") + async def chat(self, event: MessageEvent, ctx: Context): + await event.reply("开始对话...") + """ + if mode not in {"replace", "reject"}: + raise ValueError("conversation_command mode must be 'replace' or 'reject'") + # bool 是 int 子类,需单独排除 + if isinstance(timeout, bool) or int(timeout) <= 0: + raise ValueError("conversation_command timeout must be a positive integer") + if float(grace_period) <= 0: + raise ValueError("conversation_command grace_period must be positive") + + command_decorator = on_command( + command, + aliases=aliases, + description=description, + ) + + def decorator(func: HandlerCallable) -> HandlerCallable: + decorated = command_decorator(func) + meta = _get_or_create_meta(decorated) + meta.conversation = ConversationMeta( + timeout=int(timeout), + mode=mode, + busy_message=busy_message, + grace_period=float(grace_period), + ) + return decorated + + return decorator + + +def provide_capability( + name: str, + *, + description: str, + input_schema: dict[str, Any] | None = None, + output_schema: dict[str, Any] | None = None, + input_model: type[BaseModel] | None = None, + output_model: type[BaseModel] | None = None, + supports_stream: bool = False, + cancelable: bool = False, +) -> Callable[[HandlerCallable], HandlerCallable]: + """声明插件对外暴露的 capability。 + + 允许其他插件或 Core 通过 capability 名称调用此方法。 + 支持使用 JSON Schema 或 pydantic 模型定义输入输出。 + + Args: + name: capability 名称(不能使用保留命名空间) + description: 能力描述 + input_schema: 输入 JSON Schema + output_schema: 输出 JSON Schema + input_model: 输入 pydantic 模型(与 input_schema 二选一) + output_model: 输出 pydantic 模型(与 output_schema 二选一) + supports_stream: 是否支持流式输出 + cancelable: 是否可取消 + + Returns: + 装饰器函数 + + Raises: + ValueError: 如果使用保留命名空间,或同时提供 schema 和 model + + Example: + @provide_capability( + "my_plugin.calculate", + description="执行计算", + input_model=CalculateInput, + output_model=CalculateOutput, + ) + async def calculate(self, payload: dict, ctx: Context): + return {"result": payload["x"] * 2} + """ + + def decorator(func: HandlerCallable) -> HandlerCallable: + if name.startswith(RESERVED_CAPABILITY_PREFIXES): + raise ValueError(f"保留 capability 命名空间不能用于插件导出:{name}") + if input_schema is not None and input_model is not None: + raise ValueError("input_schema 和 input_model 不能同时提供") + if output_schema is not None and output_model is not None: + raise ValueError("output_schema 和 output_model 不能同时提供") + descriptor = CapabilityDescriptor( + name=name, + description=description, + input_schema=( + input_schema + if input_schema is not None + else _model_to_schema(input_model, label="input_model") + ), + output_schema=( + output_schema + if output_schema is not None + else _model_to_schema(output_model, label="output_model") + ), + supports_stream=supports_stream, + cancelable=cancelable, + ) + setattr(func, CAPABILITY_META_ATTR, CapabilityMeta(descriptor=descriptor)) + return func + + return decorator + + +def _annotation_to_schema(annotation: Any) -> dict[str, Any]: + normalized, _is_optional = unwrap_optional(annotation) + origin = typing.get_origin(normalized) + if normalized is str: + return {"type": "string"} + if normalized is int: + return {"type": "integer"} + if normalized is float: + return {"type": "number"} + if normalized is bool: + return {"type": "boolean"} + if normalized is dict or origin is dict: + return {"type": "object"} + if normalized is list or origin is list: + args = typing.get_args(normalized) + item_schema = _annotation_to_schema(args[0]) if args else {} + return {"type": "array", "items": item_schema} + return {"type": "string"} + + +def _callable_parameters_schema(func: HandlerCallable) -> dict[str, Any]: + signature = inspect.signature(func) + type_hints: dict[str, Any] = {} + try: + type_hints = typing.get_type_hints(func) + except Exception: + type_hints = {} + + properties: dict[str, Any] = {} + required: list[str] = [] + for parameter in signature.parameters.values(): + if parameter.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + continue + if parameter.name == "self": + continue + annotation = type_hints.get(parameter.name) + normalized, _is_optional = unwrap_optional(annotation) + if parameter.name in {"event", "ctx", "context"}: + continue + properties[parameter.name] = _annotation_to_schema(normalized) + if parameter.default is inspect.Parameter.empty and not _is_optional: + required.append(parameter.name) + schema: dict[str, Any] = {"type": "object", "properties": properties} + if required: + schema["required"] = required + return schema + + +def register_llm_tool( + name: str | None = None, + *, + description: str | None = None, + parameters_schema: dict[str, Any] | None = None, + active: bool = True, +) -> Callable[[HandlerCallable], HandlerCallable]: + def decorator(func: HandlerCallable) -> HandlerCallable: + tool_name = str(name or func.__name__).strip() + if not tool_name: + raise ValueError("LLM tool name must not be empty") + setattr( + func, + LLM_TOOL_META_ATTR, + LLMToolMeta( + spec=LLMToolSpec.create( + name=tool_name, + description=description + or (inspect.getdoc(func) or "").splitlines()[0] + if inspect.getdoc(func) + else "", + parameters_schema=parameters_schema + or _callable_parameters_schema(func), + handler_ref=tool_name, + active=active, + ) + ), + ) + return func + + return decorator + + +def register_agent( + name: str, + *, + description: str = "", + tool_names: list[str] | None = None, +) -> Callable[[type[BaseAgentRunner]], type[BaseAgentRunner]]: + def decorator(cls: type[BaseAgentRunner]) -> type[BaseAgentRunner]: + if not inspect.isclass(cls) or not issubclass(cls, BaseAgentRunner): + raise TypeError("@register_agent() 只接受 BaseAgentRunner 子类") + setattr( + cls, + AGENT_META_ATTR, + AgentMeta( + spec=AgentSpec( + name=name, + description=description, + tool_names=list(tool_names or []), + runner_class=f"{cls.__module__}.{cls.__qualname__}", + ) + ), + ) + return cls + + return decorator + + +def acknowledge_global_mcp_risk(cls: type[Any]) -> type[Any]: + """Mark an SDK plugin class as eligible to mutate global MCP state. + + This is intentionally a coarse, class-level marker. Runtime enforcement lives + in the Core MCP capability bridge. + """ + + setattr(cls, "__astrbot_acknowledge_global_mcp_risk__", True) + return cls diff --git a/astrbot-sdk/src/astrbot_sdk/errors.py b/astrbot-sdk/src/astrbot_sdk/errors.py new file mode 100644 index 0000000000..ffe267a0c1 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/errors.py @@ -0,0 +1,311 @@ +"""跨运行时边界传递的统一错误模型。 + +AstrBotError 是 SDK 中所有可预期错误的标准格式, +支持跨进程传递(通过 to_payload/from_payload 序列化)。 + +错误处理流程: + 1. 运行时抛出 AstrBotError 子类或实例 + 2. 错误被捕获并序列化为 payload + 3. 跨进程传输后反序列化 + 4. 在 on_error 钩子中统一处理 + +Example: + # 抛出错误 + raise AstrBotError.invalid_input("参数不能为空") + + # 捕获并处理 + try: + await some_operation() + except AstrBotError as e: + if e.retryable: + # 可重试的错误 + await retry() + else: + # 不可重试的错误 + await event.reply(e.hint or e.message) +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +class ErrorCodes: + """AstrBot v4 的稳定错误码常量。 + + 这些错误码在协议层稳定,不应随意更改。 + 新增错误码应放在对应分类的末尾。 + + 分类: + - 不可重试错误(retryable=False):配置错误、权限错误等 + - 可重试错误(retryable=True):网络超时、临时故障等 + """ + + UNKNOWN_ERROR = "unknown_error" + + # 不可重试错误 - 配置或使用问题 + LLM_NOT_CONFIGURED = "llm_not_configured" + CAPABILITY_NOT_FOUND = "capability_not_found" + PERMISSION_DENIED = "permission_denied" + LLM_ERROR = "llm_error" + INVALID_INPUT = "invalid_input" + CANCELLED = "cancelled" + PROTOCOL_VERSION_MISMATCH = "protocol_version_mismatch" + PROTOCOL_ERROR = "protocol_error" + INTERNAL_ERROR = "internal_error" + RATE_LIMITED = "rate_limited" + COOLDOWN_ACTIVE = "cooldown_active" + + # 可重试错误 - 临时故障 + CAPABILITY_TIMEOUT = "capability_timeout" + NETWORK_ERROR = "network_error" + LLM_TEMPORARY_ERROR = "llm_temporary_error" + + +@dataclass(slots=True) +class AstrBotError(Exception): + """AstrBot SDK 的标准错误类型。 + + 所有可预期的错误都应使用此类或其工厂方法创建。 + 支持跨进程传递,包含用户友好的提示信息。 + + Attributes: + code: 错误码,来自 ErrorCodes 常量 + message: 错误消息,面向开发者 + hint: 用户提示,面向终端用户 + retryable: 是否可重试 + + Example: + # 使用工厂方法创建错误 + raise AstrBotError.invalid_input("参数格式错误", hint="请使用 JSON 格式") + + # 检查错误类型 + try: + await operation() + except AstrBotError as e: + if e.code == ErrorCodes.CAPABILITY_NOT_FOUND: + logger.error(f"能力不存在: {e.message}") + """ + + code: str + message: str + hint: str = "" + retryable: bool = False + docs_url: str = "" + details: dict[str, Any] | None = None + + def __str__(self) -> str: + return self.message + + @classmethod + def cancelled(cls, message: str = "调用被取消") -> AstrBotError: + """创建取消错误。 + + Args: + message: 错误消息 + + Returns: + AstrBotError 实例 + """ + return cls( + code=ErrorCodes.CANCELLED, + message=message, + hint="", + retryable=False, + ) + + @classmethod + def capability_not_found(cls, name: str) -> AstrBotError: + """创建能力未找到错误。 + + Args: + name: 未找到的能力名称 + + Returns: + AstrBotError 实例 + """ + return cls( + code=ErrorCodes.CAPABILITY_NOT_FOUND, + message=f"未找到能力:{name}", + hint="请确认 AstrBot Core 是否已注册该 capability", + retryable=False, + ) + + @classmethod + def invalid_input( + cls, + message: str, + *, + hint: str = "请检查调用参数", + docs_url: str = "", + details: dict[str, Any] | None = None, + ) -> AstrBotError: + """创建输入无效错误。 + + Args: + message: 详细错误消息 + hint: 用户提示 + + Returns: + AstrBotError 实例 + """ + return cls( + code=ErrorCodes.INVALID_INPUT, + message=message, + hint=hint, + retryable=False, + docs_url=docs_url, + details=details, + ) + + @classmethod + def protocol_version_mismatch(cls, message: str) -> AstrBotError: + """创建协议版本不匹配错误。 + + Args: + message: 详细错误消息 + + Returns: + AstrBotError 实例 + """ + return cls( + code=ErrorCodes.PROTOCOL_VERSION_MISMATCH, + message=message, + hint="请升级 astrbot_sdk 至最新版本", + retryable=False, + ) + + @classmethod + def protocol_error(cls, message: str) -> AstrBotError: + """创建协议错误。 + + Args: + message: 详细错误消息 + + Returns: + AstrBotError 实例 + """ + return cls( + code=ErrorCodes.PROTOCOL_ERROR, + message=message, + hint="请检查通信双方的协议实现", + retryable=False, + ) + + @classmethod + def internal_error( + cls, + message: str, + *, + hint: str = "请联系插件作者", + docs_url: str = "", + details: dict[str, Any] | None = None, + ) -> AstrBotError: + """创建内部错误。 + + Args: + message: 详细错误消息 + hint: 用户提示 + + Returns: + AstrBotError 实例 + """ + return cls( + code=ErrorCodes.INTERNAL_ERROR, + message=message, + hint=hint, + retryable=False, + docs_url=docs_url, + details=details, + ) + + @classmethod + def network_error( + cls, + message: str, + *, + hint: str = "网络请求失败,请稍后重试", + docs_url: str = "", + details: dict[str, Any] | None = None, + ) -> AstrBotError: + return cls( + code=ErrorCodes.NETWORK_ERROR, + message=message, + hint=hint, + retryable=True, + docs_url=docs_url, + details=details, + ) + + @classmethod + def rate_limited( + cls, + *, + hint: str = "操作过于频繁,请稍后再试。", + details: dict[str, Any] | None = None, + ) -> AstrBotError: + return cls( + code=ErrorCodes.RATE_LIMITED, + message="handler invocation is rate limited", + hint=hint, + retryable=False, + details=details, + ) + + @classmethod + def cooldown_active( + cls, + *, + hint: str, + details: dict[str, Any] | None = None, + ) -> AstrBotError: + return cls( + code=ErrorCodes.COOLDOWN_ACTIVE, + message="handler cooldown is active", + hint=hint, + retryable=False, + details=details, + ) + + def to_payload(self) -> dict[str, object]: + """序列化为可传输的字典格式。 + + 用于跨进程传递错误信息。 + + Returns: + 包含错误信息的字典 + """ + return { + "code": self.code, + "message": self.message, + "hint": self.hint, + "retryable": self.retryable, + "docs_url": self.docs_url, + "details": dict(self.details) if isinstance(self.details, dict) else None, + } + + @classmethod + def from_payload(cls, payload: dict[str, object]) -> AstrBotError: + """从字典反序列化错误实例。 + + Args: + payload: 包含错误信息的字典 + + Returns: + AstrBotError 实例 + """ + details_payload = payload.get("details") + details = ( + {str(key): value for key, value in details_payload.items()} + if isinstance(details_payload, dict) + else None + ) + return cls( + code=str(payload.get("code", ErrorCodes.UNKNOWN_ERROR)), + message=str(payload.get("message", "未知错误")), + hint=str(payload.get("hint", "")), + retryable=bool(payload.get("retryable", False)), + docs_url=str(payload.get("docs_url", "")), + details=details, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/events.py b/astrbot-sdk/src/astrbot_sdk/events.py new file mode 100644 index 0000000000..9d07b3cffd --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/events.py @@ -0,0 +1,752 @@ +"""v4 原生事件对象。 + +顶层 ``MessageEvent`` 保持精简,只承载 v4 运行时真正需要的基础能力。 +迁移期扩展事件能力放在独立模块中,而不是继续塞回顶层事件类型。 + +MessageEvent 是 handler 接收的主要事件类型,封装了: + - 消息文本内容 + - 发送者信息(user_id, group_id) + - 平台标识 + - 回复能力(reply, reply_image, reply_chain) +""" + +from __future__ import annotations + +import json +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from ._message_types import normalize_message_type +from .message.components import ( + At, + BaseMessageComponent, + File, + Image, + Plain, + component_to_payload_sync, + payloads_to_components, +) +from .message.result import EventResultType, MessageChain, MessageEventResult +from .protocol.descriptors import SessionRef + +if TYPE_CHECKING: + from .context import Context + + +@dataclass(slots=True) +class PlainTextResult: + """纯文本结果。 + + 用于 handler 返回简单的文本结果。 + """ + + text: str + + +ReplyHandler = Callable[[str], Awaitable[None]] + +_JSON_DROP = object() + + +def _coerce_str(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + return str(value) + + +def _coerce_optional_str(value: Any) -> str | None: + if value is None: + return None + text = value if isinstance(value, str) else str(value) + return text or None + + +def _json_safe_value(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, (list, tuple)): + items = [] + for item in value: + normalized = _json_safe_value(item) + if normalized is not _JSON_DROP: + items.append(normalized) + return items + if isinstance(value, dict): + normalized_dict: dict[str, Any] = {} + for key, item in value.items(): + normalized = _json_safe_value(item) + if normalized is not _JSON_DROP: + normalized_dict[str(key)] = normalized + return normalized_dict + model_dump = getattr(value, "model_dump", None) + if callable(model_dump): + try: + return _json_safe_value(model_dump()) + except Exception: + return _JSON_DROP + try: + json.dumps(value) + except (TypeError, ValueError): + return _JSON_DROP + return value + + +def _json_safe_mapping(value: Any) -> dict[str, Any]: + if not isinstance(value, dict): + return {} + normalized: dict[str, Any] = {} + for key, item in value.items(): + safe_item = _json_safe_value(item) + if safe_item is not _JSON_DROP: + normalized[str(key)] = safe_item + return normalized + + +class MessageEvent: + """消息事件对象。 + + 封装收到的消息,提供便捷的回复方法。 + 每个 handler 调用都会创建新的 MessageEvent 实例。 + + Attributes: + text: 消息文本内容 + user_id: 发送者用户 ID,缺失时为空字符串 + group_id: 群组 ID(私聊时为 None) + platform: 平台标识(如 "qq", "wechat"),缺失时为空字符串 + session_id: 会话 ID(通常是 group_id 或 user_id,缺失时为空字符串) + raw: 原始消息数据 + + Example: + @on_command("echo") + async def echo(self, event: MessageEvent, ctx: Context): + await event.reply(f"你说: {event.text}") + """ + + text: str + user_id: str + group_id: str | None + platform: str + session_id: str + self_id: str + platform_id: str + message_type: str + sender_name: str + + def __init__( + self, + *, + text: str = "", + user_id: str | None = None, + group_id: str | None = None, + platform: str | None = None, + session_id: str | None = None, + self_id: str | None = None, + platform_id: str | None = None, + message_type: str | None = None, + sender_name: str | None = None, + is_admin: bool = False, + raw: dict[str, Any] | None = None, + context: Context | None = None, + reply_handler: ReplyHandler | None = None, + ) -> None: + """初始化消息事件。 + + Args: + text: 消息文本 + user_id: 用户 ID + group_id: 群组 ID + platform: 平台标识 + session_id: 会话 ID,None 时自动从 group_id/user_id 推断 + raw: 原始消息数据 + context: 运行时上下文 + reply_handler: 自定义回复处理器 + """ + normalized_user_id = _coerce_str(user_id) + normalized_group_id = _coerce_optional_str(group_id) + normalized_platform = _coerce_str(platform) + normalized_session_id = _coerce_str(session_id) + + self.text = text + self.user_id = normalized_user_id + self.group_id = normalized_group_id + self.platform = normalized_platform + self.session_id = ( + normalized_session_id or normalized_group_id or normalized_user_id or "" + ) + self.self_id = _coerce_str(self_id) + self.platform_id = _coerce_str(platform_id) or normalized_platform + self.message_type = normalize_message_type( + message_type, + group_id=normalized_group_id, + user_id=normalized_user_id, + ) + self.sender_name = _coerce_str(sender_name) + self._is_admin = bool(is_admin) + self.raw = raw or {} + self._stopped = False + host_extras = self.raw.get("host_extras") + raw_extras = self.raw.get("extras") + self._host_extras = _json_safe_mapping( + host_extras if isinstance(host_extras, dict) else raw_extras + ) + self._host_extras_present = "host_extras" in self.raw or "extras" in self.raw + sdk_local_extras = self.raw.get("sdk_local_extras") + self._sdk_local_extras = _json_safe_mapping(sdk_local_extras) + self._sdk_local_extras_present = "sdk_local_extras" in self.raw + self._sdk_local_extras_dirty = False + messages_payload = self.raw.get("messages") + self._messages = ( + payloads_to_components(messages_payload) + if isinstance(messages_payload, list) + else [] + ) + self._messages_present = "messages" in self.raw + self._message_outline = str(self.raw.get("message_outline", self.text)) + sent_messages_payload = self.raw.get("sent_messages") + self._sent_messages = ( + payloads_to_components(sent_messages_payload) + if isinstance(sent_messages_payload, list) + else [] + ) + self._sent_messages_present = "sent_messages" in self.raw + self._sent_message_outline = str(self.raw.get("sent_message_outline", "")) + self._sent_message_outline_present = "sent_message_outline" in self.raw + self._context = context + self._reply_handler = reply_handler + if self._reply_handler is None and context is not None: + self._reply_handler = lambda text: context.platform.send( + self.session_ref or self.session_id, + text, + ) + + def _require_runtime_context(self, action: str) -> Context: + """获取运行时上下文,不存在则抛出异常。""" + if self._context is None: + raise RuntimeError(f"MessageEvent 未绑定运行时上下文,无法 {action}") + return self._context + + def _reply_target(self) -> SessionRef | str: + """获取回复目标。""" + return self.session_ref or self.session_id + + @classmethod + def from_payload( + cls, + payload: dict[str, Any], + *, + context: Context | None = None, + reply_handler: ReplyHandler | None = None, + ) -> MessageEvent: + """从协议载荷创建事件实例。 + + Args: + payload: 协议层传递的消息数据 + context: 运行时上下文 + reply_handler: 自定义回复处理器 + + Returns: + 新的 MessageEvent 实例 + """ + target_payload = payload.get("target") + session_id = payload.get("session_id") + platform = payload.get("platform") + if isinstance(target_payload, dict): + target = SessionRef.model_validate(target_payload) + session_id = session_id or target.session + platform = platform or target.platform + return cls( + text=str(payload.get("text", "")), + user_id=payload.get("user_id"), + group_id=payload.get("group_id"), + platform=platform, + session_id=session_id, + self_id=payload.get("self_id"), + platform_id=payload.get("platform_id"), + message_type=payload.get("message_type"), + sender_name=payload.get("sender_name"), + is_admin=bool(payload.get("is_admin", False)), + raw=payload, + context=context, + reply_handler=reply_handler, + ) + + def to_payload(self) -> dict[str, Any]: + """转换为协议载荷格式。 + + Returns: + 可序列化的字典 + """ + payload = dict(self.raw) + payload.update( + { + "text": self.text, + "user_id": self.user_id, + "group_id": self.group_id, + "platform": self.platform, + "session_id": self.session_id, + "self_id": self.self_id, + "platform_id": self.platform_id, + "message_type": self.message_type, + "sender_name": self.sender_name, + "is_admin": self._is_admin, + } + ) + if self.session_ref is not None: + payload["target"] = self.session_ref.to_payload() + merged_extras = dict(self._host_extras) + merged_extras.update(self._sdk_local_extras_payload()) + if merged_extras: + payload["extras"] = merged_extras + elif self._host_extras_present: + payload["extras"] = {} + else: + payload.pop("extras", None) + if self._host_extras or self._host_extras_present: + payload["host_extras"] = dict(self._host_extras) + else: + payload.pop("host_extras", None) + sdk_local_extras = self._sdk_local_extras_payload() + if sdk_local_extras or self._should_serialize_sdk_local_extras(): + payload["sdk_local_extras"] = sdk_local_extras + else: + payload.pop("sdk_local_extras", None) + if self._messages or self._messages_present: + payload["messages"] = [ + component_to_payload_sync(component) for component in self._messages + ] + else: + payload.pop("messages", None) + payload["message_outline"] = self._message_outline + if self._sent_messages or self._sent_messages_present: + payload["sent_messages"] = [ + component_to_payload_sync(component) + for component in self._sent_messages + ] + else: + payload.pop("sent_messages", None) + if self._sent_message_outline or self._sent_message_outline_present: + payload["sent_message_outline"] = self._sent_message_outline + else: + payload.pop("sent_message_outline", None) + return payload + + @property + def session_ref(self) -> SessionRef | None: + """获取会话引用对象。 + + Returns: + SessionRef 实例,如果没有有效的 session_id 则返回 None + """ + if not self.session_id: + return None + return SessionRef( + conversation_id=self.session_id, + platform=self.platform, + raw=self.raw or None, + ) + + @property + def target(self) -> SessionRef | None: + """session_ref 的别名。""" + return self.session_ref + + @property + def unified_msg_origin(self) -> str: + """Unified message origin string.""" + return self.session_id + + def is_private_chat(self) -> bool: + """Whether the current event belongs to a private chat.""" + if self.message_type: + return self.message_type == "private" + return not bool(self.group_id) + + def is_group_chat(self) -> bool: + if self.message_type: + return self.message_type == "group" + return bool(self.group_id) + + def get_platform_id(self) -> str: + """Get the platform instance identifier.""" + return self.platform_id + + def get_message_type(self) -> str: + """Get the normalized message type.""" + return self.message_type + + def get_session_id(self) -> str: + """Get the current session identifier.""" + return self.session_id + + def is_admin(self) -> bool: + """Whether the sender has admin permission.""" + return self._is_admin + + def get_messages(self) -> list[BaseMessageComponent]: + """Return SDK message components for the current event.""" + return list(self._messages) + + def get_sent_messages(self) -> list[BaseMessageComponent]: + """Return outbound SDK message components for after-send events.""" + return list(self._sent_messages) + + def has_component(self, type_: type[BaseMessageComponent]) -> bool: + return any(isinstance(component, type_) for component in self._messages) + + def get_components( + self, + type_: type[BaseMessageComponent], + ) -> list[BaseMessageComponent]: + return [ + component for component in self._messages if isinstance(component, type_) + ] + + def get_images(self) -> list[Image]: + return [ + component for component in self._messages if isinstance(component, Image) + ] + + def get_files(self) -> list[File]: + return [ + component for component in self._messages if isinstance(component, File) + ] + + def extract_plain_text(self) -> str: + return " ".join( + component.text + for component in self._messages + if isinstance(component, Plain) + ) + + def get_at_users(self) -> list[str]: + return [ + str(component.qq) + for component in self._messages + if isinstance(component, At) and str(component.qq).lower() != "all" + ] + + def get_message_outline(self) -> str: + """Return the normalized message outline.""" + return self._message_outline + + def get_sent_message_outline(self) -> str: + """Return the outbound message outline for after-send events.""" + return self._sent_message_outline + + async def get_group(self) -> dict[str, Any] | None: + """Get current-group metadata for the bound message request.""" + context = self._require_runtime_context("get_group") + output = await context._proxy.call( # noqa: SLF001 + "platform.get_group", + { + "session": self.session_id, + "target": ( + self.session_ref.to_payload() + if self.session_ref is not None + else None + ), + }, + ) + payload = output.get("group") + if not isinstance(payload, dict): + return None + return dict(payload) + + def set_extra(self, key: str, value: Any) -> None: + """Store SDK-local transient event data.""" + self._sdk_local_extras[key] = value + self._sdk_local_extras_dirty = True + + def get_extra(self, key: str | None = None, default: Any = None) -> Any: + """Read SDK-local transient event data.""" + extras = dict(self._host_extras) + extras.update(self._sdk_local_extras) + if key is None: + return extras + return extras.get(key, default) + + def clear_extra(self) -> None: + """Clear SDK-local transient event data.""" + self._sdk_local_extras.clear() + self._sdk_local_extras_dirty = True + + def _sdk_local_extras_payload(self) -> dict[str, Any]: + return _json_safe_mapping(self._sdk_local_extras) + + def _should_serialize_sdk_local_extras(self) -> bool: + return ( + self._sdk_local_extras_present + or self._sdk_local_extras_dirty + or bool(self._sdk_local_extras) + ) + + async def request_llm(self) -> bool: + """Request the default LLM chain for the current message request.""" + context = self._require_runtime_context("request_llm") + output = await context._proxy.call( # noqa: SLF001 + "system.event.llm.request", + { + "target": ( + self.session_ref.to_payload() + if self.session_ref is not None + else None + ), + }, + ) + return bool(output.get("should_call_llm", False)) + + async def should_call_llm(self) -> bool: + """Read the current default-LLM decision from the host bridge.""" + context = self._require_runtime_context("should_call_llm") + output = await context._proxy.call( # noqa: SLF001 + "system.event.llm.get_state", + { + "target": ( + self.session_ref.to_payload() + if self.session_ref is not None + else None + ), + }, + ) + return bool(output.get("should_call_llm", False)) + + async def set_result(self, result: MessageEventResult) -> MessageEventResult: + """Store a request-scoped SDK result in the host bridge.""" + context = self._require_runtime_context("set_result") + await context._proxy.call( # noqa: SLF001 + "system.event.result.set", + { + "target": ( + self.session_ref.to_payload() + if self.session_ref is not None + else None + ), + "result": result.to_payload(), + }, + ) + return result + + async def get_result(self) -> MessageEventResult | None: + """Read the current request-scoped SDK result from the host bridge.""" + context = self._require_runtime_context("get_result") + output = await context._proxy.call( # noqa: SLF001 + "system.event.result.get", + { + "target": ( + self.session_ref.to_payload() + if self.session_ref is not None + else None + ), + }, + ) + payload = output.get("result") + if not isinstance(payload, dict): + return None + return MessageEventResult.from_payload(payload) + + async def clear_result(self) -> None: + """Clear the current request-scoped SDK result.""" + context = self._require_runtime_context("clear_result") + await context._proxy.call( # noqa: SLF001 + "system.event.result.clear", + { + "target": ( + self.session_ref.to_payload() + if self.session_ref is not None + else None + ), + }, + ) + + def stop_event(self) -> None: + """Mark the SDK-local event as stopped.""" + self._stopped = True + + def continue_event(self) -> None: + """Clear the SDK-local stop flag.""" + self._stopped = False + + def is_stopped(self) -> bool: + """Return whether the SDK-local event is stopped.""" + return self._stopped + + async def reply(self, text: str) -> None: + """回复文本消息。 + + Args: + text: 要回复的文本内容 + + Raises: + RuntimeError: 如果未绑定 reply handler + """ + if self._reply_handler is None: + raise RuntimeError("MessageEvent 未绑定 reply handler,无法 reply") + await self._reply_handler(text) + + async def reply_image(self, image_url: str) -> None: + """回复图片消息。 + + Args: + image_url: 图片 URL + + Raises: + RuntimeError: 如果未绑定运行时上下文 + """ + context = self._require_runtime_context("reply_image") + await context.platform.send_image(self._reply_target(), image_url) + + async def reply_chain( + self, + chain: MessageChain | list[BaseMessageComponent] | list[dict[str, Any]], + ) -> None: + """回复消息链(多类型消息组合)。 + + Args: + chain: 消息链组件列表 + + Raises: + RuntimeError: 如果未绑定运行时上下文 + """ + context = self._require_runtime_context("reply_chain") + await context.platform.send_chain(self._reply_target(), chain) + + async def react(self, emoji: str) -> bool: + """Send a platform reaction when supported.""" + context = self._require_runtime_context("react") + output = await context._proxy.call( # noqa: SLF001 + "system.event.react", + { + "target": ( + self.session_ref.to_payload() + if self.session_ref is not None + else None + ), + "emoji": emoji, + }, + ) + return bool(output.get("supported", False)) + + async def send_typing(self) -> bool: + """Emit typing state when the host platform supports it.""" + context = self._require_runtime_context("send_typing") + output = await context._proxy.call( # noqa: SLF001 + "system.event.send_typing", + { + "target": ( + self.session_ref.to_payload() + if self.session_ref is not None + else None + ), + }, + ) + return bool(output.get("supported", False)) + + async def send_streaming( + self, + generator, + use_fallback: bool = False, + ) -> bool: + """Replay normalized chunks through the host streaming pathway.""" + context = self._require_runtime_context("send_streaming") + output = await context._proxy.call( # noqa: SLF001 + "system.event.send_streaming", + { + "target": ( + self.session_ref.to_payload() + if self.session_ref is not None + else None + ), + "use_fallback": use_fallback, + }, + ) + if not bool(output.get("supported", False)): + return False + + stream_id = str(output.get("stream_id", "")) + if not stream_id: + return False + + try: + async for item in generator: + if isinstance(item, str): + chain = MessageChain([Plain(item, convert=False)]) + else: + chain = self._coerce_chain_or_raise(item) + await context._proxy.call( # noqa: SLF001 + "system.event.send_streaming_chunk", + { + "stream_id": stream_id, + "chain": await chain.to_payload_async(), + }, + ) + finally: + output = await context._proxy.call( # noqa: SLF001 + "system.event.send_streaming_close", + {"stream_id": stream_id}, + ) + return bool(output.get("supported", False)) + + def bind_reply_handler(self, reply_handler: ReplyHandler) -> None: + """绑定自定义回复处理器。 + + Args: + reply_handler: 回复处理函数 + """ + self._reply_handler = reply_handler + + def plain_result(self, text: str) -> PlainTextResult: + """创建纯文本结果。 + + Args: + text: 结果文本 + + Returns: + PlainTextResult 实例 + """ + return PlainTextResult(text=text) + + def make_result(self) -> MessageEventResult: + """Create an empty SDK-local result wrapper.""" + return MessageEventResult(type=EventResultType.EMPTY) + + def image_result(self, url_or_path: str) -> MessageEventResult: + """Create a chain result that contains one image component.""" + if url_or_path.startswith(("http://", "https://")): + image = Image.fromURL(url_or_path) + elif url_or_path.startswith("base64://"): + image = Image.fromBase64(url_or_path.removeprefix("base64://")) + else: + image = Image.fromFileSystem(url_or_path) + return MessageEventResult( + type=EventResultType.CHAIN, + chain=MessageChain([image]), + ) + + def chain_result( + self, + chain: MessageChain | list[BaseMessageComponent], + ) -> MessageEventResult: + """Create a chain result from SDK components.""" + normalized = ( + chain if isinstance(chain, MessageChain) else MessageChain(list(chain)) + ) + return MessageEventResult(type=EventResultType.CHAIN, chain=normalized) + + @staticmethod + def _coerce_chain_or_raise(item: Any) -> MessageChain: + if isinstance(item, MessageEventResult): + return item.chain + if isinstance(item, MessageChain): + return item + if isinstance(item, BaseMessageComponent): + return MessageChain([item]) + if isinstance(item, list) and all( + isinstance(component, BaseMessageComponent) for component in item + ): + return MessageChain(list(item)) + raise TypeError( + "send_streaming only accepts str, MessageChain, MessageEventResult or SDK message components" + ) diff --git a/astrbot-sdk/src/astrbot_sdk/filters.py b/astrbot-sdk/src/astrbot_sdk/filters.py new file mode 100644 index 0000000000..7951cb0c0c --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/filters.py @@ -0,0 +1,218 @@ +"""SDK-native filter declarations. + +本模块提供事件过滤器的声明式 API,用于在 handler 执行前进行条件判断。 + +内置过滤器类型: +- PlatformFilter: 按平台名称过滤(如 qq、wechat) +- MessageTypeFilter: 按消息类型过滤(如 group、private) +- CustomFilter: 用户自定义的同步布尔函数 + +组合操作: +- all_of(*filters): 所有过滤器都通过才执行(AND 逻辑) +- any_of(*filters): 任一过滤器通过即可执行(OR 逻辑) +- 支持 & 和 | 运算符进行链式组合 + +过滤器在本地(SDK worker 进程内)求值,避免不必要的跨进程调用。 +""" + +from __future__ import annotations + +import inspect +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, Literal, TypeAlias + +from .decorators import append_filter_meta +from .protocol.descriptors import ( + CompositeFilterSpec, + FilterSpec, + LocalFilterRefSpec, + MessageTypeFilterSpec, + PlatformFilterSpec, +) + +FilterOperator: TypeAlias = Literal["and", "or"] + + +@dataclass(slots=True) +class LocalFilterBinding: + filter_id: str + callable: Callable[..., bool] + args: dict[str, Any] = field(default_factory=dict) + + def evaluate(self, *, event=None, ctx=None) -> bool: + signature = inspect.signature(self.callable) + kwargs: dict[str, Any] = {} + if "event" in signature.parameters: + kwargs["event"] = event + if "ctx" in signature.parameters: + kwargs["ctx"] = ctx + result = self.callable(**kwargs) + if inspect.isawaitable(result): + raise TypeError("CustomFilter must return a synchronous bool") + if not isinstance(result, bool): + raise TypeError("CustomFilter must return bool") + return result + + +class FilterBinding: + def __and__(self, other: FilterBinding) -> CompositeFilter: + return CompositeFilter("and", [self, other]) + + def __or__(self, other: FilterBinding) -> CompositeFilter: + return CompositeFilter("or", [self, other]) + + def compile(self) -> tuple[FilterSpec, list[LocalFilterBinding]]: + raise NotImplementedError + + +@dataclass(slots=True) +class PlatformFilter(FilterBinding): + platforms: list[str] + + def compile(self) -> tuple[FilterSpec, list[LocalFilterBinding]]: + return PlatformFilterSpec(platforms=list(self.platforms)), [] + + +@dataclass(slots=True) +class MessageTypeFilter(FilterBinding): + message_types: list[str] + + def compile(self) -> tuple[FilterSpec, list[LocalFilterBinding]]: + return MessageTypeFilterSpec(message_types=list(self.message_types)), [] + + +@dataclass(slots=True) +class CustomFilter(FilterBinding): + callable: Callable[..., bool] + filter_id: str | None = None + + def __post_init__(self) -> None: + if self.filter_id is None: + self.filter_id = f"{self.callable.__module__}.{getattr(self.callable, '__qualname__', self.callable.__name__)}" + + def compile(self) -> tuple[FilterSpec, list[LocalFilterBinding]]: + assert self.filter_id is not None + return LocalFilterRefSpec(filter_id=self.filter_id), [ + LocalFilterBinding(filter_id=self.filter_id, callable=self.callable), + ] + + +@dataclass(slots=True) +class CompositeFilter(FilterBinding): + operator: FilterOperator + children: list[FilterBinding] + + def compile(self) -> tuple[FilterSpec, list[LocalFilterBinding]]: + compiled_children: list[FilterSpec] = [] + local_bindings: list[LocalFilterBinding] = [] + for child in self.children: + spec, locals_for_child = child.compile() + compiled_children.append(spec) + local_bindings.extend(locals_for_child) + + if local_bindings: + filter_id = ( + "composite:" + + ":".join(binding.filter_id for binding in local_bindings) + + f":{self.operator}" + ) + + def _evaluate(*, event=None, ctx=None) -> bool: + results = [ + _evaluate_filter_spec_locally( + spec, local_bindings, event=event, ctx=ctx + ) + for spec in compiled_children + ] + if self.operator == "and": + return all(results) + return any(results) + + return ( + LocalFilterRefSpec(filter_id=filter_id), + [LocalFilterBinding(filter_id=filter_id, callable=_evaluate)], + ) + + return CompositeFilterSpec(kind=self.operator, children=compiled_children), [] + + +def _evaluate_filter_spec_locally( + spec: FilterSpec, + local_bindings: list[LocalFilterBinding], + *, + event=None, + ctx=None, +) -> bool: + if isinstance(spec, PlatformFilterSpec): + if event is None: + return True + platform = getattr(event, "platform", "") or "" + return platform in spec.platforms + if isinstance(spec, MessageTypeFilterSpec): + if event is None: + return True + message_type = getattr(event, "message_type", "") or "" + return message_type in spec.message_types + if isinstance(spec, LocalFilterRefSpec): + binding = next( + (item for item in local_bindings if item.filter_id == spec.filter_id), + None, + ) + if binding is None: + # LocalFilterRefSpec 只在当前 worker 持有同名 local binding 时可真正执行。 + # 缺失 binding 往往意味着描述符来自远端/测试快照,此时保持 fail-open, + # 避免因为无法调用进程内函数而把原本可执行的 handler 错误过滤掉。 + return True + return binding.evaluate(event=event, ctx=ctx) + if isinstance(spec, CompositeFilterSpec): + results = [ + _evaluate_filter_spec_locally( + child, + local_bindings, + event=event, + ctx=ctx, + ) + for child in spec.children + ] + if spec.kind == "and": + return all(results) + return any(results) + return True + + +def custom_filter( + binding: FilterBinding, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Attach a filter declaration to a handler.""" + + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + spec, local_bindings = binding.compile() + append_filter_meta( + func, + specs=[spec], + local_bindings=local_bindings, + ) + return func + + return decorator + + +def all_of(*bindings: FilterBinding) -> CompositeFilter: + return CompositeFilter("and", list(bindings)) + + +def any_of(*bindings: FilterBinding) -> CompositeFilter: + return CompositeFilter("or", list(bindings)) + + +__all__ = [ + "CustomFilter", + "FilterBinding", + "LocalFilterBinding", + "MessageTypeFilter", + "PlatformFilter", + "all_of", + "any_of", + "custom_filter", +] diff --git a/astrbot-sdk/src/astrbot_sdk/llm/__init__.py b/astrbot-sdk/src/astrbot_sdk/llm/__init__.py new file mode 100644 index 0000000000..02e15b9d2f --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/llm/__init__.py @@ -0,0 +1,105 @@ +"""Canonical SDK LLM/tool/provider entrypoints for P0.5.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .agents import AgentSpec, BaseAgentRunner + from .entities import ( + LLMToolSpec, + ProviderMeta, + ProviderRequest, + ProviderType, + RerankResult, + ToolCallsResult, + ) + from .providers import ( + EmbeddingProvider, + ProviderProxy, + RerankProvider, + STTProvider, + TTSAudioChunk, + TTSProvider, + ) + from .tools import LLMToolManager + +__all__ = [ + "AgentSpec", + "BaseAgentRunner", + "EmbeddingProvider", + "LLMToolManager", + "LLMToolSpec", + "ProviderMeta", + "ProviderProxy", + "ProviderRequest", + "ProviderType", + "RerankProvider", + "RerankResult", + "STTProvider", + "TTSAudioChunk", + "TTSProvider", + "ToolCallsResult", +] + + +def __getattr__(name: str) -> Any: + if name in {"AgentSpec", "BaseAgentRunner"}: + from .agents import AgentSpec, BaseAgentRunner + + return {"AgentSpec": AgentSpec, "BaseAgentRunner": BaseAgentRunner}[name] + if name in { + "LLMToolSpec", + "ProviderMeta", + "ProviderRequest", + "ProviderType", + "RerankResult", + "ToolCallsResult", + }: + from .entities import ( + LLMToolSpec, + ProviderMeta, + ProviderRequest, + ProviderType, + RerankResult, + ToolCallsResult, + ) + + return { + "LLMToolSpec": LLMToolSpec, + "ProviderMeta": ProviderMeta, + "ProviderRequest": ProviderRequest, + "ProviderType": ProviderType, + "RerankResult": RerankResult, + "ToolCallsResult": ToolCallsResult, + }[name] + if name in { + "EmbeddingProvider", + "ProviderProxy", + "RerankProvider", + "STTProvider", + "TTSAudioChunk", + "TTSProvider", + }: + from .providers import ( + EmbeddingProvider, + ProviderProxy, + RerankProvider, + STTProvider, + TTSAudioChunk, + TTSProvider, + ) + + return { + "EmbeddingProvider": EmbeddingProvider, + "ProviderProxy": ProviderProxy, + "RerankProvider": RerankProvider, + "STTProvider": STTProvider, + "TTSAudioChunk": TTSAudioChunk, + "TTSProvider": TTSProvider, + }[name] + if name == "LLMToolManager": + from .tools import LLMToolManager + + return LLMToolManager + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/astrbot-sdk/src/astrbot_sdk/llm/agents.py b/astrbot-sdk/src/astrbot_sdk/llm/agents.py new file mode 100644 index 0000000000..2a0f887292 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/llm/agents.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel, ConfigDict, Field + +from .entities import ProviderRequest + +if TYPE_CHECKING: + from ..context import Context + + +class AgentSpec(BaseModel): + model_config = ConfigDict(extra="forbid") + + name: str + description: str = "" + tool_names: list[str] = Field(default_factory=list) + runner_class: str + + def to_payload(self) -> dict[str, Any]: + return self.model_dump(exclude_none=True) + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> AgentSpec: + return cls.model_validate(payload) + + +class BaseAgentRunner(ABC): + """P0.5 agent registration surface. + + P0.5 only supports agent registration metadata. Actual execution remains + owned by the core tool loop and is not directly callable from SDK plugins. + """ + + @abstractmethod + async def run(self, ctx: Context, request: ProviderRequest) -> Any: + raise NotImplementedError diff --git a/astrbot-sdk/src/astrbot_sdk/llm/entities.py b/astrbot-sdk/src/astrbot_sdk/llm/entities.py new file mode 100644 index 0000000000..ba252db24b --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/llm/entities.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import enum +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + + +class _EntityModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + def to_payload(self) -> dict[str, Any]: + return self.model_dump(exclude_none=True) + + +class ProviderType(str, enum.Enum): + CHAT_COMPLETION = "chat_completion" + SPEECH_TO_TEXT = "speech_to_text" + TEXT_TO_SPEECH = "text_to_speech" + EMBEDDING = "embedding" + RERANK = "rerank" + + +class ProviderMeta(_EntityModel): + id: str + model: str | None = None + type: str + provider_type: ProviderType = ProviderType.CHAT_COMPLETION + + @classmethod + def from_payload(cls, payload: dict[str, Any] | None) -> ProviderMeta | None: + if not isinstance(payload, dict): + return None + return cls.model_validate(payload) + + +class ToolCallsResult(_EntityModel): + tool_call_id: str | None = None + tool_name: str + content: str + success: bool = True + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> ToolCallsResult: + return cls.model_validate(payload) + + +class RerankResult(_EntityModel): + index: int + score: float + document: str + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> RerankResult: + return cls.model_validate(payload) + + +class LLMToolSpec(_EntityModel): + name: str + description: str = "" + parameters_schema: dict[str, Any] = Field( + default_factory=lambda: {"type": "object", "properties": {}} + ) + handler_ref: str | None = Field( + default=None, + description="Worker-side handler reference used to resolve the tool callable.", + ) + handler_capability: str | None = Field( + default=None, + description="Optional capability name override for executing this tool handler.", + ) + active: bool = True + + @classmethod + def create( + cls, + *, + name: str, + description: str = "", + parameters_schema: dict[str, Any] | None = None, + handler_ref: str | None = None, + handler_capability: str | None = None, + active: bool = True, + ) -> LLMToolSpec: + # Keep an explicit factory signature so static analyzers do not depend on + # Pydantic's generated __init__ when SDK call sites construct tool specs. + payload: dict[str, Any] = { + "name": name, + "description": description, + "parameters_schema": parameters_schema + if parameters_schema is not None + else {"type": "object", "properties": {}}, + "active": active, + } + if handler_ref is not None: + payload["handler_ref"] = handler_ref + if handler_capability is not None: + payload["handler_capability"] = handler_capability + return cls.from_payload(payload) + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> LLMToolSpec: + return cls.model_validate(payload) + + +class ProviderRequest(_EntityModel): + prompt: str | None = None + system_prompt: str | None = None + session_id: str | None = None + contexts: list[dict[str, Any]] = Field(default_factory=list) + image_urls: list[str] = Field(default_factory=list) + tool_names: list[str] | None = None + tool_calls_result: list[ToolCallsResult] = Field(default_factory=list) + provider_id: str | None = None + model: str | None = None + temperature: float | None = None + max_steps: int | None = None + tool_call_timeout: int | None = None + + def to_payload(self) -> dict[str, Any]: + payload = super().to_payload() + payload["tool_calls_result"] = [ + item.to_payload() for item in self.tool_calls_result + ] + return payload + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> ProviderRequest: + normalized = dict(payload) + raw_results = normalized.get("tool_calls_result") + if isinstance(raw_results, list): + normalized["tool_calls_result"] = [ + ToolCallsResult.from_payload(item) + for item in raw_results + if isinstance(item, dict) + ] + return cls.model_validate(normalized) diff --git a/astrbot-sdk/src/astrbot_sdk/llm/providers.py b/astrbot-sdk/src/astrbot_sdk/llm/providers.py new file mode 100644 index 0000000000..591e1d57d5 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/llm/providers.py @@ -0,0 +1,199 @@ +"""Provider-facing SDK entities and typed proxy helpers.""" + +from __future__ import annotations + +import base64 +from collections.abc import AsyncIterable, AsyncIterator +from dataclasses import dataclass + +from ..clients._proxy import CapabilityProxy +from .entities import ProviderMeta, ProviderType, RerankResult + + +@dataclass(slots=True) +class TTSAudioChunk: + audio: bytes + text: str | None = None + + +class _BaseProviderProxy: + def __init__(self, proxy: CapabilityProxy, meta: ProviderMeta) -> None: + self._proxy = proxy + self._meta = meta + + @property + def id(self) -> str: + return self._meta.id + + @property + def model(self) -> str | None: + return self._meta.model + + @property + def type(self) -> str: + return self._meta.type + + @property + def provider_type(self) -> ProviderType: + return self._meta.provider_type + + def meta(self) -> ProviderMeta: + return self._meta + + +class STTProvider(_BaseProviderProxy): + async def get_text(self, audio_url: str) -> str: + output = await self._proxy.call( + "provider.stt.get_text", + {"provider_id": self.id, "audio_url": str(audio_url)}, + ) + return str(output.get("text", "")) + + +class TTSProvider(_BaseProviderProxy): + def __init__( + self, + proxy: CapabilityProxy, + meta: ProviderMeta, + *, + supports_stream: bool = False, + ) -> None: + super().__init__(proxy, meta) + self._supports_stream = supports_stream + + async def get_audio(self, text: str) -> str: + output = await self._proxy.call( + "provider.tts.get_audio", + {"provider_id": self.id, "text": str(text)}, + ) + return str(output.get("audio_path", "")) + + def support_stream(self) -> bool: + return self._supports_stream + + async def get_audio_stream( + self, + text: str | AsyncIterable[str], + ) -> AsyncIterator[TTSAudioChunk]: + payload = await self._build_stream_payload(text) + async for chunk in self._proxy.stream("provider.tts.get_audio_stream", payload): + audio_base64 = str(chunk.get("audio_base64", "")) + yield TTSAudioChunk( + audio=base64.b64decode(audio_base64) if audio_base64 else b"", + text=( + str(chunk.get("text")) if chunk.get("text") is not None else None + ), + ) + + async def _build_stream_payload( + self, + text: str | AsyncIterable[str], + ) -> dict[str, object]: + payload: dict[str, object] = {"provider_id": self.id} + if isinstance(text, str): + payload["text"] = text + return payload + payload["text_chunks"] = [str(item) async for item in text] + return payload + + +class EmbeddingProvider(_BaseProviderProxy): + async def get_embedding(self, text: str) -> list[float]: + output = await self._proxy.call( + "provider.embedding.get_embedding", + {"provider_id": self.id, "text": str(text)}, + ) + embedding = output.get("embedding") + if not isinstance(embedding, list): + return [] + return [float(item) for item in embedding] + + async def get_embeddings(self, texts: list[str]) -> list[list[float]]: + output = await self._proxy.call( + "provider.embedding.get_embeddings", + { + "provider_id": self.id, + "texts": [str(item) for item in texts], + }, + ) + embeddings = output.get("embeddings") + if not isinstance(embeddings, list): + return [] + return [ + [float(value) for value in item] + for item in embeddings + if isinstance(item, list) + ] + + async def get_dim(self) -> int: + output = await self._proxy.call( + "provider.embedding.get_dim", + {"provider_id": self.id}, + ) + return int(output.get("dim", 0)) + + +class RerankProvider(_BaseProviderProxy): + async def rerank( + self, + query: str, + documents: list[str], + top_n: int | None = None, + ) -> list[RerankResult]: + output = await self._proxy.call( + "provider.rerank.rerank", + { + "provider_id": self.id, + "query": str(query), + "documents": [str(item) for item in documents], + "top_n": top_n, + }, + ) + results = output.get("results") + if not isinstance(results, list): + return [] + return [ + RerankResult.from_payload(item) + for item in results + if isinstance(item, dict) + ] + + +ProviderProxy = STTProvider | TTSProvider | EmbeddingProvider | RerankProvider + + +def provider_proxy_from_meta( + proxy: CapabilityProxy, + meta: ProviderMeta | None, + *, + tts_supports_stream: bool | None = None, +) -> ProviderProxy | None: + if meta is None: + return None + if meta.provider_type == ProviderType.SPEECH_TO_TEXT: + return STTProvider(proxy, meta) + if meta.provider_type == ProviderType.TEXT_TO_SPEECH: + return TTSProvider( + proxy, + meta, + supports_stream=bool(tts_supports_stream), + ) + if meta.provider_type == ProviderType.EMBEDDING: + return EmbeddingProvider(proxy, meta) + if meta.provider_type == ProviderType.RERANK: + return RerankProvider(proxy, meta) + return None + + +__all__ = [ + "EmbeddingProvider", + "ProviderMeta", + "ProviderProxy", + "ProviderType", + "RerankProvider", + "RerankResult", + "STTProvider", + "TTSAudioChunk", + "TTSProvider", + "provider_proxy_from_meta", +] diff --git a/astrbot-sdk/src/astrbot_sdk/llm/tools.py b/astrbot-sdk/src/astrbot_sdk/llm/tools.py new file mode 100644 index 0000000000..d1a67b30c7 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/llm/tools.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .entities import LLMToolSpec + +if TYPE_CHECKING: + from ..clients._proxy import CapabilityProxy + + +class LLMToolManager: + def __init__(self, proxy: CapabilityProxy) -> None: + self._proxy = proxy + + async def list_registered(self) -> list[LLMToolSpec]: + output = await self._proxy.call("llm_tool.manager.get", {}) + items = output.get("registered") + if not isinstance(items, list): + return [] + return [ + LLMToolSpec.from_payload(item) for item in items if isinstance(item, dict) + ] + + async def list_active(self) -> list[LLMToolSpec]: + output = await self._proxy.call("llm_tool.manager.get", {}) + items = output.get("active") + if not isinstance(items, list): + return [] + return [ + LLMToolSpec.from_payload(item) for item in items if isinstance(item, dict) + ] + + async def activate(self, name: str) -> bool: + output = await self._proxy.call("llm_tool.manager.activate", {"name": name}) + return bool(output.get("activated", False)) + + async def deactivate(self, name: str) -> bool: + output = await self._proxy.call("llm_tool.manager.deactivate", {"name": name}) + return bool(output.get("deactivated", False)) + + async def add(self, *tools: LLMToolSpec) -> list[str]: + output = await self._proxy.call( + "llm_tool.manager.add", + {"tools": [tool.to_payload() for tool in tools]}, + ) + result = output.get("names") + if not isinstance(result, list): + return [] + return [str(item) for item in result] + + async def remove(self, name: str) -> bool: + output = await self._proxy.call("llm_tool.manager.remove", {"name": name}) + return bool(output.get("removed", False)) + + async def get(self, name: str) -> LLMToolSpec | None: + for tool in await self.list_registered(): + if tool.name == name: + return tool + return None diff --git a/astrbot-sdk/src/astrbot_sdk/message/__init__.py b/astrbot-sdk/src/astrbot_sdk/message/__init__.py new file mode 100644 index 0000000000..4125a0db12 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/message/__init__.py @@ -0,0 +1,103 @@ +"""Message component, result, and session subpackage.""" + +from .components import ( + At as At, +) +from .components import ( + AtAll as AtAll, +) +from .components import ( + BaseMessageComponent as BaseMessageComponent, +) +from .components import ( + File as File, +) +from .components import ( + Forward as Forward, +) +from .components import ( + Image as Image, +) +from .components import ( + MediaHelper as MediaHelper, +) +from .components import ( + Plain as Plain, +) +from .components import ( + Poke as Poke, +) +from .components import ( + Record as Record, +) +from .components import ( + Reply as Reply, +) +from .components import ( + UnknownComponent as UnknownComponent, +) +from .components import ( + Video as Video, +) +from .components import ( + build_media_component_from_url as build_media_component_from_url, +) +from .components import ( + component_to_payload as component_to_payload, +) +from .components import ( + component_to_payload_sync as component_to_payload_sync, +) +from .components import ( + is_message_component as is_message_component, +) +from .components import ( + payload_to_component as payload_to_component, +) +from .components import ( + payloads_to_components as payloads_to_components, +) +from .result import ( + EventResultType as EventResultType, +) +from .result import ( + MessageBuilder as MessageBuilder, +) +from .result import ( + MessageChain as MessageChain, +) +from .result import ( + MessageEventResult as MessageEventResult, +) +from .result import ( + coerce_message_chain as coerce_message_chain, +) +from .session import MessageSession as MessageSession + +__all__ = [ + "At", + "AtAll", + "BaseMessageComponent", + "EventResultType", + "File", + "Forward", + "Image", + "MediaHelper", + "MessageBuilder", + "MessageChain", + "MessageEventResult", + "MessageSession", + "Plain", + "Poke", + "Record", + "Reply", + "UnknownComponent", + "Video", + "build_media_component_from_url", + "coerce_message_chain", + "component_to_payload", + "component_to_payload_sync", + "is_message_component", + "payload_to_component", + "payloads_to_components", +] diff --git a/astrbot-sdk/src/astrbot_sdk/message/components.py b/astrbot-sdk/src/astrbot_sdk/message/components.py new file mode 100644 index 0000000000..5c5423499d --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/message/components.py @@ -0,0 +1,625 @@ +"""SDK message component compatibility layer. + +该模块有意避免在导入时导入遗留核心组件模块。 +SDK工作线程应该保持轻量级并且不能依赖于主机核心引导程序 +仅用于构造消息对象的路径。 +""" + +from __future__ import annotations + +import asyncio +import base64 +import inspect +import os +import tempfile +import uuid +from collections.abc import Mapping +from pathlib import Path +from typing import Any +from urllib.parse import urlparse +from urllib.request import urlretrieve + +from .._internal.star_runtime import current_runtime_context +from ..errors import AstrBotError + +_IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"} +_RECORD_SUFFIXES = {".mp3", ".wav", ".ogg", ".flac", ".aac", ".m4a"} +_VIDEO_SUFFIXES = {".mp4", ".webm", ".mov", ".mkv", ".avi"} + + +def _temp_path(prefix: str, suffix: str = "") -> Path: + return Path(tempfile.gettempdir()) / f"{prefix}_{uuid.uuid4().hex}{suffix}" + + +def _guess_suffix_from_url(url: str, fallback: str = "") -> str: + suffix = Path(urlparse(url).path).suffix + return suffix or fallback + + +def _download_to_temp(url: str, prefix: str, fallback_suffix: str = "") -> str: + target = _temp_path(prefix, _guess_suffix_from_url(url, fallback_suffix)) + urlretrieve(url, target) + return str(target.resolve()) + + +async def _download_to_temp_async( + url: str, + prefix: str, + fallback_suffix: str = "", +) -> str: + return await asyncio.to_thread( + _download_to_temp, + url, + prefix, + fallback_suffix, + ) + + +def _stringify_mapping(mapping: Mapping[Any, Any]) -> dict[str, Any]: + return {str(key): value for key, value in mapping.items()} + + +async def _register_file_to_service(path: str) -> str: + context = current_runtime_context() + if context is None: + raise RuntimeError("message component file service requires runtime context") + return await context._register_file_url(path) + + +def _reply_chain_payloads_sync(value: Any) -> list[dict[str, Any]]: + if not isinstance(value, list): + return [] + return [component_to_payload_sync(item) for item in value] + + +async def _reply_chain_payloads(value: Any) -> list[dict[str, Any]]: + if not isinstance(value, list): + return [] + return [await component_to_payload(item) for item in value] + + +def _coerce_reply_chain(value: Any) -> list[BaseMessageComponent]: + if not isinstance(value, list): + return [] + if value and all(isinstance(item, BaseMessageComponent) for item in value): + return list(value) + return payloads_to_components(value) + + +def _component_type_name(component: Any) -> str: + raw_type = getattr(component, "type", "unknown") + normalized = getattr(raw_type, "value", raw_type) + return str(normalized or "unknown").lower() + + +def _plain_payload(text: Any) -> dict[str, Any]: + return {"type": "text", "data": {"text": str(text)}} + + +def _reply_payload_data( + component: Any, + *, + chain_payloads: list[dict[str, Any]], +) -> dict[str, Any]: + return { + "id": getattr(component, "id", ""), + "chain": chain_payloads, + "sender_id": getattr(component, "sender_id", 0), + "sender_nickname": getattr(component, "sender_nickname", ""), + "time": getattr(component, "time", 0), + "message_str": getattr(component, "message_str", ""), + "text": getattr(component, "text", ""), + "qq": getattr(component, "qq", 0), + "seq": getattr(component, "seq", 0), + } + + +def _resolve_media_kind(url: str, kind: str = "auto") -> str: + normalized_kind = str(kind).strip().lower() or "auto" + if normalized_kind != "auto": + return normalized_kind + suffix = Path(urlparse(url).path).suffix.lower() + if suffix in _IMAGE_SUFFIXES: + return "image" + if suffix in _RECORD_SUFFIXES: + return "record" + if suffix in _VIDEO_SUFFIXES: + return "video" + return "file" + + +def build_media_component_from_url( + url: str, + *, + kind: str = "auto", +) -> BaseMessageComponent: + url_text = str(url).strip() + if not url_text: + raise AstrBotError.invalid_input( + "MediaHelper.from_url requires a non-empty url" + ) + resolved_kind = _resolve_media_kind(url_text, kind=kind) + if resolved_kind == "image": + return Image.fromURL(url_text) + if resolved_kind in {"record", "audio"}: + return Record.fromURL(url_text) + if resolved_kind == "video": + return Video.fromURL(url_text) + if resolved_kind == "file": + return File(name=_filename_from_url(url_text), url=url_text) + raise AstrBotError.invalid_input( + f"Unsupported media kind: {kind}", + details={"kind": kind, "url": url_text}, + ) + + +def _filename_from_url(url: str) -> str: + name = Path(urlparse(url).path).name + return name or "download" + + +class BaseMessageComponent: + type: str = "unknown" + + def toDict(self) -> dict[str, Any]: + data: dict[str, Any] = {} + for key, value in self.__dict__.items(): + if key == "type" or value is None: + continue + data["type" if key == "_type" else key] = value + return {"type": str(self.type).lower(), "data": data} + + async def to_dict(self) -> dict[str, Any]: + return self.toDict() + + +class Plain(BaseMessageComponent): + type = "plain" + + def __init__(self, text: str, convert: bool = True, **_: Any) -> None: + self.text = text + self.convert = convert + + def toDict(self) -> dict[str, Any]: + return _plain_payload(self.text) + + async def to_dict(self) -> dict[str, Any]: + return _plain_payload(self.text) + + +class At(BaseMessageComponent): + type = "at" + + def __init__(self, qq: int | str, name: str | None = "", **_: Any) -> None: + self.qq = qq + self.name = name or "" + + def toDict(self) -> dict[str, Any]: + return {"type": "at", "data": {"qq": str(self.qq)}} + + +class AtAll(At): + def __init__(self, **_: Any) -> None: + super().__init__(qq="all") + + +class Reply(BaseMessageComponent): + type = "reply" + + def __init__(self, **kwargs: Any) -> None: + self.id = kwargs.get("id", "") + self.chain = _coerce_reply_chain(kwargs.get("chain", [])) + self.sender_id = kwargs.get("sender_id", 0) + self.sender_nickname = kwargs.get("sender_nickname", "") + self.time = kwargs.get("time", 0) + self.message_str = kwargs.get("message_str", "") + self.text = kwargs.get("text", "") + self.qq = kwargs.get("qq", 0) + self.seq = kwargs.get("seq", 0) + + def toDict(self) -> dict[str, Any]: + return { + "type": "reply", + "data": _reply_payload_data( + self, + chain_payloads=_reply_chain_payloads_sync(self.chain), + ), + } + + async def to_dict(self) -> dict[str, Any]: + return { + "type": "reply", + "data": _reply_payload_data( + self, + chain_payloads=await _reply_chain_payloads(self.chain), + ), + } + + +class Image(BaseMessageComponent): + type = "image" + + def __init__(self, file: str | None, **kwargs: Any) -> None: + self.file = file or "" + self._type = kwargs.get("_type", "") + self.subType = kwargs.get("subType", 0) + self.url = kwargs.get("url", "") + self.cache = kwargs.get("cache", True) + self.id = kwargs.get("id", 40000) + self.c = kwargs.get("c", 2) + self.path = kwargs.get("path", "") + self.file_unique = kwargs.get("file_unique", "") + + @staticmethod + def fromURL(url: str, **kwargs: Any) -> Image: + return Image(url, **kwargs) + + @staticmethod + def fromFileSystem(path: str, **kwargs: Any) -> Image: + return Image(f"file:///{os.path.abspath(path)}", path=path, **kwargs) + + @staticmethod + def fromBase64(base64_data: str, **kwargs: Any) -> Image: + return Image(f"base64://{base64_data}", **kwargs) + + async def convert_to_file_path(self) -> str: + url = self.url or self.file + if not url: + raise ValueError("No valid file or URL provided") + if url.startswith("file:///"): + return os.path.abspath(url[8:]) + if url.startswith(("http://", "https://")): + return await _download_to_temp_async(url, "imgseg", ".jpg") + if url.startswith("base64://"): + file_path = _temp_path("imgseg", ".jpg") + file_path.write_bytes(base64.b64decode(url.removeprefix("base64://"))) + return str(file_path.resolve()) + if os.path.exists(url): + return os.path.abspath(url) + raise ValueError(f"not a valid file: {url}") + + async def register_to_file_service(self) -> str: + return await _register_file_to_service(await self.convert_to_file_path()) + + +class Record(BaseMessageComponent): + type = "record" + + def __init__(self, file: str | None, **kwargs: Any) -> None: + self.file = file or "" + self.magic = kwargs.get("magic", False) + self.url = kwargs.get("url", "") + self.cache = kwargs.get("cache", True) + self.proxy = kwargs.get("proxy", True) + self.timeout = kwargs.get("timeout", 0) + self.text = kwargs.get("text") + self.path = kwargs.get("path") + + @staticmethod + def fromFileSystem(path: str, **kwargs: Any) -> Record: + return Record(f"file:///{os.path.abspath(path)}", path=path, **kwargs) + + @staticmethod + def fromURL(url: str, **kwargs: Any) -> Record: + return Record(url, **kwargs) + + async def convert_to_file_path(self) -> str: + if self.file.startswith("file:///"): + return os.path.abspath(self.file[8:]) + if self.file.startswith(("http://", "https://")): + return await _download_to_temp_async(self.file, "recordseg", ".dat") + if self.file.startswith("base64://"): + file_path = _temp_path("recordseg", ".dat") + file_path.write_bytes(base64.b64decode(self.file.removeprefix("base64://"))) + return str(file_path.resolve()) + if os.path.exists(self.file): + return os.path.abspath(self.file) + raise ValueError(f"not a valid file: {self.file}") + + async def register_to_file_service(self) -> str: + return await _register_file_to_service(await self.convert_to_file_path()) + + +class Video(BaseMessageComponent): + type = "video" + + def __init__(self, file: str, **kwargs: Any) -> None: + self.file = file + self.cover = kwargs.get("cover", "") + self.c = kwargs.get("c", 2) + self.path = kwargs.get("path", "") + + @staticmethod + def fromFileSystem(path: str, **kwargs: Any) -> Video: + return Video(f"file:///{os.path.abspath(path)}", path=path, **kwargs) + + @staticmethod + def fromURL(url: str, **kwargs: Any) -> Video: + return Video(url, **kwargs) + + async def convert_to_file_path(self) -> str: + if self.file.startswith("file:///"): + return os.path.abspath(self.file[8:]) + if self.file.startswith(("http://", "https://")): + return await _download_to_temp_async(self.file, "videoseg") + if os.path.exists(self.file): + return os.path.abspath(self.file) + raise ValueError(f"not a valid file: {self.file}") + + async def register_to_file_service(self) -> str: + return await _register_file_to_service(await self.convert_to_file_path()) + + +class File(BaseMessageComponent): + type = "file" + + def __init__(self, name: str, file: str = "", url: str = "") -> None: + self.name = name + self.file_ = file + self.url = url + + @property + def file(self) -> str: + return self.file_ + + @file.setter + def file(self, value: str) -> None: + if value.startswith(("http://", "https://")): + self.url = value + else: + self.file_ = value + + async def get_file(self, allow_return_url: bool = False) -> str: + if allow_return_url and self.url: + return self.url + if self.file_: + path = self.file_ + if path.startswith("file://"): + path = path[7:] + if ( + os.name == "nt" + and len(path) > 2 + and path[0] == "/" + and path[2] == ":" + ): + path = path[1:] + if os.path.exists(path): + return os.path.abspath(path) + if self.url: + suffix = Path(urlparse(self.url).path).suffix + target = await _download_to_temp_async(self.url, "fileseg", suffix) + self.file_ = target + return target + return "" + + async def register_to_file_service(self) -> str: + return await _register_file_to_service(await self.get_file()) + + def toDict(self) -> dict[str, Any]: + payload_file = self.url or self.file_ + return { + "type": "file", + "data": { + "name": self.name, + "file": payload_file, + }, + } + + async def to_dict(self) -> dict[str, Any]: + payload_file = await self.get_file(allow_return_url=True) + return { + "type": "file", + "data": { + "name": self.name, + "file": payload_file, + }, + } + + +class Poke(BaseMessageComponent): + type = "poke" + + def __init__(self, poke_type: str | int | None = None, **kwargs: Any) -> None: + legacy_type = kwargs.pop("type", None) + if poke_type is None: + poke_type = legacy_type + if poke_type in (None, "", "poke", "Poke"): + poke_type = "126" + self._type = str(poke_type) + self.id = kwargs.get("id") + self.qq = kwargs.get("qq", 0) + + def target_id(self) -> str | None: + for value in (self.id, self.qq): + if value is None: + continue + text = str(value).strip() + if text and text != "0": + return text + return None + + def toDict(self) -> dict[str, Any]: + data = {"type": str(self._type or "126")} + target_id = self.target_id() + if target_id: + data["id"] = target_id + return {"type": "poke", "data": data} + + +class Forward(BaseMessageComponent): + type = "forward" + + def __init__(self, id: str, **_: Any) -> None: + self.id = id + + +class UnknownComponent(BaseMessageComponent): + type = "unknown" + + def __init__( + self, + *, + raw_type: str = "unknown", + raw_data: dict[str, Any] | None = None, + ) -> None: + self.raw_type = raw_type + self.raw_data = raw_data or {} + + def toDict(self) -> dict[str, Any]: + return { + "type": self.raw_type or "unknown", + "data": dict(self.raw_data), + } + + +def is_message_component(value: Any) -> bool: + return isinstance(value, BaseMessageComponent) + + +def payload_to_component(payload: Any) -> BaseMessageComponent: + if not isinstance(payload, dict): + return UnknownComponent(raw_data={"value": payload}) + + raw_type = str(payload.get("type", "unknown") or "unknown").lower() + data = payload.get("data") + if not isinstance(data, dict): + data = {} + + if raw_type in {"text", "plain"}: + return Plain(str(data.get("text", "")), convert=False) + if raw_type == "image": + return Image(str(data.get("file") or data.get("url") or "")) + if raw_type == "at": + qq_value = data.get("qq") + if str(qq_value).lower() == "all": + return AtAll() + qq = "" if qq_value is None else str(qq_value) + return At(qq=qq, name=str(data.get("name", ""))) + if raw_type == "reply": + return Reply(**data) + if raw_type == "record": + return Record(str(data.get("file") or data.get("url") or ""), **data) + if raw_type == "video": + return Video(str(data.get("file") or ""), **data) + if raw_type == "file": + file_value = str(data.get("file") or data.get("file_") or "") + if not file_value: + file_value = str(data.get("url") or "") + return File( + str(data.get("name", "")), + file="" if file_value.startswith(("http://", "https://")) else file_value, + url=file_value if file_value.startswith(("http://", "https://")) else "", + ) + if raw_type == "poke": + return Poke( + poke_type=data.get("type"), + id=data.get("id"), + qq=data.get("qq"), + ) + if raw_type == "forward": + return Forward(id=str(data.get("id", ""))) + + return UnknownComponent(raw_type=raw_type, raw_data=_stringify_mapping(data)) + + +def payloads_to_components(payloads: list[Any]) -> list[BaseMessageComponent]: + return [payload_to_component(item) for item in payloads] + + +def component_to_payload_sync(component: Any) -> dict[str, Any]: + if isinstance(component, UnknownComponent): + return component.toDict() + if isinstance(component, Plain): + return _plain_payload(component.text) + if _component_type_name(component) == "reply": + return { + "type": "reply", + "data": _reply_payload_data( + component, + chain_payloads=_reply_chain_payloads_sync( + getattr(component, "chain", []) + ), + ), + } + to_dict = getattr(component, "toDict", None) + if callable(to_dict): + result = to_dict() + if isinstance(result, Mapping): + return _stringify_mapping(result) + return {"type": "unknown", "data": {"value": str(component)}} + + +async def component_to_payload(component: Any) -> dict[str, Any]: + if isinstance(component, (UnknownComponent, Plain)): + return component_to_payload_sync(component) + async_method = getattr(component, "to_dict", None) + if callable(async_method): + payload = async_method() + if inspect.isawaitable(payload): + result = await payload + if isinstance(result, dict): + return result + return component_to_payload_sync(component) + + +class MediaHelper: + @staticmethod + async def from_url( + url: str, + *, + kind: str = "auto", + ) -> BaseMessageComponent: + return build_media_component_from_url(url, kind=kind) + + @staticmethod + async def download(url: str, save_dir: Path) -> Path: + url_text = str(url).strip() + if not url_text: + raise AstrBotError.invalid_input( + "MediaHelper.download requires a non-empty url" + ) + parsed = urlparse(url_text) + if parsed.scheme not in {"http", "https"}: + raise AstrBotError.invalid_input( + "MediaHelper.download only supports http/https urls", + details={"url": url_text}, + ) + target_dir = Path(save_dir) + try: + target_dir.mkdir(parents=True, exist_ok=True) + except OSError as exc: + raise AstrBotError.internal_error( + f"Failed to prepare download directory: {target_dir}", + details={"save_dir": str(target_dir)}, + ) from exc + target_path = target_dir / _filename_from_url(url_text) + try: + await asyncio.to_thread(urlretrieve, url_text, target_path) + except Exception as exc: + raise AstrBotError.network_error( + f"Failed to download media from '{url_text}'", + details={"url": url_text}, + ) from exc + return target_path.resolve() + + +__all__ = [ + "At", + "AtAll", + "BaseMessageComponent", + "File", + "Forward", + "Image", + "MediaHelper", + "Plain", + "Poke", + "Record", + "Reply", + "UnknownComponent", + "Video", + "component_to_payload", + "component_to_payload_sync", + "is_message_component", + "payload_to_component", + "payloads_to_components", +] diff --git a/astrbot-sdk/src/astrbot_sdk/message/result.py b/astrbot-sdk/src/astrbot_sdk/message/result.py new file mode 100644 index 0000000000..3b32bac010 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/message/result.py @@ -0,0 +1,173 @@ +"""SDK-local rich message result objects. + +本模块定义消息事件的结果对象,用于构建和返回富文本/多媒体消息。 + +核心类: +- MessageChain: 消息组件列表,支持同步/异步序列化为协议 payload +- MessageEventResult: 事件处理结果,包含类型标记和消息链 +- EventResultType: 结果类型枚举(EMPTY / CHAIN) + +辅助函数: +- coerce_message_chain: 将多种输入格式统一转换为 MessageChain, + 支持 MessageEventResult、MessageChain、单个组件或组件列表 +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from .components import ( + At, + AtAll, + BaseMessageComponent, + File, + Plain, + Reply, + build_media_component_from_url, + component_to_payload, + component_to_payload_sync, + is_message_component, + payloads_to_components, +) + + +class EventResultType(str, Enum): + EMPTY = "empty" + CHAIN = "chain" + + +@dataclass(slots=True) +class MessageChain: + components: list[BaseMessageComponent] = field(default_factory=list) + + def append(self, component: BaseMessageComponent) -> MessageChain: + self.components.append(component) + return self + + def extend(self, components: list[BaseMessageComponent]) -> MessageChain: + self.components.extend(components) + return self + + def __iter__(self): + return iter(self.components) + + def __len__(self) -> int: + return len(self.components) + + def to_payload(self) -> list[dict[str, Any]]: + return [component_to_payload_sync(component) for component in self.components] + + async def to_payload_async(self) -> list[dict[str, Any]]: + return [await component_to_payload(component) for component in self.components] + + def get_plain_text(self, with_other_comps_mark: bool = False) -> str: + texts: list[str] = [] + for component in self.components: + if isinstance(component, Plain): + texts.append(component.text) + elif with_other_comps_mark: + texts.append(f"[{component.__class__.__name__}]") + return " ".join(texts) + + def plain_text(self, with_other_comps_mark: bool = False) -> str: + return self.get_plain_text(with_other_comps_mark=with_other_comps_mark) + + +@dataclass(slots=True) +class MessageEventResult: + type: EventResultType = EventResultType.EMPTY + chain: MessageChain = field(default_factory=MessageChain) + + def to_payload(self) -> dict[str, Any]: + return { + "type": self.type.value, + "chain": self.chain.to_payload(), + } + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> MessageEventResult: + result_type_raw = str(payload.get("type", EventResultType.EMPTY.value)) + try: + result_type = EventResultType(result_type_raw) + except ValueError: + result_type = EventResultType.EMPTY + chain_payload = payload.get("chain") + components = ( + payloads_to_components(chain_payload) + if isinstance(chain_payload, list) + else [] + ) + return cls(type=result_type, chain=MessageChain(components)) + + +@dataclass(slots=True) +class MessageBuilder: + components: list[BaseMessageComponent] = field(default_factory=list) + + def text(self, content: str) -> MessageBuilder: + self.components.append(Plain(content, convert=False)) + return self + + def at(self, user_id: str) -> MessageBuilder: + self.components.append(At(user_id)) + return self + + def at_all(self) -> MessageBuilder: + self.components.append(AtAll()) + return self + + def image(self, url: str) -> MessageBuilder: + self.components.append(build_media_component_from_url(url, kind="image")) + return self + + def record(self, url: str) -> MessageBuilder: + self.components.append(build_media_component_from_url(url, kind="record")) + return self + + def video(self, url: str) -> MessageBuilder: + self.components.append(build_media_component_from_url(url, kind="video")) + return self + + def file(self, name: str, *, file: str = "", url: str = "") -> MessageBuilder: + self.components.append(File(name=name, file=file, url=url)) + return self + + def reply(self, **kwargs: Any) -> MessageBuilder: + self.components.append(Reply(**kwargs)) + return self + + def append(self, component: BaseMessageComponent) -> MessageBuilder: + self.components.append(component) + return self + + def extend(self, components: list[BaseMessageComponent]) -> MessageBuilder: + self.components.extend(components) + return self + + def build(self) -> MessageChain: + return MessageChain(list(self.components)) + + +def coerce_message_chain(value: Any) -> MessageChain | None: + if isinstance(value, MessageEventResult): + return value.chain + if isinstance(value, MessageChain): + return value + if is_message_component(value): + return MessageChain([value]) + if isinstance(value, (list, tuple)) and all( + is_message_component(item) for item in value + ): + return MessageChain(list(value)) + return None + + +__all__ = [ + "EventResultType", + "MessageChain", + "MessageBuilder", + "MessageEventResult", + "coerce_message_chain", +] diff --git a/astrbot-sdk/src/astrbot_sdk/message/session.py b/astrbot-sdk/src/astrbot_sdk/message/session.py new file mode 100644 index 0000000000..96bc1ae068 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/message/session.py @@ -0,0 +1,48 @@ +"""SDK-visible message session identifier. + +本模块定义 MessageSession 类,用于统一表示消息会话标识符。 +会话标识符格式为:platform_id:message_type:session_id + +例如: +- qq:group:123456 表示 QQ 群 123456 +- wechat:private:user789 表示微信私聊用户 user789 + +该格式与 AstrBot 核心的 unified_msg_origin 保持兼容, +确保 SDK 与核心之间的会话信息能够正确传递。 +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from .._message_types import normalize_message_type + + +@dataclass(slots=True) +class MessageSession: + """SDK-visible message session identifier. + + The string form stays compatible with AstrBot's unified message origin: + ``platform_id:message_type:session_id``. + """ + + platform_id: str + message_type: str + session_id: str + + def __post_init__(self) -> None: + self.platform_id = str(self.platform_id) + self.message_type = normalize_message_type(self.message_type) + self.session_id = str(self.session_id) + + def __str__(self) -> str: + return f"{self.platform_id}:{self.message_type}:{self.session_id}" + + @classmethod + def from_str(cls, session: str) -> MessageSession: + platform_id, message_type, session_id = str(session).split(":", 2) + return cls( + platform_id=platform_id, + message_type=message_type, + session_id=session_id, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/message_components.py b/astrbot-sdk/src/astrbot_sdk/message_components.py new file mode 100644 index 0000000000..372bd54a67 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/message_components.py @@ -0,0 +1,13 @@ +"""Backward-compatible alias for ``astrbot_sdk.message.components``. + +This module intentionally aliases the implementation module instead of re-exporting +names one by one so private helpers keep working with existing monkeypatch sites. +""" + +from __future__ import annotations + +import sys + +from .message import components as _components_module + +sys.modules[__name__] = _components_module diff --git a/astrbot-sdk/src/astrbot_sdk/message_result.py b/astrbot-sdk/src/astrbot_sdk/message_result.py new file mode 100644 index 0000000000..0b575aad5c --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/message_result.py @@ -0,0 +1,13 @@ +"""Backward-compatible alias for ``astrbot_sdk.message.result``. + +Use a module alias so callers patching helper functions on the legacy module path +still affect ``MessageBuilder`` and other implementation globals. +""" + +from __future__ import annotations + +import sys + +from .message import result as _result_module + +sys.modules[__name__] = _result_module diff --git a/astrbot-sdk/src/astrbot_sdk/message_session.py b/astrbot-sdk/src/astrbot_sdk/message_session.py new file mode 100644 index 0000000000..ec87255555 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/message_session.py @@ -0,0 +1,9 @@ +"""Backward-compatible message session exports. + +The canonical implementation moved to ``astrbot_sdk.message.session``. Preserve the +legacy import path to avoid breaking existing plugins. +""" + +from .message.session import MessageSession + +__all__ = ["MessageSession"] diff --git a/astrbot-sdk/src/astrbot_sdk/plugin_kv.py b/astrbot-sdk/src/astrbot_sdk/plugin_kv.py new file mode 100644 index 0000000000..de1922b60b --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/plugin_kv.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast + +if TYPE_CHECKING: + from .context import Context + +_VT = TypeVar("_VT") + + +class _HasRuntimeContext(Protocol): + def _require_runtime_context(self) -> Context: ... + + +class PluginKVStoreMixin: + """Plugin-scoped KV helpers backed by the runtime db client.""" + + def _runtime_context(self) -> Context: + owner = cast(_HasRuntimeContext, self) + return owner._require_runtime_context() + + @property + def plugin_id(self) -> str: + ctx = self._runtime_context() + return ctx.plugin_id + + async def put_kv_data(self, key: str, value: Any) -> None: + ctx = self._runtime_context() + await ctx.db.set(str(key), value) + + async def get_kv_data(self, key: str, default: _VT) -> _VT: + ctx = self._runtime_context() + value = await ctx.db.get(str(key)) + return default if value is None else value + + async def delete_kv_data(self, key: str) -> None: + ctx = self._runtime_context() + await ctx.db.delete(str(key)) diff --git a/astrbot-sdk/src/astrbot_sdk/protocol/__init__.py b/astrbot-sdk/src/astrbot_sdk/protocol/__init__.py new file mode 100644 index 0000000000..6684d30705 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/protocol/__init__.py @@ -0,0 +1,160 @@ +"""AstrBot v4 协议公共入口。 + +这里暴露 v4 原生协议的消息模型、描述符和解析函数。 + +握手阶段由 `InitializeMessage` 发起,返回值不是另一条 initialize 消息,而是 +`ResultMessage(kind="initialize_result")`,其 `output` 负载可解析为 +`InitializeOutput`。 + +## 插件作者指南:什么时候用什么? + +### CapabilityDescriptor vs BUILTIN_CAPABILITY_SCHEMAS + +**CapabilityDescriptor** 用于**声明**能力: +- 当你的插件想**暴露**一个可被其他插件或核心调用的能力时 +- 例如:你的插件提供了一个翻译功能,想让其他插件调用 + + ```python + from astrbot_sdk.protocol import CapabilityDescriptor + + descriptor = CapabilityDescriptor( + name="my_plugin.translate", # 格式: 插件名.能力名 + description="翻译文本到指定语言", + input_schema={ + "type": "object", + "properties": { + "text": {"type": "string", "description": "要翻译的文本"}, + "target_lang": {"type": "string", "description": "目标语言"}, + }, + "required": ["text", "target_lang"], + }, + output_schema={ + "type": "object", + "properties": { + "translated": {"type": "string"}, + }, + }, + ) + ``` + +**BUILTIN_CAPABILITY_SCHEMAS** 用于**查询**内置能力的参数格式: +- 当你想**调用**核心提供的内置能力时,用它了解参数结构 +- 例如:你想调用 `llm.chat`,但不确定参数格式 + + ```python + from astrbot_sdk.protocol import BUILTIN_CAPABILITY_SCHEMAS + + # 查看 llm.chat 的输入参数格式 + schema = BUILTIN_CAPABILITY_SCHEMAS["llm.chat"] + print(schema["input"]) # 输入参数的 JSON Schema + print(schema["output"]) # 输出结果的 JSON Schema + ``` + +### 命名规范 + +能力名称必须遵循 `{namespace}.{action}` 或 `{namespace}.{sub_namespace}.{action}` 格式: +- `llm.chat` - LLM 对话 +- `db.set` - 数据库写入 +- `llm_tool.manager.activate` - LLM 工具管理 + +**保留命名空间**(插件不可使用): +- `handler.` - 处理器相关 +- `system.` - 系统内部能力 +- `internal.` - 内部实现细节 + +### 常用内置能力速查 + +| 能力名 | 用途 | +|-------|------| +| `llm.chat` | 同步 LLM 对话 | +| `llm.stream_chat` | 流式 LLM 对话 | +| `memory.save` / `memory.get` | 短期记忆存储 | +| `db.set` / `db.get` | 持久化键值存储 | +| `platform.send` | 发送消息 | +| `provider.get_using` | 获取当前 Provider | +""" + +from __future__ import annotations + +from typing import Any + +from . import _builtin_schemas as builtin_schemas +from .descriptors import ( # noqa: F401 + BUILTIN_CAPABILITY_SCHEMAS, + CapabilityDescriptor, + CommandRouteSpec, + CommandTrigger, + CompositeFilterSpec, + EventTrigger, + FilterSpec, + HandlerDescriptor, + LocalFilterRefSpec, + MessageTrigger, + MessageTypeFilterSpec, + ParamSpec, + Permissions, + PlatformFilterSpec, + ScheduleTrigger, + SessionRef, + Trigger, +) +from .messages import ( # noqa: F401 + CancelMessage, + ErrorPayload, + EventMessage, + InitializeMessage, + InitializeOutput, + InvokeMessage, + PeerInfo, + ProtocolMessage, + ResultMessage, + parse_message, +) + +_DIRECT_EXPORTS = [ + "BUILTIN_CAPABILITY_SCHEMAS", + "CapabilityDescriptor", + "CommandRouteSpec", + "CommandTrigger", + "CancelMessage", + "builtin_schemas", + "CompositeFilterSpec", + "ErrorPayload", + "EventTrigger", + "EventMessage", + "FilterSpec", + "HandlerDescriptor", + "InitializeMessage", + "InitializeOutput", + "InvokeMessage", + "LocalFilterRefSpec", + "MessageTrigger", + "MessageTypeFilterSpec", + "ParamSpec", + "PeerInfo", + "PlatformFilterSpec", + "Permissions", + "ProtocolMessage", + "ResultMessage", + "ScheduleTrigger", + "SessionRef", + "Trigger", + "parse_message", +] + +_BUILTIN_SCHEMA_EXPORTS = tuple( + name for name in builtin_schemas.__all__ if name != "BUILTIN_CAPABILITY_SCHEMAS" +) + + +def __getattr__(name: str) -> Any: + if name in _BUILTIN_SCHEMA_EXPORTS: + return getattr(builtin_schemas, name) + raise AttributeError(name) + + +def __dir__() -> list[str]: + return sorted(set(globals()) | set(_BUILTIN_SCHEMA_EXPORTS)) + + +__all__ = list(dict.fromkeys([*_DIRECT_EXPORTS, *_BUILTIN_SCHEMA_EXPORTS])) diff --git a/astrbot-sdk/src/astrbot_sdk/protocol/_builtin_schemas.py b/astrbot-sdk/src/astrbot_sdk/protocol/_builtin_schemas.py new file mode 100644 index 0000000000..432d7f5ee2 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/protocol/_builtin_schemas.py @@ -0,0 +1,2470 @@ +"""Builtin protocol schema constants. + +本模块定义了 AstrBot SDK v4 协议中所有内置能力的 JSON Schema。 +这些 Schema 用于: +1. 验证能力调用的输入参数是否符合预期格式 +2. 生成能力描述文档,供插件开发者参考 +3. 确保跨进程/跨语言调用时的类型安全 + +所有 Schema 遵循 JSON Schema 规范,支持基本类型检查、必填字段、数组元素约束等。 +""" + +from __future__ import annotations + +from typing import Any + +JSONSchema = dict[str, Any] + + +def _object_schema( + *, + required: tuple[str, ...] = (), + **properties: Any, +) -> JSONSchema: + return { + "type": "object", + "properties": properties, + "required": list(required), + } + + +def _nullable(schema: JSONSchema) -> JSONSchema: + return {"anyOf": [schema, {"type": "null"}]} + + +_OPTIONAL_CHAT_PROPERTIES: dict[str, Any] = { + "system": {"type": "string"}, + "history": {"type": "array", "items": {"type": "object"}}, + "contexts": {"type": "array", "items": {"type": "object"}}, + "provider_id": {"type": "string"}, + "tool_calls_result": {"type": "array", "items": {"type": "object"}}, + "model": {"type": "string"}, + "temperature": {"type": "number"}, + "image_urls": {"type": "array", "items": {"type": "string"}}, + "tools": {"type": "array"}, + "max_steps": {"type": "integer"}, +} + +LLM_CHAT_INPUT_SCHEMA = _object_schema( + required=("prompt",), + prompt={"type": "string"}, + **_OPTIONAL_CHAT_PROPERTIES, +) +LLM_CHAT_OUTPUT_SCHEMA = _object_schema(required=("text",), text={"type": "string"}) +LLM_CHAT_RAW_INPUT_SCHEMA = _object_schema( + required=("prompt",), + prompt={"type": "string"}, + **_OPTIONAL_CHAT_PROPERTIES, +) +LLM_CHAT_RAW_OUTPUT_SCHEMA = _object_schema( + required=("text",), + text={"type": "string"}, + usage=_nullable({"type": "object"}), + finish_reason=_nullable({"type": "string"}), + tool_calls={"type": "array", "items": {"type": "object"}}, + role=_nullable({"type": "string"}), + reasoning_content=_nullable({"type": "string"}), + reasoning_signature=_nullable({"type": "string"}), +) +LLM_STREAM_CHAT_INPUT_SCHEMA = _object_schema( + required=("prompt",), + prompt={"type": "string"}, + **_OPTIONAL_CHAT_PROPERTIES, +) +LLM_STREAM_CHAT_OUTPUT_SCHEMA = _object_schema( + required=("text",), text={"type": "string"} +) +MEMORY_SEARCH_INPUT_SCHEMA = _object_schema( + required=("query",), + query={"type": "string"}, + mode={"type": "string", "enum": ["auto", "keyword", "vector", "hybrid"]}, + limit={"type": "integer", "minimum": 1}, + min_score={"type": "number"}, + provider_id={"type": "string"}, + namespace={"type": "string"}, + include_descendants={"type": "boolean"}, +) +MEMORY_SEARCH_OUTPUT_SCHEMA = _object_schema( + required=("items",), + items={ + "type": "array", + "items": _object_schema( + required=("key", "value", "score", "match_type"), + key={"type": "string"}, + namespace=_nullable({"type": "string"}), + value=_nullable({"type": "object"}), + score={"type": "number"}, + match_type={ + "type": "string", + "enum": ["keyword", "vector", "hybrid"], + }, + ), + }, +) +MEMORY_SAVE_INPUT_SCHEMA = _object_schema( + required=("key", "value"), + key={"type": "string"}, + value={"type": "object"}, + namespace={"type": "string"}, +) +MEMORY_SAVE_OUTPUT_SCHEMA = _object_schema() +MEMORY_GET_INPUT_SCHEMA = _object_schema( + required=("key",), + key={"type": "string"}, + namespace={"type": "string"}, +) +MEMORY_GET_OUTPUT_SCHEMA = _object_schema( + required=("value",), + value=_nullable({"type": "object"}), +) +MEMORY_LIST_KEYS_INPUT_SCHEMA = _object_schema(namespace={"type": "string"}) +MEMORY_LIST_KEYS_OUTPUT_SCHEMA = _object_schema( + required=("keys",), + keys={"type": "array", "items": {"type": "string"}}, +) +MEMORY_EXISTS_INPUT_SCHEMA = _object_schema( + required=("key",), + key={"type": "string"}, + namespace={"type": "string"}, +) +MEMORY_EXISTS_OUTPUT_SCHEMA = _object_schema( + required=("exists",), + exists={"type": "boolean"}, +) +MEMORY_DELETE_INPUT_SCHEMA = _object_schema( + required=("key",), + key={"type": "string"}, + namespace={"type": "string"}, +) +MEMORY_DELETE_OUTPUT_SCHEMA = _object_schema() +MEMORY_CLEAR_NAMESPACE_INPUT_SCHEMA = _object_schema( + namespace={"type": "string"}, + include_descendants={"type": "boolean"}, +) +MEMORY_CLEAR_NAMESPACE_OUTPUT_SCHEMA = _object_schema( + required=("deleted_count",), + deleted_count={"type": "integer"}, +) +MEMORY_SAVE_WITH_TTL_INPUT_SCHEMA = _object_schema( + required=("key", "value", "ttl_seconds"), + key={"type": "string"}, + value={"type": "object"}, + ttl_seconds={"type": "integer", "minimum": 1}, + namespace={"type": "string"}, +) +MEMORY_SAVE_WITH_TTL_OUTPUT_SCHEMA = _object_schema() +MEMORY_GET_MANY_INPUT_SCHEMA = _object_schema( + required=("keys",), + keys={"type": "array", "items": {"type": "string"}}, + namespace={"type": "string"}, +) +MEMORY_GET_MANY_OUTPUT_SCHEMA = _object_schema( + required=("items",), + items={ + "type": "array", + "items": _object_schema( + required=("key", "value"), + key={"type": "string"}, + value=_nullable({"type": "object"}), + ), + }, +) +MEMORY_DELETE_MANY_INPUT_SCHEMA = _object_schema( + required=("keys",), + keys={"type": "array", "items": {"type": "string"}}, + namespace={"type": "string"}, +) +MEMORY_DELETE_MANY_OUTPUT_SCHEMA = _object_schema( + required=("deleted_count",), + deleted_count={"type": "integer"}, +) +MEMORY_COUNT_INPUT_SCHEMA = _object_schema( + namespace={"type": "string"}, + include_descendants={"type": "boolean"}, +) +MEMORY_COUNT_OUTPUT_SCHEMA = _object_schema( + required=("count",), + count={"type": "integer"}, +) +MEMORY_STATS_INPUT_SCHEMA = _object_schema( + namespace={"type": "string"}, + include_descendants={"type": "boolean"}, +) +MEMORY_STATS_OUTPUT_SCHEMA = _object_schema( + total_items={"type": "integer"}, + total_bytes=_nullable({"type": "integer"}), + plugin_id=_nullable({"type": "string"}), + ttl_entries=_nullable({"type": "integer"}), + namespace=_nullable({"type": "string"}), + namespace_count=_nullable({"type": "integer"}), + indexed_items=_nullable({"type": "integer"}), + embedded_items=_nullable({"type": "integer"}), + dirty_items=_nullable({"type": "integer"}), + fts_enabled={"type": "boolean"}, + vector_backend=_nullable({"type": "string"}), + vector_indexes={"type": "array", "items": {"type": "object"}}, +) +SYSTEM_GET_DATA_DIR_INPUT_SCHEMA = _object_schema() +SYSTEM_GET_DATA_DIR_OUTPUT_SCHEMA = _object_schema( + required=("path",), + path={"type": "string"}, +) +SYSTEM_TEXT_TO_IMAGE_INPUT_SCHEMA = _object_schema( + required=("text",), + text={"type": "string"}, + return_url={"type": "boolean"}, +) +SYSTEM_TEXT_TO_IMAGE_OUTPUT_SCHEMA = _object_schema( + required=("result",), + result={"type": "string"}, +) +SYSTEM_HTML_RENDER_INPUT_SCHEMA = _object_schema( + required=("tmpl", "data"), + tmpl={"type": "string"}, + data={"type": "object"}, + return_url={"type": "boolean"}, + options=_nullable({"type": "object"}), +) +SYSTEM_HTML_RENDER_OUTPUT_SCHEMA = _object_schema( + required=("result",), + result={"type": "string"}, +) +SYSTEM_FILE_REGISTER_INPUT_SCHEMA = _object_schema( + required=("path",), + path={"type": "string"}, + timeout=_nullable({"type": "number"}), +) +SYSTEM_FILE_REGISTER_OUTPUT_SCHEMA = _object_schema( + required=("token", "url"), + token={"type": "string"}, + url={"type": "string"}, +) +SYSTEM_FILE_HANDLE_INPUT_SCHEMA = _object_schema( + required=("token",), + token={"type": "string"}, +) +SYSTEM_FILE_HANDLE_OUTPUT_SCHEMA = _object_schema( + required=("path",), + path={"type": "string"}, +) +SYSTEM_SESSION_WAITER_REGISTER_INPUT_SCHEMA = _object_schema( + required=("session_key",), + session_key={"type": "string"}, +) +SYSTEM_SESSION_WAITER_REGISTER_OUTPUT_SCHEMA = _object_schema() +SYSTEM_SESSION_WAITER_UNREGISTER_INPUT_SCHEMA = _object_schema( + required=("session_key",), + session_key={"type": "string"}, +) +SYSTEM_SESSION_WAITER_UNREGISTER_OUTPUT_SCHEMA = _object_schema() +DB_GET_INPUT_SCHEMA = _object_schema(required=("key",), key={"type": "string"}) +DB_GET_OUTPUT_SCHEMA = _object_schema( + required=("value",), + value=_nullable({}), +) +DB_SET_INPUT_SCHEMA = _object_schema( + required=("key", "value"), + key={"type": "string"}, + value={}, +) +DB_SET_OUTPUT_SCHEMA = _object_schema() +DB_DELETE_INPUT_SCHEMA = _object_schema(required=("key",), key={"type": "string"}) +DB_DELETE_OUTPUT_SCHEMA = _object_schema() +DB_LIST_INPUT_SCHEMA = _object_schema(prefix=_nullable({"type": "string"})) +DB_LIST_OUTPUT_SCHEMA = _object_schema( + required=("keys",), + keys={"type": "array", "items": {"type": "string"}}, +) +DB_GET_MANY_INPUT_SCHEMA = _object_schema( + required=("keys",), + keys={"type": "array", "items": {"type": "string"}}, +) +DB_GET_MANY_OUTPUT_SCHEMA = _object_schema( + required=("items",), + items={ + "type": "array", + "items": _object_schema( + required=("key", "value"), + key={"type": "string"}, + value=_nullable({}), + ), + }, +) +DB_SET_MANY_INPUT_SCHEMA = _object_schema( + required=("items",), + items={ + "type": "array", + "items": _object_schema( + required=("key", "value"), + key={"type": "string"}, + value={}, + ), + }, +) +DB_SET_MANY_OUTPUT_SCHEMA = _object_schema() +DB_WATCH_INPUT_SCHEMA = _object_schema(prefix=_nullable({"type": "string"})) +DB_WATCH_OUTPUT_SCHEMA = _object_schema() +SESSION_REF_SCHEMA = _object_schema( + required=("conversation_id",), + conversation_id={"type": "string"}, + platform=_nullable({"type": "string"}), + raw=_nullable({"type": "object"}), +) +SYSTEM_EVENT_REACT_INPUT_SCHEMA = _object_schema( + required=("emoji",), + target=_nullable(SESSION_REF_SCHEMA), + emoji={"type": "string"}, +) +SYSTEM_EVENT_REACT_OUTPUT_SCHEMA = _object_schema( + required=("supported",), + supported={"type": "boolean"}, +) +SYSTEM_EVENT_SEND_TYPING_INPUT_SCHEMA = _object_schema( + target=_nullable(SESSION_REF_SCHEMA), +) +SYSTEM_EVENT_SEND_TYPING_OUTPUT_SCHEMA = _object_schema( + required=("supported",), + supported={"type": "boolean"}, +) +SYSTEM_EVENT_SEND_STREAMING_INPUT_SCHEMA = _object_schema( + target=_nullable(SESSION_REF_SCHEMA), + use_fallback={"type": "boolean"}, +) +SYSTEM_EVENT_SEND_STREAMING_OUTPUT_SCHEMA = _object_schema( + required=("supported",), + supported={"type": "boolean"}, + stream_id=_nullable({"type": "string"}), +) +SYSTEM_EVENT_SEND_STREAMING_CHUNK_INPUT_SCHEMA = _object_schema( + required=("stream_id", "chain"), + stream_id={"type": "string"}, + chain={"type": "array", "items": {"type": "object"}}, +) +SYSTEM_EVENT_SEND_STREAMING_CHUNK_OUTPUT_SCHEMA = _object_schema() +SYSTEM_EVENT_SEND_STREAMING_CLOSE_INPUT_SCHEMA = _object_schema( + required=("stream_id",), + stream_id={"type": "string"}, +) +SYSTEM_EVENT_SEND_STREAMING_CLOSE_OUTPUT_SCHEMA = _object_schema( + required=("supported",), + supported={"type": "boolean"}, +) +SYSTEM_EVENT_LLM_GET_STATE_INPUT_SCHEMA = _object_schema( + target=_nullable(SESSION_REF_SCHEMA), +) +SYSTEM_EVENT_LLM_GET_STATE_OUTPUT_SCHEMA = _object_schema( + required=("should_call_llm", "requested_llm"), + should_call_llm={"type": "boolean"}, + requested_llm={"type": "boolean"}, +) +SYSTEM_EVENT_LLM_REQUEST_INPUT_SCHEMA = _object_schema( + target=_nullable(SESSION_REF_SCHEMA), +) +SYSTEM_EVENT_LLM_REQUEST_OUTPUT_SCHEMA = _object_schema( + required=("should_call_llm", "requested_llm"), + should_call_llm={"type": "boolean"}, + requested_llm={"type": "boolean"}, +) +SYSTEM_EVENT_RESULT_GET_INPUT_SCHEMA = _object_schema( + target=_nullable(SESSION_REF_SCHEMA), +) +SYSTEM_EVENT_RESULT_GET_OUTPUT_SCHEMA = _object_schema( + required=("result",), + result=_nullable({"type": "object"}), +) +SYSTEM_EVENT_RESULT_SET_INPUT_SCHEMA = _object_schema( + required=("result",), + target=_nullable(SESSION_REF_SCHEMA), + result={"type": "object"}, +) +SYSTEM_EVENT_RESULT_SET_OUTPUT_SCHEMA = _object_schema( + required=("result",), + result={"type": "object"}, +) +SYSTEM_EVENT_RESULT_CLEAR_INPUT_SCHEMA = _object_schema( + target=_nullable(SESSION_REF_SCHEMA), +) +SYSTEM_EVENT_RESULT_CLEAR_OUTPUT_SCHEMA = _object_schema() +SYSTEM_EVENT_HANDLER_WHITELIST_GET_INPUT_SCHEMA = _object_schema( + target=_nullable(SESSION_REF_SCHEMA), +) +SYSTEM_EVENT_HANDLER_WHITELIST_GET_OUTPUT_SCHEMA = _object_schema( + required=("plugin_names",), + plugin_names=_nullable({"type": "array", "items": {"type": "string"}}), +) +SYSTEM_EVENT_HANDLER_WHITELIST_SET_INPUT_SCHEMA = _object_schema( + target=_nullable(SESSION_REF_SCHEMA), + plugin_names=_nullable({"type": "array", "items": {"type": "string"}}), +) +SYSTEM_EVENT_HANDLER_WHITELIST_SET_OUTPUT_SCHEMA = _object_schema( + required=("plugin_names",), + plugin_names=_nullable({"type": "array", "items": {"type": "string"}}), +) +PLATFORM_SEND_INPUT_SCHEMA = _object_schema( + required=("session", "text"), + session={"type": "string"}, + target=_nullable(SESSION_REF_SCHEMA), + text={"type": "string"}, +) +PLATFORM_SEND_OUTPUT_SCHEMA = _object_schema( + required=("message_id",), + message_id={"type": "string"}, +) +PLATFORM_SEND_IMAGE_INPUT_SCHEMA = _object_schema( + required=("session", "image_url"), + session={"type": "string"}, + target=_nullable(SESSION_REF_SCHEMA), + image_url={"type": "string"}, +) +PLATFORM_SEND_IMAGE_OUTPUT_SCHEMA = _object_schema( + required=("message_id",), + message_id={"type": "string"}, +) +PLATFORM_SEND_CHAIN_INPUT_SCHEMA = _object_schema( + required=("session", "chain"), + session={"type": "string"}, + target=_nullable(SESSION_REF_SCHEMA), + chain={"type": "array", "items": {"type": "object"}}, +) +PLATFORM_SEND_CHAIN_OUTPUT_SCHEMA = _object_schema( + required=("message_id",), + message_id={"type": "string"}, +) +PLATFORM_SEND_BY_SESSION_INPUT_SCHEMA = _object_schema( + required=("session", "chain"), + session={"type": "string"}, + chain={"type": "array", "items": {"type": "object"}}, +) +PLATFORM_SEND_BY_SESSION_OUTPUT_SCHEMA = _object_schema( + required=("message_id",), + message_id={"type": "string"}, +) +PLATFORM_GET_GROUP_INPUT_SCHEMA = _object_schema( + required=("session",), + session={"type": "string"}, + target=_nullable(SESSION_REF_SCHEMA), +) +PLATFORM_GET_GROUP_OUTPUT_SCHEMA = _object_schema( + required=("group",), + group=_nullable({"type": "object"}), +) +PLATFORM_GET_MEMBERS_INPUT_SCHEMA = _object_schema( + required=("session",), + session={"type": "string"}, + target=_nullable(SESSION_REF_SCHEMA), +) +PLATFORM_GET_MEMBERS_OUTPUT_SCHEMA = _object_schema( + required=("members",), + members={"type": "array", "items": {"type": "object"}}, +) +PLATFORM_INSTANCE_SCHEMA = _object_schema( + required=("id", "name", "type", "status"), + id={"type": "string"}, + name={"type": "string"}, + type={"type": "string"}, + status={"type": "string"}, +) +PLATFORM_LIST_INSTANCES_INPUT_SCHEMA = _object_schema() +PLATFORM_LIST_INSTANCES_OUTPUT_SCHEMA = _object_schema( + required=("platforms",), + platforms={"type": "array", "items": PLATFORM_INSTANCE_SCHEMA}, +) +PLATFORM_ERROR_SCHEMA = _object_schema( + required=("message", "timestamp"), + message={"type": "string"}, + timestamp={"type": "string"}, + traceback=_nullable({"type": "string"}), +) +PLATFORM_MANAGER_STATE_SCHEMA = _object_schema( + required=("id", "name", "type", "status", "errors", "unified_webhook"), + id={"type": "string"}, + name={"type": "string"}, + type={"type": "string"}, + status={"type": "string"}, + errors={"type": "array", "items": PLATFORM_ERROR_SCHEMA}, + last_error=_nullable(PLATFORM_ERROR_SCHEMA), + unified_webhook={"type": "boolean"}, +) +PLATFORM_STATS_SCHEMA = _object_schema( + required=( + "id", + "type", + "display_name", + "status", + "error_count", + "unified_webhook", + ), + id={"type": "string"}, + type={"type": "string"}, + display_name={"type": "string"}, + status={"type": "string"}, + started_at=_nullable({"type": "string"}), + error_count={"type": "integer"}, + last_error=_nullable(PLATFORM_ERROR_SCHEMA), + unified_webhook={"type": "boolean"}, + meta={"type": "object"}, +) +PLATFORM_MANAGER_GET_BY_ID_INPUT_SCHEMA = _object_schema( + required=("platform_id",), + platform_id={"type": "string"}, +) +PLATFORM_MANAGER_GET_BY_ID_OUTPUT_SCHEMA = _object_schema( + required=("platform",), + platform=_nullable(PLATFORM_MANAGER_STATE_SCHEMA), +) +PLATFORM_MANAGER_CLEAR_ERRORS_INPUT_SCHEMA = _object_schema( + required=("platform_id",), + platform_id={"type": "string"}, +) +PLATFORM_MANAGER_CLEAR_ERRORS_OUTPUT_SCHEMA = _object_schema() +PLATFORM_MANAGER_GET_STATS_INPUT_SCHEMA = _object_schema( + required=("platform_id",), + platform_id={"type": "string"}, +) +PLATFORM_MANAGER_GET_STATS_OUTPUT_SCHEMA = _object_schema( + required=("stats",), + stats=_nullable(PLATFORM_STATS_SCHEMA), +) +PERMISSION_ROLE_SCHEMA = {"type": "string", "enum": ["member", "admin"]} +PERMISSION_CHECK_INPUT_SCHEMA = _object_schema( + required=("user_id",), + user_id={"type": "string"}, + session_id=_nullable({"type": "string"}), +) +PERMISSION_CHECK_RESULT_SCHEMA = _object_schema( + required=("is_admin", "role"), + is_admin={"type": "boolean"}, + role=PERMISSION_ROLE_SCHEMA, +) +PERMISSION_CHECK_OUTPUT_SCHEMA = PERMISSION_CHECK_RESULT_SCHEMA +PERMISSION_GET_ADMINS_INPUT_SCHEMA = _object_schema() +PERMISSION_GET_ADMINS_OUTPUT_SCHEMA = _object_schema( + required=("admins",), + admins={"type": "array", "items": {"type": "string"}}, +) +PERMISSION_MANAGER_ADD_ADMIN_INPUT_SCHEMA = _object_schema( + required=("user_id",), + user_id={"type": "string"}, +) +PERMISSION_MANAGER_ADD_ADMIN_OUTPUT_SCHEMA = _object_schema( + required=("changed",), + changed={"type": "boolean"}, +) +PERMISSION_MANAGER_REMOVE_ADMIN_INPUT_SCHEMA = _object_schema( + required=("user_id",), + user_id={"type": "string"}, +) +PERMISSION_MANAGER_REMOVE_ADMIN_OUTPUT_SCHEMA = _object_schema( + required=("changed",), + changed={"type": "boolean"}, +) +SESSION_PLUGIN_IS_ENABLED_INPUT_SCHEMA = _object_schema( + required=("session", "plugin_name"), + session={"type": "string"}, + plugin_name={"type": "string"}, +) +SESSION_PLUGIN_IS_ENABLED_OUTPUT_SCHEMA = _object_schema( + required=("enabled",), + enabled={"type": "boolean"}, +) +SESSION_PLUGIN_FILTER_HANDLERS_INPUT_SCHEMA = _object_schema( + required=("session", "handlers"), + session={"type": "string"}, + handlers={"type": "array", "items": {"type": "object"}}, +) +SESSION_PLUGIN_FILTER_HANDLERS_OUTPUT_SCHEMA = _object_schema( + required=("handlers",), + handlers={"type": "array", "items": {"type": "object"}}, +) +SESSION_SERVICE_IS_LLM_ENABLED_INPUT_SCHEMA = _object_schema( + required=("session",), + session={"type": "string"}, +) +SESSION_SERVICE_IS_LLM_ENABLED_OUTPUT_SCHEMA = _object_schema( + required=("enabled",), + enabled={"type": "boolean"}, +) +SESSION_SERVICE_SET_LLM_STATUS_INPUT_SCHEMA = _object_schema( + required=("session", "enabled"), + session={"type": "string"}, + enabled={"type": "boolean"}, +) +SESSION_SERVICE_SET_LLM_STATUS_OUTPUT_SCHEMA = _object_schema() +SESSION_SERVICE_IS_TTS_ENABLED_INPUT_SCHEMA = _object_schema( + required=("session",), + session={"type": "string"}, +) +SESSION_SERVICE_IS_TTS_ENABLED_OUTPUT_SCHEMA = _object_schema( + required=("enabled",), + enabled={"type": "boolean"}, +) +SESSION_SERVICE_SET_TTS_STATUS_INPUT_SCHEMA = _object_schema( + required=("session", "enabled"), + session={"type": "string"}, + enabled={"type": "boolean"}, +) +SESSION_SERVICE_SET_TTS_STATUS_OUTPUT_SCHEMA = _object_schema() +PERSONA_RECORD_SCHEMA = _object_schema( + required=("persona_id", "system_prompt", "begin_dialogs", "sort_order"), + persona_id={"type": "string"}, + system_prompt={"type": "string"}, + begin_dialogs={"type": "array", "items": {"type": "string"}}, + tools=_nullable({"type": "array", "items": {"type": "string"}}), + skills=_nullable({"type": "array", "items": {"type": "string"}}), + custom_error_message=_nullable({"type": "string"}), + folder_id=_nullable({"type": "string"}), + sort_order={"type": "integer"}, + created_at=_nullable({"type": "string"}), + updated_at=_nullable({"type": "string"}), +) +PERSONA_CREATE_SCHEMA = _object_schema( + required=("persona_id", "system_prompt"), + persona_id={"type": "string"}, + system_prompt={"type": "string"}, + begin_dialogs={"type": "array", "items": {"type": "string"}}, + tools=_nullable({"type": "array", "items": {"type": "string"}}), + skills=_nullable({"type": "array", "items": {"type": "string"}}), + custom_error_message=_nullable({"type": "string"}), + folder_id=_nullable({"type": "string"}), + sort_order={"type": "integer"}, +) +PERSONA_UPDATE_SCHEMA = _object_schema( + system_prompt=_nullable({"type": "string"}), + begin_dialogs=_nullable({"type": "array", "items": {"type": "string"}}), + tools=_nullable({"type": "array", "items": {"type": "string"}}), + skills=_nullable({"type": "array", "items": {"type": "string"}}), + custom_error_message=_nullable({"type": "string"}), +) +PERSONA_GET_INPUT_SCHEMA = _object_schema( + required=("persona_id",), + persona_id={"type": "string"}, +) +PERSONA_GET_OUTPUT_SCHEMA = _object_schema( + required=("persona",), + persona=PERSONA_RECORD_SCHEMA, +) +PERSONA_LIST_INPUT_SCHEMA = _object_schema() +PERSONA_LIST_OUTPUT_SCHEMA = _object_schema( + required=("personas",), + personas={"type": "array", "items": PERSONA_RECORD_SCHEMA}, +) +PERSONA_CREATE_INPUT_SCHEMA = _object_schema( + required=("persona",), + persona=PERSONA_CREATE_SCHEMA, +) +PERSONA_CREATE_OUTPUT_SCHEMA = _object_schema( + required=("persona",), + persona=PERSONA_RECORD_SCHEMA, +) +PERSONA_UPDATE_INPUT_SCHEMA = _object_schema( + required=("persona_id", "persona"), + persona_id={"type": "string"}, + persona=PERSONA_UPDATE_SCHEMA, +) +PERSONA_UPDATE_OUTPUT_SCHEMA = _object_schema( + required=("persona",), + persona=_nullable(PERSONA_RECORD_SCHEMA), +) +PERSONA_DELETE_INPUT_SCHEMA = _object_schema( + required=("persona_id",), + persona_id={"type": "string"}, +) +PERSONA_DELETE_OUTPUT_SCHEMA = _object_schema() +CONVERSATION_RECORD_SCHEMA = _object_schema( + required=("conversation_id", "session", "platform_id", "history"), + conversation_id={"type": "string"}, + session={"type": "string"}, + platform_id={"type": "string"}, + history={"type": "array", "items": {"type": "object"}}, + title=_nullable({"type": "string"}), + persona_id=_nullable({"type": "string"}), + created_at=_nullable({"type": "string"}), + updated_at=_nullable({"type": "string"}), + token_usage=_nullable({"type": "integer"}), +) +CONVERSATION_CREATE_SCHEMA = _object_schema( + platform_id=_nullable({"type": "string"}), + history=_nullable({"type": "array", "items": {"type": "object"}}), + title=_nullable({"type": "string"}), + persona_id=_nullable({"type": "string"}), +) +CONVERSATION_UPDATE_SCHEMA = _object_schema( + history=_nullable({"type": "array", "items": {"type": "object"}}), + title=_nullable({"type": "string"}), + persona_id=_nullable({"type": "string"}), + token_usage=_nullable({"type": "integer"}), +) +CONVERSATION_NEW_INPUT_SCHEMA = _object_schema( + required=("session",), + session={"type": "string"}, + conversation=_nullable(CONVERSATION_CREATE_SCHEMA), +) +CONVERSATION_NEW_OUTPUT_SCHEMA = _object_schema( + required=("conversation_id",), + conversation_id={"type": "string"}, +) +CONVERSATION_SWITCH_INPUT_SCHEMA = _object_schema( + required=("session", "conversation_id"), + session={"type": "string"}, + conversation_id={"type": "string"}, +) +CONVERSATION_SWITCH_OUTPUT_SCHEMA = _object_schema() +CONVERSATION_DELETE_INPUT_SCHEMA = _object_schema( + required=("session",), + session={"type": "string"}, + conversation_id=_nullable({"type": "string"}), +) +CONVERSATION_DELETE_OUTPUT_SCHEMA = _object_schema() +CONVERSATION_GET_INPUT_SCHEMA = _object_schema( + required=("session", "conversation_id"), + session={"type": "string"}, + conversation_id={"type": "string"}, + create_if_not_exists={"type": "boolean"}, +) +CONVERSATION_GET_OUTPUT_SCHEMA = _object_schema( + required=("conversation",), + conversation=_nullable(CONVERSATION_RECORD_SCHEMA), +) +CONVERSATION_GET_CURRENT_INPUT_SCHEMA = _object_schema( + required=("session",), + session={"type": "string"}, + create_if_not_exists={"type": "boolean"}, +) +CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA = _object_schema( + required=("conversation",), + conversation=_nullable(CONVERSATION_RECORD_SCHEMA), +) +CONVERSATION_LIST_INPUT_SCHEMA = _object_schema( + session=_nullable({"type": "string"}), + platform_id=_nullable({"type": "string"}), +) +CONVERSATION_LIST_OUTPUT_SCHEMA = _object_schema( + required=("conversations",), + conversations={"type": "array", "items": CONVERSATION_RECORD_SCHEMA}, +) +CONVERSATION_UPDATE_INPUT_SCHEMA = _object_schema( + required=("session",), + session={"type": "string"}, + conversation_id=_nullable({"type": "string"}), + conversation=_nullable(CONVERSATION_UPDATE_SCHEMA), +) +CONVERSATION_UPDATE_OUTPUT_SCHEMA = _object_schema() +CONVERSATION_UNSET_PERSONA_INPUT_SCHEMA = _object_schema( + required=("session",), + session={"type": "string"}, + conversation_id=_nullable({"type": "string"}), +) +CONVERSATION_UNSET_PERSONA_OUTPUT_SCHEMA = _object_schema() +MESSAGE_HISTORY_SESSION_SCHEMA = _object_schema( + required=("platform_id", "message_type", "session_id"), + platform_id={"type": "string"}, + message_type={"type": "string", "enum": ["group", "private", "other"]}, + session_id={"type": "string"}, +) +MESSAGE_HISTORY_SENDER_SCHEMA = _object_schema( + sender_id=_nullable({"type": "string"}), + sender_name=_nullable({"type": "string"}), +) +MESSAGE_HISTORY_RECORD_SCHEMA = _object_schema( + required=("id", "session", "sender", "parts", "metadata"), + id={"type": "integer"}, + session=MESSAGE_HISTORY_SESSION_SCHEMA, + sender=MESSAGE_HISTORY_SENDER_SCHEMA, + parts={"type": "array", "items": {"type": "object"}}, + metadata={"type": "object"}, + created_at=_nullable({"type": "string"}), + updated_at=_nullable({"type": "string"}), + idempotency_key=_nullable({"type": "string"}), +) +MESSAGE_HISTORY_PAGE_SCHEMA = _object_schema( + required=("records",), + records={"type": "array", "items": MESSAGE_HISTORY_RECORD_SCHEMA}, + next_cursor=_nullable({"type": "string"}), + total=_nullable({"type": "integer"}), +) +MESSAGE_HISTORY_LIST_INPUT_SCHEMA = _object_schema( + required=("session",), + session=MESSAGE_HISTORY_SESSION_SCHEMA, + cursor=_nullable({"type": "string", "pattern": "^(|[1-9][0-9]*)$"}), + limit={"type": "integer", "minimum": 1}, +) +MESSAGE_HISTORY_LIST_OUTPUT_SCHEMA = _object_schema( + required=("page",), + page=MESSAGE_HISTORY_PAGE_SCHEMA, +) +MESSAGE_HISTORY_GET_BY_ID_INPUT_SCHEMA = _object_schema( + required=("session", "record_id"), + session=MESSAGE_HISTORY_SESSION_SCHEMA, + record_id={"type": "integer", "minimum": 1}, +) +MESSAGE_HISTORY_GET_BY_ID_OUTPUT_SCHEMA = _object_schema( + required=("record",), + record=_nullable(MESSAGE_HISTORY_RECORD_SCHEMA), +) +MESSAGE_HISTORY_APPEND_INPUT_SCHEMA = _object_schema( + required=("session", "sender", "parts"), + session=MESSAGE_HISTORY_SESSION_SCHEMA, + sender=MESSAGE_HISTORY_SENDER_SCHEMA, + parts={"type": "array", "items": {"type": "object"}}, + metadata=_nullable({"type": "object"}), + idempotency_key=_nullable({"type": "string"}), +) +MESSAGE_HISTORY_APPEND_OUTPUT_SCHEMA = _object_schema( + required=("record",), + record=MESSAGE_HISTORY_RECORD_SCHEMA, +) +MESSAGE_HISTORY_DELETE_BEFORE_INPUT_SCHEMA = _object_schema( + required=("session", "before"), + session=MESSAGE_HISTORY_SESSION_SCHEMA, + before={"type": "string"}, +) +MESSAGE_HISTORY_DELETE_BEFORE_OUTPUT_SCHEMA = _object_schema( + required=("deleted_count",), + deleted_count={"type": "integer"}, +) +MESSAGE_HISTORY_DELETE_AFTER_INPUT_SCHEMA = _object_schema( + required=("session", "after"), + session=MESSAGE_HISTORY_SESSION_SCHEMA, + after={"type": "string"}, +) +MESSAGE_HISTORY_DELETE_AFTER_OUTPUT_SCHEMA = _object_schema( + required=("deleted_count",), + deleted_count={"type": "integer"}, +) +MESSAGE_HISTORY_DELETE_ALL_INPUT_SCHEMA = _object_schema( + required=("session",), + session=MESSAGE_HISTORY_SESSION_SCHEMA, +) +MESSAGE_HISTORY_DELETE_ALL_OUTPUT_SCHEMA = _object_schema( + required=("deleted_count",), + deleted_count={"type": "integer"}, +) +MCP_SERVER_SCOPE_SCHEMA = {"type": "string", "enum": ["local", "global"]} +MCP_SERVER_RECORD_SCHEMA = _object_schema( + required=("name", "scope", "active", "running", "config", "tools", "errlogs"), + name={"type": "string"}, + scope=MCP_SERVER_SCOPE_SCHEMA, + active={"type": "boolean"}, + running={"type": "boolean"}, + config={"type": "object"}, + tools={"type": "array", "items": {"type": "string"}}, + errlogs={"type": "array", "items": {"type": "string"}}, + last_error=_nullable({"type": "string"}), +) +MCP_LOCAL_GET_INPUT_SCHEMA = _object_schema(required=("name",), name={"type": "string"}) +MCP_LOCAL_GET_OUTPUT_SCHEMA = _object_schema( + required=("server",), + server=_nullable(MCP_SERVER_RECORD_SCHEMA), +) +MCP_LOCAL_LIST_INPUT_SCHEMA = _object_schema() +MCP_LOCAL_LIST_OUTPUT_SCHEMA = _object_schema( + required=("servers",), + servers={"type": "array", "items": MCP_SERVER_RECORD_SCHEMA}, +) +MCP_LOCAL_ENABLE_INPUT_SCHEMA = _object_schema( + required=("name",), name={"type": "string"} +) +MCP_LOCAL_ENABLE_OUTPUT_SCHEMA = _object_schema( + required=("server",), + server=MCP_SERVER_RECORD_SCHEMA, +) +MCP_LOCAL_DISABLE_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +MCP_LOCAL_DISABLE_OUTPUT_SCHEMA = _object_schema( + required=("server",), + server=MCP_SERVER_RECORD_SCHEMA, +) +MCP_LOCAL_WAIT_UNTIL_READY_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, + timeout={"type": "number"}, +) +MCP_LOCAL_WAIT_UNTIL_READY_OUTPUT_SCHEMA = _object_schema( + required=("server",), + server=MCP_SERVER_RECORD_SCHEMA, +) +MCP_SESSION_OPEN_INPUT_SCHEMA = _object_schema( + required=("name", "config"), + name={"type": "string"}, + config={"type": "object"}, + timeout={"type": "number"}, +) +MCP_SESSION_OPEN_OUTPUT_SCHEMA = _object_schema( + required=("session_id", "tools"), + session_id={"type": "string"}, + tools={"type": "array", "items": {"type": "string"}}, +) +MCP_SESSION_LIST_TOOLS_INPUT_SCHEMA = _object_schema( + required=("session_id",), + session_id={"type": "string"}, +) +MCP_SESSION_LIST_TOOLS_OUTPUT_SCHEMA = _object_schema( + required=("tools",), + tools={"type": "array", "items": {"type": "string"}}, +) +MCP_SESSION_CALL_TOOL_INPUT_SCHEMA = _object_schema( + required=("session_id", "tool_name", "args"), + session_id={"type": "string"}, + tool_name={"type": "string"}, + args={"type": "object"}, +) +MCP_SESSION_CALL_TOOL_OUTPUT_SCHEMA = _object_schema( + required=("result",), + result={"type": "object"}, +) +MCP_SESSION_CLOSE_INPUT_SCHEMA = _object_schema( + required=("session_id",), + session_id={"type": "string"}, +) +MCP_SESSION_CLOSE_OUTPUT_SCHEMA = _object_schema() +MCP_GLOBAL_REGISTER_INPUT_SCHEMA = _object_schema( + required=("name", "config"), + name={"type": "string"}, + config={"type": "object"}, + timeout={"type": "number"}, +) +MCP_GLOBAL_REGISTER_OUTPUT_SCHEMA = _object_schema( + required=("server",), + server=MCP_SERVER_RECORD_SCHEMA, +) +MCP_GLOBAL_GET_INPUT_SCHEMA = _object_schema( + required=("name",), name={"type": "string"} +) +MCP_GLOBAL_GET_OUTPUT_SCHEMA = _object_schema( + required=("server",), + server=_nullable(MCP_SERVER_RECORD_SCHEMA), +) +MCP_GLOBAL_LIST_INPUT_SCHEMA = _object_schema() +MCP_GLOBAL_LIST_OUTPUT_SCHEMA = _object_schema( + required=("servers",), + servers={"type": "array", "items": MCP_SERVER_RECORD_SCHEMA}, +) +MCP_GLOBAL_ENABLE_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, + timeout={"type": "number"}, +) +MCP_GLOBAL_ENABLE_OUTPUT_SCHEMA = _object_schema( + required=("server",), + server=MCP_SERVER_RECORD_SCHEMA, +) +MCP_GLOBAL_DISABLE_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +MCP_GLOBAL_DISABLE_OUTPUT_SCHEMA = _object_schema( + required=("server",), + server=MCP_SERVER_RECORD_SCHEMA, +) +MCP_GLOBAL_UNREGISTER_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +MCP_GLOBAL_UNREGISTER_OUTPUT_SCHEMA = _object_schema( + required=("server",), + server=MCP_SERVER_RECORD_SCHEMA, +) +INTERNAL_MCP_LOCAL_EXECUTE_INPUT_SCHEMA = _object_schema( + required=("plugin_id", "server_name", "tool_name", "tool_args"), + plugin_id={"type": "string"}, + server_name={"type": "string"}, + tool_name={"type": "string"}, + tool_args={"type": "object"}, +) +INTERNAL_MCP_LOCAL_EXECUTE_OUTPUT_SCHEMA = _object_schema( + required=("content", "success"), + content=_nullable({"type": "string"}), + success={"type": "boolean"}, +) +KNOWLEDGE_BASE_RECORD_SCHEMA = _object_schema( + required=("kb_id", "kb_name", "embedding_provider_id", "doc_count", "chunk_count"), + kb_id={"type": "string"}, + kb_name={"type": "string"}, + description=_nullable({"type": "string"}), + emoji=_nullable({"type": "string"}), + embedding_provider_id={"type": "string"}, + rerank_provider_id=_nullable({"type": "string"}), + chunk_size=_nullable({"type": "integer"}), + chunk_overlap=_nullable({"type": "integer"}), + top_k_dense=_nullable({"type": "integer"}), + top_k_sparse=_nullable({"type": "integer"}), + top_m_final=_nullable({"type": "integer"}), + doc_count={"type": "integer"}, + chunk_count={"type": "integer"}, + created_at=_nullable({"type": "string"}), + updated_at=_nullable({"type": "string"}), +) +KNOWLEDGE_BASE_CREATE_SCHEMA = _object_schema( + required=("kb_name", "embedding_provider_id"), + kb_name={"type": "string"}, + embedding_provider_id={"type": "string"}, + description=_nullable({"type": "string"}), + emoji=_nullable({"type": "string"}), + rerank_provider_id=_nullable({"type": "string"}), + chunk_size=_nullable({"type": "integer"}), + chunk_overlap=_nullable({"type": "integer"}), + top_k_dense=_nullable({"type": "integer"}), + top_k_sparse=_nullable({"type": "integer"}), + top_m_final=_nullable({"type": "integer"}), +) +KNOWLEDGE_BASE_UPDATE_SCHEMA = _object_schema( + kb_name=_nullable({"type": "string"}), + description=_nullable({"type": "string"}), + emoji=_nullable({"type": "string"}), + embedding_provider_id=_nullable({"type": "string"}), + rerank_provider_id=_nullable({"type": "string"}), + chunk_size=_nullable({"type": "integer"}), + chunk_overlap=_nullable({"type": "integer"}), + top_k_dense=_nullable({"type": "integer"}), + top_k_sparse=_nullable({"type": "integer"}), + top_m_final=_nullable({"type": "integer"}), +) +KNOWLEDGE_BASE_DOCUMENT_RECORD_SCHEMA = _object_schema( + required=( + "doc_id", + "kb_id", + "doc_name", + "file_type", + "file_size", + "chunk_count", + "media_count", + ), + doc_id={"type": "string"}, + kb_id={"type": "string"}, + doc_name={"type": "string"}, + file_type={"type": "string"}, + file_size={"type": "integer"}, + file_path={"type": "string"}, + chunk_count={"type": "integer"}, + media_count={"type": "integer"}, + created_at=_nullable({"type": "string"}), + updated_at=_nullable({"type": "string"}), +) +KNOWLEDGE_BASE_RETRIEVE_RESULT_SCHEMA = _object_schema( + required=( + "chunk_id", + "doc_id", + "kb_id", + "kb_name", + "doc_name", + "chunk_index", + "content", + "score", + "char_count", + ), + chunk_id={"type": "string"}, + doc_id={"type": "string"}, + kb_id={"type": "string"}, + kb_name={"type": "string"}, + doc_name={"type": "string"}, + chunk_index={"type": "integer"}, + content={"type": "string"}, + score={"type": "number"}, + char_count={"type": "integer"}, +) +KNOWLEDGE_BASE_DOCUMENT_UPLOAD_SCHEMA = _object_schema( + file_token=_nullable({"type": "string"}), + url=_nullable({"type": "string"}), + text=_nullable({"type": "string"}), + file_name=_nullable({"type": "string"}), + file_type=_nullable({"type": "string"}), + chunk_size=_nullable({"type": "integer"}), + chunk_overlap=_nullable({"type": "integer"}), + batch_size=_nullable({"type": "integer"}), + tasks_limit=_nullable({"type": "integer"}), + max_retries=_nullable({"type": "integer"}), + enable_cleaning=_nullable({"type": "boolean"}), + cleaning_provider_id=_nullable({"type": "string"}), +) +KB_LIST_INPUT_SCHEMA = _object_schema() +KB_LIST_OUTPUT_SCHEMA = _object_schema( + required=("kbs",), + kbs={"type": "array", "items": KNOWLEDGE_BASE_RECORD_SCHEMA}, +) +KB_GET_INPUT_SCHEMA = _object_schema( + required=("kb_id",), + kb_id={"type": "string"}, +) +KB_GET_OUTPUT_SCHEMA = _object_schema( + required=("kb",), + kb=_nullable(KNOWLEDGE_BASE_RECORD_SCHEMA), +) +KB_CREATE_INPUT_SCHEMA = _object_schema( + required=("kb",), + kb=KNOWLEDGE_BASE_CREATE_SCHEMA, +) +KB_CREATE_OUTPUT_SCHEMA = _object_schema( + required=("kb",), + kb=KNOWLEDGE_BASE_RECORD_SCHEMA, +) +KB_UPDATE_INPUT_SCHEMA = _object_schema( + required=("kb_id", "kb"), + kb_id={"type": "string"}, + kb=KNOWLEDGE_BASE_UPDATE_SCHEMA, +) +KB_UPDATE_OUTPUT_SCHEMA = _object_schema( + required=("kb",), + kb=_nullable(KNOWLEDGE_BASE_RECORD_SCHEMA), +) +KB_DELETE_INPUT_SCHEMA = _object_schema( + required=("kb_id",), + kb_id={"type": "string"}, +) +KB_DELETE_OUTPUT_SCHEMA = _object_schema( + required=("deleted",), + deleted={"type": "boolean"}, +) +KB_RETRIEVE_INPUT_SCHEMA = _object_schema( + required=("query",), + query={"type": "string"}, + kb_ids={"type": "array", "items": {"type": "string"}}, + kb_names={"type": "array", "items": {"type": "string"}}, + top_k_fusion={"type": "integer"}, + top_m_final={"type": "integer"}, +) +KB_RETRIEVE_OUTPUT_SCHEMA = _object_schema( + required=("result",), + result=_nullable( + _object_schema( + required=("context_text", "results"), + context_text={"type": "string"}, + results={ + "type": "array", + "items": KNOWLEDGE_BASE_RETRIEVE_RESULT_SCHEMA, + }, + ) + ), +) +KB_DOCUMENT_UPLOAD_INPUT_SCHEMA = _object_schema( + required=("kb_id", "document"), + kb_id={"type": "string"}, + document=KNOWLEDGE_BASE_DOCUMENT_UPLOAD_SCHEMA, +) +KB_DOCUMENT_UPLOAD_OUTPUT_SCHEMA = _object_schema( + required=("document",), + document=KNOWLEDGE_BASE_DOCUMENT_RECORD_SCHEMA, +) +KB_DOCUMENT_LIST_INPUT_SCHEMA = _object_schema( + required=("kb_id",), + kb_id={"type": "string"}, + offset={"type": "integer"}, + limit={"type": "integer"}, +) +KB_DOCUMENT_LIST_OUTPUT_SCHEMA = _object_schema( + required=("documents",), + documents={"type": "array", "items": KNOWLEDGE_BASE_DOCUMENT_RECORD_SCHEMA}, +) +KB_DOCUMENT_GET_INPUT_SCHEMA = _object_schema( + required=("kb_id", "doc_id"), + kb_id={"type": "string"}, + doc_id={"type": "string"}, +) +KB_DOCUMENT_GET_OUTPUT_SCHEMA = _object_schema( + required=("document",), + document=_nullable(KNOWLEDGE_BASE_DOCUMENT_RECORD_SCHEMA), +) +KB_DOCUMENT_DELETE_INPUT_SCHEMA = _object_schema( + required=("kb_id", "doc_id"), + kb_id={"type": "string"}, + doc_id={"type": "string"}, +) +KB_DOCUMENT_DELETE_OUTPUT_SCHEMA = _object_schema( + required=("deleted",), + deleted={"type": "boolean"}, +) +KB_DOCUMENT_REFRESH_INPUT_SCHEMA = _object_schema( + required=("kb_id", "doc_id"), + kb_id={"type": "string"}, + doc_id={"type": "string"}, +) +KB_DOCUMENT_REFRESH_OUTPUT_SCHEMA = _object_schema( + required=("document",), + document=_nullable(KNOWLEDGE_BASE_DOCUMENT_RECORD_SCHEMA), +) +REGISTRY_COMMAND_REGISTER_INPUT_SCHEMA = _object_schema( + required=("command_name", "handler_full_name"), + command_name={"type": "string"}, + handler_full_name={"type": "string"}, + source_event_type={"type": "string"}, + desc={"type": "string"}, + priority={"type": "integer"}, + use_regex={"type": "boolean"}, + ignore_prefix={"type": "boolean"}, +) +REGISTRY_COMMAND_REGISTER_OUTPUT_SCHEMA = _object_schema() +SKILL_REGISTER_INPUT_SCHEMA = _object_schema( + required=("name", "path"), + name={"type": "string"}, + path={"type": "string"}, + description={"type": "string"}, +) +SKILL_REGISTER_OUTPUT_SCHEMA = _object_schema( + required=("name", "description", "path", "skill_dir"), + name={"type": "string"}, + description={"type": "string"}, + path={"type": "string"}, + skill_dir={"type": "string"}, +) +SKILL_UNREGISTER_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +SKILL_UNREGISTER_OUTPUT_SCHEMA = _object_schema( + required=("removed",), + removed={"type": "boolean"}, +) +SKILL_LIST_INPUT_SCHEMA = _object_schema() +SKILL_LIST_OUTPUT_SCHEMA = _object_schema( + required=("skills",), + skills={ + "type": "array", + "items": SKILL_REGISTER_OUTPUT_SCHEMA, + }, +) +HTTP_REGISTER_API_INPUT_SCHEMA = _object_schema( + required=("route", "methods", "handler_capability"), + route={"type": "string"}, + methods={"type": "array", "items": {"type": "string"}}, + handler_capability={"type": "string"}, + description={"type": "string"}, +) +HTTP_REGISTER_API_OUTPUT_SCHEMA = _object_schema() +HTTP_UNREGISTER_API_INPUT_SCHEMA = _object_schema( + required=("route", "methods"), + route={"type": "string"}, + methods={"type": "array", "items": {"type": "string"}}, +) +HTTP_UNREGISTER_API_OUTPUT_SCHEMA = _object_schema() +HTTP_LIST_APIS_INPUT_SCHEMA = _object_schema() +HTTP_LIST_APIS_OUTPUT_SCHEMA = _object_schema( + required=("apis",), + apis={"type": "array", "items": {"type": "object"}}, +) +METADATA_GET_PLUGIN_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +METADATA_GET_PLUGIN_OUTPUT_SCHEMA = _object_schema( + required=("plugin",), + plugin=_nullable({"type": "object"}), +) +METADATA_LIST_PLUGINS_INPUT_SCHEMA = _object_schema() +METADATA_LIST_PLUGINS_OUTPUT_SCHEMA = _object_schema( + required=("plugins",), + plugins={"type": "array", "items": {"type": "object"}}, +) +METADATA_GET_PLUGIN_CONFIG_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +METADATA_GET_PLUGIN_CONFIG_OUTPUT_SCHEMA = _object_schema( + required=("config",), + config=_nullable({"type": "object"}), +) +METADATA_SAVE_PLUGIN_CONFIG_INPUT_SCHEMA = _object_schema( + required=("config",), + config={"type": "object"}, +) +METADATA_SAVE_PLUGIN_CONFIG_OUTPUT_SCHEMA = _object_schema( + required=("config",), + config=_nullable({"type": "object"}), +) +REGISTRY_GET_HANDLERS_BY_EVENT_TYPE_INPUT_SCHEMA = _object_schema( + required=("event_type",), + event_type={"type": "string"}, +) +REGISTRY_GET_HANDLERS_BY_EVENT_TYPE_OUTPUT_SCHEMA = _object_schema( + required=("handlers",), + handlers={"type": "array", "items": {"type": "object"}}, +) +REGISTRY_GET_HANDLER_BY_FULL_NAME_INPUT_SCHEMA = _object_schema( + required=("full_name",), + full_name={"type": "string"}, +) +REGISTRY_GET_HANDLER_BY_FULL_NAME_OUTPUT_SCHEMA = _object_schema( + required=("handler",), + handler=_nullable({"type": "object"}), +) +PROVIDER_META_SCHEMA = _object_schema( + required=("id", "type", "provider_type"), + id={"type": "string"}, + model=_nullable({"type": "string"}), + type={"type": "string"}, + provider_type={"type": "string"}, +) +MANAGED_PROVIDER_RECORD_SCHEMA = _object_schema( + required=("id", "type", "provider_type", "loaded", "enabled"), + id={"type": "string"}, + model=_nullable({"type": "string"}), + type={"type": "string"}, + provider_type={"type": "string"}, + loaded={"type": "boolean"}, + enabled={"type": "boolean"}, + provider_source_id=_nullable({"type": "string"}), +) +PROVIDER_CHANGE_EVENT_SCHEMA = _object_schema( + required=("provider_id", "provider_type"), + provider_id={"type": "string"}, + provider_type={"type": "string"}, + umo=_nullable({"type": "string"}), +) +LLM_TOOL_SPEC_SCHEMA = _object_schema( + required=("name", "description", "parameters_schema", "active"), + name={"type": "string"}, + description={"type": "string"}, + parameters_schema={"type": "object"}, + handler_ref=_nullable({"type": "string"}), + handler_capability=_nullable({"type": "string"}), + active={"type": "boolean"}, +) +AGENT_SPEC_SCHEMA = _object_schema( + required=("name", "description", "tool_names", "runner_class"), + name={"type": "string"}, + description={"type": "string"}, + tool_names={"type": "array", "items": {"type": "string"}}, + runner_class={"type": "string"}, +) +PROVIDER_GET_USING_INPUT_SCHEMA = _object_schema(umo=_nullable({"type": "string"})) +PROVIDER_GET_USING_OUTPUT_SCHEMA = _object_schema( + required=("provider",), + provider=_nullable(PROVIDER_META_SCHEMA), +) +PROVIDER_GET_BY_ID_INPUT_SCHEMA = _object_schema( + required=("provider_id",), + provider_id={"type": "string"}, +) +PROVIDER_GET_BY_ID_OUTPUT_SCHEMA = _object_schema( + required=("provider",), + provider=_nullable(PROVIDER_META_SCHEMA), +) +PROVIDER_GET_CURRENT_CHAT_PROVIDER_ID_INPUT_SCHEMA = _object_schema( + umo=_nullable({"type": "string"}), +) +PROVIDER_GET_CURRENT_CHAT_PROVIDER_ID_OUTPUT_SCHEMA = _object_schema( + required=("provider_id",), + provider_id=_nullable({"type": "string"}), +) +PROVIDER_LIST_ALL_INPUT_SCHEMA = _object_schema() +PROVIDER_LIST_ALL_OUTPUT_SCHEMA = _object_schema( + required=("providers",), + providers={"type": "array", "items": PROVIDER_META_SCHEMA}, +) +PROVIDER_STT_GET_TEXT_INPUT_SCHEMA = _object_schema( + required=("provider_id", "audio_url"), + provider_id={"type": "string"}, + audio_url={"type": "string"}, +) +PROVIDER_STT_GET_TEXT_OUTPUT_SCHEMA = _object_schema( + required=("text",), + text={"type": "string"}, +) +PROVIDER_TTS_GET_AUDIO_INPUT_SCHEMA = _object_schema( + required=("provider_id", "text"), + provider_id={"type": "string"}, + text={"type": "string"}, +) +PROVIDER_TTS_GET_AUDIO_OUTPUT_SCHEMA = _object_schema( + required=("audio_path",), + audio_path={"type": "string"}, +) +PROVIDER_TTS_SUPPORT_STREAM_INPUT_SCHEMA = _object_schema( + required=("provider_id",), + provider_id={"type": "string"}, +) +PROVIDER_TTS_SUPPORT_STREAM_OUTPUT_SCHEMA = _object_schema( + required=("supported",), + supported={"type": "boolean"}, +) +PROVIDER_TTS_AUDIO_CHUNK_SCHEMA = _object_schema( + required=("audio_base64",), + audio_base64={"type": "string"}, + text=_nullable({"type": "string"}), +) +PROVIDER_TTS_GET_AUDIO_STREAM_INPUT_SCHEMA = _object_schema( + required=("provider_id",), + provider_id={"type": "string"}, + text=_nullable({"type": "string"}), + text_chunks={"type": "array", "items": {"type": "string"}}, +) +PROVIDER_TTS_GET_AUDIO_STREAM_OUTPUT_SCHEMA = PROVIDER_TTS_AUDIO_CHUNK_SCHEMA +PROVIDER_EMBEDDING_GET_INPUT_SCHEMA = _object_schema( + required=("provider_id", "text"), + provider_id={"type": "string"}, + text={"type": "string"}, +) +PROVIDER_EMBEDDING_GET_OUTPUT_SCHEMA = _object_schema( + required=("embedding",), + embedding={"type": "array", "items": {"type": "number"}}, +) +PROVIDER_EMBEDDING_GET_MANY_INPUT_SCHEMA = _object_schema( + required=("provider_id", "texts"), + provider_id={"type": "string"}, + texts={"type": "array", "items": {"type": "string"}}, +) +PROVIDER_EMBEDDING_GET_MANY_OUTPUT_SCHEMA = _object_schema( + required=("embeddings",), + embeddings={ + "type": "array", + "items": {"type": "array", "items": {"type": "number"}}, + }, +) +PROVIDER_EMBEDDING_GET_DIM_INPUT_SCHEMA = _object_schema( + required=("provider_id",), + provider_id={"type": "string"}, +) +PROVIDER_EMBEDDING_GET_DIM_OUTPUT_SCHEMA = _object_schema( + required=("dim",), + dim={"type": "integer"}, +) +PROVIDER_RERANK_RESULT_SCHEMA = _object_schema( + required=("index", "score", "document"), + index={"type": "integer"}, + score={"type": "number"}, + document={"type": "string"}, +) +PROVIDER_RERANK_INPUT_SCHEMA = _object_schema( + required=("provider_id", "query", "documents"), + provider_id={"type": "string"}, + query={"type": "string"}, + documents={"type": "array", "items": {"type": "string"}}, + top_n=_nullable({"type": "integer"}), +) +PROVIDER_RERANK_OUTPUT_SCHEMA = _object_schema( + required=("results",), + results={"type": "array", "items": PROVIDER_RERANK_RESULT_SCHEMA}, +) +PROVIDER_MANAGER_SET_INPUT_SCHEMA = _object_schema( + required=("provider_id", "provider_type"), + provider_id={"type": "string"}, + provider_type={"type": "string"}, + umo=_nullable({"type": "string"}), +) +PROVIDER_MANAGER_SET_OUTPUT_SCHEMA = _object_schema() +PROVIDER_MANAGER_GET_BY_ID_INPUT_SCHEMA = _object_schema( + required=("provider_id",), + provider_id={"type": "string"}, +) +PROVIDER_MANAGER_GET_BY_ID_OUTPUT_SCHEMA = _object_schema( + required=("provider",), + provider=_nullable(MANAGED_PROVIDER_RECORD_SCHEMA), +) +PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_INPUT_SCHEMA = _object_schema( + required=("provider_id",), + provider_id={"type": "string"}, +) +PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_OUTPUT_SCHEMA = _object_schema( + required=("config",), + config=_nullable({"type": "object"}), +) +PROVIDER_MANAGER_LOAD_INPUT_SCHEMA = _object_schema( + required=("provider_config",), + provider_config={"type": "object"}, +) +PROVIDER_MANAGER_LOAD_OUTPUT_SCHEMA = _object_schema( + required=("provider",), + provider=_nullable(MANAGED_PROVIDER_RECORD_SCHEMA), +) +PROVIDER_MANAGER_TERMINATE_INPUT_SCHEMA = _object_schema( + required=("provider_id",), + provider_id={"type": "string"}, +) +PROVIDER_MANAGER_TERMINATE_OUTPUT_SCHEMA = _object_schema() +PROVIDER_MANAGER_CREATE_INPUT_SCHEMA = _object_schema( + required=("provider_config",), + provider_config={"type": "object"}, +) +PROVIDER_MANAGER_CREATE_OUTPUT_SCHEMA = _object_schema( + required=("provider",), + provider=_nullable(MANAGED_PROVIDER_RECORD_SCHEMA), +) +PROVIDER_MANAGER_UPDATE_INPUT_SCHEMA = _object_schema( + required=("origin_provider_id", "new_config"), + origin_provider_id={"type": "string"}, + new_config={"type": "object"}, +) +PROVIDER_MANAGER_UPDATE_OUTPUT_SCHEMA = _object_schema( + required=("provider",), + provider=_nullable(MANAGED_PROVIDER_RECORD_SCHEMA), +) +PROVIDER_MANAGER_DELETE_INPUT_SCHEMA = _object_schema( + provider_id=_nullable({"type": "string"}), + provider_source_id=_nullable({"type": "string"}), +) +PROVIDER_MANAGER_DELETE_OUTPUT_SCHEMA = _object_schema() +PROVIDER_MANAGER_GET_INSTS_INPUT_SCHEMA = _object_schema() +PROVIDER_MANAGER_GET_INSTS_OUTPUT_SCHEMA = _object_schema( + required=("providers",), + providers={"type": "array", "items": MANAGED_PROVIDER_RECORD_SCHEMA}, +) +PROVIDER_MANAGER_WATCH_CHANGES_INPUT_SCHEMA = _object_schema() +PROVIDER_MANAGER_WATCH_CHANGES_OUTPUT_SCHEMA = _object_schema( + required=("provider_id", "provider_type"), + provider_id={"type": "string"}, + provider_type={"type": "string"}, + umo=_nullable({"type": "string"}), +) +LLM_TOOL_MANAGER_GET_INPUT_SCHEMA = _object_schema() +LLM_TOOL_MANAGER_GET_OUTPUT_SCHEMA = _object_schema( + required=("registered", "active"), + registered={"type": "array", "items": LLM_TOOL_SPEC_SCHEMA}, + active={"type": "array", "items": LLM_TOOL_SPEC_SCHEMA}, +) +LLM_TOOL_MANAGER_ACTIVATE_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +LLM_TOOL_MANAGER_ACTIVATE_OUTPUT_SCHEMA = _object_schema( + required=("activated",), + activated={"type": "boolean"}, +) +LLM_TOOL_MANAGER_DEACTIVATE_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +LLM_TOOL_MANAGER_DEACTIVATE_OUTPUT_SCHEMA = _object_schema( + required=("deactivated",), + deactivated={"type": "boolean"}, +) +LLM_TOOL_MANAGER_ADD_INPUT_SCHEMA = _object_schema( + required=("tools",), + tools={"type": "array", "items": LLM_TOOL_SPEC_SCHEMA}, +) +LLM_TOOL_MANAGER_ADD_OUTPUT_SCHEMA = _object_schema( + required=("names",), + names={"type": "array", "items": {"type": "string"}}, +) +LLM_TOOL_MANAGER_REMOVE_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +LLM_TOOL_MANAGER_REMOVE_OUTPUT_SCHEMA = _object_schema( + required=("removed",), + removed={"type": "boolean"}, +) +AGENT_TOOL_LOOP_RUN_INPUT_SCHEMA = _object_schema( + prompt=_nullable({"type": "string"}), + system_prompt=_nullable({"type": "string"}), + session_id=_nullable({"type": "string"}), + contexts={"type": "array", "items": {"type": "object"}}, + image_urls={"type": "array", "items": {"type": "string"}}, + tool_names=_nullable({"type": "array", "items": {"type": "string"}}), + tool_calls_result={"type": "array", "items": {"type": "object"}}, + provider_id=_nullable({"type": "string"}), + model=_nullable({"type": "string"}), + temperature={"type": "number"}, + max_steps={"type": "integer"}, + tool_call_timeout={"type": "integer"}, +) +AGENT_TOOL_LOOP_RUN_OUTPUT_SCHEMA = LLM_CHAT_RAW_OUTPUT_SCHEMA +AGENT_REGISTRY_LIST_INPUT_SCHEMA = _object_schema() +AGENT_REGISTRY_LIST_OUTPUT_SCHEMA = _object_schema( + required=("agents",), + agents={"type": "array", "items": AGENT_SPEC_SCHEMA}, +) +AGENT_REGISTRY_GET_INPUT_SCHEMA = _object_schema( + required=("name",), + name={"type": "string"}, +) +AGENT_REGISTRY_GET_OUTPUT_SCHEMA = _object_schema( + required=("agent",), + agent=_nullable(AGENT_SPEC_SCHEMA), +) + +BUILTIN_CAPABILITY_SCHEMAS: dict[str, dict[str, JSONSchema]] = { + "llm.chat": {"input": LLM_CHAT_INPUT_SCHEMA, "output": LLM_CHAT_OUTPUT_SCHEMA}, + "llm.chat_raw": { + "input": LLM_CHAT_RAW_INPUT_SCHEMA, + "output": LLM_CHAT_RAW_OUTPUT_SCHEMA, + }, + "llm.stream_chat": { + "input": LLM_STREAM_CHAT_INPUT_SCHEMA, + "output": LLM_STREAM_CHAT_OUTPUT_SCHEMA, + }, + "memory.search": { + "input": MEMORY_SEARCH_INPUT_SCHEMA, + "output": MEMORY_SEARCH_OUTPUT_SCHEMA, + }, + "memory.save": { + "input": MEMORY_SAVE_INPUT_SCHEMA, + "output": MEMORY_SAVE_OUTPUT_SCHEMA, + }, + "memory.get": { + "input": MEMORY_GET_INPUT_SCHEMA, + "output": MEMORY_GET_OUTPUT_SCHEMA, + }, + "memory.list_keys": { + "input": MEMORY_LIST_KEYS_INPUT_SCHEMA, + "output": MEMORY_LIST_KEYS_OUTPUT_SCHEMA, + }, + "memory.exists": { + "input": MEMORY_EXISTS_INPUT_SCHEMA, + "output": MEMORY_EXISTS_OUTPUT_SCHEMA, + }, + "memory.delete": { + "input": MEMORY_DELETE_INPUT_SCHEMA, + "output": MEMORY_DELETE_OUTPUT_SCHEMA, + }, + "memory.clear_namespace": { + "input": MEMORY_CLEAR_NAMESPACE_INPUT_SCHEMA, + "output": MEMORY_CLEAR_NAMESPACE_OUTPUT_SCHEMA, + }, + "memory.save_with_ttl": { + "input": MEMORY_SAVE_WITH_TTL_INPUT_SCHEMA, + "output": MEMORY_SAVE_WITH_TTL_OUTPUT_SCHEMA, + }, + "memory.get_many": { + "input": MEMORY_GET_MANY_INPUT_SCHEMA, + "output": MEMORY_GET_MANY_OUTPUT_SCHEMA, + }, + "memory.delete_many": { + "input": MEMORY_DELETE_MANY_INPUT_SCHEMA, + "output": MEMORY_DELETE_MANY_OUTPUT_SCHEMA, + }, + "memory.count": { + "input": MEMORY_COUNT_INPUT_SCHEMA, + "output": MEMORY_COUNT_OUTPUT_SCHEMA, + }, + "memory.stats": { + "input": MEMORY_STATS_INPUT_SCHEMA, + "output": MEMORY_STATS_OUTPUT_SCHEMA, + }, + "db.get": {"input": DB_GET_INPUT_SCHEMA, "output": DB_GET_OUTPUT_SCHEMA}, + "db.set": {"input": DB_SET_INPUT_SCHEMA, "output": DB_SET_OUTPUT_SCHEMA}, + "db.delete": {"input": DB_DELETE_INPUT_SCHEMA, "output": DB_DELETE_OUTPUT_SCHEMA}, + "db.list": {"input": DB_LIST_INPUT_SCHEMA, "output": DB_LIST_OUTPUT_SCHEMA}, + "db.get_many": { + "input": DB_GET_MANY_INPUT_SCHEMA, + "output": DB_GET_MANY_OUTPUT_SCHEMA, + }, + "db.set_many": { + "input": DB_SET_MANY_INPUT_SCHEMA, + "output": DB_SET_MANY_OUTPUT_SCHEMA, + }, + "db.watch": {"input": DB_WATCH_INPUT_SCHEMA, "output": DB_WATCH_OUTPUT_SCHEMA}, + "platform.send": { + "input": PLATFORM_SEND_INPUT_SCHEMA, + "output": PLATFORM_SEND_OUTPUT_SCHEMA, + }, + "platform.send_image": { + "input": PLATFORM_SEND_IMAGE_INPUT_SCHEMA, + "output": PLATFORM_SEND_IMAGE_OUTPUT_SCHEMA, + }, + "platform.send_chain": { + "input": PLATFORM_SEND_CHAIN_INPUT_SCHEMA, + "output": PLATFORM_SEND_CHAIN_OUTPUT_SCHEMA, + }, + "platform.send_by_session": { + "input": PLATFORM_SEND_BY_SESSION_INPUT_SCHEMA, + "output": PLATFORM_SEND_BY_SESSION_OUTPUT_SCHEMA, + }, + "platform.get_group": { + "input": PLATFORM_GET_GROUP_INPUT_SCHEMA, + "output": PLATFORM_GET_GROUP_OUTPUT_SCHEMA, + }, + "platform.get_members": { + "input": PLATFORM_GET_MEMBERS_INPUT_SCHEMA, + "output": PLATFORM_GET_MEMBERS_OUTPUT_SCHEMA, + }, + "platform.list_instances": { + "input": PLATFORM_LIST_INSTANCES_INPUT_SCHEMA, + "output": PLATFORM_LIST_INSTANCES_OUTPUT_SCHEMA, + }, + "session.plugin.is_enabled": { + "input": SESSION_PLUGIN_IS_ENABLED_INPUT_SCHEMA, + "output": SESSION_PLUGIN_IS_ENABLED_OUTPUT_SCHEMA, + }, + "session.plugin.filter_handlers": { + "input": SESSION_PLUGIN_FILTER_HANDLERS_INPUT_SCHEMA, + "output": SESSION_PLUGIN_FILTER_HANDLERS_OUTPUT_SCHEMA, + }, + "session.service.is_llm_enabled": { + "input": SESSION_SERVICE_IS_LLM_ENABLED_INPUT_SCHEMA, + "output": SESSION_SERVICE_IS_LLM_ENABLED_OUTPUT_SCHEMA, + }, + "session.service.set_llm_status": { + "input": SESSION_SERVICE_SET_LLM_STATUS_INPUT_SCHEMA, + "output": SESSION_SERVICE_SET_LLM_STATUS_OUTPUT_SCHEMA, + }, + "session.service.is_tts_enabled": { + "input": SESSION_SERVICE_IS_TTS_ENABLED_INPUT_SCHEMA, + "output": SESSION_SERVICE_IS_TTS_ENABLED_OUTPUT_SCHEMA, + }, + "session.service.set_tts_status": { + "input": SESSION_SERVICE_SET_TTS_STATUS_INPUT_SCHEMA, + "output": SESSION_SERVICE_SET_TTS_STATUS_OUTPUT_SCHEMA, + }, + "persona.get": { + "input": PERSONA_GET_INPUT_SCHEMA, + "output": PERSONA_GET_OUTPUT_SCHEMA, + }, + "persona.list": { + "input": PERSONA_LIST_INPUT_SCHEMA, + "output": PERSONA_LIST_OUTPUT_SCHEMA, + }, + "persona.create": { + "input": PERSONA_CREATE_INPUT_SCHEMA, + "output": PERSONA_CREATE_OUTPUT_SCHEMA, + }, + "persona.update": { + "input": PERSONA_UPDATE_INPUT_SCHEMA, + "output": PERSONA_UPDATE_OUTPUT_SCHEMA, + }, + "persona.delete": { + "input": PERSONA_DELETE_INPUT_SCHEMA, + "output": PERSONA_DELETE_OUTPUT_SCHEMA, + }, + "conversation.new": { + "input": CONVERSATION_NEW_INPUT_SCHEMA, + "output": CONVERSATION_NEW_OUTPUT_SCHEMA, + }, + "conversation.switch": { + "input": CONVERSATION_SWITCH_INPUT_SCHEMA, + "output": CONVERSATION_SWITCH_OUTPUT_SCHEMA, + }, + "conversation.delete": { + "input": CONVERSATION_DELETE_INPUT_SCHEMA, + "output": CONVERSATION_DELETE_OUTPUT_SCHEMA, + }, + "conversation.get": { + "input": CONVERSATION_GET_INPUT_SCHEMA, + "output": CONVERSATION_GET_OUTPUT_SCHEMA, + }, + "conversation.get_current": { + "input": CONVERSATION_GET_CURRENT_INPUT_SCHEMA, + "output": CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA, + }, + "conversation.list": { + "input": CONVERSATION_LIST_INPUT_SCHEMA, + "output": CONVERSATION_LIST_OUTPUT_SCHEMA, + }, + "conversation.update": { + "input": CONVERSATION_UPDATE_INPUT_SCHEMA, + "output": CONVERSATION_UPDATE_OUTPUT_SCHEMA, + }, + "conversation.unset_persona": { + "input": CONVERSATION_UNSET_PERSONA_INPUT_SCHEMA, + "output": CONVERSATION_UNSET_PERSONA_OUTPUT_SCHEMA, + }, + "message_history.list": { + "input": MESSAGE_HISTORY_LIST_INPUT_SCHEMA, + "output": MESSAGE_HISTORY_LIST_OUTPUT_SCHEMA, + }, + "message_history.get_by_id": { + "input": MESSAGE_HISTORY_GET_BY_ID_INPUT_SCHEMA, + "output": MESSAGE_HISTORY_GET_BY_ID_OUTPUT_SCHEMA, + }, + "message_history.append": { + "input": MESSAGE_HISTORY_APPEND_INPUT_SCHEMA, + "output": MESSAGE_HISTORY_APPEND_OUTPUT_SCHEMA, + }, + "message_history.delete_before": { + "input": MESSAGE_HISTORY_DELETE_BEFORE_INPUT_SCHEMA, + "output": MESSAGE_HISTORY_DELETE_BEFORE_OUTPUT_SCHEMA, + }, + "message_history.delete_after": { + "input": MESSAGE_HISTORY_DELETE_AFTER_INPUT_SCHEMA, + "output": MESSAGE_HISTORY_DELETE_AFTER_OUTPUT_SCHEMA, + }, + "message_history.delete_all": { + "input": MESSAGE_HISTORY_DELETE_ALL_INPUT_SCHEMA, + "output": MESSAGE_HISTORY_DELETE_ALL_OUTPUT_SCHEMA, + }, + "mcp.local.get": { + "input": MCP_LOCAL_GET_INPUT_SCHEMA, + "output": MCP_LOCAL_GET_OUTPUT_SCHEMA, + }, + "mcp.local.list": { + "input": MCP_LOCAL_LIST_INPUT_SCHEMA, + "output": MCP_LOCAL_LIST_OUTPUT_SCHEMA, + }, + "mcp.local.enable": { + "input": MCP_LOCAL_ENABLE_INPUT_SCHEMA, + "output": MCP_LOCAL_ENABLE_OUTPUT_SCHEMA, + }, + "mcp.local.disable": { + "input": MCP_LOCAL_DISABLE_INPUT_SCHEMA, + "output": MCP_LOCAL_DISABLE_OUTPUT_SCHEMA, + }, + "mcp.local.wait_until_ready": { + "input": MCP_LOCAL_WAIT_UNTIL_READY_INPUT_SCHEMA, + "output": MCP_LOCAL_WAIT_UNTIL_READY_OUTPUT_SCHEMA, + }, + "mcp.session.open": { + "input": MCP_SESSION_OPEN_INPUT_SCHEMA, + "output": MCP_SESSION_OPEN_OUTPUT_SCHEMA, + }, + "mcp.session.list_tools": { + "input": MCP_SESSION_LIST_TOOLS_INPUT_SCHEMA, + "output": MCP_SESSION_LIST_TOOLS_OUTPUT_SCHEMA, + }, + "mcp.session.call_tool": { + "input": MCP_SESSION_CALL_TOOL_INPUT_SCHEMA, + "output": MCP_SESSION_CALL_TOOL_OUTPUT_SCHEMA, + }, + "mcp.session.close": { + "input": MCP_SESSION_CLOSE_INPUT_SCHEMA, + "output": MCP_SESSION_CLOSE_OUTPUT_SCHEMA, + }, + "mcp.global.register": { + "input": MCP_GLOBAL_REGISTER_INPUT_SCHEMA, + "output": MCP_GLOBAL_REGISTER_OUTPUT_SCHEMA, + }, + "mcp.global.get": { + "input": MCP_GLOBAL_GET_INPUT_SCHEMA, + "output": MCP_GLOBAL_GET_OUTPUT_SCHEMA, + }, + "mcp.global.list": { + "input": MCP_GLOBAL_LIST_INPUT_SCHEMA, + "output": MCP_GLOBAL_LIST_OUTPUT_SCHEMA, + }, + "mcp.global.enable": { + "input": MCP_GLOBAL_ENABLE_INPUT_SCHEMA, + "output": MCP_GLOBAL_ENABLE_OUTPUT_SCHEMA, + }, + "mcp.global.disable": { + "input": MCP_GLOBAL_DISABLE_INPUT_SCHEMA, + "output": MCP_GLOBAL_DISABLE_OUTPUT_SCHEMA, + }, + "mcp.global.unregister": { + "input": MCP_GLOBAL_UNREGISTER_INPUT_SCHEMA, + "output": MCP_GLOBAL_UNREGISTER_OUTPUT_SCHEMA, + }, + "internal.mcp.local.execute": { + "input": INTERNAL_MCP_LOCAL_EXECUTE_INPUT_SCHEMA, + "output": INTERNAL_MCP_LOCAL_EXECUTE_OUTPUT_SCHEMA, + }, + "kb.list": {"input": KB_LIST_INPUT_SCHEMA, "output": KB_LIST_OUTPUT_SCHEMA}, + "kb.get": {"input": KB_GET_INPUT_SCHEMA, "output": KB_GET_OUTPUT_SCHEMA}, + "kb.create": { + "input": KB_CREATE_INPUT_SCHEMA, + "output": KB_CREATE_OUTPUT_SCHEMA, + }, + "kb.update": { + "input": KB_UPDATE_INPUT_SCHEMA, + "output": KB_UPDATE_OUTPUT_SCHEMA, + }, + "kb.delete": { + "input": KB_DELETE_INPUT_SCHEMA, + "output": KB_DELETE_OUTPUT_SCHEMA, + }, + "kb.retrieve": { + "input": KB_RETRIEVE_INPUT_SCHEMA, + "output": KB_RETRIEVE_OUTPUT_SCHEMA, + }, + "kb.document.upload": { + "input": KB_DOCUMENT_UPLOAD_INPUT_SCHEMA, + "output": KB_DOCUMENT_UPLOAD_OUTPUT_SCHEMA, + }, + "kb.document.list": { + "input": KB_DOCUMENT_LIST_INPUT_SCHEMA, + "output": KB_DOCUMENT_LIST_OUTPUT_SCHEMA, + }, + "kb.document.get": { + "input": KB_DOCUMENT_GET_INPUT_SCHEMA, + "output": KB_DOCUMENT_GET_OUTPUT_SCHEMA, + }, + "kb.document.delete": { + "input": KB_DOCUMENT_DELETE_INPUT_SCHEMA, + "output": KB_DOCUMENT_DELETE_OUTPUT_SCHEMA, + }, + "kb.document.refresh": { + "input": KB_DOCUMENT_REFRESH_INPUT_SCHEMA, + "output": KB_DOCUMENT_REFRESH_OUTPUT_SCHEMA, + }, + "registry.command.register": { + "input": REGISTRY_COMMAND_REGISTER_INPUT_SCHEMA, + "output": REGISTRY_COMMAND_REGISTER_OUTPUT_SCHEMA, + }, + "skill.register": { + "input": SKILL_REGISTER_INPUT_SCHEMA, + "output": SKILL_REGISTER_OUTPUT_SCHEMA, + }, + "skill.unregister": { + "input": SKILL_UNREGISTER_INPUT_SCHEMA, + "output": SKILL_UNREGISTER_OUTPUT_SCHEMA, + }, + "skill.list": { + "input": SKILL_LIST_INPUT_SCHEMA, + "output": SKILL_LIST_OUTPUT_SCHEMA, + }, + "http.register_api": { + "input": HTTP_REGISTER_API_INPUT_SCHEMA, + "output": HTTP_REGISTER_API_OUTPUT_SCHEMA, + }, + "http.unregister_api": { + "input": HTTP_UNREGISTER_API_INPUT_SCHEMA, + "output": HTTP_UNREGISTER_API_OUTPUT_SCHEMA, + }, + "http.list_apis": { + "input": HTTP_LIST_APIS_INPUT_SCHEMA, + "output": HTTP_LIST_APIS_OUTPUT_SCHEMA, + }, + "metadata.get_plugin": { + "input": METADATA_GET_PLUGIN_INPUT_SCHEMA, + "output": METADATA_GET_PLUGIN_OUTPUT_SCHEMA, + }, + "metadata.list_plugins": { + "input": METADATA_LIST_PLUGINS_INPUT_SCHEMA, + "output": METADATA_LIST_PLUGINS_OUTPUT_SCHEMA, + }, + "metadata.get_plugin_config": { + "input": METADATA_GET_PLUGIN_CONFIG_INPUT_SCHEMA, + "output": METADATA_GET_PLUGIN_CONFIG_OUTPUT_SCHEMA, + }, + "metadata.save_plugin_config": { + "input": METADATA_SAVE_PLUGIN_CONFIG_INPUT_SCHEMA, + "output": METADATA_SAVE_PLUGIN_CONFIG_OUTPUT_SCHEMA, + }, + "registry.get_handlers_by_event_type": { + "input": REGISTRY_GET_HANDLERS_BY_EVENT_TYPE_INPUT_SCHEMA, + "output": REGISTRY_GET_HANDLERS_BY_EVENT_TYPE_OUTPUT_SCHEMA, + }, + "registry.get_handler_by_full_name": { + "input": REGISTRY_GET_HANDLER_BY_FULL_NAME_INPUT_SCHEMA, + "output": REGISTRY_GET_HANDLER_BY_FULL_NAME_OUTPUT_SCHEMA, + }, + "provider.get_using": { + "input": PROVIDER_GET_USING_INPUT_SCHEMA, + "output": PROVIDER_GET_USING_OUTPUT_SCHEMA, + }, + "provider.get_by_id": { + "input": PROVIDER_GET_BY_ID_INPUT_SCHEMA, + "output": PROVIDER_GET_BY_ID_OUTPUT_SCHEMA, + }, + "provider.get_current_chat_provider_id": { + "input": PROVIDER_GET_CURRENT_CHAT_PROVIDER_ID_INPUT_SCHEMA, + "output": PROVIDER_GET_CURRENT_CHAT_PROVIDER_ID_OUTPUT_SCHEMA, + }, + "provider.list_all": { + "input": PROVIDER_LIST_ALL_INPUT_SCHEMA, + "output": PROVIDER_LIST_ALL_OUTPUT_SCHEMA, + }, + "provider.list_all_tts": { + "input": PROVIDER_LIST_ALL_INPUT_SCHEMA, + "output": PROVIDER_LIST_ALL_OUTPUT_SCHEMA, + }, + "provider.list_all_stt": { + "input": PROVIDER_LIST_ALL_INPUT_SCHEMA, + "output": PROVIDER_LIST_ALL_OUTPUT_SCHEMA, + }, + "provider.list_all_embedding": { + "input": PROVIDER_LIST_ALL_INPUT_SCHEMA, + "output": PROVIDER_LIST_ALL_OUTPUT_SCHEMA, + }, + "provider.list_all_rerank": { + "input": PROVIDER_LIST_ALL_INPUT_SCHEMA, + "output": PROVIDER_LIST_ALL_OUTPUT_SCHEMA, + }, + "provider.get_using_tts": { + "input": PROVIDER_GET_USING_INPUT_SCHEMA, + "output": PROVIDER_GET_USING_OUTPUT_SCHEMA, + }, + "provider.get_using_stt": { + "input": PROVIDER_GET_USING_INPUT_SCHEMA, + "output": PROVIDER_GET_USING_OUTPUT_SCHEMA, + }, + "provider.stt.get_text": { + "input": PROVIDER_STT_GET_TEXT_INPUT_SCHEMA, + "output": PROVIDER_STT_GET_TEXT_OUTPUT_SCHEMA, + }, + "provider.tts.get_audio": { + "input": PROVIDER_TTS_GET_AUDIO_INPUT_SCHEMA, + "output": PROVIDER_TTS_GET_AUDIO_OUTPUT_SCHEMA, + }, + "provider.tts.support_stream": { + "input": PROVIDER_TTS_SUPPORT_STREAM_INPUT_SCHEMA, + "output": PROVIDER_TTS_SUPPORT_STREAM_OUTPUT_SCHEMA, + }, + "provider.tts.get_audio_stream": { + "input": PROVIDER_TTS_GET_AUDIO_STREAM_INPUT_SCHEMA, + "output": PROVIDER_TTS_GET_AUDIO_STREAM_OUTPUT_SCHEMA, + }, + "provider.embedding.get_embedding": { + "input": PROVIDER_EMBEDDING_GET_INPUT_SCHEMA, + "output": PROVIDER_EMBEDDING_GET_OUTPUT_SCHEMA, + }, + "provider.embedding.get_embeddings": { + "input": PROVIDER_EMBEDDING_GET_MANY_INPUT_SCHEMA, + "output": PROVIDER_EMBEDDING_GET_MANY_OUTPUT_SCHEMA, + }, + "provider.embedding.get_dim": { + "input": PROVIDER_EMBEDDING_GET_DIM_INPUT_SCHEMA, + "output": PROVIDER_EMBEDDING_GET_DIM_OUTPUT_SCHEMA, + }, + "provider.rerank.rerank": { + "input": PROVIDER_RERANK_INPUT_SCHEMA, + "output": PROVIDER_RERANK_OUTPUT_SCHEMA, + }, + "provider.manager.set": { + "input": PROVIDER_MANAGER_SET_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_SET_OUTPUT_SCHEMA, + }, + "provider.manager.get_by_id": { + "input": PROVIDER_MANAGER_GET_BY_ID_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_GET_BY_ID_OUTPUT_SCHEMA, + }, + "provider.manager.get_merged_provider_config": { + "input": PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_OUTPUT_SCHEMA, + }, + "provider.manager.load": { + "input": PROVIDER_MANAGER_LOAD_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_LOAD_OUTPUT_SCHEMA, + }, + "provider.manager.terminate": { + "input": PROVIDER_MANAGER_TERMINATE_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_TERMINATE_OUTPUT_SCHEMA, + }, + "provider.manager.create": { + "input": PROVIDER_MANAGER_CREATE_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_CREATE_OUTPUT_SCHEMA, + }, + "provider.manager.update": { + "input": PROVIDER_MANAGER_UPDATE_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_UPDATE_OUTPUT_SCHEMA, + }, + "provider.manager.delete": { + "input": PROVIDER_MANAGER_DELETE_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_DELETE_OUTPUT_SCHEMA, + }, + "provider.manager.get_insts": { + "input": PROVIDER_MANAGER_GET_INSTS_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_GET_INSTS_OUTPUT_SCHEMA, + }, + "provider.manager.watch_changes": { + "input": PROVIDER_MANAGER_WATCH_CHANGES_INPUT_SCHEMA, + "output": PROVIDER_MANAGER_WATCH_CHANGES_OUTPUT_SCHEMA, + }, + "platform.manager.get_by_id": { + "input": PLATFORM_MANAGER_GET_BY_ID_INPUT_SCHEMA, + "output": PLATFORM_MANAGER_GET_BY_ID_OUTPUT_SCHEMA, + }, + "platform.manager.clear_errors": { + "input": PLATFORM_MANAGER_CLEAR_ERRORS_INPUT_SCHEMA, + "output": PLATFORM_MANAGER_CLEAR_ERRORS_OUTPUT_SCHEMA, + }, + "platform.manager.get_stats": { + "input": PLATFORM_MANAGER_GET_STATS_INPUT_SCHEMA, + "output": PLATFORM_MANAGER_GET_STATS_OUTPUT_SCHEMA, + }, + "permission.check": { + "input": PERMISSION_CHECK_INPUT_SCHEMA, + "output": PERMISSION_CHECK_OUTPUT_SCHEMA, + }, + "permission.get_admins": { + "input": PERMISSION_GET_ADMINS_INPUT_SCHEMA, + "output": PERMISSION_GET_ADMINS_OUTPUT_SCHEMA, + }, + "permission.manager.add_admin": { + "input": PERMISSION_MANAGER_ADD_ADMIN_INPUT_SCHEMA, + "output": PERMISSION_MANAGER_ADD_ADMIN_OUTPUT_SCHEMA, + }, + "permission.manager.remove_admin": { + "input": PERMISSION_MANAGER_REMOVE_ADMIN_INPUT_SCHEMA, + "output": PERMISSION_MANAGER_REMOVE_ADMIN_OUTPUT_SCHEMA, + }, + "llm_tool.manager.get": { + "input": LLM_TOOL_MANAGER_GET_INPUT_SCHEMA, + "output": LLM_TOOL_MANAGER_GET_OUTPUT_SCHEMA, + }, + "llm_tool.manager.activate": { + "input": LLM_TOOL_MANAGER_ACTIVATE_INPUT_SCHEMA, + "output": LLM_TOOL_MANAGER_ACTIVATE_OUTPUT_SCHEMA, + }, + "llm_tool.manager.deactivate": { + "input": LLM_TOOL_MANAGER_DEACTIVATE_INPUT_SCHEMA, + "output": LLM_TOOL_MANAGER_DEACTIVATE_OUTPUT_SCHEMA, + }, + "llm_tool.manager.add": { + "input": LLM_TOOL_MANAGER_ADD_INPUT_SCHEMA, + "output": LLM_TOOL_MANAGER_ADD_OUTPUT_SCHEMA, + }, + "llm_tool.manager.remove": { + "input": LLM_TOOL_MANAGER_REMOVE_INPUT_SCHEMA, + "output": LLM_TOOL_MANAGER_REMOVE_OUTPUT_SCHEMA, + }, + "agent.tool_loop.run": { + "input": AGENT_TOOL_LOOP_RUN_INPUT_SCHEMA, + "output": AGENT_TOOL_LOOP_RUN_OUTPUT_SCHEMA, + }, + "agent.registry.list": { + "input": AGENT_REGISTRY_LIST_INPUT_SCHEMA, + "output": AGENT_REGISTRY_LIST_OUTPUT_SCHEMA, + }, + "agent.registry.get": { + "input": AGENT_REGISTRY_GET_INPUT_SCHEMA, + "output": AGENT_REGISTRY_GET_OUTPUT_SCHEMA, + }, + "system.get_data_dir": { + "input": SYSTEM_GET_DATA_DIR_INPUT_SCHEMA, + "output": SYSTEM_GET_DATA_DIR_OUTPUT_SCHEMA, + }, + "system.text_to_image": { + "input": SYSTEM_TEXT_TO_IMAGE_INPUT_SCHEMA, + "output": SYSTEM_TEXT_TO_IMAGE_OUTPUT_SCHEMA, + }, + "system.html_render": { + "input": SYSTEM_HTML_RENDER_INPUT_SCHEMA, + "output": SYSTEM_HTML_RENDER_OUTPUT_SCHEMA, + }, + "system.file.register": { + "input": SYSTEM_FILE_REGISTER_INPUT_SCHEMA, + "output": SYSTEM_FILE_REGISTER_OUTPUT_SCHEMA, + }, + "system.file.handle": { + "input": SYSTEM_FILE_HANDLE_INPUT_SCHEMA, + "output": SYSTEM_FILE_HANDLE_OUTPUT_SCHEMA, + }, + "system.session_waiter.register": { + "input": SYSTEM_SESSION_WAITER_REGISTER_INPUT_SCHEMA, + "output": SYSTEM_SESSION_WAITER_REGISTER_OUTPUT_SCHEMA, + }, + "system.session_waiter.unregister": { + "input": SYSTEM_SESSION_WAITER_UNREGISTER_INPUT_SCHEMA, + "output": SYSTEM_SESSION_WAITER_UNREGISTER_OUTPUT_SCHEMA, + }, + "system.event.react": { + "input": SYSTEM_EVENT_REACT_INPUT_SCHEMA, + "output": SYSTEM_EVENT_REACT_OUTPUT_SCHEMA, + }, + "system.event.send_typing": { + "input": SYSTEM_EVENT_SEND_TYPING_INPUT_SCHEMA, + "output": SYSTEM_EVENT_SEND_TYPING_OUTPUT_SCHEMA, + }, + "system.event.send_streaming": { + "input": SYSTEM_EVENT_SEND_STREAMING_INPUT_SCHEMA, + "output": SYSTEM_EVENT_SEND_STREAMING_OUTPUT_SCHEMA, + }, + "system.event.send_streaming_chunk": { + "input": SYSTEM_EVENT_SEND_STREAMING_CHUNK_INPUT_SCHEMA, + "output": SYSTEM_EVENT_SEND_STREAMING_CHUNK_OUTPUT_SCHEMA, + }, + "system.event.send_streaming_close": { + "input": SYSTEM_EVENT_SEND_STREAMING_CLOSE_INPUT_SCHEMA, + "output": SYSTEM_EVENT_SEND_STREAMING_CLOSE_OUTPUT_SCHEMA, + }, + "system.event.llm.get_state": { + "input": SYSTEM_EVENT_LLM_GET_STATE_INPUT_SCHEMA, + "output": SYSTEM_EVENT_LLM_GET_STATE_OUTPUT_SCHEMA, + }, + "system.event.llm.request": { + "input": SYSTEM_EVENT_LLM_REQUEST_INPUT_SCHEMA, + "output": SYSTEM_EVENT_LLM_REQUEST_OUTPUT_SCHEMA, + }, + "system.event.result.get": { + "input": SYSTEM_EVENT_RESULT_GET_INPUT_SCHEMA, + "output": SYSTEM_EVENT_RESULT_GET_OUTPUT_SCHEMA, + }, + "system.event.result.set": { + "input": SYSTEM_EVENT_RESULT_SET_INPUT_SCHEMA, + "output": SYSTEM_EVENT_RESULT_SET_OUTPUT_SCHEMA, + }, + "system.event.result.clear": { + "input": SYSTEM_EVENT_RESULT_CLEAR_INPUT_SCHEMA, + "output": SYSTEM_EVENT_RESULT_CLEAR_OUTPUT_SCHEMA, + }, + "system.event.handler_whitelist.get": { + "input": SYSTEM_EVENT_HANDLER_WHITELIST_GET_INPUT_SCHEMA, + "output": SYSTEM_EVENT_HANDLER_WHITELIST_GET_OUTPUT_SCHEMA, + }, + "system.event.handler_whitelist.set": { + "input": SYSTEM_EVENT_HANDLER_WHITELIST_SET_INPUT_SCHEMA, + "output": SYSTEM_EVENT_HANDLER_WHITELIST_SET_OUTPUT_SCHEMA, + }, +} + + +__all__ = [ + "BUILTIN_CAPABILITY_SCHEMAS", + "DB_DELETE_INPUT_SCHEMA", + "DB_DELETE_OUTPUT_SCHEMA", + "DB_GET_INPUT_SCHEMA", + "DB_GET_MANY_INPUT_SCHEMA", + "DB_GET_MANY_OUTPUT_SCHEMA", + "DB_GET_OUTPUT_SCHEMA", + "DB_LIST_INPUT_SCHEMA", + "DB_LIST_OUTPUT_SCHEMA", + "DB_SET_INPUT_SCHEMA", + "DB_SET_MANY_INPUT_SCHEMA", + "DB_SET_MANY_OUTPUT_SCHEMA", + "DB_SET_OUTPUT_SCHEMA", + "DB_WATCH_INPUT_SCHEMA", + "DB_WATCH_OUTPUT_SCHEMA", + "HTTP_LIST_APIS_INPUT_SCHEMA", + "HTTP_LIST_APIS_OUTPUT_SCHEMA", + "HTTP_REGISTER_API_INPUT_SCHEMA", + "HTTP_REGISTER_API_OUTPUT_SCHEMA", + "HTTP_UNREGISTER_API_INPUT_SCHEMA", + "HTTP_UNREGISTER_API_OUTPUT_SCHEMA", + "JSONSchema", + "LLM_CHAT_INPUT_SCHEMA", + "LLM_CHAT_OUTPUT_SCHEMA", + "LLM_CHAT_RAW_INPUT_SCHEMA", + "LLM_CHAT_RAW_OUTPUT_SCHEMA", + "LLM_STREAM_CHAT_INPUT_SCHEMA", + "LLM_STREAM_CHAT_OUTPUT_SCHEMA", + "MEMORY_CLEAR_NAMESPACE_INPUT_SCHEMA", + "MEMORY_CLEAR_NAMESPACE_OUTPUT_SCHEMA", + "MEMORY_COUNT_INPUT_SCHEMA", + "MEMORY_COUNT_OUTPUT_SCHEMA", + "MEMORY_DELETE_INPUT_SCHEMA", + "MEMORY_DELETE_MANY_INPUT_SCHEMA", + "MEMORY_DELETE_MANY_OUTPUT_SCHEMA", + "MEMORY_DELETE_OUTPUT_SCHEMA", + "MEMORY_EXISTS_INPUT_SCHEMA", + "MEMORY_EXISTS_OUTPUT_SCHEMA", + "MEMORY_GET_INPUT_SCHEMA", + "MEMORY_GET_MANY_INPUT_SCHEMA", + "MEMORY_GET_MANY_OUTPUT_SCHEMA", + "MEMORY_GET_OUTPUT_SCHEMA", + "MEMORY_LIST_KEYS_INPUT_SCHEMA", + "MEMORY_LIST_KEYS_OUTPUT_SCHEMA", + "MEMORY_SAVE_INPUT_SCHEMA", + "MEMORY_SAVE_OUTPUT_SCHEMA", + "MEMORY_SAVE_WITH_TTL_INPUT_SCHEMA", + "MEMORY_SAVE_WITH_TTL_OUTPUT_SCHEMA", + "MEMORY_SEARCH_INPUT_SCHEMA", + "MEMORY_SEARCH_OUTPUT_SCHEMA", + "MEMORY_STATS_INPUT_SCHEMA", + "MEMORY_STATS_OUTPUT_SCHEMA", + "METADATA_GET_PLUGIN_CONFIG_INPUT_SCHEMA", + "METADATA_GET_PLUGIN_CONFIG_OUTPUT_SCHEMA", + "METADATA_SAVE_PLUGIN_CONFIG_INPUT_SCHEMA", + "METADATA_SAVE_PLUGIN_CONFIG_OUTPUT_SCHEMA", + "METADATA_GET_PLUGIN_INPUT_SCHEMA", + "METADATA_GET_PLUGIN_OUTPUT_SCHEMA", + "METADATA_LIST_PLUGINS_INPUT_SCHEMA", + "METADATA_LIST_PLUGINS_OUTPUT_SCHEMA", + "PROVIDER_GET_CURRENT_CHAT_PROVIDER_ID_INPUT_SCHEMA", + "PROVIDER_GET_CURRENT_CHAT_PROVIDER_ID_OUTPUT_SCHEMA", + "PROVIDER_GET_BY_ID_INPUT_SCHEMA", + "PROVIDER_GET_BY_ID_OUTPUT_SCHEMA", + "PROVIDER_GET_USING_INPUT_SCHEMA", + "PROVIDER_GET_USING_OUTPUT_SCHEMA", + "PROVIDER_EMBEDDING_GET_DIM_INPUT_SCHEMA", + "PROVIDER_EMBEDDING_GET_DIM_OUTPUT_SCHEMA", + "PROVIDER_EMBEDDING_GET_INPUT_SCHEMA", + "PROVIDER_EMBEDDING_GET_MANY_INPUT_SCHEMA", + "PROVIDER_EMBEDDING_GET_MANY_OUTPUT_SCHEMA", + "PROVIDER_EMBEDDING_GET_OUTPUT_SCHEMA", + "PROVIDER_CHANGE_EVENT_SCHEMA", + "PROVIDER_LIST_ALL_INPUT_SCHEMA", + "PROVIDER_LIST_ALL_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_CREATE_INPUT_SCHEMA", + "PROVIDER_MANAGER_CREATE_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_DELETE_INPUT_SCHEMA", + "PROVIDER_MANAGER_DELETE_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_GET_BY_ID_INPUT_SCHEMA", + "PROVIDER_MANAGER_GET_BY_ID_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_INPUT_SCHEMA", + "PROVIDER_MANAGER_GET_MERGED_PROVIDER_CONFIG_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_GET_INSTS_INPUT_SCHEMA", + "PROVIDER_MANAGER_GET_INSTS_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_LOAD_INPUT_SCHEMA", + "PROVIDER_MANAGER_LOAD_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_SET_INPUT_SCHEMA", + "PROVIDER_MANAGER_SET_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_TERMINATE_INPUT_SCHEMA", + "PROVIDER_MANAGER_TERMINATE_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_UPDATE_INPUT_SCHEMA", + "PROVIDER_MANAGER_UPDATE_OUTPUT_SCHEMA", + "PROVIDER_MANAGER_WATCH_CHANGES_INPUT_SCHEMA", + "PROVIDER_MANAGER_WATCH_CHANGES_OUTPUT_SCHEMA", + "PROVIDER_META_SCHEMA", + "PROVIDER_RERANK_INPUT_SCHEMA", + "PROVIDER_RERANK_OUTPUT_SCHEMA", + "PROVIDER_RERANK_RESULT_SCHEMA", + "PROVIDER_STT_GET_TEXT_INPUT_SCHEMA", + "PROVIDER_STT_GET_TEXT_OUTPUT_SCHEMA", + "PROVIDER_TTS_AUDIO_CHUNK_SCHEMA", + "PROVIDER_TTS_GET_AUDIO_INPUT_SCHEMA", + "PROVIDER_TTS_GET_AUDIO_OUTPUT_SCHEMA", + "PROVIDER_TTS_GET_AUDIO_STREAM_INPUT_SCHEMA", + "PROVIDER_TTS_GET_AUDIO_STREAM_OUTPUT_SCHEMA", + "PROVIDER_TTS_SUPPORT_STREAM_INPUT_SCHEMA", + "PROVIDER_TTS_SUPPORT_STREAM_OUTPUT_SCHEMA", + "LLM_TOOL_MANAGER_ACTIVATE_INPUT_SCHEMA", + "LLM_TOOL_MANAGER_ACTIVATE_OUTPUT_SCHEMA", + "LLM_TOOL_MANAGER_ADD_INPUT_SCHEMA", + "LLM_TOOL_MANAGER_ADD_OUTPUT_SCHEMA", + "LLM_TOOL_MANAGER_REMOVE_INPUT_SCHEMA", + "LLM_TOOL_MANAGER_REMOVE_OUTPUT_SCHEMA", + "LLM_TOOL_MANAGER_DEACTIVATE_INPUT_SCHEMA", + "LLM_TOOL_MANAGER_DEACTIVATE_OUTPUT_SCHEMA", + "LLM_TOOL_MANAGER_GET_INPUT_SCHEMA", + "LLM_TOOL_MANAGER_GET_OUTPUT_SCHEMA", + "LLM_TOOL_SPEC_SCHEMA", + "AGENT_REGISTRY_GET_INPUT_SCHEMA", + "AGENT_REGISTRY_GET_OUTPUT_SCHEMA", + "AGENT_REGISTRY_LIST_INPUT_SCHEMA", + "AGENT_REGISTRY_LIST_OUTPUT_SCHEMA", + "AGENT_SPEC_SCHEMA", + "AGENT_TOOL_LOOP_RUN_INPUT_SCHEMA", + "AGENT_TOOL_LOOP_RUN_OUTPUT_SCHEMA", + "MANAGED_PROVIDER_RECORD_SCHEMA", + "PLATFORM_ERROR_SCHEMA", + "PLATFORM_GET_MEMBERS_INPUT_SCHEMA", + "PLATFORM_GET_MEMBERS_OUTPUT_SCHEMA", + "PLATFORM_GET_GROUP_INPUT_SCHEMA", + "PLATFORM_GET_GROUP_OUTPUT_SCHEMA", + "PLATFORM_INSTANCE_SCHEMA", + "PLATFORM_LIST_INSTANCES_INPUT_SCHEMA", + "PLATFORM_LIST_INSTANCES_OUTPUT_SCHEMA", + "PLATFORM_MANAGER_CLEAR_ERRORS_INPUT_SCHEMA", + "PLATFORM_MANAGER_CLEAR_ERRORS_OUTPUT_SCHEMA", + "PLATFORM_MANAGER_GET_BY_ID_INPUT_SCHEMA", + "PLATFORM_MANAGER_GET_BY_ID_OUTPUT_SCHEMA", + "PLATFORM_MANAGER_GET_STATS_INPUT_SCHEMA", + "PLATFORM_MANAGER_GET_STATS_OUTPUT_SCHEMA", + "PLATFORM_MANAGER_STATE_SCHEMA", + "PERMISSION_CHECK_INPUT_SCHEMA", + "PERMISSION_CHECK_OUTPUT_SCHEMA", + "PERMISSION_CHECK_RESULT_SCHEMA", + "PERMISSION_GET_ADMINS_INPUT_SCHEMA", + "PERMISSION_GET_ADMINS_OUTPUT_SCHEMA", + "PERMISSION_MANAGER_ADD_ADMIN_INPUT_SCHEMA", + "PERMISSION_MANAGER_ADD_ADMIN_OUTPUT_SCHEMA", + "PERMISSION_MANAGER_REMOVE_ADMIN_INPUT_SCHEMA", + "PERMISSION_MANAGER_REMOVE_ADMIN_OUTPUT_SCHEMA", + "PERMISSION_ROLE_SCHEMA", + "PLATFORM_SEND_CHAIN_INPUT_SCHEMA", + "PLATFORM_SEND_CHAIN_OUTPUT_SCHEMA", + "PLATFORM_SEND_BY_SESSION_INPUT_SCHEMA", + "PLATFORM_SEND_BY_SESSION_OUTPUT_SCHEMA", + "PLATFORM_SEND_IMAGE_INPUT_SCHEMA", + "PLATFORM_SEND_IMAGE_OUTPUT_SCHEMA", + "PLATFORM_SEND_INPUT_SCHEMA", + "PLATFORM_SEND_OUTPUT_SCHEMA", + "PLATFORM_STATS_SCHEMA", + "PERSONA_CREATE_INPUT_SCHEMA", + "PERSONA_CREATE_OUTPUT_SCHEMA", + "PERSONA_CREATE_SCHEMA", + "PERSONA_DELETE_INPUT_SCHEMA", + "PERSONA_DELETE_OUTPUT_SCHEMA", + "PERSONA_GET_INPUT_SCHEMA", + "PERSONA_GET_OUTPUT_SCHEMA", + "PERSONA_LIST_INPUT_SCHEMA", + "PERSONA_LIST_OUTPUT_SCHEMA", + "PERSONA_RECORD_SCHEMA", + "PERSONA_UPDATE_INPUT_SCHEMA", + "PERSONA_UPDATE_OUTPUT_SCHEMA", + "PERSONA_UPDATE_SCHEMA", + "CONVERSATION_CREATE_SCHEMA", + "CONVERSATION_DELETE_INPUT_SCHEMA", + "CONVERSATION_DELETE_OUTPUT_SCHEMA", + "CONVERSATION_GET_CURRENT_INPUT_SCHEMA", + "CONVERSATION_GET_CURRENT_OUTPUT_SCHEMA", + "CONVERSATION_GET_INPUT_SCHEMA", + "CONVERSATION_GET_OUTPUT_SCHEMA", + "CONVERSATION_LIST_INPUT_SCHEMA", + "CONVERSATION_LIST_OUTPUT_SCHEMA", + "CONVERSATION_NEW_INPUT_SCHEMA", + "CONVERSATION_NEW_OUTPUT_SCHEMA", + "CONVERSATION_RECORD_SCHEMA", + "CONVERSATION_SWITCH_INPUT_SCHEMA", + "CONVERSATION_SWITCH_OUTPUT_SCHEMA", + "CONVERSATION_UNSET_PERSONA_INPUT_SCHEMA", + "CONVERSATION_UNSET_PERSONA_OUTPUT_SCHEMA", + "CONVERSATION_UPDATE_INPUT_SCHEMA", + "CONVERSATION_UPDATE_OUTPUT_SCHEMA", + "CONVERSATION_UPDATE_SCHEMA", + "MESSAGE_HISTORY_APPEND_INPUT_SCHEMA", + "MESSAGE_HISTORY_APPEND_OUTPUT_SCHEMA", + "MESSAGE_HISTORY_DELETE_AFTER_INPUT_SCHEMA", + "MESSAGE_HISTORY_DELETE_AFTER_OUTPUT_SCHEMA", + "MESSAGE_HISTORY_DELETE_ALL_INPUT_SCHEMA", + "MESSAGE_HISTORY_DELETE_ALL_OUTPUT_SCHEMA", + "MESSAGE_HISTORY_DELETE_BEFORE_INPUT_SCHEMA", + "MESSAGE_HISTORY_DELETE_BEFORE_OUTPUT_SCHEMA", + "MESSAGE_HISTORY_GET_BY_ID_INPUT_SCHEMA", + "MESSAGE_HISTORY_GET_BY_ID_OUTPUT_SCHEMA", + "MESSAGE_HISTORY_LIST_INPUT_SCHEMA", + "MESSAGE_HISTORY_LIST_OUTPUT_SCHEMA", + "MESSAGE_HISTORY_PAGE_SCHEMA", + "MESSAGE_HISTORY_RECORD_SCHEMA", + "MESSAGE_HISTORY_SENDER_SCHEMA", + "MESSAGE_HISTORY_SESSION_SCHEMA", + "KB_CREATE_INPUT_SCHEMA", + "KB_CREATE_OUTPUT_SCHEMA", + "KB_DOCUMENT_DELETE_INPUT_SCHEMA", + "KB_DOCUMENT_DELETE_OUTPUT_SCHEMA", + "KB_DOCUMENT_GET_INPUT_SCHEMA", + "KB_DOCUMENT_GET_OUTPUT_SCHEMA", + "KB_DOCUMENT_LIST_INPUT_SCHEMA", + "KB_DOCUMENT_LIST_OUTPUT_SCHEMA", + "KB_DOCUMENT_REFRESH_INPUT_SCHEMA", + "KB_DOCUMENT_REFRESH_OUTPUT_SCHEMA", + "KB_DOCUMENT_UPLOAD_INPUT_SCHEMA", + "KB_DOCUMENT_UPLOAD_OUTPUT_SCHEMA", + "KB_DELETE_INPUT_SCHEMA", + "KB_DELETE_OUTPUT_SCHEMA", + "KB_GET_INPUT_SCHEMA", + "KB_GET_OUTPUT_SCHEMA", + "KB_LIST_INPUT_SCHEMA", + "KB_LIST_OUTPUT_SCHEMA", + "KB_RETRIEVE_INPUT_SCHEMA", + "KB_RETRIEVE_OUTPUT_SCHEMA", + "KB_UPDATE_INPUT_SCHEMA", + "KB_UPDATE_OUTPUT_SCHEMA", + "KNOWLEDGE_BASE_CREATE_SCHEMA", + "KNOWLEDGE_BASE_DOCUMENT_RECORD_SCHEMA", + "KNOWLEDGE_BASE_DOCUMENT_UPLOAD_SCHEMA", + "KNOWLEDGE_BASE_RECORD_SCHEMA", + "KNOWLEDGE_BASE_RETRIEVE_RESULT_SCHEMA", + "KNOWLEDGE_BASE_UPDATE_SCHEMA", + "REGISTRY_COMMAND_REGISTER_INPUT_SCHEMA", + "REGISTRY_COMMAND_REGISTER_OUTPUT_SCHEMA", + "SKILL_REGISTER_INPUT_SCHEMA", + "SKILL_REGISTER_OUTPUT_SCHEMA", + "SKILL_UNREGISTER_INPUT_SCHEMA", + "SKILL_UNREGISTER_OUTPUT_SCHEMA", + "SKILL_LIST_INPUT_SCHEMA", + "SKILL_LIST_OUTPUT_SCHEMA", + "REGISTRY_GET_HANDLER_BY_FULL_NAME_INPUT_SCHEMA", + "REGISTRY_GET_HANDLER_BY_FULL_NAME_OUTPUT_SCHEMA", + "REGISTRY_GET_HANDLERS_BY_EVENT_TYPE_INPUT_SCHEMA", + "REGISTRY_GET_HANDLERS_BY_EVENT_TYPE_OUTPUT_SCHEMA", + "SESSION_PLUGIN_FILTER_HANDLERS_INPUT_SCHEMA", + "SESSION_PLUGIN_FILTER_HANDLERS_OUTPUT_SCHEMA", + "SESSION_PLUGIN_IS_ENABLED_INPUT_SCHEMA", + "SESSION_PLUGIN_IS_ENABLED_OUTPUT_SCHEMA", + "SESSION_REF_SCHEMA", + "SESSION_SERVICE_IS_LLM_ENABLED_INPUT_SCHEMA", + "SESSION_SERVICE_IS_LLM_ENABLED_OUTPUT_SCHEMA", + "SESSION_SERVICE_IS_TTS_ENABLED_INPUT_SCHEMA", + "SESSION_SERVICE_IS_TTS_ENABLED_OUTPUT_SCHEMA", + "SESSION_SERVICE_SET_LLM_STATUS_INPUT_SCHEMA", + "SESSION_SERVICE_SET_LLM_STATUS_OUTPUT_SCHEMA", + "SESSION_SERVICE_SET_TTS_STATUS_INPUT_SCHEMA", + "SESSION_SERVICE_SET_TTS_STATUS_OUTPUT_SCHEMA", + "SYSTEM_EVENT_REACT_INPUT_SCHEMA", + "SYSTEM_EVENT_REACT_OUTPUT_SCHEMA", + "SYSTEM_EVENT_HANDLER_WHITELIST_GET_INPUT_SCHEMA", + "SYSTEM_EVENT_HANDLER_WHITELIST_GET_OUTPUT_SCHEMA", + "SYSTEM_EVENT_HANDLER_WHITELIST_SET_INPUT_SCHEMA", + "SYSTEM_EVENT_HANDLER_WHITELIST_SET_OUTPUT_SCHEMA", + "SYSTEM_EVENT_LLM_GET_STATE_INPUT_SCHEMA", + "SYSTEM_EVENT_LLM_GET_STATE_OUTPUT_SCHEMA", + "SYSTEM_EVENT_LLM_REQUEST_INPUT_SCHEMA", + "SYSTEM_EVENT_LLM_REQUEST_OUTPUT_SCHEMA", + "SYSTEM_EVENT_RESULT_CLEAR_INPUT_SCHEMA", + "SYSTEM_EVENT_RESULT_CLEAR_OUTPUT_SCHEMA", + "SYSTEM_EVENT_RESULT_GET_INPUT_SCHEMA", + "SYSTEM_EVENT_RESULT_GET_OUTPUT_SCHEMA", + "SYSTEM_EVENT_RESULT_SET_INPUT_SCHEMA", + "SYSTEM_EVENT_RESULT_SET_OUTPUT_SCHEMA", + "SYSTEM_EVENT_SEND_STREAMING_CHUNK_INPUT_SCHEMA", + "SYSTEM_EVENT_SEND_STREAMING_CHUNK_OUTPUT_SCHEMA", + "SYSTEM_EVENT_SEND_STREAMING_CLOSE_INPUT_SCHEMA", + "SYSTEM_EVENT_SEND_STREAMING_CLOSE_OUTPUT_SCHEMA", + "SYSTEM_EVENT_SEND_STREAMING_INPUT_SCHEMA", + "SYSTEM_EVENT_SEND_STREAMING_OUTPUT_SCHEMA", + "SYSTEM_EVENT_SEND_TYPING_INPUT_SCHEMA", + "SYSTEM_EVENT_SEND_TYPING_OUTPUT_SCHEMA", + "SYSTEM_FILE_HANDLE_INPUT_SCHEMA", + "SYSTEM_FILE_HANDLE_OUTPUT_SCHEMA", + "SYSTEM_FILE_REGISTER_INPUT_SCHEMA", + "SYSTEM_FILE_REGISTER_OUTPUT_SCHEMA", +] diff --git a/astrbot-sdk/src/astrbot_sdk/protocol/descriptors.py b/astrbot-sdk/src/astrbot_sdk/protocol/descriptors.py new file mode 100644 index 0000000000..fac0e9a5c5 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/protocol/descriptors.py @@ -0,0 +1,409 @@ +"""v4 协议描述符模型。 + +`protocol` 是 v4 新引入的协议层抽象,不对应旧树(圣诞树)中的一个同名目录。这里 +定义的是跨进程握手和调度时使用的声明式元数据,而不是运行时的具体处理器/ +能力实现。 +""" + +from __future__ import annotations + +from typing import Annotated, Any, Literal + +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator + +from . import _builtin_schemas +from ._builtin_schemas import * # noqa: F403 + +JSONSchema = _builtin_schemas.JSONSchema +RESERVED_CAPABILITY_NAMESPACES = ("handler", "system", "internal") +RESERVED_CAPABILITY_PREFIXES = tuple( + f"{namespace}." for namespace in RESERVED_CAPABILITY_NAMESPACES +) +BUILTIN_CAPABILITY_SCHEMAS = _builtin_schemas.BUILTIN_CAPABILITY_SCHEMAS +_BUILTIN_SCHEMA_EXPORTS = frozenset(_builtin_schemas.__all__) + + +def __getattr__(name: str) -> Any: + if name in _BUILTIN_SCHEMA_EXPORTS: + return getattr(_builtin_schemas, name) + raise AttributeError(name) + + +def __dir__() -> list[str]: + return sorted(set(globals()) | _BUILTIN_SCHEMA_EXPORTS) + + +class _DescriptorBase(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class Permissions(_DescriptorBase): + """权限配置,控制处理器的访问权限。 + + Attributes: + require_admin: 是否需要管理员权限 + required_role: 处理器要求的最小角色,v1 支持 member/admin + level: 权限等级,数值越高权限越大 + """ + + require_admin: bool = False + required_role: Literal["member", "admin"] | None = None + level: int = 0 + + @model_validator(mode="after") + def normalize_required_role(self) -> Permissions: + if self.require_admin: + if self.required_role not in {None, "admin"}: + raise ValueError( + "permissions.require_admin=True conflicts with required_role=" + f"{self.required_role!r}" + ) + self.required_role = "admin" + return self + if self.required_role == "admin": + self.require_admin = True + return self + + +class SessionRef(_DescriptorBase): + """结构化会话目标。 + + v4 运行时内部仍然保留 legacy `session` 字符串作为最低兼容层, + 但对外模型允许同时携带平台与原始寻址信息,避免平台发送接口长期 + 只依赖一个不透明字符串。 + """ + + conversation_id: str = Field( + validation_alias=AliasChoices("conversation_id", "session"), + ) + platform: str | None = None + raw: dict[str, Any] | None = None + + @property + def session(self) -> str: + return self.conversation_id + + def to_payload(self) -> dict[str, Any]: + return self.model_dump(exclude_none=True) + + +class CommandTrigger(_DescriptorBase): + """命令触发器,响应特定命令。 + + Attributes: + type: 触发器类型,固定为 "command" + command: 命令名称(不含前缀,如 "help") + aliases: 命令别名列表 + description: 命令描述,用于帮助文档 + platforms: 允许的平台列表,为空表示所有平台 + message_types: 限定的消息类型列表,为空表示不限 + """ + + type: Literal["command"] = "command" + command: str + aliases: list[str] = Field(default_factory=list) + description: str | None = None + platforms: list[str] = Field(default_factory=list) + message_types: list[str] = Field(default_factory=list) + + +class MessageTrigger(_DescriptorBase): + """消息触发器,描述消息类处理器的订阅条件。 + + Attributes: + type: 触发器类型,固定为 "message" + regex: 正则表达式模式,匹配消息文本 + keywords: 关键词列表,消息包含任一关键词即触发 + platforms: 目标平台列表,为空表示所有平台 + message_types: 限定的消息类型列表,为空表示不限 + + Note: + `regex` 和 `keywords` 可以同时为空,此时表示 "任意消息均可触发", + 仅由平台过滤或上层运行时进一步筛选。 + """ + + type: Literal["message"] = "message" + regex: str | None = None + keywords: list[str] = Field(default_factory=list) + platforms: list[str] = Field(default_factory=list) + message_types: list[str] = Field(default_factory=list) + + +class EventTrigger(_DescriptorBase): + """事件触发器,响应特定类型的事件。 + + Attributes: + type: 触发器类型,固定为 "event" + event_type: 事件类型,字符串形式(如 "message"、"notice") + """ + + type: Literal["event"] = "event" + event_type: str + + +class ScheduleTrigger(_DescriptorBase): + """定时触发器,按 cron 表达式或固定间隔执行。 + + Attributes: + type: 触发器类型,固定为 "schedule" + cron: cron 表达式(如 "0 9 * * *" 表示每天 9 点) + interval_seconds: 执行间隔(秒) + + Note: + cron 和 interval_seconds 必须且只能有一个非空。 + """ + + type: Literal["schedule"] = "schedule" + cron: str | None = Field( + default=None, + validation_alias=AliasChoices("cron", "schedule"), + ) + interval_seconds: int | None = None + + @property + def schedule(self) -> str | None: + return self.cron + + @model_validator(mode="after") + def validate_schedule(self) -> ScheduleTrigger: + has_cron = self.cron is not None + has_interval = self.interval_seconds is not None + if has_cron == has_interval: + raise ValueError("cron 和 interval_seconds 必须且只能有一个非 null") + return self + + +class PlatformFilterSpec(_DescriptorBase): + kind: Literal["platform"] = "platform" + platforms: list[str] = Field(default_factory=list) + + +class MessageTypeFilterSpec(_DescriptorBase): + kind: Literal["message_type"] = "message_type" + message_types: list[str] = Field(default_factory=list) + + +class LocalFilterRefSpec(_DescriptorBase): + kind: Literal["local"] = "local" + filter_id: str + args: dict[str, Any] = Field(default_factory=dict) + + +class CompositeFilterSpec(_DescriptorBase): + kind: Literal["and", "or"] + children: list[FilterSpec] = Field(default_factory=list) + + +FilterSpec = Annotated[ + PlatformFilterSpec + | MessageTypeFilterSpec + | LocalFilterRefSpec + | CompositeFilterSpec, + Field(discriminator="kind"), +] + + +class ParamSpec(_DescriptorBase): + name: str + type: Literal["str", "int", "float", "bool", "optional", "greedy_str"] + required: bool = True + inner_type: Literal["str", "int", "float", "bool"] | None = None + + +class CommandRouteSpec(_DescriptorBase): + group_path: list[str] = Field(default_factory=list) + display_command: str + group_help: str | None = None + + +CompositeFilterSpec.model_rebuild() + + +Trigger = Annotated[ + CommandTrigger | MessageTrigger | EventTrigger | ScheduleTrigger, + Field(discriminator="type"), +] +"""触发器联合类型,使用 type 字段作为判别器自动解析具体类型。""" + + +class HandlerDescriptor(_DescriptorBase): + """处理器描述符,描述一个事件处理函数的元信息。 + + Attributes: + id: 处理器唯一标识,通常是 "模块.函数名" 格式 + trigger: 触发器配置,决定何时执行该处理器 + kind: 处理器类别,默认普通 handler + contract: 运行时契约名,描述入参/执行语义 + priority: 优先级,数值越大越先执行 + permissions: 权限配置,控制谁可以触发该处理器 + + 使用场景: + HandlerDescriptor 通常由 `@on_command`、`@on_message` 等装饰器自动创建, + 插件作者一般不需要手动实例化。但了解其结构有助于理解插件注册机制。 + + 触发器类型: + - CommandTrigger: 响应特定命令,如 `/help` + - MessageTrigger: 响应消息(正则/关键词匹配) + - EventTrigger: 响应特定事件类型 + - ScheduleTrigger: 定时触发 + + 示例: + 插件作者通常通过装饰器声明处理器,框架会自动生成 HandlerDescriptor: + + ```python + from astrbot_sdk.decorators import on_command, on_message + + # 命令处理器 + @on_command("hello") + async def hello_handler(ctx: Context): + await ctx.reply("Hello!") + + # 消息处理器(正则匹配) + @on_message(regex=r"^test\\s+(.+)$") + async def test_handler(ctx: Context): + await ctx.reply(f"收到: {ctx.match.group(1)}") + ``` + + See Also: + Trigger: 触发器联合类型 + Permissions: 权限配置 + """ + + id: str + trigger: Trigger + kind: Literal["handler", "hook", "tool", "session"] = "handler" + contract: str | None = None + description: str | None = None + priority: int = 0 + permissions: Permissions = Field(default_factory=Permissions) + filters: list[FilterSpec] = Field(default_factory=list) + param_specs: list[ParamSpec] = Field(default_factory=list) + command_route: CommandRouteSpec | None = None + + @model_validator(mode="after") + def validate_contract_defaults(self) -> HandlerDescriptor: + if self.contract is None: + if isinstance(self.trigger, ScheduleTrigger): + self.contract = "schedule" + else: + self.contract = "message_event" + return self + + +class CapabilityDescriptor(_DescriptorBase): + """能力描述符,描述一个可调用的远程能力。 + + 能力命名规范: + - 使用 "namespace.action" 格式,如 "llm.chat"、"db.set" + - 支持多级命名空间,如 "llm_tool.manager.activate" + - 内置能力以 "internal." 开头,如 "internal.legacy.call_context_function" + + 保留命名空间(插件不可使用): + - `handler.` - 处理器相关 + - `system.` - 系统内部能力 + - `internal.` - 内部实现细节 + + Attributes: + name: 能力名称,格式为 "namespace.action" + description: 能力描述,用于文档和调试 + input_schema: 输入参数的 JSON Schema,用于验证 + output_schema: 输出结果的 JSON Schema,用于验证 + supports_stream: 是否支持流式响应 + cancelable: 是否支持取消 + + 使用场景: + 当你的插件需要**暴露**一个可被其他插件调用的能力时,使用此类声明。 + + 示例: + ```python + from astrbot_sdk.protocol import CapabilityDescriptor + + # 声明一个翻译能力 + translate_desc = CapabilityDescriptor( + name="my_plugin.translate", + description="翻译文本到指定语言", + input_schema={ + "type": "object", + "properties": { + "text": {"type": "string", "description": "要翻译的文本"}, + "target_lang": {"type": "string", "description": "目标语言"}, + }, + "required": ["text", "target_lang"], + }, + output_schema={ + "type": "object", + "properties": { + "translated": {"type": "string"}, + }, + }, + ) + + # 声明一个流式数据能力 + stream_desc = CapabilityDescriptor( + name="my_plugin.stream_data", + description="流式返回数据", + supports_stream=True, + cancelable=True, + input_schema={"type": "object", "properties": {"count": {"type": "integer"}}}, + output_schema={"type": "object", "properties": {"items": {"type": "array"}}}, + ) + ``` + + 注意: + 如果你要调用**内置能力**(如 `llm.chat`、`db.set`),不需要手动创建 + CapabilityDescriptor,而是直接通过 `Context.invoke()` 调用,或查阅 + `BUILTIN_CAPABILITY_SCHEMAS` 了解参数格式。 + + See Also: + BUILTIN_CAPABILITY_SCHEMAS: 内置能力的 schema 定义,用于查询参数格式 + """ + + name: str + description: str + input_schema: JSONSchema | None = None + output_schema: JSONSchema | None = None + supports_stream: bool = False + cancelable: bool = False + + @model_validator(mode="after") + def validate_builtin_schema_governance(self) -> CapabilityDescriptor: + builtin_schema = BUILTIN_CAPABILITY_SCHEMAS.get(self.name) + if builtin_schema is None: + return self + if self.input_schema is None or self.output_schema is None: + raise ValueError( + f"内建 capability {self.name} 必须同时提供 input_schema 和 output_schema" + ) + if ( + self.input_schema != builtin_schema["input"] + or self.output_schema != builtin_schema["output"] + ): + raise ValueError( + f"内建 capability {self.name} 的 schema 必须与协议注册表保持一致" + ) + return self + + +__all__ = [ + "Trigger", + "BUILTIN_CAPABILITY_SCHEMAS", + "CapabilityDescriptor", + "CommandRouteSpec", + "CommandTrigger", + "CompositeFilterSpec", + "EventTrigger", + "FilterSpec", + "HandlerDescriptor", + "JSONSchema", + "LocalFilterRefSpec", + "MessageTrigger", + "MessageTypeFilterSpec", + "ParamSpec", + "Permissions", + "PlatformFilterSpec", + "RESERVED_CAPABILITY_NAMESPACES", + "RESERVED_CAPABILITY_PREFIXES", + "ScheduleTrigger", + "SessionRef", +] +__all__ += list(_BUILTIN_SCHEMA_EXPORTS) diff --git a/astrbot-sdk/src/astrbot_sdk/protocol/messages.py b/astrbot-sdk/src/astrbot_sdk/protocol/messages.py new file mode 100644 index 0000000000..0bdfe3b59f --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/protocol/messages.py @@ -0,0 +1,289 @@ +"""v4 协议消息模型。 + +这些模型描述的是 `Peer` 与 `Peer` 之间的线协议。握手阶段通过 +`InitializeMessage` 发起,再由 `ResultMessage(kind="initialize_result")` +返回 `InitializeOutput`;能力调用阶段则使用 `InvokeMessage` / `ResultMessage` +或 `EventMessage` 序列。 +""" + +from __future__ import annotations + +import json +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from .descriptors import CapabilityDescriptor, HandlerDescriptor + + +class _MessageBase(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class ErrorPayload(_MessageBase): + """错误载荷,用于 ResultMessage 和 EventMessage 中传递错误信息。 + + Attributes: + code: 错误码,字符串类型,便于语义化错误分类 + message: 错误消息,人类可读的错误描述 + hint: 错误提示,可选的解决方案或建议 + retryable: 是否可重试,标识该错误是否可通过重试解决 + docs_url: 可选的文档链接,帮助调用方定位更多说明 + details: 可选的结构化细节,便于调试和日志展示 + """ + + code: str + message: str + hint: str = "" + retryable: bool = False + docs_url: str = "" + details: dict[str, Any] | None = None + + +class PeerInfo(_MessageBase): + """对等节点信息,标识消息发送方的身份。 + + Attributes: + name: 节点名称,通常是插件 ID 或核心标识 + role: 节点角色,"plugin" 或 "core" + version: 节点版本号,可选 + """ + + name: str + role: Literal["plugin", "core"] + version: str | None = None + + +class InitializeMessage(_MessageBase): + """初始化消息,用于建立连接时交换信息。 + + Attributes: + type: 消息类型,固定为 "initialize" + id: 消息 ID,用于关联响应 + protocol_version: 协议版本号 + peer: 发送方节点信息 + handlers: 注册的处理器描述符列表 + provided_capabilities: 发送方对外暴露的能力描述符列表 + metadata: 扩展元数据,可存储插件配置等信息 + """ + + type: Literal["initialize"] = "initialize" + id: str + protocol_version: str + peer: PeerInfo + handlers: list[HandlerDescriptor] = Field(default_factory=list) + provided_capabilities: list[CapabilityDescriptor] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class InitializeOutput(_MessageBase): + """初始化输出,作为 InitializeMessage 的响应数据。 + + Attributes: + peer: 接收方(核心)节点信息 + protocol_version: 协商后的协议版本;未协商时可为空 + capabilities: 核心提供的能力描述符列表 + metadata: 扩展元数据 + """ + + peer: PeerInfo + protocol_version: str | None = None + capabilities: list[CapabilityDescriptor] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class ResultMessage(_MessageBase): + """结果消息,用于返回能力调用的结果。 + + Attributes: + type: 消息类型,固定为 "result" + id: 关联的请求 ID + kind: 结果类型,可选,如 "initialize_result" 标识初始化结果 + success: 是否成功 + output: 成功时的输出数据 + error: 失败时的错误信息 + """ + + type: Literal["result"] = "result" + id: str + kind: str | None = None + success: bool + output: dict[str, Any] = Field(default_factory=dict) + error: ErrorPayload | None = None + + @model_validator(mode="after") + def validate_result_state(self) -> ResultMessage: + """约束 success / output / error 的组合状态。""" + if self.success: + if self.error is not None: + raise ValueError("success=true 时 error 必须为空") + return self + if self.error is None: + raise ValueError("success=false 时必须提供 error") + if self.output: + raise ValueError("success=false 时 output 必须为空") + return self + + +class InvokeMessage(_MessageBase): + """调用消息,用于请求执行远程能力。 + + Attributes: + type: 消息类型,固定为 "invoke" + id: 请求 ID,用于关联响应 + capability: 目标能力名称,格式为 "namespace.action" + input: 调用输入参数 + stream: 是否期望流式响应,若为 True 将收到 EventMessage 序列 + caller_plugin_id: 运行时透传的调用方插件 ID,不属于业务 payload + """ + + type: Literal["invoke"] = "invoke" + id: str + capability: str + input: dict[str, Any] = Field(default_factory=dict) + stream: bool = False + caller_plugin_id: str | None = None + + +class EventMessage(_MessageBase): + """事件消息,用于流式调用的状态通知。 + + 流式调用生命周期: + 1. started: 调用开始,所有字段为空 + 2. delta: 数据增量更新,包含 data 字段 + 3. completed: 调用完成,包含 output 字段 + 4. failed: 调用失败,包含 error 字段 + + Attributes: + type: 消息类型,固定为 "event" + id: 关联的请求 ID + phase: 事件阶段,started/delta/completed/failed + data: 增量数据,仅 delta 阶段有效 + output: 最终输出,仅 completed 阶段有效 + error: 错误信息,仅 failed 阶段有效 + """ + + type: Literal["event"] = "event" + id: str + phase: Literal["started", "delta", "completed", "failed"] + data: dict[str, Any] = Field(default_factory=dict) + output: dict[str, Any] = Field(default_factory=dict) + error: ErrorPayload | None = None + + @model_validator(mode="after") + def validate_phase_constraints(self) -> EventMessage: + """验证各 phase 的字段约束。 + + - started: 所有字段必须为空 + - delta: 必须有 data,output/error 必须为空 + - completed: 必须有 output,data/error 必须为空 + - failed: 必须有 error,data/output 必须为空 + """ + phase = self.phase + if phase == "started": + if self.data or self.output or self.error: + raise ValueError("started phase 必须所有字段为空") + elif phase == "delta": + if not self.data: + raise ValueError("delta phase 需要 data") + if self.output or self.error: + raise ValueError("delta phase 的 output/error 必须为空") + elif phase == "completed": + if not self.output: + raise ValueError("completed phase 需要 output") + if self.data or self.error: + raise ValueError("completed phase 的 data/error 必须为空") + elif phase == "failed": + if self.error is None: + raise ValueError("failed phase 需要 error") + if self.data or self.output: + raise ValueError("failed phase 的 data/output 必须为空") + return self + + +class CancelMessage(_MessageBase): + """取消消息,用于取消正在进行的调用。 + + Attributes: + type: 消息类型,固定为 "cancel" + id: 要取消的请求 ID + reason: 取消原因,默认为 "user_cancelled" + """ + + type: Literal["cancel"] = "cancel" + id: str + reason: str = "user_cancelled" + + +ProtocolMessage = ( + InitializeMessage | ResultMessage | InvokeMessage | EventMessage | CancelMessage +) +"""协议消息联合类型,所有有效消息类型的联合。""" + +_PROTOCOL_MESSAGE_MODELS = { + "initialize": InitializeMessage, + "result": ResultMessage, + "invoke": InvokeMessage, + "event": EventMessage, + "cancel": CancelMessage, +} + + +def parse_message( + payload: ProtocolMessage | str | bytes | dict[str, Any], +) -> ProtocolMessage: + """解析协议消息。 + + 从原始载荷(字符串、字节或字典)解析为对应的 ProtocolMessage 类型。 + 根据 "type" 字段自动识别消息类型并验证。 + + Args: + payload: 原始消息载荷,支持已解析模型、JSON 字符串、字节或字典 + + Returns: + 解析后的协议消息对象 + + Raises: + ValueError: 未知的消息类型 + + Example: + >>> msg = parse_message('{"type": "invoke", "id": "1", "capability": "test"}') + >>> isinstance(msg, InvokeMessage) + True + """ + if isinstance( + payload, + ( + InitializeMessage, + ResultMessage, + InvokeMessage, + EventMessage, + CancelMessage, + ), + ): + return payload + if isinstance(payload, bytes): + payload = payload.decode("utf-8") + if isinstance(payload, str): + payload = json.loads(payload) + if not isinstance(payload, dict): + raise ValueError("协议消息必须是 JSON object") + message_type = payload.get("type") + model = _PROTOCOL_MESSAGE_MODELS.get(str(message_type)) + if model is not None: + return model.model_validate(payload) + raise ValueError(f"未知消息类型:{message_type}") + + +__all__ = [ + "CancelMessage", + "ErrorPayload", + "EventMessage", + "InitializeMessage", + "InitializeOutput", + "InvokeMessage", + "PeerInfo", + "ProtocolMessage", + "ResultMessage", + "parse_message", +] diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/__init__.py b/astrbot-sdk/src/astrbot_sdk/runtime/__init__.py new file mode 100644 index 0000000000..7601f745c2 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/__init__.py @@ -0,0 +1,63 @@ +"""AstrBot SDK runtime public exports. + +本模块提供运行时核心组件的公共导出,包括: +- CapabilityRouter: 能力路由器,处理能力调用的分发和路由 +- HandlerDispatcher: 事件处理器分发器,将事件分发到注册的 handler +- Peer: 与 AstrBot 核心通信的对等端抽象 +- Transport 系列: 进程间通信传输层实现(stdio/websocket) + +延迟加载策略: +为避免导入时触发 websocket/aiohttp 等重型依赖,采用 __getattr__ 实现按需加载。 +这样轻量级导入(如仅使用类型提示)不会产生不必要的依赖开销。 +""" + +from __future__ import annotations + +from importlib import import_module +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .capability_router import CapabilityRouter, StreamExecution + from .handler_dispatcher import HandlerDispatcher + from .peer import Peer + from .transport import ( + MessageHandler, + StdioTransport, + Transport, + WebSocketClientTransport, + WebSocketServerTransport, + ) + +__all__ = [ + "CapabilityRouter", + "HandlerDispatcher", + "MessageHandler", + "Peer", + "StdioTransport", + "StreamExecution", + "Transport", + "WebSocketClientTransport", + "WebSocketServerTransport", +] + + +def __getattr__(name: str) -> Any: + if name in {"CapabilityRouter", "StreamExecution"}: + module = import_module(".capability_router", __name__) + return getattr(module, name) + if name == "HandlerDispatcher": + module = import_module(".handler_dispatcher", __name__) + return getattr(module, name) + if name == "Peer": + module = import_module(".peer", __name__) + return getattr(module, name) + if name in { + "MessageHandler", + "StdioTransport", + "Transport", + "WebSocketClientTransport", + "WebSocketServerTransport", + }: + module = import_module(".transport", __name__) + return getattr(module, name) + raise AttributeError(name) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/__init__.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/__init__.py new file mode 100644 index 0000000000..b0af66d417 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/__init__.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from .bridge_base import CapabilityRouterBridgeBase +from .capabilities import ( + ConversationCapabilityMixin, + DBCapabilityMixin, + HttpCapabilityMixin, + KnowledgeBaseCapabilityMixin, + LLMCapabilityMixin, + McpCapabilityMixin, + MemoryCapabilityMixin, + MessageHistoryCapabilityMixin, + MetadataCapabilityMixin, + PermissionCapabilityMixin, + PersonaCapabilityMixin, + PlatformCapabilityMixin, + ProviderCapabilityMixin, + SessionCapabilityMixin, + SkillCapabilityMixin, + SystemCapabilityMixin, +) + + +class BuiltinCapabilityRouterMixin( + LLMCapabilityMixin, + MemoryCapabilityMixin, + DBCapabilityMixin, + PlatformCapabilityMixin, + HttpCapabilityMixin, + MetadataCapabilityMixin, + PermissionCapabilityMixin, + ProviderCapabilityMixin, + McpCapabilityMixin, + SessionCapabilityMixin, + SkillCapabilityMixin, + PersonaCapabilityMixin, + ConversationCapabilityMixin, + MessageHistoryCapabilityMixin, + KnowledgeBaseCapabilityMixin, + SystemCapabilityMixin, + CapabilityRouterBridgeBase, +): + def _register_builtin_capabilities(self) -> None: + self._register_llm_capabilities() + self._register_memory_capabilities() + self._register_db_capabilities() + self._register_platform_capabilities() + self._register_http_capabilities() + self._register_metadata_capabilities() + self._register_permission_capabilities() + self._register_provider_capabilities() + self._register_agent_tool_capabilities() + self._register_mcp_capabilities() + self._register_session_capabilities() + self._register_skill_capabilities() + self._register_persona_capabilities() + self._register_conversation_capabilities() + self._register_message_history_capabilities() + self._register_kb_capabilities() + self._register_provider_manager_capabilities() + self._register_platform_manager_capabilities() + self._register_system_capabilities() + + +__all__ = ["BuiltinCapabilityRouterMixin"] diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/_host.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/_host.py new file mode 100644 index 0000000000..6d31ba6f2c --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/_host.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import asyncio +from datetime import datetime +from pathlib import Path +from typing import Any + +from ...protocol.descriptors import CapabilityDescriptor + + +class CapabilityRouterHost: + memory_store: dict[str, dict[str, Any]] + _memory_backends: dict[str, Any] + _memory_index: dict[str, dict[str, Any]] + _memory_dirty_keys: set[str] + _memory_expires_at: dict[str, datetime | None] + db_store: dict[str, Any] + sent_messages: list[dict[str, Any]] + event_actions: list[dict[str, Any]] + http_api_store: list[dict[str, Any]] + _event_streams: dict[str, dict[str, Any]] + _plugins: dict[str, Any] + _request_overlays: dict[str, dict[str, Any]] + _provider_catalog: dict[str, list[dict[str, Any]]] + _provider_configs: dict[str, dict[str, Any]] + _active_provider_ids: dict[str, str | None] + _provider_change_subscriptions: dict[str, asyncio.Queue[dict[str, Any]]] + _system_data_root: Path + _session_waiters: dict[str, set[str]] + _session_plugin_configs: dict[str, dict[str, Any]] + _session_service_configs: dict[str, dict[str, Any]] + _db_watch_subscriptions: dict[str, tuple[str | None, asyncio.Queue[dict[str, Any]]]] + _dynamic_command_routes: dict[str, list[dict[str, Any]]] + _file_token_store: dict[str, str] + _platform_instances: list[dict[str, Any]] + _persona_store: dict[str, dict[str, Any]] + _conversation_store: dict[str, dict[str, Any]] + _session_current_conversation_ids: dict[str, str] + _kb_store: dict[str, dict[str, Any]] + _kb_document_store: dict[str, dict[str, dict[str, Any]]] + _kb_document_content_store: dict[str, str] + + def register( + self, + descriptor: CapabilityDescriptor, + *, + call_handler=None, + stream_handler=None, + finalize=None, + exposed: bool = True, + ) -> None: + raise NotImplementedError + + def _emit_db_change(self, *, op: str, key: str, value: Any | None) -> None: + raise NotImplementedError + + @staticmethod + def _require_caller_plugin_id(capability_name: str) -> str: + raise NotImplementedError + + @staticmethod + def _validated_plugin_id(plugin_id: str, *, capability_name: str) -> str: + raise NotImplementedError + + def _plugin_data_dir(self, plugin_id: str, *, capability_name: str) -> Path: + raise NotImplementedError + + def register_dynamic_command_route( + self, + *, + plugin_id: str, + command_name: str, + handler_full_name: str, + desc: str = "", + priority: int = 0, + use_regex: bool = False, + ) -> None: + raise NotImplementedError + + def get_platform_instances(self) -> list[dict[str, Any]]: + raise NotImplementedError + + @staticmethod + def _normalize_platform_name(value: Any) -> str: + raise NotImplementedError + + @classmethod + def _normalized_platform_names(cls, values: Any) -> set[str]: + raise NotImplementedError + + def _plugin_supports_platform(self, plugin_id: str, platform_name: str) -> bool: + raise NotImplementedError + + def _platform_name_from_id(self, platform_id: str) -> str: + raise NotImplementedError + + def _session_platform_name(self, session: str) -> str: + raise NotImplementedError + + def _require_platform_support_for_session( + self, + capability_name: str, + session: str, + ) -> str: + raise NotImplementedError + + def _register_agent_tool_capabilities(self) -> None: + raise NotImplementedError + + def _provider_entry( + self, + payload: dict[str, Any], + capability_name: str, + expected_kind: str | None = None, + ) -> dict[str, Any]: + raise NotImplementedError + + async def _provider_embedding_get_embedding( + self, request_id: str, payload: dict[str, Any], token + ) -> dict[str, Any]: + raise NotImplementedError + + async def _provider_embedding_get_embeddings( + self, request_id: str, payload: dict[str, Any], token + ) -> dict[str, Any]: + raise NotImplementedError diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/bridge_base.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/bridge_base.py new file mode 100644 index 0000000000..f1e36516fe --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/bridge_base.py @@ -0,0 +1,246 @@ +from __future__ import annotations + +import copy +import hashlib +import math +import re +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from ..._internal.plugin_ids import resolve_plugin_data_dir, validate_plugin_id +from ...errors import AstrBotError +from ...protocol.descriptors import ( + BUILTIN_CAPABILITY_SCHEMAS, + CapabilityDescriptor, + SessionRef, +) +from ._host import CapabilityRouterHost + + +def _clone_target_payload(value: Any) -> dict[str, Any] | None: + if not isinstance(value, dict): + return None + return {str(key): item for key, item in value.items()} + + +def _clone_chain_payload(value: Any) -> list[dict[str, Any]]: + if not isinstance(value, list): + return [] + return [ + {str(key): item for key, item in chunk.items()} + for chunk in value + if isinstance(chunk, dict) + ] + + +_MOCK_EMBEDDING_DIM = 24 + + +def _embedding_terms(text: str) -> list[str]: + """Build stable tokens for the mock embedding implementation.""" + normalized = re.sub(r"\s+", " ", str(text).strip().casefold()) + compact = normalized.replace(" ", "") + if not normalized: + return [] + + terms = [word for word in re.findall(r"\w+", normalized, flags=re.UNICODE) if word] + if compact: + if len(compact) == 1: + terms.append(compact) + else: + terms.extend( + compact[index : index + 2] for index in range(len(compact) - 1) + ) + terms.append(compact) + return terms or [normalized] + + +def _mock_embedding_vector(text: str, *, provider_id: str) -> list[float]: + """Generate a deterministic normalized mock embedding vector.""" + values = [0.0] * _MOCK_EMBEDDING_DIM + for term in _embedding_terms(text): + digest = hashlib.sha256(f"{provider_id}:{term}".encode()).digest() + index = int.from_bytes(digest[:2], "big") % _MOCK_EMBEDDING_DIM + values[index] += 1.0 + min(len(term), 8) * 0.05 + norm = math.sqrt(sum(value * value for value in values)) + if norm <= 0: + return values + return [value / norm for value in values] + + +class CapabilityRouterBridgeBase(CapabilityRouterHost): + _memory_backends: dict[str, Any] + + @staticmethod + def _normalize_platform_name(value: Any) -> str: + return str(value or "").strip().lower() + + @classmethod + def _normalized_platform_names(cls, values: Any) -> set[str]: + if not isinstance(values, list): + return set() + return { + cls._normalize_platform_name(item) + for item in values + if cls._normalize_platform_name(item) + } + + @staticmethod + def _validated_plugin_id(plugin_id: str, *, capability_name: str) -> str: + try: + return validate_plugin_id(plugin_id) + except ValueError as exc: + raise AstrBotError.invalid_input( + f"{capability_name} requires a safe plugin_id: {exc}" + ) from exc + + def _plugin_data_dir(self, plugin_id: str, *, capability_name: str) -> Path: + try: + return resolve_plugin_data_dir(self._system_data_root, plugin_id) + except ValueError as exc: + raise AstrBotError.invalid_input( + f"{capability_name} requires a safe plugin_id: {exc}" + ) from exc + + def _builtin_descriptor( + self, + name: str, + description: str, + *, + supports_stream: bool = False, + cancelable: bool = False, + ) -> CapabilityDescriptor: + schema = BUILTIN_CAPABILITY_SCHEMAS[name] + return CapabilityDescriptor( + name=name, + description=description, + input_schema=copy.deepcopy(schema["input"]), + output_schema=copy.deepcopy(schema["output"]), + supports_stream=supports_stream, + cancelable=cancelable, + ) + + def _resolve_target( + self, payload: dict[str, Any] + ) -> tuple[str, dict[str, Any] | None]: + target_payload = payload.get("target") + if isinstance(target_payload, dict): + target = SessionRef.model_validate(target_payload) + return target.session, target.to_payload() + return str(payload.get("session", "")), None + + @staticmethod + def _is_group_session(session: str) -> bool: + normalized = str(session).lower() + return ":group:" in normalized or ":groupmessage:" in normalized + + @staticmethod + def _mock_group_payload(session: str) -> dict[str, Any] | None: + if not CapabilityRouterBridgeBase._is_group_session(session): + return None + members = [ + { + "user_id": f"{session}:member-1", + "nickname": "Member 1", + "role": "member", + }, + { + "user_id": f"{session}:member-2", + "nickname": "Member 2", + "role": "admin", + }, + ] + return { + "group_id": session.rsplit(":", maxsplit=1)[-1], + "group_name": f"Mock Group {session.rsplit(':', maxsplit=1)[-1]}", + "group_avatar": "", + "group_owner": members[0]["user_id"], + "group_admins": [members[1]["user_id"]], + "members": members, + } + + def _session_plugin_config(self, session: str) -> dict[str, Any]: + config = self._session_plugin_configs.get(str(session), {}) + return dict(config) if isinstance(config, dict) else {} + + def _session_service_config(self, session: str) -> dict[str, Any]: + config = self._session_service_configs.get(str(session), {}) + return dict(config) if isinstance(config, dict) else {} + + @staticmethod + def _now_iso() -> str: + return datetime.now(timezone.utc).isoformat() + + @staticmethod + def _session_platform_id(session: str) -> str: + parts = str(session).split(":", maxsplit=1) + if parts and parts[0].strip(): + return parts[0].strip() + return "unknown" + + def _plugin_supports_platform(self, plugin_id: str, platform_name: str) -> bool: + normalized_platform = self._normalize_platform_name(platform_name) + if not normalized_platform: + return True + plugin = self._plugins.get(str(plugin_id)) + if plugin is None: + return True + metadata = getattr(plugin, "metadata", None) + if not isinstance(metadata, dict): + return True + supported = self._normalized_platform_names(metadata.get("support_platforms")) + if not supported: + return True + return normalized_platform in supported + + def _platform_name_from_id(self, platform_id: str) -> str: + normalized_platform_id = str(platform_id).strip() + if not normalized_platform_id: + return "" + for item in self.get_platform_instances(): + if not isinstance(item, dict): + continue + if str(item.get("id", "")).strip() != normalized_platform_id: + continue + return self._normalize_platform_name(item.get("type")) + return "" + + def _session_platform_name(self, session: str) -> str: + return self._platform_name_from_id(self._session_platform_id(session)) + + def _require_platform_support_for_session( + self, + capability_name: str, + session: str, + ) -> str: + plugin_id = self._require_caller_plugin_id(capability_name) + platform_name = self._session_platform_name(session) + if not platform_name or self._plugin_supports_platform( + plugin_id, platform_name + ): + return plugin_id + raise AstrBotError.invalid_input( + f"{capability_name} does not support platform '{platform_name}' for plugin '{plugin_id}'" + ) + + @staticmethod + def _normalize_history_payload(value: Any) -> list[dict[str, Any]]: + if not isinstance(value, list): + return [] + return [dict(item) for item in value if isinstance(item, dict)] + + @staticmethod + def _normalize_persona_dialogs_payload(value: Any) -> list[str]: + if not isinstance(value, list): + return [] + return [str(item) for item in value if isinstance(item, str)] + + @staticmethod + def _optional_int(value: Any) -> int | None: + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/__init__.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/__init__.py new file mode 100644 index 0000000000..1b765697d7 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/__init__.py @@ -0,0 +1,35 @@ +from .conversation import ConversationCapabilityMixin +from .db import DBCapabilityMixin +from .http import HttpCapabilityMixin +from .kb import KnowledgeBaseCapabilityMixin +from .llm import LLMCapabilityMixin +from .mcp import McpCapabilityMixin +from .memory import MemoryCapabilityMixin +from .message_history import MessageHistoryCapabilityMixin +from .metadata import MetadataCapabilityMixin +from .permission import PermissionCapabilityMixin +from .persona import PersonaCapabilityMixin +from .platform import PlatformCapabilityMixin +from .provider import ProviderCapabilityMixin +from .session import SessionCapabilityMixin +from .skill import SkillCapabilityMixin +from .system import SystemCapabilityMixin + +__all__ = [ + "ConversationCapabilityMixin", + "DBCapabilityMixin", + "HttpCapabilityMixin", + "KnowledgeBaseCapabilityMixin", + "LLMCapabilityMixin", + "McpCapabilityMixin", + "MemoryCapabilityMixin", + "MessageHistoryCapabilityMixin", + "MetadataCapabilityMixin", + "PermissionCapabilityMixin", + "PersonaCapabilityMixin", + "PlatformCapabilityMixin", + "ProviderCapabilityMixin", + "SessionCapabilityMixin", + "SkillCapabilityMixin", + "SystemCapabilityMixin", +] diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/conversation.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/conversation.py new file mode 100644 index 0000000000..a250f43e5a --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/conversation.py @@ -0,0 +1,261 @@ +from __future__ import annotations + +import uuid +from typing import Any + +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + + +class ConversationCapabilityMixin(CapabilityRouterBridgeBase): + async def _conversation_new( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")).strip() + if not session: + raise AstrBotError.invalid_input("conversation.new requires session") + raw_conversation = payload.get("conversation") + if raw_conversation is None: + raw_conversation = {} + if not isinstance(raw_conversation, dict): + raise AstrBotError.invalid_input( + "conversation.new requires conversation object" + ) + conversation_id = uuid.uuid4().hex + now = self._now_iso() + record = { + "conversation_id": conversation_id, + "session": session, + "platform_id": ( + str(raw_conversation.get("platform_id")) + if raw_conversation.get("platform_id") is not None + else self._session_platform_id(session) + ), + "history": self._normalize_history_payload(raw_conversation.get("history")), + "title": ( + str(raw_conversation.get("title")) + if raw_conversation.get("title") is not None + else None + ), + "persona_id": ( + str(raw_conversation.get("persona_id")) + if raw_conversation.get("persona_id") is not None + else None + ), + "created_at": now, + "updated_at": now, + "token_usage": None, + } + self._conversation_store[conversation_id] = record + self._session_current_conversation_ids[session] = conversation_id + return {"conversation_id": conversation_id} + + async def _conversation_switch( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")).strip() + conversation_id = str(payload.get("conversation_id", "")).strip() + record = self._conversation_store.get(conversation_id) + if record is None or str(record.get("session", "")) != session: + raise AstrBotError.invalid_input( + "conversation.switch requires a conversation in the same session" + ) + self._session_current_conversation_ids[session] = conversation_id + return {} + + async def _conversation_delete( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")).strip() + conversation_id = payload.get("conversation_id") + normalized_conversation_id = ( + str(conversation_id).strip() if conversation_id is not None else "" + ) + if not normalized_conversation_id: + normalized_conversation_id = self._session_current_conversation_ids.get( + session, "" + ) + if not normalized_conversation_id: + return {} + record = self._conversation_store.get(normalized_conversation_id) + if record is None: + return {} + if str(record.get("session", "")) != session: + raise AstrBotError.invalid_input( + "conversation.delete requires a conversation in the same session" + ) + del self._conversation_store[normalized_conversation_id] + current_conversation_id = self._session_current_conversation_ids.get(session) + if current_conversation_id == normalized_conversation_id: + replacement = next( + ( + conversation_id + for conversation_id, item in self._conversation_store.items() + if str(item.get("session", "")) == session + ), + None, + ) + if replacement is None: + self._session_current_conversation_ids.pop(session, None) + else: + self._session_current_conversation_ids[session] = replacement + return {} + + async def _conversation_get( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")).strip() + conversation_id = str(payload.get("conversation_id", "")).strip() + record = self._conversation_store.get(conversation_id) + if record is None and bool(payload.get("create_if_not_exists", False)): + created = await self._conversation_new( + _request_id, + {"session": session, "conversation": {}}, + _token, + ) + record = self._conversation_store.get( + str(created.get("conversation_id", "")).strip() + ) + if record is None: + return {"conversation": None} + if str(record.get("session", "")) != session: + return {"conversation": None} + return {"conversation": dict(record)} + + async def _conversation_get_current( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")).strip() + conversation_id = self._session_current_conversation_ids.get(session, "") + if not conversation_id and bool(payload.get("create_if_not_exists", False)): + created = await self._conversation_new( + _request_id, + {"session": session, "conversation": {}}, + _token, + ) + conversation_id = str(created.get("conversation_id", "")).strip() + if not conversation_id: + return {"conversation": None} + record = self._conversation_store.get(conversation_id) + if record is None or str(record.get("session", "")) != session: + return {"conversation": None} + return {"conversation": dict(record)} + + async def _conversation_list( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = payload.get("session") + platform_id = payload.get("platform_id") + conversations = [] + for conversation_id in sorted(self._conversation_store.keys()): + item = self._conversation_store[conversation_id] + if session is not None and str(item.get("session", "")) != str(session): + continue + if platform_id is not None and str(item.get("platform_id", "")) != str( + platform_id + ): + continue + conversations.append(dict(item)) + return {"conversations": conversations} + + async def _conversation_update( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")).strip() + conversation_id = payload.get("conversation_id") + normalized_conversation_id = ( + str(conversation_id).strip() if conversation_id is not None else "" + ) + if not normalized_conversation_id: + normalized_conversation_id = self._session_current_conversation_ids.get( + session, "" + ) + if not normalized_conversation_id: + return {} + record = self._conversation_store.get(normalized_conversation_id) + if record is None: + return {} + if str(record.get("session", "")) != session: + raise AstrBotError.invalid_input( + "conversation.update requires a conversation in the same session" + ) + raw_conversation = payload.get("conversation") + if not isinstance(raw_conversation, dict): + raw_conversation = {} + if "history" in raw_conversation: + history = raw_conversation.get("history") + record["history"] = ( + self._normalize_history_payload(history) if history is not None else [] + ) + if "title" in raw_conversation: + title = raw_conversation.get("title") + record["title"] = str(title) if title is not None else None + if "persona_id" in raw_conversation: + persona_id = raw_conversation.get("persona_id") + record["persona_id"] = str(persona_id) if persona_id is not None else None + if "token_usage" in raw_conversation: + token_usage = raw_conversation.get("token_usage") + record["token_usage"] = ( + int(token_usage) if token_usage is not None else None + ) + record["updated_at"] = self._now_iso() + return {} + + async def _conversation_unset_persona( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")).strip() + conversation_id = payload.get("conversation_id") + normalized_conversation_id = ( + str(conversation_id).strip() if conversation_id is not None else "" + ) + if not normalized_conversation_id: + normalized_conversation_id = self._session_current_conversation_ids.get( + session, "" + ) + if not normalized_conversation_id: + return {} + record = self._conversation_store.get(normalized_conversation_id) + if record is None: + return {} + if str(record.get("session", "")) != session: + raise AstrBotError.invalid_input( + "conversation.unset_persona requires a conversation in the same session" + ) + record["persona_id"] = None + record["updated_at"] = self._now_iso() + return {} + + def _register_conversation_capabilities(self) -> None: + self.register( + self._builtin_descriptor("conversation.new", "新建对话"), + call_handler=self._conversation_new, + ) + self.register( + self._builtin_descriptor("conversation.switch", "切换对话"), + call_handler=self._conversation_switch, + ) + self.register( + self._builtin_descriptor("conversation.delete", "删除对话"), + call_handler=self._conversation_delete, + ) + self.register( + self._builtin_descriptor("conversation.get", "获取对话"), + call_handler=self._conversation_get, + ) + self.register( + self._builtin_descriptor("conversation.get_current", "获取当前对话"), + call_handler=self._conversation_get_current, + ) + self.register( + self._builtin_descriptor("conversation.list", "列出对话"), + call_handler=self._conversation_list, + ) + self.register( + self._builtin_descriptor("conversation.update", "更新对话"), + call_handler=self._conversation_update, + ) + self.register( + self._builtin_descriptor("conversation.unset_persona", "清空对话人格"), + call_handler=self._conversation_unset_persona, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/db.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/db.py new file mode 100644 index 0000000000..f8bdfedf9a --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/db.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from typing import Any + +from ....errors import AstrBotError +from ..._streaming import StreamExecution +from ..bridge_base import CapabilityRouterBridgeBase + + +class DBCapabilityMixin(CapabilityRouterBridgeBase): + def _db_scoped_key(self, plugin_id: str, key: str) -> str: + """将用户提供的 key 加上插件命名空间前缀,防止跨插件越权访问。""" + return f"{plugin_id}:{key}" + + def _db_strip_scope(self, plugin_id: str, scoped_key: str) -> str: + """去掉命名空间前缀,返回插件视角的原始 key。""" + prefix = f"{plugin_id}:" + return ( + scoped_key[len(prefix) :] if scoped_key.startswith(prefix) else scoped_key + ) + + def _db_public_event( + self, plugin_id: str, raw_event: dict[str, Any] + ) -> dict[str, Any]: + """将内部事件转换回插件可见的 key 视图。""" + event = dict(raw_event) + key = event.get("key") + if isinstance(key, str): + event["key"] = self._db_strip_scope(plugin_id, key) + return event + + async def _db_get( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("db.get") + key = self._db_scoped_key(plugin_id, str(payload.get("key", ""))) + return {"value": self.db_store.get(key)} + + async def _db_set( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("db.set") + key = self._db_scoped_key(plugin_id, str(payload.get("key", ""))) + value = payload.get("value") + self.db_store[key] = value + self._emit_db_change(op="set", key=key, value=value) + return {} + + async def _db_delete( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("db.delete") + key = self._db_scoped_key(plugin_id, str(payload.get("key", ""))) + self.db_store.pop(key, None) + self._emit_db_change(op="delete", key=key, value=None) + return {} + + async def _db_list( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("db.list") + ns_prefix = f"{plugin_id}:" + # 只列出属于当前插件命名空间的 key,并去掉命名空间前缀返回给插件 + user_prefix = payload.get("prefix") + all_keys = sorted( + key for key in self.db_store.keys() if key.startswith(ns_prefix) + ) + stripped = [self._db_strip_scope(plugin_id, k) for k in all_keys] + if isinstance(user_prefix, str): + stripped = [k for k in stripped if k.startswith(user_prefix)] + return {"keys": stripped} + + async def _db_get_many( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("db.get_many") + keys_payload = payload.get("keys") + if not isinstance(keys_payload, (list, tuple)): + raise AstrBotError.invalid_input("db.get_many 的 keys 必须是数组") + items = [ + { + "key": str(k), + "value": self.db_store.get(self._db_scoped_key(plugin_id, str(k))), + } + for k in keys_payload + ] + return {"items": items} + + async def _db_set_many( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("db.set_many") + items_payload = payload.get("items") + if not isinstance(items_payload, (list, tuple)): + raise AstrBotError.invalid_input("db.set_many 的 items 必须是数组") + for entry in items_payload: + if not isinstance(entry, dict): + raise AstrBotError.invalid_input( + "db.set_many 的 items 必须是 object 数组" + ) + key = self._db_scoped_key(plugin_id, str(entry.get("key", ""))) + value = entry.get("value") + self.db_store[key] = value + self._emit_db_change(op="set", key=key, value=value) + return {} + + async def _db_watch( + self, request_id: str, payload: dict[str, Any], _token + ) -> StreamExecution: + plugin_id = self._require_caller_plugin_id("db.watch") + prefix = payload.get("prefix") + prefix_value: str | None + if isinstance(prefix, str): + # 将用户传入的前缀也加上命名空间,只监听本插件的 key 变更 + prefix_value = self._db_scoped_key(plugin_id, prefix) + elif prefix is None: + # 无前缀时默认监听整个命名空间 + prefix_value = f"{plugin_id}:" + else: + raise AstrBotError.invalid_input("db.watch 的 prefix 必须是 string 或 null") + + queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() + self._db_watch_subscriptions[request_id] = (prefix_value, queue) + + async def iterator() -> AsyncIterator[dict[str, Any]]: + try: + while True: + yield self._db_public_event(plugin_id, await queue.get()) + finally: + self._db_watch_subscriptions.pop(request_id, None) + + return StreamExecution( + iterator=iterator(), + finalize=lambda _chunks: {}, + collect_chunks=False, + ) + + def _register_db_capabilities(self) -> None: + self.register( + self._builtin_descriptor("db.get", "读取 KV"), call_handler=self._db_get + ) + self.register( + self._builtin_descriptor("db.set", "写入 KV"), call_handler=self._db_set + ) + self.register( + self._builtin_descriptor("db.delete", "删除 KV"), + call_handler=self._db_delete, + ) + self.register( + self._builtin_descriptor("db.list", "列出 KV"), call_handler=self._db_list + ) + self.register( + self._builtin_descriptor("db.get_many", "批量读取 KV"), + call_handler=self._db_get_many, + ) + self.register( + self._builtin_descriptor("db.set_many", "批量写入 KV"), + call_handler=self._db_set_many, + ) + self.register( + self._builtin_descriptor( + "db.watch", + "订阅 KV 变更", + supports_stream=True, + cancelable=True, + ), + stream_handler=self._db_watch, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/http.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/http.py new file mode 100644 index 0000000000..eaefdc2780 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/http.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import re +from typing import Any + +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + +# 路由只允许字母、数字、/, -, _, . 以及路径参数 {param},且必须以 / 开头。 +# 参数段必须完整地形如 {param},同时禁止空段(例如连续斜杠)。 +_ROUTE_SEGMENT_RE = re.compile(r"^(?:[\w\-._]+|\{[\w\-._]+\})$") + + +def _validate_route(route: str, capability_name: str) -> None: + """校验 HTTP 路由路径格式,阻止路径遍历和非法字符。""" + if ".." in route: + raise AstrBotError.invalid_input(f"{capability_name}: 路由路径不允许包含 '..'") + if not route.startswith("/"): + raise AstrBotError.invalid_input( + f"{capability_name}: 路由路径格式非法,只允许字母/数字/-/_/./{{param}} 段," + "且必须以 / 开头,如 /foo/bar" + ) + if route == "/": + return + segments = route.split("/")[1:] + if any( + not segment or not _ROUTE_SEGMENT_RE.fullmatch(segment) for segment in segments + ): + raise AstrBotError.invalid_input( + f"{capability_name}: 路由路径格式非法,只允许字母/数字/-/_/./{{param}} 段," + "禁止连续斜杠,且必须以 / 开头,如 /foo/bar" + ) + + +class HttpCapabilityMixin(CapabilityRouterBridgeBase): + async def _http_register_api( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + methods_payload = payload.get("methods") + if not isinstance(methods_payload, list) or not all( + isinstance(item, str) for item in methods_payload + ): + raise AstrBotError.invalid_input( + "http.register_api 的 methods 必须是 string 数组" + ) + route = str(payload.get("route", "")).strip() + handler_capability = str(payload.get("handler_capability", "")).strip() + if not route or not handler_capability: + raise AstrBotError.invalid_input( + "http.register_api 需要 route 和 handler_capability" + ) + _validate_route(route, "http.register_api") + plugin_name = self._require_caller_plugin_id("http.register_api") + methods = sorted({method.upper() for method in methods_payload if method}) + entry: dict[str, Any] = { + "route": route, + "methods": methods, + "handler_capability": handler_capability, + "description": str(payload.get("description", "")), + "plugin_id": plugin_name, + } + self.http_api_store = [ + item + for item in self.http_api_store + if not ( + item.get("route") == route + and item.get("plugin_id") == entry["plugin_id"] + and item.get("methods") == methods + ) + ] + self.http_api_store.append(entry) + return {} + + async def _http_unregister_api( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + route = str(payload.get("route", "")).strip() + methods_payload = payload.get("methods") + if not isinstance(methods_payload, list) or not all( + isinstance(item, str) for item in methods_payload + ): + raise AstrBotError.invalid_input( + "http.unregister_api 的 methods 必须是 string 数组" + ) + plugin_name = self._require_caller_plugin_id("http.unregister_api") + methods = {method.upper() for method in methods_payload if method} + updated: list[dict[str, Any]] = [] + for entry in self.http_api_store: + if entry.get("route") != route: + updated.append(entry) + continue + if entry.get("plugin_id") != plugin_name: + updated.append(entry) + continue + if not methods: + # `HTTPClient.unregister_api(methods=None)` 会归一化为空列表, + # 公开语义就是“移除当前插件在该 route 下注册的全部方法”。 + continue + remaining_methods = [ + method for method in entry.get("methods", []) if method not in methods + ] + if remaining_methods: + updated.append({**entry, "methods": remaining_methods}) + self.http_api_store = updated + return {} + + async def _http_list_apis( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_name = self._require_caller_plugin_id("http.list_apis") + apis = [ + dict(entry) + for entry in self.http_api_store + if entry.get("plugin_id") == plugin_name + ] + return {"apis": apis} + + def _register_http_capabilities(self) -> None: + self.register( + self._builtin_descriptor("http.register_api", "注册 HTTP 路由"), + call_handler=self._http_register_api, + ) + self.register( + self._builtin_descriptor("http.unregister_api", "注销 HTTP 路由"), + call_handler=self._http_unregister_api, + ) + self.register( + self._builtin_descriptor("http.list_apis", "列出 HTTP 路由"), + call_handler=self._http_list_apis, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/kb.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/kb.py new file mode 100644 index 0000000000..77a03d86c7 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/kb.py @@ -0,0 +1,427 @@ +from __future__ import annotations + +import math +import uuid +from pathlib import Path +from typing import Any + +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + + +def _term_set(text: str) -> set[str]: + normalized = " ".join(str(text).strip().casefold().split()) + compact = normalized.replace(" ", "") + if not normalized: + return set() + terms = {item for item in normalized.split(" ") if item} + if compact: + terms.add(compact) + if len(compact) > 1: + terms.update( + compact[index : index + 2] for index in range(len(compact) - 1) + ) + return terms + + +class KnowledgeBaseCapabilityMixin(CapabilityRouterBridgeBase): + def _kb_documents(self, kb_id: str) -> dict[str, dict[str, Any]]: + return self._kb_document_store.setdefault(kb_id, {}) + + def _refresh_mock_kb_stats(self, kb_id: str) -> None: + kb = self._kb_store.get(kb_id) + if not isinstance(kb, dict): + return + documents = self._kb_documents(kb_id) + kb["doc_count"] = len(documents) + kb["chunk_count"] = sum( + int(document.get("chunk_count", 0) or 0) for document in documents.values() + ) + kb["updated_at"] = self._now_iso() + + def _resolve_mock_kb_ids(self, payload: dict[str, Any]) -> list[str]: + kb_ids = [ + str(item).strip() for item in payload.get("kb_ids", []) if str(item).strip() + ] + if kb_ids: + return [kb_id for kb_id in kb_ids if kb_id in self._kb_store] + + kb_names = [ + str(item).strip() + for item in payload.get("kb_names", []) + if str(item).strip() + ] + if not kb_names: + return [] + name_set = set(kb_names) + return [ + kb_id + for kb_id, kb in self._kb_store.items() + if str(kb.get("kb_name", "")).strip() in name_set + ] + + @staticmethod + def _score_mock_document(query: str, content: str) -> float: + query_terms = _term_set(query) + content_terms = _term_set(content) + if not query_terms or not content_terms: + return 0.0 + overlap = len(query_terms & content_terms) + if overlap <= 0: + return 0.0 + score = overlap / len(query_terms) + if query.strip().casefold() in str(content).casefold(): + score += 0.25 + return min(score, 1.0) + + @staticmethod + def _build_mock_context_text(results: list[dict[str, Any]]) -> str: + lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"] + for index, item in enumerate(results, start=1): + lines.append(f"【知识 {index}】") + lines.append(f"来源: {item['kb_name']} / {item['doc_name']}") + lines.append(f"内容: {item['content']}") + lines.append(f"相关度: {float(item['score']):.2f}") + lines.append("") + return "\n".join(lines) + + async def _kb_list( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + return { + "kbs": [ + dict(record) + for record in sorted( + self._kb_store.values(), + key=lambda item: str(item.get("created_at", "")), + ) + ] + } + + async def _kb_get( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + kb_id = str(payload.get("kb_id", "")).strip() + record = self._kb_store.get(kb_id) + return {"kb": dict(record) if isinstance(record, dict) else None} + + async def _kb_create( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + raw_kb = payload.get("kb") + if not isinstance(raw_kb, dict): + raise AstrBotError.invalid_input("kb.create requires kb object") + embedding_provider_id = str(raw_kb.get("embedding_provider_id", "")).strip() + if not embedding_provider_id: + raise AstrBotError.invalid_input("kb.create requires embedding_provider_id") + kb_id = uuid.uuid4().hex + now = self._now_iso() + record = { + "kb_id": kb_id, + "kb_name": str(raw_kb.get("kb_name", "")), + "description": ( + str(raw_kb.get("description")) + if raw_kb.get("description") is not None + else None + ), + "emoji": ( + str(raw_kb.get("emoji")) if raw_kb.get("emoji") is not None else None + ), + "embedding_provider_id": embedding_provider_id, + "rerank_provider_id": ( + str(raw_kb.get("rerank_provider_id")) + if raw_kb.get("rerank_provider_id") is not None + else None + ), + "chunk_size": self._optional_int(raw_kb.get("chunk_size")), + "chunk_overlap": self._optional_int(raw_kb.get("chunk_overlap")), + "top_k_dense": self._optional_int(raw_kb.get("top_k_dense")), + "top_k_sparse": self._optional_int(raw_kb.get("top_k_sparse")), + "top_m_final": self._optional_int(raw_kb.get("top_m_final")), + "doc_count": 0, + "chunk_count": 0, + "created_at": now, + "updated_at": now, + } + self._kb_store[kb_id] = record + self._kb_document_store[kb_id] = {} + return {"kb": dict(record)} + + async def _kb_update( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + kb_id = str(payload.get("kb_id", "")).strip() + raw_kb = payload.get("kb") + if not isinstance(raw_kb, dict): + raise AstrBotError.invalid_input("kb.update requires kb object") + record = self._kb_store.get(kb_id) + if not isinstance(record, dict): + return {"kb": None} + + for field_name in ( + "kb_name", + "description", + "emoji", + "embedding_provider_id", + "rerank_provider_id", + ): + if field_name in raw_kb: + value = raw_kb.get(field_name) + record[field_name] = str(value) if value is not None else None + for field_name in ( + "chunk_size", + "chunk_overlap", + "top_k_dense", + "top_k_sparse", + "top_m_final", + ): + if field_name in raw_kb: + record[field_name] = self._optional_int(raw_kb.get(field_name)) + record["updated_at"] = self._now_iso() + return {"kb": dict(record)} + + async def _kb_delete( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + kb_id = str(payload.get("kb_id", "")).strip() + documents = self._kb_document_store.pop(kb_id, {}) + for document in documents.values(): + doc_id = str(document.get("doc_id", "")).strip() + if doc_id: + self._kb_document_content_store.pop(doc_id, None) + deleted = self._kb_store.pop(kb_id, None) is not None + return {"deleted": deleted} + + async def _kb_retrieve( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + query = str(payload.get("query", "")).strip() + if not query: + raise AstrBotError.invalid_input("kb.retrieve requires query") + kb_ids = self._resolve_mock_kb_ids(payload) + if not kb_ids: + raise AstrBotError.invalid_input("kb.retrieve requires kb_ids or kb_names") + + top_m_final = self._optional_int(payload.get("top_m_final")) or 5 + results: list[dict[str, Any]] = [] + for kb_id in kb_ids: + kb = self._kb_store.get(kb_id) + if not isinstance(kb, dict): + continue + for document in self._kb_documents(kb_id).values(): + doc_id = str(document.get("doc_id", "")).strip() + if not doc_id: + continue + content = self._kb_document_content_store.get(doc_id, "") + score = self._score_mock_document(query, content) + if score <= 0: + continue + results.append( + { + "chunk_id": f"{doc_id}:0", + "doc_id": doc_id, + "kb_id": kb_id, + "kb_name": str(kb.get("kb_name", "")), + "doc_name": str(document.get("doc_name", "")), + "chunk_index": 0, + "content": content, + "score": score, + "char_count": len(content), + } + ) + results.sort(key=lambda item: float(item["score"]), reverse=True) + results = results[:top_m_final] + if not results: + return {"result": None} + return { + "result": { + "context_text": self._build_mock_context_text(results), + "results": results, + } + } + + async def _kb_document_upload( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + kb_id = str(payload.get("kb_id", "")).strip() + kb = self._kb_store.get(kb_id) + if not isinstance(kb, dict): + raise AstrBotError.invalid_input(f"Unknown knowledge base: {kb_id}") + raw_document = payload.get("document") + if not isinstance(raw_document, dict): + raise AstrBotError.invalid_input( + "kb.document.upload requires document object" + ) + + file_name = str(raw_document.get("file_name", "")).strip() + file_type = str(raw_document.get("file_type", "")).strip() + file_path = "" + content_text = "" + file_size = 0 + + text_value = raw_document.get("text") + url_value = raw_document.get("url") + file_token = str(raw_document.get("file_token", "")).strip() + + if isinstance(text_value, str) and text_value.strip(): + content_text = text_value + if not file_name: + file_name = "document.txt" + if not file_type: + file_type = "txt" + file_size = len(content_text.encode("utf-8")) + elif isinstance(url_value, str) and url_value.strip(): + url_text = url_value.strip() + content_text = f"Imported from {url_text}" + if not file_name: + file_name = ( + Path(url_text.split("?", maxsplit=1)[0]).name or "document.url" + ) + if not file_type: + suffix = Path(file_name).suffix.lstrip(".") + file_type = suffix or "url" + file_path = url_text + file_size = len(content_text.encode("utf-8")) + elif file_token: + file_path = self._file_token_store.pop(file_token, "") + if not file_path: + raise AstrBotError.invalid_input(f"Unknown file token: {file_token}") + path = Path(file_path) + if not path.exists(): + raise AstrBotError.invalid_input(f"File does not exist: {file_path}") + raw_bytes = path.read_bytes() + content_text = raw_bytes.decode("utf-8", errors="ignore") + if not file_name: + file_name = path.name + if not file_type: + file_type = path.suffix.lstrip(".") + if not file_type: + raise AstrBotError.invalid_input( + "kb.document.upload requires file_type when the file has no suffix" + ) + file_size = len(raw_bytes) + else: + raise AstrBotError.invalid_input( + "kb.document.upload requires file_token, url, or text" + ) + + chunk_size = self._optional_int(raw_document.get("chunk_size")) + if chunk_size is None or chunk_size <= 0: + chunk_size = self._optional_int(kb.get("chunk_size")) or 512 + chunk_count = max(1, math.ceil(max(len(content_text), 1) / chunk_size)) + doc_id = uuid.uuid4().hex + now = self._now_iso() + document = { + "doc_id": doc_id, + "kb_id": kb_id, + "doc_name": file_name, + "file_type": file_type, + "file_size": file_size, + "file_path": file_path, + "chunk_count": chunk_count, + "media_count": 0, + "created_at": now, + "updated_at": now, + } + self._kb_documents(kb_id)[doc_id] = document + self._kb_document_content_store[doc_id] = content_text + self._refresh_mock_kb_stats(kb_id) + return {"document": dict(document)} + + async def _kb_document_list( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + kb_id = str(payload.get("kb_id", "")).strip() + offset = max(self._optional_int(payload.get("offset")) or 0, 0) + limit = max(self._optional_int(payload.get("limit")) or 100, 0) + documents = list(self._kb_documents(kb_id).values()) + documents.sort(key=lambda item: str(item.get("created_at", ""))) + return { + "documents": [dict(item) for item in documents[offset : offset + limit]] + } + + async def _kb_document_get( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + kb_id = str(payload.get("kb_id", "")).strip() + doc_id = str(payload.get("doc_id", "")).strip() + document = self._kb_documents(kb_id).get(doc_id) + return {"document": dict(document) if isinstance(document, dict) else None} + + async def _kb_document_delete( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + kb_id = str(payload.get("kb_id", "")).strip() + doc_id = str(payload.get("doc_id", "")).strip() + deleted = self._kb_documents(kb_id).pop(doc_id, None) is not None + if deleted: + self._kb_document_content_store.pop(doc_id, None) + self._refresh_mock_kb_stats(kb_id) + return {"deleted": deleted} + + async def _kb_document_refresh( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + kb_id = str(payload.get("kb_id", "")).strip() + doc_id = str(payload.get("doc_id", "")).strip() + document = self._kb_documents(kb_id).get(doc_id) + if not isinstance(document, dict): + return {"document": None} + kb = self._kb_store.get(kb_id, {}) + chunk_size = self._optional_int(kb.get("chunk_size")) or 512 + content_text = self._kb_document_content_store.get(doc_id, "") + document["chunk_count"] = max( + 1, math.ceil(max(len(content_text), 1) / chunk_size) + ) + document["updated_at"] = self._now_iso() + self._refresh_mock_kb_stats(kb_id) + return {"document": dict(document)} + + def _register_kb_capabilities(self) -> None: + self.register( + self._builtin_descriptor("kb.list", "列出知识库"), + call_handler=self._kb_list, + ) + self.register( + self._builtin_descriptor("kb.get", "获取知识库"), + call_handler=self._kb_get, + ) + self.register( + self._builtin_descriptor("kb.create", "创建知识库"), + call_handler=self._kb_create, + ) + self.register( + self._builtin_descriptor("kb.update", "更新知识库"), + call_handler=self._kb_update, + ) + self.register( + self._builtin_descriptor("kb.delete", "删除知识库"), + call_handler=self._kb_delete, + ) + self.register( + self._builtin_descriptor("kb.retrieve", "检索知识库"), + call_handler=self._kb_retrieve, + ) + self.register( + self._builtin_descriptor("kb.document.upload", "上传知识库文档"), + call_handler=self._kb_document_upload, + ) + self.register( + self._builtin_descriptor("kb.document.list", "列出知识库文档"), + call_handler=self._kb_document_list, + ) + self.register( + self._builtin_descriptor("kb.document.get", "获取知识库文档"), + call_handler=self._kb_document_get, + ) + self.register( + self._builtin_descriptor("kb.document.delete", "删除知识库文档"), + call_handler=self._kb_document_delete, + ) + self.register( + self._builtin_descriptor("kb.document.refresh", "刷新知识库文档"), + call_handler=self._kb_document_refresh, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/llm.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/llm.py new file mode 100644 index 0000000000..daf1621128 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/llm.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from typing import Any + +from ..bridge_base import CapabilityRouterBridgeBase + + +class LLMCapabilityMixin(CapabilityRouterBridgeBase): + async def _llm_chat( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + prompt = str(payload.get("prompt", "")) + return {"text": f"Echo: {prompt}"} + + async def _llm_chat_raw( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + prompt = str(payload.get("prompt", "")) + text = f"Echo: {prompt}" + return { + "text": text, + "usage": { + "input_tokens": len(prompt), + "output_tokens": len(text), + }, + "finish_reason": "stop", + "tool_calls": [], + } + + async def _llm_stream( + self, + _request_id: str, + payload: dict[str, Any], + token, + ) -> AsyncIterator[dict[str, Any]]: + text = f"Echo: {str(payload.get('prompt', ''))}" + for char in text: + token.raise_if_cancelled() + await asyncio.sleep(0) + yield {"text": char} + + def _register_llm_capabilities(self) -> None: + self.register( + self._builtin_descriptor("llm.chat", "发送对话请求,返回文本"), + call_handler=self._llm_chat, + ) + self.register( + self._builtin_descriptor("llm.chat_raw", "发送对话请求,返回完整响应"), + call_handler=self._llm_chat_raw, + ) + self.register( + self._builtin_descriptor( + "llm.stream_chat", + "流式对话", + supports_stream=True, + cancelable=True, + ), + stream_handler=self._llm_stream, + finalize=lambda chunks: { + "text": "".join(item.get("text", "") for item in chunks) + }, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/mcp.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/mcp.py new file mode 100644 index 0000000000..33582f5b44 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/mcp.py @@ -0,0 +1,527 @@ +from __future__ import annotations + +import asyncio +import uuid +from typing import Any + +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + + +def _mock_tools_from_config(name: str, config: dict[str, Any]) -> list[str]: + configured = config.get("mock_tools") + if isinstance(configured, list): + tools = [str(item) for item in configured if str(item).strip()] + if tools: + return tools + return [f"{name}_tool"] + + +def _mock_server_record( + *, + name: str, + scope: str, + active: bool, + running: bool, + config: dict[str, Any], + tools: list[str], + errlogs: list[str] | None = None, + last_error: str | None = None, +) -> dict[str, Any]: + return { + "name": name, + "scope": scope, + "active": bool(active), + "running": bool(running), + "config": dict(config), + "tools": list(tools), + "errlogs": list(errlogs or []), + "last_error": last_error, + } + + +class McpCapabilityMixin(CapabilityRouterBridgeBase): + def _plugin_local_mcp_servers(self, plugin_id: str) -> dict[str, dict[str, Any]]: + plugin = self._plugins.get(plugin_id) + if plugin is None: + raise AstrBotError.invalid_input(f"Unknown plugin: {plugin_id}") + return plugin.local_mcp_servers + + @staticmethod + def _require_server_name(payload: dict[str, Any], capability_name: str) -> str: + name = str(payload.get("name", "")).strip() + if not name: + raise AstrBotError.invalid_input(f"{capability_name} requires name") + return name + + @staticmethod + def _normalized_timeout(payload: dict[str, Any], default: float = 30.0) -> float: + raw_value = payload.get("timeout", default) + try: + timeout = float(raw_value) + except (TypeError, ValueError) as exc: + raise AstrBotError.invalid_input("timeout must be numeric") from exc + if timeout <= 0: + raise AstrBotError.invalid_input("timeout must be greater than 0") + return timeout + + def _mock_connect_outcome( + self, + *, + name: str, + config: dict[str, Any], + scope: str, + ) -> dict[str, Any]: + if bool(config.get("mock_fail", False)): + last_error = str(config.get("mock_error") or f"{name} failed") + return _mock_server_record( + name=name, + scope=scope, + active=bool(config.get("active", True)), + running=False, + config=config, + tools=[], + errlogs=[last_error], + last_error=last_error, + ) + return _mock_server_record( + name=name, + scope=scope, + active=bool(config.get("active", True)), + running=True, + config=config, + tools=_mock_tools_from_config(name, config), + errlogs=[], + last_error=None, + ) + + async def _mcp_local_get( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.local.get") + name = self._require_server_name(payload, "mcp.local.get") + return { + "server": self._plugin_local_mcp_servers(plugin_id).get(name), + } + + async def _mcp_local_list( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.local.list") + servers = sorted( + self._plugin_local_mcp_servers(plugin_id).values(), + key=lambda item: str(item.get("name", "")), + ) + return {"servers": [dict(item) for item in servers]} + + async def _mcp_local_enable( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.local.enable") + name = self._require_server_name(payload, "mcp.local.enable") + servers = self._plugin_local_mcp_servers(plugin_id) + server = servers.get(name) + if server is None: + raise AstrBotError.invalid_input(f"Unknown local MCP server: {name}") + if bool(server.get("active", False)) and bool(server.get("running", False)): + return {"server": dict(server)} + updated = self._mock_connect_outcome( + name=name, + config=dict(server.get("config", {})), + scope="local", + ) + updated["active"] = True + servers[name] = updated + return {"server": dict(updated)} + + async def _mcp_local_disable( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.local.disable") + name = self._require_server_name(payload, "mcp.local.disable") + servers = self._plugin_local_mcp_servers(plugin_id) + server = servers.get(name) + if server is None: + raise AstrBotError.invalid_input(f"Unknown local MCP server: {name}") + if not bool(server.get("active", False)) and not bool( + server.get("running", False) + ): + return {"server": dict(server)} + updated = dict(server) + updated["active"] = False + updated["running"] = False + servers[name] = updated + return {"server": updated} + + async def _mcp_local_wait_until_ready( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.local.wait_until_ready") + name = self._require_server_name(payload, "mcp.local.wait_until_ready") + timeout = self._normalized_timeout(payload) + server = self._plugin_local_mcp_servers(plugin_id).get(name) + if server is None: + raise AstrBotError.invalid_input(f"Unknown local MCP server: {name}") + if bool(server.get("running", False)): + return {"server": dict(server)} + delay = float(server.get("config", {}).get("mock_connect_delay", 0.0) or 0.0) + if delay > timeout: + raise TimeoutError( + f"Local MCP server '{name}' did not become ready in time" + ) + if delay > 0: + await asyncio.sleep(delay) + if bool(server.get("active", False)) and not bool( + server.get("config", {}).get("mock_fail", False) + ): + refreshed = self._mock_connect_outcome( + name=name, + config=dict(server.get("config", {})), + scope="local", + ) + refreshed["active"] = bool(server.get("active", False)) + self._plugin_local_mcp_servers(plugin_id)[name] = refreshed + refreshed = self._plugin_local_mcp_servers(plugin_id).get(name) + if refreshed is None or not bool(refreshed.get("running", False)): + raise TimeoutError( + f"Local MCP server '{name}' did not become ready in time" + ) + return {"server": dict(refreshed)} + + async def _mcp_session_open( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.session.open") + name = self._require_server_name(payload, "mcp.session.open") + config = payload.get("config") + if not isinstance(config, dict): + raise AstrBotError.invalid_input("mcp.session.open requires config object") + timeout = self._normalized_timeout(payload) + delay = float(config.get("mock_connect_delay", 0.0) or 0.0) + if bool(config.get("mock_fail", False)) or delay > timeout: + raise TimeoutError(f"MCP session '{name}' failed to connect in time") + if delay > 0: + await asyncio.sleep(delay) + session_id = f"{plugin_id}:{uuid.uuid4().hex}" + tools = _mock_tools_from_config(name, dict(config)) + self._mcp_session_store[session_id] = { + "plugin_id": plugin_id, + "name": name, + "config": dict(config), + "tools": tools, + "tool_results": dict(config.get("mock_tool_results", {})) + if isinstance(config.get("mock_tool_results"), dict) + else {}, + } + return {"session_id": session_id, "tools": list(tools)} + + async def _mcp_session_list_tools( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session_id = str(payload.get("session_id", "")).strip() + session = self._mcp_session_store.get(session_id) + if session is None: + raise AstrBotError.invalid_input("Unknown MCP session") + return {"tools": list(session.get("tools", []))} + + async def _mcp_session_call_tool( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session_id = str(payload.get("session_id", "")).strip() + session = self._mcp_session_store.get(session_id) + if session is None: + raise AstrBotError.invalid_input("Unknown MCP session") + tool_name = str(payload.get("tool_name", "")).strip() + if not tool_name: + raise AstrBotError.invalid_input("mcp.session.call_tool requires tool_name") + args = payload.get("args") + if not isinstance(args, dict): + raise AstrBotError.invalid_input( + "mcp.session.call_tool requires args object" + ) + tool_results = session.get("tool_results", {}) + if isinstance(tool_results, dict) and tool_name in tool_results: + result = tool_results[tool_name] + return { + "result": dict(result) + if isinstance(result, dict) + else {"value": result} + } + return { + "result": { + "tool_name": tool_name, + "arguments": dict(args), + "content": f"mock:{session['name']}:{tool_name}", + } + } + + async def _mcp_session_close( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session_id = str(payload.get("session_id", "")).strip() + self._mcp_session_store.pop(session_id, None) + return {} + + def _require_global_mcp_risk_ack( + self, + plugin_id: str, + capability_name: str, + ) -> None: + plugin = self._plugins.get(plugin_id) + metadata = plugin.metadata if plugin is not None else {} + if bool(metadata.get("acknowledge_global_mcp_risk", False)): + return + raise PermissionError( + f"{capability_name} requires @acknowledge_global_mcp_risk" + ) + + def _audit_global_mcp_mutation( + self, + *, + plugin_id: str, + action: str, + server_name: str, + request_id: str, + ) -> None: + self._mcp_audit_logs.append( + { + "plugin_id": plugin_id, + "action": action, + "server_name": server_name, + "request_id": request_id, + } + ) + + async def _mcp_global_register( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.global.register") + self._require_global_mcp_risk_ack(plugin_id, "mcp.global.register") + name = self._require_server_name(payload, "mcp.global.register") + config = payload.get("config") + if not isinstance(config, dict): + raise AstrBotError.invalid_input( + "mcp.global.register requires config object" + ) + if name in self._mcp_global_servers: + raise AstrBotError.invalid_input( + f"Global MCP server already exists: {name}" + ) + record = self._mock_connect_outcome( + name=name, + config=dict(config), + scope="global", + ) + self._mcp_global_servers[name] = record + self._audit_global_mcp_mutation( + plugin_id=plugin_id, + action="register", + server_name=name, + request_id=request_id, + ) + return {"server": dict(record)} + + async def _mcp_global_get( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.global.get") + self._require_global_mcp_risk_ack(plugin_id, "mcp.global.get") + name = self._require_server_name(payload, "mcp.global.get") + return {"server": self._mcp_global_servers.get(name)} + + async def _mcp_global_list( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.global.list") + self._require_global_mcp_risk_ack(plugin_id, "mcp.global.list") + servers = sorted( + self._mcp_global_servers.values(), + key=lambda item: str(item.get("name", "")), + ) + return {"servers": [dict(item) for item in servers]} + + async def _mcp_global_enable( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.global.enable") + self._require_global_mcp_risk_ack(plugin_id, "mcp.global.enable") + name = self._require_server_name(payload, "mcp.global.enable") + record = self._mcp_global_servers.get(name) + if record is None: + raise AstrBotError.invalid_input(f"Unknown global MCP server: {name}") + updated = self._mock_connect_outcome( + name=name, + config=dict(record.get("config", {})), + scope="global", + ) + updated["active"] = True + self._mcp_global_servers[name] = updated + self._audit_global_mcp_mutation( + plugin_id=plugin_id, + action="enable", + server_name=name, + request_id=request_id, + ) + return {"server": dict(updated)} + + async def _mcp_global_disable( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.global.disable") + self._require_global_mcp_risk_ack(plugin_id, "mcp.global.disable") + name = self._require_server_name(payload, "mcp.global.disable") + record = self._mcp_global_servers.get(name) + if record is None: + raise AstrBotError.invalid_input(f"Unknown global MCP server: {name}") + updated = dict(record) + updated["active"] = False + updated["running"] = False + self._mcp_global_servers[name] = updated + self._audit_global_mcp_mutation( + plugin_id=plugin_id, + action="disable", + server_name=name, + request_id=request_id, + ) + return {"server": dict(updated)} + + async def _mcp_global_unregister( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("mcp.global.unregister") + self._require_global_mcp_risk_ack(plugin_id, "mcp.global.unregister") + name = self._require_server_name(payload, "mcp.global.unregister") + record = self._mcp_global_servers.pop(name, None) + if record is None: + raise AstrBotError.invalid_input(f"Unknown global MCP server: {name}") + self._audit_global_mcp_mutation( + plugin_id=plugin_id, + action="unregister", + server_name=name, + request_id=request_id, + ) + return {"server": dict(record)} + + async def _internal_mcp_local_execute( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = str(payload.get("plugin_id", "")).strip() + server_name = str(payload.get("server_name", "")).strip() + tool_name = str(payload.get("tool_name", "")).strip() + tool_args = payload.get("tool_args") + if not plugin_id or not server_name or not tool_name: + raise AstrBotError.invalid_input( + "internal.mcp.local.execute requires plugin_id, server_name, and tool_name" + ) + if not isinstance(tool_args, dict): + raise AstrBotError.invalid_input( + "internal.mcp.local.execute requires tool_args object" + ) + plugin = self._plugins.get(plugin_id) + server = ( + plugin.local_mcp_servers.get(server_name) if plugin is not None else None + ) + if server is None or not bool(server.get("running", False)): + return { + "content": f"Local MCP server unavailable: {server_name}", + "success": False, + } + if tool_name not in server.get("tools", []): + return { + "content": f"Local MCP tool not found: {server_name}.{tool_name}", + "success": False, + } + return { + "content": f"mock:{server_name}:{tool_name}:{tool_args}", + "success": True, + } + + def _register_mcp_capabilities(self) -> None: + self.register( + self._builtin_descriptor("mcp.local.get", "Get local MCP server"), + call_handler=self._mcp_local_get, + ) + self.register( + self._builtin_descriptor("mcp.local.list", "List local MCP servers"), + call_handler=self._mcp_local_list, + ) + self.register( + self._builtin_descriptor("mcp.local.enable", "Enable local MCP server"), + call_handler=self._mcp_local_enable, + ) + self.register( + self._builtin_descriptor("mcp.local.disable", "Disable local MCP server"), + call_handler=self._mcp_local_disable, + ) + self.register( + self._builtin_descriptor( + "mcp.local.wait_until_ready", + "Wait until local MCP server is ready", + ), + call_handler=self._mcp_local_wait_until_ready, + ) + self.register( + self._builtin_descriptor("mcp.session.open", "Open temporary MCP session"), + call_handler=self._mcp_session_open, + ) + self.register( + self._builtin_descriptor( + "mcp.session.list_tools", + "List tools in temporary MCP session", + ), + call_handler=self._mcp_session_list_tools, + ) + self.register( + self._builtin_descriptor( + "mcp.session.call_tool", + "Call tool in temporary MCP session", + ), + call_handler=self._mcp_session_call_tool, + ) + self.register( + self._builtin_descriptor( + "mcp.session.close", "Close temporary MCP session" + ), + call_handler=self._mcp_session_close, + ) + self.register( + self._builtin_descriptor( + "mcp.global.register", + "Register global MCP server", + ), + call_handler=self._mcp_global_register, + ) + self.register( + self._builtin_descriptor("mcp.global.get", "Get global MCP server"), + call_handler=self._mcp_global_get, + ) + self.register( + self._builtin_descriptor("mcp.global.list", "List global MCP servers"), + call_handler=self._mcp_global_list, + ) + self.register( + self._builtin_descriptor("mcp.global.enable", "Enable global MCP server"), + call_handler=self._mcp_global_enable, + ) + self.register( + self._builtin_descriptor( + "mcp.global.disable", + "Disable global MCP server", + ), + call_handler=self._mcp_global_disable, + ) + self.register( + self._builtin_descriptor( + "mcp.global.unregister", + "Unregister global MCP server", + ), + call_handler=self._mcp_global_unregister, + ) + self.register( + self._builtin_descriptor( + "internal.mcp.local.execute", + "Execute local MCP tool", + ), + call_handler=self._internal_mcp_local_execute, + exposed=False, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/memory.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/memory.py new file mode 100644 index 0000000000..f55ef7ccf0 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/memory.py @@ -0,0 +1,655 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +from ...._internal.invocation_context import current_caller_plugin_id +from ...._internal.memory_utils import ( + cosine_similarity, + extract_memory_text, + is_ttl_memory_entry, + memory_expiration_from_ttl, + memory_index_entry, + memory_keyword_score, + memory_value_for_search, +) +from ...._memory_backend import PluginMemoryBackend +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + + +class MemoryCapabilityMixin(CapabilityRouterBridgeBase): + def _memory_plugin_id(self) -> str: + plugin_id = current_caller_plugin_id() + return self._validated_plugin_id( + str(plugin_id).strip() or "__anonymous__", + capability_name="memory.*", + ) + + def _memory_backend_for_plugin(self, plugin_id: str) -> PluginMemoryBackend: + backend = self._memory_backends.get(plugin_id) + if backend is None: + backend = PluginMemoryBackend( + self._plugin_data_dir(plugin_id, capability_name="memory.*") + ) + self._memory_backends[plugin_id] = backend + return backend + + @staticmethod + def _is_ttl_memory_entry(value: Any) -> bool: + """判断存储值是否使用了 TTL 包装结构。 + + Args: + value: 待检查的存储值。 + + Returns: + bool: 如果值包含 ``value`` 和 ``ttl_seconds`` 字段则返回 ``True``。 + """ + return is_ttl_memory_entry(value) + + @classmethod + def _memory_value_for_search(cls, stored: Any) -> dict[str, Any] | None: + """提取用于检索的原始 memory payload。 + + Args: + stored: memory_store 中保存的原始值。 + + Returns: + dict[str, Any] | None: 解开 TTL 包装后的字典,无法解析时返回 ``None``。 + """ + return memory_value_for_search(stored) + + @classmethod + def _extract_memory_text(cls, stored: Any) -> str: + """提取用于检索索引的首选文本。 + + Args: + stored: memory_store 中保存的原始值。 + + Returns: + str: 优先使用 ``embedding_text`` / ``content`` 等字段,兜底为 JSON 文本。 + """ + return extract_memory_text(stored) + + @staticmethod + def _memory_expiration_from_ttl(ttl_seconds: Any) -> datetime | None: + """将 TTL 秒数转换为 UTC 过期时间。 + + Args: + ttl_seconds: TTL 秒数。 + + Returns: + datetime | None: 绝对过期时间;当输入无效时返回 ``None``。 + """ + return memory_expiration_from_ttl(ttl_seconds) + + @staticmethod + def _memory_keyword_score(query: str, key: str, text: str) -> float: + """计算关键词匹配分数。 + + Args: + query: 查询文本。 + key: memory 条目的键。 + text: 已索引的检索文本。 + + Returns: + float: 基于键名和文本命中的粗粒度关键词分数。 + """ + return memory_keyword_score(query, key, text) + + @staticmethod + def _cosine_similarity(left: list[float], right: list[float]) -> float: + """计算两个向量之间的余弦相似度。 + + Args: + left: 左侧向量。 + right: 右侧向量。 + + Returns: + float: 余弦相似度;输入不合法时返回 ``0.0``。 + """ + return cosine_similarity(left, right) + + def _resolve_memory_embedding_provider_id( + self, + provider_id: Any, + *, + required: bool, + ) -> str | None: + """解析 memory.search 要使用的 embedding provider。 + + Args: + provider_id: 调用方显式传入的 provider 标识。 + required: 当前检索模式是否强制要求 embedding provider。 + + Returns: + str | None: 最终选中的 provider 标识;在非强制场景下允许返回 ``None``。 + """ + normalized = str(provider_id).strip() if provider_id is not None else "" + if normalized: + self._provider_entry( + {"provider_id": normalized}, + "memory.search", + "embedding", + ) + return normalized + active_id = self._active_provider_ids.get("embedding") + if active_id is not None: + normalized_active = str(active_id).strip() + if normalized_active: + self._provider_entry( + {"provider_id": normalized_active}, + "memory.search", + "embedding", + ) + return normalized_active + if required: + raise AstrBotError.invalid_input( + "memory.search requires an embedding provider", + ) + return None + + @staticmethod + def _memory_index_entry(entry: Any, *, text: str) -> dict[str, Any]: + """将原始索引项规范化为内部统一结构。 + + Args: + entry: 当前索引表中的原始项。 + text: 当前条目的索引文本。 + + Returns: + dict[str, Any]: 统一后的索引项,包含 ``text``、``embedding``、``provider_id``。 + """ + return memory_index_entry(entry, text=text) + + def _clear_memory_sidecars(self, key: str) -> None: + """清理指定 memory 键对应的所有 sidecar 状态。 + + Args: + key: memory 条目的键。 + + Returns: + None + """ + self._memory_index.pop(key, None) + self._memory_expires_at.pop(key, None) + self._memory_dirty_keys.discard(key) + + def _delete_memory_entry(self, key: str) -> bool: + """删除 memory 条目并同步清理 sidecar 状态。 + + Args: + key: memory 条目的键。 + + Returns: + bool: 条目存在并删除成功时返回 ``True``。 + """ + deleted = self.memory_store.pop(key, None) is not None + self._clear_memory_sidecars(key) + return deleted + + def _upsert_memory_sidecars( + self, + key: str, + stored: dict[str, Any], + *, + expires_at: datetime | None = None, + ) -> None: + """创建或更新单条 memory 的 sidecar 索引状态。 + + Args: + key: memory 条目的键。 + stored: 需要建立索引的原始存储值。 + expires_at: 可选的绝对过期时间。 + + Returns: + None + """ + self._memory_index[key] = { + "text": self._extract_memory_text(stored), + "embedding": None, + "provider_id": None, + } + if expires_at is None: + self._memory_expires_at.pop(key, None) + else: + self._memory_expires_at[key] = expires_at + self._memory_dirty_keys.add(key) + + def _ensure_memory_sidecars(self, key: str, stored: Any) -> None: + """确保 sidecar 状态与当前存储值保持一致。 + + Args: + key: memory 条目的键。 + stored: memory_store 中的当前存储值。 + + Returns: + None + """ + if not isinstance(stored, dict): + return + text = self._extract_memory_text(stored) + existed = key in self._memory_index + entry = self._memory_index_entry(self._memory_index.get(key), text=text) + if entry["text"] != text: + entry["text"] = text + entry["embedding"] = None + entry["provider_id"] = None + self._memory_dirty_keys.add(key) + self._memory_index[key] = entry + if not existed: + self._memory_dirty_keys.add(key) + + def _is_memory_expired(self, key: str) -> bool: + """判断 memory 条目是否已过期。 + + Args: + key: memory 条目的键。 + + Returns: + bool: 如果当前时间已超过记录的过期时间则返回 ``True``。 + """ + expires_at = self._memory_expires_at.get(key) + return expires_at is not None and expires_at <= datetime.now(timezone.utc) + + def _purge_expired_memory_entry(self, key: str) -> bool: + """在单条 memory 已过期时立即清理它。 + + Args: + key: memory 条目的键。 + + Returns: + bool: 如果条目已过期并被成功清理则返回 ``True``。 + """ + if not self._is_memory_expired(key): + return False + self._delete_memory_entry(key) + return True + + def _purge_expired_memory_entries(self) -> None: + """批量清理所有已跟踪的过期 TTL 条目。 + + Returns: + None + """ + for key in list(self._memory_expires_at): + self._purge_expired_memory_entry(key) + + async def _embedding_for_text( + self, + *, + provider_id: str, + text: str, + ) -> list[float]: + """通过 embedding capability 获取单条文本向量。 + + Args: + provider_id: 使用的 embedding provider 标识。 + text: 待向量化的文本。 + + Returns: + list[float]: provider 返回的向量;异常场景下返回空列表。 + """ + output = await self._provider_embedding_get_embedding( + "", + {"provider_id": provider_id, "text": text}, + None, + ) + embedding = output.get("embedding") + if not isinstance(embedding, list): + return [] + return [float(item) for item in embedding] + + async def _embeddings_for_texts( + self, + *, + provider_id: str, + texts: list[str], + ) -> list[list[float]]: + """批量获取多条文本的 embedding 向量。 + + Args: + provider_id: 使用的 embedding provider 标识。 + texts: 待向量化的文本列表。 + + Returns: + list[list[float]]: 与输入顺序对应的向量列表。 + """ + if not texts: + return [] + output = await self._provider_embedding_get_embeddings( + "", + {"provider_id": provider_id, "texts": texts}, + None, + ) + embeddings = output.get("embeddings") + if not isinstance(embeddings, list): + return [] + return [ + [float(value) for value in item] + for item in embeddings + if isinstance(item, list) + ] + + async def _refresh_memory_embeddings(self, *, provider_id: str) -> None: + """刷新当前 provider 下脏或过期的 memory 向量索引。 + + Args: + provider_id: 当前使用的 embedding provider 标识。 + + Returns: + None + """ + keys_to_refresh: list[str] = [] + texts_to_refresh: list[str] = [] + for key, stored in self.memory_store.items(): + self._ensure_memory_sidecars(key, stored) + entry = self._memory_index_entry( + self._memory_index.get(key), + text=self._extract_memory_text(stored), + ) + should_refresh = ( + key in self._memory_dirty_keys + or entry["embedding"] is None + or entry["provider_id"] != provider_id + ) + self._memory_index[key] = entry + if should_refresh: + keys_to_refresh.append(key) + texts_to_refresh.append(str(entry["text"])) + # 分批请求,避免单次 payload 过大导致 OOM 或 413 + _BATCH_SIZE = 64 + embeddings: list[list[float]] = [] + for batch_start in range(0, len(texts_to_refresh), _BATCH_SIZE): + batch = texts_to_refresh[batch_start : batch_start + _BATCH_SIZE] + embeddings.extend( + await self._embeddings_for_texts( + provider_id=provider_id, + texts=batch, + ) + ) + for index, key in enumerate(keys_to_refresh): + entry = self._memory_index_entry( + self._memory_index.get(key), + text=str(texts_to_refresh[index]), + ) + entry["embedding"] = embeddings[index] if index < len(embeddings) else [] + entry["provider_id"] = provider_id + self._memory_index[key] = entry + self._memory_dirty_keys.discard(key) + + async def _memory_search( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + query = str(payload.get("query", "")) + mode = str(payload.get("mode", "auto")).strip().lower() or "auto" + limit = self._optional_int(payload.get("limit")) + raw_min_score = payload.get("min_score") + min_score = float(raw_min_score) if raw_min_score is not None else None + namespace = payload.get("namespace") + include_descendants = bool(payload.get("include_descendants", True)) + provider_id = self._resolve_memory_embedding_provider_id( + payload.get("provider_id"), + required=mode in {"vector", "hybrid"}, + ) + effective_mode = mode + if effective_mode == "auto": + effective_mode = "hybrid" if provider_id is not None else "keyword" + backend = self._memory_backend_for_plugin(plugin_id) + items = await backend.search( + query, + namespace=str(namespace) if namespace is not None else None, + include_descendants=include_descendants, + mode=effective_mode, + limit=limit, + min_score=min_score, + provider_id=provider_id, + embed_one=( + ( + lambda text: self._embedding_for_text( + provider_id=provider_id, text=text + ) + ) + if provider_id is not None and effective_mode in {"vector", "hybrid"} + else None + ), + embed_many=( + ( + lambda texts: self._embeddings_for_texts( + provider_id=provider_id, + texts=texts, + ) + ) + if provider_id is not None and effective_mode in {"vector", "hybrid"} + else None + ), + ) + return {"items": items} + + async def _memory_save( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + key = str(payload.get("key", "")) + value = payload.get("value") + if not isinstance(value, dict): + raise AstrBotError.invalid_input("memory.save 的 value 必须是 object") + await self._memory_backend_for_plugin(plugin_id).save( + key, + value, + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {} + + async def _memory_get( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + key = str(payload.get("key", "")) + value = await self._memory_backend_for_plugin(plugin_id).get( + key, + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"value": value} + + async def _memory_list_keys( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + keys = await self._memory_backend_for_plugin(plugin_id).list_keys( + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"keys": keys} + + async def _memory_exists( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + exists = await self._memory_backend_for_plugin(plugin_id).exists( + str(payload.get("key", "")), + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"exists": exists} + + async def _memory_delete( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + await self._memory_backend_for_plugin(plugin_id).delete( + str(payload.get("key", "")), + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {} + + async def _memory_clear_namespace( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + deleted_count = await self._memory_backend_for_plugin( + plugin_id + ).clear_namespace( + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + include_descendants=bool(payload.get("include_descendants", False)), + ) + return {"deleted_count": deleted_count} + + async def _memory_save_with_ttl( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + key = str(payload.get("key", "")) + value = payload.get("value") + ttl_seconds = payload.get("ttl_seconds", 0) + if not isinstance(value, dict): + raise AstrBotError.invalid_input( + "memory.save_with_ttl 的 value 必须是 object" + ) + await self._memory_backend_for_plugin(plugin_id).save_with_ttl( + key, + value, + int(ttl_seconds), + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {} + + async def _memory_get_many( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + keys_payload = payload.get("keys") + if not isinstance(keys_payload, (list, tuple)): + raise AstrBotError.invalid_input("memory.get_many 的 keys 必须是数组") + items = await self._memory_backend_for_plugin(plugin_id).get_many( + [str(item) for item in keys_payload], + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"items": items} + + async def _memory_delete_many( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + keys_payload = payload.get("keys") + if not isinstance(keys_payload, (list, tuple)): + raise AstrBotError.invalid_input("memory.delete_many 的 keys 必须是数组") + deleted_count = await self._memory_backend_for_plugin(plugin_id).delete_many( + [str(item) for item in keys_payload], + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"deleted_count": deleted_count} + + async def _memory_count( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + count = await self._memory_backend_for_plugin(plugin_id).count( + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + include_descendants=bool(payload.get("include_descendants", False)), + ) + return {"count": count} + + async def _memory_stats( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._memory_plugin_id() + stats = await self._memory_backend_for_plugin(plugin_id).stats( + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + include_descendants=bool(payload.get("include_descendants", True)), + ) + stats["plugin_id"] = plugin_id + return stats + + def _register_memory_capabilities(self) -> None: + self.register( + self._builtin_descriptor("memory.search", "搜索记忆"), + call_handler=self._memory_search, + ) + self.register( + self._builtin_descriptor("memory.save", "保存记忆"), + call_handler=self._memory_save, + ) + self.register( + self._builtin_descriptor("memory.get", "读取单条记忆"), + call_handler=self._memory_get, + ) + self.register( + self._builtin_descriptor("memory.list_keys", "列出命名空间内的记忆键"), + call_handler=self._memory_list_keys, + ) + self.register( + self._builtin_descriptor("memory.exists", "检查记忆键是否存在"), + call_handler=self._memory_exists, + ) + self.register( + self._builtin_descriptor("memory.delete", "删除记忆"), + call_handler=self._memory_delete, + ) + self.register( + self._builtin_descriptor("memory.clear_namespace", "清理记忆命名空间"), + call_handler=self._memory_clear_namespace, + ) + self.register( + self._builtin_descriptor("memory.save_with_ttl", "保存带过期时间的记忆"), + call_handler=self._memory_save_with_ttl, + ) + self.register( + self._builtin_descriptor("memory.get_many", "批量获取记忆"), + call_handler=self._memory_get_many, + ) + self.register( + self._builtin_descriptor("memory.delete_many", "批量删除记忆"), + call_handler=self._memory_delete_many, + ) + self.register( + self._builtin_descriptor("memory.count", "统计命名空间内的记忆数量"), + call_handler=self._memory_count, + ) + self.register( + self._builtin_descriptor("memory.stats", "获取记忆统计信息"), + call_handler=self._memory_stats, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/message_history.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/message_history.py new file mode 100644 index 0000000000..3e2b6666bc --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/message_history.py @@ -0,0 +1,338 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +from ....errors import AstrBotError +from ....message.session import MessageSession +from ..bridge_base import CapabilityRouterBridgeBase + + +def _session_payload(session: MessageSession) -> dict[str, str]: + return { + "platform_id": str(session.platform_id), + "message_type": str(session.message_type), + "session_id": str(session.session_id), + } + + +class MessageHistoryCapabilityMixin(CapabilityRouterBridgeBase): + @staticmethod + def _normalize_timestamp(raw_value: Any) -> datetime: + normalized = str(raw_value or "").strip() + if normalized.endswith("Z"): + normalized = f"{normalized[:-1]}+00:00" + parsed = datetime.fromisoformat(normalized) + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + return parsed.astimezone(timezone.utc) + + @staticmethod + def _typed_session_from_payload(payload: Any) -> MessageSession: + if not isinstance(payload, dict): + raise AstrBotError.invalid_input( + "message_history capabilities require a session object" + ) + platform_id = str(payload.get("platform_id", "")).strip() + message_type = str(payload.get("message_type", "")).strip() + session_id = str(payload.get("session_id", "")).strip() + if not platform_id or not message_type or not session_id: + raise AstrBotError.invalid_input( + "message_history session requires platform_id, message_type, and session_id" + ) + return MessageSession( + platform_id=platform_id, + message_type=message_type, + session_id=session_id, + ) + + @staticmethod + def _typed_key(session: MessageSession) -> str: + return ( + f"{str(session.platform_id)}:{str(session.message_type).lower()}:" + f"{str(session.session_id)}" + ) + + def _message_history_records(self, session: MessageSession) -> list[dict[str, Any]]: + key = self._typed_key(session) + records = self._message_history_store.get(key) + if records is None: + records = [] + self._message_history_store[key] = records + return records + + def _next_message_history_id(self) -> int: + next_id = int(self._message_history_next_id) + self._message_history_next_id += 1 + return next_id + + def _create_message_history_record( + self, + *, + session: MessageSession, + sender_payload: dict[str, Any], + parts_payload: list[dict[str, Any]], + metadata: dict[str, Any], + idempotency_key: str | None, + ) -> dict[str, Any]: + now = self._now_iso() + return { + "id": self._next_message_history_id(), + "session": _session_payload(session), + "sender": { + "sender_id": ( + str(sender_payload.get("sender_id")) + if sender_payload.get("sender_id") is not None + else None + ), + "sender_name": ( + str(sender_payload.get("sender_name")) + if sender_payload.get("sender_name") is not None + else None + ), + }, + "parts": [dict(item) for item in parts_payload if isinstance(item, dict)], + "metadata": dict(metadata), + "created_at": now, + "updated_at": now, + "idempotency_key": idempotency_key, + } + + @staticmethod + def _serialize_record(record: dict[str, Any]) -> dict[str, Any]: + return { + "id": int(record.get("id", 0) or 0), + "session": ( + dict(record.get("session")) + if isinstance(record.get("session"), dict) + else {} + ), + "sender": ( + dict(record.get("sender")) + if isinstance(record.get("sender"), dict) + else {} + ), + "parts": ( + [ + dict(item) + for item in record.get("parts", []) + if isinstance(item, dict) + ] + if isinstance(record.get("parts"), list) + else [] + ), + "metadata": ( + dict(record.get("metadata")) + if isinstance(record.get("metadata"), dict) + else {} + ), + "created_at": record.get("created_at"), + "updated_at": record.get("updated_at"), + "idempotency_key": ( + str(record.get("idempotency_key")) + if record.get("idempotency_key") is not None + else None + ), + } + + @staticmethod + def _parse_boundary(raw_value: Any, field_name: str) -> datetime: + text = str(raw_value or "").strip() + if not text: + raise AstrBotError.invalid_input( + f"message_history.{field_name} requires {field_name}" + ) + try: + return MessageHistoryCapabilityMixin._normalize_timestamp(text) + except ValueError as exc: + raise AstrBotError.invalid_input( + f"message_history.{field_name} requires an ISO datetime string" + ) from exc + + async def _message_history_list( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = self._typed_session_from_payload(payload.get("session")) + raw_limit = self._optional_int(payload.get("limit")) + limit = 50 if raw_limit is None else raw_limit + if limit < 1: + raise AstrBotError.invalid_input("message_history.list requires limit >= 1") + raw_cursor = payload.get("cursor") + cursor_id = ( + self._optional_int(raw_cursor) if raw_cursor not in (None, "") else None + ) + if raw_cursor not in (None, "") and (cursor_id is None or cursor_id < 1): + raise AstrBotError.invalid_input( + "message_history.list requires cursor to be a positive integer string" + ) + records = list(reversed(self._message_history_records(session))) + total = len(records) + if cursor_id is not None: + records = [ + record for record in records if int(record.get("id", 0)) < cursor_id + ] + page_records = records[:limit] + next_cursor = ( + str(page_records[-1]["id"]) + if len(records) > limit and page_records + else None + ) + return { + "page": { + "records": [self._serialize_record(record) for record in page_records], + "next_cursor": next_cursor, + "total": total, + } + } + + async def _message_history_get_by_id( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = self._typed_session_from_payload(payload.get("session")) + record_id = self._optional_int(payload.get("record_id")) + if record_id is None or record_id < 1: + raise AstrBotError.invalid_input( + "message_history.get_by_id requires record_id >= 1" + ) + record = next( + ( + item + for item in self._message_history_records(session) + if int(item.get("id", 0) or 0) == record_id + ), + None, + ) + return { + "record": self._serialize_record(record) if record is not None else None + } + + async def _message_history_append( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = self._typed_session_from_payload(payload.get("session")) + sender_payload = payload.get("sender") + if not isinstance(sender_payload, dict): + raise AstrBotError.invalid_input( + "message_history.append requires sender object" + ) + parts_payload = payload.get("parts") + if not isinstance(parts_payload, list) or any( + not isinstance(item, dict) for item in parts_payload + ): + raise AstrBotError.invalid_input( + "message_history.append requires parts array" + ) + metadata = payload.get("metadata") + if metadata is not None and not isinstance(metadata, dict): + raise AstrBotError.invalid_input( + "message_history.append requires metadata object when provided" + ) + idempotency_key = ( + str(payload.get("idempotency_key")) + if payload.get("idempotency_key") is not None + else None + ) + records = self._message_history_records(session) + if idempotency_key: + existing = next( + ( + record + for record in records + if str(record.get("idempotency_key") or "") == idempotency_key + ), + None, + ) + if existing is not None: + return {"record": self._serialize_record(existing)} + record = self._create_message_history_record( + session=session, + sender_payload=sender_payload, + parts_payload=parts_payload, + metadata=dict(metadata or {}), + idempotency_key=idempotency_key, + ) + records.append(record) + return {"record": self._serialize_record(record)} + + async def _message_history_delete_before( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = self._typed_session_from_payload(payload.get("session")) + before = self._parse_boundary(payload.get("before"), "delete_before") + records = self._message_history_records(session) + retained: list[dict[str, Any]] = [] + deleted_count = 0 + for record in records: + created_at = self._normalize_timestamp(record.get("created_at")) + if created_at < before: + deleted_count += 1 + continue + retained.append(record) + self._message_history_store[self._typed_key(session)] = retained + return {"deleted_count": deleted_count} + + async def _message_history_delete_after( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = self._typed_session_from_payload(payload.get("session")) + after = self._parse_boundary(payload.get("after"), "delete_after") + records = self._message_history_records(session) + retained: list[dict[str, Any]] = [] + deleted_count = 0 + for record in records: + created_at = self._normalize_timestamp(record.get("created_at")) + if created_at > after: + deleted_count += 1 + continue + retained.append(record) + self._message_history_store[self._typed_key(session)] = retained + return {"deleted_count": deleted_count} + + async def _message_history_delete_all( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = self._typed_session_from_payload(payload.get("session")) + key = self._typed_key(session) + deleted_count = len(self._message_history_store.get(key, [])) + self._message_history_store[key] = [] + return {"deleted_count": deleted_count} + + def _register_message_history_capabilities(self) -> None: + self.register( + self._builtin_descriptor("message_history.list", "List message history"), + call_handler=self._message_history_list, + ) + self.register( + self._builtin_descriptor( + "message_history.get_by_id", + "Get message history by id", + ), + call_handler=self._message_history_get_by_id, + ) + self.register( + self._builtin_descriptor( + "message_history.append", "Append message history" + ), + call_handler=self._message_history_append, + ) + self.register( + self._builtin_descriptor( + "message_history.delete_before", + "Delete message history before timestamp", + ), + call_handler=self._message_history_delete_before, + ) + self.register( + self._builtin_descriptor( + "message_history.delete_after", + "Delete message history after timestamp", + ), + call_handler=self._message_history_delete_after, + ) + self.register( + self._builtin_descriptor( + "message_history.delete_all", + "Delete all message history in session", + ), + call_handler=self._message_history_delete_all, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/metadata.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/metadata.py new file mode 100644 index 0000000000..787f63369b --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/metadata.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Any + +from ..bridge_base import CapabilityRouterBridgeBase + + +class MetadataCapabilityMixin(CapabilityRouterBridgeBase): + async def _metadata_get_plugin( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + name = str(payload.get("name", "")).strip() + plugin = self._plugins.get(name) + if plugin is None: + return {"plugin": None} + return {"plugin": dict(plugin.metadata)} + + async def _metadata_list_plugins( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugins = [ + dict(self._plugins[name].metadata) for name in sorted(self._plugins.keys()) + ] + return {"plugins": plugins} + + async def _metadata_get_plugin_config( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + name = str(payload.get("name", "")).strip() + caller_plugin_id = self._require_caller_plugin_id("metadata.get_plugin_config") + if name != caller_plugin_id: + return {"config": None} + plugin = self._plugins.get(name) + if plugin is None: + return {"config": None} + return {"config": dict(plugin.config)} + + async def _metadata_save_plugin_config( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + caller_plugin_id = self._require_caller_plugin_id("metadata.save_plugin_config") + plugin = self._plugins.get(caller_plugin_id) + if plugin is None: + return {"config": None} + config = payload.get("config") + if not isinstance(config, dict): + return {"config": dict(plugin.config)} + plugin.config = dict(config) + return {"config": dict(plugin.config)} + + def _register_metadata_capabilities(self) -> None: + self.register( + self._builtin_descriptor("metadata.get_plugin", "获取单个插件元数据"), + call_handler=self._metadata_get_plugin, + ) + self.register( + self._builtin_descriptor("metadata.list_plugins", "列出插件元数据"), + call_handler=self._metadata_list_plugins, + ) + self.register( + self._builtin_descriptor( + "metadata.get_plugin_config", + "获取插件配置", + ), + call_handler=self._metadata_get_plugin_config, + ) + self.register( + self._builtin_descriptor( + "metadata.save_plugin_config", + "保存当前插件配置", + ), + call_handler=self._metadata_save_plugin_config, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/permission.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/permission.py new file mode 100644 index 0000000000..063ab840c9 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/permission.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from typing import Any + +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + + +class PermissionCapabilityMixin(CapabilityRouterBridgeBase): + def _register_permission_capabilities(self) -> None: + self.register( + self._builtin_descriptor("permission.check", "查询用户权限角色"), + call_handler=self._permission_check, + ) + self.register( + self._builtin_descriptor("permission.get_admins", "列出管理员 ID"), + call_handler=self._permission_get_admins, + ) + self.register( + self._builtin_descriptor( + "permission.manager.add_admin", + "添加管理员 ID", + ), + call_handler=self._permission_manager_add_admin, + ) + self.register( + self._builtin_descriptor( + "permission.manager.remove_admin", + "移除管理员 ID", + ), + call_handler=self._permission_manager_remove_admin, + ) + + @staticmethod + def _normalize_admin_ids(values: Any) -> list[str]: + if not isinstance(values, list): + return [] + normalized: list[str] = [] + for item in values: + user_id = str(item).strip() + if user_id: + normalized.append(user_id) + return normalized + + def _admin_ids_snapshot(self) -> list[str]: + normalized = self._normalize_admin_ids( + getattr(self, "_permission_admin_ids", []) + ) + self._permission_admin_ids = list(normalized) + return normalized + + @staticmethod + def _required_user_id(payload: dict[str, Any], capability_name: str) -> str: + user_id = str(payload.get("user_id", "")).strip() + if not user_id: + raise AstrBotError.invalid_input(f"{capability_name} requires user_id") + return user_id + + def _require_reserved_plugin(self, capability_name: str) -> str: + plugin_id = self._require_caller_plugin_id(capability_name) + plugin = self._plugins.get(plugin_id) + if plugin is not None and bool(plugin.metadata.get("reserved", False)): + return plugin_id + if plugin_id in {"system", "__system__"}: + return plugin_id + raise AstrBotError.invalid_input( + f"{capability_name} is restricted to reserved/system plugins" + ) + + @staticmethod + def _require_admin_event_context( + payload: dict[str, Any], + capability_name: str, + ) -> None: + if bool(payload.get("_caller_is_admin", False)): + return + raise AstrBotError.invalid_input( + f"{capability_name} requires an active admin event context" + ) + + async def _permission_check( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + user_id = self._required_user_id(payload, "permission.check") + admins = self._admin_ids_snapshot() + is_admin = user_id in admins + return { + "is_admin": is_admin, + "role": "admin" if is_admin else "member", + } + + async def _permission_get_admins( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + return {"admins": self._admin_ids_snapshot()} + + async def _permission_manager_add_admin( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin("permission.manager.add_admin") + self._require_admin_event_context(payload, "permission.manager.add_admin") + user_id = self._required_user_id(payload, "permission.manager.add_admin") + admins = self._admin_ids_snapshot() + if user_id in admins: + return {"changed": False} + admins.append(user_id) + self._permission_admin_ids = admins + return {"changed": True} + + async def _permission_manager_remove_admin( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin("permission.manager.remove_admin") + self._require_admin_event_context(payload, "permission.manager.remove_admin") + user_id = self._required_user_id(payload, "permission.manager.remove_admin") + admins = self._admin_ids_snapshot() + if user_id not in admins: + return {"changed": False} + admins.remove(user_id) + self._permission_admin_ids = admins + return {"changed": True} diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/persona.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/persona.py new file mode 100644 index 0000000000..6d7b3b3531 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/persona.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from typing import Any + +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + + +class PersonaCapabilityMixin(CapabilityRouterBridgeBase): + async def _persona_get( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + persona_id = str(payload.get("persona_id", "")).strip() + record = self._persona_store.get(persona_id) + if record is None: + raise AstrBotError.invalid_input(f"persona not found: {persona_id}") + return {"persona": dict(record)} + + async def _persona_list( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + personas = [ + dict(self._persona_store[persona_id]) + for persona_id in sorted(self._persona_store.keys()) + ] + return {"personas": personas} + + async def _persona_create( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + raw_persona = payload.get("persona") + if not isinstance(raw_persona, dict): + raise AstrBotError.invalid_input("persona.create requires persona object") + persona_id = str(raw_persona.get("persona_id", "")).strip() + if not persona_id: + raise AstrBotError.invalid_input("persona.create requires persona_id") + if persona_id in self._persona_store: + raise AstrBotError.invalid_input(f"persona already exists: {persona_id}") + now = self._now_iso() + record = { + "persona_id": persona_id, + "system_prompt": str(raw_persona.get("system_prompt", "")), + "begin_dialogs": self._normalize_persona_dialogs_payload( + raw_persona.get("begin_dialogs") + ), + "tools": ( + [str(item) for item in raw_persona.get("tools", [])] + if isinstance(raw_persona.get("tools"), list) + else None + ), + "skills": ( + [str(item) for item in raw_persona.get("skills", [])] + if isinstance(raw_persona.get("skills"), list) + else None + ), + "custom_error_message": ( + str(raw_persona.get("custom_error_message")) + if raw_persona.get("custom_error_message") is not None + else None + ), + "folder_id": ( + str(raw_persona.get("folder_id")) + if raw_persona.get("folder_id") is not None + else None + ), + "sort_order": int(raw_persona.get("sort_order", 0)), + "created_at": now, + "updated_at": now, + } + self._persona_store[persona_id] = record + return {"persona": dict(record)} + + async def _persona_update( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + persona_id = str(payload.get("persona_id", "")).strip() + record = self._persona_store.get(persona_id) + if record is None: + return {"persona": None} + raw_persona = payload.get("persona") + if not isinstance(raw_persona, dict): + raise AstrBotError.invalid_input("persona.update requires persona object") + if ( + "system_prompt" in raw_persona + and raw_persona.get("system_prompt") is not None + ): + record["system_prompt"] = str(raw_persona.get("system_prompt", "")) + if "begin_dialogs" in raw_persona: + begin_dialogs = raw_persona.get("begin_dialogs") + record["begin_dialogs"] = ( + self._normalize_persona_dialogs_payload(begin_dialogs) + if begin_dialogs is not None + else [] + ) + if "tools" in raw_persona: + tools = raw_persona.get("tools") + record["tools"] = ( + [str(item) for item in tools] if isinstance(tools, list) else None + ) + if "skills" in raw_persona: + skills = raw_persona.get("skills") + record["skills"] = ( + [str(item) for item in skills] if isinstance(skills, list) else None + ) + if "custom_error_message" in raw_persona: + custom_error_message = raw_persona.get("custom_error_message") + record["custom_error_message"] = ( + str(custom_error_message) if custom_error_message is not None else None + ) + record["updated_at"] = self._now_iso() + return {"persona": dict(record)} + + async def _persona_delete( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + persona_id = str(payload.get("persona_id", "")).strip() + if persona_id not in self._persona_store: + raise AstrBotError.invalid_input(f"persona not found: {persona_id}") + del self._persona_store[persona_id] + return {} + + def _register_persona_capabilities(self) -> None: + self.register( + self._builtin_descriptor("persona.get", "获取人格"), + call_handler=self._persona_get, + ) + self.register( + self._builtin_descriptor("persona.list", "列出人格"), + call_handler=self._persona_list, + ) + self.register( + self._builtin_descriptor("persona.create", "创建人格"), + call_handler=self._persona_create, + ) + self.register( + self._builtin_descriptor("persona.update", "更新人格"), + call_handler=self._persona_update, + ) + self.register( + self._builtin_descriptor("persona.delete", "删除人格"), + call_handler=self._persona_delete, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/platform.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/platform.py new file mode 100644 index 0000000000..dbc565a013 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/platform.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +from typing import Any + +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + + +class PlatformCapabilityMixin(CapabilityRouterBridgeBase): + async def _platform_send( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session, target = self._resolve_target(payload) + self._require_platform_support_for_session("platform.send", session) + text = str(payload.get("text", "")) + message_id = f"msg_{len(self.sent_messages) + 1}" + sent: dict[str, Any] = { + "message_id": message_id, + "session": session, + "text": text, + } + if target is not None: + sent["target"] = target + self.sent_messages.append(sent) + return {"message_id": message_id} + + async def _platform_send_image( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session, target = self._resolve_target(payload) + self._require_platform_support_for_session("platform.send_image", session) + image_url = str(payload.get("image_url", "")) + message_id = f"img_{len(self.sent_messages) + 1}" + sent: dict[str, Any] = { + "message_id": message_id, + "session": session, + "image_url": image_url, + } + if target is not None: + sent["target"] = target + self.sent_messages.append(sent) + return {"message_id": message_id} + + async def _platform_send_chain( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session, target = self._resolve_target(payload) + self._require_platform_support_for_session("platform.send_chain", session) + chain = payload.get("chain") + if not isinstance(chain, list) or not all( + isinstance(item, dict) for item in chain + ): + raise AstrBotError.invalid_input( + "platform.send_chain 的 chain 必须是 object 数组" + ) + message_id = f"chain_{len(self.sent_messages) + 1}" + sent: dict[str, Any] = { + "message_id": message_id, + "session": session, + "chain": [dict(item) for item in chain], + } + if target is not None: + sent["target"] = target + self.sent_messages.append(sent) + return {"message_id": message_id} + + async def _platform_send_by_session( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + chain = payload.get("chain") + if not isinstance(chain, list) or not all( + isinstance(item, dict) for item in chain + ): + raise AstrBotError.invalid_input( + "platform.send_by_session 的 chain 必须是 object 数组" + ) + session = str(payload.get("session", "")) + self._require_platform_support_for_session("platform.send_by_session", session) + message_id = f"proactive_{len(self.sent_messages) + 1}" + self.sent_messages.append( + { + "message_id": message_id, + "session": session, + "chain": [dict(item) for item in chain], + } + ) + return {"message_id": message_id} + + async def _platform_get_group( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session, _target = self._resolve_target(payload) + return {"group": self._mock_group_payload(session)} + + async def _platform_get_members( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session, _target = self._resolve_target(payload) + group = self._mock_group_payload(session) + if group is None: + return {"members": []} + return {"members": list(group.get("members", []))} + + async def _platform_list_instances( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("platform.list_instances") + return { + "platforms": [ + { + "id": str(item.get("id", "")), + "name": str(item.get("name", "")), + "type": str(item.get("type", "")), + "status": str(item.get("status", "unknown")), + } + for item in self.get_platform_instances() + if isinstance(item, dict) + and self._plugin_supports_platform(plugin_id, str(item.get("type", ""))) + ] + } + + def _register_platform_capabilities(self) -> None: + self.register( + self._builtin_descriptor("platform.send", "发送消息"), + call_handler=self._platform_send, + ) + self.register( + self._builtin_descriptor("platform.send_image", "发送图片"), + call_handler=self._platform_send_image, + ) + self.register( + self._builtin_descriptor("platform.send_chain", "发送消息链"), + call_handler=self._platform_send_chain, + ) + self.register( + self._builtin_descriptor( + "platform.send_by_session", "按会话主动发送消息链" + ), + call_handler=self._platform_send_by_session, + ) + self.register( + self._builtin_descriptor("platform.get_group", "获取当前群信息"), + call_handler=self._platform_get_group, + ) + self.register( + self._builtin_descriptor("platform.get_members", "获取群成员"), + call_handler=self._platform_get_members, + ) + self.register( + self._builtin_descriptor("platform.list_instances", "列出平台实例元信息"), + call_handler=self._platform_list_instances, + ) + + async def _platform_manager_get_by_id( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("platform.manager.get_by_id") + platform_id = str(payload.get("platform_id", "")).strip() + platform = next( + ( + dict(item) + for item in self._platform_instances + if str(item.get("id", "")) == platform_id + ), + None, + ) + return {"platform": platform} + + async def _platform_manager_clear_errors( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("platform.manager.clear_errors") + platform_id = str(payload.get("platform_id", "")).strip() + for item in self._platform_instances: + if str(item.get("id", "")) != platform_id: + continue + item["errors"] = [] + item["last_error"] = None + if str(item.get("status", "")) == "error": + item["status"] = "running" + break + return {} + + async def _platform_manager_get_stats( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("platform.manager.get_stats") + platform_id = str(payload.get("platform_id", "")).strip() + for item in self._platform_instances: + if str(item.get("id", "")) != platform_id: + continue + stats = item.get("stats") + if isinstance(stats, dict): + return {"stats": dict(stats)} + errors = item.get("errors") + last_error = item.get("last_error") + meta = item.get("meta") + return { + "stats": { + "id": platform_id, + "type": str(item.get("type", "")), + "display_name": str(item.get("name", platform_id)), + "status": str(item.get("status", "pending")), + "started_at": item.get("started_at"), + "error_count": len(errors) if isinstance(errors, list) else 0, + "last_error": dict(last_error) + if isinstance(last_error, dict) + else None, + "unified_webhook": bool(item.get("unified_webhook", False)), + "meta": dict(meta) if isinstance(meta, dict) else {}, + } + } + return {"stats": None} + + def _register_platform_manager_capabilities(self) -> None: + self.register( + self._builtin_descriptor( + "platform.manager.get_by_id", + "按 ID 获取平台管理快照", + ), + call_handler=self._platform_manager_get_by_id, + ) + self.register( + self._builtin_descriptor( + "platform.manager.clear_errors", + "清除平台错误", + ), + call_handler=self._platform_manager_clear_errors, + ) + self.register( + self._builtin_descriptor( + "platform.manager.get_stats", + "获取平台统计信息", + ), + call_handler=self._platform_manager_get_stats, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/provider.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/provider.py new file mode 100644 index 0000000000..937373a0a0 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/provider.py @@ -0,0 +1,1080 @@ +from __future__ import annotations + +import asyncio +import base64 +from collections.abc import AsyncIterator +from typing import Any + +from ....errors import AstrBotError +from ..._streaming import StreamExecution +from ..bridge_base import ( + _MOCK_EMBEDDING_DIM, + CapabilityRouterBridgeBase, + _mock_embedding_vector, +) + + +class ProviderCapabilityMixin(CapabilityRouterBridgeBase): + @staticmethod + def _active_local_mcp_tool_names(plugin: Any | None) -> list[str]: + if plugin is None: + return [] + local_tools: list[str] = [] + for server in plugin.local_mcp_servers.values(): + if not bool(server.get("active", False)): + continue + if not bool(server.get("running", False)): + continue + server_name = str(server.get("name", "")).strip() + if not server_name: + continue + for tool_name in server.get("tools", []): + if not isinstance(tool_name, str) or not tool_name.strip(): + continue + local_tools.append(f"mcp.{server_name}.{tool_name}") + return local_tools + + def _provider_payload( + self, kind: str, provider_id: str | None + ) -> dict[str, Any] | None: + if not provider_id: + return None + for item in self._provider_catalog.get(kind, []): + if str(item.get("id", "")) == provider_id: + return dict(item) + return None + + def _provider_payload_by_id(self, provider_id: str) -> dict[str, Any] | None: + normalized = str(provider_id).strip() + if not normalized: + return None + for items in self._provider_catalog.values(): + for item in items: + if str(item.get("id", "")) == normalized: + return dict(item) + return None + + @staticmethod + def _provider_kind_from_type(provider_type: str) -> str: + mapping = { + "chat_completion": "chat", + "text_to_speech": "tts", + "speech_to_text": "stt", + "embedding": "embedding", + "rerank": "rerank", + } + normalized = str(provider_type).strip().lower() + if normalized not in mapping: + raise AstrBotError.invalid_input(f"unknown provider_type: {provider_type}") + return mapping[normalized] + + def _provider_config_by_id(self, provider_id: str) -> dict[str, Any] | None: + record = self._provider_configs.get(str(provider_id).strip()) + return dict(record) if isinstance(record, dict) else None + + @staticmethod + def _managed_provider_record( + payload: dict[str, Any], + *, + loaded: bool, + ) -> dict[str, Any]: + return { + "id": str(payload.get("id", "")), + "model": ( + str(payload.get("model")) if payload.get("model") is not None else None + ), + "type": str(payload.get("type", "")), + "provider_type": str(payload.get("provider_type", "chat_completion")), + "loaded": bool(loaded), + "enabled": bool(payload.get("enable", True)), + "provider_source_id": ( + str(payload.get("provider_source_id")) + if payload.get("provider_source_id") is not None + else None + ), + } + + def _managed_provider_record_by_id(self, provider_id: str) -> dict[str, Any] | None: + provider = self._provider_payload_by_id(provider_id) + if provider is not None: + config = self._provider_config_by_id(provider_id) or provider + merged = dict(provider) + merged.update( + { + "enable": config.get("enable", True), + "provider_source_id": config.get("provider_source_id"), + } + ) + return self._managed_provider_record(merged, loaded=True) + config = self._provider_config_by_id(provider_id) + if config is None: + return None + return self._managed_provider_record(config, loaded=False) + + def _emit_provider_change( + self, + provider_id: str, + provider_type: str, + umo: str | None, + ) -> None: + event = { + "provider_id": str(provider_id), + "provider_type": str(provider_type), + "umo": str(umo) if umo is not None else None, + } + for queue in list(self._provider_change_subscriptions.values()): + queue.put_nowait(dict(event)) + + def _require_reserved_plugin(self, capability_name: str) -> str: + plugin_id = self._require_caller_plugin_id(capability_name) + plugin = self._plugins.get(plugin_id) + if plugin is not None and bool(plugin.metadata.get("reserved", False)): + return plugin_id + if plugin_id in {"system", "__system__"}: + return plugin_id + raise AstrBotError.invalid_input( + f"{capability_name} is restricted to reserved/system plugins" + ) + + def _provider_entry( + self, + payload: dict[str, Any], + capability_name: str, + expected_kind: str | None = None, + ) -> dict[str, Any]: + provider_id = str(payload.get("provider_id", "")).strip() + if not provider_id: + raise AstrBotError.invalid_input( + f"{capability_name} requires provider_id", + ) + provider = self._provider_payload_by_id(provider_id) + if provider is None: + raise AstrBotError.invalid_input( + f"{capability_name} unknown provider_id: {provider_id}", + ) + if ( + expected_kind is not None + and str(provider.get("provider_type")) != expected_kind + ): + raise AstrBotError.invalid_input( + f"{capability_name} requires a {expected_kind} provider", + ) + return provider + + async def _provider_get_using( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + provider_id = self._active_provider_ids.get("chat") + return {"provider": self._provider_payload("chat", provider_id)} + + async def _provider_get_by_id( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + return { + "provider": self._provider_payload_by_id( + str(payload.get("provider_id", "")) + ) + } + + async def _provider_get_current_chat_provider_id( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + return {"provider_id": self._active_provider_ids.get("chat")} + + def _provider_list_payload(self, kind: str) -> dict[str, Any]: + return { + "providers": [dict(item) for item in self._provider_catalog.get(kind, [])] + } + + async def _provider_list_all( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + return self._provider_list_payload("chat") + + async def _provider_list_all_tts( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + return self._provider_list_payload("tts") + + async def _provider_list_all_stt( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + return self._provider_list_payload("stt") + + async def _provider_list_all_embedding( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + return self._provider_list_payload("embedding") + + async def _provider_list_all_rerank( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + return self._provider_list_payload("rerank") + + async def _provider_get_using_tts( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + provider_id = self._active_provider_ids.get("tts") + return {"provider": self._provider_payload("tts", provider_id)} + + async def _provider_get_using_stt( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + provider_id = self._active_provider_ids.get("stt") + return {"provider": self._provider_payload("stt", provider_id)} + + async def _provider_stt_get_text( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._provider_entry( + payload, + "provider.stt.get_text", + "speech_to_text", + ) + return {"text": f"Mock transcript: {str(payload.get('audio_url', ''))}"} + + async def _provider_tts_get_audio( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + provider = self._provider_entry( + payload, + "provider.tts.get_audio", + "text_to_speech", + ) + return { + "audio_path": ( + f"mock://tts/{provider.get('id', '')}/{str(payload.get('text', ''))}" + ) + } + + async def _provider_tts_support_stream( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + provider = self._provider_entry( + payload, + "provider.tts.support_stream", + "text_to_speech", + ) + return {"supported": bool(provider.get("support_stream", True))} + + async def _provider_tts_get_audio_stream( + self, + _request_id: str, + payload: dict[str, Any], + token, + ) -> StreamExecution: + self._provider_entry( + payload, + "provider.tts.get_audio_stream", + "text_to_speech", + ) + text = payload.get("text") + text_chunks = payload.get("text_chunks") + if isinstance(text, str): + chunks = [text] + elif isinstance(text_chunks, list) and text_chunks: + chunks = [str(item) for item in text_chunks] + else: + raise AstrBotError.invalid_input( + "provider.tts.get_audio_stream requires text or text_chunks" + ) + + async def iterator() -> AsyncIterator[dict[str, Any]]: + for chunk in chunks: + token.raise_if_cancelled() + await asyncio.sleep(0) + yield { + "audio_base64": base64.b64encode( + f"mock-audio:{chunk}".encode() + ).decode("ascii"), + "text": chunk, + } + + return StreamExecution( + iterator=iterator(), + finalize=lambda items: ( + items[-1] if items else {"audio_base64": "", "text": None} + ), + ) + + async def _provider_embedding_get_embedding( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + provider = self._provider_entry( + payload, + "provider.embedding.get_embedding", + "embedding", + ) + return { + "embedding": _mock_embedding_vector( + str(payload.get("text", "")), + provider_id=str(provider.get("id", "")), + ) + } + + async def _provider_embedding_get_embeddings( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + provider = self._provider_entry( + payload, + "provider.embedding.get_embeddings", + "embedding", + ) + texts = payload.get("texts") + if not isinstance(texts, list): + raise AstrBotError.invalid_input( + "provider.embedding.get_embeddings requires texts", + ) + return { + "embeddings": [ + _mock_embedding_vector( + str(text), + provider_id=str(provider.get("id", "")), + ) + for text in texts + ], + } + + async def _provider_embedding_get_dim( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._provider_entry( + payload, + "provider.embedding.get_dim", + "embedding", + ) + return {"dim": _MOCK_EMBEDDING_DIM} + + async def _provider_rerank_rerank( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._provider_entry( + payload, + "provider.rerank.rerank", + "rerank", + ) + documents = payload.get("documents") + if not isinstance(documents, list): + raise AstrBotError.invalid_input( + "provider.rerank.rerank requires documents", + ) + scored = [ + { + "index": index, + "score": 1.0, + "document": str(raw_document), + } + for index, raw_document in enumerate(documents) + ] + top_n = payload.get("top_n") + if top_n is not None: + scored = scored[: max(int(top_n), 0)] + return {"results": scored} + + async def _provider_manager_set( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("provider.manager.set") + provider_id = str(payload.get("provider_id", "")).strip() + provider_type = str(payload.get("provider_type", "")).strip() + kind = self._provider_kind_from_type(provider_type) + if not provider_id: + raise AstrBotError.invalid_input( + "provider.manager.set requires provider_id" + ) + if self._provider_payload(kind, provider_id) is None: + raise AstrBotError.invalid_input( + f"provider.manager.set unknown provider_id: {provider_id}" + ) + self._active_provider_ids[kind] = provider_id + self._emit_provider_change( + provider_id, + provider_type, + str(payload.get("umo")) if payload.get("umo") is not None else None, + ) + return {} + + async def _provider_manager_get_by_id( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("provider.manager.get_by_id") + return { + "provider": self._managed_provider_record_by_id( + str(payload.get("provider_id", "")) + ) + } + + async def _provider_manager_get_merged_provider_config( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("provider.manager.get_merged_provider_config") + provider_id = str(payload.get("provider_id", "")).strip() + if not provider_id: + raise AstrBotError.invalid_input( + "provider.manager.get_merged_provider_config requires provider_id" + ) + provider = self._provider_payload_by_id(provider_id) + config = self._provider_config_by_id(provider_id) + if provider is None and config is None: + raise AstrBotError.invalid_input( + "provider.manager.get_merged_provider_config " + f"unknown provider_id: {provider_id}" + ) + if provider is None: + return {"config": dict(config) if isinstance(config, dict) else config} + if config is None: + return {"config": dict(provider)} + merged_config = dict(provider) + merged_config.update(config) + return {"config": merged_config} + + @staticmethod + def _normalize_provider_config_object( + payload: Any, + capability_name: str, + field_name: str, + ) -> dict[str, Any]: + if not isinstance(payload, dict): + raise AstrBotError.invalid_input( + f"{capability_name} requires {field_name} object" + ) + return dict(payload) + + async def _provider_manager_load( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("provider.manager.load") + provider_config = self._normalize_provider_config_object( + payload.get("provider_config"), + "provider.manager.load", + "provider_config", + ) + provider_id = str(provider_config.get("id", "")).strip() + provider_type = str(provider_config.get("provider_type", "")).strip() + kind = self._provider_kind_from_type(provider_type) + if not provider_id: + raise AstrBotError.invalid_input( + "provider.manager.load requires provider id" + ) + if bool(provider_config.get("enable", True)): + record = { + "id": provider_id, + "model": ( + str(provider_config.get("model")) + if provider_config.get("model") is not None + else None + ), + "type": str(provider_config.get("type", "")), + "provider_type": provider_type, + } + self._provider_catalog[kind] = [ + item + for item in self._provider_catalog.get(kind, []) + if str(item.get("id", "")) != provider_id + ] + self._provider_catalog[kind].append(record) + self._emit_provider_change(provider_id, provider_type, None) + return { + "provider": self._managed_provider_record( + provider_config, + loaded=bool(provider_config.get("enable", True)), + ) + } + + async def _provider_manager_terminate( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("provider.manager.terminate") + provider_id = str(payload.get("provider_id", "")).strip() + if not provider_id: + raise AstrBotError.invalid_input( + "provider.manager.terminate requires provider_id" + ) + managed = self._managed_provider_record_by_id(provider_id) + if managed is None: + raise AstrBotError.invalid_input( + f"provider.manager.terminate unknown provider_id: {provider_id}" + ) + kind = self._provider_kind_from_type(str(managed.get("provider_type", ""))) + self._provider_catalog[kind] = [ + item + for item in self._provider_catalog.get(kind, []) + if str(item.get("id", "")) != provider_id + ] + if self._active_provider_ids.get(kind) == provider_id: + catalog = self._provider_catalog.get(kind, []) + self._active_provider_ids[kind] = ( + str(catalog[0].get("id")) if catalog else None + ) + self._emit_provider_change( + provider_id, str(managed.get("provider_type", "")), None + ) + return {} + + async def _provider_manager_create( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("provider.manager.create") + provider_config = self._normalize_provider_config_object( + payload.get("provider_config"), + "provider.manager.create", + "provider_config", + ) + provider_id = str(provider_config.get("id", "")).strip() + provider_type = str(provider_config.get("provider_type", "")).strip() + kind = self._provider_kind_from_type(provider_type) + if not provider_id: + raise AstrBotError.invalid_input( + "provider.manager.create requires provider id" + ) + self._provider_configs[provider_id] = dict(provider_config) + if bool(provider_config.get("enable", True)): + self._provider_catalog[kind] = [ + item + for item in self._provider_catalog.get(kind, []) + if str(item.get("id", "")) != provider_id + ] + self._provider_catalog[kind].append( + { + "id": provider_id, + "model": ( + str(provider_config.get("model")) + if provider_config.get("model") is not None + else None + ), + "type": str(provider_config.get("type", "")), + "provider_type": provider_type, + } + ) + self._emit_provider_change(provider_id, provider_type, None) + return {"provider": self._managed_provider_record_by_id(provider_id)} + + async def _provider_manager_update( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("provider.manager.update") + origin_provider_id = str(payload.get("origin_provider_id", "")).strip() + new_config = self._normalize_provider_config_object( + payload.get("new_config"), + "provider.manager.update", + "new_config", + ) + if not origin_provider_id: + raise AstrBotError.invalid_input( + "provider.manager.update requires origin_provider_id" + ) + current = self._provider_config_by_id(origin_provider_id) + if current is None: + current = self._managed_provider_record_by_id(origin_provider_id) + if current is None: + raise AstrBotError.invalid_input( + f"provider.manager.update unknown provider_id: {origin_provider_id}" + ) + target_provider_id = str(new_config.get("id") or origin_provider_id).strip() + provider_type = str( + new_config.get("provider_type") or current.get("provider_type", "") + ).strip() + kind = self._provider_kind_from_type(provider_type) + self._provider_configs.pop(origin_provider_id, None) + merged = dict(current) + merged.update(new_config) + merged["id"] = target_provider_id + merged["provider_type"] = provider_type + self._provider_configs[target_provider_id] = merged + for catalog_kind, items in list(self._provider_catalog.items()): + self._provider_catalog[catalog_kind] = [ + item for item in items if str(item.get("id", "")) != origin_provider_id + ] + if bool(merged.get("enable", True)): + self._provider_catalog[kind].append( + { + "id": target_provider_id, + "model": ( + str(merged.get("model")) + if merged.get("model") is not None + else None + ), + "type": str(merged.get("type", "")), + "provider_type": provider_type, + } + ) + for active_kind, active_id in list(self._active_provider_ids.items()): + if active_id == origin_provider_id: + self._active_provider_ids[active_kind] = ( + target_provider_id if active_kind == kind else None + ) + self._emit_provider_change(target_provider_id, provider_type, None) + return {"provider": self._managed_provider_record_by_id(target_provider_id)} + + async def _provider_manager_delete( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("provider.manager.delete") + provider_id = ( + str(payload.get("provider_id")).strip() + if payload.get("provider_id") is not None + else None + ) + provider_source_id = ( + str(payload.get("provider_source_id")).strip() + if payload.get("provider_source_id") is not None + else None + ) + if not provider_id and not provider_source_id: + raise AstrBotError.invalid_input( + "provider.manager.delete requires provider_id or provider_source_id" + ) + deleted: list[dict[str, Any]] = [] + if provider_id: + record = self._managed_provider_record_by_id(provider_id) + if record is not None: + deleted.append(record) + self._provider_configs.pop(provider_id, None) + else: + for record_id, record in list(self._provider_configs.items()): + if ( + str(record.get("provider_source_id", "")).strip() + != provider_source_id + ): + continue + deleted_record = self._managed_provider_record_by_id(record_id) + if deleted_record is not None: + deleted.append(deleted_record) + self._provider_configs.pop(record_id, None) + deleted_ids = {str(item.get("id", "")) for item in deleted} + for kind, items in list(self._provider_catalog.items()): + self._provider_catalog[kind] = [ + item for item in items if str(item.get("id", "")) not in deleted_ids + ] + if self._active_provider_ids.get(kind) in deleted_ids: + catalog = self._provider_catalog.get(kind, []) + self._active_provider_ids[kind] = ( + str(catalog[0].get("id")) if catalog else None + ) + for record in deleted: + self._emit_provider_change( + str(record.get("id", "")), + str(record.get("provider_type", "")), + None, + ) + return {} + + async def _provider_manager_get_insts( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("provider.manager.get_insts") + return { + "providers": [ + self._managed_provider_record(item, loaded=True) + for item in self._provider_catalog.get("chat", []) + ] + } + + async def _provider_manager_watch_changes( + self, request_id: str, _payload: dict[str, Any], _token + ) -> StreamExecution: + self._require_reserved_plugin("provider.manager.watch_changes") + queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() + self._provider_change_subscriptions[request_id] = queue + + async def iterator() -> AsyncIterator[dict[str, Any]]: + try: + while True: + yield await queue.get() + finally: + self._provider_change_subscriptions.pop(request_id, None) + + return StreamExecution( + iterator=iterator(), + finalize=lambda _chunks: {}, + collect_chunks=False, + ) + + async def _platform_manager_get_by_id( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("platform.manager.get_by_id") + platform_id = str(payload.get("platform_id", "")).strip() + platform = next( + ( + dict(item) + for item in self._platform_instances + if str(item.get("id", "")) == platform_id + ), + None, + ) + return {"platform": platform} + + async def _platform_manager_clear_errors( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("platform.manager.clear_errors") + platform_id = str(payload.get("platform_id", "")).strip() + for item in self._platform_instances: + if str(item.get("id", "")) != platform_id: + continue + item["errors"] = [] + item["last_error"] = None + if str(item.get("status", "")) == "error": + item["status"] = "running" + break + return {} + + async def _platform_manager_get_stats( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self._require_reserved_plugin("platform.manager.get_stats") + platform_id = str(payload.get("platform_id", "")).strip() + for item in self._platform_instances: + if str(item.get("id", "")) != platform_id: + continue + stats = item.get("stats") + if isinstance(stats, dict): + return {"stats": dict(stats)} + errors = item.get("errors") + last_error = item.get("last_error") + meta = item.get("meta") + return { + "stats": { + "id": platform_id, + "type": str(item.get("type", "")), + "display_name": str(item.get("name", platform_id)), + "status": str(item.get("status", "pending")), + "started_at": item.get("started_at"), + "error_count": len(errors) if isinstance(errors, list) else 0, + "last_error": dict(last_error) + if isinstance(last_error, dict) + else None, + "unified_webhook": bool(item.get("unified_webhook", False)), + "meta": dict(meta) if isinstance(meta, dict) else {}, + } + } + return {"stats": None} + + async def _llm_tool_manager_get( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("llm_tool.manager.get") + plugin = self._plugins.get(plugin_id) + if plugin is None: + return {"registered": [], "active": []} + registered = [dict(item) for item in plugin.llm_tools.values()] + active = [ + dict(item) + for name, item in plugin.llm_tools.items() + if name in plugin.active_llm_tools + ] + return {"registered": registered, "active": active} + + async def _llm_tool_manager_activate( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("llm_tool.manager.activate") + plugin = self._plugins.get(plugin_id) + if plugin is None: + return {"activated": False} + name = str(payload.get("name", "")) + spec = plugin.llm_tools.get(name) + if spec is None: + return {"activated": False} + spec["active"] = True + plugin.active_llm_tools.add(name) + return {"activated": True} + + async def _llm_tool_manager_deactivate( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("llm_tool.manager.deactivate") + plugin = self._plugins.get(plugin_id) + if plugin is None: + return {"deactivated": False} + name = str(payload.get("name", "")) + spec = plugin.llm_tools.get(name) + if spec is None: + return {"deactivated": False} + spec["active"] = False + plugin.active_llm_tools.discard(name) + return {"deactivated": True} + + async def _llm_tool_manager_add( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("llm_tool.manager.add") + plugin = self._plugins.get(plugin_id) + if plugin is None: + return {"names": []} + tools_payload = payload.get("tools") + if not isinstance(tools_payload, list): + raise AstrBotError.invalid_input("llm_tool.manager.add 的 tools 必须是数组") + names: list[str] = [] + for item in tools_payload: + if not isinstance(item, dict): + continue + name = str(item.get("name", "")).strip() + if not name: + continue + plugin.llm_tools[name] = dict(item) + if bool(item.get("active", True)): + plugin.active_llm_tools.add(name) + else: + plugin.active_llm_tools.discard(name) + names.append(name) + return {"names": names} + + async def _llm_tool_manager_remove( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("llm_tool.manager.remove") + plugin = self._plugins.get(plugin_id) + if plugin is None: + return {"removed": False} + name = str(payload.get("name", "")).strip() + removed = plugin.llm_tools.pop(name, None) is not None + plugin.active_llm_tools.discard(name) + return {"removed": removed} + + async def _agent_registry_list( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("agent.registry.list") + plugin = self._plugins.get(plugin_id) + if plugin is None: + return {"agents": []} + return {"agents": [dict(item) for item in plugin.agents.values()]} + + async def _agent_registry_get( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("agent.registry.get") + plugin = self._plugins.get(plugin_id) + if plugin is None: + return {"agent": None} + agent = plugin.agents.get(str(payload.get("name", ""))) + return {"agent": dict(agent) if isinstance(agent, dict) else None} + + async def _agent_tool_loop_run( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("agent.tool_loop.run") + plugin = self._plugins.get(plugin_id) + requested_tools = payload.get("tool_names") + active_tools: list[str] = [] + if plugin is not None: + local_tools = self._active_local_mcp_tool_names(plugin) + if isinstance(requested_tools, list) and requested_tools: + active_tools = [ + name + for name in (str(item) for item in requested_tools) + if name in plugin.active_llm_tools or name in local_tools + ] + else: + active_tools = sorted([*plugin.active_llm_tools, *local_tools]) + prompt = str(payload.get("prompt", "") or "") + suffix = "" + if active_tools: + suffix = f" tools={','.join(active_tools)}" + return { + "text": f"Mock tool loop: {prompt}{suffix}".strip(), + "usage": { + "input_tokens": len(prompt), + "output_tokens": len(prompt) + len(suffix), + }, + "finish_reason": "stop", + "tool_calls": [], + "role": "assistant", + "reasoning_content": None, + "reasoning_signature": None, + } + + def _register_provider_capabilities(self) -> None: + self.register( + self._builtin_descriptor("provider.get_using", "获取当前聊天 Provider"), + call_handler=self._provider_get_using, + ) + self.register( + self._builtin_descriptor("provider.get_by_id", "按 ID 获取 Provider"), + call_handler=self._provider_get_by_id, + ) + self.register( + self._builtin_descriptor( + "provider.get_current_chat_provider_id", + "获取当前聊天 Provider ID", + ), + call_handler=self._provider_get_current_chat_provider_id, + ) + self.register( + self._builtin_descriptor("provider.list_all", "列出聊天 Providers"), + call_handler=self._provider_list_all, + ) + self.register( + self._builtin_descriptor("provider.list_all_tts", "列出 TTS Providers"), + call_handler=self._provider_list_all_tts, + ) + self.register( + self._builtin_descriptor("provider.list_all_stt", "列出 STT Providers"), + call_handler=self._provider_list_all_stt, + ) + self.register( + self._builtin_descriptor( + "provider.list_all_embedding", + "列出 Embedding Providers", + ), + call_handler=self._provider_list_all_embedding, + ) + self.register( + self._builtin_descriptor( + "provider.list_all_rerank", + "列出 Rerank Providers", + ), + call_handler=self._provider_list_all_rerank, + ) + self.register( + self._builtin_descriptor("provider.get_using_tts", "获取当前 TTS Provider"), + call_handler=self._provider_get_using_tts, + ) + self.register( + self._builtin_descriptor("provider.get_using_stt", "获取当前 STT Provider"), + call_handler=self._provider_get_using_stt, + ) + self.register( + self._builtin_descriptor("provider.stt.get_text", "STT 转写"), + call_handler=self._provider_stt_get_text, + ) + self.register( + self._builtin_descriptor("provider.tts.get_audio", "TTS 合成音频"), + call_handler=self._provider_tts_get_audio, + ) + self.register( + self._builtin_descriptor( + "provider.tts.support_stream", + "检查 TTS 流式支持", + ), + call_handler=self._provider_tts_support_stream, + ) + self.register( + self._builtin_descriptor( + "provider.tts.get_audio_stream", + "流式 TTS 音频输出", + supports_stream=True, + cancelable=True, + ), + stream_handler=self._provider_tts_get_audio_stream, + ) + self.register( + self._builtin_descriptor( + "provider.embedding.get_embedding", + "获取单条向量", + ), + call_handler=self._provider_embedding_get_embedding, + ) + self.register( + self._builtin_descriptor( + "provider.embedding.get_embeddings", + "批量获取向量", + ), + call_handler=self._provider_embedding_get_embeddings, + ) + self.register( + self._builtin_descriptor( + "provider.embedding.get_dim", + "获取向量维度", + ), + call_handler=self._provider_embedding_get_dim, + ) + self.register( + self._builtin_descriptor("provider.rerank.rerank", "文档重排序"), + call_handler=self._provider_rerank_rerank, + ) + + def _register_provider_manager_capabilities(self) -> None: + self.register( + self._builtin_descriptor("provider.manager.set", "设置当前 Provider"), + call_handler=self._provider_manager_set, + ) + self.register( + self._builtin_descriptor( + "provider.manager.get_by_id", + "按 ID 获取 Provider 管理记录", + ), + call_handler=self._provider_manager_get_by_id, + ) + self.register( + self._builtin_descriptor( + "provider.manager.get_merged_provider_config", + "获取 Provider 合并配置", + ), + call_handler=self._provider_manager_get_merged_provider_config, + ) + self.register( + self._builtin_descriptor("provider.manager.load", "运行时加载 Provider"), + call_handler=self._provider_manager_load, + ) + self.register( + self._builtin_descriptor( + "provider.manager.terminate", + "终止已加载的 Provider", + ), + call_handler=self._provider_manager_terminate, + ) + self.register( + self._builtin_descriptor("provider.manager.create", "创建 Provider"), + call_handler=self._provider_manager_create, + ) + self.register( + self._builtin_descriptor("provider.manager.update", "更新 Provider"), + call_handler=self._provider_manager_update, + ) + self.register( + self._builtin_descriptor("provider.manager.delete", "删除 Provider"), + call_handler=self._provider_manager_delete, + ) + self.register( + self._builtin_descriptor( + "provider.manager.get_insts", + "列出已加载聊天 Provider", + ), + call_handler=self._provider_manager_get_insts, + ) + self.register( + self._builtin_descriptor( + "provider.manager.watch_changes", + "订阅 Provider 变更", + supports_stream=True, + cancelable=True, + ), + stream_handler=self._provider_manager_watch_changes, + ) + + def _register_agent_tool_capabilities(self) -> None: + self.register( + self._builtin_descriptor("llm_tool.manager.get", "获取 LLM 工具状态"), + call_handler=self._llm_tool_manager_get, + ) + self.register( + self._builtin_descriptor("llm_tool.manager.activate", "激活 LLM 工具"), + call_handler=self._llm_tool_manager_activate, + ) + self.register( + self._builtin_descriptor("llm_tool.manager.deactivate", "停用 LLM 工具"), + call_handler=self._llm_tool_manager_deactivate, + ) + self.register( + self._builtin_descriptor("llm_tool.manager.add", "动态添加 LLM 工具"), + call_handler=self._llm_tool_manager_add, + ) + self.register( + self._builtin_descriptor("llm_tool.manager.remove", "动态移除 LLM 工具"), + call_handler=self._llm_tool_manager_remove, + ) + self.register( + self._builtin_descriptor("agent.tool_loop.run", "运行 mock tool loop"), + call_handler=self._agent_tool_loop_run, + ) + self.register( + self._builtin_descriptor("agent.registry.list", "列出 Agent 元数据"), + call_handler=self._agent_registry_list, + ) + self.register( + self._builtin_descriptor("agent.registry.get", "获取 Agent 元数据"), + call_handler=self._agent_registry_get, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/session.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/session.py new file mode 100644 index 0000000000..e56f979e9e --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/session.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from typing import Any + +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + + +class SessionCapabilityMixin(CapabilityRouterBridgeBase): + async def _session_plugin_is_enabled( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")) + plugin_name = str(payload.get("plugin_name", "")) + config = self._session_plugin_config(session) + enabled_plugins = { + str(item) for item in config.get("enabled_plugins", []) if str(item).strip() + } + disabled_plugins = { + str(item) + for item in config.get("disabled_plugins", []) + if str(item).strip() + } + if plugin_name in enabled_plugins: + return {"enabled": True} + return {"enabled": plugin_name not in disabled_plugins} + + async def _session_plugin_filter_handlers( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")) + handlers = payload.get("handlers") + if not isinstance(handlers, list): + raise AstrBotError.invalid_input( + "session.plugin.filter_handlers 的 handlers 必须是 object 数组" + ) + disabled_plugins = { + str(item) + for item in self._session_plugin_config(session).get("disabled_plugins", []) + if str(item).strip() + } + reserved_plugins = { + str(plugin.metadata.get("name", "")) + for plugin in self._plugins.values() + if bool(plugin.metadata.get("reserved", False)) + } + filtered = [] + for item in handlers: + if not isinstance(item, dict): + continue + plugin_name = str(item.get("plugin_name", "")) + if ( + plugin_name + and plugin_name in disabled_plugins + and plugin_name not in reserved_plugins + ): + continue + filtered.append(dict(item)) + return {"handlers": filtered} + + async def _session_service_is_llm_enabled( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")) + config = self._session_service_config(session) + return {"enabled": bool(config.get("llm_enabled", True))} + + async def _session_service_set_llm_status( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")) + config = self._session_service_config(session) + config["llm_enabled"] = bool(payload.get("enabled", False)) + self._session_service_configs[session] = config + return {} + + async def _session_service_is_tts_enabled( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")) + config = self._session_service_config(session) + return {"enabled": bool(config.get("tts_enabled", True))} + + async def _session_service_set_tts_status( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + session = str(payload.get("session", "")) + config = self._session_service_config(session) + config["tts_enabled"] = bool(payload.get("enabled", False)) + self._session_service_configs[session] = config + return {} + + def _register_session_capabilities(self) -> None: + self.register( + self._builtin_descriptor("session.plugin.is_enabled", "获取会话级插件开关"), + call_handler=self._session_plugin_is_enabled, + ) + self.register( + self._builtin_descriptor( + "session.plugin.filter_handlers", + "按会话过滤 handler 元数据", + ), + call_handler=self._session_plugin_filter_handlers, + ) + self.register( + self._builtin_descriptor( + "session.service.is_llm_enabled", + "获取会话级 LLM 开关", + ), + call_handler=self._session_service_is_llm_enabled, + ) + self.register( + self._builtin_descriptor( + "session.service.set_llm_status", + "写入会话级 LLM 开关", + ), + call_handler=self._session_service_set_llm_status, + ) + self.register( + self._builtin_descriptor( + "session.service.is_tts_enabled", + "获取会话级 TTS 开关", + ), + call_handler=self._session_service_is_tts_enabled, + ) + self.register( + self._builtin_descriptor( + "session.service.set_tts_status", + "写入会话级 TTS 开关", + ), + call_handler=self._session_service_set_tts_status, + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/skill.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/skill.py new file mode 100644 index 0000000000..942f696989 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/skill.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from ....errors import AstrBotError +from ..bridge_base import CapabilityRouterBridgeBase + + +class SkillCapabilityMixin(CapabilityRouterBridgeBase): + def _register_skill_capabilities(self) -> None: + self.register( + self._builtin_descriptor("skill.register", "注册插件 skill"), + call_handler=self._skill_register, + ) + self.register( + self._builtin_descriptor("skill.unregister", "注销插件 skill"), + call_handler=self._skill_unregister, + ) + self.register( + self._builtin_descriptor("skill.list", "列出插件 skill"), + call_handler=self._skill_list, + ) + + async def _skill_register( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, str]: + plugin_id = self._require_caller_plugin_id("skill.register") + plugin = self._plugins.get(plugin_id) + if plugin is None: + raise AstrBotError.invalid_input(f"Unknown plugin: {plugin_id}") + + skill_name = str(payload.get("name", "")).strip() + if not skill_name: + raise AstrBotError.invalid_input("skill.register requires name") + skill_path = str(payload.get("path", "")).strip() + if not skill_path: + raise AstrBotError.invalid_input("skill.register requires path") + + path_obj = Path(skill_path) + skill_dir = path_obj.parent if path_obj.name == "SKILL.md" else path_obj + + entry = { + "name": skill_name, + "description": str(payload.get("description", "") or ""), + "path": skill_path, + "skill_dir": str(skill_dir), + } + plugin.skills[skill_name] = entry + return dict(entry) + + async def _skill_unregister( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, bool]: + plugin_id = self._require_caller_plugin_id("skill.unregister") + plugin = self._plugins.get(plugin_id) + if plugin is None: + raise AstrBotError.invalid_input(f"Unknown plugin: {plugin_id}") + removed = ( + plugin.skills.pop(str(payload.get("name", "")).strip(), None) is not None + ) + return {"removed": removed} + + async def _skill_list( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, list[dict[str, str]]]: + plugin_id = self._require_caller_plugin_id("skill.list") + plugin = self._plugins.get(plugin_id) + if plugin is None: + raise AstrBotError.invalid_input(f"Unknown plugin: {plugin_id}") + return { + "skills": [ + dict(plugin.skills[name]) for name in sorted(plugin.skills.keys()) + ] + } diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/system.py b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/system.py new file mode 100644 index 0000000000..12012e5699 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_capability_router_builtins/capabilities/system.py @@ -0,0 +1,492 @@ +from __future__ import annotations + +import json +import uuid +from typing import Any + +from ....errors import AstrBotError +from ..bridge_base import ( + CapabilityRouterBridgeBase, + _clone_chain_payload, + _clone_target_payload, +) + + +class SystemCapabilityMixin(CapabilityRouterBridgeBase): + @staticmethod + def _overlay_request_id(request_id: str, payload: dict[str, Any]) -> str: + scope_request_id = payload.get("_request_scope_id") + if isinstance(scope_request_id, str) and scope_request_id.strip(): + return scope_request_id + return request_id + + def _register_system_capabilities(self) -> None: + self.register( + self._builtin_descriptor("system.get_data_dir", "获取插件数据目录"), + call_handler=self._system_get_data_dir, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.text_to_image", "文本转图片"), + call_handler=self._system_text_to_image, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.html_render", "渲染 HTML 模板"), + call_handler=self._system_html_render, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.file.register", "注册文件令牌"), + call_handler=self._system_file_register, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.file.handle", "解析文件令牌"), + call_handler=self._system_file_handle, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.session_waiter.register", + "注册会话等待器", + ), + call_handler=self._system_session_waiter_register, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.session_waiter.unregister", + "注销会话等待器", + ), + call_handler=self._system_session_waiter_unregister, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.event.react", "发送事件表情回应"), + call_handler=self._system_event_react, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.event.send_typing", "发送输入中状态"), + call_handler=self._system_event_send_typing, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.send_streaming", + "发送事件流式消息", + ), + call_handler=self._system_event_send_streaming, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.send_streaming_chunk", + "推送事件流式消息分片", + ), + call_handler=self._system_event_send_streaming_chunk, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.send_streaming_close", + "关闭事件流式消息会话", + ), + call_handler=self._system_event_send_streaming_close, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.llm.get_state", + "读取当前请求的默认 LLM 状态", + ), + call_handler=self._system_event_llm_get_state, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.llm.request", + "请求当前事件继续进入默认 LLM 链路", + ), + call_handler=self._system_event_llm_request, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.event.result.get", "读取当前请求结果"), + call_handler=self._system_event_result_get, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.event.result.set", "写入当前请求结果"), + call_handler=self._system_event_result_set, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.event.result.clear", "清理当前请求结果"), + call_handler=self._system_event_result_clear, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.handler_whitelist.get", + "读取当前请求 handler 白名单", + ), + call_handler=self._system_event_handler_whitelist_get, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.handler_whitelist.set", + "写入当前请求 handler 白名单", + ), + call_handler=self._system_event_handler_whitelist_set, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "registry.get_handlers_by_event_type", + "按事件类型列出 handler 元数据", + ), + call_handler=self._registry_get_handlers_by_event_type, + ) + self.register( + self._builtin_descriptor( + "registry.get_handler_by_full_name", + "按 full name 查询 handler 元数据", + ), + call_handler=self._registry_get_handler_by_full_name, + ) + self.register( + self._builtin_descriptor( + "registry.command.register", + "注册动态命令路由", + ), + call_handler=self._registry_command_register, + ) + + def _ensure_request_overlay(self, request_id: str) -> dict[str, Any]: + overlay = self._request_overlays.get(request_id) + if overlay is None: + overlay = { + "should_call_llm": False, + "requested_llm": False, + "result": None, + "handler_whitelist": None, + } + self._request_overlays[request_id] = overlay + return overlay + + async def _system_get_data_dir( + self, _request_id: str, _payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("system.get_data_dir") + data_dir = self._plugin_data_dir( + plugin_id, + capability_name="system.get_data_dir", + ) + data_dir.mkdir(parents=True, exist_ok=True) + return {"path": str(data_dir)} + + async def _system_text_to_image( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + text = str(payload.get("text", "")) + if bool(payload.get("return_url", True)): + return {"result": f"mock://text_to_image/{text}"} + return {"result": f"{text}"} + + async def _system_html_render( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + tmpl = str(payload.get("tmpl", "")) + data = payload.get("data") + if not isinstance(data, dict): + raise AstrBotError.invalid_input("system.html_render requires object data") + if bool(payload.get("return_url", True)): + return {"result": f"mock://html_render/{tmpl}"} + return {"result": json.dumps({"tmpl": tmpl, "data": data}, ensure_ascii=False)} + + async def _system_file_register( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + path = str(payload.get("path", "")).strip() + if not path: + raise AstrBotError.invalid_input("system.file.register requires path") + file_token = uuid.uuid4().hex + self._file_token_store[file_token] = path + return {"token": file_token, "url": f"mock://file/{file_token}"} + + async def _system_file_handle( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + file_token = str(payload.get("token", "")).strip() + if not file_token: + raise AstrBotError.invalid_input("system.file.handle requires token") + path = self._file_token_store.pop(file_token, None) + if path is None: + raise AstrBotError.invalid_input(f"Unknown file token: {file_token}") + return {"path": path} + + async def _system_event_llm_get_state( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + overlay = self._ensure_request_overlay( + self._overlay_request_id(request_id, payload) + ) + return { + "should_call_llm": bool(overlay["should_call_llm"]), + "requested_llm": bool(overlay["requested_llm"]), + } + + async def _system_event_llm_request( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + overlay_request_id = self._overlay_request_id(request_id, payload) + overlay = self._ensure_request_overlay(overlay_request_id) + overlay["requested_llm"] = True + overlay["should_call_llm"] = True + return await self._system_event_llm_get_state( + request_id, + {"_request_scope_id": overlay_request_id}, + _token, + ) + + async def _system_event_result_get( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + overlay = self._ensure_request_overlay( + self._overlay_request_id(request_id, payload) + ) + result = overlay.get("result") + return {"result": dict(result) if isinstance(result, dict) else None} + + async def _system_event_result_set( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + result = payload.get("result") + if not isinstance(result, dict): + raise AstrBotError.invalid_input( + "system.event.result.set 的 result 必须是 object" + ) + overlay = self._ensure_request_overlay( + self._overlay_request_id(request_id, payload) + ) + overlay["result"] = dict(result) + return {"result": dict(result)} + + async def _system_event_result_clear( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + overlay = self._ensure_request_overlay( + self._overlay_request_id(request_id, payload) + ) + overlay["result"] = None + return {} + + async def _system_event_handler_whitelist_get( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + overlay = self._ensure_request_overlay( + self._overlay_request_id(request_id, payload) + ) + whitelist = overlay.get("handler_whitelist") + if whitelist is None: + return {"plugin_names": None} + return {"plugin_names": sorted(str(item) for item in whitelist)} + + async def _system_event_handler_whitelist_set( + self, request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + overlay_request_id = self._overlay_request_id(request_id, payload) + overlay = self._ensure_request_overlay(overlay_request_id) + plugin_names_payload = payload.get("plugin_names") + if plugin_names_payload is None: + overlay["handler_whitelist"] = None + elif isinstance(plugin_names_payload, list): + overlay["handler_whitelist"] = { + str(item) for item in plugin_names_payload if str(item).strip() + } + else: + raise AstrBotError.invalid_input( + "system.event.handler_whitelist.set 的 plugin_names 必须是数组或 null" + ) + return await self._system_event_handler_whitelist_get( + request_id, + {"_request_scope_id": overlay_request_id}, + _token, + ) + + async def _registry_get_handlers_by_event_type( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + event_type = str(payload.get("event_type", "")).strip() + handlers: list[dict[str, Any]] = [] + for plugin in self._plugins.values(): + handlers.extend( + [ + dict(handler) + for handler in plugin.handlers + if event_type in handler.get("event_types", []) + ] + ) + if event_type == "message": + for plugin_name, routes in self._dynamic_command_routes.items(): + for route in routes: + if not isinstance(route, dict): + continue + handlers.append( + { + "plugin_name": str(route.get("plugin_name", plugin_name)), + "handler_full_name": str( + route.get("handler_full_name", "") + ), + "trigger_type": ( + "message" + if bool(route.get("use_regex", False)) + else "command" + ), + "description": ( + None + if route.get("desc") is None + else str(route.get("desc", "")).strip() or None + ), + "event_types": ["message"], + "enabled": True, + "group_path": [], + "priority": int(route.get("priority", 0) or 0), + "kind": "handler", + "require_admin": False, + "required_role": None, + } + ) + return {"handlers": handlers} + + async def _registry_get_handler_by_full_name( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + full_name = str(payload.get("full_name", "")).strip() + for plugin in self._plugins.values(): + for handler in plugin.handlers: + if handler.get("handler_full_name") == full_name: + return {"handler": dict(handler)} + return {"handler": None} + + async def _registry_command_register( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + source_event_type = str(payload.get("source_event_type", "")).strip() + if source_event_type not in {"astrbot_loaded", "platform_loaded"}: + raise AstrBotError.invalid_input( + "register_commands is only available in astrbot_loaded/platform_loaded events" + ) + if bool(payload.get("ignore_prefix", False)): + raise AstrBotError.invalid_input( + "register_commands(ignore_prefix=True) is unsupported in SDK runtime" + ) + priority_value = payload.get("priority", 0) + if isinstance(priority_value, bool) or not isinstance(priority_value, int): + raise AstrBotError.invalid_input( + "registry.command.register 的 priority 必须是 integer" + ) + plugin_id = self._require_caller_plugin_id("registry.command.register") + self.register_dynamic_command_route( + plugin_id=plugin_id, + command_name=str(payload.get("command_name", "")), + handler_full_name=str(payload.get("handler_full_name", "")), + desc=str(payload.get("desc", "")), + priority=priority_value, + use_regex=bool(payload.get("use_regex", False)), + ) + return {} + + async def _system_session_waiter_register( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("system.session_waiter.register") + session_key = str(payload.get("session_key", "")).strip() + if not session_key: + raise AstrBotError.invalid_input( + "system.session_waiter.register requires session_key" + ) + self._session_waiters.setdefault(plugin_id, set()).add(session_key) + return {} + + async def _system_session_waiter_unregister( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + plugin_id = self._require_caller_plugin_id("system.session_waiter.unregister") + session_key = str(payload.get("session_key", "")).strip() + plugin_waiters = self._session_waiters.get(plugin_id) + if plugin_waiters is None: + return {} + plugin_waiters.discard(session_key) + if not plugin_waiters: + self._session_waiters.pop(plugin_id, None) + return {} + + async def _system_event_react( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self.event_actions.append( + { + "action": "react", + "emoji": str(payload.get("emoji", "")), + "target": _clone_target_payload(payload.get("target")), + } + ) + return {"supported": True} + + async def _system_event_send_typing( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + self.event_actions.append( + { + "action": "send_typing", + "target": _clone_target_payload(payload.get("target")), + } + ) + return {"supported": True} + + async def _system_event_send_streaming( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + stream_id = f"mock-stream-{len(self._event_streams) + 1}" + stream_state: dict[str, Any] = { + "target": _clone_target_payload(payload.get("target")), + "chunks": [], + "use_fallback": bool(payload.get("use_fallback", False)), + } + self._event_streams[stream_id] = stream_state + return {"supported": True, "stream_id": stream_id} + + async def _system_event_send_streaming_chunk( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + stream = self._event_streams.get(str(payload.get("stream_id", ""))) + if stream is None: + raise AstrBotError.invalid_input("Unknown sdk event streaming session") + chain = payload.get("chain") + if not isinstance(chain, list): + raise AstrBotError.invalid_input( + "system.event.send_streaming_chunk requires a chain array" + ) + stream["chunks"].append({"chain": _clone_chain_payload(chain)}) + return {} + + async def _system_event_send_streaming_close( + self, _request_id: str, payload: dict[str, Any], _token + ) -> dict[str, Any]: + stream_id = str(payload.get("stream_id", "")) + stream = self._event_streams.pop(stream_id, None) + if stream is None: + raise AstrBotError.invalid_input("Unknown sdk event streaming session") + self.event_actions.append( + { + "action": "send_streaming", + "target": stream["target"], + "chunks": list(stream["chunks"]), + "use_fallback": bool(stream["use_fallback"]), + } + ) + return {"supported": True} diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_command_matching.py b/astrbot-sdk/src/astrbot_sdk/runtime/_command_matching.py new file mode 100644 index 0000000000..66dfa44f91 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_command_matching.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import re +import shlex +from collections.abc import Sequence +from typing import Any + +from ..protocol.descriptors import ParamSpec + + +def match_command_name(text: str, command_name: str) -> str | None: + normalized = text.strip() + if normalized == command_name: + return "" + if normalized.startswith(f"{command_name} "): + return normalized[len(command_name) :].strip() + return None + + +def build_command_args( + param_specs: Sequence[ParamSpec], remainder: str +) -> dict[str, Any]: + if not param_specs or not remainder: + return {} + if len(param_specs) == 1: + return {param_specs[0].name: remainder} + parts = split_command_remainder(remainder) + values: dict[str, Any] = {} + for index, spec in enumerate(param_specs): + if index >= len(parts): + break + if spec.type == "greedy_str": + values[spec.name] = " ".join(parts[index:]) + break + values[spec.name] = parts[index] + return values + + +def build_regex_args( + param_specs: Sequence[ParamSpec], match: re.Match[str] +) -> dict[str, Any]: + named = { + key: value for key, value in match.groupdict().items() if value is not None + } + names = [spec.name for spec in param_specs if spec.name not in named] + positional = [value for value in match.groups() if value is not None] + for index, value in enumerate(positional): + if index >= len(names): + break + named[names[index]] = value + return named + + +def split_command_remainder(remainder: str) -> list[str]: + if not remainder: + return [] + try: + return shlex.split(remainder) + except ValueError: + return remainder.split() diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_loader_support.py b/astrbot-sdk/src/astrbot_sdk/runtime/_loader_support.py new file mode 100644 index 0000000000..40d162d355 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_loader_support.py @@ -0,0 +1,156 @@ +"""Support helpers for runtime loader reflection and signature validation. + +本模块提供运行时加载器所需的反射和签名验证工具函数,主要用于: +1. 解析 handler/capability 函数签名,提取参数类型信息 +2. 识别需要注入的框架对象(如 Context、MessageEvent、ScheduleContext) +3. 构建参数规格 (ParamSpec) 供协议层使用 +4. 验证 schedule handler 的签名合法性 + +关键函数: +- build_param_specs: 从 handler 签名构建参数规格列表 +- is_injected_parameter: 判断参数是否应由框架注入而非从命令行解析 +- validate_schedule_signature: 确保 schedule handler 只接受允许的注入参数 +""" + +from __future__ import annotations + +import inspect +import typing +from typing import Any, Literal, TypeAlias, cast + +from .._internal.injected_params import is_framework_injected_parameter +from .._internal.typing_utils import unwrap_optional +from ..decorators import get_capability_meta, get_handler_meta +from ..protocol.descriptors import ParamSpec +from ..types import GreedyStr + +ParamTypeName: TypeAlias = Literal[ + "str", "int", "float", "bool", "optional", "greedy_str" +] +OptionalInnerType: TypeAlias = Literal["str", "int", "float", "bool"] | None + + +def is_injected_parameter(annotation: Any, parameter_name: str) -> bool: + return is_framework_injected_parameter(parameter_name, annotation) + + +def param_type_name(annotation: Any) -> tuple[ParamTypeName, OptionalInnerType, bool]: + normalized, is_optional = unwrap_optional(annotation) + if normalized is GreedyStr: + return "greedy_str", None, False + if normalized in {int, float, bool, str}: + normalized_name = cast( + Literal["str", "int", "float", "bool"], normalized.__name__ + ) + if is_optional: + return "optional", normalized_name, False + return normalized_name, None, True + if is_optional: + return "optional", "str", False + return "str", None, True + + +def build_param_specs(handler: Any) -> list[ParamSpec]: + try: + signature = inspect.signature(handler) + except (TypeError, ValueError): + return [] + try: + type_hints = typing.get_type_hints(handler) + except Exception: + type_hints = {} + + specs: list[ParamSpec] = [] + for parameter in signature.parameters.values(): + if parameter.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + continue + annotation = type_hints.get(parameter.name) + if is_injected_parameter(annotation, parameter.name): + continue + param_type, inner_type, required = param_type_name(annotation) + if parameter.default is not inspect.Parameter.empty: + required = False + specs.append( + ParamSpec( + name=parameter.name, + type=param_type, + required=required, + inner_type=inner_type, + ) + ) + + greedy_indexes = [ + index for index, spec in enumerate(specs) if spec.type == "greedy_str" + ] + if greedy_indexes and greedy_indexes[-1] != len(specs) - 1: + greedy_spec = specs[greedy_indexes[-1]] + raise ValueError(f"参数 '{greedy_spec.name}' (GreedyStr) 必须是最后一个参数。") + return specs + + +def validate_schedule_signature(handler: Any) -> None: + try: + signature = inspect.signature(handler) + except (TypeError, ValueError): + return + allowed_names = {"ctx", "context", "sched", "schedule"} + invalid = [ + parameter.name + for parameter in signature.parameters.values() + if parameter.kind + in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + and parameter.name not in allowed_names + ] + if invalid: + raise ValueError( + "Schedule handler 只允许注入 ctx/context 和 sched/schedule 参数。" + ) + + +def resolve_handler_candidate(instance: Any, name: str) -> tuple[Any, Any] | None: + try: + raw = inspect.getattr_static(instance, name) + except AttributeError: + return None + candidates = [raw] + wrapped = getattr(raw, "__func__", None) + if wrapped is not None: + candidates.append(wrapped) + for candidate in candidates: + meta = get_handler_meta(candidate) + if meta is not None and meta.trigger is not None: + return getattr(instance, name), meta + return None + + +def resolve_capability_candidate(instance: Any, name: str) -> tuple[Any, Any] | None: + try: + raw = inspect.getattr_static(instance, name) + except AttributeError: + return None + candidates = [raw] + wrapped = getattr(raw, "__func__", None) + if wrapped is not None: + candidates.append(wrapped) + for candidate in candidates: + meta = get_capability_meta(candidate) + if meta is not None: + return getattr(instance, name), meta + return None + + +__all__ = [ + "build_param_specs", + "is_injected_parameter", + "param_type_name", + "resolve_capability_candidate", + "resolve_handler_candidate", + "unwrap_optional", + "validate_schedule_signature", +] diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/_streaming.py b/astrbot-sdk/src/astrbot_sdk/runtime/_streaming.py new file mode 100644 index 0000000000..29d2671caa --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/_streaming.py @@ -0,0 +1,28 @@ +"""Shared stream execution primitives for runtime internals. + +本模块定义流式执行的通用数据结构 StreamExecution,用于: +1. 封装异步生成器迭代器,支持逐块返回数据 +2. 提供收集完成后的聚合回调 (finalize) +3. 控制是否需要在内存中累积所有分块 + +使用场景: +- LLM 流式对话返回逐字输出 +- DB watch 监听键值变更流 +- 任何需要分块返回而非一次性返回的能力调用 +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Callable +from dataclasses import dataclass +from typing import Any + + +@dataclass(slots=True) +class StreamExecution: + iterator: AsyncIterator[dict[str, Any]] + finalize: Callable[[list[dict[str, Any]]], dict[str, Any]] + collect_chunks: bool = True + + +__all__ = ["StreamExecution"] diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/bootstrap.py b/astrbot-sdk/src/astrbot_sdk/runtime/bootstrap.py new file mode 100644 index 0000000000..a08208f912 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/bootstrap.py @@ -0,0 +1,133 @@ +"""启动引导入口。 + +对外提供三个顶层启动函数: + +- ``run_supervisor``: 启动 Supervisor 进程 +- ``run_plugin_worker``: 启动单插件或组 Worker 进程 +- ``run_websocket_server``: 以 WebSocket 方式启动 Worker + +运行时核心类分布在同目录的子模块: + +- ``runtime.supervisor``: ``SupervisorRuntime`` / ``WorkerSession`` +- ``runtime.worker``: ``PluginWorkerRuntime`` / ``GroupWorkerRuntime`` +""" + +from __future__ import annotations + +import asyncio +import sys +from pathlib import Path +from typing import IO + +from .loader import PluginEnvironmentManager +from .supervisor import ( + SupervisorRuntime, + WorkerSession, + _install_signal_handlers, + _prepare_stdio_transport, + _sdk_source_dir, + _wait_for_shutdown, +) +from .transport import StdioTransport, WebSocketServerTransport +from .worker import GroupWorkerRuntime, PluginWorkerRuntime + +__all__ = [ + "GroupWorkerRuntime", + "PluginWorkerRuntime", + "SupervisorRuntime", + "WorkerSession", + "_install_signal_handlers", + "_prepare_stdio_transport", + "_sdk_source_dir", + "_wait_for_shutdown", + "run_supervisor", + "run_plugin_worker", + "run_websocket_server", +] + + +async def run_supervisor( + *, + plugins_dir: Path = Path("plugins"), + stdin: IO[str] | None = None, + stdout: IO[str] | None = None, + env_manager: PluginEnvironmentManager | None = None, +) -> None: + transport_stdin, transport_stdout, original_stdout = _prepare_stdio_transport( + stdin, + stdout, + ) + transport = StdioTransport(stdin=transport_stdin, stdout=transport_stdout) + runtime = SupervisorRuntime( + transport=transport, + plugins_dir=plugins_dir, + env_manager=env_manager, + ) + + try: + await runtime.start() + stop_event = asyncio.Event() + _install_signal_handlers(stop_event) + await _wait_for_shutdown(runtime.peer, stop_event) + finally: + await runtime.stop() + if original_stdout is not None: + sys.stdout = original_stdout + + +async def run_plugin_worker( + *, + plugin_dir: Path | None = None, + group_metadata: Path | None = None, + stdin: IO[str] | None = None, + stdout: IO[str] | None = None, +) -> None: + if plugin_dir is None and group_metadata is None: + raise ValueError("plugin_dir or group_metadata is required") + if plugin_dir is not None and group_metadata is not None: + raise ValueError("plugin_dir and group_metadata are mutually exclusive") + + transport_stdin, transport_stdout, original_stdout = _prepare_stdio_transport( + stdin, + stdout, + ) + transport = StdioTransport(stdin=transport_stdin, stdout=transport_stdout) + if group_metadata is not None: + runtime = GroupWorkerRuntime( + group_metadata_path=group_metadata, + transport=transport, + ) + else: + # 前置互斥校验已保证单插件模式下 plugin_dir 一定存在;这里显式收窄, + # 避免把入口层的 Optional 继续传播到单插件运行时。 + assert plugin_dir is not None + runtime = PluginWorkerRuntime(plugin_dir=plugin_dir, transport=transport) + try: + await runtime.start() + stop_event = asyncio.Event() + _install_signal_handlers(stop_event) + await _wait_for_shutdown(runtime.peer, stop_event) + finally: + await runtime.stop() + if original_stdout is not None: + sys.stdout = original_stdout + + +async def run_websocket_server( + *, + host: str = "127.0.0.1", + port: int = 8765, + path: str = "/", + plugin_dir: Path | None = None, +) -> None: + runtime = PluginWorkerRuntime( + plugin_dir=plugin_dir or Path.cwd(), + transport=WebSocketServerTransport(host=host, port=port, path=path), + ) + try: + await runtime.start() + stop_event = asyncio.Event() + _install_signal_handlers(stop_event) + await _wait_for_shutdown(runtime.peer, stop_event) + finally: + await runtime.stop() diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/capability_dispatcher.py b/astrbot-sdk/src/astrbot_sdk/runtime/capability_dispatcher.py new file mode 100644 index 0000000000..c1d503490a --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/capability_dispatcher.py @@ -0,0 +1,511 @@ +"""Capability invocation dispatcher. + +本模块实现能力调用的分发器,负责: +1. 接收能力调用请求,定位对应的已注册能力 +2. 构建调用上下文 (Context),注入必要的依赖 +3. 支持同步和流式两种调用模式 +4. 管理活跃调用任务的生命周期和取消 + +参数注入策略: +按类型注入 Context / CancelToken / dict,或按参数名注入 +ctx / context / payload / input / data / cancel_token / token。 +若无法匹配则抛出详细的错误信息,帮助开发者定位问题。 +""" + +from __future__ import annotations + +import asyncio +import inspect +import json +import typing +from collections.abc import AsyncIterator, Sequence +from typing import Any, cast, get_type_hints + +from loguru import logger + +from .._internal.invocation_context import caller_plugin_scope +from .._internal.plugin_logger import PluginLogger +from .._internal.star_runtime import bind_star_runtime +from .._internal.typing_utils import unwrap_optional +from ..context import CancelToken, Context +from ..errors import AstrBotError +from ..events import MessageEvent +from ..star import Star +from ._streaming import StreamExecution +from .loader import LoadedCapability, LoadedLLMTool + + +class CapabilityDispatcher: + def __init__( + self, + *, + plugin_id: str, + peer, + capabilities: Sequence[LoadedCapability], + llm_tools: Sequence[LoadedLLMTool] | None = None, + ) -> None: + self._plugin_id = plugin_id + self._peer = peer + self._capabilities = {item.descriptor.name: item for item in capabilities} + self._llm_tools: dict[tuple[str, str], LoadedLLMTool] = {} + try: + setattr(peer, "_sdk_capability_dispatcher", self) + except AttributeError: + logger.warning( + f"Failed to attach _sdk_capability_dispatcher to peer {peer}, " + "dynamic LLM tool registration may not work" + ) + for item in llm_tools or []: + self._register_llm_tool(item, item.plugin_id or plugin_id) + self._active: dict[str, tuple[asyncio.Task[Any], CancelToken]] = {} + + def _register_llm_tool( + self, + loaded: LoadedLLMTool, + owner_plugin: str, + ) -> None: + self._llm_tools[(owner_plugin, loaded.spec.name)] = loaded + if loaded.spec.handler_ref and loaded.spec.handler_ref != loaded.spec.name: + self._llm_tools[(owner_plugin, loaded.spec.handler_ref)] = loaded + + def add_dynamic_llm_tool( + self, + *, + plugin_id: str, + spec, + callable_obj, + owner: Any | None = None, + ) -> None: + self.remove_llm_tool(plugin_id, spec.name) + loaded = LoadedLLMTool( + spec=spec.model_copy(deep=True), + callable=callable_obj, + owner=owner, + plugin_id=plugin_id, + ) + self._register_llm_tool(loaded, plugin_id) + + def remove_llm_tool(self, plugin_id: str, name: str) -> bool: + removed = False + for key, value in list(self._llm_tools.items()): + if key[0] != plugin_id: + continue + spec_name = str(getattr(value.spec, "name", "")).strip() + handler_ref = str(getattr(value.spec, "handler_ref", "") or "").strip() + if name not in {spec_name, handler_ref}: + continue + self._llm_tools.pop(key, None) + removed = True + return removed + + async def invoke( + self, + message, + cancel_token: CancelToken, + ) -> dict[str, Any] | StreamExecution: + if message.capability == "internal.llm_tool.execute": + return await self._invoke_registered_llm_tool(message, cancel_token) + + loaded = self._capabilities.get(message.capability) + if loaded is None: + raise LookupError(f"capability not found: {message.capability}") + + plugin_id = self._resolve_plugin_id(loaded) + ctx = Context( + peer=self._peer, + plugin_id=plugin_id, + request_id=message.id, + cancel_token=cancel_token, + ) + bound_logger = cast(PluginLogger, ctx.logger).bind( + plugin_id=plugin_id, + request_id=message.id, + capability=message.capability, + session_id=self._logger_session_id(dict(message.input)), + event_type=self._logger_event_type(dict(message.input)), + ) + ctx.logger = bound_logger + + with caller_plugin_scope(plugin_id): + task = asyncio.create_task( + self._run_capability( + loaded, + payload=dict(message.input), + ctx=ctx, + cancel_token=cancel_token, + stream=bool(message.stream), + ) + ) + self._active[message.id] = (task, cancel_token) + try: + return await task + finally: + self._active.pop(message.id, None) + + async def _invoke_registered_llm_tool( + self, + message, + cancel_token: CancelToken, + ) -> dict[str, Any]: + payload = dict(message.input) + plugin_id = str(payload.get("plugin_id") or self._plugin_id) + tool_name = str(payload.get("tool_name", "")) + handler_ref = str(payload.get("handler_ref") or tool_name) + loaded = self._llm_tools.get((plugin_id, handler_ref)) + if loaded is None: + loaded = self._llm_tools.get((plugin_id, tool_name)) + if loaded is None: + raise LookupError(f"llm tool not found: {plugin_id}:{tool_name}") + + event_payload = payload.get("event") + ctx = Context( + peer=self._peer, + plugin_id=plugin_id, + request_id=message.id, + cancel_token=cancel_token, + source_event_payload=event_payload + if isinstance(event_payload, dict) + else None, + ) + bound_logger = cast(PluginLogger, ctx.logger).bind( + plugin_id=plugin_id, + request_id=message.id, + capability="internal.llm_tool.execute", + session_id=self._logger_session_id(payload), + event_type=self._logger_event_type(payload), + ) + ctx.logger = bound_logger + event = MessageEvent.from_payload( + event_payload if isinstance(event_payload, dict) else {}, + context=ctx, + ) + self._bind_event_reply_handler(ctx, event) + tool_args = payload.get("tool_args") + normalized_args = dict(tool_args) if isinstance(tool_args, dict) else {} + + with caller_plugin_scope(plugin_id): + task = asyncio.create_task( + self._run_registered_llm_tool(loaded, event, ctx, normalized_args) + ) + self._active[message.id] = (task, cancel_token) + try: + return await task + finally: + self._active.pop(message.id, None) + + def _bind_event_reply_handler(self, ctx: Context, event: MessageEvent) -> None: + async def reply(text: str) -> None: + try: + await ctx.platform.send(event.session_ref or event.session_id, text) + except TypeError: + send = getattr(self._peer, "send", None) + if not callable(send): + raise + result = send(event.session_id, text) + if inspect.isawaitable(result): + await result + + event.bind_reply_handler(reply) + + async def _run_registered_llm_tool( + self, + loaded: LoadedLLMTool, + event: MessageEvent, + ctx: Context, + tool_args: dict[str, Any], + ) -> dict[str, Any]: + owner = loaded.owner if isinstance(loaded.owner, Star) else None + with bind_star_runtime(owner, ctx): + result = loaded.callable( + *self._build_tool_args( + loaded.callable, + event, + ctx, + tool_args, + ) + ) + if inspect.isasyncgen(result): + raise AstrBotError.protocol_error( + "SDK LLM tool must return awaitable result, async generator is unsupported" + ) + if inspect.isawaitable(result): + result = await result + if result is None: + # content=None means the tool completed successfully but produced no + # textual payload. The core bridge preserves this as a real None. + return {"content": None, "success": True} + if isinstance(result, dict): + return { + "content": json.dumps(result, ensure_ascii=False, default=str), + "success": True, + } + return {"content": str(result), "success": True} + + def _build_tool_args( + self, + handler, + event: MessageEvent, + ctx: Context, + tool_args: dict[str, Any], + ) -> list[Any]: + signature = inspect.signature(handler) + args: list[Any] = [] + type_hints: dict[str, Any] = {} + try: + type_hints = get_type_hints(handler) + except Exception: + type_hints = {} + + for parameter in signature.parameters.values(): + if parameter.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + continue + + injected = None + param_type = type_hints.get(parameter.name) + if param_type is not None: + injected = self._inject_tool_by_type(param_type, event, ctx) + if injected is None: + if parameter.name == "event": + injected = event + elif parameter.name in {"ctx", "context"}: + injected = ctx + elif parameter.name in tool_args: + injected = tool_args[parameter.name] + if injected is None: + if parameter.default is not parameter.empty: + continue + raise TypeError( + f"SDK LLM tool '{getattr(handler, '__name__', repr(handler))}' missing required argument '{parameter.name}'" + ) + args.append(injected) + return args + + def _inject_tool_by_type( + self, + param_type: Any, + event: MessageEvent, + ctx: Context, + ) -> Any: + param_type, _is_optional = unwrap_optional(param_type) + + if param_type is Context or ( + isinstance(param_type, type) and issubclass(param_type, Context) + ): + return ctx + if param_type is MessageEvent or ( + isinstance(param_type, type) and issubclass(param_type, MessageEvent) + ): + return event + return None + + def _resolve_plugin_id(self, loaded: LoadedCapability) -> str: + if loaded.plugin_id: + return loaded.plugin_id + return self._plugin_id + + @staticmethod + def _logger_session_id(payload: dict[str, Any]) -> str: + if isinstance(payload.get("event"), dict): + return str(payload["event"].get("session_id", "")) + return str(payload.get("session", "")) + + @staticmethod + def _logger_event_type(payload: dict[str, Any]) -> str: + if isinstance(payload.get("event"), dict): + event_payload = payload["event"] + return str( + event_payload.get("event_type") + or event_payload.get("type") + or event_payload.get("message_type") + or "message" + ) + if payload.get("session") is not None: + return "capability" + return "capability" + + async def cancel(self, request_id: str) -> None: + active = self._active.get(request_id) + if active is None: + return + task, cancel_token = active + cancel_token.cancel() + task.cancel() + + async def _run_capability( + self, + loaded: LoadedCapability, + *, + payload: dict[str, Any], + ctx: Context, + cancel_token: CancelToken, + stream: bool, + ) -> dict[str, Any] | StreamExecution: + result = loaded.callable( + *self._build_args( + loaded.callable, + payload, + ctx, + cancel_token, + plugin_id=self._resolve_plugin_id(loaded), + capability_name=loaded.descriptor.name, + ) + ) + if stream: + if inspect.isasyncgen(result): + return StreamExecution( + iterator=self._iterate_generator(result), + finalize=lambda chunks: {"items": chunks}, + ) + if inspect.isawaitable(result): + result = await result + if isinstance(result, StreamExecution): + return result + raise AstrBotError.protocol_error( + "stream=true 的插件 capability 必须返回 async generator 或 StreamExecution" + ) + + if inspect.isasyncgen(result): + raise AstrBotError.protocol_error( + "stream=false 的插件 capability 不能返回 async generator" + ) + if inspect.isawaitable(result): + result = await result + return self._normalize_output(result) + + def _build_args( + self, + handler, + payload: dict[str, Any], + ctx: Context, + cancel_token: CancelToken, + *, + plugin_id: str | None = None, + capability_name: str | None = None, + ) -> list[Any]: + signature = inspect.signature(handler) + args: list[Any] = [] + + type_hints: dict[str, Any] = {} + try: + type_hints = get_type_hints(handler) + except Exception: + pass + + for parameter in signature.parameters.values(): + if parameter.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + continue + + injected = None + param_type = type_hints.get(parameter.name) + if param_type is not None: + injected = self._inject_by_type(param_type, payload, ctx, cancel_token) + + if injected is None: + if parameter.name in {"ctx", "context"}: + injected = ctx + elif parameter.name in {"payload", "input", "data"}: + injected = payload + elif parameter.name in {"cancel_token", "token"}: + injected = cancel_token + + if injected is None: + if parameter.default is not parameter.empty: + continue + raise TypeError( + self._format_capability_injection_error( + handler=handler, + parameter_name=parameter.name, + plugin_id=plugin_id, + capability_name=capability_name, + payload=payload, + ) + ) + args.append(injected) + + return args + + def _inject_by_type( + self, + param_type: Any, + payload: dict[str, Any], + ctx: Context, + cancel_token: CancelToken, + ) -> Any: + param_type, _is_optional = unwrap_optional(param_type) + origin = typing.get_origin(param_type) + + if param_type is Context or ( + isinstance(param_type, type) and issubclass(param_type, Context) + ): + return ctx + if param_type is CancelToken or ( + isinstance(param_type, type) and issubclass(param_type, CancelToken) + ): + return cancel_token + if param_type is dict or origin is dict: + return payload + return None + + def _format_capability_injection_error( + self, + *, + handler, + parameter_name: str, + plugin_id: str | None, + capability_name: str | None, + payload: dict[str, Any], + ) -> str: + plugin_text = plugin_id or self._plugin_id + target = capability_name or getattr(handler, "__name__", "") + payload_keys = sorted(str(key) for key in payload.keys()) + payload_keys_text = ", ".join(payload_keys) if payload_keys else "" + return ( + f"插件 '{plugin_text}' 的 capability '{target}' 参数注入失败:" + f"必填参数 '{parameter_name}' 无法注入。" + f"签名: {getattr(handler, '__name__', '')}" + f"{self._callable_signature(handler)}。" + "当前支持按类型注入 Context / CancelToken / dict," + "按参数名注入 ctx / context / payload / input / data / cancel_token / token," + f"以及 payload 中现有键:{payload_keys_text}。" + ) + + async def _iterate_generator( + self, + generator: AsyncIterator[Any], + ) -> AsyncIterator[dict[str, Any]]: + async for item in generator: + yield self._normalize_chunk(item) + + def _normalize_chunk(self, item: Any) -> dict[str, Any]: + output = self._normalize_output(item) + if output: + return output + return {"ok": True} + + def _normalize_output(self, result: Any) -> dict[str, Any]: + if result is None: + return {} + if isinstance(result, dict): + return result + model_dump = getattr(result, "model_dump", None) + if callable(model_dump): + dumped = model_dump() + if isinstance(dumped, dict): + return dumped + raise AstrBotError.invalid_input("插件 capability 必须返回 dict 或可序列化对象") + + @staticmethod + def _callable_signature(handler) -> str: + try: + return str(inspect.signature(handler)) + except (TypeError, ValueError): + return "(?)" + + +__all__ = ["CapabilityDispatcher"] diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/capability_router.py b/astrbot-sdk/src/astrbot_sdk/runtime/capability_router.py new file mode 100644 index 0000000000..dbe0058977 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/capability_router.py @@ -0,0 +1,975 @@ +"""能力路由模块。 + +定义 CapabilityRouter 类,负责能力的注册、发现和执行路由。 +能力是核心侧提供给插件侧调用的功能,如 LLM 聊天、存储、消息发送等。 + +核心概念: + CapabilityDescriptor: 能力描述符,声明能力名称、输入输出 Schema 等 + CallHandler: 同步调用处理器,签名 (request_id, payload, cancel_token) -> dict + StreamHandler: 流式调用处理器,签名 (request_id, payload, cancel_token) -> AsyncIterator + FinalizeHandler: 流式结果聚合器,签名 (chunks) -> dict + +内置能力: + LLM: + llm.chat: 同步 LLM 聊天 + llm.chat_raw: 同步 LLM 聊天(完整响应) + llm.stream_chat: 流式 LLM 聊天 + Memory: + memory.search: 搜索记忆 + memory.save: 保存记忆 + memory.save_with_ttl: 保存带过期时间的记忆 + memory.get: 读取单条记忆 + memory.list_keys: 列出命名空间中的记忆键 + memory.exists: 检查记忆键是否存在 + memory.get_many: 批量获取多条记忆 + memory.delete: 删除记忆 + memory.clear_namespace: 清理命名空间中的记忆 + memory.delete_many: 批量删除多条记忆 + memory.count: 统计命名空间中的记忆数量 + memory.stats: 获取记忆统计信息 + DB: + db.get: 读取 KV 存储 + db.set: 写入 KV 存储 + db.delete: 删除 KV 存储 + db.list: 列出 KV 键 + db.get_many: 批量读取多个 KV 键 + db.set_many: 批量写入多个 KV 键 + db.watch: 订阅 KV 变更事件 + Platform: + platform.send: 发送消息 + platform.send_image: 发送图片 + platform.send_chain: 发送消息链 + platform.send_by_session: 主动按会话发送消息链 + platform.get_group: 获取当前群信息 + platform.get_members: 获取群成员 + Permission: + permission.check: 查询用户权限角色 + permission.get_admins: 列出管理员 ID + permission.manager.add_admin: 添加管理员 ID + permission.manager.remove_admin: 移除管理员 ID + HTTP: + http.register_api: 注册 HTTP 路由到插件 capability + http.unregister_api: 注销 HTTP 路由 + http.list_apis: 查询已注册的 HTTP 路由 + Metadata: + metadata.get_plugin: 获取单个插件元数据 + metadata.list_plugins: 列出所有插件元数据 + metadata.get_plugin_config: 获取当前调用插件自己的配置 + Provider: + provider.get_using: 获取当前聊天 Provider + provider.get_current_chat_provider_id: 获取当前聊天 Provider ID + provider.list_all: 列出聊天 Providers + provider.list_all_tts: 列出 TTS Providers + provider.list_all_stt: 列出 STT Providers + provider.list_all_embedding: 列出 Embedding Providers + provider.list_all_rerank: 列出 Rerank Providers + provider.get_using_tts: 获取当前 TTS Provider + provider.get_using_stt: 获取当前 STT Provider + provider.get_by_id: 按 ID 获取 Provider + provider.stt.get_text: STT 转写 + provider.tts.get_audio: TTS 合成音频 + provider.tts.support_stream: 检查 TTS 原生流式支持 + provider.tts.get_audio_stream: 流式 TTS 音频输出 + provider.embedding.get_embedding: 获取单条向量 + provider.embedding.get_embeddings: 批量获取向量 + provider.embedding.get_dim: 获取向量维度 + provider.rerank.rerank: 文档重排序 + provider.manager.set: 设置当前 Provider + provider.manager.get_by_id: 按 ID 获取 Provider 管理记录 + provider.manager.get_merged_provider_config: 获取 Provider 合并配置 + provider.manager.load: 运行时加载 Provider + provider.manager.terminate: 终止已加载的 Provider + provider.manager.create: 创建 Provider + provider.manager.update: 更新 Provider + provider.manager.delete: 删除 Provider + provider.manager.get_insts: 列出已加载聊天 Provider + provider.manager.watch_changes: 订阅 Provider 变更(流式) + Platform Manager: + platform.manager.get_by_id: 按 ID 获取平台管理快照 + platform.manager.clear_errors: 清除平台错误 + platform.manager.get_stats: 获取平台统计信息 + LLM Tool: + llm_tool.manager.get: 获取 LLM 工具状态 + llm_tool.manager.activate: 激活 LLM 工具 + llm_tool.manager.deactivate: 停用 LLM 工具 + llm_tool.manager.add: 动态添加 LLM 工具 + llm_tool.manager.remove: 动态移除 LLM 工具 + Agent: + agent.tool_loop.run: 运行 tool loop + agent.registry.list: 列出 Agent 元数据 + agent.registry.get: 获取 Agent 元数据 + Registry: + registry.get_handlers_by_event_type: 按事件类型列出 handler 元数据 + registry.get_handler_by_full_name: 按 full name 查询 handler 元数据 + Session: + session.plugin.is_enabled: 获取会话级插件开关 + session.plugin.filter_handlers: 按会话过滤 handler 元数据 + session.service.is_llm_enabled: 获取会话级 LLM 开关 + session.service.set_llm_status: 写入会话级 LLM 开关 + session.service.is_tts_enabled: 获取会话级 TTS 开关 + session.service.set_tts_status: 写入会话级 TTS 开关 + Managers: + persona.get / persona.list / persona.create / persona.update / persona.delete + conversation.new / conversation.switch / conversation.delete + conversation.get / conversation.list / conversation.update + kb.list / kb.get / kb.create / kb.update / kb.delete / kb.retrieve + kb.document.upload / kb.document.list / kb.document.get + kb.document.delete / kb.document.refresh + System (内部使用): + system.get_data_dir: 获取插件数据目录 + system.text_to_image: 文本转图片 + system.html_render: 渲染 HTML 模板 + system.file.register: 注册文件令牌 + system.file.handle: 解析文件令牌 + system.session_waiter.register: 注册会话等待器 + system.session_waiter.unregister: 注销会话等待器 + system.event.react: 发送事件表情回应 + system.event.send_typing: 发送输入中状态 + system.event.send_streaming: 发送事件流式消息 + system.event.send_streaming_chunk: 推送事件流式消息分片 + system.dynamic_command.register: 注册动态命令路由 + system.dynamic_command.list: 列出动态命令路由 + system.dynamic_command.remove: 移除动态命令路由 + +能力命名规范: + - 格式: {namespace}.{action} 或 {namespace}.{sub_namespace}.{action} + - 内置能力命名空间: llm, memory, db, platform, permission, http, metadata, provider, llm_tool, agent, registry + - 保留命名空间前缀: handler., system., internal. + +使用示例: + router = CapabilityRouter() + + # 注册同步能力 + router.register( + CapabilityDescriptor( + name="my_plugin.calculate", + description="执行计算", + input_schema={"type": "object", "properties": {"x": {"type": "number"}}}, + output_schema={"type": "object", "properties": {"result": {"type": "number"}}}, + ), + call_handler=my_calculate, + ) + + # 注册流式能力 + async def stream_data(request_id, payload, token): + for i in range(10): + yield {"index": i} + + router.register( + CapabilityDescriptor( + name="my_plugin.stream", + description="流式数据", + supports_stream=True, + cancelable=True, + ), + stream_handler=stream_data, + finalize=lambda chunks: {"count": len(chunks)}, + ) + + # 执行能力 + result = await router.execute("my_plugin.calculate", {"x": 42}, stream=False, ...) + stream_result = await router.execute("my_plugin.stream", {}, stream=True, ...) +""" + +from __future__ import annotations + +import asyncio +import inspect +import re +from collections.abc import AsyncIterator, Awaitable, Callable +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from .._internal.invocation_context import current_caller_plugin_id +from ..errors import AstrBotError +from ..protocol.descriptors import ( + RESERVED_CAPABILITY_PREFIXES, + CapabilityDescriptor, +) +from ._capability_router_builtins import BuiltinCapabilityRouterMixin +from ._streaming import StreamExecution + +CallHandler = Callable[[str, dict[str, Any], object], Awaitable[dict[str, Any]]] +FinalizeHandler = Callable[[list[dict[str, Any]]], dict[str, Any]] +CAPABILITY_NAME_PATTERN = re.compile(r"^[a-z][a-z0-9_]*(?:\.[a-z][a-z0-9_]*)+$") + + +StreamHandler = Callable[ + [str, dict[str, Any], object], + AsyncIterator[dict[str, Any]] + | StreamExecution + | Awaitable[AsyncIterator[dict[str, Any]] | StreamExecution], +] + + +@dataclass(slots=True) +class _CapabilityRegistration: + descriptor: CapabilityDescriptor + call_handler: CallHandler | None = None + stream_handler: StreamHandler | None = None + finalize: FinalizeHandler | None = None + exposed: bool = True + + +@dataclass(slots=True) +class _RegisteredPlugin: + metadata: dict[str, Any] + config: dict[str, Any] + handlers: list[dict[str, Any]] + llm_tools: dict[str, dict[str, Any]] = field(default_factory=dict) + active_llm_tools: set[str] = field(default_factory=set) + local_mcp_servers: dict[str, dict[str, Any]] = field(default_factory=dict) + agents: dict[str, dict[str, Any]] = field(default_factory=dict) + skills: dict[str, dict[str, str]] = field(default_factory=dict) + + +class CapabilityRouter(BuiltinCapabilityRouterMixin): + def __init__(self) -> None: + self._registrations: dict[str, _CapabilityRegistration] = {} + self.db_store: dict[str, Any] = {} + self.memory_store: dict[str, dict[str, Any]] = {} + self._memory_backends: dict[str, Any] = {} + self._memory_index: dict[str, dict[str, Any]] = {} + self._memory_dirty_keys: set[str] = set() + self._memory_expires_at: dict[str, datetime | None] = {} + self.sent_messages: list[dict[str, Any]] = [] + self.event_actions: list[dict[str, Any]] = [] + self._event_streams: dict[str, dict[str, Any]] = {} + self.http_api_store: list[dict[str, Any]] = [] + self._plugins: dict[str, _RegisteredPlugin] = {} + self._request_overlays: dict[str, dict[str, Any]] = {} + self._provider_catalog: dict[str, list[dict[str, Any]]] = { + "chat": [ + { + "id": "mock-chat-provider", + "model": "mock-chat-model", + "type": "mock", + "provider_type": "chat_completion", + } + ], + "tts": [ + { + "id": "mock-tts-provider", + "model": "mock-tts-model", + "type": "mock", + "provider_type": "text_to_speech", + } + ], + "stt": [ + { + "id": "mock-stt-provider", + "model": "mock-stt-model", + "type": "mock", + "provider_type": "speech_to_text", + } + ], + "embedding": [ + { + "id": "mock-embedding-provider", + "model": "mock-embedding-model", + "type": "mock", + "provider_type": "embedding", + } + ], + "rerank": [ + { + "id": "mock-rerank-provider", + "model": "mock-rerank-model", + "type": "mock", + "provider_type": "rerank", + } + ], + } + self._provider_configs: dict[str, dict[str, Any]] = { + str(item["id"]): {**item, "enable": True} + for providers in self._provider_catalog.values() + for item in providers + } + self._active_provider_ids: dict[str, str | None] = { + kind: providers[0]["id"] if providers else None + for kind, providers in self._provider_catalog.items() + } + self._provider_change_subscriptions: dict[ + str, asyncio.Queue[dict[str, Any]] + ] = {} + self._system_data_root = Path.cwd() / ".astrbot_sdk_testing" / "plugin_data" + self._session_waiters: dict[str, set[str]] = {} + self._db_watch_subscriptions: dict[ + str, tuple[str | None, asyncio.Queue[dict[str, Any]]] + ] = {} + self._session_plugin_configs: dict[str, dict[str, Any]] = {} + self._session_service_configs: dict[str, dict[str, Any]] = {} + self._dynamic_command_routes: dict[str, list[dict[str, Any]]] = {} + self._file_token_store: dict[str, str] = {} + self._persona_store: dict[str, dict[str, Any]] = {} + self._conversation_store: dict[str, dict[str, Any]] = {} + self._session_current_conversation_ids: dict[str, str] = {} + self._message_history_store: dict[str, list[dict[str, Any]]] = {} + self._message_history_next_id = 1 + self._mcp_session_store: dict[str, dict[str, Any]] = {} + self._mcp_global_servers: dict[str, dict[str, Any]] = {} + self._mcp_audit_logs: list[dict[str, str]] = [] + self._kb_store: dict[str, dict[str, Any]] = {} + self._kb_document_store: dict[str, dict[str, dict[str, Any]]] = {} + self._kb_document_content_store: dict[str, str] = {} + self._platform_instances: list[dict[str, Any]] = [ + { + "id": "mock-platform", + "name": "Mock Platform", + "type": "mock", + "status": "running", + } + ] + self._permission_admin_ids: list[str] = ["astrbot"] + self._register_builtin_capabilities() + + def upsert_plugin( + self, + *, + metadata: dict[str, Any], + config: dict[str, Any] | None = None, + ) -> None: + name = str(metadata.get("name", "")).strip() + if not name: + raise ValueError("plugin metadata must include a non-empty name") + normalized_metadata = dict(metadata) + normalized_metadata.setdefault("display_name", name) + normalized_metadata.setdefault("description", "") + normalized_metadata.setdefault("author", "") + normalized_metadata.setdefault("version", "0.0.0") + normalized_metadata.setdefault("enabled", True) + normalized_metadata.setdefault("reserved", False) + normalized_metadata.setdefault("acknowledge_global_mcp_risk", False) + normalized_metadata.setdefault("support_platforms", []) + normalized_metadata.setdefault("astrbot_version", None) + local_mcp_servers = normalized_metadata.pop("local_mcp_servers", {}) + self._plugins[name] = _RegisteredPlugin( + metadata=normalized_metadata, + config=dict(config or {}), + handlers=[], + local_mcp_servers={ + str(server_name): dict(server_payload) + for server_name, server_payload in local_mcp_servers.items() + if str(server_name).strip() and isinstance(server_payload, dict) + } + if isinstance(local_mcp_servers, dict) + else {}, + ) + + def set_plugin_handlers( + self, + name: str, + handlers: list[dict[str, Any]], + ) -> None: + plugin = self._plugins.get(name) + if plugin is None: + return + plugin.handlers = [dict(item) for item in handlers] + valid_handlers = { + str(item.get("handler_full_name", "")).strip() + for item in plugin.handlers + if isinstance(item, dict) + } + if not valid_handlers: + self._dynamic_command_routes.pop(name, None) + return + routes = self._dynamic_command_routes.get(name) + if routes is None: + return + self._dynamic_command_routes[name] = [ + dict(item) + for item in routes + if str(item.get("handler_full_name", "")).strip() in valid_handlers + ] + if not self._dynamic_command_routes[name]: + self._dynamic_command_routes.pop(name, None) + + def set_plugin_enabled(self, name: str, enabled: bool) -> None: + plugin = self._plugins.get(name) + if plugin is None: + return + plugin.metadata["enabled"] = enabled + + def register_dynamic_command_route( + self, + *, + plugin_id: str, + command_name: str, + handler_full_name: str, + desc: str = "", + priority: int = 0, + use_regex: bool = False, + ) -> None: + command_text = str(command_name).strip() + if not command_text: + raise AstrBotError.invalid_input("command_name must not be empty") + handler_text = str(handler_full_name).strip() + if not handler_text: + raise AstrBotError.invalid_input("handler_full_name must not be empty") + plugin = self._plugins.get(plugin_id) + if plugin is None: + raise AstrBotError.invalid_input(f"Unknown plugin: {plugin_id}") + if not self._plugin_has_handler(plugin_id, handler_text): + raise AstrBotError.invalid_input( + "handler_full_name must belong to the caller plugin and exist" + ) + route = { + "plugin_name": plugin_id, + "command_name": command_text, + "handler_full_name": handler_text, + "desc": str(desc), + "priority": int(priority), + "use_regex": bool(use_regex), + } + routes = [ + item + for item in self._dynamic_command_routes.get(plugin_id, []) + if str(item.get("command_name", "")).strip() != command_text + or bool(item.get("use_regex", False)) != bool(use_regex) + ] + routes.append(route) + self._dynamic_command_routes[plugin_id] = routes + + def list_dynamic_command_routes(self, plugin_id: str) -> list[dict[str, Any]]: + return [dict(item) for item in self._dynamic_command_routes.get(plugin_id, [])] + + def remove_dynamic_command_routes_for_plugin(self, plugin_id: str) -> None: + self._dynamic_command_routes.pop(plugin_id, None) + + def set_platform_instances(self, instances: list[dict[str, Any]]) -> None: + normalized: list[dict[str, Any]] = [] + for item in instances: + if not isinstance(item, dict): + continue + platform_id = str(item.get("id", "")).strip() + platform_type = str(item.get("type", "")).strip() + if not platform_id or not platform_type: + continue + errors = item.get("errors") + last_error = item.get("last_error") + stats = item.get("stats") + meta = item.get("meta") + normalized.append( + { + "id": platform_id, + "name": str(item.get("name", platform_id)), + "type": platform_type, + "status": str(item.get("status", "unknown")), + "errors": [ + dict(error) for error in errors if isinstance(error, dict) + ] + if isinstance(errors, list) + else [], + "last_error": ( + dict(last_error) if isinstance(last_error, dict) else None + ), + "unified_webhook": bool(item.get("unified_webhook", False)), + "stats": dict(stats) if isinstance(stats, dict) else None, + "meta": dict(meta) if isinstance(meta, dict) else {}, + "started_at": item.get("started_at"), + } + ) + self._platform_instances = normalized + + def get_platform_instances(self) -> list[dict[str, Any]]: + return [dict(item) for item in self._platform_instances] + + def set_admin_ids(self, admin_ids: list[str]) -> None: + self._permission_admin_ids = [ + user_id for user_id in (str(item).strip() for item in admin_ids) if user_id + ] + + def _plugin_has_handler(self, plugin_id: str, handler_full_name: str) -> bool: + plugin = self._plugins.get(plugin_id) + if plugin is None: + return False + handler_name = str(handler_full_name).strip() + if not handler_name: + return False + for handler in plugin.handlers: + if not isinstance(handler, dict): + continue + if str(handler.get("handler_full_name", "")).strip() == handler_name: + return True + return False + + def set_plugin_llm_tools( + self, + name: str, + tools: list[dict[str, Any]], + ) -> None: + plugin = self._plugins.get(name) + if plugin is None: + return + plugin.llm_tools = { + str(item.get("name", "")): dict(item) + for item in tools + if isinstance(item, dict) and str(item.get("name", "")).strip() + } + plugin.active_llm_tools = { + tool_name + for tool_name, item in plugin.llm_tools.items() + if bool(item.get("active", True)) + } + + def set_plugin_agents( + self, + name: str, + agents: list[dict[str, Any]], + ) -> None: + plugin = self._plugins.get(name) + if plugin is None: + return + plugin.agents = { + str(item.get("name", "")): dict(item) + for item in agents + if isinstance(item, dict) and str(item.get("name", "")).strip() + } + + def set_provider_catalog( + self, + kind: str, + providers: list[dict[str, Any]], + *, + active_id: str | None = None, + ) -> None: + self._provider_catalog[kind] = [ + dict(item) + for item in providers + if isinstance(item, dict) and str(item.get("id", "")).strip() + ] + for item in self._provider_catalog[kind]: + provider_id = str(item.get("id", "")).strip() + if not provider_id: + continue + self._provider_configs[provider_id] = {**item, "enable": True} + if active_id is not None: + self._active_provider_ids[kind] = active_id + else: + catalog = self._provider_catalog[kind] + self._active_provider_ids[kind] = catalog[0]["id"] if catalog else None + + def emit_provider_change( + self, + provider_id: str, + provider_type: str, + umo: str | None = None, + ) -> None: + event = { + "provider_id": str(provider_id), + "provider_type": str(provider_type), + "umo": str(umo) if umo is not None else None, + } + for queue in list(self._provider_change_subscriptions.values()): + queue.put_nowait(dict(event)) + + def record_platform_error( + self, + platform_id: str, + message: str, + *, + traceback: str | None = None, + ) -> None: + for item in self._platform_instances: + if str(item.get("id", "")) != str(platform_id): + continue + error = { + "message": str(message), + "timestamp": datetime.now(timezone.utc).isoformat(), + "traceback": str(traceback) if traceback is not None else None, + } + errors = item.setdefault("errors", []) + if isinstance(errors, list): + errors.append(error) + item["last_error"] = error + item["status"] = "error" + return + + def set_platform_stats(self, platform_id: str, stats: dict[str, Any]) -> None: + for item in self._platform_instances: + if str(item.get("id", "")) != str(platform_id): + continue + item["stats"] = dict(stats) + return + + def set_session_plugin_config( + self, + session_id: str, + *, + enabled_plugins: list[str] | None = None, + disabled_plugins: list[str] | None = None, + ) -> None: + config: dict[str, Any] = {} + if enabled_plugins is not None: + config["enabled_plugins"] = [str(item) for item in enabled_plugins] + if disabled_plugins is not None: + config["disabled_plugins"] = [str(item) for item in disabled_plugins] + self._session_plugin_configs[str(session_id)] = config + + def set_session_service_config( + self, + session_id: str, + *, + llm_enabled: bool | None = None, + tts_enabled: bool | None = None, + ) -> None: + config: dict[str, Any] = {} + if llm_enabled is not None: + config["llm_enabled"] = bool(llm_enabled) + if tts_enabled is not None: + config["tts_enabled"] = bool(tts_enabled) + self._session_service_configs[str(session_id)] = config + + def remove_http_apis_for_plugin(self, plugin_id: str) -> None: + self.http_api_store = [ + entry + for entry in self.http_api_store + if entry.get("plugin_id") != plugin_id + ] + + @staticmethod + def _require_caller_plugin_id(capability_name: str) -> str: + caller_plugin_id = current_caller_plugin_id() + if caller_plugin_id: + return caller_plugin_id + raise AstrBotError.invalid_input( + f"{capability_name} 只能在插件运行时上下文中调用" + ) + + def _emit_db_change(self, *, op: str, key: str, value: Any | None) -> None: + event = {"op": op, "key": key, "value": value} + for prefix, queue in list(self._db_watch_subscriptions.values()): + if prefix is not None and not key.startswith(prefix): + continue + queue.put_nowait(event) + + def descriptors(self) -> list[CapabilityDescriptor]: + return [entry.descriptor for entry in self._registrations.values()] + + def contains(self, name: str) -> bool: + return name in self._registrations + + def unregister(self, name: str) -> None: + self._registrations.pop(name, None) + + def register( + self, + descriptor: CapabilityDescriptor, + *, + call_handler: CallHandler | None = None, + stream_handler: StreamHandler | None = None, + finalize: FinalizeHandler | None = None, + exposed: bool = True, + ) -> None: + is_internal_reserved = not exposed and descriptor.name.startswith( + RESERVED_CAPABILITY_PREFIXES + ) + if ( + not CAPABILITY_NAME_PATTERN.fullmatch(descriptor.name) + and not is_internal_reserved + ): + raise ValueError( + f"capability 名称必须匹配 {{namespace}}.{{method}}:{descriptor.name}" + ) + if exposed and descriptor.name.startswith(RESERVED_CAPABILITY_PREFIXES): + raise ValueError( + f"保留 capability 命名空间仅供框架内部使用:{descriptor.name}" + ) + self._registrations[descriptor.name] = _CapabilityRegistration( + descriptor=descriptor, + call_handler=call_handler, + stream_handler=stream_handler, + finalize=finalize, + exposed=exposed, + ) + + async def execute( + self, + capability: str, + payload: dict[str, Any], + *, + stream: bool, + cancel_token, + request_id: str, + ) -> dict[str, Any] | StreamExecution: + registration = self._registrations.get(capability) + if registration is None: + raise AstrBotError.capability_not_found(capability) + + self._validate_schema_with_context( + capability=capability, + phase="输入", + schema=registration.descriptor.input_schema, + payload=payload, + ) + if stream: + if registration.stream_handler is None: + raise AstrBotError.invalid_input(f"{capability} 不支持 stream=true") + raw_execution = registration.stream_handler( + request_id, payload, cancel_token + ) + if inspect.isawaitable(raw_execution): + raw_execution = await raw_execution + if isinstance(raw_execution, StreamExecution): + return self._wrap_stream_execution( + registration.descriptor, + raw_execution, + ) + finalize = registration.finalize or (lambda chunks: {"items": chunks}) + return self._wrap_stream_execution( + registration.descriptor, + StreamExecution( + iterator=raw_execution, + finalize=finalize, + ), + ) + + if registration.call_handler is None: + raise AstrBotError.invalid_input( + f"{capability} 只能以 stream=true 调用,registration.call_handler 为 None" + ) + output = await registration.call_handler(request_id, payload, cancel_token) + self._validate_schema_with_context( + capability=capability, + phase="输出", + schema=registration.descriptor.output_schema, + payload=output, + ) + return output + + def _wrap_stream_execution( + self, + descriptor: CapabilityDescriptor, + execution: StreamExecution, + ) -> StreamExecution: + def validated_finalize(chunks: list[dict[str, Any]]) -> dict[str, Any]: + output = execution.finalize(chunks) + self._validate_schema_with_context( + capability=descriptor.name, + phase="输出", + schema=descriptor.output_schema, + payload=output, + ) + return output + + return StreamExecution( + iterator=execution.iterator, + finalize=validated_finalize, + collect_chunks=execution.collect_chunks, + ) + + # ------------------------------------------------------------------ + # Schema validation + # ------------------------------------------------------------------ + + def _validate_schema( + self, + schema: dict[str, Any] | None, + payload: Any, + ) -> None: + if not isinstance(schema, dict) or not schema: + return + self._validate_value(schema, payload, path="") + + def _validate_schema_with_context( + self, + *, + capability: str, + phase: str, + schema: dict[str, Any] | None, + payload: Any, + ) -> None: + try: + self._validate_schema(schema, payload) + except AstrBotError as exc: + if exc.code != "invalid_input": + raise + raise AstrBotError.invalid_input( + f"capability '{capability}' 的{phase}校验失败:{exc.message}", + hint=( + f"请检查 capability '{capability}' 的{phase.lower()}是否符合声明的 schema" + ), + ) from exc + + def _validate_value( + self, + schema: dict[str, Any], + value: Any, + *, + path: str, + ) -> None: + any_of = schema.get("anyOf") + if isinstance(any_of, list): + for candidate in any_of: + if not isinstance(candidate, dict): + continue + try: + self._validate_value(candidate, value, path=path) + return + except AstrBotError: + continue + raise AstrBotError.invalid_input( + f"{self._field_label(path)} 不符合允许的 schema 约束," + f"实际收到 {self._value_type_name(value)}" + ) + + enum = schema.get("enum") + if isinstance(enum, list) and value not in enum: + raise AstrBotError.invalid_input( + f"{self._field_label(path)} 必须是 {enum},实际收到 {value!r}" + ) + + schema_type = schema.get("type") + if schema_type == "object": + if not isinstance(value, dict): + if not path: + raise AstrBotError.invalid_input( + f"输入必须是 object,实际收到 {self._value_type_name(value)}" + ) + raise AstrBotError.invalid_input( + f"{self._field_label(path)} 必须是 object," + f"实际收到 {self._value_type_name(value)}" + ) + properties = schema.get("properties", {}) + required_fields = schema.get("required", []) + for field_name in required_fields: + field_path = self._join_path(path, str(field_name)) + if field_name not in value: + raise AstrBotError.invalid_input(f"缺少必填字段:{field_path}") + field_schema = self._property_schema(properties, field_name) + if value[field_name] is None and not self._schema_allows_null( + field_schema + ): + raise AstrBotError.invalid_input(f"缺少必填字段:{field_path}") + self._validate_value( + field_schema, + value[field_name], + path=field_path, + ) + for field_name, field_value in value.items(): + field_schema = properties.get(field_name) + if isinstance(field_schema, dict): + self._validate_value( + field_schema, + field_value, + path=self._join_path(path, str(field_name)), + ) + return + + if schema_type == "array": + if not isinstance(value, list): + raise AstrBotError.invalid_input( + f"{self._field_label(path)} 必须是 array," + f"实际收到 {self._value_type_name(value)}" + ) + item_schema = schema.get("items") + if isinstance(item_schema, dict): + for index, item in enumerate(value): + self._validate_value( + item_schema, + item, + path=self._index_path(path, index), + ) + return + + if schema_type == "string": + if not isinstance(value, str): + raise AstrBotError.invalid_input( + f"{self._field_label(path)} 必须是 string," + f"实际收到 {self._value_type_name(value)}" + ) + return + + if schema_type == "integer": + if not isinstance(value, int) or isinstance(value, bool): + raise AstrBotError.invalid_input( + f"{self._field_label(path)} 必须是 integer," + f"实际收到 {self._value_type_name(value)}" + ) + return + + if schema_type == "number": + if not isinstance(value, (int, float)) or isinstance(value, bool): + raise AstrBotError.invalid_input( + f"{self._field_label(path)} 必须是 number," + f"实际收到 {self._value_type_name(value)}" + ) + return + + if schema_type == "boolean": + if not isinstance(value, bool): + raise AstrBotError.invalid_input( + f"{self._field_label(path)} 必须是 boolean," + f"实际收到 {self._value_type_name(value)}" + ) + return + + if schema_type == "null": + if value is not None: + raise AstrBotError.invalid_input( + f"{self._field_label(path)} 必须是 null," + f"实际收到 {self._value_type_name(value)}" + ) + return + + @staticmethod + def _field_label(path: str) -> str: + if not path: + return "输入" + return f"字段 {path}" + + @staticmethod + def _join_path(path: str, field_name: str) -> str: + if not path: + return field_name + return f"{path}.{field_name}" + + @staticmethod + def _index_path(path: str, index: int) -> str: + return f"{path}[{index}]" if path else f"[{index}]" + + @staticmethod + def _property_schema( + properties: Any, + field_name: str, + ) -> dict[str, Any]: + if not isinstance(properties, dict): + return {} + field_schema = properties.get(field_name) + if isinstance(field_schema, dict): + return field_schema + return {} + + @staticmethod + def _schema_allows_null(field_schema: Any) -> bool: + if not isinstance(field_schema, dict): + return False + if field_schema.get("type") == "null": + return True + any_of = field_schema.get("anyOf") + if not isinstance(any_of, list): + return False + return any( + isinstance(candidate, dict) and candidate.get("type") == "null" + for candidate in any_of + ) + + @staticmethod + def _value_type_name(value: Any) -> str: + if value is None: + return "null" + if isinstance(value, bool): + return "boolean" + if isinstance(value, int): + return "integer" + if isinstance(value, float): + return "number" + if isinstance(value, str): + return "string" + if isinstance(value, list): + return "array" + if isinstance(value, dict): + return "object" + return type(value).__name__ diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/environment_groups.py b/astrbot-sdk/src/astrbot_sdk/runtime/environment_groups.py new file mode 100644 index 0000000000..982aaa2975 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/environment_groups.py @@ -0,0 +1,675 @@ +"""v4 runtime 的插件共享环境规划模块。 + +这个模块负责“多个插件,共享较少数量 Python 环境”的策略。核心约束是: + +- 插件仍然独立发现、独立加载 +- Worker 进程仍然保持一插件一进程 +- 只有在依赖兼容时才共享 Python 环境 + +整体流程如下: + +1. 先按插件声明的 `runtime.python` 分桶 +2. 再按依赖兼容性构建候选分组 +3. 为每个分组在 `.astrbot/` 下落地 source、lock、metadata 和 venv 路径 +4. 在 worker 启动前准备或同步该分组的共享环境 + +当前阶段优先保证兼容性,因此仍保留 `--system-site-packages`,也不改变 +现有插件 manifest 语义。 +""" + +from __future__ import annotations + +import hashlib +import json +import os +import re +import shutil +import subprocess +import tempfile +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .loader import PluginSpec + +GROUP_STATE_FILE_NAME = ".group-venv-state.json" + +_EXACT_PIN_PATTERN = re.compile(r"^([A-Za-z0-9_.-]+)==([^\s;]+)$") +_NORMALIZE_PATTERN = re.compile(r"[-_.]+") +_PYVENV_VERSION_PATTERN = re.compile( + r"^(?:version|version_info)\s*=\s*(\d+\.\d+)(?:\.\d+)?\s*$", + re.IGNORECASE | re.MULTILINE, +) + + +def _require_uv_binary(uv_binary: str | None) -> str: + if not uv_binary: + raise RuntimeError("uv executable not found") + return uv_binary + + +def _venv_python_path(venv_path: Path) -> Path: + if os.name == "nt": + return venv_path / "Scripts" / "python.exe" + return venv_path / "bin" / "python" + + +def _normalize_package_name(name: str) -> str: + return _NORMALIZE_PATTERN.sub("-", name).lower() + + +def _read_pyvenv_major_minor(pyvenv_cfg: Path) -> str | None: + if not pyvenv_cfg.exists(): + return None + try: + content = pyvenv_cfg.read_text(encoding="utf-8") + except OSError: + return None + match = _PYVENV_VERSION_PATTERN.search(content) + if match is None: + return None + return match.group(1) + + +def _requirement_lines(plugin: PluginSpec) -> list[str]: + if not plugin.requirements_path.exists(): + return [] + + lines: list[str] = [] + for raw_line in plugin.requirements_path.read_text(encoding="utf-8").splitlines(): + line = raw_line.strip() + if not line or line.startswith("#"): + continue + lines.append(line) + return lines + + +@dataclass(slots=True) +class EnvironmentGroup: + """一个或多个兼容插件最终共享的环境描述。 + + 分组是环境复用的最小单位。`plugins` 中的所有插件都会使用同一个 + `python_path`、lockfile 和 venv 目录,但运行时仍然各自启动独立的 + worker 进程。 + """ + + id: str + python_version: str + plugins: list[PluginSpec] + source_path: Path + lockfile_path: Path + metadata_path: Path + venv_path: Path + python_path: Path + environment_fingerprint: str + + +@dataclass(slots=True) +class EnvironmentPlanResult: + """一次完整规划得到的结果。 + + `plugins` 只包含成功完成规划的插件。 + `skipped_plugins` 记录规划失败的插件及原因,这类插件即使单独成组也没 + 有得到可用的共享环境。 + """ + + groups: list[EnvironmentGroup] = field(default_factory=list) + plugins: list[PluginSpec] = field(default_factory=list) + plugin_to_group: dict[str, EnvironmentGroup] = field(default_factory=dict) + skipped_plugins: dict[str, str] = field(default_factory=dict) + + +class EnvironmentPlanner: + """负责共享环境规划和分组工件落地。 + + 对 supervisor 启动来说,这个类主要回答两个问题: + + - 哪些插件可以共享一个环境 + - 这个共享环境应该对应哪份 lockfile 和哪个 venv 路径 + + 它本身不负责真正创建或同步 venv,这部分在规划结束后交给 + `GroupEnvironmentManager` 处理。 + """ + + def __init__(self, repo_root: Path, uv_binary: str | None = None) -> None: + self.repo_root = repo_root.resolve() + self.uv_binary = uv_binary or shutil.which("uv") + self.cache_dir = self.repo_root / ".uv-cache" + self.artifacts_dir = self.repo_root / ".astrbot" + self.group_dir = self.artifacts_dir / "groups" + self.lock_dir = self.artifacts_dir / "locks" + self.env_dir = self.artifacts_dir / "envs" + self._compatibility_cache: dict[str, bool] = {} + + def plan(self, plugins: list[PluginSpec]) -> EnvironmentPlanResult: + """为当前插件集合生成稳定的共享环境规划。 + + 之所以在 worker 启动前完成规划,是为了让 supervisor 能够: + + - 只跳过依赖无法满足的那部分插件 + - 在兼容插件之间复用同一个环境 + - 清理旧规划遗留的 `.astrbot` 工件 + """ + if not plugins: + self.cleanup_artifacts([]) + return EnvironmentPlanResult() + _require_uv_binary(self.uv_binary) + + candidate_groups = self._build_candidate_groups(plugins) + planned_groups: list[EnvironmentGroup] = [] + skipped_plugins: dict[str, str] = {} + for group_plugins in candidate_groups: + materialized, skipped = self._materialize_candidate_group(group_plugins) + planned_groups.extend(materialized) + skipped_plugins.update(skipped) + + planned_groups.sort(key=lambda group: (group.python_version, group.id)) + self.cleanup_artifacts(planned_groups) + + plugin_to_group = { + plugin.name: group for group in planned_groups for plugin in group.plugins + } + planned_plugins = [ + plugin for plugin in plugins if plugin.name in plugin_to_group + ] + return EnvironmentPlanResult( + groups=planned_groups, + plugins=planned_plugins, + plugin_to_group=plugin_to_group, + skipped_plugins=skipped_plugins, + ) + + def _build_candidate_groups( + self, plugins: list[PluginSpec] + ) -> list[list[PluginSpec]]: + """用贪心方式把插件装入兼容性候选组。 + + 分组过程保持确定性,规则是: + + - Python 版本是第一层硬边界 + - `requirements.txt` 约束更多的插件优先落位 + - 若仍相同,则按插件名排序 + """ + buckets: dict[str, list[PluginSpec]] = {} + for plugin in plugins: + buckets.setdefault(plugin.python_version, []).append(plugin) + + planned_groups: list[list[PluginSpec]] = [] + for python_version in sorted(buckets): + python_groups: list[list[PluginSpec]] = [] + for plugin in self._sort_plugins(buckets[python_version]): + placed = False + for group_plugins in python_groups: + if self._is_compatible([*group_plugins, plugin]): + group_plugins.append(plugin) + placed = True + break + if not placed: + python_groups.append([plugin]) + planned_groups.extend(python_groups) + return planned_groups + + @staticmethod + def _sort_plugins(plugins: list[PluginSpec]) -> list[PluginSpec]: + return sorted( + plugins, + key=lambda plugin: (-len(_requirement_lines(plugin)), plugin.name), + ) + + def _is_compatible(self, plugins: list[PluginSpec]) -> bool: + """判断一组插件是否可以共享一个环境。 + + 兼容性判断先走一个便宜的快速路径: + + - 如果每条 requirement 都是 `pkg==1.2.3` 这种精确版本锁定 + - 且归一化后的包名之间没有解析出冲突版本 + - 那么无需调用求解器,直接认为这一组兼容 + + 更复杂的情况则回退到 `uv pip compile`,以它的求解结果作为最终依 + 赖兼容性的判断依据。 + """ + cache_key = self._compatibility_cache_key(plugins) + cached = self._compatibility_cache.get(cache_key) + if cached is not None: + return cached + + requirement_lines = self._collect_requirement_lines(plugins) + if not requirement_lines: + self._compatibility_cache[cache_key] = True + return True + + if self._merge_exact_requirements(requirement_lines) is not None: + self._compatibility_cache[cache_key] = True + return True + + with tempfile.TemporaryDirectory( + prefix="astrbot-env-plan-", + dir=self.repo_root, + ) as temp_dir: + source_path = Path(temp_dir) / "compat.in" + output_path = Path(temp_dir) / "compat.txt" + self._write_source_file(source_path, plugins) + try: + self._compile_lockfile( + source_path=source_path, + output_path=output_path, + python_version=plugins[0].python_version, + ) + except RuntimeError: + self._compatibility_cache[cache_key] = False + return False + + self._compatibility_cache[cache_key] = True + return True + + def _materialize_candidate_group( + self, + plugins: list[PluginSpec], + ) -> tuple[list[EnvironmentGroup], dict[str, str]]: + """为一个候选组创建工件,失败时自动拆分。 + + 如果整组插件无法生成 lockfile,规划器会退回到“一插件一组”继续尝 + 试,避免单个坏插件阻塞整批插件启动。 + """ + try: + return [self._materialize_group(plugins)], {} + except RuntimeError as exc: + if len(plugins) == 1: + return [], {plugins[0].name: str(exc)} + + materialized: list[EnvironmentGroup] = [] + skipped: dict[str, str] = {} + for plugin in plugins: + groups, child_skipped = self._materialize_candidate_group([plugin]) + materialized.extend(groups) + skipped.update(child_skipped) + return materialized, skipped + + def _materialize_group(self, plugins: list[PluginSpec]) -> EnvironmentGroup: + """落地定义一个共享环境所需的全部文件。 + + 分组身份由 Python 版本和插件集合共同决定。 + 环境指纹则会进一步包含编译后的 lockfile 内容,这样当依赖解析结果 + 变化时,已有环境就可以走增量同步而不是盲目重建。 + """ + group_id = self._group_identity(plugins)[:16] + python_version = plugins[0].python_version + source_path = self.group_dir / f"{group_id}.in" + lockfile_path = self.lock_dir / f"{group_id}.txt" + metadata_path = self.group_dir / f"{group_id}.json" + venv_path = self.env_dir / group_id + python_path = _venv_python_path(venv_path) + + source_path.parent.mkdir(parents=True, exist_ok=True) + lockfile_path.parent.mkdir(parents=True, exist_ok=True) + metadata_path.parent.mkdir(parents=True, exist_ok=True) + venv_path.parent.mkdir(parents=True, exist_ok=True) + + self._write_source_file(source_path, plugins) + self._write_lockfile( + lockfile_path=lockfile_path, + source_path=source_path, + plugins=plugins, + python_version=python_version, + ) + environment_fingerprint = self._environment_fingerprint( + plugins=plugins, + python_version=python_version, + lockfile_path=lockfile_path, + ) + metadata_path.write_text( + json.dumps( + { + "group_id": group_id, + "python_version": python_version, + "plugins": [plugin.name for plugin in plugins], + "plugin_entries": [ + { + "name": plugin.name, + "plugin_dir": str(plugin.plugin_dir), + } + for plugin in plugins + ], + "source_path": str(source_path), + "lockfile_path": str(lockfile_path), + "venv_path": str(venv_path), + "environment_fingerprint": environment_fingerprint, + }, + ensure_ascii=True, + indent=2, + sort_keys=True, + ), + encoding="utf-8", + ) + + return EnvironmentGroup( + id=group_id, + python_version=python_version, + plugins=list(plugins), + source_path=source_path, + lockfile_path=lockfile_path, + metadata_path=metadata_path, + venv_path=venv_path, + python_path=python_path, + environment_fingerprint=environment_fingerprint, + ) + + def _write_source_file(self, source_path: Path, plugins: list[PluginSpec]) -> None: + """写入供 lockfile 生成使用的分组 requirements 输入文件。""" + lines: list[str] = [] + for plugin in sorted(plugins, key=lambda item: item.name): + requirements = _requirement_lines(plugin) + if not requirements: + continue + lines.append(f"# {plugin.name}") + lines.extend(requirements) + lines.append("") + + content = "\n".join(lines).rstrip() + if content: + content += "\n" + source_path.write_text(content, encoding="utf-8") + + def _write_lockfile( + self, + *, + lockfile_path: Path, + source_path: Path, + plugins: list[PluginSpec], + python_version: str, + ) -> None: + """为一个分组生成 lockfile。 + + 即使依赖集合为空,也会故意生成空 lockfile,这样整个共享环境流水 + 线的处理方式可以保持一致。 + """ + if not self._collect_requirement_lines(plugins): + lockfile_path.write_text("", encoding="utf-8") + return + + self._compile_lockfile( + source_path=source_path, + output_path=lockfile_path, + python_version=python_version, + ) + + def _compile_lockfile( + self, + *, + source_path: Path, + output_path: Path, + python_version: str, + ) -> None: + """把依赖求解委托给 `uv pip compile`。""" + uv_binary = _require_uv_binary(self.uv_binary) + self._run_command( + [ + uv_binary, + "pip", + "compile", + "--python-version", + python_version, + "--no-managed-python", + "--no-python-downloads", + "--quiet", + str(source_path), + "-o", + str(output_path), + ], + cwd=self.repo_root, + command_name=f"compile lockfile for {source_path.name}", + ) + + def _run_command(self, command: list[str], *, cwd: Path, command_name: str) -> None: + process = subprocess.run( + command, + cwd=str(cwd), + env={**os.environ, "UV_CACHE_DIR": str(self.cache_dir)}, + capture_output=True, + text=True, + check=False, + ) + if process.returncode != 0: + raise RuntimeError( + f"{command_name} failed with exit code {process.returncode}: " + f"{process.stderr.strip() or process.stdout.strip()}" + ) + + def cleanup_artifacts(self, groups: list[EnvironmentGroup]) -> None: + """清理不再被当前规划引用的 `.astrbot` 工件。 + + 清理范围只覆盖规划器自己维护的共享环境工件,不会碰旧式插件目录下 + 的本地 `.venv`。 + """ + active_group_ids = {group.id for group in groups} + self._cleanup_group_artifacts(active_group_ids) + self._cleanup_lockfiles(active_group_ids) + self._cleanup_envs(active_group_ids) + + def _cleanup_group_artifacts(self, active_group_ids: set[str]) -> None: + if not self.group_dir.exists(): + return + for entry in self.group_dir.iterdir(): + if entry.suffix not in {".in", ".json"}: + continue + if entry.stem in active_group_ids: + continue + entry.unlink(missing_ok=True) + + def _cleanup_lockfiles(self, active_group_ids: set[str]) -> None: + if not self.lock_dir.exists(): + return + for entry in self.lock_dir.iterdir(): + if entry.suffix != ".txt": + continue + if entry.stem in active_group_ids: + continue + entry.unlink(missing_ok=True) + + def _cleanup_envs(self, active_group_ids: set[str]) -> None: + if not self.env_dir.exists(): + return + for entry in self.env_dir.iterdir(): + if entry.name in active_group_ids: + continue + if entry.is_dir(): + shutil.rmtree(entry) + else: + entry.unlink(missing_ok=True) + + def _compatibility_cache_key(self, plugins: list[PluginSpec]) -> str: + payload = { + "python_version": plugins[0].python_version if plugins else "", + "plugins": [ + { + "name": plugin.name, + "requirements": _requirement_lines(plugin), + } + for plugin in sorted(plugins, key=lambda item: item.name) + ], + } + encoded = json.dumps(payload, ensure_ascii=True, sort_keys=True).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + @staticmethod + def _group_identity(plugins: list[PluginSpec]) -> str: + payload = { + "python_version": plugins[0].python_version if plugins else "", + "plugins": sorted(plugin.name for plugin in plugins), + } + encoded = json.dumps(payload, ensure_ascii=True, sort_keys=True).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + @staticmethod + def _environment_fingerprint( + *, + plugins: list[PluginSpec], + python_version: str, + lockfile_path: Path, + ) -> str: + payload = { + "python_version": python_version, + "plugins": sorted(plugin.name for plugin in plugins), + "lockfile": lockfile_path.read_text(encoding="utf-8"), + } + encoded = json.dumps(payload, ensure_ascii=True, sort_keys=True).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + @staticmethod + def _collect_requirement_lines(plugins: list[PluginSpec]) -> list[str]: + lines: list[str] = [] + for plugin in plugins: + lines.extend(_requirement_lines(plugin)) + return lines + + @staticmethod + def _merge_exact_requirements(requirement_lines: list[str]) -> list[str] | None: + merged: dict[str, str] = {} + for line in requirement_lines: + match = _EXACT_PIN_PATTERN.fullmatch(line) + if match is None: + return None + package_name = _normalize_package_name(match.group(1)) + existing = merged.get(package_name) + if existing is not None and existing != line: + return None + merged[package_name] = line + return [merged[name] for name in sorted(merged)] + + +class GroupEnvironmentManager: + """负责创建、校验和同步一个已经规划好的共享环境。""" + + def __init__(self, repo_root: Path, uv_binary: str | None = None) -> None: + self.repo_root = repo_root.resolve() + self.uv_binary = uv_binary or shutil.which("uv") + self.cache_dir = self.repo_root / ".uv-cache" + + def prepare(self, group: EnvironmentGroup) -> Path: + """确保分组对应的解释器路径已经可以用于 worker 启动。 + + 行为概括如下: + + - 环境缺失、Python 版本不对、lockfile 丢失:重建 + - 环境结构还在但指纹变化:执行 `uv pip sync` + - 否则:直接复用现有解释器路径 + """ + _require_uv_binary(self.uv_binary) + + state_path = group.venv_path / GROUP_STATE_FILE_NAME + state = self._load_state(state_path) + if ( + not group.python_path.exists() + or not self._matches_python_version(group.venv_path, group.python_version) + or not group.lockfile_path.exists() + ): + self._rebuild(group) + self._write_state(state_path, group) + elif not self._state_matches_group(state, group): + self._sync_existing(group) + self._write_state(state_path, group) + return group.python_path + + def _rebuild(self, group: EnvironmentGroup) -> None: + if group.venv_path.exists(): + shutil.rmtree(group.venv_path) + self._create_venv(group) + self._sync_lockfile(group) + + def _sync_existing(self, group: EnvironmentGroup) -> None: + self._sync_lockfile(group) + + def _sync_lockfile(self, group: EnvironmentGroup) -> None: + """让已安装包与该分组的 lockfile 精确对齐。""" + uv_binary = _require_uv_binary(self.uv_binary) + self._run_command( + [ + uv_binary, + "pip", + "sync", + "--python", + str(group.python_path), + "--allow-empty-requirements", + str(group.lockfile_path), + ], + cwd=self.repo_root, + command_name=f"sync group env {group.id}", + ) + + def _create_venv(self, group: EnvironmentGroup) -> None: + """为一个分组创建共享 venv。 + + 当前迁移阶段仍保留 `--system-site-packages`,以兼容那些仍然隐式依 + 赖宿主环境包的旧插件。 + """ + uv_binary = _require_uv_binary(self.uv_binary) + self._run_command( + [ + uv_binary, + "venv", + "--python", + group.python_version, + "--system-site-packages", + "--no-python-downloads", + "--no-managed-python", + str(group.venv_path), + ], + cwd=self.repo_root, + command_name=f"create group venv {group.id}", + ) + + def _run_command(self, command: list[str], *, cwd: Path, command_name: str) -> None: + process = subprocess.run( + command, + cwd=str(cwd), + env={**os.environ, "UV_CACHE_DIR": str(self.cache_dir)}, + capture_output=True, + text=True, + check=False, + ) + if process.returncode != 0: + raise RuntimeError( + f"{command_name} failed with exit code {process.returncode}: " + f"{process.stderr.strip() or process.stdout.strip()}" + ) + + @staticmethod + def _matches_python_version(venv_path: Path, version: str) -> bool: + return _read_pyvenv_major_minor(venv_path / "pyvenv.cfg") == version + + @staticmethod + def _load_state(state_path: Path) -> dict[str, object]: + if not state_path.exists(): + return {} + try: + data = json.loads(state_path.read_text(encoding="utf-8")) + except Exception: + return {} + return data if isinstance(data, dict) else {} + + @staticmethod + def _write_state(state_path: Path, group: EnvironmentGroup) -> None: + state_path.parent.mkdir(parents=True, exist_ok=True) + state_path.write_text( + json.dumps( + { + "group_id": group.id, + "python_version": group.python_version, + "environment_fingerprint": group.environment_fingerprint, + "plugins": [plugin.name for plugin in group.plugins], + }, + ensure_ascii=True, + indent=2, + sort_keys=True, + ), + encoding="utf-8", + ) + + @staticmethod + def _state_matches_group(state: dict[str, object], group: EnvironmentGroup) -> bool: + return ( + state.get("group_id") == group.id + and state.get("python_version") == group.python_version + and state.get("environment_fingerprint") == group.environment_fingerprint + ) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/handler_dispatcher.py b/astrbot-sdk/src/astrbot_sdk/runtime/handler_dispatcher.py new file mode 100644 index 0000000000..a825395302 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/handler_dispatcher.py @@ -0,0 +1,991 @@ +"""处理器分发模块。 + +定义 HandlerDispatcher 类,负责将能力调用分发到具体的处理器函数。 +支持参数注入、流式执行、错误处理。 + +核心职责: + - 根据处理器 ID 查找处理器 + - 构建处理器参数(支持类型注解注入) + - 执行处理器并处理结果 + - 处理异步生成器流式结果 + - 统一的错误处理 + +参数注入优先级: + 1. 按类型注解注入(支持 Optional[Type]) + 2. 按参数名注入(兼容无类型注解) + 3. 从 args 注入(命令参数等) + +支持的注入类型: + - MessageEvent: 消息事件 + - Context: 运行时上下文 +""" + +from __future__ import annotations + +import asyncio +import inspect +import re +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, cast, get_type_hints + +from loguru import logger + +from .._internal.command_model import ( + parse_command_model_remainder, + resolve_command_model_param, +) +from .._internal.injected_params import legacy_arg_parameter_names +from .._internal.invocation_context import caller_plugin_scope +from .._internal.plugin_logger import PluginLogger +from .._internal.star_runtime import bind_star_runtime +from .._internal.typing_utils import unwrap_optional +from ..clients.llm import LLMResponse +from ..context import CancelToken, Context +from ..conversation import ( + DEFAULT_BUSY_MESSAGE, + ConversationClosed, + ConversationReplaced, + ConversationSession, + ConversationState, +) +from ..events import MessageEvent +from ..filters import LocalFilterBinding +from ..llm.entities import ProviderRequest +from ..message.components import BaseMessageComponent +from ..message.result import ( + MessageChain, + MessageEventResult, + coerce_message_chain, +) +from ..protocol.descriptors import ( + CommandTrigger, + MessageTrigger, + ParamSpec, + ScheduleTrigger, +) +from ..schedule import ScheduleContext +from ..session_waiter import ( + SessionWaiterManager, + _mark_session_waiter_handler_task, + _unmark_session_waiter_handler_task, +) +from ..star import Star +from ._command_matching import ( + build_command_args, + build_regex_args, + match_command_name, +) +from .capability_dispatcher import CapabilityDispatcher +from .limiter import LimiterEngine +from .loader import LoadedHandler + + +@dataclass(slots=True) +class _ActiveConversation: + session: ConversationSession + task: asyncio.Task[Any] + + +@dataclass(slots=True) +class _InjectedEventPayloads: + provider_request: ProviderRequest | None = None + llm_response: LLMResponse | None = None + event_result: MessageEventResult | None = None + + +class HandlerDispatcher: + def __init__( + self, *, plugin_id: str, peer, handlers: Sequence[LoadedHandler] + ) -> None: + self._plugin_id = plugin_id + self._peer = peer + self._handlers = {item.descriptor.id: item for item in handlers} + self._active: dict[str, tuple[asyncio.Task[Any], CancelToken]] = {} + self._session_waiters = SessionWaiterManager(plugin_id=plugin_id, peer=peer) + self._limiter = LimiterEngine() + self._conversations: dict[str, _ActiveConversation] = {} + try: + setattr(peer, "_session_waiter_manager", self._session_waiters) + except AttributeError: + logger.warning( + f"Failed to attach _session_waiter_manager to peer {peer}, " + "some features may not work as expected" + ) + + def has_active_waiter(self, event: MessageEvent) -> bool: + return self._session_waiters.has_active_waiter(event) + + async def invoke(self, message, cancel_token: CancelToken) -> dict[str, Any]: + handler_id = str(message.input.get("handler_id", "")) + if handler_id == "__sdk_session_waiter__": + event_payload = message.input.get("event", {}) + requested_plugin_id = str(message.input.get("plugin_id") or "").strip() + ctx = Context( + peer=self._peer, + plugin_id=requested_plugin_id or self._plugin_id, + request_id=message.id, + cancel_token=cancel_token, + source_event_payload=event_payload + if isinstance(event_payload, dict) + else None, + ) + event = MessageEvent.from_payload(event_payload, context=ctx) + session_key = event.unified_msg_origin + if requested_plugin_id: + plugin_id = requested_plugin_id + else: + plugin_ids = self._session_waiters.get_waiter_plugin_ids(session_key) + if len(plugin_ids) > 1: + raise LookupError( + "multiple active session_waiters found for session; " + "dispatch requires explicit plugin identity" + ) + plugin_id = plugin_ids[0] if plugin_ids else self._plugin_id + if plugin_id != ctx.plugin_id: + ctx = Context( + peer=self._peer, + plugin_id=plugin_id, + request_id=message.id, + cancel_token=cancel_token, + source_event_payload=event_payload + if isinstance(event_payload, dict) + else None, + ) + event = MessageEvent.from_payload(event_payload, context=ctx) + event.bind_reply_handler(self._create_reply_handler(ctx, event)) + with caller_plugin_scope(plugin_id): + task = asyncio.create_task( + self._session_waiters.dispatch(event, plugin_id=plugin_id) + ) + _mark_session_waiter_handler_task(task) + task.add_done_callback(_unmark_session_waiter_handler_task) + self._active[message.id] = (task, cancel_token) + try: + return await task + finally: + self._active.pop(message.id, None) + + loaded = self._handlers.get(handler_id) + if loaded is None: + raise LookupError(f"handler not found: {handler_id}") + + plugin_id = self._resolve_plugin_id(loaded) + event_payload = message.input.get("event", {}) + ctx = Context( + peer=self._peer, + plugin_id=plugin_id, + request_id=message.id, + cancel_token=cancel_token, + source_event_payload=event_payload + if isinstance(event_payload, dict) + else None, + ) + event = MessageEvent.from_payload(event_payload, context=ctx) + bound_logger = cast(PluginLogger, ctx.logger).bind( + plugin_id=plugin_id, + request_id=message.id, + handler_ref=handler_id, + session_id=event.session_id, + event_type=str( + event_payload.get("event_type") + or event_payload.get("type") + or event.message_type + ), + ) + ctx.logger = bound_logger + event.bind_reply_handler(self._create_reply_handler(ctx, event)) + schedule_context = self._build_schedule_context(loaded, event_payload) + + # 提取 args 用于兼容 handler 签名 + raw_args = message.input.get("args") or {} + args = dict(raw_args) if isinstance(raw_args, dict) else {} + if not args: + args = self._derive_args(loaded, event) + + with caller_plugin_scope(plugin_id): + task = asyncio.create_task( + self._run_handler( + loaded, + event, + ctx, + args, + schedule_context=schedule_context, + ) + ) + _mark_session_waiter_handler_task(task) + task.add_done_callback(_unmark_session_waiter_handler_task) + self._active[message.id] = (task, cancel_token) + try: + return await task + finally: + self._active.pop(message.id, None) + + def _resolve_plugin_id(self, loaded: LoadedHandler) -> str: + if loaded.plugin_id: + return loaded.plugin_id + handler_id = getattr(loaded.descriptor, "id", "") + if isinstance(handler_id, str) and ":" in handler_id: + return handler_id.split(":", 1)[0] + return self._plugin_id + + def _create_reply_handler(self, ctx: Context, event: MessageEvent): + async def reply(text: str) -> None: + try: + await ctx.platform.send(event.session_ref or event.session_id, text) + except TypeError: + send = getattr(self._peer, "send", None) + if not callable(send): + raise + result = send(event.session_id, text) + if inspect.isawaitable(result): + await result + + return reply + + async def cancel(self, request_id: str) -> None: + active = self._active.get(request_id) + if active is None: + return + task, cancel_token = active + cancel_token.cancel() + task.cancel() + + async def _run_handler( + self, + loaded: LoadedHandler, + event: MessageEvent, + ctx: Context, + args: dict[str, Any] | None = None, + *, + schedule_context: ScheduleContext | None = None, + ) -> dict[str, Any]: + summary = {"sent_message": False, "stop": False, "call_llm": False} + injected_payloads = _InjectedEventPayloads() + event_type = self._event_type_name(event) + try: + limiter = loaded.limiter + if limiter is not None: + decision = self._limiter.evaluate( + plugin_id=self._resolve_plugin_id(loaded), + handler_id=loaded.descriptor.id, + limiter=limiter, + event=event, + ) + if not decision.allowed: + if decision.error is not None: + raise decision.error + if decision.hint: + await event.reply(decision.hint) + summary["sent_message"] = True + return summary + if not self._run_local_filters( + loaded.local_filters, + event=event, + ctx=ctx, + ): + return summary + parsed_args, help_text = self._prepare_handler_args( + loaded, + args or {}, + ) + if help_text is not None: + await event.reply(help_text) + summary["sent_message"] = True + return summary + if loaded.conversation is not None: + return await self._start_conversation( + loaded, + event, + ctx, + parsed_args, + schedule_context=schedule_context, + ) + owner = loaded.owner if isinstance(loaded.owner, Star) else None + with bind_star_runtime(owner, ctx): + result = loaded.callable( + *self._build_args( + loaded.callable, + event, + ctx, + parsed_args, + plugin_id=self._resolve_plugin_id(loaded), + handler_ref=loaded.descriptor.id, + schedule_context=schedule_context, + injected_payloads=injected_payloads, + ) + ) + if inspect.isasyncgen(result): + async for item in result: + self._merge_handler_summary( + summary, + await self._handle_result_item(item, event, ctx), + ) + summary["stop"] = bool(summary.get("stop")) or event.is_stopped() + self._append_injected_payloads( + summary, + injected_payloads, + event=event, + event_type=event_type, + ) + return summary + if inspect.isawaitable(result): + result = await result + if result is not None: + self._merge_handler_summary( + summary, + await self._handle_result_item(result, event, ctx), + ) + summary["stop"] = bool(summary.get("stop")) or event.is_stopped() + self._append_injected_payloads( + summary, + injected_payloads, + event=event, + event_type=event_type, + ) + return summary + except Exception as exc: + await self._handle_error( + loaded.owner, + exc, + event, + ctx, + handler_name=loaded.callable.__name__, + plugin_id=self._resolve_plugin_id(loaded), + ) + raise + + def _derive_args( + self, + loaded: LoadedHandler, + event: MessageEvent, + ) -> dict[str, Any]: + trigger = loaded.descriptor.trigger + if isinstance(trigger, CommandTrigger): + param_specs = loaded.descriptor.param_specs + for command_name in [trigger.command, *trigger.aliases]: + remainder = match_command_name(event.text, command_name) + if remainder is not None: + model_param = resolve_command_model_param(loaded.callable) + if model_param is not None: + return { + "__command_model_remainder__": remainder, + "__command_name__": command_name, + } + if param_specs: + return build_command_args(param_specs, remainder) + return build_command_args( + [ + ParamSpec(name=name, type="str") + for name in legacy_arg_parameter_names(loaded.callable) + ], + remainder, + ) + return {} + if isinstance(trigger, MessageTrigger) and trigger.regex: + match = re.search(trigger.regex, event.text) + if match is None: + return {} + if loaded.descriptor.param_specs: + return build_regex_args(loaded.descriptor.param_specs, match) + return build_regex_args( + [ + ParamSpec(name=name, type="str") + for name in legacy_arg_parameter_names(loaded.callable) + ], + match, + ) + return {} + + def _build_args( + self, + handler, + event: MessageEvent, + ctx: Context, + args: dict[str, Any] | None = None, + *, + plugin_id: str | None = None, + handler_ref: str | None = None, + schedule_context: ScheduleContext | None = None, + conversation_session: ConversationSession | None = None, + injected_payloads: _InjectedEventPayloads | None = None, + ) -> list[Any]: + """构建 handler 参数列表。""" + from loguru import logger + + signature = inspect.signature(handler) + injected_args: list[Any] = [] + args = args or {} + + type_hints: dict[str, Any] = {} + try: + type_hints = get_type_hints(handler) + except Exception: + pass + + for parameter in signature.parameters.values(): + if parameter.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + continue + + injected = None + + # 1. 优先按类型注解注入 + param_type = type_hints.get(parameter.name) + if param_type is not None: + injected = self._inject_by_type( + param_type, + event, + ctx, + schedule_context, + conversation_session, + injected_payloads=injected_payloads, + ) + + # 2. Fallback 按名字注入 + if injected is None: + if parameter.name == "event": + injected = event + elif parameter.name in {"ctx", "context"}: + injected = ctx + elif parameter.name in {"sched", "schedule"}: + injected = schedule_context + elif parameter.name in {"conversation", "conv"}: + injected = conversation_session + elif parameter.name in args: + injected = args[parameter.name] + + # 3. 检查是否有默认值 + if injected is None: + if parameter.default is not parameter.empty: + continue + logger.error( + "Handler '{}' 的必填参数 '{}' 无法注入", + handler.__name__, + parameter.name, + ) + raise TypeError( + self._format_handler_injection_error( + handler=handler, + parameter_name=parameter.name, + plugin_id=plugin_id, + handler_ref=handler_ref, + args=args, + ) + ) + else: + injected_args.append(injected) + + return injected_args + + def _prepare_handler_args( + self, + loaded: LoadedHandler, + args: dict[str, Any], + ) -> tuple[dict[str, Any], str | None]: + parsed_args = ( + self._parse_handler_args(loaded.descriptor.param_specs, args) + if loaded.descriptor.param_specs + else { + key: value + for key, value in dict(args).items() + if not str(key).startswith("__command_") + } + ) + if not isinstance(loaded.descriptor.trigger, CommandTrigger): + return parsed_args, None + model_param = resolve_command_model_param(loaded.callable) + if model_param is None: + return parsed_args, None + if "__command_model_remainder__" not in args: + return parsed_args, None + trigger = loaded.descriptor.trigger + command_name = str(args.get("__command_name__", "")) or ( + trigger.command + if isinstance(trigger, CommandTrigger) + else loaded.descriptor.id.rsplit(".", 1)[-1] + ) + result = parse_command_model_remainder( + remainder=str(args.get("__command_model_remainder__", "")), + model_param=model_param, + command_name=command_name, + ) + if result.help_text is not None: + return parsed_args, result.help_text + if result.model is not None: + parsed_args[model_param.name] = result.model + return parsed_args, None + + async def _start_conversation( + self, + loaded: LoadedHandler, + event: MessageEvent, + ctx: Context, + parsed_args: dict[str, Any], + *, + schedule_context: ScheduleContext | None, + ) -> dict[str, Any]: + assert loaded.conversation is not None + conversation_meta = loaded.conversation + summary = {"sent_message": False, "stop": True, "call_llm": False} + key = f"{self._resolve_plugin_id(loaded)}:{event.session_id}" + active = self._conversations.get(key) + if active is not None and not active.task.done(): + if conversation_meta.mode == "reject": + await event.reply( + conversation_meta.busy_message or DEFAULT_BUSY_MESSAGE + ) + summary["sent_message"] = True + return summary + active.session.mark_replaced() + await self._session_waiters.fail( + active.session.session_key, + ConversationReplaced("conversation replaced by a newer session"), + ) + await asyncio.sleep(0) + active.task.cancel() + try: + await asyncio.wait_for( + asyncio.shield(active.task), + timeout=conversation_meta.grace_period, + ) + except asyncio.TimeoutError: + cast(PluginLogger, ctx.logger).warning( + "Conversation replacement grace period exceeded for handler {}", + loaded.descriptor.id, + ) + except asyncio.CancelledError: + pass + except Exception: + pass + finally: + if self._conversations.get(key) is active: + self._conversations.pop(key, None) + + conversation = ConversationSession( + ctx=ctx, + event=event, + waiter_manager=self._session_waiters, + timeout=conversation_meta.timeout, + ) + + async def _runner() -> None: + try: + await self._run_conversation_task( + loaded, + event, + ctx, + parsed_args, + conversation, + schedule_context=schedule_context, + ) + finally: + if conversation.state == ConversationState.ACTIVE: + conversation.close(ConversationState.COMPLETED) + current = self._conversations.get(key) + if current is not None and current.session is conversation: + self._conversations.pop(key, None) + + task = await ctx.register_task( + _runner(), + f"conversation:{loaded.descriptor.id}", + ) + conversation.bind_owner_task(task) + self._conversations[key] = _ActiveConversation( + session=conversation, + task=task, + ) + return summary + + async def _run_conversation_task( + self, + loaded: LoadedHandler, + event: MessageEvent, + ctx: Context, + parsed_args: dict[str, Any], + conversation: ConversationSession, + *, + schedule_context: ScheduleContext | None, + ) -> None: + owner = loaded.owner if isinstance(loaded.owner, Star) else None + args_with_conversation = dict(parsed_args) + args_with_conversation.setdefault("conversation", conversation) + try: + with bind_star_runtime(owner, ctx): + result = loaded.callable( + *self._build_args( + loaded.callable, + event, + ctx, + args_with_conversation, + plugin_id=self._resolve_plugin_id(loaded), + handler_ref=loaded.descriptor.id, + schedule_context=schedule_context, + conversation_session=conversation, + ) + ) + if inspect.isasyncgen(result): + async for item in result: + await self._handle_result_item(item, event, ctx) + return + if inspect.isawaitable(result): + result = await result + if result is not None: + await self._handle_result_item(result, event, ctx) + except asyncio.CancelledError: + if conversation.state == ConversationState.ACTIVE: + conversation.close(ConversationState.CANCELLED) + raise + except (ConversationReplaced, ConversationClosed): + return + except Exception as exc: + await self._handle_error( + loaded.owner, + exc, + event, + ctx, + handler_name=loaded.callable.__name__, + plugin_id=self._resolve_plugin_id(loaded), + ) + + def _inject_by_type( + self, + param_type: Any, + event: MessageEvent, + ctx: Context, + schedule_context: ScheduleContext | None, + conversation_session: ConversationSession | None, + *, + injected_payloads: _InjectedEventPayloads | None = None, + ) -> Any: + """根据类型注解注入参数。""" + param_type, _is_optional = unwrap_optional(param_type) + + # 注入 MessageEvent 及其子类 + if param_type is MessageEvent: + return event + if isinstance(param_type, type) and issubclass(param_type, MessageEvent): + if isinstance(event, param_type): + return event + factory = getattr(param_type, "from_message_event", None) + if callable(factory): + return factory(event) + return event + + # 注入 Context 及其子类 + if param_type is Context or ( + isinstance(param_type, type) and issubclass(param_type, Context) + ): + return ctx + if param_type is ScheduleContext or ( + isinstance(param_type, type) and issubclass(param_type, ScheduleContext) + ): + return schedule_context + if param_type is ConversationSession or ( + isinstance(param_type, type) and issubclass(param_type, ConversationSession) + ): + return conversation_session + if param_type is ProviderRequest or ( + isinstance(param_type, type) and issubclass(param_type, ProviderRequest) + ): + return self._inject_provider_request(event, injected_payloads) + if param_type is LLMResponse or ( + isinstance(param_type, type) and issubclass(param_type, LLMResponse) + ): + return self._inject_llm_response(event, injected_payloads) + if param_type is MessageEventResult or ( + isinstance(param_type, type) and issubclass(param_type, MessageEventResult) + ): + return self._inject_event_result(event, injected_payloads) + + return None + + @staticmethod + def _event_type_name(event: MessageEvent) -> str: + raw = event.raw if isinstance(event.raw, dict) else {} + value = raw.get("event_type") or raw.get("type") + return str(value or "") + + @staticmethod + def _payload_from_event(event: MessageEvent, key: str) -> dict[str, Any] | None: + raw = event.raw if isinstance(event.raw, dict) else {} + payload = raw.get(key) + if isinstance(payload, dict): + return payload + nested_raw = raw.get("raw") + if isinstance(nested_raw, dict): + nested_payload = nested_raw.get(key) + if isinstance(nested_payload, dict): + return nested_payload + return None + + def _inject_provider_request( + self, + event: MessageEvent, + injected_payloads: _InjectedEventPayloads | None, + ) -> ProviderRequest | None: + if injected_payloads is None: + payload = self._payload_from_event(event, "provider_request") + return ( + ProviderRequest.from_payload(payload) if payload is not None else None + ) + if injected_payloads.provider_request is None: + payload = self._payload_from_event(event, "provider_request") + if payload is None: + return None + injected_payloads.provider_request = ProviderRequest.from_payload(payload) + return injected_payloads.provider_request + + def _inject_llm_response( + self, + event: MessageEvent, + injected_payloads: _InjectedEventPayloads | None, + ) -> LLMResponse | None: + if injected_payloads is None: + payload = self._payload_from_event(event, "llm_response") + return LLMResponse.model_validate(payload) if payload is not None else None + if injected_payloads.llm_response is None: + payload = self._payload_from_event(event, "llm_response") + if payload is None: + return None + injected_payloads.llm_response = LLMResponse.model_validate(payload) + return injected_payloads.llm_response + + def _inject_event_result( + self, + event: MessageEvent, + injected_payloads: _InjectedEventPayloads | None, + ) -> MessageEventResult | None: + if injected_payloads is None: + payload = self._payload_from_event(event, "event_result") + return ( + MessageEventResult.from_payload(payload) + if payload is not None + else None + ) + if injected_payloads.event_result is None: + payload = self._payload_from_event(event, "event_result") + if payload is None: + return None + injected_payloads.event_result = MessageEventResult.from_payload(payload) + return injected_payloads.event_result + + @staticmethod + def _append_injected_payloads( + summary: dict[str, Any], + injected_payloads: _InjectedEventPayloads, + *, + event: MessageEvent, + event_type: str, + ) -> None: + if ( + event_type == "llm_request" + and injected_payloads.provider_request is not None + ): + summary["provider_request"] = ( + injected_payloads.provider_request.to_payload() + ) + elif ( + event_type in {"llm_response", "agent_done"} + and injected_payloads.llm_response is not None + ): + summary["llm_response"] = injected_payloads.llm_response.model_dump( + exclude_none=True + ) + elif ( + event_type == "decorating_result" + and injected_payloads.event_result is not None + ): + summary["event_result"] = injected_payloads.event_result.to_payload() + if event._should_serialize_sdk_local_extras(): # noqa: SLF001 + summary["sdk_local_extras"] = event._sdk_local_extras_payload() # noqa: SLF001 + + def _format_handler_injection_error( + self, + *, + handler, + parameter_name: str, + plugin_id: str | None, + handler_ref: str | None, + args: dict[str, Any], + ) -> str: + plugin_text = plugin_id or self._plugin_id + target = handler_ref or getattr(handler, "__name__", "") + arg_keys = sorted(str(key) for key in args.keys()) + arg_keys_text = ", ".join(arg_keys) if arg_keys else "" + return ( + f"插件 '{plugin_text}' 的 handler '{target}' 参数注入失败:" + f"必填参数 '{parameter_name}' 无法注入。" + f"签名: {getattr(handler, '__name__', '')}" + f"{self._callable_signature(handler)}。" + "当前支持按类型注入 MessageEvent / Context," + "按参数名注入 event / ctx / context," + f"以及 args 中现有键:{arg_keys_text}。" + ) + + @staticmethod + def _callable_signature(handler) -> str: + try: + return str(inspect.signature(handler)) + except (TypeError, ValueError): + return "(...)" + + async def _handle_result_item( + self, + item: Any, + event: MessageEvent, + ctx: Context | None = None, + ) -> dict[str, Any]: + sent_message = await self._send_result(item, event, ctx) + if isinstance(item, dict): + return { + "sent_message": sent_message, + "stop": bool(item.get("stop", False)), + "call_llm": bool(item.get("call_llm", False)), + } + return { + "sent_message": sent_message, + "stop": False, + "call_llm": False, + } + + @staticmethod + def _merge_handler_summary( + target: dict[str, Any], + source: dict[str, Any], + ) -> None: + target["sent_message"] = bool(target.get("sent_message")) or bool( + source.get("sent_message") + ) + target["stop"] = bool(target.get("stop")) or bool(source.get("stop")) + target["call_llm"] = bool(target.get("call_llm")) or bool( + source.get("call_llm") + ) + + async def _send_result( + self, + item: Any, + event: MessageEvent, + ctx: Context | None = None, + ) -> bool: + """发送处理器结果。""" + if isinstance(item, str): + await event.reply(item) + return True + if isinstance(item, dict) and "text" in item: + await event.reply(str(item["text"])) + return True + if isinstance(item, MessageEventResult): + chain = item.chain + if chain.components: + await event.reply_chain(chain) + return True + return False + chain = coerce_message_chain(item) + if chain is not None: + if chain.components: + await event.reply_chain(chain) + return True + return False + if isinstance(item, list) and all( + isinstance(component, BaseMessageComponent) for component in item + ): + await event.reply_chain(MessageChain(list(item))) + return True + # 支持带 text 属性的对象 + text = getattr(item, "text", None) + if isinstance(text, str): + await event.reply(text) + return True + return False + + @staticmethod + def _parse_handler_args( + param_specs: Sequence[ParamSpec], + args: dict[str, Any], + ) -> dict[str, Any]: + parsed: dict[str, Any] = {} + for spec in param_specs: + if spec.name not in args: + if spec.type == "optional": + parsed[spec.name] = None + continue + if spec.required: + raise TypeError(f"缺少参数: {spec.name}") + continue + parsed[spec.name] = HandlerDispatcher._convert_param(spec, args[spec.name]) + return parsed + + @staticmethod + def _convert_param(spec: ParamSpec, value: Any) -> Any: + if spec.type in {"str", "greedy_str"}: + return str(value) + if spec.type == "int": + return int(str(value)) + if spec.type == "float": + return float(str(value)) + if spec.type == "bool": + normalized = str(value).strip().lower() + if normalized in {"true", "1", "yes", "on"}: + return True + if normalized in {"false", "0", "no", "off"}: + return False + raise TypeError(f"无法解析布尔参数 {spec.name}: {value!r}") + if spec.type == "optional": + if value is None: + return None + inner = ParamSpec( + name=spec.name, + type=spec.inner_type or "str", + required=False, + ) + return HandlerDispatcher._convert_param(inner, value) + return value + + @staticmethod + def _run_local_filters( + bindings: list[LocalFilterBinding], + *, + event: MessageEvent, + ctx: Context, + ) -> bool: + for binding in bindings: + if not binding.evaluate(event=event, ctx=ctx): + return False + return True + + @staticmethod + def _build_schedule_context( + loaded: LoadedHandler, + event_payload: dict[str, Any], + ) -> ScheduleContext | None: + if not isinstance(loaded.descriptor.trigger, ScheduleTrigger): + return None + try: + return ScheduleContext.from_payload(event_payload) + except Exception: + return None + + async def _handle_error( + self, + owner: Any, + exc: Exception, + event: MessageEvent, + ctx: Context, + *, + handler_name: str = "", + plugin_id: str | None = None, + ) -> None: + if hasattr(owner, "on_error") and callable(owner.on_error): + bound_owner = owner if isinstance(owner, Star) else None + with bind_star_runtime(bound_owner, ctx): + result = owner.on_error(exc, event, ctx) + if inspect.isawaitable(result): + await result + return + await Star.default_on_error(exc, event, ctx) + + +__all__ = ["CapabilityDispatcher", "HandlerDispatcher"] diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/limiter.py b/astrbot-sdk/src/astrbot_sdk/runtime/limiter.py new file mode 100644 index 0000000000..b32fe6e2da --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/limiter.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import time +from collections import deque +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +from ..decorators import LimiterMeta +from ..errors import AstrBotError + +DEFAULT_RATE_LIMIT_MESSAGE = "操作过于频繁,请稍后再试。" +DEFAULT_COOLDOWN_MESSAGE = "冷却中,请在 {remaining_seconds}s 后重试。" + + +@dataclass(slots=True) +class LimiterDecision: + allowed: bool + error: AstrBotError | None = None + hint: str | None = None + + +class LimiterEngine: + def __init__(self, *, clock: Callable[[], float] | None = None) -> None: + self._clock = clock or time.monotonic + self._windows: dict[str, deque[float]] = {} + + def evaluate( + self, + *, + plugin_id: str, + handler_id: str, + limiter: LimiterMeta, + event: Any, + ) -> LimiterDecision: + now = float(self._clock()) + key = self._make_key( + plugin_id=plugin_id, + handler_id=handler_id, + scope=limiter.scope, + event=event, + ) + bucket = self._windows.setdefault(key, deque()) + threshold = now - limiter.window + while bucket and bucket[0] <= threshold: + bucket.popleft() + + if len(bucket) < limiter.limit: + bucket.append(now) + return LimiterDecision(allowed=True) + + remaining = 0.0 + if bucket: + remaining = max(0.0, limiter.window - (now - bucket[0])) + hint = self._hint_text(limiter, remaining) + details = { + "scope": limiter.scope, + "handler_id": handler_id, + "remaining_seconds": round(remaining, 3), + } + if limiter.behavior == "silent": + return LimiterDecision(allowed=False) + if limiter.behavior == "error": + if limiter.kind == "cooldown": + return LimiterDecision( + allowed=False, + error=AstrBotError.cooldown_active(hint=hint, details=details), + ) + return LimiterDecision( + allowed=False, + error=AstrBotError.rate_limited(hint=hint, details=details), + ) + return LimiterDecision(allowed=False, hint=hint) + + @staticmethod + def _make_key( + *, + plugin_id: str, + handler_id: str, + scope: str, + event: Any, + ) -> str: + prefix = f"{plugin_id}:{handler_id}" + if scope == "global": + return prefix + if scope == "session": + return f"{prefix}:{getattr(event, 'session_id', '')}" + if scope == "user": + return ( + f"{prefix}:{getattr(event, 'platform_id', '')}" + f":{getattr(event, 'user_id', '')}" + ) + if scope == "group": + return ( + f"{prefix}:{getattr(event, 'platform_id', '')}" + f":{getattr(event, 'group_id', '')}" + ) + return prefix + + @staticmethod + def _hint_text(limiter: LimiterMeta, remaining: float) -> str: + if limiter.message: + return limiter.message.format( + remaining_seconds=max(1, int(remaining + 0.999)) + ) + if limiter.kind == "cooldown": + return DEFAULT_COOLDOWN_MESSAGE.format( + remaining_seconds=max(1, int(remaining + 0.999)) + ) + return DEFAULT_RATE_LIMIT_MESSAGE + + +__all__ = [ + "DEFAULT_COOLDOWN_MESSAGE", + "DEFAULT_RATE_LIMIT_MESSAGE", + "LimiterDecision", + "LimiterEngine", +] diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/loader.py b/astrbot-sdk/src/astrbot_sdk/runtime/loader.py new file mode 100644 index 0000000000..822a6c13d7 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/loader.py @@ -0,0 +1,1106 @@ +"""插件加载模块。 + +定义插件发现、环境管理和加载的核心逻辑。 +仅支持 v4 新版 Star 组件。 + +核心概念: + PluginSpec: 插件规范,描述插件的基本信息 + PluginDiscoveryResult: 插件发现结果,包含成功和跳过的插件 + PluginEnvironmentManager: 插件虚拟环境管理器 + LoadedHandler: 加载后的处理器,包含描述符和可调用对象 + LoadedPlugin: 加载后的插件,包含处理器和实例 + +插件发现流程: + 1. 扫描 plugins_dir 下的子目录 + 2. 检查 plugin.yaml 和 requirements.txt + 3. 解析 manifest_data 获取插件信息 + 4. 验证必要字段(name, components, runtime.python) + 5. 返回 PluginDiscoveryResult + +环境管理流程: + 1. 对插件集合做共享环境规划 + 2. 按 Python 版本和依赖兼容性构建环境分组 + 3. 为每个分组生成 lock/source/metadata 工件 + 4. 必要时重建或同步分组虚拟环境 + 5. 将单个插件映射到所属分组环境 + +插件加载流程: + 1. 将插件目录添加到 sys.path + 2. 遍历 components 列表 + 3. 动态导入组件类 + 4. 直接实例化(无参构造函数) + 5. 扫描处理器方法 + 6. 构建 HandlerDescriptor + +plugin.yaml 格式: + name: my_plugin + author: author_name + desc: Plugin description + version: 1.0.0 + runtime: + python: "3.11" + components: + - class: my_plugin.main:MyComponent + +`loader` 是 runtime 与插件代码之间的边界层,负责三件事: + +- 从 `plugin.yaml` 解析出可运行的 `PluginSpec` +- 用 `uv` 为插件准备独立环境 +- 把组件实例和 handler 元数据整理成 `LoadedPlugin` +""" + +from __future__ import annotations + +import copy +import importlib +import inspect +import json +import logging +import os +import re +import shutil +import sys +import typing +from dataclasses import dataclass, field +from importlib import import_module +from pathlib import Path +from typing import Any, Literal, TypeAlias, cast + +import yaml + +from .._internal.command_model import resolve_command_model_param +from .._internal.injected_params import is_framework_injected_parameter +from .._internal.plugin_ids import validate_plugin_id +from .._internal.typing_utils import unwrap_optional +from ..decorators import ( + ConversationMeta, + LimiterMeta, + get_agent_meta, + get_capability_meta, + get_handler_meta, + get_llm_tool_meta, +) +from ..llm.agents import AgentSpec +from ..llm.entities import LLMToolSpec +from ..protocol.descriptors import ( + CapabilityDescriptor, + HandlerDescriptor, + ParamSpec, + ScheduleTrigger, +) +from ..types import GreedyStr +from .environment_groups import ( + EnvironmentGroup, + EnvironmentPlanner, + EnvironmentPlanResult, + GroupEnvironmentManager, +) + +PLUGIN_MANIFEST_FILE = "plugin.yaml" +STATE_FILE_NAME = ".astrbot-worker-state.json" +CONFIG_SCHEMA_FILE = "_conf_schema.json" +PLUGIN_METADATA_ATTR = "__astrbot_plugin_metadata__" +ParamTypeName: TypeAlias = Literal[ + "str", "int", "float", "bool", "optional", "greedy_str" +] +OptionalInnerType: TypeAlias = Literal["str", "int", "float", "bool"] | None +HandlerKind: TypeAlias = Literal["handler", "hook", "tool", "session"] +DiscoverySeverity: TypeAlias = Literal["warning", "error"] +DiscoveryPhase: TypeAlias = Literal["discovery", "load", "lifecycle", "reload"] +_LOGGER = logging.getLogger(__name__) + + +def _default_python_version() -> str: + return f"{sys.version_info.major}.{sys.version_info.minor}" + + +def _venv_python_path(venv_dir: Path) -> Path: + if os.name == "nt": + return venv_dir / "Scripts" / "python.exe" + return venv_dir / "bin" / "python" + + +@dataclass(slots=True) +class PluginSpec: + name: str + plugin_dir: Path + manifest_path: Path + requirements_path: Path + python_version: str + manifest_data: dict[str, Any] + + +@dataclass(slots=True) +class PluginDiscoveryResult: + plugins: list[PluginSpec] + skipped_plugins: dict[str, str] + issues: list[PluginDiscoveryIssue] = field(default_factory=list) + + +@dataclass(slots=True) +class PluginDiscoveryIssue: + severity: DiscoverySeverity + phase: DiscoveryPhase + plugin_id: str + message: str + details: str = "" + hint: str = "" + + def to_payload(self) -> dict[str, str]: + return { + "severity": self.severity, + "phase": self.phase, + "plugin_id": self.plugin_id, + "message": self.message, + "details": self.details, + "hint": self.hint, + } + + +@dataclass(slots=True) +class LoadedHandler: + descriptor: HandlerDescriptor + callable: Any + owner: Any + plugin_id: str = "" + local_filters: list[Any] = field(default_factory=list) + limiter: LimiterMeta | None = None + conversation: ConversationMeta | None = None + + +@dataclass(slots=True) +class LoadedCapability: + descriptor: CapabilityDescriptor + callable: Any + owner: Any + plugin_id: str = "" + + +@dataclass(slots=True) +class LoadedLLMTool: + spec: LLMToolSpec + callable: Any + owner: Any + plugin_id: str = "" + + +@dataclass(slots=True) +class LoadedAgent: + spec: AgentSpec + runner_class: type[Any] + owner: Any | None = None + plugin_id: str = "" + + +@dataclass(slots=True) +class LoadedPlugin: + plugin: PluginSpec + handlers: list[LoadedHandler] + capabilities: list[LoadedCapability] = field(default_factory=list) + llm_tools: list[LoadedLLMTool] = field(default_factory=list) + agents: list[LoadedAgent] = field(default_factory=list) + instances: list[Any] = field(default_factory=list) + + +@dataclass(slots=True) +class _ResolvedComponent: + cls: type[Any] + class_path: str + index: int + + +def _iter_handler_names(instance: Any) -> list[str]: + handler_names = getattr(instance.__class__, "__handlers__", ()) + if handler_names: + return list(handler_names) + return list(dir(instance)) + + +def _iter_discoverable_names(instance: Any) -> list[str]: + handler_names = list(dict.fromkeys(_iter_handler_names(instance))) + known_names = set(handler_names) + extra_names = sorted(name for name in dir(instance) if name not in known_names) + return [*handler_names, *extra_names] + + +def _is_injected_parameter(annotation: Any, parameter_name: str) -> bool: + return is_framework_injected_parameter(parameter_name, annotation) + + +def _param_type_name(annotation: Any) -> tuple[ParamTypeName, OptionalInnerType, bool]: + normalized, is_optional = unwrap_optional(annotation) + if normalized is GreedyStr: + return "greedy_str", None, False + if normalized in {int, float, bool, str}: + normalized_name = cast( + Literal["str", "int", "float", "bool"], normalized.__name__ + ) + if is_optional: + return "optional", normalized_name, False + return normalized_name, None, True + if is_optional: + return "optional", "str", False + return "str", None, True + + +def _build_param_specs(handler: Any) -> list[ParamSpec]: + model_param = resolve_command_model_param(handler) + if model_param is not None: + return [] + try: + signature = inspect.signature(handler) + except (TypeError, ValueError): + return [] + try: + type_hints = typing.get_type_hints(handler) + except Exception: + type_hints = {} + + specs: list[ParamSpec] = [] + for parameter in signature.parameters.values(): + if parameter.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + continue + annotation = type_hints.get(parameter.name) + if _is_injected_parameter(annotation, parameter.name): + continue + param_type, inner_type, required = _param_type_name(annotation) + if parameter.default is not inspect.Parameter.empty: + required = False + specs.append( + ParamSpec( + name=parameter.name, + type=param_type, + required=required, + inner_type=inner_type, + ) + ) + + greedy_indexes = [ + index for index, spec in enumerate(specs) if spec.type == "greedy_str" + ] + if greedy_indexes and greedy_indexes[-1] != len(specs) - 1: + greedy_spec = specs[greedy_indexes[-1]] + raise ValueError(f"参数 '{greedy_spec.name}' (GreedyStr) 必须是最后一个参数。") + return specs + + +def _validate_schedule_signature(handler: Any) -> None: + try: + signature = inspect.signature(handler) + except (TypeError, ValueError): + return + allowed_names = {"ctx", "context", "sched", "schedule"} + invalid = [ + parameter.name + for parameter in signature.parameters.values() + if parameter.kind + in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + and parameter.name not in allowed_names + ] + if invalid: + raise ValueError( + "Schedule handler 只允许注入 ctx/context 和 sched/schedule 参数。" + ) + + +def _plugin_context(plugin: PluginSpec) -> str: + return f"插件 '{plugin.name}'({plugin.manifest_path})" + + +def _component_context(plugin: PluginSpec, *, class_path: str, index: int) -> str: + return f"{_plugin_context(plugin)} 的 components[{index}].class='{class_path}'" + + +def _resolve_handler_candidate(instance: Any, name: str) -> tuple[Any, Any] | None: + """解析 handler 名称,避免在扫描阶段触发无关 descriptor 副作用。""" + try: + raw = inspect.getattr_static(instance, name) + except AttributeError: + return None + + candidates = [raw] + wrapped = getattr(raw, "__func__", None) + if wrapped is not None: + candidates.append(wrapped) + + for candidate in candidates: + meta = get_handler_meta(candidate) + if meta is not None and meta.trigger is not None: + return getattr(instance, name), meta + return None + + +def _resolve_capability_candidate(instance: Any, name: str) -> tuple[Any, Any] | None: + try: + raw = inspect.getattr_static(instance, name) + except AttributeError: + return None + + candidates = [raw] + wrapped = getattr(raw, "__func__", None) + if wrapped is not None: + candidates.append(wrapped) + + for candidate in candidates: + meta = get_capability_meta(candidate) + if meta is not None: + return getattr(instance, name), meta + return None + + +def _resolve_llm_tool_candidate(instance: Any, name: str) -> tuple[Any, Any] | None: + try: + raw = inspect.getattr_static(instance, name) + except AttributeError: + return None + + candidates = [raw] + wrapped = getattr(raw, "__func__", None) + if wrapped is not None: + candidates.append(wrapped) + + for candidate in candidates: + meta = get_llm_tool_meta(candidate) + if meta is not None: + return getattr(instance, name), meta + return None + + +def _iter_agent_candidates(component_cls: type[Any]) -> list[tuple[type[Any], Any]]: + module = import_module(component_cls.__module__) + seen: set[str] = set() + resolved: list[tuple[type[Any], Any]] = [] + + def _collect(candidate: Any) -> None: + if not inspect.isclass(candidate): + return + meta = get_agent_meta(candidate) + if meta is None: + return + key = f"{candidate.__module__}.{candidate.__qualname__}" + if key in seen: + return + seen.add(key) + resolved.append((candidate, meta)) + + for candidate in vars(module).values(): + _collect(candidate) + for candidate in vars(component_cls).values(): + _collect(candidate) + return resolved + + +def _read_yaml(path: Path) -> dict[str, Any]: + data = yaml.safe_load(path.read_text(encoding="utf-8")) or {} + return data if isinstance(data, dict) else {} + + +def _read_requirements_text(path: Path) -> str: + if not path.exists(): + return "" + return path.read_text(encoding="utf-8") + + +def _plugin_config_dir(plugin_dir: Path) -> Path: + if plugin_dir.parent.name == "plugins" and plugin_dir.parent.parent.exists(): + return plugin_dir.parent.parent / "config" + return plugin_dir / "data" / "config" + + +def _plugin_config_path(plugin_dir: Path, plugin_name: str) -> Path: + return _plugin_config_dir(plugin_dir) / f"{plugin_name}_config.json" + + +def _schema_default(field_schema: dict[str, Any]) -> Any: + if "default" in field_schema: + return copy.deepcopy(field_schema["default"]) + + field_type = str(field_schema.get("type") or "string") + if field_type == "object": + items = field_schema.get("items") + if isinstance(items, dict): + return { + key: _normalize_config_value(child_schema, None) + for key, child_schema in items.items() + if isinstance(child_schema, dict) + } + return {} + if field_type in {"list", "template_list", "file"}: + return [] + if field_type == "dict": + return {} + if field_type == "int": + return 0 + if field_type == "float": + return 0.0 + if field_type == "bool": + return False + return "" + + +def _normalize_config_value(field_schema: dict[str, Any], value: Any) -> Any: + field_type = str(field_schema.get("type") or "string") + default_value = _schema_default(field_schema) + + if field_type == "object": + items = field_schema.get("items") + if not isinstance(items, dict): + return default_value + current = value if isinstance(value, dict) else {} + return { + key: _normalize_config_value(child_schema, current.get(key)) + for key, child_schema in items.items() + if isinstance(child_schema, dict) + } + if field_type in {"list", "template_list", "file"}: + return copy.deepcopy(value) if isinstance(value, list) else default_value + if field_type == "dict": + return copy.deepcopy(value) if isinstance(value, dict) else default_value + if field_type == "int": + return ( + value + if isinstance(value, int) and not isinstance(value, bool) + else default_value + ) + if field_type == "float": + return ( + value + if isinstance(value, (int, float)) and not isinstance(value, bool) + else default_value + ) + if field_type == "bool": + return value if isinstance(value, bool) else default_value + if field_type in {"string", "text"}: + return value if isinstance(value, str) else default_value + return copy.deepcopy(value) if value is not None else default_value + + +def load_plugin_config_schema(plugin: PluginSpec) -> dict[str, Any]: + """加载插件配置 schema,解析失败时记录日志并返回空对象。""" + schema_path = plugin.plugin_dir / CONFIG_SCHEMA_FILE + if not schema_path.exists(): + return {} + + try: + schema_payload = json.loads(schema_path.read_text(encoding="utf-8")) + except json.JSONDecodeError as exc: + _LOGGER.warning( + "Failed to parse SDK plugin config schema %s: %s", + schema_path, + exc, + ) + return {} + except OSError as exc: + _LOGGER.warning( + "Failed to read SDK plugin config schema %s: %s", + schema_path, + exc, + ) + return {} + if not isinstance(schema_payload, dict): + _LOGGER.warning( + "SDK plugin config schema %s must be a JSON object, got %s", + schema_path, + type(schema_payload).__name__, + ) + return {} + return schema_payload + + +def save_plugin_config( + plugin: PluginSpec, + payload: dict[str, Any], + *, + schema: dict[str, Any] | None = None, +) -> dict[str, Any]: + """按 schema 归一化并写回插件配置。""" + active_schema = ( + load_plugin_config_schema(plugin) if schema is None else dict(schema) + ) + normalized = { + key: _normalize_config_value(field_schema, payload.get(key)) + for key, field_schema in active_schema.items() + if isinstance(field_schema, dict) + } + + config_path = _plugin_config_path(plugin.plugin_dir, plugin.name) + config_path.parent.mkdir(parents=True, exist_ok=True) + config_path.write_text( + json.dumps(normalized, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + return normalized + + +def load_plugin_config( + plugin: PluginSpec, + *, + schema: dict[str, Any] | None = None, +) -> dict[str, Any]: + """加载插件配置,返回普通字典。""" + active_schema = ( + load_plugin_config_schema(plugin) if schema is None else dict(schema) + ) + if not active_schema: + return {} + + config_path = _plugin_config_path(plugin.plugin_dir, plugin.name) + try: + existing_payload = ( + json.loads(config_path.read_text(encoding="utf-8")) + if config_path.exists() + else {} + ) + except json.JSONDecodeError as exc: + _LOGGER.warning( + "Failed to parse SDK plugin config %s: %s", + config_path, + exc, + ) + existing_payload = {} + except OSError as exc: + _LOGGER.warning( + "Failed to read SDK plugin config %s: %s", + config_path, + exc, + ) + existing_payload = {} + existing = existing_payload if isinstance(existing_payload, dict) else {} + normalized = { + key: _normalize_config_value(field_schema, existing.get(key)) + for key, field_schema in active_schema.items() + if isinstance(field_schema, dict) + } + + if not config_path.exists() or normalized != existing: + save_plugin_config(plugin, normalized, schema=active_schema) + return normalized + + +def _is_new_star_component(cls: type[Any]) -> bool: + """检查组件类是否为 v4 新版 Star。""" + return bool(getattr(cls, "__astrbot_is_new_star__", False)) + + +def _plugin_component_classes(plugin: PluginSpec) -> list[_ResolvedComponent]: + """解析插件组件类列表。""" + components = plugin.manifest_data.get("components") or [] + if not isinstance(components, list): + return [] + + classes: list[_ResolvedComponent] = [] + for index, component in enumerate(components): + if not isinstance(component, dict): + raise ValueError( + f"{_plugin_context(plugin)} 的 components[{index}] 必须是 object。" + ) + class_path = component.get("class") + if not isinstance(class_path, str) or ":" not in class_path: + raise ValueError( + f"{_plugin_context(plugin)} 的 components[{index}].class " + "必须是 ':'。" + ) + try: + cls = import_string(class_path, plugin.plugin_dir) + except Exception as exc: + raise ValueError( + f"{_component_context(plugin, class_path=class_path, index=index)} " + f"加载失败:{exc}" + ) from exc + if not isinstance(cls, type): + raise ValueError( + f"{_component_context(plugin, class_path=class_path, index=index)} " + "解析结果不是类,请检查导出名称。" + ) + classes.append( + _ResolvedComponent( + cls=cls, + class_path=class_path, + index=index, + ) + ) + if not classes: + raise ValueError( + f"{_plugin_context(plugin)} 未声明任何可加载组件。" + "请检查 plugin.yaml 中的 components 配置。" + ) + return classes + + +def load_plugin_spec(plugin_dir: Path) -> PluginSpec: + """从插件目录加载插件规范。""" + plugin_dir = plugin_dir.resolve() + manifest_path = plugin_dir / PLUGIN_MANIFEST_FILE + requirements_path = plugin_dir / "requirements.txt" + + if not manifest_path.exists(): + raise ValueError(f"插件目录 '{plugin_dir}' 缺少 {PLUGIN_MANIFEST_FILE}。") + + manifest_data = _read_yaml(manifest_path) + runtime = manifest_data.get("runtime") or {} + python_version = runtime.get("python") or _default_python_version() + + return PluginSpec( + name=str(manifest_data.get("name") or plugin_dir.name), + plugin_dir=plugin_dir, + manifest_path=manifest_path, + requirements_path=requirements_path, + python_version=str(python_version), + manifest_data=manifest_data, + ) + + +def validate_plugin_spec(plugin: PluginSpec) -> None: + """校验单个插件规范,供 CLI 和发现流程复用。""" + manifest_data = plugin.manifest_data + manifest_label = f"插件 '{plugin.name}'({plugin.manifest_path})" + + raw_name = manifest_data.get("name") + if not isinstance(raw_name, str) or not raw_name: + raise ValueError(f"{manifest_label} 缺少 name。") + try: + validate_plugin_id(raw_name) + except ValueError as exc: + raise ValueError(f"{manifest_label} 的 name 不合法:{exc}") from exc + + raw_runtime = manifest_data.get("runtime") or {} + raw_python = raw_runtime.get("python") + if not isinstance(raw_python, str) or not raw_python: + raise ValueError(f"{manifest_label} 缺少 runtime.python。") + + components = manifest_data.get("components") + if not isinstance(components, list): + raise ValueError(f"{manifest_label} 的 components 必须是数组。") + + for index, component in enumerate(components): + if not isinstance(component, dict): + raise ValueError(f"{manifest_label} 的 components[{index}] 必须是 object。") + class_path = component.get("class") + if not isinstance(class_path, str) or ":" not in class_path: + raise ValueError( + f"{manifest_label} 的 components[{index}].class " + "必须是 ':'。" + ) + + +def discover_plugins(plugins_dir: Path) -> PluginDiscoveryResult: + """扫描目录发现所有插件。""" + plugins_root = plugins_dir.resolve() + skipped_plugins: dict[str, str] = {} + issues: list[PluginDiscoveryIssue] = [] + plugins: list[PluginSpec] = [] + seen_names: set[str] = set() + + if not plugins_root.exists(): + return PluginDiscoveryResult([], {}, []) + + for entry in sorted(plugins_root.iterdir()): + if not entry.is_dir() or entry.name.startswith("."): + continue + manifest_path = entry / PLUGIN_MANIFEST_FILE + if not manifest_path.exists(): + continue + + plugin: PluginSpec | None = None + try: + plugin = load_plugin_spec(entry) + validate_plugin_spec(plugin) + except Exception as exc: + skip_key = entry.name + if plugin is not None: + raw_name = plugin.manifest_data.get("name") + if isinstance(raw_name, str) and raw_name: + skip_key = raw_name + details = str(exc) + skipped_plugins[skip_key] = f"failed to parse plugin manifest: {details}" + issues.append( + PluginDiscoveryIssue( + severity="error", + phase="discovery", + plugin_id=skip_key, + message="插件发现失败", + details=details, + ) + ) + continue + + plugin_name = plugin.name + if not isinstance(plugin_name, str) or not plugin_name: + skipped_plugins[entry.name] = "plugin name is required" + issues.append( + PluginDiscoveryIssue( + severity="error", + phase="discovery", + plugin_id=entry.name, + message="插件缺少名称", + details="plugin name is required", + ) + ) + continue + if plugin_name in seen_names: + skipped_plugins[plugin_name] = "duplicate plugin name" + issues.append( + PluginDiscoveryIssue( + severity="error", + phase="discovery", + plugin_id=plugin_name, + message="插件名称重复", + details="duplicate plugin name", + ) + ) + continue + seen_names.add(plugin_name) + plugins.append(plugin) + + return PluginDiscoveryResult( + plugins=plugins, + skipped_plugins=skipped_plugins, + issues=issues, + ) + + +class PluginEnvironmentManager: + """运行时访问分组环境管理的门面层。 + + 运行时仍然保留历史上的 `prepare_environment(plugin)` 调用入口,但底层 + 实现已经变成两阶段模型: + + 1. `plan()` 负责解析跨插件分组和共享工件 + 2. `prepare_environment()` 负责把单个插件映射到它所属的分组环境 + """ + + def __init__(self, repo_root: Path, uv_binary: str | None = None) -> None: + self.repo_root = repo_root.resolve() + self.uv_binary = uv_binary + self.cache_dir = self.repo_root / ".uv-cache" + self._planner = EnvironmentPlanner(self.repo_root, uv_binary=uv_binary) + self._group_manager = GroupEnvironmentManager( + self.repo_root, uv_binary=uv_binary + ) + self.uv_binary = self._planner.uv_binary + self._plan_result: EnvironmentPlanResult | None = None + + def plan(self, plugins: list[PluginSpec]) -> EnvironmentPlanResult: + """为当前插件集合生成共享环境规划。""" + plan_result = self._planner.plan(plugins) + self._plan_result = plan_result + return plan_result + + def prepare_group_environment(self, group: EnvironmentGroup) -> Path: + """返回指定分组的解释器路径。""" + if self._plan_result is None: + self._plan_result = EnvironmentPlanResult(groups=[group]) + return self._group_manager.prepare(group) + + def prepare_environment(self, plugin: PluginSpec) -> Path: + """返回该插件所属分组环境的解释器路径。 + + 如果调用方还没有先对整批插件做规划,这里会自动创建一个至少包含当 + 前插件的最小规划,以保证旧的"单插件直接调用"模式仍然可用。 + """ + if ( + self._plan_result is None + or plugin.name not in self._plan_result.plugin_to_group + ): + planned_plugins = ( + list(self._plan_result.plugins) if self._plan_result else [] + ) + if plugin.name not in {item.name for item in planned_plugins}: + planned_plugins.append(plugin) + self.plan(planned_plugins) + + assert self._plan_result is not None + group = self._plan_result.plugin_to_group.get(plugin.name) + if group is None: + reason = self._plan_result.skipped_plugins.get(plugin.name) + if reason is not None: + raise RuntimeError(reason) + raise RuntimeError(f"environment plan missing plugin: {plugin.name}") + + return self.prepare_group_environment(group) + + @staticmethod + def _fingerprint(plugin: PluginSpec) -> str: + requirements = _read_requirements_text(plugin.requirements_path) + payload = { + "python_version": plugin.python_version, + "requirements": requirements, + } + return json.dumps(payload, ensure_ascii=True, sort_keys=True) + + @staticmethod + def _load_state(state_path: Path) -> dict[str, Any]: + if not state_path.exists(): + return {} + try: + data = json.loads(state_path.read_text(encoding="utf-8")) + except Exception: + return {} + return data if isinstance(data, dict) else {} + + @staticmethod + def _write_state(state_path: Path, plugin: PluginSpec, fingerprint: str) -> None: + state_path.write_text( + json.dumps( + { + "plugin": plugin.name, + "python_version": plugin.python_version, + "fingerprint": fingerprint, + }, + ensure_ascii=True, + indent=2, + sort_keys=True, + ), + encoding="utf-8", + ) + + @staticmethod + def _matches_python_version(venv_dir: Path, version: str) -> bool: + pyvenv_cfg = venv_dir / "pyvenv.cfg" + if not pyvenv_cfg.exists(): + return False + try: + content = pyvenv_cfg.read_text(encoding="utf-8") + except OSError: + return False + match = re.search(r"version\s*=\s*(\d+\.\d+)\.\d+", content, re.IGNORECASE) + return match is not None and match.group(1) == version + + +def load_plugin(plugin: PluginSpec) -> LoadedPlugin: + """加载插件,返回处理器和能力列表。 + + 仅支持 v4 新版 Star 组件(无参构造函数)。 + """ + plugin_path = str(plugin.plugin_dir) + if plugin_path not in sys.path: + sys.path.insert(0, plugin_path) + _purge_plugin_bytecode(plugin.plugin_dir) + _purge_plugin_modules(plugin.plugin_dir) + + instances: list[Any] = [] + handlers: list[LoadedHandler] = [] + capabilities: list[LoadedCapability] = [] + llm_tools: list[LoadedLLMTool] = [] + agents: list[LoadedAgent] = [] + seen_agents: set[str] = set() + + for resolved_component in _plugin_component_classes(plugin): + component_cls = resolved_component.cls + if not _is_new_star_component(component_cls): + raise ValueError( + f"{_component_context(plugin, class_path=resolved_component.class_path, index=resolved_component.index)} " + f"解析到的类 {component_cls.__module__}.{component_cls.__qualname__} " + "不是 v4 Star 组件。请继承 astrbot_sdk.Star。" + ) + try: + instance = component_cls() + except Exception as exc: + raise ValueError( + f"{_component_context(plugin, class_path=resolved_component.class_path, index=resolved_component.index)} " + f"实例化失败:{exc}" + ) from exc + instances.append(instance) + + for runner_class, meta in _iter_agent_candidates(component_cls): + runner_key = f"{runner_class.__module__}.{runner_class.__qualname__}" + if runner_key in seen_agents: + continue + seen_agents.add(runner_key) + agents.append( + LoadedAgent( + spec=meta.spec.model_copy(deep=True), + runner_class=runner_class, + owner=None, + plugin_id=plugin.name, + ) + ) + + for name in _iter_discoverable_names(instance): + resolved = _resolve_handler_candidate(instance, name) + capability = _resolve_capability_candidate(instance, name) + llm_tool = _resolve_llm_tool_candidate(instance, name) + if resolved is None and capability is None and llm_tool is None: + continue + if capability is not None: + bound, meta = capability + capabilities.append( + LoadedCapability( + descriptor=meta.descriptor.model_copy(deep=True), + callable=bound, + owner=instance, + plugin_id=plugin.name, + ) + ) + if llm_tool is not None: + bound_tool, tool_meta = llm_tool + llm_tools.append( + LoadedLLMTool( + spec=tool_meta.spec.model_copy(deep=True), + callable=bound_tool, + owner=instance, + plugin_id=plugin.name, + ), + ) + if resolved is not None: + bound, meta = resolved + handler_id = f"{plugin.name}:{instance.__class__.__module__}.{instance.__class__.__name__}.{name}" + if isinstance(meta.trigger, ScheduleTrigger): + _validate_schedule_signature(bound) + param_specs = _build_param_specs(bound) + handlers.append( + LoadedHandler( + descriptor=HandlerDescriptor( + id=handler_id, + trigger=meta.trigger, + kind=cast(HandlerKind, meta.kind), + contract=meta.contract, + description=meta.description, + priority=meta.priority, + permissions=meta.permissions.model_copy(deep=True), + filters=[ + item.model_copy(deep=True) for item in meta.filters + ], + param_specs=[ + item.model_copy(deep=True) for item in param_specs + ], + command_route=( + meta.command_route.model_copy(deep=True) + if meta.command_route is not None + else None + ), + ), + callable=bound, + owner=instance, + plugin_id=plugin.name, + local_filters=list(meta.local_filters), + limiter=( + None + if meta.limiter is None + else LimiterMeta( + kind=meta.limiter.kind, + limit=meta.limiter.limit, + window=meta.limiter.window, + scope=meta.limiter.scope, + behavior=meta.limiter.behavior, + message=meta.limiter.message, + ) + ), + conversation=( + None + if meta.conversation is None + else ConversationMeta( + timeout=meta.conversation.timeout, + mode=meta.conversation.mode, + busy_message=meta.conversation.busy_message, + grace_period=meta.conversation.grace_period, + ) + ), + ) + ) + + return LoadedPlugin( + plugin=plugin, + handlers=handlers, + capabilities=capabilities, + llm_tools=llm_tools, + agents=agents, + instances=instances, + ) + + +def _path_within_root(path: Path, root: Path) -> bool: + try: + path.resolve().relative_to(root.resolve()) + except ValueError: + return False + return True + + +def _plugin_defines_module_root(plugin_dir: Path, root_name: str) -> bool: + return (plugin_dir / f"{root_name}.py").exists() or ( + plugin_dir / root_name + ).exists() + + +def _module_belongs_to_plugin(module: Any, plugin_dir: Path) -> bool: + file_path = getattr(module, "__file__", None) + if isinstance(file_path, str) and _path_within_root(Path(file_path), plugin_dir): + return True + + package_paths = getattr(module, "__path__", None) + if package_paths is None: + return False + return any( + isinstance(candidate, str) and _path_within_root(Path(candidate), plugin_dir) + for candidate in package_paths + ) + + +def _purge_plugin_modules(plugin_dir: Path) -> None: + plugin_root = plugin_dir.resolve() + for module_name, module in list(sys.modules.items()): + if module is None: + continue + if _module_belongs_to_plugin(module, plugin_root): + sys.modules.pop(module_name, None) + + +def _purge_plugin_bytecode(plugin_dir: Path) -> None: + plugin_root = plugin_dir.resolve() + for path in plugin_root.rglob("*"): + try: + if path.is_dir() and path.name == "__pycache__": + shutil.rmtree(path, ignore_errors=True) + continue + if path.is_file() and path.suffix in {".pyc", ".pyo"}: + path.unlink(missing_ok=True) + except OSError: + continue + + +def _purge_module_root(root_name: str) -> None: + for module_name in list(sys.modules): + if module_name == root_name or module_name.startswith(f"{root_name}."): + sys.modules.pop(module_name, None) + + +def _prepare_plugin_import(module_name: str, plugin_dir: Path | None) -> None: + if plugin_dir is None: + return + + plugin_root = plugin_dir.resolve() + plugin_path = str(plugin_root) + sys.path[:] = [entry for entry in sys.path if entry != plugin_path] + sys.path.insert(0, plugin_path) + + root_name = module_name.split(".", 1)[0] + if not _plugin_defines_module_root(plugin_root, root_name): + return + + cached_root = sys.modules.get(root_name) + cached_module = sys.modules.get(module_name) + if cached_root is not None and not _module_belongs_to_plugin( + cached_root, plugin_root + ): + _purge_module_root(root_name) + elif cached_module is not None and not _module_belongs_to_plugin( + cached_module, plugin_root + ): + _purge_module_root(root_name) + + importlib.invalidate_caches() + + +def import_string(path: str, plugin_dir: Path | None = None) -> Any: + """通过字符串路径导入对象。""" + module_name, attr = path.split(":", 1) + _prepare_plugin_import(module_name, plugin_dir) + module = import_module(module_name) + return getattr(module, attr) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/peer.py b/astrbot-sdk/src/astrbot_sdk/runtime/peer.py new file mode 100644 index 0000000000..6259c50abb --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/peer.py @@ -0,0 +1,775 @@ +"""协议对等端模块。 + +定义 Peer 类,封装双向传输通道上的消息收发、初始化握手、能力调用、 +流式事件转发与取消处理。这里的 peer 指"通信对端/本端"这一网络协议概念, +而不是业务上的用户、群聊或会话对象。 + +核心职责: + - 消息序列化/反序列化 + - 初始化握手协议 + - 能力调用(同步/流式) + - 取消处理 + - 连接生命周期管理 +消息处理: + 入站: + ResultMessage -> 唤醒等待的 Future + EventMessage -> 投递到流式队列 + InitializeMessage -> 调用 initialize_handler + InvokeMessage -> 创建任务调用 invoke_handler + CancelMessage -> 取消对应的任务 + + 出站: + initialize() -> InitializeMessage + invoke() -> InvokeMessage(stream=False) + invoke_stream() -> InvokeMessage(stream=True) + cancel() -> CancelMessage + +使用示例: + # 作为客户端发起调用 + peer = Peer(transport=transport, peer_info=PeerInfo(...)) + await peer.start() + output = await peer.initialize(handlers) + result = await peer.invoke("llm.chat", {"prompt": "hello"}) + + # 作为服务端处理调用 + peer.set_invoke_handler(my_handler) + await peer.start() + +消息处理流程: + 入站消息: + ResultMessage -> 唤醒等待的 Future + EventMessage -> 投递到流式队列 + InitializeMessage -> 调用 _initialize_handler + InvokeMessage -> 创建任务调用 _invoke_handler + CancelMessage -> 取消对应的任务 + + 出站消息: + initialize() -> InitializeMessage + invoke() -> InvokeMessage(stream=False) + invoke_stream() -> InvokeMessage(stream=True) + cancel() -> CancelMessage + +取消机制: + - CancelToken 用于检查取消状态 + - 入站任务在收到 CancelMessage 时被取消 + - 早到取消:在任务执行前检查 cancel_token,避免竞态条件 + +`Peer` 把 `Transport` 和 v4 协议消息模型接起来,负责: + +- 握手与远端元数据缓存 +- 请求 ID 关联 +- 非流式 / 流式调用分发 +- 取消传播 +- 连接异常时的统一收口 + +它本身不做业务路由,真正的执行逻辑交给 `CapabilityRouter` 或 +`HandlerDispatcher`。 +""" + +from __future__ import annotations + +import asyncio +import inspect +from collections.abc import AsyncIterator, Awaitable, Callable, Sequence +from typing import Any + +from loguru import logger + +from .._internal.invocation_context import ( + caller_plugin_scope, + current_caller_plugin_id, +) +from ..context import CancelToken +from ..errors import AstrBotError, ErrorCodes +from ..protocol.messages import ( + CancelMessage, + ErrorPayload, + EventMessage, + InitializeMessage, + InitializeOutput, + InvokeMessage, + PeerInfo, + ResultMessage, + parse_message, +) +from .capability_router import StreamExecution + +InitializeHandler = Callable[[InitializeMessage], Awaitable[InitializeOutput]] +InvokeHandler = Callable[ + [InvokeMessage, CancelToken], Awaitable[dict[str, Any] | StreamExecution] +] +CancelHandler = Callable[[str], Awaitable[None]] + +SUPPORTED_PROTOCOL_VERSIONS_METADATA_KEY = "supported_protocol_versions" +NEGOTIATED_PROTOCOL_VERSION_METADATA_KEY = "negotiated_protocol_version" + + +def _dedupe_protocol_versions( + versions: Sequence[str] | None, *, preferred_version: str +) -> list[str]: + ordered_versions: list[str] = [preferred_version] + if versions is not None: + ordered_versions.extend(versions) + deduped: list[str] = [] + for version in ordered_versions: + if not isinstance(version, str) or not version: + continue + if version not in deduped: + deduped.append(version) + return deduped + + +def _parse_protocol_version(version: str) -> tuple[int, int] | None: + major, dot, minor = version.partition(".") + if not dot or not major.isdigit() or not minor.isdigit(): + return None + return int(major), int(minor) + + +def _select_negotiated_protocol_version( + requested_version: str, + remote_metadata: dict[str, Any], + local_supported_versions: Sequence[str], +) -> str | None: + if requested_version in local_supported_versions: + return requested_version + requested_key = _parse_protocol_version(requested_version) + if requested_key is None: + return None + remote_supported = remote_metadata.get(SUPPORTED_PROTOCOL_VERSIONS_METADATA_KEY) + if not isinstance(remote_supported, (list, tuple)): + return None + local_supported_set = set(local_supported_versions) + compatible_versions: list[tuple[tuple[int, int], str]] = [] + for version in remote_supported: + if not isinstance(version, str) or version not in local_supported_set: + continue + parsed_version = _parse_protocol_version(version) + if parsed_version is None: + continue + if parsed_version[0] != requested_key[0] or parsed_version > requested_key: + continue + compatible_versions.append((parsed_version, version)) + if not compatible_versions: + return None + compatible_versions.sort(reverse=True) + return compatible_versions[0][1] + + +class Peer: + """表示协议连接中的一个对等端。 + + `Peer` 封装一条双向传输通道上的消息收发、初始化握手、能力调用、 + 流式事件转发与取消处理。这里的 `peer` 指“通信对端/本端”这一网络 + 协议概念,而不是业务上的用户、群聊或会话对象。 + """ + + def __init__( + self, + *, + transport, + peer_info: PeerInfo, + protocol_version: str = "1.0", + supported_protocol_versions: Sequence[str] | None = None, + ) -> None: + """创建一个协议对等端实例。 + + Args: + transport: 底层传输实现,负责发送字符串消息并回调入站消息。 + peer_info: 当前端点对外声明的身份信息。 + protocol_version: 当前端点首选的协议版本,用于初始化握手。 + supported_protocol_versions: 当前端点可接受的协议版本列表。 + """ + self.transport = transport + self.peer_info = peer_info + self.protocol_version = protocol_version + self.supported_protocol_versions = _dedupe_protocol_versions( + supported_protocol_versions, + preferred_version=protocol_version, + ) + self.negotiated_protocol_version: str | None = None + self.remote_peer: PeerInfo | None = None + self.remote_handlers = [] + self.remote_provided_capabilities = [] + self.remote_capabilities = [] + self.remote_capability_map: dict[str, Any] = {} + self.remote_provided_capability_map: dict[str, Any] = {} + self.remote_metadata: dict[str, Any] = {} + + self._initialize_handler: InitializeHandler | None = None + self._invoke_handler: InvokeHandler | None = None + self._cancel_handler: CancelHandler | None = None + self._counter = 0 + self._closed = asyncio.Event() + self._unusable = False + self._stopping = False + self._pending_results: dict[str, asyncio.Future[ResultMessage]] = {} + self._pending_streams: dict[str, asyncio.Queue[Any]] = {} + self._inbound_tasks: dict[ + str, tuple[asyncio.Task[None], CancelToken, asyncio.Event] + ] = {} + self._remote_initialized = asyncio.Event() + self._remote_initialized_successfully = False + self._transport_watch_task: asyncio.Task[None] | None = None + + def set_initialize_handler(self, handler: InitializeHandler) -> None: + """注册处理远端 `initialize` 请求的握手处理器。""" + self._initialize_handler = handler + + def set_invoke_handler(self, handler: InvokeHandler) -> None: + """注册处理远端 `invoke` 请求的能力调用处理器。""" + self._invoke_handler = handler + + def set_cancel_handler(self, handler: CancelHandler) -> None: + """注册处理远端 `cancel` 请求的取消回调。""" + self._cancel_handler = handler + + async def start(self) -> None: + """启动传输层并将原始入站消息绑定到当前 `Peer`。""" + self._closed.clear() + self._unusable = False + self._stopping = False + self.negotiated_protocol_version = None + self._remote_initialized.clear() + self._remote_initialized_successfully = False + self.transport.set_message_handler(self._handle_raw_message) + await self.transport.start() + self._transport_watch_task = asyncio.create_task(self._watch_transport_closed()) + + async def stop(self) -> None: + """关闭 `Peer` 并清理所有挂起中的请求、流和入站任务。""" + if self._closed.is_set(): + return + self._stopping = True + # 终止所有挂起的 RPC,避免调用方永久挂起 + for future in list(self._pending_results.values()): + if not future.done(): + future.set_exception(AstrBotError.internal_error("连接已关闭")) + self._pending_results.clear() + + for queue in list(self._pending_streams.values()): + await queue.put(AstrBotError.internal_error("连接已关闭")) + self._pending_streams.clear() + + # 取消所有入站任务 + for task, token, _started in list(self._inbound_tasks.values()): + token.cancel() + task.cancel() + self._inbound_tasks.clear() + + await self.transport.stop() + self._closed.set() + + async def wait_closed(self) -> None: + """等待底层传输彻底关闭。""" + await self.transport.wait_closed() + + async def _watch_transport_closed(self) -> None: + """监视底层传输的意外关闭,并主动失败挂起调用。""" + try: + await self.transport.wait_closed() + if self._closed.is_set() or self._stopping: + return + await self._fail_connection( + AstrBotError( + code=ErrorCodes.NETWORK_ERROR, + message="连接已关闭", + hint="请检查对端进程或传输连接", + retryable=True, + ) + ) + finally: + current_task = asyncio.current_task() + if self._transport_watch_task is current_task: + self._transport_watch_task = None + + async def wait_until_remote_initialized(self, timeout: float | None = 30.0) -> None: + """等待远端完成初始化握手。 + + Args: + timeout: 等待秒数。传入 `None` 表示无限等待。 + """ + init_waiter = asyncio.create_task(self._remote_initialized.wait()) + closed_waiter = asyncio.create_task(self.wait_closed()) + try: + done, pending = await asyncio.wait( + {init_waiter, closed_waiter}, + timeout=timeout, + return_when=asyncio.FIRST_COMPLETED, + ) + if not done: + raise TimeoutError() + if init_waiter in done and self._remote_initialized_successfully: + return + raise AstrBotError.protocol_error("连接在初始化完成前关闭") + finally: + for task in (init_waiter, closed_waiter): + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + async def initialize( + self, + handlers, + *, + provided_capabilities=None, + metadata: dict[str, Any] | None = None, + ) -> InitializeOutput: + """向远端发送初始化请求并缓存远端声明的能力信息。 + + Args: + handlers: 当前端点声明可接收的处理器列表。 + metadata: 附带给远端的握手元数据。 + + Returns: + 远端返回的初始化结果。 + """ + self._ensure_usable() + request_id = self._next_id() + handshake_metadata = dict(metadata or {}) + handshake_metadata[SUPPORTED_PROTOCOL_VERSIONS_METADATA_KEY] = list( + self.supported_protocol_versions + ) + future: asyncio.Future[ResultMessage] = ( + asyncio.get_running_loop().create_future() + ) + self._pending_results[request_id] = future + await self._send( + InitializeMessage( + id=request_id, + protocol_version=self.protocol_version, + peer=self.peer_info, + handlers=list(handlers), + provided_capabilities=list(provided_capabilities or []), + metadata=handshake_metadata, + ) + ) + result = await future + if result.kind != "initialize_result": + raise AstrBotError.protocol_error("initialize 必须收到 initialize_result") + if not result.success: + self._unusable = True + await self.stop() + raise AstrBotError.from_payload( + result.error.model_dump() if result.error else {} + ) + output = InitializeOutput.model_validate(result.output) + negotiated_protocol_version = ( + output.protocol_version + or output.metadata.get(NEGOTIATED_PROTOCOL_VERSION_METADATA_KEY) + or self.protocol_version + ) + if ( + not isinstance(negotiated_protocol_version, str) + or negotiated_protocol_version not in self.supported_protocol_versions + ): + self._unusable = True + await self.stop() + raise AstrBotError.protocol_version_mismatch( + f"对端返回了当前端点不支持的协商协议版本:{negotiated_protocol_version}" + ) + self.remote_peer = output.peer + self.remote_capabilities = output.capabilities + self.remote_capability_map = {item.name: item for item in output.capabilities} + self.remote_metadata = output.metadata + self.negotiated_protocol_version = negotiated_protocol_version + self._remote_initialized_successfully = True + self._remote_initialized.set() + return output + + async def invoke( + self, + capability: str, + payload: dict[str, Any], + *, + stream: bool = False, + request_id: str | None = None, + ) -> dict[str, Any]: + """发起一次非流式能力调用并等待最终结果。 + + Args: + capability: 远端能力名。 + payload: 调用输入。 + stream: 必须为 `False`;流式场景应改用 `invoke_stream()`。 + request_id: 可选的请求 ID;未提供时自动生成。 + """ + self._ensure_usable() + if stream: + raise ValueError("stream=True 请使用 invoke_stream()") + request_id = request_id or self._next_id() + future: asyncio.Future[ResultMessage] = ( + asyncio.get_running_loop().create_future() + ) + self._pending_results[request_id] = future + await self._send( + InvokeMessage( + id=request_id, + capability=capability, + input=payload, + stream=False, + caller_plugin_id=current_caller_plugin_id(), + ) + ) + result = await future + if not result.success: + raise AstrBotError.from_payload( + result.error.model_dump() if result.error else {} + ) + return result.output + + async def invoke_stream( + self, + capability: str, + payload: dict[str, Any], + *, + request_id: str | None = None, + include_completed: bool = False, + ) -> AsyncIterator[EventMessage]: + """发起一次流式能力调用并返回事件迭代器。 + + 调用方会收到 `delta` 事件,`started` 会被内部吞掉, + 默认情况下 `completed` 用于结束迭代,`failed` 会转换为异常抛出。 + + Args: + capability: 远端能力名。 + payload: 调用输入。 + request_id: 可选的请求 ID;未提供时自动生成。 + include_completed: 是否把 `completed` 事件也返回给调用方。 + """ + self._ensure_usable() + request_id = request_id or self._next_id() + queue: asyncio.Queue[Any] = asyncio.Queue() + self._pending_streams[request_id] = queue + await self._send( + InvokeMessage( + id=request_id, + capability=capability, + input=payload, + stream=True, + caller_plugin_id=current_caller_plugin_id(), + ) + ) + + async def iterator() -> AsyncIterator[EventMessage]: + try: + while True: + item = await queue.get() + if isinstance(item, Exception): + raise item + if not isinstance(item, EventMessage): + raise AstrBotError.protocol_error("流式调用收到非法事件") + if item.phase == "started": + continue + if item.phase == "delta": + yield item + continue + if item.phase == "completed": + if include_completed: + yield item + break + if item.phase == "failed": + raise AstrBotError.from_payload( + item.error.model_dump() if item.error else {} + ) + finally: + self._pending_streams.pop(request_id, None) + + return iterator() + + async def cancel(self, request_id: str, reason: str = "user_cancelled") -> None: + """向远端发送取消请求,尝试中止指定 ID 的在途调用。""" + await self._send(CancelMessage(id=request_id, reason=reason)) + + def _next_id(self) -> str: + """生成当前连接内递增的消息 ID。""" + self._counter += 1 + return f"msg_{self._counter:04d}" + + def _ensure_usable(self) -> None: + """确保连接仍处于可用状态,否则立即抛出协议错误。""" + if self._unusable: + raise AstrBotError.protocol_error("连接已进入不可用状态") + + async def _handle_raw_message(self, payload: str) -> None: + """解析原始消息并分发到对应的消息处理分支。""" + try: + message = parse_message(payload) + if isinstance(message, ResultMessage): + await self._handle_result(message) + return + if isinstance(message, EventMessage): + await self._handle_event(message) + return + if isinstance(message, InitializeMessage): + await self._handle_initialize(message) + return + if isinstance(message, InvokeMessage): + token = CancelToken() + started = asyncio.Event() + task = asyncio.create_task(self._handle_invoke(message, token, started)) + self._inbound_tasks[message.id] = (task, token, started) + + def _on_invoke_done( + _task: asyncio.Task[None], request_id: str = message.id + ) -> None: + self._inbound_tasks.pop(request_id, None) + if _task.cancelled(): + return + exc = _task.exception() + if exc is None: + return + # 后台 invoke 理论上应把错误编码成协议消息;若异常仍逃逸,通常说明 + # 回复发送失败或连接状态异常,必须立刻标记连接失效,避免对端永久等待。 + logger.error( + "Peer inbound invoke task crashed unexpectedly: " + "request_id={} error={!r}", + request_id, + exc, + ) + error = ( + exc + if isinstance(exc, AstrBotError) + else AstrBotError( + code=ErrorCodes.NETWORK_ERROR, + message="处理入站调用响应时连接已失效", + hint=str(exc), + retryable=True, + ) + ) + asyncio.create_task(self._fail_connection(error)) + + task.add_done_callback(_on_invoke_done) + return + if isinstance(message, CancelMessage): + await self._handle_cancel(message) + return + except Exception as exc: + if isinstance(exc, AstrBotError): + error = exc + else: + error = AstrBotError.protocol_error(f"无法解析协议消息: {exc}") + await self._fail_connection(error) + raise error from exc + + async def _handle_initialize(self, message: InitializeMessage) -> None: + """处理远端发起的初始化握手并返回握手结果。""" + self.remote_peer = message.peer + self.remote_handlers = message.handlers + self.remote_provided_capabilities = message.provided_capabilities + self.remote_provided_capability_map = { + item.name: item for item in message.provided_capabilities + } + self.remote_metadata = dict(message.metadata) + if self._initialize_handler is None: + await self._reject_initialize( + message, + AstrBotError.protocol_error("对端不接受 initialize"), + ) + return + + negotiated_protocol_version = _select_negotiated_protocol_version( + message.protocol_version, + self.remote_metadata, + self.supported_protocol_versions, + ) + if negotiated_protocol_version is None: + supported_versions = ", ".join(self.supported_protocol_versions) + await self._reject_initialize( + message, + AstrBotError.protocol_version_mismatch( + "服务端支持协议版本 " + f"{supported_versions},客户端请求版本 {message.protocol_version}" + ), + ) + return + + self.negotiated_protocol_version = negotiated_protocol_version + self.remote_metadata[NEGOTIATED_PROTOCOL_VERSION_METADATA_KEY] = ( + negotiated_protocol_version + ) + output = await self._initialize_handler(message) + response_metadata = dict(output.metadata) + response_metadata[NEGOTIATED_PROTOCOL_VERSION_METADATA_KEY] = ( + negotiated_protocol_version + ) + output = output.model_copy( + update={ + "protocol_version": negotiated_protocol_version, + "metadata": response_metadata, + } + ) + await self._send( + ResultMessage( + id=message.id, + kind="initialize_result", + success=True, + output=output.model_dump(), + ) + ) + self._remote_initialized_successfully = True + self._remote_initialized.set() + + async def _handle_invoke( + self, + message: InvokeMessage, + token: CancelToken, + started: asyncio.Event, + ) -> None: + """处理远端发起的能力调用,并按流式或非流式协议返回结果。""" + try: + started.set() + token.raise_if_cancelled() + if self._invoke_handler is None: + raise AstrBotError.capability_not_found(message.capability) + with caller_plugin_scope(message.caller_plugin_id): + execution = await self._invoke_handler(message, token) + if inspect.isawaitable(execution): + execution = await execution + if message.stream: + if not isinstance(execution, StreamExecution): + raise AstrBotError.protocol_error( + "stream=true 必须返回 StreamExecution" + ) + await self._send(EventMessage(id=message.id, phase="started")) + collect_chunks = execution.collect_chunks + chunks: list[dict[str, Any]] = [] + async for chunk in execution.iterator: + if collect_chunks: + chunks.append(chunk) + await self._send( + EventMessage(id=message.id, phase="delta", data=chunk) + ) + await self._send( + EventMessage( + id=message.id, + phase="completed", + output=execution.finalize(chunks), + ) + ) + return + if isinstance(execution, StreamExecution): + raise AstrBotError.protocol_error("stream=false 不能返回流式执行对象") + await self._send( + ResultMessage(id=message.id, success=True, output=execution) + ) + except asyncio.CancelledError: + await self._send_cancelled_termination(message) + except LookupError as exc: + error = AstrBotError.invalid_input(str(exc)) + await self._send_error_result(message, error) + except AstrBotError as exc: + await self._send_error_result(message, exc) + except Exception as exc: + await self._send_error_result( + message, AstrBotError.internal_error(str(exc)) + ) + + async def _handle_cancel(self, message: CancelMessage) -> None: + """处理远端取消请求并终止对应的入站任务。""" + inbound = self._inbound_tasks.get(message.id) + if inbound is None: + return + task, token, started = inbound + token.cancel() + if self._cancel_handler is not None: + await self._cancel_handler(message.id) + if started.is_set(): + task.cancel() + + async def _handle_result(self, message: ResultMessage) -> None: + """处理非流式结果消息并唤醒等待中的调用方。""" + future = self._pending_results.pop(message.id, None) + if future is None: + queue = self._pending_streams.get(message.id) + if queue is not None: + await queue.put( + AstrBotError.protocol_error("stream=true 调用不应收到 result") + ) + return + # 检查 future 是否已完成(可能被调用方取消) + if not future.done(): + future.set_result(message) + + async def _handle_event(self, message: EventMessage) -> None: + """处理流式事件消息并投递到对应请求的事件队列。""" + queue = self._pending_streams.get(message.id) + if queue is None: + future = self._pending_results.get(message.id) + if future is not None and not future.done(): + future.set_exception( + AstrBotError.protocol_error("stream=false 调用不应收到 event") + ) + return + await queue.put(message) + + async def _send_error_result( + self, message: InvokeMessage, error: AstrBotError + ) -> None: + """根据调用模式,将错误编码为 `result` 或失败事件发回远端。""" + if message.stream: + await self._send( + EventMessage( + id=message.id, + phase="failed", + error=ErrorPayload.model_validate(error.to_payload()), + ) + ) + return + await self._send( + ResultMessage( + id=message.id, + success=False, + error=ErrorPayload.model_validate(error.to_payload()), + ) + ) + + async def _reject_initialize( + self, message: InitializeMessage, error: AstrBotError + ) -> None: + """拒绝一次初始化握手,并把连接标记为不可继续使用。""" + await self._send( + ResultMessage( + id=message.id, + kind="initialize_result", + success=False, + error=ErrorPayload.model_validate(error.to_payload()), + ) + ) + self._unusable = True + self._remote_initialized.set() + await self.stop() + + async def _send_cancelled_termination(self, message: InvokeMessage) -> None: + """把本端取消执行转换为标准化的取消错误响应。""" + error = AstrBotError.cancelled() + await self._send_error_result(message, error) + + async def _fail_connection(self, error: AstrBotError) -> None: + """把连接标记为不可用,并让所有等待中的调用尽快失败。""" + if self._unusable: + return + self._unusable = True + self._remote_initialized.set() + + for future in list(self._pending_results.values()): + if not future.done(): + future.set_exception(error) + self._pending_results.clear() + + for queue in list(self._pending_streams.values()): + await queue.put(error) + self._pending_streams.clear() + + for task, token, _started in list(self._inbound_tasks.values()): + token.cancel() + task.cancel() + self._inbound_tasks.clear() + + asyncio.create_task(self.stop()) + + async def _send(self, message) -> None: + """序列化协议消息并通过底层传输发送出去。""" + await self.transport.send(message.model_dump_json(exclude_none=True)) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/supervisor.py b/astrbot-sdk/src/astrbot_sdk/runtime/supervisor.py new file mode 100644 index 0000000000..1727f218f1 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/supervisor.py @@ -0,0 +1,871 @@ +"""Supervisor 端运行时:SupervisorRuntime 管理多个 Worker 进程,WorkerSession 封装与单个 Worker 的通信。 + +架构层次: + AstrBot Core (Python) + | + v + SupervisorRuntime (管理多插件) + | + +-- WorkerSession (插件 A) -- StdioTransport -- PluginWorkerRuntime (子进程) + | + +-- WorkerSession (插件 B) -- StdioTransport -- PluginWorkerRuntime (子进程) + | + +-- WorkerSession (插件 C) -- StdioTransport -- PluginWorkerRuntime (子进程) + +核心类: + SupervisorRuntime: 监管者运行时 + - 发现并加载所有插件 + - 为每个插件启动 Worker 进程 + - 聚合所有 handler 并向 Core 注册 + - 路由 Core 的调用请求到对应 Worker + - 处理 Worker 进程崩溃和重连 + - handler ID 冲突检测和警告 + + WorkerSession: Worker 会话 + - 管理单个插件 Worker 进程 + - 通过 Peer 与 Worker 通信 + - 提供 invoke_handler 和 cancel 方法 + - 处理连接关闭回调 + - 自动清理已注册的 handlers + +信号处理: + - SIGTERM: 设置 stop_event,触发优雅关闭 + - SIGINT: 设置 stop_event,触发优雅关闭 +""" + +from __future__ import annotations + +import asyncio +import os +import signal +import sys +from collections.abc import Callable +from pathlib import Path +from typing import IO, Any, cast + +from loguru import logger + +from ..errors import AstrBotError +from ..protocol.descriptors import CapabilityDescriptor +from ..protocol.messages import EventMessage, InitializeOutput, PeerInfo +from .capability_router import CapabilityRouter, StreamExecution +from .environment_groups import EnvironmentGroup +from .loader import ( + PluginDiscoveryIssue, + PluginEnvironmentManager, + PluginSpec, + discover_plugins, + load_plugin_config, +) +from .peer import Peer +from .transport import StdioTransport + +__all__ = [ + "SupervisorRuntime", + "WorkerSession", + "_install_signal_handlers", + "_prepare_stdio_transport", + "_sdk_source_dir", + "_wait_for_shutdown", +] + + +def _install_signal_handlers(stop_event: asyncio.Event) -> None: + loop = asyncio.get_running_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + try: + loop.add_signal_handler(sig, stop_event.set) + except NotImplementedError: + logger.debug("Signal handlers are not supported for {}", sig) + + +def _prepare_stdio_transport( + stdin: IO[str] | None, + stdout: IO[str] | None, +) -> tuple[IO[str], IO[str], IO[str] | None]: + if stdin is not None and stdout is not None: + return stdin, stdout, None + transport_stdin = stdin or sys.stdin + transport_stdout = stdout or sys.stdout + original_stdout = sys.stdout + sys.stdout = sys.stderr + return transport_stdin, transport_stdout, original_stdout + + +def _sdk_source_dir(repo_root: Path) -> Path: + candidate = repo_root.resolve() / "src" + if (candidate / "astrbot_sdk").exists(): + return candidate + return Path(__file__).resolve().parents[2] + + +async def _wait_for_shutdown(peer: Peer, stop_event: asyncio.Event) -> None: + stop_waiter = asyncio.create_task(stop_event.wait()) + transport_waiter = asyncio.create_task(peer.wait_closed()) + done, pending = await asyncio.wait( + {stop_waiter, transport_waiter}, + return_when=asyncio.FIRST_COMPLETED, + ) + for task in pending: + task.cancel() + for task in done: + if not task.cancelled(): + task.result() + + +def _plugin_name_from_handler_id(handler_id: str) -> str: + if ":" in handler_id: + return handler_id.split(":", 1)[0] + return handler_id + + +class WorkerSession: + def __init__( + self, + *, + plugin: PluginSpec | None = None, + group: EnvironmentGroup | None = None, + repo_root: Path, + env_manager: PluginEnvironmentManager, + capability_router: CapabilityRouter, + on_closed: Callable[[], None] | None = None, + ) -> None: + if plugin is None and group is None: + raise ValueError("WorkerSession requires either plugin or group") + group_ref = group + if group_ref is not None: + primary_plugin = group_ref.plugins[0] + else: + assert plugin is not None + primary_plugin = plugin + self.group = group + self.plugins = ( + list(group_ref.plugins) if group_ref is not None else [primary_plugin] + ) + self.plugin = primary_plugin + self.group_id = group_ref.id if group_ref is not None else primary_plugin.name + self.repo_root = repo_root.resolve() + self.env_manager = env_manager + self.capability_router = capability_router + self.on_closed = on_closed + self.peer: Peer | None = None + self.handlers = [] + self.provided_capabilities: list[CapabilityDescriptor] = [] + self.loaded_plugins: list[str] = [] + self.skipped_plugins: dict[str, str] = {} + self.issues: list[PluginDiscoveryIssue] = [] + self.capability_sources: dict[str, str] = {} + self.llm_tools: list[dict[str, Any]] = [] + self.agents: list[dict[str, Any]] = [] + self._connection_watch_task: asyncio.Task[None] | None = None + + async def start(self) -> None: + python_path, command, cwd = self._worker_command() + repo_src_dir = str(_sdk_source_dir(self.repo_root)) + env = os.environ.copy() + existing_pythonpath = env.get("PYTHONPATH") + env["PYTHONPATH"] = ( + f"{repo_src_dir}{os.pathsep}{existing_pythonpath}" + if existing_pythonpath + else repo_src_dir + ) + env.setdefault("PYTHONIOENCODING", "utf-8") + env.setdefault("PYTHONUTF8", "1") + + transport = StdioTransport( + command=command, + cwd=cwd, + env=env, + ) + self.peer = Peer( + transport=transport, + peer_info=PeerInfo(name="astrbot-core", role="core", version="v4"), + ) + self.peer.set_initialize_handler(self._handle_initialize) + self.peer.set_invoke_handler(self._handle_capability_invoke) + try: + await self.peer.start() + # 同时监听初始化完成和连接关闭,避免 worker 崩溃时等满超时 + init_task = asyncio.create_task( + self.peer.wait_until_remote_initialized(timeout=None) + ) + closed_task = asyncio.create_task(self.peer.wait_closed()) + done, pending = await asyncio.wait( + {init_task, closed_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + if init_task in done: + await init_task + + if closed_task in done: + raise RuntimeError(f"worker 组 {self.group_id} 在初始化阶段退出") + + self.handlers = list(self.peer.remote_handlers) + self.provided_capabilities = list(self.peer.remote_provided_capabilities) + metadata = dict(self.peer.remote_metadata) + remote_loaded_plugins = metadata.get("loaded_plugins") + if isinstance(remote_loaded_plugins, list): + self.loaded_plugins = [ + plugin_name + for plugin_name in remote_loaded_plugins + if isinstance(plugin_name, str) + ] + else: + self.loaded_plugins = [plugin.name for plugin in self.plugins] + remote_skipped_plugins = metadata.get("skipped_plugins") + if isinstance(remote_skipped_plugins, dict): + self.skipped_plugins = { + str(plugin_name): str(reason) + for plugin_name, reason in remote_skipped_plugins.items() + } + remote_capability_sources = metadata.get("capability_sources") + if isinstance(remote_capability_sources, dict): + self.capability_sources = { + str(capability_name): str(plugin_name) + for capability_name, plugin_name in remote_capability_sources.items() + } + remote_issues = metadata.get("issues") + if isinstance(remote_issues, list): + self.issues = [ + PluginDiscoveryIssue( + severity=str(item.get("severity", "error")), # type: ignore[arg-type] + phase=str(item.get("phase", "load")), # type: ignore[arg-type] + plugin_id=str(item.get("plugin_id", self.plugin.name)), + message=str(item.get("message", "")), + details=str(item.get("details", "")), + hint=str(item.get("hint", "")), + ) + for item in remote_issues + if isinstance(item, dict) + ] + remote_llm_tools = metadata.get("llm_tools") + if isinstance(remote_llm_tools, list): + self.llm_tools = [ + dict(item) for item in remote_llm_tools if isinstance(item, dict) + ] + remote_agents = metadata.get("agents") + if isinstance(remote_agents, list): + self.agents = [ + dict(item) for item in remote_agents if isinstance(item, dict) + ] + + except Exception: + await self.stop() + raise + + def _worker_command(self) -> tuple[Path, list[str], str]: + if self.group is not None: + prepare_group = getattr(self.env_manager, "prepare_group_environment", None) + if callable(prepare_group): + python_path = cast(Path, prepare_group(self.group)) + else: + python_path = self.env_manager.prepare_environment(self.plugins[0]) + return ( + python_path, + [ + str(python_path), + "-m", + "astrbot_sdk", + "worker", + "--group-metadata", + str(self.group.metadata_path), + ], + str(self.repo_root), + ) + + plugin = self.plugin + python_path = self.env_manager.prepare_environment(plugin) + return ( + python_path, + [ + str(python_path), + "-m", + "astrbot_sdk", + "worker", + "--plugin-dir", + str(plugin.plugin_dir), + ], + str(plugin.plugin_dir), + ) + + def start_close_watch(self) -> None: + if ( + self.on_closed is None + or self.peer is None + or self._connection_watch_task is not None + ): + return + self._connection_watch_task = asyncio.create_task(self._watch_connection()) + + async def _watch_connection(self) -> None: + """监听 Worker 连接关闭,触发清理回调""" + try: + if self.peer is not None: + await self.peer.wait_closed() + if self.on_closed is not None: + try: + self.on_closed() + except Exception: + logger.exception( + "on_closed callback failed for worker group {}", self.group_id + ) + finally: + current_task = asyncio.current_task() + if self._connection_watch_task is current_task: + self._connection_watch_task = None + + async def stop(self) -> None: + if self.peer is not None: + await self.peer.stop() + + async def invoke_handler( + self, + handler_id: str, + event_payload: dict[str, Any], + *, + request_id: str, + args: dict[str, Any] | None = None, + ) -> dict[str, Any]: + if self.peer is None: + raise RuntimeError("worker session is not running") + return await self.peer.invoke( + "handler.invoke", + { + "handler_id": handler_id, + "event": event_payload, + "args": dict(args or {}), + }, + request_id=request_id, + ) + + async def invoke_capability( + self, + capability_name: str, + payload: dict[str, Any], + *, + request_id: str, + ) -> dict[str, Any]: + if self.peer is None: + raise RuntimeError("worker session is not running") + return await self.peer.invoke( + capability_name, + payload, + request_id=request_id, + ) + + async def invoke_capability_stream( + self, + capability_name: str, + payload: dict[str, Any], + *, + request_id: str, + ): + if self.peer is None: + raise RuntimeError("worker session is not running") + event_stream = await self.peer.invoke_stream( + capability_name, + payload, + request_id=request_id, + include_completed=True, + ) + async for event in event_stream: + yield event + + async def cancel(self, request_id: str) -> None: + if self.peer is None: + return + await self.peer.cancel(request_id) + + async def _handle_initialize(self, _message) -> InitializeOutput: + return InitializeOutput( + peer=PeerInfo(name="astrbot-supervisor", role="core", version="v4"), + capabilities=self.capability_router.descriptors(), + metadata={ + "group_id": self.group_id, + "plugins": [plugin.name for plugin in self.plugins], + }, + ) + + async def _handle_capability_invoke(self, message, cancel_token): + return await self.capability_router.execute( + message.capability, + message.input, + stream=message.stream, + cancel_token=cancel_token, + request_id=message.id, + ) + + def describe(self) -> dict[str, Any]: + return { + "group_id": self.group_id, + "plugins": [plugin.name for plugin in self.plugins], + "loaded_plugins": list(self.loaded_plugins), + "skipped_plugins": dict(self.skipped_plugins), + "issues": [issue.to_payload() for issue in self.issues], + } + + +class SupervisorRuntime: + def __init__( + self, + *, + transport, + plugins_dir: Path, + env_manager: PluginEnvironmentManager | None = None, + ) -> None: + self.transport = transport + self.plugins_dir = plugins_dir.resolve() + self.repo_root = Path(__file__).resolve().parents[3] + self.env_manager = env_manager or PluginEnvironmentManager(self.repo_root) + self.capability_router = CapabilityRouter() + self.peer = Peer( + transport=self.transport, + peer_info=PeerInfo(name="astrbot-supervisor", role="plugin", version="v4"), + ) + self.peer.set_invoke_handler(self._handle_upstream_invoke) + self.peer.set_cancel_handler(self._handle_upstream_cancel) + self.worker_sessions: dict[str, WorkerSession] = {} + self.handler_to_worker: dict[str, WorkerSession] = {} + self.capability_to_worker: dict[str, WorkerSession] = {} + self.plugin_to_worker_session: dict[str, WorkerSession] = {} + self._handler_sources: dict[str, str] = {} # handler_id -> plugin_name + self._capability_sources: dict[str, str] = {} # capability_name -> plugin_name + self.active_requests: dict[str, WorkerSession] = {} + self.loaded_plugins: list[str] = [] + self.skipped_plugins: dict[str, str] = {} + self.issues: list[PluginDiscoveryIssue] = [] + self._register_internal_capabilities() + + def _publish_plugin_registry_snapshot( + self, + plugins: list[PluginSpec], + *, + enabled_plugins: set[str], + ) -> None: + for plugin in plugins: + manifest = plugin.manifest_data + self.capability_router.upsert_plugin( + metadata={ + "name": plugin.name, + "display_name": str(manifest.get("display_name") or plugin.name), + "description": str( + manifest.get("desc") or manifest.get("description") or "" + ), + "author": str(manifest.get("author") or ""), + "version": str(manifest.get("version") or "0.0.0"), + "enabled": plugin.name in enabled_plugins, + }, + config=load_plugin_config(plugin), + ) + + def _publish_discovered_plugin_registry(self, plugins: list[PluginSpec]) -> None: + """发布已发现插件的静态元数据。 + + 这一阶段发生在 worker 真正启动前。此时 supervisor 已经知道有哪些插件、 + 它们的 manifest/config 是什么,但尚未确认哪些插件实际完成加载,因此统一 + 以 `enabled=False` 暴露给 metadata 能力。 + """ + self._publish_plugin_registry_snapshot(plugins, enabled_plugins=set()) + + def _publish_loaded_plugin_registry(self, plugins: list[PluginSpec]) -> None: + """在 worker 启动完成后刷新插件启用状态。""" + self._publish_plugin_registry_snapshot( + plugins, + enabled_plugins=set(self.loaded_plugins), + ) + + def _register_internal_capabilities(self) -> None: + self.capability_router.register( + CapabilityDescriptor( + name="handler.invoke", + description="框架内部:转发到插件 handler", + input_schema={ + "type": "object", + "properties": { + "handler_id": {"type": "string"}, + "event": {"type": "object"}, + }, + "required": ["handler_id", "event"], + }, + output_schema={ + "type": "object", + "properties": {}, + "required": [], + }, + cancelable=True, + ), + call_handler=self._route_handler_invoke, + exposed=False, + ) + + def _register_handler( + self, handler, session: WorkerSession, plugin_name: str + ) -> None: + """注册 handler,处理冲突时输出警告。 + + Args: + handler: Handler 描述符 + session: Worker 会话 + plugin_name: 插件名称 + """ + handler_id = handler.id + existing_plugin = self._handler_sources.get(handler_id) + + if existing_plugin is not None: + logger.warning( + f"Handler ID 冲突:'{handler_id}' 已被插件 '{existing_plugin}' 注册," + f"现在被插件 '{plugin_name}' 覆盖。" + ) + + self.handler_to_worker[handler_id] = session + self._handler_sources[handler_id] = plugin_name + + def _register_plugin_capability( + self, + descriptor: CapabilityDescriptor, + session: WorkerSession, + plugin_name: str, + ) -> None: + """注册插件 capability,处理命名冲突。 + + 当 capability 名称冲突时: + - 如果是保留命名空间(handler/system/internal),跳过并警告 + - 否则,使用插件名作为前缀重新命名,例如: + - 插件 'my_plugin' 注册 'demo.echo' 冲突 + - 自动重命名为 'my_plugin.demo.echo' + """ + capability_name = descriptor.name + + if not self.capability_router.contains(capability_name): + # 无冲突,直接注册 + self._do_register_capability( + descriptor, session, capability_name, plugin_name + ) + return + + # 检查是否在保留命名空间内 + if capability_name.startswith(("handler.", "system.", "internal.")): + logger.warning( + "Capability '{}' 在保留命名空间内,跳过插件 '{}' 的注册。" + "保留命名空间不允许插件覆盖。", + capability_name, + plugin_name, + ) + return + + # 尝试添加插件前缀解决冲突 + prefixed_name = f"{plugin_name}.{capability_name}" + if self.capability_router.contains(prefixed_name): + logger.warning( + "Capability '{}' 和 '{}.{}' 均已存在," + "跳过插件 '{}' 的注册。请考虑使用更唯一的命名。", + capability_name, + plugin_name, + capability_name, + plugin_name, + ) + return + + # 使用前缀名称注册 + prefixed_descriptor = descriptor.model_copy(deep=True) + prefixed_descriptor.name = prefixed_name + logger.info( + "Capability '{}' 与已注册能力冲突,自动重命名为 '{}' (插件: {})。", + capability_name, + prefixed_name, + plugin_name, + ) + self._do_register_capability( + prefixed_descriptor, session, prefixed_name, plugin_name + ) + # 记录原始名称到前缀名称的映射,便于调试 + self._capability_sources[f"_original:{prefixed_name}"] = capability_name + + def _do_register_capability( + self, + descriptor: CapabilityDescriptor, + session: WorkerSession, + capability_name: str, + plugin_name: str, + ) -> None: + """实际执行 capability 注册。""" + self.capability_router.register( + descriptor, + call_handler=self._make_plugin_capability_caller(session, capability_name), + stream_handler=( + self._make_plugin_capability_streamer(session, capability_name) + if descriptor.supports_stream + else None + ), + ) + self.capability_to_worker[capability_name] = session + self._capability_sources[capability_name] = plugin_name + + def _make_plugin_capability_caller( + self, + session: WorkerSession, + capability_name: str, + ): + async def call_handler( + request_id: str, + payload: dict[str, Any], + _cancel_token, + ) -> dict[str, Any]: + self.active_requests[request_id] = session + try: + return await session.invoke_capability( + capability_name, + payload, + request_id=request_id, + ) + finally: + self.active_requests.pop(request_id, None) + + return call_handler + + def _make_plugin_capability_streamer( + self, + session: WorkerSession, + capability_name: str, + ): + async def stream_handler( + request_id: str, + payload: dict[str, Any], + _cancel_token, + ): + completed_output: dict[str, Any] = {} + + async def iterator(): + self.active_requests[request_id] = session + try: + async for event in session.invoke_capability_stream( + capability_name, + payload, + request_id=request_id, + ): + if not isinstance(event, EventMessage): + raise AstrBotError.protocol_error( + "插件 worker 返回了非法的流式事件" + ) + if event.phase == "delta": + yield event.data or {} + continue + if event.phase == "completed": + completed_output.clear() + completed_output.update(event.output or {}) + finally: + self.active_requests.pop(request_id, None) + + return StreamExecution( + iterator=iterator(), + finalize=lambda chunks: completed_output or {"items": chunks}, + ) + + return stream_handler + + async def start(self) -> None: + discovery = discover_plugins(self.plugins_dir) + self.skipped_plugins = dict(discovery.skipped_plugins) + self.issues = list(discovery.issues) + plan_result = self.env_manager.plan(discovery.plugins) + self.skipped_plugins.update(plan_result.skipped_plugins) + self.issues.extend( + PluginDiscoveryIssue( + severity="error", + phase="load", + plugin_id=plugin_name, + message="插件环境规划失败", + details=str(reason), + ) + for plugin_name, reason in plan_result.skipped_plugins.items() + ) + # 先发布静态插件元数据,允许 supervisor 侧在 worker 启动阶段就读取配置/清单。 + self._publish_discovered_plugin_registry(discovery.plugins) + try: + planned_sessions: list[WorkerSession] = [] + if plan_result.groups: + for group in plan_result.groups: + planned_sessions.append( + WorkerSession( + group=group, + repo_root=self.repo_root, + env_manager=self.env_manager, + capability_router=self.capability_router, + on_closed=lambda group_id=group.id: ( + self._handle_worker_closed(group_id) + ), + ) + ) + else: + for plugin in plan_result.plugins: + planned_sessions.append( + WorkerSession( + plugin=plugin, + repo_root=self.repo_root, + env_manager=self.env_manager, + capability_router=self.capability_router, + on_closed=lambda plugin_name=plugin.name: ( + self._handle_worker_closed(plugin_name) + ), + ) + ) + + for session in planned_sessions: + try: + await session.start() + except Exception as exc: + for plugin in session.plugins: + self.skipped_plugins[plugin.name] = str(exc) + self.issues.append( + PluginDiscoveryIssue( + severity="error", + phase="load", + plugin_id=plugin.name, + message="插件 worker 启动失败", + details=str(exc), + ) + ) + await session.stop() + continue + self.worker_sessions[session.group_id] = session + self.skipped_plugins.update(session.skipped_plugins) + self.issues.extend(session.issues) + for plugin_name in session.loaded_plugins: + self.plugin_to_worker_session[plugin_name] = session + if plugin_name not in self.loaded_plugins: + self.loaded_plugins.append(plugin_name) + for handler in session.handlers: + self._register_handler( + handler, + session, + _plugin_name_from_handler_id(handler.id), + ) + for descriptor in session.provided_capabilities: + plugin_name = session.capability_sources.get(descriptor.name) + if plugin_name is None and len(session.loaded_plugins) == 1: + plugin_name = session.loaded_plugins[0] + if plugin_name is None: + plugin_name = session.group_id + self._register_plugin_capability(descriptor, session, plugin_name) + session.start_close_watch() + + # worker 启动后再用实际加载结果刷新 enabled 状态,形成显式两阶段发布。 + self._publish_loaded_plugin_registry(discovery.plugins) + + aggregated_handlers = list(self.handler_to_worker.keys()) + logger.info( + "Loaded plugins: {}", ", ".join(sorted(self.loaded_plugins)) or "none" + ) + + await self.peer.start() + await self.peer.initialize( + [ + handler + for session in self.worker_sessions.values() + for handler in session.handlers + ], + provided_capabilities=self.capability_router.descriptors(), + metadata={ + "plugins": sorted(self.loaded_plugins), + "skipped_plugins": self.skipped_plugins, + "issues": [issue.to_payload() for issue in self.issues], + "aggregated_handler_ids": aggregated_handlers, + "worker_groups": [ + session.describe() for session in self.worker_sessions.values() + ], + "worker_group_count": len(self.worker_sessions), + }, + ) + except Exception: + await self.stop() + raise + + def _handle_worker_closed(self, group_id: str) -> None: + """Worker 连接关闭时的清理回调""" + session = self.worker_sessions.pop(group_id, None) + if session is None: + return + # 从 handler_to_worker 中移除该插件注册的 handlers(仅当来源仍为此插件时) + for handler in session.handlers: + source_plugin = self._handler_sources.get(handler.id) + if source_plugin == _plugin_name_from_handler_id(handler.id) or ( + source_plugin == group_id + ): + self.handler_to_worker.pop(handler.id, None) + self._handler_sources.pop(handler.id, None) + for descriptor in session.provided_capabilities: + source_plugin = self._capability_sources.get(descriptor.name) + capability_plugin = session.capability_sources.get(descriptor.name) + if source_plugin == capability_plugin or ( + capability_plugin is None + and ( + source_plugin == group_id or source_plugin in session.loaded_plugins + ) + ): + self.capability_to_worker.pop(descriptor.name, None) + self._capability_sources.pop(descriptor.name, None) + self.capability_router.unregister(descriptor.name) + session_loaded_plugins = getattr(session, "loaded_plugins", None) + if not isinstance(session_loaded_plugins, list): + session_loaded_plugins = [group_id] + for plugin_name in session_loaded_plugins: + if plugin_name in self.loaded_plugins: + self.loaded_plugins.remove(plugin_name) + self.plugin_to_worker_session.pop(plugin_name, None) + self.capability_router.set_plugin_enabled(plugin_name, False) + self.capability_router.remove_http_apis_for_plugin(plugin_name) + stale_requests = [ + request_id + for request_id, active_session in self.active_requests.items() + if active_session is session + ] + for request_id in stale_requests: + self.active_requests.pop(request_id, None) + logger.warning("worker 组 {} 连接已关闭,已清理相关 handlers", group_id) + + async def stop(self) -> None: + for session in list(self.worker_sessions.values()): + await session.stop() + await self.peer.stop() + + async def _handle_upstream_invoke(self, message, cancel_token): + return await self.capability_router.execute( + message.capability, + message.input, + stream=message.stream, + cancel_token=cancel_token, + request_id=message.id, + ) + + async def _route_handler_invoke( + self, + request_id: str, + payload: dict[str, Any], + _cancel_token, + ) -> dict[str, Any]: + handler_id = str(payload.get("handler_id", "")) + session = self.handler_to_worker.get(handler_id) + if session is None: + raise AstrBotError.invalid_input(f"handler not found: {handler_id}") + self.active_requests[request_id] = session + try: + return await session.invoke_handler( + handler_id, + payload.get("event", {}), + request_id=request_id, + args=payload.get("args", {}), + ) + finally: + self.active_requests.pop(request_id, None) + + async def _handle_upstream_cancel(self, request_id: str) -> None: + session = self.active_requests.get(request_id) + if session is not None: + await session.cancel(request_id) diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/transport.py b/astrbot-sdk/src/astrbot_sdk/runtime/transport.py new file mode 100644 index 0000000000..d4c55cdca6 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/transport.py @@ -0,0 +1,403 @@ +"""传输层抽象模块。 + +定义 Transport 抽象基类及其实现,负责底层的消息传输。 +传输层只关心"发送字符串"和"接收字符串",不处理协议细节。 +传输实现: + Transport: 抽象基类,定义 start/stop/send/wait_closed 接口 + StdioTransport: 标准输入输出传输 + - 进程模式: 通过 command 参数启动子进程 + - 文件模式: 通过 stdin/stdout 参数指定文件描述符 + +传输类型: + Transport: 抽象基类,定义 start/stop/send 接口 + StdioTransport: 标准输入输出传输,支持进程模式和文件模式 + WebSocketServerTransport: WebSocket 服务端传输 + - 单连接限制,支持心跳配置 + - 通过 port 属性获取实际监听端口 + - 自动重连需要外部实现 + +使用示例: + # 子进程模式 + transport = StdioTransport( + command=["python", "-m", "my_plugin"], + cwd="/path/to/plugin", + ) + + # 标准输入输出模式 + transport = StdioTransport(stdin=sys.stdin, stdout=sys.stdout) + + # WebSocket 服务端 + transport = WebSocketServerTransport(host="0.0.0.0", port=8765) + + # WebSocket 客户端 + transport = WebSocketClientTransport(url="ws://localhost:8765") + + # 统一接口 + transport.set_message_handler(my_handler) + await transport.start() + await transport.send(json_string) + await transport.stop() + +`Transport` 只处理“字符串发出去 / 字符串收进来”这件事,不做协议解析,也不关心 +能力、handler 或迁移适配策略。当前实现包括: + +- `StdioTransport`: 子进程或文件对象上的按行文本传输 +- `WebSocketServerTransport`: 单连接 WebSocket 服务端 +- `WebSocketClientTransport`: WebSocket 客户端 + +自动重连、消息重放等策略不在这里实现,统一留给更上层编排。 +""" + +from __future__ import annotations + +import asyncio +import sys +from abc import ABC, abstractmethod +from collections.abc import Awaitable, Callable, Sequence +from typing import IO, Any + +from loguru import logger + +MessageHandler = Callable[[str], Awaitable[None]] + + +def _get_aiohttp(): + import aiohttp + + return aiohttp + + +def _get_web(): + from aiohttp import web + + return web + + +def _frame_stdio_payload(payload: str) -> str: + body = payload + if body.endswith("\r\n"): + body = body[:-2] + elif body.endswith(("\n", "\r")): + body = body[:-1] + if "\n" in body or "\r" in body: + raise ValueError("STDIO payload 不允许包含原始换行符") + return f"{body}\n" + + +# TODO 一个更好的解决方案? +def _is_windows_access_denied(error: BaseException) -> bool: + return ( + sys.platform == "win32" + and isinstance(error, PermissionError) + and getattr(error, "winerror", None) == 5 + ) + + +class Transport(ABC): + def __init__(self) -> None: + self._handler: MessageHandler | None = None + self._closed = asyncio.Event() + + def set_message_handler(self, handler: MessageHandler) -> None: + """注册收到原始字符串消息后的回调。""" + self._handler = handler + + @abstractmethod + async def start(self) -> None: + raise NotImplementedError + + @abstractmethod + async def stop(self) -> None: + raise NotImplementedError + + @abstractmethod + async def send(self, payload: str) -> None: + raise NotImplementedError + + async def wait_closed(self) -> None: + """等待传输层进入关闭状态。""" + await self._closed.wait() + + async def _dispatch(self, payload: str) -> None: + """把收到的原始载荷转交给上层处理器。""" + if self._handler is not None: + await self._handler(payload) + + +class StdioTransport(Transport): + def __init__( + self, + *, + stdin: IO[str] | None = None, + stdout: IO[str] | None = None, + command: Sequence[str] | None = None, + cwd: str | None = None, + env: dict[str, str] | None = None, + ) -> None: + super().__init__() + self._stdin = stdin + self._stdout = stdout + self._command = list(command) if command is not None else None + self._cwd = cwd + self._env = env + self._process: asyncio.subprocess.Process | None = None + self._reader_task: asyncio.Task[None] | None = None + + async def start(self) -> None: + self._closed.clear() + if self._command is not None: + self._process = await self._start_subprocess_with_retry() + self._reader_task = asyncio.create_task(self._read_process_loop()) + return + + self._stdin = self._stdin or sys.stdin + self._stdout = self._stdout or sys.stdout + self._reader_task = asyncio.create_task(self._read_file_loop()) + + async def _start_subprocess_with_retry(self) -> asyncio.subprocess.Process: + assert self._command is not None # 类型收窄:start() 已确保非空 + delays = [0.15, 0.35, 0.75] + last_error: BaseException | None = None + for attempt, delay in enumerate([0.0, *delays], start=1): + if delay: + await asyncio.sleep(delay) + try: + return await asyncio.create_subprocess_exec( + *self._command, + cwd=self._cwd, + env=self._env, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=sys.stderr, + ) + except Exception as exc: + last_error = exc + if not _is_windows_access_denied(exc) or attempt == len(delays) + 1: + raise + logger.warning( + "Windows denied access while starting freshly prepared worker " + "interpreter, retrying attempt {}/{}: {}", + attempt, + len(delays) + 1, + exc, + ) + assert last_error is not None + raise last_error + + async def stop(self) -> None: + if self._reader_task is not None: + self._reader_task.cancel() + try: + await self._reader_task + except asyncio.CancelledError: + pass + self._reader_task = None + + if self._process is not None: + if self._process.returncode is None: + self._process.terminate() + try: + await asyncio.wait_for(self._process.wait(), timeout=5) + except asyncio.TimeoutError: + self._process.kill() + await self._process.wait() + self._process = None + self._closed.set() + + async def send(self, payload: str) -> None: + line = _frame_stdio_payload(payload) + if self._process is not None: + if self._process.stdin is None: + raise RuntimeError("STDIO subprocess stdin 不可用") + self._process.stdin.write(line.encode("utf-8")) + await self._process.stdin.drain() + return + + if self._stdout is None: + raise RuntimeError("STDIO stdout 不可用") + + def _write() -> None: + assert self._stdout is not None + self._stdout.write(line) + self._stdout.flush() + + await asyncio.to_thread(_write) + + async def _read_process_loop(self) -> None: + assert self._process is not None + assert self._process.stdout is not None + try: + while True: + raw = await self._process.stdout.readline() + if not raw: + break + await self._dispatch(raw.decode("utf-8").rstrip("\r\n")) + finally: + self._closed.set() + + async def _read_file_loop(self) -> None: + assert self._stdin is not None + try: + while True: + raw = await asyncio.to_thread(self._stdin.readline) + if not raw: + break + await self._dispatch(raw.rstrip("\r\n")) + finally: + self._closed.set() + + +class WebSocketServerTransport(Transport): + def __init__( + self, + *, + host: str = "127.0.0.1", + port: int = 8765, + path: str = "/", + heartbeat: float = 30.0, + ) -> None: + super().__init__() + self._host = host + self._port = port + self._actual_port: int | None = None + self._path = path + self._heartbeat = heartbeat + self._app: Any | None = None + self._runner: Any | None = None + self._site: Any | None = None + self._ws: Any | None = None + self._write_lock = asyncio.Lock() + self._connected = asyncio.Event() + + async def start(self) -> None: + web = _get_web() + self._closed.clear() + self._connected.clear() + self._app = web.Application() + self._app.router.add_get(self._path, self._handle_socket) + self._runner = web.AppRunner(self._app) + await self._runner.setup() + self._site = web.TCPSite(self._runner, self._host, self._port) + await self._site.start() + if self._site._server and getattr(self._site._server, "sockets", None): + socket = self._site._server.sockets[0] + self._actual_port = socket.getsockname()[1] + + async def stop(self) -> None: + self._connected.clear() + if self._ws is not None and not self._ws.closed: + await self._ws.close() + if self._site is not None: + await self._site.stop() + self._site = None + if self._runner is not None: + await self._runner.cleanup() + self._runner = None + self._closed.set() + + async def send(self, payload: str) -> None: + if self._ws is None or self._ws.closed: + await asyncio.wait_for(self._connected.wait(), timeout=30.0) + if self._ws is None or self._ws.closed: + raise RuntimeError("WebSocket 尚未连接") + async with self._write_lock: + await self._ws.send_str(payload) + + async def _handle_socket(self, request) -> Any: + web = _get_web() + aiohttp = _get_aiohttp() + if self._ws is not None and not self._ws.closed: + ws = web.WebSocketResponse() + await ws.prepare(request) + await ws.close(code=1008, message=b"only one websocket connection allowed") + return ws + + ws = web.WebSocketResponse( + heartbeat=self._heartbeat if self._heartbeat > 0 else None + ) + await ws.prepare(request) + self._ws = ws + self._connected.set() + try: + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + await self._dispatch(msg.data) + elif msg.type == aiohttp.WSMsgType.BINARY: + await self._dispatch(msg.data.decode("utf-8")) + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error("websocket server error: {}", ws.exception()) + break + finally: + self._connected.clear() + self._closed.set() + self._ws = None + return ws + + @property + def port(self) -> int: + return self._actual_port or self._port + + @property + def url(self) -> str: + return f"ws://{self._host}:{self.port}{self._path}" + + +class WebSocketClientTransport(Transport): + def __init__( + self, + *, + url: str, + heartbeat: float = 30.0, + ) -> None: + super().__init__() + self._url = url + self._heartbeat = heartbeat + self._session: Any | None = None + self._ws: Any | None = None + self._reader_task: asyncio.Task[None] | None = None + + async def start(self) -> None: + aiohttp = _get_aiohttp() + self._closed.clear() + self._session = aiohttp.ClientSession() + self._ws = await self._session.ws_connect( + self._url, + heartbeat=self._heartbeat if self._heartbeat > 0 else None, + ) + self._reader_task = asyncio.create_task(self._read_loop()) + + async def stop(self) -> None: + if self._reader_task is not None: + self._reader_task.cancel() + try: + await self._reader_task + except asyncio.CancelledError: + pass + self._reader_task = None + if self._ws is not None and not self._ws.closed: + await self._ws.close() + if self._session is not None: + await self._session.close() + self._ws = None + self._session = None + self._closed.set() + + async def send(self, payload: str) -> None: + if self._ws is None or self._ws.closed: + raise RuntimeError("WebSocket client 尚未连接") + await self._ws.send_str(payload) + + async def _read_loop(self) -> None: + assert self._ws is not None + aiohttp = _get_aiohttp() + try: + async for msg in self._ws: + if msg.type == aiohttp.WSMsgType.TEXT: + await self._dispatch(msg.data) + elif msg.type == aiohttp.WSMsgType.BINARY: + await self._dispatch(msg.data.decode("utf-8")) + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error("websocket client error: {}", self._ws.exception()) + break + finally: + self._closed.set() diff --git a/astrbot-sdk/src/astrbot_sdk/runtime/worker.py b/astrbot-sdk/src/astrbot_sdk/runtime/worker.py new file mode 100644 index 0000000000..3b593fe253 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/runtime/worker.py @@ -0,0 +1,460 @@ +"""Worker 端运行时:PluginWorkerRuntime 运行单个插件,GroupWorkerRuntime 在同一进程中运行多个插件。 + +核心类: + GroupWorkerRuntime: 组 Worker 运行时 + - 在同一进程中加载并运行多个插件 + - 聚合所有插件的 handlers 和 capabilities + - 统一处理 invoke 和 cancel 请求 + - 管理每个插件的生命周期回调 + + PluginWorkerRuntime: 单插件 Worker 运行时 + - 加载单个插件 + - 通过 Peer 与 Supervisor 通信 + - 分发 handler 调用 + - 处理生命周期回调 (on_start, on_stop) + +启动流程: + Worker 启动: + 1. load_plugin_spec() 加载插件规范 + 2. load_plugin() 加载插件组件 + 3. 创建 Peer 并设置处理器 + 4. 向 Supervisor 发送 initialize + 5. 等待 Supervisor 的 initialize_result + 6. 执行 on_start 生命周期回调 +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from loguru import logger + +from .._internal.decorator_lifecycle import run_lifecycle_with_decorators +from .._internal.invocation_context import caller_plugin_scope +from ..context import Context as RuntimeContext +from ..errors import AstrBotError +from ..protocol.messages import PeerInfo +from .handler_dispatcher import CapabilityDispatcher, HandlerDispatcher +from .loader import ( + LoadedPlugin, + PluginDiscoveryIssue, + PluginSpec, + load_plugin, + load_plugin_spec, +) +from .peer import Peer + +__all__ = [ + "GroupPluginRuntimeState", + "GroupWorkerRuntime", + "PluginWorkerRuntime", + "_load_group_plugin_specs", +] + +GLOBAL_MCP_RISK_ATTR = "__astrbot_acknowledge_global_mcp_risk__" + + +@dataclass(slots=True) +class GroupPluginRuntimeState: + plugin: PluginSpec + loaded_plugin: LoadedPlugin + lifecycle_context: RuntimeContext + + +def _plugin_acknowledges_global_mcp_risk(instances: list[Any]) -> bool: + return any( + bool(getattr(instance.__class__, GLOBAL_MCP_RISK_ATTR, False)) + for instance in instances + ) + + +def _metadata_plugin_instances(loaded_plugin: Any) -> list[Any]: + """Return plugin instances for metadata-only inspection. + + Metadata serialization is also exercised by lightweight tests that stub + ``loaded_plugin`` with only the fields relevant to the payload. Missing + ``instances`` means the plugin cannot acknowledge the global MCP risk, but + it should not break issue/metadata reporting. + """ + instances = getattr(loaded_plugin, "instances", []) + if isinstance(instances, list): + return instances + if isinstance(instances, tuple): + return list(instances) + return [] + + +def _load_group_plugin_specs(group_metadata_path: Path) -> tuple[str, list[PluginSpec]]: + try: + payload = json.loads(group_metadata_path.read_text(encoding="utf-8")) + except Exception as exc: + raise RuntimeError( + f"failed to read worker group metadata: {group_metadata_path}" + ) from exc + + if not isinstance(payload, dict): + raise RuntimeError(f"invalid worker group metadata: {group_metadata_path}") + + entries = payload.get("plugin_entries") + if not isinstance(entries, list) or not entries: + raise RuntimeError( + f"worker group metadata missing plugin_entries: {group_metadata_path}" + ) + + plugins: list[PluginSpec] = [] + for entry in entries: + if not isinstance(entry, dict): + raise RuntimeError( + f"worker group metadata contains invalid plugin entry: {group_metadata_path}" + ) + plugin_dir = entry.get("plugin_dir") + if not isinstance(plugin_dir, str) or not plugin_dir: + raise RuntimeError( + f"worker group metadata contains invalid plugin_dir: {group_metadata_path}" + ) + plugins.append(load_plugin_spec(Path(plugin_dir))) + + group_id = payload.get("group_id") + if not isinstance(group_id, str) or not group_id: + group_id = group_metadata_path.stem + return group_id, plugins + + +async def run_plugin_lifecycle( + instances: list[Any], + method_name: str, + context: RuntimeContext, +) -> None: + """运行插件生命周期方法。""" + for instance in instances: + method = getattr(instance, method_name, None) + with caller_plugin_scope(context.plugin_id): + await run_lifecycle_with_decorators( + instance=instance, + hook=method if callable(method) else None, + method_name=method_name, + context=context, + ) + + +class GroupWorkerRuntime: + def __init__(self, *, group_metadata_path: Path, transport) -> None: + self.group_metadata_path = group_metadata_path.resolve() + self.group_id, self.plugins = _load_group_plugin_specs(self.group_metadata_path) + self.transport = transport + self.peer = Peer( + transport=self.transport, + peer_info=PeerInfo(name=self.group_id, role="plugin", version="v4"), + ) + self.skipped_plugins: dict[str, str] = {} + self.issues: list[PluginDiscoveryIssue] = [] + self._plugin_states: list[GroupPluginRuntimeState] = [] + self._active_plugin_states: list[GroupPluginRuntimeState] = [] + self._load_plugins() + self._refresh_dispatchers() + self.peer.set_invoke_handler(self._handle_invoke) + self.peer.set_cancel_handler(self._handle_cancel) + + def _load_plugins(self) -> None: + for plugin in self.plugins: + try: + loaded_plugin = load_plugin(plugin) + except Exception as exc: + self.skipped_plugins[plugin.name] = str(exc) + self.issues.append( + PluginDiscoveryIssue( + severity="error", + phase="load", + plugin_id=plugin.name, + message="插件加载失败", + details=str(exc), + ) + ) + logger.exception( + "组 {} 中插件 {} 加载失败,启动时将跳过", + self.group_id, + plugin.name, + ) + continue + + lifecycle_context = RuntimeContext(peer=self.peer, plugin_id=plugin.name) + self._plugin_states.append( + GroupPluginRuntimeState( + plugin=plugin, + loaded_plugin=loaded_plugin, + lifecycle_context=lifecycle_context, + ) + ) + self._active_plugin_states = list(self._plugin_states) + + def _refresh_dispatchers(self) -> None: + handlers = [ + handler + for state in self._active_plugin_states + for handler in state.loaded_plugin.handlers + ] + capabilities = [ + capability + for state in self._active_plugin_states + for capability in state.loaded_plugin.capabilities + ] + self.dispatcher = HandlerDispatcher( + plugin_id=self.group_id, + peer=self.peer, + handlers=handlers, + ) + self.capability_dispatcher = CapabilityDispatcher( + plugin_id=self.group_id, + peer=self.peer, + capabilities=capabilities, + llm_tools=[ + tool + for state in self._active_plugin_states + for tool in state.loaded_plugin.llm_tools + ], + ) + + async def start(self) -> None: + await self.peer.start() + started_states: list[GroupPluginRuntimeState] = [] + try: + active_states: list[GroupPluginRuntimeState] = [] + for state in self._plugin_states: + try: + await self._run_lifecycle(state, "on_start") + except Exception as exc: + self.skipped_plugins[state.plugin.name] = str(exc) + self.issues.append( + PluginDiscoveryIssue( + severity="error", + phase="lifecycle", + plugin_id=state.plugin.name, + message="插件 on_start 失败", + details=str(exc), + ) + ) + logger.exception( + "组 {} 中插件 {} on_start 失败,启动时将跳过", + self.group_id, + state.plugin.name, + ) + continue + active_states.append(state) + started_states.append(state) + + self._active_plugin_states = active_states + self._refresh_dispatchers() + if not self._active_plugin_states: + raise RuntimeError( + f"worker group {self.group_id} has no active plugins" + ) + + await self.peer.initialize( + [ + handler.descriptor + for state in self._active_plugin_states + for handler in state.loaded_plugin.handlers + ], + provided_capabilities=[ + capability.descriptor + for state in self._active_plugin_states + for capability in state.loaded_plugin.capabilities + ], + metadata=self._initialize_metadata(), + ) + except Exception: + for state in reversed(started_states): + try: + await self._run_lifecycle(state, "on_stop") + except Exception: + logger.exception( + "组 {} 在启动失败清理插件 {} on_stop 时发生异常", + self.group_id, + state.plugin.name, + ) + await self.peer.stop() + raise + + async def stop(self) -> None: + first_error: Exception | None = None + try: + for state in reversed(self._active_plugin_states): + try: + await self._run_lifecycle(state, "on_stop") + except Exception as exc: + if first_error is None: + first_error = exc + logger.exception( + "组 {} 停止插件 {} 时发生异常", + self.group_id, + state.plugin.name, + ) + finally: + await self.peer.stop() + if first_error is not None: + raise first_error + + async def _handle_invoke(self, message, cancel_token): + if message.capability == "handler.invoke": + return await self.dispatcher.invoke(message, cancel_token) + try: + return await self.capability_dispatcher.invoke(message, cancel_token) + except LookupError as exc: + raise AstrBotError.capability_not_found(message.capability) from exc + + async def _handle_cancel(self, request_id: str) -> None: + await self.dispatcher.cancel(request_id) + await self.capability_dispatcher.cancel(request_id) + + def _initialize_metadata(self) -> dict[str, Any]: + return { + "group_id": self.group_id, + "plugins": [plugin.name for plugin in self.plugins], + "loaded_plugins": [ + state.plugin.name for state in self._active_plugin_states + ], + "skipped_plugins": dict(self.skipped_plugins), + "capability_sources": { + capability.descriptor.name: state.plugin.name + for state in self._active_plugin_states + for capability in state.loaded_plugin.capabilities + }, + "issues": [issue.to_payload() for issue in self.issues], + "llm_tools": [ + { + **tool.spec.to_payload(), + "plugin_id": state.plugin.name, + } + for state in self._active_plugin_states + for tool in state.loaded_plugin.llm_tools + ], + "agents": [ + { + **agent.spec.to_payload(), + "plugin_id": state.plugin.name, + } + for state in self._active_plugin_states + for agent in state.loaded_plugin.agents + ], + "acknowledge_global_mcp_risk": any( + _plugin_acknowledges_global_mcp_risk( + _metadata_plugin_instances(state.loaded_plugin) + ) + for state in self._active_plugin_states + ), + } + + async def _run_lifecycle( + self, + state: GroupPluginRuntimeState, + method_name: str, + ) -> None: + await run_plugin_lifecycle( + state.loaded_plugin.instances, method_name, state.lifecycle_context + ) + + +class PluginWorkerRuntime: + def __init__(self, *, plugin_dir: Path, transport) -> None: + self.plugin = load_plugin_spec(plugin_dir) + self.transport = transport + self.loaded_plugin = load_plugin(self.plugin) + self.peer = Peer( + transport=self.transport, + peer_info=PeerInfo(name=self.plugin.name, role="plugin", version="v4"), + ) + self.dispatcher = HandlerDispatcher( + plugin_id=self.plugin.name, + peer=self.peer, + handlers=self.loaded_plugin.handlers, + ) + self.capability_dispatcher = CapabilityDispatcher( + plugin_id=self.plugin.name, + peer=self.peer, + capabilities=self.loaded_plugin.capabilities, + llm_tools=self.loaded_plugin.llm_tools, + ) + self._lifecycle_context = RuntimeContext( + peer=self.peer, plugin_id=self.plugin.name + ) + self.issues: list[PluginDiscoveryIssue] = [] + self.peer.set_invoke_handler(self._handle_invoke) + self.peer.set_cancel_handler(self._handle_cancel) + + async def start(self) -> None: + await self.peer.start() + lifecycle_started = False + try: + await self._run_lifecycle("on_start") + lifecycle_started = True + await self.peer.initialize( + [item.descriptor for item in self.loaded_plugin.handlers], + provided_capabilities=[ + item.descriptor for item in self.loaded_plugin.capabilities + ], + metadata={ + "plugin_id": self.plugin.name, + "plugins": [self.plugin.name], + "loaded_plugins": [self.plugin.name], + "skipped_plugins": {}, + "issues": [issue.to_payload() for issue in self.issues], + "capability_sources": { + item.descriptor.name: self.plugin.name + for item in self.loaded_plugin.capabilities + }, + "llm_tools": [ + { + **item.spec.to_payload(), + "plugin_id": self.plugin.name, + } + for item in self.loaded_plugin.llm_tools + ], + "agents": [ + { + **item.spec.to_payload(), + "plugin_id": self.plugin.name, + } + for item in self.loaded_plugin.agents + ], + "acknowledge_global_mcp_risk": _plugin_acknowledges_global_mcp_risk( + _metadata_plugin_instances(self.loaded_plugin) + ), + }, + ) + except Exception: + if lifecycle_started: + try: + await self._run_lifecycle("on_stop") + except Exception: + logger.exception( + "插件 {} 在启动失败清理 on_stop 时发生异常", + self.plugin.name, + ) + await self.peer.stop() + raise + + async def stop(self) -> None: + try: + await self._run_lifecycle("on_stop") + finally: + await self.peer.stop() + + async def _handle_invoke(self, message, cancel_token): + if message.capability == "handler.invoke": + return await self.dispatcher.invoke(message, cancel_token) + try: + return await self.capability_dispatcher.invoke(message, cancel_token) + except LookupError as exc: + raise AstrBotError.capability_not_found(message.capability) from exc + + async def _handle_cancel(self, request_id: str) -> None: + await self.dispatcher.cancel(request_id) + await self.capability_dispatcher.cancel(request_id) + + async def _run_lifecycle(self, method_name: str) -> None: + await run_plugin_lifecycle( + self.loaded_plugin.instances, method_name, self._lifecycle_context + ) diff --git a/astrbot-sdk/src/astrbot_sdk/schedule.py b/astrbot-sdk/src/astrbot_sdk/schedule.py new file mode 100644 index 0000000000..e0aa20c7a4 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/schedule.py @@ -0,0 +1,60 @@ +"""Schedule-specific SDK types. + +本模块定义定时任务相关的 SDK 类型,主要为 ScheduleContext 提供数据结构。 + +ScheduleContext 包含: +- schedule_id: 调度任务唯一标识 +- plugin_id: 所属插件 ID +- handler_id: 对应 handler 的标识 +- trigger_kind: 触发类型(cron / interval / once) +- cron: cron 表达式(仅 cron 类型) +- interval_seconds: 间隔秒数(仅 interval 类型) +- scheduled_at: 计划执行时间(仅 once 类型) + +使用方式: +通过 @on_schedule 装饰器注册的 handler 可通过参数注入获取 ScheduleContext。 +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass(slots=True) +class ScheduleContext: + schedule_id: str + plugin_id: str + handler_id: str + trigger_kind: str + cron: str | None = None + interval_seconds: int | None = None + scheduled_at: str | None = None + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> ScheduleContext: + schedule = payload.get("schedule") + if not isinstance(schedule, dict): + raise ValueError("schedule payload is required") + return cls( + schedule_id=str(schedule.get("schedule_id", "")), + plugin_id=str(schedule.get("plugin_id", "")), + handler_id=str(schedule.get("handler_id", "")), + trigger_kind=str(schedule.get("trigger_kind", "")), + cron=( + str(schedule["cron"]) if isinstance(schedule.get("cron"), str) else None + ), + interval_seconds=( + int(schedule["interval_seconds"]) + if isinstance(schedule.get("interval_seconds"), int) + else None + ), + scheduled_at=( + str(schedule["scheduled_at"]) + if isinstance(schedule.get("scheduled_at"), str) + else None + ), + ) + + +__all__ = ["ScheduleContext"] diff --git a/astrbot-sdk/src/astrbot_sdk/session_waiter.py b/astrbot-sdk/src/astrbot_sdk/session_waiter.py new file mode 100644 index 0000000000..2ecc6e0cca --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/session_waiter.py @@ -0,0 +1,664 @@ +"""Session-based conversational flow management. + +本模块实现会话等待器 (session_waiter),用于构建多轮对话流程。 + +核心组件: +- SessionController: 控制会话生命周期,支持超时管理、会话保持、历史记录 +- SessionWaiterManager: 管理活跃的会话等待器,处理事件分发和注册/注销 +- @session_waiter 装饰器: 将普通 handler 转换为会话式 handler + +使用场景: +当需要在用户首次触发后继续监听后续消息(如分步表单、问答游戏), +可使用 @session_waiter 装饰器自动管理会话状态和超时。 + +注意事项: +在当前桥接设计中,不应在普通 SDK handler 内直接 await session_waiter, +这会导致首次 dispatch 保持打开直到下一条消息到达。 +推荐写法是 `await ctx.register_task(waiter(...), "...")`,让 waiter 在后台任务中 +承接后续消息;直接 await 仅适用于你明确需要保持当前 dispatch 挂起的场景。 +""" + +from __future__ import annotations + +import asyncio +import time +import weakref +from collections.abc import Awaitable, Callable, Coroutine +from contextvars import ContextVar +from dataclasses import dataclass, field +from functools import wraps +from typing import Any, Concatenate, ParamSpec, Protocol, TypeVar, cast, overload + +from loguru import logger + +from ._internal.invocation_context import current_caller_plugin_id +from .events import MessageEvent + +_OwnerT = TypeVar("_OwnerT") +_P = ParamSpec("_P") +_ResultT = TypeVar("_ResultT") +_WaiterKey = tuple[str, str] + +_HANDLER_TASKS: weakref.WeakSet[asyncio.Task[Any]] = weakref.WeakSet() +_REGISTERED_BACKGROUND_TASKS: weakref.WeakSet[asyncio.Task[Any]] = weakref.WeakSet() +_WARNED_DIRECT_WAIT_TASKS: weakref.WeakSet[asyncio.Task[Any]] = weakref.WeakSet() +_ACTIVE_WAITER_KEY: ContextVar[_WaiterKey | None] = ContextVar( + "astrbot_sdk_active_waiter_key", + default=None, +) + + +class _TaskReentrantLock: + def __init__(self) -> None: + self._lock = asyncio.Lock() + self._owner: asyncio.Task[Any] | None = None + self._depth = 0 + + async def acquire(self) -> None: + current_task = asyncio.current_task() + if current_task is None: + raise RuntimeError("session waiter lock requires an active asyncio task") + if self._owner is current_task: + self._depth += 1 + return + await self._lock.acquire() + self._owner = current_task + self._depth = 1 + + def release(self) -> None: + current_task = asyncio.current_task() + if current_task is None or self._owner is not current_task: + raise RuntimeError("session waiter lock released by a non-owner task") + self._depth -= 1 + if self._depth > 0: + return + self._owner = None + self._lock.release() + + async def __aenter__(self) -> _TaskReentrantLock: + await self.acquire() + return self + + async def __aexit__(self, *_exc_info: object) -> None: + self.release() + + +def _mark_session_waiter_handler_task(task: asyncio.Task[Any]) -> None: + _HANDLER_TASKS.add(task) + + +def _unmark_session_waiter_handler_task(task: asyncio.Task[Any]) -> None: + _HANDLER_TASKS.discard(task) + + +def _mark_session_waiter_background_task(task: asyncio.Task[Any]) -> None: + _REGISTERED_BACKGROUND_TASKS.add(task) + + +def _unmark_session_waiter_background_task(task: asyncio.Task[Any]) -> None: + _REGISTERED_BACKGROUND_TASKS.discard(task) + + +class _SessionWaiterDecorator(Protocol): + @overload + def __call__( + self, + func: Callable[ + Concatenate[SessionController, MessageEvent, _P], + Awaitable[_ResultT], + ], + /, + ) -> Callable[Concatenate[MessageEvent, _P], Coroutine[Any, Any, _ResultT]]: ... + + @overload + def __call__( + self, + func: Callable[ + Concatenate[_OwnerT, SessionController, MessageEvent, _P], + Awaitable[_ResultT], + ], + /, + ) -> Callable[ + Concatenate[_OwnerT, MessageEvent, _P], + Coroutine[Any, Any, _ResultT], + ]: ... + + +@dataclass(slots=True) +class SessionController: + future: asyncio.Future[Any] = field(default_factory=asyncio.Future) + current_event: asyncio.Event | None = None + ts: float | None = None + timeout: float | None = None + history_chains: list[list[dict[str, Any]]] = field(default_factory=list) + + def stop(self, error: Exception | None = None) -> None: + if self.future.done(): + return + if error is not None: + self.future.set_exception(error) + else: + self.future.set_result(None) + + def keep(self, timeout: float = 0, reset_timeout: bool = False) -> None: + new_ts = time.time() + if reset_timeout: + if timeout <= 0: + self.stop() + return + else: + assert self.timeout is not None + assert self.ts is not None + left_timeout = self.timeout - (new_ts - self.ts) + timeout = left_timeout + timeout + if timeout <= 0: + self.stop() + return + + if self.current_event and not self.current_event.is_set(): + self.current_event.set() + + current_event = asyncio.Event() + self.current_event = current_event + self.ts = new_ts + self.timeout = timeout + asyncio.create_task(self._holding(current_event, timeout)) + + async def _holding(self, event: asyncio.Event, timeout: float) -> None: + try: + await asyncio.wait_for(event.wait(), timeout) + except asyncio.TimeoutError as exc: + self.stop(exc) + except asyncio.CancelledError: + return + + def get_history_chains(self) -> list[list[dict[str, Any]]]: + return list(self.history_chains) + + +@dataclass(slots=True) +class _WaiterEntry: + session_key: str + plugin_id: str + handler: Callable[[SessionController, MessageEvent], Awaitable[Any]] + controller: SessionController + record_history_chains: bool + unregister_enabled: bool = True + + +class SessionWaiterManager: + def __init__(self, *, plugin_id: str, peer) -> None: + self._plugin_id = plugin_id + self._peer = peer + self._entries: dict[str, dict[str, _WaiterEntry]] = {} + self._locks: dict[_WaiterKey, _TaskReentrantLock] = {} + + @staticmethod + def _make_key(*, plugin_id: str, session_key: str) -> _WaiterKey: + return (plugin_id, session_key) + + async def register( + self, + *, + event: MessageEvent, + handler: Callable[[SessionController, MessageEvent], Awaitable[Any]], + timeout: int, + record_history_chains: bool, + ) -> Any: + if event._context is None: + raise RuntimeError("session_waiter requires runtime context") + self._warn_if_direct_wait_in_handler(event) + session_key = event.unified_msg_origin + plugin_id = self._resolve_plugin_id(event) + entry = _WaiterEntry( + session_key=session_key, + plugin_id=plugin_id, + handler=handler, + controller=SessionController(), + record_history_chains=record_history_chains, + ) + previous = self._entries.setdefault(session_key, {}).get(plugin_id) + restorable_previous: _WaiterEntry | None = None + self._entries[session_key][plugin_id] = entry + self._lock_for(session_key, plugin_id) + if previous is not None: + previous.unregister_enabled = False + if _ACTIVE_WAITER_KEY.get() == self._make_key( + plugin_id=plugin_id, + session_key=session_key, + ): + restorable_previous = previous + else: + self._finish_entry( + previous, + RuntimeError("session waiter replaced by a newer waiter"), + ) + logger.warning( + "Session waiter replaced: plugin_id={} session_key={}", + plugin_id, + session_key, + ) + try: + await self._invoke_system_waiter( + "system.session_waiter.register", + session_key=session_key, + plugin_id=plugin_id, + ) + entry.controller.keep(timeout, reset_timeout=True) + except Exception: + entry.unregister_enabled = False + await self._remove_entry(entry) + if restorable_previous is not None: + self._entries.setdefault(session_key, {})[plugin_id] = ( + restorable_previous + ) + restorable_previous.unregister_enabled = True + self._lock_for(session_key, plugin_id) + raise + try: + return await entry.controller.future + finally: + if entry.unregister_enabled: + await self.unregister(session_key, plugin_id=plugin_id) + + def _warn_if_direct_wait_in_handler(self, event: MessageEvent) -> None: + current_task = asyncio.current_task() + if current_task is None: + return + if current_task not in _HANDLER_TASKS: + return + if current_task in _REGISTERED_BACKGROUND_TASKS: + return + if current_task in _WARNED_DIRECT_WAIT_TASKS: + return + _WARNED_DIRECT_WAIT_TASKS.add(current_task) + logger.warning( + "Direct await on session_waiter blocks the current handler dispatch; " + 'prefer `await ctx.register_task(waiter(...), "...")`: ' + "plugin_id={} session_key={}", + event._context.plugin_id, + event.unified_msg_origin, + ) + + async def wait_for_event( + self, + *, + event: MessageEvent, + timeout: int, + record_history_chains: bool = False, + ) -> MessageEvent: + future: asyncio.Future[MessageEvent] = ( + asyncio.get_running_loop().create_future() + ) + + async def _handler( + controller: SessionController, + waiter_event: MessageEvent, + ) -> None: + if not future.done(): + future.set_result(waiter_event) + controller.stop() + + await self.register( + event=event, + handler=_handler, + timeout=timeout, + record_history_chains=record_history_chains, + ) + return future.result() + + async def unregister( + self, + session_key: str, + *, + plugin_id: str | None = None, + ) -> None: + target_plugin_id = self._resolve_unregister_plugin_id( + session_key, + plugin_id=plugin_id, + ) + if target_plugin_id is None: + return + lock_key = (session_key, target_plugin_id) + lock = self._lock_for(session_key, target_plugin_id) + removed = False + async with lock: + session_entries = self._entries.get(session_key) + if session_entries is None: + return + removed = session_entries.pop(target_plugin_id, None) is not None + if not session_entries: + self._entries.pop(session_key, None) + if self._locks.get(lock_key) is lock: + self._locks.pop(lock_key, None) + if not removed: + return + try: + await self._invoke_system_waiter( + "system.session_waiter.unregister", + session_key=session_key, + plugin_id=target_plugin_id, + ) + except Exception: + logger.debug( + "Failed to unregister session waiter: plugin_id={} session_key={}", + target_plugin_id, + session_key, + ) + + async def fail( + self, + session_key: str, + error: Exception, + *, + plugin_id: str | None = None, + ) -> bool: + resolved_plugin_id = plugin_id + if resolved_plugin_id is None: + caller_plugin_id = current_caller_plugin_id() + if caller_plugin_id: + resolved_plugin_id = caller_plugin_id + entry = self._select_entry( + session_key, + plugin_id=resolved_plugin_id, + allow_ambiguous=False, + missing_result=None, + ) + if entry is None: + return False + lock = self._lock_for(session_key, entry.plugin_id) + async with lock: + current = self._get_entry(session_key, entry.plugin_id) + if current is None or current.controller.future.done(): + return False + self._finish_entry(current, error) + return True + + def has_active_waiter(self, event: MessageEvent) -> bool: + session_key = event.unified_msg_origin + event_plugin_id = self._event_plugin_id(event) + if event_plugin_id is not None: + entry = self._get_entry(session_key, event_plugin_id) + return entry is not None and not entry.controller.future.done() + return bool(self.get_waiter_plugin_ids(session_key)) + + def has_waiter(self, event: MessageEvent) -> bool: + return self.has_active_waiter(event) + + def get_waiter_plugin_ids(self, session_key: str) -> list[str]: + return sorted( + plugin_id + for plugin_id, entry in self._entries.get(session_key, {}).items() + if not entry.controller.future.done() + ) + + async def dispatch( + self, + event: MessageEvent, + *, + plugin_id: str | None = None, + ) -> dict[str, Any]: + if event._context is None: + raise RuntimeError("session_waiter dispatch requires runtime context") + session_key = event.unified_msg_origin + entry = self._select_entry( + session_key, + plugin_id=plugin_id, + allow_ambiguous=False, + missing_result=None, + ambiguous_error=LookupError( + f"session waiter dispatch for session '{session_key}' requires explicit plugin identity" + ), + ) + if entry is None: + return {"sent_message": False, "stop": False, "call_llm": False} + lock = self._lock_for(session_key, entry.plugin_id) + async with lock: + current = self._get_entry(session_key, entry.plugin_id) + if current is None or current.controller.future.done(): + return {"sent_message": False, "stop": False, "call_llm": False} + waiter_event = self._build_waiter_event(current, event) + if current.record_history_chains: + chain = [] + raw_chain = ( + waiter_event.raw.get("chain") + if isinstance(waiter_event.raw, dict) + else None + ) + if isinstance(raw_chain, list): + chain = [dict(item) for item in raw_chain if isinstance(item, dict)] + current.controller.history_chains.append(chain) + active_key_token = _ACTIVE_WAITER_KEY.set( + self._make_key( + plugin_id=current.plugin_id, + session_key=current.session_key, + ) + ) + try: + # Keep follow-up handler execution serialized per waiter while still + # allowing nested waiter cleanup in the same task to re-enter safely. + await current.handler(current.controller, waiter_event) + finally: + _ACTIVE_WAITER_KEY.reset(active_key_token) + return { + "sent_message": False, + "stop": waiter_event.is_stopped(), + "call_llm": False, + } + + def _resolve_plugin_id(self, event: MessageEvent) -> str: + caller_plugin_id = current_caller_plugin_id() + if caller_plugin_id: + return caller_plugin_id + context = event._context + if context is not None and context.plugin_id.strip(): + return context.plugin_id + return self._plugin_id + + @staticmethod + def _event_plugin_id(event: MessageEvent) -> str | None: + context = event._context + if context is None: + return None + plugin_id = context.plugin_id.strip() + return plugin_id or None + + def _resolve_unregister_plugin_id( + self, + session_key: str, + *, + plugin_id: str | None, + ) -> str | None: + if plugin_id is not None: + normalized = str(plugin_id).strip() + return normalized or None + session_entries = self._entries.get(session_key, {}) + if len(session_entries) != 1: + return None + return next(iter(session_entries)) + + def _select_entry( + self, + session_key: str, + *, + plugin_id: str | None, + allow_ambiguous: bool, + missing_result: _WaiterEntry | None, + ambiguous_error: Exception | None = None, + ) -> _WaiterEntry | None: + if plugin_id is not None: + return self._get_entry(session_key, plugin_id) + active_entries = [ + entry + for entry in self._entries.get(session_key, {}).values() + if not entry.controller.future.done() + ] + if not active_entries: + return missing_result + if len(active_entries) > 1 and not allow_ambiguous: + if ambiguous_error is not None: + raise ambiguous_error + return missing_result + return active_entries[0] + + def _get_entry(self, session_key: str, plugin_id: str) -> _WaiterEntry | None: + return self._entries.get(session_key, {}).get(plugin_id) + + def _lock_for(self, session_key: str, plugin_id: str) -> _TaskReentrantLock: + return self._locks.setdefault((session_key, plugin_id), _TaskReentrantLock()) + + async def _remove_entry(self, entry: _WaiterEntry) -> None: + lock_key = (entry.session_key, entry.plugin_id) + lock = self._lock_for(entry.session_key, entry.plugin_id) + async with lock: + session_entries = self._entries.get(entry.session_key) + if session_entries is None: + return + current = session_entries.get(entry.plugin_id) + if current is not entry: + return + session_entries.pop(entry.plugin_id, None) + if not session_entries: + self._entries.pop(entry.session_key, None) + if self._locks.get(lock_key) is lock: + self._locks.pop(lock_key, None) + + @staticmethod + def _finish_entry(entry: _WaiterEntry, error: Exception | None = None) -> None: + entry.controller.stop(error) + if ( + entry.controller.current_event is not None + and not entry.controller.current_event.is_set() + ): + entry.controller.current_event.set() + + async def _invoke_system_waiter( + self, + capability: str, + *, + session_key: str, + plugin_id: str, + ) -> None: + from ._internal.invocation_context import caller_plugin_scope + + with caller_plugin_scope(plugin_id): + await self._peer.invoke( + capability, + {"session_key": session_key}, + ) + + def _build_waiter_event( + self, + entry: _WaiterEntry, + event: MessageEvent, + ) -> MessageEvent: + from .context import Context + + source_payload = self._source_payload_from_event(event) + cancel_token = ( + event._context.cancel_token if event._context is not None else None + ) + waiter_context = Context( + peer=self._peer, + plugin_id=entry.plugin_id, + request_id=( + event._context.request_id if event._context is not None else None + ), + cancel_token=cancel_token, + source_event_payload=source_payload, + ) + # Rebuild the event so the waiter always sees the registering plugin identity + # and the exact source payload that triggered the follow-up dispatch. + return MessageEvent.from_payload( + source_payload, + context=waiter_context, + ) + + @staticmethod + def _source_payload_from_event(event: MessageEvent) -> dict[str, Any]: + raw_payload = event.raw if isinstance(event.raw, dict) else None + if raw_payload is not None and { + "text", + "session_id", + "platform", + }.issubset(raw_payload): + return dict(raw_payload) + return event.to_payload() + + +def session_waiter( + timeout: int = 30, + *, + record_history_chains: bool = False, +) -> _SessionWaiterDecorator: + def decorator( + func: Callable[..., Awaitable[Any]], + ) -> Callable[..., Coroutine[Any, Any, Any]]: + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + owner = None + event: MessageEvent | None = None + trailing_args: tuple[Any, ...] = () + if args and isinstance(args[0], MessageEvent): + event = args[0] + trailing_args = args[1:] + elif len(args) >= 2 and isinstance(args[1], MessageEvent): + owner = args[0] + event = args[1] + trailing_args = args[2:] + if event is None: + raise RuntimeError("session_waiter requires a MessageEvent argument") + if event._context is None: + raise RuntimeError("session_waiter requires runtime context") + manager = getattr(event._context.peer, "_session_waiter_manager", None) + if manager is None: + raise RuntimeError("session_waiter manager is unavailable") + + if owner is None: + free_func = cast(Callable[..., Awaitable[Any]], func) + + async def bound_handler( + controller: SessionController, + waiter_event: MessageEvent, + ) -> Any: + return await free_func( + controller, + waiter_event, + *trailing_args, + **kwargs, + ) + else: + method_func = cast(Callable[..., Awaitable[Any]], func) + + async def bound_handler( + controller: SessionController, + waiter_event: MessageEvent, + ) -> Any: + return await method_func( + owner, + controller, + waiter_event, + *trailing_args, + **kwargs, + ) + + return await manager.register( + event=event, + handler=bound_handler, + timeout=timeout, + record_history_chains=record_history_chains, + ) + + return wrapper + + return cast(_SessionWaiterDecorator, decorator) + + +__all__ = [ + "_OwnerT", + "_P", + "_ResultT", + "SessionController", + "SessionWaiterManager", + "session_waiter", +] diff --git a/astrbot-sdk/src/astrbot_sdk/star.py b/astrbot-sdk/src/astrbot_sdk/star.py new file mode 100644 index 0000000000..ef774b4e78 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/star.py @@ -0,0 +1,132 @@ +"""v4 原生插件基类。""" + +from __future__ import annotations + +import json +import traceback +from contextvars import ContextVar, Token +from typing import TYPE_CHECKING, Any, cast + +from loguru import logger + +from .errors import AstrBotError +from .plugin_kv import PluginKVStoreMixin + +if TYPE_CHECKING: + from .context import Context + + +class Star(PluginKVStoreMixin): + """v4 原生插件基类。""" + + __handlers__: tuple[str, ...] = () + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + from .decorators import get_handler_meta + + handlers: dict[str, None] = {} + for base in reversed(cls.__mro__): + for name, attr in getattr(base, "__dict__", {}).items(): + func = getattr(attr, "__func__", attr) + meta = get_handler_meta(func) + if meta is not None and meta.trigger is not None: + handlers[name] = None + cls.__handlers__ = tuple(handlers.keys()) + + @property + def context(self) -> Context | None: + return self._context_var().get() + + def _require_runtime_context(self) -> Context: + ctx = self.context + if ctx is None: + raise RuntimeError( + "Star runtime context is only available during lifecycle, " + "handler, and registered LLM tool execution" + ) + return ctx + + def _context_var(self) -> ContextVar[Context | None]: + existing_context_var = getattr(self, "__astrbot_context_var__", None) + if isinstance(existing_context_var, ContextVar): + return cast("ContextVar[Context | None]", existing_context_var) + created_context_var: ContextVar[Context | None] = ContextVar( + f"astrbot_sdk_star_context_{id(self)}", + default=None, + ) + setattr(self, "__astrbot_context_var__", created_context_var) + return created_context_var + + def _bind_runtime_context(self, ctx: Context | None) -> Token[Context | None]: + return self._context_var().set(ctx) + + def _reset_runtime_context(self, token: Token[Context | None]) -> None: + self._context_var().reset(token) + + async def on_start(self, ctx: Any | None = None) -> None: + await self.initialize() + + async def on_stop(self, ctx: Any | None = None) -> None: + await self.terminate() + + async def initialize(self) -> None: + return None + + async def terminate(self) -> None: + return None + + async def text_to_image( + self, + text: str, + *, + return_url: bool = True, + ) -> str: + return await self._require_runtime_context().text_to_image( + text, + return_url=return_url, + ) + + async def html_render( + self, + tmpl: str, + data: dict[str, Any], + *, + return_url: bool = True, + options: dict[str, Any] | None = None, + ) -> str: + return await self._require_runtime_context().html_render( + tmpl, + data, + return_url=return_url, + options=options, + ) + + @staticmethod + async def default_on_error(error: Exception, event, ctx) -> None: + del ctx + if isinstance(error, AstrBotError): + lines: list[str] = [] + if error.retryable: + lines.append("请求失败,请稍后重试") + elif error.hint: + lines.append(error.hint) + else: + lines.append(error.message) + if error.docs_url: + lines.append(f"文档:{error.docs_url}") + if error.details: + lines.append( + f"详情:{json.dumps(error.details, ensure_ascii=False, sort_keys=True)}" + ) + await event.reply("\n".join(lines)) + else: + await event.reply("出了点问题,请联系插件作者") + logger.error("handler 执行失败\n{}", traceback.format_exc()) + + async def on_error(self, error: Exception, event, ctx) -> None: + await Star.default_on_error(error, event, ctx) + + @classmethod + def __astrbot_is_new_star__(cls) -> bool: + return True diff --git a/astrbot-sdk/src/astrbot_sdk/star_tools.py b/astrbot-sdk/src/astrbot_sdk/star_tools.py new file mode 100644 index 0000000000..62657ae318 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/star_tools.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Sequence +from typing import Any + +from ._internal.star_runtime import current_star_context +from .context import Context +from .message.components import BaseMessageComponent +from .message.result import MessageChain +from .message.session import MessageSession + + +class _StarToolsContextDescriptor: + def __get__(self, _instance: object, _owner: type[object]) -> Context | None: + return current_star_context() + + +class StarTools: + """Star 工具类,提供类方法访问运行时上下文能力。 + + 所有方法都通过当前上下文动态路由到对应的能力接口。 + 只在 lifecycle、handler 和已注册的 LLM tool 执行期间可用。 + """ + + _context = _StarToolsContextDescriptor() + + @classmethod + def _get_context(cls) -> Context | None: + """获取当前 Star 运行时上下文。""" + return cls._context + + @classmethod + def _require_context(cls) -> Context: + """获取当前运行时上下文,如果不存在则抛出 RuntimeError。""" + ctx = current_star_context() + if ctx is None: + raise RuntimeError( + "StarTools context is only available during lifecycle, " + "handler, and registered LLM tool execution" + ) + return ctx + + @classmethod + def get_llm_tool_manager(cls): + return cls._require_context().get_llm_tool_manager() + + @classmethod + async def activate_llm_tool(cls, name: str) -> bool: + return await cls._require_context().activate_llm_tool(name) + + @classmethod + async def deactivate_llm_tool(cls, name: str) -> bool: + return await cls._require_context().deactivate_llm_tool(name) + + @classmethod + async def send_message( + cls, + session: str | MessageSession, + content: ( + str + | MessageChain + | Sequence[BaseMessageComponent] + | Sequence[dict[str, Any]] + ), + ) -> dict[str, Any]: + return await cls._require_context().send_message(session, content) + + @classmethod + async def send_message_by_id( + cls, + type: str, + id: str, + content: ( + str + | MessageChain + | Sequence[BaseMessageComponent] + | Sequence[dict[str, Any]] + ), + *, + platform: str, + ) -> dict[str, Any]: + return await cls._require_context().send_message_by_id( + type, + id, + content, + platform=platform, + ) + + @classmethod + async def register_llm_tool( + cls, + name: str, + parameters_schema: dict[str, Any], + desc: str, + func_obj: Callable[..., Awaitable[Any]] | Callable[..., Any], + *, + active: bool = True, + ) -> list[str]: + return await cls._require_context().register_llm_tool( + name, + parameters_schema, + desc, + func_obj, + active=active, + ) + + @classmethod + async def unregister_llm_tool(cls, name: str) -> bool: + return await cls._require_context().unregister_llm_tool(name) + + @classmethod + async def register_skill( + cls, + *, + name: str, + path: str, + description: str = "", + ): + return await cls._require_context().skills.register( + name=name, + path=path, + description=description, + ) + + @classmethod + async def unregister_skill(cls, name: str) -> bool: + return await cls._require_context().skills.unregister(name) diff --git a/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/SKILL.md b/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/SKILL.md new file mode 100644 index 0000000000..62f37b485c --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/SKILL.md @@ -0,0 +1,168 @@ +--- +name: {{ skill_dir_name }} +description: Design, implement, test, and package AstrBot SDK v4 plugins. Activate when the request involves AstrBot plugins, plugin.yaml, main.py, Star base class, Context, MessageEvent, SDK decorators (on_command, on_message, on_event, on_schedule, conversation_command, provide_capability, register_llm_tool, http_api, background_task), PluginHarness, or astrbot-sdk CLI commands (init, validate, dev, build). +--- + +# AstrBot Plugin Dev + +Turn plugin requirements into working AstrBot SDK v4 plugins. Default to the stable public SDK surface and CLI workflow. + +## Project Context + +- Plugin name: `{{ plugin_name }}` +- Display name: `{{ display_name }}` +- Plugin root: `{{ plugin_root }}` +- Target agent: `{{ agent_display_name }}` + +## Step 1 — Classify the task + +- **New plugin**: scaffold with `astrbot-sdk init ` (fallback: `python -m astrbot_sdk init `). +- **Existing plugin**: read `plugin.yaml` + component module first; never run `init` inside it. +- **Migration**: v3→v4 conversion — see migration notes at the end. + +## Step 2 — Map intent to decorators + +| User wants … | Decorator | +|---|---| +| Slash-style command (`/hello`) | `@on_command` | +| Keyword or regex reaction | `@on_message` | +| Non-message platform event (join, load …) | `@on_event` | +| Periodic / cron task | `@on_schedule` | +| Multi-turn dialogue / form flow | `@conversation_command` | +| Expose HTTP endpoint | `@http_api` | +| Inter-plugin callable capability | `@provide_capability` | +| Give the LLM a callable tool | `@register_llm_tool` | +| Continuous background loop | `@background_task` | +| MCP server exposure | `@mcp_server` | + +Add guards as needed: `@require_admin`, `@platforms(...)`, `@group_only()`, `@private_only()`, `@rate_limit(...)`, `@cooldown(...)`, `@priority(...)`. + +## Step 3 — Implement + +### Handler signatures (must match trigger type) + +```python +# Command / message / event handlers +async def handler(self, event: MessageEvent, ctx: Context) -> None: ... + +# Command with typed parameters (GreedyStr MUST be last) +async def handler(self, event: MessageEvent, ctx: Context, name: str, content: GreedyStr) -> None: ... + +# Schedule handler — NO event parameter +async def handler(self, ctx: Context) -> None: ... + +# Conversation command — receives ConversationSession +async def handler(self, event: MessageEvent, ctx: Context, session: ConversationSession) -> None: ... + +# Capability handler +async def handler(self, payload: dict, ctx: Context) -> dict: ... + +# LLM tool — keyword arguments matching schema +async def handler(self, city: str, unit: str = "celsius") -> dict: ... + +# Background task +async def handler(self, ctx: Context) -> None: ... +``` + +### Imports + +```python +# Core — always needed +from astrbot_sdk import Star, Context, MessageEvent + +# Decorators — import only what you use +from astrbot_sdk.decorators import on_command, on_message, on_schedule # etc. + +# Typed command params +from astrbot_sdk import GreedyStr + +# Rich messages +from astrbot_sdk import MessageBuilder, MessageChain, Plain, Image, At, AtAll, File + +# Conversation +from astrbot_sdk import ConversationSession + +# Config validation +from pydantic import BaseModel +from astrbot_sdk.decorators import validate_config +``` + +### Public API boundaries + +**Use freely:** `astrbot_sdk.*`, `astrbot_sdk.decorators.*`, `astrbot_sdk.clients.*`, `astrbot_sdk.testing.*` + +**Never use in plugin code:** `astrbot_sdk.runtime.*`, Worker/Supervisor, Loader, Peer/Transport internals. + +## Step 4 — Validate and test + +```bash +# Validate structure + imports + handler discovery +astrbot-sdk validate --plugin-dir + +# Single-shot local test +astrbot-sdk dev --local --plugin-dir --event-text "" + +# Run tests +python -m pytest tests -q + +# Package (optional) +astrbot-sdk build --plugin-dir +``` + +If `astrbot-sdk` is not on PATH, use `python -m astrbot_sdk ` instead. + +## Guardrails — MUST follow + +### Context handling +- **NEVER** store `ctx` on `self` outside the active handler or lifecycle call. +- In `on_start` / `on_stop`, always call `await super().on_start(ctx)` / `await super().on_stop(ctx)`. + +### Client API semantics (prevents common bugs) +- `ctx.db.delete(key)` returns **None**, not bool. Check existence with `ctx.db.get()` first if you need to know. +- `ctx.db.get(key)` returns **None** for missing keys, does not raise. +- `ctx.db.list(prefix)` returns `list[str]` of key names, not values. +- `ctx.memory.save(key, value)` — value **must** be `dict`, not `str`/`int`. Raises `TypeError` otherwise. +- `ctx.memory.delete_many(keys)` returns `int` (count deleted), not list. +- `ctx.llm.chat(prompt)` returns `str`; use `chat_raw()` for `LLMResponse` with usage/tool_calls. +- `ctx.metadata.get_plugin_config()` raises `PermissionError` if accessing another plugin's config. + +### Parameter injection +- `GreedyStr` must be the **last** parameter in the handler signature. +- Typed parameters (`str`, `int`, `float`, `bool`) are parsed positionally from command text. +- `@on_schedule` handlers have **no** `event` parameter. +- `@conversation_command` handlers receive `ConversationSession` via injection. + +### Decorator stacking order +Place trigger decorator **first** (topmost), then guards below: +```python +@on_command("admin-cmd") # trigger first +@require_admin # then guard +@rate_limit(5, 60.0) # then throttle +async def admin_cmd(self, event: MessageEvent, ctx: Context) -> None: ... +``` + +### Testing +- **NEVER** use `from main import MyPlugin` in tests — pollutes `sys.modules["main"]`. +- Use `PluginHarness.from_plugin_dir(plugin_dir)` exclusively. +- Ignore `__pycache__` / `*.pyc` when copying fixtures. +- `dispatch_text()` returns `list[RecordedSend]`; check `record.text` for reply content. + +### Message components +- `Plain` serializes as `type: "text"`, not `"plain"`. +- Use `Image.fromURL(url)` or `Image.fromFileSystem(path)` factory methods; the constructor takes `file` param directly. +- `MessageBuilder` is fluent: `.text("hi").at("123").image("url").build()` → `MessageChain`. + +## v3 → v4 Migration + +- `astrbot.api.star.Star` → `astrbot_sdk.Star` +- Old filter decorators → `astrbot_sdk.decorators.*` +- `self.context` in handlers → injected `ctx` parameter +- Direct KV helpers → `ctx.db.*` or `PluginKVStoreMixin` + +## References + +Read only files needed for the task: + +- `references/api-quick-ref.md` — complete decorator parameters, client methods with return types +- `references/plugin-patterns.md` — full working plugin examples by pattern +- `references/project-structure.md` — plugin.yaml schema, testing patterns, CLI commands diff --git a/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/agents/openai.yaml b/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/agents/openai.yaml new file mode 100644 index 0000000000..07d93bce17 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/agents/openai.yaml @@ -0,0 +1,4 @@ +interface: + display_name: "AstrBot Plugin Dev ({{ agent_display_name }})" + short_description: "Use AstrBot SDK to design, build, test, and package plugins." + default_prompt: "Use ${{ skill_dir_name }} to design and implement an AstrBot SDK v4 plugin from the user request." diff --git a/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/references/api-quick-ref.md b/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/references/api-quick-ref.md new file mode 100644 index 0000000000..2c8990c5f9 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/references/api-quick-ref.md @@ -0,0 +1,482 @@ +# API Quick Reference + +## Trigger Decorators + +### @on_command + +```python +@on_command(command, *, aliases=None, description=None) +``` + +| Param | Type | Default | Notes | +|-------|------|---------|-------| +| command | `str \| Sequence[str]` | required | First item is canonical name | +| aliases | `list[str] \| None` | None | Alternative names | +| description | `str \| None` | None | Help text | + +**Handler signature:** `async def handler(self, event: MessageEvent, ctx: Context) -> None` + +With typed params: `async def handler(self, event: MessageEvent, ctx: Context, name: str, text: GreedyStr) -> None` + +- Parameters after `event`/`ctx` are parsed positionally from command text. +- `GreedyStr` must be the **last** parameter — captures all remaining text. +- Supported types: `str`, `int`, `float`, `bool`, `GreedyStr`. + +### @on_message + +```python +@on_message(*, regex=None, keywords=None, platforms=None, message_types=None, description=None) +``` + +| Param | Type | Default | Notes | +|-------|------|---------|-------| +| regex | `str \| None` | None | Python `re` pattern | +| keywords | `list[str] \| None` | None | Any keyword match triggers | +| platforms | `list[str] \| None` | None | Platform filter | +| message_types | `list[str] \| None` | None | "group", "private" | +| description | `str \| None` | None | — | + +Must provide at least `regex` or `keywords`. + +**Handler signature:** `async def handler(self, event: MessageEvent, ctx: Context) -> None` + +**PITFALL:** Do not combine `@on_message(platforms=...)` with a separate `@platforms()` decorator. + +### @on_event + +```python +@on_event(event_type, *, description=None) +``` + +| Param | Type | Default | Notes | +|-------|------|---------|-------| +| event_type | `str` | required | e.g., "group_member_join", "astrbot_loaded" | +| description | `str \| None` | None | — | + +**Handler signature:** `async def handler(self, event, ctx: Context) -> None` + +Note: `event` may not be a `MessageEvent` — it can be a raw dict depending on event type. + +### @on_schedule + +```python +@on_schedule(*, cron=None, interval_seconds=None, description=None) +``` + +| Param | Type | Default | Notes | +|-------|------|---------|-------| +| cron | `str \| None` | None | Cron expression (e.g., "0 8 * * *") | +| interval_seconds | `int \| None` | None | Seconds between invocations | +| description | `str \| None` | None | — | + +Must provide exactly one of `cron` or `interval_seconds`. + +**Handler signature:** `async def handler(self, ctx: Context) -> None` + +Optional: `async def handler(self, ctx: Context, schedule: ScheduleContext) -> None` + +**PITFALL:** No `event` parameter — you cannot call `event.reply()`. Use `ctx.platform.send()` for proactive messages. + +### @conversation_command + +```python +@conversation_command(command, *, aliases=None, description=None, timeout=60, mode="replace", busy_message=None, grace_period=1.0) +``` + +| Param | Type | Default | Notes | +|-------|------|---------|-------| +| command | `str \| Sequence[str]` | required | Command name(s) | +| aliases | `list[str] \| None` | None | — | +| description | `str \| None` | None | — | +| timeout | `int` | 60 | Session timeout in seconds (must be positive) | +| mode | `"replace" \| "reject"` | "replace" | "replace" cancels old session; "reject" denies new | +| busy_message | `str \| None` | None | Reply when rejecting (mode="reject") | +| grace_period | `float` | 1.0 | Cleanup grace period in seconds (must be positive) | + +**Handler signature:** `async def handler(self, event: MessageEvent, ctx: Context, session: ConversationSession) -> None` + +`ConversationSession` key methods: +- `await session.ask(prompt, timeout=None)` → `MessageEvent` (waits for user reply) +- `await session.reply(text)` → `None` (sends without waiting) +- `await session.reply_chain(chain)` → `None` +- `session.end()` → `None` (marks session completed) + +Raises `TimeoutError`, `CancelledError`, or `ConversationReplaced` on session loss. + +### @provide_capability + +```python +@provide_capability(name, *, description, input_schema=None, output_schema=None, input_model=None, output_model=None, supports_stream=False, cancelable=False) +``` + +| Param | Type | Default | Notes | +|-------|------|---------|-------| +| name | `str` | required | e.g., "my_plugin.calculate" — no reserved prefixes | +| description | `str` | required | — | +| input_schema | `dict \| None` | None | JSON Schema (mutually exclusive with input_model) | +| output_schema | `dict \| None` | None | JSON Schema (mutually exclusive with output_model) | +| input_model | `type[BaseModel] \| None` | None | Pydantic model (mutually exclusive with input_schema) | +| output_model | `type[BaseModel] \| None` | None | Pydantic model (mutually exclusive with output_schema) | +| supports_stream | `bool` | False | — | +| cancelable | `bool` | False | — | + +**Handler signature:** `async def handler(self, payload: dict, ctx: Context) -> dict` + +Reserved name prefixes (cannot use): `"handler."`, `"system."`, `"internal."` + +### @register_llm_tool + +```python +@register_llm_tool(name=None, *, description=None, parameters_schema=None, active=True) +``` + +| Param | Type | Default | Notes | +|-------|------|---------|-------| +| name | `str \| None` | None | Defaults to function name | +| description | `str \| None` | None | — | +| parameters_schema | `dict \| None` | None | JSON Schema; auto-generated from signature if omitted | +| active | `bool` | True | Whether tool is active by default | + +**Handler signature:** `async def handler(self, **kwargs) -> Any` + +Parameters in the function signature are used for auto-schema generation. + +### @http_api + +```python +@http_api(route, *, methods=None, description="", capability_name=None) +``` + +| Param | Type | Default | Notes | +|-------|------|---------|-------| +| route | `str` | required | e.g., "/api/status" | +| methods | `list[str] \| None` | ["GET"] | HTTP methods | +| description | `str` | "" | — | +| capability_name | `str \| None` | None | Optional capability name override | + +**Handler signature:** `async def handler(self, payload: dict, ctx: Context) -> dict` + +### @background_task + +```python +@background_task(*, description="", auto_start=True, on_error="log") +``` + +| Param | Type | Default | Notes | +|-------|------|---------|-------| +| description | `str` | "" | — | +| auto_start | `bool` | True | Start when plugin starts | +| on_error | `"log" \| "restart"` | "log" | Error handling strategy | + +**Handler signature:** `async def handler(self, ctx: Context) -> None` + +Typically contains a `while True:` loop with `await asyncio.sleep(...)`. + +### @validate_config + +```python +@validate_config(*, model=None, schema=None) +``` + +| Param | Type | Default | Notes | +|-------|------|---------|-------| +| model | `type[BaseModel] \| None` | None | Pydantic model (mutually exclusive with schema) | +| schema | `dict \| None` | None | JSON Schema (mutually exclusive with model) | + +Must provide exactly one of `model` or `schema`. + +### @on_provider_change + +```python +@on_provider_change(*, provider_types=None) +``` + +| Param | Type | Default | Notes | +|-------|------|---------|-------| +| provider_types | `list[str] \| tuple[str, ...] \| None` | None | e.g., ["llm", "embedding", "tts"] | + +**Handler signature:** `async def handler(self, ctx: Context) -> None` + +### @mcp_server + +```python +@mcp_server(*, name, scope="global", config=None, timeout=30.0, wait_until_ready=True) +``` + +| Param | Type | Default | Notes | +|-------|------|---------|-------| +| name | `str` | required | Non-empty | +| scope | `"local" \| "global"` | "global" | — | +| config | `dict \| None` | None | — | +| timeout | `float` | 30.0 | Must be positive | +| wait_until_ready | `bool` | True | — | + +### @register_skill + +```python +@register_skill(*, name, path, description="") +``` + +Class-level decorator. Can stack multiple. + +| Param | Type | Default | Notes | +|-------|------|---------|-------| +| name | `str` | required | Non-empty | +| path | `str` | required | Skill file path, non-empty | +| description | `str` | "" | — | + +### @register_agent + +```python +@register_agent(name, *, description="", tool_names=None) +``` + +Must decorate a `BaseAgentRunner` subclass. + +| Param | Type | Default | Notes | +|-------|------|---------|-------| +| name | `str` | required | — | +| description | `str` | "" | — | +| tool_names | `list[str] \| None` | None | Available tool names | + +--- + +## Filter / Guard Decorators + +### @require_admin / @admin_only + +No arguments. Restricts to admin users. + +```python +@on_command("admin-cmd") +@require_admin +async def handler(self, event: MessageEvent, ctx: Context) -> None: ... +``` + +### @platforms + +```python +@platforms(*names: str) +``` + +Restrict to specific platforms (e.g., "qq", "wechat"). + +**PITFALL:** Cannot combine with `@on_message(platforms=...)`. + +### @group_only() / @private_only() + +Called with parentheses. Restrict to group or private messages. + +**PITFALL:** Cannot combine `@group_only()` with `@private_only()`. + +### @message_types + +```python +@message_types(*types: str) +``` + +e.g., `@message_types("group", "private")` + +**PITFALL:** Cannot combine with `@group_only()` / `@private_only()`. + +### @rate_limit + +```python +@rate_limit(limit, window, *, scope="session", behavior="hint", message=None) +``` + +| Param | Type | Default | Notes | +|-------|------|---------|-------| +| limit | `int` | required | Max invocations per window (positive) | +| window | `float` | required | Seconds (positive) | +| scope | `"session" \| "user" \| "group" \| "global"` | "session" | Limiter scope | +| behavior | `"hint" \| "silent" \| "error"` | "hint" | hint=reply, silent=drop, error=raise | +| message | `str \| None` | None | Custom message for "hint" | + +### @cooldown + +```python +@cooldown(seconds, *, scope="session", behavior="hint", message=None) +``` + +Shorthand for `@rate_limit(1, seconds, ...)`. + +### @priority + +```python +@priority(value: int) +``` + +Higher value = executed first. + +### @custom_filter + +```python +@custom_filter(CustomFilter(callable)) +``` + +`callable` must be a sync function: `(event: MessageEvent) -> bool`. + +Composition: `all_of(*filters)`, `any_of(*filters)`. + +--- + +## Client APIs + +### ctx.db — KV Store (DBClient) + +| Method | Signature | Returns | Notes | +|--------|-----------|---------|-------| +| get | `(key: str)` | `Any \| None` | None if key missing | +| set | `(key: str, value: Any)` | `None` | value must be JSON-serializable | +| delete | `(key: str)` | `None` | **Always None** — does not return bool | +| list | `(prefix: str \| None = None)` | `list[str]` | Key names only; empty list if no matches | +| get_many | `(keys: Sequence[str])` | `dict[str, Any \| None]` | Missing keys → None | +| set_many | `(items: Mapping \| Sequence[tuple])` | `None` | Accepts dict or list of tuples | +| watch | `(prefix: str \| None = None)` | `AsyncIterator[dict]` | Yields `{"op": "set"\|"delete", "key": str, "value": Any\|None}` | + +### ctx.llm — AI Chat (LLMClient) + +| Method | Signature | Returns | Notes | +|--------|-----------|---------|-------| +| chat | `(prompt, *, system=None, history=None, contexts=None, provider_id=None, model=None, temperature=None, **kw)` | `str` | Text only | +| chat_raw | same as chat | `LLMResponse` | Has `.text`, `.usage`, `.finish_reason`, `.tool_calls` | +| stream_chat | same as chat | `AsyncGenerator[str]` | Yields text chunks | + +`history`: `Sequence[ChatMessage \| dict]` — chat history for context. +`contexts`: takes precedence over `history` if both provided. + +### ctx.memory — Semantic Storage (MemoryClient) + +| Method | Signature | Returns | Notes | +|--------|-----------|---------|-------| +| save | `(key, value=None, namespace=None, **extra)` | `None` | **value must be dict** (TypeError otherwise) | +| get | `(key, *, namespace=None)` | `dict \| None` | — | +| delete | `(key, *, namespace=None)` | `None` | — | +| search | `(query, *, mode="auto", limit=None, min_score=None, provider_id=None, namespace=None, include_descendants=True)` | `list[dict]` | Items have key, score, match_type | +| save_with_ttl | `(key, value, ttl_seconds, *, namespace=None)` | `None` | value must be dict; ttl_seconds >= 1 | +| get_many | `(keys, *, namespace=None)` | `list[dict]` | — | +| delete_many | `(keys, *, namespace=None)` | `int` | Count of deleted items | +| stats | `(*, namespace=None, include_descendants=True)` | `dict` | — | +| namespace | `(*parts)` | `MemoryClient` | **Not async**; returns derived client in child namespace | + +### ctx.platform — Messaging (PlatformClient) + +| Method | Signature | Returns | Notes | +|--------|-----------|---------|-------| +| send | `(session, text)` | `dict` | session: str / SessionRef / MessageSession | +| send_image | `(session, image_url)` | `dict` | — | +| send_chain | `(session, chain)` | `dict` | chain: MessageChain / list[component] / list[dict] | +| send_by_session | `(session, content)` | `dict` | content: str / MessageChain / list | +| send_by_id | `(platform_id, session_id, content, *, message_type="private")` | `dict` | — | +| get_members | `(session)` | `list[dict]` | Items may have user_id, nickname, role | + +### ctx.metadata — Plugin Metadata (MetadataClient) + +| Method | Signature | Returns | Notes | +|--------|-----------|---------|-------| +| get_plugin_config | `(name=None)` | `dict \| None` | **PermissionError** if accessing another plugin; None = current | +| save_plugin_config | `(config: dict)` | `dict` | TypeError if not dict | +| get_plugin | `(name: str)` | `StarMetadata \| None` | — | +| list_plugins | `()` | `list[StarMetadata]` | — | +| get_current_plugin | `()` | `StarMetadata \| None` | Current plugin's metadata | + +### ctx.files — File Service (FileServiceClient) + +File token registration and management. + +### ctx.http — HTTP (HTTPClient) + +HTTP API registration and listing. + +### ctx.mcp — MCP Manager (MCPManagerClient) + +MCP server lifecycle management. + +### ctx.providers — Provider Query (ProviderClient) + +Provider metadata queries and specialized provider proxy. + +### ctx.message_history — Message History (MessageHistoryManagerClient) + +Message history queries with pagination. + +--- + +## Core Types + +### MessageEvent + +| Property | Type | Notes | +|----------|------|-------| +| text | `str` | Message content | +| user_id | `str` | Sender ID | +| group_id | `str \| None` | None for private | +| platform | `str` | e.g., "qq", "wechat" | +| platform_id | `str` | Platform instance ID | +| self_id | `str` | Bot's own ID | +| session_id | `str` | For reply routing | +| message_type | `str` | "group" / "private" | +| sender_name | `str` | Display name | +| is_admin | `bool` | — | +| raw | `dict` | Raw message data | + +| Method | Returns | Notes | +|--------|---------|-------| +| `await reply(text)` | `dict` | Plain text reply | +| `await reply_image(url)` | `dict` | Image reply | +| `await reply_chain(chain)` | `dict` | Rich message reply | +| `stop_event()` | `None` | Prevent further handler processing | + +### MessageBuilder (fluent API) + +```python +chain = ( + MessageBuilder() + .text("Hello ") + .at("12345") + .text(" check this: ") + .image("https://example.com/img.png") + .build() +) +await event.reply_chain(chain.components) +``` + +Methods: `.text(str)`, `.at(user_id)`, `.at_all()`, `.image(url)`, `.record(url)`, `.video(url)`, `.file(name, *, file="", url="")`, `.reply(**kw)`, `.append(component)`, `.extend(components)`, `.build()` → `MessageChain` + +### Message Components + +| Class | Constructor | Serialized type | +|-------|-----------|----------------| +| `Plain(text)` | `Plain("hello")` | `"text"` (not "plain") | +| `Image(file)` | `Image.fromURL(url)` / `Image.fromFileSystem(path)` / `Image.fromBase64(data)` | `"image"` | +| `At(qq)` | `At(qq="12345")` | `"at"` | +| `AtAll()` | `AtAll()` | `"at"` with qq="all" | +| `Record(file)` | `Record.fromURL(url)` / `Record.fromFileSystem(path)` | `"record"` | +| `Video(file)` | `Video.fromURL(url)` / `Video.fromFileSystem(path)` | `"video"` | +| `File(name, file, url)` | `File("doc.pdf", url="https://...")` | `"file"` | +| `Reply(**kw)` | `Reply(id="msg_id")` | `"reply"` | +| `Forward(id)` | `Forward(id="msg_id")` | `"forward"` | +| `Poke(poke_type)` | `Poke()` | `"poke"` | + +### Other Key Types + +- **GreedyStr**: Annotate the last command parameter to capture all remaining text. +- **ScheduleContext**: Injected into `@on_schedule` handlers. Has `schedule_id`, `plugin_id`, `handler_id`, `trigger_kind` ("cron"/"interval"/"once"), `cron`, `interval_seconds`. +- **ConversationSession**: Injected into `@conversation_command` handlers. Key methods: `ask()`, `reply()`, `reply_chain()`, `send_message()`, `end()`. States: ACTIVE, REJECTED_BUSY, REPLACED, TIMEOUT, COMPLETED, CANCELLED. +- **ChatMessage(role, content)**: For LLM history. +- **LLMResponse**: Has `.text`, `.usage`, `.finish_reason`, `.tool_calls`, `.role`, `.reasoning_content`. +- **CommandGroup**: For hierarchical command trees. Create with `command_group(name)`, add subgroups with `.group()`, add commands with `.command()`. + +--- + +## Source basis + +Derived from: +- `src/astrbot_sdk/decorators.py` +- `src/astrbot_sdk/clients/*.py` +- `src/astrbot_sdk/events.py` +- `src/astrbot_sdk/message/components.py` +- `src/astrbot_sdk/message/result.py` +- `src/astrbot_sdk/conversation.py` +- `src/astrbot_sdk/context.py` diff --git a/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/references/plugin-patterns.md b/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/references/plugin-patterns.md new file mode 100644 index 0000000000..e794313d71 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/references/plugin-patterns.md @@ -0,0 +1,505 @@ +# Plugin Patterns + +Complete, working plugin examples. Each pattern includes `main.py`, `plugin.yaml`, and a test snippet. + +--- + +## Pattern 1: Simple Command Plugin + +A basic command with aliases and typed parameters. + +**main.py:** +```python +from astrbot_sdk import Context, MessageEvent, Star +from astrbot_sdk.decorators import on_command + + +class GreetPlugin(Star): + @on_command("greet", aliases=["hi", "hello"], description="Greet the user") + async def greet(self, event: MessageEvent, ctx: Context) -> None: + await event.reply(f"Hello, {event.sender_name}!") +``` + +**plugin.yaml:** +```yaml +name: astrbot_plugin_greet +display_name: Greet +desc: Simple greeting plugin +author: dev +version: 1.0.0 +runtime: + python: "3.12" +components: + - class: main:GreetPlugin +``` + +**test:** +```python +@pytest.mark.asyncio +async def test_greet(): + async with PluginHarness.from_plugin_dir(plugin_dir) as h: + records = await h.dispatch_text("greet") + assert any("Hello" in r.text for r in records) +``` + +--- + +## Pattern 2: CRUD KV Store Plugin + +Full create/read/update/delete with `ctx.db`. + +**main.py:** +```python +from astrbot_sdk import Context, GreedyStr, MessageEvent, Star +from astrbot_sdk.decorators import on_command + + +class NotesPlugin(Star): + @on_command("note-save", description="Save a note: note-save ") + async def save(self, event: MessageEvent, ctx: Context, key: str, content: GreedyStr) -> None: + await ctx.db.set(f"notes:{key}", {"content": str(content).strip()}) + await event.reply(f"Saved note '{key}'.") + + @on_command("note-get", description="Read a note: note-get ") + async def get(self, event: MessageEvent, ctx: Context, key: str) -> None: + note = await ctx.db.get(f"notes:{key}") + if not isinstance(note, dict) or not note.get("content"): + await event.reply(f"No note found for '{key}'.") + return + await event.reply(f"{key}: {note['content']}") + + @on_command("note-delete", description="Delete a note: note-delete ") + async def delete(self, event: MessageEvent, ctx: Context, key: str) -> None: + # IMPORTANT: ctx.db.delete() returns None, not bool. + # Check existence first if you need to inform the user. + existing = await ctx.db.get(f"notes:{key}") + if not existing: + await event.reply(f"No note found for '{key}'.") + return + await ctx.db.delete(f"notes:{key}") + await event.reply(f"Deleted note '{key}'.") + + @on_command("note-list", description="List all notes") + async def list_notes(self, event: MessageEvent, ctx: Context) -> None: + keys = await ctx.db.list("notes:") + if not keys: + await event.reply("No notes saved.") + return + names = [k.removeprefix("notes:") for k in keys] + await event.reply("Notes: " + ", ".join(names)) +``` + +**test:** +```python +@pytest.mark.asyncio +async def test_note_crud(): + async with PluginHarness.from_plugin_dir(plugin_dir) as h: + await h.dispatch_text("note-save todo buy milk") + records = await h.dispatch_text("note-get todo") + assert any(r.text == "todo: buy milk" for r in records) +``` + +--- + +## Pattern 3: Keyword and Regex Message Handler + +React to keywords and regex patterns with platform/group filters. + +**main.py:** +```python +from astrbot_sdk import Context, MessageEvent, Star +from astrbot_sdk.decorators import on_message, group_only, platforms + + +class KeywordPlugin(Star): + @on_message(keywords=["help", "帮助"]) + async def help_handler(self, event: MessageEvent, ctx: Context) -> None: + await event.reply("Available commands: /greet, /note-save, /note-get") + + @on_message(regex=r"^\d{4}-\d{2}-\d{2}$") + async def date_handler(self, event: MessageEvent, ctx: Context) -> None: + await event.reply(f"Detected date: {event.text}") + + @on_message(keywords=["notify"]) + @group_only() + @platforms("qq") + async def qq_group_only(self, event: MessageEvent, ctx: Context) -> None: + await event.reply("QQ group notification received!") +``` + +--- + +## Pattern 4: Scheduled Task Plugin + +Periodic tasks using cron and interval. + +**main.py:** +```python +import asyncio +from astrbot_sdk import Context, Star +from astrbot_sdk.decorators import on_schedule +from astrbot_sdk.schedule import ScheduleContext + + +class ScheduledPlugin(Star): + @on_schedule(cron="0 8 * * *", description="Daily morning greeting") + async def morning(self, ctx: Context) -> None: + # No event available — use ctx.platform.send() for proactive messages + await ctx.platform.send("target-session-id", "Good morning!") + + @on_schedule(interval_seconds=3600, description="Hourly health check") + async def health_check(self, ctx: Context, schedule: ScheduleContext) -> None: + ctx.logger.info(f"Health check #{schedule.schedule_id}, trigger: {schedule.trigger_kind}") +``` + +**Note:** Scheduled handlers have no `event` parameter. You cannot call `event.reply()`. + +--- + +## Pattern 5: Multi-Turn Conversation Plugin + +Interactive dialogue with `ConversationSession`. + +**main.py:** +```python +from astrbot_sdk import Context, ConversationSession, MessageEvent, Star +from astrbot_sdk.decorators import conversation_command + + +class SurveyPlugin(Star): + @conversation_command( + "survey", + description="Interactive survey", + timeout=120, + mode="reject", + busy_message="A survey is already in progress. Please complete it first.", + ) + async def survey(self, event: MessageEvent, ctx: Context, session: ConversationSession) -> None: + await session.reply("Welcome to the survey! What is your name?") + + name_event = await session.ask("Please enter your name:") + name = name_event.text.strip() + + rating_event = await session.ask(f"Hi {name}! Rate us 1-5:") + rating = rating_event.text.strip() + + await ctx.db.set(f"survey:{event.user_id}", {"name": name, "rating": rating}) + await session.reply(f"Thanks {name}! Your rating of {rating} has been saved.") + session.end() +``` + +**Key points:** +- `session.ask()` sends a prompt and waits for the user's next message. +- `session.reply()` sends a message without waiting. +- `session.end()` marks the session complete. +- Handle `TimeoutError` if the user doesn't respond within `timeout`. + +--- + +## Pattern 6: LLM-Powered Plugin + +Use AI for intelligent responses. + +**main.py:** +```python +from astrbot_sdk import Context, GreedyStr, MessageEvent, Star +from astrbot_sdk.decorators import on_command +from astrbot_sdk.clients import ChatMessage + + +class AIPlugin(Star): + @on_command("ask", description="Ask the AI a question") + async def ask_ai(self, event: MessageEvent, ctx: Context, question: GreedyStr) -> None: + answer = await ctx.llm.chat( + str(question), + system="You are a helpful assistant. Be concise.", + ) + await event.reply(answer) + + @on_command("chat", description="Chat with history") + async def chat_with_history(self, event: MessageEvent, ctx: Context, message: GreedyStr) -> None: + # Load history from DB + history_data = await ctx.db.get(f"chat_history:{event.user_id}") or {"messages": []} + history = [ChatMessage(**m) for m in history_data["messages"][-10:]] + + response = await ctx.llm.chat(str(message), history=history) + + # Save updated history + history_data["messages"].append({"role": "user", "content": str(message)}) + history_data["messages"].append({"role": "assistant", "content": response}) + await ctx.db.set(f"chat_history:{event.user_id}", history_data) + + await event.reply(response) + + @on_command("stream-ask", description="Stream AI response") + async def stream_ask(self, event: MessageEvent, ctx: Context, question: GreedyStr) -> None: + chunks = [] + async for chunk in ctx.llm.stream_chat(str(question)): + chunks.append(chunk) + await event.reply("".join(chunks)) +``` + +--- + +## Pattern 7: Capability Provider Plugin + +Expose capabilities for other plugins to call. + +**main.py:** +```python +from pydantic import BaseModel +from astrbot_sdk import Context, Star +from astrbot_sdk.decorators import provide_capability + + +class CalcInput(BaseModel): + x: float + y: float + op: str = "add" + + +class CalcOutput(BaseModel): + result: float + + +class CalcPlugin(Star): + @provide_capability( + "calc.compute", + description="Perform arithmetic", + input_model=CalcInput, + output_model=CalcOutput, + ) + async def compute(self, payload: dict, ctx: Context) -> dict: + data = CalcInput.model_validate(payload) + ops = {"add": data.x + data.y, "sub": data.x - data.y, "mul": data.x * data.y} + result = ops.get(data.op, data.x + data.y) + return {"result": result} +``` + +**test:** +```python +@pytest.mark.asyncio +async def test_capability(): + async with PluginHarness.from_plugin_dir(plugin_dir) as h: + result = await h.invoke_capability("calc.compute", {"x": 3, "y": 4, "op": "mul"}) + assert result["result"] == 12.0 +``` + +--- + +## Pattern 8: HTTP API Plugin + +Expose REST endpoints. + +**main.py:** +```python +from astrbot_sdk import Context, Star +from astrbot_sdk.decorators import http_api + + +class WebhookPlugin(Star): + @http_api("/api/status", methods=["GET"], description="Health check") + async def status(self, payload: dict, ctx: Context) -> dict: + return {"status": "ok", "plugin": ctx.plugin_id} + + @http_api("/api/notify", methods=["POST"], description="Send notification") + async def notify(self, payload: dict, ctx: Context) -> dict: + target = payload.get("session_id", "") + message = payload.get("message", "") + if not target or not message: + return {"error": "session_id and message required"} + await ctx.platform.send(target, message) + return {"sent": True} +``` + +--- + +## Pattern 9: Rich Messages Plugin + +Build rich messages with components. + +**main.py:** +```python +from astrbot_sdk import ( + Context, MessageBuilder, MessageEvent, Plain, At, Image, Star, +) +from astrbot_sdk.decorators import on_command + + +class RichPlugin(Star): + @on_command("welcome", description="Welcome with rich message") + async def welcome(self, event: MessageEvent, ctx: Context) -> None: + chain = ( + MessageBuilder() + .text("Welcome ") + .at(event.user_id) + .text("!\nHere's a guide image:") + .image("https://example.com/guide.png") + .build() + ) + await event.reply_chain(chain.components) + + @on_command("card", description="Send info card") + async def card(self, event: MessageEvent, ctx: Context) -> None: + # Manual component list + components = [ + Plain("📋 User Info\n"), + Plain(f"Name: {event.sender_name}\n"), + Plain(f"ID: {event.user_id}\n"), + Plain(f"Platform: {event.platform}"), + ] + await event.reply_chain(components) +``` + +--- + +## Pattern 10: Lifecycle Hooks with Config Validation + +Initialize resources on start, clean up on stop. + +**main.py:** +```python +from pydantic import BaseModel +from astrbot_sdk import Context, MessageEvent, Star +from astrbot_sdk.decorators import on_command, validate_config + + +class PluginConfig(BaseModel): + api_key: str + max_retries: int = 3 + timeout: float = 30.0 + + +class ConfigPlugin(Star): + def __init__(self) -> None: + super().__init__() + self._api_key: str = "" + self._max_retries: int = 3 + + async def on_start(self, ctx) -> None: + await super().on_start(ctx) + config = await ctx.metadata.get_plugin_config() + if config: + validated = PluginConfig.model_validate(config) + self._api_key = validated.api_key + self._max_retries = validated.max_retries + + async def on_stop(self, ctx) -> None: + # Clean up resources here + await super().on_stop(ctx) + + @on_command("status", description="Show config status") + async def status(self, event: MessageEvent, ctx: Context) -> None: + masked = self._api_key[:4] + "****" if self._api_key else "not set" + await event.reply(f"API key: {masked}, retries: {self._max_retries}") +``` + +**Key rules:** +- Always call `await super().on_start(ctx)` and `await super().on_stop(ctx)`. +- Do not store `ctx` on `self` — use it only within the call. +- Store extracted config values on `self`, not the ctx object. + +--- + +## Pattern 11: Command Groups + +Hierarchical command organization. + +**main.py:** +```python +from astrbot_sdk import Context, GreedyStr, MessageEvent, Star +from astrbot_sdk.commands import command_group +from astrbot_sdk.decorators import require_admin + +admin = command_group("admin", description="Admin commands") +user_grp = admin.group("user", description="User management") + + +class AdminPlugin(Star): + @user_grp.command("add", description="Add a user") + @require_admin + async def user_add(self, event: MessageEvent, ctx: Context, username: str) -> None: + await ctx.db.set(f"users:{username}", {"added_by": event.user_id}) + await event.reply(f"User '{username}' added.") + + @user_grp.command("remove", description="Remove a user") + @require_admin + async def user_remove(self, event: MessageEvent, ctx: Context, username: str) -> None: + existing = await ctx.db.get(f"users:{username}") + if not existing: + await event.reply(f"User '{username}' not found.") + return + await ctx.db.delete(f"users:{username}") + await event.reply(f"User '{username}' removed.") + + @admin.command("help", description="Show admin help") + async def admin_help(self, event: MessageEvent, ctx: Context) -> None: + await event.reply("Admin commands: admin user add , admin user remove ") +``` + +--- + +## Pattern 12: Background Task + +Long-running background loop. + +**main.py:** +```python +import asyncio +from astrbot_sdk import Context, Star +from astrbot_sdk.decorators import background_task, on_command + + +class MonitorPlugin(Star): + def __init__(self) -> None: + super().__init__() + self._check_count: int = 0 + + @background_task(description="Periodic monitor", auto_start=True, on_error="restart") + async def monitor(self, ctx: Context) -> None: + while True: + self._check_count += 1 + ctx.logger.info(f"Monitor check #{self._check_count}") + # Perform monitoring logic here + await asyncio.sleep(60) + + @on_command("monitor-status", description="Show monitor status") + async def status(self, event, ctx: Context) -> None: + await event.reply(f"Monitor has run {self._check_count} checks.") +``` + +--- + +## Pattern 13: Rate-Limited Admin Command + +Stack multiple decorators for access control and throttling. + +**main.py:** +```python +from astrbot_sdk import Context, GreedyStr, MessageEvent, Star +from astrbot_sdk.decorators import on_command, require_admin, rate_limit, group_only + + +class ModerationPlugin(Star): + @on_command("announce", description="Send group announcement") + @require_admin + @group_only() + @rate_limit(3, 300.0, scope="group", behavior="hint", message="Max 3 announcements per 5 minutes.") + async def announce(self, event: MessageEvent, ctx: Context, text: GreedyStr) -> None: + announcement = f"📢 Announcement from {event.sender_name}:\n{text}" + await event.reply(announcement) +``` + +**Decorator stacking order:** trigger first (topmost), then guards, then throttle. + +--- + +## Source basis + +Derived from: +- `src/astrbot_sdk/decorators.py` +- `src/astrbot_sdk/clients/*.py` +- `src/astrbot_sdk/conversation.py` +- `src/astrbot_sdk/commands.py` +- `src/astrbot_sdk/message/result.py` +- `forward-tests/astrbot-plugin-dev/` diff --git a/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/references/project-structure.md b/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/references/project-structure.md new file mode 100644 index 0000000000..c95d7ae575 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/templates/skills/astrbot-plugin-dev/references/project-structure.md @@ -0,0 +1,197 @@ +# Project Structure & Testing + +## plugin.yaml Schema + +```yaml +# Required fields +name: astrbot_plugin_my_plugin # Must start with astrbot_plugin_ +display_name: My Plugin # Human-readable name +desc: What the plugin does # Short description +author: Your Name # Author name +version: 1.0.0 # Semver +runtime: + python: "3.12" # Python version +components: + - class: main:MyPluginClass # module:ClassName + +# Optional fields +astrbot_version: "0.11.0" # Minimum AstrBot version +support_platforms: # Platform compatibility + - qq + - wechat +reserved: false # true only for core plugins +``` + +Notes: +- `name` must start with `astrbot_plugin_`. The `init` command adds this prefix automatically. +- `components[].class` uses `module:ClassName` format. The module is relative to the plugin directory. +- Do not invent parallel entrypoints unless the task requires multiple components. + +## File Layout + +``` +astrbot_plugin_my_plugin/ +├── plugin.yaml # Manifest (required) +├── main.py # Plugin class (required) +├── requirements.txt # Dependencies (optional) +└── tests/ + └── test_plugin.py # Tests (recommended) +``` + +## CLI Commands + +All commands support two entrypoints. If `astrbot-sdk` is not on PATH, use `python -m astrbot_sdk`: + +| Command | Purpose | +|---------|---------| +| `astrbot-sdk init ` | Scaffold new plugin | +| `astrbot-sdk validate --plugin-dir ` | Validate structure, imports, handler discovery | +| `astrbot-sdk dev --local --plugin-dir --event-text "..."` | Single-shot local test | +| `astrbot-sdk dev --local --plugin-dir --interactive` | Interactive local test | +| `astrbot-sdk dev --local --plugin-dir --watch --event-text "..."` | Watch mode with auto-reload | +| `astrbot-sdk build --plugin-dir ` | Package into distributable zip | + +### init + +- Run from the **parent** directory where the plugin folder should be created. +- Normalizes name: `quick-notes` → `astrbot_plugin_quick_notes/`. +- Do not run inside an existing plugin directory. +- Replace scaffold code with actual behavior; do not keep dead defaults. + +### validate + +Reports: handler count, capability count, component instances. + +### dev --local + +- Uses SDK's local mock core (no real AstrBot instance needed). +- `--watch` has known reload pitfalls; prefer fresh runs for subtle behavior changes. + +### build + +Produces `dist/-.zip`. + +## Testing + +### Black-box test with PluginHarness (preferred) + +```python +from pathlib import Path + +import pytest + +from astrbot_sdk.testing import PluginHarness + + +@pytest.mark.asyncio +async def test_hello(): + plugin_dir = Path(__file__).resolve().parents[1] + + async with PluginHarness.from_plugin_dir(plugin_dir) as h: + records = await h.dispatch_text("hello") + + assert any(r.text == "Hello!" for r in records) +``` + +### Custom session/user/platform + +```python +async with PluginHarness.from_plugin_dir( + plugin_dir, + session_id="test-session", + user_id="user-42", + platform="qq", + group_id="group-1", +) as h: + records = await h.dispatch_text("hello") +``` + +### Override per dispatch + +```python +records = await h.dispatch_text("hello", user_id="other-user", group_id="other-group") +``` + +### Testing capabilities + +```python +result = await h.invoke_capability("my_plugin.compute", {"x": 5, "y": 3}) +assert result["result"] == 8 +``` + +### Accessing sent messages + +```python +# All messages sent during harness lifetime +all_messages = h.sent_messages + +# Clear between test steps +h.clear_sent_messages() +``` + +### RecordedSend properties + +Each item in the `records` list has: +- `.text` — plain text content of the reply + +### Lifecycle + +`PluginHarness` as async context manager automatically calls `on_start()` on enter and `on_stop()` on exit. + +## Testing Pitfalls + +### NEVER use `from main import ...` + +```python +# BAD — pollutes sys.modules["main"] +from main import MyPlugin + +# GOOD — use PluginHarness +async with PluginHarness.from_plugin_dir(plugin_dir) as h: + records = await h.dispatch_text("hello") +``` + +### Ignore cached files + +When copying plugin fixtures, exclude: +- `__pycache__/` +- `*.pyc` +- `*.pyo` + +### dispatch_text behavior + +- Returns `list[RecordedSend]` — the messages the plugin sent in response. +- If no handler matches, behavior depends on the event type. + +### Watch mode caveats + +- Reload correctness depends on loader cache cleanup. +- Prefer fresh `dev --local` runs over `--watch` for subtle behavior changes. + +## Validation Loop + +Run after every meaningful change: + +```bash +# 1. Validate structure +astrbot-sdk validate --plugin-dir + +# 2. Smoke test +astrbot-sdk dev --local --plugin-dir --event-text "" + +# 3. Run tests +python -m pytest tests -q + +# 4. Package (if needed) +astrbot-sdk build --plugin-dir +``` + +If any step fails, fix before proceeding. + +## Source basis + +Derived from: +- `src/astrbot_sdk/cli.py` +- `src/astrbot_sdk/testing.py` +- `README.md` +- `docs/08_testing_guide.md` diff --git a/astrbot-sdk/src/astrbot_sdk/testing.py b/astrbot-sdk/src/astrbot_sdk/testing.py new file mode 100644 index 0000000000..41193b883f --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/testing.py @@ -0,0 +1,780 @@ +"""本地开发与插件测试辅助。 + +`astrbot_sdk.testing` 是面向插件作者的稳定开发入口: + +- `PluginHarness` 负责复用现有 loader / dispatcher 执行链 +- `MockCapabilityRouter` 提供进程内 mock core 能力 +- `MockPeer` 让 `Context` 客户端继续走真实的 capability 调用路径 +- `StdoutPlatformSink` / `RecordedSend` 提供可观测的发送记录 + +这个模块刻意不暴露 runtime 内部编排数据结构,只封装本地开发/测试真正 +需要的最小稳定面。 +""" + +from __future__ import annotations + +import asyncio +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from ._internal.decorator_lifecycle import run_lifecycle_with_decorators +from ._internal.testing_support import ( + InMemoryDB, + InMemoryMemory, + MockCapabilityRouter, + MockContext, + MockLLMClient, + MockMessageEvent, + MockPeer, + MockPlatformClient, + RecordedSend, + StdoutPlatformSink, +) +from ._message_types import normalize_message_type +from .context import CancelToken +from .context import Context as RuntimeContext +from .errors import AstrBotError +from .events import MessageEvent +from .protocol.descriptors import ( + CommandTrigger, + CompositeFilterSpec, + EventTrigger, + LocalFilterRefSpec, + MessageTrigger, + MessageTypeFilterSpec, + PlatformFilterSpec, + ScheduleTrigger, +) +from .protocol.messages import InvokeMessage +from .runtime._command_matching import ( + build_command_args, + build_regex_args, + match_command_name, +) +from .runtime._streaming import StreamExecution +from .runtime.handler_dispatcher import CapabilityDispatcher, HandlerDispatcher +from .runtime.loader import ( + LoadedHandler, + LoadedPlugin, + PluginSpec, + load_plugin, + load_plugin_config, + load_plugin_spec, + validate_plugin_spec, +) +from .star import Star + + +class _PluginLoadError(RuntimeError): + """本地 harness 初始化阶段的已知插件加载失败。""" + + +class _PluginExecutionError(RuntimeError): + """本地 harness 执行插件代码时的已知插件异常。""" + + +def _plugin_metadata_from_spec( + plugin: PluginSpec, + *, + enabled: bool, +) -> dict[str, Any]: + manifest = plugin.manifest_data + support_platforms = manifest.get("support_platforms") + return { + "name": plugin.name, + "display_name": str(manifest.get("display_name") or plugin.name), + "description": str(manifest.get("desc") or manifest.get("description") or ""), + "author": str(manifest.get("author") or ""), + "version": str(manifest.get("version") or "0.0.0"), + "enabled": enabled, + "reserved": bool(manifest.get("reserved", False)), + "support_platforms": [ + str(item) for item in support_platforms if isinstance(item, str) + ] + if isinstance(support_platforms, list) + else [], + "astrbot_version": ( + str(manifest.get("astrbot_version")) + if manifest.get("astrbot_version") is not None + else None + ), + } + + +def _handler_metadata_from_loaded( + plugin_id: str, loaded: LoadedHandler +) -> dict[str, Any]: + event_types: list[str] = [] + trigger = loaded.descriptor.trigger + if isinstance(trigger, EventTrigger): + event_types.append(trigger.type) + return { + "plugin_name": plugin_id, + "handler_full_name": loaded.descriptor.id, + "trigger_type": trigger.type + if isinstance(trigger, EventTrigger) + else str(getattr(trigger, "kind", trigger.type)), + "event_types": event_types, + "enabled": True, + "group_path": list( + loaded.descriptor.command_route.group_path + if loaded.descriptor.command_route is not None + else [] + ), + "require_admin": loaded.descriptor.permissions.require_admin, + "required_role": loaded.descriptor.permissions.required_role, + } + + +@dataclass(slots=True) +class LocalRuntimeConfig: + """本地 harness 的稳定配置对象。""" + + plugin_dir: Path + session_id: str = "local-session" + user_id: str = "local-user" + platform: str = "test" + group_id: str | None = None + event_type: str = "message" + + +@dataclass(slots=True) +class MockClock: + now: float = 0.0 + + def time(self) -> float: + return self.now + + def advance(self, seconds: float) -> float: + self.now += float(seconds) + return self.now + + +@dataclass(slots=True) +class SDKTestEnvironment: + root: Path + + @property + def plugins_dir(self) -> Path: + path = self.root / "plugins" + path.mkdir(parents=True, exist_ok=True) + return path + + def plugin_dir(self, name: str) -> Path: + path = self.plugins_dir / name + path.mkdir(parents=True, exist_ok=True) + return path + + +class PluginHarness: + """本地插件消息泵。 + + 这里复用真实的 loader / dispatcher 执行链,只负责: + - 在同一个事件循环里装配单插件运行时 + - 维持本地 mock core 与发送记录 + - 把后续消息持续送入同一个 dispatcher + """ + + def __init__( + self, + config: LocalRuntimeConfig, + *, + platform_sink: StdoutPlatformSink | None = None, + ) -> None: + self.config = config + self.platform_sink = platform_sink or StdoutPlatformSink() + self.router = MockCapabilityRouter(platform_sink=self.platform_sink) + self.peer = MockPeer(self.router) + self.plugin: PluginSpec | None = None + self.loaded_plugin: LoadedPlugin | None = None + self.dispatcher: HandlerDispatcher | None = None + self.capability_dispatcher: CapabilityDispatcher | None = None + self.lifecycle_context: RuntimeContext | None = None + self._request_counter = 0 + self._started = False + + @classmethod + def from_plugin_dir( + cls, + plugin_dir: str | Path, + *, + session_id: str = "local-session", + user_id: str = "local-user", + platform: str = "test", + group_id: str | None = None, + event_type: str = "message", + platform_sink: StdoutPlatformSink | None = None, + ) -> PluginHarness: + return cls( + LocalRuntimeConfig( + plugin_dir=Path(plugin_dir), + session_id=session_id, + user_id=user_id, + platform=platform, + group_id=group_id, + event_type=event_type, + ), + platform_sink=platform_sink, + ) + + async def __aenter__(self) -> PluginHarness: + await self.start() + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.stop() + + @property + def sent_messages(self) -> list[RecordedSend]: + return list(self.platform_sink.records) + + def clear_sent_messages(self) -> None: + self.platform_sink.clear() + + async def start(self) -> None: + if self._started: + return + try: + self.plugin = load_plugin_spec(self.config.plugin_dir) + validate_plugin_spec(self.plugin) + self.loaded_plugin = load_plugin(self.plugin) + except Exception as exc: # pragma: no cover - 由 CLI/集成测试覆盖 + raise _PluginLoadError(str(exc)) from exc + self.dispatcher = HandlerDispatcher( + plugin_id=self.plugin.name, + peer=self.peer, + handlers=self.loaded_plugin.handlers, + ) + self.capability_dispatcher = CapabilityDispatcher( + plugin_id=self.plugin.name, + peer=self.peer, + capabilities=self.loaded_plugin.capabilities, + llm_tools=self.loaded_plugin.llm_tools, + ) + self.lifecycle_context = RuntimeContext( + peer=self.peer, + plugin_id=self.plugin.name, + ) + plugin_metadata = _plugin_metadata_from_spec(self.plugin, enabled=True) + plugin_metadata["acknowledge_global_mcp_risk"] = any( + bool( + getattr( + instance.__class__, + "__astrbot_acknowledge_global_mcp_risk__", + False, + ) + ) + for instance in self.loaded_plugin.instances + ) + self.router.upsert_plugin( + metadata=plugin_metadata, + config=load_plugin_config(self.plugin), + ) + self.router.set_plugin_handlers( + self.plugin.name, + [ + _handler_metadata_from_loaded(self.plugin.name, handler) + for handler in self.loaded_plugin.handlers + ], + ) + self.router.set_plugin_llm_tools( + self.plugin.name, + [tool.spec.to_payload() for tool in self.loaded_plugin.llm_tools], + ) + self.router.set_plugin_agents( + self.plugin.name, + [agent.spec.to_payload() for agent in self.loaded_plugin.agents], + ) + try: + await self._run_lifecycle("on_start") + except AstrBotError: + raise + except Exception as exc: # pragma: no cover - 由 CLI/集成测试覆盖 + raise _PluginExecutionError(str(exc)) from exc + self._started = True + + async def stop(self) -> None: + if ( + not self._started + or self.loaded_plugin is None + or self.lifecycle_context is None + ): + return + try: + await self._run_lifecycle("on_stop") + finally: + if self.plugin is not None: + self.router.set_plugin_enabled(self.plugin.name, False) + self.router.set_plugin_handlers(self.plugin.name, []) + self.router.remove_dynamic_command_routes_for_plugin(self.plugin.name) + self.router.remove_http_apis_for_plugin(self.plugin.name) + self._started = False + + async def dispatch_text( + self, + text: str, + *, + session_id: str | None = None, + user_id: str | None = None, + platform: str | None = None, + group_id: str | None = None, + event_type: str | None = None, + request_id: str | None = None, + ) -> list[RecordedSend]: + payload = self.build_event_payload( + text=text, + session_id=session_id, + user_id=user_id, + platform=platform, + group_id=group_id, + event_type=event_type, + request_id=request_id, + ) + return await self.dispatch_event(payload, request_id=request_id) + + async def dispatch_event( + self, + event_payload: dict[str, Any], + *, + request_id: str | None = None, + ) -> list[RecordedSend]: + await self.start() + assert self.loaded_plugin is not None + assert self.dispatcher is not None + + start_index = len(self.platform_sink.records) + if self._has_waiter_for_event(event_payload): + await self._invoke_session_waiter( + event_payload, + request_id=request_id, + ) + await self._wait_for_followup_side_effects( + start_index=start_index, + event_payload=event_payload, + ) + return self.platform_sink.records[start_index:] + + matches = self._match_handlers(event_payload) + if not matches: + raise AstrBotError.invalid_input("未找到匹配的 handler") + for loaded, args in matches: + await self._invoke_handler( + loaded, + event_payload, + args=args, + request_id=request_id, + ) + return self.platform_sink.records[start_index:] + + async def invoke_capability( + self, + capability: str, + payload: dict[str, Any], + *, + request_id: str | None = None, + stream: bool = False, + ) -> dict[str, Any] | StreamExecution: + await self.start() + assert self.capability_dispatcher is not None + message = InvokeMessage( + id=request_id or self._next_request_id("cap"), + capability=capability, + input=dict(payload), + stream=stream, + ) + try: + return await self.capability_dispatcher.invoke(message, CancelToken()) + except AstrBotError: + raise + except Exception as exc: # pragma: no cover - 由 CLI/集成测试覆盖 + raise _PluginExecutionError(str(exc)) from exc + + def build_event_payload( + self, + *, + text: str, + session_id: str | None = None, + user_id: str | None = None, + platform: str | None = None, + group_id: str | None = None, + event_type: str | None = None, + request_id: str | None = None, + ) -> dict[str, Any]: + session_value = session_id or self.config.session_id + group_value = group_id if group_id is not None else self.config.group_id + event_type_value = event_type or self.config.event_type + payload = { + "type": event_type_value, + "event_type": event_type_value, + "text": text, + "session_id": session_value, + "user_id": user_id or self.config.user_id, + "platform": platform or self.config.platform, + "platform_id": platform or self.config.platform, + "group_id": group_value, + "self_id": f"{platform or self.config.platform}-bot", + "sender_name": str(user_id or self.config.user_id or ""), + "is_admin": False, + "raw": { + "trace_id": request_id or self._next_request_id("trace"), + "event_type": event_type_value, + }, + } + if group_value: + payload["message_type"] = "group" + elif payload["user_id"]: + payload["message_type"] = "private" + else: + payload["message_type"] = "other" + return payload + + async def _invoke_handler( + self, + loaded: LoadedHandler, + event_payload: dict[str, Any], + *, + args: dict[str, Any], + request_id: str | None = None, + ) -> None: + assert self.dispatcher is not None + message = InvokeMessage( + id=request_id or self._next_request_id("msg"), + capability="handler.invoke", + input={ + "handler_id": loaded.descriptor.id, + "event": dict(event_payload), + "args": dict(args), + }, + ) + try: + await self.dispatcher.invoke(message, CancelToken()) + except AstrBotError: + raise + except Exception as exc: # pragma: no cover - 由 CLI/集成测试覆盖 + raise _PluginExecutionError(str(exc)) from exc + + async def _invoke_session_waiter( + self, + event_payload: dict[str, Any], + *, + request_id: str | None = None, + ) -> None: + assert self.dispatcher is not None + message = InvokeMessage( + id=request_id or self._next_request_id("msg"), + capability="handler.invoke", + input={ + "handler_id": "__sdk_session_waiter__", + "event": dict(event_payload), + "args": {}, + }, + ) + try: + await self.dispatcher.invoke(message, CancelToken()) + except AstrBotError: + raise + except Exception as exc: # pragma: no cover - 由 CLI/集成测试覆盖 + raise _PluginExecutionError(str(exc)) from exc + + async def _wait_for_followup_side_effects( + self, + *, + start_index: int, + event_payload: dict[str, Any], + ) -> None: + settled_rounds = 0 + for _ in range(20): + if len(self.platform_sink.records) > start_index: + return + await asyncio.sleep(0) + if self._has_waiter_for_event(event_payload): + settled_rounds = 0 + continue + settled_rounds += 1 + if settled_rounds >= 3: + return + + async def _run_lifecycle(self, method_name: str) -> None: + assert self.loaded_plugin is not None + assert self.lifecycle_context is not None + + for instance in self.loaded_plugin.instances: + hook = self._resolve_lifecycle_hook(instance, method_name) + await run_lifecycle_with_decorators( + instance=instance, + hook=hook, + method_name=method_name, + context=self.lifecycle_context, + ) + + def _match_handlers( + self, + event_payload: dict[str, Any], + ) -> list[tuple[LoadedHandler, dict[str, Any]]]: + assert self.loaded_plugin is not None + ranked: list[tuple[int, int, LoadedHandler, dict[str, Any]]] = [] + for index, loaded in enumerate(self.loaded_plugin.handlers): + args = self._match_handler(loaded, event_payload) + if args is None: + continue + ranked.append((loaded.descriptor.priority, index, loaded, args)) + for dynamic in self._match_dynamic_handlers(event_payload): + ranked.append(dynamic) + ranked.sort(key=lambda item: (-item[0], item[1])) + return [(loaded, args) for _priority, _index, loaded, args in ranked] + + def _match_dynamic_handlers( + self, + event_payload: dict[str, Any], + ) -> list[tuple[int, int, LoadedHandler, dict[str, Any]]]: + assert self.loaded_plugin is not None + assert self.plugin is not None + ranked: list[tuple[int, int, LoadedHandler, dict[str, Any]]] = [] + routes = self.router.list_dynamic_command_routes(self.plugin.name) + handler_map = { + loaded.descriptor.id: loaded for loaded in self.loaded_plugin.handlers + } + base_order = len(self.loaded_plugin.handlers) + for index, route in enumerate(routes): + if not isinstance(route, dict): + continue + handler_full_name = str(route.get("handler_full_name", "")).strip() + loaded = handler_map.get(handler_full_name) + if loaded is None: + continue + args = self._match_dynamic_route(loaded, route, event_payload) + if args is None: + continue + priority = route.get("priority", loaded.descriptor.priority) + if not isinstance(priority, int) or isinstance(priority, bool): + priority = loaded.descriptor.priority + ranked.append((priority, base_order + index, loaded, args)) + return ranked + + def _match_dynamic_route( + self, + loaded: LoadedHandler, + route: dict[str, Any], + event_payload: dict[str, Any], + ) -> dict[str, Any] | None: + if not self._passes_filters(loaded, event_payload): + return None + command_name = str(route.get("command_name", "")).strip() + if not command_name: + return None + text = str(event_payload.get("text", "")) + if bool(route.get("use_regex", False)): + match = re.search(command_name, text) + if match is None: + return None + return build_regex_args(loaded.descriptor.param_specs, match) + remainder = match_command_name(text, command_name) + if remainder is None: + return None + return build_command_args(loaded.descriptor.param_specs, remainder) + + def _match_handler( + self, + loaded: LoadedHandler, + event_payload: dict[str, Any], + ) -> dict[str, Any] | None: + if not self._passes_permissions(loaded, event_payload): + return None + trigger = loaded.descriptor.trigger + if isinstance(trigger, CommandTrigger): + return self._match_command_trigger(loaded, trigger, event_payload) + if isinstance(trigger, MessageTrigger): + return self._match_message_trigger(loaded, trigger, event_payload) + if isinstance(trigger, EventTrigger): + current_type = str( + event_payload.get("event_type") + or event_payload.get("type") + or "message" + ) + if current_type != trigger.event_type: + return None + return {} + if isinstance(trigger, ScheduleTrigger): + if ( + str(event_payload.get("event_type") or event_payload.get("type")) + == "schedule" + ): + return {} + return None + return None + + def _match_command_trigger( + self, + loaded: LoadedHandler, + trigger: CommandTrigger, + event_payload: dict[str, Any], + ) -> dict[str, Any] | None: + if not self._passes_filters(loaded, event_payload): + return None + text = str(event_payload.get("text", "")).strip() + for command_name in [trigger.command, *trigger.aliases]: + if not command_name: + continue + match = match_command_name(text, command_name) + if match is None: + continue + return build_command_args(loaded.descriptor.param_specs, match) + return None + + def _match_message_trigger( + self, + loaded: LoadedHandler, + trigger: MessageTrigger, + event_payload: dict[str, Any], + ) -> dict[str, Any] | None: + if not self._passes_filters(loaded, event_payload): + return None + text = str(event_payload.get("text", "")) + if trigger.regex: + match = re.search(trigger.regex, text) + if match is None: + return None + return build_regex_args(loaded.descriptor.param_specs, match) + if trigger.keywords and not any( + keyword in text for keyword in trigger.keywords + ): + return None + return {} + + @staticmethod + def _passes_permissions( + loaded: LoadedHandler, + event_payload: dict[str, Any], + ) -> bool: + permissions = loaded.descriptor.permissions + required_role = permissions.required_role + if required_role is None and permissions.require_admin: + required_role = "admin" + if required_role == "admin": + return bool(event_payload.get("is_admin", False)) + return True + + def _passes_filters( + self, + loaded: LoadedHandler, + event_payload: dict[str, Any], + ) -> bool: + for filter_spec in loaded.descriptor.filters: + if isinstance(filter_spec, PlatformFilterSpec): + if str(event_payload.get("platform", "")) not in filter_spec.platforms: + return False + elif isinstance(filter_spec, MessageTypeFilterSpec): + if ( + self._message_type_name(event_payload) + not in filter_spec.message_types + ): + return False + elif isinstance(filter_spec, CompositeFilterSpec): + if not self._passes_composite_filter(filter_spec, event_payload): + return False + elif isinstance(filter_spec, LocalFilterRefSpec): + continue + return True + + def _passes_composite_filter( + self, + filter_spec: CompositeFilterSpec, + event_payload: dict[str, Any], + ) -> bool: + results: list[bool] = [] + for child in filter_spec.children: + if isinstance(child, PlatformFilterSpec): + results.append( + str(event_payload.get("platform", "")) in child.platforms + ) + elif isinstance(child, MessageTypeFilterSpec): + results.append( + self._message_type_name(event_payload) in child.message_types + ) + elif isinstance(child, LocalFilterRefSpec): + results.append(True) + elif isinstance(child, CompositeFilterSpec): + results.append(self._passes_composite_filter(child, event_payload)) + if filter_spec.kind == "and": + return all(results) + return any(results) + + def _has_waiter_for_event(self, event_payload: dict[str, Any]) -> bool: + assert self.dispatcher is not None + probe_event = MessageEvent.from_payload( + event_payload, + context=self.lifecycle_context, + ) + public_probe = getattr(self.dispatcher, "has_active_waiter", None) + if callable(public_probe): + return bool(public_probe(probe_event)) + session_waiters = getattr(self.dispatcher, "_session_waiters", None) + if session_waiters is None: + return False + if hasattr(session_waiters, "has_waiter"): + return session_waiters.has_waiter(probe_event) + if isinstance(session_waiters, dict): + return any( + manager.has_waiter(probe_event) + for manager in session_waiters.values() + if hasattr(manager, "has_waiter") + ) + return False + + @staticmethod + def _message_type_name(event_payload: dict[str, Any]) -> str: + return normalize_message_type( + event_payload.get("message_type", ""), + group_id=str(event_payload.get("group_id", "")).strip() or None, + user_id=str(event_payload.get("user_id", "")).strip() or None, + empty_default="other", + ) + + @staticmethod + def _resolve_lifecycle_hook(instance: Any, method_name: str): + hook = getattr(instance, method_name, None) + marker = getattr(instance.__class__, "__astrbot_is_new_star__", None) + is_new_star = True + if callable(marker): + is_new_star = bool(marker()) + + if hook is not None and callable(hook): + bound_func = getattr(hook, "__func__", hook) + star_default = getattr(Star, method_name, None) + if star_default is None or bound_func is not star_default: + return hook + + if not is_new_star: + alias = {"on_start": "initialize", "on_stop": "terminate"}.get(method_name) + if alias is not None: + legacy_hook = getattr(instance, alias, None) + if legacy_hook is not None and callable(legacy_hook): + return legacy_hook + + if hook is not None and callable(hook): + return hook + return None + + def _next_request_id(self, prefix: str) -> str: + self._request_counter += 1 + return f"{prefix}_{self._request_counter:04d}" + + +__all__ = [ + "InMemoryDB", + "InMemoryMemory", + "LocalRuntimeConfig", + "MockClock", + "MockCapabilityRouter", + "MockContext", + "MockLLMClient", + "MockMessageEvent", + "MockPeer", + "MockPlatformClient", + "SDKTestEnvironment", + "PluginHarness", + "RecordedSend", + "StdoutPlatformSink", +] diff --git a/astrbot-sdk/src/astrbot_sdk/types.py b/astrbot-sdk/src/astrbot_sdk/types.py new file mode 100644 index 0000000000..c2bc911ec7 --- /dev/null +++ b/astrbot-sdk/src/astrbot_sdk/types.py @@ -0,0 +1,22 @@ +"""SDK parameter helper types. + +本模块提供 SDK 参数类型助手,用于增强命令参数解析能力。 + +GreedyStr: +用于标记"贪婪字符串"参数,在命令解析时将剩余所有文本作为一个整体参数。 +例如:/echo hello world this is a test +如果最后一个参数类型为 GreedyStr,将获取 "hello world this is a test" 而非仅 "hello" + +使用方式: +在 handler 签名中将最后一个参数标注为 GreedyStr 类型, +_loader_support 会识别此类型并调整参数解析逻辑。 +""" + +from __future__ import annotations + + +class GreedyStr(str): + """Consume the remaining command text as one argument.""" + + +__all__ = ["GreedyStr"] diff --git a/astrbot-sdk/tests/__init__.py b/astrbot-sdk/tests/__init__.py new file mode 100644 index 0000000000..c9045bcfd9 --- /dev/null +++ b/astrbot-sdk/tests/__init__.py @@ -0,0 +1,6 @@ +"""Stabilize cross-test imports under pytest collection. + +The MCP runtime tests reuse helpers from sibling test modules, so `tests` +needs to behave as an explicit package instead of depending on environment- +specific namespace-package discovery. +""" diff --git a/astrbot-sdk/tests/conftest.py b/astrbot-sdk/tests/conftest.py new file mode 100644 index 0000000000..5741589657 --- /dev/null +++ b/astrbot-sdk/tests/conftest.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import sys +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +SRC = ROOT / "src" + +while str(SRC) in sys.path: + sys.path.remove(str(SRC)) + +sys.path.insert(0, str(SRC)) diff --git a/astrbot-sdk/tests/test_cli_init.py b/astrbot-sdk/tests/test_cli_init.py new file mode 100644 index 0000000000..c58732cccd --- /dev/null +++ b/astrbot-sdk/tests/test_cli_init.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +import zipfile +from pathlib import Path + +from click.testing import CliRunner + +from astrbot_sdk.cli import cli + + +def test_init_normalizes_plugin_name_and_adds_prefix() -> None: + runner = CliRunner() + + with runner.isolated_filesystem(): + result = runner.invoke(cli, ["init", "demo-plugin"]) + + assert result.exit_code == 0 + plugin_dir = Path("astrbot_plugin_demo_plugin") + assert plugin_dir.exists() + manifest = (plugin_dir / "plugin.yaml").read_text(encoding="utf-8") + assert "name: astrbot_plugin_demo_plugin" in manifest + assert "display_name: demo-plugin" in manifest + assert "version: 1.0.0" in manifest + assert not (plugin_dir / ".claude").exists() + assert not (plugin_dir / ".agents").exists() + assert not (plugin_dir / ".opencode").exists() + + +def test_init_interactive_prompts_and_sanitizes_name() -> None: + runner = CliRunner() + + with runner.isolated_filesystem(): + result = runner.invoke( + cli, + ["init"], + input="\nMy Plugin,Name;Beta\nAlice\nExample plugin\n\n", + ) + + assert result.exit_code == 0 + assert "该字段不能为空,请重新输入。" in result.output + plugin_dir = Path("astrbot_plugin_my_plugin_name_beta") + assert plugin_dir.exists() + manifest = (plugin_dir / "plugin.yaml").read_text(encoding="utf-8") + assert "name: astrbot_plugin_my_plugin_name_beta" in manifest + assert "display_name: My Plugin,Name;Beta" in manifest + assert "author: Alice" in manifest + assert "desc: Example plugin" in manifest + assert "version: 1.0.0" in manifest + + +def test_init_generates_claude_agent_directory() -> None: + runner = CliRunner() + + with runner.isolated_filesystem(): + result = runner.invoke(cli, ["init", "demo-plugin", "--agents", "claude"]) + + assert result.exit_code == 0 + plugin_dir = Path("astrbot_plugin_demo_plugin") + claude_file = ( + plugin_dir / ".claude" / "skills" / "astrbot-plugin-dev" / "SKILL.md" + ) + assert claude_file.exists() + content = claude_file.read_text(encoding="utf-8") + assert "astrbot_plugin_demo_plugin" in content + assert "name: astrbot-plugin-dev" in content + assert "Plugin root: `../../..`" in content + assert ( + plugin_dir + / ".claude" + / "skills" + / "astrbot-plugin-dev" + / "references" + / "api-quick-ref.md" + ).exists() + assert not (plugin_dir / ".agents").exists() + assert not (plugin_dir / ".opencode").exists() + + +def test_init_generates_multiple_agent_directories() -> None: + runner = CliRunner() + + with runner.isolated_filesystem(): + result = runner.invoke( + cli, + ["init", "demo-plugin", "--agents", "claude,codex"], + ) + + assert result.exit_code == 0 + plugin_dir = Path("astrbot_plugin_demo_plugin") + assert ( + plugin_dir / ".claude" / "skills" / "astrbot-plugin-dev" / "SKILL.md" + ).exists() + assert ( + plugin_dir / ".agents" / "skills" / "astrbot-plugin-dev" / "SKILL.md" + ).exists() + codex_meta = ( + plugin_dir + / ".agents" + / "skills" + / "astrbot-plugin-dev" + / "agents" + / "openai.yaml" + ).read_text(encoding="utf-8") + assert "AstrBot Plugin Dev (Codex)" in codex_meta + assert not (plugin_dir / ".opencode").exists() + + +def test_init_deduplicates_agents_case_insensitively() -> None: + runner = CliRunner() + + with runner.isolated_filesystem(): + result = runner.invoke( + cli, + ["init", "demo-plugin", "--agents", "Claude,codex,CLAUDE"], + ) + + assert result.exit_code == 0 + plugin_dir = Path("astrbot_plugin_demo_plugin") + assert ( + plugin_dir / ".claude" / "skills" / "astrbot-plugin-dev" / "SKILL.md" + ).exists() + assert ( + plugin_dir / ".agents" / "skills" / "astrbot-plugin-dev" / "SKILL.md" + ).exists() + assert not (plugin_dir / ".opencode").exists() + + +def test_init_rejects_invalid_agents() -> None: + runner = CliRunner() + + with runner.isolated_filesystem(): + result = runner.invoke( + cli, + ["init", "demo-plugin", "--agents", "claude,unknown"], + ) + + assert result.exit_code == 2 + assert "仅支持以下 agent" in result.output + assert "unknown" in result.output + assert not Path("astrbot_plugin_demo_plugin").exists() + + +def test_init_generates_opencode_agent_directory() -> None: + runner = CliRunner() + + with runner.isolated_filesystem(): + result = runner.invoke(cli, ["init", "demo-plugin", "--agents", "opencode"]) + + assert result.exit_code == 0 + plugin_dir = Path("astrbot_plugin_demo_plugin") + opencode_file = ( + plugin_dir / ".opencode" / "skills" / "astrbot-plugin-dev" / "SKILL.md" + ) + assert opencode_file.exists() + content = opencode_file.read_text(encoding="utf-8") + assert "astrbot_plugin_demo_plugin" in content + assert "Plugin root: `../../..`" in content + + +def test_build_excludes_generated_agent_skill_directories() -> None: + runner = CliRunner() + + with runner.isolated_filesystem(): + init_result = runner.invoke( + cli, + ["init", "demo-plugin", "--agents", "claude,codex,opencode"], + ) + assert init_result.exit_code == 0 + + plugin_dir = Path("astrbot_plugin_demo_plugin") + build_result = runner.invoke( + cli, + ["build", "--plugin-dir", str(plugin_dir)], + ) + + assert build_result.exit_code == 0 + archive_path = plugin_dir / "dist" / "astrbot_plugin_demo_plugin-1.0.0.zip" + assert archive_path.exists() + + with zipfile.ZipFile(archive_path) as archive: + names = archive.namelist() + + assert "plugin.yaml" in names + assert "main.py" in names + assert all(not name.startswith(".claude/") for name in names) + assert all(not name.startswith(".agents/") for name in names) + assert all(not name.startswith(".opencode/") for name in names) diff --git a/astrbot-sdk/tests/test_client_regressions.py b/astrbot-sdk/tests/test_client_regressions.py new file mode 100644 index 0000000000..45372c3a6c --- /dev/null +++ b/astrbot-sdk/tests/test_client_regressions.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from astrbot_sdk.clients._proxy import CapabilityProxy +from astrbot_sdk.clients.memory import MemoryClient +from astrbot_sdk.clients.metadata import MetadataClient + + +class _FakeProxy: + def __init__(self, responses: dict[str, dict[str, Any]] | None = None) -> None: + self.responses = responses or {} + self.calls: list[tuple[str, dict[str, Any]]] = [] + + async def call(self, name: str, payload: dict[str, Any]) -> dict[str, Any]: + self.calls.append((name, dict(payload))) + return dict(self.responses.get(name, {})) + + +@pytest.mark.asyncio +async def test_memory_get_many_skips_non_dict_items() -> None: + proxy = _FakeProxy( + { + "memory.get_many": { + "items": [ + {"key": "pref1", "value": {"theme": "dark"}}, + ["unexpected"], + None, + {"key": "pref2", "value": None}, + ] + } + } + ) + client = MemoryClient(proxy) # type: ignore[arg-type] + + items = await client.get_many(["pref1", "pref2"]) + + assert items == [ + {"key": "pref1", "value": {"theme": "dark"}}, + {"key": "pref2", "value": None}, + ] + + +@pytest.mark.asyncio +async def test_capability_proxy_ignores_magicmock_placeholder_attributes() -> None: + peer = MagicMock() + peer.invoke = AsyncMock(return_value={}) + proxy = CapabilityProxy(peer) + + result = await proxy.call("metadata.get_plugin", {"name": "demo"}) + + assert result == {} + peer.invoke.assert_awaited_once_with( + "metadata.get_plugin", + {"name": "demo"}, + stream=False, + ) + + +@pytest.mark.asyncio +async def test_metadata_client_rejects_cross_plugin_config_access() -> None: + proxy = _FakeProxy( + { + "metadata.get_plugin_config": { + "config": {"api_key": "hidden"}, + } + } + ) + client = MetadataClient(proxy, plugin_id="current-plugin") + + with pytest.raises(PermissionError, match="只允许访问当前插件自己的配置"): + await client.get_plugin_config("other-plugin") + + assert proxy.calls == [] diff --git a/astrbot-sdk/tests/test_command_matching.py b/astrbot-sdk/tests/test_command_matching.py new file mode 100644 index 0000000000..13dd1eb809 --- /dev/null +++ b/astrbot-sdk/tests/test_command_matching.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import re +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) + +from astrbot_sdk.protocol.descriptors import ParamSpec +from astrbot_sdk.runtime._command_matching import ( + build_command_args, + build_regex_args, + match_command_name, + split_command_remainder, +) + + +def test_match_command_name_trims_input_consistently() -> None: + assert match_command_name(" ping ", "ping") == "" + assert match_command_name(" ping hello world ", "ping") == "hello world" + assert match_command_name("pingpong", "ping") is None + + +def test_build_command_args_supports_quotes_and_greedy_tail() -> None: + param_specs = [ + ParamSpec(name="name", type="str"), + ParamSpec(name="message", type="greedy_str"), + ] + + args = build_command_args(param_specs, '"alpha beta" "hello world" tail') + + assert args == {"name": "alpha beta", "message": "hello world tail"} + + +def test_split_command_remainder_falls_back_on_invalid_quotes() -> None: + assert split_command_remainder('"unterminated quote test') == [ + '"unterminated', + "quote", + "test", + ] + + +def test_build_regex_args_preserves_named_and_positional_mapping() -> None: + param_specs = [ + ParamSpec(name="first", type="str"), + ParamSpec(name="second", type="str"), + ParamSpec(name="third", type="str"), + ] + match = re.search(r"(?P\w+)-(\w+)-(\w+)", "named-positional-tail") + + assert match is not None + assert build_regex_args(param_specs, match) == { + "second": "named", + "first": "named", + "third": "positional", + } diff --git a/astrbot-sdk/tests/test_command_model_parsing.py b/astrbot-sdk/tests/test_command_model_parsing.py new file mode 100644 index 0000000000..3e06dc8495 --- /dev/null +++ b/astrbot-sdk/tests/test_command_model_parsing.py @@ -0,0 +1,200 @@ +# ruff: noqa: I001 +""" +针对 command_model 解析逻辑的边界场景单元测试。 +覆盖:--help 生成、位置参数超限、重复 flag、连字符映射。 +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) + +from pydantic import BaseModel + +from astrbot_sdk._internal.command_model import ( + ResolvedCommandModelParam, + format_command_model_help, + parse_command_model_remainder, + resolve_command_model_param, +) +from astrbot_sdk.errors import AstrBotError + + +# ── 测试用模型 ──────────────────────────────────────────────────── + + +class SimpleModel(BaseModel): + name: str + count: int = 1 + verbose: bool = False + + +class HyphenModel(BaseModel): + output_dir: str + max_retries: int = 3 + + +# ── 辅助:构建 ResolvedCommandModelParam ────────────────────────── + + +def _param(model_cls: type[BaseModel]) -> ResolvedCommandModelParam: + return ResolvedCommandModelParam(name="args", model_cls=model_cls) + + +# ── --help 测试 ─────────────────────────────────────────────────── + + +def test_help_flag_short() -> None: + result = parse_command_model_remainder( + remainder="-h", model_param=_param(SimpleModel), command_name="test" + ) + assert result.model is None + assert result.help_text is not None + assert "test" in result.help_text + + +def test_help_flag_long() -> None: + result = parse_command_model_remainder( + remainder="--help", model_param=_param(SimpleModel), command_name="greet" + ) + assert result.model is None + assert result.help_text is not None + assert "name" in result.help_text + assert "count" in result.help_text + assert "verbose" in result.help_text + + +def test_format_command_model_help_contains_bool_hint() -> None: + text = format_command_model_help("myCmd", SimpleModel) + assert "--verbose" in text + assert "--no-verbose" in text + + +# ── 位置参数超限 ────────────────────────────────────────────────── + + +def test_too_many_positional_args_raises() -> None: + with pytest.raises(AstrBotError) as exc_info: + parse_command_model_remainder( + remainder="alice 10 extra", + model_param=_param(SimpleModel), + command_name="cmd", + ) + assert "Too many positional arguments" in str(exc_info.value) + + +def test_exactly_right_positional_args_succeeds() -> None: + result = parse_command_model_remainder( + remainder="alice 5", + model_param=_param(SimpleModel), + command_name="cmd", + ) + assert result.model is not None + assert result.model.name == "alice" # type: ignore[attr-defined] + assert result.model.count == 5 # type: ignore[attr-defined] + + +# ── 重复 flag ───────────────────────────────────────────────────── + + +def test_duplicate_named_flag_raises() -> None: + with pytest.raises(AstrBotError) as exc_info: + parse_command_model_remainder( + remainder="--name alice --name bob", + model_param=_param(SimpleModel), + command_name="cmd", + ) + assert "Duplicate option" in str(exc_info.value) + + +def test_duplicate_bool_flag_raises() -> None: + with pytest.raises(AstrBotError) as exc_info: + parse_command_model_remainder( + remainder="--verbose --verbose", + model_param=_param(SimpleModel), + command_name="cmd", + ) + assert "Duplicate option" in str(exc_info.value) + + +# ── 连字符映射下划线 ─────────────────────────────────────────────── + + +def test_hyphen_flag_maps_to_underscore_field() -> None: + result = parse_command_model_remainder( + remainder="--output-dir /tmp --max-retries 5", + model_param=_param(HyphenModel), + command_name="build", + ) + assert result.model is not None + assert result.model.output_dir == "/tmp" # type: ignore[attr-defined] + assert result.model.max_retries == 5 # type: ignore[attr-defined] + + +def test_underscore_flag_still_works() -> None: + """直接使用下划线形式也应正常解析(向后兼容)。""" + result = parse_command_model_remainder( + remainder="--output_dir /out", + model_param=_param(HyphenModel), + command_name="build", + ) + assert result.model is not None + assert result.model.output_dir == "/out" # type: ignore[attr-defined] + + +# ── bool 标志 --no- 前缀 ────────────────────────────────────────── + + +def test_bool_negation_flag() -> None: + result = parse_command_model_remainder( + remainder="alice --no-verbose", + model_param=_param(SimpleModel), + command_name="cmd", + ) + assert result.model is not None + assert result.model.verbose is False # type: ignore[attr-defined] + + +def test_bool_positive_flag() -> None: + result = parse_command_model_remainder( + remainder="alice --verbose", + model_param=_param(SimpleModel), + command_name="cmd", + ) + assert result.model is not None + assert result.model.verbose is True # type: ignore[attr-defined] + + +# ── 未知字段 ────────────────────────────────────────────────────── + + +def test_unknown_flag_raises() -> None: + with pytest.raises(AstrBotError) as exc_info: + parse_command_model_remainder( + remainder="--nonexistent foo", + model_param=_param(SimpleModel), + command_name="cmd", + ) + assert "Unknown option" in str(exc_info.value) + + +# ── resolve_command_model_param ─────────────────────────────────── + + +def test_resolve_finds_model_param() -> None: + def handler(event: object, args: SimpleModel) -> None: ... + + resolved = resolve_command_model_param(handler) + assert resolved is not None + assert resolved.model_cls is SimpleModel + + +def test_resolve_returns_none_for_plain_handler() -> None: + def handler(event: object, name: str) -> None: ... + + resolved = resolve_command_model_param(handler) + assert resolved is None diff --git a/astrbot-sdk/tests/test_context_llm_tool_registration.py b/astrbot-sdk/tests/test_context_llm_tool_registration.py new file mode 100644 index 0000000000..7a697a8511 --- /dev/null +++ b/astrbot-sdk/tests/test_context_llm_tool_registration.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +from astrbot_sdk._internal.testing_support import MockContext +from astrbot_sdk.llm.entities import LLMToolSpec + + +class RecordingDispatcher: + def __init__(self) -> None: + self.added: list[dict[str, Any]] = [] + self.removed: list[tuple[str, str]] = [] + + def add_dynamic_llm_tool( + self, + *, + plugin_id: str, + spec: LLMToolSpec, + callable_obj, + owner: Any | None = None, + ) -> None: + self.added.append( + { + "plugin_id": plugin_id, + "spec": spec, + "callable_obj": callable_obj, + "owner": owner, + } + ) + + def remove_llm_tool(self, plugin_id: str, name: str) -> bool: + self.removed.append((plugin_id, name)) + return True + + +@pytest.mark.asyncio +async def test_register_llm_tool_keeps_manager_and_dispatcher_specs_aligned() -> None: + ctx = MockContext() + dispatcher = RecordingDispatcher() + ctx.peer._sdk_capability_dispatcher = dispatcher + + async def echo_tool(text: str) -> str: + return text + + names = await ctx.register_llm_tool( + "echo", + {"type": "object", "properties": {"text": {"type": "string"}}}, + "Echo the provided text", + echo_tool, + active=False, + ) + + assert names == ["echo"] + registered = await ctx.get_llm_tool_manager().get("echo") + assert registered is not None + assert registered.name == "echo" + assert registered.description == "Echo the provided text" + assert registered.parameters_schema == { + "type": "object", + "properties": {"text": {"type": "string"}}, + } + assert registered.handler_ref == "__dynamic_llm_tool__:echo" + assert registered.active is False + + assert len(dispatcher.added) == 1 + added = dispatcher.added[0] + assert added["plugin_id"] == "test-plugin" + assert added["callable_obj"] is echo_tool + assert added["owner"] is None + assert added["spec"].model_dump() == registered.model_dump() + + removed = await ctx.unregister_llm_tool("echo") + assert removed is True + assert dispatcher.removed == [("test-plugin", "echo")] + assert await ctx.get_llm_tool_manager().get("echo") is None diff --git a/astrbot-sdk/tests/test_context_register_task.py b/astrbot-sdk/tests/test_context_register_task.py new file mode 100644 index 0000000000..3bb53da0ed --- /dev/null +++ b/astrbot-sdk/tests/test_context_register_task.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import asyncio + +import pytest + +from astrbot_sdk._internal.testing_support import MockContext + + +class RecordingLogger: + def __init__(self) -> None: + self.debug_calls: list[tuple[str, str, str]] = [] + self.exception_calls: list[tuple[str, str, str]] = [] + + def debug(self, message: str, plugin_id: str, desc: str) -> None: + self.debug_calls.append((message, plugin_id, desc)) + + def exception(self, message: str, plugin_id: str, desc: str) -> None: + self.exception_calls.append((message, plugin_id, desc)) + + +@pytest.mark.asyncio +async def test_register_task_accepts_coroutine() -> None: + ctx = MockContext() + + async def background() -> str: + await asyncio.sleep(0) + return "done" + + task = await ctx.register_task(background(), "coroutine") + + assert isinstance(task, asyncio.Task) + assert await task == "done" + + +@pytest.mark.asyncio +async def test_register_task_wraps_future_inputs() -> None: + ctx = MockContext() + loop = asyncio.get_running_loop() + future: asyncio.Future[str] = loop.create_future() + + task = await ctx.register_task(future, "future") + future.set_result("done") + + assert isinstance(task, asyncio.Task) + assert task is not future + assert await task == "done" + + +@pytest.mark.asyncio +async def test_register_task_logs_cancel_once() -> None: + logger = RecordingLogger() + ctx = MockContext(logger=logger) + started = asyncio.Event() + + async def background() -> None: + started.set() + await asyncio.Future() + + task = await ctx.register_task(background(), "cancelled") + await started.wait() + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + assert logger.debug_calls == [ + ( + "SDK background task cancelled: plugin_id={} desc={}", + "test-plugin", + "cancelled", + ) + ] + assert logger.exception_calls == [] + + +@pytest.mark.asyncio +async def test_register_task_logs_failures() -> None: + logger = RecordingLogger() + ctx = MockContext(logger=logger) + + async def background() -> None: + raise RuntimeError("boom") + + task = await ctx.register_task(background(), "failing") + + with pytest.raises(RuntimeError, match="boom"): + await task + + assert logger.debug_calls == [] + assert logger.exception_calls == [ + ( + "SDK background task failed: plugin_id={} desc={}", + "test-plugin", + "failing", + ) + ] diff --git a/astrbot-sdk/tests/test_db_runtime.py b/astrbot-sdk/tests/test_db_runtime.py new file mode 100644 index 0000000000..3576aba954 --- /dev/null +++ b/astrbot-sdk/tests/test_db_runtime.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from astrbot_sdk._internal.invocation_context import caller_plugin_scope +from astrbot_sdk.runtime.capability_router import CapabilityRouter + + +async def _call( + router: CapabilityRouter, + capability: str, + payload: dict[str, object], + *, + plugin_id: str = "test-plugin", +) -> dict[str, object]: + with caller_plugin_scope(plugin_id): + result = await router.execute( + capability, + payload, + stream=False, + cancel_token=object(), + request_id=f"{plugin_id}:{capability}", + ) + assert isinstance(result, dict) + return result + + +async def _stream( + router: CapabilityRouter, + capability: str, + payload: dict[str, object], + *, + plugin_id: str = "test-plugin", +): + with caller_plugin_scope(plugin_id): + result = await router.execute( + capability, + payload, + stream=True, + cancel_token=object(), + request_id=f"{plugin_id}:{capability}:stream", + ) + return result + + +@pytest.mark.asyncio +async def test_db_watch_returns_plugin_local_key_view( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.chdir(tmp_path) + router = CapabilityRouter() + + stream = await _stream(router, "db.watch", {"prefix": None}, plugin_id="plugin-a") + await _call( + router, + "db.set", + {"key": "user:1", "value": {"name": "Alice"}}, + plugin_id="plugin-a", + ) + + event = await anext(stream.iterator) + + assert event == { + "op": "set", + "key": "user:1", + "value": {"name": "Alice"}, + } + + +@pytest.mark.asyncio +async def test_db_watch_prefix_filters_within_plugin_scope( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.chdir(tmp_path) + router = CapabilityRouter() + + stream = await _stream( + router, "db.watch", {"prefix": "user:"}, plugin_id="plugin-a" + ) + await _call(router, "db.set", {"key": "config:1", "value": 1}, plugin_id="plugin-a") + await _call(router, "db.set", {"key": "user:2", "value": 2}, plugin_id="plugin-b") + await _call(router, "db.set", {"key": "user:1", "value": 3}, plugin_id="plugin-a") + + event = await anext(stream.iterator) + + assert event == {"op": "set", "key": "user:1", "value": 3} diff --git a/astrbot-sdk/tests/test_decorator_runtime_lifecycle.py b/astrbot-sdk/tests/test_decorator_runtime_lifecycle.py new file mode 100644 index 0000000000..a2a86651f5 --- /dev/null +++ b/astrbot-sdk/tests/test_decorator_runtime_lifecycle.py @@ -0,0 +1,277 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path +from textwrap import dedent + +import pytest + +from astrbot_sdk.testing import PluginHarness + + +async def _wait_until(predicate, *, timeout: float = 0.2) -> None: + deadline = asyncio.get_running_loop().time() + timeout + while asyncio.get_running_loop().time() < deadline: + if predicate(): + return + await asyncio.sleep(0) + raise AssertionError("timed out waiting for condition") + + +def _write_plugin( + plugin_dir: Path, + *, + name: str, + class_name: str, + source: str, + reserved: bool = False, +) -> None: + plugin_dir.mkdir(parents=True, exist_ok=True) + (plugin_dir / "plugin.yaml").write_text( + dedent( + f""" + _schema_version: 2 + name: {name} + author: tests + version: 1.0.0 + desc: decorator runtime tests + reserved: {"true" if reserved else "false"} + + runtime: + python: "3.12" + + components: + - class: main:{class_name} + """ + ).strip() + + "\n", + encoding="utf-8", + ) + (plugin_dir / "requirements.txt").write_text("", encoding="utf-8") + (plugin_dir / "main.py").write_text(dedent(source).lstrip(), encoding="utf-8") + + +@pytest.mark.asyncio +async def test_http_api_decorator_registers_and_unregisters_route( + tmp_path: Path, +) -> None: + plugin_dir = tmp_path / "http_api_plugin" + _write_plugin( + plugin_dir, + name="http_api_plugin", + class_name="HttpApiPlugin", + source=""" + from astrbot_sdk import Star, http_api, provide_capability + + + class HttpApiPlugin(Star): + @http_api(route="/decorated", methods=["GET", "POST"], description="Decorated API") + @provide_capability("http_api_plugin.handle_http", description="Handle decorated HTTP route") + async def handle_http(self, request_id: str, payload: dict, cancel_token): + return {"status": 200, "body": {"request_id": request_id, "payload": payload}} + """, + ) + + harness = PluginHarness.from_plugin_dir(plugin_dir) + await harness.start() + try: + assert harness.lifecycle_context is not None + apis = await harness.lifecycle_context.http.list_apis() + assert apis == [ + { + "route": "/decorated", + "methods": ["GET", "POST"], + "handler_capability": "http_api_plugin.handle_http", + "description": "Decorated API", + "plugin_id": "http_api_plugin", + } + ] + finally: + await harness.stop() + + assert harness.router.http_api_store == [] + + +@pytest.mark.asyncio +async def test_validate_config_decorator_rejects_invalid_config( + tmp_path: Path, +) -> None: + plugin_dir = tmp_path / "validate_config_plugin" + _write_plugin( + plugin_dir, + name="validate_config_plugin", + class_name="ValidateConfigPlugin", + source=""" + from pydantic import BaseModel + + from astrbot_sdk import Context, Star, validate_config + + + class PluginConfig(BaseModel): + api_key: str + + + class ValidateConfigPlugin(Star): + @validate_config(model=PluginConfig) + async def on_start(self, ctx: Context) -> None: + del ctx + """, + ) + + harness = PluginHarness.from_plugin_dir(plugin_dir) + with pytest.raises(Exception, match="api_key"): + await harness.start() + + +@pytest.mark.asyncio +async def test_on_provider_change_decorator_registers_and_unsubscribes( + tmp_path: Path, +) -> None: + plugin_dir = tmp_path / "provider_change_plugin" + _write_plugin( + plugin_dir, + name="provider_change_plugin", + class_name="ProviderChangePlugin", + reserved=True, + source=""" + from astrbot_sdk import Star, on_provider_change + + + class ProviderChangePlugin(Star): + def __init__(self) -> None: + self.events = [] + + @on_provider_change(provider_types=["embedding"]) + async def handle_change(self, provider_id: str, provider_type, umo: str | None) -> None: + self.events.append((provider_id, getattr(provider_type, "value", str(provider_type)), umo)) + """, + ) + + async with PluginHarness.from_plugin_dir(plugin_dir) as harness: + assert harness.loaded_plugin is not None + plugin = harness.loaded_plugin.instances[0] + await _wait_until( + lambda: len(harness.router._provider_change_subscriptions) == 1 + ) + harness.router.emit_provider_change("embed-a", "embedding", "session:1") + harness.router.emit_provider_change("rerank-a", "rerank", "session:2") + await asyncio.sleep(0.05) + + assert plugin.events == [("embed-a", "embedding", "session:1")] + assert harness.router._provider_change_subscriptions + + assert not harness.router._provider_change_subscriptions + + +@pytest.mark.asyncio +async def test_background_task_decorator_auto_starts_and_cancels( + tmp_path: Path, +) -> None: + plugin_dir = tmp_path / "background_task_plugin" + _write_plugin( + plugin_dir, + name="background_task_plugin", + class_name="BackgroundTaskPlugin", + source=""" + import asyncio + + from astrbot_sdk import Context, Star, background_task + + + class BackgroundTaskPlugin(Star): + def __init__(self) -> None: + self.started = asyncio.Event() + self.cancelled = False + + @background_task(description="decorated background task") + async def sync_data(self, ctx: Context) -> None: + del ctx + self.started.set() + try: + await asyncio.Future() + except asyncio.CancelledError: + self.cancelled = True + raise + """, + ) + + harness = PluginHarness.from_plugin_dir(plugin_dir) + await harness.start() + assert harness.loaded_plugin is not None + plugin = harness.loaded_plugin.instances[0] + await asyncio.wait_for(plugin.started.wait(), timeout=0.2) + await harness.stop() + + assert plugin.cancelled is True + + +@pytest.mark.asyncio +async def test_register_skill_decorator_registers_and_unregisters( + tmp_path: Path, +) -> None: + plugin_dir = tmp_path / "skill_plugin" + _write_plugin( + plugin_dir, + name="skill_plugin", + class_name="SkillPlugin", + source=""" + from astrbot_sdk import Star, register_skill + + + @register_skill(name="demo_skill", path="skills/demo.py", description="Demo skill") + class SkillPlugin(Star): + pass + """, + ) + + harness = PluginHarness.from_plugin_dir(plugin_dir) + await harness.start() + try: + assert harness.lifecycle_context is not None + skills = await harness.lifecycle_context.skills.list() + assert len(skills) == 1 + assert skills[0].name == "demo_skill" + assert skills[0].path == "skills/demo.py" + assert skills[0].description == "Demo skill" + finally: + await harness.stop() + + plugin = harness.router._plugins["skill_plugin"] + assert plugin.skills == {} + + +@pytest.mark.asyncio +async def test_mcp_server_decorator_registers_global_server_with_ack( + tmp_path: Path, +) -> None: + plugin_dir = tmp_path / "mcp_server_plugin" + _write_plugin( + plugin_dir, + name="mcp_server_plugin", + class_name="MCPServerPlugin", + source=""" + from astrbot_sdk import Star, acknowledge_global_mcp_risk, mcp_server + + + @acknowledge_global_mcp_risk + @mcp_server( + name="decorated-global", + scope="global", + config={"mock_tools": ["inspect"]}, + timeout=0.1, + ) + class MCPServerPlugin(Star): + pass + """, + ) + + harness = PluginHarness.from_plugin_dir(plugin_dir) + await harness.start() + try: + assert harness.lifecycle_context is not None + servers = await harness.lifecycle_context.mcp.list_global_servers() + assert [server.name for server in servers] == ["decorated-global"] + finally: + await harness.stop() + + assert harness.router._mcp_global_servers == {} diff --git a/astrbot-sdk/tests/test_decorators_filter_guards.py b/astrbot-sdk/tests/test_decorators_filter_guards.py new file mode 100644 index 0000000000..8379d86e5f --- /dev/null +++ b/astrbot-sdk/tests/test_decorators_filter_guards.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import pytest + +from astrbot_sdk.decorators import ( + append_filter_meta, + get_handler_meta, + message_types, + platforms, + require_admin, + require_permission, +) +from astrbot_sdk.protocol.descriptors import ( + MessageTypeFilterSpec, + Permissions, + PlatformFilterSpec, +) + + +def test_platforms_rejects_existing_manual_platform_filter() -> None: + def handler() -> None: + return None + + append_filter_meta( + handler, + specs=[PlatformFilterSpec(platforms=["qq"])], + ) + + meta = get_handler_meta(handler) + assert meta is not None + assert meta.decorator_sources == {} + + with pytest.raises(ValueError, match="已有平台过滤器"): + platforms("wechat")(handler) + + +def test_message_types_rejects_existing_manual_message_type_filter() -> None: + def handler() -> None: + return None + + append_filter_meta( + handler, + specs=[MessageTypeFilterSpec(message_types=["group"])], + ) + + meta = get_handler_meta(handler) + assert meta is not None + assert meta.decorator_sources == {} + + with pytest.raises(ValueError, match="已有消息类型过滤器"): + message_types("private")(handler) + + +def test_require_permission_sets_normalized_permissions() -> None: + def handler() -> None: + return None + + require_permission("admin")(handler) + + meta = get_handler_meta(handler) + assert meta is not None + assert meta.permissions == Permissions(require_admin=True) + + +def test_require_permission_rejects_invalid_role() -> None: + with pytest.raises(ValueError, match="只支持"): + require_permission("owner") # type: ignore[arg-type] + + +def test_require_permission_rejects_conflicting_markers() -> None: + def handler() -> None: + return None + + require_permission("member")(handler) + + with pytest.raises(ValueError, match="冲突"): + require_admin(handler) diff --git a/astrbot-sdk/tests/test_http_runtime.py b/astrbot-sdk/tests/test_http_runtime.py new file mode 100644 index 0000000000..e9178a69b3 --- /dev/null +++ b/astrbot-sdk/tests/test_http_runtime.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from astrbot_sdk._internal.invocation_context import caller_plugin_scope +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.runtime.capability_router import CapabilityRouter + + +async def _call( + router: CapabilityRouter, + capability: str, + payload: dict[str, object], + *, + plugin_id: str = "test-plugin", +) -> dict[str, object]: + with caller_plugin_scope(plugin_id): + result = await router.execute( + capability, + payload, + stream=False, + cancel_token=object(), + request_id=f"{plugin_id}:{capability}", + ) + assert isinstance(result, dict) + return result + + +@pytest.mark.asyncio +async def test_http_unregister_empty_methods_removes_all_for_route( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.chdir(tmp_path) + router = CapabilityRouter() + + await _call( + router, + "http.register_api", + { + "route": "/demo", + "methods": ["GET", "POST"], + "handler_capability": "demo.handler", + "description": "demo", + }, + ) + await _call( + router, + "http.unregister_api", + {"route": "/demo", "methods": []}, + ) + + listed = await _call(router, "http.list_apis", {}) + + assert listed == {"apis": []} + + +@pytest.mark.asyncio +async def test_http_unregister_subset_preserves_other_methods( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.chdir(tmp_path) + router = CapabilityRouter() + + await _call( + router, + "http.register_api", + { + "route": "/demo", + "methods": ["GET", "POST"], + "handler_capability": "demo.handler", + "description": "demo", + }, + ) + await _call( + router, + "http.unregister_api", + {"route": "/demo", "methods": ["POST"]}, + ) + + listed = await _call(router, "http.list_apis", {}) + + assert listed == { + "apis": [ + { + "route": "/demo", + "methods": ["GET"], + "handler_capability": "demo.handler", + "description": "demo", + "plugin_id": "test-plugin", + } + ] + } + + +@pytest.mark.asyncio +async def test_http_register_rejects_routes_with_empty_segments( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.chdir(tmp_path) + router = CapabilityRouter() + + with pytest.raises(AstrBotError) as exc_info: + await _call( + router, + "http.register_api", + { + "route": "/foo//bar", + "methods": ["GET"], + "handler_capability": "demo.handler", + "description": "demo", + }, + ) + + assert exc_info.value.code == "invalid_input" diff --git a/astrbot-sdk/tests/test_injected_params.py b/astrbot-sdk/tests/test_injected_params.py new file mode 100644 index 0000000000..caa92e14b0 --- /dev/null +++ b/astrbot-sdk/tests/test_injected_params.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import sys +from pathlib import Path +from types import SimpleNamespace + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) + +from pydantic import BaseModel + +from astrbot_sdk._internal.command_model import resolve_command_model_param +from astrbot_sdk._internal.injected_params import ( + is_framework_injected_parameter, + legacy_arg_parameter_names, +) +from astrbot_sdk.conversation import ConversationSession +from astrbot_sdk.schedule import ScheduleContext +from astrbot_sdk.protocol.descriptors import CommandTrigger, HandlerDescriptor +from astrbot_sdk.runtime.handler_dispatcher import HandlerDispatcher +from astrbot_sdk.runtime.loader import LoadedHandler, _build_param_specs + + +class _Payload(BaseModel): + name: str + + +def test_legacy_arg_parameter_names_excludes_injected_aliases() -> None: + def handler( + ctx, + conversation, + conv, + sched, + schedule, + name, + extra="fallback", + ) -> None: ... + + assert legacy_arg_parameter_names(handler) == ["name", "extra"] + + +def test_resolve_command_model_param_ignores_injected_aliases() -> None: + def handler(conversation, sched, payload: _Payload) -> None: ... + + resolved = resolve_command_model_param(handler) + + assert resolved is not None + assert resolved.name == "payload" + assert resolved.model_cls is _Payload + + +def test_is_framework_injected_parameter_supports_type_based_injection() -> None: + assert is_framework_injected_parameter("custom_conv", ConversationSession) + assert is_framework_injected_parameter("custom_schedule", ScheduleContext) + + +def test_loader_build_param_specs_excludes_injected_aliases() -> None: + def handler(conversation, schedule, name: str, count: int = 0) -> None: ... + + specs = _build_param_specs(handler) + + assert [spec.name for spec in specs] == ["name", "count"] + + +def test_handler_dispatcher_derive_args_skips_injected_aliases() -> None: + def handler(conversation, name, sched) -> None: ... + + loaded = LoadedHandler( + descriptor=HandlerDescriptor( + id="plugin.handler", + trigger=CommandTrigger(command="ping"), + ), + callable=handler, + owner=object(), + ) + dispatcher = HandlerDispatcher( + plugin_id="plugin", + peer=SimpleNamespace(), + handlers=[loaded], + ) + + args = dispatcher._derive_args(loaded, SimpleNamespace(text="ping alice")) + + assert args == {"name": "alice"} diff --git a/astrbot-sdk/tests/test_invocation_context_isolation.py b/astrbot-sdk/tests/test_invocation_context_isolation.py new file mode 100644 index 0000000000..66b7c8c056 --- /dev/null +++ b/astrbot-sdk/tests/test_invocation_context_isolation.py @@ -0,0 +1,154 @@ +""" +针对 caller_plugin_scope / invocation_context 的并发隔离单元测试。 +覆盖: + - 基本作用域绑定与清理 + - 嵌套作用域 + - 并发 Task 下 ContextVar 互不干扰 + - 作用域结束后正确 reset +""" + +from __future__ import annotations + +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) + +from astrbot_sdk._internal.invocation_context import ( + bind_caller_plugin_id, + caller_plugin_scope, + current_caller_plugin_id, + reset_caller_plugin_id, +) + + +# ── 基本绑定与清理 ───────────────────────────────────────────────── + + +def test_default_is_none() -> None: + assert current_caller_plugin_id() is None + + +def test_scope_sets_and_resets() -> None: + assert current_caller_plugin_id() is None + with caller_plugin_scope("plugin_a"): + assert current_caller_plugin_id() == "plugin_a" + # 作用域结束后必须恢复为 None + assert current_caller_plugin_id() is None + + +def test_scope_resets_to_previous_value() -> None: + token = bind_caller_plugin_id("outer") + try: + assert current_caller_plugin_id() == "outer" + with caller_plugin_scope("inner"): + assert current_caller_plugin_id() == "inner" + assert current_caller_plugin_id() == "outer" + finally: + reset_caller_plugin_id(token) + + +def test_scope_with_none_clears_id() -> None: + token = bind_caller_plugin_id("plugin_x") + try: + with caller_plugin_scope(None): + assert current_caller_plugin_id() is None + assert current_caller_plugin_id() == "plugin_x" + finally: + reset_caller_plugin_id(token) + + +def test_empty_string_normalized_to_none() -> None: + token = bind_caller_plugin_id(" ") # 空白字符串 → None + try: + assert current_caller_plugin_id() is None + finally: + reset_caller_plugin_id(token) + + +# ── 嵌套作用域 ──────────────────────────────────────────────────── + + +def test_nested_scopes_restore_correctly() -> None: + with caller_plugin_scope("a"): + assert current_caller_plugin_id() == "a" + with caller_plugin_scope("b"): + assert current_caller_plugin_id() == "b" + with caller_plugin_scope("c"): + assert current_caller_plugin_id() == "c" + assert current_caller_plugin_id() == "b" + assert current_caller_plugin_id() == "a" + assert current_caller_plugin_id() is None + + +# ── 并发 Task 隔离 ──────────────────────────────────────────────── + + +def test_concurrent_tasks_do_not_share_context() -> None: + """不同 asyncio Task 中的 ContextVar 互不干扰。""" + + results: dict[str, str | None] = {} + + async def task_fn(plugin_id: str, delay: float) -> None: + with caller_plugin_scope(plugin_id): + await asyncio.sleep(delay) + results[plugin_id] = current_caller_plugin_id() + + async def run() -> None: + # 两个 Task 并发执行,delay 设置使它们交叉运行 + await asyncio.gather( + task_fn("plugin_alpha", 0.01), + task_fn("plugin_beta", 0.001), + ) + + asyncio.run(run()) + + assert results["plugin_alpha"] == "plugin_alpha" + assert results["plugin_beta"] == "plugin_beta" + + +def test_child_task_inherits_parent_context_but_isolated() -> None: + """子 Task 继承父 Task 的 ContextVar 快照,但修改不会影响父 Task。""" + + parent_values: list[str | None] = [] + child_values: list[str | None] = [] + + async def child_task() -> None: + # 子 Task 在父 Task 的 scope 内创建,继承 "parent_plugin" 快照 + child_values.append(current_caller_plugin_id()) + # 子 Task 内修改不应该影响父 Task + with caller_plugin_scope("child_plugin"): + child_values.append(current_caller_plugin_id()) + child_values.append(current_caller_plugin_id()) + + async def parent_task() -> None: + with caller_plugin_scope("parent_plugin"): + parent_values.append(current_caller_plugin_id()) + task = asyncio.create_task(child_task()) + await asyncio.sleep(0.01) + parent_values.append(current_caller_plugin_id()) + await task + parent_values.append(current_caller_plugin_id()) + + asyncio.run(parent_task()) + + # 子 Task 继承了父 Task 的初始值 + assert child_values[0] == "parent_plugin" + assert child_values[1] == "child_plugin" + assert child_values[2] == "parent_plugin" + + # 父 Task 全程不受子 Task 影响 + assert all(v == "parent_plugin" for v in parent_values) + + +def test_scope_exception_still_resets() -> None: + """作用域内抛出异常时,ContextVar 依然被正确 reset。""" + assert current_caller_plugin_id() is None + try: + with caller_plugin_scope("error_plugin"): + assert current_caller_plugin_id() == "error_plugin" + raise RuntimeError("intentional error") + except RuntimeError: + pass + assert current_caller_plugin_id() is None diff --git a/astrbot-sdk/tests/test_mcp_runtime.py b/astrbot-sdk/tests/test_mcp_runtime.py new file mode 100644 index 0000000000..9cfef9cda5 --- /dev/null +++ b/astrbot-sdk/tests/test_mcp_runtime.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +import pytest +from astrbot_sdk.testing import MockContext + +from tests.test_sdk.unit._mcp_contract import exercise_local_mcp_contract + + +class _MockMCPBackend: + def __init__(self, ctx: MockContext) -> None: + self._ctx = ctx + + async def get_server(self, name: str): + return await self._ctx.mcp.get_server(name) + + async def list_servers(self): + return await self._ctx.mcp.list_servers() + + async def enable_server(self, name: str): + return await self._ctx.mcp.enable_server(name) + + async def disable_server(self, name: str): + return await self._ctx.mcp.disable_server(name) + + async def wait_until_ready(self, name: str, *, timeout: float): + return await self._ctx.mcp.wait_until_ready(name, timeout=timeout) + + +def _local_server_payload(name: str, *, running: bool, delay: float = 0.0) -> dict: + return { + "name": name, + "scope": "local", + "active": True, + "running": running, + "config": { + "mock_tools": ["lookup"], + "mock_connect_delay": delay, + }, + "tools": ["lookup"] if running else [], + "errlogs": [], + "last_error": None, + } + + +@pytest.mark.asyncio +async def test_mock_context_mcp_local_contract_and_alias() -> None: + ctx = MockContext( + plugin_id="sdk-demo", + plugin_metadata={ + "local_mcp_servers": { + "demo": _local_server_payload("demo", running=True), + } + }, + ) + + assert ctx.mcp is ctx.mcp_manager + + await exercise_local_mcp_contract(_MockMCPBackend(ctx)) + + +@pytest.mark.asyncio +async def test_mock_context_mcp_wait_until_ready_success_and_timeout() -> None: + ctx = MockContext( + plugin_id="sdk-demo", + plugin_metadata={ + "local_mcp_servers": { + "demo": _local_server_payload("demo", running=False, delay=0.01), + "slow": _local_server_payload("slow", running=False, delay=0.2), + } + }, + ) + + ready = await ctx.mcp.wait_until_ready("demo", timeout=0.1) + assert ready.running is True + assert ready.tools == ["lookup"] + + with pytest.raises(TimeoutError): + await ctx.mcp.wait_until_ready("slow", timeout=0.01) + + +@pytest.mark.asyncio +async def test_mock_context_mcp_session_round_trip_and_tool_loop_isolation() -> None: + ctx = MockContext(plugin_id="sdk-demo") + + async with ctx.mcp.session( + "adhoc", + { + "mock_tools": ["inspect"], + "mock_tool_results": {"inspect": {"ok": True}}, + }, + timeout=0.1, + ) as session: + assert await session.list_tools() == ["inspect"] + assert await session.call_tool("inspect", {"x": 1}) == {"ok": True} + tool_loop = await ctx.tool_loop_agent(prompt="hello mcp") + assert "inspect" not in tool_loop.text + + assert ctx.router._mcp_session_store == {} + + +@pytest.mark.asyncio +async def test_mock_context_local_mcp_tools_are_plugin_scoped() -> None: + ctx_a = MockContext( + plugin_id="plugin-a", + plugin_metadata={ + "local_mcp_servers": { + "alpha": { + "name": "alpha", + "scope": "local", + "active": True, + "running": True, + "config": {"mock_tools": ["lookup"]}, + "tools": ["lookup"], + "errlogs": [], + "last_error": None, + } + } + }, + ) + ctx_b = MockContext( + plugin_id="plugin-b", + plugin_metadata={ + "local_mcp_servers": { + "beta": { + "name": "beta", + "scope": "local", + "active": True, + "running": True, + "config": {"mock_tools": ["lookup"]}, + "tools": ["lookup"], + "errlogs": [], + "last_error": None, + } + } + }, + ) + + resp_a = await ctx_a.tool_loop_agent(prompt="hello") + resp_b = await ctx_b.tool_loop_agent(prompt="hello") + + assert "mcp.alpha.lookup" in resp_a.text + assert "mcp.beta.lookup" not in resp_a.text + assert "mcp.beta.lookup" in resp_b.text + assert "mcp.alpha.lookup" not in resp_b.text + assert ctx_a.router._mcp_global_servers == {} + assert ctx_b.router._mcp_global_servers == {} + + +@pytest.mark.asyncio +async def test_mock_context_global_mcp_requires_ack_and_audits() -> None: + plain_ctx = MockContext(plugin_id="plain-plugin") + with pytest.raises(PermissionError): + await plain_ctx.mcp.register_global_server( + "global-demo", + {"mock_tools": ["inspect"]}, + ) + with pytest.raises(PermissionError): + await plain_ctx.mcp.list_global_servers() + with pytest.raises(PermissionError): + await plain_ctx.mcp.get_global_server("global-demo") + + ctx = MockContext( + plugin_id="privileged-plugin", + plugin_metadata={"acknowledge_global_mcp_risk": True}, + ) + record = await ctx.mcp.register_global_server( + "global-demo", + {"mock_tools": ["inspect"]}, + ) + + assert record.scope.value == "global" + assert record.running is True + assert [item.name for item in await ctx.mcp.list_global_servers()] == [ + "global-demo" + ] + assert (await ctx.mcp.get_global_server("global-demo")).name == "global-demo" + assert ctx.router._mcp_audit_logs == [ + { + "plugin_id": "privileged-plugin", + "action": "register", + "server_name": "global-demo", + "request_id": "local_0001", + } + ] diff --git a/astrbot-sdk/tests/test_memory_client.py b/astrbot-sdk/tests/test_memory_client.py new file mode 100644 index 0000000000..dd13d9b7a9 --- /dev/null +++ b/astrbot-sdk/tests/test_memory_client.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +from typing import Any + +import pytest +from astrbot_sdk.clients.memory import MemoryClient + + +class _FakeProxy: + def __init__(self, responses: dict[str, dict[str, Any]] | None = None) -> None: + self.responses = responses or {} + self.calls: list[tuple[str, dict[str, Any]]] = [] + + async def call(self, name: str, payload: dict[str, Any]) -> dict[str, Any]: + self.calls.append((name, dict(payload))) + return dict(self.responses.get(name, {})) + + +@pytest.mark.asyncio +async def test_root_client_search_preserves_explicit_root_namespace() -> None: + proxy = _FakeProxy({"memory.search": {"items": []}}) + client = MemoryClient(proxy) # type: ignore[arg-type] + + await client.search("shared", namespace="", include_descendants=False) + + assert proxy.calls == [ + ( + "memory.search", + { + "query": "shared", + "mode": "auto", + "namespace": "", + "include_descendants": False, + }, + ) + ] + + +@pytest.mark.asyncio +async def test_root_client_search_omits_namespace_when_scope_is_unspecified() -> None: + proxy = _FakeProxy({"memory.search": {"items": []}}) + client = MemoryClient(proxy) # type: ignore[arg-type] + + await client.search("shared") + + assert proxy.calls == [ + ( + "memory.search", + { + "query": "shared", + "mode": "auto", + "include_descendants": True, + }, + ) + ] + + +@pytest.mark.asyncio +async def test_stats_returns_namespace_backend_fields() -> None: + proxy = _FakeProxy( + { + "memory.stats": { + "total_items": 3, + "total_bytes": 128, + "namespace": "users/alice", + "namespace_count": 2, + "fts_enabled": True, + "vector_backend": "faiss", + "vector_indexes": [{"provider_id": "embedding-1", "dirty": False}], + "plugin_id": "test-plugin", + "ttl_entries": 1, + } + } + ) + client = MemoryClient(proxy, namespace="users") # type: ignore[arg-type] + + stats = await client.stats(namespace="alice", include_descendants=False) + + assert proxy.calls == [ + ( + "memory.stats", + { + "include_descendants": False, + "namespace": "users/alice", + }, + ) + ] + assert stats == { + "total_items": 3, + "total_bytes": 128, + "namespace": "users/alice", + "namespace_count": 2, + "fts_enabled": True, + "vector_backend": "faiss", + "vector_indexes": [{"provider_id": "embedding-1", "dirty": False}], + "plugin_id": "test-plugin", + "ttl_entries": 1, + } + + +@pytest.mark.asyncio +async def test_list_keys_resolves_exact_namespace_and_returns_keys() -> None: + proxy = _FakeProxy({"memory.list_keys": {"keys": ["Alpha", "beta"]}}) + client = MemoryClient(proxy, namespace="users") # type: ignore[arg-type] + + keys = await client.list_keys(namespace="alice") + + assert proxy.calls == [ + ( + "memory.list_keys", + { + "namespace": "users/alice", + }, + ) + ] + assert keys == ["Alpha", "beta"] + + +@pytest.mark.asyncio +async def test_exists_uses_exact_namespace_and_returns_boolean() -> None: + proxy = _FakeProxy({"memory.exists": {"exists": True}}) + client = MemoryClient(proxy, namespace="users") # type: ignore[arg-type] + + exists = await client.exists("profile", namespace="alice") + + assert proxy.calls == [ + ( + "memory.exists", + { + "key": "profile", + "namespace": "users/alice", + }, + ) + ] + assert exists is True + + +@pytest.mark.asyncio +async def test_clear_namespace_returns_deleted_count() -> None: + proxy = _FakeProxy({"memory.clear_namespace": {"deleted_count": 3}}) + client = MemoryClient(proxy, namespace="users") # type: ignore[arg-type] + + deleted_count = await client.clear_namespace( + namespace="alice", + include_descendants=True, + ) + + assert proxy.calls == [ + ( + "memory.clear_namespace", + { + "namespace": "users/alice", + "include_descendants": True, + }, + ) + ] + assert deleted_count == 3 + + +@pytest.mark.asyncio +async def test_count_uses_exact_namespace_and_returns_integer() -> None: + proxy = _FakeProxy({"memory.count": {"count": 2}}) + client = MemoryClient(proxy, namespace="users") # type: ignore[arg-type] + + count = await client.count(namespace="alice", include_descendants=False) + + assert proxy.calls == [ + ( + "memory.count", + { + "namespace": "users/alice", + "include_descendants": False, + }, + ) + ] + assert count == 2 diff --git a/astrbot-sdk/tests/test_memory_runtime.py b/astrbot-sdk/tests/test_memory_runtime.py new file mode 100644 index 0000000000..a1cc2dc8ca --- /dev/null +++ b/astrbot-sdk/tests/test_memory_runtime.py @@ -0,0 +1,643 @@ +from __future__ import annotations + +import asyncio +from datetime import datetime, timedelta, timezone +from pathlib import Path + +import astrbot_sdk._memory_backend as memory_backend_module +import pytest +from astrbot_sdk._internal.invocation_context import caller_plugin_scope +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.runtime.capability_router import CapabilityRouter + + +async def _call( + router: CapabilityRouter, + capability: str, + payload: dict[str, object], + *, + plugin_id: str = "test-plugin", +) -> dict[str, object]: + with caller_plugin_scope(plugin_id): + result = await router.execute( + capability, + payload, + stream=False, + cancel_token=object(), + request_id=f"{plugin_id}:{capability}", + ) + assert isinstance(result, dict) + return result + + +def _memory_db_path(tmp_path: Path, plugin_id: str) -> Path: + return ( + tmp_path + / ".astrbot_sdk_testing" + / "plugin_data" + / plugin_id + / "memory" + / "memory.sqlite3" + ) + + +@pytest.mark.asyncio +async def test_memory_is_plugin_scoped_and_persistent( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.chdir(tmp_path) + router = CapabilityRouter() + + await _call( + router, + "memory.save", + {"key": "profile", "value": {"content": "alice likes blue"}}, + plugin_id="plugin-a", + ) + await _call( + router, + "memory.save", + {"key": "profile", "value": {"content": "bob likes green"}}, + plugin_id="plugin-b", + ) + + profile_a = await _call( + router, + "memory.get", + {"key": "profile"}, + plugin_id="plugin-a", + ) + profile_b = await _call( + router, + "memory.get", + {"key": "profile"}, + plugin_id="plugin-b", + ) + + assert profile_a == {"value": {"content": "alice likes blue"}} + assert profile_b == {"value": {"content": "bob likes green"}} + assert _memory_db_path(tmp_path, "plugin-a").exists() + + restarted = CapabilityRouter() + persisted = await _call( + restarted, + "memory.get", + {"key": "profile"}, + plugin_id="plugin-a", + ) + assert persisted == {"value": {"content": "alice likes blue"}} + + +@pytest.mark.asyncio +async def test_memory_namespace_search_respects_descendants( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.chdir(tmp_path) + router = CapabilityRouter() + + await _call( + router, + "memory.save", + { + "key": "profile", + "namespace": "users/alice", + "value": {"content": "alice likes blue"}, + }, + ) + await _call( + router, + "memory.save", + { + "key": "session-note", + "namespace": "users/alice/sessions/1", + "value": {"content": "alice asked about the sea"}, + }, + ) + await _call( + router, + "memory.save", + { + "key": "profile", + "namespace": "users/bob", + "value": {"content": "bob likes green"}, + }, + ) + + exact = await _call( + router, + "memory.search", + { + "query": "alice", + "namespace": "users/alice", + "include_descendants": False, + "mode": "keyword", + }, + ) + scoped = await _call( + router, + "memory.search", + { + "query": "alice", + "namespace": "users/alice", + "include_descendants": True, + "mode": "keyword", + }, + ) + + assert [(item["namespace"], item["key"]) for item in exact["items"]] == [ + ("users/alice", "profile") + ] + assert {(item["namespace"], item["key"]) for item in scoped["items"]} == { + ("users/alice", "profile"), + ("users/alice/sessions/1", "session-note"), + } + + +@pytest.mark.asyncio +async def test_memory_search_auto_falls_back_to_keyword_without_embedding_provider( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.chdir(tmp_path) + router = CapabilityRouter() + router._active_provider_ids["embedding"] = None + + await _call( + router, + "memory.save", + {"key": "alpha-key", "value": {"content": "blue ocean memory"}}, + ) + + result = await _call(router, "memory.search", {"query": "alpha", "mode": "auto"}) + + assert [item["key"] for item in result["items"]] == ["alpha-key"] + assert result["items"][0]["match_type"] == "keyword" + + +@pytest.mark.asyncio +async def test_memory_vector_search_and_stats_report_vector_backend( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.chdir(tmp_path) + router = CapabilityRouter() + + await _call( + router, + "memory.save", + {"key": "fruit-note", "value": {"content": "banana smoothie with mango"}}, + ) + await _call( + router, + "memory.save", + {"key": "ocean-note", "value": {"content": "waves on the blue ocean"}}, + ) + + result = await _call( + router, + "memory.search", + {"query": "banana smoothie", "mode": "vector", "limit": 1}, + ) + stats = await _call(router, "memory.stats", {}) + + assert len(result["items"]) == 1 + assert result["items"][0]["key"] == "fruit-note" + assert result["items"][0]["match_type"] == "vector" + assert stats["plugin_id"] == "test-plugin" + assert stats["total_items"] == 2 + assert stats["vector_backend"] in {"faiss", "exact"} + + +@pytest.mark.asyncio +async def test_memory_save_with_ttl_expires_across_restart( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.chdir(tmp_path) + base_now = datetime(2026, 1, 1, tzinfo=timezone.utc) + monkeypatch.setattr(memory_backend_module, "_utcnow", lambda: base_now) + router = CapabilityRouter() + + await _call( + router, + "memory.save_with_ttl", + { + "key": "session", + "namespace": "users/alice/sessions/1", + "value": {"content": "active session"}, + "ttl_seconds": 60, + }, + ) + + result = await _call( + router, + "memory.get_many", + {"keys": ["session"], "namespace": "users/alice/sessions/1"}, + ) + assert result == { + "items": [{"key": "session", "value": {"content": "active session"}}] + } + + monkeypatch.setattr( + memory_backend_module, + "_utcnow", + lambda: base_now + timedelta(seconds=61), + ) + restarted = CapabilityRouter() + expired = await _call( + restarted, + "memory.get", + {"key": "session", "namespace": "users/alice/sessions/1"}, + ) + assert expired == {"value": None} + + +@pytest.mark.asyncio +async def test_memory_rejects_unsafe_plugin_id( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.chdir(tmp_path) + router = CapabilityRouter() + + with pytest.raises(AstrBotError) as exc_info: + await _call( + router, + "memory.save", + {"key": "profile", "value": {"content": "alice likes blue"}}, + plugin_id="../escape", + ) + + assert exc_info.value.code == "invalid_input" + assert "safe plugin_id" in exc_info.value.message + + +@pytest.mark.asyncio +async def test_memory_stats_can_scope_by_namespace( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.chdir(tmp_path) + router = CapabilityRouter() + + await _call( + router, + "memory.save", + { + "key": "root-note", + "value": {"content": "top level"}, + }, + ) + await _call( + router, + "memory.save", + { + "key": "user-note", + "namespace": "users/alice", + "value": {"content": "alice memory"}, + }, + ) + await _call( + router, + "memory.save", + { + "key": "session-note", + "namespace": "users/alice/sessions/1", + "value": {"content": "session memory"}, + }, + ) + + scoped = await _call( + router, + "memory.stats", + {"namespace": "users/alice", "include_descendants": True}, + ) + + assert scoped["namespace"] == "users/alice" + assert scoped["total_items"] == 2 + assert scoped["namespace_count"] == 2 + assert scoped["fts_enabled"] in {True, False} + + +@pytest.mark.asyncio +async def test_memory_search_and_stats_can_target_root_namespace_exactly( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.chdir(tmp_path) + router = CapabilityRouter() + + await _call( + router, + "memory.save", + {"key": "root-note", "value": {"content": "shared note at root"}}, + ) + await _call( + router, + "memory.save", + { + "key": "child-note", + "namespace": "users/alice", + "value": {"content": "shared note in child namespace"}, + }, + ) + + result = await _call( + router, + "memory.search", + { + "query": "shared note", + "namespace": "", + "include_descendants": False, + "mode": "keyword", + }, + ) + stats = await _call( + router, + "memory.stats", + {"namespace": "", "include_descendants": False}, + ) + + assert [(item.get("namespace"), item["key"]) for item in result["items"]] == [ + (None, "root-note") + ] + assert stats["namespace"] == "" + assert stats["total_items"] == 1 + + +@pytest.mark.asyncio +async def test_memory_namespace_scope_escapes_like_wildcards( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.chdir(tmp_path) + router = CapabilityRouter() + + await _call( + router, + "memory.save", + { + "key": "safe", + "namespace": "team_1/room", + "value": {"content": "team scoped note"}, + }, + ) + await _call( + router, + "memory.save", + { + "key": "leak", + "namespace": "teamA1/room", + "value": {"content": "team scoped note"}, + }, + ) + + result = await _call( + router, + "memory.search", + { + "query": "team scoped", + "namespace": "team_1", + "include_descendants": True, + "mode": "keyword", + }, + ) + stats = await _call( + router, + "memory.stats", + {"namespace": "team_1", "include_descendants": True}, + ) + + assert [(item["namespace"], item["key"]) for item in result["items"]] == [ + ("team_1/room", "safe") + ] + assert stats["total_items"] == 1 + + +@pytest.mark.asyncio +async def test_memory_management_capabilities_cover_exact_and_recursive_scope( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.chdir(tmp_path) + router = CapabilityRouter() + + await _call( + router, + "memory.save", + { + "key": "beta", + "namespace": "users/alice", + "value": {"content": "beta note"}, + }, + ) + await _call( + router, + "memory.save", + { + "key": "Alpha", + "namespace": "users/alice", + "value": {"content": "alpha note"}, + }, + ) + await _call( + router, + "memory.save", + { + "key": "apple", + "namespace": "users/alice", + "value": {"content": "apple note"}, + }, + ) + await _call( + router, + "memory.save", + { + "key": "child-note", + "namespace": "users/alice/sessions/1", + "value": {"content": "child note"}, + }, + ) + + keys = await _call( + router, + "memory.list_keys", + {"namespace": "users/alice"}, + ) + exact_count = await _call( + router, + "memory.count", + {"namespace": "users/alice"}, + ) + recursive_count = await _call( + router, + "memory.count", + {"namespace": "users/alice", "include_descendants": True}, + ) + exists = await _call( + router, + "memory.exists", + {"key": "child-note", "namespace": "users/alice/sessions/1"}, + ) + missing = await _call( + router, + "memory.exists", + {"key": "child-note", "namespace": "users/alice"}, + ) + cleared_exact = await _call( + router, + "memory.clear_namespace", + {"namespace": "users/alice"}, + ) + remaining_recursive = await _call( + router, + "memory.count", + {"namespace": "users/alice", "include_descendants": True}, + ) + cleared_recursive = await _call( + router, + "memory.clear_namespace", + {"namespace": "users/alice", "include_descendants": True}, + ) + final_count = await _call( + router, + "memory.count", + {"namespace": "users/alice", "include_descendants": True}, + ) + + assert keys == {"keys": ["Alpha", "apple", "beta"]} + assert exact_count == {"count": 3} + assert recursive_count == {"count": 4} + assert exists == {"exists": True} + assert missing == {"exists": False} + assert cleared_exact == {"deleted_count": 3} + assert remaining_recursive == {"count": 1} + assert cleared_recursive == {"deleted_count": 1} + assert final_count == {"count": 0} + + +@pytest.mark.asyncio +async def test_memory_management_capabilities_ignore_expired_ttl_entries( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.chdir(tmp_path) + base_now = datetime(2026, 1, 1, tzinfo=timezone.utc) + monkeypatch.setattr(memory_backend_module, "_utcnow", lambda: base_now) + router = CapabilityRouter() + + await _call( + router, + "memory.save_with_ttl", + { + "key": "temp", + "namespace": "users/alice", + "value": {"content": "temporary"}, + "ttl_seconds": 60, + }, + ) + + monkeypatch.setattr( + memory_backend_module, + "_utcnow", + lambda: base_now + timedelta(seconds=61), + ) + restarted = CapabilityRouter() + + keys = await _call( + restarted, + "memory.list_keys", + {"namespace": "users/alice"}, + ) + count = await _call( + restarted, + "memory.count", + {"namespace": "users/alice"}, + ) + exists = await _call( + restarted, + "memory.exists", + {"key": "temp", "namespace": "users/alice"}, + ) + + assert keys == {"keys": []} + assert count == {"count": 0} + assert exists == {"exists": False} + + +@pytest.mark.asyncio +async def test_memory_management_capabilities_remain_plugin_scoped_under_overlap( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.chdir(tmp_path) + router = CapabilityRouter() + + await _call( + router, + "memory.save", + { + "key": "profile", + "namespace": "users/alice", + "value": {"content": "plugin a"}, + }, + plugin_id="plugin-a", + ) + await _call( + router, + "memory.save", + { + "key": "session", + "namespace": "users/alice/sessions/1", + "value": {"content": "plugin a child"}, + }, + plugin_id="plugin-a", + ) + await _call( + router, + "memory.save", + { + "key": "profile", + "namespace": "users/alice", + "value": {"content": "plugin b"}, + }, + plugin_id="plugin-b", + ) + + clear_task = _call( + router, + "memory.clear_namespace", + {"namespace": "users/alice", "include_descendants": True}, + plugin_id="plugin-a", + ) + count_task = _call( + router, + "memory.count", + {"namespace": "users/alice", "include_descendants": True}, + plugin_id="plugin-b", + ) + exists_task = _call( + router, + "memory.exists", + {"key": "profile", "namespace": "users/alice"}, + plugin_id="plugin-b", + ) + cleared, plugin_b_count, plugin_b_exists = await asyncio.gather( + clear_task, + count_task, + exists_task, + ) + + plugin_a_after = await _call( + router, + "memory.count", + {"namespace": "users/alice", "include_descendants": True}, + plugin_id="plugin-a", + ) + + assert cleared == {"deleted_count": 2} + assert plugin_b_count == {"count": 1} + assert plugin_b_exists == {"exists": True} + assert plugin_a_after == {"count": 0} diff --git a/astrbot-sdk/tests/test_message_components.py b/astrbot-sdk/tests/test_message_components.py new file mode 100644 index 0000000000..8c640bffc2 --- /dev/null +++ b/astrbot-sdk/tests/test_message_components.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from astrbot_sdk.message.components import File, Image, Record, Video + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("factory", "url", "prefix", "suffix"), + [ + (Image.fromURL, "https://example.com/test.jpg", "imgseg", ".jpg"), + (Record.fromURL, "https://example.com/test.dat", "recordseg", ".dat"), + (Video.fromURL, "https://example.com/test.mp4", "videoseg", ""), + ], +) +async def test_remote_media_download_uses_async_to_thread( + monkeypatch: pytest.MonkeyPatch, + factory, + url: str, + prefix: str, + suffix: str, +) -> None: + calls: list[tuple[object, tuple[object, ...]]] = [] + + async def fake_to_thread(func, *args): + calls.append((func, args)) + return str(Path("C:/tmp/downloaded.bin")) + + monkeypatch.setattr( + "astrbot_sdk.message.components.asyncio.to_thread", fake_to_thread + ) + + component = factory(url) + path = await component.convert_to_file_path() + + assert Path(path) == Path("C:/tmp/downloaded.bin") + assert len(calls) == 1 + _, args = calls[0] + assert args == (url, prefix, suffix) + + +@pytest.mark.asyncio +async def test_file_get_file_uses_async_to_thread_for_remote_download( + monkeypatch: pytest.MonkeyPatch, +) -> None: + calls: list[tuple[object, tuple[object, ...]]] = [] + + async def fake_to_thread(func, *args): + calls.append((func, args)) + return str(Path("C:/tmp/file-download.bin")) + + monkeypatch.setattr( + "astrbot_sdk.message.components.asyncio.to_thread", fake_to_thread + ) + + component = File(name="demo.bin", url="https://example.com/demo.bin") + path = await component.get_file() + + assert Path(path) == Path("C:/tmp/file-download.bin") + assert Path(component.file) == Path("C:/tmp/file-download.bin") + assert len(calls) == 1 + _, args = calls[0] + assert args == ("https://example.com/demo.bin", "fileseg", ".bin") diff --git a/astrbot-sdk/tests/test_message_history_runtime.py b/astrbot-sdk/tests/test_message_history_runtime.py new file mode 100644 index 0000000000..1e3ca40e63 --- /dev/null +++ b/astrbot-sdk/tests/test_message_history_runtime.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest + +from astrbot_sdk._internal.testing_support import MockContext +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.message.components import Plain +from astrbot_sdk.message.session import MessageSession + + +def _session_store_key(session: MessageSession) -> str: + return f"{session.platform_id}:{session.message_type}:{session.session_id}" + + +@pytest.mark.asyncio +async def test_mock_context_message_history_round_trip_and_aliases() -> None: + ctx = MockContext(plugin_id="sdk-demo") + private_session = MessageSession( + platform_id="demo-platform", + message_type="private", + session_id="user-1", + ) + group_session = MessageSession( + platform_id="demo-platform", + message_type="group", + session_id="user-1", + ) + + assert ctx.message_history_manager is ctx.message_history + + first = await ctx.message_history.append( + private_session, + parts=[Plain("first", convert=False)], + sender={"sender_id": "sender-1", "sender_name": "Tester"}, + metadata={"source": "test"}, + idempotency_key="idem-1", + ) + repeated = await ctx.message_history.append( + private_session, + parts=[Plain("first", convert=False)], + sender={"sender_id": "sender-1", "sender_name": "Tester"}, + metadata={"source": "test"}, + idempotency_key="idem-1", + ) + second = await ctx.message_history.append( + private_session, + parts=[Plain("second", convert=False)], + sender={"sender_id": "sender-2", "sender_name": "Tester 2"}, + ) + third = await ctx.message_history.append( + private_session, + parts=[Plain("third", convert=False)], + sender={"sender_id": "sender-3", "sender_name": "Tester 3"}, + ) + group_record = await ctx.message_history.append( + group_session, + parts=[Plain("group only", convert=False)], + sender={"sender_id": "group-sender", "sender_name": "Group Tester"}, + ) + + assert repeated.id == first.id + assert group_record.session.message_type == "group" + + first_page = await ctx.message_history.list(private_session, limit=2) + assert [record.id for record in first_page.records] == [third.id, second.id] + assert first_page.next_cursor == str(second.id) + assert first_page.total == 3 + assert first_page.records[0].parts[0].text == "third" + + second_page = await ctx.message_history.list( + private_session, + cursor=first_page.next_cursor, + limit=2, + ) + assert [record.id for record in second_page.records] == [first.id] + assert second_page.next_cursor is None + + fetched = await ctx.message_history.get(private_session, second.id) + assert fetched is not None + assert fetched.sender.sender_id == "sender-2" + assert fetched.metadata == {} + assert fetched.parts[0].text == "second" + + group_page = await ctx.message_history.list(group_session, limit=10) + assert [record.id for record in group_page.records] == [group_record.id] + + store = ctx.router._message_history_store[_session_store_key(private_session)] + timestamps = { + first.id: "2026-03-20T00:00:00+00:00", + second.id: "2026-03-21T00:00:00+00:00", + third.id: "2026-03-22T00:00:00+00:00", + } + for record in store: + stamped = timestamps.get(int(record["id"])) + if stamped is None: + continue + record["created_at"] = stamped + record["updated_at"] = stamped + + deleted_before = await ctx.message_history.delete_before( + private_session, + before=datetime(2026, 3, 21, 0, 0, tzinfo=timezone.utc), + ) + assert deleted_before == 1 + remaining_after_before = await ctx.message_history.list(private_session, limit=10) + assert [record.id for record in remaining_after_before.records] == [ + third.id, + second.id, + ] + + deleted_after = await ctx.message_history.delete_after( + private_session, + after=datetime(2026, 3, 21, 12, 0, tzinfo=timezone.utc), + ) + assert deleted_after == 1 + remaining_after_after = await ctx.message_history.list(private_session, limit=10) + assert [record.id for record in remaining_after_after.records] == [second.id] + + deleted_all = await ctx.message_history.delete_all(private_session) + assert deleted_all == 1 + assert (await ctx.message_history.list(private_session, limit=10)).records == [] + assert [ + record.id + for record in (await ctx.message_history.list(group_session, limit=10)).records + ] == [group_record.id] + + +@pytest.mark.asyncio +async def test_message_history_delete_boundaries_normalize_naive_datetime_to_utc() -> ( + None +): + ctx = MockContext(plugin_id="sdk-demo") + session = MessageSession( + platform_id="demo-platform", + message_type="private", + session_id="user-1", + ) + + first = await ctx.message_history.append( + session, + parts=[Plain("first", convert=False)], + sender={"sender_id": "sender-1", "sender_name": "Tester"}, + ) + second = await ctx.message_history.append( + session, + parts=[Plain("second", convert=False)], + sender={"sender_id": "sender-2", "sender_name": "Tester 2"}, + ) + third = await ctx.message_history.append( + session, + parts=[Plain("third", convert=False)], + sender={"sender_id": "sender-3", "sender_name": "Tester 3"}, + ) + + store = ctx.router._message_history_store[_session_store_key(session)] + timestamps = { + first.id: "2026-03-20T00:00:00+00:00", + second.id: "2026-03-21T00:00:00+00:00", + third.id: "2026-03-22T00:00:00+00:00", + } + for record in store: + stamped = timestamps.get(int(record["id"])) + if stamped is None: + continue + record["created_at"] = stamped + record["updated_at"] = stamped + + deleted_before = await ctx.message_history.delete_before( + session, + before=datetime(2026, 3, 21, 0, 0), + ) + assert deleted_before == 1 + + deleted_after = await ctx.message_history.delete_after( + session, + after=datetime(2026, 3, 21, 12, 0), + ) + assert deleted_after == 1 + + remaining = await ctx.message_history.list(session, limit=10) + assert [record.id for record in remaining.records] == [second.id] + + +@pytest.mark.asyncio +async def test_message_history_list_invalid_cursor_returns_invalid_input() -> None: + ctx = MockContext(plugin_id="sdk-demo") + session = MessageSession( + platform_id="demo-platform", + message_type="private", + session_id="user-1", + ) + await ctx.message_history.append( + session, + parts=[Plain("first", convert=False)], + sender={"sender_id": "sender-1", "sender_name": "Tester"}, + ) + + with pytest.raises(AstrBotError) as exc_info: + await ctx.message_history.list(session, cursor="abc") + + assert exc_info.value.code == "invalid_input" diff --git a/astrbot-sdk/tests/test_permission_client.py b/astrbot-sdk/tests/test_permission_client.py new file mode 100644 index 0000000000..43a7c34f4d --- /dev/null +++ b/astrbot-sdk/tests/test_permission_client.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +from astrbot_sdk.clients.permission import PermissionClient, PermissionManagerClient + + +class _FakeProxy: + def __init__(self, responses: dict[str, dict[str, Any]] | None = None) -> None: + self.responses = responses or {} + self.calls: list[tuple[str, dict[str, Any]]] = [] + + async def call(self, name: str, payload: dict[str, Any]) -> dict[str, Any]: + self.calls.append((name, dict(payload))) + return dict(self.responses.get(name, {})) + + +@pytest.mark.asyncio +async def test_permission_client_check_preserves_optional_session_id() -> None: + proxy = _FakeProxy({"permission.check": {"is_admin": True, "role": "admin"}}) + client = PermissionClient(proxy) # type: ignore[arg-type] + + result = await client.check("user-1", session_id="demo:group:42") + + assert proxy.calls == [ + ( + "permission.check", + {"user_id": "user-1", "session_id": "demo:group:42"}, + ) + ] + assert result.is_admin is True + assert result.role == "admin" + + +@pytest.mark.asyncio +async def test_permission_client_get_admins_returns_strings() -> None: + proxy = _FakeProxy({"permission.get_admins": {"admins": ["alpha", 42]}}) + client = PermissionClient(proxy) # type: ignore[arg-type] + + admins = await client.get_admins() + + assert proxy.calls == [("permission.get_admins", {})] + assert admins == ["alpha", "42"] + + +@pytest.mark.asyncio +async def test_permission_manager_client_forwards_admin_event_flag() -> None: + proxy = _FakeProxy({"permission.manager.add_admin": {"changed": True}}) + client = PermissionManagerClient( + proxy, # type: ignore[arg-type] + source_event_payload={"is_admin": True}, + ) + + changed = await client.add_admin("user-2") + + assert changed is True + assert proxy.calls == [ + ( + "permission.manager.add_admin", + {"user_id": "user-2", "_caller_is_admin": True}, + ) + ] + + +@pytest.mark.asyncio +async def test_permission_manager_client_remove_admin_defaults_to_non_admin_context() -> ( + None +): + proxy = _FakeProxy({"permission.manager.remove_admin": {"changed": False}}) + client = PermissionManagerClient(proxy) # type: ignore[arg-type] + + changed = await client.remove_admin("user-2") + + assert changed is False + assert proxy.calls == [ + ( + "permission.manager.remove_admin", + {"user_id": "user-2", "_caller_is_admin": False}, + ) + ] diff --git a/astrbot-sdk/tests/test_permission_runtime.py b/astrbot-sdk/tests/test_permission_runtime.py new file mode 100644 index 0000000000..9c22a0c317 --- /dev/null +++ b/astrbot-sdk/tests/test_permission_runtime.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from astrbot_sdk.context import Context +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.testing import MockContext, PluginHarness + + +def _write_permission_plugin(plugin_dir: Path) -> None: + plugin_dir.mkdir(parents=True, exist_ok=True) + (plugin_dir / "plugin.yaml").write_text( + """ +_schema_version: 2 +name: permission_runtime_plugin +author: tests +version: 1.0.0 +desc: permission runtime tests + +runtime: + python: "3.12" + +components: + - class: main:PermissionRuntimePlugin +""".strip() + + "\n", + encoding="utf-8", + ) + (plugin_dir / "requirements.txt").write_text("", encoding="utf-8") + (plugin_dir / "main.py").write_text( + """ +from astrbot_sdk import Context, MessageEvent, Star +from astrbot_sdk.decorators import on_command, require_permission + + +class PermissionRuntimePlugin(Star): + @on_command("panel") + @require_permission("admin") + async def panel(self, event: MessageEvent, ctx: Context) -> None: + await event.reply("admin-only") + + @on_command("ping") + @require_permission("member") + async def ping(self, event: MessageEvent, ctx: Context) -> None: + await event.reply("member-ok") +""".lstrip(), + encoding="utf-8", + ) + + +@pytest.mark.asyncio +async def test_mock_context_permission_clients_and_manager_gates() -> None: + ctx = MockContext(plugin_id="plain-plugin") + ctx.router.set_admin_ids(["root", "maintainer"]) + + check = await ctx.permission.check("root", session_id="demo:group:42") + + assert check.is_admin is True + assert check.role == "admin" + assert await ctx.permission.get_admins() == ["root", "maintainer"] + + elevated_plain = Context( + peer=ctx.mock_peer, + plugin_id="plain-plugin", + source_event_payload={"is_admin": True}, + ) + with pytest.raises(AstrBotError, match="reserved/system"): + await elevated_plain.permission_manager.add_admin("alice") + + reserved_ctx = MockContext( + plugin_id="reserved-plugin", + plugin_metadata={"reserved": True}, + ) + reserved_ctx.router.set_admin_ids(["root"]) + + admin_ctx = Context( + peer=reserved_ctx.mock_peer, + plugin_id="reserved-plugin", + source_event_payload={"is_admin": True}, + ) + viewer_ctx = Context( + peer=reserved_ctx.mock_peer, + plugin_id="reserved-plugin", + source_event_payload={"is_admin": False}, + ) + + assert await admin_ctx.permission_manager.add_admin("alice") is True + assert await admin_ctx.permission_manager.add_admin("alice") is False + assert await admin_ctx.permission.get_admins() == ["root", "alice"] + assert await admin_ctx.permission_manager.remove_admin("alice") is True + assert await admin_ctx.permission_manager.remove_admin("alice") is False + + with pytest.raises(AstrBotError, match="active admin event context"): + await viewer_ctx.permission_manager.add_admin("bob") + + +@pytest.mark.asyncio +async def test_plugin_harness_respects_require_permission_roles( + tmp_path: Path, +) -> None: + plugin_dir = tmp_path / "permission_runtime_plugin" + _write_permission_plugin(plugin_dir) + + async with PluginHarness.from_plugin_dir(plugin_dir) as harness: + panel_payload = harness.build_event_payload( + text="panel", + request_id="req-panel-member", + ) + with pytest.raises(AstrBotError, match="未找到匹配的 handler"): + await harness.dispatch_event(panel_payload, request_id="req-panel-member") + + admin_payload = harness.build_event_payload( + text="panel", + request_id="req-panel-admin", + ) + admin_payload["is_admin"] = True + admin_records = await harness.dispatch_event( + admin_payload, + request_id="req-panel-admin", + ) + + member_payload = harness.build_event_payload( + text="ping", + request_id="req-ping-member", + ) + member_records = await harness.dispatch_event( + member_payload, + request_id="req-ping-member", + ) + + assert len(admin_records) == 1 + assert admin_records[0].text == "admin-only" + assert len(member_records) == 1 + assert member_records[0].text == "member-ok" diff --git a/astrbot-sdk/tests/test_plugin_ids.py b/astrbot-sdk/tests/test_plugin_ids.py new file mode 100644 index 0000000000..f23e56788a --- /dev/null +++ b/astrbot-sdk/tests/test_plugin_ids.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest +from astrbot_sdk._internal.plugin_ids import resolve_plugin_data_dir, validate_plugin_id + + +def test_validate_plugin_id_accepts_safe_identifiers() -> None: + assert validate_plugin_id("plugin-1.alpha_beta") == "plugin-1.alpha_beta" + + +@pytest.mark.parametrize( + "plugin_id", + ["", "../escape", "bad/name", r"bad\\name", "bad.", "CON"], +) +def test_validate_plugin_id_rejects_unsafe_values(plugin_id: str) -> None: + with pytest.raises(ValueError): + validate_plugin_id(plugin_id) + + +def test_resolve_plugin_data_dir_stays_within_root(tmp_path: Path) -> None: + resolved = resolve_plugin_data_dir(tmp_path, "plugin-a") + + assert resolved == tmp_path.resolve() / "plugin-a" diff --git a/astrbot-sdk/tests/test_plugin_logger.py b/astrbot-sdk/tests/test_plugin_logger.py new file mode 100644 index 0000000000..71c76dd1ff --- /dev/null +++ b/astrbot-sdk/tests/test_plugin_logger.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import re + +from astrbot_sdk._internal.plugin_logger import PluginLogger + +_ANSI_RE = re.compile(r"\x1b\[[0-9;]*m") + + +class _CapturingLogger: + def __init__(self) -> None: + self.calls: list[dict[str, object]] = [] + self._current_opt: dict[str, object] = {} + + def bind(self, **_kwargs): + return self + + def opt(self, *args, **kwargs): + self._current_opt = dict(kwargs) + return self + + def log(self, level, message, *args, **kwargs) -> None: + self.calls.append( + { + "level": level, + "message": message, + "args": args, + "kwargs": kwargs, + "opt": dict(self._current_opt), + } + ) + self._current_opt = {} + + +def _strip_ansi(text: str) -> str: + return _ANSI_RE.sub("", text) + + +def test_plugin_logger_formats_like_core_console(monkeypatch) -> None: + logger = _CapturingLogger() + plugin_logger = PluginLogger(plugin_id="ai_girlfriend", logger=logger) + monkeypatch.setattr( + plugin_logger, + "_caller_info", + lambda: ("D:/repo/data/sdk_plugins/ai_girlfriend/gf_plugin.py", 321), + ) + + plugin_logger.info("hello {}", "world") + + assert len(logger.calls) == 1 + call = logger.calls[0] + assert call["level"] == "INFO" + assert call["opt"] == {"raw": True} + assert re.match( + r"^\[\d{2}:\d{2}:\d{2}\.\d{3}\] \[Plug\] \[INFO\] " + r"\[ai_girlfriend\.gf_plugin:321\]: hello world\n$", + _strip_ansi(str(call["message"])), + ) + + +def test_plugin_logger_uses_core_tag_for_sdk_internal_paths(monkeypatch) -> None: + logger = _CapturingLogger() + plugin_logger = PluginLogger(plugin_id="ai_girlfriend", logger=logger) + monkeypatch.setattr( + plugin_logger, + "_caller_info", + lambda: ("D:/repo/astrbot-sdk/src/astrbot_sdk/context.py", 88), + ) + + plugin_logger.warning("watch {}", "out") + + assert len(logger.calls) == 1 + call = logger.calls[0] + assert call["level"] == "WARNING" + assert call["opt"] == {"raw": True} + rendered = _strip_ansi(str(call["message"])) + assert "[Core] [WARN]" in rendered + assert "[astrbot_sdk.context:88]: watch out\n" in rendered diff --git a/astrbot-sdk/tests/test_protocol_stdout.py b/astrbot-sdk/tests/test_protocol_stdout.py new file mode 100644 index 0000000000..ac793f518d --- /dev/null +++ b/astrbot-sdk/tests/test_protocol_stdout.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import asyncio +import io +import os +from pathlib import Path + +from click.testing import CliRunner + +from astrbot_sdk import cli + + +class _FakeStream(io.StringIO): + def __init__(self, *, is_tty: bool) -> None: + super().__init__() + self._is_tty = is_tty + + def isatty(self) -> bool: + return self._is_tty + + +def test_resolve_protocol_stdout_defaults_to_silent_on_tty(monkeypatch) -> None: + fake_stdout = _FakeStream(is_tty=True) + monkeypatch.setattr("sys.stdout", fake_stdout) + + transport_stdout, opened_stdout = cli._resolve_protocol_stdout(None) + + assert opened_stdout is not None + assert transport_stdout is opened_stdout + assert getattr(transport_stdout, "name", None) == os.devnull + opened_stdout.close() + + +def test_resolve_protocol_stdout_defaults_to_console_when_stdout_is_piped( + monkeypatch, +) -> None: + fake_stdout = _FakeStream(is_tty=False) + monkeypatch.setattr("sys.stdout", fake_stdout) + + transport_stdout, opened_stdout = cli._resolve_protocol_stdout(None) + + assert transport_stdout is fake_stdout + assert opened_stdout is None + + +def test_resolve_protocol_stdout_supports_file_path( + monkeypatch, tmp_path: Path +) -> None: + fake_stdout = _FakeStream(is_tty=True) + output_path = tmp_path / "protocol.log" + monkeypatch.setattr("sys.stdout", fake_stdout) + + transport_stdout, opened_stdout = cli._resolve_protocol_stdout(str(output_path)) + + assert opened_stdout is not None + assert transport_stdout is opened_stdout + assert getattr(transport_stdout, "name", None) == str(output_path) + opened_stdout.close() + + +def test_run_command_resolves_protocol_stdout_to_stream( + monkeypatch, tmp_path: Path +) -> None: + captured: dict[str, object] = {} + + async def fake_run_supervisor(*, plugins_dir: Path, stdout=None, **_) -> None: + captured["plugins_dir"] = plugins_dir + captured["stdout_name"] = getattr(stdout, "name", None) + + def fake_run_async_entrypoint(entrypoint, **_) -> None: + asyncio.run(entrypoint) + + monkeypatch.setattr(cli, "run_supervisor", fake_run_supervisor) + monkeypatch.setattr(cli, "_run_async_entrypoint", fake_run_async_entrypoint) + + runner = CliRunner() + result = runner.invoke( + cli.cli, + ["run", "--plugins-dir", str(tmp_path), "--protocol-stdout", "silent"], + ) + + assert result.exit_code == 0 + assert captured == { + "plugins_dir": tmp_path, + "stdout_name": os.devnull, + } + + +def test_worker_command_resolves_protocol_stdout_to_stream( + monkeypatch, tmp_path: Path +) -> None: + captured: dict[str, object] = {} + plugin_dir = tmp_path / "plugin" + plugin_dir.mkdir() + output_path = tmp_path / "worker-protocol.log" + + async def fake_run_plugin_worker( + *, + plugin_dir: Path | None = None, + group_metadata: Path | None = None, + stdout=None, + **_, + ) -> None: + captured["plugin_dir"] = plugin_dir + captured["group_metadata"] = group_metadata + captured["stdout_name"] = getattr(stdout, "name", None) + + def fake_run_async_entrypoint(entrypoint, **_) -> None: + asyncio.run(entrypoint) + + monkeypatch.setattr(cli, "run_plugin_worker", fake_run_plugin_worker) + monkeypatch.setattr(cli, "_run_async_entrypoint", fake_run_async_entrypoint) + + runner = CliRunner() + result = runner.invoke( + cli.cli, + [ + "worker", + "--plugin-dir", + str(plugin_dir), + "--protocol-stdout", + str(output_path), + ], + ) + + assert result.exit_code == 0 + assert captured == { + "plugin_dir": plugin_dir, + "group_metadata": None, + "stdout_name": str(output_path), + } diff --git a/astrbot-sdk/tests/test_provider_client_context_regressions.py b/astrbot-sdk/tests/test_provider_client_context_regressions.py new file mode 100644 index 0000000000..c84b845f1a --- /dev/null +++ b/astrbot-sdk/tests/test_provider_client_context_regressions.py @@ -0,0 +1,450 @@ +from __future__ import annotations + +import asyncio +from copy import deepcopy +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any + +import pytest + +from astrbot_sdk._internal.testing_support import MockContext +from astrbot_sdk.clients._proxy import CapabilityProxy +from astrbot_sdk.clients.platform import PlatformStatus +from astrbot_sdk.clients.provider import ProviderManagerClient +from astrbot_sdk.context import PlatformCompatFacade +from astrbot_sdk.llm.entities import ProviderType + + +async def _wait_until( + predicate, + *, + timeout: float = 0.2, +) -> None: + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout + while loop.time() < deadline: + if predicate(): + return + await asyncio.sleep(0) + raise AssertionError("condition was not satisfied before timeout") + + +class _HookLogger: + def __init__(self) -> None: + self.debug_calls: list[tuple[str, str]] = [] + self.exception_calls: list[tuple[str, str]] = [] + + def debug(self, message: str, plugin_id: str) -> None: + self.debug_calls.append((message, plugin_id)) + + def exception(self, message: str, plugin_id: str) -> None: + self.exception_calls.append((message, plugin_id)) + + +@dataclass(slots=True) +class _CapabilityDescriptor: + supports_stream: bool | None = False + + +class _ProviderMutationPeer: + def __init__(self) -> None: + self.remote_peer = object() + self.remote_capability_map = { + "provider.manager.create": _CapabilityDescriptor(), + "provider.manager.load": _CapabilityDescriptor(), + "provider.manager.update": _CapabilityDescriptor(), + "provider.manager.get_merged_provider_config": _CapabilityDescriptor(), + } + self.stored_config = {"id": "provider-1", "model": "original-model"} + + async def invoke( + self, + capability: str, + payload: dict[str, Any], + *, + stream: bool = False, + request_id: str | None = None, + ) -> dict[str, Any]: + assert not stream + if capability in {"provider.manager.create", "provider.manager.load"}: + provider_config = payload["provider_config"] + assert isinstance(provider_config, dict) + provider_config["id"] = "mutated-by-peer" + provider_config["model"] = "mutated-model" + return { + "provider": { + "id": "provider-1", + "model": "created-model", + "type": "mock", + "provider_type": "chat_completion", + "loaded": True, + "enabled": True, + "provider_source_id": None, + } + } + if capability == "provider.manager.update": + new_config = payload["new_config"] + assert isinstance(new_config, dict) + new_config["id"] = "mutated-by-peer" + new_config["model"] = "mutated-model" + return { + "provider": { + "id": "provider-1", + "model": "updated-model", + "type": "mock", + "provider_type": "chat_completion", + "loaded": True, + "enabled": True, + "provider_source_id": None, + } + } + if capability == "provider.manager.get_merged_provider_config": + return {"config": self.stored_config} + raise AssertionError(f"unexpected capability: {capability}") + + async def invoke_stream( + self, + capability: str, + payload: dict[str, Any], + *, + request_id: str | None = None, + include_completed: bool = False, + ): + raise AssertionError(f"unexpected stream capability: {capability}") + + +class _ControlledPlatformProxy: + def __init__( + self, + *, + snapshots: list[dict[str, Any]], + cleared_snapshot: dict[str, Any] | None = None, + ) -> None: + self._snapshots = [dict(item) for item in snapshots] + self._cleared_snapshot = ( + dict(cleared_snapshot) if isinstance(cleared_snapshot, dict) else None + ) + self.call_order: list[str] = [] + self.get_by_id_calls = 0 + self.clear_errors_calls = 0 + self.first_get_started = asyncio.Event() + self.release_first_get = asyncio.Event() + self._cleared = False + + async def call(self, capability: str, payload: dict[str, Any]) -> dict[str, Any]: + self.call_order.append(capability) + if capability == "platform.manager.get_by_id": + call_index = self.get_by_id_calls + self.get_by_id_calls += 1 + if call_index == 0: + self.first_get_started.set() + await self.release_first_get.wait() + if self._cleared and self._cleared_snapshot is not None: + snapshot = self._cleared_snapshot + else: + snapshot = self._snapshots[min(call_index, len(self._snapshots) - 1)] + return {"platform": dict(snapshot)} + if capability == "platform.manager.clear_errors": + self.clear_errors_calls += 1 + self._cleared = True + return {} + raise AssertionError(f"unexpected capability: {capability}") + + +@pytest.mark.asyncio +async def test_provider_change_hook_receives_events_and_unregisters_cleanly() -> None: + ctx = MockContext(plugin_metadata={"reserved": True}) + received: list[tuple[str, ProviderType, str | None]] = [] + event_received = asyncio.Event() + + async def callback( + provider_id: str, + provider_type: ProviderType, + umo: str | None, + ) -> None: + received.append((provider_id, provider_type, umo)) + event_received.set() + + task = await ctx.provider_manager.register_provider_change_hook(callback) + await _wait_until(lambda: len(ctx.router._provider_change_subscriptions) == 1) + + ctx.router.emit_provider_change( + "mock-embedding-provider", + ProviderType.EMBEDDING.value, + "mock:session:user", + ) + await asyncio.wait_for(event_received.wait(), timeout=0.2) + + assert received == [ + ( + "mock-embedding-provider", + ProviderType.EMBEDDING, + "mock:session:user", + ) + ] + + await ctx.provider_manager.unregister_provider_change_hook(task) + await _wait_until(lambda: not ctx.router._provider_change_subscriptions) + assert task.cancelled() + assert not ctx.provider_manager._change_hook_tasks + + ctx.router.emit_provider_change( + "mock-rerank-provider", + ProviderType.RERANK.value, + None, + ) + await asyncio.sleep(0) + assert received == [ + ( + "mock-embedding-provider", + ProviderType.EMBEDDING, + "mock:session:user", + ) + ] + + +@pytest.mark.asyncio +async def test_provider_change_hook_task_cancellation_cleans_up_and_logs_once() -> None: + logger = _HookLogger() + ctx = MockContext(plugin_metadata={"reserved": True}, logger=logger) + + task = await ctx.provider_manager.register_provider_change_hook(lambda *_args: None) + await _wait_until(lambda: len(ctx.router._provider_change_subscriptions) == 1) + + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + await _wait_until(lambda: not ctx.router._provider_change_subscriptions) + assert not ctx.provider_manager._change_hook_tasks + assert logger.debug_calls == [ + ("Provider change hook cancelled: plugin_id={}", "test-plugin") + ] + assert logger.exception_calls == [] + + +@pytest.mark.asyncio +async def test_platform_compat_refresh_serializes_concurrent_state_updates() -> None: + proxy = _ControlledPlatformProxy( + snapshots=[ + { + "id": "mock-platform", + "name": "First Snapshot", + "type": "mock", + "status": "error", + "errors": [ + { + "message": "first error", + "timestamp": "2026-03-20T00:00:00+00:00", + "traceback": None, + } + ], + "last_error": { + "message": "first error", + "timestamp": "2026-03-20T00:00:00+00:00", + "traceback": None, + }, + "unified_webhook": False, + }, + { + "id": "mock-platform", + "name": "Second Snapshot", + "type": "mock-updated", + "status": "running", + "errors": [], + "last_error": None, + "unified_webhook": True, + }, + ] + ) + facade = PlatformCompatFacade( + _ctx=SimpleNamespace(_proxy=proxy), + id="mock-platform", + name="Initial Snapshot", + type="mock", + ) + + first = asyncio.create_task(facade.refresh()) + await asyncio.wait_for(proxy.first_get_started.wait(), timeout=0.2) + + second = asyncio.create_task(facade.refresh()) + await asyncio.sleep(0) + assert proxy.get_by_id_calls == 1 + + proxy.release_first_get.set() + await asyncio.gather(first, second) + + assert proxy.call_order == [ + "platform.manager.get_by_id", + "platform.manager.get_by_id", + ] + assert facade.name == "Second Snapshot" + assert facade.type == "mock-updated" + assert facade.status == PlatformStatus.RUNNING + assert facade.errors == [] + assert facade.last_error is None + assert facade.unified_webhook is True + + +@pytest.mark.asyncio +async def test_platform_compat_clear_errors_waits_for_inflight_refresh() -> None: + proxy = _ControlledPlatformProxy( + snapshots=[ + { + "id": "mock-platform", + "name": "Errored Platform", + "type": "mock", + "status": "error", + "errors": [ + { + "message": "boom", + "timestamp": "2026-03-20T00:00:00+00:00", + "traceback": "trace", + } + ], + "last_error": { + "message": "boom", + "timestamp": "2026-03-20T00:00:00+00:00", + "traceback": "trace", + }, + "unified_webhook": False, + } + ], + cleared_snapshot={ + "id": "mock-platform", + "name": "Recovered Platform", + "type": "mock", + "status": "running", + "errors": [], + "last_error": None, + "unified_webhook": False, + }, + ) + facade = PlatformCompatFacade( + _ctx=SimpleNamespace(_proxy=proxy), + id="mock-platform", + name="Initial Snapshot", + type="mock", + ) + + refresh_task = asyncio.create_task(facade.refresh()) + await asyncio.wait_for(proxy.first_get_started.wait(), timeout=0.2) + + clear_task = asyncio.create_task(facade.clear_errors()) + await asyncio.sleep(0) + assert proxy.clear_errors_calls == 0 + + proxy.release_first_get.set() + await asyncio.gather(refresh_task, clear_task) + + assert proxy.call_order == [ + "platform.manager.get_by_id", + "platform.manager.clear_errors", + "platform.manager.get_by_id", + ] + assert facade.name == "Recovered Platform" + assert facade.status == PlatformStatus.RUNNING + assert facade.errors == [] + assert facade.last_error is None + + +@pytest.mark.asyncio +async def test_mock_context_list_platforms_returns_facades_for_valid_instances() -> ( + None +): + ctx = MockContext(plugin_id="plain-plugin") + ctx.router.set_platform_instances( + [ + { + "id": "mock-platform", + "name": "Mock Platform", + "type": "mock", + "status": "running", + }, + { + "id": "mock-platform-2", + "name": "Mock Platform 2", + "type": "mock", + "status": "stopped", + }, + { + "id": "", + "name": "Broken Platform", + "type": "broken", + "status": "running", + }, + ] + ) + + platforms = await ctx.list_platforms() + + assert [platform.id for platform in platforms] == [ + "mock-platform", + "mock-platform-2", + ] + assert all(isinstance(platform, PlatformCompatFacade) for platform in platforms) + assert [platform.status for platform in platforms] == [ + PlatformStatus.RUNNING, + PlatformStatus.STOPPED, + ] + + +@pytest.mark.asyncio +async def test_provider_manager_methods_copy_caller_supplied_config_dicts() -> None: + peer = _ProviderMutationPeer() + manager = ProviderManagerClient( + CapabilityProxy(peer), + plugin_id="test-plugin", + logger=None, + ) + + create_config = { + "id": "provider-create", + "model": "create-model", + "type": "mock", + "provider_type": ProviderType.CHAT_COMPLETION.value, + } + load_config = { + "id": "provider-load", + "model": "load-model", + "type": "mock", + "provider_type": ProviderType.CHAT_COMPLETION.value, + } + update_config = { + "id": "provider-update", + "model": "update-model", + "type": "mock", + "provider_type": ProviderType.CHAT_COMPLETION.value, + } + + create_snapshot = deepcopy(create_config) + load_snapshot = deepcopy(load_config) + update_snapshot = deepcopy(update_config) + + await manager.create_provider(create_config) + await manager.load_provider(load_config) + await manager.update_provider("provider-origin", update_config) + + assert create_config == create_snapshot + assert load_config == load_snapshot + assert update_config == update_snapshot + + +@pytest.mark.asyncio +async def test_provider_manager_get_merged_provider_config_returns_detached_dict() -> ( + None +): + peer = _ProviderMutationPeer() + manager = ProviderManagerClient( + CapabilityProxy(peer), + plugin_id="test-plugin", + logger=None, + ) + + config = await manager.get_merged_provider_config("provider-1") + assert config == {"id": "provider-1", "model": "original-model"} + + assert config is not peer.stored_config + config["model"] = "changed-by-caller" + assert peer.stored_config == {"id": "provider-1", "model": "original-model"} diff --git a/astrbot-sdk/tests/test_request_id_overlay_mapping.py b/astrbot-sdk/tests/test_request_id_overlay_mapping.py new file mode 100644 index 0000000000..f1bbd1e5d9 --- /dev/null +++ b/astrbot-sdk/tests/test_request_id_overlay_mapping.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import Any + +import pytest + +from astrbot_sdk.clients._proxy import CapabilityProxy +from astrbot_sdk.testing import PluginHarness + + +def _write_overlay_test_plugin(plugin_dir: Path) -> None: + plugin_dir.mkdir(parents=True, exist_ok=True) + (plugin_dir / "plugin.yaml").write_text( + """ +_schema_version: 2 +name: overlay_test_plugin +author: tests +version: 1.0.0 +desc: request overlay regression tests + +runtime: + python: "3.12" + +components: + - class: main:OverlayPlugin +""".strip() + + "\n", + encoding="utf-8", + ) + (plugin_dir / "requirements.txt").write_text("", encoding="utf-8") + (plugin_dir / "main.py").write_text( + """ +from astrbot_sdk import Context, MessageEvent, ScheduleContext, Star +from astrbot_sdk.decorators import on_event, on_schedule + + +class OverlayPlugin(Star): + @on_schedule(interval_seconds=60) + async def scheduled(self, ctx: Context, schedule: ScheduleContext) -> None: + applied = await ctx.registry.set_handler_whitelist(["alpha", "beta"]) + current = await ctx.registry.get_handler_whitelist() + await ctx.platform.send_by_id( + "test", + "schedule-target", + f"{','.join(applied or [])}|{','.join(current or []) if current else 'none'}", + ) + + @on_event("llm_request") + async def llm_overlay(self, event: MessageEvent) -> None: + requested = await event.request_llm() + current = await event.should_call_llm() + await event.reply(f"{requested}:{current}") +""".lstrip(), + encoding="utf-8", + ) + + +@pytest.mark.asyncio +async def test_schedule_handler_preserves_request_overlay_state(tmp_path: Path) -> None: + plugin_dir = tmp_path / "overlay_test_plugin" + _write_overlay_test_plugin(plugin_dir) + + async with PluginHarness.from_plugin_dir(plugin_dir) as harness: + payload = harness.build_event_payload( + text="", + event_type="schedule", + request_id="req-schedule-1", + ) + payload["schedule"] = { + "schedule_id": "schedule-1", + "plugin_id": "overlay_test_plugin", + "handler_id": "overlay_test_plugin:scheduled", + "trigger_kind": "interval", + "interval_seconds": 60, + } + + records = await harness.dispatch_event(payload, request_id="req-schedule-1") + + assert len(records) == 1 + assert records[0].kind == "chain" + assert records[0].session == "test:private:schedule-target" + assert records[0].chain is not None + assert records[0].chain[0]["data"]["text"] == "alpha,beta|alpha,beta" + + +@pytest.mark.asyncio +async def test_non_message_event_preserves_request_overlay_state( + tmp_path: Path, +) -> None: + plugin_dir = tmp_path / "overlay_test_plugin" + _write_overlay_test_plugin(plugin_dir) + + async with PluginHarness.from_plugin_dir(plugin_dir) as harness: + payload = harness.build_event_payload( + text="trigger llm overlay", + event_type="llm_request", + request_id="req-llm-1", + ) + + records = await harness.dispatch_event(payload, request_id="req-llm-1") + + assert len(records) == 1 + assert records[0].kind == "text" + assert records[0].text == "True:True" + + +class _RecordingPeer: + def __init__(self) -> None: + self.remote_peer = None + self.remote_capability_map: dict[str, Any] = {} + self.calls: list[tuple[str, dict[str, Any], str | None]] = [] + + async def invoke( + self, + capability: str, + payload: dict[str, Any], + *, + stream: bool = False, + request_id: str | None = None, + ) -> dict[str, Any]: + self.calls.append((capability, dict(payload), request_id)) + return {"ok": True, "stream": stream} + + +@pytest.mark.asyncio +async def test_capability_proxy_keeps_transport_ids_unique_while_forwarding_request_scope() -> ( + None +): + peer = _RecordingPeer() + proxy = CapabilityProxy( + peer, + caller_plugin_id="overlay_test_plugin", + request_scope_id="req-parent-1", + ) + + await asyncio.gather( + proxy.call("system.event.llm.get_state", {}), + proxy.call("system.event.result.get", {}), + proxy.call("platform.send", {"session": "test:private:user-1", "text": "hi"}), + ) + + assert peer.calls == [ + ( + "system.event.llm.get_state", + {"_request_scope_id": "req-parent-1"}, + None, + ), + ( + "system.event.result.get", + {"_request_scope_id": "req-parent-1"}, + None, + ), + ( + "platform.send", + {"session": "test:private:user-1", "text": "hi"}, + None, + ), + ] diff --git a/astrbot-sdk/tests/test_runtime_bootstrap.py b/astrbot-sdk/tests/test_runtime_bootstrap.py new file mode 100644 index 0000000000..76ef8a5a70 --- /dev/null +++ b/astrbot-sdk/tests/test_runtime_bootstrap.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +import sys +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from astrbot_sdk.runtime import bootstrap as bootstrap_module + + +class _RecordingRuntime: + def __init__(self, *, peer_name: str = "runtime-peer") -> None: + self.peer = SimpleNamespace(name=peer_name) + self.started = False + self.stopped = False + + async def start(self) -> None: + self.started = True + + async def stop(self) -> None: + self.stopped = True + + +@pytest.mark.asyncio +async def test_run_plugin_worker_requires_exactly_one_target() -> None: + with pytest.raises(ValueError, match="plugin_dir or group_metadata is required"): + await bootstrap_module.run_plugin_worker(plugin_dir=None, group_metadata=None) + + with pytest.raises(ValueError, match="mutually exclusive"): + await bootstrap_module.run_plugin_worker( + plugin_dir=Path("plugin"), + group_metadata=Path("group.json"), + ) + + +@pytest.mark.asyncio +async def test_run_plugin_worker_uses_single_plugin_runtime_and_restores_stdout( + monkeypatch: pytest.MonkeyPatch, +) -> None: + created: list[_RecordingRuntime] = [] + original_stdout = sys.stdout + + def fake_prepare_stdio_transport(stdin, stdout): + assert stdin == "stdin" + assert stdout == "stdout" + sys.stdout = sys.stderr + return "transport-stdin", "transport-stdout", original_stdout + + class _FakeTransport: + def __init__(self, *, stdin, stdout) -> None: + self.stdin = stdin + self.stdout = stdout + + def fake_runtime(*, plugin_dir: Path, transport) -> _RecordingRuntime: + assert plugin_dir == Path("plugin-dir") + assert transport.stdin == "transport-stdin" + assert transport.stdout == "transport-stdout" + runtime = _RecordingRuntime() + created.append(runtime) + return runtime + + async def fake_wait_for_shutdown(peer, stop_event) -> None: + assert peer is created[0].peer + assert isinstance(stop_event, bootstrap_module.asyncio.Event) + + monkeypatch.setattr( + bootstrap_module, + "_prepare_stdio_transport", + fake_prepare_stdio_transport, + ) + monkeypatch.setattr(bootstrap_module, "StdioTransport", _FakeTransport) + monkeypatch.setattr(bootstrap_module, "PluginWorkerRuntime", fake_runtime) + monkeypatch.setattr( + bootstrap_module, + "_install_signal_handlers", + lambda stop_event: stop_event.set(), + ) + monkeypatch.setattr(bootstrap_module, "_wait_for_shutdown", fake_wait_for_shutdown) + + await bootstrap_module.run_plugin_worker( + plugin_dir=Path("plugin-dir"), + stdin="stdin", + stdout="stdout", + ) + + assert len(created) == 1 + assert created[0].started is True + assert created[0].stopped is True + assert sys.stdout is original_stdout + + +@pytest.mark.asyncio +async def test_run_plugin_worker_uses_group_runtime_when_group_metadata_given( + monkeypatch: pytest.MonkeyPatch, +) -> None: + created: list[_RecordingRuntime] = [] + + monkeypatch.setattr( + bootstrap_module, + "_prepare_stdio_transport", + lambda stdin, stdout: ("stdin", "stdout", None), + ) + monkeypatch.setattr( + bootstrap_module, + "StdioTransport", + lambda *, stdin, stdout: SimpleNamespace(stdin=stdin, stdout=stdout), + ) + + def fake_group_runtime( + *, group_metadata_path: Path, transport + ) -> _RecordingRuntime: + assert group_metadata_path == Path("group.json") + assert transport.stdin == "stdin" + assert transport.stdout == "stdout" + runtime = _RecordingRuntime(peer_name="group-peer") + created.append(runtime) + return runtime + + monkeypatch.setattr(bootstrap_module, "GroupWorkerRuntime", fake_group_runtime) + monkeypatch.setattr( + bootstrap_module, + "_install_signal_handlers", + lambda stop_event: stop_event.set(), + ) + monkeypatch.setattr( + bootstrap_module, + "_wait_for_shutdown", + lambda peer, stop_event: ( + created[0].start() if False else bootstrap_module.asyncio.sleep(0) + ), + ) + + await bootstrap_module.run_plugin_worker(group_metadata=Path("group.json")) + + assert len(created) == 1 + assert created[0].started is True + assert created[0].stopped is True + + +@pytest.mark.asyncio +async def test_run_supervisor_passes_env_manager_and_restores_stdout( + monkeypatch: pytest.MonkeyPatch, +) -> None: + created: list[_RecordingRuntime] = [] + env_manager = object() + original_stdout = sys.stdout + + monkeypatch.setattr( + bootstrap_module, + "_prepare_stdio_transport", + lambda stdin, stdout: ("stdin", "stdout", original_stdout), + ) + monkeypatch.setattr( + bootstrap_module, + "StdioTransport", + lambda *, stdin, stdout: SimpleNamespace(stdin=stdin, stdout=stdout), + ) + + def fake_runtime(*, transport, plugins_dir: Path, env_manager) -> _RecordingRuntime: + assert plugins_dir == Path("plugins-under-test") + assert env_manager is not None + assert transport.stdin == "stdin" + assert transport.stdout == "stdout" + runtime = _RecordingRuntime(peer_name="supervisor-peer") + created.append(runtime) + return runtime + + monkeypatch.setattr(bootstrap_module, "SupervisorRuntime", fake_runtime) + monkeypatch.setattr( + bootstrap_module, + "_install_signal_handlers", + lambda stop_event: stop_event.set(), + ) + monkeypatch.setattr( + bootstrap_module, + "_wait_for_shutdown", + lambda peer, stop_event: bootstrap_module.asyncio.sleep(0), + ) + + await bootstrap_module.run_supervisor( + plugins_dir=Path("plugins-under-test"), + env_manager=env_manager, + ) + + assert len(created) == 1 + assert created[0].started is True + assert created[0].stopped is True + assert sys.stdout is original_stdout + + +@pytest.mark.asyncio +async def test_run_websocket_server_uses_websocket_transport_and_default_cwd( + monkeypatch: pytest.MonkeyPatch, +) -> None: + created: list[_RecordingRuntime] = [] + websocket_transports: list[SimpleNamespace] = [] + + monkeypatch.setattr(bootstrap_module.Path, "cwd", lambda: Path("cwd-plugin")) + + def fake_transport(*, host: str, port: int, path: str): + transport = SimpleNamespace(host=host, port=port, path=path) + websocket_transports.append(transport) + return transport + + def fake_runtime(*, plugin_dir: Path, transport) -> _RecordingRuntime: + assert plugin_dir == Path("cwd-plugin") + assert transport is websocket_transports[0] + runtime = _RecordingRuntime(peer_name="ws-peer") + created.append(runtime) + return runtime + + monkeypatch.setattr(bootstrap_module, "WebSocketServerTransport", fake_transport) + monkeypatch.setattr(bootstrap_module, "PluginWorkerRuntime", fake_runtime) + monkeypatch.setattr( + bootstrap_module, + "_install_signal_handlers", + lambda stop_event: stop_event.set(), + ) + monkeypatch.setattr( + bootstrap_module, + "_wait_for_shutdown", + lambda peer, stop_event: bootstrap_module.asyncio.sleep(0), + ) + + await bootstrap_module.run_websocket_server( + host="0.0.0.0", + port=9000, + path="/ws", + plugin_dir=None, + ) + + assert websocket_transports == [ + SimpleNamespace(host="0.0.0.0", port=9000, path="/ws") + ] + assert created[0].started is True + assert created[0].stopped is True diff --git a/astrbot-sdk/tests/test_runtime_capability_dispatcher.py b/astrbot-sdk/tests/test_runtime_capability_dispatcher.py new file mode 100644 index 0000000000..e6b18ff1bc --- /dev/null +++ b/astrbot-sdk/tests/test_runtime_capability_dispatcher.py @@ -0,0 +1,274 @@ +from __future__ import annotations + +import json +import asyncio +from typing import Any + +import pytest +from pydantic import BaseModel + +from astrbot_sdk._internal.testing_support import MockCapabilityRouter, MockPeer +from astrbot_sdk.context import CancelToken, Context +from astrbot_sdk.events import MessageEvent +from astrbot_sdk.llm.entities import LLMToolSpec +from astrbot_sdk.protocol.descriptors import CapabilityDescriptor +from astrbot_sdk.protocol.messages import InvokeMessage +from astrbot_sdk.runtime._streaming import StreamExecution +from astrbot_sdk.runtime.capability_dispatcher import CapabilityDispatcher +from astrbot_sdk.runtime.loader import LoadedCapability, LoadedLLMTool + + +class _SerializableChunk(BaseModel): + value: str + + +def _build_loaded_capability( + handler, + *, + name: str = "test.echo", + plugin_id: str = "test-plugin", +) -> LoadedCapability: + return LoadedCapability( + descriptor=CapabilityDescriptor( + name=name, + description="test capability", + input_schema={"type": "object"}, + output_schema={"type": "object"}, + supports_stream=True, + cancelable=True, + ), + callable=handler, + owner=object(), + plugin_id=plugin_id, + ) + + +@pytest.mark.asyncio +async def test_capability_dispatcher_returns_stream_execution_for_async_generator() -> ( + None +): + peer = MockPeer(MockCapabilityRouter()) + + async def stream_capability(payload: dict[str, Any]): + yield {"value": str(payload["name"]).upper()} + yield _SerializableChunk(value="done") + + dispatcher = CapabilityDispatcher( + plugin_id="test-plugin", + peer=peer, + capabilities=[_build_loaded_capability(stream_capability)], + ) + + execution = await dispatcher.invoke( + InvokeMessage( + id="req-stream", + capability="test.echo", + input={"name": "alice"}, + stream=True, + ), + CancelToken(), + ) + + assert isinstance(execution, StreamExecution) + chunks = [chunk async for chunk in execution.iterator] + + assert chunks == [{"value": "ALICE"}, {"value": "done"}] + assert execution.finalize(chunks) == {"items": chunks} + + +@pytest.mark.asyncio +async def test_capability_dispatcher_injection_error_mentions_supported_sources() -> ( + None +): + peer = MockPeer(MockCapabilityRouter()) + + def broken(required_name: str) -> dict[str, Any]: + return {"ok": True} + + dispatcher = CapabilityDispatcher( + plugin_id="plugin-alpha", + peer=peer, + capabilities=[ + _build_loaded_capability( + broken, + name="plugin-alpha.broken", + plugin_id="plugin-alpha", + ) + ], + ) + + with pytest.raises(TypeError) as exc_info: + await dispatcher.invoke( + InvokeMessage( + id="req-broken", + capability="plugin-alpha.broken", + input={"available": "value"}, + ), + CancelToken(), + ) + + message = str(exc_info.value) + assert ( + "插件 'plugin-alpha' 的 capability 'plugin-alpha.broken' 参数注入失败" + in message + ) + assert "必填参数 'required_name' 无法注入" in message + assert "payload 中现有键:available" in message + + +@pytest.mark.asyncio +async def test_registered_llm_tool_injects_event_and_normalizes_dict_result() -> None: + peer = MockPeer(MockCapabilityRouter()) + + async def tool_handler( + event: MessageEvent, + ctx: Context, + text: str, + ) -> dict[str, str]: + return { + "echo": text, + "session": event.session_id, + "plugin": ctx.plugin_id, + } + + loaded_tool = LoadedLLMTool( + spec=LLMToolSpec.create( + name="echo", + description="Echo", + parameters_schema={ + "type": "object", + "properties": {"text": {"type": "string"}}, + }, + handler_ref="echo.ref", + ), + callable=tool_handler, + owner=object(), + plugin_id="tool-plugin", + ) + dispatcher = CapabilityDispatcher( + plugin_id="worker-group", + peer=peer, + capabilities=[], + llm_tools=[loaded_tool], + ) + + result = await dispatcher.invoke( + InvokeMessage( + id="req-tool", + capability="internal.llm_tool.execute", + input={ + "plugin_id": "tool-plugin", + "tool_name": "echo", + "handler_ref": "echo.ref", + "tool_args": {"text": "hello"}, + "event": { + "type": "message", + "event_type": "message", + "text": "trigger", + "session_id": "session-42", + "user_id": "tester", + "platform": "test", + "platform_id": "test", + "message_type": "private", + "raw": {"event_type": "message"}, + }, + }, + ), + CancelToken(), + ) + + assert result["success"] is True + assert json.loads(str(result["content"])) == { + "echo": "hello", + "session": "session-42", + "plugin": "tool-plugin", + } + + +def test_dynamic_llm_tool_registration_replaces_aliases_and_remove_cleans_all_keys() -> ( + None +): + peer = MockPeer(MockCapabilityRouter()) + dispatcher = CapabilityDispatcher( + plugin_id="worker-group", + peer=peer, + capabilities=[], + ) + + async def first_tool() -> str: + return "first" + + async def second_tool() -> str: + return "second" + + dispatcher.add_dynamic_llm_tool( + plugin_id="plugin.alpha", + spec=LLMToolSpec.create( + name="echo", + description="Echo", + handler_ref="echo.ref", + ), + callable_obj=first_tool, + ) + dispatcher.add_dynamic_llm_tool( + plugin_id="plugin.alpha", + spec=LLMToolSpec.create( + name="echo", + description="Echo updated", + handler_ref="echo.ref", + ), + callable_obj=second_tool, + ) + + loaded_by_name = dispatcher._llm_tools[("plugin.alpha", "echo")] + loaded_by_ref = dispatcher._llm_tools[("plugin.alpha", "echo.ref")] + + assert loaded_by_name.callable is second_tool + assert loaded_by_ref.callable is second_tool + assert dispatcher.remove_llm_tool("plugin.alpha", "echo.ref") is True + assert ("plugin.alpha", "echo") not in dispatcher._llm_tools + assert ("plugin.alpha", "echo.ref") not in dispatcher._llm_tools + + +@pytest.mark.asyncio +async def test_capability_dispatcher_cancel_propagates_to_task_and_token() -> None: + peer = MockPeer(MockCapabilityRouter()) + dispatcher = CapabilityDispatcher( + plugin_id="worker-group", + peer=peer, + capabilities=[], + ) + cancel_token = CancelToken() + task = asyncio.create_task(asyncio.sleep(30)) + dispatcher._active["req-cancel"] = (task, cancel_token) + + await dispatcher.cancel("req-cancel") + + assert cancel_token.cancelled is True + with pytest.raises(asyncio.CancelledError): + await task + + +@pytest.mark.asyncio +async def test_capability_dispatcher_stream_mode_rejects_non_stream_result() -> None: + peer = MockPeer(MockCapabilityRouter()) + + async def non_stream_capability(payload: dict[str, Any]) -> dict[str, Any]: + return {"payload": payload} + + dispatcher = CapabilityDispatcher( + plugin_id="test-plugin", + peer=peer, + capabilities=[_build_loaded_capability(non_stream_capability)], + ) + + with pytest.raises(Exception, match="stream=true"): + await dispatcher.invoke( + InvokeMessage( + id="req-stream-invalid", + capability="test.echo", + input={"name": "alice"}, + stream=True, + ), + CancelToken(), + ) diff --git a/astrbot-sdk/tests/test_runtime_environment_groups.py b/astrbot-sdk/tests/test_runtime_environment_groups.py new file mode 100644 index 0000000000..da7572e646 --- /dev/null +++ b/astrbot-sdk/tests/test_runtime_environment_groups.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from astrbot_sdk.runtime.environment_groups import ( + EnvironmentGroup, + EnvironmentPlanner, + GROUP_STATE_FILE_NAME, + GroupEnvironmentManager, +) +from astrbot_sdk.runtime.loader import PluginSpec + + +def _plugin_spec( + tmp_path: Path, + name: str, + *, + python_version: str = "3.12", + requirements: list[str] | None = None, +) -> PluginSpec: + plugin_dir = tmp_path / name + plugin_dir.mkdir(parents=True, exist_ok=True) + manifest_path = plugin_dir / "plugin.yaml" + requirements_path = plugin_dir / "requirements.txt" + manifest_path.write_text(f"name: {name}\n", encoding="utf-8") + requirements_path.write_text( + "\n".join(requirements or []) + ("\n" if requirements else ""), + encoding="utf-8", + ) + return PluginSpec( + name=name, + plugin_dir=plugin_dir, + manifest_path=manifest_path, + requirements_path=requirements_path, + python_version=python_version, + manifest_data={"name": name}, + ) + + +def _group( + repo_root: Path, plugin: PluginSpec, *, fingerprint: str = "fingerprint" +) -> EnvironmentGroup: + venv_path = repo_root / ".astrbot" / "envs" / plugin.name + lockfile_path = repo_root / ".astrbot" / "locks" / f"{plugin.name}.txt" + metadata_path = repo_root / ".astrbot" / "groups" / f"{plugin.name}.json" + source_path = repo_root / ".astrbot" / "groups" / f"{plugin.name}.in" + python_path = venv_path / ("Scripts/python.exe") + return EnvironmentGroup( + id=plugin.name, + python_version=plugin.python_version, + plugins=[plugin], + source_path=source_path, + lockfile_path=lockfile_path, + metadata_path=metadata_path, + venv_path=venv_path, + python_path=python_path, + environment_fingerprint=fingerprint, + ) + + +def test_environment_planner_plan_groups_compatible_plugins_and_splits_conflicts( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + alpha = _plugin_spec(tmp_path, "alpha", requirements=["pkg==1.0", "shared==2.0"]) + beta = _plugin_spec(tmp_path, "beta", requirements=["pkg==1.0"]) + gamma = _plugin_spec(tmp_path, "gamma", requirements=["pkg==2.0"]) + planner = EnvironmentPlanner(tmp_path, uv_binary="uv") + + def fake_compile_lockfile( + *, source_path: Path, output_path: Path, python_version: str + ) -> None: + content = source_path.read_text(encoding="utf-8") + if "pkg==1.0" in content and "pkg==2.0" in content: + raise RuntimeError("dependency conflict") + output_path.write_text(f"# lock for {python_version}\n", encoding="utf-8") + + monkeypatch.setattr(planner, "_compile_lockfile", fake_compile_lockfile) + + plan = planner.plan([alpha, beta, gamma]) + + assert len(plan.groups) == 2 + grouped_plugins = sorted( + sorted(plugin.name for plugin in group.plugins) for group in plan.groups + ) + assert grouped_plugins == [["alpha", "beta"], ["gamma"]] + assert sorted(plugin.name for plugin in plan.plugins) == ["alpha", "beta", "gamma"] + assert plan.skipped_plugins == {} + assert plan.plugin_to_group["alpha"] is plan.plugin_to_group["beta"] + assert plan.plugin_to_group["gamma"] is not plan.plugin_to_group["alpha"] + + +def test_environment_planner_cleanup_artifacts_removes_stale_entries( + tmp_path: Path, +) -> None: + plugin = _plugin_spec(tmp_path, "active") + planner = EnvironmentPlanner(tmp_path, uv_binary="uv") + active_group = _group(tmp_path, plugin) + active_group.source_path.parent.mkdir(parents=True, exist_ok=True) + active_group.lockfile_path.parent.mkdir(parents=True, exist_ok=True) + active_group.venv_path.mkdir(parents=True, exist_ok=True) + active_group.source_path.write_text("", encoding="utf-8") + active_group.metadata_path.write_text("{}", encoding="utf-8") + active_group.lockfile_path.write_text("", encoding="utf-8") + + stale_source = planner.group_dir / "stale.in" + stale_metadata = planner.group_dir / "stale.json" + stale_lockfile = planner.lock_dir / "stale.txt" + stale_env = planner.env_dir / "stale" + stale_source.parent.mkdir(parents=True, exist_ok=True) + stale_lockfile.parent.mkdir(parents=True, exist_ok=True) + stale_env.mkdir(parents=True, exist_ok=True) + stale_source.write_text("", encoding="utf-8") + stale_metadata.write_text("{}", encoding="utf-8") + stale_lockfile.write_text("", encoding="utf-8") + (stale_env / "pyvenv.cfg").write_text("version = 3.12\n", encoding="utf-8") + + planner.cleanup_artifacts([active_group]) + + assert active_group.source_path.exists() is True + assert active_group.metadata_path.exists() is True + assert active_group.lockfile_path.exists() is True + assert active_group.venv_path.exists() is True + assert stale_source.exists() is False + assert stale_metadata.exists() is False + assert stale_lockfile.exists() is False + assert stale_env.exists() is False + + +def test_group_environment_manager_prepare_rebuilds_when_runtime_is_missing( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + plugin = _plugin_spec(tmp_path, "alpha") + group = _group(tmp_path, plugin) + group.lockfile_path.parent.mkdir(parents=True, exist_ok=True) + group.lockfile_path.write_text("# lock\n", encoding="utf-8") + manager = GroupEnvironmentManager(tmp_path, uv_binary="uv") + calls: list[str] = [] + + monkeypatch.setattr( + manager, + "_rebuild", + lambda current_group: calls.append(f"rebuild:{current_group.id}"), + ) + monkeypatch.setattr( + manager, + "_sync_existing", + lambda current_group: calls.append(f"sync:{current_group.id}"), + ) + + python_path = manager.prepare(group) + + assert python_path == group.python_path + assert calls == ["rebuild:alpha"] + state = json.loads( + (group.venv_path / GROUP_STATE_FILE_NAME).read_text(encoding="utf-8") + ) + assert state["group_id"] == "alpha" + assert state["environment_fingerprint"] == "fingerprint" + + +def test_group_environment_manager_prepare_syncs_existing_env_when_state_changed( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + plugin = _plugin_spec(tmp_path, "alpha") + group = _group(tmp_path, plugin, fingerprint="new-fingerprint") + group.venv_path.mkdir(parents=True, exist_ok=True) + group.lockfile_path.parent.mkdir(parents=True, exist_ok=True) + group.lockfile_path.write_text("# lock\n", encoding="utf-8") + group.python_path.parent.mkdir(parents=True, exist_ok=True) + group.python_path.write_text("", encoding="utf-8") + state_path = group.venv_path / GROUP_STATE_FILE_NAME + state_path.write_text( + json.dumps( + { + "group_id": group.id, + "python_version": group.python_version, + "environment_fingerprint": "old-fingerprint", + } + ), + encoding="utf-8", + ) + manager = GroupEnvironmentManager(tmp_path, uv_binary="uv") + calls: list[str] = [] + + monkeypatch.setattr( + manager, "_matches_python_version", lambda venv_path, version: True + ) + monkeypatch.setattr( + manager, + "_rebuild", + lambda current_group: calls.append(f"rebuild:{current_group.id}"), + ) + monkeypatch.setattr( + manager, + "_sync_existing", + lambda current_group: calls.append(f"sync:{current_group.id}"), + ) + + python_path = manager.prepare(group) + + assert python_path == group.python_path + assert calls == ["sync:alpha"] + updated_state = json.loads(state_path.read_text(encoding="utf-8")) + assert updated_state["environment_fingerprint"] == "new-fingerprint" diff --git a/astrbot-sdk/tests/test_runtime_handler_dispatcher_core.py b/astrbot-sdk/tests/test_runtime_handler_dispatcher_core.py new file mode 100644 index 0000000000..024b70c951 --- /dev/null +++ b/astrbot-sdk/tests/test_runtime_handler_dispatcher_core.py @@ -0,0 +1,366 @@ +from __future__ import annotations + +import asyncio +from types import SimpleNamespace + +import pytest +from unittest.mock import AsyncMock + +from astrbot_sdk._internal.testing_support import MockCapabilityRouter, MockPeer +from astrbot_sdk.clients.llm import LLMResponse +from astrbot_sdk.context import Context +from astrbot_sdk.conversation import ( + ConversationReplaced, + ConversationSession, + ConversationState, +) +from astrbot_sdk.decorators import ConversationMeta +from astrbot_sdk.events import MessageEvent +from astrbot_sdk.llm.entities import ProviderRequest +from astrbot_sdk.protocol.descriptors import HandlerDescriptor, MessageTrigger +from astrbot_sdk.runtime.handler_dispatcher import ( + _ActiveConversation, + _InjectedEventPayloads, + HandlerDispatcher, +) +from astrbot_sdk.runtime.loader import LoadedHandler +from astrbot_sdk.star import Star + + +def _build_event( + *, + peer: MockPeer, + plugin_id: str = "test-plugin", + text: str = "hello", + session_id: str = "session-1", + event_type: str = "message", + payload_extra: dict[str, object] | None = None, + raw_extra: dict[str, object] | None = None, +) -> tuple[MessageEvent, Context]: + payload: dict[str, object] = { + "type": event_type, + "event_type": event_type, + "text": text, + "session_id": session_id, + "user_id": "tester", + "platform": "test", + "platform_id": "test", + "message_type": "private", + "raw": {"event_type": event_type}, + } + if payload_extra: + payload.update(payload_extra) + if raw_extra: + payload["raw"] = {**payload["raw"], **raw_extra} + ctx = Context(peer=peer, plugin_id=plugin_id) + event = MessageEvent.from_payload(payload, context=ctx) + return event, ctx + + +def _build_loaded_handler( + handler, + *, + plugin_id: str = "test-plugin", + owner=None, + conversation: ConversationMeta | None = None, +) -> LoadedHandler: + return LoadedHandler( + descriptor=HandlerDescriptor( + id=f"{plugin_id}.handler", + trigger=MessageTrigger(), + ), + callable=handler, + owner=owner if owner is not None else object(), + plugin_id=plugin_id, + conversation=conversation, + ) + + +@pytest.mark.asyncio +async def test_inject_provider_request_reads_nested_payload_without_cache() -> None: + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher(plugin_id="test-plugin", peer=peer, handlers=[]) + request_payload = { + "prompt": "hello", + "session_id": "session-1", + "model": "gpt-test", + } + event, _ctx = _build_event( + peer=peer, + event_type="llm_request", + raw_extra={"provider_request": request_payload}, + ) + + injected = dispatcher._inject_provider_request(event, None) + + assert injected == ProviderRequest.from_payload(request_payload) + + +@pytest.mark.asyncio +async def test_run_handler_reuses_llm_response_and_serializes_summary() -> None: + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher(plugin_id="test-plugin", peer=peer, handlers=[]) + response_payload = { + "text": "hello back", + "finish_reason": "stop", + "tool_calls": [], + } + event, ctx = _build_event( + peer=peer, + event_type="llm_response", + raw_extra={"llm_response": response_payload}, + ) + injected_payloads = _InjectedEventPayloads() + + first = dispatcher._inject_llm_response(event, injected_payloads) + second = dispatcher._inject_llm_response(event, injected_payloads) + + assert isinstance(first, LLMResponse) + assert first is second + assert first.model_dump(exclude_none=True) == response_payload + + +@pytest.mark.asyncio +async def test_run_handler_merges_dict_result_flags() -> None: + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher(plugin_id="test-plugin", peer=peer, handlers=[]) + event, ctx = _build_event(peer=peer) + + async def handler(event: MessageEvent) -> dict[str, object]: + assert event.session_id == "session-1" + return {"stop": True, "call_llm": True} + + loaded = _build_loaded_handler(handler) + + summary = await dispatcher._run_handler(loaded, event, ctx, {}) + + assert summary == {"sent_message": False, "stop": True, "call_llm": True} + + +@pytest.mark.asyncio +async def test_handle_error_prefers_owner_hook() -> None: + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher(plugin_id="test-plugin", peer=peer, handlers=[]) + event, ctx = _build_event(peer=peer) + calls: list[tuple[Exception, str, str]] = [] + + class Owner: + async def on_error( + self, + exc: Exception, + error_event: MessageEvent, + error_ctx: Context, + ) -> None: + calls.append((exc, error_event.session_id, error_ctx.plugin_id)) + + boom = RuntimeError("boom") + + await dispatcher._handle_error(Owner(), boom, event, ctx) + + assert calls == [(boom, "session-1", "test-plugin")] + + +@pytest.mark.asyncio +async def test_handle_error_falls_back_to_default_star_handler() -> None: + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher(plugin_id="test-plugin", peer=peer, handlers=[]) + event, ctx = _build_event(peer=peer) + seen: list[tuple[Exception, str, str]] = [] + original = Star.default_on_error + + async def fake_default_on_error( + exc: Exception, + error_event: MessageEvent, + error_ctx: Context, + ) -> None: + seen.append((exc, error_event.session_id, error_ctx.plugin_id)) + + Star.default_on_error = fake_default_on_error + try: + boom = ValueError("fallback") + await dispatcher._handle_error(object(), boom, event, ctx) + finally: + Star.default_on_error = original + + assert seen == [(boom, "session-1", "test-plugin")] + + +@pytest.mark.asyncio +async def test_start_conversation_rejects_when_existing_session_is_busy() -> None: + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher(plugin_id="test-plugin", peer=peer, handlers=[]) + event, ctx = _build_event(peer=peer) + blocker = asyncio.Event() + + async def pending_task() -> None: + await blocker.wait() + + existing = asyncio.create_task(pending_task()) + conversation = ConversationSession( + ctx=ctx, + event=event, + waiter_manager=dispatcher._session_waiters, + timeout=30, + ) + dispatcher._conversations["test-plugin:session-1"] = _ActiveConversation( + session=conversation, + task=existing, + ) + loaded = _build_loaded_handler( + lambda conversation: None, + conversation=ConversationMeta(mode="reject", busy_message="still busy"), + ) + + try: + summary = await dispatcher._start_conversation( + loaded, + event, + ctx, + {}, + schedule_context=None, + ) + finally: + existing.cancel() + with pytest.raises(asyncio.CancelledError): + await existing + + assert summary == {"sent_message": True, "stop": True, "call_llm": False} + assert peer._router.platform_sink.records[-1].text == "still busy" + + +@pytest.mark.asyncio +async def test_start_conversation_replaces_existing_session_and_registers_new_one() -> ( + None +): + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher(plugin_id="test-plugin", peer=peer, handlers=[]) + event, ctx = _build_event(peer=peer) + replacement_seen = asyncio.Event() + + async def previous_runner() -> None: + try: + await asyncio.Future() + except asyncio.CancelledError: + replacement_seen.set() + raise + + previous_task = asyncio.create_task(previous_runner()) + previous_session = ConversationSession( + ctx=ctx, + event=event, + waiter_manager=dispatcher._session_waiters, + timeout=30, + ) + previous_session.bind_owner_task(previous_task) + dispatcher._conversations["test-plugin:session-1"] = _ActiveConversation( + session=previous_session, + task=previous_task, + ) + dispatcher._session_waiters.fail = AsyncMock(return_value=True) # type: ignore[method-assign] + + async def handler(conversation: ConversationSession) -> None: + conversation.close(ConversationState.COMPLETED) + + loaded = _build_loaded_handler( + handler, + conversation=ConversationMeta(mode="replace", grace_period=0.1), + ) + + summary = await dispatcher._start_conversation( + loaded, + event, + ctx, + {}, + schedule_context=None, + ) + + assert summary == {"sent_message": False, "stop": True, "call_llm": False} + await asyncio.wait_for(replacement_seen.wait(), timeout=1) + assert previous_session.state == ConversationState.REPLACED + dispatcher._session_waiters.fail.assert_awaited_once() + fail_call = dispatcher._session_waiters.fail.await_args + assert fail_call.args[0] == previous_session.session_key + assert isinstance(fail_call.args[1], ConversationReplaced) + + active = dispatcher._conversations["test-plugin:session-1"] + assert active.session is not previous_session + await active.task + assert "test-plugin:session-1" not in dispatcher._conversations + + +@pytest.mark.asyncio +async def test_run_conversation_task_marks_conversation_cancelled_on_task_cancel() -> ( + None +): + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher(plugin_id="test-plugin", peer=peer, handlers=[]) + event, ctx = _build_event(peer=peer) + conversation = ConversationSession( + ctx=ctx, + event=event, + waiter_manager=dispatcher._session_waiters, + timeout=30, + ) + entered = asyncio.Event() + + async def handler(conversation: ConversationSession) -> None: + entered.set() + await asyncio.Future() + + loaded = _build_loaded_handler( + handler, + conversation=ConversationMeta(), + ) + + task = asyncio.create_task( + dispatcher._run_conversation_task( + loaded, + event, + ctx, + {}, + conversation, + schedule_context=None, + ) + ) + conversation.bind_owner_task(task) + await entered.wait() + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + assert conversation.state == ConversationState.CANCELLED + + +@pytest.mark.asyncio +async def test_run_conversation_task_reports_handler_errors_without_reraising() -> None: + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher(plugin_id="test-plugin", peer=peer, handlers=[]) + event, ctx = _build_event(peer=peer) + conversation = ConversationSession( + ctx=ctx, + event=event, + waiter_manager=dispatcher._session_waiters, + timeout=30, + ) + dispatcher._handle_error = AsyncMock() # type: ignore[method-assign] + + async def handler(conversation: ConversationSession) -> None: + raise RuntimeError("conversation exploded") + + loaded = _build_loaded_handler( + handler, + conversation=ConversationMeta(), + owner=SimpleNamespace(), + ) + + await dispatcher._run_conversation_task( + loaded, + event, + ctx, + {}, + conversation, + schedule_context=None, + ) + + dispatcher._handle_error.assert_awaited_once() diff --git a/astrbot-sdk/tests/test_runtime_limiter.py b/astrbot-sdk/tests/test_runtime_limiter.py new file mode 100644 index 0000000000..73a330481a --- /dev/null +++ b/astrbot-sdk/tests/test_runtime_limiter.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from astrbot_sdk.decorators import LimiterMeta +from astrbot_sdk.errors import ErrorCodes +from astrbot_sdk.runtime.limiter import ( + DEFAULT_COOLDOWN_MESSAGE, + DEFAULT_RATE_LIMIT_MESSAGE, + LimiterEngine, +) + + +def test_limiter_engine_scopes_rate_limit_per_session() -> None: + engine = LimiterEngine(clock=lambda: 10.0) + limiter = LimiterMeta(kind="rate_limit", limit=1, window=30.0, scope="session") + session_a = SimpleNamespace(session_id="session-a") + session_b = SimpleNamespace(session_id="session-b") + + first = engine.evaluate( + plugin_id="plugin", + handler_id="plugin.handler", + limiter=limiter, + event=session_a, + ) + blocked = engine.evaluate( + plugin_id="plugin", + handler_id="plugin.handler", + limiter=limiter, + event=session_a, + ) + other_session = engine.evaluate( + plugin_id="plugin", + handler_id="plugin.handler", + limiter=limiter, + event=session_b, + ) + + assert first.allowed is True + assert blocked.allowed is False + assert blocked.hint == DEFAULT_RATE_LIMIT_MESSAGE + assert other_session.allowed is True + + +def test_limiter_engine_error_behavior_returns_cooldown_error_details() -> None: + engine = LimiterEngine(clock=lambda: 20.0) + limiter = LimiterMeta( + kind="cooldown", + limit=1, + window=8.2, + scope="user", + behavior="error", + ) + event = SimpleNamespace(platform_id="test", user_id="user-1") + + engine.evaluate( + plugin_id="plugin", + handler_id="plugin.handler", + limiter=limiter, + event=event, + ) + blocked = engine.evaluate( + plugin_id="plugin", + handler_id="plugin.handler", + limiter=limiter, + event=event, + ) + + assert blocked.allowed is False + assert blocked.error is not None + assert blocked.error.code == ErrorCodes.COOLDOWN_ACTIVE + assert blocked.error.details == { + "scope": "user", + "handler_id": "plugin.handler", + "remaining_seconds": 8.2, + } + assert blocked.error.hint == DEFAULT_COOLDOWN_MESSAGE.format(remaining_seconds=9) + + +def test_limiter_engine_silent_behavior_returns_no_hint_or_error() -> None: + engine = LimiterEngine(clock=lambda: 5.0) + limiter = LimiterMeta( + kind="rate_limit", + limit=1, + window=10.0, + scope="global", + behavior="silent", + message="custom {remaining_seconds}", + ) + event = SimpleNamespace() + + engine.evaluate( + plugin_id="plugin", + handler_id="plugin.handler", + limiter=limiter, + event=event, + ) + blocked = engine.evaluate( + plugin_id="plugin", + handler_id="plugin.handler", + limiter=limiter, + event=event, + ) + + assert blocked.allowed is False + assert blocked.hint is None + assert blocked.error is None diff --git a/astrbot-sdk/tests/test_runtime_loader_regressions.py b/astrbot-sdk/tests/test_runtime_loader_regressions.py new file mode 100644 index 0000000000..7485214217 --- /dev/null +++ b/astrbot-sdk/tests/test_runtime_loader_regressions.py @@ -0,0 +1,253 @@ +from __future__ import annotations + +import importlib +import shutil +import sys +from contextlib import contextmanager +from pathlib import Path +from typing import Iterator + +from astrbot_sdk.runtime.loader import ( + discover_plugins, + load_plugin, + load_plugin_spec, + validate_plugin_spec, +) + + +def _write_plugin( + plugin_dir: Path, + *, + plugin_name: str, + class_name: str, + main_source: str, + extra_files: dict[str, str] | None = None, + write_requirements: bool = True, +) -> None: + plugin_dir.mkdir(parents=True, exist_ok=True) + python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + (plugin_dir / "plugin.yaml").write_text( + f""" +_schema_version: 2 +name: {plugin_name} +author: tests +version: 1.0.0 +desc: loader regression tests + +runtime: + python: "{python_version}" + +components: + - class: main:{class_name} +""".strip() + + "\n", + encoding="utf-8", + ) + if write_requirements: + (plugin_dir / "requirements.txt").write_text("", encoding="utf-8") + (plugin_dir / "main.py").write_text(main_source.lstrip(), encoding="utf-8") + + for relative_path, content in (extra_files or {}).items(): + target = plugin_dir / relative_path + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text(content, encoding="utf-8") + + +def _load_first_instance(plugin_dir: Path): + plugin = load_plugin_spec(plugin_dir) + validate_plugin_spec(plugin) + loaded = load_plugin(plugin) + assert loaded.instances + return loaded.instances[0] + + +def _purge_module_roots(*roots: str) -> None: + for root in {item for item in roots if item}: + for module_name in list(sys.modules): + if module_name == root or module_name.startswith(f"{root}."): + sys.modules.pop(module_name, None) + + +@contextmanager +def _preserve_import_state(*module_roots: str) -> Iterator[None]: + original_path = list(sys.path) + original_modules = { + name: module + for name, module in sys.modules.items() + if any(name == root or name.startswith(f"{root}.") for root in module_roots) + } + try: + yield + finally: + sys.path[:] = original_path + _purge_module_roots(*module_roots) + sys.modules.update(original_modules) + importlib.invalidate_caches() + + +def test_load_plugin_reloads_same_path_after_source_change(tmp_path: Path) -> None: + plugin_dir = tmp_path / "reload_plugin" + _write_plugin( + plugin_dir, + plugin_name="reload_plugin", + class_name="ReloadPlugin", + main_source=""" +from astrbot_sdk import Star +from support.value import CURRENT_VALUE + + +class ReloadPlugin(Star): + value = CURRENT_VALUE +""", + extra_files={ + "support/__init__.py": "", + "support/value.py": 'CURRENT_VALUE = "v1"\n', + }, + ) + + with _preserve_import_state("main", "support"): + first = _load_first_instance(plugin_dir) + assert first.value == "v1" + + (plugin_dir / "support" / "value.py").write_text( + 'CURRENT_VALUE = "v2"\n', + encoding="utf-8", + ) + + second = _load_first_instance(plugin_dir) + assert second.value == "v2" + assert second.__class__ is not first.__class__ + assert ( + Path(sys.modules["main"].__file__).resolve() + == (plugin_dir / "main.py").resolve() + ) + assert ( + Path(sys.modules["support.value"].__file__).resolve() + == (plugin_dir / "support" / "value.py").resolve() + ) + + +def test_load_plugin_prefers_target_plugin_dir_for_generic_main_module( + tmp_path: Path, +) -> None: + foreign_dir = tmp_path / "foreign_main" + foreign_dir.mkdir(parents=True, exist_ok=True) + (foreign_dir / "main.py").write_text( + """ +from astrbot_sdk import Star + + +class SharedPlugin(Star): + source = "foreign" +""".lstrip(), + encoding="utf-8", + ) + + plugin_dir = tmp_path / "generic_main_plugin" + _write_plugin( + plugin_dir, + plugin_name="generic_main_plugin", + class_name="SharedPlugin", + main_source=""" +from astrbot_sdk import Star + + +class SharedPlugin(Star): + source = "plugin" +""", + ) + + with _preserve_import_state("main"): + sys.path.insert(0, str(foreign_dir.resolve())) + sys.path.append(str(plugin_dir.resolve())) + + _purge_module_roots("main") + __import__("main") + assert ( + Path(sys.modules["main"].__file__).resolve() + == (foreign_dir / "main.py").resolve() + ) + + instance = _load_first_instance(plugin_dir) + + assert instance.source == "plugin" + assert sys.path[0] == str(plugin_dir.resolve()) + assert ( + Path(sys.modules["main"].__file__).resolve() + == (plugin_dir / "main.py").resolve() + ) + + +def test_load_plugin_cleans_stale_bytecode_from_copied_fixture(tmp_path: Path) -> None: + fixture_source = tmp_path / "fixture_source" + _write_plugin( + fixture_source, + plugin_name="copied_fixture_plugin", + class_name="FixturePlugin", + main_source=""" +from astrbot_sdk import Star + + +class FixturePlugin(Star): + value = "fresh" +""", + ) + + cache_tag = sys.implementation.cache_tag or "cpython" + stale_main_pyc = fixture_source / "__pycache__" / f"main.{cache_tag}.pyc" + stale_main_pyc.parent.mkdir(parents=True, exist_ok=True) + stale_main_pyc.write_bytes(b"stale main bytecode") + + stale_nested_pyc = ( + fixture_source / "nested" / "__pycache__" / f"helper.{cache_tag}.pyc" + ) + stale_nested_pyc.parent.mkdir(parents=True, exist_ok=True) + stale_nested_pyc.write_bytes(b"stale nested bytecode") + + stale_orphan_pyc = fixture_source / "orphan.pyc" + stale_orphan_pyc.write_bytes(b"stale orphan bytecode") + + copied_fixture = tmp_path / "copied_fixture" + shutil.copytree(fixture_source, copied_fixture) + + with _preserve_import_state("main"): + instance = _load_first_instance(copied_fixture) + + assert instance.value == "fresh" + assert not (copied_fixture / "nested" / "__pycache__").exists() + assert not (copied_fixture / "orphan.pyc").exists() + if (copied_fixture / "__pycache__" / f"main.{cache_tag}.pyc").exists(): + assert ( + copied_fixture / "__pycache__" / f"main.{cache_tag}.pyc" + ).read_bytes() != b"stale main bytecode" + + +def test_discover_plugins_allows_plugins_without_requirements_file( + tmp_path: Path, +) -> None: + plugins_dir = tmp_path / "plugins" + plugin_dir = plugins_dir / "no_requirements" + _write_plugin( + plugin_dir, + plugin_name="no_requirements", + class_name="NoRequirementsPlugin", + main_source=""" +from astrbot_sdk import Star + + +class NoRequirementsPlugin(Star): + value = "no-deps" +""", + write_requirements=False, + ) + + discovered = discover_plugins(plugins_dir) + + assert [plugin.name for plugin in discovered.plugins] == ["no_requirements"] + assert discovered.skipped_plugins == {} + assert discovered.issues == [] + + with _preserve_import_state("main"): + instance = _load_first_instance(plugin_dir) + + assert instance.value == "no-deps" diff --git a/astrbot-sdk/tests/test_runtime_loader_support.py b/astrbot-sdk/tests/test_runtime_loader_support.py new file mode 100644 index 0000000000..0a29e53243 --- /dev/null +++ b/astrbot-sdk/tests/test_runtime_loader_support.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from typing import Optional + +import pytest + +from astrbot_sdk import Context, MessageEvent, on_schedule, provide_capability +from astrbot_sdk.runtime._loader_support import ( + build_param_specs, + resolve_capability_candidate, + resolve_handler_candidate, + validate_schedule_signature, +) +from astrbot_sdk.schedule import ScheduleContext +from astrbot_sdk.types import GreedyStr + + +def test_build_param_specs_skips_injected_params_and_preserves_optional_and_greedy() -> ( + None +): + def handler( + event: MessageEvent, + ctx: Context, + count: int, + maybe_name: Optional[str], + enabled: bool = False, + remainder: GreedyStr = "", + ) -> None: + return None + + specs = build_param_specs(handler) + + assert [spec.name for spec in specs] == [ + "count", + "maybe_name", + "enabled", + "remainder", + ] + assert [spec.type for spec in specs] == ["int", "optional", "bool", "greedy_str"] + assert specs[1].inner_type == "str" + assert specs[1].required is False + assert specs[2].required is False + assert specs[3].required is False + + +def test_build_param_specs_rejects_non_terminal_greedy_string() -> None: + def handler(remainder: GreedyStr, count: int) -> None: + return None + + with pytest.raises(ValueError, match="GreedyStr"): + build_param_specs(handler) + + +def test_validate_schedule_signature_rejects_non_injected_names() -> None: + def valid(ctx: Context, schedule: ScheduleContext) -> None: + return None + + def invalid(ctx: Context, event: MessageEvent) -> None: + return None + + validate_schedule_signature(valid) + + with pytest.raises(ValueError, match="Schedule handler"): + validate_schedule_signature(invalid) + + +def test_resolve_handler_candidate_finds_schedule_decorated_method() -> None: + class Plugin: + @on_schedule(interval_seconds=60, description="heartbeat") + async def tick(self, ctx: Context) -> None: + return None + + instance = Plugin() + + resolved = resolve_handler_candidate(instance, "tick") + + assert resolved is not None + bound, meta = resolved + assert bound.__name__ == "tick" + assert meta.trigger is not None + assert meta.trigger.type == "schedule" + + +def test_resolve_capability_candidate_finds_capability_decorated_method() -> None: + class Plugin: + @provide_capability( + "plugin.echo", + description="Echo capability", + input_schema={"type": "object"}, + output_schema={"type": "object"}, + ) + async def echo(self, payload: dict) -> dict: + return payload + + instance = Plugin() + + resolved = resolve_capability_candidate(instance, "echo") + + assert resolved is not None + bound, meta = resolved + assert bound.__name__ == "echo" + assert meta.descriptor.name == "plugin.echo" diff --git a/astrbot-sdk/tests/test_runtime_peer.py b/astrbot-sdk/tests/test_runtime_peer.py new file mode 100644 index 0000000000..bcdfa50a7f --- /dev/null +++ b/astrbot-sdk/tests/test_runtime_peer.py @@ -0,0 +1,301 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable +from typing import Any + +import pytest + +from astrbot_sdk.errors import AstrBotError, ErrorCodes +from astrbot_sdk.protocol.messages import ( + EventMessage, + InvokeMessage, + PeerInfo, + ResultMessage, + parse_message, +) +from astrbot_sdk.runtime.peer import Peer +from astrbot_sdk.runtime.transport import Transport + + +class _ControlledTransport(Transport): + def __init__(self) -> None: + super().__init__() + self.sent_payloads: list[str] = [] + self.on_send: Callable[[str], Awaitable[None]] | None = None + + async def start(self) -> None: + self._closed.clear() + + async def stop(self) -> None: + self._closed.set() + + async def send(self, payload: str) -> None: + self.sent_payloads.append(payload) + if self.on_send is not None: + await self.on_send(payload) + + async def push_message(self, message: Any) -> None: + if isinstance(message, str): + payload = message + else: + payload = message.model_dump_json(exclude_none=True) + await self._dispatch(payload) + + def close_unexpected(self) -> None: + self._closed.set() + + +class _FailingSendTransport(_ControlledTransport): + async def send(self, payload: str) -> None: + self.sent_payloads.append(payload) + raise RuntimeError("send failed") + + +def _make_peer(transport: _ControlledTransport, *, name: str = "test-plugin") -> Peer: + return Peer( + transport=transport, + peer_info=PeerInfo(name=name, role="plugin", version="v4"), + ) + + +async def _stop_peer(peer: Peer) -> None: + await peer.stop() + if peer._transport_watch_task is not None: + await peer._transport_watch_task + + +@pytest.mark.asyncio +async def test_initialize_marks_remote_initialized_on_active_side() -> None: + transport = _ControlledTransport() + peer = _make_peer(transport) + + async def respond_to_initialize(payload: str) -> None: + message = parse_message(payload) + assert message.type == "initialize" + await transport.push_message( + ResultMessage( + id=message.id, + kind="initialize_result", + success=True, + output={ + "peer": { + "name": "astrbot-core", + "role": "core", + "version": "v4", + }, + "protocol_version": "1.0", + "capabilities": [], + "metadata": {"mode": "test"}, + }, + ) + ) + + transport.on_send = respond_to_initialize + await peer.start() + try: + waiter = asyncio.create_task(peer.wait_until_remote_initialized(timeout=0.2)) + await asyncio.sleep(0) + assert not waiter.done() + + output = await peer.initialize([]) + await waiter + + assert output.peer.name == "astrbot-core" + assert peer.remote_peer is not None + assert peer.remote_peer.name == "astrbot-core" + assert peer.remote_metadata["mode"] == "test" + finally: + await _stop_peer(peer) + + +@pytest.mark.asyncio +async def test_wait_until_remote_initialized_fails_when_transport_closes_pre_init() -> ( + None +): + transport = _ControlledTransport() + peer = _make_peer(transport) + await peer.start() + try: + waiter = asyncio.create_task(peer.wait_until_remote_initialized(timeout=None)) + await asyncio.sleep(0) + + transport.close_unexpected() + + with pytest.raises(AstrBotError, match="连接在初始化完成前关闭") as exc_info: + await asyncio.wait_for(waiter, timeout=0.2) + + assert exc_info.value.code == ErrorCodes.PROTOCOL_ERROR + finally: + await _stop_peer(peer) + + +@pytest.mark.asyncio +async def test_invoke_fails_pending_call_on_unexpected_transport_close() -> None: + transport = _ControlledTransport() + peer = _make_peer(transport) + await peer.start() + try: + invoke_task = asyncio.create_task(peer.invoke("llm.chat", {"prompt": "hello"})) + await asyncio.sleep(0) + + assert len(transport.sent_payloads) == 1 + transport.close_unexpected() + + with pytest.raises(AstrBotError, match="连接已关闭") as exc_info: + await asyncio.wait_for(invoke_task, timeout=0.2) + + assert exc_info.value.code == ErrorCodes.NETWORK_ERROR + finally: + await _stop_peer(peer) + + +@pytest.mark.asyncio +async def test_invoke_stream_fails_pending_iterator_on_unexpected_transport_close() -> ( + None +): + transport = _ControlledTransport() + peer = _make_peer(transport) + await peer.start() + try: + iterator = await peer.invoke_stream("llm.stream", {"prompt": "hello"}) + consume_task = asyncio.create_task(anext(iterator)) + await asyncio.sleep(0) + + assert len(transport.sent_payloads) == 1 + transport.close_unexpected() + + with pytest.raises(AstrBotError, match="连接已关闭") as exc_info: + await asyncio.wait_for(consume_task, timeout=0.2) + + assert exc_info.value.code == ErrorCodes.NETWORK_ERROR + finally: + await _stop_peer(peer) + + +@pytest.mark.asyncio +async def test_invoke_stream_hides_completed_event_by_default() -> None: + transport = _ControlledTransport() + peer = _make_peer(transport) + + async def emit_stream(payload: str) -> None: + message = parse_message(payload) + assert message.type == "invoke" + await transport.push_message(EventMessage(id=message.id, phase="started")) + await transport.push_message( + EventMessage(id=message.id, phase="delta", data={"text": "hello"}) + ) + await transport.push_message( + EventMessage(id=message.id, phase="completed", output={"text": "hello"}) + ) + + transport.on_send = emit_stream + await peer.start() + try: + iterator = await peer.invoke_stream("llm.stream", {"prompt": "hello"}) + events = [event async for event in iterator] + + assert [(event.phase, event.data, event.output) for event in events] == [ + ("delta", {"text": "hello"}, {}) + ] + finally: + await _stop_peer(peer) + + +@pytest.mark.asyncio +async def test_invoke_stream_can_include_completed_event() -> None: + transport = _ControlledTransport() + peer = _make_peer(transport) + + async def emit_stream(payload: str) -> None: + message = parse_message(payload) + assert message.type == "invoke" + await transport.push_message(EventMessage(id=message.id, phase="started")) + await transport.push_message( + EventMessage(id=message.id, phase="delta", data={"text": "hello"}) + ) + await transport.push_message( + EventMessage(id=message.id, phase="completed", output={"text": "hello"}) + ) + + transport.on_send = emit_stream + await peer.start() + try: + iterator = await peer.invoke_stream( + "llm.stream", + {"prompt": "hello"}, + include_completed=True, + ) + events = [event async for event in iterator] + + assert [(event.phase, event.data, event.output) for event in events] == [ + ("delta", {"text": "hello"}, {}), + ("completed", {}, {"text": "hello"}), + ] + finally: + await _stop_peer(peer) + + +@pytest.mark.asyncio +async def test_invoke_stream_failed_event_becomes_exception() -> None: + transport = _ControlledTransport() + peer = _make_peer(transport) + + async def emit_failed_event(payload: str) -> None: + message = parse_message(payload) + assert message.type == "invoke" + await transport.push_message(EventMessage(id=message.id, phase="started")) + await transport.push_message( + EventMessage( + id=message.id, + phase="failed", + error={ + "code": ErrorCodes.INTERNAL_ERROR, + "message": "boom", + "hint": "", + "retryable": False, + "docs_url": "", + }, + ) + ) + + transport.on_send = emit_failed_event + await peer.start() + try: + iterator = await peer.invoke_stream("llm.stream", {"prompt": "hello"}) + + with pytest.raises(AstrBotError, match="boom") as exc_info: + async for _event in iterator: + pass + + assert exc_info.value.code == ErrorCodes.INTERNAL_ERROR + finally: + await _stop_peer(peer) + + +@pytest.mark.asyncio +async def test_inbound_invoke_send_failure_marks_peer_unusable() -> None: + transport = _FailingSendTransport() + peer = _make_peer(transport) + + async def handle_invoke(_message: Any, _token: Any) -> dict[str, Any]: + return {"ok": True} + + peer.set_invoke_handler(handle_invoke) + await peer.start() + try: + await transport.push_message( + InvokeMessage( + id="msg_0001", + capability="demo.echo", + input={}, + stream=False, + ) + ) + + await asyncio.wait_for(peer.wait_closed(), timeout=0.2) + + assert peer._unusable is True + assert len(transport.sent_payloads) == 2 + finally: + await _stop_peer(peer) diff --git a/astrbot-sdk/tests/test_runtime_supervisor_registry_sync.py b/astrbot-sdk/tests/test_runtime_supervisor_registry_sync.py new file mode 100644 index 0000000000..af8f4bb176 --- /dev/null +++ b/astrbot-sdk/tests/test_runtime_supervisor_registry_sync.py @@ -0,0 +1,298 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from astrbot_sdk.errors import AstrBotError, ErrorCodes +from astrbot_sdk.runtime.capability_router import CapabilityRouter +import astrbot_sdk.runtime.supervisor as supervisor_module +from astrbot_sdk.runtime.environment_groups import EnvironmentPlanResult +from astrbot_sdk.runtime.loader import PluginDiscoveryResult, PluginSpec +from astrbot_sdk.runtime.supervisor import SupervisorRuntime, WorkerSession +from astrbot_sdk.runtime.transport import Transport + + +class _DummyTransport(Transport): + async def start(self) -> None: + self._closed.clear() + + async def stop(self) -> None: + self._closed.set() + + async def send(self, payload: str) -> None: + return None + + +class _RecordingPeer: + def __init__(self) -> None: + self.initialize_calls: list[dict[str, object]] = [] + self.started = False + self.stopped = False + + def set_invoke_handler(self, _handler) -> None: + return None + + def set_cancel_handler(self, _handler) -> None: + return None + + async def start(self) -> None: + self.started = True + + async def initialize( + self, + handlers, + *, + provided_capabilities, + metadata, + ) -> None: + self.initialize_calls.append( + { + "handlers": list(handlers), + "provided_capabilities": list(provided_capabilities), + "metadata": dict(metadata), + } + ) + + async def stop(self) -> None: + self.stopped = True + + +class _StaticEnvManager: + def __init__(self, plugins: list[PluginSpec]) -> None: + self._plugins = list(plugins) + + def plan(self, _plugins: list[PluginSpec]) -> EnvironmentPlanResult: + return EnvironmentPlanResult(plugins=list(self._plugins)) + + +def _write_plugin_spec(tmp_path: Path, plugin_name: str) -> PluginSpec: + plugin_dir = tmp_path / plugin_name + plugin_dir.mkdir(parents=True, exist_ok=True) + manifest_path = plugin_dir / "plugin.yaml" + manifest_path.write_text( + f""" +_schema_version: 2 +name: {plugin_name} +author: tests +version: 1.0.0 +desc: supervisor registry sync tests + +runtime: + python: "3.12" + +components: + - class: main:TestPlugin +""".strip() + + "\n", + encoding="utf-8", + ) + requirements_path = plugin_dir / "requirements.txt" + requirements_path.write_text("", encoding="utf-8") + (plugin_dir / "main.py").write_text( + "from astrbot_sdk import Star\n\n\nclass TestPlugin(Star):\n pass\n", + encoding="utf-8", + ) + return PluginSpec( + name=plugin_name, + plugin_dir=plugin_dir, + manifest_path=manifest_path, + requirements_path=requirements_path, + python_version="3.12", + manifest_data={ + "name": plugin_name, + "author": "tests", + "version": "1.0.0", + "desc": "supervisor registry sync tests", + "components": [{"class": "main:TestPlugin"}], + "runtime": {"python": "3.12"}, + }, + ) + + +@pytest.mark.asyncio +async def test_supervisor_publishes_plugin_registry_in_two_phases( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + alpha = _write_plugin_spec(tmp_path, "alpha") + beta = _write_plugin_spec(tmp_path, "beta") + plugins = [alpha, beta] + runtime = SupervisorRuntime( + transport=_DummyTransport(), + plugins_dir=tmp_path, + env_manager=_StaticEnvManager(plugins), + ) + peer = _RecordingPeer() + runtime.peer = peer # type: ignore[assignment] + + monkeypatch.setattr( + supervisor_module, + "discover_plugins", + lambda _plugins_dir: PluginDiscoveryResult( + plugins=list(plugins), + skipped_plugins={}, + issues=[], + ), + ) + + phase_snapshots: list[tuple[str, dict[str, bool]]] = [] + + class _FakeWorkerSession: + def __init__( + self, + *, + plugin=None, + group=None, + repo_root, + env_manager, + capability_router, + on_closed=None, + ) -> None: + del group, repo_root, env_manager, capability_router, on_closed + assert plugin is not None + self.plugin = plugin + self.plugins = [plugin] + self.group_id = plugin.name + self.handlers = [] + self.provided_capabilities = [] + self.loaded_plugins: list[str] = [] + self.skipped_plugins: dict[str, str] = {} + self.issues = [] + self.capability_sources: dict[str, str] = {} + + async def start(self) -> None: + phase_snapshots.append( + ( + self.plugin.name, + { + name: bool(entry.metadata.get("enabled", False)) + for name, entry in runtime.capability_router._plugins.items() + }, + ) + ) + if self.plugin.name == "beta": + raise RuntimeError("beta worker failed") + self.loaded_plugins = [self.plugin.name] + + async def stop(self) -> None: + return None + + def start_close_watch(self) -> None: + return None + + def describe(self) -> dict[str, object]: + return { + "group_id": self.group_id, + "plugins": [plugin.name for plugin in self.plugins], + "loaded_plugins": list(self.loaded_plugins), + "skipped_plugins": dict(self.skipped_plugins), + "issues": list(self.issues), + } + + monkeypatch.setattr(supervisor_module, "WorkerSession", _FakeWorkerSession) + + await runtime.start() + + assert phase_snapshots == [ + ("alpha", {"alpha": False, "beta": False}), + ("beta", {"alpha": False, "beta": False}), + ] + assert runtime.loaded_plugins == ["alpha"] + assert runtime.skipped_plugins["beta"] == "beta worker failed" + assert runtime.capability_router._plugins["alpha"].metadata["enabled"] is True + assert runtime.capability_router._plugins["beta"].metadata["enabled"] is False + assert peer.started is True + assert len(peer.initialize_calls) == 1 + assert peer.initialize_calls[0]["metadata"] == { + "plugins": ["alpha"], + "skipped_plugins": {"beta": "beta worker failed"}, + "issues": [ + { + "severity": "error", + "phase": "load", + "plugin_id": "beta", + "message": "插件 worker 启动失败", + "details": "beta worker failed", + "hint": "", + } + ], + "aggregated_handler_ids": [], + "worker_groups": [ + { + "group_id": "alpha", + "plugins": ["alpha"], + "loaded_plugins": ["alpha"], + "skipped_plugins": {}, + "issues": [], + } + ], + "worker_group_count": 1, + } + + +@pytest.mark.asyncio +async def test_worker_session_start_surfaces_init_waiter_failure( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + plugin = _write_plugin_spec(tmp_path, "alpha") + session = WorkerSession( + plugin=plugin, + repo_root=tmp_path, + env_manager=_StaticEnvManager([plugin]), + capability_router=CapabilityRouter(), + ) + session._worker_command = lambda: ( + Path("/usr/bin/python3"), + ["/usr/bin/python3", "-m", "astrbot_sdk", "worker"], + str(tmp_path), + ) + + class _StubStdioTransport: + def __init__(self, *, command, cwd, env) -> None: + self.command = command + self.cwd = cwd + self.env = env + + created_peers: list[object] = [] + + class _FailingInitPeer: + def __init__(self, *, transport, peer_info) -> None: + del transport, peer_info + self.remote_handlers = [] + self.remote_provided_capabilities = [] + self.remote_metadata = {} + self.stopped = False + created_peers.append(self) + + def set_initialize_handler(self, _handler) -> None: + return None + + def set_invoke_handler(self, _handler) -> None: + return None + + async def start(self) -> None: + return None + + async def wait_until_remote_initialized( + self, timeout: float | None = None + ) -> None: + del timeout + raise AstrBotError.protocol_error("连接在初始化完成前关闭") + + async def wait_closed(self) -> None: + return None + + async def stop(self) -> None: + self.stopped = True + + monkeypatch.setattr(supervisor_module, "StdioTransport", _StubStdioTransport) + monkeypatch.setattr(supervisor_module, "Peer", _FailingInitPeer) + + with pytest.raises(AstrBotError, match="连接在初始化完成前关闭") as exc_info: + await session.start() + + assert exc_info.value.code == ErrorCodes.PROTOCOL_ERROR + assert len(created_peers) == 1 + assert created_peers[0].stopped is True diff --git a/astrbot-sdk/tests/test_runtime_transport.py b/astrbot-sdk/tests/test_runtime_transport.py new file mode 100644 index 0000000000..c6981fbba7 --- /dev/null +++ b/astrbot-sdk/tests/test_runtime_transport.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +import asyncio +import io +from types import SimpleNamespace + +import pytest + +from astrbot_sdk.runtime import transport as transport_module +from astrbot_sdk.runtime.transport import ( + StdioTransport, + WebSocketServerTransport, + WebSocketClientTransport, + _frame_stdio_payload, +) + + +@pytest.mark.unit +def test_frame_stdio_payload_rejects_embedded_newlines() -> None: + with pytest.raises(ValueError, match="原始换行符"): + _frame_stdio_payload("hello\nworld") + + +@pytest.mark.asyncio +async def test_stdio_read_process_loop_dispatches_messages_and_sets_closed() -> None: + received: list[str] = [] + + class _FakeStdout: + def __init__(self) -> None: + self._items = [b"first\r\n", b"second\n", b""] + + async def readline(self) -> bytes: + return self._items.pop(0) + + transport = StdioTransport(command=["python", "-V"]) + transport._process = SimpleNamespace(stdout=_FakeStdout()) + transport.set_message_handler(lambda payload: _capture(received, payload)) + + await transport._read_process_loop() + + assert received == ["first", "second"] + assert transport._closed.is_set() is True + + +@pytest.mark.asyncio +async def test_stdio_wait_closed_unblocks_after_process_eof() -> None: + class _FakeStdout: + async def readline(self) -> bytes: + return b"" + + transport = StdioTransport(command=["python", "-V"]) + transport._process = SimpleNamespace(stdout=_FakeStdout()) + + waiter = asyncio.create_task(transport.wait_closed()) + await transport._read_process_loop() + await asyncio.wait_for(waiter, timeout=1) + + assert waiter.done() is True + + +@pytest.mark.asyncio +async def test_stdio_read_file_loop_dispatches_messages_and_sets_closed() -> None: + received: list[str] = [] + transport = StdioTransport(stdin=io.StringIO("line-1\nline-2\r\n")) + transport.set_message_handler(lambda payload: _capture(received, payload)) + + await transport._read_file_loop() + + assert received == ["line-1", "line-2"] + assert transport._closed.is_set() is True + + +@pytest.mark.asyncio +async def test_stdio_stop_kills_process_when_terminate_times_out( + monkeypatch: pytest.MonkeyPatch, +) -> None: + calls: list[str] = [] + + class _FakeProcess: + returncode = None + stdin = None + + def terminate(self) -> None: + calls.append("terminate") + + def kill(self) -> None: + calls.append("kill") + + async def wait(self) -> None: + calls.append("wait") + + async def fake_wait_for(awaitable, timeout: float): + awaitable.close() + del timeout + raise asyncio.TimeoutError + + transport = StdioTransport(command=["python", "-V"]) + transport._process = _FakeProcess() + monkeypatch.setattr(transport_module.asyncio, "wait_for", fake_wait_for) + + await transport.stop() + + assert calls == ["terminate", "kill", "wait"] + assert transport._process is None + assert transport._closed.is_set() is True + + +@pytest.mark.asyncio +async def test_websocket_client_read_loop_dispatches_text_and_binary_then_closes( + monkeypatch: pytest.MonkeyPatch, +) -> None: + received: list[str] = [] + + class _FakeWebSocket: + closed = False + + def __init__(self) -> None: + self._messages = iter( + [ + SimpleNamespace(type="text", data="hello"), + SimpleNamespace(type="binary", data=b"world"), + ] + ) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._messages) + except StopIteration as exc: + raise StopAsyncIteration from exc + + def exception(self): + return None + + fake_aiohttp = SimpleNamespace( + WSMsgType=SimpleNamespace(TEXT="text", BINARY="binary", ERROR="error") + ) + monkeypatch.setattr(transport_module, "_get_aiohttp", lambda: fake_aiohttp) + + transport = WebSocketClientTransport(url="ws://test") + transport._ws = _FakeWebSocket() + transport.set_message_handler(lambda payload: _capture(received, payload)) + + await transport._read_loop() + + assert received == ["hello", "world"] + assert transport._closed.is_set() is True + + +@pytest.mark.asyncio +async def test_websocket_server_send_raises_when_connection_is_gone_after_wait( + monkeypatch: pytest.MonkeyPatch, +) -> None: + transport = WebSocketServerTransport() + transport._connected.set() + transport._ws = SimpleNamespace(closed=True) + + async def fake_wait_for(awaitable, timeout: float): + del timeout + return await awaitable + + monkeypatch.setattr(transport_module.asyncio, "wait_for", fake_wait_for) + + with pytest.raises(RuntimeError, match="尚未连接"): + await transport.send("payload") + + +async def _capture(received: list[str], payload: str) -> None: + received.append(payload) diff --git a/astrbot-sdk/tests/test_runtime_worker.py b/astrbot-sdk/tests/test_runtime_worker.py new file mode 100644 index 0000000000..567a1e87b3 --- /dev/null +++ b/astrbot-sdk/tests/test_runtime_worker.py @@ -0,0 +1,328 @@ +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace + +import pytest +from unittest.mock import AsyncMock + +from astrbot_sdk.context import CancelToken, Context +from astrbot_sdk.errors import AstrBotError, ErrorCodes +from astrbot_sdk.llm.agents import AgentSpec +from astrbot_sdk.llm.entities import LLMToolSpec +from astrbot_sdk.protocol.descriptors import CapabilityDescriptor +from astrbot_sdk.protocol.messages import InvokeMessage +from astrbot_sdk.runtime.loader import ( + LoadedAgent, + LoadedCapability, + LoadedPlugin, + PluginDiscoveryIssue, + PluginSpec, +) +from astrbot_sdk.runtime.worker import ( + GLOBAL_MCP_RISK_ATTR, + GroupPluginRuntimeState, + GroupWorkerRuntime, + PluginWorkerRuntime, +) + + +def _plugin_spec(name: str) -> PluginSpec: + plugin_dir = Path(f"/tmp/{name}") + return PluginSpec( + name=name, + plugin_dir=plugin_dir, + manifest_path=plugin_dir / "plugin.yaml", + requirements_path=plugin_dir / "requirements.txt", + python_version="3.12", + manifest_data={"name": name}, + ) + + +@pytest.mark.asyncio +async def test_plugin_worker_handle_invoke_maps_lookup_error_to_astrbot_error() -> None: + runtime = object.__new__(PluginWorkerRuntime) + runtime.dispatcher = SimpleNamespace(invoke=AsyncMock()) + runtime.capability_dispatcher = SimpleNamespace( + invoke=AsyncMock(side_effect=LookupError("missing")), + ) + + with pytest.raises(AstrBotError) as exc_info: + await PluginWorkerRuntime._handle_invoke( + runtime, + InvokeMessage(id="req-cap", capability="missing.capability", input={}), + CancelToken(), + ) + + assert exc_info.value.code == ErrorCodes.CAPABILITY_NOT_FOUND + assert "missing.capability" in exc_info.value.message + + +@pytest.mark.asyncio +async def test_plugin_worker_handle_cancel_fans_out_to_both_dispatchers() -> None: + runtime = object.__new__(PluginWorkerRuntime) + runtime.dispatcher = SimpleNamespace(cancel=AsyncMock()) + runtime.capability_dispatcher = SimpleNamespace(cancel=AsyncMock()) + + await PluginWorkerRuntime._handle_cancel(runtime, "req-123") + + runtime.dispatcher.cancel.assert_awaited_once_with("req-123") + runtime.capability_dispatcher.cancel.assert_awaited_once_with("req-123") + + +@pytest.mark.asyncio +async def test_plugin_worker_start_initializes_metadata_and_handlers() -> None: + runtime = object.__new__(PluginWorkerRuntime) + runtime.plugin = _plugin_spec("alpha") + runtime.loaded_plugin = LoadedPlugin( + plugin=runtime.plugin, + handlers=[], + capabilities=[], + llm_tools=[], + agents=[], + instances=[], + ) + runtime.issues = [] + lifecycle_calls: list[str] = [] + + class _Peer: + def __init__(self) -> None: + self.started = False + self.stopped = False + self.initialize_calls: list[dict[str, object]] = [] + + async def start(self) -> None: + self.started = True + + async def initialize( + self, handlers, *, provided_capabilities, metadata + ) -> None: + self.initialize_calls.append( + { + "handlers": list(handlers), + "provided_capabilities": list(provided_capabilities), + "metadata": dict(metadata), + } + ) + + async def stop(self) -> None: + self.stopped = True + + runtime.peer = _Peer() + + async def fake_run_lifecycle(method_name: str) -> None: + lifecycle_calls.append(method_name) + + runtime._run_lifecycle = fake_run_lifecycle # type: ignore[method-assign] + + await PluginWorkerRuntime.start(runtime) + + assert runtime.peer.started is True + assert lifecycle_calls == ["on_start"] + assert runtime.peer.initialize_calls[0]["metadata"]["plugin_id"] == "alpha" + assert runtime.peer.initialize_calls[0]["metadata"]["loaded_plugins"] == ["alpha"] + + +@pytest.mark.asyncio +async def test_plugin_worker_start_runs_on_stop_when_initialize_fails() -> None: + runtime = object.__new__(PluginWorkerRuntime) + runtime.plugin = _plugin_spec("alpha") + runtime.loaded_plugin = LoadedPlugin( + plugin=runtime.plugin, + handlers=[], + capabilities=[], + llm_tools=[], + agents=[], + instances=[], + ) + runtime.issues = [] + lifecycle_calls: list[str] = [] + + class _Peer: + def __init__(self) -> None: + self.stopped = False + + async def start(self) -> None: + return None + + async def initialize( + self, handlers, *, provided_capabilities, metadata + ) -> None: + del handlers, provided_capabilities, metadata + raise RuntimeError("initialize failed") + + async def stop(self) -> None: + self.stopped = True + + runtime.peer = _Peer() + + async def fake_run_lifecycle(method_name: str) -> None: + lifecycle_calls.append(method_name) + + runtime._run_lifecycle = fake_run_lifecycle # type: ignore[method-assign] + + with pytest.raises(RuntimeError, match="initialize failed"): + await PluginWorkerRuntime.start(runtime) + + assert lifecycle_calls == ["on_start", "on_stop"] + assert runtime.peer.stopped is True + + +@pytest.mark.asyncio +async def test_group_worker_start_raises_when_all_plugins_become_inactive() -> None: + runtime = object.__new__(GroupWorkerRuntime) + alpha = _plugin_spec("alpha") + runtime.group_id = "worker-group" + runtime._plugin_states = [ + GroupPluginRuntimeState( + plugin=alpha, + loaded_plugin=LoadedPlugin(plugin=alpha, handlers=[], instances=[]), + lifecycle_context=Context(peer=SimpleNamespace(), plugin_id="alpha"), + ) + ] + runtime._active_plugin_states = list(runtime._plugin_states) + runtime.skipped_plugins = {} + runtime.issues = [] + refresh_snapshots: list[list[str]] = [] + + class _Peer: + def __init__(self) -> None: + self.started = False + self.stopped = False + + async def start(self) -> None: + self.started = True + + async def initialize( + self, handlers, *, provided_capabilities, metadata + ) -> None: + del handlers, provided_capabilities, metadata + raise AssertionError("initialize should not run without active plugins") + + async def stop(self) -> None: + self.stopped = True + + runtime.peer = _Peer() + + def fake_refresh_dispatchers() -> None: + refresh_snapshots.append( + [state.plugin.name for state in runtime._active_plugin_states] + ) + + async def fake_run_lifecycle(state, method_name: str) -> None: + del state, method_name + raise RuntimeError("on_start failed") + + runtime._refresh_dispatchers = fake_refresh_dispatchers # type: ignore[method-assign] + runtime._run_lifecycle = fake_run_lifecycle # type: ignore[method-assign] + + with pytest.raises(RuntimeError, match="has no active plugins"): + await GroupWorkerRuntime.start(runtime) + + assert runtime.peer.started is True + assert runtime.peer.stopped is True + assert runtime.skipped_plugins == {"alpha": "on_start failed"} + assert runtime.issues[0].phase == "lifecycle" + assert refresh_snapshots[-1] == [] + + +def test_group_worker_initialize_metadata_aggregates_runtime_state() -> None: + class _RiskyPlugin: + pass + + setattr(_RiskyPlugin, GLOBAL_MCP_RISK_ATTR, True) + + alpha = _plugin_spec("alpha") + beta = _plugin_spec("beta") + alpha_capability = LoadedCapability( + descriptor=CapabilityDescriptor( + name="alpha.echo", + description="echo", + input_schema={"type": "object"}, + output_schema={"type": "object"}, + ), + callable=lambda: None, + owner=object(), + plugin_id="alpha", + ) + alpha_tool = LoadedAgent( + spec=AgentSpec( + name="alpha-agent", + description="agent", + runner_class="alpha.runner:Runner", + ), + runner_class=type("Runner", (), {}), + plugin_id="alpha", + ) + alpha_llm_tool = LoadedPlugin( + plugin=alpha, + handlers=[], + capabilities=[alpha_capability], + llm_tools=[ + SimpleNamespace( + spec=LLMToolSpec.create(name="alpha-tool", description="tool") + ) + ], + agents=[alpha_tool], + instances=[_RiskyPlugin()], + ) + beta_plugin = LoadedPlugin( + plugin=beta, + handlers=[], + capabilities=[], + llm_tools=[], + agents=[], + instances=[object()], + ) + runtime = object.__new__(GroupWorkerRuntime) + runtime.group_id = "worker-group" + runtime.plugins = [alpha, beta] + runtime.skipped_plugins = {"beta": "start failed"} + runtime.issues = [ + PluginDiscoveryIssue( + severity="error", + phase="load", + plugin_id="beta", + message="插件加载失败", + details="start failed", + ) + ] + runtime._active_plugin_states = [ + GroupPluginRuntimeState( + plugin=alpha, + loaded_plugin=alpha_llm_tool, + lifecycle_context=Context(peer=SimpleNamespace(), plugin_id="alpha"), + ), + GroupPluginRuntimeState( + plugin=beta, + loaded_plugin=beta_plugin, + lifecycle_context=Context(peer=SimpleNamespace(), plugin_id="beta"), + ), + ] + + metadata = GroupWorkerRuntime._initialize_metadata(runtime) + + assert metadata["group_id"] == "worker-group" + assert metadata["plugins"] == ["alpha", "beta"] + assert metadata["loaded_plugins"] == ["alpha", "beta"] + assert metadata["skipped_plugins"] == {"beta": "start failed"} + assert metadata["capability_sources"] == {"alpha.echo": "alpha"} + assert metadata["llm_tools"] == [ + { + "name": "alpha-tool", + "description": "tool", + "parameters_schema": {"type": "object", "properties": {}}, + "active": True, + "plugin_id": "alpha", + } + ] + assert metadata["agents"] == [ + { + "name": "alpha-agent", + "description": "agent", + "tool_names": [], + "runner_class": "alpha.runner:Runner", + "plugin_id": "alpha", + } + ] + assert metadata["acknowledge_global_mcp_risk"] is True diff --git a/astrbot-sdk/tests/test_sdk/__init__.py b/astrbot-sdk/tests/test_sdk/__init__.py new file mode 100644 index 0000000000..aca4faad60 --- /dev/null +++ b/astrbot-sdk/tests/test_sdk/__init__.py @@ -0,0 +1 @@ +"""Package marker for shared SDK test helpers.""" diff --git a/astrbot-sdk/tests/test_sdk/unit/__init__.py b/astrbot-sdk/tests/test_sdk/unit/__init__.py new file mode 100644 index 0000000000..c9f3075a05 --- /dev/null +++ b/astrbot-sdk/tests/test_sdk/unit/__init__.py @@ -0,0 +1 @@ +"""Package marker for unit-level shared test helpers.""" diff --git a/astrbot-sdk/tests/test_sdk/unit/_mcp_contract.py b/astrbot-sdk/tests/test_sdk/unit/_mcp_contract.py new file mode 100644 index 0000000000..1ae371146a --- /dev/null +++ b/astrbot-sdk/tests/test_sdk/unit/_mcp_contract.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from typing import Protocol + + +class LocalMCPBackendContract(Protocol): + async def get_server(self, name: str): ... + + async def list_servers(self): ... + + async def enable_server(self, name: str): ... + + async def disable_server(self, name: str): ... + + async def wait_until_ready(self, name: str, *, timeout: float): ... + + +async def exercise_local_mcp_contract( + backend: LocalMCPBackendContract, +) -> None: + """Exercise the minimum local MCP behavior expected by SDK tests. + + The caller is expected to provision a local server named ``demo`` before + invoking this helper. Keeping the contract in-repo prevents the SDK test + suite from depending on AstrBot's external test tree. + """ + + server = await backend.get_server("demo") + assert server is not None + assert server.name == "demo" + assert server.scope.value == "local" + assert server.active is True + assert server.running is True + + servers = await backend.list_servers() + assert [item.name for item in servers] == ["demo"] + + disabled = await backend.disable_server("demo") + assert disabled.name == "demo" + assert disabled.scope.value == "local" + assert disabled.active is False + assert disabled.running is False + + disabled_snapshot = await backend.get_server("demo") + assert disabled_snapshot is not None + assert disabled_snapshot.active is False + assert disabled_snapshot.running is False + + enabled = await backend.enable_server("demo") + assert enabled.name == "demo" + assert enabled.scope.value == "local" + assert enabled.active is True + assert enabled.running is True + assert enabled.tools == ["lookup"] + + ready = await backend.wait_until_ready("demo", timeout=0.1) + assert ready.name == "demo" + assert ready.scope.value == "local" + assert ready.active is True + assert ready.running is True + assert ready.tools == ["lookup"] diff --git a/astrbot-sdk/tests/test_sdk_environment_groups.py b/astrbot-sdk/tests/test_sdk_environment_groups.py new file mode 100644 index 0000000000..91095de16e --- /dev/null +++ b/astrbot-sdk/tests/test_sdk_environment_groups.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest +from astrbot_sdk.runtime.environment_groups import GroupEnvironmentManager + + +@pytest.mark.unit +def test_matches_python_version_accepts_uv_version_info_format(tmp_path: Path) -> None: + venv_path = tmp_path / "venv" + venv_path.mkdir() + (venv_path / "pyvenv.cfg").write_text( + "\n".join( + [ + "home = C:\\Users\\tester\\AppData\\Local\\Programs\\Python\\Python313", + "implementation = CPython", + "uv = 0.9.17", + "version_info = 3.13.12", + "include-system-site-packages = true", + ] + ), + encoding="utf-8", + ) + + assert GroupEnvironmentManager._matches_python_version(venv_path, "3.13") is True + assert GroupEnvironmentManager._matches_python_version(venv_path, "3.11") is False diff --git a/astrbot-sdk/tests/test_sdk_peer_errors.py b/astrbot-sdk/tests/test_sdk_peer_errors.py new file mode 100644 index 0000000000..20845b2fc4 --- /dev/null +++ b/astrbot-sdk/tests/test_sdk_peer_errors.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import pytest +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.protocol.messages import ErrorPayload, ResultMessage + +pytestmark = pytest.mark.unit + + +def test_error_payload_accepts_docs_url_and_details() -> None: + payload = ErrorPayload.model_validate( + AstrBotError.invalid_input( + "bad input", + docs_url="https://docs.astrbot.org/sdk/errors#invalid-input", + details={"field": "name"}, + ).to_payload() + ) + + assert payload.docs_url == "https://docs.astrbot.org/sdk/errors#invalid-input" + assert payload.details == {"field": "name"} + + +def test_failed_result_round_trip_preserves_error_metadata() -> None: + error = AstrBotError.internal_error( + "boom", + hint="try again later", + docs_url="https://docs.astrbot.org/sdk/errors#internal-error", + details={"phase": "invoke"}, + ) + message = ResultMessage( + id="req-1", + success=False, + error=ErrorPayload.model_validate(error.to_payload()), + ) + + restored = AstrBotError.from_payload( + message.error.model_dump() if message.error else {} + ) + + assert restored.code == error.code + assert restored.message == error.message + assert restored.hint == error.hint + assert restored.docs_url == error.docs_url + assert restored.details == error.details diff --git a/astrbot-sdk/tests/test_sdk_transport.py b/astrbot-sdk/tests/test_sdk_transport.py new file mode 100644 index 0000000000..76cd2a79d4 --- /dev/null +++ b/astrbot-sdk/tests/test_sdk_transport.py @@ -0,0 +1,44 @@ +# ruff: noqa: SLF001 +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +from astrbot_sdk.runtime import transport as transport_module +from astrbot_sdk.runtime.transport import StdioTransport + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_stdio_transport_retries_transient_windows_access_denied( + monkeypatch: pytest.MonkeyPatch, +) -> None: + attempts = 0 + fake_process = SimpleNamespace() + + async def fake_create_subprocess_exec(*args, **kwargs): + nonlocal attempts + attempts += 1 + if attempts == 1: + error = PermissionError(13, "Access is denied") + error.winerror = 5 + raise error + return fake_process + + async def fake_sleep(_delay: float) -> None: + return None + + monkeypatch.setattr( + transport_module.asyncio, + "create_subprocess_exec", + fake_create_subprocess_exec, + ) + monkeypatch.setattr(transport_module.asyncio, "sleep", fake_sleep) + monkeypatch.setattr(transport_module.sys, "platform", "win32") + + transport = StdioTransport(command=["python", "--version"]) + + process = await transport._start_subprocess_with_retry() + + assert process is fake_process + assert attempts == 2 diff --git a/astrbot-sdk/tests/test_session_waiter_usage.py b/astrbot-sdk/tests/test_session_waiter_usage.py new file mode 100644 index 0000000000..bcbe6cfdf3 --- /dev/null +++ b/astrbot-sdk/tests/test_session_waiter_usage.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +import asyncio +import importlib + +import pytest + +from astrbot_sdk._internal.testing_support import ( + MockCapabilityRouter, + MockContext, + MockMessageEvent, + MockPeer, +) +from astrbot_sdk._internal.invocation_context import caller_plugin_scope +from astrbot_sdk.context import CancelToken, Context +from astrbot_sdk.events import MessageEvent +from astrbot_sdk.protocol.messages import InvokeMessage +from astrbot_sdk.runtime.handler_dispatcher import HandlerDispatcher +from astrbot_sdk.session_waiter import ( + SessionController, + SessionWaiterManager, + _mark_session_waiter_handler_task, + _unmark_session_waiter_handler_task, + session_waiter, +) + +session_waiter_module = importlib.import_module("astrbot_sdk.session_waiter") + + +def _attach_waiter_manager(ctx: MockContext) -> SessionWaiterManager: + manager = SessionWaiterManager(plugin_id=ctx.plugin_id, peer=ctx.peer) + setattr(ctx.peer, "_session_waiter_manager", manager) + return manager + + +@pytest.mark.asyncio +async def test_session_waiter_register_task_pattern_is_non_blocking( + monkeypatch: pytest.MonkeyPatch, +) -> None: + ctx = MockContext() + manager = _attach_waiter_manager(ctx) + warnings: list[tuple[object, ...]] = [] + received: list[str] = [] + + monkeypatch.setattr( + session_waiter_module.logger, + "warning", + lambda *args: warnings.append(args), + ) + + @session_waiter(timeout=30) + async def waiter( + controller: SessionController, + event: MessageEvent, + ) -> None: + received.append(event.text) + controller.stop() + + initial = MockMessageEvent(text="/bind", session_id="session-1", context=ctx) + progress = ["before"] + with caller_plugin_scope(ctx.plugin_id): + background_task = await ctx.register_task(waiter(initial), "waiter:collect") + progress.append("after") + + assert progress == ["before", "after"] + assert not background_task.done() + + for _ in range(5): + if manager.has_waiter(initial): + break + await asyncio.sleep(0) + + assert manager.has_waiter(initial) + + followup = MockMessageEvent(text="alice", session_id="session-1", context=ctx) + await manager.dispatch(followup) + await background_task + + assert received == ["alice"] + assert not manager.has_waiter(initial) + assert warnings == [] + + +@pytest.mark.asyncio +async def test_session_waiter_warns_on_direct_await_in_handler_task( + monkeypatch: pytest.MonkeyPatch, +) -> None: + ctx = MockContext() + manager = _attach_waiter_manager(ctx) + warnings: list[tuple[object, ...]] = [] + received: list[str] = [] + + monkeypatch.setattr( + session_waiter_module.logger, + "warning", + lambda *args: warnings.append(args), + ) + + @session_waiter(timeout=30) + async def waiter( + controller: SessionController, + event: MessageEvent, + ) -> None: + received.append(event.text) + controller.stop() + + initial = MockMessageEvent(text="/bind", session_id="session-2", context=ctx) + + async def direct_wait() -> None: + current_task = asyncio.current_task() + assert current_task is not None + _mark_session_waiter_handler_task(current_task) + try: + await waiter(initial) + finally: + _unmark_session_waiter_handler_task(current_task) + + with caller_plugin_scope(ctx.plugin_id): + wait_task = asyncio.create_task(direct_wait()) + + for _ in range(5): + if manager.has_waiter(initial): + break + await asyncio.sleep(0) + + assert manager.has_waiter(initial) + + followup = MockMessageEvent(text="bob", session_id="session-2", context=ctx) + await manager.dispatch(followup) + await wait_task + + assert received == ["bob"] + assert warnings == [ + ( + "Direct await on session_waiter blocks the current handler dispatch; " + 'prefer `await ctx.register_task(waiter(...), "...")`: ' + "plugin_id={} session_key={}", + "test-plugin", + "session-2", + ) + ] + + +@pytest.mark.asyncio +async def test_session_waiter_warns_on_direct_await_in_redispatched_waiter_task( + monkeypatch: pytest.MonkeyPatch, +) -> None: + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher(plugin_id="test-plugin", peer=peer, handlers=[]) + ctx = Context(peer=peer, plugin_id="test-plugin") + warnings: list[tuple[object, ...]] = [] + received: list[tuple[str, str]] = [] + + monkeypatch.setattr( + session_waiter_module.logger, + "warning", + lambda *args: warnings.append(args), + ) + + @session_waiter(timeout=30) + async def second_wait( + controller: SessionController, + event: MessageEvent, + ) -> None: + received.append(("second", event.text)) + controller.stop() + + @session_waiter(timeout=30) + async def first_wait( + controller: SessionController, + event: MessageEvent, + ) -> None: + received.append(("first", event.text)) + await second_wait(event) + controller.stop() + + initial_event = MessageEvent.from_payload( + { + "type": "message", + "event_type": "message", + "text": "/bind", + "session_id": "session-redispatch", + "user_id": "tester", + "platform": "test", + "platform_id": "test", + "message_type": "private", + "raw": {"event_type": "message"}, + }, + context=ctx, + ) + with caller_plugin_scope(ctx.plugin_id): + waiter_task = await ctx.register_task(first_wait(initial_event), "waiter:first") + + for _ in range(10): + if dispatcher.has_active_waiter(initial_event): + break + await asyncio.sleep(0) + + redispatch_task = asyncio.create_task( + dispatcher.invoke( + InvokeMessage( + id="req-session-waiter-1", + capability="handler.invoke", + input={ + "handler_id": "__sdk_session_waiter__", + "event": { + **initial_event.to_payload(), + "text": "alice", + }, + "args": {}, + }, + ), + CancelToken(), + ) + ) + + for _ in range(10): + if any( + args and str(args[0]).startswith("Direct await on session_waiter") + for args in warnings + ): + break + await asyncio.sleep(0) + + assert redispatch_task.done() is False + redispatch_task.cancel() + waiter_task.cancel() + await asyncio.gather(redispatch_task, waiter_task, return_exceptions=True) + + assert received == [("first", "alice")] + assert any( + args + == ( + "Direct await on session_waiter blocks the current handler dispatch; " + 'prefer `await ctx.register_task(waiter(...), "...")`: ' + "plugin_id={} session_key={}", + "test-plugin", + "session-redispatch", + ) + for args in warnings + ) diff --git a/astrbot-sdk/tests/test_star_on_error_fallback.py b/astrbot-sdk/tests/test_star_on_error_fallback.py new file mode 100644 index 0000000000..987fb503ec --- /dev/null +++ b/astrbot-sdk/tests/test_star_on_error_fallback.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import sys +from pathlib import Path +from types import SimpleNamespace + +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) + +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.runtime.handler_dispatcher import HandlerDispatcher +from astrbot_sdk.star import Star + + +class _DummyEvent: + def __init__(self) -> None: + self.replies: list[str] = [] + + async def reply(self, message: str) -> None: + self.replies.append(message) + + +@pytest.mark.asyncio +async def test_handle_error_fallback_does_not_instantiate_star( + monkeypatch: pytest.MonkeyPatch, +) -> None: + async def _fake_default_on_error(error: Exception, event, ctx) -> None: + del ctx + await event.reply(str(error)) + + def _fail_init(self) -> None: + raise AssertionError("Star should not be instantiated for fallback on_error") + + monkeypatch.setattr(Star, "default_on_error", staticmethod(_fake_default_on_error)) + monkeypatch.setattr(Star, "__init__", _fail_init) + + dispatcher = HandlerDispatcher( + plugin_id="plugin", peer=SimpleNamespace(), handlers=[] + ) + event = _DummyEvent() + + await dispatcher._handle_error( + object(), + RuntimeError("boom"), + event, + SimpleNamespace(), + ) + + assert event.replies == ["boom"] + + +@pytest.mark.asyncio +async def test_default_on_error_formats_astrbot_error_reply() -> None: + event = _DummyEvent() + error = AstrBotError.invalid_input( + "bad payload", + hint="check payload", + docs_url="https://example.com/docs", + details={"b": 2, "a": 1}, + ) + + await Star.default_on_error(error, event, SimpleNamespace()) + + assert len(event.replies) == 1 + assert "check payload" in event.replies[0] + assert "https://example.com/docs" in event.replies[0] + assert '"a": 1' in event.replies[0] + assert '"b": 2' in event.replies[0] + + +@pytest.mark.asyncio +async def test_default_on_error_replies_generic_message_for_unknown_errors() -> None: + event = _DummyEvent() + + await Star.default_on_error(RuntimeError("boom"), event, SimpleNamespace()) + + assert len(event.replies) == 1 + assert event.replies[0] + + +@pytest.mark.asyncio +async def test_on_error_does_not_dispatch_via_subclass_default_on_error() -> None: + class PluginWithShadowedDefault(Star): + async def default_on_error(self, error: Exception, event, ctx) -> None: + del error, event, ctx + raise AssertionError( + "Star.on_error should not virtual-dispatch default_on_error" + ) + + expected_event = _DummyEvent() + actual_event = _DummyEvent() + + await Star.default_on_error(RuntimeError("boom"), expected_event, SimpleNamespace()) + await PluginWithShadowedDefault().on_error( + RuntimeError("boom"), + actual_event, + SimpleNamespace(), + ) + + assert actual_event.replies == expected_event.replies diff --git a/astrbot-sdk/tests/test_testing_session_waiter.py b/astrbot-sdk/tests/test_testing_session_waiter.py new file mode 100644 index 0000000000..49ebc80d0c --- /dev/null +++ b/astrbot-sdk/tests/test_testing_session_waiter.py @@ -0,0 +1,825 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from astrbot_sdk._internal.invocation_context import caller_plugin_scope +from astrbot_sdk.context import CancelToken, Context +from astrbot_sdk.events import MessageEvent +from astrbot_sdk.protocol.messages import InvokeMessage +from astrbot_sdk.runtime.handler_dispatcher import HandlerDispatcher +from astrbot_sdk.session_waiter import SessionController +from astrbot_sdk.testing import LocalRuntimeConfig, PluginHarness +from astrbot_sdk._internal.testing_support import MockCapabilityRouter, MockPeer + + +def _write_session_waiter_plugin(plugin_dir: Path) -> None: + plugin_dir.mkdir(parents=True, exist_ok=True) + (plugin_dir / "plugin.yaml").write_text( + "\n".join( + [ + "name: session_waiter_plugin", + "display_name: Session Waiter Plugin", + "desc: test plugin", + "author: tests", + "version: 0.1.0", + "runtime:", + ' python: "3.11"', + "components:", + " - class: main:SessionWaiterPlugin", + "", + ] + ), + encoding="utf-8", + ) + (plugin_dir / "main.py").write_text( + "\n".join( + [ + "from astrbot_sdk import Context, MessageEvent, SessionController, Star, on_command, session_waiter", + "", + "", + "class SessionWaiterPlugin(Star):", + ' @on_command("start")', + " async def start(self, event: MessageEvent, ctx: Context) -> None:", + ' await event.reply("ready")', + ' await ctx.register_task(self.wait_for_followup(event), "wait for followup")', + "", + " @session_waiter(timeout=30)", + " async def wait_for_followup(", + " self,", + " controller: SessionController,", + " event: MessageEvent,", + " ) -> None:", + " del controller", + ' await event.reply(f"followup:{event.text}")', + "", + ] + ), + encoding="utf-8", + ) + (plugin_dir / "requirements.txt").write_text("", encoding="utf-8") + + +def _build_event(*, text: str, session_id: str, peer: MockPeer) -> MessageEvent: + return MessageEvent.from_payload( + { + "type": "message", + "event_type": "message", + "text": text, + "session_id": session_id, + "user_id": "tester", + "platform": "test", + "platform_id": "test", + "message_type": "private", + "raw": {"event_type": "message"}, + }, + context=Context(peer=peer, plugin_id="test-plugin"), + ) + + +def test_plugin_harness_waiter_probe_uses_dispatcher_public_api(tmp_path: Path) -> None: + plugin_dir = tmp_path / "session_waiter_plugin" + _write_session_waiter_plugin(plugin_dir) + harness = PluginHarness(LocalRuntimeConfig(plugin_dir=plugin_dir)) + peer = MockPeer(MockCapabilityRouter()) + probe_event = _build_event(text="hello", session_id="session-1", peer=peer) + harness.lifecycle_context = probe_event._context + + calls: list[MessageEvent] = [] + + def has_active_waiter(event: MessageEvent) -> bool: + calls.append(event) + return True + + harness.dispatcher = SimpleNamespace(has_active_waiter=has_active_waiter) + + assert harness._has_waiter_for_event(probe_event.to_payload()) is True + assert len(calls) == 1 + assert calls[0].unified_msg_origin == "session-1" + + +@pytest.mark.asyncio +async def test_plugin_harness_dispatches_followup_to_session_waiter( + tmp_path: Path, +) -> None: + plugin_dir = tmp_path / "session_waiter_plugin" + _write_session_waiter_plugin(plugin_dir) + + async with PluginHarness.from_plugin_dir(plugin_dir) as harness: + first_records = await harness.dispatch_text("start", session_id="session-1") + await asyncio.sleep(0) + second_records = await harness.dispatch_text("next", session_id="session-1") + + assert [record.text for record in first_records] == ["ready"] + assert [record.text for record in second_records] == ["followup:next"] + + +@pytest.mark.asyncio +async def test_handler_dispatcher_exposes_active_waiter_probe() -> None: + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher(plugin_id="test-plugin", peer=peer, handlers=[]) + event = _build_event(text="hello", session_id="session-1", peer=peer) + + assert dispatcher.has_active_waiter(event) is False + + async def waiter_task() -> None: + with caller_plugin_scope("test-plugin"): + await dispatcher._session_waiters.register( + event=event, + handler=_noop_waiter, + timeout=30, + record_history_chains=False, + ) + + task = asyncio.create_task(waiter_task()) + await asyncio.sleep(0) + + assert dispatcher.has_active_waiter(event) is True + + await dispatcher._session_waiters.fail( + event.unified_msg_origin, RuntimeError("stop waiter") + ) + with pytest.raises(RuntimeError, match="stop waiter"): + await task + assert dispatcher.has_active_waiter(event) is False + + +@pytest.mark.asyncio +async def test_session_waiter_fail_defaults_to_current_caller_plugin_scope() -> None: + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher(plugin_id="worker-group", peer=peer, handlers=[]) + event_alpha = _build_event(text="hello", session_id="session-1", peer=peer) + event_alpha._context = Context(peer=peer, plugin_id="plugin.alpha") + event_beta = _build_event(text="hello", session_id="session-1", peer=peer) + event_beta._context = Context(peer=peer, plugin_id="plugin.beta") + + async def waiter_alpha() -> None: + with caller_plugin_scope("plugin.alpha"): + await dispatcher._session_waiters.register( + event=event_alpha, + handler=_noop_waiter, + timeout=30, + record_history_chains=False, + ) + + async def waiter_beta() -> None: + with caller_plugin_scope("plugin.beta"): + await dispatcher._session_waiters.register( + event=event_beta, + handler=_noop_waiter, + timeout=30, + record_history_chains=False, + ) + + task_alpha = asyncio.create_task(waiter_alpha()) + task_beta = asyncio.create_task(waiter_beta()) + await asyncio.sleep(0) + + with caller_plugin_scope("plugin.alpha"): + assert ( + await dispatcher._session_waiters.fail( + "session-1", + RuntimeError("stop alpha"), + ) + is True + ) + + with pytest.raises(RuntimeError, match="stop alpha"): + await task_alpha + assert dispatcher.has_active_waiter(event_beta) is True + + await dispatcher._session_waiters.fail( + "session-1", + RuntimeError("stop beta"), + plugin_id="plugin.beta", + ) + with pytest.raises(RuntimeError, match="stop beta"): + await task_beta + + +@pytest.mark.asyncio +async def test_session_waiter_dispatch_preserves_source_event_payload() -> None: + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher(plugin_id="test-plugin", peer=peer, handlers=[]) + event_payload = { + "type": "message", + "event_type": "message", + "text": "followup", + "session_id": "session-1", + "user_id": "tester", + "platform": "test", + "platform_id": "test", + "message_type": "private", + "target": {"conversation_id": "session-1", "platform": "test"}, + "raw": {"event_type": "message"}, + } + event = MessageEvent.from_payload( + event_payload, + context=Context(peer=peer, plugin_id="test-plugin"), + ) + seen_payloads: list[dict[str, object]] = [] + + async def capture_waiter( + controller: SessionController, + waiter_event: MessageEvent, + ) -> None: + source_payload = waiter_event._context._source_event_payload + seen_payloads.append(dict(source_payload)) + controller.stop() + + async def waiter_task() -> None: + with caller_plugin_scope("test-plugin"): + await dispatcher._session_waiters.register( + event=event, + handler=capture_waiter, + timeout=30, + record_history_chains=False, + ) + + task = asyncio.create_task(waiter_task()) + await asyncio.sleep(0) + + await dispatcher.invoke( + InvokeMessage( + id="req-session-waiter", + capability="handler.invoke", + input={ + "handler_id": "__sdk_session_waiter__", + "event": dict(event_payload), + "args": {}, + }, + ), + CancelToken(), + ) + await task + + assert seen_payloads == [event_payload] + + +@pytest.mark.asyncio +async def test_session_waiter_dispatch_serializes_followups_per_waiter() -> None: + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher(plugin_id="test-plugin", peer=peer, handlers=[]) + event = _build_event(text="hello", session_id="session-serial", peer=peer) + handler_entered = asyncio.Event() + release_handler = asyncio.Event() + invocations: list[str] = [] + + async def slow_waiter( + controller: SessionController, + waiter_event: MessageEvent, + ) -> None: + invocations.append(waiter_event.text) + handler_entered.set() + await release_handler.wait() + controller.stop() + + async def waiter_task() -> None: + with caller_plugin_scope("test-plugin"): + await dispatcher._session_waiters.register( + event=event, + handler=slow_waiter, + timeout=30, + record_history_chains=False, + ) + + task = asyncio.create_task(waiter_task()) + await asyncio.sleep(0) + + first_followup = _build_event(text="first", session_id="session-serial", peer=peer) + second_followup = _build_event( + text="second", + session_id="session-serial", + peer=peer, + ) + + first_dispatch = asyncio.create_task( + dispatcher._session_waiters.dispatch(first_followup) + ) + await handler_entered.wait() + + second_dispatch = asyncio.create_task( + dispatcher._session_waiters.dispatch(second_followup) + ) + await asyncio.sleep(0) + + assert invocations == ["first"] + assert second_dispatch.done() is False + + release_handler.set() + + await first_dispatch + await second_dispatch + await task + + assert invocations == ["first"] + + +@pytest.mark.asyncio +async def test_has_active_waiter_ignores_completed_waiter_before_unregister() -> None: + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher(plugin_id="test-plugin", peer=peer, handlers=[]) + event = _build_event(text="hello", session_id="session-1", peer=peer) + release_unregister = asyncio.Event() + manager = dispatcher._session_waiters + original_unregister = manager.unregister + + async def delayed_unregister( + session_key: str, + *, + plugin_id: str | None = None, + ) -> None: + await release_unregister.wait() + await original_unregister(session_key, plugin_id=plugin_id) + + manager.unregister = delayed_unregister # type: ignore[method-assign] + + async def waiter_task() -> None: + with caller_plugin_scope("test-plugin"): + await manager.register( + event=event, + handler=_noop_waiter, + timeout=30, + record_history_chains=False, + ) + + task = asyncio.create_task(waiter_task()) + await asyncio.sleep(0) + + assert dispatcher.has_active_waiter(event) is True + + await manager.fail(event.unified_msg_origin, RuntimeError("stop waiter")) + await asyncio.sleep(0) + + assert dispatcher.has_active_waiter(event) is False + + release_unregister.set() + with pytest.raises(RuntimeError, match="stop waiter"): + await task + + +@pytest.mark.asyncio +async def test_session_waiter_dispatch_uses_registering_plugin_id() -> None: + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher( + plugin_id="worker-group", + peer=peer, + handlers=[], + ) + event_payload = { + "type": "message", + "event_type": "message", + "text": "followup", + "session_id": "session-1", + "user_id": "tester", + "platform": "test", + "platform_id": "test", + "message_type": "private", + "raw": {"event_type": "message"}, + } + register_event = MessageEvent.from_payload( + event_payload, + context=Context(peer=peer, plugin_id="plugin.alpha"), + ) + seen_plugin_ids: list[str] = [] + + async def capture_waiter( + controller: SessionController, + waiter_event: MessageEvent, + ) -> None: + seen_plugin_ids.append(waiter_event._context.plugin_id) + controller.stop() + + async def waiter_task() -> None: + with caller_plugin_scope("plugin.alpha"): + await dispatcher._session_waiters.register( + event=register_event, + handler=capture_waiter, + timeout=30, + record_history_chains=False, + ) + + task = asyncio.create_task(waiter_task()) + await asyncio.sleep(0) + + await dispatcher.invoke( + InvokeMessage( + id="req-session-waiter-plugin-id", + capability="handler.invoke", + input={ + "handler_id": "__sdk_session_waiter__", + "event": dict(event_payload), + "args": {}, + }, + ), + CancelToken(), + ) + await task + + assert seen_plugin_ids == ["plugin.alpha"] + + +@pytest.mark.asyncio +async def test_session_waiter_dispatch_resolves_session_from_target_payload() -> None: + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher( + plugin_id="worker-group", + peer=peer, + handlers=[], + ) + register_payload = { + "type": "message", + "event_type": "message", + "text": "followup", + "session_id": "session-1", + "user_id": "tester", + "platform": "test", + "platform_id": "test", + "message_type": "private", + "raw": {"event_type": "message"}, + } + target_only_payload = { + "type": "message", + "event_type": "message", + "text": "followup", + "user_id": "tester", + "platform_id": "test", + "message_type": "private", + "target": {"conversation_id": "session-1", "platform": "test"}, + "raw": {"event_type": "message"}, + } + register_event = MessageEvent.from_payload( + register_payload, + context=Context(peer=peer, plugin_id="plugin.alpha"), + ) + seen_plugin_ids: list[str] = [] + + async def capture_waiter( + controller: SessionController, + waiter_event: MessageEvent, + ) -> None: + seen_plugin_ids.append(waiter_event._context.plugin_id) + controller.stop() + + async def waiter_task() -> None: + with caller_plugin_scope("plugin.alpha"): + await dispatcher._session_waiters.register( + event=register_event, + handler=capture_waiter, + timeout=30, + record_history_chains=False, + ) + + task = asyncio.create_task(waiter_task()) + await asyncio.sleep(0) + + await dispatcher.invoke( + InvokeMessage( + id="req-session-waiter-target-only", + capability="handler.invoke", + input={ + "handler_id": "__sdk_session_waiter__", + "event": dict(target_only_payload), + "args": {}, + }, + ), + CancelToken(), + ) + await task + + assert seen_plugin_ids == ["plugin.alpha"] + + +@pytest.mark.asyncio +async def test_session_waiters_do_not_replace_across_plugins_same_session() -> None: + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher( + plugin_id="worker-group", + peer=peer, + handlers=[], + ) + event_payload = { + "type": "message", + "event_type": "message", + "text": "followup", + "session_id": "session-1", + "user_id": "tester", + "platform": "test", + "platform_id": "test", + "message_type": "private", + "raw": {"event_type": "message"}, + } + event_a = MessageEvent.from_payload( + event_payload, + context=Context(peer=peer, plugin_id="plugin.alpha"), + ) + event_b = MessageEvent.from_payload( + event_payload, + context=Context(peer=peer, plugin_id="plugin.beta"), + ) + seen_plugin_ids: list[str] = [] + + async def waiter_alpha( + controller: SessionController, + waiter_event: MessageEvent, + ) -> None: + seen_plugin_ids.append(waiter_event._context.plugin_id) + controller.stop() + + async def waiter_beta( + controller: SessionController, + waiter_event: MessageEvent, + ) -> None: + seen_plugin_ids.append(waiter_event._context.plugin_id) + controller.stop() + + async def task_alpha() -> None: + with caller_plugin_scope("plugin.alpha"): + await dispatcher._session_waiters.register( + event=event_a, + handler=waiter_alpha, + timeout=30, + record_history_chains=False, + ) + + async def task_beta() -> None: + with caller_plugin_scope("plugin.beta"): + await dispatcher._session_waiters.register( + event=event_b, + handler=waiter_beta, + timeout=30, + record_history_chains=False, + ) + + waiter_task_alpha = asyncio.create_task(task_alpha()) + waiter_task_beta = asyncio.create_task(task_beta()) + await asyncio.sleep(0) + + assert sorted(dispatcher._session_waiters.get_waiter_plugin_ids("session-1")) == [ + "plugin.alpha", + "plugin.beta", + ] + + await dispatcher._session_waiters.dispatch( + MessageEvent.from_payload( + event_payload, + context=Context(peer=peer, plugin_id="plugin.alpha"), + ), + plugin_id="plugin.alpha", + ) + await dispatcher._session_waiters.dispatch( + MessageEvent.from_payload( + event_payload, + context=Context(peer=peer, plugin_id="plugin.beta"), + ), + plugin_id="plugin.beta", + ) + await waiter_task_alpha + await waiter_task_beta + + assert sorted(seen_plugin_ids) == ["plugin.alpha", "plugin.beta"] + + +@pytest.mark.asyncio +async def test_session_waiter_dispatch_accepts_explicit_plugin_id_for_ambiguous_session() -> ( + None +): + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher( + plugin_id="worker-group", + peer=peer, + handlers=[], + ) + event_payload = { + "type": "message", + "event_type": "message", + "text": "followup", + "session_id": "session-1", + "user_id": "tester", + "platform": "test", + "platform_id": "test", + "message_type": "private", + "raw": {"event_type": "message"}, + } + event_a = MessageEvent.from_payload( + event_payload, + context=Context(peer=peer, plugin_id="plugin.alpha"), + ) + event_b = MessageEvent.from_payload( + event_payload, + context=Context(peer=peer, plugin_id="plugin.beta"), + ) + seen_plugin_ids: list[str] = [] + + async def waiter_alpha( + controller: SessionController, + waiter_event: MessageEvent, + ) -> None: + seen_plugin_ids.append(waiter_event._context.plugin_id) + controller.stop() + + async def waiter_beta( + controller: SessionController, + waiter_event: MessageEvent, + ) -> None: + seen_plugin_ids.append(waiter_event._context.plugin_id) + controller.stop() + + async def task_alpha() -> None: + with caller_plugin_scope("plugin.alpha"): + await dispatcher._session_waiters.register( + event=event_a, + handler=waiter_alpha, + timeout=30, + record_history_chains=False, + ) + + async def task_beta() -> None: + with caller_plugin_scope("plugin.beta"): + await dispatcher._session_waiters.register( + event=event_b, + handler=waiter_beta, + timeout=30, + record_history_chains=False, + ) + + waiter_task_alpha = asyncio.create_task(task_alpha()) + waiter_task_beta = asyncio.create_task(task_beta()) + await asyncio.sleep(0) + + with pytest.raises(LookupError, match="explicit plugin identity"): + await dispatcher.invoke( + InvokeMessage( + id="req-session-waiter-ambiguous", + capability="handler.invoke", + input={ + "handler_id": "__sdk_session_waiter__", + "event": dict(event_payload), + "args": {}, + }, + ), + CancelToken(), + ) + + await dispatcher.invoke( + InvokeMessage( + id="req-session-waiter-explicit-alpha", + capability="handler.invoke", + input={ + "handler_id": "__sdk_session_waiter__", + "plugin_id": "plugin.alpha", + "event": dict(event_payload), + "args": {}, + }, + ), + CancelToken(), + ) + await dispatcher.invoke( + InvokeMessage( + id="req-session-waiter-explicit-beta", + capability="handler.invoke", + input={ + "handler_id": "__sdk_session_waiter__", + "plugin_id": "plugin.beta", + "event": dict(event_payload), + "args": {}, + }, + ), + CancelToken(), + ) + + await waiter_task_alpha + await waiter_task_beta + + assert sorted(seen_plugin_ids) == ["plugin.alpha", "plugin.beta"] + + +@pytest.mark.asyncio +async def test_fail_without_plugin_id_does_not_broadcast_across_plugins() -> None: + peer = MockPeer(MockCapabilityRouter()) + dispatcher = HandlerDispatcher( + plugin_id="worker-group", + peer=peer, + handlers=[], + ) + event_payload = { + "type": "message", + "event_type": "message", + "text": "followup", + "session_id": "session-1", + "user_id": "tester", + "platform": "test", + "platform_id": "test", + "message_type": "private", + "raw": {"event_type": "message"}, + } + event_a = MessageEvent.from_payload( + event_payload, + context=Context(peer=peer, plugin_id="plugin.alpha"), + ) + event_b = MessageEvent.from_payload( + event_payload, + context=Context(peer=peer, plugin_id="plugin.beta"), + ) + + async def waiter_alpha() -> None: + with caller_plugin_scope("plugin.alpha"): + await dispatcher._session_waiters.register( + event=event_a, + handler=_noop_waiter, + timeout=30, + record_history_chains=False, + ) + + async def waiter_beta() -> None: + with caller_plugin_scope("plugin.beta"): + await dispatcher._session_waiters.register( + event=event_b, + handler=_noop_waiter, + timeout=30, + record_history_chains=False, + ) + + task_a = asyncio.create_task(waiter_alpha()) + task_b = asyncio.create_task(waiter_beta()) + await asyncio.sleep(0) + + assert ( + await dispatcher._session_waiters.fail( + "session-1", + RuntimeError("stop waiter"), + ) + is False + ) + assert dispatcher.has_active_waiter(event_a) is True + assert dispatcher.has_active_waiter(event_b) is True + + await dispatcher._session_waiters.fail( + "session-1", + RuntimeError("stop alpha"), + plugin_id="plugin.alpha", + ) + with pytest.raises(RuntimeError, match="stop alpha"): + await task_a + + assert dispatcher.has_active_waiter(event_b) is True + await dispatcher._session_waiters.fail( + "session-1", + RuntimeError("stop beta"), + plugin_id="plugin.beta", + ) + with pytest.raises(RuntimeError, match="stop beta"): + await task_b + + +@pytest.mark.asyncio +async def test_plugin_harness_waits_for_waiter_side_effects_after_stop( + tmp_path: Path, +) -> None: + plugin_dir = tmp_path / "session_waiter_stop_after_side_effects" + _write_session_waiter_plugin(plugin_dir) + (plugin_dir / "main.py").write_text( + "\n".join( + [ + "import asyncio", + "from astrbot_sdk import Context, MessageEvent, SessionController, Star, on_command, session_waiter", + "", + "", + "class SessionWaiterPlugin(Star):", + ' @on_command("start")', + " async def start(self, event: MessageEvent, ctx: Context) -> None:", + ' await event.reply("ready")', + ' await ctx.register_task(self.wait_for_followup(event), "wait for followup")', + "", + " @session_waiter(timeout=30)", + " async def wait_for_followup(", + " self,", + " controller: SessionController,", + " event: MessageEvent,", + " ) -> None:", + " controller.stop()", + " await asyncio.sleep(0)", + ' await event.reply(f"followup:{event.text}")', + "", + ] + ), + encoding="utf-8", + ) + + async with PluginHarness.from_plugin_dir(plugin_dir) as harness: + first_records = await harness.dispatch_text("start", session_id="session-1") + second_records = await harness.dispatch_text("next", session_id="session-1") + + assert [record.text for record in first_records] == ["ready"] + assert [record.text for record in second_records] == ["followup:next"] + + +async def _noop_waiter( + controller: SessionController, + waiter_event: MessageEvent, +) -> None: + del waiter_event + controller.stop() diff --git a/astrbot/__init__.py b/astrbot/__init__.py index 73d64f303f..f7604c5b15 100644 --- a/astrbot/__init__.py +++ b/astrbot/__init__.py @@ -1,3 +1,16 @@ -from .core.log import LogManager +from __future__ import annotations -logger = LogManager.GetLogger(log_name="astrbot") +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .core import logger as logger + +__all__ = ["logger"] + + +def __getattr__(name: str) -> Any: + if name == "logger": + from .core import logger + + return logger + raise AttributeError(name) diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index 51690ede27..a11435a84b 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -1,47 +1,185 @@ +from __future__ import annotations + import os +from importlib import import_module +from typing import TYPE_CHECKING, Any -from astrbot.core.config import AstrBotConfig -from astrbot.core.config.default import DB_PATH -from astrbot.core.db.sqlite import SQLiteDatabase -from astrbot.core.file_token_service import FileTokenService -from astrbot.core.utils.pip_installer import ( - DependencyConflictError as DependencyConflictError, -) -from astrbot.core.utils.pip_installer import ( - PipInstaller, -) -from astrbot.core.utils.requirements_utils import ( - RequirementsPrecheckFailed as RequirementsPrecheckFailed, -) -from astrbot.core.utils.requirements_utils import ( - find_missing_requirements as find_missing_requirements, -) -from astrbot.core.utils.requirements_utils import ( - find_missing_requirements_or_raise as find_missing_requirements_or_raise, -) -from astrbot.core.utils.shared_preferences import SharedPreferences -from astrbot.core.utils.t2i.renderer import HtmlRenderer - -from .log import LogBroker, LogManager # noqa from .utils.astrbot_path import get_astrbot_data_path -# 初始化数据存储文件夹 +if TYPE_CHECKING: + from .config import AstrBotConfig + from .db.sqlite import SQLiteDatabase + from .file_token_service import FileTokenService + from .log import LogBroker, LogManager + from .utils.pip_installer import DependencyConflictError, PipInstaller + from .utils.requirements_utils import ( + RequirementsPrecheckFailed, + find_missing_requirements, + find_missing_requirements_or_raise, + ) +else: + AstrBotConfig: Any + SQLiteDatabase: Any + FileTokenService: Any + LogBroker: Any + LogManager: Any + DependencyConflictError: Any + PipInstaller: Any + RequirementsPrecheckFailed: Any + find_missing_requirements: Any + find_missing_requirements_or_raise: Any + astrbot_config: Any + db_helper: Any + file_token_service: Any + html_renderer: Any + logger: Any + pip_installer: Any + sp: Any + os.makedirs(get_astrbot_data_path(), exist_ok=True) DEMO_MODE = os.getenv("DEMO_MODE", "False").strip().lower() in ("true", "1", "t") -astrbot_config = AstrBotConfig() -t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img") -html_renderer = HtmlRenderer(t2i_base_url) -logger = LogManager.GetLogger(log_name="astrbot") -LogManager.configure_logger(logger, astrbot_config) -LogManager.configure_trace_logger(astrbot_config) -db_helper = SQLiteDatabase(DB_PATH) -# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中 -sp = SharedPreferences(db_helper=db_helper) -# 文件令牌服务 -file_token_service = FileTokenService() -pip_installer = PipInstaller( - astrbot_config.get("pip_install_arg", ""), - astrbot_config.get("pypi_index_url", None), -) +__all__ = [ + "AstrBotConfig", + "DEMO_MODE", + "DependencyConflictError", + "FileTokenService", + "LogBroker", + "LogManager", + "PipInstaller", + "RequirementsPrecheckFailed", + "SQLiteDatabase", + "astrbot_config", + "db_helper", + "file_token_service", + "find_missing_requirements", + "find_missing_requirements_or_raise", + "html_renderer", + "logger", + "pip_installer", + "sp", +] + +_SINGLETON_CACHE: dict[str, Any] = {} + + +def _get_astrbot_config(): + config_module = import_module(".config", __name__) + cached = _SINGLETON_CACHE.get("astrbot_config") + if cached is None: + cached = config_module.AstrBotConfig() + _SINGLETON_CACHE["astrbot_config"] = cached + return cached + + +def _get_log_manager(): + return import_module(".log", __name__).LogManager + + +def _get_logger(): + cached = _SINGLETON_CACHE.get("logger") + if cached is None: + logger_obj = _get_log_manager().GetLogger(log_name="astrbot") + config = _get_astrbot_config() + log_manager = _get_log_manager() + log_manager.configure_logger(logger_obj, config) + log_manager.configure_trace_logger(config) + _SINGLETON_CACHE["logger"] = logger_obj + cached = logger_obj + return cached + + +def _get_db_helper(): + cached = _SINGLETON_CACHE.get("db_helper") + if cached is None: + sqlite_module = import_module(".db.sqlite", __name__) + default_module = import_module(".config.default", __name__) + cached = sqlite_module.SQLiteDatabase(default_module.DB_PATH) + _SINGLETON_CACHE["db_helper"] = cached + return cached + + +def _get_shared_preferences(): + cached = _SINGLETON_CACHE.get("sp") + if cached is None: + shared_preferences_module = import_module(".utils.shared_preferences", __name__) + cached = shared_preferences_module.SharedPreferences(db_helper=_get_db_helper()) + _SINGLETON_CACHE["sp"] = cached + return cached + + +def _get_file_token_service(): + cached = _SINGLETON_CACHE.get("file_token_service") + if cached is None: + service_module = import_module(".file_token_service", __name__) + cached = service_module.FileTokenService() + _SINGLETON_CACHE["file_token_service"] = cached + return cached + + +def _get_html_renderer(): + cached = _SINGLETON_CACHE.get("html_renderer") + if cached is None: + renderer_module = import_module(".utils.t2i.renderer", __name__) + config = _get_astrbot_config() + endpoint = config.get("t2i_endpoint", "https://t2i.soulter.top/text2img") + cached = renderer_module.HtmlRenderer(endpoint) + _SINGLETON_CACHE["html_renderer"] = cached + return cached + + +def _get_pip_installer(): + cached = _SINGLETON_CACHE.get("pip_installer") + if cached is None: + installer_module = import_module(".utils.pip_installer", __name__) + config = _get_astrbot_config() + cached = installer_module.PipInstaller( + config.get("pip_install_arg", ""), + config.get("pypi_index_url", None), + ) + _SINGLETON_CACHE["pip_installer"] = cached + return cached + + +def __getattr__(name: str) -> Any: + if name == "AstrBotConfig": + return import_module(".config", __name__).AstrBotConfig + if name in {"LogBroker", "LogManager"}: + module = import_module(".log", __name__) + return getattr(module, name) + if name == "DependencyConflictError": + return import_module(".utils.pip_installer", __name__).DependencyConflictError + if name == "FileTokenService": + return import_module(".file_token_service", __name__).FileTokenService + if name == "PipInstaller": + return import_module(".utils.pip_installer", __name__).PipInstaller + if name == "RequirementsPrecheckFailed": + return import_module( + ".utils.requirements_utils", __name__ + ).RequirementsPrecheckFailed + if name == "SQLiteDatabase": + return import_module(".db.sqlite", __name__).SQLiteDatabase + if name == "find_missing_requirements": + return import_module( + ".utils.requirements_utils", __name__ + ).find_missing_requirements + if name == "find_missing_requirements_or_raise": + return import_module( + ".utils.requirements_utils", __name__ + ).find_missing_requirements_or_raise + if name == "astrbot_config": + return _get_astrbot_config() + if name == "logger": + return _get_logger() + if name == "db_helper": + return _get_db_helper() + if name == "sp": + return _get_shared_preferences() + if name == "file_token_service": + return _get_file_token_service() + if name == "html_renderer": + return _get_html_renderer() + if name == "pip_installer": + return _get_pip_installer() + raise AttributeError(name) diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index af969a3fac..aceb2261ba 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -137,6 +137,7 @@ def __init__(self) -> None: self.tools: list[mcp.Tool] = [] self.server_errlogs: list[str] = [] self.running_event = asyncio.Event() + self.process_pid: int | None = None # Store connection config for reconnection self._mcp_server_config: dict | None = None @@ -144,6 +145,24 @@ def __init__(self) -> None: self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection self._reconnecting: bool = False # For logging and debugging + @staticmethod + def _extract_stdio_process_pid(streams_context: object) -> int | None: + """Best-effort extraction for stdio subprocess PID used by lease cleanup. + + TODO(refactor): replace this async-generator frame introspection with a + stable MCP library hook once the upstream transport exposes process PID. + """ + generator = getattr(streams_context, "gen", None) + frame = getattr(generator, "ag_frame", None) + if frame is None: + return None + process = frame.f_locals.get("process") + pid = getattr(process, "pid", None) + try: + return int(pid) if pid is not None else None + except (TypeError, ValueError): + return None + async def connect_to_server(self, mcp_server_config: dict, name: str) -> None: """Connect to MCP server @@ -159,6 +178,7 @@ async def connect_to_server(self, mcp_server_config: dict, name: str) -> None: # Store config for reconnection self._mcp_server_config = mcp_server_config self._server_name = name + self.process_pid = None cfg = _prepare_config(mcp_server_config.copy()) @@ -261,6 +281,7 @@ def callback(msg: str | mcp.types.LoggingMessageNotificationParams) -> None: ), # type: ignore ), ) + self.process_pid = self._extract_stdio_process_pid(self._streams_context) # Create a new client session self.session = await self.exit_stack.enter_async_context( @@ -390,6 +411,7 @@ async def cleanup(self) -> None: # Set running_event first to unblock any waiting tasks self.running_event.set() + self.process_pid = None class MCPTool(FunctionTool, Generic[TContext]): diff --git a/astrbot/core/astr_agent_hooks.py b/astrbot/core/astr_agent_hooks.py index 09bf32deb4..86f76a89db 100644 --- a/astrbot/core/astr_agent_hooks.py +++ b/astrbot/core/astr_agent_hooks.py @@ -11,7 +11,42 @@ from astrbot.core.star.star_handler import EventType +def _sdk_safe_payload(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, list): + return [_sdk_safe_payload(item) for item in value] + if isinstance(value, dict): + return {str(key): _sdk_safe_payload(item) for key, item in value.items()} + model_dump = getattr(value, "model_dump", None) + if callable(model_dump): + try: + dumped = model_dump() + except Exception: + return str(value) + return _sdk_safe_payload(dumped) + return str(value) + + class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]): + async def on_agent_begin( + self, + run_context: ContextWrapper[AstrAgentContext], + ) -> None: + sdk_plugin_bridge = getattr( + run_context.context.context, "sdk_plugin_bridge", None + ) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_message_event( + "agent_begin", + run_context.context.event, + ) + except Exception as exc: + from astrbot.core import logger + + logger.warning("SDK agent_begin dispatch failed: %s", exc) + async def on_agent_done(self, run_context, llm_response) -> None: # 执行事件钩子 if llm_response and llm_response.reasoning_content: @@ -25,6 +60,30 @@ async def on_agent_done(self, run_context, llm_response) -> None: EventType.OnLLMResponseEvent, llm_response, ) + sdk_plugin_bridge = getattr( + run_context.context.context, "sdk_plugin_bridge", None + ) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_message_event( + "agent_done", + run_context.context.event, + { + "completion_text": ( + llm_response.completion_text if llm_response else "" + ), + "tool_call_names": ( + list(llm_response.tools_call_name) + if llm_response and llm_response.tools_call_name + else [] + ), + }, + llm_response=llm_response, + ) + except Exception as exc: + from astrbot.core import logger + + logger.warning("SDK agent_done dispatch failed: %s", exc) async def on_tool_start( self, @@ -38,6 +97,23 @@ async def on_tool_start( tool, tool_args, ) + sdk_plugin_bridge = getattr( + run_context.context.context, "sdk_plugin_bridge", None + ) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_message_event( + "llm_tool_start", + run_context.context.event, + { + "tool_name": tool.name, + "tool_args": _sdk_safe_payload(tool_args), + }, + ) + except Exception as exc: + from astrbot.core import logger + + logger.warning("SDK llm_tool_start dispatch failed: %s", exc) async def on_tool_end( self, @@ -54,6 +130,24 @@ async def on_tool_end( tool_args, tool_result, ) + sdk_plugin_bridge = getattr( + run_context.context.context, "sdk_plugin_bridge", None + ) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_message_event( + "llm_tool_end", + run_context.context.event, + { + "tool_name": tool.name, + "tool_args": _sdk_safe_payload(tool_args), + "tool_result": _sdk_safe_payload(tool_result), + }, + ) + except Exception as exc: + from astrbot.core import logger + + logger.warning("SDK llm_tool_end dispatch failed: %s", exc) # special handle web_search_tavily platform_name = run_context.context.event.get_platform_name() diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index 1fb4b03368..a3154a38af 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -586,6 +586,24 @@ async def _execute_local( if awaitable is None: raise ValueError("Tool must have a valid handler or override 'run' method.") + sdk_plugin_bridge = getattr( + run_context.context.context, "sdk_plugin_bridge", None + ) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_message_event( + "calling_func_tool", + event, + { + "tool_name": tool.name, + "tool_args": json.loads( + json.dumps(tool_args, ensure_ascii=False, default=str) + ), + }, + ) + except Exception as exc: + logger.warning("SDK calling_func_tool dispatch failed: %s", exc) + wrapper = call_local_llm_tool( context=run_context, handler=awaitable, diff --git a/astrbot/core/computer/computer_client.py b/astrbot/core/computer/computer_client.py index 715f938679..579d80a97c 100644 --- a/astrbot/core/computer/computer_client.py +++ b/astrbot/core/computer/computer_client.py @@ -20,17 +20,6 @@ _MANAGED_SKILLS_FILE = ".astrbot_managed_skills.json" -def _list_local_skill_dirs(skills_root: Path) -> list[Path]: - skills: list[Path] = [] - for entry in sorted(skills_root.iterdir()): - if not entry.is_dir(): - continue - skill_md = entry / "SKILL.md" - if skill_md.exists(): - skills.append(entry) - return skills - - def _discover_bay_credentials(endpoint: str) -> str: """Try to auto-discover Bay API key from credentials.json. @@ -383,20 +372,25 @@ async def _sync_skills_to_sandbox(booter: ComputerBooter) -> None: splitting into `apply` and `scan` phases. """ skills_root = Path(get_astrbot_skills_path()) - if not skills_root.is_dir(): - return - local_skill_dirs = _list_local_skill_dirs(skills_root) + skill_manager: SkillManager | None = None + local_skill_sources = [] + if skills_root.exists(): + skill_manager = SkillManager(skills_root=str(skills_root)) + local_skill_sources = skill_manager.list_local_skill_sources() temp_dir = Path(get_astrbot_temp_path()) temp_dir.mkdir(parents=True, exist_ok=True) zip_base = temp_dir / "skills_bundle" zip_path = zip_base.with_suffix(".zip") + bundle_dir = temp_dir / f"skills_bundle_{uuid.uuid4().hex}" try: - if local_skill_dirs: + if local_skill_sources: + assert skill_manager is not None if zip_path.exists(): zip_path.unlink() - shutil.make_archive(str(zip_base), "zip", str(skills_root)) + skill_manager.materialize_local_skill_bundle(bundle_dir) + shutil.make_archive(str(zip_base), "zip", root_dir=str(bundle_dir)) remote_zip = Path(SANDBOX_SKILLS_ROOT) / "skills.zip" logger.info("Uploading skills bundle to sandbox...") await booter.shell.exec(f"mkdir -p {SANDBOX_SKILLS_ROOT}") @@ -420,6 +414,8 @@ async def _sync_skills_to_sandbox(booter: ComputerBooter) -> None: len(managed), ) finally: + if bundle_dir.exists(): + shutil.rmtree(bundle_dir, ignore_errors=True) if zip_path.exists(): try: zip_path.unlink() diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 77c298cac8..6a38311c67 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -2,6 +2,7 @@ import json import logging import os +from pathlib import Path from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -46,6 +47,7 @@ def __init__( if not self.check_exist(): """不存在时载入默认配置""" + Path(config_path).parent.mkdir(parents=True, exist_ok=True) with open(config_path, "w", encoding="utf-8-sig") as f: json.dump(default_config, f, indent=4, ensure_ascii=False) object.__setattr__(self, "first_deploy", True) # 标记第一次部署 @@ -158,6 +160,8 @@ def save_config(self, replace_config: dict | None = None) -> None: """ if replace_config: self.update(replace_config) + # Alternate config files may be created under data/config on first write. + Path(self.config_path).parent.mkdir(parents=True, exist_ok=True) with open(self.config_path, "w", encoding="utf-8-sig") as f: json.dump(self, f, indent=2, ensure_ascii=False) diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 2c282867f9..48e44dcd8c 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -262,6 +262,7 @@ async def update_conversation( history: list[dict] | None = None, title: str | None = None, persona_id: str | None = None, + clear_persona: bool = False, token_usage: int | None = None, ) -> None: """更新会话的对话. @@ -281,6 +282,7 @@ async def update_conversation( cid=conversation_id, title=title, persona_id=persona_id, + clear_persona=clear_persona, content=history, token_usage=token_usage, ) @@ -329,6 +331,19 @@ async def update_conversation_persona_id( persona_id=persona_id, ) + async def unset_conversation_persona( + self, + unified_msg_origin: str, + conversation_id: str | None = None, + ) -> None: + """Clear the conversation-specific persona override and fall back to default.""" + + await self.update_conversation( + unified_msg_origin=unified_msg_origin, + conversation_id=conversation_id, + clear_persona=True, + ) + async def add_message_pair( self, cid: str, diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index fe6b1c351d..fc6a95e29e 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -16,8 +16,7 @@ import traceback from asyncio import Queue -from astrbot.api import logger, sp -from astrbot.core import LogBroker, LogManager +from astrbot.core import LogBroker, LogManager, logger, sp from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.config.default import VERSION from astrbot.core.conversation_mgr import ConversationManager @@ -29,6 +28,7 @@ from astrbot.core.platform.manager import PlatformManager from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager from astrbot.core.provider.manager import ProviderManager +from astrbot.core.sdk_bridge import SdkPluginBridge from astrbot.core.star.context import Context from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map from astrbot.core.star.star_manager import PluginManager @@ -200,6 +200,11 @@ async def initialize(self) -> None: # 扫描、注册插件、实例化插件类 await self.plugin_manager.reload() + self.sdk_plugin_bridge = SdkPluginBridge(self.star_context) + self.star_context.sdk_plugin_bridge = self.sdk_plugin_bridge + self.platform_manager.sdk_plugin_bridge = self.sdk_plugin_bridge + await self.sdk_plugin_bridge.start() + # 根据配置实例化各个 Provider await self.provider_manager.initialize() @@ -309,6 +314,12 @@ async def start(self) -> None: except BaseException: logger.error(traceback.format_exc()) + if getattr(self, "sdk_plugin_bridge", None) is not None: + try: + await self.sdk_plugin_bridge.dispatch_system_event("astrbot_loaded") + except Exception as exc: + logger.warning(f"SDK astrbot_loaded event dispatch failed: {exc}") + # 同时运行curr_tasks中的所有任务 await asyncio.gather(*self.curr_tasks, return_exceptions=True) @@ -324,6 +335,9 @@ async def stop(self) -> None: if self.cron_manager: await self.cron_manager.shutdown() + if getattr(self, "sdk_plugin_bridge", None) is not None: + await self.sdk_plugin_bridge.stop() + for plugin in self.plugin_manager.context.get_all_stars(): try: await self.plugin_manager._terminate_plugin(plugin) @@ -349,6 +363,8 @@ async def stop(self) -> None: async def restart(self) -> None: """重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例""" + if getattr(self, "sdk_plugin_bridge", None) is not None: + await self.sdk_plugin_bridge.stop() await self.provider_manager.terminate() await self.platform_manager.terminate() await self.kb_manager.terminate() diff --git a/astrbot/core/cron/manager.py b/astrbot/core/cron/manager.py index ff7facd247..24c8ab3872 100644 --- a/astrbot/core/cron/manager.py +++ b/astrbot/core/cron/manager.py @@ -8,6 +8,7 @@ from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.triggers.cron import CronTrigger from apscheduler.triggers.date import DateTrigger +from apscheduler.triggers.interval import IntervalTrigger from astrbot import logger from astrbot.core.agent.tool import ToolSet @@ -65,7 +66,8 @@ async def add_basic_job( self, *, name: str, - cron_expression: str, + cron_expression: str | None = None, + interval_seconds: int | None = None, handler: Callable[..., Any | Awaitable[Any]], description: str | None = None, timezone: str | None = None, @@ -73,12 +75,19 @@ async def add_basic_job( enabled: bool = True, persistent: bool = False, ) -> CronJob: + if (cron_expression is None) == (interval_seconds is None): + raise ValueError( + "cron_expression and interval_seconds must have exactly one value" + ) + payload_data = dict(payload or {}) + if interval_seconds is not None: + payload_data["interval_seconds"] = interval_seconds job = await self.db.create_cron_job( name=name, job_type="basic", cron_expression=cron_expression, timezone=timezone, - payload=payload or {}, + payload=payload_data, description=description, enabled=enabled, persistent=persistent, @@ -167,7 +176,21 @@ def _schedule_job(self, job: CronJob) -> None: run_at = run_at.replace(tzinfo=tzinfo) trigger = DateTrigger(run_date=run_at, timezone=tzinfo) else: - trigger = CronTrigger.from_crontab(job.cron_expression, timezone=tzinfo) + interval_seconds = None + if isinstance(job.payload, dict): + payload_interval = job.payload.get("interval_seconds") + if isinstance(payload_interval, int): + interval_seconds = payload_interval + if interval_seconds is not None: + trigger = IntervalTrigger( + seconds=interval_seconds, + timezone=tzinfo, + ) + else: + trigger = CronTrigger.from_crontab( + job.cron_expression, + timezone=tzinfo, + ) self.scheduler.add_job( self._run_job, id=job.job_id, diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index a18c127ebf..fbded9f212 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -164,6 +164,7 @@ async def update_conversation( cid: str, title: str | None = None, persona_id: str | None = None, + clear_persona: bool = False, content: list[dict] | None = None, token_usage: int | None = None, ) -> None: @@ -213,6 +214,57 @@ async def get_platform_message_history( """Get platform message history for a specific user.""" ... + @abc.abstractmethod + async def list_sdk_platform_message_history( + self, + platform_id: str, + user_id: str, + cursor_id: int | None = None, + limit: int = 50, + include_total: bool = False, + ) -> tuple[list[PlatformMessageHistory], int | None]: + """List SDK message history records ordered by descending id.""" + ... + + @abc.abstractmethod + async def delete_platform_message_before( + self, + platform_id: str, + user_id: str, + before: datetime.datetime, + ) -> int: + """Delete platform message history records strictly older than ``before``.""" + ... + + @abc.abstractmethod + async def delete_platform_message_after( + self, + platform_id: str, + user_id: str, + after: datetime.datetime, + ) -> int: + """Delete platform message history records strictly newer than ``after``.""" + ... + + @abc.abstractmethod + async def delete_all_platform_message_history( + self, + platform_id: str, + user_id: str, + ) -> int: + """Delete all platform message history records for a specific user.""" + ... + + @abc.abstractmethod + async def find_platform_message_history_by_idempotency_key( + self, + platform_id: str, + user_id: str, + idempotency_key: str, + ) -> PlatformMessageHistory | None: + """Find one message history record by the SDK idempotency key.""" + ... + @abc.abstractmethod async def get_platform_message_history_by_id( self, diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index c8e50909d5..52cf12b5e0 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -294,7 +294,13 @@ async def create_conversation( return new_conversation async def update_conversation( - self, cid, title=None, persona_id=None, content=None, token_usage=None + self, + cid, + title=None, + persona_id=None, + clear_persona: bool = False, + content=None, + token_usage=None, ): async with self.get_db() as session: session: AsyncSession @@ -305,7 +311,9 @@ async def update_conversation( values = {} if title is not None: values["title"] = title - if persona_id is not None: + if clear_persona: + values["persona_id"] = None + elif persona_id is not None: values["persona_id"] = persona_id if content is not None: values["content"] = content @@ -510,6 +518,121 @@ async def get_platform_message_history( result = await session.execute(query.offset(offset).limit(page_size)) return result.scalars().all() + async def list_sdk_platform_message_history( + self, + platform_id, + user_id, + cursor_id=None, + limit=50, + include_total=False, + ): + """List SDK message history records ordered by descending id.""" + async with self.get_db() as session: + session: AsyncSession + query = ( + select(PlatformMessageHistory) + .where( + PlatformMessageHistory.platform_id == platform_id, + PlatformMessageHistory.user_id == user_id, + ) + .order_by(desc(PlatformMessageHistory.id)) + ) + if cursor_id is not None: + query = query.where(PlatformMessageHistory.id < cursor_id) + result = await session.execute(query.limit(limit)) + total: int | None = None + if include_total: + total_query = ( + select(func.count()) + .select_from(PlatformMessageHistory) + .where( + PlatformMessageHistory.platform_id == platform_id, + PlatformMessageHistory.user_id == user_id, + ) + ) + total_result = await session.execute(total_query) + total = int(total_result.scalar() or 0) + return list(result.scalars().all()), total + + async def delete_platform_message_before( + self, + platform_id, + user_id, + before, + ) -> int: + """Delete platform message history records strictly older than the boundary.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + result = await session.execute( + delete(PlatformMessageHistory).where( + col(PlatformMessageHistory.platform_id) == platform_id, + col(PlatformMessageHistory.user_id) == user_id, + col(PlatformMessageHistory.created_at) < before, + ), + ) + return int(result.rowcount or 0) + + async def delete_platform_message_after( + self, + platform_id, + user_id, + after, + ) -> int: + """Delete platform message history records strictly newer than the boundary.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + result = await session.execute( + delete(PlatformMessageHistory).where( + col(PlatformMessageHistory.platform_id) == platform_id, + col(PlatformMessageHistory.user_id) == user_id, + col(PlatformMessageHistory.created_at) > after, + ), + ) + return int(result.rowcount or 0) + + async def delete_all_platform_message_history( + self, + platform_id, + user_id, + ) -> int: + """Delete all platform message history records for a specific user.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + result = await session.execute( + delete(PlatformMessageHistory).where( + col(PlatformMessageHistory.platform_id) == platform_id, + col(PlatformMessageHistory.user_id) == user_id, + ), + ) + return int(result.rowcount or 0) + + async def find_platform_message_history_by_idempotency_key( + self, + platform_id, + user_id, + idempotency_key, + ) -> PlatformMessageHistory | None: + """Find a SDK message history record by its idempotency key.""" + async with self.get_db() as session: + session: AsyncSession + query = ( + select(PlatformMessageHistory) + .where( + PlatformMessageHistory.platform_id == platform_id, + PlatformMessageHistory.user_id == user_id, + func.json_extract( + PlatformMessageHistory.content, "$.idempotency_key" + ) + == str(idempotency_key), + ) + .order_by(desc(PlatformMessageHistory.id)) + ) + result = await session.execute(query.limit(1)) + return result.scalar_one_or_none() + async def get_platform_message_history_by_id( self, message_id: int ) -> PlatformMessageHistory | None: diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py index f26409e56e..43a7987980 100644 --- a/astrbot/core/knowledge_base/kb_mgr.py +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import traceback from pathlib import Path +from typing import TYPE_CHECKING from astrbot.core import logger from astrbot.core.provider.manager import ProviderManager @@ -10,9 +13,9 @@ from .kb_db_sqlite import KBSQLiteDatabase from .kb_helper import KBHelper from .models import KBDocument, KnowledgeBase -from .retrieval.manager import RetrievalManager, RetrievalResult -from .retrieval.rank_fusion import RankFusion -from .retrieval.sparse_retriever import SparseRetriever + +if TYPE_CHECKING: + from .retrieval.manager import RetrievalManager, RetrievalResult FILES_PATH = get_astrbot_knowledge_base_path() DB_PATH = Path(FILES_PATH) / "kb.db" @@ -37,6 +40,10 @@ def __init__( async def initialize(self) -> None: """初始化知识库模块""" try: + from .retrieval.manager import RetrievalManager + from .retrieval.rank_fusion import RankFusion + from .retrieval.sparse_retriever import SparseRetriever + logger.info("正在初始化知识库模块...") # 初始化数据库 diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 523d758a0a..1978b30cc8 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -180,6 +180,20 @@ async def process( await event.send_typing() await call_event_hook(event, EventType.OnWaitingLLMRequestEvent) + sdk_plugin_bridge = getattr( + self.ctx.plugin_manager.context, "sdk_plugin_bridge", None + ) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_message_event( + "waiting_llm_request", + event, + ) + except Exception as exc: + logger.warning( + "SDK waiting_llm_request dispatch failed: %s", + exc, + ) async with session_lock_manager.acquire_lock(event.unified_msg_origin): logger.debug("acquired session lock for llm request") @@ -225,6 +239,19 @@ async def process( if reset_coro: reset_coro.close() return + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_message_event( + "llm_request", + event, + { + "prompt": req.prompt, + "provider_id": provider.meta().id, + }, + provider_request=req, + ) + except Exception as exc: + logger.warning("SDK llm_request dispatch failed: %s", exc) # apply reset if reset_coro: diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index 070ad7bdee..181ce5d58a 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -4,18 +4,10 @@ from typing import TYPE_CHECKING from astrbot.core import astrbot_config, logger -from astrbot.core.agent.runners.coze.coze_agent_runner import CozeAgentRunner -from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import ( - DashscopeAgentRunner, -) from astrbot.core.agent.runners.deerflow.constants import ( DEERFLOW_AGENT_RUNNER_PROVIDER_ID_KEY, DEERFLOW_PROVIDER_TYPE, ) -from astrbot.core.agent.runners.deerflow.deerflow_agent_runner import ( - DeerFlowAgentRunner, -) -from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS from astrbot.core.message.components import Image from astrbot.core.message.message_event_result import ( @@ -327,14 +319,46 @@ async def process( # call event hook if await call_event_hook(event, EventType.OnLLMRequestEvent, req): return + sdk_plugin_bridge = getattr( + self.ctx.plugin_manager.context, "sdk_plugin_bridge", None + ) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_message_event( + "llm_request", + event, + { + "prompt": req.prompt, + "provider_id": self.prov_id, + }, + provider_request=req, + ) + except Exception as exc: + logger.warning("SDK llm_request dispatch failed: %s", exc) if self.runner_type == "dify": + from astrbot.core.agent.runners.dify.dify_agent_runner import ( + DifyAgentRunner, + ) + runner = DifyAgentRunner[AstrAgentContext]() elif self.runner_type == "coze": + from astrbot.core.agent.runners.coze.coze_agent_runner import ( + CozeAgentRunner, + ) + runner = CozeAgentRunner[AstrAgentContext]() elif self.runner_type == "dashscope": + from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import ( + DashscopeAgentRunner, + ) + runner = DashscopeAgentRunner[AstrAgentContext]() elif self.runner_type == DEERFLOW_PROVIDER_TYPE: + from astrbot.core.agent.runners.deerflow.deerflow_agent_runner import ( + DeerFlowAgentRunner, + ) + runner = DeerFlowAgentRunner[AstrAgentContext]() else: raise ValueError( diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index 9422d6317a..a353832b0b 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -60,6 +60,23 @@ async def process( e, traceback_text, ) + sdk_plugin_bridge = getattr( + self.ctx.plugin_manager.context, "sdk_plugin_bridge", None + ) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_message_event( + "plugin_error", + event, + { + "plugin_name": md.name, + "handler_name": handler.handler_name, + "error": str(e), + "traceback": traceback_text, + }, + ) + except Exception as exc: + logger.warning("SDK plugin_error dispatch failed: %s", exc) if not event.is_stopped() and event.is_at_or_wake_command: ret = f":(\n\n在调用插件 {md.name} 的处理函数 {handler.handler_name} 时出现异常:{e}" diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index 076f7f12ac..68be5d3f25 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -16,6 +16,9 @@ async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.config = ctx.astrbot_config self.plugin_manager = ctx.plugin_manager + self.sdk_plugin_bridge = getattr( + ctx.plugin_manager.context, "sdk_plugin_bridge", None + ) # initialize agent sub stage self.agent_sub_stage = AgentRequestSubStage() @@ -49,18 +52,29 @@ async def process( else: yield + if self.sdk_plugin_bridge is not None and not event.is_stopped(): + sdk_result = await self.sdk_plugin_bridge.dispatch_message(event) + if sdk_result.sent_message or sdk_result.stopped: + yield + # 调用 LLM 相关请求 if not self.ctx.astrbot_config["provider_settings"].get("enable", True): return - if ( - not event._has_send_oper - and event.is_at_or_wake_command - and not event.call_llm - ): + should_call_llm = ( + self.sdk_plugin_bridge.get_effective_should_call_llm(event) + if self.sdk_plugin_bridge is not None + and hasattr(self.sdk_plugin_bridge, "get_effective_should_call_llm") + else not event.call_llm + ) + effective_result = ( + self.sdk_plugin_bridge.get_effective_result(event) + if self.sdk_plugin_bridge is not None + and hasattr(self.sdk_plugin_bridge, "get_effective_result") + else event.get_result() + ) + if not event._has_send_oper and event.is_at_or_wake_command and should_call_llm: # 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀 - if ( - event.get_result() and not event.is_stopped() - ) or not event.get_result(): + if (effective_result and not event.is_stopped()) or not effective_result: async for _ in self.agent_sub_stage.process(event): yield diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 6a884a5181..4672bfd9d3 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -53,6 +53,9 @@ class RespondStage(Stage): async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.config = ctx.astrbot_config + self.sdk_plugin_bridge = getattr( + ctx.plugin_manager.context, "sdk_plugin_bridge", None + ) self.platform_settings: dict = self.config.get("platform_settings", {}) self.reply_with_mention = ctx.astrbot_config["platform_settings"][ @@ -86,7 +89,12 @@ async def initialize(self, ctx: PipelineContext) -> None: self.interval = [float(t) for t in interval_str_ls] except BaseException as e: logger.error(f"解析分段回复的间隔时间失败。{e}") - logger.info(f"分段回复间隔时间:{self.interval}") + logger.info(f"分段回复间隔时间:{self.interval}") + + def _get_effective_result(self, event: AstrMessageEvent): + if self.sdk_plugin_bridge is not None: + return self.sdk_plugin_bridge.get_effective_result(event) + return event.get_result() async def _word_cnt(self, text: str) -> int: """分段回复 统计字数""" @@ -128,12 +136,36 @@ async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]) -> bo # 如果所有组件都为空 return True + @staticmethod + def _message_outline_for_sdk_event( + chain: MessageChain | list[BaseMessageComponent] | None, + ) -> str: + if isinstance(chain, MessageChain): + return chain.get_plain_text(with_other_comps_mark=True) + if isinstance(chain, list): + return MessageChain(chain).get_plain_text(with_other_comps_mark=True) + return "" + + @staticmethod + def _message_payloads_for_sdk_event( + chain: MessageChain | list[BaseMessageComponent] | None, + ) -> list[dict]: + from astrbot_sdk.message.components import component_to_payload_sync + + if isinstance(chain, MessageChain): + components = chain.chain + elif isinstance(chain, list): + components = chain + else: + components = [] + return [component_to_payload_sync(component) for component in components] + def is_seg_reply_required(self, event: AstrMessageEvent) -> bool: """检查是否需要分段回复""" if not self.enable_seg: return False - if (result := event.get_result()) is None: + if (result := self._get_effective_result(event)) is None: return False if self.only_llm_result and not result.is_model_result(): return False @@ -171,7 +203,7 @@ async def process( self, event: AstrMessageEvent, ) -> None | AsyncGenerator[None, None]: - result = event.get_result() + result = self._get_effective_result(event) if result is None: return if event.get_extra("_streaming_finished", False): @@ -293,4 +325,34 @@ async def process( if await call_event_hook(event, EventType.OnAfterMessageSentEvent): return + if self.sdk_plugin_bridge is not None: + try: + from astrbot.core.sdk_bridge.event_converter import EventConverter + + await self.sdk_plugin_bridge.dispatch_message_event( + "after_message_sent", + event, + { + "session_id": event.unified_msg_origin, + "platform": event.get_platform_name(), + "platform_id": event.get_platform_id(), + "message_type": EventConverter._sdk_message_type( + event.get_message_type() + ), + "sender_name": event.get_sender_name(), + "self_id": event.get_self_id(), + "message_outline": self._message_outline_for_sdk_event( + result.chain + ), + "sent_message_outline": self._message_outline_for_sdk_event( + result.chain + ), + "sent_messages": self._message_payloads_for_sdk_event( + result.chain + ), + }, + ) + except Exception as exc: + logger.warning(f"SDK after_message_sent dispatch failed: {exc}") + event.clear_result() diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 4ee7461305..33e4e6043f 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -5,8 +5,8 @@ from collections.abc import AsyncGenerator from astrbot.core import file_token_service, html_renderer, logger -from astrbot.core.message.components import At, Image, Json, Node, Plain, Record, Reply -from astrbot.core.message.message_event_result import ResultContentType +from astrbot.core.message.components import At, Image, Node, Plain, Record, Reply +from astrbot.core.message.message_event_result import MessageChain, ResultContentType from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.message_type import MessageType @@ -20,8 +20,19 @@ @register_stage class ResultDecorateStage(Stage): + @staticmethod + def _message_outline_for_sdk_event(chain: MessageChain | list | None) -> str: + if isinstance(chain, MessageChain): + return chain.get_plain_text(with_other_comps_mark=True) + if isinstance(chain, list): + return MessageChain(chain).get_plain_text(with_other_comps_mark=True) + return "" + async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx + self.sdk_plugin_bridge = getattr( + ctx.plugin_manager.context, "sdk_plugin_bridge", None + ) self.reply_prefix = ctx.astrbot_config["platform_settings"]["reply_prefix"] self.reply_with_mention = ctx.astrbot_config["platform_settings"][ "reply_with_mention" @@ -101,6 +112,11 @@ async def initialize(self, ctx: PipelineContext) -> None: provider_cfg = ctx.astrbot_config.get("provider_settings", {}) self.show_reasoning = provider_cfg.get("display_reasoning_text", False) + def _get_effective_result(self, event: AstrMessageEvent): + if self.sdk_plugin_bridge is not None: + return self.sdk_plugin_bridge.get_effective_result(event) + return event.get_result() + def _split_text_by_words(self, text: str) -> list[str]: """使用分段词列表分段文本""" if not self.split_words_pattern: @@ -127,7 +143,7 @@ async def process( self, event: AstrMessageEvent, ) -> None | AsyncGenerator[None, None]: - result = event.get_result() + result = self._get_effective_result(event) if result is None or not result.chain: return @@ -184,13 +200,32 @@ async def process( ) return + result = self._get_effective_result(event) + if result is None or not result.chain: + return + + if self.sdk_plugin_bridge is not None: + try: + await self.sdk_plugin_bridge.dispatch_message_event( + "decorating_result", + event, + { + "message_outline": self._message_outline_for_sdk_event( + result.chain + ) + }, + event_result=result, + ) + except Exception as exc: + logger.warning(f"SDK decorating_result dispatch failed: {exc}") + # 流式输出不执行下面的逻辑 if is_stream: logger.info("流式输出已启用,跳过结果装饰阶段") return # 需要再获取一次。插件可能直接对 chain 进行了替换。 - result = event.get_result() + result = self._get_effective_result(event) if result is None: return @@ -275,21 +310,8 @@ async def process( and event.get_extra("_llm_reasoning_content") ): # inject reasoning content to chain - reasoning_content = str(event.get_extra("_llm_reasoning_content")) - if event.get_platform_name() == "lark": - result.chain.insert( - 0, - Json( - data={ - "type": "lark_collapsible_panel_reasoning", - "title": "💭 Thinking", - "expanded": False, - "content": reasoning_content, - }, - ), - ) - else: - result.chain.insert(0, Plain(f"🤔 思考: {reasoning_content}\n")) + reasoning_content = event.get_extra("_llm_reasoning_content") + result.chain.insert(0, Plain(f"🤔 思考: {reasoning_content}\n")) if should_tts and tts_provider: new_chain = [] diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index 243d03378c..e78db8660d 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -92,5 +92,14 @@ async def execute(self, event: AstrMessageEvent) -> None: logger.debug("pipeline 执行完毕。") finally: - event.cleanup_temporary_local_files() - active_event_registry.unregister(event) + try: + event.cleanup_temporary_local_files() + finally: + try: + sdk_plugin_bridge = getattr( + self.ctx.plugin_manager.context, "sdk_plugin_bridge", None + ) + if sdk_plugin_bridge is not None: + sdk_plugin_bridge.close_request_overlay_for_event(event) + finally: + active_event_registry.unregister(event) diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 82c03dbb0d..8ceed97ca8 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc import asyncio import hashlib @@ -6,11 +8,9 @@ import uuid from collections.abc import AsyncGenerator from time import time -from typing import Any +from typing import TYPE_CHECKING, Any from astrbot import logger -from astrbot.core.agent.tool import ToolSet -from astrbot.core.db.po import Conversation from astrbot.core.message.components import ( At, AtAll, @@ -23,7 +23,6 @@ ) from astrbot.core.message.message_event_result import MessageChain, MessageEventResult from astrbot.core.platform.message_type import MessageType -from astrbot.core.provider.entities import ProviderRequest from astrbot.core.utils.metrics import Metric from astrbot.core.utils.trace import TraceSpan @@ -31,6 +30,11 @@ from .message_session import MessageSesion, MessageSession # noqa from .platform_metadata import PlatformMetadata +if TYPE_CHECKING: + from astrbot.core.agent.tool import ToolSet + from astrbot.core.db.po import Conversation + from astrbot.core.provider.entities import ProviderRequest + class AstrMessageEvent(abc.ABC): def __init__( @@ -440,6 +444,8 @@ def request_llm( if len(contexts) > 0 and conversation: conversation = None + from astrbot.core.provider.entities import ProviderRequest + return ProviderRequest( prompt=prompt, session_id=session_id, diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 15c04166dc..1a26ebd58d 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -2,6 +2,7 @@ import traceback from asyncio import Queue from dataclasses import dataclass +from typing import TYPE_CHECKING from astrbot.core import logger from astrbot.core.config.astrbot_config import AstrBotConfig @@ -12,6 +13,9 @@ from .register import platform_cls_map from .sources.webchat.webchat_adapter import WebChatAdapter +if TYPE_CHECKING: + from astrbot.core.sdk_bridge.plugin_bridge import SdkPluginBridge + @dataclass class PlatformTasks: @@ -34,6 +38,7 @@ def __init__(self, config: AstrBotConfig, event_queue: Queue) -> None: 这个配置中的 unique_session 需要特殊处理, 约定整个项目中对 unique_session 的引用都从 default 的配置中获取""" self.event_queue = event_queue + self.sdk_plugin_bridge: SdkPluginBridge | None = None def _is_valid_platform_id(self, platform_id: str | None) -> bool: if not platform_id: @@ -202,6 +207,7 @@ async def load_platform(self, platform_config: dict) -> None: return cls_type = platform_cls_map[platform_config["type"]] inst: Platform = cls_type(platform_config, self.settings, self.event_queue) + setattr(inst, "sdk_plugin_bridge", self.sdk_plugin_bridge) self._inst_map[platform_config["id"]] = { "inst": inst, "client_id": inst.client_self_id, @@ -222,6 +228,17 @@ async def load_platform(self, platform_config: dict) -> None: await handler.handler() except Exception: logger.error(traceback.format_exc()) + if self.sdk_plugin_bridge is not None: + try: + await self.sdk_plugin_bridge.dispatch_system_event( + "platform_loaded", + { + "platform": inst.meta().name, + "platform_id": inst.meta().id, + }, + ) + except Exception as exc: + logger.warning(f"SDK platform_loaded event dispatch failed: {exc}") async def _task_wrapper( self, task: asyncio.Task, platform: Platform | None = None @@ -300,6 +317,48 @@ async def terminate(self) -> None: def get_insts(self): return self.platform_insts + async def refresh_native_commands( + self, *, platforms: set[str] | None = None + ) -> None: + """Refresh native command menus for running platform adapters. + + Native command registration is platform-specific. Today Telegram owns its + own command sync path, so plugin hot reloads need an explicit follow-up + refresh to make newly loaded SDK commands visible without waiting for the + periodic registration job or a full restart. + """ + requested_platforms = ( + {item.strip().lower() for item in platforms if item and item.strip()} + if platforms + else None + ) + for inst in list(self.platform_insts): + platform_name = "" + try: + platform_name = str(inst.meta().name).strip().lower() + except Exception: + logger.debug("Failed to read platform metadata during command refresh.") + continue + + if ( + requested_platforms is not None + and platform_name not in requested_platforms + ): + continue + + register_commands = getattr(inst, "register_commands", None) + if not callable(register_commands): + continue + + try: + await register_commands() + except Exception as exc: + logger.warning( + "刷新 %s 平台原生命令失败: %s", + platform_name or "unknown", + exc, + ) + def get_all_stats(self) -> dict: """获取所有平台的统计信息 diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index 7657962a11..50215ca44f 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -48,6 +48,7 @@ def __init__( self.settings = platform_settings self.client_self_id: str | None = None self.registered_handlers = [] + self.sdk_plugin_bridge = None # 指令注册相关 self.enable_command_register = self.config.get("discord_command_register", True) self.guild_id = self.config.get("discord_guild_id_for_debug", None) @@ -366,42 +367,25 @@ async def _collect_and_register_commands(self) -> None: """收集所有指令并注册到Discord""" logger.info("[Discord] 开始收集并注册斜杠指令...") registered_commands = [] - - for handler_md in star_handlers_registry: - if not star_map[handler_md.handler_module_path].activated: - continue - if not handler_md.enabled: - continue - for event_filter in handler_md.event_filters: - cmd_info = self._extract_command_info(event_filter, handler_md) - if not cmd_info: - continue - - cmd_name, description, cmd_filter_instance = cmd_info - - # 创建动态回调 - callback = self._create_dynamic_callback(cmd_name) - - # 创建一个通用的参数选项来接收所有文本输入 - options = [ - discord.Option( - name="params", - description="指令的所有参数", - type=discord.SlashCommandOptionType.string, - required=False, - ), - ] - - # 创建SlashCommand - slash_command = discord.SlashCommand( - name=cmd_name, - description=description, - func=callback, - options=options, - guild_ids=[self.guild_id] if self.guild_id else None, - ) - self.client.add_application_command(slash_command) - registered_commands.append(cmd_name) + for cmd_name, description in self.collect_commands(): + callback = self._create_dynamic_callback(cmd_name) + options = [ + discord.Option( + name="params", + description="指令的所有参数", + type=discord.SlashCommandOptionType.string, + required=False, + ), + ] + slash_command = discord.SlashCommand( + name=cmd_name, + description=description, + func=callback, + options=options, + guild_ids=[self.guild_id] if self.guild_id else None, + ) + self.client.add_application_command(slash_command) + registered_commands.append(cmd_name) if registered_commands: logger.info( @@ -415,6 +399,53 @@ async def _collect_and_register_commands(self) -> None: await self.client.sync_commands() logger.info("[Discord] 指令同步完成。") + def collect_commands(self) -> list[tuple[str, str]]: + """收集 legacy 与 SDK 的顶层原生命令。""" + command_dict: dict[str, str] = {} + + for handler_md in star_handlers_registry: + if not star_map[handler_md.handler_module_path].activated: + continue + if not handler_md.enabled: + continue + for event_filter in handler_md.event_filters: + cmd_info = self._extract_command_info(event_filter, handler_md) + if not cmd_info: + continue + cmd_name, description, _cmd_filter_instance = cmd_info + if cmd_name in command_dict: + logger.warning( + f"命令名 '{cmd_name}' 重复注册,将使用首次注册的定义: " + f"'{command_dict[cmd_name]}'" + ) + command_dict.setdefault(cmd_name, description) + + sdk_bridge = getattr(self, "sdk_plugin_bridge", None) + if sdk_bridge is not None: + for item in sdk_bridge.list_native_command_candidates("discord"): + cmd_name = str(item.get("name", "")).strip() + if not cmd_name: + continue + if not re.match(r"^[a-z0-9_-]{1,32}$", cmd_name): + logger.debug(f"[Discord] 跳过不符合规范的 SDK 指令: {cmd_name}") + continue + description = str(item.get("description") or "").strip() + if not description: + if item.get("is_group"): + description = f"Command group: {cmd_name}" + else: + description = f"Command: {cmd_name}" + if len(description) > 100: + description = f"{description[:97]}..." + if cmd_name in command_dict: + logger.warning( + f"命令名 '{cmd_name}' 重复注册,将使用首次注册的定义: " + f"'{command_dict[cmd_name]}'" + ) + command_dict.setdefault(cmd_name, description) + + return sorted(command_dict.items(), key=lambda item: item[0].lower()) + def _create_dynamic_callback(self, cmd_name: str): """为每个指令动态创建一个异步回调函数""" @@ -481,7 +512,6 @@ def _extract_command_info( ) -> tuple[str, str, CommandFilter | None] | None: """从事件过滤器中提取指令信息""" cmd_name = None - # is_group = False cmd_filter_instance = None if isinstance(event_filter, CommandFilter): @@ -501,7 +531,6 @@ def _extract_command_info( if not cmd_name: return None - # Discord 斜杠指令名称规范 if not re.match(r"^[a-z0-9_-]{1,32}$", cmd_name): logger.debug(f"[Discord] 跳过不符合规范的指令: {cmd_name}") return None diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index 5f44913573..85fd9c9e9f 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -51,6 +51,7 @@ def __init__( super().__init__(platform_config, event_queue) self.settings = platform_settings self.client_self_id = uuid.uuid4().hex[:8] + self.sdk_plugin_bridge = None base_url = self.config.get( "telegram_api_base_url", @@ -248,6 +249,31 @@ def collect_commands(self) -> list[BotCommand]: ) command_dict.setdefault(cmd_name, description) + sdk_bridge = getattr(self, "sdk_plugin_bridge", None) + if sdk_bridge is not None: + for item in sdk_bridge.list_native_command_candidates("telegram"): + cmd_name = str(item.get("name", "")).strip() + if not cmd_name or cmd_name in skip_commands: + continue + if not re.match(r"^[a-z0-9_]+$", cmd_name) or len(cmd_name) > 32: + continue + + description = str(item.get("description") or "").strip() + if not description: + if item.get("is_group"): + description = f"Command group: {cmd_name}" + else: + description = f"Command: {cmd_name}" + if len(description) > 30: + description = description[:30] + "..." + + if cmd_name in command_dict: + logger.warning( + f"命令名 '{cmd_name}' 重复注册,将使用首次注册的定义: " + f"'{command_dict[cmd_name]}'" + ) + command_dict.setdefault(cmd_name, description) + commands_a = sorted(command_dict.keys()) return [BotCommand(cmd, command_dict[cmd]) for cmd in commands_a] diff --git a/astrbot/core/platform/sources/wecom_ai_bot/__init__.py b/astrbot/core/platform/sources/wecom_ai_bot/__init__.py index 2f87b88b90..6034b5e371 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/__init__.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/__init__.py @@ -1,10 +1,22 @@ """企业微信智能机器人平台适配器包""" -from .wecomai_adapter import WecomAIBotAdapter -from .wecomai_api import WecomAIBotAPIClient -from .wecomai_event import WecomAIBotMessageEvent -from .wecomai_server import WecomAIBotServer -from .wecomai_utils import WecomAIBotConstants +from __future__ import annotations + +from importlib import import_module +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .wecomai_adapter import WecomAIBotAdapter + from .wecomai_api import WecomAIBotAPIClient + from .wecomai_event import WecomAIBotMessageEvent + from .wecomai_server import WecomAIBotServer + from .wecomai_utils import WecomAIBotConstants +else: + WecomAIBotAdapter: Any + WecomAIBotAPIClient: Any + WecomAIBotMessageEvent: Any + WecomAIBotServer: Any + WecomAIBotConstants: Any __all__ = [ "WecomAIBotAPIClient", @@ -13,3 +25,17 @@ "WecomAIBotMessageEvent", "WecomAIBotServer", ] + + +def __getattr__(name: str) -> Any: + if name == "WecomAIBotAdapter": + return import_module(".wecomai_adapter", __name__).WecomAIBotAdapter + if name == "WecomAIBotAPIClient": + return import_module(".wecomai_api", __name__).WecomAIBotAPIClient + if name == "WecomAIBotMessageEvent": + return import_module(".wecomai_event", __name__).WecomAIBotMessageEvent + if name == "WecomAIBotServer": + return import_module(".wecomai_server", __name__).WecomAIBotServer + if name == "WecomAIBotConstants": + return import_module(".wecomai_utils", __name__).WecomAIBotConstants + raise AttributeError(name) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py index f27d4671e5..86931c2c43 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py @@ -1,15 +1,19 @@ """企业微信智能机器人事件处理模块,处理消息事件的发送和接收""" +from __future__ import annotations + import asyncio from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import At, Image, Plain -from .wecomai_api import WecomAIBotAPIClient -from .wecomai_queue_mgr import WecomAIQueueMgr -from .wecomai_webhook import WecomAIBotWebhookClient +if TYPE_CHECKING: + from .wecomai_api import WecomAIBotAPIClient + from .wecomai_queue_mgr import WecomAIQueueMgr + from .wecomai_webhook import WecomAIBotWebhookClient class WecomAIBotMessageEvent(AstrMessageEvent): diff --git a/astrbot/core/platform_message_history_mgr.py b/astrbot/core/platform_message_history_mgr.py index ad8bb44f6d..c674cd8195 100644 --- a/astrbot/core/platform_message_history_mgr.py +++ b/astrbot/core/platform_message_history_mgr.py @@ -1,8 +1,232 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +from astrbot_sdk.message.components import component_to_payload_sync + from astrbot.core.db import BaseDatabase from astrbot.core.db.po import PlatformMessageHistory +from astrbot.core.message.components import ( + At, + AtAll, + BaseMessageComponent, + File, + Forward, + Image, + Plain, + Poke, + Record, + Reply, + Unknown, + Video, +) +from astrbot.core.platform.message_session import MessageSession +from astrbot.core.platform.message_type import MessageType + + +@dataclass(frozen=True, slots=True) +class MessageHistorySender: + sender_id: str | None = None + sender_name: str | None = None + + +@dataclass(slots=True) +class MessageHistoryRecord: + id: int + session: MessageSession + sender: MessageHistorySender + parts: list[BaseMessageComponent] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + created_at: datetime | None = None + updated_at: datetime | None = None + idempotency_key: str | None = None + + +@dataclass(frozen=True, slots=True) +class MessageHistoryPage: + records: list[MessageHistoryRecord] + next_cursor: str | None + total: int | None + + +def _message_type_key(value: MessageType | str) -> str: + if isinstance(value, MessageType): + if value == MessageType.GROUP_MESSAGE: + return "group" + if value == MessageType.FRIEND_MESSAGE: + return "private" + return "other" + normalized = str(value).strip().lower() + if normalized in {"group", "groupmessage", "group_message"}: + return "group" + if normalized in { + "private", + "friend", + "friendmessage", + "privatemessage", + "friend_message", + "private_message", + }: + return "private" + if normalized in {"other", "othermessage", "other_message"}: + return "other" + raise ValueError(f"Unsupported message type: {value}") + + +def _message_type_enum(value: str) -> MessageType: + normalized = _message_type_key(value) + if normalized == "group": + return MessageType.GROUP_MESSAGE + if normalized == "private": + return MessageType.FRIEND_MESSAGE + return MessageType.OTHER_MESSAGE + + +def _session_storage_key(session: MessageSession) -> str: + # TODO(refactor): persist message_type as a first-class column once the + # legacy message history model can be migrated without impacting old plugins. + return f"{_message_type_key(session.message_type)}:{session.session_id}" + + +def _optional_int_cursor(cursor: str | None) -> int | None: + if cursor is None: + return None + text = str(cursor).strip() + if not text: + return None + return int(text) + + +def _payload_to_component(payload: Any) -> BaseMessageComponent: + if not isinstance(payload, dict): + return Unknown(text=str(payload)) + + raw_type = str(payload.get("type", "unknown") or "unknown").lower() + data = payload.get("data") + if not isinstance(data, dict): + data = {} + + if raw_type in {"text", "plain"}: + return Plain(str(data.get("text", "")), convert=False) + if raw_type == "image": + image_data = dict(data) + image_file = str(image_data.pop("file", "") or image_data.get("url") or "") + return Image(image_file, **image_data) + if raw_type == "at": + qq_value = data.get("qq") + if str(qq_value).lower() == "all": + return AtAll() + return At(qq=str(qq_value or ""), name=str(data.get("name", ""))) + if raw_type == "reply": + reply_data = dict(data) + chain_payload = reply_data.get("chain") + reply_data["chain"] = ( + [_payload_to_component(item) for item in chain_payload] + if isinstance(chain_payload, list) + else [] + ) + return Reply(**reply_data) + if raw_type == "record": + record_data = dict(data) + record_file = str(record_data.pop("file", "") or record_data.get("url") or "") + return Record(record_file, **record_data) + if raw_type == "video": + video_data = dict(data) + video_file = str(video_data.pop("file", "") or "") + return Video(video_file, **video_data) + if raw_type == "file": + file_value = str(data.get("file") or data.get("file_") or data.get("url") or "") + return File( + str(data.get("name", "") or "file"), + file="" if file_value.startswith(("http://", "https://")) else file_value, + url=file_value if file_value.startswith(("http://", "https://")) else "", + ) + if raw_type == "poke": + return Poke( + poke_type=data.get("type"), + id=data.get("id"), + qq=data.get("qq"), + ) + if raw_type == "forward": + return Forward(id=str(data.get("id", ""))) + return Unknown(text=str(payload)) + + +def _legacy_content_to_payloads( + content: dict[str, Any], +) -> tuple[list[dict[str, Any]], dict[str, Any]]: + message_parts = content.get("message") + if not isinstance(message_parts, list): + return [], {} + payloads: list[dict[str, Any]] = [] + for part in message_parts: + if not isinstance(part, dict): + continue + part_type = str(part.get("type", "")).strip().lower() + if part_type == "plain": + text = str(part.get("text", "")) + if text: + payloads.append({"type": "text", "data": {"text": text}}) + continue + if part_type == "reply": + message_id = part.get("message_id") + if message_id is None: + continue + payloads.append( + { + "type": "reply", + "data": { + "id": str(message_id), + "message_str": str(part.get("selected_text", "")), + "chain": [], + }, + } + ) + continue + if part_type not in {"image", "record", "file", "video"}: + continue + payload_data: dict[str, Any] = {} + attachment_id = part.get("attachment_id") + if attachment_id is not None: + payload_data["attachment_id"] = str(attachment_id) + filename = part.get("filename") + if filename is not None: + payload_data["filename"] = str(filename) + if part_type == "file": + payload_data["name"] = str(filename) + path_value = part.get("path") + if path_value not in (None, ""): + payload_data["path"] = str(path_value) + payload_data["file"] = str(path_value) + payloads.append({"type": part_type, "data": payload_data}) + metadata = {key: value for key, value in content.items() if key != "message"} + return payloads, metadata + + +def _content_to_parts_and_metadata( + content: Any, +) -> tuple[list[dict[str, Any]], dict[str, Any], str | None]: + if not isinstance(content, dict): + return [], {}, None + if isinstance(content.get("parts"), list): + metadata = content.get("metadata") + idempotency_key = content.get("idempotency_key") + return ( + [dict(item) for item in content["parts"] if isinstance(item, dict)], + dict(metadata) if isinstance(metadata, dict) else {}, + str(idempotency_key) if idempotency_key is not None else None, + ) + payloads, metadata = _legacy_content_to_payloads(content) + return payloads, metadata, None class PlatformMessageHistoryManager: + MessageHistorySender = MessageHistorySender + MessageHistoryRecord = MessageHistoryRecord + MessageHistoryPage = MessageHistoryPage + def __init__(self, db_helper: BaseDatabase) -> None: self.db = db_helper @@ -10,7 +234,7 @@ async def insert( self, platform_id: str, user_id: str, - content: dict, # TODO: parse from message chain + content: dict, sender_id: str | None = None, sender_name: str | None = None, ) -> PlatformMessageHistory: @@ -49,3 +273,146 @@ async def delete( user_id=user_id, offset_sec=offset_sec, ) + + async def append( + self, + session: MessageSession, + *, + parts: list[BaseMessageComponent], + sender: MessageHistorySender, + metadata: dict[str, Any] | None = None, + idempotency_key: str | None = None, + ) -> MessageHistoryRecord: + storage_user_id = _session_storage_key(session) + if idempotency_key: + # TODO(refactor): move idempotency_key into a dedicated indexed column + # after the legacy history table is migrated for the new SDK path. + existing = await self.db.find_platform_message_history_by_idempotency_key( + platform_id=session.platform_id, + user_id=storage_user_id, + idempotency_key=idempotency_key, + ) + if existing is not None: + return self._record_from_model(existing) + + content = { + "parts": [component_to_payload_sync(part) for part in parts], + "metadata": dict(metadata or {}), + } + if idempotency_key is not None: + content["idempotency_key"] = str(idempotency_key) + + record = await self.db.insert_platform_message_history( + platform_id=session.platform_id, + user_id=storage_user_id, + content=content, + sender_id=sender.sender_id, + sender_name=sender.sender_name, + ) + return self._record_from_model(record) + + async def list( + self, + session: MessageSession, + *, + cursor: str | None = None, + limit: int = 50, + ) -> MessageHistoryPage: + normalized_limit = max(1, int(limit)) + rows, total = await self.db.list_sdk_platform_message_history( + platform_id=session.platform_id, + user_id=_session_storage_key(session), + cursor_id=_optional_int_cursor(cursor), + limit=normalized_limit + 1, + include_total=True, + ) + has_more = len(rows) > normalized_limit + page_rows = rows[:normalized_limit] + records = [self._record_from_model(row) for row in page_rows] + next_cursor = str(page_rows[-1].id) if has_more and page_rows else None + return MessageHistoryPage(records=records, next_cursor=next_cursor, total=total) + + async def get_by_id( + self, + session: MessageSession, + record_id: int, + ) -> MessageHistoryRecord | None: + record = await self.db.get_platform_message_history_by_id(int(record_id)) + if record is None: + return None + if record.platform_id != session.platform_id: + return None + if record.user_id != _session_storage_key(session): + return None + return self._record_from_model(record) + + async def delete_before( + self, + session: MessageSession, + *, + before: datetime, + ) -> int: + return await self.db.delete_platform_message_before( + platform_id=session.platform_id, + user_id=_session_storage_key(session), + before=before, + ) + + async def delete_after( + self, + session: MessageSession, + *, + after: datetime, + ) -> int: + return await self.db.delete_platform_message_after( + platform_id=session.platform_id, + user_id=_session_storage_key(session), + after=after, + ) + + async def delete_all(self, session: MessageSession) -> int: + return await self.db.delete_all_platform_message_history( + platform_id=session.platform_id, + user_id=_session_storage_key(session), + ) + + def _record_from_model( + self, record: PlatformMessageHistory + ) -> MessageHistoryRecord: + parts_payload, metadata, idempotency_key = _content_to_parts_and_metadata( + record.content + ) + return MessageHistoryRecord( + id=int(record.id or 0), + session=self._session_from_storage_record(record), + sender=MessageHistorySender( + sender_id=str(record.sender_id) + if record.sender_id is not None + else None, + sender_name=( + str(record.sender_name) if record.sender_name is not None else None + ), + ), + parts=[_payload_to_component(item) for item in parts_payload], + metadata=metadata, + created_at=record.created_at, + updated_at=record.updated_at, + idempotency_key=idempotency_key, + ) + + def _session_from_storage_record( + self, record: PlatformMessageHistory + ) -> MessageSession: + raw_user_id = str(record.user_id or "") + message_type = "private" + session_id = raw_user_id + if ":" in raw_user_id: + maybe_message_type, maybe_session_id = raw_user_id.split(":", 1) + if maybe_message_type in {"group", "private", "other"} and maybe_session_id: + message_type = maybe_message_type + session_id = maybe_session_id + return MessageSession( + platform_name=str(record.platform_id), + message_type=_message_type_enum(message_type), + session_id=session_id, + ) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 7a3e1543a7..c1815d2e0d 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -96,6 +96,13 @@ def register_provider_change_hook( if hook not in self._provider_change_hooks: self._provider_change_hooks.append(hook) + def unregister_provider_change_hook( + self, + hook: Callable[[str, ProviderType, str | None], None], + ) -> None: + if hook in self._provider_change_hooks: + self._provider_change_hooks.remove(hook) + def _notify_provider_changed( self, provider_id: str, diff --git a/astrbot/core/sdk_bridge/__init__.py b/astrbot/core/sdk_bridge/__init__.py new file mode 100644 index 0000000000..32b79b2384 --- /dev/null +++ b/astrbot/core/sdk_bridge/__init__.py @@ -0,0 +1,36 @@ +"""SDK bridge package public exports.""" + +from __future__ import annotations + +from importlib import import_module +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .capability_bridge import CoreCapabilityBridge + from .event_converter import EventConverter + from .plugin_bridge import SdkPluginBridge + from .trigger_converter import TriggerConverter +else: + CoreCapabilityBridge: Any + EventConverter: Any + SdkPluginBridge: Any + TriggerConverter: Any + +__all__ = [ + "CoreCapabilityBridge", + "EventConverter", + "SdkPluginBridge", + "TriggerConverter", +] + + +def __getattr__(name: str) -> Any: + if name == "CoreCapabilityBridge": + return import_module(".capability_bridge", __name__).CoreCapabilityBridge + if name == "EventConverter": + return import_module(".event_converter", __name__).EventConverter + if name == "SdkPluginBridge": + return import_module(".plugin_bridge", __name__).SdkPluginBridge + if name == "TriggerConverter": + return import_module(".trigger_converter", __name__).TriggerConverter + raise AttributeError(name) diff --git a/astrbot/core/sdk_bridge/bridge_base.py b/astrbot/core/sdk_bridge/bridge_base.py new file mode 100644 index 0000000000..771525a510 --- /dev/null +++ b/astrbot/core/sdk_bridge/bridge_base.py @@ -0,0 +1,619 @@ +from __future__ import annotations + +import asyncio +import contextlib +import json +from collections.abc import Iterable +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, cast + +from astrbot_sdk._internal.invocation_context import current_caller_plugin_id +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.runtime.capability_router import CapabilityRouter + +from astrbot.core.file_token_service import FileTokenService +from astrbot.core.message.components import ComponentTypes, Image, Plain +from astrbot.core.message.message_event_result import MessageChain + +if TYPE_CHECKING: + from astrbot.core.star.context import Context as StarContext + + +def _get_runtime_sp(): + from astrbot.core import sp + + return sp + + +def _get_runtime_html_renderer(): + from astrbot.core import html_renderer + + return html_renderer + + +def _get_runtime_astrbot_config(): + from astrbot.core import astrbot_config + + return astrbot_config + + +def _get_runtime_file_token_service() -> FileTokenService: + from astrbot.core import file_token_service + + return cast(FileTokenService, file_token_service) + + +def _get_runtime_tool_types(): + from astrbot.core.agent.tool import FunctionTool, ToolSet + + return FunctionTool, ToolSet + + +def _get_runtime_provider_types(): + from astrbot.core.provider.provider import ( + EmbeddingProvider, + RerankProvider, + STTProvider, + TTSProvider, + ) + + return STTProvider, TTSProvider, EmbeddingProvider, RerankProvider + + +@dataclass(slots=True) +class _EventStreamState: + request_context: Any + queue: asyncio.Queue[MessageChain | None] + task: asyncio.Task[None] + + +def _build_message_chain_from_payload( + chain_payload: list[dict[str, Any]], +) -> MessageChain: + components = [] + for item in chain_payload: + if not isinstance(item, dict): + continue + comp_type = str(item.get("type", "")).lower() + data = item.get("data", {}) + if comp_type in {"text", "plain"} and isinstance(data, dict): + components.append(Plain(str(data.get("text", "")), convert=False)) + continue + if comp_type == "image" and isinstance(data, dict): + file_value = str(data.get("file") or data.get("url") or "") + if file_value.startswith(("http://", "https://")): + components.append(Image.fromURL(file_value)) + elif file_value: + file_path = ( + file_value[8:] if file_value.startswith("file:///") else file_value + ) + components.append(Image.fromFileSystem(file_path)) + continue + component_cls = ComponentTypes.get(comp_type) + if component_cls is None: + components.append( + Plain(json.dumps(item, ensure_ascii=False), convert=False) + ) + continue + try: + if isinstance(data, dict): + components.append(component_cls(**data)) + else: + components.append(Plain(str(item), convert=False)) + except Exception: + components.append( + Plain(json.dumps(item, ensure_ascii=False), convert=False) + ) + return MessageChain(components) + + +class CapabilityBridgeBase(CapabilityRouter): + MEMORY_SCOPE = "sdk_memory" + + _star_context: StarContext + _plugin_bridge: Any + + @staticmethod + def _to_iso_datetime(value: Any) -> str | None: + if value is None: + return None + isoformat = getattr(value, "isoformat", None) + if callable(isoformat): + return str(isoformat()) + if isinstance(value, (int, float)) and value > 0: + return datetime.fromtimestamp(float(value), tz=timezone.utc).isoformat() + return None + + @staticmethod + def _optional_int(value: Any) -> int | None: + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None + + @staticmethod + def _normalize_history_items(value: Any) -> list[dict[str, Any]]: + if isinstance(value, list): + return [dict(item) for item in value if isinstance(item, dict)] + if isinstance(value, str): + with contextlib.suppress(json.JSONDecodeError, TypeError, ValueError): + decoded = json.loads(value) + if isinstance(decoded, list): + return [dict(item) for item in decoded if isinstance(item, dict)] + return [] + + @staticmethod + def _normalize_persona_dialogs(value: Any) -> list[str]: + if isinstance(value, list): + return [str(item) for item in value if isinstance(item, str)] + if isinstance(value, str): + with contextlib.suppress(json.JSONDecodeError, TypeError, ValueError): + decoded = json.loads(value) + if isinstance(decoded, list): + return [str(item) for item in decoded if isinstance(item, str)] + return [] + + @staticmethod + def _normalize_session_scoped_config( + raw_config: Any, + session_id: str, + ) -> dict[str, Any]: + if not isinstance(raw_config, dict): + return {} + nested = raw_config.get(session_id) + if isinstance(nested, dict): + return dict(nested) + # Session plugin config is stored as {session_id: {...}}, but session + # service config already lives directly under the per-session storage key. + # Accept both shapes so the bridge stays compatible with existing data. + return dict(raw_config) + + def _serialize_persona(self, persona: Any) -> dict[str, Any] | None: + if persona is None: + return None + return { + "persona_id": str(getattr(persona, "persona_id", "") or ""), + "system_prompt": str(getattr(persona, "system_prompt", "") or ""), + "begin_dialogs": self._normalize_persona_dialogs( + getattr(persona, "begin_dialogs", None) + ), + "tools": ( + [str(item) for item in getattr(persona, "tools", [])] + if isinstance(getattr(persona, "tools", None), list) + else None + ), + "skills": ( + [str(item) for item in getattr(persona, "skills", [])] + if isinstance(getattr(persona, "skills", None), list) + else None + ), + "custom_error_message": ( + str(getattr(persona, "custom_error_message", "")) + if getattr(persona, "custom_error_message", None) is not None + else None + ), + "folder_id": ( + str(getattr(persona, "folder_id", "")) + if getattr(persona, "folder_id", None) is not None + else None + ), + "sort_order": int(getattr(persona, "sort_order", 0) or 0), + "created_at": self._to_iso_datetime(getattr(persona, "created_at", None)), + "updated_at": self._to_iso_datetime(getattr(persona, "updated_at", None)), + } + + def _serialize_conversation(self, conversation: Any) -> dict[str, Any] | None: + if conversation is None: + return None + return { + "conversation_id": str(getattr(conversation, "cid", "") or ""), + "session": str(getattr(conversation, "user_id", "") or ""), + "platform_id": str(getattr(conversation, "platform_id", "") or ""), + "history": self._normalize_history_items( + getattr(conversation, "history", None) + ), + "title": ( + str(getattr(conversation, "title", "")) + if getattr(conversation, "title", None) is not None + else None + ), + "persona_id": ( + str(getattr(conversation, "persona_id", "")) + if getattr(conversation, "persona_id", None) is not None + else None + ), + "created_at": self._to_iso_datetime( + getattr(conversation, "created_at", None) + ), + "updated_at": self._to_iso_datetime( + getattr(conversation, "updated_at", None) + ), + "token_usage": ( + int(getattr(conversation, "token_usage")) + if getattr(conversation, "token_usage", None) is not None + else None + ), + } + + def _serialize_kb(self, kb_helper_or_record: Any) -> dict[str, Any] | None: + kb = getattr(kb_helper_or_record, "kb", kb_helper_or_record) + if kb is None: + return None + return { + "kb_id": str(getattr(kb, "kb_id", "") or ""), + "kb_name": str(getattr(kb, "kb_name", "") or ""), + "description": ( + str(getattr(kb, "description", "")) + if getattr(kb, "description", None) is not None + else None + ), + "emoji": ( + str(getattr(kb, "emoji", "")) + if getattr(kb, "emoji", None) is not None + else None + ), + "embedding_provider_id": str( + getattr(kb, "embedding_provider_id", "") or "" + ), + "rerank_provider_id": ( + str(getattr(kb, "rerank_provider_id", "")) + if getattr(kb, "rerank_provider_id", None) is not None + else None + ), + "chunk_size": ( + int(getattr(kb, "chunk_size")) + if getattr(kb, "chunk_size", None) is not None + else None + ), + "chunk_overlap": ( + int(getattr(kb, "chunk_overlap")) + if getattr(kb, "chunk_overlap", None) is not None + else None + ), + "top_k_dense": ( + int(getattr(kb, "top_k_dense")) + if getattr(kb, "top_k_dense", None) is not None + else None + ), + "top_k_sparse": ( + int(getattr(kb, "top_k_sparse")) + if getattr(kb, "top_k_sparse", None) is not None + else None + ), + "top_m_final": ( + int(getattr(kb, "top_m_final")) + if getattr(kb, "top_m_final", None) is not None + else None + ), + "doc_count": int(getattr(kb, "doc_count", 0) or 0), + "chunk_count": int(getattr(kb, "chunk_count", 0) or 0), + "created_at": self._to_iso_datetime(getattr(kb, "created_at", None)), + "updated_at": self._to_iso_datetime(getattr(kb, "updated_at", None)), + } + + def _serialize_kb_document(self, document: Any) -> dict[str, Any] | None: + if document is None: + return None + return { + "doc_id": str(getattr(document, "doc_id", "") or ""), + "kb_id": str(getattr(document, "kb_id", "") or ""), + "doc_name": str(getattr(document, "doc_name", "") or ""), + "file_type": str(getattr(document, "file_type", "") or ""), + "file_size": int(getattr(document, "file_size", 0) or 0), + "file_path": str(getattr(document, "file_path", "") or ""), + "chunk_count": int(getattr(document, "chunk_count", 0) or 0), + "media_count": int(getattr(document, "media_count", 0) or 0), + "created_at": self._to_iso_datetime(getattr(document, "created_at", None)), + "updated_at": self._to_iso_datetime(getattr(document, "updated_at", None)), + } + + @staticmethod + def _serialize_member(member: Any) -> dict[str, Any] | None: + if member is None: + return None + user_id = getattr(member, "user_id", None) + if user_id is None and isinstance(member, dict): + user_id = member.get("user_id") + if user_id is None: + return None + nickname = getattr(member, "nickname", None) + if nickname is None and isinstance(member, dict): + nickname = member.get("nickname") + role = getattr(member, "role", None) + if role is None and isinstance(member, dict): + role = member.get("role") + return { + "user_id": str(user_id), + "nickname": str(nickname or ""), + "role": str(role or ""), + } + + @classmethod + def _serialize_group(cls, group: Any) -> dict[str, Any] | None: + if group is None: + return None + members_payload = [] + raw_members = getattr(group, "members", None) + if raw_members is None: + raw_members = getattr(group, "member_list", None) + if raw_members is None and isinstance(group, dict): + raw_members = group.get("members") or group.get("member_list") + if isinstance(raw_members, list): + for member in raw_members: + serialized_member = cls._serialize_member(member) + if serialized_member is not None: + members_payload.append(serialized_member) + group_id = getattr(group, "group_id", None) + if group_id is None and isinstance(group, dict): + group_id = group.get("group_id") + group_name = getattr(group, "group_name", None) + if group_name is None and isinstance(group, dict): + group_name = group.get("group_name") + group_avatar = getattr(group, "group_avatar", None) + if group_avatar is None and isinstance(group, dict): + group_avatar = group.get("group_avatar") + group_owner = getattr(group, "group_owner", None) + if group_owner is None and isinstance(group, dict): + group_owner = group.get("group_owner") + group_admins = getattr(group, "group_admins", None) + if group_admins is None and isinstance(group, dict): + group_admins = group.get("group_admins") + return { + "group_id": str(group_id or ""), + "group_name": str(group_name or ""), + "group_avatar": str(group_avatar or ""), + "group_owner": str(group_owner or ""), + "group_admins": ( + [str(item) for item in group_admins] + if isinstance(group_admins, list) + else [] + ), + "members": members_payload, + } + + @staticmethod + def _serialize_platform_error(error: Any) -> dict[str, Any] | None: + if error is None: + return None + message = getattr(error, "message", None) + timestamp = getattr(error, "timestamp", None) + traceback_value = getattr(error, "traceback", None) + if isinstance(error, dict): + message = error.get("message", message) + timestamp = error.get("timestamp", timestamp) + traceback_value = error.get("traceback", traceback_value) + if not message: + return None + return { + "message": str(message), + "timestamp": CapabilityBridgeBase._to_iso_datetime(timestamp) + or str(timestamp or ""), + "traceback": ( + str(traceback_value) if traceback_value is not None else None + ), + } + + @classmethod + def _serialize_platform_snapshot(cls, platform: Any) -> dict[str, Any] | None: + if platform is None: + return None + meta = None + try: + meta = platform.meta() + except Exception: + meta = None + platform_id = str( + getattr(meta, "id", None) or getattr(platform, "config", {}).get("id", "") + ).strip() + platform_type = str(getattr(meta, "name", "") or "").strip() + if not platform_id or not platform_type: + return None + status = getattr(platform, "status", None) + errors = getattr(platform, "errors", []) + status_value = getattr(status, "value", status) + return { + "id": platform_id, + "name": str(getattr(meta, "adapter_display_name", None) or platform_type), + "type": platform_type, + "status": str(status_value or "pending"), + "errors": [ + payload + for payload in ( + cls._serialize_platform_error(item) + for item in (errors if isinstance(errors, list) else []) + ) + if payload is not None + ], + "last_error": cls._serialize_platform_error( + getattr(platform, "last_error", None) + ), + "unified_webhook": bool( + platform.unified_webhook() + if hasattr(platform, "unified_webhook") + else False + ), + } + + @classmethod + def _serialize_platform_stats(cls, stats: Any) -> dict[str, Any] | None: + if not isinstance(stats, dict): + return None + payload = dict(stats) + payload["last_error"] = cls._serialize_platform_error(stats.get("last_error")) + meta = stats.get("meta") + payload["meta"] = dict(meta) if isinstance(meta, dict) else {} + return payload + + def _get_platform_inst_by_id(self, platform_id: str) -> Any | None: + platform_manager = getattr(self._star_context, "platform_manager", None) + if platform_manager is None or not hasattr(platform_manager, "get_insts"): + return None + normalized_platform_id = str(platform_id).strip() + if not normalized_platform_id: + return None + for platform in list(platform_manager.get_insts()): + meta = None + try: + meta = platform.meta() + except Exception: + continue + if str(getattr(meta, "id", "")).strip() == normalized_platform_id: + return platform + return None + + def _resolve_plugin_id(self, request_id: str) -> str: + plugin_id = current_caller_plugin_id() + if plugin_id: + return plugin_id + return self._plugin_bridge.resolve_request_plugin_id(request_id) + + def _reserved_plugin_names(self) -> set[str]: + reserved: set[str] = set() + get_all_stars = getattr(self._star_context, "get_all_stars", None) + if not callable(get_all_stars): + return reserved + stars = get_all_stars() + if not isinstance(stars, Iterable): + return reserved + for star in stars: + name = getattr(star, "name", None) + if name and bool(getattr(star, "reserved", False)): + reserved.add(str(name)) + return reserved + + def _require_reserved_plugin( + self, + request_id: str, + capability_name: str, + ) -> str: + plugin_id = self._resolve_plugin_id(request_id) + if plugin_id in {"system", "__system__"}: + return plugin_id + if plugin_id in self._reserved_plugin_names(): + return plugin_id + raise AstrBotError.invalid_input( + f"{capability_name} is restricted to reserved/system plugins" + ) + + def _plugin_supports_platform(self, plugin_id: str, platform_name: str) -> bool: + checker = getattr(self._plugin_bridge, "plugin_supports_platform", None) + if not callable(checker): + return True + return bool(checker(plugin_id, platform_name)) + + def _platform_name_from_id(self, platform_id: str) -> str: + platform = self._get_platform_inst_by_id(platform_id) + if platform is None: + return "" + meta = getattr(platform, "meta", None) + if not callable(meta): + return "" + try: + payload = meta() + except Exception: + return "" + return str(getattr(payload, "name", "") or "").strip().lower() + + def _session_platform_name(self, session: str) -> str: + platform_id = str(session).split(":", maxsplit=1)[0].strip() + if not platform_id: + return "" + return self._platform_name_from_id(platform_id) + + def _require_platform_support_for_session( + self, + request_id: str, + session: str, + capability_name: str, + ) -> str: + plugin_id = self._resolve_plugin_id(request_id) + platform_name = self._session_platform_name(session) + if not platform_name or self._plugin_supports_platform( + plugin_id, platform_name + ): + return plugin_id + raise AstrBotError.invalid_input( + f"{capability_name} does not support platform '{platform_name}' for plugin '{plugin_id}'" + ) + + def _resolve_dispatch_target( + self, + request_id: str, + payload: dict[str, Any], + ) -> tuple[str, str]: + target_payload = payload.get("target") + dispatch_token = "" + if isinstance(target_payload, dict): + raw_payload = target_payload.get("raw") + if isinstance(raw_payload, dict): + dispatch_token = str(raw_payload.get("dispatch_token", "")) + if not dispatch_token: + nested_raw_payload = raw_payload.get("raw") + if isinstance(nested_raw_payload, dict): + dispatch_token = str( + nested_raw_payload.get("dispatch_token", "") + ) + if not dispatch_token: + request_context = self._plugin_bridge.resolve_request_session(request_id) + if request_context is None: + raise AstrBotError.invalid_input( + "Missing dispatch token for platform send" + ) + dispatch_token = request_context.dispatch_token + session = str(payload.get("session", "")) + return session, dispatch_token + + def _resolve_event_request_context( + self, + request_id: str, + payload: dict[str, Any], + ): + def _has_event(request_context: Any | None) -> bool: + if request_context is None: + return False + has_event = getattr(request_context, "has_event", None) + if has_event is not None: + return bool(has_event) + return hasattr(request_context, "event") + + target_payload = payload.get("target") + dispatch_token = "" + if isinstance(target_payload, dict): + raw_payload = target_payload.get("raw") + if isinstance(raw_payload, dict): + dispatch_token = str(raw_payload.get("dispatch_token", "")) + if not dispatch_token: + nested_raw = raw_payload.get("raw") + if isinstance(nested_raw, dict): + dispatch_token = str(nested_raw.get("dispatch_token", "")) + if dispatch_token: + request_context = self._plugin_bridge.get_request_context_by_token( + dispatch_token + ) + return request_context if _has_event(request_context) else None + request_context = self._plugin_bridge.resolve_request_session(request_id) + return request_context if _has_event(request_context) else None + + def _resolve_current_group_request_context( + self, + request_id: str, + payload: dict[str, Any], + ): + request_context = self._resolve_event_request_context(request_id, payload) + if request_context is None: + return None + payload_session = str(payload.get("session", "")).strip() + if payload_session and payload_session != str( + request_context.event.unified_msg_origin + ): + raise AstrBotError.invalid_input( + "platform.get_group/get_members only support the current event session" + ) + return request_context + + @staticmethod + def _build_core_message_chain(chain_payload: list[dict[str, Any]]) -> MessageChain: + return _build_message_chain_from_payload(chain_payload) diff --git a/astrbot/core/sdk_bridge/capabilities/__init__.py b/astrbot/core/sdk_bridge/capabilities/__init__.py new file mode 100644 index 0000000000..4ba44e5e9c --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/__init__.py @@ -0,0 +1,29 @@ +from .basic import BasicCapabilityMixin +from .conversation import ConversationCapabilityMixin +from .kb import KnowledgeBaseCapabilityMixin +from .llm import LLMCapabilityMixin +from .mcp import MCPCapabilityMixin +from .message_history import MessageHistoryCapabilityMixin +from .permission import PermissionCapabilityMixin +from .persona import PersonaCapabilityMixin +from .platform import PlatformCapabilityMixin +from .provider import ProviderCapabilityMixin +from .session import SessionCapabilityMixin +from .skill import SkillCapabilityMixin +from .system import SystemCapabilityMixin + +__all__ = [ + "BasicCapabilityMixin", + "ConversationCapabilityMixin", + "KnowledgeBaseCapabilityMixin", + "LLMCapabilityMixin", + "MCPCapabilityMixin", + "MessageHistoryCapabilityMixin", + "PermissionCapabilityMixin", + "PersonaCapabilityMixin", + "PlatformCapabilityMixin", + "ProviderCapabilityMixin", + "SessionCapabilityMixin", + "SkillCapabilityMixin", + "SystemCapabilityMixin", +] diff --git a/astrbot/core/sdk_bridge/capabilities/_host.py b/astrbot/core/sdk_bridge/capabilities/_host.py new file mode 100644 index 0000000000..c3bda8de05 --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/_host.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from collections.abc import Awaitable +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + + class CapabilityMixinHost: + MEMORY_SCOPE: str + _event_streams: dict[str, Any] + _plugin_bridge: Any + _star_context: Any + _memory_backends_by_plugin: dict[str, Any] + _memory_index_by_plugin: dict[str, dict[str, dict[str, Any]]] + _memory_dirty_keys_by_plugin: dict[str, set[str]] + _memory_expires_at_by_plugin: dict[str, dict[str, Any]] + + def register( + self, + descriptor: Any, + *, + call_handler: Any = None, + stream_handler: Any = None, + finalize: Any = None, + exposed: bool = True, + ) -> None: ... + + def _builtin_descriptor( + self, + name: str, + description: str, + *, + supports_stream: bool = False, + cancelable: bool = False, + ) -> Any: ... + + def _resolve_plugin_id(self, request_id: str) -> str: ... + + def _resolve_dispatch_target( + self, + request_id: str, + payload: dict[str, Any], + ) -> tuple[str, str]: ... + + def _resolve_event_request_context( + self, + request_id: str, + payload: dict[str, Any], + ) -> Any: ... + + def _resolve_current_group_request_context( + self, + request_id: str, + payload: dict[str, Any], + ) -> Any: ... + + def _build_core_message_chain( + self, chain_payload: list[dict[str, Any]] + ) -> Any: ... + + def _serialize_group(self, group: Any) -> dict[str, Any] | None: ... + + def _require_reserved_plugin( + self, + request_id: str, + capability_name: str, + ) -> str: ... + + def _plugin_supports_platform( + self, + plugin_id: str, + platform_name: str, + ) -> bool: ... + + def _platform_name_from_id(self, platform_id: str) -> str: ... + + def _session_platform_name(self, session: str) -> str: ... + + def _require_platform_support_for_session( + self, + request_id: str, + session: str, + capability_name: str, + ) -> str: ... + + def _get_platform_inst_by_id(self, platform_id: str) -> Any | None: ... + + def _serialize_platform_snapshot( + self, platform: Any + ) -> dict[str, Any] | None: ... + + def _serialize_platform_stats(self, stats: Any) -> dict[str, Any] | None: ... + + def _normalize_session_scoped_config( + self, + raw_config: Any, + session_id: str, + ) -> dict[str, Any]: ... + + def _get_typed_provider( + self, + payload: dict[str, Any], + capability_name: str, + provider_label: str, + expected_type: type[Any], + ) -> Any: ... + + def _provider_embedding_get_embedding( + self, + request_id: str, + payload: dict[str, Any], + token: Any, + ) -> Awaitable[dict[str, Any]]: ... + + def _provider_embedding_get_embeddings( + self, + request_id: str, + payload: dict[str, Any], + token: Any, + ) -> Awaitable[dict[str, Any]]: ... + + def _reserved_plugin_names(self) -> set[str]: ... + + def _serialize_persona(self, persona: Any) -> dict[str, Any] | None: ... + + def _normalize_persona_dialogs(self, value: Any) -> list[str]: ... + + def _serialize_conversation( + self, conversation: Any + ) -> dict[str, Any] | None: ... + + def _normalize_history_items(self, value: Any) -> list[dict[str, Any]]: ... + + def _optional_int(self, value: Any) -> int | None: ... + + def _serialize_kb(self, kb_helper_or_record: Any) -> dict[str, Any] | None: ... + + def _serialize_kb_document(self, document: Any) -> dict[str, Any] | None: ... + +else: + + class CapabilityMixinHost: + # Keep the runtime host empty so it cannot shadow CapabilityRouter methods in + # CoreCapabilityBridge's MRO. The typed method declarations above are only for + # static analysis. + pass diff --git a/astrbot/core/sdk_bridge/capabilities/basic.py b/astrbot/core/sdk_bridge/capabilities/basic.py new file mode 100644 index 0000000000..8a4bc765d1 --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/basic.py @@ -0,0 +1,698 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from astrbot_sdk._memory_backend import PluginMemoryBackend +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.runtime.capability_router import StreamExecution + +from astrbot.core.utils.astrbot_path import get_astrbot_plugin_data_path + +from ..bridge_base import _get_runtime_provider_types, _get_runtime_sp +from ._host import CapabilityMixinHost + + +class BasicCapabilityMixin(CapabilityMixinHost): + def _memory_backend_for_plugin(self, plugin_id: str) -> PluginMemoryBackend: + backend = self._memory_backends_by_plugin.get(plugin_id) + if backend is None: + backend = PluginMemoryBackend( + Path(get_astrbot_plugin_data_path()) / plugin_id + ) + self._memory_backends_by_plugin[plugin_id] = backend + return backend + + def _resolve_memory_embedding_provider_id( + self, + payload: dict[str, Any], + *, + required: bool, + ) -> str | None: + provider_id = str(payload.get("provider_id", "")).strip() + _, _, embedding_provider_cls, _ = _get_runtime_provider_types() + if provider_id: + provider = self._star_context.get_provider_by_id(provider_id) + if provider is None or not isinstance(provider, embedding_provider_cls): + raise AstrBotError.invalid_input( + f"memory.search unknown embedding provider: {provider_id}" + ) + return provider_id + providers = self._star_context.get_all_embedding_providers() + if providers: + provider = providers[0] + provider_id = str(getattr(provider.meta(), "id", "") or "").strip() + if provider_id: + return provider_id + if required: + raise AstrBotError.invalid_input( + "memory.search requires an embedding provider", + ) + return None + + def _register_db_capabilities(self) -> None: + self.register( + self._builtin_descriptor("db.get", "Read plugin kv"), + call_handler=self._db_get, + ) + self.register( + self._builtin_descriptor("db.set", "Write plugin kv"), + call_handler=self._db_set, + ) + self.register( + self._builtin_descriptor("db.delete", "Delete plugin kv"), + call_handler=self._db_delete, + ) + self.register( + self._builtin_descriptor("db.list", "List plugin kv"), + call_handler=self._db_list, + ) + self.register( + self._builtin_descriptor("db.get_many", "Read plugin kv in batch"), + call_handler=self._db_get_many, + ) + self.register( + self._builtin_descriptor("db.set_many", "Write plugin kv in batch"), + call_handler=self._db_set_many, + ) + self.register( + self._builtin_descriptor( + "db.watch", + "Watch plugin kv", + supports_stream=True, + cancelable=True, + ), + stream_handler=self._db_watch, + ) + + async def _db_get( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + return { + "value": await _get_runtime_sp().get_async( + "plugin", + plugin_id, + str(payload.get("key", "")), + None, + ) + } + + async def _db_set( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + await _get_runtime_sp().put_async( + "plugin", + plugin_id, + str(payload.get("key", "")), + payload.get("value"), + ) + return {} + + async def _db_delete( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + await _get_runtime_sp().remove_async( + "plugin", + plugin_id, + str(payload.get("key", "")), + ) + return {} + + async def _db_list( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + prefix = payload.get("prefix") + prefix_value = str(prefix) if isinstance(prefix, str) else None + items = await _get_runtime_sp().range_get_async("plugin", plugin_id, None) + keys = sorted( + item.key + for item in items + if prefix_value is None or item.key.startswith(prefix_value) + ) + return {"keys": keys} + + async def _db_get_many( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + keys_payload = payload.get("keys") + if not isinstance(keys_payload, list): + raise AstrBotError.invalid_input("db.get_many requires a keys array") + items = [] + for key in keys_payload: + key_text = str(key) + items.append( + { + "key": key_text, + "value": await _get_runtime_sp().get_async( + "plugin", + plugin_id, + key_text, + None, + ), + } + ) + return {"items": items} + + async def _db_set_many( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + items_payload = payload.get("items") + if not isinstance(items_payload, list): + raise AstrBotError.invalid_input("db.set_many requires an items array") + for item in items_payload: + if not isinstance(item, dict): + raise AstrBotError.invalid_input("db.set_many items must be objects") + await _get_runtime_sp().put_async( + "plugin", + plugin_id, + str(item.get("key", "")), + item.get("value"), + ) + return {} + + async def _db_watch( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> StreamExecution: + raise AstrBotError.invalid_input( + "db.watch is unsupported in AstrBot SDK MVP", + hint="Use db.get/list polling in MVP", + ) + + def _register_memory_capabilities(self) -> None: + self.register( + self._builtin_descriptor("memory.search", "Search plugin memory"), + call_handler=self._memory_search, + ) + self.register( + self._builtin_descriptor("memory.save", "Save plugin memory"), + call_handler=self._memory_save, + ) + self.register( + self._builtin_descriptor("memory.get", "Get plugin memory"), + call_handler=self._memory_get, + ) + self.register( + self._builtin_descriptor("memory.list_keys", "List plugin memory keys"), + call_handler=self._memory_list_keys, + ) + self.register( + self._builtin_descriptor("memory.exists", "Check plugin memory key"), + call_handler=self._memory_exists, + ) + self.register( + self._builtin_descriptor("memory.delete", "Delete plugin memory"), + call_handler=self._memory_delete, + ) + self.register( + self._builtin_descriptor( + "memory.clear_namespace", + "Delete plugin memory in a namespace", + ), + call_handler=self._memory_clear_namespace, + ) + self.register( + self._builtin_descriptor( + "memory.save_with_ttl", + "Save plugin memory with ttl metadata", + ), + call_handler=self._memory_save_with_ttl, + ) + self.register( + self._builtin_descriptor("memory.get_many", "Get plugin memories"), + call_handler=self._memory_get_many, + ) + self.register( + self._builtin_descriptor("memory.delete_many", "Delete plugin memories"), + call_handler=self._memory_delete_many, + ) + self.register( + self._builtin_descriptor("memory.count", "Count plugin memories"), + call_handler=self._memory_count, + ) + self.register( + self._builtin_descriptor("memory.stats", "Get plugin memory stats"), + call_handler=self._memory_stats, + ) + + async def _memory_search( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + query = str(payload.get("query", "")) + mode = str(payload.get("mode", "auto")).strip().lower() or "auto" + limit = self._optional_int(payload.get("limit")) + raw_min_score = payload.get("min_score") + min_score = float(raw_min_score) if raw_min_score is not None else None + namespace = str(payload.get("namespace")) if payload.get("namespace") else None + include_descendants = bool(payload.get("include_descendants", True)) + provider_id = self._resolve_memory_embedding_provider_id( + payload, + required=mode in {"vector", "hybrid"}, + ) + effective_mode = mode + if effective_mode == "auto": + effective_mode = "hybrid" if provider_id is not None else "keyword" + backend = self._memory_backend_for_plugin(plugin_id) + items = await backend.search( + query, + namespace=namespace, + include_descendants=include_descendants, + mode=effective_mode, + limit=limit, + min_score=min_score, + provider_id=provider_id, + embed_one=( + ( + lambda text: self._memory_embedding_for_text( + request_id, + provider_id, + text, + _token, + ) + ) + if provider_id is not None and effective_mode in {"vector", "hybrid"} + else None + ), + embed_many=( + ( + lambda texts: self._memory_embeddings_for_texts( + request_id, + provider_id, + texts, + _token, + ) + ) + if provider_id is not None and effective_mode in {"vector", "hybrid"} + else None + ), + ) + return {"items": items} + + async def _memory_embedding_for_text( + self, + request_id: str, + provider_id: str, + text: str, + token, + ) -> list[float]: + output = await self._provider_embedding_get_embedding( + request_id, + {"provider_id": provider_id, "text": text}, + token, + ) + embedding = output.get("embedding") + if not isinstance(embedding, list): + return [] + return [float(item) for item in embedding] + + async def _memory_embeddings_for_texts( + self, + request_id: str, + provider_id: str, + texts: list[str], + token, + ) -> list[list[float]]: + output = await self._provider_embedding_get_embeddings( + request_id, + {"provider_id": provider_id, "texts": texts}, + token, + ) + embeddings = output.get("embeddings") + if not isinstance(embeddings, list): + return [] + return [ + [float(value) for value in item] + for item in embeddings + if isinstance(item, list) + ] + + async def _memory_save( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + value = payload.get("value") + if not isinstance(value, dict): + raise AstrBotError.invalid_input("memory.save requires an object value") + await self._memory_backend_for_plugin(plugin_id).save( + str(payload.get("key", "")), + value, + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {} + + async def _memory_get( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + value = await self._memory_backend_for_plugin(plugin_id).get( + str(payload.get("key", "")), + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"value": value} + + async def _memory_list_keys( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + keys = await self._memory_backend_for_plugin(plugin_id).list_keys( + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"keys": keys} + + async def _memory_exists( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + exists = await self._memory_backend_for_plugin(plugin_id).exists( + str(payload.get("key", "")), + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"exists": exists} + + async def _memory_delete( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + await self._memory_backend_for_plugin(plugin_id).delete( + str(payload.get("key", "")), + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {} + + async def _memory_clear_namespace( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + deleted_count = await self._memory_backend_for_plugin( + plugin_id + ).clear_namespace( + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + include_descendants=bool(payload.get("include_descendants", False)), + ) + return {"deleted_count": deleted_count} + + async def _memory_save_with_ttl( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + value = payload.get("value") + if not isinstance(value, dict): + raise AstrBotError.invalid_input( + "memory.save_with_ttl requires an object value" + ) + ttl_seconds = int(payload.get("ttl_seconds", 0)) + await self._memory_backend_for_plugin(plugin_id).save_with_ttl( + str(payload.get("key", "")), + value, + ttl_seconds, + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {} + + async def _memory_get_many( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + keys_payload = payload.get("keys") + if not isinstance(keys_payload, list): + raise AstrBotError.invalid_input("memory.get_many requires a keys array") + items = await self._memory_backend_for_plugin(plugin_id).get_many( + [str(key) for key in keys_payload], + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"items": items} + + async def _memory_delete_many( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + keys_payload = payload.get("keys") + if not isinstance(keys_payload, list): + raise AstrBotError.invalid_input("memory.delete_many requires a keys array") + deleted_count = await self._memory_backend_for_plugin(plugin_id).delete_many( + [str(key) for key in keys_payload], + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + ) + return {"deleted_count": deleted_count} + + async def _memory_count( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + count = await self._memory_backend_for_plugin(plugin_id).count( + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + include_descendants=bool(payload.get("include_descendants", False)), + ) + return {"count": count} + + async def _memory_stats( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + stats = await self._memory_backend_for_plugin(plugin_id).stats( + namespace=( + str(payload.get("namespace")) + if payload.get("namespace") is not None + else None + ), + include_descendants=bool(payload.get("include_descendants", True)), + ) + stats["plugin_id"] = plugin_id + return stats + + def _register_http_capabilities(self) -> None: + self.register( + self._builtin_descriptor("http.register_api", "Register http route"), + call_handler=self._http_register_api, + ) + self.register( + self._builtin_descriptor("http.unregister_api", "Unregister http route"), + call_handler=self._http_unregister_api, + ) + self.register( + self._builtin_descriptor("http.list_apis", "List http routes"), + call_handler=self._http_list_apis, + ) + + async def _http_register_api( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + methods = payload.get("methods") + if not isinstance(methods, list) or not all( + isinstance(item, str) for item in methods + ): + raise AstrBotError.invalid_input( + "http.register_api requires a string methods array" + ) + self._plugin_bridge.register_http_api( + plugin_id=plugin_id, + route=str(payload.get("route", "")), + methods=methods, + handler_capability=str(payload.get("handler_capability", "")), + description=str(payload.get("description", "")), + ) + return {} + + async def _http_unregister_api( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + methods = payload.get("methods") + if not isinstance(methods, list) or not all( + isinstance(item, str) for item in methods + ): + raise AstrBotError.invalid_input( + "http.unregister_api requires a string methods array" + ) + self._plugin_bridge.unregister_http_api( + plugin_id=plugin_id, + route=str(payload.get("route", "")), + methods=methods, + ) + return {} + + async def _http_list_apis( + self, + request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + return {"apis": self._plugin_bridge.list_http_apis(plugin_id)} + + def _register_metadata_capabilities(self) -> None: + self.register( + self._builtin_descriptor("metadata.get_plugin", "Get plugin metadata"), + call_handler=self._metadata_get_plugin, + ) + self.register( + self._builtin_descriptor("metadata.list_plugins", "List plugins metadata"), + call_handler=self._metadata_list_plugins, + ) + self.register( + self._builtin_descriptor( + "metadata.get_plugin_config", + "Get current plugin config", + ), + call_handler=self._metadata_get_plugin_config, + ) + self.register( + self._builtin_descriptor( + "metadata.save_plugin_config", + "Save current plugin config", + ), + call_handler=self._metadata_save_plugin_config, + ) + + async def _metadata_get_plugin( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin = self._plugin_bridge.get_plugin_metadata(str(payload.get("name", ""))) + return {"plugin": plugin} + + async def _metadata_list_plugins( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + return {"plugins": self._plugin_bridge.list_plugin_metadata()} + + async def _metadata_get_plugin_config( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + requested = str(payload.get("name", "")) + if requested != plugin_id: + return {"config": None} + return {"config": self._plugin_bridge.get_plugin_config(plugin_id)} + + async def _metadata_save_plugin_config( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + config = payload.get("config") + if not isinstance(config, dict): + raise AstrBotError.invalid_input( + "metadata.save_plugin_config requires config object" + ) + return {"config": self._plugin_bridge.save_plugin_config(plugin_id, config)} diff --git a/astrbot/core/sdk_bridge/capabilities/conversation.py b/astrbot/core/sdk_bridge/capabilities/conversation.py new file mode 100644 index 0000000000..90ba6a15fa --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/conversation.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +from astrbot_sdk.errors import AstrBotError + +from ._host import CapabilityMixinHost + + +class ConversationCapabilityMixin(CapabilityMixinHost): + def _register_conversation_capabilities(self) -> None: + self.register( + self._builtin_descriptor("conversation.new", "Create conversation"), + call_handler=self._conversation_new, + ) + self.register( + self._builtin_descriptor("conversation.switch", "Switch conversation"), + call_handler=self._conversation_switch, + ) + self.register( + self._builtin_descriptor("conversation.delete", "Delete conversation"), + call_handler=self._conversation_delete, + ) + self.register( + self._builtin_descriptor("conversation.get", "Get conversation"), + call_handler=self._conversation_get, + ) + self.register( + self._builtin_descriptor( + "conversation.get_current", + "Get current conversation", + ), + call_handler=self._conversation_get_current, + ) + self.register( + self._builtin_descriptor("conversation.list", "List conversations"), + call_handler=self._conversation_list, + ) + self.register( + self._builtin_descriptor("conversation.update", "Update conversation"), + call_handler=self._conversation_update, + ) + self.register( + self._builtin_descriptor( + "conversation.unset_persona", + "Unset conversation persona override", + ), + call_handler=self._conversation_unset_persona, + ) + + async def _conversation_new( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = str(payload.get("session", "")).strip() + if not session: + raise AstrBotError.invalid_input("conversation.new requires session") + raw_conversation = payload.get("conversation") + if raw_conversation is None: + raw_conversation = {} + if not isinstance(raw_conversation, dict): + raise AstrBotError.invalid_input( + "conversation.new requires conversation object" + ) + conversation_id = ( + await self._star_context.conversation_manager.new_conversation( + unified_msg_origin=session, + platform_id=( + str(raw_conversation.get("platform_id")) + if raw_conversation.get("platform_id") is not None + else None + ), + content=self._normalize_history_items(raw_conversation.get("history")), + title=( + str(raw_conversation.get("title")) + if raw_conversation.get("title") is not None + else None + ), + persona_id=( + str(raw_conversation.get("persona_id")) + if raw_conversation.get("persona_id") is not None + else None + ), + ) + ) + return {"conversation_id": conversation_id} + + async def _conversation_switch( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = str(payload.get("session", "")).strip() + conversation_id = str(payload.get("conversation_id", "")).strip() + if not session: + raise AstrBotError.invalid_input("conversation.switch requires session") + if not conversation_id: + raise AstrBotError.invalid_input( + "conversation.switch requires conversation_id" + ) + await self._star_context.conversation_manager.switch_conversation( + unified_msg_origin=session, + conversation_id=conversation_id, + ) + return {} + + async def _conversation_delete( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + await self._star_context.conversation_manager.delete_conversation( + unified_msg_origin=str(payload.get("session", "")), + conversation_id=( + str(payload.get("conversation_id")) + if payload.get("conversation_id") is not None + else None + ), + ) + return {} + + async def _conversation_get( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + conversation = await self._star_context.conversation_manager.get_conversation( + unified_msg_origin=str(payload.get("session", "")), + conversation_id=str(payload.get("conversation_id", "")), + create_if_not_exists=bool(payload.get("create_if_not_exists", False)), + ) + return {"conversation": self._serialize_conversation(conversation)} + + async def _conversation_get_current( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = str(payload.get("session", "")) + conversation_id = ( + await self._star_context.conversation_manager.get_curr_conversation_id( + session + ) + ) + if not conversation_id and bool(payload.get("create_if_not_exists", False)): + conversation_id = ( + await self._star_context.conversation_manager.new_conversation(session) + ) + if not conversation_id: + return {"conversation": None} + conversation = await self._star_context.conversation_manager.get_conversation( + unified_msg_origin=session, + conversation_id=conversation_id, + create_if_not_exists=bool(payload.get("create_if_not_exists", False)), + ) + return {"conversation": self._serialize_conversation(conversation)} + + async def _conversation_list( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = payload.get("session") + platform_id = payload.get("platform_id") + conversations = await self._star_context.conversation_manager.get_conversations( + unified_msg_origin=( + str(session) if session is not None and str(session).strip() else None + ), + platform_id=( + str(platform_id) + if platform_id is not None and str(platform_id).strip() + else None + ), + ) + return { + "conversations": [ + item + for item in ( + self._serialize_conversation(conversation) + for conversation in conversations + ) + if item is not None + ] + } + + async def _conversation_update( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + raw_conversation = payload.get("conversation") + if raw_conversation is None: + raw_conversation = {} + if not isinstance(raw_conversation, dict): + raise AstrBotError.invalid_input( + "conversation.update requires conversation object" + ) + await self._star_context.conversation_manager.update_conversation( + unified_msg_origin=str(payload.get("session", "")), + conversation_id=( + str(payload.get("conversation_id")) + if payload.get("conversation_id") is not None + else None + ), + history=( + self._normalize_history_items(raw_conversation.get("history")) + if "history" in raw_conversation + else None + ), + title=( + str(raw_conversation.get("title")) + if raw_conversation.get("title") is not None + else None + ), + persona_id=( + str(raw_conversation.get("persona_id")) + if raw_conversation.get("persona_id") is not None + else None + ), + token_usage=self._optional_int(raw_conversation.get("token_usage")), + ) + return {} + + async def _conversation_unset_persona( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + await self._star_context.conversation_manager.unset_conversation_persona( + unified_msg_origin=str(payload.get("session", "")), + conversation_id=( + str(payload.get("conversation_id")) + if payload.get("conversation_id") is not None + else None + ), + ) + return {} diff --git a/astrbot/core/sdk_bridge/capabilities/kb.py b/astrbot/core/sdk_bridge/capabilities/kb.py new file mode 100644 index 0000000000..fe252d414f --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/kb.py @@ -0,0 +1,456 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import Any + +from astrbot_sdk.errors import AstrBotError + +from astrbot.core.sdk_bridge.bridge_base import _get_runtime_file_token_service + +from ._host import CapabilityMixinHost + + +class KnowledgeBaseCapabilityMixin(CapabilityMixinHost): + def _register_kb_capabilities(self) -> None: + self.register( + self._builtin_descriptor("kb.list", "List knowledge bases"), + call_handler=self._kb_list, + ) + self.register( + self._builtin_descriptor("kb.get", "Get knowledge base"), + call_handler=self._kb_get, + ) + self.register( + self._builtin_descriptor("kb.create", "Create knowledge base"), + call_handler=self._kb_create, + ) + self.register( + self._builtin_descriptor("kb.update", "Update knowledge base"), + call_handler=self._kb_update, + ) + self.register( + self._builtin_descriptor("kb.delete", "Delete knowledge base"), + call_handler=self._kb_delete, + ) + self.register( + self._builtin_descriptor("kb.retrieve", "Retrieve from knowledge bases"), + call_handler=self._kb_retrieve, + ) + self.register( + self._builtin_descriptor( + "kb.document.upload", "Upload knowledge base document" + ), + call_handler=self._kb_document_upload, + ) + self.register( + self._builtin_descriptor( + "kb.document.list", "List knowledge base documents" + ), + call_handler=self._kb_document_list, + ) + self.register( + self._builtin_descriptor("kb.document.get", "Get knowledge base document"), + call_handler=self._kb_document_get, + ) + self.register( + self._builtin_descriptor( + "kb.document.delete", + "Delete knowledge base document", + ), + call_handler=self._kb_document_delete, + ) + self.register( + self._builtin_descriptor( + "kb.document.refresh", + "Refresh knowledge base document", + ), + call_handler=self._kb_document_refresh, + ) + + async def _get_kb_helper(self, kb_id: str): + return await self._star_context.kb_manager.get_kb(kb_id) + + async def _require_kb_helper(self, kb_id: str): + kb_id_text = str(kb_id).strip() + if not kb_id_text: + raise AstrBotError.invalid_input("kb capability requires kb_id") + kb_helper = await self._get_kb_helper(kb_id_text) + if kb_helper is None: + raise AstrBotError.invalid_input(f"Unknown knowledge base: {kb_id_text}") + return kb_helper + + @staticmethod + def _normalize_kb_names(payload: dict[str, Any]) -> list[str]: + raw_names = payload.get("kb_names") + if not isinstance(raw_names, list): + return [] + return [str(item).strip() for item in raw_names if str(item).strip()] + + @staticmethod + def _normalize_kb_ids(payload: dict[str, Any]) -> list[str]: + raw_ids = payload.get("kb_ids") + if not isinstance(raw_ids, list): + return [] + return [str(item).strip() for item in raw_ids if str(item).strip()] + + async def _resolve_retrieve_kb_names( + self, + payload: dict[str, Any], + ) -> list[str]: + kb_names = self._normalize_kb_names(payload) + if kb_names: + return kb_names + resolved_names: list[str] = [] + for kb_id in self._normalize_kb_ids(payload): + kb_helper = await self._get_kb_helper(kb_id) + if kb_helper is not None and getattr(kb_helper, "kb", None) is not None: + kb_name = str(getattr(kb_helper.kb, "kb_name", "")).strip() + if kb_name: + resolved_names.append(kb_name) + return resolved_names + + async def _kb_list( + self, + _request_id: str, + _payload: dict[str, object], + _token, + ) -> dict[str, object]: + kbs = await self._star_context.kb_manager.list_kbs() + return { + "kbs": [ + payload + for payload in (self._serialize_kb(kb) for kb in kbs) + if payload is not None + ] + } + + async def _kb_get( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + kb_helper = await self._get_kb_helper(str(payload.get("kb_id", ""))) + return {"kb": self._serialize_kb(kb_helper)} + + async def _kb_create( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + raw_kb = payload.get("kb") + if not isinstance(raw_kb, dict): + raise AstrBotError.invalid_input("kb.create requires kb object") + try: + kb_helper = await self._star_context.kb_manager.create_kb( + kb_name=str(raw_kb.get("kb_name", "")), + description=( + str(raw_kb.get("description")) + if raw_kb.get("description") is not None + else None + ), + emoji=( + str(raw_kb.get("emoji")) + if raw_kb.get("emoji") is not None + else None + ), + embedding_provider_id=( + str(raw_kb.get("embedding_provider_id")) + if raw_kb.get("embedding_provider_id") is not None + else None + ), + rerank_provider_id=( + str(raw_kb.get("rerank_provider_id")) + if raw_kb.get("rerank_provider_id") is not None + else None + ), + chunk_size=self._optional_int(raw_kb.get("chunk_size")), + chunk_overlap=self._optional_int(raw_kb.get("chunk_overlap")), + top_k_dense=self._optional_int(raw_kb.get("top_k_dense")), + top_k_sparse=self._optional_int(raw_kb.get("top_k_sparse")), + top_m_final=self._optional_int(raw_kb.get("top_m_final")), + ) + except ValueError as exc: + raise AstrBotError.invalid_input(str(exc)) from exc + return {"kb": self._serialize_kb(kb_helper)} + + async def _kb_update( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + kb_id = str(payload.get("kb_id", "")).strip() + raw_kb = payload.get("kb") + if not isinstance(raw_kb, dict): + raise AstrBotError.invalid_input("kb.update requires kb object") + kb_helper = await self._get_kb_helper(kb_id) + if kb_helper is None: + return {"kb": None} + current_kb = getattr(kb_helper, "kb", None) + kb_name = raw_kb.get("kb_name") + try: + updated_helper = await self._star_context.kb_manager.update_kb( + kb_id=kb_id, + kb_name=( + str(kb_name) + if kb_name is not None + else str(getattr(current_kb, "kb_name", "")) + ), + description=( + str(raw_kb.get("description")) + if raw_kb.get("description") is not None + else None + ) + if "description" in raw_kb + else None, + emoji=( + str(raw_kb.get("emoji")) + if raw_kb.get("emoji") is not None + else None + ) + if "emoji" in raw_kb + else None, + embedding_provider_id=( + str(raw_kb.get("embedding_provider_id")) + if raw_kb.get("embedding_provider_id") is not None + else None + ) + if "embedding_provider_id" in raw_kb + else None, + rerank_provider_id=( + str(raw_kb.get("rerank_provider_id")) + if raw_kb.get("rerank_provider_id") is not None + else None + ) + if "rerank_provider_id" in raw_kb + else None, + chunk_size=( + self._optional_int(raw_kb.get("chunk_size")) + if "chunk_size" in raw_kb + else None + ), + chunk_overlap=( + self._optional_int(raw_kb.get("chunk_overlap")) + if "chunk_overlap" in raw_kb + else None + ), + top_k_dense=( + self._optional_int(raw_kb.get("top_k_dense")) + if "top_k_dense" in raw_kb + else None + ), + top_k_sparse=( + self._optional_int(raw_kb.get("top_k_sparse")) + if "top_k_sparse" in raw_kb + else None + ), + top_m_final=( + self._optional_int(raw_kb.get("top_m_final")) + if "top_m_final" in raw_kb + else None + ), + ) + except ValueError as exc: + raise AstrBotError.invalid_input(str(exc)) from exc + return {"kb": self._serialize_kb(updated_helper)} + + async def _kb_delete( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + deleted = await self._star_context.kb_manager.delete_kb( + str(payload.get("kb_id", "")) + ) + return {"deleted": bool(deleted)} + + async def _kb_retrieve( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + query = str(payload.get("query", "")).strip() + if not query: + raise AstrBotError.invalid_input("kb.retrieve requires query") + kb_names = await self._resolve_retrieve_kb_names(payload) + if not kb_names: + raise AstrBotError.invalid_input("kb.retrieve requires kb_ids or kb_names") + result = await self._star_context.kb_manager.retrieve( + query=query, + kb_names=kb_names, + top_k_fusion=self._optional_int(payload.get("top_k_fusion")) or 20, + top_m_final=self._optional_int(payload.get("top_m_final")) or 5, + ) + if result is None: + return {"result": None} + return {"result": dict(result)} + + async def _kb_document_upload( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + kb_id = str(payload.get("kb_id", "")).strip() + kb_helper = await self._require_kb_helper(kb_id) + raw_document = payload.get("document") + if not isinstance(raw_document, dict): + raise AstrBotError.invalid_input( + "kb.document.upload requires document object" + ) + + text_value = raw_document.get("text") + if isinstance(text_value, str) and text_value.strip(): + file_name = str(raw_document.get("file_name", "")).strip() or "document.txt" + file_type = ( + str(raw_document.get("file_type", "")).strip() + or Path(file_name).suffix.lstrip(".") + or "txt" + ) + document = await kb_helper.upload_document( + file_name=file_name, + file_content=None, + file_type=file_type, + chunk_size=self._optional_int(raw_document.get("chunk_size")) or 512, + chunk_overlap=( + self._optional_int(raw_document.get("chunk_overlap")) or 50 + ), + batch_size=self._optional_int(raw_document.get("batch_size")) or 32, + tasks_limit=self._optional_int(raw_document.get("tasks_limit")) or 3, + max_retries=self._optional_int(raw_document.get("max_retries")) or 3, + pre_chunked_text=[text_value], + ) + return {"document": self._serialize_kb_document(document)} + + url_value = raw_document.get("url") + if isinstance(url_value, str) and url_value.strip(): + try: + document = await self._star_context.kb_manager.upload_from_url( + kb_id=kb_id, + url=url_value.strip(), + chunk_size=self._optional_int(raw_document.get("chunk_size")) + or 512, + chunk_overlap=( + self._optional_int(raw_document.get("chunk_overlap")) or 50 + ), + batch_size=self._optional_int(raw_document.get("batch_size")) or 32, + tasks_limit=self._optional_int(raw_document.get("tasks_limit")) + or 3, + max_retries=self._optional_int(raw_document.get("max_retries")) + or 3, + enable_cleaning=bool(raw_document.get("enable_cleaning", False)), + cleaning_provider_id=( + str(raw_document.get("cleaning_provider_id")) + if raw_document.get("cleaning_provider_id") is not None + else None + ), + ) + except (OSError, ValueError) as exc: + raise AstrBotError.invalid_input(str(exc)) from exc + return {"document": self._serialize_kb_document(document)} + + file_token = str(raw_document.get("file_token", "")).strip() + if not file_token: + raise AstrBotError.invalid_input( + "kb.document.upload requires file_token, url, or text" + ) + try: + file_path = await _get_runtime_file_token_service().handle_file(file_token) + except KeyError as exc: + raise AstrBotError.invalid_input(str(exc)) from exc + path = Path(file_path) + if not path.exists(): + raise AstrBotError.invalid_input(f"File does not exist: {file_path}") + file_name = str(raw_document.get("file_name", "")).strip() or path.name + file_type = str( + raw_document.get("file_type", "") + ).strip() or path.suffix.lstrip(".") + if not file_type: + raise AstrBotError.invalid_input( + "kb.document.upload requires file_type when the file has no suffix" + ) + file_content = await asyncio.to_thread(path.read_bytes) + try: + document = await kb_helper.upload_document( + file_name=file_name, + file_content=file_content, + file_type=file_type, + chunk_size=self._optional_int(raw_document.get("chunk_size")) or 512, + chunk_overlap=( + self._optional_int(raw_document.get("chunk_overlap")) or 50 + ), + batch_size=self._optional_int(raw_document.get("batch_size")) or 32, + tasks_limit=self._optional_int(raw_document.get("tasks_limit")) or 3, + max_retries=self._optional_int(raw_document.get("max_retries")) or 3, + ) + except ValueError as exc: + raise AstrBotError.invalid_input(str(exc)) from exc + return {"document": self._serialize_kb_document(document)} + + async def _kb_document_list( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + kb_helper = await self._require_kb_helper(str(payload.get("kb_id", ""))) + documents = await kb_helper.list_documents( + offset=self._optional_int(payload.get("offset")) or 0, + limit=self._optional_int(payload.get("limit")) or 100, + ) + return { + "documents": [ + item + for item in ( + self._serialize_kb_document(document) for document in documents + ) + if item is not None + ] + } + + async def _kb_document_get( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + kb_helper = await self._require_kb_helper(str(payload.get("kb_id", ""))) + document = await kb_helper.get_document(str(payload.get("doc_id", ""))) + return {"document": self._serialize_kb_document(document)} + + async def _kb_document_delete( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + kb_helper = await self._require_kb_helper(str(payload.get("kb_id", ""))) + doc_id = str(payload.get("doc_id", "")).strip() + existing_document = await kb_helper.get_document(doc_id) + if existing_document is None: + return {"deleted": False} + await kb_helper.delete_document(doc_id) + return {"deleted": True} + + async def _kb_document_refresh( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + kb_helper = await self._require_kb_helper(str(payload.get("kb_id", ""))) + doc_id = str(payload.get("doc_id", "")).strip() + document = await kb_helper.get_document(doc_id) + if document is None: + return {"document": None} + try: + await kb_helper.refresh_document(doc_id) + except ValueError as exc: + raise AstrBotError.invalid_input(str(exc)) from exc + refreshed_document = await kb_helper.get_document(doc_id) + return {"document": self._serialize_kb_document(refreshed_document)} diff --git a/astrbot/core/sdk_bridge/capabilities/llm.py b/astrbot/core/sdk_bridge/capabilities/llm.py new file mode 100644 index 0000000000..c5bd47fb87 --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/llm.py @@ -0,0 +1,302 @@ +from __future__ import annotations + +import asyncio +import time +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING, Any, Protocol, TypeGuard + +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.runtime.capability_router import StreamExecution + +from astrbot import logger + +from ..bridge_base import _get_runtime_tool_types +from ._host import CapabilityMixinHost + +if TYPE_CHECKING: + from astrbot.core.agent.tool import ToolSet + from astrbot.core.provider.entities import LLMResponse + + +class _ChatProvider(Protocol): + async def text_chat(self, **kwargs: Any) -> LLMResponse: ... + + async def text_chat_stream(self, **kwargs: Any) -> AsyncIterator[LLMResponse]: ... + + +class _ProviderMetaLike(Protocol): + id: str + model: str | None + + +class LLMCapabilityMixin(CapabilityMixinHost): + def _register_llm_capabilities(self) -> None: + self.register( + self._builtin_descriptor("llm.chat", "Send chat request"), + call_handler=self._llm_chat, + ) + self.register( + self._builtin_descriptor( + "llm.chat_raw", + "Send chat request and return raw response", + ), + call_handler=self._llm_chat_raw, + ) + self.register( + self._builtin_descriptor( + "llm.stream_chat", + "Stream chat response", + supports_stream=True, + cancelable=True, + ), + stream_handler=self._llm_stream_chat, + ) + + async def _llm_chat( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + response = await self._call_llm(payload, request_id=request_id) + return {"text": response.completion_text} + + async def _llm_chat_raw( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + response = await self._call_llm(payload, request_id=request_id) + usage = None + if response.usage is not None: + usage = { + "input_tokens": response.usage.input, + "output_tokens": response.usage.output, + "total_tokens": response.usage.total, + } + return { + "text": response.completion_text, + "usage": usage, + "finish_reason": "tool_calls" if response.tools_call_ids else "stop", + "tool_calls": response.to_openai_tool_calls(), + "role": response.role, + "reasoning_content": response.reasoning_content or None, + "reasoning_signature": response.reasoning_signature, + } + + async def _llm_stream_chat( + self, + request_id: str, + payload: dict[str, Any], + token, + ) -> StreamExecution: + provider, request_kwargs = self._resolve_llm_request( + payload, + request_id=request_id, + ) + started_at = time.perf_counter() + provider_label = self._describe_provider(provider) + + async def fallback_iterator() -> AsyncIterator[dict[str, Any]]: + logger.warning( + f"SDK llm.stream_chat fell back to non-streaming provider.text_chat for {provider_label}" + ) + response = await provider.text_chat(**request_kwargs) + logger.info( + f"SDK llm.stream_chat fallback first output for {provider_label} after {time.perf_counter() - started_at:.3f}s" + ) + for char in response.completion_text: + token.raise_if_cancelled() + await asyncio.sleep(0) + yield {"text": char} + + async def iterator() -> AsyncIterator[dict[str, Any]]: + try: + stream = provider.text_chat_stream(**request_kwargs) + yielded_text = False + first_text_logged = False + async for response in stream: + token.raise_if_cancelled() + text = response.completion_text + if response.is_chunk: + if text: + if not first_text_logged: + first_text_logged = True + logger.info( + f"SDK llm.stream_chat first streamed chunk for {provider_label} after {time.perf_counter() - started_at:.3f}s" + ) + yielded_text = True + yield {"text": text} + continue + if text: + if not first_text_logged: + first_text_logged = True + logger.info( + f"SDK llm.stream_chat first final chunk for {provider_label} after {time.perf_counter() - started_at:.3f}s" + ) + if yielded_text: + yield {"_final_text": text} + else: + yielded_text = True + yield {"text": text, "_final_text": text} + else: + yield {"_final_text": text} + except NotImplementedError: + async for item in fallback_iterator(): + yield item + + def finalize(chunks: list[dict[str, Any]]) -> dict[str, Any]: + final_text = None + for item in reversed(chunks): + if "_final_text" in item: + final_text = str(item.get("_final_text", "")) + break + if final_text is None: + final_text = "".join(str(item.get("text", "")) for item in chunks) + return {"text": final_text} + + return StreamExecution( + iterator=iterator(), + finalize=finalize, + ) + + async def _call_llm( + self, + payload: dict[str, Any], + *, + request_id: str, + ) -> LLMResponse: + provider, request_kwargs = self._resolve_llm_request( + payload, + request_id=request_id, + ) + return await provider.text_chat(**request_kwargs) + + def _resolve_llm_request( + self, + payload: dict[str, Any], + *, + request_id: str, + ) -> tuple[_ChatProvider, dict[str, Any]]: + request_context = self._plugin_bridge.resolve_request_session(request_id) + provider_id = payload.get("provider_id") + if provider_id: + provider = self._star_context.get_provider_by_id(str(provider_id)) + else: + request_context_has_event = False + if request_context is not None: + has_event = getattr(request_context, "has_event", None) + request_context_has_event = ( + bool(has_event) + if has_event is not None + else hasattr(request_context, "event") + ) + provider = self._star_context.get_using_provider( + request_context.event.unified_msg_origin + if request_context is not None and request_context_has_event + else None, + ) + if provider is None: + raise AstrBotError.internal_error( + "No active chat provider is available", + hint="Please configure a chat provider in AstrBot first", + ) + if not self._is_chat_provider(provider): + raise AstrBotError.invalid_input( + f"Provider '{provider_id}' is not a chat provider", + hint="Please choose a configured chat provider for llm.chat requests", + ) + return provider, self._normalize_llm_payload(payload) + + @staticmethod + def _describe_provider(provider: _ChatProvider) -> str: + provider_meta_getter = getattr(provider, "meta", None) + if not callable(provider_meta_getter): + return provider.__class__.__name__ + provider_meta = provider_meta_getter() + if not LLMCapabilityMixin._is_provider_meta(provider_meta): + return provider.__class__.__name__ + return f"{provider_meta.id}/{provider_meta.model}" + + @staticmethod + def _is_chat_provider(provider: object) -> TypeGuard[_ChatProvider]: + return callable(getattr(provider, "text_chat", None)) and callable( + getattr(provider, "text_chat_stream", None) + ) + + @staticmethod + def _is_provider_meta(value: object) -> TypeGuard[_ProviderMetaLike]: + return hasattr(value, "id") and hasattr(value, "model") + + @staticmethod + def _normalize_llm_payload(payload: dict[str, Any]) -> dict[str, Any]: + contexts_payload = payload.get("contexts") + if contexts_payload is None: + contexts_payload = payload.get("history") + contexts = ( + [dict(item) for item in contexts_payload] + if isinstance(contexts_payload, list) + else None + ) + image_urls = payload.get("image_urls") + tool_calls_result = payload.get("tool_calls_result") + tools_payload = payload.get("tools") + request_kwargs: dict[str, Any] = { + "prompt": str(payload.get("prompt", "")), + "image_urls": ( + [str(item) for item in image_urls] + if isinstance(image_urls, list) + else None + ), + "func_tool": ( + LLMCapabilityMixin._build_toolset(tools_payload) + if isinstance(tools_payload, list) + else None + ), + "contexts": contexts, + "tool_calls_result": ( + [dict(item) for item in tool_calls_result] + if isinstance(tool_calls_result, list) + else None + ), + "system_prompt": str(payload.get("system", "")), + "model": (str(payload["model"]) if payload.get("model") else None), + "temperature": payload.get("temperature"), + } + return request_kwargs + + @staticmethod + def _build_toolset(tools_payload: list[Any]) -> ToolSet: + function_tool_cls, tool_set_cls = _get_runtime_tool_types() + tool_set = tool_set_cls() + for item in tools_payload: + if not isinstance(item, dict): + raise AstrBotError.invalid_input("llm tools items must be objects") + if str(item.get("type", "function")) != "function": + raise AstrBotError.invalid_input( + "Only function tools are supported in AstrBot SDK MVP" + ) + function_payload = item.get("function") + if not isinstance(function_payload, dict): + raise AstrBotError.invalid_input( + "llm tools items must contain a function object" + ) + name = str(function_payload.get("name", "")).strip() + if not name: + raise AstrBotError.invalid_input( + "llm function tool name must not be empty" + ) + description = str(function_payload.get("description", "") or "") + parameters = function_payload.get("parameters") + if not isinstance(parameters, dict): + parameters = {"type": "object", "properties": {}} + tool_set.add_tool( + function_tool_cls( + name=name, + description=description, + parameters=parameters, + handler=None, + ) + ) + return tool_set diff --git a/astrbot/core/sdk_bridge/capabilities/mcp.py b/astrbot/core/sdk_bridge/capabilities/mcp.py new file mode 100644 index 0000000000..ff58c83b5f --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/mcp.py @@ -0,0 +1,517 @@ +from __future__ import annotations + +from typing import Any + +from astrbot_sdk.errors import AstrBotError + +from astrbot.core import logger + +from ._host import CapabilityMixinHost + + +class MCPCapabilityMixin(CapabilityMixinHost): + @staticmethod + def _mcp_timeout(payload: dict[str, Any], capability_name: str) -> float: + raw_timeout = payload.get("timeout", 30.0) + try: + timeout = float(raw_timeout) + except (TypeError, ValueError) as exc: + raise AstrBotError.invalid_input( + f"{capability_name} requires numeric timeout" + ) from exc + if timeout <= 0: + raise AstrBotError.invalid_input(f"{capability_name} requires timeout > 0") + return timeout + + @staticmethod + def _mcp_name(payload: dict[str, Any], capability_name: str) -> str: + name = str(payload.get("name", "")).strip() + if not name: + raise AstrBotError.invalid_input(f"{capability_name} requires name") + return name + + @staticmethod + def _mcp_config(payload: dict[str, Any], capability_name: str) -> dict[str, Any]: + config = payload.get("config") + if not isinstance(config, dict): + raise AstrBotError.invalid_input( + f"{capability_name} requires config object" + ) + return dict(config) + + def _func_tool_manager(self): + return self._star_context.get_llm_tool_manager() + + @staticmethod + def _global_mcp_record_from_state( + *, + name: str, + config: dict[str, Any], + runtime: Any | None, + ) -> dict[str, Any]: + client = getattr(runtime, "client", None) if runtime is not None else None + return { + "name": name, + "scope": "global", + "active": bool(config.get("active", True)), + "running": runtime is not None, + "config": dict(config), + "tools": [ + str(tool.name) + for tool in getattr(client, "tools", []) + if getattr(tool, "name", None) + ] + if client is not None + else [], + "errlogs": list(getattr(client, "server_errlogs", [])) + if client is not None + else [], + "last_error": None, + } + + def _get_global_mcp_record(self, name: str) -> dict[str, Any] | None: + func_tool_manager = self._func_tool_manager() + config_payload = func_tool_manager.load_mcp_config() + servers = config_payload.get("mcpServers") + if not isinstance(servers, dict): + return None + config = servers.get(name) + if not isinstance(config, dict): + return None + runtime = func_tool_manager.mcp_server_runtime_view.get(name) + return self._global_mcp_record_from_state( + name=name, + config=dict(config), + runtime=runtime, + ) + + def _list_global_mcp_records(self) -> list[dict[str, Any]]: + func_tool_manager = self._func_tool_manager() + config_payload = func_tool_manager.load_mcp_config() + servers = config_payload.get("mcpServers") + if not isinstance(servers, dict): + return [] + return [ + self._global_mcp_record_from_state( + name=str(name), + config=dict(config), + runtime=func_tool_manager.mcp_server_runtime_view.get(str(name)), + ) + for name, config in sorted(servers.items(), key=lambda item: str(item[0])) + if str(name).strip() and isinstance(config, dict) + ] + + def _require_global_mcp_ack(self, request_id: str, capability_name: str) -> str: + plugin_id = self._resolve_plugin_id(request_id) + if self._plugin_bridge.acknowledges_global_mcp_risk(plugin_id): + return plugin_id + raise PermissionError( + f"{capability_name} requires @acknowledge_global_mcp_risk" + ) + + @staticmethod + def _audit_global_mcp_mutation( + *, + plugin_id: str, + action: str, + server_name: str, + request_id: str, + ) -> None: + audit_entry = { + "plugin_id": plugin_id, + "action": action, + "server_name": server_name, + "request_id": request_id, + } + logger.info("SDK global MCP mutation: {}", audit_entry) + + async def _mcp_local_get( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + name = self._mcp_name(payload, "mcp.local.get") + return {"server": self._plugin_bridge.get_local_mcp_server(plugin_id, name)} + + async def _mcp_local_list( + self, + request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + return {"servers": self._plugin_bridge.list_local_mcp_servers(plugin_id)} + + async def _mcp_local_enable( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + name = self._mcp_name(payload, "mcp.local.enable") + timeout = self._mcp_timeout(payload, "mcp.local.enable") + return { + "server": await self._plugin_bridge.enable_local_mcp_server( + plugin_id, + name, + timeout=timeout, + ) + } + + async def _mcp_local_disable( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + name = self._mcp_name(payload, "mcp.local.disable") + return { + "server": await self._plugin_bridge.disable_local_mcp_server( + plugin_id, + name, + ) + } + + async def _mcp_local_wait_until_ready( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + name = self._mcp_name(payload, "mcp.local.wait_until_ready") + timeout = self._mcp_timeout(payload, "mcp.local.wait_until_ready") + return { + "server": await self._plugin_bridge.wait_for_local_mcp_server( + plugin_id, + name, + timeout=timeout, + ) + } + + async def _mcp_session_open( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + name = self._mcp_name(payload, "mcp.session.open") + config = self._mcp_config(payload, "mcp.session.open") + timeout = self._mcp_timeout(payload, "mcp.session.open") + session_id, tools = await self._plugin_bridge.open_temporary_mcp_session( + plugin_id, + name=name, + config=config, + timeout=timeout, + ) + return {"session_id": session_id, "tools": tools} + + async def _mcp_session_list_tools( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + session_id = str(payload.get("session_id", "")).strip() + return { + "tools": self._plugin_bridge.get_temporary_mcp_session_tools( + plugin_id, + session_id, + ) + } + + async def _mcp_session_call_tool( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + session_id = str(payload.get("session_id", "")).strip() + tool_name = str(payload.get("tool_name", "")).strip() + if not tool_name: + raise AstrBotError.invalid_input("mcp.session.call_tool requires tool_name") + args = payload.get("args") + if not isinstance(args, dict): + raise AstrBotError.invalid_input( + "mcp.session.call_tool requires args object" + ) + result = await self._plugin_bridge.call_temporary_mcp_tool( + plugin_id, + session_id=session_id, + tool_name=tool_name, + arguments=dict(args), + ) + return {"result": result} + + async def _mcp_session_close( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + session_id = str(payload.get("session_id", "")).strip() + await self._plugin_bridge.close_temporary_mcp_session(plugin_id, session_id) + return {} + + async def _mcp_global_register( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._require_global_mcp_ack(request_id, "mcp.global.register") + name = self._mcp_name(payload, "mcp.global.register") + config = self._mcp_config(payload, "mcp.global.register") + timeout = self._mcp_timeout(payload, "mcp.global.register") + func_tool_manager = self._func_tool_manager() + config_payload = func_tool_manager.load_mcp_config() + servers = config_payload.setdefault("mcpServers", {}) + if not isinstance(servers, dict): + raise AstrBotError.invalid_input("Invalid global MCP config shape") + if name in servers: + raise AstrBotError.invalid_input( + f"Global MCP server already exists: {name}" + ) + normalized_config = dict(config) + normalized_config.setdefault("active", True) + servers[name] = normalized_config + func_tool_manager.save_mcp_config(config_payload) + if bool(normalized_config.get("active", True)): + await func_tool_manager.enable_mcp_server( + name, normalized_config, timeout=timeout + ) + record = self._get_global_mcp_record(name) + self._audit_global_mcp_mutation( + plugin_id=plugin_id, + action="register", + server_name=name, + request_id=request_id, + ) + return {"server": record} + + async def _mcp_global_get( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_global_mcp_ack(request_id, "mcp.global.get") + name = self._mcp_name(payload, "mcp.global.get") + return {"server": self._get_global_mcp_record(name)} + + async def _mcp_global_list( + self, + request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_global_mcp_ack(request_id, "mcp.global.list") + return {"servers": self._list_global_mcp_records()} + + async def _mcp_global_enable( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._require_global_mcp_ack(request_id, "mcp.global.enable") + name = self._mcp_name(payload, "mcp.global.enable") + timeout = self._mcp_timeout(payload, "mcp.global.enable") + func_tool_manager = self._func_tool_manager() + config_payload = func_tool_manager.load_mcp_config() + servers = config_payload.get("mcpServers") + if ( + not isinstance(servers, dict) + or name not in servers + or not isinstance(servers[name], dict) + ): + raise AstrBotError.invalid_input(f"Unknown global MCP server: {name}") + servers[name]["active"] = True + func_tool_manager.save_mcp_config(config_payload) + await func_tool_manager.enable_mcp_server( + name, dict(servers[name]), timeout=timeout + ) + record = self._get_global_mcp_record(name) + self._audit_global_mcp_mutation( + plugin_id=plugin_id, + action="enable", + server_name=name, + request_id=request_id, + ) + return {"server": record} + + async def _mcp_global_disable( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._require_global_mcp_ack(request_id, "mcp.global.disable") + name = self._mcp_name(payload, "mcp.global.disable") + func_tool_manager = self._func_tool_manager() + config_payload = func_tool_manager.load_mcp_config() + servers = config_payload.get("mcpServers") + if ( + not isinstance(servers, dict) + or name not in servers + or not isinstance(servers[name], dict) + ): + raise AstrBotError.invalid_input(f"Unknown global MCP server: {name}") + servers[name]["active"] = False + func_tool_manager.save_mcp_config(config_payload) + await func_tool_manager.disable_mcp_server(name) + record = self._get_global_mcp_record(name) + self._audit_global_mcp_mutation( + plugin_id=plugin_id, + action="disable", + server_name=name, + request_id=request_id, + ) + return {"server": record} + + async def _mcp_global_unregister( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._require_global_mcp_ack(request_id, "mcp.global.unregister") + name = self._mcp_name(payload, "mcp.global.unregister") + func_tool_manager = self._func_tool_manager() + existing_record = self._get_global_mcp_record(name) + if existing_record is None: + raise AstrBotError.invalid_input(f"Unknown global MCP server: {name}") + config_payload = func_tool_manager.load_mcp_config() + servers = config_payload.get("mcpServers") + if not isinstance(servers, dict): + raise AstrBotError.invalid_input("Invalid global MCP config shape") + servers.pop(name, None) + func_tool_manager.save_mcp_config(config_payload) + await func_tool_manager.disable_mcp_server(name) + existing_record["running"] = False + self._audit_global_mcp_mutation( + plugin_id=plugin_id, + action="unregister", + server_name=name, + request_id=request_id, + ) + return {"server": existing_record} + + async def _internal_mcp_local_execute( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = str(payload.get("plugin_id", "")).strip() + server_name = str(payload.get("server_name", "")).strip() + tool_name = str(payload.get("tool_name", "")).strip() + tool_args = payload.get("tool_args") + if not plugin_id or not server_name or not tool_name: + raise AstrBotError.invalid_input( + "internal.mcp.local.execute requires plugin_id, server_name, and tool_name" + ) + if not isinstance(tool_args, dict): + raise AstrBotError.invalid_input( + "internal.mcp.local.execute requires tool_args object" + ) + return await self._plugin_bridge.execute_local_mcp_tool( + plugin_id, + server_name=server_name, + tool_name=tool_name, + tool_args=dict(tool_args), + ) + + def _register_mcp_capabilities(self) -> None: + self.register( + self._builtin_descriptor("mcp.local.get", "Get local MCP server"), + call_handler=self._mcp_local_get, + ) + self.register( + self._builtin_descriptor("mcp.local.list", "List local MCP servers"), + call_handler=self._mcp_local_list, + ) + self.register( + self._builtin_descriptor("mcp.local.enable", "Enable local MCP server"), + call_handler=self._mcp_local_enable, + ) + self.register( + self._builtin_descriptor("mcp.local.disable", "Disable local MCP server"), + call_handler=self._mcp_local_disable, + ) + self.register( + self._builtin_descriptor( + "mcp.local.wait_until_ready", + "Wait until local MCP server is ready", + ), + call_handler=self._mcp_local_wait_until_ready, + ) + self.register( + self._builtin_descriptor("mcp.session.open", "Open temporary MCP session"), + call_handler=self._mcp_session_open, + ) + self.register( + self._builtin_descriptor( + "mcp.session.list_tools", + "List temporary MCP session tools", + ), + call_handler=self._mcp_session_list_tools, + ) + self.register( + self._builtin_descriptor( + "mcp.session.call_tool", + "Call tool on temporary MCP session", + ), + call_handler=self._mcp_session_call_tool, + ) + self.register( + self._builtin_descriptor( + "mcp.session.close", "Close temporary MCP session" + ), + call_handler=self._mcp_session_close, + ) + self.register( + self._builtin_descriptor( + "mcp.global.register", "Register global MCP server" + ), + call_handler=self._mcp_global_register, + ) + self.register( + self._builtin_descriptor("mcp.global.get", "Get global MCP server"), + call_handler=self._mcp_global_get, + ) + self.register( + self._builtin_descriptor("mcp.global.list", "List global MCP servers"), + call_handler=self._mcp_global_list, + ) + self.register( + self._builtin_descriptor("mcp.global.enable", "Enable global MCP server"), + call_handler=self._mcp_global_enable, + ) + self.register( + self._builtin_descriptor("mcp.global.disable", "Disable global MCP server"), + call_handler=self._mcp_global_disable, + ) + self.register( + self._builtin_descriptor( + "mcp.global.unregister", + "Unregister global MCP server", + ), + call_handler=self._mcp_global_unregister, + ) + self.register( + self._builtin_descriptor( + "internal.mcp.local.execute", + "Execute local MCP tool", + ), + call_handler=self._internal_mcp_local_execute, + exposed=False, + ) diff --git a/astrbot/core/sdk_bridge/capabilities/message_history.py b/astrbot/core/sdk_bridge/capabilities/message_history.py new file mode 100644 index 0000000000..ebcdb74378 --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/message_history.py @@ -0,0 +1,302 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.message.components import component_to_payload_sync + +from astrbot.core.platform.message_session import MessageSession +from astrbot.core.platform.message_type import MessageType +from astrbot.core.platform_message_history_mgr import MessageHistorySender + +from ._host import CapabilityMixinHost + + +def _core_message_type_from_sdk(value: str) -> MessageType: + normalized = str(value).strip().lower() + if normalized == "group": + return MessageType.GROUP_MESSAGE + if normalized == "private": + return MessageType.FRIEND_MESSAGE + if normalized == "other": + return MessageType.OTHER_MESSAGE + raise AstrBotError.invalid_input( + f"Unsupported message history message_type: {value}" + ) + + +def _sdk_message_type_from_core(value: MessageType | str) -> str: + if isinstance(value, MessageType): + if value == MessageType.GROUP_MESSAGE: + return "group" + if value == MessageType.FRIEND_MESSAGE: + return "private" + return "other" + return str(value).strip().lower() + + +class MessageHistoryCapabilityMixin(CapabilityMixinHost): + @staticmethod + def _typed_message_history_session(payload: Any) -> MessageSession: + if not isinstance(payload, dict): + raise AstrBotError.invalid_input( + "message_history capabilities require a session object" + ) + platform_id = str(payload.get("platform_id", "")).strip() + message_type = str(payload.get("message_type", "")).strip() + session_id = str(payload.get("session_id", "")).strip() + if not platform_id or not message_type or not session_id: + raise AstrBotError.invalid_input( + "message_history session requires platform_id, message_type, and session_id" + ) + return MessageSession( + platform_name=platform_id, + message_type=_core_message_type_from_sdk(message_type), + session_id=session_id, + ) + + @staticmethod + def _serialize_session(session: MessageSession) -> dict[str, str]: + return { + "platform_id": str(session.platform_id), + "message_type": _sdk_message_type_from_core(session.message_type), + "session_id": str(session.session_id), + } + + def _serialize_message_history_record(self, record: Any) -> dict[str, Any] | None: + if record is None: + return None + session = getattr(record, "session", None) + sender = getattr(record, "sender", None) + parts = getattr(record, "parts", None) + return { + "id": int(getattr(record, "id", 0) or 0), + "session": ( + self._serialize_session(session) + if isinstance(session, MessageSession) + else {} + ), + "sender": { + "sender_id": ( + str(getattr(sender, "sender_id", "")) + if getattr(sender, "sender_id", None) is not None + else None + ), + "sender_name": ( + str(getattr(sender, "sender_name", "")) + if getattr(sender, "sender_name", None) is not None + else None + ), + }, + "parts": ( + [component_to_payload_sync(part) for part in parts] + if isinstance(parts, list) + else [] + ), + "metadata": ( + dict(getattr(record, "metadata", {})) + if isinstance(getattr(record, "metadata", None), dict) + else {} + ), + "created_at": self._to_iso_datetime(getattr(record, "created_at", None)), + "updated_at": self._to_iso_datetime(getattr(record, "updated_at", None)), + "idempotency_key": ( + str(getattr(record, "idempotency_key", "")) + if getattr(record, "idempotency_key", None) is not None + else None + ), + } + + @staticmethod + def _parse_boundary(raw_value: Any, field_name: str) -> datetime: + text = str(raw_value or "").strip() + if not text: + raise AstrBotError.invalid_input( + f"message_history.{field_name} requires {field_name}" + ) + try: + return datetime.fromisoformat(text) + except ValueError as exc: + raise AstrBotError.invalid_input( + f"message_history.{field_name} requires an ISO datetime string" + ) from exc + + async def _message_history_list( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = self._typed_message_history_session(payload.get("session")) + raw_limit = self._optional_int(payload.get("limit")) + limit = 50 if raw_limit is None else raw_limit + if limit < 1: + raise AstrBotError.invalid_input("message_history.list requires limit >= 1") + page = await self._star_context.message_history_manager.list( + session, + cursor=( + str(payload.get("cursor")) + if payload.get("cursor") is not None + else None + ), + limit=limit, + ) + return { + "page": { + "records": [ + item + for item in ( + self._serialize_message_history_record(record) + for record in page.records + ) + if item is not None + ], + "next_cursor": page.next_cursor, + "total": page.total, + } + } + + async def _message_history_get_by_id( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = self._typed_message_history_session(payload.get("session")) + record_id = self._optional_int(payload.get("record_id")) + if record_id is None or record_id < 1: + raise AstrBotError.invalid_input( + "message_history.get_by_id requires record_id >= 1" + ) + record = await self._star_context.message_history_manager.get_by_id( + session, + record_id, + ) + return {"record": self._serialize_message_history_record(record)} + + async def _message_history_append( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = self._typed_message_history_session(payload.get("session")) + sender_payload = payload.get("sender") + if not isinstance(sender_payload, dict): + raise AstrBotError.invalid_input( + "message_history.append requires sender object" + ) + parts_payload = payload.get("parts") + if not isinstance(parts_payload, list) or any( + not isinstance(item, dict) for item in parts_payload + ): + raise AstrBotError.invalid_input( + "message_history.append requires parts array" + ) + metadata = payload.get("metadata") + if metadata is not None and not isinstance(metadata, dict): + raise AstrBotError.invalid_input( + "message_history.append requires metadata object when provided" + ) + record = await self._star_context.message_history_manager.append( + session, + parts=self._build_core_message_chain(parts_payload).chain, + sender=MessageHistorySender( + sender_id=( + str(sender_payload.get("sender_id")) + if sender_payload.get("sender_id") is not None + else None + ), + sender_name=( + str(sender_payload.get("sender_name")) + if sender_payload.get("sender_name") is not None + else None + ), + ), + metadata=dict(metadata or {}), + idempotency_key=( + str(payload.get("idempotency_key")) + if payload.get("idempotency_key") is not None + else None + ), + ) + return {"record": self._serialize_message_history_record(record)} + + async def _message_history_delete_before( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = self._typed_message_history_session(payload.get("session")) + deleted_count = await self._star_context.message_history_manager.delete_before( + session, + before=self._parse_boundary(payload.get("before"), "delete_before"), + ) + return {"deleted_count": int(deleted_count)} + + async def _message_history_delete_after( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = self._typed_message_history_session(payload.get("session")) + deleted_count = await self._star_context.message_history_manager.delete_after( + session, + after=self._parse_boundary(payload.get("after"), "delete_after"), + ) + return {"deleted_count": int(deleted_count)} + + async def _message_history_delete_all( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + session = self._typed_message_history_session(payload.get("session")) + deleted_count = await self._star_context.message_history_manager.delete_all( + session + ) + return {"deleted_count": int(deleted_count)} + + def _register_message_history_capabilities(self) -> None: + self.register( + self._builtin_descriptor("message_history.list", "List message history"), + call_handler=self._message_history_list, + ) + self.register( + self._builtin_descriptor( + "message_history.get_by_id", + "Get message history by id", + ), + call_handler=self._message_history_get_by_id, + ) + self.register( + self._builtin_descriptor( + "message_history.append", "Append message history" + ), + call_handler=self._message_history_append, + ) + self.register( + self._builtin_descriptor( + "message_history.delete_before", + "Delete message history before timestamp", + ), + call_handler=self._message_history_delete_before, + ) + self.register( + self._builtin_descriptor( + "message_history.delete_after", + "Delete message history after timestamp", + ), + call_handler=self._message_history_delete_after, + ) + self.register( + self._builtin_descriptor( + "message_history.delete_all", + "Delete all message history in session", + ), + call_handler=self._message_history_delete_all, + ) diff --git a/astrbot/core/sdk_bridge/capabilities/permission.py b/astrbot/core/sdk_bridge/capabilities/permission.py new file mode 100644 index 0000000000..e1d1a907aa --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/permission.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +from typing import Any + +from astrbot_sdk.errors import AstrBotError + +from ._host import CapabilityMixinHost + + +class PermissionCapabilityMixin(CapabilityMixinHost): + def _register_permission_capabilities(self) -> None: + self.register( + self._builtin_descriptor("permission.check", "Check user permission role"), + call_handler=self._permission_check, + ) + self.register( + self._builtin_descriptor("permission.get_admins", "List admin ids"), + call_handler=self._permission_get_admins, + ) + self.register( + self._builtin_descriptor( + "permission.manager.add_admin", + "Add admin id", + ), + call_handler=self._permission_manager_add_admin, + ) + self.register( + self._builtin_descriptor( + "permission.manager.remove_admin", + "Remove admin id", + ), + call_handler=self._permission_manager_remove_admin, + ) + + @staticmethod + def _normalize_admin_ids(values: Any) -> list[str]: + if not isinstance(values, list): + return [] + normalized: list[str] = [] + for item in values: + user_id = str(item).strip() + if user_id: + normalized.append(user_id) + return normalized + + def _permission_config(self) -> Any: + get_config = getattr(self._star_context, "get_config", None) + if callable(get_config): + return get_config() + config = getattr(self._star_context, "_config", None) + if config is not None: + return config + raise AstrBotError.invalid_input("permission capabilities require core config") + + def _admin_ids_snapshot(self, config: Any) -> list[str]: + admins = self._normalize_admin_ids( + config.get("admins_id", []) if hasattr(config, "get") else [] + ) + config["admins_id"] = list(admins) + return admins + + @staticmethod + def _save_config(config: Any) -> None: + save_config = getattr(config, "save_config", None) + if callable(save_config): + save_config() + + @staticmethod + def _required_user_id(payload: dict[str, Any], capability_name: str) -> str: + user_id = str(payload.get("user_id", "")).strip() + if not user_id: + raise AstrBotError.invalid_input(f"{capability_name} requires user_id") + return user_id + + def _require_admin_event_context( + self, + request_id: str, + payload: dict[str, Any], + capability_name: str, + ) -> None: + request_context = self._resolve_event_request_context(request_id, payload) + if request_context is None or bool( + getattr(request_context, "cancelled", False) + ): + raise AstrBotError.invalid_input( + f"{capability_name} requires an active event context" + ) + event = getattr(request_context, "event", None) + if event is None or not callable(getattr(event, "is_admin", None)): + raise AstrBotError.invalid_input( + f"{capability_name} requires an active event context" + ) + if not bool(event.is_admin()): + raise AstrBotError.invalid_input( + f"{capability_name} requires admin privileges" + ) + + async def _permission_check( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + user_id = self._required_user_id(payload, "permission.check") + config = self._permission_config() + admins = self._admin_ids_snapshot(config) + is_admin = user_id in admins + return { + "is_admin": is_admin, + "role": "admin" if is_admin else "member", + } + + async def _permission_get_admins( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + config = self._permission_config() + return {"admins": self._admin_ids_snapshot(config)} + + async def _permission_manager_add_admin( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "permission.manager.add_admin") + self._require_admin_event_context( + request_id, + payload, + "permission.manager.add_admin", + ) + user_id = self._required_user_id(payload, "permission.manager.add_admin") + config = self._permission_config() + admins = self._admin_ids_snapshot(config) + if user_id in admins: + return {"changed": False} + admins.append(user_id) + config["admins_id"] = admins + self._save_config(config) + return {"changed": True} + + async def _permission_manager_remove_admin( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "permission.manager.remove_admin") + self._require_admin_event_context( + request_id, + payload, + "permission.manager.remove_admin", + ) + user_id = self._required_user_id(payload, "permission.manager.remove_admin") + config = self._permission_config() + admins = self._admin_ids_snapshot(config) + if user_id not in admins: + return {"changed": False} + admins.remove(user_id) + config["admins_id"] = admins + self._save_config(config) + return {"changed": True} diff --git a/astrbot/core/sdk_bridge/capabilities/persona.py b/astrbot/core/sdk_bridge/capabilities/persona.py new file mode 100644 index 0000000000..94db89cabb --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/persona.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +from astrbot_sdk.errors import AstrBotError + +from ._host import CapabilityMixinHost + + +class PersonaCapabilityMixin(CapabilityMixinHost): + def _register_persona_capabilities(self) -> None: + self.register( + self._builtin_descriptor("persona.get", "Get persona"), + call_handler=self._persona_get, + ) + self.register( + self._builtin_descriptor("persona.list", "List personas"), + call_handler=self._persona_list, + ) + self.register( + self._builtin_descriptor("persona.create", "Create persona"), + call_handler=self._persona_create, + ) + self.register( + self._builtin_descriptor("persona.update", "Update persona"), + call_handler=self._persona_update, + ) + self.register( + self._builtin_descriptor("persona.delete", "Delete persona"), + call_handler=self._persona_delete, + ) + + async def _persona_get( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + persona_id = str(payload.get("persona_id", "")).strip() + try: + persona = await self._star_context.persona_manager.get_persona(persona_id) + except ValueError as exc: + raise AstrBotError.invalid_input(str(exc)) from exc + return {"persona": self._serialize_persona(persona)} + + async def _persona_list( + self, + _request_id: str, + _payload: dict[str, object], + _token, + ) -> dict[str, object]: + personas = await self._star_context.persona_manager.get_all_personas() + return { + "personas": [ + payload + for payload in ( + self._serialize_persona(persona) for persona in personas + ) + if payload is not None + ] + } + + async def _persona_create( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + raw_persona = payload.get("persona") + if not isinstance(raw_persona, dict): + raise AstrBotError.invalid_input("persona.create requires persona object") + try: + persona = await self._star_context.persona_manager.create_persona( + persona_id=str(raw_persona.get("persona_id", "")), + system_prompt=str(raw_persona.get("system_prompt", "")), + begin_dialogs=self._normalize_persona_dialogs( + raw_persona.get("begin_dialogs") + ), + tools=( + [str(item) for item in raw_persona.get("tools", [])] + if isinstance(raw_persona.get("tools"), list) + else None + ), + skills=( + [str(item) for item in raw_persona.get("skills", [])] + if isinstance(raw_persona.get("skills"), list) + else None + ), + custom_error_message=( + str(raw_persona.get("custom_error_message")) + if raw_persona.get("custom_error_message") is not None + else None + ), + folder_id=( + str(raw_persona.get("folder_id")) + if raw_persona.get("folder_id") is not None + else None + ), + sort_order=int(raw_persona.get("sort_order", 0)), + ) + except ValueError as exc: + raise AstrBotError.invalid_input(str(exc)) from exc + return {"persona": self._serialize_persona(persona)} + + async def _persona_update( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + raw_persona = payload.get("persona") + if not isinstance(raw_persona, dict): + raise AstrBotError.invalid_input("persona.update requires persona object") + persona = await self._star_context.persona_manager.update_persona( + persona_id=str(payload.get("persona_id", "")), + system_prompt=raw_persona.get("system_prompt"), + begin_dialogs=( + self._normalize_persona_dialogs(raw_persona.get("begin_dialogs")) + if "begin_dialogs" in raw_persona + else None + ), + tools=( + [str(item) for item in raw_persona.get("tools", [])] + if isinstance(raw_persona.get("tools"), list) + else raw_persona.get("tools") + ), + skills=( + [str(item) for item in raw_persona.get("skills", [])] + if isinstance(raw_persona.get("skills"), list) + else raw_persona.get("skills") + ), + custom_error_message=raw_persona.get("custom_error_message"), + ) + return {"persona": self._serialize_persona(persona)} + + async def _persona_delete( + self, + _request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, object]: + persona_id = str(payload.get("persona_id", "")).strip() + try: + await self._star_context.persona_manager.delete_persona(persona_id) + except ValueError as exc: + raise AstrBotError.invalid_input(str(exc)) from exc + return {} diff --git a/astrbot/core/sdk_bridge/capabilities/platform.py b/astrbot/core/sdk_bridge/capabilities/platform.py new file mode 100644 index 0000000000..68668ababc --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/platform.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +import uuid +from typing import Any + +from astrbot_sdk.errors import AstrBotError + +from astrbot.core.message.components import Image, Plain +from astrbot.core.message.message_event_result import MessageChain + +from ._host import CapabilityMixinHost + + +class PlatformCapabilityMixin(CapabilityMixinHost): + def _register_platform_capabilities(self) -> None: + self.register( + self._builtin_descriptor("platform.send", "Send plain text"), + call_handler=self._platform_send, + ) + self.register( + self._builtin_descriptor("platform.send_image", "Send image"), + call_handler=self._platform_send_image, + ) + self.register( + self._builtin_descriptor("platform.send_chain", "Send message chain"), + call_handler=self._platform_send_chain, + ) + self.register( + self._builtin_descriptor( + "platform.send_by_session", + "Send message chain to a specific session", + ), + call_handler=self._platform_send_by_session, + ) + self.register( + self._builtin_descriptor("platform.get_group", "Get current group data"), + call_handler=self._platform_get_group, + ) + self.register( + self._builtin_descriptor("platform.get_members", "Get group members"), + call_handler=self._platform_get_members, + ) + self.register( + self._builtin_descriptor( + "platform.list_instances", + "List available platform instances", + ), + call_handler=self._platform_list_instances, + ) + + def _register_platform_manager_capabilities(self) -> None: + self.register( + self._builtin_descriptor( + "platform.manager.get_by_id", + "Get platform management snapshot by id", + ), + call_handler=self._platform_manager_get_by_id, + ) + self.register( + self._builtin_descriptor( + "platform.manager.clear_errors", + "Clear platform error records", + ), + call_handler=self._platform_manager_clear_errors, + ) + self.register( + self._builtin_descriptor( + "platform.manager.get_stats", + "Get platform stats by id", + ), + call_handler=self._platform_manager_get_stats, + ) + + async def _platform_send( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + session, dispatch_token = self._resolve_dispatch_target(request_id, payload) + self._require_platform_support_for_session( + request_id, + session, + "platform.send", + ) + self._plugin_bridge.before_platform_send(dispatch_token) + await self._star_context.send_message( + session, + MessageChain([Plain(str(payload.get("text", "")), convert=False)]), + ) + return {"message_id": self._plugin_bridge.mark_platform_send(dispatch_token)} + + async def _platform_send_image( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + session, dispatch_token = self._resolve_dispatch_target(request_id, payload) + self._require_platform_support_for_session( + request_id, + session, + "platform.send_image", + ) + self._plugin_bridge.before_platform_send(dispatch_token) + image_url = str(payload.get("image_url", "")) + component = ( + Image.fromURL(image_url) + if image_url.startswith(("http://", "https://")) + else Image.fromFileSystem(image_url) + ) + await self._star_context.send_message(session, MessageChain([component])) + return {"message_id": self._plugin_bridge.mark_platform_send(dispatch_token)} + + async def _platform_send_chain( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + session, dispatch_token = self._resolve_dispatch_target(request_id, payload) + self._require_platform_support_for_session( + request_id, + session, + "platform.send_chain", + ) + self._plugin_bridge.before_platform_send(dispatch_token) + chain_payload = payload.get("chain") + if not isinstance(chain_payload, list): + raise AstrBotError.invalid_input( + "platform.send_chain requires a chain array" + ) + await self._star_context.send_message( + session, + self._build_core_message_chain(chain_payload), + ) + return {"message_id": self._plugin_bridge.mark_platform_send(dispatch_token)} + + async def _platform_send_by_session( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + chain_payload = payload.get("chain") + if not isinstance(chain_payload, list): + raise AstrBotError.invalid_input( + "platform.send_by_session requires a chain array" + ) + session = str(payload.get("session", "")) + if not session: + raise AstrBotError.invalid_input( + "platform.send_by_session requires a session" + ) + self._require_platform_support_for_session( + request_id, + session, + "platform.send_by_session", + ) + request_context = self._resolve_event_request_context(request_id, payload) + dispatch_token = None + if request_context is not None and not request_context.cancelled: + dispatch_token = request_context.dispatch_token + self._plugin_bridge.before_platform_send(dispatch_token) + await self._star_context.send_message( + session, + self._build_core_message_chain(chain_payload), + ) + if dispatch_token is not None: + return { + "message_id": self._plugin_bridge.mark_platform_send(dispatch_token) + } + return {"message_id": f"sdk_proactive_{uuid.uuid4().hex}"} + + async def _platform_get_group( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + request_context = self._resolve_current_group_request_context( + request_id, payload + ) + if request_context is None: + return {"group": None} + group = await request_context.event.get_group() + return {"group": self._serialize_group(group)} + + async def _platform_get_members( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + request_context = self._resolve_current_group_request_context( + request_id, payload + ) + if request_context is None: + return {"members": []} + group = await request_context.event.get_group() + serialized_group = self._serialize_group(group) + if serialized_group is None: + return {"members": []} + members = serialized_group.get("members") + return {"members": list(members) if isinstance(members, list) else []} + + async def _platform_list_instances( + self, + request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + platform_manager = getattr(self._star_context, "platform_manager", None) + if platform_manager is None or not hasattr(platform_manager, "get_insts"): + return {"platforms": []} + platforms_payload: list[dict[str, Any]] = [] + for platform in list(platform_manager.get_insts()): + meta = None + try: + meta = platform.meta() + except Exception: + continue + platform_id = str(getattr(meta, "id", "")).strip() + platform_type = str(getattr(meta, "name", "")).strip() + if not platform_id or not platform_type: + continue + if not self._plugin_supports_platform(plugin_id, platform_type): + continue + status = getattr(platform, "status", None) + status_value = getattr(status, "value", status) + display_name = str( + getattr(meta, "adapter_display_name", None) or platform_type + ) + platforms_payload.append( + { + "id": platform_id, + "name": display_name, + "type": platform_type, + "status": str(status_value or "unknown"), + } + ) + return {"platforms": platforms_payload} + + async def _platform_manager_get_by_id( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin( + request_id, + "platform.manager.get_by_id", + ) + platform = self._get_platform_inst_by_id(str(payload.get("platform_id", ""))) + return {"platform": self._serialize_platform_snapshot(platform)} + + async def _platform_manager_clear_errors( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin( + request_id, + "platform.manager.clear_errors", + ) + platform = self._get_platform_inst_by_id(str(payload.get("platform_id", ""))) + if platform is None: + raise AstrBotError.invalid_input("Unknown platform_id") + clear_errors = getattr(platform, "clear_errors", None) + if callable(clear_errors): + clear_errors() + return {} + + async def _platform_manager_get_stats( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin( + request_id, + "platform.manager.get_stats", + ) + platform = self._get_platform_inst_by_id(str(payload.get("platform_id", ""))) + if platform is None: + return {"stats": None} + get_stats = getattr(platform, "get_stats", None) + if not callable(get_stats): + return {"stats": None} + return {"stats": self._serialize_platform_stats(get_stats())} diff --git a/astrbot/core/sdk_bridge/capabilities/provider.py b/astrbot/core/sdk_bridge/capabilities/provider.py new file mode 100644 index 0000000000..e5823cf6c3 --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/provider.py @@ -0,0 +1,1335 @@ +from __future__ import annotations + +import asyncio +import base64 +import contextlib +import json +import uuid +from collections.abc import AsyncIterator +from typing import Any, cast + +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.llm.entities import LLMToolSpec, ProviderMeta, ToolCallsResult +from astrbot_sdk.llm.entities import ProviderType as SDKProviderType +from astrbot_sdk.runtime.capability_router import StreamExecution + +from astrbot.core.platform.astr_message_event import AstrMessageEvent + +from ..bridge_base import _get_runtime_provider_types, _get_runtime_tool_types +from ..event_converter import EventConverter +from ._host import CapabilityMixinHost + + +class ProviderCapabilityMixin(CapabilityMixinHost): + def _register_provider_capabilities(self) -> None: + self.register( + self._builtin_descriptor("provider.get_using", "Get active provider"), + call_handler=self._provider_get_using, + ) + self.register( + self._builtin_descriptor("provider.get_by_id", "Get provider by id"), + call_handler=self._provider_get_by_id, + ) + self.register( + self._builtin_descriptor( + "provider.get_current_chat_provider_id", + "Get active chat provider id", + ), + call_handler=self._provider_get_current_chat_provider_id, + ) + self.register( + self._builtin_descriptor("provider.list_all", "List chat providers"), + call_handler=self._provider_list_all, + ) + self.register( + self._builtin_descriptor("provider.list_all_tts", "List tts providers"), + call_handler=self._provider_list_all_tts, + ) + self.register( + self._builtin_descriptor("provider.list_all_stt", "List stt providers"), + call_handler=self._provider_list_all_stt, + ) + self.register( + self._builtin_descriptor( + "provider.list_all_embedding", + "List embedding providers", + ), + call_handler=self._provider_list_all_embedding, + ) + self.register( + self._builtin_descriptor( + "provider.list_all_rerank", + "List rerank providers", + ), + call_handler=self._provider_list_all_rerank, + ) + self.register( + self._builtin_descriptor( + "provider.get_using_tts", + "Get active tts provider", + ), + call_handler=self._provider_get_using_tts, + ) + self.register( + self._builtin_descriptor( + "provider.get_using_stt", + "Get active stt provider", + ), + call_handler=self._provider_get_using_stt, + ) + self.register( + self._builtin_descriptor( + "provider.stt.get_text", + "Transcribe audio with STT provider", + ), + call_handler=self._provider_stt_get_text, + ) + self.register( + self._builtin_descriptor( + "provider.tts.get_audio", + "Synthesize audio with TTS provider", + ), + call_handler=self._provider_tts_get_audio, + ) + self.register( + self._builtin_descriptor( + "provider.tts.support_stream", + "Check whether TTS provider supports native streaming", + ), + call_handler=self._provider_tts_support_stream, + ) + self.register( + self._builtin_descriptor( + "provider.tts.get_audio_stream", + "Stream audio with TTS provider", + supports_stream=True, + cancelable=True, + ), + stream_handler=self._provider_tts_get_audio_stream, + ) + self.register( + self._builtin_descriptor( + "provider.embedding.get_embedding", + "Get embedding vector", + ), + call_handler=self._provider_embedding_get_embedding, + ) + self.register( + self._builtin_descriptor( + "provider.embedding.get_embeddings", + "Get embedding vectors in batch", + ), + call_handler=self._provider_embedding_get_embeddings, + ) + self.register( + self._builtin_descriptor( + "provider.embedding.get_dim", + "Get embedding dimension", + ), + call_handler=self._provider_embedding_get_dim, + ) + self.register( + self._builtin_descriptor( + "provider.rerank.rerank", + "Rerank documents", + ), + call_handler=self._provider_rerank_rerank, + ) + self.register( + self._builtin_descriptor( + "llm_tool.manager.get", + "Get registered and active sdk llm tools", + ), + call_handler=self._llm_tool_manager_get, + ) + self.register( + self._builtin_descriptor( + "llm_tool.manager.activate", + "Activate sdk llm tool", + ), + call_handler=self._llm_tool_manager_activate, + ) + self.register( + self._builtin_descriptor( + "llm_tool.manager.deactivate", + "Deactivate sdk llm tool", + ), + call_handler=self._llm_tool_manager_deactivate, + ) + self.register( + self._builtin_descriptor( + "llm_tool.manager.add", + "Register sdk llm tool metadata", + ), + call_handler=self._llm_tool_manager_add, + ) + self.register( + self._builtin_descriptor( + "llm_tool.manager.remove", + "Unregister sdk llm tool metadata", + ), + call_handler=self._llm_tool_manager_remove, + ) + self.register( + self._builtin_descriptor("agent.tool_loop.run", "Run sdk tool loop agent"), + call_handler=self._agent_tool_loop_run, + ) + self.register( + self._builtin_descriptor("agent.registry.list", "List sdk agents"), + call_handler=self._agent_registry_list, + ) + self.register( + self._builtin_descriptor("agent.registry.get", "Get sdk agent"), + call_handler=self._agent_registry_get, + ) + + def _register_provider_manager_capabilities(self) -> None: + self.register( + self._builtin_descriptor("provider.manager.set", "Set active provider"), + call_handler=self._provider_manager_set, + ) + self.register( + self._builtin_descriptor( + "provider.manager.get_by_id", + "Get managed provider record by id", + ), + call_handler=self._provider_manager_get_by_id, + ) + self.register( + self._builtin_descriptor( + "provider.manager.get_merged_provider_config", + "Get merged managed provider config by id", + ), + call_handler=self._provider_manager_get_merged_provider_config, + ) + self.register( + self._builtin_descriptor( + "provider.manager.load", + "Load a provider instance without persisting config", + ), + call_handler=self._provider_manager_load, + ) + self.register( + self._builtin_descriptor( + "provider.manager.terminate", + "Terminate a loaded provider instance", + ), + call_handler=self._provider_manager_terminate, + ) + self.register( + self._builtin_descriptor( + "provider.manager.create", + "Create and load a provider config", + ), + call_handler=self._provider_manager_create, + ) + self.register( + self._builtin_descriptor( + "provider.manager.update", + "Update and reload a provider config", + ), + call_handler=self._provider_manager_update, + ) + self.register( + self._builtin_descriptor( + "provider.manager.delete", + "Delete a provider config", + ), + call_handler=self._provider_manager_delete, + ) + self.register( + self._builtin_descriptor( + "provider.manager.get_insts", + "List loaded chat provider instances", + ), + call_handler=self._provider_manager_get_insts, + ) + self.register( + self._builtin_descriptor( + "provider.manager.watch_changes", + "Stream provider change events", + supports_stream=True, + cancelable=True, + ), + stream_handler=self._provider_manager_watch_changes, + ) + + @staticmethod + def _provider_to_payload(provider: Any | None) -> dict[str, Any] | None: + if provider is None: + return None + meta = provider.meta() + return ProviderCapabilityMixin._provider_meta_to_payload(meta) + + @staticmethod + def _normalize_sdk_provider_type(value: Any) -> SDKProviderType: + if isinstance(value, SDKProviderType): + return value + raw_provider_type = getattr(value, "provider_type", value) + provider_type_value = ( + str(raw_provider_type.value) + if hasattr(raw_provider_type, "value") + else str(raw_provider_type) + ) + try: + return SDKProviderType(provider_type_value) + except ValueError: + return SDKProviderType.CHAT_COMPLETION + + @classmethod + def _provider_meta_to_payload(cls, meta: Any) -> dict[str, Any]: + provider_type = cls._normalize_sdk_provider_type(meta) + return ProviderMeta( + id=str(getattr(meta, "id", "")), + model=( + str(getattr(meta, "model", "")) + if getattr(meta, "model", None) is not None + else None + ), + type=str(getattr(meta, "type", "")), + provider_type=provider_type, + ).to_payload() + + @classmethod + def _managed_provider_from_config( + cls, + provider_config: dict[str, Any] | None, + *, + loaded: bool, + ) -> dict[str, Any] | None: + if not isinstance(provider_config, dict): + return None + provider_id = str(provider_config.get("id", "")).strip() + provider_type_text = str(provider_config.get("type", "")).strip() + if not provider_id or not provider_type_text: + return None + provider_type = cls._normalize_sdk_provider_type( + provider_config.get("provider_type", SDKProviderType.CHAT_COMPLETION.value) + ) + return { + "id": provider_id, + "model": ( + str(provider_config.get("model")) + if provider_config.get("model") is not None + else None + ), + "type": provider_type_text, + "provider_type": provider_type.value, + "loaded": bool(loaded), + "enabled": bool(provider_config.get("enable", True)), + "provider_source_id": ( + str(provider_config.get("provider_source_id")) + if provider_config.get("provider_source_id") is not None + else None + ), + } + + @classmethod + def _managed_provider_to_payload( + cls, provider: Any | None + ) -> dict[str, Any] | None: + if provider is None: + return None + meta_payload = cls._provider_to_payload(provider) + if meta_payload is None: + return None + provider_config = getattr(provider, "provider_config", None) + return { + **meta_payload, + "loaded": True, + "enabled": bool( + provider_config.get("enable", True) + if isinstance(provider_config, dict) + else True + ), + "provider_source_id": ( + str(provider_config.get("provider_source_id")) + if isinstance(provider_config, dict) + and provider_config.get("provider_source_id") is not None + else None + ), + } + + def _find_provider_config_by_id(self, provider_id: str) -> dict[str, Any] | None: + provider_manager = getattr(self._star_context, "provider_manager", None) + providers_config = getattr(provider_manager, "providers_config", None) + if not isinstance(providers_config, list): + return None + for item in providers_config: + if not isinstance(item, dict): + continue + if str(item.get("id", "")).strip() == provider_id: + return dict(item) + return None + + def _managed_provider_payload_by_id( + self, + provider_id: str, + *, + fallback_config: dict[str, Any] | None = None, + ) -> dict[str, Any] | None: + normalized_provider_id = str(provider_id).strip() + if not normalized_provider_id: + return None + provider = self._star_context.get_provider_by_id(normalized_provider_id) + payload = self._managed_provider_to_payload(provider) + if payload is not None: + return payload + provider_config = self._find_provider_config_by_id(normalized_provider_id) + if provider_config is None: + provider_config = ( + dict(fallback_config) if isinstance(fallback_config, dict) else None + ) + return self._managed_provider_from_config(provider_config, loaded=False) + + def _resolve_current_chat_provider_id( + self, + request_context: Any | None, + ) -> str | None: + if request_context is None: + return None + provider = self._star_context.get_using_provider( + request_context.event.unified_msg_origin + ) + if provider is None: + return None + meta = provider.meta() + return str(getattr(meta, "id", "") or "") + + async def _provider_get_using( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + provider = self._star_context.get_using_provider(payload.get("umo")) + return {"provider": self._provider_to_payload(provider)} + + async def _provider_get_current_chat_provider_id( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + provider = self._star_context.get_using_provider(payload.get("umo")) + if provider is None: + return {"provider_id": None} + return {"provider_id": str(provider.meta().id)} + + async def _provider_get_by_id( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + provider = self._get_provider_by_id(payload, "provider.get_by_id") + return {"provider": self._provider_to_payload(provider)} + + def _provider_list_payload(self, providers: list[Any]) -> dict[str, Any]: + return { + "providers": [ + payload + for payload in ( + self._provider_to_payload(provider) for provider in providers + ) + if payload is not None + ] + } + + async def _provider_list_all( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + return self._provider_list_payload(self._star_context.get_all_providers()) + + async def _provider_list_all_tts( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + return self._provider_list_payload(self._star_context.get_all_tts_providers()) + + async def _provider_list_all_stt( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + return self._provider_list_payload(self._star_context.get_all_stt_providers()) + + async def _provider_list_all_embedding( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + return self._provider_list_payload( + self._star_context.get_all_embedding_providers() + ) + + async def _provider_list_all_rerank( + self, + _request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + return self._provider_list_payload( + self._star_context.get_all_rerank_providers() + ) + + async def _provider_get_using_tts( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + provider = self._star_context.get_using_tts_provider(payload.get("umo")) + return {"provider": self._provider_to_payload(provider)} + + async def _provider_get_using_stt( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + provider = self._star_context.get_using_stt_provider(payload.get("umo")) + return {"provider": self._provider_to_payload(provider)} + + @staticmethod + def _tts_stream_texts_from_payload(payload: dict[str, Any]) -> list[str]: + text = payload.get("text") + if isinstance(text, str): + return [text] + text_chunks = payload.get("text_chunks") + if isinstance(text_chunks, list): + chunks = [str(item) for item in text_chunks] + if chunks: + return chunks + raise AstrBotError.invalid_input( + "provider.tts.get_audio_stream requires text or text_chunks" + ) + + def _get_provider_by_id( + self, + payload: dict[str, Any], + capability_name: str, + ) -> Any: + provider_id = str(payload.get("provider_id", "")).strip() + if not provider_id: + raise AstrBotError.invalid_input( + f"{capability_name} requires provider_id", + ) + provider = self._star_context.get_provider_by_id(provider_id) + if provider is None: + raise AstrBotError.invalid_input( + f"{capability_name} unknown provider_id: {provider_id}", + ) + return provider + + def _get_typed_provider( + self, + payload: dict[str, Any], + capability_name: str, + provider_label: str, + expected_type: type[Any], + ) -> Any: + provider = self._get_provider_by_id(payload, capability_name) + if not isinstance(provider, expected_type): + raise AstrBotError.invalid_input( + f"{capability_name} requires a {provider_label} provider", + ) + return provider + + async def _provider_stt_get_text( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + stt_provider_cls, _, _, _ = _get_runtime_provider_types() + provider = self._get_typed_provider( + payload, + "provider.stt.get_text", + "speech_to_text", + stt_provider_cls, + ) + return {"text": await provider.get_text(str(payload.get("audio_url", "")))} + + async def _provider_tts_get_audio( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + _, tts_provider_cls, _, _ = _get_runtime_provider_types() + provider = self._get_typed_provider( + payload, + "provider.tts.get_audio", + "text_to_speech", + tts_provider_cls, + ) + return {"audio_path": await provider.get_audio(str(payload.get("text", "")))} + + async def _provider_tts_support_stream( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + _, tts_provider_cls, _, _ = _get_runtime_provider_types() + provider = self._get_typed_provider( + payload, + "provider.tts.support_stream", + "text_to_speech", + tts_provider_cls, + ) + return {"supported": bool(provider.support_stream())} + + async def _provider_tts_get_audio_stream( + self, + _request_id: str, + payload: dict[str, Any], + token, + ) -> StreamExecution: + _, tts_provider_cls, _, _ = _get_runtime_provider_types() + provider = self._get_typed_provider( + payload, + "provider.tts.get_audio_stream", + "text_to_speech", + tts_provider_cls, + ) + texts = self._tts_stream_texts_from_payload(payload) + text_queue: asyncio.Queue[str | None] = asyncio.Queue() + audio_queue: asyncio.Queue[bytes | tuple[str, bytes] | None] = asyncio.Queue() + for text in texts: + await text_queue.put(text) + await text_queue.put(None) + state: dict[str, BaseException] = {} + + async def producer() -> None: + try: + await provider.get_audio_stream(text_queue, audio_queue) + except Exception as exc: # pragma: no cover - provider-specific failures + state["error"] = exc + finally: + await audio_queue.put(None) + + task = asyncio.create_task(producer()) + + async def iterator() -> AsyncIterator[dict[str, Any]]: + try: + while True: + token.raise_if_cancelled() + item = await audio_queue.get() + if item is None: + break + chunk_text: str | None = None + chunk_audio: bytes | bytearray + if isinstance(item, tuple): + chunk_text = str(item[0]) + chunk_audio = item[1] + else: + chunk_audio = item + yield { + "audio_base64": base64.b64encode(bytes(chunk_audio)).decode( + "ascii" + ), + "text": chunk_text, + } + error = state.get("error") + if error is not None: + raise error + finally: + if not task.done(): + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + else: + with contextlib.suppress(Exception): + await task + + def finalize(chunks: list[dict[str, Any]]) -> dict[str, Any]: + return chunks[-1] if chunks else {"audio_base64": "", "text": None} + + return StreamExecution(iterator=iterator(), finalize=finalize) + + async def _provider_embedding_get_embedding( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + _, _, embedding_provider_cls, _ = _get_runtime_provider_types() + provider = self._get_typed_provider( + payload, + "provider.embedding.get_embedding", + "embedding", + embedding_provider_cls, + ) + return {"embedding": await provider.get_embedding(str(payload.get("text", "")))} + + async def _provider_embedding_get_embeddings( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + _, _, embedding_provider_cls, _ = _get_runtime_provider_types() + provider = self._get_typed_provider( + payload, + "provider.embedding.get_embeddings", + "embedding", + embedding_provider_cls, + ) + texts = payload.get("texts") + if not isinstance(texts, list): + raise AstrBotError.invalid_input( + "provider.embedding.get_embeddings requires texts", + ) + return { + "embeddings": await provider.get_embeddings([str(item) for item in texts]) + } + + async def _provider_embedding_get_dim( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + _, _, embedding_provider_cls, _ = _get_runtime_provider_types() + provider = self._get_typed_provider( + payload, + "provider.embedding.get_dim", + "embedding", + embedding_provider_cls, + ) + return {"dim": int(provider.get_dim())} + + async def _provider_rerank_rerank( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + _, _, _, rerank_provider_cls = _get_runtime_provider_types() + provider = self._get_typed_provider( + payload, + "provider.rerank.rerank", + "rerank", + rerank_provider_cls, + ) + documents = payload.get("documents") + if not isinstance(documents, list): + raise AstrBotError.invalid_input( + "provider.rerank.rerank requires documents", + ) + normalized_documents = [str(item) for item in documents] + top_n = payload.get("top_n") + results = await provider.rerank( + str(payload.get("query", "")), + normalized_documents, + int(top_n) if top_n is not None else None, + ) + serialized = [] + for item in results: + index = int(getattr(item, "index", 0)) + serialized.append( + { + "index": index, + "score": float(getattr(item, "relevance_score", 0.0)), + "document": normalized_documents[index] + if 0 <= index < len(normalized_documents) + else "", + } + ) + return {"results": serialized} + + @staticmethod + def _normalize_provider_config_payload( + payload: Any, + capability_name: str, + field_name: str, + ) -> dict[str, Any]: + if not isinstance(payload, dict): + raise AstrBotError.invalid_input( + f"{capability_name} requires {field_name} object" + ) + return dict(payload) + + @staticmethod + def _core_provider_type(value: Any, capability_name: str): + from astrbot.core.provider.entities import ProviderType as CoreProviderType + + normalized = str(value).strip() + try: + return CoreProviderType(normalized) + except ValueError as exc: + raise AstrBotError.invalid_input( + f"{capability_name} requires a valid provider_type" + ) from exc + + async def _provider_manager_set( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "provider.manager.set") + provider_id = str(payload.get("provider_id", "")).strip() + if not provider_id: + raise AstrBotError.invalid_input( + "provider.manager.set requires provider_id" + ) + await self._star_context.provider_manager.set_provider( + provider_id=provider_id, + provider_type=self._core_provider_type( + payload.get("provider_type"), + "provider.manager.set", + ), + umo=( + str(payload.get("umo")) + if payload.get("umo") is not None and str(payload.get("umo")).strip() + else None + ), + ) + return {} + + async def _provider_manager_get_by_id( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "provider.manager.get_by_id") + provider_id = str(payload.get("provider_id", "")).strip() + return {"provider": self._managed_provider_payload_by_id(provider_id)} + + async def _provider_manager_get_merged_provider_config( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin( + request_id, + "provider.manager.get_merged_provider_config", + ) + provider_id = str(payload.get("provider_id", "")).strip() + if not provider_id: + raise AstrBotError.invalid_input( + "provider.manager.get_merged_provider_config requires provider_id" + ) + provider_manager = getattr(self._star_context, "provider_manager", None) + get_merged_provider_config = getattr( + provider_manager, + "get_merged_provider_config", + None, + ) + if provider_manager is None or not callable(get_merged_provider_config): + raise AstrBotError.invalid_input( + "Provider manager does not support merged config lookup" + ) + provider_config = self._find_provider_config_by_id(provider_id) + if provider_config is None: + raise AstrBotError.invalid_input( + "provider.manager.get_merged_provider_config unknown provider_id" + ) + merged_config = cast( + dict[str, Any], get_merged_provider_config(provider_config) + ) + return {"config": dict(merged_config)} + + async def _provider_manager_load( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "provider.manager.load") + provider_config = self._normalize_provider_config_payload( + payload.get("provider_config"), + "provider.manager.load", + "provider_config", + ) + await self._star_context.provider_manager.load_provider(provider_config) + provider_id = str(provider_config.get("id", "")).strip() + return { + "provider": self._managed_provider_payload_by_id( + provider_id, + fallback_config=provider_config, + ) + } + + async def _provider_manager_terminate( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "provider.manager.terminate") + provider_id = str(payload.get("provider_id", "")).strip() + if not provider_id: + raise AstrBotError.invalid_input( + "provider.manager.terminate requires provider_id" + ) + await self._star_context.provider_manager.terminate_provider(provider_id) + return {} + + async def _provider_manager_create( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "provider.manager.create") + provider_config = self._normalize_provider_config_payload( + payload.get("provider_config"), + "provider.manager.create", + "provider_config", + ) + await self._star_context.provider_manager.create_provider(provider_config) + provider_id = str(provider_config.get("id", "")).strip() + return {"provider": self._managed_provider_payload_by_id(provider_id)} + + async def _provider_manager_update( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "provider.manager.update") + origin_provider_id = str(payload.get("origin_provider_id", "")).strip() + if not origin_provider_id: + raise AstrBotError.invalid_input( + "provider.manager.update requires origin_provider_id" + ) + new_config = self._normalize_provider_config_payload( + payload.get("new_config"), + "provider.manager.update", + "new_config", + ) + await self._star_context.provider_manager.update_provider( + origin_provider_id, + new_config, + ) + target_provider_id = str(new_config.get("id") or origin_provider_id).strip() + return {"provider": self._managed_provider_payload_by_id(target_provider_id)} + + async def _provider_manager_delete( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "provider.manager.delete") + provider_id = ( + str(payload.get("provider_id")).strip() + if payload.get("provider_id") is not None + else None + ) + provider_source_id = ( + str(payload.get("provider_source_id")).strip() + if payload.get("provider_source_id") is not None + else None + ) + if not provider_id and not provider_source_id: + raise AstrBotError.invalid_input( + "provider.manager.delete requires provider_id or provider_source_id" + ) + await self._star_context.provider_manager.delete_provider( + provider_id=provider_id or None, + provider_source_id=provider_source_id or None, + ) + return {} + + async def _provider_manager_get_insts( + self, + request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + self._require_reserved_plugin(request_id, "provider.manager.get_insts") + provider_manager = getattr(self._star_context, "provider_manager", None) + if provider_manager is None or not hasattr(provider_manager, "get_insts"): + return {"providers": []} + return { + "providers": [ + payload + for payload in ( + self._managed_provider_to_payload(provider) + for provider in list(provider_manager.get_insts()) + ) + if payload is not None + ] + } + + async def _provider_manager_watch_changes( + self, + request_id: str, + _payload: dict[str, Any], + token, + ) -> StreamExecution: + self._require_reserved_plugin(request_id, "provider.manager.watch_changes") + provider_manager = getattr(self._star_context, "provider_manager", None) + if provider_manager is None or not hasattr( + provider_manager, "register_provider_change_hook" + ): + raise AstrBotError.invalid_input("Provider manager does not support hooks") + unregister_hook = getattr( + provider_manager, + "unregister_provider_change_hook", + None, + ) + queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() + loop = asyncio.get_running_loop() + + def hook(provider_id: str, provider_type: Any, umo: str | None) -> None: + event = { + "provider_id": str(provider_id), + "provider_type": self._normalize_sdk_provider_type(provider_type).value, + "umo": str(umo) if umo is not None else None, + } + loop.call_soon_threadsafe(queue.put_nowait, event) + + provider_manager.register_provider_change_hook(hook) + + async def iterator() -> AsyncIterator[dict[str, Any]]: + try: + while True: + token.raise_if_cancelled() + yield await queue.get() + finally: + if callable(unregister_hook): + unregister_hook(hook) + + return StreamExecution( + iterator=iterator(), + finalize=lambda _chunks: {}, + collect_chunks=False, + ) + + async def _llm_tool_manager_get( + self, + request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + return { + "registered": [ + item.to_payload() + for item in self._plugin_bridge.get_registered_llm_tools(plugin_id) + ], + "active": [ + item.to_payload() + for item in self._plugin_bridge.get_active_llm_tools(plugin_id) + ], + } + + async def _llm_tool_manager_activate( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + return { + "activated": self._plugin_bridge.activate_llm_tool( + plugin_id, str(payload.get("name", "")) + ) + } + + async def _llm_tool_manager_deactivate( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + return { + "deactivated": self._plugin_bridge.deactivate_llm_tool( + plugin_id, str(payload.get("name", "")) + ) + } + + async def _llm_tool_manager_add( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + tools_payload = payload.get("tools") + if not isinstance(tools_payload, list): + raise AstrBotError.invalid_input("llm_tool.manager.add requires tools list") + tools = [ + LLMToolSpec.from_payload(item) + for item in tools_payload + if isinstance(item, dict) + ] + return {"names": self._plugin_bridge.add_llm_tools(plugin_id, tools)} + + async def _llm_tool_manager_remove( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + return { + "removed": self._plugin_bridge.remove_llm_tool( + plugin_id, + str(payload.get("name", "")), + ) + } + + async def _agent_registry_list( + self, + request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + return { + "agents": [ + item.to_payload() + for item in self._plugin_bridge.get_registered_agents(plugin_id) + ] + } + + async def _agent_registry_get( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + agent = self._plugin_bridge.get_registered_agent( + plugin_id, str(payload.get("name", "")) + ) + return {"agent": agent.to_payload() if agent is not None else None} + + def _select_llm_tools_for_request( + self, + plugin_id: str, + payload: dict[str, Any], + ) -> list[LLMToolSpec]: + active_specs = { + item.name: item + for item in self._plugin_bridge.get_request_tool_specs(plugin_id) + } + requested = payload.get("tool_names") + if not isinstance(requested, list) or not requested: + return list(active_specs.values()) + names = [str(item) for item in requested if str(item).strip()] + return [active_specs[name] for name in names if name in active_specs] + + def _make_sdk_tool_handler( + self, + *, + plugin_id: str, + tool_spec: LLMToolSpec, + tool_call_timeout: int, + ): + async def _handler(event: AstrMessageEvent, **tool_args: Any) -> str | None: + record = self._plugin_bridge._records.get(plugin_id) + if record is None or record.session is None: + return json.dumps( + ToolCallsResult( + tool_name=tool_spec.name, + content="SDK plugin worker is unavailable", + success=False, + ).to_payload(), + ensure_ascii=False, + ) + request_id = f"sdk_tool_{plugin_id}_{uuid.uuid4().hex}" + dispatch_token = ( + self._plugin_bridge._get_dispatch_token(event) or uuid.uuid4().hex + ) + event_payload = EventConverter.core_to_sdk( + event, + dispatch_token=dispatch_token, + plugin_id=plugin_id, + request_id=request_id, + ) + call_payload = { + "plugin_id": plugin_id, + "tool_name": tool_spec.name, + "handler_ref": tool_spec.handler_ref, + "tool_args": json.loads( + json.dumps(tool_args, ensure_ascii=False, default=str) + ), + "event": event_payload, + } + try: + if tool_spec.handler_capability == "internal.mcp.local.execute": + handler_ref = json.loads(tool_spec.handler_ref or "{}") + output = await asyncio.wait_for( + self.execute( + "internal.mcp.local.execute", + { + "plugin_id": plugin_id, + "server_name": str( + handler_ref.get("server_name", "") + ).strip(), + "tool_name": str( + handler_ref.get("tool_name", "") + ).strip(), + "tool_args": call_payload["tool_args"], + }, + stream=False, + cancel_token=None, + request_id=request_id, + ), + timeout=tool_call_timeout, + ) + elif tool_spec.handler_capability: + output = await asyncio.wait_for( + record.session.invoke_capability( + tool_spec.handler_capability, + call_payload, + request_id=request_id, + ), + timeout=tool_call_timeout, + ) + else: + output = await asyncio.wait_for( + record.session.invoke_capability( + "internal.llm_tool.execute", + call_payload, + request_id=request_id, + ), + timeout=tool_call_timeout, + ) + except TimeoutError: + return json.dumps( + ToolCallsResult( + tool_name=tool_spec.name, + content=( + f"Tool execution timeout after {tool_call_timeout} seconds" + ), + success=False, + ).to_payload(), + ensure_ascii=False, + ) + except Exception as exc: + return json.dumps( + ToolCallsResult( + tool_name=tool_spec.name, + content=f"Tool execution failed: {exc}", + success=False, + ).to_payload(), + ensure_ascii=False, + ) + if not isinstance(output, dict): + return str(output) + content = output.get("content") + if output.get("success", True): + # Keep None distinct from an empty string so tools can signal + # "no content" without fabricating a textual result. + return None if content is None else str(content) + return json.dumps( + ToolCallsResult( + tool_name=tool_spec.name, + content=str(content or ""), + success=False, + ).to_payload(), + ensure_ascii=False, + ) + + return _handler + + def _build_sdk_toolset( + self, + *, + plugin_id: str, + payload: dict[str, Any], + tool_call_timeout: int, + ) -> Any | None: + tool_specs = self._select_llm_tools_for_request(plugin_id, payload) + if not tool_specs: + return None + function_tool_cls, tool_set_cls = _get_runtime_tool_types() + tool_set = tool_set_cls() + for tool_spec in tool_specs: + tool_set.add_tool( + function_tool_cls( + name=tool_spec.name, + description=tool_spec.description, + parameters=tool_spec.parameters_schema, + handler=self._make_sdk_tool_handler( + plugin_id=plugin_id, + tool_spec=tool_spec, + tool_call_timeout=tool_call_timeout, + ), + ) + ) + return tool_set + + def _llm_response_to_payload(self, response: Any) -> dict[str, Any]: + usage = None + if response.usage is not None: + usage = { + "input_tokens": response.usage.input, + "output_tokens": response.usage.output, + "total_tokens": response.usage.total, + } + return { + "text": response.completion_text, + "usage": usage, + "finish_reason": "tool_calls" if response.tools_call_ids else "stop", + "tool_calls": response.to_openai_tool_calls(), + "role": response.role, + "reasoning_content": response.reasoning_content or None, + "reasoning_signature": response.reasoning_signature, + } + + async def _agent_tool_loop_run( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + request_context = self._resolve_event_request_context(request_id, payload) + if request_context is None: + raise AstrBotError.invalid_input( + "tool_loop_agent currently requires a message-bound SDK request" + ) + provider_id = str( + payload.get("provider_id") or "" + ).strip() or self._resolve_current_chat_provider_id(request_context) + if not provider_id: + raise AstrBotError.invalid_input("No active chat provider is available") + tool_call_timeout = int(payload.get("tool_call_timeout") or 60) + llm_resp = await self._star_context.tool_loop_agent( + event=request_context.event, + chat_provider_id=provider_id, + prompt=( + str(payload.get("prompt")) + if payload.get("prompt") is not None + else None + ), + image_urls=[ + str(item) + for item in payload.get("image_urls", []) + if isinstance(item, str) + ], + tools=self._build_sdk_toolset( + plugin_id=plugin_id, + payload=payload, + tool_call_timeout=tool_call_timeout, + ), + system_prompt=str(payload.get("system_prompt") or ""), + contexts=[ + dict(item) + for item in payload.get("contexts", []) + if isinstance(item, dict) + ], + max_steps=int(payload.get("max_steps") or 30), + tool_call_timeout=tool_call_timeout, + ) + return self._llm_response_to_payload(llm_resp) diff --git a/astrbot/core/sdk_bridge/capabilities/session.py b/astrbot/core/sdk_bridge/capabilities/session.py new file mode 100644 index 0000000000..0f992ff757 --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/session.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +from typing import Any + +from astrbot_sdk.errors import AstrBotError + +from ..bridge_base import _get_runtime_sp +from ._host import CapabilityMixinHost + + +class SessionCapabilityMixin(CapabilityMixinHost): + def _register_session_capabilities(self) -> None: + self.register( + self._builtin_descriptor( + "session.plugin.is_enabled", + "Get session plugin enabled state", + ), + call_handler=self._session_plugin_is_enabled, + ) + self.register( + self._builtin_descriptor( + "session.plugin.filter_handlers", + "Filter handler metadata by session plugin config", + ), + call_handler=self._session_plugin_filter_handlers, + ) + self.register( + self._builtin_descriptor( + "session.service.is_llm_enabled", + "Get session LLM enabled state", + ), + call_handler=self._session_service_is_llm_enabled, + ) + self.register( + self._builtin_descriptor( + "session.service.set_llm_status", + "Set session LLM enabled state", + ), + call_handler=self._session_service_set_llm_status, + ) + self.register( + self._builtin_descriptor( + "session.service.is_tts_enabled", + "Get session TTS enabled state", + ), + call_handler=self._session_service_is_tts_enabled, + ) + self.register( + self._builtin_descriptor( + "session.service.set_tts_status", + "Set session TTS enabled state", + ), + call_handler=self._session_service_set_tts_status, + ) + + async def _load_session_plugin_config(self, session_id: str) -> dict[str, Any]: + raw_config = await _get_runtime_sp().get_async( + scope="umo", + scope_id=session_id, + key="session_plugin_config", + default={}, + ) + return self._normalize_session_scoped_config(raw_config, session_id) + + async def _load_session_service_config(self, session_id: str) -> dict[str, Any]: + raw_config = await _get_runtime_sp().get_async( + scope="umo", + scope_id=session_id, + key="session_service_config", + default={}, + ) + return self._normalize_session_scoped_config(raw_config, session_id) + + async def _session_plugin_is_enabled( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + session_id = str(payload.get("session", "")).strip() + plugin_name = str(payload.get("plugin_name", "")).strip() + config = await self._load_session_plugin_config(session_id) + enabled_plugins = { + str(item) for item in config.get("enabled_plugins", []) if str(item).strip() + } + disabled_plugins = { + str(item) + for item in config.get("disabled_plugins", []) + if str(item).strip() + } + if ( + plugin_name in disabled_plugins + and plugin_name not in self._reserved_plugin_names() + ): + return {"enabled": False} + if plugin_name in enabled_plugins: + return {"enabled": True} + return {"enabled": True} + + async def _session_plugin_filter_handlers( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + session_id = str(payload.get("session", "")).strip() + handlers = payload.get("handlers") + if not isinstance(handlers, list): + raise AstrBotError.invalid_input( + "session.plugin.filter_handlers requires a handlers array" + ) + config = await self._load_session_plugin_config(session_id) + disabled_plugins = { + str(item) + for item in config.get("disabled_plugins", []) + if str(item).strip() + } + reserved_plugins = self._reserved_plugin_names() + filtered = [] + for item in handlers: + if not isinstance(item, dict): + continue + plugin_name = str(item.get("plugin_name", "")).strip() + if ( + plugin_name + and plugin_name in disabled_plugins + and plugin_name not in reserved_plugins + ): + continue + filtered.append(dict(item)) + return {"handlers": filtered} + + async def _session_service_is_llm_enabled( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + session_id = str(payload.get("session", "")).strip() + config = await self._load_session_service_config(session_id) + return {"enabled": bool(config.get("llm_enabled", True))} + + async def _session_service_set_llm_status( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + session_id = str(payload.get("session", "")).strip() + config = await self._load_session_service_config(session_id) + config["llm_enabled"] = bool(payload.get("enabled", False)) + await _get_runtime_sp().put_async( + scope="umo", + scope_id=session_id, + key="session_service_config", + value=config, + ) + return {} + + async def _session_service_is_tts_enabled( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + session_id = str(payload.get("session", "")).strip() + config = await self._load_session_service_config(session_id) + return {"enabled": bool(config.get("tts_enabled", True))} + + async def _session_service_set_tts_status( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + session_id = str(payload.get("session", "")).strip() + config = await self._load_session_service_config(session_id) + config["tts_enabled"] = bool(payload.get("enabled", False)) + await _get_runtime_sp().put_async( + scope="umo", + scope_id=session_id, + key="session_service_config", + value=config, + ) + return {} diff --git a/astrbot/core/sdk_bridge/capabilities/skill.py b/astrbot/core/sdk_bridge/capabilities/skill.py new file mode 100644 index 0000000000..73fcbab614 --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/skill.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from astrbot.core import logger + +from ._host import CapabilityMixinHost + + +class SkillCapabilityMixin(CapabilityMixinHost): + def _register_skill_capabilities(self) -> None: + self.register( + self._builtin_descriptor("skill.register", "Register SDK skill"), + call_handler=self._skill_register, + ) + self.register( + self._builtin_descriptor("skill.unregister", "Unregister SDK skill"), + call_handler=self._skill_unregister, + ) + self.register( + self._builtin_descriptor("skill.list", "List SDK skills"), + call_handler=self._skill_list, + ) + + async def _skill_register( + self, + request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, str]: + plugin_id = self._resolve_plugin_id(request_id) + result = self._plugin_bridge.register_skill( + plugin_id=plugin_id, + name=str(payload.get("name", "")), + path=str(payload.get("path", "")), + description=str(payload.get("description", "")), + ) + await self._sync_registered_skills_to_sandboxes() + return result + + async def _skill_unregister( + self, + request_id: str, + payload: dict[str, object], + _token, + ) -> dict[str, bool]: + plugin_id = self._resolve_plugin_id(request_id) + removed = self._plugin_bridge.unregister_skill( + plugin_id=plugin_id, + name=str(payload.get("name", "")), + ) + if removed: + await self._sync_registered_skills_to_sandboxes() + return {"removed": removed} + + async def _skill_list( + self, + request_id: str, + _payload: dict[str, object], + _token, + ) -> dict[str, list[dict[str, str]]]: + plugin_id = self._resolve_plugin_id(request_id) + return {"skills": self._plugin_bridge.list_registered_skills(plugin_id)} + + async def _sync_registered_skills_to_sandboxes(self) -> None: + try: + from astrbot.core.computer.computer_client import ( + sync_skills_to_active_sandboxes, + ) + + await sync_skills_to_active_sandboxes() + except Exception as exc: + logger.warning( + "Failed to sync skills to active sandboxes after SDK skill update: %s", + exc, + ) diff --git a/astrbot/core/sdk_bridge/capabilities/system.py b/astrbot/core/sdk_bridge/capabilities/system.py new file mode 100644 index 0000000000..7321e56be4 --- /dev/null +++ b/astrbot/core/sdk_bridge/capabilities/system.py @@ -0,0 +1,596 @@ +from __future__ import annotations + +import asyncio +import uuid +from collections.abc import AsyncIterator +from pathlib import Path +from typing import Any + +from astrbot_sdk.errors import AstrBotError + +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + +from ..bridge_base import ( + _EventStreamState, + _get_runtime_astrbot_config, + _get_runtime_file_token_service, + _get_runtime_html_renderer, +) +from ._host import CapabilityMixinHost + + +class SystemCapabilityMixin(CapabilityMixinHost): + @staticmethod + def _overlay_request_id(request_id: str, payload: dict[str, Any]) -> str: + scope_request_id = payload.get("_request_scope_id") + if isinstance(scope_request_id, str) and scope_request_id.strip(): + return scope_request_id + return request_id + + def _register_system_capabilities(self) -> None: + self.register( + self._builtin_descriptor("system.get_data_dir", "Get plugin data dir"), + call_handler=self._system_get_data_dir, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.text_to_image", "Render text to image"), + call_handler=self._system_text_to_image, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.html_render", "Render html template"), + call_handler=self._system_html_render, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.file.register", "Register file token"), + call_handler=self._system_file_register, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.file.handle", "Resolve file token"), + call_handler=self._system_file_handle, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.session_waiter.register", + "Register sdk session waiter", + ), + call_handler=self._system_session_waiter_register, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.session_waiter.unregister", + "Unregister sdk session waiter", + ), + call_handler=self._system_session_waiter_unregister, + exposed=False, + ) + self.register( + self._builtin_descriptor("system.event.react", "Send sdk event reaction"), + call_handler=self._system_event_react, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.send_typing", + "Send sdk event typing state", + ), + call_handler=self._system_event_send_typing, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.send_streaming", + "Send sdk event streaming chunks", + ), + call_handler=self._system_event_send_streaming, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.send_streaming_chunk", + "Push sdk event streaming chunk", + ), + call_handler=self._system_event_send_streaming_chunk, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.send_streaming_close", + "Close sdk event streaming session", + ), + call_handler=self._system_event_send_streaming_close, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.llm.get_state", + "Read sdk request llm state", + ), + call_handler=self._system_event_llm_get_state, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.llm.request", + "Request default llm for current sdk request", + ), + call_handler=self._system_event_llm_request, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.result.get", + "Read sdk request result", + ), + call_handler=self._system_event_result_get, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.result.set", + "Write sdk request result", + ), + call_handler=self._system_event_result_set, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.result.clear", + "Clear sdk request result", + ), + call_handler=self._system_event_result_clear, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.handler_whitelist.get", + "Read sdk request handler whitelist", + ), + call_handler=self._system_event_handler_whitelist_get, + exposed=False, + ) + self.register( + self._builtin_descriptor( + "system.event.handler_whitelist.set", + "Write sdk request handler whitelist", + ), + call_handler=self._system_event_handler_whitelist_set, + exposed=False, + ) + + def _register_registry_capabilities(self) -> None: + self.register( + self._builtin_descriptor( + "registry.get_handlers_by_event_type", + "List SDK handlers by event type", + ), + call_handler=self._registry_get_handlers_by_event_type, + ) + self.register( + self._builtin_descriptor( + "registry.get_handler_by_full_name", + "Get SDK handler metadata by full name", + ), + call_handler=self._registry_get_handler_by_full_name, + ) + self.register( + self._builtin_descriptor( + "registry.command.register", + "Register dynamic command route", + ), + call_handler=self._registry_command_register, + ) + + async def _system_get_data_dir( + self, + request_id: str, + _payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + data_dir = Path(get_astrbot_data_path()) / "plugin_data" / plugin_id + data_dir.mkdir(parents=True, exist_ok=True) + return {"path": str(data_dir.resolve())} + + async def _system_text_to_image( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + config_obj = self._star_context.get_config() + template_name = None + if hasattr(config_obj, "get"): + try: + template_name = config_obj.get("t2i_active_template") + except Exception: + template_name = None + result = await _get_runtime_html_renderer().render_t2i( + str(payload.get("text", "")), + return_url=bool(payload.get("return_url", True)), + template_name=template_name, + ) + return {"result": result} + + async def _system_html_render( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + data = payload.get("data") + if not isinstance(data, dict): + raise AstrBotError.invalid_input("system.html_render requires object data") + options = payload.get("options") + if options is not None and not isinstance(options, dict): + raise AstrBotError.invalid_input( + "system.html_render options must be an object or null" + ) + result = await _get_runtime_html_renderer().render_custom_template( + str(payload.get("tmpl", "")), + data, + return_url=bool(payload.get("return_url", True)), + options=options, + ) + return {"result": result} + + async def _system_file_register( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + path = str(payload.get("path", "")).strip() + if not path: + raise AstrBotError.invalid_input("system.file.register requires path") + raw_timeout = payload.get("timeout") + timeout: float | None + if raw_timeout is None: + timeout = None + else: + try: + timeout = float(raw_timeout) + except (TypeError, ValueError) as exc: + raise AstrBotError.invalid_input( + "system.file.register timeout must be a number or null" + ) from exc + file_token = await _get_runtime_file_token_service().register_file( + path, timeout + ) + callback_host = _get_runtime_astrbot_config().get("callback_api_base") + if not callback_host: + raise AstrBotError.invalid_input( + "callback_api_base is required for system.file.register" + ) + base_url = str(callback_host).rstrip("/") + return {"token": file_token, "url": f"{base_url}/api/file/{file_token}"} + + async def _system_file_handle( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + file_token = str(payload.get("token", "")).strip() + if not file_token: + raise AstrBotError.invalid_input("system.file.handle requires token") + path = await _get_runtime_file_token_service().handle_file(file_token) + return {"path": str(path)} + + async def _system_session_waiter_register( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + self._plugin_bridge.register_session_waiter( + plugin_id=plugin_id, + session_key=str(payload.get("session_key", "")), + ) + return {} + + async def _system_session_waiter_unregister( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_id = self._resolve_plugin_id(request_id) + self._plugin_bridge.unregister_session_waiter( + plugin_id=plugin_id, + session_key=str(payload.get("session_key", "")), + ) + return {} + + async def _system_event_react( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + request_context = self._resolve_event_request_context(request_id, payload) + if request_context is None or request_context.cancelled: + return {"supported": False} + self._plugin_bridge.before_platform_send(request_context.dispatch_token) + await request_context.event.react(str(payload.get("emoji", ""))) + return { + "supported": bool( + self._plugin_bridge.mark_platform_send(request_context.dispatch_token) + ) + } + + async def _system_event_send_typing( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + request_context = self._resolve_event_request_context(request_id, payload) + if request_context is None or request_context.cancelled: + return {"supported": False} + if type(request_context.event).send_typing is AstrMessageEvent.send_typing: + return {"supported": False} + await request_context.event.send_typing() + return {"supported": True} + + async def _system_event_send_streaming( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + request_context = self._resolve_event_request_context(request_id, payload) + if request_context is None or request_context.cancelled: + return {"supported": False} + if ( + type(request_context.event).send_streaming + is AstrMessageEvent.send_streaming + ): + return {"supported": False} + self._plugin_bridge.before_platform_send(request_context.dispatch_token) + queue: asyncio.Queue[MessageChain | None] = asyncio.Queue() + + async def iterator() -> AsyncIterator[MessageChain]: + while True: + chunk = await queue.get() + if chunk is None or request_context.cancelled: + return + yield chunk + await asyncio.sleep(0) + + stream_id = uuid.uuid4().hex + task = asyncio.create_task( + request_context.event.send_streaming( + iterator(), + use_fallback=bool(payload.get("use_fallback", False)), + ) + ) + self._event_streams[stream_id] = _EventStreamState( + request_context=request_context, + queue=queue, + task=task, + ) + return {"supported": True, "stream_id": stream_id} + + async def _system_event_send_streaming_chunk( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + stream_state = self._event_streams.get(str(payload.get("stream_id", ""))) + if stream_state is None: + raise AstrBotError.invalid_input("Unknown sdk event streaming session") + if stream_state.request_context.cancelled: + raise AstrBotError.cancelled("The SDK request has been cancelled") + chain_payload = payload.get("chain") + if not isinstance(chain_payload, list): + raise AstrBotError.invalid_input( + "system.event.send_streaming_chunk requires a chain array" + ) + await stream_state.queue.put(self._build_core_message_chain(chain_payload)) + return {} + + async def _system_event_send_streaming_close( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + stream_id = str(payload.get("stream_id", "")) + stream_state = self._event_streams.pop(stream_id, None) + if stream_state is None: + raise AstrBotError.invalid_input("Unknown sdk event streaming session") + await stream_state.queue.put(None) + try: + await stream_state.task + finally: + self._event_streams.pop(stream_id, None) + return { + "supported": bool( + self._plugin_bridge.mark_platform_send( + stream_state.request_context.dispatch_token + ) + ) + } + + async def _system_event_llm_get_state( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + overlay_request_id = self._overlay_request_id(request_id, payload) + overlay = self._plugin_bridge.get_request_overlay_by_request_id( + overlay_request_id + ) + should_call_llm = self._plugin_bridge.get_should_call_llm_for_request( + overlay_request_id + ) + return { + "should_call_llm": bool(should_call_llm), + "requested_llm": bool(overlay.requested_llm) + if overlay is not None + else False, + } + + async def _system_event_llm_request( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + overlay_request_id = self._overlay_request_id(request_id, payload) + self._plugin_bridge.request_llm_for_request(overlay_request_id) + return await self._system_event_llm_get_state( + request_id, + {"_request_scope_id": overlay_request_id}, + _token, + ) + + async def _system_event_result_get( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + overlay_request_id = self._overlay_request_id(request_id, payload) + return { + "result": self._plugin_bridge.get_result_payload_for_request( + overlay_request_id + ) + } + + async def _system_event_result_set( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + result_payload = payload.get("result") + if not isinstance(result_payload, dict): + raise AstrBotError.invalid_input( + "system.event.result.set requires an object result payload" + ) + overlay_request_id = self._overlay_request_id(request_id, payload) + if not self._plugin_bridge.set_result_for_request( + overlay_request_id, + result_payload, + ): + raise AstrBotError.cancelled("The SDK request overlay has been closed") + return { + "result": self._plugin_bridge.get_result_payload_for_request( + overlay_request_id + ) + } + + async def _system_event_result_clear( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + overlay_request_id = self._overlay_request_id(request_id, payload) + self._plugin_bridge.clear_result_for_request(overlay_request_id) + return {} + + async def _system_event_handler_whitelist_get( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + overlay_request_id = self._overlay_request_id(request_id, payload) + plugin_names = self._plugin_bridge.get_handler_whitelist_for_request( + overlay_request_id + ) + if plugin_names is None: + return {"plugin_names": None} + return {"plugin_names": sorted(plugin_names)} + + async def _system_event_handler_whitelist_set( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + plugin_names_payload = payload.get("plugin_names") + plugin_names: set[str] | None + if plugin_names_payload is None: + plugin_names = None + elif isinstance(plugin_names_payload, list): + plugin_names = { + str(item) for item in plugin_names_payload if str(item).strip() + } + else: + raise AstrBotError.invalid_input( + "system.event.handler_whitelist.set requires a string array or null" + ) + overlay_request_id = self._overlay_request_id(request_id, payload) + if not self._plugin_bridge.set_handler_whitelist_for_request( + overlay_request_id, + plugin_names, + ): + raise AstrBotError.cancelled("The SDK request overlay has been closed") + return await self._system_event_handler_whitelist_get( + request_id, + {"_request_scope_id": overlay_request_id}, + _token, + ) + + async def _registry_get_handlers_by_event_type( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + event_type = str(payload.get("event_type", "")).strip() + return {"handlers": self._plugin_bridge.get_handlers_by_event_type(event_type)} + + async def _registry_get_handler_by_full_name( + self, + _request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + full_name = str(payload.get("full_name", "")).strip() + return {"handler": self._plugin_bridge.get_handler_by_full_name(full_name)} + + async def _registry_command_register( + self, + request_id: str, + payload: dict[str, Any], + _token, + ) -> dict[str, Any]: + source_event_type = str(payload.get("source_event_type", "")).strip() + if source_event_type not in {"astrbot_loaded", "platform_loaded"}: + raise AstrBotError.invalid_input( + "register_commands is only available in astrbot_loaded/platform_loaded events" + ) + if bool(payload.get("ignore_prefix", False)): + raise AstrBotError.invalid_input( + "register_commands(ignore_prefix=True) is unsupported in SDK runtime" + ) + priority_value = payload.get("priority", 0) + if isinstance(priority_value, bool) or not isinstance(priority_value, int): + raise AstrBotError.invalid_input( + "registry.command.register priority must be an integer" + ) + plugin_id = self._resolve_plugin_id(request_id) + self._plugin_bridge.register_dynamic_command_route( + plugin_id=plugin_id, + command_name=str(payload.get("command_name", "")), + handler_full_name=str(payload.get("handler_full_name", "")), + desc=str(payload.get("desc", "")), + priority=priority_value, + use_regex=bool(payload.get("use_regex", False)), + ) + return {} diff --git a/astrbot/core/sdk_bridge/capability_bridge.py b/astrbot/core/sdk_bridge/capability_bridge.py new file mode 100644 index 0000000000..b8f90fe4dc --- /dev/null +++ b/astrbot/core/sdk_bridge/capability_bridge.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from .bridge_base import CapabilityBridgeBase +from .capabilities import ( + BasicCapabilityMixin, + ConversationCapabilityMixin, + KnowledgeBaseCapabilityMixin, + LLMCapabilityMixin, + MCPCapabilityMixin, + MessageHistoryCapabilityMixin, + PermissionCapabilityMixin, + PersonaCapabilityMixin, + PlatformCapabilityMixin, + ProviderCapabilityMixin, + SessionCapabilityMixin, + SkillCapabilityMixin, + SystemCapabilityMixin, +) +from .event_converter import EventConverter + +if TYPE_CHECKING: + from astrbot.core.star.context import Context as StarContext + +__all__ = ["CoreCapabilityBridge", "EventConverter"] + + +class CoreCapabilityBridge( + SystemCapabilityMixin, + ProviderCapabilityMixin, + MCPCapabilityMixin, + PlatformCapabilityMixin, + PermissionCapabilityMixin, + KnowledgeBaseCapabilityMixin, + MessageHistoryCapabilityMixin, + ConversationCapabilityMixin, + PersonaCapabilityMixin, + SessionCapabilityMixin, + SkillCapabilityMixin, + LLMCapabilityMixin, + BasicCapabilityMixin, + CapabilityBridgeBase, +): + def __init__(self, *, star_context: StarContext, plugin_bridge) -> None: + self._star_context = star_context + self._plugin_bridge = plugin_bridge + self._event_streams: dict[str, Any] = {} + self._memory_backends_by_plugin: dict[str, Any] = {} + self._memory_index_by_plugin: dict[str, dict[str, dict[str, Any]]] = {} + self._memory_dirty_keys_by_plugin: dict[str, set[str]] = {} + self._memory_expires_at_by_plugin: dict[str, dict[str, Any]] = {} + # CapabilityRouter.__init__() registers the built-in capability groups + # declared by this bridge and its mixins before extended groups are added. + super().__init__() + self._register_provider_capabilities() + self._register_provider_manager_capabilities() + self._register_mcp_capabilities() + self._register_platform_manager_capabilities() + self._register_permission_capabilities() + self._register_persona_capabilities() + self._register_conversation_capabilities() + self._register_message_history_capabilities() + self._register_kb_capabilities() + self._register_skill_capabilities() + self._register_system_capabilities() + self._register_registry_capabilities() + self._register_db_capabilities() + self._register_memory_capabilities() + self._register_http_capabilities() + self._register_metadata_capabilities() diff --git a/astrbot/core/sdk_bridge/event_converter.py b/astrbot/core/sdk_bridge/event_converter.py new file mode 100644 index 0000000000..2e53c3b1c9 --- /dev/null +++ b/astrbot/core/sdk_bridge/event_converter.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any + +from astrbot_sdk._message_types import normalize_message_type +from astrbot_sdk.message.components import component_to_payload_sync + +if TYPE_CHECKING: + from astrbot.core.platform.astr_message_event import AstrMessageEvent + + +class EventConverter: + """Convert legacy AstrBot events into SDK payloads.""" + + _DROP_VALUE = object() + + @staticmethod + def _sdk_message_type( + value: Any, + *, + group_id: str | None = None, + user_id: str | None = None, + ) -> str: + return normalize_message_type( + value, + group_id=group_id, + user_id=user_id, + ) + + @classmethod + def _sanitize_extra_value(cls, value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, (list, tuple)): + items = [] + for item in value: + sanitized = cls._sanitize_extra_value(item) + if sanitized is not cls._DROP_VALUE: + items.append(sanitized) + return items + if isinstance(value, dict): + sanitized_dict: dict[str, Any] = {} + for key, item in value.items(): + sanitized = cls._sanitize_extra_value(item) + if sanitized is not cls._DROP_VALUE: + sanitized_dict[str(key)] = sanitized + return sanitized_dict + try: + json.dumps(value) + except (TypeError, ValueError): + return cls._DROP_VALUE + return value + + @classmethod + def _sanitize_extras(cls, extras: dict[str, Any]) -> dict[str, Any]: + sanitized: dict[str, Any] = {} + for key, value in extras.items(): + normalized = cls._sanitize_extra_value(value) + if normalized is not cls._DROP_VALUE: + sanitized[str(key)] = normalized + return sanitized + + @staticmethod + def core_to_sdk( + event: AstrMessageEvent, + *, + dispatch_token: str, + plugin_id: str, + request_id: str, + ) -> dict[str, Any]: + message_type = EventConverter._sdk_message_type( + event.get_message_type(), + group_id=event.get_group_id() or None, + user_id=event.get_sender_id() or None, + ) + raw = { + "dispatch_token": dispatch_token, + "plugin_id": plugin_id, + "request_id": request_id, + "platform_id": event.get_platform_id(), + } + payload: dict[str, Any] = { + "text": event.get_message_str(), + "user_id": event.get_sender_id(), + "group_id": event.get_group_id() or None, + "platform": event.get_platform_name(), + "platform_id": event.get_platform_id(), + "session_id": event.unified_msg_origin, + "self_id": event.get_self_id(), + "message_type": message_type, + "sender_name": event.get_sender_name(), + "is_admin": event.is_admin(), + "is_wake": event.is_wake, + "is_at_or_wake_command": event.is_at_or_wake_command, + "message_outline": event.get_message_outline(), + "raw": raw, + "target": { + "conversation_id": event.unified_msg_origin, + "platform": event.get_platform_name(), + "raw": raw, + }, + } + extras = event.get_extra() + if isinstance(extras, dict) and extras: + sanitized_extras = EventConverter._sanitize_extras(extras) + if sanitized_extras: + payload["extras"] = sanitized_extras + messages = [] + for component in event.get_messages(): + try: + messages.append(component_to_payload_sync(component)) + except Exception: + messages.append( + { + "type": "unknown", + "data": {"value": str(component)}, + } + ) + if messages: + payload["messages"] = messages + return payload + + @staticmethod + def extract_handler_result(sdk_result: dict[str, Any] | None) -> dict[str, Any]: + if not sdk_result: + return {"sent_message": False, "stop": False, "call_llm": False} + return { + "sent_message": bool(sdk_result.get("sent_message", False)), + "stop": bool(sdk_result.get("stop", False)), + "call_llm": bool(sdk_result.get("call_llm", False)), + } diff --git a/astrbot/core/sdk_bridge/plugin_bridge.py b/astrbot/core/sdk_bridge/plugin_bridge.py new file mode 100644 index 0000000000..a5f3b36737 --- /dev/null +++ b/astrbot/core/sdk_bridge/plugin_bridge.py @@ -0,0 +1,3605 @@ +from __future__ import annotations + +import asyncio +import contextlib +import json +import os +import re +import signal +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Any + +from astrbot_sdk.errors import AstrBotError +from astrbot_sdk.llm.agents import AgentSpec +from astrbot_sdk.llm.entities import LLMToolSpec +from astrbot_sdk.message.components import component_to_payload_sync +from astrbot_sdk.protocol.descriptors import ( + CommandTrigger, + CompositeFilterSpec, + EventTrigger, + HandlerDescriptor, + MessageTrigger, + PlatformFilterSpec, + ScheduleTrigger, +) +from astrbot_sdk.runtime.loader import ( + PluginDiscoveryIssue, + PluginEnvironmentManager, + PluginSpec, + discover_plugins, + load_plugin_config, + load_plugin_config_schema, + save_plugin_config, +) +from astrbot_sdk.runtime.supervisor import WorkerSession +from quart import request as quart_request + +from astrbot.core import logger +from astrbot.core.agent.mcp_client import MCPClient +from astrbot.core.message.message_event_result import MessageChain, MessageEventResult +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.provider.entities import LLMResponse as CoreLLMResponse +from astrbot.core.provider.entities import ProviderRequest as CoreProviderRequest +from astrbot.core.skills.skill_manager import ( + SkillManager, + _parse_frontmatter_description, +) +from astrbot.core.utils.astrbot_path import ( + get_astrbot_data_path, + get_astrbot_plugin_data_path, +) + +from .bridge_base import _build_message_chain_from_payload +from .capability_bridge import CoreCapabilityBridge +from .event_converter import EventConverter +from .trigger_converter import TriggerConverter, TriggerMatch + +SDK_STATE_ENABLED = "enabled" +SDK_STATE_DISABLED = "disabled" +SDK_STATE_RELOADING = "reloading" +SDK_STATE_FAILED = "failed" +SDK_STATE_UNSUPPORTED_PARTIAL = "unsupported_partial" + +SKIP_LEGACY_STOPPED = "legacy_stopped" +SKIP_LEGACY_REPLIED = "legacy_replied" +SKIP_SDK_RELOADING = "sdk_reloading" +SKIP_NO_MATCH = "no_match" +SKIP_WORKER_FAILED = "worker_failed" +OVERLAY_TIMEOUT_SECONDS = 300 +SDK_SKILL_NAME_RE = re.compile(r"^[A-Za-z0-9._-]+$") +SUPPORTED_SYSTEM_EVENTS = { + "astrbot_loaded", + "platform_loaded", + "after_message_sent", + "waiting_llm_request", + "agent_begin", + "llm_request", + "agent_done", + "decorating_result", + "calling_func_tool", + "llm_tool_start", + "llm_tool_end", + "plugin_error", + "plugin_loaded", + "plugin_unloaded", +} + + +@dataclass(slots=True) +class SdkHandlerRef: + descriptor: HandlerDescriptor + declaration_order: int + + @property + def handler_id(self) -> str: + return self.descriptor.id + + @property + def handler_name(self) -> str: + return self.descriptor.id.rsplit(".", 1)[-1] + + +@dataclass(slots=True) +class SdkDispatchResult: + matched_handlers: list[dict[str, str]] = field(default_factory=list) + executed_handlers: list[dict[str, str]] = field(default_factory=list) + sent_message: bool = False + stopped: bool = False + skipped_reason: str | None = None + + +@dataclass(slots=True) +class _DispatchState: + event: AstrMessageEvent + sent_message: bool = False + stopped: bool = False + + +@dataclass(slots=True) +class _RequestContext: + plugin_id: str + request_id: str + dispatch_token: str + dispatch_state: _DispatchState | None + cancelled: bool = False + + @property + def has_event(self) -> bool: + return self.dispatch_state is not None + + @property + def event(self) -> AstrMessageEvent: + if self.dispatch_state is None: + raise AstrBotError.invalid_input( + "The current SDK request is not bound to a message event" + ) + return self.dispatch_state.event + + +@dataclass(slots=True) +class _InFlightRequest: + request_id: str + dispatch_token: str + task: asyncio.Task[dict[str, Any]] + logical_cancelled: bool = False + + +@dataclass(slots=True) +class _LocalMCPServerRuntime: + name: str + config: dict[str, Any] + active: bool + running: bool = False + client: MCPClient | None = None + tools: list[str] = field(default_factory=list) + tool_specs: list[LLMToolSpec] = field(default_factory=list) + errlogs: list[str] = field(default_factory=list) + last_error: str | None = None + ready_event: asyncio.Event = field(default_factory=asyncio.Event) + connect_task: asyncio.Task[None] | None = None + lease_path: Path | None = None + + +@dataclass(slots=True) +class _TemporaryMCPSessionRuntime: + plugin_id: str + name: str + client: MCPClient + tools: list[str] + + +@dataclass(slots=True) +class _RequestOverlayState: + dispatch_token: str + should_call_llm: bool + requested_llm: bool = False + sdk_local_extras: dict[str, Any] = field(default_factory=dict) + result_payload: dict[str, Any] | None = None + result_object: MessageEventResult | None = None + result_is_set: bool = False + handler_whitelist: set[str] | None = None + request_scope_ids: set[str] = field(default_factory=set) + closed: bool = False + cleanup_task: asyncio.Task[None] | None = None + + +@dataclass(slots=True) +class SdkPluginRecord: + plugin: PluginSpec + load_order: int + state: str + unsupported_features: list[str] + config_schema: dict[str, Any] + config: dict[str, Any] + handlers: list[SdkHandlerRef] + llm_tools: dict[str, LLMToolSpec] = field(default_factory=dict) + active_llm_tools: set[str] = field(default_factory=set) + agents: dict[str, AgentSpec] = field(default_factory=dict) + skills: dict[str, SdkRegisteredSkill] = field(default_factory=dict) + dynamic_command_routes: list[SdkDynamicCommandRoute] = field(default_factory=list) + session: WorkerSession | None = None + restart_attempted: bool = False + failure_reason: str = "" + issues: list[dict[str, Any]] = field(default_factory=list) + local_mcp_servers: dict[str, _LocalMCPServerRuntime] = field(default_factory=dict) + acknowledge_global_mcp_risk: bool = False + + @property + def plugin_id(self) -> str: + return self.plugin.name + + +@dataclass(slots=True) +class SdkHttpRoute: + plugin_id: str + route: str + methods: tuple[str, ...] + handler_capability: str + description: str + + +@dataclass(slots=True) +class SdkRegisteredSkill: + name: str + description: str + skill_dir: Path + skill_md_path: Path + + def to_registry_payload(self) -> dict[str, str]: + return { + "name": self.name, + "description": self.description, + "path": str(self.skill_md_path), + "skill_dir": str(self.skill_dir), + } + + +@dataclass(slots=True) +class SdkDynamicCommandRoute: + command_name: str + handler_full_name: str + desc: str + priority: int + use_regex: bool + declaration_order: int + + +class SdkPluginBridge: + _DROP_VALUE = object() + + def __init__(self, star_context) -> None: + self.star_context = star_context + self.plugins_dir = Path(get_astrbot_data_path()) / "sdk_plugins" + self.state_path = Path(get_astrbot_data_path()) / "sdk_plugins_state.json" + self.plugins_dir.mkdir(parents=True, exist_ok=True) + self._started = False + self._stopping = False + self._state_overrides = self._load_state_overrides() + self.env_manager = PluginEnvironmentManager(Path(__file__).resolve().parents[3]) + self.capability_bridge = CoreCapabilityBridge( + star_context=star_context, + plugin_bridge=self, + ) + self._records: dict[str, SdkPluginRecord] = {} + self._request_contexts: dict[str, _RequestContext] = {} + self._request_id_to_token: dict[str, str] = {} + self._request_plugin_ids: dict[str, str] = {} + self._request_overlays: dict[str, _RequestOverlayState] = {} + self._plugin_requests: dict[str, dict[str, _InFlightRequest]] = {} + self._http_routes: dict[str, list[SdkHttpRoute]] = {} + self._session_waiters: dict[str, set[str]] = {} + self._schedule_job_ids: dict[str, set[str]] = {} + self._discovery_issues: dict[str, list[dict[str, Any]]] = {} + self._temporary_mcp_sessions: dict[str, _TemporaryMCPSessionRuntime] = {} + + async def start(self) -> None: + if self._started: + return + self._sweep_stale_mcp_leases() + await self.reload_all(reset_restart_budget=True) + self._started = True + + async def stop(self) -> None: + if not self._started and not self._records: + return + self._stopping = True + for plugin_id in list(self._records.keys()): + await self._cancel_plugin_requests(plugin_id) + await self._close_temporary_mcp_sessions(plugin_id) + for record in list(self._records.values()): + await self._shutdown_local_mcp_servers(record) + if record.session is not None: + await record.session.stop() + record.session = None + self._records.clear() + self._request_contexts.clear() + self._request_id_to_token.clear() + self._request_plugin_ids.clear() + for overlay in list(self._request_overlays.values()): + if overlay.cleanup_task is not None: + overlay.cleanup_task.cancel() + self._request_overlays.clear() + self._plugin_requests.clear() + self._http_routes.clear() + self._session_waiters.clear() + self._schedule_job_ids.clear() + self._temporary_mcp_sessions.clear() + self._started = False + self._stopping = False + + async def reload_all(self, *, reset_restart_budget: bool = False) -> None: + discovered = discover_plugins(self.plugins_dir) + self._set_discovery_issues(discovered.issues) + self.env_manager.plan(discovered.plugins) + known = {plugin.name for plugin in discovered.plugins} + SkillManager().prune_sdk_plugin_skills(known) + for plugin_id in list(self._records.keys()): + if plugin_id not in known: + await self._teardown_plugin(plugin_id) + self._records.pop(plugin_id, None) + for load_order, plugin in enumerate(discovered.plugins): + await self._load_or_reload_plugin( + plugin, + load_order=load_order, + reset_restart_budget=reset_restart_budget, + ) + await self._refresh_native_platform_commands({"telegram"}) + + async def reload_plugin(self, plugin_id: str) -> None: + discovered = discover_plugins(self.plugins_dir) + self._set_discovery_issues(discovered.issues) + self.env_manager.plan(discovered.plugins) + for load_order, plugin in enumerate(discovered.plugins): + if plugin.name != plugin_id: + continue + await self._load_or_reload_plugin( + plugin, + load_order=load_order, + reset_restart_budget=True, + ) + await self._refresh_native_platform_commands({"telegram"}) + return + raise ValueError(f"SDK plugin not found: {plugin_id}") + + async def turn_off_plugin(self, plugin_id: str) -> None: + record = self._records.get(plugin_id) + if record is None: + raise ValueError(f"SDK plugin not found: {plugin_id}") + record.state = SDK_STATE_DISABLED + await self._cancel_plugin_requests(plugin_id) + await self._teardown_plugin(plugin_id) + record.failure_reason = "" + self._set_disabled_override(plugin_id, disabled=True) + await self._refresh_native_platform_commands({"telegram"}) + + async def turn_on_plugin(self, plugin_id: str) -> None: + discovered = discover_plugins(self.plugins_dir) + self._set_discovery_issues(discovered.issues) + self.env_manager.plan(discovered.plugins) + for load_order, plugin in enumerate(discovered.plugins): + if plugin.name != plugin_id: + continue + self._set_disabled_override(plugin_id, disabled=False) + await self._load_or_reload_plugin( + plugin, + load_order=load_order, + reset_restart_budget=True, + ) + await self._refresh_native_platform_commands({"telegram"}) + return + raise ValueError(f"SDK plugin not found: {plugin_id}") + + def list_plugins(self) -> list[dict[str, Any]]: + records = sorted(self._records.values(), key=lambda item: item.load_order) + items = [self._record_to_dashboard_item(record) for record in records] + for plugin_id, issues in sorted(self._discovery_issues.items()): + if plugin_id in self._records: + continue + items.append(self._failed_issue_to_dashboard_item(plugin_id, issues)) + return items + + def get_plugin_metadata(self, plugin_id: str) -> dict[str, Any] | None: + record = self._records.get(plugin_id) + if record is not None: + manifest = record.plugin.manifest_data + support_platforms = manifest.get("support_platforms") + return { + "name": plugin_id, + "display_name": str(manifest.get("display_name") or plugin_id), + "description": str( + manifest.get("desc") or manifest.get("description") or "" + ), + "author": str(manifest.get("author") or ""), + "version": str(manifest.get("version") or "0.0.0"), + "enabled": record.state not in {SDK_STATE_DISABLED, SDK_STATE_FAILED}, + "support_platforms": [ + str(item) for item in support_platforms if isinstance(item, str) + ] + if isinstance(support_platforms, list) + else [], + "astrbot_version": ( + str(manifest.get("astrbot_version")) + if manifest.get("astrbot_version") is not None + else None + ), + "runtime_kind": "sdk", + "issues": [dict(item) for item in record.issues], + } + for plugin in self.star_context.get_all_stars(): + if plugin.name == plugin_id: + return { + "name": plugin.name, + "display_name": plugin.display_name, + "description": plugin.desc, + "author": plugin.author, + "version": plugin.version, + "enabled": plugin.activated, + "support_platforms": list(plugin.support_platforms), + "astrbot_version": plugin.astrbot_version, + "runtime_kind": "legacy", + } + if plugin_id in self._discovery_issues: + issue = self._discovery_issues[plugin_id][0] + return { + "name": plugin_id, + "display_name": plugin_id, + "description": str(issue.get("message", "")), + "author": "", + "version": "0.0.0", + "enabled": False, + "support_platforms": [], + "astrbot_version": None, + "runtime_kind": "sdk", + "issues": [dict(item) for item in self._discovery_issues[plugin_id]], + } + return None + + def list_plugin_metadata(self) -> list[dict[str, Any]]: + metadata = [] + for plugin in self.star_context.get_all_stars(): + metadata.append( + { + "name": plugin.name, + "display_name": plugin.display_name, + "description": plugin.desc, + "author": plugin.author, + "version": plugin.version, + "enabled": plugin.activated, + "support_platforms": list(plugin.support_platforms), + "astrbot_version": plugin.astrbot_version, + "runtime_kind": "legacy", + } + ) + for plugin_id in sorted(self._records.keys()): + plugin_metadata = self.get_plugin_metadata(plugin_id) + if plugin_metadata is not None: + metadata.append(plugin_metadata) + for plugin_id in sorted(self._discovery_issues.keys()): + if plugin_id in self._records: + continue + plugin_metadata = self.get_plugin_metadata(plugin_id) + if plugin_metadata is not None: + metadata.append(plugin_metadata) + return metadata + + def get_plugin_config(self, plugin_id: str) -> dict[str, Any] | None: + record = self._records.get(plugin_id) + if record is None: + return None + return dict(record.config) + + def get_plugin_config_schema(self, plugin_id: str) -> dict[str, Any] | None: + record = self._records.get(plugin_id) + if record is None: + return None + return dict(record.config_schema) + + def save_plugin_config( + self, + plugin_id: str, + payload: dict[str, Any], + ) -> dict[str, Any]: + record = self._records.get(plugin_id) + if record is None: + raise ValueError(f"SDK plugin not found: {plugin_id}") + normalized = save_plugin_config( + record.plugin, + payload, + schema=record.config_schema, + ) + record.config = dict(normalized) + return dict(record.config) + + def get_registered_llm_tools(self, plugin_id: str) -> list[LLMToolSpec]: + record = self._records.get(plugin_id) + if record is None: + return [] + return [item.model_copy(deep=True) for item in record.llm_tools.values()] + + def get_active_llm_tools(self, plugin_id: str) -> list[LLMToolSpec]: + record = self._records.get(plugin_id) + if record is None: + return [] + return [ + item.model_copy(deep=True) + for name, item in record.llm_tools.items() + if name in record.active_llm_tools + ] + + def get_llm_tool(self, plugin_id: str, name: str) -> LLMToolSpec | None: + record = self._records.get(plugin_id) + if record is None: + return None + spec = record.llm_tools.get(name) + if spec is None: + return None + return spec.model_copy(deep=True) + + def add_llm_tools(self, plugin_id: str, tools: list[LLMToolSpec]) -> list[str]: + record = self._records.get(plugin_id) + if record is None: + return [] + names: list[str] = [] + for spec in tools: + record.llm_tools[spec.name] = spec.model_copy(deep=True) + if spec.active: + record.active_llm_tools.add(spec.name) + else: + record.active_llm_tools.discard(spec.name) + names.append(spec.name) + return names + + def remove_llm_tool(self, plugin_id: str, name: str) -> bool: + record = self._records.get(plugin_id) + if record is None: + return False + removed = record.llm_tools.pop(name, None) is not None + record.active_llm_tools.discard(name) + return removed + + def activate_llm_tool(self, plugin_id: str, name: str) -> bool: + record = self._records.get(plugin_id) + if record is None: + return False + spec = record.llm_tools.get(name) + if spec is None: + return False + spec.active = True + record.active_llm_tools.add(name) + return True + + def deactivate_llm_tool(self, plugin_id: str, name: str) -> bool: + record = self._records.get(plugin_id) + if record is None: + return False + spec = record.llm_tools.get(name) + if spec is None: + return False + spec.active = False + record.active_llm_tools.discard(name) + return True + + def _local_mcp_record( + self, plugin_id: str, name: str + ) -> _LocalMCPServerRuntime | None: + record = self._records.get(plugin_id) + if record is None: + return None + return record.local_mcp_servers.get(name) + + @staticmethod + def _serialize_local_mcp_server( + runtime: _LocalMCPServerRuntime, + ) -> dict[str, Any]: + errlogs = list(runtime.errlogs) + if runtime.client is not None: + errlogs.extend(str(item) for item in runtime.client.server_errlogs) + return { + "name": runtime.name, + "scope": "local", + "active": runtime.active, + "running": runtime.running, + "config": dict(runtime.config), + "tools": list(runtime.tools), + "errlogs": errlogs, + "last_error": runtime.last_error, + } + + def get_local_mcp_server( + self, + plugin_id: str, + name: str, + ) -> dict[str, Any] | None: + runtime = self._local_mcp_record(plugin_id, name) + if runtime is None: + return None + return self._serialize_local_mcp_server(runtime) + + def list_local_mcp_servers(self, plugin_id: str) -> list[dict[str, Any]]: + record = self._records.get(plugin_id) + if record is None: + return [] + return [ + self._serialize_local_mcp_server(runtime) + for runtime in sorted( + record.local_mcp_servers.values(), + key=lambda item: item.name, + ) + ] + + def get_request_tool_specs(self, plugin_id: str) -> list[LLMToolSpec]: + record = self._records.get(plugin_id) + if record is None: + return [] + specs: dict[str, LLMToolSpec] = { + item.name: item.model_copy(deep=True) + for name, item in record.llm_tools.items() + if name in record.active_llm_tools + } + for runtime in record.local_mcp_servers.values(): + if not runtime.active or not runtime.running: + continue + for spec in runtime.tool_specs: + specs.setdefault(spec.name, spec.model_copy(deep=True)) + return list(specs.values()) + + def get_registered_agents(self, plugin_id: str) -> list[AgentSpec]: + record = self._records.get(plugin_id) + if record is None: + return [] + return [item.model_copy(deep=True) for item in record.agents.values()] + + def get_registered_agent(self, plugin_id: str, name: str) -> AgentSpec | None: + record = self._records.get(plugin_id) + if record is None: + return None + spec = record.agents.get(name) + if spec is None: + return None + return spec.model_copy(deep=True) + + def register_dynamic_command_route( + self, + *, + plugin_id: str, + command_name: str, + handler_full_name: str, + desc: str = "", + priority: int = 0, + use_regex: bool = False, + ) -> None: + record = self._records.get(plugin_id) + if record is None: + raise AstrBotError.invalid_input(f"Unknown SDK plugin: {plugin_id}") + if isinstance(priority, bool) or not isinstance(priority, int): + raise AstrBotError.invalid_input("priority must be an integer") + command_text = str(command_name).strip() + if not command_text: + raise AstrBotError.invalid_input("command_name must not be empty") + handler_text = str(handler_full_name).strip() + if not handler_text: + raise AstrBotError.invalid_input("handler_full_name must not be empty") + if not handler_text.startswith(f"{plugin_id}:"): + raise AstrBotError.invalid_input( + "handler_full_name must belong to the caller plugin" + ) + if self._find_handler_ref(record, handler_text) is None: + raise AstrBotError.invalid_input( + f"Unknown handler_full_name for plugin '{plugin_id}': {handler_text}" + ) + existing_order = next( + ( + route.declaration_order + for route in record.dynamic_command_routes + if route.command_name == command_text + and route.use_regex is bool(use_regex) + ), + len(record.dynamic_command_routes), + ) + updated = [ + route + for route in record.dynamic_command_routes + if not ( + route.command_name == command_text + and route.use_regex is bool(use_regex) + ) + ] + updated.append( + SdkDynamicCommandRoute( + command_name=command_text, + handler_full_name=handler_text, + desc=str(desc), + priority=priority, + use_regex=bool(use_regex), + declaration_order=existing_order, + ) + ) + updated.sort(key=lambda item: item.declaration_order) + record.dynamic_command_routes = updated + + def register_skill( + self, + *, + plugin_id: str, + name: str, + path: str, + description: str = "", + ) -> dict[str, str]: + record = self._records.get(plugin_id) + if record is None: + raise AstrBotError.invalid_input(f"Unknown SDK plugin: {plugin_id}") + + skill_name = str(name).strip() + if not skill_name or not SDK_SKILL_NAME_RE.fullmatch(skill_name): + raise AstrBotError.invalid_input( + "skill.register requires a name matching [A-Za-z0-9._-]+" + ) + + path_text = str(path).strip() + if not path_text: + raise AstrBotError.invalid_input("skill.register requires path") + + plugin_root = record.plugin.plugin_dir.resolve() + requested_path = Path(path_text) + resolved_path = ( + requested_path.resolve() + if requested_path.is_absolute() + else (plugin_root / requested_path).resolve() + ) + + skill_dir = resolved_path if resolved_path.is_dir() else resolved_path.parent + skill_md_path = ( + resolved_path / "SKILL.md" if resolved_path.is_dir() else resolved_path + ) + if skill_md_path.name != "SKILL.md" or not skill_md_path.is_file(): + raise AstrBotError.invalid_input( + "skill.register path must point to a skill directory containing SKILL.md or to SKILL.md itself" + ) + if not skill_dir.is_dir(): + raise AstrBotError.invalid_input( + "skill.register resolved skill_dir is not a directory" + ) + if not skill_md_path.is_relative_to(plugin_root): + raise AstrBotError.invalid_input( + "skill.register path must stay inside the plugin directory" + ) + + normalized_description = str(description).strip() + if not normalized_description: + try: + normalized_description = _parse_frontmatter_description( + skill_md_path.read_text(encoding="utf-8") + ) + except Exception: + normalized_description = "" + + record.skills[skill_name] = SdkRegisteredSkill( + name=skill_name, + description=normalized_description, + skill_dir=skill_dir, + skill_md_path=skill_md_path, + ) + self._publish_plugin_skills(plugin_id) + return { + "name": skill_name, + "description": normalized_description, + "path": str(skill_md_path), + "skill_dir": str(skill_dir), + } + + def unregister_skill(self, *, plugin_id: str, name: str) -> bool: + record = self._records.get(plugin_id) + if record is None: + raise AstrBotError.invalid_input(f"Unknown SDK plugin: {plugin_id}") + removed = record.skills.pop(str(name).strip(), None) is not None + if removed: + self._publish_plugin_skills(plugin_id) + return removed + + def list_registered_skills(self, plugin_id: str) -> list[dict[str, str]]: + record = self._records.get(plugin_id) + if record is None: + return [] + return [ + record.skills[name].to_registry_payload() + for name in sorted(record.skills.keys()) + ] + + def _publish_plugin_skills(self, plugin_id: str) -> None: + record = self._records.get(plugin_id) + manager = SkillManager() + if record is None or not record.skills: + manager.remove_sdk_plugin_skills(plugin_id) + return + manager.replace_sdk_plugin_skills( + plugin_id, + [skill.to_registry_payload() for skill in record.skills.values()], + ) + + async def _clear_plugin_skills( + self, + *, + plugin_id: str, + record: SdkPluginRecord | Any | None, + reason: str, + ) -> None: + if record is None or not getattr(record, "skills", None): + return + record.skills.clear() + self._publish_plugin_skills(plugin_id) + try: + from astrbot.core.computer.computer_client import ( + sync_skills_to_active_sandboxes, + ) + + # Keep sandbox-visible skills aligned with the bridge registry so a + # stopped plugin cannot continue exposing dead skill entries. + await sync_skills_to_active_sandboxes() + except Exception as exc: + logger.warning( + "Failed to sync skills after SDK plugin %s %s: %s", + plugin_id, + reason, + exc, + ) + + def register_http_api( + self, + *, + plugin_id: str, + route: str, + methods: list[str], + handler_capability: str, + description: str, + ) -> None: + normalized_route = self._normalize_http_route(route) + normalized_methods = self._normalize_http_methods(methods) + if not handler_capability: + raise AstrBotError.invalid_input( + "http.register_api requires handler_capability" + ) + self._ensure_http_route_available( + plugin_id=plugin_id, + route=normalized_route, + methods=normalized_methods, + ) + route_entry = SdkHttpRoute( + plugin_id=plugin_id, + route=normalized_route, + methods=normalized_methods, + handler_capability=handler_capability, + description=description, + ) + plugin_routes = [ + entry + for entry in self._http_routes.get(plugin_id, []) + if not ( + entry.route == normalized_route and entry.methods == normalized_methods + ) + ] + plugin_routes.append(route_entry) + self._http_routes[plugin_id] = plugin_routes + + def unregister_http_api( + self, + *, + plugin_id: str, + route: str, + methods: list[str], + ) -> None: + normalized_route = self._normalize_http_route(route) + normalized_methods = {method.upper() for method in methods if method} + updated: list[SdkHttpRoute] = [] + for entry in self._http_routes.get(plugin_id, []): + if entry.route != normalized_route: + updated.append(entry) + continue + if not normalized_methods: + # Plugins do not have a separate "delete route" capability, so an + # empty method list means "remove every method registered on route". + continue + remaining = tuple( + method for method in entry.methods if method not in normalized_methods + ) + if remaining: + updated.append( + SdkHttpRoute( + plugin_id=entry.plugin_id, + route=entry.route, + methods=remaining, + handler_capability=entry.handler_capability, + description=entry.description, + ) + ) + if updated: + self._http_routes[plugin_id] = updated + else: + self._http_routes.pop(plugin_id, None) + + def list_http_apis(self, plugin_id: str) -> list[dict[str, Any]]: + return [ + { + "route": entry.route, + "methods": list(entry.methods), + "handler_capability": entry.handler_capability, + "description": entry.description, + } + for entry in self._http_routes.get(plugin_id, []) + ] + + async def dispatch_http_request( + self, + route: str, + method: str, + ) -> dict[str, Any] | None: + resolved = self._resolve_http_route(route, method) + if resolved is None: + return None + record, route_entry = resolved + if record.session is None: + raise AstrBotError.invalid_input("SDK HTTP route worker is unavailable") + text_body = await quart_request.get_data(as_text=True) + payload = { + "method": method.upper(), + "route": route_entry.route, + "path": quart_request.path, + "query": quart_request.args.to_dict(flat=False), + "headers": dict(quart_request.headers), + "json_body": await quart_request.get_json(silent=True), + "text_body": text_body, + } + output = await record.session.invoke_capability( + route_entry.handler_capability, + payload, + request_id=f"sdk_http_{record.plugin_id}_{uuid.uuid4().hex}", + ) + if not isinstance(output, dict): + raise AstrBotError.invalid_input("SDK HTTP handler must return an object") + return output + + def register_session_waiter(self, *, plugin_id: str, session_key: str) -> None: + if not session_key: + raise AstrBotError.invalid_input( + "session waiter registration requires session_key" + ) + self._session_waiters.setdefault(plugin_id, set()).add(session_key) + + def unregister_session_waiter(self, *, plugin_id: str, session_key: str) -> None: + plugin_waiters = self._session_waiters.get(plugin_id) + if plugin_waiters is None: + return + plugin_waiters.discard(session_key) + if not plugin_waiters: + self._session_waiters.pop(plugin_id, None) + + async def dispatch_message(self, event: AstrMessageEvent) -> SdkDispatchResult: + result = SdkDispatchResult() + if event.is_stopped(): + result.skipped_reason = SKIP_LEGACY_STOPPED + return result + if self._legacy_has_replied(event): + result.skipped_reason = SKIP_LEGACY_REPLIED + return result + + waiter_plugins = self._match_waiter_plugins(event.unified_msg_origin) + if waiter_plugins: + return await self._dispatch_waiter_event(event, waiter_plugins) + + dispatch_token = self._get_dispatch_token(event) or uuid.uuid4().hex + self._bind_dispatch_token(event, dispatch_token) + overlay = self._ensure_request_overlay( + dispatch_token, + should_call_llm=not bool(getattr(event, "call_llm", False)), + ) + matches = self._match_handlers(event) + if not matches: + result.skipped_reason = SKIP_NO_MATCH + return result + result.matched_handlers = [ + {"plugin_id": match.plugin_id, "handler_id": match.handler_id} + for match in matches + ] + + dispatch_state = _DispatchState(event=event) + request_context = self._request_contexts.get(dispatch_token) + if request_context is None: + request_context = _RequestContext( + plugin_id="", + request_id="", + dispatch_token=dispatch_token, + dispatch_state=dispatch_state, + ) + self._request_contexts[dispatch_token] = request_context + else: + request_context.dispatch_state = dispatch_state + skipped_reason = None + for match in matches: + whitelist = ( + None + if overlay.handler_whitelist is None + else set(overlay.handler_whitelist) + ) + if whitelist is not None and match.plugin_id not in whitelist: + continue + record = self._records.get(match.plugin_id) + if record is None: + continue + if record.state == SDK_STATE_RELOADING: + skipped_reason = skipped_reason or SKIP_SDK_RELOADING + continue + if ( + record.state in {SDK_STATE_FAILED, SDK_STATE_DISABLED} + or record.session is None + ): + skipped_reason = skipped_reason or SKIP_WORKER_FAILED + continue + + request_id = f"sdk_{record.plugin_id}_{uuid.uuid4().hex}" + request_context.plugin_id = record.plugin_id + request_context.request_id = request_id + request_context.cancelled = False + setattr(event, "_sdk_last_request_id", request_id) + payload = EventConverter.core_to_sdk( + event, + dispatch_token=dispatch_token, + plugin_id=record.plugin_id, + request_id=request_id, + ) + self._apply_request_scoped_event_payload(payload, overlay) + task = asyncio.create_task( + record.session.invoke_handler( + match.handler_id, + payload, + request_id=request_id, + args=match.args, + ) + ) + self._track_request_scope( + dispatch_token=dispatch_token, + request_id=request_id, + plugin_id=record.plugin_id, + ) + self._plugin_requests.setdefault(record.plugin_id, {})[request_id] = ( + _InFlightRequest( + request_id=request_id, + dispatch_token=dispatch_token, + task=task, + ) + ) + + try: + output = await task + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning( + "SDK handler failed: plugin=%s handler=%s error=%s", + record.plugin_id, + match.handler_id, + exc, + ) + skipped_reason = skipped_reason or SKIP_WORKER_FAILED + output = {} + finally: + inflight = self._plugin_requests.get(record.plugin_id, {}).pop( + request_id, + None, + ) + + if inflight is not None and inflight.logical_cancelled: + continue + + handler_result = EventConverter.extract_handler_result( + output if isinstance(output, dict) else {} + ) + if isinstance(output, dict) and "sdk_local_extras" in output: + self._persist_sdk_local_extras_from_handler( + overlay, + output.get("sdk_local_extras"), + plugin_id=record.plugin_id, + handler_id=match.handler_id, + ) + result.executed_handlers.append( + {"plugin_id": record.plugin_id, "handler_id": match.handler_id} + ) + dispatch_state.sent_message = ( + dispatch_state.sent_message or handler_result["sent_message"] + ) + dispatch_state.stopped = dispatch_state.stopped or handler_result["stop"] + if handler_result["call_llm"]: + overlay.requested_llm = True + overlay.should_call_llm = True + if handler_result["sent_message"] or handler_result["stop"]: + overlay.should_call_llm = False + if handler_result["stop"]: + break + + result.sent_message = dispatch_state.sent_message + result.stopped = dispatch_state.stopped + if not result.executed_handlers: + result.skipped_reason = skipped_reason or SKIP_NO_MATCH + if result.sent_message: + event._has_send_oper = True + overlay.should_call_llm = False + event.should_call_llm(True) + if result.stopped: + event.stop_event() + overlay.should_call_llm = False + event.should_call_llm(True) + return result + + def resolve_request_plugin_id(self, request_id: str) -> str: + plugin_id = self._request_plugin_ids.get(request_id) + if plugin_id is not None: + return plugin_id + token = self._request_id_to_token.get(request_id) + if token is not None and token in self._request_contexts: + return self._request_contexts[token].plugin_id + raise AstrBotError.invalid_input(f"Unknown SDK request id: {request_id}") + + def resolve_request_session(self, request_id: str) -> _RequestContext | None: + token = self._request_id_to_token.get(request_id) + if token is None: + return None + return self._request_contexts.get(token) + + def get_request_context_by_token( + self, dispatch_token: str + ) -> _RequestContext | None: + return self._request_contexts.get(dispatch_token) + + def _bind_dispatch_token( + self, event: AstrMessageEvent, dispatch_token: str + ) -> None: + setattr(event, "_sdk_dispatch_token", dispatch_token) + + def _get_dispatch_token(self, event: AstrMessageEvent) -> str | None: + token = getattr(event, "_sdk_dispatch_token", None) + return str(token) if token else None + + def _schedule_overlay_cleanup( + self, dispatch_token: str + ) -> asyncio.Task[None] | None: + async def _cleanup_later() -> None: + try: + await asyncio.sleep(OVERLAY_TIMEOUT_SECONDS) + except asyncio.CancelledError: + return + self._close_request_overlay(dispatch_token) + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return None + return loop.create_task(_cleanup_later()) + + def _ensure_request_overlay( + self, + dispatch_token: str, + *, + should_call_llm: bool, + ) -> _RequestOverlayState: + overlay = self._request_overlays.get(dispatch_token) + if overlay is not None: + if overlay.closed: + overlay.closed = False + if overlay.cleanup_task is None or overlay.cleanup_task.done(): + overlay.cleanup_task = self._schedule_overlay_cleanup(dispatch_token) + return overlay + overlay = _RequestOverlayState( + dispatch_token=dispatch_token, + should_call_llm=should_call_llm, + cleanup_task=self._schedule_overlay_cleanup(dispatch_token), + ) + self._request_overlays[dispatch_token] = overlay + return overlay + + def _track_request_scope( + self, + *, + dispatch_token: str, + request_id: str, + plugin_id: str, + ) -> None: + # request-scoped system.event.* calls may outlive the original handler RPC + # when plugin code moves follow-up work into background tasks. + self._request_id_to_token[request_id] = dispatch_token + self._request_plugin_ids[request_id] = plugin_id + overlay = self._request_overlays.get(dispatch_token) + if overlay is not None: + overlay.request_scope_ids.add(request_id) + + def _close_request_overlay(self, dispatch_token: str) -> None: + overlay = self._request_overlays.pop(dispatch_token, None) + if overlay is not None: + overlay.closed = True + if overlay.cleanup_task is not None: + overlay.cleanup_task.cancel() + for request_id in overlay.request_scope_ids: + self._request_id_to_token.pop(request_id, None) + self._request_plugin_ids.pop(request_id, None) + request_context = self._request_contexts.pop(dispatch_token, None) + if request_context is not None: + request_context.cancelled = True + + def close_request_overlay_for_event(self, event: AstrMessageEvent) -> None: + dispatch_token = self._get_dispatch_token(event) + if not dispatch_token: + return + self._close_request_overlay(dispatch_token) + + def get_request_overlay_by_token( + self, dispatch_token: str + ) -> _RequestOverlayState | None: + overlay = self._request_overlays.get(dispatch_token) + if overlay is None or overlay.closed: + return None + return overlay + + def get_request_overlay_by_request_id( + self, request_id: str + ) -> _RequestOverlayState | None: + token = self._request_id_to_token.get(request_id) + if not token: + return None + return self.get_request_overlay_by_token(token) + + def request_llm_for_request(self, request_id: str) -> bool: + overlay = self.get_request_overlay_by_request_id(request_id) + if overlay is None: + return False + overlay.requested_llm = True + overlay.should_call_llm = True + return True + + def get_effective_should_call_llm(self, event: AstrMessageEvent) -> bool: + dispatch_token = self._get_dispatch_token(event) + if dispatch_token: + overlay = self.get_request_overlay_by_token(dispatch_token) + if overlay is not None: + return overlay.should_call_llm + return not bool(getattr(event, "call_llm", False)) + + def get_should_call_llm_for_request(self, request_id: str) -> bool | None: + overlay = self.get_request_overlay_by_request_id(request_id) + if overlay is None: + return None + return overlay.should_call_llm + + def set_result_for_request( + self, + request_id: str, + result_payload: dict[str, Any] | None, + ) -> bool: + overlay = self.get_request_overlay_by_request_id(request_id) + if overlay is None: + return False + if result_payload is None: + overlay.result_payload = None + overlay.result_object = None + else: + normalized_payload = json.loads(json.dumps(result_payload)) + overlay.result_payload = normalized_payload + chain_payload = normalized_payload.get("chain") + overlay.result_object = ( + self._build_core_result_from_chain_payload(chain_payload) + if isinstance(chain_payload, list) + else None + ) + overlay.result_is_set = True + return True + + def clear_result_for_request(self, request_id: str) -> bool: + overlay = self.get_request_overlay_by_request_id(request_id) + if overlay is None: + return False + overlay.result_payload = None + overlay.result_object = None + overlay.result_is_set = True + return True + + def get_result_payload_for_request(self, request_id: str) -> dict[str, Any] | None: + overlay = self.get_request_overlay_by_request_id(request_id) + request_context = self.resolve_request_session(request_id) + request_context_has_event = False + if request_context is not None: + has_event = getattr(request_context, "has_event", None) + request_context_has_event = ( + bool(has_event) + if has_event is not None + else hasattr(request_context, "event") + ) + if overlay is not None and overlay.result_is_set: + if overlay.result_object is not None: + overlay.result_payload = self._legacy_result_to_sdk_payload( + overlay.result_object + ) + return ( + json.loads(json.dumps(overlay.result_payload)) + if overlay.result_payload is not None + else None + ) + if request_context is None or not request_context_has_event: + return None + return self._legacy_result_to_sdk_payload(request_context.event.get_result()) + + def set_handler_whitelist_for_request( + self, + request_id: str, + plugin_names: set[str] | None, + ) -> bool: + overlay = self.get_request_overlay_by_request_id(request_id) + if overlay is None: + return False + overlay.handler_whitelist = None if plugin_names is None else set(plugin_names) + return True + + def get_handler_whitelist_for_request(self, request_id: str) -> set[str] | None: + overlay = self.get_request_overlay_by_request_id(request_id) + if overlay is None: + return None + return ( + None + if overlay.handler_whitelist is None + else set(overlay.handler_whitelist) + ) + + def _get_handler_whitelist_for_event( + self, event: AstrMessageEvent + ) -> set[str] | None: + dispatch_token = self._get_dispatch_token(event) + if not dispatch_token: + return None + overlay = self.get_request_overlay_by_token(dispatch_token) + if overlay is None: + return None + return ( + None + if overlay.handler_whitelist is None + else set(overlay.handler_whitelist) + ) + + @staticmethod + def _build_core_message_chain_from_payload( + chain_payload: list[dict[str, Any]], + ) -> MessageChain: + return _build_message_chain_from_payload(chain_payload) + + @classmethod + def _build_core_result_from_chain_payload( + cls, + chain_payload: list[dict[str, Any]], + ) -> MessageEventResult: + chain = cls._build_core_message_chain_from_payload(chain_payload) + result = MessageEventResult() + # Core stages currently treat result.chain as a MessageChain-like object and + # call get_plain_text()/mutate nested components on it directly. + setattr(result, "chain", chain) + result.use_t2i_ = chain.use_t2i_ + result.type = chain.type + return result + + @staticmethod + def _legacy_result_to_sdk_payload( + result: MessageEventResult | None, + ) -> dict[str, Any] | None: + if result is None: + return None + chain = ( + result.chain.chain + if isinstance(result.chain, MessageChain) + else result.chain + ) + return { + "type": "chain" if chain else "empty", + "chain": SdkPluginBridge._components_to_sdk_payload(chain), + } + + @staticmethod + def _components_to_sdk_payload( + components: list[Any] | tuple[Any, ...] | None, + ) -> list[dict[str, Any]]: + return [ + component_to_payload_sync(component) for component in (components or []) + ] + + @classmethod + def _sanitize_sdk_extra_value(cls, value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, (list, tuple)): + items = [] + for item in value: + normalized = cls._sanitize_sdk_extra_value(item) + if normalized is not cls._DROP_VALUE: + items.append(normalized) + return items + if isinstance(value, dict): + normalized_dict: dict[str, Any] = {} + for key, item in value.items(): + normalized = cls._sanitize_sdk_extra_value(item) + if normalized is not cls._DROP_VALUE: + normalized_dict[str(key)] = normalized + return normalized_dict + model_dump = getattr(value, "model_dump", None) + if callable(model_dump): + try: + return cls._sanitize_sdk_extra_value(model_dump()) + except Exception: + return cls._DROP_VALUE + try: + json.dumps(value) + except (TypeError, ValueError): + return cls._DROP_VALUE + return value + + @classmethod + def _normalize_sdk_local_extras( + cls, + payload: Any, + ) -> tuple[dict[str, Any], list[str]]: + if not isinstance(payload, dict): + return {}, [] + normalized: dict[str, Any] = {} + dropped_keys: list[str] = [] + for key, value in payload.items(): + normalized_value = cls._sanitize_sdk_extra_value(value) + if normalized_value is cls._DROP_VALUE: + dropped_keys.append(str(key)) + continue + normalized[str(key)] = normalized_value + return normalized, dropped_keys + + @classmethod + def _apply_request_scoped_event_payload( + cls, + event_payload: dict[str, Any], + overlay: _RequestOverlayState, + ) -> None: + host_extras = ( + dict(event_payload["extras"]) + if isinstance(event_payload.get("extras"), dict) + else {} + ) + sdk_local_extras = dict(overlay.sdk_local_extras) + merged_extras = dict(host_extras) + merged_extras.update(sdk_local_extras) + event_payload["host_extras"] = host_extras + event_payload["sdk_local_extras"] = sdk_local_extras + event_payload["extras"] = merged_extras + + @classmethod + def _persist_sdk_local_extras_from_handler( + cls, + overlay: _RequestOverlayState, + payload: Any, + *, + plugin_id: str, + handler_id: str, + ) -> None: + if payload is None: + overlay.sdk_local_extras = {} + return + if not isinstance(payload, dict): + logger.warning( + "SDK event handler returned invalid sdk_local_extras: plugin=%s handler=%s payload_type=%s", + plugin_id, + handler_id, + type(payload).__name__, + ) + return + normalized, dropped_keys = cls._normalize_sdk_local_extras(payload) + overlay.sdk_local_extras = normalized + for key in dropped_keys: + logger.warning( + "Dropped non-serializable sdk_local_extras entry: plugin=%s handler=%s key=%s", + plugin_id, + handler_id, + key, + ) + + @staticmethod + def _core_provider_request_to_sdk_payload( + request: CoreProviderRequest, + ) -> dict[str, Any]: + tool_calls_result: list[dict[str, Any]] = [] + raw_results = request.tool_calls_result + if raw_results is not None: + if not isinstance(raw_results, list): + raw_results = [raw_results] + for item in raw_results: + if not getattr(item, "tool_calls_result", None): + continue + tool_name_by_id: dict[str, str] = {} + tool_calls_info = getattr(item, "tool_calls_info", None) + raw_tool_calls = getattr(tool_calls_info, "tool_calls", None) + if isinstance(raw_tool_calls, list): + for tool_call in raw_tool_calls: + if isinstance(tool_call, dict): + tool_call_id = tool_call.get("id") + function_payload = tool_call.get("function") + if isinstance(function_payload, dict): + tool_name = function_payload.get("name") + else: + tool_name = None + else: + tool_call_id = getattr(tool_call, "id", None) + function_payload = getattr(tool_call, "function", None) + tool_name = getattr(function_payload, "name", None) + if tool_call_id is None or tool_name is None: + continue + tool_name_by_id[str(tool_call_id)] = str(tool_name) + for tool_result in item.tool_calls_result: + tool_name = "" + tool_call_id = getattr(tool_result, "tool_call_id", None) + content = getattr(tool_result, "content", "") + success = True + if tool_call_id is not None: + tool_name = tool_name_by_id.get(str(tool_call_id), "") + tool_calls_result.append( + { + "tool_call_id": str(tool_call_id) + if tool_call_id is not None + else None, + "tool_name": tool_name, + "content": str(content or ""), + "success": bool(success), + } + ) + return { + "prompt": request.prompt, + "system_prompt": request.system_prompt or None, + "session_id": request.session_id or None, + "contexts": json.loads(json.dumps(request.contexts or [])), + "image_urls": list(request.image_urls or []), + "tool_calls_result": tool_calls_result, + "model": request.model, + } + + @staticmethod + def _apply_sdk_provider_request_payload( + request: CoreProviderRequest, + payload: dict[str, Any], + ) -> None: + prompt = payload.get("prompt") + request.prompt = None if prompt is None else str(prompt) + system_prompt = payload.get("system_prompt") + request.system_prompt = "" if system_prompt is None else str(system_prompt) + session_id = payload.get("session_id") + request.session_id = None if session_id is None else str(session_id) + + contexts = payload.get("contexts") + if isinstance(contexts, list): + request.contexts = json.loads(json.dumps(contexts)) + + image_urls = payload.get("image_urls") + if isinstance(image_urls, list): + request.image_urls = [str(item) for item in image_urls] + + model = payload.get("model") + request.model = None if model is None else str(model) + + @staticmethod + def _core_llm_response_to_sdk_payload( + response: CoreLLMResponse, + ) -> dict[str, Any]: + usage_payload = None + if response.usage is not None: + usage_payload = { + "input_tokens": response.usage.input, + "output_tokens": response.usage.output, + "total_tokens": response.usage.total, + "input_cached_tokens": response.usage.input_cached, + } + tool_calls: list[dict[str, Any]] = [] + for idx, tool_name in enumerate(response.tools_call_name): + tool_calls.append( + { + "id": ( + response.tools_call_ids[idx] + if idx < len(response.tools_call_ids) + else None + ), + "name": tool_name, + "arguments": ( + response.tools_call_args[idx] + if idx < len(response.tools_call_args) + else {} + ), + "extra_content": ( + response.tools_call_extra_content.get( + response.tools_call_ids[idx] + ) + if idx < len(response.tools_call_ids) + else None + ), + } + ) + return { + "text": response.completion_text or "", + "usage": usage_payload, + "finish_reason": "tool_calls" if tool_calls else "stop", + "tool_calls": tool_calls, + "role": response.role, + "reasoning_content": response.reasoning_content or None, + "reasoning_signature": response.reasoning_signature, + } + + @classmethod + def _apply_sdk_result_payload( + cls, + result: MessageEventResult, + payload: dict[str, Any], + ) -> MessageEventResult: + chain_payload = payload.get("chain") + updated = ( + cls._build_core_result_from_chain_payload(chain_payload) + if isinstance(chain_payload, list) + else MessageEventResult() + ) + result.chain = updated.chain + result.use_t2i_ = updated.use_t2i_ + result.type = updated.type + return result + + def get_effective_result( + self, event: AstrMessageEvent + ) -> MessageEventResult | None: + dispatch_token = self._get_dispatch_token(event) + if dispatch_token: + overlay = self.get_request_overlay_by_token(dispatch_token) + if overlay is not None and overlay.result_is_set: + if overlay.result_payload is None: + return None + if overlay.result_object is None: + chain_payload = overlay.result_payload.get("chain") + if not isinstance(chain_payload, list): + return None + overlay.result_object = self._build_core_result_from_chain_payload( + chain_payload + ) + return overlay.result_object + return event.get_result() + + def before_platform_send(self, dispatch_token: str) -> None: + request_context = self._request_contexts.get(dispatch_token) + if request_context is None: + raise AstrBotError.invalid_input( + "Unknown SDK dispatch token for platform send" + ) + overlay = self.get_request_overlay_by_token(dispatch_token) + if overlay is None: + raise AstrBotError.cancelled("The SDK request overlay has been closed") + if request_context.cancelled: + raise AstrBotError.cancelled("The SDK request has been cancelled") + + def mark_platform_send(self, dispatch_token: str) -> str: + request_context = self._request_contexts.get(dispatch_token) + if request_context is None: + raise AstrBotError.invalid_input( + "Unknown SDK dispatch token for platform send" + ) + overlay = self.get_request_overlay_by_token(dispatch_token) + if overlay is None: + raise AstrBotError.cancelled("The SDK request overlay has been closed") + if request_context.cancelled: + raise AstrBotError.cancelled("The SDK request has been cancelled") + if request_context.dispatch_state is not None: + request_context.dispatch_state.sent_message = True + overlay.should_call_llm = False + if request_context.has_event: + request_context.event._has_send_oper = True + return f"sdk_{dispatch_token}" + + @staticmethod + def _legacy_has_replied(event: AstrMessageEvent) -> bool: + return getattr(event, "_has_send_oper", False) + + def _match_handlers(self, event: AstrMessageEvent) -> list[TriggerMatch]: + matches: list[TriggerMatch] = [] + normalized_platform = self._normalize_platform_name(event.get_platform_name()) + for record in self._records.values(): + if record.state in {SDK_STATE_DISABLED, SDK_STATE_FAILED}: + continue + if not self._record_supports_platform(record, normalized_platform): + continue + for handler in record.handlers: + match = TriggerConverter.match_handler( + plugin_id=record.plugin_id, + descriptor=handler.descriptor, + event=event, + load_order=record.load_order, + declaration_order=handler.declaration_order, + ) + if match is not None: + matches.append(match) + dynamic_base_order = len(record.handlers) + for route in getattr(record, "dynamic_command_routes", []): + match = self._match_dynamic_command_route( + record=record, + route=route, + event=event, + declaration_order=dynamic_base_order + route.declaration_order, + ) + if match is not None: + matches.append(match) + matches.sort(key=TriggerConverter.sort_key) + return matches + + def _match_dynamic_command_route( + self, + *, + record: SdkPluginRecord, + route: SdkDynamicCommandRoute, + event: AstrMessageEvent, + declaration_order: int, + ) -> TriggerMatch | None: + handler_ref = self._find_handler_ref(record, route.handler_full_name) + if handler_ref is None: + return None + descriptor = handler_ref.descriptor.model_copy(deep=True) + descriptor.priority = route.priority + if route.use_regex: + descriptor.trigger = MessageTrigger(regex=route.command_name) + else: + descriptor.trigger = CommandTrigger( + command=route.command_name, + description=route.desc or None, + ) + return TriggerConverter.match_handler( + plugin_id=record.plugin_id, + descriptor=descriptor, + event=event, + load_order=record.load_order, + declaration_order=declaration_order, + ) + + @staticmethod + def _find_handler_ref( + record: SdkPluginRecord, + handler_full_name: str, + ) -> SdkHandlerRef | None: + for handler in record.handlers: + if handler.descriptor.id == handler_full_name: + return handler + return None + + async def dispatch_system_event( + self, + event_type: str, + payload: dict[str, Any] | None = None, + ) -> None: + normalized_platform = self._normalize_platform_name( + (payload or {}).get("platform") + ) + event_payload = { + "type": event_type, + "event_type": event_type, + "text": str((payload or {}).get("message_outline", "")), + "session_id": str((payload or {}).get("session_id", "")), + "platform": str((payload or {}).get("platform", "")), + "platform_id": str((payload or {}).get("platform_id", "")), + "message_type": EventConverter._sdk_message_type( + (payload or {}).get("message_type", "") + ), + "sender_name": str((payload or {}).get("sender_name", "")), + "self_id": str((payload or {}).get("self_id", "")), + "raw": {"event_type": event_type, **(payload or {})}, + } + for key, value in (payload or {}).items(): + event_payload[key] = value + matches = self._match_event_handlers( + event_type, + platform_name=normalized_platform, + ) + for record, descriptor in matches: + if record.session is None: + continue + try: + await record.session.invoke_handler( + descriptor.id, + event_payload, + request_id=f"sdk_event_{record.plugin_id}_{uuid.uuid4().hex}", + args={}, + ) + except Exception as exc: + logger.warning( + "SDK event handler failed: plugin=%s handler=%s error=%s", + record.plugin_id, + descriptor.id, + exc, + ) + + async def dispatch_message_event( + self, + event_type: str, + event: AstrMessageEvent, + payload: dict[str, Any] | None = None, + *, + provider_request: CoreProviderRequest | None = None, + llm_response: CoreLLMResponse | None = None, + event_result: MessageEventResult | None = None, + ) -> None: + dispatch_token = self._get_dispatch_token(event) + if not dispatch_token: + return + overlay = self.get_request_overlay_by_token(dispatch_token) + if overlay is None: + return + normalized_platform = self._normalize_platform_name(event.get_platform_name()) + matches = self._match_event_handlers( + event_type, + allowed_plugins=overlay.handler_whitelist, + platform_name=normalized_platform, + ) + for record, descriptor in matches: + if record.session is None: + continue + request_id = f"sdk_event_{record.plugin_id}_{uuid.uuid4().hex}" + request_context = self._request_contexts.get(dispatch_token) + if request_context is None: + request_context = _RequestContext( + plugin_id=record.plugin_id, + request_id=request_id, + dispatch_token=dispatch_token, + dispatch_state=_DispatchState(event=event), + ) + self._request_contexts[dispatch_token] = request_context + request_context.plugin_id = record.plugin_id + request_context.request_id = request_id + request_context.dispatch_state.event = event + request_context.cancelled = False + self._track_request_scope( + dispatch_token=dispatch_token, + request_id=request_id, + plugin_id=record.plugin_id, + ) + event_payload = EventConverter.core_to_sdk( + event, + dispatch_token=dispatch_token, + plugin_id=record.plugin_id, + request_id=request_id, + ) + event_payload["type"] = event_type + event_payload["event_type"] = event_type + event_payload["raw"] = { + **( + event_payload["raw"] + if isinstance(event_payload.get("raw"), dict) + else {} + ), + "event_type": event_type, + **(payload or {}), + } + for key, value in (payload or {}).items(): + event_payload[key] = value + self._apply_request_scoped_event_payload(event_payload, overlay) + if provider_request is not None: + request_payload = self._core_provider_request_to_sdk_payload( + provider_request + ) + event_payload["provider_request"] = request_payload + if isinstance(event_payload["raw"], dict): + event_payload["raw"]["provider_request"] = request_payload + if llm_response is not None: + response_payload = self._core_llm_response_to_sdk_payload(llm_response) + event_payload["llm_response"] = response_payload + if isinstance(event_payload["raw"], dict): + event_payload["raw"]["llm_response"] = response_payload + if event_result is not None: + result_payload = self._legacy_result_to_sdk_payload(event_result) + if result_payload is not None: + event_payload["event_result"] = result_payload + if isinstance(event_payload["raw"], dict): + event_payload["raw"]["event_result"] = result_payload + try: + output = await record.session.invoke_handler( + descriptor.id, + event_payload, + request_id=request_id, + args={}, + ) + if isinstance(output, dict): + if "sdk_local_extras" in output: + self._persist_sdk_local_extras_from_handler( + overlay, + output.get("sdk_local_extras"), + plugin_id=record.plugin_id, + handler_id=descriptor.id, + ) + request_payload = output.get("provider_request") + if provider_request is not None and isinstance( + request_payload, dict + ): + self._apply_sdk_provider_request_payload( + provider_request, + request_payload, + ) + result_payload = output.get("event_result") + if event_result is not None and isinstance(result_payload, dict): + if not self.set_result_for_request(request_id, result_payload): + self._apply_sdk_result_payload(event_result, result_payload) + except Exception as exc: + logger.warning( + "SDK event handler failed: plugin=%s handler=%s error=%s", + record.plugin_id, + descriptor.id, + exc, + ) + + def _match_event_handlers( + self, + event_type: str, + *, + allowed_plugins: set[str] | None = None, + platform_name: str = "", + ) -> list[tuple[SdkPluginRecord, HandlerDescriptor]]: + matches: list[tuple[int, int, int, SdkPluginRecord, HandlerDescriptor]] = [] + for record in self._records.values(): + if record.state in { + SDK_STATE_DISABLED, + SDK_STATE_FAILED, + SDK_STATE_RELOADING, + }: + continue + if allowed_plugins is not None and record.plugin_id not in allowed_plugins: + continue + if not self._record_supports_platform(record, platform_name): + continue + for handler in record.handlers: + trigger = handler.descriptor.trigger + if not isinstance(trigger, EventTrigger): + continue + if trigger.event_type != event_type: + continue + if not self._descriptor_supports_platform( + handler.descriptor, + platform_name, + ): + continue + matches.append( + ( + -handler.descriptor.priority, + record.load_order, + handler.declaration_order, + record, + handler.descriptor, + ) + ) + matches.sort(key=lambda item: (item[0], item[1], item[2])) + return [(record, descriptor) for _, _, _, record, descriptor in matches] + + @staticmethod + def _descriptor_event_types(descriptor: HandlerDescriptor) -> list[str]: + trigger = descriptor.trigger + if isinstance(trigger, EventTrigger): + return [trigger.event_type] + return [] + + @staticmethod + def _descriptor_group_path(descriptor: HandlerDescriptor) -> list[str]: + route = getattr(descriptor, "command_route", None) + if route is None: + return [] + return list(route.group_path) + + @staticmethod + def _descriptor_description(descriptor: HandlerDescriptor) -> str | None: + description = str(descriptor.description or "").strip() + if description: + return description + trigger = descriptor.trigger + if isinstance(trigger, CommandTrigger): + command_description = str(trigger.description or "").strip() + if command_description: + return command_description + return None + + def _descriptor_metadata( + self, + *, + plugin_id: str, + descriptor: HandlerDescriptor, + ) -> dict[str, Any]: + return { + "plugin_name": plugin_id, + "handler_full_name": descriptor.id, + "trigger_type": getattr(descriptor.trigger, "type", ""), + "description": self._descriptor_description(descriptor), + "event_types": self._descriptor_event_types(descriptor), + "enabled": True, + "group_path": self._descriptor_group_path(descriptor), + "priority": descriptor.priority, + "kind": descriptor.kind, + "require_admin": descriptor.permissions.require_admin, + "required_role": descriptor.permissions.required_role, + } + + def get_handlers_by_event_type(self, event_type: str) -> list[dict[str, Any]]: + entries: list[dict[str, Any]] = [] + for record in sorted(self._records.values(), key=lambda item: item.load_order): + if record.state in { + SDK_STATE_DISABLED, + SDK_STATE_FAILED, + SDK_STATE_RELOADING, + }: + continue + for handler in record.handlers: + trigger = handler.descriptor.trigger + if ( + isinstance(trigger, EventTrigger) + and trigger.event_type == event_type + ): + entries.append( + self._descriptor_metadata( + plugin_id=record.plugin_id, + descriptor=handler.descriptor, + ) + ) + if event_type == "message": + for route in getattr(record, "dynamic_command_routes", []): + descriptor = self._build_dynamic_route_descriptor(record, route) + if descriptor is None: + continue + entries.append( + self._descriptor_metadata( + plugin_id=record.plugin_id, + descriptor=descriptor, + ) + ) + return entries + + def list_native_command_candidates( + self, + platform_name: str, + ) -> list[dict[str, Any]]: + """Expose SDK commands that can be surfaced in native platform menus. + + Native platform command menus are top-level and single-token, so grouped + SDK commands are exported as their root command (for example ``gf`` for + ``gf chat`` / ``gf affection``). + """ + normalized_platform = str(platform_name).strip().lower() + if not normalized_platform: + return [] + + entries: list[dict[str, Any]] = [] + seen_names: set[str] = set() + + for record in sorted(self._records.values(), key=lambda item: item.load_order): + if record.state in { + SDK_STATE_DISABLED, + SDK_STATE_FAILED, + SDK_STATE_RELOADING, + }: + continue + if not self._record_supports_platform(record, normalized_platform): + continue + + for handler in record.handlers: + for entry in self._descriptor_native_command_candidates( + handler.descriptor, + platform_name=normalized_platform, + ): + name = str(entry.get("name", "")).strip().lower() + if not name or name in seen_names: + continue + seen_names.add(name) + entries.append(entry) + + for route in getattr(record, "dynamic_command_routes", []): + descriptor = self._build_dynamic_route_descriptor(record, route) + if descriptor is None: + continue + for entry in self._descriptor_native_command_candidates( + descriptor, + platform_name=normalized_platform, + ): + name = str(entry.get("name", "")).strip().lower() + if not name or name in seen_names: + continue + seen_names.add(name) + entries.append(entry) + + return entries + + def get_handler_by_full_name(self, full_name: str) -> dict[str, Any] | None: + for record in self._records.values(): + for handler in record.handlers: + if handler.descriptor.id == full_name: + return self._descriptor_metadata( + plugin_id=record.plugin_id, + descriptor=handler.descriptor, + ) + return None + + def list_dashboard_commands(self) -> list[dict[str, Any]]: + items: list[dict[str, Any]] = [] + for record in sorted(self._records.values(), key=lambda item: item.load_order): + items.extend(self._build_dashboard_command_items(record)) + items.sort(key=lambda item: str(item.get("effective_command", "")).lower()) + return items + + def list_dashboard_tools(self) -> list[dict[str, Any]]: + tools: list[dict[str, Any]] = [] + for record in sorted(self._records.values(), key=lambda item: item.load_order): + display_name = str( + record.plugin.manifest_data.get("display_name") or record.plugin_id + ) + plugin_enabled = record.state not in { + SDK_STATE_DISABLED, + SDK_STATE_FAILED, + SDK_STATE_RELOADING, + } + for spec in sorted(record.llm_tools.values(), key=lambda item: item.name): + tools.append( + { + "tool_key": (f"sdk:{record.plugin_id}:{spec.name}"), + "name": spec.name, + "description": spec.description, + "parameters": dict(spec.parameters_schema), + "active": bool(spec.active) and plugin_enabled, + "origin": "sdk_plugin", + "origin_name": display_name, + "runtime_kind": "sdk", + "plugin_id": record.plugin_id, + } + ) + return tools + + def _build_dashboard_command_items( + self, + record: SdkPluginRecord, + ) -> list[dict[str, Any]]: + flat_commands: list[dict[str, Any]] = [] + for handler in record.handlers: + entry = self._build_dashboard_command_entry( + record=record, + descriptor=handler.descriptor, + ) + if entry is not None: + flat_commands.append(entry) + for route in getattr(record, "dynamic_command_routes", []): + descriptor = self._build_dynamic_route_descriptor(record, route) + if descriptor is None: + continue + entry = self._build_dashboard_command_entry( + record=record, + descriptor=descriptor, + route=route, + ) + if entry is not None: + flat_commands.append(entry) + + groups: dict[str, dict[str, Any]] = {} + root_items: list[dict[str, Any]] = [] + for entry in flat_commands: + parent_signature = str(entry.get("parent_signature", "")).strip() + if not parent_signature: + root_items.append(entry) + continue + group_key = self._dashboard_group_key(record.plugin_id, parent_signature) + group = groups.get(group_key) + if group is None: + group = { + "command_key": group_key, + "handler_full_name": group_key, + "handler_name": parent_signature.split()[-1] or record.plugin_id, + "plugin": record.plugin_id, + "plugin_display_name": str( + record.plugin.manifest_data.get("display_name") + or record.plugin_id + ), + "module_path": str(record.plugin.plugin_dir), + "description": entry.pop("_group_help", "") or "", + "type": "group", + "parent_signature": "", + "parent_group_handler": "", + "original_command": parent_signature, + "current_fragment": parent_signature.split()[-1] + if parent_signature + else "", + "effective_command": parent_signature, + "aliases": [], + "permission": "everyone", + "enabled": bool(entry.get("enabled", False)), + "is_group": True, + "has_conflict": False, + "reserved": False, + "runtime_kind": "sdk", + "supports_toggle": False, + "supports_rename": False, + "supports_permission": False, + "sub_commands": [], + } + groups[group_key] = group + root_items.append(group) + elif not group.get("description") and entry.get("_group_help"): + group["description"] = entry["_group_help"] + + if entry.get("permission") == "admin": + group["permission"] = "admin" + group["enabled"] = bool(group["enabled"]) or bool( + entry.get("enabled", False) + ) + entry["parent_group_handler"] = group["handler_full_name"] + entry.pop("_group_help", None) + group["sub_commands"].append(entry) + + for group in groups.values(): + group["sub_commands"].sort( + key=lambda item: str(item.get("effective_command", "")).lower() + ) + for item in root_items: + item.pop("_group_help", None) + return root_items + + def _build_dashboard_command_entry( + self, + *, + record: SdkPluginRecord, + descriptor: HandlerDescriptor, + route: SdkDynamicCommandRoute | None = None, + ) -> dict[str, Any] | None: + trigger = descriptor.trigger + if not isinstance(trigger, CommandTrigger): + return None + + route_meta = descriptor.command_route + effective_command = ( + str(route_meta.display_command).strip() + if route_meta is not None and str(route_meta.display_command).strip() + else str(trigger.command).strip() + ) + parent_signature = "" + group_help = "" + if route_meta is not None and route_meta.group_path: + parent_signature = " ".join( + str(item).strip() for item in route_meta.group_path if str(item).strip() + ).strip() + group_help = str(route_meta.group_help or "").strip() + + current_fragment = effective_command + if parent_signature and effective_command.startswith(f"{parent_signature} "): + current_fragment = effective_command[len(parent_signature) + 1 :].strip() + + enabled = record.state not in { + SDK_STATE_DISABLED, + SDK_STATE_FAILED, + SDK_STATE_RELOADING, + } + return { + "command_key": self._dashboard_command_key( + plugin_id=record.plugin_id, + handler_full_name=descriptor.id, + route=route, + ), + "handler_full_name": descriptor.id, + "handler_name": descriptor.id.rsplit(".", 1)[-1], + "plugin": record.plugin_id, + "plugin_display_name": str( + record.plugin.manifest_data.get("display_name") or record.plugin_id + ), + "module_path": descriptor.id.rsplit(".", 1)[0], + "description": self._descriptor_description(descriptor) or "", + "type": "sub_command" if parent_signature else "command", + "parent_signature": parent_signature, + "parent_group_handler": "", + "original_command": effective_command, + "current_fragment": current_fragment, + "effective_command": effective_command, + "aliases": list(trigger.aliases), + "permission": ( + "admin" if descriptor.permissions.require_admin else "everyone" + ), + "enabled": enabled, + "is_group": False, + "has_conflict": False, + "reserved": False, + "runtime_kind": "sdk", + "supports_toggle": False, + "supports_rename": False, + "supports_permission": False, + "sub_commands": [], + "_group_help": group_help, + } + + @staticmethod + def _dashboard_command_key( + *, + plugin_id: str, + handler_full_name: str, + route: SdkDynamicCommandRoute | None, + ) -> str: + if route is None: + return f"sdk:command:{plugin_id}:{handler_full_name}" + route_kind = "regex" if route.use_regex else "command" + return f"sdk:route:{plugin_id}:{handler_full_name}:{route_kind}:{route.command_name}" + + @staticmethod + def _dashboard_group_key(plugin_id: str, parent_signature: str) -> str: + return f"sdk:group:{plugin_id}:{parent_signature}" + + def _build_dynamic_route_descriptor( + self, + record: SdkPluginRecord, + route: SdkDynamicCommandRoute, + ) -> HandlerDescriptor | None: + handler_ref = self._find_handler_ref(record, route.handler_full_name) + if handler_ref is None: + return None + descriptor = handler_ref.descriptor.model_copy(deep=True) + descriptor.priority = route.priority + if route.use_regex: + descriptor.trigger = MessageTrigger(regex=route.command_name) + else: + descriptor.trigger = CommandTrigger( + command=route.command_name, + description=route.desc or None, + ) + return descriptor + + @staticmethod + def _normalize_platform_name(value: Any) -> str: + return str(value or "").strip().lower() + + @classmethod + def _normalized_platform_names(cls, values: Any) -> set[str]: + if not isinstance(values, list): + return set() + return { + cls._normalize_platform_name(item) + for item in values + if cls._normalize_platform_name(item) + } + + @classmethod + def _manifest_supported_platforms(cls, manifest_data: Any) -> set[str]: + if not isinstance(manifest_data, dict): + return set() + return cls._normalized_platform_names(manifest_data.get("support_platforms")) + + def plugin_supports_platform(self, plugin_id: str, platform_name: str) -> bool: + normalized_platform = self._normalize_platform_name(platform_name) + if not normalized_platform: + return True + record = self._records.get(str(plugin_id)) + if record is None: + return True + return self._record_supports_platform(record, normalized_platform) + + @staticmethod + def _record_supports_platform( + record: SdkPluginRecord, + platform_name: str, + ) -> bool: + normalized_platform = SdkPluginBridge._normalize_platform_name(platform_name) + if not normalized_platform: + return True + plugin = getattr(record, "plugin", None) + manifest_data = getattr(plugin, "manifest_data", None) + normalized = SdkPluginBridge._manifest_supported_platforms(manifest_data) + if not normalized: + return True + return normalized_platform in normalized + + @staticmethod + def _local_mcp_tool_name(server_name: str, tool_name: str) -> str: + return f"mcp.{server_name}.{tool_name}" + + @staticmethod + def _local_mcp_tool_ref(server_name: str, tool_name: str) -> str: + return json.dumps( + {"server_name": server_name, "tool_name": tool_name}, + ensure_ascii=True, + separators=(",", ":"), + ) + + @staticmethod + def _plugin_data_dir(plugin_id: str) -> Path: + return Path(get_astrbot_plugin_data_path()) / plugin_id + + @classmethod + def _plugin_mcp_lease_dir(cls, plugin_id: str) -> Path: + return cls._plugin_data_dir(plugin_id) / ".mcp_leases" + + def acknowledges_global_mcp_risk(self, plugin_id: str) -> bool: + record = self._records.get(plugin_id) + return bool(record and record.acknowledge_global_mcp_risk) + + def _load_local_mcp_configs(self, plugin: PluginSpec) -> dict[str, dict[str, Any]]: + config_path = plugin.plugin_dir / "mcp.json" + if not config_path.exists(): + return {} + try: + payload = json.loads(config_path.read_text(encoding="utf-8")) + except Exception as exc: + logger.warning( + "Failed to read SDK plugin mcp.json %s: %s", config_path, exc + ) + return {} + if not isinstance(payload, dict): + logger.warning("Ignoring invalid SDK plugin mcp.json root: %s", config_path) + return {} + servers = payload.get("mcpServers") + if not isinstance(servers, dict): + logger.warning( + "Ignoring SDK plugin mcp.json without mcpServers: %s", config_path + ) + return {} + return { + str(name): dict(config) + for name, config in servers.items() + if str(name).strip() and isinstance(config, dict) + } + + @classmethod + def _build_local_mcp_tool_specs( + cls, + server_name: str, + client: MCPClient, + ) -> list[LLMToolSpec]: + specs: list[LLMToolSpec] = [] + for tool in client.tools: + raw_tool_name = str(getattr(tool, "name", "")).strip() + if not raw_tool_name: + continue + parameters_schema = getattr(tool, "inputSchema", None) + if not isinstance(parameters_schema, dict): + parameters_schema = {"type": "object", "properties": {}} + specs.append( + LLMToolSpec.create( + name=cls._local_mcp_tool_name(server_name, raw_tool_name), + description=str(getattr(tool, "description", "") or ""), + parameters_schema=dict(parameters_schema), + handler_ref=cls._local_mcp_tool_ref(server_name, raw_tool_name), + handler_capability="internal.mcp.local.execute", + active=True, + ) + ) + return specs + + @staticmethod + def _mcp_call_result_to_text(result: Any) -> str | None: + content_items = getattr(result, "content", None) + if not isinstance(content_items, list): + return None + chunks: list[str] = [] + for item in content_items: + text = getattr(item, "text", None) + if isinstance(text, str): + chunks.append(text) + continue + model_dump = getattr(item, "model_dump", None) + if callable(model_dump): + chunks.append(json.dumps(model_dump(), ensure_ascii=False)) + continue + if item is not None: + chunks.append(str(item)) + return "\n".join(part for part in chunks if part).strip() or None + + async def _cleanup_mcp_client(self, client: MCPClient | None) -> None: + if client is None: + return + with contextlib.suppress(Exception): + await client.cleanup() + + def _write_local_mcp_lease( + self, + *, + plugin_id: str, + server_name: str, + pid: int, + ) -> Path: + lease_dir = self._plugin_mcp_lease_dir(plugin_id) + lease_dir.mkdir(parents=True, exist_ok=True) + lease_path = lease_dir / f"{server_name}.json" + lease_path.write_text( + json.dumps( + { + "pid": int(pid), + "plugin_id": plugin_id, + "server_name": server_name, + }, + ensure_ascii=True, + indent=2, + ), + encoding="utf-8", + ) + return lease_path + + @staticmethod + def _remove_local_mcp_lease(runtime: _LocalMCPServerRuntime) -> None: + lease_path = runtime.lease_path + runtime.lease_path = None + if lease_path is None: + return + with contextlib.suppress(OSError): + lease_path.unlink() + + def _terminate_stale_mcp_pid(self, pid: int) -> None: + try: + os.kill(pid, signal.SIGTERM) + except ProcessLookupError: + return + except PermissionError: + logger.warning("Permission denied while terminating stale MCP pid %s", pid) + return + except OSError as exc: + logger.warning("Failed to terminate stale MCP pid %s: %s", pid, exc) + + def _sweep_stale_mcp_leases(self) -> None: + plugin_data_root = Path(get_astrbot_plugin_data_path()) + if not plugin_data_root.exists(): + return + for lease_path in plugin_data_root.glob("*/.mcp_leases/*.json"): + try: + payload = json.loads(lease_path.read_text(encoding="utf-8")) + except Exception: + payload = {} + pid = payload.get("pid") + if pid is not None: + with contextlib.suppress(TypeError, ValueError): + self._terminate_stale_mcp_pid(int(pid)) + with contextlib.suppress(OSError): + lease_path.unlink() + + async def _connect_local_mcp_server( + self, + *, + plugin_id: str, + runtime: _LocalMCPServerRuntime, + timeout: float, + ) -> None: + runtime.ready_event.clear() + runtime.running = False + runtime.last_error = None + runtime.errlogs = [] + runtime.tools = [] + runtime.tool_specs = [] + self._remove_local_mcp_lease(runtime) + await self._cleanup_mcp_client(runtime.client) + runtime.client = None + + client = MCPClient() + client.name = runtime.name + try: + await asyncio.wait_for( + client.connect_to_server(dict(runtime.config), runtime.name), + timeout=timeout, + ) + await asyncio.wait_for(client.list_tools_and_save(), timeout=timeout) + except asyncio.CancelledError: + await self._cleanup_mcp_client(client) + raise + except TimeoutError: + runtime.last_error = ( + f"Local MCP server '{runtime.name}' did not become ready within " + f"{timeout:g} seconds" + ) + runtime.errlogs = [runtime.last_error] + await self._cleanup_mcp_client(client) + except Exception as exc: + runtime.last_error = str(exc) + runtime.errlogs = [runtime.last_error] + await self._cleanup_mcp_client(client) + else: + runtime.client = client + runtime.running = True + runtime.tools = [ + str(tool.name) for tool in client.tools if getattr(tool, "name", None) + ] + runtime.tool_specs = self._build_local_mcp_tool_specs(runtime.name, client) + runtime.errlogs = list(client.server_errlogs) + if client.process_pid is not None: + runtime.lease_path = self._write_local_mcp_lease( + plugin_id=plugin_id, + server_name=runtime.name, + pid=client.process_pid, + ) + finally: + runtime.ready_event.set() + runtime.connect_task = None + + async def _initialize_local_mcp_servers(self, record: SdkPluginRecord) -> None: + tasks: list[asyncio.Task[None]] = [] + for runtime in record.local_mcp_servers.values(): + if not runtime.active: + runtime.ready_event.set() + continue + task = asyncio.create_task( + self._connect_local_mcp_server( + plugin_id=record.plugin_id, + runtime=runtime, + timeout=30.0, + ) + ) + runtime.connect_task = task + tasks.append(task) + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + async def _shutdown_local_mcp_runtime( + self, + runtime: _LocalMCPServerRuntime, + ) -> None: + connect_task = runtime.connect_task + runtime.connect_task = None + if connect_task is not None and not connect_task.done(): + connect_task.cancel() + with contextlib.suppress(asyncio.CancelledError, Exception): + await connect_task + self._remove_local_mcp_lease(runtime) + await self._cleanup_mcp_client(runtime.client) + runtime.client = None + runtime.running = False + runtime.tools = [] + runtime.tool_specs = [] + runtime.ready_event.clear() + + async def _shutdown_local_mcp_servers(self, record: SdkPluginRecord) -> None: + for runtime in record.local_mcp_servers.values(): + await self._shutdown_local_mcp_runtime(runtime) + + async def enable_local_mcp_server( + self, + plugin_id: str, + name: str, + *, + timeout: float = 30.0, + ) -> dict[str, Any]: + runtime = self._local_mcp_record(plugin_id, name) + if runtime is None: + raise AstrBotError.invalid_input(f"Unknown local MCP server: {name}") + if runtime.active and runtime.running and runtime.connect_task is None: + return self._serialize_local_mcp_server(runtime) + if runtime.connect_task is not None and not runtime.connect_task.done(): + runtime.active = True + await runtime.connect_task + return self._serialize_local_mcp_server(runtime) + runtime.active = True + task = asyncio.create_task( + self._connect_local_mcp_server( + plugin_id=plugin_id, + runtime=runtime, + timeout=timeout, + ) + ) + runtime.connect_task = task + await task + return self._serialize_local_mcp_server(runtime) + + async def disable_local_mcp_server( + self, + plugin_id: str, + name: str, + ) -> dict[str, Any]: + runtime = self._local_mcp_record(plugin_id, name) + if runtime is None: + raise AstrBotError.invalid_input(f"Unknown local MCP server: {name}") + if not runtime.active and not runtime.running and runtime.connect_task is None: + return self._serialize_local_mcp_server(runtime) + runtime.active = False + await self._shutdown_local_mcp_runtime(runtime) + return self._serialize_local_mcp_server(runtime) + + async def wait_for_local_mcp_server( + self, + plugin_id: str, + name: str, + *, + timeout: float, + ) -> dict[str, Any]: + runtime = self._local_mcp_record(plugin_id, name) + if runtime is None: + raise AstrBotError.invalid_input(f"Unknown local MCP server: {name}") + await asyncio.wait_for(runtime.ready_event.wait(), timeout=timeout) + if not runtime.running: + raise TimeoutError( + f"Local MCP server '{name}' did not become ready in time" + ) + return self._serialize_local_mcp_server(runtime) + + async def open_temporary_mcp_session( + self, + plugin_id: str, + *, + name: str, + config: dict[str, Any], + timeout: float, + ) -> tuple[str, list[str]]: + client = MCPClient() + client.name = name + try: + await asyncio.wait_for( + client.connect_to_server(dict(config), name), + timeout=timeout, + ) + await asyncio.wait_for(client.list_tools_and_save(), timeout=timeout) + except Exception: + await self._cleanup_mcp_client(client) + raise + session_id = f"{plugin_id}:{uuid.uuid4().hex}" + tools = [str(tool.name) for tool in client.tools if getattr(tool, "name", None)] + self._temporary_mcp_sessions[session_id] = _TemporaryMCPSessionRuntime( + plugin_id=plugin_id, + name=name, + client=client, + tools=tools, + ) + return session_id, tools + + async def close_temporary_mcp_session( + self, + plugin_id: str, + session_id: str, + ) -> None: + runtime = self._temporary_mcp_sessions.get(session_id) + if runtime is None or runtime.plugin_id != plugin_id: + return + self._temporary_mcp_sessions.pop(session_id, None) + await self._cleanup_mcp_client(runtime.client) + + async def _close_temporary_mcp_sessions(self, plugin_id: str) -> None: + session_ids = [ + session_id + for session_id, runtime in self._temporary_mcp_sessions.items() + if runtime.plugin_id == plugin_id + ] + for session_id in session_ids: + await self.close_temporary_mcp_session(plugin_id, session_id) + + def get_temporary_mcp_session_tools( + self, + plugin_id: str, + session_id: str, + ) -> list[str]: + runtime = self._temporary_mcp_sessions.get(session_id) + if runtime is None or runtime.plugin_id != plugin_id: + raise AstrBotError.invalid_input("Unknown MCP session") + return list(runtime.tools) + + async def call_temporary_mcp_tool( + self, + plugin_id: str, + *, + session_id: str, + tool_name: str, + arguments: dict[str, Any], + ) -> dict[str, Any]: + runtime = self._temporary_mcp_sessions.get(session_id) + if runtime is None or runtime.plugin_id != plugin_id: + raise AstrBotError.invalid_input("Unknown MCP session") + result = await runtime.client.call_tool_with_reconnect( + tool_name=tool_name, + arguments=arguments, + read_timeout_seconds=timedelta(seconds=60), + ) + text = self._mcp_call_result_to_text(result) + return {"content": text, "is_error": bool(getattr(result, "isError", False))} + + async def execute_local_mcp_tool( + self, + plugin_id: str, + *, + server_name: str, + tool_name: str, + tool_args: dict[str, Any], + timeout_seconds: int = 60, + ) -> dict[str, Any]: + runtime = self._local_mcp_record(plugin_id, server_name) + if ( + runtime is None + or not runtime.active + or not runtime.running + or runtime.client is None + ): + return { + "content": f"Local MCP server unavailable: {server_name}", + "success": False, + } + if tool_name not in runtime.tools: + return { + "content": f"Local MCP tool not found: {server_name}.{tool_name}", + "success": False, + } + try: + result = await runtime.client.call_tool_with_reconnect( + tool_name=tool_name, + arguments=tool_args, + read_timeout_seconds=timedelta(seconds=timeout_seconds), + ) + except Exception as exc: + return {"content": f"Tool execution failed: {exc}", "success": False} + text = self._mcp_call_result_to_text(result) + return { + "content": text, + "success": not bool(getattr(result, "isError", False)), + } + + @classmethod + def _descriptor_native_command_candidates( + cls, + descriptor: HandlerDescriptor, + *, + platform_name: str, + ) -> list[dict[str, Any]]: + trigger = descriptor.trigger + if not isinstance(trigger, CommandTrigger): + return [] + if not cls._descriptor_supports_platform(descriptor, platform_name): + return [] + + names = [trigger.command, *trigger.aliases] + route = descriptor.command_route + root_candidates: list[str] = [] + + if route is not None and route.group_path: + root_candidates.append(str(route.group_path[0]).strip()) + + for name in names: + normalized = str(name).strip() + if " " not in normalized: + continue + root_candidates.append(normalized.split()[0].strip()) + + if root_candidates: + description = ( + str(route.group_help).strip() + if route is not None and route.group_help + else str(trigger.description or "").strip() + ) + root_name = next((item for item in root_candidates if item), "") + if not description and root_name: + description = f"Command group: {root_name}" + unique_roots = [ + item + for item in dict.fromkeys(root_candidates) + if isinstance(item, str) and item.strip() + ] + return [ + { + "name": item.strip(), + "description": description, + "is_group": True, + } + for item in unique_roots + ] + + description = str(trigger.description or "").strip() + if not description and trigger.command.strip(): + description = f"Command: {trigger.command.strip()}" + unique_names = [ + item for item in dict.fromkeys(str(name).strip() for name in names) if item + ] + return [ + { + "name": item, + "description": description, + "is_group": False, + } + for item in unique_names + ] + + @classmethod + def _descriptor_supports_platform( + cls, + descriptor: HandlerDescriptor, + platform_name: str, + ) -> bool: + normalized_platform = cls._normalize_platform_name(platform_name) + if not normalized_platform: + return True + trigger_platforms = getattr(descriptor.trigger, "platforms", []) + if isinstance(trigger_platforms, list): + normalized = cls._normalized_platform_names(trigger_platforms) + if normalized and normalized_platform not in normalized: + return False + for filter_spec in descriptor.filters: + if not cls._filter_supports_platform(filter_spec, normalized_platform): + return False + return True + + @classmethod + def _filter_supports_platform(cls, filter_spec, platform_name: str) -> bool: + if isinstance(filter_spec, PlatformFilterSpec): + normalized = { + str(item).strip().lower() + for item in filter_spec.platforms + if str(item).strip() + } + return not normalized or platform_name in normalized + if isinstance(filter_spec, CompositeFilterSpec): + platform_children = [ + child + for child in filter_spec.children + if isinstance(child, PlatformFilterSpec | CompositeFilterSpec) + ] + if not platform_children: + return True + results = [ + cls._filter_supports_platform(child, platform_name) + for child in platform_children + ] + if filter_spec.kind == "and": + return all(results) + return any(results) + return True + + async def _load_or_reload_plugin( + self, + plugin: PluginSpec, + *, + load_order: int, + reset_restart_budget: bool, + ) -> None: + current = self._records.get(plugin.name) + if current is not None: + current.state = SDK_STATE_RELOADING + await self._cancel_plugin_requests(plugin.name) + await self._teardown_plugin(plugin.name) + + disabled = bool( + self._state_overrides.get(plugin.name, {}).get("disabled", False) + ) + config_schema = load_plugin_config_schema(plugin) + local_mcp_configs = self._load_local_mcp_configs(plugin) + record = SdkPluginRecord( + plugin=plugin, + load_order=load_order, + state=SDK_STATE_DISABLED if disabled else SDK_STATE_ENABLED, + unsupported_features=[], + config_schema=config_schema, + config=load_plugin_config(plugin, schema=config_schema), + handlers=[], + llm_tools={}, + active_llm_tools=set(), + agents={}, + restart_attempted=False + if reset_restart_budget + else (current.restart_attempted if current is not None else False), + issues=[dict(item) for item in self._discovery_issues.get(plugin.name, [])], + local_mcp_servers={ + name: _LocalMCPServerRuntime( + name=name, + config=dict(config), + active=bool(config.get("active", True)), + ) + for name, config in local_mcp_configs.items() + }, + ) + self._records[plugin.name] = record + self._publish_plugin_skills(plugin.name) + if disabled: + self._persist_state_overrides() + return + + try: + + def _schedule_closed(plugin_id: str = plugin.name) -> None: + asyncio.create_task(self._handle_worker_closed(plugin_id)) + + session = WorkerSession( + plugin=plugin, + repo_root=Path(__file__).resolve().parents[3], + env_manager=self.env_manager, + capability_router=self.capability_bridge, + on_closed=_schedule_closed, + ) + await session.start() + session.start_close_watch() + record.session = session + remote_metadata = ( + dict(session.peer.remote_metadata) + if session.peer is not None + and isinstance(session.peer.remote_metadata, dict) + else {} + ) + record.acknowledge_global_mcp_risk = bool( + remote_metadata.get("acknowledge_global_mcp_risk", False) + ) + unsupported_features: set[str] = set() + for index, descriptor in enumerate(session.handlers): + if ( + isinstance(descriptor.trigger, EventTrigger) + and descriptor.trigger.event_type not in SUPPORTED_SYSTEM_EVENTS + ): + unsupported_features.add("event_trigger") + record.handlers.append( + SdkHandlerRef( + descriptor=descriptor, + declaration_order=index, + ) + ) + for item in session.llm_tools: + if not isinstance(item, dict): + continue + plugin_name = str(item.get("plugin_id") or plugin.name) + if plugin_name != plugin.name: + continue + normalized = dict(item) + normalized.pop("plugin_id", None) + spec = LLMToolSpec.from_payload(normalized) + record.llm_tools[spec.name] = spec + if spec.active: + record.active_llm_tools.add(spec.name) + for item in session.agents: + if not isinstance(item, dict): + continue + plugin_name = str(item.get("plugin_id") or plugin.name) + if plugin_name != plugin.name: + continue + normalized = dict(item) + normalized.pop("plugin_id", None) + spec = AgentSpec.from_payload(normalized) + record.agents[spec.name] = spec + await self._register_schedule_handlers(record) + await self._initialize_local_mcp_servers(record) + record.issues.extend(issue.to_payload() for issue in session.issues) + record.unsupported_features = sorted(unsupported_features) + record.state = ( + SDK_STATE_UNSUPPORTED_PARTIAL + if record.unsupported_features + else SDK_STATE_ENABLED + ) + record.failure_reason = "" + except Exception as exc: + record.session = None + record.state = SDK_STATE_FAILED + record.failure_reason = str(exc) + record.issues.append( + PluginDiscoveryIssue( + severity="error", + phase="load", + plugin_id=plugin.name, + message="插件 worker 启动失败", + details=str(exc), + ).to_payload() + ) + logger.warning("Failed to start SDK plugin %s: %s", plugin.name, exc) + finally: + self._persist_state_overrides() + + async def _teardown_plugin(self, plugin_id: str) -> None: + record = self._records.get(plugin_id) + self._http_routes.pop(plugin_id, None) + self._session_waiters.pop(plugin_id, None) + await self._unregister_schedule_jobs(plugin_id) + await self._close_temporary_mcp_sessions(plugin_id) + await self._clear_plugin_skills( + plugin_id=plugin_id, + record=record, + reason="teardown", + ) + if record is None or record.session is None: + if record is not None: + await self._shutdown_local_mcp_servers(record) + return + try: + await self._shutdown_local_mcp_servers(record) + await record.session.stop() + finally: + record.session = None + + async def _register_schedule_handlers(self, record: SdkPluginRecord) -> None: + cron_manager = getattr(self.star_context, "cron_manager", None) + if cron_manager is None: + return + for handler in record.handlers: + trigger = handler.descriptor.trigger + if not isinstance(trigger, ScheduleTrigger): + continue + schedule_key = f"{record.plugin_id}:{handler.handler_id}" + job = await cron_manager.add_basic_job( + name=schedule_key, + cron_expression=trigger.cron, + interval_seconds=trigger.interval_seconds, + handler=self._build_schedule_runner( + plugin_id=record.plugin_id, + handler_id=handler.handler_id, + trigger=trigger, + ), + description=f"SDK schedule handler {handler.handler_id}", + enabled=True, + persistent=False, + ) + self._schedule_job_ids.setdefault(record.plugin_id, set()).add(job.job_id) + + async def _unregister_schedule_jobs(self, plugin_id: str) -> None: + cron_manager = getattr(self.star_context, "cron_manager", None) + if cron_manager is None: + return + for job_id in list(self._schedule_job_ids.pop(plugin_id, set())): + try: + await cron_manager.delete_job(job_id) + except Exception: + logger.debug("Failed to remove SDK schedule job {}", job_id) + + def _build_schedule_runner( + self, + *, + plugin_id: str, + handler_id: str, + trigger: ScheduleTrigger, + ): + async def _run(**_scheduler_payload: Any) -> None: + # CronJobManager stores scheduler metadata such as interval_seconds in the + # job payload and replays that payload into basic handlers. SDK schedule + # handlers do not consume those transport-level kwargs, so the bridge + # must swallow them here and only forward the synthesized schedule event. + await self._invoke_schedule_handler( + plugin_id=plugin_id, + handler_id=handler_id, + trigger=trigger, + ) + + return _run + + def _set_discovery_issues(self, issues: list[PluginDiscoveryIssue]) -> None: + grouped: dict[str, list[dict[str, Any]]] = {} + for issue in issues: + grouped.setdefault(issue.plugin_id, []).append(issue.to_payload()) + self._discovery_issues = grouped + + async def _refresh_native_platform_commands( + self, platforms: set[str] | None = None + ) -> None: + platform_manager = getattr(self.star_context, "platform_manager", None) + if platform_manager is None: + return + refresh_commands = getattr(platform_manager, "refresh_native_commands", None) + if not callable(refresh_commands): + return + try: + await refresh_commands(platforms=platforms) + except Exception as exc: + logger.warning("Failed to refresh native platform commands: %s", exc) + + async def _invoke_schedule_handler( + self, + *, + plugin_id: str, + handler_id: str, + trigger: ScheduleTrigger, + ) -> None: + record = self._records.get(plugin_id) + if ( + record is None + or record.session is None + or record.state + in {SDK_STATE_DISABLED, SDK_STATE_FAILED, SDK_STATE_RELOADING} + ): + return + dispatch_token = uuid.uuid4().hex + request_id = f"sdk_schedule_{plugin_id}_{uuid.uuid4().hex}" + self._ensure_request_overlay(dispatch_token, should_call_llm=False) + self._request_contexts[dispatch_token] = _RequestContext( + plugin_id=plugin_id, + request_id=request_id, + dispatch_token=dispatch_token, + dispatch_state=None, + ) + self._track_request_scope( + dispatch_token=dispatch_token, + request_id=request_id, + plugin_id=plugin_id, + ) + payload = self._build_schedule_payload( + plugin_id=plugin_id, + handler_id=handler_id, + trigger=trigger, + ) + try: + await record.session.invoke_handler( + handler_id, + payload, + request_id=request_id, + args={}, + ) + except Exception as exc: + logger.warning( + "SDK schedule handler failed: plugin=%s handler=%s error=%s", + plugin_id, + handler_id, + exc, + ) + + @staticmethod + def _build_schedule_payload( + *, + plugin_id: str, + handler_id: str, + trigger: ScheduleTrigger, + ) -> dict[str, Any]: + scheduled_at = datetime.now(timezone.utc).isoformat() + return { + "type": "schedule", + "event_type": "schedule", + "text": "", + "session_id": "", + "platform": "", + "platform_id": "", + "message_type": "other", + "sender_name": "", + "self_id": "", + "raw": {"event_type": "schedule"}, + "schedule": { + "schedule_id": f"{plugin_id}:{handler_id}", + "plugin_id": plugin_id, + "handler_id": handler_id, + "trigger_kind": "cron" if trigger.cron is not None else "interval", + "cron": trigger.cron, + "interval_seconds": trigger.interval_seconds, + "scheduled_at": scheduled_at, + }, + } + + async def _cancel_plugin_requests(self, plugin_id: str) -> None: + requests = list(self._plugin_requests.get(plugin_id, {}).values()) + for inflight in requests: + request_context = self._request_contexts.get(inflight.dispatch_token) + if request_context is not None: + request_context.cancelled = True + self._close_request_overlay(inflight.dispatch_token) + record = self._records.get(plugin_id) + if ( + record is not None + and record.session is not None + and record.session.peer is not None + and not inflight.task.done() + ): + try: + await record.session.cancel(inflight.request_id) + except Exception: + logger.debug( + "Failed to forward SDK cancel for %s", inflight.request_id + ) + inflight.task.cancel() + else: + inflight.logical_cancelled = True + self._plugin_requests.pop(plugin_id, None) + + async def _handle_worker_closed(self, plugin_id: str) -> None: + if self._stopping: + return + await self._cancel_plugin_requests(plugin_id) + await self._close_temporary_mcp_sessions(plugin_id) + record = self._records.get(plugin_id) + if record is None: + return + await self._shutdown_local_mcp_servers(record) + record.session = None + if record.state in {SDK_STATE_RELOADING, SDK_STATE_DISABLED}: + return + if not record.restart_attempted: + record.restart_attempted = True + logger.warning( + "SDK plugin worker closed unexpectedly, retrying once: %s", + plugin_id, + ) + await self._load_or_reload_plugin( + record.plugin, + load_order=record.load_order, + reset_restart_budget=False, + ) + return + record.state = SDK_STATE_FAILED + self._http_routes.pop(plugin_id, None) + self._session_waiters.pop(plugin_id, None) + await self._unregister_schedule_jobs(plugin_id) + await self._clear_plugin_skills( + plugin_id=plugin_id, + record=record, + reason="worker failure cleanup", + ) + + def _record_to_dashboard_item(self, record: SdkPluginRecord) -> dict[str, Any]: + manifest = record.plugin.manifest_data + support_platforms = manifest.get("support_platforms") + installed_at = None + try: + installed_at = datetime.fromtimestamp( + record.plugin.plugin_dir.stat().st_mtime, + timezone.utc, + ).isoformat() + except OSError: + installed_at = None + handlers = [ + self._handler_to_dashboard_item(handler) for handler in record.handlers + ] + return { + "name": record.plugin_id, + "repo": "", + "author": str(manifest.get("author") or ""), + "desc": str(manifest.get("desc") or manifest.get("description") or ""), + "version": str(manifest.get("version") or "0.0.0"), + "reserved": False, + "activated": record.state not in {SDK_STATE_DISABLED, SDK_STATE_FAILED}, + "online_vesion": "", + "handlers": handlers, + "display_name": str(manifest.get("display_name") or record.plugin_id), + "logo": None, + "support_platforms": [ + str(item) for item in support_platforms if isinstance(item, str) + ] + if isinstance(support_platforms, list) + else [], + "astrbot_version": ( + str(manifest.get("astrbot_version")) + if manifest.get("astrbot_version") is not None + else "" + ), + "installed_at": installed_at, + "runtime_kind": "sdk", + "source_kind": "local_dir", + "managed_by": "sdk_bridge", + "state": record.state, + "trigger_summary": [item["cmd"] for item in handlers], + "unsupported_features": list(record.unsupported_features), + "failure_reason": record.failure_reason, + "issues": [dict(item) for item in record.issues], + } + + def _failed_issue_to_dashboard_item( + self, + plugin_id: str, + issues: list[dict[str, Any]], + ) -> dict[str, Any]: + issue = issues[0] if issues else {} + failure_reason = str(issue.get("details") or issue.get("message") or "") + return { + "name": plugin_id, + "repo": "", + "author": "", + "desc": str(issue.get("message", "")), + "version": "0.0.0", + "reserved": False, + "activated": False, + "online_vesion": "", + "handlers": [], + "display_name": plugin_id, + "logo": None, + "support_platforms": [], + "astrbot_version": "", + "installed_at": None, + "runtime_kind": "sdk", + "source_kind": "local_dir", + "managed_by": "sdk_bridge", + "state": SDK_STATE_FAILED, + "trigger_summary": [], + "unsupported_features": [], + "failure_reason": failure_reason, + "issues": [dict(item) for item in issues], + } + + def _handler_to_dashboard_item(self, handler: SdkHandlerRef) -> dict[str, Any]: + trigger = handler.descriptor.trigger + description = self._descriptor_description(handler.descriptor) + if not description and isinstance(trigger, CommandTrigger): + description = f"Command: {trigger.command}" + if not description: + description = "无描述" + if isinstance(trigger, CommandTrigger): + event_type = "SDKCommandEvent" + event_type_h = "SDK 指令触发" + elif isinstance(trigger, MessageTrigger): + event_type = "SDKMessageEvent" + event_type_h = "SDK 消息触发" + elif isinstance(trigger, EventTrigger): + event_type = "SDKEventTrigger" + event_type_h = "SDK 事件触发" + elif isinstance(trigger, ScheduleTrigger): + event_type = "SDKScheduleEvent" + event_type_h = "SDK 定时触发" + else: + event_type = "SDKHandler" + event_type_h = "SDK 行为触发" + + base = { + "event_type": event_type, + "event_type_h": event_type_h, + "handler_full_name": handler.handler_id, + "desc": description, + "handler_name": handler.handler_name, + "has_admin": handler.descriptor.permissions.require_admin, + } + if isinstance(trigger, CommandTrigger): + return {**base, "type": "指令", "cmd": trigger.command} + if isinstance(trigger, MessageTrigger): + if trigger.regex: + return {**base, "type": "正则匹配", "cmd": trigger.regex} + if trigger.keywords: + return {**base, "type": "关键词", "cmd": ", ".join(trigger.keywords)} + return {**base, "type": "消息", "cmd": "任意消息"} + if isinstance(trigger, EventTrigger): + return {**base, "type": "事件", "cmd": trigger.event_type} + if isinstance(trigger, ScheduleTrigger): + return { + **base, + "type": "定时", + "cmd": trigger.cron or str(trigger.interval_seconds), + } + return {**base, "type": "未知", "cmd": "未知"} + + def _load_state_overrides(self) -> dict[str, dict[str, Any]]: + if not self.state_path.exists(): + return {} + try: + data = json.loads(self.state_path.read_text(encoding="utf-8")) + except Exception: + return {} + plugins = data.get("plugins") + return dict(plugins) if isinstance(plugins, dict) else {} + + def _persist_state_overrides(self) -> None: + self.state_path.write_text( + json.dumps( + {"plugins": self._state_overrides}, ensure_ascii=False, indent=2 + ), + encoding="utf-8", + ) + + def _set_disabled_override(self, plugin_id: str, *, disabled: bool) -> None: + plugin_state = dict(self._state_overrides.get(plugin_id, {})) + if disabled: + plugin_state["disabled"] = True + self._state_overrides[plugin_id] = plugin_state + else: + plugin_state.pop("disabled", None) + if plugin_state: + self._state_overrides[plugin_id] = plugin_state + else: + self._state_overrides.pop(plugin_id, None) + self._persist_state_overrides() + + @staticmethod + def _normalize_http_route(route: str) -> str: + route_text = str(route).strip() + if not route_text: + raise AstrBotError.invalid_input("http route must not be empty") + if not route_text.startswith("/"): + route_text = f"/{route_text}" + return route_text + + @staticmethod + def _normalize_http_methods(methods: list[str]) -> tuple[str, ...]: + normalized = tuple( + sorted({str(method).upper() for method in methods if method}) + ) + if not normalized: + raise AstrBotError.invalid_input("http methods must not be empty") + return normalized + + def _ensure_http_route_available( + self, + *, + plugin_id: str, + route: str, + methods: tuple[str, ...], + ) -> None: + for legacy_route, _view_handler, legacy_methods, _desc in getattr( + self.star_context, "registered_web_apis", [] + ): + if route != legacy_route: + continue + if set(methods) & {str(method).upper() for method in legacy_methods}: + raise AstrBotError.invalid_input( + f"HTTP route conflict with legacy plugin route: {route}" + ) + for owner, entries in self._http_routes.items(): + for entry in entries: + if ( + owner == plugin_id + and entry.route == route + and entry.methods == methods + ): + continue + if entry.route != route: + continue + if set(entry.methods) & set(methods): + raise AstrBotError.invalid_input( + f"HTTP route conflict with SDK plugin route: {route}" + ) + + def _resolve_http_route( + self, + route: str, + method: str, + ) -> tuple[SdkPluginRecord, SdkHttpRoute] | None: + normalized_route = self._normalize_http_route(route) + normalized_method = str(method).upper() + for record in sorted(self._records.values(), key=lambda item: item.load_order): + for entry in self._http_routes.get(record.plugin_id, []): + if ( + entry.route == normalized_route + and normalized_method in entry.methods + ): + return record, entry + return None + + def _match_waiter_plugins(self, session_key: str) -> list[SdkPluginRecord]: + matches: list[SdkPluginRecord] = [] + for record in sorted(self._records.values(), key=lambda item: item.load_order): + if session_key in self._session_waiters.get(record.plugin_id, set()): + matches.append(record) + return matches + + async def _dispatch_waiter_event( + self, + event: AstrMessageEvent, + records: list[SdkPluginRecord], + ) -> SdkDispatchResult: + result = SdkDispatchResult() + dispatch_state = _DispatchState(event=event) + dispatch_token = self._get_dispatch_token(event) or uuid.uuid4().hex + self._bind_dispatch_token(event, dispatch_token) + overlay = self._ensure_request_overlay( + dispatch_token, + should_call_llm=not bool(getattr(event, "call_llm", False)), + ) + request_context = _RequestContext( + plugin_id="", + request_id="", + dispatch_token=dispatch_token, + dispatch_state=dispatch_state, + ) + self._request_contexts[dispatch_token] = request_context + for record in records: + if record.state in { + SDK_STATE_DISABLED, + SDK_STATE_FAILED, + SDK_STATE_RELOADING, + }: + continue + if record.session is None: + continue + whitelist = ( + None + if overlay.handler_whitelist is None + else set(overlay.handler_whitelist) + ) + if whitelist is not None and record.plugin_id not in whitelist: + continue + request_id = f"sdk_waiter_{record.plugin_id}_{uuid.uuid4().hex}" + request_context.plugin_id = record.plugin_id + request_context.request_id = request_id + request_context.cancelled = False + setattr(event, "_sdk_last_request_id", request_id) + payload = EventConverter.core_to_sdk( + event, + dispatch_token=dispatch_token, + plugin_id=record.plugin_id, + request_id=request_id, + ) + self._track_request_scope( + dispatch_token=dispatch_token, + request_id=request_id, + plugin_id=record.plugin_id, + ) + try: + output = await record.session.invoke_handler( + "__sdk_session_waiter__", + payload, + request_id=request_id, + args={}, + ) + except Exception as exc: + logger.warning( + "SDK waiter dispatch failed: plugin=%s error=%s", + record.plugin_id, + exc, + ) + output = {} + handler_result = EventConverter.extract_handler_result( + output if isinstance(output, dict) else {} + ) + result.executed_handlers.append( + {"plugin_id": record.plugin_id, "handler_id": "__sdk_session_waiter__"} + ) + dispatch_state.sent_message = ( + dispatch_state.sent_message or handler_result["sent_message"] + ) + dispatch_state.stopped = dispatch_state.stopped or handler_result["stop"] + if handler_result["call_llm"]: + overlay.requested_llm = True + overlay.should_call_llm = True + if handler_result["sent_message"] or handler_result["stop"]: + overlay.should_call_llm = False + if handler_result["stop"]: + break + result.sent_message = dispatch_state.sent_message + result.stopped = dispatch_state.stopped + if not result.executed_handlers: + result.skipped_reason = SKIP_NO_MATCH + if result.sent_message: + event._has_send_oper = True + overlay.should_call_llm = False + event.should_call_llm(True) + if result.stopped: + event.stop_event() + overlay.should_call_llm = False + event.should_call_llm(True) + return result diff --git a/astrbot/core/sdk_bridge/trigger_converter.py b/astrbot/core/sdk_bridge/trigger_converter.py new file mode 100644 index 0000000000..b4c03f86a1 --- /dev/null +++ b/astrbot/core/sdk_bridge/trigger_converter.py @@ -0,0 +1,312 @@ +from __future__ import annotations + +import inspect +import re +import shlex +import typing +from dataclasses import dataclass +from typing import Any, get_type_hints + +from astrbot_sdk._message_types import normalize_message_type +from astrbot_sdk.events import MessageEvent as SdkMessageEvent +from astrbot_sdk.protocol.descriptors import ( + CommandTrigger, + CompositeFilterSpec, + HandlerDescriptor, + LocalFilterRefSpec, + MessageTrigger, + MessageTypeFilterSpec, + ParamSpec, + PlatformFilterSpec, +) + +from astrbot.core.platform.astr_message_event import AstrMessageEvent + + +@dataclass(slots=True) +class TriggerMatch: + plugin_id: str + handler_id: str + args: dict[str, Any] + priority: int + load_order: int + declaration_order: int + + +class TriggerConverter: + @staticmethod + def _message_type_name(event: AstrMessageEvent) -> str: + return normalize_message_type( + event.get_message_type(), + group_id=event.get_group_id() or None, + user_id=event.get_sender_id() or None, + empty_default="other", + ) + + @staticmethod + def _match_command_name(text: str, command_name: str) -> str | None: + normalized = text.strip() + if normalized == command_name: + return "" + if normalized.startswith(f"{command_name} "): + return normalized[len(command_name) :].strip() + return None + + @staticmethod + def _split_command_remainder(remainder: str) -> list[str]: + try: + return shlex.split(remainder) + except ValueError: + return remainder.split() + + @classmethod + def _build_command_args(cls, handler, remainder: str) -> dict[str, Any]: + param_specs = getattr(handler, "param_specs", None) + if not isinstance(param_specs, list): + names = cls._legacy_arg_parameter_names(handler) + if not names or not remainder: + return {} + if len(names) == 1: + return {names[0]: remainder} + parts = cls._split_command_remainder(remainder) + return { + name: parts[index] + for index, name in enumerate(names) + if index < len(parts) + } + if not param_specs or not remainder: + return {} + if len(param_specs) == 1: + return {param_specs[0].name: remainder} + parts = cls._split_command_remainder(remainder) + args: dict[str, Any] = {} + for index, spec in enumerate(param_specs): + if index >= len(parts): + break + if spec.type == "greedy_str": + args[spec.name] = " ".join(parts[index:]) + break + args[spec.name] = parts[index] + return args + + @classmethod + def _build_regex_args(cls, handler, match: re.Match[str]) -> dict[str, Any]: + named = { + key: value for key, value in match.groupdict().items() if value is not None + } + param_specs = getattr(handler, "param_specs", None) + if isinstance(param_specs, list): + names = [spec.name for spec in param_specs if spec.name not in named] + else: + names = [ + name + for name in cls._legacy_arg_parameter_names(handler) + if name not in named + ] + positional = [value for value in match.groups() if value is not None] + for index, value in enumerate(positional): + if index >= len(names): + break + named[names[index]] = value + return named + + @classmethod + def _build_descriptor_command_args( + cls, + param_specs: list[ParamSpec], + remainder: str, + ) -> dict[str, Any]: + if not param_specs or not remainder: + return {} + if len(param_specs) == 1: + return {param_specs[0].name: remainder} + parts = cls._split_command_remainder(remainder) + args: dict[str, Any] = {} + for index, spec in enumerate(param_specs): + if index >= len(parts): + break + if spec.type == "greedy_str": + args[spec.name] = " ".join(parts[index:]) + break + args[spec.name] = parts[index] + return args + + @classmethod + def _build_descriptor_regex_args( + cls, + param_specs: list[ParamSpec], + match: re.Match[str], + ) -> dict[str, Any]: + named = { + key: value for key, value in match.groupdict().items() if value is not None + } + names = [spec.name for spec in param_specs if spec.name not in named] + positional = [value for value in match.groups() if value is not None] + for index, value in enumerate(positional): + if index >= len(names): + break + named[names[index]] = value + return named + + @classmethod + def _match_filters( + cls, + descriptor: HandlerDescriptor, + event: AstrMessageEvent, + ) -> bool: + for filter_spec in descriptor.filters: + if not cls._match_filter_spec(filter_spec, event): + return False + return True + + @classmethod + def _match_filter_spec(cls, filter_spec, event: AstrMessageEvent) -> bool: + if isinstance(filter_spec, PlatformFilterSpec): + return event.get_platform_name() in filter_spec.platforms + if isinstance(filter_spec, MessageTypeFilterSpec): + return cls._message_type_name(event) in filter_spec.message_types + if isinstance(filter_spec, LocalFilterRefSpec): + # Local filter refs point at plugin-process callables. The host bridge + # cannot execute them, so trigger matching must stay fail-open here. + return True + if isinstance(filter_spec, CompositeFilterSpec): + results = [ + cls._match_filter_spec(child, event) for child in filter_spec.children + ] + if filter_spec.kind == "and": + return all(results) + return any(results) + return True + + @classmethod + def _legacy_arg_parameter_names(cls, handler) -> list[str]: + try: + signature = inspect.signature(handler) + except (TypeError, ValueError): + return [] + try: + type_hints = get_type_hints(handler) + except Exception: + type_hints = {} + names: list[str] = [] + for parameter in signature.parameters.values(): + if parameter.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + continue + if cls._is_injected_parameter( + parameter.name, type_hints.get(parameter.name) + ): + continue + names.append(parameter.name) + return names + + @classmethod + def _is_injected_parameter(cls, name: str, annotation: Any) -> bool: + if name in {"event", "ctx", "context"}: + return True + normalized = cls._unwrap_optional(annotation) + if normalized is None: + return False + if normalized in {AstrMessageEvent, SdkMessageEvent}: + return True + if isinstance(normalized, type) and issubclass( + normalized, + (AstrMessageEvent, SdkMessageEvent), + ): + return True + return False + + @staticmethod + def _unwrap_optional(annotation: Any) -> Any: + if annotation is None: + return None + origin = typing.get_origin(annotation) + if origin is typing.Union: + options = [ + item for item in typing.get_args(annotation) if item is not type(None) + ] + if len(options) == 1: + return options[0] + return annotation + + @classmethod + def match_handler( + cls, + *, + plugin_id: str, + handler=None, + descriptor: HandlerDescriptor, + event: AstrMessageEvent, + load_order: int, + declaration_order: int, + ) -> TriggerMatch | None: + trigger = descriptor.trigger + + required_role = descriptor.permissions.required_role + if required_role is None and descriptor.permissions.require_admin: + required_role = "admin" + if required_role == "admin" and not event.is_admin(): + return None + if not cls._match_filters(descriptor, event): + return None + + if isinstance(trigger, CommandTrigger): + text = event.get_message_str().strip() + for command_name in [trigger.command, *trigger.aliases]: + if not command_name: + continue + remainder = cls._match_command_name(text, command_name) + if remainder is None: + continue + return TriggerMatch( + plugin_id=plugin_id, + handler_id=descriptor.id, + args=( + cls._build_command_args(handler, remainder) + if handler is not None + else cls._build_descriptor_command_args( + descriptor.param_specs, + remainder, + ) + ), + priority=descriptor.priority, + load_order=load_order, + declaration_order=declaration_order, + ) + return None + + if isinstance(trigger, MessageTrigger): + text = event.get_message_str() + if trigger.regex: + match = re.search(trigger.regex, text) + if match is None: + return None + args = ( + cls._build_regex_args(handler, match) if handler is not None else {} + ) + if handler is None: + args = cls._build_descriptor_regex_args( + descriptor.param_specs, match + ) + else: + if trigger.keywords and not any( + keyword in text for keyword in trigger.keywords + ): + return None + args = {} + return TriggerMatch( + plugin_id=plugin_id, + handler_id=descriptor.id, + args=args, + priority=descriptor.priority, + load_order=load_order, + declaration_order=declaration_order, + ) + + return None + + @staticmethod + def sort_key(match: TriggerMatch) -> tuple[int, int, int]: + return (-match.priority, match.load_order, match.declaration_order) diff --git a/astrbot/core/skills/skill_manager.py b/astrbot/core/skills/skill_manager.py index ec3ba8f034..2ee59bde44 100644 --- a/astrbot/core/skills/skill_manager.py +++ b/astrbot/core/skills/skill_manager.py @@ -22,10 +22,12 @@ SKILLS_CONFIG_FILENAME = "skills.json" SANDBOX_SKILLS_CACHE_FILENAME = "sandbox_skills_cache.json" +SDK_PLUGIN_SKILLS_FILENAME = "sdk_plugin_skills.json" DEFAULT_SKILLS_CONFIG: dict[str, dict] = {"skills": {}} SANDBOX_SKILLS_ROOT = "skills" SANDBOX_WORKSPACE_ROOT = "/workspace" _SANDBOX_SKILLS_CACHE_VERSION = 1 +_SDK_PLUGIN_SKILLS_VERSION = 1 _SKILL_NAME_RE = re.compile(r"^[A-Za-z0-9._-]+$") @@ -94,6 +96,16 @@ class SkillInfo: sandbox_exists: bool = False +@dataclass(frozen=True, slots=True) +class LocalSkillSource: + name: str + skill_dir: Path + skill_md_path: Path + owner_type: str = "standalone" + description_override: str = "" + plugin_id: str | None = None + + def _parse_frontmatter_description(text: str) -> str: """Extract the ``description`` value from YAML frontmatter. @@ -274,8 +286,221 @@ def __init__(self, skills_root: str | None = None) -> None: data_path = Path(get_astrbot_data_path()) self.config_path = str(data_path / SKILLS_CONFIG_FILENAME) self.sandbox_skills_cache_path = str(data_path / SANDBOX_SKILLS_CACHE_FILENAME) + self.sdk_plugin_skills_path = str(data_path / SDK_PLUGIN_SKILLS_FILENAME) os.makedirs(self.skills_root, exist_ok=True) + def _read_skill_description(self, skill_md_path: Path) -> str: + try: + content = skill_md_path.read_text(encoding="utf-8") + except Exception: + return "" + return _parse_frontmatter_description(content) + + def _discover_standalone_skill_sources(self) -> dict[str, LocalSkillSource]: + sources: dict[str, LocalSkillSource] = {} + skills_root = Path(self.skills_root) + if not skills_root.exists(): + return sources + + for entry in sorted(skills_root.iterdir()): + if not entry.is_dir(): + continue + skill_md_path = _normalize_skill_markdown_path(entry) + if skill_md_path is None: + continue + sources[entry.name] = LocalSkillSource( + name=entry.name, + skill_dir=entry, + skill_md_path=skill_md_path, + owner_type="standalone", + ) + return sources + + def _load_sdk_plugin_skills_registry(self) -> dict[str, object]: + if not os.path.exists(self.sdk_plugin_skills_path): + return {"version": _SDK_PLUGIN_SKILLS_VERSION, "plugins": {}} + try: + with open(self.sdk_plugin_skills_path, encoding="utf-8") as f: + data = json.load(f) + except Exception: + return {"version": _SDK_PLUGIN_SKILLS_VERSION, "plugins": {}} + if not isinstance(data, dict): + return {"version": _SDK_PLUGIN_SKILLS_VERSION, "plugins": {}} + plugins = data.get("plugins", {}) + if not isinstance(plugins, dict): + plugins = {} + return { + "version": int(data.get("version", _SDK_PLUGIN_SKILLS_VERSION)), + "plugins": plugins, + } + + def _save_sdk_plugin_skills_registry(self, registry: dict[str, object]) -> None: + registry["version"] = _SDK_PLUGIN_SKILLS_VERSION + with open(self.sdk_plugin_skills_path, "w", encoding="utf-8") as f: + json.dump(registry, f, ensure_ascii=False, indent=2) + + def replace_sdk_plugin_skills( + self, + plugin_id: str, + skills: list[dict[str, str]], + ) -> None: + plugin_name = str(plugin_id).strip() + if not plugin_name: + raise ValueError("plugin_id must not be empty") + + normalized_skills: list[dict[str, str]] = [] + for item in skills: + if not isinstance(item, dict): + continue + skill_name = str(item.get("name", "")).strip() + skill_dir_text = str(item.get("skill_dir", "")).strip() + if not skill_name or not _SKILL_NAME_RE.fullmatch(skill_name): + continue + if not skill_dir_text: + continue + skill_dir = Path(skill_dir_text).resolve() + skill_md_path = Path( + str(item.get("path", "")).strip() or str(skill_dir / "SKILL.md") + ).resolve() + normalized_skills.append( + { + "name": skill_name, + "description": str(item.get("description", "") or ""), + "path": str(skill_md_path), + "skill_dir": str(skill_dir), + } + ) + + registry = self._load_sdk_plugin_skills_registry() + plugins = registry.get("plugins", {}) + if not isinstance(plugins, dict): + plugins = {} + previous_items = plugins.get(plugin_name, []) + previous_names = { + str(item.get("name", "")).strip() + for item in previous_items + if isinstance(item, dict) + } + if normalized_skills: + plugins[plugin_name] = sorted( + normalized_skills, + key=lambda item: str(item.get("name", "")), + ) + else: + plugins.pop(plugin_name, None) + registry["plugins"] = plugins + self._save_sdk_plugin_skills_registry(registry) + + current_names = {item["name"] for item in normalized_skills} + for removed_name in sorted(previous_names - current_names): + self._remove_skill_from_sandbox_cache(removed_name) + + def remove_sdk_plugin_skills(self, plugin_id: str) -> None: + self.replace_sdk_plugin_skills(plugin_id, []) + + def prune_sdk_plugin_skills(self, active_plugin_ids: set[str]) -> None: + normalized_ids = { + str(item).strip() for item in active_plugin_ids if str(item).strip() + } + registry = self._load_sdk_plugin_skills_registry() + plugins = registry.get("plugins", {}) + if not isinstance(plugins, dict): + return + + removed_skill_names: set[str] = set() + updated_plugins: dict[str, object] = {} + for plugin_id, items in plugins.items(): + plugin_name = str(plugin_id).strip() + if not plugin_name: + continue + if plugin_name in normalized_ids: + updated_plugins[plugin_name] = items + continue + if isinstance(items, list): + removed_skill_names.update( + str(item.get("name", "")).strip() + for item in items + if isinstance(item, dict) + ) + + registry["plugins"] = updated_plugins + self._save_sdk_plugin_skills_registry(registry) + for removed_name in sorted(name for name in removed_skill_names if name): + self._remove_skill_from_sandbox_cache(removed_name) + + def _discover_sdk_plugin_skill_sources(self) -> dict[str, LocalSkillSource]: + sources: dict[str, LocalSkillSource] = {} + registry = self._load_sdk_plugin_skills_registry() + plugins = registry.get("plugins", {}) + if not isinstance(plugins, dict): + return sources + for plugin_id, items in plugins.items(): + if not isinstance(items, list): + continue + for item in items: + if not isinstance(item, dict): + continue + skill_name = str(item.get("name", "")).strip() + skill_dir_text = str(item.get("skill_dir", "")).strip() + path_text = str(item.get("path", "")).strip() + if not skill_name or not _SKILL_NAME_RE.fullmatch(skill_name): + continue + if not skill_dir_text: + continue + skill_dir = Path(skill_dir_text) + skill_md_path = Path(path_text or str(skill_dir / "SKILL.md")) + if not skill_dir.is_dir() or not skill_md_path.is_file(): + continue + sources.setdefault( + skill_name, + LocalSkillSource( + name=skill_name, + skill_dir=skill_dir, + skill_md_path=skill_md_path, + owner_type="sdk_registered", + description_override=str(item.get("description", "") or ""), + plugin_id=str(plugin_id), + ), + ) + return sources + + def list_local_skill_sources(self) -> list[LocalSkillSource]: + sources = self._discover_standalone_skill_sources() + for name, source in self._discover_sdk_plugin_skill_sources().items(): + sources.setdefault(name, source) + return [sources[name] for name in sorted(sources)] + + def get_local_skill_source(self, name: str) -> LocalSkillSource | None: + for source in self.list_local_skill_sources(): + if source.name == name: + return source + return None + + def materialize_local_skill_bundle( + self, + bundle_root: Path, + *, + skill_names: list[str] | None = None, + ) -> list[LocalSkillSource]: + selected_names = ( + {name for name in skill_names if name} if skill_names is not None else None + ) + bundle_root.mkdir(parents=True, exist_ok=True) + + copied_sources: list[LocalSkillSource] = [] + for source in self.list_local_skill_sources(): + if selected_names is not None and source.name not in selected_names: + continue + target_dir = bundle_root / source.name + if target_dir.exists(): + shutil.rmtree(target_dir) + # SDK-registered skills may live inside plugin packages, so bundle + # them under the public skill id to give sandbox/runtime a stable + # path that is independent from the plugin's internal layout. + shutil.copytree(source.skill_dir, target_dir) + copied_sources.append(source) + return copied_sources + def _load_config(self) -> dict: if not os.path.exists(self.config_path): self._save_config(DEFAULT_SKILLS_CONFIG.copy()) @@ -383,25 +608,17 @@ def list_skills( sandbox_cached_descriptions[name] = str(item.get("description", "") or "") sandbox_cached_paths[name] = path - for entry in sorted(Path(self.skills_root).iterdir()): - if not entry.is_dir(): - continue - skill_name = entry.name - skill_md = _normalize_skill_markdown_path(entry) - if skill_md is None: - continue + for source in self.list_local_skill_sources(): + skill_name = source.name active = skill_configs.get(skill_name, {}).get("active", True) if skill_name not in skill_configs: skill_configs[skill_name] = {"active": active} modified = True if active_only and not active: continue - description = "" - try: - content = skill_md.read_text(encoding="utf-8") - description = _parse_frontmatter_description(content) - except Exception: - description = "" + description = source.description_override or self._read_skill_description( + source.skill_md_path + ) sandbox_exists = ( runtime == "sandbox" and skill_name in sandbox_cached_descriptions ) @@ -412,7 +629,7 @@ def list_skills( skill_name ) or _default_sandbox_skill_path(skill_name) else: - path_str = str(skill_md) + path_str = str(source.skill_md_path) path_str = path_str.replace("\\", "/") skills_by_name[skill_name] = SkillInfo( name=skill_name, @@ -468,9 +685,7 @@ def list_skills( return [skills_by_name[name] for name in sorted(skills_by_name)] def is_sandbox_only_skill(self, name: str) -> bool: - skill_dir = Path(self.skills_root) / name - skill_md_exists = _normalize_skill_markdown_path(skill_dir) is not None - if skill_md_exists: + if self.get_local_skill_source(name) is not None: return False cache = self._load_sandbox_skills_cache() skills = cache.get("skills", []) @@ -517,9 +732,14 @@ def delete_skill(self, name: str) -> None: "Sandbox preset skill cannot be deleted from local skill management." ) - skill_dir = Path(self.skills_root) / name - if skill_dir.exists(): - shutil.rmtree(skill_dir) + source = self.get_local_skill_source(name) + if source is not None and source.owner_type != "standalone": + raise PermissionError( + "SDK-registered skill cannot be deleted here. Disable or update the owning plugin instead." + ) + + if source is not None and source.skill_dir.exists(): + shutil.rmtree(source.skill_dir) # Ensure UI consistency even when there is no active sandbox session # to refresh cache from runtime side. diff --git a/astrbot/core/star/__init__.py b/astrbot/core/star/__init__.py index 796e0bd683..f9a7417c21 100644 --- a/astrbot/core/star/__init__.py +++ b/astrbot/core/star/__init__.py @@ -1,11 +1,23 @@ -# 兼容导出: Provider 从 provider 模块重新导出 -from astrbot.core.provider import Provider +from __future__ import annotations + +from importlib import import_module +from typing import TYPE_CHECKING, Any -from .base import Star -from .context import Context from .star import StarMetadata, star_map, star_registry -from .star_manager import PluginManager -from .star_tools import StarTools + +if TYPE_CHECKING: + from astrbot.core.provider import Provider + + from .base import Star + from .context import Context + from .star_manager import PluginManager + from .star_tools import StarTools +else: + Provider: Any + Star: Any + Context: Any + PluginManager: Any + StarTools: Any __all__ = [ "Context", @@ -17,3 +29,17 @@ "star_map", "star_registry", ] + + +def __getattr__(name: str) -> Any: + if name == "Provider": + return import_module("astrbot.core.provider").Provider + if name == "Star": + return import_module(".base", __name__).Star + if name == "Context": + return import_module(".context", __name__).Context + if name == "PluginManager": + return import_module(".star_manager", __name__).PluginManager + if name == "StarTools": + return import_module(".star_tools", __name__).StarTools + raise AttributeError(name) diff --git a/astrbot/core/star/command_management.py b/astrbot/core/star/command_management.py index c60af9ea26..f73ed65600 100644 --- a/astrbot/core/star/command_management.py +++ b/astrbot/core/star/command_management.py @@ -4,8 +4,7 @@ from dataclasses import dataclass, field from typing import Any -from astrbot.api import sp -from astrbot.core import db_helper, logger +from astrbot.core import db_helper, logger, sp from astrbot.core.db.po import CommandConfig from astrbot.core.star.filter.command import CommandFilter from astrbot.core.star.filter.command_group import CommandGroupFilter diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 606f46dd73..ceb9e16694 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -5,25 +5,17 @@ from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any, Protocol +from astrbot_sdk.message.components import component_to_payload_sync from deprecated import deprecated from astrbot.core.agent.hooks import BaseAgentRunHooks from astrbot.core.agent.message import Message from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner from astrbot.core.agent.tool import ToolSet -from astrbot.core.astrbot_config_mgr import AstrBotConfigManager -from astrbot.core.config.astrbot_config import AstrBotConfig -from astrbot.core.conversation_mgr import ConversationManager -from astrbot.core.db import BaseDatabase -from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.persona_mgr import PersonaManager -from astrbot.core.platform import Platform -from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion -from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager +from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager -from astrbot.core.provider.manager import ProviderManager from astrbot.core.provider.provider import ( EmbeddingProvider, Provider, @@ -31,11 +23,11 @@ STTProvider, TTSProvider, ) +from astrbot.core.sdk_bridge.event_converter import EventConverter from astrbot.core.star.filter.platform_adapter_type import ( ADAPTER_NAME_2_TYPE, PlatformAdapterType, ) -from astrbot.core.subagent_orchestrator import SubAgentOrchestrator from ..exceptions import ProviderNotFoundError from .filter.command import CommandFilter @@ -46,7 +38,19 @@ logger = logging.getLogger("astrbot") if TYPE_CHECKING: + from astrbot.core.astrbot_config_mgr import AstrBotConfigManager + from astrbot.core.config.astrbot_config import AstrBotConfig + from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.cron.manager import CronJobManager + from astrbot.core.db import BaseDatabase + from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager + from astrbot.core.persona_mgr import PersonaManager + from astrbot.core.platform import Platform + from astrbot.core.platform.astr_message_event import AstrMessageEvent + from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager + from astrbot.core.provider.manager import ProviderManager + from astrbot.core.sdk_bridge.plugin_bridge import SdkPluginBridge + from astrbot.core.subagent_orchestrator import SubAgentOrchestrator class PlatformManagerProtocol(Protocol): @@ -100,6 +104,8 @@ def __init__( self.cron_manager = cron_manager """Cron job manager, initialized by core lifecycle.""" self.subagent_orchestrator = subagent_orchestrator + self.sdk_plugin_bridge: SdkPluginBridge | None = None + """SDK plugin bridge, initialized by core lifecycle when available.""" async def llm_generate( self, @@ -151,7 +157,7 @@ async def tool_loop_agent( image_urls: list[str] | None = None, tools: ToolSet | None = None, system_prompt: str | None = None, - contexts: list[Message] | None = None, + contexts: list[Message | dict[str, Any]] | None = None, max_steps: int = 30, tool_call_timeout: int = 120, **kwargs: Any, @@ -342,6 +348,10 @@ def get_all_embedding_providers(self) -> list[EmbeddingProvider]: """获取所有用于 Embedding 任务的 Provider。""" return self.provider_manager.embedding_provider_insts + def get_all_rerank_providers(self) -> list[RerankProvider]: + """获取所有用于 Rerank 任务的 Provider。""" + return self.provider_manager.rerank_provider_insts + def get_using_provider(self, umo: str | None = None) -> Provider | None: """获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。 @@ -454,6 +464,34 @@ async def send_message( for platform in self.platform_manager.platform_insts: if platform.meta().id == session.platform_name: await platform.send_by_session(session, message_chain) + if self.sdk_plugin_bridge is not None: + try: + await self.sdk_plugin_bridge.dispatch_system_event( + "after_message_sent", + { + "session_id": str(session), + "platform": platform.meta().name, + "platform_id": platform.meta().id, + "message_type": EventConverter._sdk_message_type( + session.message_type + ), + "message_outline": message_chain.get_plain_text( + with_other_comps_mark=True + ), + "sent_message_outline": message_chain.get_plain_text( + with_other_comps_mark=True + ), + "sent_messages": [ + component_to_payload_sync(component) + for component in message_chain.chain + ], + }, + ) + except Exception as exc: + logger.warning( + "SDK after_message_sent dispatch failed for proactive send: %s", + exc, + ) return True logger.warning( f"cannot find platform for session {str(session)}, message not sent" diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 25df73f642..db30b91eee 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -11,9 +11,11 @@ import sys import tempfile import traceback +from pathlib import Path from types import ModuleType import yaml +from astrbot_sdk.runtime.loader import load_plugin_spec, validate_plugin_spec from packaging.specifiers import InvalidSpecifier, SpecifierSet from packaging.version import InvalidVersion, Version @@ -30,6 +32,7 @@ from astrbot.core.provider.register import llm_tools from astrbot.core.utils.astrbot_path import ( get_astrbot_config_path, + get_astrbot_data_path, get_astrbot_path, get_astrbot_plugin_path, get_astrbot_temp_path, @@ -459,6 +462,104 @@ def _get_plugin_dir_name_from_metadata(plugin_path: str) -> str: PluginManager._validate_importable_name(plugin_dir_name) return plugin_dir_name + @staticmethod + def _detect_plugin_type(plugin_path: str) -> tuple[str, str]: + """根据插件清单文件识别安装目标。 + + Why: + 旧版插件和 SDK 插件分别由不同加载器管理,安装阶段必须先按 + `metadata.yaml` / `plugin.yaml` 分流,否则 SDK 插件会被误送到 + `data/plugins`,后续无法被 SDK 桥接层发现。 + """ + plugin_dir = Path(plugin_path) + plugin_manifest_path = plugin_dir / "plugin.yaml" + legacy_metadata_path = plugin_dir / "metadata.yaml" + + if plugin_manifest_path.exists(): + plugin_spec = load_plugin_spec(plugin_dir) + validate_plugin_spec(plugin_spec) + return "sdk", plugin_spec.name + + if legacy_metadata_path.exists(): + return "legacy", PluginManager._get_plugin_dir_name_from_metadata( + plugin_path + ) + + raise Exception( + "无法识别插件类型:插件目录中既没有 plugin.yaml,也没有 metadata.yaml。" + ) + + @staticmethod + def _read_plugin_readme(plugin_path: str, plugin_label: str) -> str | None: + plugin_dir = Path(plugin_path) + + for readme_name in ("README.md", "readme.md"): + readme_path = plugin_dir / readme_name + if not readme_path.exists(): + continue + try: + return readme_path.read_text(encoding="utf-8") + except Exception as exc: + logger.warning( + "读取插件 %s 的 %s 文件失败: %s", + plugin_label, + readme_name, + exc, + ) + return None + + return None + + @staticmethod + def _build_plugin_install_result( + *, + name: str, + repo: str | None, + readme: str | None, + plugin_type: str, + ) -> dict[str, str | None]: + return { + "repo": repo, + "readme": readme, + "name": name, + "type": plugin_type, + } + + async def _install_sdk_plugin( + self, + *, + temp_plugin_path: str, + plugin_name: str, + repo_url: str | None, + ) -> dict[str, str | None]: + """安装 SDK 插件到 data/sdk_plugins 并触发桥接层重新发现。""" + sdk_plugins_dir = Path(get_astrbot_data_path()) / "sdk_plugins" + target_plugin_path = sdk_plugins_dir / plugin_name + + if target_plugin_path.exists(): + raise Exception(f"安装失败:SDK 插件 {plugin_name} 已存在。") + + sdk_plugins_dir.mkdir(parents=True, exist_ok=True) + Path(temp_plugin_path).rename(target_plugin_path) + + sdk_plugin_bridge = getattr(self.context, "sdk_plugin_bridge", None) + if sdk_plugin_bridge is not None: + await sdk_plugin_bridge.reload_all(reset_restart_budget=True) + else: + logger.warning( + "SDK 插件 %s 已写入 %s,但当前未找到 sdk_plugin_bridge," + "需等待后续生命周期重载。", + plugin_name, + target_plugin_path, + ) + + return self._build_plugin_install_result( + name=plugin_name, + repo=repo_url, + readme=self._read_plugin_readme(str(target_plugin_path), plugin_name), + plugin_type="sdk", + ) + @staticmethod def _validate_astrbot_version_specifier( version_spec: str | None, @@ -1061,6 +1162,19 @@ async def load( await handler.handler(metadata) except Exception: logger.error(traceback.format_exc()) + sdk_plugin_bridge = getattr(self.context, "sdk_plugin_bridge", None) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_system_event( + "plugin_loaded", + { + "plugin_name": metadata.name, + "display_name": metadata.display_name or metadata.name, + "version": metadata.version, + }, + ) + except Exception as exc: + logger.warning("SDK plugin_loaded dispatch failed: %s", exc) except BaseException as e: logger.error(f"----- 插件 {root_dir_name} 载入失败 -----") @@ -1238,6 +1352,7 @@ async def install_plugin( async with self._pm_lock: plugin_path = "" dir_name = "" + should_track_failed_install_dir = True try: _, repo_name, _ = self.updator.parse_github_url(repo_url) repo_name = self.updator.format_name(repo_name) @@ -1248,21 +1363,36 @@ async def install_plugin( ) plugin_path = await self.updator.install(repo_url, proxy) - # reload the plugin - dir_name = os.path.basename(plugin_path) - metadata_dir_name = self._get_plugin_dir_name_from_metadata(plugin_path) + plugin_type, plugin_name = self._detect_plugin_type(plugin_path) + logger.info( + "插件安装类型识别完成:repo=%s, type=%s, name=%s", + repo_url, + plugin_type, + plugin_name, + ) + dir_name = plugin_name + if plugin_type == "sdk": + should_track_failed_install_dir = False + return await self._install_sdk_plugin( + temp_plugin_path=plugin_path, + plugin_name=plugin_name, + repo_url=repo_url, + ) + + # Why: + # 旧版插件的导入路径依赖目录名与 metadata.yaml 中的 name 一致, + # 因此在加载前必须完成重命名;SDK 插件则已在前面的分支单独处理。 target_plugin_path = os.path.join( self.plugin_store_path, - metadata_dir_name, + plugin_name, ) if target_plugin_path != plugin_path and os.path.exists( target_plugin_path ): - raise Exception(f"安装失败:目录 {metadata_dir_name} 已存在。") + raise Exception(f"安装失败:目录 {plugin_name} 已存在。") if target_plugin_path != plugin_path: os.rename(plugin_path, target_plugin_path) plugin_path = target_plugin_path - dir_name = metadata_dir_name await self._ensure_plugin_requirements( plugin_path, dir_name, @@ -1286,36 +1416,25 @@ async def install_plugin( plugin = star break - # Extract README.md content if exists - readme_content = None - readme_path = os.path.join(plugin_path, "README.md") - if not os.path.exists(readme_path): - readme_path = os.path.join(plugin_path, "readme.md") - - if os.path.exists(readme_path): - try: - with open(readme_path, encoding="utf-8") as f: - readme_content = f.read() - except Exception as e: - logger.warning( - f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}", - ) + readme_content = self._read_plugin_readme(plugin_path, dir_name) plugin_info = None if plugin: - plugin_info = { - "repo": plugin.repo, - "readme": readme_content, - "name": plugin.name, - } + plugin_info = self._build_plugin_install_result( + name=plugin.name, + repo=plugin.repo, + readme=readme_content, + plugin_type="legacy", + ) return plugin_info except Exception as e: - self._track_failed_install_dir( - dir_name=dir_name, - plugin_path=plugin_path, - error=e, - ) + if should_track_failed_install_dir: + self._track_failed_install_dir( + dir_name=dir_name, + plugin_path=plugin_path, + error=e, + ) if dir_name and plugin_path: logger.warning( f"安装插件 {dir_name} 失败,插件安装目录:{plugin_path}", @@ -1601,6 +1720,24 @@ def _log_del_exception(fut: asyncio.Future) -> None: await handler.handler(star_metadata) except Exception: logger.error(traceback.format_exc()) + sdk_plugin_bridge = ( + getattr(star_metadata.star_cls.context, "sdk_plugin_bridge", None) + if getattr(star_metadata, "star_cls", None) + else None + ) + if sdk_plugin_bridge is not None: + try: + await sdk_plugin_bridge.dispatch_system_event( + "plugin_unloaded", + { + "plugin_name": star_metadata.name, + "display_name": star_metadata.display_name + or star_metadata.name, + "version": star_metadata.version, + }, + ) + except Exception as exc: + logger.warning("SDK plugin_unloaded dispatch failed: %s", exc) async def turn_on_plugin(self, plugin_name: str) -> None: plugin = self.context.get_registered_star(plugin_name) @@ -1636,26 +1773,41 @@ async def install_plugin_from_file( dir=self.plugin_store_path, prefix="plugin_upload_" ) temp_desti_dir = desti_dir + should_track_failed_install_dir = True try: self.updator.unzip_file(zip_file_path, desti_dir) - metadata_dir_name = self._get_plugin_dir_name_from_metadata(desti_dir) + try: + os.remove(zip_file_path) + except BaseException as e: + logger.warning(f"删除插件压缩包失败: {e!s}") + + plugin_type, plugin_name = self._detect_plugin_type(desti_dir) + logger.info( + "上传插件安装类型识别完成:type=%s, name=%s, file=%s", + plugin_type, + plugin_name, + zip_file_path, + ) + dir_name = plugin_name + if plugin_type == "sdk": + should_track_failed_install_dir = False + return await self._install_sdk_plugin( + temp_plugin_path=desti_dir, + plugin_name=plugin_name, + repo_url=None, + ) + target_plugin_path = os.path.join( self.plugin_store_path, - metadata_dir_name, + plugin_name, ) if target_plugin_path != desti_dir and os.path.exists(target_plugin_path): - raise Exception(f"安装失败:目录 {metadata_dir_name} 已存在。") + raise Exception(f"安装失败:目录 {plugin_name} 已存在。") if target_plugin_path != desti_dir: os.rename(desti_dir, target_plugin_path) - dir_name = metadata_dir_name desti_dir = target_plugin_path - # remove the zip - try: - os.remove(zip_file_path) - except BaseException as e: - logger.warning(f"删除插件压缩包失败: {e!s}") await self._ensure_plugin_requirements(desti_dir, dir_name) # await self.reload() success, error_message = await self.load( @@ -1677,26 +1829,16 @@ async def install_plugin_from_file( plugin = star break - # Extract README.md content if exists - readme_content = None - readme_path = os.path.join(desti_dir, "README.md") - if not os.path.exists(readme_path): - readme_path = os.path.join(desti_dir, "readme.md") - - if os.path.exists(readme_path): - try: - with open(readme_path, encoding="utf-8") as f: - readme_content = f.read() - except Exception as e: - logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}") + readme_content = self._read_plugin_readme(desti_dir, dir_name) plugin_info = None if plugin: - plugin_info = { - "repo": plugin.repo, - "readme": readme_content, - "name": plugin.name, - } + plugin_info = self._build_plugin_install_result( + name=plugin.name, + repo=plugin.repo, + readme=readme_content, + plugin_type="legacy", + ) if plugin.repo: asyncio.create_task( @@ -1708,14 +1850,13 @@ async def install_plugin_from_file( return plugin_info except Exception as e: - self._track_failed_install_dir( - dir_name=dir_name, - plugin_path=desti_dir, - error=e, - ) - logger.warning( - f"安装插件 {dir_name} 失败,插件安装目录:{desti_dir}", - ) + if should_track_failed_install_dir: + self._track_failed_install_dir( + dir_name=dir_name, + plugin_path=desti_dir, + error=e, + ) + logger.warning(f"安装插件 {dir_name} 失败,插件安装目录:{desti_dir}") raise finally: if temp_desti_dir != desti_dir and os.path.isdir(temp_desti_dir): diff --git a/astrbot/core/star/star_tools.py b/astrbot/core/star/star_tools.py index 4d85131fc6..94237620d7 100644 --- a/astrbot/core/star/star_tools.py +++ b/astrbot/core/star/star_tools.py @@ -28,12 +28,6 @@ from astrbot.core.message.components import BaseMessageComponent from astrbot.core.message.message_event_result import MessageChain from astrbot.core.platform.astr_message_event import MessageSesion -from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import ( - AiocqhttpMessageEvent, -) -from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( - AiocqhttpAdapter, -) from astrbot.core.star.context import Context from astrbot.core.star.star import star_map from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -103,6 +97,13 @@ async def send_message_by_id( raise ValueError("StarTools not initialized") platforms = cls._context.platform_manager.get_insts() if platform == "aiocqhttp": + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import ( + AiocqhttpMessageEvent, + ) + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + adapter = next( (p for p in platforms if isinstance(p, AiocqhttpAdapter)), None, @@ -183,6 +184,13 @@ async def create_event( raise ValueError("StarTools not initialized") platforms = cls._context.platform_manager.get_insts() if platform == "aiocqhttp": + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import ( + AiocqhttpMessageEvent, + ) + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + adapter = next( (p for p in platforms if isinstance(p, AiocqhttpAdapter)), None, diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index b565926749..82e4ea0744 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -9,7 +9,6 @@ import zipfile from pathlib import Path -import aiohttp import certifi import psutil from PIL import Image @@ -19,6 +18,12 @@ logger = logging.getLogger("astrbot") +def _get_aiohttp(): + import aiohttp + + return aiohttp + + def on_error(func, path, exc_info) -> None: """A callback of the rmtree function.""" import stat @@ -70,6 +75,7 @@ async def download_image_by_url( path: str | None = None, ) -> str: """下载图片, 返回 path""" + aiohttp = _get_aiohttp() try: ssl_context = ssl.create_default_context( cafile=certifi.where(), @@ -125,6 +131,7 @@ async def download_image_by_url( async def download_file(url: str, path: str, show_progress: bool = False) -> None: """从指定 url 下载文件到指定路径 path""" + aiohttp = _get_aiohttp() try: ssl_context = ssl.create_default_context( cafile=certifi.where(), diff --git a/astrbot/core/utils/metrics.py b/astrbot/core/utils/metrics.py index 8fb1464284..a3ebd40e7e 100644 --- a/astrbot/core/utils/metrics.py +++ b/astrbot/core/utils/metrics.py @@ -3,12 +3,21 @@ import sys import uuid -import aiohttp - -from astrbot.core import db_helper, logger from astrbot.core.config import VERSION +def _get_aiohttp(): + import aiohttp + + return aiohttp + + +def _get_runtime_dependencies(): + from astrbot.core import db_helper, logger + + return db_helper, logger + + class Metric: _iid_cache = None @@ -45,6 +54,7 @@ async def upload(**kwargs) -> None: Powered by TickStats. """ + db_helper, logger = _get_runtime_dependencies() if os.environ.get("ASTRBOT_DISABLE_METRICS", "0") == "1": return base_url = "https://tickstats.soulter.top/api/metric/90a6c2a1" @@ -69,6 +79,7 @@ async def upload(**kwargs) -> None: logger.error(f"保存指标到数据库失败: {e}") try: + aiohttp = _get_aiohttp() async with aiohttp.ClientSession(trust_env=True) as session: async with session.post(base_url, json=payload, timeout=3) as response: if response.status != 200: diff --git a/astrbot/core/utils/t2i/local_strategy.py b/astrbot/core/utils/t2i/local_strategy.py index 2fa2351291..c50c3b08a2 100644 --- a/astrbot/core/utils/t2i/local_strategy.py +++ b/astrbot/core/utils/t2i/local_strategy.py @@ -1,17 +1,23 @@ -import re import os -import aiohttp +import re import ssl -import certifi -from io import BytesIO -from typing import List, Tuple from abc import ABC, abstractmethod +from io import BytesIO + +import certifi +from PIL import Image, ImageDraw, ImageFont + from astrbot.core.config import VERSION +from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.utils.io import save_temp_img from . import RenderStrategy -from PIL import ImageFont, Image, ImageDraw -from astrbot.core.utils.io import save_temp_img -from astrbot.core.utils.astrbot_path import get_astrbot_data_path + + +def _get_aiohttp(): + import aiohttp + + return aiohttp class FontManager: @@ -20,7 +26,7 @@ class FontManager: _font_cache = {} @classmethod - def get_font(cls, size: int) -> ImageFont.FreeTypeFont|ImageFont.ImageFont: + def get_font(cls, size: int) -> ImageFont.FreeTypeFont | ImageFont.ImageFont: """获取指定大小的字体,优先从缓存获取""" if size in cls._font_cache: return cls._font_cache[size] @@ -66,7 +72,9 @@ class TextMeasurer: """测量文本尺寸的工具类""" @staticmethod - def get_text_size(text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont) -> tuple[int, int]: + def get_text_size( + text: str, font: ImageFont.FreeTypeFont | ImageFont.ImageFont + ) -> tuple[int, int]: """获取文本的尺寸""" # 依赖库Pillow>=11.2.1,不再需要考虑<9.0.0 @@ -75,7 +83,7 @@ def get_text_size(text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont) - @staticmethod def split_text_to_fit_width( - text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont, max_width: int + text: str, font: ImageFont.FreeTypeFont | ImageFont.ImageFont, max_width: int ) -> list[str]: """将文本拆分为多行,确保每行不超过指定宽度""" lines = [] @@ -293,7 +301,10 @@ def render( # 倾斜变换,使用仿射变换实现斜体效果 # 变换矩阵: [1, 0.2, 0, 0, 1, 0] italic_img = text_img.transform( - text_img.size, Image.Transform.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.Resampling.BICUBIC + text_img.size, + Image.Transform.AFFINE, + (1, 0.2, 0, 0, 1, 0), + Image.Resampling.BICUBIC, ) # 粘贴到原图像 @@ -629,6 +640,7 @@ def __init__(self, content: str, image_url: str): async def load_image(self): """加载图片""" try: + aiohttp = _get_aiohttp() ssl_context = ssl.create_default_context(cafile=certifi.where()) connector = aiohttp.TCPConnector(ssl=ssl_context) diff --git a/astrbot/core/utils/t2i/network_strategy.py b/astrbot/core/utils/t2i/network_strategy.py index 53d9441fab..828fa597a7 100644 --- a/astrbot/core/utils/t2i/network_strategy.py +++ b/astrbot/core/utils/t2i/network_strategy.py @@ -2,8 +2,6 @@ import logging import random -import aiohttp - from astrbot.core.config import VERSION from astrbot.core.utils.http_ssl import build_tls_connector from astrbot.core.utils.io import download_image_by_url @@ -16,6 +14,12 @@ logger = logging.getLogger("astrbot") +def _get_aiohttp(): + import aiohttp + + return aiohttp + + class NetworkRenderStrategy(RenderStrategy): def __init__(self, base_url: str | None = None) -> None: super().__init__() @@ -38,6 +42,7 @@ async def get_template(self, name: str = "base") -> str: async def get_official_endpoints(self) -> None: """获取官方的 t2i 端点列表。""" try: + aiohttp = _get_aiohttp() async with aiohttp.ClientSession( trust_env=True, connector=build_tls_connector(), @@ -89,6 +94,7 @@ async def render_custom_template( last_exception = None for endpoint in endpoints: try: + aiohttp = _get_aiohttp() if return_url: async with ( aiohttp.ClientSession( diff --git a/astrbot/dashboard/routes/command.py b/astrbot/dashboard/routes/command.py index cbc565c476..341684448c 100644 --- a/astrbot/dashboard/routes/command.py +++ b/astrbot/dashboard/routes/command.py @@ -1,5 +1,6 @@ from quart import request +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.star.command_management import ( list_command_conflicts, list_commands, @@ -18,8 +19,13 @@ class CommandRoute(Route): - def __init__(self, context: RouteContext) -> None: + def __init__( + self, + context: RouteContext, + core_lifecycle: AstrBotCoreLifecycle, + ) -> None: super().__init__(context) + self.core_lifecycle = core_lifecycle self.routes = { "/commands": ("GET", self.get_commands), "/commands/conflicts": ("GET", self.get_conflicts), @@ -30,7 +36,7 @@ def __init__(self, context: RouteContext) -> None: self.register_routes() async def get_commands(self): - commands = await list_commands() + commands = await _list_dashboard_commands(self.core_lifecycle) summary = { "total": len(commands), "disabled": len([cmd for cmd in commands if not cmd["enabled"]]), @@ -44,62 +50,153 @@ async def get_conflicts(self): async def toggle_command(self): data = await request.get_json() - handler_full_name = data.get("handler_full_name") + command_key = _resolve_command_key(data) enabled = data.get("enabled") - if handler_full_name is None or enabled is None: - return Response().error("handler_full_name 与 enabled 均为必填。").__dict__ + if command_key is None or enabled is None: + return Response().error("command_key 与 enabled 均为必填。").__dict__ if isinstance(enabled, str): enabled = enabled.lower() in ("1", "true", "yes", "on") + item = await _get_command_payload(self.core_lifecycle, command_key) + if item.get("runtime_kind") == "sdk": + return ( + Response() + .error("SDK commands are read-only in the dashboard.") + .__dict__ + ) + try: - await toggle_command_service(handler_full_name, bool(enabled)) + await toggle_command_service(command_key, bool(enabled)) except ValueError as exc: return Response().error(str(exc)).__dict__ - payload = await _get_command_payload(handler_full_name) + payload = await _get_command_payload(self.core_lifecycle, command_key) return Response().ok(payload).__dict__ async def rename_command(self): data = await request.get_json() - handler_full_name = data.get("handler_full_name") + command_key = _resolve_command_key(data) new_name = data.get("new_name") aliases = data.get("aliases") - if not handler_full_name or not new_name: - return Response().error("handler_full_name 与 new_name 均为必填。").__dict__ + if not command_key or not new_name: + return Response().error("command_key 与 new_name 均为必填。").__dict__ + + item = await _get_command_payload(self.core_lifecycle, command_key) + if item.get("runtime_kind") == "sdk": + return ( + Response() + .error("SDK commands are read-only in the dashboard.") + .__dict__ + ) try: - await rename_command_service(handler_full_name, new_name, aliases=aliases) + await rename_command_service(command_key, new_name, aliases=aliases) except ValueError as exc: return Response().error(str(exc)).__dict__ - payload = await _get_command_payload(handler_full_name) + payload = await _get_command_payload(self.core_lifecycle, command_key) return Response().ok(payload).__dict__ async def update_permission(self): data = await request.get_json() - handler_full_name = data.get("handler_full_name") + command_key = _resolve_command_key(data) permission = data.get("permission") - if not handler_full_name or not permission: + if not command_key or not permission: + return Response().error("command_key 与 permission 均为必填。").__dict__ + + item = await _get_command_payload(self.core_lifecycle, command_key) + if item.get("runtime_kind") == "sdk": return ( - Response().error("handler_full_name 与 permission 均为必填。").__dict__ + Response() + .error("SDK commands are read-only in the dashboard.") + .__dict__ ) try: - await update_command_permission_service(handler_full_name, permission) + await update_command_permission_service(command_key, permission) except ValueError as exc: return Response().error(str(exc)).__dict__ - payload = await _get_command_payload(handler_full_name) + payload = await _get_command_payload(self.core_lifecycle, command_key) return Response().ok(payload).__dict__ -async def _get_command_payload(handler_full_name: str): - commands = await list_commands() - for cmd in commands: - if cmd["handler_full_name"] == handler_full_name: +def _resolve_command_key(data: dict | None) -> str | None: + if not isinstance(data, dict): + return None + command_key = data.get("command_key") + if command_key: + return str(command_key) + handler_full_name = data.get("handler_full_name") + if handler_full_name: + return str(handler_full_name) + return None + + +async def _list_dashboard_commands( + core_lifecycle: AstrBotCoreLifecycle, +) -> list[dict]: + commands = _decorate_legacy_commands(await list_commands()) + sdk_bridge = getattr(core_lifecycle, "sdk_plugin_bridge", None) + if sdk_bridge is not None: + commands.extend(sdk_bridge.list_dashboard_commands()) + _apply_conflict_flags(commands) + commands.sort(key=lambda item: str(item.get("effective_command", "")).lower()) + return commands + + +def _decorate_legacy_commands(commands: list[dict]) -> list[dict]: + for item in commands: + _decorate_legacy_command_item(item) + return commands + + +def _decorate_legacy_command_item(item: dict) -> None: + item["command_key"] = str(item.get("handler_full_name", "")) + item["runtime_kind"] = "legacy" + item["supports_toggle"] = True + item["supports_rename"] = True + item["supports_permission"] = True + sub_commands = item.get("sub_commands") + if not isinstance(sub_commands, list): + return + for sub in sub_commands: + if isinstance(sub, dict): + _decorate_legacy_command_item(sub) + + +def _apply_conflict_flags(commands: list[dict]) -> None: + counts: dict[str, int] = {} + for item in _walk_command_items(commands): + command_name = str(item.get("effective_command", "")).strip() + if not command_name or not bool(item.get("enabled", False)): + continue + counts[command_name] = counts.get(command_name, 0) + 1 + + for item in _walk_command_items(commands): + command_name = str(item.get("effective_command", "")).strip() + item["has_conflict"] = bool(command_name and counts.get(command_name, 0) > 1) + + +def _walk_command_items(commands: list[dict]): + for item in commands: + yield item + sub_commands = item.get("sub_commands") + if not isinstance(sub_commands, list): + continue + yield from _walk_command_items(sub_commands) + + +async def _get_command_payload( + core_lifecycle: AstrBotCoreLifecycle, + command_key: str, +): + commands = await _list_dashboard_commands(core_lifecycle) + for cmd in _walk_command_items(commands): + if cmd.get("command_key") == command_key: return cmd return {} diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index bcd7e075c7..72a45d27c6 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -1043,7 +1043,7 @@ async def post_plugin_configs(self): plugin_name = request.args.get("plugin_name", "unknown") try: await self._save_plugin_configs(post_configs, plugin_name) - await self.core_lifecycle.plugin_manager.reload(plugin_name) + await self._reload_plugin_after_config_save(plugin_name) return ( Response() .ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在热重载插件。") @@ -1058,6 +1058,16 @@ def _get_plugin_metadata_by_name(self, plugin_name: str) -> StarMetadata | None: return plugin_md return None + def _sdk_bridge(self): + return getattr(self.core_lifecycle, "sdk_plugin_bridge", None) + + async def _reload_plugin_after_config_save(self, plugin_name: str) -> None: + sdk_bridge = self._sdk_bridge() + if sdk_bridge is not None and sdk_bridge.get_plugin_metadata(plugin_name): + await sdk_bridge.reload_plugin(plugin_name) + return + await self.core_lifecycle.plugin_manager.reload(plugin_name) + def _resolve_config_file_scope( self, ) -> tuple[str, str, str, StarMetadata, AstrBotConfig]: @@ -1516,6 +1526,26 @@ async def _get_plugin_config(self, plugin_name: str): } break + if ret["metadata"] is not None: + return ret + + sdk_bridge = self._sdk_bridge() + if sdk_bridge is None: + return ret + + schema = sdk_bridge.get_plugin_config_schema(plugin_name) + if schema is None or not schema: + return ret + config = sdk_bridge.get_plugin_config(plugin_name) or {} + ret["config"] = config + ret["metadata"] = { + plugin_name: { + "description": f"{plugin_name} 配置", + "type": "object", + "items": schema, + }, + } + return ret async def _save_astrbot_configs( @@ -1542,18 +1572,40 @@ async def _save_plugin_configs(self, post_configs: dict, plugin_name: str) -> No if plugin_md.name == plugin_name: md = plugin_md - if not md: + if md: + if not md.config: + raise ValueError(f"插件 {plugin_name} 没有注册配置") + assert md.config is not None + + try: + errors, post_configs = validate_config( + post_configs, getattr(md.config, "schema", {}), is_core=False + ) + if errors: + raise ValueError(f"格式校验未通过: {errors}") + md.config.save_config(post_configs) + return + except Exception as e: + raise e + + sdk_bridge = self._sdk_bridge() + if sdk_bridge is None: + raise ValueError(f"插件 {plugin_name} 不存在") + + schema = sdk_bridge.get_plugin_config_schema(plugin_name) + if schema is None: raise ValueError(f"插件 {plugin_name} 不存在") - if not md.config: + if not schema: raise ValueError(f"插件 {plugin_name} 没有注册配置") - assert md.config is not None try: errors, post_configs = validate_config( - post_configs, getattr(md.config, "schema", {}), is_core=False + post_configs, + schema, + is_core=False, ) if errors: raise ValueError(f"格式校验未通过: {errors}") - md.config.save_config(post_configs) + sdk_bridge.save_plugin_config(plugin_name, post_configs) except Exception as e: raise e diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index d151bbe6f6..ef50b52fe2 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -87,6 +87,17 @@ def __init__( self._logo_cache = {} + def _sdk_bridge(self): + return getattr(self.core_lifecycle, "sdk_plugin_bridge", None) + + def _is_sdk_plugin(self, plugin_name: str) -> bool: + sdk_bridge = self._sdk_bridge() + if sdk_bridge is None: + return False + return any( + plugin["name"] == plugin_name for plugin in sdk_bridge.list_plugins() + ) + async def check_plugin_compatibility(self): try: data = await request.get_json() @@ -146,9 +157,19 @@ async def reload_plugins(self): data = await request.get_json() plugin_name = data.get("name", None) try: - success, message = await self.plugin_manager.reload(plugin_name) - if not success: - return Response().error(message or "插件重载失败").__dict__ + if plugin_name and self._is_sdk_plugin(plugin_name): + sdk_bridge = self._sdk_bridge() + if sdk_bridge is None: + return Response().error("SDK bridge 未初始化").__dict__ + await sdk_bridge.reload_plugin(plugin_name) + else: + success, message = await self.plugin_manager.reload(plugin_name) + if not success: + return Response().error(message or "插件重载失败").__dict__ + if plugin_name is None: + sdk_bridge = self._sdk_bridge() + if sdk_bridge is not None: + await sdk_bridge.reload_all(reset_restart_budget=True) return Response().ok(None, "重载成功。").__dict__ except Exception as e: logger.error(f"/api/plugin/reload: {traceback.format_exc()}") @@ -420,6 +441,12 @@ async def get_plugins(self): ): continue _plugin_resp.append(_t) + sdk_bridge = self._sdk_bridge() + if sdk_bridge is not None: + for plugin in sdk_bridge.list_plugins(): + if plugin_name and plugin["name"] != plugin_name: + continue + _plugin_resp.append(plugin) return ( Response() .ok(_plugin_resp, message=self.plugin_manager.failed_plugin_info) @@ -515,6 +542,8 @@ async def install_plugin(self): ignore_version_check=ignore_version_check, ) # self.core_lifecycle.restart() + if plugin_info and plugin_info.get("type") == "sdk": + logger.info("SDK 插件 %s 安装成功", plugin_info.get("name")) logger.info(f"安装插件 {repo_url} 成功。") return Response().ok(plugin_info, "安装成功。").__dict__ except PluginVersionIncompatibleError as e: @@ -556,6 +585,8 @@ async def install_plugin_upload(self): ignore_version_check=ignore_version_check, ) # self.core_lifecycle.restart() + if plugin_info and plugin_info.get("type") == "sdk": + logger.info("SDK 插件 %s 上传安装成功", plugin_info.get("name")) logger.info(f"安装插件 {file.filename} 成功") return Response().ok(plugin_info, "安装成功。").__dict__ except PluginVersionIncompatibleError as e: @@ -583,6 +614,10 @@ async def uninstall_plugin(self): plugin_name = post_data["name"] delete_config = post_data.get("delete_config", False) delete_data = post_data.get("delete_data", False) + if self._is_sdk_plugin(plugin_name): + return Response().error( + "SDK 插件在 MVP 中不支持卸载,请手动移除目录" + ).__dict__, 400 try: logger.info(f"正在卸载插件 {plugin_name}") await self.plugin_manager.uninstall_plugin( @@ -635,6 +670,8 @@ async def update_plugin(self): post_data = await request.get_json() plugin_name = post_data["name"] proxy: str = post_data.get("proxy", None) + if self._is_sdk_plugin(plugin_name): + return Response().error("SDK 插件在 MVP 中不支持更新").__dict__, 400 try: logger.info(f"正在更新插件 {plugin_name}") await self.plugin_manager.update_plugin(plugin_name, proxy) @@ -709,6 +746,16 @@ async def off_plugin(self): post_data = await request.get_json() plugin_name = post_data["name"] + if self._is_sdk_plugin(plugin_name): + sdk_bridge = self._sdk_bridge() + if sdk_bridge is None: + return Response().error("SDK bridge 未初始化").__dict__, 500 + try: + await sdk_bridge.turn_off_plugin(plugin_name) + except ValueError as exc: + return Response().error(str(exc)).__dict__, 404 + logger.info(f"停用 SDK 插件 {plugin_name} 。") + return Response().ok(None, "停用成功。").__dict__ try: await self.plugin_manager.turn_off_plugin(plugin_name) logger.info(f"停用插件 {plugin_name} 。") @@ -727,6 +774,16 @@ async def on_plugin(self): post_data = await request.get_json() plugin_name = post_data["name"] + if self._is_sdk_plugin(plugin_name): + sdk_bridge = self._sdk_bridge() + if sdk_bridge is None: + return Response().error("SDK bridge 未初始化").__dict__, 500 + try: + await sdk_bridge.turn_on_plugin(plugin_name) + except ValueError as exc: + return Response().error(str(exc)).__dict__, 404 + logger.info(f"启用 SDK 插件 {plugin_name} 。") + return Response().ok(None, "启用成功。").__dict__ try: await self.plugin_manager.turn_on_plugin(plugin_name) logger.info(f"启用插件 {plugin_name} 。") diff --git a/astrbot/dashboard/routes/skills.py b/astrbot/dashboard/routes/skills.py index 42ba7fd802..2610f9556e 100644 --- a/astrbot/dashboard/routes/skills.py +++ b/astrbot/dashboard/routes/skills.py @@ -316,24 +316,28 @@ async def download_skill(self): .__dict__ ) - skill_dir = Path(skill_mgr.skills_root) / name - skill_md = skill_dir / "SKILL.md" - if not skill_dir.is_dir() or not skill_md.exists(): + if skill_mgr.get_local_skill_source(name) is None: return Response().error("Local skill not found").__dict__ export_dir = Path(get_astrbot_temp_path()) / "skill_exports" export_dir.mkdir(parents=True, exist_ok=True) zip_base = export_dir / name zip_path = zip_base.with_suffix(".zip") + bundle_dir = export_dir / f"{name}_{uuid.uuid4().hex}" if zip_path.exists(): zip_path.unlink() - shutil.make_archive( - str(zip_base), - "zip", - root_dir=str(skill_mgr.skills_root), - base_dir=name, - ) + try: + skill_mgr.materialize_local_skill_bundle(bundle_dir, skill_names=[name]) + shutil.make_archive( + str(zip_base), + "zip", + root_dir=str(bundle_dir), + base_dir=name, + ) + finally: + if bundle_dir.exists(): + shutil.rmtree(bundle_dir, ignore_errors=True) return await send_file( str(zip_path), diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index 84f8dcc6d7..825abc005f 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -445,14 +445,20 @@ async def get_tool_list(self): origin_name = "unknown" tool_info = { + "tool_key": _build_legacy_tool_key(tool, origin, origin_name), "name": tool.name, "description": tool.description, "parameters": tool.parameters, "active": tool.active, "origin": origin, "origin_name": origin_name, + "runtime_kind": "legacy", + "plugin_id": None, } tools_dict.append(tool_info) + sdk_bridge = getattr(self.core_lifecycle, "sdk_plugin_bridge", None) + if sdk_bridge is not None: + tools_dict.extend(sdk_bridge.list_dashboard_tools()) return Response().ok(data=tools_dict).__dict__ except Exception as e: logger.error(traceback.format_exc()) @@ -463,28 +469,65 @@ async def toggle_tool(self): try: data = await request.json tool_name = data.get("name") + tool_key = data.get("tool_key") action = data.get("activate") # True or False + runtime_kind = str(data.get("runtime_kind", "legacy") or "legacy") + plugin_id = data.get("plugin_id") - if not tool_name or action is None: + if (not tool_name and not tool_key) or action is None: return ( Response() - .error("Missing required parameters: name or activate") + .error("Missing required parameters: tool_key/name or activate") .__dict__ ) - if action: - try: - ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map) - except ValueError as e: - return Response().error(f"Failed to activate tool: {e!s}").__dict__ + if runtime_kind == "sdk": + sdk_bridge = getattr(self.core_lifecycle, "sdk_plugin_bridge", None) + if sdk_bridge is None: + return Response().error("SDK bridge is unavailable.").__dict__ + if not plugin_id or not tool_name: + return ( + Response() + .error("SDK tool toggle requires plugin_id and name") + .__dict__ + ) + plugin_metadata = sdk_bridge.get_plugin_metadata(str(plugin_id)) + if ( + action + and plugin_metadata is not None + and not plugin_metadata.get("enabled", False) + ): + return ( + Response() + .error( + "The SDK plugin is disabled. Enable the plugin before activating its tool." + ) + .__dict__ + ) + if action: + ok = sdk_bridge.activate_llm_tool(str(plugin_id), str(tool_name)) + else: + ok = sdk_bridge.deactivate_llm_tool(str(plugin_id), str(tool_name)) else: - ok = self.tool_mgr.deactivate_llm_tool(tool_name) + if action: + try: + ok = self.tool_mgr.activate_llm_tool( + str(tool_name), star_map=star_map + ) + except ValueError as e: + return ( + Response().error(f"Failed to activate tool: {e!s}").__dict__ + ) + else: + ok = self.tool_mgr.deactivate_llm_tool(str(tool_name)) if ok: return Response().ok(None, "Operation successful.").__dict__ return ( Response() - .error(f"Tool {tool_name} does not exist or the operation failed.") + .error( + f"Tool {tool_key or tool_name} does not exist or the operation failed." + ) .__dict__ ) @@ -510,3 +553,11 @@ async def sync_provider(self): except Exception as e: logger.error(traceback.format_exc()) return Response().error(f"Sync failed: {e!s}").__dict__ + + +def _build_legacy_tool_key(tool, origin: str, origin_name: str) -> str: + if origin == "mcp" and origin_name: + return f"mcp:{origin_name}:{tool.name}" + if origin == "plugin" and getattr(tool, "handler_module_path", None): + return f"plugin:{tool.handler_module_path}:{tool.name}" + return f"tool:{tool.name}" diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index a4742aa672..e265d5076b 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -13,6 +13,7 @@ from hypercorn.asyncio import serve from hypercorn.config import Config as HyperConfig from quart import Quart, g, jsonify, request +from quart import Response as QuartResponse from quart.logging import default_handler from astrbot.core import logger @@ -108,7 +109,7 @@ def __init__( core_lifecycle, core_lifecycle.plugin_manager, ) - self.command_route = CommandRoute(self.context) + self.command_route = CommandRoute(self.context, core_lifecycle) self.cr = ConfigRoute(self.context, core_lifecycle) self.lr = LogRoute(self.context, core_lifecycle.log_broker) self.sfr = StaticFileRoute(self.context) @@ -157,8 +158,46 @@ async def srv_plug_route(self, subpath, *args, **kwargs): route, view_handler, methods, _ = api if route == f"/{subpath}" and request.method in methods: return await view_handler(*args, **kwargs) + sdk_bridge = getattr(self.core_lifecycle, "sdk_plugin_bridge", None) + if sdk_bridge is not None: + output = await sdk_bridge.dispatch_http_request( + f"/{subpath}", request.method + ) + if output is not None: + return self._build_sdk_plugin_response(output) return jsonify(Response().error("未找到该路由").__dict__) + @staticmethod + def _build_sdk_plugin_response(output: dict) -> QuartResponse: + status = int(output.get("status", 200)) + headers = output.get("headers") + if headers is None: + headers = {} + if not isinstance(headers, dict): + raise ValueError("SDK HTTP handler headers must be an object") + + body = output.get("body") + if isinstance(body, (dict, list)): + response = jsonify(body) + response.status_code = status + response.headers.setdefault("Content-Type", "application/json") + elif isinstance(body, str): + response = QuartResponse( + body, + status=status, + content_type="text/plain; charset=utf-8", + ) + elif body is None: + response = QuartResponse("", status=status) + else: + raise ValueError( + "SDK HTTP handler body must be object, array, string or null" + ) + + for key, value in headers.items(): + response.headers[str(key)] = str(value) + return response + async def auth_middleware(self): if not request.path.startswith("/api"): return None diff --git a/dashboard/src/components/extension/componentPanel/components/CommandTable.vue b/dashboard/src/components/extension/componentPanel/components/CommandTable.vue index 32eebb746b..d9d281e971 100644 --- a/dashboard/src/components/extension/componentPanel/components/CommandTable.vue +++ b/dashboard/src/components/extension/componentPanel/components/CommandTable.vue @@ -90,6 +90,10 @@ const getRowProps = ({ item }: { item: CommandItem }) => { } return classes.length > 0 ? { class: classes.join(' ') } : {}; }; + +const canToggle = (cmd: CommandItem): boolean => cmd.supports_toggle !== false; +const canRename = (cmd: CommandItem): boolean => cmd.supports_rename !== false; +const canEditPermission = (cmd: CommandItem): boolean => cmd.supports_permission !== false;