From e9785f00762e5e3ddbd96f908041f4dfac8d36cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rn-Michael=20Miehe?= <40151420+ldericher@users.noreply.github.com> Date: Wed, 16 Mar 2022 14:54:42 +0000 Subject: [PATCH] runtime database define --- api/kiwi_vpn_api/config.py | 1 - api/kiwi_vpn_api/db.py | 8 +++++--- api/kiwi_vpn_api/main.py | 15 +++++++++++---- api/kiwi_vpn_api/routers/install.py | 24 +++++++++++++----------- 4 files changed, 29 insertions(+), 19 deletions(-) diff --git a/api/kiwi_vpn_api/config.py b/api/kiwi_vpn_api/config.py index 1f2b2cc..8fc4da6 100644 --- a/api/kiwi_vpn_api/config.py +++ b/api/kiwi_vpn_api/config.py @@ -8,7 +8,6 @@ from passlib.context import CryptContext from peewee import Database, SqliteDatabase from pydantic import BaseModel, Field -DB = SqliteDatabase("tmp/vpn.db") PRODUCTION_MODE = False # to get a string like this run: diff --git a/api/kiwi_vpn_api/db.py b/api/kiwi_vpn_api/db.py index cb912cd..d44743e 100644 --- a/api/kiwi_vpn_api/db.py +++ b/api/kiwi_vpn_api/db.py @@ -3,10 +3,12 @@ from __future__ import annotations import datetime from typing import Optional -from peewee import (BooleanField, CharField, DateTimeField, ForeignKeyField, - Model) +from peewee import (BooleanField, CharField, DatabaseProxy, DateTimeField, + ForeignKeyField, Model) -from .config import CRYPT_CONTEXT, DB +from .config import CRYPT_CONTEXT + +DB = DatabaseProxy() class BaseModel(Model): diff --git a/api/kiwi_vpn_api/main.py b/api/kiwi_vpn_api/main.py index cd481a2..6effa63 100755 --- a/api/kiwi_vpn_api/main.py +++ b/api/kiwi_vpn_api/main.py @@ -1,12 +1,12 @@ #!/usr/bin/env python3 import uvicorn -from fastapi import FastAPI - -from kiwi_vpn_api.routers import install +from fastapi import Depends, FastAPI +from peewee import Database from .config import PRODUCTION_MODE -from .routers import auth, user +from .db import DB +from .routers import auth, install, user api = FastAPI( title="kiwi-vpn API", @@ -27,6 +27,13 @@ api.include_router(install.router) api.include_router(auth.router) api.include_router(user.router) + +@api.on_event("startup") +async def api_startup( + db: Database = Depends(install.connect_db) +): + DB.initialize(db) + app = FastAPI() app.mount("/api", api) diff --git a/api/kiwi_vpn_api/routers/install.py b/api/kiwi_vpn_api/routers/install.py index f6b9678..004c7b1 100644 --- a/api/kiwi_vpn_api/routers/install.py +++ b/api/kiwi_vpn_api/routers/install.py @@ -6,7 +6,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from peewee import Database from ..config import BaseConfig -from ..db import Certificate, DistinguishedName, User, UserCapability +from ..db import DB, Certificate, DistinguishedName, User, UserCapability router = APIRouter(prefix="/install") @@ -27,6 +27,16 @@ async def load_config() -> BaseConfig: return BaseConfig() +async def connect_db(config: BaseConfig = Depends(load_config)) -> Database: + db = await config.db.database + db.connect() + return db + + +async def has_tables(db: Database = Depends(connect_db)) -> bool: + return db.table_exists(User) + + @router.get( "/config", response_model=BaseConfig, @@ -69,20 +79,12 @@ async def set_config( if config.jwt.secret is None: config.jwt.secret = token_hex(32) + DB.initialize(await connect_db(config)) + with open(CONFIG_FILE, "w") as kv: kv.write(config.json(indent=2)) -async def connect_db(config: BaseConfig = Depends(load_config)) -> Database: - db = await config.db.database - db.connect() - return db - - -async def has_tables(db: Database = Depends(connect_db)): - return db.table_exists(User) - - @router.get("/db", responses={ status.HTTP_200_OK: { "model": bool,