runtime database define

This commit is contained in:
Jörn-Michael Miehe 2022-03-16 14:54:42 +00:00
parent c8ede06c26
commit e9785f0076
4 changed files with 29 additions and 19 deletions

View file

@ -8,7 +8,6 @@ from passlib.context import CryptContext
from peewee import Database, SqliteDatabase from peewee import Database, SqliteDatabase
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
DB = SqliteDatabase("tmp/vpn.db")
PRODUCTION_MODE = False PRODUCTION_MODE = False
# to get a string like this run: # to get a string like this run:

View file

@ -3,10 +3,12 @@ from __future__ import annotations
import datetime import datetime
from typing import Optional from typing import Optional
from peewee import (BooleanField, CharField, DateTimeField, ForeignKeyField, from peewee import (BooleanField, CharField, DatabaseProxy, DateTimeField,
Model) ForeignKeyField, Model)
from .config import CRYPT_CONTEXT, DB from .config import CRYPT_CONTEXT
DB = DatabaseProxy()
class BaseModel(Model): class BaseModel(Model):

View file

@ -1,12 +1,12 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import Depends, FastAPI
from peewee import Database
from kiwi_vpn_api.routers import install
from .config import PRODUCTION_MODE from .config import PRODUCTION_MODE
from .routers import auth, user from .db import DB
from .routers import auth, install, user
api = FastAPI( api = FastAPI(
title="kiwi-vpn API", title="kiwi-vpn API",
@ -27,6 +27,13 @@ api.include_router(install.router)
api.include_router(auth.router) api.include_router(auth.router)
api.include_router(user.router) api.include_router(user.router)
@api.on_event("startup")
async def api_startup(
db: Database = Depends(install.connect_db)
):
DB.initialize(db)
app = FastAPI() app = FastAPI()
app.mount("/api", api) app.mount("/api", api)

View file

@ -6,7 +6,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
from peewee import Database from peewee import Database
from ..config import BaseConfig from ..config import BaseConfig
from ..db import Certificate, DistinguishedName, User, UserCapability from ..db import DB, Certificate, DistinguishedName, User, UserCapability
router = APIRouter(prefix="/install") router = APIRouter(prefix="/install")
@ -27,6 +27,16 @@ async def load_config() -> BaseConfig:
return BaseConfig() return BaseConfig()
async def connect_db(config: BaseConfig = Depends(load_config)) -> Database:
db = await config.db.database
db.connect()
return db
async def has_tables(db: Database = Depends(connect_db)) -> bool:
return db.table_exists(User)
@router.get( @router.get(
"/config", "/config",
response_model=BaseConfig, response_model=BaseConfig,
@ -69,20 +79,12 @@ async def set_config(
if config.jwt.secret is None: if config.jwt.secret is None:
config.jwt.secret = token_hex(32) config.jwt.secret = token_hex(32)
DB.initialize(await connect_db(config))
with open(CONFIG_FILE, "w") as kv: with open(CONFIG_FILE, "w") as kv:
kv.write(config.json(indent=2)) kv.write(config.json(indent=2))
async def connect_db(config: BaseConfig = Depends(load_config)) -> Database:
db = await config.db.database
db.connect()
return db
async def has_tables(db: Database = Depends(connect_db)):
return db.table_exists(User)
@router.get("/db", responses={ @router.get("/db", responses={
status.HTTP_200_OK: { status.HTTP_200_OK: {
"model": bool, "model": bool,