Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion astrbot/core/provider/sources/gemini_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,10 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:

conversation = self._prepare_conversation(payloads)
temperature = payloads.get("temperature", 0.7)
# Keep the retry knobs local to this narrow Gemini workaround.
EMPTY_PARTS_RETRY_DELAY_SECONDS = 0.2
MAX_EMPTY_PARTS_RETRIES = 3
empty_parts_retry_count = 0

result: types.GenerateContentResponse | None = None
while True:
Expand All @@ -550,7 +554,9 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
logger.error(f"请求失败, 返回的 candidates 为空: {result}")
raise Exception("请求失败, 返回的 candidates 为空。")

if result.candidates[0].finish_reason == types.FinishReason.RECITATION:
candidate = result.candidates[0]

if candidate.finish_reason == types.FinishReason.RECITATION:
if temperature > 2:
raise Exception("温度参数已超过最大值2,仍然发生recitation")
temperature += 0.2
Expand All @@ -559,6 +565,26 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
)
continue

if (
candidate.finish_reason == types.FinishReason.STOP
and candidate.content
and not candidate.content.parts
):
if empty_parts_retry_count < MAX_EMPTY_PARTS_RETRIES:
empty_parts_retry_count += 1
logger.warning(
"Gemini 返回 STOP 但 candidate.content.parts 为空,正在重试(%s/%s): %s",
empty_parts_retry_count,
MAX_EMPTY_PARTS_RETRIES,
candidate,
)
await asyncio.sleep(EMPTY_PARTS_RETRY_DELAY_SECONDS)
continue
logger.warning(
"Gemini 在 %s 次重试后仍返回空的 candidate.content.parts,将沿用现有失败逻辑。",
MAX_EMPTY_PARTS_RETRIES,
)

break

except APIError as e:
Expand Down
151 changes: 151 additions & 0 deletions tests/test_gemini_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock

import pytest
from google.genai import types

from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI


def _make_provider(overrides: dict | None = None) -> ProviderGoogleGenAI:
provider_config = {
"id": "test-gemini",
"type": "googlegenai_chat_completion",
"model": "gemini-3-flash-preview",
"key": ["test-key"],
}
if overrides:
provider_config.update(overrides)
return ProviderGoogleGenAI(
provider_config=provider_config,
provider_settings={},
)


def _make_result(parts: list | None, response_id: str) -> SimpleNamespace:
return SimpleNamespace(
candidates=[
SimpleNamespace(
content=SimpleNamespace(parts=parts),
finish_reason=types.FinishReason.STOP,
)
],
response_id=response_id,
usage_metadata=None,
)


def _make_text_part(text: str) -> SimpleNamespace:
return SimpleNamespace(
text=text,
thought=False,
function_call=None,
inline_data=None,
thought_signature=None,
)


def _make_function_call_part(
name: str,
args: dict,
*,
tool_call_id: str | None = None,
) -> SimpleNamespace:
return SimpleNamespace(
text=None,
thought=False,
function_call=SimpleNamespace(
name=name,
args=args,
id=tool_call_id,
),
inline_data=None,
thought_signature=None,
)


@pytest.mark.asyncio
async def test_gemini_query_retries_empty_parts_response(
monkeypatch: pytest.MonkeyPatch,
):
provider = _make_provider()
try:
sleep_mock = AsyncMock()
monkeypatch.setattr(
"astrbot.core.provider.sources.gemini_source.asyncio.sleep",
sleep_mock,
)

empty_result = _make_result([], "empty-response")
success_result = _make_result(
[_make_text_part("Recovered response")],
"success-response",
)
generate_content = AsyncMock(side_effect=[empty_result, success_result])
provider.client = SimpleNamespace(
models=SimpleNamespace(generate_content=generate_content),
aclose=AsyncMock(),
)

response = await provider._query(
payloads={
"messages": [{"role": "user", "content": "hello"}],
"model": "gemini-3-flash-preview",
},
tools=None,
)

assert generate_content.await_count == 2
sleep_mock.assert_awaited_once()
assert response.completion_text == "Recovered response"
assert response.id == "success-response"
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_gemini_query_retries_empty_parts_before_function_call(
monkeypatch: pytest.MonkeyPatch,
):
provider = _make_provider()
try:
sleep_mock = AsyncMock()
monkeypatch.setattr(
"astrbot.core.provider.sources.gemini_source.asyncio.sleep",
sleep_mock,
)

empty_result = _make_result([], "empty-response")
tool_call_result = _make_result(
[
_make_function_call_part(
"read_file",
{"path": "README.md"},
tool_call_id="call-readme",
)
],
"tool-call-response",
)
generate_content = AsyncMock(side_effect=[empty_result, tool_call_result])
provider.client = SimpleNamespace(
models=SimpleNamespace(generate_content=generate_content),
aclose=AsyncMock(),
)

response = await provider._query(
payloads={
"messages": [{"role": "user", "content": "summarize the file"}],
"model": "gemini-3-flash-preview",
},
tools=None,
)

assert generate_content.await_count == 2
sleep_mock.assert_awaited_once()
assert response.role == "tool"
assert response.tools_call_name == ["read_file"]
assert response.tools_call_args == [{"path": "README.md"}]
assert response.tools_call_ids == ["call-readme"]
assert response.id == "tool-call-response"
finally:
await provider.terminate()
Loading