Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ venv.bak/

.vscode
.idea
.cursor

# custom
*.pkl
Expand Down
247 changes: 233 additions & 14 deletions ms_agent/agent/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from ms_agent.memory import Memory, get_memory_meta_safe, memory_mapping
from ms_agent.memory.memory_manager import SharedMemoryManager
from ms_agent.rag.base import RAG
from ms_agent.session import ContextAssembler, SessionLog
from ms_agent.session.strategies import SummaryCompactor, ToolOutputPruner
from ms_agent.rag.utils import rag_mapping
from ms_agent.tools import ToolManager
from ms_agent.utils import async_retry, read_history, save_history
Expand Down Expand Up @@ -107,9 +109,11 @@ def __init__(
self.tool_manager: Optional[ToolManager] = None
self.memory_tools: List[Memory] = []
self.rag: Optional[RAG] = None
self.knowledge_search: Optional[SirschmunkSearch] = None
self.knowledge_search: Optional[SirchmunkSearch] = None
self.llm: Optional[LLM] = None
self.runtime: Optional[Runtime] = None
self.session_log: Optional[SessionLog] = None
self.context_assembler: Optional[ContextAssembler] = None
self.max_chat_round: int = 0
self.load_cache = kwargs.get('load_cache', False)
self.config.load_cache = self.load_cache
Expand Down Expand Up @@ -733,6 +737,11 @@ async def do_skill(self,
async def load_memory(self):
"""Initialize and append memory tool instances based on the configuration provided in the global config.

For ``unified_memory``, this also:
- Passes the agent's LLM instance to the orchestrator
- Registers the ``memory`` / ``memory_read`` tools into ToolManager
- Injects memory-usage guidance into the system prompt

Raises:
AssertionError: If a specified memory type in the config does not exist in memory_mapping.
"""
Expand All @@ -747,6 +756,40 @@ async def load_memory(self):
self.config, mem_instance_type)
self.memory_tools.append(shared_memory)

# Wire unified_memory into the tool system
if mem_instance_type == 'unified_memory':
await self._register_memory_tool(shared_memory)

async def _register_memory_tool(self, orchestrator):
"""Register the memory tool into ToolManager and inject prompt guidance."""
from ms_agent.memory.unified.memory_tool import MemoryTool, MEMORY_USAGE_PROMPT

if not hasattr(orchestrator, 'get_tool_schemas'):
return

# Pass LLM and session_log to orchestrator for consolidation / extraction
if self.llm is not None:
orchestrator.set_llm(self.llm)
orchestrator.init_update_queue()
if self.session_log is not None and hasattr(orchestrator, '_session_log'):
orchestrator._session_log = self.session_log

# Register memory tool into the agent's tool system
if self.tool_manager is not None:
mem_tool = MemoryTool(self.config, orchestrator)
self.tool_manager.register_tool(mem_tool)
await self.tool_manager.reindex_tool()
logger.info('[unified_memory] Memory tool registered')

# Inject usage guidance into system prompt
if hasattr(self.config, 'prompt') and hasattr(self.config.prompt, 'system'):
current_prompt = self.config.prompt.system or ''
if 'Long-term Memory' not in current_prompt:
OmegaConf.update(
self.config, 'prompt.system',
current_prompt + '\n\n' + MEMORY_USAGE_PROMPT,
merge=True)

async def prepare_rag(self):
"""Load and initialize the RAG component from the config."""
if hasattr(self.config, 'rag'):
Expand All @@ -770,19 +813,135 @@ async def prepare_knowledge_search(self):
self.config)

async def condense_memory(self, messages: List[Message]) -> List[Message]:
"""Inject long-term memory context into messages.

.. deprecated::
Historically this also ran context compressors. Compression is
now handled by :class:`ContextAssembler` before this method is
called. This method only performs memory *injection* (adding
``<long-term-memory>`` blocks, etc.).
"""
Update memory using the current conversation history.
for memory_tool in self.memory_tools:
messages = await memory_tool.run(messages)
return messages

Args:
messages (List[Message]): Current message history.
async def inject_memory(self, messages: List[Message]) -> List[Message]:
"""Inject long-term memory context into the message list.

Returns:
List[Message]: Possibly updated message history after memory refinement.
Unlike ``condense_memory`` this only runs ``unified_memory`` style
tools that *inject* context (MEMORY.md snapshot, facts, etc.) — it
never trims or compresses messages.
"""
for memory_tool in self.memory_tools:
messages = await memory_tool.run(messages)
return messages

def _init_session_log(self) -> None:
"""Create SessionLog and ContextAssembler if session logging is enabled."""
session_cfg = getattr(self.config, 'session_log', None)
enabled = getattr(session_cfg, 'enabled', True) if session_cfg else True
if not enabled:
return

session_dir = getattr(
session_cfg, 'dir', None
) if session_cfg else None
if session_dir is None:
session_dir = os.path.join(
getattr(self.config, 'output_dir', 'output'),
'sessions',
)

session_key = getattr(session_cfg, 'session_key', None) if session_cfg else None
self.session_log = SessionLog(session_dir, session_key=session_key)

compaction_cfg = getattr(self.config, 'compaction', None)
compaction_enabled = (
getattr(compaction_cfg, 'enabled', True) if compaction_cfg else True
)

if not compaction_enabled:
self.context_assembler = ContextAssembler(
session_log=self.session_log, strategies=[], config={},
)
return

strategies = self._build_compaction_strategies(compaction_cfg)
assembler_config = self._build_assembler_config(compaction_cfg, session_cfg)
flush_callback = self._make_memory_flush_callback()

self.context_assembler = ContextAssembler(
session_log=self.session_log,
strategies=strategies,
config=assembler_config,
memory_flush_callback=flush_callback,
)

def _build_compaction_strategies(self, compaction_cfg):
"""Build the strategy list from YAML ``compaction.strategies``."""
if compaction_cfg and hasattr(compaction_cfg, 'strategies'):
strategies = []
for s_cfg in compaction_cfg.strategies:
name = getattr(s_cfg, 'name', '')
if not getattr(s_cfg, 'enabled', True):
continue
if name == 'tool_output_pruner':
strategies.append(ToolOutputPruner())
elif name == 'summary_compactor':
strategies.append(SummaryCompactor(llm=self.llm))
else:
logger.warning(f"Unknown compaction strategy: {name}")
return strategies

return [ToolOutputPruner(), SummaryCompactor(llm=self.llm)]

def _build_assembler_config(self, compaction_cfg, session_cfg):
"""Merge compaction params from ``compaction`` and ``session_log``."""
config: Dict[str, Any] = {}

if session_cfg:
for key in ('context_limit', 'reserved_buffer', 'prune_protect'):
val = getattr(session_cfg, key, None)
if val is not None:
config[key] = val

if compaction_cfg:
for key in ('context_limit', 'reserved_buffer'):
val = getattr(compaction_cfg, key, None)
if val is not None:
config[key] = val
if hasattr(compaction_cfg, 'strategies'):
for s_cfg in compaction_cfg.strategies:
if getattr(s_cfg, 'name', '') == 'tool_output_pruner':
pp = getattr(s_cfg, 'prune_protect', None)
if pp is not None:
config['prune_protect'] = pp

config.setdefault('context_limit', 128000)
config.setdefault('reserved_buffer', 20000)
config.setdefault('prune_protect', 40000)
return config

def _make_memory_flush_callback(self):
"""Create a callback that flushes memory before context compaction."""
def _flush(discarded_messages):
for memory_tool in self.memory_tools:
orchestrator = memory_tool
if hasattr(orchestrator, 'flush'):
import asyncio
from ms_agent.llm.utils import Message as _Msg
msgs = [_Msg(
role=m.get('role', 'user'),
content=m.get('content', ''),
tool_calls=m.get('tool_calls'),
) for m in discarded_messages]
try:
loop = asyncio.get_running_loop()
loop.create_task(orchestrator.flush(msgs))
except RuntimeError:
asyncio.run(orchestrator.flush(msgs))
return _flush

def log_output(self, content: Union[str, list]):
"""
Log formatted output with a tag prefix.
Expand Down Expand Up @@ -1089,6 +1248,31 @@ def save_history(self, messages: List[Message], **kwargs):
save_history(
self.output_dir, task=self.tag, config=config, messages=messages)

@staticmethod
def _msg_to_dict(msg: Message) -> Dict[str, Any]:
"""Convert a Message to a plain dict for SessionLog.

Preserves ``prompt_tokens`` and ``completion_tokens`` individually
so that :class:`ContextAssembler` strategies can leverage API-reported
usage data for accurate overflow detection.
"""
d: Dict[str, Any] = {'role': msg.role, 'content': msg.content or ''}
if msg.tool_calls:
d['tool_calls'] = msg.tool_calls
if hasattr(msg, 'tool_call_id') and msg.tool_call_id:
d['tool_call_id'] = msg.tool_call_id
if hasattr(msg, 'name') and msg.name:
d['name'] = msg.name
prompt_tokens = int(getattr(msg, 'prompt_tokens', 0) or 0)
completion_tokens = int(getattr(msg, 'completion_tokens', 0) or 0)
if prompt_tokens:
d['prompt_tokens'] = prompt_tokens
if completion_tokens:
d['completion_tokens'] = completion_tokens
if prompt_tokens or completion_tokens:
d['tokens'] = prompt_tokens + completion_tokens
return d

async def run_loop(self, messages: Union[List[Message], str],
**kwargs) -> AsyncGenerator[Any, Any]:
"""
Expand All @@ -1112,13 +1296,28 @@ async def run_loop(self, messages: Union[List[Message], str],
await self.load_memory()
await self.prepare_rag()
await self.prepare_knowledge_search()
self._init_session_log()
self.runtime.tag = self.tag

if messages is None:
messages = self.query

# Load history and restore state
self.config, self.runtime, messages = self.read_history(messages)
if self.session_log is not None:
restored = self.session_log.get_all_messages()
if restored and self.load_cache:
from ms_agent.llm.utils import Message as _Msg
messages = [_Msg(
role=m.get('role', 'user'),
content=m.get('content', ''),
tool_calls=m.get('tool_calls'),
) for m in restored]
else:
self.config, self.runtime, messages = self.read_history(
messages)
else:
self.config, self.runtime, messages = self.read_history(
messages)

if self.runtime.round == 0:
# New task: create standardized messages first
Expand All @@ -1137,14 +1336,30 @@ async def run_loop(self, messages: Union[List[Message], str],
await self.do_rag(messages)
await self.on_task_begin(messages)

# Seed SessionLog with initial messages
if self.session_log is not None:
for msg in messages:
self.session_log.append(self._msg_to_dict(msg))

for message in messages:
if message.role != 'system':
self.log_output('[' + message.role + ']:')
self.log_output(message.content)
while not self.runtime.should_stop:
# Rebuild context view from SessionLog (non-destructive compression)
if self.context_assembler is not None and self.runtime.round > 0:
messages = self.context_assembler.assemble()

pre_step_len = len(messages)
async for messages in self.step(messages):
yield messages
self.runtime.round += 1

# Append new messages to SessionLog
if self.session_log is not None:
for msg in messages[pre_step_len:]:
self.session_log.append(self._msg_to_dict(msg))

# save memory and history
await self.add_memory(
messages, add_type='add_after_step', **kwargs)
Expand All @@ -1153,13 +1368,17 @@ async def run_loop(self, messages: Union[List[Message], str],
# +1 means the next round the assistant may give a conclusion
if self.runtime.round >= self.max_chat_round + 1:
if not self.runtime.should_stop:
messages.append(
Message(
role='assistant',
content=
f'Task {messages[1].content} was cutted off, because '
f'max round({self.max_chat_round}) exceeded.',
))
cutoff_msg = Message(
role='assistant',
content=
f'Task {messages[1].content} was cutted off, because '
f'max round({self.max_chat_round}) exceeded.',
)
messages.append(cutoff_msg)
if self.session_log is not None:
self.session_log.append(
self._msg_to_dict(cutoff_msg))
self.save_history(messages)
self.runtime.should_stop = True
yield messages

Expand Down
2 changes: 2 additions & 0 deletions ms_agent/cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse

from ms_agent.cli.app import AppCMD
from ms_agent.cli.cron import CronCMD
from ms_agent.cli.run import RunCMD
from ms_agent.cli.ui import UICMD

Expand All @@ -20,6 +21,7 @@ def run_cmd():
RunCMD.define_args(subparsers)
AppCMD.define_args(subparsers)
UICMD.define_args(subparsers)
CronCMD.define_args(subparsers)

# unknown args will be handled in config.py
args, _ = parser.parse_known_args()
Expand Down
Loading
Loading