pydantic settings, improved scoping

This commit is contained in:
Jörn-Michael Miehe 2022-03-18 22:43:02 +00:00
parent d5b39db400
commit 98c89991b4
3 changed files with 49 additions and 28 deletions

View file

@ -1,13 +1,27 @@
from __future__ import annotations
import functools
import json
from enum import Enum
from jose.constants import ALGORITHMS
from passlib.context import CryptContext
from pydantic import BaseModel, Field
from pydantic import BaseModel, BaseSettings, Field
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
CONFIG_FILE = "tmp/config.json"
class Settings(BaseSettings):
production_mode: bool = False
config_file: str = "tmp/config.json"
openapi_url: str = "/openapi.json"
docs_url: str | None = "/docs"
redoc_url: str | None = None
@staticmethod
@functools.lru_cache
def get() -> Settings:
return Settings()
class DBType(Enum):
@ -43,11 +57,25 @@ class CryptoConfig(BaseModel):
)
class BaseConfig(BaseModel):
class Config(BaseModel):
db: DBConfig = Field(default_factory=DBConfig)
jwt: JWTConfig = Field(default_factory=JWTConfig)
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
@staticmethod
async def get() -> Config | None:
try:
with open(Settings.get().config_file, "r") as config_file:
return Config.parse_obj(json.load(config_file))
except FileNotFoundError:
return None
@staticmethod
def set(config: Config) -> None:
with open(Settings.get().config_file, "w") as config_file:
config_file.write(config.json(indent=2))
@property
def crypt_context(self) -> CryptContext:
return self.crypto.crypt_context
@ -55,17 +83,3 @@ class BaseConfig(BaseModel):
@property
def db_engine(self) -> Engine:
return self.db.db_engine
async def get() -> BaseConfig | None:
try:
with open(CONFIG_FILE, "r") as config_file:
return BaseConfig.parse_obj(json.load(config_file))
except FileNotFoundError:
return None
def set(config: BaseConfig) -> None:
with open(CONFIG_FILE, "w") as config_file:
config_file.write(config.json(indent=2))

View file

@ -3,11 +3,11 @@
import uvicorn
from fastapi import FastAPI
from . import config
from .db import connection
from .config import Config, Settings
from .db import connection, crud
from .routers import admin
PRODUCTION_MODE = False
settings = Settings.get()
api = FastAPI(
@ -21,8 +21,9 @@ api = FastAPI(
"name": "MIT License",
"url": "https://opensource.org/licenses/mit-license.php",
},
docs_url="/docs" if not PRODUCTION_MODE else None,
redoc_url="/redoc" if not PRODUCTION_MODE else None,
openapi_url=settings.openapi_url,
docs_url=settings.docs_url if not settings.production_mode else None,
redoc_url=settings.redoc_url if not settings.production_mode else None,
)
app = FastAPI()
@ -34,9 +35,15 @@ async def on_startup():
# always include admin router
api.include_router(admin.router)
if (current_config := await config.get()) is not None:
if (current_config := await Config.get()) is not None:
connection.reconnect(current_config.db_engine)
async for db in connection.get():
user = crud.get_user(db, "admin")
print(user.name)
for cap in user.capabilities:
print(cap.capability)
# include other routers
# api.include_router(auth.router)

View file

@ -3,7 +3,7 @@ from secrets import token_hex
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from .. import config
from ..config import Config
from ..db import connection, crud, schemas
router = APIRouter(prefix="/admin")
@ -22,8 +22,8 @@ router = APIRouter(prefix="/admin")
},
)
async def set_config(
new_config: config.BaseConfig,
current_config: config.BaseConfig | None = Depends(config.get),
new_config: Config,
current_config: Config | None = Depends(Config.get),
):
if current_config is not None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
@ -33,7 +33,7 @@ async def set_config(
connection.reconnect(new_config.db_engine)
config.set(new_config)
Config.set(new_config)
@router.post(
@ -51,7 +51,7 @@ async def set_config(
async def add_user(
user_name: str,
user_password: str,
current_config: config.BaseConfig | None = Depends(config.get),
current_config: Config | None = Depends(Config.get),
db: Session | None = Depends(connection.get),
):
if current_config is None: