Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions src/agents/models/chatcmpl_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
from contextvars import ContextVar

from openai import AsyncOpenAI
from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob
from openai.types.responses.response_output_text import Logprob, LogprobTopLogprob
from openai.types.responses.response_text_delta_event import (
Logprob as DeltaLogprob,
LogprobTopLogprob as DeltaTopLogprob,
)

from ..model_settings import ModelSettings
from ..version import __version__
Expand Down Expand Up @@ -41,3 +47,54 @@ def get_stream_options_param(
)
stream_options = {"include_usage": include_usage} if include_usage is not None else None
return stream_options

@classmethod
def convert_logprobs_for_output_text(
cls, logprobs: list[ChatCompletionTokenLogprob] | None
) -> list[Logprob] | None:
if not logprobs:
return None

converted: list[Logprob] = []
for token_logprob in logprobs:
converted.append(
Logprob(
token=token_logprob.token,
logprob=token_logprob.logprob,
bytes=token_logprob.bytes or [],
top_logprobs=[
LogprobTopLogprob(
token=top_logprob.token,
logprob=top_logprob.logprob,
bytes=top_logprob.bytes or [],
)
for top_logprob in token_logprob.top_logprobs
],
)
)
return converted

@classmethod
def convert_logprobs_for_text_delta(
cls, logprobs: list[ChatCompletionTokenLogprob] | None
) -> list[DeltaLogprob] | None:
if not logprobs:
return None

converted: list[DeltaLogprob] = []
for token_logprob in logprobs:
converted.append(
DeltaLogprob(
token=token_logprob.token,
logprob=token_logprob.logprob,
top_logprobs=[
DeltaTopLogprob(
token=top_logprob.token,
logprob=top_logprob.logprob,
)
for top_logprob in token_logprob.top_logprobs
]
or None,
)
)
return converted
18 changes: 17 additions & 1 deletion src/agents/models/chatcmpl_stream_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails

from ..items import TResponseStreamEvent
from .chatcmpl_helpers import ChatCmplHelpers
from .fake_id import FAKE_RESPONSES_ID


Expand Down Expand Up @@ -105,6 +106,7 @@ async def handle_stream(
continue

delta = chunk.choices[0].delta
choice_logprobs = chunk.choices[0].logprobs

# Handle thinking blocks from Anthropic (for preserving signatures)
if hasattr(delta, "thinking_blocks") and delta.thinking_blocks:
Expand Down Expand Up @@ -266,6 +268,15 @@ async def handle_stream(
type="response.content_part.added",
sequence_number=sequence_number.get_and_increment(),
)
delta_logprobs = (
ChatCmplHelpers.convert_logprobs_for_text_delta(
choice_logprobs.content if choice_logprobs else None
)
or []
)
output_logprobs = ChatCmplHelpers.convert_logprobs_for_output_text(
choice_logprobs.content if choice_logprobs else None
)
# Emit the delta for this segment of content
yield ResponseTextDeltaEvent(
content_index=state.text_content_index_and_output[0],
Expand All @@ -275,10 +286,15 @@ async def handle_stream(
is not None, # fixed 0 -> 0 or 1
type="response.output_text.delta",
sequence_number=sequence_number.get_and_increment(),
logprobs=[],
logprobs=delta_logprobs,
)
# Accumulate the text into the response part
state.text_content_index_and_output[1].text += delta.content
if output_logprobs:
existing_logprobs = state.text_content_index_and_output[1].logprobs or []
state.text_content_index_and_output[1].logprobs = (
existing_logprobs + output_logprobs
)

# Handle refusals (model declines to answer)
# This is always set by the OpenAI API, but not by others e.g. LiteLLM
Expand Down
29 changes: 28 additions & 1 deletion src/agents/models/openai_chatcompletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
from openai.types import ChatModel
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
from openai.types.responses import Response
from openai.types.responses import (
Response,
ResponseOutputItem,
ResponseOutputMessage,
ResponseOutputText,
)
from openai.types.responses.response_output_text import Logprob
from openai.types.responses.response_prompt_param import ResponsePromptParam

from .. import _debug
Expand Down Expand Up @@ -119,12 +125,33 @@ async def get_response(

items = Converter.message_to_output_items(message) if message is not None else []

logprob_models = None
if first_choice and first_choice.logprobs and first_choice.logprobs.content:
logprob_models = ChatCmplHelpers.convert_logprobs_for_output_text(
first_choice.logprobs.content
)

if logprob_models:
self._attach_logprobs_to_output(items, logprob_models)

return ModelResponse(
output=items,
usage=usage,
response_id=None,
)

def _attach_logprobs_to_output(
self, output_items: list[ResponseOutputItem], logprobs: list[Logprob]
) -> None:
for output_item in output_items:
if not isinstance(output_item, ResponseOutputMessage):
continue

for content in output_item.content:
if isinstance(content, ResponseOutputText):
content.logprobs = logprobs
return

async def stream_response(
self,
system_instructions: str | None,
Expand Down
65 changes: 64 additions & 1 deletion tests/test_openai_chatcompletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
import httpx
import pytest
from openai import AsyncOpenAI, omit
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion import ChatCompletion, Choice, ChoiceLogprobs
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call import ( # type: ignore[attr-defined]
ChatCompletionMessageFunctionToolCall,
Function,
)
from openai.types.chat.chat_completion_token_logprob import (
ChatCompletionTokenLogprob,
TopLogprob,
)
from openai.types.completion_usage import (
CompletionUsage,
PromptTokensDetails,
Expand Down Expand Up @@ -98,6 +102,65 @@ async def patched_fetch_response(self, *args, **kwargs):
assert resp.response_id is None


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_get_response_attaches_logprobs(monkeypatch) -> None:
msg = ChatCompletionMessage(role="assistant", content="Hi!")
choice = Choice(
index=0,
finish_reason="stop",
message=msg,
logprobs=ChoiceLogprobs(
content=[
ChatCompletionTokenLogprob(
token="Hi",
logprob=-0.5,
bytes=[1],
top_logprobs=[TopLogprob(token="Hi", logprob=-0.5, bytes=[1])],
),
ChatCompletionTokenLogprob(
token="!",
logprob=-0.1,
bytes=[2],
top_logprobs=[TopLogprob(token="!", logprob=-0.1, bytes=[2])],
),
]
),
)
chat = ChatCompletion(
id="resp-id",
created=0,
model="fake",
object="chat.completion",
choices=[choice],
usage=None,
)

async def patched_fetch_response(self, *args, **kwargs):
return chat

monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response)
model = OpenAIProvider(use_responses=False).get_model("gpt-4")
resp: ModelResponse = await model.get_response(
system_instructions=None,
input="",
model_settings=ModelSettings(),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
conversation_id=None,
prompt=None,
)
assert len(resp.output) == 1
assert isinstance(resp.output[0], ResponseOutputMessage)
text_part = resp.output[0].content[0]
assert isinstance(text_part, ResponseOutputText)
assert text_part.logprobs is not None
assert [lp.token for lp in text_part.logprobs] == ["Hi", "!"]


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_get_response_with_refusal(monkeypatch) -> None:
Expand Down
113 changes: 113 additions & 0 deletions tests/test_openai_chatcompletions_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
ChoiceDelta,
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
ChoiceLogprobs,
)
from openai.types.chat.chat_completion_token_logprob import (
ChatCompletionTokenLogprob,
TopLogprob,
)
from openai.types.completion_usage import (
CompletionTokensDetails,
Expand All @@ -15,6 +20,7 @@
)
from openai.types.responses import (
Response,
ResponseCompletedEvent,
ResponseFunctionToolCall,
ResponseOutputMessage,
ResponseOutputRefusal,
Expand Down Expand Up @@ -128,6 +134,113 @@ async def patched_fetch_response(self, *args, **kwargs):
assert completed_resp.usage.output_tokens_details.reasoning_tokens == 3


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_stream_response_includes_logprobs(monkeypatch) -> None:
chunk1 = ChatCompletionChunk(
id="chunk-id",
created=1,
model="fake",
object="chat.completion.chunk",
choices=[
Choice(
index=0,
delta=ChoiceDelta(content="Hi"),
logprobs=ChoiceLogprobs(
content=[
ChatCompletionTokenLogprob(
token="Hi",
logprob=-0.5,
bytes=[1],
top_logprobs=[TopLogprob(token="Hi", logprob=-0.5, bytes=[1])],
)
]
),
)
],
)
chunk2 = ChatCompletionChunk(
id="chunk-id",
created=1,
model="fake",
object="chat.completion.chunk",
choices=[
Choice(
index=0,
delta=ChoiceDelta(content=" there"),
logprobs=ChoiceLogprobs(
content=[
ChatCompletionTokenLogprob(
token=" there",
logprob=-0.25,
bytes=[2],
top_logprobs=[TopLogprob(token=" there", logprob=-0.25, bytes=[2])],
)
]
),
)
],
usage=CompletionUsage(
completion_tokens=5,
prompt_tokens=7,
total_tokens=12,
prompt_tokens_details=PromptTokensDetails(cached_tokens=2),
completion_tokens_details=CompletionTokensDetails(reasoning_tokens=3),
),
)

async def fake_stream() -> AsyncIterator[ChatCompletionChunk]:
for c in (chunk1, chunk2):
yield c

async def patched_fetch_response(self, *args, **kwargs):
resp = Response(
id="resp-id",
created_at=0,
model="fake-model",
object="response",
output=[],
tool_choice="none",
tools=[],
parallel_tool_calls=False,
)
return resp, fake_stream()

monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response)
model = OpenAIProvider(use_responses=False).get_model("gpt-4")
output_events = []
async for event in model.stream_response(
system_instructions=None,
input="",
model_settings=ModelSettings(),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
conversation_id=None,
prompt=None,
):
output_events.append(event)

text_delta_events = [
event for event in output_events if event.type == "response.output_text.delta"
]
assert len(text_delta_events) == 2
assert [lp.token for lp in text_delta_events[0].logprobs] == ["Hi"]
assert [lp.token for lp in text_delta_events[1].logprobs] == [" there"]

completed_event = next(event for event in output_events if event.type == "response.completed")
assert isinstance(completed_event, ResponseCompletedEvent)
completed_resp = completed_event.response
assert isinstance(completed_resp.output[0], ResponseOutputMessage)
text_part = completed_resp.output[0].content[0]
assert isinstance(text_part, ResponseOutputText)
assert text_part.text == "Hi there"
assert text_part.logprobs is not None
assert [lp.token for lp in text_part.logprobs] == ["Hi", " there"]


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_stream_response_yields_events_for_refusal_content(monkeypatch) -> None:
Expand Down