pydantic settings, improved scoping
This commit is contained in:
parent
d5b39db400
commit
98c89991b4
3 changed files with 49 additions and 28 deletions
|
@ -1,13 +1,27 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from jose.constants import ALGORITHMS
|
from jose.constants import ALGORITHMS
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, BaseSettings, Field
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
CONFIG_FILE = "tmp/config.json"
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
production_mode: bool = False
|
||||||
|
config_file: str = "tmp/config.json"
|
||||||
|
openapi_url: str = "/openapi.json"
|
||||||
|
docs_url: str | None = "/docs"
|
||||||
|
redoc_url: str | None = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@functools.lru_cache
|
||||||
|
def get() -> Settings:
|
||||||
|
return Settings()
|
||||||
|
|
||||||
|
|
||||||
class DBType(Enum):
|
class DBType(Enum):
|
||||||
|
@ -43,11 +57,25 @@ class CryptoConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseConfig(BaseModel):
|
class Config(BaseModel):
|
||||||
db: DBConfig = Field(default_factory=DBConfig)
|
db: DBConfig = Field(default_factory=DBConfig)
|
||||||
jwt: JWTConfig = Field(default_factory=JWTConfig)
|
jwt: JWTConfig = Field(default_factory=JWTConfig)
|
||||||
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
|
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get() -> Config | None:
|
||||||
|
try:
|
||||||
|
with open(Settings.get().config_file, "r") as config_file:
|
||||||
|
return Config.parse_obj(json.load(config_file))
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set(config: Config) -> None:
|
||||||
|
with open(Settings.get().config_file, "w") as config_file:
|
||||||
|
config_file.write(config.json(indent=2))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def crypt_context(self) -> CryptContext:
|
def crypt_context(self) -> CryptContext:
|
||||||
return self.crypto.crypt_context
|
return self.crypto.crypt_context
|
||||||
|
@ -55,17 +83,3 @@ class BaseConfig(BaseModel):
|
||||||
@property
|
@property
|
||||||
def db_engine(self) -> Engine:
|
def db_engine(self) -> Engine:
|
||||||
return self.db.db_engine
|
return self.db.db_engine
|
||||||
|
|
||||||
|
|
||||||
async def get() -> BaseConfig | None:
|
|
||||||
try:
|
|
||||||
with open(CONFIG_FILE, "r") as config_file:
|
|
||||||
return BaseConfig.parse_obj(json.load(config_file))
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def set(config: BaseConfig) -> None:
|
|
||||||
with open(CONFIG_FILE, "w") as config_file:
|
|
||||||
config_file.write(config.json(indent=2))
|
|
||||||
|
|
|
@ -3,11 +3,11 @@
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
from . import config
|
from .config import Config, Settings
|
||||||
from .db import connection
|
from .db import connection, crud
|
||||||
from .routers import admin
|
from .routers import admin
|
||||||
|
|
||||||
PRODUCTION_MODE = False
|
settings = Settings.get()
|
||||||
|
|
||||||
|
|
||||||
api = FastAPI(
|
api = FastAPI(
|
||||||
|
@ -21,8 +21,9 @@ api = FastAPI(
|
||||||
"name": "MIT License",
|
"name": "MIT License",
|
||||||
"url": "https://opensource.org/licenses/mit-license.php",
|
"url": "https://opensource.org/licenses/mit-license.php",
|
||||||
},
|
},
|
||||||
docs_url="/docs" if not PRODUCTION_MODE else None,
|
openapi_url=settings.openapi_url,
|
||||||
redoc_url="/redoc" if not PRODUCTION_MODE else None,
|
docs_url=settings.docs_url if not settings.production_mode else None,
|
||||||
|
redoc_url=settings.redoc_url if not settings.production_mode else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
@ -34,9 +35,15 @@ async def on_startup():
|
||||||
# always include admin router
|
# always include admin router
|
||||||
api.include_router(admin.router)
|
api.include_router(admin.router)
|
||||||
|
|
||||||
if (current_config := await config.get()) is not None:
|
if (current_config := await Config.get()) is not None:
|
||||||
connection.reconnect(current_config.db_engine)
|
connection.reconnect(current_config.db_engine)
|
||||||
|
|
||||||
|
async for db in connection.get():
|
||||||
|
user = crud.get_user(db, "admin")
|
||||||
|
print(user.name)
|
||||||
|
for cap in user.capabilities:
|
||||||
|
print(cap.capability)
|
||||||
|
|
||||||
# include other routers
|
# include other routers
|
||||||
# api.include_router(auth.router)
|
# api.include_router(auth.router)
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from secrets import token_hex
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from .. import config
|
from ..config import Config
|
||||||
from ..db import connection, crud, schemas
|
from ..db import connection, crud, schemas
|
||||||
|
|
||||||
router = APIRouter(prefix="/admin")
|
router = APIRouter(prefix="/admin")
|
||||||
|
@ -22,8 +22,8 @@ router = APIRouter(prefix="/admin")
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def set_config(
|
async def set_config(
|
||||||
new_config: config.BaseConfig,
|
new_config: Config,
|
||||||
current_config: config.BaseConfig | None = Depends(config.get),
|
current_config: Config | None = Depends(Config.get),
|
||||||
):
|
):
|
||||||
if current_config is not None:
|
if current_config is not None:
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||||
|
@ -33,7 +33,7 @@ async def set_config(
|
||||||
|
|
||||||
connection.reconnect(new_config.db_engine)
|
connection.reconnect(new_config.db_engine)
|
||||||
|
|
||||||
config.set(new_config)
|
Config.set(new_config)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
@ -51,7 +51,7 @@ async def set_config(
|
||||||
async def add_user(
|
async def add_user(
|
||||||
user_name: str,
|
user_name: str,
|
||||||
user_password: str,
|
user_password: str,
|
||||||
current_config: config.BaseConfig | None = Depends(config.get),
|
current_config: Config | None = Depends(Config.get),
|
||||||
db: Session | None = Depends(connection.get),
|
db: Session | None = Depends(connection.get),
|
||||||
):
|
):
|
||||||
if current_config is None:
|
if current_config is None:
|
||||||
|
|
Loading…
Reference in a new issue