diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 9557f3dbcd..186ba81148 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -528,6 +528,10 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: conversation = self._prepare_conversation(payloads) temperature = payloads.get("temperature", 0.7) + # Keep the retry knobs local to this narrow Gemini workaround. + EMPTY_PARTS_RETRY_DELAY_SECONDS = 0.2 + MAX_EMPTY_PARTS_RETRIES = 3 + empty_parts_retry_count = 0 result: types.GenerateContentResponse | None = None while True: @@ -550,7 +554,9 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: logger.error(f"请求失败, 返回的 candidates 为空: {result}") raise Exception("请求失败, 返回的 candidates 为空。") - if result.candidates[0].finish_reason == types.FinishReason.RECITATION: + candidate = result.candidates[0] + + if candidate.finish_reason == types.FinishReason.RECITATION: if temperature > 2: raise Exception("温度参数已超过最大值2,仍然发生recitation") temperature += 0.2 @@ -559,6 +565,26 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: ) continue + if ( + candidate.finish_reason == types.FinishReason.STOP + and candidate.content + and not candidate.content.parts + ): + if empty_parts_retry_count < MAX_EMPTY_PARTS_RETRIES: + empty_parts_retry_count += 1 + logger.warning( + "Gemini 返回 STOP 但 candidate.content.parts 为空,正在重试(%s/%s): %s", + empty_parts_retry_count, + MAX_EMPTY_PARTS_RETRIES, + candidate, + ) + await asyncio.sleep(EMPTY_PARTS_RETRY_DELAY_SECONDS) + continue + logger.warning( + "Gemini 在 %s 次重试后仍返回空的 candidate.content.parts,将沿用现有失败逻辑。", + MAX_EMPTY_PARTS_RETRIES, + ) + break except APIError as e: diff --git a/tests/test_gemini_source.py b/tests/test_gemini_source.py new file mode 100644 index 0000000000..2af59c3bef --- /dev/null +++ b/tests/test_gemini_source.py @@ -0,0 +1,151 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest +from google.genai import types + +from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI + + +def _make_provider(overrides: dict | None = None) -> ProviderGoogleGenAI: + provider_config = { + "id": "test-gemini", + "type": "googlegenai_chat_completion", + "model": "gemini-3-flash-preview", + "key": ["test-key"], + } + if overrides: + provider_config.update(overrides) + return ProviderGoogleGenAI( + provider_config=provider_config, + provider_settings={}, + ) + + +def _make_result(parts: list | None, response_id: str) -> SimpleNamespace: + return SimpleNamespace( + candidates=[ + SimpleNamespace( + content=SimpleNamespace(parts=parts), + finish_reason=types.FinishReason.STOP, + ) + ], + response_id=response_id, + usage_metadata=None, + ) + + +def _make_text_part(text: str) -> SimpleNamespace: + return SimpleNamespace( + text=text, + thought=False, + function_call=None, + inline_data=None, + thought_signature=None, + ) + + +def _make_function_call_part( + name: str, + args: dict, + *, + tool_call_id: str | None = None, +) -> SimpleNamespace: + return SimpleNamespace( + text=None, + thought=False, + function_call=SimpleNamespace( + name=name, + args=args, + id=tool_call_id, + ), + inline_data=None, + thought_signature=None, + ) + + +@pytest.mark.asyncio +async def test_gemini_query_retries_empty_parts_response( + monkeypatch: pytest.MonkeyPatch, +): + provider = _make_provider() + try: + sleep_mock = AsyncMock() + monkeypatch.setattr( + "astrbot.core.provider.sources.gemini_source.asyncio.sleep", + sleep_mock, + ) + + empty_result = _make_result([], "empty-response") + success_result = _make_result( + [_make_text_part("Recovered response")], + "success-response", + ) + generate_content = AsyncMock(side_effect=[empty_result, success_result]) + provider.client = SimpleNamespace( + models=SimpleNamespace(generate_content=generate_content), + aclose=AsyncMock(), + ) + + response = await provider._query( + payloads={ + "messages": [{"role": "user", "content": "hello"}], + "model": "gemini-3-flash-preview", + }, + tools=None, + ) + + assert generate_content.await_count == 2 + sleep_mock.assert_awaited_once() + assert response.completion_text == "Recovered response" + assert response.id == "success-response" + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_gemini_query_retries_empty_parts_before_function_call( + monkeypatch: pytest.MonkeyPatch, +): + provider = _make_provider() + try: + sleep_mock = AsyncMock() + monkeypatch.setattr( + "astrbot.core.provider.sources.gemini_source.asyncio.sleep", + sleep_mock, + ) + + empty_result = _make_result([], "empty-response") + tool_call_result = _make_result( + [ + _make_function_call_part( + "read_file", + {"path": "README.md"}, + tool_call_id="call-readme", + ) + ], + "tool-call-response", + ) + generate_content = AsyncMock(side_effect=[empty_result, tool_call_result]) + provider.client = SimpleNamespace( + models=SimpleNamespace(generate_content=generate_content), + aclose=AsyncMock(), + ) + + response = await provider._query( + payloads={ + "messages": [{"role": "user", "content": "summarize the file"}], + "model": "gemini-3-flash-preview", + }, + tools=None, + ) + + assert generate_content.await_count == 2 + sleep_mock.assert_awaited_once() + assert response.role == "tool" + assert response.tools_call_name == ["read_file"] + assert response.tools_call_args == [{"path": "README.md"}] + assert response.tools_call_ids == ["call-readme"] + assert response.id == "tool-call-response" + finally: + await provider.terminate()