From 00bdf88b6e234fad2fac38896843013b0e008d06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rn-Michael=20Miehe?= <40151420+ldericher@users.noreply.github.com> Date: Sun, 20 Mar 2022 00:12:56 +0000 Subject: [PATCH] Connection.use() --- api/kiwi_vpn_api/db/connection.py | 20 ++++++++++++++++++++ api/kiwi_vpn_api/main.py | 2 +- api/kiwi_vpn_api/routers/admin.py | 2 +- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/api/kiwi_vpn_api/db/connection.py b/api/kiwi_vpn_api/db/connection.py index 9f65c7c..cbb5ab5 100644 --- a/api/kiwi_vpn_api/db/connection.py +++ b/api/kiwi_vpn_api/db/connection.py @@ -6,6 +6,19 @@ from sqlalchemy.orm import Session, sessionmaker from .models import ORMBaseModel +class SessionManager: + __session: Session + + def __init__(self, session: Session) -> None: + self.__session = session + + def __enter__(self) -> Session: + return self.__session + + def __exit__(self, *args) -> None: + self.__session.close() + + class Connection: engine: Engine | None = None session_local: sessionmaker | None = None @@ -18,6 +31,13 @@ class Connection: ) ORMBaseModel.metadata.create_all(bind=engine) + @classmethod + def use(cls) -> SessionManager | None: + if cls.session_local is None: + return None + + return SessionManager(cls.session_local()) + @classmethod async def get(cls) -> Generator[Session | None, None, None]: if cls.session_local is None: diff --git a/api/kiwi_vpn_api/main.py b/api/kiwi_vpn_api/main.py index dc699d1..4d35e9d 100755 --- a/api/kiwi_vpn_api/main.py +++ b/api/kiwi_vpn_api/main.py @@ -39,7 +39,7 @@ async def on_startup() -> None: Connection.connect(await current_config.db.db_engine) # some testing - async for db in Connection.get(): + with Connection.use() as db: print(schemas.User.from_db(db, "admin")) print(schemas.User.from_db(db, "nonexistent")) diff --git a/api/kiwi_vpn_api/routers/admin.py b/api/kiwi_vpn_api/routers/admin.py index eef10ec..062e7b5 100644 --- a/api/kiwi_vpn_api/routers/admin.py +++ b/api/kiwi_vpn_api/routers/admin.py @@ -25,7 +25,7 @@ async def install( await config.save() Connection.connect(await config.db.db_engine) - async for db in Connection.get(): + with Connection.use() as db: admin_user = schemas.User.create( db=db, user=admin_user,