From 3d9175e38d539a2e1ff7070205c1d316f5c44593 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Thu, 11 Dec 2025 21:14:19 +0000 Subject: [PATCH 1/3] feat(amgi-aiokafka): add message send manager so it can be used by other servers --- .../src/amgi_aiokafka/__init__.py | 100 ++++++++++++------ 1 file changed, 69 insertions(+), 31 deletions(-) diff --git a/packages/amgi-aiokafka/src/amgi_aiokafka/__init__.py b/packages/amgi-aiokafka/src/amgi_aiokafka/__init__.py index b419ef5..6a7984a 100644 --- a/packages/amgi-aiokafka/src/amgi_aiokafka/__init__.py +++ b/packages/amgi-aiokafka/src/amgi_aiokafka/__init__.py @@ -1,10 +1,13 @@ import logging +import sys from asyncio import Lock from collections import deque from collections.abc import Awaitable from collections.abc import Callable from collections.abc import Iterable +from types import TracebackType from typing import Any +from typing import AsyncContextManager from aiokafka import AIOKafkaConsumer from aiokafka import AIOKafkaProducer @@ -19,6 +22,14 @@ from amgi_types import MessageScope from amgi_types import MessageSendEvent +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +_MessageSendT = Callable[[MessageSendEvent], Awaitable[None]] +_MessageSendManagerT = AsyncContextManager[_MessageSendT] + logger = logging.getLogger("amgi-aiokafka.error") @@ -28,9 +39,14 @@ def run( *topics: Iterable[str], bootstrap_servers: str | list[str] = "localhost", group_id: str | None = None, + message_send: _MessageSendManagerT | None = None, ) -> None: server = Server( - app, *topics, bootstrap_servers=bootstrap_servers, group_id=group_id + app, + *topics, + bootstrap_servers=bootstrap_servers, + group_id=group_id, + message_send=message_send, ) server_serve(server) @@ -75,8 +91,8 @@ def __init__( self, consumer: AIOKafkaConsumer, message_receive_ids: dict[str, dict[TopicPartition, int]], - message_send: Callable[[MessageSendEvent], Awaitable[None]], ackable_consumer: bool, + message_send: _MessageSendT, ) -> None: self._consumer = consumer self._message_send = message_send @@ -91,6 +107,48 @@ async def __call__(self, event: AMGISendEvent) -> None: await self._message_send(event) +class MessageSend: + def __init__(self, bootstrap_servers: str | list[str]) -> None: + self._bootstrap_servers = bootstrap_servers + self._producer = None + self._producer_lock = Lock() + + async def __aenter__(self) -> Self: + return self + + async def __call__(self, event: MessageSendEvent) -> None: + producer = await self._get_producer() + encoded_headers = [(key.decode(), value) for key, value in event["headers"]] + + key = event.get("bindings", {}).get("kafka", {}).get("key") + await producer.send( + event["address"], + headers=encoded_headers, + value=event.get("payload"), + key=key, + ) + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if self._producer is not None: + await self._producer.stop() + + async def _get_producer(self) -> AIOKafkaProducer: + if self._producer is None: + async with self._producer_lock: + if self._producer is None: + producer = AIOKafkaProducer( + bootstrap_servers=self._bootstrap_servers + ) + await producer.start() + self._producer = producer + return self._producer + + class Server: _consumer: AIOKafkaConsumer @@ -100,14 +158,15 @@ def __init__( *topics: Iterable[str], bootstrap_servers: str | list[str], group_id: str | None, + message_send: _MessageSendManagerT | None = None, ) -> None: self._app = app self._topics = topics self._bootstrap_servers = bootstrap_servers self._group_id = group_id + self._message_send = message_send or MessageSend(bootstrap_servers) + self._ackable_consumer = self._group_id is not None - self._producer: AIOKafkaProducer | None = None - self._producer_lock = Lock() self._stoppable = Stoppable() async def serve(self) -> None: @@ -117,14 +176,13 @@ async def serve(self) -> None: group_id=self._group_id, enable_auto_commit=False, ) - async with self._consumer: + async with self._consumer, self._message_send as message_send: async with Lifespan(self._app) as state: - await self._main_loop(state) + await self._main_loop(state, message_send) - if self._producer is not None: - await self._producer.stop() - - async def _main_loop(self, state: dict[str, Any]) -> None: + async def _main_loop( + self, state: dict[str, Any], message_send: _MessageSendT + ) -> None: async for messages in self._stoppable.call( self._consumer.getmany, timeout_ms=1000 ): @@ -153,30 +211,10 @@ async def _main_loop(self, state: dict[str, Any]) -> None: _Send( self._consumer, message_receive_ids, - self._message_send, self._ackable_consumer, + message_send, ), ) - async def _get_producer(self) -> AIOKafkaProducer: - async with self._producer_lock: - if self._producer is None: - producer = AIOKafkaProducer(bootstrap_servers=self._bootstrap_servers) - await producer.start() - self._producer = producer - return self._producer - - async def _message_send(self, event: MessageSendEvent) -> None: - producer = await self._get_producer() - encoded_headers = [(key.decode(), value) for key, value in event["headers"]] - - key = event.get("bindings", {}).get("kafka", {}).get("key") - await producer.send( - event["address"], - headers=encoded_headers, - value=event.get("payload"), - key=key, - ) - def stop(self) -> None: self._stoppable.stop() From 5272ec64b8dec15122a54f43d22d537ae3dd25fa Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Thu, 11 Dec 2025 21:16:15 +0000 Subject: [PATCH 2/3] feat(amgi-redis): add message send manager so it can be used by other servers --- .../amgi-redis/src/amgi_redis/__init__.py | 79 +++++++++++++++---- 1 file changed, 64 insertions(+), 15 deletions(-) diff --git a/packages/amgi-redis/src/amgi_redis/__init__.py b/packages/amgi-redis/src/amgi_redis/__init__.py index 8b96949..37268a9 100644 --- a/packages/amgi-redis/src/amgi_redis/__init__.py +++ b/packages/amgi-redis/src/amgi_redis/__init__.py @@ -1,6 +1,11 @@ import asyncio +import sys from asyncio import Task +from collections.abc import Awaitable +from collections.abc import Callable +from types import TracebackType from typing import Any +from typing import AsyncContextManager from amgi_common import Lifespan from amgi_common import server_serve @@ -9,13 +14,26 @@ from amgi_types import AMGISendEvent from amgi_types import MessageReceiveEvent from amgi_types import MessageScope +from amgi_types import MessageSendEvent from redis.asyncio import from_url from redis.asyncio.client import PubSub -from redis.asyncio.client import Redis +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self -def run(app: AMGIApplication, *channels: str, url: str = "redis://localhost") -> None: - server = Server(app, *channels, url=url) +_MessageSendT = Callable[[MessageSendEvent], Awaitable[None]] +_MessageSendManagerT = AsyncContextManager[_MessageSendT] + + +def run( + app: AMGIApplication, + *channels: str, + url: str = "redis://localhost", + message_send: _MessageSendManagerT | None = None, +) -> None: + server = Server(app, *channels, url=url, message_send=message_send) server_serve(server) @@ -40,45 +58,76 @@ async def __call__(self) -> MessageReceiveEvent: class _Send: - def __init__(self, redis: Redis) -> None: - self._redis = redis + def __init__(self, message_send: _MessageSendT) -> None: + self._message_send = message_send + + async def __call__(self, event: AMGISendEvent) -> None: + if event["type"] == "message.send": + await self._message_send(event) + - async def __call__(self, message: AMGISendEvent) -> None: - if message["type"] == "message.send": - await self._redis.publish(message["address"], message["payload"]) +class MessageSend: + def __init__(self, url: str) -> None: + self._redis = from_url(url) + + async def __aenter__(self) -> Self: + return self + + async def __call__(self, event: MessageSendEvent) -> None: + await self._redis.publish(event["address"], event["payload"]) + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self._redis.aclose() class Server: - def __init__(self, app: AMGIApplication, *channels: str, url: str): + def __init__( + self, + app: AMGIApplication, + *channels: str, + url: str, + message_send: _MessageSendManagerT | None = None, + ) -> None: self._app = app self._channels = channels self._url = url + self._message_send = message_send or MessageSend(url) self._stoppable = Stoppable() self._tasks = set[Task[None]]() async def serve(self) -> None: redis = from_url(self._url) - async with redis.pubsub() as pubsub: + async with redis.pubsub() as pubsub, self._message_send as message_send: await pubsub.subscribe(*self._channels) async with Lifespan(self._app) as state: - await self._main_loop(redis, pubsub, state) + await self._main_loop(message_send, pubsub, state) await asyncio.gather(*self._tasks, return_exceptions=True) async def _main_loop( - self, redis: Redis, pubsub: PubSub, state: dict[str, Any] + self, message_send: _MessageSendT, pubsub: PubSub, state: dict[str, Any] ) -> None: loop = asyncio.get_event_loop() async for message in self._stoppable.call( pubsub.get_message, ignore_subscribe_messages=True, timeout=None ): if message is not None: - task = loop.create_task(self._handle_message(message, redis, state)) + task = loop.create_task( + self._handle_message(message, message_send, state) + ) self._tasks.add(task) task.add_done_callback(self._tasks.discard) async def _handle_message( - self, message: dict[str, Any], redis: Redis, state: dict[str, Any] + self, + message: dict[str, Any], + message_send: _MessageSendT, + state: dict[str, Any], ) -> None: scope: MessageScope = { "type": "message", @@ -86,7 +135,7 @@ async def _handle_message( "address": message["channel"].decode(), "state": state.copy(), } - await self._app(scope, _Receive(message), _Send(redis)) + await self._app(scope, _Receive(message), _Send(message_send)) def stop(self) -> None: self._stoppable.stop() From 2e6f3aa0224b55dfd58e3e059b2eacd4b90612c9 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Mon, 15 Dec 2025 18:55:25 +0000 Subject: [PATCH 3/3] feat(amgi-aiobotocore): add message send manager so it can be used by other servers --- .../src/amgi_aiobotocore/sqs.py | 119 +++++++++++++----- 1 file changed, 89 insertions(+), 30 deletions(-) diff --git a/packages/amgi-aiobotocore/src/amgi_aiobotocore/sqs.py b/packages/amgi-aiobotocore/src/amgi_aiobotocore/sqs.py index 2d47a9d..2ab92e8 100644 --- a/packages/amgi-aiobotocore/src/amgi_aiobotocore/sqs.py +++ b/packages/amgi-aiobotocore/src/amgi_aiobotocore/sqs.py @@ -1,9 +1,14 @@ import asyncio +import sys from collections import deque +from collections.abc import Awaitable +from collections.abc import Callable from collections.abc import Generator from collections.abc import Iterable from collections.abc import Sequence +from types import TracebackType from typing import Any +from typing import AsyncContextManager from aiobotocore.session import get_session from amgi_common import Lifespan @@ -15,6 +20,16 @@ from amgi_types import AMGISendEvent from amgi_types import MessageReceiveEvent from amgi_types import MessageScope +from amgi_types import MessageSendEvent + + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +_MessageSendT = Callable[[MessageSendEvent], Awaitable[None]] +_MessageSendManagerT = AsyncContextManager[_MessageSendT] def run( @@ -24,6 +39,7 @@ def run( endpoint_url: str | None = None, aws_access_key_id: str | None = None, aws_secret_access_key: str | None = None, + message_send: _MessageSendManagerT | None = None, ) -> None: server = Server( app, @@ -32,6 +48,7 @@ def run( endpoint_url=endpoint_url, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, + message_send=message_send, ) server_serve(server) @@ -84,20 +101,23 @@ async def __call__(self) -> MessageReceiveEvent: } +async def _get_queue_url(client: Any, queue_name: str) -> str: + queue_url_response = await client.get_queue_url(QueueName=queue_name) + queue_url = queue_url_response["QueueUrl"] + assert isinstance(queue_url, str) + return queue_url + + class _QueueUrlCache: def __init__(self, client: Any) -> None: self._client = client - self._operation_cacher = OperationCacher(self._get_queue_url) + self._operation_cacher = OperationCacher[str, str]( + lambda queue_name: _get_queue_url(client, queue_name) + ) async def get_queue_url(self, queue_name: str) -> str: return await self._operation_cacher.get(queue_name) - async def _get_queue_url(self, queue_name: str) -> str: - queue_url_response = await self._client.get_queue_url(QueueName=queue_name) - queue_url = queue_url_response["QueueUrl"] - assert isinstance(queue_url, str) - return queue_url - class SqsBatchFailureError(IOError): def __init__(self, sender_fault: bool, code: str, message: str): @@ -202,18 +222,56 @@ async def send_message( await self._operation_batcher.enqueue((queue_url, payload, headers)) +class MessageSend: + def __init__( + self, + region_name: str | None = None, + endpoint_url: str | None = None, + aws_access_key_id: str | None = None, + aws_secret_access_key: str | None = None, + ) -> None: + session = get_session() + + self._client_context = session.create_client( + "sqs", + region_name=region_name, + endpoint_url=endpoint_url, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + ) + + async def __aenter__(self) -> Self: + self._client = await self._client_context.__aenter__() + self._send_batcher = _SendBatcher(self._client) + self._queue_url_cache = _QueueUrlCache(self._client) + + return self + + async def __call__(self, event: MessageSendEvent) -> None: + queue_url = await self._queue_url_cache.get_queue_url(event["address"]) + await self._send_batcher.send_message( + queue_url, event["payload"], event["headers"] + ) + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self._client_context.__aexit__(exc_type, exc_val, exc_tb) + + class _Send: def __init__( self, - send_batcher: _SendBatcher, delete_batcher: _DeleteBatcher, - queue_url_cache: _QueueUrlCache, queue_url: str, + message_send: _MessageSendT, ) -> None: self._queue_url = queue_url - self._queue_url_cache = queue_url_cache self._delete_batcher = delete_batcher - self._send_batcher = send_batcher + self._message_send = message_send async def __call__(self, event: AMGISendEvent) -> None: if event["type"] == "message.ack": @@ -222,10 +280,7 @@ async def __call__(self, event: AMGISendEvent) -> None: event["id"], ) if event["type"] == "message.send": - queue_url = await self._queue_url_cache.get_queue_url(event["address"]) - await self._send_batcher.send_message( - queue_url, event["payload"], event["headers"] - ) + await self._message_send(event) class Server: @@ -237,6 +292,7 @@ def __init__( endpoint_url: str | None = None, aws_access_key_id: str | None = None, aws_secret_access_key: str | None = None, + message_send: _MessageSendManagerT | None = None, ) -> None: self._app = app self._queues = queues @@ -244,25 +300,30 @@ def __init__( self._endpoint_url = endpoint_url self._aws_access_key_id = aws_access_key_id self._aws_secret_access_key = aws_secret_access_key + self._message_send = message_send or MessageSend( + region_name, endpoint_url, aws_access_key_id, aws_secret_access_key + ) + self._stoppable = Stoppable() async def serve(self) -> None: session = get_session() - async with session.create_client( - "sqs", - region_name=self._region_name, - endpoint_url=self._endpoint_url, - aws_access_key_id=self._aws_access_key_id, - aws_secret_access_key=self._aws_secret_access_key, - ) as client: - queue_url_cache = _QueueUrlCache(client) + async with ( + session.create_client( + "sqs", + region_name=self._region_name, + endpoint_url=self._endpoint_url, + aws_access_key_id=self._aws_access_key_id, + aws_secret_access_key=self._aws_secret_access_key, + ) as client, + self._message_send as message_send, + ): delete_batcher = _DeleteBatcher(client) - send_batcher = _SendBatcher(client) queue_urls = zip( await asyncio.gather( - *(queue_url_cache.get_queue_url(queue) for queue in self._queues) + *(_get_queue_url(client, queue) for queue in self._queues) ), self._queues, ) @@ -273,9 +334,8 @@ async def serve(self) -> None: client, queue_url, queue, - queue_url_cache, delete_batcher, - send_batcher, + message_send, state, ) for queue_url, queue in queue_urls @@ -287,9 +347,8 @@ async def _queue_loop( client: Any, queue_url: str, queue_name: str, - queue_url_cache: _QueueUrlCache, delete_batcher: _DeleteBatcher, - send_batcher: _SendBatcher, + message_send: _MessageSendT, state: dict[str, Any], ) -> None: async for messages_response in self._stoppable.call( @@ -309,7 +368,7 @@ async def _queue_loop( await self._app( scope, _Receive(messages), - _Send(send_batcher, delete_batcher, queue_url_cache, queue_url), + _Send(delete_batcher, queue_url, message_send), ) def stop(self) -> None: