diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index 035b75878b..590d3a0f3f 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -17,6 +17,7 @@ import asyncio import logging import sys +import time from typing import Callable from typing import Dict from typing import List @@ -96,6 +97,8 @@ def __init__( header_provider: Optional[ Callable[[ReadonlyContext], Dict[str, str]] ] = None, + cache: bool = False, + cache_ttl_seconds: Optional[int] = None, ): """Initializes the McpToolset. @@ -121,6 +124,10 @@ def __init__( tools. header_provider: A callable that takes a ReadonlyContext and returns a dictionary of headers to be used for the MCP session. + cache: If True, the toolset will cache the response from the + first `list_tools` call and reuse it for subsequent calls. + cache_ttl_seconds: If set, the in-memory cache will expire + after this many seconds. """ super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix) @@ -139,6 +146,11 @@ def __init__( self._auth_scheme = auth_scheme self._auth_credential = auth_credential self._require_confirmation = require_confirmation + self._cache = cache + self._cache_ttl_seconds = cache_ttl_seconds + self._cached_tool_response: Optional[ListToolsResult] = None + self._cache_creation_time: Optional[float] = None + self._cache_lock = asyncio.Lock() @retry_on_errors async def get_tools( @@ -154,41 +166,74 @@ async def get_tools( Returns: List[BaseTool]: A list of tools available under the specified context. """ - headers = ( - self._header_provider(readonly_context) - if self._header_provider and readonly_context - else None - ) - # Get session from session manager - session = await self._mcp_session_manager.create_session(headers=headers) - - # Fetch available tools from the MCP server - timeout_in_seconds = ( - self._connection_params.timeout - if hasattr(self._connection_params, "timeout") - else None - ) - try: - tools_response: ListToolsResult = await asyncio.wait_for( - session.list_tools(), timeout=timeout_in_seconds - ) - except Exception as e: - raise ConnectionError("Failed to get tools from MCP server.") from e + + def _is_cache_valid() -> bool: + if not self._cache or not self._cached_tool_response: + return False + + if self._cache_ttl_seconds is None: + return True # No TTL set, consider cache always valid + + if self._cache_creation_time is None: + # This should not happen in a well-initialized system + return False + + elapsed = time.monotonic() - self._cache_creation_time + return elapsed <= self._cache_ttl_seconds + + # First check without a lock for performance. + if _is_cache_valid(): + tools_response = self._cached_tool_response + else: + # If cache is invalid, acquire lock to prevent stampede. + async with self._cache_lock: + # Double-check if cache was populated while waiting for the lock. + if _is_cache_valid(): + tools_response = self._cached_tool_response + else: + # Cache is still invalid, so we are the one to fetch it. + headers = ( + self._header_provider(readonly_context) + if self._header_provider and readonly_context + else None + ) + session = await self._mcp_session_manager.create_session( + headers=headers + ) + + timeout_in_seconds = ( + self._connection_params.timeout + if hasattr(self._connection_params, "timeout") + else None + ) + try: + fetched_tools = await asyncio.wait_for( + session.list_tools(), timeout=timeout_in_seconds + ) + if self._cache: + self._cached_tool_response = fetched_tools + self._cache_creation_time = time.monotonic() + tools_response = fetched_tools + except Exception as e: + raise ConnectionError( + "Failed to get tools from MCP server." + ) from e # Apply filtering based on context and tool_filter tools = [] - for tool in tools_response.tools: - mcp_tool = MCPTool( - mcp_tool=tool, - mcp_session_manager=self._mcp_session_manager, - auth_scheme=self._auth_scheme, - auth_credential=self._auth_credential, - require_confirmation=self._require_confirmation, - header_provider=self._header_provider, - ) - - if self._is_tool_selected(mcp_tool, readonly_context): - tools.append(mcp_tool) + if tools_response: + for tool in tools_response.tools: + mcp_tool = MCPTool( + mcp_tool=tool, + mcp_session_manager=self._mcp_session_manager, + auth_scheme=self._auth_scheme, + auth_credential=self._auth_credential, + require_confirmation=self._require_confirmation, + header_provider=self._header_provider, + ) + + if self._is_tool_selected(mcp_tool, readonly_context): + tools.append(mcp_tool) return tools async def close(self) -> None: @@ -230,6 +275,7 @@ def from_config( auth_scheme=mcp_toolset_config.auth_scheme, auth_credential=mcp_toolset_config.auth_credential, ) + class MCPToolset(McpToolset): diff --git a/tests/unittests/tools/test_mcp_toolset.py b/tests/unittests/tools/test_mcp_toolset.py index 7bfd912669..14cedc216b 100644 --- a/tests/unittests/tools/test_mcp_toolset.py +++ b/tests/unittests/tools/test_mcp_toolset.py @@ -14,6 +14,7 @@ """Unit tests for McpToolset.""" +import asyncio from unittest.mock import AsyncMock from unittest.mock import MagicMock @@ -69,3 +70,143 @@ async def test_mcp_toolset_with_prefix(): # Assert that the original tools are not modified assert tools[0].name == "tool1" assert tools[1].name == "tool2" + + +def _create_mock_session_manager(): + """Helper to create a mock MCPSessionManager.""" + mock_session_manager = MagicMock() + mock_session = MagicMock() + + mock_tool1 = MagicMock() + mock_tool1.name = "tool1" + mock_tool1.description = "tool 1 desc" + mock_tool2 = MagicMock() + mock_tool2.name = "tool2" + mock_tool2.description = "tool 2 desc" + list_tools_result = MagicMock() + list_tools_result.tools = [mock_tool1, mock_tool2] + + mock_session.list_tools = AsyncMock(return_value=list_tools_result) + mock_session_manager.create_session = AsyncMock(return_value=mock_session) + return mock_session_manager, mock_session + + +@pytest.mark.asyncio +async def test_mcp_toolset_cache_disabled(): + """Test that list_tools is called every time when cache is disabled.""" + mock_connection_params = MagicMock() + mock_connection_params.timeout = None + mock_session_manager, mock_session = _create_mock_session_manager() + + toolset = McpToolset(connection_params=mock_connection_params, cache=False) + toolset._mcp_session_manager = mock_session_manager + + await toolset.get_tools() + await toolset.get_tools() + + assert mock_session.list_tools.call_count == 2 + + +@pytest.mark.asyncio +async def test_mcp_toolset_cache_enabled(): + """Test that list_tools is called only once when cache is enabled.""" + mock_connection_params = MagicMock() + mock_connection_params.timeout = None + mock_session_manager, mock_session = _create_mock_session_manager() + + toolset = McpToolset(connection_params=mock_connection_params, cache=True) + toolset._mcp_session_manager = mock_session_manager + + tools1 = await toolset.get_tools() + tools2 = await toolset.get_tools() + + mock_session.list_tools.assert_called_once() + assert len(tools1) == 2 + assert len(tools2) == 2 + assert tools1[0].name == tools2[0].name + + +@pytest.mark.asyncio +async def test_mcp_toolset_cache_with_ttl_not_expired(): + """Test that cache is used when TTL has not expired.""" + mock_connection_params = MagicMock() + mock_connection_params.timeout = None + mock_session_manager, mock_session = _create_mock_session_manager() + + toolset = McpToolset( + connection_params=mock_connection_params, cache=True, cache_ttl_seconds=10 + ) + toolset._mcp_session_manager = mock_session_manager + + await toolset.get_tools() + await toolset.get_tools() + + mock_session.list_tools.assert_called_once() + + +@pytest.mark.asyncio +async def test_mcp_toolset_cache_with_ttl_expired(): + """Test that list_tools is called again after TTL expires.""" + mock_connection_params = MagicMock() + mock_connection_params.timeout = None + mock_session_manager, mock_session = _create_mock_session_manager() + + toolset = McpToolset( + connection_params=mock_connection_params, cache=True, cache_ttl_seconds=1 + ) + toolset._mcp_session_manager = mock_session_manager + + await toolset.get_tools() + mock_session.list_tools.assert_called_once() + + await asyncio.sleep(1.1) + + await toolset.get_tools() + assert mock_session.list_tools.call_count == 2 + + +@pytest.mark.asyncio +async def test_mcp_toolset_cache_concurrency(): + """Test that list_tools is called only once during concurrent requests.""" + mock_connection_params = MagicMock() + mock_connection_params.timeout = None + + # Create a mock session manager. Add a small delay to the mock call + # to simulate network latency and increase the chance of a race condition. + mock_session_manager = MagicMock() + mock_session = MagicMock() + + mock_tool1 = MagicMock() + mock_tool1.name = "tool1" + mock_tool1.description = "tool 1 desc" + list_tools_result = MagicMock() + list_tools_result.tools = [mock_tool1] + + async def delayed_list_tools(): + await asyncio.sleep(0.1) + return list_tools_result + + mock_session.list_tools = AsyncMock(side_effect=delayed_list_tools) + mock_session_manager.create_session = AsyncMock(return_value=mock_session) + + # Initialize the toolset with caching enabled + toolset = McpToolset(connection_params=mock_connection_params, cache=True) + toolset._mcp_session_manager = mock_session_manager + + # Create multiple concurrent tasks to call get_tools + tasks = [asyncio.create_task(toolset.get_tools()) for _ in range(5)] + + # Run all tasks concurrently + results = await asyncio.gather(*tasks) + + # Assert that list_tools was only called once, thanks to the lock + mock_session.list_tools.assert_called_once() + + # Assert that all results are the same and correct + assert len(results) == 5 + for result in results: + assert len(result) == 1 + assert result[0].name == "tool1" + + # Check that the first result is the same as the others + assert all(results[0][0].name == r[0].name for r in results[1:]) \ No newline at end of file