diff --git a/pyproject.toml b/pyproject.toml index 8e74bd6e1..6ea339047 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ dev = [ "pytest-pretty>=1.3.0", "openai-agents>=0.3,<0.7; python_version >= '3.14'", "openai-agents[litellm]>=0.3,<0.7; python_version < '3.14'", + "litellm>=1.83.0", "openinference-instrumentation-google-adk>=0.1.8", "googleapis-common-protos==1.70.0", "pytest-rerunfailures>=16.1", diff --git a/tests/contrib/google_adk_agents/test_google_adk_agents.py b/tests/contrib/google_adk_agents/test_google_adk_agents.py index d7ccd4699..a02e98b3f 100644 --- a/tests/contrib/google_adk_agents/test_google_adk_agents.py +++ b/tests/contrib/google_adk_agents/test_google_adk_agents.py @@ -855,3 +855,111 @@ async def test_activity_tool_supports_complex_inputs_via_adk(client: Client): ), "annotate_trip": "SFO->LAX:3", } + + +def litellm_agent(model_name: str) -> Agent: + return Agent( + name="litellm_test_agent", + model=TemporalModel(model_name), + ) + + +@workflow.defn +class LiteLlmWorkflow: + @workflow.run + async def run(self, prompt: str, model_name: str) -> Event | None: + agent = litellm_agent(model_name) + + runner = InMemoryRunner( + agent=agent, + app_name="litellm_test_app", + ) + + session = await runner.session_service.create_session( + app_name="litellm_test_app", user_id="test" + ) + + last_event = None + async with Aclosing( + runner.run_async( + user_id="test", + session_id=session.id, + new_message=types.Content(role="user", parts=[types.Part(text=prompt)]), + ) + ) as agen: + async for event in agen: + last_event = event + + return last_event + + +@pytest.mark.asyncio +async def test_litellm_model(client: Client): + """Test that a litellm-backed model works with TemporalModel through a full Temporal workflow.""" + import litellm as litellm_module + from google.adk.models.lite_llm import LiteLlm + from litellm import ModelResponse + from litellm.llms.custom_llm import CustomLLM + + class FakeLiteLlmProvider(CustomLLM): + """A fake litellm provider that returns canned responses locally.""" + + def _make_response(self, model: str) -> ModelResponse: + return ModelResponse( + choices=[ + { + "message": { + "content": "hello from litellm", + "role": "assistant", + }, + "index": 0, + "finish_reason": "stop", + } + ], + model=model, + ) + + def completion(self, *args: Any, **kwargs: Any) -> ModelResponse: + model = args[0] if args else kwargs.get("model", "unknown") + return self._make_response(model) + + async def acompletion(self, *args: Any, **kwargs: Any) -> ModelResponse: + return self.completion(*args, **kwargs) + + class FakeLiteLlm(LiteLlm): + """LiteLlm subclass that supports the fake/test-model name for testing.""" + + @classmethod + def supported_models(cls) -> list[str]: + return ["fake/test-model"] + + # Register our fake provider with litellm + litellm_module.custom_provider_map = [ + {"provider": "fake", "custom_handler": FakeLiteLlmProvider()} + ] + + LLMRegistry.register(FakeLiteLlm) + + new_config = client.config() + new_config["plugins"] = [GoogleAdkPlugin()] + client = Client(**new_config) + + async with Worker( + client, + task_queue="adk-task-queue-litellm", + workflows=[LiteLlmWorkflow], + max_cached_workflows=0, + ): + handle = await client.start_workflow( + LiteLlmWorkflow.run, + args=["Say hello", "fake/test-model"], + id=f"litellm-agent-workflow-{uuid.uuid4()}", + task_queue="adk-task-queue-litellm", + execution_timeout=timedelta(seconds=60), + ) + result = await handle.result() + + assert result is not None + assert result.content is not None + assert result.content.parts is not None + assert result.content.parts[0].text == "hello from litellm" diff --git a/uv.lock b/uv.lock index c45409833..6d824cf92 100644 --- a/uv.lock +++ b/uv.lock @@ -2128,15 +2128,15 @@ name = "huggingface-hub" version = "1.9.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "filelock", marker = "python_full_version < '3.14'" }, - { name = "fsspec", marker = "python_full_version < '3.14'" }, - { name = "hf-xet", marker = "(python_full_version < '3.14' and platform_machine == 'AMD64') or (python_full_version < '3.14' and platform_machine == 'aarch64') or (python_full_version < '3.14' and platform_machine == 'amd64') or (python_full_version < '3.14' and platform_machine == 'arm64') or (python_full_version < '3.14' and platform_machine == 'x86_64')" }, - { name = "httpx", marker = "python_full_version < '3.14'" }, - { name = "packaging", marker = "python_full_version < '3.14'" }, - { name = "pyyaml", marker = "python_full_version < '3.14'" }, - { name = "tqdm", marker = "python_full_version < '3.14'" }, - { name = "typer", marker = "python_full_version < '3.14'" }, - { name = "typing-extensions", marker = "python_full_version < '3.14'" }, + { name = "filelock" }, + { name = "fsspec" }, + { name = "hf-xet", marker = "platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, + { name = "httpx" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "tqdm" }, + { name = "typer" }, + { name = "typing-extensions" }, ] sdist = { url = "https://files.pythonhosted.org/packages/44/40/68d9b286b125d9318ae95c8f8b206e8672e7244b0eea61ebb4a88037638c/huggingface_hub-1.9.1.tar.gz", hash = "sha256:442af372207cc24dcb089caf507fcd7dbc1217c11d6059a06f6b90afe64e8bd2", size = 750355, upload-time = "2026-04-07T13:47:59.167Z" } wheels = [ @@ -2541,18 +2541,18 @@ name = "litellm" version = "1.83.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "aiohttp", marker = "python_full_version < '3.14'" }, - { name = "click", marker = "python_full_version < '3.14'" }, - { name = "fastuuid", marker = "python_full_version < '3.14'" }, - { name = "httpx", marker = "python_full_version < '3.14'" }, - { name = "importlib-metadata", marker = "python_full_version < '3.14'" }, - { name = "jinja2", marker = "python_full_version < '3.14'" }, - { name = "jsonschema", marker = "python_full_version < '3.14'" }, - { name = "openai", marker = "python_full_version < '3.14'" }, - { name = "pydantic", marker = "python_full_version < '3.14'" }, - { name = "python-dotenv", marker = "python_full_version < '3.14'" }, - { name = "tiktoken", marker = "python_full_version < '3.14'" }, - { name = "tokenizers", marker = "python_full_version < '3.14'" }, + { name = "aiohttp" }, + { name = "click" }, + { name = "fastuuid" }, + { name = "httpx" }, + { name = "importlib-metadata" }, + { name = "jinja2" }, + { name = "jsonschema" }, + { name = "openai" }, + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "tiktoken" }, + { name = "tokenizers" }, ] sdist = { url = "https://files.pythonhosted.org/packages/22/92/6ce9737554994ca8e536e5f4f6a87cc7c4774b656c9eb9add071caf7d54b/litellm-1.83.0.tar.gz", hash = "sha256:860bebc76c4bb27b4cf90b4a77acd66dba25aced37e3db98750de8a1766bfb7a", size = 17333062, upload-time = "2026-03-31T05:08:25.331Z" } wheels = [ @@ -5058,6 +5058,7 @@ dev = [ { name = "grpcio-tools" }, { name = "httpx" }, { name = "langsmith" }, + { name = "litellm" }, { name = "maturin" }, { name = "moto", extra = ["s3", "server"] }, { name = "mypy" }, @@ -5119,6 +5120,7 @@ dev = [ { name = "grpcio-tools", specifier = ">=1.48.2,<2" }, { name = "httpx", specifier = ">=0.28.1" }, { name = "langsmith", specifier = ">=0.7.0,<0.8" }, + { name = "litellm", specifier = ">=1.83.0" }, { name = "maturin", specifier = ">=1.8.2" }, { name = "moto", extras = ["s3", "server"], specifier = ">=5" }, { name = "mypy", specifier = "==1.18.2" }, @@ -5161,8 +5163,8 @@ name = "tiktoken" version = "0.12.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "regex", marker = "python_full_version < '3.14'" }, - { name = "requests", marker = "python_full_version < '3.14'" }, + { name = "regex" }, + { name = "requests" }, ] sdist = { url = "https://files.pythonhosted.org/packages/7d/ab/4d017d0f76ec3171d469d80fc03dfbb4e48a4bcaddaa831b31d526f05edc/tiktoken-0.12.0.tar.gz", hash = "sha256:b18ba7ee2b093863978fcb14f74b3707cdc8d4d4d3836853ce7ec60772139931", size = 37806, upload-time = "2025-10-06T20:22:45.419Z" } wheels = [ @@ -5222,7 +5224,7 @@ name = "tokenizers" version = "0.22.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "huggingface-hub", marker = "python_full_version < '3.14'" }, + { name = "huggingface-hub" }, ] sdist = { url = "https://files.pythonhosted.org/packages/73/6f/f80cfef4a312e1fb34baf7d85c72d4411afde10978d4657f8cdd811d3ccc/tokenizers-0.22.2.tar.gz", hash = "sha256:473b83b915e547aa366d1eee11806deaf419e17be16310ac0a14077f1e28f917", size = 372115, upload-time = "2026-01-05T10:45:15.988Z" } wheels = [ @@ -5365,10 +5367,10 @@ name = "typer" version = "0.24.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "annotated-doc", marker = "python_full_version < '3.14'" }, - { name = "click", marker = "python_full_version < '3.14'" }, - { name = "rich", marker = "python_full_version < '3.14'" }, - { name = "shellingham", marker = "python_full_version < '3.14'" }, + { name = "annotated-doc" }, + { name = "click" }, + { name = "rich" }, + { name = "shellingham" }, ] sdist = { url = "https://files.pythonhosted.org/packages/f5/24/cb09efec5cc954f7f9b930bf8279447d24618bb6758d4f6adf2574c41780/typer-0.24.1.tar.gz", hash = "sha256:e39b4732d65fbdcde189ae76cf7cd48aeae72919dea1fdfc16593be016256b45", size = 118613, upload-time = "2026-02-21T16:54:40.609Z" } wheels = [