Skip to content

Commit 177edc6

Browse files
authored
Merge pull request #20 from Zerohertz/issue#18/feat/transaction
[Feat] Transaction ๊ตฌ์„ฑ ๋ฐ ํ…Œ์ŠคํŠธ ๊ณ ๋„ํ™”
2 parents cd75b2c + 88acb56 commit 177edc6

File tree

25 files changed

+473
-174
lines changed

25 files changed

+473
-174
lines changed

โ€ŽMakefileโ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ lint:
1616
test:
1717
uv sync --group test
1818
export DESCRIPTION=$$(cat README.md) && \
19-
uv run pytest \
19+
uv run pytest -vv \
2020
--cov=app --cov-branch --cov-report=xml \
2121
--junitxml=junit.xml -o junit_family=legacy
2222

โ€ŽREADME.mdโ€Ž

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,6 @@
3434
</h4>
3535

3636
- [fastapi/full-stack-fastapi-template/backend](https://github.com/fastapi/full-stack-fastapi-template/tree/master/backend)
37+
- [teamhide/fastapi-boilerplate](https://github.com/teamhide/fastapi-boilerplate)
3738
- [jujumilk3/fastapi-clean-architecture](https://github.com/jujumilk3/fastapi-clean-architecture)
39+
- [UponTheSky/How to implement a transactional decorator in FastAPI + SQLAlchemy - with reviewing other approaches](https://dev.to/uponthesky/python-post-reviewhow-to-implement-a-transactional-decorator-in-fastapi-sqlalchemy-ein)

โ€Žapp/core/configs.pyโ€Ž

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,26 @@
1+
from enum import Enum
12
from typing import Optional
23

34
from pydantic import computed_field
45
from pydantic_core import MultiHostUrl
56
from pydantic_settings import BaseSettings
67

78

9+
class ENVIRONMENT(Enum):
10+
TEST = "TEST"
11+
DEV = "DEV"
12+
PROD = "PROD"
13+
14+
815
class Configs(BaseSettings):
16+
ENV: ENVIRONMENT
17+
918
# --------- APP SETTINGS --------- #
1019
PROJECT_NAME: str
1120
DESCRIPTION: str
1221
VERSION: str
1322
PREFIX: str
23+
TZ: str = "Asia/Seoul"
1424

1525
# --------- DATABASE SETTINGS --------- #
1626
DB_TYPE: str

โ€Žapp/core/container.pyโ€Ž

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
from dependency_injector.containers import DeclarativeContainer, WiringConfiguration
2-
from dependency_injector.providers import Factory, Singleton
2+
from dependency_injector.providers import Factory
33

4-
from app.core.database import Database
54
from app.repositories.users import UserRepository
65
from app.services.users import UserService
76

87

98
class Container(DeclarativeContainer):
109
wiring_config = WiringConfiguration(modules=["app.api.v1.endpoints.users"])
1110

12-
database = Singleton(Database)
13-
14-
user_repository = Factory(UserRepository, session=database.provided.session)
11+
user_repository = Factory(UserRepository)
1512

1613
user_service = Factory(UserService, user_repository=user_repository)

โ€Žapp/core/database.pyโ€Ž

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,87 @@
1-
from contextlib import asynccontextmanager
2-
from typing import AsyncIterator
1+
from contextvars import ContextVar, Token
2+
from functools import wraps
3+
from typing import Awaitable, Callable, Optional
34

4-
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
5+
from loguru import logger
6+
from sqlalchemy import NullPool
7+
from sqlalchemy.ext.asyncio import (
8+
AsyncSession,
9+
async_scoped_session,
10+
async_sessionmaker,
11+
create_async_engine,
12+
)
513

6-
from app.core.configs import configs
14+
from app.core.configs import ENVIRONMENT, configs
715
from app.models.base import BaseModel
816

917

18+
class Context:
19+
def __init__(self) -> None:
20+
self.context: ContextVar[Optional[int]] = ContextVar(
21+
"session_context", default=None
22+
)
23+
24+
def get(self) -> int:
25+
session_id = self.context.get()
26+
if not session_id:
27+
raise ValueError("Currently no session is available.")
28+
return session_id
29+
30+
def set(self, session_id: int) -> Token:
31+
return self.context.set(session_id)
32+
33+
def reset(self, context: Token) -> None:
34+
self.context.reset(context)
35+
36+
1037
class Database:
1138
def __init__(self) -> None:
12-
self.engine = create_async_engine(configs.DATABASE_URI, echo=configs.DB_ECHO)
39+
self.context = Context()
40+
async_engine_kwargs = {
41+
"url": configs.DATABASE_URI,
42+
"echo": configs.DB_ECHO,
43+
}
44+
if configs.ENV == ENVIRONMENT.TEST and configs.DB_DRIVER != "aiosqlite":
45+
# NOTE: PyTest ์‹œ event loop ์ถฉ๋Œ ๋ฐœ์ƒ (related: #19)
46+
logger.warning("Using NullPool for async engine")
47+
async_engine_kwargs["poolclass"] = NullPool # type: ignore[assignment]
48+
self.engine = create_async_engine(**async_engine_kwargs) # type: ignore[arg-type]
1349
self.sessionmaker = async_sessionmaker(
14-
bind=self.engine, class_=AsyncSession, expire_on_commit=False
50+
bind=self.engine,
51+
class_=AsyncSession,
52+
autoflush=False,
53+
autocommit=False,
54+
expire_on_commit=False,
55+
)
56+
self.scoped_session = async_scoped_session(
57+
session_factory=self.sessionmaker,
58+
scopefunc=self.context.get,
1559
)
1660

1761
async def create_all(self) -> None:
62+
logger.warning("Create database")
1863
async with self.engine.begin() as conn:
64+
if configs.ENV == ENVIRONMENT.TEST:
65+
await conn.run_sync(BaseModel.metadata.drop_all)
1966
await conn.run_sync(BaseModel.metadata.create_all)
2067

21-
@asynccontextmanager
22-
async def session(self) -> AsyncIterator[AsyncSession]:
23-
async with self.sessionmaker() as session:
68+
def transactional(self, func: Callable[..., Awaitable]) -> Callable[..., Awaitable]:
69+
@wraps(func)
70+
async def wrapper(*args, **kwargs):
2471
try:
25-
yield session
26-
except Exception:
27-
await session.rollback()
28-
raise
29-
finally:
30-
await session.close()
72+
session = self.scoped_session()
73+
if session.in_transaction():
74+
logger.trace(
75+
f"[Session in transaction]\tID: {database.context.get()}, {self.context=}"
76+
)
77+
return await func(*args, **kwargs)
78+
async with session.begin():
79+
response = await func(*args, **kwargs)
80+
return response
81+
except Exception as error:
82+
raise error
83+
84+
return wrapper
85+
86+
87+
database = Database()

โ€Žapp/core/lifespan.pyโ€Ž

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
from fastapi import FastAPI
77
from loguru import logger
88

9-
from app.core.configs import configs
9+
from app.core.configs import ENVIRONMENT, configs
1010
from app.core.container import Container
11+
from app.core.database import database
1112
from app.utils.logging import remove_handler
1213

1314

@@ -16,9 +17,12 @@ async def lifespan(app: FastAPI): # pylint: disable=unused-argument
1617
remove_handler(logging.getLogger("uvicorn.access"))
1718
remove_handler(logging.getLogger("uvicorn.error"))
1819
logger.remove()
20+
level = 0
21+
if configs.ENV == ENVIRONMENT.PROD:
22+
level = 20
1923
logger.add(
2024
sys.stderr,
21-
level=0,
25+
level=level,
2226
format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> <bg #800a0a>"
2327
+ time.tzname[0]
2428
+ "</bg #800a0a> | <level>{level: <8}</level> | <fg #800a0a>{name}</fg #800a0a>:<fg #800a0a>{function}</fg #800a0a>:<fg #800a0a>{line}</fg #800a0a> - <level>{message}</level>",
@@ -27,10 +31,9 @@ async def lifespan(app: FastAPI): # pylint: disable=unused-argument
2731
# logging.getLogger("uvicorn.access").addHandler(LoguruHandler())
2832
# logging.getLogger("uvicorn.error").addHandler(LoguruHandler())
2933

30-
container = Container()
34+
logger.info(f"{configs.ENV=}")
3135
if configs.DB_TABLE_CREATE:
32-
logger.warning("Create database")
33-
database = container.database()
3436
await database.create_all()
37+
app.container = Container() # type: ignore[attr-defined]
3538

3639
yield

โ€Žapp/core/middlewares.pyโ€Ž

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from loguru import logger
55
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
66

7+
from app.core.database import database
78
from app.utils.logging import ANSI_BG_COLOR, ANSI_STYLE, ansi_format
89

910

@@ -57,3 +58,18 @@ async def dispatch(
5758
f"[IP: {ip}] [URL: {url}] [Method: {method}] [Status: {status} (Elapsed Time: {elapsed_time})]"
5859
)
5960
return response
61+
62+
63+
class SessionMiddleware(BaseHTTPMiddleware):
64+
async def dispatch(
65+
self, request: Request, call_next: RequestResponseEndpoint
66+
) -> Response:
67+
try:
68+
context = database.context.set(session_id=hash(request))
69+
logger.trace(f"[Session Start]\tID: {database.context.get()}, {context=}")
70+
response = await call_next(request)
71+
finally:
72+
await database.scoped_session.remove()
73+
logger.trace(f"[Session End]\tID: {database.context.get()}, {context=}")
74+
database.context.reset(context=context)
75+
return response

โ€Žapp/exceptions/base.pyโ€Ž

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,11 @@
44
class CoreException(abc.ABC, Exception):
55
status: int
66
message: str
7+
8+
def __str__(self) -> str:
9+
return (
10+
f"[{self.__class__.__name__}] status={self.status}, message={self.message}"
11+
)
12+
13+
def __repr__(self) -> str:
14+
return f"[{self.__class__.__name__}] {self.message}"

โ€Žapp/exceptions/database.pyโ€Ž

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33
from app.exceptions.base import CoreException
44

55

6-
class EntityNotFound(CoreException):
6+
class DatabaseException(CoreException):
7+
status: int
8+
message: str
9+
10+
11+
class EntityAlreadyExists(DatabaseException):
12+
status: int = status.HTTP_409_CONFLICT
13+
message: str = "Entity already exists in the database."
14+
15+
16+
class EntityNotFound(DatabaseException):
717
status: int = status.HTTP_404_NOT_FOUND
818
message: str = "Entity not found in the database."

โ€Žapp/exceptions/handlers.pyโ€Ž

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@ async def global_exception_handler(request: Request, exc: Exception) -> JSONResp
1919
)
2020

2121

22-
async def business_exception_handler(
23-
request: Request, exc: CoreException
22+
async def core_exception_handler(
23+
request: Request, exc: CoreException # pylint: disable=unused-argument
2424
) -> JSONResponse:
25-
logger.error(f"{request=}, {exc=}")
26-
name = exc.__class__.__name__
25+
logger.error(exc)
2726
return JSONResponse(
28-
content=APIResponse.error(
29-
status=exc.status, message=f"[{name}] {exc.message}"
30-
).model_dump(mode="json"),
27+
content=APIResponse.error(status=exc.status, message=repr(exc)).model_dump(
28+
mode="json"
29+
),
3130
status_code=exc.status,
3231
)

0 commit comments

Comments
ย (0)