diff --git a/api/kiwi_vpn_api/config.py b/api/kiwi_vpn_api/config.py index 5a09b96..ca39caf 100644 --- a/api/kiwi_vpn_api/config.py +++ b/api/kiwi_vpn_api/config.py @@ -1,30 +1,14 @@ -from __future__ import annotations - -import json from enum import Enum -from pathlib import Path -from typing import Generator +import json -from fastapi import Depends from jose.constants import ALGORITHMS from passlib.context import CryptContext -from pydantic import BaseModel, Field, PrivateAttr +from pydantic import BaseModel, Field + +from sqlalchemy.engine import Engine from sqlalchemy import create_engine -from sqlalchemy.orm import Session, sessionmaker -from .db.models import ORMBaseModel - -PRODUCTION_MODE = False - -# to get a string like this run: -# openssl rand -hex 32 -SECRET_KEY = "2f7875b0d2be8a76eba8077ab4d9f8b1c749e02647e9ac9e0f909c3acbfc9856" -ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 30 - -CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto") - -SESSION_LOCAL = None +CONFIG_FILE = "tmp/config.json" class DBType(Enum): @@ -35,6 +19,13 @@ class DBType(Enum): class DBConfig(BaseModel): db_type: DBType = DBType.sqlite + @property + def db_engine(self) -> Engine: + return create_engine( + "sqlite:///./tmp/vpn.db", + connect_args={"check_same_thread": False}, + ) + class JWTConfig(BaseModel): secret: str | None = None @@ -45,6 +36,13 @@ class JWTConfig(BaseModel): class CryptoConfig(BaseModel): schemes: list[str] = ["bcrypt"] + @property + def crypt_context(self) -> CryptContext: + return CryptContext( + schemes=self.schemes, + deprecated="auto", + ) + class BaseConfig(BaseModel): db: DBConfig = Field(default_factory=DBConfig) @@ -53,53 +51,22 @@ class BaseConfig(BaseModel): @property def crypt_context(self) -> CryptContext: - return CryptContext( - schemes=self.crypto.schemes, - deprecated="auto", - ) - - __session_local: sessionmaker = PrivateAttr() - - async def connect_db(self) -> None: - engine = create_engine( - "sqlite:///./tmp/vpn.db", - connect_args={"check_same_thread": False}, - ) - self.__session_local = sessionmaker( - autocommit=False, autoflush=False, bind=engine, - ) - ORMBaseModel.metadata.create_all(bind=engine) + return self.crypto.crypt_context @property - def database(self) -> Session | None: - if self.__session_local is not None: - return self.__session_local() + def db_engine(self) -> Engine: + return self.db.db_engine -CONFIG_FILE = "tmp/config.json" - - -async def has_config() -> bool: - return Path(CONFIG_FILE).is_file() - - -async def load_config() -> BaseConfig: +async def get() -> BaseConfig | None: try: - with open(CONFIG_FILE, "r") as kv: - return BaseConfig.parse_obj(json.load(kv)) + with open(CONFIG_FILE, "r") as config_file: + return BaseConfig.parse_obj(json.load(config_file)) except FileNotFoundError: - return BaseConfig() + return None -async def get_db( - config: BaseConfig = Depends(load_config) -) -> Generator[Session | None, None, None]: - if db := config.database is None: - yield None - - else: - try: - yield db - finally: - db.close() +def set(config: BaseConfig) -> None: + with open(CONFIG_FILE, "w") as config_file: + config_file.write(config.json(indent=2)) diff --git a/api/kiwi_vpn_api/db/connection.py b/api/kiwi_vpn_api/db/connection.py new file mode 100644 index 0000000..991896e --- /dev/null +++ b/api/kiwi_vpn_api/db/connection.py @@ -0,0 +1,30 @@ +from typing import Generator +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker, Session + +from .models import ORMBaseModel + +ENGINE: Engine | None = None +SESSION_LOCAL: sessionmaker | None = None + + +def reconnect(engine: Engine) -> None: + global ENGINE, SESSION_LOCAL + + ENGINE = engine + SESSION_LOCAL = sessionmaker( + autocommit=False, autoflush=False, bind=engine, + ) + ORMBaseModel.metadata.create_all(bind=engine) + + +async def get() -> Generator[Session | None, None, None]: + if SESSION_LOCAL is None: + yield None + + else: + db = SESSION_LOCAL() + try: + yield db + finally: + db.close() diff --git a/api/kiwi_vpn_api/main.py b/api/kiwi_vpn_api/main.py index e3641aa..ea4a53e 100755 --- a/api/kiwi_vpn_api/main.py +++ b/api/kiwi_vpn_api/main.py @@ -1,11 +1,14 @@ #!/usr/bin/env python3 import uvicorn -from fastapi import Depends, FastAPI +from fastapi import FastAPI + +from . import config +from .routers import admin +from .db import connection + +PRODUCTION_MODE = False -from .config import (PRODUCTION_MODE, BaseConfig, has_config, - load_config) -from .routers import install api = FastAPI( title="kiwi-vpn API", @@ -22,19 +25,20 @@ api = FastAPI( redoc_url="/redoc" if not PRODUCTION_MODE else None, ) -api.include_router(install.router) -# api.include_router(auth.router) -# api.include_router(user.router) - app = FastAPI() app.mount("/api", api) @app.on_event("startup") async def on_startup(): - if await has_config(): - config = await load_config() - await config.connect_db() + # always include admin router + api.include_router(admin.router) + + if (current_config := await config.get()) is not None: + connection.reconnect(current_config.db_engine) + + # include other routers + # api.include_router(auth.router) def main(): diff --git a/api/kiwi_vpn_api/plan.md b/api/kiwi_vpn_api/plan.md index 4d14309..a3e6c23 100644 --- a/api/kiwi_vpn_api/plan.md +++ b/api/kiwi_vpn_api/plan.md @@ -26,7 +26,11 @@ else: # POST admin/user -if user table is empty: +if no config file present: + +- error + +elif user table is empty: - create new user - give "admin" cap to new user diff --git a/api/kiwi_vpn_api/routers/admin.py b/api/kiwi_vpn_api/routers/admin.py new file mode 100644 index 0000000..a9c72a0 --- /dev/null +++ b/api/kiwi_vpn_api/routers/admin.py @@ -0,0 +1,69 @@ +from secrets import token_hex + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session + +from .. import config + +from ..db import crud, schemas, connection + +router = APIRouter(prefix="/admin") + + +@router.put( + "/config", + responses={ + status.HTTP_200_OK: { + "content": None, + }, + status.HTTP_403_FORBIDDEN: { + "description": "Must be admin", + "content": None, + }, + }, +) +async def set_config( + new_config: config.BaseConfig, + current_config: config.BaseConfig | None = Depends(config.get), +): + if current_config is not None: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + + if new_config.jwt.secret is None: + new_config.jwt.secret = token_hex(32) + + connection.reconnect(new_config.db_engine) + + config.set(new_config) + + +@router.post( + "/user", + responses={ + status.HTTP_200_OK: { + "content": None, + }, + status.HTTP_400_BAD_REQUEST: { + "description": "Database doesn't exist", + "content": None, + }, + }, +) +async def add_user( + user_name: str, + user_password: str, + current_config: config.BaseConfig | None = Depends(config.get), + db: Session | None = Depends(connection.get), +): + if current_config is None: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + + crud.create_user( + db=db, + user=schemas.UserCreate( + name=user_name, + password=user_password, + ), + crypt_context=current_config.crypt_context, + ) + crud.add_user_capability(db, user_name=user_name, capability="admin") diff --git a/api/kiwi_vpn_api/routers/install.py b/api/kiwi_vpn_api/routers/install.py deleted file mode 100644 index 7a347ed..0000000 --- a/api/kiwi_vpn_api/routers/install.py +++ /dev/null @@ -1,101 +0,0 @@ -from secrets import token_hex - -from fastapi import APIRouter, Depends, HTTPException, status -from sqlalchemy.orm import Session - -from ..config import (CONFIG_FILE, BaseConfig, get_db, has_config, - load_config) -from ..db import crud, schemas - -router = APIRouter(prefix="/install") - - -@router.get( - "/config", - response_model=BaseConfig, - responses={ - status.HTTP_403_FORBIDDEN: { - "description": "Must be admin", - "content": None, - }, - }, -) -async def get_config( - config: BaseConfig = Depends(load_config), - has_config: bool = Depends(has_config), -): - if has_config: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - - return config - - -@router.put( - "/config", - responses={ - status.HTTP_200_OK: { - "content": None, - }, - status.HTTP_403_FORBIDDEN: { - "description": "Must be admin", - "content": None, - }, - }, -) -async def set_config( - config: BaseConfig, - has_config: bool = Depends(has_config), -): - if has_config: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - - if config.jwt.secret is None: - config.jwt.secret = token_hex(32) - - await config.connect_db() - - with open(CONFIG_FILE, "w") as kv: - kv.write(config.json(indent=2)) - - -@router.get("/db", responses={ - status.HTTP_200_OK: { - "model": bool, - }, -}) -async def check_db(): - return True - - -@router.put( - "/db", - responses={ - status.HTTP_200_OK: { - "content": None, - }, - status.HTTP_400_BAD_REQUEST: { - "description": "Database exists", - "content": None, - }, - }, -) -async def create_db( - admin_name: str, - admin_password: str, - config: BaseConfig = Depends(load_config), - db: Session = Depends(get_db), -): - # if await has_tables(db): - # raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) - - # cryptContext = await config.crypto.cryptContext - - if db is None: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) - - crud.create_user(db, schemas.UserCreate( - name=admin_name, - password=admin_password, - crypt_context=config.crypt_context, - )) - crud.add_user_capability(db, user_name=admin_name, capability="admin") diff --git a/api/kiwi_vpn_api/routers/user.py b/api/kiwi_vpn_api/routers/user.py deleted file mode 100644 index 3487e4e..0000000 --- a/api/kiwi_vpn_api/routers/user.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - - -from fastapi import APIRouter, Depends, HTTPException, status -from jose import JWTError, jwt -from pydantic import BaseModel - -from ..config import ALGORITHM, SECRET_KEY -from ..db import User as db_User -from .auth import SCHEME - -router = APIRouter(prefix="/user") - - -class User(BaseModel): - name: str - capabilities: list[str] - - @classmethod - def from_db(cls, username: str) -> User | None: - user = db_User.get_by_name(username) - - if not user: - return None - - return cls( - name=user.name, - capabilities=[cap.capability for cap in user.capabilities], - ) - - -@router.get("/current", response_model=User) -async def get_current_user(token: str = Depends(SCHEME)): - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - username: str = payload.get("sub") - if username is None: - raise credentials_exception - except JWTError: - raise credentials_exception - user = User.from_db(username) - if user is None: - raise credentials_exception - return user - - -async def is_admin(current_user: User = Depends(get_current_user)): - return ("admin" in current_user.capabilities) - - -@router.get("/current/is_admin") -async def current_user_is_admin(is_admin: bool = Depends(is_admin)): - return {"is_admin": is_admin}