Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 89 additions & 30 deletions packages/amgi-aiobotocore/src/amgi_aiobotocore/sqs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import asyncio
import sys
from collections import deque
from collections.abc import Awaitable
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterable
from collections.abc import Sequence
from types import TracebackType
from typing import Any
from typing import AsyncContextManager

from aiobotocore.session import get_session
from amgi_common import Lifespan
Expand All @@ -15,6 +20,16 @@
from amgi_types import AMGISendEvent
from amgi_types import MessageReceiveEvent
from amgi_types import MessageScope
from amgi_types import MessageSendEvent


if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

_MessageSendT = Callable[[MessageSendEvent], Awaitable[None]]
_MessageSendManagerT = AsyncContextManager[_MessageSendT]


def run(
Expand All @@ -24,6 +39,7 @@ def run(
endpoint_url: str | None = None,
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
message_send: _MessageSendManagerT | None = None,
) -> None:
server = Server(
app,
Expand All @@ -32,6 +48,7 @@ def run(
endpoint_url=endpoint_url,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
message_send=message_send,
)
server_serve(server)

Expand Down Expand Up @@ -84,20 +101,23 @@ async def __call__(self) -> MessageReceiveEvent:
}


async def _get_queue_url(client: Any, queue_name: str) -> str:
queue_url_response = await client.get_queue_url(QueueName=queue_name)
queue_url = queue_url_response["QueueUrl"]
assert isinstance(queue_url, str)
return queue_url


class _QueueUrlCache:
def __init__(self, client: Any) -> None:
self._client = client
self._operation_cacher = OperationCacher(self._get_queue_url)
self._operation_cacher = OperationCacher[str, str](
lambda queue_name: _get_queue_url(client, queue_name)
)

async def get_queue_url(self, queue_name: str) -> str:
return await self._operation_cacher.get(queue_name)

async def _get_queue_url(self, queue_name: str) -> str:
queue_url_response = await self._client.get_queue_url(QueueName=queue_name)
queue_url = queue_url_response["QueueUrl"]
assert isinstance(queue_url, str)
return queue_url


class SqsBatchFailureError(IOError):
def __init__(self, sender_fault: bool, code: str, message: str):
Expand Down Expand Up @@ -202,18 +222,56 @@ async def send_message(
await self._operation_batcher.enqueue((queue_url, payload, headers))


class MessageSend:
def __init__(
self,
region_name: str | None = None,
endpoint_url: str | None = None,
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
) -> None:
session = get_session()

self._client_context = session.create_client(
"sqs",
region_name=region_name,
endpoint_url=endpoint_url,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)

async def __aenter__(self) -> Self:
self._client = await self._client_context.__aenter__()
self._send_batcher = _SendBatcher(self._client)
self._queue_url_cache = _QueueUrlCache(self._client)

return self

async def __call__(self, event: MessageSendEvent) -> None:
queue_url = await self._queue_url_cache.get_queue_url(event["address"])
await self._send_batcher.send_message(
queue_url, event["payload"], event["headers"]
)

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self._client_context.__aexit__(exc_type, exc_val, exc_tb)


class _Send:
def __init__(
self,
send_batcher: _SendBatcher,
delete_batcher: _DeleteBatcher,
queue_url_cache: _QueueUrlCache,
queue_url: str,
message_send: _MessageSendT,
) -> None:
self._queue_url = queue_url
self._queue_url_cache = queue_url_cache
self._delete_batcher = delete_batcher
self._send_batcher = send_batcher
self._message_send = message_send

async def __call__(self, event: AMGISendEvent) -> None:
if event["type"] == "message.ack":
Expand All @@ -222,10 +280,7 @@ async def __call__(self, event: AMGISendEvent) -> None:
event["id"],
)
if event["type"] == "message.send":
queue_url = await self._queue_url_cache.get_queue_url(event["address"])
await self._send_batcher.send_message(
queue_url, event["payload"], event["headers"]
)
await self._message_send(event)


class Server:
Expand All @@ -237,32 +292,38 @@ def __init__(
endpoint_url: str | None = None,
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
message_send: _MessageSendManagerT | None = None,
) -> None:
self._app = app
self._queues = queues
self._region_name = region_name
self._endpoint_url = endpoint_url
self._aws_access_key_id = aws_access_key_id
self._aws_secret_access_key = aws_secret_access_key
self._message_send = message_send or MessageSend(
region_name, endpoint_url, aws_access_key_id, aws_secret_access_key
)

self._stoppable = Stoppable()

async def serve(self) -> None:
session = get_session()

async with session.create_client(
"sqs",
region_name=self._region_name,
endpoint_url=self._endpoint_url,
aws_access_key_id=self._aws_access_key_id,
aws_secret_access_key=self._aws_secret_access_key,
) as client:
queue_url_cache = _QueueUrlCache(client)
async with (
session.create_client(
"sqs",
region_name=self._region_name,
endpoint_url=self._endpoint_url,
aws_access_key_id=self._aws_access_key_id,
aws_secret_access_key=self._aws_secret_access_key,
) as client,
self._message_send as message_send,
):
delete_batcher = _DeleteBatcher(client)
send_batcher = _SendBatcher(client)

queue_urls = zip(
await asyncio.gather(
*(queue_url_cache.get_queue_url(queue) for queue in self._queues)
*(_get_queue_url(client, queue) for queue in self._queues)
),
self._queues,
)
Expand All @@ -273,9 +334,8 @@ async def serve(self) -> None:
client,
queue_url,
queue,
queue_url_cache,
delete_batcher,
send_batcher,
message_send,
state,
)
for queue_url, queue in queue_urls
Expand All @@ -287,9 +347,8 @@ async def _queue_loop(
client: Any,
queue_url: str,
queue_name: str,
queue_url_cache: _QueueUrlCache,
delete_batcher: _DeleteBatcher,
send_batcher: _SendBatcher,
message_send: _MessageSendT,
state: dict[str, Any],
) -> None:
async for messages_response in self._stoppable.call(
Expand All @@ -309,7 +368,7 @@ async def _queue_loop(
await self._app(
scope,
_Receive(messages),
_Send(send_batcher, delete_batcher, queue_url_cache, queue_url),
_Send(delete_batcher, queue_url, message_send),
)

def stop(self) -> None:
Expand Down
Loading