adhere to a plan for a change :)

This commit is contained in:
Jörn-Michael Miehe 2022-03-18 18:22:17 +00:00
parent e34e669f79
commit 746ca51bdd
7 changed files with 148 additions and 233 deletions

View file

@ -1,30 +1,14 @@
from __future__ import annotations
import json
from enum import Enum
from pathlib import Path
from typing import Generator
import json
from fastapi import Depends
from jose.constants import ALGORITHMS
from passlib.context import CryptContext
from pydantic import BaseModel, Field, PrivateAttr
from pydantic import BaseModel, Field
from sqlalchemy.engine import Engine
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from .db.models import ORMBaseModel
PRODUCTION_MODE = False
# to get a string like this run:
# openssl rand -hex 32
SECRET_KEY = "2f7875b0d2be8a76eba8077ab4d9f8b1c749e02647e9ac9e0f909c3acbfc9856"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto")
SESSION_LOCAL = None
CONFIG_FILE = "tmp/config.json"
class DBType(Enum):
@ -35,6 +19,13 @@ class DBType(Enum):
class DBConfig(BaseModel):
db_type: DBType = DBType.sqlite
@property
def db_engine(self) -> Engine:
return create_engine(
"sqlite:///./tmp/vpn.db",
connect_args={"check_same_thread": False},
)
class JWTConfig(BaseModel):
secret: str | None = None
@ -45,6 +36,13 @@ class JWTConfig(BaseModel):
class CryptoConfig(BaseModel):
schemes: list[str] = ["bcrypt"]
@property
def crypt_context(self) -> CryptContext:
return CryptContext(
schemes=self.schemes,
deprecated="auto",
)
class BaseConfig(BaseModel):
db: DBConfig = Field(default_factory=DBConfig)
@ -53,53 +51,22 @@ class BaseConfig(BaseModel):
@property
def crypt_context(self) -> CryptContext:
return CryptContext(
schemes=self.crypto.schemes,
deprecated="auto",
)
__session_local: sessionmaker = PrivateAttr()
async def connect_db(self) -> None:
engine = create_engine(
"sqlite:///./tmp/vpn.db",
connect_args={"check_same_thread": False},
)
self.__session_local = sessionmaker(
autocommit=False, autoflush=False, bind=engine,
)
ORMBaseModel.metadata.create_all(bind=engine)
return self.crypto.crypt_context
@property
def database(self) -> Session | None:
if self.__session_local is not None:
return self.__session_local()
def db_engine(self) -> Engine:
return self.db.db_engine
CONFIG_FILE = "tmp/config.json"
async def has_config() -> bool:
return Path(CONFIG_FILE).is_file()
async def load_config() -> BaseConfig:
async def get() -> BaseConfig | None:
try:
with open(CONFIG_FILE, "r") as kv:
return BaseConfig.parse_obj(json.load(kv))
with open(CONFIG_FILE, "r") as config_file:
return BaseConfig.parse_obj(json.load(config_file))
except FileNotFoundError:
return BaseConfig()
return None
async def get_db(
config: BaseConfig = Depends(load_config)
) -> Generator[Session | None, None, None]:
if db := config.database is None:
yield None
else:
try:
yield db
finally:
db.close()
def set(config: BaseConfig) -> None:
with open(CONFIG_FILE, "w") as config_file:
config_file.write(config.json(indent=2))

View file

@ -0,0 +1,30 @@
from typing import Generator
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker, Session
from .models import ORMBaseModel
ENGINE: Engine | None = None
SESSION_LOCAL: sessionmaker | None = None
def reconnect(engine: Engine) -> None:
global ENGINE, SESSION_LOCAL
ENGINE = engine
SESSION_LOCAL = sessionmaker(
autocommit=False, autoflush=False, bind=engine,
)
ORMBaseModel.metadata.create_all(bind=engine)
async def get() -> Generator[Session | None, None, None]:
if SESSION_LOCAL is None:
yield None
else:
db = SESSION_LOCAL()
try:
yield db
finally:
db.close()

View file

@ -1,11 +1,14 @@
#!/usr/bin/env python3
import uvicorn
from fastapi import Depends, FastAPI
from fastapi import FastAPI
from . import config
from .routers import admin
from .db import connection
PRODUCTION_MODE = False
from .config import (PRODUCTION_MODE, BaseConfig, has_config,
load_config)
from .routers import install
api = FastAPI(
title="kiwi-vpn API",
@ -22,19 +25,20 @@ api = FastAPI(
redoc_url="/redoc" if not PRODUCTION_MODE else None,
)
api.include_router(install.router)
# api.include_router(auth.router)
# api.include_router(user.router)
app = FastAPI()
app.mount("/api", api)
@app.on_event("startup")
async def on_startup():
if await has_config():
config = await load_config()
await config.connect_db()
# always include admin router
api.include_router(admin.router)
if (current_config := await config.get()) is not None:
connection.reconnect(current_config.db_engine)
# include other routers
# api.include_router(auth.router)
def main():

View file

@ -26,7 +26,11 @@ else:
# POST admin/user
if user table is empty:
if no config file present:
- error
elif user table is empty:
- create new user
- give "admin" cap to new user

View file

@ -0,0 +1,69 @@
from secrets import token_hex
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from .. import config
from ..db import crud, schemas, connection
router = APIRouter(prefix="/admin")
@router.put(
"/config",
responses={
status.HTTP_200_OK: {
"content": None,
},
status.HTTP_403_FORBIDDEN: {
"description": "Must be admin",
"content": None,
},
},
)
async def set_config(
new_config: config.BaseConfig,
current_config: config.BaseConfig | None = Depends(config.get),
):
if current_config is not None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if new_config.jwt.secret is None:
new_config.jwt.secret = token_hex(32)
connection.reconnect(new_config.db_engine)
config.set(new_config)
@router.post(
"/user",
responses={
status.HTTP_200_OK: {
"content": None,
},
status.HTTP_400_BAD_REQUEST: {
"description": "Database doesn't exist",
"content": None,
},
},
)
async def add_user(
user_name: str,
user_password: str,
current_config: config.BaseConfig | None = Depends(config.get),
db: Session | None = Depends(connection.get),
):
if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
crud.create_user(
db=db,
user=schemas.UserCreate(
name=user_name,
password=user_password,
),
crypt_context=current_config.crypt_context,
)
crud.add_user_capability(db, user_name=user_name, capability="admin")

View file

@ -1,101 +0,0 @@
from secrets import token_hex
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from ..config import (CONFIG_FILE, BaseConfig, get_db, has_config,
load_config)
from ..db import crud, schemas
router = APIRouter(prefix="/install")
@router.get(
"/config",
response_model=BaseConfig,
responses={
status.HTTP_403_FORBIDDEN: {
"description": "Must be admin",
"content": None,
},
},
)
async def get_config(
config: BaseConfig = Depends(load_config),
has_config: bool = Depends(has_config),
):
if has_config:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return config
@router.put(
"/config",
responses={
status.HTTP_200_OK: {
"content": None,
},
status.HTTP_403_FORBIDDEN: {
"description": "Must be admin",
"content": None,
},
},
)
async def set_config(
config: BaseConfig,
has_config: bool = Depends(has_config),
):
if has_config:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if config.jwt.secret is None:
config.jwt.secret = token_hex(32)
await config.connect_db()
with open(CONFIG_FILE, "w") as kv:
kv.write(config.json(indent=2))
@router.get("/db", responses={
status.HTTP_200_OK: {
"model": bool,
},
})
async def check_db():
return True
@router.put(
"/db",
responses={
status.HTTP_200_OK: {
"content": None,
},
status.HTTP_400_BAD_REQUEST: {
"description": "Database exists",
"content": None,
},
},
)
async def create_db(
admin_name: str,
admin_password: str,
config: BaseConfig = Depends(load_config),
db: Session = Depends(get_db),
):
# if await has_tables(db):
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# cryptContext = await config.crypto.cryptContext
if db is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
crud.create_user(db, schemas.UserCreate(
name=admin_name,
password=admin_password,
crypt_context=config.crypt_context,
))
crud.add_user_capability(db, user_name=admin_name, capability="admin")

View file

@ -1,58 +0,0 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, status
from jose import JWTError, jwt
from pydantic import BaseModel
from ..config import ALGORITHM, SECRET_KEY
from ..db import User as db_User
from .auth import SCHEME
router = APIRouter(prefix="/user")
class User(BaseModel):
name: str
capabilities: list[str]
@classmethod
def from_db(cls, username: str) -> User | None:
user = db_User.get_by_name(username)
if not user:
return None
return cls(
name=user.name,
capabilities=[cap.capability for cap in user.capabilities],
)
@router.get("/current", response_model=User)
async def get_current_user(token: str = Depends(SCHEME)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = User.from_db(username)
if user is None:
raise credentials_exception
return user
async def is_admin(current_user: User = Depends(get_current_user)):
return ("admin" in current_user.capabilities)
@router.get("/current/is_admin")
async def current_user_is_admin(is_admin: bool = Depends(is_admin)):
return {"is_admin": is_admin}