Skip to content

Commit e81aa20

Browse files
fix typing
1 parent 5a3e109 commit e81aa20

3 files changed

Lines changed: 51 additions & 23 deletions

File tree

sentry_sdk/integrations/openai_agents/patches/agent_run.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
from typing import TYPE_CHECKING
1515

1616
if TYPE_CHECKING:
17-
from typing import Any, Optional
17+
from typing import Any, Optional, Callable, Awaitable
1818

1919
from sentry_sdk.tracing import Span
2020

21+
from agents.run_internal.run_steps import SingleStepResult
22+
2123
try:
2224
import agents
2325
except ImportError:
@@ -88,7 +90,11 @@ def _maybe_start_agent_span(
8890
return span
8991

9092

91-
async def _single_turn(original_run_single_turn, *args, **kwargs):
93+
async def _single_turn(
94+
original_run_single_turn: "Callable[..., Awaitable[SingleStepResult]]",
95+
*args: "Any",
96+
**kwargs: "Any",
97+
) -> "SingleStepResult":
9298
"""Patched _run_single_turn that creates agent invocation spans"""
9399
agent = kwargs.get("agent")
94100
context_wrapper = kwargs.get("context_wrapper")
@@ -109,7 +115,7 @@ async def _single_turn(original_run_single_turn, *args, **kwargs):
109115
return result
110116

111117

112-
def _patch_agent_runner_run_single_turn():
118+
def _patch_agent_runner_run_single_turn() -> None:
113119
original_run_single_turn = agents.run.AgentRunner._run_single_turn
114120

115121
@wraps(
@@ -123,10 +129,10 @@ async def patched_run_single_turn(
123129
"""Patched _run_single_turn that creates agent invocation spans"""
124130
return await _single_turn(original_run_single_turn, *args, **kwargs)
125131

126-
agents.run.AgentRunner._run_single_turn = patched_run_single_turn
132+
agents.run.AgentRunner._run_single_turn = classmethod(patched_run_single_turn)
127133

128134

129-
def _patch_run_loop_run_single_turn():
135+
def _patch_run_loop_run_single_turn() -> None:
130136
original_run_single_turn = run_loop.run_single_turn
131137

132138
@wraps(original_run_single_turn)
@@ -137,7 +143,11 @@ async def patched_run_single_turn(*args: "Any", **kwargs: "Any") -> "Any":
137143
agents.run.run_single_turn = patched_run_single_turn
138144

139145

140-
async def _single_turn_streamed(original_run_single_turn_streamed, *args, **kwargs):
146+
async def _single_turn_streamed(
147+
original_run_single_turn_streamed: "Callable[..., Awaitable[SingleStepResult]]",
148+
*args: "Any",
149+
**kwargs: "Any",
150+
) -> "SingleStepResult":
141151
"""Patched _run_single_turn_streamed that creates agent invocation spans for streaming.
142152
143153
Note: Unlike _run_single_turn which uses keyword-only arguments (*,),
@@ -187,7 +197,7 @@ async def _single_turn_streamed(original_run_single_turn_streamed, *args, **kwar
187197
return result
188198

189199

190-
def _patch_agent_runner_run_single_turn_streamed():
200+
def _patch_agent_runner_run_single_turn_streamed() -> None:
191201
original_run_single_turn_streamed = agents.run.AgentRunner._run_single_turn_streamed
192202

193203
@wraps(
@@ -202,10 +212,12 @@ async def patched_run_single_turn_streamed(
202212
original_run_single_turn_streamed, *args, **kwargs
203213
)
204214

205-
agents.run.AgentRunner._run_single_turn_streamed = patched_run_single_turn_streamed
215+
agents.run.AgentRunner._run_single_turn_streamed = classmethod(
216+
patched_run_single_turn_streamed
217+
)
206218

207219

208-
def _patch_run_loop_run_single_turn_streamed():
220+
def _patch_run_loop_run_single_turn_streamed() -> None:
209221
original_run_single_turn_streamed = run_loop.run_single_turn_streamed
210222

211223
@wraps(original_run_single_turn_streamed)
@@ -219,7 +231,11 @@ async def patched_run_single_turn_streamed(*args: "Any", **kwargs: "Any") -> "An
219231
)
220232

221233

222-
async def execute_handoffs(original_execute_handoffs, *args, **kwargs):
234+
async def execute_handoffs(
235+
original_execute_handoffs: "Callable[..., Awaitable[SingleStepResult]]",
236+
*args: "Any",
237+
**kwargs: "Any",
238+
) -> "SingleStepResult":
223239
"""Patched execute_handoffs that creates handoff spans and ends agent span for handoffs"""
224240
context_wrapper = kwargs.get("context_wrapper")
225241
run_handoffs = kwargs.get("run_handoffs")
@@ -247,7 +263,7 @@ async def execute_handoffs(original_execute_handoffs, *args, **kwargs):
247263
return result
248264

249265

250-
def _patch_run_impl_execute_handoffs():
266+
def _patch_run_impl_execute_handoffs() -> None:
251267
original_execute_handoffs = agents._run_impl.RunImpl.execute_handoffs
252268

253269
@wraps(
@@ -260,10 +276,10 @@ async def patched_execute_handoffs(
260276
) -> "Any":
261277
return await execute_handoffs(original_execute_handoffs, *args, **kwargs)
262278

263-
agents._run_impl.RunImpl.execute_handoffs = patched_execute_handoffs
279+
agents._run_impl.RunImpl.execute_handoffs = classmethod(patched_execute_handoffs)
264280

265281

266-
def _patch_turn_resolution_execute_handoffs():
282+
def _patch_turn_resolution_execute_handoffs() -> None:
267283
original_execute_handoffs = turn_resolution.execute_handoffs
268284

269285
@wraps(original_execute_handoffs)
@@ -273,7 +289,11 @@ async def patched_execute_handoffs(*args: "Any", **kwargs: "Any") -> "Any":
273289
agents.run_internal.turn_resolution.execute_handoffs = patched_execute_handoffs
274290

275291

276-
async def execute_final_output(original_execute_final_output, *args, **kwargs):
292+
async def execute_final_output(
293+
original_execute_final_output: "Callable[..., Awaitable[SingleStepResult]]",
294+
*args: "Any",
295+
**kwargs: "Any",
296+
) -> "SingleStepResult":
277297
"""Patched execute_final_output that ends agent span for final outputs"""
278298

279299
agent = kwargs.get("agent")
@@ -292,7 +312,7 @@ async def execute_final_output(original_execute_final_output, *args, **kwargs):
292312
return result
293313

294314

295-
def _patch_run_impl_execute_final_output():
315+
def _patch_run_impl_execute_final_output() -> None:
296316
original_execute_final_output = agents._run_impl.RunImpl.execute_final_output
297317

298318
@wraps(
@@ -307,10 +327,12 @@ async def patched_execute_final_output(
307327
original_execute_final_output, *args, **kwargs
308328
)
309329

310-
agents._run_impl.RunImpl = patched_execute_final_output
330+
agents._run_impl.RunImpl.execute_final_output = classmethod(
331+
patched_execute_final_output
332+
)
311333

312334

313-
def _patch_turn_resolution_execute_final_output():
335+
def _patch_turn_resolution_execute_final_output() -> None:
314336
original_execute_final_output = turn_resolution.execute_final_output
315337

316338
@wraps(original_execute_final_output)

sentry_sdk/integrations/openai_agents/patches/models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22
import time
3-
from functools import wraps
3+
from functools import wraps, partial
44

55
from sentry_sdk.integrations import DidNotEnable
66

@@ -67,8 +67,10 @@ def _inject_trace_propagation_headers(
6767

6868

6969
def _get_model(
70-
original_get_model, agent: "agents.Agent", run_config: "agents.RunConfig"
71-
):
70+
original_get_model: "Callable[..., agents.Model]",
71+
agent: "agents.Agent",
72+
run_config: "agents.RunConfig",
73+
) -> "agents.Model":
7274
# copy the model to double patching its methods. We use copy on purpose here (instead of deepcopy)
7375
# because we only patch its direct methods, all underlying data can remain unchanged.
7476
model = copy.copy(original_get_model(agent, run_config))

sentry_sdk/integrations/openai_agents/patches/tools.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from functools import wraps
1+
from functools import wraps, partial
22

33
from sentry_sdk.integrations import DidNotEnable
44

@@ -7,15 +7,19 @@
77
from typing import TYPE_CHECKING
88

99
if TYPE_CHECKING:
10-
from typing import Any, Callable
10+
from typing import Any, Callable, Awaitable
1111

1212
try:
1313
import agents
1414
except ImportError:
1515
raise DidNotEnable("OpenAI Agents not installed")
1616

1717

18-
async def _get_all_tools(original_get_all_tools, agent, context_wrapper):
18+
async def _get_all_tools(
19+
original_get_all_tools: "Callable[..., Awaitable[list[agents.Tool]]]",
20+
agent: "agents.Agent",
21+
context_wrapper: "agents.RunContextWrapper",
22+
) -> "list[agents.Tool]":
1923
# Get the original tools
2024
tools = await original_get_all_tools(agent, context_wrapper)
2125

0 commit comments

Comments
 (0)