|
1 | | -from contextlib import asynccontextmanager |
2 | | -from typing import AsyncIterator |
| 1 | +from contextvars import ContextVar, Token |
| 2 | +from functools import wraps |
| 3 | +from typing import Awaitable, Callable, Optional |
3 | 4 |
|
4 | | -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine |
| 5 | +from loguru import logger |
| 6 | +from sqlalchemy import NullPool |
| 7 | +from sqlalchemy.ext.asyncio import ( |
| 8 | + AsyncSession, |
| 9 | + async_scoped_session, |
| 10 | + async_sessionmaker, |
| 11 | + create_async_engine, |
| 12 | +) |
5 | 13 |
|
6 | | -from app.core.configs import configs |
| 14 | +from app.core.configs import ENVIRONMENT, configs |
7 | 15 | from app.models.base import BaseModel |
8 | 16 |
|
9 | 17 |
|
| 18 | +class Context: |
| 19 | + def __init__(self) -> None: |
| 20 | + self.context: ContextVar[Optional[int]] = ContextVar( |
| 21 | + "session_context", default=None |
| 22 | + ) |
| 23 | + |
| 24 | + def get(self) -> int: |
| 25 | + session_id = self.context.get() |
| 26 | + if not session_id: |
| 27 | + raise ValueError("Currently no session is available.") |
| 28 | + return session_id |
| 29 | + |
| 30 | + def set(self, session_id: int) -> Token: |
| 31 | + return self.context.set(session_id) |
| 32 | + |
| 33 | + def reset(self, context: Token) -> None: |
| 34 | + self.context.reset(context) |
| 35 | + |
| 36 | + |
10 | 37 | class Database: |
11 | 38 | def __init__(self) -> None: |
12 | | - self.engine = create_async_engine(configs.DATABASE_URI, echo=configs.DB_ECHO) |
| 39 | + self.context = Context() |
| 40 | + async_engine_kwargs = { |
| 41 | + "url": configs.DATABASE_URI, |
| 42 | + "echo": configs.DB_ECHO, |
| 43 | + } |
| 44 | + if configs.ENV == ENVIRONMENT.TEST and configs.DB_DRIVER != "aiosqlite": |
| 45 | + # NOTE: PyTest ์ event loop ์ถฉ๋ ๋ฐ์ (related: #19) |
| 46 | + logger.warning("Using NullPool for async engine") |
| 47 | + async_engine_kwargs["poolclass"] = NullPool # type: ignore[assignment] |
| 48 | + self.engine = create_async_engine(**async_engine_kwargs) # type: ignore[arg-type] |
13 | 49 | self.sessionmaker = async_sessionmaker( |
14 | | - bind=self.engine, class_=AsyncSession, expire_on_commit=False |
| 50 | + bind=self.engine, |
| 51 | + class_=AsyncSession, |
| 52 | + autoflush=False, |
| 53 | + autocommit=False, |
| 54 | + expire_on_commit=False, |
| 55 | + ) |
| 56 | + self.scoped_session = async_scoped_session( |
| 57 | + session_factory=self.sessionmaker, |
| 58 | + scopefunc=self.context.get, |
15 | 59 | ) |
16 | 60 |
|
17 | 61 | async def create_all(self) -> None: |
| 62 | + logger.warning("Create database") |
18 | 63 | async with self.engine.begin() as conn: |
| 64 | + if configs.ENV == ENVIRONMENT.TEST: |
| 65 | + await conn.run_sync(BaseModel.metadata.drop_all) |
19 | 66 | await conn.run_sync(BaseModel.metadata.create_all) |
20 | 67 |
|
21 | | - @asynccontextmanager |
22 | | - async def session(self) -> AsyncIterator[AsyncSession]: |
23 | | - async with self.sessionmaker() as session: |
| 68 | + def transactional(self, func: Callable[..., Awaitable]) -> Callable[..., Awaitable]: |
| 69 | + @wraps(func) |
| 70 | + async def wrapper(*args, **kwargs): |
24 | 71 | try: |
25 | | - yield session |
26 | | - except Exception: |
27 | | - await session.rollback() |
28 | | - raise |
29 | | - finally: |
30 | | - await session.close() |
| 72 | + session = self.scoped_session() |
| 73 | + if session.in_transaction(): |
| 74 | + logger.trace( |
| 75 | + f"[Session in transaction]\tID: {database.context.get()}, {self.context=}" |
| 76 | + ) |
| 77 | + return await func(*args, **kwargs) |
| 78 | + async with session.begin(): |
| 79 | + response = await func(*args, **kwargs) |
| 80 | + return response |
| 81 | + except Exception as error: |
| 82 | + raise error |
| 83 | + |
| 84 | + return wrapper |
| 85 | + |
| 86 | + |
| 87 | +database = Database() |
0 commit comments