Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/google/adk/auth/auth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ async def exchange_auth_token(
self,
) -> AuthCredential:
exchanger = OAuth2CredentialExchanger()
return await exchanger.exchange(
exchanged_credential, _ = await exchanger.exchange(
self.auth_config.exchanged_auth_credential, self.auth_config.auth_scheme
)
return exchanged_credential

async def parse_and_store_auth_response(self, state: State) -> None:

Expand Down
13 changes: 6 additions & 7 deletions src/google/adk/auth/credential_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,14 @@ async def _exchange_credential(
return credential, False

if isinstance(exchanger, ServiceAccountCredentialExchanger):
exchanged_credential = exchanger.exchange_credential(
self._auth_config.auth_scheme, credential
)
else:
exchanged_credential = await exchanger.exchange(
credential, self._auth_config.auth_scheme
return (
exchanger.exchange_credential(
self._auth_config.auth_scheme, credential
),
True,
)

return exchanged_credential, True
return await exchanger.exchange(credential, self._auth_config.auth_scheme)

async def _refresh_credential(
self, credential: AuthCredential
Expand Down
7 changes: 5 additions & 2 deletions src/google/adk/auth/exchanger/base_credential_exchanger.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,18 @@ async def exchange(
self,
auth_credential: AuthCredential,
auth_scheme: Optional[AuthScheme] = None,
) -> AuthCredential:
) -> tuple[AuthCredential, bool]:
"""Exchange credential if needed.

Args:
auth_credential: The credential to exchange.
auth_scheme: The authentication scheme (optional, some exchangers don't need it).

Returns:
The exchanged credential.
A tuple of (credential, exchanged) where:
- credential: The exchanged credential if exchange occurred, otherwise
the original credential.
- exchanged: True if credential was exchanged, False otherwise.

Raises:
CredentialExchangeError: If credential exchange fails.
Expand Down
39 changes: 24 additions & 15 deletions src/google/adk/auth/exchanger/oauth2_credential_exchanger.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def exchange(
self,
auth_credential: AuthCredential,
auth_scheme: Optional[AuthScheme] = None,
) -> AuthCredential:
) -> tuple[AuthCredential, bool]:
"""Exchange OAuth2 credential from authorization response.

if credential exchange failed, the original credential will be returned.
Expand All @@ -61,7 +61,10 @@ async def exchange(
auth_scheme: The OAuth2 authentication scheme.

Returns:
The exchanged credential with access token.
A tuple of (credential, exchanged) where:
- credential: The exchanged credential with an access token if exchange occurred, otherwise
the original credential.
- exchanged: True if credential was exchanged, False otherwise.

Raises:
CredentialExchangeError: If auth_scheme is missing.
Expand All @@ -79,10 +82,10 @@ async def exchange(
logger.warning(
"authlib is not available, skipping OAuth2 credential exchange."
)
return auth_credential
return (auth_credential, False)

if auth_credential.oauth2 and auth_credential.oauth2.access_token:
return auth_credential
return (auth_credential, False)

# Determine grant type from auth_scheme
grant_type = self._determine_grant_type(auth_scheme)
Expand All @@ -97,7 +100,7 @@ async def exchange(
)
else:
logger.warning("Unsupported OAuth2 grant type: %s", grant_type)
return auth_credential
return auth_credential, False

def _determine_grant_type(
self, auth_scheme: AuthScheme
Expand Down Expand Up @@ -129,22 +132,25 @@ async def _exchange_client_credentials(
self,
auth_credential: AuthCredential,
auth_scheme: AuthScheme,
) -> AuthCredential:
) -> tuple[AuthCredential, bool]:
"""Exchange client credentials for access token.

Args:
auth_credential: The OAuth2 credential to exchange.
auth_scheme: The OAuth2 authentication scheme.

Returns:
The credential with access token.
A tuple of (credential, exchanged) where:
- credential: The exchanged credential with an access token if exchange occurred, otherwise
the original credential.
- exchanged: True if credential was exchanged, False otherwise.
"""
client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential)
if not client:
logger.warning(
"Could not create OAuth2 session for client credentials exchange"
)
return auth_credential
return auth_credential, False

try:
tokens = client.fetch_token(
Expand All @@ -155,9 +161,9 @@ async def _exchange_client_credentials(
logger.debug("Successfully exchanged client credentials for access token")
except Exception as e:
logger.error("Failed to exchange client credentials: %s", e)
return auth_credential
return auth_credential, False

return auth_credential
return auth_credential, True

def _normalize_auth_uri(self, auth_uri: str | None) -> str | None:
# Authlib currently used a simplified token check by simply scanning hash existence,
Expand All @@ -171,22 +177,25 @@ async def _exchange_authorization_code(
self,
auth_credential: AuthCredential,
auth_scheme: AuthScheme,
) -> AuthCredential:
) -> tuple[AuthCredential, bool]:
"""Exchange authorization code for access token.

Args:
auth_credential: The OAuth2 credential to exchange.
auth_scheme: The OAuth2 authentication scheme.

Returns:
The credential with access token.
A tuple of (credential, exchanged) where:
- credential: The exchanged credential with an access token if exchange occurred, otherwise
the original credential.
- exchanged: True if credential was exchanged, False otherwise.
"""
client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential)
if not client:
logger.warning(
"Could not create OAuth2 session for authorization code exchange"
)
return auth_credential
return auth_credential, False

try:
tokens = client.fetch_token(
Expand All @@ -201,6 +210,6 @@ async def _exchange_authorization_code(
logger.debug("Successfully exchanged authorization code for access token")
except Exception as e:
logger.error("Failed to exchange authorization code: %s", e)
return auth_credential
return auth_credential, False

return auth_credential
return auth_credential, True
72 changes: 41 additions & 31 deletions tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
class TestOAuth2CredentialExchanger:
"""Test suite for OAuth2CredentialExchanger."""

@pytest.mark.asyncio
async def test_exchange_with_existing_token(self):
"""Test exchange method when access token already exists."""
scheme = OpenIdConnectWithConfig(
Expand All @@ -55,14 +54,16 @@ async def test_exchange_with_existing_token(self):
)

exchanger = OAuth2CredentialExchanger()
result = await exchanger.exchange(credential, scheme)
exchanged_credential, was_exchanged = await exchanger.exchange(
credential, scheme
)

# Should return the same credential since access token already exists
assert result == credential
assert result.oauth2.access_token == "existing_token"
assert exchanged_credential == credential
assert exchanged_credential.oauth2.access_token == "existing_token"
assert not was_exchanged

@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
@pytest.mark.asyncio
async def test_exchange_success(self, mock_oauth2_session):
"""Test successful token exchange."""
# Setup mock
Expand Down Expand Up @@ -96,14 +97,16 @@ async def test_exchange_success(self, mock_oauth2_session):
)

exchanger = OAuth2CredentialExchanger()
result = await exchanger.exchange(credential, scheme)
exchanged_credential, was_exchanged = await exchanger.exchange(
credential, scheme
)

# Verify token exchange was successful
assert result.oauth2.access_token == "new_access_token"
assert result.oauth2.refresh_token == "new_refresh_token"
assert exchanged_credential.oauth2.access_token == "new_access_token"
assert exchanged_credential.oauth2.refresh_token == "new_refresh_token"
assert was_exchanged
mock_client.fetch_token.assert_called_once()

@pytest.mark.asyncio
async def test_exchange_missing_auth_scheme(self):
"""Test exchange with missing auth_scheme raises ValueError."""
credential = AuthCredential(
Expand All @@ -122,7 +125,6 @@ async def test_exchange_missing_auth_scheme(self):
assert "auth_scheme is required" in str(e)

@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
@pytest.mark.asyncio
async def test_exchange_no_session(self, mock_oauth2_session):
"""Test exchange when OAuth2Session cannot be created."""
# Mock to return None for create_oauth2_session
Expand All @@ -146,14 +148,16 @@ async def test_exchange_no_session(self, mock_oauth2_session):
)

exchanger = OAuth2CredentialExchanger()
result = await exchanger.exchange(credential, scheme)
exchanged_credential, was_exchanged = await exchanger.exchange(
credential, scheme
)

# Should return original credential when session creation fails
assert result == credential
assert result.oauth2.access_token is None
assert exchanged_credential == credential
assert exchanged_credential.oauth2.access_token is None
assert not was_exchanged

@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
@pytest.mark.asyncio
async def test_exchange_fetch_token_failure(self, mock_oauth2_session):
"""Test exchange when fetch_token fails."""
# Setup mock to raise exception during fetch_token
Expand Down Expand Up @@ -181,14 +185,16 @@ async def test_exchange_fetch_token_failure(self, mock_oauth2_session):
)

exchanger = OAuth2CredentialExchanger()
result = await exchanger.exchange(credential, scheme)
exchanged_credential, was_exchanged = await exchanger.exchange(
credential, scheme
)

# Should return original credential when fetch_token fails
assert result == credential
assert result.oauth2.access_token is None
assert exchanged_credential == credential
assert exchanged_credential.oauth2.access_token is None
assert not was_exchanged
mock_client.fetch_token.assert_called_once()

@pytest.mark.asyncio
async def test_exchange_authlib_not_available(self):
"""Test exchange when authlib is not available."""
scheme = OpenIdConnectWithConfig(
Expand Down Expand Up @@ -217,14 +223,16 @@ async def test_exchange_authlib_not_available(self):
"google.adk.auth.exchanger.oauth2_credential_exchanger.AUTHLIB_AVAILABLE",
False,
):
result = await exchanger.exchange(credential, scheme)
exchanged_credential, was_exchanged = await exchanger.exchange(
credential, scheme
)

# Should return original credential when authlib is not available
assert result == credential
assert result.oauth2.access_token is None
assert exchanged_credential == credential
assert exchanged_credential.oauth2.access_token is None
assert not was_exchanged

@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
@pytest.mark.asyncio
async def test_exchange_client_credentials_success(self, mock_oauth2_session):
"""Test successful client credentials exchange."""
# Setup mock
Expand Down Expand Up @@ -255,17 +263,19 @@ async def test_exchange_client_credentials_success(self, mock_oauth2_session):
)

exchanger = OAuth2CredentialExchanger()
result = await exchanger.exchange(credential, scheme)
exchanged_credential, was_exchanged = await exchanger.exchange(
credential, scheme
)

# Verify client credentials exchange was successful
assert result.oauth2.access_token == "client_access_token"
assert exchanged_credential.oauth2.access_token == "client_access_token"
assert was_exchanged
mock_client.fetch_token.assert_called_once_with(
"https://example.com/token",
grant_type="client_credentials",
)

@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
@pytest.mark.asyncio
async def test_exchange_client_credentials_failure(self, mock_oauth2_session):
"""Test client credentials exchange failure."""
# Setup mock to raise exception during fetch_token
Expand All @@ -292,15 +302,17 @@ async def test_exchange_client_credentials_failure(self, mock_oauth2_session):
)

exchanger = OAuth2CredentialExchanger()
result = await exchanger.exchange(credential, scheme)
exchanged_credential, was_exchanged = await exchanger.exchange(
credential, scheme
)

# Should return original credential when client credentials exchange fails
assert result == credential
assert result.oauth2.access_token is None
assert exchanged_credential == credential
assert exchanged_credential.oauth2.access_token is None
assert not was_exchanged
mock_client.fetch_token.assert_called_once()

@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
@pytest.mark.asyncio
async def test_exchange_normalize_uri(self, mock_oauth2_session):
"""Test exchange method normalizes auth_response_uri."""
mock_client = Mock()
Expand Down Expand Up @@ -343,7 +355,6 @@ async def test_exchange_normalize_uri(self, mock_oauth2_session):
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
)

@pytest.mark.asyncio
async def test_determine_grant_type_client_credentials(self):
"""Test grant type determination for client credentials."""
flows = OAuthFlows(
Expand All @@ -360,7 +371,6 @@ async def test_determine_grant_type_client_credentials(self):

assert grant_type == OAuthGrantType.CLIENT_CREDENTIALS

@pytest.mark.asyncio
async def test_determine_grant_type_openid_connect(self):
"""Test grant type determination for OpenID Connect (defaults to auth code)."""
scheme = OpenIdConnectWithConfig(
Expand Down