diff --git a/api/kiwi_vpn_api/config.py b/api/kiwi_vpn_api/config.py index 42ca7ae..945e367 100644 --- a/api/kiwi_vpn_api/config.py +++ b/api/kiwi_vpn_api/config.py @@ -1,13 +1,27 @@ +from __future__ import annotations + +import functools import json from enum import Enum from jose.constants import ALGORITHMS from passlib.context import CryptContext -from pydantic import BaseModel, Field +from pydantic import BaseModel, BaseSettings, Field from sqlalchemy import create_engine from sqlalchemy.engine import Engine -CONFIG_FILE = "tmp/config.json" + +class Settings(BaseSettings): + production_mode: bool = False + config_file: str = "tmp/config.json" + openapi_url: str = "/openapi.json" + docs_url: str | None = "/docs" + redoc_url: str | None = None + + @staticmethod + @functools.lru_cache + def get() -> Settings: + return Settings() class DBType(Enum): @@ -43,11 +57,25 @@ class CryptoConfig(BaseModel): ) -class BaseConfig(BaseModel): +class Config(BaseModel): db: DBConfig = Field(default_factory=DBConfig) jwt: JWTConfig = Field(default_factory=JWTConfig) crypto: CryptoConfig = Field(default_factory=CryptoConfig) + @staticmethod + async def get() -> Config | None: + try: + with open(Settings.get().config_file, "r") as config_file: + return Config.parse_obj(json.load(config_file)) + + except FileNotFoundError: + return None + + @staticmethod + def set(config: Config) -> None: + with open(Settings.get().config_file, "w") as config_file: + config_file.write(config.json(indent=2)) + @property def crypt_context(self) -> CryptContext: return self.crypto.crypt_context @@ -55,17 +83,3 @@ class BaseConfig(BaseModel): @property def db_engine(self) -> Engine: return self.db.db_engine - - -async def get() -> BaseConfig | None: - try: - with open(CONFIG_FILE, "r") as config_file: - return BaseConfig.parse_obj(json.load(config_file)) - - except FileNotFoundError: - return None - - -def set(config: BaseConfig) -> None: - with open(CONFIG_FILE, "w") as config_file: - config_file.write(config.json(indent=2)) diff --git a/api/kiwi_vpn_api/main.py b/api/kiwi_vpn_api/main.py index 2fa873a..b1c87e5 100755 --- a/api/kiwi_vpn_api/main.py +++ b/api/kiwi_vpn_api/main.py @@ -3,11 +3,11 @@ import uvicorn from fastapi import FastAPI -from . import config -from .db import connection +from .config import Config, Settings +from .db import connection, crud from .routers import admin -PRODUCTION_MODE = False +settings = Settings.get() api = FastAPI( @@ -21,8 +21,9 @@ api = FastAPI( "name": "MIT License", "url": "https://opensource.org/licenses/mit-license.php", }, - docs_url="/docs" if not PRODUCTION_MODE else None, - redoc_url="/redoc" if not PRODUCTION_MODE else None, + openapi_url=settings.openapi_url, + docs_url=settings.docs_url if not settings.production_mode else None, + redoc_url=settings.redoc_url if not settings.production_mode else None, ) app = FastAPI() @@ -34,9 +35,15 @@ async def on_startup(): # always include admin router api.include_router(admin.router) - if (current_config := await config.get()) is not None: + if (current_config := await Config.get()) is not None: connection.reconnect(current_config.db_engine) + async for db in connection.get(): + user = crud.get_user(db, "admin") + print(user.name) + for cap in user.capabilities: + print(cap.capability) + # include other routers # api.include_router(auth.router) diff --git a/api/kiwi_vpn_api/routers/admin.py b/api/kiwi_vpn_api/routers/admin.py index ab75c40..8926418 100644 --- a/api/kiwi_vpn_api/routers/admin.py +++ b/api/kiwi_vpn_api/routers/admin.py @@ -3,7 +3,7 @@ from secrets import token_hex from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from .. import config +from ..config import Config from ..db import connection, crud, schemas router = APIRouter(prefix="/admin") @@ -22,8 +22,8 @@ router = APIRouter(prefix="/admin") }, ) async def set_config( - new_config: config.BaseConfig, - current_config: config.BaseConfig | None = Depends(config.get), + new_config: Config, + current_config: Config | None = Depends(Config.get), ): if current_config is not None: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) @@ -33,7 +33,7 @@ async def set_config( connection.reconnect(new_config.db_engine) - config.set(new_config) + Config.set(new_config) @router.post( @@ -51,7 +51,7 @@ async def set_config( async def add_user( user_name: str, user_password: str, - current_config: config.BaseConfig | None = Depends(config.get), + current_config: Config | None = Depends(Config.get), db: Session | None = Depends(connection.get), ): if current_config is None: