Skip to content

Commit 2299b20

Browse files
authored
feat: add ContextCompressor for context overflow handling (#885)
1 parent 827bc07 commit 2299b20

3 files changed

Lines changed: 227 additions & 0 deletions

File tree

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright (c) ModelScope Contributors. All rights reserved.
2+
from .code_condenser import CodeCondenser
3+
from .context_compressor import ContextCompressor
4+
from .refine_condenser import RefineCondenser
5+
6+
__all__ = ['CodeCondenser', 'RefineCondenser', 'ContextCompressor']
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright (c) ModelScope Contributors. All rights reserved.
2+
3+
from typing import List, Optional
4+
5+
import json
6+
from ms_agent.llm import LLM, Message
7+
from ms_agent.memory import Memory
8+
from ms_agent.utils.logger import logger
9+
10+
# Default summary prompt template (from opencode)
11+
SUMMARY_PROMPT = """Summarize this conversation to help continue the work.
12+
13+
Focus on:
14+
- Goal: What is the user trying to accomplish?
15+
- Instructions: Important user requirements or constraints
16+
- Discoveries: Notable findings during the conversation
17+
- Accomplished: What's done, in progress, and remaining
18+
- Relevant files: Files read, edited, or created
19+
20+
Keep it concise but comprehensive enough for another agent to continue."""
21+
22+
23+
class ContextCompressor(Memory):
24+
"""Context Compressor - Inspired by opencode's context compaction mechanism.
25+
26+
Core concepts:
27+
1. Token overflow detection - Monitor token usage against context limits
28+
2. Tool output pruning - Compress old tool call outputs to save context
29+
3. Summary compaction - Use LLM to generate conversation summary
30+
31+
Reference: opencode/packages/opencode/src/session/compaction.ts
32+
"""
33+
34+
def __init__(self, config):
35+
super().__init__(config)
36+
mem_config = getattr(config.memory, 'context_compressor', None)
37+
if mem_config is None:
38+
mem_config = config.memory
39+
40+
# Token thresholds (inspired by opencode's PRUNE constants)
41+
self.context_limit = getattr(mem_config, 'context_limit', 128000)
42+
self.prune_protect = getattr(mem_config, 'prune_protect', 40000)
43+
self.prune_minimum = getattr(mem_config, 'prune_minimum', 20000)
44+
self.reserved_buffer = getattr(mem_config, 'reserved_buffer', 20000)
45+
46+
# Summary prompt
47+
self.summary_prompt = getattr(mem_config, 'summary_prompt',
48+
SUMMARY_PROMPT)
49+
50+
# LLM for summarization
51+
self.llm: Optional[LLM] = None
52+
if getattr(mem_config, 'enable_summary', True):
53+
try:
54+
self.llm = LLM.from_config(config)
55+
except Exception as e:
56+
logger.warning(f'Failed to init LLM for summary: {e}')
57+
58+
def estimate_tokens(self, text: str) -> int:
59+
"""Estimate token count from text.
60+
Simple heuristic: ~4 chars per token for mixed content.
61+
"""
62+
if not text:
63+
return 0
64+
return len(text) // 4
65+
66+
def _estimate_message_tokens_from_content(self, msg: Message) -> int:
67+
"""Heuristic token count from message body (no API usage fields)."""
68+
total = 0
69+
if msg.content:
70+
content = msg.content if isinstance(
71+
msg.content, str) else json.dumps(
72+
msg.content, ensure_ascii=False)
73+
total += self.estimate_tokens(content)
74+
if msg.tool_calls:
75+
total += self.estimate_tokens(json.dumps(msg.tool_calls))
76+
if msg.reasoning_content:
77+
total += self.estimate_tokens(msg.reasoning_content)
78+
return total
79+
80+
def estimate_message_tokens(self, msg: Message) -> int:
81+
"""Tokens for one message: prefer ``Message`` usage, else content heuristic."""
82+
pt = int(getattr(msg, 'prompt_tokens', 0) or 0)
83+
ct = int(getattr(msg, 'completion_tokens', 0) or 0)
84+
if pt or ct:
85+
return pt + ct
86+
return self._estimate_message_tokens_from_content(msg)
87+
88+
def estimate_total_tokens(self, messages: List[Message]) -> int:
89+
"""Total tokens for the conversation."""
90+
last_usage_idx = -1
91+
for i in range(len(messages) - 1, -1, -1):
92+
m = messages[i]
93+
if m.role != 'assistant':
94+
continue
95+
pt = int(getattr(m, 'prompt_tokens', 0) or 0)
96+
ct = int(getattr(m, 'completion_tokens', 0) or 0)
97+
if pt or ct:
98+
last_usage_idx = i
99+
break
100+
if last_usage_idx >= 0:
101+
m = messages[last_usage_idx]
102+
base = int(getattr(m, 'prompt_tokens', 0) or 0) + int(
103+
getattr(m, 'completion_tokens', 0) or 0)
104+
tail = sum(
105+
self._estimate_message_tokens_from_content(x)
106+
for x in messages[last_usage_idx + 1:])
107+
return base + tail
108+
return sum(self.estimate_message_tokens(m) for m in messages)
109+
110+
def is_overflow(self, messages: List[Message]) -> bool:
111+
"""Check if messages exceed context limit."""
112+
total = self.estimate_total_tokens(messages)
113+
usable = self.context_limit - self.reserved_buffer
114+
return total >= usable
115+
116+
def prune_tool_outputs(self, messages: List[Message]) -> List[Message]:
117+
"""Prune old tool outputs to reduce context size.
118+
119+
Strategy (from opencode):
120+
- Scan backwards through messages
121+
- Protect the most recent tool outputs (prune_protect tokens)
122+
- Truncate older tool outputs
123+
"""
124+
total_tool_tokens = 0
125+
pruned_count = 0
126+
127+
for idx in range(len(messages) - 1, -1, -1):
128+
msg = messages[idx]
129+
if msg.role != 'tool' or not msg.content:
130+
continue
131+
content_str = msg.content if isinstance(
132+
msg.content, str) else json.dumps(
133+
msg.content, ensure_ascii=False)
134+
tokens = self.estimate_tokens(content_str)
135+
total_tool_tokens += tokens
136+
137+
if total_tool_tokens > self.prune_protect:
138+
msg.content = '[Output truncated to save context]'
139+
pruned_count += 1
140+
141+
if pruned_count > 0:
142+
logger.info(f'Pruned {pruned_count} tool outputs')
143+
144+
return messages
145+
146+
def summarize(self, messages: List[Message]) -> Optional[str]:
147+
"""Generate conversation summary using LLM."""
148+
if not self.llm:
149+
return None
150+
151+
# Build conversation text for summarization
152+
conv_parts = []
153+
for msg in messages:
154+
role = msg.role.upper()
155+
content = msg.content if isinstance(msg.content, str) else str(
156+
msg.content)
157+
if content:
158+
conv_parts.append(f'{role}: {content[:2000]}')
159+
160+
conversation = '\n'.join(conv_parts)
161+
query = f'{self.summary_prompt}\n\n---\n{conversation}'
162+
163+
try:
164+
response = self.llm.generate([Message(role='user', content=query)],
165+
stream=False)
166+
return response.content
167+
except Exception as e:
168+
logger.error(f'Summary generation failed: {e}')
169+
return None
170+
171+
def compress(self, messages: List[Message]) -> List[Message]:
172+
"""Compress messages when context overflows.
173+
174+
Steps:
175+
1. Try pruning tool outputs first
176+
2. If still overflow, generate summary and replace history
177+
"""
178+
if not self.is_overflow(messages):
179+
return messages
180+
181+
logger.info('Context overflow detected, starting compression')
182+
183+
# Step 1: Prune tool outputs
184+
pruned = self.prune_tool_outputs(messages)
185+
if not self.is_overflow(pruned):
186+
return pruned
187+
188+
# Step 2: Generate summary
189+
summary = self.summarize(messages)
190+
if not summary:
191+
logger.warning('Summary failed, returning pruned messages')
192+
return pruned
193+
194+
# Keep system prompt and replace history with summary
195+
result = []
196+
for msg in messages:
197+
if msg.role == 'system':
198+
result.append(msg)
199+
break
200+
201+
result.append(
202+
Message(
203+
role='user',
204+
content=f'[Conversation Summary]\n{summary}\n\n'
205+
'Please continue based on this summary.'))
206+
207+
# Keep the most recent user message if different
208+
if messages and messages[-1].role == 'user':
209+
last_user = messages[-1]
210+
if last_user.content and last_user.content != result[-1].content:
211+
result.append(last_user)
212+
213+
logger.info(
214+
f'Compressed {len(messages)} messages to {len(result)} messages')
215+
return result
216+
217+
async def run(self, messages: List[Message]) -> List[Message]:
218+
"""Main entry point for context compression."""
219+
return self.compress(messages)

ms_agent/memory/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from omegaconf import DictConfig, OmegaConf
33

44
from .condenser.code_condenser import CodeCondenser
5+
from .condenser.context_compressor import ContextCompressor
56
from .condenser.refine_condenser import RefineCondenser
67
from .default_memory import DefaultMemory
78
from .diversity import Diversity
@@ -11,6 +12,7 @@
1112
'diversity': Diversity,
1213
'code_condenser': CodeCondenser,
1314
'refine_condenser': RefineCondenser,
15+
'context_compressor': ContextCompressor,
1416
}
1517

1618

0 commit comments

Comments
 (0)