diff --git a/temporalio/contrib/openai_agents/__init__.py b/temporalio/contrib/openai_agents/__init__.py index d49733e8b..b70bf19ed 100644 --- a/temporalio/contrib/openai_agents/__init__.py +++ b/temporalio/contrib/openai_agents/__init__.py @@ -12,7 +12,11 @@ StatefulMCPServerProvider, StatelessMCPServerProvider, ) -from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters +from temporalio.contrib.openai_agents._model_parameters import ( + ModelActivityParameters, + ModelSummaryProvider, + StreamingOptions, +) from temporalio.contrib.openai_agents._temporal_openai_agents import ( OpenAIAgentsPlugin, OpenAIPayloadConverter, @@ -27,10 +31,12 @@ __all__ = [ "AgentsWorkflowError", "ModelActivityParameters", + "ModelSummaryProvider", "OpenAIAgentsPlugin", "OpenAIPayloadConverter", "StatelessMCPServerProvider", "StatefulMCPServerProvider", + "StreamingOptions", "testing", "workflow", ] diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index f03458c32..743902a22 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -3,11 +3,11 @@ Implements mapping of OpenAI datastructures to Pydantic friendly types. """ +import asyncio import enum -import json from dataclasses import dataclass from datetime import timedelta -from typing import Any, Optional, Union +from typing import Any, NoReturn, Union from agents import ( AgentOutputSchemaBase, @@ -28,16 +28,19 @@ UserError, WebSearchTool, ) +from agents.items import TResponseStreamEvent from openai import ( APIStatusError, AsyncOpenAI, ) +from openai.types.responses import ResponseErrorEvent from openai.types.responses.tool_param import Mcp from pydantic_core import to_json from typing_extensions import Required, TypedDict from temporalio import activity, workflow from temporalio.contrib.openai_agents._heartbeat_decorator import _auto_heartbeater +from temporalio.contrib.openai_agents._model_parameters import StreamingOptions from temporalio.exceptions import ApplicationError @@ -148,16 +151,25 @@ class ActivityModelInput(TypedDict, total=False): prompt: Any | None +class ActivityModelInputWithSignal(ActivityModelInput): + """Input for the stream_model activity.""" + + signal: str + + class ModelActivity: """Class wrapper for model invocation activities to allow model customization. By default, we use an OpenAIProvider with retries disabled. Disabling retries in your model of choice is recommended to allow activity retries to define the retry model. """ - def __init__(self, model_provider: ModelProvider | None = None): + def __init__( + self, model_provider: ModelProvider | None, streaming_options: StreamingOptions + ): """Initialize the activity with a model provider.""" self._model_provider = model_provider or OpenAIProvider( openai_client=AsyncOpenAI(max_retries=0) ) + self._streaming_options = streaming_options @activity.defn @_auto_heartbeater @@ -165,52 +177,8 @@ async def invoke_model_activity(self, input: ActivityModelInput) -> ModelRespons """Activity that invokes a model with the given input.""" model = self._model_provider.get_model(input.get("model_name")) - async def empty_on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str: - return "" - - async def empty_on_invoke_handoff( - ctx: RunContextWrapper[Any], input: str - ) -> Any: - return None - - def make_tool(tool: ToolInput) -> Tool: - if isinstance( - tool, - ( - FileSearchTool, - WebSearchTool, - ImageGenerationTool, - CodeInterpreterTool, - ), - ): - return tool - elif isinstance(tool, HostedMCPToolInput): - return HostedMCPTool( - tool_config=tool.tool_config, - ) - elif isinstance(tool, FunctionToolInput): - return FunctionTool( - name=tool.name, - description=tool.description, - params_json_schema=tool.params_json_schema, - on_invoke_tool=empty_on_invoke_tool, - strict_json_schema=tool.strict_json_schema, - ) - else: - raise UserError(f"Unknown tool type: {tool.name}") - - tools = [make_tool(x) for x in input.get("tools", [])] - handoffs: list[Handoff[Any, Any]] = [ - Handoff( - tool_name=x.tool_name, - tool_description=x.tool_description, - input_json_schema=x.input_json_schema, - agent_name=x.agent_name, - strict_json_schema=x.strict_json_schema, - on_invoke_handoff=empty_on_invoke_handoff, - ) - for x in input.get("handoffs", []) - ] + tools = _make_tools(input) + handoffs = _make_handoffs(input) try: return await model.get_response( @@ -226,40 +194,209 @@ def make_tool(tool: ToolInput) -> Tool: prompt=input.get("prompt"), ) except APIStatusError as e: - # Listen to server hints - retry_after = None - retry_after_ms_header = e.response.headers.get("retry-after-ms") - if retry_after_ms_header is not None: - retry_after = timedelta(milliseconds=float(retry_after_ms_header)) - - if retry_after is None: - retry_after_header = e.response.headers.get("retry-after") - if retry_after_header is not None: - retry_after = timedelta(seconds=float(retry_after_header)) - - should_retry_header = e.response.headers.get("x-should-retry") - if should_retry_header == "true": - raise e - if should_retry_header == "false": - raise ApplicationError( - "Non retryable OpenAI error", - non_retryable=True, - next_retry_delay=retry_after, - ) from e - - # Specifically retryable status codes - if ( - e.response.status_code in [408, 409, 429] - or e.response.status_code >= 500 - ): - raise ApplicationError( - f"Retryable OpenAI status code: {e.response.status_code}", - non_retryable=False, - next_retry_delay=retry_after, - ) from e - - raise ApplicationError( - f"Non retryable OpenAI status code: {e.response.status_code}", - non_retryable=True, - next_retry_delay=retry_after, - ) from e + _handle_error(e) + + @activity.defn + async def stream_model(self, input: ActivityModelInputWithSignal) -> None: + """Activity that streams a model with the given input.""" + model = self._model_provider.get_model(input.get("model_name")) + + tools = _make_tools(input) + handoffs = _make_handoffs(input) + + handle = activity.client().get_workflow_handle( + workflow_id=activity.info().workflow_id + ) + + batch: list[TResponseStreamEvent] = [] + + # If the activity previously failed, notify the stream + if activity.info().attempt > 1: + batch.append( + ResponseErrorEvent( + message="Activity Failed", + sequence_number=0, + type="error", + ) + ) + try: + events = model.stream_response( + system_instructions=input.get("system_instructions"), + input=input["input"], + model_settings=input["model_settings"], + tools=tools, + output_schema=input.get("output_schema"), + handoffs=handoffs, + tracing=ModelTracing(input["tracing"]), + previous_response_id=input.get("previous_response_id"), + conversation_id=input.get("conversation_id"), + prompt=input.get("prompt"), + ) + + async def send_batch(): + if batch: + await handle.signal(input["signal"], batch) + batch.clear() + + async def send_batches(): + while True: + await asyncio.sleep( + self._streaming_options.signal_batch_latency_seconds + ) + await send_batch() + + async def read_events(): + async for event in events: + event.model_rebuild() + batch.append(event) + if self._streaming_options.callback is not None: + await self._streaming_options.callback( + input["model_settings"], event + ) + + try: + completed, pending = await asyncio.wait( + [ + asyncio.create_task(read_events()), + asyncio.create_task(send_batches()), + ], + return_when=asyncio.FIRST_COMPLETED, + ) + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + for task in completed: + await task + + except StopAsyncIteration as e: + pass + # Send any remaining events in the batch + if batch: + await send_batch() + + except APIStatusError as e: + _handle_error(e) + + @activity.defn + async def batch_stream_model( + self, input: ActivityModelInput + ) -> list[TResponseStreamEvent]: + """Activity that streams a model with the given input.""" + model = self._model_provider.get_model(input.get("model_name")) + + tools = _make_tools(input) + handoffs = _make_handoffs(input) + + events = model.stream_response( + system_instructions=input.get("system_instructions"), + input=input["input"], + model_settings=input["model_settings"], + tools=tools, + output_schema=input.get("output_schema"), + handoffs=handoffs, + tracing=ModelTracing(input["tracing"]), + previous_response_id=input.get("previous_response_id"), + conversation_id=input.get("conversation_id"), + prompt=input.get("prompt"), + ) + result = [] + async for event in events: + event.model_rebuild() + result.append(event) + if self._streaming_options.callback is not None: + await self._streaming_options.callback(input["model_settings"], event) + + return result + + +async def _empty_on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str: + return "" + + +async def _empty_on_invoke_handoff(ctx: RunContextWrapper[Any], input: str) -> Any: + return None + + +def _make_tool(tool: ToolInput) -> Tool: + if isinstance( + tool, + ( + FileSearchTool, + WebSearchTool, + ImageGenerationTool, + CodeInterpreterTool, + ), + ): + return tool + elif isinstance(tool, HostedMCPToolInput): + return HostedMCPTool( + tool_config=tool.tool_config, + ) + elif isinstance(tool, FunctionToolInput): + return FunctionTool( + name=tool.name, + description=tool.description, + params_json_schema=tool.params_json_schema, + on_invoke_tool=_empty_on_invoke_tool, + strict_json_schema=tool.strict_json_schema, + ) + else: + raise UserError(f"Unknown tool type: {tool.name}") + + +def _make_tools(input: ActivityModelInput) -> list[Tool]: + return [_make_tool(x) for x in input.get("tools", [])] + + +def _make_handoffs(input: ActivityModelInput) -> list[Handoff[Any, Any]]: + return [ + Handoff( + tool_name=x.tool_name, + tool_description=x.tool_description, + input_json_schema=x.input_json_schema, + agent_name=x.agent_name, + strict_json_schema=x.strict_json_schema, + on_invoke_handoff=_empty_on_invoke_handoff, + ) + for x in input.get("handoffs", []) + ] + + +def _handle_error(e: APIStatusError) -> NoReturn: + # Listen to server hints + retry_after = None + retry_after_ms_header = e.response.headers.get("retry-after-ms") + if retry_after_ms_header is not None: + retry_after = timedelta(milliseconds=float(retry_after_ms_header)) + + if retry_after is None: + retry_after_header = e.response.headers.get("retry-after") + if retry_after_header is not None: + retry_after = timedelta(seconds=float(retry_after_header)) + + should_retry_header = e.response.headers.get("x-should-retry") + if should_retry_header == "true": + raise e + if should_retry_header == "false": + raise ApplicationError( + "Non retryable OpenAI error", + non_retryable=True, + next_retry_delay=retry_after, + ) from e + + # Specifically retryable status codes + if e.response.status_code in [408, 409, 429] or e.response.status_code >= 500: + raise ApplicationError( + f"Retryable OpenAI status code: {e.response.status_code}", + non_retryable=False, + next_retry_delay=retry_after, + ) from e + + raise ApplicationError( + f"Non retryable OpenAI status code: {e.response.status_code}", + non_retryable=True, + next_retry_delay=retry_after, + ) from e diff --git a/temporalio/contrib/openai_agents/_model_parameters.py b/temporalio/contrib/openai_agents/_model_parameters.py index 3cab91c27..376f4d2db 100644 --- a/temporalio/contrib/openai_agents/_model_parameters.py +++ b/temporalio/contrib/openai_agents/_model_parameters.py @@ -4,9 +4,10 @@ from collections.abc import Callable from dataclasses import dataclass from datetime import timedelta -from typing import Any, Optional, Union +from typing import Any, Awaitable, Optional, Union -from agents import Agent, TResponseInputItem +from agents import Agent, ModelSettings, TResponseInputItem +from agents.items import TResponseStreamEvent from temporalio.common import Priority, RetryPolicy from temporalio.workflow import ActivityCancellationType, VersioningIntent @@ -69,3 +70,22 @@ class ModelActivityParameters: use_local_activity: bool = False """Whether to use a local activity. If changed during a workflow execution, that would break determinism.""" + + +@dataclass +class StreamingOptions: + """Options applicable for use of run_streamed""" + + callback: ( + Callable[[ModelSettings, TResponseStreamEvent], Awaitable[None]] | None + ) = None + """A callback function that will be invoked inside the activity on every stream event which occurs. + ModelSettings are provided so that the callback can distinguish what to do based on extra_args if desired.""" + + use_signals: bool = False + """If true, the activity will use signals to provide events to the workflow as they occur. Ensure that the workflow + appropriately handles those signals during replay. If false, all the stream events will be delivered when the activity completes.""" + + signal_batch_latency_seconds: float = 1.0 + """Batch latency for sending signals. Lower values will result in lower stream event latency but higher + signal volume, and therefore cost.""" diff --git a/temporalio/contrib/openai_agents/_openai_runner.py b/temporalio/contrib/openai_agents/_openai_runner.py index a8065d207..7bac44ddb 100644 --- a/temporalio/contrib/openai_agents/_openai_runner.py +++ b/temporalio/contrib/openai_agents/_openai_runner.py @@ -15,10 +15,13 @@ Tool, TResponseInputItem, ) -from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner +from agents.run import DEFAULT_AGENT_RUNNER, AgentRunner from temporalio import workflow -from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters +from temporalio.contrib.openai_agents._model_parameters import ( + ModelActivityParameters, + StreamingOptions, +) from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError @@ -26,6 +29,7 @@ # Recursively replace models in all agents def _convert_agent( model_params: ModelActivityParameters, + streaming_options: StreamingOptions, agent: Agent[Any], seen: dict[int, Agent] | None, ) -> Agent[Any]: @@ -49,13 +53,17 @@ def _convert_agent( new_handoffs: list[Agent | Handoff] = [] for handoff in agent.handoffs: if isinstance(handoff, Agent): - new_handoffs.append(_convert_agent(model_params, handoff, seen)) + new_handoffs.append( + _convert_agent(model_params, streaming_options, handoff, seen) + ) elif isinstance(handoff, Handoff): original_invoke = handoff.on_invoke_handoff async def on_invoke(context: RunContextWrapper[Any], args: str) -> Agent: handoff_agent = await original_invoke(context, args) - return _convert_agent(model_params, handoff_agent, seen) + return _convert_agent( + model_params, streaming_options, handoff_agent, seen + ) new_handoffs.append( dataclasses.replace(handoff, on_invoke_handoff=on_invoke) @@ -67,6 +75,7 @@ async def on_invoke(context: RunContextWrapper[Any], args: str) -> Agent: model_name=name, model_params=model_params, agent=agent, + streaming_options=streaming_options, ) new_agent.handoffs = new_handoffs return new_agent @@ -79,10 +88,13 @@ class TemporalOpenAIRunner(AgentRunner): """ - def __init__(self, model_params: ModelActivityParameters) -> None: + def __init__( + self, model_params: ModelActivityParameters, streaming_options: StreamingOptions + ) -> None: """Initialize the Temporal OpenAI Runner.""" self._runner = DEFAULT_AGENT_RUNNER or AgentRunner() self.model_params = model_params + self.streaming_options = streaming_options async def run( self, @@ -98,66 +110,17 @@ async def run( **kwargs, ) - tool_types = typing.get_args(Tool) - for t in starting_agent.tools: - if not isinstance(t, tool_types): - raise ValueError( - "Provided tool is not a tool type. If using an activity, make sure to wrap it with openai_agents.workflow.activity_as_tool." - ) - - if starting_agent.mcp_servers: - from temporalio.contrib.openai_agents._mcp import ( - _StatefulMCPServerReference, - _StatelessMCPServerReference, - ) - - for s in starting_agent.mcp_servers: - if not isinstance( - s, - ( - _StatelessMCPServerReference, - _StatefulMCPServerReference, - ), - ): - raise ValueError( - f"Unknown mcp_server type {type(s)} may not work durably." - ) - - context = kwargs.get("context") - max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) - hooks = kwargs.get("hooks") - run_config = kwargs.get("run_config") - previous_response_id = kwargs.get("previous_response_id") - session = kwargs.get("session") - - if isinstance(session, SQLiteSession): - raise ValueError("Temporal workflows don't support SQLite sessions.") + _check_preconditions(starting_agent, **kwargs) - if run_config is None: - run_config = RunConfig() - - if run_config.model: - if not isinstance(run_config.model, str): - raise ValueError( - "Temporal workflows require a model name to be a string in the run config." - ) - run_config = dataclasses.replace( - run_config, - model=_TemporalModelStub( - run_config.model, model_params=self.model_params, agent=None - ), - ) + kwargs["run_config"] = self._process_run_config(kwargs.get("run_config")) try: return await self._runner.run( - starting_agent=_convert_agent(self.model_params, starting_agent, None), + starting_agent=_convert_agent( + self.model_params, self.streaming_options, starting_agent, None + ), input=input, - context=context, - max_turns=max_turns, - hooks=hooks, - run_config=run_config, - previous_response_id=previous_response_id, - session=session, + **kwargs, ) except AgentsException as e: # In order for workflow failures to properly fail the workflow, we need to rewrap them in @@ -199,7 +162,50 @@ def run_streamed( input, **kwargs, ) - raise RuntimeError("Temporal workflows do not support streaming.") + + _check_preconditions(starting_agent, **kwargs) + + kwargs["run_config"] = self._process_run_config(kwargs.get("run_config")) + + try: + return self._runner.run_streamed( + starting_agent=_convert_agent( + self.model_params, self.streaming_options, starting_agent, None + ), + input=input, + **kwargs, + ) + except AgentsException as e: + # In order for workflow failures to properly fail the workflow, we need to rewrap them in + # a Temporal error + if e.__cause__ and workflow.is_failure_exception(e.__cause__): + reraise = AgentsWorkflowError( + f"Workflow failure exception in Agents Framework: {e}" + ) + reraise.__traceback__ = e.__traceback__ + raise reraise from e.__cause__ + else: + raise e + + def _process_run_config(self, run_config: RunConfig | None) -> RunConfig: + if run_config is None: + run_config = RunConfig() + + if run_config.model: + if not isinstance(run_config.model, str): + raise ValueError( + "Temporal workflows require a model name to be a string in the run config." + ) + run_config = dataclasses.replace( + run_config, + model=_TemporalModelStub( + run_config.model, + model_params=self.model_params, + streaming_options=self.streaming_options, + agent=None, + ), + ) + return run_config def _model_name(agent: Agent[Any]) -> str | None: @@ -209,3 +215,34 @@ def _model_name(agent: Agent[Any]) -> str | None: "Temporal workflows require a model name to be a string in the agent." ) return name + + +def _check_preconditions(starting_agent: Agent[TContext], **kwargs: Any) -> None: + tool_types = typing.get_args(Tool) + for t in starting_agent.tools: + if not isinstance(t, tool_types): + raise ValueError( + "Provided tool is not a tool type. If using an activity, make sure to wrap it with openai_agents.workflow.activity_as_tool." + ) + + if starting_agent.mcp_servers: + from temporalio.contrib.openai_agents._mcp import ( + _StatefulMCPServerReference, + _StatelessMCPServerReference, + ) + + for s in starting_agent.mcp_servers: + if not isinstance( + s, + ( + _StatelessMCPServerReference, + _StatefulMCPServerReference, + ), + ): + raise ValueError( + f"Unknown mcp_server type {type(s)} may not work durably." + ) + + session = kwargs.get("session") + if isinstance(session, SQLiteSession): + raise ValueError("Temporal workflows don't support SQLite sessions.") diff --git a/temporalio/contrib/openai_agents/_temporal_model_stub.py b/temporalio/contrib/openai_agents/_temporal_model_stub.py index f84488541..44887db15 100644 --- a/temporalio/contrib/openai_agents/_temporal_model_stub.py +++ b/temporalio/contrib/openai_agents/_temporal_model_stub.py @@ -1,15 +1,18 @@ from __future__ import annotations +import asyncio import logging -from typing import Optional from temporalio import workflow -from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters +from temporalio.contrib.openai_agents._model_parameters import ( + ModelActivityParameters, + StreamingOptions, +) logger = logging.getLogger(__name__) from collections.abc import AsyncIterator -from typing import Any, Union, cast +from typing import Any from agents import ( Agent, @@ -34,6 +37,7 @@ from temporalio.contrib.openai_agents._invoke_model_activity import ( ActivityModelInput, + ActivityModelInputWithSignal, AgentOutputSchemaInput, FunctionToolInput, HandoffInput, @@ -53,10 +57,13 @@ def __init__( *, model_params: ModelActivityParameters, agent: Agent[Any] | None, + streaming_options: StreamingOptions, ) -> None: self.model_name = model_name self.model_params = model_params self.agent = agent + self.stream_events: list[TResponseStreamEvent] = [] + self.streaming_options = streaming_options async def get_response( self, @@ -72,88 +79,25 @@ async def get_response( conversation_id: str | None, prompt: ResponsePromptParam | None, ) -> ModelResponse: - def make_tool_info(tool: Tool) -> ToolInput: - if isinstance( - tool, - ( - FileSearchTool, - WebSearchTool, - ImageGenerationTool, - CodeInterpreterTool, - ), - ): - return tool - elif isinstance(tool, HostedMCPTool): - return HostedMCPToolInput(tool_config=tool.tool_config) - elif isinstance(tool, FunctionTool): - return FunctionToolInput( - name=tool.name, - description=tool.description, - params_json_schema=tool.params_json_schema, - strict_json_schema=tool.strict_json_schema, - ) - else: - raise ValueError(f"Unsupported tool type: {tool.name}") - - tool_infos = [make_tool_info(x) for x in tools] - handoff_infos = [ - HandoffInput( - tool_name=x.tool_name, - tool_description=x.tool_description, - input_json_schema=x.input_json_schema, - agent_name=x.agent_name, - strict_json_schema=x.strict_json_schema, - ) - for x in handoffs - ] - if output_schema is not None and not isinstance( - output_schema, AgentOutputSchema - ): - raise TypeError( - f"Only AgentOutputSchema is supported by Temporal Model, got {type(output_schema).__name__}" - ) - agent_output_schema = output_schema - output_schema_input = ( - None - if agent_output_schema is None - else AgentOutputSchemaInput( - output_type_name=agent_output_schema.name(), - is_wrapped=agent_output_schema._is_wrapped, - output_schema=agent_output_schema.json_schema() - if not agent_output_schema.is_plain_text() - else None, - strict_json_schema=agent_output_schema.is_strict_json_schema(), - ) - ) + tool_inputs = _make_tool_inputs(tools) + handoff_inputs = _make_handoff_inputs(handoffs) + output_schema_input = _make_output_schema_input(output_schema) activity_input = ActivityModelInput( model_name=self.model_name, system_instructions=system_instructions, input=input, model_settings=model_settings, - tools=tool_infos, + tools=tool_inputs, output_schema=output_schema_input, - handoffs=handoff_infos, + handoffs=handoff_inputs, tracing=ModelTracingInput(tracing.value), previous_response_id=previous_response_id, conversation_id=conversation_id, prompt=prompt, ) - if self.model_params.summary_override: - summary = ( - self.model_params.summary_override - if isinstance(self.model_params.summary_override, str) - else ( - self.model_params.summary_override.provide( - self.agent, system_instructions, input - ) - ) - ) - elif self.agent: - summary = self.agent.name - else: - summary = None + summary = self._make_summary(system_instructions, input) if self.model_params.use_local_activity: return await workflow.execute_local_activity_method( @@ -196,7 +140,111 @@ def stream_response( conversation_id: str | None, prompt: ResponsePromptParam | None, ) -> AsyncIterator[TResponseStreamEvent]: - raise NotImplementedError("Temporal model doesn't support streams yet") + if self.model_params.use_local_activity: + raise ValueError("Streaming is not available with local activities.") + + tool_inputs = _make_tool_inputs(tools) + handoff_inputs = _make_handoff_inputs(handoffs) + output_schema_input = _make_output_schema_input(output_schema) + + summary = self._make_summary(system_instructions, input) + + stream_queue: asyncio.Queue[TResponseStreamEvent | None] = asyncio.Queue() + + async def handle_stream_event(events: list[TResponseStreamEvent]): + for event in events: + await stream_queue.put(event) + + signal_name = "model_stream_signal" + workflow.set_signal_handler(signal_name, handle_stream_event) + + activity_input = ActivityModelInput( + model_name=self.model_name, + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tool_inputs, + output_schema=output_schema_input, + handoffs=handoff_inputs, + tracing=ModelTracingInput(tracing.value), + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt, + ) + if self.streaming_options.use_signals: + handle = workflow.start_activity_method( + ModelActivity.stream_model, + args=[ + ActivityModelInputWithSignal(**activity_input, signal=signal_name) + ], + summary=summary, + task_queue=self.model_params.task_queue, + schedule_to_close_timeout=self.model_params.schedule_to_close_timeout, + schedule_to_start_timeout=self.model_params.schedule_to_start_timeout, + start_to_close_timeout=self.model_params.start_to_close_timeout, + heartbeat_timeout=self.model_params.heartbeat_timeout, + retry_policy=self.model_params.retry_policy, + cancellation_type=self.model_params.cancellation_type, + versioning_intent=self.model_params.versioning_intent, + priority=self.model_params.priority, + ) + + async def monitor_activity(): + try: + await handle + finally: + await stream_queue.put(None) # Signal end of stream + + monitor_task = asyncio.create_task(monitor_activity()) + + async def generator() -> AsyncIterator[TResponseStreamEvent]: + while True: + item = await stream_queue.get() + if item is None: + await monitor_task + return + yield item + + return generator() + else: + + async def generator() -> AsyncIterator[TResponseStreamEvent]: + results = await workflow.execute_activity_method( + ModelActivity.batch_stream_model, + args=[activity_input], + summary=summary, + task_queue=self.model_params.task_queue, + schedule_to_close_timeout=self.model_params.schedule_to_close_timeout, + schedule_to_start_timeout=self.model_params.schedule_to_start_timeout, + start_to_close_timeout=self.model_params.start_to_close_timeout, + heartbeat_timeout=self.model_params.heartbeat_timeout, + retry_policy=self.model_params.retry_policy, + cancellation_type=self.model_params.cancellation_type, + versioning_intent=self.model_params.versioning_intent, + priority=self.model_params.priority, + ) + for event in results: + yield event + + return generator() + + def _make_summary( + self, system_instructions: str | None, input: str | list[TResponseInputItem] + ) -> str | None: + if self.model_params.summary_override: + return ( + self.model_params.summary_override + if isinstance(self.model_params.summary_override, str) + else ( + self.model_params.summary_override.provide( + self.agent, system_instructions, input + ) + ) + ) + elif self.agent: + return self.agent.name + else: + return None def _extract_summary(input: str | list[TResponseInputItem]) -> str: @@ -228,3 +276,67 @@ def _extract_summary(input: str | list[TResponseInputItem]) -> str: except Exception as e: logger.error(f"Error getting summary: {e}") return "" + + +def _make_tool_input(tool: Tool) -> ToolInput: + if isinstance( + tool, + ( + FileSearchTool, + WebSearchTool, + ImageGenerationTool, + CodeInterpreterTool, + ), + ): + return tool + elif isinstance(tool, HostedMCPTool): + return HostedMCPToolInput(tool_config=tool.tool_config) + elif isinstance(tool, FunctionTool): + return FunctionToolInput( + name=tool.name, + description=tool.description, + params_json_schema=tool.params_json_schema, + strict_json_schema=tool.strict_json_schema, + ) + else: + raise ValueError(f"Unsupported tool type: {tool.name}") + + +def _make_tool_inputs(tools: list[Tool]) -> list[ToolInput]: + return [_make_tool_input(x) for x in tools] + + +def _make_handoff_inputs(handoffs: list[Handoff]) -> list[HandoffInput]: + return [ + HandoffInput( + tool_name=x.tool_name, + tool_description=x.tool_description, + input_json_schema=x.input_json_schema, + agent_name=x.agent_name, + strict_json_schema=x.strict_json_schema, + ) + for x in handoffs + ] + + +def _make_output_schema_input( + output_schema: AgentOutputSchemaBase | None, +) -> AgentOutputSchemaInput | None: + if output_schema is not None and not isinstance(output_schema, AgentOutputSchema): + raise TypeError( + f"Only AgentOutputSchema is supported by Temporal Model, got {type(output_schema).__name__}" + ) + + agent_output_schema = output_schema + return ( + None + if agent_output_schema is None + else AgentOutputSchemaInput( + output_type_name=agent_output_schema.name(), + is_wrapped=agent_output_schema._is_wrapped, + output_schema=agent_output_schema.json_schema() + if not agent_output_schema.is_plain_text() + else None, + strict_json_schema=agent_output_schema.is_strict_json_schema(), + ) + ) diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index 41ae419f7..6fb3d8f3b 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterator, Callable, Sequence from contextlib import asynccontextmanager, contextmanager from datetime import timedelta -from typing import Optional, Union +from typing import Union from agents import ModelProvider, set_trace_provider from agents.run import get_default_agent_runner, set_default_agent_runner @@ -13,7 +13,10 @@ from agents.tracing.provider import DefaultTraceProvider from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity -from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters +from temporalio.contrib.openai_agents._model_parameters import ( + ModelActivityParameters, + StreamingOptions, +) from temporalio.contrib.openai_agents._openai_runner import ( TemporalOpenAIRunner, ) @@ -47,6 +50,7 @@ def set_open_ai_agent_temporal_overrides( model_params: ModelActivityParameters, auto_close_tracing_in_workflows: bool = False, + streaming_options: StreamingOptions = StreamingOptions(), ): """Configure Temporal-specific overrides for OpenAI agents. @@ -67,6 +71,7 @@ def set_open_ai_agent_temporal_overrides( Args: model_params: Configuration parameters for Temporal activity execution of model calls. auto_close_tracing_in_workflows: If set to true, close tracing spans immediately. + streaming_options: Options applicable for use of run_streamed. Returns: A context manager that yields the configured TemporalTraceProvider. @@ -78,7 +83,7 @@ def set_open_ai_agent_temporal_overrides( ) try: - set_default_agent_runner(TemporalOpenAIRunner(model_params)) + set_default_agent_runner(TemporalOpenAIRunner(model_params, streaming_options)) set_trace_provider(provider) yield provider finally: @@ -136,6 +141,7 @@ class OpenAIAgentsPlugin(SimplePlugin): The plugin will wrap each server in a TemporalMCPServer if needed and manage their connection lifecycles tied to the worker lifetime. This is the recommended way to use MCP servers with Temporal workflows. + streaming_options: Options applicable for use of run_streamed. Example: >>> from temporalio.client import Client @@ -182,6 +188,7 @@ def __init__( Union["StatelessMCPServerProvider", "StatefulMCPServerProvider"] ] = (), register_activities: bool = True, + streaming_options: StreamingOptions = StreamingOptions(), ) -> None: """Initialize the OpenAI agents plugin. @@ -197,6 +204,7 @@ def __init__( register_activities: Whether to register activities during the worker execution. This can be disabled on some workers to allow a separation of workflows and activities but should not be disabled on all workers, or agents will not be able to progress. + streaming_options: Options applicable for use of run_streamed. """ if model_params is None: model_params = ModelActivityParameters() @@ -221,7 +229,12 @@ def add_activities( if not register_activities: return activities or [] - new_activities = [ModelActivity(model_provider).invoke_model_activity] + model_activity = ModelActivity(model_provider, streaming_options) + new_activities = [ + model_activity.invoke_model_activity, + model_activity.stream_model, + model_activity.batch_stream_model, + ] server_names = [server.name for server in mcp_server_providers] if len(server_names) != len(set(server_names)): @@ -247,7 +260,9 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: @asynccontextmanager async def run_context() -> AsyncIterator[None]: - with set_open_ai_agent_temporal_overrides(model_params): + with set_open_ai_agent_temporal_overrides( + model_params, streaming_options=streaming_options + ): yield super().__init__( diff --git a/temporalio/contrib/openai_agents/testing.py b/temporalio/contrib/openai_agents/testing.py index 4acab196a..03ac23d58 100644 --- a/temporalio/contrib/openai_agents/testing.py +++ b/temporalio/contrib/openai_agents/testing.py @@ -17,9 +17,14 @@ ) from agents.items import TResponseOutputItem, TResponseStreamEvent from openai.types.responses import ( + Response, + ResponseCompletedEvent, + ResponseContentPartDoneEvent, ResponseFunctionToolCall, + ResponseOutputItemDoneEvent, ResponseOutputMessage, ResponseOutputText, + ResponseTextDeltaEvent, ) from temporalio.client import Client @@ -27,7 +32,10 @@ StatefulMCPServerProvider, StatelessMCPServerProvider, ) -from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters +from temporalio.contrib.openai_agents._model_parameters import ( + ModelActivityParameters, + StreamingOptions, +) from temporalio.contrib.openai_agents._temporal_openai_agents import OpenAIAgentsPlugin __all__ = [ @@ -109,6 +117,100 @@ def output_message(text: str) -> ModelResponse: ) +class EventBuilders: + """Builders for creating stream events for testing. + + .. warning:: + This API is experimental and may change in the future. + """ + + @staticmethod + def text_delta(text: str) -> ResponseTextDeltaEvent: + """Create a TResponseStreamEvent with an text delta. + + .. warning:: + This API is experimental and may change in the future. + """ + return ResponseTextDeltaEvent( + content_index=0, + delta=text, + item_id="", + logprobs=[], + output_index=0, + sequence_number=0, + type="response.output_text.delta", + ) + + @staticmethod + def content_part_done(text: str) -> TResponseStreamEvent: + """Create a TResponseStreamEvent for content part completion. + + .. warning:: + This API is experimental and may change in the future. + """ + return ResponseContentPartDoneEvent( + content_index=0, + item_id="", + output_index=0, + sequence_number=0, + type="response.content_part.done", + part=ResponseOutputText( + text=text, + annotations=[], + type="output_text", + ), + ) + + @staticmethod + def output_item_done(text: str) -> TResponseStreamEvent: + """Create a TResponseStreamEvent for output item completion. + + .. warning:: + This API is experimental and may change in the future. + """ + return ResponseOutputItemDoneEvent( + output_index=0, + sequence_number=0, + type="response.output_item.done", + item=ResponseBuilders.response_output_message(text), + ) + + @staticmethod + def response_completion(text: str) -> TResponseStreamEvent: + """Create a TResponseStreamEvent for response completion. + + .. warning:: + This API is experimental and may change in the future. + """ + return ResponseCompletedEvent( + response=Response( + id="", + created_at=0.0, + object="response", + model="", + parallel_tool_calls=False, + tool_choice="none", + tools=[], + output=[ResponseBuilders.response_output_message(text)], + ), + sequence_number=0, + type="response.completed", + ) + + @staticmethod + def ending(text: str) -> list[TResponseStreamEvent]: + """Create a list of TResponseStreamEvent for the end of a stream. + + .. warning:: + This API is experimental and may change in the future. + """ + return [ + EventBuilders.content_part_done(text), + EventBuilders.output_item_done(text), + EventBuilders.response_completion(text), + ] + + class TestModelProvider(ModelProvider): """Test model provider which simply returns the given module. @@ -144,13 +246,19 @@ class TestModel(Model): __test__ = False - def __init__(self, fn: Callable[[], ModelResponse]) -> None: + def __init__( + self, + fn: Callable[[], ModelResponse] | None, + *, + streaming_fn: Callable[[], AsyncIterator[TResponseStreamEvent]] | None = None, + ) -> None: """Initialize a test model with a callable. .. warning:: This API is experimental and may change in the future. """ self.fn = fn + self.streaming_fn = streaming_fn async def get_response( self, @@ -164,6 +272,8 @@ async def get_response( **kwargs, ) -> ModelResponse: """Get a response from the mocked model, by calling the callable passed to the constructor.""" + if self.fn is None: + raise ValueError("No non-streaming function provided") return self.fn() def stream_response( @@ -177,8 +287,10 @@ def stream_response( tracing: ModelTracing, **kwargs, ) -> AsyncIterator[TResponseStreamEvent]: - """Get a streamed response from the model. Unimplemented.""" - raise NotImplementedError() + """Get a streamed response from the model.""" + if self.streaming_fn is None: + raise ValueError("No streaming function provided") + return self.streaming_fn() @staticmethod def returning_responses(responses: list[ModelResponse]) -> "TestModel": @@ -190,6 +302,35 @@ def returning_responses(responses: list[ModelResponse]) -> "TestModel": i = iter(responses) return TestModel(lambda: next(i)) + @staticmethod + def streaming_events(events: list[TResponseStreamEvent]) -> "TestModel": + """Create a mock model which sequentially returns responses from a list. + + .. warning:: + This API is experimental and may change in the future. + """ + + async def generator(): + for event in events: + yield event + + return TestModel(None, streaming_fn=lambda: generator()) + + @staticmethod + def streaming_events_with_ending( + events: list[ResponseTextDeltaEvent], + ) -> "TestModel": + """Create a mock model which sequentially returns responses from a list. Appends ending markers + + .. warning:: + This API is experimental and may change in the future. + """ + content = "" + for event in events: + content += event.delta + + return TestModel.streaming_events(events + EventBuilders.ending(content)) + class AgentEnvironment: """Testing environment for OpenAI agents with Temporal integration. @@ -226,6 +367,7 @@ def __init__( StatelessMCPServerProvider | StatefulMCPServerProvider ] = (), register_activities: bool = True, + streaming_options: StreamingOptions = StreamingOptions(), ) -> None: """Initialize the AgentEnvironment. @@ -242,6 +384,7 @@ def __init__( If both are provided, model_provider will be used. mcp_server_providers: Sequence of MCP servers to automatically register with the worker. register_activities: Whether to register activities during worker execution. + streaming_options: Options applicable for use of run_streamed. .. warning:: This API is experimental and may change in the future. @@ -255,6 +398,7 @@ def __init__( self._mcp_server_providers = mcp_server_providers self._register_activities = register_activities self._plugin: OpenAIAgentsPlugin | None = None + self.streaming_options = streaming_options async def __aenter__(self) -> "AgentEnvironment": """Enter the async context manager.""" @@ -264,6 +408,7 @@ async def __aenter__(self) -> "AgentEnvironment": model_provider=self._model_provider, mcp_server_providers=self._mcp_server_providers, register_activities=self._register_activities, + streaming_options=self.streaming_options, ) return self diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 381a213cd..b63aad427 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -8,6 +8,7 @@ from datetime import timedelta from typing import ( Any, + Awaitable, Optional, Union, cast, @@ -65,6 +66,7 @@ from openai.types.responses import ( EasyInputMessageParam, ResponseCodeInterpreterToolCall, + ResponseErrorEvent, ResponseFileSearchToolCall, ResponseFunctionToolCall, ResponseFunctionToolCallParam, @@ -72,6 +74,7 @@ ResponseInputTextParam, ResponseOutputMessage, ResponseOutputText, + ResponseTextDeltaEvent, ) from openai.types.responses.response_file_search_tool_call import Result from openai.types.responses.response_function_web_search import ActionSearch @@ -90,10 +93,11 @@ from temporalio.contrib import openai_agents from temporalio.contrib.openai_agents import ( ModelActivityParameters, + ModelSummaryProvider, StatefulMCPServerProvider, StatelessMCPServerProvider, + StreamingOptions, ) -from temporalio.contrib.openai_agents._model_parameters import ModelSummaryProvider from temporalio.contrib.openai_agents._openai_runner import _convert_agent from temporalio.contrib.openai_agents._temporal_model_stub import ( _extract_summary, @@ -101,6 +105,7 @@ ) from temporalio.contrib.openai_agents.testing import ( AgentEnvironment, + EventBuilders, ResponseBuilders, TestModel, TestModelProvider, @@ -2551,7 +2556,9 @@ def override_get_activities() -> Sequence[Callable]: async def test_model_conversion_loops(): agent = init_agents() - converted = _convert_agent(ModelActivityParameters(), agent, None) + converted = _convert_agent( + ModelActivityParameters(), StreamingOptions(), agent, None + ) seat_booking_handoff = converted.handoffs[1] assert isinstance(seat_booking_handoff, Handoff) context: RunContextWrapper[AirlineAgentContext] = RunContextWrapper( @@ -2635,3 +2642,162 @@ async def test_split_workers(client: Client): execution_timeout=timedelta(seconds=120), ) assert result == "test" + + +@workflow.defn +class StreamingHelloWorldAgent: + def __init__(self): + self.events = [] + self._has_failure = False + + @workflow.run + async def run(self, prompt: str) -> str | None: + agent = Agent[None]( + name="Assistant", + instructions="You are a helpful assistant.", + ) + + result = Runner.run_streamed(starting_agent=agent, input=prompt) + async for event in result.stream_events(): + if event.type == "raw_response_event" and isinstance( + event.data, ResponseTextDeltaEvent + ): + self.events.append(event.data.delta) + if event.type == "raw_response_event" and isinstance( + event.data, ResponseErrorEvent + ): + self._has_failure = True + + return result.final_output if result else None + + @workflow.query + def get_events(self) -> list[str]: + return self.events + + @workflow.query + def has_failure(self) -> bool: + return self._has_failure + + +def streaming_hello_model(): + return TestModel.streaming_events_with_ending( + [ + EventBuilders.text_delta("Hello"), + EventBuilders.text_delta(" there"), + EventBuilders.text_delta("!"), + ] + ) + + +async def test_signal_streaming(client: Client): + async with AgentEnvironment( + model=streaming_hello_model(), + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30), + ), + streaming_options=StreamingOptions( + use_signals=True, + ), + ) as env: + client = env.applied_on_client(client) + + async with new_worker( + client, StreamingHelloWorldAgent, max_cached_workflows=0 + ) as worker: + handle = await client.start_workflow( + StreamingHelloWorldAgent.run, + args=["Say hello."], + id=f"hello-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=50), + ) + result = await handle.result() + assert result == "Hello there!" + assert len(await handle.query(StreamingHelloWorldAgent.get_events)) == 3 + + +failed = False + + +def streaming_failure_model(): + async def generator() -> AsyncIterator[TResponseStreamEvent]: + global failed + for event in [ + EventBuilders.text_delta("Hello"), + EventBuilders.text_delta(" there"), + EventBuilders.text_delta("!"), + ]: + yield event + await asyncio.sleep(0.25) + if not failed: + failed = True + raise ValueError("Intentional failure") + + for end_event in EventBuilders.ending("Hello there!"): + yield end_event + + return TestModel(None, streaming_fn=lambda: generator()) + + +async def test_signal_streaming_failure(client: Client): + async with AgentEnvironment( + model=streaming_failure_model(), + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30), + ), + streaming_options=StreamingOptions( + use_signals=True, + signal_batch_latency_seconds=0.1, + ), + ) as env: + client = env.applied_on_client(client) + + async with new_worker( + client, StreamingHelloWorldAgent, max_cached_workflows=0 + ) as worker: + handle = await client.start_workflow( + StreamingHelloWorldAgent.run, + args=["Say hello."], + id=f"hello-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=50), + ) + result = await handle.result() + assert result == "Hello there!" + assert len(await handle.query(StreamingHelloWorldAgent.get_events)) == 6 + assert await handle.query(StreamingHelloWorldAgent.has_failure) + + +async def test_callback_streaming(client: Client): + events = [] + + async def callback(_: ModelSettings, event: TResponseStreamEvent) -> None: + events.append(event) + + async with AgentEnvironment( + model=streaming_hello_model(), + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30), + ), + streaming_options=StreamingOptions( + callback=callback, + ), + ) as env: + client = env.applied_on_client(client) + + async with new_worker( + client, StreamingHelloWorldAgent, max_cached_workflows=0 + ) as worker: + handle = await client.start_workflow( + StreamingHelloWorldAgent.run, + args=["Say hello."], + id=f"hello-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=50), + ) + result = await handle.result() + assert result == "Hello there!" + assert len(await handle.query(StreamingHelloWorldAgent.get_events)) == 3 + + # The results include the ending markers because it wasn't filtered like the workflow + assert len(events) == 6