pydantic settings, improved scoping
This commit is contained in:
parent
d5b39db400
commit
98c89991b4
3 changed files with 49 additions and 28 deletions
|
@ -1,13 +1,27 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
from jose.constants import ALGORITHMS
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, BaseSettings, Field
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
CONFIG_FILE = "tmp/config.json"
|
||||
|
||||
class Settings(BaseSettings):
|
||||
production_mode: bool = False
|
||||
config_file: str = "tmp/config.json"
|
||||
openapi_url: str = "/openapi.json"
|
||||
docs_url: str | None = "/docs"
|
||||
redoc_url: str | None = None
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache
|
||||
def get() -> Settings:
|
||||
return Settings()
|
||||
|
||||
|
||||
class DBType(Enum):
|
||||
|
@ -43,11 +57,25 @@ class CryptoConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
class Config(BaseModel):
|
||||
db: DBConfig = Field(default_factory=DBConfig)
|
||||
jwt: JWTConfig = Field(default_factory=JWTConfig)
|
||||
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
|
||||
|
||||
@staticmethod
|
||||
async def get() -> Config | None:
|
||||
try:
|
||||
with open(Settings.get().config_file, "r") as config_file:
|
||||
return Config.parse_obj(json.load(config_file))
|
||||
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def set(config: Config) -> None:
|
||||
with open(Settings.get().config_file, "w") as config_file:
|
||||
config_file.write(config.json(indent=2))
|
||||
|
||||
@property
|
||||
def crypt_context(self) -> CryptContext:
|
||||
return self.crypto.crypt_context
|
||||
|
@ -55,17 +83,3 @@ class BaseConfig(BaseModel):
|
|||
@property
|
||||
def db_engine(self) -> Engine:
|
||||
return self.db.db_engine
|
||||
|
||||
|
||||
async def get() -> BaseConfig | None:
|
||||
try:
|
||||
with open(CONFIG_FILE, "r") as config_file:
|
||||
return BaseConfig.parse_obj(json.load(config_file))
|
||||
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
def set(config: BaseConfig) -> None:
|
||||
with open(CONFIG_FILE, "w") as config_file:
|
||||
config_file.write(config.json(indent=2))
|
||||
|
|
|
@ -3,11 +3,11 @@
|
|||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
|
||||
from . import config
|
||||
from .db import connection
|
||||
from .config import Config, Settings
|
||||
from .db import connection, crud
|
||||
from .routers import admin
|
||||
|
||||
PRODUCTION_MODE = False
|
||||
settings = Settings.get()
|
||||
|
||||
|
||||
api = FastAPI(
|
||||
|
@ -21,8 +21,9 @@ api = FastAPI(
|
|||
"name": "MIT License",
|
||||
"url": "https://opensource.org/licenses/mit-license.php",
|
||||
},
|
||||
docs_url="/docs" if not PRODUCTION_MODE else None,
|
||||
redoc_url="/redoc" if not PRODUCTION_MODE else None,
|
||||
openapi_url=settings.openapi_url,
|
||||
docs_url=settings.docs_url if not settings.production_mode else None,
|
||||
redoc_url=settings.redoc_url if not settings.production_mode else None,
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
|
@ -34,9 +35,15 @@ async def on_startup():
|
|||
# always include admin router
|
||||
api.include_router(admin.router)
|
||||
|
||||
if (current_config := await config.get()) is not None:
|
||||
if (current_config := await Config.get()) is not None:
|
||||
connection.reconnect(current_config.db_engine)
|
||||
|
||||
async for db in connection.get():
|
||||
user = crud.get_user(db, "admin")
|
||||
print(user.name)
|
||||
for cap in user.capabilities:
|
||||
print(cap.capability)
|
||||
|
||||
# include other routers
|
||||
# api.include_router(auth.router)
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from secrets import token_hex
|
|||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import config
|
||||
from ..config import Config
|
||||
from ..db import connection, crud, schemas
|
||||
|
||||
router = APIRouter(prefix="/admin")
|
||||
|
@ -22,8 +22,8 @@ router = APIRouter(prefix="/admin")
|
|||
},
|
||||
)
|
||||
async def set_config(
|
||||
new_config: config.BaseConfig,
|
||||
current_config: config.BaseConfig | None = Depends(config.get),
|
||||
new_config: Config,
|
||||
current_config: Config | None = Depends(Config.get),
|
||||
):
|
||||
if current_config is not None:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
|
@ -33,7 +33,7 @@ async def set_config(
|
|||
|
||||
connection.reconnect(new_config.db_engine)
|
||||
|
||||
config.set(new_config)
|
||||
Config.set(new_config)
|
||||
|
||||
|
||||
@router.post(
|
||||
|
@ -51,7 +51,7 @@ async def set_config(
|
|||
async def add_user(
|
||||
user_name: str,
|
||||
user_password: str,
|
||||
current_config: config.BaseConfig | None = Depends(config.get),
|
||||
current_config: Config | None = Depends(Config.get),
|
||||
db: Session | None = Depends(connection.get),
|
||||
):
|
||||
if current_config is None:
|
||||
|
|
Loading…
Reference in a new issue