|
2 | 2 | from collections.abc import AsyncGenerator |
3 | 3 | from datetime import datetime, timedelta, timezone |
4 | 4 |
|
| 5 | +import aioboto3 |
5 | 6 | import pytest |
6 | 7 | import pytest_asyncio |
| 8 | +from moto import mock_aws |
7 | 9 | from pydantic import UUID4 |
8 | | -from sqlalchemy import exc |
9 | | -from sqlalchemy.ext.asyncio import ( |
10 | | - AsyncEngine, |
11 | | - AsyncSession, |
12 | | - async_sessionmaker, |
13 | | - create_async_engine, |
14 | | -) |
15 | | -from sqlalchemy.orm import DeclarativeBase |
16 | 10 |
|
17 | | -from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID |
18 | | -from fastapi_users_db_sqlalchemy.access_token import ( |
19 | | - SQLAlchemyAccessTokenDatabase, |
20 | | - SQLAlchemyBaseAccessTokenTableUUID, |
| 11 | +from fastapi_users_db_dynamodb import DynamoDBBaseUserTableUUID, DynamoDBUserDatabase |
| 12 | +from fastapi_users_db_dynamodb.access_token import ( |
| 13 | + DynamoDBAccessTokenDatabase, |
| 14 | + DynamoDBBaseAccessTokenTableUUID, |
21 | 15 | ) |
22 | | -from tests.conftest import DATABASE_URL |
23 | 16 |
|
24 | 17 |
|
25 | | -class Base(DeclarativeBase): |
| 18 | +class Base: |
26 | 19 | pass |
27 | 20 |
|
28 | 21 |
|
29 | | -class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base): |
| 22 | +class AccessToken(DynamoDBBaseAccessTokenTableUUID, Base): |
30 | 23 | pass |
31 | 24 |
|
32 | 25 |
|
33 | | -class User(SQLAlchemyBaseUserTableUUID, Base): |
| 26 | +class User(DynamoDBBaseUserTableUUID, Base): |
34 | 27 | pass |
35 | 28 |
|
36 | 29 |
|
37 | | -def create_async_session_maker(engine: AsyncEngine): |
38 | | - return async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) |
39 | | - |
40 | | - |
41 | 30 | @pytest.fixture |
42 | 31 | def user_id() -> UUID4: |
43 | 32 | return uuid.uuid4() |
44 | 33 |
|
45 | 34 |
|
46 | 35 | @pytest_asyncio.fixture |
47 | | -async def sqlalchemy_access_token_db( |
| 36 | +async def dynamodb_access_token_db( |
48 | 37 | user_id: UUID4, |
49 | | -) -> AsyncGenerator[SQLAlchemyAccessTokenDatabase[AccessToken], None]: |
50 | | - engine = create_async_engine(DATABASE_URL) |
51 | | - sessionmaker = create_async_session_maker(engine) |
| 38 | +) -> AsyncGenerator[DynamoDBAccessTokenDatabase[AccessToken]]: |
| 39 | + with mock_aws(): |
| 40 | + session = aioboto3.Session() |
| 41 | + user_table_name = "users_test" |
| 42 | + token_table_name = "access_tokens_test" |
| 43 | + |
| 44 | + user_db = DynamoDBUserDatabase( |
| 45 | + session, DynamoDBBaseUserTableUUID, user_table_name |
| 46 | + ) |
| 47 | + user = await user_db.create( |
| 48 | + { |
| 49 | + "id": user_id, |
| 50 | + "email": "lancelot@camelot.bt", |
| 51 | + "hashed_password": "guinevere", |
| 52 | + } |
| 53 | + ) |
52 | 54 |
|
53 | | - async with engine.begin() as connection: |
54 | | - await connection.run_sync(Base.metadata.create_all) |
| 55 | + token_db = DynamoDBAccessTokenDatabase(session, AccessToken, token_table_name) |
55 | 56 |
|
56 | | - async with sessionmaker() as session: |
57 | | - user = User( |
58 | | - id=user_id, email="lancelot@camelot.bt", hashed_password="guinevere" |
59 | | - ) |
60 | | - session.add(user) |
61 | | - await session.commit() |
| 57 | + # Vorherigen Token löschen, falls er existiert |
| 58 | + token_obj = await token_db.get_by_token("TOKEN") |
| 59 | + if token_obj: |
| 60 | + await token_db.delete(token_obj) |
62 | 61 |
|
63 | | - yield SQLAlchemyAccessTokenDatabase(session, AccessToken) |
| 62 | + yield token_db |
64 | 63 |
|
65 | | - async with engine.begin() as connection: |
66 | | - await connection.run_sync(Base.metadata.drop_all) |
| 64 | + token_obj = await token_db.get_by_token("TOKEN") |
| 65 | + if token_obj: |
| 66 | + await token_db.delete(token_obj) |
| 67 | + |
| 68 | + await user_db.delete(user) |
67 | 69 |
|
68 | 70 |
|
69 | 71 | @pytest.mark.asyncio |
70 | 72 | async def test_queries( |
71 | | - sqlalchemy_access_token_db: SQLAlchemyAccessTokenDatabase[AccessToken], |
| 73 | + dynamodb_access_token_db: DynamoDBAccessTokenDatabase[AccessToken], |
72 | 74 | user_id: UUID4, |
73 | 75 | ): |
74 | 76 | access_token_create = {"token": "TOKEN", "user_id": user_id} |
75 | 77 |
|
76 | | - # Create |
77 | | - access_token = await sqlalchemy_access_token_db.create(access_token_create) |
| 78 | + access_token = await dynamodb_access_token_db.create(access_token_create) |
78 | 79 | assert access_token.token == "TOKEN" |
79 | 80 | assert access_token.user_id == user_id |
80 | 81 |
|
81 | | - # Update |
82 | | - update_dict = {"created_at": datetime.now(timezone.utc)} |
83 | | - updated_access_token = await sqlalchemy_access_token_db.update( |
84 | | - access_token, update_dict |
| 82 | + new_time = datetime.now(timezone.utc) |
| 83 | + updated_access_token = await dynamodb_access_token_db.update( |
| 84 | + access_token, {"created_at": new_time} |
85 | 85 | ) |
86 | | - assert updated_access_token.created_at.replace(microsecond=0) == update_dict[ |
87 | | - "created_at" |
88 | | - ].replace(microsecond=0) |
89 | | - |
90 | | - # Get by token |
91 | | - access_token_by_token = await sqlalchemy_access_token_db.get_by_token( |
92 | | - access_token.token |
| 86 | + assert updated_access_token.created_at.replace(microsecond=0) == new_time.replace( |
| 87 | + microsecond=0 |
93 | 88 | ) |
94 | | - assert access_token_by_token is not None |
95 | 89 |
|
96 | | - # Get by token expired |
97 | | - access_token_by_token = await sqlalchemy_access_token_db.get_by_token( |
| 90 | + token_obj = await dynamodb_access_token_db.get_by_token(access_token.token) |
| 91 | + assert token_obj is not None |
| 92 | + |
| 93 | + token_obj = await dynamodb_access_token_db.get_by_token( |
98 | 94 | access_token.token, max_age=datetime.now(timezone.utc) + timedelta(hours=1) |
99 | 95 | ) |
100 | | - assert access_token_by_token is None |
| 96 | + assert token_obj is None |
101 | 97 |
|
102 | | - # Get by token not expired |
103 | | - access_token_by_token = await sqlalchemy_access_token_db.get_by_token( |
| 98 | + token_obj = await dynamodb_access_token_db.get_by_token( |
104 | 99 | access_token.token, max_age=datetime.now(timezone.utc) - timedelta(hours=1) |
105 | 100 | ) |
106 | | - assert access_token_by_token is not None |
| 101 | + assert token_obj is not None |
107 | 102 |
|
108 | | - # Get by token unknown |
109 | | - access_token_by_token = await sqlalchemy_access_token_db.get_by_token( |
110 | | - "NOT_EXISTING_TOKEN" |
111 | | - ) |
112 | | - assert access_token_by_token is None |
| 103 | + token_obj = await dynamodb_access_token_db.get_by_token("NOT_EXISTING_TOKEN") |
| 104 | + assert token_obj is None |
113 | 105 |
|
114 | | - # Delete token |
115 | | - await sqlalchemy_access_token_db.delete(access_token) |
116 | | - deleted_access_token = await sqlalchemy_access_token_db.get_by_token( |
117 | | - access_token.token |
118 | | - ) |
119 | | - assert deleted_access_token is None |
| 106 | + await dynamodb_access_token_db.delete(access_token) |
| 107 | + deleted_token = await dynamodb_access_token_db.get_by_token(access_token.token) |
| 108 | + assert deleted_token is None |
120 | 109 |
|
121 | 110 |
|
122 | 111 | @pytest.mark.asyncio |
123 | 112 | async def test_insert_existing_token( |
124 | | - sqlalchemy_access_token_db: SQLAlchemyAccessTokenDatabase[AccessToken], |
| 113 | + dynamodb_access_token_db: DynamoDBAccessTokenDatabase[AccessToken], |
125 | 114 | user_id: UUID4, |
126 | 115 | ): |
127 | 116 | access_token_create = {"token": "TOKEN", "user_id": user_id} |
128 | | - await sqlalchemy_access_token_db.create(access_token_create) |
129 | 117 |
|
130 | | - with pytest.raises(exc.IntegrityError): |
131 | | - await sqlalchemy_access_token_db.create(access_token_create) |
| 118 | + await dynamodb_access_token_db.create(access_token_create) |
| 119 | + |
| 120 | + with pytest.raises(Exception): |
| 121 | + await dynamodb_access_token_db.create(access_token_create) |
0 commit comments