diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 8602d1621..3c705e8ab 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -694,8 +694,24 @@ async def load_memory(self): shared_memory = await SharedMemoryManager.get_shared_memory( self.config, mem_instance_type) + + ignore_roles = getattr(_memory, 'ignore_roles', []) + shared_memory.should_early_add_after_task = ( + 'assistant' in ignore_roles and 'tool' in ignore_roles) + shared_memory.early_add_after_task_done = False + self.memory_tools.append(shared_memory) + def _schedule_add_memory_after_task(self, messages, timestamp=None): + + def _add_memory(): + asyncio.run( + self.add_memory( + messages, add_type='add_after_task', timestamp=timestamp)) + + loop = asyncio.get_running_loop() + loop.run_in_executor(None, _add_memory) + async def prepare_rag(self): """Load and initialize the RAG component from the config.""" if hasattr(self.config, 'rag'): @@ -784,7 +800,6 @@ async def step( """ messages = deepcopy(messages) if (not self.load_cache) or messages[-1].role != 'assistant': - messages = await self.condense_memory(messages) await self.on_generate_response(messages) tools = await self.tool_manager.get_tools() @@ -961,26 +976,39 @@ def _get_run_memory_info(self, memory_config: DictConfig): async def add_memory(self, messages: List[Message], add_type, **kwargs): if hasattr(self.config, 'memory') and self.config.memory: - tools_num = len(self.memory_tools) if self.memory_tools else 0 - - for idx, (mem_instance_type, - memory_config) in enumerate(self.config.memory.items()): + for tool, (_, memory_config) in zip(self.memory_tools, + self.config.memory.items()): + timestamp = kwargs.get('timestamp', '') if add_type == 'add_after_task': user_id, agent_id, run_id, memory_type = self._get_run_memory_info( memory_config) + should_early = getattr(tool, 'should_early_add_after_task', + False) + early_done = getattr(tool, 'early_add_after_task_done', + False) + + if timestamp == 'early': + if not (should_early and not early_done): + # pass memory tool.run + continue + tool.early_add_after_task_done = True + else: + if early_done: + # pass memory tool.run + continue + else: user_id, agent_id, run_id, memory_type = self._get_step_memory_info( memory_config) - if idx < tools_num: - if any(v is not None - for v in [user_id, agent_id, run_id, memory_type]): - await self.memory_tools[idx].add( - messages, - user_id=user_id, - agent_id=agent_id, - run_id=run_id, - memory_type=memory_type) + if not any([user_id, agent_id, run_id, memory_type]): + continue + await tool.add( + messages, + user_id=user_id, + agent_id=agent_id, + run_id=run_id, + memory_type=memory_type) def save_history(self, messages: List[Message], **kwargs): """ @@ -1057,6 +1085,10 @@ async def run_loop(self, messages: Union[List[Message], str], self.log_output('[' + message.role + ']:') self.log_output(message.content) while not self.runtime.should_stop: + messages = await self.condense_memory(messages) + # If assistant and tool content can be ignored, add memory earlier to reduce running time. + self._schedule_add_memory_after_task( + messages, timestamp='early') async for messages in self.step(messages): yield messages self.runtime.round += 1 @@ -1082,13 +1114,8 @@ async def run_loop(self, messages: Union[List[Message], str], await self.cleanup_tools() yield messages - def _add_memory(): - asyncio.run( - self.add_memory( - messages, add_type='add_after_task', **kwargs)) + self._schedule_add_memory_after_task(messages) - loop = asyncio.get_running_loop() - loop.run_in_executor(None, _add_memory) except Exception as e: import traceback logger.warning(traceback.format_exc()) diff --git a/ms_agent/memory/default_memory.py b/ms_agent/memory/default_memory.py index b087a2dd9..4246602a0 100644 --- a/ms_agent/memory/default_memory.py +++ b/ms_agent/memory/default_memory.py @@ -588,8 +588,29 @@ def _init_memory_obj(self): f'Failed to import mem0: {e}. Please install mem0ai package via `pip install mem0ai`.' ) raise - + import mem0.vector_stores.milvus capture_event_origin = mem0.memory.main.capture_event + update_origin = mem0.vector_stores.milvus.MilvusDB.update + + @wraps(update_origin) + def update(self, vector_id=None, vector=None, payload=None): + """ + Update a vector and its payload. + + Args: + vector_id (str): ID of the vector to update. + vector (List[float], optional): Updated vector. + payload (Dict, optional): Updated payload. + """ + if vector is None: + res = self.client.get( + collection_name=self.collection_name, ids=[vector_id]) + if res: + vector = res[0]['vectors'] + + schema = {'id': vector_id, 'vectors': vector, 'metadata': payload} + self.client.upsert( + collection_name=self.collection_name, data=schema) @wraps(capture_event_origin) def patched_capture_event(event_name, @@ -597,6 +618,7 @@ def patched_capture_event(event_name, additional_data=None): pass + mem0.vector_stores.milvus.MilvusDB.update = update mem0.memory.main.capture_event = partial(patched_capture_event, ) # emb config diff --git a/ms_agent/memory/memory_manager.py b/ms_agent/memory/memory_manager.py index 5a203505d..8a142f777 100644 --- a/ms_agent/memory/memory_manager.py +++ b/ms_agent/memory/memory_manager.py @@ -4,7 +4,7 @@ from ms_agent.memory import Memory, memory_mapping from ms_agent.utils import get_logger from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR, DEFAULT_USER -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf logger = get_logger() @@ -17,10 +17,15 @@ class SharedMemoryManager: async def get_shared_memory(cls, config: DictConfig, mem_instance_type: str) -> Memory: """Get or create a shared memory instance based on configuration.""" - user_id: str = getattr(config, 'user_id', DEFAULT_USER) - path: str = getattr(config, 'path', DEFAULT_OUTPUT_DIR) - - key = f'{mem_instance_type}_{user_id}_{path}' + user_id: str = getattr( + getattr(config.memory, mem_instance_type, OmegaConf.create({})), + 'user_id', DEFAULT_USER) + path: str = getattr( + getattr(config.memory, mem_instance_type, OmegaConf.create({})), + 'path', DEFAULT_OUTPUT_DIR) + llm_str: str = getattr(config.llm, 'model', 'default_model') + + key = f'{mem_instance_type}_{user_id}_{llm_str}_{path}' if key not in cls._instances: logger.info(f'Creating new shared memory instance for key: {key}')