-
Notifications
You must be signed in to change notification settings - Fork 477
feat: add ContextCompressor for context overflow handling #885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c633b1c
aec1543
0280964
0702ce8
89d7e4b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| # Copyright (c) ModelScope Contributors. All rights reserved. | ||
| from .code_condenser import CodeCondenser | ||
| from .context_compressor import ContextCompressor | ||
| from .refine_condenser import RefineCondenser | ||
|
|
||
| __all__ = ['CodeCondenser', 'RefineCondenser', 'ContextCompressor'] |
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,219 @@ | ||||||||||
| # Copyright (c) ModelScope Contributors. All rights reserved. | ||||||||||
|
|
||||||||||
| from typing import List, Optional | ||||||||||
|
|
||||||||||
| import json | ||||||||||
| from ms_agent.llm import LLM, Message | ||||||||||
| from ms_agent.memory import Memory | ||||||||||
| from ms_agent.utils.logger import logger | ||||||||||
|
|
||||||||||
| # Default summary prompt template (from opencode) | ||||||||||
| SUMMARY_PROMPT = """Summarize this conversation to help continue the work. | ||||||||||
|
|
||||||||||
| Focus on: | ||||||||||
| - Goal: What is the user trying to accomplish? | ||||||||||
| - Instructions: Important user requirements or constraints | ||||||||||
| - Discoveries: Notable findings during the conversation | ||||||||||
| - Accomplished: What's done, in progress, and remaining | ||||||||||
| - Relevant files: Files read, edited, or created | ||||||||||
|
|
||||||||||
| Keep it concise but comprehensive enough for another agent to continue.""" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class ContextCompressor(Memory): | ||||||||||
| """Context Compressor - Inspired by opencode's context compaction mechanism. | ||||||||||
|
|
||||||||||
| Core concepts: | ||||||||||
| 1. Token overflow detection - Monitor token usage against context limits | ||||||||||
| 2. Tool output pruning - Compress old tool call outputs to save context | ||||||||||
| 3. Summary compaction - Use LLM to generate conversation summary | ||||||||||
|
|
||||||||||
| Reference: opencode/packages/opencode/src/session/compaction.ts | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| def __init__(self, config): | ||||||||||
| super().__init__(config) | ||||||||||
| mem_config = getattr(config.memory, 'context_compressor', None) | ||||||||||
| if mem_config is None: | ||||||||||
| mem_config = config.memory | ||||||||||
|
|
||||||||||
| # Token thresholds (inspired by opencode's PRUNE constants) | ||||||||||
| self.context_limit = getattr(mem_config, 'context_limit', 128000) | ||||||||||
| self.prune_protect = getattr(mem_config, 'prune_protect', 40000) | ||||||||||
| self.prune_minimum = getattr(mem_config, 'prune_minimum', 20000) | ||||||||||
| self.reserved_buffer = getattr(mem_config, 'reserved_buffer', 20000) | ||||||||||
|
|
||||||||||
| # Summary prompt | ||||||||||
| self.summary_prompt = getattr(mem_config, 'summary_prompt', | ||||||||||
| SUMMARY_PROMPT) | ||||||||||
|
|
||||||||||
| # LLM for summarization | ||||||||||
| self.llm: Optional[LLM] = None | ||||||||||
| if getattr(mem_config, 'enable_summary', True): | ||||||||||
| try: | ||||||||||
| self.llm = LLM.from_config(config) | ||||||||||
| except Exception as e: | ||||||||||
| logger.warning(f'Failed to init LLM for summary: {e}') | ||||||||||
|
|
||||||||||
| def estimate_tokens(self, text: str) -> int: | ||||||||||
| """Estimate token count from text. | ||||||||||
| Simple heuristic: ~4 chars per token for mixed content. | ||||||||||
| """ | ||||||||||
| if not text: | ||||||||||
| return 0 | ||||||||||
| return len(text) // 4 | ||||||||||
|
|
||||||||||
| def _estimate_message_tokens_from_content(self, msg: Message) -> int: | ||||||||||
| """Heuristic token count from message body (no API usage fields).""" | ||||||||||
| total = 0 | ||||||||||
| if msg.content: | ||||||||||
| content = msg.content if isinstance( | ||||||||||
| msg.content, str) else json.dumps( | ||||||||||
| msg.content, ensure_ascii=False) | ||||||||||
| total += self.estimate_tokens(content) | ||||||||||
| if msg.tool_calls: | ||||||||||
| total += self.estimate_tokens(json.dumps(msg.tool_calls)) | ||||||||||
| if msg.reasoning_content: | ||||||||||
| total += self.estimate_tokens(msg.reasoning_content) | ||||||||||
| return total | ||||||||||
|
|
||||||||||
| def estimate_message_tokens(self, msg: Message) -> int: | ||||||||||
| """Tokens for one message: prefer ``Message`` usage, else content heuristic.""" | ||||||||||
| pt = int(getattr(msg, 'prompt_tokens', 0) or 0) | ||||||||||
| ct = int(getattr(msg, 'completion_tokens', 0) or 0) | ||||||||||
| if pt or ct: | ||||||||||
| return pt + ct | ||||||||||
| return self._estimate_message_tokens_from_content(msg) | ||||||||||
|
|
||||||||||
| def estimate_total_tokens(self, messages: List[Message]) -> int: | ||||||||||
| """Total tokens for the conversation.""" | ||||||||||
| last_usage_idx = -1 | ||||||||||
| for i in range(len(messages) - 1, -1, -1): | ||||||||||
| m = messages[i] | ||||||||||
| if m.role != 'assistant': | ||||||||||
| continue | ||||||||||
| pt = int(getattr(m, 'prompt_tokens', 0) or 0) | ||||||||||
| ct = int(getattr(m, 'completion_tokens', 0) or 0) | ||||||||||
| if pt or ct: | ||||||||||
| last_usage_idx = i | ||||||||||
| break | ||||||||||
| if last_usage_idx >= 0: | ||||||||||
| m = messages[last_usage_idx] | ||||||||||
| base = int(getattr(m, 'prompt_tokens', 0) or 0) + int( | ||||||||||
| getattr(m, 'completion_tokens', 0) or 0) | ||||||||||
| tail = sum( | ||||||||||
| self._estimate_message_tokens_from_content(x) | ||||||||||
| for x in messages[last_usage_idx + 1:]) | ||||||||||
| return base + tail | ||||||||||
| return sum(self.estimate_message_tokens(m) for m in messages) | ||||||||||
|
|
||||||||||
| def is_overflow(self, messages: List[Message]) -> bool: | ||||||||||
| """Check if messages exceed context limit.""" | ||||||||||
| total = self.estimate_total_tokens(messages) | ||||||||||
| usable = self.context_limit - self.reserved_buffer | ||||||||||
| return total >= usable | ||||||||||
|
|
||||||||||
| def prune_tool_outputs(self, messages: List[Message]) -> List[Message]: | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个函数的prune操作过程感觉统一一下风格会不会更好?比如按照LLMAgent内的风格统一改成in-place操作修改msg.content=xxx。 |
||||||||||
| """Prune old tool outputs to reduce context size. | ||||||||||
|
|
||||||||||
| Strategy (from opencode): | ||||||||||
| - Scan backwards through messages | ||||||||||
| - Protect the most recent tool outputs (prune_protect tokens) | ||||||||||
| - Truncate older tool outputs | ||||||||||
| """ | ||||||||||
| total_tool_tokens = 0 | ||||||||||
| pruned_count = 0 | ||||||||||
|
|
||||||||||
| for idx in range(len(messages) - 1, -1, -1): | ||||||||||
| msg = messages[idx] | ||||||||||
| if msg.role != 'tool' or not msg.content: | ||||||||||
| continue | ||||||||||
| content_str = msg.content if isinstance( | ||||||||||
| msg.content, str) else json.dumps( | ||||||||||
| msg.content, ensure_ascii=False) | ||||||||||
| tokens = self.estimate_tokens(content_str) | ||||||||||
| total_tool_tokens += tokens | ||||||||||
|
|
||||||||||
| if total_tool_tokens > self.prune_protect: | ||||||||||
| msg.content = '[Output truncated to save context]' | ||||||||||
| pruned_count += 1 | ||||||||||
|
|
||||||||||
| if pruned_count > 0: | ||||||||||
| logger.info(f'Pruned {pruned_count} tool outputs') | ||||||||||
|
|
||||||||||
| return messages | ||||||||||
|
|
||||||||||
| def summarize(self, messages: List[Message]) -> Optional[str]: | ||||||||||
| """Generate conversation summary using LLM.""" | ||||||||||
| if not self.llm: | ||||||||||
| return None | ||||||||||
|
|
||||||||||
| # Build conversation text for summarization | ||||||||||
| conv_parts = [] | ||||||||||
| for msg in messages: | ||||||||||
| role = msg.role.upper() | ||||||||||
| content = msg.content if isinstance(msg.content, str) else str( | ||||||||||
| msg.content) | ||||||||||
| if content: | ||||||||||
| conv_parts.append(f'{role}: {content[:2000]}') | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The value |
||||||||||
|
|
||||||||||
| conversation = '\n'.join(conv_parts) | ||||||||||
| query = f'{self.summary_prompt}\n\n---\n{conversation}' | ||||||||||
|
|
||||||||||
| try: | ||||||||||
| response = self.llm.generate([Message(role='user', content=query)], | ||||||||||
| stream=False) | ||||||||||
| return response.content | ||||||||||
| except Exception as e: | ||||||||||
| logger.error(f'Summary generation failed: {e}') | ||||||||||
|
Comment on lines
+167
to
+168
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This |
||||||||||
| return None | ||||||||||
|
|
||||||||||
| def compress(self, messages: List[Message]) -> List[Message]: | ||||||||||
| """Compress messages when context overflows. | ||||||||||
|
|
||||||||||
| Steps: | ||||||||||
| 1. Try pruning tool outputs first | ||||||||||
| 2. If still overflow, generate summary and replace history | ||||||||||
| """ | ||||||||||
| if not self.is_overflow(messages): | ||||||||||
| return messages | ||||||||||
|
|
||||||||||
| logger.info('Context overflow detected, starting compression') | ||||||||||
|
|
||||||||||
| # Step 1: Prune tool outputs | ||||||||||
| pruned = self.prune_tool_outputs(messages) | ||||||||||
| if not self.is_overflow(pruned): | ||||||||||
| return pruned | ||||||||||
|
|
||||||||||
| # Step 2: Generate summary | ||||||||||
| summary = self.summarize(messages) | ||||||||||
| if not summary: | ||||||||||
| logger.warning('Summary failed, returning pruned messages') | ||||||||||
| return pruned | ||||||||||
|
|
||||||||||
| # Keep system prompt and replace history with summary | ||||||||||
| result = [] | ||||||||||
| for msg in messages: | ||||||||||
| if msg.role == 'system': | ||||||||||
| result.append(msg) | ||||||||||
| break | ||||||||||
|
|
||||||||||
| result.append( | ||||||||||
| Message( | ||||||||||
| role='user', | ||||||||||
| content=f'[Conversation Summary]\n{summary}\n\n' | ||||||||||
| 'Please continue based on this summary.')) | ||||||||||
|
|
||||||||||
| # Keep the most recent user message if different | ||||||||||
| if messages and messages[-1].role == 'user': | ||||||||||
| last_user = messages[-1] | ||||||||||
| if last_user.content and last_user.content != result[-1].content: | ||||||||||
| result.append(last_user) | ||||||||||
|
Comment on lines
+210
to
+211
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The condition
Suggested change
|
||||||||||
|
|
||||||||||
| logger.info( | ||||||||||
| f'Compressed {len(messages)} messages to {len(result)} messages') | ||||||||||
| return result | ||||||||||
|
|
||||||||||
| async def run(self, messages: List[Message]) -> List[Message]: | ||||||||||
| """Main entry point for context compression.""" | ||||||||||
| return self.compress(messages) | ||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Catching a generic
Exceptionis generally discouraged as it can mask unexpected errors during LLM initialization and make debugging difficult. It's better to catch more specific exceptions that you anticipateLLM.from_config()might raise, such as configuration or import errors. This makes the error handling more robust and intentional.