pydantic settings, improved scoping

This commit is contained in:
Jörn-Michael Miehe 2022-03-18 22:43:02 +00:00
parent d5b39db400
commit 98c89991b4
3 changed files with 49 additions and 28 deletions

View file

@ -1,13 +1,27 @@
from __future__ import annotations
import functools
import json import json
from enum import Enum from enum import Enum
from jose.constants import ALGORITHMS from jose.constants import ALGORITHMS
from passlib.context import CryptContext from passlib.context import CryptContext
from pydantic import BaseModel, Field from pydantic import BaseModel, BaseSettings, Field
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
CONFIG_FILE = "tmp/config.json"
class Settings(BaseSettings):
production_mode: bool = False
config_file: str = "tmp/config.json"
openapi_url: str = "/openapi.json"
docs_url: str | None = "/docs"
redoc_url: str | None = None
@staticmethod
@functools.lru_cache
def get() -> Settings:
return Settings()
class DBType(Enum): class DBType(Enum):
@ -43,11 +57,25 @@ class CryptoConfig(BaseModel):
) )
class BaseConfig(BaseModel): class Config(BaseModel):
db: DBConfig = Field(default_factory=DBConfig) db: DBConfig = Field(default_factory=DBConfig)
jwt: JWTConfig = Field(default_factory=JWTConfig) jwt: JWTConfig = Field(default_factory=JWTConfig)
crypto: CryptoConfig = Field(default_factory=CryptoConfig) crypto: CryptoConfig = Field(default_factory=CryptoConfig)
@staticmethod
async def get() -> Config | None:
try:
with open(Settings.get().config_file, "r") as config_file:
return Config.parse_obj(json.load(config_file))
except FileNotFoundError:
return None
@staticmethod
def set(config: Config) -> None:
with open(Settings.get().config_file, "w") as config_file:
config_file.write(config.json(indent=2))
@property @property
def crypt_context(self) -> CryptContext: def crypt_context(self) -> CryptContext:
return self.crypto.crypt_context return self.crypto.crypt_context
@ -55,17 +83,3 @@ class BaseConfig(BaseModel):
@property @property
def db_engine(self) -> Engine: def db_engine(self) -> Engine:
return self.db.db_engine return self.db.db_engine
async def get() -> BaseConfig | None:
try:
with open(CONFIG_FILE, "r") as config_file:
return BaseConfig.parse_obj(json.load(config_file))
except FileNotFoundError:
return None
def set(config: BaseConfig) -> None:
with open(CONFIG_FILE, "w") as config_file:
config_file.write(config.json(indent=2))

View file

@ -3,11 +3,11 @@
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from . import config from .config import Config, Settings
from .db import connection from .db import connection, crud
from .routers import admin from .routers import admin
PRODUCTION_MODE = False settings = Settings.get()
api = FastAPI( api = FastAPI(
@ -21,8 +21,9 @@ api = FastAPI(
"name": "MIT License", "name": "MIT License",
"url": "https://opensource.org/licenses/mit-license.php", "url": "https://opensource.org/licenses/mit-license.php",
}, },
docs_url="/docs" if not PRODUCTION_MODE else None, openapi_url=settings.openapi_url,
redoc_url="/redoc" if not PRODUCTION_MODE else None, docs_url=settings.docs_url if not settings.production_mode else None,
redoc_url=settings.redoc_url if not settings.production_mode else None,
) )
app = FastAPI() app = FastAPI()
@ -34,9 +35,15 @@ async def on_startup():
# always include admin router # always include admin router
api.include_router(admin.router) api.include_router(admin.router)
if (current_config := await config.get()) is not None: if (current_config := await Config.get()) is not None:
connection.reconnect(current_config.db_engine) connection.reconnect(current_config.db_engine)
async for db in connection.get():
user = crud.get_user(db, "admin")
print(user.name)
for cap in user.capabilities:
print(cap.capability)
# include other routers # include other routers
# api.include_router(auth.router) # api.include_router(auth.router)

View file

@ -3,7 +3,7 @@ from secrets import token_hex
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from .. import config from ..config import Config
from ..db import connection, crud, schemas from ..db import connection, crud, schemas
router = APIRouter(prefix="/admin") router = APIRouter(prefix="/admin")
@ -22,8 +22,8 @@ router = APIRouter(prefix="/admin")
}, },
) )
async def set_config( async def set_config(
new_config: config.BaseConfig, new_config: Config,
current_config: config.BaseConfig | None = Depends(config.get), current_config: Config | None = Depends(Config.get),
): ):
if current_config is not None: if current_config is not None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
@ -33,7 +33,7 @@ async def set_config(
connection.reconnect(new_config.db_engine) connection.reconnect(new_config.db_engine)
config.set(new_config) Config.set(new_config)
@router.post( @router.post(
@ -51,7 +51,7 @@ async def set_config(
async def add_user( async def add_user(
user_name: str, user_name: str,
user_password: str, user_password: str,
current_config: config.BaseConfig | None = Depends(config.get), current_config: Config | None = Depends(Config.get),
db: Session | None = Depends(connection.get), db: Session | None = Depends(connection.get),
): ):
if current_config is None: if current_config is None: