kiwi-vpn/api/kiwi_vpn_api/db/connection.py

75 lines
1.6 KiB
Python

"""
Utilities for handling SQLAlchemy database connections.
"""
from typing import Generator
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, sessionmaker
from .models import ORMBaseModel
class SessionManager:
"""
Simple context manager for an ORM session.
"""
__session: Session
def __init__(self, session: Session) -> None:
self.__session = session
def __enter__(self) -> Session:
return self.__session
def __exit__(self, *args) -> None:
self.__session.close()
class Connection:
"""
Namespace for the database connection.
"""
engine: Engine | None = None
session_local: sessionmaker | None = None
@classmethod
def connect(cls, engine: Engine) -> None:
"""
Connect ORM to a database engine.
"""
cls.engine = engine
cls.session_local = sessionmaker(
autocommit=False, autoflush=False, bind=engine,
)
ORMBaseModel.metadata.create_all(bind=engine)
@classmethod
def use(cls) -> SessionManager | None:
"""
Create an ORM session using a context manager.
"""
if cls.session_local is None:
return None
return SessionManager(cls.session_local())
@classmethod
async def get(cls) -> Generator[Session | None, None, None]:
"""
Create an ORM session using a FastAPI compatible async generator.
"""
if cls.session_local is None:
yield None
else:
db = cls.session_local()
try:
yield db
finally:
db.close()