Skip to content

Commit 3ebf3a3

Browse files
committed
refactor: attach logprobs to ResponseOutputText for Responses API consistency
Instead of adding a separate logprobs field to ModelResponse, attach logprobs directly to ResponseOutputText content parts. This makes the chat completions API behavior consistent with the Responses API. - Add conversion helpers in chatcmpl_helpers.py - Update streaming to include logprobs in delta events and accumulate - Attach logprobs to text parts in non-streaming responses - Add tests for both streaming and non-streaming logprobs
1 parent 251c3ff commit 3ebf3a3

File tree

6 files changed

+276
-13
lines changed

6 files changed

+276
-13
lines changed

src/agents/items.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -356,13 +356,6 @@ class ModelResponse:
356356
be passed to `Runner.run`.
357357
"""
358358

359-
logprobs: list[Any] | None = None
360-
"""Token log probabilities from the model response.
361-
Only populated when using the chat completions API with `top_logprobs` set in ModelSettings.
362-
Each element corresponds to a token and contains the token string, log probability, and
363-
optionally the top alternative tokens with their log probabilities.
364-
"""
365-
366359
def to_input_items(self) -> list[TResponseInputItem]:
367360
"""Convert the output into a list of input items suitable for passing to the model."""
368361
# We happen to know that the shape of the Pydantic output items are the same as the

src/agents/models/chatcmpl_helpers.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
from contextvars import ContextVar
44

55
from openai import AsyncOpenAI
6+
from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob
7+
from openai.types.responses.response_output_text import Logprob, LogprobTopLogprob
8+
from openai.types.responses.response_text_delta_event import (
9+
Logprob as DeltaLogprob,
10+
LogprobTopLogprob as DeltaTopLogprob,
11+
)
612

713
from ..model_settings import ModelSettings
814
from ..version import __version__
@@ -41,3 +47,54 @@ def get_stream_options_param(
4147
)
4248
stream_options = {"include_usage": include_usage} if include_usage is not None else None
4349
return stream_options
50+
51+
@classmethod
52+
def convert_logprobs_for_output_text(
53+
cls, logprobs: list[ChatCompletionTokenLogprob] | None
54+
) -> list[Logprob] | None:
55+
if not logprobs:
56+
return None
57+
58+
converted: list[Logprob] = []
59+
for token_logprob in logprobs:
60+
converted.append(
61+
Logprob(
62+
token=token_logprob.token,
63+
logprob=token_logprob.logprob,
64+
bytes=token_logprob.bytes or [],
65+
top_logprobs=[
66+
LogprobTopLogprob(
67+
token=top_logprob.token,
68+
logprob=top_logprob.logprob,
69+
bytes=top_logprob.bytes or [],
70+
)
71+
for top_logprob in token_logprob.top_logprobs
72+
],
73+
)
74+
)
75+
return converted
76+
77+
@classmethod
78+
def convert_logprobs_for_text_delta(
79+
cls, logprobs: list[ChatCompletionTokenLogprob] | None
80+
) -> list[DeltaLogprob] | None:
81+
if not logprobs:
82+
return None
83+
84+
converted: list[DeltaLogprob] = []
85+
for token_logprob in logprobs:
86+
converted.append(
87+
DeltaLogprob(
88+
token=token_logprob.token,
89+
logprob=token_logprob.logprob,
90+
top_logprobs=[
91+
DeltaTopLogprob(
92+
token=top_logprob.token,
93+
logprob=top_logprob.logprob,
94+
)
95+
for top_logprob in token_logprob.top_logprobs
96+
]
97+
or None,
98+
)
99+
)
100+
return converted

src/agents/models/chatcmpl_stream_handler.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
4343

4444
from ..items import TResponseStreamEvent
45+
from .chatcmpl_helpers import ChatCmplHelpers
4546
from .fake_id import FAKE_RESPONSES_ID
4647

4748

@@ -103,6 +104,7 @@ async def handle_stream(
103104
continue
104105

105106
delta = chunk.choices[0].delta
107+
choice_logprobs = chunk.choices[0].logprobs
106108

107109
# Handle thinking blocks from Anthropic (for preserving signatures)
108110
if hasattr(delta, "thinking_blocks") and delta.thinking_blocks:
@@ -264,6 +266,12 @@ async def handle_stream(
264266
type="response.content_part.added",
265267
sequence_number=sequence_number.get_and_increment(),
266268
)
269+
delta_logprobs = ChatCmplHelpers.convert_logprobs_for_text_delta(
270+
choice_logprobs.content if choice_logprobs else None
271+
) or []
272+
output_logprobs = ChatCmplHelpers.convert_logprobs_for_output_text(
273+
choice_logprobs.content if choice_logprobs else None
274+
)
267275
# Emit the delta for this segment of content
268276
yield ResponseTextDeltaEvent(
269277
content_index=state.text_content_index_and_output[0],
@@ -273,10 +281,15 @@ async def handle_stream(
273281
is not None, # fixed 0 -> 0 or 1
274282
type="response.output_text.delta",
275283
sequence_number=sequence_number.get_and_increment(),
276-
logprobs=[],
284+
logprobs=delta_logprobs,
277285
)
278286
# Accumulate the text into the response part
279287
state.text_content_index_and_output[1].text += delta.content
288+
if output_logprobs:
289+
existing_logprobs = state.text_content_index_and_output[1].logprobs or []
290+
state.text_content_index_and_output[1].logprobs = (
291+
existing_logprobs + output_logprobs
292+
)
280293

281294
# Handle refusals (model declines to answer)
282295
# This is always set by the OpenAI API, but not by others e.g. LiteLLM

src/agents/models/openai_chatcompletions.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
from openai.types import ChatModel
1010
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
1111
from openai.types.chat.chat_completion import Choice
12-
from openai.types.responses import Response
12+
from openai.types.responses import (
13+
Response,
14+
ResponseOutputItem,
15+
ResponseOutputMessage,
16+
ResponseOutputText,
17+
)
18+
from openai.types.responses.response_output_text import Logprob
1319
from openai.types.responses.response_prompt_param import ResponsePromptParam
1420
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
1521

@@ -129,17 +135,33 @@ async def get_response(
129135

130136
items = Converter.message_to_output_items(message) if message is not None else []
131137

132-
logprobs_data = None
138+
logprob_models = None
133139
if first_choice and first_choice.logprobs and first_choice.logprobs.content:
134-
logprobs_data = [lp.model_dump() for lp in first_choice.logprobs.content]
140+
logprob_models = ChatCmplHelpers.convert_logprobs_for_output_text(
141+
first_choice.logprobs.content
142+
)
143+
144+
if logprob_models:
145+
self._attach_logprobs_to_output(items, logprob_models)
135146

136147
return ModelResponse(
137148
output=items,
138149
usage=usage,
139150
response_id=None,
140-
logprobs=logprobs_data,
141151
)
142152

153+
def _attach_logprobs_to_output(
154+
self, output_items: list[ResponseOutputItem], logprobs: list[Logprob]
155+
) -> None:
156+
for output_item in output_items:
157+
if not isinstance(output_item, ResponseOutputMessage):
158+
continue
159+
160+
for content in output_item.content:
161+
if isinstance(content, ResponseOutputText):
162+
content.logprobs = logprobs
163+
return
164+
143165
async def stream_response(
144166
self,
145167
system_instructions: str | None,

tests/test_openai_chatcompletions.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66
import httpx
77
import pytest
88
from openai import AsyncOpenAI, omit
9-
from openai.types.chat.chat_completion import ChatCompletion, Choice
9+
from openai.types.chat.chat_completion import ChatCompletion, Choice, ChoiceLogprobs
1010
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
1111
from openai.types.chat.chat_completion_message import ChatCompletionMessage
1212
from openai.types.chat.chat_completion_message_tool_call import ( # type: ignore[attr-defined]
1313
ChatCompletionMessageFunctionToolCall,
1414
Function,
1515
)
16+
from openai.types.chat.chat_completion_token_logprob import (
17+
ChatCompletionTokenLogprob,
18+
TopLogprob,
19+
)
1620
from openai.types.completion_usage import (
1721
CompletionUsage,
1822
PromptTokensDetails,
@@ -98,6 +102,65 @@ async def patched_fetch_response(self, *args, **kwargs):
98102
assert resp.response_id is None
99103

100104

105+
@pytest.mark.allow_call_model_methods
106+
@pytest.mark.asyncio
107+
async def test_get_response_attaches_logprobs(monkeypatch) -> None:
108+
msg = ChatCompletionMessage(role="assistant", content="Hi!")
109+
choice = Choice(
110+
index=0,
111+
finish_reason="stop",
112+
message=msg,
113+
logprobs=ChoiceLogprobs(
114+
content=[
115+
ChatCompletionTokenLogprob(
116+
token="Hi",
117+
logprob=-0.5,
118+
bytes=[1],
119+
top_logprobs=[TopLogprob(token="Hi", logprob=-0.5, bytes=[1])],
120+
),
121+
ChatCompletionTokenLogprob(
122+
token="!",
123+
logprob=-0.1,
124+
bytes=[2],
125+
top_logprobs=[TopLogprob(token="!", logprob=-0.1, bytes=[2])],
126+
),
127+
]
128+
),
129+
)
130+
chat = ChatCompletion(
131+
id="resp-id",
132+
created=0,
133+
model="fake",
134+
object="chat.completion",
135+
choices=[choice],
136+
usage=None,
137+
)
138+
139+
async def patched_fetch_response(self, *args, **kwargs):
140+
return chat
141+
142+
monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response)
143+
model = OpenAIProvider(use_responses=False).get_model("gpt-4")
144+
resp: ModelResponse = await model.get_response(
145+
system_instructions=None,
146+
input="",
147+
model_settings=ModelSettings(),
148+
tools=[],
149+
output_schema=None,
150+
handoffs=[],
151+
tracing=ModelTracing.DISABLED,
152+
previous_response_id=None,
153+
conversation_id=None,
154+
prompt=None,
155+
)
156+
assert len(resp.output) == 1
157+
assert isinstance(resp.output[0], ResponseOutputMessage)
158+
text_part = resp.output[0].content[0]
159+
assert isinstance(text_part, ResponseOutputText)
160+
assert text_part.logprobs is not None
161+
assert [lp.token for lp in text_part.logprobs] == ["Hi", "!"]
162+
163+
101164
@pytest.mark.allow_call_model_methods
102165
@pytest.mark.asyncio
103166
async def test_get_response_with_refusal(monkeypatch) -> None:

0 commit comments

Comments
 (0)