runtime database define
This commit is contained in:
parent
c8ede06c26
commit
e9785f0076
4 changed files with 29 additions and 19 deletions
|
@ -8,7 +8,6 @@ from passlib.context import CryptContext
|
||||||
from peewee import Database, SqliteDatabase
|
from peewee import Database, SqliteDatabase
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
DB = SqliteDatabase("tmp/vpn.db")
|
|
||||||
PRODUCTION_MODE = False
|
PRODUCTION_MODE = False
|
||||||
|
|
||||||
# to get a string like this run:
|
# to get a string like this run:
|
||||||
|
|
|
@ -3,10 +3,12 @@ from __future__ import annotations
|
||||||
import datetime
|
import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from peewee import (BooleanField, CharField, DateTimeField, ForeignKeyField,
|
from peewee import (BooleanField, CharField, DatabaseProxy, DateTimeField,
|
||||||
Model)
|
ForeignKeyField, Model)
|
||||||
|
|
||||||
from .config import CRYPT_CONTEXT, DB
|
from .config import CRYPT_CONTEXT
|
||||||
|
|
||||||
|
DB = DatabaseProxy()
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(Model):
|
class BaseModel(Model):
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import Depends, FastAPI
|
||||||
|
from peewee import Database
|
||||||
from kiwi_vpn_api.routers import install
|
|
||||||
|
|
||||||
from .config import PRODUCTION_MODE
|
from .config import PRODUCTION_MODE
|
||||||
from .routers import auth, user
|
from .db import DB
|
||||||
|
from .routers import auth, install, user
|
||||||
|
|
||||||
api = FastAPI(
|
api = FastAPI(
|
||||||
title="kiwi-vpn API",
|
title="kiwi-vpn API",
|
||||||
|
@ -27,6 +27,13 @@ api.include_router(install.router)
|
||||||
api.include_router(auth.router)
|
api.include_router(auth.router)
|
||||||
api.include_router(user.router)
|
api.include_router(user.router)
|
||||||
|
|
||||||
|
|
||||||
|
@api.on_event("startup")
|
||||||
|
async def api_startup(
|
||||||
|
db: Database = Depends(install.connect_db)
|
||||||
|
):
|
||||||
|
DB.initialize(db)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.mount("/api", api)
|
app.mount("/api", api)
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from peewee import Database
|
from peewee import Database
|
||||||
|
|
||||||
from ..config import BaseConfig
|
from ..config import BaseConfig
|
||||||
from ..db import Certificate, DistinguishedName, User, UserCapability
|
from ..db import DB, Certificate, DistinguishedName, User, UserCapability
|
||||||
|
|
||||||
router = APIRouter(prefix="/install")
|
router = APIRouter(prefix="/install")
|
||||||
|
|
||||||
|
@ -27,6 +27,16 @@ async def load_config() -> BaseConfig:
|
||||||
return BaseConfig()
|
return BaseConfig()
|
||||||
|
|
||||||
|
|
||||||
|
async def connect_db(config: BaseConfig = Depends(load_config)) -> Database:
|
||||||
|
db = await config.db.database
|
||||||
|
db.connect()
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
async def has_tables(db: Database = Depends(connect_db)) -> bool:
|
||||||
|
return db.table_exists(User)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/config",
|
"/config",
|
||||||
response_model=BaseConfig,
|
response_model=BaseConfig,
|
||||||
|
@ -69,20 +79,12 @@ async def set_config(
|
||||||
if config.jwt.secret is None:
|
if config.jwt.secret is None:
|
||||||
config.jwt.secret = token_hex(32)
|
config.jwt.secret = token_hex(32)
|
||||||
|
|
||||||
|
DB.initialize(await connect_db(config))
|
||||||
|
|
||||||
with open(CONFIG_FILE, "w") as kv:
|
with open(CONFIG_FILE, "w") as kv:
|
||||||
kv.write(config.json(indent=2))
|
kv.write(config.json(indent=2))
|
||||||
|
|
||||||
|
|
||||||
async def connect_db(config: BaseConfig = Depends(load_config)) -> Database:
|
|
||||||
db = await config.db.database
|
|
||||||
db.connect()
|
|
||||||
return db
|
|
||||||
|
|
||||||
|
|
||||||
async def has_tables(db: Database = Depends(connect_db)):
|
|
||||||
return db.table_exists(User)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/db", responses={
|
@router.get("/db", responses={
|
||||||
status.HTTP_200_OK: {
|
status.HTTP_200_OK: {
|
||||||
"model": bool,
|
"model": bool,
|
||||||
|
|
Loading…
Reference in a new issue