kiwi-vpn/api/kiwi_vpn_api/db/connection.py

52 lines
1.2 KiB
Python
Raw Normal View History

2022-03-18 18:22:17 +00:00
from typing import Generator
2022-03-18 18:24:09 +00:00
2022-03-18 18:22:17 +00:00
from sqlalchemy.engine import Engine
2022-03-18 18:24:09 +00:00
from sqlalchemy.orm import Session, sessionmaker
2022-03-18 18:22:17 +00:00
from .models import ORMBaseModel
2022-03-18 23:04:28 +00:00
2022-03-20 00:12:56 +00:00
class SessionManager:
__session: Session
def __init__(self, session: Session) -> None:
self.__session = session
def __enter__(self) -> Session:
return self.__session
def __exit__(self, *args) -> None:
self.__session.close()
2022-03-18 23:04:28 +00:00
class Connection:
engine: Engine | None = None
session_local: sessionmaker | None = None
@classmethod
def connect(cls, engine: Engine) -> None:
cls.engine = engine
cls.session_local = sessionmaker(
autocommit=False, autoflush=False, bind=engine,
)
ORMBaseModel.metadata.create_all(bind=engine)
2022-03-20 00:12:56 +00:00
@classmethod
def use(cls) -> SessionManager | None:
if cls.session_local is None:
return None
return SessionManager(cls.session_local())
2022-03-18 23:04:28 +00:00
@classmethod
async def get(cls) -> Generator[Session | None, None, None]:
if cls.session_local is None:
yield None
else:
db = cls.session_local()
try:
yield db
finally:
db.close()