diff --git a/docs/en/Components/Config.md b/docs/en/Components/Config.md index d40f03898..f1253bd75 100644 --- a/docs/en/Components/Config.md +++ b/docs/en/Components/Config.md @@ -166,3 +166,20 @@ In addition to yaml configuration, MS-Agent also supports several additional com ``` > Any configuration in agent.yaml can be passed in with new values via command line, and also supports reading from environment variables with the same name (case insensitive), for example `--llm.modelscope_api_key xxx-xxx`. + +- knowledge_search_paths: Knowledge search paths, comma-separated multiple paths. When provided, automatically enables SirchmunkSearch for knowledge retrieval, with LLM configuration automatically inherited from the `llm` module. + +### Quick Start for Knowledge Search + +Use the `--knowledge_search_paths` parameter to quickly enable knowledge search based on local documents: + +```bash +# Using default agent.yaml configuration, automatically reuses LLM settings +ms-agent run --query "How to implement user authentication?" --knowledge_search_paths "./src,./docs" + +# Specify configuration file +ms-agent run --config /path/to/agent.yaml --query "your question" --knowledge_search_paths "/path/to/docs" +``` + +LLM-related parameters (api_key, base_url, model) are automatically inherited from the `llm` module in the configuration file, no need to configure them repeatedly. +If you need to use independent LLM configuration in the `knowledge_search` module, you can explicitly configure `knowledge_search.llm_api_key` and other parameters in the yaml. diff --git a/docs/zh/Components/config.md b/docs/zh/Components/config.md index 041820b93..12849f2a7 100644 --- a/docs/zh/Components/config.md +++ b/docs/zh/Components/config.md @@ -1,12 +1,12 @@ --- slug: config title: 配置与参数 -description: Ms-Agent 配置与参数:类型配置、自定义代码、LLM配置、推理配置、system和query、callbacks、工具配置、其他、config_handler、命令行配置 +description: Ms-Agent 配置与参数:类型配置、自定义代码、LLM 配置、推理配置、system 和 query、callbacks、工具配置、其他、config_handler、命令行配置 --- # 配置与参数 -MS-Agent使用一个yaml文件进行配置管理,通常这个文件被命名为`agent.yaml`,这样的设计使不同场景可以读取不同的配置文件。该文件具体包含的字段有: +MS-Agent 使用一个 yaml 文件进行配置管理,通常这个文件被命名为 `agent.yaml`,这样的设计使不同场景可以读取不同的配置文件。该文件具体包含的字段有: ## 类型配置 @@ -17,31 +17,31 @@ MS-Agent使用一个yaml文件进行配置管理,通常这个文件被命名 type: llmagent ``` -标识本配置对应的agent类型,支持`llmagent`和`codeagent`两类。默认为`llmagent`。如果yaml中包含了code_file字段,则code_file优先生效。 +标识本配置对应的 agent 类型,支持 `llmagent` 和 `codeagent` 两类。默认为 `llmagent`。如果 yaml 中包含了 code_file 字段,则 code_file 优先生效。 ## 自定义代码 -> 可选,在需要自定义LLMAgent时使用 +> 可选,在需要自定义 LLMAgent 时使用 ```yaml code_file: custom_agent ``` -可以使用一个外部agent类,该类需要继承自`LLMAgent`。可以复写其中的若干方法,如果code_file有值,则`type`字段不生效。 +可以使用一个外部 agent 类,该类需要继承自 `LLMAgent`。可以复写其中的若干方法,如果 code_file 有值,则 `type` 字段不生效。 -## LLM配置 +## LLM 配置 > 必须存在 ```yaml llm: - # 大模型服务backend + # 大模型服务 backend service: modelscope - # 模型id + # 模型 id model: Qwen/Qwen3-235B-A22B-Instruct-2507 - # 模型api_key + # 模型 api_key modelscope_api_key: - # 模型base_url + # 模型 base_url modelscope_base_url: https://api-inference.modelscope.cn/v1 ``` @@ -51,7 +51,7 @@ llm: ```yaml generation_config: - # 下面的字段均为OpenAI sdk的标准参数,你也可以配置OpenAI支持的其他参数在这里。 + # 下面的字段均为 OpenAI sdk 的标准参数,你也可以配置 OpenAI 支持的其他参数在这里。 top_p: 0.6 temperature: 0.2 top_k: 20 @@ -60,25 +60,25 @@ generation_config: enable_thinking: false ``` -## system和query +## system 和 query -> 可选,但推荐传入system +> 可选,但推荐传入 system ```yaml prompt: - # LLM system,如果不传递则使用默认的`you are a helpful assistant.` + # LLM system,如果不传递则使用默认的 `you are a helpful assistant.` system: - # LLM初始query,通常来说可以不使用 + # LLM 初始 query,通常来说可以不使用 query: ``` ## callbacks -> 可选,推荐自定义callbacks +> 可选,推荐自定义 callbacks ```yaml callbacks: - # 用户输入callback,该callback在assistant回复后自动等待用户输入 + # 用户输入 callback,该 callback 在 assistant 回复后自动等待用户输入 - input_callback ``` @@ -90,9 +90,9 @@ callbacks: tools: # 工具名称 file_system: - # 是否是mcp + # 是否是 mcp mcp: false - # 排除的function,可以为空 + # 排除的 function,可以为空 exclude: - create_directory - write_file @@ -104,20 +104,20 @@ tools: - map_geo ``` -支持的完整工具列表,以及自定义工具请参考[这里](./tools) +支持的完整工具列表,以及自定义工具请参考 [这里](./tools) ## 其他 > 可选,按需配置 ```yaml -# 自动对话轮数,默认为20轮 +# 自动对话轮数,默认为 20 轮 max_chat_round: 9999 # 工具调用超时时间,单位秒 tool_call_timeout: 30000 -# 输出artifact目录 +# 输出 artifact 目录 output_dir: output # 帮助信息,通常在运行错误后出现 @@ -127,13 +127,13 @@ help: | ## config_handler -为了便于在任务开始时对config进行定制化,MS-Agent构建了一个名为`ConfigLifecycleHandler`的机制。这是一个callback类,开发者可以在yaml文件中增加这样一个配置: +为了便于在任务开始时对 config 进行定制化,MS-Agent 构建了一个名为 `ConfigLifecycleHandler` 的机制。这是一个 callback 类,开发者可以在 yaml 文件中增加这样一个配置: ```yaml handler: custom_handler ``` -这代表和yaml文件同级有一个custom_handler.py文件,该文件的类继承自`ConfigLifecycleHandler`,分别有两个方法: +这代表和 yaml 文件同级有一个 custom_handler.py 文件,该文件的类继承自 `ConfigLifecycleHandler`,分别有两个方法: ```python def task_begin(self, config: DictConfig, tag: str) -> DictConfig: @@ -143,18 +143,18 @@ handler: custom_handler return config ``` -`task_begin`在LLMAgent类构造时生效,在该方法中可以对config进行一些修改。如果你的工作流中下游任务会继承上游的yaml配置,这个机制会有帮助。值得注意的是`tag`参数,该参数会传入当前LLMAgent的名字,方便分辨当前工作流的节点。 +`task_begin` 在 LLMAgent 类构造时生效,在该方法中可以对 config 进行一些修改。如果你的工作流中下游任务会继承上游的 yaml 配置,这个机制会有帮助。值得注意的是 `tag` 参数,该参数会传入当前 LLMAgent 的名字,方便分辨当前工作流的节点。 ## 命令行配置 -在yaml配置之外,MS-Agent还支持若干额外的命令行参数。 +在 yaml 配置之外,MS-Agent 还支持若干额外的命令行参数。 -- query: 初始query,这个query的优先级高于yaml中的prompt.query -- config: 配置文件路径,支持modelscope model-id -- trust_remote_code: 是否信任外部代码。如果某个配置包含了一些外部代码,需要将这个参数置为true才会生效 -- load_cache: 从历史messages继续对话。cache会被自动存储在`output`配置中。默认为`False` -- mcp_server_file: 可以读取一个外部的mcp工具配置,格式为: +- query: 初始 query,这个 query 的优先级高于 yaml 中的 prompt.query +- config: 配置文件路径,支持 modelscope model-id +- trust_remote_code: 是否信任外部代码。如果某个配置包含了一些外部代码,需要将这个参数置为 true 才会生效 +- load_cache: 从历史 messages 继续对话。cache 会被自动存储在 `output` 配置中。默认为 `False` +- mcp_server_file: 可以读取一个外部的 mcp 工具配置,格式为: ```json { "mcpServers": { @@ -165,5 +165,21 @@ handler: custom_handler } } ``` +- knowledge_search_paths: 知识搜索路径,逗号分隔的多个路径。传入后会自动启用 SirchmunkSearch 进行知识检索,LLM 配置自动从 `llm` 模块复用 -> agent.yaml中的任意一个配置,都可以使用命令行传入新的值, 也支持从同名(大小写不敏感)环境变量中读取,例如`--llm.modelscope_api_key xxx-xxx`。 +> agent.yaml 中的任意一个配置,都可以使用命令行传入新的值,也支持从同名(大小写不敏感)环境变量中读取,例如 `--llm.modelscope_api_key xxx-xxx`。 + +### 知识搜索快速使用 + +通过 `--knowledge_search_paths` 参数,可以快速启用基于本地文档的知识搜索: + +```bash +# 使用默认 agent.yaml 配置,自动复用 LLM 设置 +ms-agent run --query "如何实现用户认证?" --knowledge_search_paths "./src,./docs" + +# 指定配置文件 +ms-agent run --config /path/to/agent.yaml --query "你的问题" --knowledge_search_paths "/path/to/docs" +``` + +LLM 相关参数(api_key, base_url, model)会自动从配置文件的 `llm` 模块继承,无需重复配置。 +如果需要在 `knowledge_search` 模块中使用独立的 LLM 配置,可以在 yaml 中显式配置 `knowledge_search.llm_api_key` 等参数。 diff --git a/examples/knowledge_search/agent.yaml.example b/examples/knowledge_search/agent.yaml.example new file mode 100644 index 000000000..cc11a8a3d --- /dev/null +++ b/examples/knowledge_search/agent.yaml.example @@ -0,0 +1,86 @@ +# Sirchmunk Knowledge Search 配置示例 +# Sirchmunk Knowledge Search Configuration Example + +# 在您的 agent.yaml 或 workflow.yaml 中添加以下配置: + +llm: + service: modelscope + model: Qwen/Qwen3-235B-A22B-Instruct-2507 + modelscope_api_key: + modelscope_base_url: https://api-inference.modelscope.cn/v1 + +generation_config: + temperature: 0.3 + top_k: 20 + stream: true + +# Knowledge Search 配置(可选) +# 用于在本地代码库中搜索相关信息 +knowledge_search: + # 必选:要搜索的路径列表 + paths: + - ./src + - ./docs + + # 可选:sirchmunk 工作目录,用于缓存 + work_path: ./.sirchmunk + + # 可选:LLM 配置(如不配置则使用上面 llm 的配置) + llm_api_key: + llm_base_url: https://api.openai.com/v1 + llm_model_name: gpt-4o-mini + + # 可选:Embedding 模型 + embedding_model: text-embedding-3-small + + # 可选:聚类相似度阈值 + cluster_sim_threshold: 0.85 + + # 可选:聚类 TopK + cluster_sim_top_k: 3 + + # 可选:是否重用之前的知识 + reuse_knowledge: true + + # 可选:搜索模式 (DEEP, FAST, FILENAME_ONLY) + mode: FAST + + # 可选:最大循环次数 + max_loops: 10 + + # 可选:最大 token 预算 + max_token_budget: 128000 + +prompt: + system: | + You are an assistant that helps me complete tasks. + +max_chat_round: 9999 + +# 使用说明: +# 1. 配置 knowledge_search 后,LLMAgent 会在处理用户请求时自动搜索本地代码库 +# 2. 搜索结果会自动添加到 user message 的 search_result 和 searching_detail 字段 +# 3. search_result 包含搜索到的相关文档,会作为上下文提供给 LLM +# 4. searching_detail 包含搜索日志和元数据,可用于前端展示 +# +# Python 使用示例: +# ```python +# from ms_agent import LLMAgent +# from ms_agent.config import Config +# +# config = Config.from_task('path/to/agent.yaml') +# agent = LLMAgent(config=config) +# result = await agent.run('如何实现用户认证功能?') +# +# # 获取搜索详情(用于前端展示) +# for msg in result: +# if msg.role == 'user': +# print(f"Search logs: {msg.searching_detail}") +# print(f"Search results: {msg.search_result}") +# ``` +# +# CLI 测试命令: +# export LLM_API_KEY="your-api-key" +# export LLM_BASE_URL="https://api.openai.com/v1" +# export LLM_MODEL_NAME="gpt-4o-mini" +# python tests/knowledge_search/test_cli.py --query "你的问题" diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 740eab690..5f2ddf2e7 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -13,6 +13,7 @@ import json from ms_agent.agent.runtime import Runtime from ms_agent.callbacks import Callback, callbacks_mapping +from ms_agent.knowledge_search import SirchmunkSearch from ms_agent.llm.llm import LLM from ms_agent.llm.utils import Message, ToolResult from ms_agent.memory import Memory, get_memory_meta_safe, memory_mapping @@ -89,11 +90,13 @@ class LLMAgent(Agent): TOTAL_CACHE_CREATION_INPUT_TOKENS = 0 TOKEN_LOCK = asyncio.Lock() - def __init__(self, - config: DictConfig = DictConfig({}), - tag: str = DEFAULT_TAG, - trust_remote_code: bool = False, - **kwargs): + def __init__( + self, + config: DictConfig = DictConfig({}), + tag: str = DEFAULT_TAG, + trust_remote_code: bool = False, + **kwargs, + ): if not hasattr(config, 'llm'): default_yaml = os.path.join( os.path.dirname(os.path.abspath(__file__)), 'agent.yaml') @@ -104,6 +107,7 @@ def __init__(self, self.tool_manager: Optional[ToolManager] = None self.memory_tools: List[Memory] = [] self.rag: Optional[RAG] = None + self.knowledge_search: Optional[SirschmunkSearch] = None self.llm: Optional[LLM] = None self.runtime: Optional[Runtime] = None self.max_chat_round: int = 0 @@ -159,6 +163,7 @@ def _ensure_auto_skills(self) -> bool: use_sandbox = getattr(skills_config, 'use_sandbox', True) if use_sandbox: from ms_agent.utils.docker_utils import is_docker_daemon_running + if not is_docker_daemon_running(): logger.warning( 'Docker not running, disabling sandbox for skills') @@ -263,13 +268,15 @@ async def execute_skills(self, query: str, execution_input=None): return None skills_config = self._get_skills_config() - stop_on_failure = getattr(skills_config, 'stop_on_failure', - True) if skills_config else True + stop_on_failure = ( + getattr(skills_config, 'stop_on_failure', True) + if skills_config else True) result = await self._auto_skills.run( query=query, execution_input=execution_input, - stop_on_failure=stop_on_failure) + stop_on_failure=stop_on_failure, + ) self._last_skill_result = result return result @@ -319,7 +326,9 @@ def _format_skill_result_as_messages(self, dag_result) -> List[Message]: if output.output_files: content += f'**Generated files:** {list(output.output_files.values())}\n\n' - content += f'Total execution time: {exec_result.total_duration_ms:.2f}ms' + content += ( + f'Total execution time: {exec_result.total_duration_ms:.2f}ms' + ) else: content = 'Skill execution completed with errors.\n\n' for skill_id, result in exec_result.results.items(): @@ -396,7 +405,9 @@ def register_config_handler(self) -> Optional[ConfigLifecycleHandler]: f'registered in the config: {handler_file}. ' f'\nThis is external code, if you trust this workflow, ' f'please specify `--trust_remote_code true`') - assert local_dir is not None, 'Using external py files, but local_dir cannot be found.' + assert ( + local_dir is not None + ), 'Using external py files, but local_dir cannot be found.' if local_dir not in sys.path: sys.path.insert(0, local_dir) @@ -408,10 +419,12 @@ def register_config_handler(self) -> Optional[ConfigLifecycleHandler]: } handler = None for name, handler_cls in module_classes.items(): - if handler_cls.__bases__[ - 0] is ConfigLifecycleHandler and handler_cls.__module__ == handler_file: + if (handler_cls.__bases__[0] is ConfigLifecycleHandler + and handler_cls.__module__ == handler_file): handler = handler_cls() - assert handler is not None, f'Config Lifecycle handler class cannot be found in {handler_file}' + assert ( + handler is not None + ), f'Config Lifecycle handler class cannot be found in {handler_file}' return handler return None @@ -428,7 +441,9 @@ def register_callback_from_config(self): callbacks = self.config.callbacks or [] for _callback in callbacks: subdir = os.path.dirname(_callback) - assert local_dir is not None, 'Using external py files, but local_dir cannot be found.' + assert ( + local_dir is not None + ), 'Using external py files, but local_dir cannot be found.' if subdir: subdir = os.path.join(local_dir, str(subdir)) _callback = os.path.basename(_callback) @@ -512,7 +527,8 @@ async def parallel_tool_call(self, content=tool_call_result_format.text, tool_call_id=tool_call_query['id'], name=tool_call_query['tool_name'], - resources=tool_call_result_format.resources) + resources=tool_call_result_format.resources, + ) if _new_message.tool_call_id is None: # If tool call id is None, add a random one @@ -528,7 +544,8 @@ async def prepare_tools(self): self.config, self.mcp_config, self.mcp_client, - trust_remote_code=self.trust_remote_code) + trust_remote_code=self.trust_remote_code, + ) await self.tool_manager.connect() async def cleanup_tools(self): @@ -602,8 +619,8 @@ async def create_messages( """ if isinstance(messages, list): system = self.system - if system is not None and messages[ - 0].role == 'system' and system != messages[0].content: + if (system is not None and messages[0].role == 'system' + and system != messages[0].content): # Replace the existing system messages[0].content = system else: @@ -619,8 +636,41 @@ async def create_messages( return messages async def do_rag(self, messages: List[Message]): + """Process RAG or knowledge search to enrich the user query with context. + + This method handles both traditional RAG and sirchmunk-based knowledge search. + For knowledge search, it also populates searching_detail and search_result + fields in the message for frontend display and next-turn LLM context. + + Args: + messages (List[Message]): The message list to process. + """ + user_message = messages[1] if len(messages) > 1 else None + if user_message is None or user_message.role != 'user': + return + + query = user_message.content + + # Handle traditional RAG if self.rag is not None: - messages[1].content = await self.rag.query(messages[1].content) + user_message.content = await self.rag.query(query) + # Handle sirchmunk knowledge search + if self.knowledge_search is not None: + # Perform search and get results + search_result = await self.knowledge_search.query(query) + search_details = self.knowledge_search.get_search_details() + + # Store search details in the message for frontend display + user_message.searching_detail = search_details + user_message.search_result = search_result + + # Build enriched context from search results + if search_result: + # Append search context to user query + context = search_result + user_message.content = ( + f'Relevant context retrieved from codebase search:\n\n{context}\n\n' + f'User question: {query}') async def do_skill(self, messages: List[Message]) -> Optional[List[Message]]: @@ -654,8 +704,9 @@ async def do_skill(self, try: skills_config = self._get_skills_config() - auto_execute = getattr(skills_config, 'auto_execute', - True) if skills_config else True + auto_execute = ( + getattr(skills_config, 'auto_execute', True) + if skills_config else True) if auto_execute: dag_result = await self.execute_skills(query) @@ -706,6 +757,18 @@ async def prepare_rag(self): f'which supports: {list(rag_mapping.keys())}') self.rag: RAG = rag_mapping(rag.name)(self.config) + async def prepare_knowledge_search(self): + """Load and initialize the knowledge search component from the config.""" + if self.knowledge_search is not None: + # Already initialized (e.g. by caller before run_loop), skip to avoid + # overwriting a configured instance (e.g. one with streaming callbacks set). + return + if hasattr(self.config, 'knowledge_search'): + ks_config = self.config.knowledge_search + if ks_config is not None: + self.knowledge_search: SirchmunkSearch = SirchmunkSearch( + self.config) + async def condense_memory(self, messages: List[Message]) -> List[Message]: """ Update memory using the current conversation history. @@ -769,8 +832,8 @@ def handle_new_response(self, messages: List[Message], if messages[-1] is not response_message: messages.append(response_message) - if messages[-1].role == 'assistant' and not messages[ - -1].content and response_message.tool_calls: + if (messages[-1].role == 'assistant' and not messages[-1].content + and response_message.tool_calls): messages[-1].content = 'Let me do a tool calling.' @async_retry(max_attempts=Agent.retry_count, delay=1.0) @@ -820,8 +883,9 @@ async def step( # Optional: stream model "thinking/reasoning" if available. if self.show_reasoning: - reasoning_text = getattr(_response_message, - 'reasoning_content', '') or '' + reasoning_text = ( + getattr(_response_message, 'reasoning_content', '') + or '') # Some providers may reset / shorten content across chunks. if len(reasoning_text) < len(_reasoning): _reasoning = '' @@ -845,8 +909,9 @@ async def step( else: _response_message = self.llm.generate(messages, tools=tools) if self.show_reasoning: - reasoning_text = getattr(_response_message, - 'reasoning_content', '') or '' + reasoning_text = ( + getattr(_response_message, 'reasoning_content', '') + or '') if reasoning_text: self._write_reasoning('[thinking]:\n') self._write_reasoning(reasoning_text) @@ -874,8 +939,8 @@ async def step( prompt_tokens = _response_message.prompt_tokens completion_tokens = _response_message.completion_tokens cached_tokens = getattr(_response_message, 'cached_tokens', 0) or 0 - cache_creation_input_tokens = getattr( - _response_message, 'cache_creation_input_tokens', 0) or 0 + cache_creation_input_tokens = ( + getattr(_response_message, 'cache_creation_input_tokens', 0) or 0) async with LLMAgent.TOKEN_LOCK: LLMAgent.TOTAL_PROMPT_TOKENS += prompt_tokens @@ -967,7 +1032,8 @@ def _get_run_memory_info(self, memory_config: DictConfig): user_id, agent_id, run_id, memory_type = get_memory_meta_safe( memory_config, 'add_after_task', - default_user_id=getattr(memory_config, 'user_id', None)) + default_user_id=getattr(memory_config, 'user_id', None), + ) if all(value is None for value in [user_id, agent_id, run_id, memory_type]): return None, None, None, None @@ -997,7 +1063,8 @@ async def add_memory(self, messages: List[Message], add_type, **kwargs): user_id=user_id, agent_id=agent_id, run_id=run_id, - memory_type=memory_type) + memory_type=memory_type, + ) def save_history(self, messages: List[Message], **kwargs): """ @@ -1044,6 +1111,7 @@ async def run_loop(self, messages: Union[List[Message], str], await self.prepare_tools() await self.load_memory() await self.prepare_rag() + await self.prepare_knowledge_search() self.runtime.tag = self.tag if messages is None: @@ -1090,7 +1158,8 @@ async def run_loop(self, messages: Union[List[Message], str], role='assistant', content= f'Task {messages[1].content} was cutted off, because ' - f'max round({self.max_chat_round}) exceeded.')) + f'max round({self.max_chat_round}) exceeded.', + )) self.runtime.should_stop = True yield messages @@ -1108,6 +1177,7 @@ def _add_memory(): loop.run_in_executor(None, _add_memory) except Exception as e: import traceback + logger.warning(traceback.format_exc()) if hasattr(self.config, 'help'): logger.error( diff --git a/ms_agent/cli/run.py b/ms_agent/cli/run.py index c2df42eee..67ec8d563 100644 --- a/ms_agent/cli/run.py +++ b/ms_agent/cli/run.py @@ -5,8 +5,10 @@ from importlib import resources as importlib_resources from ms_agent.config import Config +from ms_agent.config.env import Env from ms_agent.utils import get_logger, strtobool from ms_agent.utils.constants import AGENT_CONFIG_FILE, MS_AGENT_ASCII +from omegaconf import OmegaConf from .base import CLICommand @@ -52,6 +54,15 @@ def define_args(parsers: argparse.ArgumentParser): projects = list_builtin_projects() parser: argparse.ArgumentParser = parsers.add_parser(RunCMD.name) + parser.add_argument( + '--env', + required=False, + type=str, + default=None, + metavar='PATH', + help= + 'Path to a .env file. If omitted, loads ./.env from the current ' + 'working directory when present; missing file is ignored.') parser.add_argument( '--query', required=False, @@ -120,6 +131,14 @@ def define_args(parsers: argparse.ArgumentParser): help= 'Animation mode for video_generate project: auto (default) or human.' ) + parser.add_argument( + '--knowledge_search_paths', + required=False, + type=str, + default=None, + help= + 'Comma-separated list of paths for knowledge search. When provided, enables SirchmunkSearch using LLM config from llm module.' + ) parser.set_defaults(func=subparser_func) def execute(self): @@ -150,10 +169,19 @@ def execute(self): return self._execute_with_config() def _execute_with_config(self): + Env.load_dotenv_into_environ(getattr(self.args, 'env', None)) + if not self.args.config: current_dir = os.getcwd() if os.path.exists(os.path.join(current_dir, AGENT_CONFIG_FILE)): self.args.config = os.path.join(current_dir, AGENT_CONFIG_FILE) + else: + # Use built-in default agent.yaml from package + default_config_path = importlib_resources.files( + 'ms_agent').joinpath('agent', AGENT_CONFIG_FILE) + with importlib_resources.as_file( + default_config_path) as config_file: + self.args.config = str(config_file) elif not os.path.exists(self.args.config): from modelscope import snapshot_download self.args.config = snapshot_download(self.args.config) @@ -190,6 +218,32 @@ def _execute_with_config(self): config = Config.from_task(self.args.config) + # If knowledge_search_paths is provided, configure SirchmunkSearch + if getattr(self.args, 'knowledge_search_paths', None): + paths = [ + p.strip() for p in self.args.knowledge_search_paths.split(',') + if p.strip() + ] + if paths: + if 'knowledge_search' not in config or not config.knowledge_search: + # No existing knowledge_search config, create minimal config + # LLM settings will be auto-reused from llm module by SirchmunkSearch + knowledge_search_config = { + 'name': 'SirchmunkSearch', + 'paths': paths, + 'work_path': './.sirchmunk', + 'mode': 'FAST', + } + config['knowledge_search'] = OmegaConf.create( + knowledge_search_config) + else: + # Existing knowledge_search config found, only update paths + # LLM settings are already handled by SirchmunkSearch internally + existing = OmegaConf.to_container( + config.knowledge_search, resolve=True) + existing['paths'] = paths + config['knowledge_search'] = OmegaConf.create(existing) + if Config.is_workflow(config): from ms_agent.workflow.loader import WorkflowLoader engine = WorkflowLoader.build( diff --git a/ms_agent/config/env.py b/ms_agent/config/env.py index 78cb428f7..83553254c 100644 --- a/ms_agent/config/env.py +++ b/ms_agent/config/env.py @@ -1,7 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import os.path +import os from copy import copy -from typing import Dict +from typing import Dict, Optional from dotenv import load_dotenv @@ -9,9 +9,30 @@ class Env: @staticmethod - def load_env(envs: Dict[str, str] = None) -> Dict[str, str]: - """Load environment variables from .env file and merges with the input envs""" - load_dotenv() + def load_dotenv_into_environ(dotenv_path: Optional[str] = None) -> None: + """Load key=value pairs from a .env file into ``os.environ``. + + Does not override variables already set in the process environment. + + If ``dotenv_path`` is given, loads that file; it must exist. + If ``dotenv_path`` is None, loads ``/.env`` when that file exists; + a missing default file is a no-op. + """ + if dotenv_path is not None: + path = os.path.abspath(os.path.expanduser(dotenv_path)) + if not os.path.isfile(path): + raise FileNotFoundError(f'Env file not found: {path}') + load_dotenv(path, override=False) + else: + default = os.path.join(os.getcwd(), '.env') + if os.path.isfile(default): + load_dotenv(default, override=False) + + @staticmethod + def load_env(envs: Dict[str, str] = None, + dotenv_path: Optional[str] = None) -> Dict[str, str]: + """Load .env into the process env, then merge with ``envs`` and return.""" + Env.load_dotenv_into_environ(dotenv_path) _envs = copy(os.environ) _envs.update(envs or {}) return _envs diff --git a/ms_agent/knowledge_search/__init__.py b/ms_agent/knowledge_search/__init__.py new file mode 100644 index 000000000..33362beee --- /dev/null +++ b/ms_agent/knowledge_search/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Knowledge search module based on sirchmunk. + +This module provides integration between sirchmunk's AgenticSearch +and the ms_agent framework, enabling intelligent codebase search +capabilities similar to RAG. +""" + +from .sirchmunk_search import SirchmunkSearch + +__all__ = ['SirchmunkSearch'] diff --git a/ms_agent/knowledge_search/sirchmunk_search.py b/ms_agent/knowledge_search/sirchmunk_search.py new file mode 100644 index 000000000..e1c76181f --- /dev/null +++ b/ms_agent/knowledge_search/sirchmunk_search.py @@ -0,0 +1,490 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Sirchmunk-based knowledge search integration. + +This module wraps sirchmunk's AgenticSearch to work with the ms_agent framework, +providing document retrieval capabilities similar to RAG but optimized for +codebase and documentation search. +""" + +import asyncio +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union + +from loguru import logger +from ms_agent.rag.base import RAG +from omegaconf import DictConfig + + +class SirchmunkSearch(RAG): + """Sirchmunk-based knowledge search class. + + This class wraps the sirchmunk library to provide intelligent codebase search + capabilities. Unlike traditional RAG that uses vector embeddings, Sirchmunk + uses a combination of keyword search, semantic clustering, and LLM-powered + analysis to find relevant information from codebases. + + The configuration needed in the config yaml: + - name: SirchmunkSearch + - paths: List of paths to search, required + - work_path: Working directory for sirchmunk cache, default './.sirchmunk' + - embedding_model: Embedding model for clustering, default 'text-embedding-3-small' + - cluster_sim_threshold: Threshold for cluster similarity, default 0.85 + - cluster_sim_top_k: Top K clusters to consider, default 3 + - reuse_knowledge: Whether to reuse previous search results, default True + - mode: Search mode (DEEP, FAST, FILENAME_ONLY), default 'FAST' + + Args: + config (DictConfig): Configuration object containing sirchmunk settings. + """ + + def __init__(self, config: DictConfig): + super().__init__(config) + + self._validate_config(config) + + # Extract configuration parameters + rag_config = config.get('knowledge_search', {}) + + # Search paths - required + paths = rag_config.get('paths', []) + if isinstance(paths, str): + paths = [paths] + self.search_paths: List[str] = [ + str(Path(p).expanduser().resolve()) for p in paths + ] + + # Work path for sirchmunk cache + _work_path = rag_config.get('work_path', './.sirchmunk') + self.work_path: Path = Path(_work_path).expanduser().resolve() + + # Sirchmunk search parameters + self.reuse_knowledge = rag_config.get('reuse_knowledge', True) + self.cluster_sim_threshold = rag_config.get('cluster_sim_threshold', + 0.85) + self.cluster_sim_top_k = rag_config.get('cluster_sim_top_k', 3) + self.search_mode = rag_config.get('mode', 'FAST') + self.max_loops = rag_config.get('max_loops', 10) + self.max_token_budget = rag_config.get('max_token_budget', 128000) + + # LLM configuration for sirchmunk + # First try knowledge_search.llm_api_key, then fall back to main llm config + self.llm_api_key = rag_config.get('llm_api_key', None) + self.llm_base_url = rag_config.get('llm_base_url', None) + self.llm_model_name = rag_config.get('llm_model_name', None) + + # Fall back to main llm config if not specified in knowledge_search + if (self.llm_api_key is None or self.llm_base_url is None + or self.llm_model_name is None): + llm_config = config.get('llm', {}) + if llm_config: + service = getattr(llm_config, 'service', 'dashscope') + if self.llm_api_key is None: + self.llm_api_key = getattr(llm_config, + f'{service}_api_key', None) + if self.llm_base_url is None: + self.llm_base_url = getattr(llm_config, + f'{service}_base_url', None) + if self.llm_model_name is None: + self.llm_model_name = getattr(llm_config, 'model', None) + + # Embedding model configuration + self.embedding_model_id = rag_config.get('embedding_model', None) + self.embedding_model_cache_dir = rag_config.get( + 'embedding_model_cache_dir', None) + + # Runtime state + self._searcher = None + self._initialized = False + self._cluster_cache_hit = False + self._cluster_cache_hit_time: str | None = None + self._last_search_result: List[Dict[str, Any]] | None = None + + # Callback for capturing logs + self._log_callback = None + self._search_logs: List[str] = [] + # Async queue for streaming logs in real-time + self._log_queue: asyncio.Queue | None = None + self._streaming_callback: Callable | None = None + + def _validate_config(self, config: DictConfig): + """Validate configuration parameters.""" + if not hasattr(config, + 'knowledge_search') or config.knowledge_search is None: + raise ValueError( + 'Missing knowledge_search configuration. ' + 'Please add knowledge_search section to your config with at least "paths" specified.' + ) + + rag_config = config.knowledge_search + paths = rag_config.get('paths', []) + if not paths: + raise ValueError( + 'knowledge_search.paths must be specified and non-empty') + + def _initialize_searcher(self): + """Initialize the sirchmunk AgenticSearch instance.""" + if self._initialized: + return + + try: + from sirchmunk.llm.openai_chat import OpenAIChat + from sirchmunk.search import AgenticSearch + from sirchmunk.utils.embedding_util import EmbeddingUtil + + # Create LLM client + llm = OpenAIChat( + api_key=self.llm_api_key, + base_url=self.llm_base_url, + model=self.llm_model_name, + max_retries=3, + log_callback=self._log_callback_wrapper(), + ) + + # Create embedding util + # Handle empty strings by using None (which triggers DEFAULT_MODEL_ID) + embedding_model_id = ( + self.embedding_model_id if self.embedding_model_id else None) + embedding_cache_dir = ( + self.embedding_model_cache_dir + if self.embedding_model_cache_dir else None) + embedding = EmbeddingUtil( + model_id=embedding_model_id, cache_dir=embedding_cache_dir) + + # Create AgenticSearch instance + self._searcher = AgenticSearch( + llm=llm, + embedding=embedding, + work_path=str(self.work_path), + paths=self.search_paths, + verbose=True, + reuse_knowledge=self.reuse_knowledge, + cluster_sim_threshold=self.cluster_sim_threshold, + cluster_sim_top_k=self.cluster_sim_top_k, + log_callback=self._log_callback_wrapper(), + ) + + self._initialized = True + logger.info( + f'SirschmunkSearch initialized with paths: {self.search_paths}' + ) + + except ImportError as e: + raise ImportError( + f'Failed to import sirchmunk: {e}. ' + 'Please install sirchmunk: pip install sirchmunk') + except Exception as e: + raise RuntimeError(f'Failed to initialize SirchmunkSearch: {e}') + + def _log_callback_wrapper(self): + """Create a callback wrapper to capture search logs. + + The sirchmunk LogCallback signature is: + (level: str, message: str, end: str, flush: bool) -> None + See sirchmunk/utils/log_utils.py for reference. + """ + + def log_callback( + level: str, + message: str, + end: str = '\n', + flush: bool = False, + ): + log_entry = f'[{level.upper()}] {message}' + self._search_logs.append(log_entry) + # Stream log in real-time if streaming callback is set + if self._streaming_callback: + asyncio.create_task(self._streaming_callback(log_entry)) + + return log_callback + + async def add_documents(self, documents: List[str]) -> bool: + """Add documents to the search index. + + Note: Sirchmunk works by scanning existing files in the specified paths. + This method is provided for RAG interface compatibility but doesn't + directly add documents. Instead, documents should be saved to files + within the search paths. + + Args: + documents (List[str]): List of document contents to add. + + Returns: + bool: True if successful (for interface compatibility). + """ + logger.warning( + 'SirchmunkSearch does not support direct document addition. ' + 'Documents should be saved to files within the configured search paths.' + ) + # Trigger re-scan of the search paths + if self._searcher and hasattr(self._searcher, 'knowledge_base'): + try: + await self._searcher.knowledge_base.refresh() + return True + except Exception as e: + logger.error(f'Failed to refresh knowledge base: {e}') + return False + return True + + async def add_documents_from_files(self, file_paths: List[str]) -> bool: + """Add documents from file paths. + + Args: + file_paths (List[str]): List of file paths to scan. + + Returns: + bool: True if successful. + """ + self._initialize_searcher() + + if self._searcher and hasattr(self._searcher, 'scan_directory'): + try: + for file_path in file_paths: + if Path(file_path).exists(): + await self._searcher.scan_directory( + str(Path(file_path).parent)) + return True + except Exception as e: + logger.error(f'Failed to scan files: {e}') + return False + return True + + async def retrieve(self, + query: str, + limit: int = 5, + score_threshold: float = 0.7, + **filters) -> List[Dict[str, Any]]: + """Retrieve relevant documents using sirchmunk. + + Args: + query (str): The search query. + limit (int): Maximum number of results to return. + score_threshold (float): Minimum relevance score threshold. + **filters: Additional filters (mode, max_loops, etc.). + + Returns: + List[Dict[str, Any]]: List of search results with 'text', 'score', + 'metadata' fields. + """ + self._initialize_searcher() + self._search_logs.clear() + + try: + mode = filters.get('mode', self.search_mode) + max_loops = filters.get('max_loops', self.max_loops) + max_token_budget = filters.get('max_token_budget', + self.max_token_budget) + + # Perform search + result = await self._searcher.search( + query=query, + mode=mode, + max_loops=max_loops, + max_token_budget=max_token_budget, + return_context=True, + ) + + # Check if cluster cache was hit + self._cluster_cache_hit = False + self._cluster_cache_hit_time = None + if hasattr(result, 'cluster') and result.cluster is not None: + # If a similar cluster was found and reused, it's a cache hit + self._cluster_cache_hit = getattr(result.cluster, + '_reused_from_cache', False) + # Get the cluster cache hit time if available + if hasattr(result.cluster, 'updated_at'): + self._cluster_cache_hit_time = getattr( + result.cluster, 'updated_at', None) + + # Parse results into standard format + return self._parse_search_result(result, score_threshold, limit) + + except Exception as e: + logger.error(f'SirschmunkSearch retrieve failed: {e}') + return [] + + async def query(self, query: str) -> str: + """Query sirchmunk and return a synthesized answer. + + This method performs a search and returns the LLM-synthesized answer + along with search details that can be used for frontend display. + + Args: + query (str): The search query. + + Returns: + str: The synthesized answer from sirchmunk. + """ + self._initialize_searcher() + self._search_logs.clear() + + try: + mode = self.search_mode + max_loops = self.max_loops + max_token_budget = self.max_token_budget + + # Single search with context so we get both the synthesized answer and + # source units in one call, avoiding a redundant second search. + result = await self._searcher.search( + query=query, + mode=mode, + max_loops=max_loops, + max_token_budget=max_token_budget, + return_context=True, + ) + + # Check if cluster cache was hit + self._cluster_cache_hit = False + self._cluster_cache_hit_time = None + if hasattr(result, 'cluster') and result.cluster is not None: + self._cluster_cache_hit = getattr(result.cluster, + '_reused_from_cache', False) + if hasattr(result.cluster, 'updated_at'): + self._cluster_cache_hit_time = getattr( + result.cluster, 'updated_at', None) + + # Store parsed context for frontend display + self._last_search_result = self._parse_search_result( + result, score_threshold=0.7, limit=5) + + # Extract the synthesized answer from the context result + if hasattr(result, 'answer'): + return result.answer + + # If result is already a plain string (some modes return str directly) + if isinstance(result, str): + return result + + # Fallback: convert to string + return str(result) + + except Exception as e: + logger.error(f'SirschmunkSearch query failed: {e}') + return f'Query failed: {e}' + + def _parse_search_result(self, result: Any, score_threshold: float, + limit: int) -> List[Dict[str, Any]]: + """Parse sirchmunk search result into standard format. + + Args: + result: The raw search result from sirchmunk. + score_threshold: Minimum score threshold. + limit: Maximum number of results. + + Returns: + List[Dict[str, Any]]: Parsed results. + """ + results = [] + + # Handle SearchContext format (returned when return_context=True) + if hasattr(result, 'cluster') and result.cluster is not None: + cluster = result.cluster + for unit in cluster.evidences: + # Extract score from snippets if available + score = getattr(cluster, 'confidence', 1.0) + if score >= score_threshold: + # Extract text from snippets + text_parts = [] + source = str(getattr(unit, 'file_or_url', 'unknown')) + for snippet in getattr(unit, 'snippets', []): + if isinstance(snippet, dict): + text_parts.append(snippet.get('snippet', '')) + else: + text_parts.append(str(snippet)) + + results.append({ + 'text': + '\n'.join(text_parts) if text_parts else getattr( + unit, 'summary', ''), + 'score': + score, + 'metadata': { + 'source': + source, + 'type': + getattr(unit, 'abstraction_level', 'text') + if hasattr(unit, 'abstraction_level') else 'text', + }, + }) + + # Handle format with evidence_units attribute directly + elif hasattr(result, 'evidence_units'): + for unit in result.evidence_units: + score = getattr(unit, 'confidence', 1.0) + if score >= score_threshold: + results.append({ + 'text': + str(unit.content) + if hasattr(unit, 'content') else str(unit), + 'score': + score, + 'metadata': { + 'source': getattr(unit, 'source_file', 'unknown'), + 'type': getattr(unit, 'abstraction_level', 'text'), + }, + }) + + # Handle list format + elif isinstance(result, list): + for item in result: + if isinstance(item, dict): + score = item.get('score', item.get('confidence', 1.0)) + if score >= score_threshold: + results.append({ + 'text': + item.get('content', item.get('text', str(item))), + 'score': + score, + 'metadata': + item.get('metadata', {}), + }) + + # Handle dict format + elif isinstance(result, dict): + score = result.get('score', result.get('confidence', 1.0)) + if score >= score_threshold: + results.append({ + 'text': + result.get('content', result.get('text', str(result))), + 'score': + score, + 'metadata': + result.get('metadata', {}), + }) + + # Sort by score and limit results + results.sort(key=lambda x: x.get('score', 0), reverse=True) + return results[:limit] + + def get_search_logs(self) -> List[str]: + """Get the captured search logs. + + Returns: + List[str]: List of log messages from the search operation. + """ + return self._search_logs.copy() + + def get_search_details(self) -> Dict[str, Any]: + """Get detailed search information including logs and metadata. + + Returns: + Dict[str, Any]: Search details including logs, mode, and paths. + """ + return { + 'logs': self._search_logs.copy(), + 'mode': self.search_mode, + 'paths': self.search_paths, + 'work_path': str(self.work_path), + 'reuse_knowledge': self.reuse_knowledge, + 'cluster_cache_hit': self._cluster_cache_hit, + 'cluster_cache_hit_time': self._cluster_cache_hit_time, + } + + def enable_streaming_logs(self, callback: Callable): + """Enable streaming mode for search logs. + + Args: + callback: Async callback function to receive log entries in real-time. + Signature: async def callback(log_entry: str) -> None + """ + self._streaming_callback = callback + self._search_logs.clear() diff --git a/ms_agent/llm/dashscope_llm.py b/ms_agent/llm/dashscope_llm.py index af766f679..b4a6ddaa8 100644 --- a/ms_agent/llm/dashscope_llm.py +++ b/ms_agent/llm/dashscope_llm.py @@ -12,7 +12,7 @@ class DashScope(OpenAI): def __init__(self, config: DictConfig): super().__init__( config, - base_url=config.llm.modelscope_base_url + base_url=config.llm.dashscope_base_url or get_service_config('dashscope').base_url, api_key=config.llm.dashscope_api_key) diff --git a/ms_agent/llm/utils.py b/ms_agent/llm/utils.py index 6a336ca6e..410aa12f0 100644 --- a/ms_agent/llm/utils.py +++ b/ms_agent/llm/utils.py @@ -61,6 +61,12 @@ class Message: api_calls: int = 1 + # Knowledge search (sirchmunk) related fields + # searching_detail: Search process logs and metadata for frontend display + searching_detail: Dict[str, Any] = field(default_factory=dict) + # search_result: Raw search results to be used as context for next LLM turn + search_result: List[Dict[str, Any]] = field(default_factory=list) + def to_dict(self): return asdict(self) diff --git a/ms_agent/rag/utils.py b/ms_agent/rag/utils.py index 08e9a4db7..e66da954d 100644 --- a/ms_agent/rag/utils.py +++ b/ms_agent/rag/utils.py @@ -4,3 +4,6 @@ rag_mapping = { 'LlamaIndexRAG': LlamaIndexRAG, } + +# Note: SirchmunkSearch is registered in knowledge_search module +# and integrated directly in LLMAgent, not through rag_mapping diff --git a/tests/knowledge_search/__init__.py b/tests/knowledge_search/__init__.py new file mode 100644 index 000000000..0cc40e613 --- /dev/null +++ b/tests/knowledge_search/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Knowledge search tests.""" diff --git a/tests/knowledge_search/test_sirschmunk.py b/tests/knowledge_search/test_sirschmunk.py new file mode 100644 index 000000000..5a4f43213 --- /dev/null +++ b/tests/knowledge_search/test_sirschmunk.py @@ -0,0 +1,203 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Tests for SirchmunkSearch knowledge search integration via LLMAgent. + +These tests verify the sirchmunk-based knowledge search functionality +through the LLMAgent entry point, including verification that +search_result and searching_detail fields are properly populated. + +To run these tests, you need to set the following environment variables: + - TEST_LLM_API_KEY: Your LLM API key + - TEST_LLM_BASE_URL: Your LLM API base URL (optional, default: OpenAI) + - TEST_LLM_MODEL_NAME: Your LLM model name (optional) + - TEST_EMBEDDING_MODEL_ID: Embedding model ID (optional) + - TEST_EMBEDDING_MODEL_CACHE_DIR: Embedding model cache directory (optional) + +Example: + export TEST_LLM_API_KEY="your-api-key" + export TEST_LLM_BASE_URL="https://api.openai.com/v1" + export TEST_LLM_MODEL_NAME="gpt-4o" + export TEST_EMBEDDING_MODEL_ID="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" + export TEST_EMBEDDING_MODEL_CACHE_DIR="/tmp/embedding_cache" + python -m pytest tests/knowledge_search/test_sirschmunk.py +""" +import asyncio +import os +import shutil +import unittest +from pathlib import Path + +from ms_agent.knowledge_search import SirchmunkSearch +from ms_agent.agent import LLMAgent +from ms_agent.config import Config +from omegaconf import DictConfig + +from modelscope.utils.test_utils import test_level + + +class SirchmunkLLMAgentIntegrationTest(unittest.TestCase): + """Test cases for SirchmunkSearch integration with LLMAgent. + + These tests verify that when LLMAgent runs a query that triggers + knowledge search, the Message objects have search_result and + searching_detail fields properly populated. + """ + + @classmethod + def setUpClass(cls): + """Set up test fixtures.""" + # Create test directory with sample files + cls.test_dir = Path('./test_llm_agent_knowledge') + cls.test_dir.mkdir(exist_ok=True) + + # Create sample documentation + (cls.test_dir / 'README.md').write_text(''' +# Test Project Documentation + +## Overview +This is a test project for knowledge search integration. + +## API Reference + +### UserManager +The UserManager class handles user operations: +- create_user: Create a new user account +- delete_user: Delete an existing user +- update_user: Update user information +- get_user: Retrieve user details + +### AuthService +The AuthService class handles authentication: +- login: Authenticate user credentials +- logout: End user session +- refresh_token: Refresh authentication token +- verify_token: Validate authentication token +''') + + (cls.test_dir / 'config.py').write_text(''' +"""Configuration module.""" + +class Config: + """Application configuration.""" + + def __init__(self): + self.database_url = "postgresql://localhost:5432/mydb" + self.secret_key = "your-secret-key" + self.debug_mode = False + + def load_from_env(self): + """Load configuration from environment variables.""" + import os + self.database_url = os.getenv("DATABASE_URL", self.database_url) + self.secret_key = os.getenv("SECRET_KEY", self.secret_key) + return self +''') + + @classmethod + def tearDownClass(cls): + """Clean up test fixtures.""" + if cls.test_dir.exists(): + shutil.rmtree(cls.test_dir, ignore_errors=True) + work_dir = Path('./.sirchmunk') + if work_dir.exists(): + shutil.rmtree(work_dir, ignore_errors=True) + + def _get_agent_config(self): + """Create agent configuration with knowledge search.""" + llm_api_key = os.getenv('TEST_LLM_API_KEY', 'test-api-key') + llm_base_url = os.getenv('TEST_LLM_BASE_URL', 'https://api.openai.com/v1') + llm_model_name = os.getenv('TEST_LLM_MODEL_NAME', 'gpt-4o-mini') + # Read from TEST_* env vars (for test-specific config) + # These can be set from .env file which uses TEST_* prefix + embedding_model_id = os.getenv('TEST_EMBEDDING_MODEL_ID', '') + embedding_model_cache_dir = os.getenv('TEST_EMBEDDING_MODEL_CACHE_DIR', '') + + config = DictConfig({ + 'llm': { + 'service': 'openai', + 'model': llm_model_name, + 'openai_api_key': llm_api_key, + 'openai_base_url': llm_base_url, + }, + 'generation_config': { + 'temperature': 0.3, + 'max_tokens': 500, + }, + 'knowledge_search': { + 'name': 'SirchmunkSearch', + 'paths': [str(self.test_dir)], + 'work_path': './.sirchmunk', + 'llm_api_key': llm_api_key, + 'llm_base_url': llm_base_url, + 'llm_model_name': llm_model_name, + 'embedding_model': embedding_model_id, + 'embedding_model_cache_dir': embedding_model_cache_dir, + 'mode': 'FAST', + } + }) + return config + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_llm_agent_with_knowledge_search(self): + """Test LLMAgent using knowledge search. + + This test verifies that: + 1. LLMAgent can be initialized with SirchmunkSearch configuration + 2. Running a query produces a valid response + 3. User message has searching_detail and search_result populated + 4. searching_detail contains expected keys (logs, mode, paths) + 5. search_result is a list + """ + config = self._get_agent_config() + agent = LLMAgent(config=config, tag='test-knowledge-agent') + + # Test query that should trigger knowledge search + query = 'How do I use UserManager to create a user?' + + async def run_agent(): + result = await agent.run(query) + return result + + result = asyncio.run(run_agent()) + + # Verify result + self.assertIsNotNone(result) + self.assertIsInstance(result, list) + self.assertTrue(len(result) > 0) + + # Check that assistant message exists + assistant_message = [m for m in result if m.role == 'assistant'] + self.assertTrue(len(assistant_message) > 0) + + # Check that user message has search_result and searching_detail populated + user_messages = [m for m in result if m.role == 'user'] + self.assertTrue(len(user_messages) > 0, "Expected at least one user message") + + # The first user message should have search details after do_rag processing + user_msg = user_messages[0] + self.assertTrue( + hasattr(user_msg, 'searching_detail'), + "User message should have searching_detail attribute" + ) + self.assertTrue( + hasattr(user_msg, 'search_result'), + "User message should have search_result attribute" + ) + + # Check that searching_detail is a dict with expected keys + self.assertIsInstance( + user_msg.searching_detail, dict, + "searching_detail should be a dictionary" + ) + self.assertIn('logs', user_msg.searching_detail) + self.assertIn('mode', user_msg.searching_detail) + self.assertIn('paths', user_msg.searching_detail) + + # Check that search_result is a list (may be empty if no relevant docs found) + self.assertIsInstance( + user_msg.search_result, list, + "search_result should be a list" + ) + + +if __name__ == '__main__': + unittest.main()