diff --git a/src/agents/models/chatcmpl_helpers.py b/src/agents/models/chatcmpl_helpers.py index 335e3f521..01ced356b 100644 --- a/src/agents/models/chatcmpl_helpers.py +++ b/src/agents/models/chatcmpl_helpers.py @@ -3,6 +3,12 @@ from contextvars import ContextVar from openai import AsyncOpenAI +from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob +from openai.types.responses.response_output_text import Logprob, LogprobTopLogprob +from openai.types.responses.response_text_delta_event import ( + Logprob as DeltaLogprob, + LogprobTopLogprob as DeltaTopLogprob, +) from ..model_settings import ModelSettings from ..version import __version__ @@ -41,3 +47,54 @@ def get_stream_options_param( ) stream_options = {"include_usage": include_usage} if include_usage is not None else None return stream_options + + @classmethod + def convert_logprobs_for_output_text( + cls, logprobs: list[ChatCompletionTokenLogprob] | None + ) -> list[Logprob] | None: + if not logprobs: + return None + + converted: list[Logprob] = [] + for token_logprob in logprobs: + converted.append( + Logprob( + token=token_logprob.token, + logprob=token_logprob.logprob, + bytes=token_logprob.bytes or [], + top_logprobs=[ + LogprobTopLogprob( + token=top_logprob.token, + logprob=top_logprob.logprob, + bytes=top_logprob.bytes or [], + ) + for top_logprob in token_logprob.top_logprobs + ], + ) + ) + return converted + + @classmethod + def convert_logprobs_for_text_delta( + cls, logprobs: list[ChatCompletionTokenLogprob] | None + ) -> list[DeltaLogprob] | None: + if not logprobs: + return None + + converted: list[DeltaLogprob] = [] + for token_logprob in logprobs: + converted.append( + DeltaLogprob( + token=token_logprob.token, + logprob=token_logprob.logprob, + top_logprobs=[ + DeltaTopLogprob( + token=top_logprob.token, + logprob=top_logprob.logprob, + ) + for top_logprob in token_logprob.top_logprobs + ] + or None, + ) + ) + return converted diff --git a/src/agents/models/chatcmpl_stream_handler.py b/src/agents/models/chatcmpl_stream_handler.py index f1c504977..b018b38a9 100644 --- a/src/agents/models/chatcmpl_stream_handler.py +++ b/src/agents/models/chatcmpl_stream_handler.py @@ -42,6 +42,7 @@ from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from ..items import TResponseStreamEvent +from .chatcmpl_helpers import ChatCmplHelpers from .fake_id import FAKE_RESPONSES_ID @@ -105,6 +106,7 @@ async def handle_stream( continue delta = chunk.choices[0].delta + choice_logprobs = chunk.choices[0].logprobs # Handle thinking blocks from Anthropic (for preserving signatures) if hasattr(delta, "thinking_blocks") and delta.thinking_blocks: @@ -266,6 +268,15 @@ async def handle_stream( type="response.content_part.added", sequence_number=sequence_number.get_and_increment(), ) + delta_logprobs = ( + ChatCmplHelpers.convert_logprobs_for_text_delta( + choice_logprobs.content if choice_logprobs else None + ) + or [] + ) + output_logprobs = ChatCmplHelpers.convert_logprobs_for_output_text( + choice_logprobs.content if choice_logprobs else None + ) # Emit the delta for this segment of content yield ResponseTextDeltaEvent( content_index=state.text_content_index_and_output[0], @@ -275,10 +286,15 @@ async def handle_stream( is not None, # fixed 0 -> 0 or 1 type="response.output_text.delta", sequence_number=sequence_number.get_and_increment(), - logprobs=[], + logprobs=delta_logprobs, ) # Accumulate the text into the response part state.text_content_index_and_output[1].text += delta.content + if output_logprobs: + existing_logprobs = state.text_content_index_and_output[1].logprobs or [] + state.text_content_index_and_output[1].logprobs = ( + existing_logprobs + output_logprobs + ) # Handle refusals (model declines to answer) # This is always set by the OpenAI API, but not by others e.g. LiteLLM diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 76f36d86b..ea8ba98cd 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -9,7 +9,13 @@ from openai.types import ChatModel from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage from openai.types.chat.chat_completion import Choice -from openai.types.responses import Response +from openai.types.responses import ( + Response, + ResponseOutputItem, + ResponseOutputMessage, + ResponseOutputText, +) +from openai.types.responses.response_output_text import Logprob from openai.types.responses.response_prompt_param import ResponsePromptParam from .. import _debug @@ -119,12 +125,33 @@ async def get_response( items = Converter.message_to_output_items(message) if message is not None else [] + logprob_models = None + if first_choice and first_choice.logprobs and first_choice.logprobs.content: + logprob_models = ChatCmplHelpers.convert_logprobs_for_output_text( + first_choice.logprobs.content + ) + + if logprob_models: + self._attach_logprobs_to_output(items, logprob_models) + return ModelResponse( output=items, usage=usage, response_id=None, ) + def _attach_logprobs_to_output( + self, output_items: list[ResponseOutputItem], logprobs: list[Logprob] + ) -> None: + for output_item in output_items: + if not isinstance(output_item, ResponseOutputMessage): + continue + + for content in output_item.content: + if isinstance(content, ResponseOutputText): + content.logprobs = logprobs + return + async def stream_response( self, system_instructions: str | None, diff --git a/tests/test_openai_chatcompletions.py b/tests/test_openai_chatcompletions.py index 3a0f75364..7e88242c7 100644 --- a/tests/test_openai_chatcompletions.py +++ b/tests/test_openai_chatcompletions.py @@ -6,13 +6,17 @@ import httpx import pytest from openai import AsyncOpenAI, omit -from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.chat.chat_completion import ChatCompletion, Choice, ChoiceLogprobs from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.chat.chat_completion_message import ChatCompletionMessage from openai.types.chat.chat_completion_message_tool_call import ( # type: ignore[attr-defined] ChatCompletionMessageFunctionToolCall, Function, ) +from openai.types.chat.chat_completion_token_logprob import ( + ChatCompletionTokenLogprob, + TopLogprob, +) from openai.types.completion_usage import ( CompletionUsage, PromptTokensDetails, @@ -98,6 +102,65 @@ async def patched_fetch_response(self, *args, **kwargs): assert resp.response_id is None +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_get_response_attaches_logprobs(monkeypatch) -> None: + msg = ChatCompletionMessage(role="assistant", content="Hi!") + choice = Choice( + index=0, + finish_reason="stop", + message=msg, + logprobs=ChoiceLogprobs( + content=[ + ChatCompletionTokenLogprob( + token="Hi", + logprob=-0.5, + bytes=[1], + top_logprobs=[TopLogprob(token="Hi", logprob=-0.5, bytes=[1])], + ), + ChatCompletionTokenLogprob( + token="!", + logprob=-0.1, + bytes=[2], + top_logprobs=[TopLogprob(token="!", logprob=-0.1, bytes=[2])], + ), + ] + ), + ) + chat = ChatCompletion( + id="resp-id", + created=0, + model="fake", + object="chat.completion", + choices=[choice], + usage=None, + ) + + async def patched_fetch_response(self, *args, **kwargs): + return chat + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider(use_responses=False).get_model("gpt-4") + resp: ModelResponse = await model.get_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt=None, + ) + assert len(resp.output) == 1 + assert isinstance(resp.output[0], ResponseOutputMessage) + text_part = resp.output[0].content[0] + assert isinstance(text_part, ResponseOutputText) + assert text_part.logprobs is not None + assert [lp.token for lp in text_part.logprobs] == ["Hi", "!"] + + @pytest.mark.allow_call_model_methods @pytest.mark.asyncio async def test_get_response_with_refusal(monkeypatch) -> None: diff --git a/tests/test_openai_chatcompletions_stream.py b/tests/test_openai_chatcompletions_stream.py index 947816f01..847aef8da 100644 --- a/tests/test_openai_chatcompletions_stream.py +++ b/tests/test_openai_chatcompletions_stream.py @@ -7,6 +7,11 @@ ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction, + ChoiceLogprobs, +) +from openai.types.chat.chat_completion_token_logprob import ( + ChatCompletionTokenLogprob, + TopLogprob, ) from openai.types.completion_usage import ( CompletionTokensDetails, @@ -15,6 +20,7 @@ ) from openai.types.responses import ( Response, + ResponseCompletedEvent, ResponseFunctionToolCall, ResponseOutputMessage, ResponseOutputRefusal, @@ -128,6 +134,113 @@ async def patched_fetch_response(self, *args, **kwargs): assert completed_resp.usage.output_tokens_details.reasoning_tokens == 3 +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_includes_logprobs(monkeypatch) -> None: + chunk1 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[ + Choice( + index=0, + delta=ChoiceDelta(content="Hi"), + logprobs=ChoiceLogprobs( + content=[ + ChatCompletionTokenLogprob( + token="Hi", + logprob=-0.5, + bytes=[1], + top_logprobs=[TopLogprob(token="Hi", logprob=-0.5, bytes=[1])], + ) + ] + ), + ) + ], + ) + chunk2 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[ + Choice( + index=0, + delta=ChoiceDelta(content=" there"), + logprobs=ChoiceLogprobs( + content=[ + ChatCompletionTokenLogprob( + token=" there", + logprob=-0.25, + bytes=[2], + top_logprobs=[TopLogprob(token=" there", logprob=-0.25, bytes=[2])], + ) + ] + ), + ) + ], + usage=CompletionUsage( + completion_tokens=5, + prompt_tokens=7, + total_tokens=12, + prompt_tokens_details=PromptTokensDetails(cached_tokens=2), + completion_tokens_details=CompletionTokensDetails(reasoning_tokens=3), + ), + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + for c in (chunk1, chunk2): + yield c + + async def patched_fetch_response(self, *args, **kwargs): + resp = Response( + id="resp-id", + created_at=0, + model="fake-model", + object="response", + output=[], + tool_choice="none", + tools=[], + parallel_tool_calls=False, + ) + return resp, fake_stream() + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider(use_responses=False).get_model("gpt-4") + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + conversation_id=None, + prompt=None, + ): + output_events.append(event) + + text_delta_events = [ + event for event in output_events if event.type == "response.output_text.delta" + ] + assert len(text_delta_events) == 2 + assert [lp.token for lp in text_delta_events[0].logprobs] == ["Hi"] + assert [lp.token for lp in text_delta_events[1].logprobs] == [" there"] + + completed_event = next(event for event in output_events if event.type == "response.completed") + assert isinstance(completed_event, ResponseCompletedEvent) + completed_resp = completed_event.response + assert isinstance(completed_resp.output[0], ResponseOutputMessage) + text_part = completed_resp.output[0].content[0] + assert isinstance(text_part, ResponseOutputText) + assert text_part.text == "Hi there" + assert text_part.logprobs is not None + assert [lp.token for lp in text_part.logprobs] == ["Hi", " there"] + + @pytest.mark.allow_call_model_methods @pytest.mark.asyncio async def test_stream_response_yields_events_for_refusal_content(monkeypatch) -> None: