diff --git a/api/kiwi_vpn_api/db.py b/api/kiwi_vpn_api/db.py index d44743e..64196d1 100644 --- a/api/kiwi_vpn_api/db.py +++ b/api/kiwi_vpn_api/db.py @@ -8,15 +8,13 @@ from peewee import (BooleanField, CharField, DatabaseProxy, DateTimeField, from .config import CRYPT_CONTEXT -DB = DatabaseProxy() - -class BaseModel(Model): +class ORMBaseModel(Model): class Meta: - database = DB + database = DatabaseProxy() -class User(BaseModel): +class User(ORMBaseModel): name = CharField(unique=True) password = CharField() @@ -36,12 +34,12 @@ class User(BaseModel): return False -class UserCapability(BaseModel): +class UserCapability(ORMBaseModel): user = ForeignKeyField(User, backref="capabilities") capability = CharField() -class DistinguishedName(BaseModel): +class DistinguishedName(ORMBaseModel): cn_only = BooleanField(default=True) common_name = CharField() email = CharField() @@ -59,7 +57,7 @@ class DistinguishedName(BaseModel): ) -class Certificate(BaseModel): +class Certificate(ORMBaseModel): owner = ForeignKeyField(User, backref="certs") distinguished_name = ForeignKeyField(DistinguishedName) expiry = DateTimeField(default=datetime.datetime.now) diff --git a/api/kiwi_vpn_api/main.py b/api/kiwi_vpn_api/main.py index 6effa63..446f558 100755 --- a/api/kiwi_vpn_api/main.py +++ b/api/kiwi_vpn_api/main.py @@ -5,7 +5,7 @@ from fastapi import Depends, FastAPI from peewee import Database from .config import PRODUCTION_MODE -from .db import DB +from .db import ORMBaseModel from .routers import auth, install, user api = FastAPI( @@ -32,7 +32,7 @@ api.include_router(user.router) async def api_startup( db: Database = Depends(install.connect_db) ): - DB.initialize(db) + ORMBaseModel._meta.database.initialize(db) app = FastAPI() app.mount("/api", api) diff --git a/api/kiwi_vpn_api/routers/install.py b/api/kiwi_vpn_api/routers/install.py index 004c7b1..5556415 100644 --- a/api/kiwi_vpn_api/routers/install.py +++ b/api/kiwi_vpn_api/routers/install.py @@ -6,7 +6,8 @@ from fastapi import APIRouter, Depends, HTTPException, status from peewee import Database from ..config import BaseConfig -from ..db import DB, Certificate, DistinguishedName, User, UserCapability +from ..db import (Certificate, DistinguishedName, ORMBaseModel, User, + UserCapability) router = APIRouter(prefix="/install") @@ -79,7 +80,7 @@ async def set_config( if config.jwt.secret is None: config.jwt.secret = token_hex(32) - DB.initialize(await connect_db(config)) + ORMBaseModel._meta.database.initialize(await connect_db(config)) with open(CONFIG_FILE, "w") as kv: kv.write(config.json(indent=2))