refactor CRUD utils
This commit is contained in:
parent
b42a5b44f3
commit
8194027c36
6 changed files with 72 additions and 58 deletions
|
@ -0,0 +1 @@
|
|||
from .connection import Connection
|
|
@ -1,31 +0,0 @@
|
|||
from sqlalchemy.orm import Session
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from . import models, schemas
|
||||
|
||||
|
||||
def get_user(db: Session, name: str):
|
||||
return (db
|
||||
.query(models.User)
|
||||
.filter(models.User.name == name).first())
|
||||
|
||||
|
||||
def create_user(
|
||||
db: Session,
|
||||
user: schemas.UserCreate,
|
||||
capabilities: list[str],
|
||||
crypt_context: CryptContext,
|
||||
):
|
||||
db_user = models.User(
|
||||
name=user.name,
|
||||
password=crypt_context.hash(user.password),
|
||||
capabilities=[
|
||||
models.UserCapability(capability=capability)
|
||||
for capability in capabilities
|
||||
]
|
||||
)
|
||||
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
return db_user
|
|
@ -1,5 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel
|
||||
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import BaseModel, validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from . import models
|
||||
|
||||
|
||||
class CertificateBase(BaseModel):
|
||||
|
@ -20,6 +27,15 @@ class Certificate(CertificateBase):
|
|||
|
||||
class UserBase(BaseModel):
|
||||
name: str
|
||||
capabilities: list[str]
|
||||
|
||||
@validator("capabilities", pre=True)
|
||||
@classmethod
|
||||
def caps_from_orm(cls, value: list[models.UserCapability]) -> list[str]:
|
||||
return [
|
||||
capability.capability
|
||||
for capability in value
|
||||
]
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
|
@ -27,12 +43,46 @@ class UserCreate(UserBase):
|
|||
|
||||
|
||||
class User(UserBase):
|
||||
capabilities: list[str]
|
||||
certificates: list[Certificate]
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
@classmethod
|
||||
def get(
|
||||
cls,
|
||||
db: Session,
|
||||
name: str,
|
||||
) -> User:
|
||||
user = (db
|
||||
.query(models.User)
|
||||
.filter(models.User.name == name)
|
||||
.first())
|
||||
|
||||
return cls.from_orm(user)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
db: Session,
|
||||
user: UserCreate,
|
||||
crypt_context: CryptContext,
|
||||
) -> User:
|
||||
user = models.User(
|
||||
name=user.name,
|
||||
password=crypt_context.hash(user.password),
|
||||
capabilities=[
|
||||
models.UserCapability(capability=capability)
|
||||
for capability in user.capabilities
|
||||
]
|
||||
)
|
||||
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
return cls.from_orm(user)
|
||||
|
||||
|
||||
class DistinguishedNameBase(BaseModel):
|
||||
cn_only: bool
|
||||
|
|
|
@ -4,7 +4,7 @@ import uvicorn
|
|||
from fastapi import FastAPI
|
||||
|
||||
from .config import Config, Settings
|
||||
from .db.connection import Connection
|
||||
from .db import Connection, schemas
|
||||
from .routers import admin, user
|
||||
|
||||
settings = Settings.get()
|
||||
|
@ -29,26 +29,21 @@ api = FastAPI(
|
|||
app = FastAPI()
|
||||
app.mount("/api", api)
|
||||
|
||||
api.include_router(admin.router)
|
||||
api.include_router(user.router)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
# always include admin router
|
||||
api.include_router(admin.router)
|
||||
|
||||
async def on_startup() -> None:
|
||||
if (current_config := await Config.get()) is not None:
|
||||
Connection.connect(current_config.db_engine)
|
||||
|
||||
# async for db in connection.get():
|
||||
# user = crud.get_user(db, "admin")
|
||||
# print(user.name)
|
||||
# for cap in user.capabilities:
|
||||
# print(cap.capability)
|
||||
|
||||
# include other routers
|
||||
api.include_router(user.router)
|
||||
async for db in Connection.get():
|
||||
user = schemas.User.get(db, "admin")
|
||||
print(str(user))
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
uvicorn.run(
|
||||
"kiwi_vpn_api.main:app",
|
||||
host="0.0.0.0",
|
||||
|
|
|
@ -4,8 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..config import Config
|
||||
from ..db import crud, schemas
|
||||
from ..db.connection import Connection
|
||||
from ..db import Connection, schemas
|
||||
|
||||
router = APIRouter(prefix="/admin")
|
||||
|
||||
|
@ -32,9 +31,8 @@ async def set_config(
|
|||
if new_config.jwt.secret is None:
|
||||
new_config.jwt.secret = token_hex(32)
|
||||
|
||||
Connection.connect(new_config.db_engine)
|
||||
|
||||
Config.set(new_config)
|
||||
Connection.connect(new_config.db_engine)
|
||||
|
||||
|
||||
@router.post(
|
||||
|
@ -58,12 +56,12 @@ async def add_user(
|
|||
if current_config is None:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
crud.create_user(
|
||||
schemas.User.create(
|
||||
db=db,
|
||||
user=schemas.UserCreate(
|
||||
name=user_name,
|
||||
password=user_password,
|
||||
capabilities=["admin"],
|
||||
),
|
||||
capabilities=["admin"],
|
||||
crypt_context=current_config.crypt_context,
|
||||
)
|
||||
|
|
|
@ -7,11 +7,12 @@ from pydantic import BaseModel
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..config import Config
|
||||
from ..db import crud
|
||||
from ..db.connection import Connection
|
||||
from ..db import Connection, schemas
|
||||
|
||||
router = APIRouter(prefix="/user")
|
||||
SCHEME = OAuth2PasswordBearer(tokenUrl=f".{router.prefix}/token")
|
||||
SCHEME = OAuth2PasswordBearer(
|
||||
tokenUrl=f".{router.prefix}/token",
|
||||
)
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
|
@ -37,7 +38,7 @@ def create_access_token(
|
|||
|
||||
|
||||
@router.post("/auth", response_model=Token)
|
||||
async def login_for_access_token(
|
||||
async def login(
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
config: Config = Depends(Config.get),
|
||||
db: Session = Depends(Connection.get),
|
||||
|
@ -48,7 +49,7 @@ async def login_for_access_token(
|
|||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
user = crud.get_user(db, form_data.username)
|
||||
user = schemas.User.get(db, form_data.username)
|
||||
if user is None:
|
||||
config.crypt_context.dummy_verify()
|
||||
raise credentials_exception
|
||||
|
|
Loading…
Reference in a new issue