dirty commit

This commit is contained in:
Jörn-Michael Miehe 2022-03-18 17:36:44 +00:00
parent 641dfd7ba0
commit e34e669f79
6 changed files with 101 additions and 35 deletions

View file

@ -3,11 +3,12 @@ from __future__ import annotations
import json
from enum import Enum
from pathlib import Path
from typing import Generator
from fastapi import Depends
from jose.constants import ALGORITHMS
from passlib.context import CryptContext
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
@ -50,6 +51,30 @@ class BaseConfig(BaseModel):
jwt: JWTConfig = Field(default_factory=JWTConfig)
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
@property
def crypt_context(self) -> CryptContext:
return CryptContext(
schemes=self.crypto.schemes,
deprecated="auto",
)
__session_local: sessionmaker = PrivateAttr()
async def connect_db(self) -> None:
engine = create_engine(
"sqlite:///./tmp/vpn.db",
connect_args={"check_same_thread": False},
)
self.__session_local = sessionmaker(
autocommit=False, autoflush=False, bind=engine,
)
ORMBaseModel.metadata.create_all(bind=engine)
@property
def database(self) -> Session | None:
if self.__session_local is not None:
return self.__session_local()
CONFIG_FILE = "tmp/config.json"
@ -67,23 +92,13 @@ async def load_config() -> BaseConfig:
return BaseConfig()
async def connect_db(config: BaseConfig = Depends(load_config)) -> None:
global SESSION_LOCAL
async def get_db(
config: BaseConfig = Depends(load_config)
) -> Generator[Session | None, None, None]:
if db := config.database is None:
yield None
engine = create_engine(
"sqlite:///./tmp/vpn.db",
connect_args={"check_same_thread": False},
)
SESSION_LOCAL = sessionmaker(
autocommit=False, autoflush=False, bind=engine)
ORMBaseModel.metadata.create_all(bind=engine)
async def get_db() -> Session:
global SESSION_LOCAL
db = SESSION_LOCAL()
else:
try:
yield db
finally:

View file

@ -1,4 +1,5 @@
from sqlalchemy.orm import Session
from passlib.context import CryptContext
from . import models, schemas
@ -9,10 +10,14 @@ def get_user(db: Session, name: str):
.filter(models.User.name == name).first())
def create_user(db: Session, user: schemas.UserCreate):
def create_user(
db: Session,
user: schemas.UserCreate,
crypt_context: CryptContext
):
db_user = models.User(
name=user.name,
password=user.password + "notreallyhashed",
password=crypt_context.hash(user.password),
)
db.add(db_user)
db.commit()

View file

@ -1,9 +1,10 @@
#!/usr/bin/env python3
import uvicorn
from fastapi import FastAPI
from fastapi import Depends, FastAPI
from .config import PRODUCTION_MODE
from .config import (PRODUCTION_MODE, BaseConfig, has_config,
load_config)
from .routers import install
api = FastAPI(
@ -29,6 +30,13 @@ app = FastAPI()
app.mount("/api", api)
@app.on_event("startup")
async def on_startup():
if await has_config():
config = await load_config()
await config.connect_db()
def main():
uvicorn.run(
"kiwi_vpn_api.main:app",

39
api/kiwi_vpn_api/plan.md Normal file
View file

@ -0,0 +1,39 @@
# Startup
if config file present:
- load config file
- connect to DB
- mount all routers
else:
- mount admin router
# PUT admin/config
if config file present:
- if user is admin:
- overwrite config
- reload config, reconnect to DB
else:
- overwrite config
- reload config, connect to DB
- mount all routers
# POST admin/user
if user table is empty:
- create new user
- give "admin" cap to new user
else:
- if user is admin:
- create new user
...

View file

@ -3,7 +3,7 @@ from secrets import token_hex
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from ..config import (CONFIG_FILE, BaseConfig, connect_db, get_db, has_config,
from ..config import (CONFIG_FILE, BaseConfig, get_db, has_config,
load_config)
from ..db import crud, schemas
@ -52,7 +52,7 @@ async def set_config(
if config.jwt.secret is None:
config.jwt.secret = token_hex(32)
await connect_db(config)
await config.connect_db()
with open(CONFIG_FILE, "w") as kv:
kv.write(config.json(indent=2))
@ -82,17 +82,20 @@ async def check_db():
async def create_db(
admin_name: str,
admin_password: str,
config: BaseConfig = Depends(load_config),
db: Session = Depends(get_db),
):
# if await has_tables(db):
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# db.create_tables([Certificate, DistinguishedName, User, UserCapability])
# cryptContext = await config.crypto.cryptContext
if db is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
crud.create_user(db, schemas.UserCreate(
name=admin_name,
password=admin_password,
crypt_context=config.crypt_context,
))
crud.add_user_capability(db, user_name=admin_name, capability="admin")

View file

@ -29,6 +29,7 @@ class User(BaseModel):
)
@router.get("/current", response_model=User)
async def get_current_user(token: str = Depends(SCHEME)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -48,15 +49,10 @@ async def get_current_user(token: str = Depends(SCHEME)):
return user
@router.get("/current_user/get", response_model=User)
async def get_current_user(current_user: User = Depends(get_current_user)):
return current_user
async def is_admin(current_user: User = Depends(get_current_user)):
return ("admin" in current_user.capabilities)
@router.get("/current_user/is_admin")
@router.get("/current/is_admin")
async def current_user_is_admin(is_admin: bool = Depends(is_admin)):
return {"is_admin": is_admin}