diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py index 07515ab2e8..3248cbe11f 100644 --- a/src/google/adk/auth/auth_handler.py +++ b/src/google/adk/auth/auth_handler.py @@ -48,9 +48,10 @@ async def exchange_auth_token( self, ) -> AuthCredential: exchanger = OAuth2CredentialExchanger() - return await exchanger.exchange( + exchanged_credential, _ = await exchanger.exchange( self.auth_config.exchanged_auth_credential, self.auth_config.auth_scheme ) + return exchanged_credential async def parse_and_store_auth_response(self, state: State) -> None: diff --git a/src/google/adk/auth/credential_manager.py b/src/google/adk/auth/credential_manager.py index c022ab694c..292963e151 100644 --- a/src/google/adk/auth/credential_manager.py +++ b/src/google/adk/auth/credential_manager.py @@ -214,15 +214,14 @@ async def _exchange_credential( return credential, False if isinstance(exchanger, ServiceAccountCredentialExchanger): - exchanged_credential = exchanger.exchange_credential( - self._auth_config.auth_scheme, credential - ) - else: - exchanged_credential = await exchanger.exchange( - credential, self._auth_config.auth_scheme + return ( + exchanger.exchange_credential( + self._auth_config.auth_scheme, credential + ), + True, ) - return exchanged_credential, True + return await exchanger.exchange(credential, self._auth_config.auth_scheme) async def _refresh_credential( self, credential: AuthCredential diff --git a/src/google/adk/auth/exchanger/base_credential_exchanger.py b/src/google/adk/auth/exchanger/base_credential_exchanger.py index 31106b55e2..2b686d6681 100644 --- a/src/google/adk/auth/exchanger/base_credential_exchanger.py +++ b/src/google/adk/auth/exchanger/base_credential_exchanger.py @@ -41,7 +41,7 @@ async def exchange( self, auth_credential: AuthCredential, auth_scheme: Optional[AuthScheme] = None, - ) -> AuthCredential: + ) -> tuple[AuthCredential, bool]: """Exchange credential if needed. Args: @@ -49,7 +49,10 @@ async def exchange( auth_scheme: The authentication scheme (optional, some exchangers don't need it). Returns: - The exchanged credential. + A tuple of (credential, exchanged) where: + - credential: The exchanged credential if exchange occurred, otherwise + the original credential. + - exchanged: True if credential was exchanged, False otherwise. Raises: CredentialExchangeError: If credential exchange fails. diff --git a/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py b/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py index 431798cc6c..e4b392abde 100644 --- a/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py +++ b/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py @@ -51,7 +51,7 @@ async def exchange( self, auth_credential: AuthCredential, auth_scheme: Optional[AuthScheme] = None, - ) -> AuthCredential: + ) -> tuple[AuthCredential, bool]: """Exchange OAuth2 credential from authorization response. if credential exchange failed, the original credential will be returned. @@ -61,7 +61,10 @@ async def exchange( auth_scheme: The OAuth2 authentication scheme. Returns: - The exchanged credential with access token. + A tuple of (credential, exchanged) where: + - credential: The exchanged credential with an access token if exchange occurred, otherwise + the original credential. + - exchanged: True if credential was exchanged, False otherwise. Raises: CredentialExchangeError: If auth_scheme is missing. @@ -79,10 +82,10 @@ async def exchange( logger.warning( "authlib is not available, skipping OAuth2 credential exchange." ) - return auth_credential + return (auth_credential, False) if auth_credential.oauth2 and auth_credential.oauth2.access_token: - return auth_credential + return (auth_credential, False) # Determine grant type from auth_scheme grant_type = self._determine_grant_type(auth_scheme) @@ -97,7 +100,7 @@ async def exchange( ) else: logger.warning("Unsupported OAuth2 grant type: %s", grant_type) - return auth_credential + return auth_credential, False def _determine_grant_type( self, auth_scheme: AuthScheme @@ -129,7 +132,7 @@ async def _exchange_client_credentials( self, auth_credential: AuthCredential, auth_scheme: AuthScheme, - ) -> AuthCredential: + ) -> tuple[AuthCredential, bool]: """Exchange client credentials for access token. Args: @@ -137,14 +140,17 @@ async def _exchange_client_credentials( auth_scheme: The OAuth2 authentication scheme. Returns: - The credential with access token. + A tuple of (credential, exchanged) where: + - credential: The exchanged credential with an access token if exchange occurred, otherwise + the original credential. + - exchanged: True if credential was exchanged, False otherwise. """ client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential) if not client: logger.warning( "Could not create OAuth2 session for client credentials exchange" ) - return auth_credential + return auth_credential, False try: tokens = client.fetch_token( @@ -155,9 +161,9 @@ async def _exchange_client_credentials( logger.debug("Successfully exchanged client credentials for access token") except Exception as e: logger.error("Failed to exchange client credentials: %s", e) - return auth_credential + return auth_credential, False - return auth_credential + return auth_credential, True def _normalize_auth_uri(self, auth_uri: str | None) -> str | None: # Authlib currently used a simplified token check by simply scanning hash existence, @@ -171,7 +177,7 @@ async def _exchange_authorization_code( self, auth_credential: AuthCredential, auth_scheme: AuthScheme, - ) -> AuthCredential: + ) -> tuple[AuthCredential, bool]: """Exchange authorization code for access token. Args: @@ -179,14 +185,17 @@ async def _exchange_authorization_code( auth_scheme: The OAuth2 authentication scheme. Returns: - The credential with access token. + A tuple of (credential, exchanged) where: + - credential: The exchanged credential with an access token if exchange occurred, otherwise + the original credential. + - exchanged: True if credential was exchanged, False otherwise. """ client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential) if not client: logger.warning( "Could not create OAuth2 session for authorization code exchange" ) - return auth_credential + return auth_credential, False try: tokens = client.fetch_token( @@ -201,6 +210,6 @@ async def _exchange_authorization_code( logger.debug("Successfully exchanged authorization code for access token") except Exception as e: logger.error("Failed to exchange authorization code: %s", e) - return auth_credential + return auth_credential, False - return auth_credential + return auth_credential, True diff --git a/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py b/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py index 4fc439ad13..284a911270 100644 --- a/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py +++ b/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py @@ -33,7 +33,6 @@ class TestOAuth2CredentialExchanger: """Test suite for OAuth2CredentialExchanger.""" - @pytest.mark.asyncio async def test_exchange_with_existing_token(self): """Test exchange method when access token already exists.""" scheme = OpenIdConnectWithConfig( @@ -55,14 +54,16 @@ async def test_exchange_with_existing_token(self): ) exchanger = OAuth2CredentialExchanger() - result = await exchanger.exchange(credential, scheme) + exchanged_credential, was_exchanged = await exchanger.exchange( + credential, scheme + ) # Should return the same credential since access token already exists - assert result == credential - assert result.oauth2.access_token == "existing_token" + assert exchanged_credential == credential + assert exchanged_credential.oauth2.access_token == "existing_token" + assert not was_exchanged @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") - @pytest.mark.asyncio async def test_exchange_success(self, mock_oauth2_session): """Test successful token exchange.""" # Setup mock @@ -96,14 +97,16 @@ async def test_exchange_success(self, mock_oauth2_session): ) exchanger = OAuth2CredentialExchanger() - result = await exchanger.exchange(credential, scheme) + exchanged_credential, was_exchanged = await exchanger.exchange( + credential, scheme + ) # Verify token exchange was successful - assert result.oauth2.access_token == "new_access_token" - assert result.oauth2.refresh_token == "new_refresh_token" + assert exchanged_credential.oauth2.access_token == "new_access_token" + assert exchanged_credential.oauth2.refresh_token == "new_refresh_token" + assert was_exchanged mock_client.fetch_token.assert_called_once() - @pytest.mark.asyncio async def test_exchange_missing_auth_scheme(self): """Test exchange with missing auth_scheme raises ValueError.""" credential = AuthCredential( @@ -122,7 +125,6 @@ async def test_exchange_missing_auth_scheme(self): assert "auth_scheme is required" in str(e) @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") - @pytest.mark.asyncio async def test_exchange_no_session(self, mock_oauth2_session): """Test exchange when OAuth2Session cannot be created.""" # Mock to return None for create_oauth2_session @@ -146,14 +148,16 @@ async def test_exchange_no_session(self, mock_oauth2_session): ) exchanger = OAuth2CredentialExchanger() - result = await exchanger.exchange(credential, scheme) + exchanged_credential, was_exchanged = await exchanger.exchange( + credential, scheme + ) # Should return original credential when session creation fails - assert result == credential - assert result.oauth2.access_token is None + assert exchanged_credential == credential + assert exchanged_credential.oauth2.access_token is None + assert not was_exchanged @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") - @pytest.mark.asyncio async def test_exchange_fetch_token_failure(self, mock_oauth2_session): """Test exchange when fetch_token fails.""" # Setup mock to raise exception during fetch_token @@ -181,14 +185,16 @@ async def test_exchange_fetch_token_failure(self, mock_oauth2_session): ) exchanger = OAuth2CredentialExchanger() - result = await exchanger.exchange(credential, scheme) + exchanged_credential, was_exchanged = await exchanger.exchange( + credential, scheme + ) # Should return original credential when fetch_token fails - assert result == credential - assert result.oauth2.access_token is None + assert exchanged_credential == credential + assert exchanged_credential.oauth2.access_token is None + assert not was_exchanged mock_client.fetch_token.assert_called_once() - @pytest.mark.asyncio async def test_exchange_authlib_not_available(self): """Test exchange when authlib is not available.""" scheme = OpenIdConnectWithConfig( @@ -217,14 +223,16 @@ async def test_exchange_authlib_not_available(self): "google.adk.auth.exchanger.oauth2_credential_exchanger.AUTHLIB_AVAILABLE", False, ): - result = await exchanger.exchange(credential, scheme) + exchanged_credential, was_exchanged = await exchanger.exchange( + credential, scheme + ) # Should return original credential when authlib is not available - assert result == credential - assert result.oauth2.access_token is None + assert exchanged_credential == credential + assert exchanged_credential.oauth2.access_token is None + assert not was_exchanged @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") - @pytest.mark.asyncio async def test_exchange_client_credentials_success(self, mock_oauth2_session): """Test successful client credentials exchange.""" # Setup mock @@ -255,17 +263,19 @@ async def test_exchange_client_credentials_success(self, mock_oauth2_session): ) exchanger = OAuth2CredentialExchanger() - result = await exchanger.exchange(credential, scheme) + exchanged_credential, was_exchanged = await exchanger.exchange( + credential, scheme + ) # Verify client credentials exchange was successful - assert result.oauth2.access_token == "client_access_token" + assert exchanged_credential.oauth2.access_token == "client_access_token" + assert was_exchanged mock_client.fetch_token.assert_called_once_with( "https://example.com/token", grant_type="client_credentials", ) @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") - @pytest.mark.asyncio async def test_exchange_client_credentials_failure(self, mock_oauth2_session): """Test client credentials exchange failure.""" # Setup mock to raise exception during fetch_token @@ -292,15 +302,17 @@ async def test_exchange_client_credentials_failure(self, mock_oauth2_session): ) exchanger = OAuth2CredentialExchanger() - result = await exchanger.exchange(credential, scheme) + exchanged_credential, was_exchanged = await exchanger.exchange( + credential, scheme + ) # Should return original credential when client credentials exchange fails - assert result == credential - assert result.oauth2.access_token is None + assert exchanged_credential == credential + assert exchanged_credential.oauth2.access_token is None + assert not was_exchanged mock_client.fetch_token.assert_called_once() @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") - @pytest.mark.asyncio async def test_exchange_normalize_uri(self, mock_oauth2_session): """Test exchange method normalizes auth_response_uri.""" mock_client = Mock() @@ -343,7 +355,6 @@ async def test_exchange_normalize_uri(self, mock_oauth2_session): grant_type=OAuthGrantType.AUTHORIZATION_CODE, ) - @pytest.mark.asyncio async def test_determine_grant_type_client_credentials(self): """Test grant type determination for client credentials.""" flows = OAuthFlows( @@ -360,7 +371,6 @@ async def test_determine_grant_type_client_credentials(self): assert grant_type == OAuthGrantType.CLIENT_CREDENTIALS - @pytest.mark.asyncio async def test_determine_grant_type_openid_connect(self): """Test grant type determination for OpenID Connect (defaults to auth code).""" scheme = OpenIdConnectWithConfig(