-
Notifications
You must be signed in to change notification settings - Fork 477
Fix/memory #887
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix/memory #887
Changes from all commits
b8f842d
257c557
8eb5973
33281b8
292b130
996ddde
03372f5
86f1cb0
6a6a2dd
eb5eae6
6baa994
9c8205c
1384a16
9a4cac8
65da529
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -694,8 +694,24 @@ async def load_memory(self): | |
|
|
||
| shared_memory = await SharedMemoryManager.get_shared_memory( | ||
| self.config, mem_instance_type) | ||
|
|
||
| ignore_roles = getattr(_memory, 'ignore_roles', []) | ||
| shared_memory.should_early_add_after_task = ( | ||
| 'assistant' in ignore_roles and 'tool' in ignore_roles) | ||
| shared_memory.early_add_after_task_done = False | ||
|
|
||
| self.memory_tools.append(shared_memory) | ||
|
|
||
| def _schedule_add_memory_after_task(self, messages, timestamp=None): | ||
|
|
||
| def _add_memory(): | ||
| asyncio.run( | ||
| self.add_memory( | ||
| messages, add_type='add_after_task', timestamp=timestamp)) | ||
|
|
||
| loop = asyncio.get_running_loop() | ||
| loop.run_in_executor(None, _add_memory) | ||
|
|
||
| async def prepare_rag(self): | ||
| """Load and initialize the RAG component from the config.""" | ||
| if hasattr(self.config, 'rag'): | ||
|
|
@@ -784,7 +800,6 @@ async def step( | |
| """ | ||
| messages = deepcopy(messages) | ||
| if (not self.load_cache) or messages[-1].role != 'assistant': | ||
| messages = await self.condense_memory(messages) | ||
| await self.on_generate_response(messages) | ||
| tools = await self.tool_manager.get_tools() | ||
|
|
||
|
|
@@ -961,26 +976,39 @@ def _get_run_memory_info(self, memory_config: DictConfig): | |
|
|
||
| async def add_memory(self, messages: List[Message], add_type, **kwargs): | ||
| if hasattr(self.config, 'memory') and self.config.memory: | ||
| tools_num = len(self.memory_tools) if self.memory_tools else 0 | ||
|
|
||
| for idx, (mem_instance_type, | ||
| memory_config) in enumerate(self.config.memory.items()): | ||
| for tool, (_, memory_config) in zip(self.memory_tools, | ||
| self.config.memory.items()): | ||
| timestamp = kwargs.get('timestamp', '') | ||
| if add_type == 'add_after_task': | ||
| user_id, agent_id, run_id, memory_type = self._get_run_memory_info( | ||
| memory_config) | ||
| should_early = getattr(tool, 'should_early_add_after_task', | ||
| False) | ||
| early_done = getattr(tool, 'early_add_after_task_done', | ||
| False) | ||
|
|
||
| if timestamp == 'early': | ||
| if not (should_early and not early_done): | ||
| # pass memory tool.run | ||
| continue | ||
| tool.early_add_after_task_done = True | ||
| else: | ||
| if early_done: | ||
| # pass memory tool.run | ||
| continue | ||
|
|
||
| else: | ||
| user_id, agent_id, run_id, memory_type = self._get_step_memory_info( | ||
| memory_config) | ||
|
|
||
| if idx < tools_num: | ||
| if any(v is not None | ||
| for v in [user_id, agent_id, run_id, memory_type]): | ||
| await self.memory_tools[idx].add( | ||
| messages, | ||
| user_id=user_id, | ||
| agent_id=agent_id, | ||
| run_id=run_id, | ||
| memory_type=memory_type) | ||
| if not any([user_id, agent_id, run_id, memory_type]): | ||
| continue | ||
| await tool.add( | ||
| messages, | ||
| user_id=user_id, | ||
| agent_id=agent_id, | ||
| run_id=run_id, | ||
| memory_type=memory_type) | ||
|
|
||
| def save_history(self, messages: List[Message], **kwargs): | ||
| """ | ||
|
|
@@ -1057,6 +1085,10 @@ async def run_loop(self, messages: Union[List[Message], str], | |
| self.log_output('[' + message.role + ']:') | ||
| self.log_output(message.content) | ||
| while not self.runtime.should_stop: | ||
| messages = await self.condense_memory(messages) | ||
| # If assistant and tool content can be ignored, add memory earlier to reduce running time. | ||
| self._schedule_add_memory_after_task( | ||
| messages, timestamp='early') | ||
|
Comment on lines
+1090
to
+1091
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| async for messages in self.step(messages): | ||
| yield messages | ||
| self.runtime.round += 1 | ||
|
|
@@ -1082,13 +1114,8 @@ async def run_loop(self, messages: Union[List[Message], str], | |
| await self.cleanup_tools() | ||
| yield messages | ||
|
|
||
| def _add_memory(): | ||
| asyncio.run( | ||
| self.add_memory( | ||
| messages, add_type='add_after_task', **kwargs)) | ||
| self._schedule_add_memory_after_task(messages) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| loop = asyncio.get_running_loop() | ||
| loop.run_in_executor(None, _add_memory) | ||
| except Exception as e: | ||
| import traceback | ||
| logger.warning(traceback.format_exc()) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -588,15 +588,37 @@ def _init_memory_obj(self): | |
| f'Failed to import mem0: {e}. Please install mem0ai package via `pip install mem0ai`.' | ||
| ) | ||
| raise | ||
|
|
||
| import mem0.vector_stores.milvus | ||
| capture_event_origin = mem0.memory.main.capture_event | ||
| update_origin = mem0.vector_stores.milvus.MilvusDB.update | ||
|
|
||
| @wraps(update_origin) | ||
| def update(self, vector_id=None, vector=None, payload=None): | ||
| """ | ||
| Update a vector and its payload. | ||
|
|
||
| Args: | ||
| vector_id (str): ID of the vector to update. | ||
| vector (List[float], optional): Updated vector. | ||
| payload (Dict, optional): Updated payload. | ||
| """ | ||
| if vector is None: | ||
| res = self.client.get( | ||
| collection_name=self.collection_name, ids=[vector_id]) | ||
| if res: | ||
| vector = res[0]['vectors'] | ||
|
|
||
| schema = {'id': vector_id, 'vectors': vector, 'metadata': payload} | ||
| self.client.upsert( | ||
| collection_name=self.collection_name, data=schema) | ||
|
Comment on lines
+591
to
+613
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Monkey-patching a third-party library's method ( Consider documenting the specific reason for this monkey-patching (e.g., a bug in |
||
|
|
||
| @wraps(capture_event_origin) | ||
| def patched_capture_event(event_name, | ||
| memory_instance, | ||
| additional_data=None): | ||
| pass | ||
|
|
||
| mem0.vector_stores.milvus.MilvusDB.update = update | ||
| mem0.memory.main.capture_event = partial(patched_capture_event, ) | ||
|
|
||
| # emb config | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
_schedule_add_memory_after_taskfunction is called from anasynccontext (run_loop). Usingasyncio.run()inside a function that is already part of a running event loop (whichasyncio.get_running_loop()implies) is an anti-pattern and will raise aRuntimeError: Cannot run asyncio.run() from a running event loop.To run
self.add_memoryconcurrently without blocking the main event loop,_schedule_add_memory_after_taskshould be anasyncfunction, andasyncio.create_taskshould be used to schedule theadd_memorycoroutine. This will allowadd_memoryto run in the background on the existing event loop.