From 8417688bdce1bb37779d25c0d92a6cbefe250056 Mon Sep 17 00:00:00 2001 From: Margubur Rahman Date: Mon, 15 Dec 2025 15:31:31 +0000 Subject: [PATCH] feat(storage): add AsyncConnection along with unit tests --- .../_experimental/asyncio/async_connection.py | 321 ++++++++++++++++++ google/cloud/storage/_http.py | 7 +- tests/unit/asyncio/test_async_connection.py | 276 +++++++++++++++ 3 files changed, 601 insertions(+), 3 deletions(-) create mode 100644 google/cloud/storage/_experimental/asyncio/async_connection.py create mode 100644 tests/unit/asyncio/test_async_connection.py diff --git a/google/cloud/storage/_experimental/asyncio/async_connection.py b/google/cloud/storage/_experimental/asyncio/async_connection.py new file mode 100644 index 000000000..8d297ef8a --- /dev/null +++ b/google/cloud/storage/_experimental/asyncio/async_connection.py @@ -0,0 +1,321 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Create/interact with Google Cloud Storage connections in asynchronous manner.""" + +import json +import collections +import functools +from urllib.parse import urlencode + +import google.api_core.exceptions +from google.cloud import _http +from google.cloud.storage import _http as storage_http +from google.cloud.storage import _helpers +from google.api_core.client_info import ClientInfo +from google.cloud.storage._opentelemetry_tracing import create_trace_span +from google.cloud.storage import __version__ +from google.cloud.storage._http import AGENT_VERSION + + +class AsyncConnection: + """Class for asynchronous connection using google.auth.aio. + + This class handles the creation of API requests, header management, + user agent configuration, and error handling for the Async Storage Client. + + Args: + client: The client that owns this connection. + client_info: Information about the client library. + api_endpoint: The API endpoint to use. + """ + + def __init__(self, client, client_info=None, api_endpoint=None): + self._client = client + + if client_info is None: + client_info = ClientInfo() + + self._client_info = client_info + if self._client_info.user_agent is None: + self._client_info.user_agent = AGENT_VERSION + else: + self._client_info.user_agent = ( + f"{self._client_info.user_agent} {AGENT_VERSION}" + ) + self._client_info.client_library_version = __version__ + self._extra_headers = {} + + self.API_BASE_URL = api_endpoint or storage_http.Connection.DEFAULT_API_ENDPOINT + self.API_VERSION = storage_http.Connection.API_VERSION + self.API_URL_TEMPLATE = storage_http.Connection.API_URL_TEMPLATE + + @property + def extra_headers(self): + """Returns extra headers to send with every request.""" + return self._extra_headers + + @extra_headers.setter + def extra_headers(self, value): + """Set the extra header property.""" + self._extra_headers = value + + @property + def async_http(self): + """Returns the AsyncAuthorizedSession from the client. + + Returns: + google.auth.aio.transport.sessions.AsyncAuthorizedSession: The async session. + """ + return self._client.async_http + + @property + def user_agent(self): + """Returns user_agent for async HTTP transport. + + Returns: + str: The user agent string. + """ + return self._client_info.to_user_agent() + + @user_agent.setter + def user_agent(self, value): + """Setter for user_agent in connection.""" + self._client_info.user_agent = value + + async def _make_request( + self, + method, + url, + data=None, + content_type=None, + headers=None, + target_object=None, + timeout=_http._DEFAULT_TIMEOUT, + extra_api_info=None, + ): + """A low level method to send a request to the API. + + Args: + method (str): The HTTP method (e.g., 'GET', 'POST'). + url (str): The specific API URL. + data (Optional[Union[str, bytes, dict]]): The body of the request. + content_type (Optional[str]): The Content-Type header. + headers (Optional[dict]): Additional headers for the request. + target_object (Optional[object]): (Unused in async impl) Reference to the target object. + timeout (Optional[float]): The timeout in seconds. + extra_api_info (Optional[str]): Extra info for the User-Agent / Client-Info. + + Returns: + google.auth.aio.transport.Response: The HTTP response object. + """ + headers = headers.copy() if headers else {} + headers.update(self.extra_headers) + headers["Accept-Encoding"] = "gzip" + + if content_type: + headers["Content-Type"] = content_type + + if extra_api_info: + headers[_http.CLIENT_INFO_HEADER] = f"{self.user_agent} {extra_api_info}" + else: + headers[_http.CLIENT_INFO_HEADER] = self.user_agent + headers["User-Agent"] = self.user_agent + + return await self._do_request( + method, url, headers, data, target_object, timeout=timeout + ) + + async def _do_request( + self, method, url, headers, data, target_object, timeout=_http._DEFAULT_TIMEOUT + ): + """Low-level helper: perform the actual API request. + + Args: + method (str): HTTP method. + url (str): API URL. + headers (dict): HTTP headers. + data (Optional[bytes]): Request body. + target_object: Unused in this implementation, kept for compatibility. + timeout (float): Request timeout. + + Returns: + google.auth.aio.transport.Response: The response object. + """ + return await self.async_http.request( + method=method, + url=url, + headers=headers, + data=data, + timeout=timeout, + ) + + async def api_request(self, *args, **kwargs): + """Perform an API request with retry and tracing support. + + Args: + *args: Positional arguments passed to _perform_api_request. + **kwargs: Keyword arguments passed to _perform_api_request. + Can include 'retry' (an AsyncRetry object). + + Returns: + Union[dict, bytes]: The parsed JSON response or raw bytes. + """ + retry = kwargs.pop("retry", None) + invocation_id = _helpers._get_invocation_id() + kwargs["extra_api_info"] = invocation_id + span_attributes = { + "gccl-invocation-id": invocation_id, + } + + call = functools.partial(self._perform_api_request, *args, **kwargs) + + with create_trace_span( + name="Storage.AsyncConnection.api_request", + attributes=span_attributes, + client=self._client, + api_request=kwargs, + retry=retry, + ): + if retry: + # Ensure the retry policy checks its conditions + try: + retry = retry.get_retry_policy_if_conditions_met(**kwargs) + except AttributeError: + pass + if retry: + call = retry(call) + return await call() + + def build_api_url( + self, path, query_params=None, api_base_url=None, api_version=None + ): + """Construct an API URL. + + Args: + path (str): The API path (e.g. '/b/bucket-name'). + query_params (Optional[Union[dict, list]]): Query parameters. + api_base_url (Optional[str]): Base URL override. + api_version (Optional[str]): API version override. + + Returns: + str: The fully constructed URL. + """ + url = self.API_URL_TEMPLATE.format( + api_base_url=(api_base_url or self.API_BASE_URL), + api_version=(api_version or self.API_VERSION), + path=path, + ) + + query_params = query_params or {} + + if isinstance(query_params, collections.abc.Mapping): + query_params = query_params.copy() + else: + query_params_dict = collections.defaultdict(list) + for key, value in query_params: + query_params_dict[key].append(value) + query_params = query_params_dict + + query_params.setdefault("prettyPrint", "false") + + url += "?" + urlencode(query_params, doseq=True) + + return url + + async def _perform_api_request( + self, + method, + path, + query_params=None, + data=None, + content_type=None, + headers=None, + api_base_url=None, + api_version=None, + expect_json=True, + _target_object=None, + timeout=_http._DEFAULT_TIMEOUT, + extra_api_info=None, + ): + """Internal helper to prepare the URL/Body and execute the request. + + This method handles JSON serialization of the body, URL construction, + and converts HTTP errors into google.api_core.exceptions. + + Args: + method (str): HTTP method. + path (str): URL path. + query_params (Optional[dict]): Query params. + data (Optional[Union[dict, bytes]]): Request body. + content_type (Optional[str]): Content-Type header. + headers (Optional[dict]): HTTP headers. + api_base_url (Optional[str]): Base URL override. + api_version (Optional[str]): API version override. + expect_json (bool): If True, parses response as JSON. Defaults to True. + _target_object: Internal use (unused here). + timeout (float): Request timeout. + extra_api_info (Optional[str]): Extra client info. + + Returns: + Union[dict, bytes]: Parsed JSON or raw bytes. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the API returns an error. + """ + url = self.build_api_url( + path=path, + query_params=query_params, + api_base_url=api_base_url, + api_version=api_version, + ) + + if data and isinstance(data, dict): + data = json.dumps(data) + content_type = "application/json" + + response = await self._make_request( + method=method, + url=url, + data=data, + content_type=content_type, + headers=headers, + target_object=_target_object, + timeout=timeout, + extra_api_info=extra_api_info, + ) + + # Handle API Errors + if not (200 <= response.status_code < 300): + content = await response.read() + payload = {} + if content: + try: + payload = json.loads(content.decode("utf-8")) + except (ValueError, UnicodeDecodeError): + payload = { + "error": {"message": content.decode("utf-8", errors="replace")} + } + raise google.api_core.exceptions.format_http_response_error( + response, method, url, payload + ) + + # Handle Success + payload = await response.read() + if expect_json: + if not payload: + return {} + return json.loads(payload) + else: + return payload diff --git a/google/cloud/storage/_http.py b/google/cloud/storage/_http.py index aea13cc57..d3b11cb20 100644 --- a/google/cloud/storage/_http.py +++ b/google/cloud/storage/_http.py @@ -20,6 +20,8 @@ from google.cloud.storage import _helpers from google.cloud.storage._opentelemetry_tracing import create_trace_span +AGENT_VERSION = f"gcloud-python/{__version__}" + class Connection(_http.JSONConnection): """A connection to Google Cloud Storage via the JSON REST API. @@ -54,9 +56,8 @@ def __init__(self, client, client_info=None, api_endpoint=None): # TODO: When metrics all use gccl, this should be removed #9552 if self._client_info.user_agent is None: # pragma: no branch self._client_info.user_agent = "" - agent_version = f"gcloud-python/{__version__}" - if agent_version not in self._client_info.user_agent: - self._client_info.user_agent += f" {agent_version} " + if AGENT_VERSION not in self._client_info.user_agent: + self._client_info.user_agent += f" {AGENT_VERSION} " API_VERSION = _helpers._API_VERSION """The version of the API, used in building the API call's URL.""" diff --git a/tests/unit/asyncio/test_async_connection.py b/tests/unit/asyncio/test_async_connection.py new file mode 100644 index 000000000..5a4dde8c2 --- /dev/null +++ b/tests/unit/asyncio/test_async_connection.py @@ -0,0 +1,276 @@ +import json +import pytest +from unittest import mock + +from google.cloud.storage import _http as storage_http +from google.api_core import exceptions +from google.api_core.client_info import ClientInfo +from google.cloud.storage._experimental.asyncio.async_connection import AsyncConnection + + +class MockAuthResponse: + """Simulates google.auth.aio.transport.aiohttp.Response.""" + + def __init__(self, status_code=200, data=b"{}", headers=None): + self.status_code = status_code + self._data = data + self._headers = headers or {} + + @property + def headers(self): + return self._headers + + async def read(self): + return self._data + + +@pytest.fixture +def mock_client(): + """Mocks the Google Cloud Storage Client.""" + client = mock.Mock() + client.async_http = mock.AsyncMock() + return client + + +@pytest.fixture +def async_connection(mock_client): + """Creates an instance of AsyncConnection with a mocked client.""" + return AsyncConnection(mock_client) + + +@pytest.fixture +def mock_trace_span(): + """Mocks the OpenTelemetry trace span context manager.""" + target = ( + "google.cloud.storage._experimental.asyncio.async_connection.create_trace_span" + ) + with mock.patch(target) as mock_span: + mock_span.return_value.__enter__.return_value = None + yield mock_span + + +def test_init_defaults(async_connection): + """Test initialization with default values.""" + assert isinstance(async_connection._client_info, ClientInfo) + assert async_connection.API_BASE_URL == storage_http.Connection.DEFAULT_API_ENDPOINT + assert async_connection.API_VERSION == storage_http.Connection.API_VERSION + assert async_connection.API_URL_TEMPLATE == storage_http.Connection.API_URL_TEMPLATE + assert "gcloud-python" in async_connection.user_agent + + +def test_init_custom_endpoint(mock_client): + """Test initialization with a custom API endpoint.""" + custom_endpoint = "https://custom.storage.googleapis.com" + conn = AsyncConnection(mock_client, api_endpoint=custom_endpoint) + assert conn.API_BASE_URL == custom_endpoint + + +def test_extra_headers_property(async_connection): + """Test getter and setter for extra_headers.""" + headers = {"X-Custom-Header": "value"} + async_connection.extra_headers = headers + assert async_connection.extra_headers == headers + + +def test_build_api_url_simple(async_connection): + """Test building a simple API URL.""" + url = async_connection.build_api_url(path="/b/bucket-name") + expected = ( + f"{async_connection.API_BASE_URL}/storage/v1/b/bucket-name?prettyPrint=false" + ) + assert url == expected + + +def test_build_api_url_with_params(async_connection): + """Test building an API URL with query parameters.""" + params = {"projection": "full", "versions": True} + url = async_connection.build_api_url(path="/b/bucket", query_params=params) + + assert "projection=full" in url + assert "versions=True" in url + assert "prettyPrint=false" in url + + +@pytest.mark.asyncio +async def test_make_request_headers(async_connection, mock_client): + """Test that _make_request adds the correct headers.""" + mock_response = MockAuthResponse(status_code=200) + mock_client.async_http.request.return_value = mock_response + + async_connection.user_agent = "test-agent/1.0" + async_connection.extra_headers = {"X-Test": "True"} + + await async_connection._make_request( + method="GET", url="http://example.com", content_type="application/json" + ) + + call_args = mock_client.async_http.request.call_args + _, kwargs = call_args + headers = kwargs["headers"] + + assert headers["Content-Type"] == "application/json" + assert headers["Accept-Encoding"] == "gzip" + + assert "test-agent/1.0" in headers["User-Agent"] + + assert headers["X-Test"] == "True" + + +@pytest.mark.asyncio +async def test_api_request_success(async_connection, mock_client, mock_trace_span): + """Test the high-level api_request method wraps the call correctly.""" + expected_data = {"items": []} + mock_response = MockAuthResponse( + status_code=200, data=json.dumps(expected_data).encode("utf-8") + ) + mock_client.async_http.request.return_value = mock_response + + response = await async_connection.api_request(method="GET", path="/b/bucket") + + assert response == expected_data + mock_trace_span.assert_called_once() + + +@pytest.mark.asyncio +async def test_perform_api_request_json_serialization( + async_connection, mock_client, mock_trace_span +): + """Test that dictionary data is serialized to JSON.""" + mock_response = MockAuthResponse(status_code=200) + mock_client.async_http.request.return_value = mock_response + + data = {"key": "value"} + await async_connection.api_request(method="POST", path="/b", data=data) + + call_args = mock_client.async_http.request.call_args + _, kwargs = call_args + + assert kwargs["data"] == json.dumps(data) + assert kwargs["headers"]["Content-Type"] == "application/json" + + +@pytest.mark.asyncio +async def test_perform_api_request_error_handling( + async_connection, mock_client, mock_trace_span +): + """Test that non-2xx responses raise GoogleAPICallError.""" + error_json = {"error": {"message": "Not Found"}} + mock_response = MockAuthResponse( + status_code=404, data=json.dumps(error_json).encode("utf-8") + ) + mock_client.async_http.request.return_value = mock_response + + with pytest.raises(exceptions.GoogleAPICallError) as excinfo: + await async_connection.api_request(method="GET", path="/b/nonexistent") + + assert "Not Found" in str(excinfo.value) + + +@pytest.mark.asyncio +async def test_perform_api_request_no_json_response( + async_connection, mock_client, mock_trace_span +): + """Test response handling when expect_json is False.""" + raw_bytes = b"binary_data" + mock_response = MockAuthResponse(status_code=200, data=raw_bytes) + mock_client.async_http.request.return_value = mock_response + + response = await async_connection.api_request( + method="GET", path="/b/obj", expect_json=False + ) + + assert response == raw_bytes + + +@pytest.mark.asyncio +async def test_api_request_with_retry(async_connection, mock_client, mock_trace_span): + """Test that the retry policy is applied if conditions are met.""" + mock_response = MockAuthResponse(status_code=200, data=b"{}") + mock_client.async_http.request.return_value = mock_response + + mock_retry = mock.Mock() + mock_policy = mock.Mock(side_effect=lambda call: call) + mock_retry.get_retry_policy_if_conditions_met.return_value = mock_policy + + await async_connection.api_request(method="GET", path="/b/bucket", retry=mock_retry) + + mock_retry.get_retry_policy_if_conditions_met.assert_called_once() + mock_policy.assert_called_once() + + +def test_build_api_url_repeated_params(async_connection): + """Test building URL with a list of tuples (repeated keys).""" + params = [("field", "name"), ("field", "size")] + url = async_connection.build_api_url(path="/b/bucket", query_params=params) + + assert "field=name" in url + assert "field=size" in url + assert url.count("field=") == 2 + + +def test_build_api_url_overrides(async_connection): + """Test building URL with explicit base URL and version overrides.""" + url = async_connection.build_api_url( + path="/b/bucket", api_base_url="https://example.com", api_version="v2" + ) + assert "https://example.com/storage/v2/b/bucket" in url + + +@pytest.mark.asyncio +async def test_perform_api_request_empty_response( + async_connection, mock_client, mock_trace_span +): + """Test handling of empty 2xx response when expecting JSON.""" + mock_response = MockAuthResponse(status_code=204, data=b"") + mock_client.async_http.request.return_value = mock_response + + response = await async_connection.api_request( + method="DELETE", path="/b/bucket/o/object" + ) + + assert response == {} + + +@pytest.mark.asyncio +async def test_perform_api_request_non_json_error( + async_connection, mock_client, mock_trace_span +): + """Test error handling when the error response is plain text (not JSON).""" + error_text = "Bad Gateway" + mock_response = MockAuthResponse(status_code=502, data=error_text.encode("utf-8")) + mock_client.async_http.request.return_value = mock_response + + with pytest.raises(exceptions.GoogleAPICallError) as excinfo: + await async_connection.api_request(method="GET", path="/b/bucket") + + assert error_text in str(excinfo.value) + assert excinfo.value.code == 502 + + +@pytest.mark.asyncio +async def test_make_request_extra_api_info(async_connection, mock_client): + """Test logic for constructing x-goog-api-client header with extra info.""" + mock_response = MockAuthResponse(status_code=200) + mock_client.async_http.request.return_value = mock_response + + invocation_id = "test-id-123" + + await async_connection._make_request( + method="GET", url="http://example.com", extra_api_info=invocation_id + ) + + call_args = mock_client.async_http.request.call_args + _, kwargs = call_args + headers = kwargs["headers"] + + client_header = headers.get("X-Goog-API-Client") + assert async_connection.user_agent in client_header + assert invocation_id in client_header + + +def test_user_agent_setter(async_connection): + """Test explicit setter for user_agent.""" + new_ua = "my-custom-app/1.0" + async_connection.user_agent = new_ua + assert new_ua in async_connection.user_agent + assert async_connection._client_info.user_agent == new_ua