Skip to content
67 changes: 47 additions & 20 deletions ms_agent/agent/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,8 +694,24 @@ async def load_memory(self):

shared_memory = await SharedMemoryManager.get_shared_memory(
self.config, mem_instance_type)

ignore_roles = getattr(_memory, 'ignore_roles', [])
shared_memory.should_early_add_after_task = (
'assistant' in ignore_roles and 'tool' in ignore_roles)
shared_memory.early_add_after_task_done = False

self.memory_tools.append(shared_memory)

def _schedule_add_memory_after_task(self, messages, timestamp=None):

def _add_memory():
asyncio.run(
self.add_memory(
messages, add_type='add_after_task', timestamp=timestamp))

loop = asyncio.get_running_loop()
loop.run_in_executor(None, _add_memory)
Comment on lines +705 to +713
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _schedule_add_memory_after_task function is called from an async context (run_loop). Using asyncio.run() inside a function that is already part of a running event loop (which asyncio.get_running_loop() implies) is an anti-pattern and will raise a RuntimeError: Cannot run asyncio.run() from a running event loop.

To run self.add_memory concurrently without blocking the main event loop, _schedule_add_memory_after_task should be an async function, and asyncio.create_task should be used to schedule the add_memory coroutine. This will allow add_memory to run in the background on the existing event loop.

    async def _schedule_add_memory_after_task(self, messages, timestamp=None):
        asyncio.create_task(
            self.add_memory(
                messages, add_type='add_after_task', timestamp=timestamp))


async def prepare_rag(self):
"""Load and initialize the RAG component from the config."""
if hasattr(self.config, 'rag'):
Expand Down Expand Up @@ -784,7 +800,6 @@ async def step(
"""
messages = deepcopy(messages)
if (not self.load_cache) or messages[-1].role != 'assistant':
messages = await self.condense_memory(messages)
await self.on_generate_response(messages)
tools = await self.tool_manager.get_tools()

Expand Down Expand Up @@ -961,26 +976,39 @@ def _get_run_memory_info(self, memory_config: DictConfig):

async def add_memory(self, messages: List[Message], add_type, **kwargs):
if hasattr(self.config, 'memory') and self.config.memory:
tools_num = len(self.memory_tools) if self.memory_tools else 0

for idx, (mem_instance_type,
memory_config) in enumerate(self.config.memory.items()):
for tool, (_, memory_config) in zip(self.memory_tools,
self.config.memory.items()):
timestamp = kwargs.get('timestamp', '')
if add_type == 'add_after_task':
user_id, agent_id, run_id, memory_type = self._get_run_memory_info(
memory_config)
should_early = getattr(tool, 'should_early_add_after_task',
False)
early_done = getattr(tool, 'early_add_after_task_done',
False)

if timestamp == 'early':
if not (should_early and not early_done):
# pass memory tool.run
continue
tool.early_add_after_task_done = True
else:
if early_done:
# pass memory tool.run
continue

else:
user_id, agent_id, run_id, memory_type = self._get_step_memory_info(
memory_config)

if idx < tools_num:
if any(v is not None
for v in [user_id, agent_id, run_id, memory_type]):
await self.memory_tools[idx].add(
messages,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
memory_type=memory_type)
if not any([user_id, agent_id, run_id, memory_type]):
continue
await tool.add(
messages,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
memory_type=memory_type)

def save_history(self, messages: List[Message], **kwargs):
"""
Expand Down Expand Up @@ -1057,6 +1085,10 @@ async def run_loop(self, messages: Union[List[Message], str],
self.log_output('[' + message.role + ']:')
self.log_output(message.content)
while not self.runtime.should_stop:
messages = await self.condense_memory(messages)
# If assistant and tool content can be ignored, add memory earlier to reduce running time.
self._schedule_add_memory_after_task(
messages, timestamp='early')
Comment on lines +1090 to +1091
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Following the fix for _schedule_add_memory_after_task to be an async function, this call site must be updated to await the function.

                await self._schedule_add_memory_after_task(
                    messages, timestamp='early')

async for messages in self.step(messages):
yield messages
self.runtime.round += 1
Expand All @@ -1082,13 +1114,8 @@ async def run_loop(self, messages: Union[List[Message], str],
await self.cleanup_tools()
yield messages

def _add_memory():
asyncio.run(
self.add_memory(
messages, add_type='add_after_task', **kwargs))
self._schedule_add_memory_after_task(messages)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Following the fix for _schedule_add_memory_after_task to be an async function, this call site must be updated to await the function.

            await self._schedule_add_memory_after_task(messages)


loop = asyncio.get_running_loop()
loop.run_in_executor(None, _add_memory)
except Exception as e:
import traceback
logger.warning(traceback.format_exc())
Expand Down
24 changes: 23 additions & 1 deletion ms_agent/memory/default_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,15 +588,37 @@ def _init_memory_obj(self):
f'Failed to import mem0: {e}. Please install mem0ai package via `pip install mem0ai`.'
)
raise

import mem0.vector_stores.milvus
capture_event_origin = mem0.memory.main.capture_event
update_origin = mem0.vector_stores.milvus.MilvusDB.update

@wraps(update_origin)
def update(self, vector_id=None, vector=None, payload=None):
"""
Update a vector and its payload.

Args:
vector_id (str): ID of the vector to update.
vector (List[float], optional): Updated vector.
payload (Dict, optional): Updated payload.
"""
if vector is None:
res = self.client.get(
collection_name=self.collection_name, ids=[vector_id])
if res:
vector = res[0]['vectors']

schema = {'id': vector_id, 'vectors': vector, 'metadata': payload}
self.client.upsert(
collection_name=self.collection_name, data=schema)
Comment on lines +591 to +613
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Monkey-patching a third-party library's method (mem0.vector_stores.milvus.MilvusDB.update) can be a fragile practice. While it might be necessary for an immediate fix or specific functionality not available upstream, it introduces a dependency on the internal implementation details of mem0.

Consider documenting the specific reason for this monkey-patching (e.g., a bug in mem0 or a missing feature) and explore more robust integration methods in the long term, such as contributing the change upstream to the mem0 library or using a custom wrapper if mem0 provides extension points. This will improve maintainability and reduce the risk of breakage with future mem0 updates.


@wraps(capture_event_origin)
def patched_capture_event(event_name,
memory_instance,
additional_data=None):
pass

mem0.vector_stores.milvus.MilvusDB.update = update
mem0.memory.main.capture_event = partial(patched_capture_event, )

# emb config
Expand Down
15 changes: 10 additions & 5 deletions ms_agent/memory/memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ms_agent.memory import Memory, memory_mapping
from ms_agent.utils import get_logger
from ms_agent.utils.constants import DEFAULT_OUTPUT_DIR, DEFAULT_USER
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf

logger = get_logger()

Expand All @@ -17,10 +17,15 @@ class SharedMemoryManager:
async def get_shared_memory(cls, config: DictConfig,
mem_instance_type: str) -> Memory:
"""Get or create a shared memory instance based on configuration."""
user_id: str = getattr(config, 'user_id', DEFAULT_USER)
path: str = getattr(config, 'path', DEFAULT_OUTPUT_DIR)

key = f'{mem_instance_type}_{user_id}_{path}'
user_id: str = getattr(
getattr(config.memory, mem_instance_type, OmegaConf.create({})),
'user_id', DEFAULT_USER)
path: str = getattr(
getattr(config.memory, mem_instance_type, OmegaConf.create({})),
'path', DEFAULT_OUTPUT_DIR)
llm_str: str = getattr(config.llm, 'model', 'default_model')

key = f'{mem_instance_type}_{user_id}_{llm_str}_{path}'

if key not in cls._instances:
logger.info(f'Creating new shared memory instance for key: {key}')
Expand Down
Loading