Skip to content

Commit 91d33eb

Browse files
committed
fix: allow litellm security patch
1 parent 9e8fb83 commit 91d33eb

13 files changed

Lines changed: 596 additions & 282 deletions

File tree

examples/tutorials/run_agent_test.sh

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,18 @@ run_test() {
260260

261261

262262
# Run the tests with retry mechanism
263+
local -a pytest_cmd=("uv" "run" "pytest")
264+
if [ "$BUILD_CLI" = true ]; then
265+
local wheel_file
266+
wheel_file=$(ls /home/runner/work/*/*/dist/agentex_sdk-*.whl 2>/dev/null | head -n1)
267+
if [[ -z "$wheel_file" ]]; then
268+
wheel_file=$(ls "${SCRIPT_DIR}/../../dist/agentex_sdk-*.whl" 2>/dev/null | head -n1)
269+
fi
270+
if [[ -n "$wheel_file" ]]; then
271+
pytest_cmd=("uv" "run" "--with" "$wheel_file" "pytest")
272+
fi
273+
fi
274+
263275
local max_retries=5
264276
local retry_count=0
265277
local exit_code=1
@@ -270,7 +282,7 @@ run_test() {
270282
fi
271283

272284
# Stream pytest output directly in real-time
273-
uv run pytest tests/test_agent.py -v -s
285+
"${pytest_cmd[@]}" tests/test_agent.py -v -s
274286
exit_code=$?
275287

276288
if [ $exit_code -eq 0 ]; then

pyproject.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ authors = [
99
]
1010

1111
dependencies = [
12-
"httpx>=0.27.2,<0.28",
12+
"httpx>=0.28.1,<0.29",
1313
"pydantic>=2.0.0, <3",
1414
"typing-extensions>=4.14, <5",
1515
"anyio>=3.5.0, <5",
@@ -18,7 +18,9 @@ dependencies = [
1818
"typer>=0.16,<0.17",
1919
"questionary>=2.0.1,<3",
2020
"rich>=13.9.2,<14",
21-
"fastapi>=0.115.0,<0.116",
21+
"fastapi>=0.115.0",
22+
"starlette>=0.49.1",
23+
"tornado>=6.5.5",
2224
"uvicorn>=0.31.1",
2325
"watchfiles>=0.24.0,<1.0",
2426
"python-on-whales>=0.73.0,<0.74",
@@ -28,7 +30,7 @@ dependencies = [
2830
"temporalio>=1.26.0,<2",
2931
"aiohttp>=3.10.10,<4",
3032
"redis>=5.2.0,<6",
31-
"litellm>=1.83.0,<2",
33+
"litellm>=1.83.7,<2",
3234
"kubernetes>=25.0.0,<36.0.0",
3335
"jinja2>=3.1.3,<4",
3436
"mcp[cli]>=1.4.1",

src/agentex/lib/core/temporal/plugins/openai_agents/hooks/hooks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: To
131131
logger.warning(f"Failed to parse tool arguments: {tool_context.tool_arguments}")
132132
tool_arguments = {}
133133

134-
await workflow.execute_activity_method(
134+
await workflow.execute_activity(
135135
stream_lifecycle_content,
136136
args=[
137137
self.task_id,
@@ -167,7 +167,7 @@ async def on_tool_end(
167167
else f"call_{id(tool)}"
168168
)
169169

170-
await workflow.execute_activity_method(
170+
await workflow.execute_activity(
171171
stream_lifecycle_content,
172172
args=[
173173
self.task_id,
@@ -195,7 +195,7 @@ async def on_handoff(
195195
from_agent: The agent transferring control
196196
to_agent: The agent receiving control
197197
"""
198-
await workflow.execute_activity_method(
198+
await workflow.execute_activity(
199199
stream_lifecycle_content,
200200
args=[
201201
self.task_id,

src/agentex/lib/core/tracing/processors/agentex_tracing_processor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def on_span_start(self, span: Span) -> None:
2626
input=span.input,
2727
output=span.output,
2828
parent_id=span.parent_id,
29+
task_id=span.task_id,
2930
)
3031

3132
@override
@@ -94,6 +95,7 @@ async def on_span_start(self, span: Span) -> None:
9495
input=span.input,
9596
output=span.output,
9697
data=span.data,
98+
task_id=span.task_id,
9799
)
98100

99101
@override

src/agentex/lib/sdk/fastacp/base/base_acp_server.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
import uvicorn
1212
from fastapi import FastAPI, Request
1313
from pydantic import TypeAdapter, ValidationError
14+
from starlette.types import Send, Scope, ASGIApp, Receive
1415
from fastapi.responses import StreamingResponse
15-
from starlette.middleware.base import BaseHTTPMiddleware
1616

1717
from agentex.lib.types.acp import (
1818
RPC_SYNC_METHODS,
@@ -44,17 +44,19 @@
4444
task_message_update_adapter = TypeAdapter(TaskMessageUpdate)
4545

4646

47-
class RequestIDMiddleware(BaseHTTPMiddleware):
48-
"""Middleware to extract or generate request IDs and add them to logs and response headers"""
47+
class RequestIDMiddleware:
48+
"""Pure ASGI middleware to set request IDs without buffering streaming responses."""
4949

50-
async def dispatch(self, request: Request, call_next): # type: ignore[override]
51-
# Extract request ID from header or generate a new one if there isn't one
52-
request_id = request.headers.get("x-request-id") or uuid.uuid4().hex
53-
# Store request ID in request state for access in handlers
54-
ctx_var_request_id.set(request_id)
55-
# Process request
56-
response = await call_next(request)
57-
return response
50+
def __init__(self, app: ASGIApp) -> None:
51+
self.app = app
52+
53+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
54+
if scope["type"] == "http":
55+
headers = dict(scope.get("headers", []))
56+
raw_request_id = headers.get(b"x-request-id", b"")
57+
request_id = raw_request_id.decode() if raw_request_id else uuid.uuid4().hex
58+
ctx_var_request_id.set(request_id)
59+
await self.app(scope, receive, send)
5860

5961

6062
class BaseACPServer(FastAPI):

src/agentex/resources/agents.py

Lines changed: 129 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import json
6-
from typing import Union, Optional, Generator, AsyncGenerator
6+
from typing import Any, Union, Optional, Generator, AsyncGenerator
77
from typing_extensions import Literal
88

99
import httpx
@@ -359,7 +359,7 @@ def create_task(
359359
) -> CreateTaskResponse:
360360
if agent_id is not None and agent_name is not None:
361361
raise ValueError("Either agent_id or agent_name must be provided, but not both")
362-
362+
363363
if agent_id is not None:
364364
raw_agent_rpc_response = self.rpc(
365365
agent_id=agent_id,
@@ -386,7 +386,7 @@ def create_task(
386386
)
387387
else:
388388
raise ValueError("Either agent_id or agent_name must be provided")
389-
389+
390390
return CreateTaskResponse.model_validate(raw_agent_rpc_response, from_attributes=True)
391391

392392
def cancel_task(
@@ -453,38 +453,71 @@ def send_message(
453453
) -> SendMessageResponse:
454454
if agent_id is not None and agent_name is not None:
455455
raise ValueError("Either agent_id or agent_name must be provided, but not both")
456-
456+
457457
if "stream" in params and params["stream"] == True:
458458
raise ValueError("If stream is set to True, use send_message_stream() instead")
459+
460+
if agent_id is not None:
461+
raw_agent_rpc_response = self.with_streaming_response.rpc(
462+
agent_id=agent_id,
463+
method="message/send",
464+
params=params,
465+
id=id,
466+
jsonrpc=jsonrpc,
467+
extra_headers=extra_headers,
468+
extra_query=extra_query,
469+
extra_body=extra_body,
470+
timeout=timeout,
471+
)
472+
elif agent_name is not None:
473+
raw_agent_rpc_response = self.with_streaming_response.rpc_by_name(
474+
agent_name=agent_name,
475+
method="message/send",
476+
params=params,
477+
id=id,
478+
jsonrpc=jsonrpc,
479+
extra_headers=extra_headers,
480+
extra_query=extra_query,
481+
extra_body=extra_body,
482+
timeout=timeout,
483+
)
459484
else:
460-
if agent_id is not None:
461-
raw_agent_rpc_response = self.rpc(
462-
agent_id=agent_id,
463-
method="message/send",
464-
params=params,
465-
id=id,
466-
jsonrpc=jsonrpc,
467-
extra_headers=extra_headers,
468-
extra_query=extra_query,
469-
extra_body=extra_body,
470-
timeout=timeout,
471-
)
472-
elif agent_name is not None:
473-
raw_agent_rpc_response = self.rpc_by_name(
474-
agent_name=agent_name,
475-
method="message/send",
476-
params=params,
477-
id=id,
478-
jsonrpc=jsonrpc,
479-
extra_headers=extra_headers,
480-
extra_query=extra_query,
481-
extra_body=extra_body,
482-
timeout=timeout,
483-
)
484-
else:
485-
raise ValueError("Either agent_id or agent_name must be provided")
486-
487-
return SendMessageResponse.model_validate(raw_agent_rpc_response, from_attributes=True)
485+
raise ValueError("Either agent_id or agent_name must be provided")
486+
487+
task_messages: list[Any] = []
488+
response_meta: dict[str, Any] = {}
489+
490+
with raw_agent_rpc_response as response:
491+
for _line in response.iter_lines():
492+
if not _line:
493+
continue
494+
line = _line.strip()
495+
if line.startswith("data:"):
496+
line = line[len("data:"):].strip()
497+
if not line:
498+
continue
499+
try:
500+
chunk = json.loads(line)
501+
if not response_meta:
502+
response_meta = {"id": chunk.get("id"), "jsonrpc": chunk.get("jsonrpc")}
503+
try:
504+
return SendMessageResponse.model_validate(chunk)
505+
except ValidationError:
506+
pass
507+
chunk_stream = SendMessageStreamResponse.model_validate(chunk, from_attributes=True)
508+
result = chunk_stream.result
509+
if result is not None and getattr(result, "type", None) == "full":
510+
parent = getattr(result, "parent_task_message", None)
511+
if parent is not None:
512+
task_messages.append(parent)
513+
except (json.JSONDecodeError, ValidationError):
514+
continue
515+
516+
return SendMessageResponse(
517+
id=response_meta.get("id"),
518+
jsonrpc=response_meta.get("jsonrpc"),
519+
result=task_messages,
520+
)
488521

489522
def send_message_stream(
490523
self,
@@ -552,8 +585,8 @@ def send_message_stream(
552585
from_attributes=True
553586
)
554587
yield chunk_rpc_response
555-
except json.JSONDecodeError:
556-
# Skip invalid JSON lines
588+
except (json.JSONDecodeError, ValidationError):
589+
# Skip invalid JSON lines or lines that cannot be validated
557590
continue
558591

559592
def send_event(
@@ -1021,38 +1054,71 @@ async def send_message(
10211054
) -> SendMessageResponse:
10221055
if agent_id is not None and agent_name is not None:
10231056
raise ValueError("Either agent_id or agent_name must be provided, but not both")
1024-
1057+
10251058
if "stream" in params and params["stream"] == True:
10261059
raise ValueError("If stream is set to True, use send_message_stream() instead")
1060+
1061+
if agent_id is not None:
1062+
raw_agent_rpc_response = self.with_streaming_response.rpc(
1063+
agent_id=agent_id,
1064+
method="message/send",
1065+
params=params,
1066+
id=id,
1067+
jsonrpc=jsonrpc,
1068+
extra_headers=extra_headers,
1069+
extra_query=extra_query,
1070+
extra_body=extra_body,
1071+
timeout=timeout,
1072+
)
1073+
elif agent_name is not None:
1074+
raw_agent_rpc_response = self.with_streaming_response.rpc_by_name(
1075+
agent_name=agent_name,
1076+
method="message/send",
1077+
params=params,
1078+
id=id,
1079+
jsonrpc=jsonrpc,
1080+
extra_headers=extra_headers,
1081+
extra_query=extra_query,
1082+
extra_body=extra_body,
1083+
timeout=timeout,
1084+
)
10271085
else:
1028-
if agent_id is not None:
1029-
raw_agent_rpc_response = await self.rpc(
1030-
agent_id=agent_id,
1031-
method="message/send",
1032-
params=params,
1033-
id=id,
1034-
jsonrpc=jsonrpc,
1035-
extra_headers=extra_headers,
1036-
extra_query=extra_query,
1037-
extra_body=extra_body,
1038-
timeout=timeout,
1039-
)
1040-
elif agent_name is not None:
1041-
raw_agent_rpc_response = await self.rpc_by_name(
1042-
agent_name=agent_name,
1043-
method="message/send",
1044-
params=params,
1045-
id=id,
1046-
jsonrpc=jsonrpc,
1047-
extra_headers=extra_headers,
1048-
extra_query=extra_query,
1049-
extra_body=extra_body,
1050-
timeout=timeout,
1051-
)
1052-
else:
1053-
raise ValueError("Either agent_id or agent_name must be provided")
1054-
1055-
return SendMessageResponse.model_validate(raw_agent_rpc_response, from_attributes=True)
1086+
raise ValueError("Either agent_id or agent_name must be provided")
1087+
1088+
task_messages: list[Any] = []
1089+
response_meta: dict[str, Any] = {}
1090+
1091+
async with raw_agent_rpc_response as response:
1092+
async for _line in response.iter_lines():
1093+
if not _line:
1094+
continue
1095+
line = _line.strip()
1096+
if line.startswith("data:"):
1097+
line = line[len("data:"):].strip()
1098+
if not line:
1099+
continue
1100+
try:
1101+
chunk = json.loads(line)
1102+
if not response_meta:
1103+
response_meta = {"id": chunk.get("id"), "jsonrpc": chunk.get("jsonrpc")}
1104+
try:
1105+
return SendMessageResponse.model_validate(chunk)
1106+
except ValidationError:
1107+
pass
1108+
chunk_stream = SendMessageStreamResponse.model_validate(chunk, from_attributes=True)
1109+
result = chunk_stream.result
1110+
if result is not None and getattr(result, "type", None) == "full":
1111+
parent = getattr(result, "parent_task_message", None)
1112+
if parent is not None:
1113+
task_messages.append(parent)
1114+
except (json.JSONDecodeError, ValidationError):
1115+
continue
1116+
1117+
return SendMessageResponse(
1118+
id=response_meta.get("id"),
1119+
jsonrpc=response_meta.get("jsonrpc"),
1120+
result=task_messages,
1121+
)
10561122

10571123
async def send_message_stream(
10581124
self,

0 commit comments

Comments
 (0)