From 34b527ccc16c98d8d90940320c76d855a2221ee5 Mon Sep 17 00:00:00 2001 From: qer Date: Wed, 20 May 2026 17:21:38 +0800 Subject: [PATCH 1/5] fix(kimi): clamp completion budget dynamically --- docs/en/configuration/env-vars.md | 12 ++- docs/zh/configuration/env-vars.md | 13 ++- .../kosong/src/kosong/chat_provider/kimi.py | 20 ++-- .../tests/api_snapshot_tests/test_kimi.py | 16 ++- src/kimi_cli/llm.py | 30 +++++- src/kimi_cli/soul/compaction.py | 4 +- src/kimi_cli/soul/kimisoul.py | 39 +++++++- tests/core/test_create_llm.py | 50 +++++++++- tests/core/test_kimisoul_completion_budget.py | 98 +++++++++++++++++++ 9 files changed, 254 insertions(+), 28 deletions(-) create mode 100644 tests/core/test_kimisoul_completion_budget.py diff --git a/docs/en/configuration/env-vars.md b/docs/en/configuration/env-vars.md index 43707c76d..3fa107d82 100644 --- a/docs/en/configuration/env-vars.md +++ b/docs/en/configuration/env-vars.md @@ -17,7 +17,8 @@ The following environment variables take effect when using `kimi` type providers | `KIMI_MODEL_CAPABILITIES` | Model capabilities, comma-separated (e.g., `thinking,image_in`) | | `KIMI_MODEL_TEMPERATURE` | Generation parameter `temperature` | | `KIMI_MODEL_TOP_P` | Generation parameter `top_p` | -| `KIMI_MODEL_MAX_TOKENS` | Generation parameter `max_tokens` | +| `KIMI_MODEL_MAX_COMPLETION_TOKENS` | Generation parameter `max_completion_tokens` | +| `KIMI_MODEL_MAX_TOKENS` | Compatibility alias for `KIMI_MODEL_MAX_COMPLETION_TOKENS` | | `KIMI_MODEL_THINKING_KEEP` | Moonshot `thinking.keep` switch for preserved thinking (only applied when thinking mode is active) | ### `KIMI_BASE_URL` @@ -76,14 +77,17 @@ Sets the generation parameter `top_p` (nucleus sampling), controlling output div export KIMI_MODEL_TOP_P="0.9" ``` -### `KIMI_MODEL_MAX_TOKENS` +### `KIMI_MODEL_MAX_COMPLETION_TOKENS` -Sets the generation parameter `max_tokens`, limiting the maximum tokens per response. +Sets the generation parameter `max_completion_tokens`, limiting the maximum tokens per response. ```sh -export KIMI_MODEL_MAX_TOKENS="4096" +export KIMI_MODEL_MAX_COMPLETION_TOKENS="4096" ``` +`KIMI_MODEL_MAX_TOKENS` is still accepted. If both variables are set, +`KIMI_MODEL_MAX_COMPLETION_TOKENS` takes precedence. + ### `KIMI_MODEL_THINKING_KEEP` Forwards the value verbatim to the Moonshot API as `thinking.keep`, enabling Preserved Thinking (see the [Moonshot docs](https://platform.kimi.com/docs/guide/use-kimi-k2-thinking-model#preserved-thinking)). Setting it to `all` causes the provider to preserve the reasoning content of previous assistant turns across requests. The value is passed through unchanged, no validation or case normalization is performed. diff --git a/docs/zh/configuration/env-vars.md b/docs/zh/configuration/env-vars.md index 09083c64b..0c03e0d74 100644 --- a/docs/zh/configuration/env-vars.md +++ b/docs/zh/configuration/env-vars.md @@ -17,7 +17,8 @@ Kimi Code CLI 支持通过环境变量覆盖配置或控制运行行为。本页 | `KIMI_MODEL_CAPABILITIES` | 模型能力,逗号分隔(如 `thinking,image_in`) | | `KIMI_MODEL_TEMPERATURE` | 生成参数 `temperature` | | `KIMI_MODEL_TOP_P` | 生成参数 `top_p` | -| `KIMI_MODEL_MAX_TOKENS` | 生成参数 `max_tokens` | +| `KIMI_MODEL_MAX_COMPLETION_TOKENS` | 生成参数 `max_completion_tokens` | +| `KIMI_MODEL_MAX_TOKENS` | `KIMI_MODEL_MAX_COMPLETION_TOKENS` 的兼容别名 | | `KIMI_MODEL_THINKING_KEEP` | Moonshot `thinking.keep` 开关(Preserved Thinking),仅在 Thinking 模式下生效 | ### `KIMI_BASE_URL` @@ -76,14 +77,17 @@ export KIMI_MODEL_TEMPERATURE="0.7" export KIMI_MODEL_TOP_P="0.9" ``` -### `KIMI_MODEL_MAX_TOKENS` +### `KIMI_MODEL_MAX_COMPLETION_TOKENS` -设置生成参数 `max_tokens`,限制单次回复的最大 token 数。 +设置生成参数 `max_completion_tokens`,限制单次回复的最大 token 数。 ```sh -export KIMI_MODEL_MAX_TOKENS="4096" +export KIMI_MODEL_MAX_COMPLETION_TOKENS="4096" ``` +`KIMI_MODEL_MAX_TOKENS` 仍可使用;如果两个环境变量都设置,优先使用 +`KIMI_MODEL_MAX_COMPLETION_TOKENS`。 + ### `KIMI_MODEL_THINKING_KEEP` 将 env 值原样作为 `thinking.keep` 字段发送给 Moonshot API,用于开启 Preserved Thinking(参考 [Moonshot 官方文档](https://platform.kimi.com/docs/guide/use-kimi-k2-thinking-model#preserved-thinking))。设为 `all` 可让模型在多轮之间保留历史 reasoning_content。值不做任何校验、不做大小写归一化,透传给 API 自己判断。 @@ -189,4 +193,3 @@ export KIMI_CLI_PASTE_LINE_THRESHOLD="2" 注意:两个阈值的判断逻辑是"满足任一即折叠"(字符数 **或** 行数),因此只需调低行数阈值即可。不建议将字符数阈值设为很小的值(如 `1`),否则所有非空粘贴(包括单行短文本)都会被折叠。 ::: - diff --git a/packages/kosong/src/kosong/chat_provider/kimi.py b/packages/kosong/src/kosong/chat_provider/kimi.py index 82b6e5ac3..cbaf4152b 100644 --- a/packages/kosong/src/kosong/chat_provider/kimi.py +++ b/packages/kosong/src/kosong/chat_provider/kimi.py @@ -85,7 +85,9 @@ class GenerationKwargs(TypedDict, total=False): See https://platform.moonshot.ai/docs/api/chat#request-body. """ + max_completion_tokens: int | None max_tokens: int | None + """Deprecated alias. Normalized to ``max_completion_tokens`` before requests.""" temperature: float | None top_p: float | None n: int | None @@ -160,11 +162,7 @@ async def generate( messages.append({"role": "system", "content": system_prompt}) messages.extend(_convert_message(message) for message in history) - generation_kwargs: dict[str, Any] = { - # default kimi generation kwargs - "max_tokens": 32000, - } - generation_kwargs.update(self._generation_kwargs) + generation_kwargs: dict[str, Any] = dict(self._generation_kwargs) try: response = await self.client.chat.completions.create( @@ -221,7 +219,7 @@ def with_generation_kwargs(self, **kwargs: Unpack[GenerationKwargs]) -> Self: """ new_self = copy.copy(self) new_self._generation_kwargs = copy.deepcopy(self._generation_kwargs) - new_self._generation_kwargs.update(kwargs) + new_self._generation_kwargs.update(_normalize_generation_kwargs(kwargs)) return new_self def with_extra_body(self, extra_body: ExtraBody) -> Self: @@ -304,6 +302,14 @@ def _guess_filename(mime_type: str) -> str: return f"upload{extension}" +def _normalize_generation_kwargs(kwargs: Kimi.GenerationKwargs) -> Kimi.GenerationKwargs: + normalized: dict[str, Any] = dict(kwargs) + max_tokens = normalized.pop("max_tokens", None) + if max_tokens is not None and "max_completion_tokens" not in normalized: + normalized["max_completion_tokens"] = max_tokens + return cast(Kimi.GenerationKwargs, normalized) + + def _convert_message(message: Message) -> ChatCompletionMessageParam: message = message.model_copy(deep=True) reasoning_content: str = "" @@ -507,7 +513,7 @@ async def _dev_main(): ] stream = await chat.with_generation_kwargs( temperature=0, - max_tokens=1000, + max_completion_tokens=1000, ).generate(system_prompt, [], history) async for part in stream: print(part.model_dump(exclude_none=True)) diff --git a/packages/kosong/tests/api_snapshot_tests/test_kimi.py b/packages/kosong/tests/api_snapshot_tests/test_kimi.py index 6358f985d..7081e57de 100644 --- a/packages/kosong/tests/api_snapshot_tests/test_kimi.py +++ b/packages/kosong/tests/api_snapshot_tests/test_kimi.py @@ -368,7 +368,21 @@ async def test_kimi_generation_kwargs(): async for _ in stream: pass body = json.loads(mock.calls.last.request.content.decode()) - assert (body["temperature"], body["max_tokens"]) == snapshot((0.7, 2048)) + assert (body["temperature"], body["max_completion_tokens"]) == snapshot((0.7, 2048)) + + +async def test_kimi_default_omits_completion_cap(): + with respx.mock(base_url="https://api.moonshot.ai") as mock: + mock.post("/v1/chat/completions").mock( + return_value=Response(200, json=make_chat_completion_response()) + ) + provider = Kimi(model="kimi-k2-turbo-preview", api_key="test-key", stream=False) + stream = await provider.generate("", [], [Message(role="user", content="Hi")]) + async for _ in stream: + pass + body = json.loads(mock.calls.last.request.content.decode()) + assert "max_tokens" not in body + assert "max_completion_tokens" not in body async def test_kimi_with_thinking(): diff --git a/src/kimi_cli/llm.py b/src/kimi_cli/llm.py index 8a713a9e6..fbdfdf8e3 100644 --- a/src/kimi_cli/llm.py +++ b/src/kimi_cli/llm.py @@ -31,6 +31,7 @@ type ModelCapability = Literal["image_in", "video_in", "thinking", "always_thinking"] ALL_MODEL_CAPABILITIES: set[ModelCapability] = set(get_args(ModelCapability.__value__)) +MAX_COMPLETION_TOKENS_SAFETY_MARGIN = 1024 @dataclass(slots=True) @@ -46,6 +47,28 @@ def model_name(self) -> str: return self.chat_provider.model_name +def compute_max_completion_tokens( + *, + max_context_size: int, + input_tokens: int, + response_budget: int, + safety_margin: int = MAX_COMPLETION_TOKENS_SAFETY_MARGIN, +) -> int: + """Compute a per-request completion cap that fits inside the context window.""" + configured_budget = max(1, response_budget) + if max_context_size <= 0: + return configured_budget + + input_tokens = max(0, input_tokens) + remaining = max_context_size - input_tokens + if remaining <= 0: + return 1 + + safe_remaining = remaining - max(0, safety_margin) + available = safe_remaining if safe_remaining > 0 else remaining + return max(1, min(configured_budget, available)) + + def model_display_name(model_name: str | None, model: LLMModel | None = None) -> str: if model is not None and model.display_name: return model.display_name @@ -147,8 +170,11 @@ def create_llm( gen_kwargs["temperature"] = float(temperature) if top_p := os.getenv("KIMI_MODEL_TOP_P"): gen_kwargs["top_p"] = float(top_p) - if max_tokens := os.getenv("KIMI_MODEL_MAX_TOKENS"): - gen_kwargs["max_tokens"] = int(max_tokens) + max_completion_tokens = os.getenv("KIMI_MODEL_MAX_COMPLETION_TOKENS") + if max_completion_tokens is None: + max_completion_tokens = os.getenv("KIMI_MODEL_MAX_TOKENS") + if max_completion_tokens: + gen_kwargs["max_completion_tokens"] = int(max_completion_tokens) if gen_kwargs: chat_provider = chat_provider.with_generation_kwargs(**gen_kwargs) diff --git a/src/kimi_cli/soul/compaction.py b/src/kimi_cli/soul/compaction.py index 7db37d5d8..4a61e3148 100644 --- a/src/kimi_cli/soul/compaction.py +++ b/src/kimi_cli/soul/compaction.py @@ -111,8 +111,8 @@ async def compact( if compact_message is None: return CompactionResult(messages=to_preserve, usage=None) - # Call kosong.step to get the compacted context - # TODO: set max completion tokens + # Call kosong.step to get the compacted context. + # KimiSoul configures provider-specific completion caps before calling this. logger.debug("Compacting context...") result = await kosong.step( chat_provider=llm.chat_provider, diff --git a/src/kimi_cli/soul/kimisoul.py b/src/kimi_cli/soul/kimisoul.py index 877c565f2..60da08115 100644 --- a/src/kimi_cli/soul/kimisoul.py +++ b/src/kimi_cli/soul/kimisoul.py @@ -4,7 +4,7 @@ import time import uuid from collections.abc import Awaitable, Callable, Sequence -from dataclasses import dataclass +from dataclasses import dataclass, replace from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, cast @@ -16,6 +16,7 @@ APIEmptyResponseError, APIStatusError, APITimeoutError, + ChatProvider, RetryableChatProvider, ) from kosong.message import Message @@ -29,7 +30,7 @@ ) from kimi_cli.background import build_active_task_snapshot from kimi_cli.hooks.engine import HookEngine -from kimi_cli.llm import ModelCapability +from kimi_cli.llm import ModelCapability, compute_max_completion_tokens from kimi_cli.notifications import ( NotificationView, build_notification_message, @@ -1003,6 +1004,7 @@ async def _append_notification(view: NotificationView) -> None: # Normalize: merge adjacent user messages for clean API input effective_history = normalize_history(self._context.history) + chat_provider = self._with_dynamic_completion_budget(chat_provider) async def _run_step_once() -> StepResult: # run an LLM step (may be interrupted) @@ -1123,6 +1125,27 @@ async def _kosong_step_with_retry() -> StepResult: return None return StepOutcome(stop_reason="no_tool_calls", assistant_message=result.message) + def _with_dynamic_completion_budget(self, chat_provider: ChatProvider) -> ChatProvider: + from kosong.chat_provider.kimi import Kimi + + if not isinstance(chat_provider, Kimi): + return chat_provider + + parameters = chat_provider.model_parameters + configured_budget = parameters.get("max_completion_tokens") + if configured_budget is None: + configured_budget = parameters.get("max_tokens") + if type(configured_budget) is not int: + configured_budget = self._loop_control.reserved_context_size + + assert self._runtime.llm is not None + max_completion_tokens = compute_max_completion_tokens( + max_context_size=self._runtime.llm.max_context_size, + input_tokens=self._context.token_count_with_pending, + response_budget=configured_budget, + ) + return chat_provider.with_generation_kwargs(max_completion_tokens=max_completion_tokens) + async def _grow_context(self, result: StepResult, tool_results: list[ToolResult]): logger.debug("Growing context with result: {result}", result=result) @@ -1166,13 +1189,19 @@ async def compact_context( ChatProviderError: When the chat provider returns an error. """ - chat_provider = self._runtime.llm.chat_provider if self._runtime.llm is not None else None + compaction_llm = None + if self._runtime.llm is not None: + compaction_llm = replace( + self._runtime.llm, + chat_provider=self._with_dynamic_completion_budget(self._runtime.llm.chat_provider), + ) + chat_provider = compaction_llm.chat_provider if compaction_llm is not None else None async def _run_compaction_once() -> CompactionResult: - if self._runtime.llm is None: + if compaction_llm is None: raise LLMNotSet() return await self._compaction.compact( - self._context.history, self._runtime.llm, custom_instruction=custom_instruction + self._context.history, compaction_llm, custom_instruction=custom_instruction ) start_time = time.monotonic() diff --git a/tests/core/test_create_llm.py b/tests/core/test_create_llm.py index 037ab3697..a2c3f38c4 100644 --- a/tests/core/test_create_llm.py +++ b/tests/core/test_create_llm.py @@ -7,7 +7,7 @@ from pydantic import SecretStr from kimi_cli.config import LLMModel, LLMProvider -from kimi_cli.llm import augment_provider_with_env_vars, create_llm +from kimi_cli.llm import augment_provider_with_env_vars, compute_max_completion_tokens, create_llm def test_augment_provider_with_env_vars_kimi(monkeypatch): @@ -74,11 +74,57 @@ def test_create_llm_kimi_model_parameters(monkeypatch): "base_url": "https://api.test/v1/", "temperature": 0.2, "top_p": 0.8, - "max_tokens": 1234, + "max_completion_tokens": 1234, } ) +def test_create_llm_kimi_prefers_max_completion_tokens_env(monkeypatch): + provider = LLMProvider( + type="kimi", + base_url="https://api.test/v1", + api_key=SecretStr("test-key"), + ) + model = LLMModel( + provider="kimi", + model="kimi-base", + max_context_size=4096, + capabilities=None, + ) + + monkeypatch.setenv("KIMI_MODEL_MAX_TOKENS", "1234") + monkeypatch.setenv("KIMI_MODEL_MAX_COMPLETION_TOKENS", "5678") + + llm = create_llm(provider, model) + assert llm is not None + assert isinstance(llm.chat_provider, Kimi) + + assert llm.chat_provider.model_parameters["max_completion_tokens"] == 5678 + + +def test_compute_max_completion_tokens_uses_response_budget_when_it_fits(): + assert ( + compute_max_completion_tokens( + max_context_size=262_144, + input_tokens=20_000, + response_budget=50_000, + ) + == 50_000 + ) + + +def test_compute_max_completion_tokens_clamps_to_remaining_context(): + assert ( + compute_max_completion_tokens( + max_context_size=8_192, + input_tokens=7_000, + response_budget=50_000, + safety_margin=512, + ) + == 680 + ) + + def test_create_llm_echo_provider(): provider = LLMProvider(type="_echo", base_url="", api_key=SecretStr("")) model = LLMModel(provider="_echo", model="echo", max_context_size=1234) diff --git a/tests/core/test_kimisoul_completion_budget.py b/tests/core/test_kimisoul_completion_budget.py new file mode 100644 index 000000000..096ec176c --- /dev/null +++ b/tests/core/test_kimisoul_completion_budget.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest +from kosong.chat_provider.kimi import Kimi +from kosong.tooling.empty import EmptyToolset +from pydantic import SecretStr + +from kimi_cli.config import LLMModel, LLMProvider +from kimi_cli.llm import LLM +from kimi_cli.soul.agent import Agent, Runtime +from kimi_cli.soul.context import Context +from kimi_cli.soul.kimisoul import KimiSoul + + +def _make_soul(runtime: Runtime, tmp_path: Path) -> KimiSoul: + agent = Agent( + name="Completion Budget Test Agent", + system_prompt="Test prompt.", + toolset=EmptyToolset(), + runtime=runtime, + ) + return KimiSoul(agent, context=Context(file_backend=tmp_path / "history.jsonl")) + + +def _make_kimi_llm(chat_provider: Kimi, *, max_context_size: int = 100_000) -> LLM: + return LLM( + chat_provider=chat_provider, + max_context_size=max_context_size, + capabilities=set(), + model_config=LLMModel( + provider="kimi", + model="kimi-k2", + max_context_size=max_context_size, + ), + provider_config=LLMProvider( + type="kimi", + base_url="https://api.test/v1", + api_key=SecretStr("test-key"), + ), + ) + + +@pytest.mark.asyncio +async def test_dynamic_completion_budget_clamps_kimi_request( + runtime: Runtime, tmp_path: Path +) -> None: + chat_provider = Kimi( + model="kimi-k2", + base_url="https://api.test/v1", + api_key="test-key", + stream=False, + ) + runtime.llm = _make_kimi_llm(chat_provider) + soul = _make_soul(runtime, tmp_path) + await soul.context.update_token_count(60_000) + + budgeted = soul._with_dynamic_completion_budget(chat_provider) + + assert isinstance(budgeted, Kimi) + assert budgeted.model_parameters["max_completion_tokens"] == 38_976 + + +def test_dynamic_completion_budget_preserves_explicit_kimi_cap( + runtime: Runtime, tmp_path: Path +) -> None: + chat_provider = Kimi( + model="kimi-k2", + base_url="https://api.test/v1", + api_key="test-key", + stream=False, + ).with_generation_kwargs(max_completion_tokens=1234) + runtime.llm = _make_kimi_llm(chat_provider) + soul = _make_soul(runtime, tmp_path) + + budgeted = soul._with_dynamic_completion_budget(chat_provider) + + assert budgeted.model_parameters["max_completion_tokens"] == 1234 + + +@pytest.mark.asyncio +async def test_dynamic_completion_budget_clamps_explicit_kimi_cap( + runtime: Runtime, tmp_path: Path +) -> None: + chat_provider = Kimi( + model="kimi-k2", + base_url="https://api.test/v1", + api_key="test-key", + stream=False, + ).with_generation_kwargs(max_completion_tokens=50_000) + runtime.llm = _make_kimi_llm(chat_provider, max_context_size=8_192) + soul = _make_soul(runtime, tmp_path) + await soul.context.update_token_count(7_000) + + budgeted = soul._with_dynamic_completion_budget(chat_provider) + + assert budgeted.model_parameters["max_completion_tokens"] == 168 From f26b1f784dce118614833629bb3c0e42d3e40ba6 Mon Sep 17 00:00:00 2001 From: qer Date: Wed, 20 May 2026 17:27:01 +0800 Subject: [PATCH 2/5] test(kimi): narrow budgeted provider type --- tests/core/test_kimisoul_completion_budget.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/core/test_kimisoul_completion_budget.py b/tests/core/test_kimisoul_completion_budget.py index 096ec176c..d9dd963f4 100644 --- a/tests/core/test_kimisoul_completion_budget.py +++ b/tests/core/test_kimisoul_completion_budget.py @@ -76,6 +76,7 @@ def test_dynamic_completion_budget_preserves_explicit_kimi_cap( budgeted = soul._with_dynamic_completion_budget(chat_provider) + assert isinstance(budgeted, Kimi) assert budgeted.model_parameters["max_completion_tokens"] == 1234 @@ -95,4 +96,5 @@ async def test_dynamic_completion_budget_clamps_explicit_kimi_cap( budgeted = soul._with_dynamic_completion_budget(chat_provider) + assert isinstance(budgeted, Kimi) assert budgeted.model_parameters["max_completion_tokens"] == 168 From a3ed772300318789bfa6c50ff14ba3578cc1c39f Mon Sep 17 00:00:00 2001 From: qer Date: Wed, 20 May 2026 17:39:31 +0800 Subject: [PATCH 3/5] fix(kimi): avoid cloning unchanged compaction llm --- src/kimi_cli/soul/kimisoul.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/kimi_cli/soul/kimisoul.py b/src/kimi_cli/soul/kimisoul.py index 60da08115..f94bc3860 100644 --- a/src/kimi_cli/soul/kimisoul.py +++ b/src/kimi_cli/soul/kimisoul.py @@ -1191,10 +1191,10 @@ async def compact_context( compaction_llm = None if self._runtime.llm is not None: - compaction_llm = replace( - self._runtime.llm, - chat_provider=self._with_dynamic_completion_budget(self._runtime.llm.chat_provider), - ) + compaction_llm = self._runtime.llm + chat_provider = self._with_dynamic_completion_budget(compaction_llm.chat_provider) + if chat_provider is not compaction_llm.chat_provider: + compaction_llm = replace(compaction_llm, chat_provider=chat_provider) chat_provider = compaction_llm.chat_provider if compaction_llm is not None else None async def _run_compaction_once() -> CompactionResult: From b0e4851c89e3428961eadc38c4d5e281add5b85f Mon Sep 17 00:00:00 2001 From: 7Sageer <7sageer@djwcb.cn> Date: Wed, 20 May 2026 20:38:44 +0800 Subject: [PATCH 4/5] fix(kimi): pass per-call completion budget via generation_overrides The previous approach in this PR derived a budget-clamped Kimi instance via with_generation_kwargs() in KimiSoul._with_dynamic_completion_budget, but that shallow copy shared self.client with runtime.llm.chat_provider. A transient APIConnectionError would invoke on_retryable_error on the copy, rebinding only the copy's client while the original (now closed) client stayed attached to runtime.llm.chat_provider. Every subsequent step then had to walk through a fresh connection-recovery cycle before making any real progress. Switch to a per-call generation_overrides mapping plumbed through kosong.step / kosong.generate / ChatProvider.generate, so providers can merge per-request kwargs without producing a parallel instance. The single Kimi instance owned by Runtime keeps its live client and OAuth token state, and on_retryable_error mutations are visible on the next step. The Compaction protocol gains the same parameter so callers no longer need to mutate the chat provider before invoking compact(). - Extend ChatProvider.generate Protocol with generation_overrides - Forward through kosong.generate and kosong.step - Kimi.generate normalizes max_tokens -> max_completion_tokens in overrides - Other providers (Anthropic, OpenAILegacy, OpenAIResponses, GoogleGenAI, Chaos, Mock, Echo, ScriptedEcho) accept and merge the parameter - KimiSoul replaces _with_dynamic_completion_budget with _compute_completion_overrides returning a dict - compact_context no longer constructs a replacement LLM - SimpleCompaction.compact threads generation_overrides to kosong.step - Add regression test ensuring runtime.llm.chat_provider is not replaced - Add kosong-level tests for per-call override merge/normalization --- packages/kosong/src/kosong/__init__.py | 14 ++- packages/kosong/src/kosong/_generate.py | 9 +- .../src/kosong/chat_provider/__init__.py | 15 ++- .../kosong/src/kosong/chat_provider/chaos.py | 8 +- .../src/kosong/chat_provider/echo/echo.py | 7 +- .../chat_provider/echo/scripted_echo.py | 7 +- .../kosong/src/kosong/chat_provider/kimi.py | 10 +- .../kosong/src/kosong/chat_provider/mock.py | 7 +- .../kosong/contrib/chat_provider/anthropic.py | 4 + .../contrib/chat_provider/google_genai.py | 9 +- .../contrib/chat_provider/openai_legacy.py | 6 +- .../contrib/chat_provider/openai_responses.py | 6 +- .../tests/api_snapshot_tests/test_kimi.py | 66 +++++++++++++ src/kimi_cli/soul/compaction.py | 27 ++++-- src/kimi_cli/soul/kimisoul.py | 40 ++++---- tests/core/test_auth_error_handling.py | 8 +- tests/core/test_kimisoul_completion_budget.py | 93 +++++++++++++++++-- tests/core/test_kimisoul_ralph_loop.py | 3 +- tests/core/test_kimisoul_retry_recovery.py | 8 +- tests/core/test_kimisoul_steer.py | 2 +- tests/core/test_notifications.py | 3 +- 21 files changed, 298 insertions(+), 54 deletions(-) diff --git a/packages/kosong/src/kosong/__init__.py b/packages/kosong/src/kosong/__init__.py index 7dfdd9414..c4785d48d 100644 --- a/packages/kosong/src/kosong/__init__.py +++ b/packages/kosong/src/kosong/__init__.py @@ -70,8 +70,9 @@ async def main() -> None: """ import asyncio -from collections.abc import Callable, Sequence +from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass +from typing import Any from loguru import logger @@ -109,6 +110,7 @@ async def step( *, on_message_part: Callback[[StreamedMessagePart], None] | None = None, on_tool_result: Callable[[ToolResult], None] | None = None, + generation_overrides: Mapping[str, Any] | None = None, ) -> "StepResult": """ Run one agent "step". In one step, the function generates LLM response based on the given @@ -121,6 +123,15 @@ async def step( The token usage will be returned in the `StepResult` if available. + Args: + chat_provider: The chat provider to use for generation. + system_prompt: The system prompt forwarded to the chat provider. + toolset: The toolset that handles tool calls and exposes available tools. + history: The message history forwarded to the chat provider. + on_message_part: Optional callback fired for each streamed message part. + on_tool_result: Optional callback fired when an individual tool result resolves. + generation_overrides: Optional per-call overrides forwarded to ``chat_provider.generate``. + Raises: APIConnectionError: If the API connection fails. APITimeoutError: If the API request times out. @@ -162,6 +173,7 @@ async def on_tool_call(tool_call: ToolCall): history, on_message_part=on_message_part, on_tool_call=on_tool_call, + generation_overrides=generation_overrides, ) except (ChatProviderError, asyncio.CancelledError): # cancel all the futures to avoid hanging tasks diff --git a/packages/kosong/src/kosong/_generate.py b/packages/kosong/src/kosong/_generate.py index 1eb45013e..406a42b3b 100644 --- a/packages/kosong/src/kosong/_generate.py +++ b/packages/kosong/src/kosong/_generate.py @@ -1,5 +1,6 @@ -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from dataclasses import dataclass +from typing import Any from loguru import logger @@ -22,6 +23,7 @@ async def generate( *, on_message_part: Callback[[StreamedMessagePart], None] | None = None, on_tool_call: Callback[[ToolCall], None] | None = None, + generation_overrides: Mapping[str, Any] | None = None, ) -> "GenerateResult": """ Generate one message based on the given context. @@ -34,6 +36,7 @@ async def generate( history: The message history to use for generation. on_message_part: An optional callback to be called for each raw message part. on_tool_call: An optional callback to be called for each complete tool call. + generation_overrides: Optional per-call overrides forwarded to ``chat_provider.generate``. Returns: A tuple of the generated message and the token usage (if available). @@ -50,7 +53,9 @@ async def generate( pending_part: StreamedMessagePart | None = None # message part that is currently incomplete logger.trace("Generating with history: {history}", history=history) - stream = await chat_provider.generate(system_prompt, tools, history) + stream = await chat_provider.generate( + system_prompt, tools, history, generation_overrides=generation_overrides + ) async for part in stream: logger.trace("Received part: {part}", part=part) if on_message_part: diff --git a/packages/kosong/src/kosong/chat_provider/__init__.py b/packages/kosong/src/kosong/chat_provider/__init__.py index 4cc910f79..7e8770451 100644 --- a/packages/kosong/src/kosong/chat_provider/__init__.py +++ b/packages/kosong/src/kosong/chat_provider/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations -from collections.abc import AsyncIterator, Sequence -from typing import TYPE_CHECKING, Literal, Protocol, Self, runtime_checkable +from collections.abc import AsyncIterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Literal, Protocol, Self, runtime_checkable from pydantic import BaseModel @@ -40,10 +40,21 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + *, + generation_overrides: Mapping[str, Any] | None = None, ) -> StreamedMessage: """ Generate a new message based on the given system prompt, tools, and history. + Args: + system_prompt: The system prompt to use for generation. + tools: The tools available for the model to call. + history: The message history to use for generation. + generation_overrides: Optional per-call overrides merged on top of the provider's + generation kwargs without mutating provider state. Treat the mapping as + read-only and request-scoped — implementations must not alias it into any + long-lived state. + Raises: APIConnectionError: If the API connection fails. APITimeoutError: If the API request times out. diff --git a/packages/kosong/src/kosong/chat_provider/chaos.py b/packages/kosong/src/kosong/chat_provider/chaos.py index ece2f76c6..e001a2180 100644 --- a/packages/kosong/src/kosong/chat_provider/chaos.py +++ b/packages/kosong/src/kosong/chat_provider/chaos.py @@ -1,7 +1,7 @@ import json import os import random -from collections.abc import AsyncIterator, Sequence +from collections.abc import AsyncIterator, Mapping, Sequence from typing import TYPE_CHECKING, Any import httpx @@ -115,8 +115,12 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + *, + generation_overrides: Mapping[str, Any] | None = None, ) -> "ChaosStreamedMessage": - base_stream = await self._provider.generate(system_prompt, tools, history) + base_stream = await self._provider.generate( + system_prompt, tools, history, generation_overrides=generation_overrides + ) return ChaosStreamedMessage(base_stream, self._chaos_config) def _monkey_patch_client(self): diff --git a/packages/kosong/src/kosong/chat_provider/echo/echo.py b/packages/kosong/src/kosong/chat_provider/echo/echo.py index b1b99e0fd..0831c2e87 100644 --- a/packages/kosong/src/kosong/chat_provider/echo/echo.py +++ b/packages/kosong/src/kosong/chat_provider/echo/echo.py @@ -1,8 +1,8 @@ from __future__ import annotations import copy -from collections.abc import AsyncIterator, Sequence -from typing import TYPE_CHECKING, Self +from collections.abc import AsyncIterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Self from kosong.chat_provider import ( ChatProvider, @@ -72,7 +72,10 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + *, + generation_overrides: Mapping[str, Any] | None = None, ) -> EchoStreamedMessage: + del generation_overrides # echo provider has no API to forward overrides to if not history: raise ChatProviderError("EchoChatProvider requires at least one message in history.") if history[-1].role != "user": diff --git a/packages/kosong/src/kosong/chat_provider/echo/scripted_echo.py b/packages/kosong/src/kosong/chat_provider/echo/scripted_echo.py index b044d5978..dd00e007a 100644 --- a/packages/kosong/src/kosong/chat_provider/echo/scripted_echo.py +++ b/packages/kosong/src/kosong/chat_provider/echo/scripted_echo.py @@ -3,8 +3,8 @@ import copy import json from collections import deque -from collections.abc import AsyncIterator, Iterable, Sequence -from typing import TYPE_CHECKING, Self +from collections.abc import AsyncIterator, Iterable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Self from kosong.chat_provider import ( ChatProvider, @@ -49,7 +49,10 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + *, + generation_overrides: Mapping[str, Any] | None = None, ) -> ScriptedEchoStreamedMessage: + del generation_overrides # scripted echo replays canned scripts; overrides not applicable if not self._scripts: raise ChatProviderError(f"ScriptedEchoChatProvider exhausted at turn {self._turn + 1}.") script_text = self._scripts.popleft() diff --git a/packages/kosong/src/kosong/chat_provider/kimi.py b/packages/kosong/src/kosong/chat_provider/kimi.py index cbaf4152b..7341c3391 100644 --- a/packages/kosong/src/kosong/chat_provider/kimi.py +++ b/packages/kosong/src/kosong/chat_provider/kimi.py @@ -2,7 +2,7 @@ import mimetypes import os import uuid -from collections.abc import AsyncIterator, Sequence +from collections.abc import AsyncIterator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal, Self, Unpack, cast import httpx @@ -156,6 +156,8 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + *, + generation_overrides: Mapping[str, Any] | None = None, ) -> "KimiStreamedMessage": messages: list[ChatCompletionMessageParam] = [] if system_prompt: @@ -163,6 +165,12 @@ async def generate( messages.extend(_convert_message(message) for message in history) generation_kwargs: dict[str, Any] = dict(self._generation_kwargs) + if generation_overrides: + generation_kwargs.update( + _normalize_generation_kwargs( + cast(Kimi.GenerationKwargs, dict(generation_overrides)) + ) + ) try: response = await self.client.chat.completions.create( diff --git a/packages/kosong/src/kosong/chat_provider/mock.py b/packages/kosong/src/kosong/chat_provider/mock.py index b75828707..f13679efc 100644 --- a/packages/kosong/src/kosong/chat_provider/mock.py +++ b/packages/kosong/src/kosong/chat_provider/mock.py @@ -1,6 +1,6 @@ import copy -from collections.abc import AsyncIterator, Sequence -from typing import TYPE_CHECKING, Self +from collections.abc import AsyncIterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Self from kosong.chat_provider import ( ChatProvider, @@ -45,8 +45,11 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + *, + generation_overrides: Mapping[str, Any] | None = None, ) -> "MockStreamedMessage": """Always return the predefined message parts.""" + del generation_overrides # mock provider ignores per-call overrides return MockStreamedMessage(self._message_parts) def with_thinking(self, effort: ThinkingEffort) -> Self: diff --git a/packages/kosong/src/kosong/contrib/chat_provider/anthropic.py b/packages/kosong/src/kosong/contrib/chat_provider/anthropic.py index 5c7495cc8..b46970842 100644 --- a/packages/kosong/src/kosong/contrib/chat_provider/anthropic.py +++ b/packages/kosong/src/kosong/contrib/chat_provider/anthropic.py @@ -281,6 +281,8 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + *, + generation_overrides: Mapping[str, Any] | None = None, ) -> "AnthropicStreamedMessage": # https://docs.claude.com/en/api/messages#body-messages # Anthropic API does not support system roles, but just a system prompt. @@ -341,6 +343,8 @@ async def generate( pass generation_kwargs: dict[str, Any] = {} generation_kwargs.update(self._generation_kwargs) + if generation_overrides: + generation_kwargs.update(generation_overrides) betas = generation_kwargs.pop("beta_features", []) extra_headers = { **{"anthropic-beta": ",".join(str(e) for e in betas)}, diff --git a/packages/kosong/src/kosong/contrib/chat_provider/google_genai.py b/packages/kosong/src/kosong/contrib/chat_provider/google_genai.py index 0a662bd79..2b6898a1d 100644 --- a/packages/kosong/src/kosong/contrib/chat_provider/google_genai.py +++ b/packages/kosong/src/kosong/contrib/chat_provider/google_genai.py @@ -10,7 +10,7 @@ import copy import json import mimetypes -from collections.abc import AsyncIterator, Sequence +from collections.abc import AsyncIterator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Self, TypedDict, Unpack, cast import httpx @@ -144,10 +144,15 @@ async def generate( system_prompt: str, tools: Sequence[KosongTool], history: Sequence[Message], + *, + generation_overrides: Mapping[str, Any] | None = None, ) -> "GoogleGenAIStreamedMessage": contents = messages_to_google_genai_contents(history) - config = GenerateContentConfig(**self._generation_kwargs) + merged_kwargs: dict[str, Any] = dict(self._generation_kwargs) + if generation_overrides: + merged_kwargs.update(generation_overrides) + config = GenerateContentConfig(**merged_kwargs) config.system_instruction = system_prompt config.tools = [tool_to_google_genai(tool) for tool in tools] diff --git a/packages/kosong/src/kosong/contrib/chat_provider/openai_legacy.py b/packages/kosong/src/kosong/contrib/chat_provider/openai_legacy.py index ef736360b..2a9058757 100644 --- a/packages/kosong/src/kosong/contrib/chat_provider/openai_legacy.py +++ b/packages/kosong/src/kosong/contrib/chat_provider/openai_legacy.py @@ -1,6 +1,6 @@ import copy import uuid -from collections.abc import AsyncIterator, Sequence +from collections.abc import AsyncIterator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Self, Unpack, cast import httpx @@ -116,6 +116,8 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + *, + generation_overrides: Mapping[str, Any] | None = None, ) -> "OpenAILegacyStreamedMessage": messages: list[ChatCompletionMessageParam] = [] if system_prompt: @@ -125,6 +127,8 @@ async def generate( generation_kwargs: dict[str, Any] = {} generation_kwargs.update(self._generation_kwargs) + if generation_overrides: + generation_kwargs.update(generation_overrides) reasoning_effort = self._reasoning_effort # Auto-enable reasoning_effort when the history contains ThinkPart but reasoning diff --git a/packages/kosong/src/kosong/contrib/chat_provider/openai_responses.py b/packages/kosong/src/kosong/contrib/chat_provider/openai_responses.py index d675faf4e..79858f0da 100644 --- a/packages/kosong/src/kosong/contrib/chat_provider/openai_responses.py +++ b/packages/kosong/src/kosong/contrib/chat_provider/openai_responses.py @@ -1,6 +1,6 @@ import copy import uuid -from collections.abc import AsyncIterator, Sequence +from collections.abc import AsyncIterator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Self, TypedDict, Unpack, cast, get_args import httpx @@ -154,6 +154,8 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + *, + generation_overrides: Mapping[str, Any] | None = None, ) -> "OpenAIResponsesStreamedMessage": inputs: ResponseInputParam = [] if system_prompt: @@ -168,6 +170,8 @@ async def generate( generation_kwargs: dict[str, Any] = {} generation_kwargs.update(self._generation_kwargs) + if generation_overrides: + generation_kwargs.update(generation_overrides) reasoning_effort = generation_kwargs.pop("reasoning_effort", None) if reasoning_effort is not None: generation_kwargs["reasoning"] = Reasoning( diff --git a/packages/kosong/tests/api_snapshot_tests/test_kimi.py b/packages/kosong/tests/api_snapshot_tests/test_kimi.py index 7081e57de..a74523333 100644 --- a/packages/kosong/tests/api_snapshot_tests/test_kimi.py +++ b/packages/kosong/tests/api_snapshot_tests/test_kimi.py @@ -385,6 +385,72 @@ async def test_kimi_default_omits_completion_cap(): assert "max_completion_tokens" not in body +async def test_kimi_generation_overrides_per_call(): + """Per-call ``generation_overrides`` reach the request body without mutating the provider.""" + with respx.mock(base_url="https://api.moonshot.ai") as mock: + mock.post("/v1/chat/completions").mock( + return_value=Response(200, json=make_chat_completion_response()) + ) + provider = Kimi( + model="kimi-k2-turbo-preview", api_key="test-key", stream=False + ).with_generation_kwargs(temperature=0.7) + stream = await provider.generate( + "", + [], + [Message(role="user", content="Hi")], + generation_overrides={"max_completion_tokens": 4096}, + ) + async for _ in stream: + pass + body = json.loads(mock.calls.last.request.content.decode()) + assert (body["temperature"], body["max_completion_tokens"]) == snapshot((0.7, 4096)) + # The override must not have leaked into the provider's persistent kwargs. + assert "max_completion_tokens" not in provider.model_parameters + + +async def test_kimi_generation_overrides_normalize_max_tokens_alias(): + """An override key ``max_tokens`` is normalized to ``max_completion_tokens``.""" + with respx.mock(base_url="https://api.moonshot.ai") as mock: + mock.post("/v1/chat/completions").mock( + return_value=Response(200, json=make_chat_completion_response()) + ) + provider = Kimi(model="kimi-k2-turbo-preview", api_key="test-key", stream=False) + stream = await provider.generate( + "", + [], + [Message(role="user", content="Hi")], + generation_overrides={"max_tokens": 2048}, + ) + async for _ in stream: + pass + body = json.loads(mock.calls.last.request.content.decode()) + assert body["max_completion_tokens"] == 2048 + assert "max_tokens" not in body + + +async def test_kimi_generation_overrides_take_precedence_over_provider_kwargs(): + """Per-call override beats the provider-level value for the same key.""" + with respx.mock(base_url="https://api.moonshot.ai") as mock: + mock.post("/v1/chat/completions").mock( + return_value=Response(200, json=make_chat_completion_response()) + ) + provider = Kimi( + model="kimi-k2-turbo-preview", api_key="test-key", stream=False + ).with_generation_kwargs(max_completion_tokens=8000) + stream = await provider.generate( + "", + [], + [Message(role="user", content="Hi")], + generation_overrides={"max_completion_tokens": 1024}, + ) + async for _ in stream: + pass + body = json.loads(mock.calls.last.request.content.decode()) + assert body["max_completion_tokens"] == 1024 + # Provider-level kwargs are unchanged after the call. + assert provider.model_parameters["max_completion_tokens"] == 8000 + + async def test_kimi_with_thinking(): with respx.mock(base_url="https://api.moonshot.ai") as mock: mock.post("/v1/chat/completions").mock( diff --git a/src/kimi_cli/soul/compaction.py b/src/kimi_cli/soul/compaction.py index 4a61e3148..45669f723 100644 --- a/src/kimi_cli/soul/compaction.py +++ b/src/kimi_cli/soul/compaction.py @@ -1,7 +1,7 @@ from __future__ import annotations -from collections.abc import Sequence -from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, runtime_checkable import kosong from kosong.chat_provider import TokenUsage @@ -75,7 +75,12 @@ def should_auto_compact( @runtime_checkable class Compaction(Protocol): async def compact( - self, messages: Sequence[Message], llm: LLM, *, custom_instruction: str = "" + self, + messages: Sequence[Message], + llm: LLM, + *, + custom_instruction: str = "", + generation_overrides: Mapping[str, Any] | None = None, ) -> CompactionResult: """ Compact a sequence of messages into a new sequence of messages. @@ -84,6 +89,9 @@ async def compact( messages (Sequence[Message]): The messages to compact. llm (LLM): The LLM to use for compaction. custom_instruction: Optional user instruction to guide compaction focus. + generation_overrides: Optional per-call overrides forwarded to the chat provider via + ``kosong.step``. Used by callers that need to cap the compaction response size to + fit the remaining context window. Returns: CompactionResult: The compacted messages and token usage from the compaction LLM call. @@ -105,20 +113,27 @@ def __init__(self, max_preserved_messages: int = 2) -> None: self.max_preserved_messages = max_preserved_messages async def compact( - self, messages: Sequence[Message], llm: LLM, *, custom_instruction: str = "" + self, + messages: Sequence[Message], + llm: LLM, + *, + custom_instruction: str = "", + generation_overrides: Mapping[str, Any] | None = None, ) -> CompactionResult: compact_message, to_preserve = self.prepare(messages, custom_instruction=custom_instruction) if compact_message is None: return CompactionResult(messages=to_preserve, usage=None) - # Call kosong.step to get the compacted context. - # KimiSoul configures provider-specific completion caps before calling this. + # Call kosong.step to get the compacted context. The caller is responsible for + # computing provider-specific generation overrides (e.g. a completion budget + # that fits the remaining context window); they are forwarded verbatim. logger.debug("Compacting context...") result = await kosong.step( chat_provider=llm.chat_provider, system_prompt="You are a helpful assistant that compacts conversation context.", toolset=EmptyToolset(), history=[compact_message], + generation_overrides=generation_overrides, ) if result.usage: logger.debug( diff --git a/src/kimi_cli/soul/kimisoul.py b/src/kimi_cli/soul/kimisoul.py index a66f0b7a9..e3b3271f2 100644 --- a/src/kimi_cli/soul/kimisoul.py +++ b/src/kimi_cli/soul/kimisoul.py @@ -4,7 +4,7 @@ import time import uuid from collections.abc import Awaitable, Callable, Sequence -from dataclasses import dataclass, replace +from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, cast @@ -1071,7 +1071,7 @@ async def _append_notification(view: NotificationView) -> None: # 2e.3. HISTORY NORMALIZATION # ═══════════════════════════════════════════════════════════════════════ effective_history = normalize_history(self._context.history) - chat_provider = self._with_dynamic_completion_budget(chat_provider) + generation_overrides = self._compute_completion_overrides(chat_provider) # ═══════════════════════════════════════════════════════════════════════ # 2e.4. LLM CALL WITH RETRY @@ -1094,6 +1094,7 @@ async def _run_step_once() -> StepResult: effective_history, on_message_part=wire_send, on_tool_result=wire_send, + generation_overrides=generation_overrides, ) max_attempts = self._loop_control.max_retries_per_step @@ -1221,16 +1222,23 @@ async def _kosong_step_with_retry() -> StepResult: return None return StepOutcome(stop_reason="no_tool_calls", assistant_message=result.message) - def _with_dynamic_completion_budget(self, chat_provider: ChatProvider) -> ChatProvider: + def _compute_completion_overrides(self, chat_provider: ChatProvider) -> dict[str, Any] | None: + """Compute per-call generation overrides for the given chat provider. + + Returns a dict of per-call generation overrides forwarded to ``kosong.step``, + or ``None`` if no overrides apply for the given provider. The chat provider + instance is not modified, so transport-level state (the live OpenAI client and + any in-flight OAuth token) stays attached to the single instance owned by + ``Runtime.llm`` — retry recovery in ``Kimi.on_retryable_error`` therefore + affects subsequent steps as intended. + """ from kosong.chat_provider.kimi import Kimi if not isinstance(chat_provider, Kimi): - return chat_provider + return None parameters = chat_provider.model_parameters configured_budget = parameters.get("max_completion_tokens") - if configured_budget is None: - configured_budget = parameters.get("max_tokens") if type(configured_budget) is not int: configured_budget = self._loop_control.reserved_context_size @@ -1240,7 +1248,7 @@ def _with_dynamic_completion_budget(self, chat_provider: ChatProvider) -> ChatPr input_tokens=self._context.token_count_with_pending, response_budget=configured_budget, ) - return chat_provider.with_generation_kwargs(max_completion_tokens=max_completion_tokens) + return {"max_completion_tokens": max_completion_tokens} async def _grow_context(self, result: StepResult, tool_results: list[ToolResult]): logger.debug("Growing context with result: {result}", result=result) @@ -1285,19 +1293,19 @@ async def compact_context( ChatProviderError: When the chat provider returns an error. """ - compaction_llm = None - if self._runtime.llm is not None: - compaction_llm = self._runtime.llm - chat_provider = self._with_dynamic_completion_budget(compaction_llm.chat_provider) - if chat_provider is not compaction_llm.chat_provider: - compaction_llm = replace(compaction_llm, chat_provider=chat_provider) - chat_provider = compaction_llm.chat_provider if compaction_llm is not None else None + chat_provider = self._runtime.llm.chat_provider if self._runtime.llm is not None else None + compaction_overrides = ( + self._compute_completion_overrides(chat_provider) if chat_provider is not None else None + ) async def _run_compaction_once() -> CompactionResult: - if compaction_llm is None: + if self._runtime.llm is None: raise LLMNotSet() return await self._compaction.compact( - self._context.history, compaction_llm, custom_instruction=custom_instruction + self._context.history, + self._runtime.llm, + custom_instruction=custom_instruction, + generation_overrides=compaction_overrides, ) start_time = time.monotonic() diff --git a/tests/core/test_auth_error_handling.py b/tests/core/test_auth_error_handling.py index 028c4f054..c190ffbad 100644 --- a/tests/core/test_auth_error_handling.py +++ b/tests/core/test_auth_error_handling.py @@ -10,7 +10,7 @@ import asyncio from collections.abc import AsyncIterator, Sequence from pathlib import Path -from typing import Self +from typing import Any, Self from unittest.mock import AsyncMock, MagicMock import acp @@ -93,6 +93,7 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + **_kwargs: Any, ) -> StaticStreamedMessage: raise APIStatusError(401, "incorrect API KEY") @@ -125,6 +126,7 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + **_kwargs: Any, ) -> StaticStreamedMessage: raise APIStatusError(self._status_code, self._message) @@ -153,6 +155,7 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + **_kwargs: Any, ) -> StaticStreamedMessage: raise ChatProviderError("something went wrong") @@ -181,6 +184,7 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + **_kwargs: Any, ) -> StaticStreamedMessage: return StaticStreamedMessage([MessageTextPart(text="hello")]) @@ -209,6 +213,7 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + **_kwargs: Any, ) -> None: import ssl @@ -239,6 +244,7 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + **_kwargs: Any, ) -> None: raise ConnectionError("Connection reset by peer") diff --git a/tests/core/test_kimisoul_completion_budget.py b/tests/core/test_kimisoul_completion_budget.py index d9dd963f4..4c93fc9d4 100644 --- a/tests/core/test_kimisoul_completion_budget.py +++ b/tests/core/test_kimisoul_completion_budget.py @@ -1,8 +1,10 @@ from __future__ import annotations from pathlib import Path +from typing import Any import pytest +from kosong.chat_provider import APIConnectionError from kosong.chat_provider.kimi import Kimi from kosong.tooling.empty import EmptyToolset from pydantic import SecretStr @@ -56,10 +58,9 @@ async def test_dynamic_completion_budget_clamps_kimi_request( soul = _make_soul(runtime, tmp_path) await soul.context.update_token_count(60_000) - budgeted = soul._with_dynamic_completion_budget(chat_provider) + overrides = soul._compute_completion_overrides(chat_provider) - assert isinstance(budgeted, Kimi) - assert budgeted.model_parameters["max_completion_tokens"] == 38_976 + assert overrides == {"max_completion_tokens": 38_976} def test_dynamic_completion_budget_preserves_explicit_kimi_cap( @@ -74,10 +75,9 @@ def test_dynamic_completion_budget_preserves_explicit_kimi_cap( runtime.llm = _make_kimi_llm(chat_provider) soul = _make_soul(runtime, tmp_path) - budgeted = soul._with_dynamic_completion_budget(chat_provider) + overrides = soul._compute_completion_overrides(chat_provider) - assert isinstance(budgeted, Kimi) - assert budgeted.model_parameters["max_completion_tokens"] == 1234 + assert overrides == {"max_completion_tokens": 1234} @pytest.mark.asyncio @@ -94,7 +94,82 @@ async def test_dynamic_completion_budget_clamps_explicit_kimi_cap( soul = _make_soul(runtime, tmp_path) await soul.context.update_token_count(7_000) - budgeted = soul._with_dynamic_completion_budget(chat_provider) + overrides = soul._compute_completion_overrides(chat_provider) - assert isinstance(budgeted, Kimi) - assert budgeted.model_parameters["max_completion_tokens"] == 168 + assert overrides == {"max_completion_tokens": 168} + + +def test_compute_completion_overrides_returns_none_for_non_kimi_provider( + runtime: Runtime, tmp_path: Path +) -> None: + """Non-Kimi providers receive no overrides and run with their built-in defaults.""" + + class _NotKimi: + name = "not-kimi" + + @property + def model_name(self) -> str: + return "stub" + + @property + def thinking_effort(self) -> None: + return None + + async def generate(self, *args: Any, **kwargs: Any) -> Any: # pragma: no cover - unused + raise NotImplementedError + + def with_thinking(self, effort: Any) -> _NotKimi: # pragma: no cover - unused + return self + + soul = _make_soul(runtime, tmp_path) + + assert soul._compute_completion_overrides(_NotKimi()) is None # type: ignore[arg-type] + + +@pytest.mark.asyncio +async def test_compute_overrides_does_not_copy_chat_provider( + runtime: Runtime, tmp_path: Path +) -> None: + """Regression for F3: the dynamic budget must not produce a shallow copy of the + chat provider that shadows ``runtime.llm.chat_provider``. + + Before the fix, ``_with_dynamic_completion_budget`` returned a fresh ``Kimi`` instance + via ``with_generation_kwargs``. That copy shared ``client``/``_api_key`` with the + original, but ``on_retryable_error`` rebound ``self.client`` only on the copy — so the + runtime's ``chat_provider`` was left pointing at the (now-closed) old client and every + subsequent step had to recover from a dead connection first. + + With the new design ``_compute_completion_overrides`` returns a plain dict and the + runtime keeps owning the single live provider instance, so recovery on it is the + visible state for the next step. + """ + chat_provider = Kimi( + model="kimi-k2", + base_url="https://api.test/v1", + api_key="test-key", + stream=False, + ) + runtime.llm = _make_kimi_llm(chat_provider) + soul = _make_soul(runtime, tmp_path) + await soul.context.update_token_count(1_000) + + overrides = soul._compute_completion_overrides(runtime.llm.chat_provider) + + # The override path returns data, not a substitute provider. + assert isinstance(overrides, dict) + assert runtime.llm.chat_provider is chat_provider + + # When a transient error triggers recovery on the live provider, the next call to + # ``_compute_completion_overrides`` still sees the same instance — proof that + # the budget calculation has not forked a parallel provider that would mask + # the client refresh. + original_client = chat_provider.client + chat_provider.on_retryable_error(APIConnectionError("simulated")) + assert chat_provider.client is not original_client + runtime_provider = runtime.llm.chat_provider + assert isinstance(runtime_provider, Kimi) + assert runtime_provider.client is chat_provider.client + + overrides_after_recovery = soul._compute_completion_overrides(runtime_provider) + assert isinstance(overrides_after_recovery, dict) + assert runtime_provider is chat_provider diff --git a/tests/core/test_kimisoul_ralph_loop.py b/tests/core/test_kimisoul_ralph_loop.py index 0ec7c27ab..496169d62 100644 --- a/tests/core/test_kimisoul_ralph_loop.py +++ b/tests/core/test_kimisoul_ralph_loop.py @@ -3,7 +3,7 @@ import asyncio from collections.abc import AsyncIterator, Sequence from pathlib import Path -from typing import Self, TypeVar +from typing import Any, Self, TypeVar import pytest from inline_snapshot import Snapshot, snapshot @@ -89,6 +89,7 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + **_kwargs: Any, ) -> SequenceStreamedMessage: index = min(self._index, len(self._sequences) - 1) self._index += 1 diff --git a/tests/core/test_kimisoul_retry_recovery.py b/tests/core/test_kimisoul_retry_recovery.py index 6f3fc5b67..d1c0844e5 100644 --- a/tests/core/test_kimisoul_retry_recovery.py +++ b/tests/core/test_kimisoul_retry_recovery.py @@ -3,7 +3,7 @@ import asyncio from collections.abc import AsyncIterator, Sequence from pathlib import Path -from typing import Self +from typing import Any, Self from unittest.mock import AsyncMock import pytest @@ -75,6 +75,7 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + **_kwargs: Any, ) -> StaticStreamedMessage | PartialThenErrorStreamedMessage: self.generate_attempts += 1 if self.generate_attempts == 1: @@ -109,6 +110,7 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + **_kwargs: Any, ) -> StaticStreamedMessage: self.generate_attempts += 1 raise APIConnectionError("Connection error.") @@ -142,6 +144,7 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + **_kwargs: Any, ) -> StaticStreamedMessage: self.generate_attempts += 1 if self.generate_attempts < 3: @@ -176,6 +179,7 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + **_kwargs: Any, ) -> StaticStreamedMessage | PartialThenErrorStreamedMessage: self.generate_attempts += 1 if self.generate_attempts == 1: @@ -234,6 +238,7 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + **_kwargs: Any, ) -> StaticStreamedMessage: self.generate_attempts += 1 if self.generate_attempts == 1: @@ -264,6 +269,7 @@ async def generate( system_prompt: str, tools: Sequence[Tool], history: Sequence[Message], + **_kwargs: Any, ) -> StaticStreamedMessage: self.generate_attempts += 1 if self.generate_attempts == 1: diff --git a/tests/core/test_kimisoul_steer.py b/tests/core/test_kimisoul_steer.py index 500b2ce6a..6301b3440 100644 --- a/tests/core/test_kimisoul_steer.py +++ b/tests/core/test_kimisoul_steer.py @@ -400,7 +400,7 @@ def model_name(self) -> str: def thinking_effort(self): return None - async def generate(self, system_prompt, tools, history): + async def generate(self, system_prompt, tools, history, **_kwargs): index = min(self._calls, len(self._sequences) - 1) self._calls += 1 return _SequenceStreamedMessage(self._sequences[index]) diff --git a/tests/core/test_notifications.py b/tests/core/test_notifications.py index 06177b78e..8c5a4ddf3 100644 --- a/tests/core/test_notifications.py +++ b/tests/core/test_notifications.py @@ -4,7 +4,7 @@ import time from collections.abc import Sequence from pathlib import Path -from typing import Self +from typing import Any, Self import pytest from kosong.chat_provider import StreamedMessagePart, ThinkingEffort, TokenUsage @@ -65,6 +65,7 @@ async def generate( system_prompt: str, tools: Sequence[object], history: Sequence[Message], + **_kwargs: Any, ) -> _SequenceStream: return _SequenceStream(self._parts) From 4a0f9ec2eea891a038a697e30db215785f50f8e2 Mon Sep 17 00:00:00 2001 From: 7Sageer <7sageer@djwcb.cn> Date: Wed, 20 May 2026 21:22:59 +0800 Subject: [PATCH 5/5] fix(kimi): cap kosong.step paths that bypass KimiSoul After removing the provider-level ``max_tokens=32000`` default, two call paths still invoked ``kosong.step`` without ``generation_overrides`` and therefore sent Kimi requests with no completion cap at all: - ``KimiSoul`` /btw side-question loop (``execute_side_question`` calls ``kosong.step`` directly with ``soul._runtime.llm.chat_provider``), - the kosong package's developer demo (``kosong.__main__``). Both are flagged by codex as a P1 regression in latency/cost behavior. Reuse ``KimiSoul._compute_completion_overrides`` in the /btw path so it shares the same context-aware budget as the main step loop. Give the kosong demo a conservative ``max_completion_tokens=8192`` default so ``python -m kosong kimi`` does not run unbounded generations. Add a regression test asserting /btw forwards the overrides to ``kosong.step``. --- packages/kosong/src/kosong/__main__.py | 7 +++++- src/kimi_cli/soul/btw.py | 7 ++++++ tests/ui_and_conv/test_btw.py | 30 ++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/packages/kosong/src/kosong/__main__.py b/packages/kosong/src/kosong/__main__.py index 75cac7447..c01e13da7 100644 --- a/packages/kosong/src/kosong/__main__.py +++ b/packages/kosong/src/kosong/__main__.py @@ -123,7 +123,12 @@ async def main(): assert api_key is not None, "Expect KIMI_API_KEY environment variable" model = model or "kimi-k2-turbo-preview" - chat_provider = Kimi(base_url=base_url, api_key=api_key, model=model) + # ``Kimi.generate`` no longer carries a built-in completion cap, so set a + # conservative default here to keep this developer-facing demo from running + # unbounded generations. + chat_provider = Kimi( + base_url=base_url, api_key=api_key, model=model + ).with_generation_kwargs(max_completion_tokens=8192) case "openai": from kosong.contrib.chat_provider.openai_responses import OpenAIResponses diff --git a/src/kimi_cli/soul/btw.py b/src/kimi_cli/soul/btw.py index f08ab1b04..a8a9c9316 100644 --- a/src/kimi_cli/soul/btw.py +++ b/src/kimi_cli/soul/btw.py @@ -131,6 +131,12 @@ async def execute_side_question( return None, "LLM is not set." chat_provider = soul._runtime.llm.chat_provider # pyright: ignore[reportPrivateUsage] + # Compute per-call completion overrides up front so /btw is subject to the same + # context-aware completion budget as KimiSoul's main step path. Without this the + # bare ``kosong.step`` call below would send a request with no completion cap on + # Kimi providers (the provider-level default was removed), so a side-question + # could run for hundreds of seconds and burn a large completion budget. + generation_overrides = soul._compute_completion_overrides(chat_provider) # pyright: ignore[reportPrivateUsage] system_prompt, history, toolset = _build_btw_context(soul, question) text_chunks: list[str] = [] @@ -149,6 +155,7 @@ def _on_part(part: StreamedMessagePart) -> None: toolset, history, on_message_part=_on_part, + generation_overrides=generation_overrides, ) # Check for text response — but only accept it if the LLM diff --git a/tests/ui_and_conv/test_btw.py b/tests/ui_and_conv/test_btw.py index 7721c2c86..5a2c13e3f 100644 --- a/tests/ui_and_conv/test_btw.py +++ b/tests/ui_and_conv/test_btw.py @@ -443,6 +443,36 @@ async def fake_step(provider, sys_prompt, toolset, history, **kw): assert chunks == ["chunk1", "chunk2"] assert response == "chunk1chunk2" + def test_forwards_generation_overrides_from_soul(self): + """Regression: /btw must forward the per-call completion budget to ``kosong.step``. + + Without this, ``kosong.step`` is called with ``generation_overrides=None`` and the + Kimi provider sends an unbounded request because the provider-level default cap + was removed in this PR. + """ + soul = MagicMock() + soul._runtime.llm.chat_provider = MagicMock() + soul._compute_completion_overrides.return_value = {"max_completion_tokens": 4096} + soul._agent.system_prompt = "sys" + soul._agent.toolset.tools = [] + soul.context.history = [] + + captured_overrides: list[object] = [] + + async def fake_step(provider, sys_prompt, toolset, history, **kw): + captured_overrides.append(kw.get("generation_overrides")) + if kw.get("on_message_part"): + kw["on_message_part"](TextPart(text="ok")) + return _text_result("ok") + + with patch("kimi_cli.soul.btw.kosong.step", side_effect=fake_step): + response, error = asyncio.run(execute_side_question(soul, "hi")) + + assert response == "ok" + assert error is None + assert captured_overrides == [{"max_completion_tokens": 4096}] + soul._compute_completion_overrides.assert_called_once_with(soul._runtime.llm.chat_provider) + # --------------------------------------------------------------------------- # Telemetry tracking for execute_side_question