diff --git a/api/kiwi_vpn_api/config.py b/api/kiwi_vpn_api/config.py index c5fab27..5a09b96 100644 --- a/api/kiwi_vpn_api/config.py +++ b/api/kiwi_vpn_api/config.py @@ -3,11 +3,12 @@ from __future__ import annotations import json from enum import Enum from pathlib import Path +from typing import Generator from fastapi import Depends from jose.constants import ALGORITHMS from passlib.context import CryptContext -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, PrivateAttr from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker @@ -50,6 +51,30 @@ class BaseConfig(BaseModel): jwt: JWTConfig = Field(default_factory=JWTConfig) crypto: CryptoConfig = Field(default_factory=CryptoConfig) + @property + def crypt_context(self) -> CryptContext: + return CryptContext( + schemes=self.crypto.schemes, + deprecated="auto", + ) + + __session_local: sessionmaker = PrivateAttr() + + async def connect_db(self) -> None: + engine = create_engine( + "sqlite:///./tmp/vpn.db", + connect_args={"check_same_thread": False}, + ) + self.__session_local = sessionmaker( + autocommit=False, autoflush=False, bind=engine, + ) + ORMBaseModel.metadata.create_all(bind=engine) + + @property + def database(self) -> Session | None: + if self.__session_local is not None: + return self.__session_local() + CONFIG_FILE = "tmp/config.json" @@ -67,24 +92,14 @@ async def load_config() -> BaseConfig: return BaseConfig() -async def connect_db(config: BaseConfig = Depends(load_config)) -> None: - global SESSION_LOCAL +async def get_db( + config: BaseConfig = Depends(load_config) +) -> Generator[Session | None, None, None]: + if db := config.database is None: + yield None - engine = create_engine( - "sqlite:///./tmp/vpn.db", - connect_args={"check_same_thread": False}, - ) - SESSION_LOCAL = sessionmaker( - autocommit=False, autoflush=False, bind=engine) - - ORMBaseModel.metadata.create_all(bind=engine) - - -async def get_db() -> Session: - global SESSION_LOCAL - - db = SESSION_LOCAL() - try: - yield db - finally: - db.close() + else: + try: + yield db + finally: + db.close() diff --git a/api/kiwi_vpn_api/db/crud.py b/api/kiwi_vpn_api/db/crud.py index abf1042..3e8c6be 100644 --- a/api/kiwi_vpn_api/db/crud.py +++ b/api/kiwi_vpn_api/db/crud.py @@ -1,4 +1,5 @@ from sqlalchemy.orm import Session +from passlib.context import CryptContext from . import models, schemas @@ -9,10 +10,14 @@ def get_user(db: Session, name: str): .filter(models.User.name == name).first()) -def create_user(db: Session, user: schemas.UserCreate): +def create_user( + db: Session, + user: schemas.UserCreate, + crypt_context: CryptContext +): db_user = models.User( name=user.name, - password=user.password + "notreallyhashed", + password=crypt_context.hash(user.password), ) db.add(db_user) db.commit() diff --git a/api/kiwi_vpn_api/main.py b/api/kiwi_vpn_api/main.py index d1c3626..e3641aa 100755 --- a/api/kiwi_vpn_api/main.py +++ b/api/kiwi_vpn_api/main.py @@ -1,9 +1,10 @@ #!/usr/bin/env python3 import uvicorn -from fastapi import FastAPI +from fastapi import Depends, FastAPI -from .config import PRODUCTION_MODE +from .config import (PRODUCTION_MODE, BaseConfig, has_config, + load_config) from .routers import install api = FastAPI( @@ -29,6 +30,13 @@ app = FastAPI() app.mount("/api", api) +@app.on_event("startup") +async def on_startup(): + if await has_config(): + config = await load_config() + await config.connect_db() + + def main(): uvicorn.run( "kiwi_vpn_api.main:app", diff --git a/api/kiwi_vpn_api/plan.md b/api/kiwi_vpn_api/plan.md new file mode 100644 index 0000000..4d14309 --- /dev/null +++ b/api/kiwi_vpn_api/plan.md @@ -0,0 +1,39 @@ +# Startup + +if config file present: + +- load config file +- connect to DB +- mount all routers + +else: + +- mount admin router + +# PUT admin/config + +if config file present: + +- if user is admin: +- overwrite config +- reload config, reconnect to DB + +else: + +- overwrite config +- reload config, connect to DB +- mount all routers + +# POST admin/user + +if user table is empty: + +- create new user +- give "admin" cap to new user + +else: + +- if user is admin: +- create new user + +... \ No newline at end of file diff --git a/api/kiwi_vpn_api/routers/install.py b/api/kiwi_vpn_api/routers/install.py index 7f9c22f..7a347ed 100644 --- a/api/kiwi_vpn_api/routers/install.py +++ b/api/kiwi_vpn_api/routers/install.py @@ -3,7 +3,7 @@ from secrets import token_hex from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from ..config import (CONFIG_FILE, BaseConfig, connect_db, get_db, has_config, +from ..config import (CONFIG_FILE, BaseConfig, get_db, has_config, load_config) from ..db import crud, schemas @@ -52,7 +52,7 @@ async def set_config( if config.jwt.secret is None: config.jwt.secret = token_hex(32) - await connect_db(config) + await config.connect_db() with open(CONFIG_FILE, "w") as kv: kv.write(config.json(indent=2)) @@ -82,17 +82,20 @@ async def check_db(): async def create_db( admin_name: str, admin_password: str, + config: BaseConfig = Depends(load_config), db: Session = Depends(get_db), ): # if await has_tables(db): # raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) - # db.create_tables([Certificate, DistinguishedName, User, UserCapability]) - # cryptContext = await config.crypto.cryptContext + if db is None: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + crud.create_user(db, schemas.UserCreate( name=admin_name, password=admin_password, + crypt_context=config.crypt_context, )) crud.add_user_capability(db, user_name=admin_name, capability="admin") diff --git a/api/kiwi_vpn_api/routers/user.py b/api/kiwi_vpn_api/routers/user.py index eae9636..3487e4e 100644 --- a/api/kiwi_vpn_api/routers/user.py +++ b/api/kiwi_vpn_api/routers/user.py @@ -29,6 +29,7 @@ class User(BaseModel): ) +@router.get("/current", response_model=User) async def get_current_user(token: str = Depends(SCHEME)): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -48,15 +49,10 @@ async def get_current_user(token: str = Depends(SCHEME)): return user -@router.get("/current_user/get", response_model=User) -async def get_current_user(current_user: User = Depends(get_current_user)): - return current_user - - async def is_admin(current_user: User = Depends(get_current_user)): return ("admin" in current_user.capabilities) -@router.get("/current_user/is_admin") +@router.get("/current/is_admin") async def current_user_is_admin(is_admin: bool = Depends(is_admin)): return {"is_admin": is_admin}