Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 78 additions & 32 deletions src/google/adk/tools/mcp_tool/mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import asyncio
import logging
import sys
import time
from typing import Callable
from typing import Dict
from typing import List
Expand Down Expand Up @@ -96,6 +97,8 @@ def __init__(
header_provider: Optional[
Callable[[ReadonlyContext], Dict[str, str]]
] = None,
cache: bool = False,
cache_ttl_seconds: Optional[int] = None,
):
"""Initializes the McpToolset.

Expand All @@ -121,6 +124,10 @@ def __init__(
tools.
header_provider: A callable that takes a ReadonlyContext and returns a
dictionary of headers to be used for the MCP session.
cache: If True, the toolset will cache the response from the
first `list_tools` call and reuse it for subsequent calls.
cache_ttl_seconds: If set, the in-memory cache will expire
after this many seconds.
"""
super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix)

Expand All @@ -139,6 +146,11 @@ def __init__(
self._auth_scheme = auth_scheme
self._auth_credential = auth_credential
self._require_confirmation = require_confirmation
self._cache = cache
self._cache_ttl_seconds = cache_ttl_seconds
self._cached_tool_response: Optional[ListToolsResult] = None
self._cache_creation_time: Optional[float] = None
self._cache_lock = asyncio.Lock()

@retry_on_errors
async def get_tools(
Expand All @@ -154,41 +166,74 @@ async def get_tools(
Returns:
List[BaseTool]: A list of tools available under the specified context.
"""
headers = (
self._header_provider(readonly_context)
if self._header_provider and readonly_context
else None
)
# Get session from session manager
session = await self._mcp_session_manager.create_session(headers=headers)

# Fetch available tools from the MCP server
timeout_in_seconds = (
self._connection_params.timeout
if hasattr(self._connection_params, "timeout")
else None
)
try:
tools_response: ListToolsResult = await asyncio.wait_for(
session.list_tools(), timeout=timeout_in_seconds
)
except Exception as e:
raise ConnectionError("Failed to get tools from MCP server.") from e

def _is_cache_valid() -> bool:
if not self._cache or not self._cached_tool_response:
return False

if self._cache_ttl_seconds is None:
return True # No TTL set, consider cache always valid

if self._cache_creation_time is None:
# This should not happen in a well-initialized system
return False

elapsed = time.monotonic() - self._cache_creation_time
return elapsed <= self._cache_ttl_seconds

# First check without a lock for performance.
if _is_cache_valid():
tools_response = self._cached_tool_response
else:
# If cache is invalid, acquire lock to prevent stampede.
async with self._cache_lock:
# Double-check if cache was populated while waiting for the lock.
if _is_cache_valid():
tools_response = self._cached_tool_response
else:
# Cache is still invalid, so we are the one to fetch it.
headers = (
self._header_provider(readonly_context)
if self._header_provider and readonly_context
else None
)
session = await self._mcp_session_manager.create_session(
headers=headers
)

timeout_in_seconds = (
self._connection_params.timeout
if hasattr(self._connection_params, "timeout")
else None
)
try:
fetched_tools = await asyncio.wait_for(
session.list_tools(), timeout=timeout_in_seconds
)
if self._cache:
self._cached_tool_response = fetched_tools
self._cache_creation_time = time.monotonic()
tools_response = fetched_tools
except Exception as e:
raise ConnectionError(
"Failed to get tools from MCP server."
) from e

# Apply filtering based on context and tool_filter
tools = []
for tool in tools_response.tools:
mcp_tool = MCPTool(
mcp_tool=tool,
mcp_session_manager=self._mcp_session_manager,
auth_scheme=self._auth_scheme,
auth_credential=self._auth_credential,
require_confirmation=self._require_confirmation,
header_provider=self._header_provider,
)

if self._is_tool_selected(mcp_tool, readonly_context):
tools.append(mcp_tool)
if tools_response:
for tool in tools_response.tools:
mcp_tool = MCPTool(
mcp_tool=tool,
mcp_session_manager=self._mcp_session_manager,
auth_scheme=self._auth_scheme,
auth_credential=self._auth_credential,
require_confirmation=self._require_confirmation,
header_provider=self._header_provider,
)

if self._is_tool_selected(mcp_tool, readonly_context):
tools.append(mcp_tool)
return tools

async def close(self) -> None:
Expand Down Expand Up @@ -230,6 +275,7 @@ def from_config(
auth_scheme=mcp_toolset_config.auth_scheme,
auth_credential=mcp_toolset_config.auth_credential,
)



class MCPToolset(McpToolset):
Expand Down
141 changes: 141 additions & 0 deletions tests/unittests/tools/test_mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Unit tests for McpToolset."""

import asyncio
from unittest.mock import AsyncMock
from unittest.mock import MagicMock

Expand Down Expand Up @@ -69,3 +70,143 @@ async def test_mcp_toolset_with_prefix():
# Assert that the original tools are not modified
assert tools[0].name == "tool1"
assert tools[1].name == "tool2"


def _create_mock_session_manager():
"""Helper to create a mock MCPSessionManager."""
mock_session_manager = MagicMock()
mock_session = MagicMock()

mock_tool1 = MagicMock()
mock_tool1.name = "tool1"
mock_tool1.description = "tool 1 desc"
mock_tool2 = MagicMock()
mock_tool2.name = "tool2"
mock_tool2.description = "tool 2 desc"
list_tools_result = MagicMock()
list_tools_result.tools = [mock_tool1, mock_tool2]

mock_session.list_tools = AsyncMock(return_value=list_tools_result)
mock_session_manager.create_session = AsyncMock(return_value=mock_session)
return mock_session_manager, mock_session


@pytest.mark.asyncio
async def test_mcp_toolset_cache_disabled():
"""Test that list_tools is called every time when cache is disabled."""
mock_connection_params = MagicMock()
mock_connection_params.timeout = None
mock_session_manager, mock_session = _create_mock_session_manager()

toolset = McpToolset(connection_params=mock_connection_params, cache=False)
toolset._mcp_session_manager = mock_session_manager

await toolset.get_tools()
await toolset.get_tools()

assert mock_session.list_tools.call_count == 2


@pytest.mark.asyncio
async def test_mcp_toolset_cache_enabled():
"""Test that list_tools is called only once when cache is enabled."""
mock_connection_params = MagicMock()
mock_connection_params.timeout = None
mock_session_manager, mock_session = _create_mock_session_manager()

toolset = McpToolset(connection_params=mock_connection_params, cache=True)
toolset._mcp_session_manager = mock_session_manager

tools1 = await toolset.get_tools()
tools2 = await toolset.get_tools()

mock_session.list_tools.assert_called_once()
assert len(tools1) == 2
assert len(tools2) == 2
assert tools1[0].name == tools2[0].name


@pytest.mark.asyncio
async def test_mcp_toolset_cache_with_ttl_not_expired():
"""Test that cache is used when TTL has not expired."""
mock_connection_params = MagicMock()
mock_connection_params.timeout = None
mock_session_manager, mock_session = _create_mock_session_manager()

toolset = McpToolset(
connection_params=mock_connection_params, cache=True, cache_ttl_seconds=10
)
toolset._mcp_session_manager = mock_session_manager

await toolset.get_tools()
await toolset.get_tools()

mock_session.list_tools.assert_called_once()


@pytest.mark.asyncio
async def test_mcp_toolset_cache_with_ttl_expired():
"""Test that list_tools is called again after TTL expires."""
mock_connection_params = MagicMock()
mock_connection_params.timeout = None
mock_session_manager, mock_session = _create_mock_session_manager()

toolset = McpToolset(
connection_params=mock_connection_params, cache=True, cache_ttl_seconds=1
)
toolset._mcp_session_manager = mock_session_manager

await toolset.get_tools()
mock_session.list_tools.assert_called_once()

await asyncio.sleep(1.1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using asyncio.sleep() to test time-dependent logic can make tests slow and potentially flaky. A more robust approach is to mock the time source, time.monotonic. This gives you precise control over time in your test, making it faster and more reliable.

You can use mocker.patch from pytest-mock to do this. Here's an example of how you could rewrite this test:

@pytest.mark.asyncio
async def test_mcp_toolset_cache_with_ttl_expired(mocker):
    """Test that list_tools is called again after TTL expires."""
    mock_connection_params = MagicMock()
    mock_connection_params.timeout = None
    mock_session_manager, mock_session = _create_mock_session_manager()

    toolset = McpToolset(
        connection_params=mock_connection_params, cache=True, cache_ttl_seconds=1
    )
    toolset._mcp_session_manager = mock_session_manager

    # Patch time.monotonic
    mock_time = mocker.patch('time.monotonic')

    # First call, populates cache
    mock_time.return_value = 1000.0
    await toolset.get_tools()
    mock_session.list_tools.assert_called_once()

    # Second call, after TTL expired
    mock_time.return_value = 1001.1  # More than 1 second later
    await toolset.get_tools()
    assert mock_session.list_tools.call_count == 2


await toolset.get_tools()
assert mock_session.list_tools.call_count == 2


@pytest.mark.asyncio
async def test_mcp_toolset_cache_concurrency():
"""Test that list_tools is called only once during concurrent requests."""
mock_connection_params = MagicMock()
mock_connection_params.timeout = None

# Create a mock session manager. Add a small delay to the mock call
# to simulate network latency and increase the chance of a race condition.
mock_session_manager = MagicMock()
mock_session = MagicMock()

mock_tool1 = MagicMock()
mock_tool1.name = "tool1"
mock_tool1.description = "tool 1 desc"
list_tools_result = MagicMock()
list_tools_result.tools = [mock_tool1]

async def delayed_list_tools():
await asyncio.sleep(0.1)
return list_tools_result

mock_session.list_tools = AsyncMock(side_effect=delayed_list_tools)
mock_session_manager.create_session = AsyncMock(return_value=mock_session)

# Initialize the toolset with caching enabled
toolset = McpToolset(connection_params=mock_connection_params, cache=True)
toolset._mcp_session_manager = mock_session_manager

# Create multiple concurrent tasks to call get_tools
tasks = [asyncio.create_task(toolset.get_tools()) for _ in range(5)]

# Run all tasks concurrently
results = await asyncio.gather(*tasks)

# Assert that list_tools was only called once, thanks to the lock
mock_session.list_tools.assert_called_once()

# Assert that all results are the same and correct
assert len(results) == 5
for result in results:
assert len(result) == 1
assert result[0].name == "tool1"

# Check that the first result is the same as the others
assert all(results[0][0].name == r[0].name for r in results[1:])