Skip to content

Commit df020d1

Browse files
JRMeyerseratch
andauthored
feat: preserve logprobs from chat completions API in ModelResponse (#2134)
Co-authored-by: Kazuhiro Sera <seratch@openai.com>
1 parent 9fcc68f commit df020d1

File tree

5 files changed

+279
-3
lines changed

5 files changed

+279
-3
lines changed

src/agents/models/chatcmpl_helpers.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
from contextvars import ContextVar
44

55
from openai import AsyncOpenAI
6+
from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob
7+
from openai.types.responses.response_output_text import Logprob, LogprobTopLogprob
8+
from openai.types.responses.response_text_delta_event import (
9+
Logprob as DeltaLogprob,
10+
LogprobTopLogprob as DeltaTopLogprob,
11+
)
612

713
from ..model_settings import ModelSettings
814
from ..version import __version__
@@ -41,3 +47,54 @@ def get_stream_options_param(
4147
)
4248
stream_options = {"include_usage": include_usage} if include_usage is not None else None
4349
return stream_options
50+
51+
@classmethod
52+
def convert_logprobs_for_output_text(
53+
cls, logprobs: list[ChatCompletionTokenLogprob] | None
54+
) -> list[Logprob] | None:
55+
if not logprobs:
56+
return None
57+
58+
converted: list[Logprob] = []
59+
for token_logprob in logprobs:
60+
converted.append(
61+
Logprob(
62+
token=token_logprob.token,
63+
logprob=token_logprob.logprob,
64+
bytes=token_logprob.bytes or [],
65+
top_logprobs=[
66+
LogprobTopLogprob(
67+
token=top_logprob.token,
68+
logprob=top_logprob.logprob,
69+
bytes=top_logprob.bytes or [],
70+
)
71+
for top_logprob in token_logprob.top_logprobs
72+
],
73+
)
74+
)
75+
return converted
76+
77+
@classmethod
78+
def convert_logprobs_for_text_delta(
79+
cls, logprobs: list[ChatCompletionTokenLogprob] | None
80+
) -> list[DeltaLogprob] | None:
81+
if not logprobs:
82+
return None
83+
84+
converted: list[DeltaLogprob] = []
85+
for token_logprob in logprobs:
86+
converted.append(
87+
DeltaLogprob(
88+
token=token_logprob.token,
89+
logprob=token_logprob.logprob,
90+
top_logprobs=[
91+
DeltaTopLogprob(
92+
token=top_logprob.token,
93+
logprob=top_logprob.logprob,
94+
)
95+
for top_logprob in token_logprob.top_logprobs
96+
]
97+
or None,
98+
)
99+
)
100+
return converted

src/agents/models/chatcmpl_stream_handler.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
4343

4444
from ..items import TResponseStreamEvent
45+
from .chatcmpl_helpers import ChatCmplHelpers
4546
from .fake_id import FAKE_RESPONSES_ID
4647

4748

@@ -105,6 +106,7 @@ async def handle_stream(
105106
continue
106107

107108
delta = chunk.choices[0].delta
109+
choice_logprobs = chunk.choices[0].logprobs
108110

109111
# Handle thinking blocks from Anthropic (for preserving signatures)
110112
if hasattr(delta, "thinking_blocks") and delta.thinking_blocks:
@@ -266,6 +268,15 @@ async def handle_stream(
266268
type="response.content_part.added",
267269
sequence_number=sequence_number.get_and_increment(),
268270
)
271+
delta_logprobs = (
272+
ChatCmplHelpers.convert_logprobs_for_text_delta(
273+
choice_logprobs.content if choice_logprobs else None
274+
)
275+
or []
276+
)
277+
output_logprobs = ChatCmplHelpers.convert_logprobs_for_output_text(
278+
choice_logprobs.content if choice_logprobs else None
279+
)
269280
# Emit the delta for this segment of content
270281
yield ResponseTextDeltaEvent(
271282
content_index=state.text_content_index_and_output[0],
@@ -275,10 +286,15 @@ async def handle_stream(
275286
is not None, # fixed 0 -> 0 or 1
276287
type="response.output_text.delta",
277288
sequence_number=sequence_number.get_and_increment(),
278-
logprobs=[],
289+
logprobs=delta_logprobs,
279290
)
280291
# Accumulate the text into the response part
281292
state.text_content_index_and_output[1].text += delta.content
293+
if output_logprobs:
294+
existing_logprobs = state.text_content_index_and_output[1].logprobs or []
295+
state.text_content_index_and_output[1].logprobs = (
296+
existing_logprobs + output_logprobs
297+
)
282298

283299
# Handle refusals (model declines to answer)
284300
# This is always set by the OpenAI API, but not by others e.g. LiteLLM

src/agents/models/openai_chatcompletions.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
from openai.types import ChatModel
1010
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
1111
from openai.types.chat.chat_completion import Choice
12-
from openai.types.responses import Response
12+
from openai.types.responses import (
13+
Response,
14+
ResponseOutputItem,
15+
ResponseOutputMessage,
16+
ResponseOutputText,
17+
)
18+
from openai.types.responses.response_output_text import Logprob
1319
from openai.types.responses.response_prompt_param import ResponsePromptParam
1420

1521
from .. import _debug
@@ -119,12 +125,33 @@ async def get_response(
119125

120126
items = Converter.message_to_output_items(message) if message is not None else []
121127

128+
logprob_models = None
129+
if first_choice and first_choice.logprobs and first_choice.logprobs.content:
130+
logprob_models = ChatCmplHelpers.convert_logprobs_for_output_text(
131+
first_choice.logprobs.content
132+
)
133+
134+
if logprob_models:
135+
self._attach_logprobs_to_output(items, logprob_models)
136+
122137
return ModelResponse(
123138
output=items,
124139
usage=usage,
125140
response_id=None,
126141
)
127142

143+
def _attach_logprobs_to_output(
144+
self, output_items: list[ResponseOutputItem], logprobs: list[Logprob]
145+
) -> None:
146+
for output_item in output_items:
147+
if not isinstance(output_item, ResponseOutputMessage):
148+
continue
149+
150+
for content in output_item.content:
151+
if isinstance(content, ResponseOutputText):
152+
content.logprobs = logprobs
153+
return
154+
128155
async def stream_response(
129156
self,
130157
system_instructions: str | None,

tests/test_openai_chatcompletions.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66
import httpx
77
import pytest
88
from openai import AsyncOpenAI, omit
9-
from openai.types.chat.chat_completion import ChatCompletion, Choice
9+
from openai.types.chat.chat_completion import ChatCompletion, Choice, ChoiceLogprobs
1010
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
1111
from openai.types.chat.chat_completion_message import ChatCompletionMessage
1212
from openai.types.chat.chat_completion_message_tool_call import ( # type: ignore[attr-defined]
1313
ChatCompletionMessageFunctionToolCall,
1414
Function,
1515
)
16+
from openai.types.chat.chat_completion_token_logprob import (
17+
ChatCompletionTokenLogprob,
18+
TopLogprob,
19+
)
1620
from openai.types.completion_usage import (
1721
CompletionUsage,
1822
PromptTokensDetails,
@@ -98,6 +102,65 @@ async def patched_fetch_response(self, *args, **kwargs):
98102
assert resp.response_id is None
99103

100104

105+
@pytest.mark.allow_call_model_methods
106+
@pytest.mark.asyncio
107+
async def test_get_response_attaches_logprobs(monkeypatch) -> None:
108+
msg = ChatCompletionMessage(role="assistant", content="Hi!")
109+
choice = Choice(
110+
index=0,
111+
finish_reason="stop",
112+
message=msg,
113+
logprobs=ChoiceLogprobs(
114+
content=[
115+
ChatCompletionTokenLogprob(
116+
token="Hi",
117+
logprob=-0.5,
118+
bytes=[1],
119+
top_logprobs=[TopLogprob(token="Hi", logprob=-0.5, bytes=[1])],
120+
),
121+
ChatCompletionTokenLogprob(
122+
token="!",
123+
logprob=-0.1,
124+
bytes=[2],
125+
top_logprobs=[TopLogprob(token="!", logprob=-0.1, bytes=[2])],
126+
),
127+
]
128+
),
129+
)
130+
chat = ChatCompletion(
131+
id="resp-id",
132+
created=0,
133+
model="fake",
134+
object="chat.completion",
135+
choices=[choice],
136+
usage=None,
137+
)
138+
139+
async def patched_fetch_response(self, *args, **kwargs):
140+
return chat
141+
142+
monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response)
143+
model = OpenAIProvider(use_responses=False).get_model("gpt-4")
144+
resp: ModelResponse = await model.get_response(
145+
system_instructions=None,
146+
input="",
147+
model_settings=ModelSettings(),
148+
tools=[],
149+
output_schema=None,
150+
handoffs=[],
151+
tracing=ModelTracing.DISABLED,
152+
previous_response_id=None,
153+
conversation_id=None,
154+
prompt=None,
155+
)
156+
assert len(resp.output) == 1
157+
assert isinstance(resp.output[0], ResponseOutputMessage)
158+
text_part = resp.output[0].content[0]
159+
assert isinstance(text_part, ResponseOutputText)
160+
assert text_part.logprobs is not None
161+
assert [lp.token for lp in text_part.logprobs] == ["Hi", "!"]
162+
163+
101164
@pytest.mark.allow_call_model_methods
102165
@pytest.mark.asyncio
103166
async def test_get_response_with_refusal(monkeypatch) -> None:

tests/test_openai_chatcompletions_stream.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
ChoiceDelta,
88
ChoiceDeltaToolCall,
99
ChoiceDeltaToolCallFunction,
10+
ChoiceLogprobs,
11+
)
12+
from openai.types.chat.chat_completion_token_logprob import (
13+
ChatCompletionTokenLogprob,
14+
TopLogprob,
1015
)
1116
from openai.types.completion_usage import (
1217
CompletionTokensDetails,
@@ -15,6 +20,7 @@
1520
)
1621
from openai.types.responses import (
1722
Response,
23+
ResponseCompletedEvent,
1824
ResponseFunctionToolCall,
1925
ResponseOutputMessage,
2026
ResponseOutputRefusal,
@@ -128,6 +134,113 @@ async def patched_fetch_response(self, *args, **kwargs):
128134
assert completed_resp.usage.output_tokens_details.reasoning_tokens == 3
129135

130136

137+
@pytest.mark.allow_call_model_methods
138+
@pytest.mark.asyncio
139+
async def test_stream_response_includes_logprobs(monkeypatch) -> None:
140+
chunk1 = ChatCompletionChunk(
141+
id="chunk-id",
142+
created=1,
143+
model="fake",
144+
object="chat.completion.chunk",
145+
choices=[
146+
Choice(
147+
index=0,
148+
delta=ChoiceDelta(content="Hi"),
149+
logprobs=ChoiceLogprobs(
150+
content=[
151+
ChatCompletionTokenLogprob(
152+
token="Hi",
153+
logprob=-0.5,
154+
bytes=[1],
155+
top_logprobs=[TopLogprob(token="Hi", logprob=-0.5, bytes=[1])],
156+
)
157+
]
158+
),
159+
)
160+
],
161+
)
162+
chunk2 = ChatCompletionChunk(
163+
id="chunk-id",
164+
created=1,
165+
model="fake",
166+
object="chat.completion.chunk",
167+
choices=[
168+
Choice(
169+
index=0,
170+
delta=ChoiceDelta(content=" there"),
171+
logprobs=ChoiceLogprobs(
172+
content=[
173+
ChatCompletionTokenLogprob(
174+
token=" there",
175+
logprob=-0.25,
176+
bytes=[2],
177+
top_logprobs=[TopLogprob(token=" there", logprob=-0.25, bytes=[2])],
178+
)
179+
]
180+
),
181+
)
182+
],
183+
usage=CompletionUsage(
184+
completion_tokens=5,
185+
prompt_tokens=7,
186+
total_tokens=12,
187+
prompt_tokens_details=PromptTokensDetails(cached_tokens=2),
188+
completion_tokens_details=CompletionTokensDetails(reasoning_tokens=3),
189+
),
190+
)
191+
192+
async def fake_stream() -> AsyncIterator[ChatCompletionChunk]:
193+
for c in (chunk1, chunk2):
194+
yield c
195+
196+
async def patched_fetch_response(self, *args, **kwargs):
197+
resp = Response(
198+
id="resp-id",
199+
created_at=0,
200+
model="fake-model",
201+
object="response",
202+
output=[],
203+
tool_choice="none",
204+
tools=[],
205+
parallel_tool_calls=False,
206+
)
207+
return resp, fake_stream()
208+
209+
monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response)
210+
model = OpenAIProvider(use_responses=False).get_model("gpt-4")
211+
output_events = []
212+
async for event in model.stream_response(
213+
system_instructions=None,
214+
input="",
215+
model_settings=ModelSettings(),
216+
tools=[],
217+
output_schema=None,
218+
handoffs=[],
219+
tracing=ModelTracing.DISABLED,
220+
previous_response_id=None,
221+
conversation_id=None,
222+
prompt=None,
223+
):
224+
output_events.append(event)
225+
226+
text_delta_events = [
227+
event for event in output_events if event.type == "response.output_text.delta"
228+
]
229+
assert len(text_delta_events) == 2
230+
assert [lp.token for lp in text_delta_events[0].logprobs] == ["Hi"]
231+
assert [lp.token for lp in text_delta_events[1].logprobs] == [" there"]
232+
233+
completed_event = next(event for event in output_events if event.type == "response.completed")
234+
assert isinstance(completed_event, ResponseCompletedEvent)
235+
completed_resp = completed_event.response
236+
assert isinstance(completed_resp.output[0], ResponseOutputMessage)
237+
text_part = completed_resp.output[0].content[0]
238+
assert isinstance(text_part, ResponseOutputText)
239+
assert text_part.text == "Hi there"
240+
assert text_part.logprobs is not None
241+
assert [lp.token for lp in text_part.logprobs] == ["Hi", " there"]
242+
243+
131244
@pytest.mark.allow_call_model_methods
132245
@pytest.mark.asyncio
133246
async def test_stream_response_yields_events_for_refusal_content(monkeypatch) -> None:

0 commit comments

Comments
 (0)