adhere to a plan for a change :)

This commit is contained in:
Jörn-Michael Miehe 2022-03-18 18:22:17 +00:00
parent e34e669f79
commit 746ca51bdd
7 changed files with 148 additions and 233 deletions

View file

@ -1,30 +1,14 @@
from __future__ import annotations
import json
from enum import Enum from enum import Enum
from pathlib import Path import json
from typing import Generator
from fastapi import Depends
from jose.constants import ALGORITHMS from jose.constants import ALGORITHMS
from passlib.context import CryptContext from passlib.context import CryptContext
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field
from sqlalchemy.engine import Engine
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from .db.models import ORMBaseModel CONFIG_FILE = "tmp/config.json"
PRODUCTION_MODE = False
# to get a string like this run:
# openssl rand -hex 32
SECRET_KEY = "2f7875b0d2be8a76eba8077ab4d9f8b1c749e02647e9ac9e0f909c3acbfc9856"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto")
SESSION_LOCAL = None
class DBType(Enum): class DBType(Enum):
@ -35,6 +19,13 @@ class DBType(Enum):
class DBConfig(BaseModel): class DBConfig(BaseModel):
db_type: DBType = DBType.sqlite db_type: DBType = DBType.sqlite
@property
def db_engine(self) -> Engine:
return create_engine(
"sqlite:///./tmp/vpn.db",
connect_args={"check_same_thread": False},
)
class JWTConfig(BaseModel): class JWTConfig(BaseModel):
secret: str | None = None secret: str | None = None
@ -45,6 +36,13 @@ class JWTConfig(BaseModel):
class CryptoConfig(BaseModel): class CryptoConfig(BaseModel):
schemes: list[str] = ["bcrypt"] schemes: list[str] = ["bcrypt"]
@property
def crypt_context(self) -> CryptContext:
return CryptContext(
schemes=self.schemes,
deprecated="auto",
)
class BaseConfig(BaseModel): class BaseConfig(BaseModel):
db: DBConfig = Field(default_factory=DBConfig) db: DBConfig = Field(default_factory=DBConfig)
@ -53,53 +51,22 @@ class BaseConfig(BaseModel):
@property @property
def crypt_context(self) -> CryptContext: def crypt_context(self) -> CryptContext:
return CryptContext( return self.crypto.crypt_context
schemes=self.crypto.schemes,
deprecated="auto",
)
__session_local: sessionmaker = PrivateAttr()
async def connect_db(self) -> None:
engine = create_engine(
"sqlite:///./tmp/vpn.db",
connect_args={"check_same_thread": False},
)
self.__session_local = sessionmaker(
autocommit=False, autoflush=False, bind=engine,
)
ORMBaseModel.metadata.create_all(bind=engine)
@property @property
def database(self) -> Session | None: def db_engine(self) -> Engine:
if self.__session_local is not None: return self.db.db_engine
return self.__session_local()
CONFIG_FILE = "tmp/config.json" async def get() -> BaseConfig | None:
async def has_config() -> bool:
return Path(CONFIG_FILE).is_file()
async def load_config() -> BaseConfig:
try: try:
with open(CONFIG_FILE, "r") as kv: with open(CONFIG_FILE, "r") as config_file:
return BaseConfig.parse_obj(json.load(kv)) return BaseConfig.parse_obj(json.load(config_file))
except FileNotFoundError: except FileNotFoundError:
return BaseConfig() return None
async def get_db( def set(config: BaseConfig) -> None:
config: BaseConfig = Depends(load_config) with open(CONFIG_FILE, "w") as config_file:
) -> Generator[Session | None, None, None]: config_file.write(config.json(indent=2))
if db := config.database is None:
yield None
else:
try:
yield db
finally:
db.close()

View file

@ -0,0 +1,30 @@
from typing import Generator
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker, Session
from .models import ORMBaseModel
ENGINE: Engine | None = None
SESSION_LOCAL: sessionmaker | None = None
def reconnect(engine: Engine) -> None:
global ENGINE, SESSION_LOCAL
ENGINE = engine
SESSION_LOCAL = sessionmaker(
autocommit=False, autoflush=False, bind=engine,
)
ORMBaseModel.metadata.create_all(bind=engine)
async def get() -> Generator[Session | None, None, None]:
if SESSION_LOCAL is None:
yield None
else:
db = SESSION_LOCAL()
try:
yield db
finally:
db.close()

View file

@ -1,11 +1,14 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import uvicorn import uvicorn
from fastapi import Depends, FastAPI from fastapi import FastAPI
from . import config
from .routers import admin
from .db import connection
PRODUCTION_MODE = False
from .config import (PRODUCTION_MODE, BaseConfig, has_config,
load_config)
from .routers import install
api = FastAPI( api = FastAPI(
title="kiwi-vpn API", title="kiwi-vpn API",
@ -22,19 +25,20 @@ api = FastAPI(
redoc_url="/redoc" if not PRODUCTION_MODE else None, redoc_url="/redoc" if not PRODUCTION_MODE else None,
) )
api.include_router(install.router)
# api.include_router(auth.router)
# api.include_router(user.router)
app = FastAPI() app = FastAPI()
app.mount("/api", api) app.mount("/api", api)
@app.on_event("startup") @app.on_event("startup")
async def on_startup(): async def on_startup():
if await has_config(): # always include admin router
config = await load_config() api.include_router(admin.router)
await config.connect_db()
if (current_config := await config.get()) is not None:
connection.reconnect(current_config.db_engine)
# include other routers
# api.include_router(auth.router)
def main(): def main():

View file

@ -26,7 +26,11 @@ else:
# POST admin/user # POST admin/user
if user table is empty: if no config file present:
- error
elif user table is empty:
- create new user - create new user
- give "admin" cap to new user - give "admin" cap to new user

View file

@ -0,0 +1,69 @@
from secrets import token_hex
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from .. import config
from ..db import crud, schemas, connection
router = APIRouter(prefix="/admin")
@router.put(
"/config",
responses={
status.HTTP_200_OK: {
"content": None,
},
status.HTTP_403_FORBIDDEN: {
"description": "Must be admin",
"content": None,
},
},
)
async def set_config(
new_config: config.BaseConfig,
current_config: config.BaseConfig | None = Depends(config.get),
):
if current_config is not None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if new_config.jwt.secret is None:
new_config.jwt.secret = token_hex(32)
connection.reconnect(new_config.db_engine)
config.set(new_config)
@router.post(
"/user",
responses={
status.HTTP_200_OK: {
"content": None,
},
status.HTTP_400_BAD_REQUEST: {
"description": "Database doesn't exist",
"content": None,
},
},
)
async def add_user(
user_name: str,
user_password: str,
current_config: config.BaseConfig | None = Depends(config.get),
db: Session | None = Depends(connection.get),
):
if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
crud.create_user(
db=db,
user=schemas.UserCreate(
name=user_name,
password=user_password,
),
crypt_context=current_config.crypt_context,
)
crud.add_user_capability(db, user_name=user_name, capability="admin")

View file

@ -1,101 +0,0 @@
from secrets import token_hex
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from ..config import (CONFIG_FILE, BaseConfig, get_db, has_config,
load_config)
from ..db import crud, schemas
router = APIRouter(prefix="/install")
@router.get(
"/config",
response_model=BaseConfig,
responses={
status.HTTP_403_FORBIDDEN: {
"description": "Must be admin",
"content": None,
},
},
)
async def get_config(
config: BaseConfig = Depends(load_config),
has_config: bool = Depends(has_config),
):
if has_config:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return config
@router.put(
"/config",
responses={
status.HTTP_200_OK: {
"content": None,
},
status.HTTP_403_FORBIDDEN: {
"description": "Must be admin",
"content": None,
},
},
)
async def set_config(
config: BaseConfig,
has_config: bool = Depends(has_config),
):
if has_config:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if config.jwt.secret is None:
config.jwt.secret = token_hex(32)
await config.connect_db()
with open(CONFIG_FILE, "w") as kv:
kv.write(config.json(indent=2))
@router.get("/db", responses={
status.HTTP_200_OK: {
"model": bool,
},
})
async def check_db():
return True
@router.put(
"/db",
responses={
status.HTTP_200_OK: {
"content": None,
},
status.HTTP_400_BAD_REQUEST: {
"description": "Database exists",
"content": None,
},
},
)
async def create_db(
admin_name: str,
admin_password: str,
config: BaseConfig = Depends(load_config),
db: Session = Depends(get_db),
):
# if await has_tables(db):
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# cryptContext = await config.crypto.cryptContext
if db is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
crud.create_user(db, schemas.UserCreate(
name=admin_name,
password=admin_password,
crypt_context=config.crypt_context,
))
crud.add_user_capability(db, user_name=admin_name, capability="admin")

View file

@ -1,58 +0,0 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, status
from jose import JWTError, jwt
from pydantic import BaseModel
from ..config import ALGORITHM, SECRET_KEY
from ..db import User as db_User
from .auth import SCHEME
router = APIRouter(prefix="/user")
class User(BaseModel):
name: str
capabilities: list[str]
@classmethod
def from_db(cls, username: str) -> User | None:
user = db_User.get_by_name(username)
if not user:
return None
return cls(
name=user.name,
capabilities=[cap.capability for cap in user.capabilities],
)
@router.get("/current", response_model=User)
async def get_current_user(token: str = Depends(SCHEME)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = User.from_db(username)
if user is None:
raise credentials_exception
return user
async def is_admin(current_user: User = Depends(get_current_user)):
return ("admin" in current_user.capabilities)
@router.get("/current/is_admin")
async def current_user_is_admin(is_admin: bool = Depends(is_admin)):
return {"is_admin": is_admin}