11from contextlib import AbstractContextManager
22from typing import Any , Callable , Generic , Type , TypeVar
33
4+ from loguru import logger
5+ from sqlalchemy import update
6+ from sqlalchemy .future import select
47from sqlalchemy .orm import Session , joinedload
58
69from app .exceptions .database import EntityNotFound
@@ -18,49 +21,59 @@ def __init__(
1821 self .session = session
1922 self .model = model
2023
21- def create (self , model : T ) -> T :
22- with self .session () as session :
24+ async def create (self , model : T ) -> T :
25+ async with self .session () as session :
2326 session .add (model )
24- session .commit ()
25- session .refresh (model )
27+ await session .commit ()
28+ await session .refresh (model )
2629 return model
2730
28- def read_by_id (self , id : int , eager : bool = False ) -> T :
29- with self .session () as session :
30- query = session . query (self .model )
31+ async def read_by_id (self , id : int , eager : bool = False ) -> T :
32+ async with self .session () as session :
33+ query = select (self .model )
3134 if eager :
3235 for _eager in getattr (self .model , "eagers" ):
3336 query = query .options (joinedload (getattr (self .model , _eager )))
34- result = query .filter (self .model .id == id ).first ()
37+ query = query .filter (self .model .id == id )
38+ result = await session .execute (query )
39+ result = result .scalar_one_or_none ()
3540 if not result :
3641 raise EntityNotFound
3742 return result
3843
39- def update_by_id (self , id : int , model : dict ) -> T :
40- with self .session () as session :
41- session . query (self .model ).filter (self .model .id == id ). update ( model )
42- session .commit ( )
43- result = session . query ( self . model ). filter ( self . model . id == id ). first ()
44+ async def update_by_id (self , id : int , model : dict ) -> T :
45+ async with self .session () as session :
46+ query = select (self .model ).filter (self .model .id == id )
47+ result = await session .execute ( query )
48+ result = result . scalar_one_or_none ()
4449 if not result :
4550 raise EntityNotFound
51+ logger .warning (result .updated_at )
52+ for key , value in model .items ():
53+ setattr (result , key , value )
54+ await session .commit ()
55+ await session .refresh (result )
56+ logger .warning (result .updated_at )
4657 return result
4758
48- def update_attr_by_id (self , id : int , column : str , value : Any ) -> T :
49- with self .session () as session :
50- session .query (self .model ).filter (self .model .id == id ).update (
51- {column : value }
52- )
53- session .commit ()
54- result = session .query (self .model ).filter (self .model .id == id ).first ()
59+ async def update_attr_by_id (self , id : int , column : str , value : Any ) -> T :
60+ async with self .session () as session :
61+ query = select (self .model ).filter (self .model .id == id )
62+ result = await session .execute (query )
63+ result = result .scalar_one_or_none ()
5564 if not result :
5665 raise EntityNotFound
66+ setattr (result , column , value )
67+ await session .commit ()
5768 return result
5869
59- def delete_by_id (self , id : int ) -> T :
60- with self .session () as session :
61- query = session .query (self .model ).filter (self .model .id == id ).first ()
62- if not query :
70+ async def delete_by_id (self , id : int ) -> T :
71+ async with self .session () as session :
72+ query = select (self .model ).filter (self .model .id == id )
73+ result = await session .execute (query )
74+ result = result .scalar_one_or_none ()
75+ if not result :
6376 raise EntityNotFound
64- session .delete (query )
65- session .commit ()
66- return query
77+ await session .delete (result )
78+ await session .commit ()
79+ return result
0 commit comments