refactor CRUD utils

This commit is contained in:
Jörn-Michael Miehe 2022-03-18 23:45:09 +00:00
parent b42a5b44f3
commit 8194027c36
6 changed files with 72 additions and 58 deletions

View file

@ -0,0 +1 @@
from .connection import Connection

View file

@ -1,31 +0,0 @@
from sqlalchemy.orm import Session
from passlib.context import CryptContext
from . import models, schemas
def get_user(db: Session, name: str):
return (db
.query(models.User)
.filter(models.User.name == name).first())
def create_user(
db: Session,
user: schemas.UserCreate,
capabilities: list[str],
crypt_context: CryptContext,
):
db_user = models.User(
name=user.name,
password=crypt_context.hash(user.password),
capabilities=[
models.UserCapability(capability=capability)
for capability in capabilities
]
)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user

View file

@ -1,5 +1,12 @@
from __future__ import annotations
from datetime import datetime from datetime import datetime
from pydantic import BaseModel
from passlib.context import CryptContext
from pydantic import BaseModel, validator
from sqlalchemy.orm import Session
from . import models
class CertificateBase(BaseModel): class CertificateBase(BaseModel):
@ -20,6 +27,15 @@ class Certificate(CertificateBase):
class UserBase(BaseModel): class UserBase(BaseModel):
name: str name: str
capabilities: list[str]
@validator("capabilities", pre=True)
@classmethod
def caps_from_orm(cls, value: list[models.UserCapability]) -> list[str]:
return [
capability.capability
for capability in value
]
class UserCreate(UserBase): class UserCreate(UserBase):
@ -27,12 +43,46 @@ class UserCreate(UserBase):
class User(UserBase): class User(UserBase):
capabilities: list[str]
certificates: list[Certificate] certificates: list[Certificate]
class Config: class Config:
orm_mode = True orm_mode = True
@classmethod
def get(
cls,
db: Session,
name: str,
) -> User:
user = (db
.query(models.User)
.filter(models.User.name == name)
.first())
return cls.from_orm(user)
@classmethod
def create(
cls,
db: Session,
user: UserCreate,
crypt_context: CryptContext,
) -> User:
user = models.User(
name=user.name,
password=crypt_context.hash(user.password),
capabilities=[
models.UserCapability(capability=capability)
for capability in user.capabilities
]
)
db.add(user)
db.commit()
db.refresh(user)
return cls.from_orm(user)
class DistinguishedNameBase(BaseModel): class DistinguishedNameBase(BaseModel):
cn_only: bool cn_only: bool

View file

@ -4,7 +4,7 @@ import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from .config import Config, Settings from .config import Config, Settings
from .db.connection import Connection from .db import Connection, schemas
from .routers import admin, user from .routers import admin, user
settings = Settings.get() settings = Settings.get()
@ -29,26 +29,21 @@ api = FastAPI(
app = FastAPI() app = FastAPI()
app.mount("/api", api) app.mount("/api", api)
@app.on_event("startup")
async def on_startup():
# always include admin router
api.include_router(admin.router) api.include_router(admin.router)
if (current_config := await Config.get()) is not None:
Connection.connect(current_config.db_engine)
# async for db in connection.get():
# user = crud.get_user(db, "admin")
# print(user.name)
# for cap in user.capabilities:
# print(cap.capability)
# include other routers
api.include_router(user.router) api.include_router(user.router)
def main(): @app.on_event("startup")
async def on_startup() -> None:
if (current_config := await Config.get()) is not None:
Connection.connect(current_config.db_engine)
async for db in Connection.get():
user = schemas.User.get(db, "admin")
print(str(user))
def main() -> None:
uvicorn.run( uvicorn.run(
"kiwi_vpn_api.main:app", "kiwi_vpn_api.main:app",
host="0.0.0.0", host="0.0.0.0",

View file

@ -4,8 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from ..config import Config from ..config import Config
from ..db import crud, schemas from ..db import Connection, schemas
from ..db.connection import Connection
router = APIRouter(prefix="/admin") router = APIRouter(prefix="/admin")
@ -32,9 +31,8 @@ async def set_config(
if new_config.jwt.secret is None: if new_config.jwt.secret is None:
new_config.jwt.secret = token_hex(32) new_config.jwt.secret = token_hex(32)
Connection.connect(new_config.db_engine)
Config.set(new_config) Config.set(new_config)
Connection.connect(new_config.db_engine)
@router.post( @router.post(
@ -58,12 +56,12 @@ async def add_user(
if current_config is None: if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
crud.create_user( schemas.User.create(
db=db, db=db,
user=schemas.UserCreate( user=schemas.UserCreate(
name=user_name, name=user_name,
password=user_password, password=user_password,
),
capabilities=["admin"], capabilities=["admin"],
),
crypt_context=current_config.crypt_context, crypt_context=current_config.crypt_context,
) )

View file

@ -7,11 +7,12 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from ..config import Config from ..config import Config
from ..db import crud from ..db import Connection, schemas
from ..db.connection import Connection
router = APIRouter(prefix="/user") router = APIRouter(prefix="/user")
SCHEME = OAuth2PasswordBearer(tokenUrl=f".{router.prefix}/token") SCHEME = OAuth2PasswordBearer(
tokenUrl=f".{router.prefix}/token",
)
class Token(BaseModel): class Token(BaseModel):
@ -37,7 +38,7 @@ def create_access_token(
@router.post("/auth", response_model=Token) @router.post("/auth", response_model=Token)
async def login_for_access_token( async def login(
form_data: OAuth2PasswordRequestForm = Depends(), form_data: OAuth2PasswordRequestForm = Depends(),
config: Config = Depends(Config.get), config: Config = Depends(Config.get),
db: Session = Depends(Connection.get), db: Session = Depends(Connection.get),
@ -48,7 +49,7 @@ async def login_for_access_token(
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
user = crud.get_user(db, form_data.username) user = schemas.User.get(db, form_data.username)
if user is None: if user is None:
config.crypt_context.dummy_verify() config.crypt_context.dummy_verify()
raise credentials_exception raise credentials_exception