From 1d3043173feb1d1455f006449d360c1feca777df Mon Sep 17 00:00:00 2001 From: secprog Date: Fri, 5 Dec 2025 16:05:53 +0000 Subject: [PATCH] feat: Add support for custom task store in to_a2a function --- src/google/adk/a2a/utils/agent_to_a2a.py | 14 ++++++- .../unittests/a2a/utils/test_agent_to_a2a.py | 42 +++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/src/google/adk/a2a/utils/agent_to_a2a.py b/src/google/adk/a2a/utils/agent_to_a2a.py index 1a1ba35618..6f1c3516f8 100644 --- a/src/google/adk/a2a/utils/agent_to_a2a.py +++ b/src/google/adk/a2a/utils/agent_to_a2a.py @@ -21,6 +21,7 @@ from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryTaskStore +from a2a.server.tasks import TaskStore from a2a.types import AgentCard from starlette.applications import Starlette @@ -79,6 +80,7 @@ def to_a2a( protocol: str = "http", agent_card: Optional[Union[AgentCard, str]] = None, runner: Optional[Runner] = None, + task_store: Optional[TaskStore] = None, ) -> Starlette: """Convert an ADK agent to a A2A Starlette application. @@ -92,7 +94,9 @@ def to_a2a( agent. runner: Optional pre-built Runner object. If not provided, a default runner will be created using in-memory services. - + task_store: Optional task store instance. If not provided, an + InMemoryTaskStore will be created. Must be compatible with + DefaultRequestHandler's task_store parameter. Returns: A Starlette application that can be run with uvicorn @@ -103,6 +107,11 @@ def to_a2a( # Or with custom agent card: app = to_a2a(agent, agent_card=my_custom_agent_card) + + # Or with custom task store: + from a2a.server.tasks import TaskStore + class MyCustomTaskStore(TaskStore): ... # A user-defined TaskStore; abstract methods must be implemented + app = to_a2a(agent, task_store=MyCustomTaskStore()) """ # Set up ADK logging to ensure logs are visible when using uvicorn directly adk_logger = logging.getLogger("google_adk") @@ -121,7 +130,8 @@ async def create_runner() -> Runner: ) # Create A2A components - task_store = InMemoryTaskStore() + if task_store is None: + task_store = InMemoryTaskStore() agent_executor = A2aAgentExecutor( runner=runner or create_runner, diff --git a/tests/unittests/a2a/utils/test_agent_to_a2a.py b/tests/unittests/a2a/utils/test_agent_to_a2a.py index 503e572f2f..c0fd7eda62 100644 --- a/tests/unittests/a2a/utils/test_agent_to_a2a.py +++ b/tests/unittests/a2a/utils/test_agent_to_a2a.py @@ -19,6 +19,7 @@ from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryTaskStore +from a2a.server.tasks import TaskStore from a2a.types import AgentCard from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder @@ -131,6 +132,47 @@ def test_to_a2a_with_custom_runner( "startup", mock_app.add_event_handler.call_args[0][1] ) + @patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor") + @patch("google.adk.a2a.utils.agent_to_a2a.DefaultRequestHandler") + @patch("google.adk.a2a.utils.agent_to_a2a.InMemoryTaskStore") + @patch("google.adk.a2a.utils.agent_to_a2a.AgentCardBuilder") + @patch("google.adk.a2a.utils.agent_to_a2a.Starlette") + def test_to_a2a_with_custom_task_store( + self, + mock_starlette_class, + mock_card_builder_class, + mock_task_store_class, + mock_request_handler_class, + mock_agent_executor_class, + ): + """Test to_a2a with a custom task store.""" + # Arrange + mock_app = Mock(spec=Starlette) + mock_starlette_class.return_value = mock_app + custom_task_store = Mock(spec=TaskStore) + mock_agent_executor = Mock(spec=A2aAgentExecutor) + mock_agent_executor_class.return_value = mock_agent_executor + + # Act + result = to_a2a(self.mock_agent, task_store=custom_task_store) + + # Assert + assert result == mock_app + mock_starlette_class.assert_called_once() + # Verify InMemoryTaskStore was NOT created since we provided a custom one + mock_task_store_class.assert_not_called() + mock_agent_executor_class.assert_called_once() + # Verify the custom task store was used + mock_request_handler_class.assert_called_once_with( + agent_executor=mock_agent_executor, task_store=custom_task_store + ) + mock_card_builder_class.assert_called_once_with( + agent=self.mock_agent, rpc_url="http://localhost:8000/" + ) + mock_app.add_event_handler.assert_called_once_with( + "startup", mock_app.add_event_handler.call_args[0][1] + ) + @patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor") @patch("google.adk.a2a.utils.agent_to_a2a.DefaultRequestHandler") @patch("google.adk.a2a.utils.agent_to_a2a.InMemoryTaskStore")