diff --git a/sdks/python/apache_beam/examples/rate_limiter_simple.py b/sdks/python/apache_beam/examples/rate_limiter_simple.py new file mode 100644 index 000000000000..d757140d8686 --- /dev/null +++ b/sdks/python/apache_beam/examples/rate_limiter_simple.py @@ -0,0 +1,91 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""A simple example demonstrating usage of the EnvoyRateLimiter in a Beam pipeline.""" + +import argparse +import logging +import time + +import apache_beam as beam +from apache_beam.utils import shared +from apache_beam.io.components.rate_limiter import EnvoyRateLimiter +from apache_beam.options.pipeline_options import PipelineOptions + + +class SampleApiDoFn(beam.DoFn): + """A DoFn that simulates calling an external API with rate limiting.""" + def __init__(self, rls_address, domain, descriptors): + self.rls_address = rls_address + self.domain = domain + self.descriptors = descriptors + self._shared = shared.Shared() + self.rate_limiter = None + + def setup(self): + # Initialize the rate limiter in setup() + # We use shared.Shared() to ensure only one RateLimiter instance is created + # per worker and shared across threads. + def init_limiter(): + logging.info(f"Connecting to Envoy RLS at {self.rls_address}") + return EnvoyRateLimiter( + service_address=self.rls_address, + domain=self.domain, + descriptors=self.descriptors, + namespace='example_pipeline') + + self.rate_limiter = self._shared.acquire(init_limiter) + + def process(self, element): + self.rate_limiter.throttle() + + # Process the element mock API call + logging.info(f"Processing element: {element}") + time.sleep(0.1) + yield element + + +def parse_known_args(argv): + """Parses args for the workflow.""" + parser = argparse.ArgumentParser() + parser.add_argument( + '--rls_address', + default='localhost:8081', + help='Address of the Envoy Rate Limit Service') + return parser.parse_known_args(argv) + + +def run(argv=None): + known_args, pipeline_args = parse_known_args(argv) + pipeline_options = PipelineOptions(pipeline_args) + + with beam.Pipeline(options=pipeline_options) as p: + ( + p + | 'Create' >> beam.Create(range(100)) + | 'RateLimit' >> beam.ParDo( + SampleApiDoFn( + rls_address=known_args.rls_address, + domain="mongo_cps", + descriptors=[{ + "database": "users" + }]))) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + run() diff --git a/sdks/python/apache_beam/io/components/rate_limiter.py b/sdks/python/apache_beam/io/components/rate_limiter.py new file mode 100644 index 000000000000..8c382e7c6b60 --- /dev/null +++ b/sdks/python/apache_beam/io/components/rate_limiter.py @@ -0,0 +1,225 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Rate Limiter classes for controlling access to external resources. +""" + +import abc +import logging +import random +import threading +import time +from typing import Dict +from typing import List + +import grpc +from envoy_data_plane.envoy.extensions.common.ratelimit.v3 import RateLimitDescriptor +from envoy_data_plane.envoy.extensions.common.ratelimit.v3 import RateLimitDescriptorEntry +from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitRequest +from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitResponse +from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitResponseCode + +from apache_beam.io.components import adaptive_throttler +from apache_beam.metrics import Metrics + +_LOGGER = logging.getLogger(__name__) + +_MAX_CONNECTION_RETRIES = 5 +_RETRY_DELAY_SECONDS = 10 + + +class RateLimiter(abc.ABC): + """Abstract base class for RateLimiters.""" + def __init__(self, namespace: str = ""): + # Metrics collected from the RateLimiter + # Metric updates are thread safe + self.throttling_signaler = adaptive_throttler.ThrottlingSignaler( + namespace=namespace) + self.requests_counter = Metrics.counter( + namespace, 'envoyRatelimitRequestsTotal') + self.requests_allowed = Metrics.counter( + namespace, 'envoyRatelimitRequestsAllowed') + self.requests_throttled = Metrics.counter( + namespace, 'envoyRatelimitRequestsThrottled') + self.rpc_errors = Metrics.counter(namespace, 'envoyRatelimitRpcErrors') + self.rpc_retries = Metrics.counter(namespace, 'envoyRatelimitRpcRetries') + self.rpc_latency = Metrics.distribution( + namespace, 'envoyRatelimitRpcLatencyMs') + + @abc.abstractmethod + def throttle(self, **kwargs) -> bool: + """Check if request should be throttled. + + Args: + **kwargs: Keyword arguments specific to the RateLimiter implementation. + + Returns: + bool: True if the request is allowed, False if retries exceeded. + + Raises: + Exception: If an underlying infrastructure error occurs (e.g. RPC failure). + """ + pass + + +class EnvoyRateLimiter(RateLimiter): + """ + Rate limiter implementation that uses an external Envoy Rate Limit Service. + """ + def __init__( + self, + service_address: str, + domain: str, + descriptors: List[Dict[str, str]], + timeout: float = 5.0, + block_until_allowed: bool = True, + retries: int = 3, + namespace: str = ""): + """ + Args: + service_address: Address of the Envoy RLS (e.g., 'localhost:8081'). + domain: The rate limit domain. + descriptors: List of descriptors (key-value pairs). + retries: Number of retries to attempt if rate limited, respected only if + block_until_allowed is False. + timeout: gRPC timeout in seconds. + block_until_allowed: If enabled blocks until RateLimiter gets + the token. + namespace: the namespace to use for logging and signaling + throttling is occurring. + """ + super().__init__(namespace=namespace) + + self.service_address = service_address + self.domain = domain + self.descriptors = descriptors + self.retries = retries + self.timeout = timeout + self.block_until_allowed = block_until_allowed + self._stub = None + self._lock = threading.Lock() + + class RateLimitServiceStub(object): + """ + Wrapper for gRPC stub to be compatible with envoy_data_plane messages. + + The envoy-data-plane package uses 'betterproto' which generates async stubs + for 'grpclib'. As Beam uses standard synchronous 'grpcio', RateLimitServiceStub + is a bridge class to use the betterproto Message types (RateLimitRequest) + with a standard grpcio Channel. + """ + def __init__(self, channel): + self.ShouldRateLimit = channel.unary_unary( + '/envoy.service.ratelimit.v3.RateLimitService/ShouldRateLimit', + request_serializer=RateLimitRequest.SerializeToString, + response_deserializer=RateLimitResponse.FromString, + ) + + def init_connection(self): + if self._stub is None: + # Acquire lock to safegaurd againest multiple DoFn threads sharing the same + # RateLimiter instance, which is the case when using Shared(). + with self._lock: + if self._stub is None: + channel = grpc.insecure_channel(self.service_address) + self._stub = EnvoyRateLimiter.RateLimitServiceStub(channel) + + def throttle(self, hits_added: int = 1) -> bool: + """Calls the Envoy RLS to check for rate limits. + + Args: + hits_added: Number of hits to add to the rate limit. + + Returns: + bool: True if the request is allowed, False if retries exceeded. + """ + self.init_connection() + + # execute thread-safe gRPC call + # Convert descriptors to proto format + proto_descriptors = [] + for d in self.descriptors: + entries = [] + for k, v in d.items(): + entries.append(RateLimitDescriptorEntry(key=k, value=v)) + proto_descriptors.append(RateLimitDescriptor(entries=entries)) + + request = RateLimitRequest( + domain=self.domain, + descriptors=proto_descriptors, + hits_addend=hits_added) + + self.requests_counter.inc() + attempt = 0 + throttled = False + while True: + if not self.block_until_allowed and attempt > self.retries: + break + + # Connection retry loop + for conn_attempt in range(_MAX_CONNECTION_RETRIES): + try: + start_time = time.time() + response = self._stub.ShouldRateLimit(request, timeout=self.timeout) + self.rpc_latency.update(int((time.time() - start_time) * 1000)) + break + except grpc.RpcError as e: + if conn_attempt == _MAX_CONNECTION_RETRIES - 1: + _LOGGER.error("[EnvoyRateLimiter] Connection Failed: %s", e) + self.rpc_errors.inc() + raise e + self.rpc_retries.inc() + _LOGGER.warning( + "[EnvoyRateLimiter] Connection Failed, retrying: %s", e) + time.sleep(_RETRY_DELAY_SECONDS) + + if response.overall_code == RateLimitResponseCode.OK: + self.requests_allowed.inc() + throttled = True + break + elif response.overall_code == RateLimitResponseCode.OVER_LIMIT: + self.requests_throttled.inc() + # Ratelimit exceeded, sleep for duration until reset and retry + # multiple rules can be set in the RLS config, so we need to find the max duration + sleep_s = 0.0 + if response.statuses: + for status in response.statuses: + if status.code == RateLimitResponseCode.OVER_LIMIT: + dur = status.duration_until_reset + # duration_until_reset is converted to timedelta by betterproto + # duration_until_reset has microsecond precision + val = dur.total_seconds() + if val > sleep_s: + sleep_s = val + + # Add 1% additive jitter to prevent thundering herd + # This adds jitter in the order of ms + jitter = random.uniform(0, 0.01 * sleep_s) + sleep_s += jitter + + _LOGGER.warning("[EnvoyRateLimiter] Throttled for %s seconds", sleep_s) + # signal throttled time to backend + self.throttling_signaler.signal_throttled(int(sleep_s)) + time.sleep(sleep_s) + attempt += 1 + else: + _LOGGER.error( + "[EnvoyRateLimiter] Unknown code from RLS: %s", + response.overall_code) + break + return throttled diff --git a/sdks/python/apache_beam/io/components/rate_limiter_test.py b/sdks/python/apache_beam/io/components/rate_limiter_test.py new file mode 100644 index 000000000000..e5c0ef248318 --- /dev/null +++ b/sdks/python/apache_beam/io/components/rate_limiter_test.py @@ -0,0 +1,141 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from unittest import mock +import grpc +from datetime import timedelta + +from apache_beam.io.components import rate_limiter +from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitResponse +from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitResponseCode +from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitResponseDescriptorStatus + + +class EnvoyRateLimiterTest(unittest.TestCase): + def setUp(self): + self.service_address = 'localhost:8081' + self.domain = 'test_domain' + self.descriptors = [{'key': 'value'}] + self.limiter = rate_limiter.EnvoyRateLimiter( + self.service_address, + self.domain, + self.descriptors, + timeout=0.1, # Fast timeout for tests + block_until_allowed=False, + retries=2, + namespace='test_namespace') + + @mock.patch('grpc.insecure_channel') + def test_throttle_allowed(self, mock_channel): + # Mock successful OK response + mock_stub = mock.Mock() + mock_response = RateLimitResponse(overall_code=RateLimitResponseCode.OK) + mock_stub.ShouldRateLimit.return_value = mock_response + + # Inject mock stub + self.limiter._stub = mock_stub + + throttled = self.limiter.throttle() + + self.assertTrue(throttled) + mock_stub.ShouldRateLimit.assert_called_once() + + @mock.patch('grpc.insecure_channel') + def test_throttle_over_limit_retries_exceeded(self, mock_channel): + # Mock OVER_LIMIT response + mock_stub = mock.Mock() + mock_response = RateLimitResponse( + overall_code=RateLimitResponseCode.OVER_LIMIT) + mock_stub.ShouldRateLimit.return_value = mock_response + + self.limiter._stub = mock_stub + # block_until_allowed is False, so it should eventually return False + + # We mock time.sleep to run fast + with mock.patch('time.sleep'): + throttled = self.limiter.throttle() + + self.assertFalse(throttled) + # Should be called 1 (initial) + 2 (retries) + 1 (last check > retries logic depends on loop) + # Logic: attempt starts at 0. + # Loop 1: attempt 0. status OVER_LIMIT. sleep. attempt becomes 1. + # Loop 2: attempt 1. status OVER_LIMIT. sleep. attempt becomes 2. + # Loop 3: attempt 2. status OVER_LIMIT. sleep. attempt becomes 3. + # Loop 4: attempt 3 > retries(2). Break. + # Total calls: 3 + self.assertEqual(mock_stub.ShouldRateLimit.call_count, 3) + + @mock.patch('grpc.insecure_channel') + def test_throttle_rpc_error_retry(self, mock_channel): + # Mock RpcError then Success + mock_stub = mock.Mock() + mock_response = RateLimitResponse(overall_code=RateLimitResponseCode.OK) + + # Side effect: Error, Error, Success + error = grpc.RpcError() + mock_stub.ShouldRateLimit.side_effect = [error, error, mock_response] + + self.limiter._stub = mock_stub + + with mock.patch('time.sleep'): + throttled = self.limiter.throttle() + + self.assertTrue(throttled) + self.assertEqual(mock_stub.ShouldRateLimit.call_count, 3) + + @mock.patch('grpc.insecure_channel') + def test_throttle_rpc_error_fail(self, mock_channel): + # Mock Persistent RpcError + mock_stub = mock.Mock() + error = grpc.RpcError() + mock_stub.ShouldRateLimit.side_effect = error + + self.limiter._stub = mock_stub + + with mock.patch('time.sleep'): + with self.assertRaises(grpc.RpcError): + self.limiter.throttle() + + # The inner loop tries 5 times for connection errors + self.assertEqual(mock_stub.ShouldRateLimit.call_count, 5) + + @mock.patch('grpc.insecure_channel') + @mock.patch('random.uniform', return_value=0.0) + def test_extract_duration_from_response(self, mock_random, mock_channel): + # Mock OVER_LIMIT with specific duration + mock_stub = mock.Mock() + + # Valid until 5 seconds + status = RateLimitResponseDescriptorStatus( + code=RateLimitResponseCode.OVER_LIMIT, + duration_until_reset=timedelta(seconds=5)) + mock_response = RateLimitResponse( + overall_code=RateLimitResponseCode.OVER_LIMIT, statuses=[status]) + + mock_stub.ShouldRateLimit.return_value = mock_response + self.limiter._stub = mock_stub + self.limiter.retries = 0 # Single attempt + + with mock.patch('time.sleep') as mock_sleep: + self.limiter.throttle() + # Should sleep for 5 seconds (jitter is 0.0) + mock_sleep.assert_called_with(5.0) + + +if __name__ == '__main__': + unittest.main()