diff --git a/api/kiwi_vpn_api/config.py b/api/kiwi_vpn_api/config.py index 41ab5ae..5f9fbfc 100644 --- a/api/kiwi_vpn_api/config.py +++ b/api/kiwi_vpn_api/config.py @@ -1,6 +1,7 @@ from __future__ import annotations from enum import Enum +import json from pathlib import Path from typing import Optional @@ -20,8 +21,6 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30 CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto") -CONFIG_FILE = "tmp/config.json" - class DBType(Enum): sqlite = "sqlite" @@ -29,7 +28,7 @@ class DBType(Enum): class DBConfig(BaseModel): - db_type: DBType = "sqlite" + db_type: DBType = DBType.sqlite @property def database(self) -> Database: @@ -58,3 +57,19 @@ class BaseConfig(BaseModel): def save(self, filename: Path) -> None: with open(filename, "w") as kv: kv.write(self.json(indent=2)) + + +CONFIG_FILE = "tmp/config.json" + + +async def get_default_config() -> BaseConfig: + return BaseConfig() + + +async def is_configured() -> bool: + return Path(CONFIG_FILE).is_file() + + +async def get_config() -> BaseConfig: + with open(CONFIG_FILE, "r") as kv: + return BaseConfig.parse_obj(json.load(kv)) diff --git a/api/kiwi_vpn_api/routers/install.py b/api/kiwi_vpn_api/routers/install.py index 22486c9..1c6be3b 100644 --- a/api/kiwi_vpn_api/routers/install.py +++ b/api/kiwi_vpn_api/routers/install.py @@ -1,35 +1,22 @@ -import json -from pathlib import Path +from secrets import token_hex from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse -from ..config import CONFIG_FILE, CRYPT_CONTEXT, DB, BaseConfig +from ..config import (CONFIG_FILE, CRYPT_CONTEXT, DB, BaseConfig, get_config, + get_default_config, is_configured) from ..db import Certificate, DistinguishedName, User, UserCapability router = APIRouter(prefix="/install") -async def get_default_config() -> BaseConfig: - return BaseConfig() - - -async def get_config() -> BaseConfig: - with open(CONFIG_FILE, "r") as kv: - return BaseConfig.parse_obj(json.load(kv)) - - -async def is_configured() -> bool: - return Path(CONFIG_FILE).is_file() - - -@router.get("/config/get_default", response_model=BaseConfig) -async def config_get_default(config: BaseConfig = Depends(get_default_config)): +@router.get("/config/default", response_model=BaseConfig) +async def get_default_config(config: BaseConfig = Depends(get_default_config)): return config @router.get( - "/config/get", + "/config", response_model=BaseConfig, responses={ status.HTTP_404_NOT_FOUND: { @@ -38,9 +25,9 @@ async def config_get_default(config: BaseConfig = Depends(get_default_config)): }, }, ) -async def config_get( - config: BaseConfig = Depends(get_config), +async def get_config( is_configured: bool = Depends(is_configured), + config: BaseConfig = Depends(get_config), ): if not is_configured: return JSONResponse(status_code=status.HTTP_404_NOT_FOUND) @@ -48,8 +35,8 @@ async def config_get( return config -@router.post( - "/config/set", +@router.put( + "/config", responses={ status.HTTP_200_OK: { "content": None, @@ -60,13 +47,16 @@ async def config_get( }, }, ) -async def config_set( +async def set_config( config: BaseConfig, is_configured: bool = Depends(is_configured), ): if is_configured: return JSONResponse(status_code=status.HTTP_403_FORBIDDEN) + if config.jwt.secret is None: + config.jwt.secret = token_hex(32) + config.save(CONFIG_FILE)