refactor CRUD utils

This commit is contained in:
Jörn-Michael Miehe 2022-03-18 23:45:09 +00:00
parent b42a5b44f3
commit 8194027c36
6 changed files with 72 additions and 58 deletions

View file

@ -0,0 +1 @@
from .connection import Connection

View file

@ -1,31 +0,0 @@
from sqlalchemy.orm import Session
from passlib.context import CryptContext
from . import models, schemas
def get_user(db: Session, name: str):
return (db
.query(models.User)
.filter(models.User.name == name).first())
def create_user(
db: Session,
user: schemas.UserCreate,
capabilities: list[str],
crypt_context: CryptContext,
):
db_user = models.User(
name=user.name,
password=crypt_context.hash(user.password),
capabilities=[
models.UserCapability(capability=capability)
for capability in capabilities
]
)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user

View file

@ -1,5 +1,12 @@
from __future__ import annotations
from datetime import datetime
from pydantic import BaseModel
from passlib.context import CryptContext
from pydantic import BaseModel, validator
from sqlalchemy.orm import Session
from . import models
class CertificateBase(BaseModel):
@ -20,6 +27,15 @@ class Certificate(CertificateBase):
class UserBase(BaseModel):
name: str
capabilities: list[str]
@validator("capabilities", pre=True)
@classmethod
def caps_from_orm(cls, value: list[models.UserCapability]) -> list[str]:
return [
capability.capability
for capability in value
]
class UserCreate(UserBase):
@ -27,12 +43,46 @@ class UserCreate(UserBase):
class User(UserBase):
capabilities: list[str]
certificates: list[Certificate]
class Config:
orm_mode = True
@classmethod
def get(
cls,
db: Session,
name: str,
) -> User:
user = (db
.query(models.User)
.filter(models.User.name == name)
.first())
return cls.from_orm(user)
@classmethod
def create(
cls,
db: Session,
user: UserCreate,
crypt_context: CryptContext,
) -> User:
user = models.User(
name=user.name,
password=crypt_context.hash(user.password),
capabilities=[
models.UserCapability(capability=capability)
for capability in user.capabilities
]
)
db.add(user)
db.commit()
db.refresh(user)
return cls.from_orm(user)
class DistinguishedNameBase(BaseModel):
cn_only: bool

View file

@ -4,7 +4,7 @@ import uvicorn
from fastapi import FastAPI
from .config import Config, Settings
from .db.connection import Connection
from .db import Connection, schemas
from .routers import admin, user
settings = Settings.get()
@ -29,26 +29,21 @@ api = FastAPI(
app = FastAPI()
app.mount("/api", api)
@app.on_event("startup")
async def on_startup():
# always include admin router
api.include_router(admin.router)
if (current_config := await Config.get()) is not None:
Connection.connect(current_config.db_engine)
# async for db in connection.get():
# user = crud.get_user(db, "admin")
# print(user.name)
# for cap in user.capabilities:
# print(cap.capability)
# include other routers
api.include_router(user.router)
def main():
@app.on_event("startup")
async def on_startup() -> None:
if (current_config := await Config.get()) is not None:
Connection.connect(current_config.db_engine)
async for db in Connection.get():
user = schemas.User.get(db, "admin")
print(str(user))
def main() -> None:
uvicorn.run(
"kiwi_vpn_api.main:app",
host="0.0.0.0",

View file

@ -4,8 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from ..config import Config
from ..db import crud, schemas
from ..db.connection import Connection
from ..db import Connection, schemas
router = APIRouter(prefix="/admin")
@ -32,9 +31,8 @@ async def set_config(
if new_config.jwt.secret is None:
new_config.jwt.secret = token_hex(32)
Connection.connect(new_config.db_engine)
Config.set(new_config)
Connection.connect(new_config.db_engine)
@router.post(
@ -58,12 +56,12 @@ async def add_user(
if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
crud.create_user(
schemas.User.create(
db=db,
user=schemas.UserCreate(
name=user_name,
password=user_password,
),
capabilities=["admin"],
),
crypt_context=current_config.crypt_context,
)

View file

@ -7,11 +7,12 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from ..config import Config
from ..db import crud
from ..db.connection import Connection
from ..db import Connection, schemas
router = APIRouter(prefix="/user")
SCHEME = OAuth2PasswordBearer(tokenUrl=f".{router.prefix}/token")
SCHEME = OAuth2PasswordBearer(
tokenUrl=f".{router.prefix}/token",
)
class Token(BaseModel):
@ -37,7 +38,7 @@ def create_access_token(
@router.post("/auth", response_model=Token)
async def login_for_access_token(
async def login(
form_data: OAuth2PasswordRequestForm = Depends(),
config: Config = Depends(Config.get),
db: Session = Depends(Connection.get),
@ -48,7 +49,7 @@ async def login_for_access_token(
headers={"WWW-Authenticate": "Bearer"},
)
user = crud.get_user(db, form_data.username)
user = schemas.User.get(db, form_data.username)
if user is None:
config.crypt_context.dummy_verify()
raise credentials_exception