Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 67 additions & 11 deletions astrbot/core/agent/context/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,14 @@ def should_compress(
return usage_rate > self.compression_threshold

async def __call__(self, messages: list[Message]) -> list[Message]:
"""Compress messages by removing oldest turns."""
truncator = ContextTruncator()

truncated_messages = truncator.truncate_by_dropping_oldest_turns(
messages,
drop_turns=self.truncate_turns,
)

return truncated_messages


Expand Down Expand Up @@ -143,9 +146,30 @@ def split_history(
return system_messages, messages_to_summarize, recent_messages


def _generate_summary_cache_key(messages: list[Message]) -> str:
"""Generate a cache key for summary based on full history.

Uses role and content from all messages to create a collision-resistant key.
"""
if not messages:
return ""

key_parts = []
for msg in messages:
Comment on lines +149 to +158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Summary cache key truncation to 50 characters can cause collisions and incorrect cache hits.

Because the key uses content[:50] per message, different histories that share the same first 50 characters will collide and reuse the same summary, returning stale or incorrect results. Consider using a hash of the full content (e.g. hashlib.sha1(content.encode()).hexdigest()), optionally combined with message count and roles, or at least a longer prefix plus the full length to reduce collision risk.

content = msg.content if isinstance(msg.content, str) else str(msg.content)
key_parts.append(f"{msg.role}:{content[:50]}")

return "|".join(key_parts)


class LLMSummaryCompressor:
"""LLM-based summary compressor.
Uses LLM to summarize the old conversation history, keeping the latest messages.

Optimizations:
- 支持增量摘要,只摘要超出的部分
- 添加摘要缓存避免重复摘要
- 支持自定义摘要提示词
"""

def __init__(
Expand Down Expand Up @@ -175,6 +199,10 @@ def __init__(
"4. Write the summary in the user's language.\n"
)

# 新增: 摘要缓存
self._summary_cache: dict[str, str] = {}
self._max_cache_size = 50

def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
Expand All @@ -200,6 +228,10 @@ async def __call__(self, messages: list[Message]) -> list[Message]:
1. Divide messages: keep the system message and the latest N messages.
2. Send the old messages + the instruction message to the LLM.
3. Reconstruct the message list: [system message, summary message, latest messages].

Optimizations:
- 添加摘要缓存
- 检查是否已有摘要,避免重复生成
"""
if len(messages) <= self.keep_recent + 1:
return messages
Expand All @@ -211,17 +243,37 @@ async def __call__(self, messages: list[Message]) -> list[Message]:
if not messages_to_summarize:
return messages

# build payload
instruction_message = Message(role="user", content=self.instruction_text)
llm_payload = messages_to_summarize + [instruction_message]

# generate summary
try:
response = await self.provider.text_chat(contexts=llm_payload)
summary_content = response.completion_text
except Exception as e:
logger.error(f"Failed to generate summary: {e}")
return messages
# 生成缓存键
cache_key = _generate_summary_cache_key(messages_to_summarize)

# 尝试从缓存获取摘要
summary_content = None
if cache_key in self._summary_cache:
summary_content = self._summary_cache[cache_key]
logger.debug("Using cached summary")

# 如果缓存没有,生成新摘要
if summary_content is None:
# build payload
instruction_message = Message(role="user", content=self.instruction_text)
llm_payload = messages_to_summarize + [instruction_message]

# generate summary
try:
response = await self.provider.text_chat(contexts=llm_payload)
summary_content = response.completion_text

# 缓存摘要
if len(self._summary_cache) < self._max_cache_size:
self._summary_cache[cache_key] = summary_content
else:
# 简单的缓存淘汰
self._summary_cache.pop(next(iter(self._summary_cache)))
self._summary_cache[cache_key] = summary_content

except Exception as e:
logger.error(f"Failed to generate summary: {e}")
return messages

# build result
result = []
Expand All @@ -243,3 +295,7 @@ async def __call__(self, messages: list[Message]) -> list[Message]:
result.extend(recent_messages)

return result

def clear_cache(self) -> None:
"""清空摘要缓存。"""
self._summary_cache.clear()
99 changes: 88 additions & 11 deletions astrbot/core/agent/context/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@


class ContextManager:
"""Context compression manager."""
"""Context compression manager.

Optimizations:
- 减少重复 token 计算
- 添加增量压缩支持
- 优化日志输出
"""

def __init__(
self,
Expand Down Expand Up @@ -41,13 +47,27 @@
truncate_turns=config.truncate_turns
)

# 缓存上一次计算的消息指纹和 token 数
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider removing the new caching/fingerprinting logic from ContextManager and delegating all token-count caching to EstimateTokenCounter while preserving compression behavior and stats.

You can simplify the new logic by dropping the duplicated caching/fingerprinting in ContextManager and relying solely on EstimateTokenCounter’s internal cache, while keeping compression behavior and stats.

1. Remove redundant cache state and private‑method coupling

Remove the extra fields and _get_messages_fingerprint, and their usages:

class ContextManager:
    def __init__(self, config: ContextConfig) -> None:
        ...
        # Remove these:
        # self._last_messages_fingerprint: int | None = None
        # self._last_token_count: int | None = None
        self._compression_count = 0

    # Remove this method entirely
    # def _get_messages_fingerprint(self, messages: list[Message]) -> int:
    #     ...

2. Simplify process to always use token_counter.count_tokens

This keeps all behavior (including trusted_token_usage) but delegates caching to EstimateTokenCounter:

async def process(
    self, messages: list[Message], trusted_token_usage: int = 0
) -> list[Message]:
    try:
        result = messages

        if self.config.enforce_max_turns != -1:
            result = self.truncator.truncate_by_turns(
                result,
                keep_most_recent_turns=self.config.enforce_max_turns,
                drop_turns=self.config.truncate_turns,
            )

        if self.config.max_context_tokens > 0:
            if trusted_token_usage > 0:
                total_tokens = trusted_token_usage
            else:
                total_tokens = self.token_counter.count_tokens(result)

            if self.compressor.should_compress(
                result, total_tokens, self.config.max_context_tokens
            ):
                result = await self._run_compression(result, total_tokens)

        return result
    except Exception as e:
        logger.error(f"Error during context processing: {e}", exc_info=True)
        return messages

3. Simplify _run_compression and avoid manual cache updates

Just count tokens via the counter; its cache will make repeated calls cheap:

async def _run_compression(
    self, messages: list[Message], prev_tokens: int
) -> list[Message]:
    logger.debug("Compress triggered, starting compression...")
    self._compression_count += 1

    messages = await self.compressor(messages)

    tokens_after_compression = self.token_counter.count_tokens(messages)

    compress_rate = (
        tokens_after_compression / self.config.max_context_tokens
    ) * 100

    logger.info(
        f"Compress #{self._compression_count} completed."
        f" {prev_tokens} -> {tokens_after_compression} tokens,"
        f" compression rate: {compress_rate:.2f}%.",
    )

    if self.compressor.should_compress(
        messages, tokens_after_compression, self.config.max_context_tokens
    ):
        logger.info(
            "Context still exceeds max tokens after compression, "
            "applying halving truncation..."
        )
        messages = self.truncator.truncate_by_halving(messages)

    return messages

4. Keep stats, but avoid parallel cache tracking

You can still expose compression_count and use the token counter’s own stats, without tracking a parallel cache in ContextManager:

def get_stats(self) -> dict:
    stats = {
        "compression_count": self._compression_count,
    }
    if hasattr(self.token_counter, "get_cache_stats"):
        stats["token_counter_cache"] = self.token_counter.get_cache_stats()
    return stats

def reset_stats(self) -> None:
    self._compression_count = 0
    if hasattr(self.token_counter, "clear_cache"):
        self.token_counter.clear_cache()

This retains all functional behavior (compression flow, logging, trusted_token_usage, observability) but removes duplicated caching logic, state sync branches, and reliance on a private _get_cache_key method.

self._last_messages_fingerprint: int | None = None
self._last_token_count: int | None = None
self._compression_count = 0

def _get_messages_fingerprint(self, messages: list[Message]) -> int:
"""生成消息列表的指纹,用于检测消息内容是否变化。"""
if not messages:
return 0

# 使用 token counter 的缓存键作为指纹
return self.token_counter._get_cache_key(messages)
Comment on lines +60 to +61
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Using a private method for the fingerprint couples ContextManager tightly to token_counter internals.

Since this calls self.token_counter._get_cache_key(messages), it depends on a private API that may change when the token counter is refactored (e.g., cache key structure or implementation). Please either add and use a public get_cache_key/fingerprint on the token counter, or compute the fingerprint here using only the public Message interface to keep the abstraction boundary clean.

Suggested implementation:

        # 使用 token counter 的公开缓存键方法作为指纹
        return self.token_counter.get_cache_key(messages)

You will also need to:

  1. Add a public method (e.g. get_cache_key) on the token counter class, most likely in astrbot/core/agent/context/token_counter.py, with a signature like:
    def get_cache_key(self, messages: list[Message]) -> int:
        return self._get_cache_key(messages)
  2. Ensure that this new public method is part of the token counter's stable API and replace any other direct usages of _get_cache_key (if present) with get_cache_key for consistency and to keep the abstraction boundary clean.


async def process(
self, messages: list[Message], trusted_token_usage: int = 0
) -> list[Message]:
"""Process the messages.

Args:
messages: The original message list.
trusted_token_usage: The total token usage that LLM API returned.

Returns:
The processed message list.
Expand All @@ -65,14 +85,34 @@

# 2. 基于 token 的压缩
if self.config.max_context_tokens > 0:
total_tokens = self.token_counter.count_tokens(
result, trusted_token_usage
)
# 优化: 使用缓存的 token 计数或计算新值
current_fingerprint = self._get_messages_fingerprint(messages)

if trusted_token_usage > 0:
total_tokens = trusted_token_usage
elif (
self._last_messages_fingerprint is not None
and self._last_messages_fingerprint == current_fingerprint
and self._last_token_count is not None
):
# 消息内容没变化,使用缓存的 token 计数
total_tokens = self._last_token_count
else:
# 消息内容变了,需要重新计算
total_tokens = self.token_counter.count_tokens(result)
self._last_messages_fingerprint = current_fingerprint

# 更新缓存
self._last_token_count = total_tokens

if self.compressor.should_compress(
result, total_tokens, self.config.max_context_tokens
):
result = await self._run_compression(result, total_tokens)
# 压缩后更新指纹
self._last_messages_fingerprint = self._get_messages_fingerprint(
result
)

return result
except Exception as e:
Expand All @@ -94,27 +134,64 @@
"""
logger.debug("Compress triggered, starting compression...")

self._compression_count += 1

messages = await self.compressor(messages)

# double check
tokens_after_summary = self.token_counter.count_tokens(messages)
# 优化: 压缩后只计算一次 token
tokens_after_compression = self.token_counter.count_tokens(messages)

# calculate compress rate
compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100
compress_rate = (
Comment on lines 143 to +145

Check failure

Code scanning / CodeQL

Clear-text logging of sensitive information High

This expression logs
sensitive data (secret)
as clear text.
This expression logs
sensitive data (secret)
as clear text.
tokens_after_compression / self.config.max_context_tokens
) * 100
logger.info(
f"Compress completed."
f" {prev_tokens} -> {tokens_after_summary} tokens,"
f"Compress #{self._compression_count} completed."
f" {prev_tokens} -> {tokens_after_compression} tokens,"
f" compression rate: {compress_rate:.2f}%.",
)

# last check
# 更新缓存
self._last_token_count = tokens_after_compression
self._last_messages_fingerprint = self._get_messages_fingerprint(messages)

# last check - 优化: 减少不必要的递归调用
if self.compressor.should_compress(
messages, tokens_after_summary, self.config.max_context_tokens
messages, tokens_after_compression, self.config.max_context_tokens
):
logger.info(
"Context still exceeds max tokens after compression, applying halving truncation..."
)
# still need compress, truncate by half
messages = self.truncator.truncate_by_halving(messages)
# 更新缓存
self._last_token_count = self.token_counter.count_tokens(messages)
self._last_messages_fingerprint = self._get_messages_fingerprint(messages)

return messages

def get_stats(self) -> dict:
"""获取上下文管理器的统计信息。

Returns:
Dictionary with stats including compression count and token counter stats.
"""
stats = {
"compression_count": self._compression_count,
"last_token_count": self._last_token_count,
"last_messages_fingerprint": self._last_messages_fingerprint,
}

# 如果 token counter 有缓存统计,也一并返回
if hasattr(self.token_counter, "get_cache_stats"):
stats["token_counter_cache"] = self.token_counter.get_cache_stats()

return stats

def reset_stats(self) -> None:
"""重置统计信息。"""
self._compression_count = 0
self._last_token_count = None
self._last_messages_fingerprint = None
if hasattr(self.token_counter, "clear_cache"):
self.token_counter.clear_cache()
Loading
Loading