From 6baa994b14f103cb4d6033e756760de9586c3894 Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 6 Feb 2026 15:03:04 +0800 Subject: [PATCH 1/8] fix: video gen exclude edit_file --- projects/singularity_cinema/agent.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/projects/singularity_cinema/agent.yaml b/projects/singularity_cinema/agent.yaml index d171fc7fc..dc1756486 100644 --- a/projects/singularity_cinema/agent.yaml +++ b/projects/singularity_cinema/agent.yaml @@ -279,6 +279,7 @@ tools: mcp: false allow_read_all_files: true exclude: + - edit_file - list_files - search_file_content - search_file_name From 693d9e95d5f9d013f8af32789de5ffc381cd10cb Mon Sep 17 00:00:00 2001 From: suluyan Date: Sun, 15 Mar 2026 17:48:32 +0800 Subject: [PATCH 2/8] feat: support search local paths through sirchmunk --- docs/en/Components/Config.md | 17 + docs/zh/Components/config.md | 80 ++-- examples/knowledge_search/agent.yaml.example | 86 ++++ ms_agent/agent/llm_agent.py | 70 ++- ms_agent/cli/run.py | 55 +++ ms_agent/knowledge_search/README.md | 277 ++++++++++++ ms_agent/knowledge_search/__init__.py | 11 + ms_agent/knowledge_search/sirchmunk_search.py | 401 ++++++++++++++++++ ms_agent/llm/dashscope_llm.py | 2 +- ms_agent/llm/utils.py | 6 + ms_agent/rag/utils.py | 3 + tests/knowledge_search/__init__.py | 2 + tests/knowledge_search/test_sirschmunk.py | 203 +++++++++ 13 files changed, 1179 insertions(+), 34 deletions(-) create mode 100644 examples/knowledge_search/agent.yaml.example create mode 100644 ms_agent/knowledge_search/README.md create mode 100644 ms_agent/knowledge_search/__init__.py create mode 100644 ms_agent/knowledge_search/sirchmunk_search.py create mode 100644 tests/knowledge_search/__init__.py create mode 100644 tests/knowledge_search/test_sirschmunk.py diff --git a/docs/en/Components/Config.md b/docs/en/Components/Config.md index d40f03898..f1253bd75 100644 --- a/docs/en/Components/Config.md +++ b/docs/en/Components/Config.md @@ -166,3 +166,20 @@ In addition to yaml configuration, MS-Agent also supports several additional com ``` > Any configuration in agent.yaml can be passed in with new values via command line, and also supports reading from environment variables with the same name (case insensitive), for example `--llm.modelscope_api_key xxx-xxx`. + +- knowledge_search_paths: Knowledge search paths, comma-separated multiple paths. When provided, automatically enables SirchmunkSearch for knowledge retrieval, with LLM configuration automatically inherited from the `llm` module. + +### Quick Start for Knowledge Search + +Use the `--knowledge_search_paths` parameter to quickly enable knowledge search based on local documents: + +```bash +# Using default agent.yaml configuration, automatically reuses LLM settings +ms-agent run --query "How to implement user authentication?" --knowledge_search_paths "./src,./docs" + +# Specify configuration file +ms-agent run --config /path/to/agent.yaml --query "your question" --knowledge_search_paths "/path/to/docs" +``` + +LLM-related parameters (api_key, base_url, model) are automatically inherited from the `llm` module in the configuration file, no need to configure them repeatedly. +If you need to use independent LLM configuration in the `knowledge_search` module, you can explicitly configure `knowledge_search.llm_api_key` and other parameters in the yaml. diff --git a/docs/zh/Components/config.md b/docs/zh/Components/config.md index 041820b93..12849f2a7 100644 --- a/docs/zh/Components/config.md +++ b/docs/zh/Components/config.md @@ -1,12 +1,12 @@ --- slug: config title: 配置与参数 -description: Ms-Agent 配置与参数:类型配置、自定义代码、LLM配置、推理配置、system和query、callbacks、工具配置、其他、config_handler、命令行配置 +description: Ms-Agent 配置与参数:类型配置、自定义代码、LLM 配置、推理配置、system 和 query、callbacks、工具配置、其他、config_handler、命令行配置 --- # 配置与参数 -MS-Agent使用一个yaml文件进行配置管理,通常这个文件被命名为`agent.yaml`,这样的设计使不同场景可以读取不同的配置文件。该文件具体包含的字段有: +MS-Agent 使用一个 yaml 文件进行配置管理,通常这个文件被命名为 `agent.yaml`,这样的设计使不同场景可以读取不同的配置文件。该文件具体包含的字段有: ## 类型配置 @@ -17,31 +17,31 @@ MS-Agent使用一个yaml文件进行配置管理,通常这个文件被命名 type: llmagent ``` -标识本配置对应的agent类型,支持`llmagent`和`codeagent`两类。默认为`llmagent`。如果yaml中包含了code_file字段,则code_file优先生效。 +标识本配置对应的 agent 类型,支持 `llmagent` 和 `codeagent` 两类。默认为 `llmagent`。如果 yaml 中包含了 code_file 字段,则 code_file 优先生效。 ## 自定义代码 -> 可选,在需要自定义LLMAgent时使用 +> 可选,在需要自定义 LLMAgent 时使用 ```yaml code_file: custom_agent ``` -可以使用一个外部agent类,该类需要继承自`LLMAgent`。可以复写其中的若干方法,如果code_file有值,则`type`字段不生效。 +可以使用一个外部 agent 类,该类需要继承自 `LLMAgent`。可以复写其中的若干方法,如果 code_file 有值,则 `type` 字段不生效。 -## LLM配置 +## LLM 配置 > 必须存在 ```yaml llm: - # 大模型服务backend + # 大模型服务 backend service: modelscope - # 模型id + # 模型 id model: Qwen/Qwen3-235B-A22B-Instruct-2507 - # 模型api_key + # 模型 api_key modelscope_api_key: - # 模型base_url + # 模型 base_url modelscope_base_url: https://api-inference.modelscope.cn/v1 ``` @@ -51,7 +51,7 @@ llm: ```yaml generation_config: - # 下面的字段均为OpenAI sdk的标准参数,你也可以配置OpenAI支持的其他参数在这里。 + # 下面的字段均为 OpenAI sdk 的标准参数,你也可以配置 OpenAI 支持的其他参数在这里。 top_p: 0.6 temperature: 0.2 top_k: 20 @@ -60,25 +60,25 @@ generation_config: enable_thinking: false ``` -## system和query +## system 和 query -> 可选,但推荐传入system +> 可选,但推荐传入 system ```yaml prompt: - # LLM system,如果不传递则使用默认的`you are a helpful assistant.` + # LLM system,如果不传递则使用默认的 `you are a helpful assistant.` system: - # LLM初始query,通常来说可以不使用 + # LLM 初始 query,通常来说可以不使用 query: ``` ## callbacks -> 可选,推荐自定义callbacks +> 可选,推荐自定义 callbacks ```yaml callbacks: - # 用户输入callback,该callback在assistant回复后自动等待用户输入 + # 用户输入 callback,该 callback 在 assistant 回复后自动等待用户输入 - input_callback ``` @@ -90,9 +90,9 @@ callbacks: tools: # 工具名称 file_system: - # 是否是mcp + # 是否是 mcp mcp: false - # 排除的function,可以为空 + # 排除的 function,可以为空 exclude: - create_directory - write_file @@ -104,20 +104,20 @@ tools: - map_geo ``` -支持的完整工具列表,以及自定义工具请参考[这里](./tools) +支持的完整工具列表,以及自定义工具请参考 [这里](./tools) ## 其他 > 可选,按需配置 ```yaml -# 自动对话轮数,默认为20轮 +# 自动对话轮数,默认为 20 轮 max_chat_round: 9999 # 工具调用超时时间,单位秒 tool_call_timeout: 30000 -# 输出artifact目录 +# 输出 artifact 目录 output_dir: output # 帮助信息,通常在运行错误后出现 @@ -127,13 +127,13 @@ help: | ## config_handler -为了便于在任务开始时对config进行定制化,MS-Agent构建了一个名为`ConfigLifecycleHandler`的机制。这是一个callback类,开发者可以在yaml文件中增加这样一个配置: +为了便于在任务开始时对 config 进行定制化,MS-Agent 构建了一个名为 `ConfigLifecycleHandler` 的机制。这是一个 callback 类,开发者可以在 yaml 文件中增加这样一个配置: ```yaml handler: custom_handler ``` -这代表和yaml文件同级有一个custom_handler.py文件,该文件的类继承自`ConfigLifecycleHandler`,分别有两个方法: +这代表和 yaml 文件同级有一个 custom_handler.py 文件,该文件的类继承自 `ConfigLifecycleHandler`,分别有两个方法: ```python def task_begin(self, config: DictConfig, tag: str) -> DictConfig: @@ -143,18 +143,18 @@ handler: custom_handler return config ``` -`task_begin`在LLMAgent类构造时生效,在该方法中可以对config进行一些修改。如果你的工作流中下游任务会继承上游的yaml配置,这个机制会有帮助。值得注意的是`tag`参数,该参数会传入当前LLMAgent的名字,方便分辨当前工作流的节点。 +`task_begin` 在 LLMAgent 类构造时生效,在该方法中可以对 config 进行一些修改。如果你的工作流中下游任务会继承上游的 yaml 配置,这个机制会有帮助。值得注意的是 `tag` 参数,该参数会传入当前 LLMAgent 的名字,方便分辨当前工作流的节点。 ## 命令行配置 -在yaml配置之外,MS-Agent还支持若干额外的命令行参数。 +在 yaml 配置之外,MS-Agent 还支持若干额外的命令行参数。 -- query: 初始query,这个query的优先级高于yaml中的prompt.query -- config: 配置文件路径,支持modelscope model-id -- trust_remote_code: 是否信任外部代码。如果某个配置包含了一些外部代码,需要将这个参数置为true才会生效 -- load_cache: 从历史messages继续对话。cache会被自动存储在`output`配置中。默认为`False` -- mcp_server_file: 可以读取一个外部的mcp工具配置,格式为: +- query: 初始 query,这个 query 的优先级高于 yaml 中的 prompt.query +- config: 配置文件路径,支持 modelscope model-id +- trust_remote_code: 是否信任外部代码。如果某个配置包含了一些外部代码,需要将这个参数置为 true 才会生效 +- load_cache: 从历史 messages 继续对话。cache 会被自动存储在 `output` 配置中。默认为 `False` +- mcp_server_file: 可以读取一个外部的 mcp 工具配置,格式为: ```json { "mcpServers": { @@ -165,5 +165,21 @@ handler: custom_handler } } ``` +- knowledge_search_paths: 知识搜索路径,逗号分隔的多个路径。传入后会自动启用 SirchmunkSearch 进行知识检索,LLM 配置自动从 `llm` 模块复用 -> agent.yaml中的任意一个配置,都可以使用命令行传入新的值, 也支持从同名(大小写不敏感)环境变量中读取,例如`--llm.modelscope_api_key xxx-xxx`。 +> agent.yaml 中的任意一个配置,都可以使用命令行传入新的值,也支持从同名(大小写不敏感)环境变量中读取,例如 `--llm.modelscope_api_key xxx-xxx`。 + +### 知识搜索快速使用 + +通过 `--knowledge_search_paths` 参数,可以快速启用基于本地文档的知识搜索: + +```bash +# 使用默认 agent.yaml 配置,自动复用 LLM 设置 +ms-agent run --query "如何实现用户认证?" --knowledge_search_paths "./src,./docs" + +# 指定配置文件 +ms-agent run --config /path/to/agent.yaml --query "你的问题" --knowledge_search_paths "/path/to/docs" +``` + +LLM 相关参数(api_key, base_url, model)会自动从配置文件的 `llm` 模块继承,无需重复配置。 +如果需要在 `knowledge_search` 模块中使用独立的 LLM 配置,可以在 yaml 中显式配置 `knowledge_search.llm_api_key` 等参数。 diff --git a/examples/knowledge_search/agent.yaml.example b/examples/knowledge_search/agent.yaml.example new file mode 100644 index 000000000..cc11a8a3d --- /dev/null +++ b/examples/knowledge_search/agent.yaml.example @@ -0,0 +1,86 @@ +# Sirchmunk Knowledge Search 配置示例 +# Sirchmunk Knowledge Search Configuration Example + +# 在您的 agent.yaml 或 workflow.yaml 中添加以下配置: + +llm: + service: modelscope + model: Qwen/Qwen3-235B-A22B-Instruct-2507 + modelscope_api_key: + modelscope_base_url: https://api-inference.modelscope.cn/v1 + +generation_config: + temperature: 0.3 + top_k: 20 + stream: true + +# Knowledge Search 配置(可选) +# 用于在本地代码库中搜索相关信息 +knowledge_search: + # 必选:要搜索的路径列表 + paths: + - ./src + - ./docs + + # 可选:sirchmunk 工作目录,用于缓存 + work_path: ./.sirchmunk + + # 可选:LLM 配置(如不配置则使用上面 llm 的配置) + llm_api_key: + llm_base_url: https://api.openai.com/v1 + llm_model_name: gpt-4o-mini + + # 可选:Embedding 模型 + embedding_model: text-embedding-3-small + + # 可选:聚类相似度阈值 + cluster_sim_threshold: 0.85 + + # 可选:聚类 TopK + cluster_sim_top_k: 3 + + # 可选:是否重用之前的知识 + reuse_knowledge: true + + # 可选:搜索模式 (DEEP, FAST, FILENAME_ONLY) + mode: FAST + + # 可选:最大循环次数 + max_loops: 10 + + # 可选:最大 token 预算 + max_token_budget: 128000 + +prompt: + system: | + You are an assistant that helps me complete tasks. + +max_chat_round: 9999 + +# 使用说明: +# 1. 配置 knowledge_search 后,LLMAgent 会在处理用户请求时自动搜索本地代码库 +# 2. 搜索结果会自动添加到 user message 的 search_result 和 searching_detail 字段 +# 3. search_result 包含搜索到的相关文档,会作为上下文提供给 LLM +# 4. searching_detail 包含搜索日志和元数据,可用于前端展示 +# +# Python 使用示例: +# ```python +# from ms_agent import LLMAgent +# from ms_agent.config import Config +# +# config = Config.from_task('path/to/agent.yaml') +# agent = LLMAgent(config=config) +# result = await agent.run('如何实现用户认证功能?') +# +# # 获取搜索详情(用于前端展示) +# for msg in result: +# if msg.role == 'user': +# print(f"Search logs: {msg.searching_detail}") +# print(f"Search results: {msg.search_result}") +# ``` +# +# CLI 测试命令: +# export LLM_API_KEY="your-api-key" +# export LLM_BASE_URL="https://api.openai.com/v1" +# export LLM_MODEL_NAME="gpt-4o-mini" +# python tests/knowledge_search/test_cli.py --query "你的问题" diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 740eab690..3bc40c1fc 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -19,6 +19,7 @@ from ms_agent.memory.memory_manager import SharedMemoryManager from ms_agent.rag.base import RAG from ms_agent.rag.utils import rag_mapping +from ms_agent.knowledge_search import SirchmunkSearch from ms_agent.tools import ToolManager from ms_agent.utils import async_retry, read_history, save_history from ms_agent.utils.constants import DEFAULT_TAG, DEFAULT_USER @@ -104,6 +105,7 @@ def __init__(self, self.tool_manager: Optional[ToolManager] = None self.memory_tools: List[Memory] = [] self.rag: Optional[RAG] = None + self.knowledge_search: Optional[SirschmunkSearch] = None self.llm: Optional[LLM] = None self.runtime: Optional[Runtime] = None self.max_chat_round: int = 0 @@ -619,8 +621,52 @@ async def create_messages( return messages async def do_rag(self, messages: List[Message]): + """Process RAG or knowledge search to enrich the user query with context. + + This method handles both traditional RAG and sirchmunk-based knowledge search. + For knowledge search, it also populates searching_detail and search_result + fields in the message for frontend display and next-turn LLM context. + + Args: + messages (List[Message]): The message list to process. + """ + user_message = messages[1] if len(messages) > 1 else None + if user_message is None or user_message.role != 'user': + return + + query = user_message.content + + # Handle traditional RAG if self.rag is not None: - messages[1].content = await self.rag.query(messages[1].content) + user_message.content = await self.rag.query(query) + + # Handle sirchmunk knowledge search + if self.knowledge_search is not None: + # Perform search and get results + search_result = await self.knowledge_search.retrieve(query) + search_details = self.knowledge_search.get_search_details() + + # Store search details in the message for frontend display + user_message.searching_detail = search_details + user_message.search_result = search_result + + # Build enriched context from search results + if search_result: + context_parts = [] + for i, result in enumerate(search_result, 1): + text = result.get('text', '') + source = result.get('metadata', {}).get('source', 'unknown') + score = result.get('score', 0) + context_parts.append( + f"[Source {i}] {source} (relevance: {score:.2f})\n{text}\n" + ) + + # Append search context to user query + context = '\n'.join(context_parts) + user_message.content = ( + f"Relevant context retrieved from codebase search:\n\n{context}\n\n" + f"User question: {query}" + ) async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: @@ -706,6 +752,27 @@ async def prepare_rag(self): f'which supports: {list(rag_mapping.keys())}') self.rag: RAG = rag_mapping(rag.name)(self.config) + async def prepare_knowledge_search(self): + """Load and initialize the knowledge search component from the config.""" + if hasattr(self.config, 'knowledge_search'): + ks_config = self.config.knowledge_search + if ks_config is not None: + # Extract LLM config for sirchmunk + if hasattr(self.config, 'llm'): + llm_config = self.config.llm + # Update knowledge_search config with LLM settings if not specified + if not hasattr(ks_config, 'llm_api_key') and hasattr(llm_config, 'modelscope_api_key'): + OmegaConf.update(self.config, 'knowledge_search.llm_api_key', + getattr(llm_config, 'modelscope_api_key', None), merge=True) + if not hasattr(ks_config, 'llm_base_url') and hasattr(llm_config, 'modelscope_base_url'): + OmegaConf.update(self.config, 'knowledge_search.llm_base_url', + getattr(llm_config, 'modelscope_base_url', None), merge=True) + if not hasattr(ks_config, 'llm_model_name') and hasattr(llm_config, 'model'): + OmegaConf.update(self.config, 'knowledge_search.llm_model_name', + getattr(llm_config, 'model', None), merge=True) + + self.knowledge_search: SirchmunkSearch = SirchmunkSearch(self.config) + async def condense_memory(self, messages: List[Message]) -> List[Message]: """ Update memory using the current conversation history. @@ -1044,6 +1111,7 @@ async def run_loop(self, messages: Union[List[Message], str], await self.prepare_tools() await self.load_memory() await self.prepare_rag() + await self.prepare_knowledge_search() self.runtime.tag = self.tag if messages is None: diff --git a/ms_agent/cli/run.py b/ms_agent/cli/run.py index c2df42eee..cfe387e5a 100644 --- a/ms_agent/cli/run.py +++ b/ms_agent/cli/run.py @@ -4,6 +4,8 @@ import os from importlib import resources as importlib_resources +from omegaconf import OmegaConf + from ms_agent.config import Config from ms_agent.utils import get_logger, strtobool from ms_agent.utils.constants import AGENT_CONFIG_FILE, MS_AGENT_ASCII @@ -46,6 +48,22 @@ class RunCMD(CLICommand): def __init__(self, args): self.args = args + def load_env_file(self): + """Load environment variables from .env file in current directory.""" + env_file = os.path.join(os.getcwd(), '.env') + if os.path.exists(env_file): + with open(env_file, 'r') as f: + for line in f: + line = line.strip() + if line and not line.startswith('#') and '=' in line: + key, value = line.split('=', 1) + key = key.strip() + value = value.strip() + # Only set if not already set in environment + if key not in os.environ: + os.environ[key] = value + logger.debug(f'Loaded {key} from .env file') + @staticmethod def define_args(parsers: argparse.ArgumentParser): """Define args for run command.""" @@ -120,6 +138,14 @@ def define_args(parsers: argparse.ArgumentParser): help= 'Animation mode for video_generate project: auto (default) or human.' ) + parser.add_argument( + '--knowledge_search_paths', + required=False, + type=str, + default=None, + help= + 'Comma-separated list of paths for knowledge search. When provided, enables SirchmunkSearch using LLM config from llm module.' + ) parser.set_defaults(func=subparser_func) def execute(self): @@ -150,10 +176,18 @@ def execute(self): return self._execute_with_config() def _execute_with_config(self): + # Load environment variables from .env file if exists + self.load_env_file() + if not self.args.config: current_dir = os.getcwd() if os.path.exists(os.path.join(current_dir, AGENT_CONFIG_FILE)): self.args.config = os.path.join(current_dir, AGENT_CONFIG_FILE) + else: + # Use built-in default agent.yaml from package + default_config_path = importlib_resources.files('ms_agent').joinpath('agent', AGENT_CONFIG_FILE) + with importlib_resources.as_file(default_config_path) as config_file: + self.args.config = str(config_file) elif not os.path.exists(self.args.config): from modelscope import snapshot_download self.args.config = snapshot_download(self.args.config) @@ -190,6 +224,27 @@ def _execute_with_config(self): config = Config.from_task(self.args.config) + # If knowledge_search_paths is provided, configure SirchmunkSearch + if getattr(self.args, 'knowledge_search_paths', None): + paths = [p.strip() for p in self.args.knowledge_search_paths.split(',') if p.strip()] + if paths: + if 'knowledge_search' not in config or not config.knowledge_search: + # No existing knowledge_search config, create minimal config + # LLM settings will be auto-reused from llm module by SirchmunkSearch + knowledge_search_config = { + 'name': 'SirchmunkSearch', + 'paths': paths, + 'work_path': './.sirchmunk', + 'mode': 'FAST', + } + config['knowledge_search'] = OmegaConf.create(knowledge_search_config) + else: + # Existing knowledge_search config found, only update paths + # LLM settings are already handled by SirchmunkSearch internally + existing = OmegaConf.to_container(config.knowledge_search, resolve=True) + existing['paths'] = paths + config['knowledge_search'] = OmegaConf.create(existing) + if Config.is_workflow(config): from ms_agent.workflow.loader import WorkflowLoader engine = WorkflowLoader.build( diff --git a/ms_agent/knowledge_search/README.md b/ms_agent/knowledge_search/README.md new file mode 100644 index 000000000..00743601e --- /dev/null +++ b/ms_agent/knowledge_search/README.md @@ -0,0 +1,277 @@ +# Sirchmunk Knowledge Search 集成 + +本模块实现了 [sirchmunk](https://github.com/modelscope/sirchmunk) 与 ms_agent 框架的集成,提供了基于代码库的智能搜索功能。 + +## 功能特性 + +- **智能代码搜索**: 使用 LLM 和 embedding 模型对代码库进行语义搜索 +- **多模式搜索**: 支持 FAST、DEEP、FILENAME_ONLY 三种搜索模式 +- **知识复用**: 自动缓存和复用之前的搜索结果,减少 LLM 调用 +- **前端友好**: 提供详细的搜索日志和结果,方便前端展示 +- **无缝集成**: 与 LLMAgent 无缝集成,像使用 RAG 一样简单 + +## 安装 + +```bash +pip install sirchmunk +``` + +## 配置 + +在您的 `agent.yaml` 或 `workflow.yaml` 中添加以下配置: + +```yaml +llm: + service: dashscope + model: qwen3.5-plus + dashscope_api_key: + dashscope_base_url: https://dashscope.aliyuncs.com/compatible-mode/v1 + +generation_config: + temperature: 0.3 + stream: true + +# Knowledge Search 配置 +knowledge_search: + # 必选:要搜索的路径列表 + paths: + - ./src + - ./docs + + # 可选:sirchmunk 工作目录 + work_path: ./.sirchmunk + + # 可选:LLM 配置(如不配置则自动复用上面 llm 模块的配置) + # llm_api_key: + # llm_base_url: https://api.openai.com/v1 + # llm_model_name: gpt-4o-mini + + # 可选:Embedding 模型 + embedding_model: text-embedding-3-small + + # 可选:搜索模式 (DEEP, FAST, FILENAME_ONLY) + mode: FAST + + # 可选:是否重用之前的知识 + reuse_knowledge: true +``` + +**LLM 配置自动复用机制**: + +`SirchmunkSearch` 会自动从主配置的 `llm` 模块复用 LLM 相关参数: +- 如果 `knowledge_search.llm_api_key` 未配置,自动使用 `llm.{service}_api_key` +- 如果 `knowledge_search.llm_base_url` 未配置,自动使用 `llm.{service}_base_url` +- 如果 `knowledge_search.llm_model_name` 未配置,自动使用 `llm.model` + +其中 `service` 是 `llm.service` 的值(如 `dashscope`, `modelscope`, `openai` 等)。 + +通过 CLI 使用时,只需传入 `--knowledge_search_paths` 参数,无需额外配置 LLM 参数。 + +## 使用方式 + +### 1. 通过 CLI 使用(推荐) + +从命令行直接运行,无需编写代码: + +```bash +# 基本用法 - LLM 配置自动从 agent.yaml 的 llm 模块复用 +ms-agent run --query "如何实现用户认证功能?" --knowledge_search_paths "./src,./docs" + +# 指定配置文件 +ms-agent run --config /path/to/agent.yaml --query "你的问题" --knowledge_search_paths "/path/to/docs" +``` + +**说明**: +- `--knowledge_search_paths` 参数支持逗号分隔的多个路径 +- LLM 相关配置(api_key, base_url, model)会自动从配置文件的 `llm` 模块复用 +- 如果 `knowledge_search` 模块单独配置了 `llm_api_key` 等参数,则优先使用模块自己的配置 + +### 2. 通过 LLMAgent 使用 + +```python +from ms_agent import LLMAgent +from ms_agent.config import Config + +# 从配置文件加载 +config = Config.from_task('path/to/agent.yaml') +agent = LLMAgent(config=config) + +# 运行查询 - 会自动触发知识搜索 +result = await agent.run('如何实现用户认证功能?') + +# 获取搜索结果 +for msg in result: + if msg.role == 'user': + # 搜索详情(用于前端展示) + print(f"Search logs: {msg.searching_detail}") + # 搜索结果(作为 LLM 上下文) + print(f"Search results: {msg.search_result}") +``` + +### 2. 单独使用 SirchmunkSearch + +```python +from ms_agent.knowledge_search import SirchmunkSearch +from omegaconf import DictConfig + +config = DictConfig({ + 'knowledge_search': { + 'paths': ['./src', './docs'], + 'work_path': './.sirchmunk', + 'llm_api_key': 'your-api-key', + 'llm_model_name': 'gpt-4o-mini', + 'mode': 'FAST', + } +}) + +searcher = SirchmunkSearch(config) + +# 查询(返回合成答案) +answer = await searcher.query('如何实现用户认证?') + +# 检索(返回原始搜索结果) +results = await searcher.retrieve( + query='用户认证', + limit=5, + score_threshold=0.7 +) + +# 获取搜索日志 +logs = searcher.get_search_logs() + +# 获取搜索详情 +details = searcher.get_search_details() +``` + +## 环境变量 + +可以通过环境变量配置: + +```bash +# LLM 配置(如不设置则自动从 agent.yaml 的 llm 模块读取) +export LLM_API_KEY="your-api-key" +export LLM_BASE_URL="https://api.openai.com/v1" +export LLM_MODEL_NAME="gpt-4o-mini" + +# Embedding 模型配置 +export EMBEDDING_MODEL_ID="text-embedding-3-small" +export SIRCHMUNK_WORK_PATH="./.sirchmunk" +``` + +**注意**:通过 CLI 使用时,推荐直接在 `.env` 文件或 agent.yaml 中配置 LLM 参数,`SirchmunkSearch` 会自动复用。 + +## 测试 + +### 单元测试 + +```bash +export LLM_API_KEY="your-api-key" +export LLM_BASE_URL="https://api.openai.com/v1" +export LLM_MODEL_NAME="gpt-4o-mini" + +python -m unittest tests/knowledge_search/test_sirschmunk.py +``` + +### CLI 测试 + +```bash +# 基本测试 +python tests/knowledge_search/test_cli.py + +# 指定查询 +python tests/knowledge_search/test_cli.py -q "如何实现用户认证?" + +# 仅测试 standalone 模式 +python tests/knowledge_search/test_cli.py -m standalone + +# 仅测试 agent 模式 +python tests/knowledge_search/test_cli.py -m agent +``` + +## 配置参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| paths | List[str] | 必选 | 要搜索的目录/文件路径列表 | +| work_path | str | ./.sirchmunk | sirchmunk 工作目录,用于缓存 | +| llm_api_key | str | 从 llm 配置继承 | LLM API 密钥 | +| llm_base_url | str | 从 llm 配置继承 | LLM API 基础 URL | +| llm_model_name | str | 从 llm 配置继承 | LLM 模型名称 | +| embedding_model | str | text-embedding-3-small | Embedding 模型 ID | +| cluster_sim_threshold | float | 0.85 | 聚类相似度阈值 | +| cluster_sim_top_k | int | 3 | 聚类 TopK 数量 | +| reuse_knowledge | bool | true | 是否重用之前的知识 | +| mode | str | FAST | 搜索模式 (DEEP/FAST/FILENAME_ONLY) | +| max_loops | int | 10 | 最大搜索循环次数 | +| max_token_budget | int | 128000 | 最大 token 预算 | + +## 搜索模式 + +- **FAST**: 快速模式,使用贪婪策略,1-5 秒内返回结果,0-2 次 LLM 调用 +- **DEEP**: 深度模式,并行多路径检索 + ReAct 优化,5-30 秒,4-6 次 LLM 调用 +- **FILENAME_ONLY**: 仅文件名模式,基于模式匹配,无 LLM 调用,非常快 + +## Message 字段扩展 + +为了支持知识搜索,`Message` 类增加了两个字段: + +- **searching_detail** (Dict[str, Any]): 搜索过程日志和元数据,用于前端展示 + - `logs`: 搜索日志列表 + - `mode`: 使用的搜索模式 + - `paths`: 搜索的路径 + - `work_path`: 工作目录 + - `reuse_knowledge`: 是否重用知识 + +- **search_result** (List[Dict[str, Any]]): 搜索结果,作为下一轮 LLM 的上下文 + - `text`: 文档内容 + - `score`: 相关性分数 + - `metadata`: 元数据(如源文件、类型等) + +## 工作原理 + +1. 用户发送查询 +2. LLMAgent 调用 `prepare_knowledge_search()` 初始化 SirchmunkSearch +3. `do_rag()` 方法执行知识搜索: + - 调用 `searcher.retrieve()` 获取相关文档 + - 将搜索结果存入 `message.search_result` + - 将搜索日志存入 `message.searching_detail` + - 将搜索结果格式化为上下文,附加到用户查询 +4. LLM 接收 enriched query 并生成回答 +5. 前端可以通过 `searching_detail` 展示搜索过程 + +## 故障排除 + +### 常见问题 + +1. **ImportError: No module named 'sirchmunk'** + ```bash + pip install sirchmunk + ``` + +2. **搜索结果为空** + - 检查 `paths` 配置是否正确 + - 确保路径下有可搜索的文件 + - 尝试降低 `cluster_sim_threshold` 值 + +3. **LLM API 调用失败** + - 检查 API key 是否正确 + - 检查 base URL 是否正确 + - 查看搜索日志了解详细错误 + +### 日志查看 + +```python +# 查看搜索日志 +logs = searcher.get_search_logs() +for log in logs: + print(log) + +# 或在配置中启用 verbose +knowledge_search: + verbose: true +``` + +## 参考资源 + +- [sirchmunk GitHub](https://github.com/modelscope/sirchmunk) +- [ModelScope Agent](https://github.com/modelscope/modelscope-agent) diff --git a/ms_agent/knowledge_search/__init__.py b/ms_agent/knowledge_search/__init__.py new file mode 100644 index 000000000..33362beee --- /dev/null +++ b/ms_agent/knowledge_search/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Knowledge search module based on sirchmunk. + +This module provides integration between sirchmunk's AgenticSearch +and the ms_agent framework, enabling intelligent codebase search +capabilities similar to RAG. +""" + +from .sirchmunk_search import SirchmunkSearch + +__all__ = ['SirchmunkSearch'] diff --git a/ms_agent/knowledge_search/sirchmunk_search.py b/ms_agent/knowledge_search/sirchmunk_search.py new file mode 100644 index 000000000..4e1e322a5 --- /dev/null +++ b/ms_agent/knowledge_search/sirchmunk_search.py @@ -0,0 +1,401 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Sirchmunk-based knowledge search integration. + +This module wraps sirchmunk's AgenticSearch to work with the ms_agent framework, +providing document retrieval capabilities similar to RAG but optimized for +codebase and documentation search. +""" + +import asyncio +from pathlib import Path +from typing import Any, Dict, List, Optional, Union +from loguru import logger + +from ms_agent.rag.base import RAG +from omegaconf import DictConfig + + +class SirchmunkSearch(RAG): + """Sirchmunk-based knowledge search class. + + This class wraps the sirchmunk library to provide intelligent codebase search + capabilities. Unlike traditional RAG that uses vector embeddings, Sirchmunk + uses a combination of keyword search, semantic clustering, and LLM-powered + analysis to find relevant information from codebases. + + The configuration needed in the config yaml: + - name: SirchmunkSearch + - paths: List of paths to search, required + - work_path: Working directory for sirchmunk cache, default './.sirchmunk' + - embedding_model: Embedding model for clustering, default 'text-embedding-3-small' + - cluster_sim_threshold: Threshold for cluster similarity, default 0.85 + - cluster_sim_top_k: Top K clusters to consider, default 3 + - reuse_knowledge: Whether to reuse previous search results, default True + - mode: Search mode (DEEP, FAST, FILENAME_ONLY), default 'FAST' + + Args: + config (DictConfig): Configuration object containing sirchmunk settings. + """ + + def __init__(self, config: DictConfig): + super().__init__(config) + + self._validate_config(config) + + # Extract configuration parameters + rag_config = config.get('knowledge_search', {}) + + # Search paths - required + paths = rag_config.get('paths', []) + if isinstance(paths, str): + paths = [paths] + self.search_paths: List[str] = [str(Path(p).expanduser().resolve()) for p in paths] + + # Work path for sirchmunk cache + _work_path = rag_config.get('work_path', './.sirchmunk') + self.work_path: Path = Path(_work_path).expanduser().resolve() + + # Sirchmunk search parameters + self.reuse_knowledge = rag_config.get('reuse_knowledge', True) + self.cluster_sim_threshold = rag_config.get('cluster_sim_threshold', 0.85) + self.cluster_sim_top_k = rag_config.get('cluster_sim_top_k', 3) + self.search_mode = rag_config.get('mode', 'FAST') + self.max_loops = rag_config.get('max_loops', 10) + self.max_token_budget = rag_config.get('max_token_budget', 128000) + + # LLM configuration for sirchmunk + # First try knowledge_search.llm_api_key, then fall back to main llm config + self.llm_api_key = rag_config.get('llm_api_key', None) + self.llm_base_url = rag_config.get('llm_base_url', None) + self.llm_model_name = rag_config.get('llm_model_name', None) + + # Fall back to main llm config if not specified in knowledge_search + if self.llm_api_key is None or self.llm_base_url is None or self.llm_model_name is None: + llm_config = config.get('llm', {}) + if llm_config: + service = getattr(llm_config, 'service', 'dashscope') + if self.llm_api_key is None: + self.llm_api_key = getattr(llm_config, f'{service}_api_key', None) + if self.llm_base_url is None: + self.llm_base_url = getattr(llm_config, f'{service}_base_url', None) + if self.llm_model_name is None: + self.llm_model_name = getattr(llm_config, 'model', None) + + # Embedding model configuration + self.embedding_model_id = rag_config.get('embedding_model', None) + self.embedding_model_cache_dir = rag_config.get('embedding_model_cache_dir', None) + + # Runtime state + self._searcher = None + self._initialized = False + + # Callback for capturing logs + self._log_callback = None + self._search_logs: List[str] = [] + + def _validate_config(self, config: DictConfig): + """Validate configuration parameters.""" + if not hasattr(config, 'knowledge_search') or config.knowledge_search is None: + raise ValueError( + 'Missing knowledge_search configuration. ' + 'Please add knowledge_search section to your config with at least "paths" specified.' + ) + + rag_config = config.knowledge_search + paths = rag_config.get('paths', []) + if not paths: + raise ValueError('knowledge_search.paths must be specified and non-empty') + + def _initialize_searcher(self): + """Initialize the sirchmunk AgenticSearch instance.""" + if self._initialized: + return + + try: + from sirchmunk.search import AgenticSearch + from sirchmunk.llm.openai_chat import OpenAIChat + from sirchmunk.utils.embedding_util import EmbeddingUtil + + # Create LLM client + llm = OpenAIChat( + api_key=self.llm_api_key, + base_url=self.llm_base_url, + model=self.llm_model_name, + max_retries=3, + log_callback=self._log_callback_wrapper(), + ) + + # Create embedding util + # Handle empty strings by using None (which triggers DEFAULT_MODEL_ID) + embedding_model_id = self.embedding_model_id if self.embedding_model_id else None + embedding_cache_dir = self.embedding_model_cache_dir if self.embedding_model_cache_dir else None + embedding = EmbeddingUtil(model_id=embedding_model_id, cache_dir=embedding_cache_dir) + + # Create AgenticSearch instance + self._searcher = AgenticSearch( + llm=llm, + embedding=embedding, + work_path=str(self.work_path), + paths=self.search_paths, + verbose=True, + reuse_knowledge=self.reuse_knowledge, + cluster_sim_threshold=self.cluster_sim_threshold, + cluster_sim_top_k=self.cluster_sim_top_k, + log_callback=self._log_callback_wrapper(), + ) + + self._initialized = True + logger.info(f'SirschmunkSearch initialized with paths: {self.search_paths}') + + except ImportError as e: + raise ImportError( + f'Failed to import sirchmunk: {e}. ' + 'Please install sirchmunk: pip install sirchmunk' + ) + except Exception as e: + raise RuntimeError(f'Failed to initialize SirchmunkSearch: {e}') + + def _log_callback_wrapper(self): + """Create a callback wrapper to capture search logs.""" + def log_callback(message: str, level: str = 'INFO', logger_name: str = '', is_async: bool = False): + self._search_logs.append(f'[{level}] {message}') + + return log_callback + + async def add_documents(self, documents: List[str]) -> bool: + """Add documents to the search index. + + Note: Sirchmunk works by scanning existing files in the specified paths. + This method is provided for RAG interface compatibility but doesn't + directly add documents. Instead, documents should be saved to files + within the search paths. + + Args: + documents (List[str]): List of document contents to add. + + Returns: + bool: True if successful (for interface compatibility). + """ + logger.warning( + 'SirchmunkSearch does not support direct document addition. ' + 'Documents should be saved to files within the configured search paths.' + ) + # Trigger re-scan of the search paths + if self._searcher and hasattr(self._searcher, 'knowledge_base'): + try: + await self._searcher.knowledge_base.refresh() + return True + except Exception as e: + logger.error(f'Failed to refresh knowledge base: {e}') + return False + return True + + async def add_documents_from_files(self, file_paths: List[str]) -> bool: + """Add documents from file paths. + + Args: + file_paths (List[str]): List of file paths to scan. + + Returns: + bool: True if successful. + """ + self._initialize_searcher() + + if self._searcher and hasattr(self._searcher, 'scan_directory'): + try: + for file_path in file_paths: + if Path(file_path).exists(): + await self._searcher.scan_directory(str(Path(file_path).parent)) + return True + except Exception as e: + logger.error(f'Failed to scan files: {e}') + return False + return True + + async def retrieve(self, + query: str, + limit: int = 5, + score_threshold: float = 0.7, + **filters) -> List[Dict[str, Any]]: + """Retrieve relevant documents using sirchmunk. + + Args: + query (str): The search query. + limit (int): Maximum number of results to return. + score_threshold (float): Minimum relevance score threshold. + **filters: Additional filters (mode, max_loops, etc.). + + Returns: + List[Dict[str, Any]]: List of search results with 'text', 'score', + 'metadata' fields. + """ + self._initialize_searcher() + self._search_logs.clear() + + try: + mode = filters.get('mode', self.search_mode) + max_loops = filters.get('max_loops', self.max_loops) + max_token_budget = filters.get('max_token_budget', self.max_token_budget) + + # Perform search + result = await self._searcher.search( + query=query, + mode=mode, + max_loops=max_loops, + max_token_budget=max_token_budget, + return_context=True, + ) + + # Parse results into standard format + return self._parse_search_result(result, score_threshold, limit) + + except Exception as e: + logger.error(f'SirschmunkSearch retrieve failed: {e}') + return [] + + async def query(self, query: str) -> str: + """Query sirchmunk and return a synthesized answer. + + This method performs a search and returns the LLM-synthesized answer + along with search details that can be used for frontend display. + + Args: + query (str): The search query. + + Returns: + str: The synthesized answer from sirchmunk. + """ + self._initialize_searcher() + self._search_logs.clear() + + try: + mode = self.search_mode + max_loops = self.max_loops + max_token_budget = self.max_token_budget + + # Perform search and get answer + result = await self._searcher.search( + query=query, + mode=mode, + max_loops=max_loops, + max_token_budget=max_token_budget, + return_context=False, + ) + + # Result is already a synthesized answer string + if isinstance(result, str): + return result + + # If we got SearchContext or other format, extract the answer + if hasattr(result, 'answer'): + return result.answer + + # Fallback: convert to string + return str(result) + + except Exception as e: + logger.error(f'SirschmunkSearch query failed: {e}') + return f'Query failed: {e}' + + def _parse_search_result(self, + result: Any, + score_threshold: float, + limit: int) -> List[Dict[str, Any]]: + """Parse sirchmunk search result into standard format. + + Args: + result: The raw search result from sirchmunk. + score_threshold: Minimum score threshold. + limit: Maximum number of results. + + Returns: + List[Dict[str, Any]]: Parsed results. + """ + results = [] + + # Handle SearchContext format (returned when return_context=True) + if hasattr(result, 'cluster') and result.cluster is not None: + cluster = result.cluster + for unit in cluster.evidences: + # Extract score from snippets if available + score = getattr(cluster, 'confidence', 1.0) + if score >= score_threshold: + # Extract text from snippets + text_parts = [] + source = str(getattr(unit, 'file_or_url', 'unknown')) + for snippet in getattr(unit, 'snippets', []): + if isinstance(snippet, dict): + text_parts.append(snippet.get('snippet', '')) + else: + text_parts.append(str(snippet)) + + results.append({ + 'text': '\n'.join(text_parts) if text_parts else getattr(unit, 'summary', ''), + 'score': score, + 'metadata': { + 'source': source, + 'type': getattr(unit, 'abstraction_level', 'text') if hasattr(unit, 'abstraction_level') else 'text', + } + }) + + # Handle format with evidence_units attribute directly + elif hasattr(result, 'evidence_units'): + for unit in result.evidence_units: + score = getattr(unit, 'confidence', 1.0) + if score >= score_threshold: + results.append({ + 'text': str(unit.content) if hasattr(unit, 'content') else str(unit), + 'score': score, + 'metadata': { + 'source': getattr(unit, 'source_file', 'unknown'), + 'type': getattr(unit, 'abstraction_level', 'text'), + } + }) + + # Handle list format + elif isinstance(result, list): + for item in result: + if isinstance(item, dict): + score = item.get('score', item.get('confidence', 1.0)) + if score >= score_threshold: + results.append({ + 'text': item.get('content', item.get('text', str(item))), + 'score': score, + 'metadata': item.get('metadata', {}), + }) + + # Handle dict format + elif isinstance(result, dict): + score = result.get('score', result.get('confidence', 1.0)) + if score >= score_threshold: + results.append({ + 'text': result.get('content', result.get('text', str(result))), + 'score': score, + 'metadata': result.get('metadata', {}), + }) + + # Sort by score and limit results + results.sort(key=lambda x: x.get('score', 0), reverse=True) + return results[:limit] + + def get_search_logs(self) -> List[str]: + """Get the captured search logs. + + Returns: + List[str]: List of log messages from the search operation. + """ + return self._search_logs.copy() + + def get_search_details(self) -> Dict[str, Any]: + """Get detailed search information including logs and metadata. + + Returns: + Dict[str, Any]: Search details including logs, mode, and paths. + """ + return { + 'logs': self._search_logs.copy(), + 'mode': self.search_mode, + 'paths': self.search_paths, + 'work_path': str(self.work_path), + 'reuse_knowledge': self.reuse_knowledge, + } diff --git a/ms_agent/llm/dashscope_llm.py b/ms_agent/llm/dashscope_llm.py index af766f679..b4a6ddaa8 100644 --- a/ms_agent/llm/dashscope_llm.py +++ b/ms_agent/llm/dashscope_llm.py @@ -12,7 +12,7 @@ class DashScope(OpenAI): def __init__(self, config: DictConfig): super().__init__( config, - base_url=config.llm.modelscope_base_url + base_url=config.llm.dashscope_base_url or get_service_config('dashscope').base_url, api_key=config.llm.dashscope_api_key) diff --git a/ms_agent/llm/utils.py b/ms_agent/llm/utils.py index 6a336ca6e..410aa12f0 100644 --- a/ms_agent/llm/utils.py +++ b/ms_agent/llm/utils.py @@ -61,6 +61,12 @@ class Message: api_calls: int = 1 + # Knowledge search (sirchmunk) related fields + # searching_detail: Search process logs and metadata for frontend display + searching_detail: Dict[str, Any] = field(default_factory=dict) + # search_result: Raw search results to be used as context for next LLM turn + search_result: List[Dict[str, Any]] = field(default_factory=list) + def to_dict(self): return asdict(self) diff --git a/ms_agent/rag/utils.py b/ms_agent/rag/utils.py index 08e9a4db7..e66da954d 100644 --- a/ms_agent/rag/utils.py +++ b/ms_agent/rag/utils.py @@ -4,3 +4,6 @@ rag_mapping = { 'LlamaIndexRAG': LlamaIndexRAG, } + +# Note: SirchmunkSearch is registered in knowledge_search module +# and integrated directly in LLMAgent, not through rag_mapping diff --git a/tests/knowledge_search/__init__.py b/tests/knowledge_search/__init__.py new file mode 100644 index 000000000..0cc40e613 --- /dev/null +++ b/tests/knowledge_search/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Knowledge search tests.""" diff --git a/tests/knowledge_search/test_sirschmunk.py b/tests/knowledge_search/test_sirschmunk.py new file mode 100644 index 000000000..5a4f43213 --- /dev/null +++ b/tests/knowledge_search/test_sirschmunk.py @@ -0,0 +1,203 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for SirchmunkSearch knowledge search integration via LLMAgent. + +These tests verify the sirchmunk-based knowledge search functionality +through the LLMAgent entry point, including verification that +search_result and searching_detail fields are properly populated. + +To run these tests, you need to set the following environment variables: + - TEST_LLM_API_KEY: Your LLM API key + - TEST_LLM_BASE_URL: Your LLM API base URL (optional, default: OpenAI) + - TEST_LLM_MODEL_NAME: Your LLM model name (optional) + - TEST_EMBEDDING_MODEL_ID: Embedding model ID (optional) + - TEST_EMBEDDING_MODEL_CACHE_DIR: Embedding model cache directory (optional) + +Example: + export TEST_LLM_API_KEY="your-api-key" + export TEST_LLM_BASE_URL="https://api.openai.com/v1" + export TEST_LLM_MODEL_NAME="gpt-4o" + export TEST_EMBEDDING_MODEL_ID="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" + export TEST_EMBEDDING_MODEL_CACHE_DIR="/tmp/embedding_cache" + python -m pytest tests/knowledge_search/test_sirschmunk.py +""" +import asyncio +import os +import shutil +import unittest +from pathlib import Path + +from ms_agent.knowledge_search import SirchmunkSearch +from ms_agent.agent import LLMAgent +from ms_agent.config import Config +from omegaconf import DictConfig + +from modelscope.utils.test_utils import test_level + + +class SirchmunkLLMAgentIntegrationTest(unittest.TestCase): + """Test cases for SirchmunkSearch integration with LLMAgent. + + These tests verify that when LLMAgent runs a query that triggers + knowledge search, the Message objects have search_result and + searching_detail fields properly populated. + """ + + @classmethod + def setUpClass(cls): + """Set up test fixtures.""" + # Create test directory with sample files + cls.test_dir = Path('./test_llm_agent_knowledge') + cls.test_dir.mkdir(exist_ok=True) + + # Create sample documentation + (cls.test_dir / 'README.md').write_text(''' +# Test Project Documentation + +## Overview +This is a test project for knowledge search integration. + +## API Reference + +### UserManager +The UserManager class handles user operations: +- create_user: Create a new user account +- delete_user: Delete an existing user +- update_user: Update user information +- get_user: Retrieve user details + +### AuthService +The AuthService class handles authentication: +- login: Authenticate user credentials +- logout: End user session +- refresh_token: Refresh authentication token +- verify_token: Validate authentication token +''') + + (cls.test_dir / 'config.py').write_text(''' +"""Configuration module.""" + +class Config: + """Application configuration.""" + + def __init__(self): + self.database_url = "postgresql://localhost:5432/mydb" + self.secret_key = "your-secret-key" + self.debug_mode = False + + def load_from_env(self): + """Load configuration from environment variables.""" + import os + self.database_url = os.getenv("DATABASE_URL", self.database_url) + self.secret_key = os.getenv("SECRET_KEY", self.secret_key) + return self +''') + + @classmethod + def tearDownClass(cls): + """Clean up test fixtures.""" + if cls.test_dir.exists(): + shutil.rmtree(cls.test_dir, ignore_errors=True) + work_dir = Path('./.sirchmunk') + if work_dir.exists(): + shutil.rmtree(work_dir, ignore_errors=True) + + def _get_agent_config(self): + """Create agent configuration with knowledge search.""" + llm_api_key = os.getenv('TEST_LLM_API_KEY', 'test-api-key') + llm_base_url = os.getenv('TEST_LLM_BASE_URL', 'https://api.openai.com/v1') + llm_model_name = os.getenv('TEST_LLM_MODEL_NAME', 'gpt-4o-mini') + # Read from TEST_* env vars (for test-specific config) + # These can be set from .env file which uses TEST_* prefix + embedding_model_id = os.getenv('TEST_EMBEDDING_MODEL_ID', '') + embedding_model_cache_dir = os.getenv('TEST_EMBEDDING_MODEL_CACHE_DIR', '') + + config = DictConfig({ + 'llm': { + 'service': 'openai', + 'model': llm_model_name, + 'openai_api_key': llm_api_key, + 'openai_base_url': llm_base_url, + }, + 'generation_config': { + 'temperature': 0.3, + 'max_tokens': 500, + }, + 'knowledge_search': { + 'name': 'SirchmunkSearch', + 'paths': [str(self.test_dir)], + 'work_path': './.sirchmunk', + 'llm_api_key': llm_api_key, + 'llm_base_url': llm_base_url, + 'llm_model_name': llm_model_name, + 'embedding_model': embedding_model_id, + 'embedding_model_cache_dir': embedding_model_cache_dir, + 'mode': 'FAST', + } + }) + return config + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_llm_agent_with_knowledge_search(self): + """Test LLMAgent using knowledge search. + + This test verifies that: + 1. LLMAgent can be initialized with SirchmunkSearch configuration + 2. Running a query produces a valid response + 3. User message has searching_detail and search_result populated + 4. searching_detail contains expected keys (logs, mode, paths) + 5. search_result is a list + """ + config = self._get_agent_config() + agent = LLMAgent(config=config, tag='test-knowledge-agent') + + # Test query that should trigger knowledge search + query = 'How do I use UserManager to create a user?' + + async def run_agent(): + result = await agent.run(query) + return result + + result = asyncio.run(run_agent()) + + # Verify result + self.assertIsNotNone(result) + self.assertIsInstance(result, list) + self.assertTrue(len(result) > 0) + + # Check that assistant message exists + assistant_message = [m for m in result if m.role == 'assistant'] + self.assertTrue(len(assistant_message) > 0) + + # Check that user message has search_result and searching_detail populated + user_messages = [m for m in result if m.role == 'user'] + self.assertTrue(len(user_messages) > 0, "Expected at least one user message") + + # The first user message should have search details after do_rag processing + user_msg = user_messages[0] + self.assertTrue( + hasattr(user_msg, 'searching_detail'), + "User message should have searching_detail attribute" + ) + self.assertTrue( + hasattr(user_msg, 'search_result'), + "User message should have search_result attribute" + ) + + # Check that searching_detail is a dict with expected keys + self.assertIsInstance( + user_msg.searching_detail, dict, + "searching_detail should be a dictionary" + ) + self.assertIn('logs', user_msg.searching_detail) + self.assertIn('mode', user_msg.searching_detail) + self.assertIn('paths', user_msg.searching_detail) + + # Check that search_result is a list (may be empty if no relevant docs found) + self.assertIsInstance( + user_msg.search_result, list, + "search_result should be a list" + ) + + +if __name__ == '__main__': + unittest.main() From 2977ba09204169b6fe02492da5089e56ad06573e Mon Sep 17 00:00:00 2001 From: suluyana <110878454+suluyana@users.noreply.github.com> Date: Mon, 16 Mar 2026 09:58:32 +0800 Subject: [PATCH 3/8] Update ms_agent/agent/llm_agent.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- ms_agent/agent/llm_agent.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 3bc40c1fc..515446381 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -757,20 +757,6 @@ async def prepare_knowledge_search(self): if hasattr(self.config, 'knowledge_search'): ks_config = self.config.knowledge_search if ks_config is not None: - # Extract LLM config for sirchmunk - if hasattr(self.config, 'llm'): - llm_config = self.config.llm - # Update knowledge_search config with LLM settings if not specified - if not hasattr(ks_config, 'llm_api_key') and hasattr(llm_config, 'modelscope_api_key'): - OmegaConf.update(self.config, 'knowledge_search.llm_api_key', - getattr(llm_config, 'modelscope_api_key', None), merge=True) - if not hasattr(ks_config, 'llm_base_url') and hasattr(llm_config, 'modelscope_base_url'): - OmegaConf.update(self.config, 'knowledge_search.llm_base_url', - getattr(llm_config, 'modelscope_base_url', None), merge=True) - if not hasattr(ks_config, 'llm_model_name') and hasattr(llm_config, 'model'): - OmegaConf.update(self.config, 'knowledge_search.llm_model_name', - getattr(llm_config, 'model', None), merge=True) - self.knowledge_search: SirchmunkSearch = SirchmunkSearch(self.config) async def condense_memory(self, messages: List[Message]) -> List[Message]: From 879dba4190c819ee7b70ce7d4f511aba21e44715 Mon Sep 17 00:00:00 2001 From: suluyana <110878454+suluyana@users.noreply.github.com> Date: Mon, 16 Mar 2026 09:59:07 +0800 Subject: [PATCH 4/8] Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- ms_agent/knowledge_search/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ms_agent/knowledge_search/README.md b/ms_agent/knowledge_search/README.md index 00743601e..ef86df0da 100644 --- a/ms_agent/knowledge_search/README.md +++ b/ms_agent/knowledge_search/README.md @@ -108,7 +108,7 @@ for msg in result: print(f"Search results: {msg.search_result}") ``` -### 2. 单独使用 SirchmunkSearch +### 3. 单独使用 SirchmunkSearch ```python from ms_agent.knowledge_search import SirchmunkSearch From 520282384734a76a4a1eec047ff1e4f962c0ca59 Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 20 Mar 2026 18:27:13 +0800 Subject: [PATCH 5/8] full modify? --- ms_agent/agent/llm_agent.py | 352 +++++++++--------- ms_agent/knowledge_search/sirchmunk_search.py | 208 ++++++++--- 2 files changed, 330 insertions(+), 230 deletions(-) diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 515446381..76289a403 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -2,30 +2,29 @@ import asyncio import importlib import inspect +import json import os.path import sys import threading import uuid from contextlib import contextmanager from copy import deepcopy +from omegaconf import DictConfig, OmegaConf from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union -import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback, callbacks_mapping +from ms_agent.knowledge_search import SirchmunkSearch from ms_agent.llm.llm import LLM from ms_agent.llm.utils import Message, ToolResult from ms_agent.memory import Memory, get_memory_meta_safe, memory_mapping from ms_agent.memory.memory_manager import SharedMemoryManager from ms_agent.rag.base import RAG from ms_agent.rag.utils import rag_mapping -from ms_agent.knowledge_search import SirchmunkSearch from ms_agent.tools import ToolManager from ms_agent.utils import async_retry, read_history, save_history from ms_agent.utils.constants import DEFAULT_TAG, DEFAULT_USER from ms_agent.utils.logger import get_logger -from omegaconf import DictConfig, OmegaConf - from ..config.config import Config, ConfigLifecycleHandler from .base import Agent @@ -90,14 +89,17 @@ class LLMAgent(Agent): TOTAL_CACHE_CREATION_INPUT_TOKENS = 0 TOKEN_LOCK = asyncio.Lock() - def __init__(self, - config: DictConfig = DictConfig({}), - tag: str = DEFAULT_TAG, - trust_remote_code: bool = False, - **kwargs): + def __init__( + self, + config: DictConfig = DictConfig({}), + tag: str = DEFAULT_TAG, + trust_remote_code: bool = False, + **kwargs, + ): if not hasattr(config, 'llm'): default_yaml = os.path.join( - os.path.dirname(os.path.abspath(__file__)), 'agent.yaml') + os.path.dirname(os.path.abspath(__file__)), 'agent.yaml' + ) llm_config = Config.from_task(default_yaml) config = OmegaConf.merge(llm_config, config) super().__init__(config, tag, trust_remote_code) @@ -113,7 +115,8 @@ def __init__(self, self.config.load_cache = self.load_cache self.mcp_server_file = kwargs.get('mcp_server_file', None) self.mcp_config: Dict[str, Any] = self.parse_mcp_servers( - kwargs.get('mcp_config', {})) + kwargs.get('mcp_config', {}) + ) self.mcp_client = kwargs.get('mcp_client', None) self.config_handler = self.register_config_handler() @@ -161,37 +164,34 @@ def _ensure_auto_skills(self) -> bool: use_sandbox = getattr(skills_config, 'use_sandbox', True) if use_sandbox: from ms_agent.utils.docker_utils import is_docker_daemon_running + if not is_docker_daemon_running(): - logger.warning( - 'Docker not running, disabling sandbox for skills') + logger.warning('Docker not running, disabling sandbox for skills') use_sandbox = False # Build retrieve args retrieve_args = {} if hasattr(skills_config, 'retrieve_args'): - retrieve_args = OmegaConf.to_container( - skills_config.retrieve_args) + retrieve_args = OmegaConf.to_container(skills_config.retrieve_args) self._auto_skills = AutoSkills( skills=skills_path, llm=self.llm, - enable_retrieve=getattr(skills_config, 'enable_retrieve', - None), + enable_retrieve=getattr(skills_config, 'enable_retrieve', None), retrieve_args=retrieve_args, - max_candidate_skills=getattr(skills_config, - 'max_candidate_skills', 10), + max_candidate_skills=getattr(skills_config, 'max_candidate_skills', 10), max_retries=getattr(skills_config, 'max_retries', 3), work_dir=getattr(skills_config, 'work_dir', None), use_sandbox=use_sandbox, ) logger.info( - f'AutoSkills initialized with {len(self._auto_skills.all_skills)} skills' + f"AutoSkills initialized with {len(self._auto_skills.all_skills)} skills" ) self._auto_skills_initialized = True return True except Exception as e: - logger.warning(f'Failed to initialize AutoSkills: {e}') + logger.warning(f"Failed to initialize AutoSkills: {e}") self._auto_skills_initialized = True return False @@ -233,7 +233,7 @@ async def should_use_skills(self, query: str) -> bool: needs_skills, _, _, _ = self._auto_skills._analyze_query(query) return needs_skills except Exception as e: - logger.error(f'Skill analysis error: {e}') + logger.error(f"Skill analysis error: {e}") return False async def get_skill_dag(self, query: str): @@ -265,13 +265,15 @@ async def execute_skills(self, query: str, execution_input=None): return None skills_config = self._get_skills_config() - stop_on_failure = getattr(skills_config, 'stop_on_failure', - True) if skills_config else True + stop_on_failure = ( + getattr(skills_config, 'stop_on_failure', True) if skills_config else True + ) result = await self._auto_skills.run( query=query, execution_input=execution_input, - stop_on_failure=stop_on_failure) + stop_on_failure=stop_on_failure, + ) self._last_skill_result = result return result @@ -289,15 +291,14 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: # Handle chat-only response if dag_result.chat_response: - messages.append( - Message(role='assistant', content=dag_result.chat_response)) + messages.append(Message(role='assistant', content=dag_result.chat_response)) return messages # Handle incomplete skills if not dag_result.is_complete: content = "I couldn't find suitable skills for this task." if dag_result.clarification: - content += f'\n\n{dag_result.clarification}' + content += f"\n\n{dag_result.clarification}" messages.append(Message(role='assistant', content=content)) return messages @@ -317,28 +318,30 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: stdout_preview = output.stdout[:1000] if len(output.stdout) > 1000: stdout_preview += '...' - content += f'**{skill_id} output:**\n{stdout_preview}\n\n' + content += f"**{skill_id} output:**\n{stdout_preview}\n\n" if output.output_files: - content += f'**Generated files:** {list(output.output_files.values())}\n\n' + content += f"**Generated files:** {list(output.output_files.values())}\n\n" - content += f'Total execution time: {exec_result.total_duration_ms:.2f}ms' + content += ( + f"Total execution time: {exec_result.total_duration_ms:.2f}ms" + ) else: content = 'Skill execution completed with errors.\n\n' for skill_id, result in exec_result.results.items(): if not result.success: - content += f'**{skill_id} failed:** {result.error}\n' + content += f"**{skill_id} failed:** {result.error}\n" messages.append(Message(role='assistant', content=content)) else: # DAG only, no execution skill_names = list(dag_result.selected_skills.keys()) - content = f'Found {len(skill_names)} relevant skill(s) for your task:\n' + content = f"Found {len(skill_names)} relevant skill(s) for your task:\n" for skill_id, skill in dag_result.selected_skills.items(): desc_preview = skill.description[:100] if len(skill.description) > 100: desc_preview += '...' - content += f'- **{skill.name}** ({skill_id}): {desc_preview}\n' - content += f'\nExecution order: {dag_result.execution_order}' + content += f"- **{skill.name}** ({skill_id}): {desc_preview}\n" + content += f"\nExecution order: {dag_result.execution_order}" messages.append(Message(role='assistant', content=content)) @@ -364,8 +367,7 @@ def parse_mcp_servers(self, mcp_config: Dict[str, Any]) -> Dict[str, Any]: Dict[str, Any]: Merged configuration including file-based overrides. """ mcp_config = mcp_config or {} - if self.mcp_server_file is not None and os.path.isfile( - self.mcp_server_file): + if self.mcp_server_file is not None and os.path.isfile(self.mcp_server_file): with open(self.mcp_server_file, 'r') as f: config = json.load(f) config.update(mcp_config) @@ -394,26 +396,32 @@ def register_config_handler(self) -> Optional[ConfigLifecycleHandler]: if handler_file is not None: local_dir = self.config.local_dir assert self.config.trust_remote_code, ( - f'[External Code]A Config Lifecycle handler ' - f'registered in the config: {handler_file}. ' - f'\nThis is external code, if you trust this workflow, ' - f'please specify `--trust_remote_code true`') - assert local_dir is not None, 'Using external py files, but local_dir cannot be found.' + f"[External Code]A Config Lifecycle handler " + f"registered in the config: {handler_file}. " + f"\nThis is external code, if you trust this workflow, " + f"please specify `--trust_remote_code true`" + ) + assert ( + local_dir is not None + ), 'Using external py files, but local_dir cannot be found.' if local_dir not in sys.path: sys.path.insert(0, local_dir) handler_module = importlib.import_module(handler_file) module_classes = { name: cls - for name, cls in inspect.getmembers(handler_module, - inspect.isclass) + for name, cls in inspect.getmembers(handler_module, inspect.isclass) } handler = None for name, handler_cls in module_classes.items(): - if handler_cls.__bases__[ - 0] is ConfigLifecycleHandler and handler_cls.__module__ == handler_file: + if ( + handler_cls.__bases__[0] is ConfigLifecycleHandler + and handler_cls.__module__ == handler_file + ): handler = handler_cls() - assert handler is not None, f'Config Lifecycle handler class cannot be found in {handler_file}' + assert ( + handler is not None + ), f"Config Lifecycle handler class cannot be found in {handler_file}" return handler return None @@ -424,13 +432,14 @@ def register_callback_from_config(self): Raises: AssertionError: If untrusted external code is referenced without permission. """ - local_dir = self.config.local_dir if hasattr(self.config, - 'local_dir') else None + local_dir = self.config.local_dir if hasattr(self.config, 'local_dir') else None if hasattr(self.config, 'callbacks'): callbacks = self.config.callbacks or [] for _callback in callbacks: subdir = os.path.dirname(_callback) - assert local_dir is not None, 'Using external py files, but local_dir cannot be found.' + assert ( + local_dir is not None + ), 'Using external py files, but local_dir cannot be found.' if subdir: subdir = os.path.join(local_dir, str(subdir)) _callback = os.path.basename(_callback) @@ -451,23 +460,22 @@ def register_callback_from_config(self): module_classes = { name: cls for name, cls in inspect.getmembers( - callback_file, inspect.isclass) + callback_file, inspect.isclass + ) } for name, cls in module_classes.items(): # Find cls which base class is `Callback` - if issubclass( - cls, Callback) and cls.__module__ == _callback: + if issubclass(cls, Callback) and cls.__module__ == _callback: self.callbacks.append(cls(self.config)) # noqa else: - self.callbacks.append(callbacks_mapping[_callback]( - self.config)) + self.callbacks.append(callbacks_mapping[_callback](self.config)) async def on_task_begin(self, messages: List[Message]): - self.log_output(f'Agent {self.tag} task beginning.') + self.log_output(f"Agent {self.tag} task beginning.") await self.loop_callback('on_task_begin', messages) async def on_task_end(self, messages: List[Message]): - self.log_output(f'Agent {self.tag} task finished.') + self.log_output(f"Agent {self.tag} task finished.") await self.loop_callback('on_task_end', messages) async def on_generate_response(self, messages: List[Message]): @@ -492,8 +500,7 @@ async def loop_callback(self, point, messages: List[Message]): for callback in self.callbacks: await getattr(callback, point)(self.runtime, messages) - async def parallel_tool_call(self, - messages: List[Message]) -> List[Message]: + async def parallel_tool_call(self, messages: List[Message]) -> List[Message]: """ Execute multiple tool calls in parallel and append results to the message list. @@ -504,17 +511,20 @@ async def parallel_tool_call(self, List[Message]: Updated message list including tool responses. """ tool_call_result = await self.tool_manager.parallel_call_tool( - messages[-1].tool_calls) + messages[-1].tool_calls + ) assert len(tool_call_result) == len(messages[-1].tool_calls) - for tool_call_result, tool_call_query in zip(tool_call_result, - messages[-1].tool_calls): + for tool_call_result, tool_call_query in zip( + tool_call_result, messages[-1].tool_calls + ): tool_call_result_format = ToolResult.from_raw(tool_call_result) _new_message = Message( role='tool', content=tool_call_result_format.text, tool_call_id=tool_call_query['id'], name=tool_call_query['tool_name'], - resources=tool_call_result_format.resources) + resources=tool_call_result_format.resources, + ) if _new_message.tool_call_id is None: # If tool call id is None, add a random one @@ -530,7 +540,8 @@ async def prepare_tools(self): self.config, self.mcp_config, self.mcp_client, - trust_remote_code=self.trust_remote_code) + trust_remote_code=self.trust_remote_code, + ) await self.tool_manager.connect() async def cleanup_tools(self): @@ -539,8 +550,7 @@ async def cleanup_tools(self): @property def stream(self): - generation_config = getattr(self.config, 'generation_config', - DictConfig({})) + generation_config = getattr(self.config, 'generation_config', DictConfig({})) return getattr(generation_config, 'stream', False) @property @@ -551,8 +561,7 @@ def show_reasoning(self) -> bool: - This only affects local console output. - Reasoning is carried by `Message.reasoning_content` (if the backend provides it). """ - generation_config = getattr(self.config, 'generation_config', - DictConfig({})) + generation_config = getattr(self.config, 'generation_config', DictConfig({})) return bool(getattr(generation_config, 'show_reasoning', False)) @property @@ -563,8 +572,7 @@ def reasoning_output(self) -> str: - "stderr" (default): keep stdout clean for assistant final text - "stdout": interleave reasoning with assistant output on stdout """ - generation_config = getattr(self.config, 'generation_config', - DictConfig({})) + generation_config = getattr(self.config, 'generation_config', DictConfig({})) return str(getattr(generation_config, 'reasoning_output', 'stdout')) def _write_reasoning(self, text: str): @@ -580,19 +588,18 @@ def _write_reasoning(self, text: str): @property def system(self): - return getattr( - getattr(self.config, 'prompt', DictConfig({})), 'system', None) + return getattr(getattr(self.config, 'prompt', DictConfig({})), 'system', None) @property def query(self): - query = getattr( - getattr(self.config, 'prompt', DictConfig({})), 'query', None) + query = getattr(getattr(self.config, 'prompt', DictConfig({})), 'query', None) if not query: query = input('>>>') return query async def create_messages( - self, messages: Union[List[Message], str]) -> List[Message]: + self, messages: Union[List[Message], str] + ) -> List[Message]: """ Convert input into a standardized list of messages. @@ -604,18 +611,19 @@ async def create_messages( """ if isinstance(messages, list): system = self.system - if system is not None and messages[ - 0].role == 'system' and system != messages[0].content: + if ( + system is not None + and messages[0].role == 'system' + and system != messages[0].content + ): # Replace the existing system messages[0].content = system else: assert isinstance( messages, str - ), f'inputs can be either a list or a string, but current is {type(messages)}' + ), f"inputs can be either a list or a string, but current is {type(messages)}" messages = [ - Message( - role='system', - content=self.system or LLMAgent.DEFAULT_SYSTEM), + Message(role='system', content=self.system or LLMAgent.DEFAULT_SYSTEM), Message(role='user', content=messages or self.query), ] return messages @@ -639,11 +647,10 @@ async def do_rag(self, messages: List[Message]): # Handle traditional RAG if self.rag is not None: user_message.content = await self.rag.query(query) - # Handle sirchmunk knowledge search if self.knowledge_search is not None: # Perform search and get results - search_result = await self.knowledge_search.retrieve(query) + search_result = await self.knowledge_search.query(query) search_details = self.knowledge_search.get_search_details() # Store search details in the message for frontend display @@ -652,24 +659,14 @@ async def do_rag(self, messages: List[Message]): # Build enriched context from search results if search_result: - context_parts = [] - for i, result in enumerate(search_result, 1): - text = result.get('text', '') - source = result.get('metadata', {}).get('source', 'unknown') - score = result.get('score', 0) - context_parts.append( - f"[Source {i}] {source} (relevance: {score:.2f})\n{text}\n" - ) - # Append search context to user query - context = '\n'.join(context_parts) + context = search_result user_message.content = ( f"Relevant context retrieved from codebase search:\n\n{context}\n\n" f"User question: {query}" ) - async def do_skill(self, - messages: List[Message]) -> Optional[List[Message]]: + async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: """ Process skill-related query if applicable. @@ -686,7 +683,9 @@ async def do_skill(self, # Extract user query from normalized messages query = ( messages[1].content - if len(messages) > 1 and messages[1].role == 'user' else None) + if len(messages) > 1 and messages[1].role == 'user' + else None + ) if not query: return None @@ -700,8 +699,9 @@ async def do_skill(self, try: skills_config = self._get_skills_config() - auto_execute = getattr(skills_config, 'auto_execute', - True) if skills_config else True + auto_execute = ( + getattr(skills_config, 'auto_execute', True) if skills_config else True + ) if auto_execute: dag_result = await self.execute_skills(query) @@ -709,8 +709,7 @@ async def do_skill(self, dag_result = await self.get_skill_dag(query) if dag_result: - skill_messages = self._format_skill_result_as_messages( - dag_result) + skill_messages = self._format_skill_result_as_messages(dag_result) for msg in skill_messages: messages.append(msg) return messages @@ -721,7 +720,8 @@ async def do_skill(self, except Exception as e: logger.warning( - f'Skill execution failed: {e}, falling back to standard agent') + f"Skill execution failed: {e}, falling back to standard agent" + ) self._skill_mode_active = False return None @@ -735,11 +735,13 @@ async def load_memory(self): if hasattr(self.config, 'memory'): for mem_instance_type, _memory in self.config.memory.items(): assert mem_instance_type in memory_mapping, ( - f'{mem_instance_type} not in memory_mapping, ' - f'which supports: {list(memory_mapping.keys())}') + f"{mem_instance_type} not in memory_mapping, " + f"which supports: {list(memory_mapping.keys())}" + ) shared_memory = await SharedMemoryManager.get_shared_memory( - self.config, mem_instance_type) + self.config, mem_instance_type + ) self.memory_tools.append(shared_memory) async def prepare_rag(self): @@ -748,12 +750,17 @@ async def prepare_rag(self): rag = self.config.rag if rag is not None: assert rag.name in rag_mapping, ( - f'{rag.name} not in rag_mapping, ' - f'which supports: {list(rag_mapping.keys())}') + f"{rag.name} not in rag_mapping, " + f"which supports: {list(rag_mapping.keys())}" + ) self.rag: RAG = rag_mapping(rag.name)(self.config) async def prepare_knowledge_search(self): """Load and initialize the knowledge search component from the config.""" + if self.knowledge_search is not None: + # Already initialized (e.g. by caller before run_loop), skip to avoid + # overwriting a configured instance (e.g. one with streaming callbacks set). + return if hasattr(self.config, 'knowledge_search'): ks_config = self.config.knowledge_search if ks_config is not None: @@ -790,7 +797,7 @@ def log_output(self, content: Union[str, list]): text_parts.append(item.get('text', '')) elif item.get('type') == 'image_url': img_url = item.get('image_url', {}).get('url', '') - text_parts.append(f'[Image: {img_url[:50]}...]') + text_parts.append(f"[Image: {img_url[:50]}...]") content = ' '.join(text_parts) # Ensure content is a string @@ -801,10 +808,9 @@ def log_output(self, content: Union[str, list]): content = content[:512] + '\n...\n' + content[-512:] for line in content.split('\n'): for _line in line.split('\\n'): - logger.info(f'[{self.tag}] {_line}') + logger.info(f"[{self.tag}] {_line}") - def handle_new_response(self, messages: List[Message], - response_message: Message): + def handle_new_response(self, messages: List[Message], response_message: Message): assert response_message is not None, 'No response message generated from LLM.' if response_message.tool_calls: self.log_output('[tool_calling]:') @@ -812,24 +818,23 @@ def handle_new_response(self, messages: List[Message], tool_call = deepcopy(tool_call) if isinstance(tool_call['arguments'], str): try: - tool_call['arguments'] = json.loads( - tool_call['arguments']) + tool_call['arguments'] = json.loads(tool_call['arguments']) except json.decoder.JSONDecodeError: pass - self.log_output( - json.dumps(tool_call, ensure_ascii=False, indent=4)) + self.log_output(json.dumps(tool_call, ensure_ascii=False, indent=4)) if messages[-1] is not response_message: messages.append(response_message) - if messages[-1].role == 'assistant' and not messages[ - -1].content and response_message.tool_calls: + if ( + messages[-1].role == 'assistant' + and not messages[-1].content + and response_message.tool_calls + ): messages[-1].content = 'Let me do a tool calling.' @async_retry(max_attempts=Agent.retry_count, delay=1.0) - async def step( - self, messages: List[Message] - ) -> AsyncGenerator[List[Message], Any]: # type: ignore + async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], Any]: # type: ignore """ Execute a single step in the agent's interaction loop. @@ -865,20 +870,20 @@ async def step( is_first = True _response_message = None _printed_reasoning_header = False - for _response_message in self.llm.generate( - messages, tools=tools): + for _response_message in self.llm.generate(messages, tools=tools): if is_first: messages.append(_response_message) is_first = False # Optional: stream model "thinking/reasoning" if available. if self.show_reasoning: - reasoning_text = getattr(_response_message, - 'reasoning_content', '') or '' + reasoning_text = ( + getattr(_response_message, 'reasoning_content', '') or '' + ) # Some providers may reset / shorten content across chunks. if len(reasoning_text) < len(_reasoning): _reasoning = '' - new_reasoning = reasoning_text[len(_reasoning):] + new_reasoning = reasoning_text[len(_reasoning) :] if new_reasoning: if not _printed_reasoning_header: self._write_reasoning('[thinking]:\n') @@ -886,7 +891,7 @@ async def step( self._write_reasoning(new_reasoning) _reasoning = reasoning_text - new_content = _response_message.content[len(_content):] + new_content = _response_message.content[len(_content) :] sys.stdout.write(new_content) sys.stdout.flush() _content = _response_message.content @@ -898,8 +903,9 @@ async def step( else: _response_message = self.llm.generate(messages, tools=tools) if self.show_reasoning: - reasoning_text = getattr(_response_message, - 'reasoning_content', '') or '' + reasoning_text = ( + getattr(_response_message, 'reasoning_content', '') or '' + ) if reasoning_text: self._write_reasoning('[thinking]:\n') self._write_reasoning(reasoning_text) @@ -927,8 +933,9 @@ async def step( prompt_tokens = _response_message.prompt_tokens completion_tokens = _response_message.completion_tokens cached_tokens = getattr(_response_message, 'cached_tokens', 0) or 0 - cache_creation_input_tokens = getattr( - _response_message, 'cache_creation_input_tokens', 0) or 0 + cache_creation_input_tokens = ( + getattr(_response_message, 'cache_creation_input_tokens', 0) or 0 + ) async with LLMAgent.TOKEN_LOCK: LLMAgent.TOTAL_PROMPT_TOKENS += prompt_tokens @@ -938,20 +945,21 @@ async def step( # tokens in the current step self.log_output( - f'[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}' + f"[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" ) if cached_tokens or cache_creation_input_tokens: self.log_output( - f'[usage_cache] cache_hit: {cached_tokens}, cache_created: {cache_creation_input_tokens}' + f"[usage_cache] cache_hit: {cached_tokens}, cache_created: {cache_creation_input_tokens}" ) # total tokens for the process so far self.log_output( - f'[usage_total] total_prompt_tokens: {LLMAgent.TOTAL_PROMPT_TOKENS}, ' - f'total_completion_tokens: {LLMAgent.TOTAL_COMPLETION_TOKENS}') + f"[usage_total] total_prompt_tokens: {LLMAgent.TOTAL_PROMPT_TOKENS}, " + f"total_completion_tokens: {LLMAgent.TOTAL_COMPLETION_TOKENS}" + ) if LLMAgent.TOTAL_CACHED_TOKENS or LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS: self.log_output( - f'[usage_cache_total] total_cache_hit: {LLMAgent.TOTAL_CACHED_TOKENS}, ' - f'total_cache_created: {LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS}' + f"[usage_cache_total] total_cache_hit: {LLMAgent.TOTAL_CACHED_TOKENS}, " + f"total_cache_created: {LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS}" ) yield messages @@ -964,8 +972,9 @@ def prepare_runtime(self): """Initialize the runtime context.""" self.runtime: Runtime = Runtime(llm=self.llm) - def read_history(self, messages: List[Message], - **kwargs) -> Tuple[DictConfig, Runtime, List[Message]]: + def read_history( + self, messages: List[Message], **kwargs + ) -> Tuple[DictConfig, Runtime, List[Message]]: """ Load previous chat history from disk if available. @@ -1009,9 +1018,9 @@ def get_user_id(self, default_user_id=DEFAULT_USER) -> Optional[str]: def _get_step_memory_info(self, memory_config: DictConfig): user_id, agent_id, run_id, memory_type = get_memory_meta_safe( - memory_config, 'add_after_step') - if all(value is None - for value in [user_id, agent_id, run_id, memory_type]): + memory_config, 'add_after_step' + ) + if all(value is None for value in [user_id, agent_id, run_id, memory_type]): return None, None, None, None user_id = user_id or getattr(memory_config, 'user_id', None) return user_id, agent_id, run_id, memory_type @@ -1020,9 +1029,9 @@ def _get_run_memory_info(self, memory_config: DictConfig): user_id, agent_id, run_id, memory_type = get_memory_meta_safe( memory_config, 'add_after_task', - default_user_id=getattr(memory_config, 'user_id', None)) - if all(value is None - for value in [user_id, agent_id, run_id, memory_type]): + default_user_id=getattr(memory_config, 'user_id', None), + ) + if all(value is None for value in [user_id, agent_id, run_id, memory_type]): return None, None, None, None user_id = user_id or getattr(memory_config, 'user_id', None) agent_id = agent_id or self.tag @@ -1033,24 +1042,29 @@ async def add_memory(self, messages: List[Message], add_type, **kwargs): if hasattr(self.config, 'memory') and self.config.memory: tools_num = len(self.memory_tools) if self.memory_tools else 0 - for idx, (mem_instance_type, - memory_config) in enumerate(self.config.memory.items()): + for idx, (mem_instance_type, memory_config) in enumerate( + self.config.memory.items() + ): if add_type == 'add_after_task': user_id, agent_id, run_id, memory_type = self._get_run_memory_info( - memory_config) + memory_config + ) else: user_id, agent_id, run_id, memory_type = self._get_step_memory_info( - memory_config) + memory_config + ) if idx < tools_num: - if any(v is not None - for v in [user_id, agent_id, run_id, memory_type]): + if any( + v is not None for v in [user_id, agent_id, run_id, memory_type] + ): await self.memory_tools[idx].add( messages, user_id=user_id, agent_id=agent_id, run_id=run_id, - memory_type=memory_type) + memory_type=memory_type, + ) def save_history(self, messages: List[Message], **kwargs): """ @@ -1072,11 +1086,11 @@ def save_history(self, messages: List[Message], **kwargs): config: DictConfig = deepcopy(self.config) config.runtime = self.runtime.to_dict() - save_history( - self.output_dir, task=self.tag, config=config, messages=messages) + save_history(self.output_dir, task=self.tag, config=config, messages=messages) - async def run_loop(self, messages: Union[List[Message], str], - **kwargs) -> AsyncGenerator[Any, Any]: + async def run_loop( + self, messages: Union[List[Message], str], **kwargs + ) -> AsyncGenerator[Any, Any]: """ Run the agent, mainly contains a llm calling and tool calling loop. @@ -1089,8 +1103,9 @@ async def run_loop(self, messages: Union[List[Message], str], List[Message]: A list of message objects representing the agent's response or interaction history. """ try: - self.max_chat_round = getattr(self.config, 'max_chat_round', - LLMAgent.DEFAULT_MAX_CHAT_ROUND) + self.max_chat_round = getattr( + self.config, 'max_chat_round', LLMAgent.DEFAULT_MAX_CHAT_ROUND + ) self.register_callback_from_config() self.prepare_llm() self.prepare_runtime() @@ -1132,8 +1147,7 @@ async def run_loop(self, messages: Union[List[Message], str], yield messages self.runtime.round += 1 # save memory and history - await self.add_memory( - messages, add_type='add_after_step', **kwargs) + await self.add_memory(messages, add_type='add_after_step', **kwargs) self.save_history(messages) # +1 means the next round the assistant may give a conclusion @@ -1142,9 +1156,10 @@ async def run_loop(self, messages: Union[List[Message], str], messages.append( Message( role='assistant', - content= - f'Task {messages[1].content} was cutted off, because ' - f'max round({self.max_chat_round}) exceeded.')) + content=f"Task {messages[1].content} was cutted off, because " + f"max round({self.max_chat_round}) exceeded.", + ) + ) self.runtime.should_stop = True yield messages @@ -1155,32 +1170,33 @@ async def run_loop(self, messages: Union[List[Message], str], def _add_memory(): asyncio.run( - self.add_memory( - messages, add_type='add_after_task', **kwargs)) + self.add_memory(messages, add_type='add_after_task', **kwargs) + ) loop = asyncio.get_running_loop() loop.run_in_executor(None, _add_memory) except Exception as e: import traceback + logger.warning(traceback.format_exc()) if hasattr(self.config, 'help'): logger.error( - f'[{self.tag}] Runtime error, please follow the instructions:\n\n {self.config.help}' + f"[{self.tag}] Runtime error, please follow the instructions:\n\n {self.config.help}" ) raise e async def run( - self, messages: Union[List[Message], str], **kwargs + self, messages: Union[List[Message], str], **kwargs ) -> Union[List[Message], AsyncGenerator[List[Message], Any]]: stream = kwargs.get('stream', False) with self.config_context(): if stream: OmegaConf.update( - self.config, 'generation_config.stream', True, merge=True) + self.config, 'generation_config.stream', True, merge=True + ) async def stream_generator(): - async for _chunk in self.run_loop( - messages=messages, **kwargs): + async for _chunk in self.run_loop(messages=messages, **kwargs): yield _chunk return stream_generator() diff --git a/ms_agent/knowledge_search/sirchmunk_search.py b/ms_agent/knowledge_search/sirchmunk_search.py index 4e1e322a5..dd80738f9 100644 --- a/ms_agent/knowledge_search/sirchmunk_search.py +++ b/ms_agent/knowledge_search/sirchmunk_search.py @@ -7,12 +7,12 @@ """ import asyncio -from pathlib import Path -from typing import Any, Dict, List, Optional, Union from loguru import logger +from omegaconf import DictConfig +from pathlib import Path +from typing import Any, Dict, List, Optional, Union, Callable from ms_agent.rag.base import RAG -from omegaconf import DictConfig class SirchmunkSearch(RAG): @@ -49,7 +49,9 @@ def __init__(self, config: DictConfig): paths = rag_config.get('paths', []) if isinstance(paths, str): paths = [paths] - self.search_paths: List[str] = [str(Path(p).expanduser().resolve()) for p in paths] + self.search_paths: List[str] = [ + str(Path(p).expanduser().resolve()) for p in paths + ] # Work path for sirchmunk cache _work_path = rag_config.get('work_path', './.sirchmunk') @@ -70,28 +72,40 @@ def __init__(self, config: DictConfig): self.llm_model_name = rag_config.get('llm_model_name', None) # Fall back to main llm config if not specified in knowledge_search - if self.llm_api_key is None or self.llm_base_url is None or self.llm_model_name is None: + if ( + self.llm_api_key is None + or self.llm_base_url is None + or self.llm_model_name is None + ): llm_config = config.get('llm', {}) if llm_config: service = getattr(llm_config, 'service', 'dashscope') if self.llm_api_key is None: - self.llm_api_key = getattr(llm_config, f'{service}_api_key', None) + self.llm_api_key = getattr(llm_config, f"{service}_api_key", None) if self.llm_base_url is None: - self.llm_base_url = getattr(llm_config, f'{service}_base_url', None) + self.llm_base_url = getattr(llm_config, f"{service}_base_url", None) if self.llm_model_name is None: self.llm_model_name = getattr(llm_config, 'model', None) # Embedding model configuration self.embedding_model_id = rag_config.get('embedding_model', None) - self.embedding_model_cache_dir = rag_config.get('embedding_model_cache_dir', None) + self.embedding_model_cache_dir = rag_config.get( + 'embedding_model_cache_dir', None + ) # Runtime state self._searcher = None self._initialized = False + self._cluster_cache_hit = False + self._cluster_cache_hit_time: str | None = None + self._last_search_result: List[Dict[str, Any]] | None = None # Callback for capturing logs self._log_callback = None self._search_logs: List[str] = [] + # Async queue for streaming logs in real-time + self._log_queue: asyncio.Queue | None = None + self._streaming_callback: Callable | None = None def _validate_config(self, config: DictConfig): """Validate configuration parameters.""" @@ -112,8 +126,8 @@ def _initialize_searcher(self): return try: - from sirchmunk.search import AgenticSearch from sirchmunk.llm.openai_chat import OpenAIChat + from sirchmunk.search import AgenticSearch from sirchmunk.utils.embedding_util import EmbeddingUtil # Create LLM client @@ -127,9 +141,17 @@ def _initialize_searcher(self): # Create embedding util # Handle empty strings by using None (which triggers DEFAULT_MODEL_ID) - embedding_model_id = self.embedding_model_id if self.embedding_model_id else None - embedding_cache_dir = self.embedding_model_cache_dir if self.embedding_model_cache_dir else None - embedding = EmbeddingUtil(model_id=embedding_model_id, cache_dir=embedding_cache_dir) + embedding_model_id = ( + self.embedding_model_id if self.embedding_model_id else None + ) + embedding_cache_dir = ( + self.embedding_model_cache_dir + if self.embedding_model_cache_dir + else None + ) + embedding = EmbeddingUtil( + model_id=embedding_model_id, cache_dir=embedding_cache_dir + ) # Create AgenticSearch instance self._searcher = AgenticSearch( @@ -145,20 +167,35 @@ def _initialize_searcher(self): ) self._initialized = True - logger.info(f'SirschmunkSearch initialized with paths: {self.search_paths}') + logger.info(f"SirschmunkSearch initialized with paths: {self.search_paths}") except ImportError as e: raise ImportError( - f'Failed to import sirchmunk: {e}. ' + f"Failed to import sirchmunk: {e}. " 'Please install sirchmunk: pip install sirchmunk' ) except Exception as e: - raise RuntimeError(f'Failed to initialize SirchmunkSearch: {e}') + raise RuntimeError(f"Failed to initialize SirchmunkSearch: {e}") def _log_callback_wrapper(self): - """Create a callback wrapper to capture search logs.""" - def log_callback(message: str, level: str = 'INFO', logger_name: str = '', is_async: bool = False): - self._search_logs.append(f'[{level}] {message}') + """Create a callback wrapper to capture search logs. + + The sirchmunk LogCallback signature is: + (level: str, message: str, end: str, flush: bool) -> None + See sirchmunk/utils/log_utils.py for reference. + """ + + def log_callback( + level: str, + message: str, + end: str = '\n', + flush: bool = False, + ): + log_entry = f"[{level.upper()}] {message}" + self._search_logs.append(log_entry) + # Stream log in real-time if streaming callback is set + if self._streaming_callback: + asyncio.create_task(self._streaming_callback(log_entry)) return log_callback @@ -186,7 +223,7 @@ async def add_documents(self, documents: List[str]) -> bool: await self._searcher.knowledge_base.refresh() return True except Exception as e: - logger.error(f'Failed to refresh knowledge base: {e}') + logger.error(f"Failed to refresh knowledge base: {e}") return False return True @@ -208,15 +245,13 @@ async def add_documents_from_files(self, file_paths: List[str]) -> bool: await self._searcher.scan_directory(str(Path(file_path).parent)) return True except Exception as e: - logger.error(f'Failed to scan files: {e}') + logger.error(f"Failed to scan files: {e}") return False return True - async def retrieve(self, - query: str, - limit: int = 5, - score_threshold: float = 0.7, - **filters) -> List[Dict[str, Any]]: + async def retrieve( + self, query: str, limit: int = 5, score_threshold: float = 0.7, **filters + ) -> List[Dict[str, Any]]: """Retrieve relevant documents using sirchmunk. Args: @@ -246,11 +281,21 @@ async def retrieve(self, return_context=True, ) + # Check if cluster cache was hit + self._cluster_cache_hit = False + self._cluster_cache_hit_time = None + if hasattr(result, 'cluster') and result.cluster is not None: + # If a similar cluster was found and reused, it's a cache hit + self._cluster_cache_hit = getattr(result.cluster, '_reused_from_cache', False) + # Get the cluster cache hit time if available + if hasattr(result.cluster, 'updated_at'): + self._cluster_cache_hit_time = getattr(result.cluster, 'updated_at', None) + # Parse results into standard format return self._parse_search_result(result, score_threshold, limit) except Exception as e: - logger.error(f'SirschmunkSearch retrieve failed: {e}') + logger.error(f"SirschmunkSearch retrieve failed: {e}") return [] async def query(self, query: str) -> str: @@ -273,34 +318,45 @@ async def query(self, query: str) -> str: max_loops = self.max_loops max_token_budget = self.max_token_budget - # Perform search and get answer + # Single search with context so we get both the synthesized answer and + # source units in one call, avoiding a redundant second search. result = await self._searcher.search( query=query, mode=mode, max_loops=max_loops, max_token_budget=max_token_budget, - return_context=False, + return_context=True, ) - # Result is already a synthesized answer string - if isinstance(result, str): - return result + # Check if cluster cache was hit + self._cluster_cache_hit = False + self._cluster_cache_hit_time = None + if hasattr(result, 'cluster') and result.cluster is not None: + self._cluster_cache_hit = getattr(result.cluster, '_reused_from_cache', False) + if hasattr(result.cluster, 'updated_at'): + self._cluster_cache_hit_time = getattr(result.cluster, 'updated_at', None) + + # Store parsed context for frontend display + self._last_search_result = self._parse_search_result(result, score_threshold=0.7, limit=5) - # If we got SearchContext or other format, extract the answer + # Extract the synthesized answer from the context result if hasattr(result, 'answer'): return result.answer + # If result is already a plain string (some modes return str directly) + if isinstance(result, str): + return result + # Fallback: convert to string return str(result) except Exception as e: - logger.error(f'SirschmunkSearch query failed: {e}') - return f'Query failed: {e}' + logger.error(f"SirschmunkSearch query failed: {e}") + return f"Query failed: {e}" - def _parse_search_result(self, - result: Any, - score_threshold: float, - limit: int) -> List[Dict[str, Any]]: + def _parse_search_result( + self, result: Any, score_threshold: float, limit: int + ) -> List[Dict[str, Any]]: """Parse sirchmunk search result into standard format. Args: @@ -329,28 +385,38 @@ def _parse_search_result(self, else: text_parts.append(str(snippet)) - results.append({ - 'text': '\n'.join(text_parts) if text_parts else getattr(unit, 'summary', ''), - 'score': score, - 'metadata': { - 'source': source, - 'type': getattr(unit, 'abstraction_level', 'text') if hasattr(unit, 'abstraction_level') else 'text', + results.append( + { + 'text': '\n'.join(text_parts) + if text_parts + else getattr(unit, 'summary', ''), + 'score': score, + 'metadata': { + 'source': source, + 'type': getattr(unit, 'abstraction_level', 'text') + if hasattr(unit, 'abstraction_level') + else 'text', + }, } - }) + ) # Handle format with evidence_units attribute directly elif hasattr(result, 'evidence_units'): for unit in result.evidence_units: score = getattr(unit, 'confidence', 1.0) if score >= score_threshold: - results.append({ - 'text': str(unit.content) if hasattr(unit, 'content') else str(unit), - 'score': score, - 'metadata': { - 'source': getattr(unit, 'source_file', 'unknown'), - 'type': getattr(unit, 'abstraction_level', 'text'), + results.append( + { + 'text': str(unit.content) + if hasattr(unit, 'content') + else str(unit), + 'score': score, + 'metadata': { + 'source': getattr(unit, 'source_file', 'unknown'), + 'type': getattr(unit, 'abstraction_level', 'text'), + }, } - }) + ) # Handle list format elif isinstance(result, list): @@ -358,21 +424,27 @@ def _parse_search_result(self, if isinstance(item, dict): score = item.get('score', item.get('confidence', 1.0)) if score >= score_threshold: - results.append({ - 'text': item.get('content', item.get('text', str(item))), - 'score': score, - 'metadata': item.get('metadata', {}), - }) + results.append( + { + 'text': item.get( + 'content', item.get('text', str(item)) + ), + 'score': score, + 'metadata': item.get('metadata', {}), + } + ) # Handle dict format elif isinstance(result, dict): score = result.get('score', result.get('confidence', 1.0)) if score >= score_threshold: - results.append({ - 'text': result.get('content', result.get('text', str(result))), - 'score': score, - 'metadata': result.get('metadata', {}), - }) + results.append( + { + 'text': result.get('content', result.get('text', str(result))), + 'score': score, + 'metadata': result.get('metadata', {}), + } + ) # Sort by score and limit results results.sort(key=lambda x: x.get('score', 0), reverse=True) @@ -398,4 +470,16 @@ def get_search_details(self) -> Dict[str, Any]: 'paths': self.search_paths, 'work_path': str(self.work_path), 'reuse_knowledge': self.reuse_knowledge, + 'cluster_cache_hit': self._cluster_cache_hit, + 'cluster_cache_hit_time': self._cluster_cache_hit_time, } + + def enable_streaming_logs(self, callback: Callable): + """Enable streaming mode for search logs. + + Args: + callback: Async callback function to receive log entries in real-time. + Signature: async def callback(log_entry: str) -> None + """ + self._streaming_callback = callback + self._search_logs.clear() From 24eb500b75b79941a77eb6172d3683df848f76e6 Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 20 Mar 2026 20:03:13 +0800 Subject: [PATCH 6/8] fix lint --- ms_agent/agent/llm_agent.py | 294 +++++++++--------- ms_agent/cli/run.py | 20 +- ms_agent/knowledge_search/sirchmunk_search.py | 181 +++++------ 3 files changed, 253 insertions(+), 242 deletions(-) diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 76289a403..5f2ddf2e7 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -2,16 +2,15 @@ import asyncio import importlib import inspect -import json import os.path import sys import threading import uuid from contextlib import contextmanager from copy import deepcopy -from omegaconf import DictConfig, OmegaConf from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union +import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback, callbacks_mapping from ms_agent.knowledge_search import SirchmunkSearch @@ -25,6 +24,8 @@ from ms_agent.utils import async_retry, read_history, save_history from ms_agent.utils.constants import DEFAULT_TAG, DEFAULT_USER from ms_agent.utils.logger import get_logger +from omegaconf import DictConfig, OmegaConf + from ..config.config import Config, ConfigLifecycleHandler from .base import Agent @@ -98,8 +99,7 @@ def __init__( ): if not hasattr(config, 'llm'): default_yaml = os.path.join( - os.path.dirname(os.path.abspath(__file__)), 'agent.yaml' - ) + os.path.dirname(os.path.abspath(__file__)), 'agent.yaml') llm_config = Config.from_task(default_yaml) config = OmegaConf.merge(llm_config, config) super().__init__(config, tag, trust_remote_code) @@ -115,8 +115,7 @@ def __init__( self.config.load_cache = self.load_cache self.mcp_server_file = kwargs.get('mcp_server_file', None) self.mcp_config: Dict[str, Any] = self.parse_mcp_servers( - kwargs.get('mcp_config', {}) - ) + kwargs.get('mcp_config', {})) self.mcp_client = kwargs.get('mcp_client', None) self.config_handler = self.register_config_handler() @@ -166,32 +165,36 @@ def _ensure_auto_skills(self) -> bool: from ms_agent.utils.docker_utils import is_docker_daemon_running if not is_docker_daemon_running(): - logger.warning('Docker not running, disabling sandbox for skills') + logger.warning( + 'Docker not running, disabling sandbox for skills') use_sandbox = False # Build retrieve args retrieve_args = {} if hasattr(skills_config, 'retrieve_args'): - retrieve_args = OmegaConf.to_container(skills_config.retrieve_args) + retrieve_args = OmegaConf.to_container( + skills_config.retrieve_args) self._auto_skills = AutoSkills( skills=skills_path, llm=self.llm, - enable_retrieve=getattr(skills_config, 'enable_retrieve', None), + enable_retrieve=getattr(skills_config, 'enable_retrieve', + None), retrieve_args=retrieve_args, - max_candidate_skills=getattr(skills_config, 'max_candidate_skills', 10), + max_candidate_skills=getattr(skills_config, + 'max_candidate_skills', 10), max_retries=getattr(skills_config, 'max_retries', 3), work_dir=getattr(skills_config, 'work_dir', None), use_sandbox=use_sandbox, ) logger.info( - f"AutoSkills initialized with {len(self._auto_skills.all_skills)} skills" + f'AutoSkills initialized with {len(self._auto_skills.all_skills)} skills' ) self._auto_skills_initialized = True return True except Exception as e: - logger.warning(f"Failed to initialize AutoSkills: {e}") + logger.warning(f'Failed to initialize AutoSkills: {e}') self._auto_skills_initialized = True return False @@ -233,7 +236,7 @@ async def should_use_skills(self, query: str) -> bool: needs_skills, _, _, _ = self._auto_skills._analyze_query(query) return needs_skills except Exception as e: - logger.error(f"Skill analysis error: {e}") + logger.error(f'Skill analysis error: {e}') return False async def get_skill_dag(self, query: str): @@ -266,8 +269,8 @@ async def execute_skills(self, query: str, execution_input=None): skills_config = self._get_skills_config() stop_on_failure = ( - getattr(skills_config, 'stop_on_failure', True) if skills_config else True - ) + getattr(skills_config, 'stop_on_failure', True) + if skills_config else True) result = await self._auto_skills.run( query=query, @@ -291,14 +294,15 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: # Handle chat-only response if dag_result.chat_response: - messages.append(Message(role='assistant', content=dag_result.chat_response)) + messages.append( + Message(role='assistant', content=dag_result.chat_response)) return messages # Handle incomplete skills if not dag_result.is_complete: content = "I couldn't find suitable skills for this task." if dag_result.clarification: - content += f"\n\n{dag_result.clarification}" + content += f'\n\n{dag_result.clarification}' messages.append(Message(role='assistant', content=content)) return messages @@ -318,30 +322,30 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: stdout_preview = output.stdout[:1000] if len(output.stdout) > 1000: stdout_preview += '...' - content += f"**{skill_id} output:**\n{stdout_preview}\n\n" + content += f'**{skill_id} output:**\n{stdout_preview}\n\n' if output.output_files: - content += f"**Generated files:** {list(output.output_files.values())}\n\n" + content += f'**Generated files:** {list(output.output_files.values())}\n\n' content += ( - f"Total execution time: {exec_result.total_duration_ms:.2f}ms" + f'Total execution time: {exec_result.total_duration_ms:.2f}ms' ) else: content = 'Skill execution completed with errors.\n\n' for skill_id, result in exec_result.results.items(): if not result.success: - content += f"**{skill_id} failed:** {result.error}\n" + content += f'**{skill_id} failed:** {result.error}\n' messages.append(Message(role='assistant', content=content)) else: # DAG only, no execution skill_names = list(dag_result.selected_skills.keys()) - content = f"Found {len(skill_names)} relevant skill(s) for your task:\n" + content = f'Found {len(skill_names)} relevant skill(s) for your task:\n' for skill_id, skill in dag_result.selected_skills.items(): desc_preview = skill.description[:100] if len(skill.description) > 100: desc_preview += '...' - content += f"- **{skill.name}** ({skill_id}): {desc_preview}\n" - content += f"\nExecution order: {dag_result.execution_order}" + content += f'- **{skill.name}** ({skill_id}): {desc_preview}\n' + content += f'\nExecution order: {dag_result.execution_order}' messages.append(Message(role='assistant', content=content)) @@ -367,7 +371,8 @@ def parse_mcp_servers(self, mcp_config: Dict[str, Any]) -> Dict[str, Any]: Dict[str, Any]: Merged configuration including file-based overrides. """ mcp_config = mcp_config or {} - if self.mcp_server_file is not None and os.path.isfile(self.mcp_server_file): + if self.mcp_server_file is not None and os.path.isfile( + self.mcp_server_file): with open(self.mcp_server_file, 'r') as f: config = json.load(f) config.update(mcp_config) @@ -396,11 +401,10 @@ def register_config_handler(self) -> Optional[ConfigLifecycleHandler]: if handler_file is not None: local_dir = self.config.local_dir assert self.config.trust_remote_code, ( - f"[External Code]A Config Lifecycle handler " - f"registered in the config: {handler_file}. " - f"\nThis is external code, if you trust this workflow, " - f"please specify `--trust_remote_code true`" - ) + f'[External Code]A Config Lifecycle handler ' + f'registered in the config: {handler_file}. ' + f'\nThis is external code, if you trust this workflow, ' + f'please specify `--trust_remote_code true`') assert ( local_dir is not None ), 'Using external py files, but local_dir cannot be found.' @@ -410,18 +414,17 @@ def register_config_handler(self) -> Optional[ConfigLifecycleHandler]: handler_module = importlib.import_module(handler_file) module_classes = { name: cls - for name, cls in inspect.getmembers(handler_module, inspect.isclass) + for name, cls in inspect.getmembers(handler_module, + inspect.isclass) } handler = None for name, handler_cls in module_classes.items(): - if ( - handler_cls.__bases__[0] is ConfigLifecycleHandler - and handler_cls.__module__ == handler_file - ): + if (handler_cls.__bases__[0] is ConfigLifecycleHandler + and handler_cls.__module__ == handler_file): handler = handler_cls() assert ( handler is not None - ), f"Config Lifecycle handler class cannot be found in {handler_file}" + ), f'Config Lifecycle handler class cannot be found in {handler_file}' return handler return None @@ -432,7 +435,8 @@ def register_callback_from_config(self): Raises: AssertionError: If untrusted external code is referenced without permission. """ - local_dir = self.config.local_dir if hasattr(self.config, 'local_dir') else None + local_dir = self.config.local_dir if hasattr(self.config, + 'local_dir') else None if hasattr(self.config, 'callbacks'): callbacks = self.config.callbacks or [] for _callback in callbacks: @@ -460,22 +464,23 @@ def register_callback_from_config(self): module_classes = { name: cls for name, cls in inspect.getmembers( - callback_file, inspect.isclass - ) + callback_file, inspect.isclass) } for name, cls in module_classes.items(): # Find cls which base class is `Callback` - if issubclass(cls, Callback) and cls.__module__ == _callback: + if issubclass( + cls, Callback) and cls.__module__ == _callback: self.callbacks.append(cls(self.config)) # noqa else: - self.callbacks.append(callbacks_mapping[_callback](self.config)) + self.callbacks.append(callbacks_mapping[_callback]( + self.config)) async def on_task_begin(self, messages: List[Message]): - self.log_output(f"Agent {self.tag} task beginning.") + self.log_output(f'Agent {self.tag} task beginning.') await self.loop_callback('on_task_begin', messages) async def on_task_end(self, messages: List[Message]): - self.log_output(f"Agent {self.tag} task finished.") + self.log_output(f'Agent {self.tag} task finished.') await self.loop_callback('on_task_end', messages) async def on_generate_response(self, messages: List[Message]): @@ -500,7 +505,8 @@ async def loop_callback(self, point, messages: List[Message]): for callback in self.callbacks: await getattr(callback, point)(self.runtime, messages) - async def parallel_tool_call(self, messages: List[Message]) -> List[Message]: + async def parallel_tool_call(self, + messages: List[Message]) -> List[Message]: """ Execute multiple tool calls in parallel and append results to the message list. @@ -511,12 +517,10 @@ async def parallel_tool_call(self, messages: List[Message]) -> List[Message]: List[Message]: Updated message list including tool responses. """ tool_call_result = await self.tool_manager.parallel_call_tool( - messages[-1].tool_calls - ) + messages[-1].tool_calls) assert len(tool_call_result) == len(messages[-1].tool_calls) - for tool_call_result, tool_call_query in zip( - tool_call_result, messages[-1].tool_calls - ): + for tool_call_result, tool_call_query in zip(tool_call_result, + messages[-1].tool_calls): tool_call_result_format = ToolResult.from_raw(tool_call_result) _new_message = Message( role='tool', @@ -550,7 +554,8 @@ async def cleanup_tools(self): @property def stream(self): - generation_config = getattr(self.config, 'generation_config', DictConfig({})) + generation_config = getattr(self.config, 'generation_config', + DictConfig({})) return getattr(generation_config, 'stream', False) @property @@ -561,7 +566,8 @@ def show_reasoning(self) -> bool: - This only affects local console output. - Reasoning is carried by `Message.reasoning_content` (if the backend provides it). """ - generation_config = getattr(self.config, 'generation_config', DictConfig({})) + generation_config = getattr(self.config, 'generation_config', + DictConfig({})) return bool(getattr(generation_config, 'show_reasoning', False)) @property @@ -572,7 +578,8 @@ def reasoning_output(self) -> str: - "stderr" (default): keep stdout clean for assistant final text - "stdout": interleave reasoning with assistant output on stdout """ - generation_config = getattr(self.config, 'generation_config', DictConfig({})) + generation_config = getattr(self.config, 'generation_config', + DictConfig({})) return str(getattr(generation_config, 'reasoning_output', 'stdout')) def _write_reasoning(self, text: str): @@ -588,18 +595,19 @@ def _write_reasoning(self, text: str): @property def system(self): - return getattr(getattr(self.config, 'prompt', DictConfig({})), 'system', None) + return getattr( + getattr(self.config, 'prompt', DictConfig({})), 'system', None) @property def query(self): - query = getattr(getattr(self.config, 'prompt', DictConfig({})), 'query', None) + query = getattr( + getattr(self.config, 'prompt', DictConfig({})), 'query', None) if not query: query = input('>>>') return query async def create_messages( - self, messages: Union[List[Message], str] - ) -> List[Message]: + self, messages: Union[List[Message], str]) -> List[Message]: """ Convert input into a standardized list of messages. @@ -611,19 +619,18 @@ async def create_messages( """ if isinstance(messages, list): system = self.system - if ( - system is not None - and messages[0].role == 'system' - and system != messages[0].content - ): + if (system is not None and messages[0].role == 'system' + and system != messages[0].content): # Replace the existing system messages[0].content = system else: assert isinstance( messages, str - ), f"inputs can be either a list or a string, but current is {type(messages)}" + ), f'inputs can be either a list or a string, but current is {type(messages)}' messages = [ - Message(role='system', content=self.system or LLMAgent.DEFAULT_SYSTEM), + Message( + role='system', + content=self.system or LLMAgent.DEFAULT_SYSTEM), Message(role='user', content=messages or self.query), ] return messages @@ -662,11 +669,11 @@ async def do_rag(self, messages: List[Message]): # Append search context to user query context = search_result user_message.content = ( - f"Relevant context retrieved from codebase search:\n\n{context}\n\n" - f"User question: {query}" - ) + f'Relevant context retrieved from codebase search:\n\n{context}\n\n' + f'User question: {query}') - async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: + async def do_skill(self, + messages: List[Message]) -> Optional[List[Message]]: """ Process skill-related query if applicable. @@ -683,9 +690,7 @@ async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: # Extract user query from normalized messages query = ( messages[1].content - if len(messages) > 1 and messages[1].role == 'user' - else None - ) + if len(messages) > 1 and messages[1].role == 'user' else None) if not query: return None @@ -700,8 +705,8 @@ async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: try: skills_config = self._get_skills_config() auto_execute = ( - getattr(skills_config, 'auto_execute', True) if skills_config else True - ) + getattr(skills_config, 'auto_execute', True) + if skills_config else True) if auto_execute: dag_result = await self.execute_skills(query) @@ -709,7 +714,8 @@ async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: dag_result = await self.get_skill_dag(query) if dag_result: - skill_messages = self._format_skill_result_as_messages(dag_result) + skill_messages = self._format_skill_result_as_messages( + dag_result) for msg in skill_messages: messages.append(msg) return messages @@ -720,8 +726,7 @@ async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: except Exception as e: logger.warning( - f"Skill execution failed: {e}, falling back to standard agent" - ) + f'Skill execution failed: {e}, falling back to standard agent') self._skill_mode_active = False return None @@ -735,13 +740,11 @@ async def load_memory(self): if hasattr(self.config, 'memory'): for mem_instance_type, _memory in self.config.memory.items(): assert mem_instance_type in memory_mapping, ( - f"{mem_instance_type} not in memory_mapping, " - f"which supports: {list(memory_mapping.keys())}" - ) + f'{mem_instance_type} not in memory_mapping, ' + f'which supports: {list(memory_mapping.keys())}') shared_memory = await SharedMemoryManager.get_shared_memory( - self.config, mem_instance_type - ) + self.config, mem_instance_type) self.memory_tools.append(shared_memory) async def prepare_rag(self): @@ -750,9 +753,8 @@ async def prepare_rag(self): rag = self.config.rag if rag is not None: assert rag.name in rag_mapping, ( - f"{rag.name} not in rag_mapping, " - f"which supports: {list(rag_mapping.keys())}" - ) + f'{rag.name} not in rag_mapping, ' + f'which supports: {list(rag_mapping.keys())}') self.rag: RAG = rag_mapping(rag.name)(self.config) async def prepare_knowledge_search(self): @@ -764,7 +766,8 @@ async def prepare_knowledge_search(self): if hasattr(self.config, 'knowledge_search'): ks_config = self.config.knowledge_search if ks_config is not None: - self.knowledge_search: SirchmunkSearch = SirchmunkSearch(self.config) + self.knowledge_search: SirchmunkSearch = SirchmunkSearch( + self.config) async def condense_memory(self, messages: List[Message]) -> List[Message]: """ @@ -797,7 +800,7 @@ def log_output(self, content: Union[str, list]): text_parts.append(item.get('text', '')) elif item.get('type') == 'image_url': img_url = item.get('image_url', {}).get('url', '') - text_parts.append(f"[Image: {img_url[:50]}...]") + text_parts.append(f'[Image: {img_url[:50]}...]') content = ' '.join(text_parts) # Ensure content is a string @@ -808,9 +811,10 @@ def log_output(self, content: Union[str, list]): content = content[:512] + '\n...\n' + content[-512:] for line in content.split('\n'): for _line in line.split('\\n'): - logger.info(f"[{self.tag}] {_line}") + logger.info(f'[{self.tag}] {_line}') - def handle_new_response(self, messages: List[Message], response_message: Message): + def handle_new_response(self, messages: List[Message], + response_message: Message): assert response_message is not None, 'No response message generated from LLM.' if response_message.tool_calls: self.log_output('[tool_calling]:') @@ -818,23 +822,24 @@ def handle_new_response(self, messages: List[Message], response_message: Message tool_call = deepcopy(tool_call) if isinstance(tool_call['arguments'], str): try: - tool_call['arguments'] = json.loads(tool_call['arguments']) + tool_call['arguments'] = json.loads( + tool_call['arguments']) except json.decoder.JSONDecodeError: pass - self.log_output(json.dumps(tool_call, ensure_ascii=False, indent=4)) + self.log_output( + json.dumps(tool_call, ensure_ascii=False, indent=4)) if messages[-1] is not response_message: messages.append(response_message) - if ( - messages[-1].role == 'assistant' - and not messages[-1].content - and response_message.tool_calls - ): + if (messages[-1].role == 'assistant' and not messages[-1].content + and response_message.tool_calls): messages[-1].content = 'Let me do a tool calling.' @async_retry(max_attempts=Agent.retry_count, delay=1.0) - async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], Any]: # type: ignore + async def step( + self, messages: List[Message] + ) -> AsyncGenerator[List[Message], Any]: # type: ignore """ Execute a single step in the agent's interaction loop. @@ -870,7 +875,8 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A is_first = True _response_message = None _printed_reasoning_header = False - for _response_message in self.llm.generate(messages, tools=tools): + for _response_message in self.llm.generate( + messages, tools=tools): if is_first: messages.append(_response_message) is_first = False @@ -878,12 +884,12 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A # Optional: stream model "thinking/reasoning" if available. if self.show_reasoning: reasoning_text = ( - getattr(_response_message, 'reasoning_content', '') or '' - ) + getattr(_response_message, 'reasoning_content', '') + or '') # Some providers may reset / shorten content across chunks. if len(reasoning_text) < len(_reasoning): _reasoning = '' - new_reasoning = reasoning_text[len(_reasoning) :] + new_reasoning = reasoning_text[len(_reasoning):] if new_reasoning: if not _printed_reasoning_header: self._write_reasoning('[thinking]:\n') @@ -891,7 +897,7 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A self._write_reasoning(new_reasoning) _reasoning = reasoning_text - new_content = _response_message.content[len(_content) :] + new_content = _response_message.content[len(_content):] sys.stdout.write(new_content) sys.stdout.flush() _content = _response_message.content @@ -904,8 +910,8 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A _response_message = self.llm.generate(messages, tools=tools) if self.show_reasoning: reasoning_text = ( - getattr(_response_message, 'reasoning_content', '') or '' - ) + getattr(_response_message, 'reasoning_content', '') + or '') if reasoning_text: self._write_reasoning('[thinking]:\n') self._write_reasoning(reasoning_text) @@ -934,8 +940,7 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A completion_tokens = _response_message.completion_tokens cached_tokens = getattr(_response_message, 'cached_tokens', 0) or 0 cache_creation_input_tokens = ( - getattr(_response_message, 'cache_creation_input_tokens', 0) or 0 - ) + getattr(_response_message, 'cache_creation_input_tokens', 0) or 0) async with LLMAgent.TOKEN_LOCK: LLMAgent.TOTAL_PROMPT_TOKENS += prompt_tokens @@ -945,21 +950,20 @@ async def step(self, messages: List[Message]) -> AsyncGenerator[List[Message], A # tokens in the current step self.log_output( - f"[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" + f'[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}' ) if cached_tokens or cache_creation_input_tokens: self.log_output( - f"[usage_cache] cache_hit: {cached_tokens}, cache_created: {cache_creation_input_tokens}" + f'[usage_cache] cache_hit: {cached_tokens}, cache_created: {cache_creation_input_tokens}' ) # total tokens for the process so far self.log_output( - f"[usage_total] total_prompt_tokens: {LLMAgent.TOTAL_PROMPT_TOKENS}, " - f"total_completion_tokens: {LLMAgent.TOTAL_COMPLETION_TOKENS}" - ) + f'[usage_total] total_prompt_tokens: {LLMAgent.TOTAL_PROMPT_TOKENS}, ' + f'total_completion_tokens: {LLMAgent.TOTAL_COMPLETION_TOKENS}') if LLMAgent.TOTAL_CACHED_TOKENS or LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS: self.log_output( - f"[usage_cache_total] total_cache_hit: {LLMAgent.TOTAL_CACHED_TOKENS}, " - f"total_cache_created: {LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS}" + f'[usage_cache_total] total_cache_hit: {LLMAgent.TOTAL_CACHED_TOKENS}, ' + f'total_cache_created: {LLMAgent.TOTAL_CACHE_CREATION_INPUT_TOKENS}' ) yield messages @@ -972,9 +976,8 @@ def prepare_runtime(self): """Initialize the runtime context.""" self.runtime: Runtime = Runtime(llm=self.llm) - def read_history( - self, messages: List[Message], **kwargs - ) -> Tuple[DictConfig, Runtime, List[Message]]: + def read_history(self, messages: List[Message], + **kwargs) -> Tuple[DictConfig, Runtime, List[Message]]: """ Load previous chat history from disk if available. @@ -1018,9 +1021,9 @@ def get_user_id(self, default_user_id=DEFAULT_USER) -> Optional[str]: def _get_step_memory_info(self, memory_config: DictConfig): user_id, agent_id, run_id, memory_type = get_memory_meta_safe( - memory_config, 'add_after_step' - ) - if all(value is None for value in [user_id, agent_id, run_id, memory_type]): + memory_config, 'add_after_step') + if all(value is None + for value in [user_id, agent_id, run_id, memory_type]): return None, None, None, None user_id = user_id or getattr(memory_config, 'user_id', None) return user_id, agent_id, run_id, memory_type @@ -1031,7 +1034,8 @@ def _get_run_memory_info(self, memory_config: DictConfig): 'add_after_task', default_user_id=getattr(memory_config, 'user_id', None), ) - if all(value is None for value in [user_id, agent_id, run_id, memory_type]): + if all(value is None + for value in [user_id, agent_id, run_id, memory_type]): return None, None, None, None user_id = user_id or getattr(memory_config, 'user_id', None) agent_id = agent_id or self.tag @@ -1042,22 +1046,18 @@ async def add_memory(self, messages: List[Message], add_type, **kwargs): if hasattr(self.config, 'memory') and self.config.memory: tools_num = len(self.memory_tools) if self.memory_tools else 0 - for idx, (mem_instance_type, memory_config) in enumerate( - self.config.memory.items() - ): + for idx, (mem_instance_type, + memory_config) in enumerate(self.config.memory.items()): if add_type == 'add_after_task': user_id, agent_id, run_id, memory_type = self._get_run_memory_info( - memory_config - ) + memory_config) else: user_id, agent_id, run_id, memory_type = self._get_step_memory_info( - memory_config - ) + memory_config) if idx < tools_num: - if any( - v is not None for v in [user_id, agent_id, run_id, memory_type] - ): + if any(v is not None + for v in [user_id, agent_id, run_id, memory_type]): await self.memory_tools[idx].add( messages, user_id=user_id, @@ -1086,11 +1086,11 @@ def save_history(self, messages: List[Message], **kwargs): config: DictConfig = deepcopy(self.config) config.runtime = self.runtime.to_dict() - save_history(self.output_dir, task=self.tag, config=config, messages=messages) + save_history( + self.output_dir, task=self.tag, config=config, messages=messages) - async def run_loop( - self, messages: Union[List[Message], str], **kwargs - ) -> AsyncGenerator[Any, Any]: + async def run_loop(self, messages: Union[List[Message], str], + **kwargs) -> AsyncGenerator[Any, Any]: """ Run the agent, mainly contains a llm calling and tool calling loop. @@ -1103,9 +1103,8 @@ async def run_loop( List[Message]: A list of message objects representing the agent's response or interaction history. """ try: - self.max_chat_round = getattr( - self.config, 'max_chat_round', LLMAgent.DEFAULT_MAX_CHAT_ROUND - ) + self.max_chat_round = getattr(self.config, 'max_chat_round', + LLMAgent.DEFAULT_MAX_CHAT_ROUND) self.register_callback_from_config() self.prepare_llm() self.prepare_runtime() @@ -1147,7 +1146,8 @@ async def run_loop( yield messages self.runtime.round += 1 # save memory and history - await self.add_memory(messages, add_type='add_after_step', **kwargs) + await self.add_memory( + messages, add_type='add_after_step', **kwargs) self.save_history(messages) # +1 means the next round the assistant may give a conclusion @@ -1156,10 +1156,10 @@ async def run_loop( messages.append( Message( role='assistant', - content=f"Task {messages[1].content} was cutted off, because " - f"max round({self.max_chat_round}) exceeded.", - ) - ) + content= + f'Task {messages[1].content} was cutted off, because ' + f'max round({self.max_chat_round}) exceeded.', + )) self.runtime.should_stop = True yield messages @@ -1170,8 +1170,8 @@ async def run_loop( def _add_memory(): asyncio.run( - self.add_memory(messages, add_type='add_after_task', **kwargs) - ) + self.add_memory( + messages, add_type='add_after_task', **kwargs)) loop = asyncio.get_running_loop() loop.run_in_executor(None, _add_memory) @@ -1181,22 +1181,22 @@ def _add_memory(): logger.warning(traceback.format_exc()) if hasattr(self.config, 'help'): logger.error( - f"[{self.tag}] Runtime error, please follow the instructions:\n\n {self.config.help}" + f'[{self.tag}] Runtime error, please follow the instructions:\n\n {self.config.help}' ) raise e async def run( - self, messages: Union[List[Message], str], **kwargs + self, messages: Union[List[Message], str], **kwargs ) -> Union[List[Message], AsyncGenerator[List[Message], Any]]: stream = kwargs.get('stream', False) with self.config_context(): if stream: OmegaConf.update( - self.config, 'generation_config.stream', True, merge=True - ) + self.config, 'generation_config.stream', True, merge=True) async def stream_generator(): - async for _chunk in self.run_loop(messages=messages, **kwargs): + async for _chunk in self.run_loop( + messages=messages, **kwargs): yield _chunk return stream_generator() diff --git a/ms_agent/cli/run.py b/ms_agent/cli/run.py index cfe387e5a..a07397c95 100644 --- a/ms_agent/cli/run.py +++ b/ms_agent/cli/run.py @@ -4,11 +4,10 @@ import os from importlib import resources as importlib_resources -from omegaconf import OmegaConf - from ms_agent.config import Config from ms_agent.utils import get_logger, strtobool from ms_agent.utils.constants import AGENT_CONFIG_FILE, MS_AGENT_ASCII +from omegaconf import OmegaConf from .base import CLICommand @@ -185,8 +184,10 @@ def _execute_with_config(self): self.args.config = os.path.join(current_dir, AGENT_CONFIG_FILE) else: # Use built-in default agent.yaml from package - default_config_path = importlib_resources.files('ms_agent').joinpath('agent', AGENT_CONFIG_FILE) - with importlib_resources.as_file(default_config_path) as config_file: + default_config_path = importlib_resources.files( + 'ms_agent').joinpath('agent', AGENT_CONFIG_FILE) + with importlib_resources.as_file( + default_config_path) as config_file: self.args.config = str(config_file) elif not os.path.exists(self.args.config): from modelscope import snapshot_download @@ -226,7 +227,10 @@ def _execute_with_config(self): # If knowledge_search_paths is provided, configure SirchmunkSearch if getattr(self.args, 'knowledge_search_paths', None): - paths = [p.strip() for p in self.args.knowledge_search_paths.split(',') if p.strip()] + paths = [ + p.strip() for p in self.args.knowledge_search_paths.split(',') + if p.strip() + ] if paths: if 'knowledge_search' not in config or not config.knowledge_search: # No existing knowledge_search config, create minimal config @@ -237,11 +241,13 @@ def _execute_with_config(self): 'work_path': './.sirchmunk', 'mode': 'FAST', } - config['knowledge_search'] = OmegaConf.create(knowledge_search_config) + config['knowledge_search'] = OmegaConf.create( + knowledge_search_config) else: # Existing knowledge_search config found, only update paths # LLM settings are already handled by SirchmunkSearch internally - existing = OmegaConf.to_container(config.knowledge_search, resolve=True) + existing = OmegaConf.to_container( + config.knowledge_search, resolve=True) existing['paths'] = paths config['knowledge_search'] = OmegaConf.create(existing) diff --git a/ms_agent/knowledge_search/sirchmunk_search.py b/ms_agent/knowledge_search/sirchmunk_search.py index dd80738f9..e1c76181f 100644 --- a/ms_agent/knowledge_search/sirchmunk_search.py +++ b/ms_agent/knowledge_search/sirchmunk_search.py @@ -7,12 +7,12 @@ """ import asyncio -from loguru import logger -from omegaconf import DictConfig from pathlib import Path -from typing import Any, Dict, List, Optional, Union, Callable +from typing import Any, Callable, Dict, List, Optional, Union +from loguru import logger from ms_agent.rag.base import RAG +from omegaconf import DictConfig class SirchmunkSearch(RAG): @@ -59,7 +59,8 @@ def __init__(self, config: DictConfig): # Sirchmunk search parameters self.reuse_knowledge = rag_config.get('reuse_knowledge', True) - self.cluster_sim_threshold = rag_config.get('cluster_sim_threshold', 0.85) + self.cluster_sim_threshold = rag_config.get('cluster_sim_threshold', + 0.85) self.cluster_sim_top_k = rag_config.get('cluster_sim_top_k', 3) self.search_mode = rag_config.get('mode', 'FAST') self.max_loops = rag_config.get('max_loops', 10) @@ -72,26 +73,24 @@ def __init__(self, config: DictConfig): self.llm_model_name = rag_config.get('llm_model_name', None) # Fall back to main llm config if not specified in knowledge_search - if ( - self.llm_api_key is None - or self.llm_base_url is None - or self.llm_model_name is None - ): + if (self.llm_api_key is None or self.llm_base_url is None + or self.llm_model_name is None): llm_config = config.get('llm', {}) if llm_config: service = getattr(llm_config, 'service', 'dashscope') if self.llm_api_key is None: - self.llm_api_key = getattr(llm_config, f"{service}_api_key", None) + self.llm_api_key = getattr(llm_config, + f'{service}_api_key', None) if self.llm_base_url is None: - self.llm_base_url = getattr(llm_config, f"{service}_base_url", None) + self.llm_base_url = getattr(llm_config, + f'{service}_base_url', None) if self.llm_model_name is None: self.llm_model_name = getattr(llm_config, 'model', None) # Embedding model configuration self.embedding_model_id = rag_config.get('embedding_model', None) self.embedding_model_cache_dir = rag_config.get( - 'embedding_model_cache_dir', None - ) + 'embedding_model_cache_dir', None) # Runtime state self._searcher = None @@ -109,7 +108,8 @@ def __init__(self, config: DictConfig): def _validate_config(self, config: DictConfig): """Validate configuration parameters.""" - if not hasattr(config, 'knowledge_search') or config.knowledge_search is None: + if not hasattr(config, + 'knowledge_search') or config.knowledge_search is None: raise ValueError( 'Missing knowledge_search configuration. ' 'Please add knowledge_search section to your config with at least "paths" specified.' @@ -118,7 +118,8 @@ def _validate_config(self, config: DictConfig): rag_config = config.knowledge_search paths = rag_config.get('paths', []) if not paths: - raise ValueError('knowledge_search.paths must be specified and non-empty') + raise ValueError( + 'knowledge_search.paths must be specified and non-empty') def _initialize_searcher(self): """Initialize the sirchmunk AgenticSearch instance.""" @@ -142,16 +143,12 @@ def _initialize_searcher(self): # Create embedding util # Handle empty strings by using None (which triggers DEFAULT_MODEL_ID) embedding_model_id = ( - self.embedding_model_id if self.embedding_model_id else None - ) + self.embedding_model_id if self.embedding_model_id else None) embedding_cache_dir = ( self.embedding_model_cache_dir - if self.embedding_model_cache_dir - else None - ) + if self.embedding_model_cache_dir else None) embedding = EmbeddingUtil( - model_id=embedding_model_id, cache_dir=embedding_cache_dir - ) + model_id=embedding_model_id, cache_dir=embedding_cache_dir) # Create AgenticSearch instance self._searcher = AgenticSearch( @@ -167,15 +164,16 @@ def _initialize_searcher(self): ) self._initialized = True - logger.info(f"SirschmunkSearch initialized with paths: {self.search_paths}") + logger.info( + f'SirschmunkSearch initialized with paths: {self.search_paths}' + ) except ImportError as e: raise ImportError( - f"Failed to import sirchmunk: {e}. " - 'Please install sirchmunk: pip install sirchmunk' - ) + f'Failed to import sirchmunk: {e}. ' + 'Please install sirchmunk: pip install sirchmunk') except Exception as e: - raise RuntimeError(f"Failed to initialize SirchmunkSearch: {e}") + raise RuntimeError(f'Failed to initialize SirchmunkSearch: {e}') def _log_callback_wrapper(self): """Create a callback wrapper to capture search logs. @@ -191,7 +189,7 @@ def log_callback( end: str = '\n', flush: bool = False, ): - log_entry = f"[{level.upper()}] {message}" + log_entry = f'[{level.upper()}] {message}' self._search_logs.append(log_entry) # Stream log in real-time if streaming callback is set if self._streaming_callback: @@ -223,7 +221,7 @@ async def add_documents(self, documents: List[str]) -> bool: await self._searcher.knowledge_base.refresh() return True except Exception as e: - logger.error(f"Failed to refresh knowledge base: {e}") + logger.error(f'Failed to refresh knowledge base: {e}') return False return True @@ -242,16 +240,19 @@ async def add_documents_from_files(self, file_paths: List[str]) -> bool: try: for file_path in file_paths: if Path(file_path).exists(): - await self._searcher.scan_directory(str(Path(file_path).parent)) + await self._searcher.scan_directory( + str(Path(file_path).parent)) return True except Exception as e: - logger.error(f"Failed to scan files: {e}") + logger.error(f'Failed to scan files: {e}') return False return True - async def retrieve( - self, query: str, limit: int = 5, score_threshold: float = 0.7, **filters - ) -> List[Dict[str, Any]]: + async def retrieve(self, + query: str, + limit: int = 5, + score_threshold: float = 0.7, + **filters) -> List[Dict[str, Any]]: """Retrieve relevant documents using sirchmunk. Args: @@ -270,7 +271,8 @@ async def retrieve( try: mode = filters.get('mode', self.search_mode) max_loops = filters.get('max_loops', self.max_loops) - max_token_budget = filters.get('max_token_budget', self.max_token_budget) + max_token_budget = filters.get('max_token_budget', + self.max_token_budget) # Perform search result = await self._searcher.search( @@ -286,16 +288,18 @@ async def retrieve( self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: # If a similar cluster was found and reused, it's a cache hit - self._cluster_cache_hit = getattr(result.cluster, '_reused_from_cache', False) + self._cluster_cache_hit = getattr(result.cluster, + '_reused_from_cache', False) # Get the cluster cache hit time if available if hasattr(result.cluster, 'updated_at'): - self._cluster_cache_hit_time = getattr(result.cluster, 'updated_at', None) + self._cluster_cache_hit_time = getattr( + result.cluster, 'updated_at', None) # Parse results into standard format return self._parse_search_result(result, score_threshold, limit) except Exception as e: - logger.error(f"SirschmunkSearch retrieve failed: {e}") + logger.error(f'SirschmunkSearch retrieve failed: {e}') return [] async def query(self, query: str) -> str: @@ -332,12 +336,15 @@ async def query(self, query: str) -> str: self._cluster_cache_hit = False self._cluster_cache_hit_time = None if hasattr(result, 'cluster') and result.cluster is not None: - self._cluster_cache_hit = getattr(result.cluster, '_reused_from_cache', False) + self._cluster_cache_hit = getattr(result.cluster, + '_reused_from_cache', False) if hasattr(result.cluster, 'updated_at'): - self._cluster_cache_hit_time = getattr(result.cluster, 'updated_at', None) + self._cluster_cache_hit_time = getattr( + result.cluster, 'updated_at', None) # Store parsed context for frontend display - self._last_search_result = self._parse_search_result(result, score_threshold=0.7, limit=5) + self._last_search_result = self._parse_search_result( + result, score_threshold=0.7, limit=5) # Extract the synthesized answer from the context result if hasattr(result, 'answer'): @@ -351,12 +358,11 @@ async def query(self, query: str) -> str: return str(result) except Exception as e: - logger.error(f"SirschmunkSearch query failed: {e}") - return f"Query failed: {e}" + logger.error(f'SirschmunkSearch query failed: {e}') + return f'Query failed: {e}' - def _parse_search_result( - self, result: Any, score_threshold: float, limit: int - ) -> List[Dict[str, Any]]: + def _parse_search_result(self, result: Any, score_threshold: float, + limit: int) -> List[Dict[str, Any]]: """Parse sirchmunk search result into standard format. Args: @@ -385,38 +391,37 @@ def _parse_search_result( else: text_parts.append(str(snippet)) - results.append( - { - 'text': '\n'.join(text_parts) - if text_parts - else getattr(unit, 'summary', ''), - 'score': score, - 'metadata': { - 'source': source, - 'type': getattr(unit, 'abstraction_level', 'text') - if hasattr(unit, 'abstraction_level') - else 'text', - }, - } - ) + results.append({ + 'text': + '\n'.join(text_parts) if text_parts else getattr( + unit, 'summary', ''), + 'score': + score, + 'metadata': { + 'source': + source, + 'type': + getattr(unit, 'abstraction_level', 'text') + if hasattr(unit, 'abstraction_level') else 'text', + }, + }) # Handle format with evidence_units attribute directly elif hasattr(result, 'evidence_units'): for unit in result.evidence_units: score = getattr(unit, 'confidence', 1.0) if score >= score_threshold: - results.append( - { - 'text': str(unit.content) - if hasattr(unit, 'content') - else str(unit), - 'score': score, - 'metadata': { - 'source': getattr(unit, 'source_file', 'unknown'), - 'type': getattr(unit, 'abstraction_level', 'text'), - }, - } - ) + results.append({ + 'text': + str(unit.content) + if hasattr(unit, 'content') else str(unit), + 'score': + score, + 'metadata': { + 'source': getattr(unit, 'source_file', 'unknown'), + 'type': getattr(unit, 'abstraction_level', 'text'), + }, + }) # Handle list format elif isinstance(result, list): @@ -424,27 +429,27 @@ def _parse_search_result( if isinstance(item, dict): score = item.get('score', item.get('confidence', 1.0)) if score >= score_threshold: - results.append( - { - 'text': item.get( - 'content', item.get('text', str(item)) - ), - 'score': score, - 'metadata': item.get('metadata', {}), - } - ) + results.append({ + 'text': + item.get('content', item.get('text', str(item))), + 'score': + score, + 'metadata': + item.get('metadata', {}), + }) # Handle dict format elif isinstance(result, dict): score = result.get('score', result.get('confidence', 1.0)) if score >= score_threshold: - results.append( - { - 'text': result.get('content', result.get('text', str(result))), - 'score': score, - 'metadata': result.get('metadata', {}), - } - ) + results.append({ + 'text': + result.get('content', result.get('text', str(result))), + 'score': + score, + 'metadata': + result.get('metadata', {}), + }) # Sort by score and limit results results.sort(key=lambda x: x.get('score', 0), reverse=True) From 08914187cecd38d500fe9d8539aea4cd4e3723e4 Mon Sep 17 00:00:00 2001 From: suluyan Date: Mon, 23 Mar 2026 19:04:55 +0800 Subject: [PATCH 7/8] fix comments --- ms_agent/cli/run.py | 29 ++- ms_agent/config/env.py | 32 +++- ms_agent/knowledge_search/README.md | 277 ---------------------------- 3 files changed, 38 insertions(+), 300 deletions(-) delete mode 100644 ms_agent/knowledge_search/README.md diff --git a/ms_agent/cli/run.py b/ms_agent/cli/run.py index a07397c95..67ec8d563 100644 --- a/ms_agent/cli/run.py +++ b/ms_agent/cli/run.py @@ -5,6 +5,7 @@ from importlib import resources as importlib_resources from ms_agent.config import Config +from ms_agent.config.env import Env from ms_agent.utils import get_logger, strtobool from ms_agent.utils.constants import AGENT_CONFIG_FILE, MS_AGENT_ASCII from omegaconf import OmegaConf @@ -47,28 +48,21 @@ class RunCMD(CLICommand): def __init__(self, args): self.args = args - def load_env_file(self): - """Load environment variables from .env file in current directory.""" - env_file = os.path.join(os.getcwd(), '.env') - if os.path.exists(env_file): - with open(env_file, 'r') as f: - for line in f: - line = line.strip() - if line and not line.startswith('#') and '=' in line: - key, value = line.split('=', 1) - key = key.strip() - value = value.strip() - # Only set if not already set in environment - if key not in os.environ: - os.environ[key] = value - logger.debug(f'Loaded {key} from .env file') - @staticmethod def define_args(parsers: argparse.ArgumentParser): """Define args for run command.""" projects = list_builtin_projects() parser: argparse.ArgumentParser = parsers.add_parser(RunCMD.name) + parser.add_argument( + '--env', + required=False, + type=str, + default=None, + metavar='PATH', + help= + 'Path to a .env file. If omitted, loads ./.env from the current ' + 'working directory when present; missing file is ignored.') parser.add_argument( '--query', required=False, @@ -175,8 +169,7 @@ def execute(self): return self._execute_with_config() def _execute_with_config(self): - # Load environment variables from .env file if exists - self.load_env_file() + Env.load_dotenv_into_environ(getattr(self.args, 'env', None)) if not self.args.config: current_dir = os.getcwd() diff --git a/ms_agent/config/env.py b/ms_agent/config/env.py index 78cb428f7..0d4cd3fb0 100644 --- a/ms_agent/config/env.py +++ b/ms_agent/config/env.py @@ -1,7 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import os.path +import os from copy import copy -from typing import Dict +from typing import Dict, Optional from dotenv import load_dotenv @@ -9,9 +9,31 @@ class Env: @staticmethod - def load_env(envs: Dict[str, str] = None) -> Dict[str, str]: - """Load environment variables from .env file and merges with the input envs""" - load_dotenv() + def load_dotenv_into_environ(dotenv_path: Optional[str] = None) -> None: + """Load key=value pairs from a .env file into ``os.environ``. + + Does not override variables already set in the process environment. + + If ``dotenv_path`` is given, loads that file; it must exist. + If ``dotenv_path`` is None, loads ``/.env`` when that file exists; + a missing default file is a no-op. + """ + if dotenv_path is not None: + path = os.path.abspath(os.path.expanduser(dotenv_path)) + if not os.path.isfile(path): + raise FileNotFoundError(f'Env file not found: {path}') + load_dotenv(path, override=False) + else: + default = os.path.join(os.getcwd(), '.env') + if os.path.isfile(default): + load_dotenv(default, override=False) + + @staticmethod + def load_env( + envs: Dict[str, str] = None, + dotenv_path: Optional[str] = None) -> Dict[str, str]: + """Load .env into the process env, then merge with ``envs`` and return.""" + Env.load_dotenv_into_environ(dotenv_path) _envs = copy(os.environ) _envs.update(envs or {}) return _envs diff --git a/ms_agent/knowledge_search/README.md b/ms_agent/knowledge_search/README.md deleted file mode 100644 index ef86df0da..000000000 --- a/ms_agent/knowledge_search/README.md +++ /dev/null @@ -1,277 +0,0 @@ -# Sirchmunk Knowledge Search 集成 - -本模块实现了 [sirchmunk](https://github.com/modelscope/sirchmunk) 与 ms_agent 框架的集成,提供了基于代码库的智能搜索功能。 - -## 功能特性 - -- **智能代码搜索**: 使用 LLM 和 embedding 模型对代码库进行语义搜索 -- **多模式搜索**: 支持 FAST、DEEP、FILENAME_ONLY 三种搜索模式 -- **知识复用**: 自动缓存和复用之前的搜索结果,减少 LLM 调用 -- **前端友好**: 提供详细的搜索日志和结果,方便前端展示 -- **无缝集成**: 与 LLMAgent 无缝集成,像使用 RAG 一样简单 - -## 安装 - -```bash -pip install sirchmunk -``` - -## 配置 - -在您的 `agent.yaml` 或 `workflow.yaml` 中添加以下配置: - -```yaml -llm: - service: dashscope - model: qwen3.5-plus - dashscope_api_key: - dashscope_base_url: https://dashscope.aliyuncs.com/compatible-mode/v1 - -generation_config: - temperature: 0.3 - stream: true - -# Knowledge Search 配置 -knowledge_search: - # 必选:要搜索的路径列表 - paths: - - ./src - - ./docs - - # 可选:sirchmunk 工作目录 - work_path: ./.sirchmunk - - # 可选:LLM 配置(如不配置则自动复用上面 llm 模块的配置) - # llm_api_key: - # llm_base_url: https://api.openai.com/v1 - # llm_model_name: gpt-4o-mini - - # 可选:Embedding 模型 - embedding_model: text-embedding-3-small - - # 可选:搜索模式 (DEEP, FAST, FILENAME_ONLY) - mode: FAST - - # 可选:是否重用之前的知识 - reuse_knowledge: true -``` - -**LLM 配置自动复用机制**: - -`SirchmunkSearch` 会自动从主配置的 `llm` 模块复用 LLM 相关参数: -- 如果 `knowledge_search.llm_api_key` 未配置,自动使用 `llm.{service}_api_key` -- 如果 `knowledge_search.llm_base_url` 未配置,自动使用 `llm.{service}_base_url` -- 如果 `knowledge_search.llm_model_name` 未配置,自动使用 `llm.model` - -其中 `service` 是 `llm.service` 的值(如 `dashscope`, `modelscope`, `openai` 等)。 - -通过 CLI 使用时,只需传入 `--knowledge_search_paths` 参数,无需额外配置 LLM 参数。 - -## 使用方式 - -### 1. 通过 CLI 使用(推荐) - -从命令行直接运行,无需编写代码: - -```bash -# 基本用法 - LLM 配置自动从 agent.yaml 的 llm 模块复用 -ms-agent run --query "如何实现用户认证功能?" --knowledge_search_paths "./src,./docs" - -# 指定配置文件 -ms-agent run --config /path/to/agent.yaml --query "你的问题" --knowledge_search_paths "/path/to/docs" -``` - -**说明**: -- `--knowledge_search_paths` 参数支持逗号分隔的多个路径 -- LLM 相关配置(api_key, base_url, model)会自动从配置文件的 `llm` 模块复用 -- 如果 `knowledge_search` 模块单独配置了 `llm_api_key` 等参数,则优先使用模块自己的配置 - -### 2. 通过 LLMAgent 使用 - -```python -from ms_agent import LLMAgent -from ms_agent.config import Config - -# 从配置文件加载 -config = Config.from_task('path/to/agent.yaml') -agent = LLMAgent(config=config) - -# 运行查询 - 会自动触发知识搜索 -result = await agent.run('如何实现用户认证功能?') - -# 获取搜索结果 -for msg in result: - if msg.role == 'user': - # 搜索详情(用于前端展示) - print(f"Search logs: {msg.searching_detail}") - # 搜索结果(作为 LLM 上下文) - print(f"Search results: {msg.search_result}") -``` - -### 3. 单独使用 SirchmunkSearch - -```python -from ms_agent.knowledge_search import SirchmunkSearch -from omegaconf import DictConfig - -config = DictConfig({ - 'knowledge_search': { - 'paths': ['./src', './docs'], - 'work_path': './.sirchmunk', - 'llm_api_key': 'your-api-key', - 'llm_model_name': 'gpt-4o-mini', - 'mode': 'FAST', - } -}) - -searcher = SirchmunkSearch(config) - -# 查询(返回合成答案) -answer = await searcher.query('如何实现用户认证?') - -# 检索(返回原始搜索结果) -results = await searcher.retrieve( - query='用户认证', - limit=5, - score_threshold=0.7 -) - -# 获取搜索日志 -logs = searcher.get_search_logs() - -# 获取搜索详情 -details = searcher.get_search_details() -``` - -## 环境变量 - -可以通过环境变量配置: - -```bash -# LLM 配置(如不设置则自动从 agent.yaml 的 llm 模块读取) -export LLM_API_KEY="your-api-key" -export LLM_BASE_URL="https://api.openai.com/v1" -export LLM_MODEL_NAME="gpt-4o-mini" - -# Embedding 模型配置 -export EMBEDDING_MODEL_ID="text-embedding-3-small" -export SIRCHMUNK_WORK_PATH="./.sirchmunk" -``` - -**注意**:通过 CLI 使用时,推荐直接在 `.env` 文件或 agent.yaml 中配置 LLM 参数,`SirchmunkSearch` 会自动复用。 - -## 测试 - -### 单元测试 - -```bash -export LLM_API_KEY="your-api-key" -export LLM_BASE_URL="https://api.openai.com/v1" -export LLM_MODEL_NAME="gpt-4o-mini" - -python -m unittest tests/knowledge_search/test_sirschmunk.py -``` - -### CLI 测试 - -```bash -# 基本测试 -python tests/knowledge_search/test_cli.py - -# 指定查询 -python tests/knowledge_search/test_cli.py -q "如何实现用户认证?" - -# 仅测试 standalone 模式 -python tests/knowledge_search/test_cli.py -m standalone - -# 仅测试 agent 模式 -python tests/knowledge_search/test_cli.py -m agent -``` - -## 配置参数说明 - -| 参数 | 类型 | 默认值 | 说明 | -|------|------|--------|------| -| paths | List[str] | 必选 | 要搜索的目录/文件路径列表 | -| work_path | str | ./.sirchmunk | sirchmunk 工作目录,用于缓存 | -| llm_api_key | str | 从 llm 配置继承 | LLM API 密钥 | -| llm_base_url | str | 从 llm 配置继承 | LLM API 基础 URL | -| llm_model_name | str | 从 llm 配置继承 | LLM 模型名称 | -| embedding_model | str | text-embedding-3-small | Embedding 模型 ID | -| cluster_sim_threshold | float | 0.85 | 聚类相似度阈值 | -| cluster_sim_top_k | int | 3 | 聚类 TopK 数量 | -| reuse_knowledge | bool | true | 是否重用之前的知识 | -| mode | str | FAST | 搜索模式 (DEEP/FAST/FILENAME_ONLY) | -| max_loops | int | 10 | 最大搜索循环次数 | -| max_token_budget | int | 128000 | 最大 token 预算 | - -## 搜索模式 - -- **FAST**: 快速模式,使用贪婪策略,1-5 秒内返回结果,0-2 次 LLM 调用 -- **DEEP**: 深度模式,并行多路径检索 + ReAct 优化,5-30 秒,4-6 次 LLM 调用 -- **FILENAME_ONLY**: 仅文件名模式,基于模式匹配,无 LLM 调用,非常快 - -## Message 字段扩展 - -为了支持知识搜索,`Message` 类增加了两个字段: - -- **searching_detail** (Dict[str, Any]): 搜索过程日志和元数据,用于前端展示 - - `logs`: 搜索日志列表 - - `mode`: 使用的搜索模式 - - `paths`: 搜索的路径 - - `work_path`: 工作目录 - - `reuse_knowledge`: 是否重用知识 - -- **search_result** (List[Dict[str, Any]]): 搜索结果,作为下一轮 LLM 的上下文 - - `text`: 文档内容 - - `score`: 相关性分数 - - `metadata`: 元数据(如源文件、类型等) - -## 工作原理 - -1. 用户发送查询 -2. LLMAgent 调用 `prepare_knowledge_search()` 初始化 SirchmunkSearch -3. `do_rag()` 方法执行知识搜索: - - 调用 `searcher.retrieve()` 获取相关文档 - - 将搜索结果存入 `message.search_result` - - 将搜索日志存入 `message.searching_detail` - - 将搜索结果格式化为上下文,附加到用户查询 -4. LLM 接收 enriched query 并生成回答 -5. 前端可以通过 `searching_detail` 展示搜索过程 - -## 故障排除 - -### 常见问题 - -1. **ImportError: No module named 'sirchmunk'** - ```bash - pip install sirchmunk - ``` - -2. **搜索结果为空** - - 检查 `paths` 配置是否正确 - - 确保路径下有可搜索的文件 - - 尝试降低 `cluster_sim_threshold` 值 - -3. **LLM API 调用失败** - - 检查 API key 是否正确 - - 检查 base URL 是否正确 - - 查看搜索日志了解详细错误 - -### 日志查看 - -```python -# 查看搜索日志 -logs = searcher.get_search_logs() -for log in logs: - print(log) - -# 或在配置中启用 verbose -knowledge_search: - verbose: true -``` - -## 参考资源 - -- [sirchmunk GitHub](https://github.com/modelscope/sirchmunk) -- [ModelScope Agent](https://github.com/modelscope/modelscope-agent) From 7823513b8f578c83623d9dc3589e4f99406a3137 Mon Sep 17 00:00:00 2001 From: suluyan Date: Mon, 23 Mar 2026 19:06:30 +0800 Subject: [PATCH 8/8] fix lint --- ms_agent/config/env.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ms_agent/config/env.py b/ms_agent/config/env.py index 0d4cd3fb0..83553254c 100644 --- a/ms_agent/config/env.py +++ b/ms_agent/config/env.py @@ -29,9 +29,8 @@ def load_dotenv_into_environ(dotenv_path: Optional[str] = None) -> None: load_dotenv(default, override=False) @staticmethod - def load_env( - envs: Dict[str, str] = None, - dotenv_path: Optional[str] = None) -> Dict[str, str]: + def load_env(envs: Dict[str, str] = None, + dotenv_path: Optional[str] = None) -> Dict[str, str]: """Load .env into the process env, then merge with ``envs`` and return.""" Env.load_dotenv_into_environ(dotenv_path) _envs = copy(os.environ)