dirty commit

This commit is contained in:
Jörn-Michael Miehe 2022-03-18 17:36:44 +00:00
parent 641dfd7ba0
commit e34e669f79
6 changed files with 101 additions and 35 deletions

View file

@ -3,11 +3,12 @@ from __future__ import annotations
import json import json
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Generator
from fastapi import Depends from fastapi import Depends
from jose.constants import ALGORITHMS from jose.constants import ALGORITHMS
from passlib.context import CryptContext from passlib.context import CryptContext
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, PrivateAttr
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
@ -50,6 +51,30 @@ class BaseConfig(BaseModel):
jwt: JWTConfig = Field(default_factory=JWTConfig) jwt: JWTConfig = Field(default_factory=JWTConfig)
crypto: CryptoConfig = Field(default_factory=CryptoConfig) crypto: CryptoConfig = Field(default_factory=CryptoConfig)
@property
def crypt_context(self) -> CryptContext:
return CryptContext(
schemes=self.crypto.schemes,
deprecated="auto",
)
__session_local: sessionmaker = PrivateAttr()
async def connect_db(self) -> None:
engine = create_engine(
"sqlite:///./tmp/vpn.db",
connect_args={"check_same_thread": False},
)
self.__session_local = sessionmaker(
autocommit=False, autoflush=False, bind=engine,
)
ORMBaseModel.metadata.create_all(bind=engine)
@property
def database(self) -> Session | None:
if self.__session_local is not None:
return self.__session_local()
CONFIG_FILE = "tmp/config.json" CONFIG_FILE = "tmp/config.json"
@ -67,24 +92,14 @@ async def load_config() -> BaseConfig:
return BaseConfig() return BaseConfig()
async def connect_db(config: BaseConfig = Depends(load_config)) -> None: async def get_db(
global SESSION_LOCAL config: BaseConfig = Depends(load_config)
) -> Generator[Session | None, None, None]:
if db := config.database is None:
yield None
engine = create_engine( else:
"sqlite:///./tmp/vpn.db", try:
connect_args={"check_same_thread": False}, yield db
) finally:
SESSION_LOCAL = sessionmaker( db.close()
autocommit=False, autoflush=False, bind=engine)
ORMBaseModel.metadata.create_all(bind=engine)
async def get_db() -> Session:
global SESSION_LOCAL
db = SESSION_LOCAL()
try:
yield db
finally:
db.close()

View file

@ -1,4 +1,5 @@
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from passlib.context import CryptContext
from . import models, schemas from . import models, schemas
@ -9,10 +10,14 @@ def get_user(db: Session, name: str):
.filter(models.User.name == name).first()) .filter(models.User.name == name).first())
def create_user(db: Session, user: schemas.UserCreate): def create_user(
db: Session,
user: schemas.UserCreate,
crypt_context: CryptContext
):
db_user = models.User( db_user = models.User(
name=user.name, name=user.name,
password=user.password + "notreallyhashed", password=crypt_context.hash(user.password),
) )
db.add(db_user) db.add(db_user)
db.commit() db.commit()

View file

@ -1,9 +1,10 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import Depends, FastAPI
from .config import PRODUCTION_MODE from .config import (PRODUCTION_MODE, BaseConfig, has_config,
load_config)
from .routers import install from .routers import install
api = FastAPI( api = FastAPI(
@ -29,6 +30,13 @@ app = FastAPI()
app.mount("/api", api) app.mount("/api", api)
@app.on_event("startup")
async def on_startup():
if await has_config():
config = await load_config()
await config.connect_db()
def main(): def main():
uvicorn.run( uvicorn.run(
"kiwi_vpn_api.main:app", "kiwi_vpn_api.main:app",

39
api/kiwi_vpn_api/plan.md Normal file
View file

@ -0,0 +1,39 @@
# Startup
if config file present:
- load config file
- connect to DB
- mount all routers
else:
- mount admin router
# PUT admin/config
if config file present:
- if user is admin:
- overwrite config
- reload config, reconnect to DB
else:
- overwrite config
- reload config, connect to DB
- mount all routers
# POST admin/user
if user table is empty:
- create new user
- give "admin" cap to new user
else:
- if user is admin:
- create new user
...

View file

@ -3,7 +3,7 @@ from secrets import token_hex
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from ..config import (CONFIG_FILE, BaseConfig, connect_db, get_db, has_config, from ..config import (CONFIG_FILE, BaseConfig, get_db, has_config,
load_config) load_config)
from ..db import crud, schemas from ..db import crud, schemas
@ -52,7 +52,7 @@ async def set_config(
if config.jwt.secret is None: if config.jwt.secret is None:
config.jwt.secret = token_hex(32) config.jwt.secret = token_hex(32)
await connect_db(config) await config.connect_db()
with open(CONFIG_FILE, "w") as kv: with open(CONFIG_FILE, "w") as kv:
kv.write(config.json(indent=2)) kv.write(config.json(indent=2))
@ -82,17 +82,20 @@ async def check_db():
async def create_db( async def create_db(
admin_name: str, admin_name: str,
admin_password: str, admin_password: str,
config: BaseConfig = Depends(load_config),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
# if await has_tables(db): # if await has_tables(db):
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) # raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# db.create_tables([Certificate, DistinguishedName, User, UserCapability])
# cryptContext = await config.crypto.cryptContext # cryptContext = await config.crypto.cryptContext
if db is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
crud.create_user(db, schemas.UserCreate( crud.create_user(db, schemas.UserCreate(
name=admin_name, name=admin_name,
password=admin_password, password=admin_password,
crypt_context=config.crypt_context,
)) ))
crud.add_user_capability(db, user_name=admin_name, capability="admin") crud.add_user_capability(db, user_name=admin_name, capability="admin")

View file

@ -29,6 +29,7 @@ class User(BaseModel):
) )
@router.get("/current", response_model=User)
async def get_current_user(token: str = Depends(SCHEME)): async def get_current_user(token: str = Depends(SCHEME)):
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -48,15 +49,10 @@ async def get_current_user(token: str = Depends(SCHEME)):
return user return user
@router.get("/current_user/get", response_model=User)
async def get_current_user(current_user: User = Depends(get_current_user)):
return current_user
async def is_admin(current_user: User = Depends(get_current_user)): async def is_admin(current_user: User = Depends(get_current_user)):
return ("admin" in current_user.capabilities) return ("admin" in current_user.capabilities)
@router.get("/current_user/is_admin") @router.get("/current/is_admin")
async def current_user_is_admin(is_admin: bool = Depends(is_admin)): async def current_user_is_admin(is_admin: bool = Depends(is_admin)):
return {"is_admin": is_admin} return {"is_admin": is_admin}