|
7 | 7 | ChoiceDelta, |
8 | 8 | ChoiceDeltaToolCall, |
9 | 9 | ChoiceDeltaToolCallFunction, |
| 10 | + ChoiceLogprobs, |
| 11 | +) |
| 12 | +from openai.types.chat.chat_completion_token_logprob import ( |
| 13 | + ChatCompletionTokenLogprob, |
| 14 | + TopLogprob, |
10 | 15 | ) |
11 | 16 | from openai.types.completion_usage import ( |
12 | 17 | CompletionTokensDetails, |
|
15 | 20 | ) |
16 | 21 | from openai.types.responses import ( |
17 | 22 | Response, |
| 23 | + ResponseCompletedEvent, |
18 | 24 | ResponseFunctionToolCall, |
19 | 25 | ResponseOutputMessage, |
20 | 26 | ResponseOutputRefusal, |
@@ -128,6 +134,113 @@ async def patched_fetch_response(self, *args, **kwargs): |
128 | 134 | assert completed_resp.usage.output_tokens_details.reasoning_tokens == 3 |
129 | 135 |
|
130 | 136 |
|
| 137 | +@pytest.mark.allow_call_model_methods |
| 138 | +@pytest.mark.asyncio |
| 139 | +async def test_stream_response_includes_logprobs(monkeypatch) -> None: |
| 140 | + chunk1 = ChatCompletionChunk( |
| 141 | + id="chunk-id", |
| 142 | + created=1, |
| 143 | + model="fake", |
| 144 | + object="chat.completion.chunk", |
| 145 | + choices=[ |
| 146 | + Choice( |
| 147 | + index=0, |
| 148 | + delta=ChoiceDelta(content="Hi"), |
| 149 | + logprobs=ChoiceLogprobs( |
| 150 | + content=[ |
| 151 | + ChatCompletionTokenLogprob( |
| 152 | + token="Hi", |
| 153 | + logprob=-0.5, |
| 154 | + bytes=[1], |
| 155 | + top_logprobs=[TopLogprob(token="Hi", logprob=-0.5, bytes=[1])], |
| 156 | + ) |
| 157 | + ] |
| 158 | + ), |
| 159 | + ) |
| 160 | + ], |
| 161 | + ) |
| 162 | + chunk2 = ChatCompletionChunk( |
| 163 | + id="chunk-id", |
| 164 | + created=1, |
| 165 | + model="fake", |
| 166 | + object="chat.completion.chunk", |
| 167 | + choices=[ |
| 168 | + Choice( |
| 169 | + index=0, |
| 170 | + delta=ChoiceDelta(content=" there"), |
| 171 | + logprobs=ChoiceLogprobs( |
| 172 | + content=[ |
| 173 | + ChatCompletionTokenLogprob( |
| 174 | + token=" there", |
| 175 | + logprob=-0.25, |
| 176 | + bytes=[2], |
| 177 | + top_logprobs=[TopLogprob(token=" there", logprob=-0.25, bytes=[2])], |
| 178 | + ) |
| 179 | + ] |
| 180 | + ), |
| 181 | + ) |
| 182 | + ], |
| 183 | + usage=CompletionUsage( |
| 184 | + completion_tokens=5, |
| 185 | + prompt_tokens=7, |
| 186 | + total_tokens=12, |
| 187 | + prompt_tokens_details=PromptTokensDetails(cached_tokens=2), |
| 188 | + completion_tokens_details=CompletionTokensDetails(reasoning_tokens=3), |
| 189 | + ), |
| 190 | + ) |
| 191 | + |
| 192 | + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: |
| 193 | + for c in (chunk1, chunk2): |
| 194 | + yield c |
| 195 | + |
| 196 | + async def patched_fetch_response(self, *args, **kwargs): |
| 197 | + resp = Response( |
| 198 | + id="resp-id", |
| 199 | + created_at=0, |
| 200 | + model="fake-model", |
| 201 | + object="response", |
| 202 | + output=[], |
| 203 | + tool_choice="none", |
| 204 | + tools=[], |
| 205 | + parallel_tool_calls=False, |
| 206 | + ) |
| 207 | + return resp, fake_stream() |
| 208 | + |
| 209 | + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) |
| 210 | + model = OpenAIProvider(use_responses=False).get_model("gpt-4") |
| 211 | + output_events = [] |
| 212 | + async for event in model.stream_response( |
| 213 | + system_instructions=None, |
| 214 | + input="", |
| 215 | + model_settings=ModelSettings(), |
| 216 | + tools=[], |
| 217 | + output_schema=None, |
| 218 | + handoffs=[], |
| 219 | + tracing=ModelTracing.DISABLED, |
| 220 | + previous_response_id=None, |
| 221 | + conversation_id=None, |
| 222 | + prompt=None, |
| 223 | + ): |
| 224 | + output_events.append(event) |
| 225 | + |
| 226 | + text_delta_events = [ |
| 227 | + event for event in output_events if event.type == "response.output_text.delta" |
| 228 | + ] |
| 229 | + assert len(text_delta_events) == 2 |
| 230 | + assert [lp.token for lp in text_delta_events[0].logprobs] == ["Hi"] |
| 231 | + assert [lp.token for lp in text_delta_events[1].logprobs] == [" there"] |
| 232 | + |
| 233 | + completed_event = next(event for event in output_events if event.type == "response.completed") |
| 234 | + assert isinstance(completed_event, ResponseCompletedEvent) |
| 235 | + completed_resp = completed_event.response |
| 236 | + assert isinstance(completed_resp.output[0], ResponseOutputMessage) |
| 237 | + text_part = completed_resp.output[0].content[0] |
| 238 | + assert isinstance(text_part, ResponseOutputText) |
| 239 | + assert text_part.text == "Hi there" |
| 240 | + assert text_part.logprobs is not None |
| 241 | + assert [lp.token for lp in text_part.logprobs] == ["Hi", " there"] |
| 242 | + |
| 243 | + |
131 | 244 | @pytest.mark.allow_call_model_methods |
132 | 245 | @pytest.mark.asyncio |
133 | 246 | async def test_stream_response_yields_events_for_refusal_content(monkeypatch) -> None: |
|
0 commit comments