Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ dev = [
"pytest-pretty>=1.3.0",
"openai-agents>=0.3,<0.7; python_version >= '3.14'",
"openai-agents[litellm]>=0.3,<0.7; python_version < '3.14'",
"litellm>=1.83.0",
"openinference-instrumentation-google-adk>=0.1.8",
"googleapis-common-protos==1.70.0",
"pytest-rerunfailures>=16.1",
Expand Down
108 changes: 108 additions & 0 deletions tests/contrib/google_adk_agents/test_google_adk_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,3 +855,111 @@ async def test_activity_tool_supports_complex_inputs_via_adk(client: Client):
),
"annotate_trip": "SFO->LAX:3",
}


def litellm_agent(model_name: str) -> Agent:
return Agent(
name="litellm_test_agent",
model=TemporalModel(model_name),
)


@workflow.defn
class LiteLlmWorkflow:
@workflow.run
async def run(self, prompt: str, model_name: str) -> Event | None:
agent = litellm_agent(model_name)

runner = InMemoryRunner(
agent=agent,
app_name="litellm_test_app",
)

session = await runner.session_service.create_session(
app_name="litellm_test_app", user_id="test"
)

last_event = None
async with Aclosing(
runner.run_async(
user_id="test",
session_id=session.id,
new_message=types.Content(role="user", parts=[types.Part(text=prompt)]),
)
) as agen:
async for event in agen:
last_event = event

return last_event


@pytest.mark.asyncio
async def test_litellm_model(client: Client):
"""Test that a litellm-backed model works with TemporalModel through a full Temporal workflow."""
import litellm as litellm_module
from google.adk.models.lite_llm import LiteLlm
from litellm import ModelResponse
from litellm.llms.custom_llm import CustomLLM

class FakeLiteLlmProvider(CustomLLM):
"""A fake litellm provider that returns canned responses locally."""

def _make_response(self, model: str) -> ModelResponse:
return ModelResponse(
choices=[
{
"message": {
"content": "hello from litellm",
"role": "assistant",
},
"index": 0,
"finish_reason": "stop",
}
],
model=model,
)

def completion(self, *args: Any, **kwargs: Any) -> ModelResponse:
model = args[0] if args else kwargs.get("model", "unknown")
return self._make_response(model)

async def acompletion(self, *args: Any, **kwargs: Any) -> ModelResponse:
return self.completion(*args, **kwargs)

class FakeLiteLlm(LiteLlm):
"""LiteLlm subclass that supports the fake/test-model name for testing."""

@classmethod
def supported_models(cls) -> list[str]:
return ["fake/test-model"]

# Register our fake provider with litellm
litellm_module.custom_provider_map = [
{"provider": "fake", "custom_handler": FakeLiteLlmProvider()}
]

LLMRegistry.register(FakeLiteLlm)

new_config = client.config()
new_config["plugins"] = [GoogleAdkPlugin()]
client = Client(**new_config)

async with Worker(
client,
task_queue="adk-task-queue-litellm",
workflows=[LiteLlmWorkflow],
max_cached_workflows=0,
):
handle = await client.start_workflow(
LiteLlmWorkflow.run,
args=["Say hello", "fake/test-model"],
id=f"litellm-agent-workflow-{uuid.uuid4()}",
task_queue="adk-task-queue-litellm",
execution_timeout=timedelta(seconds=60),
)
result = await handle.result()

assert result is not None
assert result.content is not None
assert result.content.parts is not None
assert result.content.parts[0].text == "hello from litellm"
58 changes: 30 additions & 28 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading